1#ifndef FREE_TENSOR_FISSION_H
2#define FREE_TENSOR_FISSION_H
6#include <unordered_map>
7#include <unordered_set>
24 std::unordered_map<ID, std::vector<ID>> toAdd_;
26 std::unordered_map<ID, For> forMap_;
33 template <
class T> T doAdd(T op) {
34 if (toAdd_.count(
def(op->var_)->
id())) {
35 for (
auto &&
loop : views::reverse(toAdd_.at(
def(op->var_)->
id()))) {
39 forMap_.at(
loop)->begin_),
40 forMap_.at(
loop)->step_));
57 std::optional<std::string> op0_, op1_;
58 std::unordered_map<ID, ID> ids0_, ids1_;
59 std::unordered_set<std::string> varUses_;
60 bool inside_ =
false, isPart0_ =
true, anyInside_ =
false, isAfter_ =
false;
64 const std::string &suffix0,
const std::string &suffix1)
65 : loop_(loop), before_(before), after_(after),
66 op0_(suffix0.empty() ?
std::nullopt
67 :
std::optional{
"fission" + suffix0}),
68 op1_(suffix1.empty() ?
std::nullopt
69 :
std::optional{
"fission" + suffix1}) {}
71 const std::unordered_map<ID, ID> &
ids0()
const {
return ids0_; }
72 const std::unordered_map<ID, ID> &
ids1()
const {
return ids1_; }
75 void markNewId(
const Stmt &op,
bool isPart0);
78 return inside_ && ((isPart0_ && !isAfter_) || (!isPart0_ && isAfter_));
94 std::pair<std::unordered_map<ID, ID>, std::unordered_map<ID, ID>>>
96 bool allowEnlarge,
const std::string &suffix0,
97 const std::string &suffix1);
AddDimToVar(const std::unordered_map< ID, std::vector< ID > > &toAdd)
Definition: fission.h:29
Stmt visit(const For &op) override
Definition: fission.cc:13
FissionFor(const ID &loop, const ID &before, const ID &after, const std::string &suffix0, const std::string &suffix1)
Definition: fission.h:63
Stmt visitStmt(const Stmt &op) override
Definition: fission.cc:51
const std::unordered_map< ID, ID > & ids1() const
Definition: fission.h:72
const std::unordered_map< ID, ID > & ids0() const
Definition: fission.h:71
Stmt visit(const For &op) override
Definition: fission.cc:85
ID id() const
Definition: ast.cc:362
Definition: symbol_table.h:122
const For & loop(const std::string &name) const override
Definition: symbol_table.h:152
const VarDef & def(const std::string &name) const override
Definition: symbol_table.h:142
Definition: allocator.h:9
Ref< VarDefNode > VarDef
Definition: stmt.h:107
FissionSide
Definition: fission.h:15
Ref< LoadNode > Load
Definition: expr.h:61
Ref< StoreNode > Store
Definition: stmt.h:140
Expr makeFloorDiv(T &&lhs, U &&rhs, std::source_location loc=std::source_location::current())
Definition: expr.h:239
Ref< ReduceToNode > ReduceTo
Definition: stmt.h:248
Ref< IfNode > If
Definition: stmt.h:352
Ref< ForNode > For
Definition: stmt.h:308
Ref< StmtNode > Stmt
Definition: ast.h:152
std::pair< Stmt, std::pair< std::unordered_map< ID, ID >, std::unordered_map< ID, ID > > > fission(const Stmt &ast, const ID &loop, FissionSide side, const ID &splitter, bool allowEnlarge, const std::string &suffix0, const std::string &suffix1)
Definition: fission.cc:182
Expr makeSub(T &&lhs, U &&rhs, std::source_location loc=std::source_location::current())
Definition: expr.h:188
Expr makeVar(const std::string &name, std::source_location loc=std::source_location::current())
Definition: expr.h:42
std::ostream & operator<<(std::ostream &os, const Dependence &dep)
Definition: deps.cc:1404
Ref< ExprNode > Expr
Definition: ast.h:184
Ref< AssertNode > Assert
Definition: stmt.h:392
Ref< StmtSeqNode > StmtSeq
Definition: stmt.h:49