Ginkgo  Generated from pipelines/1589998975 branch based on develop. Ginkgo version 1.10.0
A numerical linear algebra library targeting many-core architectures
half.hpp
1 // SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
2 //
3 // SPDX-License-Identifier: BSD-3-Clause
4 
5 #ifndef GKO_PUBLIC_CORE_BASE_HALF_HPP_
6 #define GKO_PUBLIC_CORE_BASE_HALF_HPP_
7 
8 
9 #include <climits>
10 #include <complex>
11 #include <cstdint>
12 #include <cstring>
13 #include <type_traits>
14 
15 
16 class __half;
17 
18 
19 namespace gko {
20 
21 
22 template <typename, std::size_t, std::size_t>
23 class truncated;
24 
25 
26 class half;
27 
28 
29 namespace detail {
30 
31 
32 constexpr std::size_t byte_size = CHAR_BIT;
33 
34 template <std::size_t, typename = void>
35 struct uint_of_impl {};
36 
37 template <std::size_t Bits>
38 struct uint_of_impl<Bits, std::enable_if_t<(Bits <= 16)>> {
39  using type = std::uint16_t;
40 };
41 
42 template <std::size_t Bits>
43 struct uint_of_impl<Bits, std::enable_if_t<(16 < Bits && Bits <= 32)>> {
44  using type = std::uint32_t;
45 };
46 
47 template <std::size_t Bits>
48 struct uint_of_impl<Bits, std::enable_if_t<(32 < Bits) && (Bits <= 64)>> {
49  using type = std::uint64_t;
50 };
51 
52 template <std::size_t Bits>
53 using uint_of = typename uint_of_impl<Bits>::type;
54 
55 
56 template <typename T>
57 struct basic_float_traits {};
58 
59 template <>
60 struct basic_float_traits<half> {
61  using type = half;
62  static constexpr int sign_bits = 1;
63  static constexpr int significand_bits = 10;
64  static constexpr int exponent_bits = 5;
65  static constexpr bool rounds_to_nearest = true;
66 };
67 
68 template <>
69 struct basic_float_traits<__half> {
70  using type = __half;
71  static constexpr int sign_bits = 1;
72  static constexpr int significand_bits = 10;
73  static constexpr int exponent_bits = 5;
74  static constexpr bool rounds_to_nearest = true;
75 };
76 
77 template <>
78 struct basic_float_traits<float> {
79  using type = float;
80  static constexpr int sign_bits = 1;
81  static constexpr int significand_bits = 23;
82  static constexpr int exponent_bits = 8;
83  static constexpr bool rounds_to_nearest = true;
84 };
85 
86 template <>
87 struct basic_float_traits<double> {
88  using type = double;
89  static constexpr int sign_bits = 1;
90  static constexpr int significand_bits = 52;
91  static constexpr int exponent_bits = 11;
92  static constexpr bool rounds_to_nearest = true;
93 };
94 
95 template <typename FloatType, std::size_t NumComponents,
96  std::size_t ComponentId>
97 struct basic_float_traits<truncated<FloatType, NumComponents, ComponentId>> {
98  using type = truncated<FloatType, NumComponents, ComponentId>;
99  static constexpr int sign_bits = ComponentId == 0 ? 1 : 0;
100  static constexpr int exponent_bits =
101  ComponentId == 0 ? basic_float_traits<FloatType>::exponent_bits : 0;
102  static constexpr int significand_bits =
103  ComponentId == 0 ? sizeof(type) * byte_size - exponent_bits - 1
104  : sizeof(type) * byte_size;
105  static constexpr bool rounds_to_nearest = false;
106 };
107 
108 
109 template <typename UintType>
110 constexpr UintType create_ones(int n)
111 {
112  return (n == sizeof(UintType) * byte_size ? static_cast<UintType>(0)
113  : static_cast<UintType>(1) << n) -
114  static_cast<UintType>(1);
115 }
116 
117 
118 template <typename T>
119 struct float_traits {
120  using type = typename basic_float_traits<T>::type;
121  using bits_type = uint_of<sizeof(type) * byte_size>;
122  static constexpr int sign_bits = basic_float_traits<T>::sign_bits;
123  static constexpr int significand_bits =
124  basic_float_traits<T>::significand_bits;
125  static constexpr int exponent_bits = basic_float_traits<T>::exponent_bits;
126  static constexpr bits_type significand_mask =
127  create_ones<bits_type>(significand_bits);
128  static constexpr bits_type exponent_mask =
129  create_ones<bits_type>(significand_bits + exponent_bits) -
130  significand_mask;
131  static constexpr bits_type bias_mask =
132  create_ones<bits_type>(significand_bits + exponent_bits - 1) -
133  significand_mask;
134  static constexpr bits_type sign_mask =
135  create_ones<bits_type>(sign_bits + significand_bits + exponent_bits) -
136  exponent_mask - significand_mask;
137  static constexpr bool rounds_to_nearest =
138  basic_float_traits<T>::rounds_to_nearest;
139 
140  static constexpr auto eps =
141  1.0 / (1ll << (significand_bits + rounds_to_nearest));
142 
143  static constexpr bool is_inf(bits_type data)
144  {
145  return (data & exponent_mask) == exponent_mask &&
146  (data & significand_mask) == bits_type{};
147  }
148 
149  static constexpr bool is_nan(bits_type data)
150  {
151  return (data & exponent_mask) == exponent_mask &&
152  (data & significand_mask) != bits_type{};
153  }
154 
155  static constexpr bool is_denom(bits_type data)
156  {
157  return (data & exponent_mask) == bits_type{};
158  }
159 };
160 
161 
162 template <typename SourceType, typename ResultType,
163  bool = (sizeof(SourceType) <= sizeof(ResultType))>
164 struct precision_converter;
165 
166 // upcasting implementation details
167 template <typename SourceType, typename ResultType>
168 struct precision_converter<SourceType, ResultType, true> {
169  using source_traits = float_traits<SourceType>;
170  using result_traits = float_traits<ResultType>;
171  using source_bits = typename source_traits::bits_type;
172  using result_bits = typename result_traits::bits_type;
173 
174  static_assert(source_traits::exponent_bits <=
175  result_traits::exponent_bits &&
176  source_traits::significand_bits <=
177  result_traits::significand_bits,
178  "SourceType has to have both lower range and precision or "
179  "higher range and precision than ResultType");
180 
181  static constexpr int significand_offset =
182  result_traits::significand_bits - source_traits::significand_bits;
183  static constexpr int exponent_offset = significand_offset;
184  static constexpr int sign_offset = result_traits::exponent_bits -
185  source_traits::exponent_bits +
186  exponent_offset;
187  static constexpr result_bits bias_change =
188  result_traits::bias_mask -
189  (static_cast<result_bits>(source_traits::bias_mask) << exponent_offset);
190 
191  static constexpr result_bits shift_significand(source_bits data) noexcept
192  {
193  return static_cast<result_bits>(data & source_traits::significand_mask)
194  << significand_offset;
195  }
196 
197  static constexpr result_bits shift_exponent(source_bits data) noexcept
198  {
199  return update_bias(
200  static_cast<result_bits>(data & source_traits::exponent_mask)
201  << exponent_offset);
202  }
203 
204  static constexpr result_bits shift_sign(source_bits data) noexcept
205  {
206  return static_cast<result_bits>(data & source_traits::sign_mask)
207  << sign_offset;
208  }
209 
210 private:
211  static constexpr result_bits update_bias(result_bits data) noexcept
212  {
213  return data == typename result_traits::bits_type{} ? data
214  : data + bias_change;
215  }
216 };
217 
218 // downcasting implementation details
219 template <typename SourceType, typename ResultType>
220 struct precision_converter<SourceType, ResultType, false> {
221  using source_traits = float_traits<SourceType>;
222  using result_traits = float_traits<ResultType>;
223  using source_bits = typename source_traits::bits_type;
224  using result_bits = typename result_traits::bits_type;
225 
226  static_assert(source_traits::exponent_bits >=
227  result_traits::exponent_bits &&
228  source_traits::significand_bits >=
229  result_traits::significand_bits,
230  "SourceType has to have both lower range and precision or "
231  "higher range and precision than ResultType");
232 
233  static constexpr int significand_offset =
234  source_traits::significand_bits - result_traits::significand_bits;
235  static constexpr int exponent_offset = significand_offset;
236  static constexpr int sign_offset = source_traits::exponent_bits -
237  result_traits::exponent_bits +
238  exponent_offset;
239  static constexpr source_bits bias_change =
240  (source_traits::bias_mask >> exponent_offset) -
241  static_cast<source_bits>(result_traits::bias_mask);
242 
243  static constexpr result_bits shift_significand(source_bits data) noexcept
244  {
245  return static_cast<result_bits>(
246  (data & source_traits::significand_mask) >> significand_offset);
247  }
248 
249  static constexpr result_bits shift_exponent(source_bits data) noexcept
250  {
251  return static_cast<result_bits>(update_bias(
252  (data & source_traits::exponent_mask) >> exponent_offset));
253  }
254 
255  static constexpr result_bits shift_sign(source_bits data) noexcept
256  {
257  return static_cast<result_bits>((data & source_traits::sign_mask) >>
258  sign_offset);
259  }
260 
261 private:
262  static constexpr source_bits update_bias(source_bits data) noexcept
263  {
264  return data <= bias_change ? typename source_traits::bits_type{}
265  : limit_exponent(data - bias_change);
266  }
267 
268  static constexpr source_bits limit_exponent(source_bits data) noexcept
269  {
270  return data >= static_cast<source_bits>(result_traits::exponent_mask)
271  ? static_cast<source_bits>(result_traits::exponent_mask)
272  : data;
273  }
274 };
275 
276 
277 } // namespace detail
278 
279 
286 class alignas(std::uint16_t) half {
287 public:
288  // create half value from the bits directly.
289  static constexpr half create_from_bits(const std::uint16_t& bits) noexcept
290  {
291  half result;
292  result.data_ = bits;
293  return result;
294  }
295 
296  // TODO: NVHPC (host side) may not use zero initialization for the data
297  // member by default constructor in some cases. Not sure whether it is
298  // caused by something else in jacobi or isai.
299  constexpr half() noexcept : data_(0){};
300 
301  template <typename T, typename = std::enable_if_t<std::is_scalar<T>::value>>
302  half(const T& val) : data_(0)
303  {
304  this->float2half(static_cast<float>(val));
305  }
306 
307  template <typename V>
308  half& operator=(const V& val)
309  {
310  this->float2half(static_cast<float>(val));
311  return *this;
312  }
313 
314  operator float() const noexcept
315  {
316  const auto bits = half2float(data_);
317  float ans(0);
318  std::memcpy(&ans, &bits, sizeof(float));
319  return ans;
320  }
321 
322  // can not use half operator _op(const half) for half + half
323  // operation will cast it to float and then do float operation such that it
324  // becomes float in the end.
325 #define HALF_OPERATOR(_op, _opeq) \
326  friend half operator _op(const half& lhf, const half& rhf) \
327  { \
328  return static_cast<half>(static_cast<float>(lhf) \
329  _op static_cast<float>(rhf)); \
330  } \
331  half& operator _opeq(const half& hf) \
332  { \
333  auto result = *this _op hf; \
334  data_ = result.data_; \
335  return *this; \
336  }
337 
338  HALF_OPERATOR(+, +=)
339  HALF_OPERATOR(-, -=)
340  HALF_OPERATOR(*, *=)
341  HALF_OPERATOR(/, /=)
342 
343 #undef HALF_OPERATOR
344 
345  // Do operation with different type
346  // If it is floating point, using floating point as type.
347  // If it is integer, using half as type
348 #define HALF_FRIEND_OPERATOR(_op, _opeq) \
349  template <typename T> \
350  friend std::enable_if_t< \
351  !std::is_same<T, half>::value && std::is_scalar<T>::value, \
352  std::conditional_t<std::is_floating_point<T>::value, T, half>> \
353  operator _op(const half& hf, const T& val) \
354  { \
355  using type = \
356  std::conditional_t<std::is_floating_point<T>::value, T, half>; \
357  auto result = static_cast<type>(hf); \
358  result _opeq static_cast<type>(val); \
359  return result; \
360  } \
361  template <typename T> \
362  friend std::enable_if_t< \
363  !std::is_same<T, half>::value && std::is_scalar<T>::value, \
364  std::conditional_t<std::is_floating_point<T>::value, T, half>> \
365  operator _op(const T& val, const half& hf) \
366  { \
367  using type = \
368  std::conditional_t<std::is_floating_point<T>::value, T, half>; \
369  auto result = static_cast<type>(val); \
370  result _opeq static_cast<type>(hf); \
371  return result; \
372  }
373 
374  HALF_FRIEND_OPERATOR(+, +=)
375  HALF_FRIEND_OPERATOR(-, -=)
376  HALF_FRIEND_OPERATOR(*, *=)
377  HALF_FRIEND_OPERATOR(/, /=)
378 
379 #undef HALF_FRIEND_OPERATOR
380 
381  // the negative
382  half operator-() const
383  {
384  auto val = 0.0f - *this;
385  return static_cast<half>(val);
386  }
387 
388 private:
389  using f16_traits = detail::float_traits<half>;
390  using f32_traits = detail::float_traits<float>;
391 
392  void float2half(const float& val) noexcept
393  {
394  std::uint32_t bit_val(0);
395  std::memcpy(&bit_val, &val, sizeof(float));
396  data_ = float2half(bit_val);
397  }
398 
399  static constexpr std::uint16_t float2half(std::uint32_t data_) noexcept
400  {
401  using conv = detail::precision_converter<float, half>;
402  if (f32_traits::is_inf(data_)) {
403  return conv::shift_sign(data_) | f16_traits::exponent_mask;
404  } else if (f32_traits::is_nan(data_)) {
405  return conv::shift_sign(data_) | f16_traits::exponent_mask |
406  f16_traits::significand_mask;
407  } else {
408  const auto exp = conv::shift_exponent(data_);
409  if (f16_traits::is_inf(exp)) {
410  return conv::shift_sign(data_) | exp;
411  } else if (f16_traits::is_denom(exp)) {
412  // TODO: handle denormals
413  return conv::shift_sign(data_);
414  } else {
415  // Rounding to even
416  const auto result = conv::shift_sign(data_) | exp |
417  conv::shift_significand(data_);
418  const auto tail =
419  data_ & static_cast<f32_traits::bits_type>(
420  (1 << conv::significand_offset) - 1);
421 
422  constexpr auto half = static_cast<f32_traits::bits_type>(
423  1 << (conv::significand_offset - 1));
424  return result +
425  (tail > half || ((tail == half) && (result & 1)));
426  }
427  }
428  }
429 
430  static constexpr std::uint32_t half2float(std::uint16_t data_) noexcept
431  {
432  using conv = detail::precision_converter<half, float>;
433  if (f16_traits::is_inf(data_)) {
434  return conv::shift_sign(data_) | f32_traits::exponent_mask;
435  } else if (f16_traits::is_nan(data_)) {
436  return conv::shift_sign(data_) | f32_traits::exponent_mask |
437  f32_traits::significand_mask;
438  } else if (f16_traits::is_denom(data_)) {
439  // TODO: handle denormals
440  return conv::shift_sign(data_);
441  } else {
442  return conv::shift_sign(data_) | conv::shift_exponent(data_) |
443  conv::shift_significand(data_);
444  }
445  }
446 
447  std::uint16_t data_;
448 };
449 
450 
451 } // namespace gko
452 
453 
454 namespace std {
455 
456 
457 template <>
458 class complex<gko::half> {
459 public:
460  using value_type = gko::half;
461 
462  complex(const value_type& real = value_type(0.f),
463  const value_type& imag = value_type(0.f))
464  : real_(real), imag_(imag)
465  {}
466 
467  template <typename T, typename U,
468  typename = std::enable_if_t<std::is_scalar<T>::value &&
469  std::is_scalar<U>::value>>
470  explicit complex(const T& real, const U& imag)
471  : real_(static_cast<value_type>(real)),
472  imag_(static_cast<value_type>(imag))
473  {}
474 
475  template <typename T, typename = std::enable_if_t<std::is_scalar<T>::value>>
476  complex(const T& real)
477  : real_(static_cast<value_type>(real)),
478  imag_(static_cast<value_type>(0.f))
479  {}
480 
481  // When using complex(real, imag), MSVC with CUDA try to recognize the
482  // complex is a member not constructor.
483  template <typename T, typename = std::enable_if_t<std::is_scalar<T>::value>>
484  explicit complex(const complex<T>& other)
485  : real_(static_cast<value_type>(other.real())),
486  imag_(static_cast<value_type>(other.imag()))
487  {}
488 
489  value_type real() const noexcept { return real_; }
490 
491  value_type imag() const noexcept { return imag_; }
492 
493  operator std::complex<float>() const noexcept
494  {
495  return std::complex<float>(static_cast<float>(real_),
496  static_cast<float>(imag_));
497  }
498 
499  template <typename V>
500  complex& operator=(const V& val)
501  {
502  real_ = val;
503  imag_ = value_type();
504  return *this;
505  }
506 
507  template <typename V>
508  complex& operator=(const std::complex<V>& val)
509  {
510  real_ = val.real();
511  imag_ = val.imag();
512  return *this;
513  }
514 
515  complex& operator+=(const value_type& real)
516  {
517  real_ += real;
518  return *this;
519  }
520 
521  complex& operator-=(const value_type& real)
522  {
523  real_ -= real;
524  return *this;
525  }
526 
527  complex& operator*=(const value_type& real)
528  {
529  real_ *= real;
530  imag_ *= real;
531  return *this;
532  }
533 
534  complex& operator/=(const value_type& real)
535  {
536  real_ /= real;
537  imag_ /= real;
538  return *this;
539  }
540 
541  template <typename T>
542  complex& operator+=(const complex<T>& val)
543  {
544  real_ += val.real();
545  imag_ += val.imag();
546  return *this;
547  }
548 
549  template <typename T>
550  complex& operator-=(const complex<T>& val)
551  {
552  real_ -= val.real();
553  imag_ -= val.imag();
554  return *this;
555  }
556 
557  template <typename T>
558  complex& operator*=(const complex<T>& val)
559  {
560  auto val_f = static_cast<std::complex<float>>(val);
561  auto result_f = static_cast<std::complex<float>>(*this);
562  result_f *= val_f;
563  real_ = result_f.real();
564  imag_ = result_f.imag();
565  return *this;
566  }
567 
568  template <typename T>
569  complex& operator/=(const complex<T>& val)
570  {
571  auto val_f = static_cast<std::complex<float>>(val);
572  auto result_f = static_cast<std::complex<float>>(*this);
573  result_f /= val_f;
574  real_ = result_f.real();
575  imag_ = result_f.imag();
576  return *this;
577  }
578 
579 #define COMPLEX_HALF_OPERATOR(_op, _opeq) \
580  friend complex operator _op(const complex& lhf, const complex& rhf) \
581  { \
582  auto a = lhf; \
583  a _opeq rhf; \
584  return a; \
585  }
586 
587  COMPLEX_HALF_OPERATOR(+, +=)
588  COMPLEX_HALF_OPERATOR(-, -=)
589  COMPLEX_HALF_OPERATOR(*, *=)
590  COMPLEX_HALF_OPERATOR(/, /=)
591 
592 #undef COMPLEX_HALF_OPERATOR
593 
594 private:
595  value_type real_;
596  value_type imag_;
597 };
598 
599 
600 template <>
601 struct numeric_limits<gko::half> {
602  static constexpr bool is_specialized{true};
603  static constexpr bool is_signed{true};
604  static constexpr bool is_integer{false};
605  static constexpr bool is_exact{false};
606  static constexpr bool is_bounded{true};
607  static constexpr bool is_modulo{false};
608  static constexpr int digits{
609  gko::detail::float_traits<gko::half>::significand_bits + 1};
610  // 3/10 is approx. log_10(2)
611  static constexpr int digits10{digits * 3 / 10};
612 
613  static constexpr gko::half epsilon()
614  {
615  constexpr auto bits = static_cast<std::uint16_t>(0b0'00101'0000000000u);
616  return gko::half::create_from_bits(bits);
617  }
618 
619  static constexpr gko::half infinity()
620  {
621  constexpr auto bits = static_cast<std::uint16_t>(0b0'11111'0000000000u);
622  return gko::half::create_from_bits(bits);
623  }
624 
625  static constexpr gko::half min()
626  {
627  constexpr auto bits = static_cast<std::uint16_t>(0b0'00001'0000000000u);
628  return gko::half::create_from_bits(bits);
629  }
630 
631  static constexpr gko::half max()
632  {
633  constexpr auto bits = static_cast<std::uint16_t>(0b0'11110'1111111111u);
634  return gko::half::create_from_bits(bits);
635  }
636 
637  static constexpr gko::half lowest()
638  {
639  constexpr auto bits = static_cast<std::uint16_t>(0b1'11110'1111111111u);
640  return gko::half::create_from_bits(bits);
641  };
642 
643  static constexpr gko::half quiet_NaN()
644  {
645  constexpr auto bits = static_cast<std::uint16_t>(0b0'11111'1111111111u);
646  return gko::half::create_from_bits(bits);
647  }
648 };
649 
650 
651 // complex using a template on operator= for any kind of complex<T>, so we can
652 // do full specialization for half
653 template <>
654 inline complex<double>& complex<double>::operator=(
655  const std::complex<gko::half>& a)
656 {
657  complex<double> t(a.real(), a.imag());
658  operator=(t);
659  return *this;
660 }
661 
662 
663 // For MSVC
664 template <>
665 inline complex<float>& complex<float>::operator=(
666  const std::complex<gko::half>& a)
667 {
668  complex<float> t(a.real(), a.imag());
669  operator=(t);
670  return *this;
671 }
672 
673 
674 } // namespace std
675 
676 
677 #endif // GKO_PUBLIC_CORE_BASE_HALF_HPP_
gko::truncated
Definition: half.hpp:23
gko::byte_size
constexpr size_type byte_size
Number of bits in a byte.
Definition: types.hpp:177
gko
The Ginkgo namespace.
Definition: abstract_factory.hpp:20
gko::half
A class providing basic support for half precision floating point types.
Definition: half.hpp:286