1#ifndef FREE_TENSOR_CODE_GEN_CUDA_H
2#define FREE_TENSOR_CODE_GEN_CUDA_H
6#include <unordered_map>
7#include <unordered_set>
15struct CodeGenCUDAStream :
public CodeGenStream {
16 std::unordered_map<ParallelScope, Expr> threadDim_;
20class CodeGenCUDA :
public CodeGenC<CodeGenCUDAStream> {
22 typedef CodeGenCUDAStream Stream;
25 Ref<GPUTarget> target_;
26 std::string kernelPrefix_;
31 std::unordered_set<Stmt> streamScopes_;
32 bool inMatmul_ =
false;
33 std::vector<std::string> neededMicroKernels_;
36 CodeGenCUDA(
const std::vector<FuncParam> ¶ms,
37 const std::vector<FuncRet> &returns,
38 const Ref<GPUTarget> &target,
const std::string &kernelPrefix)
39 : CodeGenC(
params, returns), target_(target),
40 kernelPrefix_(kernelPrefix) {}
42 using CodeGenC<CodeGenCUDAStream>::genMdPtrType;
43 using CodeGenC<CodeGenCUDAStream>::genMdPtrDef;
44 std::function<std::ostream &(std::ostream &)>
45 genMdPtrType(
const VarDef &def,
bool isConst =
false)
override;
46 void genMdPtrDef(
const VarDef &def,
const std::function<
void()> &genRawPtr,
47 bool isConst =
false)
override;
49 Expr globalSize()
const {
return globalSize_; }
51 std::string gen(
const DataType &dtype)
override;
53 const auto &neededMicroKernels()
const {
return neededMicroKernels_; }
56 bool inKernel()
const;
58 void exprOr1(
const std::unordered_map<ParallelScope, Expr> &dict,
59 const ParallelScope &key);
61 void enterKernel(
const Stmt &body);
70 bool canRunInKernel(
const Stmt &stmt);
73 void genAlloc(
const Ref<Tensor> &tensor,
const std::string &rawPtr,
74 const std::string &shapePtr,
75 const std::string &dimPtr)
override;
77 using CodeGenC::genScalar;
78 void genScalar(
const VarDef &def,
79 const std::vector<Expr> &indices)
override;
81 using CodeGenC<CodeGenCUDAStream>::visit;
82 void visitStmt(
const Stmt &stmt)
override;
83 void visit(
const Min &op)
override;
84 void visit(
const Max &op)
override;
85 void visit(
const Sqrt &op)
override;
86 void visit(
const Exp &op)
override;
87 void visit(
const Ln &op)
override;
88 void visit(
const Sin &op)
override;
89 void visit(
const Cos &op)
override;
90 void visit(
const Tan &op)
override;
91 void visit(
const Tanh &op)
override;
92 void visit(
const Abs &op)
override;
93 void visit(
const Floor &op)
override;
94 void visit(
const Ceil &op)
override;
95 void visit(
const Cast &op)
override;
96 void visit(
const ReduceTo &op)
override;
97 void visit(
const Var &op)
override;
98 void visit(
const For &op)
override;
99 void visit(
const VarDef &op)
override;
100 void visit(
const MatMul &op)
override;
101 void visit(
const Load &op)
override;
102 void visit(
const Store &op)
override;
103 void visit(
const Alloc &op)
override;
104 void visit(
const Free &op)
override;
112NativeCode codeGenCUDA(
const Func &func,
const Ref<Target> &target);
Definition: allocator.h:9
PBSet params(T &&set)
Definition: presburger.h:1065
Ref< FuncNode > Func
Definition: func.h:64
Ref< ExprNode > Expr
Definition: ast.h:184
Expr makeIntConst(int64_t val, std::source_location loc=std::source_location::current())
Definition: expr.h:102