1#ifndef FREE_TENSOR_VAR_REORDER_H
2#define FREE_TENSOR_VAR_REORDER_H
16 std::vector<int> order_;
17 bool forceReorderInMatMul_;
22 bool forceReorderInMatMul)
23 : def_(
def), order_(order),
24 forceReorderInMatMul_(forceReorderInMatMul) {
25 std::vector<int> numbers;
26 numbers.reserve(order.size());
27 for (
int i = 0,
n = order.size(); i <
n; i++) {
28 numbers.emplace_back(i);
30 if (!std::is_permutation(order.begin(), order.end(), numbers.begin())) {
32 "the existing dimensions");
36 bool found()
const {
return found_; }
39 template <
class T> T reorderMemAcc(
const T &op) {
40 if (op->var_ == var_) {
41 std::vector<Expr> indices;
42 indices.reserve(order_.size());
43 if (order_.size() != op->indices_.size()) {
45 "not match the variable");
47 for (
auto &&nth : order_) {
48 indices.emplace_back(op->indices_[nth]);
50 op->indices_ = std::move(indices);
65 const std::vector<int> &order,
66 bool forceReorderInMatMul =
false);
Definition: symbol_table.h:122
BaseClass::StmtRetType visit(const VarDef &op) override
Definition: symbol_table.h:167
const VarDef & def(const std::string &name) const override
Definition: symbol_table.h:142
Definition: var_reorder.h:11
bool found() const
Definition: var_reorder.h:36
Stmt visit(const VarDef &op) override
Definition: var_reorder.cc:7
VarReorder(const ID &def, const std::vector< int > &order, bool forceReorderInMatMul)
Definition: var_reorder.h:21
Definition: allocator.h:9
Ref< VarDefNode > VarDef
Definition: stmt.h:107
Ref< LoadNode > Load
Definition: expr.h:61
Ref< StoreNode > Store
Definition: stmt.h:140
Ref< ReduceToNode > ReduceTo
Definition: stmt.h:248
Ref< StmtNode > Stmt
Definition: ast.h:152
Ref< ExprNode > Expr
Definition: ast.h:184
Stmt varReorder(const Stmt &ast, const ID &def, const std::vector< int > &order)
Definition: var_reorder.cc:82
Ref< MatMulNode > MatMul
Definition: stmt.h:533
Stmt varReorderImpl(const Stmt &ast, const ID &def, const std::vector< int > &order, bool forceReorderInMatMul=false)
Definition: var_reorder.cc:72