Ginkgo  Generated from pipelines/2171896597 branch based on develop. Ginkgo version 1.11.0
A numerical linear algebra library targeting many-core architectures
combination.hpp
1 // SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
2 //
3 // SPDX-License-Identifier: BSD-3-Clause
4 
5 #ifndef GKO_PUBLIC_CORE_BASE_COMBINATION_HPP_
6 #define GKO_PUBLIC_CORE_BASE_COMBINATION_HPP_
7 
8 
9 #include <type_traits>
10 #include <vector>
11 
12 #include <ginkgo/core/base/lin_op.hpp>
13 
14 
15 namespace gko {
16 
17 
30 template <typename ValueType = default_precision>
31 class Combination : public EnableLinOp<Combination<ValueType>>,
32  public EnableCreateMethod<Combination<ValueType>>,
33  public Transposable {
35  friend class EnableCreateMethod<Combination>;
36  GKO_ASSERT_SUPPORTED_VALUE_TYPE;
37 
38 public:
39  using value_type = ValueType;
41 
47  const std::vector<std::shared_ptr<const LinOp>>& get_coefficients()
48  const noexcept
49  {
50  return coefficients_;
51  }
52 
58  const std::vector<std::shared_ptr<const LinOp>>& get_operators()
59  const noexcept
60  {
61  return operators_;
62  }
63 
64  std::unique_ptr<LinOp> transpose() const override;
65 
66  std::unique_ptr<LinOp> conj_transpose() const override;
67 
73 
81 
86  Combination(const Combination&);
87 
94 
95 protected:
96  void add_operators() {}
97 
98  template <typename... Rest>
99  void add_operators(std::shared_ptr<const LinOp> coef,
100  std::shared_ptr<const LinOp> oper, Rest&&... rest)
101  {
102  GKO_ASSERT_EQUAL_DIMENSIONS(coef, dim<2>(1, 1));
103  GKO_ASSERT_EQUAL_DIMENSIONS(oper, this->get_size());
104  auto exec = this->get_executor();
105  coefficients_.push_back(std::move(coef));
106  operators_.push_back(std::move(oper));
107  if (coefficients_.back()->get_executor() != exec) {
108  coefficients_.back() = gko::clone(exec, coefficients_.back());
109  }
110  if (operators_.back()->get_executor() != exec) {
111  operators_.back() = gko::clone(exec, operators_.back());
112  }
113  add_operators(std::forward<Rest>(rest)...);
114  }
115 
121  explicit Combination(std::shared_ptr<const Executor> exec)
122  : EnableLinOp<Combination>(exec)
123  {}
124 
139  template <
140  typename CoefficientIterator, typename OperatorIterator,
141  typename = std::void_t<
142  typename std::iterator_traits<
143  CoefficientIterator>::iterator_category,
144  typename std::iterator_traits<OperatorIterator>::iterator_category>>
145  explicit Combination(CoefficientIterator coefficient_begin,
146  CoefficientIterator coefficient_end,
147  OperatorIterator operator_begin,
148  OperatorIterator operator_end)
149  : EnableLinOp<Combination>([&] {
150  if (operator_begin == operator_end) {
151  throw OutOfBoundsError(__FILE__, __LINE__, 1, 0);
152  }
153  return (*operator_begin)->get_executor();
154  }())
155  {
156  GKO_ASSERT_EQ(std::distance(coefficient_begin, coefficient_end),
157  std::distance(operator_begin, operator_end));
158  this->set_size((*operator_begin)->get_size());
159  auto coefficient_it = coefficient_begin;
160  for (auto operator_it = operator_begin; operator_it != operator_end;
161  ++operator_it) {
162  add_operators(*coefficient_it, *operator_it);
163  ++coefficient_it;
164  }
165  }
166 
177  template <typename... Rest>
178  explicit Combination(std::shared_ptr<const LinOp> coef,
179  std::shared_ptr<const LinOp> oper, Rest&&... rest)
180  : Combination(oper->get_executor())
181  {
182  this->set_size(oper->get_size());
183  add_operators(std::move(coef), std::move(oper),
184  std::forward<Rest>(rest)...);
185  }
186 
187  void apply_impl(const LinOp* b, LinOp* x) const override;
188 
189  void apply_impl(const LinOp* alpha, const LinOp* b, const LinOp* beta,
190  LinOp* x) const override;
191 
192 private:
193  std::vector<std::shared_ptr<const LinOp>> coefficients_;
194  std::vector<std::shared_ptr<const LinOp>> operators_;
195 
196  // TODO: solve race conditions when multithreading
197  mutable struct cache_struct {
198  cache_struct() = default;
199  ~cache_struct() = default;
200  cache_struct(const cache_struct& other) {}
201  cache_struct& operator=(const cache_struct& other) { return *this; }
202 
203  std::unique_ptr<LinOp> zero;
204  std::unique_ptr<LinOp> one;
205  std::unique_ptr<LinOp> intermediate_x;
206  } cache_;
207 };
208 
209 
210 } // namespace gko
211 
212 
213 #endif // GKO_PUBLIC_CORE_BASE_COMBINATION_HPP_
gko::Combination
The Combination class can be used to construct a linear combination of multiple linear operators c1 *...
Definition: combination.hpp:31
gko::LinOp
Definition: lin_op.hpp:117
gko::EnableCreateMethod
This mixin implements a static create() method on ConcreteType that dynamically allocates the memory,...
Definition: polymorphic_object.hpp:767
gko::Combination::transpose
std::unique_ptr< LinOp > transpose() const override
Returns a LinOp representing the transpose of the Transposable object.
gko::Transposable
Linear operators which support transposition should implement the Transposable interface.
Definition: lin_op.hpp:433
gko::clone
detail::cloned_type< Pointer > clone(const Pointer &p)
Creates a unique clone of the object pointed to by p.
Definition: utils_helper.hpp:173
gko
The Ginkgo namespace.
Definition: abstract_factory.hpp:20
gko::Combination::get_coefficients
const std::vector< std::shared_ptr< const LinOp > > & get_coefficients() const noexcept
Returns a list of coefficients of the combination.
Definition: combination.hpp:47
gko::Combination::Combination
Combination(const Combination &)
Copy-constructs a Combination.
gko::Combination::operator=
Combination & operator=(const Combination &)
Copy-assigns a Combination.
gko::Combination::conj_transpose
std::unique_ptr< LinOp > conj_transpose() const override
Returns a LinOp representing the conjugate transpose of the Transposable object.
gko::PolymorphicObject::get_executor
std::shared_ptr< const Executor > get_executor() const noexcept
Returns the Executor of the object.
Definition: polymorphic_object.hpp:243
gko::LinOp::get_size
const dim< 2 > & get_size() const noexcept
Returns the size of the operator.
Definition: lin_op.hpp:210
gko::Combination::get_operators
const std::vector< std::shared_ptr< const LinOp > > & get_operators() const noexcept
Returns a list of operators of the combination.
Definition: combination.hpp:58
gko::EnableLinOp
The EnableLinOp mixin can be used to provide sensible default implementations of the majority of the ...
Definition: lin_op.hpp:877
gko::LinOp::LinOp
LinOp(const LinOp &)=default
Copy-constructs a LinOp.
gko::zero
constexpr T zero()
Returns the additive identity for T.
Definition: math.hpp:626
gko::one
constexpr T one()
Returns the multiplicative identity for T.
Definition: math.hpp:654
gko::EnablePolymorphicObject
This mixin inherits from (a subclass of) PolymorphicObject and provides a base implementation of a ne...
Definition: polymorphic_object.hpp:667