FreeTensor
Loading...
Searching...
No Matches
blend.h
Go to the documentation of this file.
1#ifndef FREE_TENSOR_BLEND_H
2#define FREE_TENSOR_BLEND_H
3
4#include <unordered_map>
5
7#include <mutator.h>
8#include <visitor.h>
9
10namespace freetensor {
11
13 ID loop_;
14 std::vector<ID> scopes_;
15 bool inLoop_ = false;
16 bool found_ = false;
17
18 public:
19 FindAllScopesInside(const ID &loop) : loop_(loop) {}
20
21 const std::vector<ID> &scopes() const { return scopes_; }
22
23 bool found() const { return found_; }
24
25 protected:
26 void visit(const For &op) override;
27 void visit(const StmtSeq &op) override;
28};
29
30class BlendPass : public Mutator {
31 ID loop_;
32 bool inLoop_ = false;
33 std::string iter_;
34 Expr begin_, step_;
35 int len_ = 0, curIter_ = 0;
36 std::vector<Stmt> envStack_;
37 std::vector<VarDef> defs_;
38 std::unordered_map<std::string, std::pair<Expr, Expr>> offset_;
39 const LoopVariExprMap &exprVari_;
40 const LoopVariUniqVarMap &varVari_;
41
42 public:
43 BlendPass(const ID &loop, const LoopVariExprMap &exprVari,
44 const LoopVariUniqVarMap &varVari)
45 : loop_(loop), exprVari_(exprVari), varVari_(varVari) {}
46
47 private:
48 template <class T> Stmt visitLeafStmt(const T &op) {
49 if (inLoop_) {
50 std::vector<Stmt> stmts;
51 for (curIter_ = 0; curIter_ < len_; curIter_++) {
52 auto stmt = Mutator::visit(op);
53 if (stmt->nodeType() == ASTNodeType::Store) {
54 stmt = visitMemAccess(stmt.template as<StoreNode>());
55 } else if (stmt->nodeType() == ASTNodeType::ReduceTo) {
56 stmt = visitMemAccess(stmt.template as<ReduceToNode>());
57 }
58 stmt->metadata() =
59 makeMetadata("blend." + std::to_string(curIter_), stmt);
60
61 for (auto it = envStack_.rbegin(); it != envStack_.rend();
62 it++) {
63 switch ((*it)->nodeType()) {
64 case ASTNodeType::For: {
65 auto env = it->as<ForNode>();
66 stmt = makeFor(env->iter_, (*this)(env->begin_),
67 (*this)(env->end_), (*this)(env->step_),
68 (*this)(env->len_), env->property_,
69 std::move(stmt));
70 break;
71 }
72 case ASTNodeType::If: {
73 auto env = it->as<IfNode>();
74 stmt = makeIf((*this)(env->cond_), std::move(stmt));
75 break;
76 }
78 auto env = it->as<AssertNode>();
79 stmt = makeAssert((*this)(env->cond_), std::move(stmt));
80 break;
81 }
82 default:
83 ASSERT(false);
84 }
85 }
86 stmts.emplace_back(std::move(stmt));
87 }
88 return makeStmtSeq(std::move(stmts));
89 } else {
90 return Mutator::visit(op);
91 }
92 }
93
94 template <class T> T visitMemAccess(const T &op) {
95 if (inLoop_) {
96 for (auto &&def : defs_) {
97 if (def->name_ == op->var_) {
98 op->var_ += "." + std::to_string(curIter_);
99 }
100 }
101 }
102 return op;
103 }
104
105 protected:
106 Stmt visit(const Store &op) override { return visitLeafStmt(op); }
107 Stmt visit(const ReduceTo &op) override { return visitLeafStmt(op); }
108 Stmt visit(const Eval &op) override { return visitLeafStmt(op); }
109 Stmt visit(const For &op) override;
110 Stmt visit(const If &op) override;
111 Stmt visit(const Assert &op) override;
112 Stmt visit(const VarDef &op) override;
113 Expr visit(const Var &op) override;
114 Expr visit(const Load &op) override;
115};
116
117Stmt blend(const Stmt &ast, const ID &loop);
118
119} // namespace freetensor
120
121#endif // FREE_TENSOR_BLEND_H
Definition: blend.h:30
Stmt visit(const ReduceTo &op) override
Definition: blend.h:107
BlendPass(const ID &loop, const LoopVariExprMap &exprVari, const LoopVariUniqVarMap &varVari)
Definition: blend.h:43
Stmt visit(const Eval &op) override
Definition: blend.h:108
Stmt visit(const Store &op) override
Definition: blend.h:106
Definition: blend.h:12
void visit(const For &op) override
Definition: blend.cc:9
bool found() const
Definition: blend.h:23
FindAllScopesInside(const ID &loop)
Definition: blend.h:19
const std::vector< ID > & scopes() const
Definition: blend.h:21
Definition: id.h:18
Definition: mutator.h:11
virtual Stmt visit(const Any &op)
Definition: mutator.h:39
Definition: visitor.h:11
#define ASSERT(expr)
Definition: except.h:152
Definition: allocator.h:9
std::unordered_map< ID, std::unordered_map< ID, LoopVariability > > LoopVariUniqVarMap
Definition: find_loop_variance.h:23
auto makeMetadata(const std::string &op, Srcs &&...sourceStmts)
Definition: ast.h:315
std::unordered_map< StmtOrExprID, std::unordered_map< ID, LoopVariability > > LoopVariExprMap
Definition: find_loop_variance.h:26
Stmt makeAssert(Tcond &&cond, Tbody &&body, const Metadata &metadata=nullptr, const ID &id={}, std::source_location loc=std::source_location::current())
Definition: stmt.h:394
Ref< StmtNode > Stmt
Definition: ast.h:152
Stmt makeStmtSeq(Tstmts &&stmts, const Metadata &metadata=nullptr, const ID &id={}, std::source_location loc=std::source_location::current())
Definition: stmt.h:51
Stmt makeIf(Tcond &&cond, Tthen &&thenCase, Telse &&elseCase, const Metadata &metadata=nullptr, const ID &id={}, std::source_location loc=std::source_location::current())
Definition: stmt.h:354
Stmt makeFor(const std::string &iter, Tbegin &&begin, Tend &&end, Tstep &&step, Tlen &&len, Tproperty &&property, Tbody &&body, const Metadata &metadata=nullptr, const ID &id={}, std::source_location loc=std::source_location::current())
Definition: stmt.h:311
Stmt blend(const Stmt &ast, const ID &loop)
Definition: blend.cc:166