1#ifndef FREE_TENSOR_SHRINK_VAR_H
2#define FREE_TENSOR_SHRINK_VAR_H
4#include <unordered_map>
23 const std::unordered_map<ID, AccessBound> &newRangeWithShape_;
27 const std::unordered_map<ID, AccessBound> &newRangeWithoutShape_;
31 std::unordered_map<std::string, std::vector<Expr>> lowerWithShape_,
33 std::unordered_map<std::string, std::vector<Expr>> lowerWithoutShape_,
35 std::unordered_map<ID, Expr> guards_;
38 ShrinkVar(
const std::unordered_map<ID, AccessBound> &newRangeWithShape,
39 const std::unordered_map<ID, AccessBound> &newRangeWithoutShape,
40 bool guardReads =
false)
41 : newRangeWithShape_(newRangeWithShape),
42 newRangeWithoutShape_(newRangeWithoutShape), guardReads_(guardReads) {
46 template <
class T> T modifyAccess(
const T &op) {
47 if (lowerWithoutShape_.count(op->var_)) {
48 auto &&offset = lowerWithoutShape_.at(op->var_);
49 ASSERT(offset.size() == op->indices_.size());
50 for (
auto &&[idx, off] : views::zip(op->indices_, offset)) {
59 template <
class T>
void addGuard(
const T &oldOp,
const T &op) {
63 if (upperWithoutShape_.count(op->var_)) {
64 auto &&upper = upperWithoutShape_.at(op->var_);
65 ASSERT(upper.size() == op->indices_.size());
66 for (
auto &&[idx, u] : views::zip(oldOp->indices_, upper)) {
73 if (lowerWithoutShape_.count(op->var_)) {
74 auto &&
lower = lowerWithoutShape_.at(op->var_);
76 for (
auto &&[idx, l] : views::zip(oldOp->indices_,
lower)) {
83 if (guard.isValid()) {
85 if constexpr (std::is_base_of_v<StmtNode, typename T::Object>) {
87 }
else if constexpr (std::is_base_of_v<ExprNode,
88 typename T::Object>) {
93 guards_[s->id()] = guards_[s->id()].isValid()
Definition: shrink_var.h:20
Stmt visitStmt(const Stmt &s) override
Definition: shrink_var.cc:12
Stmt visit(const VarDef &op) override
Definition: shrink_var.cc:20
ShrinkVar(const std::unordered_map< ID, AccessBound > &newRangeWithShape, const std::unordered_map< ID, AccessBound > &newRangeWithoutShape, bool guardReads=false)
Definition: shrink_var.h:38
Ref< StmtNode > parentStmt() const
Definition: ast.cc:103
#define ASSERT(expr)
Definition: except.h:152
#define DEFINE_PASS_FOR_FUNC(pass)
Definition: func.h:88
Definition: allocator.h:9
Ref< VarDefNode > VarDef
Definition: stmt.h:107
T lower(const T &_ast, const Ref< Target > &_target=nullptr, const std::unordered_set< std::string > &skipPasses={}, int verbose=0)
Definition: lower.h:53
Ref< LoadNode > Load
Definition: expr.h:61
Ref< StoreNode > Store
Definition: stmt.h:140
Ref< ReduceToNode > ReduceTo
Definition: stmt.h:248
Expr makeLE(T &&lhs, U &&rhs, std::source_location loc=std::source_location::current())
Definition: expr.h:382
Stmt shrinkVar(const Stmt &op)
Definition: shrink_var.cc:102
Ref< StmtNode > Stmt
Definition: ast.h:152
Stmt shrinkSingleVar(const Stmt &op, const ID &varDefId)
Definition: shrink_var.cc:128
Expr makeLAnd(T &&lhs, U &&rhs, std::source_location loc=std::source_location::current())
Definition: expr.h:452
Expr makeSub(T &&lhs, U &&rhs, std::source_location loc=std::source_location::current())
Definition: expr.h:188
Expr makeGE(T &&lhs, U &&rhs, std::source_location loc=std::source_location::current())
Definition: expr.h:410
Ref< ExprNode > Expr
Definition: ast.h:184