Ginkgo  Generated from pipelines/1744748943 branch based on develop. Ginkgo version 1.10.0
A numerical linear algebra library targeting many-core architectures
All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends Modules Pages
mpi.hpp
1 // SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
2 //
3 // SPDX-License-Identifier: BSD-3-Clause
4 
5 #ifndef GKO_PUBLIC_CORE_BASE_MPI_HPP_
6 #define GKO_PUBLIC_CORE_BASE_MPI_HPP_
7 
8 
9 #include <memory>
10 #include <type_traits>
11 #include <utility>
12 
13 #include <ginkgo/config.hpp>
14 #include <ginkgo/core/base/exception.hpp>
15 #include <ginkgo/core/base/exception_helpers.hpp>
16 #include <ginkgo/core/base/executor.hpp>
17 #include <ginkgo/core/base/half.hpp>
18 #include <ginkgo/core/base/types.hpp>
19 #include <ginkgo/core/base/utils_helper.hpp>
20 
21 
22 #if GINKGO_BUILD_MPI
23 
24 
25 #include <mpi.h>
26 
27 
28 namespace gko {
29 namespace experimental {
36 namespace mpi {
37 
38 
42 inline constexpr bool is_gpu_aware()
43 {
44 #if GINKGO_HAVE_GPU_AWARE_MPI
45  return true;
46 #else
47  return false;
48 #endif
49 }
50 
51 
59 int map_rank_to_device_id(MPI_Comm comm, int num_devices);
60 
61 
62 #define GKO_REGISTER_MPI_TYPE(input_type, mpi_type) \
63  template <> \
64  struct type_impl<input_type> { \
65  static MPI_Datatype get_type() { return mpi_type; } \
66  }
67 
76 template <typename T>
77 struct type_impl {};
78 
79 
80 GKO_REGISTER_MPI_TYPE(char, MPI_CHAR);
81 GKO_REGISTER_MPI_TYPE(unsigned char, MPI_UNSIGNED_CHAR);
82 GKO_REGISTER_MPI_TYPE(unsigned, MPI_UNSIGNED);
83 GKO_REGISTER_MPI_TYPE(int, MPI_INT);
84 GKO_REGISTER_MPI_TYPE(unsigned short, MPI_UNSIGNED_SHORT);
85 GKO_REGISTER_MPI_TYPE(unsigned long, MPI_UNSIGNED_LONG);
86 GKO_REGISTER_MPI_TYPE(long, MPI_LONG);
87 GKO_REGISTER_MPI_TYPE(long long, MPI_LONG_LONG_INT);
88 GKO_REGISTER_MPI_TYPE(unsigned long long, MPI_UNSIGNED_LONG_LONG);
89 GKO_REGISTER_MPI_TYPE(float, MPI_FLOAT);
90 GKO_REGISTER_MPI_TYPE(double, MPI_DOUBLE);
91 GKO_REGISTER_MPI_TYPE(long double, MPI_LONG_DOUBLE);
92 #if GINKGO_ENABLE_HALF
93 // OpenMPI 5.0 have support from MPIX_C_FLOAT16 and MPICHv3.4a1 MPIX_C_FLOAT16
94 // Only OpenMPI support complex half
95 // TODO: use native type when mpi is configured with half feature
96 GKO_REGISTER_MPI_TYPE(half, MPI_UNSIGNED_SHORT);
97 GKO_REGISTER_MPI_TYPE(std::complex<half>, MPI_FLOAT);
98 #endif // GKO_ENABLE_HALF
99 GKO_REGISTER_MPI_TYPE(std::complex<float>, MPI_C_FLOAT_COMPLEX);
100 GKO_REGISTER_MPI_TYPE(std::complex<double>, MPI_C_DOUBLE_COMPLEX);
101 
102 
110 public:
117  contiguous_type(int count, MPI_Datatype old_type) : type_(MPI_DATATYPE_NULL)
118  {
119  GKO_ASSERT_NO_MPI_ERRORS(MPI_Type_contiguous(count, old_type, &type_));
120  GKO_ASSERT_NO_MPI_ERRORS(MPI_Type_commit(&type_));
121  }
122 
126  contiguous_type() : type_(MPI_DATATYPE_NULL) {}
127 
131  contiguous_type(const contiguous_type&) = delete;
132 
136  contiguous_type& operator=(const contiguous_type&) = delete;
137 
143  contiguous_type(contiguous_type&& other) noexcept : type_(MPI_DATATYPE_NULL)
144  {
145  *this = std::move(other);
146  }
147 
156  {
157  if (this != &other) {
158  this->type_ = std::exchange(other.type_, MPI_DATATYPE_NULL);
159  }
160  return *this;
161  }
162 
167  {
168  if (type_ != MPI_DATATYPE_NULL) {
169  MPI_Type_free(&type_);
170  }
171  }
172 
178  MPI_Datatype get() const { return type_; }
179 
180 private:
181  MPI_Datatype type_;
182 };
183 
184 
189 enum class thread_type {
190  serialized = MPI_THREAD_SERIALIZED,
191  funneled = MPI_THREAD_FUNNELED,
192  single = MPI_THREAD_SINGLE,
193  multiple = MPI_THREAD_MULTIPLE
194 };
195 
196 
206 class environment {
207 public:
208  static bool is_finalized()
209  {
210  int flag = 0;
211  GKO_ASSERT_NO_MPI_ERRORS(MPI_Finalized(&flag));
212  return flag;
213  }
214 
215  static bool is_initialized()
216  {
217  int flag = 0;
218  GKO_ASSERT_NO_MPI_ERRORS(MPI_Initialized(&flag));
219  return flag;
220  }
221 
227  int get_provided_thread_support() const { return provided_thread_support_; }
228 
237  environment(int& argc, char**& argv,
238  const thread_type thread_t = thread_type::serialized)
239  {
240  this->required_thread_support_ = static_cast<int>(thread_t);
241  GKO_ASSERT_NO_MPI_ERRORS(
242  MPI_Init_thread(&argc, &argv, this->required_thread_support_,
243  &(this->provided_thread_support_)));
244  }
245 
249  ~environment() { MPI_Finalize(); }
250 
251  environment(const environment&) = delete;
252  environment(environment&&) = delete;
253  environment& operator=(const environment&) = delete;
254  environment& operator=(environment&&) = delete;
255 
256 private:
257  int required_thread_support_;
258  int provided_thread_support_;
259 };
260 
261 
262 namespace {
263 
264 
269 class comm_deleter {
270 public:
271  using pointer = MPI_Comm*;
272  void operator()(pointer comm) const
273  {
274  GKO_ASSERT(*comm != MPI_COMM_NULL);
275  GKO_ASSERT_NO_MPI_ERRORS(MPI_Comm_free(comm));
276  delete comm;
277  }
278 };
279 
280 
281 } // namespace
282 
283 
287 struct status {
291  status() : status_(MPI_Status{}) {}
292 
298  MPI_Status* get() { return &this->status_; }
299 
310  template <typename T>
311  int get_count(const T* data) const
312  {
313  int count;
314  MPI_Get_count(&status_, type_impl<T>::get_type(), &count);
315  return count;
316  }
317 
318 private:
319  MPI_Status status_;
320 };
321 
322 
327 class request {
328 public:
333  request() : req_(MPI_REQUEST_NULL) {}
334 
335  request(const request&) = delete;
336 
337  request& operator=(const request&) = delete;
338 
339  request(request&& o) noexcept { *this = std::move(o); }
340 
341  request& operator=(request&& o) noexcept
342  {
343  if (this != &o) {
344  this->req_ = std::exchange(o.req_, MPI_REQUEST_NULL);
345  }
346  return *this;
347  }
348 
349  ~request()
350  {
351  if (req_ != MPI_REQUEST_NULL) {
352  if (MPI_Request_free(&req_) != MPI_SUCCESS) {
353  std::terminate(); // since we can't throw in destructors, we
354  // have to terminate the program
355  }
356  }
357  }
358 
364  MPI_Request* get() { return &this->req_; }
365 
373  {
374  status status;
375  GKO_ASSERT_NO_MPI_ERRORS(MPI_Wait(&req_, status.get()));
376  return status;
377  }
378 
379 
380 private:
381  MPI_Request req_;
382 };
383 
384 
392 inline std::vector<status> wait_all(std::vector<request>& req)
393 {
394  std::vector<status> stat;
395  for (std::size_t i = 0; i < req.size(); ++i) {
396  stat.emplace_back(req[i].wait());
397  }
398  return stat;
399 }
400 
401 
417 public:
428  communicator(const MPI_Comm& comm, bool force_host_buffer = false)
429  : comm_(), force_host_buffer_(force_host_buffer)
430  {
431  this->comm_.reset(new MPI_Comm(comm));
432  }
433 
442  communicator(const MPI_Comm& comm, int color, int key)
443  {
444  MPI_Comm comm_out;
445  GKO_ASSERT_NO_MPI_ERRORS(MPI_Comm_split(comm, color, key, &comm_out));
446  this->comm_.reset(new MPI_Comm(comm_out), comm_deleter{});
447  }
448 
457  communicator(const communicator& comm, int color, int key)
458  {
459  MPI_Comm comm_out;
460  GKO_ASSERT_NO_MPI_ERRORS(
461  MPI_Comm_split(comm.get(), color, key, &comm_out));
462  this->comm_.reset(new MPI_Comm(comm_out), comm_deleter{});
463  }
464 
474  static communicator create_owning(const MPI_Comm& comm,
475  bool force_host_buffer = false)
476  {
477  communicator comm_out(MPI_COMM_NULL, force_host_buffer);
478  comm_out.comm_.reset(new MPI_Comm(comm), comm_deleter{});
479  return comm_out;
480  }
481 
487  communicator(const communicator& other) = default;
488 
495  communicator(communicator&& other) { *this = std::move(other); }
496 
500  communicator& operator=(const communicator& other) = default;
501 
506  {
507  if (this != &other) {
508  comm_ = std::exchange(other.comm_,
509  std::make_shared<MPI_Comm>(MPI_COMM_NULL));
510  force_host_buffer_ = other.force_host_buffer_;
511  }
512  return *this;
513  }
514 
520  const MPI_Comm& get() const { return *(this->comm_.get()); }
521 
522  bool force_host_buffer() const { return force_host_buffer_; }
523 
529  int size() const { return get_num_ranks(); }
530 
536  int rank() const { return get_my_rank(); };
537 
543  int node_local_rank() const { return get_node_local_rank(); };
544 
550  bool operator==(const communicator& rhs) const { return is_identical(rhs); }
551 
557  bool operator!=(const communicator& rhs) const { return !(*this == rhs); }
558 
568  bool is_identical(const communicator& rhs) const
569  {
570  if (get() == MPI_COMM_NULL || rhs.get() == MPI_COMM_NULL) {
571  return get() == rhs.get();
572  }
573  int flag;
574  GKO_ASSERT_NO_MPI_ERRORS(MPI_Comm_compare(get(), rhs.get(), &flag));
575  return flag == MPI_IDENT;
576  }
577 
590  bool is_congruent(const communicator& rhs) const
591  {
592  if (get() == MPI_COMM_NULL || rhs.get() == MPI_COMM_NULL) {
593  return get() == rhs.get();
594  }
595  int flag;
596  GKO_ASSERT_NO_MPI_ERRORS(MPI_Comm_compare(get(), rhs.get(), &flag));
597  return flag == MPI_CONGRUENT;
598  }
599 
604  void synchronize() const
605  {
606  GKO_ASSERT_NO_MPI_ERRORS(MPI_Barrier(this->get()));
607  }
608 
622  template <typename SendType>
623  void send(std::shared_ptr<const Executor> exec, const SendType* send_buffer,
624  const int send_count, const int destination_rank,
625  const int send_tag) const
626  {
627  auto guard = exec->get_scoped_device_id_guard();
628  GKO_ASSERT_NO_MPI_ERRORS(
629  MPI_Send(send_buffer, send_count, type_impl<SendType>::get_type(),
630  destination_rank, send_tag, this->get()));
631  }
632 
649  template <typename SendType>
650  request i_send(std::shared_ptr<const Executor> exec,
651  const SendType* send_buffer, const int send_count,
652  const int destination_rank, const int send_tag) const
653  {
654  auto guard = exec->get_scoped_device_id_guard();
655  request req;
656  GKO_ASSERT_NO_MPI_ERRORS(
657  MPI_Isend(send_buffer, send_count, type_impl<SendType>::get_type(),
658  destination_rank, send_tag, this->get(), req.get()));
659  return req;
660  }
661 
677  template <typename RecvType>
678  status recv(std::shared_ptr<const Executor> exec, RecvType* recv_buffer,
679  const int recv_count, const int source_rank,
680  const int recv_tag) const
681  {
682  auto guard = exec->get_scoped_device_id_guard();
683  status st;
684  GKO_ASSERT_NO_MPI_ERRORS(
685  MPI_Recv(recv_buffer, recv_count, type_impl<RecvType>::get_type(),
686  source_rank, recv_tag, this->get(), st.get()));
687  return st;
688  }
689 
705  template <typename RecvType>
706  request i_recv(std::shared_ptr<const Executor> exec, RecvType* recv_buffer,
707  const int recv_count, const int source_rank,
708  const int recv_tag) const
709  {
710  auto guard = exec->get_scoped_device_id_guard();
711  request req;
712  GKO_ASSERT_NO_MPI_ERRORS(
713  MPI_Irecv(recv_buffer, recv_count, type_impl<RecvType>::get_type(),
714  source_rank, recv_tag, this->get(), req.get()));
715  return req;
716  }
717 
730  template <typename BroadcastType>
731  void broadcast(std::shared_ptr<const Executor> exec, BroadcastType* buffer,
732  int count, int root_rank) const
733  {
734  auto guard = exec->get_scoped_device_id_guard();
735  GKO_ASSERT_NO_MPI_ERRORS(MPI_Bcast(buffer, count,
737  root_rank, this->get()));
738  }
739 
755  template <typename BroadcastType>
756  request i_broadcast(std::shared_ptr<const Executor> exec,
757  BroadcastType* buffer, int count, int root_rank) const
758  {
759  auto guard = exec->get_scoped_device_id_guard();
760  request req;
761  GKO_ASSERT_NO_MPI_ERRORS(
762  MPI_Ibcast(buffer, count, type_impl<BroadcastType>::get_type(),
763  root_rank, this->get(), req.get()));
764  return req;
765  }
766 
781  template <typename ReduceType>
782  void reduce(std::shared_ptr<const Executor> exec,
783  const ReduceType* send_buffer, ReduceType* recv_buffer,
784  int count, MPI_Op operation, int root_rank) const
785  {
786  auto guard = exec->get_scoped_device_id_guard();
787  GKO_ASSERT_NO_MPI_ERRORS(MPI_Reduce(send_buffer, recv_buffer, count,
789  operation, root_rank, this->get()));
790  }
791 
808  template <typename ReduceType>
809  request i_reduce(std::shared_ptr<const Executor> exec,
810  const ReduceType* send_buffer, ReduceType* recv_buffer,
811  int count, MPI_Op operation, int root_rank) const
812  {
813  auto guard = exec->get_scoped_device_id_guard();
814  request req;
815  GKO_ASSERT_NO_MPI_ERRORS(MPI_Ireduce(
816  send_buffer, recv_buffer, count, type_impl<ReduceType>::get_type(),
817  operation, root_rank, this->get(), req.get()));
818  return req;
819  }
820 
834  template <typename ReduceType>
835  void all_reduce(std::shared_ptr<const Executor> exec,
836  ReduceType* recv_buffer, int count, MPI_Op operation) const
837  {
838  auto guard = exec->get_scoped_device_id_guard();
839  GKO_ASSERT_NO_MPI_ERRORS(MPI_Allreduce(
840  MPI_IN_PLACE, recv_buffer, count, type_impl<ReduceType>::get_type(),
841  operation, this->get()));
842  }
843 
859  template <typename ReduceType>
860  request i_all_reduce(std::shared_ptr<const Executor> exec,
861  ReduceType* recv_buffer, int count,
862  MPI_Op operation) const
863  {
864  auto guard = exec->get_scoped_device_id_guard();
865  request req;
866  GKO_ASSERT_NO_MPI_ERRORS(MPI_Iallreduce(
867  MPI_IN_PLACE, recv_buffer, count, type_impl<ReduceType>::get_type(),
868  operation, this->get(), req.get()));
869  return req;
870  }
871 
886  template <typename ReduceType>
887  void all_reduce(std::shared_ptr<const Executor> exec,
888  const ReduceType* send_buffer, ReduceType* recv_buffer,
889  int count, MPI_Op operation) const
890  {
891  auto guard = exec->get_scoped_device_id_guard();
892  GKO_ASSERT_NO_MPI_ERRORS(MPI_Allreduce(
893  send_buffer, recv_buffer, count, type_impl<ReduceType>::get_type(),
894  operation, this->get()));
895  }
896 
913  template <typename ReduceType>
914  request i_all_reduce(std::shared_ptr<const Executor> exec,
915  const ReduceType* send_buffer, ReduceType* recv_buffer,
916  int count, MPI_Op operation) const
917  {
918  auto guard = exec->get_scoped_device_id_guard();
919  request req;
920  GKO_ASSERT_NO_MPI_ERRORS(MPI_Iallreduce(
921  send_buffer, recv_buffer, count, type_impl<ReduceType>::get_type(),
922  operation, this->get(), req.get()));
923  return req;
924  }
925 
942  template <typename SendType, typename RecvType>
943  void gather(std::shared_ptr<const Executor> exec,
944  const SendType* send_buffer, const int send_count,
945  RecvType* recv_buffer, const int recv_count,
946  int root_rank) const
947  {
948  auto guard = exec->get_scoped_device_id_guard();
949  GKO_ASSERT_NO_MPI_ERRORS(
950  MPI_Gather(send_buffer, send_count, type_impl<SendType>::get_type(),
951  recv_buffer, recv_count, type_impl<RecvType>::get_type(),
952  root_rank, this->get()));
953  }
954 
974  template <typename SendType, typename RecvType>
975  request i_gather(std::shared_ptr<const Executor> exec,
976  const SendType* send_buffer, const int send_count,
977  RecvType* recv_buffer, const int recv_count,
978  int root_rank) const
979  {
980  auto guard = exec->get_scoped_device_id_guard();
981  request req;
982  GKO_ASSERT_NO_MPI_ERRORS(MPI_Igather(
983  send_buffer, send_count, type_impl<SendType>::get_type(),
984  recv_buffer, recv_count, type_impl<RecvType>::get_type(), root_rank,
985  this->get(), req.get()));
986  return req;
987  }
988 
1007  template <typename SendType, typename RecvType>
1008  void gather_v(std::shared_ptr<const Executor> exec,
1009  const SendType* send_buffer, const int send_count,
1010  RecvType* recv_buffer, const int* recv_counts,
1011  const int* displacements, int root_rank) const
1012  {
1013  auto guard = exec->get_scoped_device_id_guard();
1014  GKO_ASSERT_NO_MPI_ERRORS(MPI_Gatherv(
1015  send_buffer, send_count, type_impl<SendType>::get_type(),
1016  recv_buffer, recv_counts, displacements,
1017  type_impl<RecvType>::get_type(), root_rank, this->get()));
1018  }
1019 
1040  template <typename SendType, typename RecvType>
1041  request i_gather_v(std::shared_ptr<const Executor> exec,
1042  const SendType* send_buffer, const int send_count,
1043  RecvType* recv_buffer, const int* recv_counts,
1044  const int* displacements, int root_rank) const
1045  {
1046  auto guard = exec->get_scoped_device_id_guard();
1047  request req;
1048  GKO_ASSERT_NO_MPI_ERRORS(MPI_Igatherv(
1049  send_buffer, send_count, type_impl<SendType>::get_type(),
1050  recv_buffer, recv_counts, displacements,
1051  type_impl<RecvType>::get_type(), root_rank, this->get(),
1052  req.get()));
1053  return req;
1054  }
1055 
1071  template <typename SendType, typename RecvType>
1072  void all_gather(std::shared_ptr<const Executor> exec,
1073  const SendType* send_buffer, const int send_count,
1074  RecvType* recv_buffer, const int recv_count) const
1075  {
1076  auto guard = exec->get_scoped_device_id_guard();
1077  GKO_ASSERT_NO_MPI_ERRORS(MPI_Allgather(
1078  send_buffer, send_count, type_impl<SendType>::get_type(),
1079  recv_buffer, recv_count, type_impl<RecvType>::get_type(),
1080  this->get()));
1081  }
1082 
1101  template <typename SendType, typename RecvType>
1102  request i_all_gather(std::shared_ptr<const Executor> exec,
1103  const SendType* send_buffer, const int send_count,
1104  RecvType* recv_buffer, const int recv_count) const
1105  {
1106  auto guard = exec->get_scoped_device_id_guard();
1107  request req;
1108  GKO_ASSERT_NO_MPI_ERRORS(MPI_Iallgather(
1109  send_buffer, send_count, type_impl<SendType>::get_type(),
1110  recv_buffer, recv_count, type_impl<RecvType>::get_type(),
1111  this->get(), req.get()));
1112  return req;
1113  }
1114 
1130  template <typename SendType, typename RecvType>
1131  void scatter(std::shared_ptr<const Executor> exec,
1132  const SendType* send_buffer, const int send_count,
1133  RecvType* recv_buffer, const int recv_count,
1134  int root_rank) const
1135  {
1136  auto guard = exec->get_scoped_device_id_guard();
1137  GKO_ASSERT_NO_MPI_ERRORS(MPI_Scatter(
1138  send_buffer, send_count, type_impl<SendType>::get_type(),
1139  recv_buffer, recv_count, type_impl<RecvType>::get_type(), root_rank,
1140  this->get()));
1141  }
1142 
1161  template <typename SendType, typename RecvType>
1162  request i_scatter(std::shared_ptr<const Executor> exec,
1163  const SendType* send_buffer, const int send_count,
1164  RecvType* recv_buffer, const int recv_count,
1165  int root_rank) const
1166  {
1167  auto guard = exec->get_scoped_device_id_guard();
1168  request req;
1169  GKO_ASSERT_NO_MPI_ERRORS(MPI_Iscatter(
1170  send_buffer, send_count, type_impl<SendType>::get_type(),
1171  recv_buffer, recv_count, type_impl<RecvType>::get_type(), root_rank,
1172  this->get(), req.get()));
1173  return req;
1174  }
1175 
1194  template <typename SendType, typename RecvType>
1195  void scatter_v(std::shared_ptr<const Executor> exec,
1196  const SendType* send_buffer, const int* send_counts,
1197  const int* displacements, RecvType* recv_buffer,
1198  const int recv_count, int root_rank) const
1199  {
1200  auto guard = exec->get_scoped_device_id_guard();
1201  GKO_ASSERT_NO_MPI_ERRORS(MPI_Scatterv(
1202  send_buffer, send_counts, displacements,
1203  type_impl<SendType>::get_type(), recv_buffer, recv_count,
1204  type_impl<RecvType>::get_type(), root_rank, this->get()));
1205  }
1206 
1227  template <typename SendType, typename RecvType>
1228  request i_scatter_v(std::shared_ptr<const Executor> exec,
1229  const SendType* send_buffer, const int* send_counts,
1230  const int* displacements, RecvType* recv_buffer,
1231  const int recv_count, int root_rank) const
1232  {
1233  auto guard = exec->get_scoped_device_id_guard();
1234  request req;
1235  GKO_ASSERT_NO_MPI_ERRORS(
1236  MPI_Iscatterv(send_buffer, send_counts, displacements,
1237  type_impl<SendType>::get_type(), recv_buffer,
1238  recv_count, type_impl<RecvType>::get_type(),
1239  root_rank, this->get(), req.get()));
1240  return req;
1241  }
1242 
1259  template <typename RecvType>
1260  void all_to_all(std::shared_ptr<const Executor> exec, RecvType* recv_buffer,
1261  const int recv_count) const
1262  {
1263  auto guard = exec->get_scoped_device_id_guard();
1264  GKO_ASSERT_NO_MPI_ERRORS(MPI_Alltoall(
1265  MPI_IN_PLACE, recv_count, type_impl<RecvType>::get_type(),
1266  recv_buffer, recv_count, type_impl<RecvType>::get_type(),
1267  this->get()));
1268  }
1269 
1288  template <typename RecvType>
1289  request i_all_to_all(std::shared_ptr<const Executor> exec,
1290  RecvType* recv_buffer, const int recv_count) const
1291  {
1292  auto guard = exec->get_scoped_device_id_guard();
1293  request req;
1294  GKO_ASSERT_NO_MPI_ERRORS(MPI_Ialltoall(
1295  MPI_IN_PLACE, recv_count, type_impl<RecvType>::get_type(),
1296  recv_buffer, recv_count, type_impl<RecvType>::get_type(),
1297  this->get(), req.get()));
1298  return req;
1299  }
1300 
1317  template <typename SendType, typename RecvType>
1318  void all_to_all(std::shared_ptr<const Executor> exec,
1319  const SendType* send_buffer, const int send_count,
1320  RecvType* recv_buffer, const int recv_count) const
1321  {
1322  auto guard = exec->get_scoped_device_id_guard();
1323  GKO_ASSERT_NO_MPI_ERRORS(MPI_Alltoall(
1324  send_buffer, send_count, type_impl<SendType>::get_type(),
1325  recv_buffer, recv_count, type_impl<RecvType>::get_type(),
1326  this->get()));
1327  }
1328 
1347  template <typename SendType, typename RecvType>
1348  request i_all_to_all(std::shared_ptr<const Executor> exec,
1349  const SendType* send_buffer, const int send_count,
1350  RecvType* recv_buffer, const int recv_count) const
1351  {
1352  auto guard = exec->get_scoped_device_id_guard();
1353  request req;
1354  GKO_ASSERT_NO_MPI_ERRORS(MPI_Ialltoall(
1355  send_buffer, send_count, type_impl<SendType>::get_type(),
1356  recv_buffer, recv_count, type_impl<RecvType>::get_type(),
1357  this->get(), req.get()));
1358  return req;
1359  }
1360 
1380  template <typename SendType, typename RecvType>
1381  void all_to_all_v(std::shared_ptr<const Executor> exec,
1382  const SendType* send_buffer, const int* send_counts,
1383  const int* send_offsets, RecvType* recv_buffer,
1384  const int* recv_counts, const int* recv_offsets) const
1385  {
1386  this->all_to_all_v(std::move(exec), send_buffer, send_counts,
1387  send_offsets, type_impl<SendType>::get_type(),
1388  recv_buffer, recv_counts, recv_offsets,
1390  }
1391 
1407  void all_to_all_v(std::shared_ptr<const Executor> exec,
1408  const void* send_buffer, const int* send_counts,
1409  const int* send_offsets, MPI_Datatype send_type,
1410  void* recv_buffer, const int* recv_counts,
1411  const int* recv_offsets, MPI_Datatype recv_type) const
1412  {
1413  auto guard = exec->get_scoped_device_id_guard();
1414  GKO_ASSERT_NO_MPI_ERRORS(MPI_Alltoallv(
1415  send_buffer, send_counts, send_offsets, send_type, recv_buffer,
1416  recv_counts, recv_offsets, recv_type, this->get()));
1417  }
1418 
1438  request i_all_to_all_v(std::shared_ptr<const Executor> exec,
1439  const void* send_buffer, const int* send_counts,
1440  const int* send_offsets, MPI_Datatype send_type,
1441  void* recv_buffer, const int* recv_counts,
1442  const int* recv_offsets,
1443  MPI_Datatype recv_type) const
1444  {
1445  auto guard = exec->get_scoped_device_id_guard();
1446  request req;
1447  GKO_ASSERT_NO_MPI_ERRORS(MPI_Ialltoallv(
1448  send_buffer, send_counts, send_offsets, send_type, recv_buffer,
1449  recv_counts, recv_offsets, recv_type, this->get(), req.get()));
1450  return req;
1451  }
1452 
1473  template <typename SendType, typename RecvType>
1474  request i_all_to_all_v(std::shared_ptr<const Executor> exec,
1475  const SendType* send_buffer, const int* send_counts,
1476  const int* send_offsets, RecvType* recv_buffer,
1477  const int* recv_counts,
1478  const int* recv_offsets) const
1479  {
1480  return this->i_all_to_all_v(
1481  std::move(exec), send_buffer, send_counts, send_offsets,
1482  type_impl<SendType>::get_type(), recv_buffer, recv_counts,
1483  recv_offsets, type_impl<RecvType>::get_type());
1484  }
1485 
1500  template <typename ScanType>
1501  void scan(std::shared_ptr<const Executor> exec, const ScanType* send_buffer,
1502  ScanType* recv_buffer, int count, MPI_Op operation) const
1503  {
1504  auto guard = exec->get_scoped_device_id_guard();
1505  GKO_ASSERT_NO_MPI_ERRORS(MPI_Scan(send_buffer, recv_buffer, count,
1507  operation, this->get()));
1508  }
1509 
1526  template <typename ScanType>
1527  request i_scan(std::shared_ptr<const Executor> exec,
1528  const ScanType* send_buffer, ScanType* recv_buffer,
1529  int count, MPI_Op operation) const
1530  {
1531  auto guard = exec->get_scoped_device_id_guard();
1532  request req;
1533  GKO_ASSERT_NO_MPI_ERRORS(MPI_Iscan(send_buffer, recv_buffer, count,
1535  operation, this->get(), req.get()));
1536  return req;
1537  }
1538 
1539 private:
1540  std::shared_ptr<MPI_Comm> comm_;
1541  bool force_host_buffer_;
1542 
1543  int get_my_rank() const
1544  {
1545  int my_rank = 0;
1546  GKO_ASSERT_NO_MPI_ERRORS(MPI_Comm_rank(get(), &my_rank));
1547  return my_rank;
1548  }
1549 
1550  int get_node_local_rank() const
1551  {
1552  MPI_Comm local_comm;
1553  int rank;
1554  GKO_ASSERT_NO_MPI_ERRORS(MPI_Comm_split_type(
1555  this->get(), MPI_COMM_TYPE_SHARED, 0, MPI_INFO_NULL, &local_comm));
1556  GKO_ASSERT_NO_MPI_ERRORS(MPI_Comm_rank(local_comm, &rank));
1557  MPI_Comm_free(&local_comm);
1558  return rank;
1559  }
1560 
1561  int get_num_ranks() const
1562  {
1563  int size = 1;
1564  GKO_ASSERT_NO_MPI_ERRORS(MPI_Comm_size(this->get(), &size));
1565  return size;
1566  }
1567 };
1568 
1569 
1574 bool requires_host_buffer(const std::shared_ptr<const Executor>& exec,
1575  const communicator& comm);
1576 
1577 
1583 inline double get_walltime() { return MPI_Wtime(); }
1584 
1585 
1594 template <typename ValueType>
1595 class window {
1596 public:
1600  enum class create_type { allocate = 1, create = 2, dynamic_create = 3 };
1601 
1605  enum class lock_type { shared = 1, exclusive = 2 };
1606 
1610  window() : window_(MPI_WIN_NULL) {}
1611 
1612  window(const window& other) = delete;
1613 
1614  window& operator=(const window& other) = delete;
1615 
1622  window(window&& other) : window_{std::exchange(other.window_, MPI_WIN_NULL)}
1623  {}
1624 
1632  {
1633  window_ = std::exchange(other.window_, MPI_WIN_NULL);
1634  }
1635 
1648  window(std::shared_ptr<const Executor> exec, ValueType* base, int num_elems,
1649  const communicator& comm, const int disp_unit = sizeof(ValueType),
1650  MPI_Info input_info = MPI_INFO_NULL,
1651  create_type c_type = create_type::create)
1652  {
1653  auto guard = exec->get_scoped_device_id_guard();
1654  unsigned size = num_elems * sizeof(ValueType);
1655  if (c_type == create_type::create) {
1656  GKO_ASSERT_NO_MPI_ERRORS(MPI_Win_create(
1657  base, size, disp_unit, input_info, comm.get(), &this->window_));
1658  } else if (c_type == create_type::dynamic_create) {
1659  GKO_ASSERT_NO_MPI_ERRORS(
1660  MPI_Win_create_dynamic(input_info, comm.get(), &this->window_));
1661  } else if (c_type == create_type::allocate) {
1662  GKO_ASSERT_NO_MPI_ERRORS(MPI_Win_allocate(
1663  size, disp_unit, input_info, comm.get(), base, &this->window_));
1664  } else {
1665  GKO_NOT_IMPLEMENTED;
1666  }
1667  }
1668 
1674  MPI_Win get_window() const { return this->window_; }
1675 
1682  void fence(int assert = 0) const
1683  {
1684  GKO_ASSERT_NO_MPI_ERRORS(MPI_Win_fence(assert, this->window_));
1685  }
1686 
1695  void lock(int rank, lock_type lock_t = lock_type::shared,
1696  int assert = 0) const
1697  {
1698  if (lock_t == lock_type::shared) {
1699  GKO_ASSERT_NO_MPI_ERRORS(
1700  MPI_Win_lock(MPI_LOCK_SHARED, rank, assert, this->window_));
1701  } else if (lock_t == lock_type::exclusive) {
1702  GKO_ASSERT_NO_MPI_ERRORS(
1703  MPI_Win_lock(MPI_LOCK_EXCLUSIVE, rank, assert, this->window_));
1704  } else {
1705  GKO_NOT_IMPLEMENTED;
1706  }
1707  }
1708 
1715  void unlock(int rank) const
1716  {
1717  GKO_ASSERT_NO_MPI_ERRORS(MPI_Win_unlock(rank, this->window_));
1718  }
1719 
1726  void lock_all(int assert = 0) const
1727  {
1728  GKO_ASSERT_NO_MPI_ERRORS(MPI_Win_lock_all(assert, this->window_));
1729  }
1730 
1735  void unlock_all() const
1736  {
1737  GKO_ASSERT_NO_MPI_ERRORS(MPI_Win_unlock_all(this->window_));
1738  }
1739 
1746  void flush(int rank) const
1747  {
1748  GKO_ASSERT_NO_MPI_ERRORS(MPI_Win_flush(rank, this->window_));
1749  }
1750 
1757  void flush_local(int rank) const
1758  {
1759  GKO_ASSERT_NO_MPI_ERRORS(MPI_Win_flush_local(rank, this->window_));
1760  }
1761 
1766  void flush_all() const
1767  {
1768  GKO_ASSERT_NO_MPI_ERRORS(MPI_Win_flush_all(this->window_));
1769  }
1770 
1775  void flush_all_local() const
1776  {
1777  GKO_ASSERT_NO_MPI_ERRORS(MPI_Win_flush_local_all(this->window_));
1778  }
1779 
1783  void sync() const { GKO_ASSERT_NO_MPI_ERRORS(MPI_Win_sync(this->window_)); }
1784 
1789  {
1790  if (this->window_ && this->window_ != MPI_WIN_NULL) {
1791  MPI_Win_free(&this->window_);
1792  }
1793  }
1794 
1805  template <typename PutType>
1806  void put(std::shared_ptr<const Executor> exec, const PutType* origin_buffer,
1807  const int origin_count, const int target_rank,
1808  const unsigned int target_disp, const int target_count) const
1809  {
1810  auto guard = exec->get_scoped_device_id_guard();
1811  GKO_ASSERT_NO_MPI_ERRORS(
1812  MPI_Put(origin_buffer, origin_count, type_impl<PutType>::get_type(),
1813  target_rank, target_disp, target_count,
1815  }
1816 
1829  template <typename PutType>
1830  request r_put(std::shared_ptr<const Executor> exec,
1831  const PutType* origin_buffer, const int origin_count,
1832  const int target_rank, const unsigned int target_disp,
1833  const int target_count) const
1834  {
1835  auto guard = exec->get_scoped_device_id_guard();
1836  request req;
1837  GKO_ASSERT_NO_MPI_ERRORS(MPI_Rput(
1838  origin_buffer, origin_count, type_impl<PutType>::get_type(),
1839  target_rank, target_disp, target_count,
1840  type_impl<PutType>::get_type(), this->get_window(), req.get()));
1841  return req;
1842  }
1843 
1855  template <typename PutType>
1856  void accumulate(std::shared_ptr<const Executor> exec,
1857  const PutType* origin_buffer, const int origin_count,
1858  const int target_rank, const unsigned int target_disp,
1859  const int target_count, MPI_Op operation) const
1860  {
1861  auto guard = exec->get_scoped_device_id_guard();
1862  GKO_ASSERT_NO_MPI_ERRORS(MPI_Accumulate(
1863  origin_buffer, origin_count, type_impl<PutType>::get_type(),
1864  target_rank, target_disp, target_count,
1865  type_impl<PutType>::get_type(), operation, this->get_window()));
1866  }
1867 
1881  template <typename PutType>
1882  request r_accumulate(std::shared_ptr<const Executor> exec,
1883  const PutType* origin_buffer, const int origin_count,
1884  const int target_rank, const unsigned int target_disp,
1885  const int target_count, MPI_Op operation) const
1886  {
1887  auto guard = exec->get_scoped_device_id_guard();
1888  request req;
1889  GKO_ASSERT_NO_MPI_ERRORS(MPI_Raccumulate(
1890  origin_buffer, origin_count, type_impl<PutType>::get_type(),
1891  target_rank, target_disp, target_count,
1892  type_impl<PutType>::get_type(), operation, this->get_window(),
1893  req.get()));
1894  return req;
1895  }
1896 
1907  template <typename GetType>
1908  void get(std::shared_ptr<const Executor> exec, GetType* origin_buffer,
1909  const int origin_count, const int target_rank,
1910  const unsigned int target_disp, const int target_count) const
1911  {
1912  auto guard = exec->get_scoped_device_id_guard();
1913  GKO_ASSERT_NO_MPI_ERRORS(
1914  MPI_Get(origin_buffer, origin_count, type_impl<GetType>::get_type(),
1915  target_rank, target_disp, target_count,
1917  }
1918 
1931  template <typename GetType>
1932  request r_get(std::shared_ptr<const Executor> exec, GetType* origin_buffer,
1933  const int origin_count, const int target_rank,
1934  const unsigned int target_disp, const int target_count) const
1935  {
1936  auto guard = exec->get_scoped_device_id_guard();
1937  request req;
1938  GKO_ASSERT_NO_MPI_ERRORS(MPI_Rget(
1939  origin_buffer, origin_count, type_impl<GetType>::get_type(),
1940  target_rank, target_disp, target_count,
1941  type_impl<GetType>::get_type(), this->get_window(), req.get()));
1942  return req;
1943  }
1944 
1958  template <typename GetType>
1959  void get_accumulate(std::shared_ptr<const Executor> exec,
1960  GetType* origin_buffer, const int origin_count,
1961  GetType* result_buffer, const int result_count,
1962  const int target_rank, const unsigned int target_disp,
1963  const int target_count, MPI_Op operation) const
1964  {
1965  auto guard = exec->get_scoped_device_id_guard();
1966  GKO_ASSERT_NO_MPI_ERRORS(MPI_Get_accumulate(
1967  origin_buffer, origin_count, type_impl<GetType>::get_type(),
1968  result_buffer, result_count, type_impl<GetType>::get_type(),
1969  target_rank, target_disp, target_count,
1970  type_impl<GetType>::get_type(), operation, this->get_window()));
1971  }
1972 
1988  template <typename GetType>
1989  request r_get_accumulate(std::shared_ptr<const Executor> exec,
1990  GetType* origin_buffer, const int origin_count,
1991  GetType* result_buffer, const int result_count,
1992  const int target_rank,
1993  const unsigned int target_disp,
1994  const int target_count, MPI_Op operation) const
1995  {
1996  auto guard = exec->get_scoped_device_id_guard();
1997  request req;
1998  GKO_ASSERT_NO_MPI_ERRORS(MPI_Rget_accumulate(
1999  origin_buffer, origin_count, type_impl<GetType>::get_type(),
2000  result_buffer, result_count, type_impl<GetType>::get_type(),
2001  target_rank, target_disp, target_count,
2002  type_impl<GetType>::get_type(), operation, this->get_window(),
2003  req.get()));
2004  return req;
2005  }
2006 
2017  template <typename GetType>
2018  void fetch_and_op(std::shared_ptr<const Executor> exec,
2019  GetType* origin_buffer, GetType* result_buffer,
2020  const int target_rank, const unsigned int target_disp,
2021  MPI_Op operation) const
2022  {
2023  auto guard = exec->get_scoped_device_id_guard();
2024  GKO_ASSERT_NO_MPI_ERRORS(MPI_Fetch_and_op(
2025  origin_buffer, result_buffer, type_impl<GetType>::get_type(),
2026  target_rank, target_disp, operation, this->get_window()));
2027  }
2028 
2029 private:
2030  MPI_Win window_;
2031 };
2032 
2033 
2034 } // namespace mpi
2035 } // namespace experimental
2036 } // namespace gko
2037 
2038 
2039 #endif // GKO_HAVE_MPI
2040 
2041 
2042 #endif // GKO_PUBLIC_CORE_BASE_MPI_HPP_
gko::experimental::mpi::window
This class wraps the MPI_Window class with RAII functionality.
Definition: mpi.hpp:1595
gko::experimental::mpi::environment::get_provided_thread_support
int get_provided_thread_support() const
Return the provided thread support.
Definition: mpi.hpp:227
gko::experimental::mpi::requires_host_buffer
bool requires_host_buffer(const std::shared_ptr< const Executor > &exec, const communicator &comm)
Checks if the combination of Executor and communicator requires passing MPI buffers from the host mem...
gko::experimental::mpi::communicator::i_scan
request i_scan(std::shared_ptr< const Executor > exec, const ScanType *send_buffer, ScanType *recv_buffer, int count, MPI_Op operation) const
Does a scan operation with the given operator.
Definition: mpi.hpp:1527
gko::experimental::mpi::contiguous_type::contiguous_type
contiguous_type()
Constructs empty wrapper with MPI_DATATYPE_NULL.
Definition: mpi.hpp:126
gko::experimental::mpi::communicator::scan
void scan(std::shared_ptr< const Executor > exec, const ScanType *send_buffer, ScanType *recv_buffer, int count, MPI_Op operation) const
Does a scan operation with the given operator.
Definition: mpi.hpp:1501
gko::experimental::mpi::window::get_accumulate
void get_accumulate(std::shared_ptr< const Executor > exec, GetType *origin_buffer, const int origin_count, GetType *result_buffer, const int result_count, const int target_rank, const unsigned int target_disp, const int target_count, MPI_Op operation) const
Get Accumulate data from the target window.
Definition: mpi.hpp:1959
gko::experimental::mpi::window::lock
void lock(int rank, lock_type lock_t=lock_type::shared, int assert=0) const
Create an epoch using MPI_Win_lock for the window object.
Definition: mpi.hpp:1695
gko::experimental::mpi::communicator::i_broadcast
request i_broadcast(std::shared_ptr< const Executor > exec, BroadcastType *buffer, int count, int root_rank) const
(Non-blocking) Broadcast data from calling process to all ranks in the communicator
Definition: mpi.hpp:756
gko::experimental::mpi::communicator::communicator
communicator(const communicator &comm, int color, int key)
Create a communicator object from an existing MPI_Comm object using color and key.
Definition: mpi.hpp:457
gko::experimental::mpi::window::window
window(std::shared_ptr< const Executor > exec, ValueType *base, int num_elems, const communicator &comm, const int disp_unit=sizeof(ValueType), MPI_Info input_info=MPI_INFO_NULL, create_type c_type=create_type::create)
Create a window object with a given data pointer and type.
Definition: mpi.hpp:1648
gko::experimental::mpi::communicator::i_all_to_all
request i_all_to_all(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int send_count, RecvType *recv_buffer, const int recv_count) const
(Non-blocking) Communicate data from all ranks to all other ranks (MPI_Ialltoall).
Definition: mpi.hpp:1348
gko::experimental::mpi::environment::environment
environment(int &argc, char **&argv, const thread_type thread_t=thread_type::serialized)
Call MPI_Init_thread and initialize the MPI environment.
Definition: mpi.hpp:237
gko::experimental::mpi::window::put
void put(std::shared_ptr< const Executor > exec, const PutType *origin_buffer, const int origin_count, const int target_rank, const unsigned int target_disp, const int target_count) const
Put data into the target window.
Definition: mpi.hpp:1806
gko::experimental::mpi::window::create_type
create_type
The create type for the window object.
Definition: mpi.hpp:1600
gko::experimental::mpi::window::accumulate
void accumulate(std::shared_ptr< const Executor > exec, const PutType *origin_buffer, const int origin_count, const int target_rank, const unsigned int target_disp, const int target_count, MPI_Op operation) const
Accumulate data into the target window.
Definition: mpi.hpp:1856
gko::experimental::mpi::window::fetch_and_op
void fetch_and_op(std::shared_ptr< const Executor > exec, GetType *origin_buffer, GetType *result_buffer, const int target_rank, const unsigned int target_disp, MPI_Op operation) const
Fetch and operate on data from the target window (An optimized version of Get_accumulate).
Definition: mpi.hpp:2018
gko::experimental::mpi::environment
Class that sets up and finalizes the MPI environment.
Definition: mpi.hpp:206
gko::experimental::mpi::communicator::all_to_all_v
void all_to_all_v(std::shared_ptr< const Executor > exec, const void *send_buffer, const int *send_counts, const int *send_offsets, MPI_Datatype send_type, void *recv_buffer, const int *recv_counts, const int *recv_offsets, MPI_Datatype recv_type) const
Communicate data from all ranks to all other ranks with offsets (MPI_Alltoallv).
Definition: mpi.hpp:1407
gko::experimental::mpi::communicator::send
void send(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int send_count, const int destination_rank, const int send_tag) const
Send (Blocking) data from calling process to destination rank.
Definition: mpi.hpp:623
gko::experimental::mpi::communicator::communicator
communicator(const MPI_Comm &comm, int color, int key)
Create a communicator object from an existing MPI_Comm object using color and key.
Definition: mpi.hpp:442
gko::experimental::mpi::communicator::i_scatter_v
request i_scatter_v(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int *send_counts, const int *displacements, RecvType *recv_buffer, const int recv_count, int root_rank) const
(Non-blocking) Scatter data from root rank to all ranks in the communicator with offsets.
Definition: mpi.hpp:1228
gko::experimental::mpi::communicator::synchronize
void synchronize() const
This function is used to synchronize the ranks in the communicator.
Definition: mpi.hpp:604
gko::experimental::mpi::communicator::all_gather
void all_gather(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int send_count, RecvType *recv_buffer, const int recv_count) const
Gather data onto all ranks from all ranks in the communicator.
Definition: mpi.hpp:1072
gko::experimental::mpi::communicator::all_reduce
void all_reduce(std::shared_ptr< const Executor > exec, const ReduceType *send_buffer, ReduceType *recv_buffer, int count, MPI_Op operation) const
Reduce data from all calling processes from all calling processes on same communicator.
Definition: mpi.hpp:887
gko::experimental::mpi::window::get_window
MPI_Win get_window() const
Get the underlying window object of MPI_Win type.
Definition: mpi.hpp:1674
gko::experimental::mpi::window::fence
void fence(int assert=0) const
The active target synchronization using MPI_Win_fence for the window object.
Definition: mpi.hpp:1682
gko::experimental::mpi::communicator::broadcast
void broadcast(std::shared_ptr< const Executor > exec, BroadcastType *buffer, int count, int root_rank) const
Broadcast data from calling process to all ranks in the communicator.
Definition: mpi.hpp:731
gko::experimental::mpi::window::unlock_all
void unlock_all() const
Close the epoch on all ranks using MPI_Win_unlock_all for the window object.
Definition: mpi.hpp:1735
gko::experimental::mpi::communicator::all_reduce
void all_reduce(std::shared_ptr< const Executor > exec, ReduceType *recv_buffer, int count, MPI_Op operation) const
(In-place) Reduce data from all calling processes from all calling processes on same communicator.
Definition: mpi.hpp:835
gko::experimental::mpi::communicator::i_all_to_all
request i_all_to_all(std::shared_ptr< const Executor > exec, RecvType *recv_buffer, const int recv_count) const
(In-place, Non-blocking) Communicate data from all ranks to all other ranks in place (MPI_Ialltoall).
Definition: mpi.hpp:1289
gko::experimental::mpi::contiguous_type::contiguous_type
contiguous_type(contiguous_type &&other) noexcept
Move constructor, leaves other with MPI_DATATYPE_NULL.
Definition: mpi.hpp:143
gko::experimental::mpi::request
The request class is a light, move-only wrapper around the MPI_Request handle.
Definition: mpi.hpp:327
gko::experimental::mpi::communicator::size
int size() const
Return the size of the communicator (number of ranks).
Definition: mpi.hpp:529
gko::experimental::mpi::status::status
status()
The default constructor.
Definition: mpi.hpp:291
gko::experimental::mpi::environment::~environment
~environment()
Call MPI_Finalize at the end of the scope of this class.
Definition: mpi.hpp:249
gko::experimental::mpi::communicator::i_scatter
request i_scatter(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int send_count, RecvType *recv_buffer, const int recv_count, int root_rank) const
(Non-blocking) Scatter data from root rank to all ranks in the communicator.
Definition: mpi.hpp:1162
gko
The Ginkgo namespace.
Definition: abstract_factory.hpp:20
gko::experimental::mpi::communicator::reduce
void reduce(std::shared_ptr< const Executor > exec, const ReduceType *send_buffer, ReduceType *recv_buffer, int count, MPI_Op operation, int root_rank) const
Reduce data into root from all calling processes on the same communicator.
Definition: mpi.hpp:782
gko::experimental::mpi::request::wait
status wait()
Allows a rank to wait on a particular request handle.
Definition: mpi.hpp:372
gko::experimental::mpi::contiguous_type::operator=
contiguous_type & operator=(contiguous_type &&other) noexcept
Move assignment, leaves other with MPI_DATATYPE_NULL.
Definition: mpi.hpp:155
gko::experimental::mpi::window::flush_all
void flush_all() const
Flush all the existing RDMA operations for the calling process for the window object.
Definition: mpi.hpp:1766
gko::experimental::mpi::communicator::all_to_all_v
void all_to_all_v(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int *send_counts, const int *send_offsets, RecvType *recv_buffer, const int *recv_counts, const int *recv_offsets) const
Communicate data from all ranks to all other ranks with offsets (MPI_Alltoallv).
Definition: mpi.hpp:1381
gko::experimental::mpi::communicator::i_gather_v
request i_gather_v(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int send_count, RecvType *recv_buffer, const int *recv_counts, const int *displacements, int root_rank) const
(Non-blocking) Gather data onto the root rank from all ranks in the communicator with offsets.
Definition: mpi.hpp:1041
gko::experimental::mpi::contiguous_type::contiguous_type
contiguous_type(int count, MPI_Datatype old_type)
Constructs a wrapper for a contiguous MPI_Datatype.
Definition: mpi.hpp:117
gko::experimental::mpi::window::unlock
void unlock(int rank) const
Close the epoch using MPI_Win_unlock for the window object.
Definition: mpi.hpp:1715
gko::experimental::mpi::window::flush_all_local
void flush_all_local() const
Flush all the local existing RDMA operations on the calling rank for the window object.
Definition: mpi.hpp:1775
gko::experimental::mpi::is_gpu_aware
constexpr bool is_gpu_aware()
Return if GPU aware functionality is available.
Definition: mpi.hpp:42
gko::experimental::mpi::window::lock_all
void lock_all(int assert=0) const
Create the epoch on all ranks using MPI_Win_lock_all for the window object.
Definition: mpi.hpp:1726
gko::experimental::mpi::window::lock_type
lock_type
The lock type for passive target synchronization of the windows.
Definition: mpi.hpp:1605
gko::experimental::mpi::contiguous_type::get
MPI_Datatype get() const
Access the underlying MPI_Datatype.
Definition: mpi.hpp:178
gko::experimental::mpi::communicator::operator==
bool operator==(const communicator &rhs) const
Compare two communicator objects for equality.
Definition: mpi.hpp:550
gko::experimental::mpi::window::r_accumulate
request r_accumulate(std::shared_ptr< const Executor > exec, const PutType *origin_buffer, const int origin_count, const int target_rank, const unsigned int target_disp, const int target_count, MPI_Op operation) const
(Non-blocking) Accumulate data into the target window.
Definition: mpi.hpp:1882
gko::experimental::mpi::communicator
A thin wrapper of MPI_Comm that supports most MPI calls.
Definition: mpi.hpp:416
gko::experimental::mpi::contiguous_type::~contiguous_type
~contiguous_type()
Destructs object by freeing wrapped MPI_Datatype.
Definition: mpi.hpp:166
gko::experimental::mpi::communicator::create_owning
static communicator create_owning(const MPI_Comm &comm, bool force_host_buffer=false)
Creates a new communicator and takes ownership of the MPI_Comm.
Definition: mpi.hpp:474
gko::experimental::mpi::type_impl
A struct that is used to determine the MPI_Datatype of a specified type.
Definition: mpi.hpp:77
gko::experimental::mpi::communicator::communicator
communicator(const MPI_Comm &comm, bool force_host_buffer=false)
Non-owning constructor for an existing communicator of type MPI_Comm.
Definition: mpi.hpp:428
gko::experimental::mpi::communicator::all_to_all
void all_to_all(std::shared_ptr< const Executor > exec, RecvType *recv_buffer, const int recv_count) const
(In-place) Communicate data from all ranks to all other ranks in place (MPI_Alltoall).
Definition: mpi.hpp:1260
gko::experimental::mpi::communicator::i_all_reduce
request i_all_reduce(std::shared_ptr< const Executor > exec, const ReduceType *send_buffer, ReduceType *recv_buffer, int count, MPI_Op operation) const
Reduce data from all calling processes from all calling processes on same communicator.
Definition: mpi.hpp:914
gko::experimental::mpi::request::request
request()
The default constructor.
Definition: mpi.hpp:333
gko::experimental::mpi::communicator::i_send
request i_send(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int send_count, const int destination_rank, const int send_tag) const
Send (Non-blocking, Immediate return) data from calling process to destination rank.
Definition: mpi.hpp:650
gko::experimental::mpi::communicator::scatter
void scatter(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int send_count, RecvType *recv_buffer, const int recv_count, int root_rank) const
Scatter data from root rank to all ranks in the communicator.
Definition: mpi.hpp:1131
gko::experimental::mpi::window::r_get_accumulate
request r_get_accumulate(std::shared_ptr< const Executor > exec, GetType *origin_buffer, const int origin_count, GetType *result_buffer, const int result_count, const int target_rank, const unsigned int target_disp, const int target_count, MPI_Op operation) const
(Non-blocking) Get Accumulate data (with handle) from the target window.
Definition: mpi.hpp:1989
gko::experimental::mpi::window::~window
~window()
The deleter which calls MPI_Win_free when the window leaves its scope.
Definition: mpi.hpp:1788
gko::experimental::mpi::communicator::operator=
communicator & operator=(communicator &&other)
Definition: mpi.hpp:505
gko::experimental::mpi::map_rank_to_device_id
int map_rank_to_device_id(MPI_Comm comm, int num_devices)
Maps each MPI rank to a single device id in a round robin manner.
gko::experimental::mpi::window::window
window()
The default constructor.
Definition: mpi.hpp:1610
gko::experimental::mpi::get_walltime
double get_walltime()
Get the rank in the communicator of the calling process.
Definition: mpi.hpp:1583
gko::experimental::mpi::status::get_count
int get_count(const T *data) const
Get the count of the number of elements received by the communication call.
Definition: mpi.hpp:311
gko::experimental::mpi::communicator::i_gather
request i_gather(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int send_count, RecvType *recv_buffer, const int recv_count, int root_rank) const
(Non-blocking) Gather data onto the root rank from all ranks in the communicator.
Definition: mpi.hpp:975
gko::experimental::mpi::communicator::rank
int rank() const
Return the rank of the calling process in the communicator.
Definition: mpi.hpp:536
gko::experimental::mpi::communicator::i_all_reduce
request i_all_reduce(std::shared_ptr< const Executor > exec, ReduceType *recv_buffer, int count, MPI_Op operation) const
(In-place, non-blocking) Reduce data from all calling processes from all calling processes on same co...
Definition: mpi.hpp:860
gko::experimental::mpi::communicator::is_congruent
bool is_congruent(const communicator &rhs) const
Checks if the rhs communicator is congruent to this communicator.
Definition: mpi.hpp:590
gko::experimental::mpi::wait_all
std::vector< status > wait_all(std::vector< request > &req)
Allows a rank to wait on multiple request handles.
Definition: mpi.hpp:392
gko::half
A class providing basic support for half precision floating point types.
Definition: half.hpp:286
gko::experimental::mpi::communicator::gather_v
void gather_v(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int send_count, RecvType *recv_buffer, const int *recv_counts, const int *displacements, int root_rank) const
Gather data onto the root rank from all ranks in the communicator with offsets.
Definition: mpi.hpp:1008
gko::experimental::mpi::communicator::scatter_v
void scatter_v(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int *send_counts, const int *displacements, RecvType *recv_buffer, const int recv_count, int root_rank) const
Scatter data from root rank to all ranks in the communicator with offsets.
Definition: mpi.hpp:1195
gko::experimental::mpi::communicator::operator=
communicator & operator=(const communicator &other)=default
gko::experimental::mpi::thread_type
thread_type
This enum specifies the threading type to be used when creating an MPI environment.
Definition: mpi.hpp:189
gko::experimental::mpi::communicator::operator!=
bool operator!=(const communicator &rhs) const
Compare two communicator objects for non-equality.
Definition: mpi.hpp:557
gko::experimental::mpi::status::get
MPI_Status * get()
Get a pointer to the underlying MPI_Status object.
Definition: mpi.hpp:298
gko::experimental::mpi::communicator::is_identical
bool is_identical(const communicator &rhs) const
Checks if the rhs communicator is identical to this communicator.
Definition: mpi.hpp:568
gko::experimental::mpi::communicator::gather
void gather(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int send_count, RecvType *recv_buffer, const int recv_count, int root_rank) const
Gather data onto the root rank from all ranks in the communicator.
Definition: mpi.hpp:943
gko::experimental::mpi::contiguous_type::operator=
contiguous_type & operator=(const contiguous_type &)=delete
Disallow copying of wrapper type.
gko::experimental::mpi::communicator::i_all_gather
request i_all_gather(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int send_count, RecvType *recv_buffer, const int recv_count) const
(Non-blocking) Gather data onto all ranks from all ranks in the communicator.
Definition: mpi.hpp:1102
gko::experimental::mpi::window::window
window(window &&other)
The move constructor.
Definition: mpi.hpp:1622
gko::experimental::mpi::status
The status struct is a light wrapper around the MPI_Status struct.
Definition: mpi.hpp:287
gko::experimental::mpi::window::flush_local
void flush_local(int rank) const
Flush the existing RDMA operations on the calling rank from the target rank for the window object.
Definition: mpi.hpp:1757
gko::experimental::mpi::request::get
MPI_Request * get()
Get a pointer to the underlying MPI_Request handle.
Definition: mpi.hpp:364
gko::experimental::mpi::communicator::communicator
communicator(communicator &&other)
Move constructor.
Definition: mpi.hpp:495
gko::experimental::mpi::communicator::i_all_to_all_v
request i_all_to_all_v(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int *send_counts, const int *send_offsets, RecvType *recv_buffer, const int *recv_counts, const int *recv_offsets) const
Communicate data from all ranks to all other ranks with offsets (MPI_Ialltoallv).
Definition: mpi.hpp:1474
gko::experimental::mpi::communicator::i_recv
request i_recv(std::shared_ptr< const Executor > exec, RecvType *recv_buffer, const int recv_count, const int source_rank, const int recv_tag) const
Receive (Non-blocking, Immediate return) data from source rank.
Definition: mpi.hpp:706
gko::experimental::mpi::window::sync
void sync() const
Synchronize the public and private buffers for the window object.
Definition: mpi.hpp:1783
gko::experimental::mpi::window::flush
void flush(int rank) const
Flush the existing RDMA operations on the target rank for the calling process for the window object.
Definition: mpi.hpp:1746
gko::experimental::mpi::communicator::node_local_rank
int node_local_rank() const
Return the node local rank of the calling process in the communicator.
Definition: mpi.hpp:543
gko::experimental::mpi::contiguous_type
A move-only wrapper for a contiguous MPI_Datatype.
Definition: mpi.hpp:109
gko::experimental::mpi::communicator::i_all_to_all_v
request i_all_to_all_v(std::shared_ptr< const Executor > exec, const void *send_buffer, const int *send_counts, const int *send_offsets, MPI_Datatype send_type, void *recv_buffer, const int *recv_counts, const int *recv_offsets, MPI_Datatype recv_type) const
Communicate data from all ranks to all other ranks with offsets (MPI_Ialltoallv).
Definition: mpi.hpp:1438
gko::experimental::mpi::communicator::all_to_all
void all_to_all(std::shared_ptr< const Executor > exec, const SendType *send_buffer, const int send_count, RecvType *recv_buffer, const int recv_count) const
Communicate data from all ranks to all other ranks (MPI_Alltoall).
Definition: mpi.hpp:1318
gko::experimental::mpi::window::r_get
request r_get(std::shared_ptr< const Executor > exec, GetType *origin_buffer, const int origin_count, const int target_rank, const unsigned int target_disp, const int target_count) const
Get data (with handle) from the target window.
Definition: mpi.hpp:1932
gko::experimental::mpi::communicator::recv
status recv(std::shared_ptr< const Executor > exec, RecvType *recv_buffer, const int recv_count, const int source_rank, const int recv_tag) const
Receive data from source rank.
Definition: mpi.hpp:678
gko::experimental::mpi::communicator::i_reduce
request i_reduce(std::shared_ptr< const Executor > exec, const ReduceType *send_buffer, ReduceType *recv_buffer, int count, MPI_Op operation, int root_rank) const
(Non-blocking) Reduce data into root from all calling processes on the same communicator.
Definition: mpi.hpp:809
gko::experimental::mpi::window::get
void get(std::shared_ptr< const Executor > exec, GetType *origin_buffer, const int origin_count, const int target_rank, const unsigned int target_disp, const int target_count) const
Get data from the target window.
Definition: mpi.hpp:1908
gko::experimental::mpi::window::operator=
window & operator=(window &&other)
The move assignment operator.
Definition: mpi.hpp:1631
gko::experimental::mpi::communicator::get
const MPI_Comm & get() const
Return the underlying MPI_Comm object.
Definition: mpi.hpp:520
gko::experimental::mpi::window::r_put
request r_put(std::shared_ptr< const Executor > exec, const PutType *origin_buffer, const int origin_count, const int target_rank, const unsigned int target_disp, const int target_count) const
Put data into the target window.
Definition: mpi.hpp:1830