FreeTensor
Loading...
Searching...
No Matches
code_gen_c.h
Go to the documentation of this file.
1#ifndef FREE_TENSOR_CODE_GEN_C_H
2#define FREE_TENSOR_CODE_GEN_C_H
3
4#include <functional>
5#include <unordered_map>
6#include <unordered_set>
7#include <vector>
8
9#include <codegen/code_gen.h>
11
12namespace freetensor {
13
14template <class Stream> class CodeGenC : public CodeGen<Stream> {
16
17 const std::vector<FuncParam> &params_;
18 const std::vector<FuncRet> &returns_;
19
20 public:
21 CodeGenC(const std::vector<FuncParam> &params,
22 const std::vector<FuncRet> &returns)
23 : params_(params), returns_(returns) {}
24
25 virtual std::string gen(const DataType &dtype);
26
27 protected:
28 virtual void genAlloc(const Ref<Tensor> &tensor, const std::string &rawPtr,
29 const std::string &shapePtr,
30 const std::string &dimPtr) = 0;
31
32 // Generate a pointer to an multi-dimensional array
33 virtual std::function<std::ostream &(std::ostream &)>
34 genMdPtrType(const VarDef &def, bool isConst = false);
35 virtual void genMdPtrDef(const VarDef &def,
36 const std::function<void()> &genRawPtr,
37 bool isConst = false);
38 void genMdPtrDef(const VarDef &def, const std::string &rawPtr,
39 bool isConst = false) {
40 this->genMdPtrDef(
41 def, [&]() { this->os() << rawPtr; }, isConst);
42 }
43
44 // Generate the access to a scalar or an element of an array
45 virtual void genScalar(const VarDef &def, const std::vector<Expr> &indices);
46 template <class T> void genScalar(const T &op) {
47 genScalar(this->def(op->var_), op->indices_);
48 }
49
50 virtual void visit(const StmtSeq &op) override;
51 virtual void visit(const VarDef &op) override;
52 virtual void visit(const Var &op) override;
53 virtual void visit(const Store &op) override;
54 virtual void visit(const Alloc &op) override;
55 virtual void visit(const Free &op) override;
56 virtual void visit(const Load &op) override;
57 virtual void visit(const ReduceTo &op) override;
58 virtual void visit(const IntConst &op) override;
59 virtual void visit(const FloatConst &op) override;
60 virtual void visit(const BoolConst &op) override;
61 virtual void visit(const Add &op) override;
62 virtual void visit(const Sub &op) override;
63 virtual void visit(const Mul &op) override;
64 virtual void visit(const RealDiv &op) override;
65 virtual void visit(const FloorDiv &op) override;
66 virtual void visit(const CeilDiv &op) override;
67 virtual void visit(const RoundTowards0Div &op) override;
68 virtual void visit(const Mod &op) override;
69 virtual void visit(const Remainder &op) override;
70 virtual void visit(const Min &op) override;
71 virtual void visit(const Max &op) override;
72 virtual void visit(const LT &op) override;
73 virtual void visit(const LE &op) override;
74 virtual void visit(const GT &op) override;
75 virtual void visit(const GE &op) override;
76 virtual void visit(const EQ &op) override;
77 virtual void visit(const NE &op) override;
78 virtual void visit(const LAnd &op) override;
79 virtual void visit(const LOr &op) override;
80 virtual void visit(const LNot &op) override;
81 virtual void visit(const Sqrt &op) override;
82 virtual void visit(const Exp &op) override;
83 virtual void visit(const Ln &op) override;
84 virtual void visit(const Square &op) override;
85 virtual void visit(const Sigmoid &op) override;
86 virtual void visit(const Sin &op) override;
87 virtual void visit(const Cos &op) override;
88 virtual void visit(const Tan &op) override;
89 virtual void visit(const Tanh &op) override;
90 virtual void visit(const Abs &op) override;
91 virtual void visit(const Floor &op) override;
92 virtual void visit(const Ceil &op) override;
93 virtual void visit(const IfExpr &op) override;
94 virtual void visit(const Cast &op) override;
95 virtual void visit(const For &op) override;
96 virtual void visit(const If &op) override;
97 virtual void visit(const Assert &op) override;
98 virtual void visit(const Intrinsic &op) override;
99 virtual void visit(const Eval &op) override;
100};
101
102} // namespace freetensor
103
104#endif // FREE_TENSOR_CODE_GEN_C_H
Definition: code_gen_c.h:14
virtual void visit(const StmtSeq &op) override
Definition: code_gen_c.h:131
void genScalar(const T &op)
Definition: code_gen_c.h:46
CodeGenC(const std::vector< FuncParam > &params, const std::vector< FuncRet > &returns)
Definition: code_gen_c.h:21
virtual void genAlloc(const Ref< Tensor > &tensor, const std::string &rawPtr, const std::string &shapePtr, const std::string &dimPtr)=0
virtual void genMdPtrDef(const VarDef &def, const std::function< void()> &genRawPtr, bool isConst=false)
Definition: code_gen_c.h:71
virtual std::function< std::ostream &(std::ostream &)> genMdPtrType(const VarDef &def, bool isConst=false)
Definition: code_gen_c.h:21
virtual void genScalar(const VarDef &def, const std::vector< Expr > &indices)
Definition: code_gen_c.h:107
virtual std::string gen(const DataType &dtype)
Definition: code_gen_c.h:764
void genMdPtrDef(const VarDef &def, const std::string &rawPtr, bool isConst=false)
Definition: code_gen_c.h:38
Definition: code_gen.h:28
std::ostream & os()
Definition: code_gen.h:87
Definition: data_type.h:106
Definition: ref.h:24
const VarDef & def(const std::string &name) const override
Definition: symbol_table.h:142
Definition: allocator.h:9
PBSet params(T &&set)
Definition: presburger.h:1065