@@ -811,118 +811,217 @@ llvm::Value* EmitF8e4m3b11fnuzToF16(llvm::Value* f8_value,
811
811
812
812
absl::StatusOr<llvm::Value*> EmitF16ToF4e2m1fn (llvm::Value* f16_value,
813
813
llvm::IRBuilder<>* b) {
814
+ auto i8_const = [&](int val) {
815
+ return llvm::ConstantInt::get (b->getInt8Ty (), val);
816
+ };
817
+ auto i16_const = [&](int val) {
818
+ return llvm::ConstantInt::get (b->getInt16Ty (), val);
819
+ };
820
+ constexpr int mantissa_diff = 9 ; // 10 for F16, 1 for F4
821
+ constexpr int bias_diff = 14 ; // 15 for F16, 1 for F4
822
+
823
+ // Cast the input value to an integer for bitwise manipulation.
824
+ // Get the absolute value of the input (discard the sign).
825
+ // f16_bits = bitcast(f16_value, int)
826
+ // f16_abs_bits = f16_bits & 0x7FFF
827
+ llvm::Value* f16_bits = b->CreateBitCast (f16_value, b->getInt16Ty ());
828
+ llvm::Value* f16_abs_bits = b->CreateAnd (f16_bits, i16_const (0x7FFF ));
829
+
830
+ // If the input absolute value is >= 7.0 or an infinity, the result saturates
831
+ // to max value (6.0). If (0.75 <= input < 1), the result is rounded to 1.0.
832
+ // If (0 <= input <= 0.25), the result is rounded to 0.0.
833
+ // If the input is NaN, the result is undefined (implemented as minus zero).
834
+ // The rest of the cases are handled by the "happy path".
835
+ // is_overflow = f16_abs_bits >= 0x1.Cp2
836
+ // is_one = f16_abs_bits >= 0x1.8p-1 (used only if exponent underflows)
837
+ // is_zero = f16_abs_bits <= 0x1p-2 (used only if exponent underflows)
838
+ // is_nan = f16_abs_bits > 0x7C00 (F16 NaN threshold)
839
+ llvm::Value* is_overflow =
840
+ b->CreateICmpUGE (f16_abs_bits, i16_const (0x4700 )); // 7.0
841
+ llvm::Value* is_one =
842
+ b->CreateICmpUGE (f16_abs_bits, i16_const (0x3A00 )); // 0.75
843
+ llvm::Value* is_zero =
844
+ b->CreateICmpULE (f16_abs_bits, i16_const (0x3400 )); // 0.25
845
+ llvm::Value* is_nan =
846
+ b->CreateICmpUGT (f16_abs_bits, i16_const (0x7C00 )); // inf
847
+
848
+ // Truncate the mantissa to 1 bit and the exponent to 3 bits (not 2 bits, as
849
+ // the type doesn't have Inf/NaN and can represent unbiased exponent 2).
850
+ // This case, as well as the denormal, is handled below.
814
851
TF_ASSIGN_OR_RETURN (
815
852
llvm::Value * reduced_precision,
816
853
EmitReducePrecisionIR (
817
854
/* src_ty=*/ F16, f16_value,
818
855
/* dest_exponent_bits=*/ primitive_util::ExponentWidth (F4E2M1FN) + 1 ,
819
856
/* dest_mantissa_bits=*/ primitive_util::SignificandWidth (F4E2M1FN) - 1 ,
820
857
/* quiet_nans=*/ false , b));
858
+
859
+ // Cast the reduced precision value to an integer for bitwise manipulation.
860
+ // Discard the least significant (9) mantissa bits leaving 1 bit.
861
+ // Truncate to
862
+ // as_int16 = bitcast(reduced_precision, int)
863
+ // as_int8 = as_int16 >> (f16_mantissa - f4_mantissa)
821
864
llvm::Value* as_int16 = b->CreateBitCast (reduced_precision, b->getInt16Ty ());
822
865
llvm::Value* as_int8 =
823
- b->CreateTrunc (b->CreateLShr (as_int16, 9 ), b->getInt8Ty ());
866
+ b->CreateTrunc (b->CreateLShr (as_int16, mantissa_diff ), b->getInt8Ty ());
824
867
825
- // Extract sign, exponent and mantissa from reduced precision value.
826
- auto i8_const = [&](int val) {
827
- return llvm::ConstantInt::get (b->getInt8Ty (), val);
828
- };
868
+ // Get the sign (0 or 1).
869
+ // f4_sign = as_int8 >> 6
829
870
llvm::Value* f4_sign = b->CreateLShr (as_int8, 6 );
871
+
872
+ // Get exponent and mantissa bits without the sign.
873
+ // Important: the mask is 0x3F (not 0x7F), discard bit #6.
874
+ // f4_bits = as_int8 & 0x3F
830
875
llvm::Value* f4_bits = b->CreateAnd (as_int8, i8_const (0x3F ));
831
- llvm::Value* f4_normal = b->CreateSub (f4_bits, i8_const (28 ));
832
876
833
- // Special case for exponent overflow.
834
- auto i16_const = [&](int val) {
835
- return llvm::ConstantInt::get (b->getInt16Ty (), val);
836
- };
837
- llvm::Value* f16_bits = b->CreateAnd (
838
- b->CreateBitCast (f16_value, b->getInt16Ty ()), i16_const (0x7FFF ));
839
- llvm::Value* is_overflow =
840
- b->CreateICmpUGE (f16_bits, i16_const (0x4700 )); // 7.0
841
- llvm::Value* is_nan = b->CreateICmpUGT (f16_bits, i16_const (0x7C00 )); // inf
842
- llvm::Value* max_or_nan =
843
- b->CreateSelect (is_nan, i8_const (0x8 ), i8_const (0x7 ));
844
- llvm::Value* f4_normal_or_overflow =
845
- b->CreateSelect (is_overflow, max_or_nan, f4_normal);
846
-
847
- // Special case for exponent underflow.
877
+ // Convert F16 exponent to F4 exponent by readjusting the exponent bias.
878
+ // This produces the "normal" result, i.e. not Inf or NaN or denormal.
879
+ // f4_normal = f4_bits - ((f16_bias - f4_bias) << f4_mantissa)
880
+ constexpr int f4_exponent_offset = bias_diff << 1 ;
881
+ llvm::Value* f4_normal = b->CreateSub (f4_bits, i8_const (f4_exponent_offset));
882
+
883
+ // If the rounding resulted in zero exponent, the value is incorrect.
884
+ // This happens when the input is < 1.0
885
+ // is_underflow = f4_normal <= 1
848
886
llvm::Value* is_underflow = b->CreateICmpSLE (f4_normal, i8_const (1 ));
849
- llvm::Value* is_one = b->CreateICmpUGE (f16_bits, i16_const (0x3A00 )); // 0.75
850
- llvm::Value* is_zero = b->CreateICmpULE (f16_bits, i16_const (0x3400 )); // 0.25
851
- llvm::Value* denorm_or_zero =
852
- b->CreateSelect (is_zero, i8_const (0x0 ), i8_const (0x1 ));
853
- llvm::Value* f4_small =
854
- b->CreateSelect (is_one, i8_const (0x2 ), denorm_or_zero);
855
- llvm::Value* f4_result =
856
- b->CreateSelect (is_underflow, f4_small, f4_normal_or_overflow);
887
+
888
+ // Chain of selects that handles the special cases.
889
+ // f4_result =
890
+ // is_underflow ? (is_one ? 1.0 : (is_zero ? 0.0 : 0.5)) :
891
+ // is_overflow ? (is_nan ? -0.0 : 6.0) :
892
+ // f4_normal
893
+ llvm::Value* f4_result = b->CreateSelect (
894
+ is_underflow,
895
+ // If underflow, the input is < 1.0; the result is either 0.0, 0.5 or 1.0
896
+ b->CreateSelect (is_one, i8_const (0x2 ),
897
+ b->CreateSelect (is_zero, i8_const (0x0 ), i8_const (0x1 ))),
898
+ // If overflow, the input is >= 7.0 or infinity or NaN.
899
+ b->CreateSelect (is_overflow,
900
+ b->CreateSelect (is_nan, i8_const (0x8 ), i8_const (0x7 )),
901
+ f4_normal));
857
902
858
903
// Add sign to the resulting value.
904
+ // f4_signed_result = (f4_sign << 3) | f4_result
859
905
return b->CreateOr (f4_result, b->CreateShl (f4_sign, 3 ));
860
906
}
861
907
862
908
llvm::Value* EmitF4e2m1fnToF16 (llvm::Value* f8_value, llvm::IRBuilder<>* b) {
863
- llvm::Value* as_int16 = b->CreateZExt (f8_value, b->getInt16Ty ());
864
-
865
- // Extract sign, exponent and mantissa from reduced precision value.
866
909
auto i16_const = [&](int val) {
867
910
return llvm::ConstantInt::get (b->getInt16Ty (), val);
868
911
};
869
- llvm::Value* sign = b->CreateLShr (as_int16, 3 );
870
- llvm::Value* sign_shifted = b->CreateShl (sign, 15 );
871
- llvm::Value* bits = b->CreateAnd (as_int16, i16_const (0x7 ));
872
- llvm::Value* bits_shifted = b->CreateShl (bits, 9 );
873
-
874
- // Re-bias the exponent and handle denormals.
875
- llvm::Value* f16_normal = b->CreateAdd (bits_shifted, i16_const (14 << 10 ));
876
- llvm::Value* is_denorm_or_zero = b->CreateICmpULE (bits, i16_const (1 ));
877
- llvm::Value* is_zero = b->CreateICmpEQ (bits, i16_const (0 ));
878
- llvm::Value* denorm_or_zero =
879
- b->CreateSelect (is_zero, i16_const (0x0000 ), i16_const (0x3800 ));
880
- llvm::Value* f16_result =
881
- b->CreateSelect (is_denorm_or_zero, denorm_or_zero, f16_normal);
912
+ constexpr int mantissa_diff = 9 ; // 10 for F16, 1 for F4
913
+ constexpr int bias_diff = 14 ; // 15 for F16, 1 for F4
914
+
915
+ // The input value is a 8-bit integer, extend it to 16-bit integer.
916
+ // as_int16 = bitcast(f8_value, int)
917
+ llvm::Value* as_int16 = b->CreateZExt (f8_value, b->getInt16Ty ());
918
+
919
+ // Get the sign and shift it to F16 position.
920
+ // f4_sign = as_int16 >> 3
921
+ // f16_sign_bit = f4_sign << 15
922
+ llvm::Value* f4_sign = b->CreateLShr (as_int16, 3 );
923
+ llvm::Value* f16_sign_bit = b->CreateShl (f4_sign, 15 );
924
+
925
+ // Get exponent and mantissa bits without the sign.
926
+ // f4_bits = as_int16 & 0x7
927
+ // f16_bits = f4_bits << (f16_mantissa - f4_mantissa)
928
+ llvm::Value* f4_bits = b->CreateAnd (as_int16, i16_const (0x7 ));
929
+ llvm::Value* f16_bits = b->CreateShl (f4_bits, mantissa_diff);
930
+
931
+ // Convert F16 exponent to F4 exponent by readjusting the exponent bias.
932
+ // f4_normal = f4_bits - ((f16_bias - f4_bias) << f4_mantissa)
933
+ constexpr int f16_exponent_offset = bias_diff << 10 ;
934
+ llvm::Value* f16_normal =
935
+ b->CreateAdd (f16_bits, i16_const (f16_exponent_offset));
936
+
937
+ // For denormal and zero, the exponent is different. Handle these cases
938
+ // separately below.
939
+ // is_denorm_or_zero = f4_bits <= 1
940
+ // is_zero = f4_bits == 0
941
+ llvm::Value* is_denorm_or_zero = b->CreateICmpULE (f4_bits, i16_const (1 ));
942
+ llvm::Value* is_zero = b->CreateICmpEQ (f4_bits, i16_const (0 ));
943
+
944
+ // Chain of selects that handles the special cases.
945
+ // f16_result = is_denorm_or_zero ? (is_zero ? 0.0 : 0.5) : f16_normal
946
+ llvm::Value* f16_result = b->CreateSelect (
947
+ is_denorm_or_zero,
948
+ b->CreateSelect (is_zero, i16_const (0x0000 ), i16_const (0x3800 )),
949
+ f16_normal);
882
950
883
951
// Add sign to the resulting value.
884
- llvm::Value* f16_signed = b->CreateOr (f16_result, sign_shifted);
885
- return b->CreateBitCast (f16_signed, b->getHalfTy ());
952
+ // f16_signed_result = f16_sign_bit | f16_result
953
+ llvm::Value* f16_signed_result = b->CreateOr (f16_result, f16_sign_bit);
954
+ return b->CreateBitCast (f16_signed_result, b->getHalfTy ());
886
955
}
887
956
888
957
llvm::Value* EmitF32ToF8e8m0fnu (llvm::Value* f32_value, llvm::IRBuilder<>* b) {
889
- llvm::Value* as_int32 = b->CreateBitCast (f32_value, b->getInt32Ty ());
890
-
891
- // Result is NaN if input is zero, negative, infinity or NaN.
892
958
auto i32_const = [&](int val) {
893
959
return llvm::ConstantInt::get (b->getInt32Ty (), val);
894
960
};
895
- llvm::Value* is_denorm = b->CreateICmpULE (as_int32, i32_const (0x400000 ));
896
- llvm::Value* is_nan =
897
- b->CreateOr (b->CreateICmpSLE (as_int32, i32_const (0 )),
898
- b->CreateICmpSGE (as_int32, i32_const (0x7F400000 )));
899
961
900
- // Round the value and extract exponent.
901
- llvm::Value* rounded = b->CreateAdd (as_int32, i32_const (0x400000 ));
902
- llvm::Value* shifted = b->CreateAShr (rounded, 23 );
903
- llvm::Value* finite = b->CreateSelect (is_denorm, i32_const (0x00 ), shifted);
904
- llvm::Value* f32_result = b->CreateSelect (is_nan, i32_const (0xFF ), finite );
962
+ // Cast the input value to an integer for bitwise manipulation.
963
+ // as_int32 = bitcast(f32_value, int)
964
+ llvm::Value* as_int32 = b->CreateBitCast (f32_value, b->getInt32Ty ());
965
+
966
+ // Check if the input is zero, negative, overflow, infinity or NaN.
967
+ // All of these cases cannot be represented in the E8M0 format.
968
+ // is_zero_or_negative = as_int32 <= 0
969
+ // is_overflow_or_nan = as_int32 >= 0x1.8p127
970
+ // is_nan = is_zero_or_negative | is_overflow_or_nan
971
+ llvm::Value* is_zero_or_negative = b->CreateICmpSLE (as_int32, i32_const (0 ));
972
+ llvm::Value* is_overflow_or_nan =
973
+ b->CreateICmpSGE (as_int32, i32_const (0x7F400000 )); // 1.5 * 2^127
974
+ llvm::Value* is_nan = b->CreateOr (is_zero_or_negative, is_overflow_or_nan);
975
+
976
+ // Check if the input is a denormal which should round to the minimum value
977
+ // (2^-127), as there is no zero value.
978
+ // is_denorm = as_int32 <= 0x1p-127
979
+ llvm::Value* is_denorm =
980
+ b->CreateICmpULE (as_int32, i32_const (0x400000 )); // 1.0 * 2^-127
981
+
982
+ // Round the value (always up) and discard the mantissa.
983
+ // rounded = as_int32 + 0x1p-127
984
+ // f8_normal = as_int32 >> f32_mantissa
985
+ llvm::Value* rounded =
986
+ b->CreateAdd (as_int32, i32_const (0x400000 )); // 1.0 * 2^-127
987
+ llvm::Value* f8_normal = b->CreateAShr (rounded, 23 );
988
+
989
+ // Chain of selects that handles the special cases.
990
+ // f8_result = is_nan ? 0xFF : (is_denorm ? 0x00 : f8_normal)
991
+ llvm::Value* f8_result =
992
+ b->CreateSelect (is_nan, i32_const (0xFF ),
993
+ b->CreateSelect (is_denorm, i32_const (0x00 ), f8_normal));
905
994
906
995
// Truncate to the result type.
907
- return b->CreateTrunc (f32_result , b->getInt8Ty ());
996
+ return b->CreateTrunc (f8_result , b->getInt8Ty ());
908
997
}
909
998
910
999
llvm::Value* EmitF8e8m0fnuToF32 (llvm::Value* f8_value, llvm::IRBuilder<>* b) {
911
- // Shift exponent to the left for the normal case.
912
- llvm::Value* as_int32 = b->CreateZExt (f8_value, b->getInt32Ty ());
913
- llvm::Value* shifted = b->CreateShl (as_int32, 23 );
914
-
915
- // Special values for 0x00 (denorm) and 0xFF (NaN).
916
1000
auto i32_const = [&](int val) {
917
1001
return llvm::ConstantInt::get (b->getInt32Ty (), val);
918
1002
};
1003
+
1004
+ // The input value is a 8-bit integer, extend it to 32-bit integer.
1005
+ // as_int32 = bitcast(f8_value, int)
1006
+ llvm::Value* as_int32 = b->CreateZExt (f8_value, b->getInt32Ty ());
1007
+
1008
+ // Check if the input is a denormal or NaN.
1009
+ // is_zero = as_int32 == 0x00
1010
+ // is_nan = as_int32 == 0xFF
919
1011
llvm::Value* is_zero = b->CreateICmpEQ (as_int32, i32_const (0 ));
920
1012
llvm::Value* is_nan = b->CreateICmpEQ (as_int32, i32_const (0xFF ));
921
- llvm::Value* denorm_or_shifted =
922
- b->CreateSelect (is_zero, i32_const (0x00400000 ), shifted);
923
- llvm::Value* f32_result =
924
- b->CreateSelect (is_nan, i32_const (0x7FC00000 ), denorm_or_shifted);
925
1013
1014
+ // Shift exponent to the left for the normal case.
1015
+ // f32_normal = as_int32 << mantissa_diff
1016
+ llvm::Value* f32_normal = b->CreateShl (as_int32, 23 );
1017
+
1018
+ // Chain of selects that handles the special cases.
1019
+ // f32_result = is_nan ? 0x7FC00000 : (is_zero ? 0x1p-127 : f32_normal)
1020
+ llvm::Value* f32_result = b->CreateSelect (
1021
+ is_nan, i32_const (0x7FC00000 ),
1022
+ b->CreateSelect (is_zero, i32_const (0x400000 ), f32_normal));
1023
+
1024
+ // Bitcast integer bits to the result type.
926
1025
return b->CreateBitCast (f32_result, b->getFloatTy ());
927
1026
}
928
1027
0 commit comments