1#ifndef FREE_TENSOR_GPU_NORMALIZE_THREADS_H
2#define FREE_TENSOR_GPU_NORMALIZE_THREADS_H
6#include <unordered_map>
16class NormalizeThreads :
public Mutator {
18 std::unordered_map<std::string, std::string> varMap_;
19 std::unordered_map<ParallelScope, int>
22 std::unordered_map<ParallelScope, std::vector<ID>> loops_;
23 bool inKernel_ =
false;
26 NormalizeThreads(
const Stmt &root) : root_(root) {}
29 Stmt makeParallelScopes(
const Stmt &body);
31 Stmt doVisitFor(
const For &op);
32 Stmt doVisitStmt(
const Stmt &op);
35 Expr visit(
const Var &op)
override;
36 Stmt visit(
const For &op)
override;
37 Stmt visit(
const VarDef &op)
override;
38 Stmt visit(
const Store &op)
override;
39 Stmt visit(
const ReduceTo &op)
override;
40 Stmt visit(
const Eval &op)
override;
43class ShrinkNormalizedThreads :
public ShrinkFor {
44 typedef ShrinkFor BaseClass;
46 std::unordered_set<For> openLoopsInKernel_;
47 bool inKernel_ =
false;
50 bool filterLoop(
const For &op)
override;
52 std::unordered_set<std::string>
53 filterNames(
const std::unordered_set<std::string> &names)
override;
56 using BaseClass::visit;
57 Stmt visit(
const For &op)
override;
67Stmt normalizeThreads(
const Stmt &op);
#define DEFINE_PASS_FOR_FUNC(pass)
Definition: func.h:88
Definition: allocator.h:9
Ref< StmtNode > Stmt
Definition: ast.h:152
Ref< ExprNode > Expr
Definition: ast.h:184