5 #ifndef GKO_PUBLIC_CORE_BASE_MATH_HPP_
6 #define GKO_PUBLIC_CORE_BASE_MATH_HPP_
13 #include <type_traits>
16 #include <ginkgo/config.hpp>
17 #include <ginkgo/core/base/half.hpp>
18 #include <ginkgo/core/base/types.hpp>
19 #include <ginkgo/core/base/utils.hpp>
40 struct remove_complex_impl {
48 struct remove_complex_impl<std::complex<T>> {
59 struct to_complex_impl {
60 using type = std::complex<T>;
69 struct to_complex_impl<std::complex<T>> {
70 using type = std::complex<T>;
75 struct is_complex_impl :
public std::integral_constant<bool, false> {};
78 struct is_complex_impl<std::complex<T>>
79 :
public std::integral_constant<bool, true> {};
83 struct is_complex_or_scalar_impl : std::is_scalar<T> {};
86 struct is_complex_or_scalar_impl<half> : std::true_type {};
89 struct is_complex_or_scalar_impl<std::complex<T>>
90 : is_complex_or_scalar_impl<T> {};
100 template <
template <
typename>
class converter,
typename T>
101 struct template_converter {};
112 template <
template <
typename>
class converter,
template <
typename...>
class T,
114 struct template_converter<converter, T<Rest...>> {
115 using type = T<typename converter<Rest>::type...>;
119 template <
typename T,
typename =
void>
120 struct remove_complex_s {};
128 template <
typename T>
129 struct remove_complex_s<T,
130 std::enable_if_t<is_complex_or_scalar_impl<T>::value>> {
131 using type =
typename detail::remove_complex_impl<T>::type;
140 template <
typename T>
141 struct remove_complex_s<
142 T, std::enable_if_t<!is_complex_or_scalar_impl<T>::value>> {
144 typename detail::template_converter<detail::remove_complex_impl,
149 template <
typename T,
typename =
void>
150 struct to_complex_s {};
158 template <
typename T>
159 struct to_complex_s<T, std::enable_if_t<is_complex_or_scalar_impl<T>::value>> {
160 using type =
typename detail::to_complex_impl<T>::type;
169 template <
typename T>
170 struct to_complex_s<T, std::enable_if_t<!is_complex_or_scalar_impl<T>::value>> {
172 typename detail::template_converter<detail::to_complex_impl, T>::type;
184 template <
typename T>
195 template <
typename T>
198 using type =
typename std::complex<T>::value_type;
210 template <
typename T>
220 template <
typename T>
223 return detail::is_complex_impl<T>::value;
234 template <
typename T>
244 template <
typename T>
247 return detail::is_complex_or_scalar_impl<T>::value;
259 template <
typename T>
278 template <
typename T>
287 template <
typename T>
295 template <
typename T>
296 struct next_precision_base_impl {};
299 struct next_precision_base_impl<float> {
304 struct next_precision_base_impl<double> {
308 template <
typename T>
309 struct next_precision_base_impl<std::complex<T>> {
310 using type = std::complex<typename next_precision_base_impl<T>::type>;
314 template <
typename T>
315 struct next_precision_impl {};
319 struct next_precision_impl<
gko::
half> {
324 struct next_precision_impl<float> {
329 struct next_precision_impl<double> {
333 template <
typename T>
334 struct next_precision_impl<std::complex<T>> {
335 using type = std::complex<typename next_precision_impl<T>::type>;
339 template <
typename T>
340 struct reduce_precision_impl {
344 template <
typename T>
345 struct reduce_precision_impl<std::complex<T>> {
346 using type = std::complex<typename reduce_precision_impl<T>::type>;
350 struct reduce_precision_impl<double> {
355 struct reduce_precision_impl<float> {
360 template <
typename T>
361 struct increase_precision_impl {
365 template <
typename T>
366 struct increase_precision_impl<std::complex<T>> {
367 using type = std::complex<typename increase_precision_impl<T>::type>;
371 struct increase_precision_impl<float> {
376 struct increase_precision_impl<
half> {
381 template <
typename T>
382 struct infinity_impl {
385 static constexpr
auto value = std::numeric_limits<T>::infinity();
392 template <
typename T1,
typename T2>
393 struct highest_precision_impl {
394 using type = decltype(T1{} + T2{});
397 template <
typename T1,
typename T2>
398 struct highest_precision_impl<std::complex<T1>, std::complex<T2>> {
399 using type = std::complex<typename highest_precision_impl<T1, T2>::type>;
402 template <
typename Head,
typename... Tail>
403 struct highest_precision_variadic {
404 using type =
typename highest_precision_impl<
405 Head,
typename highest_precision_variadic<Tail...>::type>::type;
408 template <
typename Head>
409 struct highest_precision_variadic<Head> {
420 template <
typename T>
430 template <
typename T>
436 #if GINKGO_ENABLE_HALF
437 template <
typename T>
438 using next_precision =
typename detail::next_precision_impl<T>::type;
440 template <
typename T>
444 template <
typename T>
447 template <
typename T>
455 template <
typename T>
462 template <
typename T>
477 template <
typename... Ts>
479 typename detail::highest_precision_variadic<Ts...>::type;
491 template <
typename T>
507 template <
typename T>
514 template <
typename FloatType,
size_type NumComponents,
size_type ComponentId>
521 template <
typename T>
522 struct truncate_type_impl {
523 using type = truncated<T, 2, 0>;
526 template <
typename T,
size_type Components>
527 struct truncate_type_impl<truncated<T, Components, 0>> {
528 using type = truncated<T, 2 * Components, 0>;
531 template <
typename T>
532 struct truncate_type_impl<std::complex<T>> {
533 using type = std::complex<typename truncate_type_impl<T>::type>;
537 template <
typename T>
538 struct type_size_impl {
539 static constexpr
auto value =
sizeof(T) *
byte_size;
542 template <
typename T>
543 struct type_size_impl<std::complex<T>> {
544 static constexpr
auto value =
sizeof(T) *
byte_size;
555 template <
typename T,
size_type Limit = sizeof(u
int16) *
byte_size>
557 std::conditional_t<detail::type_size_impl<T>::value >= 2 * Limit,
567 template <
typename S,
typename R>
575 GKO_ATTRIBUTES R
operator()(S val) {
return static_cast<R>(val); }
592 return (num + den - 1) / den;
601 template <
typename T>
617 template <
typename T>
618 GKO_INLINE constexpr T
zero(
const T&)
629 template <
typename T>
630 GKO_INLINE constexpr T
one()
636 GKO_INLINE constexpr half one<half>()
638 constexpr
auto bits = static_cast<uint16>(0b0
'01111'0000000000u);
639 return half::create_from_bits(bits);
652 template <
typename T>
653 GKO_INLINE constexpr T
one(
const T&)
667 template <
typename T>
670 return value == zero<T>();
682 template <
typename T>
685 return value != zero<T>();
700 template <
typename T>
701 GKO_INLINE constexpr T
max(
const T& x,
const T& y)
703 return x >= y ? x : y;
718 template <
typename T>
719 GKO_INLINE constexpr T
min(
const T& x,
const T& y)
721 return x <= y ? x : y;
737 template <
typename Ref,
typename Dummy = std::
void_t<>>
738 struct has_to_arithmetic_type : std::false_type {
739 static_assert(std::is_same<Dummy, void>::value,
740 "Do not modify the Dummy value!");
744 template <
typename Ref>
745 struct has_to_arithmetic_type<
746 Ref, std::
void_t<decltype(std::declval<Ref>().to_arithmetic_type())>>
748 using type = decltype(std::declval<Ref>().to_arithmetic_type());
756 template <
typename Ref,
typename Dummy = std::
void_t<>>
757 struct has_arithmetic_type : std::false_type {
758 static_assert(std::is_same<Dummy, void>::value,
759 "Do not modify the Dummy value!");
762 template <
typename Ref>
763 struct has_arithmetic_type<Ref, std::
void_t<typename Ref::arithmetic_type>>
778 template <
typename Ref>
779 constexpr GKO_ATTRIBUTES
780 std::enable_if_t<has_to_arithmetic_type<Ref>::value,
781 typename has_to_arithmetic_type<Ref>::type>
782 to_arithmetic_type(
const Ref& ref)
784 return ref.to_arithmetic_type();
787 template <
typename Ref>
788 constexpr GKO_ATTRIBUTES std::enable_if_t<!has_to_arithmetic_type<Ref>::value &&
789 has_arithmetic_type<Ref>::value,
790 typename Ref::arithmetic_type>
791 to_arithmetic_type(
const Ref& ref)
796 template <
typename Ref>
797 constexpr GKO_ATTRIBUTES std::enable_if_t<!has_to_arithmetic_type<Ref>::value &&
798 !has_arithmetic_type<Ref>::value,
800 to_arithmetic_type(
const Ref& ref)
809 template <
typename T>
810 GKO_ATTRIBUTES GKO_INLINE constexpr std::enable_if_t<!is_complex_s<T>::value, T>
811 real_impl(
const T& x)
816 template <
typename T>
817 GKO_ATTRIBUTES GKO_INLINE constexpr std::enable_if_t<is_complex_s<T>::value,
819 real_impl(
const T& x)
825 template <
typename T>
826 GKO_ATTRIBUTES GKO_INLINE constexpr std::enable_if_t<!is_complex_s<T>::value, T>
832 template <
typename T>
833 GKO_ATTRIBUTES GKO_INLINE constexpr std::enable_if_t<is_complex_s<T>::value,
835 imag_impl(
const T& x)
841 template <
typename T>
842 GKO_ATTRIBUTES GKO_INLINE constexpr std::enable_if_t<!is_complex_s<T>::value, T>
843 conj_impl(
const T& x)
848 template <
typename T>
849 GKO_ATTRIBUTES GKO_INLINE constexpr std::enable_if_t<is_complex_s<T>::value, T>
850 conj_impl(
const T& x)
852 return T{real_impl(x), -imag_impl(x)};
868 template <
typename T>
869 GKO_ATTRIBUTES GKO_INLINE constexpr
auto real(
const T& x)
871 return detail::real_impl(detail::to_arithmetic_type(x));
884 template <
typename T>
885 GKO_ATTRIBUTES GKO_INLINE constexpr
auto imag(
const T& x)
887 return detail::imag_impl(detail::to_arithmetic_type(x));
898 template <
typename T>
899 GKO_ATTRIBUTES GKO_INLINE constexpr
auto conj(
const T& x)
901 return detail::conj_impl(detail::to_arithmetic_type(x));
912 template <
typename T>
930 template <
typename T>
931 GKO_INLINE constexpr std::enable_if_t<!is_complex_s<T>::value, T>
abs(
934 return x >= zero<T>() ? x : -x;
938 template <
typename T>
939 GKO_INLINE constexpr std::enable_if_t<is_complex_s<T>::value, remove_complex<T>>
946 GKO_INLINE
gko::half abs(
const std::complex<gko::half>& x)
949 return static_cast<gko::half>(
abs(std::complex<float>(x)));
960 GKO_INLINE std::complex<gko::half> sqrt(std::complex<gko::half> a)
962 return std::complex<gko::half>(sqrt(std::complex<float>(
963 static_cast<float>(a.real()), static_cast<float>(a.imag()))));
972 template <
typename T>
973 GKO_INLINE constexpr T
pi()
975 return static_cast<T>(3.1415926535897932384626433);
987 template <
typename T>
1008 template <
typename T>
1026 template <
typename T>
1028 const T& hint = T{1}) noexcept
1045 template <
typename T>
1046 GKO_INLINE GKO_ATTRIBUTES std::enable_if_t<!is_complex_s<T>::value,
bool>
1049 constexpr T infinity{detail::infinity_impl<T>::value};
1050 return abs(value) < infinity;
1065 template <
typename T>
1066 GKO_INLINE GKO_ATTRIBUTES std::enable_if_t<is_complex_s<T>::value,
bool>
1084 template <
typename T>
1087 return b == zero<T>() ? zero<T>() : a / b;
1100 template <
typename T>
1102 "is_nan can't be used safely on the device (MSVC+CUDA), and will thus be "
1103 "removed in a future release, without replacement")
1104 GKO_INLINE GKO_ATTRIBUTES
1108 return isnan(value);
1121 template <
typename T>
1123 "is_nan can't be used safely on the device (MSVC+CUDA), and will thus be "
1124 "removed in a future release, without replacement")
1139 template <
typename T>
1140 GKO_INLINE constexpr std::enable_if_t<!is_complex_s<T>::value, T>
nan()
1142 return std::numeric_limits<T>::quiet_NaN();
1153 template <
typename T>
1154 GKO_INLINE constexpr std::enable_if_t<is_complex_s<T>::value, T>
nan()
1163 #endif // GKO_PUBLIC_CORE_BASE_MATH_HPP_