1#ifndef FREE_TENSOR_LOWER_H
2#define FREE_TENSOR_LOWER_H
4#include <unordered_set>
54 const std::unordered_set<std::string> &skipPasses = {},
59 auto maybePrint = [&](
const std::string &name,
const T &ast) -> T {
61 logger() <<
"AST after " << name <<
" is:" << std::endl
67#define FIRST_OF(x, ...) (x)
68#define APPLY(name, pass, ...) \
69 skipPasses.count(name) ? FIRST_OF(__VA_ARGS__) \
70 : maybePrint(name, pass(__VA_ARGS__))
127 switch (target->type()) {
130 auto t = target.as<GPUTarget>();
131 ast =
APPLY(
"gpu_lower_parallel_reduction", gpu::lowerParallelReduction,
133 ast =
APPLY(
"gpu_multiplex_buffers", gpu::multiplexBuffers, ast, t);
134 ast =
APPLY(
"gpu_simplex_buffers", gpu::simplexBuffers, ast);
135 ast =
APPLY(
"gpu_normalize_threads", gpu::normalizeThreads,
140 ast =
APPLY(
"gpu_make_sync", gpu::makeSync, ast,
142 ast =
APPLY(
"gpu_lower_vector", gpu::lowerVector, ast);
163 logger() <<
"The lowered AST is:" << std::endl << ast << std::endl;
static Ref< Target > defaultTarget()
Definition: config.h:146
#define ASSERT(expr)
Definition: except.h:152
#define APPLY(name, pass,...)
Stmt lowerParallelReduction(const Stmt &op)
Definition: lower_parallel_reduction.cc:206
Stmt normalizeVarInKernel(const Stmt &s)
Definition: normalize_var_in_kernel.cc:121
Definition: allocator.h:9
Stmt removeDeadVar(const Stmt &op)
Definition: remove_dead_var.cc:124
Stmt useBuiltinDiv(const Stmt &op)
Definition: use_builtin_div.cc:95
T lower(const T &_ast, const Ref< Target > &_target=nullptr, const std::unordered_set< std::string > &skipPasses={}, int verbose=0)
Definition: lower.h:53
Stmt shrinkFor(const Stmt &op, const ID &subAST=ID(), bool doSimplify=true, bool unordered=false)
Definition: shrink_for.cc:396
Stmt simplify(const Stmt &op)
Definition: simplify.cc:1036
Stmt tensorPropConst(const Stmt &op, const ID &bothInSubAST=ID(), const ID &eitherInSubAST=ID())
Definition: tensor_prop_const.cc:25
Stmt sinkVar(const Stmt &op, const std::optional< std::unordered_set< ID > > &toSink=std::nullopt, const std::function< bool(const Stmt &)> &scopeFilter=nullptr)
Definition: sink_var.cc:182
Stmt z3Simplify(const Stmt &op)
Definition: z3_simplify.cc:602
Logger logger()
Definition: logger.h:60
Stmt mergeAndHoistIf(const Stmt &op)
Definition: merge_and_hoist_if.cc:104
Stmt shrinkVar(const Stmt &op)
Definition: shrink_var.cc:102
Stmt makeHeapAlloc(const Stmt &op)
Definition: make_heap_alloc.cc:127
Stmt clearMarkVersion(const Stmt &op)
Definition: clear_mark_version.h:15
Stmt removeWrites(const Stmt &op, const ID &singleDefId={})
Definition: remove_writes.cc:185
Stmt moveOutFirstOrLastIter(const Stmt &op)
Definition: move_out_first_or_last_iter.h:35
Stmt scalarPropConst(const Stmt &op)
Definition: scalar_prop_const.cc:241
Stmt propOneTimeUse(const Stmt &op, const ID &subAST=ID())
Definition: prop_one_time_use.cc:57
Stmt floatSimplify(const Stmt &op)
Definition: float_simplify.cc:492
Stmt removeCyclicAssign(const Stmt &op)
Definition: remove_cyclic_assign.cc:6
Stmt makeReduction(const Stmt &op, const std::unordered_set< ReduceOp > &types, bool canonicalOnly=false)
Definition: make_reduction.h:38
Stmt makeParallelReduction(const Stmt &op, const Ref< Target > &target)
Definition: make_parallel_reduction.cc:376