5 #ifndef GKO_PUBLIC_CORE_PRECONDITIONER_ILU_HPP_
6 #define GKO_PUBLIC_CORE_PRECONDITIONER_ILU_HPP_
10 #include <type_traits>
12 #include <ginkgo/core/base/abstract_factory.hpp>
13 #include <ginkgo/core/base/composition.hpp>
14 #include <ginkgo/core/base/exception.hpp>
15 #include <ginkgo/core/base/exception_helpers.hpp>
16 #include <ginkgo/core/base/lin_op.hpp>
17 #include <ginkgo/core/base/precision_dispatch.hpp>
18 #include <ginkgo/core/base/type_traits.hpp>
19 #include <ginkgo/core/config/config.hpp>
20 #include <ginkgo/core/config/registry.hpp>
21 #include <ginkgo/core/factorization/par_ilu.hpp>
22 #include <ginkgo/core/matrix/dense.hpp>
23 #include <ginkgo/core/solver/solver_traits.hpp>
24 #include <ginkgo/core/solver/triangular.hpp>
25 #include <ginkgo/core/stop/combined.hpp>
26 #include <ginkgo/core/stop/iteration.hpp>
27 #include <ginkgo/core/stop/residual_norm.hpp>
31 namespace preconditioner {
35 template <
typename Type>
36 constexpr
bool support_ilu_parse =
37 std::is_same_v<typename Type::l_solver_type, LinOp>&&
38 std::is_same_v<typename Type::u_solver_type, LinOp>;
41 template <
typename Ilu, std::enable_if_t<!support_ilu_parse<Ilu>>* =
nullptr>
42 typename Ilu::parameters_type ilu_parse(
43 const config::pnode& config,
const config::registry& context,
44 const config::type_descriptor& td_for_child)
47 "preconditioner::Ilu only supports limited type for parse.");
50 template <
typename Ilu, std::enable_if_t<support_ilu_parse<Ilu>>* =
nullptr>
51 typename Ilu::parameters_type ilu_parse(
52 const config::pnode& config,
const config::registry& context,
53 const config::type_descriptor& td_for_child);
110 template <
typename LSolverTypeOrValueType = solver::LowerTrs<>,
111 typename USolverTypeOrValueType =
112 gko::detail::transposed_type<LSolverTypeOrValueType>,
113 bool ReverseApply = false,
typename IndexType =
int32>
115 :
public EnableLinOp<Ilu<LSolverTypeOrValueType, USolverTypeOrValueType,
116 ReverseApply, IndexType>>,
123 std::is_same_v<gko::detail::get_value_type<LSolverTypeOrValueType>,
124 gko::detail::get_value_type<USolverTypeOrValueType>>,
125 "Both the L- and the U-solver must use the same `value_type`!");
126 using value_type = gko::detail::get_value_type<LSolverTypeOrValueType>;
127 using l_solver_type =
128 std::conditional_t<gko::detail::is_ginkgo_linop<LSolverTypeOrValueType>,
129 LSolverTypeOrValueType,
LinOp>;
130 using u_solver_type =
131 std::conditional_t<gko::detail::is_ginkgo_linop<USolverTypeOrValueType>,
132 USolverTypeOrValueType,
LinOp>;
133 static constexpr
bool performs_reverse_apply = ReverseApply;
134 using index_type = IndexType;
137 gko::detail::transposed_type<LSolverTypeOrValueType>, ReverseApply,
147 std::shared_ptr<const gko::detail::factory_type<l_solver_type>>
153 std::shared_ptr<const gko::detail::factory_type<u_solver_type>>
161 GKO_DEPRECATED(
"use with_l_solver instead")
164 const
gko::detail::factory_type<l_solver_type>>
177 const gko::detail::factory_type<l_solver_type>>
180 this->l_solver_generator = std::move(solver);
181 this->deferred_factories[
"l_solver"] = [](
const auto& exec,
183 if (!params.l_solver_generator.is_empty()) {
184 params.l_solver_factory =
185 params.l_solver_generator.on(exec);
191 GKO_DEPRECATED(
"use with_u_solver instead")
194 const
gko::detail::factory_type<u_solver_type>>
207 const gko::detail::factory_type<u_solver_type>>
210 this->u_solver_generator = std::move(solver);
211 this->deferred_factories[
"u_solver"] = [](
const auto& exec,
213 if (!params.u_solver_generator.is_empty()) {
214 params.u_solver_factory =
215 params.u_solver_generator.on(exec);
221 GKO_DEPRECATED(
"use with_factorization instead")
225 return with_factorization(std::move(factorization));
228 parameters_type& with_factorization(
231 this->factorization_generator = std::move(factorization);
232 this->deferred_factories[
"factorization"] = [](
const auto& exec,
234 if (!params.factorization_generator.is_empty()) {
235 params.factorization_factory =
236 params.factorization_generator.on(exec);
243 deferred_factory_parameter<
244 const gko::detail::factory_type<l_solver_type>>
247 deferred_factory_parameter<
248 const gko::detail::factory_type<u_solver_type>>
251 deferred_factory_parameter<const LinOpFactory> factorization_generator;
277 config::make_type_descriptor<value_type, index_type>())
280 return detail::ilu_parse<Ilu>(config, context, td_for_child);
305 std::unique_ptr<transposed_type> transposed{
308 transposed->l_solver_ =
309 share(
as<gko::detail::transposed_type<u_solver_type>>(
311 transposed->u_solver_ =
312 share(
as<gko::detail::transposed_type<l_solver_type>>(
315 return std::move(transposed);
320 std::unique_ptr<transposed_type> transposed{
323 transposed->l_solver_ =
324 share(
as<gko::detail::transposed_type<u_solver_type>>(
326 transposed->u_solver_ =
327 share(
as<gko::detail::transposed_type<l_solver_type>>(
330 return std::move(transposed);
340 if (&other !=
this) {
343 l_solver_ = other.l_solver_;
344 u_solver_ = other.u_solver_;
345 parameters_ = other.parameters_;
362 if (&other !=
this) {
365 l_solver_ = std::move(other.l_solver_);
366 u_solver_ = std::move(other.u_solver_);
368 if (other.get_executor() != exec) {
390 void apply_impl(
const LinOp* b,
LinOp* x)
const override
393 precision_dispatch_real_complex<value_type>(
394 [&](
auto dense_b,
auto dense_x) {
395 this->set_cache_to(dense_b);
397 l_solver_->apply(dense_b, cache_.intermediate);
398 if (u_solver_->apply_uses_initial_guess()) {
399 dense_x->copy_from(cache_.intermediate);
401 u_solver_->apply(cache_.intermediate, dense_x);
403 u_solver_->apply(dense_b, cache_.intermediate);
404 if (l_solver_->apply_uses_initial_guess()) {
405 dense_x->copy_from(cache_.intermediate);
407 l_solver_->apply(cache_.intermediate, dense_x);
414 LinOp* x)
const override
416 precision_dispatch_real_complex<value_type>(
417 [&](
auto dense_alpha,
auto dense_b,
auto dense_beta,
auto dense_x) {
418 this->set_cache_to(dense_b);
420 l_solver_->apply(dense_b, cache_.intermediate);
421 u_solver_->apply(dense_alpha, cache_.intermediate,
422 dense_beta, dense_x);
424 u_solver_->apply(dense_b, cache_.intermediate);
425 l_solver_->apply(dense_alpha, cache_.intermediate,
426 dense_beta, dense_x);
432 explicit Ilu(std::shared_ptr<const Executor> exec)
433 : EnableLinOp<
Ilu>(std::move(exec))
436 explicit Ilu(
const Factory* factory, std::shared_ptr<const LinOp> lin_op)
438 parameters_{
factory->get_parameters()}
441 std::dynamic_pointer_cast<
const Composition<value_type>>(lin_op);
442 std::shared_ptr<const LinOp> l_factor;
443 std::shared_ptr<const LinOp> u_factor;
447 auto exec = lin_op->get_executor();
450 factorization::ParIlu<value_type, index_type>::build().on(
453 auto fact = std::shared_ptr<const LinOp>(
456 comp = as<const Composition<value_type>>(fact);
458 if (comp->get_operators().size() == 2) {
459 l_factor = comp->get_operators()[0];
460 u_factor = comp->get_operators()[1];
462 GKO_NOT_SUPPORTED(comp);
464 GKO_ASSERT_EQUAL_DIMENSIONS(l_factor, u_factor);
471 l_solver_ = generate_default_solver<std::conditional_t<
472 std::is_same_v<l_solver_type, LinOp>,
473 solver::LowerTrs<value_type, index_type>, l_solver_type>>(
480 u_solver_ = generate_default_solver<std::conditional_t<
481 std::is_same_v<u_solver_type, LinOp>,
482 solver::UpperTrs<value_type, index_type>, u_solver_type>>(
496 void set_cache_to(
const LinOp* b)
const
498 if (cache_.intermediate ==
nullptr) {
499 cache_.intermediate =
503 cache_.intermediate->copy_from(b);
514 template <
typename SolverType>
515 static std::enable_if_t<solver::has_with_criteria<SolverType>::value,
516 std::unique_ptr<SolverType>>
517 generate_default_solver(
const std::shared_ptr<const Executor>& exec,
518 const std::shared_ptr<const LinOp>& mtx)
522 const unsigned int default_max_iters{
523 static_cast<unsigned int>(mtx->get_size()[0])};
525 return SolverType::build()
527 gko::stop::Iteration::build().with_max_iters(default_max_iters),
529 .with_reduction_factor(default_reduce_residual))
537 template <
typename SolverType>
538 static std::enable_if_t<!solver::has_with_criteria<SolverType>::value,
539 std::unique_ptr<SolverType>>
540 generate_default_solver(
const std::shared_ptr<const Executor>& exec,
541 const std::shared_ptr<const LinOp>& mtx)
543 return SolverType::build().on(exec)->generate(mtx);
547 std::shared_ptr<const l_solver_type> l_solver_{};
548 std::shared_ptr<const u_solver_type> u_solver_{};
559 mutable struct cache_struct {
560 cache_struct() =
default;
561 ~cache_struct() =
default;
562 cache_struct(
const cache_struct&) {}
563 cache_struct(cache_struct&&) {}
564 cache_struct&
operator=(
const cache_struct&) {
return *
this; }
565 cache_struct&
operator=(cache_struct&&) {
return *
this; }
566 std::unique_ptr<LinOp> intermediate{};
575 #endif // GKO_PUBLIC_CORE_PRECONDITIONER_ILU_HPP_