Ginkgo  Generated from pipelines/2118098289 branch based on develop. Ginkgo version 1.11.0
A numerical linear algebra library targeting many-core architectures
half.hpp
1 // SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
2 //
3 // SPDX-License-Identifier: BSD-3-Clause
4 
5 #ifndef GKO_PUBLIC_CORE_BASE_HALF_HPP_
6 #define GKO_PUBLIC_CORE_BASE_HALF_HPP_
7 
8 
9 #include <climits>
10 #include <complex>
11 #include <cstdint>
12 #include <cstring>
13 #include <type_traits>
14 
15 
16 class __half;
17 
18 
19 namespace gko {
20 
21 
22 template <typename, std::size_t, std::size_t>
23 class truncated;
24 
25 
26 class half;
27 
28 class bfloat16;
29 
30 
31 namespace detail {
32 
33 
34 constexpr std::size_t byte_size = CHAR_BIT;
35 
36 template <std::size_t, typename = void>
37 struct uint_of_impl {};
38 
39 template <std::size_t Bits>
40 struct uint_of_impl<Bits, std::enable_if_t<(Bits <= 16)>> {
41  using type = std::uint16_t;
42 };
43 
44 template <std::size_t Bits>
45 struct uint_of_impl<Bits, std::enable_if_t<(16 < Bits && Bits <= 32)>> {
46  using type = std::uint32_t;
47 };
48 
49 template <std::size_t Bits>
50 struct uint_of_impl<Bits, std::enable_if_t<(32 < Bits) && (Bits <= 64)>> {
51  using type = std::uint64_t;
52 };
53 
54 template <std::size_t Bits>
55 using uint_of = typename uint_of_impl<Bits>::type;
56 
57 
58 template <typename T>
59 struct basic_float_traits {};
60 
61 template <>
62 struct basic_float_traits<half> {
63  using type = half;
64  static constexpr int sign_bits = 1;
65  static constexpr int significand_bits = 10;
66  static constexpr int exponent_bits = 5;
67  static constexpr bool rounds_to_nearest = true;
68 };
69 
70 template <>
71 struct basic_float_traits<__half> {
72  using type = __half;
73  static constexpr int sign_bits = 1;
74  static constexpr int significand_bits = 10;
75  static constexpr int exponent_bits = 5;
76  static constexpr bool rounds_to_nearest = true;
77 };
78 
79 template <>
80 struct basic_float_traits<float> {
81  using type = float;
82  static constexpr int sign_bits = 1;
83  static constexpr int significand_bits = 23;
84  static constexpr int exponent_bits = 8;
85  static constexpr bool rounds_to_nearest = true;
86 };
87 
88 template <>
89 struct basic_float_traits<double> {
90  using type = double;
91  static constexpr int sign_bits = 1;
92  static constexpr int significand_bits = 52;
93  static constexpr int exponent_bits = 11;
94  static constexpr bool rounds_to_nearest = true;
95 };
96 
97 template <typename FloatType, std::size_t NumComponents,
98  std::size_t ComponentId>
99 struct basic_float_traits<truncated<FloatType, NumComponents, ComponentId>> {
100  using type = truncated<FloatType, NumComponents, ComponentId>;
101  static constexpr int sign_bits = ComponentId == 0 ? 1 : 0;
102  static constexpr int exponent_bits =
103  ComponentId == 0 ? basic_float_traits<FloatType>::exponent_bits : 0;
104  static constexpr int significand_bits =
105  ComponentId == 0 ? sizeof(type) * byte_size - exponent_bits - 1
106  : sizeof(type) * byte_size;
107  static constexpr bool rounds_to_nearest = false;
108 };
109 
110 
111 template <typename UintType>
112 constexpr UintType create_ones(int n)
113 {
114  return (n == sizeof(UintType) * byte_size ? static_cast<UintType>(0)
115  : static_cast<UintType>(1) << n) -
116  static_cast<UintType>(1);
117 }
118 
119 
120 template <typename T>
121 struct float_traits {
122  using type = typename basic_float_traits<T>::type;
123  using bits_type = uint_of<sizeof(type) * byte_size>;
124  static constexpr int sign_bits = basic_float_traits<T>::sign_bits;
125  static constexpr int significand_bits =
126  basic_float_traits<T>::significand_bits;
127  static constexpr int exponent_bits = basic_float_traits<T>::exponent_bits;
128  static constexpr bits_type significand_mask =
129  create_ones<bits_type>(significand_bits);
130  static constexpr bits_type exponent_mask =
131  create_ones<bits_type>(significand_bits + exponent_bits) -
132  significand_mask;
133  static constexpr bits_type bias_mask =
134  create_ones<bits_type>(significand_bits + exponent_bits - 1) -
135  significand_mask;
136  static constexpr bits_type sign_mask =
137  create_ones<bits_type>(sign_bits + significand_bits + exponent_bits) -
138  exponent_mask - significand_mask;
139  static constexpr bool rounds_to_nearest =
140  basic_float_traits<T>::rounds_to_nearest;
141 
142  static constexpr auto eps =
143  1.0 / (1ll << (significand_bits + rounds_to_nearest));
144 
145  static constexpr bool is_inf(bits_type data)
146  {
147  return (data & exponent_mask) == exponent_mask &&
148  (data & significand_mask) == bits_type{};
149  }
150 
151  static constexpr bool is_nan(bits_type data)
152  {
153  return (data & exponent_mask) == exponent_mask &&
154  (data & significand_mask) != bits_type{};
155  }
156 
157  static constexpr bool is_denom(bits_type data)
158  {
159  return (data & exponent_mask) == bits_type{};
160  }
161 };
162 
163 
164 template <typename SourceType, typename ResultType,
165  bool = (sizeof(SourceType) <= sizeof(ResultType))>
166 struct precision_converter;
167 
168 // upcasting implementation details
169 template <typename SourceType, typename ResultType>
170 struct precision_converter<SourceType, ResultType, true> {
171  using source_traits = float_traits<SourceType>;
172  using result_traits = float_traits<ResultType>;
173  using source_bits = typename source_traits::bits_type;
174  using result_bits = typename result_traits::bits_type;
175 
176  static_assert(source_traits::exponent_bits <=
177  result_traits::exponent_bits &&
178  source_traits::significand_bits <=
179  result_traits::significand_bits,
180  "SourceType has to have both lower range and precision or "
181  "higher range and precision than ResultType");
182 
183  static constexpr int significand_offset =
184  result_traits::significand_bits - source_traits::significand_bits;
185  static constexpr int exponent_offset = significand_offset;
186  static constexpr int sign_offset = result_traits::exponent_bits -
187  source_traits::exponent_bits +
188  exponent_offset;
189  static constexpr result_bits bias_change =
190  result_traits::bias_mask -
191  (static_cast<result_bits>(source_traits::bias_mask) << exponent_offset);
192 
193  static constexpr result_bits shift_significand(source_bits data) noexcept
194  {
195  return static_cast<result_bits>(data & source_traits::significand_mask)
196  << significand_offset;
197  }
198 
199  static constexpr result_bits shift_exponent(source_bits data) noexcept
200  {
201  return update_bias(
202  static_cast<result_bits>(data & source_traits::exponent_mask)
203  << exponent_offset);
204  }
205 
206  static constexpr result_bits shift_sign(source_bits data) noexcept
207  {
208  return static_cast<result_bits>(data & source_traits::sign_mask)
209  << sign_offset;
210  }
211 
212 private:
213  static constexpr result_bits update_bias(result_bits data) noexcept
214  {
215  return data == typename result_traits::bits_type{} ? data
216  : data + bias_change;
217  }
218 };
219 
220 // downcasting implementation details
221 template <typename SourceType, typename ResultType>
222 struct precision_converter<SourceType, ResultType, false> {
223  using source_traits = float_traits<SourceType>;
224  using result_traits = float_traits<ResultType>;
225  using source_bits = typename source_traits::bits_type;
226  using result_bits = typename result_traits::bits_type;
227 
228  static_assert(source_traits::exponent_bits >=
229  result_traits::exponent_bits &&
230  source_traits::significand_bits >=
231  result_traits::significand_bits,
232  "SourceType has to have both lower range and precision or "
233  "higher range and precision than ResultType");
234 
235  static constexpr int significand_offset =
236  source_traits::significand_bits - result_traits::significand_bits;
237  static constexpr int exponent_offset = significand_offset;
238  static constexpr int sign_offset = source_traits::exponent_bits -
239  result_traits::exponent_bits +
240  exponent_offset;
241  static constexpr source_bits bias_change =
242  (source_traits::bias_mask >> exponent_offset) -
243  static_cast<source_bits>(result_traits::bias_mask);
244 
245  static constexpr result_bits shift_significand(source_bits data) noexcept
246  {
247  return static_cast<result_bits>(
248  (data & source_traits::significand_mask) >> significand_offset);
249  }
250 
251  static constexpr result_bits shift_exponent(source_bits data) noexcept
252  {
253  return static_cast<result_bits>(update_bias(
254  (data & source_traits::exponent_mask) >> exponent_offset));
255  }
256 
257  static constexpr result_bits shift_sign(source_bits data) noexcept
258  {
259  return static_cast<result_bits>((data & source_traits::sign_mask) >>
260  sign_offset);
261  }
262 
263 private:
264  static constexpr source_bits update_bias(source_bits data) noexcept
265  {
266  return data <= bias_change ? typename source_traits::bits_type{}
267  : limit_exponent(data - bias_change);
268  }
269 
270  static constexpr source_bits limit_exponent(source_bits data) noexcept
271  {
272  return data >= static_cast<source_bits>(result_traits::exponent_mask)
273  ? static_cast<source_bits>(result_traits::exponent_mask)
274  : data;
275  }
276 };
277 
278 
279 } // namespace detail
280 
281 
288 class alignas(std::uint16_t) half {
289 public:
290  // create half value from the bits directly.
291  static constexpr half create_from_bits(const std::uint16_t& bits) noexcept
292  {
293  half result;
294  result.data_ = bits;
295  return result;
296  }
297 
298  // TODO: NVHPC (host side) may not use zero initialization for the data
299  // member by default constructor in some cases. Not sure whether it is
300  // caused by something else in jacobi or isai.
301  constexpr half() noexcept : data_(0){};
302 
303  template <typename T,
304  typename = std::enable_if_t<std::is_scalar<T>::value ||
305  std::is_same_v<T, bfloat16>>>
306  half(const T& val) : data_(0)
307  {
308  this->float2half(static_cast<float>(val));
309  }
310 
311  template <typename V>
312  half& operator=(const V& val)
313  {
314  this->float2half(static_cast<float>(val));
315  return *this;
316  }
317 
318  operator float() const noexcept
319  {
320  const auto bits = half2float(data_);
321  float ans(0);
322  std::memcpy(&ans, &bits, sizeof(float));
323  return ans;
324  }
325 
326  // can not use half operator _op(const half) for half + half
327  // operation will cast it to float and then do float operation such that it
328  // becomes float in the end.
329 #define HALF_OPERATOR(_op, _opeq) \
330  friend half operator _op(const half& lhf, const half& rhf) \
331  { \
332  return static_cast<half>(static_cast<float>(lhf) \
333  _op static_cast<float>(rhf)); \
334  } \
335  half& operator _opeq(const half& hf) \
336  { \
337  auto result = *this _op hf; \
338  data_ = result.data_; \
339  return *this; \
340  }
341 
342  HALF_OPERATOR(+, +=)
343  HALF_OPERATOR(-, -=)
344  HALF_OPERATOR(*, *=)
345  HALF_OPERATOR(/, /=)
346 
347 #undef HALF_OPERATOR
348 
349  // Do operation with different type
350  // If it is floating point, using floating point as type.
351  // If it is integer, using half as type
352  // Note: we do not define the operation with bfloat16, which is already
353  // defined in bfloat16.hpp
354 #define HALF_FRIEND_OPERATOR(_op, _opeq) \
355  template <typename T> \
356  friend std::enable_if_t< \
357  !std::is_same<T, half>::value && std::is_scalar<T>::value, \
358  std::conditional_t<std::is_floating_point<T>::value, T, half>> \
359  operator _op(const half& hf, const T& val) \
360  { \
361  using type = \
362  std::conditional_t<std::is_floating_point<T>::value, T, half>; \
363  auto result = static_cast<type>(hf); \
364  result _opeq static_cast<type>(val); \
365  return result; \
366  } \
367  template <typename T> \
368  friend std::enable_if_t< \
369  !std::is_same<T, half>::value && std::is_scalar<T>::value, \
370  std::conditional_t<std::is_floating_point<T>::value, T, half>> \
371  operator _op(const T& val, const half& hf) \
372  { \
373  using type = \
374  std::conditional_t<std::is_floating_point<T>::value, T, half>; \
375  auto result = static_cast<type>(val); \
376  result _opeq static_cast<type>(hf); \
377  return result; \
378  }
379 
380  HALF_FRIEND_OPERATOR(+, +=)
381  HALF_FRIEND_OPERATOR(-, -=)
382  HALF_FRIEND_OPERATOR(*, *=)
383  HALF_FRIEND_OPERATOR(/, /=)
384 
385 #undef HALF_FRIEND_OPERATOR
386 
387  // the negative
388  half operator-() const
389  {
390  auto val = 0.0f - *this;
391  return static_cast<half>(val);
392  }
393 
394 private:
395  using f16_traits = detail::float_traits<half>;
396  using f32_traits = detail::float_traits<float>;
397 
398  void float2half(const float& val) noexcept
399  {
400  std::uint32_t bit_val(0);
401  std::memcpy(&bit_val, &val, sizeof(float));
402  data_ = float2half(bit_val);
403  }
404 
405  static constexpr std::uint16_t float2half(std::uint32_t data_) noexcept
406  {
407  using conv = detail::precision_converter<float, half>;
408  if (f32_traits::is_inf(data_)) {
409  return conv::shift_sign(data_) | f16_traits::exponent_mask;
410  } else if (f32_traits::is_nan(data_)) {
411  return conv::shift_sign(data_) | f16_traits::exponent_mask |
412  f16_traits::significand_mask;
413  } else {
414  const auto exp = conv::shift_exponent(data_);
415  if (f16_traits::is_inf(exp)) {
416  return conv::shift_sign(data_) | exp;
417  } else if (f16_traits::is_denom(exp)) {
418  // TODO: handle denormals
419  return conv::shift_sign(data_);
420  } else {
421  // Rounding to even
422  const auto result = conv::shift_sign(data_) | exp |
423  conv::shift_significand(data_);
424  const auto tail =
425  data_ & static_cast<f32_traits::bits_type>(
426  (1 << conv::significand_offset) - 1);
427 
428  constexpr auto half = static_cast<f32_traits::bits_type>(
429  1 << (conv::significand_offset - 1));
430  return result +
431  (tail > half || ((tail == half) && (result & 1)));
432  }
433  }
434  }
435 
436  static constexpr std::uint32_t half2float(std::uint16_t data_) noexcept
437  {
438  using conv = detail::precision_converter<half, float>;
439  if (f16_traits::is_inf(data_)) {
440  return conv::shift_sign(data_) | f32_traits::exponent_mask;
441  } else if (f16_traits::is_nan(data_)) {
442  return conv::shift_sign(data_) | f32_traits::exponent_mask |
443  f32_traits::significand_mask;
444  } else if (f16_traits::is_denom(data_)) {
445  // TODO: handle denormals
446  return conv::shift_sign(data_);
447  } else {
448  return conv::shift_sign(data_) | conv::shift_exponent(data_) |
449  conv::shift_significand(data_);
450  }
451  }
452 
453  std::uint16_t data_;
454 };
455 
456 
457 } // namespace gko
458 
459 
460 namespace std {
461 
462 
463 template <>
464 class complex<gko::half> {
465 public:
466  using value_type = gko::half;
467 
468  complex(const value_type& real = value_type(0.f),
469  const value_type& imag = value_type(0.f))
470  : real_(real), imag_(imag)
471  {}
472 
473  template <
474  typename T, typename U,
475  typename = std::enable_if_t<
476  (std::is_scalar<T>::value || std::is_same_v<T, gko::bfloat16>)&&(
477  std::is_scalar<U>::value || std::is_same_v<U, gko::bfloat16>)>>
478  explicit complex(const T& real, const U& imag)
479  : real_(static_cast<value_type>(real)),
480  imag_(static_cast<value_type>(imag))
481  {}
482 
483  template <typename T,
484  typename = std::enable_if_t<std::is_scalar<T>::value ||
485  std::is_same_v<T, gko::bfloat16>>>
486  complex(const T& real)
487  : real_(static_cast<value_type>(real)),
488  imag_(static_cast<value_type>(0.f))
489  {}
490 
491  // When using complex(real, imag), MSVC with CUDA try to recognize the
492  // complex is a member not constructor.
493  template <typename T,
494  typename = std::enable_if_t<std::is_scalar<T>::value ||
495  std::is_same_v<T, gko::bfloat16>>>
496  explicit complex(const complex<T>& other)
497  : real_(static_cast<value_type>(other.real())),
498  imag_(static_cast<value_type>(other.imag()))
499  {}
500 
501  value_type real() const noexcept { return real_; }
502 
503  value_type imag() const noexcept { return imag_; }
504 
505  operator std::complex<float>() const noexcept
506  {
507  return std::complex<float>(static_cast<float>(real_),
508  static_cast<float>(imag_));
509  }
510 
511  template <typename V>
512  complex& operator=(const V& val)
513  {
514  real_ = val;
515  imag_ = value_type();
516  return *this;
517  }
518 
519  template <typename V>
520  complex& operator=(const std::complex<V>& val)
521  {
522  real_ = val.real();
523  imag_ = val.imag();
524  return *this;
525  }
526 
527  complex& operator+=(const value_type& real)
528  {
529  real_ += real;
530  return *this;
531  }
532 
533  complex& operator-=(const value_type& real)
534  {
535  real_ -= real;
536  return *this;
537  }
538 
539  complex& operator*=(const value_type& real)
540  {
541  real_ *= real;
542  imag_ *= real;
543  return *this;
544  }
545 
546  complex& operator/=(const value_type& real)
547  {
548  real_ /= real;
549  imag_ /= real;
550  return *this;
551  }
552 
553  template <typename T>
554  complex& operator+=(const complex<T>& val)
555  {
556  real_ += val.real();
557  imag_ += val.imag();
558  return *this;
559  }
560 
561  template <typename T>
562  complex& operator-=(const complex<T>& val)
563  {
564  real_ -= val.real();
565  imag_ -= val.imag();
566  return *this;
567  }
568 
569  template <typename T>
570  complex& operator*=(const complex<T>& val)
571  {
572  auto val_f = static_cast<std::complex<float>>(val);
573  auto result_f = static_cast<std::complex<float>>(*this);
574  result_f *= val_f;
575  real_ = result_f.real();
576  imag_ = result_f.imag();
577  return *this;
578  }
579 
580  template <typename T>
581  complex& operator/=(const complex<T>& val)
582  {
583  auto val_f = static_cast<std::complex<float>>(val);
584  auto result_f = static_cast<std::complex<float>>(*this);
585  result_f /= val_f;
586  real_ = result_f.real();
587  imag_ = result_f.imag();
588  return *this;
589  }
590 
591 #define COMPLEX_HALF_OPERATOR(_op, _opeq) \
592  friend complex operator _op(const complex& lhf, const complex& rhf) \
593  { \
594  auto a = lhf; \
595  a _opeq rhf; \
596  return a; \
597  }
598 
599  COMPLEX_HALF_OPERATOR(+, +=)
600  COMPLEX_HALF_OPERATOR(-, -=)
601  COMPLEX_HALF_OPERATOR(*, *=)
602  COMPLEX_HALF_OPERATOR(/, /=)
603 
604 #undef COMPLEX_HALF_OPERATOR
605 
606 private:
607  value_type real_;
608  value_type imag_;
609 };
610 
611 
612 template <>
613 struct numeric_limits<gko::half> {
614  static constexpr bool is_specialized{true};
615  static constexpr bool is_signed{true};
616  static constexpr bool is_integer{false};
617  static constexpr bool is_exact{false};
618  static constexpr bool is_bounded{true};
619  static constexpr bool is_modulo{false};
620  static constexpr int digits{
621  gko::detail::float_traits<gko::half>::significand_bits + 1};
622  // 3/10 is approx. log_10(2)
623  static constexpr int digits10{digits * 3 / 10};
624 
625  static constexpr gko::half epsilon()
626  {
627  constexpr auto bits = static_cast<std::uint16_t>(0b0'00101'0000000000u);
628  return gko::half::create_from_bits(bits);
629  }
630 
631  static constexpr gko::half infinity()
632  {
633  constexpr auto bits = static_cast<std::uint16_t>(0b0'11111'0000000000u);
634  return gko::half::create_from_bits(bits);
635  }
636 
637  static constexpr gko::half min()
638  {
639  constexpr auto bits = static_cast<std::uint16_t>(0b0'00001'0000000000u);
640  return gko::half::create_from_bits(bits);
641  }
642 
643  static constexpr gko::half max()
644  {
645  constexpr auto bits = static_cast<std::uint16_t>(0b0'11110'1111111111u);
646  return gko::half::create_from_bits(bits);
647  }
648 
649  static constexpr gko::half lowest()
650  {
651  constexpr auto bits = static_cast<std::uint16_t>(0b1'11110'1111111111u);
652  return gko::half::create_from_bits(bits);
653  };
654 
655  static constexpr gko::half quiet_NaN()
656  {
657  constexpr auto bits = static_cast<std::uint16_t>(0b0'11111'1111111111u);
658  return gko::half::create_from_bits(bits);
659  }
660 };
661 
662 
663 // complex using a template on operator= for any kind of complex<T>, so we can
664 // do full specialization for half
665 template <>
666 inline complex<double>& complex<double>::operator=(
667  const std::complex<gko::half>& a)
668 {
669  complex<double> t(a.real(), a.imag());
670  operator=(t);
671  return *this;
672 }
673 
674 
675 // For MSVC
676 template <>
677 inline complex<float>& complex<float>::operator=(
678  const std::complex<gko::half>& a)
679 {
680  complex<float> t(a.real(), a.imag());
681  operator=(t);
682  return *this;
683 }
684 
685 
686 } // namespace std
687 
688 
689 #endif // GKO_PUBLIC_CORE_BASE_HALF_HPP_
gko::bfloat16
A class providing basic support for bfloat16 precision floating point types.
Definition: bfloat16.hpp:76
gko::truncated
Definition: half.hpp:23
gko::byte_size
constexpr size_type byte_size
Number of bits in a byte.
Definition: types.hpp:178
gko
The Ginkgo namespace.
Definition: abstract_factory.hpp:20
gko::half
A class providing basic support for half precision floating point types.
Definition: half.hpp:288