Skip to content

Commit 87da2eb

Browse files
committed
Add F4E2M1FN type: add tests
1 parent aabe9c6 commit 87da2eb

File tree

25 files changed

+190
-78
lines changed

25 files changed

+190
-78
lines changed

xla/array2d_test.cc

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,20 @@ TEST(Array2dTest, LinspaceF8E3M4) {
219219
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(2, 1)), 3.5);
220220
}
221221

222+
TEST(Array2dTest, LinspaceF4E2M1FN) {
223+
auto arr = MakeLinspaceArray2D<tsl::float4_e2m1fn>(1.0, 3.5, 3, 2);
224+
225+
EXPECT_EQ(arr->n1(), 3);
226+
EXPECT_EQ(arr->n2(), 2);
227+
228+
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(0, 0)), 1.0);
229+
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(0, 1)), 1.5);
230+
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(1, 0)), 2.0);
231+
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(1, 1)), 2.0); // 2.5 rounded down
232+
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(2, 0)), 3.0);
233+
EXPECT_FLOAT_EQ(static_cast<float>((*arr)(2, 1)), 4.0); // 3.5 rounded up
234+
}
235+
222236
TEST(Array2dTest, Stringification) {
223237
auto arr = MakeLinspaceArray2D(1.0, 3.5, 3, 2);
224238
const std::string expected = R"([[1, 1.5],

xla/fp_util_test.cc

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,59 @@ class FP8E4M3DistanceTest : public ::testing::Test {};
119119
using F8E4M3Types = ::testing::Types<tsl::float8_e4m3, tsl::float8_e4m3fn>;
120120
TYPED_TEST_SUITE(FP8E4M3DistanceTest, F8E4M3Types);
121121

122+
TEST(FPDistanceTest, F4E2M1FNDistance) {
123+
// a & b are equal
124+
EXPECT_EQ(CalculateDistanceInFloats<tsl::float4_e2m1fn>(
125+
tsl::float4_e2m1fn(4.0), tsl::float4_e2m1fn(4.0)),
126+
0);
127+
128+
// a & b have the same exponents
129+
EXPECT_EQ(CalculateDistanceInFloats<tsl::float4_e2m1fn>(
130+
tsl::float4_e2m1fn(4.0), tsl::float4_e2m1fn(6.0)),
131+
1);
132+
133+
// a & b have different exponents
134+
EXPECT_EQ(CalculateDistanceInFloats<tsl::float4_e2m1fn>(
135+
tsl::float4_e2m1fn(2.0), tsl::float4_e2m1fn(4.0)),
136+
2);
137+
138+
// 1 from 0 in the positive direction
139+
EXPECT_EQ(CalculateDistanceInFloats<tsl::float4_e2m1fn>(
140+
std::numeric_limits<tsl::float4_e2m1fn>::denorm_min(),
141+
tsl::float4_e2m1fn(0)),
142+
1);
143+
144+
// 1 from 0 in the negative direction
145+
EXPECT_EQ(CalculateDistanceInFloats<tsl::float4_e2m1fn>(
146+
-std::numeric_limits<tsl::float4_e2m1fn>::denorm_min(),
147+
tsl::float4_e2m1fn(0)),
148+
1);
149+
150+
// a & b have different signs
151+
EXPECT_EQ(CalculateDistanceInFloats<tsl::float4_e2m1fn>(
152+
-std::numeric_limits<tsl::float4_e2m1fn>::denorm_min(),
153+
std::numeric_limits<tsl::float4_e2m1fn>::denorm_min()),
154+
2);
155+
156+
// 1 non denorm from 0 in the positive direction
157+
EXPECT_EQ(CalculateDistanceInFloats<tsl::float4_e2m1fn>(
158+
std::numeric_limits<tsl::float4_e2m1fn>::min(),
159+
tsl::float4_e2m1fn(0)),
160+
2);
161+
162+
// 1 non denorm from 0 in the negative direction
163+
EXPECT_EQ(CalculateDistanceInFloats<tsl::float4_e2m1fn>(
164+
-std::numeric_limits<tsl::float4_e2m1fn>::min(),
165+
tsl::float4_e2m1fn(0)),
166+
2);
167+
168+
// a & b have different signs
169+
EXPECT_EQ(CalculateDistanceInFloats<tsl::float4_e2m1fn>(
170+
-std::numeric_limits<tsl::float4_e2m1fn>::min(),
171+
std::numeric_limits<tsl::float4_e2m1fn>::min()),
172+
4);
173+
}
174+
122175
TEST(FPDistanceTest, F8E3M4Distance) {
123176
// a & b are equal
124177
EXPECT_EQ(CalculateDistanceInFloats<tsl::float8_e3m4>(tsl::float8_e3m4(8.0),

xla/hlo/builder/lib/math.cc

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,7 @@ XlaOp IsNegZero(XlaOp operand) {
184184
case F32:
185185
return Eq(BitcastConvertType(operand, U32),
186186
ConstantR0WithType(&b, U32, uint32_t{1} << 31));
187+
case F4E2M1FN:
187188
case F8E3M4:
188189
case F8E4M3:
189190
case F8E5M2:
@@ -971,8 +972,9 @@ XlaOp Igamma(XlaOp a, XlaOp x) {
971972
TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Igamma", a));
972973
PrimitiveType a_x_type = a_shape.element_type();
973974
bool needs_upcast = false;
974-
for (PrimitiveType type : {BF16, F16, F8E3M4, F8E4M3, F8E5M2, F8E4M3FN,
975-
F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ}) {
975+
for (PrimitiveType type :
976+
{BF16, F16, F4E2M1FN, F8E3M4, F8E4M3, F8E4M3B11FNUZ, F8E4M3FN,
977+
F8E4M3FNUZ, F8E5M2, F8E5M2FNUZ}) {
976978
if (a_shape.element_type() == type) {
977979
needs_upcast = true;
978980
break;
@@ -1024,8 +1026,9 @@ XlaOp IgammaGradA(XlaOp a, XlaOp x) {
10241026
}
10251027
TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("IgammaGradA", a));
10261028
bool needs_upcast = false;
1027-
for (PrimitiveType type : {BF16, F16, F8E3M4, F8E4M3, F8E5M2, F8E4M3FN,
1028-
F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ}) {
1029+
for (PrimitiveType type :
1030+
{BF16, F16, F4E2M1FN, F8E3M4, F8E4M3, F8E4M3B11FNUZ, F8E4M3FN,
1031+
F8E4M3FNUZ, F8E5M2, F8E5M2FNUZ}) {
10291032
if (a_shape.element_type() == type) {
10301033
needs_upcast = true;
10311034
break;

xla/hlo/builder/lib/math_test.cc

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -95,17 +95,22 @@ class MathTypedTest : public MathTest {
9595
Tuple(&b, {IsFinite(x), IsInf(x), IsPosInf(x), IsNegInf(x), IsNan(x)});
9696

9797
bool has_inf = std::numeric_limits<T>::has_infinity;
98+
bool has_nan = std::numeric_limits<T>::has_quiet_NaN;
99+
bool is_finite = !has_inf && !has_nan;
100+
bool is_nan_only = !has_inf && has_nan;
101+
98102
auto expected = LiteralUtil::MakeTupleOwned(
99-
LiteralUtil::CreateR1<bool>(
100-
{true, true, true, true, true, false, false, false, false}),
103+
LiteralUtil::CreateR1<bool>({true, true, true, true, true, is_finite,
104+
is_finite, is_finite, is_finite}),
101105
LiteralUtil::CreateR1<bool>({false, false, false, false, false, has_inf,
102106
has_inf, false, false}),
103107
LiteralUtil::CreateR1<bool>(
104108
{false, false, false, false, false, has_inf, false, false, false}),
105109
LiteralUtil::CreateR1<bool>(
106110
{false, false, false, false, false, false, has_inf, false, false}),
107111
LiteralUtil::CreateR1<bool>({false, false, false, false, false,
108-
!has_inf, !has_inf, true, true}));
112+
is_nan_only, is_nan_only, has_nan,
113+
has_nan}));
109114
ComputeAndCompareLiteral(&b, expected, {});
110115
}
111116

@@ -118,10 +123,11 @@ class MathTypedTest : public MathTest {
118123
LiteralUtil::CreateR1<T>({T{-0.0}, T{0}, T{1}, T{-1}, inf, -inf, nan}),
119124
&b));
120125

126+
bool is_mx = std::is_same_v<T, tsl::float4_e2m1fn>;
121127
ComputeAndCompareLiteral(
122128
&b,
123129
LiteralUtil::CreateR1<bool>(
124-
{has_negative_zero_v<T>, false, false, false, false, false, false}),
130+
{has_negative_zero_v<T>, false, false, false, false, false, is_mx}),
125131
{}, error_spec_);
126132
}
127133

@@ -136,6 +142,9 @@ class MathTypedTest : public MathTest {
136142
// For good measure, we also check pow with an exponent other than 0.5.
137143
void TestSqrtPowInequivalence() {
138144
SetFastMathDisabled(true);
145+
if (std::is_same_v<T, tsl::float4_e2m1fn>) {
146+
GTEST_SKIP() << "Skipping due to low precision";
147+
}
139148

140149
// Tests disable constant folding by default, but this test needs it
141150
// enabled, otherwise we don't tickle the bug we're trying to catch.
@@ -181,19 +190,24 @@ class MathTypedTest : public MathTest {
181190
&b);
182191
Erf(x);
183192

184-
bool has_inf = std::numeric_limits<T>::has_infinity;
185-
std::vector<T> expected = {
186-
has_inf ? T(-1) : nan, has_inf ? T(1) : nan, T(-0), T(0), T(-1), T(1)};
193+
bool inf_as_nan = !std::numeric_limits<T>::has_infinity &&
194+
std::numeric_limits<T>::has_quiet_NaN;
195+
std::vector<T> expected = {inf_as_nan ? nan : T(-1),
196+
inf_as_nan ? nan : T(1),
197+
T(-0),
198+
T(0),
199+
T(-1),
200+
T(1)};
187201

188202
ComputeAndCompareR1<T>(&b, expected, {}, error_spec_);
189203
}
190204
};
191205

192206
// TODO(b/123355973): Add bfloat16 to TestTypes once it's working.
193207
using TestTypes =
194-
::testing::Types<tsl::float8_e3m4, tsl::float8_e4m3, tsl::float8_e4m3fnuz,
195-
tsl::float8_e4m3b11fnuz, tsl::float8_e5m2,
196-
tsl::float8_e5m2fnuz,
208+
::testing::Types<tsl::float4_e2m1fn, tsl::float8_e3m4, tsl::float8_e4m3,
209+
tsl::float8_e4m3fnuz, tsl::float8_e4m3b11fnuz,
210+
tsl::float8_e5m2, tsl::float8_e5m2fnuz,
197211
#ifndef XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16
198212
Eigen::half,
199213
#endif

xla/hlo/transforms/simplifiers/float_normalization.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,9 @@ absl::Status FloatNormalizationVisitor::ChangeOutputTypeThenInsertConvertBack(
217217
hlo->mutable_shape(), [&](Shape* subshape, const xla::ShapeIndex& index) {
218218
if (subshape->element_type() == from) {
219219
subshape->set_element_type(to);
220+
if (subshape->has_layout() && from == F4E2M1FN) {
221+
subshape->mutable_layout()->set_element_size_in_bits(0);
222+
}
220223
}
221224
});
222225
float_normalization_->UpdateLayout(hlo->mutable_shape());

xla/hlo/transforms/simplifiers/float_normalization_test.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,9 @@ class FloatNormalizationF8Test
150150
public ::testing::WithParamInterface<PrimitiveType> {};
151151

152152
INSTANTIATE_TEST_SUITE_P(FloatNormalizationF8Suite, FloatNormalizationF8Test,
153-
::testing::Values(F8E3M4, F8E4M3, F8E5M2));
153+
::testing::Values(F4E2M1FN, F8E3M4, F8E4M3,
154+
F8E4M3B11FNUZ, F8E4M3FN, F8E4M3FNUZ,
155+
F8E5M2, F8E5M2FNUZ));
154156

155157
TEST_F(FloatNormalizationTest, NoopIfSupported) {
156158
auto builder = HloComputation::Builder(TestName());

xla/mlir/utils/type_util.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ absl::StatusOr<mlir::Type> ConvertPrimitiveTypeToMlirType(
3232
switch (type) {
3333
case xla::PrimitiveType::PRED:
3434
return b.getI1Type();
35+
case xla::PrimitiveType::F4E2M1FN:
36+
return b.getFloat4E2M1FNType();
3537
case xla::PrimitiveType::F8E5M2:
3638
return b.getFloat8E5M2Type();
3739
case xla::PrimitiveType::F8E4M3:
@@ -78,7 +80,9 @@ absl::StatusOr<mlir::Type> ConvertPrimitiveTypeToMlirType(
7880
}
7981

8082
xla::PrimitiveType ConvertMlirTypeToPrimitiveType(mlir::Type type) {
81-
if (type.isFloat8E5M2()) {
83+
if (type.isFloat4E2M1FN()) {
84+
return xla::PrimitiveType::F4E2M1FN;
85+
} else if (type.isFloat8E5M2()) {
8286
return xla::PrimitiveType::F8E5M2;
8387
} else if (type.isFloat8E4M3()) {
8488
return xla::PrimitiveType::F8E4M3;

xla/mlir/utils/type_util_test.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ INSTANTIATE_TEST_SUITE_P(
101101
Execute, TypeUtilTest,
102102
::testing::ValuesIn(std::vector<TypeUtilTestParam>(
103103
{{PRED, [](mlir::Builder b) { return b.getI1Type(); }},
104+
{F4E2M1FN, [](mlir::Builder b) { return b.getFloat4E2M1FNType(); }},
104105
{F8E5M2, [](mlir::Builder b) { return b.getFloat8E5M2Type(); }},
105106
{F8E4M3, [](mlir::Builder b) { return b.getFloat8E4M3Type(); }},
106107
{F8E4M3FN, [](mlir::Builder b) { return b.getFloat8E4M3FNType(); }},

xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6832,6 +6832,13 @@ func.func @invalid_dimension_attr(%arg0: tensor<?x?xf32, #mhlo.type_extensions<b
68326832

68336833
// -----
68346834

6835+
func.func @f4e2m1fn(%arg0: tensor<f16>) -> tensor<f4E2M1FN> {
6836+
%0 = "mhlo.convert"(%arg0) : (tensor<f16>) -> tensor<f4E2M1FN>
6837+
func.return %0 : tensor<f4E2M1FN>
6838+
}
6839+
6840+
// -----
6841+
68356842
func.func @f8e3m4(%arg0: tensor<f16>) -> tensor<f8E3M4> {
68366843
%0 = "mhlo.convert"(%arg0) : (tensor<f16>) -> tensor<f8E3M4>
68376844
func.return %0 : tensor<f8E3M4>

xla/python/ifrt/dtype_test.cc

Lines changed: 15 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -66,36 +66,21 @@ TEST(DTypeTest, ByteSize) {
6666
TEST(DTypeTest, BitSize) {
6767
for (const auto& [kind, bit_size] :
6868
std::vector<std::tuple<DType::Kind, int>>({
69-
{DType::kS2, 2},
70-
{DType::kU2, 2},
71-
{DType::kS4, 4},
72-
{DType::kU4, 4},
73-
{DType::kPred, 8},
74-
{DType::kS8, 8},
75-
{DType::kU8, 8},
76-
{DType::kF4E2M1FN, 4},
77-
{DType::kF8E3M4, 8},
78-
{DType::kF8E4M3, 8},
79-
{DType::kF8E4M3FN, 8},
80-
{DType::kF8E4M3B11FNUZ, 8},
81-
{DType::kF8E4M3FNUZ, 8},
82-
{DType::kF8E5M2, 8},
83-
{DType::kF8E5M2FNUZ, 8},
84-
{DType::kS16, 16},
85-
{DType::kU16, 16},
86-
{DType::kF16, 16},
87-
{DType::kBF16, 16},
88-
{DType::kS32, 32},
89-
{DType::kU32, 32},
90-
{DType::kF32, 32},
91-
{DType::kS64, 64},
92-
{DType::kU64, 64},
93-
{DType::kF64, 64},
94-
{DType::kC64, 64},
95-
{DType::kC128, 128},
96-
{DType::kToken, -1},
97-
{DType::kInvalid, -1},
98-
{DType::kString, -1},
69+
{DType::kS2, 2}, {DType::kU2, 2},
70+
{DType::kS4, 4}, {DType::kU4, 4},
71+
{DType::kPred, 8}, {DType::kS8, 8},
72+
{DType::kU8, 8}, {DType::kF4E2M1FN, 4},
73+
{DType::kF8E3M4, 8}, {DType::kF8E4M3, 8},
74+
{DType::kF8E4M3FN, 8}, {DType::kF8E4M3B11FNUZ, 8},
75+
{DType::kF8E4M3FNUZ, 8}, {DType::kF8E5M2, 8},
76+
{DType::kF8E5M2FNUZ, 8}, {DType::kS16, 16},
77+
{DType::kU16, 16}, {DType::kF16, 16},
78+
{DType::kBF16, 16}, {DType::kS32, 32},
79+
{DType::kU32, 32}, {DType::kF32, 32},
80+
{DType::kS64, 64}, {DType::kU64, 64},
81+
{DType::kF64, 64}, {DType::kC64, 64},
82+
{DType::kC128, 128}, {DType::kToken, -1},
83+
{DType::kInvalid, -1}, {DType::kString, -1},
9984
})) {
10085
EXPECT_EQ(DType(kind).bit_size(),
10186
bit_size == -1 ? std::nullopt : std::make_optional(bit_size));

xla/service/cpu/cpu_compiler.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -607,6 +607,8 @@ absl::Status CpuCompiler::RunHloPassesThroughLayoutAssn(
607607
pipeline.AddPass<FloatNormalization>(&f8e4m3fnuz_support);
608608
FloatSupport f8e3m4_support(F8E3M4, F16);
609609
pipeline.AddPass<FloatNormalization>(&f8e3m4_support);
610+
FloatSupport f4e2m1fn_support(F4E2M1FN, F16);
611+
pipeline.AddPass<FloatNormalization>(&f4e2m1fn_support);
610612
// After canonicalization, there may be more batch dots that can be
611613
// simplified.
612614
pipeline.AddPass<BatchDotSimplification>();

xla/service/cpu/onednn_memory_util.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ inline dnnl::memory::data_type ToOneDnnDataType(PrimitiveType ptype) {
7373

7474
// TODO(intel-tf): properly handle not supported types:
7575
// S16, S64, U16, U32, U64, C64, C128, F8E5M2, F8E4M3FN, S4, U4,
76-
// F8E4M3B11FNUZ, F8E4M3, F8E3M4
76+
// F8E4M3B11FNUZ, F8E4M3, F8E3M4, F4E2M1FN
7777
default:
7878
return dt::undef;
7979
}

0 commit comments

Comments
 (0)