FreeTensor
Loading...
Searching...
No Matches
fission.h
Go to the documentation of this file.
1#ifndef FREE_TENSOR_FISSION_H
2#define FREE_TENSOR_FISSION_H
3
4#include <iostream>
5#include <string>
6#include <unordered_map>
7#include <unordered_set>
8
10#include <container_utils.h>
11#include <mutator.h>
12
13namespace freetensor {
14
15enum class FissionSide : int { Before, After };
16inline std::ostream &operator<<(std::ostream &os, FissionSide side) {
17 return os << (side == FissionSide::Before ? "before" : "after");
18}
19
20class AddDimToVar : public SymbolTable<Mutator> {
22
23 // VarDef ID -> [for ID], from outer to inner
24 std::unordered_map<ID, std::vector<ID>> toAdd_;
25 // for ID -> For
26 std::unordered_map<ID, For> forMap_;
27
28 public:
29 AddDimToVar(const std::unordered_map<ID, std::vector<ID>> &toAdd)
30 : toAdd_(toAdd) {}
31
32 private:
33 template <class T> T doAdd(T op) {
34 if (toAdd_.count(def(op->var_)->id())) {
35 for (auto &&loop : views::reverse(toAdd_.at(def(op->var_)->id()))) {
36 op->indices_.insert(
37 op->indices_.begin(),
38 makeFloorDiv(makeSub(makeVar(forMap_.at(loop)->iter_),
39 forMap_.at(loop)->begin_),
40 forMap_.at(loop)->step_));
41 }
42 }
43 return op;
44 }
45
46 protected:
47 Stmt visit(const For &op) override;
48 Stmt visit(const VarDef &op) override;
49 Stmt visit(const Store &op) override;
50 Stmt visit(const ReduceTo &op) override;
51 Expr visit(const Load &op) override;
52};
53
54class FissionFor : public Mutator {
55 ID loop_;
56 ID before_, after_;
57 std::optional<std::string> op0_, op1_;
58 std::unordered_map<ID, ID> ids0_, ids1_;
59 std::unordered_set<std::string> varUses_;
60 bool inside_ = false, isPart0_ = true, anyInside_ = false, isAfter_ = false;
61
62 public:
63 FissionFor(const ID &loop, const ID &before, const ID &after,
64 const std::string &suffix0, const std::string &suffix1)
65 : loop_(loop), before_(before), after_(after),
66 op0_(suffix0.empty() ? std::nullopt
67 : std::optional{"fission" + suffix0}),
68 op1_(suffix1.empty() ? std::nullopt
69 : std::optional{"fission" + suffix1}) {}
70
71 const std::unordered_map<ID, ID> &ids0() const { return ids0_; }
72 const std::unordered_map<ID, ID> &ids1() const { return ids1_; }
73
74 private:
75 void markNewId(const Stmt &op, bool isPart0);
76
77 bool inPart() const {
78 return inside_ && ((isPart0_ && !isAfter_) || (!isPart0_ && isAfter_));
79 }
80
81 protected:
82 Stmt visitStmt(const Stmt &op) override;
83 Stmt visit(const For &op) override;
84 Stmt visit(const StmtSeq &op) override;
85 Stmt visit(const VarDef &op) override;
86 Stmt visit(const Store &op) override;
87 Expr visit(const Load &op) override;
88 Stmt visit(const ReduceTo &op) override;
89 Stmt visit(const If &op) override;
90 Stmt visit(const Assert &op) override;
91};
92
93std::pair<Stmt,
94 std::pair<std::unordered_map<ID, ID>, std::unordered_map<ID, ID>>>
95fission(const Stmt &ast, const ID &loop, FissionSide side, const ID &splitter,
96 bool allowEnlarge, const std::string &suffix0,
97 const std::string &suffix1);
98
99} // namespace freetensor
100
101#endif // FREE_TENSOR_FISSION_H
Definition: fission.h:20
AddDimToVar(const std::unordered_map< ID, std::vector< ID > > &toAdd)
Definition: fission.h:29
Stmt visit(const For &op) override
Definition: fission.cc:13
Definition: fission.h:54
FissionFor(const ID &loop, const ID &before, const ID &after, const std::string &suffix0, const std::string &suffix1)
Definition: fission.h:63
Stmt visitStmt(const Stmt &op) override
Definition: fission.cc:51
const std::unordered_map< ID, ID > & ids1() const
Definition: fission.h:72
const std::unordered_map< ID, ID > & ids0() const
Definition: fission.h:71
Stmt visit(const For &op) override
Definition: fission.cc:85
Definition: id.h:18
Definition: mutator.h:11
ID id() const
Definition: ast.cc:362
Definition: symbol_table.h:122
const For & loop(const std::string &name) const override
Definition: symbol_table.h:152
const VarDef & def(const std::string &name) const override
Definition: symbol_table.h:142
Definition: allocator.h:9
Ref< VarDefNode > VarDef
Definition: stmt.h:107
FissionSide
Definition: fission.h:15
Ref< LoadNode > Load
Definition: expr.h:61
Ref< StoreNode > Store
Definition: stmt.h:140
Expr makeFloorDiv(T &&lhs, U &&rhs, std::source_location loc=std::source_location::current())
Definition: expr.h:239
Ref< ReduceToNode > ReduceTo
Definition: stmt.h:248
Ref< IfNode > If
Definition: stmt.h:352
Ref< ForNode > For
Definition: stmt.h:308
Ref< StmtNode > Stmt
Definition: ast.h:152
std::pair< Stmt, std::pair< std::unordered_map< ID, ID >, std::unordered_map< ID, ID > > > fission(const Stmt &ast, const ID &loop, FissionSide side, const ID &splitter, bool allowEnlarge, const std::string &suffix0, const std::string &suffix1)
Definition: fission.cc:182
Expr makeSub(T &&lhs, U &&rhs, std::source_location loc=std::source_location::current())
Definition: expr.h:188
Expr makeVar(const std::string &name, std::source_location loc=std::source_location::current())
Definition: expr.h:42
std::ostream & operator<<(std::ostream &os, const Dependence &dep)
Definition: deps.cc:1404
Ref< ExprNode > Expr
Definition: ast.h:184
Ref< AssertNode > Assert
Definition: stmt.h:392
Ref< StmtSeqNode > StmtSeq
Definition: stmt.h:49
STL namespace.