1#ifndef FREE_TENSOR_CONST_FOLD_H
2#define FREE_TENSOR_CONST_FOLD_H
27 template <
typename F>
static auto dispatch(
const Const &c, F f) {
28 switch (c->nodeType()) {
36 ASSERT(
false &&
"Unknown Const node");
49 static Const wrap(
const int64_t &t) {
53 static Const wrap(
const double &t) {
57 static Const wrap(
const bool &t) {
72 template <
typename F,
typename FAlt>
74 auto lhs = (*this)(op->lhs_);
75 auto rhs = (*this)(op->rhs_);
76 if (
lhs->isConst() &&
rhs->isConst()) {
78 return dispatch(rhs.as<ConstNode>(),
79 [&](auto rr) { return wrap(f(ll, rr)); });
96 template <
typename F,
typename FAlt>
98 auto x = (*this)(op->expr_);
101 [&](
auto xx) { return wrap(f(xx)); });
116 return dispatch(val, [type](
auto v) {
117 switch (type.
base()) {
118 case DataType::Int32:
119 case DataType::Int64:
120 return wrap(int64_t(v));
121 case DataType::Float16:
122 case DataType::Float32:
123 case DataType::Float64:
124 return wrap(double(v));
126 return wrap(bool(v));
128 ASSERT(false &&
"Unrecognized variable type assigned");
167 Expr visit(
const Cast &op)
override;
bool val_
Definition: expr.h:129
Definition: const_fold.h:17
Expr visit(const Square &op) override
Expr visit(const LE &op) override
Expr visit(const Min &op) override
Expr visit(const Sigmoid &op) override
Expr visit(const LAnd &op) override
Expr visit(const Cos &op) override
Expr visit(const Exp &op) override
Expr visit(const RealDiv &op) override
Expr visit(const NE &op) override
static Const castType(DataType type, const Const &val)
Cast the data type of a Const node.
Definition: const_fold.h:115
Expr visit(const Abs &op) override
Expr visit(const Floor &op) override
Expr visit(const GT &op) override
Expr visit(const Remainder &op) override
Expr visit(const Mod &op) override
Expr visit(const FloorDiv &op) override
Expr visit(const GE &op) override
Expr visit(const Mul &op) override
Expr visit(const Sub &op) override
Expr visit(const CeilDiv &op) override
Expr visit(const Sin &op) override
Expr visit(const EQ &op) override
Expr visit(const Max &op) override
Expr visit(const LNot &op) override
Expr visit(const Tan &op) override
Expr visit(const RoundTowards0Div &op) override
Expr visit(const Tanh &op) override
Expr visit(const LT &op) override
Expr visit(const Sqrt &op) override
Expr visit(const Add &op) override
Expr visit(const Ceil &op) override
Expr visit(const LOr &op) override
Definition: data_type.h:106
const auto & base() const
Definition: data_type.h:132
double val_
Definition: expr.h:112
int64_t val_
Definition: expr.h:95
virtual Stmt visit(const Any &op)
Definition: mutator.h:39
Ref< U > as() const
Definition: ref.h:83
#define ASSERT(expr)
Definition: except.h:152
#define DEFINE_PASS_FOR_FUNC(pass)
Definition: func.h:88
Definition: allocator.h:9
Stmt constFold(const Stmt &op)
Definition: const_fold.h:177
auto && lhs
Definition: const_fold.cc:70
auto auto && rhs
Definition: const_fold.cc:70
Expr makeBoolConst(bool val, std::source_location loc=std::source_location::current())
Definition: expr.h:136
Expr makeIntConst(int64_t val, std::source_location loc=std::source_location::current())
Definition: expr.h:102
Expr makeFloatConst(double val, std::source_location loc=std::source_location::current())
Definition: expr.h:119