FreeTensor
Loading...
Searching...
No Matches
linear.h
Go to the documentation of this file.
1#ifndef FREE_TENSOR_LINEAR_H
2#define FREE_TENSOR_LINEAR_H
3
4#include <algorithm>
5#include <iostream>
6
7#include <analyze/all_uses.h>
8#include <hash.h>
9
10namespace freetensor {
11
15template <class T> struct Scale {
16 T k_;
18};
19
23template <class T> struct LinearExpr {
24 // std::unordered_map can not guarantee ASTs generated from two identical
25 // `LinearExpr`s are the same, but std::map is too slow. So, we are using
26 // std::vector and sort each factor by its hash
27 std::vector<Scale<T>> coeff_;
28 T bias_ = 0;
29
30 bool isConst() const { return coeff_.empty(); }
31
32 std::unordered_set<std::string> allNames() const {
33 std::unordered_set<std::string> names;
34 for (auto &&[k, a] : coeff_) {
35 for (auto &&name : ::freetensor::allNames(a)) {
36 names.insert(name);
37 }
38 }
39 return names;
40 }
41};
42
43template <class T>
45 LinearExpr<T> ret;
46 auto m = lhs.coeff_.size(), n = rhs.coeff_.size();
47 ret.coeff_.reserve(m + n);
48 for (size_t p = 0, q = 0; p < m || q < n;) {
49 if (q == n ||
50 (p < m && lhs.coeff_[p].a_->hash() < rhs.coeff_[q].a_->hash())) {
51 ret.coeff_.emplace_back(lhs.coeff_[p++]);
52 } else if (p == m || (q < n && lhs.coeff_[p].a_->hash() >
53 rhs.coeff_[q].a_->hash())) {
54 ret.coeff_.emplace_back(rhs.coeff_[q++]);
55 } else {
56 Scale<T> s{lhs.coeff_[p].k_ + rhs.coeff_[q].k_, lhs.coeff_[p].a_};
57 p++, q++;
58 if (s.k_ != 0) {
59 ret.coeff_.emplace_back(s);
60 }
61 }
62 }
63 ret.bias_ = lhs.bias_ + rhs.bias_;
64 return ret;
65}
66
67template <class T>
69 LinearExpr<T> ret;
70 auto m = lhs.coeff_.size(), n = rhs.coeff_.size();
71 ret.coeff_.reserve(m + n);
72 for (size_t p = 0, q = 0; p < m || q < n;) {
73 if (q == n ||
74 (p < m && lhs.coeff_[p].a_->hash() < rhs.coeff_[q].a_->hash())) {
75 ret.coeff_.emplace_back(lhs.coeff_[p++]);
76 } else if (p == m || (q < n && lhs.coeff_[p].a_->hash() >
77 rhs.coeff_[q].a_->hash())) {
78 ret.coeff_.emplace_back(rhs.coeff_[q++]);
79 ret.coeff_.back().k_ = -ret.coeff_.back().k_;
80 } else {
81 Scale<T> s{lhs.coeff_[p].k_ - rhs.coeff_[q].k_, lhs.coeff_[p].a_};
82 p++, q++;
83 if (s.k_ != 0) {
84 ret.coeff_.emplace_back(s);
85 }
86 }
87 }
88 ret.bias_ = lhs.bias_ - rhs.bias_;
89 return ret;
90}
91
92template <class T> LinearExpr<T> mul(const LinearExpr<T> &lin, const T &k) {
93 if (k == 0) {
94 return LinearExpr<T>{{}, 0};
95 }
96 LinearExpr<T> ret;
97 ret.coeff_.reserve(lin.coeff_.size());
98 for (auto &&item : lin.coeff_) {
99 ret.coeff_.emplace_back(Scale<T>{item.k_ * k, item.a_});
100 }
101 ret.bias_ = lin.bias_ * k;
102 return ret;
103}
104
105template <class T>
107 if (lhs.coeff_.size() == rhs.coeff_.size()) {
108 for (size_t i = 0, iEnd = lhs.coeff_.size(); i < iEnd; i++) {
109 if (lhs.coeff_[i].k_ != rhs.coeff_[i].k_) {
110 return false;
111 }
112 if (!HashComparator()(lhs.coeff_[i].a_, rhs.coeff_[i].a_)) {
113 return false;
114 }
115 }
116 return true;
117 }
118 return false;
119}
120
128template <class T>
129 requires std::integral<T> || std::floating_point<T>
131 Expr b = makeIntConst(lin.bias_);
132
133 for (auto &&item : lin.coeff_) {
134 auto k = item.k_;
135 auto a = deepCopy(item.a_);
136
137 if (k == 0) {
138 continue;
139 }
140 Expr x;
141 if (a->nodeType() == ASTNodeType::IntConst) {
142 x = makeIntConst(k * a.template as<IntConstNode>()->val_);
143 } else if (k == 1) {
144 x = a;
145 } else {
146 x = makeMul(makeIntConst(k), a);
147 }
148
149 if (x->nodeType() == ASTNodeType::IntConst &&
151 x = makeIntConst(x.as<IntConstNode>()->val_ +
152 b.as<IntConstNode>()->val_);
153 } else if (b->nodeType() == ASTNodeType::IntConst &&
154 b.as<IntConstNode>()->val_ == 0) {
155 // do nothing
156 } else {
157 x = makeAdd(x, b);
158 }
159
160 b = std::move(x);
161 }
162
163 return b;
164}
165
166template <class T>
168 return hasIdenticalCoeff(lhs, rhs) && lhs.bias_ == rhs.bias_;
169}
170
171template <class T>
172std::ostream &operator<<(std::ostream &os, const LinearExpr<T> &lin) {
173 for (auto &&[k, a] : lin.coeff_) {
174 os << k << " * " << a << " + ";
175 }
176 os << lin.bias_;
177 return os;
178}
179
180} // namespace freetensor
181
182#endif // FREE_TENSOR_LINEAR_H
virtual ASTNodeType nodeType() const =0
Definition: hash.h:67
Definition: expr.h:93
int64_t val_
Definition: expr.h:95
Ref< U > as() const
Definition: ref.h:83
int n
Definition: metadata.cc:15
Definition: allocator.h:9
auto && lhs
Definition: const_fold.cc:70
Expr makeMul(T &&lhs, U &&rhs, std::source_location loc=std::source_location::current())
Definition: expr.h:202
Expr makeAdd(T &&lhs, U &&rhs, std::source_location loc=std::source_location::current())
Definition: expr.h:174
UpperBound sub(const UpperBound &b1, const LowerBound &b2)
Definition: bounds.cc:200
Expr lin2expr(const LinearExpr< T > &lin)
Definition: linear.h:130
bool operator==(const Allocator< T > &lhs, const Allocator< T > &rhs)
Definition: allocator.h:100
Expr deepCopy(const Expr &op)
Definition: ast.cc:364
UpperBound mul(const UpperBound &b, int k)
Definition: bounds.cc:207
bool hasIdenticalCoeff(const LinearExpr< T > &lhs, const LinearExpr< T > &rhs)
Definition: linear.h:106
std::unordered_set< std::string > allNames(const AST &op, bool noRecurseIdx=false, bool noRecurseSubStmt=false)
Definition: all_uses.h:134
auto auto && rhs
Definition: const_fold.cc:70
UpperBound add(const UpperBound &b1, const UpperBound &b2)
Definition: bounds.cc:193
std::ostream & operator<<(std::ostream &os, const Dependence &dep)
Definition: deps.cc:1404
Expr makeIntConst(int64_t val, std::source_location loc=std::source_location::current())
Definition: expr.h:102
Definition: linear.h:23
std::unordered_set< std::string > allNames() const
Definition: linear.h:32
T bias_
Definition: linear.h:28
bool isConst() const
Definition: linear.h:30
std::vector< Scale< T > > coeff_
Definition: linear.h:27
Definition: linear.h:15
T k_
Definition: linear.h:16
Expr a_
Definition: linear.h:17