5 #ifndef GKO_PUBLIC_CORE_BASE_PRECISION_DISPATCH_HPP_
6 #define GKO_PUBLIC_CORE_BASE_PRECISION_DISPATCH_HPP_
9 #include <ginkgo/config.hpp>
10 #include <ginkgo/core/base/math.hpp>
11 #include <ginkgo/core/base/temporary_conversion.hpp>
12 #include <ginkgo/core/distributed/vector.hpp>
13 #include <ginkgo/core/matrix/dense.hpp>
43 template <
typename ValueType,
typename Ptr>
44 detail::temporary_conversion<std::conditional_t<
45 std::is_const<detail::pointee<Ptr>>::value,
const matrix::Dense<ValueType>,
46 matrix::Dense<ValueType>>>
49 using Pointee = detail::pointee<Ptr>;
54 using MaybeConstDense =
55 std::conditional_t<std::is_const<Pointee>::value,
const Dense, Dense>;
57 detail::temporary_conversion<MaybeConstDense>::template create<
58 NextDense, Next2Dense, Next3Dense>(matrix);
60 GKO_NOT_SUPPORTED(matrix);
80 template <
typename ValueType,
typename Function,
typename... Args>
83 fn(make_temporary_conversion<ValueType>(linops).get()...);
96 template <
typename ValueType,
typename Function>
103 auto complex_to_real =
104 !(is_complex<ValueType>() ||
106 if (complex_to_real) {
107 auto dense_in = make_temporary_conversion<to_complex<ValueType>>(in);
108 auto dense_out = make_temporary_conversion<to_complex<ValueType>>(out);
113 fn(dynamic_cast<const Dense*>(dense_in->create_real_view().get()),
114 dynamic_cast<Dense*>(dense_out->create_real_view().get()));
116 precision_dispatch<ValueType>(fn, in, out);
130 template <
typename ValueType,
typename Function>
138 auto complex_to_real =
139 !(is_complex<ValueType>() ||
141 if (complex_to_real) {
142 auto dense_in = make_temporary_conversion<to_complex<ValueType>>(in);
143 auto dense_out = make_temporary_conversion<to_complex<ValueType>>(out);
144 auto dense_alpha = make_temporary_conversion<ValueType>(alpha);
149 fn(dense_alpha.get(),
150 dynamic_cast<const Dense*>(dense_in->create_real_view().get()),
151 dynamic_cast<Dense*>(dense_out->create_real_view().get()));
153 precision_dispatch<ValueType>(fn, alpha, in, out);
167 template <
typename ValueType,
typename Function>
176 auto complex_to_real =
177 !(is_complex<ValueType>() ||
179 if (complex_to_real) {
180 auto dense_in = make_temporary_conversion<to_complex<ValueType>>(in);
181 auto dense_out = make_temporary_conversion<to_complex<ValueType>>(out);
182 auto dense_alpha = make_temporary_conversion<ValueType>(alpha);
183 auto dense_beta = make_temporary_conversion<ValueType>(beta);
188 fn(dense_alpha.get(),
189 dynamic_cast<const Dense*>(dense_in->create_real_view().get()),
191 dynamic_cast<Dense*>(dense_out->create_real_view().get()));
193 precision_dispatch<ValueType>(fn, alpha, in, beta, out);
227 template <
typename ValueType,
typename Function>
230 #ifdef GINKGO_MIXED_PRECISION
235 auto dispatch_out_vector = [&](
auto dense_in) {
236 if (
auto dense_out = dynamic_cast<fst_type*>(out)) {
237 fn(dense_in, dense_out);
238 }
else if (
auto dense_out = dynamic_cast<snd_type*>(out)) {
239 fn(dense_in, dense_out);
240 }
else if (
auto dense_out = dynamic_cast<trd_type*>(out)) {
241 fn(dense_in, dense_out);
242 }
else if (
auto dense_out = dynamic_cast<fth_type*>(out)) {
243 fn(dense_in, dense_out);
245 GKO_NOT_SUPPORTED(out);
248 if (
auto dense_in = dynamic_cast<const fst_type*>(in)) {
249 dispatch_out_vector(dense_in);
250 }
else if (
auto dense_in = dynamic_cast<const snd_type*>(in)) {
251 dispatch_out_vector(dense_in);
252 }
else if (
auto dense_in = dynamic_cast<const trd_type*>(in)) {
253 dispatch_out_vector(dense_in);
254 }
else if (
auto dense_in = dynamic_cast<const fth_type*>(in)) {
255 dispatch_out_vector(dense_in);
257 GKO_NOT_SUPPORTED(in);
260 precision_dispatch<ValueType>(fn, in, out);
274 template <
typename ValueType,
typename Function,
275 std::enable_if_t<is_complex<ValueType>()>* =
nullptr>
279 #ifdef GINKGO_MIXED_PRECISION
280 mixed_precision_dispatch<ValueType>(fn, in, out);
282 precision_dispatch<ValueType>(fn, in, out);
287 template <
typename ValueType,
typename Function,
288 std::enable_if_t<!is_complex<ValueType>()>* =
nullptr>
292 #ifdef GINKGO_MIXED_PRECISION
293 if (!
dynamic_cast<const ConvertibleTo<matrix::Dense<>
>*>(in)) {
294 mixed_precision_dispatch<to_complex<ValueType>>(
295 [&fn](
auto dense_in,
auto dense_out) {
296 fn(dense_in->create_real_view().get(),
297 dense_out->create_real_view().get());
301 mixed_precision_dispatch<ValueType>(fn, in, out);
304 precision_dispatch_real_complex<ValueType>(fn, in, out);
309 namespace experimental {
315 namespace distributed {
343 template <
typename ValueType>
348 gko::detail::temporary_conversion<Vector<ValueType>>::template create<
353 GKO_NOT_SUPPORTED(matrix);
362 template <
typename ValueType>
363 gko::detail::temporary_conversion<const Vector<ValueType>>
366 auto result = gko::detail::temporary_conversion<const Vector<ValueType>>::
367 template create<Vector<next_precision<ValueType>>,
371 GKO_NOT_SUPPORTED(matrix);
391 template <
typename ValueType,
typename Function,
typename... Args>
394 fn(distributed::make_temporary_conversion<ValueType>(linops).get()...);
398 template <
typename ValueType,
typename Function>
401 #ifdef GINKGO_MIXED_PRECISION
402 using fst_type = Vector<ValueType>;
403 using snd_type = Vector<next_precision<ValueType, 2>>;
404 using trd_type = Vector<next_precision<ValueType, 3>>;
405 auto dispatch_out_vector = [&](
auto vector_in) {
406 if (
auto vector_out = dynamic_cast<fst_type*>(out)) {
407 fn(vector_in, vector_out);
408 }
else if (
auto vector_out = dynamic_cast<snd_type*>(out)) {
409 fn(vector_in, vector_out);
410 }
else if (
auto vector_out = dynamic_cast<trd_type*>(out)) {
411 fn(vector_in, vector_out);
413 GKO_NOT_SUPPORTED(out);
416 if (
auto vector_in = dynamic_cast<const fst_type*>(in)) {
417 dispatch_out_vector(vector_in);
418 }
else if (
auto vector_in = dynamic_cast<const snd_type*>(in)) {
419 dispatch_out_vector(vector_in);
420 }
else if (
auto vector_in = dynamic_cast<const trd_type*>(in)) {
421 dispatch_out_vector(vector_in);
423 GKO_NOT_SUPPORTED(in);
427 distributed::precision_dispatch<ValueType>(fn, in, out);
441 template <
typename ValueType,
typename Function>
444 auto complex_to_real = !(
445 is_complex<ValueType>() ||
448 if (complex_to_real) {
450 distributed::make_temporary_conversion<to_complex<ValueType>>(in);
452 distributed::make_temporary_conversion<to_complex<ValueType>>(out);
457 fn(dynamic_cast<const Vector*>(dense_in->create_real_view().get()),
458 dynamic_cast<Vector*>(dense_out->create_real_view().get()));
460 distributed::precision_dispatch<ValueType>(fn, in, out);
465 template <
typename ValueType,
typename Function>
469 auto complex_to_real = !(
470 is_complex<ValueType>() ||
473 if (complex_to_real) {
474 distributed::mixed_precision_dispatch<to_complex<ValueType>>(
475 [&fn](
auto vector_in,
auto vector_out) {
476 fn(vector_in->create_real_view().get(),
477 vector_out->create_real_view().get());
481 distributed::mixed_precision_dispatch<ValueType>(fn, in, out);
489 template <
typename ValueType,
typename Function>
493 auto complex_to_real = !(
494 is_complex<ValueType>() ||
497 if (complex_to_real) {
499 distributed::make_temporary_conversion<to_complex<ValueType>>(in);
501 distributed::make_temporary_conversion<to_complex<ValueType>>(out);
502 auto dense_alpha = gko::make_temporary_conversion<ValueType>(alpha);
507 fn(dense_alpha.get(),
508 dynamic_cast<const Vector*>(dense_in->create_real_view().get()),
509 dynamic_cast<Vector*>(dense_out->create_real_view().get()));
511 fn(gko::make_temporary_conversion<ValueType>(alpha).get(),
512 distributed::make_temporary_conversion<ValueType>(in).get(),
513 distributed::make_temporary_conversion<ValueType>(out).get());
521 template <
typename ValueType,
typename Function>
526 auto complex_to_real = !(
527 is_complex<ValueType>() ||
530 if (complex_to_real) {
532 distributed::make_temporary_conversion<to_complex<ValueType>>(in);
534 distributed::make_temporary_conversion<to_complex<ValueType>>(out);
535 auto dense_alpha = gko::make_temporary_conversion<ValueType>(alpha);
536 auto dense_beta = gko::make_temporary_conversion<ValueType>(beta);
541 fn(dense_alpha.get(),
542 dynamic_cast<const Vector*>(dense_in->create_real_view().get()),
544 dynamic_cast<Vector*>(dense_out->create_real_view().get()));
546 fn(gko::make_temporary_conversion<ValueType>(alpha).get(),
547 distributed::make_temporary_conversion<ValueType>(in).get(),
548 gko::make_temporary_conversion<ValueType>(beta).get(),
549 distributed::make_temporary_conversion<ValueType>(out).get());
570 template <
typename ValueType,
typename Function>
571 void precision_dispatch_real_complex_distributed(Function fn,
const LinOp* in,
574 if (dynamic_cast<const experimental::distributed::DistributedBase*>(in)) {
575 experimental::distributed::precision_dispatch_real_complex<ValueType>(
578 gko::precision_dispatch_real_complex<ValueType>(fn, in, out);
587 template <
typename ValueType,
typename Function>
588 void precision_dispatch_real_complex_distributed(Function fn,
590 const LinOp* in, LinOp* out)
592 if (dynamic_cast<const experimental::distributed::DistributedBase*>(in)) {
593 experimental::distributed::precision_dispatch_real_complex<ValueType>(
596 gko::precision_dispatch_real_complex<ValueType>(fn, alpha, in, out);
605 template <
typename ValueType,
typename Function>
606 void precision_dispatch_real_complex_distributed(Function fn,
609 const LinOp* beta, LinOp* out)
611 if (dynamic_cast<const experimental::distributed::DistributedBase*>(in)) {
612 experimental::distributed::precision_dispatch_real_complex<ValueType>(
613 fn, alpha, in, beta, out);
616 gko::precision_dispatch_real_complex<ValueType>(fn, alpha, in, beta,
635 template <
typename ValueType,
typename Function,
typename... Args>
636 void precision_dispatch_real_complex_distributed(Function fn, Args*... args)
638 precision_dispatch_real_complex<ValueType>(fn, args...);
649 #endif // GKO_PUBLIC_CORE_BASE_PRECISION_DISPATCH_HPP_