FreeTensor
Loading...
Searching...
No Matches
frontend_var.h
Go to the documentation of this file.
1#ifndef FREE_TENSOR_FRONTEND_VAR
2#define FREE_TENSOR_FRONTEND_VAR
3
4#include <iostream>
5
6#include <container_utils.h>
7#include <debug.h>
8#include <expr.h>
9#include <pass/const_fold.h>
10#include <serialize/to_string.h>
11#include <stmt.h>
12
13namespace freetensor {
14
15enum class FrontendVarIdxType : int { Single, Slice };
16
19 Expr start_, stop_, len_;
20
21 public:
22 FrontendVarIdxType type() const { return type_; }
23
24 const Expr &single() const {
26 return start_;
27 }
28
29 const Expr &start() const {
31 return start_;
32 }
33
34 const Expr &stop() const {
36 return stop_;
37 }
38
39 const Expr &len() const { return len_; }
40
44 ret.start_ = constFold(single);
45 ret.len_ = makeIntConst(1);
46 return ret;
47 }
48
49 static FrontendVarIdx fromSlice(const Expr &start, const Expr &stop,
50 const Expr &len = nullptr) {
52 ret.type_ = FrontendVarIdxType::Slice;
53 if (start.isValid()) {
54 ret.start_ = start;
55 } else {
58 ret.start_ = constFold(makeSub(stop, len));
59 }
60 if (stop.isValid()) {
61 ret.stop_ = stop;
62 } else {
65 ret.stop_ = constFold(makeAdd(start, len));
66 }
67 if (len.isValid()) {
68 ret.len_ = len;
69 } else {
72 ret.len_ = constFold(makeSub(stop, start));
73 }
74 return ret;
75 }
76};
77
78inline std::ostream &operator<<(std::ostream &os, const FrontendVarIdx &idx) {
79 if (idx.type() == FrontendVarIdxType::Single) {
80 return os << idx.single();
81 } else {
82 return os << "(" << idx.start() << ", " << idx.stop() << ")";
83 }
84}
85
87 std::string name_;
88 std::vector<Expr> fullShape_;
89 DataType dtype_;
90 MemType mtype_;
91 std::vector<FrontendVarIdx> indices_;
92 bool isLoadAtVersion_;
93
94 public:
95 FrontendVar(const std::string &name, const std::vector<Expr> &fullShape,
97 const std::vector<FrontendVarIdx> &indices,
98 bool isLoadAtVersion = false)
99 : name_(name), fullShape_(fullShape), dtype_(dtype), mtype_(mtype),
100 indices_(indices), isLoadAtVersion_(isLoadAtVersion) {}
101
102 const std::string &name() const { return name_; }
103
107 const std::vector<Expr> &fullShape() const { return fullShape_; }
108
109 DataType dtype() const { return dtype_; }
110 MemType mtype() const { return mtype_; }
111
115 int ndim() const;
116
117 const std::vector<FrontendVarIdx> &indices() const { return indices_; }
118
123 Expr shape(const Expr &idx) const;
124 std::vector<Expr> shape() const;
127 Expr asLoad() const;
128 Stmt asStore(const Metadata &metadata, const Expr &value) const;
129 Stmt asReduceTo(ReduceOp op, const Metadata &metadata, const Expr &value,
130 bool atomic = false) const;
131
132 std::vector<FrontendVarIdx>
133 chainIndices(const std::vector<FrontendVarIdx> &next) const;
134};
135
136inline std::ostream &operator<<(std::ostream &os, const FrontendVar &var) {
137 os << var.name() << "[";
138 for (auto &&[i, idx] : views::enumerate(var.indices())) {
139 os << (i == 0 ? "" : ", ") << idx;
140 }
141 return os << "]";
142}
143
144std::unordered_set<std::string> allReads(const FrontendVarIdx &idx);
145
146} // namespace freetensor
147
148#endif // FREE_TENSOR_FRONTEND_VAR
Definition: data_type.h:106
Definition: frontend_var.h:17
const Expr & start() const
Definition: frontend_var.h:29
const Expr & single() const
Definition: frontend_var.h:24
const Expr & stop() const
Definition: frontend_var.h:34
static FrontendVarIdx fromSingle(const Expr &single)
Definition: frontend_var.h:41
static FrontendVarIdx fromSlice(const Expr &start, const Expr &stop, const Expr &len=nullptr)
Definition: frontend_var.h:49
const Expr & len() const
Definition: frontend_var.h:39
FrontendVarIdxType type() const
Definition: frontend_var.h:22
Definition: frontend_var.h:86
DataType dtype() const
Definition: frontend_var.h:109
const std::vector< FrontendVarIdx > & indices() const
Definition: frontend_var.h:117
const std::vector< Expr > & fullShape() const
Definition: frontend_var.h:107
Stmt asStore(const Metadata &metadata, const Expr &value) const
Definition: frontend_var.cc:87
const std::string & name() const
Definition: frontend_var.h:102
MemType mtype() const
Definition: frontend_var.h:110
Stmt asReduceTo(ReduceOp op, const Metadata &metadata, const Expr &value, bool atomic=false) const
Definition: frontend_var.cc:108
std::vector< FrontendVarIdx > chainIndices(const std::vector< FrontendVarIdx > &next) const
Definition: frontend_var.cc:126
FrontendVar(const std::string &name, const std::vector< Expr > &fullShape, DataType dtype, MemType mtype, const std::vector< FrontendVarIdx > &indices, bool isLoadAtVersion=false)
Definition: frontend_var.h:95
int ndim() const
Definition: frontend_var.cc:8
std::vector< Expr > shape() const
Definition: frontend_var.cc:50
Expr asLoad() const
Definition: frontend_var.cc:67
bool isValid() const
Definition: ref.h:89
#define ASSERT(expr)
Definition: except.h:152
Definition: allocator.h:9
Stmt constFold(const Stmt &op)
Definition: const_fold.h:177
FrontendVarIdxType
Definition: frontend_var.h:15
Expr makeAdd(T &&lhs, U &&rhs, std::source_location loc=std::source_location::current())
Definition: expr.h:174
std::unordered_set< std::string > allReads(const AST &op, bool noRecurseIdx=false, bool noRecurseSubStmt=false)
Definition: all_uses.h:83
ReduceOp
Definition: reduce_op.h:30
Expr makeSub(T &&lhs, U &&rhs, std::source_location loc=std::source_location::current())
Definition: expr.h:188
std::ostream & operator<<(std::ostream &os, const Dependence &dep)
Definition: deps.cc:1404
Expr makeIntConst(int64_t val, std::source_location loc=std::source_location::current())
Definition: expr.h:102
MemType
Definition: mem_type.h:14