Ginkgo  Generated from pipelines/1330831941 branch based on master. Ginkgo version 1.8.0
A numerical linear algebra library targeting many-core architectures
batch_lin_op.hpp
1 // SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
2 //
3 // SPDX-License-Identifier: BSD-3-Clause
4 
5 #ifndef GKO_PUBLIC_CORE_BASE_BATCH_LIN_OP_HPP_
6 #define GKO_PUBLIC_CORE_BASE_BATCH_LIN_OP_HPP_
7 
8 
9 #include <memory>
10 #include <type_traits>
11 #include <utility>
12 
13 
14 #include <ginkgo/core/base/abstract_factory.hpp>
15 #include <ginkgo/core/base/batch_multi_vector.hpp>
16 #include <ginkgo/core/base/dim.hpp>
17 #include <ginkgo/core/base/exception_helpers.hpp>
18 #include <ginkgo/core/base/math.hpp>
19 #include <ginkgo/core/base/matrix_assembly_data.hpp>
20 #include <ginkgo/core/base/matrix_data.hpp>
21 #include <ginkgo/core/base/polymorphic_object.hpp>
22 #include <ginkgo/core/base/types.hpp>
23 #include <ginkgo/core/base/utils.hpp>
24 #include <ginkgo/core/log/logger.hpp>
25 
26 
27 namespace gko {
28 namespace batch {
29 
30 
60 class BatchLinOp : public EnableAbstractPolymorphicObject<BatchLinOp> {
61 public:
67  size_type get_num_batch_items() const noexcept
68  {
69  return get_size().get_num_batch_items();
70  }
71 
78 
84  const batch_dim<2>& get_size() const noexcept { return size_; }
85 
91  template <typename ValueType>
93  MultiVector<ValueType>* x) const
94  {
95  GKO_ASSERT_EQ(b->get_num_batch_items(), this->get_num_batch_items());
96  GKO_ASSERT_EQ(this->get_num_batch_items(), x->get_num_batch_items());
97 
98  GKO_ASSERT_CONFORMANT(this->get_common_size(), b->get_common_size());
99  GKO_ASSERT_EQUAL_ROWS(this->get_common_size(), x->get_common_size());
100  GKO_ASSERT_EQUAL_COLS(b->get_common_size(), x->get_common_size());
101  }
102 
108  template <typename ValueType>
110  const MultiVector<ValueType>* b,
111  const MultiVector<ValueType>* beta,
112  MultiVector<ValueType>* x) const
113  {
114  GKO_ASSERT_EQ(b->get_num_batch_items(), this->get_num_batch_items());
115  GKO_ASSERT_EQ(this->get_num_batch_items(), x->get_num_batch_items());
116 
117  GKO_ASSERT_CONFORMANT(this->get_common_size(), b->get_common_size());
118  GKO_ASSERT_EQUAL_ROWS(this->get_common_size(), x->get_common_size());
119  GKO_ASSERT_EQUAL_COLS(b->get_common_size(), x->get_common_size());
120  GKO_ASSERT_EQUAL_DIMENSIONS(alpha->get_common_size(),
121  gko::dim<2>(1, 1));
122  GKO_ASSERT_EQUAL_DIMENSIONS(beta->get_common_size(), gko::dim<2>(1, 1));
123  }
124 
125 protected:
131  void set_size(const batch_dim<2>& size) { size_ = size; }
132 
139  explicit BatchLinOp(std::shared_ptr<const Executor> exec,
140  const batch_dim<2>& batch_size)
141  : EnableAbstractPolymorphicObject<BatchLinOp>(exec), size_{batch_size}
142  {}
143 
152  explicit BatchLinOp(std::shared_ptr<const Executor> exec,
153  const size_type num_batch_items = 0,
154  const dim<2>& common_size = dim<2>{})
155  : BatchLinOp{std::move(exec),
156  num_batch_items > 0
157  ? batch_dim<2>(num_batch_items, common_size)
158  : batch_dim<2>{}}
159  {}
160 
161 private:
162  batch_dim<2> size_{};
163 };
164 
165 
196  : public AbstractFactory<BatchLinOp, std::shared_ptr<const BatchLinOp>> {
197 public:
199  std::shared_ptr<const BatchLinOp>>::AbstractFactory;
200 
201  std::unique_ptr<BatchLinOp> generate(
202  std::shared_ptr<const BatchLinOp> input) const
203  {
204  this->template log<
205  gko::log::Logger::batch_linop_factory_generate_started>(
206  this, input.get());
207  const auto exec = this->get_executor();
208  std::unique_ptr<BatchLinOp> generated;
209  if (input->get_executor() == exec) {
210  generated = this->AbstractFactory::generate(input);
211  } else {
212  generated =
213  this->AbstractFactory::generate(gko::clone(exec, input));
214  }
215  this->template log<
216  gko::log::Logger::batch_linop_factory_generate_completed>(
217  this, input.get(), generated.get());
218  return generated;
219  }
220 };
221 
222 
250 template <typename ConcreteBatchLinOp, typename PolymorphicBase = BatchLinOp>
252  : public EnablePolymorphicObject<ConcreteBatchLinOp, PolymorphicBase>,
253  public EnablePolymorphicAssignment<ConcreteBatchLinOp> {
254 public:
255  using EnablePolymorphicObject<ConcreteBatchLinOp,
256  PolymorphicBase>::EnablePolymorphicObject;
257 };
258 
259 
276 template <typename ConcreteFactory, typename ConcreteBatchLinOp,
277  typename ParametersType, typename PolymorphicBase = BatchLinOpFactory>
279  EnableDefaultFactory<ConcreteFactory, ConcreteBatchLinOp, ParametersType,
280  PolymorphicBase>;
281 
282 
359 #define GKO_ENABLE_BATCH_LIN_OP_FACTORY(_batch_lin_op, _parameters_name, \
360  _factory_name) \
361 public: \
362  const _parameters_name##_type& get_##_parameters_name() const \
363  { \
364  return _parameters_name##_; \
365  } \
366  \
367  class _factory_name \
368  : public ::gko::batch::EnableDefaultBatchLinOpFactory< \
369  _factory_name, _batch_lin_op, _parameters_name##_type> { \
370  friend class ::gko::EnablePolymorphicObject< \
371  _factory_name, ::gko::batch::BatchLinOpFactory>; \
372  friend class ::gko::enable_parameters_type<_parameters_name##_type, \
373  _factory_name>; \
374  explicit _factory_name(std::shared_ptr<const ::gko::Executor> exec) \
375  : ::gko::batch::EnableDefaultBatchLinOpFactory< \
376  _factory_name, _batch_lin_op, _parameters_name##_type>( \
377  std::move(exec)) \
378  {} \
379  explicit _factory_name(std::shared_ptr<const ::gko::Executor> exec, \
380  const _parameters_name##_type& parameters) \
381  : ::gko::batch::EnableDefaultBatchLinOpFactory< \
382  _factory_name, _batch_lin_op, _parameters_name##_type>( \
383  std::move(exec), parameters) \
384  {} \
385  }; \
386  friend ::gko::batch::EnableDefaultBatchLinOpFactory< \
387  _factory_name, _batch_lin_op, _parameters_name##_type>; \
388  \
389  \
390 private: \
391  _parameters_name##_type _parameters_name##_; \
392  \
393 public: \
394  static_assert(true, \
395  "This assert is used to counter the false positive extra " \
396  "semi-colon warnings")
397 
398 
399 } // namespace batch
400 } // namespace gko
401 
402 
403 #endif // GKO_PUBLIC_CORE_BASE_BATCH_LIN_OP_HPP_
gko::batch::EnableBatchLinOp
The EnableBatchLinOp mixin can be used to provide sensible default implementations of the majority of...
Definition: batch_lin_op.hpp:251
gko::batch_dim::get_num_batch_items
size_type get_num_batch_items() const
Get the number of batch items stored.
Definition: batch_dim.hpp:37
gko::batch_dim::get_common_size
dim< dimensionality, dimension_type > get_common_size() const
Get the common size of the batch items.
Definition: batch_dim.hpp:44
gko::batch::BatchLinOp::validate_application_parameters
void validate_application_parameters(const MultiVector< ValueType > *alpha, const MultiVector< ValueType > *b, const MultiVector< ValueType > *beta, MultiVector< ValueType > *x) const
Validates the sizes for the apply(alpha, b , beta, x) operation in the concrete BatchLinOp.
Definition: batch_lin_op.hpp:109
gko::batch::BatchLinOp::validate_application_parameters
void validate_application_parameters(const MultiVector< ValueType > *b, MultiVector< ValueType > *x) const
Validates the sizes for the apply(b,x) operation in the concrete BatchLinOp.
Definition: batch_lin_op.hpp:92
gko::batch::BatchLinOpFactory
A BatchLinOpFactory represents a higher order mapping which transforms one batch linear operator into...
Definition: batch_lin_op.hpp:195
gko::AbstractFactory
The AbstractFactory is a generic interface template that enables easy implementation of the abstract ...
Definition: abstract_factory.hpp:45
gko::batch::BatchLinOp
Definition: batch_lin_op.hpp:60
gko::batch::BatchLinOp::get_num_batch_items
size_type get_num_batch_items() const noexcept
Returns the number of items in the batch operator.
Definition: batch_lin_op.hpp:67
gko::AbstractFactory::generate
std::unique_ptr< abstract_product_type > generate(Args &&... args) const
Creates a new product from the given components.
Definition: abstract_factory.hpp:67
gko::EnableAbstractPolymorphicObject
This mixin inherits from (a subclass of) PolymorphicObject and provides a base implementation of a ne...
Definition: polymorphic_object.hpp:346
gko::size_type
std::size_t size_type
Integral type used for allocation quantities.
Definition: types.hpp:108
gko::EnableDefaultFactory
This mixin provides a default implementation of a concrete factory.
Definition: abstract_factory.hpp:124
gko::EnablePolymorphicAssignment
This mixin is used to enable a default PolymorphicObject::copy_from() implementation for objects that...
Definition: polymorphic_object.hpp:724
gko::batch::MultiVector
MultiVector stores multiple vectors in a batched fashion and is useful for batched operations.
Definition: batch_multi_vector.hpp:53
gko::clone
detail::cloned_type< Pointer > clone(const Pointer &p)
Creates a unique clone of the object pointed to by p.
Definition: utils_helper.hpp:175
gko::batch::BatchLinOp::get_common_size
dim< 2 > get_common_size() const
Returns the common size of the batch items.
Definition: batch_lin_op.hpp:77
gko
The Ginkgo namespace.
Definition: abstract_factory.hpp:20
gko::dim< 2 >
gko::batch::MultiVector::get_common_size
dim< 2 > get_common_size() const
Returns the common size of the batch items.
Definition: batch_multi_vector.hpp:127
gko::batch_dim< 2 >
gko::batch::MultiVector::get_num_batch_items
size_type get_num_batch_items() const
Returns the number of batch items.
Definition: batch_multi_vector.hpp:117
gko::PolymorphicObject::get_executor
std::shared_ptr< const Executor > get_executor() const noexcept
Returns the Executor of the object.
Definition: polymorphic_object.hpp:235
gko::batch::BatchLinOp::get_size
const batch_dim< 2 > & get_size() const noexcept
Returns the size of the batch operator.
Definition: batch_lin_op.hpp:84
gko::EnablePolymorphicObject
This mixin inherits from (a subclass of) PolymorphicObject and provides a base implementation of a ne...
Definition: polymorphic_object.hpp:662