FreeTensor
Loading...
Searching...
No Matches
const_fold.h
Go to the documentation of this file.
1#ifndef FREE_TENSOR_CONST_FOLD_H
2#define FREE_TENSOR_CONST_FOLD_H
3
4#include <func.h>
5#include <mutator.h>
6
7namespace freetensor {
8
17class ConstFold : public Mutator {
27 template <typename F> static auto dispatch(const Const &c, F f) {
28 switch (c->nodeType()) {
30 return f(c.as<IntConstNode>()->val_);
32 return f(c.as<FloatConstNode>()->val_);
34 return f(c.as<BoolConstNode>()->val_);
35 default:
36 ASSERT(false && "Unknown Const node");
37 }
38 }
39
47 static Const wrap(const int &t) { return makeIntConst(t).as<ConstNode>(); }
48
49 static Const wrap(const int64_t &t) {
50 return makeIntConst(t).as<ConstNode>();
51 }
52
53 static Const wrap(const double &t) {
54 return makeFloatConst(t).as<ConstNode>();
55 }
56
57 static Const wrap(const bool &t) {
58 return makeBoolConst(t).as<ConstNode>();
59 }
72 template <typename F, typename FAlt>
73 Expr visitBinary(const BinaryExpr &op, F f, FAlt falt) {
74 auto lhs = (*this)(op->lhs_);
75 auto rhs = (*this)(op->rhs_);
76 if (lhs->isConst() && rhs->isConst()) {
77 return dispatch(lhs.as<ConstNode>(), [&](auto ll) {
78 return dispatch(rhs.as<ConstNode>(),
79 [&](auto rr) { return wrap(f(ll, rr)); });
80 });
81 } else {
82 return falt(lhs, rhs);
83 }
84 }
85
96 template <typename F, typename FAlt>
97 Expr visitUnary(const UnaryExpr &op, F f, FAlt falt) {
98 auto x = (*this)(op->expr_);
99 if (x->isConst()) {
100 return dispatch(x.as<ConstNode>(),
101 [&](auto xx) { return wrap(f(xx)); });
102 } else {
103 return falt(x);
104 }
105 }
106
107 protected:
115 static Const castType(DataType type, const Const &val) {
116 return dispatch(val, [type](auto v) {
117 switch (type.base()) {
118 case DataType::Int32:
119 case DataType::Int64:
120 return wrap(int64_t(v));
121 case DataType::Float16:
122 case DataType::Float32:
123 case DataType::Float64:
124 return wrap(double(v));
125 case DataType::Bool:
126 return wrap(bool(v));
127 default:
128 ASSERT(false && "Unrecognized variable type assigned");
129 }
130 });
131 }
132
133 protected:
134 using Mutator::visit;
135 Expr visit(const Add &op) override;
136 Expr visit(const Sub &op) override;
137 Expr visit(const Mul &op) override;
138 Expr visit(const RealDiv &op) override;
139 Expr visit(const FloorDiv &op) override;
140 Expr visit(const CeilDiv &op) override;
141 Expr visit(const RoundTowards0Div &op) override;
142 Expr visit(const Mod &op) override;
143 Expr visit(const Remainder &op) override;
144 Expr visit(const Min &op) override;
145 Expr visit(const Max &op) override;
146 Expr visit(const LT &op) override;
147 Expr visit(const LE &op) override;
148 Expr visit(const GT &op) override;
149 Expr visit(const GE &op) override;
150 Expr visit(const EQ &op) override;
151 Expr visit(const NE &op) override;
152 Expr visit(const LAnd &op) override;
153 Expr visit(const LOr &op) override;
154 Expr visit(const LNot &op) override;
155 Expr visit(const Sqrt &op) override;
156 Expr visit(const Exp &op) override;
157 Expr visit(const Square &op) override;
158 Expr visit(const Sigmoid &op) override;
159 Expr visit(const Sin &op) override;
160 Expr visit(const Cos &op) override;
161 Expr visit(const Tan &op) override;
162 Expr visit(const Tanh &op) override;
163 Expr visit(const Abs &op) override;
164 Expr visit(const Floor &op) override;
165 Expr visit(const Ceil &op) override;
166 Expr visit(const Unbound &op) override;
167 Expr visit(const Cast &op) override;
168 Expr visit(const IfExpr &op) override;
169};
170
177inline Stmt constFold(const Stmt &op) { return ConstFold()(op); }
178inline Expr constFold(const Expr &op) { return ConstFold()(op); }
179
180DEFINE_PASS_FOR_FUNC(constFold)
181
182} // namespace freetensor
183
184#endif // FREE_TENSOR_CONST_FOLD_H
Definition: expr.h:127
bool val_
Definition: expr.h:129
Definition: const_fold.h:17
Expr visit(const Square &op) override
Expr visit(const LE &op) override
Expr visit(const Min &op) override
Expr visit(const Sigmoid &op) override
Expr visit(const LAnd &op) override
Expr visit(const Cos &op) override
Expr visit(const Exp &op) override
Expr visit(const RealDiv &op) override
Expr visit(const NE &op) override
static Const castType(DataType type, const Const &val)
Cast the data type of a Const node.
Definition: const_fold.h:115
Expr visit(const Abs &op) override
Expr visit(const Floor &op) override
Expr visit(const GT &op) override
Expr visit(const Remainder &op) override
Expr visit(const Mod &op) override
Expr visit(const FloorDiv &op) override
Expr visit(const GE &op) override
Expr visit(const Mul &op) override
Expr visit(const Sub &op) override
Expr visit(const CeilDiv &op) override
Expr visit(const Sin &op) override
Expr visit(const EQ &op) override
Expr visit(const Max &op) override
Expr visit(const LNot &op) override
Expr visit(const Tan &op) override
Expr visit(const RoundTowards0Div &op) override
Expr visit(const Tanh &op) override
Expr visit(const LT &op) override
Expr visit(const Sqrt &op) override
Expr visit(const Add &op) override
Expr visit(const Ceil &op) override
Expr visit(const LOr &op) override
Definition: expr.h:86
Definition: data_type.h:106
const auto & base() const
Definition: data_type.h:132
Definition: expr.h:110
double val_
Definition: expr.h:112
Definition: expr.h:93
int64_t val_
Definition: expr.h:95
Definition: mutator.h:11
virtual Stmt visit(const Any &op)
Definition: mutator.h:39
Definition: ref.h:24
Ref< U > as() const
Definition: ref.h:83
#define ASSERT(expr)
Definition: except.h:152
#define DEFINE_PASS_FOR_FUNC(pass)
Definition: func.h:88
Definition: allocator.h:9
Stmt constFold(const Stmt &op)
Definition: const_fold.h:177
auto && lhs
Definition: const_fold.cc:70
auto auto && rhs
Definition: const_fold.cc:70
Expr makeBoolConst(bool val, std::source_location loc=std::source_location::current())
Definition: expr.h:136
Expr makeIntConst(int64_t val, std::source_location loc=std::source_location::current())
Definition: expr.h:102
Expr makeFloatConst(double val, std::source_location loc=std::source_location::current())
Definition: expr.h:119