5 #ifndef GKO_PUBLIC_CORE_SOLVER_SOLVER_BASE_HPP_
6 #define GKO_PUBLIC_CORE_SOLVER_SOLVER_BASE_HPP_
10 #include <type_traits>
14 #include <ginkgo/core/base/lin_op.hpp>
15 #include <ginkgo/core/base/math.hpp>
16 #include <ginkgo/core/log/logger.hpp>
17 #include <ginkgo/core/matrix/dense.hpp>
18 #include <ginkgo/core/matrix/identity.hpp>
19 #include <ginkgo/core/solver/workspace.hpp>
20 #include <ginkgo/core/stop/combined.hpp>
21 #include <ginkgo/core/stop/criterion.hpp>
24 GKO_BEGIN_DISABLE_DEPRECATION_WARNINGS
68 friend class multigrid::detail::MultigridState;
83 virtual void apply_with_initial_guess(
const LinOp* b,
LinOp* x,
89 apply_with_initial_guess(b.
get(), x.
get(), guess);
104 virtual void apply_with_initial_guess(
const LinOp* alpha,
const LinOp* b,
115 apply_with_initial_guess(alpha.
get(), b.
get(), beta.
get(), x.
get(),
161 template <
typename DerivedType>
164 friend class multigrid::detail::MultigridState;
175 void apply_with_initial_guess(
const LinOp* b,
LinOp* x,
178 self()->
template log<log::Logger::linop_apply_started>(
self(), b, x);
179 auto exec =
self()->get_executor();
180 GKO_ASSERT_CONFORMANT(
self(), b);
181 GKO_ASSERT_EQUAL_ROWS(
self(), x);
182 GKO_ASSERT_EQUAL_COLS(b, x);
186 self()->
template log<log::Logger::linop_apply_completed>(
self(), b, x);
193 void apply_with_initial_guess(
const LinOp* alpha,
const LinOp* b,
197 self()->
template log<log::Logger::linop_advanced_apply_started>(
198 self(), alpha, b, beta, x);
199 auto exec =
self()->get_executor();
200 GKO_ASSERT_CONFORMANT(
self(), b);
201 GKO_ASSERT_EQUAL_ROWS(
self(), x);
202 GKO_ASSERT_EQUAL_COLS(b, x);
203 GKO_ASSERT_EQUAL_DIMENSIONS(alpha,
dim<2>(1, 1));
204 GKO_ASSERT_EQUAL_DIMENSIONS(beta,
dim<2>(1, 1));
205 this->apply_with_initial_guess_impl(
210 self()->
template log<log::Logger::linop_advanced_apply_completed>(
211 self(), alpha, b, beta, x);
219 virtual void apply_with_initial_guess_impl(
226 virtual void apply_with_initial_guess_impl(
230 GKO_ENABLE_SELF(DerivedType);
238 template <
typename Solver>
241 static int num_vectors(
const Solver&) {
return 0; }
243 static int num_arrays(
const Solver&) {
return 0; }
245 static std::vector<std::string> op_names(
const Solver&) {
return {}; }
247 static std::vector<std::string> array_names(
const Solver&) {
return {}; }
249 static std::vector<int> scalars(
const Solver&) {
return {}; }
251 static std::vector<int> vectors(
const Solver&) {
return {}; }
270 template <
typename DerivedType>
281 auto exec =
self()->get_executor();
283 GKO_ASSERT_EQUAL_DIMENSIONS(
self(), new_precond);
284 GKO_ASSERT_IS_SQUARE_MATRIX(new_precond);
285 if (new_precond->get_executor() != exec) {
298 if (&other !=
this) {
311 if (&other !=
this) {
313 other.set_preconditioner(
nullptr);
339 *
this = std::move(other);
343 DerivedType*
self() {
return static_cast<DerivedType*>(
this); }
345 const DerivedType*
self()
const
347 return static_cast<const DerivedType*>(
this);
363 class SolverBaseLinOp {
365 SolverBaseLinOp(std::shared_ptr<const Executor> exec)
366 : workspace_{std::move(exec)}
369 virtual ~SolverBaseLinOp() =
default;
376 std::shared_ptr<const LinOp> get_system_matrix()
const
378 return system_matrix_;
381 const LinOp* get_workspace_op(
int vector_id)
const
383 return workspace_.get_op(vector_id);
386 virtual int get_num_workspace_ops()
const {
return 0; }
388 virtual std::vector<std::string> get_workspace_op_names()
const
397 virtual std::vector<int> get_workspace_scalars()
const {
return {}; }
403 virtual std::vector<int> get_workspace_vectors()
const {
return {}; }
406 void set_system_matrix_base(std::shared_ptr<const LinOp> system_matrix)
408 system_matrix_ = std::move(system_matrix);
411 void set_workspace_size(
int num_operators,
int num_arrays)
const
413 workspace_.set_size(num_operators, num_arrays);
416 template <
typename LinOpType>
417 LinOpType* create_workspace_op(
int vector_id,
gko::dim<2> size)
const
419 return workspace_.template create_or_get_op<LinOpType>(
422 return LinOpType::create(this->workspace_.get_executor(), size);
424 typeid(LinOpType), size, size[1]);
427 template <
typename LinOpType>
428 LinOpType* create_workspace_op_with_config_of(
int vector_id,
429 const LinOpType* vec)
const
431 return workspace_.template create_or_get_op<LinOpType>(
432 vector_id, [&] {
return LinOpType::create_with_config_of(vec); },
433 typeid(*vec), vec->get_size(), vec->get_stride());
436 template <
typename LinOpType>
437 LinOpType* create_workspace_op_with_type_of(
int vector_id,
438 const LinOpType* vec,
441 return workspace_.template create_or_get_op<LinOpType>(
444 return LinOpType::create_with_type_of(
445 vec, workspace_.get_executor(), size, size[1]);
447 typeid(*vec), size, size[1]);
450 template <
typename LinOpType>
451 LinOpType* create_workspace_op_with_type_of(
int vector_id,
452 const LinOpType* vec,
454 dim<2> local_size)
const
456 return workspace_.template create_or_get_op<LinOpType>(
459 return LinOpType::create_with_type_of(
460 vec, workspace_.get_executor(), global_size, local_size,
463 typeid(*vec), global_size, local_size[1]);
466 template <
typename ValueType>
467 matrix::Dense<ValueType>* create_workspace_scalar(
int vector_id,
470 return workspace_.template create_or_get_op<matrix::Dense<ValueType>>(
474 workspace_.get_executor(), dim<2>{1, size});
476 typeid(matrix::Dense<ValueType>),
gko::dim<2>{1, size}, size);
479 template <
typename ValueType>
480 array<ValueType>& create_workspace_array(
int array_id,
size_type size)
const
482 return workspace_.template create_or_get_array<ValueType>(array_id,
486 template <
typename ValueType>
487 array<ValueType>& create_workspace_array(
int array_id)
const
489 return workspace_.template init_or_get_array<ValueType>(array_id);
493 mutable detail::workspace workspace_;
495 std::shared_ptr<const LinOp> system_matrix_;
502 template <
typename MatrixType>
505 GKO_DEPRECATED(
"This class will be replaced by the template-less detail::SolverBaseLinOp in a future release")
SolverBase
507 : public detail::SolverBaseLinOp {
509 using detail::SolverBaseLinOp::SolverBaseLinOp;
520 return std::dynamic_pointer_cast<const MatrixType>(
521 SolverBaseLinOp::get_system_matrix());
525 void set_system_matrix_base(std::shared_ptr<const MatrixType> system_matrix)
527 SolverBaseLinOp::set_system_matrix_base(std::move(system_matrix));
541 template <
typename DerivedType,
typename MatrixType = LinOp>
550 if (&other !=
this) {
562 if (&other !=
this) {
563 set_system_matrix(other.get_system_matrix());
564 other.set_system_matrix(
nullptr);
571 EnableSolverBase(std::shared_ptr<const MatrixType> system_matrix)
572 : SolverBase<MatrixType>{
self()->get_executor()}
574 set_system_matrix(std::move(system_matrix));
581 :
SolverBase<MatrixType>{other.self()->get_executor()}
591 :
SolverBase<MatrixType>{other.self()->get_executor()}
593 *
this = std::move(other);
596 int get_num_workspace_ops()
const override
598 using traits = workspace_traits<DerivedType>;
599 return traits::num_vectors(*
self());
602 std::vector<std::string> get_workspace_op_names()
const override
604 using traits = workspace_traits<DerivedType>;
605 return traits::op_names(*
self());
615 return traits::scalars(*
self());
625 return traits::vectors(*
self());
629 void set_system_matrix(std::shared_ptr<const MatrixType> new_system_matrix)
631 auto exec =
self()->get_executor();
632 if (new_system_matrix) {
633 GKO_ASSERT_EQUAL_DIMENSIONS(
self(), new_system_matrix);
634 GKO_ASSERT_IS_SQUARE_MATRIX(new_system_matrix);
635 if (new_system_matrix->get_executor() != exec) {
636 new_system_matrix =
gko::clone(exec, new_system_matrix);
639 this->set_system_matrix_base(new_system_matrix);
642 void setup_workspace()
const
644 using traits = workspace_traits<DerivedType>;
645 this->set_workspace_size(traits::num_vectors(*
self()),
646 traits::num_arrays(*
self()));
650 DerivedType*
self() {
return static_cast<DerivedType*>(
this); }
652 const DerivedType*
self()
const
654 return static_cast<const DerivedType*>(
this);
675 return stop_factory_;
684 std::shared_ptr<const stop::CriterionFactory> new_stop_factory)
686 stop_factory_ = new_stop_factory;
690 std::shared_ptr<const stop::CriterionFactory> stop_factory_;
703 template <
typename DerivedType>
712 if (&other !=
this) {
725 if (&other !=
this) {
727 other.set_stop_criterion_factory(
nullptr);
735 std::shared_ptr<const stop::CriterionFactory> stop_factory)
751 *
this = std::move(other);
755 std::shared_ptr<const stop::CriterionFactory> new_stop_factory)
override
757 auto exec =
self()->get_executor();
758 if (new_stop_factory && new_stop_factory->get_executor() != exec) {
759 new_stop_factory =
gko::clone(exec, new_stop_factory);
765 DerivedType*
self() {
return static_cast<DerivedType*>(
this); }
767 const DerivedType*
self()
const
769 return static_cast<const DerivedType*>(
this);
784 template <
typename ValueType,
typename DerivedType>
793 std::shared_ptr<const LinOp> system_matrix,
794 std::shared_ptr<const stop::CriterionFactory> stop_factory,
795 std::shared_ptr<const LinOp> preconditioner)
801 template <
typename FactoryParameters>
803 std::shared_ptr<const LinOp> system_matrix,
804 const FactoryParameters& params)
807 generate_preconditioner(system_matrix, params)}
811 template <
typename FactoryParameters>
812 static std::shared_ptr<const LinOp> generate_preconditioner(
813 std::shared_ptr<const LinOp> system_matrix,
814 const FactoryParameters& params)
816 if (params.generated_preconditioner) {
817 return params.generated_preconditioner;
818 }
else if (params.preconditioner) {
819 return params.preconditioner->generate(system_matrix);
822 system_matrix->get_executor(), system_matrix->get_size());
828 template <
typename Parameters,
typename Factory>
834 std::vector<std::shared_ptr<const stop::CriterionFactory>>
839 template <
typename Parameters,
typename Factory>
846 std::shared_ptr<const LinOpFactory> GKO_DEFERRED_FACTORY_PARAMETER(
862 GKO_END_DISABLE_DEPRECATION_WARNINGS
865 #endif // GKO_PUBLIC_CORE_SOLVER_SOLVER_BASE_HPP_