1#ifndef FREE_TENSOR_DRIVER_H
2#define FREE_TENSOR_DRIVER_H
5#include <unordered_map>
11#include <../runtime/cpu_context.h>
13#include <../runtime/gpu_context.h>
19 void *dlHandle_ =
nullptr;
20 void (*func_)(
void ** ,
void ** ,
21 size_t ** ,
size_t * ,
25 std::vector<Ref<Array>> args_;
26 std::vector<void *> rawArgs_,
29 std::vector<size_t *> retShapes_;
30 std::vector<size_t> retDims_;
31 std::unordered_map<std::string, size_t> name2param_;
34 std::unique_ptr<Context> ctx_;
36 std::vector<std::string> cxxFlags_;
38 bool verbose_ =
false;
56 const std::vector<std::string> &cxxFlags = {},
bool verbose =
false);
58 const std::vector<std::string> &cxxFlags = {},
bool verbose =
false)
67 for (
void *retVal : rawRets_) {
68 if (retVal !=
nullptr) {
69 WARNING(
"Return values must be collected, or there will be "
83 const std::unordered_map<std::string,
Ref<Array>> &kws = {});
107 std::pair<double, double>
time(
int rounds = 10,
int warmups = 3);
const Ref< Device > & device() const
Definition: driver.h:111
void setArgs(const std::vector< Ref< Array > > &args, const std::unordered_map< std::string, Ref< Array > > &kws={})
Definition: driver.cc:345
void setArgs(const std::unordered_map< std::string, Ref< Array > > &kws)
Definition: driver.h:84
~Driver()
Definition: driver.h:66
void unload()
Definition: driver.cc:557
void sync()
Definition: driver.cc:444
std::vector< Ref< Array > > collectReturns()
Definition: driver.cc:446
Driver(const Driver &)=delete
Driver & operator=(const Driver &)=delete
Driver & operator=(Driver &&)=delete
std::pair< double, double > time(int rounds=10, int warmups=3)
Definition: driver.cc:497
Driver(const NativeCode &nativeCode, const Ref< Device > &device, const std::vector< std::string > &cxxFlags={}, bool verbose=false)
Definition: driver.h:57
void run()
Definition: driver.cc:434
Definition: native_code.h:79
#define WARNING(msg)
Definition: except.h:146
Definition: allocator.h:9