FreeTensor
Loading...
Searching...
No Matches
make_sync.h
Go to the documentation of this file.
1#ifndef FREE_TENSOR_GPU_MAKE_SYNC_H
2#define FREE_TENSOR_GPU_MAKE_SYNC_H
3
4#ifdef FT_WITH_CUDA
5
6#include <optional>
7#include <unordered_map>
8#include <unordered_set>
9
11#include <driver/target.h>
12#include <func.h>
13#include <math/bounds.h>
14#include <mutator.h>
15#include <visitor.h>
16
17namespace freetensor {
18
19namespace gpu {
20
21struct ThreadInfo {
22 For loop_;
23 bool inWarp_;
24};
25
26class FindAllThreads : public Visitor {
27 int warpSize_;
28 std::optional<int> thx_ = 1;
29 std::optional<int> thy_ = 1;
30 std::optional<int> thz_ = 1;
31 std::unordered_map<ID, ThreadInfo> results_;
32
33 public:
34 FindAllThreads(const Ref<GPUTarget> &target)
35 : warpSize_(target->warpSize()) {}
36
37 const std::unordered_map<ID, ThreadInfo> &results() const {
38 return results_;
39 }
40
41 protected:
42 void visit(const For &op) override;
43};
44
45class CopyParts : public Mutator {
46 Expr cond_;
47 const std::vector<Stmt> &splitters_;
48 std::unordered_set<Stmt> fullParts_;
49
50 public:
51 CopyParts(const Expr &cond, const std::vector<Stmt> &splitters)
52 : cond_(cond), splitters_(splitters) {}
53
54 protected:
55 Stmt visitStmt(const Stmt &op) override;
56 Stmt visit(const For &op) override;
57 Stmt visit(const If &op) override;
58 Stmt visit(const Assert &op) override;
59 Stmt visit(const VarDef &op) override;
60 Stmt visit(const StmtSeq &op) override;
61};
62
63struct CrossThreadDep {
64 Stmt later_, earlier_, lcaStmt_, lcaLoop_;
65 bool inWarp_;
66 bool visiting_ = false, synced_ = false, syncedOnlyInBranch_ = false;
67
68 CrossThreadDep(const Stmt &later, const Stmt &earlier, const Stmt &lcaStmt,
69 const Stmt &lcaLoop, bool inWarp)
70 : later_(later), earlier_(earlier), lcaStmt_(lcaStmt),
71 lcaLoop_(lcaLoop), inWarp_(inWarp) {}
72};
73
74class MakeSync : public Mutator {
75 typedef Mutator BaseClass;
76
77 Stmt root_;
78 const std::unordered_map<ID, ThreadInfo> &loop2thread_;
79 std::vector<CrossThreadDep> deps_;
80 std::unordered_map<ID, std::pair<Stmt, bool /* isSyncWarp */>>
81 syncBeforeFor_, syncBeforeIf_, syncBeforeLib_;
82 std::unordered_map<ID, std::vector<Stmt>> branchSplittersThen_,
83 branchSplittersElse_;
84 LoopVariExprMap variantExprs_;
85
86 public:
87 MakeSync(const Stmt root,
88 const std::unordered_map<ID, ThreadInfo> &loop2thread,
89 std::vector<CrossThreadDep> &&deps, LoopVariExprMap &&variantExprs)
90 : root_(root), loop2thread_(loop2thread), deps_(std::move(deps)),
91 variantExprs_(std::move(variantExprs)) {}
92
93 private:
122 static Stmt makeSyncThreads();
123 static Stmt makeSyncWarp();
184 void markSyncForSplitting(const Stmt &stmtInTree, const Stmt &sync,
185 bool isSyncWarp);
186
187 protected:
188 Stmt visitStmt(const Stmt &op) override;
189 Stmt visit(const For &op) override;
190 Stmt visit(const If &op) override;
191 Stmt visit(const MatMul &op) override;
192};
193
194Stmt makeSync(const Stmt &op, const Ref<GPUTarget> &target);
195
196DEFINE_PASS_FOR_FUNC(makeSync)
197
198} // namespace gpu
199
200} // namespace freetensor
201
202#endif // FT_WITH_CUDA
203
204#endif // FREE_TENSOR_GPU_MAKE_SYNC_H
#define DEFINE_PASS_FOR_FUNC(pass)
Definition: func.h:88
Expr cond_
Definition: invert_stmts.cc:58
Definition: allocator.h:9
std::unordered_map< StmtOrExprID, std::unordered_map< ID, LoopVariability > > LoopVariExprMap
Definition: find_loop_variance.h:26
Ref< ForNode > For
Definition: stmt.h:308
Stmt lcaStmt(const Stmt &lhs, const Stmt &rhs)
Definition: ast.cc:383
Ref< StmtNode > Stmt
Definition: ast.h:152
Ref< ExprNode > Expr
Definition: ast.h:184
STL namespace.