1#ifndef FREE_TENSOR_GPU_SIMPLEX_BUFFERS_H
2#define FREE_TENSOR_GPU_SIMPLEX_BUFFERS_H
6#include <unordered_set>
22 ASTHashSet<Expr> offset_;
25class FindSimplexOffset :
public SymbolTable<Visitor> {
26 typedef SymbolTable<Visitor> BaseClass;
29 std::unordered_map<ID, std::vector<Ref<SimplexOffset>>>
31 AnalyzeLinear analyzeLinear_;
34 FindSimplexOffset(
const ID &defId = ID()) : defId_(defId) {}
36 const std::unordered_map<ID, std::vector<Ref<SimplexOffset>>> &
43 getSimplexOffset(
const std::unordered_set<ParallelScope> &filter,
46 template <
class T>
void visitMemAcc(
const T &op) {
49 auto mtype = buffer(op->var_)->mtype();
50 if (mtype != MemType::GPULocal && mtype != MemType::GPUShared &&
51 mtype != MemType::GPUWarp) {
55 auto &&defId = def(op->var_)->id();
56 if (defId_.isValid() && defId_ != defId) {
60 std::vector<Ref<SimplexOffset>> thisOffsets;
61 for (
auto &&idx : op->indices_) {
62 Ref<SimplexOffset> offset;
63 if (mtype == MemType::GPUShared || mtype == MemType::GPUWarp) {
71 thisOffsets.emplace_back(offset);
74 if (!offsets_.count(defId)) {
75 offsets_[defId] = thisOffsets;
77 ASSERT(offsets_.at(defId).size() == thisOffsets.size());
78 for (
auto &&[old, cur] :
79 views::zip(offsets_.at(defId), thisOffsets)) {
81 (!cur.isValid() || old->offset_ != cur->offset_)) {
89 using BaseClass::visit;
90 void visit(
const Load &op)
override { visitMemAcc(op); }
91 void visit(
const Store &op)
override { visitMemAcc(op); }
92 void visit(
const ReduceTo &op)
override { visitMemAcc(op); }
95class ApplySimplexOffset :
public SymbolTable<Mutator> {
96 typedef SymbolTable<Mutator> BaseClass;
98 const std::unordered_map<ID, std::vector<Ref<SimplexOffset>>>
100 std::unordered_map<std::string, Expr> para2var_;
104 const std::unordered_map<ID, std::vector<Ref<SimplexOffset>>> &offsets)
105 : offsets_(offsets) {}
108 template <
class T> T visitMemAcc(
const T &_op) {
109 auto __op = BaseClass::visit(_op);
110 ASSERT(__op->nodeType() == _op->nodeType());
111 auto op = __op.template as<typename T::Object>();
113 auto &&defId = def(op->var_)->id();
114 if (offsets_.count(defId)) {
115 auto &&offset = offsets_.at(defId);
116 ASSERT(offset.size() == op->indices_.size());
117 for (
auto &&[off, idx] : views::zip(offset, op->indices_)) {
119 for (
auto &&expr : off->offset_) {
120 idx =
makeSub(idx, ReplaceIter{para2var_}(expr));
129 using BaseClass::visit;
130 Stmt visit(
const For &op)
override;
131 Expr visit(
const Load &op)
override {
return visitMemAcc(op); }
132 Stmt visit(
const Store &op)
override {
return visitMemAcc(op); }
133 Stmt visit(
const ReduceTo &op)
override {
return visitMemAcc(op); }
145Stmt simplexBuffers(
const Stmt &op,
const ID &defId = ID());
#define ASSERT(expr)
Definition: except.h:152
#define DEFINE_PASS_FOR_FUNC(pass)
Definition: func.h:88
Definition: allocator.h:9
constexpr ParallelScope threadIdxY
Definition: parallel_scope.h:115
constexpr ParallelScope blockIdxZ
Definition: parallel_scope.h:119
constexpr ParallelScope threadIdxZ
Definition: parallel_scope.h:116
Ref< StmtNode > Stmt
Definition: ast.h:152
constexpr ParallelScope blockIdxX
Definition: parallel_scope.h:117
Expr makeSub(T &&lhs, U &&rhs, std::source_location loc=std::source_location::current())
Definition: expr.h:188
Ref< ExprNode > Expr
Definition: ast.h:184
constexpr ParallelScope threadIdxX
Definition: parallel_scope.h:114
constexpr ParallelScope blockIdxY
Definition: parallel_scope.h:118