From 259e79e3ff5ffcc53dfd7d92a09435362d573ffb Mon Sep 17 00:00:00 2001 From: Jane Xu Date: Wed, 30 Jul 2025 08:11:59 -0700 Subject: [PATCH] Move Half to headeronly (#159172) Essence of this copypasta: - combine Half-inl.h and Half.h in c10/util -> torch/headeronly/util/Half.h - Add NOLINTNEXTLINE's to the portions of Half-inl.h that were previously in the ignore list of clangtidy - Re-expose all APIs in namespaces and through includes of the original files. Ideally, we would have the APIs in torch::headeronly and reexpose them in c10, but that runs into BC issues (see D78997465) so for now we are keeping the APIs in c10 but reexposing them in torch::headeronly. - Change test cases in test_aoti_abi_check to test torch::headeronly::Half vs c10::Half (they're the same thing but we eventually want all the tests for headeronly APIs to only import from headeronly). Pull Request resolved: https://github.com/pytorch/pytorch/pull/159172 Approved by: https://github.com/albanD, https://github.com/desertfire --- c10/util/Half-inl.h | 351 +--------------- c10/util/Half.h | 235 +---------- test/cpp/aoti_abi_check/test_dtype.cpp | 25 +- torch/headeronly/util/Half.h | 550 ++++++++++++++++++++++++- 4 files changed, 568 insertions(+), 593 deletions(-) diff --git a/c10/util/Half-inl.h b/c10/util/Half-inl.h index ae4469e5636..fe66779a0e5 100644 --- a/c10/util/Half-inl.h +++ b/c10/util/Half-inl.h @@ -1,350 +1 @@ -#pragma once - -#include -#include - -#include -#include - -#ifdef __CUDACC__ -#include -#endif - -#ifdef __HIPCC__ -#include -#endif - -#if defined(CL_SYCL_LANGUAGE_VERSION) -#include // for SYCL 1.2.1 -#elif defined(SYCL_LANGUAGE_VERSION) -#include // for SYCL 2020 -#endif - -#if (defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)) && \ - !defined(__APPLE__) -#include -#endif - -C10_CLANG_DIAGNOSTIC_PUSH() -#if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion") -C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion") -#endif - -namespace c10 { - -#if defined(__aarch64__) && !defined(__CUDACC__) -/// Constructors -inline Half::Half(float16_t value) : x(detail::fp16_to_bits(value)) {} -inline Half::operator float16_t() const { - return detail::fp16_from_bits(x); -} -#else - -inline C10_HOST_DEVICE Half::Half(float value) - : -#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) - x(__half_as_short(__float2half(value))) -#elif defined(__SYCL_DEVICE_ONLY__) - x(c10::bit_cast(sycl::half(value))) -#elif (defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)) && \ - !defined(__APPLE__) - x(at::vec::float2half_scalar(value)) -#else - x(detail::fp16_ieee_from_fp32_value(value)) -#endif -{ -} - -/// Implicit conversions - -inline C10_HOST_DEVICE Half::operator float() const { -#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) - return __half2float(*reinterpret_cast(&x)); -#elif defined(__SYCL_DEVICE_ONLY__) - return float(c10::bit_cast(x)); -#elif (defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)) && \ - !defined(__APPLE__) - return at::vec::half2float_scalar(x); -#elif defined(__aarch64__) && !defined(__CUDACC__) - return detail::native_fp16_to_fp32_value(x); -#else - return detail::fp16_ieee_to_fp32_value(x); -#endif -} - -#endif /* !defined(__aarch64__) || defined(__CUDACC__) \ - */ - -#if defined(__CUDACC__) || defined(__HIPCC__) -inline C10_HOST_DEVICE Half::Half(const __half& value) { - x = *reinterpret_cast(&value); -} -inline C10_HOST_DEVICE Half::operator __half() const { - return *reinterpret_cast(&x); -} -#endif - -#ifdef SYCL_LANGUAGE_VERSION -inline C10_HOST_DEVICE Half::Half(const sycl::half& value) { - x = *reinterpret_cast(&value); -} -inline C10_HOST_DEVICE Half::operator sycl::half() const { - return *reinterpret_cast(&x); -} -#endif - -// CUDA intrinsics - -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 350)) || \ - (defined(__clang__) && defined(__CUDA__)) -inline __device__ Half __ldg(const Half* ptr) { - return __ldg(reinterpret_cast(ptr)); -} -#endif - -/// Arithmetic - -inline C10_HOST_DEVICE Half operator+(const Half& a, const Half& b) { - return static_cast(a) + static_cast(b); -} - -inline C10_HOST_DEVICE Half operator-(const Half& a, const Half& b) { - return static_cast(a) - static_cast(b); -} - -inline C10_HOST_DEVICE Half operator*(const Half& a, const Half& b) { - return static_cast(a) * static_cast(b); -} - -inline C10_HOST_DEVICE Half operator/(const Half& a, const Half& b) - __ubsan_ignore_float_divide_by_zero__ { - return static_cast(a) / static_cast(b); -} - -inline C10_HOST_DEVICE Half operator-(const Half& a) { -#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || \ - defined(__HIP_DEVICE_COMPILE__) - return __hneg(a); -#elif defined(__SYCL_DEVICE_ONLY__) - return -c10::bit_cast(a); -#else - return -static_cast(a); -#endif -} - -inline C10_HOST_DEVICE Half& operator+=(Half& a, const Half& b) { - a = a + b; - return a; -} - -inline C10_HOST_DEVICE Half& operator-=(Half& a, const Half& b) { - a = a - b; - return a; -} - -inline C10_HOST_DEVICE Half& operator*=(Half& a, const Half& b) { - a = a * b; - return a; -} - -inline C10_HOST_DEVICE Half& operator/=(Half& a, const Half& b) { - a = a / b; - return a; -} - -/// Arithmetic with floats - -inline C10_HOST_DEVICE float operator+(Half a, float b) { - return static_cast(a) + b; -} -inline C10_HOST_DEVICE float operator-(Half a, float b) { - return static_cast(a) - b; -} -inline C10_HOST_DEVICE float operator*(Half a, float b) { - return static_cast(a) * b; -} -inline C10_HOST_DEVICE float operator/(Half a, float b) - __ubsan_ignore_float_divide_by_zero__ { - return static_cast(a) / b; -} - -inline C10_HOST_DEVICE float operator+(float a, Half b) { - return a + static_cast(b); -} -inline C10_HOST_DEVICE float operator-(float a, Half b) { - return a - static_cast(b); -} -inline C10_HOST_DEVICE float operator*(float a, Half b) { - return a * static_cast(b); -} -inline C10_HOST_DEVICE float operator/(float a, Half b) - __ubsan_ignore_float_divide_by_zero__ { - return a / static_cast(b); -} - -inline C10_HOST_DEVICE float& operator+=(float& a, const Half& b) { - return a += static_cast(b); -} -inline C10_HOST_DEVICE float& operator-=(float& a, const Half& b) { - return a -= static_cast(b); -} -inline C10_HOST_DEVICE float& operator*=(float& a, const Half& b) { - return a *= static_cast(b); -} -inline C10_HOST_DEVICE float& operator/=(float& a, const Half& b) { - return a /= static_cast(b); -} - -/// Arithmetic with doubles - -inline C10_HOST_DEVICE double operator+(Half a, double b) { - return static_cast(a) + b; -} -inline C10_HOST_DEVICE double operator-(Half a, double b) { - return static_cast(a) - b; -} -inline C10_HOST_DEVICE double operator*(Half a, double b) { - return static_cast(a) * b; -} -inline C10_HOST_DEVICE double operator/(Half a, double b) - __ubsan_ignore_float_divide_by_zero__ { - return static_cast(a) / b; -} - -inline C10_HOST_DEVICE double operator+(double a, Half b) { - return a + static_cast(b); -} -inline C10_HOST_DEVICE double operator-(double a, Half b) { - return a - static_cast(b); -} -inline C10_HOST_DEVICE double operator*(double a, Half b) { - return a * static_cast(b); -} -inline C10_HOST_DEVICE double operator/(double a, Half b) - __ubsan_ignore_float_divide_by_zero__ { - return a / static_cast(b); -} - -/// Arithmetic with ints - -inline C10_HOST_DEVICE Half operator+(Half a, int b) { - return a + static_cast(b); -} -inline C10_HOST_DEVICE Half operator-(Half a, int b) { - return a - static_cast(b); -} -inline C10_HOST_DEVICE Half operator*(Half a, int b) { - return a * static_cast(b); -} -inline C10_HOST_DEVICE Half operator/(Half a, int b) { - return a / static_cast(b); -} - -inline C10_HOST_DEVICE Half operator+(int a, Half b) { - return static_cast(a) + b; -} -inline C10_HOST_DEVICE Half operator-(int a, Half b) { - return static_cast(a) - b; -} -inline C10_HOST_DEVICE Half operator*(int a, Half b) { - return static_cast(a) * b; -} -inline C10_HOST_DEVICE Half operator/(int a, Half b) { - return static_cast(a) / b; -} - -//// Arithmetic with int64_t - -inline C10_HOST_DEVICE Half operator+(Half a, int64_t b) { - return a + static_cast(b); -} -inline C10_HOST_DEVICE Half operator-(Half a, int64_t b) { - return a - static_cast(b); -} -inline C10_HOST_DEVICE Half operator*(Half a, int64_t b) { - return a * static_cast(b); -} -inline C10_HOST_DEVICE Half operator/(Half a, int64_t b) { - return a / static_cast(b); -} - -inline C10_HOST_DEVICE Half operator+(int64_t a, Half b) { - return static_cast(a) + b; -} -inline C10_HOST_DEVICE Half operator-(int64_t a, Half b) { - return static_cast(a) - b; -} -inline C10_HOST_DEVICE Half operator*(int64_t a, Half b) { - return static_cast(a) * b; -} -inline C10_HOST_DEVICE Half operator/(int64_t a, Half b) { - return static_cast(a) / b; -} - -/// NOTE: we do not define comparisons directly and instead rely on the implicit -/// conversion from c10::Half to float. - -} // namespace c10 - -namespace std { - -template <> -class numeric_limits { - public: - static constexpr bool is_specialized = true; - static constexpr bool is_signed = true; - static constexpr bool is_integer = false; - static constexpr bool is_exact = false; - static constexpr bool has_infinity = true; - static constexpr bool has_quiet_NaN = true; - static constexpr bool has_signaling_NaN = true; - static constexpr auto has_denorm = numeric_limits::has_denorm; - static constexpr auto has_denorm_loss = - numeric_limits::has_denorm_loss; - static constexpr auto round_style = numeric_limits::round_style; - static constexpr bool is_iec559 = true; - static constexpr bool is_bounded = true; - static constexpr bool is_modulo = false; - static constexpr int digits = 11; - static constexpr int digits10 = 3; - static constexpr int max_digits10 = 5; - static constexpr int radix = 2; - static constexpr int min_exponent = -13; - static constexpr int min_exponent10 = -4; - static constexpr int max_exponent = 16; - static constexpr int max_exponent10 = 4; - static constexpr auto traps = numeric_limits::traps; - static constexpr auto tinyness_before = - numeric_limits::tinyness_before; - static constexpr c10::Half min() { - return c10::Half(0x0400, c10::Half::from_bits()); - } - static constexpr c10::Half lowest() { - return c10::Half(0xFBFF, c10::Half::from_bits()); - } - static constexpr c10::Half max() { - return c10::Half(0x7BFF, c10::Half::from_bits()); - } - static constexpr c10::Half epsilon() { - return c10::Half(0x1400, c10::Half::from_bits()); - } - static constexpr c10::Half round_error() { - return c10::Half(0x3800, c10::Half::from_bits()); - } - static constexpr c10::Half infinity() { - return c10::Half(0x7C00, c10::Half::from_bits()); - } - static constexpr c10::Half quiet_NaN() { - return c10::Half(0x7E00, c10::Half::from_bits()); - } - static constexpr c10::Half signaling_NaN() { - return c10::Half(0x7D00, c10::Half::from_bits()); - } - static constexpr c10::Half denorm_min() { - return c10::Half(0x0001, c10::Half::from_bits()); - } -}; - -} // namespace std - -C10_CLANG_DIAGNOSTIC_POP() +#include diff --git a/c10/util/Half.h b/c10/util/Half.h index 9ac1c898013..98480b22db3 100644 --- a/c10/util/Half.h +++ b/c10/util/Half.h @@ -1,233 +1,8 @@ -#pragma once - -/// Defines the Half type (half-precision floating-point) including conversions -/// to standard C types and basic arithmetic operations. Note that arithmetic -/// operations are implemented by converting to floating point and -/// performing the operation in float32, instead of using CUDA half intrinsics. -/// Most uses of this type within ATen are memory bound, including the -/// element-wise kernels, and the half intrinsics aren't efficient on all GPUs. -/// If you are writing a compute bound kernel, you can use the CUDA half -/// intrinsics directly on the Half type from device code. - -#include -#include -#include -#include #include -#include -#if defined(__cplusplus) -#include -#elif !defined(__OPENCL_VERSION__) -#include +// need to keep the following for BC because the APIs in here were exposed +// before migrating Half to torch/headeronly +#if (defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)) && \ + !defined(__APPLE__) +#include #endif - -#ifdef _MSC_VER -#include -#endif - -#include -#include -#include -#include -#include - -#ifdef __CUDACC__ -#include -#endif - -#ifdef __HIPCC__ -#include -#endif - -#if defined(CL_SYCL_LANGUAGE_VERSION) -#include // for SYCL 1.2.1 -#elif defined(SYCL_LANGUAGE_VERSION) -#include // for SYCL 2020 -#endif - -#if defined(__aarch64__) && !defined(__CUDACC__) -#include -#endif - -#if defined(__GNUC__) || defined(__clang__) -#if defined(__x86_64__) || defined(_M_X64) || defined(__i386) || \ - defined(_M_IX86) -#if defined(__F16C__) && \ - !(defined(__CUDA_ARCH__) || defined(__CUDACC__) || \ - defined(__HIP_DEVICE_COMPILE__)) -#define C10_X86_F16 1 -#include // import conversion ops from f16cintrin.h -#endif // defined(__F16C__) && !(defined(__CUDA_ARCH__) || defined(__CUDACC__) - // || defined(__HIP_DEVICE_COMPILE__)) -#endif // __x86_64__ || _M_X64 || __i386 || _M_IX86 -#endif // __GNUC__ || __clang__ - -namespace c10 { - -namespace detail { - -/* - * Convert a 16-bit floating-point number in IEEE half-precision format, in bit - * representation, to a 32-bit floating-point number in IEEE single-precision - * format, in bit representation. - * - * @note The implementation doesn't use any floating-point operations. - */ -inline uint32_t fp16_ieee_to_fp32_bits(uint16_t h) { - /* - * Extend the half-precision floating-point number to 32 bits and shift to the - * upper part of the 32-bit word: - * +---+-----+------------+-------------------+ - * | S |EEEEE|MM MMMM MMMM|0000 0000 0000 0000| - * +---+-----+------------+-------------------+ - * Bits 31 26-30 16-25 0-15 - * - * S - sign bit, E - bits of the biased exponent, M - bits of the mantissa, 0 - * - zero bits. - */ - const uint32_t w = (uint32_t)h << 16; - /* - * Extract the sign of the input number into the high bit of the 32-bit word: - * - * +---+----------------------------------+ - * | S |0000000 00000000 00000000 00000000| - * +---+----------------------------------+ - * Bits 31 0-31 - */ - const uint32_t sign = w & UINT32_C(0x80000000); - /* - * Extract mantissa and biased exponent of the input number into the bits 0-30 - * of the 32-bit word: - * - * +---+-----+------------+-------------------+ - * | 0 |EEEEE|MM MMMM MMMM|0000 0000 0000 0000| - * +---+-----+------------+-------------------+ - * Bits 30 27-31 17-26 0-16 - */ - const uint32_t nonsign = w & UINT32_C(0x7FFFFFFF); - /* - * Renorm shift is the number of bits to shift mantissa left to make the - * half-precision number normalized. If the initial number is normalized, some - * of its high 6 bits (sign == 0 and 5-bit exponent) equals one. In this case - * renorm_shift == 0. If the number is denormalize, renorm_shift > 0. Note - * that if we shift denormalized nonsign by renorm_shift, the unit bit of - * mantissa will shift into exponent, turning the biased exponent into 1, and - * making mantissa normalized (i.e. without leading 1). - */ -#ifdef _MSC_VER - unsigned long nonsign_bsr; - _BitScanReverse(&nonsign_bsr, (unsigned long)nonsign); - uint32_t renorm_shift = (uint32_t)nonsign_bsr ^ 31; -#else - uint32_t renorm_shift = __builtin_clz(nonsign); -#endif - renorm_shift = renorm_shift > 5 ? renorm_shift - 5 : 0; - /* - * Iff half-precision number has exponent of 15, the addition overflows - * it into bit 31, and the subsequent shift turns the high 9 bits - * into 1. Thus inf_nan_mask == 0x7F800000 if the half-precision number - * had exponent of 15 (i.e. was NaN or infinity) 0x00000000 otherwise - */ - const int32_t inf_nan_mask = - ((int32_t)(nonsign + 0x04000000) >> 8) & INT32_C(0x7F800000); - /* - * Iff nonsign is 0, it overflows into 0xFFFFFFFF, turning bit 31 - * into 1. Otherwise, bit 31 remains 0. The signed shift right by 31 - * broadcasts bit 31 into all bits of the zero_mask. Thus zero_mask == - * 0xFFFFFFFF if the half-precision number was zero (+0.0h or -0.0h) - * 0x00000000 otherwise - */ - const int32_t zero_mask = (int32_t)(nonsign - 1) >> 31; - /* - * 1. Shift nonsign left by renorm_shift to normalize it (if the input - * was denormal) - * 2. Shift nonsign right by 3 so the exponent (5 bits originally) - * becomes an 8-bit field and 10-bit mantissa shifts into the 10 high - * bits of the 23-bit mantissa of IEEE single-precision number. - * 3. Add 0x70 to the exponent (starting at bit 23) to compensate the - * different in exponent bias (0x7F for single-precision number less 0xF - * for half-precision number). - * 4. Subtract renorm_shift from the exponent (starting at bit 23) to - * account for renormalization. As renorm_shift is less than 0x70, this - * can be combined with step 3. - * 5. Binary OR with inf_nan_mask to turn the exponent into 0xFF if the - * input was NaN or infinity. - * 6. Binary ANDNOT with zero_mask to turn the mantissa and exponent - * into zero if the input was zero. - * 7. Combine with the sign of the input number. - */ - return sign | - ((((nonsign << renorm_shift >> 3) + ((0x70 - renorm_shift) << 23)) | - inf_nan_mask) & - ~zero_mask); -} - -#ifdef C10_X86_F16 -#undef C10_X86_F16 -#endif // C10_X86_F16 - -#if defined(__aarch64__) && !defined(__CUDACC__) -inline float16_t fp16_from_bits(uint16_t h) { - return c10::bit_cast(h); -} - -inline uint16_t fp16_to_bits(float16_t f) { - return c10::bit_cast(f); -} - -// According to https://godbolt.org/z/frExdbsWG it would translate to single -// fcvt s0, h0 -inline float native_fp16_to_fp32_value(uint16_t h) { - return static_cast(fp16_from_bits(h)); -} - -inline uint16_t native_fp16_from_fp32_value(float f) { - return fp16_to_bits(static_cast(f)); -} -#endif - -} // namespace detail - -struct alignas(2) Half { - unsigned short x; - - struct from_bits_t {}; - C10_HOST_DEVICE static constexpr from_bits_t from_bits() { - return from_bits_t(); - } - - // HIP wants __host__ __device__ tag, CUDA does not -#if defined(USE_ROCM) - C10_HOST_DEVICE Half() = default; -#else - Half() = default; -#endif - - constexpr C10_HOST_DEVICE Half(unsigned short bits, from_bits_t) : x(bits) {} -#if defined(__aarch64__) && !defined(__CUDACC__) - inline Half(float16_t value); - inline operator float16_t() const; -#else - inline C10_HOST_DEVICE Half(float value); - inline C10_HOST_DEVICE operator float() const; -#endif - -#if defined(__CUDACC__) || defined(__HIPCC__) - inline C10_HOST_DEVICE Half(const __half& value); - inline C10_HOST_DEVICE operator __half() const; -#endif -#ifdef SYCL_LANGUAGE_VERSION - inline C10_HOST_DEVICE Half(const sycl::half& value); - inline C10_HOST_DEVICE operator sycl::half() const; -#endif -}; - -inline std::ostream& operator<<(std::ostream& out, const Half& value) { - out << (float)value; - return out; -} - -} // namespace c10 - -#include // IWYU pragma: keep diff --git a/test/cpp/aoti_abi_check/test_dtype.cpp b/test/cpp/aoti_abi_check/test_dtype.cpp index ad4393b3cce..25385de5b10 100644 --- a/test/cpp/aoti_abi_check/test_dtype.cpp +++ b/test/cpp/aoti_abi_check/test_dtype.cpp @@ -6,10 +6,10 @@ #include #include #include -#include #include #include +#include #include #include #include @@ -93,17 +93,28 @@ TEST(TestDtype, TestFloat4) { } TEST(TestDtype, TestHalf) { - c10::Half a = 1.0f; - c10::Half b = 2.0f; - c10::Half add = 3.0f; - c10::Half sub = -1.0f; - c10::Half mul = 2.0f; - c10::Half div = 0.5f; + torch::headeronly::Half a = 1.0f; + torch::headeronly::Half b = 2.0f; + torch::headeronly::Half add = 3.0f; + torch::headeronly::Half sub = -1.0f; + torch::headeronly::Half mul = 2.0f; + torch::headeronly::Half div = 0.5f; EXPECT_EQ(a + b, add); EXPECT_EQ(a - b, sub); EXPECT_EQ(a * b, mul); EXPECT_EQ(a / b, div); + EXPECT_EQ(a += b, add); + EXPECT_EQ(a -= b, add - b); + EXPECT_EQ(a *= b, b); + EXPECT_EQ(a /= b, mul * div); + +#if defined(__aarch64__) && !defined(__CUDACC__) + EXPECT_EQ( + torch::headeronly::detail::fp16_to_bits( + torch::headeronly::detail::fp16_from_bits(32)), + 32); +#endif } TEST(TestDtype, TestComplexFloat) { diff --git a/torch/headeronly/util/Half.h b/torch/headeronly/util/Half.h index 5c9c15d2743..59a86f07e33 100644 --- a/torch/headeronly/util/Half.h +++ b/torch/headeronly/util/Half.h @@ -1,5 +1,14 @@ #pragma once +/// Defines the Half type (half-precision floating-point) including conversions +/// to standard C types and basic arithmetic operations. Note that arithmetic +/// operations are implemented by converting to floating point and +/// performing the operation in float32, instead of using CUDA half intrinsics. +/// Most uses of this type within ATen are memory bound, including the +/// element-wise kernels, and the half intrinsics aren't efficient on all GPUs. +/// If you are writing a compute bound kernel, you can use the CUDA half +/// intrinsics directly on the Half type from device code. + #include #include #include @@ -16,6 +25,7 @@ #include #include +#include #ifdef __CUDACC__ #include @@ -31,6 +41,11 @@ #include // for SYCL 2020 #endif +#if (defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)) && \ + !defined(__APPLE__) +#include +#endif + #if defined(__aarch64__) && !defined(__CUDACC__) #include #endif @@ -48,7 +63,48 @@ #endif // __x86_64__ || _M_X64 || __i386 || _M_IX86 #endif // __GNUC__ || __clang__ -namespace torch::headeronly::detail { +namespace c10 { + +struct alignas(2) Half { + unsigned short x; + + struct from_bits_t {}; + C10_HOST_DEVICE static constexpr from_bits_t from_bits() { + return from_bits_t(); + } + + // HIP wants __host__ __device__ tag, CUDA does not +#if defined(USE_ROCM) + C10_HOST_DEVICE Half() = default; +#else + Half() = default; +#endif + + constexpr C10_HOST_DEVICE Half(unsigned short bits, from_bits_t) : x(bits) {} +#if defined(__aarch64__) && !defined(__CUDACC__) + inline Half(float16_t value); + inline operator float16_t() const; +#else + inline C10_HOST_DEVICE Half(float value); + inline C10_HOST_DEVICE operator float() const; +#endif + +#if defined(__CUDACC__) || defined(__HIPCC__) + inline C10_HOST_DEVICE Half(const __half& value); + inline C10_HOST_DEVICE operator __half() const; +#endif +#ifdef SYCL_LANGUAGE_VERSION + inline C10_HOST_DEVICE Half(const sycl::half& value); + inline C10_HOST_DEVICE operator sycl::half() const; +#endif +}; + +inline std::ostream& operator<<(std::ostream& out, const Half& value) { + out << (float)value; + return out; +} + +namespace detail { /* * Convert a 16-bit floating-point number in IEEE half-precision format, in bit * representation, to a 32-bit floating-point number in IEEE single-precision @@ -241,9 +297,491 @@ inline uint16_t fp16_ieee_from_fp32_value(float f) { #endif // C10_X86_F16 } -} // namespace torch::headeronly::detail +/* + * Convert a 16-bit floating-point number in IEEE half-precision format, in bit + * representation, to a 32-bit floating-point number in IEEE single-precision + * format, in bit representation. + * + * @note The implementation doesn't use any floating-point operations. + */ +inline uint32_t fp16_ieee_to_fp32_bits(uint16_t h) { + /* + * Extend the half-precision floating-point number to 32 bits and shift to the + * upper part of the 32-bit word: + * +---+-----+------------+-------------------+ + * | S |EEEEE|MM MMMM MMMM|0000 0000 0000 0000| + * +---+-----+------------+-------------------+ + * Bits 31 26-30 16-25 0-15 + * + * S - sign bit, E - bits of the biased exponent, M - bits of the mantissa, 0 + * - zero bits. + */ + const uint32_t w = (uint32_t)h << 16; + /* + * Extract the sign of the input number into the high bit of the 32-bit word: + * + * +---+----------------------------------+ + * | S |0000000 00000000 00000000 00000000| + * +---+----------------------------------+ + * Bits 31 0-31 + */ + const uint32_t sign = w & UINT32_C(0x80000000); + /* + * Extract mantissa and biased exponent of the input number into the bits 0-30 + * of the 32-bit word: + * + * +---+-----+------------+-------------------+ + * | 0 |EEEEE|MM MMMM MMMM|0000 0000 0000 0000| + * +---+-----+------------+-------------------+ + * Bits 30 27-31 17-26 0-16 + */ + const uint32_t nonsign = w & UINT32_C(0x7FFFFFFF); + /* + * Renorm shift is the number of bits to shift mantissa left to make the + * half-precision number normalized. If the initial number is normalized, some + * of its high 6 bits (sign == 0 and 5-bit exponent) equals one. In this case + * renorm_shift == 0. If the number is denormalize, renorm_shift > 0. Note + * that if we shift denormalized nonsign by renorm_shift, the unit bit of + * mantissa will shift into exponent, turning the biased exponent into 1, and + * making mantissa normalized (i.e. without leading 1). + */ +#ifdef _MSC_VER + unsigned long nonsign_bsr; + _BitScanReverse(&nonsign_bsr, (unsigned long)nonsign); + uint32_t renorm_shift = (uint32_t)nonsign_bsr ^ 31; +#else + uint32_t renorm_shift = __builtin_clz(nonsign); +#endif + renorm_shift = renorm_shift > 5 ? renorm_shift - 5 : 0; + /* + * Iff half-precision number has exponent of 15, the addition overflows + * it into bit 31, and the subsequent shift turns the high 9 bits + * into 1. Thus inf_nan_mask == 0x7F800000 if the half-precision number + * had exponent of 15 (i.e. was NaN or infinity) 0x00000000 otherwise + */ + const int32_t inf_nan_mask = + ((int32_t)(nonsign + 0x04000000) >> 8) & INT32_C(0x7F800000); + /* + * Iff nonsign is 0, it overflows into 0xFFFFFFFF, turning bit 31 + * into 1. Otherwise, bit 31 remains 0. The signed shift right by 31 + * broadcasts bit 31 into all bits of the zero_mask. Thus zero_mask == + * 0xFFFFFFFF if the half-precision number was zero (+0.0h or -0.0h) + * 0x00000000 otherwise + */ + const int32_t zero_mask = (int32_t)(nonsign - 1) >> 31; + /* + * 1. Shift nonsign left by renorm_shift to normalize it (if the input + * was denormal) + * 2. Shift nonsign right by 3 so the exponent (5 bits originally) + * becomes an 8-bit field and 10-bit mantissa shifts into the 10 high + * bits of the 23-bit mantissa of IEEE single-precision number. + * 3. Add 0x70 to the exponent (starting at bit 23) to compensate the + * different in exponent bias (0x7F for single-precision number less 0xF + * for half-precision number). + * 4. Subtract renorm_shift from the exponent (starting at bit 23) to + * account for renormalization. As renorm_shift is less than 0x70, this + * can be combined with step 3. + * 5. Binary OR with inf_nan_mask to turn the exponent into 0xFF if the + * input was NaN or infinity. + * 6. Binary ANDNOT with zero_mask to turn the mantissa and exponent + * into zero if the input was zero. + * 7. Combine with the sign of the input number. + */ + return sign | + ((((nonsign << renorm_shift >> 3) + ((0x70 - renorm_shift) << 23)) | + inf_nan_mask) & + ~zero_mask); +} -namespace c10::detail { -using torch::headeronly::detail::fp16_ieee_from_fp32_value; -using torch::headeronly::detail::fp16_ieee_to_fp32_value; -} // namespace c10::detail +#ifdef C10_X86_F16 +#undef C10_X86_F16 +#endif // C10_X86_F16 + +#if defined(__aarch64__) && !defined(__CUDACC__) +inline float16_t fp16_from_bits(uint16_t h) { + return c10::bit_cast(h); +} + +inline uint16_t fp16_to_bits(float16_t f) { + return c10::bit_cast(f); +} + +// According to https://godbolt.org/z/frExdbsWG it would translate to single +// fcvt s0, h0 +inline float native_fp16_to_fp32_value(uint16_t h) { + return static_cast(fp16_from_bits(h)); +} + +inline uint16_t native_fp16_from_fp32_value(float f) { + return fp16_to_bits(static_cast(f)); +} +#endif + +} // namespace detail + +//---------- below is copied from c10/util/Half-inl.h ----------------// +C10_CLANG_DIAGNOSTIC_PUSH() +#if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion") +C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion") +#endif + +#if defined(__aarch64__) && !defined(__CUDACC__) +/// Constructors +inline Half::Half(float16_t value) : x(detail::fp16_to_bits(value)) {} +inline Half::operator float16_t() const { + return detail::fp16_from_bits(x); +} +#else + +inline C10_HOST_DEVICE Half::Half(float value) + : +#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) + x(__half_as_short(__float2half(value))) +#elif defined(__SYCL_DEVICE_ONLY__) + x(c10::bit_cast(sycl::half(value))) +#elif (defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)) && \ + !defined(__APPLE__) + x(at::vec::float2half_scalar(value)) +#else + x(detail::fp16_ieee_from_fp32_value(value)) +#endif +{ +} + +/// Implicit conversions + +inline C10_HOST_DEVICE Half::operator float() const { +#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) + return __half2float(*reinterpret_cast(&x)); +#elif defined(__SYCL_DEVICE_ONLY__) + return float(c10::bit_cast(x)); +#elif (defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)) && \ + !defined(__APPLE__) + return at::vec::half2float_scalar(x); +#elif defined(__aarch64__) && !defined(__CUDACC__) + return detail::native_fp16_to_fp32_value(x); +#else + return detail::fp16_ieee_to_fp32_value(x); +#endif +} + +#endif /* !defined(__aarch64__) || defined(__CUDACC__) \ + */ + +#if defined(__CUDACC__) || defined(__HIPCC__) +inline C10_HOST_DEVICE Half::Half(const __half& value) { + x = *reinterpret_cast(&value); +} +inline C10_HOST_DEVICE Half::operator __half() const { + return *reinterpret_cast(&x); +} +#endif + +#ifdef SYCL_LANGUAGE_VERSION +inline C10_HOST_DEVICE Half::Half(const sycl::half& value) { + x = *reinterpret_cast(&value); +} +inline C10_HOST_DEVICE Half::operator sycl::half() const { + return *reinterpret_cast(&x); +} +#endif + +// CUDA intrinsics + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 350)) || \ + (defined(__clang__) && defined(__CUDA__)) +inline __device__ Half __ldg(const Half* ptr) { + return __ldg(reinterpret_cast(ptr)); +} +#endif + +/// Arithmetic + +inline C10_HOST_DEVICE Half operator+(const Half& a, const Half& b) { + return static_cast(a) + static_cast(b); +} + +inline C10_HOST_DEVICE Half operator-(const Half& a, const Half& b) { + return static_cast(a) - static_cast(b); +} + +inline C10_HOST_DEVICE Half operator*(const Half& a, const Half& b) { + return static_cast(a) * static_cast(b); +} + +inline C10_HOST_DEVICE Half operator/(const Half& a, const Half& b) + __ubsan_ignore_float_divide_by_zero__ { + return static_cast(a) / static_cast(b); +} + +inline C10_HOST_DEVICE Half operator-(const Half& a) { +#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || \ + defined(__HIP_DEVICE_COMPILE__) + return __hneg(a); +#elif defined(__SYCL_DEVICE_ONLY__) + return -c10::bit_cast(a); +#else + return -static_cast(a); +#endif +} + +inline C10_HOST_DEVICE Half& operator+=(Half& a, const Half& b) { + a = a + b; + return a; +} + +inline C10_HOST_DEVICE Half& operator-=(Half& a, const Half& b) { + a = a - b; + return a; +} + +inline C10_HOST_DEVICE Half& operator*=(Half& a, const Half& b) { + a = a * b; + return a; +} + +inline C10_HOST_DEVICE Half& operator/=(Half& a, const Half& b) { + a = a / b; + return a; +} + +/// Arithmetic with floats + +inline C10_HOST_DEVICE float operator+(Half a, float b) { + return static_cast(a) + b; +} +inline C10_HOST_DEVICE float operator-(Half a, float b) { + return static_cast(a) - b; +} +inline C10_HOST_DEVICE float operator*(Half a, float b) { + return static_cast(a) * b; +} +inline C10_HOST_DEVICE float operator/(Half a, float b) + __ubsan_ignore_float_divide_by_zero__ { + return static_cast(a) / b; +} + +inline C10_HOST_DEVICE float operator+(float a, Half b) { + return a + static_cast(b); +} +inline C10_HOST_DEVICE float operator-(float a, Half b) { + return a - static_cast(b); +} +inline C10_HOST_DEVICE float operator*(float a, Half b) { + return a * static_cast(b); +} +inline C10_HOST_DEVICE float operator/(float a, Half b) + __ubsan_ignore_float_divide_by_zero__ { + return a / static_cast(b); +} + +inline C10_HOST_DEVICE float& operator+=(float& a, const Half& b) { + return a += static_cast(b); +} +inline C10_HOST_DEVICE float& operator-=(float& a, const Half& b) { + return a -= static_cast(b); +} +inline C10_HOST_DEVICE float& operator*=(float& a, const Half& b) { + return a *= static_cast(b); +} +inline C10_HOST_DEVICE float& operator/=(float& a, const Half& b) { + return a /= static_cast(b); +} + +/// Arithmetic with doubles + +inline C10_HOST_DEVICE double operator+(Half a, double b) { + return static_cast(a) + b; +} +inline C10_HOST_DEVICE double operator-(Half a, double b) { + return static_cast(a) - b; +} +inline C10_HOST_DEVICE double operator*(Half a, double b) { + return static_cast(a) * b; +} +inline C10_HOST_DEVICE double operator/(Half a, double b) + __ubsan_ignore_float_divide_by_zero__ { + return static_cast(a) / b; +} + +inline C10_HOST_DEVICE double operator+(double a, Half b) { + return a + static_cast(b); +} +inline C10_HOST_DEVICE double operator-(double a, Half b) { + return a - static_cast(b); +} +inline C10_HOST_DEVICE double operator*(double a, Half b) { + return a * static_cast(b); +} +inline C10_HOST_DEVICE double operator/(double a, Half b) + __ubsan_ignore_float_divide_by_zero__ { + return a / static_cast(b); +} + +/// Arithmetic with ints + +inline C10_HOST_DEVICE Half operator+(Half a, int b) { + // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) + return a + static_cast(b); +} +inline C10_HOST_DEVICE Half operator-(Half a, int b) { + // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) + return a - static_cast(b); +} +inline C10_HOST_DEVICE Half operator*(Half a, int b) { + // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) + return a * static_cast(b); +} +inline C10_HOST_DEVICE Half operator/(Half a, int b) { + // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) + return a / static_cast(b); +} + +inline C10_HOST_DEVICE Half operator+(int a, Half b) { + // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) + return static_cast(a) + b; +} +inline C10_HOST_DEVICE Half operator-(int a, Half b) { + // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) + return static_cast(a) - b; +} +inline C10_HOST_DEVICE Half operator*(int a, Half b) { + // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) + return static_cast(a) * b; +} +inline C10_HOST_DEVICE Half operator/(int a, Half b) { + // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) + return static_cast(a) / b; +} + +//// Arithmetic with int64_t + +inline C10_HOST_DEVICE Half operator+(Half a, int64_t b) { + // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) + return a + static_cast(b); +} +inline C10_HOST_DEVICE Half operator-(Half a, int64_t b) { + // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) + return a - static_cast(b); +} +inline C10_HOST_DEVICE Half operator*(Half a, int64_t b) { + // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) + return a * static_cast(b); +} +inline C10_HOST_DEVICE Half operator/(Half a, int64_t b) { + // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) + return a / static_cast(b); +} + +inline C10_HOST_DEVICE Half operator+(int64_t a, Half b) { + // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) + return static_cast(a) + b; +} +inline C10_HOST_DEVICE Half operator-(int64_t a, Half b) { + // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) + return static_cast(a) - b; +} +inline C10_HOST_DEVICE Half operator*(int64_t a, Half b) { + // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) + return static_cast(a) * b; +} +inline C10_HOST_DEVICE Half operator/(int64_t a, Half b) { + // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) + return static_cast(a) / b; +} + +/// NOTE: we do not define comparisons directly and instead rely on the implicit +/// conversion from c10::Half to float. + +C10_CLANG_DIAGNOSTIC_POP() + +} // namespace c10 + +namespace torch::headeronly { + +using c10::Half; +using c10::operator+; +using c10::operator-; +using c10::operator*; +using c10::operator/; +using c10::operator+=; +using c10::operator-=; +using c10::operator*=; +using c10::operator/=; +using c10::operator<<; + +namespace detail { +#if defined(__aarch64__) && !defined(__CUDACC__) +using c10::detail::fp16_from_bits; +using c10::detail::fp16_to_bits; +using c10::detail::native_fp16_from_fp32_value; +using c10::detail::native_fp16_to_fp32_value; +#endif + +using c10::detail::fp16_ieee_from_fp32_value; +using c10::detail::fp16_ieee_to_fp32_bits; +using c10::detail::fp16_ieee_to_fp32_value; +} // namespace detail + +} // namespace torch::headeronly + +namespace std { + +template <> +class numeric_limits { + public: + static constexpr bool is_specialized = true; + static constexpr bool is_signed = true; + static constexpr bool is_integer = false; + static constexpr bool is_exact = false; + static constexpr bool has_infinity = true; + static constexpr bool has_quiet_NaN = true; + static constexpr bool has_signaling_NaN = true; + static constexpr auto has_denorm = numeric_limits::has_denorm; + static constexpr auto has_denorm_loss = + numeric_limits::has_denorm_loss; + static constexpr auto round_style = numeric_limits::round_style; + static constexpr bool is_iec559 = true; + static constexpr bool is_bounded = true; + static constexpr bool is_modulo = false; + static constexpr int digits = 11; + static constexpr int digits10 = 3; + static constexpr int max_digits10 = 5; + static constexpr int radix = 2; + static constexpr int min_exponent = -13; + static constexpr int min_exponent10 = -4; + static constexpr int max_exponent = 16; + static constexpr int max_exponent10 = 4; + static constexpr auto traps = numeric_limits::traps; + static constexpr auto tinyness_before = + numeric_limits::tinyness_before; + static constexpr c10::Half min() { + return c10::Half(0x0400, c10::Half::from_bits()); + } + static constexpr c10::Half lowest() { + return c10::Half(0xFBFF, c10::Half::from_bits()); + } + static constexpr c10::Half max() { + return c10::Half(0x7BFF, c10::Half::from_bits()); + } + static constexpr c10::Half epsilon() { + return c10::Half(0x1400, c10::Half::from_bits()); + } + static constexpr c10::Half round_error() { + return c10::Half(0x3800, c10::Half::from_bits()); + } + static constexpr c10::Half infinity() { + return c10::Half(0x7C00, c10::Half::from_bits()); + } + static constexpr c10::Half quiet_NaN() { + return c10::Half(0x7E00, c10::Half::from_bits()); + } + static constexpr c10::Half signaling_NaN() { + return c10::Half(0x7D00, c10::Half::from_bits()); + } + static constexpr c10::Half denorm_min() { + return c10::Half(0x0001, c10::Half::from_bits()); + } +}; + +} // namespace std