1#ifndef DETAIL_CODE_GEN_C_H
2#define DETAIL_CODE_GEN_C_H
19template <
class Stream>
20std::function<std::ostream &(std::ostream &)>
26#if defined(__clang__) && __clang_major__ < 17
27 return [=](std::ostream &os) -> std::ostream & {
29 return [=,
this](std::ostream &os) -> std::ostream & {
33 if (buf->tensor()->shape().empty()) {
38 os << gen(buf->tensor()->dtype()) <<
" &";
42 bool isRestricted =
true;
52 : isRestricted ?
"mdspan_r<"
57 os << gen(buf->tensor()->dtype()) <<
", extents<";
58 for (
auto &&[i, dim] : views::enumerate(buf->tensor()->shape())) {
59 os << (i > 0 ?
", " :
"");
61 os << dim.template as<IntConstNode>()->val_;
63 os <<
"dynamic_extent";
70template <
class Stream>
72 const std::function<
void()> &genRawPtr,
76 if (buf->tensor()->shape().empty()) {
80 this->os() <<
"((" << genMdPtrType(def, isConst) <<
")*((";
82 this->os() <<
"const ";
84 this->os() << gen(buf->tensor()->dtype()) <<
" *)(";
90 this->os() << genMdPtrType(def, isConst) <<
"((";
92 this->os() <<
"const ";
94 this->os() << gen(buf->tensor()->dtype()) <<
"*)(";
97 for (
auto &&dim : buf->tensor()->shape()) {
106template <
class Stream>
108 const std::vector<Expr> &indices) {
112 for (
auto &&index : indices) {
119 if (!def->
buffer_->tensor()->shape().empty()) {
122 for (
auto &&[i, index] : views::enumerate(indices)) {
123 this->os() << (i > 0 ?
", " :
"");
132 for (
auto &&stmt : op->
stmts_) {
146 auto &&tensor = op->
buffer_->tensor();
147 auto &&shape = tensor->shape();
154 while (source->viewOf_.has_value()) {
155 source = this->def(*source->viewOf_);
157 this->os() <<
"auto &&" << name <<
" = ";
158 genMdPtrDef(op,
mangle(source->name_) +
".data_handle()",
160 this->os() <<
";" << std::endl;
166 this->os() << gen(tensor->dtype()) <<
" " << name;
167 for (
auto &&dim : shape) {
172 this->os() <<
";" << std::endl;
174 auto paramPositions = ranges::to<std::vector>(
175 params_ | views::enumerate | views::filter([&](
auto &&pair) {
176 return pair.second.name_ == op->
name_;
179 auto returnPositions = ranges::to<std::vector>(
180 returns_ | views::enumerate | views::filter([&](
auto &&pair) {
181 return pair.second.name_ == op->
name_;
184 bool isParam = !paramPositions.empty();
185 bool isReturn = !returnPositions.empty();
186 if (!isParam && !isReturn) {
188 " used but not defined as a function's "
189 "parameters or return values");
193 if (paramPositions.size() > 1) {
197 int nthParam = paramPositions.front();
198 rawPtr =
"params[" + std::to_string(nthParam) +
"]";
202 "Only outputting variable can be as a return value");
207 int nthReturn = returnPositions.front();
208 rawPtr =
"returns[" + std::to_string(nthReturn) +
"]";
209 std::string shapePtr =
210 "retShapes[" + std::to_string(nthReturn) +
"]";
211 std::string dimPtr =
"retDims[" + std::to_string(nthReturn) +
"]";
212 this->os() <<
"if (" + rawPtr +
" == NULL) ";
214 this->genAlloc(op->
buffer_->tensor(), rawPtr, shapePtr, dimPtr);
219 switch (op->
buffer_->mtype()) {
233 " can only be Input");
235 for (
auto &&dim : shape) {
238 " can only have a constant size");
242 this->os() << gen(tensor->dtype()) <<
" " << name <<
" = *(("
243 << gen(tensor->dtype()) <<
"*)" << rawPtr <<
");"
246 for (
size_t i = 0, iEnd = shape.size(); i < iEnd; i++) {
247 this->os() <<
"__ByValArray<";
249 this->os() << gen(tensor->dtype());
250 for (
auto it = shape.rbegin(); it != shape.rend(); it++) {
251 this->os() <<
", " << (*it).as<
IntConstNode>()->val_ <<
">";
253 this->os() <<
" " << name <<
";" << std::endl;
254 std::vector<int> idx(shape.size(), 0);
255 std::function<void(
size_t,
int)> f = [&](
size_t i,
int offset) {
256 if (i == shape.size()) {
260 this->os() <<
"[" << x <<
"]";
263 <<
" = ((" << gen(tensor->dtype()) <<
"*)" << rawPtr
264 <<
")[" << offset <<
"];" << std::endl;
267 for (
int j = 0, jEnd = shape[i].as<IntConstNode>()->val_;
270 f(i + 1, offset * jEnd + j);
280 this->os() <<
"auto &&" << name <<
" = ";
282 this->os() <<
";" << std::endl;
292 this->markUseIter(op->name_);
293 this->os() <<
mangle(op->name_);
294 BaseClass::visit(op);
298 this->markUse(op->
var_);
304 this->os() <<
";" << std::endl;
308 this->markUse(op->var_);
311 auto &&def = BaseClass::def(op->var_);
312 auto &&tensor = def->
buffer_->tensor();
313 auto &&shape = tensor->shape();
314 auto &&dtype = tensor->dtype();
318 this->os() <<
mangle(op->var_) <<
"_opt = ";
319 genMdPtrDef(def, [&]() {
320 this->os() <<
"new " << gen(dtype) <<
"[";
321 for (
auto i = 0lu; i < shape.size(); ++i) {
330 this->os() <<
";" << std::endl;
339 auto &&name =
mangle(op->var_);
341 this->os() <<
"auto " << name <<
"_ptr = " << name <<
".data_handle();"
344 this->os() << name <<
"_opt.drop();" << std::endl;
346 this->os() << name <<
"_opt = std::nullopt;" << std::endl;
348 this->os() <<
"delete[] " << name <<
"_ptr;" << std::endl;
352 this->markUse(op->var_);
357 this->markUse(op->
var_);
361 auto genAddr = [&]() { this->genScalar(op); };
362 auto genExpr = [&]() { (*this)(op->
expr_); };
366 genAddr(), this->os() <<
" += ", genExpr();
369 genAddr(), this->os() <<
" *= ", genExpr();
372 genAddr(), this->os()
374 << this->gen(this->buffer(op->
var_)->tensor()->dtype())
376 genAddr(), this->os() <<
", ", genExpr(), this->os() <<
")";
379 genAddr(), this->os()
381 << this->gen(this->buffer(op->
var_)->tensor()->dtype())
383 genAddr(), this->os() <<
", ", genExpr(), this->os() <<
")";
386 genAddr(), this->os() <<
" &= (bool)(", genExpr(), this->os() <<
")";
389 genAddr(), this->os() <<
" |= (bool)(", genExpr(), this->os() <<
")";
395 this->os() <<
";" << std::endl;
399 this->os() << std::to_string(op->val_);
403 if (std::isnan(op->val_)) {
405 }
else if (op->val_ == INFINITY) {
406 this->os() <<
"INFINITY";
407 }
else if (op->val_ == -INFINITY) {
408 this->os() <<
"-INFINITY";
410 this->os() << std::hexfloat << op->val_
416 this->os() << std::to_string(op->val_);
452 this->os() <<
"(float(";
454 this->os() <<
") / float(";
460template <
class Stream>
470 this->os() <<
"floorDiv<" << this->gen(op->dtype()) <<
">(";
478 this->os() <<
"ceilDiv<" << this->gen(op->dtype()) <<
">(";
486 this->os() <<
"runtime_mod(";
502 this->os() <<
"std::min<" << this->gen(op->dtype()) <<
">(";
510 this->os() <<
"std::max<" << this->gen(op->dtype()) <<
">(";
528 this->os() <<
" <= ";
544 this->os() <<
" >= ";
552 this->os() <<
" == ";
560 this->os() <<
" != ";
568 this->os() <<
" && ";
576 this->os() <<
" || ";
587 this->os() <<
"sqrt(";
593 this->os() <<
"exp(";
599 this->os() <<
"log(";
605 this->os() <<
"runtime_square(";
611 this->os() <<
"runtime_sigmoid(";
617 this->os() <<
"std::sin(";
623 this->os() <<
"std::cos(";
629 this->os() <<
"std::tan(";
635 this->os() <<
"std::tanh(";
641 this->os() <<
"std::abs(";
647 this->os() <<
"std::floor(";
653 this->os() <<
"std::ceil(";
662 (*this)(op->thenCase_);
664 (*this)(op->elseCase_);
669 this->os() << gen(op->destType_) <<
"(";
678 this->os() <<
"for (int " <<
mangle(op->
iter_) <<
" = ";
684 this->markDefIter(op);
686 this->markUndefIter(op);
691 this->os() <<
"for (int " << iterCnt <<
" = 0; " << iterCnt <<
" < ";
693 this->os() <<
"; " << iterCnt <<
"++) ";
696 this->os() <<
"int " <<
mangle(op->
iter_) <<
" = ";
698 this->os() <<
" + " << iterCnt <<
" * ";
700 this->os() <<
";" << std::endl;
701 this->markDefIter(op);
703 this->markUndefIter(op);
710 this->os() <<
"if (";
714 (*this)(op->thenCase_);
718 this->os() <<
"else ";
720 (*this)(op->elseCase_);
727 this->os() <<
"assert(";
729 this->os() <<
");" << std::endl;
738 size_t i = 0, j = 0,
n = op->format_.length();
740 if (op->format_[j] ==
'%') {
741 if (j + 1 <
n && op->format_[j + 1] ==
'%') {
745 (*this)(op->params_.at(i++));
749 this->os() << op->format_[j];
760 this->os() <<
";" << std::endl;
763template <
class Stream>
765 switch (dtype.
base()) {
772 "float16 arithmetics on CPU is supported via emulation and comes "
773 "with a performance cost, which is only for compatibility purpose. "
774 "If you intend to do float32 computation on float16 variables, "
775 "please convert them explicitly. Please ignore this warning if you "
776 "are only allocating buffers and not performing arithmetics.");
777 return "half_float::half";
786 FT_MSG << dtype <<
" is not supported by this codegen backend");
Definition: code_gen_c.h:14
virtual void visit(const StmtSeq &op) override
Definition: code_gen_c.h:131
virtual void genMdPtrDef(const VarDef &def, const std::function< void()> &genRawPtr, bool isConst=false)
Definition: code_gen_c.h:71
virtual void genScalar(const VarDef &def, const std::vector< Expr > &indices)
Definition: code_gen_c.h:107
virtual std::string gen(const DataType &dtype)
Definition: code_gen_c.h:764
static bool debugRuntimeCheck()
Definition: config.h:96
Definition: data_type.h:106
static constexpr auto Float32
Definition: data_type.h:126
const auto & base() const
Definition: data_type.h:132
static constexpr auto Int32
Definition: data_type.h:128
static constexpr auto Int64
Definition: data_type.h:129
static constexpr auto Float16
Definition: data_type.h:125
static constexpr auto Float64
Definition: data_type.h:127
static constexpr auto Bool
Definition: data_type.h:123
SubTree< ExprNode > begin_
Definition: stmt.h:294
SubTree< ExprNode > step_
Definition: stmt.h:296
SubTree< ExprNode > len_
Definition: stmt.h:297
SubTree< ExprNode > end_
Definition: stmt.h:295
std::string iter_
Definition: stmt.h:289
SubTree< StmtNode > body_
Definition: stmt.h:299
int64_t val_
Definition: expr.h:95
std::string var_
Definition: stmt.h:231
ReduceOp op_
Definition: stmt.h:233
SubTree< ExprNode > expr_
Definition: stmt.h:234
bool isValid() const
Definition: ref.h:89
Ref< U > as() const
Definition: ref.h:83
SubTreeList< StmtNode > stmts_
Definition: stmt.h:44
std::string var_
Definition: stmt.h:134
SubTree< ExprNode > expr_
Definition: stmt.h:136
SubTree< StmtNode > body_
Definition: stmt.h:101
SubTree< Buffer > buffer_
Definition: stmt.h:86
std::optional< std::string > viewOf_
Definition: stmt.h:99
std::string name_
Definition: stmt.h:85
#define ASSERT(expr)
Definition: except.h:152
#define WARNING(msg)
Definition: except.h:146
#define FT_MSG
Definition: except.h:23
Definition: allocator.h:9
Ref< StmtNode > Stmt
Definition: ast.h:152
std::string mangle(const std::string &name)
Definition: mangle.cc:8
bool isFloat(BaseDataType dtype)
Definition: data_type.cc:38
bool isInputting(AccessType atype)
Definition: access_type.h:99
bool isWritable(AccessType atype)
Definition: access_type.h:87
std::vector< Stmt > findAllStmt(const Stmt &ast, const ID &id)
Definition: find_stmt.cc:32
bool isOutputting(AccessType atype)
Definition: access_type.h:111