FreeTensor
Loading...
Searching...
No Matches
analyze_version.h
Go to the documentation of this file.
1#ifndef FREE_TENSOR_ANALYZE_VERSION_H
2#define FREE_TENSOR_ANALYZE_VERSION_H
3
4#include <unordered_map>
5#include <unordered_set>
6
10#include <autograd/derivative.h>
11#include <visitor.h>
12
13namespace freetensor {
14
15class CountScopeLen : public CompTransientBounds<SymbolTable<Visitor>> {
17
18 ID def_;
19 std::string var_;
20 const std::unordered_set<ID> &affectingScopes_; // For IDs
21 const std::unordered_set<ID> &needVersions_; // Store or ReduceTo IDs
22 const std::unordered_set<ID> &fakeStoreIds_;
23
24 // Total number of versions in a sub-tree
25 std::unordered_map<Stmt, Expr> scopeLen_;
26
27 public:
28 CountScopeLen(const ID &def, const std::unordered_set<ID> &affectingScopes,
29 const std::unordered_set<ID> &needVersions,
30 const std::unordered_set<ID> &fakeStoreIds)
31 : def_(def), affectingScopes_(affectingScopes),
32 needVersions_(needVersions), fakeStoreIds_(fakeStoreIds) {}
33
34 const std::unordered_map<Stmt, Expr> &scopeLen() const { return scopeLen_; }
35
36 protected:
37 using BaseClass::visit;
38 void visit(const Store &op) override;
39 void visit(const ReduceTo &op) override;
40 void visit(const VarDef &op) override;
41 void visit(const For &op) override;
42 void visit(const StmtSeq &op) override;
43 void visit(const If &op) override;
44 void visit(const Assert &op) override;
45};
46
47class AnalyzeVersion : public TrackStmt<Visitor> {
49
50 ID def_;
51 std::string var_;
52 const std::unordered_set<ID> &affectingScopes_; // For IDs
53 const std::unordered_set<ID> &needVersions_; // Store or ReduceTo IDs
54 const std::unordered_set<ID> &fakeStoreIds_;
55 const std::unordered_map<Stmt, Expr> &scopeLen_;
56 Expr totLen_;
57 std::unordered_map<StmtOrExprID, Expr> &versions_;
58 std::unordered_map<std::string, std::pair<std::string, Expr>>
59 &userVersions_;
60 std::string tapeName_;
61 Expr offset_ = makeIntConst(0);
62
63 public:
64 AnalyzeVersion(const ID &def, const std::unordered_set<ID> &affectingScopes,
65 const std::unordered_set<ID> &needVersions,
66 const std::unordered_set<ID> &fakeStoreIds,
67 const std::unordered_map<Stmt, Expr> &scopeLen,
68 const Expr &totLen,
69 std::unordered_map<StmtOrExprID, Expr> &versions,
70 std::unordered_map<std::string, std::pair<std::string, Expr>>
71 &userVersions)
72 : def_(def), affectingScopes_(affectingScopes),
73 needVersions_(needVersions), fakeStoreIds_(fakeStoreIds),
74 scopeLen_(scopeLen), totLen_(totLen), versions_(versions),
75 userVersions_(userVersions) {}
76
77 const std::string &tapeName() const { return tapeName_; }
78
79 protected:
80 void visit(const Load &op) override;
81 void visit(const MarkVersion &op) override;
82 void visit(const Store &op) override;
83 void visit(const ReduceTo &op) override;
84 void visit(const VarDef &op) override;
85 void visit(const For &op) override;
86 void visit(const StmtSeq &op) override;
87};
88
93class SetUserVersionsForInputs : public SymbolTable<Visitor> {
95
96 std::unordered_map<std::string, std::pair<std::string, Expr>>
97 &userVersions_;
98
99 public:
101 std::unordered_map<std::string, std::pair<std::string, Expr>>
102 &userVersions)
103 : userVersions_(userVersions) {}
104
105 protected:
106 using BaseClass::visit;
107 void visit(const MarkVersion &op) override;
108};
109
139std::tuple<std::unordered_map<StmtOrExprID, Expr>, std::unordered_map<ID, Expr>,
140 std::unordered_set<ID>,
141 std::unordered_map<std::string, std::pair<std::string, Expr>>>
143 const Stmt &op,
144 const std::unordered_map<ID, std::unordered_set<ID>> &needVersions,
145 const std::unordered_map<StmtOrExprID, Derivative::LazyFullDerivative>
146 &derivatives,
147 bool localVersionsOnly);
148
149} // namespace freetensor
150
151#endif // FREE_TENSOR_ANALYZE_VERSION_H
Definition: analyze_version.h:47
void visit(const Load &op) override
Definition: analyze_version.cc:155
const std::string & tapeName() const
Definition: analyze_version.h:77
AnalyzeVersion(const ID &def, const std::unordered_set< ID > &affectingScopes, const std::unordered_set< ID > &needVersions, const std::unordered_set< ID > &fakeStoreIds, const std::unordered_map< Stmt, Expr > &scopeLen, const Expr &totLen, std::unordered_map< StmtOrExprID, Expr > &versions, std::unordered_map< std::string, std::pair< std::string, Expr > > &userVersions)
Definition: analyze_version.h:64
Definition: comp_transient_bounds.h:50
BaseClass::StmtRetType visit(const For &op) override
Definition: comp_transient_bounds.h:128
Definition: analyze_version.h:15
CountScopeLen(const ID &def, const std::unordered_set< ID > &affectingScopes, const std::unordered_set< ID > &needVersions, const std::unordered_set< ID > &fakeStoreIds)
Definition: analyze_version.h:28
const std::unordered_map< Stmt, Expr > & scopeLen() const
Definition: analyze_version.h:34
void visit(const Store &op) override
Definition: analyze_version.cc:53
Definition: id.h:18
Definition: analyze_version.h:93
SetUserVersionsForInputs(std::unordered_map< std::string, std::pair< std::string, Expr > > &userVersions)
Definition: analyze_version.h:100
void visit(const MarkVersion &op) override
Definition: analyze_version.cc:238
Definition: symbol_table.h:122
BaseClass::StmtRetType visit(const VarDef &op) override
Definition: symbol_table.h:167
const VarDef & def(const std::string &name) const override
Definition: symbol_table.h:142
Definition: track_stmt.h:24
Definition: allocator.h:9
std::tuple< std::unordered_map< StmtOrExprID, Expr >, std::unordered_map< ID, Expr >, std::unordered_set< ID >, std::unordered_map< std::string, std::pair< std::string, Expr > > > analyzeVersion(const Stmt &op, const std::unordered_map< ID, std::unordered_set< ID > > &needVersions, const std::unordered_map< StmtOrExprID, Derivative::LazyFullDerivative > &derivatives, bool localVersionsOnly)
Definition: analyze_version.cc:248
Ref< StmtNode > Stmt
Definition: ast.h:152
Expr makeIntConst(int64_t val, std::source_location loc=std::source_location::current())
Definition: expr.h:102