FreeTensor
Loading...
Searching...
No Matches
Classes | Public Member Functions | Protected Member Functions | List of all members
freetensor::Derivative Class Reference

#include <derivative.h>

Inheritance diagram for freetensor::Derivative:
Inheritance graph
[legend]
Collaboration diagram for freetensor::Derivative:
Collaboration graph
[legend]

Classes

class  LazyFullDerivative
 
class  LazyPartialDerivative
 

Public Member Functions

const auto & derivatives () const
 
- Public Member Functions inherited from freetensor::SymbolTable< Visitor >
 SymbolTable (T &&...args)
 
const std::unordered_set< std::string > & names () const override
 
const std::unordered_map< std::string, VarDef > & defs () const override
 
const std::unordered_map< std::string, For > & loops () const override
 
bool hasDef (const std::string &name) const override
 
const VarDefdef (const std::string &name) const override
 
Ref< Bufferbuffer (const std::string &name) const override
 
bool hasLoop (const std::string &name) const override
 
const Forloop (const std::string &name) const override
 
void pushDef (const VarDef &op) override
 
void popDef (const VarDef &op) override
 
void pushFor (const For &op) override
 
void popFor (const For &op) override
 
const SymbolTableDatasymbolTableSnapshot () const
 
- Public Member Functions inherited from freetensor::Visitor
virtual ~Visitor ()
 
virtual void operator() (const AST &op) final
 
virtual const std::unordered_set< std::string > & names () const =0
 
virtual const std::unordered_map< std::string, VarDef > & defs () const =0
 
virtual const std::unordered_map< std::string, For > & loops () const =0
 
virtual bool hasDef (const std::string &name) const =0
 
virtual const VarDefdef (const std::string &name) const =0
 
virtual Ref< Bufferbuffer (const std::string &name) const =0
 
virtual bool hasLoop (const std::string &name) const =0
 
virtual const Forloop (const std::string &name) const =0
 
virtual void pushDef (const VarDef &op)=0
 
virtual void popDef (const VarDef &op)=0
 
virtual void pushFor (const For &op)=0
 
virtual void popFor (const For &op)=0
 

Protected Member Functions

void visitExpr (const Expr &expr) override
 
void visit (const Store &op) override
 
void visit (const Load &op) override
 
void visit (const Add &op) override
 
void visit (const Sub &op) override
 
void visit (const Mul &op) override
 
void visit (const RealDiv &op) override
 
void visit (const Min &op) override
 
void visit (const Max &op) override
 
void visit (const IfExpr &op) override
 
void visit (const Sqrt &op) override
 
void visit (const Exp &op) override
 
void visit (const Ln &op) override
 
void visit (const Square &op) override
 
void visit (const Sigmoid &op) override
 
void visit (const Sin &op) override
 
void visit (const Cos &op) override
 
void visit (const Tan &op) override
 
void visit (const Tanh &op) override
 
void visit (const Abs &op) override
 
void visit (const Cast &op) override
 
void visit (const Intrinsic &op) override
 
- Protected Member Functions inherited from freetensor::SymbolTable< Visitor >
BaseClass::StmtRetType visit (const VarDef &op) override
 
BaseClass::StmtRetType visit (const For &op) override
 
- Protected Member Functions inherited from freetensor::Visitor
virtual void visitExpr (const Expr &op)
 
virtual void visitStmt (const Stmt &op)
 
virtual void visit (const Any &op)
 
virtual void visit (const AnyExpr &op)
 
virtual void visit (const Func &op)
 
virtual void visit (const StmtSeq &op)
 
virtual void visit (const VarDef &op)
 
virtual void visit (const Var &op)
 
virtual void visit (const Store &op)
 
virtual void visit (const Alloc &op)
 
virtual void visit (const Free &op)
 
virtual void visit (const Load &op)
 
virtual void visit (const ReduceTo &op)
 
virtual void visit (const IntConst &op)
 
virtual void visit (const FloatConst &op)
 
virtual void visit (const BoolConst &op)
 
virtual void visit (const Add &op)
 
virtual void visit (const Sub &op)
 
virtual void visit (const Mul &op)
 
virtual void visit (const RealDiv &op)
 
virtual void visit (const FloorDiv &op)
 
virtual void visit (const CeilDiv &op)
 
virtual void visit (const RoundTowards0Div &op)
 
virtual void visit (const Mod &op)
 
virtual void visit (const Remainder &op)
 
virtual void visit (const Min &op)
 
virtual void visit (const Max &op)
 
virtual void visit (const LT &op)
 
virtual void visit (const LE &op)
 
virtual void visit (const GT &op)
 
virtual void visit (const GE &op)
 
virtual void visit (const EQ &op)
 
virtual void visit (const NE &op)
 
virtual void visit (const LAnd &op)
 
virtual void visit (const LOr &op)
 
virtual void visit (const LNot &op)
 
virtual void visit (const Sqrt &op)
 
virtual void visit (const Exp &op)
 
virtual void visit (const Ln &op)
 
virtual void visit (const Square &op)
 
virtual void visit (const Sigmoid &op)
 
virtual void visit (const Sin &op)
 
virtual void visit (const Cos &op)
 
virtual void visit (const Tan &op)
 
virtual void visit (const Tanh &op)
 
virtual void visit (const Abs &op)
 
virtual void visit (const Floor &op)
 
virtual void visit (const Ceil &op)
 
virtual void visit (const Unbound &op)
 
virtual void visit (const For &op)
 
virtual void visit (const If &op)
 
virtual void visit (const Assert &op)
 
virtual void visit (const Assume &op)
 
virtual void visit (const IfExpr &op)
 
virtual void visit (const Cast &op)
 
virtual void visit (const Intrinsic &op)
 
virtual void visit (const Eval &op)
 
virtual void visit (const MatMul &op)
 
virtual void visit (const MarkVersion &op)
 
virtual void visit (const LoadAtVersion &op)
 

Additional Inherited Members

- Public Types inherited from freetensor::Visitor
typedef void ExprRetType
 
typedef void StmtRetType
 

Detailed Description

Find derivative of each expression

Gradients will be updated by multiplying the derivatives

Derivative has two phase. In the first phase, derivative expressions are built according to mathematicall principles, but the variables they load may not exist in an actual backward pass. The result is stored in Derivative::LazyFullDerivative.

The result of the first phase can be used to decide which variabels have to be saved to tape or recomputed.

In the second phase, given the information of tape and recomputation, the variables in derivative expressions are corrected.

Member Function Documentation

◆ derivatives()

const auto & freetensor::Derivative::derivatives ( ) const
inline

◆ visit() [1/21]

void freetensor::Derivative::visit ( const Abs op)
overrideprotectedvirtual

Reimplemented from freetensor::Visitor.

◆ visit() [2/21]

void freetensor::Derivative::visit ( const Add op)
overrideprotectedvirtual

Reimplemented from freetensor::Visitor.

◆ visit() [3/21]

void freetensor::Derivative::visit ( const Cast op)
overrideprotectedvirtual

Reimplemented from freetensor::Visitor.

◆ visit() [4/21]

void freetensor::Derivative::visit ( const Cos op)
overrideprotectedvirtual

Reimplemented from freetensor::Visitor.

◆ visit() [5/21]

void freetensor::Derivative::visit ( const Exp op)
overrideprotectedvirtual

Reimplemented from freetensor::Visitor.

◆ visit() [6/21]

void freetensor::Derivative::visit ( const IfExpr op)
overrideprotectedvirtual

Reimplemented from freetensor::Visitor.

◆ visit() [7/21]

void freetensor::Derivative::visit ( const Intrinsic op)
overrideprotectedvirtual

Reimplemented from freetensor::Visitor.

◆ visit() [8/21]

void freetensor::Derivative::visit ( const Ln op)
overrideprotectedvirtual

Reimplemented from freetensor::Visitor.

◆ visit() [9/21]

void freetensor::Derivative::visit ( const Load op)
overrideprotectedvirtual

Reimplemented from freetensor::Visitor.

◆ visit() [10/21]

void freetensor::Derivative::visit ( const Max op)
overrideprotectedvirtual

Reimplemented from freetensor::Visitor.

◆ visit() [11/21]

void freetensor::Derivative::visit ( const Min op)
overrideprotectedvirtual

Reimplemented from freetensor::Visitor.

◆ visit() [12/21]

void freetensor::Derivative::visit ( const Mul op)
overrideprotectedvirtual

Reimplemented from freetensor::Visitor.

◆ visit() [13/21]

void freetensor::Derivative::visit ( const RealDiv op)
overrideprotectedvirtual

Reimplemented from freetensor::Visitor.

◆ visit() [14/21]

void freetensor::Derivative::visit ( const Sigmoid op)
overrideprotectedvirtual

Reimplemented from freetensor::Visitor.

◆ visit() [15/21]

void freetensor::Derivative::visit ( const Sin op)
overrideprotectedvirtual

Reimplemented from freetensor::Visitor.

◆ visit() [16/21]

void freetensor::Derivative::visit ( const Sqrt op)
overrideprotectedvirtual

Reimplemented from freetensor::Visitor.

◆ visit() [17/21]

void freetensor::Derivative::visit ( const Square op)
overrideprotectedvirtual

Reimplemented from freetensor::Visitor.

◆ visit() [18/21]

void freetensor::Derivative::visit ( const Store op)
overrideprotectedvirtual

Reimplemented from freetensor::Visitor.

◆ visit() [19/21]

void freetensor::Derivative::visit ( const Sub op)
overrideprotectedvirtual

Reimplemented from freetensor::Visitor.

◆ visit() [20/21]

void freetensor::Derivative::visit ( const Tan op)
overrideprotectedvirtual

Reimplemented from freetensor::Visitor.

◆ visit() [21/21]

void freetensor::Derivative::visit ( const Tanh op)
overrideprotectedvirtual

Reimplemented from freetensor::Visitor.

◆ visitExpr()

void freetensor::Derivative::visitExpr ( const Expr expr)
overrideprotectedvirtual

Reimplemented from freetensor::Visitor.


The documentation for this class was generated from the following files: