FreeTensor
Loading...
Searching...
No Matches
comp_access_bound.h
Go to the documentation of this file.
1#ifndef FREE_TENSOR_COMP_ACCESS_BOUND_H
2#define FREE_TENSOR_COMP_ACCESS_BOUND_H
3
4#include <memory>
5#include <unordered_map>
6#include <unordered_set>
7
10#include <math/bounds.h>
11#include <visitor.h>
12
13namespace freetensor {
14
16 std::vector<Expr> lower_;
17 std::vector<Expr> upper_;
18 std::vector<Expr> len_;
20 // TODO: Ideally, len_ should be upper_bound(access_i - access_j) + 1, which
21 // supports shrinking a skewed variable, instead of upper_bound(access) -
22 // lower_bound(aceess). However, it brins too much burden on pass/simplify,
23 // so we do not choose it for now
24};
25
31
32class FindMemType : public Visitor {
33 ID varDefId_;
34 MemType mtype_;
35
36 public:
37 FindMemType(const ID &varDefId) : varDefId_(varDefId) {}
38
39 MemType mtype() const { return mtype_; }
40
41 protected:
42 void visit(const VarDef &op) override;
43};
44
45class CompAccessBound : public CompTransientBounds<SymbolTable<Visitor>> {
47
48 public:
49 struct Access {
50 std::vector<Expr> indices_, conds_;
51 std::vector<Ref<CompUniqueBounds::Bound>> bounds_;
52
53 Access(CompUniqueBounds &unique, const std::vector<Expr> &indices,
54 const std::vector<Expr> &conds,
55 const std::unordered_set<std::string> &names)
56 : indices_(indices), conds_(conds) {
57 for (auto &&idx : indices) {
58 bounds_.emplace_back(
59 unique.getBound(idx)->restrictScope(names));
60 }
61 }
62
63 Access(CompUniqueBounds &unique, const std::vector<Expr> &indices,
64 const std::vector<Expr> &conds)
65 : indices_(indices), conds_(conds) {
66 for (auto &&idx : indices) {
67 bounds_.emplace_back(unique.getBound(idx));
68 }
69 }
70 };
71
72 private:
74
75 // The variable to compute
76 ID varDefId_;
77 std::string var_;
78 MemType mtype_;
79
80 // each access to the specific variable
81 std::vector<Access> access_;
82
83 // all defined name in the scope
84 std::unordered_set<std::string> defs_;
85 std::unordered_map<std::string, std::unordered_set<std::string>>
86 defsAtVarDef_;
87
89 bool includeTrivialBound_;
90
91 ID filterSubTree_;
92 bool filtered_ = false;
93
94 AccessBound result_;
95
96 public:
97 CompAccessBound(const ID &varDefId, MemType mtype,
99 bool includeTrivialBound = true,
100 const ID &filterSubTree = ID())
101 : varDefId_(varDefId), mtype_(mtype), mode_(mode),
102 includeTrivialBound_(includeTrivialBound),
103 filterSubTree_(filterSubTree) {
104 if (!filterSubTree_.isValid()) {
105 filtered_ = true;
106 }
107 }
108
109 const AccessBound &result() const { return result_; }
110
111 protected:
112 using BaseClass::visit;
113 void visitStmt(const Stmt &stmt) override;
114 void visit(const VarDef &op) override;
115 void visit(const Load &op) override;
116 void visit(const Store &op) override;
117 void visit(const ReduceTo &op) override;
118 void visit(const For &op) override;
119};
120
133AccessBound compAccessBound(const Stmt &op, const ID &varDefId,
135 bool includeTrivialBound = true,
136 const ID &filterSubTree = ID());
137
138} // namespace freetensor
139
140#endif // FREE_TENSOR_COMP_ACCESS_BOUND_H
Definition: comp_access_bound.h:45
CompAccessBound(const ID &varDefId, MemType mtype, CompAccessBoundMode mode=COMP_ACCESS_BOUND_ALL, bool includeTrivialBound=true, const ID &filterSubTree=ID())
Definition: comp_access_bound.h:97
void visit(const VarDef &op) override
Definition: comp_access_bound.cc:64
const AccessBound & result() const
Definition: comp_access_bound.h:109
void visitStmt(const Stmt &stmt) override
Definition: comp_access_bound.cc:48
Definition: comp_transient_bounds.h:50
BaseClass::StmtRetType visit(const For &op) override
Definition: comp_transient_bounds.h:128
const std::vector< Expr > & conds() const override
Definition: comp_transient_bounds.h:66
Definition: comp_unique_bounds.h:13
virtual Ref< Bound > getBound(const Expr &op)=0
Definition: comp_access_bound.h:32
MemType mtype() const
Definition: comp_access_bound.h:39
void visit(const VarDef &op) override
Definition: comp_access_bound.cc:41
FindMemType(const ID &varDefId)
Definition: comp_access_bound.h:37
Definition: id.h:18
bool isValid() const
Definition: id.h:33
const std::unordered_set< std::string > & names() const override
Definition: symbol_table.h:129
Definition: visitor.h:11
Definition: allocator.h:9
const CompAccessBoundMode COMP_ACCESS_BOUND_WRITE
Definition: comp_access_bound.h:28
const CompAccessBoundMode COMP_ACCESS_BOUND_READ
Definition: comp_access_bound.h:27
const CompAccessBoundMode COMP_ACCESS_BOUND_ALL
Definition: comp_access_bound.h:29
AccessBound compAccessBound(const Stmt &op, const ID &varDefId, CompAccessBoundMode mode=COMP_ACCESS_BOUND_ALL, bool includeTrivialBound=true, const ID &filterSubTree=ID())
Definition: comp_access_bound.cc:195
Ref< StmtNode > Stmt
Definition: ast.h:152
int CompAccessBoundMode
Definition: comp_access_bound.h:26
MemType
Definition: mem_type.h:14
Definition: comp_access_bound.h:15
std::vector< Expr > lower_
Definition: comp_access_bound.h:16
Expr cond_
upper_bound(access) - lower_bound(access) + 1
Definition: comp_access_bound.h:19
std::vector< Expr > len_
upper_bound(access)
Definition: comp_access_bound.h:18
std::vector< Expr > upper_
lower_bound(access)
Definition: comp_access_bound.h:17
Definition: comp_access_bound.h:49
Access(CompUniqueBounds &unique, const std::vector< Expr > &indices, const std::vector< Expr > &conds)
Definition: comp_access_bound.h:63
std::vector< Expr > conds_
Definition: comp_access_bound.h:50
Access(CompUniqueBounds &unique, const std::vector< Expr > &indices, const std::vector< Expr > &conds, const std::unordered_set< std::string > &names)
Definition: comp_access_bound.h:53
std::vector< Ref< CompUniqueBounds::Bound > > bounds_
Definition: comp_access_bound.h:51
std::vector< Expr > indices_
Definition: comp_access_bound.h:50