3
3
#include < memory>
4
4
5
5
#include < CL/sycl.hpp>
6
+ #include < type_traits>
7
+ #include < variant>
6
8
7
9
#include " config.h"
8
10
#include " workaround.h"
9
11
10
12
namespace celerity {
11
13
namespace detail {
12
14
15
+ struct auto_select_device {};
16
+ using device_selector = std::function<int (const sycl::device&)>;
17
+ using device_or_selector = std::variant<auto_select_device, sycl::device, device_selector>;
18
+
13
19
class task ;
14
20
15
21
/* *
@@ -21,9 +27,9 @@ namespace detail {
21
27
* @brief Initializes the @p device_queue, selecting an appropriate device in the process.
22
28
*
23
29
* @param cfg The configuration is used to select the appropriate SYCL device.
24
- * @param user_device Optionally a device can be provided, which will take precedence over any configuration.
30
+ * @param user_device_or_selector Optionally a device ( which will take precedence over any configuration) or a device selector can be provided .
25
31
*/
26
- void init (const config& cfg, cl::sycl::device* user_device );
32
+ void init (const config& cfg, const device_or_selector& user_device_or_selector );
27
33
28
34
/* *
29
35
* @brief Executes the kernel associated with task @p ctsk over the chunk @p chnk.
@@ -62,12 +68,111 @@ namespace detail {
62
68
void handle_async_exceptions (cl::sycl::exception_list el) const ;
63
69
};
64
70
71
+ // Try to find a platform that can provide a unique device for each node using a device selector.
72
+ template <typename DeviceT, typename PlatformT, typename SelectorT>
73
+ bool try_find_device_per_node (
74
+ std::string& how_selected, DeviceT& device, const std::vector<PlatformT>& platforms, const host_config& host_cfg, SelectorT selector) {
75
+ std::vector<std::tuple<DeviceT, size_t >> devices_with_platform_idx;
76
+ for (size_t i = 0 ; i < platforms.size (); ++i) {
77
+ auto && platform = platforms[i];
78
+ for (auto device : platform.get_devices ()) {
79
+ if (selector (device) == -1 ) { continue ; }
80
+ devices_with_platform_idx.emplace_back (device, i);
81
+ }
82
+ }
83
+
84
+ std::stable_sort (devices_with_platform_idx.begin (), devices_with_platform_idx.end (),
85
+ [selector](const auto & a, const auto & b) { return selector (std::get<0 >(a)) > selector (std::get<0 >(b)); });
86
+ bool same_platform = true ;
87
+ bool same_device_type = true ;
88
+ if (devices_with_platform_idx.size () >= host_cfg.node_count ) {
89
+ auto [device_from_platform, idx] = devices_with_platform_idx[0 ];
90
+ const auto platform = device_from_platform.get_platform ();
91
+ const auto device_type = device_from_platform.template get_info <sycl::info::device::device_type>();
92
+
93
+ for (size_t i = 1 ; i < host_cfg.node_count ; ++i) {
94
+ auto [device_from_platform, idx] = devices_with_platform_idx[i];
95
+ if (device_from_platform.get_platform () != platform) { same_platform = false ; }
96
+ if (device_from_platform.template get_info <sycl::info::device::device_type>() != device_type) { same_device_type = false ; }
97
+ }
98
+
99
+ if (!same_platform || !same_device_type) { CELERITY_WARN (" Selected devices are of different type and/or do not belong to the same platform" ); }
100
+
101
+ auto [selected_device_from_platform, selected_idx] = devices_with_platform_idx[host_cfg.local_rank ];
102
+ how_selected = fmt::format (" device selector specified: platform {}, device {}" , selected_idx, host_cfg.local_rank );
103
+ device = selected_device_from_platform;
104
+ return true ;
105
+ }
106
+
107
+ return false ;
108
+ }
109
+
110
+ // Try to find a platform that can provide a unique device for each node.
65
111
template <typename DeviceT, typename PlatformT>
66
- DeviceT pick_device (const config& cfg, DeviceT* user_device, const std::vector<PlatformT>& platforms) {
112
+ bool try_find_device_per_node (
113
+ std::string& how_selected, DeviceT& device, const std::vector<PlatformT>& platforms, const host_config& host_cfg, sycl::info::device_type type) {
114
+ for (size_t i = 0 ; i < platforms.size (); ++i) {
115
+ auto && platform = platforms[i];
116
+ std::vector<DeviceT> platform_devices;
117
+
118
+ platform_devices = platform.get_devices (type);
119
+ if (platform_devices.size () >= host_cfg.node_count ) {
120
+ how_selected = fmt::format (" automatically selected platform {}, device {}" , i, host_cfg.local_rank );
121
+ device = platform_devices[host_cfg.local_rank ];
122
+ return true ;
123
+ }
124
+ }
125
+
126
+ return false ;
127
+ }
128
+
129
+ template <typename DeviceT, typename PlatformT, typename SelectorT>
130
+ bool try_find_one_device (
131
+ std::string& how_selected, DeviceT& device, const std::vector<PlatformT>& platforms, const host_config& host_cfg, SelectorT selector) {
132
+ std::vector<DeviceT> platform_devices;
133
+ for (auto & p : platforms) {
134
+ auto p_devices = p.get_devices ();
135
+ platform_devices.insert (platform_devices.end (), p_devices.begin (), p_devices.end ());
136
+ }
137
+
138
+ std::stable_sort (platform_devices.begin (), platform_devices.end (), [selector](const auto & a, const auto & b) { return selector (a) > selector (b); });
139
+ if (!platform_devices.empty ()) {
140
+ if (selector (platform_devices[0 ]) == -1 ) { return false ; }
141
+ device = platform_devices[0 ];
142
+ return true ;
143
+ }
144
+
145
+ return false ;
146
+ };
147
+
148
+ template <typename DeviceT, typename PlatformT>
149
+ bool try_find_one_device (
150
+ std::string& how_selected, DeviceT& device, const std::vector<PlatformT>& platforms, const host_config& host_cfg, sycl::info::device_type type) {
151
+ for (auto & p : platforms) {
152
+ for (auto & d : p.get_devices (type)) {
153
+ device = d;
154
+ return true ;
155
+ }
156
+ }
157
+
158
+ return false ;
159
+ };
160
+
161
+
162
+ template <typename DevicePtrOrSelector, typename PlatformT>
163
+ auto pick_device (const config& cfg, const DevicePtrOrSelector& user_device_or_selector, const std::vector<PlatformT>& platforms) {
164
+ using DeviceT = typename decltype (std::declval<PlatformT&>().get_devices ())::value_type;
165
+
166
+ constexpr bool user_device_provided = std::is_same_v<DevicePtrOrSelector, DeviceT>;
167
+ constexpr bool device_selector_provided = std::is_invocable_r_v<int , DevicePtrOrSelector, DeviceT>;
168
+ constexpr bool auto_select = std::is_same_v<auto_select_device, DevicePtrOrSelector>;
169
+ static_assert (
170
+ user_device_provided ^ device_selector_provided ^ auto_select, " pick_device requires either a device, a selector, or the auto_select_device tag" );
171
+
67
172
DeviceT device;
68
173
std::string how_selected = " automatically selected" ;
69
- if (user_device != nullptr ) {
70
- device = *user_device ;
174
+ if constexpr (user_device_provided ) {
175
+ device = user_device_or_selector ;
71
176
how_selected = " specified by user" ;
72
177
} else {
73
178
const auto device_cfg = cfg.get_device_config ();
@@ -86,48 +191,37 @@ namespace detail {
86
191
} else {
87
192
const auto host_cfg = cfg.get_host_config ();
88
193
89
- const auto try_find_device_per_node = [&host_cfg, &device, &how_selected, &platforms](cl::sycl::info::device_type type) {
90
- // Try to find a platform that can provide a unique device for each node.
91
- for (size_t i = 0 ; i < platforms.size (); ++i) {
92
- auto && platform = platforms[i];
93
- const auto devices = platform.get_devices (type);
94
- if (devices.size () >= host_cfg.node_count ) {
95
- how_selected = fmt::format (" automatically selected platform {}, device {}" , i, host_cfg.local_rank );
96
- device = devices[host_cfg.local_rank ];
97
- return true ;
98
- }
99
- }
100
- return false ;
101
- };
102
-
103
- const auto try_find_one_device = [&device, &platforms](cl::sycl::info::device_type type) {
104
- for (auto & p : platforms) {
105
- for (auto & d : p.get_devices (type)) {
106
- device = d;
107
- return true ;
194
+ if constexpr (!device_selector_provided) {
195
+ // Try to find a unique GPU per node.
196
+ if (!try_find_device_per_node (how_selected, device, platforms, host_cfg, sycl::info::device_type::gpu)) {
197
+ if (try_find_device_per_node (how_selected, device, platforms, host_cfg, sycl::info::device_type::all)) {
198
+ CELERITY_WARN (" No suitable platform found that can provide {} GPU devices, and CELERITY_DEVICES not set" , host_cfg.node_count );
199
+ } else {
200
+ CELERITY_WARN (" No suitable platform found that can provide {} devices, and CELERITY_DEVICES not set" , host_cfg.node_count );
201
+ // Just use the first available device. Prefer GPUs, but settle for anything.
202
+ if (!try_find_one_device (how_selected, device, platforms, host_cfg, sycl::info::device_type::gpu)
203
+ && !try_find_one_device (how_selected, device, platforms, host_cfg, sycl::info::device_type::all)) {
204
+ throw std::runtime_error (" Automatic device selection failed: No device available" );
205
+ }
108
206
}
109
207
}
110
- return false ;
111
- };
112
-
113
- // Try to find a unique GPU per node.
114
- if (!try_find_device_per_node (cl::sycl::info::device_type::gpu)) {
115
- // Try to find a unique device (of any type) per node.
116
- if (try_find_device_per_node (cl::sycl::info::device_type::all)) {
117
- CELERITY_WARN (" No suitable platform found that can provide {} GPU devices, and CELERITY_DEVICES not set" , host_cfg.node_count );
118
- } else {
119
- CELERITY_WARN (" No suitable platform found that can provide {} devices, and CELERITY_DEVICES not set" , host_cfg.node_count );
120
- // Just use the first available device. Prefer GPUs, but settle for anything.
121
- if (!try_find_one_device (cl::sycl::info::device_type::gpu) && !try_find_one_device (cl::sycl::info::device_type::all)) {
122
- throw std::runtime_error (" Automatic device selection failed: No device available" );
208
+ } else {
209
+ // Try to find a unique device per node using a selector.
210
+ if (!try_find_device_per_node (how_selected, device, platforms, host_cfg, user_device_or_selector)) {
211
+ CELERITY_WARN (" No suitable platform found that can provide {} devices that match the specified device selector, and "
212
+ " CELERITY_DEVICES not set" ,
213
+ host_cfg.node_count );
214
+ // Use the first available device according to the selector, but fails if no such device is found.
215
+ if (!try_find_one_device (how_selected, device, platforms, host_cfg, user_device_or_selector)) {
216
+ throw std::runtime_error (" Device selection with device selector failed: No device available" );
123
217
}
124
218
}
125
219
}
126
220
}
127
221
}
128
222
129
- const auto platform_name = device.get_platform ().template get_info <cl:: sycl::info::platform::name>();
130
- const auto device_name = device.template get_info <cl:: sycl::info::device::name>();
223
+ const auto platform_name = device.get_platform ().template get_info <sycl::info::platform::name>();
224
+ const auto device_name = device.template get_info <sycl::info::device::name>();
131
225
CELERITY_INFO (" Using platform '{}', device '{}' ({})" , platform_name, device_name, how_selected);
132
226
133
227
return device;
0 commit comments