1#ifndef FREE_TENSOR_EXPR_H
2#define FREE_TENSOR_EXPR_H
21 std::vector<Expr>
children()
const override {
return {}; }
26makeAnyExpr(std::source_location loc = std::source_location::current()) {
28 a->setDebugBlame(loc);
37 std::vector<Expr>
children()
const override {
return {}; }
43 std::source_location loc = std::source_location::current()) {
47 v->setDebugBlame(loc);
62template <
class Tindices>
64 std::source_location loc = std::source_location::current()) {
68 l->indices_ = std::forward<Tindices>(indices);
69 l->loadType_ = loadType;
70 l->setDebugBlame(loc);
74makeLoad(
const std::string &var,
const std::vector<Expr> &indices,
76 std::source_location loc = std::source_location::current()) {
80 l->indices_ = indices;
81 l->loadType_ = loadType;
82 l->setDebugBlame(loc);
88 bool isConst()
const override {
return true; }
89 std::vector<Expr>
children()
const override {
return {}; }
103 std::source_location loc = std::source_location::current()) {
106 c->setDebugBlame(loc);
120 std::source_location loc = std::source_location::current()) {
123 c->setDebugBlame(loc);
137 std::source_location loc = std::source_location::current()) {
140 b->setDebugBlame(loc);
169 void inferDType()
override;
173template <
class T,
class U>
175 std::source_location loc = std::source_location::current()) {
177 a->lhs_ = std::forward<T>(
lhs), a->rhs_ = std::forward<U>(
rhs);
178 a->setDebugBlame(loc);
183 void inferDType()
override;
187template <
class T,
class U>
189 std::source_location loc = std::source_location::current()) {
191 a->lhs_ = std::forward<T>(
lhs), a->rhs_ = std::forward<U>(
rhs);
192 a->setDebugBlame(loc);
197 void inferDType()
override;
201template <
class T,
class U>
203 std::source_location loc = std::source_location::current()) {
205 a->lhs_ = std::forward<T>(
lhs), a->rhs_ = std::forward<U>(
rhs);
206 a->setDebugBlame(loc);
214 void inferDType()
override;
218template <
class T,
class U>
220 std::source_location loc = std::source_location::current()) {
222 a->lhs_ = std::forward<T>(
lhs), a->rhs_ = std::forward<U>(
rhs);
223 a->setDebugBlame(loc);
234 void inferDType()
override;
238template <
class T,
class U>
240 std::source_location loc = std::source_location::current()) {
242 a->lhs_ = std::forward<T>(
lhs), a->rhs_ = std::forward<U>(
rhs);
243 a->setDebugBlame(loc);
254 void inferDType()
override;
258template <
class T,
class U>
260 std::source_location loc = std::source_location::current()) {
262 a->lhs_ = std::forward<T>(
lhs), a->rhs_ = std::forward<U>(
rhs);
263 a->setDebugBlame(loc);
274 void inferDType()
override;
278template <
class T,
class U>
281 std::source_location loc = std::source_location::current()) {
283 a->lhs_ = std::forward<T>(
lhs), a->rhs_ = std::forward<U>(
rhs);
284 a->setDebugBlame(loc);
298 void inferDType()
override;
302template <
class T,
class U>
304 std::source_location loc = std::source_location::current()) {
306 a->lhs_ = std::forward<T>(
lhs), a->rhs_ = std::forward<U>(
rhs);
307 a->setDebugBlame(loc);
321 void inferDType()
override;
325template <
class T,
class U>
327 std::source_location loc = std::source_location::current()) {
329 a->lhs_ = std::forward<T>(
lhs), a->rhs_ = std::forward<U>(
rhs);
330 a->setDebugBlame(loc);
335 void inferDType()
override;
339template <
class T,
class U>
341 std::source_location loc = std::source_location::current()) {
343 m->lhs_ = std::forward<T>(
lhs), m->rhs_ = std::forward<U>(
rhs);
344 m->setDebugBlame(loc);
349 void inferDType()
override;
353template <
class T,
class U>
355 std::source_location loc = std::source_location::current()) {
357 m->lhs_ = std::forward<T>(
lhs), m->rhs_ = std::forward<U>(
rhs);
358 m->setDebugBlame(loc);
363 void inferDType()
override;
367template <
class T,
class U>
369 std::source_location loc = std::source_location::current()) {
371 a->lhs_ = std::forward<T>(
lhs), a->rhs_ = std::forward<U>(
rhs);
372 a->setDebugBlame(loc);
377 void inferDType()
override;
381template <
class T,
class U>
383 std::source_location loc = std::source_location::current()) {
385 a->lhs_ = std::forward<T>(
lhs), a->rhs_ = std::forward<U>(
rhs);
386 a->setDebugBlame(loc);
391 void inferDType()
override;
395template <
class T,
class U>
397 std::source_location loc = std::source_location::current()) {
399 a->lhs_ = std::forward<T>(
lhs), a->rhs_ = std::forward<U>(
rhs);
400 a->setDebugBlame(loc);
405 void inferDType()
override;
409template <
class T,
class U>
411 std::source_location loc = std::source_location::current()) {
413 a->lhs_ = std::forward<T>(
lhs), a->rhs_ = std::forward<U>(
rhs);
414 a->setDebugBlame(loc);
419 void inferDType()
override;
423template <
class T,
class U>
425 std::source_location loc = std::source_location::current()) {
427 a->lhs_ = std::forward<T>(
lhs), a->rhs_ = std::forward<U>(
rhs);
428 a->setDebugBlame(loc);
433 void inferDType()
override;
437template <
class T,
class U>
439 std::source_location loc = std::source_location::current()) {
441 a->lhs_ = std::forward<T>(
lhs), a->rhs_ = std::forward<U>(
rhs);
442 a->setDebugBlame(loc);
447 void inferDType()
override;
451template <
class T,
class U>
453 std::source_location loc = std::source_location::current()) {
455 l->lhs_ = std::forward<T>(
lhs), l->rhs_ = std::forward<U>(
rhs);
456 l->setDebugBlame(loc);
461 void inferDType()
override;
465template <
class T,
class U>
467 std::source_location loc = std::source_location::current()) {
469 l->lhs_ = std::forward<T>(
lhs), l->rhs_ = std::forward<U>(
rhs);
470 l->setDebugBlame(loc);
479 bool isUnary()
const override {
return true; }
485 void inferDType()
override;
491 std::source_location loc = std::source_location::current()) {
493 n->expr_ = std::forward<T>(expr);
494 n->setDebugBlame(loc);
499 void inferDType()
override;
505 std::source_location loc = std::source_location::current()) {
507 s->expr_ = std::forward<T>(expr);
508 s->setDebugBlame(loc);
513 void inferDType()
override;
519 std::source_location loc = std::source_location::current()) {
521 e->expr_ = std::forward<T>(expr);
522 e->setDebugBlame(loc);
527 void inferDType()
override;
533 std::source_location loc = std::source_location::current()) {
535 e->expr_ = std::forward<T>(expr);
536 e->setDebugBlame(loc);
541 void inferDType()
override;
547 std::source_location loc = std::source_location::current()) {
549 e->expr_ = std::forward<T>(expr);
550 e->setDebugBlame(loc);
555 void inferDType()
override;
561 std::source_location loc = std::source_location::current()) {
563 e->expr_ = std::forward<T>(expr);
564 e->setDebugBlame(loc);
569 void inferDType()
override;
575 std::source_location loc = std::source_location::current()) {
577 e->expr_ = std::forward<T>(expr);
578 e->setDebugBlame(loc);
583 void inferDType()
override;
589 std::source_location loc = std::source_location::current()) {
591 e->expr_ = std::forward<T>(expr);
592 e->setDebugBlame(loc);
597 void inferDType()
override;
603 std::source_location loc = std::source_location::current()) {
605 e->expr_ = std::forward<T>(expr);
606 e->setDebugBlame(loc);
611 void inferDType()
override;
617 std::source_location loc = std::source_location::current()) {
619 e->expr_ = std::forward<T>(expr);
620 e->setDebugBlame(loc);
625 void inferDType()
override;
631 std::source_location loc = std::source_location::current()) {
633 e->expr_ = std::forward<T>(expr);
634 e->setDebugBlame(loc);
639 void inferDType()
override;
645 std::source_location loc = std::source_location::current()) {
647 e->expr_ = std::forward<T>(expr);
648 e->setDebugBlame(loc);
653 void inferDType()
override;
659 std::source_location loc = std::source_location::current()) {
661 e->expr_ = std::forward<T>(expr);
662 e->setDebugBlame(loc);
680 void inferDType()
override;
686 std::source_location loc = std::source_location::current()) {
688 e->expr_ = std::forward<T>(expr);
689 e->setDebugBlame(loc);
706template <
class T,
class U,
class V>
708 std::source_location loc = std::source_location::current()) {
710 e->cond_ = std::forward<T>(cond);
711 e->thenCase_ = std::forward<U>(thenCase);
712 e->elseCase_ = std::forward<V>(elseCase);
713 e->setDebugBlame(loc);
729 std::source_location loc = std::source_location::current()) {
731 e->expr_ = std::forward<T>(expr);
732 e->destType_ = destType;
733 e->setDebugBlame(loc);
758 std::source_location loc = std::source_location::current()) {
761 i->params_ = std::forward<T>(
params);
762 i->retType_ = retType;
763 i->hasSideEffect_ = hasSideEffect;
764 i->setDebugBlame(loc);
769 DataType retType,
bool hasSideEffect,
770 std::source_location loc = std::source_location::current()) {
774 i->retType_ = retType;
775 i->hasSideEffect_ = hasSideEffect;
776 i->setDebugBlame(loc);
791template <
class Tindices>
795 std::source_location loc = std::source_location::current()) {
797 l->tapeName_ = tapeName;
798 l->indices_ = std::forward<Tindices>(indices);
799 l->loadType_ = loadType;
800 l->setDebugBlame(loc);
806 std::source_location loc = std::source_location::current()) {
808 l->tapeName_ = tapeName;
809 l->indices_ = indices;
810 l->loadType_ = loadType;
811 l->setDebugBlame(loc);
815template <
class T,
class U>
817 std::source_location loc = std::source_location::current()) {
820 return makeAdd(std::forward<T>(
lhs), std::forward<U>(
rhs), loc);
822 return makeSub(std::forward<T>(
lhs), std::forward<U>(
rhs), loc);
824 return makeMul(std::forward<T>(
lhs), std::forward<U>(
rhs), loc);
835 return makeMod(std::forward<T>(
lhs), std::forward<U>(
rhs), loc);
839 return makeMin(std::forward<T>(
lhs), std::forward<U>(
rhs), loc);
841 return makeMax(std::forward<T>(
lhs), std::forward<U>(
rhs), loc);
843 return makeLT(std::forward<T>(
lhs), std::forward<U>(
rhs), loc);
845 return makeLE(std::forward<T>(
lhs), std::forward<U>(
rhs), loc);
847 return makeGT(std::forward<T>(
lhs), std::forward<U>(
rhs), loc);
849 return makeGE(std::forward<T>(
lhs), std::forward<U>(
rhs), loc);
851 return makeEQ(std::forward<T>(
lhs), std::forward<U>(
rhs), loc);
853 return makeNE(std::forward<T>(
lhs), std::forward<U>(
rhs), loc);
857 return makeLOr(std::forward<T>(
lhs), std::forward<U>(
rhs), loc);
865 std::source_location loc = std::source_location::current()) {
868 return makeLNot(std::forward<T>(expr), loc);
870 return makeSqrt(std::forward<T>(expr), loc);
872 return makeExp(std::forward<T>(expr), loc);
874 return makeLn(std::forward<T>(expr), loc);
876 return makeSquare(std::forward<T>(expr), loc);
880 return makeSin(std::forward<T>(expr), loc);
882 return makeCos(std::forward<T>(expr), loc);
884 return makeTan(std::forward<T>(expr), loc);
886 return makeTanh(std::forward<T>(expr), loc);
888 return makeAbs(std::forward<T>(expr), loc);
890 return makeFloor(std::forward<T>(expr), loc);
892 return makeCeil(std::forward<T>(expr), loc);
#define DEFINE_NODE_TRAIT(name)
Definition: ast.h:106
void inferDType() override
Definition: expr.h:20
void compHash() override
Definition: expr.cc:12
DEFINE_NODE_TRAIT(AnyExpr)
std::vector< Expr > children() const override
Definition: expr.h:21
std::vector< Expr > children() const override
Definition: expr.h:149
virtual bool isCommutative() const =0
SubTree< ExprNode > lhs_
Definition: expr.h:146
SubTree< ExprNode > rhs_
Definition: expr.h:146
bool isBinary() const override
Definition: expr.h:148
DEFINE_NODE_TRAIT(BoolConst)
void inferDType() override
Definition: expr.cc:27
bool val_
Definition: expr.h:129
void compHash() override
Definition: expr.cc:17
void compHash() override
Definition: expr.cc:19
SubTree< ExprNode > expr_
Definition: expr.h:719
std::vector< Expr > children() const override
Definition: expr.h:723
void inferDType() override
Definition: expr.cc:64
DataType destType_
Definition: expr.h:720
void compHash() override
Definition: expr.cc:7
bool isCommutative() const override
Definition: expr.h:157
bool isConst() const override
Definition: expr.h:88
std::vector< Expr > children() const override
Definition: expr.h:89
Definition: data_type.h:106
void inferDType() override
Definition: expr.cc:26
void compHash() override
Definition: expr.cc:16
double val_
Definition: expr.h:112
DEFINE_NODE_TRAIT(FloatConst)
std::vector< Expr > children() const override
Definition: expr.h:700
SubTree< ExprNode > thenCase_
Definition: expr.h:696
SubTree< ExprNode > cond_
Definition: expr.h:695
void inferDType() override
Definition: expr.cc:63
DEFINE_NODE_TRAIT(IfExpr)
SubTree< ExprNode > elseCase_
Definition: expr.h:697
void compHash() override
Definition: expr.cc:18
int64_t val_
Definition: expr.h:95
DEFINE_NODE_TRAIT(IntConst)
void compHash() override
Definition: expr.cc:15
void inferDType() override
Definition: expr.cc:25
std::vector< Expr > children() const override
Definition: expr.h:751
std::string format_
Definition: expr.h:743
void inferDType() override
Definition: expr.cc:65
void compHash() override
Definition: expr.cc:20
SubTreeList< ExprNode > params_
Definition: expr.h:746
DEFINE_NODE_TRAIT(Intrinsic)
DataType retType_
Definition: expr.h:747
bool hasSideEffect_
Definition: expr.h:748
void inferDType() override
Definition: expr.cc:66
std::string tapeName_
Definition: expr.h:782
SubTreeList< ExprNode > indices_
Definition: expr.h:783
DEFINE_NODE_TRAIT(LoadAtVersion)
void compHash() override
Definition: expr.cc:21
DataType loadType_
Definition: expr.h:784
std::vector< Expr > children() const override
Definition: expr.h:787
void compHash() override
Definition: expr.cc:14
std::string var_
Definition: expr.h:53
SubTreeList< ExprNode > indices_
Definition: expr.h:54
void inferDType() override
Definition: expr.cc:24
std::vector< Expr > children() const override
Definition: expr.h:58
DataType loadType_
Definition: expr.h:55
bool isCommutative() const override
Definition: expr.h:164
void compHash() override
Definition: expr.cc:8
static Ref make()
Definition: ref.h:105
Definition: sub_tree.h:305
Definition: sub_tree.h:134
std::vector< Expr > children() const override
Definition: expr.h:480
SubTree< ExprNode > expr_
Definition: expr.h:476
bool isUnary() const override
Definition: expr.h:479
void compHash() override
Definition: expr.cc:11
std::string name_
Definition: expr.h:34
void inferDType() override
Definition: expr.cc:23
std::vector< Expr > children() const override
Definition: expr.h:37
void compHash() override
Definition: expr.cc:13
#define ASSERT(expr)
Definition: except.h:152
Definition: allocator.h:9
Expr makeCeil(T &&expr, std::source_location loc=std::source_location::current())
Definition: expr.h:658
Ref< TanNode > Tan
Definition: expr.h:600
Expr makeLoad(const std::string &var, Tindices &&indices, DataType loadType, std::source_location loc=std::source_location::current())
Definition: expr.h:63
Expr makeRemainder(T &&lhs, U &&rhs, std::source_location loc=std::source_location::current())
Definition: expr.h:326
Ref< RemainderNode > Remainder
Definition: expr.h:324
Ref< BoolConstNode > BoolConst
Definition: expr.h:134
Ref< RoundTowards0DivNode > RoundTowards0Div
Definition: expr.h:277
Ref< VarNode > Var
Definition: expr.h:40
auto && lhs
Definition: const_fold.cc:70
Ref< CeilDivNode > CeilDiv
Definition: expr.h:257
Expr makeSigmoid(T &&expr, std::source_location loc=std::source_location::current())
Definition: expr.h:560
Expr makeUnary(ASTNodeType nodeType, T &&expr, std::source_location loc=std::source_location::current())
Definition: expr.h:864
Ref< MaxNode > Max
Definition: expr.h:352
Expr makeLT(T &&lhs, U &&rhs, std::source_location loc=std::source_location::current())
Definition: expr.h:368
Ref< ConstNode > Const
Definition: expr.h:91
Ref< LAndNode > LAnd
Definition: expr.h:450
Expr makeMin(T &&lhs, U &&rhs, std::source_location loc=std::source_location::current())
Definition: expr.h:340
Expr makeIntrinsic(const std::string &format, T &¶ms, DataType retType, bool hasSideEffect, std::source_location loc=std::source_location::current())
Definition: expr.h:756
PBSet params(T &&set)
Definition: presburger.h:1065
Expr makeSin(T &&expr, std::source_location loc=std::source_location::current())
Definition: expr.h:574
Expr makeLoadAtVersion(const std::string &tapeName, Tindices &&indices, const DataType loadType, std::source_location loc=std::source_location::current())
Definition: expr.h:793
Expr makeUnbound(T &&expr, std::source_location loc=std::source_location::current())
Definition: expr.h:685
Expr makeMul(T &&lhs, U &&rhs, std::source_location loc=std::source_location::current())
Definition: expr.h:202
Ref< LoadNode > Load
Definition: expr.h:61
Expr makeGT(T &&lhs, U &&rhs, std::source_location loc=std::source_location::current())
Definition: expr.h:396
Expr makeFloorDiv(T &&lhs, U &&rhs, std::source_location loc=std::source_location::current())
Definition: expr.h:239
Expr makeAdd(T &&lhs, U &&rhs, std::source_location loc=std::source_location::current())
Definition: expr.h:174
Expr makeCeilDiv(T &&lhs, U &&rhs, std::source_location loc=std::source_location::current())
Definition: expr.h:259
Ref< BinaryExprNode > BinaryExpr
Definition: expr.h:152
Expr makeIfExpr(T &&cond, U &&thenCase, V &&elseCase, std::source_location loc=std::source_location::current())
Definition: expr.h:707
Ref< SubNode > Sub
Definition: expr.h:186
Ref< IfExprNode > IfExpr
Definition: expr.h:705
Ref< SquareNode > Square
Definition: expr.h:544
Ref< SigmoidNode > Sigmoid
Definition: expr.h:558
Ref< CastNode > Cast
Definition: expr.h:726
Ref< RealDivNode > RealDiv
Definition: expr.h:217
Expr makeMod(T &&lhs, U &&rhs, std::source_location loc=std::source_location::current())
Definition: expr.h:303
Ref< CeilNode > Ceil
Definition: expr.h:656
Expr makeBinary(ASTNodeType nodeType, T &&lhs, U &&rhs, std::source_location loc=std::source_location::current())
Definition: expr.h:816
Expr makeLE(T &&lhs, U &&rhs, std::source_location loc=std::source_location::current())
Definition: expr.h:382
Ref< LNotNode > LNot
Definition: expr.h:488
Ref< FloorNode > Floor
Definition: expr.h:642
Ref< GENode > GE
Definition: expr.h:408
Ref< GTNode > GT
Definition: expr.h:394
Ref< FloorDivNode > FloorDiv
Definition: expr.h:237
Expr makeAbs(T &&expr, std::source_location loc=std::source_location::current())
Definition: expr.h:630
Ref< UnaryExprNode > UnaryExpr
Definition: expr.h:482
Ref< TanhNode > Tanh
Definition: expr.h:614
Expr makeCos(T &&expr, std::source_location loc=std::source_location::current())
Definition: expr.h:588
Ref< MulNode > Mul
Definition: expr.h:200
Expr makeLn(T &&expr, std::source_location loc=std::source_location::current())
Definition: expr.h:532
Ref< LTNode > LT
Definition: expr.h:366
Expr makeTanh(T &&expr, std::source_location loc=std::source_location::current())
Definition: expr.h:616
Ref< NENode > NE
Definition: expr.h:436
Expr makeNE(T &&lhs, U &&rhs, std::source_location loc=std::source_location::current())
Definition: expr.h:438
Ref< LoadAtVersionNode > LoadAtVersion
Definition: expr.h:790
Ref< IntrinsicNode > Intrinsic
Definition: expr.h:754
Expr makeEQ(T &&lhs, U &&rhs, std::source_location loc=std::source_location::current())
Definition: expr.h:424
Expr makeLOr(T &&lhs, U &&rhs, std::source_location loc=std::source_location::current())
Definition: expr.h:466
Expr makeSquare(T &&expr, std::source_location loc=std::source_location::current())
Definition: expr.h:546
Ref< MinNode > Min
Definition: expr.h:338
Expr makeLAnd(T &&lhs, U &&rhs, std::source_location loc=std::source_location::current())
Definition: expr.h:452
Expr makeRealDiv(T &&lhs, U &&rhs, std::source_location loc=std::source_location::current())
Definition: expr.h:219
Ref< IntConstNode > IntConst
Definition: expr.h:100
Ref< AnyExprNode > AnyExpr
Definition: expr.h:24
Ref< AbsNode > Abs
Definition: expr.h:628
Expr makeCast(T &&expr, DataType destType, std::source_location loc=std::source_location::current())
Definition: expr.h:728
Ref< EQNode > EQ
Definition: expr.h:422
Expr makeAnyExpr(std::source_location loc=std::source_location::current())
Definition: expr.h:26
Expr makeTan(T &&expr, std::source_location loc=std::source_location::current())
Definition: expr.h:602
Expr makeSub(T &&lhs, U &&rhs, std::source_location loc=std::source_location::current())
Definition: expr.h:188
Expr makeExp(T &&expr, std::source_location loc=std::source_location::current())
Definition: expr.h:518
auto auto && rhs
Definition: const_fold.cc:70
Expr makeGE(T &&lhs, U &&rhs, std::source_location loc=std::source_location::current())
Definition: expr.h:410
Ref< SqrtNode > Sqrt
Definition: expr.h:502
Ref< NonCommutativeBinaryExprNode > NonCommutativeBinaryExpr
Definition: expr.h:166
Expr makeSqrt(T &&expr, std::source_location loc=std::source_location::current())
Definition: expr.h:504
Ref< AddNode > Add
Definition: expr.h:172
Expr makeFloor(T &&expr, std::source_location loc=std::source_location::current())
Definition: expr.h:644
Expr makeVar(const std::string &name, std::source_location loc=std::source_location::current())
Definition: expr.h:42
Ref< ExprNode > Expr
Definition: ast.h:184
Expr makeBoolConst(bool val, std::source_location loc=std::source_location::current())
Definition: expr.h:136
Ref< UnboundNode > Unbound
Definition: expr.h:683
Expr makeIntConst(int64_t val, std::source_location loc=std::source_location::current())
Definition: expr.h:102
Expr makeLNot(T &&expr, std::source_location loc=std::source_location::current())
Definition: expr.h:490
Ref< ModNode > Mod
Definition: expr.h:301
Expr makeFloatConst(double val, std::source_location loc=std::source_location::current())
Definition: expr.h:119
Ref< LOrNode > LOr
Definition: expr.h:464
Ref< CosNode > Cos
Definition: expr.h:586
Ref< SinNode > Sin
Definition: expr.h:572
Expr makeRoundTowards0Div(T &&lhs, U &&rhs, std::source_location loc=std::source_location::current())
Definition: expr.h:279
Ref< LENode > LE
Definition: expr.h:380
Ref< CommutativeBinaryExprNode > CommutativeBinaryExpr
Definition: expr.h:159
Ref< FloatConstNode > FloatConst
Definition: expr.h:117
Ref< ExpNode > Exp
Definition: expr.h:516
ASTNodeType
Definition: ast.h:20
Ref< LnNode > Ln
Definition: expr.h:530
Expr makeMax(T &&lhs, U &&rhs, std::source_location loc=std::source_location::current())
Definition: expr.h:354
Definition: sub_tree.h:20