1#ifndef FREE_TENSOR_SYMBOL_TABLE_H
2#define FREE_TENSOR_SYMBOL_TABLE_H
5#include <unordered_map>
6#include <unordered_set>
15 virtual const std::unordered_set<std::string> &
names()
const = 0;
16 virtual const std::unordered_map<std::string, VarDef> &
defs()
const = 0;
17 virtual const std::unordered_map<std::string, For> &
loops()
const = 0;
19 virtual bool hasDef(
const std::string &name)
const = 0;
20 virtual const VarDef &
def(
const std::string &name)
const = 0;
23 virtual bool hasLoop(
const std::string &name)
const = 0;
24 virtual const For &
loop(
const std::string &name)
const = 0;
34 std::unordered_map<std::string, VarDef> defs_;
35 std::unordered_map<std::string, For> loops_;
36 std::unordered_set<std::string> names_;
39 const std::unordered_set<std::string> &
names()
const override {
42 const std::unordered_map<std::string, VarDef> &
defs()
const override {
45 const std::unordered_map<std::string, For> &
loops()
const override {
49 bool hasDef(
const std::string &name)
const override {
50 return defs_.count(name);
53 const VarDef &
def(
const std::string &name)
const override {
54 if (
auto it = defs_.find(name); it != defs_.end()) {
58 "` in the current scope");
66 virtual bool hasLoop(
const std::string &name)
const override {
67 return loops_.count(name);
70 virtual const For &
loop(
const std::string &name)
const override {
71 if (
auto it = loops_.find(name); it != loops_.end()) {
75 name +
"` in the current scope");
80 if (names_.count(op->
name_)) {
82 op->
name_ +
"\" is not allowed");
84 defs_[op->
name_] = op;
85 names_.insert(op->
name_);
89 defs_.erase(op->
name_);
90 names_.erase(op->
name_);
94 if (names_.count(op->
iter_)) {
96 op->
iter_ +
"\" is not allowed");
98 loops_[op->
iter_] = op;
99 names_.insert(op->
iter_);
103 loops_.erase(op->
iter_);
104 names_.erase(op->
iter_);
121template <
class BaseClass>
126 template <
class... T>
129 const std::unordered_set<std::string> &
names()
const override {
130 return impl_.
names();
132 const std::unordered_map<std::string, VarDef> &
defs()
const override {
135 const std::unordered_map<std::string, For> &
loops()
const override {
136 return impl_.
loops();
139 bool hasDef(
const std::string &name)
const override {
140 return impl_.
hasDef(name);
142 const VarDef &
def(
const std::string &name)
const override {
143 return impl_.
def(name);
146 return impl_.
buffer(name);
149 bool hasLoop(
const std::string &name)
const override {
152 const For &
loop(
const std::string &name)
const override {
153 return impl_.
loop(name);
165 using BaseClass::visit;
167 typename BaseClass::StmtRetType
visit(
const VarDef &op)
override {
168 if constexpr (std::is_same_v<typename BaseClass::StmtRetType, void>) {
169 for (
auto &&dim : op->
buffer_->tensor()->shape()) {
177 std::vector<Expr> shape;
178 shape.reserve(op->
buffer_->tensor()->shape().size());
179 for (
auto &&dim : op->
buffer_->tensor()->shape()) {
180 shape.emplace_back((*
this)(dim));
188 auto body = (*this)(op->
body_);
197 typename BaseClass::StmtRetType
visit(
const For &op)
override {
204 if constexpr (!std::is_same_v<typename BaseClass::StmtRetType, void>) {
208 ->withVectorize(op->
property_->vectorize_)
210 ->withPreferLibs(op->
property_->preferLibs_);
211 property->reductions_.reserve(op->
property_->reductions_.size());
212 for (
auto &&r : op->
property_->reductions_) {
213 std::vector<Expr> begins, ends;
214 begins.reserve(r->begins_.size());
215 ends.reserve(r->ends_.size());
216 for (
auto &&item : r->begins_) {
217 begins.emplace_back((*
this)(item));
219 for (
auto &&item : r->ends_) {
220 ends.emplace_back((*
this)(item));
222 property->reductions_.emplace_back(
224 std::move(ends), r->syncFlush_));
232 if constexpr (!std::is_same_v<typename BaseClass::StmtRetType, void>) {
233 return makeFor(op->
iter_, std::move(begin), std::move(end),
234 std::move(step), std::move(len), std::move(property),
std::source_location debugBlame() const
Definition: ast.h:134
SubTree< ForProperty > property_
Definition: stmt.h:298
SubTree< ExprNode > begin_
Definition: stmt.h:294
SubTree< ExprNode > step_
Definition: stmt.h:296
SubTree< ExprNode > len_
Definition: stmt.h:297
SubTree< ExprNode > end_
Definition: stmt.h:295
std::string iter_
Definition: stmt.h:289
SubTree< StmtNode > body_
Definition: stmt.h:299
static Ref make()
Definition: ref.h:105
const Metadata & metadata() const
Definition: ast.h:233
ID id() const
Definition: ast.cc:362
Definition: symbol_table.h:33
void pushDef(const VarDef &op) override
Definition: symbol_table.h:79
const std::unordered_map< std::string, VarDef > & defs() const override
Definition: symbol_table.h:42
Ref< Buffer > buffer(const std::string &name) const override
Definition: symbol_table.h:62
const std::unordered_map< std::string, For > & loops() const override
Definition: symbol_table.h:45
void pushFor(const For &op) override
Definition: symbol_table.h:93
bool hasDef(const std::string &name) const override
Definition: symbol_table.h:49
virtual const For & loop(const std::string &name) const override
Definition: symbol_table.h:70
void popFor(const For &op) override
Definition: symbol_table.h:102
const std::unordered_set< std::string > & names() const override
Definition: symbol_table.h:39
virtual bool hasLoop(const std::string &name) const override
Definition: symbol_table.h:66
void popDef(const VarDef &op) override
Definition: symbol_table.h:88
const VarDef & def(const std::string &name) const override
Definition: symbol_table.h:53
Definition: symbol_table.h:13
virtual void pushFor(const For &op)=0
virtual void pushDef(const VarDef &op)=0
virtual Ref< Buffer > buffer(const std::string &name) const =0
virtual const For & loop(const std::string &name) const =0
virtual void popDef(const VarDef &op)=0
virtual const std::unordered_set< std::string > & names() const =0
virtual const std::unordered_map< std::string, VarDef > & defs() const =0
virtual const std::unordered_map< std::string, For > & loops() const =0
virtual void popFor(const For &op)=0
virtual bool hasLoop(const std::string &name) const =0
virtual const VarDef & def(const std::string &name) const =0
virtual bool hasDef(const std::string &name) const =0
Definition: symbol_table.h:122
BaseClass::StmtRetType visit(const VarDef &op) override
Definition: symbol_table.h:167
const std::unordered_map< std::string, For > & loops() const override
Definition: symbol_table.h:135
SymbolTable(T &&...args)
Definition: symbol_table.h:127
const SymbolTableData & symbolTableSnapshot() const
Definition: symbol_table.h:162
const std::unordered_set< std::string > & names() const override
Definition: symbol_table.h:129
const For & loop(const std::string &name) const override
Definition: symbol_table.h:152
const VarDef & def(const std::string &name) const override
Definition: symbol_table.h:142
void popDef(const VarDef &op) override
Definition: symbol_table.h:157
void pushDef(const VarDef &op) override
Definition: symbol_table.h:156
bool hasDef(const std::string &name) const override
Definition: symbol_table.h:139
Ref< Buffer > buffer(const std::string &name) const override
Definition: symbol_table.h:145
void pushFor(const For &op) override
Definition: symbol_table.h:159
BaseClass::StmtRetType visit(const For &op) override
Definition: symbol_table.h:197
void popFor(const For &op) override
Definition: symbol_table.h:160
const std::unordered_map< std::string, VarDef > & defs() const override
Definition: symbol_table.h:132
bool hasLoop(const std::string &name) const override
Definition: symbol_table.h:149
bool pinned_
Definition: stmt.h:102
SubTree< StmtNode > body_
Definition: stmt.h:101
SubTree< Buffer > buffer_
Definition: stmt.h:86
std::optional< std::string > viewOf_
Definition: stmt.h:99
std::string name_
Definition: stmt.h:85
#define MAYBE_VOID(name, expr)
Definition: maybe_void.h:25
Definition: allocator.h:9
Ref< ReductionItem > makeReductionItem(ReduceOp op, const std::string &var, Tbegins &&begins, Tends &&ends, bool syncFlush)
Definition: for_property.h:26
Ref< Buffer > makeBuffer(T &&tensor, AccessType atype, MemType mtype)
Definition: buffer.h:32
Stmt makeVarDef(const std::string &name, Tbuffer &&buffer, const std::optional< std::string > &viewOf, Tbody &&body, bool pinned, const Metadata &metadata=nullptr, const ID &id={}, std::source_location loc=std::source_location::current())
Definition: stmt.h:109
Stmt makeFor(const std::string &iter, Tbegin &&begin, Tend &&end, Tstep &&step, Tlen &&len, Tproperty &&property, Tbody &&body, const Metadata &metadata=nullptr, const ID &id={}, std::source_location loc=std::source_location::current())
Definition: stmt.h:311
Ref< Tensor > makeTensor(T &&shape, DataType dtype)
Definition: tensor.h:36