FreeTensor
Loading...
Searching...
No Matches
schedule.h
Go to the documentation of this file.
1#ifndef FREE_TENSOR_SCHEDULE_H
2#define FREE_TENSOR_SCHEDULE_H
3
4#include <functional>
5#include <unordered_map>
6
7#include <analyze/find_stmt.h>
9#include <driver/target.h>
10#include <func.h>
12#include <random.h>
13#include <schedule/as_matmul.h>
14#include <schedule/fission.h>
16#include <schedule/reorder.h>
18#include <schedule/var_split.h>
19#include <stmt.h>
20
21namespace freetensor {
22
23enum class MoveToSide : int { Before, After };
24
29 double time_, stddev_;
30};
31
32class Schedule {
33 struct Transaction {
34 Stmt ast_;
35 ScheduleLog logs_;
36
37 Transaction(const Stmt &ast, const ScheduleLog &logs)
38 : ast_(ast), logs_(logs) {}
39 };
40
41 Func func_;
43
44 std::vector<Transaction> openTrans_;
45
46 int verbose_ = 0;
47
48 Ref<MemoizedSchedules> memoized_;
49
52
53 private:
54 void setAst(const Stmt &ast);
55 void setLogs(const ScheduleLog &log);
56
66 static Stmt quickOptimizations(const Stmt &ast);
67
72 template <class F> auto futureSchedule(const F &sched) {
73 return [&](auto &&...args) {
74 auto ret = sched(ast(), std::forward<decltype(args)>(args)...);
75 if constexpr (std::convertible_to<decltype(ret), Stmt>) {
76 return quickOptimizations(ret);
77 } else { // pair(Stmt, other info)
78 return std::make_pair(quickOptimizations(ret.first),
79 ret.second);
80 }
81 };
82 }
83
92 template <class T> T appendLog(const T &_log) {
93 auto log = _log;
94 setLogs(memoized_->lookupOrCreate(logs().push(log)));
95 ASSERT(logs().top()->type() == log->type());
96 log = logs().top().as<typename decltype(log)::Object>();
97 log->run();
98 return log;
99 }
100
109 template <class T> auto applyLog(const T &log) {
110 auto ret = log->getResult();
111 if constexpr (std::convertible_to<decltype(ret), Stmt>) {
112 setAst(ret);
113 return;
114 } else { // pair(Stmt, other info)
115 setAst(ret.first);
116 return ret.second;
117 }
118 }
119
120 public:
121 Schedule() = default;
122 Schedule(const Stmt &ast, int verbose = 0);
123 Schedule(const Func &func, int verbose = 0)
124 : Schedule(func->body_, verbose) {
125 func_ = func;
126 }
127
128 // Copy by default, which means `Ref`s in a `Schedule` object is shared
129 Schedule(const Schedule &) = default;
130 Schedule &operator=(const Schedule &) = default;
131
142 Schedule fork() const { return *this; }
143
158 void beginTransaction();
159 void commitTransaction();
160 void abortTransaction();
166 Func func() const {
167 ASSERT(func_.isValid());
168 return makeFunc(func_->name_, func_->params_, func_->returns_, ast());
169 }
170
174 const Stmt &ast() const;
175
179 const ScheduleLog &logs() const;
180
184 int verbose() const { return verbose_; }
185
195 template <class T> std::vector<Stmt> findAll(const T &filter) const {
196 return findAllStmt(ast(), filter);
197 }
198
206 template <class T> std::vector<Stmt> findAtLeastOne(const T &filter) const {
207 auto ret = findAllStmt(ast(), filter);
208 if (ret.empty()) {
209 throw InvalidSchedule(ast(), "No statement found by filter");
210 }
211 return ret;
212 }
213
222 template <class T> Stmt find(const T &filter) const {
223 try {
224 return findStmt(ast(), filter);
225 } catch (const UnexpectedQueryResult &e) {
226 throw InvalidSchedule(ast(), e.what());
227 }
228 }
229
269 std::pair<ID, ID> split(const ID &id, int factor = -1, int nparts = -1,
270 int shift = 0);
271
286 void reorder(const std::vector<ID> &order,
288
304 ID merge(const ID &loop1, const ID &loop2);
305
322 std::vector<ID>
323 permute(const std::vector<ID> &loopsId,
324 const std::function<std::vector<Expr>(std::vector<Expr>)>
325 &transformFunc);
326
327 typedef std::unordered_map<ID, ID> IDMap;
357 std::pair<IDMap, IDMap> fission(const ID &loop, FissionSide side,
358 const ID &splitter,
359 bool allowEnlarge = true,
360 const std::string &suffix0 = ".0",
361 const std::string &suffix1 = ".1");
362
385 ID fuse(const ID &loop0, const ID &loop1, bool strict = false);
386 ID fuse(const ID &loop0, bool strict = false);
398 void swap(const std::vector<ID> &order);
399
427 void blend(const ID &loop);
428
463 std::tuple<ID, ID, std::string, ID>
464 cache(const ID &stmt, const std::string &var, MemType mtype);
465
495 std::tuple<ID, ID, std::string, ID>
496 cacheReduction(const ID &stmt, const std::string &var, MemType mtype);
497
517 void setMemType(const ID &def, MemType mtype);
518 void setMemType(const ID &def, MemType mtype, bool rejectIndirectAccess);
537 void varSplit(const ID &def, int dim, VarSplitMode mode, int factor = -1,
538 int nparts = -1);
539
546 void varMerge(const ID &def, int dim);
547
555 void varReorder(const ID &def, const std::vector<int> &order);
556
569 void varUnsqueeze(const ID &def, int dim);
570
583 void varSqueeze(const ID &def, int dim);
584
601 std::pair<ID, ID> moveTo(const ID &stmt, MoveToSide side, const ID &dst);
602
610 void inlining(const ID &def);
611
673 void parallelize(const ID &loop, const ParallelScope &parallel,
674 bool allowReduction = true);
675
689 void parallelizeAs(const ID &nest, const ID &reference, const ID &defId);
690
704 void unroll(const ID &loop, bool immediate = false);
705
717 void vectorize(const ID &loop);
718
759 void separateTail(bool noDuplicateVarDefs = false);
760
778 void asMatMul(const ID &loop, AsMatMulMode mode, const Ref<Target> &target,
779 MatMulBackend backend);
780 void asMatMul(const ID &loop, AsMatMulMode mode, const Ref<Target> &target);
781 void asMatMul(const ID &loop,
783
806 std::pair<ID, int> plutoFuse(const ID &loop0, const ID &loop1,
807 int nestLevel0 = 0, int nestLevel1 = 0,
808 int fusableOverlapThreshold = 1,
809 int fusableNonOverlapTolerance = 4,
810 bool doSimplify = true);
811
824 std::pair<ID, int> plutoPermute(const ID &loop, int nestLevel = 0,
825 bool doSimplify = true);
826
833 void autoSchedule(const Ref<Target> &target,
834 const Ref<RandTrace> &trace = nullptr);
835
841 void autoInline(const Ref<Target> &target);
842
848 void autoUseLib(const Ref<Target> &target);
849
855 void autoReorder(const Ref<Target> &target);
856
863 void autoSwap(const Ref<Target> &target);
864
870 void autoPluto(const Ref<Target> &target);
871
879 void autoFissionFuse(const Ref<Target> &target,
880 const Ref<RandTrace> &trace = nullptr);
881
887 void autoMemLayout(const Ref<Target> &target);
888
894 void autoParallelize(const Ref<Target> &target);
895
901 void autoSetMemType(const Ref<Target> &target);
902
908 void autoUnroll(const Ref<Target> &target);
909
910 std::vector<AutoScheduleTuneTrial> tuneAutoSchedule(
911 int nBatch, int batchSize, const Ref<Device> &device,
912 const std::vector<Ref<Array>> &args,
913 const std::unordered_map<std::string, Ref<Array>> &kws = {},
914 const std::regex &toLearn = std::regex{".*"});
915};
916
917} // namespace freetensor
918
919#endif // FREE_TENSOR_SCHEDULE_H
Definition: id.h:18
Definition: except.h:40
Definition: native_code.h:79
Definition: ref.h:24
Definition: schedule.h:32
void varReorder(const ID &def, const std::vector< int > &order)
Definition: var_reorder.cc:87
void varSqueeze(const ID &def, int dim)
Definition: var_squeeze.cc:66
std::vector< AutoScheduleTuneTrial > tuneAutoSchedule(int nBatch, int batchSize, const Ref< Device > &device, const std::vector< Ref< Array > > &args, const std::unordered_map< std::string, Ref< Array > > &kws={}, const std::regex &toLearn=std::regex{".*"})
Definition: schedule.cc:107
int verbose() const
Definition: schedule.h:184
void autoSetMemType(const Ref< Target > &target)
Definition: auto_set_mem_type.cc:217
void commitTransaction()
Definition: schedule.cc:55
void autoPluto(const Ref< Target > &target)
Definition: auto_pluto.cc:10
Schedule(const Func &func, int verbose=0)
Definition: schedule.h:123
std::pair< ID, int > plutoFuse(const ID &loop0, const ID &loop1, int nestLevel0=0, int nestLevel1=0, int fusableOverlapThreshold=1, int fusableNonOverlapTolerance=4, bool doSimplify=true)
Definition: pluto.cc:1249
void autoReorder(const Ref< Target > &target)
Definition: auto_reorder.cc:9
void autoParallelize(const Ref< Target > &target)
Definition: auto_parallelize.cc:312
void beginTransaction()
Definition: schedule.cc:53
void autoUnroll(const Ref< Target > &target)
Definition: auto_unroll.cc:6
void autoFissionFuse(const Ref< Target > &target, const Ref< RandTrace > &trace=nullptr)
Definition: auto_fission_fuse.cc:6
void unroll(const ID &loop, bool immediate=false)
Definition: unroll.cc:85
const Stmt & ast() const
Definition: schedule.cc:85
void setMemType(const ID &def, MemType mtype)
Definition: set_mem_type.cc:42
void autoSwap(const Ref< Target > &target)
Definition: auto_swap.cc:6
void varMerge(const ID &def, int dim)
Definition: var_merge.cc:78
std::pair< IDMap, IDMap > fission(const ID &loop, FissionSide side, const ID &splitter, bool allowEnlarge=true, const std::string &suffix0=".0", const std::string &suffix1=".1")
Definition: fission.cc:311
Stmt find(const T &filter) const
Definition: schedule.h:222
std::vector< Stmt > findAll(const T &filter) const
Definition: schedule.h:195
std::vector< ID > permute(const std::vector< ID > &loopsId, const std::function< std::vector< Expr >(std::vector< Expr >)> &transformFunc)
Definition: permute.cc:292
void separateTail(bool noDuplicateVarDefs=false)
Definition: separate_tail.cc:170
void autoMemLayout(const Ref< Target > &target)
Definition: auto_mem_layout.cc:8
void abortTransaction()
Definition: schedule.cc:78
void autoInline(const Ref< Target > &target)
Definition: auto_inline.cc:30
std::tuple< ID, ID, std::string, ID > cacheReduction(const ID &stmt, const std::string &var, MemType mtype)
Definition: cache.cc:341
std::vector< Stmt > findAtLeastOne(const T &filter) const
Definition: schedule.h:206
std::pair< ID, ID > moveTo(const ID &stmt, MoveToSide side, const ID &dst)
Definition: move_to.cc:5
void asMatMul(const ID &loop, AsMatMulMode mode, const Ref< Target > &target, MatMulBackend backend)
Definition: as_matmul.cc:460
void varSplit(const ID &def, int dim, VarSplitMode mode, int factor=-1, int nparts=-1)
Definition: var_split.cc:96
std::pair< ID, int > plutoPermute(const ID &loop, int nestLevel=0, bool doSimplify=true)
Definition: pluto.cc:1268
std::pair< ID, ID > split(const ID &id, int factor=-1, int nparts=-1, int shift=0)
Definition: split.cc:116
void inlining(const ID &def)
Definition: inlining.cc:189
ID merge(const ID &loop1, const ID &loop2)
Definition: merge.cc:135
void autoSchedule(const Ref< Target > &target, const Ref< RandTrace > &trace=nullptr)
Definition: schedule.cc:93
void vectorize(const ID &loop)
Definition: vectorize.cc:36
void varUnsqueeze(const ID &def, int dim)
Definition: var_unsqueeze.cc:60
Schedule(const Schedule &)=default
void reorder(const std::vector< ID > &order, ReorderMode mode=ReorderMode::PerfectOnly)
Definition: reorder.cc:278
std::unordered_map< ID, ID > IDMap
Definition: schedule.h:327
Func func() const
Definition: schedule.h:166
void parallelizeAs(const ID &nest, const ID &reference, const ID &defId)
Definition: parallelize_as.cc:296
Schedule fork() const
Definition: schedule.h:142
ID fuse(const ID &loop0, const ID &loop1, bool strict=false)
Definition: fuse.cc:304
Schedule & operator=(const Schedule &)=default
void blend(const ID &loop)
Definition: blend.cc:194
std::tuple< ID, ID, std::string, ID > cache(const ID &stmt, const std::string &var, MemType mtype)
Definition: cache.cc:326
void swap(const std::vector< ID > &order)
Definition: swap.cc:109
const ScheduleLog & logs() const
Definition: schedule.cc:88
void autoUseLib(const Ref< Target > &target)
Definition: auto_use_lib.cc:7
void parallelize(const ID &loop, const ParallelScope &parallel, bool allowReduction=true)
Definition: parallelize.cc:164
const T & top() const
Definition: shared_linked_list.h:31
Definition: except.h:112
#define ASSERT(expr)
Definition: except.h:152
Definition: allocator.h:9
AsMatMulMode
Definition: as_matmul.h:17
Stmt findStmt(const Stmt &ast, const ID &id)
Definition: find_stmt.cc:22
FissionSide
Definition: fission.h:15
std::vector< T > filter(const std::vector< T > &vec, const U &callback)
Definition: container_utils.h:131
MatMulBackend
Definition: stmt.h:465
std::variant< SerialScope, OpenMPScope, CUDAStreamScope, CUDAScope > ParallelScope
Definition: parallel_scope.h:73
VarSplitMode
Definition: var_split.h:8
ReorderMode
Definition: reorder.h:11
MoveToSide
Definition: schedule.h:23
std::vector< Stmt > findAllStmt(const Stmt &ast, const ID &id)
Definition: find_stmt.cc:32
Func makeFunc(const std::string &name, Tparams &&params, Treturns &&returns, Tbody &&body)
Definition: func.h:66
MemType
Definition: mem_type.h:14
PBFunc::Serialized func_
Definition: prop_one_time_use.cc:22
Definition: schedule.h:25
double stddev_
Definition: schedule.h:29
Ref< RandTrace > trace_
Definition: schedule.h:26
Func lowered_
Definition: schedule.h:27
NativeCode code_
Definition: schedule.h:28
double time_
Definition: schedule.h:29