FreeTensor
Loading...
Searching...
No Matches
schedule_log.h
Go to the documentation of this file.
1#ifndef FREE_TENSOR_SCHEDULE_LOG_H
2#define FREE_TENSOR_SCHEDULE_LOG_H
3
4#include <array>
5#include <exception>
6#include <iostream>
7#include <mutex>
8#include <variant>
9
10#include <ast.h>
11#include <hash.h>
12#include <serialize/to_string.h>
13#include <shared_linked_list.h>
14
15namespace freetensor {
16
17enum class ScheduleType : int {
18 Split = 0,
19 Reorder,
20 Merge,
21 Fission,
22 Fuse,
23 Swap,
24 Blend,
25 Cache,
33 Inline,
36 Unroll,
40 Permute,
43 // ------
45};
46
47constexpr std::array scheduleTypeNames = {
48 "split", "reorder", "merge",
49 "fission", "fuse", "swap",
50 "blend", "cache", "cache_reduction",
51 "set_mem_type", "var_split", "var_merge",
52 "var_reorder", "var_unsqueeze", "var_squeeze",
53 "inline", "parallelize", "parallelize_as",
54 "unroll", "vectorize", "separate_tail",
55 "as_matmul", "permute", "pluto_fuse",
56 "pluto_permute",
57};
58static_assert(scheduleTypeNames.size() == (size_t)ScheduleType::NumTypes);
59
60inline std::ostream &operator<<(std::ostream &os, ScheduleType type) {
61 return os << scheduleTypeNames.at((size_t)type);
62}
63
75 public:
76 virtual ~ScheduleLogItem() {}
77 virtual ScheduleType type() const = 0;
78 virtual std::string toString() const = 0;
79 virtual std::string toPrettyString() const = 0;
80 virtual size_t hash() const = 0;
81 virtual bool equals(const ScheduleLogItem &other) const = 0;
82 virtual void run() = 0;
83 virtual Stmt resultAST() const = 0;
84};
85
86using IDMetadataPack = std::pair<ID, Metadata>;
87
88inline std::ostream &operator<<(std::ostream &os, const IDMetadataPack &pack) {
89 return os << pack.first << "(" << pack.second << ")";
90}
91
92template <typename... Args>
93auto getIDFromPack(const std::tuple<Args...> &args) {
94 auto f = [&](auto &&arg) {
95 if constexpr (std::is_same_v<std::decay_t<decltype(arg)>,
97 return arg.first;
98 else if constexpr (std::is_same_v<std::decay_t<decltype(arg)>,
99 std::vector<IDMetadataPack>>) {
100 auto ids =
101 arg | views::transform([&](auto pack) { return pack.first; });
102 return std::vector<ID>(ids.begin(), ids.end());
103 } else
104 return arg;
105 };
106 return std::apply(
107 [&](auto &&...args) { return std::make_tuple(f(args)...); }, args);
108}
109
110template <typename... Args>
111auto getMetadataFromPack(const std::tuple<Args...> &args) {
112 auto f = [&](auto &&arg) {
113 if constexpr (std::is_same_v<std::decay_t<decltype(arg)>,
115 return arg.second;
116 else if constexpr (std::is_same_v<std::decay_t<decltype(arg)>,
117 std::vector<IDMetadataPack>>) {
118 auto metas =
119 arg | views::transform([&](auto pack) { return pack.second; });
120 return std::vector<Metadata>(metas.begin(), metas.end());
121 } else
122 return arg;
123 };
124 return std::apply(
125 [&](auto &&...args) { return std::make_tuple(f(args)...); }, args);
126}
127
128template <typename... Args>
129auto getPackFromID(auto schedule, const std::tuple<Args...> &args) {
130 auto f = [&](auto &&arg) {
131 if constexpr (std::is_same_v<std::decay_t<decltype(arg)>, ID>)
132 return IDMetadataPack{arg, schedule->find(arg)->metadata()};
133 else if constexpr (std::is_same_v<std::decay_t<decltype(arg)>,
134 std::vector<ID>>) {
135 auto packs =
136 arg | views::transform([&](auto id) {
137 return IDMetadataPack{id, schedule->find(id)->metadata()};
138 });
139 return std::vector<IDMetadataPack>(packs.begin(), packs.end());
140 } else
141 return arg;
142 };
143 return std::apply(
144 [&](auto &&...args) { return std::make_tuple(f(args)...); }, args);
145}
146
153template <ScheduleType TYPE, class _Invocable, class _Params, class _Result>
155 protected:
156 // Types defined for subclasses
157 typedef _Invocable Invocable;
158 typedef _Params Params;
159 typedef _Result Result;
160
163 std::variant<std::nullopt_t, Result, std::exception_ptr> result_ =
164 std::nullopt;
165 std::mutex lock_;
166
167 public:
168 ScheduleLogItemImpl(const Invocable &doSchedule, const Params &params)
169 : doSchedule_(doSchedule), params_(params) {}
170
171 ScheduleType type() const override { return TYPE; }
172
173 std::string toString() const override {
174 std::ostringstream os;
175 os << std::boolalpha << type() << '(' << params_ << ')';
176 return os.str();
177 }
178
179 std::string toPrettyString() const override {
180 std::ostringstream os;
181 os << std::boolalpha << type() << '(' << getMetadataFromPack(params_)
182 << ')';
183 return os.str();
184 }
185
186 size_t hash() const override {
187 auto idParams = getIDFromPack(params_);
188 return std::hash<decltype(idParams)>()(idParams);
189 }
190
191 bool equals(const ScheduleLogItem &other) const override {
192 if (other.type() != type()) {
193 return false;
194 }
195 return ((const ScheduleLogItemImpl &)other).params_ == params_;
196 }
197
201 void run() override {
202 std::lock_guard<std::mutex> guard(lock_);
203 if (std::holds_alternative<std::nullopt_t>(result_)) {
204 try {
206 } catch (...) {
207 result_ = std::current_exception();
208 }
209 }
210 }
211
216 if (std::holds_alternative<std::nullopt_t>(result_)) {
217 ERROR("BUG: The schedule log is not run yet");
218 } else if (std::holds_alternative<Result>(result_)) {
219 return std::get<Result>(result_);
220 } else {
221 ASSERT(std::holds_alternative<std::exception_ptr>(result_));
222 std::rethrow_exception(std::get<std::exception_ptr>(result_));
223 }
224 }
225
226 Stmt resultAST() const override final {
227 if (std::holds_alternative<std::nullopt_t>(result_)) {
228 ERROR("The schedule log is not run yet");
229 } else if (std::holds_alternative<Result>(result_)) {
230 Result result = std::get<Result>(result_);
231 if constexpr (std::derived_from<Result, Stmt>) {
232 return result;
233 } else {
234 return std::get<0>(result);
235 }
236 } else {
237 return nullptr;
238 }
239 }
240};
241
242inline std::ostream &operator<<(std::ostream &os, const ScheduleLogItem &log) {
243 return os << log.toString();
244}
245
248 const Ref<ScheduleLogItem> &rhs) const {
249 return lhs->equals(*rhs);
250 }
251};
252
254 bool operator()(const Ref<ScheduleLogItem> &item) const {
255 return item->hash();
256 }
257};
258
259typedef SharedLinkedList<Ref<ScheduleLogItem>, ScheduleLogItemHash,
260 ScheduleLogItemEqual>
262
267#define MAKE_SCHEDULE_LOG(TYPE, FUNC, ...) \
268 ([this](const auto &func, const auto &_params) { \
269 auto params = getPackFromID(this, _params); \
270 /* decay is required: we must not store an reference */ \
271 typedef ScheduleLogItemImpl< \
272 ScheduleType::TYPE, std::decay_t<decltype(func)>, \
273 std::decay_t<decltype(params)>, \
274 std::decay_t<decltype(std::apply(func, _params))>> \
275 BaseClass; \
276 class ScheduleLogItem##TYPE : public BaseClass { \
277 public: \
278 ScheduleLogItem##TYPE(const typename BaseClass::Invocable &f, \
279 const typename BaseClass::Params &p) \
280 : BaseClass(f, p) {} \
281 }; \
282 return Ref<ScheduleLogItem##TYPE>::make(func, params); \
283 })(futureSchedule(FUNC), std::make_tuple(__VA_ARGS__))
284
285} // namespace freetensor
286
287#endif // FREE_TENSOR_SCHEDULE_LOG_H
Definition: as_matmul.h:56
Definition: id.h:18
Definition: parallelize.h:11
Definition: permute.cc:15
Definition: reorder.h:39
Definition: schedule_log.h:154
std::mutex lock_
Definition: schedule_log.h:165
size_t hash() const override
Definition: schedule_log.h:186
bool equals(const ScheduleLogItem &other) const override
Definition: schedule_log.h:191
Stmt resultAST() const override final
Definition: schedule_log.h:226
void run() override
Definition: schedule_log.h:201
_Params Params
Definition: schedule_log.h:158
_Result Result
Definition: schedule_log.h:159
std::variant< std::nullopt_t, Result, std::exception_ptr > result_
Definition: schedule_log.h:163
std::string toPrettyString() const override
Definition: schedule_log.h:179
ScheduleType type() const override
Definition: schedule_log.h:171
Result getResult() const
Definition: schedule_log.h:215
Invocable doSchedule_
Definition: schedule_log.h:161
_Invocable Invocable
Definition: schedule_log.h:157
ScheduleLogItemImpl(const Invocable &doSchedule, const Params &params)
Definition: schedule_log.h:168
Params params_
Definition: schedule_log.h:162
std::string toString() const override
Definition: schedule_log.h:173
Definition: schedule_log.h:74
virtual std::string toPrettyString() const =0
virtual std::string toString() const =0
virtual Stmt resultAST() const =0
virtual size_t hash() const =0
virtual bool equals(const ScheduleLogItem &other) const =0
virtual ~ScheduleLogItem()
Definition: schedule_log.h:76
virtual ScheduleType type() const =0
Definition: separate_tail.h:53
Definition: set_mem_type.h:52
Definition: swap.h:11
Definition: var_merge.h:8
Definition: var_reorder.h:11
Definition: var_split.h:10
Definition: vectorize.h:8
#define ASSERT(expr)
Definition: except.h:152
#define ERROR(msg)
Definition: except.h:141
Definition: allocator.h:9
auto getMetadataFromPack(const std::tuple< Args... > &args)
Definition: schedule_log.h:111
auto && lhs
Definition: const_fold.cc:70
auto getIDFromPack(const std::tuple< Args... > &args)
Definition: schedule_log.h:93
PBSet params(T &&set)
Definition: presburger.h:1065
SharedLinkedList< Ref< ScheduleLogItem >, ScheduleLogItemHash, ScheduleLogItemEqual > ScheduleLog
Definition: schedule_log.h:261
constexpr std::array scheduleTypeNames
Definition: schedule_log.h:47
std::pair< ID, Metadata > IDMetadataPack
Definition: schedule_log.h:86
auto getPackFromID(auto schedule, const std::tuple< Args... > &args)
Definition: schedule_log.h:129
auto auto && rhs
Definition: const_fold.cc:70
ScheduleType
Definition: schedule_log.h:17
std::ostream & operator<<(std::ostream &os, const Dependence &dep)
Definition: deps.cc:1404
Definition: schedule_log.h:246
bool operator()(const Ref< ScheduleLogItem > &lhs, const Ref< ScheduleLogItem > &rhs) const
Definition: schedule_log.h:247
Definition: schedule_log.h:253
bool operator()(const Ref< ScheduleLogItem > &item) const
Definition: schedule_log.h:254