Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -222,8 +222,10 @@ set(SOURCES
src/graph_serializer.cc
src/grid.cc
src/instruction_graph_generator.cc
src/mpi_communicator.cc
src/print_graph.cc
src/recorders.cc
src/receive_arbiter.cc
src/runtime.cc
src/scheduler.cc
src/split.cc
Expand Down
1 change: 1 addition & 0 deletions ci/run-system-tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ NUM_NODES=("$@")

SYSTEM_TESTS=(
"distr_tests"
"mpi_tests"
)

for e in "${!SYSTEM_TESTS[@]}"; do
Expand Down
59 changes: 59 additions & 0 deletions include/async_event.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
#pragma once

#include <cassert>
#include <memory>
#include <type_traits>


namespace celerity::detail {

/// Abstract base class for `async_event` implementations.
class async_event_impl {
public:
async_event_impl() = default;
async_event_impl(const async_event_impl&) = delete;
async_event_impl(async_event_impl&&) = delete;
async_event_impl& operator=(const async_event_impl&) = delete;
async_event_impl& operator=(async_event_impl&&) = delete;
virtual ~async_event_impl() = default;

/// If this function returns true once, the implementation must guarantee that it will always do so in the future.
/// The event is expected to be cheap to poll repeatedly, and the operation must proceed in the background even while not being polled.
virtual bool is_complete() const = 0;
};

/// `async_event` implementation that is immediately complete. Used to report synchronous completion of some operations within an otherwise asynchronous
/// context.
class complete_event final : public async_event_impl {
public:
complete_event() = default;
bool is_complete() const override { return true; }
};

/// Type-erased event signalling completion of events at the executor layer. These may wrap SYCL events, asynchronous MPI requests, or similar.
class [[nodiscard]] async_event {
public:
async_event() = default;
async_event(std::unique_ptr<async_event_impl> impl) noexcept : m_impl(std::move(impl)) {}

/// Polls the underlying event operation to check if it has completed. This function is cheap to call repeatedly.
bool is_complete() const {
assert(m_impl != nullptr);
return m_impl->is_complete();
}

private:
std::unique_ptr<async_event_impl> m_impl;
};

/// Shortcut to create an `async_event` using an `async_event_impl`-derived type `Event`.
template <typename Event, typename... CtorParams>
async_event make_async_event(CtorParams&&... ctor_args) {
static_assert(std::is_base_of_v<async_event_impl, Event>);
return async_event(std::make_unique<Event>(std::forward<CtorParams>(ctor_args)...));
}

/// Shortcut to create an `async_event(complete_event)`.
inline async_event make_complete_event() { return make_async_event<complete_event>(); }

} // namespace celerity::detail
89 changes: 89 additions & 0 deletions include/communicator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
#pragma once

#include "async_event.h"
#include "pilot.h"
#include "utils.h"

namespace celerity::detail {

/// Interface for peer-to-peer and collective communication across nodes to be implemented for MPI or similar system APIs.
///
/// Celerity maintains one root communicator which can be cloned collectively, and the same communicator instance in this "cloning tree" must participate in
/// corresponding operations on each node. Communicator instances themselves are not thread-safe, but if there are multiple (cloned) instances, each may be used
/// from their own thread.
///
/// Peer-to-peer operations (send/receive/poll) can be arbitrarily re-ordered by the communicator, but collectives will
/// always be executed precisely in the order they are submitted.
class communicator {
public:
/// Addresses a 1/2/3-dimensional subrange of a type-erased (buffer) allocation to be sent from or received into.
struct stride {
range<3> allocation_range;
subrange<3> transfer;
size_t element_size = 1;

friend bool operator==(const stride& lhs, const stride& rhs) {
return lhs.allocation_range == rhs.allocation_range && lhs.transfer == rhs.transfer && lhs.element_size == rhs.element_size;
}
friend bool operator!=(const stride& lhs, const stride& rhs) { return !(lhs == rhs); }
};

communicator() = default;
communicator(const communicator&) = delete;
communicator(communicator&&) = delete;
communicator& operator=(const communicator&) = delete;
communicator& operator=(communicator&&) = delete;

/// Communicator destruction is a collective operation like `collective_barrier`.
///
/// The user must ensure that any asynchronous operation is already complete when the destructor runs.
virtual ~communicator() = default;

/// Returns the number of nodes (processes) that are part of this communicator.
virtual size_t get_num_nodes() const = 0;

/// Returns the 0-based id of the local node in the communicator.
virtual node_id get_local_node_id() const = 0;

/// Asynchronously sends a pilot message, returning without acknowledgement from the receiver. The pilot is copied internally and the reference does not
/// need to remain live after the function returns.
virtual void send_outbound_pilot(const outbound_pilot& pilot) = 0;

/// Returns all inbound pilots received on this communicator since the last invocation of the same function. Never blocks.
[[nodiscard]] virtual std::vector<inbound_pilot> poll_inbound_pilots() = 0;

/// Begins sending strided data (that was previously announced using an outbound_pilot) to the specified node. The `base` allocation must remain live until
/// the returned event completes, and no element inside `stride` must be written to during that time.
[[nodiscard]] virtual async_event send_payload(node_id to, message_id msgid, const void* base, const stride& stride) = 0;

/// Begins receiving strided data (which was previously announced using an inbound_pilot) from the specified node. The `base` allocation must remain live
/// until the returned event completes, and no element inside `stride` must be written to during that time.
[[nodiscard]] virtual async_event receive_payload(node_id from, message_id msgid, void* base, const stride& stride) = 0;

/// Creates a new communicator that is fully concurrent to this one, and which has its own "namespace" for peer-to-peer and collective operations.
///
/// Must be ordered identically to all other collective operations on this communicator across all nodes.
virtual std::unique_ptr<communicator> collective_clone() = 0;

/// Blocks until all nodes in this communicator have called `collective_barrier()`.
///
/// Must be ordered identically to all other collective operations on this communicator across all nodes.
virtual void collective_barrier() = 0;
};

} // namespace celerity::detail

/// Required for caching strided datatypes in `mpi_communicator`.
template <>
struct std::hash<celerity::detail::communicator::stride> {
size_t operator()(const celerity::detail::communicator::stride& stride) const {
size_t h = 0;
for(int d = 0; d < 3; ++d) {
celerity::detail::utils::hash_combine(h, stride.allocation_range[d]);
celerity::detail::utils::hash_combine(h, stride.transfer.offset[d]);
celerity::detail::utils::hash_combine(h, stride.transfer.range[d]);
}
celerity::detail::utils::hash_combine(h, stride.element_size);
return h;
}
};
3 changes: 1 addition & 2 deletions include/launcher.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once

#include "async_event.h"
#include "grid.h"
#include "host_queue.h"

Expand All @@ -11,8 +12,6 @@

namespace celerity::detail {

struct async_event {}; // [IDAG placeholder]

using device_kernel_launcher = std::function<void(sycl::handler& sycl_cgh, const box<3>& execution_range, const std::vector<void*>& reduction_ptrs)>;
using host_task_launcher = std::function<async_event(host_queue& q, const box<3>& execution_range, MPI_Comm mpi_comm)>;
using command_group_launcher = std::variant<device_kernel_launcher, host_task_launcher>;
Expand Down
72 changes: 72 additions & 0 deletions include/mpi_communicator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
#pragma once

#include "communicator.h"

#include <memory>
#include <unordered_map>
#include <vector>

#include <mpi.h>

namespace celerity::detail {

/// Constructor tag for mpi_communicator
struct collective_clone_from_tag {
} inline constexpr collective_clone_from{};

/// MPI implementation of the `communicator` interface.
///
/// Wraps an `MPI_Comm`, manages strided MPI datatypes for sends / receives and optionally maintains an inbound / outbound queue of pilot messages.
class mpi_communicator final : public communicator {
public:
/// Creates a new `mpi_communicator` by cloning the given `MPI_Comm`, which must not be `MPI_COMM_NULL`.
explicit mpi_communicator(collective_clone_from_tag tag, MPI_Comm mpi_comm);

mpi_communicator(const mpi_communicator&) = delete;
mpi_communicator(mpi_communicator&&) = delete;
mpi_communicator& operator=(const mpi_communicator&) = delete;
mpi_communicator& operator=(mpi_communicator&&) = delete;
~mpi_communicator() override;

size_t get_num_nodes() const override;
node_id get_local_node_id() const override;

void send_outbound_pilot(const outbound_pilot& pilot) override;
[[nodiscard]] std::vector<inbound_pilot> poll_inbound_pilots() override;

[[nodiscard]] async_event send_payload(node_id to, message_id msgid, const void* base, const stride& stride) override;
[[nodiscard]] async_event receive_payload(node_id from, message_id msgid, void* base, const stride& stride) override;

[[nodiscard]] std::unique_ptr<communicator> collective_clone() override;
void collective_barrier() override;

/// Returns the underlying `MPI_Comm`. The result is never `MPI_COMM_NULL`.
MPI_Comm get_native() const { return m_mpi_comm; }

private:
friend struct mpi_communicator_testspy;

struct datatype_deleter {
void operator()(MPI_Datatype dtype) const;
};
using unique_datatype = std::unique_ptr<std::remove_pointer_t<MPI_Datatype>, datatype_deleter>;

/// Keeps a stable pointer to a `pilot_message` alive during an asynchronous pilot send / receive operation.
struct in_flight_pilot {
std::unique_ptr<pilot_message> message;
MPI_Request request = MPI_REQUEST_NULL;
};

MPI_Comm m_mpi_comm = MPI_COMM_NULL;

in_flight_pilot m_inbound_pilot; ///< continually Irecv'd into after the first call to poll_inbound_pilots()
std::vector<in_flight_pilot> m_outbound_pilots;

std::unordered_map<size_t, unique_datatype> m_scalar_type_cache;
std::unordered_map<stride, unique_datatype> m_array_type_cache;

MPI_Datatype get_scalar_type(size_t bytes);
MPI_Datatype get_array_type(const stride& stride);
};

} // namespace celerity::detail
1 change: 1 addition & 0 deletions include/mpi_support.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ constexpr int TAG_CMD = 0;
constexpr int TAG_DATA_TRANSFER = 1;
constexpr int TAG_TELEMETRY = 2;
constexpr int TAG_PRINT_GRAPH = 3;
constexpr int TAG_COMMUNICATOR = 4;

class data_type {
public:
Expand Down
6 changes: 6 additions & 0 deletions include/pilot.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,10 @@ struct outbound_pilot {
pilot_message message;
};

/// A pilot message as packaged on the receiver side.
struct inbound_pilot {
node_id from = -1;
pilot_message message;
};

} // namespace celerity::detail
Loading