Skip to content

Commit b0f0676

Browse files
authored
[AArch64] Implement intrinsics for SME FP8 F1CVT/F2CVT and BF1CVT/BF2CVT (#118027)
This patch implements the following intrinsics: 8-bit floating-point convert to half-precision or BFloat16 (in-order). ``` c // Variant is also available for: _bf16[_mf8]_x2 svfloat16x2_t svcvt1_f16[_mf8]_x2_fpm(svmfloat8_t zn, fpm_t fpm) __arm_streaming; svfloat16x2_t svcvt2_f16[_mf8]_x2_fpm(svmfloat8_t zn, fpm_t fpm) __arm_streaming; ``` In accordance with ARM-software/acle#323. Co-authored-by: Marin Lukac [email protected] Co-authored-by: Caroline Concatto [email protected]
1 parent 2ab687e commit b0f0676

File tree

6 files changed

+151
-10
lines changed

6 files changed

+151
-10
lines changed

clang/include/clang/Basic/arm_sve.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2429,6 +2429,10 @@ let SVETargetGuard = InvalidMode, SMETargetGuard = "sme2,fp8" in {
24292429
def FSCALE_X2 : Inst<"svscale[_{d}_x2]", "222.x", "fhd", MergeNone, "aarch64_sme_fp8_scale_x2", [IsStreaming],[]>;
24302430
def FSCALE_X4 : Inst<"svscale[_{d}_x4]", "444.x", "fhd", MergeNone, "aarch64_sme_fp8_scale_x4", [IsStreaming],[]>;
24312431

2432+
// Convert from FP8 to half-precision/BFloat16 multi-vector
2433+
def SVF1CVT : Inst<"svcvt1_{d}[_mf8]_x2_fpm", "2~>", "bh", MergeNone, "aarch64_sve_fp8_cvt1_x2", [IsStreaming, SetsFPMR], []>;
2434+
def SVF2CVT : Inst<"svcvt2_{d}[_mf8]_x2_fpm", "2~>", "bh", MergeNone, "aarch64_sve_fp8_cvt2_x2", [IsStreaming, SetsFPMR], []>;
2435+
24322436
// Convert from FP8 to deinterleaved half-precision/BFloat16 multi-vector
24332437
def SVF1CVTL : Inst<"svcvtl1_{d}[_mf8]_x2_fpm", "2~>", "bh", MergeNone, "aarch64_sve_fp8_cvtl1_x2", [IsStreaming, SetsFPMR], []>;
24342438
def SVF2CVTL : Inst<"svcvtl2_{d}[_mf8]_x2_fpm", "2~>", "bh", MergeNone, "aarch64_sve_fp8_cvtl2_x2", [IsStreaming, SetsFPMR], []>;

clang/test/CodeGen/AArch64/fp8-intrinsics/acle_sme2_fp8_cvt.c

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,70 @@
1616
#define SVE_ACLE_FUNC(A1,A2,A3) A1##A2##A3
1717
#endif
1818

19+
// CHECK-LABEL: @test_cvt1_f16_x2(
20+
// CHECK-NEXT: entry:
21+
// CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPMR:%.*]])
22+
// CHECK-NEXT: [[TMP0:%.*]] = tail call { <vscale x 8 x half>, <vscale x 8 x half> } @llvm.aarch64.sve.fp8.cvt1.x2.nxv8f16(<vscale x 16 x i8> [[ZN:%.*]])
23+
// CHECK-NEXT: ret { <vscale x 8 x half>, <vscale x 8 x half> } [[TMP0]]
24+
//
25+
// CPP-CHECK-LABEL: @_Z16test_cvt1_f16_x2u13__SVMfloat8_tm(
26+
// CPP-CHECK-NEXT: entry:
27+
// CPP-CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPMR:%.*]])
28+
// CPP-CHECK-NEXT: [[TMP0:%.*]] = tail call { <vscale x 8 x half>, <vscale x 8 x half> } @llvm.aarch64.sve.fp8.cvt1.x2.nxv8f16(<vscale x 16 x i8> [[ZN:%.*]])
29+
// CPP-CHECK-NEXT: ret { <vscale x 8 x half>, <vscale x 8 x half> } [[TMP0]]
30+
//
31+
svfloat16x2_t test_cvt1_f16_x2(svmfloat8_t zn, fpm_t fpmr) __arm_streaming {
32+
return SVE_ACLE_FUNC(svcvt1_f16,_mf8,_x2_fpm)(zn, fpmr);
33+
}
34+
35+
// CHECK-LABEL: @test_cvt2_f16_x2(
36+
// CHECK-NEXT: entry:
37+
// CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPMR:%.*]])
38+
// CHECK-NEXT: [[TMP0:%.*]] = tail call { <vscale x 8 x half>, <vscale x 8 x half> } @llvm.aarch64.sve.fp8.cvt2.x2.nxv8f16(<vscale x 16 x i8> [[ZN:%.*]])
39+
// CHECK-NEXT: ret { <vscale x 8 x half>, <vscale x 8 x half> } [[TMP0]]
40+
//
41+
// CPP-CHECK-LABEL: @_Z16test_cvt2_f16_x2u13__SVMfloat8_tm(
42+
// CPP-CHECK-NEXT: entry:
43+
// CPP-CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPMR:%.*]])
44+
// CPP-CHECK-NEXT: [[TMP0:%.*]] = tail call { <vscale x 8 x half>, <vscale x 8 x half> } @llvm.aarch64.sve.fp8.cvt2.x2.nxv8f16(<vscale x 16 x i8> [[ZN:%.*]])
45+
// CPP-CHECK-NEXT: ret { <vscale x 8 x half>, <vscale x 8 x half> } [[TMP0]]
46+
//
47+
svfloat16x2_t test_cvt2_f16_x2(svmfloat8_t zn, fpm_t fpmr) __arm_streaming {
48+
return SVE_ACLE_FUNC(svcvt2_f16,_mf8,_x2_fpm)(zn, fpmr);
49+
}
50+
51+
// CHECK-LABEL: @test_cvt1_bf16_x2(
52+
// CHECK-NEXT: entry:
53+
// CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPMR:%.*]])
54+
// CHECK-NEXT: [[TMP0:%.*]] = tail call { <vscale x 8 x bfloat>, <vscale x 8 x bfloat> } @llvm.aarch64.sve.fp8.cvt1.x2.nxv8bf16(<vscale x 16 x i8> [[ZN:%.*]])
55+
// CHECK-NEXT: ret { <vscale x 8 x bfloat>, <vscale x 8 x bfloat> } [[TMP0]]
56+
//
57+
// CPP-CHECK-LABEL: @_Z17test_cvt1_bf16_x2u13__SVMfloat8_tm(
58+
// CPP-CHECK-NEXT: entry:
59+
// CPP-CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPMR:%.*]])
60+
// CPP-CHECK-NEXT: [[TMP0:%.*]] = tail call { <vscale x 8 x bfloat>, <vscale x 8 x bfloat> } @llvm.aarch64.sve.fp8.cvt1.x2.nxv8bf16(<vscale x 16 x i8> [[ZN:%.*]])
61+
// CPP-CHECK-NEXT: ret { <vscale x 8 x bfloat>, <vscale x 8 x bfloat> } [[TMP0]]
62+
//
63+
svbfloat16x2_t test_cvt1_bf16_x2(svmfloat8_t zn, fpm_t fpmr) __arm_streaming {
64+
return SVE_ACLE_FUNC(svcvt1_bf16,_mf8,_x2_fpm)(zn, fpmr);
65+
}
66+
67+
// CHECK-LABEL: @test_cvt2_bf16_x2(
68+
// CHECK-NEXT: entry:
69+
// CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPMR:%.*]])
70+
// CHECK-NEXT: [[TMP0:%.*]] = tail call { <vscale x 8 x bfloat>, <vscale x 8 x bfloat> } @llvm.aarch64.sve.fp8.cvt2.x2.nxv8bf16(<vscale x 16 x i8> [[ZN:%.*]])
71+
// CHECK-NEXT: ret { <vscale x 8 x bfloat>, <vscale x 8 x bfloat> } [[TMP0]]
72+
//
73+
// CPP-CHECK-LABEL: @_Z17test_cvt2_bf16_x2u13__SVMfloat8_tm(
74+
// CPP-CHECK-NEXT: entry:
75+
// CPP-CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPMR:%.*]])
76+
// CPP-CHECK-NEXT: [[TMP0:%.*]] = tail call { <vscale x 8 x bfloat>, <vscale x 8 x bfloat> } @llvm.aarch64.sve.fp8.cvt2.x2.nxv8bf16(<vscale x 16 x i8> [[ZN:%.*]])
77+
// CPP-CHECK-NEXT: ret { <vscale x 8 x bfloat>, <vscale x 8 x bfloat> } [[TMP0]]
78+
//
79+
svbfloat16x2_t test_cvt2_bf16_x2(svmfloat8_t zn, fpm_t fpmr) __arm_streaming {
80+
return SVE_ACLE_FUNC(svcvt2_bf16,_mf8,_x2_fpm)(zn, fpmr);
81+
}
82+
1983
// CHECK-LABEL: @test_cvtl1_f16_x2(
2084
// CHECK-NEXT: entry:
2185
// CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPMR:%.*]])

clang/test/Sema/aarch64-fp8-intrinsics/acle_sme2_fp8_cvt.c

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,13 @@ void test_features_sme2_fp8(svmfloat8_t zn, fpm_t fpmr) __arm_streaming {
1414
svcvtl1_bf16_mf8_x2_fpm(zn, fpmr);
1515
// expected-error@+1 {{'svcvtl2_bf16_mf8_x2_fpm' needs target feature sme,sme2,fp8}}
1616
svcvtl2_bf16_mf8_x2_fpm(zn, fpmr);
17+
18+
// expected-error@+1 {{'svcvt1_f16_mf8_x2_fpm' needs target feature sme,sme2,fp8}}
19+
svcvt1_f16_mf8_x2_fpm(zn, fpmr);
20+
// expected-error@+1 {{'svcvt2_f16_mf8_x2_fpm' needs target feature sme,sme2,fp8}}
21+
svcvt2_f16_mf8_x2_fpm(zn, fpmr);
22+
// expected-error@+1 {{'svcvt1_bf16_mf8_x2_fpm' needs target feature sme,sme2,fp8}}
23+
svcvt1_bf16_mf8_x2_fpm(zn, fpmr);
24+
// expected-error@+1 {{'svcvt2_bf16_mf8_x2_fpm' needs target feature sme,sme2,fp8}}
25+
svcvt2_bf16_mf8_x2_fpm(zn, fpmr);
1726
}

llvm/include/llvm/IR/IntrinsicsAArch64.td

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3812,16 +3812,6 @@ let TargetPrefix = "aarch64" in {
38123812
[LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>,
38133813
LLVMVectorOfBitcastsToInt<0>, LLVMVectorOfBitcastsToInt<0>, LLVMVectorOfBitcastsToInt<0>, LLVMVectorOfBitcastsToInt<0>],
38143814
[IntrNoMem]>;
3815-
3816-
class SME2_FP8_CVT_X2_Single_Intrinsic
3817-
: DefaultAttrsIntrinsic<[llvm_anyvector_ty, LLVMMatchType<0>],
3818-
[llvm_nxv16i8_ty],
3819-
[IntrReadMem, IntrInaccessibleMemOnly]>;
3820-
//
3821-
// CVT from FP8 to deinterleaved half-precision/BFloat16 multi-vector
3822-
//
3823-
def int_aarch64_sve_fp8_cvtl1_x2 : SME2_FP8_CVT_X2_Single_Intrinsic;
3824-
def int_aarch64_sve_fp8_cvtl2_x2 : SME2_FP8_CVT_X2_Single_Intrinsic;
38253815
}
38263816

38273817
// SVE2.1 - ZIPQ1, ZIPQ2, UZPQ1, UZPQ2
@@ -3864,3 +3854,25 @@ def int_aarch64_sve_famin_u : AdvSIMD_Pred2VectorArg_Intrinsic;
38643854
// Neon absolute maximum and minimum
38653855
def int_aarch64_neon_famax : AdvSIMD_2VectorArg_Intrinsic;
38663856
def int_aarch64_neon_famin : AdvSIMD_2VectorArg_Intrinsic;
3857+
3858+
//
3859+
// FP8 Intrinsics
3860+
//
3861+
let TargetPrefix = "aarch64" in {
3862+
3863+
class SME2_FP8_CVT_X2_Single_Intrinsic
3864+
: DefaultAttrsIntrinsic<[llvm_anyvector_ty, LLVMMatchType<0>],
3865+
[llvm_nxv16i8_ty],
3866+
[IntrReadMem, IntrInaccessibleMemOnly]>;
3867+
//
3868+
// CVT from FP8 to half-precision/BFloat16 multi-vector
3869+
//
3870+
def int_aarch64_sve_fp8_cvt1_x2 : SME2_FP8_CVT_X2_Single_Intrinsic;
3871+
def int_aarch64_sve_fp8_cvt2_x2 : SME2_FP8_CVT_X2_Single_Intrinsic;
3872+
3873+
//
3874+
// CVT from FP8 to deinterleaved half-precision/BFloat16 multi-vector
3875+
//
3876+
def int_aarch64_sve_fp8_cvtl1_x2 : SME2_FP8_CVT_X2_Single_Intrinsic;
3877+
def int_aarch64_sve_fp8_cvtl2_x2 : SME2_FP8_CVT_X2_Single_Intrinsic;
3878+
}

llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5581,6 +5581,18 @@ void AArch64DAGToDAGISel::Select(SDNode *Node) {
55815581
{AArch64::BF2CVTL_2ZZ_BtoH, AArch64::F2CVTL_2ZZ_BtoH}))
55825582
SelectCVTIntrinsicFP8(Node, 2, Opc);
55835583
return;
5584+
case Intrinsic::aarch64_sve_fp8_cvt1_x2:
5585+
if (auto Opc = SelectOpcodeFromVT<SelectTypeKind::FP>(
5586+
Node->getValueType(0),
5587+
{AArch64::BF1CVT_2ZZ_BtoH, AArch64::F1CVT_2ZZ_BtoH}))
5588+
SelectCVTIntrinsicFP8(Node, 2, Opc);
5589+
return;
5590+
case Intrinsic::aarch64_sve_fp8_cvt2_x2:
5591+
if (auto Opc = SelectOpcodeFromVT<SelectTypeKind::FP>(
5592+
Node->getValueType(0),
5593+
{AArch64::BF2CVT_2ZZ_BtoH, AArch64::F2CVT_2ZZ_BtoH}))
5594+
SelectCVTIntrinsicFP8(Node, 2, Opc);
5595+
return;
55845596
}
55855597
} break;
55865598
case ISD::INTRINSIC_WO_CHAIN: {

llvm/test/CodeGen/AArch64/sme2-fp8-intrinsics-cvt.ll

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,46 @@
11
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 2
22
; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sme2,+fp8 -verify-machineinstrs -force-streaming < %s | FileCheck %s
33

4+
; F1CVT / F2CVT
5+
6+
define { <vscale x 8 x half>, <vscale x 8 x half> } @f1cvt(<vscale x 16 x i8> %zm) {
7+
; CHECK-LABEL: f1cvt:
8+
; CHECK: // %bb.0:
9+
; CHECK-NEXT: f1cvt { z0.h, z1.h }, z0.b
10+
; CHECK-NEXT: ret
11+
%res = call { <vscale x 8 x half>, <vscale x 8 x half> } @llvm.aarch64.sve.fp8.cvt1.x2.nxv8f16(<vscale x 16 x i8> %zm)
12+
ret { <vscale x 8 x half>, <vscale x 8 x half> } %res
13+
}
14+
15+
define { <vscale x 8 x half>, <vscale x 8 x half> } @f2cvt(<vscale x 16 x i8> %zm) {
16+
; CHECK-LABEL: f2cvt:
17+
; CHECK: // %bb.0:
18+
; CHECK-NEXT: f2cvt { z0.h, z1.h }, z0.b
19+
; CHECK-NEXT: ret
20+
%res = call { <vscale x 8 x half>, <vscale x 8 x half> } @llvm.aarch64.sve.fp8.cvt2.x2.nxv8f16(<vscale x 16 x i8> %zm)
21+
ret { <vscale x 8 x half>, <vscale x 8 x half> } %res
22+
}
23+
24+
; BF1CVT / BF2CVT
25+
26+
define { <vscale x 8 x bfloat>, <vscale x 8 x bfloat> } @bf1cvt(<vscale x 16 x i8> %zm) {
27+
; CHECK-LABEL: bf1cvt:
28+
; CHECK: // %bb.0:
29+
; CHECK-NEXT: bf1cvt { z0.h, z1.h }, z0.b
30+
; CHECK-NEXT: ret
31+
%res = call { <vscale x 8 x bfloat>, <vscale x 8 x bfloat> } @llvm.aarch64.sve.fp8.cvt1.x2.nxv8bf16(<vscale x 16 x i8> %zm)
32+
ret { <vscale x 8 x bfloat>, <vscale x 8 x bfloat> } %res
33+
}
34+
35+
define { <vscale x 8 x bfloat>, <vscale x 8 x bfloat> } @bf2cvt(<vscale x 16 x i8> %zm) {
36+
; CHECK-LABEL: bf2cvt:
37+
; CHECK: // %bb.0:
38+
; CHECK-NEXT: bf2cvt { z0.h, z1.h }, z0.b
39+
; CHECK-NEXT: ret
40+
%res = call { <vscale x 8 x bfloat>, <vscale x 8 x bfloat> } @llvm.aarch64.sve.fp8.cvt2.x2.nxv8bf16(<vscale x 16 x i8> %zm)
41+
ret { <vscale x 8 x bfloat>, <vscale x 8 x bfloat> } %res
42+
}
43+
444
; F1CVTL / F2CVTL
545

646
define { <vscale x 8 x half>, <vscale x 8 x half> } @f1cvtl(<vscale x 16 x i8> %zm) {

0 commit comments

Comments
 (0)