FreeTensor
Loading...
Searching...
No Matches
grad.h
Go to the documentation of this file.
1#ifndef FREE_TENSOR_GRAD_H
2#define FREE_TENSOR_GRAD_H
3
4#include <unordered_map>
5#include <unordered_set>
6
12#include <autograd/user_grad.h>
13#include <func.h>
14#include <mutator.h>
15#include <visitor.h>
16
17namespace freetensor {
18
19class InsertUserGrad : public Mutator {
20 const SymbolTableInterface &symbolTable_;
21 const std::unordered_map<ID, std::string> &intermediatesMap_;
22 const std::unordered_map<std::string, std::pair<std::string, Expr>>
23 &userVersions_;
24 const std::unordered_map<std::string, std::string>
25 &gradNames_; // x -> dy/dx
26
27 std::unordered_set<std::string> localVarDefNames_;
28
29 public:
31 const SymbolTableInterface &symbolTable,
32 const std::unordered_map<ID, std::string> &intermediatesMap,
33 const std::unordered_map<std::string, std::pair<std::string, Expr>>
34 &userVersions,
35 const std::unordered_map<std::string, std::string> &gradNames)
36 : symbolTable_(symbolTable), intermediatesMap_(intermediatesMap),
37 userVersions_(userVersions), gradNames_(gradNames) {}
38
39 protected:
40 Expr visit(const LoadAtVersion &op) override;
41 Stmt visit(const Store &op) override;
42 Stmt visit(const ReduceTo &op) override;
43 Stmt visit(const VarDef &op) override;
44};
45
46template <class BaseClass> class RenewIDs : public BaseClass {
47 protected:
48 Stmt visitStmt(const Stmt &s) override {
49 auto ret = BaseClass::visitStmt(s);
50 ret->setId();
51 return ret;
52 }
53};
54
55class Grad : public RenewIDs<SymbolTable<Mutator>> {
56 // Because a statement can be both recomputed and computed for gradient, and
57 // we may even recompute a statement several times, all IDs must be renewed,
58 // even for recomputation
60
61 std::unordered_map<StmtOrExprID, Derivative::LazyFullDerivative>
62 &derivatives_; // Mutable for lazy operations
63 const std::unordered_set<std::string> &requires_;
64 const std::unordered_set<std::string> &provides_;
65 const std::unordered_set<ID> &tapes_;
66 const std::unordered_set<ID> &defsNeedGrad_;
67 const std::unordered_map<ID, std::string>
68 &intermediatesMap_; // All saved variables, including in forward stage
69 // (tapes) and backward stage (during recomputation)
70 const std::unordered_map<StmtOrExprID, Expr> &versions_;
71 const std::unordered_map<std::string, std::pair<std::string, Expr>>
72 &userVersions_;
73 const std::unordered_map<ID, Expr> &totLens_;
74 const std::unordered_set<ID> &saveLocalStmts_;
75 const std::unordered_set<Stmt> &notSingleWrite_;
76 bool resetProvidedGrad_;
77 const std::unordered_map<ID, InversionInfo> &inverseStmts_;
78 std::vector<RangeToUserGrad> userGrads_; // mutable
79
80 std::unordered_map<std::string, std::string> requireGrads_; // var name map
81 std::unordered_map<std::string, std::string> provideGrads_; // var name map
82
83 std::unordered_map<std::string, std::string> gradNames_; // x -> dy/dx
84 std::unordered_map<Expr, Expr> equLoads_;
85 std::unordered_map<std::string, std::unordered_set<Stmt>>
86 recomputed_; // var name -> set{stmt}
87 bool isRecompute_ = false;
88
89 std::unordered_set<ID> inverselyUpdated_; // {VarDef IDs}
90
91 std::optional<RangeToUserGrad> userGradOpen_;
92 ID userGradInsertPos_;
93
94 private:
103 ReplaceBySaved getReplacer(const Stmt &stmt,
104 const Store &alreadyStored = nullptr) const;
105
106 Stmt doVisitStmt(const Stmt &s);
107
108 public:
109 Grad(std::unordered_map<StmtOrExprID, Derivative::LazyFullDerivative>
110 &derivatives,
111 const std::unordered_set<std::string> &_requires,
112 const std::unordered_set<std::string> &provides,
113 const std::unordered_set<ID> &tapes,
114 const std::unordered_set<ID> &defsNeedGrad,
115 const std::unordered_map<ID, std::string> &intermediatesMap,
116 const std::unordered_map<StmtOrExprID, Expr> &versions,
117 const std::unordered_map<std::string, std::pair<std::string, Expr>>
118 &userVersions,
119 const std::unordered_map<ID, Expr> &totLens,
120 const std::unordered_set<ID> &saveLocalStmts,
121 const std::unordered_set<Stmt> &notSingleWrite, bool resetProvidedGrad,
122 const std::unordered_map<ID, InversionInfo> &inverseStmts,
123 const std::vector<RangeToUserGrad> &userGrads)
124 : derivatives_(derivatives), requires_(_requires), provides_(provides),
125 tapes_(tapes), defsNeedGrad_(defsNeedGrad),
126 intermediatesMap_(intermediatesMap), versions_(versions),
127 userVersions_(userVersions), totLens_(totLens),
128 saveLocalStmts_(saveLocalStmts), notSingleWrite_(notSingleWrite),
129 resetProvidedGrad_(resetProvidedGrad), inverseStmts_(inverseStmts),
130 userGrads_(userGrads) {}
131
132 const std::unordered_map<std::string, std::string> &requireGrads() const {
133 return requireGrads_;
134 }
135 const std::unordered_map<std::string, std::string> &provideGrads() const {
136 return provideGrads_;
137 }
138
139 protected:
140 Stmt visitStmt(const Stmt &s) override;
141 Stmt visit(const StmtSeq &op) override;
142 Stmt visit(const For &op) override;
143 Stmt visit(const If &op) override;
144 Stmt visit(const Assert &op) override;
145 Stmt visit(const VarDef &op) override;
146 Stmt visit(const Store &op) override;
147 Stmt visit(const ReduceTo &op) override;
148};
149
186std::tuple<Stmt, Stmt, std::unordered_map<std::string, std::string>,
187 std::unordered_map<std::string, std::string>,
188 std::unordered_map<ID, std::string>>
189gradBody(const Stmt &op, const std::unordered_set<std::string> &_requires,
190 const std::unordered_set<std::string> &provides,
191 const TapeStrategy &tapes, bool resetProvidedGrad = true,
192 bool invert = false,
193 const std::vector<StmtSetToUserGrad> &userGrads = {});
194
195std::tuple<Func, Func, std::unordered_map<std::string, std::string>,
196 std::unordered_map<std::string, std::string>>
197gradFuncInplace(const Func &func,
198 const std::unordered_set<std::string> &_requires,
199 const std::unordered_set<std::string> &provides,
200 const TapeStrategy &tapes, bool tapeInClosure = true,
201 bool resetProvidedGrad = true, bool invert = false,
202 const std::vector<StmtSetToUserGrad> &userGrads = {});
203
204std::tuple<Func, Func, std::unordered_map<std::string, std::string>,
205 std::unordered_map<std::string, std::string>>
206gradFuncOutOfPlace(const Func &func,
207 const std::unordered_set<std::string> &_requires,
208 const std::unordered_set<std::string> &provides,
209 const TapeStrategy &tapes, bool tapeInClosure = true,
210 bool resetProvidedGrad = true, bool invert = false,
211 const std::vector<StmtSetToUserGrad> &userGrads = {});
214} // namespace freetensor
215
216#endif // FREE_TENSOR_GRAD_H
Definition: grad.h:55
const std::unordered_map< std::string, std::string > & requireGrads() const
Definition: grad.h:132
Grad(std::unordered_map< StmtOrExprID, Derivative::LazyFullDerivative > &derivatives, const std::unordered_set< std::string > &_requires, const std::unordered_set< std::string > &provides, const std::unordered_set< ID > &tapes, const std::unordered_set< ID > &defsNeedGrad, const std::unordered_map< ID, std::string > &intermediatesMap, const std::unordered_map< StmtOrExprID, Expr > &versions, const std::unordered_map< std::string, std::pair< std::string, Expr > > &userVersions, const std::unordered_map< ID, Expr > &totLens, const std::unordered_set< ID > &saveLocalStmts, const std::unordered_set< Stmt > &notSingleWrite, bool resetProvidedGrad, const std::unordered_map< ID, InversionInfo > &inverseStmts, const std::vector< RangeToUserGrad > &userGrads)
Definition: grad.h:109
const std::unordered_map< std::string, std::string > & provideGrads() const
Definition: grad.h:135
Stmt visitStmt(const Stmt &s) override
Definition: grad.cc:193
Stmt visit(const StmtSeq &op) override
Definition: grad.cc:231
Definition: id.h:18
Definition: grad.h:19
Expr visit(const LoadAtVersion &op) override
Definition: grad.cc:31
InsertUserGrad(const SymbolTableInterface &symbolTable, const std::unordered_map< ID, std::string > &intermediatesMap, const std::unordered_map< std::string, std::pair< std::string, Expr > > &userVersions, const std::unordered_map< std::string, std::string > &gradNames)
Definition: grad.h:30
Definition: mutator.h:11
Definition: grad.h:46
Stmt visitStmt(const Stmt &s) override
Definition: grad.h:48
Definition: replace_by_saved.h:27
Definition: symbol_table.h:13
Definition: allocator.h:9
std::tuple< Func, Func, std::unordered_map< std::string, std::string >, std::unordered_map< std::string, std::string > > gradFuncInplace(const Func &func, const std::unordered_set< std::string > &_requires, const std::unordered_set< std::string > &provides, const TapeStrategy &tapes, bool tapeInClosure=true, bool resetProvidedGrad=true, bool invert=false, const std::vector< StmtSetToUserGrad > &userGrads={})
Definition: grad.cc:825
std::tuple< Stmt, Stmt, std::unordered_map< std::string, std::string >, std::unordered_map< std::string, std::string >, std::unordered_map< ID, std::string > > gradBody(const Stmt &op, const std::unordered_set< std::string > &_requires, const std::unordered_set< std::string > &provides, const TapeStrategy &tapes, bool resetProvidedGrad=true, bool invert=false, const std::vector< StmtSetToUserGrad > &userGrads={})
Definition: grad.cc:593
std::tuple< Func, Func, std::unordered_map< std::string, std::string >, std::unordered_map< std::string, std::string > > gradFuncOutOfPlace(const Func &func, const std::unordered_set< std::string > &_requires, const std::unordered_set< std::string > &provides, const TapeStrategy &tapes, bool tapeInClosure=true, bool resetProvidedGrad=true, bool invert=false, const std::vector< StmtSetToUserGrad > &userGrads={})
Definition: grad.cc:837
Ref< StmtNode > Stmt
Definition: ast.h:152
Ref< FuncNode > Func
Definition: func.h:64