FreeTensor
Loading...
Searching...
No Matches
var_reorder.h
Go to the documentation of this file.
1#ifndef FREE_TENSOR_VAR_REORDER_H
2#define FREE_TENSOR_VAR_REORDER_H
3
4#include <algorithm>
5
7#include <mutator.h>
8
9namespace freetensor {
10
11class VarReorder : public SymbolTable<Mutator> {
13
14 ID def_;
15 std::string var_;
16 std::vector<int> order_;
17 bool forceReorderInMatMul_;
18 bool found_ = false;
19
20 public:
21 VarReorder(const ID &def, const std::vector<int> &order,
22 bool forceReorderInMatMul)
23 : def_(def), order_(order),
24 forceReorderInMatMul_(forceReorderInMatMul) {
25 std::vector<int> numbers;
26 numbers.reserve(order.size());
27 for (int i = 0, n = order.size(); i < n; i++) {
28 numbers.emplace_back(i);
29 }
30 if (!std::is_permutation(order.begin(), order.end(), numbers.begin())) {
31 throw InvalidSchedule("The new order should be a permutation of "
32 "the existing dimensions");
33 }
34 }
35
36 bool found() const { return found_; }
37
38 private:
39 template <class T> T reorderMemAcc(const T &op) {
40 if (op->var_ == var_) {
41 std::vector<Expr> indices;
42 indices.reserve(order_.size());
43 if (order_.size() != op->indices_.size()) {
44 throw InvalidSchedule("Number of dimensions in the order does "
45 "not match the variable");
46 }
47 for (auto &&nth : order_) {
48 indices.emplace_back(op->indices_[nth]);
49 }
50 op->indices_ = std::move(indices);
51 }
52 return op;
53 }
54
55 protected:
56 using BaseClass::visit;
57 Stmt visit(const VarDef &op) override;
58 Stmt visit(const Store &op) override;
59 Stmt visit(const ReduceTo &op) override;
60 Expr visit(const Load &op) override;
61 Stmt visit(const MatMul &op) override;
62};
63
64Stmt varReorderImpl(const Stmt &ast, const ID &def,
65 const std::vector<int> &order,
66 bool forceReorderInMatMul = false);
67
68Stmt varReorder(const Stmt &ast, const ID &def, const std::vector<int> &order);
69
70} // namespace freetensor
71
72#endif // FREE_TENSOR_VAR_REORDER_H
Definition: id.h:18
Definition: except.h:40
Definition: symbol_table.h:122
BaseClass::StmtRetType visit(const VarDef &op) override
Definition: symbol_table.h:167
const VarDef & def(const std::string &name) const override
Definition: symbol_table.h:142
Definition: var_reorder.h:11
bool found() const
Definition: var_reorder.h:36
Stmt visit(const VarDef &op) override
Definition: var_reorder.cc:7
VarReorder(const ID &def, const std::vector< int > &order, bool forceReorderInMatMul)
Definition: var_reorder.h:21
int n
Definition: metadata.cc:15
Definition: allocator.h:9
Ref< VarDefNode > VarDef
Definition: stmt.h:107
Ref< LoadNode > Load
Definition: expr.h:61
Ref< StoreNode > Store
Definition: stmt.h:140
Ref< ReduceToNode > ReduceTo
Definition: stmt.h:248
Ref< StmtNode > Stmt
Definition: ast.h:152
Ref< ExprNode > Expr
Definition: ast.h:184
Stmt varReorder(const Stmt &ast, const ID &def, const std::vector< int > &order)
Definition: var_reorder.cc:82
Ref< MatMulNode > MatMul
Definition: stmt.h:533
Stmt varReorderImpl(const Stmt &ast, const ID &def, const std::vector< int > &order, bool forceReorderInMatMul=false)
Definition: var_reorder.cc:72