FreeTensor
Loading...
Searching...
No Matches
simplex_buffers.h
Go to the documentation of this file.
1#ifndef FREE_TENSOR_GPU_SIMPLEX_BUFFERS_H
2#define FREE_TENSOR_GPU_SIMPLEX_BUFFERS_H
3
4#ifdef FT_WITH_CUDA
5
6#include <unordered_set>
7
10#include <container_utils.h>
11#include <func.h>
12#include <hash.h>
13#include <mutator.h>
14#include <pass/replace_iter.h>
15#include <visitor.h>
16
17namespace freetensor {
18
19namespace gpu {
20
21struct SimplexOffset {
22 ASTHashSet<Expr> offset_;
23};
24
25class FindSimplexOffset : public SymbolTable<Visitor> {
26 typedef SymbolTable<Visitor> BaseClass;
27
28 ID defId_;
29 std::unordered_map<ID, std::vector<Ref<SimplexOffset>>>
30 offsets_; // def ID -> [offset for each index]
31 AnalyzeLinear analyzeLinear_;
32
33 public:
34 FindSimplexOffset(const ID &defId = ID()) : defId_(defId) {}
35
36 const std::unordered_map<ID, std::vector<Ref<SimplexOffset>>> &
37 offsets() const {
38 return offsets_;
39 }
40
41 private:
42 Ref<SimplexOffset>
43 getSimplexOffset(const std::unordered_set<ParallelScope> &filter,
44 const Expr &expr);
45
46 template <class T> void visitMemAcc(const T &op) {
47 BaseClass::visit(op);
48
49 auto mtype = buffer(op->var_)->mtype();
50 if (mtype != MemType::GPULocal && mtype != MemType::GPUShared &&
51 mtype != MemType::GPUWarp) {
52 return;
53 }
54
55 auto &&defId = def(op->var_)->id();
56 if (defId_.isValid() && defId_ != defId) {
57 return;
58 }
59
60 std::vector<Ref<SimplexOffset>> thisOffsets;
61 for (auto &&idx : op->indices_) {
62 Ref<SimplexOffset> offset;
63 if (mtype == MemType::GPUShared || mtype == MemType::GPUWarp) {
64 offset =
65 getSimplexOffset({blockIdxX, blockIdxY, blockIdxZ}, idx);
66 } else {
67 offset = getSimplexOffset({threadIdxX, threadIdxY, threadIdxZ,
69 idx);
70 }
71 thisOffsets.emplace_back(offset);
72 }
73
74 if (!offsets_.count(defId)) {
75 offsets_[defId] = thisOffsets;
76 } else {
77 ASSERT(offsets_.at(defId).size() == thisOffsets.size());
78 for (auto &&[old, cur] :
79 views::zip(offsets_.at(defId), thisOffsets)) {
80 if (old.isValid() &&
81 (!cur.isValid() || old->offset_ != cur->offset_)) {
82 old = nullptr;
83 }
84 }
85 }
86 }
87
88 protected:
89 using BaseClass::visit;
90 void visit(const Load &op) override { visitMemAcc(op); }
91 void visit(const Store &op) override { visitMemAcc(op); }
92 void visit(const ReduceTo &op) override { visitMemAcc(op); }
93};
94
95class ApplySimplexOffset : public SymbolTable<Mutator> {
96 typedef SymbolTable<Mutator> BaseClass;
97
98 const std::unordered_map<ID, std::vector<Ref<SimplexOffset>>>
99 &offsets_; // def ID -> [offset for each index]
100 std::unordered_map<std::string, Expr> para2var_;
101
102 public:
103 ApplySimplexOffset(
104 const std::unordered_map<ID, std::vector<Ref<SimplexOffset>>> &offsets)
105 : offsets_(offsets) {}
106
107 private:
108 template <class T> T visitMemAcc(const T &_op) {
109 auto __op = BaseClass::visit(_op);
110 ASSERT(__op->nodeType() == _op->nodeType());
111 auto op = __op.template as<typename T::Object>();
112
113 auto &&defId = def(op->var_)->id();
114 if (offsets_.count(defId)) {
115 auto &&offset = offsets_.at(defId);
116 ASSERT(offset.size() == op->indices_.size());
117 for (auto &&[off, idx] : views::zip(offset, op->indices_)) {
118 if (off.isValid()) {
119 for (auto &&expr : off->offset_) {
120 idx = makeSub(idx, ReplaceIter{para2var_}(expr));
121 }
122 }
123 }
124 }
125 return op;
126 }
127
128 protected:
129 using BaseClass::visit;
130 Stmt visit(const For &op) override;
131 Expr visit(const Load &op) override { return visitMemAcc(op); }
132 Stmt visit(const Store &op) override { return visitMemAcc(op); }
133 Stmt visit(const ReduceTo &op) override { return visitMemAcc(op); }
134};
135
145Stmt simplexBuffers(const Stmt &op, const ID &defId = ID());
146
147DEFINE_PASS_FOR_FUNC(simplexBuffers)
148
149} // namespace gpu
150
151} // namespace freetensor
152
153#endif // FT_WITH_CUDA
154
155#endif // FREE_TENSOR_GPU_SIMPLEX_BUFFERS_H
#define ASSERT(expr)
Definition: except.h:152
#define DEFINE_PASS_FOR_FUNC(pass)
Definition: func.h:88
Definition: allocator.h:9
constexpr ParallelScope threadIdxY
Definition: parallel_scope.h:115
constexpr ParallelScope blockIdxZ
Definition: parallel_scope.h:119
constexpr ParallelScope threadIdxZ
Definition: parallel_scope.h:116
Ref< StmtNode > Stmt
Definition: ast.h:152
constexpr ParallelScope blockIdxX
Definition: parallel_scope.h:117
Expr makeSub(T &&lhs, U &&rhs, std::source_location loc=std::source_location::current())
Definition: expr.h:188
Ref< ExprNode > Expr
Definition: ast.h:184
constexpr ParallelScope threadIdxX
Definition: parallel_scope.h:114
constexpr ParallelScope blockIdxY
Definition: parallel_scope.h:118