FreeTensor
Loading...
Searching...
No Matches
derivative.h
Go to the documentation of this file.
1#ifndef FREE_TENSOR_DERIVATIVE_H
2#define FREE_TENSOR_DERIVATIVE_H
3
4#include <optional>
5#include <unordered_map>
6#include <unordered_set>
7
10#include <mutator.h>
11#include <visitor.h>
12
13namespace freetensor {
14
31class Derivative : public SymbolTable<Visitor> {
33
34 public:
39 SymbolTableData symbolTable_;
40 Expr mathExpr_;
41 ID rootStmtID_;
42
43 // Additional info to replace `y = f(x)`'s derivative to use `y`. The
44 // replacement has to be done lazily, because the version of `y` is
45 // different from version of `x`
46 std::optional<InvertFromStore> invertFromStore_;
47 std::optional<bool> usingStore_;
48
49 std::optional<std::unordered_set<std::string>> reads_;
50
51 public:
52 LazyPartialDerivative() {} // Unintialized
53
55 const Expr &mathExpr, const ID &rootStmtID,
56 const std::optional<InvertFromStore>
57 &invertFromStore = std::nullopt)
58 : symbolTable_(symbolTable), mathExpr_(mathExpr),
59 rootStmtID_(rootStmtID), invertFromStore_(invertFromStore) {}
60
62 void merge(const LazyPartialDerivative &other) {
63 ASSERT(other.rootStmtID_ == rootStmtID_);
64 mathExpr_ = makeAdd(mathExpr_, other.mathExpr_);
65 usingStore_ = std::nullopt;
66 }
67
70 const Expr &mathExpr() const { return mathExpr_; }
71
73 bool usingStore();
74
79 const std::unordered_set<std::string> &reads();
80
83 Expr
84 replaceExpr(const std::unordered_map<ID, std::string> &intermediatesMap,
85 const std::unordered_map<StmtOrExprID, Expr> &versions,
86 const Expr &expr);
87
90 Expr
91 genReplaced(const std::unordered_map<ID, std::string> &intermediatesMap,
92 const std::unordered_map<StmtOrExprID, Expr> &versions) {
93 return replaceExpr(intermediatesMap, versions, mathExpr_);
94 }
95 };
96
110 // (x, dy/dx). Instead of using a map, an ordered list is used to ensure
111 // that the statements generated by `genGrads` are consistently ordered
112 std::vector<std::pair<Load, LazyPartialDerivative>> partials_;
113
114 // Union of all partials. These are lazily cached variables
115 std::optional<bool> usingStore_;
116 std::optional<std::unordered_set<std::string>> reads_;
117
118 std::exception_ptr error_;
119
120 public:
121 void addPartial(const Load &x, const LazyPartialDerivative &partial);
122
123 void setError(const std::exception_ptr &error) { error_ = error; }
124
125 bool usingStore();
126
127 const std::unordered_set<std::string> &reads();
128
129 std::vector<Stmt>
130 genGrads(const std::unordered_map<ID, std::string> &intermediatesMap,
131 const std::unordered_map<StmtOrExprID, Expr> &versions,
132 const std::unordered_map<std::string, std::string> &gradNames,
133 const Expr &gradY);
134 };
135
136 private:
137 std::unordered_map<StmtOrExprID, LazyFullDerivative> derivatives_;
138 std::unordered_map<Expr, LazyPartialDerivative> partials_;
139 StmtOrExprID rootExpr_;
140 std::optional<InvertFromStore> invertFromStore_;
141
147 void setPartial(const Expr &expr, const Expr &partial);
148
149 public:
150 const auto &derivatives() const { return derivatives_; }
151
152 protected:
153 using BaseClass::visit;
154
155 void visitExpr(const Expr &expr) override;
156
157 // If we have `y = f(x)` as a `Store` node, we can use `y` in the
158 // derivative. Please note that this does not apply to `ReduceTo` nodes: in
159 // `y += exp(x)`, we have no variable equals to `exp(x)`
160 void visit(const Store &op) override;
161
162 void visit(const Load &op) override;
163 void visit(const Add &op) override;
164 void visit(const Sub &op) override;
165 void visit(const Mul &op) override;
166 void visit(const RealDiv &op) override;
167 void visit(const Min &op) override;
168 void visit(const Max &op) override;
169 void visit(const IfExpr &op) override;
170 void visit(const Sqrt &op) override;
171 void visit(const Exp &op) override;
172 void visit(const Ln &op) override;
173 void visit(const Square &op) override;
174 void visit(const Sigmoid &op) override;
175 void visit(const Sin &op) override;
176 void visit(const Cos &op) override;
177 void visit(const Tan &op) override;
178 void visit(const Tanh &op) override;
179 void visit(const Abs &op) override;
180 void visit(const Cast &op) override;
181 void visit(const Intrinsic &op) override;
182};
183
184} // namespace freetensor
185
186#endif // FREE_TENSOR_DERIVATIVE_H
Definition: derivative.h:109
const std::unordered_set< std::string > & reads()
Definition: derivative.cc:65
void setError(const std::exception_ptr &error)
Definition: derivative.h:123
bool usingStore()
Definition: derivative.cc:51
void addPartial(const Load &x, const LazyPartialDerivative &partial)
Definition: derivative.cc:39
std::vector< Stmt > genGrads(const std::unordered_map< ID, std::string > &intermediatesMap, const std::unordered_map< StmtOrExprID, Expr > &versions, const std::unordered_map< std::string, std::string > &gradNames, const Expr &gradY)
Definition: derivative.cc:78
LazyPartialDerivative()
Definition: derivative.h:52
Expr replaceExpr(const std::unordered_map< ID, std::string > &intermediatesMap, const std::unordered_map< StmtOrExprID, Expr > &versions, const Expr &expr)
Definition: derivative.cc:31
bool usingStore()
True if using y for y = f(x)'s derivative.
Definition: derivative.cc:8
const Expr & mathExpr() const
Definition: derivative.h:70
const std::unordered_set< std::string > & reads()
Definition: derivative.cc:20
void merge(const LazyPartialDerivative &other)
+= another partial derivative
Definition: derivative.h:62
LazyPartialDerivative(const SymbolTableData &symbolTable, const Expr &mathExpr, const ID &rootStmtID, const std::optional< InvertFromStore > &invertFromStore=std::nullopt)
Definition: derivative.h:54
Expr genReplaced(const std::unordered_map< ID, std::string > &intermediatesMap, const std::unordered_map< StmtOrExprID, Expr > &versions)
Definition: derivative.h:91
Definition: derivative.h:31
const auto & derivatives() const
Definition: derivative.h:150
void visit(const Store &op) override
Definition: derivative.cc:136
void visitExpr(const Expr &expr) override
Definition: derivative.cc:118
Definition: id.h:18
Definition: ast.h:193
Definition: symbol_table.h:33
Definition: symbol_table.h:122
BaseClass::StmtRetType visit(const VarDef &op) override
Definition: symbol_table.h:167
#define ASSERT(expr)
Definition: except.h:152
Definition: allocator.h:9
Expr makeAdd(T &&lhs, U &&rhs, std::source_location loc=std::source_location::current())
Definition: expr.h:174