FreeTensor
Loading...
Searching...
No Matches
parallel_scope.h
Go to the documentation of this file.
1#ifndef FREE_TENSOR_PARALLEL_SCOPE_H
2#define FREE_TENSOR_PARALLEL_SCOPE_H
3
4#include <iostream>
5#include <string>
6#include <variant>
7
8#include <container_utils.h>
9#include <except.h>
10#include <hash_combine.h>
11#include <serialize/to_string.h>
12
13namespace freetensor {
14
15struct SerialScope {};
16inline bool operator==(const SerialScope &lhs, const SerialScope &rhs) {
17 return true;
18}
19
20struct OpenMPScope {};
21inline bool operator==(const OpenMPScope &lhs, const OpenMPScope &rhs) {
22 return true;
23}
24inline std::ostream &operator<<(std::ostream &os, const OpenMPScope &parallel) {
25 return os << "openmp";
26}
27
29inline bool operator==(const CUDAStreamScope &lhs, const CUDAStreamScope &rhs) {
30 return true;
31}
32inline std::ostream &operator<<(std::ostream &os,
33 const CUDAStreamScope &parallel) {
34 return os << "cudastream";
35}
36
37struct CUDAScope {
39 enum Dim { X, Y, Z } dim_;
40};
41inline bool operator==(const CUDAScope &lhs, const CUDAScope &rhs) {
42 return lhs.level_ == rhs.level_ && lhs.dim_ == rhs.dim_;
43}
44inline std::ostream &operator<<(std::ostream &os, const CUDAScope &parallel) {
45 switch (parallel.level_) {
47 os << "blockIdx";
48 break;
50 os << "threadIdx";
51 break;
52 default:
53 ASSERT(false);
54 }
55 switch (parallel.dim_) {
57 os << ".x";
58 break;
60 os << ".y";
61 break;
63 os << ".z";
64 break;
65 default:
66 ASSERT(false);
67 }
68 return os;
69}
70
71// The first type is default
72typedef std::variant<SerialScope, OpenMPScope, CUDAStreamScope, CUDAScope>
74
75inline std::ostream &operator<<(std::ostream &os,
76 const ParallelScope &parallel) {
77 if (std::holds_alternative<SerialScope>(parallel)) {
78 return os;
79 } else if (std::holds_alternative<OpenMPScope>(parallel)) {
80 return os << std::get<OpenMPScope>(parallel);
81 } else if (std::holds_alternative<CUDAScope>(parallel)) {
82 return os << std::get<CUDAScope>(parallel);
83 } else if (std::holds_alternative<CUDAStreamScope>(parallel)) {
84 return os << std::get<CUDAStreamScope>(parallel);
85 } else {
86 ASSERT(false);
87 }
88}
89
90inline ParallelScope parseParallelScope(const std::string &_str) {
91 auto &&str = tolower(_str);
92 if (auto scope = SerialScope{}; str == tolower(toString(scope))) {
93 return scope;
94 }
95 if (auto scope = OpenMPScope{}; str == tolower(toString(scope))) {
96 return scope;
97 }
98 if (auto scope = CUDAStreamScope{}; str == tolower(toString(scope))) {
99 return scope;
100 }
101 for (auto &&level : {CUDAScope::Block, CUDAScope::Thread}) {
102 for (auto &&dim : {CUDAScope::X, CUDAScope::Y, CUDAScope::Z}) {
103 if (auto scope = CUDAScope{level, dim};
104 str == tolower(toString(scope))) {
105 return scope;
106 }
107 }
108 }
109 ERROR("Unrecognized parallel scope " + _str);
110}
111
113
120
121} // namespace freetensor
122
123namespace std {
124
125template <> struct hash<freetensor::SerialScope> {
126 size_t operator()(const freetensor::SerialScope &) { return 0; }
127};
128
129template <> struct hash<freetensor::OpenMPScope> {
130 size_t operator()(const freetensor::OpenMPScope &) { return 0; }
131};
132
133template <> struct hash<freetensor::CUDAStreamScope> {
134 size_t operator()(const freetensor::CUDAStreamScope &) { return 0; }
135};
136
137template <> struct hash<freetensor::CUDAScope> {
138 size_t operator()(const freetensor::CUDAScope &parallel) {
139 return freetensor::hashCombine(std::hash<int>()((int)parallel.level_),
140 std::hash<int>()((int)parallel.dim_));
141 }
142};
143
144} // namespace std
145
146#endif // FREE_TENSOR_PARALLEL_SCOPE_H
#define ASSERT(expr)
Definition: except.h:152
#define ERROR(msg)
Definition: except.h:141
Definition: allocator.h:9
auto && lhs
Definition: const_fold.cc:70
constexpr ParallelScope threadIdxY
Definition: parallel_scope.h:115
constexpr ParallelScope blockIdxZ
Definition: parallel_scope.h:119
std::string tolower(const std::string &s)
Definition: container_utils.h:142
std::string toString(const AST &op)
Definition: print_ast.cc:784
bool operator==(const Allocator< T > &lhs, const Allocator< T > &rhs)
Definition: allocator.h:100
constexpr ParallelScope threadIdxZ
Definition: parallel_scope.h:116
std::variant< SerialScope, OpenMPScope, CUDAStreamScope, CUDAScope > ParallelScope
Definition: parallel_scope.h:73
ParallelScope parseParallelScope(const std::string &_str)
Definition: parallel_scope.h:90
constexpr ParallelScope blockIdxX
Definition: parallel_scope.h:117
constexpr ParallelScope serialScope
Definition: parallel_scope.h:112
auto auto && rhs
Definition: const_fold.cc:70
size_t hashCombine(size_t seed, size_t other)
Definition: hash_combine.cc:5
std::ostream & operator<<(std::ostream &os, const Dependence &dep)
Definition: deps.cc:1404
constexpr ParallelScope threadIdxX
Definition: parallel_scope.h:114
constexpr ParallelScope blockIdxY
Definition: parallel_scope.h:118
STL namespace.
Definition: parallel_scope.h:37
Level
Definition: parallel_scope.h:38
@ Thread
Definition: parallel_scope.h:38
@ Block
Definition: parallel_scope.h:38
enum freetensor::CUDAScope::Level level_
enum freetensor::CUDAScope::Dim dim_
Dim
Definition: parallel_scope.h:39
@ X
Definition: parallel_scope.h:39
@ Y
Definition: parallel_scope.h:39
@ Z
Definition: parallel_scope.h:39
Definition: parallel_scope.h:28
Definition: parallel_scope.h:20
Definition: parallel_scope.h:15
size_t operator()(const freetensor::CUDAScope &parallel)
Definition: parallel_scope.h:138
size_t operator()(const freetensor::CUDAStreamScope &)
Definition: parallel_scope.h:134
size_t operator()(const freetensor::OpenMPScope &)
Definition: parallel_scope.h:130
size_t operator()(const freetensor::SerialScope &)
Definition: parallel_scope.h:126