FreeTensor
Loading...
Searching...
No Matches
replace_by_saved.h
Go to the documentation of this file.
1#ifndef FREE_TENSOR_REPLACE_BY_SAVED_H
2#define FREE_TENSOR_REPLACE_BY_SAVED_H
3
4#include <optional>
5
8#include <mutator.h>
9
10namespace freetensor {
11
27class ReplaceBySaved : public Mutator {
28 const SymbolTableInterface &symbolTable_;
29 const std::unordered_map<ID, std::string> &intermediatesMap_;
30 const std::unordered_map<StmtOrExprID, Expr> &versions_;
31 ID rootStmtID_;
32 std::optional<InvertFromStore> invertFromStore_;
33 bool isGrad_ = false;
34
35 public:
37 const SymbolTableInterface &symbolTable,
38 const std::unordered_map<ID, std::string> &intermediatesMap,
39 const std::unordered_map<StmtOrExprID, Expr> &versions,
40 const ID &rootStmtID,
41 const std::optional<InvertFromStore> &invertFromStore = std::nullopt)
42 : symbolTable_(symbolTable), intermediatesMap_(intermediatesMap),
43 versions_(versions), rootStmtID_(rootStmtID),
44 invertFromStore_(invertFromStore) {}
45
46 // Replace recomputing expressions
47 auto recomp(const auto &op) {
48 isGrad_ = false;
49 return (*this)(op);
50 }
51
52 // Replace gradient expressions
53 auto grad(const auto &op) {
54 isGrad_ = true;
55 return (*this)(op);
56 }
57
58 private:
59 // Disabled. Use `ReplcaeBySaved::recomp` or `RepalceBySaved::grad` instaed
60 using Mutator::operator();
61
62 protected:
63 Expr visitExpr(const Expr &expr) override;
64 Expr visit(const Load &op) override;
65};
66
67} // namespace freetensor
68
69#endif // FREE_TENSOR_REPLACE_BY_SAVED_H
Definition: id.h:18
Definition: mutator.h:11
Definition: replace_by_saved.h:27
Expr visitExpr(const Expr &expr) override
Definition: replace_by_saved.cc:6
auto grad(const auto &op)
Definition: replace_by_saved.h:53
ReplaceBySaved(const SymbolTableInterface &symbolTable, const std::unordered_map< ID, std::string > &intermediatesMap, const std::unordered_map< StmtOrExprID, Expr > &versions, const ID &rootStmtID, const std::optional< InvertFromStore > &invertFromStore=std::nullopt)
Definition: replace_by_saved.h:36
Expr visit(const Load &op) override
Definition: replace_by_saved.cc:28
auto recomp(const auto &op)
Definition: replace_by_saved.h:47
Definition: symbol_table.h:13
Definition: allocator.h:9