@@ -16,80 +16,10 @@ namespace detail {
16
16
17
17
const auto props = device_profiling_enabled ? cl::sycl::property_list{cl::sycl::property::queue::enable_profiling ()} : cl::sycl::property_list{};
18
18
const auto handle_exceptions = cl::sycl::async_handler{[this ](cl::sycl::exception_list el) { this ->handle_async_exceptions (el); }};
19
- auto device = pick_device (cfg, user_device);
19
+ auto device = pick_device (cfg, user_device, cl::sycl::platform::get_platforms () );
20
20
sycl_queue = std::make_unique<cl::sycl::queue>(device, handle_exceptions, props);
21
21
}
22
22
23
- cl::sycl::device device_queue::pick_device (const config& cfg, cl::sycl::device* user_device) const {
24
- cl::sycl::device device;
25
- std::string how_selected = " automatically selected" ;
26
- if (user_device != nullptr ) {
27
- device = *user_device;
28
- how_selected = " specified by user" ;
29
- } else {
30
- const auto device_cfg = cfg.get_device_config ();
31
- if (device_cfg != std::nullopt) {
32
- how_selected = fmt::format (" set by CELERITY_DEVICES: platform {}, device {}" , device_cfg->platform_id , device_cfg->device_id );
33
- const auto platforms = cl::sycl::platform::get_platforms ();
34
- CELERITY_DEBUG (" {} platforms available" , platforms.size ());
35
- if (device_cfg->platform_id >= platforms.size ()) {
36
- throw std::runtime_error (fmt::format (" Invalid platform id {}: Only {} platforms available" , device_cfg->platform_id , platforms.size ()));
37
- }
38
- const auto devices = platforms[device_cfg->platform_id ].get_devices ();
39
- if (device_cfg->device_id >= devices.size ()) {
40
- throw std::runtime_error (fmt::format (
41
- " Invalid device id {}: Only {} devices available on platform {}" , device_cfg->device_id , devices.size (), device_cfg->platform_id ));
42
- }
43
- device = devices[device_cfg->device_id ];
44
- } else {
45
- const auto host_cfg = cfg.get_host_config ();
46
-
47
- const auto try_find_device_per_node = [&host_cfg, &device, &how_selected](cl::sycl::info::device_type type) {
48
- // Try to find a platform that can provide a unique device for each node.
49
- const auto platforms = cl::sycl::platform::get_platforms ();
50
- for (size_t i = 0 ; i < platforms.size (); ++i) {
51
- auto && platform = platforms[i];
52
- const auto devices = platform.get_devices (type);
53
- if (devices.size () >= host_cfg.node_count ) {
54
- how_selected = fmt::format (" automatically selected platform {}, device {}" , i, host_cfg.local_rank );
55
- device = devices[host_cfg.local_rank ];
56
- return true ;
57
- }
58
- }
59
- return false ;
60
- };
61
-
62
- const auto try_find_one_device = [&device](cl::sycl::info::device_type type) {
63
- const auto devices = cl::sycl::device::get_devices (type);
64
- if (!devices.empty ()) {
65
- device = devices[0 ];
66
- return true ;
67
- }
68
- return false ;
69
- };
70
-
71
- // Try to find a unique GPU per node.
72
- if (!try_find_device_per_node (cl::sycl::info::device_type::gpu)) {
73
- // Try to find a unique device (of any type) per node.
74
- if (try_find_device_per_node (cl::sycl::info::device_type::all)) {
75
- CELERITY_WARN (" No suitable platform found that can provide {} GPU devices, and CELERITY_DEVICES not set" , host_cfg.node_count );
76
- } else {
77
- CELERITY_WARN (" No suitable platform found that can provide {} devices, and CELERITY_DEVICES not set" , host_cfg.node_count );
78
- // Just use the first available device. Prefer GPUs, but settle for anything.
79
- if (!try_find_one_device (cl::sycl::info::device_type::gpu) && !try_find_one_device (cl::sycl::info::device_type::all)) {
80
- throw std::runtime_error (" Automatic device selection failed: No device available" );
81
- }
82
- }
83
- }
84
- }
85
- }
86
-
87
- const auto platform_name = device.get_platform ().get_info <cl::sycl::info::platform::name>();
88
- const auto device_name = device.get_info <cl::sycl::info::device::name>();
89
- CELERITY_INFO (" Using platform '{}', device '{}' ({})" , platform_name, device_name, how_selected);
90
-
91
- return device;
92
- }
93
23
94
24
void device_queue::handle_async_exceptions (cl::sycl::exception_list el) const {
95
25
for (auto & e : el) {
@@ -102,6 +32,5 @@ namespace detail {
102
32
}
103
33
}
104
34
105
-
106
35
} // namespace detail
107
36
} // namespace celerity
0 commit comments