Skip to content

Commit 556b6f2

Browse files
almightyvatspsalz
authored andcommitted
Add support for passing device selector to distr_queue ctor
... and runtime::init
1 parent 2a56b50 commit 556b6f2

File tree

11 files changed

+730
-243
lines changed

11 files changed

+730
-243
lines changed

include/celerity.h

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#ifndef RUNTIME_INCLUDE_ENTRY_CELERITY
22
#define RUNTIME_INCLUDE_ENTRY_CELERITY
33

4+
#include "device_queue.h"
45
#include "runtime.h"
56

67
#include "accessor.h"
@@ -15,14 +16,24 @@ namespace runtime {
1516
/**
1617
* @brief Initializes the Celerity runtime.
1718
*/
18-
inline void init(int* argc, char** argv[]) { detail::runtime::init(argc, argv, nullptr); }
19+
inline void init(int* argc, char** argv[]) { detail::runtime::init(argc, argv, detail::auto_select_device{}); }
1920

2021
/**
2122
* @brief Initializes the Celerity runtime and instructs it to use a particular device.
2223
*
2324
* @param device The device to be used on the current node. This can vary between nodes.
2425
*/
25-
inline void init(int* argc, char** argv[], cl::sycl::device& device) { detail::runtime::init(argc, argv, &device); }
26+
[[deprecated("Use the overload with device selector instead, this will be removed in future release")]] inline void init(
27+
int* argc, char** argv[], sycl::device& device) {
28+
detail::runtime::init(argc, argv, device);
29+
}
30+
31+
/**
32+
* @brief Initializes the Celerity runtime and instructs it to use a particular device.
33+
*
34+
* @param device_selector The device selector to be used on the current node. This can vary between nodes.
35+
*/
36+
inline void init(int* argc, char** argv[], const detail::device_selector& device_selector) { detail::runtime::init(argc, argv, device_selector); }
2637
} // namespace runtime
2738
} // namespace celerity
2839

include/config.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ namespace detail {
1919
};
2020

2121
class config {
22+
friend struct config_testspy;
23+
2224
public:
2325
/**
2426
* Initializes the @p config by parsing environment variables and passed arguments.
@@ -48,7 +50,6 @@ namespace detail {
4850
std::optional<device_config> device_cfg;
4951
std::optional<bool> enable_device_profiling;
5052
size_t graph_print_max_verts = 200;
51-
friend struct config_testspy;
5253
};
5354

5455
} // namespace detail

include/device_queue.h

Lines changed: 133 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,19 @@
33
#include <memory>
44

55
#include <CL/sycl.hpp>
6+
#include <type_traits>
7+
#include <variant>
68

79
#include "config.h"
810
#include "workaround.h"
911

1012
namespace celerity {
1113
namespace detail {
1214

15+
struct auto_select_device {};
16+
using device_selector = std::function<int(const sycl::device&)>;
17+
using device_or_selector = std::variant<auto_select_device, sycl::device, device_selector>;
18+
1319
class task;
1420

1521
/**
@@ -21,9 +27,9 @@ namespace detail {
2127
* @brief Initializes the @p device_queue, selecting an appropriate device in the process.
2228
*
2329
* @param cfg The configuration is used to select the appropriate SYCL device.
24-
* @param user_device Optionally a device can be provided, which will take precedence over any configuration.
30+
* @param user_device_or_selector Optionally a device (which will take precedence over any configuration) or a device selector can be provided.
2531
*/
26-
void init(const config& cfg, cl::sycl::device* user_device);
32+
void init(const config& cfg, const device_or_selector& user_device_or_selector);
2733

2834
/**
2935
* @brief Executes the kernel associated with task @p ctsk over the chunk @p chnk.
@@ -62,12 +68,111 @@ namespace detail {
6268
void handle_async_exceptions(cl::sycl::exception_list el) const;
6369
};
6470

71+
// Try to find a platform that can provide a unique device for each node using a device selector.
72+
template <typename DeviceT, typename PlatformT, typename SelectorT>
73+
bool try_find_device_per_node(
74+
std::string& how_selected, DeviceT& device, const std::vector<PlatformT>& platforms, const host_config& host_cfg, SelectorT selector) {
75+
std::vector<std::tuple<DeviceT, size_t>> devices_with_platform_idx;
76+
for(size_t i = 0; i < platforms.size(); ++i) {
77+
auto&& platform = platforms[i];
78+
for(auto device : platform.get_devices()) {
79+
if(selector(device) == -1) { continue; }
80+
devices_with_platform_idx.emplace_back(device, i);
81+
}
82+
}
83+
84+
std::stable_sort(devices_with_platform_idx.begin(), devices_with_platform_idx.end(),
85+
[selector](const auto& a, const auto& b) { return selector(std::get<0>(a)) > selector(std::get<0>(b)); });
86+
bool same_platform = true;
87+
bool same_device_type = true;
88+
if(devices_with_platform_idx.size() >= host_cfg.node_count) {
89+
auto [device_from_platform, idx] = devices_with_platform_idx[0];
90+
const auto platform = device_from_platform.get_platform();
91+
const auto device_type = device_from_platform.template get_info<sycl::info::device::device_type>();
92+
93+
for(size_t i = 1; i < host_cfg.node_count; ++i) {
94+
auto [device_from_platform, idx] = devices_with_platform_idx[i];
95+
if(device_from_platform.get_platform() != platform) { same_platform = false; }
96+
if(device_from_platform.template get_info<sycl::info::device::device_type>() != device_type) { same_device_type = false; }
97+
}
98+
99+
if(!same_platform || !same_device_type) { CELERITY_WARN("Selected devices are of different type and/or do not belong to the same platform"); }
100+
101+
auto [selected_device_from_platform, selected_idx] = devices_with_platform_idx[host_cfg.local_rank];
102+
how_selected = fmt::format("device selector specified: platform {}, device {}", selected_idx, host_cfg.local_rank);
103+
device = selected_device_from_platform;
104+
return true;
105+
}
106+
107+
return false;
108+
}
109+
110+
// Try to find a platform that can provide a unique device for each node.
65111
template <typename DeviceT, typename PlatformT>
66-
DeviceT pick_device(const config& cfg, DeviceT* user_device, const std::vector<PlatformT>& platforms) {
112+
bool try_find_device_per_node(
113+
std::string& how_selected, DeviceT& device, const std::vector<PlatformT>& platforms, const host_config& host_cfg, sycl::info::device_type type) {
114+
for(size_t i = 0; i < platforms.size(); ++i) {
115+
auto&& platform = platforms[i];
116+
std::vector<DeviceT> platform_devices;
117+
118+
platform_devices = platform.get_devices(type);
119+
if(platform_devices.size() >= host_cfg.node_count) {
120+
how_selected = fmt::format("automatically selected platform {}, device {}", i, host_cfg.local_rank);
121+
device = platform_devices[host_cfg.local_rank];
122+
return true;
123+
}
124+
}
125+
126+
return false;
127+
}
128+
129+
template <typename DeviceT, typename PlatformT, typename SelectorT>
130+
bool try_find_one_device(
131+
std::string& how_selected, DeviceT& device, const std::vector<PlatformT>& platforms, const host_config& host_cfg, SelectorT selector) {
132+
std::vector<DeviceT> platform_devices;
133+
for(auto& p : platforms) {
134+
auto p_devices = p.get_devices();
135+
platform_devices.insert(platform_devices.end(), p_devices.begin(), p_devices.end());
136+
}
137+
138+
std::stable_sort(platform_devices.begin(), platform_devices.end(), [selector](const auto& a, const auto& b) { return selector(a) > selector(b); });
139+
if(!platform_devices.empty()) {
140+
if(selector(platform_devices[0]) == -1) { return false; }
141+
device = platform_devices[0];
142+
return true;
143+
}
144+
145+
return false;
146+
};
147+
148+
template <typename DeviceT, typename PlatformT>
149+
bool try_find_one_device(
150+
std::string& how_selected, DeviceT& device, const std::vector<PlatformT>& platforms, const host_config& host_cfg, sycl::info::device_type type) {
151+
for(auto& p : platforms) {
152+
for(auto& d : p.get_devices(type)) {
153+
device = d;
154+
return true;
155+
}
156+
}
157+
158+
return false;
159+
};
160+
161+
162+
template <typename DevicePtrOrSelector, typename PlatformT>
163+
auto pick_device(const config& cfg, const DevicePtrOrSelector& user_device_or_selector, const std::vector<PlatformT>& platforms) {
164+
using DeviceT = typename decltype(std::declval<PlatformT&>().get_devices())::value_type;
165+
166+
constexpr bool user_device_provided = std::is_same_v<DevicePtrOrSelector, DeviceT>;
167+
constexpr bool device_selector_provided = std::is_invocable_r_v<int, DevicePtrOrSelector, DeviceT>;
168+
constexpr bool auto_select = std::is_same_v<auto_select_device, DevicePtrOrSelector>;
169+
static_assert(
170+
user_device_provided ^ device_selector_provided ^ auto_select, "pick_device requires either a device, a selector, or the auto_select_device tag");
171+
67172
DeviceT device;
68173
std::string how_selected = "automatically selected";
69-
if(user_device != nullptr) {
70-
device = *user_device;
174+
if constexpr(user_device_provided) {
175+
device = user_device_or_selector;
71176
how_selected = "specified by user";
72177
} else {
73178
const auto device_cfg = cfg.get_device_config();
@@ -86,48 +191,37 @@ namespace detail {
86191
} else {
87192
const auto host_cfg = cfg.get_host_config();
88193

89-
const auto try_find_device_per_node = [&host_cfg, &device, &how_selected, &platforms](cl::sycl::info::device_type type) {
90-
// Try to find a platform that can provide a unique device for each node.
91-
for(size_t i = 0; i < platforms.size(); ++i) {
92-
auto&& platform = platforms[i];
93-
const auto devices = platform.get_devices(type);
94-
if(devices.size() >= host_cfg.node_count) {
95-
how_selected = fmt::format("automatically selected platform {}, device {}", i, host_cfg.local_rank);
96-
device = devices[host_cfg.local_rank];
97-
return true;
98-
}
99-
}
100-
return false;
101-
};
102-
103-
const auto try_find_one_device = [&device, &platforms](cl::sycl::info::device_type type) {
104-
for(auto& p : platforms) {
105-
for(auto& d : p.get_devices(type)) {
106-
device = d;
107-
return true;
194+
if constexpr(!device_selector_provided) {
195+
// Try to find a unique GPU per node.
196+
if(!try_find_device_per_node(how_selected, device, platforms, host_cfg, sycl::info::device_type::gpu)) {
197+
if(try_find_device_per_node(how_selected, device, platforms, host_cfg, sycl::info::device_type::all)) {
198+
CELERITY_WARN("No suitable platform found that can provide {} GPU devices, and CELERITY_DEVICES not set", host_cfg.node_count);
199+
} else {
200+
CELERITY_WARN("No suitable platform found that can provide {} devices, and CELERITY_DEVICES not set", host_cfg.node_count);
201+
// Just use the first available device. Prefer GPUs, but settle for anything.
202+
if(!try_find_one_device(how_selected, device, platforms, host_cfg, sycl::info::device_type::gpu)
203+
&& !try_find_one_device(how_selected, device, platforms, host_cfg, sycl::info::device_type::all)) {
204+
throw std::runtime_error("Automatic device selection failed: No device available");
205+
}
108206
}
109207
}
110-
return false;
111-
};
112-
113-
// Try to find a unique GPU per node.
114-
if(!try_find_device_per_node(cl::sycl::info::device_type::gpu)) {
115-
// Try to find a unique device (of any type) per node.
116-
if(try_find_device_per_node(cl::sycl::info::device_type::all)) {
117-
CELERITY_WARN("No suitable platform found that can provide {} GPU devices, and CELERITY_DEVICES not set", host_cfg.node_count);
118-
} else {
119-
CELERITY_WARN("No suitable platform found that can provide {} devices, and CELERITY_DEVICES not set", host_cfg.node_count);
120-
// Just use the first available device. Prefer GPUs, but settle for anything.
121-
if(!try_find_one_device(cl::sycl::info::device_type::gpu) && !try_find_one_device(cl::sycl::info::device_type::all)) {
122-
throw std::runtime_error("Automatic device selection failed: No device available");
208+
} else {
209+
// Try to find a unique device per node using a selector.
210+
if(!try_find_device_per_node(how_selected, device, platforms, host_cfg, user_device_or_selector)) {
211+
CELERITY_WARN("No suitable platform found that can provide {} devices that match the specified device selector, and "
212+
"CELERITY_DEVICES not set",
213+
host_cfg.node_count);
214+
// Use the first available device according to the selector, but fails if no such device is found.
215+
if(!try_find_one_device(how_selected, device, platforms, host_cfg, user_device_or_selector)) {
216+
throw std::runtime_error("Device selection with device selector failed: No device available");
123217
}
124218
}
125219
}
126220
}
127221
}
128222

129-
const auto platform_name = device.get_platform().template get_info<cl::sycl::info::platform::name>();
130-
const auto device_name = device.template get_info<cl::sycl::info::device::name>();
223+
const auto platform_name = device.get_platform().template get_info<sycl::info::platform::name>();
224+
const auto device_name = device.template get_info<sycl::info::device::name>();
131225
CELERITY_INFO("Using platform '{}', device '{}' ({})", platform_name, device_name, how_selected);
132226

133227
return device;

include/distr_queue.h

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include <memory>
44
#include <type_traits>
55

6+
#include "device_queue.h"
67
#include "runtime.h"
78
#include "task_manager.h"
89

@@ -25,10 +26,19 @@ inline constexpr allow_by_ref_t allow_by_ref{};
2526

2627
class distr_queue {
2728
public:
28-
distr_queue() { init(nullptr); }
29-
distr_queue(cl::sycl::device& device) {
29+
distr_queue() { init(detail::auto_select_device{}); }
30+
31+
[[deprecated("Use the overload with device selector instead, this will be removed in future release")]] distr_queue(cl::sycl::device& device) {
3032
if(detail::runtime::is_initialized()) { throw std::runtime_error("Passing explicit device not possible, runtime has already been initialized."); }
31-
init(&device);
33+
init(device);
34+
}
35+
36+
template <typename DeviceSelector>
37+
distr_queue(const DeviceSelector& device_selector) {
38+
if(detail::runtime::is_initialized()) {
39+
throw std::runtime_error("Passing explicit device selector not possible, runtime has already been initialized.");
40+
}
41+
init(device_selector);
3242
}
3343

3444
distr_queue(const distr_queue&) = default;
@@ -77,8 +87,8 @@ class distr_queue {
7787
private:
7888
std::shared_ptr<detail::distr_queue_tracker> tracker;
7989

80-
void init(cl::sycl::device* user_device) {
81-
if(!detail::runtime::is_initialized()) { detail::runtime::init(nullptr, nullptr, user_device); }
90+
void init(detail::device_or_selector device_or_selector) {
91+
if(!detail::runtime::is_initialized()) { detail::runtime::init(nullptr, nullptr, device_or_selector); }
8292
try {
8393
detail::runtime::get_instance().startup();
8494
} catch(detail::runtime_already_started_error&) {

include/runtime.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,10 @@ namespace detail {
4141

4242
public:
4343
/**
44-
* @param user_device This optional device can be provided by the user, overriding any other device selection strategy.
44+
* @param user_device_or_selector This optional device (overriding any other device selection strategy) or device selector can be provided by the user.
4545
*/
46-
static void init(int* argc, char** argv[], cl::sycl::device* user_device = nullptr);
46+
static void init(int* argc, char** argv[], device_or_selector user_device_or_selector = auto_select_device{});
47+
4748
static bool is_initialized() { return instance != nullptr; }
4849
static runtime& get_instance();
4950

@@ -117,7 +118,7 @@ namespace detail {
117118
};
118119
std::deque<flush_handle> active_flushes;
119120

120-
runtime(int* argc, char** argv[], cl::sycl::device* user_device = nullptr);
121+
runtime(int* argc, char** argv[], device_or_selector user_device_or_selector);
121122
runtime(const runtime&) = delete;
122123
runtime(runtime&&) = delete;
123124

src/device_queue.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,20 @@
88
namespace celerity {
99
namespace detail {
1010

11-
void device_queue::init(const config& cfg, cl::sycl::device* user_device) {
11+
void device_queue::init(const config& cfg, const device_or_selector& user_device_or_selector) {
1212
assert(sycl_queue == nullptr);
1313
const auto profiling_cfg = cfg.get_enable_device_profiling();
1414
device_profiling_enabled = profiling_cfg != std::nullopt && *profiling_cfg;
1515
if(device_profiling_enabled) { CELERITY_INFO("Device profiling enabled."); }
1616

1717
const auto props = device_profiling_enabled ? cl::sycl::property_list{cl::sycl::property::queue::enable_profiling()} : cl::sycl::property_list{};
1818
const auto handle_exceptions = cl::sycl::async_handler{[this](cl::sycl::exception_list el) { this->handle_async_exceptions(el); }};
19-
auto device = pick_device(cfg, user_device, cl::sycl::platform::get_platforms());
19+
20+
auto device = std::visit(
21+
[&cfg](const auto& value) { return ::celerity::detail::pick_device(cfg, value, cl::sycl::platform::get_platforms()); }, user_device_or_selector);
2022
sycl_queue = std::make_unique<cl::sycl::queue>(device, handle_exceptions, props);
2123
}
2224

23-
2425
void device_queue::handle_async_exceptions(cl::sycl::exception_list el) const {
2526
for(auto& e : el) {
2627
try {

src/runtime.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,9 @@ namespace detail {
4848
mpi_finalized = true;
4949
}
5050

51-
void runtime::init(int* argc, char** argv[], cl::sycl::device* user_device) {
51+
void runtime::init(int* argc, char** argv[], device_or_selector user_device_or_selector) {
5252
assert(!instance);
53-
instance = std::unique_ptr<runtime>(new runtime(argc, argv, user_device));
53+
instance = std::unique_ptr<runtime>(new runtime(argc, argv, user_device_or_selector));
5454
}
5555

5656
runtime& runtime::get_instance() {
@@ -91,7 +91,7 @@ namespace detail {
9191
#endif
9292
}
9393

94-
runtime::runtime(int* argc, char** argv[], cl::sycl::device* user_device) {
94+
runtime::runtime(int* argc, char** argv[], device_or_selector user_device_or_selector) {
9595
if(test_mode) {
9696
assert(test_active && "initializing the runtime from a test without a runtime_fixture");
9797
} else {
@@ -145,7 +145,7 @@ namespace detail {
145145

146146
CELERITY_INFO(
147147
"Celerity runtime version {} running on {}. PID = {}, build type = {}", get_version_string(), get_sycl_version(), get_pid(), get_build_type());
148-
d_queue->init(*cfg, user_device);
148+
d_queue->init(*cfg, user_device_or_selector);
149149
}
150150

151151
runtime::~runtime() {

0 commit comments

Comments
 (0)