1#ifndef FREE_TENSOR_COMP_TRANSIENT_BOUNDS_H
2#define FREE_TENSOR_COMP_TRANSIENT_BOUNDS_H
5#include <unordered_set>
27 virtual const std::vector<Expr> &
conds()
const = 0;
48template <
class BaseClass>
56 std::vector<Expr> conds_;
60 if (transients_.count(op)) {
61 return transients_.at(op);
66 const std::vector<Expr> &
conds()
const override {
return conds_; }
69 void applyCond(
const Expr &_cond,
70 const std::unordered_set<std::string> &bodyAllWrites) {
71 auto dnf =
asDNF(_cond);
73 if (dnf.size() != 1) {
76 for (
auto &&item : dnf) {
77 for (
auto &&
sub : item) {
83 conds_.emplace_back(_cond);
87 for (
auto &&cond : dnf.front()) {
97 if (!norm.has_value()) {
101 auto &&[lin, type] = *norm;
106 for (
auto &&[k, a] : lin.coeff_) {
110 if (
lower.has_value()) {
111 transients_[a].expr_ = a;
112 transients_[a].lower_.emplace_back(
lower->expr());
114 if (upper.has_value()) {
115 transients_[a].expr_ = a;
116 transients_[a].upper_.emplace_back(upper->expr());
121 conds_.emplace_back(cond);
126 using BaseClass::visit;
128 typename BaseClass::StmtRetType
visit(
const For &op)
override {
135 if (transients_.count(var)) {
137 "iterators with the same name in nested loops are not allowed");
139 auto oldCondsSize = conds_.size();
150 }
else if (step < 0) {
163 if constexpr (std::is_base_of_v<SymbolTableInterface, BaseClass>) {
167 if constexpr (std::is_base_of_v<SymbolTableInterface, BaseClass>) {
170 conds_.resize(oldCondsSize);
171 transients_.erase(var);
173 if constexpr (!std::is_same_v<typename BaseClass::StmtRetType, void>) {
177 ->withVectorize(op->
property_->vectorize_)
179 ->withPreferLibs(op->
property_->preferLibs_);
180 property->reductions_.reserve(op->
property_->reductions_.size());
181 for (
auto &&r : op->
property_->reductions_) {
182 std::vector<Expr> begins, ends;
183 begins.reserve(r->begins_.size());
184 ends.reserve(r->ends_.size());
185 for (
auto &&item : r->begins_) {
186 begins.emplace_back((*
this)(item));
188 for (
auto &&item : r->ends_) {
189 ends.emplace_back((*
this)(item));
191 property->reductions_.emplace_back(
193 std::move(ends), r->syncFlush_));
195 return makeFor(op->
iter_, std::move(begin), std::move(end),
196 std::move(step), std::move(len), std::move(property),
202 typename BaseClass::StmtRetType
visit(
const If &op)
override {
205 auto oldMap = transients_;
206 auto oldCondsSize = conds_.size();
207 applyCond(op->cond_,
allWrites(op->thenCase_));
209 transients_ = oldMap;
210 conds_.resize(oldCondsSize);
212 [[maybe_unused]]
Stmt elseCase =
nullptr;
214 auto oldCondsSize = conds_.size();
217 transients_ = oldMap;
218 conds_.resize(oldCondsSize);
221 if constexpr (!std::is_same_v<typename BaseClass::StmtRetType, void>) {
222 return makeIf(std::move(cond), std::move(thenCase),
223 std::move(elseCase), op->metadata(), op->id(),
228 typename BaseClass::StmtRetType
visit(
const Assert &op)
override {
231 auto oldMap = transients_;
232 auto oldCondsSize = conds_.size();
233 applyCond(op->cond_,
allWrites(op->body_));
235 transients_ = oldMap;
236 conds_.resize(oldCondsSize);
238 if constexpr (!std::is_same_v<typename BaseClass::StmtRetType, void>) {
239 return makeAssert(std::move(cond), std::move(body), op->metadata(),
240 op->id(), op->debugBlame());
244 typename BaseClass::StmtRetType
visit(
const Assume &op)
override {
247 auto oldMap = transients_;
248 auto oldCondsSize = conds_.size();
249 applyCond(op->cond_,
allWrites(op->body_));
251 transients_ = oldMap;
252 conds_.resize(oldCondsSize);
254 if constexpr (!std::is_same_v<typename BaseClass::StmtRetType, void>) {
255 return makeAssume(std::move(cond), std::move(body), op->metadata(),
256 op->id(), op->debugBlame());
std::source_location debugBlame() const
Definition: ast.h:134
Definition: comp_transient_bounds.h:24
virtual TransientBound transient(const Expr &op) const =0
virtual const std::vector< Expr > & conds() const =0
Definition: comp_transient_bounds.h:50
BaseClass::StmtRetType visit(const Assume &op) override
Definition: comp_transient_bounds.h:244
TransientBound transient(const Expr &op) const override
Definition: comp_transient_bounds.h:59
BaseClass::StmtRetType visit(const For &op) override
Definition: comp_transient_bounds.h:128
const std::vector< Expr > & conds() const override
Definition: comp_transient_bounds.h:66
BaseClass::StmtRetType visit(const Assert &op) override
Definition: comp_transient_bounds.h:228
BaseClass::StmtRetType visit(const If &op) override
Definition: comp_transient_bounds.h:202
SubTree< ForProperty > property_
Definition: stmt.h:298
SubTree< ExprNode > begin_
Definition: stmt.h:294
SubTree< ExprNode > step_
Definition: stmt.h:296
SubTree< ExprNode > len_
Definition: stmt.h:297
SubTree< ExprNode > end_
Definition: stmt.h:295
std::string iter_
Definition: stmt.h:289
SubTree< StmtNode > body_
Definition: stmt.h:299
static Ref make()
Definition: ref.h:105
bool isValid() const
Definition: ref.h:89
const Metadata & metadata() const
Definition: ast.h:233
ID id() const
Definition: ast.cc:362
#define MAYBE_VOID_ASSIGN(name, expr)
Definition: maybe_void.h:15
#define MAYBE_VOID(name, expr)
Definition: maybe_void.h:25
Definition: allocator.h:9
std::unordered_map< K, V, Hasher, HashComparator > ASTHashMap
Definition: hash.h:114
T lower(const T &_ast, const Ref< Target > &_target=nullptr, const std::unordered_set< std::string > &skipPasses={}, int verbose=0)
Definition: lower.h:53
Expr makeLT(T &&lhs, U &&rhs, std::source_location loc=std::source_location::current())
Definition: expr.h:368
std::pair< std::optional< LowerBound >, std::optional< UpperBound > > lin2bounds(const LinearExpr< T > &_lin, ASTNodeType cmp, const Expr &x)
Definition: bounds.h:84
Stmt makeAssume(Tcond &&cond, Tbody &&body, const Metadata &metadata=nullptr, const ID &id={}, std::source_location loc=std::source_location::current())
Definition: stmt.h:427
Ref< ReductionItem > makeReductionItem(ReduceOp op, const std::string &var, Tbegins &&begins, Tends &&ends, bool syncFlush)
Definition: for_property.h:26
std::unordered_set< std::string > allWrites(const AST &op, bool noRecurseIdx=false, bool noRecurseSubStmt=false)
Definition: all_uses.h:100
Expr makeGT(T &&lhs, U &&rhs, std::source_location loc=std::source_location::current())
Definition: expr.h:396
Expr makeAdd(T &&lhs, U &&rhs, std::source_location loc=std::source_location::current())
Definition: expr.h:174
UpperBound sub(const UpperBound &b1, const LowerBound &b2)
Definition: bounds.cc:200
Stmt makeAssert(Tcond &&cond, Tbody &&body, const Metadata &metadata=nullptr, const ID &id={}, std::source_location loc=std::source_location::current())
Definition: stmt.h:394
Expr lin2expr(const LinearExpr< T > &lin)
Definition: linear.h:130
DNF asDNF(const Expr &expr)
Definition: as_dnf.cc:114
Expr makeMod(T &&lhs, U &&rhs, std::source_location loc=std::source_location::current())
Definition: expr.h:303
Expr makeLE(T &&lhs, U &&rhs, std::source_location loc=std::source_location::current())
Definition: expr.h:382
Expr makeEQ(T &&lhs, U &&rhs, std::source_location loc=std::source_location::current())
Definition: expr.h:424
std::unordered_set< std::string > allReads(const AST &op, bool noRecurseIdx=false, bool noRecurseSubStmt=false)
Definition: all_uses.h:83
Stmt makeIf(Tcond &&cond, Tthen &&thenCase, Telse &&elseCase, const Metadata &metadata=nullptr, const ID &id={}, std::source_location loc=std::source_location::current())
Definition: stmt.h:354
Expr makeSub(T &&lhs, U &&rhs, std::source_location loc=std::source_location::current())
Definition: expr.h:188
bool hasIntersect(const std::unordered_set< T, Hash, KeyEqual > &lhs, const std::unordered_set< T, Hash, KeyEqual > &rhs)
Definition: container_utils.h:51
Expr makeGE(T &&lhs, U &&rhs, std::source_location loc=std::source_location::current())
Definition: expr.h:410
Expr makeVar(const std::string &name, std::source_location loc=std::source_location::current())
Definition: expr.h:42
Stmt makeFor(const std::string &iter, Tbegin &&begin, Tend &&end, Tstep &&step, Tlen &&len, Tproperty &&property, Tbody &&body, const Metadata &metadata=nullptr, const ID &id={}, std::source_location loc=std::source_location::current())
Definition: stmt.h:311
Expr makeIntConst(int64_t val, std::source_location loc=std::source_location::current())
Definition: expr.h:102
Expr makeLNot(T &&expr, std::source_location loc=std::source_location::current())
Definition: expr.h:490
bool isInt(BaseDataType dtype)
Definition: data_type.cc:27
std::optional< std::pair< LinearExpr< int64_t >, ASTNodeType > > linearComp(const Expr &expr)
Definition: analyze_linear.cc:56
Definition: comp_transient_bounds.h:19
std::vector< Expr > upper_
Definition: comp_transient_bounds.h:21
std::vector< Expr > lower_
Definition: comp_transient_bounds.h:21
Expr expr_
Definition: comp_transient_bounds.h:20