FreeTensor
Loading...
Searching...
No Matches
scalar_prop_const.h
Go to the documentation of this file.
1#ifndef FREE_TENSOR_SCALAR_PROP_CONST_H
2#define FREE_TENSOR_SCALAR_PROP_CONST_H
3
5#include <func.h>
6#include <hash.h>
7#include <mutator.h>
8#include <pass/const_fold.h>
9
10#include <map>
11
12namespace freetensor {
13
21class ScalarPropConst : public SymbolTable<ConstFold> {
22 protected:
24
25 struct ScalarIndex : public std::variant<int64_t, std::string> {
26 template <class... Args>
27 ScalarIndex(Args &&...args)
28 : std::variant<int64_t, std::string>(std::forward<Args>(args)...) {}
29
30 std::strong_ordering operator<=>(const ScalarIndex &other) const {
31 switch (index() * 2 + other.index()) {
32 case 0: // int int
33 return std::get<int64_t>(*this) <=> std::get<int64_t>(other);
34 case 1: // int str
35 return std::strong_ordering::less;
36 case 2: // str int
37 return std::strong_ordering::greater;
38 case 3: // str str
39 return std::get<std::string>(*this) <=>
40 std::get<std::string>(other);
41 default:
42 ASSERT(false);
43 }
44 }
45 };
46
51 std::vector<ScalarIndex> offset;
52
54 std::strong_ordering operator<=>(const ScalarIndices &other) const {
55 ASSERT(offset.size() == other.offset.size() &&
56 "Index count should be identical for same tensor");
57 return offset <=> other.offset;
58 }
59
60 bool operator==(const ScalarIndices &) const = default;
61 };
62
70 std::optional<ScalarIndices> tryToScalar(const std::vector<Expr> &exprs);
71
74 std::unordered_map<std::string, std::map<ScalarIndices, Expr>> constants_;
75
78 std::unordered_multimap<std::string, std::pair<std::string, ScalarIndices>>
80
81 void gen_constant(const std::string &name,
82 const std::optional<ScalarIndices> &indices,
83 const Expr &value);
84 void kill_iter_dep_entry(const std::string &name,
85 const ScalarIndices &indices);
86 void kill_constant(const std::string &name,
87 const std::optional<ScalarIndices> &indices);
88 void kill_iter(const std::string &it_var);
89
101 std::unordered_map<std::string, std::map<ScalarIndices, Expr>> other);
102
103 protected:
104 using BaseClass::visit;
105 Stmt visit(const Store &store_orig) override;
106 Stmt visit(const ReduceTo &op) override;
107 Expr visit(const Load &load_orig) override;
108 Stmt visit(const If &op) override;
109 Stmt visit(const VarDef &vd) override;
110 Stmt visit(const For &op) override;
111};
112
134Stmt scalarPropConst(const Stmt &op);
135
137
138} // namespace freetensor
139
140#endif // FREE_TENSOR_SCALAR_PROP_CONST_H
Definition: scalar_prop_const.h:21
bool intersect_constants_with(std::unordered_map< std::string, std::map< ScalarIndices, Expr > > other)
Intersect currently recorded scalar constants with provided map.
Definition: scalar_prop_const.cc:89
std::unordered_map< std::string, std::map< ScalarIndices, Expr > > constants_
Definition: scalar_prop_const.h:74
std::optional< ScalarIndices > tryToScalar(const std::vector< Expr > &exprs)
Try converting indices' AST nodes to constant indices.
Definition: scalar_prop_const.cc:17
void kill_iter_dep_entry(const std::string &name, const ScalarIndices &indices)
Definition: scalar_prop_const.cc:50
void gen_constant(const std::string &name, const std::optional< ScalarIndices > &indices, const Expr &value)
Definition: scalar_prop_const.cc:29
Stmt visit(const Store &store_orig) override
Store: kill & gen optionally.
Definition: scalar_prop_const.cc:125
void kill_iter(const std::string &it_var)
Definition: scalar_prop_const.cc:78
std::unordered_multimap< std::string, std::pair< std::string, ScalarIndices > > iter_dep_constants_
Definition: scalar_prop_const.h:79
void kill_constant(const std::string &name, const std::optional< ScalarIndices > &indices)
Definition: scalar_prop_const.cc:64
SymbolTable< ConstFold > BaseClass
Definition: scalar_prop_const.h:23
Definition: symbol_table.h:122
BaseClass::StmtRetType visit(const VarDef &op) override
Definition: symbol_table.h:167
#define ASSERT(expr)
Definition: except.h:152
#define DEFINE_PASS_FOR_FUNC(pass)
Definition: func.h:88
Definition: allocator.h:9
Stmt scalarPropConst(const Stmt &op)
Definition: scalar_prop_const.cc:241
STL namespace.
Definition: scalar_prop_const.h:25
std::strong_ordering operator<=>(const ScalarIndex &other) const
Definition: scalar_prop_const.h:30
ScalarIndex(Args &&...args)
Definition: scalar_prop_const.h:27
Indices to a scalar, includes a sequence of constant offsets.
Definition: scalar_prop_const.h:50
bool operator==(const ScalarIndices &) const =default
std::strong_ordering operator<=>(const ScalarIndices &other) const
Support comparison to use std::map.
Definition: scalar_prop_const.h:54
std::vector< ScalarIndex > offset
Definition: scalar_prop_const.h:51