FreeTensor
Loading...
Searching...
No Matches
shrink_var.h
Go to the documentation of this file.
1#ifndef FREE_TENSOR_SHRINK_VAR_H
2#define FREE_TENSOR_SHRINK_VAR_H
3
4#include <unordered_map>
5
7#include <container_utils.h>
8#include <func.h>
9#include <mutator.h>
10
11namespace freetensor {
12
20class ShrinkVar : public Mutator {
21 // Bound considering the old shape. Used for preventing make the shape even
22 // larger after shrinking
23 const std::unordered_map<ID, AccessBound> &newRangeWithShape_;
24
25 // Bound without considering the old shape. Used for preventing redundant
26 // guards for maybe-unsafe user code
27 const std::unordered_map<ID, AccessBound> &newRangeWithoutShape_;
28
29 bool guardReads_;
30
31 std::unordered_map<std::string, std::vector<Expr>> lowerWithShape_,
32 upperWithShape_;
33 std::unordered_map<std::string, std::vector<Expr>> lowerWithoutShape_,
34 upperWithoutShape_;
35 std::unordered_map<ID, Expr> guards_;
36
37 public:
38 ShrinkVar(const std::unordered_map<ID, AccessBound> &newRangeWithShape,
39 const std::unordered_map<ID, AccessBound> &newRangeWithoutShape,
40 bool guardReads = false)
41 : newRangeWithShape_(newRangeWithShape),
42 newRangeWithoutShape_(newRangeWithoutShape), guardReads_(guardReads) {
43 }
44
45 private:
46 template <class T> T modifyAccess(const T &op) {
47 if (lowerWithoutShape_.count(op->var_)) {
48 auto &&offset = lowerWithoutShape_.at(op->var_);
49 ASSERT(offset.size() == op->indices_.size());
50 for (auto &&[idx, off] : views::zip(op->indices_, offset)) {
51 if (off.isValid()) {
52 idx = makeSub(idx, off);
53 }
54 }
55 }
56 return op;
57 }
58
59 template <class T> void addGuard(const T &oldOp, const T &op) {
60 // We add check w.r.t oldOp because it is simplier, which brings less
61 // redundancy to pass/simplify
62 Expr guard;
63 if (upperWithoutShape_.count(op->var_)) {
64 auto &&upper = upperWithoutShape_.at(op->var_);
65 ASSERT(upper.size() == op->indices_.size());
66 for (auto &&[idx, u] : views::zip(oldOp->indices_, upper)) {
67 if (u.isValid()) {
68 guard = guard.isValid() ? makeLAnd(guard, makeLE(idx, u))
69 : makeLE(idx, u);
70 }
71 }
72 }
73 if (lowerWithoutShape_.count(op->var_)) {
74 auto &&lower = lowerWithoutShape_.at(op->var_);
75 ASSERT(lower.size() == op->indices_.size());
76 for (auto &&[idx, l] : views::zip(oldOp->indices_, lower)) {
77 if (l.isValid()) {
78 guard = guard.isValid() ? makeLAnd(guard, makeGE(idx, l))
79 : makeGE(idx, l);
80 }
81 }
82 }
83 if (guard.isValid()) {
84 Stmt s;
85 if constexpr (std::is_base_of_v<StmtNode, typename T::Object>) {
86 s = oldOp;
87 } else if constexpr (std::is_base_of_v<ExprNode,
88 typename T::Object>) {
89 s = oldOp->parentStmt();
90 } else {
91 ASSERT(false);
92 }
93 guards_[s->id()] = guards_[s->id()].isValid()
94 ? makeLAnd(guards_[s->id()], guard)
95 : guard;
96 }
97 }
98
99 protected:
100 Stmt visitStmt(const Stmt &s) override;
101 Stmt visit(const VarDef &op) override;
102 Expr visit(const Load &op) override;
103 Stmt visit(const Store &op) override;
104 Stmt visit(const ReduceTo &op) override;
105};
106
117Stmt shrinkVar(const Stmt &op);
118Stmt shrinkSingleVar(const Stmt &op, const ID &varDefId);
122
123} // namespace freetensor
124
125#endif // FREE_TENSOR_SHRINK_VAR_H
Definition: mutator.h:11
Definition: shrink_var.h:20
Stmt visitStmt(const Stmt &s) override
Definition: shrink_var.cc:12
Stmt visit(const VarDef &op) override
Definition: shrink_var.cc:20
ShrinkVar(const std::unordered_map< ID, AccessBound > &newRangeWithShape, const std::unordered_map< ID, AccessBound > &newRangeWithoutShape, bool guardReads=false)
Definition: shrink_var.h:38
Ref< StmtNode > parentStmt() const
Definition: ast.cc:103
#define ASSERT(expr)
Definition: except.h:152
#define DEFINE_PASS_FOR_FUNC(pass)
Definition: func.h:88
Definition: allocator.h:9
Ref< VarDefNode > VarDef
Definition: stmt.h:107
T lower(const T &_ast, const Ref< Target > &_target=nullptr, const std::unordered_set< std::string > &skipPasses={}, int verbose=0)
Definition: lower.h:53
Ref< LoadNode > Load
Definition: expr.h:61
Ref< StoreNode > Store
Definition: stmt.h:140
Ref< ReduceToNode > ReduceTo
Definition: stmt.h:248
Expr makeLE(T &&lhs, U &&rhs, std::source_location loc=std::source_location::current())
Definition: expr.h:382
Stmt shrinkVar(const Stmt &op)
Definition: shrink_var.cc:102
Ref< StmtNode > Stmt
Definition: ast.h:152
Stmt shrinkSingleVar(const Stmt &op, const ID &varDefId)
Definition: shrink_var.cc:128
Expr makeLAnd(T &&lhs, U &&rhs, std::source_location loc=std::source_location::current())
Definition: expr.h:452
Expr makeSub(T &&lhs, U &&rhs, std::source_location loc=std::source_location::current())
Definition: expr.h:188
Expr makeGE(T &&lhs, U &&rhs, std::source_location loc=std::source_location::current())
Definition: expr.h:410
Ref< ExprNode > Expr
Definition: ast.h:184