FreeTensor
Loading...
Searching...
No Matches
bounds.h
Go to the documentation of this file.
1#ifndef FREE_TENSOR_BOUNDS_H
2#define FREE_TENSOR_BOUNDS_H
3
4#include <iostream>
5#include <optional>
6#include <unordered_set>
7
8#include <math/linear.h>
9#include <math/rational.h>
10
11namespace freetensor {
12
13namespace detail {
15};
16
18 Expr expr_;
19 std::optional<std::unordered_set<std::string>> allNames_;
21
22 public:
24 : expr_(expr), lin_{{{1, deepCopy(expr)}}, 0} {}
25 UpperBound(const LinearExpr<Rational<int64_t>> &lin) : lin_(lin) {}
26 UpperBound(LinearExpr<Rational<int64_t>> &&lin) : lin_(std::move(lin)) {}
27
28 const Expr &expr();
29 const std::unordered_set<std::string> &allNames();
30 const LinearExpr<Rational<int64_t>> &lin() const { return lin_; }
31
32 friend std::ostream &operator<<(std::ostream &os, const UpperBound &b) {
33 return os << b.lin();
34 }
35};
36
38 Expr expr_;
39 std::optional<std::unordered_set<std::string>> allNames_;
41
42 public:
43 LowerBound(const Expr &expr) : expr_(expr), lin_{{{1, expr}}, 0} {}
44 LowerBound(const LinearExpr<Rational<int64_t>> &lin) : lin_(lin) {}
45 LowerBound(LinearExpr<Rational<int64_t>> &&lin) : lin_(std::move(lin)) {}
46
47 const Expr &expr();
48 const std::unordered_set<std::string> &allNames();
49 const LinearExpr<Rational<int64_t>> &lin() const { return lin_; }
50
51 friend std::ostream &operator<<(std::ostream &os, const LowerBound &u) {
52 return os << u.lin();
53 }
54};
55
56UpperBound add(const UpperBound &b1, const UpperBound &b2);
57LowerBound add(const LowerBound &b1, const LowerBound &b2);
58
59UpperBound sub(const UpperBound &b1, const LowerBound &b2);
60LowerBound sub(const LowerBound &b1, const UpperBound &b2);
61
62// we deal with multiplying constant only. Otherwise, the extreme value of
63// `x * y` may not falls in the extreme value of `x` and `y`
64UpperBound mul(const UpperBound &b, int k);
65LowerBound mul(const LowerBound &b, int k);
66
67// we deal with dividing by constant only. Otherwise, the extreme value of
68// `x / y` may not falls in the extreme value of `x` and `y`
69UpperBound floorDiv(const UpperBound &b, int k);
70LowerBound floorDiv(const LowerBound &b, int k);
71UpperBound ceilDiv(const UpperBound &b, int k);
72LowerBound ceilDiv(const LowerBound &b, int k);
73
74bool alwaysLT(const UpperBound &b1, const LowerBound &b2);
75bool alwaysLE(const UpperBound &b1, const LowerBound &b2);
76
82template <class T>
83std::pair<std::optional<LowerBound>, std::optional<UpperBound>>
84lin2bounds(const LinearExpr<T> &_lin, ASTNodeType cmp, const Expr &x) {
85 typedef std::pair<std::optional<LowerBound>, std::optional<UpperBound>>
86 RetType;
87
88 // 1. Remove x from lin
89 // 2. Convert to a rational linear because we need to do division later
91 std::optional<Rational<int64_t>> selfK;
92 lin.bias_ = _lin.bias_;
93 lin.coeff_.reserve(_lin.coeff_.size() - 1);
94 for (auto &&[k, a] : _lin.coeff_) {
95 if (HashComparator()(a, x)) {
96 ASSERT(!selfK.has_value());
97 selfK = std::make_optional<Rational<int64_t>>(k);
98 } else {
99 lin.coeff_.emplace_back(Scale<Rational<int64_t>>{k, a});
100 }
101 }
102 if (!selfK.has_value() || *selfK == 0) {
103 return RetType(std::nullopt, std::nullopt);
104 }
105
106 // 3. Normalize according to selfK
107 // Now x is at the left side and the other items are at the right side
108 if (*selfK < 0) {
109 cmp = detail::reverseCmp(cmp);
110 }
111 lin.bias_ /= -*selfK;
112 for (auto &item : lin.coeff_) {
113 item.k_ /= -*selfK;
114 }
115
116 // 4. Construct the bounds according to cmp
117 // We normalize LT and GT to LE and GE according to the following:
118 // selfK * x < y <==> selfK * x <= y - 1 (as an integer expression)
119 // x < 1/selfK * y <==> x <= 1/selfK * y - 1/selfK
120 switch (cmp) {
121 case ASTNodeType::LE:
122 return RetType(std::nullopt, std::make_optional<UpperBound>(lin));
123 case ASTNodeType::LT:
124 return RetType(
125 std::nullopt,
126 std::make_optional<UpperBound>(LinearExpr<Rational<int64_t>>{
127 lin.coeff_, lin.bias_ - 1 / std::abs(*selfK)}));
128 case ASTNodeType::GE:
129 return RetType(std::make_optional<LowerBound>(lin), std::nullopt);
130 case ASTNodeType::GT:
131 return RetType(
132 std::make_optional<LowerBound>(LinearExpr<Rational<int64_t>>{
133 lin.coeff_, lin.bias_ + 1 / std::abs(*selfK)}),
134 std::nullopt);
135 case ASTNodeType::EQ:
136 return RetType(std::make_optional<LowerBound>(lin),
137 std::make_optional<UpperBound>(lin));
138 default:
139 return RetType(std::nullopt, std::nullopt);
140 }
141}
142
143namespace detail {
144
146 switch (type) {
147 case ASTNodeType::LT:
148 return ASTNodeType::GT;
149 case ASTNodeType::LE:
150 return ASTNodeType::GE;
151 case ASTNodeType::GT:
152 return ASTNodeType::LT;
153 case ASTNodeType::GE:
154 return ASTNodeType::LE;
155 case ASTNodeType::EQ:
156 return ASTNodeType::EQ;
157 case ASTNodeType::NE:
158 return ASTNodeType::NE;
159 default:
160 ASSERT(false);
161 }
162}
163
164}; // namespace detail
165
166} // namespace freetensor
167
168#endif // FREE_TENSOR_BOUNDS_H
Definition: hash.h:67
Definition: bounds.h:37
LowerBound(const Expr &expr)
Definition: bounds.h:43
friend std::ostream & operator<<(std::ostream &os, const LowerBound &u)
Definition: bounds.h:51
LowerBound(const LinearExpr< Rational< int64_t > > &lin)
Definition: bounds.h:44
const LinearExpr< Rational< int64_t > > & lin() const
Definition: bounds.h:49
LowerBound(LinearExpr< Rational< int64_t > > &&lin)
Definition: bounds.h:45
Definition: bounds.h:17
friend std::ostream & operator<<(std::ostream &os, const UpperBound &b)
Definition: bounds.h:32
UpperBound(LinearExpr< Rational< int64_t > > &&lin)
Definition: bounds.h:26
const Expr & expr()
Definition: bounds.cc:131
UpperBound(const Expr &expr)
Definition: bounds.h:23
const LinearExpr< Rational< int64_t > > & lin() const
Definition: bounds.h:30
UpperBound(const LinearExpr< Rational< int64_t > > &lin)
Definition: bounds.h:25
#define ASSERT(expr)
Definition: except.h:152
ASTNodeType reverseCmp(ASTNodeType type)
Definition: bounds.h:145
Definition: allocator.h:9
std::pair< std::optional< LowerBound >, std::optional< UpperBound > > lin2bounds(const LinearExpr< T > &_lin, ASTNodeType cmp, const Expr &x)
Definition: bounds.h:84
Expr deepCopy(const Expr &op)
Definition: ast.cc:364
std::unordered_set< std::string > allNames(const AST &op, bool noRecurseIdx=false, bool noRecurseSubStmt=false)
Definition: all_uses.h:134
ASTNodeType
Definition: ast.h:20
STL namespace.
freetensor::Rational< T > abs(const freetensor::Rational< T > &x)
Definition: rational.h:85
Definition: linear.h:23
T bias_
Definition: linear.h:28
std::vector< Scale< T > > coeff_
Definition: linear.h:27
Definition: rational.h:9
Definition: linear.h:15