5 #ifndef GKO_PUBLIC_CORE_BASE_MATH_HPP_
6 #define GKO_PUBLIC_CORE_BASE_MATH_HPP_
13 #include <type_traits>
17 #include <ginkgo/config.hpp>
18 #include <ginkgo/core/base/types.hpp>
19 #include <ginkgo/core/base/utils.hpp>
109 template <
typename T>
110 struct remove_complex_impl {
117 template <
typename T>
118 struct remove_complex_impl<std::complex<T>> {
128 template <
typename T>
129 struct to_complex_impl {
130 using type = std::complex<T>;
138 template <
typename T>
139 struct to_complex_impl<std::complex<T>> {
140 using type = std::complex<T>;
144 template <
typename T>
145 struct is_complex_impl :
public std::integral_constant<bool, false> {};
147 template <
typename T>
148 struct is_complex_impl<std::complex<T>>
149 :
public std::integral_constant<bool, true> {};
152 template <
typename T>
153 struct is_complex_or_scalar_impl : std::is_scalar<T> {};
155 template <
typename T>
156 struct is_complex_or_scalar_impl<std::complex<T>> : std::is_scalar<T> {};
166 template <
template <
typename>
class converter,
typename T>
167 struct template_converter {};
178 template <
template <
typename>
class converter,
template <
typename...>
class T,
180 struct template_converter<converter, T<Rest...>> {
181 using type = T<typename converter<Rest>::type...>;
185 template <
typename T,
typename =
void>
186 struct remove_complex_s {};
194 template <
typename T>
195 struct remove_complex_s<T,
196 std::enable_if_t<is_complex_or_scalar_impl<T>::value>> {
197 using type =
typename detail::remove_complex_impl<T>::type;
206 template <
typename T>
207 struct remove_complex_s<
208 T, std::enable_if_t<!is_complex_or_scalar_impl<T>::value>> {
210 typename detail::template_converter<detail::remove_complex_impl,
215 template <
typename T,
typename =
void>
216 struct to_complex_s {};
224 template <
typename T>
225 struct to_complex_s<T, std::enable_if_t<is_complex_or_scalar_impl<T>::value>> {
226 using type =
typename detail::to_complex_impl<T>::type;
235 template <
typename T>
236 struct to_complex_s<T, std::enable_if_t<!is_complex_or_scalar_impl<T>::value>> {
238 typename detail::template_converter<detail::to_complex_impl, T>::type;
250 template <
typename T>
261 template <
typename T>
264 using type =
typename std::complex<T>::value_type;
276 template <
typename T>
286 template <
typename T>
289 return detail::is_complex_impl<T>::value;
300 template <
typename T>
310 template <
typename T>
313 return detail::is_complex_or_scalar_impl<T>::value;
325 template <
typename T>
344 template <
typename T>
353 template <
typename T>
361 template <
typename T>
362 struct next_precision_impl {};
365 struct next_precision_impl<float> {
370 struct next_precision_impl<double> {
374 template <
typename T>
375 struct next_precision_impl<std::complex<T>> {
376 using type = std::complex<typename next_precision_impl<T>::type>;
380 template <
typename T>
381 struct reduce_precision_impl {
385 template <
typename T>
386 struct reduce_precision_impl<std::complex<T>> {
387 using type = std::complex<typename reduce_precision_impl<T>::type>;
391 struct reduce_precision_impl<double> {
396 struct reduce_precision_impl<float> {
401 template <
typename T>
402 struct increase_precision_impl {
406 template <
typename T>
407 struct increase_precision_impl<std::complex<T>> {
408 using type = std::complex<typename increase_precision_impl<T>::type>;
412 struct increase_precision_impl<float> {
417 struct increase_precision_impl<half> {
422 template <
typename T>
423 struct infinity_impl {
426 static constexpr
auto value = std::numeric_limits<T>::infinity();
433 template <
typename T1,
typename T2>
434 struct highest_precision_impl {
435 using type = decltype(T1{} + T2{});
438 template <
typename T1,
typename T2>
439 struct highest_precision_impl<std::complex<T1>, std::complex<T2>> {
440 using type = std::complex<typename highest_precision_impl<T1, T2>::type>;
443 template <
typename Head,
typename... Tail>
444 struct highest_precision_variadic {
445 using type =
typename highest_precision_impl<
446 Head,
typename highest_precision_variadic<Tail...>::type>::type;
449 template <
typename Head>
450 struct highest_precision_variadic<Head> {
461 template <
typename T>
471 template <
typename T>
478 template <
typename T>
485 template <
typename T>
500 template <
typename... Ts>
502 typename detail::highest_precision_variadic<Ts...>::type;
514 template <
typename T>
530 template <
typename T>
537 template <
typename FloatType,
size_type NumComponents,
size_type ComponentId>
544 template <
typename T>
545 struct truncate_type_impl {
549 template <
typename T,
size_type Components>
550 struct truncate_type_impl<
truncated<T, Components, 0>> {
554 template <
typename T>
555 struct truncate_type_impl<std::complex<T>> {
556 using type = std::complex<typename truncate_type_impl<T>::type>;
560 template <
typename T>
561 struct type_size_impl {
562 static constexpr
auto value =
sizeof(T) *
byte_size;
565 template <
typename T>
566 struct type_size_impl<std::complex<T>> {
567 static constexpr
auto value =
sizeof(T) *
byte_size;
578 template <
typename T,
size_type Limit = sizeof(u
int16) *
byte_size>
580 std::conditional_t<detail::type_size_impl<T>::value >= 2 * Limit,
590 template <
typename S,
typename R>
598 GKO_ATTRIBUTES R
operator()(S val) {
return static_cast<R>(val); }
615 return (num + den - 1) / den;
619 #if defined(__HIPCC__) && GINKGO_HIP_PLATFORM_HCC
627 template <
typename T>
628 GKO_INLINE __host__ constexpr T
zero()
643 template <
typename T>
644 GKO_INLINE __host__ constexpr T
zero(
const T&)
655 template <
typename T>
656 GKO_INLINE __host__ constexpr T
one()
671 template <
typename T>
672 GKO_INLINE __host__ constexpr T
one(
const T&)
683 template <
typename T>
684 GKO_INLINE __device__ constexpr std::enable_if_t<
685 !std::is_same<T, std::complex<remove_complex<T>>>::value, T>
701 template <
typename T>
702 GKO_INLINE __device__ constexpr T
zero(
const T&)
713 template <
typename T>
714 GKO_INLINE __device__ constexpr std::enable_if_t<
715 !std::is_same<T, std::complex<remove_complex<T>>>::value, T>
731 template <
typename T>
732 GKO_INLINE __device__ constexpr T
one(
const T&)
746 template <
typename T>
747 GKO_INLINE GKO_ATTRIBUTES constexpr T
zero()
762 template <
typename T>
763 GKO_INLINE GKO_ATTRIBUTES constexpr T
zero(
const T&)
774 template <
typename T>
775 GKO_INLINE GKO_ATTRIBUTES constexpr T
one()
790 template <
typename T>
791 GKO_INLINE GKO_ATTRIBUTES constexpr T
one(
const T&)
797 #endif // defined(__HIPCC__) && GINKGO_HIP_PLATFORM_HCC
800 #undef GKO_BIND_ZERO_ONE
811 template <
typename T>
812 GKO_INLINE GKO_ATTRIBUTES constexpr
bool is_zero(T value)
814 return value == zero<T>();
826 template <
typename T>
829 return value != zero<T>();
844 template <
typename T>
845 GKO_INLINE GKO_ATTRIBUTES constexpr T
max(
const T& x,
const T& y)
847 return x >= y ? x : y;
862 template <
typename T>
863 GKO_INLINE GKO_ATTRIBUTES constexpr T
min(
const T& x,
const T& y)
865 return x <= y ? x : y;
881 template <
typename Ref,
typename Dummy = xstd::
void_t<>>
882 struct has_to_arithmetic_type : std::false_type {
883 static_assert(std::is_same<Dummy, void>::value,
884 "Do not modify the Dummy value!");
888 template <
typename Ref>
889 struct has_to_arithmetic_type<
890 Ref, xstd::void_t<decltype(std::declval<Ref>().to_arithmetic_type())>>
892 using type = decltype(std::declval<Ref>().to_arithmetic_type());
900 template <
typename Ref,
typename Dummy = xstd::
void_t<>>
901 struct has_arithmetic_type : std::false_type {
902 static_assert(std::is_same<Dummy, void>::value,
903 "Do not modify the Dummy value!");
906 template <
typename Ref>
907 struct has_arithmetic_type<Ref, xstd::void_t<typename Ref::arithmetic_type>>
922 template <
typename Ref>
923 constexpr GKO_ATTRIBUTES
924 std::enable_if_t<has_to_arithmetic_type<Ref>::value,
925 typename has_to_arithmetic_type<Ref>::type>
926 to_arithmetic_type(
const Ref& ref)
928 return ref.to_arithmetic_type();
931 template <
typename Ref>
932 constexpr GKO_ATTRIBUTES std::enable_if_t<!has_to_arithmetic_type<Ref>::value &&
933 has_arithmetic_type<Ref>::value,
934 typename Ref::arithmetic_type>
935 to_arithmetic_type(
const Ref& ref)
940 template <
typename Ref>
941 constexpr GKO_ATTRIBUTES std::enable_if_t<!has_to_arithmetic_type<Ref>::value &&
942 !has_arithmetic_type<Ref>::value,
944 to_arithmetic_type(
const Ref& ref)
953 template <
typename T>
954 GKO_ATTRIBUTES GKO_INLINE constexpr std::enable_if_t<!is_complex_s<T>::value, T>
955 real_impl(
const T& x)
960 template <
typename T>
961 GKO_ATTRIBUTES GKO_INLINE constexpr std::enable_if_t<is_complex_s<T>::value,
963 real_impl(
const T& x)
969 template <
typename T>
970 GKO_ATTRIBUTES GKO_INLINE constexpr std::enable_if_t<!is_complex_s<T>::value, T>
976 template <
typename T>
977 GKO_ATTRIBUTES GKO_INLINE constexpr std::enable_if_t<is_complex_s<T>::value,
979 imag_impl(
const T& x)
985 template <
typename T>
986 GKO_ATTRIBUTES GKO_INLINE constexpr std::enable_if_t<!is_complex_s<T>::value, T>
987 conj_impl(
const T& x)
992 template <
typename T>
993 GKO_ATTRIBUTES GKO_INLINE constexpr std::enable_if_t<is_complex_s<T>::value, T>
994 conj_impl(
const T& x)
996 return T{real_impl(x), -imag_impl(x)};
1012 template <
typename T>
1013 GKO_ATTRIBUTES GKO_INLINE constexpr
auto real(
const T& x)
1015 return detail::real_impl(detail::to_arithmetic_type(x));
1028 template <
typename T>
1029 GKO_ATTRIBUTES GKO_INLINE constexpr
auto imag(
const T& x)
1031 return detail::imag_impl(detail::to_arithmetic_type(x));
1042 template <
typename T>
1043 GKO_ATTRIBUTES GKO_INLINE constexpr
auto conj(
const T& x)
1045 return detail::conj_impl(detail::to_arithmetic_type(x));
1056 template <
typename T>
1073 template <
typename T>
1075 GKO_ATTRIBUTES constexpr xstd::enable_if_t<!is_complex_s<T>::value, T>
1078 return x >= zero<T>() ? x : -x;
1082 template <
typename T>
1083 GKO_INLINE GKO_ATTRIBUTES constexpr xstd::enable_if_t<is_complex_s<T>::value,
1096 template <
typename T>
1097 GKO_INLINE GKO_ATTRIBUTES constexpr T
pi()
1099 return static_cast<T>(3.1415926535897932384626433);
1111 template <
typename T>
1112 GKO_INLINE GKO_ATTRIBUTES constexpr std::complex<remove_complex<T>>
unit_root(
1132 template <
typename T>
1150 template <
typename T>
1152 const T& hint = T{1}) noexcept
1169 template <
typename T>
1170 GKO_INLINE GKO_ATTRIBUTES std::enable_if_t<!is_complex_s<T>::value,
bool>
1173 constexpr T infinity{detail::infinity_impl<T>::value};
1174 return abs(value) < infinity;
1189 template <
typename T>
1190 GKO_INLINE GKO_ATTRIBUTES std::enable_if_t<is_complex_s<T>::value,
bool>
1208 template <
typename T>
1211 return b == zero<T>() ? zero<T>() : a / b;
1224 template <
typename T>
1225 GKO_INLINE GKO_ATTRIBUTES std::enable_if_t<!is_complex_s<T>::value,
bool>
1228 return std::isnan(value);
1241 template <
typename T>
1242 GKO_INLINE GKO_ATTRIBUTES std::enable_if_t<is_complex_s<T>::value,
bool>
is_nan(
1245 return std::isnan(value.real()) || std::isnan(value.imag());
1256 template <
typename T>
1257 GKO_INLINE GKO_ATTRIBUTES constexpr std::enable_if_t<!is_complex_s<T>::value, T>
1260 return std::numeric_limits<T>::quiet_NaN();
1271 template <
typename T>
1272 GKO_INLINE GKO_ATTRIBUTES constexpr std::enable_if_t<is_complex_s<T>::value, T>
1282 #endif // GKO_PUBLIC_CORE_BASE_MATH_HPP_