1#ifndef FREE_TENSOR_MAKE_MATMUL_H
2#define FREE_TENSOR_MAKE_MATMUL_H
5#include <unordered_map>
6#include <unordered_set>
25 return os <<
"keep_mem_layout";
27 return os <<
"try_var_reorder";
29 return os <<
"try_transpose";
40 const std::string &msg)
45 static std::string genErrMsg(
const ID &vardef,
46 const std::vector<int> &order,
47 const std::string &msg) {
48 std::ostringstream os;
49 os << msg <<
". Consider retrying after `var_reorder`ing " << vardef
50 <<
" to order [" << order
51 <<
"], or retrying with a different `mode` of `as_matmul`";
63 std::unordered_map<std::string, int> iterMap_;
64 std::unordered_set<std::string> outerDefs_;
65 std::vector<VarDef> innerDefs_;
66 std::vector<int> orderInit_;
68 bool foundInit_ =
false, foundLeaf_ =
false, inside_ =
false;
69 Expr a_, b_, c_, initC_, m_, k_, n_, lda_, stridea_, ldb_, strideb_, ldc_,
71 bool aIsRowMajor_, bIsRowMajor_, cIsRowMajor_;
78 std::vector<bool> dimsABatch_, dimsBBatch_, dimsCBatch_, dimsAM_, dimsAK_,
79 dimsBK_, dimsBN_, dimsCM_, dimsCN_;
80 ID defIdA_, defIdB_, defIdC_;
84 : loop_(
loop), backend_(backend) {}
92 const auto &
dimsAM()
const {
return dimsAM_; }
93 const auto &
dimsAK()
const {
return dimsAK_; }
94 const auto &
dimsBK()
const {
return dimsBK_; }
95 const auto &
dimsBN()
const {
return dimsBN_; }
96 const auto &
dimsCM()
const {
return dimsCM_; }
97 const auto &
dimsCN()
const {
return dimsCN_; }
98 const auto &
defIdA()
const {
return defIdA_; }
99 const auto &
defIdB()
const {
return defIdB_; }
100 const auto &
defIdC()
const {
return defIdC_; }
106 std::tuple<std::vector<bool>, std::vector<int>,
Expr>
107 findIterUsedAndBaseAddr(
const T &acc) {
108 std::vector<bool> usedBy(nestCnt_,
false);
109 std::vector<int> order;
111 buffer(acc->var_)->tensor()->dtype());
112 for (
auto &&[idx, dimLen, baseIdx] :
113 views::zip(acc->indices_,
buffer(acc->var_)->tensor()->shape(),
115 auto &&lin = analyzeLinear(idx);
116 if (lin.coeff_.size() != 1 ||
117 std::abs(lin.coeff_.front().k_) != 1 ||
121 " should be plain loop iterators");
125 Var var = lin.coeff_.front().a_.template as<VarNode>();
126 if (!iterMap_.count(var->name_)) {
131 int loopLevel = iterMap_.at(var->name_);
132 if (!HashComparator()(
loop(var->name_)->
len_, dimLen)) {
133 throw InvalidSchedule(
134 FT_MSG <<
"Iterator " << var->name_ <<
" of " << acc->var_
135 <<
" should loop over the entire range (" << dimLen
136 <<
"), instead of " <<
loop(var->name_)->
len_);
138 usedBy[loopLevel] =
true;
139 order.emplace_back(loopLevel);
141 return std::make_tuple(usedBy, order, baseAddr);
145 std::vector<bool> findDimsUsed(
const T &acc,
146 const std::vector<bool> &loopsUsed) {
147 std::vector<bool> dimsUsed(acc->indices_.size(),
false);
148 for (
auto &&[dimUsed, idx, dimLen] :
149 views::zip(dimsUsed, acc->indices_,
150 buffer(acc->var_)->tensor()->shape())) {
151 auto &&lin = analyzeLinear(idx);
153 if (lin.coeff_.size() != 1 ||
154 std::abs(lin.coeff_.front().k_) != 1 ||
158 Var var = lin.coeff_.front().a_.template as<VarNode>();
159 if (!iterMap_.count(var->name_) ||
160 !loopsUsed[iterMap_.at(var->name_)]) {
169 std::pair<Expr, Expr> findLenAndStride(
const T &acc,
170 const std::vector<bool> &dimsIn) {
172 bool lastDimIn =
false;
174 for (
auto &&[thisDimIn, idx, dimLen] : views::zip(
175 dimsIn, acc->indices_,
buffer(acc->var_)->tensor()->shape())) {
177 if (lastInDim.isValid()) {
179 throw InvalidSchedule(
180 FT_MSG <<
"Dimensions " << lastInDim <<
" and "
181 << idx <<
" should be contiguous");
184 len = len.isValid() ?
makeMul(len, dimLen) : (
Expr)dimLen;
188 stride = stride.isValid() ?
makeMul(stride, dimLen)
192 lastDimIn = thisDimIn;
196 return std::make_pair(len, stride);
199 void checkSameOrderOrRetry(
const ID &idA,
const std::vector<int> &orderA,
200 const std::vector<bool> &filterA,
const ID &idB,
201 const std::vector<int> &orderB,
202 const std::vector<bool> &filterB,
203 const std::string &message);
204 void checkSameOrderNoRetry(
const ID &idA,
const std::vector<int> &orderA,
205 const std::vector<bool> &filterA,
const ID &idB,
206 const std::vector<int> &orderB,
207 const std::vector<bool> &filterB,
208 const std::string &message);
210 void retryReorderingBack(
const ID &
id,
const std::vector<bool> &
filter,
211 const std::string &message);
212 void retryReorderingFront(
const ID &
id,
const std::vector<bool> &
filter,
213 const std::string &message);
Definition: analyze_linear.h:14
Definition: as_matmul.h:56
Stmt visit(const For &op) override
Definition: as_matmul.cc:177
const auto & dimsBBatch() const
Definition: as_matmul.h:90
const auto & defIdC() const
Definition: as_matmul.h:100
const auto & dimsCM() const
Definition: as_matmul.h:96
const auto & defIdB() const
Definition: as_matmul.h:99
const auto & dimsBN() const
Definition: as_matmul.h:95
const auto & dimsABatch() const
Definition: as_matmul.h:89
Stmt visitStmt(const Stmt &op) override
Definition: as_matmul.cc:165
AsMatMul(const ID &loop, MatMulBackend backend)
Definition: as_matmul.h:83
const auto & dimsCN() const
Definition: as_matmul.h:97
const auto & dimsBK() const
Definition: as_matmul.h:94
const auto & dimsAK() const
Definition: as_matmul.h:93
const auto & dimsCBatch() const
Definition: as_matmul.h:91
const auto & dimsAM() const
Definition: as_matmul.h:92
const ID & resultId() const
Definition: as_matmul.h:87
const auto & defIdA() const
Definition: as_matmul.h:98
bool done() const
Definition: as_matmul.h:86
SubTree< ExprNode > len_
Definition: stmt.h:297
bool isValid() const
Definition: id.h:33
SubTreeList< ExprNode > indices_
Definition: expr.h:54
Ref< U > as() const
Definition: ref.h:83
Definition: symbol_table.h:122
const For & loop(const std::string &name) const override
Definition: symbol_table.h:152
Ref< Buffer > buffer(const std::string &name) const override
Definition: symbol_table.h:145
#define ASSERT(expr)
Definition: except.h:152
#define FT_MSG
Definition: except.h:23
Definition: allocator.h:9
Expr makeLoad(const std::string &var, Tindices &&indices, DataType loadType, std::source_location loc=std::source_location::current())
Definition: expr.h:63
AsMatMulMode
Definition: as_matmul.h:17
Ref< VarNode > Var
Definition: expr.h:40
bool checkAllDefined(const std::unordered_set< std::string > &defs, const std::unordered_set< std::string > &namesInOp)
Definition: check_all_defined.h:11
Ref< VarDefNode > VarDef
Definition: stmt.h:107
Expr makeMul(T &&lhs, U &&rhs, std::source_location loc=std::source_location::current())
Definition: expr.h:202
Ref< StoreNode > Store
Definition: stmt.h:140
Stmt asMatMul(const Stmt &ast, const ID &loop, MatMulBackend backend)
Definition: as_matmul.cc:442
Ref< ReduceToNode > ReduceTo
Definition: stmt.h:248
Ref< ForNode > For
Definition: stmt.h:308
std::vector< T > filter(const std::vector< T > &vec, const U &callback)
Definition: container_utils.h:131
MatMulBackend
Definition: stmt.h:465
Ref< StmtNode > Stmt
Definition: ast.h:152
std::ostream & operator<<(std::ostream &os, const Dependence &dep)
Definition: deps.cc:1404
Ref< ExprNode > Expr
Definition: ast.h:184
Expr makeIntConst(int64_t val, std::source_location loc=std::source_location::current())
Definition: expr.h:102
freetensor::Rational< T > abs(const freetensor::Rational< T > &x)
Definition: rational.h:85
Definition: as_matmul.h:35
ID vardef_
Definition: as_matmul.h:36
NeedVarReorder(const ID &vardef, const std::vector< int > &order, const std::string &msg)
Definition: as_matmul.h:39
std::vector< int > order_
Definition: as_matmul.h:37