1#ifndef FREE_TENSOR_DERIVATIVE_H
2#define FREE_TENSOR_DERIVATIVE_H
5#include <unordered_map>
6#include <unordered_set>
46 std::optional<InvertFromStore> invertFromStore_;
47 std::optional<bool> usingStore_;
49 std::optional<std::unordered_set<std::string>> reads_;
56 const std::optional<InvertFromStore>
57 &invertFromStore = std::nullopt)
58 : symbolTable_(symbolTable), mathExpr_(
mathExpr),
59 rootStmtID_(rootStmtID), invertFromStore_(invertFromStore) {}
63 ASSERT(other.rootStmtID_ == rootStmtID_);
64 mathExpr_ =
makeAdd(mathExpr_, other.mathExpr_);
65 usingStore_ = std::nullopt;
79 const std::unordered_set<std::string> &
reads();
84 replaceExpr(
const std::unordered_map<ID, std::string> &intermediatesMap,
85 const std::unordered_map<StmtOrExprID, Expr> &versions,
91 genReplaced(
const std::unordered_map<ID, std::string> &intermediatesMap,
92 const std::unordered_map<StmtOrExprID, Expr> &versions) {
93 return replaceExpr(intermediatesMap, versions, mathExpr_);
112 std::vector<std::pair<Load, LazyPartialDerivative>> partials_;
115 std::optional<bool> usingStore_;
116 std::optional<std::unordered_set<std::string>> reads_;
118 std::exception_ptr error_;
123 void setError(
const std::exception_ptr &error) { error_ = error; }
127 const std::unordered_set<std::string> &
reads();
130 genGrads(
const std::unordered_map<ID, std::string> &intermediatesMap,
131 const std::unordered_map<StmtOrExprID, Expr> &versions,
132 const std::unordered_map<std::string, std::string> &gradNames,
137 std::unordered_map<StmtOrExprID, LazyFullDerivative> derivatives_;
138 std::unordered_map<Expr, LazyPartialDerivative> partials_;
140 std::optional<InvertFromStore> invertFromStore_;
147 void setPartial(
const Expr &expr,
const Expr &partial);
172 void visit(
const Ln &op)
override;
Definition: derivative.h:109
const std::unordered_set< std::string > & reads()
Definition: derivative.cc:65
void setError(const std::exception_ptr &error)
Definition: derivative.h:123
bool usingStore()
Definition: derivative.cc:51
void addPartial(const Load &x, const LazyPartialDerivative &partial)
Definition: derivative.cc:39
std::vector< Stmt > genGrads(const std::unordered_map< ID, std::string > &intermediatesMap, const std::unordered_map< StmtOrExprID, Expr > &versions, const std::unordered_map< std::string, std::string > &gradNames, const Expr &gradY)
Definition: derivative.cc:78
Definition: derivative.h:38
LazyPartialDerivative()
Definition: derivative.h:52
Expr replaceExpr(const std::unordered_map< ID, std::string > &intermediatesMap, const std::unordered_map< StmtOrExprID, Expr > &versions, const Expr &expr)
Definition: derivative.cc:31
bool usingStore()
True if using y for y = f(x)'s derivative.
Definition: derivative.cc:8
const Expr & mathExpr() const
Definition: derivative.h:70
const std::unordered_set< std::string > & reads()
Definition: derivative.cc:20
void merge(const LazyPartialDerivative &other)
+= another partial derivative
Definition: derivative.h:62
LazyPartialDerivative(const SymbolTableData &symbolTable, const Expr &mathExpr, const ID &rootStmtID, const std::optional< InvertFromStore > &invertFromStore=std::nullopt)
Definition: derivative.h:54
Expr genReplaced(const std::unordered_map< ID, std::string > &intermediatesMap, const std::unordered_map< StmtOrExprID, Expr > &versions)
Definition: derivative.h:91
Definition: derivative.h:31
const auto & derivatives() const
Definition: derivative.h:150
void visit(const Store &op) override
Definition: derivative.cc:136
void visitExpr(const Expr &expr) override
Definition: derivative.cc:118
Definition: symbol_table.h:33
Definition: symbol_table.h:122
BaseClass::StmtRetType visit(const VarDef &op) override
Definition: symbol_table.h:167
#define ASSERT(expr)
Definition: except.h:152
Definition: allocator.h:9
Expr makeAdd(T &&lhs, U &&rhs, std::source_location loc=std::source_location::current())
Definition: expr.h:174