FreeTensor
Loading...
Searching...
No Matches
driver.h
Go to the documentation of this file.
1#ifndef FREE_TENSOR_DRIVER_H
2#define FREE_TENSOR_DRIVER_H
3
4#include <string>
5#include <unordered_map>
6#include <vector>
7
9#include <driver/array.h>
10
11#include <../runtime/cpu_context.h>
12#ifdef FT_WITH_CUDA
13#include <../runtime/gpu_context.h>
14#endif
15
16namespace freetensor {
17
18class Driver {
19 void *dlHandle_ = nullptr;
20 void (*func_)(void ** /* params */, void ** /* retRaw */,
21 size_t ** /* retShapes */, size_t * /* retDims */,
22 void * /* ctx */) = nullptr;
23
24 NativeCode nativeCode_;
25 std::vector<Ref<Array>> args_;
26 std::vector<void *> rawArgs_,
27 rawRets_;
29 std::vector<size_t *> retShapes_;
30 std::vector<size_t> retDims_;
31 std::unordered_map<std::string, size_t> name2param_;
32 Ref<Device> dev_, hostDev_;
33
34 std::unique_ptr<Context> ctx_;
35
36 std::vector<std::string> cxxFlags_;
37
38 bool verbose_ = false;
39
40 private:
41 void buildAndLoad();
42
43 public:
54 Driver(const NativeCode &nativeCode, const Ref<Device> &device,
55 const Ref<Device> &hostDevice,
56 const std::vector<std::string> &cxxFlags = {}, bool verbose = false);
57 Driver(const NativeCode &nativeCode, const Ref<Device> &device,
58 const std::vector<std::string> &cxxFlags = {}, bool verbose = false)
59 : Driver(nativeCode, device,
60 device->type() == TargetType::CPU
61 ? device
63 cxxFlags, verbose) {}
67 for (void *retVal : rawRets_) {
68 if (retVal != nullptr) {
69 WARNING("Return values must be collected, or there will be "
70 "memory leaks");
71 }
72 }
73 unload();
74 }
75
76 Driver(const Driver &) = delete;
77 Driver &operator=(const Driver &) = delete;
78
79 Driver(Driver &&) = delete; // If we need it, pay attention to `dlHandle_`
80 Driver &operator=(Driver &&) = delete;
81
82 void setArgs(const std::vector<Ref<Array>> &args,
83 const std::unordered_map<std::string, Ref<Array>> &kws = {});
84 void setArgs(const std::unordered_map<std::string, Ref<Array>> &kws) {
85 setArgs({}, kws);
86 }
87
88 void run();
89
95 void sync();
96
97 std::vector<Ref<Array>> collectReturns();
98
107 std::pair<double, double> time(int rounds = 10, int warmups = 3);
108
109 void unload();
110
111 const Ref<Device> &device() const { return dev_; }
112};
113
114} // namespace freetensor
115
116#endif // FREE_TENSOR_DRIVER_H
Definition: driver.h:18
const Ref< Device > & device() const
Definition: driver.h:111
void setArgs(const std::vector< Ref< Array > > &args, const std::unordered_map< std::string, Ref< Array > > &kws={})
Definition: driver.cc:345
Driver(Driver &&)=delete
void setArgs(const std::unordered_map< std::string, Ref< Array > > &kws)
Definition: driver.h:84
~Driver()
Definition: driver.h:66
void unload()
Definition: driver.cc:557
void sync()
Definition: driver.cc:444
std::vector< Ref< Array > > collectReturns()
Definition: driver.cc:446
Driver(const Driver &)=delete
Driver & operator=(const Driver &)=delete
Driver & operator=(Driver &&)=delete
std::pair< double, double > time(int rounds=10, int warmups=3)
Definition: driver.cc:497
Driver(const NativeCode &nativeCode, const Ref< Device > &device, const std::vector< std::string > &cxxFlags={}, bool verbose=false)
Definition: driver.h:57
void run()
Definition: driver.cc:434
Definition: native_code.h:79
Definition: ref.h:24
#define WARNING(msg)
Definition: except.h:146
Definition: allocator.h:9