1#ifndef FREE_TENSOR_GPU_MAKE_SYNC_H
2#define FREE_TENSOR_GPU_MAKE_SYNC_H
7#include <unordered_map>
8#include <unordered_set>
26class FindAllThreads :
public Visitor {
28 std::optional<int> thx_ = 1;
29 std::optional<int> thy_ = 1;
30 std::optional<int> thz_ = 1;
31 std::unordered_map<ID, ThreadInfo> results_;
34 FindAllThreads(
const Ref<GPUTarget> &target)
35 : warpSize_(target->warpSize()) {}
37 const std::unordered_map<ID, ThreadInfo> &results()
const {
42 void visit(
const For &op)
override;
45class CopyParts :
public Mutator {
47 const std::vector<Stmt> &splitters_;
48 std::unordered_set<Stmt> fullParts_;
51 CopyParts(
const Expr &cond,
const std::vector<Stmt> &splitters)
52 :
cond_(cond), splitters_(splitters) {}
55 Stmt visitStmt(
const Stmt &op)
override;
56 Stmt visit(
const For &op)
override;
57 Stmt visit(
const If &op)
override;
58 Stmt visit(
const Assert &op)
override;
59 Stmt visit(
const VarDef &op)
override;
60 Stmt visit(
const StmtSeq &op)
override;
63struct CrossThreadDep {
64 Stmt later_, earlier_, lcaStmt_, lcaLoop_;
66 bool visiting_ =
false, synced_ =
false, syncedOnlyInBranch_ =
false;
68 CrossThreadDep(
const Stmt &later,
const Stmt &earlier,
const Stmt &lcaStmt,
69 const Stmt &lcaLoop,
bool inWarp)
70 : later_(later), earlier_(earlier), lcaStmt_(
lcaStmt),
71 lcaLoop_(lcaLoop), inWarp_(inWarp) {}
74class MakeSync :
public Mutator {
75 typedef Mutator BaseClass;
78 const std::unordered_map<ID, ThreadInfo> &loop2thread_;
79 std::vector<CrossThreadDep> deps_;
80 std::unordered_map<ID, std::pair<
Stmt,
bool >>
81 syncBeforeFor_, syncBeforeIf_, syncBeforeLib_;
82 std::unordered_map<ID, std::vector<Stmt>> branchSplittersThen_,
87 MakeSync(
const Stmt root,
88 const std::unordered_map<ID, ThreadInfo> &loop2thread,
89 std::vector<CrossThreadDep> &&deps, LoopVariExprMap &&variantExprs)
90 : root_(root), loop2thread_(loop2thread), deps_(
std::move(deps)),
91 variantExprs_(
std::move(variantExprs)) {}
122 static Stmt makeSyncThreads();
123 static Stmt makeSyncWarp();
184 void markSyncForSplitting(
const Stmt &stmtInTree,
const Stmt &sync,
188 Stmt visitStmt(
const Stmt &op)
override;
189 Stmt visit(
const For &op)
override;
190 Stmt visit(
const If &op)
override;
191 Stmt visit(
const MatMul &op)
override;
194Stmt makeSync(
const Stmt &op,
const Ref<GPUTarget> &target);
#define DEFINE_PASS_FOR_FUNC(pass)
Definition: func.h:88
Expr cond_
Definition: invert_stmts.cc:58
Definition: allocator.h:9
std::unordered_map< StmtOrExprID, std::unordered_map< ID, LoopVariability > > LoopVariExprMap
Definition: find_loop_variance.h:26
Ref< ForNode > For
Definition: stmt.h:308
Stmt lcaStmt(const Stmt &lhs, const Stmt &rhs)
Definition: ast.cc:383
Ref< StmtNode > Stmt
Definition: ast.h:152
Ref< ExprNode > Expr
Definition: ast.h:184