5 #ifndef GKO_PUBLIC_CORE_DISTRIBUTED_MATRIX_HPP_
6 #define GKO_PUBLIC_CORE_DISTRIBUTED_MATRIX_HPP_
9 #include <ginkgo/config.hpp>
15 #include <ginkgo/core/base/dense_cache.hpp>
16 #include <ginkgo/core/base/mpi.hpp>
17 #include <ginkgo/core/distributed/base.hpp>
18 #include <ginkgo/core/distributed/index_map.hpp>
19 #include <ginkgo/core/distributed/lin_op.hpp>
26 template <
typename ValueType,
typename IndexType>
36 template <
typename ValueType,
typename IndexType>
50 template <
typename Builder,
typename ValueType,
typename IndexType,
52 struct is_matrix_type_builder : std::false_type {};
55 template <
typename Builder,
typename ValueType,
typename IndexType>
56 struct is_matrix_type_builder<
57 Builder, ValueType, IndexType,
59 decltype(std::declval<Builder>().template create<ValueType, IndexType>(
60 std::declval<std::shared_ptr<const Executor>>()))>>
64 template <
template <
typename,
typename>
class MatrixType,
65 typename... CreateArgs>
66 struct MatrixTypeBuilderFromValueAndIndex {
67 template <
typename ValueType,
typename IndexType, std::size_t... I>
68 auto create_impl(std::shared_ptr<const Executor> exec,
69 std::index_sequence<I...>)
71 return MatrixType<ValueType, IndexType>::create(
72 exec, std::get<I>(create_args)...);
76 template <
typename ValueType,
typename IndexType>
77 auto create(std::shared_ptr<const Executor> exec)
80 static constexpr
auto size =
sizeof...(CreateArgs);
81 return create_impl<ValueType, IndexType>(
82 std::move(exec), std::make_index_sequence<size>{});
85 std::tuple<CreateArgs...> create_args;
123 template <
template <
typename,
typename>
class MatrixType,
typename... Args>
126 return detail::MatrixTypeBuilderFromValueAndIndex<MatrixType, Args...>{
127 std::forward_as_tuple(create_args...)};
131 namespace experimental {
132 namespace distributed {
135 template <
typename LocalIndexType,
typename GlobalIndexType>
137 template <
typename ValueType>
246 typename LocalIndexType =
int32,
typename GlobalIndexType =
int64>
249 Matrix<ValueType, LocalIndexType, GlobalIndexType>>,
251 Matrix<next_precision<ValueType>, LocalIndexType, GlobalIndexType>>,
259 using value_type = ValueType;
260 using index_type = GlobalIndexType;
261 using local_index_type = LocalIndexType;
262 using global_index_type = GlobalIndexType;
270 GlobalIndexType>>::convert_to;
272 GlobalIndexType>>::move_to;
275 global_index_type>* result)
const override;
278 global_index_type>* result)
override;
369 return non_local_mtx_;
415 static std::unique_ptr<Matrix>
create(std::shared_ptr<const Executor> exec,
438 template <
typename MatrixType,
439 typename = std::enable_if_t<detail::is_matrix_type_builder<
440 MatrixType, ValueType, LocalIndexType>::value>>
441 static std::unique_ptr<Matrix>
create(std::shared_ptr<const Executor> exec,
443 MatrixType matrix_template)
447 matrix_template.template create<ValueType, LocalIndexType>(exec));
478 template <
typename LocalMatrixType,
typename NonLocalMatrixType,
479 typename = std::enable_if_t<
480 detail::is_matrix_type_builder<LocalMatrixType, ValueType,
481 LocalIndexType>::value &&
482 detail::is_matrix_type_builder<NonLocalMatrixType, ValueType,
483 LocalIndexType>::value>>
486 LocalMatrixType local_matrix_template,
487 NonLocalMatrixType non_local_matrix_template)
491 local_matrix_template.template create<ValueType, LocalIndexType>(
493 non_local_matrix_template
494 .template create<ValueType, LocalIndexType>(exec));
511 static std::unique_ptr<Matrix>
create(
531 static std::unique_ptr<Matrix>
create(
548 static std::unique_ptr<Matrix>
create(std::shared_ptr<const Executor> exec,
550 std::shared_ptr<LinOp> local_linop);
570 static std::unique_ptr<Matrix>
create(
572 dim<2> size, std::shared_ptr<LinOp> local_linop,
573 std::shared_ptr<LinOp> non_local_linop,
574 std::vector<comm_index_type> recv_sizes,
575 std::vector<comm_index_type> recv_offsets,
579 explicit Matrix(std::shared_ptr<const Executor> exec,
582 explicit Matrix(std::shared_ptr<const Executor> exec,
587 explicit Matrix(std::shared_ptr<const Executor> exec,
589 std::shared_ptr<LinOp> local_linop);
591 explicit Matrix(std::shared_ptr<const Executor> exec,
593 std::shared_ptr<LinOp> local_linop,
594 std::shared_ptr<LinOp> non_local_linop,
595 std::vector<comm_index_type> recv_sizes,
596 std::vector<comm_index_type> recv_offsets,
607 mpi::request communicate(
const local_vector_type* local_b)
const;
609 void apply_impl(
const LinOp* b,
LinOp* x)
const override;
612 LinOp* x)
const override;
615 std::vector<comm_index_type> send_offsets_;
616 std::vector<comm_index_type> send_sizes_;
617 std::vector<comm_index_type> recv_offsets_;
618 std::vector<comm_index_type> recv_sizes_;
621 gko::detail::DenseCache<value_type> one_scalar_;
622 gko::detail::DenseCache<value_type> host_send_buffer_;
623 gko::detail::DenseCache<value_type> host_recv_buffer_;
624 gko::detail::DenseCache<value_type> send_buffer_;
625 gko::detail::DenseCache<value_type> recv_buffer_;
626 std::shared_ptr<LinOp> local_mtx_;
627 std::shared_ptr<LinOp> non_local_mtx_;
639 #endif // GKO_PUBLIC_CORE_DISTRIBUTED_MATRIX_HPP_