1#ifndef FREE_TENSOR_AST_H
2#define FREE_TENSOR_AST_H
8#include <source_location>
104#define DEFINE_NODE_ACCESS(name) DEFINE_AST_PART_ACCESS(name##Node)
106#define DEFINE_NODE_TRAIT(name) \
107 DEFINE_NODE_ACCESS(name) \
109 virtual ASTNodeType nodeType() const override { return ASTNodeType::name; }
119#ifdef FT_DEBUG_BLAME_AST
120 std::source_location debugBlame_;
127 bool isAST()
const override {
return true; }
128 virtual bool isFunc()
const {
return false; }
129 virtual bool isStmt()
const {
return false; }
130 virtual bool isExpr()
const {
return false; }
135#ifdef FT_DEBUG_BLAME_AST
138 return std::source_location::current();
142#ifdef FT_DEBUG_BLAME_AST
162 bool isExpr()
const override {
return true; }
164 virtual bool isConst()
const {
return false; }
166 virtual bool isUnary()
const {
return false; }
168 virtual std::vector<Ref<ExprNode>>
children()
const = 0;
205 template <std::convertible_to<Stmt> T>
217 friend struct ::std::hash<StmtOrExprID>;
236 bool isStmt()
const override {
return true; }
244 virtual std::vector<Ref<StmtNode>>
children()
const {
return {}; }
313template <
typename... Srcs>
314 requires(std::convertible_to<Srcs, Stmt> && ...)
317 if (s->metadata().isValid())
318 return s->metadata();
323 std::vector<Metadata>{metadataFrom(sourceStmts)...});
virtual bool isExpr() const
Definition: ast.h:130
virtual bool isFunc() const
Definition: ast.h:128
void setDebugBlame(std::source_location loc)
Definition: ast.h:141
virtual ASTNodeType nodeType() const =0
virtual ~ASTNode()
Definition: ast.h:124
Ref< ASTNode > parentAST() const
Definition: ast.cc:76
bool isAST() const override
Definition: ast.h:127
std::source_location debugBlame() const
Definition: ast.h:134
virtual bool isStmt() const
Definition: ast.h:129
Definition: sub_tree.h:50
virtual void modifiedHook()
Definition: sub_tree.h:117
Definition: data_type.h:106
virtual bool isConst() const
Definition: ast.h:164
virtual void inferDType()=0
void modifiedHook() override
Definition: ast.h:173
void resetDType()
Definition: ast.cc:336
std::optional< DataType > dtype_
Definition: ast.h:159
DataType dtype()
Definition: ast.cc:329
virtual bool isUnary() const
Definition: ast.h:166
Ref< StmtNode > parentStmt() const
Definition: ast.cc:94
Ref< ExprNode > parentExpr() const
Definition: ast.cc:85
bool isExpr() const override
Definition: ast.h:162
virtual std::vector< Ref< ExprNode > > children() const =0
virtual bool isBinary() const
Definition: ast.h:165
static ID make()
Definition: id.h:30
bool isValid() const
Definition: id.h:33
Ref< StmtNode > ancestorById(const ID &lookup) const
Definition: ast.cc:279
Ref< StmtNode > parentStmt() const
Definition: ast.cc:103
const Metadata & metadata() const
Definition: ast.h:233
Ref< StmtNode > parentStmtByFilter(const std::function< bool(const Stmt &)> &filter) const
Definition: ast.cc:112
virtual std::vector< Ref< StmtNode > > children() const
Definition: ast.h:244
virtual bool isCtrlFlow() const
Definition: ast.h:242
void setId(const ID &id=ID::make())
Definition: ast.cc:356
Ref< StmtNode > prevStmtInDFSPostOrder() const
Definition: ast.cc:241
Metadata & metadata()
Definition: ast.h:234
bool isBefore(const Stmt &other) const
Definition: ast.cc:296
Ref< StmtNode > prevStmt() const
Definition: ast.cc:122
bool isAncestorOf(const Stmt &other) const
Definition: ast.cc:288
Ref< StmtNode > prevInCtrlFlow() const
Definition: ast.cc:155
bool isStmt() const override
Definition: ast.h:236
Ref< StmtNode > nextStmtInDFSPreOrder() const
Definition: ast.cc:260
Ref< StmtNode > prevLeafStmtInDFSOrder() const
Definition: ast.cc:195
Ref< StmtNode > nextLeafStmtInDFSOrder() const
Definition: ast.cc:218
Ref< StmtNode > nextInCtrlFlow() const
Definition: ast.cc:175
Ref< StmtNode > parentCtrlFlow() const
Definition: ast.cc:146
ID id() const
Definition: ast.cc:362
Ref< StmtNode > nextStmt() const
Definition: ast.cc:134
const ID & stmtId() const
Definition: ast.h:210
friend std::ostream & operator<<(std::ostream &os, const StmtOrExprID &id)
Definition: ast.cc:345
StmtOrExprID(const ID &stmtId)
Definition: ast.h:200
bool isValid() const
Definition: ast.h:213
const Expr & expr() const
Definition: ast.h:211
StmtOrExprID(const Expr &expr, T &&parent)
Definition: ast.h:206
StmtOrExprID()
Definition: ast.h:198
friend bool operator==(const StmtOrExprID &lhs, const StmtOrExprID &rhs)
Definition: ast.cc:352
StmtOrExprID(const Expr &expr, const ID &stmtId)
Definition: ast.h:202
Definition: allocator.h:9
Ref< TanNode > Tan
Definition: expr.h:600
Ref< AssumeNode > Assume
Definition: stmt.h:425
Ref< RemainderNode > Remainder
Definition: expr.h:324
Ref< FreeNode > Free
Definition: stmt.h:207
Ref< BoolConstNode > BoolConst
Definition: expr.h:134
Ref< RoundTowards0DivNode > RoundTowards0Div
Definition: expr.h:277
AST lcaAST(const AST &lhs, const AST &rhs)
Definition: ast.cc:367
Ref< VarNode > Var
Definition: expr.h:40
auto && lhs
Definition: const_fold.cc:70
Ref< VarDefNode > VarDef
Definition: stmt.h:107
auto makeMetadata(const std::string &op, Srcs &&...sourceStmts)
Definition: ast.h:315
Ref< CeilDivNode > CeilDiv
Definition: expr.h:257
Ref< MaxNode > Max
Definition: expr.h:352
Ref< LAndNode > LAnd
Definition: expr.h:450
Ref< MarkVersionNode > MarkVersion
Definition: stmt.h:579
Ref< LoadNode > Load
Definition: expr.h:61
Ref< StoreNode > Store
Definition: stmt.h:140
Ref< ReduceToNode > ReduceTo
Definition: stmt.h:248
Ref< SubNode > Sub
Definition: expr.h:186
Ref< IfNode > If
Definition: stmt.h:352
Ref< IfExprNode > IfExpr
Definition: expr.h:705
Ref< SquareNode > Square
Definition: expr.h:544
Ref< SigmoidNode > Sigmoid
Definition: expr.h:558
Ref< CastNode > Cast
Definition: expr.h:726
Ref< ForNode > For
Definition: stmt.h:308
Ref< RealDivNode > RealDiv
Definition: expr.h:217
std::vector< T > filter(const std::vector< T > &vec, const U &callback)
Definition: container_utils.h:131
Ref< EvalNode > Eval
Definition: stmt.h:450
Ref< CeilNode > Ceil
Definition: expr.h:656
Ref< ASTNode > AST
Definition: ast.h:149
Ref< LNotNode > LNot
Definition: expr.h:488
Ref< FloorNode > Floor
Definition: expr.h:642
Ref< GENode > GE
Definition: expr.h:408
Ref< GTNode > GT
Definition: expr.h:394
Expr lcaExpr(const Expr &lhs, const Expr &rhs)
Definition: ast.cc:375
Ref< FloorDivNode > FloorDiv
Definition: expr.h:237
Ref< TanhNode > Tanh
Definition: expr.h:614
Stmt lcaStmt(const Stmt &lhs, const Stmt &rhs)
Definition: ast.cc:383
Ref< MulNode > Mul
Definition: expr.h:200
Ref< LTNode > LT
Definition: expr.h:366
Ref< NENode > NE
Definition: expr.h:436
Expr deepCopy(const Expr &op)
Definition: ast.cc:364
Ref< StmtNode > Stmt
Definition: ast.h:152
Ref< LoadAtVersionNode > LoadAtVersion
Definition: expr.h:790
Ref< IntrinsicNode > Intrinsic
Definition: expr.h:754
Ref< MinNode > Min
Definition: expr.h:338
Ref< IntConstNode > IntConst
Definition: expr.h:100
Ref< AnyExprNode > AnyExpr
Definition: expr.h:24
Ref< AbsNode > Abs
Definition: expr.h:628
Ref< EQNode > EQ
Definition: expr.h:422
auto auto && rhs
Definition: const_fold.cc:70
Ref< FuncNode > Func
Definition: func.h:64
Ref< SqrtNode > Sqrt
Definition: expr.h:502
Ref< AddNode > Add
Definition: expr.h:172
std::ostream & operator<<(std::ostream &os, const Dependence &dep)
Definition: deps.cc:1404
Ref< ExprNode > Expr
Definition: ast.h:184
Ref< UnboundNode > Unbound
Definition: expr.h:683
Ref< ModNode > Mod
Definition: expr.h:301
Ref< AssertNode > Assert
Definition: stmt.h:392
Ref< AllocNode > Alloc
Definition: stmt.h:182
Ref< LOrNode > LOr
Definition: expr.h:464
Ref< MatMulNode > MatMul
Definition: stmt.h:533
Ref< CosNode > Cos
Definition: expr.h:586
Ref< SinNode > Sin
Definition: expr.h:572
Ref< LENode > LE
Definition: expr.h:380
Ref< AnyNode > Any
Definition: stmt.h:27
Ref< FloatConstNode > FloatConst
Definition: expr.h:117
Ref< StmtSeqNode > StmtSeq
Definition: stmt.h:49
Ref< ExpNode > Exp
Definition: expr.h:516
ASTNodeType
Definition: ast.h:20
Ref< LnNode > Ln
Definition: expr.h:530