Ginkgo  Generated from pipelines/2457497073 branch based on develop. Ginkgo version 2.0.0
A numerical linear algebra library targeting many-core architectures
batch_solver_base.hpp
1 // SPDX-FileCopyrightText: 2017 - 2026 The Ginkgo authors
2 //
3 // SPDX-License-Identifier: BSD-3-Clause
4 
5 #ifndef GKO_PUBLIC_CORE_SOLVER_BATCH_SOLVER_BASE_HPP_
6 #define GKO_PUBLIC_CORE_SOLVER_BATCH_SOLVER_BASE_HPP_
7 
8 
9 #include <ginkgo/core/base/abstract_factory.hpp>
10 #include <ginkgo/core/base/batch_lin_op.hpp>
11 #include <ginkgo/core/base/batch_multi_vector.hpp>
12 #include <ginkgo/core/base/utils_helper.hpp>
13 #include <ginkgo/core/log/batch_logger.hpp>
14 #include <ginkgo/core/matrix/batch_identity.hpp>
15 #include <ginkgo/core/stop/batch_stop_enum.hpp>
16 
17 
18 namespace gko {
19 namespace batch {
20 namespace solver {
21 
22 
29 class BatchSolver {
30 public:
36  std::shared_ptr<const BatchLinOp> get_system_matrix() const
37  {
38  return this->system_matrix_;
39  }
40 
46  std::shared_ptr<const BatchLinOp> get_preconditioner() const
47  {
48  return this->preconditioner_;
49  }
50 
56  double get_tolerance() const { return this->residual_tol_; }
57 
64  void reset_tolerance(double res_tol)
65  {
66  if (res_tol < 0) {
67  GKO_INVALID_STATE("Tolerance cannot be negative!");
68  }
69  this->residual_tol_ = res_tol;
70  }
71 
77  int get_max_iterations() const { return this->max_iterations_; }
78 
85  void reset_max_iterations(int max_iterations)
86  {
87  if (max_iterations < 0) {
88  GKO_INVALID_STATE("Max iterations cannot be negative!");
89  }
90  this->max_iterations_ = max_iterations;
91  }
92 
98  ::gko::batch::stop::tolerance_type get_tolerance_type() const
99  {
100  return this->tol_type_;
101  }
102 
108  void reset_tolerance_type(::gko::batch::stop::tolerance_type tol_type)
109  {
110  if (tol_type == ::gko::batch::stop::tolerance_type::absolute ||
111  tol_type == ::gko::batch::stop::tolerance_type::relative) {
112  this->tol_type_ = tol_type;
113  } else {
114  GKO_INVALID_STATE("Invalid tolerance type specified!");
115  }
116  }
117 
118 protected:
119  BatchSolver() {}
120 
121  BatchSolver(std::shared_ptr<const BatchLinOp> system_matrix,
122  std::shared_ptr<const BatchLinOp> gen_preconditioner,
123  const double res_tol, const int max_iterations,
124  const ::gko::batch::stop::tolerance_type tol_type)
125  : system_matrix_{std::move(system_matrix)},
126  preconditioner_{std::move(gen_preconditioner)},
127  residual_tol_{res_tol},
128  max_iterations_{max_iterations},
129  tol_type_{tol_type},
130  workspace_{}
131  {}
132 
133  void set_system_matrix_base(std::shared_ptr<const BatchLinOp> system_matrix)
134  {
135  this->system_matrix_ = std::move(system_matrix);
136  }
137 
138  void set_preconditioner_base(std::shared_ptr<const BatchLinOp> precond)
139  {
140  this->preconditioner_ = std::move(precond);
141  }
142 
143  std::shared_ptr<const BatchLinOp> system_matrix_{};
144  std::shared_ptr<const BatchLinOp> preconditioner_{};
145  double residual_tol_{};
146  int max_iterations_{};
147  ::gko::batch::stop::tolerance_type tol_type_{};
148  mutable array<unsigned char> workspace_{};
149 };
150 
151 
152 template <typename Parameters, typename Factory>
154  : enable_parameters_type<Parameters, Factory> {
162 
170 
175  ::gko::batch::stop::tolerance_type GKO_FACTORY_PARAMETER_SCALAR(
176  tolerance_type, ::gko::batch::stop::tolerance_type::absolute);
177 
182  std::shared_ptr<const BatchLinOpFactory> GKO_DEFERRED_FACTORY_PARAMETER(
184 
189  std::shared_ptr<const BatchLinOp> GKO_FACTORY_PARAMETER_SCALAR(
191 };
192 
193 
202 template <typename ConcreteSolver, typename ValueType,
203  typename PolymorphicBase = BatchLinOp>
205  : public BatchSolver,
206  public EnableBatchLinOp<ConcreteSolver, PolymorphicBase> {
207 public:
208  using real_type = remove_complex<ValueType>;
209 
210  void apply(ptr_param<const MultiVector<ValueType>> b,
212  {
213  this->validate_application_parameters(b.get(), x.get());
214  auto exec = this->get_executor();
215  this->apply_impl(make_temporary_clone(exec, b).get(),
216  make_temporary_clone(exec, x).get());
217  }
218 
219  void apply(ptr_param<const MultiVector<ValueType>> alpha,
221  ptr_param<const MultiVector<ValueType>> beta,
223  {
224  this->validate_application_parameters(alpha.get(), b.get(), beta.get(),
225  x.get());
226  auto exec = this->get_executor();
227  this->apply_impl(make_temporary_clone(exec, alpha).get(),
228  make_temporary_clone(exec, b).get(),
229  make_temporary_clone(exec, beta).get(),
230  make_temporary_clone(exec, x).get());
231  }
232 
233 protected:
234  GKO_ENABLE_SELF(ConcreteSolver);
235 
236  explicit EnableBatchSolver(std::shared_ptr<const Executor> exec)
238  {}
239 
240  template <typename FactoryParameters>
241  explicit EnableBatchSolver(std::shared_ptr<const Executor> exec,
242  std::shared_ptr<const BatchLinOp> system_matrix,
243  const FactoryParameters& params)
244  : BatchSolver(system_matrix, nullptr, params.tolerance,
245  params.max_iterations, params.tolerance_type),
247  exec, gko::transpose(system_matrix->get_size()))
248  {
249  GKO_ASSERT_BATCH_HAS_SQUARE_DIMENSIONS(system_matrix_);
250 
251  using value_type = typename ConcreteSolver::value_type;
252  using Identity = matrix::Identity<value_type>;
253  using real_type = remove_complex<value_type>;
254 
255  if (params.generated_preconditioner) {
256  GKO_ASSERT_BATCH_EQUAL_DIMENSIONS(params.generated_preconditioner,
257  this);
258  preconditioner_ = std::move(params.generated_preconditioner);
259  } else if (params.preconditioner) {
260  preconditioner_ = params.preconditioner->generate(system_matrix_);
261  } else {
262  auto id = Identity::create(exec, system_matrix->get_size());
263  preconditioner_ = std::move(id);
264  }
265  // We use a workspace here to store the logger data (iteration count
266  // and solver residual), and require a minimum size of
267  // `sizeof(real_type)+ sizeof(int)`
268  const size_type workspace_size =
269  system_matrix->get_num_batch_items() * 32;
270  workspace_.set_executor(exec);
271  workspace_.resize_and_reset(workspace_size);
272  }
273 
274  void set_system_matrix(std::shared_ptr<const BatchLinOp> new_system_matrix)
275  {
276  auto exec = self()->get_executor();
277  if (new_system_matrix) {
278  GKO_ASSERT_BATCH_EQUAL_DIMENSIONS(self(), new_system_matrix);
279  GKO_ASSERT_BATCH_HAS_SQUARE_DIMENSIONS(new_system_matrix);
280  if (new_system_matrix->get_executor() != exec) {
281  new_system_matrix = gko::clone(exec, new_system_matrix);
282  }
283  }
284  this->set_system_matrix_base(new_system_matrix);
285  }
286 
287  void set_preconditioner(std::shared_ptr<const BatchLinOp> new_precond)
288  {
289  auto exec = self()->get_executor();
290  if (new_precond) {
291  GKO_ASSERT_BATCH_EQUAL_DIMENSIONS(self(), new_precond);
292  GKO_ASSERT_BATCH_HAS_SQUARE_DIMENSIONS(new_precond);
293  if (new_precond->get_executor() != exec) {
294  new_precond = gko::clone(exec, new_precond);
295  }
296  }
297  this->set_preconditioner_base(new_precond);
298  }
299 
300  EnableBatchSolver& operator=(const EnableBatchSolver& other)
301  {
302  if (&other != this) {
303  this->set_size(other.get_size());
304  this->set_system_matrix(other.get_system_matrix());
305  this->set_preconditioner(other.get_preconditioner());
306  this->reset_tolerance(other.get_tolerance());
309  }
310 
311  return *this;
312  }
313 
314  EnableBatchSolver& operator=(EnableBatchSolver&& other)
315  {
316  if (&other != this) {
317  this->set_size(other.get_size());
318  this->set_system_matrix(other.get_system_matrix());
319  this->set_preconditioner(other.get_preconditioner());
320  this->reset_tolerance(other.get_tolerance());
323  other.set_system_matrix(nullptr);
324  other.set_preconditioner(nullptr);
325  }
326  return *this;
327  }
328 
331  other.self()->get_executor(), other.self()->get_size())
332  {
333  *this = other;
334  }
335 
338  other.self()->get_executor(), other.self()->get_size())
339  {
340  *this = std::move(other);
341  }
342 
343  void apply_impl(const MultiVector<ValueType>* b,
344  MultiVector<ValueType>* x) const
345  {
346  auto exec = this->get_executor();
347  if (b->get_common_size()[1] > 1) {
348  GKO_NOT_IMPLEMENTED;
349  }
350  auto workspace_view = workspace_.as_view();
351  auto log_data_ = std::make_unique<log::detail::log_data<real_type>>(
352  exec, b->get_num_batch_items(), workspace_view);
353 
354  this->solver_apply(b, x, log_data_.get());
355 
356  this->template log<gko::log::Logger::batch_solver_completed>(
357  log_data_->iter_counts, log_data_->res_norms);
358  }
359 
360  void apply_impl(const MultiVector<ValueType>* alpha,
361  const MultiVector<ValueType>* b,
362  const MultiVector<ValueType>* beta,
363  MultiVector<ValueType>* x) const
364  {
365  auto x_clone = x->clone();
366  this->apply(b, x_clone.get());
367  x->scale(beta);
368  x->add_scaled(alpha, x_clone.get());
369  }
370 
371  virtual void solver_apply(const MultiVector<ValueType>* b,
373  log::detail::log_data<real_type>* info) const = 0;
374 };
375 
376 
377 } // namespace solver
378 } // namespace batch
379 } // namespace gko
380 
381 
382 #endif // GKO_PUBLIC_CORE_SOLVER_BATCH_SOLVER_BASE_HPP_
gko::batch::EnableBatchLinOp
The EnableBatchLinOp mixin can be used to provide sensible default implementations of the majority of...
Definition: batch_lin_op.hpp:250
gko::batch::MultiVector::scale
void scale(ptr_param< const MultiVector< ValueType >> alpha)
Scales the vector with a scalar (aka: BLAS scal).
gko::log::profile_event_category::solver
Solver events.
gko::batch::MultiVector::add_scaled
void add_scaled(ptr_param< const MultiVector< ValueType >> alpha, ptr_param< const MultiVector< ValueType >> b)
Adds b scaled by alpha to the vector (aka: BLAS axpy).
gko::batch::BatchLinOp
Definition: batch_lin_op.hpp:59
gko::size_type
std::size_t size_type
Integral type used for allocation quantities.
Definition: types.hpp:90
gko::batch::solver::enable_preconditioned_iterative_solver_factory_parameters::tolerance_type
::gko::batch::stop::tolerance_type tolerance_type
To specify which type of tolerance check is to be considered, absolute or relative (to the rhs l2 nor...
Definition: batch_solver_base.hpp:176
gko::batch::MultiVector
MultiVector stores multiple vectors in a batched fashion and is useful for batched operations.
Definition: batch_multi_vector.hpp:52
GKO_FACTORY_PARAMETER_SCALAR
#define GKO_FACTORY_PARAMETER_SCALAR(_name, _default)
Creates a scalar factory parameter in the factory parameters structure.
Definition: abstract_factory.hpp:445
gko::clone
detail::cloned_type< Pointer > clone(const Pointer &p)
Creates a unique clone of the object pointed to by p.
Definition: utils_helper.hpp:173
gko::batch::solver::EnableBatchSolver
This mixin provides apply and common iterative solver functionality to all the batched solvers.
Definition: batch_solver_base.hpp:204
gko::batch::solver::BatchSolver::reset_max_iterations
void reset_max_iterations(int max_iterations)
Set the maximum number of iterations for the solver to use, independent of the factory that created i...
Definition: batch_solver_base.hpp:85
gko
The Ginkgo namespace.
Definition: abstract_factory.hpp:20
gko::batch::solver::enable_preconditioned_iterative_solver_factory_parameters::tolerance
double tolerance
Default residual tolerance.
Definition: batch_solver_base.hpp:169
gko::batch::solver::BatchSolver
The BatchSolver is a base class for all batched solvers and provides the common getters and setter fo...
Definition: batch_solver_base.hpp:29
gko::batch::solver::BatchSolver::get_tolerance_type
::gko::batch::stop::tolerance_type get_tolerance_type() const
Get the tolerance type.
Definition: batch_solver_base.hpp:98
gko::batch::MultiVector::get_common_size
dim< 2 > get_common_size() const
Returns the common size of the batch items.
Definition: batch_multi_vector.hpp:155
gko::batch::matrix::Identity
The batch Identity matrix, which represents a batch of Identity matrices.
Definition: batch_identity.hpp:32
gko::batch::MultiVector::get_num_batch_items
size_type get_num_batch_items() const
Returns the number of batch items.
Definition: batch_multi_vector.hpp:145
gko::ptr_param
This class is used for function parameters in the place of raw pointers.
Definition: utils_helper.hpp:41
gko::transpose
batch_dim< 2, DimensionType > transpose(const batch_dim< 2, DimensionType > &input)
Returns a batch_dim object with its dimensions swapped for batched operators.
Definition: batch_dim.hpp:119
gko::batch::solver::BatchSolver::get_preconditioner
std::shared_ptr< const BatchLinOp > get_preconditioner() const
Returns the generated preconditioner.
Definition: batch_solver_base.hpp:46
gko::batch::solver::BatchSolver::get_tolerance
double get_tolerance() const
Get the residual tolerance used by the solver.
Definition: batch_solver_base.hpp:56
gko::batch::solver::BatchSolver::reset_tolerance_type
void reset_tolerance_type(::gko::batch::stop::tolerance_type tol_type)
Set the type of tolerance check to use inside the solver.
Definition: batch_solver_base.hpp:108
gko::batch::solver::enable_preconditioned_iterative_solver_factory_parameters::preconditioner
std::shared_ptr< const BatchLinOpFactory > preconditioner
The preconditioner to be used by the iterative solver.
Definition: batch_solver_base.hpp:183
gko::batch::solver::enable_preconditioned_iterative_solver_factory_parameters::max_iterations
int max_iterations
Default maximum number iterations allowed.
Definition: batch_solver_base.hpp:161
gko::batch::solver::BatchSolver::reset_tolerance
void reset_tolerance(double res_tol)
Update the residual tolerance to be used by the solver.
Definition: batch_solver_base.hpp:64
gko::batch::solver::enable_preconditioned_iterative_solver_factory_parameters
Definition: batch_solver_base.hpp:153
gko::make_temporary_clone
detail::temporary_clone< detail::pointee< Ptr > > make_temporary_clone(std::shared_ptr< const Executor > exec, Ptr &&ptr)
Creates a temporary_clone.
Definition: temporary_clone.hpp:208
gko::batch::solver::enable_preconditioned_iterative_solver_factory_parameters::generated_preconditioner
std::shared_ptr< const BatchLinOp > generated_preconditioner
Already generated preconditioner.
Definition: batch_solver_base.hpp:190
gko::batch::solver::BatchSolver::get_max_iterations
int get_max_iterations() const
Get the maximum number of iterations set on the solver.
Definition: batch_solver_base.hpp:77
gko::batch::solver::BatchSolver::get_system_matrix
std::shared_ptr< const BatchLinOp > get_system_matrix() const
Returns the system operator (matrix) of the linear system.
Definition: batch_solver_base.hpp:36
gko::enable_parameters_type
The enable_parameters_type mixin is used to create a base implementation of the factory parameters st...
Definition: abstract_factory.hpp:211
gko::remove_complex
typename detail::remove_complex_s< T >::type remove_complex
Obtain the type which removed the complex of complex/scalar type or the template parameter of class b...
Definition: math.hpp:264