FreeTensor
Loading...
Searching...
No Matches
code_gen_c.h
Go to the documentation of this file.
1#ifndef DETAIL_CODE_GEN_C_H
2#define DETAIL_CODE_GEN_C_H
3
4#include <algorithm>
5#include <cmath>
6#include <functional>
7#include <vector>
8
9#include <analyze/find_stmt.h>
10#include <codegen/code_gen_c.h>
11#include <config.h>
12#include <container_utils.h>
13#include <serialize/mangle.h>
14
15#include "code_gen.h"
16
17namespace freetensor {
18
19template <class Stream>
20std::function<std::ostream &(std::ostream &)>
21CodeGenC<Stream>::genMdPtrType(const VarDef &def, bool isConst) {
22 // NOTE: `[=]` implicitly capturing `this` is deprecated in C++20. Using
23 // `[=]` will trigger a warning in GCC (because of deprecation), but using
24 // `[=, this]` will trigger a warning in Clang<17 (because it will think
25 // `this` is duplicated).
26#if defined(__clang__) && __clang_major__ < 17
27 return [=](std::ostream &os) -> std::ostream & {
28#else
29 return [=, this](std::ostream &os) -> std::ostream & {
30#endif
31 auto &&buf = def->buffer_;
32
33 if (buf->tensor()->shape().empty()) {
34 // Use reference for scalars
35 if (isConst) {
36 os << "const ";
37 }
38 os << gen(buf->tensor()->dtype()) << " &";
39 return os;
40 }
41
42 bool isRestricted = true;
43 if (def->viewOf_.has_value() ||
44 !findAllStmt(def, [&](const Stmt &inner) {
45 return inner->nodeType() == ASTNodeType::VarDef &&
46 inner.as<VarDefNode>()->viewOf_ == def->name_;
47 }).empty()) {
48 isRestricted = false;
49 }
51 os << (Config::debugRuntimeCheck() ? "mdspan_dbg<"
52 : isRestricted ? "mdspan_r<"
53 : "mdspan<");
54 if (isConst) {
55 os << "const ";
56 }
57 os << gen(buf->tensor()->dtype()) << ", extents<";
58 for (auto &&[i, dim] : views::enumerate(buf->tensor()->shape())) {
59 os << (i > 0 ? ", " : "");
60 if (dim->nodeType() == ASTNodeType::IntConst) {
61 os << dim.template as<IntConstNode>()->val_;
62 } else {
63 os << "dynamic_extent";
64 }
65 }
66 return os << ">>";
67 };
70template <class Stream>
72 const std::function<void()> &genRawPtr,
73 bool isConst) {
74 auto &&buf = def->buffer_;
76 if (buf->tensor()->shape().empty()) {
77 // Use reference for scalars
78 // e.g.
79 // ((const int32_t &)*((const int32_t *)(...)))
80 this->os() << "((" << genMdPtrType(def, isConst) << ")*((";
81 if (isConst) {
82 this->os() << "const ";
83 }
84 this->os() << gen(buf->tensor()->dtype()) << " *)(";
85 genRawPtr();
86 this->os() << ")))";
87 return;
88 }
90 this->os() << genMdPtrType(def, isConst) << "((";
91 if (isConst) {
92 this->os() << "const ";
93 }
94 this->os() << gen(buf->tensor()->dtype()) << "*)(";
95 genRawPtr();
96 this->os() << ")";
97 for (auto &&dim : buf->tensor()->shape()) {
98 if (dim->nodeType() != ASTNodeType::IntConst) {
99 this->os() << ", ";
100 (*this)(dim);
101 }
102 }
103 this->os() << ")";
104}
105
106template <class Stream>
108 const std::vector<Expr> &indices) {
109 if (def->buffer_->mtype() == MemType::ByValue) {
110 // __ByValArray
111 this->os() << mangle(def->name_);
112 for (auto &&index : indices) {
113 this->os() << "[";
114 (*this)(index);
115 this->os() << "]";
116 }
117 } else {
118 this->os() << mangle(def->name_);
119 if (!def->buffer_->tensor()->shape().empty()) {
120 // TODO: Switch bracket after C++23
121 this->os() << "(";
122 for (auto &&[i, index] : views::enumerate(indices)) {
123 this->os() << (i > 0 ? ", " : "");
124 (*this)(index);
125 }
126 this->os() << ")";
127 }
128 }
129}
130
131template <class Stream> void CodeGenC<Stream>::visit(const StmtSeq &op) {
132 for (auto &&stmt : op->stmts_) {
133 if (stmt->nodeType() == ASTNodeType::VarDef) {
134 this->makeIndent();
135 this->beginBlock();
136 (*this)(stmt);
137 this->endBlock();
138 } else {
139 (*this)(stmt);
140 }
141 }
142}
143
144template <class Stream> void CodeGenC<Stream>::visit(const VarDef &op) {
145 this->makeIndent();
146 auto &&tensor = op->buffer_->tensor();
147 auto &&shape = tensor->shape();
148 auto name = mangle(op->name_);
149
150 if (op->viewOf_.has_value()) {
151 // e.g.
152 // auto &&x = mdspan_r<const float, extents<5, 5>>(y.data_handle());
153 auto source = op;
154 while (source->viewOf_.has_value()) {
155 source = this->def(*source->viewOf_);
156 }
157 this->os() << "auto &&" << name << " = ";
158 genMdPtrDef(op, mangle(source->name_) + ".data_handle()",
159 !isWritable(source->buffer_->atype()));
160 this->os() << ";" << std::endl;
161
162 } else if (!isInputting(op->buffer_->atype()) &&
163 !isOutputting(op->buffer_->atype())) {
164 // e.g. 1. float x;
165 // 2. float x[5][5][5];
166 this->os() << gen(tensor->dtype()) << " " << name;
167 for (auto &&dim : shape) {
168 this->os() << "[";
169 (*this)(dim);
170 this->os() << "]";
171 }
172 this->os() << ";" << std::endl;
173 } else {
174 auto paramPositions = ranges::to<std::vector>(
175 params_ | views::enumerate | views::filter([&](auto &&pair) {
176 return pair.second.name_ == op->name_;
177 }) |
178 views::keys);
179 auto returnPositions = ranges::to<std::vector>(
180 returns_ | views::enumerate | views::filter([&](auto &&pair) {
181 return pair.second.name_ == op->name_;
182 }) |
183 views::keys);
184 bool isParam = !paramPositions.empty();
185 bool isReturn = !returnPositions.empty();
186 if (!isParam && !isReturn) {
187 throw InvalidProgram("I/O variable " + op->name_ +
188 " used but not defined as a function's "
189 "parameters or return values");
190 }
191 std::string rawPtr;
192 if (isParam) {
193 if (paramPositions.size() > 1) {
194 throw InvalidProgram("Parameter '" + op->name_ +
195 "' is duplicated");
196 }
197 int nthParam = paramPositions.front();
198 rawPtr = "params[" + std::to_string(nthParam) + "]";
199 } else {
200 if (!isOutputting(op->buffer_->atype())) {
201 throw InvalidProgram(
202 "Only outputting variable can be as a return value");
203 }
204 // If there are multiple position with the same name, we only fill
205 // the first position. Driver::collectReturns will only collect the
206 // first
207 int nthReturn = returnPositions.front();
208 rawPtr = "returns[" + std::to_string(nthReturn) + "]";
209 std::string shapePtr =
210 "retShapes[" + std::to_string(nthReturn) + "]";
211 std::string dimPtr = "retDims[" + std::to_string(nthReturn) + "]";
212 this->os() << "if (" + rawPtr + " == NULL) ";
213 this->beginBlock();
214 this->genAlloc(op->buffer_->tensor(), rawPtr, shapePtr, dimPtr);
215 this->endBlock();
216 this->makeIndent();
217 }
218
219 switch (op->buffer_->mtype()) {
220 case MemType::ByValue:
221 // e.g. (1)
222 // float x;
223 // x = *((float*)params[0]);
224
225 // e.g. (2)
226 // __ByValArray<__ByValArray<float, 2>, 2> x;
227 // x[0][0] = *((float*)params[0])[0];
228 // x[0][1] = *((float*)params[0])[1];
229 // x[1][0] = *((float*)params[0])[2];
230 // x[1][1] = *((float*)params[0])[3];
231 if (op->buffer_->atype() != AccessType::Input) {
232 throw InvalidProgram("ByValue typed var " + op->name_ +
233 " can only be Input");
234 }
235 for (auto &&dim : shape) {
236 if (dim->nodeType() != ASTNodeType::IntConst) {
237 throw InvalidProgram("ByValue typed var " + op->name_ +
238 " can only have a constant size");
239 }
240 }
241 if (shape.empty()) {
242 this->os() << gen(tensor->dtype()) << " " << name << " = *(("
243 << gen(tensor->dtype()) << "*)" << rawPtr << ");"
244 << std::endl;
245 } else {
246 for (size_t i = 0, iEnd = shape.size(); i < iEnd; i++) {
247 this->os() << "__ByValArray<";
248 }
249 this->os() << gen(tensor->dtype());
250 for (auto it = shape.rbegin(); it != shape.rend(); it++) {
251 this->os() << ", " << (*it).as<IntConstNode>()->val_ << ">";
252 }
253 this->os() << " " << name << ";" << std::endl;
254 std::vector<int> idx(shape.size(), 0);
255 std::function<void(size_t, int)> f = [&](size_t i, int offset) {
256 if (i == shape.size()) {
257 this->makeIndent();
258 this->os() << name;
259 for (int x : idx) {
260 this->os() << "[" << x << "]";
261 }
262 this->os()
263 << " = ((" << gen(tensor->dtype()) << "*)" << rawPtr
264 << ")[" << offset << "];" << std::endl;
265 return;
266 }
267 for (int j = 0, jEnd = shape[i].as<IntConstNode>()->val_;
268 j < jEnd; j++) {
269 idx[i] = j;
270 f(i + 1, offset * jEnd + j);
271 }
272 };
273 f(0, 0);
274 }
275 break;
276
277 default:
278 // e.g.
279 // auto &&x = mdspan_r<const float, extents<5, 5>>(params[0]);
280 this->os() << "auto &&" << name << " = ";
281 genMdPtrDef(op, rawPtr, !isWritable(op->buffer_->atype()));
282 this->os() << ";" << std::endl;
283 }
284 }
285
286 this->markDef(op);
287 (*this)(op->body_);
288 this->markUndef(op);
289}
290
291template <class Stream> void CodeGenC<Stream>::visit(const Var &op) {
292 this->markUseIter(op->name_);
293 this->os() << mangle(op->name_);
294 BaseClass::visit(op);
295}
296
297template <class Stream> void CodeGenC<Stream>::visit(const Store &op) {
298 this->markUse(op->var_);
299
300 this->makeIndent();
301 this->genScalar(op);
302 this->os() << " = ";
303 (*this)(op->expr_);
304 this->os() << ";" << std::endl;
305}
306
307template <class Stream> void CodeGenC<Stream>::visit(const Alloc &op) {
308 this->markUse(op->var_);
309 this->makeIndent();
310
311 auto &&def = BaseClass::def(op->var_);
312 auto &&tensor = def->buffer_->tensor();
313 auto &&shape = tensor->shape();
314 auto &&dtype = tensor->dtype();
315
316 // e.g.
317 // x_opt = mdspan_r<int, extents<5, 5>>(new int[n*m*l]);
318 this->os() << mangle(op->var_) << "_opt = ";
319 genMdPtrDef(def, [&]() {
320 this->os() << "new " << gen(dtype) << "[";
321 for (auto i = 0lu; i < shape.size(); ++i) {
322 if (i != 0lu)
323 this->os() << "*";
324 this->os() << "(";
325 (*this)(shape[i]);
326 this->os() << ")";
327 }
328 this->os() << "]";
329 });
330 this->os() << ";" << std::endl;
331}
332
333template <class Stream> void CodeGenC<Stream>::visit(const Free &op) {
334
335 // e.g. auto x_ptr = x.data_handle();
336 // x_opt.drop();
337 // x_opt = std::nullopt;
338 // delete[] x_ptr;
339 auto &&name = mangle(op->var_);
340 this->makeIndent();
341 this->os() << "auto " << name << "_ptr = " << name << ".data_handle();"
342 << std::endl;
343 this->makeIndent();
344 this->os() << name << "_opt.drop();" << std::endl;
345 this->makeIndent();
346 this->os() << name << "_opt = std::nullopt;" << std::endl;
347 this->makeIndent();
348 this->os() << "delete[] " << name << "_ptr;" << std::endl;
349}
350
351template <class Stream> void CodeGenC<Stream>::visit(const Load &op) {
352 this->markUse(op->var_);
353 this->genScalar(op);
354}
355
356template <class Stream> void CodeGenC<Stream>::visit(const ReduceTo &op) {
357 this->markUse(op->var_);
358
359 this->makeIndent();
360
361 auto genAddr = [&]() { this->genScalar(op); };
362 auto genExpr = [&]() { (*this)(op->expr_); };
363
364 switch (op->op_) {
365 case ReduceOp::Add:
366 genAddr(), this->os() << " += ", genExpr();
367 break;
368 case ReduceOp::Mul:
369 genAddr(), this->os() << " *= ", genExpr();
370 break;
371 case ReduceOp::Min:
372 genAddr(), this->os()
373 << " = std::min<"
374 << this->gen(this->buffer(op->var_)->tensor()->dtype())
375 << ">(";
376 genAddr(), this->os() << ", ", genExpr(), this->os() << ")";
377 break;
378 case ReduceOp::Max:
379 genAddr(), this->os()
380 << " = std::max<"
381 << this->gen(this->buffer(op->var_)->tensor()->dtype())
382 << ">(";
383 genAddr(), this->os() << ", ", genExpr(), this->os() << ")";
384 break;
385 case ReduceOp::LAnd:
386 genAddr(), this->os() << " &= (bool)(", genExpr(), this->os() << ")";
387 break;
388 case ReduceOp::LOr:
389 genAddr(), this->os() << " |= (bool)(", genExpr(), this->os() << ")";
390 break;
391 default:
392 ASSERT(false);
393 }
394
395 this->os() << ";" << std::endl;
396}
397
398template <class Stream> void CodeGenC<Stream>::visit(const IntConst &op) {
399 this->os() << std::to_string(op->val_);
400}
401
402template <class Stream> void CodeGenC<Stream>::visit(const FloatConst &op) {
403 if (std::isnan(op->val_)) {
404 throw InvalidProgram("NaN literal in the program");
405 } else if (op->val_ == INFINITY) {
406 this->os() << "INFINITY";
407 } else if (op->val_ == -INFINITY) {
408 this->os() << "-INFINITY";
409 } else {
410 this->os() << std::hexfloat << op->val_
411 << "f"; // FIXME: Determine the actual type
412 }
413}
414
415template <class Stream> void CodeGenC<Stream>::visit(const BoolConst &op) {
416 this->os() << std::to_string(op->val_);
417}
418
419template <class Stream> void CodeGenC<Stream>::visit(const Add &op) {
420 this->os() << "(";
421 (*this)(op->lhs_);
422 this->os() << " + ";
423 (*this)(op->rhs_);
424 this->os() << ")";
425}
426
427template <class Stream> void CodeGenC<Stream>::visit(const Sub &op) {
428 this->os() << "(";
429 (*this)(op->lhs_);
430 this->os() << " - ";
431 (*this)(op->rhs_);
432 this->os() << ")";
433}
434
435template <class Stream> void CodeGenC<Stream>::visit(const Mul &op) {
436 this->os() << "(";
437 (*this)(op->lhs_);
438 this->os() << " * ";
439 (*this)(op->rhs_);
440 this->os() << ")";
441}
442
443template <class Stream> void CodeGenC<Stream>::visit(const RealDiv &op) {
444 if (isFloat(op->lhs_->dtype()) || isFloat(op->rhs_->dtype())) {
445 this->os() << "(";
446 (*this)(op->lhs_);
447 this->os() << " / ";
448 (*this)(op->rhs_);
449 this->os() << ")";
450 } else {
451 // TODO: Use double?
452 this->os() << "(float(";
453 (*this)(op->lhs_);
454 this->os() << ") / float(";
455 (*this)(op->rhs_);
456 this->os() << "))";
457 }
458}
459
460template <class Stream>
462 this->os() << "(";
463 (*this)(op->lhs_);
464 this->os() << " / ";
465 (*this)(op->rhs_);
466 this->os() << ")";
467}
468
469template <class Stream> void CodeGenC<Stream>::visit(const FloorDiv &op) {
470 this->os() << "floorDiv<" << this->gen(op->dtype()) << ">(";
471 (*this)(op->lhs_);
472 this->os() << ", ";
473 (*this)(op->rhs_);
474 this->os() << ")";
475}
476
477template <class Stream> void CodeGenC<Stream>::visit(const CeilDiv &op) {
478 this->os() << "ceilDiv<" << this->gen(op->dtype()) << ">(";
479 (*this)(op->lhs_);
480 this->os() << ", ";
481 (*this)(op->rhs_);
482 this->os() << ")";
483}
484
485template <class Stream> void CodeGenC<Stream>::visit(const Mod &op) {
486 this->os() << "runtime_mod(";
487 (*this)(op->lhs_);
488 this->os() << ", ";
489 (*this)(op->rhs_);
490 this->os() << ")";
491}
492
493template <class Stream> void CodeGenC<Stream>::visit(const Remainder &op) {
494 this->os() << "(";
495 (*this)(op->lhs_);
496 this->os() << " % ";
497 (*this)(op->rhs_);
498 this->os() << ")";
499}
500
501template <class Stream> void CodeGenC<Stream>::visit(const Min &op) {
502 this->os() << "std::min<" << this->gen(op->dtype()) << ">(";
503 (*this)(op->lhs_);
504 this->os() << ", ";
505 (*this)(op->rhs_);
506 this->os() << ")";
507}
508
509template <class Stream> void CodeGenC<Stream>::visit(const Max &op) {
510 this->os() << "std::max<" << this->gen(op->dtype()) << ">(";
511 (*this)(op->lhs_);
512 this->os() << ", ";
513 (*this)(op->rhs_);
514 this->os() << ")";
515}
516
517template <class Stream> void CodeGenC<Stream>::visit(const LT &op) {
518 this->os() << "(";
519 (*this)(op->lhs_);
520 this->os() << " < ";
521 (*this)(op->rhs_);
522 this->os() << ")";
523}
524
525template <class Stream> void CodeGenC<Stream>::visit(const LE &op) {
526 this->os() << "(";
527 (*this)(op->lhs_);
528 this->os() << " <= ";
529 (*this)(op->rhs_);
530 this->os() << ")";
531}
532
533template <class Stream> void CodeGenC<Stream>::visit(const GT &op) {
534 this->os() << "(";
535 (*this)(op->lhs_);
536 this->os() << " > ";
537 (*this)(op->rhs_);
538 this->os() << ")";
539}
540
541template <class Stream> void CodeGenC<Stream>::visit(const GE &op) {
542 this->os() << "(";
543 (*this)(op->lhs_);
544 this->os() << " >= ";
545 (*this)(op->rhs_);
546 this->os() << ")";
547}
548
549template <class Stream> void CodeGenC<Stream>::visit(const EQ &op) {
550 this->os() << "(";
551 (*this)(op->lhs_);
552 this->os() << " == ";
553 (*this)(op->rhs_);
554 this->os() << ")";
555}
556
557template <class Stream> void CodeGenC<Stream>::visit(const NE &op) {
558 this->os() << "(";
559 (*this)(op->lhs_);
560 this->os() << " != ";
561 (*this)(op->rhs_);
562 this->os() << ")";
563}
564
565template <class Stream> void CodeGenC<Stream>::visit(const LAnd &op) {
566 this->os() << "(";
567 (*this)(op->lhs_);
568 this->os() << " && ";
569 (*this)(op->rhs_);
570 this->os() << ")";
571}
572
573template <class Stream> void CodeGenC<Stream>::visit(const LOr &op) {
574 this->os() << "(";
575 (*this)(op->lhs_);
576 this->os() << " || ";
577 (*this)(op->rhs_);
578 this->os() << ")";
579}
580
581template <class Stream> void CodeGenC<Stream>::visit(const LNot &op) {
582 this->os() << "!";
583 (*this)(op->expr_);
584}
585
586template <class Stream> void CodeGenC<Stream>::visit(const Sqrt &op) {
587 this->os() << "sqrt(";
588 (*this)(op->expr_);
589 this->os() << ")";
590}
591
592template <class Stream> void CodeGenC<Stream>::visit(const Exp &op) {
593 this->os() << "exp(";
594 (*this)(op->expr_);
595 this->os() << ")";
596}
597
598template <class Stream> void CodeGenC<Stream>::visit(const Ln &op) {
599 this->os() << "log(";
600 (*this)(op->expr_);
601 this->os() << ")";
602}
603
604template <class Stream> void CodeGenC<Stream>::visit(const Square &op) {
605 this->os() << "runtime_square(";
606 (*this)(op->expr_);
607 this->os() << ")";
608}
609
610template <class Stream> void CodeGenC<Stream>::visit(const Sigmoid &op) {
611 this->os() << "runtime_sigmoid(";
612 (*this)(op->expr_);
613 this->os() << ")";
614}
615
616template <class Stream> void CodeGenC<Stream>::visit(const Sin &op) {
617 this->os() << "std::sin(";
618 (*this)(op->expr_);
619 this->os() << ")";
620}
621
622template <class Stream> void CodeGenC<Stream>::visit(const Cos &op) {
623 this->os() << "std::cos(";
624 (*this)(op->expr_);
625 this->os() << ")";
626}
627
628template <class Stream> void CodeGenC<Stream>::visit(const Tan &op) {
629 this->os() << "std::tan(";
630 (*this)(op->expr_);
631 this->os() << ")";
632}
633
634template <class Stream> void CodeGenC<Stream>::visit(const Tanh &op) {
635 this->os() << "std::tanh(";
636 (*this)(op->expr_);
637 this->os() << ")";
638}
639
640template <class Stream> void CodeGenC<Stream>::visit(const Abs &op) {
641 this->os() << "std::abs(";
642 (*this)(op->expr_);
643 this->os() << ")";
644}
645
646template <class Stream> void CodeGenC<Stream>::visit(const Floor &op) {
647 this->os() << "std::floor(";
648 (*this)(op->expr_);
649 this->os() << ")";
650}
651
652template <class Stream> void CodeGenC<Stream>::visit(const Ceil &op) {
653 this->os() << "std::ceil(";
654 (*this)(op->expr_);
655 this->os() << ")";
656}
657
658template <class Stream> void CodeGenC<Stream>::visit(const IfExpr &op) {
659 this->os() << "(";
660 (*this)(op->cond_);
661 this->os() << " ? ";
662 (*this)(op->thenCase_);
663 this->os() << " : ";
664 (*this)(op->elseCase_);
665 this->os() << ")";
666}
667
668template <class Stream> void CodeGenC<Stream>::visit(const Cast &op) {
669 this->os() << gen(op->destType_) << "(";
670 (*this)(op->expr_);
671 this->os() << ")";
672}
673
674template <class Stream> void CodeGenC<Stream>::visit(const For &op) {
675 if (op->step_->nodeType() == ASTNodeType::IntConst &&
676 op->step_.as<IntConstNode>()->val_ == 1) {
677 this->makeIndent();
678 this->os() << "for (int " << mangle(op->iter_) << " = ";
679 (*this)(op->begin_);
680 this->os() << "; " << mangle(op->iter_) << " < ";
681 (*this)(op->end_);
682 this->os() << "; " << mangle(op->iter_) << "++) ";
683 this->beginBlock();
684 this->markDefIter(op);
685 (*this)(op->body_);
686 this->markUndefIter(op);
687 this->endBlock();
688 } else {
689 auto iterCnt = mangle(op->iter_ + ".cnt");
690 this->makeIndent();
691 this->os() << "for (int " << iterCnt << " = 0; " << iterCnt << " < ";
692 (*this)(op->len_);
693 this->os() << "; " << iterCnt << "++) ";
694 this->beginBlock();
695 this->makeIndent();
696 this->os() << "int " << mangle(op->iter_) << " = ";
697 (*this)(op->begin_);
698 this->os() << " + " << iterCnt << " * ";
699 (*this)(op->step_);
700 this->os() << ";" << std::endl;
701 this->markDefIter(op);
702 (*this)(op->body_);
703 this->markUndefIter(op);
704 this->endBlock();
705 }
706}
707
708template <class Stream> void CodeGenC<Stream>::visit(const If &op) {
709 this->makeIndent();
710 this->os() << "if (";
711 (*this)(op->cond_);
712 this->os() << ") ";
713 this->beginBlock();
714 (*this)(op->thenCase_);
715 this->endBlock();
716 if (op->elseCase_.isValid()) {
717 this->makeIndent();
718 this->os() << "else ";
719 this->beginBlock();
720 (*this)(op->elseCase_);
721 this->endBlock();
722 }
723}
724
725template <class Stream> void CodeGenC<Stream>::visit(const Assert &op) {
726 this->makeIndent();
727 this->os() << "assert(";
728 (*this)(op->cond_);
729 this->os() << ");" << std::endl;
730 (*this)(op->body_);
731}
732
733template <class Stream> void CodeGenC<Stream>::visit(const Intrinsic &op) {
734 bool parentIsEval =
735 op->parent().as<ASTNode>()->nodeType() != ASTNodeType::Eval;
736 if (parentIsEval)
737 this->os() << "(";
738 size_t i = 0, j = 0, n = op->format_.length();
739 while (j < n) {
740 if (op->format_[j] == '%') {
741 if (j + 1 < n && op->format_[j + 1] == '%') {
742 this->os() << '%';
743 j += 2;
744 } else {
745 (*this)(op->params_.at(i++));
746 j++;
747 }
748 } else {
749 this->os() << op->format_[j];
750 j++;
751 }
752 }
753 if (parentIsEval)
754 this->os() << ")";
755}
756
757template <class Stream> void CodeGenC<Stream>::visit(const Eval &op) {
758 this->makeIndent();
759 (*this)(op->expr_);
760 this->os() << ";" << std::endl;
761}
762
763template <class Stream>
764std::string CodeGenC<Stream>::gen(const DataType &dtype) {
765 switch (dtype.base()) {
767 return "double";
769 return "float";
771 WARNING(
772 "float16 arithmetics on CPU is supported via emulation and comes "
773 "with a performance cost, which is only for compatibility purpose. "
774 "If you intend to do float32 computation on float16 variables, "
775 "please convert them explicitly. Please ignore this warning if you "
776 "are only allocating buffers and not performing arithmetics.");
777 return "half_float::half"; // From 3rd-party/half
778 case DataType::Int64:
779 return "int64_t";
780 case DataType::Int32:
781 return "int32_t";
782 case DataType::Bool:
783 return "bool";
784 default:
785 throw InvalidProgram(
786 FT_MSG << dtype << " is not supported by this codegen backend");
787 }
788}
789
790} // namespace freetensor
791
792#endif // DETAIL_CODE_GEN_C_H
Definition: ast.h:118
Definition: code_gen_c.h:14
virtual void visit(const StmtSeq &op) override
Definition: code_gen_c.h:131
virtual void genMdPtrDef(const VarDef &def, const std::function< void()> &genRawPtr, bool isConst=false)
Definition: code_gen_c.h:71
virtual void genScalar(const VarDef &def, const std::vector< Expr > &indices)
Definition: code_gen_c.h:107
virtual std::string gen(const DataType &dtype)
Definition: code_gen_c.h:764
static bool debugRuntimeCheck()
Definition: config.h:96
Definition: data_type.h:106
static constexpr auto Float32
Definition: data_type.h:126
const auto & base() const
Definition: data_type.h:132
static constexpr auto Int32
Definition: data_type.h:128
static constexpr auto Int64
Definition: data_type.h:129
static constexpr auto Float16
Definition: data_type.h:125
static constexpr auto Float64
Definition: data_type.h:127
static constexpr auto Bool
Definition: data_type.h:123
SubTree< ExprNode > begin_
Definition: stmt.h:294
SubTree< ExprNode > step_
Definition: stmt.h:296
SubTree< ExprNode > len_
Definition: stmt.h:297
SubTree< ExprNode > end_
Definition: stmt.h:295
std::string iter_
Definition: stmt.h:289
SubTree< StmtNode > body_
Definition: stmt.h:299
Definition: expr.h:93
int64_t val_
Definition: expr.h:95
Definition: except.h:83
std::string var_
Definition: stmt.h:231
ReduceOp op_
Definition: stmt.h:233
SubTree< ExprNode > expr_
Definition: stmt.h:234
bool isValid() const
Definition: ref.h:89
Ref< U > as() const
Definition: ref.h:83
SubTreeList< StmtNode > stmts_
Definition: stmt.h:44
std::string var_
Definition: stmt.h:134
SubTree< ExprNode > expr_
Definition: stmt.h:136
Definition: stmt.h:83
SubTree< StmtNode > body_
Definition: stmt.h:101
SubTree< Buffer > buffer_
Definition: stmt.h:86
std::optional< std::string > viewOf_
Definition: stmt.h:99
std::string name_
Definition: stmt.h:85
#define ASSERT(expr)
Definition: except.h:152
#define WARNING(msg)
Definition: except.h:146
#define FT_MSG
Definition: except.h:23
int n
Definition: metadata.cc:15
Definition: allocator.h:9
Ref< StmtNode > Stmt
Definition: ast.h:152
std::string mangle(const std::string &name)
Definition: mangle.cc:8
bool isFloat(BaseDataType dtype)
Definition: data_type.cc:38
bool isInputting(AccessType atype)
Definition: access_type.h:99
bool isWritable(AccessType atype)
Definition: access_type.h:87
std::vector< Stmt > findAllStmt(const Stmt &ast, const ID &id)
Definition: find_stmt.cc:32
bool isOutputting(AccessType atype)
Definition: access_type.h:111