FreeTensor
Loading...
Searching...
No Matches
lower_vector.h
Go to the documentation of this file.
1#ifndef FREE_TENSOR_GPU_LOWER_VECTOR_H
2#define FREE_TENSOR_GPU_LOWER_VECTOR_H
3
4#ifdef FT_WITH_CUDA
5
6#include <unordered_map>
7
9#include <func.h>
10#include <pass/z3_simplify.h>
11
12namespace freetensor {
13
14namespace gpu {
15
16class LowerVector : public Z3SimplifyWithSymbolTable {
17 typedef Z3Simplify BaseClass;
18
19 static constexpr int VEC_LEN[] = {4, 2};
20
21 Var var_;
22 Expr begin_;
23 int vecLen_, isIndex_ = 0;
24 bool simplifyOnly_ = false;
25
26 AnalyzeLinear analyzeLinear_;
27
28 private:
29 std::string vecType(DataType dtype) const;
30 bool hasVectorIndices(const std::vector<Expr> &indices,
31 const std::vector<Expr> &shape);
32 std::vector<Expr> getIndices(const std::vector<Expr> &indices);
33
34 protected:
35 using BaseClass::visit;
36
37 Stmt visit(const For &op) override;
38 Expr visit(const Var &op) override;
39 Expr visit(const Load &op) override;
40 Stmt visit(const Store &op) override;
41 Stmt visit(const ReduceTo &op) override;
42};
43
44Stmt lowerVector(const Stmt &op);
45
46DEFINE_PASS_FOR_FUNC(lowerVector)
47
48} // namespace gpu
49
50} // namespace freetensor
51
52#endif // FT_WITH_CUDA
53
54#endif // FREE_TENSOR_GPU_LOWER_VECTOR_H
#define DEFINE_PASS_FOR_FUNC(pass)
Definition: func.h:88
Definition: allocator.h:9
Ref< VarNode > Var
Definition: expr.h:40
Ref< StmtNode > Stmt
Definition: ast.h:152
Ref< ExprNode > Expr
Definition: ast.h:184