5 #ifndef GKO_PUBLIC_CORE_BASE_PRECISION_DISPATCH_HPP_
6 #define GKO_PUBLIC_CORE_BASE_PRECISION_DISPATCH_HPP_
9 #include <ginkgo/config.hpp>
10 #include <ginkgo/core/base/math.hpp>
11 #include <ginkgo/core/base/temporary_conversion.hpp>
12 #include <ginkgo/core/distributed/vector.hpp>
13 #include <ginkgo/core/matrix/dense.hpp>
43 template <
typename ValueType,
typename Ptr>
44 detail::temporary_conversion<std::conditional_t<
45 std::is_const<detail::pointee<Ptr>>::value,
const matrix::Dense<ValueType>,
46 matrix::Dense<ValueType>>>
49 using Pointee = detail::pointee<Ptr>;
54 using MaybeConstDense =
55 std::conditional_t<std::is_const<Pointee>::value,
const Dense, Dense>;
56 auto result = detail::temporary_conversion<
57 MaybeConstDense>::template create<NextDense, NextNextDense>(matrix);
59 GKO_NOT_SUPPORTED(matrix);
79 template <
typename ValueType,
typename Function,
typename... Args>
82 fn(make_temporary_conversion<ValueType>(linops).get()...);
95 template <
typename ValueType,
typename Function>
102 auto complex_to_real =
103 !(is_complex<ValueType>() ||
105 if (complex_to_real) {
106 auto dense_in = make_temporary_conversion<to_complex<ValueType>>(in);
107 auto dense_out = make_temporary_conversion<to_complex<ValueType>>(out);
112 fn(dynamic_cast<const Dense*>(dense_in->create_real_view().get()),
113 dynamic_cast<Dense*>(dense_out->create_real_view().get()));
115 precision_dispatch<ValueType>(fn, in, out);
129 template <
typename ValueType,
typename Function>
137 auto complex_to_real =
138 !(is_complex<ValueType>() ||
140 if (complex_to_real) {
141 auto dense_in = make_temporary_conversion<to_complex<ValueType>>(in);
142 auto dense_out = make_temporary_conversion<to_complex<ValueType>>(out);
143 auto dense_alpha = make_temporary_conversion<ValueType>(alpha);
148 fn(dense_alpha.get(),
149 dynamic_cast<const Dense*>(dense_in->create_real_view().get()),
150 dynamic_cast<Dense*>(dense_out->create_real_view().get()));
152 precision_dispatch<ValueType>(fn, alpha, in, out);
166 template <
typename ValueType,
typename Function>
175 auto complex_to_real =
176 !(is_complex<ValueType>() ||
178 if (complex_to_real) {
179 auto dense_in = make_temporary_conversion<to_complex<ValueType>>(in);
180 auto dense_out = make_temporary_conversion<to_complex<ValueType>>(out);
181 auto dense_alpha = make_temporary_conversion<ValueType>(alpha);
182 auto dense_beta = make_temporary_conversion<ValueType>(beta);
187 fn(dense_alpha.get(),
188 dynamic_cast<const Dense*>(dense_in->create_real_view().get()),
190 dynamic_cast<Dense*>(dense_out->create_real_view().get()));
192 precision_dispatch<ValueType>(fn, alpha, in, beta, out);
226 template <
typename ValueType,
typename Function>
229 #ifdef GINKGO_MIXED_PRECISION
233 auto dispatch_out_vector = [&](
auto dense_in) {
234 if (
auto dense_out = dynamic_cast<fst_type*>(out)) {
235 fn(dense_in, dense_out);
236 }
else if (
auto dense_out = dynamic_cast<snd_type*>(out)) {
237 fn(dense_in, dense_out);
238 }
else if (
auto dense_out = dynamic_cast<trd_type*>(out)) {
239 fn(dense_in, dense_out);
241 GKO_NOT_SUPPORTED(out);
244 if (
auto dense_in = dynamic_cast<const fst_type*>(in)) {
245 dispatch_out_vector(dense_in);
246 }
else if (
auto dense_in = dynamic_cast<const snd_type*>(in)) {
247 dispatch_out_vector(dense_in);
248 }
else if (
auto dense_in = dynamic_cast<const trd_type*>(in)) {
249 dispatch_out_vector(dense_in);
251 GKO_NOT_SUPPORTED(in);
254 precision_dispatch<ValueType>(fn, in, out);
268 template <
typename ValueType,
typename Function,
269 std::enable_if_t<is_complex<ValueType>()>* =
nullptr>
273 #ifdef GINKGO_MIXED_PRECISION
274 mixed_precision_dispatch<ValueType>(fn, in, out);
276 precision_dispatch<ValueType>(fn, in, out);
281 template <
typename ValueType,
typename Function,
282 std::enable_if_t<!is_complex<ValueType>()>* =
nullptr>
286 #ifdef GINKGO_MIXED_PRECISION
287 if (!
dynamic_cast<const ConvertibleTo<matrix::Dense<>
>*>(in)) {
288 mixed_precision_dispatch<to_complex<ValueType>>(
289 [&fn](
auto dense_in,
auto dense_out) {
290 fn(dense_in->create_real_view().get(),
291 dense_out->create_real_view().get());
295 mixed_precision_dispatch<ValueType>(fn, in, out);
298 precision_dispatch_real_complex<ValueType>(fn, in, out);
303 namespace experimental {
309 namespace distributed {
337 template <
typename ValueType>
342 gko::detail::temporary_conversion<Vector<ValueType>>::template create<
346 GKO_NOT_SUPPORTED(matrix);
355 template <
typename ValueType>
356 gko::detail::temporary_conversion<const Vector<ValueType>>
359 auto result = gko::detail::temporary_conversion<const Vector<ValueType>>::
360 template create<Vector<next_precision<ValueType>>,
364 GKO_NOT_SUPPORTED(matrix);
384 template <
typename ValueType,
typename Function,
typename... Args>
387 fn(distributed::make_temporary_conversion<ValueType>(linops).get()...);
400 template <
typename ValueType,
typename Function>
403 auto complex_to_real = !(
404 is_complex<ValueType>() ||
407 if (complex_to_real) {
409 distributed::make_temporary_conversion<to_complex<ValueType>>(in);
411 distributed::make_temporary_conversion<to_complex<ValueType>>(out);
416 fn(dynamic_cast<const Vector*>(dense_in->create_real_view().get()),
417 dynamic_cast<Vector*>(dense_out->create_real_view().get()));
419 distributed::precision_dispatch<ValueType>(fn, in, out);
427 template <
typename ValueType,
typename Function>
431 auto complex_to_real = !(
432 is_complex<ValueType>() ||
435 if (complex_to_real) {
437 distributed::make_temporary_conversion<to_complex<ValueType>>(in);
439 distributed::make_temporary_conversion<to_complex<ValueType>>(out);
440 auto dense_alpha = gko::make_temporary_conversion<ValueType>(alpha);
445 fn(dense_alpha.get(),
446 dynamic_cast<const Vector*>(dense_in->create_real_view().get()),
447 dynamic_cast<Vector*>(dense_out->create_real_view().get()));
449 fn(gko::make_temporary_conversion<ValueType>(alpha).get(),
450 distributed::make_temporary_conversion<ValueType>(in).get(),
451 distributed::make_temporary_conversion<ValueType>(out).get());
459 template <
typename ValueType,
typename Function>
464 auto complex_to_real = !(
465 is_complex<ValueType>() ||
468 if (complex_to_real) {
470 distributed::make_temporary_conversion<to_complex<ValueType>>(in);
472 distributed::make_temporary_conversion<to_complex<ValueType>>(out);
473 auto dense_alpha = gko::make_temporary_conversion<ValueType>(alpha);
474 auto dense_beta = gko::make_temporary_conversion<ValueType>(beta);
479 fn(dense_alpha.get(),
480 dynamic_cast<const Vector*>(dense_in->create_real_view().get()),
482 dynamic_cast<Vector*>(dense_out->create_real_view().get()));
484 fn(gko::make_temporary_conversion<ValueType>(alpha).get(),
485 distributed::make_temporary_conversion<ValueType>(in).get(),
486 gko::make_temporary_conversion<ValueType>(beta).get(),
487 distributed::make_temporary_conversion<ValueType>(out).get());
508 template <
typename ValueType,
typename Function>
509 void precision_dispatch_real_complex_distributed(Function fn,
const LinOp* in,
512 if (dynamic_cast<const experimental::distributed::DistributedBase*>(in)) {
513 experimental::distributed::precision_dispatch_real_complex<ValueType>(
516 gko::precision_dispatch_real_complex<ValueType>(fn, in, out);
525 template <
typename ValueType,
typename Function>
526 void precision_dispatch_real_complex_distributed(Function fn,
528 const LinOp* in, LinOp* out)
530 if (dynamic_cast<const experimental::distributed::DistributedBase*>(in)) {
531 experimental::distributed::precision_dispatch_real_complex<ValueType>(
534 gko::precision_dispatch_real_complex<ValueType>(fn, alpha, in, out);
543 template <
typename ValueType,
typename Function>
544 void precision_dispatch_real_complex_distributed(Function fn,
547 const LinOp* beta, LinOp* out)
549 if (dynamic_cast<const experimental::distributed::DistributedBase*>(in)) {
550 experimental::distributed::precision_dispatch_real_complex<ValueType>(
551 fn, alpha, in, beta, out);
554 gko::precision_dispatch_real_complex<ValueType>(fn, alpha, in, beta,
573 template <
typename ValueType,
typename Function,
typename... Args>
574 void precision_dispatch_real_complex_distributed(Function fn, Args*... args)
576 precision_dispatch_real_complex<ValueType>(fn, args...);
587 #endif // GKO_PUBLIC_CORE_BASE_PRECISION_DISPATCH_HPP_