Ginkgo  Generated from pipelines/1706354773 branch based on develop. Ginkgo version 1.10.0
A numerical linear algebra library targeting many-core architectures
collective_communicator.hpp
1 // SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
2 //
3 // SPDX-License-Identifier: BSD-3-Clause
4 
5 #ifndef GKO_PUBLIC_CORE_DISTRIBUTED_COLLECTIVE_COMMUNICATOR_HPP_
6 #define GKO_PUBLIC_CORE_DISTRIBUTED_COLLECTIVE_COMMUNICATOR_HPP_
7 
8 
9 #include <ginkgo/config.hpp>
10 
11 
12 #if GINKGO_BUILD_MPI
13 
14 
15 #include <ginkgo/core/base/mpi.hpp>
16 #include <ginkgo/core/distributed/index_map_fwd.hpp>
17 
18 
19 namespace gko {
20 namespace experimental {
21 namespace mpi {
22 
23 
31 public:
32  virtual ~CollectiveCommunicator() = default;
33 
34  explicit CollectiveCommunicator(communicator base = MPI_COMM_NULL);
35 
36  [[nodiscard]] const communicator& get_base_communicator() const;
37 
54  template <typename SendType, typename RecvType>
55  [[nodiscard]] request i_all_to_all_v(std::shared_ptr<const Executor> exec,
56  const SendType* send_buffer,
57  RecvType* recv_buffer) const;
58 
63  request i_all_to_all_v(std::shared_ptr<const Executor> exec,
64  const void* send_buffer, MPI_Datatype send_type,
65  void* recv_buffer, MPI_Datatype recv_type) const;
66 
75  [[nodiscard]] virtual std::unique_ptr<CollectiveCommunicator>
77  const distributed::index_map_variant& imap) const = 0;
78 
86  [[nodiscard]] virtual std::unique_ptr<CollectiveCommunicator>
87  create_inverse() const = 0;
88 
95  [[nodiscard]] virtual comm_index_type get_recv_size() const = 0;
96 
103  [[nodiscard]] virtual comm_index_type get_send_size() const = 0;
104 
105 protected:
106  virtual request i_all_to_all_v_impl(std::shared_ptr<const Executor> exec,
107  const void* send_buffer,
108  MPI_Datatype send_type,
109  void* recv_buffer,
110  MPI_Datatype recv_type) const = 0;
111 
112 private:
113  communicator base_;
114 };
115 
116 
117 template <typename SendType, typename RecvType>
119  std::shared_ptr<const Executor> exec, const SendType* send_buffer,
120  RecvType* recv_buffer) const
121 {
122  return this->i_all_to_all_v(std::move(exec), send_buffer,
123  type_impl<SendType>::get_type(), recv_buffer,
125 }
126 
127 
128 } // namespace mpi
129 } // namespace experimental
130 } // namespace gko
131 
132 
133 #endif
134 #endif // GKO_PUBLIC_CORE_DISTRIBUTED_COLLECTIVE_COMMUNICATOR_HPP_
gko::experimental::mpi::CollectiveCommunicator::i_all_to_all_v
request i_all_to_all_v(std::shared_ptr< const Executor > exec, const SendType *send_buffer, RecvType *recv_buffer) const
Non-blocking all-to-all communication.
Definition: collective_communicator.hpp:118
gko::experimental::mpi::request
The request class is a light, move-only wrapper around the MPI_Request handle.
Definition: mpi.hpp:327
gko::experimental::mpi::CollectiveCommunicator
Interface for a collective communicator.
Definition: collective_communicator.hpp:30
gko
The Ginkgo namespace.
Definition: abstract_factory.hpp:20
gko::experimental::mpi::CollectiveCommunicator::create_with_same_type
virtual std::unique_ptr< CollectiveCommunicator > create_with_same_type(communicator base, const distributed::index_map_variant &imap) const =0
Creates a new CollectiveCommunicator with the same dynamic type.
gko::experimental::mpi::CollectiveCommunicator::create_inverse
virtual std::unique_ptr< CollectiveCommunicator > create_inverse() const =0
Creates a CollectiveCommunicator with the inverse communication pattern than this object.
gko::experimental::mpi::CollectiveCommunicator::get_send_size
virtual comm_index_type get_send_size() const =0
Get the total number of sent elements this communication patterns expects.
gko::experimental::mpi::communicator
A thin wrapper of MPI_Comm that supports most MPI calls.
Definition: mpi.hpp:416
gko::experimental::mpi::comm_index_type
int comm_index_type
Index type for enumerating processes in a distributed application.
Definition: types.hpp:967
gko::experimental::mpi::type_impl
A struct that is used to determine the MPI_Datatype of a specified type.
Definition: mpi.hpp:77
gko::experimental::mpi::CollectiveCommunicator::get_recv_size
virtual comm_index_type get_recv_size() const =0
Get the total number of received elements this communication patterns expects.