5 #ifndef GKO_PUBLIC_CORE_PRECONDITIONER_IC_HPP_
6 #define GKO_PUBLIC_CORE_PRECONDITIONER_IC_HPP_
10 #include <type_traits>
12 #include <ginkgo/core/base/abstract_factory.hpp>
13 #include <ginkgo/core/base/composition.hpp>
14 #include <ginkgo/core/base/exception.hpp>
15 #include <ginkgo/core/base/exception_helpers.hpp>
16 #include <ginkgo/core/base/lin_op.hpp>
17 #include <ginkgo/core/base/precision_dispatch.hpp>
18 #include <ginkgo/core/config/config.hpp>
19 #include <ginkgo/core/config/registry.hpp>
20 #include <ginkgo/core/factorization/par_ic.hpp>
21 #include <ginkgo/core/matrix/dense.hpp>
22 #include <ginkgo/core/preconditioner/isai.hpp>
23 #include <ginkgo/core/preconditioner/utils.hpp>
24 #include <ginkgo/core/solver/gmres.hpp>
25 #include <ginkgo/core/solver/ir.hpp>
26 #include <ginkgo/core/solver/solver_traits.hpp>
27 #include <ginkgo/core/solver/triangular.hpp>
28 #include <ginkgo/core/stop/combined.hpp>
29 #include <ginkgo/core/stop/iteration.hpp>
30 #include <ginkgo/core/stop/residual_norm.hpp>
34 namespace preconditioner {
38 template <
typename Type>
39 constexpr
bool support_ic_parse =
40 is_instantiation_of<Type, solver::LowerTrs>::value ||
41 is_instantiation_of<Type, solver::Ir>::value ||
42 is_instantiation_of<Type, solver::Gmres>::value ||
43 is_instantiation_of<Type, preconditioner::LowerIsai>::value;
48 std::enable_if_t<!support_ic_parse<typename Ic::l_solver_type>>* =
nullptr>
49 typename Ic::parameters_type ic_parse(
50 const config::pnode& config,
const config::registry& context,
51 const config::type_descriptor& td_for_child)
54 "preconditioner::Ic only supports limited type for parse.");
59 std::enable_if_t<support_ic_parse<typename Ic::l_solver_type>>* =
nullptr>
60 typename Ic::parameters_type ic_parse(
61 const config::pnode& config,
const config::registry& context,
62 const config::type_descriptor& td_for_child);
112 template <
typename LSolverType = solver::LowerTrs<>,
typename IndexType =
int32>
119 std::is_same<
typename LSolverType::transposed_type::transposed_type,
121 "LSolverType::transposed_type must be symmetric");
122 using value_type =
typename LSolverType::value_type;
123 using l_solver_type = LSolverType;
124 using lh_solver_type =
typename LSolverType::transposed_type;
125 using index_type = IndexType;
135 std::shared_ptr<const typename l_solver_type::Factory>
143 GKO_DEPRECATED(
"use with_l_solver instead")
148 return with_l_solver(std::move(solver));
151 parameters_type& with_l_solver(
155 this->l_solver_generator = std::move(solver);
156 this->deferred_factories[
"l_solver"] = [](
const auto& exec,
158 if (!params.l_solver_generator.is_empty()) {
159 params.l_solver_factory =
160 params.l_solver_generator.on(exec);
166 GKO_DEPRECATED(
"use with_factorization instead")
167 parameters_type& with_factorization_factory(
168 deferred_factory_parameter<const LinOpFactory> factorization)
170 return with_factorization(std::move(factorization));
173 parameters_type& with_factorization(
174 deferred_factory_parameter<const LinOpFactory> factorization)
176 this->factorization_generator = std::move(factorization);
177 this->deferred_factories[
"factorization"] = [](
const auto& exec,
179 if (!params.factorization_generator.is_empty()) {
180 params.factorization_factory =
181 params.factorization_generator.on(exec);
188 deferred_factory_parameter<const typename l_solver_type::Factory>
191 deferred_factory_parameter<const LinOpFactory> factorization_generator;
216 config::make_type_descriptor<value_type, index_type>())
218 return detail::ic_parse<Ic>(config, context, td_for_child);
243 std::unique_ptr<transposed_type> transposed{
246 transposed->l_solver_ =
247 share(as<typename lh_solver_type::transposed_type>(
249 transposed->lh_solver_ =
250 share(as<typename l_solver_type::transposed_type>(
253 return std::move(transposed);
258 std::unique_ptr<transposed_type> transposed{
261 transposed->l_solver_ =
262 share(as<typename lh_solver_type::transposed_type>(
264 transposed->lh_solver_ =
265 share(as<typename l_solver_type::transposed_type>(
268 return std::move(transposed);
278 if (&other !=
this) {
281 l_solver_ = other.l_solver_;
282 lh_solver_ = other.lh_solver_;
283 parameters_ = other.parameters_;
300 if (&other !=
this) {
303 l_solver_ = std::move(other.l_solver_);
304 lh_solver_ = std::move(other.lh_solver_);
306 if (other.get_executor() != exec) {
328 void apply_impl(
const LinOp* b,
LinOp* x)
const override
331 precision_dispatch_real_complex<value_type>(
332 [&](
auto dense_b,
auto dense_x) {
333 this->set_cache_to(dense_b);
334 l_solver_->apply(dense_b, cache_.intermediate);
335 if (lh_solver_->apply_uses_initial_guess()) {
336 dense_x->copy_from(cache_.intermediate);
338 lh_solver_->apply(cache_.intermediate, dense_x);
344 LinOp* x)
const override
346 precision_dispatch_real_complex<value_type>(
347 [&](
auto dense_alpha,
auto dense_b,
auto dense_beta,
auto dense_x) {
348 this->set_cache_to(dense_b);
349 l_solver_->apply(dense_b, cache_.intermediate);
350 lh_solver_->apply(dense_alpha, cache_.intermediate, dense_beta,
356 explicit Ic(std::shared_ptr<const Executor> exec)
357 : EnableLinOp<
Ic>(std::move(exec))
360 explicit Ic(
const Factory* factory, std::shared_ptr<const LinOp> lin_op)
362 parameters_{
factory->get_parameters()}
365 std::dynamic_pointer_cast<
const Composition<value_type>>(lin_op);
366 std::shared_ptr<const LinOp> l_factor;
370 auto exec = lin_op->get_executor();
373 factorization::ParIc<value_type, index_type>::build()
374 .with_both_factors(
false)
377 auto fact = std::shared_ptr<const LinOp>(
381 std::dynamic_pointer_cast<
const Composition<value_type>>(fact);
383 GKO_NOT_SUPPORTED(comp);
387 if (comp->get_operators().size() > 2 || comp->get_operators().empty()) {
388 GKO_NOT_SUPPORTED(comp);
390 l_factor = comp->get_operators()[0];
391 GKO_ASSERT_IS_SQUARE_MATRIX(l_factor);
397 l_solver_ = generate_default_solver<l_solver_type>(exec, l_factor);
400 if (comp->get_operators().size() == 2) {
401 auto lh_factor = comp->get_operators()[1];
402 GKO_ASSERT_EQUAL_DIMENSIONS(l_factor, lh_factor);
403 lh_solver_ = as<lh_solver_type>(l_solver_->conj_transpose());
405 lh_solver_ = as<lh_solver_type>(l_solver_->conj_transpose());
409 lh_solver_ = as<lh_solver_type>(l_solver_->conj_transpose());
420 void set_cache_to(
const LinOp* b)
const
422 if (cache_.intermediate ==
nullptr) {
423 cache_.intermediate =
427 cache_.intermediate->copy_from(b);
438 template <
typename SolverType>
439 static std::enable_if_t<solver::has_with_criteria<SolverType>::value,
440 std::unique_ptr<SolverType>>
441 generate_default_solver(
const std::shared_ptr<const Executor>& exec,
442 const std::shared_ptr<const LinOp>& mtx)
445 const unsigned int default_max_iters{
446 static_cast<unsigned int>(mtx->get_size()[0])};
448 return SolverType::build()
450 gko::stop::Iteration::build().with_max_iters(default_max_iters),
452 .with_reduction_factor(default_reduce_residual))
460 template <
typename SolverType>
461 static std::enable_if_t<!solver::has_with_criteria<SolverType>::value,
462 std::unique_ptr<SolverType>>
463 generate_default_solver(
const std::shared_ptr<const Executor>& exec,
464 const std::shared_ptr<const LinOp>& mtx)
466 return SolverType::build().on(exec)->generate(mtx);
470 std::shared_ptr<const l_solver_type> l_solver_{};
471 std::shared_ptr<const lh_solver_type> lh_solver_{};
482 mutable struct cache_struct {
483 cache_struct() =
default;
484 ~cache_struct() =
default;
485 cache_struct(
const cache_struct&) {}
486 cache_struct(cache_struct&&) {}
487 cache_struct&
operator=(
const cache_struct&) {
return *
this; }
488 cache_struct&
operator=(cache_struct&&) {
return *
this; }
489 std::unique_ptr<LinOp> intermediate{};
498 #endif // GKO_PUBLIC_CORE_PRECONDITIONER_IC_HPP_