Ginkgo  Generated from pipelines/1680925034 branch based on develop. Ginkgo version 1.10.0
A numerical linear algebra library targeting many-core architectures
precision_dispatch.hpp
1 // SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
2 //
3 // SPDX-License-Identifier: BSD-3-Clause
4 
5 #ifndef GKO_PUBLIC_CORE_BASE_PRECISION_DISPATCH_HPP_
6 #define GKO_PUBLIC_CORE_BASE_PRECISION_DISPATCH_HPP_
7 
8 
9 #include <ginkgo/config.hpp>
10 #include <ginkgo/core/base/math.hpp>
11 #include <ginkgo/core/base/temporary_conversion.hpp>
12 #include <ginkgo/core/distributed/vector.hpp>
13 #include <ginkgo/core/matrix/dense.hpp>
14 
15 
16 namespace gko {
17 
18 
43 template <typename ValueType, typename Ptr>
44 detail::temporary_conversion<std::conditional_t<
45  std::is_const<detail::pointee<Ptr>>::value, const matrix::Dense<ValueType>,
46  matrix::Dense<ValueType>>>
48 {
49  using Pointee = detail::pointee<Ptr>;
50  using Dense = matrix::Dense<ValueType>;
51  using NextDense = matrix::Dense<next_precision<ValueType>>;
52  using NextNextDense =
54  using MaybeConstDense =
55  std::conditional_t<std::is_const<Pointee>::value, const Dense, Dense>;
56  auto result = detail::temporary_conversion<
57  MaybeConstDense>::template create<NextDense, NextNextDense>(matrix);
58  if (!result) {
59  GKO_NOT_SUPPORTED(matrix);
60  }
61  return result;
62 }
63 
64 
79 template <typename ValueType, typename Function, typename... Args>
80 void precision_dispatch(Function fn, Args*... linops)
81 {
82  fn(make_temporary_conversion<ValueType>(linops).get()...);
83 }
84 
85 
95 template <typename ValueType, typename Function>
96 void precision_dispatch_real_complex(Function fn, const LinOp* in, LinOp* out)
97 {
98  // do we need to convert complex Dense to real Dense?
99  // all real dense vectors are intra-convertible, thus by casting to
100  // ConvertibleTo<matrix::Dense<>>, we can check whether a LinOp is a real
101  // dense matrix:
102  auto complex_to_real =
103  !(is_complex<ValueType>() ||
104  dynamic_cast<const ConvertibleTo<matrix::Dense<>>*>(in));
105  if (complex_to_real) {
106  auto dense_in = make_temporary_conversion<to_complex<ValueType>>(in);
107  auto dense_out = make_temporary_conversion<to_complex<ValueType>>(out);
108  using Dense = matrix::Dense<ValueType>;
109  // These dynamic_casts are only needed to make the code compile
110  // If ValueType is complex, this branch will never be taken
111  // If ValueType is real, the cast is a no-op
112  fn(dynamic_cast<const Dense*>(dense_in->create_real_view().get()),
113  dynamic_cast<Dense*>(dense_out->create_real_view().get()));
114  } else {
115  precision_dispatch<ValueType>(fn, in, out);
116  }
117 }
118 
119 
129 template <typename ValueType, typename Function>
130 void precision_dispatch_real_complex(Function fn, const LinOp* alpha,
131  const LinOp* in, LinOp* out)
132 {
133  // do we need to convert complex Dense to real Dense?
134  // all real dense vectors are intra-convertible, thus by casting to
135  // ConvertibleTo<matrix::Dense<>>, we can check whether a LinOp is a real
136  // dense matrix:
137  auto complex_to_real =
138  !(is_complex<ValueType>() ||
139  dynamic_cast<const ConvertibleTo<matrix::Dense<>>*>(in));
140  if (complex_to_real) {
141  auto dense_in = make_temporary_conversion<to_complex<ValueType>>(in);
142  auto dense_out = make_temporary_conversion<to_complex<ValueType>>(out);
143  auto dense_alpha = make_temporary_conversion<ValueType>(alpha);
144  using Dense = matrix::Dense<ValueType>;
145  // These dynamic_casts are only needed to make the code compile
146  // If ValueType is complex, this branch will never be taken
147  // If ValueType is real, the cast is a no-op
148  fn(dense_alpha.get(),
149  dynamic_cast<const Dense*>(dense_in->create_real_view().get()),
150  dynamic_cast<Dense*>(dense_out->create_real_view().get()));
151  } else {
152  precision_dispatch<ValueType>(fn, alpha, in, out);
153  }
154 }
155 
156 
166 template <typename ValueType, typename Function>
167 void precision_dispatch_real_complex(Function fn, const LinOp* alpha,
168  const LinOp* in, const LinOp* beta,
169  LinOp* out)
170 {
171  // do we need to convert complex Dense to real Dense?
172  // all real dense vectors are intra-convertible, thus by casting to
173  // ConvertibleTo<matrix::Dense<>>, we can check whether a LinOp is a real
174  // dense matrix:
175  auto complex_to_real =
176  !(is_complex<ValueType>() ||
177  dynamic_cast<const ConvertibleTo<matrix::Dense<>>*>(in));
178  if (complex_to_real) {
179  auto dense_in = make_temporary_conversion<to_complex<ValueType>>(in);
180  auto dense_out = make_temporary_conversion<to_complex<ValueType>>(out);
181  auto dense_alpha = make_temporary_conversion<ValueType>(alpha);
182  auto dense_beta = make_temporary_conversion<ValueType>(beta);
183  using Dense = matrix::Dense<ValueType>;
184  // These dynamic_casts are only needed to make the code compile
185  // If ValueType is complex, this branch will never be taken
186  // If ValueType is real, the cast is a no-op
187  fn(dense_alpha.get(),
188  dynamic_cast<const Dense*>(dense_in->create_real_view().get()),
189  dense_beta.get(),
190  dynamic_cast<Dense*>(dense_out->create_real_view().get()));
191  } else {
192  precision_dispatch<ValueType>(fn, alpha, in, beta, out);
193  }
194 }
195 
196 
226 template <typename ValueType, typename Function>
227 void mixed_precision_dispatch(Function fn, const LinOp* in, LinOp* out)
228 {
229 #ifdef GINKGO_MIXED_PRECISION
230  using fst_type = matrix::Dense<ValueType>;
231  using snd_type = matrix::Dense<next_precision<ValueType>>;
233  auto dispatch_out_vector = [&](auto dense_in) {
234  if (auto dense_out = dynamic_cast<fst_type*>(out)) {
235  fn(dense_in, dense_out);
236  } else if (auto dense_out = dynamic_cast<snd_type*>(out)) {
237  fn(dense_in, dense_out);
238  } else if (auto dense_out = dynamic_cast<trd_type*>(out)) {
239  fn(dense_in, dense_out);
240  } else {
241  GKO_NOT_SUPPORTED(out);
242  }
243  };
244  if (auto dense_in = dynamic_cast<const fst_type*>(in)) {
245  dispatch_out_vector(dense_in);
246  } else if (auto dense_in = dynamic_cast<const snd_type*>(in)) {
247  dispatch_out_vector(dense_in);
248  } else if (auto dense_in = dynamic_cast<const trd_type*>(in)) {
249  dispatch_out_vector(dense_in);
250  } else {
251  GKO_NOT_SUPPORTED(in);
252  }
253 #else
254  precision_dispatch<ValueType>(fn, in, out);
255 #endif
256 }
257 
258 
268 template <typename ValueType, typename Function,
269  std::enable_if_t<is_complex<ValueType>()>* = nullptr>
270 void mixed_precision_dispatch_real_complex(Function fn, const LinOp* in,
271  LinOp* out)
272 {
273 #ifdef GINKGO_MIXED_PRECISION
274  mixed_precision_dispatch<ValueType>(fn, in, out);
275 #else
276  precision_dispatch<ValueType>(fn, in, out);
277 #endif
278 }
279 
280 
281 template <typename ValueType, typename Function,
282  std::enable_if_t<!is_complex<ValueType>()>* = nullptr>
283 void mixed_precision_dispatch_real_complex(Function fn, const LinOp* in,
284  LinOp* out)
285 {
286 #ifdef GINKGO_MIXED_PRECISION
287  if (!dynamic_cast<const ConvertibleTo<matrix::Dense<>>*>(in)) {
288  mixed_precision_dispatch<to_complex<ValueType>>(
289  [&fn](auto dense_in, auto dense_out) {
290  fn(dense_in->create_real_view().get(),
291  dense_out->create_real_view().get());
292  },
293  in, out);
294  } else {
295  mixed_precision_dispatch<ValueType>(fn, in, out);
296  }
297 #else
298  precision_dispatch_real_complex<ValueType>(fn, in, out);
299 #endif
300 }
301 
302 
303 namespace experimental {
304 
305 
306 #if GINKGO_BUILD_MPI
307 
308 
309 namespace distributed {
310 
311 
337 template <typename ValueType>
338 gko::detail::temporary_conversion<Vector<ValueType>> make_temporary_conversion(
339  LinOp* matrix)
340 {
341  auto result =
342  gko::detail::temporary_conversion<Vector<ValueType>>::template create<
345  if (!result) {
346  GKO_NOT_SUPPORTED(matrix);
347  }
348  return result;
349 }
350 
351 
355 template <typename ValueType>
356 gko::detail::temporary_conversion<const Vector<ValueType>>
358 {
359  auto result = gko::detail::temporary_conversion<const Vector<ValueType>>::
360  template create<Vector<next_precision<ValueType>>,
362  matrix);
363  if (!result) {
364  GKO_NOT_SUPPORTED(matrix);
365  }
366  return result;
367 }
368 
369 
384 template <typename ValueType, typename Function, typename... Args>
385 void precision_dispatch(Function fn, Args*... linops)
386 {
387  fn(distributed::make_temporary_conversion<ValueType>(linops).get()...);
388 }
389 
390 
400 template <typename ValueType, typename Function>
401 void precision_dispatch_real_complex(Function fn, const LinOp* in, LinOp* out)
402 {
403  auto complex_to_real = !(
404  is_complex<ValueType>() ||
406  in));
407  if (complex_to_real) {
408  auto dense_in =
409  distributed::make_temporary_conversion<to_complex<ValueType>>(in);
410  auto dense_out =
411  distributed::make_temporary_conversion<to_complex<ValueType>>(out);
413  // These dynamic_casts are only needed to make the code compile
414  // If ValueType is complex, this branch will never be taken
415  // If ValueType is real, the cast is a no-op
416  fn(dynamic_cast<const Vector*>(dense_in->create_real_view().get()),
417  dynamic_cast<Vector*>(dense_out->create_real_view().get()));
418  } else {
419  distributed::precision_dispatch<ValueType>(fn, in, out);
420  }
421 }
422 
423 
427 template <typename ValueType, typename Function>
428 void precision_dispatch_real_complex(Function fn, const LinOp* alpha,
429  const LinOp* in, LinOp* out)
430 {
431  auto complex_to_real = !(
432  is_complex<ValueType>() ||
434  in));
435  if (complex_to_real) {
436  auto dense_in =
437  distributed::make_temporary_conversion<to_complex<ValueType>>(in);
438  auto dense_out =
439  distributed::make_temporary_conversion<to_complex<ValueType>>(out);
440  auto dense_alpha = gko::make_temporary_conversion<ValueType>(alpha);
442  // These dynamic_casts are only needed to make the code compile
443  // If ValueType is complex, this branch will never be taken
444  // If ValueType is real, the cast is a no-op
445  fn(dense_alpha.get(),
446  dynamic_cast<const Vector*>(dense_in->create_real_view().get()),
447  dynamic_cast<Vector*>(dense_out->create_real_view().get()));
448  } else {
449  fn(gko::make_temporary_conversion<ValueType>(alpha).get(),
450  distributed::make_temporary_conversion<ValueType>(in).get(),
451  distributed::make_temporary_conversion<ValueType>(out).get());
452  }
453 }
454 
455 
459 template <typename ValueType, typename Function>
460 void precision_dispatch_real_complex(Function fn, const LinOp* alpha,
461  const LinOp* in, const LinOp* beta,
462  LinOp* out)
463 {
464  auto complex_to_real = !(
465  is_complex<ValueType>() ||
467  in));
468  if (complex_to_real) {
469  auto dense_in =
470  distributed::make_temporary_conversion<to_complex<ValueType>>(in);
471  auto dense_out =
472  distributed::make_temporary_conversion<to_complex<ValueType>>(out);
473  auto dense_alpha = gko::make_temporary_conversion<ValueType>(alpha);
474  auto dense_beta = gko::make_temporary_conversion<ValueType>(beta);
476  // These dynamic_casts are only needed to make the code compile
477  // If ValueType is complex, this branch will never be taken
478  // If ValueType is real, the cast is a no-op
479  fn(dense_alpha.get(),
480  dynamic_cast<const Vector*>(dense_in->create_real_view().get()),
481  dense_beta.get(),
482  dynamic_cast<Vector*>(dense_out->create_real_view().get()));
483  } else {
484  fn(gko::make_temporary_conversion<ValueType>(alpha).get(),
485  distributed::make_temporary_conversion<ValueType>(in).get(),
486  gko::make_temporary_conversion<ValueType>(beta).get(),
487  distributed::make_temporary_conversion<ValueType>(out).get());
488  }
489 }
490 
491 
492 } // namespace distributed
493 
494 
508 template <typename ValueType, typename Function>
509 void precision_dispatch_real_complex_distributed(Function fn, const LinOp* in,
510  LinOp* out)
511 {
512  if (dynamic_cast<const experimental::distributed::DistributedBase*>(in)) {
513  experimental::distributed::precision_dispatch_real_complex<ValueType>(
514  fn, in, out);
515  } else {
516  gko::precision_dispatch_real_complex<ValueType>(fn, in, out);
517  }
518 }
519 
520 
525 template <typename ValueType, typename Function>
526 void precision_dispatch_real_complex_distributed(Function fn,
527  const LinOp* alpha,
528  const LinOp* in, LinOp* out)
529 {
530  if (dynamic_cast<const experimental::distributed::DistributedBase*>(in)) {
531  experimental::distributed::precision_dispatch_real_complex<ValueType>(
532  fn, alpha, in, out);
533  } else {
534  gko::precision_dispatch_real_complex<ValueType>(fn, alpha, in, out);
535  }
536 }
537 
538 
543 template <typename ValueType, typename Function>
544 void precision_dispatch_real_complex_distributed(Function fn,
545  const LinOp* alpha,
546  const LinOp* in,
547  const LinOp* beta, LinOp* out)
548 {
549  if (dynamic_cast<const experimental::distributed::DistributedBase*>(in)) {
550  experimental::distributed::precision_dispatch_real_complex<ValueType>(
551  fn, alpha, in, beta, out);
552 
553  } else {
554  gko::precision_dispatch_real_complex<ValueType>(fn, alpha, in, beta,
555  out);
556  }
557 }
558 
559 
560 #else
561 
562 
573 template <typename ValueType, typename Function, typename... Args>
574 void precision_dispatch_real_complex_distributed(Function fn, Args*... args)
575 {
576  precision_dispatch_real_complex<ValueType>(fn, args...);
577 }
578 
579 
580 #endif
581 
582 
583 } // namespace experimental
584 } // namespace gko
585 
586 
587 #endif // GKO_PUBLIC_CORE_BASE_PRECISION_DISPATCH_HPP_
gko::LinOp
Definition: lin_op.hpp:117
gko::matrix::Dense
Dense is a matrix format which explicitly stores all values of the matrix.
Definition: dense_cache.hpp:19
gko::experimental::distributed::Vector
Vector is a format which explicitly stores (multiple) distributed column vectors in a dense storage f...
Definition: matrix.hpp:151
gko::mixed_precision_dispatch_real_complex
void mixed_precision_dispatch_real_complex(Function fn, const LinOp *in, LinOp *out)
Calls the given function with the given LinOps cast to their dynamic type matrix::Dense<ValueType>* a...
Definition: precision_dispatch.hpp:270
gko::experimental::distributed::precision_dispatch_real_complex
void precision_dispatch_real_complex(Function fn, const LinOp *in, LinOp *out)
Calls the given function with the given LinOps temporarily converted to experimental::distributed::Ve...
Definition: precision_dispatch.hpp:401
gko
The Ginkgo namespace.
Definition: abstract_factory.hpp:20
gko::precision_dispatch
void precision_dispatch(Function fn, Args *... linops)
Calls the given function with each given argument LinOp temporarily converted into matrix::Dense<Valu...
Definition: precision_dispatch.hpp:80
gko::experimental::distributed::precision_dispatch
void precision_dispatch(Function fn, Args *... linops)
Calls the given function with each given argument LinOp temporarily converted into experimental::dist...
Definition: precision_dispatch.hpp:385
gko::mixed_precision_dispatch
void mixed_precision_dispatch(Function fn, const LinOp *in, LinOp *out)
Calls the given function with each given argument LinOp converted into matrix::Dense<ValueType> as pa...
Definition: precision_dispatch.hpp:227
gko::precision_dispatch_real_complex
void precision_dispatch_real_complex(Function fn, const LinOp *in, LinOp *out)
Calls the given function with the given LinOps temporarily converted to matrix::Dense<ValueType>* as ...
Definition: precision_dispatch.hpp:96
gko::make_temporary_conversion
detail::temporary_conversion< std::conditional_t< std::is_const< detail::pointee< Ptr > >::value, const matrix::Dense< ValueType >, matrix::Dense< ValueType > > > make_temporary_conversion(Ptr &&matrix)
Convert the given LinOp from matrix::Dense<...> to matrix::Dense<ValueType>.
Definition: precision_dispatch.hpp:47
gko::ConvertibleTo
ConvertibleTo interface is used to mark that the implementer can be converted to the object of Result...
Definition: polymorphic_object.hpp:479
gko::experimental::distributed::make_temporary_conversion
gko::detail::temporary_conversion< Vector< ValueType > > make_temporary_conversion(LinOp *matrix)
Convert the given LinOp from experimental::distributed::Vector<...> to experimental::distributed::Vec...
Definition: precision_dispatch.hpp:338