FreeTensor
Loading...
Searching...
No Matches
find_loop_variance.h
Go to the documentation of this file.
1#ifndef FREE_TENSOR_FIND_LOOP_VARIANCE_H
2#define FREE_TENSOR_FIND_LOOP_VARIANCE_H
3
4#include <unordered_map>
5#include <vector>
6
9#include <visitor.h>
10
11namespace freetensor {
12
13enum class LoopVariability : int {
14 // Non-existence = Invariant
15 Unknown,
16 Variant,
17};
18
19typedef std::unordered_map<std::string, std::unordered_map<ID, LoopVariability>>
21typedef std::unordered_map<ID /* of VarDef */,
22 std::unordered_map<ID, LoopVariability>>
24typedef std::unordered_map<StmtOrExprID /* of Expr */,
25 std::unordered_map<ID, LoopVariability>>
27
31class InitExprVari : public TrackStmt<Visitor> {
33
34 LoopVariExprMap &exprInfo_;
35 std::vector<ID> loopStack_;
36
37 public:
38 InitExprVari(LoopVariExprMap &exprInfo) : exprInfo_(exprInfo) {}
39
40 protected:
41 void visitExpr(const Expr &expr) override;
42 void visit(const For &op) override;
43};
44
57class MarkStores : public TrackStmt<Visitor> {
59
60 const std::string &var_;
61 std::vector<StmtOrExprID> &condStack_;
62 LoopVariTransVarMap &varInfo_;
63 const LoopVariExprMap &exprInfo_;
64
65 public:
66 MarkStores(const std::string &var, std::vector<StmtOrExprID> &condStack,
67 LoopVariTransVarMap &varInfo, const LoopVariExprMap &exprInfo)
68 : var_(var), condStack_(condStack), varInfo_(varInfo),
69 exprInfo_(exprInfo) {
70 varInfo_.erase(var_); // Invariant
71 }
72
73 private:
74 // to = from meet to
75 void meetTo(const Expr &from, const std::string &to);
76 void meetTo(const StmtOrExprID &from, const std::string &to);
77
78 template <class T> void visitMemWrite(const T &op) {
80 if (op->var_ == var_) {
81 meetTo(op->expr_, op->var_);
82 for (auto &&index : op->indices_) {
83 meetTo(index, op->var_);
84 }
85 for (auto &&cond : condStack_) {
86 meetTo(cond, op->var_);
87 }
88 }
89 }
90
91 protected:
92 void visit(const For &op) override;
93 void visit(const If &op) override;
94 void visit(const Store &op) override { visitMemWrite(op); }
95 void visit(const ReduceTo &op) override {
96 visitMemWrite(op);
97 for (auto p = op->parentStmt(); p.isValid(); p = p->parentStmt()) {
98 if (p->nodeType() == ASTNodeType::VarDef &&
99 p.as<VarDefNode>()->name_ == op->var_) {
100 break;
101 }
102 if (p->nodeType() == ASTNodeType::For) {
103 varInfo_[op->var_][p->id()] = LoopVariability::Variant;
104 }
105 }
106 }
107};
108
109class FindLoopVariance : public SymbolTable<TrackStmt<Visitor>> {
111
112 std::vector<ID> loopStack_;
113 std::vector<StmtOrExprID> condStack_;
114 LoopVariTransVarMap varInfo_;
115 LoopVariUniqVarMap uniqVarInfo_;
116 LoopVariExprMap exprInfo_;
117
118 public:
119 FindLoopVariance(const Stmt &root) { InitExprVari{exprInfo_}(root); }
120
121 const LoopVariExprMap &exprInfo() const { return exprInfo_; }
122 const LoopVariUniqVarMap &varInfo() const { return uniqVarInfo_; }
123
124 int unknownCnt() const;
125
126 private:
127 // to = from
128 void copyInfo(const Expr &from, const Expr &to);
129 // to = from meet to
130 void meetTo(const Expr &from, const Expr &to);
131
132 void visitConst(const Const &op);
133 void visitBinOp(const BinaryExpr &op);
134 void visitUnaryOp(const UnaryExpr &op);
135
136 protected:
137 void visit(const For &op) override;
138 void visit(const If &op) override;
139 void visit(const VarDef &op) override;
140
141 void visitExpr(const Expr &op) override;
142 void visit(const Var &op) override;
143 void visit(const Load &op) override;
144 void visit(const IfExpr &op) override;
145 void visit(const Cast &op) override;
146 void visit(const Intrinsic &op) override;
147};
148
149bool isVariant(const LoopVariExprMap &exprInfo, const StmtOrExprID &expr,
150 const ID &loop);
151bool isVariant(const LoopVariUniqVarMap &varInfo, const VarDef &def,
152 const ID &loop);
153bool isVariant(const LoopVariUniqVarMap &varInfo, const ID &defId,
154 const ID &loop);
155
186std::pair<LoopVariExprMap, LoopVariUniqVarMap> findLoopVariance(const Stmt &op);
187
188} // namespace freetensor
189
190#endif // FREE_TENSOR_FIND_LOOP_VARIANCE_H
Definition: find_loop_variance.h:109
void visitExpr(const Expr &op) override
Definition: find_loop_variance.cc:153
const LoopVariExprMap & exprInfo() const
Definition: find_loop_variance.h:121
int unknownCnt() const
Definition: find_loop_variance.cc:82
FindLoopVariance(const Stmt &root)
Definition: find_loop_variance.h:119
const LoopVariUniqVarMap & varInfo() const
Definition: find_loop_variance.h:122
void visit(const For &op) override
Definition: find_loop_variance.cc:104
Definition: id.h:18
Definition: find_loop_variance.h:31
void visit(const For &op) override
Definition: find_loop_variance.cc:47
void visitExpr(const Expr &expr) override
Definition: find_loop_variance.cc:40
InitExprVari(LoopVariExprMap &exprInfo)
Definition: find_loop_variance.h:38
Definition: find_loop_variance.h:57
void visit(const For &op) override
Definition: find_loop_variance.cc:61
MarkStores(const std::string &var, std::vector< StmtOrExprID > &condStack, LoopVariTransVarMap &varInfo, const LoopVariExprMap &exprInfo)
Definition: find_loop_variance.h:66
void visit(const ReduceTo &op) override
Definition: find_loop_variance.h:95
void visit(const Store &op) override
Definition: find_loop_variance.h:94
std::string var_
Definition: stmt.h:231
bool isValid() const
Definition: ref.h:89
Ref< StmtNode > parentStmt() const
Definition: ast.cc:103
Definition: ast.h:193
Definition: symbol_table.h:122
Definition: track_stmt.h:24
Definition: stmt.h:83
std::string name_
Definition: stmt.h:85
virtual void visit(const Any &op)
Definition: visitor.h:36
Definition: allocator.h:9
bool isVariant(const LoopVariExprMap &exprInfo, const StmtOrExprID &expr, const ID &loop)
Definition: find_loop_variance.cc:225
std::unordered_map< ID, std::unordered_map< ID, LoopVariability > > LoopVariUniqVarMap
Definition: find_loop_variance.h:23
Ref< VarDefNode > VarDef
Definition: stmt.h:107
std::unordered_map< StmtOrExprID, std::unordered_map< ID, LoopVariability > > LoopVariExprMap
Definition: find_loop_variance.h:26
std::unordered_map< std::string, std::unordered_map< ID, LoopVariability > > LoopVariTransVarMap
Definition: find_loop_variance.h:20
Ref< IfNode > If
Definition: stmt.h:352
Ref< ForNode > For
Definition: stmt.h:308
LoopVariability
Definition: find_loop_variance.h:13
Ref< StmtNode > Stmt
Definition: ast.h:152
std::pair< LoopVariExprMap, LoopVariUniqVarMap > findLoopVariance(const Stmt &op)
Definition: find_loop_variance.cc:241