Skip to content

Commit be2e457

Browse files
committed
Addressing PR#19096 review comments
1 parent e0ee48c commit be2e457

File tree

10 files changed

+248
-115
lines changed

10 files changed

+248
-115
lines changed

xla/hlo/builder/lib/math_test.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -96,20 +96,20 @@ class MathTypedTest : public MathTest {
9696

9797
bool has_inf = std::numeric_limits<T>::has_infinity;
9898
bool has_nan = std::numeric_limits<T>::has_quiet_NaN;
99-
bool is_finite = !has_inf && !has_nan;
100-
bool is_nan_only = !has_inf && has_nan;
99+
bool has_finite = !has_inf && !has_nan;
100+
bool has_nan_only = !has_inf && has_nan;
101101

102102
auto expected = LiteralUtil::MakeTupleOwned(
103-
LiteralUtil::CreateR1<bool>({true, true, true, true, true, is_finite,
104-
is_finite, is_finite, is_finite}),
103+
LiteralUtil::CreateR1<bool>({true, true, true, true, true, has_finite,
104+
has_finite, has_finite, has_finite}),
105105
LiteralUtil::CreateR1<bool>({false, false, false, false, false, has_inf,
106106
has_inf, false, false}),
107107
LiteralUtil::CreateR1<bool>(
108108
{false, false, false, false, false, has_inf, false, false, false}),
109109
LiteralUtil::CreateR1<bool>(
110110
{false, false, false, false, false, false, has_inf, false, false}),
111111
LiteralUtil::CreateR1<bool>({false, false, false, false, false,
112-
is_nan_only, is_nan_only, has_nan,
112+
has_nan_only, has_nan_only, has_nan,
113113
has_nan}));
114114
ComputeAndCompareLiteral(&b, expected, {});
115115
}

xla/literal.h

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -593,13 +593,6 @@ class LiteralBase {
593593
static_assert(8 % bits_per_element == 0);
594594

595595
constexpr int elements_per_byte = 8 / bits_per_element;
596-
constexpr auto cast = [](NativeT x) -> uint8_t {
597-
if constexpr (primitive_util::IsFloatingPointType(primitive_type)) {
598-
return Eigen::numext::bit_cast<uint8_t>(x);
599-
}
600-
return static_cast<uint8_t>(x);
601-
};
602-
603596
int64_t bytes = elements.size() / elements_per_byte;
604597
for (int64_t i = 0; i < bytes; ++i) {
605598
uint8_t byte = 0;
@@ -710,14 +703,14 @@ class LiteralBase {
710703
static_assert(!primitive_util::IsComplexType(primitive_type));
711704
static_assert(8 % bits_per_element == 0);
712705

713-
constexpr int elements_per_byte = 8 / bits_per_element;
714706
constexpr auto cast = [](uint8_t x) -> NativeT {
715707
if constexpr (primitive_util::IsFloatingPointType(primitive_type)) {
716708
return Eigen::numext::bit_cast<NativeT>(x);
717709
}
718710
return static_cast<NativeT>(x);
719711
};
720712

713+
constexpr int elements_per_byte = 8 / bits_per_element;
721714
int64_t bytes = elements.size() / elements_per_byte;
722715
for (int64_t i = 0; i < bytes; ++i) {
723716
uint8_t byte;

xla/literal_comparison_test.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ TYPED_TEST(LiteralComparisonTest, CompareNear_NotEqual_4ulps) {
7373
auto actual = LiteralUtil::CreateR0<TypeParam>(TypeParam(1.0));
7474
float expV = 1.5; // F8E4M3*
7575
if (type == F8E5M2 || type == F8E5M2FNUZ)
76-
expV = 1.75;
76+
expV = 2.0;
7777
else if (type == F8E3M4)
7878
expV = 1.25;
7979
else if (type == F4E2M1FN)
@@ -99,7 +99,7 @@ TYPED_TEST(LiteralComparisonTest, FloatUsingCompareNear_NotEqual_4ulps) {
9999
auto actual = LiteralUtil::CreateR0<float>(1.0);
100100
float expV = 1.51; // F8E4M3*
101101
if (type == F8E5M2 || type == F8E5M2FNUZ)
102-
expV = 1.76;
102+
expV = 2.01;
103103
else if (type == F8E3M4)
104104
expV = 1.26;
105105
else if (type == F4E2M1FN)

xla/python/ifrt/dtype.proto

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ message DTypeProto {
8181
// collision.
8282
KIND_STRING = 99;
8383

84-
// Next: 31
84+
// Next: 32
8585
}
8686
// LINT.ThenChange()
8787
Kind kind = 1;

xla/python/ifrt/dtype_test.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ TEST(DTypeTest, BitSize) {
7474
{DType::kF8E3M4, 8}, {DType::kF8E4M3, 8},
7575
{DType::kF8E4M3FN, 8}, {DType::kF8E4M3B11FNUZ, 8},
7676
{DType::kF8E4M3FNUZ, 8}, {DType::kF8E5M2, 8},
77-
{DType::kF8E5M2FNUZ, 8}, {DType::kF8E8M0FNU, 4},
77+
{DType::kF8E5M2FNUZ, 8}, {DType::kF8E8M0FNU, 8},
7878
{DType::kS16, 16}, {DType::kU16, 16},
7979
{DType::kF16, 16}, {DType::kBF16, 16},
8080
{DType::kS32, 32}, {DType::kU32, 32},

xla/service/elemental_ir_emitter.cc

Lines changed: 168 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -811,118 +811,217 @@ llvm::Value* EmitF8e4m3b11fnuzToF16(llvm::Value* f8_value,
811811

812812
absl::StatusOr<llvm::Value*> EmitF16ToF4e2m1fn(llvm::Value* f16_value,
813813
llvm::IRBuilder<>* b) {
814+
auto i8_const = [&](int val) {
815+
return llvm::ConstantInt::get(b->getInt8Ty(), val);
816+
};
817+
auto i16_const = [&](int val) {
818+
return llvm::ConstantInt::get(b->getInt16Ty(), val);
819+
};
820+
constexpr int mantissa_diff = 9; // 10 for F16, 1 for F4
821+
constexpr int bias_diff = 14; // 15 for F16, 1 for F4
822+
823+
// Cast the input value to an integer for bitwise manipulation.
824+
// Get the absolute value of the input (discard the sign).
825+
// f16_bits = bitcast(f16_value, int)
826+
// f16_abs_bits = f16_bits & 0x7FFF
827+
llvm::Value* f16_bits = b->CreateBitCast(f16_value, b->getInt16Ty());
828+
llvm::Value* f16_abs_bits = b->CreateAnd(f16_bits, i16_const(0x7FFF));
829+
830+
// If the input absolute value is >= 7.0 or an infinity, the result saturates
831+
// to max value (6.0). If (0.75 <= input < 1), the result is rounded to 1.0.
832+
// If (0 <= input <= 0.25), the result is rounded to 0.0.
833+
// If the input is NaN, the result is undefined (implemented as minus zero).
834+
// The rest of the cases are handled by the "happy path".
835+
// is_overflow = f16_abs_bits >= 0x1.Cp2
836+
// is_one = f16_abs_bits >= 0x1.8p-1 (used only if exponent underflows)
837+
// is_zero = f16_abs_bits <= 0x1p-2 (used only if exponent underflows)
838+
// is_nan = f16_abs_bits > 0x7C00 (F16 NaN threshold)
839+
llvm::Value* is_overflow =
840+
b->CreateICmpUGE(f16_abs_bits, i16_const(0x4700)); // 7.0
841+
llvm::Value* is_one =
842+
b->CreateICmpUGE(f16_abs_bits, i16_const(0x3A00)); // 0.75
843+
llvm::Value* is_zero =
844+
b->CreateICmpULE(f16_abs_bits, i16_const(0x3400)); // 0.25
845+
llvm::Value* is_nan =
846+
b->CreateICmpUGT(f16_abs_bits, i16_const(0x7C00)); // inf
847+
848+
// Truncate the mantissa to 1 bit and the exponent to 3 bits (not 2 bits, as
849+
// the type doesn't have Inf/NaN and can represent unbiased exponent 2).
850+
// This case, as well as the denormal, is handled below.
814851
TF_ASSIGN_OR_RETURN(
815852
llvm::Value * reduced_precision,
816853
EmitReducePrecisionIR(
817854
/*src_ty=*/F16, f16_value,
818855
/*dest_exponent_bits=*/primitive_util::ExponentWidth(F4E2M1FN) + 1,
819856
/*dest_mantissa_bits=*/primitive_util::SignificandWidth(F4E2M1FN) - 1,
820857
/*quiet_nans=*/false, b));
858+
859+
// Cast the reduced precision value to an integer for bitwise manipulation.
860+
// Discard the least significant (9) mantissa bits leaving 1 bit.
861+
// Truncate to
862+
// as_int16 = bitcast(reduced_precision, int)
863+
// as_int8 = as_int16 >> (f16_mantissa - f4_mantissa)
821864
llvm::Value* as_int16 = b->CreateBitCast(reduced_precision, b->getInt16Ty());
822865
llvm::Value* as_int8 =
823-
b->CreateTrunc(b->CreateLShr(as_int16, 9), b->getInt8Ty());
866+
b->CreateTrunc(b->CreateLShr(as_int16, mantissa_diff), b->getInt8Ty());
824867

825-
// Extract sign, exponent and mantissa from reduced precision value.
826-
auto i8_const = [&](int val) {
827-
return llvm::ConstantInt::get(b->getInt8Ty(), val);
828-
};
868+
// Get the sign (0 or 1).
869+
// f4_sign = as_int8 >> 6
829870
llvm::Value* f4_sign = b->CreateLShr(as_int8, 6);
871+
872+
// Get exponent and mantissa bits without the sign.
873+
// Important: the mask is 0x3F (not 0x7F), discard bit #6.
874+
// f4_bits = as_int8 & 0x3F
830875
llvm::Value* f4_bits = b->CreateAnd(as_int8, i8_const(0x3F));
831-
llvm::Value* f4_normal = b->CreateSub(f4_bits, i8_const(28));
832876

833-
// Special case for exponent overflow.
834-
auto i16_const = [&](int val) {
835-
return llvm::ConstantInt::get(b->getInt16Ty(), val);
836-
};
837-
llvm::Value* f16_bits = b->CreateAnd(
838-
b->CreateBitCast(f16_value, b->getInt16Ty()), i16_const(0x7FFF));
839-
llvm::Value* is_overflow =
840-
b->CreateICmpUGE(f16_bits, i16_const(0x4700)); // 7.0
841-
llvm::Value* is_nan = b->CreateICmpUGT(f16_bits, i16_const(0x7C00)); // inf
842-
llvm::Value* max_or_nan =
843-
b->CreateSelect(is_nan, i8_const(0x8), i8_const(0x7));
844-
llvm::Value* f4_normal_or_overflow =
845-
b->CreateSelect(is_overflow, max_or_nan, f4_normal);
846-
847-
// Special case for exponent underflow.
877+
// Convert F16 exponent to F4 exponent by readjusting the exponent bias.
878+
// This produces the "normal" result, i.e. not Inf or NaN or denormal.
879+
// f4_normal = f4_bits - ((f16_bias - f4_bias) << f4_mantissa)
880+
constexpr int f4_exponent_offset = bias_diff << 1;
881+
llvm::Value* f4_normal = b->CreateSub(f4_bits, i8_const(f4_exponent_offset));
882+
883+
// If the rounding resulted in zero exponent, the value is incorrect.
884+
// This happens when the input is < 1.0
885+
// is_underflow = f4_normal <= 1
848886
llvm::Value* is_underflow = b->CreateICmpSLE(f4_normal, i8_const(1));
849-
llvm::Value* is_one = b->CreateICmpUGE(f16_bits, i16_const(0x3A00)); // 0.75
850-
llvm::Value* is_zero = b->CreateICmpULE(f16_bits, i16_const(0x3400)); // 0.25
851-
llvm::Value* denorm_or_zero =
852-
b->CreateSelect(is_zero, i8_const(0x0), i8_const(0x1));
853-
llvm::Value* f4_small =
854-
b->CreateSelect(is_one, i8_const(0x2), denorm_or_zero);
855-
llvm::Value* f4_result =
856-
b->CreateSelect(is_underflow, f4_small, f4_normal_or_overflow);
887+
888+
// Chain of selects that handles the special cases.
889+
// f4_result =
890+
// is_underflow ? (is_one ? 1.0 : (is_zero ? 0.0 : 0.5)) :
891+
// is_overflow ? (is_nan ? -0.0 : 6.0) :
892+
// f4_normal
893+
llvm::Value* f4_result = b->CreateSelect(
894+
is_underflow,
895+
// If underflow, the input is < 1.0; the result is either 0.0, 0.5 or 1.0
896+
b->CreateSelect(is_one, i8_const(0x2),
897+
b->CreateSelect(is_zero, i8_const(0x0), i8_const(0x1))),
898+
// If overflow, the input is >= 7.0 or infinity or NaN.
899+
b->CreateSelect(is_overflow,
900+
b->CreateSelect(is_nan, i8_const(0x8), i8_const(0x7)),
901+
f4_normal));
857902

858903
// Add sign to the resulting value.
904+
// f4_signed_result = (f4_sign << 3) | f4_result
859905
return b->CreateOr(f4_result, b->CreateShl(f4_sign, 3));
860906
}
861907

862908
llvm::Value* EmitF4e2m1fnToF16(llvm::Value* f8_value, llvm::IRBuilder<>* b) {
863-
llvm::Value* as_int16 = b->CreateZExt(f8_value, b->getInt16Ty());
864-
865-
// Extract sign, exponent and mantissa from reduced precision value.
866909
auto i16_const = [&](int val) {
867910
return llvm::ConstantInt::get(b->getInt16Ty(), val);
868911
};
869-
llvm::Value* sign = b->CreateLShr(as_int16, 3);
870-
llvm::Value* sign_shifted = b->CreateShl(sign, 15);
871-
llvm::Value* bits = b->CreateAnd(as_int16, i16_const(0x7));
872-
llvm::Value* bits_shifted = b->CreateShl(bits, 9);
873-
874-
// Re-bias the exponent and handle denormals.
875-
llvm::Value* f16_normal = b->CreateAdd(bits_shifted, i16_const(14 << 10));
876-
llvm::Value* is_denorm_or_zero = b->CreateICmpULE(bits, i16_const(1));
877-
llvm::Value* is_zero = b->CreateICmpEQ(bits, i16_const(0));
878-
llvm::Value* denorm_or_zero =
879-
b->CreateSelect(is_zero, i16_const(0x0000), i16_const(0x3800));
880-
llvm::Value* f16_result =
881-
b->CreateSelect(is_denorm_or_zero, denorm_or_zero, f16_normal);
912+
constexpr int mantissa_diff = 9; // 10 for F16, 1 for F4
913+
constexpr int bias_diff = 14; // 15 for F16, 1 for F4
914+
915+
// The input value is a 8-bit integer, extend it to 16-bit integer.
916+
// as_int16 = bitcast(f8_value, int)
917+
llvm::Value* as_int16 = b->CreateZExt(f8_value, b->getInt16Ty());
918+
919+
// Get the sign and shift it to F16 position.
920+
// f4_sign = as_int16 >> 3
921+
// f16_sign_bit = f4_sign << 15
922+
llvm::Value* f4_sign = b->CreateLShr(as_int16, 3);
923+
llvm::Value* f16_sign_bit = b->CreateShl(f4_sign, 15);
924+
925+
// Get exponent and mantissa bits without the sign.
926+
// f4_bits = as_int16 & 0x7
927+
// f16_bits = f4_bits << (f16_mantissa - f4_mantissa)
928+
llvm::Value* f4_bits = b->CreateAnd(as_int16, i16_const(0x7));
929+
llvm::Value* f16_bits = b->CreateShl(f4_bits, mantissa_diff);
930+
931+
// Convert F16 exponent to F4 exponent by readjusting the exponent bias.
932+
// f4_normal = f4_bits - ((f16_bias - f4_bias) << f4_mantissa)
933+
constexpr int f16_exponent_offset = bias_diff << 10;
934+
llvm::Value* f16_normal =
935+
b->CreateAdd(f16_bits, i16_const(f16_exponent_offset));
936+
937+
// For denormal and zero, the exponent is different. Handle these cases
938+
// separately below.
939+
// is_denorm_or_zero = f4_bits <= 1
940+
// is_zero = f4_bits == 0
941+
llvm::Value* is_denorm_or_zero = b->CreateICmpULE(f4_bits, i16_const(1));
942+
llvm::Value* is_zero = b->CreateICmpEQ(f4_bits, i16_const(0));
943+
944+
// Chain of selects that handles the special cases.
945+
// f16_result = is_denorm_or_zero ? (is_zero ? 0.0 : 0.5) : f16_normal
946+
llvm::Value* f16_result = b->CreateSelect(
947+
is_denorm_or_zero,
948+
b->CreateSelect(is_zero, i16_const(0x0000), i16_const(0x3800)),
949+
f16_normal);
882950

883951
// Add sign to the resulting value.
884-
llvm::Value* f16_signed = b->CreateOr(f16_result, sign_shifted);
885-
return b->CreateBitCast(f16_signed, b->getHalfTy());
952+
// f16_signed_result = f16_sign_bit | f16_result
953+
llvm::Value* f16_signed_result = b->CreateOr(f16_result, f16_sign_bit);
954+
return b->CreateBitCast(f16_signed_result, b->getHalfTy());
886955
}
887956

888957
llvm::Value* EmitF32ToF8e8m0fnu(llvm::Value* f32_value, llvm::IRBuilder<>* b) {
889-
llvm::Value* as_int32 = b->CreateBitCast(f32_value, b->getInt32Ty());
890-
891-
// Result is NaN if input is zero, negative, infinity or NaN.
892958
auto i32_const = [&](int val) {
893959
return llvm::ConstantInt::get(b->getInt32Ty(), val);
894960
};
895-
llvm::Value* is_denorm = b->CreateICmpULE(as_int32, i32_const(0x400000));
896-
llvm::Value* is_nan =
897-
b->CreateOr(b->CreateICmpSLE(as_int32, i32_const(0)),
898-
b->CreateICmpSGE(as_int32, i32_const(0x7F400000)));
899961

900-
// Round the value and extract exponent.
901-
llvm::Value* rounded = b->CreateAdd(as_int32, i32_const(0x400000));
902-
llvm::Value* shifted = b->CreateAShr(rounded, 23);
903-
llvm::Value* finite = b->CreateSelect(is_denorm, i32_const(0x00), shifted);
904-
llvm::Value* f32_result = b->CreateSelect(is_nan, i32_const(0xFF), finite);
962+
// Cast the input value to an integer for bitwise manipulation.
963+
// as_int32 = bitcast(f32_value, int)
964+
llvm::Value* as_int32 = b->CreateBitCast(f32_value, b->getInt32Ty());
965+
966+
// Check if the input is zero, negative, overflow, infinity or NaN.
967+
// All of these cases cannot be represented in the E8M0 format.
968+
// is_zero_or_negative = as_int32 <= 0
969+
// is_overflow_or_nan = as_int32 >= 0x1.8p127
970+
// is_nan = is_zero_or_negative | is_overflow_or_nan
971+
llvm::Value* is_zero_or_negative = b->CreateICmpSLE(as_int32, i32_const(0));
972+
llvm::Value* is_overflow_or_nan =
973+
b->CreateICmpSGE(as_int32, i32_const(0x7F400000)); // 1.5 * 2^127
974+
llvm::Value* is_nan = b->CreateOr(is_zero_or_negative, is_overflow_or_nan);
975+
976+
// Check if the input is a denormal which should round to the minimum value
977+
// (2^-127), as there is no zero value.
978+
// is_denorm = as_int32 <= 0x1p-127
979+
llvm::Value* is_denorm =
980+
b->CreateICmpULE(as_int32, i32_const(0x400000)); // 1.0 * 2^-127
981+
982+
// Round the value (always up) and discard the mantissa.
983+
// rounded = as_int32 + 0x1p-127
984+
// f8_normal = as_int32 >> f32_mantissa
985+
llvm::Value* rounded =
986+
b->CreateAdd(as_int32, i32_const(0x400000)); // 1.0 * 2^-127
987+
llvm::Value* f8_normal = b->CreateAShr(rounded, 23);
988+
989+
// Chain of selects that handles the special cases.
990+
// f8_result = is_nan ? 0xFF : (is_denorm ? 0x00 : f8_normal)
991+
llvm::Value* f8_result =
992+
b->CreateSelect(is_nan, i32_const(0xFF),
993+
b->CreateSelect(is_denorm, i32_const(0x00), f8_normal));
905994

906995
// Truncate to the result type.
907-
return b->CreateTrunc(f32_result, b->getInt8Ty());
996+
return b->CreateTrunc(f8_result, b->getInt8Ty());
908997
}
909998

910999
llvm::Value* EmitF8e8m0fnuToF32(llvm::Value* f8_value, llvm::IRBuilder<>* b) {
911-
// Shift exponent to the left for the normal case.
912-
llvm::Value* as_int32 = b->CreateZExt(f8_value, b->getInt32Ty());
913-
llvm::Value* shifted = b->CreateShl(as_int32, 23);
914-
915-
// Special values for 0x00 (denorm) and 0xFF (NaN).
9161000
auto i32_const = [&](int val) {
9171001
return llvm::ConstantInt::get(b->getInt32Ty(), val);
9181002
};
1003+
1004+
// The input value is a 8-bit integer, extend it to 32-bit integer.
1005+
// as_int32 = bitcast(f8_value, int)
1006+
llvm::Value* as_int32 = b->CreateZExt(f8_value, b->getInt32Ty());
1007+
1008+
// Check if the input is a denormal or NaN.
1009+
// is_zero = as_int32 == 0x00
1010+
// is_nan = as_int32 == 0xFF
9191011
llvm::Value* is_zero = b->CreateICmpEQ(as_int32, i32_const(0));
9201012
llvm::Value* is_nan = b->CreateICmpEQ(as_int32, i32_const(0xFF));
921-
llvm::Value* denorm_or_shifted =
922-
b->CreateSelect(is_zero, i32_const(0x00400000), shifted);
923-
llvm::Value* f32_result =
924-
b->CreateSelect(is_nan, i32_const(0x7FC00000), denorm_or_shifted);
9251013

1014+
// Shift exponent to the left for the normal case.
1015+
// f32_normal = as_int32 << mantissa_diff
1016+
llvm::Value* f32_normal = b->CreateShl(as_int32, 23);
1017+
1018+
// Chain of selects that handles the special cases.
1019+
// f32_result = is_nan ? 0x7FC00000 : (is_zero ? 0x1p-127 : f32_normal)
1020+
llvm::Value* f32_result = b->CreateSelect(
1021+
is_nan, i32_const(0x7FC00000),
1022+
b->CreateSelect(is_zero, i32_const(0x400000), f32_normal));
1023+
1024+
// Bitcast integer bits to the result type.
9261025
return b->CreateBitCast(f32_result, b->getFloatTy());
9271026
}
9281027

0 commit comments

Comments
 (0)