FreeTensor
Loading...
Searching...
No Matches
data_type.h
Go to the documentation of this file.
1#ifndef FREE_TENSOR_DATA_TYPE_H
2#define FREE_TENSOR_DATA_TYPE_H
3
4#include <array>
5#include <functional>
6
7#include <container_utils.h>
8#include <except.h>
10
11namespace freetensor {
12
13enum class BaseDataType : size_t {
14 Void = 0, // Returns nothing. It is a Unit Type
15 Float16,
16 Float32,
17 Float64,
18 Int32,
19 Int64,
20 Bool,
21 Custom,
22 Never, // Never returns. It is the Bottom Type
23 // ------
25};
26
27constexpr std::array baseDataTypeNames = {
28 "void", "float16", "float32", "float64", "int32",
29 "int64", "bool", "custom", "never",
30};
31static_assert(baseDataTypeNames.size() == (size_t)BaseDataType::NumTypes);
32
33namespace detail {
34
35template <typename T, T... i>
36constexpr auto createAllBaseDataTypes(std::integer_sequence<T, i...>) {
37 return std::array{(BaseDataType)i...};
38}
39
40} // namespace detail
41
43 std::make_index_sequence<(size_t)BaseDataType::NumTypes>{});
44
45inline std::ostream &operator<<(std::ostream &os, BaseDataType dtype) {
46 return os << baseDataTypeNames.at((size_t)dtype);
47}
48
49inline BaseDataType parseBaseDataType(const std::string &_str) {
50 auto &&str = tolower(_str);
51 for (auto &&[i, s] : views::enumerate(baseDataTypeNames)) {
52 if (s == str) {
53 return (BaseDataType)i;
54 }
55 }
56 ERROR(FT_MSG << "Unrecognized base data type \"" << _str
57 << "\". Candidates are (case-insensitive): "
58 << (baseDataTypeNames | join(", ")));
59}
60
61enum class SignDataType : size_t {
62 Any = 0,
63 GT0,
64 GE0,
65 LT0,
66 LE0,
67 NE0,
68 EQ0, // EQ0 is only for "0" literals. No need to type-inference EQ0 because
69 // we have const_fold
70 Never, // Bottom type
71 // ------
73};
74
75constexpr std::array signDataTypeNames = {
76 "", ">0", ">=0", "<0", "<=0", "!=0", "==0", "{}",
77};
78static_assert(signDataTypeNames.size() == (size_t)SignDataType::NumTypes);
79
80namespace detail {
81
82template <typename T, T... i>
83constexpr auto createAllSignDataTypes(std::integer_sequence<T, i...>) {
84 return std::array{(SignDataType)i...};
85}
86
87} // namespace detail
88
90 std::make_index_sequence<(size_t)SignDataType::NumTypes>{});
91
92inline std::ostream &operator<<(std::ostream &os, SignDataType dtype) {
93 return os << signDataTypeNames.at((size_t)dtype);
94}
95
96inline SignDataType parseSignDataType(const std::string &str) {
97 for (auto &&[i, s] : views::enumerate(signDataTypeNames)) {
98 if (s == str) {
99 return (SignDataType)i;
100 }
101 }
102 ERROR(FT_MSG << "Unrecognized sign data type \"" << str
103 << "\". Candidates are: " << (signDataTypeNames | join(", ")));
104}
105
106class DataType {
107 BaseDataType base_;
108 SignDataType sign_;
109
110 public:
111 DataType() {} // Construct without initialization
113 : base_(base), sign_(sign) {}
114
115 // Expose BaseDataType::* to DataType::*
116 //
117 // TODO: Use the following line after GCC 12.3. GCC is buggy with `using
118 // enum` before 12.3 (https://gcc.gnu.org/bugzilla/show_bug.cgi?id=103081)
119 //
120 // using enum BaseDataType;
121 //
122 // and remove the following lines
123 constexpr static auto Bool = BaseDataType::Bool;
124 constexpr static auto Custom = BaseDataType::Custom;
125 constexpr static auto Float16 = BaseDataType::Float16;
126 constexpr static auto Float32 = BaseDataType::Float32;
127 constexpr static auto Float64 = BaseDataType::Float64;
128 constexpr static auto Int32 = BaseDataType::Int32;
129 constexpr static auto Int64 = BaseDataType::Int64;
130 constexpr static auto Void = BaseDataType::Void;
131
132 const auto &base() const { return base_; }
133 const auto &sign() const { return sign_; }
134
135 friend bool operator==(const DataType &, const DataType &) = default;
136};
137
138inline std::ostream &operator<<(std::ostream &os, const DataType &dtype) {
139 return os << dtype.base() << dtype.sign();
140}
141
142inline DataType parseDType(const std::string &str) {
143 auto split = str.find_first_of("<>=!");
144 if (split == std::string::npos) {
145 split = str.length();
146 }
147 auto base = parseBaseDataType(str.substr(0, split));
148 auto sign = parseSignDataType(str.substr(split));
149 return DataType{base, sign};
150}
151
152size_t sizeOf(BaseDataType dtype);
153inline size_t sizeOf(const DataType &dtype) { return sizeOf(dtype.base()); }
154
155// The following functions tests properties of a type. NOTE: All properties hold
156// for the bottom type `Never`, because $\forall x \in \emptyset : P(x)$ is
157// always true for any $P$
158
159bool isInt(BaseDataType dtype);
160inline bool isInt(const DataType &dtype) { return isInt(dtype.base()); }
161
162bool isFloat(BaseDataType dtype);
163inline bool isFloat(const DataType &dtype) { return isFloat(dtype.base()); }
164
165inline bool isNumber(BaseDataType dtype) {
166 return isInt(dtype) || isFloat(dtype);
167}
168inline bool isNumber(const DataType &dtype) { return isNumber(dtype.base()); }
169
170bool isBool(BaseDataType dtype);
171inline bool isBool(const DataType &dtype) { return isBool(dtype.base()); }
172
173bool isGT0(SignDataType dtype);
174inline bool isGT0(const DataType &dtype) { return isGT0(dtype.sign()); }
175
176bool isGE0(SignDataType dtype);
177inline bool isGE0(const DataType &dtype) { return isGE0(dtype.sign()); }
178
179bool isLT0(SignDataType dtype);
180inline bool isLT0(const DataType &dtype) { return isLT0(dtype.sign()); }
181
182bool isLE0(SignDataType dtype);
183inline bool isLE0(const DataType &dtype) { return isLE0(dtype.sign()); }
184
185bool isNE0(SignDataType dtype);
186inline bool isNE0(const DataType &dtype) { return isNE0(dtype.sign()); }
187
188bool isEQ0(SignDataType dtype);
189inline bool isEQ0(const DataType &dtype) { return isNE0(dtype.sign()); }
190
204inline DataType upCast(const DataType &lhs, const DataType &rhs) {
205 return {upCast(lhs.base(), rhs.base()), upCast(lhs.sign(), rhs.sign())};
206}
218inline DataType downCast(const DataType &lhs, const DataType &rhs) {
219 return {downCast(lhs.base(), rhs.base()), downCast(lhs.sign(), rhs.sign())};
220}
223} // namespace freetensor
224
225namespace std {
226
227template <> class hash<freetensor::DataType> {
228 std::hash<size_t> h_;
229
230 public:
231 size_t operator()(const freetensor::DataType &dtype) const;
232};
233
234} // namespace std
235
236#endif // FREE_TENSOR_DATA_TYPE
Definition: data_type.h:106
static constexpr auto Float32
Definition: data_type.h:126
const auto & sign() const
Definition: data_type.h:133
const auto & base() const
Definition: data_type.h:132
static constexpr auto Int32
Definition: data_type.h:128
DataType(BaseDataType base, SignDataType sign=SignDataType::Any)
Definition: data_type.h:112
DataType()
Definition: data_type.h:111
friend bool operator==(const DataType &, const DataType &)=default
static constexpr auto Int64
Definition: data_type.h:129
static constexpr auto Float16
Definition: data_type.h:125
static constexpr auto Custom
Definition: data_type.h:124
static constexpr auto Float64
Definition: data_type.h:127
static constexpr auto Bool
Definition: data_type.h:123
static constexpr auto Void
Definition: data_type.h:130
Definition: ref.h:24
#define ERROR(msg)
Definition: except.h:141
#define FT_MSG
Definition: except.h:23
constexpr auto createAllSignDataTypes(std::integer_sequence< T, i... >)
Definition: data_type.h:83
constexpr auto createAllBaseDataTypes(std::integer_sequence< T, i... >)
Definition: data_type.h:36
Definition: allocator.h:9
bool isLT0(SignDataType dtype)
Definition: data_type.cc:82
bool isGT0(SignDataType dtype)
Definition: data_type.cc:60
SignDataType
Definition: data_type.h:61
auto && lhs
Definition: const_fold.cc:70
SignDataType parseSignDataType(const std::string &str)
Definition: data_type.h:96
constexpr auto allBaseDataTypes
Definition: data_type.h:42
bool isEQ0(SignDataType dtype)
Definition: data_type.cc:116
constexpr auto allSignDataTypes
Definition: data_type.h:89
std::string tolower(const std::string &s)
Definition: container_utils.h:142
BaseDataType downCast(BaseDataType lhs, BaseDataType rhs)
Definition: data_type.cc:176
bool isGE0(SignDataType dtype)
Definition: data_type.cc:70
bool isLE0(SignDataType dtype)
Definition: data_type.cc:92
BaseDataType parseBaseDataType(const std::string &_str)
Definition: data_type.h:49
constexpr std::array signDataTypeNames
Definition: data_type.h:75
bool isBool(BaseDataType dtype)
Definition: data_type.cc:50
bool isFloat(BaseDataType dtype)
Definition: data_type.cc:38
std::string join(Container &&c, const std::string &splitter)
Definition: container_utils.h:194
constexpr std::array baseDataTypeNames
Definition: data_type.h:27
auto auto && rhs
Definition: const_fold.cc:70
BaseDataType upCast(BaseDataType lhs, BaseDataType rhs)
Definition: data_type.cc:126
std::ostream & operator<<(std::ostream &os, const Dependence &dep)
Definition: deps.cc:1404
bool isNE0(SignDataType dtype)
Definition: data_type.cc:104
DataType parseDType(const std::string &str)
Definition: data_type.h:142
bool isInt(BaseDataType dtype)
Definition: data_type.cc:27
BaseDataType
Definition: data_type.h:13
bool isNumber(BaseDataType dtype)
Definition: data_type.h:165
size_t sizeOf(BaseDataType dtype)
Definition: data_type.cc:6
std::pair< Stmt, std::pair< ID, ID > > split(const Stmt &ast, const ID &id, int factor, int nparts, int shift=0)
Definition: split.cc:90
STL namespace.