5 #ifndef GKO_PUBLIC_CORE_BASE_HALF_HPP_
6 #define GKO_PUBLIC_CORE_BASE_HALF_HPP_
13 #include <type_traits>
22 template <
typename, std::
size_t, std::
size_t>
32 constexpr std::size_t
byte_size = CHAR_BIT;
34 template <std::
size_t,
typename =
void>
35 struct uint_of_impl {};
37 template <std::
size_t Bits>
38 struct uint_of_impl<Bits, std::enable_if_t<(Bits <= 16)>> {
39 using type = std::uint16_t;
42 template <std::
size_t Bits>
43 struct uint_of_impl<Bits, std::enable_if_t<(16 < Bits && Bits <= 32)>> {
44 using type = std::uint32_t;
47 template <std::size_t Bits>
48 struct uint_of_impl<Bits, std::enable_if_t<(32 < Bits) && (Bits <= 64)>> {
49 using type = std::uint64_t;
52 template <std::size_t Bits>
53 using uint_of = typename uint_of_impl<Bits>::type;
57 struct basic_float_traits {};
60 struct basic_float_traits<half> {
62 static constexpr int sign_bits = 1;
63 static constexpr int significand_bits = 10;
64 static constexpr int exponent_bits = 5;
65 static constexpr bool rounds_to_nearest = true;
69 struct basic_float_traits<__half> {
71 static constexpr int sign_bits = 1;
72 static constexpr int significand_bits = 10;
73 static constexpr int exponent_bits = 5;
74 static constexpr bool rounds_to_nearest = true;
78 struct basic_float_traits<float> {
80 static constexpr int sign_bits = 1;
81 static constexpr int significand_bits = 23;
82 static constexpr int exponent_bits = 8;
83 static constexpr bool rounds_to_nearest = true;
87 struct basic_float_traits<double> {
89 static constexpr int sign_bits = 1;
90 static constexpr int significand_bits = 52;
91 static constexpr int exponent_bits = 11;
92 static constexpr bool rounds_to_nearest = true;
95 template <typename FloatType, std::size_t NumComponents,
96 std::size_t ComponentId>
97 struct basic_float_traits<truncated<FloatType, NumComponents, ComponentId>> {
98 using type = truncated<FloatType, NumComponents, ComponentId>;
99 static constexpr int sign_bits = ComponentId == 0 ? 1 : 0;
100 static constexpr int exponent_bits =
101 ComponentId == 0 ? basic_float_traits<FloatType>::exponent_bits : 0;
102 static constexpr int significand_bits =
103 ComponentId == 0 ? sizeof(type) * byte_size - exponent_bits - 1
104 : sizeof(type) * byte_size;
105 static constexpr bool rounds_to_nearest = false;
109 template <typename UintType>
110 constexpr UintType create_ones(int n)
112 return (n == sizeof(UintType) * byte_size ? static_cast<UintType>(0)
113 : static_cast<UintType>(1) << n) -
114 static_cast<UintType>(1);
118 template <typename T>
119 struct float_traits {
120 using type = typename basic_float_traits<T>::type;
121 using bits_type = uint_of<sizeof(type) * byte_size>;
122 static constexpr int sign_bits = basic_float_traits<T>::sign_bits;
123 static constexpr int significand_bits =
124 basic_float_traits<T>::significand_bits;
125 static constexpr int exponent_bits = basic_float_traits<T>::exponent_bits;
126 static constexpr bits_type significand_mask =
127 create_ones<bits_type>(significand_bits);
128 static constexpr bits_type exponent_mask =
129 create_ones<bits_type>(significand_bits + exponent_bits) -
131 static constexpr bits_type bias_mask =
132 create_ones<bits_type>(significand_bits + exponent_bits - 1) -
134 static constexpr bits_type sign_mask =
135 create_ones<bits_type>(sign_bits + significand_bits + exponent_bits) -
136 exponent_mask - significand_mask;
137 static constexpr bool rounds_to_nearest =
138 basic_float_traits<T>::rounds_to_nearest;
140 static constexpr auto eps =
141 1.0 / (1ll << (significand_bits + rounds_to_nearest));
143 static constexpr bool is_inf(bits_type data)
145 return (data & exponent_mask) == exponent_mask &&
146 (data & significand_mask) == bits_type{};
149 static constexpr bool is_nan(bits_type data)
151 return (data & exponent_mask) == exponent_mask &&
152 (data & significand_mask) != bits_type{};
155 static constexpr bool is_denom(bits_type data)
157 return (data & exponent_mask) == bits_type{};
162 template <typename SourceType, typename ResultType,
163 bool = (sizeof(SourceType) <= sizeof(ResultType))>
164 struct precision_converter;
167 template <typename SourceType, typename ResultType>
168 struct precision_converter<SourceType, ResultType, true> {
169 using source_traits = float_traits<SourceType>;
170 using result_traits = float_traits<ResultType>;
171 using source_bits = typename source_traits::bits_type;
172 using result_bits = typename result_traits::bits_type;
174 static_assert(source_traits::exponent_bits <=
175 result_traits::exponent_bits &&
176 source_traits::significand_bits <=
177 result_traits::significand_bits,
178 "SourceType has to have both lower range and precision or "
179 "higher range and precision than ResultType");
181 static constexpr int significand_offset =
182 result_traits::significand_bits - source_traits::significand_bits;
183 static constexpr int exponent_offset = significand_offset;
184 static constexpr int sign_offset = result_traits::exponent_bits -
185 source_traits::exponent_bits +
187 static constexpr result_bits bias_change =
188 result_traits::bias_mask -
189 (static_cast<result_bits>(source_traits::bias_mask) << exponent_offset);
191 static constexpr result_bits shift_significand(source_bits data) noexcept
193 return static_cast<result_bits>(data & source_traits::significand_mask)
194 << significand_offset;
197 static constexpr result_bits shift_exponent(source_bits data) noexcept
200 static_cast<result_bits>(data & source_traits::exponent_mask)
204 static constexpr result_bits shift_sign(source_bits data) noexcept
206 return static_cast<result_bits>(data & source_traits::sign_mask)
211 static constexpr result_bits update_bias(result_bits data) noexcept
213 return data == typename result_traits::bits_type{} ? data
214 : data + bias_change;
219 template <typename SourceType, typename ResultType>
220 struct precision_converter<SourceType, ResultType, false> {
221 using source_traits = float_traits<SourceType>;
222 using result_traits = float_traits<ResultType>;
223 using source_bits = typename source_traits::bits_type;
224 using result_bits = typename result_traits::bits_type;
226 static_assert(source_traits::exponent_bits >=
227 result_traits::exponent_bits &&
228 source_traits::significand_bits >=
229 result_traits::significand_bits,
230 "SourceType has to have both lower range and precision or "
231 "higher range and precision than ResultType");
233 static constexpr int significand_offset =
234 source_traits::significand_bits - result_traits::significand_bits;
235 static constexpr int exponent_offset = significand_offset;
236 static constexpr int sign_offset = source_traits::exponent_bits -
237 result_traits::exponent_bits +
239 static constexpr source_bits bias_change =
240 (source_traits::bias_mask >> exponent_offset) -
241 static_cast<source_bits>(result_traits::bias_mask);
243 static constexpr result_bits shift_significand(source_bits data) noexcept
245 return static_cast<result_bits>(
246 (data & source_traits::significand_mask) >> significand_offset);
249 static constexpr result_bits shift_exponent(source_bits data) noexcept
251 return static_cast<result_bits>(update_bias(
252 (data & source_traits::exponent_mask) >> exponent_offset));
255 static constexpr result_bits shift_sign(source_bits data) noexcept
257 return static_cast<result_bits>((data & source_traits::sign_mask) >>
262 static constexpr source_bits update_bias(source_bits data) noexcept
264 return data <= bias_change ? typename source_traits::bits_type{}
265 : limit_exponent(data - bias_change);
268 static constexpr source_bits limit_exponent(source_bits data) noexcept
270 return data >= static_cast<source_bits>(result_traits::exponent_mask)
271 ? static_cast<source_bits>(result_traits::exponent_mask)
286 class alignas(std::uint16_t) half {
289 static constexpr half create_from_bits(const std::uint16_t& bits) noexcept
299 constexpr half() noexcept : data_(0){};
301 template <typename T, typename = std::enable_if_t<std::is_scalar<T>::value>>
302 half(const T& val) : data_(0)
304 this->float2half(static_cast<float>(val));
307 template <typename V>
308 half& operator=(const V& val)
310 this->float2half(static_cast<float>(val));
314 operator float() const noexcept
316 const auto bits = half2float(data_);
318 std::memcpy(&ans, &bits, sizeof(float));
325 #define HALF_OPERATOR(_op, _opeq) \
326 friend half operator _op(const half& lhf, const half& rhf) \
328 return static_cast<half>(static_cast<float>(lhf) \
329 _op static_cast<float>(rhf)); \
331 half& operator _opeq(const half& hf) \
333 auto result = *this _op hf; \
334 data_ = result.data_; \
348 #define HALF_FRIEND_OPERATOR(_op, _opeq) \
349 template <typename T> \
350 friend std::enable_if_t< \
351 !std::is_same<T, half>::value && std::is_scalar<T>::value, \
352 std::conditional_t<std::is_floating_point<T>::value, T, half>> \
353 operator _op(const half& hf, const T& val) \
356 std::conditional_t<std::is_floating_point<T>::value, T, half>; \
357 auto result = static_cast<type>(hf); \
358 result _opeq static_cast<type>(val); \
361 template <typename T> \
362 friend std::enable_if_t< \
363 !std::is_same<T, half>::value && std::is_scalar<T>::value, \
364 std::conditional_t<std::is_floating_point<T>::value, T, half>> \
365 operator _op(const T& val, const half& hf) \
368 std::conditional_t<std::is_floating_point<T>::value, T, half>; \
369 auto result = static_cast<type>(val); \
370 result _opeq static_cast<type>(hf); \
374 HALF_FRIEND_OPERATOR(+, +=)
375 HALF_FRIEND_OPERATOR(-, -=)
376 HALF_FRIEND_OPERATOR(*, *=)
377 HALF_FRIEND_OPERATOR(/, /=)
379 #undef HALF_FRIEND_OPERATOR
382 half operator-() const
384 auto val = 0.0f - *this;
385 return static_cast<half>(val);
389 using f16_traits = detail::float_traits<half>;
390 using f32_traits = detail::float_traits<float>;
392 void float2half(const float& val) noexcept
394 std::uint32_t bit_val(0);
395 std::memcpy(&bit_val, &val, sizeof(float));
396 data_ = float2half(bit_val);
399 static constexpr std::uint16_t float2half(std::uint32_t data_) noexcept
401 using conv = detail::precision_converter<float, half>;
402 if (f32_traits::is_inf(data_)) {
403 return conv::shift_sign(data_) | f16_traits::exponent_mask;
404 } else if (f32_traits::is_nan(data_)) {
405 return conv::shift_sign(data_) | f16_traits::exponent_mask |
406 f16_traits::significand_mask;
408 const auto exp = conv::shift_exponent(data_);
409 if (f16_traits::is_inf(exp)) {
410 return conv::shift_sign(data_) | exp;
411 } else if (f16_traits::is_denom(exp)) {
413 return conv::shift_sign(data_);
416 const auto result = conv::shift_sign(data_) | exp |
417 conv::shift_significand(data_);
419 data_ & static_cast<f32_traits::bits_type>(
420 (1 << conv::significand_offset) - 1);
422 constexpr auto half = static_cast<f32_traits::bits_type>(
423 1 << (conv::significand_offset - 1));
425 (tail > half || ((tail == half) && (result & 1)));
430 static constexpr std::uint32_t half2float(std::uint16_t data_) noexcept
432 using conv = detail::precision_converter<half, float>;
433 if (f16_traits::is_inf(data_)) {
434 return conv::shift_sign(data_) | f32_traits::exponent_mask;
435 } else if (f16_traits::is_nan(data_)) {
436 return conv::shift_sign(data_) | f32_traits::exponent_mask |
437 f32_traits::significand_mask;
438 } else if (f16_traits::is_denom(data_)) {
440 return conv::shift_sign(data_);
442 return conv::shift_sign(data_) | conv::shift_exponent(data_) |
443 conv::shift_significand(data_);
458 class complex<gko::half> {
460 using value_type = gko::half;
462 complex(const value_type& real = value_type(0.f),
463 const value_type& imag = value_type(0.f))
464 : real_(real), imag_(imag)
467 template <typename T, typename U,
468 typename = std::enable_if_t<std::is_scalar<T>::value &&
469 std::is_scalar<U>::value>>
470 explicit complex(const T& real, const U& imag)
471 : real_(static_cast<value_type>(real)),
472 imag_(static_cast<value_type>(imag))
475 template <typename T, typename = std::enable_if_t<std::is_scalar<T>::value>>
476 complex(const T& real)
477 : real_(static_cast<value_type>(real)),
478 imag_(static_cast<value_type>(0.f))
483 template <typename T, typename = std::enable_if_t<std::is_scalar<T>::value>>
484 explicit complex(const complex<T>& other)
485 : real_(static_cast<value_type>(other.real())),
486 imag_(static_cast<value_type>(other.imag()))
489 value_type real() const noexcept { return real_; }
491 value_type imag() const noexcept { return imag_; }
493 operator std::complex<float>() const noexcept
495 return std::complex<float>(static_cast<float>(real_),
496 static_cast<float>(imag_));
499 template <typename V>
500 complex& operator=(const V& val)
503 imag_ = value_type();
507 template <typename V>
508 complex& operator=(const std::complex<V>& val)
515 complex& operator+=(const value_type& real)
521 complex& operator-=(const value_type& real)
527 complex& operator*=(const value_type& real)
534 complex& operator/=(const value_type& real)
541 template <typename T>
542 complex& operator+=(const complex<T>& val)
549 template <typename T>
550 complex& operator-=(const complex<T>& val)
557 template <typename T>
558 complex& operator*=(const complex<T>& val)
560 auto val_f = static_cast<std::complex<float>>(val);
561 auto result_f = static_cast<std::complex<float>>(*this);
563 real_ = result_f.real();
564 imag_ = result_f.imag();
568 template <typename T>
569 complex& operator/=(const complex<T>& val)
571 auto val_f = static_cast<std::complex<float>>(val);
572 auto result_f = static_cast<std::complex<float>>(*this);
574 real_ = result_f.real();
575 imag_ = result_f.imag();
579 #define COMPLEX_HALF_OPERATOR(_op, _opeq) \
580 friend complex operator _op(const complex& lhf, const complex& rhf) \
587 COMPLEX_HALF_OPERATOR(+, +=)
588 COMPLEX_HALF_OPERATOR(-, -=)
589 COMPLEX_HALF_OPERATOR(*, *=)
590 COMPLEX_HALF_OPERATOR(/, /=)
592 #undef COMPLEX_HALF_OPERATOR
601 struct numeric_limits<gko::half> {
602 static constexpr bool is_specialized{true};
603 static constexpr bool is_signed{true};
604 static constexpr bool is_integer{false};
605 static constexpr bool is_exact{false};
606 static constexpr bool is_bounded{true};
607 static constexpr bool is_modulo{false};
608 static constexpr int digits{
609 gko::detail::float_traits<gko::half>::significand_bits + 1};
611 static constexpr int digits10{digits * 3 / 10};
613 static constexpr gko::half epsilon()
615 constexpr auto bits = static_cast<std::uint16_t>(0b0'00101'0000000000u);
616 return gko::half::create_from_bits(bits);
619 static constexpr gko::half infinity()
621 constexpr auto bits = static_cast<std::uint16_t>(0b0'11111'0000000000u);
622 return gko::half::create_from_bits(bits);
625 static constexpr gko::half min()
627 constexpr auto bits = static_cast<std::uint16_t>(0b0'00001'0000000000u);
628 return gko::half::create_from_bits(bits);
631 static constexpr gko::half max()
633 constexpr auto bits = static_cast<std::uint16_t>(0b0'11110'1111111111u);
634 return gko::half::create_from_bits(bits);
637 static constexpr gko::half lowest()
639 constexpr auto bits = static_cast<std::uint16_t>(0b1'11110'1111111111u);
640 return gko::half::create_from_bits(bits);
643 static constexpr gko::half quiet_NaN()
645 constexpr auto bits = static_cast<std::uint16_t>(0b0'11111'1111111111u);
646 return gko::half::create_from_bits(bits);
654 inline complex<double>& complex<double>::operator=(
655 const std::complex<gko::half>& a)
657 complex<double> t(a.real(), a.imag());
665 inline complex<float>& complex<float>::operator=(
666 const std::complex<gko::half>& a)
668 complex<float> t(a.real(), a.imag());