FreeTensor
Loading...
Searching...
No Matches
hash.h
Go to the documentation of this file.
1#ifndef FREE_TENSOR_HASH_H
2#define FREE_TENSOR_HASH_H
3
4#include <unordered_map>
5#include <unordered_set>
6
7#include <expr.h>
8#include <hash_combine.h>
9#include <stmt.h>
10
11namespace freetensor {
12
13class Hasher {
14 static constexpr size_t P = 2147483647; // % P
15 static constexpr size_t K1 = 179424673, B1 = 275604541;
16 // (node type * K1 + B1) % P
17 static constexpr size_t K2 = 373587883, B2 = 472882027;
18 // ((current hash + non-permutable factor) * K2 + B2) % P
19 // or
20 // (current hash + permutable factor) % P
21 static constexpr size_t K3 = 573259391, B3 = 674506081;
22 // (finally * K3 + B3) % P
23
24 public:
25 static size_t compHash(const Tensor &t);
26 static size_t compHash(const Buffer &b);
27 static size_t compHash(const ReductionItem &r);
28 static size_t compHash(const ForProperty &p);
29 static size_t compHash(const CutlassMicroKernelProperty &p);
30
31 // stmt
32 static size_t compHash(const AnyNode &op);
33 static size_t compHash(const StmtSeqNode &op);
34 static size_t compHash(const VarDefNode &op);
35 static size_t compHash(const StoreNode &op);
36 static size_t compHash(const AllocNode &op);
37 static size_t compHash(const FreeNode &op);
38 static size_t compHash(const ReduceToNode &op);
39 static size_t compHash(const ForNode &op);
40 static size_t compHash(const IfNode &op);
41 static size_t compHash(const AssertNode &op);
42 static size_t compHash(const AssumeNode &op);
43 static size_t compHash(const EvalNode &op);
44 static size_t compHash(const MatMulNode &op);
45 static size_t compHash(const MarkVersionNode &op);
46
47 // expr
48 static size_t compHash(const CommutativeBinaryExprNode &op);
49 static size_t compHash(const NonCommutativeBinaryExprNode &op);
50 static size_t compHash(const UnaryExprNode &op);
51 static size_t compHash(const AnyExprNode &op);
52 static size_t compHash(const VarNode &op);
53 static size_t compHash(const LoadNode &op);
54 static size_t compHash(const IntConstNode &op);
55 static size_t compHash(const FloatConstNode &op);
56 static size_t compHash(const BoolConstNode &op);
57 static size_t compHash(const IfExprNode &op);
58 static size_t compHash(const CastNode &op);
59 static size_t compHash(const IntrinsicNode &op);
60 static size_t compHash(const LoadAtVersionNode &op);
61
62 size_t operator()(const Ref<ASTPart> &op) const {
63 return op.isValid() ? op->hash() : P;
64 }
65};
66
68 private:
69 // stmt
70 bool compare(const Any &lhs, const Any &rhs) const;
71 bool compare(const StmtSeq &lhs, const StmtSeq &rhs) const;
72 bool compare(const VarDef &lhs, const VarDef &rhs) const;
73 bool compare(const Store &lhs, const Store &rhs) const;
74 bool compare(const Alloc &lhs, const Alloc &rhs) const;
75 bool compare(const Free &lhs, const Free &rhs) const;
76 bool compare(const ReduceTo &lhs, const ReduceTo &rhs) const;
77 bool compare(const For &lhs, const For &rhs) const;
78 bool compare(const If &lhs, const If &rhs) const;
79 bool compare(const Assert &lhs, const Assert &rhs) const;
80 bool compare(const Assume &lhs, const Assume &rhs) const;
81 bool compare(const Eval &lhs, const Eval &rhs) const;
82 bool compare(const MatMul &lhs, const MatMul &rhs) const;
83 bool compare(const MarkVersion &lhs, const MarkVersion &rhs) const;
84
85 // expr
86 bool compare(const CommutativeBinaryExpr &lhs,
87 const CommutativeBinaryExpr &rhs) const;
88 bool compare(const NonCommutativeBinaryExpr &lhs,
89 const NonCommutativeBinaryExpr &rhs) const;
90 bool compare(const UnaryExpr &lhs, const UnaryExpr &rhs) const;
91 bool compare(const Var &lhs, const Var &rhs) const;
92 bool compare(const IntConst &lhs, const IntConst &rhs) const;
93 bool compare(const FloatConst &lhs, const FloatConst &rhs) const;
94 bool compare(const BoolConst &lhs, const BoolConst &rhs) const;
95 bool compare(const Load &lhs, const Load &rhs) const;
96 bool compare(const IfExpr &lhs, const IfExpr &rhs) const;
97 bool compare(const Cast &lhs, const Cast &rhs) const;
98 bool compare(const Intrinsic &lhs, const Intrinsic &rhs) const;
99 bool compare(const LoadAtVersion &lhs, const LoadAtVersion &rhs) const;
100
101 public:
102 bool operator()(const Ref<Tensor> &lhs, const Ref<Tensor> &rhs) const;
103 bool operator()(const Ref<Buffer> &lhs, const Ref<Buffer> &rhs) const;
105 const Ref<ReductionItem> &rhs) const;
106 bool operator()(const Ref<ForProperty> &lhs,
107 const Ref<ForProperty> &rhs) const;
110 bool operator()(const AST &lhs, const AST &rhs) const;
111};
112
113template <class K, class V>
114using ASTHashMap = std::unordered_map<K, V, Hasher, HashComparator>;
115
116template <class K>
117using ASTHashSet = std::unordered_set<K, Hasher, HashComparator>;
118
119// Default operator== on std::unordered_map and std::unordered_set does not work
120// when we use custom KeyEqual, so we need to define our own. See
121// https://stackoverflow.com/questions/36167764/can-not-compare-stdunorded-set-with-custom-keyequal
122
123template <class K, class V>
124inline bool operator==(const ASTHashMap<K, V> &lhs,
125 const ASTHashMap<K, V> &rhs) {
126 if (lhs.size() != rhs.size()) {
127 return false;
128 }
129 for (auto &&[k, v] : lhs) {
130 if (!rhs.count(k) || rhs.at(k) != v) {
131 return false;
132 }
133 }
134 return true;
135}
136
137template <class K>
138inline bool operator==(const ASTHashSet<K> &lhs, const ASTHashSet<K> &rhs) {
139 if (lhs.size() != rhs.size()) {
140 return false;
141 }
142 for (auto &&k : lhs) {
143 if (!rhs.count(k)) {
144 return false;
145 }
146 }
147 return true;
148}
149
150} // namespace freetensor
151
152namespace std {
153
154template <class T, class U> class hash<std::pair<T, U>> {
155 std::hash<T> hashT_;
156 std::hash<U> hashU_;
157
158 public:
159 size_t operator()(const std::pair<T, U> &pair) const {
160 return freetensor::hashCombine(hashT_(pair.first), hashU_(pair.second));
161 }
162};
163
164template <class... Ts> class hash<std::tuple<Ts...>> {
165 private:
166 template <class T> static size_t hashManyImpl(const T &first) {
167 return std::hash<T>()(first);
168 }
169
170 template <class T, class... Args>
171 static size_t hashManyImpl(const T &first, const Args &...others) {
172 return freetensor::hashCombine(std::hash<T>()(first),
173 hashManyImpl(others...));
174 }
175
176 // std::apply needs an unoverloaded invocable
177 static size_t hashMany(const Ts &...args) { return hashManyImpl(args...); }
178
179 public:
180 size_t operator()(const std::tuple<Ts...> &t) const {
181 return std::apply(hashMany, t);
182 }
183};
184
185template <class T> class hash<std::vector<T>> {
186 std::hash<T> hash_;
187
188 public:
189 size_t operator()(const std::vector<T> &vec) const {
190 // Encode the size first, to distinguish between one vector and several
191 // consecutive or nested vectors
192 size_t h = std::hash<size_t>()(vec.size());
193
194 for (auto &&item : vec) {
195 h = freetensor::hashCombine(h, hash_(item));
196 }
197 return h;
198 }
199};
200
201} // namespace std
202
203#endif // FREE_TENSOR_HASH_H
Definition: stmt.h:176
Definition: expr.h:17
Definition: stmt.h:22
Definition: stmt.h:383
Definition: stmt.h:417
Definition: expr.h:127
Definition: buffer.h:11
Definition: expr.h:717
Definition: stmt.h:444
Definition: expr.h:110
Definition: stmt.h:287
Definition: stmt.h:201
Definition: hash.h:67
bool operator()(const Ref< Tensor > &lhs, const Ref< Tensor > &rhs) const
Definition: hash.cc:575
Definition: hash.h:13
size_t operator()(const Ref< ASTPart > &op) const
Definition: hash.h:62
static size_t compHash(const Tensor &t)
Definition: hash.cc:25
Definition: expr.h:693
Definition: stmt.h:333
Definition: expr.h:93
Definition: expr.h:740
Definition: expr.h:780
Definition: expr.h:51
Definition: stmt.h:573
Definition: stmt.h:501
Definition: stmt.h:229
Definition: ref.h:24
bool isValid() const
Definition: ref.h:89
Definition: stmt.h:42
Definition: stmt.h:132
Definition: tensor.h:13
Definition: expr.h:474
Definition: stmt.h:83
Definition: expr.h:32
size_t operator()(const std::pair< T, U > &pair) const
Definition: hash.h:159
size_t operator()(const std::tuple< Ts... > &t) const
Definition: hash.h:180
size_t operator()(const std::vector< T > &vec) const
Definition: hash.h:189
Definition: allocator.h:9
std::unordered_map< K, V, Hasher, HashComparator > ASTHashMap
Definition: hash.h:114
auto && lhs
Definition: const_fold.cc:70
bool operator==(const Allocator< T > &lhs, const Allocator< T > &rhs)
Definition: allocator.h:100
auto auto && rhs
Definition: const_fold.cc:70
size_t hashCombine(size_t seed, size_t other)
Definition: hash_combine.cc:5
std::unordered_set< K, Hasher, HashComparator > ASTHashSet
Definition: hash.h:117
STL namespace.
Definition: cutlass_micro_kernel_property.h:9
Definition: for_property.h:38
Definition: for_property.h:11