1#ifndef FREE_TENSOR_DEPS_H
2#define FREE_TENSOR_DEPS_H
8#include <unordered_map>
9#include <unordered_set>
51 std::vector<std::pair<Expr, ID>>
59 void addCoord(
int defAxis,
auto &&iter,
auto &&access,
auto &&conds) {
61 iter_ = std::forward<decltype(iter)>(iter);
62 access_ = std::forward<decltype(access)>(access);
63 conds_ = std::forward<decltype(conds)>(conds);
68 std::unordered_map<std::string, std::vector<ID>>
77 const std::unordered_map<std::string, std::vector<ID>> &
results()
const {
109 std::vector<IterAxis> cur_;
110 std::vector<std::pair<Expr, ID>>
114 std::vector<Ref<AccessPoint>> reads_, writes_;
117 std::unordered_map<ID, std::vector<IterAxis>> scope2coord_;
119 std::vector<ID> allScopes_;
126 using BaseClass::operator();
128 void pushCond(
const Expr &cond,
const ID &baseStmtId) {
129 conds_.emplace_back(cond, baseStmtId);
132 void popCond() { conds_.pop_back(); }
141 void removeTrivialScopeFromAccesses(
144 void removeTrivialScopeFromScopes(std::vector<ID>::iterator begin,
145 std::vector<ID>::iterator end);
153 const auto &
reads()
const {
return reads_; }
154 const auto &
writes()
const {
return writes_; }
158 template <
class T>
void visitStoreLike(
const T &op) {
161 bool isThisVarDef =
false;
163 if (
def(op->var_)->
id() == vardef_) {
166 for (
auto source =
def(op->var_); source->
viewOf_.has_value();) {
167 source =
def(*source->viewOf_);
168 if (source->id() == vardef_) {
179 std::vector<Expr> exprs;
188 exprs = std::vector<Expr>(
189 viewOf->
buffer_->tensor()->shape().size(),
193 exprs = op->indices_;
198 if (accFilter_ ==
nullptr || accFilter_(*ap)) {
203 cur_.back().iter_.as<IntConstNode>()->val_ + 1);
207 ap->addCoord(defAxis_, cur_, std::move(exprs), conds_);
208 writes_.emplace_back(ap);
216 void visit(
const If &op)
override;
221 void visit(
const MatMul &op)
override { (*this)(op->equivalent_); }
241typedef std::vector<std::pair<NodeIDOrParallelScope, DepDirection>>
FindDepsDir;
290 std::vector<Ref<AccessPoint>> readsAsEarlier_, readsAsLater_,
291 writesAsEarlier_, writesAsLater_;
292 const std::unordered_map<ID, std::vector<IterAxis>> &scope2coord_;
293 const std::unordered_map<std::string, std::vector<ID>>
297 const std::vector<FindDepsDir> &direction_;
304 const bool ignoreReductionWAW_;
305 const bool eraseOutsideVarDef_;
306 const bool noProjectOutPrivateAxis_;
308 std::vector<std::function<void()>> tasks_;
314 const std::unordered_map<
ID, std::vector<IterAxis>> &scope2coord,
315 const std::unordered_map<std::string, std::vector<ID>> &noDepsLists,
317 const std::vector<FindDepsDir> &direction,
321 bool ignoreReductionWAW,
bool eraseOutsideVarDef,
322 bool noProjectOutPrivateAxis)
323 : scope2coord_(scope2coord), noDepsLists_(noDepsLists),
324 variantExpr_(variantExpr), direction_(direction), found_(found),
325 earlierFilter_(earlierFilter), laterFilter_(laterFilter),
326 filter_(
filter), mode_(mode), depType_(depType),
327 ignoreReductionWAW_(ignoreReductionWAW),
328 eraseOutsideVarDef_(eraseOutsideVarDef),
329 noProjectOutPrivateAxis_(noProjectOutPrivateAxis) {
333 (earlierFilter_ ==
nullptr || earlierFilter_(*acc));
338 (laterFilter_ ==
nullptr || laterFilter_(*acc));
342 return earlierFilter_ ==
nullptr || earlierFilter_(*acc);
346 return laterFilter_ ==
nullptr || laterFilter_(*acc);
352 const std::vector<std::function<void()>> &
tasks()
const {
return tasks_; }
355 static std::string
makeIterList(
const std::vector<IterAxis> &list,
int n);
356 static std::string
makeNegIterMap(
const std::vector<IterAxis> &list,
int n);
357 static std::string
makeNdList(
const std::string &name,
int n);
359 std::pair<std::string , std::string >>
367 const std::string &extSuffix,
370 bool eraseOutsideVarDef);
374 int iterDim,
int accDim,
const std::string &extSuffix,
378 externals, noNeedToBeVars, eraseOutsideVarDef_);
381 PBMap makeEqForBothOps(
const Ref<PBCtx> &presburger,
382 const std::vector<std::pair<int, int>> &coord,
384 PBMap makeIneqBetweenOps(
const Ref<PBCtx> &presburger,
DepDirection mode,
385 int iterId,
int iterDim)
const;
387 PBMap makeSerialToAll(
const Ref<PBCtx> &presburger,
int iterDim,
388 const std::vector<IterAxis> &point)
const;
389 static int countSerial(
const std::vector<IterAxis> &point);
391 PBMap makeExternalEq(
const Ref<PBCtx> &presburger,
int iterDim,
392 const std::string &ext1,
const std::string &ext2);
394 PBMap makeConstraintOfSingleLoop(
const Ref<PBCtx> &presburger,
398 PBMap makeConstraintOfParallelScope(
const Ref<PBCtx> &presburger,
401 const AccessPoint &later,
402 const AccessPoint &earlier);
418 PBMap makeEraseVarDefConstraint(
const Ref<PBCtx> &presburger,
419 const Ref<AccessPoint> &point,
int iterDim);
424 PBMap makeNoDepsConstraint(
const Ref<PBCtx> &presburger,
425 const std::string &var,
int iterDim);
441 PBMap makeExternalVarConstraint(
const Ref<PBCtx> &presburger,
442 const Ref<AccessPoint> &later,
443 const Ref<AccessPoint> &earlier,
463 PBMap projectOutPrivateAxis(
const Ref<PBCtx> &presburger,
int iterDim,
465 void projectOutPrivateAxis(
const Ref<PBCtx> &presburger,
466 const Ref<AccessPoint> &point,
467 const std::vector<Ref<AccessPoint>> &otherList,
468 std::vector<PBMap> &otherMapList,
int iterDim);
469 int numCommonDims(
const Ref<AccessPoint> &p1,
const Ref<AccessPoint> &p2);
471 void checkAgainstCond(
const Ref<AccessPoint> &later,
472 const Ref<AccessPoint> &earlier,
const PBMap &depAll,
473 const PBMap &nearest,
const PBMap &laterMap,
474 const PBMap &earlierMap,
const PBMap &extConstraint,
477 static const std::string &getVar(
const AST &op);
487 checkDepLatestEarlier(
const Ref<AccessPoint> &later,
488 const std::vector<Ref<AccessPoint>> &earlierList);
490 checkDepLatestEarlierImpl(
const Ref<PBCtx> &presburger,
491 const Ref<AccessPoint> &later,
492 const std::vector<Ref<AccessPoint>> &earlierList);
501 void checkDepEarliestLater(
const std::vector<Ref<AccessPoint>> &laterList,
502 const Ref<AccessPoint> &earlier);
504 checkDepEarliestLaterImpl(
const Ref<PBCtx> &presburger,
505 const std::vector<Ref<AccessPoint>> &laterList,
506 const Ref<AccessPoint> &earlier);
518 std::vector<FindDepsDir> direction_ = {{}};
523 std::function<void(
const ID &,
524 const std::unordered_map<
ID, std::vector<IterAxis>> &)>
525 scope2CoordCallback_ =
nullptr;
526 bool ignoreReductionWAW_ =
true;
527 bool eraseOutsideVarDef_ =
true;
528 bool noProjectOutPrivateAxis_ =
false;
605 ret.accFilter_ = ret.accFilter_ ==
nullptr
609 return f0(acc) && f1(acc);
626 ret.earlierFilter_ ==
nullptr
628 : [f0 = ret.earlierFilter_, f1 = f](
const AccessPoint &acc) {
629 return f0(acc) && f1(acc);
645 ret.laterFilter_ ==
nullptr
647 : [f0 = ret.laterFilter_, f1 = f](
const AccessPoint &acc) {
648 return f0(acc) && f1(acc);
665 ret.filter_ ==
nullptr
667 : [f0 = ret.filter_, f1 = f](
const AccessPoint &later,
669 return f0(later, earlier) && f1(later, earlier);
691 ret.ignoreReductionWAW_ = flag;
702 ret.eraseOutsideVarDef_ = flag;
715 ret.noProjectOutPrivateAxis_ = flag;
721 const ID &,
const std::unordered_map<
ID, std::vector<IterAxis>> &)>
724 ret.scope2CoordCallback_ = callback;
735 const std::function<
void(
const Dependence &)> &found) {
758std::ostream &
operator<<(std::ostream &os,
const Dependence &dep);
void genTasks()
Definition: deps.cc:1269
AnalyzeDeps(const std::vector< Ref< AccessPoint > > &reads, const std::vector< Ref< AccessPoint > > &writes, const std::unordered_map< ID, std::vector< IterAxis > > &scope2coord, const std::unordered_map< std::string, std::vector< ID > > &noDepsLists, Lazy< LoopVariExprMap > &variantExpr, const std::vector< FindDepsDir > &direction, const FindDepsCallback &found, FindDepsMode mode, DepType depType, const FindDepsAccPtFilter &earlierFilter, const FindDepsAccPtFilter &laterFilter, const FindDepsFilter &filter, bool ignoreReductionWAW, bool eraseOutsideVarDef, bool noProjectOutPrivateAxis)
Definition: deps.h:311
static std::vector< std::pair< std::string, std::string > > makeAccList(GenPBExpr &genPBExpr, const std::vector< Expr > &list, GenPBExpr::VarMap &externals)
Definition: deps.cc:315
static std::string makeCond(GenPBExpr &genPBExpr, GenPBExpr::VarMap &externals, bool eraseOutsideVarDef, const AccessPoint &ap)
Definition: deps.cc:346
static std::string makeNegIterMap(const std::vector< IterAxis > &list, int n)
Definition: deps.cc:296
const std::vector< std::function< void()> > & tasks() const
Definition: deps.h:352
static std::string makeNdList(const std::string &name, int n)
Definition: deps.cc:475
static std::string makeIterList(const std::vector< IterAxis > &list, int n)
Definition: deps.cc:274
static PBMap makeAccMapStatic(const Ref< PBCtx > &presburger, const AccessPoint &p, int iterDim, int accDim, const std::string &extSuffix, GenPBExpr::VarMap &externals, const ASTHashSet< Expr > &noNeedToBeVars, bool eraseOutsideVarDef)
Definition: deps.cc:437
static constexpr auto Int32
Definition: data_type.h:128
void visit(const VarDef &op) override
Definition: deps.cc:90
void visit(const MatMul &op) override
Definition: deps.h:221
void visit(const Store &op) override
Definition: deps.h:218
void doFind(const Stmt &root)
Definition: deps.cc:30
void visit(const ReduceTo &op) override
Definition: deps.h:219
const auto & writes() const
Definition: deps.h:154
const auto & reads() const
Definition: deps.h:153
const auto & scope2coord() const
Definition: deps.h:155
const std::unordered_map< std::string, std::vector< ID > > & results() const
Definition: deps.h:77
void visit(const For &op) override
Definition: deps.cc:19
FindDeps noProjectOutPrivateAxis(bool flag)
Definition: deps.h:713
FindDeps eraseOutsideVarDef(bool flag)
Definition: deps.h:700
bool exists(const Stmt &op)
Definition: deps.cc:1392
FindDeps filterEarlier(const FindDepsAccPtFilter &f)
Definition: deps.h:623
FindDeps scope2CoordCallback(std::function< void(const ID &, const std::unordered_map< ID, std::vector< IterAxis > > &)> callback)
Definition: deps.h:719
FindDeps type(DepType t)
Definition: deps.h:560
FindDeps mode(FindDepsMode m)
Definition: deps.h:549
FindDeps filterSubAST(const ID &subAST)
Definition: deps.h:677
FindDeps direction(const std::vector< FindDepsDir > &d)
Definition: deps.h:580
FindDeps filterAccess(const FindDepsAccFilter &f)
Definition: deps.h:603
FindDeps filterLater(const FindDepsAccPtFilter &f)
Definition: deps.h:642
FindDeps filterAccess(const std::function< bool(const AccessPointBase &)> &f)
Definition: deps.h:600
FindDeps ignoreReductionWAW(bool flag)
Definition: deps.h:689
FindDeps filter(const FindDepsFilter &f)
Definition: deps.h:661
void operator()(const Stmt &op, const std::function< void(const Dependence &)> &found)
Definition: deps.h:734
Definition: gen_pb_expr.h:34
ASTHashMap< Expr, std::string > VarMap
Definition: gen_pb_expr.h:37
Definition: presburger.h:75
static Ref make()
Definition: ref.h:105
bool isValid() const
Definition: ref.h:89
Ref< StmtNode > ancestorById(const ID &lookup) const
Definition: ast.cc:279
ID id() const
Definition: ast.cc:362
Definition: symbol_table.h:122
BaseClass::StmtRetType visit(const VarDef &op) override
Definition: symbol_table.h:167
const VarDef & def(const std::string &name) const override
Definition: symbol_table.h:142
Definition: sync_func.h:16
const Stmt & curStmt() const
Definition: track_stmt.h:31
SubTree< Buffer > buffer_
Definition: stmt.h:86
std::optional< std::string > viewOf_
Definition: stmt.h:99
Definition: allocator.h:9
detail::TaggedSyncFunc< std::remove_reference_t< T > > syncFunc(T &&f)
Definition: sync_func.h:76
std::function< bool(const AccessPoint &later, const AccessPoint &earlier)> FindDepsFilter
Definition: deps.h:89
int DepType
Definition: deps.h:91
Ref< VarDefNode > VarDef
Definition: stmt.h:107
Expr makeIntrinsic(const std::string &format, T &¶ms, DataType retType, bool hasSideEffect, std::source_location loc=std::source_location::current())
Definition: expr.h:756
std::function< bool(const AccessPoint &)> FindDepsAccPtFilter
Definition: deps.h:86
SyncFunc< void(const Dependence &)> FindDepsCallback
Definition: deps.h:273
Ref< IfNode > If
Definition: stmt.h:352
Ref< ForNode > For
Definition: stmt.h:308
std::vector< T > filter(const std::vector< T > &vec, const U &callback)
Definition: container_utils.h:131
std::vector< std::pair< NodeIDOrParallelScope, DepDirection > > FindDepsDir
Definition: deps.h:241
Ref< ASTNode > AST
Definition: ast.h:149
std::variant< SerialScope, OpenMPScope, CUDAStreamScope, CUDAScope > ParallelScope
Definition: parallel_scope.h:73
const DepType DEP_RAW
Definition: deps.h:94
DepDirection
Definition: deps.h:224
const DepType DEP_WAR
Definition: deps.h:93
FindDepsMode
Definition: deps.h:275
const DepType DEP_ALL
Definition: deps.h:95
constexpr ParallelScope serialScope
Definition: parallel_scope.h:112
SyncFunc< bool(const AccessPointBase &)> FindDepsAccFilter
Definition: deps.h:85
std::ostream & operator<<(std::ostream &os, const Dependence &dep)
Definition: deps.cc:1404
Expr makeIntConst(int64_t val, std::source_location loc=std::source_location::current())
Definition: expr.h:102
Ref< AssertNode > Assert
Definition: stmt.h:392
std::unordered_set< K, Hasher, HashComparator > ASTHashSet
Definition: hash.h:117
detail::TaggedUnsyncFunc< std::remove_reference_t< T > > unsyncFunc(T &&f)
Definition: sync_func.h:84
Ref< StmtSeqNode > StmtSeq
Definition: stmt.h:49
const DepType DEP_WAW
Definition: deps.h:92
Ref< Buffer > buffer_
Definition: deps.h:40
AccessPointBase(const AST &op, const Stmt &stmt, const VarDef &def, const Ref< Buffer > &buffer)
Definition: deps.h:42
VarDef def_
Definition: deps.h:39
Stmt stmt_
Definition: deps.h:38
AST op_
Definition: deps.h:37
AccessPoint(const AST &op, const Stmt &stmt, const VarDef &def, const Ref< Buffer > &buffer)
Definition: deps.h:55
std::vector< Expr > access_
The temporal location of the access.
Definition: deps.h:50
void addCoord(int defAxis, auto &&iter, auto &&access, auto &&conds)
Definition: deps.h:59
std::vector< IterAxis > iter_
The position of the VarDef.
Definition: deps.h:49
int defAxis_
Definition: deps.h:48
std::vector< std::pair< Expr, ID > > conds_
The spatial location of the access.
Definition: deps.h:52
const AST & later() const
Definition: deps.h:262
const FindDepsDir & dir_
Definition: deps.h:246
const std::string & var_
Direction vector filtering out this dependence.
Definition: deps.h:247
PBMap earlierIter2Idx_
Definition: deps.h:255
const VarDef & def() const
Definition: deps.h:264
PBMap laterIter2Idx_
Definition: deps.h:255
ID defId() const
Definition: deps.h:265
PBMap later2EarlierIter_
Definition: deps.h:254
const AccessPoint & later_
Definition: deps.h:248
const AST & earlier() const
Definition: deps.h:263
int iterDim_
Definition: deps.h:249
PBMap extConstraint_
Definition: deps.h:258
PBMap later2EarlierIterAllPossible_
Definition: deps.h:257
PBMap extraCheck(PBMap dep, const NodeIDOrParallelScope &nodeOrParallel, const DepDirection &dir) const
Definition: deps.cc:1312
AnalyzeDeps & self_
Definition: deps.h:259
const AccessPoint & earlier_
Definition: deps.h:248
bool negStep_
Definition: deps.h:29
IterAxis(Expr iter, const ParallelScope ¶llel=serialScope, bool negStep=false)
Definition: deps.h:31
ParallelScope parallel_
Definition: deps.h:28
Expr iter_
Definition: deps.h:27
ID id_
Definition: deps.h:232
ParallelScope parallel_
Definition: deps.h:233
NodeIDOrParallelScope(const ParallelScope ¶llel)
Definition: deps.h:237
bool isNode_
Definition: deps.h:234
NodeIDOrParallelScope(const ID &id)
Definition: deps.h:236