FreeTensor
Loading...
Searching...
No Matches
code_gen_cuda.h
Go to the documentation of this file.
1#ifndef FREE_TENSOR_CODE_GEN_CUDA_H
2#define FREE_TENSOR_CODE_GEN_CUDA_H
3
4#ifdef FT_WITH_CUDA
5
6#include <unordered_map>
7#include <unordered_set>
8
10#include <codegen/native_code.h>
11#include <func.h>
12
13namespace freetensor {
14
15struct CodeGenCUDAStream : public CodeGenStream {
16 std::unordered_map<ParallelScope, Expr> threadDim_;
17 Expr sharedSize_ = makeIntConst(0);
18};
19
20class CodeGenCUDA : public CodeGenC<CodeGenCUDAStream> {
21 public:
22 typedef CodeGenCUDAStream Stream;
23
24 private:
25 Ref<GPUTarget> target_;
26 std::string kernelPrefix_;
27 int nKernel_ = 0;
28 Expr sharedStackTop_ = makeIntConst(0);
29 Expr globalStackTop_ = makeIntConst(0);
30 Expr globalSize_ = makeIntConst(0);
31 std::unordered_set<Stmt> streamScopes_;
32 bool inMatmul_ = false;
33 std::vector<std::string> neededMicroKernels_;
34
35 public:
36 CodeGenCUDA(const std::vector<FuncParam> &params,
37 const std::vector<FuncRet> &returns,
38 const Ref<GPUTarget> &target, const std::string &kernelPrefix)
39 : CodeGenC(params, returns), target_(target),
40 kernelPrefix_(kernelPrefix) {}
41
42 using CodeGenC<CodeGenCUDAStream>::genMdPtrType;
43 using CodeGenC<CodeGenCUDAStream>::genMdPtrDef;
44 std::function<std::ostream &(std::ostream &)>
45 genMdPtrType(const VarDef &def, bool isConst = false) override;
46 void genMdPtrDef(const VarDef &def, const std::function<void()> &genRawPtr,
47 bool isConst = false) override;
48
49 Expr globalSize() const { return globalSize_; }
50
51 std::string gen(const DataType &dtype) override;
52
53 const auto &neededMicroKernels() const { return neededMicroKernels_; }
54
55 private:
56 bool inKernel() const;
57
58 void exprOr1(const std::unordered_map<ParallelScope, Expr> &dict,
59 const ParallelScope &key);
60
61 void enterKernel(const Stmt &body);
62
70 bool canRunInKernel(const Stmt &stmt);
71
72 protected:
73 void genAlloc(const Ref<Tensor> &tensor, const std::string &rawPtr,
74 const std::string &shapePtr,
75 const std::string &dimPtr) override;
76
77 using CodeGenC::genScalar;
78 void genScalar(const VarDef &def,
79 const std::vector<Expr> &indices) override;
80
81 using CodeGenC<CodeGenCUDAStream>::visit;
82 void visitStmt(const Stmt &stmt) override;
83 void visit(const Min &op) override;
84 void visit(const Max &op) override;
85 void visit(const Sqrt &op) override;
86 void visit(const Exp &op) override;
87 void visit(const Ln &op) override;
88 void visit(const Sin &op) override;
89 void visit(const Cos &op) override;
90 void visit(const Tan &op) override;
91 void visit(const Tanh &op) override;
92 void visit(const Abs &op) override;
93 void visit(const Floor &op) override;
94 void visit(const Ceil &op) override;
95 void visit(const Cast &op) override;
96 void visit(const ReduceTo &op) override;
97 void visit(const Var &op) override;
98 void visit(const For &op) override;
99 void visit(const VarDef &op) override;
100 void visit(const MatMul &op) override;
101 void visit(const Load &op) override;
102 void visit(const Store &op) override;
103 void visit(const Alloc &op) override;
104 void visit(const Free &op) override;
105};
106
112NativeCode codeGenCUDA(const Func &func, const Ref<Target> &target);
113
114} // namespace freetensor
115
116#endif // FT_WITH_CUDA
117
118#endif // FREE_TENSOR_CODE_GEN_CUDA_H
Definition: allocator.h:9
PBSet params(T &&set)
Definition: presburger.h:1065
Ref< FuncNode > Func
Definition: func.h:64
Ref< ExprNode > Expr
Definition: ast.h:184
Expr makeIntConst(int64_t val, std::source_location loc=std::source_location::current())
Definition: expr.h:102