5 #ifndef GKO_PUBLIC_CORE_BASE_PRECISION_DISPATCH_HPP_
6 #define GKO_PUBLIC_CORE_BASE_PRECISION_DISPATCH_HPP_
9 #include <ginkgo/config.hpp>
10 #include <ginkgo/core/base/math.hpp>
11 #include <ginkgo/core/base/temporary_conversion.hpp>
12 #include <ginkgo/core/distributed/vector.hpp>
13 #include <ginkgo/core/matrix/dense.hpp>
43 template <
typename ValueType,
typename Ptr>
44 detail::temporary_conversion<std::conditional_t<
45 std::is_const<detail::pointee<Ptr>>::value,
const matrix::Dense<ValueType>,
46 matrix::Dense<ValueType>>>
49 using Pointee = detail::pointee<Ptr>;
54 using MaybeConstDense =
55 std::conditional_t<std::is_const<Pointee>::value,
const Dense, Dense>;
56 auto result = detail::temporary_conversion<
57 MaybeConstDense>::template create<NextDense, NextNextDense>(matrix);
59 GKO_NOT_SUPPORTED(matrix);
79 template <
typename ValueType,
typename Function,
typename... Args>
82 fn(make_temporary_conversion<ValueType>(linops).get()...);
95 template <
typename ValueType,
typename Function>
102 auto complex_to_real =
103 !(is_complex<ValueType>() ||
105 if (complex_to_real) {
106 auto dense_in = make_temporary_conversion<to_complex<ValueType>>(in);
107 auto dense_out = make_temporary_conversion<to_complex<ValueType>>(out);
112 fn(dynamic_cast<const Dense*>(dense_in->create_real_view().get()),
113 dynamic_cast<Dense*>(dense_out->create_real_view().get()));
115 precision_dispatch<ValueType>(fn, in, out);
129 template <
typename ValueType,
typename Function>
137 auto complex_to_real =
138 !(is_complex<ValueType>() ||
140 if (complex_to_real) {
141 auto dense_in = make_temporary_conversion<to_complex<ValueType>>(in);
142 auto dense_out = make_temporary_conversion<to_complex<ValueType>>(out);
143 auto dense_alpha = make_temporary_conversion<ValueType>(alpha);
148 fn(dense_alpha.get(),
149 dynamic_cast<const Dense*>(dense_in->create_real_view().get()),
150 dynamic_cast<Dense*>(dense_out->create_real_view().get()));
152 precision_dispatch<ValueType>(fn, alpha, in, out);
166 template <
typename ValueType,
typename Function>
175 auto complex_to_real =
176 !(is_complex<ValueType>() ||
178 if (complex_to_real) {
179 auto dense_in = make_temporary_conversion<to_complex<ValueType>>(in);
180 auto dense_out = make_temporary_conversion<to_complex<ValueType>>(out);
181 auto dense_alpha = make_temporary_conversion<ValueType>(alpha);
182 auto dense_beta = make_temporary_conversion<ValueType>(beta);
187 fn(dense_alpha.get(),
188 dynamic_cast<const Dense*>(dense_in->create_real_view().get()),
190 dynamic_cast<Dense*>(dense_out->create_real_view().get()));
192 precision_dispatch<ValueType>(fn, alpha, in, beta, out);
226 template <
typename ValueType,
typename Function>
229 #ifdef GINKGO_MIXED_PRECISION
233 auto dispatch_out_vector = [&](
auto dense_in) {
234 if (
auto dense_out = dynamic_cast<fst_type*>(out)) {
235 fn(dense_in, dense_out);
236 }
else if (
auto dense_out = dynamic_cast<snd_type*>(out)) {
237 fn(dense_in, dense_out);
238 }
else if (
auto dense_out = dynamic_cast<trd_type*>(out)) {
239 fn(dense_in, dense_out);
241 GKO_NOT_SUPPORTED(out);
244 if (
auto dense_in = dynamic_cast<const fst_type*>(in)) {
245 dispatch_out_vector(dense_in);
246 }
else if (
auto dense_in = dynamic_cast<const snd_type*>(in)) {
247 dispatch_out_vector(dense_in);
248 }
else if (
auto dense_in = dynamic_cast<const trd_type*>(in)) {
249 dispatch_out_vector(dense_in);
251 GKO_NOT_SUPPORTED(in);
254 precision_dispatch<ValueType>(fn, in, out);
268 template <
typename ValueType,
typename Function,
269 std::enable_if_t<is_complex<ValueType>()>* =
nullptr>
273 #ifdef GINKGO_MIXED_PRECISION
274 mixed_precision_dispatch<ValueType>(fn, in, out);
276 precision_dispatch<ValueType>(fn, in, out);
281 template <
typename ValueType,
typename Function,
282 std::enable_if_t<!is_complex<ValueType>()>* =
nullptr>
286 #ifdef GINKGO_MIXED_PRECISION
287 if (!
dynamic_cast<const ConvertibleTo<matrix::Dense<>
>*>(in)) {
288 mixed_precision_dispatch<to_complex<ValueType>>(
289 [&fn](
auto dense_in,
auto dense_out) {
290 fn(dense_in->create_real_view().get(),
291 dense_out->create_real_view().get());
295 mixed_precision_dispatch<ValueType>(fn, in, out);
298 precision_dispatch_real_complex<ValueType>(fn, in, out);
303 namespace experimental {
309 namespace distributed {
337 template <
typename ValueType>
342 gko::detail::temporary_conversion<Vector<ValueType>>::template create<
345 GKO_NOT_SUPPORTED(matrix);
354 template <
typename ValueType>
355 gko::detail::temporary_conversion<const Vector<ValueType>>
358 auto result = gko::detail::temporary_conversion<const Vector<ValueType>>::
359 template create<Vector<next_precision_base<ValueType>>>(matrix);
361 GKO_NOT_SUPPORTED(matrix);
381 template <
typename ValueType,
typename Function,
typename... Args>
385 GKO_NOT_SUPPORTED(
nullptr);
387 fn(distributed::make_temporary_conversion<ValueType>(linops).get()...);
401 template <
typename ValueType,
typename Function>
405 GKO_NOT_SUPPORTED(
nullptr);
407 auto complex_to_real = !(
408 is_complex<ValueType>() ||
411 if (complex_to_real) {
413 distributed::make_temporary_conversion<to_complex<ValueType>>(
416 distributed::make_temporary_conversion<to_complex<ValueType>>(
422 fn(dynamic_cast<const Vector*>(dense_in->create_real_view().get()),
423 dynamic_cast<Vector*>(dense_out->create_real_view().get()));
425 distributed::precision_dispatch<ValueType>(fn, in, out);
434 template <
typename ValueType,
typename Function>
439 GKO_NOT_SUPPORTED(
nullptr);
441 auto complex_to_real = !(
442 is_complex<ValueType>() ||
445 if (complex_to_real) {
447 distributed::make_temporary_conversion<to_complex<ValueType>>(
450 distributed::make_temporary_conversion<to_complex<ValueType>>(
452 auto dense_alpha = gko::make_temporary_conversion<ValueType>(alpha);
457 fn(dense_alpha.get(),
458 dynamic_cast<const Vector*>(dense_in->create_real_view().get()),
459 dynamic_cast<Vector*>(dense_out->create_real_view().get()));
461 fn(gko::make_temporary_conversion<ValueType>(alpha).get(),
462 distributed::make_temporary_conversion<ValueType>(in).get(),
463 distributed::make_temporary_conversion<ValueType>(out).get());
472 template <
typename ValueType,
typename Function>
478 GKO_NOT_SUPPORTED(
nullptr);
480 auto complex_to_real = !(
481 is_complex<ValueType>() ||
484 if (complex_to_real) {
486 distributed::make_temporary_conversion<to_complex<ValueType>>(
489 distributed::make_temporary_conversion<to_complex<ValueType>>(
491 auto dense_alpha = gko::make_temporary_conversion<ValueType>(alpha);
492 auto dense_beta = gko::make_temporary_conversion<ValueType>(beta);
497 fn(dense_alpha.get(),
498 dynamic_cast<const Vector*>(dense_in->create_real_view().get()),
500 dynamic_cast<Vector*>(dense_out->create_real_view().get()));
502 fn(gko::make_temporary_conversion<ValueType>(alpha).get(),
503 distributed::make_temporary_conversion<ValueType>(in).get(),
504 gko::make_temporary_conversion<ValueType>(beta).get(),
505 distributed::make_temporary_conversion<ValueType>(out).get());
527 template <
typename ValueType,
typename Function>
528 void precision_dispatch_real_complex_distributed(Function fn,
const LinOp* in,
531 if (dynamic_cast<const experimental::distributed::DistributedBase*>(in)) {
532 experimental::distributed::precision_dispatch_real_complex<ValueType>(
535 gko::precision_dispatch_real_complex<ValueType>(fn, in, out);
544 template <
typename ValueType,
typename Function>
545 void precision_dispatch_real_complex_distributed(Function fn,
547 const LinOp* in, LinOp* out)
549 if (dynamic_cast<const experimental::distributed::DistributedBase*>(in)) {
550 experimental::distributed::precision_dispatch_real_complex<ValueType>(
553 gko::precision_dispatch_real_complex<ValueType>(fn, alpha, in, out);
562 template <
typename ValueType,
typename Function>
563 void precision_dispatch_real_complex_distributed(Function fn,
566 const LinOp* beta, LinOp* out)
568 if (dynamic_cast<const experimental::distributed::DistributedBase*>(in)) {
569 experimental::distributed::precision_dispatch_real_complex<ValueType>(
570 fn, alpha, in, beta, out);
573 gko::precision_dispatch_real_complex<ValueType>(fn, alpha, in, beta,
592 template <
typename ValueType,
typename Function,
typename... Args>
593 void precision_dispatch_real_complex_distributed(Function fn, Args*... args)
595 precision_dispatch_real_complex<ValueType>(fn, args...);
606 #endif // GKO_PUBLIC_CORE_BASE_PRECISION_DISPATCH_HPP_