5 #ifndef GKO_PUBLIC_CORE_DISTRIBUTED_VECTOR_HPP_
6 #define GKO_PUBLIC_CORE_DISTRIBUTED_VECTOR_HPP_
9 #include <ginkgo/config.hpp>
15 #include <ginkgo/core/base/dense_cache.hpp>
16 #include <ginkgo/core/base/lin_op.hpp>
17 #include <ginkgo/core/base/mpi.hpp>
18 #include <ginkgo/core/distributed/base.hpp>
19 #include <ginkgo/core/matrix/dense.hpp>
23 namespace experimental {
24 namespace distributed {
28 template <
typename ValueType>
35 template <
typename LocalIndexType,
typename GlobalIndexType>
66 template <
typename ValueType =
double>
68 :
public EnableLinOp<Vector<ValueType>>,
69 public ConvertibleTo<Vector<next_precision<ValueType>>>,
70 #if GINKGO_ENABLE_HALF || GINKGO_ENABLE_BFLOAT16
71 public ConvertibleTo<Vector<next_precision<ValueType, 2>>>,
73 #if GINKGO_ENABLE_HALF && GINKGO_ENABLE_BFLOAT16
74 public ConvertibleTo<Vector<next_precision<ValueType, 3>>>,
76 public EnableAbsoluteComputation<remove_complex<Vector<ValueType>>>,
77 public DistributedBase {
78 friend class EnablePolymorphicObject<Vector,
LinOp>;
82 friend class detail::VectorCache<ValueType>;
83 GKO_ASSERT_SUPPORTED_VALUE_TYPE;
88 using ConvertibleTo<Vector<next_precision<ValueType>>>::convert_to;
89 using ConvertibleTo<Vector<next_precision<ValueType>>>::move_to;
91 using value_type = ValueType;
92 using absolute_type = remove_complex<Vector>;
93 using real_type = absolute_type;
94 using complex_type = Vector<to_complex<value_type>>;
104 ptr_param<const Vector> other);
119 ptr_param<const Vector> other, std::shared_ptr<const Executor> exec);
134 ptr_param<const Vector> other, std::shared_ptr<const Executor> exec,
135 const dim<2>& global_size,
const dim<2>& local_size,
size_type stride);
152 ptr_param<
const Partition<int64, int64>> partition);
155 ptr_param<
const Partition<int32, int64>> partition);
158 ptr_param<
const Partition<int32, int32>> partition);
170 ptr_param<
const Partition<int64, int64>> partition);
173 ptr_param<
const Partition<int32, int64>> partition);
176 ptr_param<
const Partition<int32, int32>> partition);
178 void convert_to(Vector<next_precision<ValueType>>* result)
const override;
180 void move_to(Vector<next_precision<ValueType>>* result)
override;
182 #if GINKGO_ENABLE_HALF || GINKGO_ENABLE_BFLOAT16
184 using ConvertibleTo<Vector<next_precision<ValueType, 2>>>::convert_to;
185 using ConvertibleTo<Vector<next_precision<ValueType, 2>>>::move_to;
188 Vector<next_precision<ValueType, 2>>* result)
const override;
190 void move_to(Vector<next_precision<ValueType, 2>>* result)
override;
193 #if GINKGO_ENABLE_HALF && GINKGO_ENABLE_BFLOAT16
195 using ConvertibleTo<Vector<next_precision<ValueType, 3>>>::convert_to;
196 using ConvertibleTo<Vector<next_precision<ValueType, 3>>>::move_to;
199 Vector<next_precision<ValueType, 3>>* result)
const override;
201 void move_to(Vector<next_precision<ValueType, 3>>* result)
override;
219 void make_complex(ptr_param<complex_type> result)
const;
225 std::unique_ptr<real_type>
get_real()
const;
230 void get_real(ptr_param<real_type> result)
const;
236 std::unique_ptr<real_type>
get_imag()
const;
242 void get_imag(ptr_param<real_type> result)
const;
249 void fill(ValueType value);
260 void scale(ptr_param<const LinOp> alpha);
271 void inv_scale(ptr_param<const LinOp> alpha);
282 void add_scaled(ptr_param<const LinOp> alpha, ptr_param<const LinOp> b);
292 void sub_scaled(ptr_param<const LinOp> alpha, ptr_param<const LinOp> b);
303 void compute_dot(ptr_param<const LinOp> b, ptr_param<LinOp> result)
const;
317 void compute_dot(ptr_param<const LinOp> b, ptr_param<LinOp> result,
318 array<char>& tmp)
const;
330 ptr_param<LinOp> result)
const;
345 array<char>& tmp)
const;
391 void compute_norm2(ptr_param<LinOp> result, array<char>& tmp)
const;
413 void compute_norm1(ptr_param<LinOp> result, array<char>& tmp)
const;
436 void compute_mean(ptr_param<LinOp> result, array<char>& tmp)
const;
539 static std::unique_ptr<Vector>
create(std::shared_ptr<const Executor> exec,
540 mpi::communicator comm,
541 dim<2> global_size, dim<2> local_size,
555 static std::unique_ptr<Vector>
create(std::shared_ptr<const Executor> exec,
556 mpi::communicator comm,
557 dim<2> global_size = {},
558 dim<2> local_size = {});
577 static std::unique_ptr<Vector>
create(
578 std::shared_ptr<const Executor> exec, mpi::communicator comm,
579 dim<2> global_size, std::unique_ptr<local_vector_type> local_vector);
599 static std::unique_ptr<Vector>
create(
600 std::shared_ptr<const Executor> exec, mpi::communicator comm,
601 std::unique_ptr<local_vector_type> local_vector);
616 std::shared_ptr<const Executor> exec, mpi::communicator comm,
618 std::unique_ptr<const local_vector_type> local_vector);
633 std::shared_ptr<const Executor> exec, mpi::communicator comm,
634 std::unique_ptr<const local_vector_type> local_vector);
637 Vector(std::shared_ptr<const Executor> exec, mpi::communicator comm,
638 dim<2> global_size, dim<2> local_size,
size_type stride);
640 explicit Vector(std::shared_ptr<const Executor> exec,
641 mpi::communicator comm, dim<2> global_size = {},
642 dim<2> local_size = {});
644 Vector(std::shared_ptr<const Executor> exec, mpi::communicator comm,
645 dim<2> global_size, std::unique_ptr<local_vector_type> local_vector);
647 Vector(std::shared_ptr<const Executor> exec, mpi::communicator comm,
648 std::unique_ptr<local_vector_type> local_vector);
650 void resize(dim<2> global_size, dim<2> local_size);
652 template <
typename LocalIndexType,
typename GlobalIndexType>
653 void read_distributed_impl(
654 const device_matrix_data<ValueType, GlobalIndexType>& data,
655 const Partition<LocalIndexType, GlobalIndexType>* partition);
657 void apply_impl(
const LinOp*,
LinOp*)
const override;
660 LinOp*)
const override;
668 virtual std::unique_ptr<Vector> create_with_same_config()
const;
682 virtual std::unique_ptr<Vector> create_with_type_of_impl(
683 std::shared_ptr<const Executor> exec,
const dim<2>& global_size,
684 const dim<2>& local_size,
size_type stride)
const;
689 virtual std::unique_ptr<Vector> create_submatrix_impl(local_span rows,
694 local_vector_type local_;
695 ::gko::detail::DenseCache<ValueType> host_reduction_buffer_;
696 ::gko::detail::DenseCache<remove_complex<ValueType>> host_norm_buffer_;
707 template <
typename TargetType>
708 struct conversion_target_helper;
720 template <
typename ValueType>
721 struct conversion_target_helper<experimental::distributed::Vector<ValueType>> {
722 using target_type = experimental::distributed::Vector<ValueType>;
724 experimental::distributed::Vector<previous_precision<ValueType>>;
726 static std::unique_ptr<target_type> create_empty(
const source_type* source)
728 return target_type::create(source->get_executor(),
729 source->get_communicator());
735 static std::unique_ptr<target_type> create_empty(
const target_type* source)
737 return target_type::create(source->get_executor(),
738 source->get_communicator());
741 #if GINKGO_ENABLE_HALF || GINKGO_ENABLE_BFLOAT16
742 using snd_source_type =
743 experimental::distributed::Vector<previous_precision<ValueType, 2>>;
745 static std::unique_ptr<target_type> create_empty(
746 const snd_source_type* source)
748 return target_type::create(source->get_executor(),
749 source->get_communicator());
752 #if GINKGO_ENABLE_HALF && GINKGO_ENABLE_BFLOAT16
753 using trd_source_type =
754 experimental::distributed::Vector<previous_precision<ValueType, 3>>;
756 static std::unique_ptr<target_type> create_empty(
757 const trd_source_type* source)
759 return target_type::create(source->get_executor(),
760 source->get_communicator());
770 #endif // GINKGO_BUILD_MPI
773 #endif // GKO_PUBLIC_CORE_DISTRIBUTED_VECTOR_HPP_