Ginkgo  Generated from pipelines/1589998975 branch based on develop. Ginkgo version 1.10.0
A numerical linear algebra library targeting many-core architectures
combination.hpp
1 // SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
2 //
3 // SPDX-License-Identifier: BSD-3-Clause
4 
5 #ifndef GKO_PUBLIC_CORE_BASE_COMBINATION_HPP_
6 #define GKO_PUBLIC_CORE_BASE_COMBINATION_HPP_
7 
8 
9 #include <type_traits>
10 #include <vector>
11 
12 #include <ginkgo/core/base/lin_op.hpp>
13 
14 
15 namespace gko {
16 
17 
30 template <typename ValueType = default_precision>
31 class Combination : public EnableLinOp<Combination<ValueType>>,
32  public EnableCreateMethod<Combination<ValueType>>,
33  public Transposable {
35  friend class EnableCreateMethod<Combination>;
36 
37 public:
38  using value_type = ValueType;
40 
46  const std::vector<std::shared_ptr<const LinOp>>& get_coefficients()
47  const noexcept
48  {
49  return coefficients_;
50  }
51 
57  const std::vector<std::shared_ptr<const LinOp>>& get_operators()
58  const noexcept
59  {
60  return operators_;
61  }
62 
63  std::unique_ptr<LinOp> transpose() const override;
64 
65  std::unique_ptr<LinOp> conj_transpose() const override;
66 
72 
80 
85  Combination(const Combination&);
86 
93 
94 protected:
95  void add_operators() {}
96 
97  template <typename... Rest>
98  void add_operators(std::shared_ptr<const LinOp> coef,
99  std::shared_ptr<const LinOp> oper, Rest&&... rest)
100  {
101  GKO_ASSERT_EQUAL_DIMENSIONS(coef, dim<2>(1, 1));
102  GKO_ASSERT_EQUAL_DIMENSIONS(oper, this->get_size());
103  auto exec = this->get_executor();
104  coefficients_.push_back(std::move(coef));
105  operators_.push_back(std::move(oper));
106  if (coefficients_.back()->get_executor() != exec) {
107  coefficients_.back() = gko::clone(exec, coefficients_.back());
108  }
109  if (operators_.back()->get_executor() != exec) {
110  operators_.back() = gko::clone(exec, operators_.back());
111  }
112  add_operators(std::forward<Rest>(rest)...);
113  }
114 
120  explicit Combination(std::shared_ptr<const Executor> exec)
121  : EnableLinOp<Combination>(exec)
122  {}
123 
138  template <
139  typename CoefficientIterator, typename OperatorIterator,
140  typename = std::void_t<
141  typename std::iterator_traits<
142  CoefficientIterator>::iterator_category,
143  typename std::iterator_traits<OperatorIterator>::iterator_category>>
144  explicit Combination(CoefficientIterator coefficient_begin,
145  CoefficientIterator coefficient_end,
146  OperatorIterator operator_begin,
147  OperatorIterator operator_end)
148  : EnableLinOp<Combination>([&] {
149  if (operator_begin == operator_end) {
150  throw OutOfBoundsError(__FILE__, __LINE__, 1, 0);
151  }
152  return (*operator_begin)->get_executor();
153  }())
154  {
155  GKO_ASSERT_EQ(std::distance(coefficient_begin, coefficient_end),
156  std::distance(operator_begin, operator_end));
157  this->set_size((*operator_begin)->get_size());
158  auto coefficient_it = coefficient_begin;
159  for (auto operator_it = operator_begin; operator_it != operator_end;
160  ++operator_it) {
161  add_operators(*coefficient_it, *operator_it);
162  ++coefficient_it;
163  }
164  }
165 
176  template <typename... Rest>
177  explicit Combination(std::shared_ptr<const LinOp> coef,
178  std::shared_ptr<const LinOp> oper, Rest&&... rest)
179  : Combination(oper->get_executor())
180  {
181  this->set_size(oper->get_size());
182  add_operators(std::move(coef), std::move(oper),
183  std::forward<Rest>(rest)...);
184  }
185 
186  void apply_impl(const LinOp* b, LinOp* x) const override;
187 
188  void apply_impl(const LinOp* alpha, const LinOp* b, const LinOp* beta,
189  LinOp* x) const override;
190 
191 private:
192  std::vector<std::shared_ptr<const LinOp>> coefficients_;
193  std::vector<std::shared_ptr<const LinOp>> operators_;
194 
195  // TODO: solve race conditions when multithreading
196  mutable struct cache_struct {
197  cache_struct() = default;
198  ~cache_struct() = default;
199  cache_struct(const cache_struct& other) {}
200  cache_struct& operator=(const cache_struct& other) { return *this; }
201 
202  std::unique_ptr<LinOp> zero;
203  std::unique_ptr<LinOp> one;
204  std::unique_ptr<LinOp> intermediate_x;
205  } cache_;
206 };
207 
208 
209 } // namespace gko
210 
211 
212 #endif // GKO_PUBLIC_CORE_BASE_COMBINATION_HPP_
gko::Combination
The Combination class can be used to construct a linear combination of multiple linear operators c1 *...
Definition: combination.hpp:31
gko::LinOp
Definition: lin_op.hpp:117
gko::EnableCreateMethod
This mixin implements a static create() method on ConcreteType that dynamically allocates the memory,...
Definition: polymorphic_object.hpp:767
gko::Combination::transpose
std::unique_ptr< LinOp > transpose() const override
Returns a LinOp representing the transpose of the Transposable object.
gko::Transposable
Linear operators which support transposition should implement the Transposable interface.
Definition: lin_op.hpp:433
gko::clone
detail::cloned_type< Pointer > clone(const Pointer &p)
Creates a unique clone of the object pointed to by p.
Definition: utils_helper.hpp:173
gko
The Ginkgo namespace.
Definition: abstract_factory.hpp:20
gko::Combination::get_coefficients
const std::vector< std::shared_ptr< const LinOp > > & get_coefficients() const noexcept
Returns a list of coefficients of the combination.
Definition: combination.hpp:46
gko::Combination::Combination
Combination(const Combination &)
Copy-constructs a Combination.
gko::Combination::operator=
Combination & operator=(const Combination &)
Copy-assigns a Combination.
gko::Combination::conj_transpose
std::unique_ptr< LinOp > conj_transpose() const override
Returns a LinOp representing the conjugate transpose of the Transposable object.
gko::PolymorphicObject::get_executor
std::shared_ptr< const Executor > get_executor() const noexcept
Returns the Executor of the object.
Definition: polymorphic_object.hpp:243
gko::LinOp::get_size
const dim< 2 > & get_size() const noexcept
Returns the size of the operator.
Definition: lin_op.hpp:210
gko::Combination::get_operators
const std::vector< std::shared_ptr< const LinOp > > & get_operators() const noexcept
Returns a list of operators of the combination.
Definition: combination.hpp:57
gko::EnableLinOp
The EnableLinOp mixin can be used to provide sensible default implementations of the majority of the ...
Definition: lin_op.hpp:877
gko::LinOp::LinOp
LinOp(const LinOp &)=default
Copy-constructs a LinOp.
gko::zero
constexpr T zero()
Returns the additive identity for T.
Definition: math.hpp:602
gko::one
constexpr T one()
Returns the multiplicative identity for T.
Definition: math.hpp:630
gko::EnablePolymorphicObject
This mixin inherits from (a subclass of) PolymorphicObject and provides a base implementation of a ne...
Definition: polymorphic_object.hpp:667