FreeTensor
Loading...
Searching...
No Matches
comp_transient_bounds.h
Go to the documentation of this file.
1#ifndef FREE_TENSOR_COMP_TRANSIENT_BOUNDS_H
2#define FREE_TENSOR_COMP_TRANSIENT_BOUNDS_H
3
4#include <type_traits>
5#include <unordered_set>
6
7#include <analyze/all_uses.h>
9#include <analyze/as_dnf.h>
11#include <container_utils.h>
12#include <hash.h>
13#include <math/bounds.h>
14#include <maybe_void.h>
15#include <stmt.h>
16
17namespace freetensor {
18
21 std::vector<Expr> lower_, upper_;
22};
23
25 public:
26 virtual TransientBound transient(const Expr &op) const = 0;
27 virtual const std::vector<Expr> &conds() const = 0;
28};
29
48template <class BaseClass>
49class CompTransientBounds : public BaseClass,
51 // Bounds related to certain expressions
52 // Bounds in transients_ has already been recursed with (*this)(...)
54
55 // Original bounds
56 std::vector<Expr> conds_;
57
58 public:
59 TransientBound transient(const Expr &op) const override {
60 if (transients_.count(op)) {
61 return transients_.at(op);
62 }
63 return {};
64 }
65
66 const std::vector<Expr> &conds() const override { return conds_; }
67
68 private:
69 void applyCond(const Expr &_cond,
70 const std::unordered_set<std::string> &bodyAllWrites) {
71 auto dnf = asDNF(_cond);
72
73 if (dnf.size() != 1) {
74 // Currently `transients_` cannot handle OR, leave it as-is to
75 // `conds_`. But ignore if the condition is marked as `unbound`
76 for (auto &&item : dnf) {
77 for (auto &&sub : item) {
78 if (sub->nodeType() == ASTNodeType::Unbound) {
79 return;
80 }
81 }
82 }
83 conds_.emplace_back(_cond);
84 return;
85 }
86
87 for (auto &&cond : dnf.front()) {
88 if (cond->nodeType() == ASTNodeType::Unbound) {
89 continue;
90 }
91
92 if (hasIntersect(allReads(cond), bodyAllWrites)) {
93 continue;
94 }
95
96 auto norm = linearComp(cond);
97 if (!norm.has_value()) {
98 continue;
99 }
100
101 auto &&[lin, type] = *norm;
102 if (!isInt(lin2expr(lin)->dtype())) {
103 continue;
104 }
105
106 for (auto &&[k, a] : lin.coeff_) {
107 if (a->nodeType() == ASTNodeType::Var ||
108 a->nodeType() == ASTNodeType::Load) {
109 auto [lower, upper] = lin2bounds(lin, type, a);
110 if (lower.has_value()) {
111 transients_[a].expr_ = a;
112 transients_[a].lower_.emplace_back(lower->expr());
113 }
114 if (upper.has_value()) {
115 transients_[a].expr_ = a;
116 transients_[a].upper_.emplace_back(upper->expr());
117 }
118 }
119 }
120
121 conds_.emplace_back(cond);
122 }
123 }
124
125 protected:
126 using BaseClass::visit; // Avoid hiding virtual functions
127
128 typename BaseClass::StmtRetType visit(const For &op) override {
129 MAYBE_VOID(begin, (*this)(op->begin_));
130 MAYBE_VOID(end, (*this)(op->end_));
131 MAYBE_VOID(step, (*this)(op->step_));
132 MAYBE_VOID(len, (*this)(op->len_));
133
134 auto var = makeVar(op->iter_);
135 if (transients_.count(var)) {
136 throw InvalidProgram(
137 "iterators with the same name in nested loops are not allowed");
138 }
139 auto oldCondsSize = conds_.size();
140 if (op->step_->nodeType() == ASTNodeType::IntConst) {
141 auto step = op->step_.as<IntConstNode>()->val_;
142 if (step > 0) {
143 transients_[var] = {
144 var, {op->begin_}, {makeSub(op->end_, makeIntConst(1))}};
145 conds_.emplace_back(makeGE(var, op->begin_));
146 conds_.emplace_back(makeLT(var, op->end_));
147 conds_.emplace_back(
148 makeEQ(makeMod(makeSub(var, op->begin_), op->step_),
149 makeIntConst(0)));
150 } else if (step < 0) {
151 transients_[var] = {
152 var, {makeAdd(op->end_, makeIntConst(1))}, {op->begin_}};
153 conds_.emplace_back(makeLE(var, op->begin_));
154 conds_.emplace_back(makeGT(var, op->end_));
155 conds_.emplace_back(
156 makeEQ(makeMod(makeSub(var, op->begin_), op->step_),
157 makeIntConst(0)));
158 } else {
159 transients_[var] = {var, {op->begin_}, {op->begin_}};
160 conds_.emplace_back(makeEQ(var, op->begin_));
161 }
162 }
163 if constexpr (std::is_base_of_v<SymbolTableInterface, BaseClass>) {
164 this->pushFor(op);
165 }
166 MAYBE_VOID(body, (*this)(op->body_));
167 if constexpr (std::is_base_of_v<SymbolTableInterface, BaseClass>) {
168 this->popFor(op);
169 }
170 conds_.resize(oldCondsSize);
171 transients_.erase(var);
172
173 if constexpr (!std::is_same_v<typename BaseClass::StmtRetType, void>) {
174 auto property = Ref<ForProperty>::make()
175 ->withParallel(op->property_->parallel_)
176 ->withUnroll(op->property_->unroll_)
177 ->withVectorize(op->property_->vectorize_)
178 ->withNoDeps(op->property_->noDeps_)
179 ->withPreferLibs(op->property_->preferLibs_);
180 property->reductions_.reserve(op->property_->reductions_.size());
181 for (auto &&r : op->property_->reductions_) {
182 std::vector<Expr> begins, ends;
183 begins.reserve(r->begins_.size());
184 ends.reserve(r->ends_.size());
185 for (auto &&item : r->begins_) {
186 begins.emplace_back((*this)(item));
187 }
188 for (auto &&item : r->ends_) {
189 ends.emplace_back((*this)(item));
190 }
191 property->reductions_.emplace_back(
192 makeReductionItem(r->op_, r->var_, std::move(begins),
193 std::move(ends), r->syncFlush_));
194 }
195 return makeFor(op->iter_, std::move(begin), std::move(end),
196 std::move(step), std::move(len), std::move(property),
197 std::move(body), op->metadata(), op->id(),
198 op->debugBlame());
199 }
200 }
201
202 typename BaseClass::StmtRetType visit(const If &op) override {
203 MAYBE_VOID(cond, (*this)(op->cond_));
204
205 auto oldMap = transients_;
206 auto oldCondsSize = conds_.size();
207 applyCond(op->cond_, allWrites(op->thenCase_));
208 MAYBE_VOID(thenCase, (*this)(op->thenCase_));
209 transients_ = oldMap;
210 conds_.resize(oldCondsSize);
211
212 [[maybe_unused]] Stmt elseCase = nullptr;
213 if (op->elseCase_.isValid()) {
214 auto oldCondsSize = conds_.size();
215 applyCond(makeLNot(op->cond_), allWrites(op->elseCase_));
216 MAYBE_VOID_ASSIGN(elseCase, (*this)(op->elseCase_));
217 transients_ = oldMap;
218 conds_.resize(oldCondsSize);
219 }
220
221 if constexpr (!std::is_same_v<typename BaseClass::StmtRetType, void>) {
222 return makeIf(std::move(cond), std::move(thenCase),
223 std::move(elseCase), op->metadata(), op->id(),
224 op->debugBlame());
225 }
226 }
227
228 typename BaseClass::StmtRetType visit(const Assert &op) override {
229 MAYBE_VOID(cond, (*this)(op->cond_));
230
231 auto oldMap = transients_;
232 auto oldCondsSize = conds_.size();
233 applyCond(op->cond_, allWrites(op->body_));
234 MAYBE_VOID(body, (*this)(op->body_));
235 transients_ = oldMap;
236 conds_.resize(oldCondsSize);
237
238 if constexpr (!std::is_same_v<typename BaseClass::StmtRetType, void>) {
239 return makeAssert(std::move(cond), std::move(body), op->metadata(),
240 op->id(), op->debugBlame());
241 }
242 }
243
244 typename BaseClass::StmtRetType visit(const Assume &op) override {
245 MAYBE_VOID(cond, (*this)(op->cond_));
246
247 auto oldMap = transients_;
248 auto oldCondsSize = conds_.size();
249 applyCond(op->cond_, allWrites(op->body_));
250 MAYBE_VOID(body, (*this)(op->body_));
251 transients_ = oldMap;
252 conds_.resize(oldCondsSize);
253
254 if constexpr (!std::is_same_v<typename BaseClass::StmtRetType, void>) {
255 return makeAssume(std::move(cond), std::move(body), op->metadata(),
256 op->id(), op->debugBlame());
257 }
258 }
259};
260
261} // namespace freetensor
262
263#endif // FREE_TENSOR_COMP_TRANSIENT_BOUNDS_H
std::source_location debugBlame() const
Definition: ast.h:134
Definition: comp_transient_bounds.h:24
virtual TransientBound transient(const Expr &op) const =0
virtual const std::vector< Expr > & conds() const =0
Definition: comp_transient_bounds.h:50
BaseClass::StmtRetType visit(const Assume &op) override
Definition: comp_transient_bounds.h:244
TransientBound transient(const Expr &op) const override
Definition: comp_transient_bounds.h:59
BaseClass::StmtRetType visit(const For &op) override
Definition: comp_transient_bounds.h:128
const std::vector< Expr > & conds() const override
Definition: comp_transient_bounds.h:66
BaseClass::StmtRetType visit(const Assert &op) override
Definition: comp_transient_bounds.h:228
BaseClass::StmtRetType visit(const If &op) override
Definition: comp_transient_bounds.h:202
SubTree< ForProperty > property_
Definition: stmt.h:298
SubTree< ExprNode > begin_
Definition: stmt.h:294
SubTree< ExprNode > step_
Definition: stmt.h:296
SubTree< ExprNode > len_
Definition: stmt.h:297
SubTree< ExprNode > end_
Definition: stmt.h:295
std::string iter_
Definition: stmt.h:289
SubTree< StmtNode > body_
Definition: stmt.h:299
Definition: expr.h:93
Definition: except.h:83
static Ref make()
Definition: ref.h:105
bool isValid() const
Definition: ref.h:89
const Metadata & metadata() const
Definition: ast.h:233
ID id() const
Definition: ast.cc:362
#define MAYBE_VOID_ASSIGN(name, expr)
Definition: maybe_void.h:15
#define MAYBE_VOID(name, expr)
Definition: maybe_void.h:25
Definition: allocator.h:9
std::unordered_map< K, V, Hasher, HashComparator > ASTHashMap
Definition: hash.h:114
T lower(const T &_ast, const Ref< Target > &_target=nullptr, const std::unordered_set< std::string > &skipPasses={}, int verbose=0)
Definition: lower.h:53
Expr makeLT(T &&lhs, U &&rhs, std::source_location loc=std::source_location::current())
Definition: expr.h:368
std::pair< std::optional< LowerBound >, std::optional< UpperBound > > lin2bounds(const LinearExpr< T > &_lin, ASTNodeType cmp, const Expr &x)
Definition: bounds.h:84
Stmt makeAssume(Tcond &&cond, Tbody &&body, const Metadata &metadata=nullptr, const ID &id={}, std::source_location loc=std::source_location::current())
Definition: stmt.h:427
Ref< ReductionItem > makeReductionItem(ReduceOp op, const std::string &var, Tbegins &&begins, Tends &&ends, bool syncFlush)
Definition: for_property.h:26
std::unordered_set< std::string > allWrites(const AST &op, bool noRecurseIdx=false, bool noRecurseSubStmt=false)
Definition: all_uses.h:100
Expr makeGT(T &&lhs, U &&rhs, std::source_location loc=std::source_location::current())
Definition: expr.h:396
Expr makeAdd(T &&lhs, U &&rhs, std::source_location loc=std::source_location::current())
Definition: expr.h:174
UpperBound sub(const UpperBound &b1, const LowerBound &b2)
Definition: bounds.cc:200
Stmt makeAssert(Tcond &&cond, Tbody &&body, const Metadata &metadata=nullptr, const ID &id={}, std::source_location loc=std::source_location::current())
Definition: stmt.h:394
Expr lin2expr(const LinearExpr< T > &lin)
Definition: linear.h:130
DNF asDNF(const Expr &expr)
Definition: as_dnf.cc:114
Expr makeMod(T &&lhs, U &&rhs, std::source_location loc=std::source_location::current())
Definition: expr.h:303
Expr makeLE(T &&lhs, U &&rhs, std::source_location loc=std::source_location::current())
Definition: expr.h:382
Expr makeEQ(T &&lhs, U &&rhs, std::source_location loc=std::source_location::current())
Definition: expr.h:424
std::unordered_set< std::string > allReads(const AST &op, bool noRecurseIdx=false, bool noRecurseSubStmt=false)
Definition: all_uses.h:83
Stmt makeIf(Tcond &&cond, Tthen &&thenCase, Telse &&elseCase, const Metadata &metadata=nullptr, const ID &id={}, std::source_location loc=std::source_location::current())
Definition: stmt.h:354
Expr makeSub(T &&lhs, U &&rhs, std::source_location loc=std::source_location::current())
Definition: expr.h:188
bool hasIntersect(const std::unordered_set< T, Hash, KeyEqual > &lhs, const std::unordered_set< T, Hash, KeyEqual > &rhs)
Definition: container_utils.h:51
Expr makeGE(T &&lhs, U &&rhs, std::source_location loc=std::source_location::current())
Definition: expr.h:410
Expr makeVar(const std::string &name, std::source_location loc=std::source_location::current())
Definition: expr.h:42
Stmt makeFor(const std::string &iter, Tbegin &&begin, Tend &&end, Tstep &&step, Tlen &&len, Tproperty &&property, Tbody &&body, const Metadata &metadata=nullptr, const ID &id={}, std::source_location loc=std::source_location::current())
Definition: stmt.h:311
Expr makeIntConst(int64_t val, std::source_location loc=std::source_location::current())
Definition: expr.h:102
Expr makeLNot(T &&expr, std::source_location loc=std::source_location::current())
Definition: expr.h:490
bool isInt(BaseDataType dtype)
Definition: data_type.cc:27
std::optional< std::pair< LinearExpr< int64_t >, ASTNodeType > > linearComp(const Expr &expr)
Definition: analyze_linear.cc:56
Definition: comp_transient_bounds.h:19
std::vector< Expr > upper_
Definition: comp_transient_bounds.h:21
std::vector< Expr > lower_
Definition: comp_transient_bounds.h:21
Expr expr_
Definition: comp_transient_bounds.h:20