FreeTensor
Loading...
Searching...
No Matches
as_matmul.h
Go to the documentation of this file.
1#ifndef FREE_TENSOR_MAKE_MATMUL_H
2#define FREE_TENSOR_MAKE_MATMUL_H
3
4#include <sstream>
5#include <unordered_map>
6#include <unordered_set>
7
11#include <container_utils.h>
12#include <hash.h>
13#include <mutator.h>
14
15namespace freetensor {
16
17enum class AsMatMulMode : int {
21};
22inline std::ostream &operator<<(std::ostream &os, AsMatMulMode mode) {
23 switch (mode) {
25 return os << "keep_mem_layout";
27 return os << "try_var_reorder";
29 return os << "try_transpose";
30 default:
31 ASSERT(false);
32 }
33}
34
37 std::vector<int> order_;
38
39 NeedVarReorder(const ID &vardef, const std::vector<int> &order,
40 const std::string &msg)
41 : Error(genErrMsg(vardef, order, msg)), vardef_(vardef), order_(order) {
42 }
43
44 private:
45 static std::string genErrMsg(const ID &vardef,
46 const std::vector<int> &order,
47 const std::string &msg) {
48 std::ostringstream os;
49 os << msg << ". Consider retrying after `var_reorder`ing " << vardef
50 << " to order [" << order
51 << "], or retrying with a different `mode` of `as_matmul`";
52 return os.str();
53 }
54};
55
56class AsMatMul : public SymbolTable<Mutator> {
58
59 ID loop_;
60 MatMulBackend backend_;
61
62 int nestCnt_ = 0;
63 std::unordered_map<std::string, int> iterMap_; // iter var -> nest cnt
64 std::unordered_set<std::string> outerDefs_;
65 std::vector<VarDef> innerDefs_;
66 std::vector<int> orderInit_;
67
68 bool foundInit_ = false, foundLeaf_ = false, inside_ = false;
69 Expr a_, b_, c_, initC_, m_, k_, n_, lda_, stridea_, ldb_, strideb_, ldc_,
70 stridec_, batchSize_;
71 bool aIsRowMajor_, bIsRowMajor_, cIsRowMajor_;
72
73 AnalyzeLinear analyzeLinear_;
74
75 ID resultId_;
76
77 // Public matching details
78 std::vector<bool> dimsABatch_, dimsBBatch_, dimsCBatch_, dimsAM_, dimsAK_,
79 dimsBK_, dimsBN_, dimsCM_, dimsCN_;
80 ID defIdA_, defIdB_, defIdC_;
81
82 public:
83 AsMatMul(const ID &loop, MatMulBackend backend)
84 : loop_(loop), backend_(backend) {}
85
86 bool done() const { return resultId_.isValid(); }
87 const ID &resultId() const { return resultId_; }
88
89 const auto &dimsABatch() const { return dimsABatch_; }
90 const auto &dimsBBatch() const { return dimsBBatch_; }
91 const auto &dimsCBatch() const { return dimsCBatch_; }
92 const auto &dimsAM() const { return dimsAM_; }
93 const auto &dimsAK() const { return dimsAK_; }
94 const auto &dimsBK() const { return dimsBK_; }
95 const auto &dimsBN() const { return dimsBN_; }
96 const auto &dimsCM() const { return dimsCM_; }
97 const auto &dimsCN() const { return dimsCN_; }
98 const auto &defIdA() const { return defIdA_; }
99 const auto &defIdB() const { return defIdB_; }
100 const auto &defIdC() const { return defIdC_; }
101
102 private:
103 const LinearExpr<int64_t> &analyzeLinear(const Expr &expr);
104
105 template <class T>
106 std::tuple<std::vector<bool>, std::vector<int>, Expr>
107 findIterUsedAndBaseAddr(const T &acc) {
108 std::vector<bool> usedBy(nestCnt_, false);
109 std::vector<int> order;
110 Expr baseAddr = makeLoad(acc->var_, acc->indices_,
111 buffer(acc->var_)->tensor()->dtype());
112 for (auto &&[idx, dimLen, baseIdx] :
113 views::zip(acc->indices_, buffer(acc->var_)->tensor()->shape(),
114 baseAddr.as<LoadNode>()->indices_)) {
115 auto &&lin = analyzeLinear(idx);
116 if (lin.coeff_.size() != 1 ||
117 std::abs(lin.coeff_.front().k_) != 1 ||
118 lin.coeff_.front().a_->nodeType() != ASTNodeType::Var) {
119 if (!checkAllDefined(outerDefs_, idx)) {
120 throw InvalidSchedule("Indices of " + acc->var_ +
121 " should be plain loop iterators");
122 }
123 continue; // not a dim in matmul
124 }
125 Var var = lin.coeff_.front().a_.template as<VarNode>();
126 if (!iterMap_.count(var->name_)) {
127 continue; // not a dim in matmul
128 } else {
129 baseIdx = makeIntConst(0);
130 }
131 int loopLevel = iterMap_.at(var->name_);
132 if (!HashComparator()(loop(var->name_)->len_, dimLen)) {
133 throw InvalidSchedule(
134 FT_MSG << "Iterator " << var->name_ << " of " << acc->var_
135 << " should loop over the entire range (" << dimLen
136 << "), instead of " << loop(var->name_)->len_);
137 }
138 usedBy[loopLevel] = true;
139 order.emplace_back(loopLevel);
140 }
141 return std::make_tuple(usedBy, order, baseAddr);
142 }
143
144 template <class T>
145 std::vector<bool> findDimsUsed(const T &acc,
146 const std::vector<bool> &loopsUsed) {
147 std::vector<bool> dimsUsed(acc->indices_.size(), false);
148 for (auto &&[dimUsed, idx, dimLen] :
149 views::zip(dimsUsed, acc->indices_,
150 buffer(acc->var_)->tensor()->shape())) {
151 auto &&lin = analyzeLinear(idx);
152 dimUsed = true;
153 if (lin.coeff_.size() != 1 ||
154 std::abs(lin.coeff_.front().k_) != 1 ||
155 lin.coeff_.front().a_->nodeType() != ASTNodeType::Var) {
156 dimUsed = false;
157 } else {
158 Var var = lin.coeff_.front().a_.template as<VarNode>();
159 if (!iterMap_.count(var->name_) ||
160 !loopsUsed[iterMap_.at(var->name_)]) {
161 dimUsed = false;
162 }
163 }
164 }
165 return dimsUsed;
166 }
167
168 template <class T>
169 std::pair<Expr, Expr> findLenAndStride(const T &acc,
170 const std::vector<bool> &dimsIn) {
171 Expr len, stride;
172 bool lastDimIn = false;
173 Expr lastInDim;
174 for (auto &&[thisDimIn, idx, dimLen] : views::zip(
175 dimsIn, acc->indices_, buffer(acc->var_)->tensor()->shape())) {
176 if (thisDimIn) {
177 if (lastInDim.isValid()) {
178 if (!lastDimIn) {
179 throw InvalidSchedule(
180 FT_MSG << "Dimensions " << lastInDim << " and "
181 << idx << " should be contiguous");
182 }
183 }
184 len = len.isValid() ? makeMul(len, dimLen) : (Expr)dimLen;
185 lastInDim = idx;
186 } else {
187 if (len.isValid()) {
188 stride = stride.isValid() ? makeMul(stride, dimLen)
189 : (Expr)dimLen;
190 }
191 }
192 lastDimIn = thisDimIn;
193 }
194 len = len.isValid() ? len : makeIntConst(1);
195 stride = stride.isValid() ? stride : makeIntConst(1);
196 return std::make_pair(len, stride);
197 }
198
199 void checkSameOrderOrRetry(const ID &idA, const std::vector<int> &orderA,
200 const std::vector<bool> &filterA, const ID &idB,
201 const std::vector<int> &orderB,
202 const std::vector<bool> &filterB,
203 const std::string &message);
204 void checkSameOrderNoRetry(const ID &idA, const std::vector<int> &orderA,
205 const std::vector<bool> &filterA, const ID &idB,
206 const std::vector<int> &orderB,
207 const std::vector<bool> &filterB,
208 const std::string &message);
209
210 void retryReorderingBack(const ID &id, const std::vector<bool> &filter,
211 const std::string &message);
212 void retryReorderingFront(const ID &id, const std::vector<bool> &filter,
213 const std::string &message);
214
215 protected:
216 Stmt visitStmt(const Stmt &op) override;
217 Stmt visit(const For &op) override;
218 Stmt visit(const ReduceTo &op) override;
219 Stmt visit(const Store &op) override;
220 Stmt visit(const VarDef &op) override;
221};
222
223Stmt asMatMul(const Stmt &ast, const ID &loop, MatMulBackend backend);
224
225} // namespace freetensor
226
227#endif // FREE_TENSOR_MAKE_MATMUL_H
Definition: analyze_linear.h:14
Definition: as_matmul.h:56
Stmt visit(const For &op) override
Definition: as_matmul.cc:177
const auto & dimsBBatch() const
Definition: as_matmul.h:90
const auto & defIdC() const
Definition: as_matmul.h:100
const auto & dimsCM() const
Definition: as_matmul.h:96
const auto & defIdB() const
Definition: as_matmul.h:99
const auto & dimsBN() const
Definition: as_matmul.h:95
const auto & dimsABatch() const
Definition: as_matmul.h:89
Stmt visitStmt(const Stmt &op) override
Definition: as_matmul.cc:165
AsMatMul(const ID &loop, MatMulBackend backend)
Definition: as_matmul.h:83
const auto & dimsCN() const
Definition: as_matmul.h:97
const auto & dimsBK() const
Definition: as_matmul.h:94
const auto & dimsAK() const
Definition: as_matmul.h:93
const auto & dimsCBatch() const
Definition: as_matmul.h:91
const auto & dimsAM() const
Definition: as_matmul.h:92
const ID & resultId() const
Definition: as_matmul.h:87
const auto & defIdA() const
Definition: as_matmul.h:98
bool done() const
Definition: as_matmul.h:86
Definition: except.h:25
SubTree< ExprNode > len_
Definition: stmt.h:297
Definition: id.h:18
bool isValid() const
Definition: id.h:33
Definition: except.h:40
Definition: expr.h:51
SubTreeList< ExprNode > indices_
Definition: expr.h:54
Ref< U > as() const
Definition: ref.h:83
Definition: symbol_table.h:122
const For & loop(const std::string &name) const override
Definition: symbol_table.h:152
Ref< Buffer > buffer(const std::string &name) const override
Definition: symbol_table.h:145
#define ASSERT(expr)
Definition: except.h:152
#define FT_MSG
Definition: except.h:23
Definition: allocator.h:9
Expr makeLoad(const std::string &var, Tindices &&indices, DataType loadType, std::source_location loc=std::source_location::current())
Definition: expr.h:63
AsMatMulMode
Definition: as_matmul.h:17
Ref< VarNode > Var
Definition: expr.h:40
bool checkAllDefined(const std::unordered_set< std::string > &defs, const std::unordered_set< std::string > &namesInOp)
Definition: check_all_defined.h:11
Ref< VarDefNode > VarDef
Definition: stmt.h:107
Expr makeMul(T &&lhs, U &&rhs, std::source_location loc=std::source_location::current())
Definition: expr.h:202
Ref< StoreNode > Store
Definition: stmt.h:140
Stmt asMatMul(const Stmt &ast, const ID &loop, MatMulBackend backend)
Definition: as_matmul.cc:442
Ref< ReduceToNode > ReduceTo
Definition: stmt.h:248
Ref< ForNode > For
Definition: stmt.h:308
std::vector< T > filter(const std::vector< T > &vec, const U &callback)
Definition: container_utils.h:131
MatMulBackend
Definition: stmt.h:465
Ref< StmtNode > Stmt
Definition: ast.h:152
std::ostream & operator<<(std::ostream &os, const Dependence &dep)
Definition: deps.cc:1404
Ref< ExprNode > Expr
Definition: ast.h:184
Expr makeIntConst(int64_t val, std::source_location loc=std::source_location::current())
Definition: expr.h:102
freetensor::Rational< T > abs(const freetensor::Rational< T > &x)
Definition: rational.h:85
Definition: linear.h:23
Definition: as_matmul.h:35
ID vardef_
Definition: as_matmul.h:36
NeedVarReorder(const ID &vardef, const std::vector< int > &order, const std::string &msg)
Definition: as_matmul.h:39
std::vector< int > order_
Definition: as_matmul.h:37