1#ifndef FREE_TENSOR_GPU_LOWER_PARALLEL_REDUCTION_H
2#define FREE_TENSOR_GPU_LOWER_PARALLEL_REDUCTION_H
6#include <unordered_map>
7#include <unordered_set>
18class InsertWorkspaces :
public SymbolTable<Mutator> {
19 typedef SymbolTable<Mutator> BaseClass;
21 std::unordered_map<ID, std::pair<std::string, Ref<ReductionItem>>>
23 std::vector<For> loopStack_;
24 std::unordered_set<std::string> handledVars_;
25 bool converged_ =
true;
28 const auto &ws2red()
const {
return ws2red_; }
29 bool converged()
const {
return converged_; }
32 std::vector<std::pair<For, int>> reducedBy(
const ReduceTo &op);
35 using BaseClass::visit;
36 Stmt visit(
const VarDef &op)
override;
37 Stmt visit(
const For &op)
override;
38 Stmt visit(
const ReduceTo &op)
override;
41class InsertBinaryReduction :
public SymbolTable<Mutator> {
42 typedef SymbolTable<Mutator> BaseClass;
44 const std::unordered_map<ID, std::pair<std::string, Ref<ReductionItem>>>
46 std::unordered_map<ID, ID>
51 InsertBinaryReduction(
52 const std::unordered_map<ID, std::pair<std::string, Ref<ReductionItem>>>
56 const auto &ws2scope()
const {
return ws2scope_; }
59 template <
class T> T visitMemAcc(
const T &_op) {
60 auto __op = BaseClass::visit(_op);
61 ASSERT(__op->nodeType() == _op->nodeType());
62 auto op = __op.template as<typename T::Object>();
63 if (
auto it = ws2red_.find(def(op->var_)->id()); it != ws2red_.end()) {
64 auto &&l = loop(it->second.first);
66 op->indices_.insert(op->indices_.begin(), nth);
72 using BaseClass::visit;
73 Stmt visit(
const VarDef &op)
override;
74 Stmt visit(
const Store &op)
override {
return visitMemAcc(op); }
75 Stmt visit(
const ReduceTo &op)
override {
return visitMemAcc(op); }
76 Expr visit(
const Load &op)
override {
return visitMemAcc(op); }
79class CorrectInterThreadDependence
80 :
public CompTransientBounds<SymbolTable<Mutator>> {
81 typedef CompTransientBounds<SymbolTable<Mutator>> BaseClass;
83 const std::unordered_map<ID, std::pair<std::string, Ref<ReductionItem>>>
86 std::unordered_map<ID, std::vector<VarDef>> loop2ws_;
89 CorrectInterThreadDependence(
90 const std::unordered_map<ID, std::pair<std::string, Ref<ReductionItem>>>
95 using BaseClass::visit;
96 Stmt visit(
const VarDef &op)
override;
97 Stmt visit(
const For &op)
override;
#define ASSERT(expr)
Definition: except.h:152
#define DEFINE_PASS_FOR_FUNC(pass)
Definition: func.h:88
Stmt lowerParallelReduction(const Stmt &op)
Definition: lower_parallel_reduction.cc:206
Definition: allocator.h:9
Ref< StmtNode > Stmt
Definition: ast.h:152
Expr makeSub(T &&lhs, U &&rhs, std::source_location loc=std::source_location::current())
Definition: expr.h:188
Expr makeVar(const std::string &name, std::source_location loc=std::source_location::current())
Definition: expr.h:42
Ref< ExprNode > Expr
Definition: ast.h:184