5 #ifndef GKO_PUBLIC_CORE_SOLVER_BATCH_SOLVER_BASE_HPP_
6 #define GKO_PUBLIC_CORE_SOLVER_BATCH_SOLVER_BASE_HPP_
9 #include <ginkgo/core/base/abstract_factory.hpp>
10 #include <ginkgo/core/base/batch_lin_op.hpp>
11 #include <ginkgo/core/base/batch_multi_vector.hpp>
12 #include <ginkgo/core/base/utils_helper.hpp>
13 #include <ginkgo/core/log/batch_logger.hpp>
14 #include <ginkgo/core/matrix/batch_identity.hpp>
15 #include <ginkgo/core/stop/batch_stop_enum.hpp>
38 return this->system_matrix_;
48 return this->preconditioner_;
67 GKO_INVALID_STATE(
"Tolerance cannot be negative!");
69 this->residual_tol_ = res_tol;
87 if (max_iterations < 0) {
88 GKO_INVALID_STATE(
"Max iterations cannot be negative!");
90 this->max_iterations_ = max_iterations;
100 return this->tol_type_;
110 if (tol_type == ::gko::batch::stop::tolerance_type::absolute ||
111 tol_type == ::gko::batch::stop::tolerance_type::relative) {
112 this->tol_type_ = tol_type;
114 GKO_INVALID_STATE(
"Invalid tolerance type specified!");
121 BatchSolver(std::shared_ptr<const BatchLinOp> system_matrix,
122 std::shared_ptr<const BatchLinOp> gen_preconditioner,
123 const double res_tol,
const int max_iterations,
124 const ::gko::batch::stop::tolerance_type tol_type)
125 : system_matrix_{std::move(system_matrix)},
126 preconditioner_{std::move(gen_preconditioner)},
127 residual_tol_{res_tol},
128 max_iterations_{max_iterations},
133 void set_system_matrix_base(std::shared_ptr<const BatchLinOp> system_matrix)
135 this->system_matrix_ = std::move(system_matrix);
138 void set_preconditioner_base(std::shared_ptr<const BatchLinOp> precond)
140 this->preconditioner_ = std::move(precond);
143 std::shared_ptr<const BatchLinOp> system_matrix_{};
144 std::shared_ptr<const BatchLinOp> preconditioner_{};
145 double residual_tol_{};
146 int max_iterations_{};
147 ::gko::batch::stop::tolerance_type tol_type_{};
148 mutable array<unsigned char> workspace_{};
152 template <
typename Parameters,
typename Factory>
182 std::shared_ptr<const BatchLinOpFactory> GKO_DEFERRED_FACTORY_PARAMETER(
202 template <
typename ConcreteSolver,
typename ValueType,
213 this->validate_application_parameters(b.get(), x.get());
214 auto exec = this->get_executor();
225 this->validate_application_parameters(alpha.get(), b.get(), beta.get(),
227 auto exec = this->get_executor();
238 this->validate_application_parameters(b.get(), x.get());
239 auto exec = this->get_executor();
250 this->validate_application_parameters(alpha.get(), b.get(), beta.get(),
252 auto exec = this->get_executor();
261 GKO_ENABLE_SELF(ConcreteSolver);
267 template <
typename FactoryParameters>
269 std::shared_ptr<const BatchLinOp> system_matrix,
270 const FactoryParameters& params)
271 :
BatchSolver(system_matrix,
nullptr, params.tolerance,
272 params.max_iterations, params.tolerance_type),
276 GKO_ASSERT_BATCH_HAS_SQUARE_DIMENSIONS(system_matrix_);
278 using value_type =
typename ConcreteSolver::value_type;
282 if (params.generated_preconditioner) {
283 GKO_ASSERT_BATCH_EQUAL_DIMENSIONS(params.generated_preconditioner,
285 preconditioner_ = std::move(params.generated_preconditioner);
286 }
else if (params.preconditioner) {
287 preconditioner_ = params.preconditioner->generate(system_matrix_);
289 auto id = Identity::create(exec, system_matrix->get_size());
290 preconditioner_ = std::move(
id);
296 system_matrix->get_num_batch_items() * 32;
297 workspace_.set_executor(exec);
298 workspace_.resize_and_reset(workspace_size);
301 void set_system_matrix(std::shared_ptr<const BatchLinOp> new_system_matrix)
303 auto exec =
self()->get_executor();
304 if (new_system_matrix) {
305 GKO_ASSERT_BATCH_EQUAL_DIMENSIONS(
self(), new_system_matrix);
306 GKO_ASSERT_BATCH_HAS_SQUARE_DIMENSIONS(new_system_matrix);
307 if (new_system_matrix->get_executor() != exec) {
308 new_system_matrix =
gko::clone(exec, new_system_matrix);
311 this->set_system_matrix_base(new_system_matrix);
314 void set_preconditioner(std::shared_ptr<const BatchLinOp> new_precond)
316 auto exec =
self()->get_executor();
318 GKO_ASSERT_BATCH_EQUAL_DIMENSIONS(
self(), new_precond);
319 GKO_ASSERT_BATCH_HAS_SQUARE_DIMENSIONS(new_precond);
320 if (new_precond->get_executor() != exec) {
324 this->set_preconditioner_base(new_precond);
329 if (&other !=
this) {
330 this->set_size(other.get_size());
343 if (&other !=
this) {
344 this->set_size(other.get_size());
350 other.set_system_matrix(
nullptr);
351 other.set_preconditioner(
nullptr);
358 other.self()->get_executor(), other.self()->get_size())
365 other.self()->get_executor(), other.self()->get_size())
367 *
this = std::move(other);
373 auto exec = this->get_executor();
377 auto workspace_view = workspace_.as_view();
378 auto log_data_ = std::make_unique<log::detail::log_data<real_type>>(
381 this->solver_apply(b, x, log_data_.get());
383 this->
template log<gko::log::Logger::batch_solver_completed>(
384 log_data_->iter_counts, log_data_->res_norms);
392 auto x_clone = x->clone();
393 this->apply(b, x_clone.get());
400 log::detail::log_data<real_type>* info)
const = 0;
409 #endif // GKO_PUBLIC_CORE_SOLVER_BATCH_SOLVER_BASE_HPP_