Skip to content

Commit 2a56b50

Browse files
almightyvatspsalz
authored andcommitted
Add tests for existing device selection logic
1 parent 63d5ffc commit 2a56b50

File tree

6 files changed

+464
-74
lines changed

6 files changed

+464
-74
lines changed

include/config.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ namespace detail {
4848
std::optional<device_config> device_cfg;
4949
std::optional<bool> enable_device_profiling;
5050
size_t graph_print_max_verts = 200;
51+
friend struct config_testspy;
5152
};
5253

5354
} // namespace detail

include/device_queue.h

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,79 @@ namespace detail {
5959
std::unique_ptr<cl::sycl::queue> sycl_queue;
6060
bool device_profiling_enabled = false;
6161

62-
cl::sycl::device pick_device(const config& cfg, cl::sycl::device* user_device) const;
6362
void handle_async_exceptions(cl::sycl::exception_list el) const;
6463
};
6564

65+
template <typename DeviceT, typename PlatformT>
66+
DeviceT pick_device(const config& cfg, DeviceT* user_device, const std::vector<PlatformT>& platforms) {
67+
DeviceT device;
68+
std::string how_selected = "automatically selected";
69+
if(user_device != nullptr) {
70+
device = *user_device;
71+
how_selected = "specified by user";
72+
} else {
73+
const auto device_cfg = cfg.get_device_config();
74+
if(device_cfg != std::nullopt) {
75+
how_selected = fmt::format("set by CELERITY_DEVICES: platform {}, device {}", device_cfg->platform_id, device_cfg->device_id);
76+
CELERITY_DEBUG("{} platforms available", platforms.size());
77+
if(device_cfg->platform_id >= platforms.size()) {
78+
throw std::runtime_error(fmt::format("Invalid platform id {}: Only {} platforms available", device_cfg->platform_id, platforms.size()));
79+
}
80+
const auto devices = platforms[device_cfg->platform_id].get_devices();
81+
if(device_cfg->device_id >= devices.size()) {
82+
throw std::runtime_error(fmt::format(
83+
"Invalid device id {}: Only {} devices available on platform {}", device_cfg->device_id, devices.size(), device_cfg->platform_id));
84+
}
85+
device = devices[device_cfg->device_id];
86+
} else {
87+
const auto host_cfg = cfg.get_host_config();
88+
89+
const auto try_find_device_per_node = [&host_cfg, &device, &how_selected, &platforms](cl::sycl::info::device_type type) {
90+
// Try to find a platform that can provide a unique device for each node.
91+
for(size_t i = 0; i < platforms.size(); ++i) {
92+
auto&& platform = platforms[i];
93+
const auto devices = platform.get_devices(type);
94+
if(devices.size() >= host_cfg.node_count) {
95+
how_selected = fmt::format("automatically selected platform {}, device {}", i, host_cfg.local_rank);
96+
device = devices[host_cfg.local_rank];
97+
return true;
98+
}
99+
}
100+
return false;
101+
};
102+
103+
const auto try_find_one_device = [&device, &platforms](cl::sycl::info::device_type type) {
104+
for(auto& p : platforms) {
105+
for(auto& d : p.get_devices(type)) {
106+
device = d;
107+
return true;
108+
}
109+
}
110+
return false;
111+
};
112+
113+
// Try to find a unique GPU per node.
114+
if(!try_find_device_per_node(cl::sycl::info::device_type::gpu)) {
115+
// Try to find a unique device (of any type) per node.
116+
if(try_find_device_per_node(cl::sycl::info::device_type::all)) {
117+
CELERITY_WARN("No suitable platform found that can provide {} GPU devices, and CELERITY_DEVICES not set", host_cfg.node_count);
118+
} else {
119+
CELERITY_WARN("No suitable platform found that can provide {} devices, and CELERITY_DEVICES not set", host_cfg.node_count);
120+
// Just use the first available device. Prefer GPUs, but settle for anything.
121+
if(!try_find_one_device(cl::sycl::info::device_type::gpu) && !try_find_one_device(cl::sycl::info::device_type::all)) {
122+
throw std::runtime_error("Automatic device selection failed: No device available");
123+
}
124+
}
125+
}
126+
}
127+
}
128+
129+
const auto platform_name = device.get_platform().template get_info<cl::sycl::info::platform::name>();
130+
const auto device_name = device.template get_info<cl::sycl::info::device::name>();
131+
CELERITY_INFO("Using platform '{}', device '{}' ({})", platform_name, device_name, how_selected);
132+
133+
return device;
134+
}
135+
66136
} // namespace detail
67137
} // namespace celerity

src/config.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,5 @@ namespace detail {
192192
}
193193
}
194194
}
195-
196195
} // namespace detail
197196
} // namespace celerity

src/device_queue.cc

Lines changed: 1 addition & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -16,80 +16,10 @@ namespace detail {
1616

1717
const auto props = device_profiling_enabled ? cl::sycl::property_list{cl::sycl::property::queue::enable_profiling()} : cl::sycl::property_list{};
1818
const auto handle_exceptions = cl::sycl::async_handler{[this](cl::sycl::exception_list el) { this->handle_async_exceptions(el); }};
19-
auto device = pick_device(cfg, user_device);
19+
auto device = pick_device(cfg, user_device, cl::sycl::platform::get_platforms());
2020
sycl_queue = std::make_unique<cl::sycl::queue>(device, handle_exceptions, props);
2121
}
2222

23-
cl::sycl::device device_queue::pick_device(const config& cfg, cl::sycl::device* user_device) const {
24-
cl::sycl::device device;
25-
std::string how_selected = "automatically selected";
26-
if(user_device != nullptr) {
27-
device = *user_device;
28-
how_selected = "specified by user";
29-
} else {
30-
const auto device_cfg = cfg.get_device_config();
31-
if(device_cfg != std::nullopt) {
32-
how_selected = fmt::format("set by CELERITY_DEVICES: platform {}, device {}", device_cfg->platform_id, device_cfg->device_id);
33-
const auto platforms = cl::sycl::platform::get_platforms();
34-
CELERITY_DEBUG("{} platforms available", platforms.size());
35-
if(device_cfg->platform_id >= platforms.size()) {
36-
throw std::runtime_error(fmt::format("Invalid platform id {}: Only {} platforms available", device_cfg->platform_id, platforms.size()));
37-
}
38-
const auto devices = platforms[device_cfg->platform_id].get_devices();
39-
if(device_cfg->device_id >= devices.size()) {
40-
throw std::runtime_error(fmt::format(
41-
"Invalid device id {}: Only {} devices available on platform {}", device_cfg->device_id, devices.size(), device_cfg->platform_id));
42-
}
43-
device = devices[device_cfg->device_id];
44-
} else {
45-
const auto host_cfg = cfg.get_host_config();
46-
47-
const auto try_find_device_per_node = [&host_cfg, &device, &how_selected](cl::sycl::info::device_type type) {
48-
// Try to find a platform that can provide a unique device for each node.
49-
const auto platforms = cl::sycl::platform::get_platforms();
50-
for(size_t i = 0; i < platforms.size(); ++i) {
51-
auto&& platform = platforms[i];
52-
const auto devices = platform.get_devices(type);
53-
if(devices.size() >= host_cfg.node_count) {
54-
how_selected = fmt::format("automatically selected platform {}, device {}", i, host_cfg.local_rank);
55-
device = devices[host_cfg.local_rank];
56-
return true;
57-
}
58-
}
59-
return false;
60-
};
61-
62-
const auto try_find_one_device = [&device](cl::sycl::info::device_type type) {
63-
const auto devices = cl::sycl::device::get_devices(type);
64-
if(!devices.empty()) {
65-
device = devices[0];
66-
return true;
67-
}
68-
return false;
69-
};
70-
71-
// Try to find a unique GPU per node.
72-
if(!try_find_device_per_node(cl::sycl::info::device_type::gpu)) {
73-
// Try to find a unique device (of any type) per node.
74-
if(try_find_device_per_node(cl::sycl::info::device_type::all)) {
75-
CELERITY_WARN("No suitable platform found that can provide {} GPU devices, and CELERITY_DEVICES not set", host_cfg.node_count);
76-
} else {
77-
CELERITY_WARN("No suitable platform found that can provide {} devices, and CELERITY_DEVICES not set", host_cfg.node_count);
78-
// Just use the first available device. Prefer GPUs, but settle for anything.
79-
if(!try_find_one_device(cl::sycl::info::device_type::gpu) && !try_find_one_device(cl::sycl::info::device_type::all)) {
80-
throw std::runtime_error("Automatic device selection failed: No device available");
81-
}
82-
}
83-
}
84-
}
85-
}
86-
87-
const auto platform_name = device.get_platform().get_info<cl::sycl::info::platform::name>();
88-
const auto device_name = device.get_info<cl::sycl::info::device::name>();
89-
CELERITY_INFO("Using platform '{}', device '{}' ({})", platform_name, device_name, how_selected);
90-
91-
return device;
92-
}
9323

9424
void device_queue::handle_async_exceptions(cl::sycl::exception_list el) const {
9525
for(auto& e : el) {
@@ -102,6 +32,5 @@ namespace detail {
10232
}
10333
}
10434

105-
10635
} // namespace detail
10736
} // namespace celerity

test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ set(TEST_TARGETS
3737
runtime_deprecation_tests
3838
sycl_tests
3939
task_graph_tests
40+
device_selection_tests
4041
)
4142

4243
add_library(test_main test_main.cc)

0 commit comments

Comments
 (0)