FreeTensor
Loading...
Searching...
No Matches
ref.h
Go to the documentation of this file.
1#ifndef FREE_TENSOR_REF_H
2#define FREE_TENSOR_REF_H
3
4#include <functional> // hash
5#include <memory>
6#include <type_traits>
7
8#include <allocator.h>
9#include <except.h>
10
11namespace freetensor {
12
13class EnableSelfBase;
14template <class T> class EnableSelf;
15
24template <class T> class Ref {
25 template <class U> friend class Ref;
26 template <class U> friend class Weak;
27 template <class U> friend class EnableSelf;
28
29 std::shared_ptr<T> ptr_;
30
31 private:
32 Ref(std::shared_ptr<T> &&ptr) : ptr_(std::move(ptr)) { updateSelf(); }
33
37 void updateSelf() {
38 if constexpr (std::is_base_of_v<EnableSelfBase, T>) {
39 if (ptr_ != nullptr) {
40 auto &self =
41 std::static_pointer_cast<EnableSelf<typename T::Self>>(ptr_)
42 ->self_;
43 // We only need to set `self` when we point a `Ref` to it for
44 // the first time. Later when we retrieve `Ref` from
45 // `Weak::lock`, we are still entering this `updateSelf`, but we
46 // need to set nothing (it is already set). Otherwise, there may
47 // be some false alarm from some data rase detectors, when we
48 // retrieve from `Weak::lock` simultenously from multiple
49 // threads
50 if (!self.isValid()) {
51 self = *this;
52 }
53 }
54 }
55 }
56
57 public:
58 typedef T Object;
59
60 Ref() = default;
61 Ref(std::nullptr_t) : Ref() {}
62 Ref(const Ref &) = default;
63 Ref(Ref &&) = default;
64
67 Ref(T *ptr) : ptr_(ptr) { updateSelf(); }
68
72 template <std::derived_from<T> U>
73 Ref(const Ref<U> &other) : ptr_(std::static_pointer_cast<T>(other.ptr_)) {}
74
75 template <std::derived_from<T> U> Ref &operator=(const Ref<U> &other) {
76 ptr_ = std::static_pointer_cast<T>(other.ptr_);
77 return *this;
78 }
79
80 Ref &operator=(const Ref &) = default;
81 Ref &operator=(Ref &&) = default;
82
83 template <class U> Ref<U> as() const {
84 Ref<U> ret;
85 ret.ptr_ = std::static_pointer_cast<U>(ptr_);
86 return ret;
87 }
88
89 bool isValid() const { return ptr_ != nullptr; }
90
91 T &operator*() const {
92 ASSERT(isValid());
93 return *ptr_;
94 }
95
96 T *operator->() const {
97 ASSERT(isValid());
98 return ptr_.get();
99 }
100
101 T *get() const {
102 return ptr_.get(); // maybe called from PyBind11, don't assert isValid()
103 }
104
105 static Ref make() { return Ref(std::allocate_shared<T>(Allocator<T>())); }
106 static Ref make(T &&x) {
107 return Ref(std::allocate_shared<T>(Allocator<T>(), std::move(x)));
108 }
109 static Ref make(const T &x) {
110 return Ref(std::allocate_shared<T>(Allocator<T>(), x));
111 }
112 template <class... Args> static Ref make(Args &&...args) {
113 return Ref(std::allocate_shared<T>(Allocator<T>(),
114 std::forward<Args>(args)...));
115 }
116
117 friend bool operator==(const Ref &lhs, const Ref &rhs) {
118 return lhs.ptr_ == rhs.ptr_;
119 }
120 friend auto operator<=>(const Ref &lhs, const Ref &rhs) {
121 return lhs.ptr_ <=> rhs.ptr_;
122 }
123};
124
125template <class T> class Weak {
126 std::weak_ptr<T> ptr_;
127 bool notNull_ = false;
128
129 public:
130 Weak() {}
131 Weak(std::nullptr_t) {}
132
133 template <std::derived_from<T> U>
134 Weak(const Ref<U> &ref) : ptr_(ref.ptr_), notNull_(ref.isValid()) {}
135
140 bool isValid() const { return notNull_; }
141
142 Ref<T> lock() const { return Ref<T>(ptr_.lock()); }
143};
144
146
154template <class T> class EnableSelf : public EnableSelfBase {
155 template <class U> friend class Ref;
156
157 Weak<T> self_;
158
159 public:
160 typedef T Self;
161
162 Ref<T> self() const {
163 auto ret = self_.lock();
164 if (!ret.isValid()) {
165 ERROR(
166 "BUG: This class is not managed by Ref. Are you trying to get "
167 "the Ref in a constructor even before a Ref is constructed?");
168 }
169 return ret;
170 };
171};
172
173} // namespace freetensor
174
175namespace std {
176
177template <class T> struct hash<freetensor::Ref<T>> {
178 hash<T *> hash_;
179 size_t operator()(const freetensor::Ref<T> &ref) const {
180 return hash_(ref.get());
181 }
182};
183
184} // namespace std
185
186#endif // FREE_TENSOR_REF_H
Definition: allocator.h:66
Definition: ref.h:145
Definition: ref.h:154
T Self
Definition: ref.h:160
Ref< T > self() const
Definition: ref.h:162
Definition: ref.h:24
static Ref make(T &&x)
Definition: ref.h:106
static Ref make()
Definition: ref.h:105
Ref(Ref &&)=default
static Ref make(const T &x)
Definition: ref.h:109
T & operator*() const
Definition: ref.h:91
Ref(T *ptr)
Definition: ref.h:67
friend bool operator==(const Ref &lhs, const Ref &rhs)
Definition: ref.h:117
bool isValid() const
Definition: ref.h:89
T * operator->() const
Definition: ref.h:96
Ref< U > as() const
Definition: ref.h:83
T Object
Definition: ref.h:58
static Ref make(Args &&...args)
Definition: ref.h:112
Ref & operator=(const Ref &)=default
Ref & operator=(Ref &&)=default
Ref(const Ref &)=default
friend auto operator<=>(const Ref &lhs, const Ref &rhs)
Definition: ref.h:120
Ref(const Ref< U > &other)
Definition: ref.h:73
Ref & operator=(const Ref< U > &other)
Definition: ref.h:75
friend class Ref
Definition: ref.h:25
Ref(std::nullptr_t)
Definition: ref.h:61
T * get() const
Definition: ref.h:101
Definition: ref.h:125
Weak()
Definition: ref.h:130
Ref< T > lock() const
Definition: ref.h:142
Weak(std::nullptr_t)
Definition: ref.h:131
bool isValid() const
Definition: ref.h:140
Weak(const Ref< U > &ref)
Definition: ref.h:134
#define ASSERT(expr)
Definition: except.h:152
#define ERROR(msg)
Definition: except.h:141
Definition: allocator.h:9
auto && lhs
Definition: const_fold.cc:70
auto auto && rhs
Definition: const_fold.cc:70
STL namespace.
hash< T * > hash_
Definition: ref.h:178
size_t operator()(const freetensor::Ref< T > &ref) const
Definition: ref.h:179