FreeTensor
Loading...
Searching...
No Matches
symbol_table.h
Go to the documentation of this file.
1#ifndef FREE_TENSOR_SYMBOL_TABLE_H
2#define FREE_TENSOR_SYMBOL_TABLE_H
3
4#include <type_traits>
5#include <unordered_map>
6#include <unordered_set>
7
8#include <maybe_void.h>
9#include <stmt.h>
10
11namespace freetensor {
12
14 public:
15 virtual const std::unordered_set<std::string> &names() const = 0;
16 virtual const std::unordered_map<std::string, VarDef> &defs() const = 0;
17 virtual const std::unordered_map<std::string, For> &loops() const = 0;
18
19 virtual bool hasDef(const std::string &name) const = 0;
20 virtual const VarDef &def(const std::string &name) const = 0;
21 virtual Ref<Buffer> buffer(const std::string &name) const = 0;
22
23 virtual bool hasLoop(const std::string &name) const = 0;
24 virtual const For &loop(const std::string &name) const = 0;
25
26 virtual void pushDef(const VarDef &op) = 0;
27 virtual void popDef(const VarDef &op) = 0;
28
29 virtual void pushFor(const For &op) = 0;
30 virtual void popFor(const For &op) = 0;
31};
32
34 std::unordered_map<std::string, VarDef> defs_;
35 std::unordered_map<std::string, For> loops_;
36 std::unordered_set<std::string> names_;
37
38 public:
39 const std::unordered_set<std::string> &names() const override {
40 return names_;
41 }
42 const std::unordered_map<std::string, VarDef> &defs() const override {
43 return defs_;
44 }
45 const std::unordered_map<std::string, For> &loops() const override {
46 return loops_;
47 }
48
49 bool hasDef(const std::string &name) const override {
50 return defs_.count(name);
51 }
52
53 const VarDef &def(const std::string &name) const override {
54 if (auto it = defs_.find(name); it != defs_.end()) {
55 return it->second;
56 } else {
57 throw SymbolNotFound("There is no VarDef with name `" + name +
58 "` in the current scope");
59 }
60 }
61
62 Ref<Buffer> buffer(const std::string &name) const override {
63 return def(name)->buffer_;
64 }
65
66 virtual bool hasLoop(const std::string &name) const override {
67 return loops_.count(name);
68 }
69
70 virtual const For &loop(const std::string &name) const override {
71 if (auto it = loops_.find(name); it != loops_.end()) {
72 return it->second;
73 } else {
74 throw SymbolNotFound("There is no For with iterator named `" +
75 name + "` in the current scope");
76 }
77 }
78
79 void pushDef(const VarDef &op) override {
80 if (names_.count(op->name_)) {
81 throw InvalidProgram("Nested VarDef with the same name \"" +
82 op->name_ + "\" is not allowed");
83 }
84 defs_[op->name_] = op;
85 names_.insert(op->name_);
86 }
87
88 void popDef(const VarDef &op) override {
89 defs_.erase(op->name_);
90 names_.erase(op->name_);
91 }
92
93 void pushFor(const For &op) override {
94 if (names_.count(op->iter_)) {
95 throw InvalidProgram("Nested For with the same iterator \"" +
96 op->iter_ + "\" is not allowed");
97 }
98 loops_[op->iter_] = op;
99 names_.insert(op->iter_);
100 }
101
102 void popFor(const For &op) override {
103 loops_.erase(op->iter_);
104 names_.erase(op->iter_);
105 }
106};
107
121template <class BaseClass>
122class SymbolTable : public BaseClass, public SymbolTableInterface {
123 SymbolTableData impl_;
124
125 public:
126 template <class... T>
127 SymbolTable(T &&...args) : BaseClass(std::forward<T>(args)...) {}
128
129 const std::unordered_set<std::string> &names() const override {
130 return impl_.names();
131 }
132 const std::unordered_map<std::string, VarDef> &defs() const override {
133 return impl_.defs();
134 }
135 const std::unordered_map<std::string, For> &loops() const override {
136 return impl_.loops();
137 }
138
139 bool hasDef(const std::string &name) const override {
140 return impl_.hasDef(name);
141 }
142 const VarDef &def(const std::string &name) const override {
143 return impl_.def(name);
144 }
145 Ref<Buffer> buffer(const std::string &name) const override {
146 return impl_.buffer(name);
147 }
148
149 bool hasLoop(const std::string &name) const override {
150 return impl_.hasLoop(name);
151 }
152 const For &loop(const std::string &name) const override {
153 return impl_.loop(name);
154 }
155
156 void pushDef(const VarDef &op) override { impl_.pushDef(op); }
157 void popDef(const VarDef &op) override { impl_.popDef(op); }
158
159 void pushFor(const For &op) override { impl_.pushFor(op); }
160 void popFor(const For &op) override { impl_.popFor(op); }
161
162 const SymbolTableData &symbolTableSnapshot() const { return impl_; }
163
164 protected:
165 using BaseClass::visit;
166
167 typename BaseClass::StmtRetType visit(const VarDef &op) override {
168 if constexpr (std::is_same_v<typename BaseClass::StmtRetType, void>) {
169 for (auto &&dim : op->buffer_->tensor()->shape()) {
170 (*this)(dim);
171 }
172
173 pushDef(op);
174 (*this)(op->body_);
175 popDef(op);
176 } else {
177 std::vector<Expr> shape;
178 shape.reserve(op->buffer_->tensor()->shape().size());
179 for (auto &&dim : op->buffer_->tensor()->shape()) {
180 shape.emplace_back((*this)(dim));
181 }
182 Ref<Tensor> t =
183 makeTensor(std::move(shape), op->buffer_->tensor()->dtype());
184 Ref<Buffer> b = makeBuffer(std::move(t), op->buffer_->atype(),
185 op->buffer_->mtype());
186
187 pushDef(op);
188 auto body = (*this)(op->body_);
189 popDef(op);
190
191 return makeVarDef(op->name_, std::move(b), op->viewOf_,
192 std::move(body), op->pinned_, op->metadata(),
193 op->id(), op->debugBlame());
194 }
195 }
196
197 typename BaseClass::StmtRetType visit(const For &op) override {
198 MAYBE_VOID(begin, (*this)(op->begin_));
199 MAYBE_VOID(end, (*this)(op->end_));
200 MAYBE_VOID(step, (*this)(op->step_));
201 MAYBE_VOID(len, (*this)(op->len_));
202
203 Ref<ForProperty> property;
204 if constexpr (!std::is_same_v<typename BaseClass::StmtRetType, void>) {
205 property = Ref<ForProperty>::make()
206 ->withParallel(op->property_->parallel_)
207 ->withUnroll(op->property_->unroll_)
208 ->withVectorize(op->property_->vectorize_)
209 ->withNoDeps(op->property_->noDeps_)
210 ->withPreferLibs(op->property_->preferLibs_);
211 property->reductions_.reserve(op->property_->reductions_.size());
212 for (auto &&r : op->property_->reductions_) {
213 std::vector<Expr> begins, ends;
214 begins.reserve(r->begins_.size());
215 ends.reserve(r->ends_.size());
216 for (auto &&item : r->begins_) {
217 begins.emplace_back((*this)(item));
218 }
219 for (auto &&item : r->ends_) {
220 ends.emplace_back((*this)(item));
221 }
222 property->reductions_.emplace_back(
223 makeReductionItem(r->op_, r->var_, std::move(begins),
224 std::move(ends), r->syncFlush_));
225 }
226 }
227
228 pushFor(op);
229 MAYBE_VOID(body, (*this)(op->body_));
230 popFor(op);
231
232 if constexpr (!std::is_same_v<typename BaseClass::StmtRetType, void>) {
233 return makeFor(op->iter_, std::move(begin), std::move(end),
234 std::move(step), std::move(len), std::move(property),
235 std::move(body), op->metadata(), op->id(),
236 op->debugBlame());
237 }
238 }
239};
240
241} // namespace freetensor
242
243#endif // FREE_TENSOR_SYMBOL_TABLE_H
std::source_location debugBlame() const
Definition: ast.h:134
SubTree< ForProperty > property_
Definition: stmt.h:298
SubTree< ExprNode > begin_
Definition: stmt.h:294
SubTree< ExprNode > step_
Definition: stmt.h:296
SubTree< ExprNode > len_
Definition: stmt.h:297
SubTree< ExprNode > end_
Definition: stmt.h:295
std::string iter_
Definition: stmt.h:289
SubTree< StmtNode > body_
Definition: stmt.h:299
Definition: except.h:83
static Ref make()
Definition: ref.h:105
const Metadata & metadata() const
Definition: ast.h:233
ID id() const
Definition: ast.cc:362
Definition: except.h:90
Definition: symbol_table.h:33
void pushDef(const VarDef &op) override
Definition: symbol_table.h:79
const std::unordered_map< std::string, VarDef > & defs() const override
Definition: symbol_table.h:42
Ref< Buffer > buffer(const std::string &name) const override
Definition: symbol_table.h:62
const std::unordered_map< std::string, For > & loops() const override
Definition: symbol_table.h:45
void pushFor(const For &op) override
Definition: symbol_table.h:93
bool hasDef(const std::string &name) const override
Definition: symbol_table.h:49
virtual const For & loop(const std::string &name) const override
Definition: symbol_table.h:70
void popFor(const For &op) override
Definition: symbol_table.h:102
const std::unordered_set< std::string > & names() const override
Definition: symbol_table.h:39
virtual bool hasLoop(const std::string &name) const override
Definition: symbol_table.h:66
void popDef(const VarDef &op) override
Definition: symbol_table.h:88
const VarDef & def(const std::string &name) const override
Definition: symbol_table.h:53
Definition: symbol_table.h:13
virtual void pushFor(const For &op)=0
virtual void pushDef(const VarDef &op)=0
virtual Ref< Buffer > buffer(const std::string &name) const =0
virtual const For & loop(const std::string &name) const =0
virtual void popDef(const VarDef &op)=0
virtual const std::unordered_set< std::string > & names() const =0
virtual const std::unordered_map< std::string, VarDef > & defs() const =0
virtual const std::unordered_map< std::string, For > & loops() const =0
virtual void popFor(const For &op)=0
virtual bool hasLoop(const std::string &name) const =0
virtual const VarDef & def(const std::string &name) const =0
virtual bool hasDef(const std::string &name) const =0
Definition: symbol_table.h:122
BaseClass::StmtRetType visit(const VarDef &op) override
Definition: symbol_table.h:167
const std::unordered_map< std::string, For > & loops() const override
Definition: symbol_table.h:135
SymbolTable(T &&...args)
Definition: symbol_table.h:127
const SymbolTableData & symbolTableSnapshot() const
Definition: symbol_table.h:162
const std::unordered_set< std::string > & names() const override
Definition: symbol_table.h:129
const For & loop(const std::string &name) const override
Definition: symbol_table.h:152
const VarDef & def(const std::string &name) const override
Definition: symbol_table.h:142
void popDef(const VarDef &op) override
Definition: symbol_table.h:157
void pushDef(const VarDef &op) override
Definition: symbol_table.h:156
bool hasDef(const std::string &name) const override
Definition: symbol_table.h:139
Ref< Buffer > buffer(const std::string &name) const override
Definition: symbol_table.h:145
void pushFor(const For &op) override
Definition: symbol_table.h:159
BaseClass::StmtRetType visit(const For &op) override
Definition: symbol_table.h:197
void popFor(const For &op) override
Definition: symbol_table.h:160
const std::unordered_map< std::string, VarDef > & defs() const override
Definition: symbol_table.h:132
bool hasLoop(const std::string &name) const override
Definition: symbol_table.h:149
bool pinned_
Definition: stmt.h:102
SubTree< StmtNode > body_
Definition: stmt.h:101
SubTree< Buffer > buffer_
Definition: stmt.h:86
std::optional< std::string > viewOf_
Definition: stmt.h:99
std::string name_
Definition: stmt.h:85
#define MAYBE_VOID(name, expr)
Definition: maybe_void.h:25
Definition: allocator.h:9
Ref< ReductionItem > makeReductionItem(ReduceOp op, const std::string &var, Tbegins &&begins, Tends &&ends, bool syncFlush)
Definition: for_property.h:26
Ref< Buffer > makeBuffer(T &&tensor, AccessType atype, MemType mtype)
Definition: buffer.h:32
Stmt makeVarDef(const std::string &name, Tbuffer &&buffer, const std::optional< std::string > &viewOf, Tbody &&body, bool pinned, const Metadata &metadata=nullptr, const ID &id={}, std::source_location loc=std::source_location::current())
Definition: stmt.h:109
Stmt makeFor(const std::string &iter, Tbegin &&begin, Tend &&end, Tstep &&step, Tlen &&len, Tproperty &&property, Tbody &&body, const Metadata &metadata=nullptr, const ID &id={}, std::source_location loc=std::source_location::current())
Definition: stmt.h:311
Ref< Tensor > makeTensor(T &&shape, DataType dtype)
Definition: tensor.h:36
STL namespace.