Ginkgo  Generated from pipelines/2011557978 branch based on develop. Ginkgo version 1.11.0
A numerical linear algebra library targeting many-core architectures
bfloat16.hpp
1 // SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
2 //
3 // SPDX-License-Identifier: BSD-3-Clause
4 
5 #ifndef GKO_PUBLIC_CORE_BASE_BFLOAT16_HPP_
6 #define GKO_PUBLIC_CORE_BASE_BFLOAT16_HPP_
7 
8 
9 #include <climits>
10 #include <complex>
11 #include <cstdint>
12 #include <cstring>
13 #include <type_traits>
14 
15 #include <ginkgo/core/base/half.hpp>
16 
17 
18 class __nv_bfloat16;
19 class hip_bfloat16;
20 class __hip_bfloat16;
21 
22 
23 namespace gko {
24 
25 
26 class bfloat16;
27 
28 
29 namespace detail {
30 template <>
31 struct basic_float_traits<bfloat16> {
32  using type = bfloat16;
33  static constexpr int sign_bits = 1;
34  static constexpr int significand_bits = 7;
35  static constexpr int exponent_bits = 8;
36  static constexpr bool rounds_to_nearest = true;
37 };
38 
39 template <>
40 struct basic_float_traits<__nv_bfloat16> {
41  using type = __nv_bfloat16;
42  static constexpr int sign_bits = 1;
43  static constexpr int significand_bits = 7;
44  static constexpr int exponent_bits = 8;
45  static constexpr bool rounds_to_nearest = true;
46 };
47 
48 template <>
49 struct basic_float_traits<hip_bfloat16> {
50  using type = hip_bfloat16;
51  static constexpr int sign_bits = 1;
52  static constexpr int significand_bits = 7;
53  static constexpr int exponent_bits = 8;
54  static constexpr bool rounds_to_nearest = true;
55 };
56 
57 template <>
58 struct basic_float_traits<__hip_bfloat16> {
59  using type = __hip_bfloat16;
60  static constexpr int sign_bits = 1;
61  static constexpr int significand_bits = 7;
62  static constexpr int exponent_bits = 8;
63  static constexpr bool rounds_to_nearest = true;
64 };
65 
66 
67 } // namespace detail
68 
69 
76 class alignas(std::uint16_t) bfloat16 {
77 public:
78  // create bfloat16 value from the bits directly.
79  static constexpr bfloat16 create_from_bits(
80  const std::uint16_t& bits) noexcept
81  {
82  bfloat16 result;
83  result.data_ = bits;
84  return result;
85  }
86 
87  // TODO: NVHPC (host side) may not use zero initialization for the data
88  // member by default constructor in some cases. Not sure whether it is
89  // caused by something else in jacobi or isai.
90  constexpr bfloat16() noexcept : data_(0){};
91 
92  template <typename T,
93  typename = std::enable_if_t<std::is_scalar<T>::value ||
94  std::is_same_v<T, half>>>
95  bfloat16(const T& val) : data_(0)
96  {
97  this->float2bfloat16(static_cast<float>(val));
98  }
99 
100  template <typename V>
101  bfloat16& operator=(const V& val)
102  {
103  this->float2bfloat16(static_cast<float>(val));
104  return *this;
105  }
106 
107  operator float() const noexcept
108  {
109  const auto bits = bfloat162float(data_);
110  float ans(0);
111  std::memcpy(&ans, &bits, sizeof(float));
112  return ans;
113  }
114 
115  // can not use bfloat16 operator _op(const bfloat16) for bfloat16 + bfloat16
116  // operation will cast it to float and then do float operation such that it
117  // becomes float in the end.
118 #define BFLOAT16_OPERATOR(_op, _opeq) \
119  friend bfloat16 operator _op(const bfloat16& lhf, const bfloat16& rhf) \
120  { \
121  return static_cast<bfloat16>(static_cast<float>(lhf) \
122  _op static_cast<float>(rhf)); \
123  } \
124  bfloat16& operator _opeq(const bfloat16& hf) \
125  { \
126  auto result = *this _op hf; \
127  data_ = result.data_; \
128  return *this; \
129  }
130 
131  BFLOAT16_OPERATOR(+, +=)
132  BFLOAT16_OPERATOR(-, -=)
133  BFLOAT16_OPERATOR(*, *=)
134  BFLOAT16_OPERATOR(/, /=)
135 
136 #undef BFLOAT16_OPERATOR
137 
138  // Do operation with different type
139  // If it is floating point, using floating point as type.
140  // If it is bfloat16, using float as type.
141  // If it is integer, using bfloat16 as type.
142 #define BFLOAT16_FRIEND_OPERATOR(_op, _opeq) \
143  template <typename T> \
144  friend std::enable_if_t< \
145  !std::is_same<T, bfloat16>::value && \
146  (std::is_scalar<T>::value || std::is_same_v<T, half>), \
147  std::conditional_t< \
148  std::is_floating_point<T>::value, T, \
149  std::conditional_t<std::is_same_v<T, half>, float, bfloat16>>> \
150  operator _op(const bfloat16& hf, const T& val) \
151  { \
152  using type = \
153  std::conditional_t<std::is_floating_point<T>::value, T, bfloat16>; \
154  auto result = static_cast<type>(hf); \
155  result _opeq static_cast<type>(val); \
156  return result; \
157  } \
158  template <typename T> \
159  friend std::enable_if_t< \
160  !std::is_same<T, bfloat16>::value && \
161  (std::is_scalar<T>::value || std::is_same_v<T, half>), \
162  std::conditional_t< \
163  std::is_floating_point<T>::value, T, \
164  std::conditional_t<std::is_same_v<T, half>, float, bfloat16>>> \
165  operator _op(const T& val, const bfloat16& hf) \
166  { \
167  using type = \
168  std::conditional_t<std::is_floating_point<T>::value, T, bfloat16>; \
169  auto result = static_cast<type>(val); \
170  result _opeq static_cast<type>(hf); \
171  return result; \
172  }
173 
174  BFLOAT16_FRIEND_OPERATOR(+, +=)
175  BFLOAT16_FRIEND_OPERATOR(-, -=)
176  BFLOAT16_FRIEND_OPERATOR(*, *=)
177  BFLOAT16_FRIEND_OPERATOR(/, /=)
178 
179 #undef BFLOAT16_FRIEND_OPERATOR
180 
181  // the negative
182  bfloat16 operator-() const
183  {
184  auto val = 0.0f - *this;
185  return static_cast<bfloat16>(val);
186  }
187 
188 private:
189  using f16_traits = detail::float_traits<bfloat16>;
190  using f32_traits = detail::float_traits<float>;
191 
192  void float2bfloat16(const float& val) noexcept
193  {
194  std::uint32_t bit_val(0);
195  std::memcpy(&bit_val, &val, sizeof(float));
196  data_ = float2bfloat16(bit_val);
197  }
198 
199  static constexpr std::uint16_t float2bfloat16(std::uint32_t data_) noexcept
200  {
201  using conv = detail::precision_converter<float, bfloat16>;
202  if (f32_traits::is_inf(data_)) {
203  return conv::shift_sign(data_) | f16_traits::exponent_mask;
204  } else if (f32_traits::is_nan(data_)) {
205  return conv::shift_sign(data_) | f16_traits::exponent_mask |
206  f16_traits::significand_mask;
207  } else {
208  const auto exp = conv::shift_exponent(data_);
209  if (f16_traits::is_inf(exp)) {
210  return conv::shift_sign(data_) | exp;
211  } else if (f16_traits::is_denom(exp)) {
212  // TODO: handle denormals
213  return conv::shift_sign(data_);
214  } else {
215  // Rounding to even
216  const auto result = conv::shift_sign(data_) | exp |
217  conv::shift_significand(data_);
218  const auto tail =
219  data_ & static_cast<f32_traits::bits_type>(
220  (1 << conv::significand_offset) - 1);
221 
222  constexpr auto bfloat16 = static_cast<f32_traits::bits_type>(
223  1 << (conv::significand_offset - 1));
224  return result + (tail > bfloat16 ||
225  ((tail == bfloat16) && (result & 1)));
226  }
227  }
228  }
229 
230  static constexpr std::uint32_t bfloat162float(std::uint16_t data_) noexcept
231  {
232  using conv = detail::precision_converter<bfloat16, float>;
233  if (f16_traits::is_inf(data_)) {
234  return conv::shift_sign(data_) | f32_traits::exponent_mask;
235  } else if (f16_traits::is_nan(data_)) {
236  return conv::shift_sign(data_) | f32_traits::exponent_mask |
237  f32_traits::significand_mask;
238  } else if (f16_traits::is_denom(data_)) {
239  // TODO: handle denormals
240  return conv::shift_sign(data_);
241  } else {
242  return conv::shift_sign(data_) | conv::shift_exponent(data_) |
243  conv::shift_significand(data_);
244  }
245  }
246 
247  std::uint16_t data_;
248 };
249 
250 
251 } // namespace gko
252 
253 
254 namespace std {
255 
256 
257 template <>
258 class complex<gko::bfloat16> {
259 public:
260  using value_type = gko::bfloat16;
261 
262  complex(const value_type& real = value_type(0.f),
263  const value_type& imag = value_type(0.f))
264  : real_(real), imag_(imag)
265  {}
266 
267  template <
268  typename T, typename U,
269  typename = std::enable_if_t<
270  (std::is_scalar<T>::value || std::is_same_v<T, gko::half>)&&(
271  std::is_scalar<U>::value || std::is_same_v<U, gko::half>)>>
272  explicit complex(const T& real, const U& imag)
273  : real_(static_cast<value_type>(real)),
274  imag_(static_cast<value_type>(imag))
275  {}
276 
277  template <typename T,
278  typename = std::enable_if_t<std::is_scalar<T>::value ||
279  std::is_same_v<T, gko::half>>>
280  complex(const T& real)
281  : real_(static_cast<value_type>(real)),
282  imag_(static_cast<value_type>(0.f))
283  {}
284 
285  // When using complex(real, imag), MSVC with CUDA try to recognize the
286  // complex is a member not constructor.
287  template <typename T,
288  typename = std::enable_if_t<std::is_scalar<T>::value ||
289  std::is_same_v<T, gko::half>>>
290  explicit complex(const complex<T>& other)
291  : real_(static_cast<value_type>(other.real())),
292  imag_(static_cast<value_type>(other.imag()))
293  {}
294 
295  value_type real() const noexcept { return real_; }
296 
297  value_type imag() const noexcept { return imag_; }
298 
299  operator std::complex<float>() const noexcept
300  {
301  return std::complex<float>(static_cast<float>(real_),
302  static_cast<float>(imag_));
303  }
304 
305  template <typename V>
306  complex& operator=(const V& val)
307  {
308  real_ = val;
309  imag_ = value_type();
310  return *this;
311  }
312 
313  template <typename V>
314  complex& operator=(const std::complex<V>& val)
315  {
316  real_ = val.real();
317  imag_ = val.imag();
318  return *this;
319  }
320 
321  complex& operator+=(const value_type& real)
322  {
323  real_ += real;
324  return *this;
325  }
326 
327  complex& operator-=(const value_type& real)
328  {
329  real_ -= real;
330  return *this;
331  }
332 
333  complex& operator*=(const value_type& real)
334  {
335  real_ *= real;
336  imag_ *= real;
337  return *this;
338  }
339 
340  complex& operator/=(const value_type& real)
341  {
342  real_ /= real;
343  imag_ /= real;
344  return *this;
345  }
346 
347  template <typename T>
348  complex& operator+=(const complex<T>& val)
349  {
350  real_ += val.real();
351  imag_ += val.imag();
352  return *this;
353  }
354 
355  template <typename T>
356  complex& operator-=(const complex<T>& val)
357  {
358  real_ -= val.real();
359  imag_ -= val.imag();
360  return *this;
361  }
362 
363  template <typename T>
364  complex& operator*=(const complex<T>& val)
365  {
366  auto val_f = static_cast<std::complex<float>>(val);
367  auto result_f = static_cast<std::complex<float>>(*this);
368  result_f *= val_f;
369  real_ = result_f.real();
370  imag_ = result_f.imag();
371  return *this;
372  }
373 
374  template <typename T>
375  complex& operator/=(const complex<T>& val)
376  {
377  auto val_f = static_cast<std::complex<float>>(val);
378  auto result_f = static_cast<std::complex<float>>(*this);
379  result_f /= val_f;
380  real_ = result_f.real();
381  imag_ = result_f.imag();
382  return *this;
383  }
384 
385 #define COMPLEX_BFLOAT16_OPERATOR(_op, _opeq) \
386  friend complex operator _op(const complex& lhf, const complex& rhf) \
387  { \
388  auto a = lhf; \
389  a _opeq rhf; \
390  return a; \
391  }
392 
393  COMPLEX_BFLOAT16_OPERATOR(+, +=)
394  COMPLEX_BFLOAT16_OPERATOR(-, -=)
395  COMPLEX_BFLOAT16_OPERATOR(*, *=)
396  COMPLEX_BFLOAT16_OPERATOR(/, /=)
397 
398 #undef COMPLEX_BFLOAT16_OPERATOR
399 
400 private:
401  value_type real_;
402  value_type imag_;
403 };
404 
405 
406 template <>
407 struct numeric_limits<gko::bfloat16> {
408  static constexpr bool is_specialized{true};
409  static constexpr bool is_signed{true};
410  static constexpr bool is_integer{false};
411  static constexpr bool is_exact{false};
412  static constexpr bool is_bounded{true};
413  static constexpr bool is_modulo{false};
414  static constexpr int digits{
415  gko::detail::float_traits<gko::bfloat16>::significand_bits + 1};
416  // 3/10 is approx. log_10(2)
417  static constexpr int digits10{digits * 3 / 10};
418 
419  static constexpr gko::bfloat16 epsilon()
420  {
421  constexpr auto bits = static_cast<std::uint16_t>(0b0'01111000'0000000u);
422  return gko::bfloat16::create_from_bits(bits);
423  }
424 
425  static constexpr gko::bfloat16 infinity()
426  {
427  constexpr auto bits = static_cast<std::uint16_t>(0b0'11111111'0000000u);
428  return gko::bfloat16::create_from_bits(bits);
429  }
430 
431  static constexpr gko::bfloat16 min()
432  {
433  constexpr auto bits = static_cast<std::uint16_t>(0b0'00000001'0000000u);
434  return gko::bfloat16::create_from_bits(bits);
435  }
436 
437  static constexpr gko::bfloat16 max()
438  {
439  constexpr auto bits = static_cast<std::uint16_t>(0b0'11111110'1111111u);
440  return gko::bfloat16::create_from_bits(bits);
441  }
442 
443  static constexpr gko::bfloat16 lowest()
444  {
445  constexpr auto bits = static_cast<std::uint16_t>(0b1'11111110'1111111u);
446  return gko::bfloat16::create_from_bits(bits);
447  };
448 
449  static constexpr gko::bfloat16 quiet_NaN()
450  {
451  constexpr auto bits = static_cast<std::uint16_t>(0b0'11111111'1111111u);
452  return gko::bfloat16::create_from_bits(bits);
453  }
454 };
455 
456 
457 // complex using a template on operator= for any kind of complex<T>, so we can
458 // do full specialization for bfloat16
459 template <>
460 inline complex<double>& complex<double>::operator=(
461  const std::complex<gko::bfloat16>& a)
462 {
463  complex<double> t(a.real(), a.imag());
464  operator=(t);
465  return *this;
466 }
467 
468 
469 // For MSVC
470 template <>
471 inline complex<float>& complex<float>::operator=(
472  const std::complex<gko::bfloat16>& a)
473 {
474  complex<float> t(a.real(), a.imag());
475  operator=(t);
476  return *this;
477 }
478 
479 
480 } // namespace std
481 
482 
483 #endif // GKO_PUBLIC_CORE_BASE_bfloat16_HPP_
gko::max
constexpr T max(const T &x, const T &y)
Returns the larger of the arguments.
Definition: math.hpp:732
gko::bfloat16
A class providing basic support for bfloat16 precision floating point types.
Definition: bfloat16.hpp:76
gko
The Ginkgo namespace.
Definition: abstract_factory.hpp:20
gko::min
constexpr T min(const T &x, const T &y)
Returns the smaller of the arguments.
Definition: math.hpp:750
gko::real
constexpr auto real(const T &x)
Returns the real part of the object.
Definition: math.hpp:900
gko::imag
constexpr auto imag(const T &x)
Returns the imaginary part of the object.
Definition: math.hpp:916