FreeTensor
Loading...
Searching...
No Matches
ast.h
Go to the documentation of this file.
1#ifndef FREE_TENSOR_AST_H
2#define FREE_TENSOR_AST_H
3
4#include <atomic>
5#include <functional>
6#include <iostream>
7#include <optional>
8#include <source_location>
9#include <string>
10
11#include <id.h>
12#include <metadata.h>
13#include <ref.h>
14#include <serialize/to_string.h>
15#include <sub_tree.h>
16#include <type/data_type.h>
17
18namespace freetensor {
19
20enum class ASTNodeType : int {
21 // Matching
22 Any,
23 AnyExpr,
24
25 // Function
26 Func,
27
28 // Memory Access
29 Store,
31 Load,
32 Alloc,
33 Free,
34
35 // Structral statements
36 StmtSeq,
37 VarDef,
38 For,
39 If,
40 Assert,
41 Assume,
42
43 // Calls to external libs
44 MatMul,
45
46 // Other statements
47 Eval,
48
49 // Values
50 Var,
54
55 // Binary ops
56 Add,
57 Sub,
58 Mul,
59 RealDiv,
61 CeilDiv,
63 Mod,
65 Min,
66 Max,
67 LT,
68 LE,
69 GT,
70 GE,
71 EQ,
72 NE,
73 LAnd,
74 LOr,
75
76 // Unary ops
77 LNot,
78 Sqrt,
79 Exp,
80 Ln,
81 Square,
82 Sigmoid,
83 Sin,
84 Cos,
85 Tan,
86 Tanh,
87 Abs,
88 Floor,
89 Ceil,
90 Unbound,
91
92 // Other expressions
93 IfExpr,
94 Cast,
96
97 // For custom gradient only
100};
101
102std::ostream &operator<<(std::ostream &os, ASTNodeType type);
103
104#define DEFINE_NODE_ACCESS(name) DEFINE_AST_PART_ACCESS(name##Node)
105
106#define DEFINE_NODE_TRAIT(name) \
107 DEFINE_NODE_ACCESS(name) \
108 public: \
109 virtual ASTNodeType nodeType() const override { return ASTNodeType::name; }
110
118class ASTNode : public ASTPart {
119#ifdef FT_DEBUG_BLAME_AST
120 std::source_location debugBlame_;
121#endif
122
123 public:
124 virtual ~ASTNode() {}
125 virtual ASTNodeType nodeType() const = 0;
126
127 bool isAST() const override { return true; }
128 virtual bool isFunc() const { return false; }
129 virtual bool isStmt() const { return false; }
130 virtual bool isExpr() const { return false; }
131
132 Ref<ASTNode> parentAST() const;
133
134 std::source_location debugBlame() const {
135#ifdef FT_DEBUG_BLAME_AST
136 return debugBlame_;
137#else
138 return std::source_location::current(); // Arbitrary return
139#endif
140 }
141 void setDebugBlame(std::source_location loc) {
142#ifdef FT_DEBUG_BLAME_AST
143 debugBlame_ = loc;
144#endif
145 }
146
148};
150
151class StmtNode;
153
157class ExprNode : public ASTNode {
158 protected:
159 std::optional<DataType> dtype_;
160
161 public:
162 bool isExpr() const override { return true; }
163
164 virtual bool isConst() const { return false; }
165 virtual bool isBinary() const { return false; }
166 virtual bool isUnary() const { return false; }
167
168 virtual std::vector<Ref<ExprNode>> children() const = 0;
169
172
173 void modifiedHook() override {
175 resetDType();
176 }
177
178 DataType dtype();
179 void resetDType();
180 virtual void inferDType() = 0;
181
183};
185
194 ID stmtId_;
195 Expr expr_;
196
197 public:
199
200 StmtOrExprID(const ID &stmtId) : stmtId_(stmtId) {}
201
203 : stmtId_(stmtId), expr_(expr) {}
204
205 template <std::convertible_to<Stmt> T>
206 StmtOrExprID(const Expr &expr, T &&parent) : stmtId_(parent->id()) {
207 expr_ = expr;
208 }
209
210 const ID &stmtId() const { return stmtId_; }
211 const Expr &expr() const { return expr_; }
212
213 bool isValid() const { return stmtId_.isValid(); }
214
215 friend std::ostream &operator<<(std::ostream &os, const StmtOrExprID &id);
216 friend bool operator==(const StmtOrExprID &lhs, const StmtOrExprID &rhs);
217 friend struct ::std::hash<StmtOrExprID>;
218};
219
223class StmtNode : public ASTNode {
224 friend ID;
225
226 ID id_;
227 Metadata metadata_;
228
229 public:
230 void setId(const ID &id = ID::make());
231 ID id() const;
232
233 const Metadata &metadata() const { return metadata_; }
234 Metadata &metadata() { return metadata_; }
235
236 bool isStmt() const override { return true; }
237
242 virtual bool isCtrlFlow() const { return false; }
243
244 virtual std::vector<Ref<StmtNode>> children() const { return {}; }
245
256 parentStmtByFilter(const std::function<bool(const Stmt &)> &filter) const;
257 Ref<StmtNode> prevStmt() const;
258 Ref<StmtNode> nextStmt() const;
277 Ref<StmtNode> prevStmtInDFSPostOrder() const; // may return child
278 Ref<StmtNode> nextStmtInDFSPreOrder() const; // may return child
279
283 Ref<StmtNode> ancestorById(const ID &lookup) const;
284
288 bool isAncestorOf(const Stmt &other) const;
289
293 bool isBefore(const Stmt &other) const;
294
296};
297
298Expr deepCopy(const Expr &op);
299Stmt deepCopy(const Stmt &op);
300
301AST lcaAST(const AST &lhs, const AST &rhs);
302Expr lcaExpr(const Expr &lhs, const Expr &rhs);
303Stmt lcaStmt(const Stmt &lhs, const Stmt &rhs);
304
313template <typename... Srcs>
314 requires(std::convertible_to<Srcs, Stmt> && ...)
315auto makeMetadata(const std::string &op, Srcs &&...sourceStmts) {
316 auto metadataFrom = [](const Stmt &s) -> Metadata {
317 if (s->metadata().isValid())
318 return s->metadata();
319 else
320 return makeMetadata(s->id());
321 };
322 return makeMetadata(op,
323 std::vector<Metadata>{metadataFrom(sourceStmts)...});
324}
325
326} // namespace freetensor
327
328namespace std {
329
330template <> struct hash<freetensor::StmtOrExprID> {
331 size_t operator()(const freetensor::StmtOrExprID &id) const;
332};
333
334} // namespace std
335
336#endif // FREE_TENSOR_AST_H
Definition: ast.h:118
virtual bool isExpr() const
Definition: ast.h:130
virtual bool isFunc() const
Definition: ast.h:128
void setDebugBlame(std::source_location loc)
Definition: ast.h:141
virtual ASTNodeType nodeType() const =0
virtual ~ASTNode()
Definition: ast.h:124
Ref< ASTNode > parentAST() const
Definition: ast.cc:76
bool isAST() const override
Definition: ast.h:127
std::source_location debugBlame() const
Definition: ast.h:134
virtual bool isStmt() const
Definition: ast.h:129
Definition: sub_tree.h:50
virtual void modifiedHook()
Definition: sub_tree.h:117
Definition: data_type.h:106
Definition: ast.h:157
virtual bool isConst() const
Definition: ast.h:164
virtual void inferDType()=0
void modifiedHook() override
Definition: ast.h:173
void resetDType()
Definition: ast.cc:336
std::optional< DataType > dtype_
Definition: ast.h:159
DataType dtype()
Definition: ast.cc:329
virtual bool isUnary() const
Definition: ast.h:166
Ref< StmtNode > parentStmt() const
Definition: ast.cc:94
Ref< ExprNode > parentExpr() const
Definition: ast.cc:85
bool isExpr() const override
Definition: ast.h:162
virtual std::vector< Ref< ExprNode > > children() const =0
virtual bool isBinary() const
Definition: ast.h:165
Definition: id.h:18
static ID make()
Definition: id.h:30
bool isValid() const
Definition: id.h:33
Definition: ref.h:24
Definition: ast.h:223
Ref< StmtNode > ancestorById(const ID &lookup) const
Definition: ast.cc:279
Ref< StmtNode > parentStmt() const
Definition: ast.cc:103
const Metadata & metadata() const
Definition: ast.h:233
Ref< StmtNode > parentStmtByFilter(const std::function< bool(const Stmt &)> &filter) const
Definition: ast.cc:112
virtual std::vector< Ref< StmtNode > > children() const
Definition: ast.h:244
virtual bool isCtrlFlow() const
Definition: ast.h:242
void setId(const ID &id=ID::make())
Definition: ast.cc:356
Ref< StmtNode > prevStmtInDFSPostOrder() const
Definition: ast.cc:241
Metadata & metadata()
Definition: ast.h:234
bool isBefore(const Stmt &other) const
Definition: ast.cc:296
Ref< StmtNode > prevStmt() const
Definition: ast.cc:122
bool isAncestorOf(const Stmt &other) const
Definition: ast.cc:288
Ref< StmtNode > prevInCtrlFlow() const
Definition: ast.cc:155
bool isStmt() const override
Definition: ast.h:236
Ref< StmtNode > nextStmtInDFSPreOrder() const
Definition: ast.cc:260
Ref< StmtNode > prevLeafStmtInDFSOrder() const
Definition: ast.cc:195
Ref< StmtNode > nextLeafStmtInDFSOrder() const
Definition: ast.cc:218
Ref< StmtNode > nextInCtrlFlow() const
Definition: ast.cc:175
Ref< StmtNode > parentCtrlFlow() const
Definition: ast.cc:146
ID id() const
Definition: ast.cc:362
Ref< StmtNode > nextStmt() const
Definition: ast.cc:134
Definition: ast.h:193
const ID & stmtId() const
Definition: ast.h:210
friend std::ostream & operator<<(std::ostream &os, const StmtOrExprID &id)
Definition: ast.cc:345
StmtOrExprID(const ID &stmtId)
Definition: ast.h:200
bool isValid() const
Definition: ast.h:213
const Expr & expr() const
Definition: ast.h:211
StmtOrExprID(const Expr &expr, T &&parent)
Definition: ast.h:206
StmtOrExprID()
Definition: ast.h:198
friend bool operator==(const StmtOrExprID &lhs, const StmtOrExprID &rhs)
Definition: ast.cc:352
StmtOrExprID(const Expr &expr, const ID &stmtId)
Definition: ast.h:202
Definition: allocator.h:9
Ref< TanNode > Tan
Definition: expr.h:600
Ref< AssumeNode > Assume
Definition: stmt.h:425
Ref< RemainderNode > Remainder
Definition: expr.h:324
Ref< FreeNode > Free
Definition: stmt.h:207
Ref< BoolConstNode > BoolConst
Definition: expr.h:134
Ref< RoundTowards0DivNode > RoundTowards0Div
Definition: expr.h:277
AST lcaAST(const AST &lhs, const AST &rhs)
Definition: ast.cc:367
Ref< VarNode > Var
Definition: expr.h:40
auto && lhs
Definition: const_fold.cc:70
Ref< VarDefNode > VarDef
Definition: stmt.h:107
auto makeMetadata(const std::string &op, Srcs &&...sourceStmts)
Definition: ast.h:315
Ref< CeilDivNode > CeilDiv
Definition: expr.h:257
Ref< MaxNode > Max
Definition: expr.h:352
Ref< LAndNode > LAnd
Definition: expr.h:450
Ref< MarkVersionNode > MarkVersion
Definition: stmt.h:579
Ref< LoadNode > Load
Definition: expr.h:61
Ref< StoreNode > Store
Definition: stmt.h:140
Ref< ReduceToNode > ReduceTo
Definition: stmt.h:248
Ref< SubNode > Sub
Definition: expr.h:186
Ref< IfNode > If
Definition: stmt.h:352
Ref< IfExprNode > IfExpr
Definition: expr.h:705
Ref< SquareNode > Square
Definition: expr.h:544
Ref< SigmoidNode > Sigmoid
Definition: expr.h:558
Ref< CastNode > Cast
Definition: expr.h:726
Ref< ForNode > For
Definition: stmt.h:308
Ref< RealDivNode > RealDiv
Definition: expr.h:217
std::vector< T > filter(const std::vector< T > &vec, const U &callback)
Definition: container_utils.h:131
Ref< EvalNode > Eval
Definition: stmt.h:450
Ref< CeilNode > Ceil
Definition: expr.h:656
Ref< ASTNode > AST
Definition: ast.h:149
Ref< LNotNode > LNot
Definition: expr.h:488
Ref< FloorNode > Floor
Definition: expr.h:642
Ref< GENode > GE
Definition: expr.h:408
Ref< GTNode > GT
Definition: expr.h:394
Expr lcaExpr(const Expr &lhs, const Expr &rhs)
Definition: ast.cc:375
Ref< FloorDivNode > FloorDiv
Definition: expr.h:237
Ref< TanhNode > Tanh
Definition: expr.h:614
Stmt lcaStmt(const Stmt &lhs, const Stmt &rhs)
Definition: ast.cc:383
Ref< MulNode > Mul
Definition: expr.h:200
Ref< LTNode > LT
Definition: expr.h:366
Ref< NENode > NE
Definition: expr.h:436
Expr deepCopy(const Expr &op)
Definition: ast.cc:364
Ref< StmtNode > Stmt
Definition: ast.h:152
Ref< LoadAtVersionNode > LoadAtVersion
Definition: expr.h:790
Ref< IntrinsicNode > Intrinsic
Definition: expr.h:754
Ref< MinNode > Min
Definition: expr.h:338
Ref< IntConstNode > IntConst
Definition: expr.h:100
Ref< AnyExprNode > AnyExpr
Definition: expr.h:24
Ref< AbsNode > Abs
Definition: expr.h:628
Ref< EQNode > EQ
Definition: expr.h:422
auto auto && rhs
Definition: const_fold.cc:70
Ref< FuncNode > Func
Definition: func.h:64
Ref< SqrtNode > Sqrt
Definition: expr.h:502
Ref< AddNode > Add
Definition: expr.h:172
std::ostream & operator<<(std::ostream &os, const Dependence &dep)
Definition: deps.cc:1404
Ref< ExprNode > Expr
Definition: ast.h:184
Ref< UnboundNode > Unbound
Definition: expr.h:683
Ref< ModNode > Mod
Definition: expr.h:301
Ref< AssertNode > Assert
Definition: stmt.h:392
Ref< AllocNode > Alloc
Definition: stmt.h:182
Ref< LOrNode > LOr
Definition: expr.h:464
Ref< MatMulNode > MatMul
Definition: stmt.h:533
Ref< CosNode > Cos
Definition: expr.h:586
Ref< SinNode > Sin
Definition: expr.h:572
Ref< LENode > LE
Definition: expr.h:380
Ref< AnyNode > Any
Definition: stmt.h:27
Ref< FloatConstNode > FloatConst
Definition: expr.h:117
Ref< StmtSeqNode > StmtSeq
Definition: stmt.h:49
Ref< ExpNode > Exp
Definition: expr.h:516
ASTNodeType
Definition: ast.h:20
Ref< LnNode > Ln
Definition: expr.h:530
STL namespace.