5 #ifndef GKO_PUBLIC_CORE_BASE_MATH_HPP_
6 #define GKO_PUBLIC_CORE_BASE_MATH_HPP_
14 #include <type_traits>
17 #include <ginkgo/config.hpp>
18 #include <ginkgo/core/base/half.hpp>
19 #include <ginkgo/core/base/types.hpp>
20 #include <ginkgo/core/base/utils.hpp>
41 struct remove_complex_impl {
49 struct remove_complex_impl<std::complex<T>> {
60 struct to_complex_impl {
61 using type = std::complex<T>;
70 struct to_complex_impl<std::complex<T>> {
71 using type = std::complex<T>;
76 struct is_complex_impl :
public std::integral_constant<bool, false> {};
79 struct is_complex_impl<std::complex<T>>
80 :
public std::integral_constant<bool, true> {};
84 struct is_complex_or_scalar_impl : std::is_scalar<T> {};
87 struct is_complex_or_scalar_impl<half> : std::true_type {};
90 struct is_complex_or_scalar_impl<bfloat16> : std::true_type {};
93 struct is_complex_or_scalar_impl<std::complex<T>>
94 : is_complex_or_scalar_impl<T> {};
104 template <
template <
typename>
class converter,
typename T>
105 struct template_converter {};
116 template <
template <
typename>
class converter,
template <
typename...>
class T,
118 struct template_converter<converter, T<Rest...>> {
119 using type = T<typename converter<Rest>::type...>;
123 template <
typename T,
typename =
void>
124 struct remove_complex_s {};
132 template <
typename T>
133 struct remove_complex_s<T,
134 std::enable_if_t<is_complex_or_scalar_impl<T>::value>> {
135 using type =
typename detail::remove_complex_impl<T>::type;
144 template <
typename T>
145 struct remove_complex_s<
146 T, std::enable_if_t<!is_complex_or_scalar_impl<T>::value>> {
148 typename detail::template_converter<detail::remove_complex_impl,
153 template <
typename T,
typename =
void>
154 struct to_complex_s {};
162 template <
typename T>
163 struct to_complex_s<T, std::enable_if_t<is_complex_or_scalar_impl<T>::value>> {
164 using type =
typename detail::to_complex_impl<T>::type;
173 template <
typename T>
174 struct to_complex_s<T, std::enable_if_t<!is_complex_or_scalar_impl<T>::value>> {
176 typename detail::template_converter<detail::to_complex_impl, T>::type;
188 template <
typename T>
199 template <
typename T>
202 using type =
typename std::complex<T>::value_type;
214 template <
typename T>
224 template <
typename T>
227 return detail::is_complex_impl<T>::value;
238 template <
typename T>
248 template <
typename T>
251 return detail::is_complex_or_scalar_impl<T>::value;
263 template <
typename T>
282 template <
typename T>
291 template <
typename T>
299 template <
typename T>
300 struct next_precision_base_impl {};
303 struct next_precision_base_impl<float> {
308 struct next_precision_base_impl<double> {
312 template <
typename T>
313 struct next_precision_base_impl<std::complex<T>> {
314 using type = std::complex<typename next_precision_base_impl<T>::type>;
323 template <
typename T,
int step,
typename Visited,
typename... Rest>
324 struct find_precision_list_impl;
326 template <
typename T,
int step,
typename... Visited,
typename U,
328 struct find_precision_list_impl<T, step, std::tuple<Visited...>, U, Rest...> {
330 typename find_precision_list_impl<T, step, std::tuple<Visited..., U>,
334 template <
typename T,
int step,
typename... Visited,
typename... Rest>
335 struct find_precision_list_impl<T, step, std::tuple<Visited...>, T, Rest...> {
336 using tuple = std::tuple<T, Rest..., Visited...>;
337 constexpr
static auto tuple_size =
338 static_cast<int>(std::tuple_size_v<tuple>);
340 constexpr
static int index = (tuple_size + step % tuple_size) % tuple_size;
341 using type = std::tuple_element_t<index, tuple>;
345 template <
typename T,
int step = 1>
346 struct find_precision_impl {
347 using type =
typename find_precision_list_impl<T, step, std::tuple<>,
348 #if GINKGO_ENABLE_HALF
351 #if GINKGO_ENABLE_BFLOAT16
354 float,
double>::type;
358 template <
typename T,
int step>
359 struct find_precision_impl<std::complex<T>, step> {
360 using type = std::complex<typename find_precision_impl<T, step>::type>;
364 template <
typename T>
365 struct reduce_precision_impl {
369 template <
typename T>
370 struct reduce_precision_impl<std::complex<T>> {
371 using type = std::complex<typename reduce_precision_impl<T>::type>;
375 struct reduce_precision_impl<double> {
381 struct reduce_precision_impl<float> {
386 template <
typename T>
387 struct increase_precision_impl {
391 template <
typename T>
392 struct increase_precision_impl<std::complex<T>> {
393 using type = std::complex<typename increase_precision_impl<T>::type>;
397 struct increase_precision_impl<float> {
403 struct increase_precision_impl<
half> {
408 template <
typename T>
409 struct infinity_impl {
412 static constexpr
auto value = std::numeric_limits<T>::infinity();
419 template <
typename T1,
typename T2>
420 struct highest_precision_impl {
421 using type = decltype(T1{} + T2{});
424 template <
typename T1,
typename T2>
425 struct highest_precision_impl<std::complex<T1>, std::complex<T2>> {
426 using type = std::complex<typename highest_precision_impl<T1, T2>::type>;
429 template <
typename Head,
typename... Tail>
430 struct highest_precision_variadic {
431 using type =
typename highest_precision_impl<
432 Head,
typename highest_precision_variadic<Tail...>::type>::type;
435 template <
typename Head>
436 struct highest_precision_variadic<Head> {
447 template <
typename T>
457 template <
typename T>
465 template <
typename T,
int step = 1>
472 template <
typename T,
int step = 1>
479 template <
typename T>
486 template <
typename T>
501 template <
typename... Ts>
503 typename detail::highest_precision_variadic<Ts...>::type;
515 template <
typename T>
531 template <
typename T>
538 template <
typename FloatType,
size_type NumComponents,
size_type ComponentId>
545 template <
typename T>
546 struct truncate_type_impl {
547 using type = truncated<T, 2, 0>;
550 template <
typename T,
size_type Components>
551 struct truncate_type_impl<truncated<T, Components, 0>> {
552 using type = truncated<T, 2 * Components, 0>;
555 template <
typename T>
556 struct truncate_type_impl<std::complex<T>> {
557 using type = std::complex<typename truncate_type_impl<T>::type>;
561 template <
typename T>
562 struct type_size_impl {
563 static constexpr
auto value =
sizeof(T) *
byte_size;
566 template <
typename T>
567 struct type_size_impl<std::complex<T>> {
568 static constexpr
auto value =
sizeof(T) *
byte_size;
579 template <
typename T,
size_type Limit = sizeof(u
int16) *
byte_size>
581 std::conditional_t<detail::type_size_impl<T>::value >= 2 * Limit,
591 template <
typename S,
typename R>
599 GKO_ATTRIBUTES R
operator()(S val) {
return static_cast<R>(val); }
616 return (num + den - 1) / den;
625 template <
typename T>
641 template <
typename T>
642 GKO_INLINE constexpr T
zero(
const T&)
653 template <
typename T>
654 GKO_INLINE constexpr T
one()
660 GKO_INLINE constexpr half one<half>()
662 constexpr
auto bits = static_cast<uint16>(0b0
'01111'0000000000u);
663 return half::create_from_bits(bits);
667 GKO_INLINE constexpr bfloat16 one<bfloat16>()
669 constexpr
auto bits = static_cast<uint16>(0b0
'01111111'0000000u);
670 return bfloat16::create_from_bits(bits);
683 template <
typename T>
684 GKO_INLINE constexpr T
one(
const T&)
698 template <
typename T>
701 return value == zero<T>();
713 template <
typename T>
716 return value != zero<T>();
731 template <
typename T>
732 GKO_INLINE constexpr T
max(
const T& x,
const T& y)
734 return x >= y ? x : y;
749 template <
typename T>
750 GKO_INLINE constexpr T
min(
const T& x,
const T& y)
752 return x <= y ? x : y;
768 template <
typename Ref,
typename Dummy = std::
void_t<>>
769 struct has_to_arithmetic_type : std::false_type {
770 static_assert(std::is_same<Dummy, void>::value,
771 "Do not modify the Dummy value!");
775 template <
typename Ref>
776 struct has_to_arithmetic_type<
777 Ref, std::
void_t<decltype(std::declval<Ref>().to_arithmetic_type())>>
779 using type = decltype(std::declval<Ref>().to_arithmetic_type());
787 template <
typename Ref,
typename Dummy = std::
void_t<>>
788 struct has_arithmetic_type : std::false_type {
789 static_assert(std::is_same<Dummy, void>::value,
790 "Do not modify the Dummy value!");
793 template <
typename Ref>
794 struct has_arithmetic_type<Ref, std::
void_t<typename Ref::arithmetic_type>>
809 template <
typename Ref>
810 constexpr GKO_ATTRIBUTES
811 std::enable_if_t<has_to_arithmetic_type<Ref>::value,
812 typename has_to_arithmetic_type<Ref>::type>
813 to_arithmetic_type(
const Ref& ref)
815 return ref.to_arithmetic_type();
818 template <
typename Ref>
819 constexpr GKO_ATTRIBUTES std::enable_if_t<!has_to_arithmetic_type<Ref>::value &&
820 has_arithmetic_type<Ref>::value,
821 typename Ref::arithmetic_type>
822 to_arithmetic_type(
const Ref& ref)
827 template <
typename Ref>
828 constexpr GKO_ATTRIBUTES std::enable_if_t<!has_to_arithmetic_type<Ref>::value &&
829 !has_arithmetic_type<Ref>::value,
831 to_arithmetic_type(
const Ref& ref)
840 template <
typename T>
841 GKO_ATTRIBUTES GKO_INLINE constexpr std::enable_if_t<!is_complex_s<T>::value, T>
842 real_impl(
const T& x)
847 template <
typename T>
848 GKO_ATTRIBUTES GKO_INLINE constexpr std::enable_if_t<is_complex_s<T>::value,
850 real_impl(
const T& x)
856 template <
typename T>
857 GKO_ATTRIBUTES GKO_INLINE constexpr std::enable_if_t<!is_complex_s<T>::value, T>
863 template <
typename T>
864 GKO_ATTRIBUTES GKO_INLINE constexpr std::enable_if_t<is_complex_s<T>::value,
866 imag_impl(
const T& x)
872 template <
typename T>
873 GKO_ATTRIBUTES GKO_INLINE constexpr std::enable_if_t<!is_complex_s<T>::value, T>
874 conj_impl(
const T& x)
879 template <
typename T>
880 GKO_ATTRIBUTES GKO_INLINE constexpr std::enable_if_t<is_complex_s<T>::value, T>
881 conj_impl(
const T& x)
883 return T{real_impl(x), -imag_impl(x)};
899 template <
typename T>
900 GKO_ATTRIBUTES GKO_INLINE constexpr
auto real(
const T& x)
902 return detail::real_impl(detail::to_arithmetic_type(x));
915 template <
typename T>
916 GKO_ATTRIBUTES GKO_INLINE constexpr
auto imag(
const T& x)
918 return detail::imag_impl(detail::to_arithmetic_type(x));
929 template <
typename T>
930 GKO_ATTRIBUTES GKO_INLINE constexpr
auto conj(
const T& x)
932 return detail::conj_impl(detail::to_arithmetic_type(x));
943 template <
typename T>
961 template <
typename T>
962 GKO_INLINE constexpr std::enable_if_t<!is_complex_s<T>::value, T>
abs(
965 return x >= zero<T>() ? x : -x;
969 template <
typename T>
970 GKO_INLINE constexpr std::enable_if_t<is_complex_s<T>::value, remove_complex<T>>
977 GKO_INLINE
gko::half abs(
const std::complex<gko::half>& x)
980 return static_cast<gko::half>(
abs(std::complex<float>(x)));
986 return static_cast<gko::bfloat16>(
abs(std::complex<float>(x)));
997 GKO_INLINE std::complex<gko::half> sqrt(std::complex<gko::half> a)
999 return std::complex<gko::half>(sqrt(std::complex<float>(
1000 static_cast<float>(a.real()), static_cast<float>(a.imag()))));
1008 GKO_INLINE std::complex<gko::bfloat16> sqrt(std::complex<gko::bfloat16> a)
1010 return std::complex<gko::bfloat16>(sqrt(std::complex<float>(
1011 static_cast<float>(a.real()), static_cast<float>(a.imag()))));
1020 template <
typename T>
1021 GKO_INLINE constexpr T
pi()
1023 return static_cast<T>(3.1415926535897932384626433);
1035 template <
typename T>
1056 template <
typename T>
1074 template <
typename T>
1076 const T& hint = T{1}) noexcept
1093 template <
typename T>
1094 GKO_INLINE GKO_ATTRIBUTES std::enable_if_t<!is_complex_s<T>::value,
bool>
1097 constexpr T infinity{detail::infinity_impl<T>::value};
1098 return abs(value) < infinity;
1113 template <
typename T>
1114 GKO_INLINE GKO_ATTRIBUTES std::enable_if_t<is_complex_s<T>::value,
bool>
1132 template <
typename T>
1135 return b == zero<T>() ? zero<T>() : a / b;
1148 template <
typename T>
1150 "is_nan can't be used safely on the device (MSVC+CUDA), and will thus be "
1151 "removed in a future release, without replacement")
1152 GKO_INLINE GKO_ATTRIBUTES
1156 return isnan(value);
1169 template <
typename T>
1171 "is_nan can't be used safely on the device (MSVC+CUDA), and will thus be "
1172 "removed in a future release, without replacement")
1187 template <
typename T>
1188 GKO_INLINE constexpr std::enable_if_t<!is_complex_s<T>::value, T>
nan()
1190 return std::numeric_limits<T>::quiet_NaN();
1201 template <
typename T>
1202 GKO_INLINE constexpr std::enable_if_t<is_complex_s<T>::value, T>
nan()
1211 #endif // GKO_PUBLIC_CORE_BASE_MATH_HPP_