5 #ifndef GKO_PUBLIC_CORE_DISTRIBUTED_MATRIX_HPP_
6 #define GKO_PUBLIC_CORE_DISTRIBUTED_MATRIX_HPP_
9 #include <ginkgo/config.hpp>
15 #include <ginkgo/core/base/dense_cache.hpp>
16 #include <ginkgo/core/base/lin_op.hpp>
17 #include <ginkgo/core/base/mpi.hpp>
18 #include <ginkgo/core/base/std_extensions.hpp>
19 #include <ginkgo/core/distributed/base.hpp>
20 #include <ginkgo/core/distributed/index_map.hpp>
21 #include <ginkgo/core/distributed/row_gatherer.hpp>
22 #include <ginkgo/core/distributed/vector_cache.hpp>
29 template <
typename ValueType,
typename IndexType>
39 template <
typename ValueType,
typename IndexType>
53 template <
typename Builder,
typename ValueType,
typename IndexType,
55 struct is_matrix_type_builder : std::false_type {};
58 template <
typename Builder,
typename ValueType,
typename IndexType>
59 struct is_matrix_type_builder<
60 Builder, ValueType, IndexType,
62 decltype(std::declval<Builder>().template create<ValueType, IndexType>(
63 std::declval<std::shared_ptr<const Executor>>()))>>
67 template <
template <
typename,
typename>
class MatrixType,
68 typename... CreateArgs>
69 struct MatrixTypeBuilderFromValueAndIndex {
70 template <
typename ValueType,
typename IndexType, std::size_t... I>
71 auto create_impl(std::shared_ptr<const Executor> exec,
72 std::index_sequence<I...>)
74 return MatrixType<ValueType, IndexType>::create(
75 exec, std::get<I>(create_args)...);
79 template <
typename ValueType,
typename IndexType>
80 auto create(std::shared_ptr<const Executor> exec)
83 static constexpr
auto size =
sizeof...(CreateArgs);
84 return create_impl<ValueType, IndexType>(
85 std::move(exec), std::make_index_sequence<size>{});
88 std::tuple<CreateArgs...> create_args;
126 template <
template <
typename,
typename>
class MatrixType,
typename... Args>
129 return detail::MatrixTypeBuilderFromValueAndIndex<MatrixType, Args...>{
130 std::forward_as_tuple(create_args...)};
134 namespace experimental {
135 namespace distributed {
150 template <
typename LocalIndexType,
typename GlobalIndexType>
152 template <
typename ValueType>
261 typename LocalIndexType =
int32,
typename GlobalIndexType =
int64>
263 :
public EnableLinOp<Matrix<ValueType, LocalIndexType, GlobalIndexType>>,
265 Matrix<next_precision<ValueType>, LocalIndexType, GlobalIndexType>>,
266 #if GINKGO_ENABLE_HALF || GINKGO_ENABLE_BFLOAT16
267 public ConvertibleTo<Matrix<next_precision<ValueType, 2>, LocalIndexType,
270 #if GINKGO_ENABLE_HALF && GINKGO_ENABLE_BFLOAT16
271 public ConvertibleTo<Matrix<next_precision<ValueType, 3>, LocalIndexType,
279 GKO_ASSERT_SUPPORTED_VALUE_AND_DIST_INDEX_TYPE;
282 using value_type = ValueType;
283 using index_type = GlobalIndexType;
284 using local_index_type = LocalIndexType;
285 using global_index_type = GlobalIndexType;
293 GlobalIndexType>>::convert_to;
295 GlobalIndexType>>::move_to;
298 global_index_type>* result)
const override;
301 global_index_type>* result)
override;
303 #if GINKGO_ENABLE_HALF || GINKGO_ENABLE_BFLOAT16
307 global_index_type>>::convert_to;
309 global_index_type>>::move_to;
312 global_index_type>* result)
const override;
315 global_index_type>* result)
override;
318 #if GINKGO_ENABLE_HALF && GINKGO_ENABLE_BFLOAT16
322 global_index_type>>::convert_to;
324 global_index_type>>::move_to;
327 global_index_type>* result)
const override;
330 global_index_type>* result)
override;
433 return off_diag_mtx_;
439 GKO_DEPRECATED(
"use get_diag_matrix() instead")
448 GKO_DEPRECATED(
"use get_off_diag_matrix() instead")
497 static std::unique_ptr<Matrix>
create(std::shared_ptr<const Executor> exec,
513 static std::unique_ptr<Matrix>
create(
514 std::shared_ptr<const Executor> exec,
516 row_gatherer_template);
538 template <
typename MatrixType,
539 typename = std::enable_if_t<gko::detail::is_matrix_type_builder<
540 MatrixType, ValueType, LocalIndexType>::value>>
541 static std::unique_ptr<Matrix>
create(std::shared_ptr<const Executor> exec,
543 MatrixType matrix_template)
547 matrix_template.template create<ValueType, LocalIndexType>(exec));
578 template <
typename DiagMatrixType,
typename OffDiagMatrixType,
579 typename = std::enable_if_t<
580 gko::detail::is_matrix_type_builder<DiagMatrixType, ValueType,
581 LocalIndexType>::value &&
582 gko::detail::is_matrix_type_builder<
583 OffDiagMatrixType, ValueType, LocalIndexType>::value>>
586 DiagMatrixType diag_matrix_template,
587 OffDiagMatrixType off_diag_matrix_template)
591 diag_matrix_template.template create<ValueType, LocalIndexType>(
593 off_diag_matrix_template.template create<ValueType, LocalIndexType>(
611 static std::unique_ptr<Matrix>
create(
631 static std::unique_ptr<Matrix>
create(
648 static std::unique_ptr<Matrix>
create(std::shared_ptr<const Executor> exec,
650 std::shared_ptr<LinOp> diag_linop);
671 "Please use the overload with an index_map instead.")]]
static std::
674 dim<2> size, std::shared_ptr<LinOp> diag_linop,
675 std::shared_ptr<LinOp> off_diag_linop,
676 std::vector<comm_index_type> recv_sizes,
677 std::vector<comm_index_type> recv_offsets,
693 static std::unique_ptr<Matrix>
create(
696 std::shared_ptr<LinOp> diag_linop,
697 std::shared_ptr<LinOp> off_diag_linop);
718 explicit Matrix(std::shared_ptr<const Executor> exec,
721 explicit Matrix(std::shared_ptr<const Executor> exec,
723 row_gatherer_template,
727 explicit Matrix(std::shared_ptr<const Executor> exec,
729 std::shared_ptr<LinOp> diag_linop);
731 explicit Matrix(std::shared_ptr<const Executor> exec,
734 std::shared_ptr<LinOp> diag_linop,
735 std::shared_ptr<LinOp> off_diag_linop);
737 void apply_impl(
const LinOp* b,
LinOp* x)
const override;
740 LinOp* x)
const override;
743 std::shared_ptr<RowGatherer<LocalIndexType>> row_gatherer_;
745 gko::detail::ScalarCache one_scalar_;
746 detail::GenericVectorCache recv_buffer_;
747 detail::GenericVectorCache host_recv_buffer_;
748 std::shared_ptr<LinOp> diag_mtx_;
749 std::shared_ptr<LinOp> off_diag_mtx_;
761 #endif // GKO_PUBLIC_CORE_DISTRIBUTED_MATRIX_HPP_