FreeTensor
Loading...
Searching...
No Matches
lower.h
Go to the documentation of this file.
1#ifndef FREE_TENSOR_LOWER_H
2#define FREE_TENSOR_LOWER_H
3
4#include <unordered_set>
5
7#include <config.h>
8#include <driver/target.h>
10#include <pass/float_simplify.h>
13#include <pass/gpu/make_sync.h>
20#include <pass/make_reduction.h>
26#include <pass/remove_writes.h>
28#include <pass/shrink_for.h>
29#include <pass/shrink_var.h>
30#include <pass/simplify.h>
31#include <pass/sink_var.h>
34#include <pass/z3_simplify.h>
35
36namespace freetensor {
37
52template <class T>
53T lower(const T &_ast, const Ref<Target> &_target = nullptr,
54 const std::unordered_set<std::string> &skipPasses = {},
55 int verbose = 0) {
56
57 auto target = _target.isValid() ? _target : Config::defaultTarget();
58
59 auto maybePrint = [&](const std::string &name, const T &ast) -> T {
60 if (verbose >= 2) {
61 logger() << "AST after " << name << " is:" << std::endl
62 << ast << std::endl;
63 }
64 return ast;
65 };
66
67#define FIRST_OF(x, ...) (x)
68#define APPLY(name, pass, ...) \
69 skipPasses.count(name) ? FIRST_OF(__VA_ARGS__) \
70 : maybePrint(name, pass(__VA_ARGS__))
71
72 // NOTE: The following passes enables each other: some optimizations can be
73 // done in pass A only after we do pass B first. Thus the order of the
74 // passes matters. If you found some program that cannot be optimized by the
75 // current order, add it to `test/20.pass/test_lower.py` and adjust the
76 // order.
77 //
78 // We only focus on programs having a real use, because there is no one
79 // order that fits all. A seemingly possible solution is to run all the
80 // passes iteratively until convergence, but the passes are slow and it may
81 // require a number of iterations proportional to the program size to
82 // converge. Such a progam can be
83 //
84 // ```
85 // if (1 == 1) {
86 // a = 1
87 // }
88 // if (a == 1) {
89 // b = 1
90 // }
91 // if (b == 1) {
92 // c = 1
93 // }
94 // ```
95 //
96 // where it needs `simplify` to remove the `if`s, and `prop_const` to fill
97 // the varaible into the `if`s' conditions. We consider it more important to
98 // compile a program than to make it optimal, so we are not going to fully
99 // optimize it.
100
101 T ast = _ast;
102 ast = clearMarkVersion(ast);
103 ast = APPLY("make_reduction", makeReduction, ast);
104 ast = APPLY("scalar_prop_const", scalarPropConst, ast);
105 ast = APPLY("remove_dead_var", removeDeadVar, ast);
106 ast = APPLY("simplify", simplify,
107 ast); // first time before propagations for indices
108 ast = APPLY("remove_writes", removeWrites, ast);
109 ast = APPLY("prop_one_time_use", propOneTimeUse, ast);
110 ast = APPLY("float_simplify", floatSimplify, ast); // After propOneTimeUse
111 ast = APPLY("z3_simplify", z3Simplify, ast);
112 ast = APPLY("simplify", simplify,
113 ast); // next time after propagations for propagated values
114 ast = APPLY("move_out_first_or_last_iter", moveOutFirstOrLastIter, ast);
115 ast = APPLY("sink_var", sinkVar, ast);
116 ast = APPLY("shrink_var", shrinkVar, ast);
117 ast = APPLY("merge_and_hoist_if", mergeAndHoistIf, ast);
118 ast = APPLY("tensor_prop_const", tensorPropConst, ast);
119 ast = APPLY("remove_dead_var", removeDeadVar,
120 ast); // After remove_writes and prop_const
121 ast = APPLY("remove_cyclic_assign", removeCyclicAssign,
122 ast); // After remove_writes and remove_dead_var
123 ast = APPLY("make_parallel_reduction", makeParallelReduction, ast, target);
124 ast = APPLY("shrink_for", shrinkFor,
125 ast); // After remove_writes and make_parallel_reduction
126
127 switch (target->type()) {
128#ifdef FT_WITH_CUDA
129 case TargetType::GPU: {
130 auto t = target.as<GPUTarget>();
131 ast = APPLY("gpu_lower_parallel_reduction", gpu::lowerParallelReduction,
132 ast); // Before gpu_nromalize_threads
133 ast = APPLY("gpu_multiplex_buffers", gpu::multiplexBuffers, ast, t);
134 ast = APPLY("gpu_simplex_buffers", gpu::simplexBuffers, ast);
135 ast = APPLY("gpu_normalize_threads", gpu::normalizeThreads,
136 ast); // After gpu_multiplex_buffers
137 ast = APPLY("gpu_normalize_var_in_kernel", gpu::normalizeVarInKernel,
138 ast);
139 ast = APPLY("make_heap_alloc", makeHeapAlloc, ast);
140 ast = APPLY("gpu_make_sync", gpu::makeSync, ast,
141 t); // After gpu_normalize_threads
142 ast = APPLY("gpu_lower_vector", gpu::lowerVector, ast);
143 ast = APPLY("use_builtin_div", useBuiltinDiv, ast);
144 break;
145 }
146#endif // FT_WITH_CUDA
147
148 case TargetType::CPU:
149 ast = APPLY("cpu_lower_parallel_reduction", cpu::lowerParallelReduction,
150 ast);
151 ast = APPLY("make_heap_alloc", makeHeapAlloc, ast);
152 ast = APPLY("use_builtin_div", useBuiltinDiv, ast);
153 break;
154
155 default:
156 ASSERT(false);
157 }
158
159#undef FIRST_OF
160#undef APPLY
161
162 if (verbose >= 1) {
163 logger() << "The lowered AST is:" << std::endl << ast << std::endl;
164 }
165
166 return ast;
167}
168
169} // namespace freetensor
170
171#endif // FREE_TENSOR_LOWER_H
static Ref< Target > defaultTarget()
Definition: config.h:146
Definition: ref.h:24
#define ASSERT(expr)
Definition: except.h:152
#define APPLY(name, pass,...)
Stmt lowerParallelReduction(const Stmt &op)
Definition: lower_parallel_reduction.cc:206
Stmt normalizeVarInKernel(const Stmt &s)
Definition: normalize_var_in_kernel.cc:121
Definition: allocator.h:9
Stmt removeDeadVar(const Stmt &op)
Definition: remove_dead_var.cc:124
Stmt useBuiltinDiv(const Stmt &op)
Definition: use_builtin_div.cc:95
T lower(const T &_ast, const Ref< Target > &_target=nullptr, const std::unordered_set< std::string > &skipPasses={}, int verbose=0)
Definition: lower.h:53
Stmt shrinkFor(const Stmt &op, const ID &subAST=ID(), bool doSimplify=true, bool unordered=false)
Definition: shrink_for.cc:396
Stmt simplify(const Stmt &op)
Definition: simplify.cc:1036
Stmt tensorPropConst(const Stmt &op, const ID &bothInSubAST=ID(), const ID &eitherInSubAST=ID())
Definition: tensor_prop_const.cc:25
Stmt sinkVar(const Stmt &op, const std::optional< std::unordered_set< ID > > &toSink=std::nullopt, const std::function< bool(const Stmt &)> &scopeFilter=nullptr)
Definition: sink_var.cc:182
Stmt z3Simplify(const Stmt &op)
Definition: z3_simplify.cc:602
Logger logger()
Definition: logger.h:60
Stmt mergeAndHoistIf(const Stmt &op)
Definition: merge_and_hoist_if.cc:104
Stmt shrinkVar(const Stmt &op)
Definition: shrink_var.cc:102
Stmt makeHeapAlloc(const Stmt &op)
Definition: make_heap_alloc.cc:127
Stmt clearMarkVersion(const Stmt &op)
Definition: clear_mark_version.h:15
Stmt removeWrites(const Stmt &op, const ID &singleDefId={})
Definition: remove_writes.cc:185
Stmt moveOutFirstOrLastIter(const Stmt &op)
Definition: move_out_first_or_last_iter.h:35
Stmt scalarPropConst(const Stmt &op)
Definition: scalar_prop_const.cc:241
Stmt propOneTimeUse(const Stmt &op, const ID &subAST=ID())
Definition: prop_one_time_use.cc:57
Stmt floatSimplify(const Stmt &op)
Definition: float_simplify.cc:492
Stmt removeCyclicAssign(const Stmt &op)
Definition: remove_cyclic_assign.cc:6
Stmt makeReduction(const Stmt &op, const std::unordered_set< ReduceOp > &types, bool canonicalOnly=false)
Definition: make_reduction.h:38
Stmt makeParallelReduction(const Stmt &op, const Ref< Target > &target)
Definition: make_parallel_reduction.cc:376