Ginkgo  Generated from pipelines/1554403166 branch based on develop. Ginkgo version 1.9.0
A numerical linear algebra library targeting many-core architectures
batch_lin_op.hpp
1 // SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
2 //
3 // SPDX-License-Identifier: BSD-3-Clause
4 
5 #ifndef GKO_PUBLIC_CORE_BASE_BATCH_LIN_OP_HPP_
6 #define GKO_PUBLIC_CORE_BASE_BATCH_LIN_OP_HPP_
7 
8 
9 #include <memory>
10 #include <type_traits>
11 #include <utility>
12 
13 #include <ginkgo/core/base/abstract_factory.hpp>
14 #include <ginkgo/core/base/batch_multi_vector.hpp>
15 #include <ginkgo/core/base/dim.hpp>
16 #include <ginkgo/core/base/exception_helpers.hpp>
17 #include <ginkgo/core/base/math.hpp>
18 #include <ginkgo/core/base/matrix_assembly_data.hpp>
19 #include <ginkgo/core/base/matrix_data.hpp>
20 #include <ginkgo/core/base/polymorphic_object.hpp>
21 #include <ginkgo/core/base/types.hpp>
22 #include <ginkgo/core/base/utils.hpp>
23 #include <ginkgo/core/log/logger.hpp>
24 
25 
26 namespace gko {
27 namespace batch {
28 
29 
59 class BatchLinOp : public EnableAbstractPolymorphicObject<BatchLinOp> {
60 public:
66  size_type get_num_batch_items() const noexcept
67  {
68  return get_size().get_num_batch_items();
69  }
70 
77 
83  const batch_dim<2>& get_size() const noexcept { return size_; }
84 
90  template <typename ValueType>
92  MultiVector<ValueType>* x) const
93  {
94  GKO_ASSERT_EQ(b->get_num_batch_items(), this->get_num_batch_items());
95  GKO_ASSERT_EQ(this->get_num_batch_items(), x->get_num_batch_items());
96 
97  GKO_ASSERT_CONFORMANT(this->get_common_size(), b->get_common_size());
98  GKO_ASSERT_EQUAL_ROWS(this->get_common_size(), x->get_common_size());
99  GKO_ASSERT_EQUAL_COLS(b->get_common_size(), x->get_common_size());
100  }
101 
107  template <typename ValueType>
109  const MultiVector<ValueType>* b,
110  const MultiVector<ValueType>* beta,
111  MultiVector<ValueType>* x) const
112  {
113  GKO_ASSERT_EQ(b->get_num_batch_items(), this->get_num_batch_items());
114  GKO_ASSERT_EQ(this->get_num_batch_items(), x->get_num_batch_items());
115 
116  GKO_ASSERT_CONFORMANT(this->get_common_size(), b->get_common_size());
117  GKO_ASSERT_EQUAL_ROWS(this->get_common_size(), x->get_common_size());
118  GKO_ASSERT_EQUAL_COLS(b->get_common_size(), x->get_common_size());
119  GKO_ASSERT_EQUAL_DIMENSIONS(alpha->get_common_size(),
120  gko::dim<2>(1, 1));
121  GKO_ASSERT_EQUAL_DIMENSIONS(beta->get_common_size(), gko::dim<2>(1, 1));
122  }
123 
124 protected:
130  void set_size(const batch_dim<2>& size) { size_ = size; }
131 
138  explicit BatchLinOp(std::shared_ptr<const Executor> exec,
139  const batch_dim<2>& batch_size)
140  : EnableAbstractPolymorphicObject<BatchLinOp>(exec), size_{batch_size}
141  {}
142 
151  explicit BatchLinOp(std::shared_ptr<const Executor> exec,
152  const size_type num_batch_items = 0,
153  const dim<2>& common_size = dim<2>{})
154  : BatchLinOp{std::move(exec),
155  num_batch_items > 0
156  ? batch_dim<2>(num_batch_items, common_size)
157  : batch_dim<2>{}}
158  {}
159 
160 private:
161  batch_dim<2> size_{};
162 };
163 
164 
195  : public AbstractFactory<BatchLinOp, std::shared_ptr<const BatchLinOp>> {
196 public:
198  std::shared_ptr<const BatchLinOp>>::AbstractFactory;
199 
200  std::unique_ptr<BatchLinOp> generate(
201  std::shared_ptr<const BatchLinOp> input) const
202  {
203  this->template log<
204  gko::log::Logger::batch_linop_factory_generate_started>(
205  this, input.get());
206  const auto exec = this->get_executor();
207  std::unique_ptr<BatchLinOp> generated;
208  if (input->get_executor() == exec) {
209  generated = this->AbstractFactory::generate(input);
210  } else {
211  generated =
212  this->AbstractFactory::generate(gko::clone(exec, input));
213  }
214  this->template log<
215  gko::log::Logger::batch_linop_factory_generate_completed>(
216  this, input.get(), generated.get());
217  return generated;
218  }
219 };
220 
221 
249 template <typename ConcreteBatchLinOp, typename PolymorphicBase = BatchLinOp>
251  : public EnablePolymorphicObject<ConcreteBatchLinOp, PolymorphicBase>,
252  public EnablePolymorphicAssignment<ConcreteBatchLinOp> {
253 public:
254  using EnablePolymorphicObject<ConcreteBatchLinOp,
255  PolymorphicBase>::EnablePolymorphicObject;
256 };
257 
258 
275 template <typename ConcreteFactory, typename ConcreteBatchLinOp,
276  typename ParametersType, typename PolymorphicBase = BatchLinOpFactory>
278  EnableDefaultFactory<ConcreteFactory, ConcreteBatchLinOp, ParametersType,
279  PolymorphicBase>;
280 
281 
358 #define GKO_ENABLE_BATCH_LIN_OP_FACTORY(_batch_lin_op, _parameters_name, \
359  _factory_name) \
360 public: \
361  const _parameters_name##_type& get_##_parameters_name() const \
362  { \
363  return _parameters_name##_; \
364  } \
365  \
366  class _factory_name \
367  : public ::gko::batch::EnableDefaultBatchLinOpFactory< \
368  _factory_name, _batch_lin_op, _parameters_name##_type> { \
369  friend class ::gko::EnablePolymorphicObject< \
370  _factory_name, ::gko::batch::BatchLinOpFactory>; \
371  friend class ::gko::enable_parameters_type<_parameters_name##_type, \
372  _factory_name>; \
373  explicit _factory_name(std::shared_ptr<const ::gko::Executor> exec) \
374  : ::gko::batch::EnableDefaultBatchLinOpFactory< \
375  _factory_name, _batch_lin_op, _parameters_name##_type>( \
376  std::move(exec)) \
377  {} \
378  explicit _factory_name(std::shared_ptr<const ::gko::Executor> exec, \
379  const _parameters_name##_type& parameters) \
380  : ::gko::batch::EnableDefaultBatchLinOpFactory< \
381  _factory_name, _batch_lin_op, _parameters_name##_type>( \
382  std::move(exec), parameters) \
383  {} \
384  }; \
385  friend ::gko::batch::EnableDefaultBatchLinOpFactory< \
386  _factory_name, _batch_lin_op, _parameters_name##_type>; \
387  \
388  \
389 private: \
390  _parameters_name##_type _parameters_name##_; \
391  \
392 public: \
393  static_assert(true, \
394  "This assert is used to counter the false positive extra " \
395  "semi-colon warnings")
396 
397 
398 } // namespace batch
399 } // namespace gko
400 
401 
402 #endif // GKO_PUBLIC_CORE_BASE_BATCH_LIN_OP_HPP_
gko::batch::EnableBatchLinOp
The EnableBatchLinOp mixin can be used to provide sensible default implementations of the majority of...
Definition: batch_lin_op.hpp:250
gko::batch_dim::get_num_batch_items
size_type get_num_batch_items() const
Get the number of batch items stored.
Definition: batch_dim.hpp:36
gko::batch_dim::get_common_size
dim< dimensionality, dimension_type > get_common_size() const
Get the common size of the batch items.
Definition: batch_dim.hpp:43
gko::batch::BatchLinOp::validate_application_parameters
void validate_application_parameters(const MultiVector< ValueType > *alpha, const MultiVector< ValueType > *b, const MultiVector< ValueType > *beta, MultiVector< ValueType > *x) const
Validates the sizes for the apply(alpha, b , beta, x) operation in the concrete BatchLinOp.
Definition: batch_lin_op.hpp:108
gko::batch::BatchLinOp::validate_application_parameters
void validate_application_parameters(const MultiVector< ValueType > *b, MultiVector< ValueType > *x) const
Validates the sizes for the apply(b,x) operation in the concrete BatchLinOp.
Definition: batch_lin_op.hpp:91
gko::batch::BatchLinOpFactory
A BatchLinOpFactory represents a higher order mapping which transforms one batch linear operator into...
Definition: batch_lin_op.hpp:194
gko::AbstractFactory
The AbstractFactory is a generic interface template that enables easy implementation of the abstract ...
Definition: abstract_factory.hpp:45
gko::batch::BatchLinOp
Definition: batch_lin_op.hpp:59
gko::batch::BatchLinOp::get_num_batch_items
size_type get_num_batch_items() const noexcept
Returns the number of items in the batch operator.
Definition: batch_lin_op.hpp:66
gko::AbstractFactory::generate
std::unique_ptr< abstract_product_type > generate(Args &&... args) const
Creates a new product from the given components.
Definition: abstract_factory.hpp:67
gko::EnableAbstractPolymorphicObject
This mixin inherits from (a subclass of) PolymorphicObject and provides a base implementation of a ne...
Definition: polymorphic_object.hpp:345
gko::size_type
std::size_t size_type
Integral type used for allocation quantities.
Definition: types.hpp:86
gko::EnableDefaultFactory
This mixin provides a default implementation of a concrete factory.
Definition: abstract_factory.hpp:124
gko::EnablePolymorphicAssignment
This mixin is used to enable a default PolymorphicObject::copy_from() implementation for objects that...
Definition: polymorphic_object.hpp:723
gko::batch::MultiVector
MultiVector stores multiple vectors in a batched fashion and is useful for batched operations.
Definition: batch_multi_vector.hpp:52
gko::clone
detail::cloned_type< Pointer > clone(const Pointer &p)
Creates a unique clone of the object pointed to by p.
Definition: utils_helper.hpp:173
gko::batch::BatchLinOp::get_common_size
dim< 2 > get_common_size() const
Returns the common size of the batch items.
Definition: batch_lin_op.hpp:76
gko
The Ginkgo namespace.
Definition: abstract_factory.hpp:20
gko::dim< 2 >
gko::batch::MultiVector::get_common_size
dim< 2 > get_common_size() const
Returns the common size of the batch items.
Definition: batch_multi_vector.hpp:126
gko::batch_dim< 2 >
gko::batch::MultiVector::get_num_batch_items
size_type get_num_batch_items() const
Returns the number of batch items.
Definition: batch_multi_vector.hpp:116
gko::PolymorphicObject::get_executor
std::shared_ptr< const Executor > get_executor() const noexcept
Returns the Executor of the object.
Definition: polymorphic_object.hpp:234
gko::batch::BatchLinOp::get_size
const batch_dim< 2 > & get_size() const noexcept
Returns the size of the batch operator.
Definition: batch_lin_op.hpp:83
gko::EnablePolymorphicObject
This mixin inherits from (a subclass of) PolymorphicObject and provides a base implementation of a ne...
Definition: polymorphic_object.hpp:661