1#ifndef FREE_TENSOR_STMT_H
2#define FREE_TENSOR_STMT_H
29makeAny(std::source_location loc = std::source_location::current()) {
31 a->setDebugBlame(loc);
50template <
class Tstmts>
53 std::source_location loc = std::source_location::current()) {
55 s->metadata() = metadata;
57 s->stmts_ = std::forward<Tstmts>(stmts);
58 s->setDebugBlame(loc);
63 const Metadata &metadata =
nullptr,
const ID &
id = {},
64 std::source_location loc = std::source_location::current()) {
66 s->metadata() = metadata;
69 s->setDebugBlame(loc);
108template <
class Tbuffer,
class Tbody>
110 const std::optional<std::string> &viewOf, Tbody &&body,
111 bool pinned,
const Metadata &metadata =
nullptr,
113 std::source_location loc = std::source_location::current()) {
116 d->metadata() = metadata;
119 d->buffer_ = std::forward<Tbuffer>(buffer);
121 d->body_ = std::forward<Tbody>(body);
123 d->setDebugBlame(loc);
141template <
class Tindices,
class Texpr>
143 const Metadata &metadata =
nullptr,
const ID &
id = {},
144 std::source_location loc = std::source_location::current()) {
147 s->metadata() = metadata;
150 s->indices_ = std::forward<Tindices>(indices);
151 s->expr_ = std::forward<Texpr>(expr);
152 s->setDebugBlame(loc);
155template <
class Texpr>
157 Texpr &&expr,
const Metadata &metadata =
nullptr,
159 std::source_location loc = std::source_location::current()) {
162 s->metadata() = metadata;
165 s->indices_ = indices;
166 s->expr_ = std::forward<Texpr>(expr);
167 s->setDebugBlame(loc);
186 std::source_location loc = std::source_location::current()) {
189 a->metadata() = metadata;
192 a->setDebugBlame(loc);
211 std::source_location loc = std::source_location::current()) {
214 f->metadata() = metadata;
217 f->setDebugBlame(loc);
249template <
class Tindices,
class Texpr>
251 Texpr &&expr,
bool sync,
const Metadata &metadata =
nullptr,
253 std::source_location loc = std::source_location::current()) {
256 a->metadata() = metadata;
259 a->indices_ = std::forward<Tindices>(indices);
261 a->expr_ = std::forward<Texpr>(expr);
263 a->setDebugBlame(loc);
266template <
class Texpr>
268 ReduceOp op, Texpr &&expr,
bool sync,
269 const Metadata &metadata =
nullptr,
const ID &
id = {},
270 std::source_location loc = std::source_location::current()) {
273 a->metadata() = metadata;
276 a->indices_ = indices;
278 a->expr_ = std::forward<Texpr>(expr);
280 a->setDebugBlame(loc);
309template <
class Tbegin,
class Tend,
class Tstep,
class Tlen,
class Tbody,
311Stmt makeFor(
const std::string &iter, Tbegin &&begin, Tend &&end, Tstep &&step,
312 Tlen &&len, Tproperty &&property, Tbody &&body,
313 const Metadata &metadata =
nullptr,
const ID &
id = {},
314 std::source_location loc = std::source_location::current()) {
317 f->metadata() = metadata;
320 f->begin_ = std::forward<Tbegin>(begin);
321 f->end_ = std::forward<Tend>(end);
322 f->step_ = std::forward<Tstep>(step);
323 f->len_ = std::forward<Tlen>(len);
324 f->property_ = std::forward<Tproperty>(property);
325 f->body_ = std::forward<Tbody>(body);
326 f->setDebugBlame(loc);
353template <
class Tcond,
class Tthen,
class Telse = std::
nullptr_t>
355 const Metadata &metadata =
nullptr,
const ID &
id = {},
356 std::source_location loc = std::source_location::current()) {
358 i->metadata() = metadata;
360 i->cond_ = std::forward<Tcond>(cond);
361 i->thenCase_ = std::forward<Tthen>(thenCase);
362 i->elseCase_ = std::forward<Telse>(elseCase);
363 i->setDebugBlame(loc);
366template <
class Tcond,
class Tthen,
class Telse = std::
nullptr_t>
369 std::source_location loc = std::source_location::current()) {
370 return makeIf(cond, thenCase,
nullptr, metadata,
id, loc);
393template <
class Tcond,
class Tbody>
396 std::source_location loc = std::source_location::current()) {
398 a->metadata() = metadata;
400 a->cond_ = std::forward<Tcond>(cond);
401 a->body_ = std::forward<Tbody>(body);
402 a->setDebugBlame(loc);
426template <
class Tcond,
class Tbody>
429 std::source_location loc = std::source_location::current()) {
431 a->metadata() = metadata;
433 a->cond_ = std::forward<Tcond>(cond);
434 a->body_ = std::forward<Tbody>(body);
435 a->setDebugBlame(loc);
453 std::source_location loc = std::source_location::current()) {
455 e->metadata() = metadata;
457 e->expr_ = std::forward<T>(expr);
458 e->setDebugBlame(loc);
478 "mkl",
"cublas",
"cutlass",
"cutlass-micro-block",
"cutlass-micro-thread",
493 ERROR(
FT_MSG <<
"Unrecognized MatMul backend \"" << _str
494 <<
"\". Candidates are (case-insensitive): "
540 const Expr &stridea,
const Expr &strideb,
const Expr &stridec,
541 const Expr &batchSize,
bool aIsRowMajor,
bool bIsRowMajor,
542 bool cIsRowMajor,
const Stmt &equivalent,
543 const Metadata &metadata =
nullptr,
const ID &
id = {},
544 std::source_location loc = std::source_location::current()) {
546 s->metadata() = metadata;
548 s->backend_ = backend;
549 s->cutlassMicroKernelProperty_ = cutlassMicroKernelProperty;
561 s->stridea_ = stridea;
562 s->strideb_ = strideb;
563 s->stridec_ = stridec;
564 s->batchSize_ = batchSize;
565 s->aIsRowMajor_ = aIsRowMajor;
566 s->bIsRowMajor_ = bIsRowMajor;
567 s->cIsRowMajor_ = cIsRowMajor;
568 s->equivalent_ = equivalent;
582 const Metadata &metadata =
nullptr,
const ID &
id = {},
583 std::source_location loc = std::source_location::current()) {
585 s->metadata() = metadata;
587 s->tapeName_ = tapeName;
589 s->setDebugBlame(loc);
#define DEFINE_NODE_TRAIT(name)
Definition: ast.h:106
void setDebugBlame(std::source_location loc)
Definition: ast.h:141
std::string var_
Definition: stmt.h:178
void compHash() override
Definition: stmt.cc:10
void compHash() override
Definition: stmt.cc:6
void compHash() override
Definition: stmt.cc:15
SubTree< StmtNode > body_
Definition: stmt.h:386
SubTree< ExprNode > cond_
Definition: stmt.h:385
DEFINE_NODE_TRAIT(Assert)
std::vector< Stmt > children() const override
Definition: stmt.h:388
bool isCtrlFlow() const override
Definition: stmt.h:387
void compHash() override
Definition: stmt.cc:16
std::vector< Stmt > children() const override
Definition: stmt.h:421
DEFINE_NODE_TRAIT(Assume)
SubTree< StmtNode > body_
Definition: stmt.h:420
SubTree< ExprNode > cond_
Definition: stmt.h:419
SubTree< ExprNode > expr_
Definition: stmt.h:446
void compHash() override
Definition: stmt.cc:17
void compHash() override
Definition: stmt.cc:13
SubTree< ForProperty > property_
Definition: stmt.h:298
SubTree< ExprNode > begin_
Definition: stmt.h:294
SubTree< ExprNode > step_
Definition: stmt.h:296
std::vector< Stmt > children() const override
Definition: stmt.h:302
SubTree< ExprNode > len_
Definition: stmt.h:297
bool isCtrlFlow() const override
Definition: stmt.h:301
SubTree< ExprNode > end_
Definition: stmt.h:295
std::string iter_
Definition: stmt.h:289
SubTree< StmtNode > body_
Definition: stmt.h:299
void compHash() override
Definition: stmt.cc:11
std::string var_
Definition: stmt.h:203
SubTree< ExprNode > cond_
Definition: stmt.h:335
std::vector< Stmt > children() const override
Definition: stmt.h:340
SubTree< StmtNode > thenCase_
Definition: stmt.h:336
bool isCtrlFlow() const override
Definition: stmt.h:339
void compHash() override
Definition: stmt.cc:14
SubTree< StmtNode, NullPolicy::Nullable > elseCase_
Definition: stmt.h:337
std::string tapeName_
Definition: stmt.h:575
DEFINE_NODE_TRAIT(MarkVersion)
std::string var_
Definition: stmt.h:575
void compHash() override
Definition: stmt.cc:19
bool aIsRowMajor_
Definition: stmt.h:526
MatMulBackend backend_
Definition: stmt.h:503
std::vector< Stmt > children() const override
Definition: stmt.h:529
SubTree< ExprNode > ldb_
Definition: stmt.h:520
SubTree< ExprNode > beta_
Definition: stmt.h:515
SubTree< ExprNode > alpha_
Definition: stmt.h:514
SubTree< ExprNode > m_
Definition: stmt.h:516
SubTree< ExprNode > stridec_
Definition: stmt.h:524
SubTree< ExprNode > batchSize_
Definition: stmt.h:525
SubTree< ExprNode > k_
Definition: stmt.h:517
SubTree< ExprNode > lda_
Definition: stmt.h:519
DEFINE_NODE_TRAIT(MatMul)
SubTree< ExprNode > n_
Definition: stmt.h:518
SubTree< ExprNode > ldc_
Definition: stmt.h:521
SubTree< StmtNode > equivalent_
Definition: stmt.h:527
bool bIsRowMajor_
Definition: stmt.h:526
SubTree< CutlassMicroKernelProperty, NullPolicy::Nullable > cutlassMicroKernelProperty_
Definition: stmt.h:505
bool cIsRowMajor_
Definition: stmt.h:526
void compHash() override
Definition: stmt.cc:18
SubTree< ExprNode > c_
Definition: stmt.h:513
SubTree< ExprNode > a_
Definition: stmt.h:511
SubTree< ExprNode > b_
Definition: stmt.h:512
SubTree< ExprNode > strideb_
Definition: stmt.h:523
SubTree< ExprNode > stridea_
Definition: stmt.h:522
void compHash() override
Definition: stmt.cc:12
std::string var_
Definition: stmt.h:231
SubTreeList< ExprNode > indices_
Definition: stmt.h:232
ReduceOp op_
Definition: stmt.h:233
SubTree< ExprNode > expr_
Definition: stmt.h:234
bool sync_
Definition: stmt.h:243
static Ref make()
Definition: ref.h:105
SubTreeList< StmtNode > stmts_
Definition: stmt.h:44
std::vector< Stmt > children() const override
Definition: stmt.h:45
DEFINE_NODE_TRAIT(StmtSeq)
void compHash() override
Definition: stmt.cc:7
std::string var_
Definition: stmt.h:134
void compHash() override
Definition: stmt.cc:9
SubTree< ExprNode > expr_
Definition: stmt.h:136
SubTreeList< ExprNode > indices_
Definition: stmt.h:135
Definition: sub_tree.h:305
Definition: sub_tree.h:134
bool pinned_
Definition: stmt.h:102
SubTree< StmtNode > body_
Definition: stmt.h:101
void compHash() override
Definition: stmt.cc:8
SubTree< Buffer > buffer_
Definition: stmt.h:86
std::vector< Stmt > children() const override
If pinned, SinkVar and ShrinkVar will not alter this node.
Definition: stmt.h:103
std::optional< std::string > viewOf_
Definition: stmt.h:99
std::string name_
Definition: stmt.h:85
DEFINE_NODE_TRAIT(VarDef)
#define ASSERT(expr)
Definition: except.h:152
#define ERROR(msg)
Definition: except.h:141
#define FT_MSG
Definition: except.h:23
Definition: allocator.h:9
Ref< AssumeNode > Assume
Definition: stmt.h:425
Ref< FreeNode > Free
Definition: stmt.h:207
Ref< VarDefNode > VarDef
Definition: stmt.h:107
Stmt makeEval(T &&expr, const Metadata &metadata=nullptr, const ID &id={}, std::source_location loc=std::source_location::current())
Definition: stmt.h:452
Stmt makeAssume(Tcond &&cond, Tbody &&body, const Metadata &metadata=nullptr, const ID &id={}, std::source_location loc=std::source_location::current())
Definition: stmt.h:427
Ref< MarkVersionNode > MarkVersion
Definition: stmt.h:579
std::string tolower(const std::string &s)
Definition: container_utils.h:142
Ref< StoreNode > Store
Definition: stmt.h:140
Ref< ReduceToNode > ReduceTo
Definition: stmt.h:248
Ref< IfNode > If
Definition: stmt.h:352
Stmt makeAssert(Tcond &&cond, Tbody &&body, const Metadata &metadata=nullptr, const ID &id={}, std::source_location loc=std::source_location::current())
Definition: stmt.h:394
Ref< ForNode > For
Definition: stmt.h:308
MatMulBackend
Definition: stmt.h:465
Ref< EvalNode > Eval
Definition: stmt.h:450
constexpr std::array matMulBackendNames
Definition: stmt.h:477
Stmt makeStore(const std::string &var, Tindices &&indices, Texpr &&expr, const Metadata &metadata=nullptr, const ID &id={}, std::source_location loc=std::source_location::current())
Definition: stmt.h:142
Stmt makeAny(std::source_location loc=std::source_location::current())
Definition: stmt.h:29
Stmt makeReduceTo(const std::string &var, Tindices &&indices, ReduceOp op, Texpr &&expr, bool sync, const Metadata &metadata=nullptr, const ID &id={}, std::source_location loc=std::source_location::current())
Definition: stmt.h:250
Ref< StmtNode > Stmt
Definition: ast.h:152
Stmt makeStmtSeq(Tstmts &&stmts, const Metadata &metadata=nullptr, const ID &id={}, std::source_location loc=std::source_location::current())
Definition: stmt.h:51
Stmt makeMatMul(MatMulBackend backend, const Ref< CutlassMicroKernelProperty > &cutlassMicroKernelProperty, const Expr &a, const Expr &b, const Expr &c, const Expr &alpha, const Expr &beta, const Expr &m, const Expr &k, const Expr &n, const Expr &lda, const Expr &ldb, const Expr &ldc, const Expr &stridea, const Expr &strideb, const Expr &stridec, const Expr &batchSize, bool aIsRowMajor, bool bIsRowMajor, bool cIsRowMajor, const Stmt &equivalent, const Metadata &metadata=nullptr, const ID &id={}, std::source_location loc=std::source_location::current())
Definition: stmt.h:535
Stmt makeMarkVersion(const std::string &tapeName, const std::string &var, const Metadata &metadata=nullptr, const ID &id={}, std::source_location loc=std::source_location::current())
Definition: stmt.h:581
Stmt makeIf(Tcond &&cond, Tthen &&thenCase, Telse &&elseCase, const Metadata &metadata=nullptr, const ID &id={}, std::source_location loc=std::source_location::current())
Definition: stmt.h:354
ReduceOp
Definition: reduce_op.h:30
std::string join(Container &&c, const std::string &splitter)
Definition: container_utils.h:194
Stmt makeVarDef(const std::string &name, Tbuffer &&buffer, const std::optional< std::string > &viewOf, Tbody &&body, bool pinned, const Metadata &metadata=nullptr, const ID &id={}, std::source_location loc=std::source_location::current())
Definition: stmt.h:109
constexpr std::array baseDataTypeNames
Definition: data_type.h:27
Stmt makeFor(const std::string &iter, Tbegin &&begin, Tend &&end, Tstep &&step, Tlen &&len, Tproperty &&property, Tbody &&body, const Metadata &metadata=nullptr, const ID &id={}, std::source_location loc=std::source_location::current())
Definition: stmt.h:311
std::ostream & operator<<(std::ostream &os, const Dependence &dep)
Definition: deps.cc:1404
MatMulBackend parseMatMulBackend(const std::string &_str)
Definition: stmt.h:486
Ref< AssertNode > Assert
Definition: stmt.h:392
Ref< AllocNode > Alloc
Definition: stmt.h:182
Ref< MatMulNode > MatMul
Definition: stmt.h:533
Stmt makeFree(const std::string &var, const Metadata &metadata=nullptr, const ID &id={}, std::source_location loc=std::source_location::current())
Definition: stmt.h:209
Stmt makeAlloc(const std::string &var, const Metadata &metadata=nullptr, const ID &id={}, std::source_location loc=std::source_location::current())
Definition: stmt.h:184
Ref< AnyNode > Any
Definition: stmt.h:27
Ref< StmtSeqNode > StmtSeq
Definition: stmt.h:49
Definition: sub_tree.h:20