FreeTensor
Loading...
Searching...
No Matches
multiplex_buffers.h
Go to the documentation of this file.
1#ifndef FREE_TENSOR_GPU_MULTIPLEX_BUFFERS_H
2#define FREE_TENSOR_GPU_MULTIPLEX_BUFFERS_H
3
4#ifdef FT_WITH_CUDA
5
6#include <unordered_map>
7#include <unordered_set>
8#include <vector>
9
11#include <driver/target.h>
12#include <func.h>
13#include <mutator.h>
14#include <visitor.h>
15
16namespace freetensor {
17
18namespace gpu {
19
20class FindParallelLoops : public Visitor {
21 Ref<GPUTarget> target_;
22 ID defId_;
23 std::vector<For> loops_, stack_;
24 std::unordered_map<ID, std::unordered_map<ID, bool>>
25 affecting_; // VarDef ID -> {For ID -> can communicate?}
26
27 public:
28 FindParallelLoops(const Ref<GPUTarget> &target, const ID &defId = ID())
29 : target_(target), defId_(defId) {}
30
31 const std::vector<For> &loops() const { return loops_; }
32 const auto &affecting() const { return affecting_; }
33
34 protected:
35 void visit(const For &op) override;
36 void visit(const VarDef &op) override;
37};
38
39class MultiplexMutator : public SymbolTable<Mutator> {
40 typedef SymbolTable<Mutator> BaseClass;
41
42 std::vector<For> stack_;
43 std::unordered_map<std::string, int> defPos_;
44 const std::unordered_map<ID, std::unordered_set<ID>>
45 &affecting_; // VarDef ID -> For ID
46
47 public:
48 MultiplexMutator(
49 const std::unordered_map<ID, std::unordered_set<ID>> &affecting)
50 : affecting_(affecting) {}
51
52 private:
53 template <class T> T alterAccess(const T &op) {
54 if (!defPos_.count(op->var_)) {
55 return op;
56 }
57 if (affecting_.count(def(op->var_)->id())) {
58 auto &&aff = affecting_.at(def(op->var_)->id());
59 int pos = defPos_.at(op->var_);
60 for (int i = pos - 1; i >= 0; i--) {
61 if (aff.count(stack_[i]->id())) {
62 auto &indices = op->indices_;
63 indices.insert(indices.begin(), makeVar(stack_[i]->iter_));
64 }
65 }
66 }
67 return op;
68 }
69
70 protected:
71 Stmt visit(const For &op) override;
72 Stmt visit(const VarDef &op) override;
73 Expr visit(const Load &op) override;
74 Stmt visit(const Store &op) override;
75 Stmt visit(const ReduceTo &op) override;
76};
77
86Stmt multiplexBuffers(const Stmt &op, const Ref<GPUTarget> &target,
87 const ID &defId = ID());
88
89DEFINE_PASS_FOR_FUNC(multiplexBuffers)
90
91} // namespace gpu
92
93} // namespace freetensor
94
95#endif // FT_WITH_CUDA
96
97#endif // FREE_TENSOR_GPU_MULTIPLEX_BUFFERS_H
#define DEFINE_PASS_FOR_FUNC(pass)
Definition: func.h:88
std::vector< IterAxis > iter_
Definition: invert_stmts.cc:56
Definition: allocator.h:9
Ref< StmtNode > Stmt
Definition: ast.h:152
Expr makeVar(const std::string &name, std::source_location loc=std::source_location::current())
Definition: expr.h:42
Ref< ExprNode > Expr
Definition: ast.h:184