5 #ifndef GKO_PUBLIC_CORE_BASE_HALF_HPP_
6 #define GKO_PUBLIC_CORE_BASE_HALF_HPP_
13 #include <type_traits>
22 template <
typename, std::
size_t, std::
size_t>
34 constexpr std::size_t
byte_size = CHAR_BIT;
36 template <std::
size_t,
typename =
void>
37 struct uint_of_impl {};
39 template <std::
size_t Bits>
40 struct uint_of_impl<Bits, std::enable_if_t<(Bits <= 16)>> {
41 using type = std::uint16_t;
44 template <std::
size_t Bits>
45 struct uint_of_impl<Bits, std::enable_if_t<(16 < Bits && Bits <= 32)>> {
46 using type = std::uint32_t;
49 template <std::size_t Bits>
50 struct uint_of_impl<Bits, std::enable_if_t<(32 < Bits) && (Bits <= 64)>> {
51 using type = std::uint64_t;
54 template <std::size_t Bits>
55 using uint_of = typename uint_of_impl<Bits>::type;
59 struct basic_float_traits {};
62 struct basic_float_traits<half> {
64 static constexpr int sign_bits = 1;
65 static constexpr int significand_bits = 10;
66 static constexpr int exponent_bits = 5;
67 static constexpr bool rounds_to_nearest = true;
71 struct basic_float_traits<__half> {
73 static constexpr int sign_bits = 1;
74 static constexpr int significand_bits = 10;
75 static constexpr int exponent_bits = 5;
76 static constexpr bool rounds_to_nearest = true;
80 struct basic_float_traits<float> {
82 static constexpr int sign_bits = 1;
83 static constexpr int significand_bits = 23;
84 static constexpr int exponent_bits = 8;
85 static constexpr bool rounds_to_nearest = true;
89 struct basic_float_traits<double> {
91 static constexpr int sign_bits = 1;
92 static constexpr int significand_bits = 52;
93 static constexpr int exponent_bits = 11;
94 static constexpr bool rounds_to_nearest = true;
97 template <typename FloatType, std::size_t NumComponents,
98 std::size_t ComponentId>
99 struct basic_float_traits<truncated<FloatType, NumComponents, ComponentId>> {
100 using type = truncated<FloatType, NumComponents, ComponentId>;
101 static constexpr int sign_bits = ComponentId == 0 ? 1 : 0;
102 static constexpr int exponent_bits =
103 ComponentId == 0 ? basic_float_traits<FloatType>::exponent_bits : 0;
104 static constexpr int significand_bits =
105 ComponentId == 0 ? sizeof(type) * byte_size - exponent_bits - 1
106 : sizeof(type) * byte_size;
107 static constexpr bool rounds_to_nearest = false;
111 template <typename UintType>
112 constexpr UintType create_ones(int n)
114 return (n == sizeof(UintType) * byte_size ? static_cast<UintType>(0)
115 : static_cast<UintType>(1) << n) -
116 static_cast<UintType>(1);
120 template <typename T>
121 struct float_traits {
122 using type = typename basic_float_traits<T>::type;
123 using bits_type = uint_of<sizeof(type) * byte_size>;
124 static constexpr int sign_bits = basic_float_traits<T>::sign_bits;
125 static constexpr int significand_bits =
126 basic_float_traits<T>::significand_bits;
127 static constexpr int exponent_bits = basic_float_traits<T>::exponent_bits;
128 static constexpr bits_type significand_mask =
129 create_ones<bits_type>(significand_bits);
130 static constexpr bits_type exponent_mask =
131 create_ones<bits_type>(significand_bits + exponent_bits) -
133 static constexpr bits_type bias_mask =
134 create_ones<bits_type>(significand_bits + exponent_bits - 1) -
136 static constexpr bits_type sign_mask =
137 create_ones<bits_type>(sign_bits + significand_bits + exponent_bits) -
138 exponent_mask - significand_mask;
139 static constexpr bool rounds_to_nearest =
140 basic_float_traits<T>::rounds_to_nearest;
142 static constexpr auto eps =
143 1.0 / (1ll << (significand_bits + rounds_to_nearest));
145 static constexpr bool is_inf(bits_type data)
147 return (data & exponent_mask) == exponent_mask &&
148 (data & significand_mask) == bits_type{};
151 static constexpr bool is_nan(bits_type data)
153 return (data & exponent_mask) == exponent_mask &&
154 (data & significand_mask) != bits_type{};
157 static constexpr bool is_denom(bits_type data)
159 return (data & exponent_mask) == bits_type{};
164 template <typename SourceType, typename ResultType,
165 bool = (sizeof(SourceType) <= sizeof(ResultType))>
166 struct precision_converter;
169 template <typename SourceType, typename ResultType>
170 struct precision_converter<SourceType, ResultType, true> {
171 using source_traits = float_traits<SourceType>;
172 using result_traits = float_traits<ResultType>;
173 using source_bits = typename source_traits::bits_type;
174 using result_bits = typename result_traits::bits_type;
176 static_assert(source_traits::exponent_bits <=
177 result_traits::exponent_bits &&
178 source_traits::significand_bits <=
179 result_traits::significand_bits,
180 "SourceType has to have both lower range and precision or "
181 "higher range and precision than ResultType");
183 static constexpr int significand_offset =
184 result_traits::significand_bits - source_traits::significand_bits;
185 static constexpr int exponent_offset = significand_offset;
186 static constexpr int sign_offset = result_traits::exponent_bits -
187 source_traits::exponent_bits +
189 static constexpr result_bits bias_change =
190 result_traits::bias_mask -
191 (static_cast<result_bits>(source_traits::bias_mask) << exponent_offset);
193 static constexpr result_bits shift_significand(source_bits data) noexcept
195 return static_cast<result_bits>(data & source_traits::significand_mask)
196 << significand_offset;
199 static constexpr result_bits shift_exponent(source_bits data) noexcept
202 static_cast<result_bits>(data & source_traits::exponent_mask)
206 static constexpr result_bits shift_sign(source_bits data) noexcept
208 return static_cast<result_bits>(data & source_traits::sign_mask)
213 static constexpr result_bits update_bias(result_bits data) noexcept
215 return data == typename result_traits::bits_type{} ? data
216 : data + bias_change;
221 template <typename SourceType, typename ResultType>
222 struct precision_converter<SourceType, ResultType, false> {
223 using source_traits = float_traits<SourceType>;
224 using result_traits = float_traits<ResultType>;
225 using source_bits = typename source_traits::bits_type;
226 using result_bits = typename result_traits::bits_type;
228 static_assert(source_traits::exponent_bits >=
229 result_traits::exponent_bits &&
230 source_traits::significand_bits >=
231 result_traits::significand_bits,
232 "SourceType has to have both lower range and precision or "
233 "higher range and precision than ResultType");
235 static constexpr int significand_offset =
236 source_traits::significand_bits - result_traits::significand_bits;
237 static constexpr int exponent_offset = significand_offset;
238 static constexpr int sign_offset = source_traits::exponent_bits -
239 result_traits::exponent_bits +
241 static constexpr source_bits bias_change =
242 (source_traits::bias_mask >> exponent_offset) -
243 static_cast<source_bits>(result_traits::bias_mask);
245 static constexpr result_bits shift_significand(source_bits data) noexcept
247 return static_cast<result_bits>(
248 (data & source_traits::significand_mask) >> significand_offset);
251 static constexpr result_bits shift_exponent(source_bits data) noexcept
253 return static_cast<result_bits>(update_bias(
254 (data & source_traits::exponent_mask) >> exponent_offset));
257 static constexpr result_bits shift_sign(source_bits data) noexcept
259 return static_cast<result_bits>((data & source_traits::sign_mask) >>
264 static constexpr source_bits update_bias(source_bits data) noexcept
266 return data <= bias_change ? typename source_traits::bits_type{}
267 : limit_exponent(data - bias_change);
270 static constexpr source_bits limit_exponent(source_bits data) noexcept
272 return data >= static_cast<source_bits>(result_traits::exponent_mask)
273 ? static_cast<source_bits>(result_traits::exponent_mask)
288 class alignas(std::uint16_t) half {
291 static constexpr half create_from_bits(const std::uint16_t& bits) noexcept
301 constexpr half() noexcept : data_(0){};
303 template <typename T,
304 typename = std::enable_if_t<std::is_scalar<T>::value ||
305 std::is_same_v<T, bfloat16>>>
306 half(const T& val) : data_(0)
308 this->float2half(static_cast<float>(val));
311 template <typename V>
312 half& operator=(const V& val)
314 this->float2half(static_cast<float>(val));
318 operator float() const noexcept
320 const auto bits = half2float(data_);
322 std::memcpy(&ans, &bits, sizeof(float));
329 #define HALF_OPERATOR(_op, _opeq) \
330 friend half operator _op(const half& lhf, const half& rhf) \
332 return static_cast<half>(static_cast<float>(lhf) \
333 _op static_cast<float>(rhf)); \
335 half& operator _opeq(const half& hf) \
337 auto result = *this _op hf; \
338 data_ = result.data_; \
354 #define HALF_FRIEND_OPERATOR(_op, _opeq) \
355 template <typename T> \
356 friend std::enable_if_t< \
357 !std::is_same<T, half>::value && std::is_scalar<T>::value, \
358 std::conditional_t<std::is_floating_point<T>::value, T, half>> \
359 operator _op(const half& hf, const T& val) \
362 std::conditional_t<std::is_floating_point<T>::value, T, half>; \
363 auto result = static_cast<type>(hf); \
364 result _opeq static_cast<type>(val); \
367 template <typename T> \
368 friend std::enable_if_t< \
369 !std::is_same<T, half>::value && std::is_scalar<T>::value, \
370 std::conditional_t<std::is_floating_point<T>::value, T, half>> \
371 operator _op(const T& val, const half& hf) \
374 std::conditional_t<std::is_floating_point<T>::value, T, half>; \
375 auto result = static_cast<type>(val); \
376 result _opeq static_cast<type>(hf); \
380 HALF_FRIEND_OPERATOR(+, +=)
381 HALF_FRIEND_OPERATOR(-, -=)
382 HALF_FRIEND_OPERATOR(*, *=)
383 HALF_FRIEND_OPERATOR(/, /=)
385 #undef HALF_FRIEND_OPERATOR
388 half operator-() const
390 auto val = 0.0f - *this;
391 return static_cast<half>(val);
395 using f16_traits = detail::float_traits<half>;
396 using f32_traits = detail::float_traits<float>;
398 void float2half(const float& val) noexcept
400 std::uint32_t bit_val(0);
401 std::memcpy(&bit_val, &val, sizeof(float));
402 data_ = float2half(bit_val);
405 static constexpr std::uint16_t float2half(std::uint32_t data_) noexcept
407 using conv = detail::precision_converter<float, half>;
408 if (f32_traits::is_inf(data_)) {
409 return conv::shift_sign(data_) | f16_traits::exponent_mask;
410 } else if (f32_traits::is_nan(data_)) {
411 return conv::shift_sign(data_) | f16_traits::exponent_mask |
412 f16_traits::significand_mask;
414 const auto exp = conv::shift_exponent(data_);
415 if (f16_traits::is_inf(exp)) {
416 return conv::shift_sign(data_) | exp;
417 } else if (f16_traits::is_denom(exp)) {
419 return conv::shift_sign(data_);
422 const auto result = conv::shift_sign(data_) | exp |
423 conv::shift_significand(data_);
425 data_ & static_cast<f32_traits::bits_type>(
426 (1 << conv::significand_offset) - 1);
428 constexpr auto half = static_cast<f32_traits::bits_type>(
429 1 << (conv::significand_offset - 1));
431 (tail > half || ((tail == half) && (result & 1)));
436 static constexpr std::uint32_t half2float(std::uint16_t data_) noexcept
438 using conv = detail::precision_converter<half, float>;
439 if (f16_traits::is_inf(data_)) {
440 return conv::shift_sign(data_) | f32_traits::exponent_mask;
441 } else if (f16_traits::is_nan(data_)) {
442 return conv::shift_sign(data_) | f32_traits::exponent_mask |
443 f32_traits::significand_mask;
444 } else if (f16_traits::is_denom(data_)) {
446 return conv::shift_sign(data_);
448 return conv::shift_sign(data_) | conv::shift_exponent(data_) |
449 conv::shift_significand(data_);
464 class complex<gko::half> {
466 using value_type = gko::half;
468 complex(const value_type& real = value_type(0.f),
469 const value_type& imag = value_type(0.f))
470 : real_(real), imag_(imag)
474 typename T, typename U,
475 typename = std::enable_if_t<
476 (std::is_scalar<T>::value || std::is_same_v<T, gko::bfloat16>)&&(
477 std::is_scalar<U>::value || std::is_same_v<U, gko::bfloat16>)>>
478 explicit complex(const T& real, const U& imag)
479 : real_(static_cast<value_type>(real)),
480 imag_(static_cast<value_type>(imag))
483 template <typename T,
484 typename = std::enable_if_t<std::is_scalar<T>::value ||
485 std::is_same_v<T, gko::bfloat16>>>
486 complex(const T& real)
487 : real_(static_cast<value_type>(real)),
488 imag_(static_cast<value_type>(0.f))
493 template <typename T,
494 typename = std::enable_if_t<std::is_scalar<T>::value ||
495 std::is_same_v<T, gko::bfloat16>>>
496 explicit complex(const complex<T>& other)
497 : real_(static_cast<value_type>(other.real())),
498 imag_(static_cast<value_type>(other.imag()))
501 value_type real() const noexcept { return real_; }
503 value_type imag() const noexcept { return imag_; }
505 operator std::complex<float>() const noexcept
507 return std::complex<float>(static_cast<float>(real_),
508 static_cast<float>(imag_));
511 template <typename V>
512 complex& operator=(const V& val)
515 imag_ = value_type();
519 template <typename V>
520 complex& operator=(const std::complex<V>& val)
527 complex& operator+=(const value_type& real)
533 complex& operator-=(const value_type& real)
539 complex& operator*=(const value_type& real)
546 complex& operator/=(const value_type& real)
553 template <typename T>
554 complex& operator+=(const complex<T>& val)
561 template <typename T>
562 complex& operator-=(const complex<T>& val)
569 template <typename T>
570 complex& operator*=(const complex<T>& val)
572 auto val_f = static_cast<std::complex<float>>(val);
573 auto result_f = static_cast<std::complex<float>>(*this);
575 real_ = result_f.real();
576 imag_ = result_f.imag();
580 template <typename T>
581 complex& operator/=(const complex<T>& val)
583 auto val_f = static_cast<std::complex<float>>(val);
584 auto result_f = static_cast<std::complex<float>>(*this);
586 real_ = result_f.real();
587 imag_ = result_f.imag();
591 #define COMPLEX_HALF_OPERATOR(_op, _opeq) \
592 friend complex operator _op(const complex& lhf, const complex& rhf) \
599 COMPLEX_HALF_OPERATOR(+, +=)
600 COMPLEX_HALF_OPERATOR(-, -=)
601 COMPLEX_HALF_OPERATOR(*, *=)
602 COMPLEX_HALF_OPERATOR(/, /=)
604 #undef COMPLEX_HALF_OPERATOR
613 struct numeric_limits<gko::half> {
614 static constexpr bool is_specialized{true};
615 static constexpr bool is_signed{true};
616 static constexpr bool is_integer{false};
617 static constexpr bool is_exact{false};
618 static constexpr bool is_bounded{true};
619 static constexpr bool is_modulo{false};
620 static constexpr int digits{
621 gko::detail::float_traits<gko::half>::significand_bits + 1};
623 static constexpr int digits10{digits * 3 / 10};
625 static constexpr gko::half epsilon()
627 constexpr auto bits = static_cast<std::uint16_t>(0b0'00101'0000000000u);
628 return gko::half::create_from_bits(bits);
631 static constexpr gko::half infinity()
633 constexpr auto bits = static_cast<std::uint16_t>(0b0'11111'0000000000u);
634 return gko::half::create_from_bits(bits);
637 static constexpr gko::half min()
639 constexpr auto bits = static_cast<std::uint16_t>(0b0'00001'0000000000u);
640 return gko::half::create_from_bits(bits);
643 static constexpr gko::half max()
645 constexpr auto bits = static_cast<std::uint16_t>(0b0'11110'1111111111u);
646 return gko::half::create_from_bits(bits);
649 static constexpr gko::half lowest()
651 constexpr auto bits = static_cast<std::uint16_t>(0b1'11110'1111111111u);
652 return gko::half::create_from_bits(bits);
655 static constexpr gko::half quiet_NaN()
657 constexpr auto bits = static_cast<std::uint16_t>(0b0'11111'1111111111u);
658 return gko::half::create_from_bits(bits);
666 inline complex<double>& complex<double>::operator=(
667 const std::complex<gko::half>& a)
669 complex<double> t(a.real(), a.imag());
677 inline complex<float>& complex<float>::operator=(
678 const std::complex<gko::half>& a)
680 complex<float> t(a.real(), a.imag());