FreeTensor
Loading...
Searching...
No Matches
normalize_thread_dims.h
Go to the documentation of this file.
1#ifndef FREE_TENSOR_GPU_NORMALIZE_THREAD_DIMS_H
2#define FREE_TENSOR_GPU_NORMALIZE_THREAD_DIMS_H
3
4#ifdef FT_WITH_CUDA
5
6#include <unordered_set>
7
10#include <mutator.h>
11
12namespace freetensor {
13
14namespace gpu {
15
16class NormalizeThreadDims : public CompTransientBounds<SymbolTable<Mutator>> {
17 typedef CompTransientBounds<SymbolTable<Mutator>> BaseClass;
18
19 std::unordered_set<For> openLoopsInKernel_;
20 bool inKernel_ = false;
21
22 private:
26 bool isLegalLen(const Expr &expr);
27 bool isLegalLen(const std::unordered_set<std::string> &names);
28
29 protected:
30 using BaseClass::visit;
31 Stmt visit(const For &op) override;
32};
33
40inline Stmt normalizeThreadDims(const Stmt &ast) {
41 return NormalizeThreadDims{}(ast);
42}
43
44} // namespace gpu
45
46} // namespace freetensor
47
48#endif // FT_WITH_CUDA
49
50#endif // FREE_TENSOR_GPU_NORMALIZE_THREAD_DIMS_H
Definition: allocator.h:9
Ref< StmtNode > Stmt
Definition: ast.h:152