Ginkgo  Generated from pipelines/1868155508 branch based on main. Ginkgo version 1.10.0
A numerical linear algebra library targeting many-core architectures
mpi.hpp
1 // SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
2 //
3 // SPDX-License-Identifier: BSD-3-Clause
4 
5 #ifndef GKO_PUBLIC_CORE_BASE_MPI_HPP_
6 #define GKO_PUBLIC_CORE_BASE_MPI_HPP_
7 
8 
9 #include <memory>
10 #include <type_traits>
11 #include <utility>
12 
13 #include <ginkgo/config.hpp>
14 #include <ginkgo/core/base/exception.hpp>
15 #include <ginkgo/core/base/exception_helpers.hpp>
16 #include <ginkgo/core/base/executor.hpp>
17 #include <ginkgo/core/base/half.hpp>
18 #include <ginkgo/core/base/types.hpp>
19 #include <ginkgo/core/base/utils_helper.hpp>
20 
21 
22 #if GINKGO_BUILD_MPI
23 
24 
25 #include <mpi.h>
26 
27 
28 namespace gko {
29 namespace experimental {
36 namespace mpi {
37 
38 
42 inline constexpr bool is_gpu_aware()
43 {
44 #if GINKGO_HAVE_GPU_AWARE_MPI
45  return true;
46 #else
47  return false;
48 #endif
49 }
50 
51 
59 int map_rank_to_device_id(MPI_Comm comm, int num_devices);
60 
61 
62 #define GKO_REGISTER_MPI_TYPE(input_type, mpi_type) \
63  template <> \
64  struct type_impl<input_type> { \
65  static MPI_Datatype get_type() { return mpi_type; } \
66  }
67 
76 template <typename T>
77 struct type_impl {};
78 
79 
80 GKO_REGISTER_MPI_TYPE(char, MPI_CHAR);
81 GKO_REGISTER_MPI_TYPE(unsigned char, MPI_UNSIGNED_CHAR);
82 GKO_REGISTER_MPI_TYPE(unsigned, MPI_UNSIGNED);
83 GKO_REGISTER_MPI_TYPE(int, MPI_INT);
84 GKO_REGISTER_MPI_TYPE(unsigned short, MPI_UNSIGNED_SHORT);
85 GKO_REGISTER_MPI_TYPE(unsigned long, MPI_UNSIGNED_LONG);
86 GKO_REGISTER_MPI_TYPE(long, MPI_LONG);
87 GKO_REGISTER_MPI_TYPE(long long, MPI_LONG_LONG_INT);
88 GKO_REGISTER_MPI_TYPE(unsigned long long, MPI_UNSIGNED_LONG_LONG);
89 GKO_REGISTER_MPI_TYPE(float, MPI_FLOAT);
90 GKO_REGISTER_MPI_TYPE(double, MPI_DOUBLE);
91 GKO_REGISTER_MPI_TYPE(long double, MPI_LONG_DOUBLE);
92 #if GINKGO_ENABLE_HALF
93 // OpenMPI 5.0 have support from MPIX_C_FLOAT16 and MPICHv3.4a1 MPIX_C_FLOAT16
94 // Only OpenMPI support complex float16
95 // TODO: use native type when mpi is configured with half feature
96 GKO_REGISTER_MPI_TYPE(half, MPI_UNSIGNED_SHORT);
97 GKO_REGISTER_MPI_TYPE(std::complex<half>, MPI_FLOAT);
98 #endif // GKO_ENABLE_HALF
99 #if GINKGO_ENABLE_BFLOAT16
100 GKO_REGISTER_MPI_TYPE(bfloat16, MPI_UNSIGNED_SHORT);
101 GKO_REGISTER_MPI_TYPE(std::complex<bfloat16>, MPI_FLOAT);
102 #endif // GKO_ENABLE_BFLOAT16
103 GKO_REGISTER_MPI_TYPE(std::complex<float>, MPI_C_FLOAT_COMPLEX);
104 GKO_REGISTER_MPI_TYPE(std::complex<double>, MPI_C_DOUBLE_COMPLEX);
105 
106 
114 public:
121  contiguous_type(int count, MPI_Datatype old_type) : type_(MPI_DATATYPE_NULL)
122  {
123  GKO_ASSERT_NO_MPI_ERRORS(MPI_Type_contiguous(count, old_type, &type_));
124  GKO_ASSERT_NO_MPI_ERRORS(MPI_Type_commit(&type_));
125  }
126 
130  contiguous_type() : type_(MPI_DATATYPE_NULL) {}
131 
135  contiguous_type(const contiguous_type&) = delete;
136 
140  contiguous_type& operator=(const contiguous_type&) = delete;
141 
147  contiguous_type(contiguous_type&& other) noexcept : type_(MPI_DATATYPE_NULL)
148  {
149  *this = std::move(other);
150  }
151 
160  {
161  if (this != &other) {
162  this->type_ = std::exchange(other.type_, MPI_DATATYPE_NULL);
163  }
164  return *this;
165  }
166 
171  {
172  if (type_ != MPI_DATATYPE_NULL) {
173  MPI_Type_free(&type_);
174  }
175  }
176 
182  MPI_Datatype get() const { return type_; }
183 
184 private:
185  MPI_Datatype type_;
186 };
187 
188 
193 enum class thread_type {
194  serialized = MPI_THREAD_SERIALIZED,
195  funneled = MPI_THREAD_FUNNELED,
196  single = MPI_THREAD_SINGLE,
197  multiple = MPI_THREAD_MULTIPLE
198 };
199 
200 
210 class environment {
211 public:
212  static bool is_finalized()
213  {
214  int flag = 0;
215  GKO_ASSERT_NO_MPI_ERRORS(MPI_Finalized(&flag));
216  return flag;
217  }
218 
219  static bool is_initialized()
220  {
221  int flag = 0;
222  GKO_ASSERT_NO_MPI_ERRORS(MPI_Initialized(&flag));
223  return flag;
224  }
225 
231  int get_provided_thread_support() const { return provided_thread_support_; }
232 
241  environment(int& argc, char**& argv,
242  const thread_type thread_t = thread_type::serialized)
243  {
244  this->required_thread_support_ = static_cast<int>(thread_t);
245  GKO_ASSERT_NO_MPI_ERRORS(
246  MPI_Init_thread(&argc, &argv, this->required_thread_support_,
247  &(this->provided_thread_support_)));
248  }
249 
253  ~environment() { MPI_Finalize(); }
254 
255  environment(const environment&) = delete;
256  environment(environment&&) = delete;
257  environment& operator=(const environment&) = delete;
258  environment& operator=(environment&&) = delete;
259 
260 private:
261  int required_thread_support_;
262  int provided_thread_support_;
263 };
264 
265 
266 namespace {
267 
268 
273 class comm_deleter {
274 public:
275  using pointer = MPI_Comm*;
276  void operator()(pointer comm) const
277  {
278  GKO_ASSERT(*comm != MPI_COMM_NULL);
279  GKO_ASSERT_NO_MPI_ERRORS(MPI_Comm_free(comm));
280  delete comm;
281  }
282 };
283 
284 
285 } // namespace
286 
287 
291 struct status {
295  status() : status_(MPI_Status{}) {}
296 
302  MPI_Status* get() { return &this->status_; }
303 
314  template <typename T>
315  int get_count(const T* data) const
316  {
317  int count;
318  MPI_Get_count(&status_, type_impl<T>::get_type(), &count);
319  return count;
320  }
321 
322 private:
323  MPI_Status status_;
324 };
325 
326 
331 class request {
332 public:
337  request() : req_(MPI_REQUEST_NULL) {}
338 
339  request(const request&) = delete;
340 
341  request& operator=(const request&) = delete;
342 
343  request(request&& o) noexcept { *this = std::move(o); }
344 
345  request& operator=(request&& o) noexcept
346  {
347  if (this != &o) {
348  this->req_ = std::exchange(o.req_, MPI_REQUEST_NULL);
349  }
350  return *this;
351  }
352 
353  ~request()
354  {
355  if (req_ != MPI_REQUEST_NULL) {
356  if (MPI_Request_free(&req_) != MPI_SUCCESS) {
357  std::terminate(); // since we can't throw in destructors, we
358  // have to terminate the program
359  }
360  }
361  }
362 
368  MPI_Request* get() { return &this->req_; }
369 
377  {
378  status status;
379  GKO_ASSERT_NO_MPI_ERRORS(MPI_Wait(&req_, status.get()));
380  return status;
381  }
382 
383 private:
384  MPI_Request req_;
385 };
386 
387 
395 inline std::vector<status> wait_all(std::vector<request>& req)
396 {
397  std::vector<status> stat;
398  for (std::size_t i = 0; i < req.size(); ++i) {
399  stat.emplace_back(req[i].wait());
400  }
401  return stat;
402 }
403 
404 
420 public:
431  communicator(const MPI_Comm& comm, bool force_host_buffer = false)
432  : comm_(), force_host_buffer_(force_host_buffer)
433  {
434  this->comm_.reset(new MPI_Comm(comm));
435  }
436 
445  communicator(const MPI_Comm& comm, int color, int key)
446  {
447  MPI_Comm comm_out;
448  GKO_ASSERT_NO_MPI_ERRORS(MPI_Comm_split(comm, color, key, &comm_out));
449  this->comm_.reset(new MPI_Comm(comm_out), comm_deleter{});
450  }
451 
460  communicator(const communicator& comm, int color, int key)
461  {
462  MPI_Comm comm_out;
463  GKO_ASSERT_NO_MPI_ERRORS(
464  MPI_Comm_split(comm.get(), color, key, &comm_out));
465  this->comm_.reset(new MPI_Comm(comm_out), comm_deleter{});
466  }
467 
477  static communicator create_owning(const MPI_Comm& comm,
478  bool force_host_buffer = false)
479  {
480  communicator comm_out(MPI_COMM_NULL, force_host_buffer);
481  comm_out.comm_.reset(new MPI_Comm(comm), comm_deleter{});
482  return comm_out;
483  }
484 
490  communicator(const communicator& other) = default;
491 
498  communicator(communicator&& other) { *this = std::move(other); }
499 
503  communicator& operator=(const communicator& other) = default;
504 
509  {
510  if (this != &other) {
511  comm_ = std::exchange(other.comm_,
512  std::make_shared<MPI_Comm>(MPI_COMM_NULL));
513  force_host_buffer_ = other.force_host_buffer_;
514  }
515  return *this;
516  }
517 
523  const MPI_Comm& get() const { return *(this->comm_.get()); }
524 
525  bool force_host_buffer() const { return force_host_buffer_; }
526 
532  int size() const { return get_num_ranks(); }
533 
539  int rank() const { return get_my_rank(); };
540 
546  int node_local_rank() const { return get_node_local_rank(); };
547 
553  bool operator==(const communicator& rhs) const { return is_identical(rhs); }
554 
560  bool operator!=(const communicator& rhs) const { return !(*this == rhs); }
561 
571  bool is_identical(const communicator& rhs) const
572  {
573  if (get() == MPI_COMM_NULL || rhs.get() == MPI_COMM_NULL) {
574  return get() == rhs.get();
575  }
576  int flag;
577  GKO_ASSERT_NO_MPI_ERRORS(MPI_Comm_compare(get(), rhs.get(), &flag));
578  return flag == MPI_IDENT;
579  }
580 
593  bool is_congruent(const communicator& rhs) const
594  {
595  if (get() == MPI_COMM_NULL || rhs.get() == MPI_COMM_NULL) {
596  return get() == rhs.get();
597  }
598  int flag;
599  GKO_ASSERT_NO_MPI_ERRORS(MPI_Comm_compare(get(), rhs.get(), &flag));
600  return flag == MPI_CONGRUENT;
601  }
602 
607  void synchronize() const
608  {
609  GKO_ASSERT_NO_MPI_ERRORS(MPI_Barrier(this->get()));
610  }
611 
625  template <typename SendType>
626  void send(std::shared_ptr<const Executor> exec, const SendType* send_buffer,
627  const int send_count, const int destination_rank,
628  const int send_tag) const
629  {
630  auto guard = exec->get_scoped_device_id_guard();
631  GKO_ASSERT_NO_MPI_ERRORS(
632  MPI_Send(send_buffer, send_count, type_impl<SendType>::get_type(),
633  destination_rank, send_tag, this->get()));
634  }
635 
652  template <typename SendType>
653  request i_send(std::shared_ptr<const Executor> exec,
654  const SendType* send_buffer, const int send_count,
655  const int destination_rank, const int send_tag) const
656  {
657  auto guard = exec->get_scoped_device_id_guard();
658  request req;
659  GKO_ASSERT_NO_MPI_ERRORS(
660  MPI_Isend(send_buffer, send_count, type_impl<SendType>::get_type(),
661  destination_rank, send_tag, this->get(), req.get()));
662  return req;
663  }
664 
680  template <typename RecvType>
681  status recv(std::shared_ptr<const Executor> exec, RecvType* recv_buffer,
682  const int recv_count, const int source_rank,
683  const int recv_tag) const
684  {
685  auto guard = exec->get_scoped_device_id_guard();
686  status st;
687  GKO_ASSERT_NO_MPI_ERRORS(
688  MPI_Recv(recv_buffer, recv_count, type_impl<RecvType>::get_type(),
689  source_rank, recv_tag, this->get(), st.get()));
690  return st;
691  }
692 
708  template <typename RecvType>
709  request i_recv(std::shared_ptr<const Executor> exec, RecvType* recv_buffer,
710  const int recv_count, const int source_rank,
711  const int recv_tag) const
712  {
713  auto guard = exec->get_scoped_device_id_guard();
714  request req;
715  GKO_ASSERT_NO_MPI_ERRORS(
716  MPI_Irecv(recv_buffer, recv_count, type_impl<RecvType>::get_type(),
717  source_rank, recv_tag, this->get(), req.get()));
718  return req;
719  }
720 
733  template <typename BroadcastType>
734  void broadcast(std::shared_ptr<const Executor> exec, BroadcastType* buffer,
735  int count, int root_rank) const
736  {
737  auto guard = exec->get_scoped_device_id_guard();
738  GKO_ASSERT_NO_MPI_ERRORS(MPI_Bcast(buffer, count,
740  root_rank, this->get()));
741  }
742 
758  template <typename BroadcastType>
759  request i_broadcast(std::shared_ptr<const Executor> exec,
760  BroadcastType* buffer, int count, int root_rank) const
761  {
762  auto guard = exec->get_scoped_device_id_guard();
763  request req;
764  GKO_ASSERT_NO_MPI_ERRORS(
765  MPI_Ibcast(buffer, count, type_impl<BroadcastType>::get_type(),
766  root_rank, this->get(), req.get()));
767  return req;
768  }
769 
784  template <typename ReduceType>
785  void reduce(std::shared_ptr<const Executor> exec,
786  const ReduceType* send_buffer, ReduceType* recv_buffer,
787  int count, MPI_Op operation, int root_rank) const
788  {
789  auto guard = exec->get_scoped_device_id_guard();
790  GKO_ASSERT_NO_MPI_ERRORS(MPI_Reduce(send_buffer, recv_buffer, count,
792  operation, root_rank, this->get()));
793  }
794 
811  template <typename ReduceType>
812  request i_reduce(std::shared_ptr<const Executor> exec,
813  const ReduceType* send_buffer, ReduceType* recv_buffer,
814  int count, MPI_Op operation, int root_rank) const
815  {
816  auto guard = exec->get_scoped_device_id_guard();
817  request req;
818  GKO_ASSERT_NO_MPI_ERRORS(MPI_Ireduce(
819  send_buffer, recv_buffer, count, type_impl<ReduceType>::get_type(),
820  operation, root_rank, this->get(), req.get()));
821  return req;
822  }
823 
837  template <typename ReduceType>
838  void all_reduce(std::shared_ptr<const Executor> exec,
839  ReduceType* recv_buffer, int count, MPI_Op operation) const
840  {
841  auto guard = exec->get_scoped_device_id_guard();
842  GKO_ASSERT_NO_MPI_ERRORS(MPI_Allreduce(
843  MPI_IN_PLACE, recv_buffer, count, type_impl<ReduceType>::get_type(),
844  operation, this->get()));
845  }
846 
862  template <typename ReduceType>
863  request i_all_reduce(std::shared_ptr<const Executor> exec,
864  ReduceType* recv_buffer, int count,
865  MPI_Op operation) const
866  {
867  auto guard = exec->get_scoped_device_id_guard();
868  request req;
869  GKO_ASSERT_NO_MPI_ERRORS(MPI_Iallreduce(
870  MPI_IN_PLACE, recv_buffer, count, type_impl<ReduceType>::get_type(),
871  operation, this->get(), req.get()));
872  return req;
873  }
874 
889  template <typename ReduceType>
890  void all_reduce(std::shared_ptr<const Executor> exec,
891  const ReduceType* send_buffer, ReduceType* recv_buffer,
892  int count, MPI_Op operation) const
893  {
894  auto guard = exec->get_scoped_device_id_guard();
895  GKO_ASSERT_NO_MPI_ERRORS(MPI_Allreduce(
896  send_buffer, recv_buffer, count, type_impl<ReduceType>::get_type(),
897  operation, this->get()));
898  }
899 
916  template <typename ReduceType>
917  request i_all_reduce(std::shared_ptr<const Executor> exec,
918  const ReduceType* send_buffer, ReduceType* recv_buffer,
919  int count, MPI_Op operation) const
920  {
921  auto guard = exec->get_scoped_device_id_guard();
922  request req;
923  GKO_ASSERT_NO_MPI_ERRORS(MPI_Iallreduce(
924  send_buffer, recv_buffer, count, type_impl<ReduceType>::get_type(),
925  operation, this->get(), req.get()));
926  return req;
927  }
928 
945  template <typename SendType, typename RecvType>
946  void gather(std::shared_ptr<const Executor> exec,
947  const SendType* send_buffer, const int send_count,
948  RecvType* recv_buffer, const int recv_count,
949  int root_rank) const
950  {
951  auto guard = exec->get_scoped_device_id_guard();
952  GKO_ASSERT_NO_MPI_ERRORS(
953  MPI_Gather(send_buffer, send_count, type_impl<SendType>::get_type(),
954  recv_buffer, recv_count, type_impl<RecvType>::get_type(),
955  root_rank, this->get()));
956  }
957 
977  template <typename SendType, typename RecvType>
978  request i_gather(std::shared_ptr<const Executor> exec,
979  const SendType* send_buffer, const int send_count,
980  RecvType* recv_buffer, const int recv_count,
981  int root_rank) const
982  {
983  auto guard = exec->get_scoped_device_id_guard();
984  request req;
985  GKO_ASSERT_NO_MPI_ERRORS(MPI_Igather(
986  send_buffer, send_count, type_impl<SendType>::get_type(),
987  recv_buffer, recv_count, type_impl<RecvType>::get_type(), root_rank,
988  this->get(), req.get()));
989  return req;
990  }
991 
1010  template <typename SendType, typename RecvType>
1011  void gather_v(std::shared_ptr<const Executor> exec,
1012  const SendType* send_buffer, const int send_count,
1013  RecvType* recv_buffer, const int* recv_counts,
1014  const int* displacements, int root_rank) const
1015  {
1016  auto guard = exec->get_scoped_device_id_guard();
1017  GKO_ASSERT_NO_MPI_ERRORS(MPI_Gatherv(
1018  send_buffer, send_count, type_impl<SendType>::get_type(),
1019  recv_buffer, recv_counts, displacements,
1020  type_impl<RecvType>::get_type(), root_rank, this->get()));
1021  }
1022 
1043  template <typename SendType, typename RecvType>
1044  request i_gather_v(std::shared_ptr<const Executor> exec,
1045  const SendType* send_buffer, const int send_count,
1046  RecvType* recv_buffer, const int* recv_counts,
1047  const int* displacements, int root_rank) const
1048  {
1049  auto guard = exec->get_scoped_device_id_guard();
1050  request req;
1051  GKO_ASSERT_NO_MPI_ERRORS(MPI_Igatherv(
1052  send_buffer, send_count, type_impl<SendType>::get_type(),
1053  recv_buffer, recv_counts, displacements,
1054  type_impl<RecvType>::get_type(), root_rank, this->get(),
1055  req.get()));
1056  return req;
1057  }
1058 
1074  template <typename SendType, typename RecvType>
1075  void all_gather(std::shared_ptr<const Executor> exec,
1076  const SendType* send_buffer, const int send_count,
1077  RecvType* recv_buffer, const int recv_count) const
1078  {
1079  auto guard = exec->get_scoped_device_id_guard();
1080  GKO_ASSERT_NO_MPI_ERRORS(MPI_Allgather(
1081  send_buffer, send_count, type_impl<SendType>::get_type(),
1082  recv_buffer, recv_count, type_impl<RecvType>::get_type(),
1083  this->get()));
1084  }
1085 
1104  template <typename SendType, typename RecvType>
1105  request i_all_gather(std::shared_ptr<const Executor> exec,
1106  const SendType* send_buffer, const int send_count,
1107  RecvType* recv_buffer, const int recv_count) const
1108  {
1109  auto guard = exec->get_scoped_device_id_guard();
1110  request req;
1111  GKO_ASSERT_NO_MPI_ERRORS(MPI_Iallgather(
1112  send_buffer, send_count, type_impl<SendType>::get_type(),
1113  recv_buffer, recv_count, type_impl<RecvType>::get_type(),
1114  this->get(), req.get()));
1115  return req;
1116  }
1117 
1133  template <typename SendType, typename RecvType>
1134  void scatter(std::shared_ptr<const Executor> exec,
1135  const SendType* send_buffer, const int send_count,
1136  RecvType* recv_buffer, const int recv_count,
1137  int root_rank) const
1138  {
1139  auto guard = exec->get_scoped_device_id_guard();
1140  GKO_ASSERT_NO_MPI_ERRORS(MPI_Scatter(
1141  send_buffer, send_count, type_impl<SendType>::get_type(),
1142  recv_buffer, recv_count, type_impl<RecvType>::get_type(), root_rank,
1143  this->get()));
1144  }
1145 
1164  template <typename SendType, typename RecvType>
1165  request i_scatter(std::shared_ptr<const Executor> exec,
1166  const SendType* send_buffer, const int send_count,
1167  RecvType* recv_buffer, const int recv_count,
1168  int root_rank) const
1169  {
1170  auto guard = exec->get_scoped_device_id_guard();
1171  request req;
1172  GKO_ASSERT_NO_MPI_ERRORS(MPI_Iscatter(
1173  send_buffer, send_count, type_impl<SendType>::get_type(),
1174  recv_buffer, recv_count, type_impl<RecvType>::get_type(), root_rank,
1175  this->get(), req.get()));
1176  return req;
1177  }
1178 
1197  template <typename SendType, typename RecvType>
1198  void scatter_v(std::shared_ptr<const Executor> exec,
1199  const SendType* send_buffer, const int* send_counts,
1200  const int* displacements, RecvType* recv_buffer,
1201  const int recv_count, int root_rank) const
1202  {
1203  auto guard = exec->get_scoped_device_id_guard();
1204  GKO_ASSERT_NO_MPI_ERRORS(MPI_Scatterv(
1205  send_buffer, send_counts, displacements,
1206  type_impl<SendType>::get_type(), recv_buffer, recv_count,
1207  type_impl<RecvType>::get_type(), root_rank, this->get()));
1208  }
1209 
1230  template <typename SendType, typename RecvType>
1231  request i_scatter_v(std::shared_ptr<const Executor> exec,
1232  const SendType* send_buffer, const int* send_counts,
1233  const int* displacements, RecvType* recv_buffer,
1234  const int recv_count, int root_rank) const
1235  {
1236  auto guard = exec->get_scoped_device_id_guard();
1237  request req;
1238  GKO_ASSERT_NO_MPI_ERRORS(
1239  MPI_Iscatterv(send_buffer, send_counts, displacements,
1240  type_impl<SendType>::get_type(), recv_buffer,
1241  recv_count, type_impl<RecvType>::get_type(),
1242  root_rank, this->get(), req.get()));
1243  return req;
1244  }
1245 
1262  template <typename RecvType>
1263  void all_to_all(std::shared_ptr<const Executor> exec, RecvType* recv_buffer,
1264  const int recv_count) const
1265  {
1266  auto guard = exec->get_scoped_device_id_guard();
1267  GKO_ASSERT_NO_MPI_ERRORS(MPI_Alltoall(
1268  MPI_IN_PLACE, recv_count, type_impl<RecvType>::get_type(),
1269  recv_buffer, recv_count, type_impl<RecvType>::get_type(),
1270  this->get()));
1271  }
1272 
1291  template <typename RecvType>
1292  request i_all_to_all(std::shared_ptr<const Executor> exec,
1293  RecvType* recv_buffer, const int recv_count) const
1294  {
1295  auto guard = exec->get_scoped_device_id_guard();
1296  request req;
1297  GKO_ASSERT_NO_MPI_ERRORS(MPI_Ialltoall(
1298  MPI_IN_PLACE, recv_count, type_impl<RecvType>::get_type(),
1299  recv_buffer, recv_count, type_impl<RecvType>::get_type(),
1300  this->get(), req.get()));
1301  return req;
1302  }
1303 
1320  template <typename SendType, typename RecvType>
1321  void all_to_all(std::shared_ptr<const Executor> exec,
1322  const SendType* send_buffer, const int send_count,
1323  RecvType* recv_buffer, const int recv_count) const
1324  {
1325  auto guard = exec->get_scoped_device_id_guard();
1326  GKO_ASSERT_NO_MPI_ERRORS(MPI_Alltoall(
1327  send_buffer, send_count, type_impl<SendType>::get_type(),
1328  recv_buffer, recv_count, type_impl<RecvType>::get_type(),
1329  this->get()));
1330  }
1331 
1350  template <typename SendType, typename RecvType>
1351  request i_all_to_all(std::shared_ptr<const Executor> exec,
1352  const SendType* send_buffer, const int send_count,
1353  RecvType* recv_buffer, const int recv_count) const
1354  {
1355  auto guard = exec->get_scoped_device_id_guard();
1356  request req;
1357  GKO_ASSERT_NO_MPI_ERRORS(MPI_Ialltoall(
1358  send_buffer, send_count, type_impl<SendType>::get_type(),
1359  recv_buffer, recv_count, type_impl<RecvType>::get_type(),
1360  this->get(), req.get()));
1361  return req;
1362  }
1363 
1383  template <typename SendType, typename RecvType>
1384  void all_to_all_v(std::shared_ptr<const Executor> exec,
1385  const SendType* send_buffer, const int* send_counts,
1386  const int* send_offsets, RecvType* recv_buffer,
1387  const int* recv_counts, const int* recv_offsets) const
1388  {
1389  this->all_to_all_v(std::move(exec), send_buffer, send_counts,
1390  send_offsets, type_impl<SendType>::get_type(),
1391  recv_buffer, recv_counts, recv_offsets,
1393  }
1394 
1410  void all_to_all_v(std::shared_ptr<const Executor> exec,
1411  const void* send_buffer, const int* send_counts,
1412  const int* send_offsets, MPI_Datatype send_type,
1413  void* recv_buffer, const int* recv_counts,
1414  const int* recv_offsets, MPI_Datatype recv_type) const
1415  {
1416  auto guard = exec->get_scoped_device_id_guard();
1417  GKO_ASSERT_NO_MPI_ERRORS(MPI_Alltoallv(
1418  send_buffer, send_counts, send_offsets, send_type, recv_buffer,
1419  recv_counts, recv_offsets, recv_type, this->get()));
1420  }
1421 
1441  request i_all_to_all_v(std::shared_ptr<const Executor> exec,
1442  const void* send_buffer, const int* send_counts,
1443  const int* send_offsets, MPI_Datatype send_type,
1444  void* recv_buffer, const int* recv_counts,
1445  const int* recv_offsets,
1446  MPI_Datatype recv_type) const
1447  {
1448  auto guard = exec->get_scoped_device_id_guard();
1449  request req;
1450  GKO_ASSERT_NO_MPI_ERRORS(MPI_Ialltoallv(
1451  send_buffer, send_counts, send_offsets, send_type, recv_buffer,
1452  recv_counts, recv_offsets, recv_type, this->get(), req.get()));
1453  return req;
1454  }
1455 
1476  template <typename SendType, typename RecvType>
1477  request i_all_to_all_v(std::shared_ptr<const Executor> exec,
1478  const SendType* send_buffer, const int* send_counts,
1479  const int* send_offsets, RecvType* recv_buffer,
1480  const int* recv_counts,
1481  const int* recv_offsets) const
1482  {
1483  return this->i_all_to_all_v(
1484  std::move(exec), send_buffer, send_counts, send_offsets,
1485  type_impl<SendType>::get_type(), recv_buffer, recv_counts,
1486  recv_offsets, type_impl<RecvType>::get_type());
1487  }
1488 
1503  template <typename ScanType>
1504  void scan(std::shared_ptr<const Executor> exec, const ScanType* send_buffer,
1505  ScanType* recv_buffer, int count, MPI_Op operation) const
1506  {
1507  auto guard = exec->get_scoped_device_id_guard();
1508  GKO_ASSERT_NO_MPI_ERRORS(MPI_Scan(send_buffer, recv_buffer, count,
1510  operation, this->get()));
1511  }
1512 
1529  template <typename ScanType>
1530  request i_scan(std::shared_ptr<const Executor> exec,
1531  const ScanType* send_buffer, ScanType* recv_buffer,
1532  int count, MPI_Op operation) const
1533  {
1534  auto guard = exec->get_scoped_device_id_guard();
1535  request req;
1536  GKO_ASSERT_NO_MPI_ERRORS(MPI_Iscan(send_buffer, recv_buffer, count,
1538  operation, this->get(), req.get()));
1539  return req;
1540  }
1541 
1542 private:
1543  std::shared_ptr<MPI_Comm> comm_;
1544  bool force_host_buffer_;
1545 
1546  int get_my_rank() const
1547  {
1548  int my_rank = 0;
1549  GKO_ASSERT_NO_MPI_ERRORS(MPI_Comm_rank(get(), &my_rank));
1550  return my_rank;
1551  }
1552 
1553  int get_node_local_rank() const
1554  {
1555  MPI_Comm local_comm;
1556  int rank;
1557  GKO_ASSERT_NO_MPI_ERRORS(MPI_Comm_split_type(
1558  this->get(), MPI_COMM_TYPE_SHARED, 0, MPI_INFO_NULL, &local_comm));
1559  GKO_ASSERT_NO_MPI_ERRORS(MPI_Comm_rank(local_comm, &rank));
1560  MPI_Comm_free(&local_comm);
1561  return rank;
1562  }
1563 
1564  int get_num_ranks() const
1565  {
1566  int size = 1;
1567  GKO_ASSERT_NO_MPI_ERRORS(MPI_Comm_size(this->get(), &size));
1568  return size;
1569  }
1570 };
1571 
1572 
1577 bool requires_host_buffer(const std::shared_ptr<const Executor>& exec,
1578  const communicator& comm);
1579 
1580 
1586 inline double get_walltime() { return MPI_Wtime(); }
1587 
1588 
1597 template <typename ValueType>
1598 class window {
1599 public:
1603  enum class create_type { allocate = 1, create = 2, dynamic_create = 3 };
1604 
1608  enum class lock_type { shared = 1, exclusive = 2 };
1609 
1613  window() : window_(MPI_WIN_NULL) {}
1614 
1615  window(const window& other) = delete;
1616 
1617  window& operator=(const window& other) = delete;
1618 
1625  window(window&& other) : window_{std::exchange(other.window_, MPI_WIN_NULL)}
1626  {}
1627 
1635  {
1636  window_ = std::exchange(other.window_, MPI_WIN_NULL);
1637  }
1638 
1651  window(std::shared_ptr<const Executor> exec, ValueType* base, int num_elems,
1652  const communicator& comm, const int disp_unit = sizeof(ValueType),
1653  MPI_Info input_info = MPI_INFO_NULL,
1654  create_type c_type = create_type::create)
1655  {
1656  auto guard = exec->get_scoped_device_id_guard();
1657  unsigned size = num_elems * sizeof(ValueType);
1658  if (c_type == create_type::create) {
1659  GKO_ASSERT_NO_MPI_ERRORS(MPI_Win_create(
1660  base, size, disp_unit, input_info, comm.get(), &this->window_));
1661  } else if (c_type == create_type::dynamic_create) {
1662  GKO_ASSERT_NO_MPI_ERRORS(
1663  MPI_Win_create_dynamic(input_info, comm.get(), &this->window_));
1664  } else if (c_type == create_type::allocate) {
1665  GKO_ASSERT_NO_MPI_ERRORS(MPI_Win_allocate(
1666  size, disp_unit, input_info, comm.get(), base, &this->window_));
1667  } else {
1668  GKO_NOT_IMPLEMENTED;
1669  }
1670  }
1671 
1677  MPI_Win get_window() const { return this->window_; }
1678 
1685  void fence(int assert = 0) const
1686  {
1687  GKO_ASSERT_NO_MPI_ERRORS(MPI_Win_fence(assert, this->window_));
1688  }
1689 
1698  void lock(int rank, lock_type lock_t = lock_type::shared,
1699  int assert = 0) const
1700  {
1701  if (lock_t == lock_type::shared) {
1702  GKO_ASSERT_NO_MPI_ERRORS(
1703  MPI_Win_lock(MPI_LOCK_SHARED, rank, assert, this->window_));
1704  } else if (lock_t == lock_type::exclusive) {
1705  GKO_ASSERT_NO_MPI_ERRORS(
1706  MPI_Win_lock(MPI_LOCK_EXCLUSIVE, rank, assert, this->window_));
1707  } else {
1708  GKO_NOT_IMPLEMENTED;
1709  }
1710  }
1711 
1718  void unlock(int rank) const
1719  {
1720  GKO_ASSERT_NO_MPI_ERRORS(MPI_Win_unlock(rank, this->window_));
1721  }
1722 
1729  void lock_all(int assert = 0) const
1730  {
1731  GKO_ASSERT_NO_MPI_ERRORS(MPI_Win_lock_all(assert, this->window_));
1732  }
1733 
1738  void unlock_all() const
1739  {
1740  GKO_ASSERT_NO_MPI_ERRORS(MPI_Win_unlock_all(this->window_));
1741  }
1742 
1749  void flush(int rank) const
1750  {
1751  GKO_ASSERT_NO_MPI_ERRORS(MPI_Win_flush(rank, this->window_));
1752  }
1753 
1760  void flush_local(int rank) const
1761  {
1762  GKO_ASSERT_NO_MPI_ERRORS(MPI_Win_flush_local(rank, this->window_));
1763  }
1764 
1769  void flush_all() const
1770  {
1771  GKO_ASSERT_NO_MPI_ERRORS(MPI_Win_flush_all(this->window_));
1772  }
1773 
1778  void flush_all_local() const
1779  {
1780  GKO_ASSERT_NO_MPI_ERRORS(MPI_Win_flush_local_all(this->window_));
1781  }
1782 
1786  void sync() const { GKO_ASSERT_NO_MPI_ERRORS(MPI_Win_sync(this->window_)); }
1787 
1792  {
1793  if (this->window_ && this->window_ != MPI_WIN_NULL) {
1794  MPI_Win_free(&this->window_);
1795  }
1796  }
1797 
1808  template <typename PutType>
1809  void put(std::shared_ptr<const Executor> exec, const PutType* origin_buffer,
1810  const int origin_count, const int target_rank,
1811  const unsigned int target_disp, const int target_count) const
1812  {
1813  auto guard = exec->get_scoped_device_id_guard();
1814  GKO_ASSERT_NO_MPI_ERRORS(
1815  MPI_Put(origin_buffer, origin_count, type_impl<PutType>::get_type(),
1816  target_rank, target_disp, target_count,
1818  }
1819 
1832  template <typename PutType>
1833  request r_put(std::shared_ptr<const Executor> exec,
1834  const PutType* origin_buffer, const int origin_count,
1835  const int target_rank, const unsigned int target_disp,
1836  const int target_count) const
1837  {
1838  auto guard = exec->get_scoped_device_id_guard();
1839  request req;
1840  GKO_ASSERT_NO_MPI_ERRORS(MPI_Rput(
1841  origin_buffer, origin_count, type_impl<PutType>::get_type(),
1842  target_rank, target_disp, target_count,
1843  type_impl<PutType>::get_type(), this->get_window(), req.get()));
1844  return req;
1845  }
1846 
1858  template <typename PutType>
1859  void accumulate(std::shared_ptr<const Executor> exec,
1860  const PutType* origin_buffer, const int origin_count,
1861  const int target_rank, const unsigned int target_disp,
1862  const int target_count, MPI_Op operation) const
1863  {
1864  auto guard = exec->get_scoped_device_id_guard();
1865  GKO_ASSERT_NO_MPI_ERRORS(MPI_Accumulate(
1866  origin_buffer, origin_count, type_impl<PutType>::get_type(),
1867  target_rank, target_disp, target_count,
1868  type_impl<PutType>::get_type(), operation, this->get_window()));
1869  }
1870 
1884  template <typename PutType>
1885  request r_accumulate(std::shared_ptr<const Executor> exec,
1886  const PutType* origin_buffer, const int origin_count,
1887  const int target_rank, const unsigned int target_disp,
1888  const int target_count, MPI_Op operation) const
1889  {
1890  auto guard = exec->get_scoped_device_id_guard();
1891  request req;
1892  GKO_ASSERT_NO_MPI_ERRORS(MPI_Raccumulate(
1893  origin_buffer, origin_count, type_impl<PutType>::get_type(),
1894  target_rank, target_disp, target_count,
1895  type_impl<PutType>::get_type(), operation, this->get_window(),
1896  req.get()));
1897  return req;
1898  }
1899 
1910  template <typename GetType>
1911  void get(std::shared_ptr<const Executor> exec, GetType* origin_buffer,
1912  const int origin_count, const int target_rank,
1913  const unsigned int target_disp, const int target_count) const
1914  {
1915  auto guard = exec->get_scoped_device_id_guard();
1916  GKO_ASSERT_NO_MPI_ERRORS(
1917  MPI_Get(origin_buffer, origin_count, type_impl<GetType>::get_type(),
1918  target_rank, target_disp, target_count,
1920  }
1921 
1934  template <typename GetType>
1935  request r_get(std::shared_ptr<const Executor> exec, GetType* origin_buffer,
1936  const int origin_count, const int target_rank,
1937  const unsigned int target_disp, const int target_count) const
1938  {
1939  auto guard = exec->get_scoped_device_id_guard();
1940  request req;
1941  GKO_ASSERT_NO_MPI_ERRORS(MPI_Rget(
1942  origin_buffer, origin_count, type_impl<GetType>::get_type(),
1943  target_rank, target_disp, target_count,
1944  type_impl<GetType>::get_type(), this->get_window(), req.get()));
1945  return req;
1946  }
1947 
1961  template <typename GetType>
1962  void get_accumulate(std::shared_ptr<const Executor> exec,
1963  GetType* origin_buffer, const int origin_count,
1964  GetType* result_buffer, const int result_count,
1965  const int target_rank, const unsigned int target_disp,
1966  const int target_count, MPI_Op operation) const
1967  {
1968  auto guard = exec->get_scoped_device_id_guard();
1969  GKO_ASSERT_NO_MPI_ERRORS(MPI_Get_accumulate(
1970  origin_buffer, origin_count, type_impl<GetType>::get_type(),
1971  result_buffer, result_count, type_impl<GetType>::get_type(),
1972  target_rank, target_disp, target_count,
1973  type_impl<GetType>::get_type(), operation, this->get_window()));
1974  }
1975 
1991  template <typename GetType>
1992  request r_get_accumulate(std::shared_ptr<const Executor> exec,
1993  GetType* origin_buffer, const int origin_count,
1994  GetType* result_buffer, const int result_count,
1995  const int target_rank,
1996  const unsigned int target_disp,
1997  const int target_count, MPI_Op operation) const
1998  {
1999  auto guard = exec->get_scoped_device_id_guard();
2000  request req;
2001  GKO_ASSERT_NO_MPI_ERRORS(MPI_Rget_accumulate(
2002  origin_buffer, origin_count, type_impl<GetType>::get_type(),
2003  result_buffer, result_count, type_impl<GetType>::get_type(),
2004  target_rank, target_disp, target_count,
2005  type_impl<GetType>::get_type(), operation, this->get_window(),
2006  req.get()));
2007  return req;
2008  }
2009 
2020  template <typename GetType>
2021  void fetch_and_op(std::shared_ptr<const Executor> exec,
2022  GetType* origin_buffer, GetType* result_buffer,
2023  const int target_rank, const unsigned int target_disp,
2024  MPI_Op operation) const
2025  {
2026  auto guard = exec->get_scoped_device_id_guard();
2027  GKO_ASSERT_NO_MPI_ERRORS(MPI_Fetch_and_op(
2028  origin_buffer, result_buffer, type_impl<GetType>::get_type(),
2029  target_rank, target_disp, operation, this->get_window()));
2030  }
2031 
2032 private:
2033  MPI_Win window_;
2034 };
2035 
2036 
2037 } // namespace mpi
2038 } // namespace experimental
2039 } // namespace gko
2040 
2041 
2042 #endif // GKO_HAVE_MPI
2043 
2044 
2045 #endif // GKO_PUBLIC_CORE_BASE_MPI_HPP_
gko::experimental::mpi::window
This class wraps the MPI_Window class with RAII functionality.
Definition: mpi.hpp:1598
gko::experimental::mpi::environment::get_provided_thread_support
int get_provided_thread_support() const
Return the provided thread support.
Definition: mpi.hpp:231
gko::experimental::mpi::requires_host_buffer
bool requires_host_buffer(const std::shared_ptr< const Executor > &exec, const communicator &comm)
Checks if the combination of Executor and communicator requires passing MPI buffers from the host mem...
gko::experimental::mpi::communicator::i_scan
request i_scan(std::shared_ptr< const Executor > exec, const ScanType *send_buffer, ScanType *recv_buffer, int count, MPI_Op operation) const
Does a scan operation with the given operator.
Definition: mpi.hpp:1530
gko::experimental::mpi::contiguous_type::contiguous_type
contiguous_type()
Constructs empty wrapper with MPI_DATATYPE_NULL.
Definition: mpi.hpp:130
gko::experimental::mpi::communicator::scan
void scan(std::shared_ptr< const Executor > exec, const ScanType *send_buffer, ScanType *recv_buffer, int count, MPI_Op operation) const
Does a scan operation with the given operator.
Definition: mpi.hpp:1504
gko::experimental::mpi::window::get_accumulate
void get_accumulate(std::shared_ptr< const Executor > exec, GetType *origin_buffer, const int origin_count, GetType *result_buffer, const int result_count, const int target_rank, const unsigned int target_disp, const int target_count, MPI_Op operation) const
Get Accumulate data from the target window.
Definition: mpi.hpp:1962
gko::experimental::mpi::window::lock
void lock(int rank, lock_type lock_t=lock_type::shared, int assert=0) const
Create an epoch using MPI_Win_lock for the window object.
Definition: mpi.hpp:1698
gko::experimental::mpi::communicator::i_broadcast
request i_broadcast(std::shared_ptr< const Executor > exec, BroadcastType *buffer, int count, int root_rank) const
(Non-blocking) Broadcast data from calling process to all ranks in the communicator
Definition: mpi.hpp:759
gko::experimental::mpi::communicator::communicator
communicator(const communicator &comm, int color, int key)
Create a communicator object from an existing MPI_Comm object using color and key.
Definition: mpi.hpp:460
gko::experimental::mpi::window::window
window(std::shared_ptr< const Executor > exec, ValueType *base, int num_elems, const communicator &comm, const int disp_unit=sizeof(ValueType), MPI_Info input_info=MPI_INFO_NULL, create_type c_type=create_type::create)
Create a window object with a given data pointer and type.
Definition: mpi.hpp:1651
gko::experimental::mpi::communicator::i_all_to_all
request i_all_to_all(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int send_count, RecvType *recv_buffer, const int recv_count) const
(Non-blocking) Communicate data from all ranks to all other ranks (MPI_Ialltoall).
Definition: mpi.hpp:1351
gko::experimental::mpi::environment::environment
environment(int &argc, char **&argv, const thread_type thread_t=thread_type::serialized)
Call MPI_Init_thread and initialize the MPI environment.
Definition: mpi.hpp:241
gko::experimental::mpi::window::put
void put(std::shared_ptr< const Executor > exec, const PutType *origin_buffer, const int origin_count, const int target_rank, const unsigned int target_disp, const int target_count) const
Put data into the target window.
Definition: mpi.hpp:1809
gko::bfloat16
A class providing basic support for bfloat16 precision floating point types.
Definition: bfloat16.hpp:76
gko::experimental::mpi::window::create_type
create_type
The create type for the window object.
Definition: mpi.hpp:1603
gko::experimental::mpi::window::accumulate
void accumulate(std::shared_ptr< const Executor > exec, const PutType *origin_buffer, const int origin_count, const int target_rank, const unsigned int target_disp, const int target_count, MPI_Op operation) const
Accumulate data into the target window.
Definition: mpi.hpp:1859
gko::experimental::mpi::window::fetch_and_op
void fetch_and_op(std::shared_ptr< const Executor > exec, GetType *origin_buffer, GetType *result_buffer, const int target_rank, const unsigned int target_disp, MPI_Op operation) const
Fetch and operate on data from the target window (An optimized version of Get_accumulate).
Definition: mpi.hpp:2021
gko::experimental::mpi::environment
Class that sets up and finalizes the MPI environment.
Definition: mpi.hpp:210
gko::experimental::mpi::communicator::all_to_all_v
void all_to_all_v(std::shared_ptr< const Executor > exec, const void *send_buffer, const int *send_counts, const int *send_offsets, MPI_Datatype send_type, void *recv_buffer, const int *recv_counts, const int *recv_offsets, MPI_Datatype recv_type) const
Communicate data from all ranks to all other ranks with offsets (MPI_Alltoallv).
Definition: mpi.hpp:1410
gko::experimental::mpi::communicator::send
void send(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int send_count, const int destination_rank, const int send_tag) const
Send (Blocking) data from calling process to destination rank.
Definition: mpi.hpp:626
gko::experimental::mpi::communicator::communicator
communicator(const MPI_Comm &comm, int color, int key)
Create a communicator object from an existing MPI_Comm object using color and key.
Definition: mpi.hpp:445
gko::experimental::mpi::communicator::i_scatter_v
request i_scatter_v(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int *send_counts, const int *displacements, RecvType *recv_buffer, const int recv_count, int root_rank) const
(Non-blocking) Scatter data from root rank to all ranks in the communicator with offsets.
Definition: mpi.hpp:1231
gko::experimental::mpi::communicator::synchronize
void synchronize() const
This function is used to synchronize the ranks in the communicator.
Definition: mpi.hpp:607
gko::experimental::mpi::communicator::all_gather
void all_gather(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int send_count, RecvType *recv_buffer, const int recv_count) const
Gather data onto all ranks from all ranks in the communicator.
Definition: mpi.hpp:1075
gko::experimental::mpi::communicator::all_reduce
void all_reduce(std::shared_ptr< const Executor > exec, const ReduceType *send_buffer, ReduceType *recv_buffer, int count, MPI_Op operation) const
Reduce data from all calling processes from all calling processes on same communicator.
Definition: mpi.hpp:890
gko::experimental::mpi::window::get_window
MPI_Win get_window() const
Get the underlying window object of MPI_Win type.
Definition: mpi.hpp:1677
gko::experimental::mpi::window::fence
void fence(int assert=0) const
The active target synchronization using MPI_Win_fence for the window object.
Definition: mpi.hpp:1685
gko::experimental::mpi::communicator::broadcast
void broadcast(std::shared_ptr< const Executor > exec, BroadcastType *buffer, int count, int root_rank) const
Broadcast data from calling process to all ranks in the communicator.
Definition: mpi.hpp:734
gko::experimental::mpi::window::unlock_all
void unlock_all() const
Close the epoch on all ranks using MPI_Win_unlock_all for the window object.
Definition: mpi.hpp:1738
gko::experimental::mpi::communicator::all_reduce
void all_reduce(std::shared_ptr< const Executor > exec, ReduceType *recv_buffer, int count, MPI_Op operation) const
(In-place) Reduce data from all calling processes from all calling processes on same communicator.
Definition: mpi.hpp:838
gko::experimental::mpi::communicator::i_all_to_all
request i_all_to_all(std::shared_ptr< const Executor > exec, RecvType *recv_buffer, const int recv_count) const
(In-place, Non-blocking) Communicate data from all ranks to all other ranks in place (MPI_Ialltoall).
Definition: mpi.hpp:1292
gko::experimental::mpi::contiguous_type::contiguous_type
contiguous_type(contiguous_type &&other) noexcept
Move constructor, leaves other with MPI_DATATYPE_NULL.
Definition: mpi.hpp:147
gko::experimental::mpi::request
The request class is a light, move-only wrapper around the MPI_Request handle.
Definition: mpi.hpp:331
gko::experimental::mpi::communicator::size
int size() const
Return the size of the communicator (number of ranks).
Definition: mpi.hpp:532
gko::experimental::mpi::status::status
status()
The default constructor.
Definition: mpi.hpp:295
gko::experimental::mpi::environment::~environment
~environment()
Call MPI_Finalize at the end of the scope of this class.
Definition: mpi.hpp:253
gko::experimental::mpi::communicator::i_scatter
request i_scatter(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int send_count, RecvType *recv_buffer, const int recv_count, int root_rank) const
(Non-blocking) Scatter data from root rank to all ranks in the communicator.
Definition: mpi.hpp:1165
gko
The Ginkgo namespace.
Definition: abstract_factory.hpp:20
gko::experimental::mpi::communicator::reduce
void reduce(std::shared_ptr< const Executor > exec, const ReduceType *send_buffer, ReduceType *recv_buffer, int count, MPI_Op operation, int root_rank) const
Reduce data into root from all calling processes on the same communicator.
Definition: mpi.hpp:785
gko::experimental::mpi::request::wait
status wait()
Allows a rank to wait on a particular request handle.
Definition: mpi.hpp:376
gko::experimental::mpi::contiguous_type::operator=
contiguous_type & operator=(contiguous_type &&other) noexcept
Move assignment, leaves other with MPI_DATATYPE_NULL.
Definition: mpi.hpp:159
gko::experimental::mpi::window::flush_all
void flush_all() const
Flush all the existing RDMA operations for the calling process for the window object.
Definition: mpi.hpp:1769
gko::experimental::mpi::communicator::all_to_all_v
void all_to_all_v(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int *send_counts, const int *send_offsets, RecvType *recv_buffer, const int *recv_counts, const int *recv_offsets) const
Communicate data from all ranks to all other ranks with offsets (MPI_Alltoallv).
Definition: mpi.hpp:1384
gko::experimental::mpi::communicator::i_gather_v
request i_gather_v(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int send_count, RecvType *recv_buffer, const int *recv_counts, const int *displacements, int root_rank) const
(Non-blocking) Gather data onto the root rank from all ranks in the communicator with offsets.
Definition: mpi.hpp:1044
gko::experimental::mpi::contiguous_type::contiguous_type
contiguous_type(int count, MPI_Datatype old_type)
Constructs a wrapper for a contiguous MPI_Datatype.
Definition: mpi.hpp:121
gko::experimental::mpi::window::unlock
void unlock(int rank) const
Close the epoch using MPI_Win_unlock for the window object.
Definition: mpi.hpp:1718
gko::experimental::mpi::window::flush_all_local
void flush_all_local() const
Flush all the local existing RDMA operations on the calling rank for the window object.
Definition: mpi.hpp:1778
gko::experimental::mpi::is_gpu_aware
constexpr bool is_gpu_aware()
Return if GPU aware functionality is available.
Definition: mpi.hpp:42
gko::experimental::mpi::window::lock_all
void lock_all(int assert=0) const
Create the epoch on all ranks using MPI_Win_lock_all for the window object.
Definition: mpi.hpp:1729
gko::experimental::mpi::window::lock_type
lock_type
The lock type for passive target synchronization of the windows.
Definition: mpi.hpp:1608
gko::experimental::mpi::contiguous_type::get
MPI_Datatype get() const
Access the underlying MPI_Datatype.
Definition: mpi.hpp:182
gko::experimental::mpi::communicator::operator==
bool operator==(const communicator &rhs) const
Compare two communicator objects for equality.
Definition: mpi.hpp:553
gko::experimental::mpi::window::r_accumulate
request r_accumulate(std::shared_ptr< const Executor > exec, const PutType *origin_buffer, const int origin_count, const int target_rank, const unsigned int target_disp, const int target_count, MPI_Op operation) const
(Non-blocking) Accumulate data into the target window.
Definition: mpi.hpp:1885
gko::experimental::mpi::communicator
A thin wrapper of MPI_Comm that supports most MPI calls.
Definition: mpi.hpp:419
gko::experimental::mpi::contiguous_type::~contiguous_type
~contiguous_type()
Destructs object by freeing wrapped MPI_Datatype.
Definition: mpi.hpp:170
gko::experimental::mpi::communicator::create_owning
static communicator create_owning(const MPI_Comm &comm, bool force_host_buffer=false)
Creates a new communicator and takes ownership of the MPI_Comm.
Definition: mpi.hpp:477
gko::experimental::mpi::type_impl
A struct that is used to determine the MPI_Datatype of a specified type.
Definition: mpi.hpp:77
gko::experimental::mpi::communicator::communicator
communicator(const MPI_Comm &comm, bool force_host_buffer=false)
Non-owning constructor for an existing communicator of type MPI_Comm.
Definition: mpi.hpp:431
gko::experimental::mpi::communicator::all_to_all
void all_to_all(std::shared_ptr< const Executor > exec, RecvType *recv_buffer, const int recv_count) const
(In-place) Communicate data from all ranks to all other ranks in place (MPI_Alltoall).
Definition: mpi.hpp:1263
gko::experimental::mpi::communicator::i_all_reduce
request i_all_reduce(std::shared_ptr< const Executor > exec, const ReduceType *send_buffer, ReduceType *recv_buffer, int count, MPI_Op operation) const
Reduce data from all calling processes from all calling processes on same communicator.
Definition: mpi.hpp:917
gko::experimental::mpi::request::request
request()
The default constructor.
Definition: mpi.hpp:337
gko::experimental::mpi::communicator::i_send
request i_send(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int send_count, const int destination_rank, const int send_tag) const
Send (Non-blocking, Immediate return) data from calling process to destination rank.
Definition: mpi.hpp:653
gko::experimental::mpi::communicator::scatter
void scatter(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int send_count, RecvType *recv_buffer, const int recv_count, int root_rank) const
Scatter data from root rank to all ranks in the communicator.
Definition: mpi.hpp:1134
gko::experimental::mpi::window::r_get_accumulate
request r_get_accumulate(std::shared_ptr< const Executor > exec, GetType *origin_buffer, const int origin_count, GetType *result_buffer, const int result_count, const int target_rank, const unsigned int target_disp, const int target_count, MPI_Op operation) const
(Non-blocking) Get Accumulate data (with handle) from the target window.
Definition: mpi.hpp:1992
gko::experimental::mpi::window::~window
~window()
The deleter which calls MPI_Win_free when the window leaves its scope.
Definition: mpi.hpp:1791
gko::experimental::mpi::communicator::operator=
communicator & operator=(communicator &&other)
Definition: mpi.hpp:508
gko::experimental::mpi::map_rank_to_device_id
int map_rank_to_device_id(MPI_Comm comm, int num_devices)
Maps each MPI rank to a single device id in a round robin manner.
gko::experimental::mpi::window::window
window()
The default constructor.
Definition: mpi.hpp:1613
gko::experimental::mpi::get_walltime
double get_walltime()
Get the rank in the communicator of the calling process.
Definition: mpi.hpp:1586
gko::experimental::mpi::status::get_count
int get_count(const T *data) const
Get the count of the number of elements received by the communication call.
Definition: mpi.hpp:315
gko::experimental::mpi::communicator::i_gather
request i_gather(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int send_count, RecvType *recv_buffer, const int recv_count, int root_rank) const
(Non-blocking) Gather data onto the root rank from all ranks in the communicator.
Definition: mpi.hpp:978
gko::experimental::mpi::communicator::rank
int rank() const
Return the rank of the calling process in the communicator.
Definition: mpi.hpp:539
gko::experimental::mpi::communicator::i_all_reduce
request i_all_reduce(std::shared_ptr< const Executor > exec, ReduceType *recv_buffer, int count, MPI_Op operation) const
(In-place, non-blocking) Reduce data from all calling processes from all calling processes on same co...
Definition: mpi.hpp:863
gko::experimental::mpi::communicator::is_congruent
bool is_congruent(const communicator &rhs) const
Checks if the rhs communicator is congruent to this communicator.
Definition: mpi.hpp:593
gko::experimental::mpi::wait_all
std::vector< status > wait_all(std::vector< request > &req)
Allows a rank to wait on multiple request handles.
Definition: mpi.hpp:395
gko::half
A class providing basic support for half precision floating point types.
Definition: half.hpp:288
gko::experimental::mpi::communicator::gather_v
void gather_v(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int send_count, RecvType *recv_buffer, const int *recv_counts, const int *displacements, int root_rank) const
Gather data onto the root rank from all ranks in the communicator with offsets.
Definition: mpi.hpp:1011
gko::experimental::mpi::communicator::scatter_v
void scatter_v(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int *send_counts, const int *displacements, RecvType *recv_buffer, const int recv_count, int root_rank) const
Scatter data from root rank to all ranks in the communicator with offsets.
Definition: mpi.hpp:1198
gko::experimental::mpi::communicator::operator=
communicator & operator=(const communicator &other)=default
gko::experimental::mpi::thread_type
thread_type
This enum specifies the threading type to be used when creating an MPI environment.
Definition: mpi.hpp:193
gko::experimental::mpi::communicator::operator!=
bool operator!=(const communicator &rhs) const
Compare two communicator objects for non-equality.
Definition: mpi.hpp:560
gko::experimental::mpi::status::get
MPI_Status * get()
Get a pointer to the underlying MPI_Status object.
Definition: mpi.hpp:302
gko::experimental::mpi::communicator::is_identical
bool is_identical(const communicator &rhs) const
Checks if the rhs communicator is identical to this communicator.
Definition: mpi.hpp:571
gko::experimental::mpi::communicator::gather
void gather(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int send_count, RecvType *recv_buffer, const int recv_count, int root_rank) const
Gather data onto the root rank from all ranks in the communicator.
Definition: mpi.hpp:946
gko::experimental::mpi::contiguous_type::operator=
contiguous_type & operator=(const contiguous_type &)=delete
Disallow copying of wrapper type.
gko::experimental::mpi::communicator::i_all_gather
request i_all_gather(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int send_count, RecvType *recv_buffer, const int recv_count) const
(Non-blocking) Gather data onto all ranks from all ranks in the communicator.
Definition: mpi.hpp:1105
gko::experimental::mpi::window::window
window(window &&other)
The move constructor.
Definition: mpi.hpp:1625
gko::experimental::mpi::status
The status struct is a light wrapper around the MPI_Status struct.
Definition: mpi.hpp:291
gko::experimental::mpi::window::flush_local
void flush_local(int rank) const
Flush the existing RDMA operations on the calling rank from the target rank for the window object.
Definition: mpi.hpp:1760
gko::experimental::mpi::request::get
MPI_Request * get()
Get a pointer to the underlying MPI_Request handle.
Definition: mpi.hpp:368
gko::experimental::mpi::communicator::communicator
communicator(communicator &&other)
Move constructor.
Definition: mpi.hpp:498
gko::experimental::mpi::communicator::i_all_to_all_v
request i_all_to_all_v(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int *send_counts, const int *send_offsets, RecvType *recv_buffer, const int *recv_counts, const int *recv_offsets) const
Communicate data from all ranks to all other ranks with offsets (MPI_Ialltoallv).
Definition: mpi.hpp:1477
gko::experimental::mpi::communicator::i_recv
request i_recv(std::shared_ptr< const Executor > exec, RecvType *recv_buffer, const int recv_count, const int source_rank, const int recv_tag) const
Receive (Non-blocking, Immediate return) data from source rank.
Definition: mpi.hpp:709
gko::experimental::mpi::window::sync
void sync() const
Synchronize the public and private buffers for the window object.
Definition: mpi.hpp:1786
gko::experimental::mpi::window::flush
void flush(int rank) const
Flush the existing RDMA operations on the target rank for the calling process for the window object.
Definition: mpi.hpp:1749
gko::experimental::mpi::communicator::node_local_rank
int node_local_rank() const
Return the node local rank of the calling process in the communicator.
Definition: mpi.hpp:546
gko::experimental::mpi::contiguous_type
A move-only wrapper for a contiguous MPI_Datatype.
Definition: mpi.hpp:113
gko::experimental::mpi::communicator::i_all_to_all_v
request i_all_to_all_v(std::shared_ptr< const Executor > exec, const void *send_buffer, const int *send_counts, const int *send_offsets, MPI_Datatype send_type, void *recv_buffer, const int *recv_counts, const int *recv_offsets, MPI_Datatype recv_type) const
Communicate data from all ranks to all other ranks with offsets (MPI_Ialltoallv).
Definition: mpi.hpp:1441
gko::experimental::mpi::communicator::all_to_all
void all_to_all(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int send_count, RecvType *recv_buffer, const int recv_count) const
Communicate data from all ranks to all other ranks (MPI_Alltoall).
Definition: mpi.hpp:1321
gko::experimental::mpi::window::r_get
request r_get(std::shared_ptr< const Executor > exec, GetType *origin_buffer, const int origin_count, const int target_rank, const unsigned int target_disp, const int target_count) const
Get data (with handle) from the target window.
Definition: mpi.hpp:1935
gko::experimental::mpi::communicator::recv
status recv(std::shared_ptr< const Executor > exec, RecvType *recv_buffer, const int recv_count, const int source_rank, const int recv_tag) const
Receive data from source rank.
Definition: mpi.hpp:681
gko::experimental::mpi::communicator::i_reduce
request i_reduce(std::shared_ptr< const Executor > exec, const ReduceType *send_buffer, ReduceType *recv_buffer, int count, MPI_Op operation, int root_rank) const
(Non-blocking) Reduce data into root from all calling processes on the same communicator.
Definition: mpi.hpp:812
gko::experimental::mpi::window::get
void get(std::shared_ptr< const Executor > exec, GetType *origin_buffer, const int origin_count, const int target_rank, const unsigned int target_disp, const int target_count) const
Get data from the target window.
Definition: mpi.hpp:1911
gko::experimental::mpi::window::operator=
window & operator=(window &&other)
The move assignment operator.
Definition: mpi.hpp:1634
gko::experimental::mpi::communicator::get
const MPI_Comm & get() const
Return the underlying MPI_Comm object.
Definition: mpi.hpp:523
gko::experimental::mpi::window::r_put
request r_put(std::shared_ptr< const Executor > exec, const PutType *origin_buffer, const int origin_count, const int target_rank, const unsigned int target_disp, const int target_count) const
Put data into the target window.
Definition: mpi.hpp:1833