1#ifndef FREE_TENSOR_SCHEDULE_H
2#define FREE_TENSOR_SCHEDULE_H
5#include <unordered_map>
44 std::vector<Transaction> openTrans_;
72 template <
class F>
auto futureSchedule(
const F &sched) {
73 return [&](
auto &&...args) {
74 auto ret = sched(
ast(), std::forward<
decltype(args)>(args)...);
75 if constexpr (std::convertible_to<
decltype(ret),
Stmt>) {
76 return quickOptimizations(ret);
78 return std::make_pair(quickOptimizations(ret.first),
92 template <
class T> T appendLog(
const T &_log) {
94 setLogs(memoized_->lookupOrCreate(
logs().push(log)));
96 log =
logs().
top().as<
typename decltype(log)::Object>();
109 template <
class T>
auto applyLog(
const T &log) {
110 auto ret = log->getResult();
111 if constexpr (std::convertible_to<
decltype(ret),
Stmt>) {
269 std::pair<ID, ID>
split(
const ID &
id,
int factor = -1,
int nparts = -1,
286 void reorder(
const std::vector<ID> &order,
323 permute(
const std::vector<ID> &loopsId,
324 const std::function<std::vector<Expr>(std::vector<Expr>)>
327 typedef std::unordered_map<ID, ID>
IDMap;
359 bool allowEnlarge =
true,
360 const std::string &suffix0 =
".0",
361 const std::string &suffix1 =
".1");
385 ID fuse(
const ID &loop0,
const ID &loop1,
bool strict =
false);
386 ID fuse(
const ID &loop0,
bool strict =
false);
398 void swap(
const std::vector<ID> &order);
463 std::tuple<ID, ID, std::string, ID>
495 std::tuple<ID, ID, std::string, ID>
555 void varReorder(
const ID &def,
const std::vector<int> &order);
674 bool allowReduction =
true);
704 void unroll(
const ID &loop,
bool immediate =
false);
806 std::pair<ID, int>
plutoFuse(
const ID &loop0,
const ID &loop1,
807 int nestLevel0 = 0,
int nestLevel1 = 0,
808 int fusableOverlapThreshold = 1,
809 int fusableNonOverlapTolerance = 4,
810 bool doSimplify =
true);
824 std::pair<ID, int>
plutoPermute(
const ID &loop,
int nestLevel = 0,
825 bool doSimplify =
true);
911 int nBatch,
int batchSize,
const Ref<Device> &device,
913 const std::unordered_map<std::string,
Ref<Array>> &kws = {},
914 const std::regex &toLearn = std::regex{
".*"});
Definition: native_code.h:79
Definition: schedule.h:32
void varReorder(const ID &def, const std::vector< int > &order)
Definition: var_reorder.cc:87
void varSqueeze(const ID &def, int dim)
Definition: var_squeeze.cc:66
std::vector< AutoScheduleTuneTrial > tuneAutoSchedule(int nBatch, int batchSize, const Ref< Device > &device, const std::vector< Ref< Array > > &args, const std::unordered_map< std::string, Ref< Array > > &kws={}, const std::regex &toLearn=std::regex{".*"})
Definition: schedule.cc:107
int verbose() const
Definition: schedule.h:184
void autoSetMemType(const Ref< Target > &target)
Definition: auto_set_mem_type.cc:217
void commitTransaction()
Definition: schedule.cc:55
void autoPluto(const Ref< Target > &target)
Definition: auto_pluto.cc:10
Schedule(const Func &func, int verbose=0)
Definition: schedule.h:123
std::pair< ID, int > plutoFuse(const ID &loop0, const ID &loop1, int nestLevel0=0, int nestLevel1=0, int fusableOverlapThreshold=1, int fusableNonOverlapTolerance=4, bool doSimplify=true)
Definition: pluto.cc:1249
void autoReorder(const Ref< Target > &target)
Definition: auto_reorder.cc:9
void autoParallelize(const Ref< Target > &target)
Definition: auto_parallelize.cc:312
void beginTransaction()
Definition: schedule.cc:53
void autoUnroll(const Ref< Target > &target)
Definition: auto_unroll.cc:6
void autoFissionFuse(const Ref< Target > &target, const Ref< RandTrace > &trace=nullptr)
Definition: auto_fission_fuse.cc:6
void unroll(const ID &loop, bool immediate=false)
Definition: unroll.cc:85
const Stmt & ast() const
Definition: schedule.cc:85
void setMemType(const ID &def, MemType mtype)
Definition: set_mem_type.cc:42
void autoSwap(const Ref< Target > &target)
Definition: auto_swap.cc:6
void varMerge(const ID &def, int dim)
Definition: var_merge.cc:78
std::pair< IDMap, IDMap > fission(const ID &loop, FissionSide side, const ID &splitter, bool allowEnlarge=true, const std::string &suffix0=".0", const std::string &suffix1=".1")
Definition: fission.cc:311
Stmt find(const T &filter) const
Definition: schedule.h:222
std::vector< Stmt > findAll(const T &filter) const
Definition: schedule.h:195
std::vector< ID > permute(const std::vector< ID > &loopsId, const std::function< std::vector< Expr >(std::vector< Expr >)> &transformFunc)
Definition: permute.cc:292
void separateTail(bool noDuplicateVarDefs=false)
Definition: separate_tail.cc:170
void autoMemLayout(const Ref< Target > &target)
Definition: auto_mem_layout.cc:8
void abortTransaction()
Definition: schedule.cc:78
void autoInline(const Ref< Target > &target)
Definition: auto_inline.cc:30
std::tuple< ID, ID, std::string, ID > cacheReduction(const ID &stmt, const std::string &var, MemType mtype)
Definition: cache.cc:341
std::vector< Stmt > findAtLeastOne(const T &filter) const
Definition: schedule.h:206
std::pair< ID, ID > moveTo(const ID &stmt, MoveToSide side, const ID &dst)
Definition: move_to.cc:5
void asMatMul(const ID &loop, AsMatMulMode mode, const Ref< Target > &target, MatMulBackend backend)
Definition: as_matmul.cc:460
void varSplit(const ID &def, int dim, VarSplitMode mode, int factor=-1, int nparts=-1)
Definition: var_split.cc:96
std::pair< ID, int > plutoPermute(const ID &loop, int nestLevel=0, bool doSimplify=true)
Definition: pluto.cc:1268
std::pair< ID, ID > split(const ID &id, int factor=-1, int nparts=-1, int shift=0)
Definition: split.cc:116
void inlining(const ID &def)
Definition: inlining.cc:189
ID merge(const ID &loop1, const ID &loop2)
Definition: merge.cc:135
void autoSchedule(const Ref< Target > &target, const Ref< RandTrace > &trace=nullptr)
Definition: schedule.cc:93
void vectorize(const ID &loop)
Definition: vectorize.cc:36
void varUnsqueeze(const ID &def, int dim)
Definition: var_unsqueeze.cc:60
Schedule(const Schedule &)=default
void reorder(const std::vector< ID > &order, ReorderMode mode=ReorderMode::PerfectOnly)
Definition: reorder.cc:278
std::unordered_map< ID, ID > IDMap
Definition: schedule.h:327
Func func() const
Definition: schedule.h:166
void parallelizeAs(const ID &nest, const ID &reference, const ID &defId)
Definition: parallelize_as.cc:296
Schedule fork() const
Definition: schedule.h:142
ID fuse(const ID &loop0, const ID &loop1, bool strict=false)
Definition: fuse.cc:304
Schedule & operator=(const Schedule &)=default
void blend(const ID &loop)
Definition: blend.cc:194
std::tuple< ID, ID, std::string, ID > cache(const ID &stmt, const std::string &var, MemType mtype)
Definition: cache.cc:326
void swap(const std::vector< ID > &order)
Definition: swap.cc:109
const ScheduleLog & logs() const
Definition: schedule.cc:88
void autoUseLib(const Ref< Target > &target)
Definition: auto_use_lib.cc:7
void parallelize(const ID &loop, const ParallelScope ¶llel, bool allowReduction=true)
Definition: parallelize.cc:164
const T & top() const
Definition: shared_linked_list.h:31
#define ASSERT(expr)
Definition: except.h:152
Definition: allocator.h:9
AsMatMulMode
Definition: as_matmul.h:17
Stmt findStmt(const Stmt &ast, const ID &id)
Definition: find_stmt.cc:22
FissionSide
Definition: fission.h:15
std::vector< T > filter(const std::vector< T > &vec, const U &callback)
Definition: container_utils.h:131
MatMulBackend
Definition: stmt.h:465
std::variant< SerialScope, OpenMPScope, CUDAStreamScope, CUDAScope > ParallelScope
Definition: parallel_scope.h:73
VarSplitMode
Definition: var_split.h:8
ReorderMode
Definition: reorder.h:11
MoveToSide
Definition: schedule.h:23
std::vector< Stmt > findAllStmt(const Stmt &ast, const ID &id)
Definition: find_stmt.cc:32
Func makeFunc(const std::string &name, Tparams &¶ms, Treturns &&returns, Tbody &&body)
Definition: func.h:66
MemType
Definition: mem_type.h:14
PBFunc::Serialized func_
Definition: prop_one_time_use.cc:22
Definition: schedule.h:25
double stddev_
Definition: schedule.h:29
Ref< RandTrace > trace_
Definition: schedule.h:26
Func lowered_
Definition: schedule.h:27
NativeCode code_
Definition: schedule.h:28
double time_
Definition: schedule.h:29