FreeTensor
Loading...
Searching...
No Matches
remove_writes.h
Go to the documentation of this file.
1#ifndef FREE_TENSOR_REMOVE_WRITES_H
2#define FREE_TENSOR_REMOVE_WRITES_H
3
4#include <unordered_map>
5#include <unordered_set>
6
9#include <func.h>
10#include <mutator.h>
11#include <visitor.h>
12
13namespace freetensor {
14
15class FindLoopInvariantWrites : public SymbolTable<Visitor> {
17
18 std::vector<For> loopStack_;
19 std::vector<If> ifStack_;
20 std::unordered_map<Store, std::tuple<VarDef, Expr, For>>
21 results_;
22 std::unordered_map<std::string, int> defDepth_;
23 const LoopVariExprMap &variantExpr_;
24 ID singleDefId_;
25
26 public:
28 const ID &singleDefId)
29 : variantExpr_(variantExpr), singleDefId_(singleDefId) {}
30
31 const std::unordered_map<Store, std::tuple<VarDef, Expr, For>> &
32 results() const {
33 return results_;
34 }
35
36 protected:
37 using BaseClass::visit;
38 void visit(const For &op) override;
39 void visit(const If &op) override;
40 void visit(const VarDef &op) override;
41 void visit(const Store &op) override;
42};
43
44class RemoveWrites : public Mutator {
45 const std::unordered_set<Stmt> &redundant_;
46 const std::unordered_map<Stmt, Stmt> &replacement_;
47
48 public:
49 RemoveWrites(const std::unordered_set<Stmt> &redundant,
50 const std::unordered_map<Stmt, Stmt> &replacement)
51 : redundant_(redundant), replacement_(replacement) {}
52
53 template <class T> Stmt doVisit(const T &op) {
54 if (redundant_.count(op)) {
55 return makeStmtSeq({}, op->metadata(), op->id());
56 } else if (replacement_.count(op)) {
57 return replacement_.at(op);
58 } else {
59 return Mutator::visit(op);
60 }
61 }
62
63 protected:
64 Stmt visit(const Store &op) override { return doVisit(op); }
65 Stmt visit(const ReduceTo &op) override { return doVisit(op); }
66 Stmt visit(const StmtSeq &op) override;
67 Stmt visit(const For &op) override;
68 Stmt visit(const If &op) override;
69};
70
105Stmt removeWrites(const Stmt &op, const ID &singleDefId = {});
106
108
109} // namespace freetensor
110
111#endif // FREE_TENSOR_REMOVE_WRITES_H
Definition: remove_writes.h:15
void visit(const For &op) override
Definition: remove_writes.cc:75
FindLoopInvariantWrites(const LoopVariExprMap &variantExpr, const ID &singleDefId)
Definition: remove_writes.h:27
const std::unordered_map< Store, std::tuple< VarDef, Expr, For > > & results() const
Definition: remove_writes.h:32
Definition: id.h:18
Definition: mutator.h:11
virtual Stmt visit(const Any &op)
Definition: mutator.h:39
Definition: remove_writes.h:44
Stmt visit(const Store &op) override
Definition: remove_writes.h:64
RemoveWrites(const std::unordered_set< Stmt > &redundant, const std::unordered_map< Stmt, Stmt > &replacement)
Definition: remove_writes.h:49
Stmt doVisit(const T &op)
Definition: remove_writes.h:53
Stmt visit(const ReduceTo &op) override
Definition: remove_writes.h:65
Definition: symbol_table.h:122
BaseClass::StmtRetType visit(const VarDef &op) override
Definition: symbol_table.h:167
#define DEFINE_PASS_FOR_FUNC(pass)
Definition: func.h:88
Definition: allocator.h:9
std::unordered_map< StmtOrExprID, std::unordered_map< ID, LoopVariability > > LoopVariExprMap
Definition: find_loop_variance.h:26
Ref< StmtNode > Stmt
Definition: ast.h:152
Stmt makeStmtSeq(Tstmts &&stmts, const Metadata &metadata=nullptr, const ID &id={}, std::source_location loc=std::source_location::current())
Definition: stmt.h:51
Stmt removeWrites(const Stmt &op, const ID &singleDefId={})
Definition: remove_writes.cc:185