5 #ifndef GKO_PUBLIC_CORE_BASE_BFLOAT16_HPP_
6 #define GKO_PUBLIC_CORE_BASE_BFLOAT16_HPP_
13 #include <type_traits>
15 #include <ginkgo/core/base/half.hpp>
31 struct basic_float_traits<bfloat16> {
32 using type = bfloat16;
33 static constexpr
int sign_bits = 1;
34 static constexpr
int significand_bits = 7;
35 static constexpr
int exponent_bits = 8;
36 static constexpr
bool rounds_to_nearest =
true;
40 struct basic_float_traits<__nv_bfloat16> {
41 using type = __nv_bfloat16;
42 static constexpr
int sign_bits = 1;
43 static constexpr
int significand_bits = 7;
44 static constexpr
int exponent_bits = 8;
45 static constexpr
bool rounds_to_nearest =
true;
49 struct basic_float_traits<hip_bfloat16> {
50 using type = hip_bfloat16;
51 static constexpr
int sign_bits = 1;
52 static constexpr
int significand_bits = 7;
53 static constexpr
int exponent_bits = 8;
54 static constexpr
bool rounds_to_nearest =
true;
58 struct basic_float_traits<__hip_bfloat16> {
59 using type = __hip_bfloat16;
60 static constexpr
int sign_bits = 1;
61 static constexpr
int significand_bits = 7;
62 static constexpr
int exponent_bits = 8;
63 static constexpr
bool rounds_to_nearest =
true;
79 static constexpr
bfloat16 create_from_bits(
80 const std::uint16_t& bits) noexcept
90 constexpr
bfloat16() noexcept : data_(0){};
93 typename = std::enable_if_t<std::is_scalar<T>::value ||
94 std::is_same_v<T, half>>>
97 this->float2bfloat16(static_cast<float>(val));
100 template <
typename V>
103 this->float2bfloat16(static_cast<float>(val));
107 operator float()
const noexcept
109 const auto bits = bfloat162float(data_);
111 std::memcpy(&ans, &bits,
sizeof(
float));
118 #define BFLOAT16_OPERATOR(_op, _opeq) \
119 friend bfloat16 operator _op(const bfloat16& lhf, const bfloat16& rhf) \
121 return static_cast<bfloat16>(static_cast<float>(lhf) \
122 _op static_cast<float>(rhf)); \
124 bfloat16& operator _opeq(const bfloat16& hf) \
126 auto result = *this _op hf; \
127 data_ = result.data_; \
131 BFLOAT16_OPERATOR(+, +=)
132 BFLOAT16_OPERATOR(-, -=)
133 BFLOAT16_OPERATOR(*, *=)
134 BFLOAT16_OPERATOR(/, /=)
136 #undef BFLOAT16_OPERATOR
142 #define BFLOAT16_FRIEND_OPERATOR(_op, _opeq) \
143 template <typename T> \
144 friend std::enable_if_t< \
145 !std::is_same<T, bfloat16>::value && \
146 (std::is_scalar<T>::value || std::is_same_v<T, half>), \
147 std::conditional_t< \
148 std::is_floating_point<T>::value, T, \
149 std::conditional_t<std::is_same_v<T, half>, float, bfloat16>>> \
150 operator _op(const bfloat16& hf, const T& val) \
153 std::conditional_t<std::is_floating_point<T>::value, T, bfloat16>; \
154 auto result = static_cast<type>(hf); \
155 result _opeq static_cast<type>(val); \
158 template <typename T> \
159 friend std::enable_if_t< \
160 !std::is_same<T, bfloat16>::value && \
161 (std::is_scalar<T>::value || std::is_same_v<T, half>), \
162 std::conditional_t< \
163 std::is_floating_point<T>::value, T, \
164 std::conditional_t<std::is_same_v<T, half>, float, bfloat16>>> \
165 operator _op(const T& val, const bfloat16& hf) \
168 std::conditional_t<std::is_floating_point<T>::value, T, bfloat16>; \
169 auto result = static_cast<type>(val); \
170 result _opeq static_cast<type>(hf); \
174 BFLOAT16_FRIEND_OPERATOR(+, +=)
175 BFLOAT16_FRIEND_OPERATOR(-, -=)
176 BFLOAT16_FRIEND_OPERATOR(*, *=)
177 BFLOAT16_FRIEND_OPERATOR(/, /=)
179 #undef BFLOAT16_FRIEND_OPERATOR
184 auto val = 0.0f - *
this;
185 return static_cast<bfloat16>(val);
189 using f16_traits = detail::float_traits<bfloat16>;
190 using f32_traits = detail::float_traits<float>;
192 void float2bfloat16(
const float& val) noexcept
194 std::uint32_t bit_val(0);
195 std::memcpy(&bit_val, &val,
sizeof(
float));
196 data_ = float2bfloat16(bit_val);
199 static constexpr std::uint16_t float2bfloat16(std::uint32_t data_) noexcept
201 using conv = detail::precision_converter<float, bfloat16>;
202 if (f32_traits::is_inf(data_)) {
203 return conv::shift_sign(data_) | f16_traits::exponent_mask;
204 }
else if (f32_traits::is_nan(data_)) {
205 return conv::shift_sign(data_) | f16_traits::exponent_mask |
206 f16_traits::significand_mask;
208 const auto exp = conv::shift_exponent(data_);
209 if (f16_traits::is_inf(exp)) {
210 return conv::shift_sign(data_) | exp;
211 }
else if (f16_traits::is_denom(exp)) {
213 return conv::shift_sign(data_);
216 const auto result = conv::shift_sign(data_) | exp |
217 conv::shift_significand(data_);
219 data_ & static_cast<f32_traits::bits_type>(
220 (1 << conv::significand_offset) - 1);
222 constexpr
auto bfloat16 = static_cast<f32_traits::bits_type>(
223 1 << (conv::significand_offset - 1));
225 ((tail ==
bfloat16) && (result & 1)));
230 static constexpr std::uint32_t bfloat162float(std::uint16_t data_) noexcept
232 using conv = detail::precision_converter<bfloat16, float>;
233 if (f16_traits::is_inf(data_)) {
234 return conv::shift_sign(data_) | f32_traits::exponent_mask;
235 }
else if (f16_traits::is_nan(data_)) {
236 return conv::shift_sign(data_) | f32_traits::exponent_mask |
237 f32_traits::significand_mask;
238 }
else if (f16_traits::is_denom(data_)) {
240 return conv::shift_sign(data_);
242 return conv::shift_sign(data_) | conv::shift_exponent(data_) |
243 conv::shift_significand(data_);
258 class complex<
gko::bfloat16> {
262 complex(
const value_type& real = value_type(0.f),
263 const value_type& imag = value_type(0.f))
264 : real_(real), imag_(imag)
268 typename T,
typename U,
269 typename = std::enable_if_t<
270 (std::is_scalar<T>::value || std::is_same_v<T, gko::half>)&&(
271 std::is_scalar<U>::value || std::is_same_v<U, gko::half>)>>
272 explicit complex(
const T& real,
const U& imag)
273 : real_(static_cast<value_type>(
real)),
274 imag_(static_cast<value_type>(
imag))
277 template <
typename T,
278 typename = std::enable_if_t<std::is_scalar<T>::value ||
279 std::is_same_v<T, gko::half>>>
280 complex(
const T& real)
281 : real_(static_cast<value_type>(
real)),
282 imag_(static_cast<value_type>(0.f))
287 template <
typename T,
288 typename = std::enable_if_t<std::is_scalar<T>::value ||
289 std::is_same_v<T, gko::half>>>
290 explicit complex(
const complex<T>& other)
291 : real_(static_cast<value_type>(other.
real())),
292 imag_(static_cast<value_type>(other.
imag()))
295 value_type
real() const noexcept {
return real_; }
297 value_type
imag() const noexcept {
return imag_; }
299 operator std::complex<float>() const noexcept
301 return std::complex<float>(static_cast<float>(real_),
302 static_cast<float>(imag_));
305 template <
typename V>
306 complex& operator=(
const V& val)
309 imag_ = value_type();
313 template <
typename V>
314 complex& operator=(
const std::complex<V>& val)
321 complex& operator+=(
const value_type& real)
327 complex& operator-=(
const value_type& real)
333 complex& operator*=(
const value_type& real)
340 complex& operator/=(
const value_type& real)
347 template <
typename T>
348 complex& operator+=(
const complex<T>& val)
355 template <
typename T>
356 complex& operator-=(
const complex<T>& val)
363 template <
typename T>
364 complex& operator*=(
const complex<T>& val)
366 auto val_f =
static_cast<std::complex<float>
>(val);
367 auto result_f =
static_cast<std::complex<float>
>(*this);
369 real_ = result_f.real();
370 imag_ = result_f.imag();
374 template <
typename T>
375 complex& operator/=(
const complex<T>& val)
377 auto val_f =
static_cast<std::complex<float>
>(val);
378 auto result_f =
static_cast<std::complex<float>
>(*this);
380 real_ = result_f.real();
381 imag_ = result_f.imag();
385 #define COMPLEX_BFLOAT16_OPERATOR(_op, _opeq) \
386 friend complex operator _op(const complex& lhf, const complex& rhf) \
393 COMPLEX_BFLOAT16_OPERATOR(+, +=)
394 COMPLEX_BFLOAT16_OPERATOR(-, -=)
395 COMPLEX_BFLOAT16_OPERATOR(*, *=)
396 COMPLEX_BFLOAT16_OPERATOR(/, /=)
398 #undef COMPLEX_BFLOAT16_OPERATOR
407 struct numeric_limits<
gko::bfloat16> {
408 static constexpr
bool is_specialized{
true};
409 static constexpr
bool is_signed{
true};
410 static constexpr
bool is_integer{
false};
411 static constexpr
bool is_exact{
false};
412 static constexpr
bool is_bounded{
true};
413 static constexpr
bool is_modulo{
false};
414 static constexpr
int digits{
415 gko::detail::float_traits<gko::bfloat16>::significand_bits + 1};
417 static constexpr
int digits10{digits * 3 / 10};
421 constexpr
auto bits = static_cast<std::uint16_t>(0b0
'01111000'0000000u);
422 return gko::bfloat16::create_from_bits(bits);
427 constexpr
auto bits = static_cast<std::uint16_t>(0b0
'11111111'0000000u);
428 return gko::bfloat16::create_from_bits(bits);
433 constexpr
auto bits = static_cast<std::uint16_t>(0b0
'00000001'0000000u);
434 return gko::bfloat16::create_from_bits(bits);
439 constexpr
auto bits = static_cast<std::uint16_t>(0b0
'11111110'1111111u);
440 return gko::bfloat16::create_from_bits(bits);
445 constexpr
auto bits = static_cast<std::uint16_t>(0b1
'11111110'1111111u);
446 return gko::bfloat16::create_from_bits(bits);
451 constexpr
auto bits = static_cast<std::uint16_t>(0b0
'11111111'1111111u);
452 return gko::bfloat16::create_from_bits(bits);
460 inline complex<double>& complex<double>::operator=(
461 const std::complex<gko::bfloat16>& a)
463 complex<double> t(a.real(), a.imag());
471 inline complex<float>& complex<float>::operator=(
472 const std::complex<gko::bfloat16>& a)
474 complex<float> t(a.real(), a.imag());
483 #endif // GKO_PUBLIC_CORE_BASE_bfloat16_HPP_