diff --git a/CMakeLists.txt b/CMakeLists.txt index f1a88af5c..776d2b34f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -222,8 +222,10 @@ set(SOURCES src/graph_serializer.cc src/grid.cc src/instruction_graph_generator.cc + src/mpi_communicator.cc src/print_graph.cc src/recorders.cc + src/receive_arbiter.cc src/runtime.cc src/scheduler.cc src/split.cc diff --git a/ci/run-system-tests.sh b/ci/run-system-tests.sh index 8c8d6892b..93281bb0a 100755 --- a/ci/run-system-tests.sh +++ b/ci/run-system-tests.sh @@ -13,6 +13,7 @@ NUM_NODES=("$@") SYSTEM_TESTS=( "distr_tests" + "mpi_tests" ) for e in "${!SYSTEM_TESTS[@]}"; do diff --git a/include/async_event.h b/include/async_event.h new file mode 100644 index 000000000..7eb5617df --- /dev/null +++ b/include/async_event.h @@ -0,0 +1,59 @@ +#pragma once + +#include +#include +#include + + +namespace celerity::detail { + +/// Abstract base class for `async_event` implementations. +class async_event_impl { + public: + async_event_impl() = default; + async_event_impl(const async_event_impl&) = delete; + async_event_impl(async_event_impl&&) = delete; + async_event_impl& operator=(const async_event_impl&) = delete; + async_event_impl& operator=(async_event_impl&&) = delete; + virtual ~async_event_impl() = default; + + /// If this function returns true once, the implementation must guarantee that it will always do so in the future. + /// The event is expected to be cheap to poll repeatedly, and the operation must proceed in the background even while not being polled. + virtual bool is_complete() const = 0; +}; + +/// `async_event` implementation that is immediately complete. Used to report synchronous completion of some operations within an otherwise asynchronous +/// context. +class complete_event final : public async_event_impl { + public: + complete_event() = default; + bool is_complete() const override { return true; } +}; + +/// Type-erased event signalling completion of events at the executor layer. These may wrap SYCL events, asynchronous MPI requests, or similar. +class [[nodiscard]] async_event { + public: + async_event() = default; + async_event(std::unique_ptr impl) noexcept : m_impl(std::move(impl)) {} + + /// Polls the underlying event operation to check if it has completed. This function is cheap to call repeatedly. + bool is_complete() const { + assert(m_impl != nullptr); + return m_impl->is_complete(); + } + + private: + std::unique_ptr m_impl; +}; + +/// Shortcut to create an `async_event` using an `async_event_impl`-derived type `Event`. +template +async_event make_async_event(CtorParams&&... ctor_args) { + static_assert(std::is_base_of_v); + return async_event(std::make_unique(std::forward(ctor_args)...)); +} + +/// Shortcut to create an `async_event(complete_event)`. +inline async_event make_complete_event() { return make_async_event(); } + +} // namespace celerity::detail diff --git a/include/communicator.h b/include/communicator.h new file mode 100644 index 000000000..5a45b123d --- /dev/null +++ b/include/communicator.h @@ -0,0 +1,89 @@ +#pragma once + +#include "async_event.h" +#include "pilot.h" +#include "utils.h" + +namespace celerity::detail { + +/// Interface for peer-to-peer and collective communication across nodes to be implemented for MPI or similar system APIs. +/// +/// Celerity maintains one root communicator which can be cloned collectively, and the same communicator instance in this "cloning tree" must participate in +/// corresponding operations on each node. Communicator instances themselves are not thread-safe, but if there are multiple (cloned) instances, each may be used +/// from their own thread. +/// +/// Peer-to-peer operations (send/receive/poll) can be arbitrarily re-ordered by the communicator, but collectives will +/// always be executed precisely in the order they are submitted. +class communicator { + public: + /// Addresses a 1/2/3-dimensional subrange of a type-erased (buffer) allocation to be sent from or received into. + struct stride { + range<3> allocation_range; + subrange<3> transfer; + size_t element_size = 1; + + friend bool operator==(const stride& lhs, const stride& rhs) { + return lhs.allocation_range == rhs.allocation_range && lhs.transfer == rhs.transfer && lhs.element_size == rhs.element_size; + } + friend bool operator!=(const stride& lhs, const stride& rhs) { return !(lhs == rhs); } + }; + + communicator() = default; + communicator(const communicator&) = delete; + communicator(communicator&&) = delete; + communicator& operator=(const communicator&) = delete; + communicator& operator=(communicator&&) = delete; + + /// Communicator destruction is a collective operation like `collective_barrier`. + /// + /// The user must ensure that any asynchronous operation is already complete when the destructor runs. + virtual ~communicator() = default; + + /// Returns the number of nodes (processes) that are part of this communicator. + virtual size_t get_num_nodes() const = 0; + + /// Returns the 0-based id of the local node in the communicator. + virtual node_id get_local_node_id() const = 0; + + /// Asynchronously sends a pilot message, returning without acknowledgement from the receiver. The pilot is copied internally and the reference does not + /// need to remain live after the function returns. + virtual void send_outbound_pilot(const outbound_pilot& pilot) = 0; + + /// Returns all inbound pilots received on this communicator since the last invocation of the same function. Never blocks. + [[nodiscard]] virtual std::vector poll_inbound_pilots() = 0; + + /// Begins sending strided data (that was previously announced using an outbound_pilot) to the specified node. The `base` allocation must remain live until + /// the returned event completes, and no element inside `stride` must be written to during that time. + [[nodiscard]] virtual async_event send_payload(node_id to, message_id msgid, const void* base, const stride& stride) = 0; + + /// Begins receiving strided data (which was previously announced using an inbound_pilot) from the specified node. The `base` allocation must remain live + /// until the returned event completes, and no element inside `stride` must be written to during that time. + [[nodiscard]] virtual async_event receive_payload(node_id from, message_id msgid, void* base, const stride& stride) = 0; + + /// Creates a new communicator that is fully concurrent to this one, and which has its own "namespace" for peer-to-peer and collective operations. + /// + /// Must be ordered identically to all other collective operations on this communicator across all nodes. + virtual std::unique_ptr collective_clone() = 0; + + /// Blocks until all nodes in this communicator have called `collective_barrier()`. + /// + /// Must be ordered identically to all other collective operations on this communicator across all nodes. + virtual void collective_barrier() = 0; +}; + +} // namespace celerity::detail + +/// Required for caching strided datatypes in `mpi_communicator`. +template <> +struct std::hash { + size_t operator()(const celerity::detail::communicator::stride& stride) const { + size_t h = 0; + for(int d = 0; d < 3; ++d) { + celerity::detail::utils::hash_combine(h, stride.allocation_range[d]); + celerity::detail::utils::hash_combine(h, stride.transfer.offset[d]); + celerity::detail::utils::hash_combine(h, stride.transfer.range[d]); + } + celerity::detail::utils::hash_combine(h, stride.element_size); + return h; + } +}; diff --git a/include/launcher.h b/include/launcher.h index 0f5283c6f..eeb667d49 100644 --- a/include/launcher.h +++ b/include/launcher.h @@ -1,5 +1,6 @@ #pragma once +#include "async_event.h" #include "grid.h" #include "host_queue.h" @@ -11,8 +12,6 @@ namespace celerity::detail { -struct async_event {}; // [IDAG placeholder] - using device_kernel_launcher = std::function& execution_range, const std::vector& reduction_ptrs)>; using host_task_launcher = std::function& execution_range, MPI_Comm mpi_comm)>; using command_group_launcher = std::variant; diff --git a/include/mpi_communicator.h b/include/mpi_communicator.h new file mode 100644 index 000000000..35dd8031d --- /dev/null +++ b/include/mpi_communicator.h @@ -0,0 +1,72 @@ +#pragma once + +#include "communicator.h" + +#include +#include +#include + +#include + +namespace celerity::detail { + +/// Constructor tag for mpi_communicator +struct collective_clone_from_tag { +} inline constexpr collective_clone_from{}; + +/// MPI implementation of the `communicator` interface. +/// +/// Wraps an `MPI_Comm`, manages strided MPI datatypes for sends / receives and optionally maintains an inbound / outbound queue of pilot messages. +class mpi_communicator final : public communicator { + public: + /// Creates a new `mpi_communicator` by cloning the given `MPI_Comm`, which must not be `MPI_COMM_NULL`. + explicit mpi_communicator(collective_clone_from_tag tag, MPI_Comm mpi_comm); + + mpi_communicator(const mpi_communicator&) = delete; + mpi_communicator(mpi_communicator&&) = delete; + mpi_communicator& operator=(const mpi_communicator&) = delete; + mpi_communicator& operator=(mpi_communicator&&) = delete; + ~mpi_communicator() override; + + size_t get_num_nodes() const override; + node_id get_local_node_id() const override; + + void send_outbound_pilot(const outbound_pilot& pilot) override; + [[nodiscard]] std::vector poll_inbound_pilots() override; + + [[nodiscard]] async_event send_payload(node_id to, message_id msgid, const void* base, const stride& stride) override; + [[nodiscard]] async_event receive_payload(node_id from, message_id msgid, void* base, const stride& stride) override; + + [[nodiscard]] std::unique_ptr collective_clone() override; + void collective_barrier() override; + + /// Returns the underlying `MPI_Comm`. The result is never `MPI_COMM_NULL`. + MPI_Comm get_native() const { return m_mpi_comm; } + + private: + friend struct mpi_communicator_testspy; + + struct datatype_deleter { + void operator()(MPI_Datatype dtype) const; + }; + using unique_datatype = std::unique_ptr, datatype_deleter>; + + /// Keeps a stable pointer to a `pilot_message` alive during an asynchronous pilot send / receive operation. + struct in_flight_pilot { + std::unique_ptr message; + MPI_Request request = MPI_REQUEST_NULL; + }; + + MPI_Comm m_mpi_comm = MPI_COMM_NULL; + + in_flight_pilot m_inbound_pilot; ///< continually Irecv'd into after the first call to poll_inbound_pilots() + std::vector m_outbound_pilots; + + std::unordered_map m_scalar_type_cache; + std::unordered_map m_array_type_cache; + + MPI_Datatype get_scalar_type(size_t bytes); + MPI_Datatype get_array_type(const stride& stride); +}; + +} // namespace celerity::detail diff --git a/include/mpi_support.h b/include/mpi_support.h index 4b44a59c8..7a5aca092 100644 --- a/include/mpi_support.h +++ b/include/mpi_support.h @@ -8,6 +8,7 @@ constexpr int TAG_CMD = 0; constexpr int TAG_DATA_TRANSFER = 1; constexpr int TAG_TELEMETRY = 2; constexpr int TAG_PRINT_GRAPH = 3; +constexpr int TAG_COMMUNICATOR = 4; class data_type { public: diff --git a/include/pilot.h b/include/pilot.h index ca50bee58..16c29166d 100644 --- a/include/pilot.h +++ b/include/pilot.h @@ -20,4 +20,10 @@ struct outbound_pilot { pilot_message message; }; +/// A pilot message as packaged on the receiver side. +struct inbound_pilot { + node_id from = -1; + pilot_message message; +}; + } // namespace celerity::detail diff --git a/include/receive_arbiter.h b/include/receive_arbiter.h new file mode 100644 index 000000000..630ec3d9e --- /dev/null +++ b/include/receive_arbiter.h @@ -0,0 +1,149 @@ +#pragma once + +#include "communicator.h" +#include "pilot.h" + +#include +#include + +namespace celerity::detail::receive_arbiter_detail { + +/// A single box received by the communicator, as described earlier by an inbound pilot. +struct incoming_region_fragment { + detail::box<3> box; + async_event communication; ///< async communicator event for receiving this fragment +}; + +/// State for a single incomplete `receive` operation or a `begin_split_receive` / `await_split_receive_subregion` tree. +struct region_request { + void* allocation; + box<3> allocated_box; + region<3> incomplete_region; + std::vector incoming_fragments; + + region_request(region<3> requested_region, void* const allocation, const box<3>& allocated_bounding_box) + : allocation(allocation), allocated_box(allocated_bounding_box), incomplete_region(std::move(requested_region)) {} + bool do_complete(); +}; + +/// A single chunk in a `gather_request` that is currently being received by the communicator. +struct incoming_gather_chunk { + async_event communication; ///< async communicator event for receiving this chunk +}; + +/// State for a single incomplete `gather_receive` operation. +struct gather_request { + void* allocation; + size_t chunk_size; ///< in bytes + size_t num_incomplete_chunks; ///< number of chunks that are currently being received or for which we have not seen a pilot yet + std::vector incoming_chunks; ///< chunks that are currently being received + + gather_request(void* const allocation, const size_t chunk_size, const size_t num_total_chunks) + : allocation(allocation), chunk_size(chunk_size), num_incomplete_chunks(num_total_chunks) {} + bool do_complete(); +}; + +// shared_ptrs for pointer stability (referenced by receive_arbiter::event) +using stable_region_request = std::shared_ptr; +using stable_gather_request = std::shared_ptr; + +/// A transfer that is only known through inbound pilots so far, but no `receive` / `begin_split_receive` has been issued so far. +struct unassigned_transfer { + std::vector pilots; + bool do_complete(); +}; + +/// A (non-gather) transfer that has been mentioned in one or more calls to `receive` / `begin_split_receive`. Note that there may be multiple disjoint +/// receives mapping to the same `transfer_id` as long as their regions are pairwise disconnected. +struct multi_region_transfer { + size_t elem_size; ///< in bytes + std::vector active_requests; ///< all `receive`s and `begin_split_receive`s active for this transfer id. + std::vector unassigned_pilots; ///< all inbound pilots that do not map to any `active_request`. + + explicit multi_region_transfer(const size_t elem_size) : elem_size(elem_size) {} + explicit multi_region_transfer(const size_t elem_size, std::vector&& unassigned_pilots) + : elem_size(elem_size), unassigned_pilots(std::move(unassigned_pilots)) {} + bool do_complete(); +}; + +/// A transfer originating through `gather_receive`. It is fully described by a single `gather_request`. +struct gather_transfer { + stable_gather_request request; + bool do_complete(); +}; + +/// Depending on the order of inputs, transfers may start out as unassigned and will be replaced by either `multi_region_transfer`s or `gather_transfer`s +/// once explicit calls to the respective receive arbiter functions are made. +using transfer = std::variant; + +} // namespace celerity::detail::receive_arbiter_detail + +namespace celerity::detail { + +/// Matches receive instructions to inbound pilots and triggers in-place payload receives on the communicator. +/// +/// For scalability reasons, distributed command graph generation only yields exact destinations and buffer sub-ranges for push commands, while await-pushes do +/// not carry such information - they just denote the full region to be received. Sender nodes later communicate the exact ranges to the receiver during +/// execution time via pilot messages that are generated alongside the instruction graph. +/// +/// The receive_arbiter's job is to match these inbound pilots to receive instructions generated from await-push commands to issue in-place receives (i.e. +/// `MPI_Recv`) of the data into an appropriate host allocation. Since these inputs may arrive in arbitrary order, it maintains a separate state machine for +/// each `transfer_id` to drive all operations that eventually result in completing an `async_event` for each receive instruction. +class receive_arbiter { + public: + /// `receive_arbiter` will use `comm` to poll for inbound pilots and issue payload-receives. + explicit receive_arbiter(communicator& comm); + + receive_arbiter(const receive_arbiter&) = delete; + receive_arbiter(receive_arbiter&&) = default; + receive_arbiter& operator=(const receive_arbiter&) = delete; + receive_arbiter& operator=(receive_arbiter&&) = default; + ~receive_arbiter(); + + /// Receive a buffer region associated with a single transfer id `trid` into an existing `allocation` with size `allocated_box.size() * elem_size`. The + /// `request` region must be fully contained in `allocated_box`, and the caller must ensure that it the communicator will not receive an inbound pilot that + /// intersects `request` without being fully contained in it. The returned `async_event` will complete once the receive is complete. + [[nodiscard]] async_event receive(const transfer_id& trid, const region<3>& request, void* allocation, const box<3>& allocated_box, size_t elem_size); + + /// Begin the reception of a buffer region into an existing allocation similar to `receive`, but do not await its completion with a single `async_event`. + /// Instead, the caller must follow up with calls to `await_split_receive_subregion` to the same `transfer_id` whose request regions do not necessarily have + /// to be disjoint, but whose union must be equal to the original `request`. + void begin_split_receive(const transfer_id& trid, const region<3>& request, void* allocation, const box<3>& allocated_box, size_t elem_size); + + /// To be called after `begin_split_receive` to await receiving a `subregion` of the original request. Subregions passed to different invocations of this + /// function may overlap, but must not exceed the original request. If the entire split-receive has finished already, this will return an instantly complete + /// event. + [[nodiscard]] async_event await_split_receive_subregion(const transfer_id& trid, const region<3>& subregion); + + /// Receive a contiguous chunk of data from every peer node, placing the results in `allocation[node_chunk_size * node_id]`. The location reserved for the + /// local node is not written to and may be concurrently accessed while this operation is in progress. If a peer node announces that it will not contribute + /// to this transfer by sending an empty-box pilot, its location will also remain unmodified. + /// + /// This feature is a temporary solution until we implement inter-node reductions through inter-node collectives. + [[nodiscard]] async_event gather_receive(const transfer_id& trid, void* allocation, size_t node_chunk_size); + + /// Polls the communicator for inbound pilots and advances the state of all ongoing receive operations. This is expected to be called in a loop + /// unconditionally. + void poll_communicator(); + + private: + communicator* m_comm; + size_t m_num_nodes; + + /// State machines for all `transfer_id`s that were mentioned in an inbound pilot or call to one of the receive functions. Once a transfer is complete, it + /// is cleared from `m_transfers`, but `multi_region_transfer`s can be re-created if there later appears another pair of inbound pilots and `receive`s for + /// the same transfer id that did not temporally overlap with the original ones. + std::unordered_map m_transfers; + + /// Initiates a new `region_request` for which the caller can construct events to await either the entire region or sub-regions. + receive_arbiter_detail::stable_region_request& initiate_region_request( + const transfer_id& trid, const region<3>& request, void* allocation, const box<3>& allocated_box, size_t elem_size); + + /// Updates the state of an active `region_request` from receiving an inbound pilot. + void handle_region_request_pilot(receive_arbiter_detail::region_request& rr, const inbound_pilot& pilot, size_t elem_size); + + /// Updates the state of an active `gather_request` from receiving an inbound pilot. + void handle_gather_request_pilot(receive_arbiter_detail::gather_request& gr, const inbound_pilot& pilot); +}; + +} // namespace celerity::detail diff --git a/include/utils.h b/include/utils.h index f62b44d13..1e23805e0 100644 --- a/include/utils.h +++ b/include/utils.h @@ -14,6 +14,12 @@ namespace celerity::detail::utils { +/// Like std::move, but move-constructs the result so it does not reference the argument after returning. +template +T take(T& from) { + return std::move(from); +} + template bool isa(const P* p) { return dynamic_cast(p) != nullptr; diff --git a/src/mpi_communicator.cc b/src/mpi_communicator.cc new file mode 100644 index 000000000..a68225893 --- /dev/null +++ b/src/mpi_communicator.cc @@ -0,0 +1,282 @@ +#include "mpi_communicator.h" +#include "log.h" +#include "mpi_support.h" +#include "ranges.h" + +#include +#include + +#include + +namespace celerity::detail::mpi_detail { + +/// async_event wrapper around an MPI_Request. +class mpi_event final : public async_event_impl { + public: + explicit mpi_event(MPI_Request req) : m_req(req) {} + + mpi_event(const mpi_event&) = delete; + mpi_event(mpi_event&&) = delete; + mpi_event& operator=(const mpi_event&) = delete; + mpi_event& operator=(mpi_event&&) = delete; + + ~mpi_event() override { + // MPI_Request_free is always incorrect for our use case: events originate from an Isend or Irecv, which must ensure that the user-provided buffer + // remains live until the operation has completed. + MPI_Wait(&m_req, MPI_STATUS_IGNORE); + } + + bool is_complete() const override { + int flag = -1; + MPI_Test(&m_req, &flag, MPI_STATUS_IGNORE); + return flag != 0; + } + + private: + mutable MPI_Request m_req; +}; + +constexpr int pilot_exchange_tag = mpi_support::TAG_COMMUNICATOR; +constexpr int first_message_tag = pilot_exchange_tag + 1; + +constexpr int message_id_to_mpi_tag(message_id msgid) { + // If the resulting tag would overflow INT_MAX in a long-running program with many nodes, we wrap around to `first_message_tag` instead, assuming that + // there will never be a way to cause temporal ambiguity between transfers that are 2^31 message ids apart. + msgid %= static_cast(INT_MAX - first_message_tag); + return first_message_tag + static_cast(msgid); +} + +constexpr int node_id_to_mpi_rank(const node_id nid) { + assert(nid <= static_cast(INT_MAX)); + return static_cast(nid); +} + +constexpr node_id mpi_rank_to_node_id(const int rank) { + assert(rank >= 0); + return static_cast(rank); +} + +/// Strides that only differ e.g. in their dim0 allocation size are equivalent when adjusting the base pointer. This not only improves mpi_communicator type +/// cache efficiency, but is in fact necessary to make sure all boxes that instruction_graph_generator emits for send instructions and inbound pilots +/// are representable in the 32-bit integer world of MPI. +/// @tparam Void Either `void` or `const void`. +template +constexpr std::tuple normalize_strided_pointer(Void* ptr, communicator::stride stride) { + using byte_pointer_t = std::conditional_t, const std::byte*, std::byte*>; + + // drop leading buffer dimensions with extent 1, which allows us to do pointer adjustment in d1 / d2 + while(stride.allocation_range[0] == 1 && stride.allocation_range[1] * stride.allocation_range[2] > 1) { + stride.allocation_range[0] = stride.allocation_range[1], stride.allocation_range[1] = stride.allocation_range[2], stride.allocation_range[2] = 1; + stride.transfer.range[0] = stride.transfer.range[1], stride.transfer.range[1] = stride.transfer.range[2], stride.transfer.range[2] = 1; + stride.transfer.offset[0] = stride.transfer.offset[1], stride.transfer.offset[1] = stride.transfer.offset[2], stride.transfer.offset[2] = 0; + } + + // adjust base pointer to remove the offset + const auto offset_elements = stride.transfer.offset[0] * stride.allocation_range[1] * stride.allocation_range[2]; + ptr = static_cast(ptr) + offset_elements * stride.element_size; + stride.transfer.offset[0] = 0; + + // clamp allocation size to subrange (MPI will not access memory beyond subrange.range anyway) + stride.allocation_range[0] = stride.transfer.range[0]; + + // TODO we can normalize further if we accept arbitrarily large scalar types (via MPI contiguous / struct types): + // - collapse fast dimensions if contiguous via `stride.element_size *= stride.subrange.range[d]` + // - factorize stride coordinates: `element_size *= gcd(allocation[0], offset[0], range[0], allocation[1], ...)` + // Doing all this will complicate instruction_graph_generator_detail::split_into_communicator_compatible_boxes though. + return {ptr, stride}; +} + +} // namespace celerity::detail::mpi_detail + +namespace celerity::detail { + +mpi_communicator::mpi_communicator(const collective_clone_from_tag /* tag */, const MPI_Comm mpi_comm) : m_mpi_comm(MPI_COMM_NULL) { + assert(mpi_comm != MPI_COMM_NULL); +#if MPI_VERSION < 3 + // MPI 2 only has Comm_dup - we assume that the user has not done any obscure things to MPI_COMM_WORLD + MPI_Comm_dup(mpi_comm, &m_mpi_comm); +#else + // MPI >= 3.0 provides MPI_Comm_dup_with_info, which allows us to reset all implementation hints on the communicator to our liking + MPI_Info info; + MPI_Info_create(&info); + // See the OpenMPI manpage for MPI_Comm_set_info for keys and values + MPI_Info_set(info, "mpi_assert_no_any_tag", "true"); // promise never to use MPI_ANY_TAG (we _do_ use MPI_ANY_SOURCE for pilots) + MPI_Info_set(info, "mpi_assert_exact_length", "true"); // promise to exactly match sizes between corresponding MPI_Send and MPI_Recv calls + MPI_Info_set(info, "mpi_assert_allow_overtaking", "true"); // we do not care about message ordering since we disambiguate by tag + MPI_Comm_dup_with_info(mpi_comm, info, &m_mpi_comm); + MPI_Info_free(&info); +#endif +} + +mpi_communicator::~mpi_communicator() { + // All asynchronous sends / receives must have completed at this point - unfortunately we have no easy way of checking this here. + + // Await the completion of all outbound pilot sends. The blocking-wait should usually be unnecessary because completion of payload-sends should imply + // completion of the outbound-pilot sends, although there is no real guarantee of this given MPI's freedom to buffer transfers however it likes. + // MPI_Wait will also free the async request, so we use this function unconditionally. + for(auto& outbound : m_outbound_pilots) { + MPI_Wait(&outbound.request, MPI_STATUS_IGNORE); + } + + // We always re-start the pilot Irecv immediately, so we need to MPI_Cancel the last such request (and then free it using MPI_Wait). + if(m_inbound_pilot.request != MPI_REQUEST_NULL) { + MPI_Cancel(&m_inbound_pilot.request); + MPI_Wait(&m_inbound_pilot.request, MPI_STATUS_IGNORE); + } + + // MPI_Comm_free is itself a collective, but since this call happens from a destructor we implicitly guarantee that it cant' be re-ordered against any + // other collective operation on this communicator. + MPI_Comm_free(&m_mpi_comm); +} + +size_t mpi_communicator::get_num_nodes() const { + int size = -1; + MPI_Comm_size(m_mpi_comm, &size); + assert(size > 0); + return static_cast(size); +} + +node_id mpi_communicator::get_local_node_id() const { + int rank = -1; + MPI_Comm_rank(m_mpi_comm, &rank); + return mpi_detail::mpi_rank_to_node_id(rank); +} + +void mpi_communicator::send_outbound_pilot(const outbound_pilot& pilot) { + CELERITY_DEBUG("[mpi] pilot -> N{} (MSG{}, {}, {})", pilot.to, pilot.message.id, pilot.message.transfer_id, pilot.message.box); + + assert(pilot.to < get_num_nodes()); + assert(pilot.to != get_local_node_id()); + + // Initiate Isend as early as possible to hide latency. + in_flight_pilot newly_in_flight; + newly_in_flight.message = std::make_unique(pilot.message); + MPI_Isend(newly_in_flight.message.get(), sizeof *newly_in_flight.message, MPI_BYTE, mpi_detail::node_id_to_mpi_rank(pilot.to), + mpi_detail::pilot_exchange_tag, m_mpi_comm, &newly_in_flight.request); + + // Collect finished sends (TODO consider rate-limiting this to avoid quadratic behavior) + constexpr auto pilot_send_finished = [](in_flight_pilot& already_in_flight) { + int flag = -1; + MPI_Test(&already_in_flight.request, &flag, MPI_STATUS_IGNORE); + return already_in_flight.request == MPI_REQUEST_NULL; + }; + m_outbound_pilots.erase(std::remove_if(m_outbound_pilots.begin(), m_outbound_pilots.end(), pilot_send_finished), m_outbound_pilots.end()); + + // Keep allocation until Isend has completed + m_outbound_pilots.push_back(std::move(newly_in_flight)); +} + +std::vector mpi_communicator::poll_inbound_pilots() { + // Irecv needs to be called initially, and after receiving each pilot to enqueue the next operation. + const auto begin_receiving_next_pilot = [this] { + assert(m_inbound_pilot.message != nullptr); + assert(m_inbound_pilot.request == MPI_REQUEST_NULL); + MPI_Irecv(m_inbound_pilot.message.get(), sizeof *m_inbound_pilot.message, MPI_BYTE, MPI_ANY_SOURCE, mpi_detail::pilot_exchange_tag, m_mpi_comm, + &m_inbound_pilot.request); + }; + + if(m_inbound_pilot.request == MPI_REQUEST_NULL) { + // This is the first call to poll_inbound_pilots, spin up the pilot-receiving machinery - we don't do this unconditionally in the constructor + // because communicators for collective groups do not deal with pilots + m_inbound_pilot.message = std::make_unique(); + begin_receiving_next_pilot(); + } + + // MPI might have received and buffered multiple inbound pilots, collect all of them in a loop + std::vector received_pilots; + for(;;) { + int flag = -1; + MPI_Status status; + MPI_Test(&m_inbound_pilot.request, &flag, &status); + if(flag == 0 /* incomplete */) { + return received_pilots; // no more pilots in queue, we're done collecting + } + + const inbound_pilot pilot{mpi_detail::mpi_rank_to_node_id(status.MPI_SOURCE), *m_inbound_pilot.message}; + begin_receiving_next_pilot(); // initiate the next receive asap + + CELERITY_DEBUG("[mpi] pilot <- N{} (MSG{}, {} {})", pilot.from, pilot.message.id, pilot.message.transfer_id, pilot.message.box); + received_pilots.push_back(pilot); + } +} + +async_event mpi_communicator::send_payload(const node_id to, const message_id msgid, const void* const base, const stride& stride) { + CELERITY_DEBUG("[mpi] payload -> N{} (MSG{}) from {} ({}) {}x{}", to, msgid, base, stride.allocation_range, stride.transfer, stride.element_size); + + assert(to < get_num_nodes()); + assert(to != get_local_node_id()); + + MPI_Request req = MPI_REQUEST_NULL; + const auto [adjusted_base, normalized_stride] = mpi_detail::normalize_strided_pointer(base, stride); + MPI_Isend( + adjusted_base, 1, get_array_type(normalized_stride), mpi_detail::node_id_to_mpi_rank(to), mpi_detail::message_id_to_mpi_tag(msgid), m_mpi_comm, &req); + return make_async_event(req); +} + +async_event mpi_communicator::receive_payload(const node_id from, const message_id msgid, void* const base, const stride& stride) { + CELERITY_DEBUG("[mpi] payload <- N{} (MSG{}) into {} ({}) {}x{}", from, msgid, base, stride.allocation_range, stride.transfer, stride.element_size); + + assert(from < get_num_nodes()); + assert(from != get_local_node_id()); + + MPI_Request req = MPI_REQUEST_NULL; + const auto [adjusted_base, normalized_stride] = mpi_detail::normalize_strided_pointer(base, stride); + MPI_Irecv( + adjusted_base, 1, get_array_type(normalized_stride), mpi_detail::node_id_to_mpi_rank(from), mpi_detail::message_id_to_mpi_tag(msgid), m_mpi_comm, &req); + return make_async_event(req); +} + +std::unique_ptr mpi_communicator::collective_clone() { return std::make_unique(collective_clone_from, m_mpi_comm); } + +void mpi_communicator::collective_barrier() { MPI_Barrier(m_mpi_comm); } + +MPI_Datatype mpi_communicator::get_scalar_type(const size_t bytes) { + if(const auto it = m_scalar_type_cache.find(bytes); it != m_scalar_type_cache.end()) { return it->second.get(); } + + assert(bytes > 0); + assert(bytes <= static_cast(INT_MAX)); + MPI_Datatype type = MPI_DATATYPE_NULL; + MPI_Type_contiguous(static_cast(bytes), MPI_BYTE, &type); + MPI_Type_commit(&type); + m_scalar_type_cache.emplace(bytes, unique_datatype(type)); + return type; +} + +MPI_Datatype mpi_communicator::get_array_type(const stride& stride) { + if(const auto it = m_array_type_cache.find(stride); it != m_array_type_cache.end()) { return it->second.get(); } + + const int dims = detail::get_effective_dims(stride.allocation_range); + assert(detail::get_effective_dims(stride.transfer) <= dims); + + // MPI (understandably) does not recognize a 0-dimensional subarray as a scalar + if(dims == 0) { return get_scalar_type(stride.element_size); } + + // TODO - can we get runaway behavior from constructing too many MPI data types, especially with Spectrum MPI? + // TODO - eagerly create MPI types ahead-of-time whenever we send or receive a pilot to reduce latency? + + int size_array[3]; + int subsize_array[3]; + int start_array[3]; + for(int d = 0; d < 3; ++d) { + // The instruction graph generator must only ever emit transfers which can be described with a signed-int stride + assert(stride.allocation_range[d] <= static_cast(INT_MAX)); + assert(stride.transfer.range[d] <= static_cast(INT_MAX)); + assert(stride.transfer.offset[d] <= static_cast(INT_MAX)); + size_array[d] = static_cast(stride.allocation_range[d]); + subsize_array[d] = static_cast(stride.transfer.range[d]); + start_array[d] = static_cast(stride.transfer.offset[d]); + } + + MPI_Datatype type = MPI_DATATYPE_NULL; + MPI_Type_create_subarray(dims, size_array, subsize_array, start_array, MPI_ORDER_C, get_scalar_type(stride.element_size), &type); + MPI_Type_commit(&type); + + m_array_type_cache.emplace(stride, unique_datatype(type)); + return type; +} + +void mpi_communicator::datatype_deleter::operator()(MPI_Datatype dtype) const { // + MPI_Type_free(&dtype); +} + +} // namespace celerity::detail diff --git a/src/receive_arbiter.cc b/src/receive_arbiter.cc new file mode 100644 index 000000000..2d662f199 --- /dev/null +++ b/src/receive_arbiter.cc @@ -0,0 +1,257 @@ +#include "receive_arbiter.h" +#include "grid.h" + +#include + +#include +#include + +namespace celerity::detail::receive_arbiter_detail { + +// weak-pointers for referencing stable_region/gather_requests that are held by the receive_arbiter. If they expire, we know the event is complete. +using weak_region_request = std::weak_ptr; +using weak_gather_request = std::weak_ptr; + +/// Event for `receive_arbiter::receive`, which immediately awaits the entire receive-region. +class region_receive_event final : public async_event_impl { + public: + explicit region_receive_event(const stable_region_request& rr) : m_request(rr) {} + + bool is_complete() const override { return m_request.expired(); } + + private: + weak_region_request m_request; +}; + +/// Event for `receive_arbiter::await_split_receive_subregion`, which awaits a specific subregion of a split receive. +class subregion_receive_event final : public async_event_impl { + public: + explicit subregion_receive_event(const stable_region_request& rr, const region<3>& awaited_subregion) + : m_request(rr), m_awaited_region(awaited_subregion) {} + + bool is_complete() const override { + const auto rr = m_request.lock(); + return rr == nullptr || region_intersection(rr->incomplete_region, m_awaited_region).empty(); + } + + private: + weak_region_request m_request; + region<3> m_awaited_region; +}; + +/// Event for `receive_arbiter::gather_receive`, which waits for incoming messages (or empty-box pilots) from all peers. +class gather_receive_event final : public async_event_impl { + public: + explicit gather_receive_event(const stable_gather_request& gr) : m_request(gr) {} + + bool is_complete() const override { return m_request.expired(); } + + private: + weak_gather_request m_request; +}; + +bool region_request::do_complete() { + const auto complete_fragment = [&](const incoming_region_fragment& fragment) { + if(!fragment.communication.is_complete()) return false; + incomplete_region = region_difference(incomplete_region, fragment.box); + return true; + }; + incoming_fragments.erase(std::remove_if(incoming_fragments.begin(), incoming_fragments.end(), complete_fragment), incoming_fragments.end()); + assert(!incomplete_region.empty() || incoming_fragments.empty()); + return incomplete_region.empty(); +} + +bool multi_region_transfer::do_complete() { + const auto complete_request = [](stable_region_request& rr) { return rr->do_complete(); }; + active_requests.erase(std::remove_if(active_requests.begin(), active_requests.end(), complete_request), active_requests.end()); + return active_requests.empty() && unassigned_pilots.empty(); +} + +bool gather_request::do_complete() { + const auto complete_chunk = [&](const incoming_gather_chunk& chunk) { + if(!chunk.communication.is_complete()) return false; + assert(num_incomplete_chunks > 0); + num_incomplete_chunks -= 1; + return true; + }; + incoming_chunks.erase(std::remove_if(incoming_chunks.begin(), incoming_chunks.end(), complete_chunk), incoming_chunks.end()); + return num_incomplete_chunks == 0; +} + +bool gather_transfer::do_complete() { return request->do_complete(); } + +bool unassigned_transfer::do_complete() { // NOLINT(readability-make-member-function-const) + // an unassigned_transfer inside receive_arbiter::m_transfers is never empty. + assert(!pilots.empty()); + return false; +} + +} // namespace celerity::detail::receive_arbiter_detail + +namespace celerity::detail { + +using namespace receive_arbiter_detail; + +receive_arbiter::receive_arbiter(communicator& comm) : m_comm(&comm), m_num_nodes(comm.get_num_nodes()) { assert(m_num_nodes > 0); } + +receive_arbiter::~receive_arbiter() { assert(std::uncaught_exceptions() > 0 || m_transfers.empty()); } + +receive_arbiter_detail::stable_region_request& receive_arbiter::initiate_region_request( + const transfer_id& trid, const region<3>& request, void* const allocation, const box<3>& allocated_box, const size_t elem_size) { + assert(allocated_box.covers(bounding_box(request))); + + // Ensure there is a multi_region_transfer present - if there is none, create it by consuming unassigned pilots + multi_region_transfer* mrt = nullptr; + if(const auto entry = m_transfers.find(trid); entry != m_transfers.end()) { + matchbox::match( + entry->second, // + [&](unassigned_transfer& ut) { mrt = &entry->second.emplace(elem_size, utils::take(ut.pilots)); }, + [&](multi_region_transfer& existing_mrt) { mrt = &existing_mrt; }, + [&](gather_transfer& gt) { utils::panic("calling receive_arbiter::begin_receive on an active gather transfer"); }); + } else { + mrt = &m_transfers[trid].emplace(elem_size); + } + + // Add a new region_request to the `mrt` (transfers have transfer_id granularity, but there might be multiple receives from independent range mappers + assert(std::all_of(mrt->active_requests.begin(), mrt->active_requests.end(), + [&](const stable_region_request& rr) { return region_intersection(rr->incomplete_region, request).empty(); })); + auto& rr = mrt->active_requests.emplace_back(std::make_shared(request, allocation, allocated_box)); + + // If the new region_request matches any of the still-unassigned pilots associated with `mrt`, immediately initiate the appropriate payload-receives + const auto assign_pilot = [&](const inbound_pilot& pilot) { + assert((region_intersection(rr->incomplete_region, pilot.message.box) != pilot.message.box) + == region_intersection(rr->incomplete_region, pilot.message.box).empty()); + if(region_intersection(rr->incomplete_region, pilot.message.box) == pilot.message.box) { + handle_region_request_pilot(*rr, pilot, elem_size); + return true; + } + return false; + }; + mrt->unassigned_pilots.erase(std::remove_if(mrt->unassigned_pilots.begin(), mrt->unassigned_pilots.end(), assign_pilot), mrt->unassigned_pilots.end()); + + return rr; +} + +void receive_arbiter::begin_split_receive( + const transfer_id& trid, const region<3>& request, void* const allocation, const box<3>& allocated_box, const size_t elem_size) { + initiate_region_request(trid, request, allocation, allocated_box, elem_size); +} + +async_event receive_arbiter::await_split_receive_subregion(const transfer_id& trid, const region<3>& subregion) { + // If there is no known associated `transfer`, we must have erased it previously due to the the entire `begin_split_receive` being completed. Any (partial) + // await thus immediately completes as well. + const auto transfer_it = m_transfers.find(trid); + if(transfer_it == m_transfers.end()) { return make_complete_event(); } + + auto& mrt = std::get(transfer_it->second); + +#ifndef NDEBUG + // all boxes from the awaited region must be contained in a single allocation + const auto awaited_bounds = bounding_box(subregion); + assert(std::all_of(mrt.active_requests.begin(), mrt.active_requests.end(), [&](const stable_region_request& rr) { + const auto overlap = box_intersection(rr->allocated_box, awaited_bounds); + return overlap.empty() || overlap == awaited_bounds; + })); +#endif + + // If the transfer (by transfer_id) as a whole has not completed yet but the subregion is, this "await" also completes immediately. + const auto req_it = std::find_if(mrt.active_requests.begin(), mrt.active_requests.end(), + [&](const stable_region_request& rr) { return !region_intersection(rr->incomplete_region, subregion).empty(); }); + if(req_it == mrt.active_requests.end()) { return make_complete_event(); } + + return make_async_event(*req_it, subregion); +} + +async_event receive_arbiter::receive( + const transfer_id& trid, const region<3>& request, void* const allocation, const box<3>& allocated_box, const size_t elem_size) { + return make_async_event(initiate_region_request(trid, request, allocation, allocated_box, elem_size)); +} + +async_event receive_arbiter::gather_receive(const transfer_id& trid, void* const allocation, const size_t node_chunk_size) { + auto gr = std::make_shared(allocation, node_chunk_size, m_num_nodes - 1 /* number of peers */); + + if(const auto entry = m_transfers.find(trid); entry != m_transfers.end()) { + // If we are already tracking a transfer `trid`, it must be unassigned, and we can initiate payload-receives for all unassigned pilots right away. + auto& ut = std::get(entry->second); + for(auto& pilot : ut.pilots) { + handle_gather_request_pilot(*gr, pilot); + } + entry->second = gather_transfer{gr}; + } else { + // Otherwise, we insert the transfer as pending and wait for the first pilots to arrive. + m_transfers.emplace(trid, gather_transfer{gr}); + } + + return make_async_event(gr); +} + +void receive_arbiter::poll_communicator() { + // Try completing all pending payload sends / receives by polling their communicator events + for(auto entry = m_transfers.begin(); entry != m_transfers.end();) { + if(std::visit([](auto& transfer) { return transfer.do_complete(); }, entry->second)) { + entry = m_transfers.erase(entry); + } else { + ++entry; + } + } + + for(const auto& pilot : m_comm->poll_inbound_pilots()) { + if(const auto entry = m_transfers.find(pilot.message.transfer_id); entry != m_transfers.end()) { + // If we already know a the transfer id, initiate pending payload-receives or add the pilot to the unassigned-list. + matchbox::match( + entry->second, // + [&](unassigned_transfer& ut) { // + ut.pilots.push_back(pilot); + }, + [&](multi_region_transfer& mrt) { + // find the unique region-request this pilot belongs to + const auto rr = std::find_if(mrt.active_requests.begin(), mrt.active_requests.end(), [&](const stable_region_request& rr) { + assert((region_intersection(rr->incomplete_region, pilot.message.box) != pilot.message.box) + == region_intersection(rr->incomplete_region, pilot.message.box).empty()); + return region_intersection(rr->incomplete_region, pilot.message.box) == pilot.message.box; + }); + if(rr != mrt.active_requests.end()) { + handle_region_request_pilot(**rr, pilot, mrt.elem_size); + } else { + mrt.unassigned_pilots.push_back(pilot); + } + }, + [&](gather_transfer& gt) { // + handle_gather_request_pilot(*gt.request, pilot); + }); + } else { + // If we haven't seen the transfer id before, create a new unassigned_transfer for it. + m_transfers.emplace(pilot.message.transfer_id, unassigned_transfer{{pilot}}); + } + } +} + +void receive_arbiter::handle_region_request_pilot(region_request& rr, const inbound_pilot& pilot, const size_t elem_size) { + assert(region_intersection(rr.incomplete_region, pilot.message.box) == pilot.message.box); + assert(rr.allocated_box.covers(pilot.message.box)); + + // Initiate a strided payload-receive directly into the allocation passed to receive() / begin_split_receive() + const auto offset_in_allocation = pilot.message.box.get_offset() - rr.allocated_box.get_offset(); + const communicator::stride stride{ + rr.allocated_box.get_range(), + subrange<3>{offset_in_allocation, pilot.message.box.get_range()}, + elem_size, + }; + auto event = m_comm->receive_payload(pilot.from, pilot.message.id, rr.allocation, stride); + rr.incoming_fragments.push_back({pilot.message.box, std::move(event)}); +} + +void receive_arbiter::handle_gather_request_pilot(gather_request& gr, const inbound_pilot& pilot) { + if(pilot.message.box.empty()) { + // Peers will send a pilot with an empty box to signal that they don't contribute to a reduction + assert(gr.num_incomplete_chunks > 0); + gr.num_incomplete_chunks -= 1; + } else { + // Initiate a region-receive with a simple stride to address the chunk id in the allocation + const communicator::stride stride{range_cast<3>(range(m_num_nodes)), subrange(id_cast<3>(id(pilot.from)), range_cast<3>(range(1))), gr.chunk_size}; + auto event = m_comm->receive_payload(pilot.from, pilot.message.id, gr.allocation, stride); + gr.incoming_chunks.push_back(incoming_gather_chunk{std::move(event)}); + } +} + +} // namespace celerity::detail diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 73d45482c..8151de9b9 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -44,6 +44,7 @@ set(TEST_TARGETS print_graph_tests region_map_tests range_tests + receive_arbiter_tests runtime_tests runtime_deprecation_tests sycl_tests diff --git a/test/receive_arbiter_tests.cc b/test/receive_arbiter_tests.cc new file mode 100644 index 000000000..c54ee4375 --- /dev/null +++ b/test/receive_arbiter_tests.cc @@ -0,0 +1,539 @@ +#include "buffer_storage.h" // for memcpy_strided_host +#include "receive_arbiter.h" +#include "test_utils.h" + +#include +#include + +#include +#include +#include + +using namespace celerity; +using namespace celerity::detail; + +/// A mock communicator implementation that allows tests to manually push inbound pilots and incoming receive payloads that the receive_arbiter is waiting for. +class mock_recv_communicator : public communicator { + public: + /// `num_nodes` and `local_node_id` simply are the values reported by the respective getters. + explicit mock_recv_communicator(const size_t num_nodes, const node_id local_node_id) : m_num_nodes(num_nodes), m_local_nid(local_node_id) {} + mock_recv_communicator(const mock_recv_communicator&) = delete; + mock_recv_communicator(mock_recv_communicator&&) = delete; + mock_recv_communicator& operator=(const mock_recv_communicator&) = delete; + mock_recv_communicator& operator=(mock_recv_communicator&&) = delete; + ~mock_recv_communicator() override { CHECK(m_pending_recvs.empty()); } + + size_t get_num_nodes() const override { return m_num_nodes; } + node_id get_local_node_id() const override { return m_local_nid; } + + void send_outbound_pilot(const outbound_pilot& /* pilot */) override { + utils::panic("unimplemented"); // receive_arbiter does not send stuff + } + + [[nodiscard]] std::vector poll_inbound_pilots() override { return std::move(m_inbound_pilots); } + + [[nodiscard]] async_event send_payload( + const node_id /* to */, const message_id /* outbound_pilot_tag */, const void* const /* base */, const stride& /* stride */) override { + utils::panic("unimplemented"); // receive_arbiter does not send stuff + } + + [[nodiscard]] async_event receive_payload(const node_id from, const message_id msgid, void* const base, const stride& stride) override { + const auto key = std::pair(from, msgid); + REQUIRE(m_pending_recvs.count(key) == 0); + completion_flag flag = std::make_shared(false); + m_pending_recvs.emplace(key, std::tuple(base, stride, flag)); + return make_async_event(flag); + } + + void push_inbound_pilot(const inbound_pilot& pilot) { m_inbound_pilots.push_back(pilot); } + + void complete_receiving_payload(const node_id from, const message_id msgid, const void* const src, const range<3>& src_range) { + const auto key = std::pair(from, msgid); + const auto [dest, stride, flag] = m_pending_recvs.at(key); + REQUIRE(src_range == stride.transfer.range); + memcpy_strided_host(src, dest, stride.element_size, src_range, zeros, stride.allocation_range, stride.transfer.offset, stride.transfer.range); + *flag = true; + m_pending_recvs.erase(key); + } + + std::unique_ptr collective_clone() override { utils::panic("unimplemented"); } + void collective_barrier() override { utils::panic("unimplemented"); } + + private: + using completion_flag = std::shared_ptr; + + class mock_event final : public async_event_impl { + public: + explicit mock_event(const completion_flag& flag) : m_flag(flag) {} + bool is_complete() const override { return *m_flag; } + + private: + completion_flag m_flag; + }; + + size_t m_num_nodes; + node_id m_local_nid; + std::vector m_inbound_pilots; + std::unordered_map, std::tuple, utils::pair_hash> m_pending_recvs; +}; + +/// Instructs the test loop to perform a specific operation on receive_arbiter or mock_recv_communicator in order to execute tests in all possible event orders +/// with the help of `enumerate_all_event_orders`. +struct receive_event { + enum { + call_to_receive, ///< `receive_arbiter::receive` or `begin_split_receive` is called + incoming_pilot, ///< `communicator::poll_inbound_pilots` returns a pilot with matching transfer id + incoming_data, ///< The async_event from `communicator::receive_payload` completes (after call_to_receive and incoming_pilot) + } transition; + + /// For call_to_receive, the index in requested_regions; for incoming_pilot/incoming_data, the index in incoming_fragments + size_t which; + + friend bool operator==(const receive_event& lhs, const receive_event& rhs) { return lhs.transition == rhs.transition && lhs.which == rhs.which; } + friend bool operator!=(const receive_event& lhs, const receive_event& rhs) { return !(lhs == rhs); } +}; + +template <> +struct Catch::StringMaker { + static std::string convert(const receive_event& event) { + switch(event.transition) { + case receive_event::call_to_receive: return fmt::format("call_to_receive[{}]", event.which); + case receive_event::incoming_pilot: return fmt::format("incoming_pilot[{}]", event.which); + case receive_event::incoming_data: return fmt::format("incoming_data[{}]", event.which); + default: abort(); + } + } +}; + +/// Enumerates all O(N!) possible `receive_event` orders that would complete the `requested_regions` with the `incoming_fragments`. +std::vector> enumerate_all_event_orders( + const std::vector>& requested_regions, const std::vector>& incoming_fragments) // +{ + constexpr static auto permutation_order = [](const receive_event& lhs, const receive_event& rhs) { + if(lhs.transition < rhs.transition) return true; + if(lhs.transition > rhs.transition) return false; + return lhs.which < rhs.which; + }; + + // construct the first permutation according to permutation_order + std::vector current_permutation; + for(size_t region_id = 0; region_id < requested_regions.size(); ++region_id) { + current_permutation.push_back({receive_event::call_to_receive, region_id}); + } + for(size_t fragment_id = 0; fragment_id < incoming_fragments.size(); ++fragment_id) { + current_permutation.push_back({receive_event::incoming_pilot, fragment_id}); + } + for(size_t fragment_id = 0; fragment_id < incoming_fragments.size(); ++fragment_id) { + current_permutation.push_back({receive_event::incoming_data, fragment_id}); + } + + // helper: get the index within current_permutation + const auto index_of = [&](const receive_event& event) { + for(size_t i = 0; i < current_permutation.size(); ++i) { + if(current_permutation[i].transition == event.transition && current_permutation[i].which == event.which) return i; + } + abort(); + }; + + // collect all legal permutations (i.e. pilots are received before data, and calls to receive() also happen before receiving data) + std::vector> transition_orders; + for(;;) { + bool is_valid_order = true; + for(size_t fragment_id = 0; fragment_id < incoming_fragments.size(); ++fragment_id) { + is_valid_order &= index_of({receive_event::incoming_pilot, fragment_id}) < index_of({receive_event::incoming_data, fragment_id}); + } + for(size_t region_id = 0; region_id < requested_regions.size(); ++region_id) { + const auto receive_called_at = index_of({receive_event::call_to_receive, region_id}); + for(size_t i = 0; i < receive_called_at; ++i) { + if(current_permutation[i].transition == receive_event::incoming_data) { + is_valid_order &= region_intersection(incoming_fragments[current_permutation[i].which], requested_regions[region_id]).empty(); + } + } + } + if(is_valid_order) { transition_orders.push_back(current_permutation); } + + if(!std::next_permutation(current_permutation.begin(), current_permutation.end(), permutation_order)) { + // we wrapped around to the first permutation according to permutation_order + return transition_orders; + } + } +} + +TEST_CASE("receive_arbiter aggregates receives from multiple incoming fragments", "[receive_arbiter]") { + static const transfer_id trid(task_id(1), buffer_id(420), no_reduction_id); + static const box<3> alloc_box = {{2, 1, 0}, {39, 10, 10}}; + static const std::vector> incoming_fragments{ + box<3>({4, 2, 1}, {22, 9, 4}), + box<3>({4, 2, 4}, {22, 9, 8}), + box<3>({22, 2, 1}, {37, 9, 4}), + }; + static const std::vector> requested_regions{ + region(box_vector<3>(incoming_fragments.begin(), incoming_fragments.end())), + }; + static const size_t elem_size = sizeof(int); + + const auto& event_order = GENERATE(from_range(enumerate_all_event_orders(requested_regions, incoming_fragments))); + CAPTURE(event_order); + + const auto receive_method = GENERATE(values({"single", "split_await"})); + CAPTURE(receive_method); + + mock_recv_communicator comm(4, 0); + receive_arbiter ra(comm); + + std::vector allocation(alloc_box.get_range().size()); + std::optional receive; + + for(const auto& [transition, which] : event_order) { + const node_id peer = 1 + which; + const message_id msgid = 10 + which; + CAPTURE(transition, which, peer, msgid); + + // only the last event (always an incoming_data transition) will complete the receive + if(receive.has_value()) { CHECK(!receive->is_complete()); } + + switch(transition) { + case receive_event::call_to_receive: { + CHECK_FALSE(receive.has_value()); + if(receive_method == "single") { + receive = ra.receive(trid, requested_regions[0], allocation.data(), alloc_box, elem_size); + } else if(receive_method == "split_await") { + ra.begin_split_receive(trid, requested_regions[0], allocation.data(), alloc_box, elem_size); + receive = ra.await_split_receive_subregion(trid, requested_regions[0]); + } + break; + } + case receive_event::incoming_pilot: { + comm.push_inbound_pilot(inbound_pilot{peer, pilot_message{msgid, trid, incoming_fragments[which]}}); + break; + } + case receive_event::incoming_data: { + std::vector fragment(incoming_fragments[which].get_range().size(), static_cast(peer)); + comm.complete_receiving_payload(peer, msgid, fragment.data(), incoming_fragments[which].get_range()); + break; + } + } + ra.poll_communicator(); + } + + REQUIRE(receive.has_value()); + CHECK(receive->is_complete()); + + // it is legal to `await` a transfer that has already been completed and is not tracked by the receive_arbiter anymore + CHECK(ra.await_split_receive_subregion(trid, requested_regions[0]).is_complete()); + CHECK(ra.await_split_receive_subregion(trid, incoming_fragments[0]).is_complete()); + + std::vector expected_allocation(alloc_box.get_range().size()); + for(size_t which = 0; which < incoming_fragments.size(); ++which) { + const auto& box = incoming_fragments[which]; + const node_id peer = 1 + which; + test_utils::for_each_in_range(box.get_range(), box.get_offset() - alloc_box.get_offset(), [&](const id<3>& id_in_allocation) { + const auto linear_index = get_linear_index(alloc_box.get_range(), id_in_allocation); + expected_allocation[linear_index] = static_cast(peer); + }); + } + CHECK(allocation == expected_allocation); +} + +TEST_CASE("receive_arbiter can complete await-receives through differently-shaped overlapping fragments", "[receive_arbiter]") { + static const transfer_id trid(task_id(1), buffer_id(420), no_reduction_id); + static const box<3> alloc_box = {{2, 1, 0}, {19, 20, 1}}; + static const std::vector> requested_regions{ + box<3>{{4, 1, 0}, {19, 18, 1}}, + }; + static const std::vector> awaited_regions{ + region<3>{{{{4, 1, 0}, {14, 10, 1}}, {{14, 1, 0}, {19, 18, 1}}}}, + box<3>{{4, 10, 0}, {14, 18, 1}}, + }; + static const std::vector> incoming_fragments{ + box<3>{{4, 1, 0}, {14, 18, 1}}, + box<3>{{14, 1, 0}, {19, 18, 1}}, + }; + static const size_t elem_size = sizeof(int); + + const auto& event_order = GENERATE(from_range(enumerate_all_event_orders(requested_regions, incoming_fragments))); + CAPTURE(event_order); + + mock_recv_communicator comm(2, 0); + receive_arbiter ra(comm); + + const node_id peer = 1; + + std::vector allocation(alloc_box.get_range().size()); + std::optional awaits[2]; + region<3> region_received; + + for(const auto& [transition, which] : event_order) { + const message_id msgid = 10 + which; + CAPTURE(transition, which, peer, msgid); + + // check that fragments[0] completes awaits[1] + for(size_t await_id = 0; await_id < 2; ++await_id) { + if(!awaits[await_id].has_value()) continue; + CHECK(awaits[await_id]->is_complete() == (region_intersection(region_received, awaited_regions[await_id]) == awaited_regions[await_id])); + } + + switch(transition) { + case receive_event::call_to_receive: { + ra.begin_split_receive(trid, requested_regions[0], allocation.data(), alloc_box, elem_size); + awaits[0] = ra.await_split_receive_subregion(trid, awaited_regions[0]); + awaits[1] = ra.await_split_receive_subregion(trid, awaited_regions[1]); + break; + } + case receive_event::incoming_pilot: { + comm.push_inbound_pilot(inbound_pilot{peer, pilot_message{msgid, trid, incoming_fragments[which]}}); + break; + } + case receive_event::incoming_data: { + std::vector fragment(incoming_fragments[which].get_range().size(), static_cast(1 + which)); + comm.complete_receiving_payload(peer, msgid, fragment.data(), incoming_fragments[which].get_range()); + region_received = region_union(region_received, incoming_fragments[which]); + break; + } + } + ra.poll_communicator(); + } + + REQUIRE(awaits[0].has_value()); + REQUIRE(awaits[0]->is_complete()); + REQUIRE(awaits[1].has_value()); + REQUIRE(awaits[1]->is_complete()); + + // it is legal to `await` a transfer that has already been completed and is not tracked by the receive_arbiter anymore + CHECK(ra.await_split_receive_subregion(trid, requested_regions[0]).is_complete()); + CHECK(ra.await_split_receive_subregion(trid, incoming_fragments[0]).is_complete()); + + std::vector expected_allocation(alloc_box.get_range().size()); + for(size_t which = 0; which < incoming_fragments.size(); ++which) { + const auto& box = incoming_fragments[which]; + test_utils::for_each_in_range(box.get_range(), box.get_offset() - alloc_box.get_offset(), [&](const id<3>& id_in_allocation) { + const auto linear_index = get_linear_index(alloc_box.get_range(), id_in_allocation); + expected_allocation[linear_index] = static_cast(1 + which); + }); + } + CHECK(allocation == expected_allocation); +} + +TEST_CASE("receive_arbiter immediately completes await-receives for which all corresponding fragments have already been received", "[receive_arbiter]") { + static const transfer_id trid(task_id(1), buffer_id(420), no_reduction_id); + static const box<3> alloc_box = {{2, 1, 0}, {19, 20, 1}}; + static const std::vector> requested_regions{ + box<3>{{4, 1, 0}, {19, 18, 1}}, + }; + static const std::vector> awaited_regions{ + region<3>{{{{4, 1, 0}, {14, 10, 1}}, {{14, 1, 0}, {19, 18, 1}}}}, + box<3>{{4, 10, 0}, {14, 18, 1}}, + }; + static const std::vector> incoming_fragments{ + box<3>{{4, 1, 0}, {14, 18, 1}}, + box<3>{{14, 1, 0}, {19, 18, 1}}, + }; + static const size_t elem_size = sizeof(int); + + mock_recv_communicator comm(2, 0); + receive_arbiter ra(comm); + + const node_id peer = 1; + + std::vector allocation(alloc_box.get_range().size()); + region<3> region_received; + + const auto receive_fragment = [&](const size_t which) { + const message_id msgid = 10 + which; + comm.push_inbound_pilot(inbound_pilot{peer, pilot_message{msgid, trid, incoming_fragments[which]}}); + ra.poll_communicator(); + std::vector fragment(incoming_fragments[which].get_range().size(), static_cast(1 + which)); + comm.complete_receiving_payload(peer, msgid, fragment.data(), incoming_fragments[which].get_range()); + ra.poll_communicator(); + region_received = region_union(region_received, incoming_fragments[which]); + }; + + ra.begin_split_receive(trid, requested_regions[0], allocation.data(), alloc_box, elem_size); + receive_fragment(0); + + auto await0 = ra.await_split_receive_subregion(trid, awaited_regions[0]); + CHECK_FALSE(await0.is_complete()); + auto await1 = ra.await_split_receive_subregion(trid, awaited_regions[1]); + CHECK(await1.is_complete()); + + receive_fragment(1); + CHECK(await0.is_complete()); + + std::vector expected_allocation(alloc_box.get_range().size()); + for(size_t which = 0; which < incoming_fragments.size(); ++which) { + const auto& box = incoming_fragments[which]; + test_utils::for_each_in_range(box.get_range(), box.get_offset() - alloc_box.get_offset(), [&](const id<3>& id_in_allocation) { + const auto linear_index = get_linear_index(alloc_box.get_range(), id_in_allocation); + expected_allocation[linear_index] = static_cast(1 + which); + }); + } + CHECK(allocation == expected_allocation); +} + +TEST_CASE("receive_arbiter handles multiple receive instructions for the same transfer id", "[receive_arbiter]") { + static const transfer_id trid(task_id(1), buffer_id(420), no_reduction_id); + static const box<3> alloc_box = {{0, 0, 0}, {20, 20, 1}}; + static const std::vector incoming_fragments{ + box<3>({2, 2, 0}, {8, 18, 1}), + box<3>({12, 2, 0}, {18, 18, 1}), + }; + static const std::vector requested_regions{ + region(incoming_fragments[0]), + region(incoming_fragments[1]), + }; + static const size_t elem_size = sizeof(int); + + const auto& event_order = GENERATE(from_range(enumerate_all_event_orders(requested_regions, incoming_fragments))); + CAPTURE(event_order); + + mock_recv_communicator comm(3, 0); + receive_arbiter ra(comm); + + std::vector allocation(alloc_box.get_range().size()); + std::map events; + + for(const auto& [transition, which] : event_order) { + const node_id peer = 1 + which; + const message_id msgid = 10 + which; + CAPTURE(transition, which, peer, msgid); + + switch(transition) { + case receive_event::call_to_receive: { + events.emplace(peer, ra.receive(trid, requested_regions[which], allocation.data(), alloc_box, elem_size)); + break; + } + case receive_event::incoming_pilot: { + comm.push_inbound_pilot(inbound_pilot{peer, pilot_message{msgid, trid, incoming_fragments[which]}}); + break; + } + case receive_event::incoming_data: { + std::vector fragment(incoming_fragments[which].get_range().size(), static_cast(peer)); + comm.complete_receiving_payload(peer, msgid, fragment.data(), incoming_fragments[which].get_range()); + break; + } + } + ra.poll_communicator(); + + if(events.count(peer) > 0) { CHECK(events.at(peer).is_complete() == (transition == receive_event::incoming_data)); } + } + + for(auto& [from, event] : events) { + CAPTURE(from); + CHECK(event.is_complete()); + } + + std::vector expected(alloc_box.get_range().size()); + for(size_t from = 0; from < incoming_fragments.size(); ++from) { + const auto& box = incoming_fragments[from]; + test_utils::for_each_in_range(box.get_range(), box.get_offset() - alloc_box.get_offset(), [&, from = from](const id<3>& id_in_allocation) { + const auto linear_index = get_linear_index(alloc_box.get_range(), id_in_allocation); + expected[linear_index] = static_cast(1 + from); + }); + } + CHECK(allocation == expected); +} + +TEST_CASE("receive_arbiter::gather_receive works", "[receive_arbiter]") { + static const transfer_id trid(task_id(2), buffer_id(0), reduction_id(1)); + static const box<3> unit_box{{0, 0, 0}, {1, 1, 1}}; + static const std::vector requested_regions{region(unit_box)}; + static const std::vector incoming_fragments{unit_box, unit_box, unit_box}; // each fragment is the chunk from a peer + static const size_t chunk_size = sizeof(int); + + const auto& event_order = GENERATE(from_range(enumerate_all_event_orders(requested_regions, incoming_fragments))); + CAPTURE(event_order); + + mock_recv_communicator comm(4, 0); + receive_arbiter ra(comm); + + std::vector allocation(comm.get_num_nodes(), -1); + std::optional receive; + + for(const auto& [transition, which] : event_order) { + CAPTURE(transition, which); + + // only the last event (always an incoming_data transition) will complete the receive + if(receive.has_value()) { CHECK(!receive->is_complete()); } + + switch(transition) { + case receive_event::call_to_receive: { + receive = ra.gather_receive(trid, allocation.data(), chunk_size); + break; + } + case receive_event::incoming_pilot: { + const node_id peer = 1 + which; + const message_id msgid = 10 + which; + CAPTURE(peer, msgid); + comm.push_inbound_pilot(inbound_pilot{peer, pilot_message{msgid, trid, incoming_fragments[which]}}); + break; + } + case receive_event::incoming_data: { + const node_id peer = 1 + which; + const message_id msgid = 10 + which; + CAPTURE(peer, msgid); + std::vector fragment(incoming_fragments[which].get_range().size(), static_cast(peer)); + comm.complete_receiving_payload(peer, msgid, fragment.data(), incoming_fragments[which].get_range()); + break; + } + } + ra.poll_communicator(); + } + + REQUIRE(receive.has_value()); + CHECK(receive->is_complete()); + CHECK(allocation == std::vector{-1 /* unchanged */, 1, 2, 3}); +} + +// peers will send a pilot with an empty box to signal that they don't contribute to a reduction +TEST_CASE("receive_arbiter knows how to handle empty pilot boxes in gathers", "[receive_arbiter]") { + const transfer_id trid(task_id(2), buffer_id(0), reduction_id(1)); + const box<3> empty_box; + static const box<3> unit_box{{0, 0, 0}, {1, 1, 1}}; + static const std::vector requested_regions{region(unit_box)}; + static const std::vector incoming_fragments{unit_box, empty_box}; + static const size_t chunk_size = sizeof(int); + + const auto& event_order = GENERATE(from_range(enumerate_all_event_orders(requested_regions, incoming_fragments))); + CAPTURE(event_order); + + mock_recv_communicator comm(3, 0); + receive_arbiter ra(comm); + + std::vector allocation(comm.get_num_nodes(), -1); + std::optional receive; + + for(const auto& [transition, which] : event_order) { + CAPTURE(transition, which); + + switch(transition) { + case receive_event::call_to_receive: { + receive = ra.gather_receive(trid, allocation.data(), chunk_size); + break; + } + case receive_event::incoming_pilot: { + const node_id peer = 1 + which; + const message_id msgid = 10 + which; + CAPTURE(peer, msgid); + comm.push_inbound_pilot(inbound_pilot{peer, pilot_message{msgid, trid, incoming_fragments[which]}}); + break; + } + case receive_event::incoming_data: { + if(!incoming_fragments[which].empty()) { + const node_id peer = 1 + which; + const message_id msgid = 10 + which; + CAPTURE(peer, msgid); + std::vector fragment(incoming_fragments[which].get_range().size(), static_cast(peer)); + comm.complete_receiving_payload(peer, msgid, fragment.data(), incoming_fragments[which].get_range()); + } + break; + } + } + ra.poll_communicator(); + } + + REQUIRE(receive.has_value()); + CHECK(receive->is_complete()); + + // only a single chunk, `1`, is actually written to `allocation` + CHECK(allocation == std::vector{-1, 1, -1}); +} diff --git a/test/system/CMakeLists.txt b/test/system/CMakeLists.txt index 252953ca1..1b0c027ee 100644 --- a/test/system/CMakeLists.txt +++ b/test/system/CMakeLists.txt @@ -1,14 +1,14 @@ -add_executable(distr_tests distr_tests.cc) +set(SYSTEM_TEST_TARGETS + distr_tests + mpi_tests +) -target_link_libraries(distr_tests PRIVATE test_main) +foreach(TEST_TARGET ${SYSTEM_TEST_TARGETS}) + set(TEST_SOURCE ${TEST_TARGET}.cc) -set_property(TARGET distr_tests PROPERTY CXX_STANDARD ${CELERITY_CXX_STANDARD}) -set_property(TARGET distr_tests PROPERTY FOLDER "tests/system") + add_executable(${TEST_TARGET} ${TEST_SOURCE}) + target_link_libraries(${TEST_TARGET} PRIVATE test_main) + set_test_target_parameters(${TEST_TARGET} ${TEST_SOURCE}) -add_celerity_to_target(TARGET distr_tests SOURCES distr_tests.cc) - -if(MSVC) - target_compile_options(distr_tests PRIVATE /D_CRT_SECURE_NO_WARNINGS /MP /W3 /bigobj) -elseif(CMAKE_CXX_COMPILER_ID MATCHES "GNU|Clang|AppleClang") - target_compile_options(distr_tests PRIVATE -Wall -Wextra -Wno-unused-parameter -Wno-unused-variable) -endif() + ParseAndAddCatchTests_ParseFile(${TEST_SOURCE} ${TEST_TARGET}) +endforeach() diff --git a/test/system/mpi_tests.cc b/test/system/mpi_tests.cc new file mode 100644 index 000000000..0cd555a69 --- /dev/null +++ b/test/system/mpi_tests.cc @@ -0,0 +1,360 @@ +#include "../test_utils.h" + +#include "communicator.h" +#include "mpi_communicator.h" +#include "types.h" + +#include + +#include +#include + + +using namespace celerity; +using namespace celerity::detail; +using namespace std::chrono_literals; + + +namespace celerity::detail { + +struct mpi_communicator_testspy { + static size_t get_num_active_outbound_pilots(const mpi_communicator& comm) { return comm.m_outbound_pilots.size(); } + static size_t get_num_cached_array_types(const mpi_communicator& comm) { return comm.m_array_type_cache.size(); } + static size_t get_num_cached_scalar_types(const mpi_communicator& comm) { return comm.m_scalar_type_cache.size(); } +}; + +} // namespace celerity::detail + + +TEST_CASE_METHOD(test_utils::mpi_fixture, "mpi_communicator sends and receives pilot messages", "[mpi]") { + mpi_communicator comm(collective_clone_from, MPI_COMM_WORLD); + const auto num_nodes = comm.get_num_nodes(); + const auto self = comm.get_local_node_id(); + CAPTURE(num_nodes, self); + + if(num_nodes <= 1) { SKIP("test must be run on at least 2 ranks"); } + + const auto make_pilot_message = [&](const node_id sender, const node_id receiver) { + // Compute a unique id for the (sender, receiver) tuple and base all other members of the pilot message on this ID to test that we receive the correct + // pilots on the correct nodes (and everything remains uncorrupted). + const auto p2p_id = 1 + sender * num_nodes + receiver; + const message_id msgid = p2p_id * 13; + const buffer_id bid = p2p_id * 11; + const task_id consumer_tid = p2p_id * 17; + const reduction_id rid = p2p_id * 19; + const transfer_id trid(consumer_tid, bid, rid); + const box<3> box = {id{p2p_id, p2p_id * 2, p2p_id * 3}, id{p2p_id * 4, p2p_id * 5, p2p_id * 6}}; + return outbound_pilot{receiver, pilot_message{msgid, trid, box}}; + }; + + // Send a pilot from each node to each other node + for(node_id other = 0; other < num_nodes; ++other) { + if(other == self) continue; + comm.send_outbound_pilot(make_pilot_message(self, other)); + } + + size_t num_pilots_received = 0; + while(num_pilots_received < num_nodes - 1) { + // busy-wait for all expected pilots to arrive + for(const auto& pilot : comm.poll_inbound_pilots()) { + CAPTURE(pilot.from); + const auto expect = make_pilot_message(pilot.from, self); + CHECK(pilot.message.id == expect.message.id); + CHECK(pilot.message.transfer_id == expect.message.transfer_id); + CHECK(pilot.message.box == expect.message.box); + ++num_pilots_received; + } + } + + SUCCEED("it didn't deadlock 👏"); +} + +TEST_CASE_METHOD(test_utils::mpi_fixture, "mpi_communicator sends and receives payloads", "[mpi]") { + mpi_communicator comm(collective_clone_from, MPI_COMM_WORLD); + const auto num_nodes = comm.get_num_nodes(); + const auto self = comm.get_local_node_id(); + CAPTURE(num_nodes, self); + + if(num_nodes <= 1) { SKIP("test must be run on at least 2 ranks"); } + + const auto make_msgid = [=](const node_id sender, const node_id receiver) { // + return message_id(1 + sender * num_nodes + receiver); + }; + + constexpr static communicator::stride stride{{12, 11, 11}, {{1, 0, 3}, {5, 4, 6}}, sizeof(int)}; + + std::vector> send_buffers; + std::vector> receive_buffers; + std::vector events; + for(node_id other = 0; other < num_nodes; ++other) { + if(other == self) continue; + + // allocate a send buffer and fill with a (sender, receiver) specific pattern that can be tested after receiving + auto& send = send_buffers.emplace_back(stride.allocation_range.size()); + std::iota(send.begin(), send.end(), make_msgid(self, other)); + + // allocate and zero-fill a receive buffer (zero is never part of a payload) + auto& receive = receive_buffers.emplace_back(stride.allocation_range.size()); + + // start send and receive operations + events.push_back(comm.send_payload(other, make_msgid(self, other), send.data(), stride)); + events.push_back(comm.receive_payload(other, make_msgid(other, self), receive.data(), stride)); + } + + // busy-wait for all send / receive events to complete + while(!events.empty()) { + const auto end_incomplete = std::remove_if(events.begin(), events.end(), std::mem_fn(&async_event::is_complete)); + events.erase(end_incomplete, events.end()); + } + + auto received = receive_buffers.begin(); + for(node_id other = 0; other < num_nodes; ++other) { + if(other == self) continue; + + // reconstruct the expected receive buffer + std::vector other_send(stride.allocation_range.size()); + std::iota(other_send.begin(), other_send.end(), make_msgid(other, self)); + std::vector expected(stride.allocation_range.size()); + test_utils::for_each_in_range(stride.transfer.range, stride.transfer.offset, [&](const id<3>& id) { + const auto linear_index = get_linear_index(stride.allocation_range, id); + expected[linear_index] = other_send[linear_index]; + }); + + CHECK(*received == expected); + ++received; // not equivalent to receive_buffers[other] because we skip `other == self` + } +} + +// We require that it's well-defined to send a scalar from an n-dimensional stride and receive it in an m-dimensional stride, since stride dimensionality is +// determined from effective allocation dimensionality, which can vary between participating nodes depending on the size of their buffer host allocations. +TEST_CASE_METHOD(test_utils::mpi_fixture, "mpi_communicator correctly transfers scalars between strides of different dimensionality", "[mpi]") { + // All GENERATEs must happen before an early-return, otherwise different nodes will execute this test case different numbers of times + const auto send_dims = GENERATE(values({0, 1, 2, 3})); + const auto recv_dims = GENERATE(values({0, 1, 2, 3})); + CAPTURE(send_dims, recv_dims); + + mpi_communicator comm(collective_clone_from, MPI_COMM_WORLD); + const auto num_nodes = comm.get_num_nodes(); + const auto local_node_id = comm.get_local_node_id(); + CAPTURE(num_nodes, local_node_id); + + if(num_nodes <= 1) { SKIP("test must be run on at least 2 ranks"); } + if(local_node_id >= 2) return; // needs exactly 2 nodes to participate + + constexpr communicator::stride dim_strides[] = { + {{1, 1, 1}, {{0, 0, 0}, {1, 1, 1}}, 4}, // 0-dimensional + {{2, 1, 1}, {{1, 0, 0}, {1, 1, 1}}, 4}, // 1-dimensional + {{2, 3, 1}, {{1, 2, 0}, {1, 1, 1}}, 4}, // 2-dimensional + {{2, 3, 5}, {{1, 2, 3}, {1, 1, 1}}, 4}, // 3-dimensional + }; + + const auto& send_stride = dim_strides[send_dims]; + const auto& recv_stride = dim_strides[recv_dims]; + + std::vector buf(dim_strides[3].allocation_range.size()); + async_event evt; + if(local_node_id == 1) { // sender + buf[get_linear_index(send_stride.allocation_range, send_stride.transfer.offset)] = 42; + evt = comm.send_payload(0, 99, buf.data(), send_stride); + } else { // receiver + evt = comm.receive_payload(1, 99, buf.data(), recv_stride); + } + // busy-wait for event + while(!evt.is_complete()) {} + + if(local_node_id == 0) { // receiver + std::vector expected(dim_strides[3].allocation_range.size()); + expected[get_linear_index(recv_stride.allocation_range, recv_stride.transfer.offset)] = 42; + CHECK(buf == expected); + } +} + +TEST_CASE_METHOD(test_utils::mpi_fixture, "mpi_communicator correctly transfers boxes that map to different subranges on sender and receiver", "[mpi]") { + // All GENERATEs must happen before an early-return, otherwise different nodes will execute this test case a different number of times + const auto dims = GENERATE(values({1, 2, 3})); + CAPTURE(dims); + + mpi_communicator comm(collective_clone_from, MPI_COMM_WORLD); + const auto num_nodes = comm.get_num_nodes(); + const auto local_node_id = comm.get_local_node_id(); + CAPTURE(num_nodes, local_node_id); + + if(num_nodes <= 1) { SKIP("test must be run on at least 2 ranks"); } + if(local_node_id >= 2) return; // needs exactly 2 nodes + + range box_range{3, 4, 5}; + range sender_allocation{10, 7, 11}; + id sender_offset{1, 2, 3}; + range receiver_allocation{8, 10, 13}; + id receiver_offset{2, 0, 4}; + // manually truncate to runtime value `dims` + for(int d = dims; d < 3; ++d) { + box_range[d] = 1; + sender_allocation[d] = 1; + sender_offset[d] = 0; + receiver_allocation[d] = 1; + receiver_offset[d] = 0; + } + + std::vector send_buf(sender_allocation.size()); + std::vector recv_buf(receiver_allocation.size()); + + std::iota(send_buf.begin(), send_buf.end(), 0); + + async_event evt; + if(local_node_id == 1) { // sender + evt = comm.send_payload(0, 99, send_buf.data(), communicator::stride{sender_allocation, subrange{sender_offset, box_range}, sizeof(int)}); + } else { // receiver + evt = comm.receive_payload(1, 99, recv_buf.data(), communicator::stride{receiver_allocation, subrange{receiver_offset, box_range}, sizeof(int)}); + } + while(!evt.is_complete()) {} // busy-wait for evt + + if(local_node_id == 0) { + std::vector expected(receiver_allocation.size()); + test_utils::for_each_in_range(box_range, [&](const id<3>& id) { + const auto sender_idx = get_linear_index(sender_allocation, sender_offset + id); + const auto receiver_idx = get_linear_index(receiver_allocation, receiver_offset + id); + expected[receiver_idx] = send_buf[sender_idx]; + }); + CHECK(recv_buf == expected); + } +} + +TEST_CASE_METHOD(test_utils::mpi_fixture, "collectives are concurrent between distinct mpi_communicators", "[mpi][smoke-test]") { + constexpr static size_t concurrency = 16; + + // create a bunch of communicators that we can then operate on from concurrent threads + std::vector> roots; + for(size_t i = 0; i < concurrency; ++i) { + roots.push_back(std::make_unique(collective_clone_from, MPI_COMM_WORLD)); + } + + // for each communicator, spawn a thread that creates more communicators + std::vector>> concurrent_clones(concurrency); + std::vector concurrent_threads(concurrency); + for(size_t i = 0; i < concurrency; ++i) { + concurrent_threads[i] = std::thread([&, i] { + for(size_t j = 0; j < concurrency; ++j) { + concurrent_clones[i].push_back(roots[i]->collective_clone()); + std::this_thread::sleep_for(10ms); // ensure the OS doesn't serialize all threads by chance + } + }); + } + for(size_t i = 0; i < concurrency; ++i) { + concurrent_threads[i].join(); + } + + // flip the iteration order and issue a barrier from each new collective group + for(size_t i = 0; i < concurrency; ++i) { + concurrent_threads[i] = std::thread([&, i] { + for(size_t j = 0; j < concurrency; ++j) { + concurrent_clones[j][i]->collective_barrier(); + std::this_thread::sleep_for(10ms); // ensure the OS doesn't serialize all threads by chance + } + }); + } + for(size_t i = 0; i < concurrency; ++i) { + concurrent_threads[i].join(); + } + + // ~mpi_communicator is also a collective operation; and it shouldn't matter if we destroy parents before their children + roots.clear(); + + for(size_t i = 0; i < concurrency; ++i) { + concurrent_threads[i] = std::thread([&, i] { + concurrent_clones[i].clear(); // ~mpi_communicator is a collective operation + }); + } + for(size_t i = 0; i < concurrency; ++i) { + concurrent_threads[i].join(); + } + + SUCCEED("it didn't deadlock or crash 🎉"); +} + +TEST_CASE("mpi_communicator normalizes strides to cache and re-uses MPI data types", "[mpi]") { + static const std::vector> sets_of_equivalent_strides{ + // strides only differ in allocation size / offset in dim 0 and can be normalized by adjusting the base pointer + { + {{13, 12, 11}, {{1, 0, 3}, {5, 4, 6}}, sizeof(int)}, + {{5, 12, 11}, {{0, 0, 3}, {5, 4, 6}}, sizeof(int)}, + {{20, 12, 11}, {{4, 0, 3}, {5, 4, 6}}, sizeof(int)}, + }, + { + {{13, 1, 1}, {{1, 0, 0}, {5, 1, 1}}, sizeof(int)}, + {{1, 13, 1}, {{0, 1, 0}, {1, 5, 1}}, sizeof(int)}, + {{1, 1, 13}, {{0, 0, 1}, {1, 1, 5}}, sizeof(int)}, + }, + }; + // All GENERATEs must happen before an early-return, otherwise different nodes will execute this test case different numbers of times + const auto equivalent_strides = GENERATE(from_range(sets_of_equivalent_strides)); + + mpi_communicator comm(collective_clone_from, MPI_COMM_WORLD); + const auto num_nodes = comm.get_num_nodes(); + const auto self = comm.get_local_node_id(); + CAPTURE(num_nodes, self); + + if(num_nodes <= 1) { SKIP("test must be run on at least 2 ranks"); } + if(self >= 2) return; // needs exactly 2 nodes + const node_id peer = 1 - self; + + message_id msgid = 0; + for(int repeat = 0; repeat < 2; ++repeat) { + CAPTURE(repeat); + for(const auto& stride : equivalent_strides) { + CAPTURE(stride.allocation_range, stride.transfer, stride.element_size); + std::vector send_buf(stride.allocation_range.size()); + std::iota(send_buf.begin(), send_buf.end(), 1); + const auto send_evt = comm.send_payload(peer, msgid, std::as_const(send_buf).data(), stride); + + std::vector recv_buf(stride.allocation_range.size()); + const auto recv_event = comm.receive_payload(peer, msgid, recv_buf.data(), stride); + + while(!send_evt.is_complete() || !recv_event.is_complete()) {} // busy-wait for events to complete + + std::vector expected(stride.allocation_range.size()); + test_utils::for_each_in_range(stride.transfer.range, stride.transfer.offset, [&](const id<3>& id) { + const auto linear_id = get_linear_index(stride.allocation_range, id); + expected[linear_id] = send_buf[linear_id]; + }); + + CHECK(recv_buf == expected); + ++msgid; + } + } + + CHECK(mpi_communicator_testspy::get_num_cached_array_types(comm) == 1); // all strides we sent/received were equivalent under normalization + CHECK(mpi_communicator_testspy::get_num_cached_scalar_types(comm) == 1); // only scalar type used was int +} + +TEST_CASE("successfully sent pilots are garbage-collected by communicator", "[mpi]") { + mpi_communicator comm(collective_clone_from, MPI_COMM_WORLD); + const auto num_nodes = comm.get_num_nodes(); + const auto self = comm.get_local_node_id(); + CAPTURE(num_nodes, self); + + if(num_nodes <= 1) { SKIP("test must be run on at least 2 ranks"); } + + const bool participate = self < 2; // needs exactly 2 participating nodes + const node_id peer = 1 - self; + + if(participate) { + for(int i = 0; i < 3; ++i) { + comm.send_outbound_pilot(outbound_pilot{peer, pilot_message{0, transfer_id{0, 0, 0}, box<3>{id{0, 0, 0}, id{1, 1, 1}}}}); + } + CHECK(mpi_communicator_testspy::get_num_active_outbound_pilots(comm) <= 3); + + size_t num_received_pilots = 0; + while(num_received_pilots < 3) { + num_received_pilots += comm.poll_inbound_pilots().size(); + } + } + + comm.collective_barrier(); // hope that this also means all p2p transfers have complete... + + if(participate) { + // send_outbound_pilot will garbage-collect all finished pilot-sends + comm.send_outbound_pilot(outbound_pilot{peer, pilot_message{0, transfer_id{0, 0, 0}, box<3>{id{0, 0, 0}, id{1, 1, 1}}}}); + CHECK(mpi_communicator_testspy::get_num_active_outbound_pilots(comm) <= 1); + } +}