1#ifndef FREE_TENSOR_MUTATOR_H
2#define FREE_TENSOR_MUTATOR_H
42 std::vector<Stmt> stmts;
43 stmts.reserve(op->
stmts_.size());
44 for (
auto &&stmt : op->
stmts_) {
45 stmts.emplace_back((*
this)(stmt));
52 std::vector<Expr> shape;
53 shape.reserve(op->
buffer_->tensor()->shape().size());
54 for (
auto &&dim : op->
buffer_->tensor()->shape()) {
55 shape.emplace_back((*
this)(dim));
67 return makeVar(op->name_, op->debugBlame());
71 std::vector<Expr> indices;
72 indices.reserve(op->
indices_.size());
74 indices.emplace_back((*
this)(index));
76 auto &&expr = (*this)(op->
expr_);
77 return makeStore(op->
var_, std::move(indices), std::move(expr),
82 return makeAlloc(op->var_, op->metadata(), op->id(), op->debugBlame());
86 return makeFree(op->var_, op->metadata(), op->id(), op->debugBlame());
90 std::vector<Expr> indices;
91 indices.reserve(op->indices_.size());
92 for (
auto &&index : op->indices_) {
93 indices.emplace_back((*
this)(index));
95 return makeLoad(op->var_, std::move(indices), op->loadType_,
100 std::vector<Expr> indices;
101 indices.reserve(op->
indices_.size());
103 indices.emplace_back((*
this)(index));
105 auto &&expr = (*this)(op->
expr_);
128 return makeAdd((*
this)(op->lhs_), (*
this)(op->rhs_), op->debugBlame());
132 return makeSub((*
this)(op->lhs_), (*
this)(op->rhs_), op->debugBlame());
136 return makeMul((*
this)(op->lhs_), (*
this)(op->rhs_), op->debugBlame());
140 return makeRealDiv((*
this)(op->lhs_), (*
this)(op->rhs_),
145 return makeFloorDiv((*
this)(op->lhs_), (*
this)(op->rhs_),
150 return makeCeilDiv((*
this)(op->lhs_), (*
this)(op->rhs_),
160 return makeMod((*
this)(op->lhs_), (*
this)(op->rhs_), op->debugBlame());
169 return makeMin((*
this)(op->lhs_), (*
this)(op->rhs_), op->debugBlame());
173 return makeMax((*
this)(op->lhs_), (*
this)(op->rhs_), op->debugBlame());
177 return makeLT((*
this)(op->lhs_), (*
this)(op->rhs_), op->debugBlame());
181 return makeLE((*
this)(op->lhs_), (*
this)(op->rhs_), op->debugBlame());
185 return makeGT((*
this)(op->lhs_), (*
this)(op->rhs_), op->debugBlame());
189 return makeGE((*
this)(op->lhs_), (*
this)(op->rhs_), op->debugBlame());
193 return makeEQ((*
this)(op->lhs_), (*
this)(op->rhs_), op->debugBlame());
197 return makeNE((*
this)(op->lhs_), (*
this)(op->rhs_), op->debugBlame());
201 return makeLAnd((*
this)(op->lhs_), (*
this)(op->rhs_), op->debugBlame());
205 return makeLOr((*
this)(op->lhs_), (*
this)(op->rhs_), op->debugBlame());
209 return makeLNot((*
this)(op->expr_), op->debugBlame());
213 return makeSqrt((*
this)(op->expr_), op->debugBlame());
217 return makeExp((*
this)(op->expr_), op->debugBlame());
221 return makeLn((*
this)(op->expr_), op->debugBlame());
225 return makeSquare((*
this)(op->expr_), op->debugBlame());
229 return makeSigmoid((*
this)(op->expr_), op->debugBlame());
233 return makeSin((*
this)(op->expr_), op->debugBlame());
237 return makeCos((*
this)(op->expr_), op->debugBlame());
241 return makeTan((*
this)(op->expr_), op->debugBlame());
245 return makeTanh((*
this)(op->expr_), op->debugBlame());
249 return makeAbs((*
this)(op->expr_), op->debugBlame());
253 return makeFloor((*
this)(op->expr_), op->debugBlame());
257 return makeCeil((*
this)(op->expr_), op->debugBlame());
261 return makeUnbound((*
this)(op->expr_), op->debugBlame());
265 auto begin = (*this)(op->
begin_);
266 auto end = (*this)(op->
end_);
267 auto step = (*this)(op->
step_);
268 auto len = (*this)(op->
len_);
272 ->withVectorize(op->
property_->vectorize_)
274 ->withPreferLibs(op->
property_->preferLibs_);
275 property->reductions_.reserve(op->
property_->reductions_.size());
276 for (
auto &&r : op->
property_->reductions_) {
277 std::vector<Expr> begins, ends;
278 begins.reserve(r->begins_.size());
279 ends.reserve(r->ends_.size());
280 for (
auto &&item : r->begins_) {
281 begins.emplace_back((*
this)(item));
283 for (
auto &&item : r->ends_) {
284 ends.emplace_back((*
this)(item));
286 property->reductions_.emplace_back(
288 std::move(ends), r->syncFlush_));
290 auto body = (*this)(op->
body_);
291 return makeFor(op->
iter_, std::move(begin), std::move(end),
292 std::move(step), std::move(len), std::move(property),
298 auto cond = (*this)(op->cond_);
299 auto thenCase = (*this)(op->thenCase_);
301 op->elseCase_.
isValid() ? (*this)(op->elseCase_) :
nullptr;
302 return makeIf(std::move(cond), std::move(thenCase), std::move(elseCase),
303 op->metadata(), op->id(), op->debugBlame());
307 return makeAssert((*
this)(op->cond_), (*
this)(op->body_),
308 op->metadata(), op->id(), op->debugBlame());
312 return makeAssume((*
this)(op->cond_), (*
this)(op->body_),
313 op->metadata(), op->id(), op->debugBlame());
317 return makeIfExpr((*
this)(op->cond_), (*
this)(op->thenCase_),
318 (*
this)(op->elseCase_), op->debugBlame());
322 return makeCast((*
this)(op->expr_), op->destType_, op->debugBlame());
327 params.reserve(op->params_.size());
328 for (
auto &¶m : op->params_) {
329 params.emplace_back((*
this)(param));
332 op->hasSideEffect_, op->debugBlame());
336 return makeEval((*
this)(op->expr_), op->metadata(), op->id(),
342 if (op->cutlassMicroKernelProperty_.
isValid()) {
344 op->cutlassMicroKernelProperty_->nWarpBatch_,
345 op->cutlassMicroKernelProperty_->nWarpM_,
346 op->cutlassMicroKernelProperty_->nWarpN_,
347 (*
this)(op->cutlassMicroKernelProperty_->warpIdBatch_),
348 (*
this)(op->cutlassMicroKernelProperty_->warpIdM_),
349 (*
this)(op->cutlassMicroKernelProperty_->warpIdN_),
350 (*
this)(op->cutlassMicroKernelProperty_->laneId_));
353 op->backend_, std::move(cutlassMicroKernelProperty),
354 (*
this)(op->a_), (*
this)(op->b_), (*
this)(op->c_),
355 (*
this)(op->alpha_), (*
this)(op->beta_), (*
this)(op->m_),
356 (*
this)(op->k_), (*
this)(op->n_), (*
this)(op->lda_),
357 (*
this)(op->ldb_), (*
this)(op->ldc_), (*
this)(op->stridea_),
358 (*
this)(op->strideb_), (*
this)(op->stridec_),
359 (*
this)(op->batchSize_), op->aIsRowMajor_, op->bIsRowMajor_,
360 op->cIsRowMajor_, (*
this)(op->equivalent_), op->metadata(),
361 op->id(), op->debugBlame());
366 op->id(), op->debugBlame());
370 std::vector<Expr> indices;
371 indices.reserve(op->indices_.size());
372 for (
auto &&index : op->indices_) {
373 indices.emplace_back((*
this)(index));
376 op->loadType_, op->debugBlame());
std::source_location debugBlame() const
Definition: ast.h:134
SubTree< ForProperty > property_
Definition: stmt.h:298
SubTree< ExprNode > begin_
Definition: stmt.h:294
SubTree< ExprNode > step_
Definition: stmt.h:296
SubTree< ExprNode > len_
Definition: stmt.h:297
SubTree< ExprNode > end_
Definition: stmt.h:295
std::string iter_
Definition: stmt.h:289
SubTree< StmtNode > body_
Definition: stmt.h:299
virtual Expr visit(const Exp &op)
Definition: mutator.h:216
virtual Expr visit(const Sigmoid &op)
Definition: mutator.h:228
virtual Expr visit(const Var &op)
Definition: mutator.h:66
virtual Stmt visit(const Eval &op)
Definition: mutator.h:335
virtual Expr visit(const RoundTowards0Div &op)
Definition: mutator.h:154
virtual Expr visit(const LAnd &op)
Definition: mutator.h:200
virtual Expr visit(const BoolConst &op)
Definition: mutator.h:123
virtual Expr visit(const Sqrt &op)
Definition: mutator.h:212
virtual Stmt visit(const If &op)
Definition: mutator.h:297
virtual Expr visit(const Min &op)
Definition: mutator.h:168
virtual Stmt visit(const Store &op)
Definition: mutator.h:70
virtual Stmt visitStmt(const Stmt &op)
Definition: mutator.cc:64
virtual Stmt visit(const Alloc &op)
Definition: mutator.h:81
virtual Expr visit(const Mul &op)
Definition: mutator.h:135
virtual Expr visit(const LOr &op)
Definition: mutator.h:204
virtual Stmt visit(const VarDef &op)
Definition: mutator.h:51
virtual Expr visit(const NE &op)
Definition: mutator.h:196
virtual Expr visit(const LNot &op)
Definition: mutator.h:208
virtual Expr visit(const GE &op)
Definition: mutator.h:188
virtual Expr visit(const EQ &op)
Definition: mutator.h:192
virtual Expr visit(const Remainder &op)
Definition: mutator.h:163
virtual Stmt visit(const Free &op)
Definition: mutator.h:85
virtual Expr visit(const Square &op)
Definition: mutator.h:224
virtual Stmt visit(const ReduceTo &op)
Definition: mutator.h:99
virtual Expr visit(const Max &op)
Definition: mutator.h:172
virtual Expr visit(const Intrinsic &op)
Definition: mutator.h:325
virtual Expr visit(const AnyExpr &op)
Definition: mutator.h:111
virtual Expr visit(const Cast &op)
Definition: mutator.h:321
virtual Stmt visit(const Assume &op)
Definition: mutator.h:311
virtual Stmt visit(const MatMul &op)
Definition: mutator.h:340
virtual Expr visit(const GT &op)
Definition: mutator.h:184
virtual Expr visitExpr(const Expr &op)
Definition: mutator.cc:6
virtual Expr visit(const FloorDiv &op)
Definition: mutator.h:144
virtual Expr visit(const Sub &op)
Definition: mutator.h:131
virtual Expr visit(const Cos &op)
Definition: mutator.h:236
virtual Expr visit(const IfExpr &op)
Definition: mutator.h:316
virtual Expr visit(const Tan &op)
Definition: mutator.h:240
virtual Stmt visit(const Any &op)
Definition: mutator.h:39
virtual Expr visit(const Abs &op)
Definition: mutator.h:248
virtual Expr visit(const RealDiv &op)
Definition: mutator.h:139
virtual Expr visit(const IntConst &op)
Definition: mutator.h:115
virtual Expr visit(const Ln &op)
Definition: mutator.h:220
virtual Expr visit(const LE &op)
Definition: mutator.h:180
virtual Expr visit(const Load &op)
Definition: mutator.h:89
virtual Expr visit(const Ceil &op)
Definition: mutator.h:256
virtual Expr visit(const LT &op)
Definition: mutator.h:176
virtual Stmt visit(const Assert &op)
Definition: mutator.h:306
virtual Stmt visit(const For &op)
Definition: mutator.h:264
virtual Stmt visit(const MarkVersion &op)
Definition: mutator.h:364
virtual Expr visit(const LoadAtVersion &op)
Definition: mutator.h:369
virtual Expr visit(const Add &op)
Definition: mutator.h:127
virtual Expr visit(const Floor &op)
Definition: mutator.h:252
virtual ~Mutator()
Definition: mutator.h:16
virtual Expr visit(const Mod &op)
Definition: mutator.h:159
virtual Stmt visit(const StmtSeq &op)
Definition: mutator.h:41
virtual Expr visit(const Unbound &op)
Definition: mutator.h:260
virtual Expr visit(const CeilDiv &op)
Definition: mutator.h:149
virtual Expr visit(const Tanh &op)
Definition: mutator.h:244
Stmt StmtRetType
Definition: mutator.h:14
Expr ExprRetType
Definition: mutator.h:13
virtual Expr visit(const FloatConst &op)
Definition: mutator.h:119
virtual Stmt operator()(const Stmt &op) final
Definition: mutator.cc:91
virtual Expr visit(const Sin &op)
Definition: mutator.h:232
std::string var_
Definition: stmt.h:231
SubTreeList< ExprNode > indices_
Definition: stmt.h:232
ReduceOp op_
Definition: stmt.h:233
SubTree< ExprNode > expr_
Definition: stmt.h:234
bool sync_
Definition: stmt.h:243
static Ref make()
Definition: ref.h:105
bool isValid() const
Definition: ref.h:89
const Metadata & metadata() const
Definition: ast.h:233
ID id() const
Definition: ast.cc:362
SubTreeList< StmtNode > stmts_
Definition: stmt.h:44
std::string var_
Definition: stmt.h:134
SubTree< ExprNode > expr_
Definition: stmt.h:136
SubTreeList< ExprNode > indices_
Definition: stmt.h:135
bool pinned_
Definition: stmt.h:102
SubTree< StmtNode > body_
Definition: stmt.h:101
SubTree< Buffer > buffer_
Definition: stmt.h:86
std::optional< std::string > viewOf_
Definition: stmt.h:99
std::string name_
Definition: stmt.h:85
Definition: allocator.h:9
Expr makeCeil(T &&expr, std::source_location loc=std::source_location::current())
Definition: expr.h:658
Expr makeLoad(const std::string &var, Tindices &&indices, DataType loadType, std::source_location loc=std::source_location::current())
Definition: expr.h:63
Expr makeRemainder(T &&lhs, U &&rhs, std::source_location loc=std::source_location::current())
Definition: expr.h:326
Expr makeSigmoid(T &&expr, std::source_location loc=std::source_location::current())
Definition: expr.h:560
Expr makeLT(T &&lhs, U &&rhs, std::source_location loc=std::source_location::current())
Definition: expr.h:368
Stmt makeEval(T &&expr, const Metadata &metadata=nullptr, const ID &id={}, std::source_location loc=std::source_location::current())
Definition: stmt.h:452
Expr makeMin(T &&lhs, U &&rhs, std::source_location loc=std::source_location::current())
Definition: expr.h:340
Expr makeIntrinsic(const std::string &format, T &¶ms, DataType retType, bool hasSideEffect, std::source_location loc=std::source_location::current())
Definition: expr.h:756
PBSet params(T &&set)
Definition: presburger.h:1065
Stmt makeAssume(Tcond &&cond, Tbody &&body, const Metadata &metadata=nullptr, const ID &id={}, std::source_location loc=std::source_location::current())
Definition: stmt.h:427
Expr makeSin(T &&expr, std::source_location loc=std::source_location::current())
Definition: expr.h:574
Ref< ReductionItem > makeReductionItem(ReduceOp op, const std::string &var, Tbegins &&begins, Tends &&ends, bool syncFlush)
Definition: for_property.h:26
Expr makeLoadAtVersion(const std::string &tapeName, Tindices &&indices, const DataType loadType, std::source_location loc=std::source_location::current())
Definition: expr.h:793
Expr makeUnbound(T &&expr, std::source_location loc=std::source_location::current())
Definition: expr.h:685
Expr makeMul(T &&lhs, U &&rhs, std::source_location loc=std::source_location::current())
Definition: expr.h:202
Ref< Buffer > makeBuffer(T &&tensor, AccessType atype, MemType mtype)
Definition: buffer.h:32
Expr makeGT(T &&lhs, U &&rhs, std::source_location loc=std::source_location::current())
Definition: expr.h:396
Expr makeFloorDiv(T &&lhs, U &&rhs, std::source_location loc=std::source_location::current())
Definition: expr.h:239
Expr makeAdd(T &&lhs, U &&rhs, std::source_location loc=std::source_location::current())
Definition: expr.h:174
Expr makeCeilDiv(T &&lhs, U &&rhs, std::source_location loc=std::source_location::current())
Definition: expr.h:259
Expr makeIfExpr(T &&cond, U &&thenCase, V &&elseCase, std::source_location loc=std::source_location::current())
Definition: expr.h:707
Stmt makeAssert(Tcond &&cond, Tbody &&body, const Metadata &metadata=nullptr, const ID &id={}, std::source_location loc=std::source_location::current())
Definition: stmt.h:394
Expr makeMod(T &&lhs, U &&rhs, std::source_location loc=std::source_location::current())
Definition: expr.h:303
Expr makeLE(T &&lhs, U &&rhs, std::source_location loc=std::source_location::current())
Definition: expr.h:382
Stmt makeStore(const std::string &var, Tindices &&indices, Texpr &&expr, const Metadata &metadata=nullptr, const ID &id={}, std::source_location loc=std::source_location::current())
Definition: stmt.h:142
Expr makeAbs(T &&expr, std::source_location loc=std::source_location::current())
Definition: expr.h:630
Expr makeCos(T &&expr, std::source_location loc=std::source_location::current())
Definition: expr.h:588
Stmt makeAny(std::source_location loc=std::source_location::current())
Definition: stmt.h:29
Expr makeLn(T &&expr, std::source_location loc=std::source_location::current())
Definition: expr.h:532
Expr makeTanh(T &&expr, std::source_location loc=std::source_location::current())
Definition: expr.h:616
Stmt makeReduceTo(const std::string &var, Tindices &&indices, ReduceOp op, Texpr &&expr, bool sync, const Metadata &metadata=nullptr, const ID &id={}, std::source_location loc=std::source_location::current())
Definition: stmt.h:250
Expr makeNE(T &&lhs, U &&rhs, std::source_location loc=std::source_location::current())
Definition: expr.h:438
Stmt makeStmtSeq(Tstmts &&stmts, const Metadata &metadata=nullptr, const ID &id={}, std::source_location loc=std::source_location::current())
Definition: stmt.h:51
Stmt makeMatMul(MatMulBackend backend, const Ref< CutlassMicroKernelProperty > &cutlassMicroKernelProperty, const Expr &a, const Expr &b, const Expr &c, const Expr &alpha, const Expr &beta, const Expr &m, const Expr &k, const Expr &n, const Expr &lda, const Expr &ldb, const Expr &ldc, const Expr &stridea, const Expr &strideb, const Expr &stridec, const Expr &batchSize, bool aIsRowMajor, bool bIsRowMajor, bool cIsRowMajor, const Stmt &equivalent, const Metadata &metadata=nullptr, const ID &id={}, std::source_location loc=std::source_location::current())
Definition: stmt.h:535
Expr makeEQ(T &&lhs, U &&rhs, std::source_location loc=std::source_location::current())
Definition: expr.h:424
Expr makeLOr(T &&lhs, U &&rhs, std::source_location loc=std::source_location::current())
Definition: expr.h:466
Expr makeSquare(T &&expr, std::source_location loc=std::source_location::current())
Definition: expr.h:546
Expr makeLAnd(T &&lhs, U &&rhs, std::source_location loc=std::source_location::current())
Definition: expr.h:452
Expr makeRealDiv(T &&lhs, U &&rhs, std::source_location loc=std::source_location::current())
Definition: expr.h:219
Stmt makeMarkVersion(const std::string &tapeName, const std::string &var, const Metadata &metadata=nullptr, const ID &id={}, std::source_location loc=std::source_location::current())
Definition: stmt.h:581
Stmt makeIf(Tcond &&cond, Tthen &&thenCase, Telse &&elseCase, const Metadata &metadata=nullptr, const ID &id={}, std::source_location loc=std::source_location::current())
Definition: stmt.h:354
Expr makeCast(T &&expr, DataType destType, std::source_location loc=std::source_location::current())
Definition: expr.h:728
Expr makeAnyExpr(std::source_location loc=std::source_location::current())
Definition: expr.h:26
Stmt makeVarDef(const std::string &name, Tbuffer &&buffer, const std::optional< std::string > &viewOf, Tbody &&body, bool pinned, const Metadata &metadata=nullptr, const ID &id={}, std::source_location loc=std::source_location::current())
Definition: stmt.h:109
Expr makeTan(T &&expr, std::source_location loc=std::source_location::current())
Definition: expr.h:602
Expr makeSub(T &&lhs, U &&rhs, std::source_location loc=std::source_location::current())
Definition: expr.h:188
Expr makeExp(T &&expr, std::source_location loc=std::source_location::current())
Definition: expr.h:518
Expr makeGE(T &&lhs, U &&rhs, std::source_location loc=std::source_location::current())
Definition: expr.h:410
Expr makeSqrt(T &&expr, std::source_location loc=std::source_location::current())
Definition: expr.h:504
Expr makeFloor(T &&expr, std::source_location loc=std::source_location::current())
Definition: expr.h:644
Expr makeVar(const std::string &name, std::source_location loc=std::source_location::current())
Definition: expr.h:42
Stmt makeFor(const std::string &iter, Tbegin &&begin, Tend &&end, Tstep &&step, Tlen &&len, Tproperty &&property, Tbody &&body, const Metadata &metadata=nullptr, const ID &id={}, std::source_location loc=std::source_location::current())
Definition: stmt.h:311
Expr makeBoolConst(bool val, std::source_location loc=std::source_location::current())
Definition: expr.h:136
Expr makeIntConst(int64_t val, std::source_location loc=std::source_location::current())
Definition: expr.h:102
Expr makeLNot(T &&expr, std::source_location loc=std::source_location::current())
Definition: expr.h:490
Expr makeFloatConst(double val, std::source_location loc=std::source_location::current())
Definition: expr.h:119
Stmt makeFree(const std::string &var, const Metadata &metadata=nullptr, const ID &id={}, std::source_location loc=std::source_location::current())
Definition: stmt.h:209
Expr makeRoundTowards0Div(T &&lhs, U &&rhs, std::source_location loc=std::source_location::current())
Definition: expr.h:279
Stmt makeAlloc(const std::string &var, const Metadata &metadata=nullptr, const ID &id={}, std::source_location loc=std::source_location::current())
Definition: stmt.h:184
Ref< Tensor > makeTensor(T &&shape, DataType dtype)
Definition: tensor.h:36
Expr makeMax(T &&lhs, U &&rhs, std::source_location loc=std::source_location::current())
Definition: expr.h:354