FreeTensor
Loading...
Searching...
No Matches
output_intermediates.h
Go to the documentation of this file.
1#ifndef FREE_TENSOR_OUTPUT_INTERMEDIATES_H
2#define FREE_TENSOR_OUTPUT_INTERMEDIATES_H
3
4#include <unordered_map>
5#include <unordered_set>
6
9#include <mutator.h>
10
11namespace freetensor {
12
14
15class OutputIntermediates : public SymbolTable<Mutator> {
17
18 const std::unordered_map<StmtOrExprID, Expr> &versions_;
19 const std::unordered_map<ID, Expr> &totLens_;
20 const std::unordered_set<ID> &trivials_;
22 std::string varSuffix_;
23
24 std::unordered_map<ID, std::string> savedNames_;
25 std::unordered_set<ID> insertedStmts_;
26 std::unordered_map<ID, std::vector<Stmt>> toSave_;
27 ID curStmt_;
28
29 public:
30 OutputIntermediates(const std::unordered_map<StmtOrExprID, Expr> &versions,
31 const std::unordered_map<ID, Expr> &totLens,
32 const std::unordered_set<ID> &trivials,
34 const std::string &varSuffix)
35 : versions_(versions), totLens_(totLens), trivials_(trivials),
36 stage_(stage), varSuffix_(varSuffix) {}
37
38 const auto &savedNames() const { return savedNames_; }
39 const auto &insertedStmts() const { return insertedStmts_; }
40
41 private:
42 std::string savingName(const std::string &oldName) const;
43
44 protected:
45 using BaseClass::visit;
46 Stmt visitStmt(const Stmt &stmt) override;
47 Expr visit(const Load &op) override;
48 Stmt visit(const Store &op) override;
49 Stmt visit(const ReduceTo &op) override;
50 Stmt visit(const VarDef &op) override;
51};
52
115std::tuple<Stmt, std::unordered_map<ID, std::string>,
116 std::unordered_map<StmtOrExprID, Expr>, std::unordered_map<ID, Expr>,
117 std::unordered_set<ID>,
118 std::unordered_map<std::string, std::pair<std::string, Expr>>>
120 const Stmt &op,
121 const std::unordered_map<ID, std::unordered_set<ID>> &needVersions,
122 const std::unordered_map<StmtOrExprID, Derivative::LazyFullDerivative>
123 &derivatives,
125 const std::string &varSuffix = ".tape");
126
133 const Stmt &op, const std::unordered_set<ID> &intermediates,
135 const std::string &varSuffix = ".tape");
136
137} // namespace freetensor
138
139#endif // FREE_TENSOR_OUTPUT_INTERMEDIATES_H
Definition: id.h:18
Definition: output_intermediates.h:15
OutputIntermediates(const std::unordered_map< StmtOrExprID, Expr > &versions, const std::unordered_map< ID, Expr > &totLens, const std::unordered_set< ID > &trivials, OutputIntermediatesStage stage, const std::string &varSuffix)
Definition: output_intermediates.h:30
Stmt visitStmt(const Stmt &stmt) override
Definition: output_intermediates.cc:57
const auto & savedNames() const
Definition: output_intermediates.h:38
const auto & insertedStmts() const
Definition: output_intermediates.h:39
Expr visit(const Load &op) override
Definition: output_intermediates.cc:68
Definition: symbol_table.h:122
BaseClass::StmtRetType visit(const VarDef &op) override
Definition: symbol_table.h:167
Definition: allocator.h:9
std::tuple< Stmt, std::unordered_map< ID, std::string >, std::unordered_map< StmtOrExprID, Expr >, std::unordered_map< ID, Expr >, std::unordered_set< ID >, std::unordered_map< std::string, std::pair< std::string, Expr > > > outputIntermediates(const Stmt &op, const std::unordered_map< ID, std::unordered_set< ID > > &needVersions, const std::unordered_map< StmtOrExprID, Derivative::LazyFullDerivative > &derivatives, OutputIntermediatesStage stage=OutputIntermediatesStage::Forward, const std::string &varSuffix=".tape")
Definition: output_intermediates.cc:163
Stmt outputAllIntermedaites(const Stmt &op, const std::unordered_set< ID > &intermediates, OutputIntermediatesStage stage=OutputIntermediatesStage::Forward, const std::string &varSuffix=".tape")
Definition: output_intermediates.cc:178
Ref< StmtNode > Stmt
Definition: ast.h:152
OutputIntermediatesStage
Definition: output_intermediates.h:13