FreeTensor
Loading...
Searching...
No Matches
invert_stmts.h
Go to the documentation of this file.
1#ifndef FREE_TENSOR_INVERT_STMTS_H
2#define FREE_TENSOR_INVERT_STMTS_H
3
4#include <unordered_map>
5#include <unordered_set>
6
9#include <visitor.h>
10
11namespace freetensor {
12
15 Expr cond_; // null for always invert
16};
17
21class FindInvertibles : public SymbolTable<Visitor> {
23
24 // Original ID -> Inverse Statement
25 std::unordered_map<ID, Stmt> invertibles_;
26
27 public:
28 const auto &invertibles() const { return invertibles_; }
29
30 protected:
31 using BaseClass::visit;
32 // TODO 1: Support invertibles in Store.
33 // TODO 2: We now always invert first, and then compute gradient. Things
34 // will go wrong after we support Store, where we may use y for gradient.
35 void visit(const ReduceTo &op) override;
36};
37
51std::tuple<Stmt, std::unordered_map<ID, InversionInfo>>
52invertStmts(const Stmt &op,
53 std::unordered_map<ID, std::unordered_set<ID>> *idsNeeded,
54 std::unordered_map<StmtOrExprID, Derivative::LazyFullDerivative>
55 *derivatives);
56
57} // namespace freetensor
58
59#endif // FREE_TENSOR_INVERT_STMTS_H
Definition: invert_stmts.h:21
void visit(const ReduceTo &op) override
Definition: invert_stmts.cc:158
const auto & invertibles() const
Definition: invert_stmts.h:28
Definition: symbol_table.h:122
BaseClass::StmtRetType visit(const VarDef &op) override
Definition: symbol_table.h:167
Definition: allocator.h:9
Ref< StmtNode > Stmt
Definition: ast.h:152
std::tuple< Stmt, std::unordered_map< ID, InversionInfo > > invertStmts(const Stmt &op, std::unordered_map< ID, std::unordered_set< ID > > *idsNeeded, std::unordered_map< StmtOrExprID, Derivative::LazyFullDerivative > *derivatives)
Definition: invert_stmts.cc:196
Definition: invert_stmts.h:13
Stmt inv_
Definition: invert_stmts.h:14
Expr cond_
Definition: invert_stmts.h:15