|
FreeTensor
|
#include <derivative.h>


Classes | |
| class | LazyFullDerivative |
| class | LazyPartialDerivative |
Public Member Functions | |
| const auto & | derivatives () const |
Public Member Functions inherited from freetensor::SymbolTable< Visitor > | |
| SymbolTable (T &&...args) | |
| const std::unordered_set< std::string > & | names () const override |
| const std::unordered_map< std::string, VarDef > & | defs () const override |
| const std::unordered_map< std::string, For > & | loops () const override |
| bool | hasDef (const std::string &name) const override |
| const VarDef & | def (const std::string &name) const override |
| Ref< Buffer > | buffer (const std::string &name) const override |
| bool | hasLoop (const std::string &name) const override |
| const For & | loop (const std::string &name) const override |
| void | pushDef (const VarDef &op) override |
| void | popDef (const VarDef &op) override |
| void | pushFor (const For &op) override |
| void | popFor (const For &op) override |
| const SymbolTableData & | symbolTableSnapshot () const |
Public Member Functions inherited from freetensor::Visitor | |
| virtual | ~Visitor () |
| virtual void | operator() (const AST &op) final |
| virtual const std::unordered_set< std::string > & | names () const =0 |
| virtual const std::unordered_map< std::string, VarDef > & | defs () const =0 |
| virtual const std::unordered_map< std::string, For > & | loops () const =0 |
| virtual bool | hasDef (const std::string &name) const =0 |
| virtual const VarDef & | def (const std::string &name) const =0 |
| virtual Ref< Buffer > | buffer (const std::string &name) const =0 |
| virtual bool | hasLoop (const std::string &name) const =0 |
| virtual const For & | loop (const std::string &name) const =0 |
| virtual void | pushDef (const VarDef &op)=0 |
| virtual void | popDef (const VarDef &op)=0 |
| virtual void | pushFor (const For &op)=0 |
| virtual void | popFor (const For &op)=0 |
Protected Member Functions | |
| void | visitExpr (const Expr &expr) override |
| void | visit (const Store &op) override |
| void | visit (const Load &op) override |
| void | visit (const Add &op) override |
| void | visit (const Sub &op) override |
| void | visit (const Mul &op) override |
| void | visit (const RealDiv &op) override |
| void | visit (const Min &op) override |
| void | visit (const Max &op) override |
| void | visit (const IfExpr &op) override |
| void | visit (const Sqrt &op) override |
| void | visit (const Exp &op) override |
| void | visit (const Ln &op) override |
| void | visit (const Square &op) override |
| void | visit (const Sigmoid &op) override |
| void | visit (const Sin &op) override |
| void | visit (const Cos &op) override |
| void | visit (const Tan &op) override |
| void | visit (const Tanh &op) override |
| void | visit (const Abs &op) override |
| void | visit (const Cast &op) override |
| void | visit (const Intrinsic &op) override |
Protected Member Functions inherited from freetensor::SymbolTable< Visitor > | |
| BaseClass::StmtRetType | visit (const VarDef &op) override |
| BaseClass::StmtRetType | visit (const For &op) override |
Protected Member Functions inherited from freetensor::Visitor | |
| virtual void | visitExpr (const Expr &op) |
| virtual void | visitStmt (const Stmt &op) |
| virtual void | visit (const Any &op) |
| virtual void | visit (const AnyExpr &op) |
| virtual void | visit (const Func &op) |
| virtual void | visit (const StmtSeq &op) |
| virtual void | visit (const VarDef &op) |
| virtual void | visit (const Var &op) |
| virtual void | visit (const Store &op) |
| virtual void | visit (const Alloc &op) |
| virtual void | visit (const Free &op) |
| virtual void | visit (const Load &op) |
| virtual void | visit (const ReduceTo &op) |
| virtual void | visit (const IntConst &op) |
| virtual void | visit (const FloatConst &op) |
| virtual void | visit (const BoolConst &op) |
| virtual void | visit (const Add &op) |
| virtual void | visit (const Sub &op) |
| virtual void | visit (const Mul &op) |
| virtual void | visit (const RealDiv &op) |
| virtual void | visit (const FloorDiv &op) |
| virtual void | visit (const CeilDiv &op) |
| virtual void | visit (const RoundTowards0Div &op) |
| virtual void | visit (const Mod &op) |
| virtual void | visit (const Remainder &op) |
| virtual void | visit (const Min &op) |
| virtual void | visit (const Max &op) |
| virtual void | visit (const LT &op) |
| virtual void | visit (const LE &op) |
| virtual void | visit (const GT &op) |
| virtual void | visit (const GE &op) |
| virtual void | visit (const EQ &op) |
| virtual void | visit (const NE &op) |
| virtual void | visit (const LAnd &op) |
| virtual void | visit (const LOr &op) |
| virtual void | visit (const LNot &op) |
| virtual void | visit (const Sqrt &op) |
| virtual void | visit (const Exp &op) |
| virtual void | visit (const Ln &op) |
| virtual void | visit (const Square &op) |
| virtual void | visit (const Sigmoid &op) |
| virtual void | visit (const Sin &op) |
| virtual void | visit (const Cos &op) |
| virtual void | visit (const Tan &op) |
| virtual void | visit (const Tanh &op) |
| virtual void | visit (const Abs &op) |
| virtual void | visit (const Floor &op) |
| virtual void | visit (const Ceil &op) |
| virtual void | visit (const Unbound &op) |
| virtual void | visit (const For &op) |
| virtual void | visit (const If &op) |
| virtual void | visit (const Assert &op) |
| virtual void | visit (const Assume &op) |
| virtual void | visit (const IfExpr &op) |
| virtual void | visit (const Cast &op) |
| virtual void | visit (const Intrinsic &op) |
| virtual void | visit (const Eval &op) |
| virtual void | visit (const MatMul &op) |
| virtual void | visit (const MarkVersion &op) |
| virtual void | visit (const LoadAtVersion &op) |
Additional Inherited Members | |
Public Types inherited from freetensor::Visitor | |
| typedef void | ExprRetType |
| typedef void | StmtRetType |
Find derivative of each expression
Gradients will be updated by multiplying the derivatives
Derivative has two phase. In the first phase, derivative expressions are built according to mathematicall principles, but the variables they load may not exist in an actual backward pass. The result is stored in Derivative::LazyFullDerivative.
The result of the first phase can be used to decide which variabels have to be saved to tape or recomputed.
In the second phase, given the information of tape and recomputation, the variables in derivative expressions are corrected.
|
inline |
|
overrideprotectedvirtual |
Reimplemented from freetensor::Visitor.
|
overrideprotectedvirtual |
Reimplemented from freetensor::Visitor.
|
overrideprotectedvirtual |
Reimplemented from freetensor::Visitor.
|
overrideprotectedvirtual |
Reimplemented from freetensor::Visitor.
|
overrideprotectedvirtual |
Reimplemented from freetensor::Visitor.
|
overrideprotectedvirtual |
Reimplemented from freetensor::Visitor.
|
overrideprotectedvirtual |
Reimplemented from freetensor::Visitor.
|
overrideprotectedvirtual |
Reimplemented from freetensor::Visitor.
|
overrideprotectedvirtual |
Reimplemented from freetensor::Visitor.
|
overrideprotectedvirtual |
Reimplemented from freetensor::Visitor.
|
overrideprotectedvirtual |
Reimplemented from freetensor::Visitor.
|
overrideprotectedvirtual |
Reimplemented from freetensor::Visitor.
|
overrideprotectedvirtual |
Reimplemented from freetensor::Visitor.
|
overrideprotectedvirtual |
Reimplemented from freetensor::Visitor.
|
overrideprotectedvirtual |
Reimplemented from freetensor::Visitor.
|
overrideprotectedvirtual |
Reimplemented from freetensor::Visitor.
|
overrideprotectedvirtual |
Reimplemented from freetensor::Visitor.
|
overrideprotectedvirtual |
Reimplemented from freetensor::Visitor.
|
overrideprotectedvirtual |
Reimplemented from freetensor::Visitor.
|
overrideprotectedvirtual |
Reimplemented from freetensor::Visitor.
|
overrideprotectedvirtual |
Reimplemented from freetensor::Visitor.
|
overrideprotectedvirtual |
Reimplemented from freetensor::Visitor.