1#ifndef FREE_TENSOR_COUNT_CONTIG_ACCESS_LOOPS_H
2#define FREE_TENSOR_COUNT_CONTIG_ACCESS_LOOPS_H
4#include <unordered_map>
17 std::unordered_map<ID, std::pair<int64_t, int>>
24 const std::unordered_map<ID, std::pair<int64_t, int>> &
counts()
const {
29 int64_t getStaticSize(
const std::string &var) {
31 for (
auto &&dim :
buffer(var)->tensor()->shape()) {
41 void countContigVars(std::unordered_map<std::string, int> *cnt,
42 const Expr &expr,
const Expr &modP =
nullptr) {
44 for (
auto &&[_k, a] : analyzeLinear_.
result().at(expr).coeff_) {
50 auto p = _p.as<IntConstNode>()->val_;
56 if (k == 1 || k == -1) {
62 switch (a->nodeType()) {
64 (*cnt)[a.template as<VarNode>()->name_] += repeat_;
68 countContigVars(cnt, a.as<BinaryExprNode>()->lhs_,
69 a.as<BinaryExprNode>()->rhs_);
76 template <
class T>
void visitMemAccess(
const T &op) {
78 auto size = getStaticSize(op->var_);
79 if (size != -1 && size < 128) {
84 if (!op->indices_.empty()) {
85 std::unordered_map<std::string, int> cnt;
86 countContigVars(&cnt, op->indices_.back());
87 for (
auto &&[v, c] : cnt) {
88 counts_[
loop(v)->
id()].first += c;
96 void visit(
const Load &op)
override { visitMemAccess(op); }
97 void visit(
const Store &op)
override { visitMemAccess(op); }
virtual ASTNodeType nodeType() const =0
Definition: analyze_linear.h:14
const std::unordered_map< AST, LinearExpr< int64_t > > & result() const
Definition: analyze_linear.h:18
Definition: count_contig_access_loops.h:14
void visit(const Store &op) override
Definition: count_contig_access_loops.h:97
void visit(const Load &op) override
Definition: count_contig_access_loops.h:96
const std::unordered_map< ID, std::pair< int64_t, int > > & counts() const
Definition: count_contig_access_loops.h:24
void visit(const MatMul &op) override
Definition: count_contig_access_loops.h:99
void visit(const ReduceTo &op) override
Definition: count_contig_access_loops.h:98
void visit(const For &op) override
Definition: count_contig_access_loops.cc:5
ID id() const
Definition: ast.cc:362
Definition: symbol_table.h:122
BaseClass::StmtRetType visit(const VarDef &op) override
Definition: symbol_table.h:167
const For & loop(const std::string &name) const override
Definition: symbol_table.h:152
Ref< Buffer > buffer(const std::string &name) const override
Definition: symbol_table.h:145
Definition: allocator.h:9
Stmt constFold(const Stmt &op)
Definition: const_fold.h:177
Ref< ForNode > For
Definition: stmt.h:308
auto mod(IntegralExceptBool auto a, IntegralExceptBool auto b)
Definition: utils.h:32
Ref< ExprNode > Expr
Definition: ast.h:184