From 382fbcc1e43ae5d46ec148bdfd8dcfb73da81b77 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Thu, 20 Feb 2025 13:55:42 +0000 Subject: [PATCH] add the `torch.float8_e8m0fnu` dtype to PyTorch (#147466) Summary: Continuing the work from https://github.com/pytorch/pytorch/pull/146427 Adds the `torch.float8_e8m0fnu` dtype to PyTorch, as detailed in https://github.com/pytorch/pytorch/issues/146414 . Please see the issue for a detailed definition of the format. Example of basic functionality: ```python import torch # round trip x0 = torch.randn(4, 4, dtype=torch.float32) x1 = x0.to(torch.float8_e8m0fnu) # RNE rounding x2 = x1.to(torch.float32) # 2 ** exponent # creation with empty x0 = torch.empty(4, 4, dtype=torch.float8_e8m0fnu) # printing print(x0) ``` Done in this PR: * numerical correctness * op coverage (except for `torch._scaled_mm`): create tensor, cast to/from float32 * printing a tensor works For future PRs: * performance optimizations for casting * torch._scaled_mm * PT2 * various cleanups (detailed in comments with issue numbers) Test Plan: ``` pytest test/quantization/core/experimental/test_float8.py -s ``` Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/147466 Approved by: https://github.com/drisspg --- aten/src/ATen/DLConvertor.cpp | 2 + aten/src/ATen/Dispatch_v2.h | 2 +- aten/src/ATen/native/Copy.cpp | 4 +- aten/src/ATen/native/TensorCompare.cpp | 3 +- aten/src/ATen/native/cpu/CopyKernel.cpp | 8 +- aten/src/ATen/native/cpu/FillKernel.cpp | 3 + aten/src/ATen/native/cpu/IndexKernel.cpp | 8 +- aten/src/ATen/native/cuda/Copy.cu | 24 ++- aten/src/ATen/native/cuda/Indexing.cu | 40 ++++- aten/src/ATen/native/cuda/jit_utils.h | 4 + c10/core/Scalar.h | 11 +- c10/core/ScalarType.cpp | 3 + c10/core/ScalarType.h | 22 ++- c10/util/Float8_e8m0fnu-inl.h | 112 ++++++++++++++ c10/util/Float8_e8m0fnu.cpp | 12 ++ c10/util/Float8_e8m0fnu.h | 120 +++++++++++++++ c10/util/TypeCast.h | 14 ++ .../core/experimental/test_float8.py | 142 +++++++++++++++++- tools/pyi/gen_pyi.py | 3 +- torch/_tensor_str.py | 10 ++ torch/csrc/TypeInfo.cpp | 15 +- torch/csrc/utils/python_scalars.h | 13 +- torch/storage.py | 1 + torchgen/api/types/types.py | 2 + torchgen/model.py | 1 + 25 files changed, 535 insertions(+), 44 deletions(-) create mode 100644 c10/util/Float8_e8m0fnu-inl.h create mode 100644 c10/util/Float8_e8m0fnu.cpp create mode 100644 c10/util/Float8_e8m0fnu.h diff --git a/aten/src/ATen/DLConvertor.cpp b/aten/src/ATen/DLConvertor.cpp index 137cb8456d7..2d16299c780 100644 --- a/aten/src/ATen/DLConvertor.cpp +++ b/aten/src/ATen/DLConvertor.cpp @@ -63,10 +63,12 @@ DLDataType getDLDataType(const Tensor& t) { case ScalarType::BFloat16: dtype.code = DLDataTypeCode::kDLBfloat; break; + // TODO(#146647): use macro here instead of spelling out each shell dtype case ScalarType::Float8_e5m2: case ScalarType::Float8_e5m2fnuz: case ScalarType::Float8_e4m3fn: case ScalarType::Float8_e4m3fnuz: + case ScalarType::Float8_e8m0fnu: TORCH_CHECK(false, "float8 types are not supported by dlpack"); break; case ScalarType::QInt8: diff --git a/aten/src/ATen/Dispatch_v2.h b/aten/src/ATen/Dispatch_v2.h index 31dd12f8de9..d0b77220fae 100644 --- a/aten/src/ATen/Dispatch_v2.h +++ b/aten/src/ATen/Dispatch_v2.h @@ -87,7 +87,7 @@ #define AT_FLOAT8_TYPES \ c10::kFloat8_e5m2, c10::kFloat8_e5m2fnuz, c10::kFloat8_e4m3fn, \ - c10::kFloat8_e4m3fnuz + c10::kFloat8_e4m3fnuz, c10::kFloat8_e8m0fnu #define AT_INTEGRAL_TYPES \ c10::kByte, c10::kChar, c10::kInt, c10::kLong, c10::kShort diff --git a/aten/src/ATen/native/Copy.cpp b/aten/src/ATen/native/Copy.cpp index 5517d945748..4cd46f3b002 100644 --- a/aten/src/ATen/native/Copy.cpp +++ b/aten/src/ATen/native/Copy.cpp @@ -59,8 +59,8 @@ bool copy_transpose_valid(const Tensor& self, const Tensor& src) { #if !defined(C10_MOBILE) #define _AT_DISPATCH_CP_TYPES(TYPE, NAME, ...) \ AT_DISPATCH_V2( \ - TYPE, NAME, AT_WRAP(__VA_ARGS__), kComplexHalf, kHalf, kBool, kBFloat16, kFloat8_e5m2, \ - kFloat8_e4m3fn, kFloat8_e5m2fnuz, kFloat8_e4m3fnuz, AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)) + TYPE, NAME, AT_WRAP(__VA_ARGS__), kComplexHalf, kHalf, kBool, kBFloat16, \ + AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)) #else #define _AT_DISPATCH_CP_TYPES(TYPE, NAME, ...) \ AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( \ diff --git a/aten/src/ATen/native/TensorCompare.cpp b/aten/src/ATen/native/TensorCompare.cpp index 862b5fdaa25..f37376b5fc8 100644 --- a/aten/src/ATen/native/TensorCompare.cpp +++ b/aten/src/ATen/native/TensorCompare.cpp @@ -460,7 +460,8 @@ Tensor isinf(const Tensor& self) { Tensor isfinite(const Tensor& self) { // Note: Integral tensor values are always finite - if (c10::isIntegralType(self.scalar_type(), /*includeBool=*/true)) { + if (c10::isIntegralType(self.scalar_type(), /*includeBool=*/true) || + self.scalar_type() == kFloat8_e8m0fnu) { return at::ones_like(self, at::kBool, at::MemoryFormat::Preserve); } diff --git a/aten/src/ATen/native/cpu/CopyKernel.cpp b/aten/src/ATen/native/cpu/CopyKernel.cpp index 3992490ff8a..78651bca746 100644 --- a/aten/src/ATen/native/cpu/CopyKernel.cpp +++ b/aten/src/ATen/native/cpu/CopyKernel.cpp @@ -204,12 +204,12 @@ static void reduced_float_copy_kernel(TensorIteratorBase &iter, bool requires_ne #define _AT_DISPATCH_ALL_TYPES(TYPE, NAME, ...) \ AT_DISPATCH_V2(TYPE, NAME, AT_WRAP(__VA_ARGS__), \ kComplexHalf, kHalf, kBool, \ - kBFloat16, kFloat8_e5m2, kFloat8_e4m3fn, \ - kFloat8_e5m2fnuz, kFloat8_e4m3fnuz, AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)) + kBFloat16, AT_EXPAND(AT_FLOAT8_TYPES), \ + AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)) #define _AT_DISPATCH_ALL_TYPES_NO_CF(TYPE, NAME, ...) \ AT_DISPATCH_V2(TYPE, NAME, AT_WRAP(__VA_ARGS__), \ - kBool, kHalf, kBFloat16, kFloat8_e5m2, kFloat8_e4m3fn, \ - kFloat8_e5m2fnuz, kFloat8_e4m3fnuz, AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)) + kBool, kHalf, kBFloat16, AT_EXPAND(AT_FLOAT8_TYPES), \ + AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)) #else #define _AT_DISPATCH_ALL_TYPES(TYPE, NAME, ...) \ AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( \ diff --git a/aten/src/ATen/native/cpu/FillKernel.cpp b/aten/src/ATen/native/cpu/FillKernel.cpp index e059636a43c..e22df01635f 100644 --- a/aten/src/ATen/native/cpu/FillKernel.cpp +++ b/aten/src/ATen/native/cpu/FillKernel.cpp @@ -51,6 +51,9 @@ void fill_kernel(TensorIterator& iter, const Scalar& value_scalar) { fill_non_native_type(iter, value_scalar); } else if (iter.dtype() == ScalarType::Float8_e5m2fnuz) { fill_non_native_type(iter, value_scalar); + } else if (iter.dtype() == ScalarType::Float8_e8m0fnu) { + // TODO(#146647): use macro here instead of spelling out each float8 dtype + fill_non_native_type(iter, value_scalar); } else { AT_DISPATCH_V2( iter.dtype(), "fill_cpu", AT_WRAP([&]() { diff --git a/aten/src/ATen/native/cpu/IndexKernel.cpp b/aten/src/ATen/native/cpu/IndexKernel.cpp index 966105b29e4..1e6723b5f08 100644 --- a/aten/src/ATen/native/cpu/IndexKernel.cpp +++ b/aten/src/ATen/native/cpu/IndexKernel.cpp @@ -184,7 +184,13 @@ void index_put_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef } }), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), - AT_EXPAND(AT_FLOAT8_TYPES), + // AT_EXPAND(AT_FLOAT8_TYPES), + // TODO(#113663): clean up accumulation behavior in float8 dtypes, accumulate=True + // should not be supported here, then reenable AT_FLOAT8_DTYPES + kFloat8_e4m3fn, + kFloat8_e5m2, + kFloat8_e4m3fnuz, + kFloat8_e5m2fnuz, kComplexHalf, kHalf, kBool, diff --git a/aten/src/ATen/native/cuda/Copy.cu b/aten/src/ATen/native/cuda/Copy.cu index 0113d9f0e33..7e6b66d8e06 100644 --- a/aten/src/ATen/native/cuda/Copy.cu +++ b/aten/src/ATen/native/cuda/Copy.cu @@ -144,6 +144,28 @@ void float8_copy_kernel_cuda(TensorIteratorBase &iter) { gpu_kernel(iter, [] GPU_LAMBDA(Float8_e5m2fnuz x) { return x; }); break; } + } else if (dtype == kFloat8_e8m0fnu) { + // TODO(#146647): clean this up, too much copy-pasta + switch (other_dtype) { + case kFloat: + gpu_kernel_nocast(iter, [] GPU_LAMBDA(float value) { + return Float8_e8m0fnu(value); + }); + break; + case kHalf: + gpu_kernel_nocast(iter, [] GPU_LAMBDA(Half value) { + return Float8_e8m0fnu(value); + }); + break; + case kBFloat16: + gpu_kernel_nocast(iter, [] GPU_LAMBDA(BFloat16 value) { + return Float8_e8m0fnu(value); + }); + break; + default: + gpu_kernel(iter, [] GPU_LAMBDA(Float8_e8m0fnu x) { return x; }); + break; + } } else { TORCH_CHECK(false, "This supposed ot be called only for Float8 types"); } @@ -157,7 +179,7 @@ void direct_copy_kernel_cuda(TensorIteratorBase &iter) { AT_DISPATCH_QINT_TYPES(dtype, "copy_", [&] { gpu_kernel(iter, [] GPU_LAMBDA(scalar_t x) { return x; }); }); - } else if (dtype == kFloat8_e5m2 || dtype == kFloat8_e4m3fn || dtype == kFloat8_e5m2fnuz || dtype == kFloat8_e4m3fnuz) { + } else if (isFloat8Type(dtype)) { float8_copy_kernel_cuda(iter); } else if (iter.dtype(1) == kFloat && (dtype == kBFloat16 || dtype == kHalf)) { if (dtype == kBFloat16) { diff --git a/aten/src/ATen/native/cuda/Indexing.cu b/aten/src/ATen/native/cuda/Indexing.cu index ecea5e08f6b..dbd49cd4c1c 100644 --- a/aten/src/ATen/native/cuda/Indexing.cu +++ b/aten/src/ATen/native/cuda/Indexing.cu @@ -582,7 +582,13 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List inline std::string typeName() { template <> inline std::string typeName() { return "at::Float8_e4m3fnuz"; } +template <> inline std::string typeName() { + // TODO(#146647): Can the code here be made generic for any scalartype? + return "at::Float8_e8m0fnu"; +} #define TYPE_NAME_CASE(ctype, scalartype) \ case ScalarType::scalartype: return typeName(); diff --git a/c10/core/Scalar.h b/c10/core/Scalar.h index 9d1dad2d993..2a40114573c 100644 --- a/c10/core/Scalar.h +++ b/c10/core/Scalar.h @@ -49,16 +49,9 @@ class C10_API Scalar { #define DEFINE_IMPLICIT_CTOR(type, name) \ Scalar(type vv) : Scalar(vv, true) {} - AT_FORALL_SCALAR_TYPES_AND7( - Half, - BFloat16, - Float8_e5m2, - Float8_e4m3fn, - Float8_e5m2fnuz, - Float8_e4m3fnuz, - ComplexHalf, - DEFINE_IMPLICIT_CTOR) + AT_FORALL_SCALAR_TYPES_AND3(Half, BFloat16, ComplexHalf, DEFINE_IMPLICIT_CTOR) AT_FORALL_COMPLEX_TYPES(DEFINE_IMPLICIT_CTOR) + AT_FORALL_FLOAT8_TYPES(DEFINE_IMPLICIT_CTOR) // Helper constructors to allow Scalar creation from long and long long types // As std::is_same_v is false(except Android), one needs to diff --git a/c10/core/ScalarType.cpp b/c10/core/ScalarType.cpp index e3fe4b07532..d00d0240a29 100644 --- a/c10/core/ScalarType.cpp +++ b/c10/core/ScalarType.cpp @@ -222,6 +222,9 @@ std::pair getDtypeNames(c10::ScalarType scalarType) { return std::make_pair("float8_e5m2fnuz", ""); case c10::ScalarType::Float8_e4m3fnuz: return std::make_pair("float8_e4m3fnuz", ""); + case c10::ScalarType::Float8_e8m0fnu: + // TODO(#146647): macroify all of this + return std::make_pair("float8_e8m0fnu", ""); default: throw std::runtime_error("Unimplemented scalar type"); } diff --git a/c10/core/ScalarType.h b/c10/core/ScalarType.h index fa0ef9be841..32ae5aaee8f 100644 --- a/c10/core/ScalarType.h +++ b/c10/core/ScalarType.h @@ -7,6 +7,7 @@ #include #include #include +#include #include #include #include @@ -102,7 +103,8 @@ struct dummy_int1_7_t {}; _(c10::dummy_int1_7_t<4>, Int4) /* 40 */ \ _(c10::dummy_int1_7_t<5>, Int5) /* 41 */ \ _(c10::dummy_int1_7_t<6>, Int6) /* 42 */ \ - _(c10::dummy_int1_7_t<7>, Int7) /* 43 */ + _(c10::dummy_int1_7_t<7>, Int7) /* 43 */ \ + _(c10::Float8_e8m0fnu, Float8_e8m0fnu) /* 44 */ // If you want to support ComplexHalf for real, add ComplexHalf // into this macro (and change the name). But beware: convert() @@ -146,7 +148,8 @@ struct dummy_int1_7_t {}; _(at::Float8_e5m2, Float8_e5m2) \ _(at::Float8_e4m3fn, Float8_e4m3fn) \ _(at::Float8_e5m2fnuz, Float8_e5m2fnuz) \ - _(at::Float8_e4m3fnuz, Float8_e4m3fnuz) + _(at::Float8_e4m3fnuz, Float8_e4m3fnuz) \ + _(at::Float8_e8m0fnu, Float8_e8m0fnu) enum class ScalarType : int8_t { #define DEFINE_ST_ENUM_VAL_(_1, n) n, @@ -317,6 +320,13 @@ AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_CppTypeToScalarType) _(c10::quint4x2, QUInt4x2) \ _(c10::quint2x4, QUInt2x4) +#define AT_FORALL_FLOAT8_TYPES(_) \ + _(at::Float8_e5m2, Float8_e5m2) \ + _(at::Float8_e4m3fn, Float8_e4m3fn) \ + _(at::Float8_e5m2fnuz, Float8_e5m2fnuz) \ + _(at::Float8_e4m3fnuz, Float8_e4m3fnuz) \ + _(at::Float8_e8m0fnu, Float8_e8m0fnu) + #define AT_FORALL_COMPLEX_TYPES(_) \ _(c10::complex, ComplexFloat) \ _(c10::complex, ComplexDouble) @@ -372,7 +382,8 @@ inline bool isIntegralType(ScalarType t) { inline bool isFloat8Type(ScalarType t) { return t == ScalarType::Float8_e5m2 || t == ScalarType::Float8_e5m2fnuz || - t == ScalarType::Float8_e4m3fn || t == ScalarType::Float8_e4m3fnuz; + t == ScalarType::Float8_e4m3fn || t == ScalarType::Float8_e4m3fnuz || + t == ScalarType::Float8_e8m0fnu; } inline bool isReducedFloatingType(ScalarType t) { @@ -446,6 +457,10 @@ inline bool isSignedType(ScalarType t) { return std::numeric_limits< \ ::c10::impl::ScalarTypeToCPPTypeT>::is_signed; + // TODO(#146647): If we expect to have numeric_limits for everything, + // let's just have a big macro for the whole thing. + // If we're hardcoding it, let's just use the macro and a "true"/"false" + // below? switch (t) { case ScalarType::QInt8: case ScalarType::QUInt8: @@ -467,6 +482,7 @@ inline bool isSignedType(ScalarType t) { CASE_ISSIGNED(Float8_e5m2fnuz); CASE_ISSIGNED(Float8_e4m3fn); CASE_ISSIGNED(Float8_e4m3fnuz); + CASE_ISSIGNED(Float8_e8m0fnu); CASE_ISSIGNED(Byte); CASE_ISSIGNED(Char); CASE_ISSIGNED(Short); diff --git a/c10/util/Float8_e8m0fnu-inl.h b/c10/util/Float8_e8m0fnu-inl.h new file mode 100644 index 00000000000..7d67934abd1 --- /dev/null +++ b/c10/util/Float8_e8m0fnu-inl.h @@ -0,0 +1,112 @@ +#pragma once + +#include +#include +#include +#include + +// TODO(#146647): Can we remove the below warning? +C10_CLANG_DIAGNOSTIC_PUSH() +#if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion") +C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion") +#endif + +namespace c10 { + +/// Constructors + +inline C10_HOST_DEVICE Float8_e8m0fnu::Float8_e8m0fnu(float value) + : x(detail::fp8e8m0fnu_from_fp32_value(value)) {} + +/// Implicit conversions + +inline C10_HOST_DEVICE Float8_e8m0fnu::operator float() const { + // TODO(#146647): maybe rewrite without control flow + + // if exponent is zero, need to special case to return 2^-127 instead of zero + if (x == 0) { + return c10::detail::fp32_from_bits(0x00400000); + } + + // if exponent is NaN, need to special case to return properly encoded NaN + if (isnan()) { + return c10::detail::fp32_from_bits(0x7f800001); + } + + // leave sign at 0, set the exponent bits, leave stored mantissa at 0 + uint32_t res = x << 23; + + return c10::detail::fp32_from_bits(res); +} + +/// Special values helper + +inline C10_HOST_DEVICE bool Float8_e8m0fnu::isnan() const { + return x == 0b11111111; +} + +/// NOTE: we do not define comparisons directly and instead rely on the implicit +/// conversion from c10::Float8_e8m0fnu to float. + +} // namespace c10 + +namespace std { + +template <> +class numeric_limits { + public: + static constexpr bool is_specialized = true; + static constexpr bool is_signed = false; + static constexpr bool is_integer = false; + static constexpr bool is_exact = false; + static constexpr bool has_infinity = false; + static constexpr bool has_quiet_NaN = true; + static constexpr bool has_signaling_NaN = false; + static constexpr auto has_denorm = false; + static constexpr auto has_denorm_loss = false; + static constexpr auto round_style = numeric_limits::round_style; + static constexpr bool is_iec559 = false; + static constexpr bool is_bounded = true; + static constexpr bool is_modulo = false; + static constexpr int digits = 1; + static constexpr int digits10 = 0; + static constexpr int max_digits10 = 1; // just a 2! + static constexpr int radix = 2; + static constexpr int min_exponent = -126; + static constexpr int min_exponent10 = -38; + static constexpr int max_exponent = 128; + static constexpr int max_exponent10 = 38; + static constexpr auto traps = numeric_limits::traps; + static constexpr auto tinyness_before = false; + + static constexpr c10::Float8_e8m0fnu min() { + // 2^-127 + return c10::Float8_e8m0fnu(0b00000000, c10::Float8_e8m0fnu::from_bits()); + } + static constexpr c10::Float8_e8m0fnu lowest() { + // 2^-127 + return c10::Float8_e8m0fnu(0b00000000, c10::Float8_e8m0fnu::from_bits()); + } + static constexpr c10::Float8_e8m0fnu max() { + // 254 biased, which is 127 unbiased, so 2^127 + return c10::Float8_e8m0fnu(0b11111110, c10::Float8_e8m0fnu::from_bits()); + } + static constexpr c10::Float8_e8m0fnu epsilon() { + // according to https://en.cppreference.com/w/cpp/types/numeric_limits, this + // is "the difference between 1.0 and the next representable value of the + // given floating-point type". The next representable value is 2.0, so the + // difference is 1.0 which is 2^0. 0 unbiased is 127 biased. + return c10::Float8_e8m0fnu(0b01111111, c10::Float8_e8m0fnu::from_bits()); + } + static constexpr c10::Float8_e8m0fnu round_error() { + // 0.5 in float, which is 2^-1, and -1 + 127 = 126 + return c10::Float8_e8m0fnu(0b01111110, c10::Float8_e8m0fnu::from_bits()); + } + static constexpr c10::Float8_e8m0fnu quiet_NaN() { + return c10::Float8_e8m0fnu(0b11111111, c10::Float8_e8m0fnu::from_bits()); + } +}; + +} // namespace std + +C10_CLANG_DIAGNOSTIC_POP() diff --git a/c10/util/Float8_e8m0fnu.cpp b/c10/util/Float8_e8m0fnu.cpp new file mode 100644 index 00000000000..5787d2637de --- /dev/null +++ b/c10/util/Float8_e8m0fnu.cpp @@ -0,0 +1,12 @@ +#include +#include + +namespace c10 { + +// TODO(#146647): Can we have these in a single shared cpp file +// built with macro to remove the need for a new cpp file? +static_assert( + std::is_standard_layout_v, + "c10::Float8_e8m0fnu must be standard layout."); + +} // namespace c10 diff --git a/c10/util/Float8_e8m0fnu.h b/c10/util/Float8_e8m0fnu.h new file mode 100644 index 00000000000..91db8409174 --- /dev/null +++ b/c10/util/Float8_e8m0fnu.h @@ -0,0 +1,120 @@ +#pragma once + +/// Defines the Float8_e8m0fnu type (8-bit floating-point) including +/// conversions to standard C types +/// Binary configuration : +/// eeeeeeee +/// no sign bits +/// 8 exponent bits +/// no mantissa bits +/// +/// This is the E8M0 dtype from the OCP MX format spec +/// (https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf, +/// Section 5.4.1) + +#include +#include +#include +#include + +// TODO(#146647): do we need to special case OPENCL? +#if defined(__cplusplus) +#include +#elif !defined(__OPENCL_VERSION__) +#include +#include +#endif + +#include +#include + +namespace c10 { + +namespace detail { + +/* + * Convert a 32-bit floating-point number in IEEE single-precision format to a + * 8-bit floating-point number in fp8 e8m0fnu format, in bit representation. + */ +inline C10_HOST_DEVICE uint8_t fp8e8m0fnu_from_fp32_value(float f) { + // TODO(#146647): maybe rewrite without control flow + + uint32_t f_bits = c10::detail::fp32_to_bits(f); + + // extract the exponent + uint32_t exponent = (f_bits >> 23) & 0b11111111; + + // special case float32 NaN and +-inf to map to e8m0 nan + if (exponent == 0b11111111) { + return exponent; + } + + // next, we use guard, round, sticky bits and the LSB to implement round to + // nearest, with ties to even + + // guard bit - bit 23, or 22 zero-indexed + uint8_t g = (f_bits & 0x400000) > 0; + // round bit - bit 22, or 21 zero-indexed + uint8_t r = (f_bits & 0x200000) > 0; + // sticky bit - bits 21 to 1, or 20 to 0 zero-indexed + uint8_t s = (f_bits & 0x1FFFFF) > 0; + // in casting to e8m0, LSB is the implied mantissa bit. It equals to 0 if the + // original float32 is denormal, and to 1 if the original float32 is normal. + uint8_t lsb = exponent > 0; + + // implement the RNE logic + bool round_up = false; + + // if g == 0, round down (no-op) + if (g == 1) { + if ((r == 1) || (s == 1)) { + // round up + round_up = true; + } else { + if (lsb == 1) { + // round up + round_up = true; + } + // if lsb == 0, round down (no-op) + } + } + + if (round_up) { + // adjust exponent + // note that if exponent was 255 we would have already returned earlier, so + // we know we can add one safely without running out of bounds + exponent++; + } + + return exponent; +} + +} // namespace detail + +struct alignas(1) Float8_e8m0fnu { + uint8_t x; + + struct from_bits_t {}; + C10_HOST_DEVICE static constexpr from_bits_t from_bits() { + return from_bits_t(); + } + + Float8_e8m0fnu() = default; + + constexpr C10_HOST_DEVICE Float8_e8m0fnu(uint8_t bits, from_bits_t) + : x(bits) {} + inline C10_HOST_DEVICE Float8_e8m0fnu(float value); + inline C10_HOST_DEVICE operator float() const; + inline C10_HOST_DEVICE bool isnan() const; +}; + +C10_API inline std::ostream& operator<<( + std::ostream& out, + const Float8_e8m0fnu& value) { + out << (float)value; + return out; +} + +} // namespace c10 + +#include // IWYU pragma: keep diff --git a/c10/util/TypeCast.h b/c10/util/TypeCast.h index 7406b83f51f..3291fce2c41 100644 --- a/c10/util/TypeCast.h +++ b/c10/util/TypeCast.h @@ -5,6 +5,7 @@ #include #include #include +#include #include #include #include @@ -151,6 +152,19 @@ struct static_cast_with_inter_type< } }; +// TODO(#146647): Can we make all these template specialization happen +// based off our apply macros? +template <> +struct static_cast_with_inter_type< + c10::complex, + c10::Float8_e8m0fnu> { + C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline c10::complex< + c10::Half> + apply(c10::Float8_e8m0fnu src) { + return static_cast>(c10::complex{src}); + } +}; + template <> struct static_cast_with_inter_type, c10::Half> { C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline c10::complex< diff --git a/test/quantization/core/experimental/test_float8.py b/test/quantization/core/experimental/test_float8.py index e6b40d3edc1..1c4956d551a 100644 --- a/test/quantization/core/experimental/test_float8.py +++ b/test/quantization/core/experimental/test_float8.py @@ -1,5 +1,6 @@ # Owner(s): ["oncall: quantization"] +import struct import unittest import torch @@ -14,6 +15,7 @@ from torch.testing._internal.common_utils import ( parametrize, run_tests, subtest, + TemporaryFileName, TestCase, ) @@ -23,11 +25,13 @@ FLOAT8_DTYPES = [ torch.float8_e5m2fnuz, torch.float8_e4m3fn, torch.float8_e4m3fnuz, + torch.float8_e8m0fnu, ] CUDA_FLOAT8_DTYPES = [ torch.float8_e5m2, torch.float8_e4m3fn, + torch.float8_e8m0fnu, ] # The following information are not yet provided by torch.finfo. @@ -37,6 +41,7 @@ MANTISSA_BITS = { torch.float8_e5m2fnuz: 2, torch.float8_e4m3fn: 3, torch.float8_e4m3fnuz: 3, + torch.float8_e8m0fnu: 0, } # As in np.finfo(dtype).minexp @@ -45,6 +50,7 @@ MINEXP = { torch.float8_e5m2fnuz: -15, torch.float8_e4m3fn: -6, torch.float8_e4m3fnuz: -7, + torch.float8_e8m0fnu: -127, } SPECIAL_NUMBERS = { @@ -108,11 +114,24 @@ SPECIAL_NUMBERS = { ("00000001", 0.125 * (2**-7), "min_subnorm"), ("10000001", -0.125 * (2**-7), "neg_min_subnorm"), ], + torch.float8_e8m0fnu: [ + ("00000000", float(2**-127), "smallest_number"), + ("11111110", float(2**127), "largest_number"), + ("01111110", 0.5, "zero_point_five"), + ("01111111", 1.0, "one"), + ("10000000", 2.0, "two"), + ("11111111", float("nan"), "nan"), + ], } FLOAT8_DTYPES_WITH_INF = [torch.float8_e5m2] +def _int_bits_to_float(x): + y = struct.unpack("!f", struct.pack("!I", x))[0] + return y + + def simulate_fp8_precision(input, variant): """Round input (as float32) to the given float8 datatype variant.""" @@ -165,6 +184,24 @@ def simulate_fp8_precision(input, variant): return vals * signs +def _round_e8m0_rne(biased_exponent, lsb, g, r, s): + round_up = False + + # apply g,r,s rounding rules for RNE rounding + if g == 1: + if (r == 1) or (s == 1): + round_up = True + else: + if lsb: + round_up = True + + # round up if necessary + if round_up: + biased_exponent += 1 + + return biased_exponent + + ROUND_TRIP_TEST_CASES = ( # A general 'soak test'. subtest( @@ -198,17 +235,19 @@ ROUND_TRIP_TEST_CASES = ( class TestFloat8Dtype(TestCase): - """ - Sanity test for zeros comparison - """ - @dtypes(*FLOAT8_DTYPES) @dtypesIfCUDA(*CUDA_FLOAT8_DTYPES) def test_creation_with_zeros(self, dtype, device): """Sanity test, round-trip casting of zeros.""" - x = torch.zeros(8, dtype=torch.float, device=device) x8 = torch.zeros(8, dtype=dtype, device=device) - self.assertEqual(x, x8.float(), atol=0, rtol=0) + if dtype is torch.float8_e8m0fnu: + # zeros are not supported for this dtype, values get clamped + # to 2 ^ -127 + x = torch.full((8,), 2**-127, dtype=torch.float, device=device) + self.assertEqual(x, x8.float(), atol=0, rtol=0) + else: + x = torch.zeros(8, dtype=torch.float, device=device) + self.assertEqual(x, x8.float(), atol=0, rtol=0) @dtypes(*FLOAT8_DTYPES) @dtypesIfCUDA(*CUDA_FLOAT8_DTYPES) @@ -217,12 +256,69 @@ class TestFloat8Dtype(TestCase): """Numerical test of float8 conversion, by performing a round-trip cast to the float8 dtype and back to float32, comparing against simulated lower precision.""" + if dtype is torch.float8_e8m0fnu: + return unittest.skip("numerics for e8m0fnu are tested elsewhere") + x = get_input(dtype, device) x = torch.cat((x, -x)) x8 = x.to(dtype) x8_simulated = simulate_fp8_precision(x, dtype) self.assertEqual(x8_simulated, x8.float()) + def test_float8_e8m0fnu_rne_rounding(self, device): + """ + For every possible e8m0 exponent (256 options) and for every possible + g, r, s bits of the float32 mantissa, verify that RNE rounding is + correctly applied when casting from float32 to e8m0 + + Note: this code is morally similar to `test_cast_round_trip`, but + IMO simpler to special case e8m0 here. + """ + + for biased_exponent in range(0, 256): + # iterate through all the possible options of guard, round, sticky bits + # for the current exponent + for grs in range(8): + # create a positive floating point number with the specified exponent + # and mantissa guard, round, sticky bits + uint32_t_start = (biased_exponent << 23) + (grs << 20) + fp32_start = _int_bits_to_float(uint32_t_start) + + # create an RNE rounded version of the exponent + if biased_exponent == 255: + new_biased_exponent = biased_exponent + else: + lsb = biased_exponent > 0 + g = grs >> 2 + r = (grs >> 1) & 0b1 + s = grs & 0b1 + new_biased_exponent = _round_e8m0_rne(biased_exponent, lsb, g, r, s) + + # create an RNE rounded version of the float + fp32_e8m0_fp32_emulated = _int_bits_to_float(new_biased_exponent << 23) + + # now, do the same in PyTorch and see if results match + fp32_pt_start = torch.full( + (1,), fp32_start, device=device, dtype=torch.float + ) + fp32_pt_e8m0 = fp32_pt_start.to(torch.float8_e8m0fnu) + fp32_pt_e8m0_fp32 = fp32_pt_e8m0.to(torch.float) + + expected = fp32_e8m0_fp32_emulated + if biased_exponent == 254 and grs >= 4: + # special case rounding up from the largest representable float32 exponent, which + # saturates to nan + expected = float("nan") + elif biased_exponent == 255: + # special case inf and nan, which becomes nan + expected = float("nan") + + actual = fp32_pt_e8m0_fp32.item() + + self.assertEqual( + expected, actual, f"expected: {expected}, actual: {actual}" + ) + @dtypes(*FLOAT8_DTYPES) @dtypesIfCUDA(*CUDA_FLOAT8_DTYPES) def test_special_numbers(self, dtype, device): @@ -269,6 +365,32 @@ class TestFloat8Dtype(TestCase): torch.use_deterministic_algorithms(use_deterministic) torch.empty(4, 4, device=device, dtype=dtype) + @dtypes(*FLOAT8_DTYPES) + @dtypesIfCUDA(*CUDA_FLOAT8_DTYPES) + def test_to_string(self, dtype, device): + x = torch.empty(4, 4, device=device, dtype=dtype) + str(x) + + @dtypes(*FLOAT8_DTYPES) + def test_finfo(self, dtype, device): + torch.finfo(dtype) + + @dtypes(*FLOAT8_DTYPES) + @dtypesIfCUDA(*CUDA_FLOAT8_DTYPES) + def test_cat(self, dtype, device): + x1 = torch.empty(4, 4, device=device, dtype=dtype) + x2 = torch.empty(4, 4, device=device, dtype=dtype) + torch.cat([x1, x2]) + + @dtypes(*FLOAT8_DTYPES) + @dtypesIfCUDA(*CUDA_FLOAT8_DTYPES) + def test_save_load(self, dtype, device): + x1 = torch.randint(0, 10, (4, 4), device=device, dtype=torch.uint8).view(dtype) + with TemporaryFileName() as fname: + torch.save(x1, fname) + x1_save_load = torch.load(fname) + torch.testing.assert_close(x1, x1_save_load, atol=0, rtol=0) + instantiate_device_type_tests(TestFloat8Dtype, globals()) @@ -285,6 +407,9 @@ class TestFloat8DtypeCPUOnly(TestCase): @dtypes(*CUDA_FLOAT8_DTYPES) def test_mul(self, dtype): + # TODO(#113663): remove arithmetic support from all float8 dtypes + if dtype is torch.float8_e8m0fnu: + return unittest.skip("arithmetic not supported for torch.float8_e8m0fnu") shape = (10, 10) a = torch.randn(shape) a8_simulated = simulate_fp8_precision(a, dtype) @@ -299,6 +424,11 @@ class TestFloat8DtypeCPUOnly(TestCase): @unittest.skipIf(IS_WINDOWS, "torch.compile not supported on Windows yet") @dtypes(*CUDA_FLOAT8_DTYPES) def test_pt2_traceable_aot_eager(self, dtype): + if dtype is torch.float8_e8m0fnu: + return unittest.skip( + "PT2 support for torch.float8_e8m0fnu is not implemented yet" + ) + @torch.compile(backend="aot_eager", fullgraph=True) def f(x): x = x.to(dtype) diff --git a/tools/pyi/gen_pyi.py b/tools/pyi/gen_pyi.py index bd3f83a31c9..8f032e4c8ee 100644 --- a/tools/pyi/gen_pyi.py +++ b/tools/pyi/gen_pyi.py @@ -1362,7 +1362,7 @@ def gen_pyi( # Generate type signatures for dtype classes # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - # TODO: don't explicitly list dtypes here; get it from canonical + # TODO(#146647): don't explicitly list dtypes here; get it from canonical # source dtype_class_hints = [ f"{n}: dtype = ..." @@ -1377,6 +1377,7 @@ def gen_pyi( "float8_e4m3fnuz", "float8_e5m2", "float8_e5m2fnuz", + "float8_e8m0fnu", "half", "uint8", "uint16", diff --git a/torch/_tensor_str.py b/torch/_tensor_str.py index df298fe8fd3..182236d62e7 100644 --- a/torch/_tensor_str.py +++ b/torch/_tensor_str.py @@ -150,7 +150,17 @@ class _Formatter: # no valid number, do nothing return + if tensor.dtype == torch.float8_e8m0fnu: # type: ignore[attr-defined] + # float8_e8m0fnu is special and does not define arithmetic ops, + # and printing code further in this file assumes the existence + # of various arithmetic ops to figure out what to print. We hack + # and convert to float here to make printing work correctly. + # TODO(#113663): also add the other float8 dtypes here after arithmetic + # support for them is removed + nonzero_finite_vals = nonzero_finite_vals.float() + # Convert to double for easy calculation. HalfTensor overflows with 1e8, and there's no div() on CPU. + nonzero_finite_abs = tensor_totype(nonzero_finite_vals.abs()) nonzero_finite_min = tensor_totype(nonzero_finite_abs.min()) nonzero_finite_max = tensor_totype(nonzero_finite_abs.max()) diff --git a/torch/csrc/TypeInfo.cpp b/torch/csrc/TypeInfo.cpp index 479d88ac206..9c944fa79d4 100644 --- a/torch/csrc/TypeInfo.cpp +++ b/torch/csrc/TypeInfo.cpp @@ -123,16 +123,15 @@ static PyObject* THPDTypeInfo_bits(THPDTypeInfo* self, void*) { } #define _AT_DISPATCH_FINFO_TYPES(TYPE, NAME, ...) \ - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND6( \ - at::kHalf, \ - at::ScalarType::BFloat16, \ - at::ScalarType::Float8_e5m2, \ - at::ScalarType::Float8_e5m2fnuz, \ - at::ScalarType::Float8_e4m3fn, \ - at::ScalarType::Float8_e4m3fnuz, \ + AT_DISPATCH_V2( \ TYPE, \ NAME, \ - __VA_ARGS__) + AT_WRAP(__VA_ARGS__), \ + AT_EXPAND(AT_FLOATING_TYPES), \ + AT_EXPAND(AT_COMPLEX_TYPES), \ + at::kHalf, \ + at::ScalarType::BFloat16, \ + AT_EXPAND(AT_FLOAT8_TYPES)) static PyObject* THPFInfo_eps(THPFInfo* self, void*) { HANDLE_TH_ERRORS diff --git a/torch/csrc/utils/python_scalars.h b/torch/csrc/utils/python_scalars.h index eeeebb709c9..89ce38353be 100644 --- a/torch/csrc/utils/python_scalars.h +++ b/torch/csrc/utils/python_scalars.h @@ -79,6 +79,7 @@ inline void store_scalar(void* data, at::ScalarType scalarType, PyObject* obj) { *(at::BFloat16*)data = at::convert(THPUtils_unpackDouble(obj)); break; + // TODO(#146647): simplify below with macros case at::kFloat8_e5m2: *(at::Float8_e5m2*)data = at::convert(THPUtils_unpackDouble(obj)); @@ -95,8 +96,12 @@ inline void store_scalar(void* data, at::ScalarType scalarType, PyObject* obj) { *(at::Float8_e4m3fnuz*)data = at::convert(THPUtils_unpackDouble(obj)); break; + case at::kFloat8_e8m0fnu: + *(at::Float8_e8m0fnu*)data = + at::convert(THPUtils_unpackDouble(obj)); + break; default: - throw std::runtime_error("invalid type"); + throw std::runtime_error("store_scalar: invalid type"); } } @@ -143,6 +148,7 @@ inline PyObject* load_scalar(const void* data, at::ScalarType scalarType) { case at::kBFloat16: return PyFloat_FromDouble( at::convert(*(at::BFloat16*)data)); + // TODO(#146647): simplify below with macros case at::kFloat8_e5m2: return PyFloat_FromDouble( at::convert(*(at::Float8_e5m2*)data)); @@ -155,8 +161,11 @@ inline PyObject* load_scalar(const void* data, at::ScalarType scalarType) { case at::kFloat8_e4m3fnuz: return PyFloat_FromDouble(at::convert( *(at::Float8_e4m3fnuz*)data)); + case at::kFloat8_e8m0fnu: + return PyFloat_FromDouble( + at::convert(*(at::Float8_e8m0fnu*)data)); default: - throw std::runtime_error("invalid type"); + throw std::runtime_error("load_scalar: invalid type"); } } diff --git a/torch/storage.py b/torch/storage.py index d543a9d5550..aebab0d5e6b 100644 --- a/torch/storage.py +++ b/torch/storage.py @@ -535,6 +535,7 @@ def _new_dtypes(): torch.float8_e4m3fn, torch.float8_e5m2fnuz, torch.float8_e4m3fnuz, + torch.float8_e8m0fnu, torch.bits8, torch.bits16, torch.bits1x8, diff --git a/torchgen/api/types/types.py b/torchgen/api/types/types.py index 6bf753c727b..8e068291738 100644 --- a/torchgen/api/types/types.py +++ b/torchgen/api/types/types.py @@ -51,6 +51,7 @@ float8_e5m2T = BaseCppType("at", "Float8_e5m2") float8_e5m2fnuzT = BaseCppType("at", "Float8_e5m2fnuz") float8_e4m3fnT = BaseCppType("at", "Float8_e4m3fn") float8_e4m3fnuzT = BaseCppType("at", "Float8_e4m3fnuz") +float8_e8m0fnuT = BaseCppType("at", "Float8_e8m0fnu") stringT = BaseCppType("c10", "string_view") generatorT = BaseCppType("at", "Generator") scalarTypeT = BaseCppType("at", "ScalarType") @@ -102,6 +103,7 @@ ScalarTypeToCppMapping: dict[ScalarType, BaseCppType] = { ScalarType.Float8_e5m2fnuz: float8_e5m2fnuzT, ScalarType.Float8_e4m3fn: float8_e4m3fnT, ScalarType.Float8_e4m3fnuz: float8_e4m3fnuzT, + ScalarType.Float8_e8m0fnu: float8_e8m0fnuT, } BaseTypeToCppMapping: dict[BaseTy, BaseCppType] = { diff --git a/torchgen/model.py b/torchgen/model.py index 54bb8087dc0..0c35e3b98a6 100644 --- a/torchgen/model.py +++ b/torchgen/model.py @@ -374,6 +374,7 @@ class ScalarType(Enum): Float8_e5m2fnuz = auto() Float8_e4m3fn = auto() Float8_e4m3fnuz = auto() + Float8_e8m0fnu = auto() def __str__(self) -> str: return self.name