FreeTensor
Loading...
Searching...
No Matches
make_parallel_reduction.h
Go to the documentation of this file.
1#ifndef FREE_TENSOR_MAKE_PARLLEL_REDUCTION_H
2#define FREE_TENSOR_MAKE_PARLLEL_REDUCTION_H
3
4#include <memory>
5#include <unordered_map>
6#include <unordered_set>
7
12#include <driver/target.h>
13#include <func.h>
14#include <mutator.h>
15#include <visitor.h>
16
17namespace freetensor {
18
20 ParallelScope type_; // parallel type
21 std::vector<ID> outerLoops_; // outer loop ID
22};
23
24class FindAllParallel : public Visitor {
25 // Loop ID -> ParallelInfo
26 std::unordered_map<ID, ParallelInfo> results_;
27
28 std::vector<ID> loopStack_;
29
30 public:
31 const std::unordered_map<ID, ParallelInfo> &results() const {
32 return results_;
33 }
34
35 protected:
36 void visit(const For &op) override;
37};
38
40 std::unordered_map<ID, std::vector<For>>
41 results_; // ReduceTo ID -> [For], from inner to outer
42 std::vector<For> loopStack_;
43
44 public:
45 const std::unordered_map<ID, std::vector<For>> &results() const {
46 return results_;
47 }
48
49 protected:
50 void visit(const For &op) override;
51 void visit(const ReduceTo &op) override;
52};
53
61 : public CompTransientBounds<SymbolTable<Mutator>> {
63
64 struct ReductionItemFactors {
65 ReduceOp op_;
66 std::string var_;
67 std::vector<std::vector<Ref<CompUniqueBounds::Bound>>>
68 bound_; // [dim][access]
69 bool syncFlush_;
70 };
71
72 const std::unordered_map<ID, std::unordered_set<ID>>
73 &toAlter_; // ReduceTo ID -> Racing For ID
74 const LoopVariExprMap &variantMap_;
75
76 // ReduceTo IDs. For all reductions in `toAlter`, we first try to lower them
77 // as loop-carried reductions. If impossible, we then insert them to this
78 // map, which is passed to `MakeSyncReduction`.
79 std::unordered_set<ID> toUseSync_;
80
81 std::unordered_map<ID, ParallelScope> paraScopes_; // For Id -> parallel
82 std::unordered_map<ID, std::vector<ReductionItemFactors>> forReductions_;
83 std::unordered_map<ID, std::unordered_set<std::string>>
84 scopeDefined_; // For ID -> definitions at that scope
85
86 std::vector<ID> paraLoopStack_;
87
88 private:
89 bool needSync(const ReduceTo &op, const ID &loopId);
90
91 public:
93 const std::unordered_map<ID, std::unordered_set<ID>> &toAlter,
94 const LoopVariExprMap &variantMap)
95 : toAlter_(toAlter), variantMap_(variantMap) {}
96
97 const auto &toUseSync() const { return toUseSync_; }
98
99 protected:
100 using BaseClass::visit;
101 Stmt visit(const ReduceTo &op) override;
102 Stmt visit(const For &op) override;
103};
104
109class MakeSyncReduction : public SymbolTable<Mutator> {
111
112 const std::unordered_set<ID> &toUseSync_;
113 const std::unordered_map<ID, std::vector<For>>
114 &serialOverRed_; // ReduceTo ID -> [For], from inner to outer
115 const LoopVariExprMap &variantMap_;
116
117#if defined(__GNUC__) && !defined(__clang__)
118 // GCC<12 dose not support [[maybe_unused]] on member vars
119 const Ref<Target> &target_;
120#else
121 [[maybe_unused]] /* used only if FT_WITH_CUDA */ const Ref<Target> &target_;
122#endif
123
124 struct SyncCacheInfo {
125 ReduceTo oldNode_;
126 std::vector<Expr> newShape_, newTargetIndices_;
127 std::vector<bool> preserveDim_;
128 };
129 std::unordered_map<ID,
130 std::vector<SyncCacheInfo>>
131 cacheSync_; // loop ID -> [SyncCacheInfo]
132
133 int64_t gpuThreadDim_ = 1;
134
135 private:
164 bool canResideInGPULocal(DataType dtype,
165 const std::vector<Expr> &shape) const;
166
167 MemType localMType(MemType mtype, DataType dtype,
168 const std::vector<Expr> &shape) const;
169
170 public:
172 const std::unordered_set<ID> &toUseSync,
173 const std::unordered_map<ID, std::vector<For>> &serialOverRed,
174 const LoopVariExprMap &variantMap, const Ref<Target> &target)
175 : toUseSync_(toUseSync), serialOverRed_(serialOverRed),
176 variantMap_(variantMap), target_(target) {}
177
178 protected:
179 using BaseClass::visit;
180 Stmt visit(const ReduceTo &op) override;
181 Stmt visit(const For &op) override;
182};
183
193Stmt makeParallelReduction(const Stmt &op, const Ref<Target> &target);
194
196
197} // namespace freetensor
198
199#endif // FREE_TENSOR_MAKE_PARLLEL_REDUCTION_H
Definition: comp_transient_bounds.h:50
BaseClass::StmtRetType visit(const For &op) override
Definition: comp_transient_bounds.h:128
Definition: data_type.h:106
Definition: make_parallel_reduction.h:24
const std::unordered_map< ID, ParallelInfo > & results() const
Definition: make_parallel_reduction.h:31
void visit(const For &op) override
Definition: make_parallel_reduction.cc:29
Definition: make_parallel_reduction.h:39
const std::unordered_map< ID, std::vector< For > > & results() const
Definition: make_parallel_reduction.h:45
void visit(const For &op) override
Definition: make_parallel_reduction.cc:39
Definition: id.h:18
Definition: make_parallel_reduction.h:61
const auto & toUseSync() const
Definition: make_parallel_reduction.h:97
MakeLoopCarriedReduction(const std::unordered_map< ID, std::unordered_set< ID > > &toAlter, const LoopVariExprMap &variantMap)
Definition: make_parallel_reduction.h:92
Stmt visit(const ReduceTo &op) override
Definition: make_parallel_reduction.cc:71
Definition: make_parallel_reduction.h:109
Stmt visit(const ReduceTo &op) override
Definition: make_parallel_reduction.cc:229
MakeSyncReduction(const std::unordered_set< ID > &toUseSync, const std::unordered_map< ID, std::vector< For > > &serialOverRed, const LoopVariExprMap &variantMap, const Ref< Target > &target)
Definition: make_parallel_reduction.h:171
Definition: symbol_table.h:122
BaseClass::StmtRetType visit(const VarDef &op) override
Definition: symbol_table.h:167
Definition: visitor.h:11
#define DEFINE_PASS_FOR_FUNC(pass)
Definition: func.h:88
Definition: allocator.h:9
std::unordered_map< StmtOrExprID, std::unordered_map< ID, LoopVariability > > LoopVariExprMap
Definition: find_loop_variance.h:26
std::variant< SerialScope, OpenMPScope, CUDAStreamScope, CUDAScope > ParallelScope
Definition: parallel_scope.h:73
Ref< StmtNode > Stmt
Definition: ast.h:152
ReduceOp
Definition: reduce_op.h:30
Stmt makeParallelReduction(const Stmt &op, const Ref< Target > &target)
Definition: make_parallel_reduction.cc:376
MemType
Definition: mem_type.h:14
Definition: make_parallel_reduction.h:19
std::vector< ID > outerLoops_
Definition: make_parallel_reduction.h:21
ParallelScope type_
Definition: make_parallel_reduction.h:20