1#ifndef FREE_TENSOR_GPU_MULTIPLEX_BUFFERS_H
2#define FREE_TENSOR_GPU_MULTIPLEX_BUFFERS_H
6#include <unordered_map>
7#include <unordered_set>
20class FindParallelLoops :
public Visitor {
21 Ref<GPUTarget> target_;
23 std::vector<For> loops_, stack_;
24 std::unordered_map<ID, std::unordered_map<ID, bool>>
28 FindParallelLoops(
const Ref<GPUTarget> &target,
const ID &defId = ID())
29 : target_(target), defId_(defId) {}
31 const std::vector<For> &loops()
const {
return loops_; }
32 const auto &affecting()
const {
return affecting_; }
35 void visit(
const For &op)
override;
36 void visit(
const VarDef &op)
override;
39class MultiplexMutator :
public SymbolTable<Mutator> {
40 typedef SymbolTable<Mutator> BaseClass;
42 std::vector<For> stack_;
43 std::unordered_map<std::string, int> defPos_;
44 const std::unordered_map<ID, std::unordered_set<ID>>
49 const std::unordered_map<ID, std::unordered_set<ID>> &affecting)
50 : affecting_(affecting) {}
53 template <
class T> T alterAccess(
const T &op) {
54 if (!defPos_.count(op->var_)) {
57 if (affecting_.count(def(op->var_)->id())) {
58 auto &&aff = affecting_.at(def(op->var_)->id());
59 int pos = defPos_.at(op->var_);
60 for (
int i = pos - 1; i >= 0; i--) {
61 if (aff.count(stack_[i]->id())) {
62 auto &indices = op->indices_;
63 indices.insert(indices.begin(),
makeVar(stack_[i]->
iter_));
71 Stmt visit(
const For &op)
override;
72 Stmt visit(
const VarDef &op)
override;
73 Expr visit(
const Load &op)
override;
74 Stmt visit(
const Store &op)
override;
75 Stmt visit(
const ReduceTo &op)
override;
86Stmt multiplexBuffers(
const Stmt &op,
const Ref<GPUTarget> &target,
87 const ID &defId = ID());
#define DEFINE_PASS_FOR_FUNC(pass)
Definition: func.h:88
std::vector< IterAxis > iter_
Definition: invert_stmts.cc:56
Definition: allocator.h:9
Ref< StmtNode > Stmt
Definition: ast.h:152
Expr makeVar(const std::string &name, std::source_location loc=std::source_location::current())
Definition: expr.h:42
Ref< ExprNode > Expr
Definition: ast.h:184