5 #ifndef GKO_PUBLIC_CORE_DISTRIBUTED_MATRIX_HPP_ 
    6 #define GKO_PUBLIC_CORE_DISTRIBUTED_MATRIX_HPP_ 
    9 #include <ginkgo/config.hpp> 
   15 #include <ginkgo/core/base/dense_cache.hpp> 
   16 #include <ginkgo/core/base/mpi.hpp> 
   17 #include <ginkgo/core/distributed/base.hpp> 
   18 #include <ginkgo/core/distributed/index_map.hpp> 
   19 #include <ginkgo/core/distributed/lin_op.hpp> 
   26 template <
typename ValueType, 
typename IndexType>
 
   36 template <
typename ValueType, 
typename IndexType>
 
   50 template <
typename Builder, 
typename ValueType, 
typename IndexType,
 
   52 struct is_matrix_type_builder : std::false_type {};
 
   55 template <
typename Builder, 
typename ValueType, 
typename IndexType>
 
   56 struct is_matrix_type_builder<
 
   57     Builder, ValueType, IndexType,
 
   59         decltype(std::declval<Builder>().template create<ValueType, IndexType>(
 
   60             std::declval<std::shared_ptr<const Executor>>()))>>
 
   64 template <
template <
typename, 
typename> 
class MatrixType,
 
   65           typename... CreateArgs>
 
   66 struct MatrixTypeBuilderFromValueAndIndex {
 
   67     template <
typename ValueType, 
typename IndexType, std::size_t... I>
 
   68     auto create_impl(std::shared_ptr<const Executor> exec,
 
   69                      std::index_sequence<I...>)
 
   71         return MatrixType<ValueType, IndexType>::create(
 
   72             exec, std::get<I>(create_args)...);
 
   76     template <
typename ValueType, 
typename IndexType>
 
   77     auto create(std::shared_ptr<const Executor> exec)
 
   80         static constexpr 
auto size = 
sizeof...(CreateArgs);
 
   81         return create_impl<ValueType, IndexType>(
 
   82             std::move(exec), std::make_index_sequence<size>{});
 
   85     std::tuple<CreateArgs...> create_args;
 
  123 template <
template <
typename, 
typename> 
class MatrixType, 
typename... Args>
 
  126     return detail::MatrixTypeBuilderFromValueAndIndex<MatrixType, Args...>{
 
  127         std::forward_as_tuple(create_args...)};
 
  131 namespace experimental {
 
  132 namespace distributed {
 
  135 template <
typename LocalIndexType, 
typename GlobalIndexType>
 
  137 template <
typename ValueType>
 
  246           typename LocalIndexType = 
int32, 
typename GlobalIndexType = 
int64>
 
  249           Matrix<ValueType, LocalIndexType, GlobalIndexType>>,
 
  251           Matrix<next_precision<ValueType>, LocalIndexType, GlobalIndexType>>,
 
  259     using value_type = ValueType;
 
  260     using index_type = GlobalIndexType;
 
  261     using local_index_type = LocalIndexType;
 
  262     using global_index_type = GlobalIndexType;
 
  270                                GlobalIndexType>>::convert_to;
 
  272                                GlobalIndexType>>::move_to;
 
  275                            global_index_type>* result) 
const override;
 
  278                         global_index_type>* result) 
override;
 
  369         return non_local_mtx_;
 
  415     static std::unique_ptr<Matrix> 
create(std::shared_ptr<const Executor> exec,
 
  438     template <
typename MatrixType,
 
  439               typename = std::enable_if_t<detail::is_matrix_type_builder<
 
  440                   MatrixType, ValueType, LocalIndexType>::value>>
 
  441     static std::unique_ptr<Matrix> 
create(std::shared_ptr<const Executor> exec,
 
  443                                           MatrixType matrix_template)
 
  447             matrix_template.template create<ValueType, LocalIndexType>(exec));
 
  478     template <
typename LocalMatrixType, 
typename NonLocalMatrixType,
 
  479               typename = std::enable_if_t<
 
  480                   detail::is_matrix_type_builder<LocalMatrixType, ValueType,
 
  481                                                  LocalIndexType>::value &&
 
  482                   detail::is_matrix_type_builder<NonLocalMatrixType, ValueType,
 
  483                                                  LocalIndexType>::value>>
 
  486         LocalMatrixType local_matrix_template,
 
  487         NonLocalMatrixType non_local_matrix_template)
 
  491             local_matrix_template.template create<ValueType, LocalIndexType>(
 
  493             non_local_matrix_template
 
  494                 .template create<ValueType, LocalIndexType>(exec));
 
  511     static std::unique_ptr<Matrix> 
create(
 
  531     static std::unique_ptr<Matrix> 
create(
 
  548     static std::unique_ptr<Matrix> 
create(std::shared_ptr<const Executor> exec,
 
  550                                           std::shared_ptr<LinOp> local_linop);
 
  570     static std::unique_ptr<Matrix> 
create(
 
  572         dim<2> size, std::shared_ptr<LinOp> local_linop,
 
  573         std::shared_ptr<LinOp> non_local_linop,
 
  574         std::vector<comm_index_type> recv_sizes,
 
  575         std::vector<comm_index_type> recv_offsets,
 
  579     explicit Matrix(std::shared_ptr<const Executor> exec,
 
  582     explicit Matrix(std::shared_ptr<const Executor> exec,
 
  587     explicit Matrix(std::shared_ptr<const Executor> exec,
 
  589                     std::shared_ptr<LinOp> local_linop);
 
  591     explicit Matrix(std::shared_ptr<const Executor> exec,
 
  593                     std::shared_ptr<LinOp> local_linop,
 
  594                     std::shared_ptr<LinOp> non_local_linop,
 
  595                     std::vector<comm_index_type> recv_sizes,
 
  596                     std::vector<comm_index_type> recv_offsets,
 
  607     mpi::request communicate(
const local_vector_type* local_b) 
const;
 
  609     void apply_impl(
const LinOp* b, 
LinOp* x) 
const override;
 
  612                     LinOp* x) 
const override;
 
  615     std::vector<comm_index_type> send_offsets_;
 
  616     std::vector<comm_index_type> send_sizes_;
 
  617     std::vector<comm_index_type> recv_offsets_;
 
  618     std::vector<comm_index_type> recv_sizes_;
 
  621     gko::detail::DenseCache<value_type> one_scalar_;
 
  622     gko::detail::DenseCache<value_type> host_send_buffer_;
 
  623     gko::detail::DenseCache<value_type> host_recv_buffer_;
 
  624     gko::detail::DenseCache<value_type> send_buffer_;
 
  625     gko::detail::DenseCache<value_type> recv_buffer_;
 
  626     std::shared_ptr<LinOp> local_mtx_;
 
  627     std::shared_ptr<LinOp> non_local_mtx_;
 
  639 #endif  // GKO_PUBLIC_CORE_DISTRIBUTED_MATRIX_HPP_