FreeTensor
Loading...
Searching...
No Matches
propagate_defs_need_grad.h
Go to the documentation of this file.
1#ifndef FREE_TENSOR_PROPAGATE_DEFS_NEED_GRAD_H
2#define FREE_TENSOR_PROPAGATE_DEFS_NEED_GRAD_H
3
4#include <unordered_set>
5
7#include <container_utils.h>
8#include <visitor.h>
9
10namespace freetensor {
11
16class PropagateRequires : public SymbolTable<Visitor> {
18
19 const std::unordered_set<std::string> &requires_; // input var names
20 const std::unordered_set<std::string> &provides_; // output var names
21
22 std::unordered_set<ID> affectedDefs_; // all VarDef IDs
23
24 ID curTarget_; // VarDef ID of current var being written to
25
26 public:
27 PropagateRequires(const std::unordered_set<std::string> &_requires,
28 const std::unordered_set<std::string> &provides)
29 : requires_(_requires), provides_(provides) {}
30
31 const std::unordered_set<ID> &affectedDefs() const { return affectedDefs_; }
32
33 static std::unordered_set<ID>
35 const std::unordered_set<std::string> &_requires,
36 const std::unordered_set<std::string> &provides);
37
38 protected:
39 using BaseClass::visit;
40 void visitExpr(const Expr &e) override;
41 void visit(const Load &op) override;
42 void visit(const Store &op) override;
43 void visit(const ReduceTo &op) override;
44 void visit(const VarDef &op) override;
45};
46
51class PropagateProvides : public SymbolTable<Visitor> {
53
54 const std::unordered_set<std::string> &requires_; // input var names
55 const std::unordered_set<std::string> &provides_; // output var names
56
57 std::unordered_set<ID> affectedDefs_; // all VarDef IDs
58
59 ID curTarget_; // VarDef ID of current var being written to
60
61 public:
62 PropagateProvides(const std::unordered_set<std::string> &_requires,
63 const std::unordered_set<std::string> &provides)
64 : requires_(_requires), provides_(provides) {}
65
66 const std::unordered_set<ID> &affectedDefs() const { return affectedDefs_; }
67
68 static std::unordered_set<ID>
70 const std::unordered_set<std::string> &_requires,
71 const std::unordered_set<std::string> &provides);
72
73 protected:
74 using BaseClass::visit;
75 void visitExpr(const Expr &e) override;
76 void visit(const Load &op) override;
77 void visit(const Store &op) override;
78 void visit(const ReduceTo &op) override;
79 void visit(const VarDef &op) override;
80};
81
85inline std::unordered_set<ID>
87 const std::unordered_set<std::string> &_requires,
88 const std::unordered_set<std::string> &provides) {
89 return intersect(
90 PropagateProvides::propagateUntilConverge(op, _requires, provides),
91 PropagateRequires::propagateUntilConverge(op, _requires, provides));
92}
93
94} // namespace freetensor
95
96#endif // FREE_TENSOR_PROPAGATE_DEFS_NEED_GRAD_H
Definition: id.h:18
Definition: propagate_defs_need_grad.h:51
static std::unordered_set< ID > propagateUntilConverge(const Stmt &op, const std::unordered_set< std::string > &_requires, const std::unordered_set< std::string > &provides)
Definition: propagate_defs_need_grad.cc:118
const std::unordered_set< ID > & affectedDefs() const
Definition: propagate_defs_need_grad.h:66
void visit(const Load &op) override
Definition: propagate_defs_need_grad.cc:85
PropagateProvides(const std::unordered_set< std::string > &_requires, const std::unordered_set< std::string > &provides)
Definition: propagate_defs_need_grad.h:62
void visitExpr(const Expr &e) override
Definition: propagate_defs_need_grad.cc:79
Definition: propagate_defs_need_grad.h:16
PropagateRequires(const std::unordered_set< std::string > &_requires, const std::unordered_set< std::string > &provides)
Definition: propagate_defs_need_grad.h:27
void visitExpr(const Expr &e) override
Definition: propagate_defs_need_grad.cc:6
static std::unordered_set< ID > propagateUntilConverge(const Stmt &op, const std::unordered_set< std::string > &_requires, const std::unordered_set< std::string > &provides)
Definition: propagate_defs_need_grad.cc:44
const std::unordered_set< ID > & affectedDefs() const
Definition: propagate_defs_need_grad.h:31
void visit(const Load &op) override
Definition: propagate_defs_need_grad.cc:12
Definition: symbol_table.h:122
BaseClass::StmtRetType visit(const VarDef &op) override
Definition: symbol_table.h:167
Definition: allocator.h:9
std::unordered_set< ID > propagateDefsNeedGrad(const Stmt &op, const std::unordered_set< std::string > &_requires, const std::unordered_set< std::string > &provides)
Definition: propagate_defs_need_grad.h:86
std::unordered_map< T, std::pair< V1, V2 >, Hash, KeyEqual > intersect(const std::unordered_map< T, V1, Hash, KeyEqual > &lhs, const std::unordered_map< T, V2, Hash, KeyEqual > &rhs)
Definition: container_utils.h:24