FreeTensor
Loading...
Searching...
No Matches
lower_parallel_reduction.h
Go to the documentation of this file.
1#ifndef FREE_TENSOR_GPU_LOWER_PARALLEL_REDUCTION_H
2#define FREE_TENSOR_GPU_LOWER_PARALLEL_REDUCTION_H
3
4#ifdef FT_WITH_CUDA
5
6#include <unordered_map>
7#include <unordered_set>
8
11#include <func.h>
12#include <mutator.h>
13
14namespace freetensor {
15
16namespace gpu {
17
18class InsertWorkspaces : public SymbolTable<Mutator> {
19 typedef SymbolTable<Mutator> BaseClass;
20
21 std::unordered_map<ID, std::pair<std::string, Ref<ReductionItem>>>
22 ws2red_; // workspace ID -> (loop iter name, reduction info)
23 std::vector<For> loopStack_;
24 std::unordered_set<std::string> handledVars_;
25 bool converged_ = true;
26
27 public:
28 const auto &ws2red() const { return ws2red_; }
29 bool converged() const { return converged_; }
30
31 private:
32 std::vector<std::pair<For, int>> reducedBy(const ReduceTo &op);
33
34 protected:
35 using BaseClass::visit;
36 Stmt visit(const VarDef &op) override;
37 Stmt visit(const For &op) override;
38 Stmt visit(const ReduceTo &op) override;
39};
40
41class InsertBinaryReduction : public SymbolTable<Mutator> {
42 typedef SymbolTable<Mutator> BaseClass;
43
44 const std::unordered_map<ID, std::pair<std::string, Ref<ReductionItem>>>
45 &ws2red_; // workspace ID -> (loop iter name, reduction info)
46 std::unordered_map<ID, ID>
47 ws2scope_; // workspace ID -> scope that actually do the computation,
48 // excluding initialization, binary reduction and flushing
49
50 public:
51 InsertBinaryReduction(
52 const std::unordered_map<ID, std::pair<std::string, Ref<ReductionItem>>>
53 &ws2red)
54 : ws2red_(ws2red) {}
55
56 const auto &ws2scope() const { return ws2scope_; }
57
58 private:
59 template <class T> T visitMemAcc(const T &_op) {
60 auto __op = BaseClass::visit(_op);
61 ASSERT(__op->nodeType() == _op->nodeType());
62 auto op = __op.template as<typename T::Object>();
63 if (auto it = ws2red_.find(def(op->var_)->id()); it != ws2red_.end()) {
64 auto &&l = loop(it->second.first);
65 auto nth = makeSub(makeVar(l->iter_), l->begin_);
66 op->indices_.insert(op->indices_.begin(), nth);
67 }
68 return op;
69 }
70
71 protected:
72 using BaseClass::visit;
73 Stmt visit(const VarDef &op) override;
74 Stmt visit(const Store &op) override { return visitMemAcc(op); }
75 Stmt visit(const ReduceTo &op) override { return visitMemAcc(op); }
76 Expr visit(const Load &op) override { return visitMemAcc(op); }
77};
78
79class CorrectInterThreadDependence
80 : public CompTransientBounds<SymbolTable<Mutator>> {
81 typedef CompTransientBounds<SymbolTable<Mutator>> BaseClass;
82
83 const std::unordered_map<ID, std::pair<std::string, Ref<ReductionItem>>>
84 &ws2red_; // workspace ID -> (loop iter name, reduction info)
85
86 std::unordered_map<ID, std::vector<VarDef>> loop2ws_;
87
88 public:
89 CorrectInterThreadDependence(
90 const std::unordered_map<ID, std::pair<std::string, Ref<ReductionItem>>>
91 &ws2red)
92 : ws2red_(ws2red) {}
93
94 protected:
95 using BaseClass::visit;
96 Stmt visit(const VarDef &op) override;
97 Stmt visit(const For &op) override;
98};
99
101
102DEFINE_PASS_FOR_FUNC(lowerParallelReduction)
103
104} // namespace gpu
105
106} // namespace freetensor
107
108#endif // FT_WITH_CUDA
109
110#endif // FREE_TENSOR_GPU_LOWER_PARALLEL_REDUCTION_H
#define ASSERT(expr)
Definition: except.h:152
#define DEFINE_PASS_FOR_FUNC(pass)
Definition: func.h:88
Stmt lowerParallelReduction(const Stmt &op)
Definition: lower_parallel_reduction.cc:206
Definition: allocator.h:9
Ref< StmtNode > Stmt
Definition: ast.h:152
Expr makeSub(T &&lhs, U &&rhs, std::source_location loc=std::source_location::current())
Definition: expr.h:188
Expr makeVar(const std::string &name, std::source_location loc=std::source_location::current())
Definition: expr.h:42
Ref< ExprNode > Expr
Definition: ast.h:184