Ginkgo  Generated from pipelines/1478841010 branch based on develop. Ginkgo version 1.9.0
A numerical linear algebra library targeting many-core architectures
batch_solver_base.hpp
1 // SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
2 //
3 // SPDX-License-Identifier: BSD-3-Clause
4 
5 #ifndef GKO_PUBLIC_CORE_SOLVER_BATCH_SOLVER_BASE_HPP_
6 #define GKO_PUBLIC_CORE_SOLVER_BATCH_SOLVER_BASE_HPP_
7 
8 
9 #include <ginkgo/core/base/abstract_factory.hpp>
10 #include <ginkgo/core/base/batch_lin_op.hpp>
11 #include <ginkgo/core/base/batch_multi_vector.hpp>
12 #include <ginkgo/core/base/utils_helper.hpp>
13 #include <ginkgo/core/log/batch_logger.hpp>
14 #include <ginkgo/core/matrix/batch_identity.hpp>
15 #include <ginkgo/core/stop/batch_stop_enum.hpp>
16 
17 
18 namespace gko {
19 namespace batch {
20 namespace solver {
21 
22 
29 class BatchSolver {
30 public:
36  std::shared_ptr<const BatchLinOp> get_system_matrix() const
37  {
38  return this->system_matrix_;
39  }
40 
46  std::shared_ptr<const BatchLinOp> get_preconditioner() const
47  {
48  return this->preconditioner_;
49  }
50 
56  double get_tolerance() const { return this->residual_tol_; }
57 
64  void reset_tolerance(double res_tol)
65  {
66  if (res_tol < 0) {
67  GKO_INVALID_STATE("Tolerance cannot be negative!");
68  }
69  this->residual_tol_ = res_tol;
70  }
71 
77  int get_max_iterations() const { return this->max_iterations_; }
78 
85  void reset_max_iterations(int max_iterations)
86  {
87  if (max_iterations < 0) {
88  GKO_INVALID_STATE("Max iterations cannot be negative!");
89  }
90  this->max_iterations_ = max_iterations;
91  }
92 
98  ::gko::batch::stop::tolerance_type get_tolerance_type() const
99  {
100  return this->tol_type_;
101  }
102 
108  void reset_tolerance_type(::gko::batch::stop::tolerance_type tol_type)
109  {
110  if (tol_type == ::gko::batch::stop::tolerance_type::absolute ||
111  tol_type == ::gko::batch::stop::tolerance_type::relative) {
112  this->tol_type_ = tol_type;
113  } else {
114  GKO_INVALID_STATE("Invalid tolerance type specified!");
115  }
116  }
117 
118 protected:
119  BatchSolver() {}
120 
121  BatchSolver(std::shared_ptr<const BatchLinOp> system_matrix,
122  std::shared_ptr<const BatchLinOp> gen_preconditioner,
123  const double res_tol, const int max_iterations,
124  const ::gko::batch::stop::tolerance_type tol_type)
125  : system_matrix_{std::move(system_matrix)},
126  preconditioner_{std::move(gen_preconditioner)},
127  residual_tol_{res_tol},
128  max_iterations_{max_iterations},
129  tol_type_{tol_type},
130  workspace_{}
131  {}
132 
133  void set_system_matrix_base(std::shared_ptr<const BatchLinOp> system_matrix)
134  {
135  this->system_matrix_ = std::move(system_matrix);
136  }
137 
138  void set_preconditioner_base(std::shared_ptr<const BatchLinOp> precond)
139  {
140  this->preconditioner_ = std::move(precond);
141  }
142 
143  std::shared_ptr<const BatchLinOp> system_matrix_{};
144  std::shared_ptr<const BatchLinOp> preconditioner_{};
145  double residual_tol_{};
146  int max_iterations_{};
147  ::gko::batch::stop::tolerance_type tol_type_{};
148  mutable array<unsigned char> workspace_{};
149 };
150 
151 
152 template <typename Parameters, typename Factory>
154  : enable_parameters_type<Parameters, Factory> {
162 
170 
175  ::gko::batch::stop::tolerance_type GKO_FACTORY_PARAMETER_SCALAR(
176  tolerance_type, ::gko::batch::stop::tolerance_type::absolute);
177 
182  std::shared_ptr<const BatchLinOpFactory> GKO_DEFERRED_FACTORY_PARAMETER(
184 
189  std::shared_ptr<const BatchLinOp> GKO_FACTORY_PARAMETER_SCALAR(
191 };
192 
193 
202 template <typename ConcreteSolver, typename ValueType,
203  typename PolymorphicBase = BatchLinOp>
205  : public BatchSolver,
206  public EnableBatchLinOp<ConcreteSolver, PolymorphicBase> {
207 public:
208  using real_type = remove_complex<ValueType>;
209 
210  const ConcreteSolver* apply(ptr_param<const MultiVector<ValueType>> b,
212  {
213  this->validate_application_parameters(b.get(), x.get());
214  auto exec = this->get_executor();
215  this->apply_impl(make_temporary_clone(exec, b).get(),
216  make_temporary_clone(exec, x).get());
217  return self();
218  }
219 
220  const ConcreteSolver* apply(ptr_param<const MultiVector<ValueType>> alpha,
222  ptr_param<const MultiVector<ValueType>> beta,
224  {
225  this->validate_application_parameters(alpha.get(), b.get(), beta.get(),
226  x.get());
227  auto exec = this->get_executor();
228  this->apply_impl(make_temporary_clone(exec, alpha).get(),
229  make_temporary_clone(exec, b).get(),
230  make_temporary_clone(exec, beta).get(),
231  make_temporary_clone(exec, x).get());
232  return self();
233  }
234 
235  ConcreteSolver* apply(ptr_param<const MultiVector<ValueType>> b,
237  {
238  this->validate_application_parameters(b.get(), x.get());
239  auto exec = this->get_executor();
240  this->apply_impl(make_temporary_clone(exec, b).get(),
241  make_temporary_clone(exec, x).get());
242  return self();
243  }
244 
245  ConcreteSolver* apply(ptr_param<const MultiVector<ValueType>> alpha,
247  ptr_param<const MultiVector<ValueType>> beta,
249  {
250  this->validate_application_parameters(alpha.get(), b.get(), beta.get(),
251  x.get());
252  auto exec = this->get_executor();
253  this->apply_impl(make_temporary_clone(exec, alpha).get(),
254  make_temporary_clone(exec, b).get(),
255  make_temporary_clone(exec, beta).get(),
256  make_temporary_clone(exec, x).get());
257  return self();
258  }
259 
260 protected:
261  GKO_ENABLE_SELF(ConcreteSolver);
262 
263  explicit EnableBatchSolver(std::shared_ptr<const Executor> exec)
265  {}
266 
267  template <typename FactoryParameters>
268  explicit EnableBatchSolver(std::shared_ptr<const Executor> exec,
269  std::shared_ptr<const BatchLinOp> system_matrix,
270  const FactoryParameters& params)
271  : BatchSolver(system_matrix, nullptr, params.tolerance,
272  params.max_iterations, params.tolerance_type),
274  exec, gko::transpose(system_matrix->get_size()))
275  {
276  GKO_ASSERT_BATCH_HAS_SQUARE_DIMENSIONS(system_matrix_);
277 
278  using value_type = typename ConcreteSolver::value_type;
279  using Identity = matrix::Identity<value_type>;
280  using real_type = remove_complex<value_type>;
281 
282  if (params.generated_preconditioner) {
283  GKO_ASSERT_BATCH_EQUAL_DIMENSIONS(params.generated_preconditioner,
284  this);
285  preconditioner_ = std::move(params.generated_preconditioner);
286  } else if (params.preconditioner) {
287  preconditioner_ = params.preconditioner->generate(system_matrix_);
288  } else {
289  auto id = Identity::create(exec, system_matrix->get_size());
290  preconditioner_ = std::move(id);
291  }
292  // We use a workspace here to store the logger data (iteration count
293  // and solver residual), and require a minimum size of
294  // `sizeof(real_type)+ sizeof(int)`
295  const size_type workspace_size =
296  system_matrix->get_num_batch_items() * 32;
297  workspace_.set_executor(exec);
298  workspace_.resize_and_reset(workspace_size);
299  }
300 
301  void set_system_matrix(std::shared_ptr<const BatchLinOp> new_system_matrix)
302  {
303  auto exec = self()->get_executor();
304  if (new_system_matrix) {
305  GKO_ASSERT_BATCH_EQUAL_DIMENSIONS(self(), new_system_matrix);
306  GKO_ASSERT_BATCH_HAS_SQUARE_DIMENSIONS(new_system_matrix);
307  if (new_system_matrix->get_executor() != exec) {
308  new_system_matrix = gko::clone(exec, new_system_matrix);
309  }
310  }
311  this->set_system_matrix_base(new_system_matrix);
312  }
313 
314  void set_preconditioner(std::shared_ptr<const BatchLinOp> new_precond)
315  {
316  auto exec = self()->get_executor();
317  if (new_precond) {
318  GKO_ASSERT_BATCH_EQUAL_DIMENSIONS(self(), new_precond);
319  GKO_ASSERT_BATCH_HAS_SQUARE_DIMENSIONS(new_precond);
320  if (new_precond->get_executor() != exec) {
321  new_precond = gko::clone(exec, new_precond);
322  }
323  }
324  this->set_preconditioner_base(new_precond);
325  }
326 
327  EnableBatchSolver& operator=(const EnableBatchSolver& other)
328  {
329  if (&other != this) {
330  this->set_size(other.get_size());
331  this->set_system_matrix(other.get_system_matrix());
332  this->set_preconditioner(other.get_preconditioner());
333  this->reset_tolerance(other.get_tolerance());
336  }
337 
338  return *this;
339  }
340 
341  EnableBatchSolver& operator=(EnableBatchSolver&& other)
342  {
343  if (&other != this) {
344  this->set_size(other.get_size());
345  this->set_system_matrix(other.get_system_matrix());
346  this->set_preconditioner(other.get_preconditioner());
347  this->reset_tolerance(other.get_tolerance());
350  other.set_system_matrix(nullptr);
351  other.set_preconditioner(nullptr);
352  }
353  return *this;
354  }
355 
358  other.self()->get_executor(), other.self()->get_size())
359  {
360  *this = other;
361  }
362 
365  other.self()->get_executor(), other.self()->get_size())
366  {
367  *this = std::move(other);
368  }
369 
370  void apply_impl(const MultiVector<ValueType>* b,
371  MultiVector<ValueType>* x) const
372  {
373  auto exec = this->get_executor();
374  if (b->get_common_size()[1] > 1) {
375  GKO_NOT_IMPLEMENTED;
376  }
377  auto workspace_view = workspace_.as_view();
378  auto log_data_ = std::make_unique<log::detail::log_data<real_type>>(
379  exec, b->get_num_batch_items(), workspace_view);
380 
381  this->solver_apply(b, x, log_data_.get());
382 
383  this->template log<gko::log::Logger::batch_solver_completed>(
384  log_data_->iter_counts, log_data_->res_norms);
385  }
386 
387  void apply_impl(const MultiVector<ValueType>* alpha,
388  const MultiVector<ValueType>* b,
389  const MultiVector<ValueType>* beta,
390  MultiVector<ValueType>* x) const
391  {
392  auto x_clone = x->clone();
393  this->apply(b, x_clone.get());
394  x->scale(beta);
395  x->add_scaled(alpha, x_clone.get());
396  }
397 
398  virtual void solver_apply(const MultiVector<ValueType>* b,
400  log::detail::log_data<real_type>* info) const = 0;
401 };
402 
403 
404 } // namespace solver
405 } // namespace batch
406 } // namespace gko
407 
408 
409 #endif // GKO_PUBLIC_CORE_SOLVER_BATCH_SOLVER_BASE_HPP_
gko::batch::EnableBatchLinOp
The EnableBatchLinOp mixin can be used to provide sensible default implementations of the majority of...
Definition: batch_lin_op.hpp:250
gko::batch::MultiVector::scale
void scale(ptr_param< const MultiVector< ValueType >> alpha)
Scales the vector with a scalar (aka: BLAS scal).
gko::log::profile_event_category::solver
Solver events.
gko::batch::MultiVector::add_scaled
void add_scaled(ptr_param< const MultiVector< ValueType >> alpha, ptr_param< const MultiVector< ValueType >> b)
Adds b scaled by alpha to the vector (aka: BLAS axpy).
gko::batch::BatchLinOp
Definition: batch_lin_op.hpp:59
gko::size_type
std::size_t size_type
Integral type used for allocation quantities.
Definition: types.hpp:86
gko::batch::solver::enable_preconditioned_iterative_solver_factory_parameters::tolerance_type
::gko::batch::stop::tolerance_type tolerance_type
To specify which type of tolerance check is to be considered, absolute or relative (to the rhs l2 nor...
Definition: batch_solver_base.hpp:176
gko::batch::MultiVector
MultiVector stores multiple vectors in a batched fashion and is useful for batched operations.
Definition: batch_multi_vector.hpp:52
GKO_FACTORY_PARAMETER_SCALAR
#define GKO_FACTORY_PARAMETER_SCALAR(_name, _default)
Creates a scalar factory parameter in the factory parameters structure.
Definition: abstract_factory.hpp:445
gko::clone
detail::cloned_type< Pointer > clone(const Pointer &p)
Creates a unique clone of the object pointed to by p.
Definition: utils_helper.hpp:173
gko::batch::solver::EnableBatchSolver
This mixin provides apply and common iterative solver functionality to all the batched solvers.
Definition: batch_solver_base.hpp:204
gko::batch::solver::BatchSolver::reset_max_iterations
void reset_max_iterations(int max_iterations)
Set the maximum number of iterations for the solver to use, independent of the factory that created i...
Definition: batch_solver_base.hpp:85
gko
The Ginkgo namespace.
Definition: abstract_factory.hpp:20
gko::batch::solver::enable_preconditioned_iterative_solver_factory_parameters::tolerance
double tolerance
Default residual tolerance.
Definition: batch_solver_base.hpp:169
gko::batch::solver::BatchSolver
The BatchSolver is a base class for all batched solvers and provides the common getters and setter fo...
Definition: batch_solver_base.hpp:29
gko::batch::solver::BatchSolver::get_tolerance_type
::gko::batch::stop::tolerance_type get_tolerance_type() const
Get the tolerance type.
Definition: batch_solver_base.hpp:98
gko::batch::MultiVector::get_common_size
dim< 2 > get_common_size() const
Returns the common size of the batch items.
Definition: batch_multi_vector.hpp:126
gko::batch::matrix::Identity
The batch Identity matrix, which represents a batch of Identity matrices.
Definition: batch_identity.hpp:32
gko::batch::MultiVector::get_num_batch_items
size_type get_num_batch_items() const
Returns the number of batch items.
Definition: batch_multi_vector.hpp:116
gko::ptr_param
This class is used for function parameters in the place of raw pointers.
Definition: utils_helper.hpp:41
gko::transpose
batch_dim< 2, DimensionType > transpose(const batch_dim< 2, DimensionType > &input)
Returns a batch_dim object with its dimensions swapped for batched operators.
Definition: batch_dim.hpp:119
gko::batch::solver::BatchSolver::get_preconditioner
std::shared_ptr< const BatchLinOp > get_preconditioner() const
Returns the generated preconditioner.
Definition: batch_solver_base.hpp:46
gko::batch::solver::BatchSolver::get_tolerance
double get_tolerance() const
Get the residual tolerance used by the solver.
Definition: batch_solver_base.hpp:56
gko::batch::solver::BatchSolver::reset_tolerance_type
void reset_tolerance_type(::gko::batch::stop::tolerance_type tol_type)
Set the type of tolerance check to use inside the solver.
Definition: batch_solver_base.hpp:108
gko::batch::solver::enable_preconditioned_iterative_solver_factory_parameters::preconditioner
std::shared_ptr< const BatchLinOpFactory > preconditioner
The preconditioner to be used by the iterative solver.
Definition: batch_solver_base.hpp:183
gko::batch::solver::enable_preconditioned_iterative_solver_factory_parameters::max_iterations
int max_iterations
Default maximum number iterations allowed.
Definition: batch_solver_base.hpp:161
gko::batch::solver::BatchSolver::reset_tolerance
void reset_tolerance(double res_tol)
Update the residual tolerance to be used by the solver.
Definition: batch_solver_base.hpp:64
gko::batch::solver::enable_preconditioned_iterative_solver_factory_parameters
Definition: batch_solver_base.hpp:153
gko::make_temporary_clone
detail::temporary_clone< detail::pointee< Ptr > > make_temporary_clone(std::shared_ptr< const Executor > exec, Ptr &&ptr)
Creates a temporary_clone.
Definition: temporary_clone.hpp:208
gko::batch::solver::enable_preconditioned_iterative_solver_factory_parameters::generated_preconditioner
std::shared_ptr< const BatchLinOp > generated_preconditioner
Already generated preconditioner.
Definition: batch_solver_base.hpp:190
gko::batch::solver::BatchSolver::get_max_iterations
int get_max_iterations() const
Get the maximum number of iterations set on the solver.
Definition: batch_solver_base.hpp:77
gko::batch::solver::BatchSolver::get_system_matrix
std::shared_ptr< const BatchLinOp > get_system_matrix() const
Returns the system operator (matrix) of the linear system.
Definition: batch_solver_base.hpp:36
gko::enable_parameters_type
The enable_parameters_type mixin is used to create a base implementation of the factory parameters st...
Definition: abstract_factory.hpp:211
gko::remove_complex
typename detail::remove_complex_s< T >::type remove_complex
Obtain the type which removed the complex of complex/scalar type or the template parameter of class b...
Definition: math.hpp:325