5 #ifndef GKO_PUBLIC_CORE_DISTRIBUTED_VECTOR_HPP_
6 #define GKO_PUBLIC_CORE_DISTRIBUTED_VECTOR_HPP_
9 #include <ginkgo/config.hpp>
15 #include <ginkgo/core/base/dense_cache.hpp>
16 #include <ginkgo/core/base/lin_op.hpp>
17 #include <ginkgo/core/base/mpi.hpp>
18 #include <ginkgo/core/distributed/base.hpp>
19 #include <ginkgo/core/matrix/dense.hpp>
23 namespace experimental {
24 namespace distributed {
28 template <
typename ValueType>
35 template <
typename LocalIndexType,
typename GlobalIndexType>
66 template <
typename ValueType =
double>
68 :
public EnableLinOp<Vector<ValueType>>,
69 public ConvertibleTo<Vector<next_precision<ValueType>>>,
70 #if GINKGO_ENABLE_HALF || GINKGO_ENABLE_BFLOAT16
71 public ConvertibleTo<Vector<next_precision<ValueType, 2>>>,
73 #if GINKGO_ENABLE_HALF && GINKGO_ENABLE_BFLOAT16
74 public ConvertibleTo<Vector<next_precision<ValueType, 3>>>,
76 public EnableAbsoluteComputation<remove_complex<Vector<ValueType>>>,
77 public DistributedBase {
78 friend class EnablePolymorphicObject<Vector,
LinOp>;
82 friend class detail::VectorCache<ValueType>;
87 using ConvertibleTo<Vector<next_precision<ValueType>>>::convert_to;
88 using ConvertibleTo<Vector<next_precision<ValueType>>>::move_to;
90 using value_type = ValueType;
91 using absolute_type = remove_complex<Vector>;
92 using real_type = absolute_type;
93 using complex_type = Vector<to_complex<value_type>>;
103 ptr_param<const Vector> other);
118 ptr_param<const Vector> other, std::shared_ptr<const Executor> exec);
133 ptr_param<const Vector> other, std::shared_ptr<const Executor> exec,
134 const dim<2>& global_size,
const dim<2>& local_size,
size_type stride);
151 ptr_param<
const Partition<int64, int64>> partition);
154 ptr_param<
const Partition<int32, int64>> partition);
157 ptr_param<
const Partition<int32, int32>> partition);
169 ptr_param<
const Partition<int64, int64>> partition);
172 ptr_param<
const Partition<int32, int64>> partition);
175 ptr_param<
const Partition<int32, int32>> partition);
177 void convert_to(Vector<next_precision<ValueType>>* result)
const override;
179 void move_to(Vector<next_precision<ValueType>>* result)
override;
181 #if GINKGO_ENABLE_HALF || GINKGO_ENABLE_BFLOAT16
183 using ConvertibleTo<Vector<next_precision<ValueType, 2>>>::convert_to;
184 using ConvertibleTo<Vector<next_precision<ValueType, 2>>>::move_to;
187 Vector<next_precision<ValueType, 2>>* result)
const override;
189 void move_to(Vector<next_precision<ValueType, 2>>* result)
override;
192 #if GINKGO_ENABLE_HALF && GINKGO_ENABLE_BFLOAT16
194 using ConvertibleTo<Vector<next_precision<ValueType, 3>>>::convert_to;
195 using ConvertibleTo<Vector<next_precision<ValueType, 3>>>::move_to;
198 Vector<next_precision<ValueType, 3>>* result)
const override;
200 void move_to(Vector<next_precision<ValueType, 3>>* result)
override;
218 void make_complex(ptr_param<complex_type> result)
const;
224 std::unique_ptr<real_type>
get_real()
const;
229 void get_real(ptr_param<real_type> result)
const;
235 std::unique_ptr<real_type>
get_imag()
const;
241 void get_imag(ptr_param<real_type> result)
const;
248 void fill(ValueType value);
259 void scale(ptr_param<const LinOp> alpha);
270 void inv_scale(ptr_param<const LinOp> alpha);
281 void add_scaled(ptr_param<const LinOp> alpha, ptr_param<const LinOp> b);
291 void sub_scaled(ptr_param<const LinOp> alpha, ptr_param<const LinOp> b);
302 void compute_dot(ptr_param<const LinOp> b, ptr_param<LinOp> result)
const;
316 void compute_dot(ptr_param<const LinOp> b, ptr_param<LinOp> result,
317 array<char>& tmp)
const;
329 ptr_param<LinOp> result)
const;
344 array<char>& tmp)
const;
390 void compute_norm2(ptr_param<LinOp> result, array<char>& tmp)
const;
412 void compute_norm1(ptr_param<LinOp> result, array<char>& tmp)
const;
435 void compute_mean(ptr_param<LinOp> result, array<char>& tmp)
const;
538 static std::unique_ptr<Vector>
create(std::shared_ptr<const Executor> exec,
539 mpi::communicator comm,
540 dim<2> global_size, dim<2> local_size,
554 static std::unique_ptr<Vector>
create(std::shared_ptr<const Executor> exec,
555 mpi::communicator comm,
556 dim<2> global_size = {},
557 dim<2> local_size = {});
576 static std::unique_ptr<Vector>
create(
577 std::shared_ptr<const Executor> exec, mpi::communicator comm,
578 dim<2> global_size, std::unique_ptr<local_vector_type> local_vector);
598 static std::unique_ptr<Vector>
create(
599 std::shared_ptr<const Executor> exec, mpi::communicator comm,
600 std::unique_ptr<local_vector_type> local_vector);
615 std::shared_ptr<const Executor> exec, mpi::communicator comm,
617 std::unique_ptr<const local_vector_type> local_vector);
632 std::shared_ptr<const Executor> exec, mpi::communicator comm,
633 std::unique_ptr<const local_vector_type> local_vector);
636 Vector(std::shared_ptr<const Executor> exec, mpi::communicator comm,
637 dim<2> global_size, dim<2> local_size,
size_type stride);
639 explicit Vector(std::shared_ptr<const Executor> exec,
640 mpi::communicator comm, dim<2> global_size = {},
641 dim<2> local_size = {});
643 Vector(std::shared_ptr<const Executor> exec, mpi::communicator comm,
644 dim<2> global_size, std::unique_ptr<local_vector_type> local_vector);
646 Vector(std::shared_ptr<const Executor> exec, mpi::communicator comm,
647 std::unique_ptr<local_vector_type> local_vector);
649 void resize(dim<2> global_size, dim<2> local_size);
651 template <
typename LocalIndexType,
typename GlobalIndexType>
652 void read_distributed_impl(
653 const device_matrix_data<ValueType, GlobalIndexType>& data,
654 const Partition<LocalIndexType, GlobalIndexType>* partition);
656 void apply_impl(
const LinOp*,
LinOp*)
const override;
659 LinOp*)
const override;
667 virtual std::unique_ptr<Vector> create_with_same_config()
const;
681 virtual std::unique_ptr<Vector> create_with_type_of_impl(
682 std::shared_ptr<const Executor> exec,
const dim<2>& global_size,
683 const dim<2>& local_size,
size_type stride)
const;
688 virtual std::unique_ptr<Vector> create_submatrix_impl(local_span rows,
693 local_vector_type local_;
694 ::gko::detail::DenseCache<ValueType> host_reduction_buffer_;
695 ::gko::detail::DenseCache<remove_complex<ValueType>> host_norm_buffer_;
706 template <
typename TargetType>
707 struct conversion_target_helper;
719 template <
typename ValueType>
720 struct conversion_target_helper<experimental::distributed::Vector<ValueType>> {
721 using target_type = experimental::distributed::Vector<ValueType>;
723 experimental::distributed::Vector<previous_precision<ValueType>>;
725 static std::unique_ptr<target_type> create_empty(
const source_type* source)
727 return target_type::create(source->get_executor(),
728 source->get_communicator());
734 static std::unique_ptr<target_type> create_empty(
const target_type* source)
736 return target_type::create(source->get_executor(),
737 source->get_communicator());
740 #if GINKGO_ENABLE_HALF || GINKGO_ENABLE_BFLOAT16
741 using snd_source_type =
742 experimental::distributed::Vector<previous_precision<ValueType, 2>>;
744 static std::unique_ptr<target_type> create_empty(
745 const snd_source_type* source)
747 return target_type::create(source->get_executor(),
748 source->get_communicator());
751 #if GINKGO_ENABLE_HALF && GINKGO_ENABLE_BFLOAT16
752 using trd_source_type =
753 experimental::distributed::Vector<previous_precision<ValueType, 3>>;
755 static std::unique_ptr<target_type> create_empty(
756 const trd_source_type* source)
758 return target_type::create(source->get_executor(),
759 source->get_communicator());
769 #endif // GINKGO_BUILD_MPI
772 #endif // GKO_PUBLIC_CORE_DISTRIBUTED_VECTOR_HPP_