5 #ifndef GKO_PUBLIC_CORE_PRECONDITIONER_IC_HPP_
6 #define GKO_PUBLIC_CORE_PRECONDITIONER_IC_HPP_
10 #include <type_traits>
12 #include <ginkgo/core/base/abstract_factory.hpp>
13 #include <ginkgo/core/base/composition.hpp>
14 #include <ginkgo/core/base/exception.hpp>
15 #include <ginkgo/core/base/exception_helpers.hpp>
16 #include <ginkgo/core/base/lin_op.hpp>
17 #include <ginkgo/core/base/precision_dispatch.hpp>
18 #include <ginkgo/core/base/type_traits.hpp>
19 #include <ginkgo/core/config/config.hpp>
20 #include <ginkgo/core/config/registry.hpp>
21 #include <ginkgo/core/factorization/par_ic.hpp>
22 #include <ginkgo/core/matrix/dense.hpp>
23 #include <ginkgo/core/solver/solver_traits.hpp>
24 #include <ginkgo/core/solver/triangular.hpp>
25 #include <ginkgo/core/stop/combined.hpp>
26 #include <ginkgo/core/stop/iteration.hpp>
27 #include <ginkgo/core/stop/residual_norm.hpp>
31 namespace preconditioner {
35 template <
typename Type>
36 constexpr
bool support_ic_parse =
37 std::is_same_v<typename Type::l_solver_type, LinOp>;
40 template <
typename Ic, std::enable_if_t<!support_ic_parse<Ic>>* =
nullptr>
41 typename Ic::parameters_type ic_parse(
42 const config::pnode& config,
const config::registry& context,
43 const config::type_descriptor& td_for_child)
46 "preconditioner::Ic only supports limited type for parse.");
49 template <
typename Ic, std::enable_if_t<support_ic_parse<Ic>>* =
nullptr>
50 typename Ic::parameters_type ic_parse(
51 const config::pnode& config,
const config::registry& context,
52 const config::type_descriptor& td_for_child);
107 template <
typename LSolverTypeOrValueType = solver::LowerTrs<>,
108 typename IndexType =
int32>
115 using l_solver_type =
116 std::conditional_t<gko::detail::is_ginkgo_linop<LSolverTypeOrValueType>,
117 LSolverTypeOrValueType,
LinOp>;
118 static_assert(std::is_same<gko::detail::transposed_type<
119 gko::detail::transposed_type<l_solver_type>>,
120 l_solver_type>::value,
121 "l_solver_type::transposed_type must be symmetric");
122 using value_type = gko::detail::get_value_type<LSolverTypeOrValueType>;
123 using lh_solver_type = gko::detail::transposed_type<l_solver_type>;
124 using index_type = IndexType;
134 std::shared_ptr<const gko::detail::factory_type<l_solver_type>>
142 GKO_DEPRECATED(
"use with_l_solver instead")
145 const
gko::detail::factory_type<l_solver_type>>
158 const gko::detail::factory_type<l_solver_type>>
161 this->l_solver_generator = std::move(solver);
162 this->deferred_factories[
"l_solver"] = [](
const auto& exec,
164 if (!params.l_solver_generator.is_empty()) {
165 params.l_solver_factory =
166 params.l_solver_generator.on(exec);
172 GKO_DEPRECATED(
"use with_factorization instead")
176 return with_factorization(std::move(factorization));
179 parameters_type& with_factorization(
182 this->factorization_generator = std::move(factorization);
183 this->deferred_factories[
"factorization"] = [](
const auto& exec,
185 if (!params.factorization_generator.is_empty()) {
186 params.factorization_factory =
187 params.factorization_generator.on(exec);
194 deferred_factory_parameter<
195 const gko::detail::factory_type<l_solver_type>>
198 deferred_factory_parameter<const LinOpFactory> factorization_generator;
225 config::make_type_descriptor<value_type, index_type>())
228 return detail::ic_parse<Ic>(config, context, td_for_child);
253 std::unique_ptr<transposed_type> transposed{
256 transposed->l_solver_ =
257 share(
as<gko::detail::transposed_type<lh_solver_type>>(
259 transposed->lh_solver_ =
260 share(
as<gko::detail::transposed_type<l_solver_type>>(
263 return std::move(transposed);
268 std::unique_ptr<transposed_type> transposed{
271 transposed->l_solver_ =
272 share(
as<gko::detail::transposed_type<lh_solver_type>>(
274 transposed->lh_solver_ =
275 share(
as<gko::detail::transposed_type<l_solver_type>>(
278 return std::move(transposed);
288 if (&other !=
this) {
291 l_solver_ = other.l_solver_;
292 lh_solver_ = other.lh_solver_;
293 parameters_ = other.parameters_;
310 if (&other !=
this) {
313 l_solver_ = std::move(other.l_solver_);
314 lh_solver_ = std::move(other.lh_solver_);
316 if (other.get_executor() != exec) {
338 void apply_impl(
const LinOp* b,
LinOp* x)
const override
341 precision_dispatch_real_complex<value_type>(
342 [&](
auto dense_b,
auto dense_x) {
343 this->set_cache_to(dense_b);
344 l_solver_->apply(dense_b, cache_.intermediate);
345 if (lh_solver_->apply_uses_initial_guess()) {
346 dense_x->copy_from(cache_.intermediate);
348 lh_solver_->apply(cache_.intermediate, dense_x);
354 LinOp* x)
const override
356 precision_dispatch_real_complex<value_type>(
357 [&](
auto dense_alpha,
auto dense_b,
auto dense_beta,
auto dense_x) {
358 this->set_cache_to(dense_b);
359 l_solver_->apply(dense_b, cache_.intermediate);
360 lh_solver_->apply(dense_alpha, cache_.intermediate, dense_beta,
366 explicit Ic(std::shared_ptr<const Executor> exec)
367 : EnableLinOp<
Ic>(std::move(exec))
370 explicit Ic(
const Factory* factory, std::shared_ptr<const LinOp> lin_op)
372 parameters_{
factory->get_parameters()}
375 std::dynamic_pointer_cast<
const Composition<value_type>>(lin_op);
376 std::shared_ptr<const LinOp> l_factor;
380 auto exec = lin_op->get_executor();
384 factorization::ParIc<value_type, index_type>::build()
385 .with_both_factors(
false)
388 auto fact = std::shared_ptr<const LinOp>(
391 comp = gko::as<const Composition<value_type>>(fact);
394 if (comp->get_operators().size() > 2 || comp->get_operators().empty()) {
395 GKO_NOT_SUPPORTED(comp);
397 l_factor = comp->get_operators()[0];
398 GKO_ASSERT_IS_SQUARE_MATRIX(l_factor);
405 l_solver_ = generate_default_solver<std::conditional_t<
406 std::is_same_v<l_solver_type, LinOp>,
407 solver::LowerTrs<value_type, index_type>, l_solver_type>>(
413 if (comp->get_operators().size() == 2) {
414 auto lh_factor = comp->get_operators()[1];
415 GKO_ASSERT_EQUAL_DIMENSIONS(l_factor, lh_factor);
417 lh_solver_ = as<lh_solver_type>(
421 lh_solver_ = as<lh_solver_type>(
433 void set_cache_to(
const LinOp* b)
const
435 if (cache_.intermediate ==
nullptr) {
436 cache_.intermediate =
440 cache_.intermediate->copy_from(b);
450 template <
typename SolverType>
451 static std::enable_if_t<solver::has_with_criteria<SolverType>::value &&
452 !std::is_same_v<SolverType, LinOp>,
453 std::unique_ptr<SolverType>>
454 generate_default_solver(
const std::shared_ptr<const Executor>& exec,
455 const std::shared_ptr<const LinOp>& mtx)
458 const unsigned int default_max_iters{
459 static_cast<unsigned int>(mtx->get_size()[0])};
461 return SolverType::build()
463 gko::stop::Iteration::build().with_max_iters(default_max_iters),
465 .with_reduction_factor(default_reduce_residual))
473 template <
typename SolverType>
474 static std::enable_if_t<!solver::has_with_criteria<SolverType>::value &&
475 !std::is_same_v<SolverType, LinOp>,
476 std::unique_ptr<SolverType>>
477 generate_default_solver(
const std::shared_ptr<const Executor>& exec,
478 const std::shared_ptr<const LinOp>& mtx)
480 return SolverType::build().on(exec)->generate(mtx);
484 std::shared_ptr<const l_solver_type> l_solver_{};
485 std::shared_ptr<const lh_solver_type> lh_solver_{};
496 mutable struct cache_struct {
497 cache_struct() =
default;
498 ~cache_struct() =
default;
499 cache_struct(
const cache_struct&) {}
500 cache_struct(cache_struct&&) {}
501 cache_struct&
operator=(
const cache_struct&) {
return *
this; }
502 cache_struct&
operator=(cache_struct&&) {
return *
this; }
503 std::unique_ptr<LinOp> intermediate{};
512 #endif // GKO_PUBLIC_CORE_PRECONDITIONER_IC_HPP_