1#ifndef FREE_TENSOR_GRAD_H
2#define FREE_TENSOR_GRAD_H
4#include <unordered_map>
5#include <unordered_set>
21 const std::unordered_map<ID, std::string> &intermediatesMap_;
22 const std::unordered_map<std::string, std::pair<std::string, Expr>>
24 const std::unordered_map<std::string, std::string>
27 std::unordered_set<std::string> localVarDefNames_;
32 const std::unordered_map<ID, std::string> &intermediatesMap,
33 const std::unordered_map<std::string, std::pair<std::string, Expr>>
35 const std::unordered_map<std::string, std::string> &gradNames)
36 : symbolTable_(symbolTable), intermediatesMap_(intermediatesMap),
37 userVersions_(userVersions), gradNames_(gradNames) {}
46template <
class BaseClass>
class RenewIDs :
public BaseClass {
49 auto ret = BaseClass::visitStmt(s);
61 std::unordered_map<StmtOrExprID, Derivative::LazyFullDerivative>
63 const std::unordered_set<std::string> &requires_;
64 const std::unordered_set<std::string> &provides_;
65 const std::unordered_set<ID> &tapes_;
66 const std::unordered_set<ID> &defsNeedGrad_;
67 const std::unordered_map<ID, std::string>
70 const std::unordered_map<StmtOrExprID, Expr> &versions_;
71 const std::unordered_map<std::string, std::pair<std::string, Expr>>
73 const std::unordered_map<ID, Expr> &totLens_;
74 const std::unordered_set<ID> &saveLocalStmts_;
75 const std::unordered_set<Stmt> ¬SingleWrite_;
76 bool resetProvidedGrad_;
77 const std::unordered_map<ID, InversionInfo> &inverseStmts_;
78 std::vector<RangeToUserGrad> userGrads_;
80 std::unordered_map<std::string, std::string> requireGrads_;
81 std::unordered_map<std::string, std::string> provideGrads_;
83 std::unordered_map<std::string, std::string> gradNames_;
84 std::unordered_map<Expr, Expr> equLoads_;
85 std::unordered_map<std::string, std::unordered_set<Stmt>>
87 bool isRecompute_ =
false;
89 std::unordered_set<ID> inverselyUpdated_;
91 std::optional<RangeToUserGrad> userGradOpen_;
92 ID userGradInsertPos_;
104 const Store &alreadyStored =
nullptr)
const;
109 Grad(std::unordered_map<StmtOrExprID, Derivative::LazyFullDerivative>
111 const std::unordered_set<std::string> &_requires,
112 const std::unordered_set<std::string> &provides,
113 const std::unordered_set<ID> &tapes,
114 const std::unordered_set<ID> &defsNeedGrad,
115 const std::unordered_map<ID, std::string> &intermediatesMap,
116 const std::unordered_map<StmtOrExprID, Expr> &versions,
117 const std::unordered_map<std::string, std::pair<std::string, Expr>>
119 const std::unordered_map<ID, Expr> &totLens,
120 const std::unordered_set<ID> &saveLocalStmts,
121 const std::unordered_set<Stmt> ¬SingleWrite,
bool resetProvidedGrad,
122 const std::unordered_map<ID, InversionInfo> &inverseStmts,
123 const std::vector<RangeToUserGrad> &userGrads)
124 : derivatives_(derivatives), requires_(_requires), provides_(provides),
125 tapes_(tapes), defsNeedGrad_(defsNeedGrad),
126 intermediatesMap_(intermediatesMap), versions_(versions),
127 userVersions_(userVersions), totLens_(totLens),
128 saveLocalStmts_(saveLocalStmts), notSingleWrite_(notSingleWrite),
129 resetProvidedGrad_(resetProvidedGrad), inverseStmts_(inverseStmts),
130 userGrads_(userGrads) {}
132 const std::unordered_map<std::string, std::string> &
requireGrads()
const {
133 return requireGrads_;
135 const std::unordered_map<std::string, std::string> &
provideGrads()
const {
136 return provideGrads_;
186std::tuple<Stmt, Stmt, std::unordered_map<std::string, std::string>,
187 std::unordered_map<std::string, std::string>,
188 std::unordered_map<ID, std::string>>
189gradBody(
const Stmt &op,
const std::unordered_set<std::string> &_requires,
190 const std::unordered_set<std::string> &provides,
191 const TapeStrategy &tapes,
bool resetProvidedGrad =
true,
193 const std::vector<StmtSetToUserGrad> &userGrads = {});
195std::tuple<Func, Func, std::unordered_map<std::string, std::string>,
196 std::unordered_map<std::string, std::string>>
198 const std::unordered_set<std::string> &_requires,
199 const std::unordered_set<std::string> &provides,
200 const TapeStrategy &tapes,
bool tapeInClosure =
true,
201 bool resetProvidedGrad =
true,
bool invert =
false,
202 const std::vector<StmtSetToUserGrad> &userGrads = {});
204std::tuple<Func, Func, std::unordered_map<std::string, std::string>,
205 std::unordered_map<std::string, std::string>>
207 const std::unordered_set<std::string> &_requires,
208 const std::unordered_set<std::string> &provides,
209 const TapeStrategy &tapes,
bool tapeInClosure =
true,
210 bool resetProvidedGrad =
true,
bool invert =
false,
211 const std::vector<StmtSetToUserGrad> &userGrads = {});
const std::unordered_map< std::string, std::string > & requireGrads() const
Definition: grad.h:132
Grad(std::unordered_map< StmtOrExprID, Derivative::LazyFullDerivative > &derivatives, const std::unordered_set< std::string > &_requires, const std::unordered_set< std::string > &provides, const std::unordered_set< ID > &tapes, const std::unordered_set< ID > &defsNeedGrad, const std::unordered_map< ID, std::string > &intermediatesMap, const std::unordered_map< StmtOrExprID, Expr > &versions, const std::unordered_map< std::string, std::pair< std::string, Expr > > &userVersions, const std::unordered_map< ID, Expr > &totLens, const std::unordered_set< ID > &saveLocalStmts, const std::unordered_set< Stmt > ¬SingleWrite, bool resetProvidedGrad, const std::unordered_map< ID, InversionInfo > &inverseStmts, const std::vector< RangeToUserGrad > &userGrads)
Definition: grad.h:109
const std::unordered_map< std::string, std::string > & provideGrads() const
Definition: grad.h:135
Stmt visitStmt(const Stmt &s) override
Definition: grad.cc:193
Stmt visit(const StmtSeq &op) override
Definition: grad.cc:231
Expr visit(const LoadAtVersion &op) override
Definition: grad.cc:31
InsertUserGrad(const SymbolTableInterface &symbolTable, const std::unordered_map< ID, std::string > &intermediatesMap, const std::unordered_map< std::string, std::pair< std::string, Expr > > &userVersions, const std::unordered_map< std::string, std::string > &gradNames)
Definition: grad.h:30
Stmt visitStmt(const Stmt &s) override
Definition: grad.h:48
Definition: replace_by_saved.h:27
Definition: symbol_table.h:13
Definition: allocator.h:9
std::tuple< Func, Func, std::unordered_map< std::string, std::string >, std::unordered_map< std::string, std::string > > gradFuncInplace(const Func &func, const std::unordered_set< std::string > &_requires, const std::unordered_set< std::string > &provides, const TapeStrategy &tapes, bool tapeInClosure=true, bool resetProvidedGrad=true, bool invert=false, const std::vector< StmtSetToUserGrad > &userGrads={})
Definition: grad.cc:825
std::tuple< Stmt, Stmt, std::unordered_map< std::string, std::string >, std::unordered_map< std::string, std::string >, std::unordered_map< ID, std::string > > gradBody(const Stmt &op, const std::unordered_set< std::string > &_requires, const std::unordered_set< std::string > &provides, const TapeStrategy &tapes, bool resetProvidedGrad=true, bool invert=false, const std::vector< StmtSetToUserGrad > &userGrads={})
Definition: grad.cc:593
std::tuple< Func, Func, std::unordered_map< std::string, std::string >, std::unordered_map< std::string, std::string > > gradFuncOutOfPlace(const Func &func, const std::unordered_set< std::string > &_requires, const std::unordered_set< std::string > &provides, const TapeStrategy &tapes, bool tapeInClosure=true, bool resetProvidedGrad=true, bool invert=false, const std::vector< StmtSetToUserGrad > &userGrads={})
Definition: grad.cc:837
Ref< StmtNode > Stmt
Definition: ast.h:152
Ref< FuncNode > Func
Definition: func.h:64