5 #ifndef GKO_PUBLIC_CORE_PRECONDITIONER_ILU_HPP_
6 #define GKO_PUBLIC_CORE_PRECONDITIONER_ILU_HPP_
10 #include <type_traits>
13 #include <ginkgo/core/base/abstract_factory.hpp>
14 #include <ginkgo/core/base/composition.hpp>
15 #include <ginkgo/core/base/exception.hpp>
16 #include <ginkgo/core/base/exception_helpers.hpp>
17 #include <ginkgo/core/base/lin_op.hpp>
18 #include <ginkgo/core/base/precision_dispatch.hpp>
19 #include <ginkgo/core/base/std_extensions.hpp>
20 #include <ginkgo/core/config/config.hpp>
21 #include <ginkgo/core/config/registry.hpp>
22 #include <ginkgo/core/factorization/par_ilu.hpp>
23 #include <ginkgo/core/matrix/dense.hpp>
24 #include <ginkgo/core/preconditioner/isai.hpp>
25 #include <ginkgo/core/preconditioner/utils.hpp>
26 #include <ginkgo/core/solver/gmres.hpp>
27 #include <ginkgo/core/solver/ir.hpp>
28 #include <ginkgo/core/solver/solver_traits.hpp>
29 #include <ginkgo/core/solver/triangular.hpp>
30 #include <ginkgo/core/stop/combined.hpp>
31 #include <ginkgo/core/stop/iteration.hpp>
32 #include <ginkgo/core/stop/residual_norm.hpp>
36 namespace preconditioner {
40 template <
typename LSolverType,
typename USolverType>
41 constexpr
bool support_ilu_parse =
42 std::is_same<typename USolverType::transposed_type, LSolverType>::value &&
43 (is_instantiation_of<LSolverType, solver::LowerTrs>::value ||
44 is_instantiation_of<LSolverType, solver::Ir>::value ||
45 is_instantiation_of<LSolverType, solver::Gmres>::value ||
46 is_instantiation_of<LSolverType, preconditioner::LowerIsai>::value);
49 template <
typename Ilu,
50 std::enable_if_t<!support_ilu_parse<
typename Ilu::l_solver_type,
51 typename Ilu::u_solver_type>>* =
53 typename Ilu::parameters_type ilu_parse(
54 const config::pnode& config,
const config::registry& context,
55 const config::type_descriptor& td_for_child)
58 "preconditioner::Ilu only supports limited type for parse.");
63 std::enable_if_t<support_ilu_parse<
typename Ilu::l_solver_type,
64 typename Ilu::u_solver_type>>* =
nullptr>
65 typename Ilu::parameters_type ilu_parse(
66 const config::pnode& config,
const config::registry& context,
67 const config::type_descriptor& td_for_child);
121 template <
typename LSolverType = solver::LowerTrs<>,
122 typename USolverType = solver::UpperTrs<>,
bool ReverseApply = false,
123 typename IndexType =
int32>
125 Ilu<LSolverType, USolverType, ReverseApply, IndexType>>,
132 std::is_same<
typename LSolverType::value_type,
133 typename USolverType::value_type>::value,
134 "Both the L- and the U-solver must use the same `value_type`!");
135 using value_type =
typename LSolverType::value_type;
136 using l_solver_type = LSolverType;
137 using u_solver_type = USolverType;
138 static constexpr
bool performs_reverse_apply = ReverseApply;
139 using index_type = IndexType;
141 Ilu<
typename USolverType::transposed_type,
142 typename LSolverType::transposed_type, ReverseApply, IndexType>;
151 std::shared_ptr<const typename l_solver_type::Factory>
157 std::shared_ptr<const typename u_solver_type::Factory>
165 GKO_DEPRECATED(
"use with_l_solver instead")
170 return with_l_solver(std::move(solver));
173 parameters_type& with_l_solver(
177 this->l_solver_generator = std::move(solver);
178 this->deferred_factories[
"l_solver"] = [](
const auto& exec,
180 if (!params.l_solver_generator.is_empty()) {
181 params.l_solver_factory =
182 params.l_solver_generator.on(exec);
188 GKO_DEPRECATED(
"use with_u_solver instead")
189 parameters_type& with_u_solver_factory(
190 deferred_factory_parameter<const typename u_solver_type::Factory>
193 return with_u_solver(std::move(solver));
196 parameters_type& with_u_solver(
197 deferred_factory_parameter<const typename u_solver_type::Factory>
200 this->u_solver_generator = std::move(solver);
201 this->deferred_factories[
"u_solver"] = [](
const auto& exec,
203 if (!params.u_solver_generator.is_empty()) {
204 params.u_solver_factory =
205 params.u_solver_generator.on(exec);
211 GKO_DEPRECATED(
"use with_factorization instead")
212 parameters_type& with_factorization_factory(
213 deferred_factory_parameter<const LinOpFactory> factorization)
215 return with_factorization(std::move(factorization));
218 parameters_type& with_factorization(
219 deferred_factory_parameter<const LinOpFactory> factorization)
221 this->factorization_generator = std::move(factorization);
222 this->deferred_factories[
"factorization"] = [](
const auto& exec,
224 if (!params.factorization_generator.is_empty()) {
225 params.factorization_factory =
226 params.factorization_generator.on(exec);
233 deferred_factory_parameter<const typename l_solver_type::Factory>
236 deferred_factory_parameter<const typename u_solver_type::Factory>
239 deferred_factory_parameter<const LinOpFactory> factorization_generator;
265 config::make_type_descriptor<value_type, index_type>())
267 return detail::ilu_parse<Ilu>(config, context, td_for_child);
292 std::unique_ptr<transposed_type> transposed{
295 transposed->l_solver_ =
296 share(as<typename u_solver_type::transposed_type>(
298 transposed->u_solver_ =
299 share(as<typename l_solver_type::transposed_type>(
302 return std::move(transposed);
307 std::unique_ptr<transposed_type> transposed{
310 transposed->l_solver_ =
311 share(as<typename u_solver_type::transposed_type>(
313 transposed->u_solver_ =
314 share(as<typename l_solver_type::transposed_type>(
317 return std::move(transposed);
327 if (&other !=
this) {
330 l_solver_ = other.l_solver_;
331 u_solver_ = other.u_solver_;
332 parameters_ = other.parameters_;
349 if (&other !=
this) {
352 l_solver_ = std::move(other.l_solver_);
353 u_solver_ = std::move(other.u_solver_);
355 if (other.get_executor() != exec) {
377 void apply_impl(
const LinOp* b,
LinOp* x)
const override
380 precision_dispatch_real_complex<value_type>(
381 [&](
auto dense_b,
auto dense_x) {
382 this->set_cache_to(dense_b);
384 l_solver_->apply(dense_b, cache_.intermediate);
385 if (u_solver_->apply_uses_initial_guess()) {
386 dense_x->copy_from(cache_.intermediate);
388 u_solver_->apply(cache_.intermediate, dense_x);
390 u_solver_->apply(dense_b, cache_.intermediate);
391 if (l_solver_->apply_uses_initial_guess()) {
392 dense_x->copy_from(cache_.intermediate);
394 l_solver_->apply(cache_.intermediate, dense_x);
401 LinOp* x)
const override
403 precision_dispatch_real_complex<value_type>(
404 [&](
auto dense_alpha,
auto dense_b,
auto dense_beta,
auto dense_x) {
405 this->set_cache_to(dense_b);
407 l_solver_->apply(dense_b, cache_.intermediate);
408 u_solver_->apply(dense_alpha, cache_.intermediate,
409 dense_beta, dense_x);
411 u_solver_->apply(dense_b, cache_.intermediate);
412 l_solver_->apply(dense_alpha, cache_.intermediate,
413 dense_beta, dense_x);
419 explicit Ilu(std::shared_ptr<const Executor> exec)
420 : EnableLinOp<
Ilu>(std::move(exec))
423 explicit Ilu(
const Factory* factory, std::shared_ptr<const LinOp> lin_op)
425 parameters_{
factory->get_parameters()}
428 std::dynamic_pointer_cast<
const Composition<value_type>>(lin_op);
429 std::shared_ptr<const LinOp> l_factor;
430 std::shared_ptr<const LinOp> u_factor;
434 auto exec = lin_op->get_executor();
437 factorization::ParIlu<value_type, index_type>::build().on(
440 auto fact = std::shared_ptr<const LinOp>(
444 std::dynamic_pointer_cast<
const Composition<value_type>>(fact);
446 GKO_NOT_SUPPORTED(comp);
449 if (comp->get_operators().size() == 2) {
450 l_factor = comp->get_operators()[0];
451 u_factor = comp->get_operators()[1];
453 GKO_NOT_SUPPORTED(comp);
455 GKO_ASSERT_EQUAL_DIMENSIONS(l_factor, u_factor);
461 l_solver_ = generate_default_solver<l_solver_type>(exec, l_factor);
466 u_solver_ = generate_default_solver<u_solver_type>(exec, u_factor);
479 void set_cache_to(
const LinOp* b)
const
481 if (cache_.intermediate ==
nullptr) {
482 cache_.intermediate =
486 cache_.intermediate->copy_from(b);
497 template <
typename SolverType>
498 static std::enable_if_t<solver::has_with_criteria<SolverType>::value,
499 std::unique_ptr<SolverType>>
500 generate_default_solver(
const std::shared_ptr<const Executor>& exec,
501 const std::shared_ptr<const LinOp>& mtx)
504 const unsigned int default_max_iters{
505 static_cast<unsigned int>(mtx->get_size()[0])};
507 return SolverType::build()
509 gko::stop::Iteration::build().with_max_iters(default_max_iters),
511 .with_reduction_factor(default_reduce_residual))
519 template <
typename SolverType>
520 static std::enable_if_t<!solver::has_with_criteria<SolverType>::value,
521 std::unique_ptr<SolverType>>
522 generate_default_solver(
const std::shared_ptr<const Executor>& exec,
523 const std::shared_ptr<const LinOp>& mtx)
525 return SolverType::build().on(exec)->generate(mtx);
529 std::shared_ptr<const l_solver_type> l_solver_{};
530 std::shared_ptr<const u_solver_type> u_solver_{};
541 mutable struct cache_struct {
542 cache_struct() =
default;
543 ~cache_struct() =
default;
544 cache_struct(
const cache_struct&) {}
545 cache_struct(cache_struct&&) {}
546 cache_struct&
operator=(
const cache_struct&) {
return *
this; }
547 cache_struct&
operator=(cache_struct&&) {
return *
this; }
548 std::unique_ptr<LinOp> intermediate{};
557 #endif // GKO_PUBLIC_CORE_PRECONDITIONER_ILU_HPP_