FreeTensor
Loading...
Searching...
No Matches
fuse.h
Go to the documentation of this file.
1#ifndef FREE_TENSOR_FUSE_H
2#define FREE_TENSOR_FUSE_H
3
4#include <mutator.h>
5#include <visitor.h>
6
7namespace freetensor {
8
11 std::vector<Stmt> scopes_; // inner to outer
12};
13
14enum class FindLoopInScopesDirection : int { Front, Back };
15
16class FuseFor : public Mutator {
17 Stmt root_;
18 ID id0_, id1_, fused_;
19 std::string iter0_, iter1_, newIter_;
20 ID beforeId_, afterId_;
21 Expr begin0_, begin1_, step0_, step1_;
22 bool strict_, inLoop0_ = false, inLoop1_ = false;
23
24 public:
25 FuseFor(const Stmt &root, const ID &id0, const ID &id1,
26 const std::string &newIter, bool strict)
27 : root_(root), id0_(id0), id1_(id1), newIter_(newIter),
28 strict_(strict) {}
29
30 const ID &fused() const { return fused_; }
31 const ID &beforeId() const { return beforeId_; }
32 const ID &afterId() const { return afterId_; }
33
34 protected:
35 Expr visit(const Var &op) override;
36 Stmt visit(const For &op) override;
37 Stmt visit(const StmtSeq &op) override;
38};
39
41 ID id0_, id1_;
42 LoopInScopes loop0InScopes_, loop1InScopes_;
43
44 public:
45 CheckFuseAccessible(const ID &id0, const ID &id1) : id0_(id0), id1_(id1) {}
46
47 const LoopInScopes &loop0() const { return loop0InScopes_; }
48 const LoopInScopes &loop1() const { return loop1InScopes_; }
49
50 void check(const Stmt &ast);
51
52 protected:
53 void visit(const StmtSeq &op) override;
54};
55
56std::pair<Stmt, ID> fuse(const Stmt &ast, const ID &loop0, const ID &loop1,
57 bool strict);
58
59} // namespace freetensor
60
61#endif // FREE_TENSOR_FUSE_H
Definition: fuse.h:40
const LoopInScopes & loop1() const
Definition: fuse.h:48
void check(const Stmt &ast)
Definition: fuse.cc:242
void visit(const StmtSeq &op) override
Definition: fuse.cc:217
const LoopInScopes & loop0() const
Definition: fuse.h:47
CheckFuseAccessible(const ID &id0, const ID &id1)
Definition: fuse.h:45
Definition: fuse.h:16
FuseFor(const Stmt &root, const ID &id0, const ID &id1, const std::string &newIter, bool strict)
Definition: fuse.h:25
const ID & beforeId() const
Definition: fuse.h:31
const ID & afterId() const
Definition: fuse.h:32
const ID & fused() const
Definition: fuse.h:30
Expr visit(const Var &op) override
Definition: fuse.cc:101
Definition: id.h:18
Definition: mutator.h:11
Definition: visitor.h:11
Definition: allocator.h:9
FindLoopInScopesDirection
Definition: fuse.h:14
Ref< StmtNode > Stmt
Definition: ast.h:152
std::pair< Stmt, ID > fuse(const Stmt &ast, const ID &loop0, const ID &loop1, bool strict)
Definition: fuse.cc:250
Definition: fuse.h:9
For loop_
Definition: fuse.h:10
std::vector< Stmt > scopes_
Definition: fuse.h:11