Ginkgo  Generated from pipelines/1589998975 branch based on develop. Ginkgo version 1.10.0
A numerical linear algebra library targeting many-core architectures
precision_dispatch.hpp
1 // SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
2 //
3 // SPDX-License-Identifier: BSD-3-Clause
4 
5 #ifndef GKO_PUBLIC_CORE_BASE_PRECISION_DISPATCH_HPP_
6 #define GKO_PUBLIC_CORE_BASE_PRECISION_DISPATCH_HPP_
7 
8 
9 #include <ginkgo/config.hpp>
10 #include <ginkgo/core/base/math.hpp>
11 #include <ginkgo/core/base/temporary_conversion.hpp>
12 #include <ginkgo/core/distributed/vector.hpp>
13 #include <ginkgo/core/matrix/dense.hpp>
14 
15 
16 namespace gko {
17 
18 
43 template <typename ValueType, typename Ptr>
44 detail::temporary_conversion<std::conditional_t<
45  std::is_const<detail::pointee<Ptr>>::value, const matrix::Dense<ValueType>,
46  matrix::Dense<ValueType>>>
48 {
49  using Pointee = detail::pointee<Ptr>;
50  using Dense = matrix::Dense<ValueType>;
51  using NextDense = matrix::Dense<next_precision<ValueType>>;
52  using NextNextDense =
54  using MaybeConstDense =
55  std::conditional_t<std::is_const<Pointee>::value, const Dense, Dense>;
56  auto result = detail::temporary_conversion<
57  MaybeConstDense>::template create<NextDense, NextNextDense>(matrix);
58  if (!result) {
59  GKO_NOT_SUPPORTED(matrix);
60  }
61  return result;
62 }
63 
64 
79 template <typename ValueType, typename Function, typename... Args>
80 void precision_dispatch(Function fn, Args*... linops)
81 {
82  fn(make_temporary_conversion<ValueType>(linops).get()...);
83 }
84 
85 
95 template <typename ValueType, typename Function>
96 void precision_dispatch_real_complex(Function fn, const LinOp* in, LinOp* out)
97 {
98  // do we need to convert complex Dense to real Dense?
99  // all real dense vectors are intra-convertible, thus by casting to
100  // ConvertibleTo<matrix::Dense<>>, we can check whether a LinOp is a real
101  // dense matrix:
102  auto complex_to_real =
103  !(is_complex<ValueType>() ||
104  dynamic_cast<const ConvertibleTo<matrix::Dense<>>*>(in));
105  if (complex_to_real) {
106  auto dense_in = make_temporary_conversion<to_complex<ValueType>>(in);
107  auto dense_out = make_temporary_conversion<to_complex<ValueType>>(out);
108  using Dense = matrix::Dense<ValueType>;
109  // These dynamic_casts are only needed to make the code compile
110  // If ValueType is complex, this branch will never be taken
111  // If ValueType is real, the cast is a no-op
112  fn(dynamic_cast<const Dense*>(dense_in->create_real_view().get()),
113  dynamic_cast<Dense*>(dense_out->create_real_view().get()));
114  } else {
115  precision_dispatch<ValueType>(fn, in, out);
116  }
117 }
118 
119 
129 template <typename ValueType, typename Function>
130 void precision_dispatch_real_complex(Function fn, const LinOp* alpha,
131  const LinOp* in, LinOp* out)
132 {
133  // do we need to convert complex Dense to real Dense?
134  // all real dense vectors are intra-convertible, thus by casting to
135  // ConvertibleTo<matrix::Dense<>>, we can check whether a LinOp is a real
136  // dense matrix:
137  auto complex_to_real =
138  !(is_complex<ValueType>() ||
139  dynamic_cast<const ConvertibleTo<matrix::Dense<>>*>(in));
140  if (complex_to_real) {
141  auto dense_in = make_temporary_conversion<to_complex<ValueType>>(in);
142  auto dense_out = make_temporary_conversion<to_complex<ValueType>>(out);
143  auto dense_alpha = make_temporary_conversion<ValueType>(alpha);
144  using Dense = matrix::Dense<ValueType>;
145  // These dynamic_casts are only needed to make the code compile
146  // If ValueType is complex, this branch will never be taken
147  // If ValueType is real, the cast is a no-op
148  fn(dense_alpha.get(),
149  dynamic_cast<const Dense*>(dense_in->create_real_view().get()),
150  dynamic_cast<Dense*>(dense_out->create_real_view().get()));
151  } else {
152  precision_dispatch<ValueType>(fn, alpha, in, out);
153  }
154 }
155 
156 
166 template <typename ValueType, typename Function>
167 void precision_dispatch_real_complex(Function fn, const LinOp* alpha,
168  const LinOp* in, const LinOp* beta,
169  LinOp* out)
170 {
171  // do we need to convert complex Dense to real Dense?
172  // all real dense vectors are intra-convertible, thus by casting to
173  // ConvertibleTo<matrix::Dense<>>, we can check whether a LinOp is a real
174  // dense matrix:
175  auto complex_to_real =
176  !(is_complex<ValueType>() ||
177  dynamic_cast<const ConvertibleTo<matrix::Dense<>>*>(in));
178  if (complex_to_real) {
179  auto dense_in = make_temporary_conversion<to_complex<ValueType>>(in);
180  auto dense_out = make_temporary_conversion<to_complex<ValueType>>(out);
181  auto dense_alpha = make_temporary_conversion<ValueType>(alpha);
182  auto dense_beta = make_temporary_conversion<ValueType>(beta);
183  using Dense = matrix::Dense<ValueType>;
184  // These dynamic_casts are only needed to make the code compile
185  // If ValueType is complex, this branch will never be taken
186  // If ValueType is real, the cast is a no-op
187  fn(dense_alpha.get(),
188  dynamic_cast<const Dense*>(dense_in->create_real_view().get()),
189  dense_beta.get(),
190  dynamic_cast<Dense*>(dense_out->create_real_view().get()));
191  } else {
192  precision_dispatch<ValueType>(fn, alpha, in, beta, out);
193  }
194 }
195 
196 
226 template <typename ValueType, typename Function>
227 void mixed_precision_dispatch(Function fn, const LinOp* in, LinOp* out)
228 {
229 #ifdef GINKGO_MIXED_PRECISION
230  using fst_type = matrix::Dense<ValueType>;
231  using snd_type = matrix::Dense<next_precision<ValueType>>;
233  auto dispatch_out_vector = [&](auto dense_in) {
234  if (auto dense_out = dynamic_cast<fst_type*>(out)) {
235  fn(dense_in, dense_out);
236  } else if (auto dense_out = dynamic_cast<snd_type*>(out)) {
237  fn(dense_in, dense_out);
238  } else if (auto dense_out = dynamic_cast<trd_type*>(out)) {
239  fn(dense_in, dense_out);
240  } else {
241  GKO_NOT_SUPPORTED(out);
242  }
243  };
244  if (auto dense_in = dynamic_cast<const fst_type*>(in)) {
245  dispatch_out_vector(dense_in);
246  } else if (auto dense_in = dynamic_cast<const snd_type*>(in)) {
247  dispatch_out_vector(dense_in);
248  } else if (auto dense_in = dynamic_cast<const trd_type*>(in)) {
249  dispatch_out_vector(dense_in);
250  } else {
251  GKO_NOT_SUPPORTED(in);
252  }
253 #else
254  precision_dispatch<ValueType>(fn, in, out);
255 #endif
256 }
257 
258 
268 template <typename ValueType, typename Function,
269  std::enable_if_t<is_complex<ValueType>()>* = nullptr>
270 void mixed_precision_dispatch_real_complex(Function fn, const LinOp* in,
271  LinOp* out)
272 {
273 #ifdef GINKGO_MIXED_PRECISION
274  mixed_precision_dispatch<ValueType>(fn, in, out);
275 #else
276  precision_dispatch<ValueType>(fn, in, out);
277 #endif
278 }
279 
280 
281 template <typename ValueType, typename Function,
282  std::enable_if_t<!is_complex<ValueType>()>* = nullptr>
283 void mixed_precision_dispatch_real_complex(Function fn, const LinOp* in,
284  LinOp* out)
285 {
286 #ifdef GINKGO_MIXED_PRECISION
287  if (!dynamic_cast<const ConvertibleTo<matrix::Dense<>>*>(in)) {
288  mixed_precision_dispatch<to_complex<ValueType>>(
289  [&fn](auto dense_in, auto dense_out) {
290  fn(dense_in->create_real_view().get(),
291  dense_out->create_real_view().get());
292  },
293  in, out);
294  } else {
295  mixed_precision_dispatch<ValueType>(fn, in, out);
296  }
297 #else
298  precision_dispatch_real_complex<ValueType>(fn, in, out);
299 #endif
300 }
301 
302 
303 namespace experimental {
304 
305 
306 #if GINKGO_BUILD_MPI
307 
308 
309 namespace distributed {
310 
311 
337 template <typename ValueType>
338 gko::detail::temporary_conversion<Vector<ValueType>> make_temporary_conversion(
339  LinOp* matrix)
340 {
341  auto result =
342  gko::detail::temporary_conversion<Vector<ValueType>>::template create<
344  if (!result) {
345  GKO_NOT_SUPPORTED(matrix);
346  }
347  return result;
348 }
349 
350 
354 template <typename ValueType>
355 gko::detail::temporary_conversion<const Vector<ValueType>>
357 {
358  auto result = gko::detail::temporary_conversion<const Vector<ValueType>>::
359  template create<Vector<next_precision_base<ValueType>>>(matrix);
360  if (!result) {
361  GKO_NOT_SUPPORTED(matrix);
362  }
363  return result;
364 }
365 
366 
381 template <typename ValueType, typename Function, typename... Args>
382 void precision_dispatch(Function fn, Args*... linops)
383 {
384  if constexpr (std::is_same_v<remove_complex<ValueType>, half>) {
385  GKO_NOT_SUPPORTED(nullptr);
386  } else {
387  fn(distributed::make_temporary_conversion<ValueType>(linops).get()...);
388  }
389 }
390 
391 
401 template <typename ValueType, typename Function>
402 void precision_dispatch_real_complex(Function fn, const LinOp* in, LinOp* out)
403 {
404  if constexpr (std::is_same_v<remove_complex<ValueType>, half>) {
405  GKO_NOT_SUPPORTED(nullptr);
406  } else {
407  auto complex_to_real = !(
408  is_complex<ValueType>() ||
409  dynamic_cast<
411  if (complex_to_real) {
412  auto dense_in =
413  distributed::make_temporary_conversion<to_complex<ValueType>>(
414  in);
415  auto dense_out =
416  distributed::make_temporary_conversion<to_complex<ValueType>>(
417  out);
419  // These dynamic_casts are only needed to make the code compile
420  // If ValueType is complex, this branch will never be taken
421  // If ValueType is real, the cast is a no-op
422  fn(dynamic_cast<const Vector*>(dense_in->create_real_view().get()),
423  dynamic_cast<Vector*>(dense_out->create_real_view().get()));
424  } else {
425  distributed::precision_dispatch<ValueType>(fn, in, out);
426  }
427  }
428 }
429 
430 
434 template <typename ValueType, typename Function>
435 void precision_dispatch_real_complex(Function fn, const LinOp* alpha,
436  const LinOp* in, LinOp* out)
437 {
438  if constexpr (std::is_same_v<remove_complex<ValueType>, half>) {
439  GKO_NOT_SUPPORTED(nullptr);
440  } else {
441  auto complex_to_real = !(
442  is_complex<ValueType>() ||
443  dynamic_cast<
445  if (complex_to_real) {
446  auto dense_in =
447  distributed::make_temporary_conversion<to_complex<ValueType>>(
448  in);
449  auto dense_out =
450  distributed::make_temporary_conversion<to_complex<ValueType>>(
451  out);
452  auto dense_alpha = gko::make_temporary_conversion<ValueType>(alpha);
454  // These dynamic_casts are only needed to make the code compile
455  // If ValueType is complex, this branch will never be taken
456  // If ValueType is real, the cast is a no-op
457  fn(dense_alpha.get(),
458  dynamic_cast<const Vector*>(dense_in->create_real_view().get()),
459  dynamic_cast<Vector*>(dense_out->create_real_view().get()));
460  } else {
461  fn(gko::make_temporary_conversion<ValueType>(alpha).get(),
462  distributed::make_temporary_conversion<ValueType>(in).get(),
463  distributed::make_temporary_conversion<ValueType>(out).get());
464  }
465  }
466 }
467 
468 
472 template <typename ValueType, typename Function>
473 void precision_dispatch_real_complex(Function fn, const LinOp* alpha,
474  const LinOp* in, const LinOp* beta,
475  LinOp* out)
476 {
477  if constexpr (std::is_same_v<remove_complex<ValueType>, half>) {
478  GKO_NOT_SUPPORTED(nullptr);
479  } else {
480  auto complex_to_real = !(
481  is_complex<ValueType>() ||
482  dynamic_cast<
484  if (complex_to_real) {
485  auto dense_in =
486  distributed::make_temporary_conversion<to_complex<ValueType>>(
487  in);
488  auto dense_out =
489  distributed::make_temporary_conversion<to_complex<ValueType>>(
490  out);
491  auto dense_alpha = gko::make_temporary_conversion<ValueType>(alpha);
492  auto dense_beta = gko::make_temporary_conversion<ValueType>(beta);
494  // These dynamic_casts are only needed to make the code compile
495  // If ValueType is complex, this branch will never be taken
496  // If ValueType is real, the cast is a no-op
497  fn(dense_alpha.get(),
498  dynamic_cast<const Vector*>(dense_in->create_real_view().get()),
499  dense_beta.get(),
500  dynamic_cast<Vector*>(dense_out->create_real_view().get()));
501  } else {
502  fn(gko::make_temporary_conversion<ValueType>(alpha).get(),
503  distributed::make_temporary_conversion<ValueType>(in).get(),
504  gko::make_temporary_conversion<ValueType>(beta).get(),
505  distributed::make_temporary_conversion<ValueType>(out).get());
506  }
507  }
508 }
509 
510 
511 } // namespace distributed
512 
513 
527 template <typename ValueType, typename Function>
528 void precision_dispatch_real_complex_distributed(Function fn, const LinOp* in,
529  LinOp* out)
530 {
531  if (dynamic_cast<const experimental::distributed::DistributedBase*>(in)) {
532  experimental::distributed::precision_dispatch_real_complex<ValueType>(
533  fn, in, out);
534  } else {
535  gko::precision_dispatch_real_complex<ValueType>(fn, in, out);
536  }
537 }
538 
539 
544 template <typename ValueType, typename Function>
545 void precision_dispatch_real_complex_distributed(Function fn,
546  const LinOp* alpha,
547  const LinOp* in, LinOp* out)
548 {
549  if (dynamic_cast<const experimental::distributed::DistributedBase*>(in)) {
550  experimental::distributed::precision_dispatch_real_complex<ValueType>(
551  fn, alpha, in, out);
552  } else {
553  gko::precision_dispatch_real_complex<ValueType>(fn, alpha, in, out);
554  }
555 }
556 
557 
562 template <typename ValueType, typename Function>
563 void precision_dispatch_real_complex_distributed(Function fn,
564  const LinOp* alpha,
565  const LinOp* in,
566  const LinOp* beta, LinOp* out)
567 {
568  if (dynamic_cast<const experimental::distributed::DistributedBase*>(in)) {
569  experimental::distributed::precision_dispatch_real_complex<ValueType>(
570  fn, alpha, in, beta, out);
571 
572  } else {
573  gko::precision_dispatch_real_complex<ValueType>(fn, alpha, in, beta,
574  out);
575  }
576 }
577 
578 
579 #else
580 
581 
592 template <typename ValueType, typename Function, typename... Args>
593 void precision_dispatch_real_complex_distributed(Function fn, Args*... args)
594 {
595  precision_dispatch_real_complex<ValueType>(fn, args...);
596 }
597 
598 
599 #endif
600 
601 
602 } // namespace experimental
603 } // namespace gko
604 
605 
606 #endif // GKO_PUBLIC_CORE_BASE_PRECISION_DISPATCH_HPP_
gko::LinOp
Definition: lin_op.hpp:117
gko::matrix::Dense
Dense is a matrix format which explicitly stores all values of the matrix.
Definition: dense_cache.hpp:19
gko::experimental::distributed::Vector
Vector is a format which explicitly stores (multiple) distributed column vectors in a dense storage f...
Definition: matrix.hpp:151
gko::mixed_precision_dispatch_real_complex
void mixed_precision_dispatch_real_complex(Function fn, const LinOp *in, LinOp *out)
Calls the given function with the given LinOps cast to their dynamic type matrix::Dense<ValueType>* a...
Definition: precision_dispatch.hpp:270
gko::experimental::distributed::precision_dispatch_real_complex
void precision_dispatch_real_complex(Function fn, const LinOp *in, LinOp *out)
Calls the given function with the given LinOps temporarily converted to experimental::distributed::Ve...
Definition: precision_dispatch.hpp:402
gko
The Ginkgo namespace.
Definition: abstract_factory.hpp:20
gko::precision_dispatch
void precision_dispatch(Function fn, Args *... linops)
Calls the given function with each given argument LinOp temporarily converted into matrix::Dense<Valu...
Definition: precision_dispatch.hpp:80
gko::experimental::distributed::precision_dispatch
void precision_dispatch(Function fn, Args *... linops)
Calls the given function with each given argument LinOp temporarily converted into experimental::dist...
Definition: precision_dispatch.hpp:382
gko::mixed_precision_dispatch
void mixed_precision_dispatch(Function fn, const LinOp *in, LinOp *out)
Calls the given function with each given argument LinOp converted into matrix::Dense<ValueType> as pa...
Definition: precision_dispatch.hpp:227
gko::half
A class providing basic support for half precision floating point types.
Definition: half.hpp:286
gko::precision_dispatch_real_complex
void precision_dispatch_real_complex(Function fn, const LinOp *in, LinOp *out)
Calls the given function with the given LinOps temporarily converted to matrix::Dense<ValueType>* as ...
Definition: precision_dispatch.hpp:96
gko::make_temporary_conversion
detail::temporary_conversion< std::conditional_t< std::is_const< detail::pointee< Ptr > >::value, const matrix::Dense< ValueType >, matrix::Dense< ValueType > > > make_temporary_conversion(Ptr &&matrix)
Convert the given LinOp from matrix::Dense<...> to matrix::Dense<ValueType>.
Definition: precision_dispatch.hpp:47
gko::ConvertibleTo
ConvertibleTo interface is used to mark that the implementer can be converted to the object of Result...
Definition: polymorphic_object.hpp:479
gko::experimental::distributed::make_temporary_conversion
gko::detail::temporary_conversion< Vector< ValueType > > make_temporary_conversion(LinOp *matrix)
Convert the given LinOp from experimental::distributed::Vector<...> to experimental::distributed::Vec...
Definition: precision_dispatch.hpp:338
gko::remove_complex
typename detail::remove_complex_s< T >::type remove_complex
Obtain the type which removed the complex of complex/scalar type or the template parameter of class b...
Definition: math.hpp:260