5 #ifndef GKO_PUBLIC_CORE_DISTRIBUTED_MATRIX_HPP_
6 #define GKO_PUBLIC_CORE_DISTRIBUTED_MATRIX_HPP_
9 #include <ginkgo/config.hpp>
15 #include <ginkgo/core/base/dense_cache.hpp>
16 #include <ginkgo/core/base/lin_op.hpp>
17 #include <ginkgo/core/base/mpi.hpp>
18 #include <ginkgo/core/base/std_extensions.hpp>
19 #include <ginkgo/core/distributed/base.hpp>
20 #include <ginkgo/core/distributed/index_map.hpp>
27 template <
typename ValueType,
typename IndexType>
37 template <
typename ValueType,
typename IndexType>
51 template <
typename Builder,
typename ValueType,
typename IndexType,
53 struct is_matrix_type_builder : std::false_type {};
56 template <
typename Builder,
typename ValueType,
typename IndexType>
57 struct is_matrix_type_builder<
58 Builder, ValueType, IndexType,
60 decltype(std::declval<Builder>().template create<ValueType, IndexType>(
61 std::declval<std::shared_ptr<const Executor>>()))>>
65 template <
template <
typename,
typename>
class MatrixType,
66 typename... CreateArgs>
67 struct MatrixTypeBuilderFromValueAndIndex {
68 template <
typename ValueType,
typename IndexType, std::size_t... I>
69 auto create_impl(std::shared_ptr<const Executor> exec,
70 std::index_sequence<I...>)
72 return MatrixType<ValueType, IndexType>::create(
73 exec, std::get<I>(create_args)...);
77 template <
typename ValueType,
typename IndexType>
78 auto create(std::shared_ptr<const Executor> exec)
81 static constexpr
auto size =
sizeof...(CreateArgs);
82 return create_impl<ValueType, IndexType>(
83 std::move(exec), std::make_index_sequence<size>{});
86 std::tuple<CreateArgs...> create_args;
124 template <
template <
typename,
typename>
class MatrixType,
typename... Args>
127 return detail::MatrixTypeBuilderFromValueAndIndex<MatrixType, Args...>{
128 std::forward_as_tuple(create_args...)};
132 namespace experimental {
133 namespace distributed {
148 template <
typename LocalIndexType,
typename GlobalIndexType>
150 template <
typename ValueType>
259 typename LocalIndexType =
int32,
typename GlobalIndexType =
int64>
261 :
public EnableLinOp<Matrix<ValueType, LocalIndexType, GlobalIndexType>>,
263 Matrix<next_precision<ValueType>, LocalIndexType, GlobalIndexType>>,
264 #if GINKGO_ENABLE_HALF
265 public ConvertibleTo<Matrix<next_precision<next_precision<ValueType>>,
266 LocalIndexType, GlobalIndexType>>,
270 friend class Matrix<previous_precision<ValueType>, LocalIndexType,
276 using value_type = ValueType;
277 using index_type = GlobalIndexType;
278 using local_index_type = LocalIndexType;
279 using global_index_type = GlobalIndexType;
287 GlobalIndexType>>::convert_to;
289 GlobalIndexType>>::move_to;
292 global_index_type>* result)
const override;
295 global_index_type>* result)
override;
296 #if GINKGO_ENABLE_HALF
297 friend class Matrix<previous_precision<previous_precision<ValueType>>,
298 LocalIndexType, GlobalIndexType>;
301 global_index_type>>::convert_to;
303 local_index_type, global_index_type>>::move_to;
307 global_index_type>* result)
const override;
310 local_index_type, global_index_type>* result)
override;
408 return non_local_mtx_;
454 static std::unique_ptr<Matrix>
create(std::shared_ptr<const Executor> exec,
477 template <
typename MatrixType,
478 typename = std::enable_if_t<gko::detail::is_matrix_type_builder<
479 MatrixType, ValueType, LocalIndexType>::value>>
480 static std::unique_ptr<Matrix>
create(std::shared_ptr<const Executor> exec,
482 MatrixType matrix_template)
486 matrix_template.template create<ValueType, LocalIndexType>(exec));
517 template <
typename LocalMatrixType,
typename NonLocalMatrixType,
518 typename = std::enable_if_t<
519 gko::detail::is_matrix_type_builder<
520 LocalMatrixType, ValueType, LocalIndexType>::value &&
521 gko::detail::is_matrix_type_builder<
522 NonLocalMatrixType, ValueType, LocalIndexType>::value>>
525 LocalMatrixType local_matrix_template,
526 NonLocalMatrixType non_local_matrix_template)
530 local_matrix_template.template create<ValueType, LocalIndexType>(
532 non_local_matrix_template
533 .template create<ValueType, LocalIndexType>(exec));
550 static std::unique_ptr<Matrix>
create(
570 static std::unique_ptr<Matrix>
create(
587 static std::unique_ptr<Matrix>
create(std::shared_ptr<const Executor> exec,
589 std::shared_ptr<LinOp> local_linop);
609 static std::unique_ptr<Matrix>
create(
611 dim<2> size, std::shared_ptr<LinOp> local_linop,
612 std::shared_ptr<LinOp> non_local_linop,
613 std::vector<comm_index_type> recv_sizes,
614 std::vector<comm_index_type> recv_offsets,
636 explicit Matrix(std::shared_ptr<const Executor> exec,
639 explicit Matrix(std::shared_ptr<const Executor> exec,
644 explicit Matrix(std::shared_ptr<const Executor> exec,
646 std::shared_ptr<LinOp> local_linop);
648 explicit Matrix(std::shared_ptr<const Executor> exec,
650 std::shared_ptr<LinOp> local_linop,
651 std::shared_ptr<LinOp> non_local_linop,
652 std::vector<comm_index_type> recv_sizes,
653 std::vector<comm_index_type> recv_offsets,
664 mpi::request communicate(
const local_vector_type* local_b)
const;
666 void apply_impl(
const LinOp* b,
LinOp* x)
const override;
669 LinOp* x)
const override;
672 std::vector<comm_index_type> send_offsets_;
673 std::vector<comm_index_type> send_sizes_;
674 std::vector<comm_index_type> recv_offsets_;
675 std::vector<comm_index_type> recv_sizes_;
678 gko::detail::DenseCache<value_type> one_scalar_;
679 gko::detail::DenseCache<value_type> host_send_buffer_;
680 gko::detail::DenseCache<value_type> host_recv_buffer_;
681 gko::detail::DenseCache<value_type> send_buffer_;
682 gko::detail::DenseCache<value_type> recv_buffer_;
683 std::shared_ptr<LinOp> local_mtx_;
684 std::shared_ptr<LinOp> non_local_mtx_;
696 #endif // GKO_PUBLIC_CORE_DISTRIBUTED_MATRIX_HPP_