FreeTensor
Loading...
Searching...
No Matches
func.h
Go to the documentation of this file.
1#ifndef FREE_TENSOR_FUNC_H
2#define FREE_TENSOR_FUNC_H
3
4#include <iostream>
5#include <string>
6#include <unordered_map>
7#include <utility>
8#include <vector>
9
10#include <ast.h>
11#include <buffer.h>
12#include <driver/array.h>
13#include <stmt.h>
14#include <tensor.h>
15
16namespace freetensor {
17
18struct FuncParam {
19 std::string name_;
22
23 bool isInClosure() const { return closure_.isValid(); }
24
25 FuncParam(const std::string &name, const Ref<Ref<Array>> &closure,
26 bool updateClosure)
27 : name_(name), closure_(closure), updateClosure_(updateClosure) {}
28};
29
30std::ostream &operator<<(std::ostream &os, const FuncParam &p);
31
32struct FuncRet {
33 std::string name_;
37
38 bool isInClosure() const { return closure_.isValid(); }
39
40 FuncRet(const std::string &name, DataType dtype,
41 const Ref<Ref<Array>> &closure, bool returnClosure)
42 : name_(name), dtype_(dtype), closure_(closure),
43 returnClosure_(returnClosure) {}
44};
45
46std::ostream &operator<<(std::ostream &os, const FuncRet &r);
47
48class FuncNode : public ASTNode {
49 public:
50 std::string name_;
51 std::vector<FuncParam> params_;
52 std::vector<FuncRet>
53 returns_; // NOTE: multiple items in `returns_` may share the same name.
54 // In this case, one variable should be returned to multiple
55 // positions
57
58 bool isFunc() const override { return true; }
59
60 void compHash() override { ASSERT(false); } // TODO
61
63};
65template <class Tbody, class Tparams, class Treturns, class Tclosure>
66Func makeFunc(const std::string &name, Tparams &&params, Treturns &&returns,
67 Tbody &&body) {
68 Func f = Func::make();
69 f->name_ = name;
70 f->params_ = std::forward<Tparams>(params);
71 f->returns_ = std::forward<Treturns>(returns);
72 f->body_ = std::forward<Tbody>(body);
73 return f;
74}
75template <class Tbody>
76Func makeFunc(const std::string &name, const std::vector<FuncParam> &params,
77 const std::vector<FuncRet> &returns, Tbody &&body) {
78 Func f = Func::make();
79 f->name_ = name;
80 f->params_ = params;
81 f->returns_ = returns;
82 f->body_ = std::forward<Tbody>(body);
83 return f;
84}
85
86Func deepCopy(const Func &func);
87
88#define DEFINE_PASS_FOR_FUNC(pass) \
89 template <typename... T> Func pass(const Func &func, T &&...args) { \
90 return makeFunc(func->name_, func->params_, func->returns_, \
91 pass(func->body_, std::forward<T>(args)...)); \
92 }
93
94} // namespace freetensor
95
96#endif // FREE_TENSOR_FUNC_H
Definition: ast.h:118
Definition: data_type.h:106
Definition: func.h:48
std::vector< FuncParam > params_
Definition: func.h:51
void compHash() override
Definition: func.h:60
std::string name_
Definition: func.h:50
bool isFunc() const override
Definition: func.h:58
std::vector< FuncRet > returns_
Definition: func.h:53
SubTree< StmtNode > body_
Definition: func.h:56
Definition: ref.h:24
static Ref make()
Definition: ref.h:105
Definition: sub_tree.h:134
#define ASSERT(expr)
Definition: except.h:152
Definition: allocator.h:9
PBSet params(T &&set)
Definition: presburger.h:1065
Expr deepCopy(const Expr &op)
Definition: ast.cc:364
Ref< FuncNode > Func
Definition: func.h:64
std::ostream & operator<<(std::ostream &os, const Dependence &dep)
Definition: deps.cc:1404
Func makeFunc(const std::string &name, Tparams &&params, Treturns &&returns, Tbody &&body)
Definition: func.h:66
Definition: sub_tree.h:20
Definition: func.h:18
bool isInClosure() const
Accept user input even if there is a closure.
Definition: func.h:23
bool updateClosure_
Data bound to this parameter.
Definition: func.h:21
Ref< Ref< Array > > closure_
Definition: func.h:20
std::string name_
Definition: func.h:19
FuncParam(const std::string &name, const Ref< Ref< Array > > &closure, bool updateClosure)
Definition: func.h:25
Definition: func.h:32
FuncRet(const std::string &name, DataType dtype, const Ref< Ref< Array > > &closure, bool returnClosure)
Definition: func.h:40
DataType dtype_
Definition: func.h:34
Ref< Ref< Array > > closure_
Definition: func.h:35
bool returnClosure_
Data bound to this return value.
Definition: func.h:36
bool isInClosure() const
Return even if there is a closure.
Definition: func.h:38
std::string name_
Definition: func.h:33