FreeTensor
Loading...
Searching...
No Matches
count_contig_access_loops.h
Go to the documentation of this file.
1#ifndef FREE_TENSOR_COUNT_CONTIG_ACCESS_LOOPS_H
2#define FREE_TENSOR_COUNT_CONTIG_ACCESS_LOOPS_H
3
4#include <unordered_map>
5
8#include <math/utils.h>
9#include <pass/const_fold.h>
10#include <visitor.h>
11
12namespace freetensor {
13
14class CountContigAccessLoops : public SymbolTable<Visitor> {
16
17 std::unordered_map<ID, std::pair<int64_t, int>>
18 counts_; // for ID -> (count, -depth)
19 AnalyzeLinear analyzeLinear_;
20 int64_t repeat_ = 1;
21 int depth_ = 0;
22
23 public:
24 const std::unordered_map<ID, std::pair<int64_t, int>> &counts() const {
25 return counts_;
26 }
27
28 private:
29 int64_t getStaticSize(const std::string &var) {
30 int64_t ret = 1;
31 for (auto &&dim : buffer(var)->tensor()->shape()) {
32 if (dim->nodeType() == ASTNodeType::IntConst) {
33 ret *= dim.as<IntConstNode>()->val_;
34 } else {
35 return -1;
36 }
37 }
38 return ret;
39 }
40
41 void countContigVars(std::unordered_map<std::string, int> *cnt,
42 const Expr &expr, const Expr &modP = nullptr) {
43 analyzeLinear_(expr);
44 for (auto &&[_k, a] : analyzeLinear_.result().at(expr).coeff_) {
45 int64_t k = _k;
46 if (modP.isValid()) {
47 // TODO: Dynamic p: (p - 1) === -1, mod p
48 if (auto _p = constFold(modP);
50 auto p = _p.as<IntConstNode>()->val_;
51 if (mod(k, p) == mod(1, p) || mod(k, p) == mod(-1, p)) {
52 goto ok;
53 }
54 }
55 }
56 if (k == 1 || k == -1) {
57 goto ok;
58 }
59 continue;
60
61 ok:
62 switch (a->nodeType()) {
64 (*cnt)[a.template as<VarNode>()->name_] += repeat_;
65 break;
67 // TODO: ASTNodeType::Remainder
68 countContigVars(cnt, a.as<BinaryExprNode>()->lhs_,
69 a.as<BinaryExprNode>()->rhs_);
70 break;
71 default:;
72 }
73 }
74 }
75
76 template <class T> void visitMemAccess(const T &op) {
78 auto size = getStaticSize(op->var_);
79 if (size != -1 && size < 128) {
80 // We don't count too small vars here because they are likely
81 // registers
82 return;
83 }
84 if (!op->indices_.empty()) {
85 std::unordered_map<std::string, int> cnt;
86 countContigVars(&cnt, op->indices_.back());
87 for (auto &&[v, c] : cnt) {
88 counts_[loop(v)->id()].first += c;
89 }
90 }
91 }
92
93 protected:
94 using BaseClass::visit;
95 void visit(const For &op) override;
96 void visit(const Load &op) override { visitMemAccess(op); }
97 void visit(const Store &op) override { visitMemAccess(op); }
98 void visit(const ReduceTo &op) override { visitMemAccess(op); }
99 void visit(const MatMul &op) override {} // do nothing
100};
101
102} // namespace freetensor
103
104#endif // FREE_TENSOR_COUNT_CONTIG_ACCESS_LOOPS_H
virtual ASTNodeType nodeType() const =0
Definition: analyze_linear.h:14
const std::unordered_map< AST, LinearExpr< int64_t > > & result() const
Definition: analyze_linear.h:18
Definition: count_contig_access_loops.h:14
void visit(const Store &op) override
Definition: count_contig_access_loops.h:97
void visit(const Load &op) override
Definition: count_contig_access_loops.h:96
const std::unordered_map< ID, std::pair< int64_t, int > > & counts() const
Definition: count_contig_access_loops.h:24
void visit(const MatMul &op) override
Definition: count_contig_access_loops.h:99
void visit(const ReduceTo &op) override
Definition: count_contig_access_loops.h:98
void visit(const For &op) override
Definition: count_contig_access_loops.cc:5
Definition: expr.h:93
Definition: ref.h:24
ID id() const
Definition: ast.cc:362
Definition: symbol_table.h:122
BaseClass::StmtRetType visit(const VarDef &op) override
Definition: symbol_table.h:167
const For & loop(const std::string &name) const override
Definition: symbol_table.h:152
Ref< Buffer > buffer(const std::string &name) const override
Definition: symbol_table.h:145
Definition: allocator.h:9
Stmt constFold(const Stmt &op)
Definition: const_fold.h:177
Ref< ForNode > For
Definition: stmt.h:308
auto mod(IntegralExceptBool auto a, IntegralExceptBool auto b)
Definition: utils.h:32
Ref< ExprNode > Expr
Definition: ast.h:184