Skip to content

Commit 51db6eb

Browse files
[mlir][IR] Remove isF...() type API for low-precision FP types
1 parent 8c85f1f commit 51db6eb

File tree

11 files changed

+75
-95
lines changed

11 files changed

+75
-95
lines changed

mlir/include/mlir/IR/CommonTypeConstraints.td

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -329,31 +329,31 @@ def F64 : F<64>;
329329
def F80 : F<80>;
330330
def F128 : F<128>;
331331

332-
def BF16 : Type<CPred<"$_self.isBF16()">, "bfloat16 type">,
332+
def BF16 : Type<CPred<"::llvm::isa<BFloat16Type>($_self)">, "bfloat16 type">,
333333
BuildableType<"$_builder.getType<BFloat16Type>()">;
334-
def TF32 : Type<CPred<"$_self.isTF32()">, "tf32 type">,
334+
def TF32 : Type<CPred<"::llvm::isa<FloatTF32Type>($_self)">, "tf32 type">,
335335
BuildableType<"$_builder.getType<FloatTF32Type>()">;
336-
def F8E4M3FN : Type<CPred<"$_self.isFloat8E4M3FN()">, "f8E4M3FN type">,
336+
def F8E4M3FN : Type<CPred<"::llvm::isa<Float8E4M3FNType>($_self)">, "f8E4M3FN type">,
337337
BuildableType<"$_builder.getType<Float8E4M3FNType>()">;
338-
def F8E5M2 : Type<CPred<"$_self.isFloat8E5M2()">, "f8E5M2 type">,
338+
def F8E5M2 : Type<CPred<"::llvm::isa<Float8E5M2Type>($_self)">, "f8E5M2 type">,
339339
BuildableType<"$_builder.getType<Float8E5M2Type>()">;
340-
def F8E4M3 : Type<CPred<"$_self.isFloat8E4M3()">, "f8E4M3 type">,
340+
def F8E4M3 : Type<CPred<"::llvm::isa<Float8E4M3Type>($_self)">, "f8E4M3 type">,
341341
BuildableType<"$_builder.getType<Float8E4M3Type>()">;
342-
def F8E4M3FNUZ : Type<CPred<"$_self.isFloat8E4M3FNUZ()">, "f8E4M3FNUZ type">,
342+
def F8E4M3FNUZ : Type<CPred<"::llvm::isa<Float8E4M3FNUZType>($_self)">, "f8E4M3FNUZ type">,
343343
BuildableType<"$_builder.getType<Float8E4M3FNUZType>()">;
344-
def F8E4M3B11FNUZ : Type<CPred<"$_self.isFloat8E4M3B11FNUZ()">, "f8E4M3B11FNUZ type">,
344+
def F8E4M3B11FNUZ : Type<CPred<"::llvm::isa<Float8E4M3B11FNUZType>($_self)">, "f8E4M3B11FNUZ type">,
345345
BuildableType<"$_builder.getType<Float8E4M3B11FNUZType>()">;
346-
def F8E5M2FNUZ : Type<CPred<"$_self.isFloat8E5M2FNUZ()">, "f8E5M2FNUZ type">,
346+
def F8E5M2FNUZ : Type<CPred<"::llvm::isa<Float8E5M2FNUZType>($_self)">, "f8E5M2FNUZ type">,
347347
BuildableType<"$_builder.getType<Float8E5M2FNUZType>()">;
348-
def F8E3M4 : Type<CPred<"$_self.isFloat8E3M4()">, "f8E3M4 type">,
348+
def F8E3M4 : Type<CPred<"::llvm::isa<Float8E3M4Type>($_self)">, "f8E3M4 type">,
349349
BuildableType<"$_builder.getType<Float8E3M4Type>()">;
350-
def F4E2M1FN : Type<CPred<"$_self.isFloat4E2M1FN()">, "f4E2M1FN type">,
350+
def F4E2M1FN : Type<CPred<"::llvm::isa<Float4E2M1FNType>($_self)">, "f4E2M1FN type">,
351351
BuildableType<"$_builder.getType<Float4E2M1FNType>()">;
352-
def F6E2M3FN : Type<CPred<"$_self.isFloat6E2M3FN()">, "f6E2M3FN type">,
352+
def F6E2M3FN : Type<CPred<"::llvm::isa<Float6E2M3FNType>($_self)">, "f6E2M3FN type">,
353353
BuildableType<"$_builder.getType<Float6E2M3FNType>()">;
354-
def F6E3M2FN : Type<CPred<"$_self.isFloat6E3M2FN()">, "f6E3M2FN type">,
354+
def F6E3M2FN : Type<CPred<"::llvm::isa<Float6E3M2FNType($_self)">, "f6E3M2FN type">,
355355
BuildableType<"$_builder.getType<Float6E3M2FNType>()">;
356-
def F8E8M0FNU : Type<CPred<"$_self.isFloat8E8M0FNU()">, "f8E8M0FNU type">,
356+
def F8E8M0FNU : Type<CPred<"::llvm::isa<Float8E8M0FNUType>($_self)">, "f8E8M0FNU type">,
357357
BuildableType<"$_builder.getType<Float8E8M0FNUType>()">;
358358

359359
def AnyComplex : Type<CPred<"::llvm::isa<::mlir::ComplexType>($_self)">,

mlir/include/mlir/IR/Types.h

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -125,17 +125,6 @@ class Type {
125125
// Convenience predicates. This is only for floating point types,
126126
// derived types should use isa/dyn_cast.
127127
bool isIndex() const;
128-
bool isFloat4E2M1FN() const;
129-
bool isFloat6E2M3FN() const;
130-
bool isFloat6E3M2FN() const;
131-
bool isFloat8E5M2() const;
132-
bool isFloat8E4M3() const;
133-
bool isFloat8E4M3FN() const;
134-
bool isFloat8E5M2FNUZ() const;
135-
bool isFloat8E4M3FNUZ() const;
136-
bool isFloat8E4M3B11FNUZ() const;
137-
bool isFloat8E3M4() const;
138-
bool isFloat8E8M0FNU() const;
139128
bool isBF16() const;
140129
bool isF16() const;
141130
bool isTF32() const;

mlir/lib/CAPI/IR/BuiltinTypes.cpp

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ MlirTypeID mlirFloat4E2M1FNTypeGetTypeID() {
9090
}
9191

9292
bool mlirTypeIsAFloat4E2M1FN(MlirType type) {
93-
return unwrap(type).isFloat4E2M1FN();
93+
return llvm::isa<Float4E2M1FNType>(unwrap(type));
9494
}
9595

9696
MlirType mlirFloat4E2M1FNTypeGet(MlirContext ctx) {
@@ -102,7 +102,7 @@ MlirTypeID mlirFloat6E2M3FNTypeGetTypeID() {
102102
}
103103

104104
bool mlirTypeIsAFloat6E2M3FN(MlirType type) {
105-
return unwrap(type).isFloat6E2M3FN();
105+
return llvm::isa<Float6E2M3FNType>(unwrap(type));
106106
}
107107

108108
MlirType mlirFloat6E2M3FNTypeGet(MlirContext ctx) {
@@ -114,7 +114,7 @@ MlirTypeID mlirFloat6E3M2FNTypeGetTypeID() {
114114
}
115115

116116
bool mlirTypeIsAFloat6E3M2FN(MlirType type) {
117-
return unwrap(type).isFloat6E3M2FN();
117+
return llvm::isa<Float6E3M2FNType>(unwrap(type));
118118
}
119119

120120
MlirType mlirFloat6E3M2FNTypeGet(MlirContext ctx) {
@@ -126,7 +126,7 @@ MlirTypeID mlirFloat8E5M2TypeGetTypeID() {
126126
}
127127

128128
bool mlirTypeIsAFloat8E5M2(MlirType type) {
129-
return unwrap(type).isFloat8E5M2();
129+
return llvm::isa<Float8E5M2Type>(unwrap(type));
130130
}
131131

132132
MlirType mlirFloat8E5M2TypeGet(MlirContext ctx) {
@@ -138,7 +138,7 @@ MlirTypeID mlirFloat8E4M3TypeGetTypeID() {
138138
}
139139

140140
bool mlirTypeIsAFloat8E4M3(MlirType type) {
141-
return unwrap(type).isFloat8E4M3();
141+
return llvm::isa<Float8E4M3Type>(unwrap(type));
142142
}
143143

144144
MlirType mlirFloat8E4M3TypeGet(MlirContext ctx) {
@@ -150,7 +150,7 @@ MlirTypeID mlirFloat8E4M3FNTypeGetTypeID() {
150150
}
151151

152152
bool mlirTypeIsAFloat8E4M3FN(MlirType type) {
153-
return unwrap(type).isFloat8E4M3FN();
153+
return llvm::isa<Float8E4M3FNType>(unwrap(type));
154154
}
155155

156156
MlirType mlirFloat8E4M3FNTypeGet(MlirContext ctx) {
@@ -162,7 +162,7 @@ MlirTypeID mlirFloat8E5M2FNUZTypeGetTypeID() {
162162
}
163163

164164
bool mlirTypeIsAFloat8E5M2FNUZ(MlirType type) {
165-
return unwrap(type).isFloat8E5M2FNUZ();
165+
return llvm::isa<Float8E5M2FNUZType>(unwrap(type));
166166
}
167167

168168
MlirType mlirFloat8E5M2FNUZTypeGet(MlirContext ctx) {
@@ -174,7 +174,7 @@ MlirTypeID mlirFloat8E4M3FNUZTypeGetTypeID() {
174174
}
175175

176176
bool mlirTypeIsAFloat8E4M3FNUZ(MlirType type) {
177-
return unwrap(type).isFloat8E4M3FNUZ();
177+
return llvm::isa<Float8E4M3FNUZType>(unwrap(type));
178178
}
179179

180180
MlirType mlirFloat8E4M3FNUZTypeGet(MlirContext ctx) {
@@ -186,7 +186,7 @@ MlirTypeID mlirFloat8E4M3B11FNUZTypeGetTypeID() {
186186
}
187187

188188
bool mlirTypeIsAFloat8E4M3B11FNUZ(MlirType type) {
189-
return unwrap(type).isFloat8E4M3B11FNUZ();
189+
return llvm::isa<Float8E4M3B11FNUZType>(unwrap(type));
190190
}
191191

192192
MlirType mlirFloat8E4M3B11FNUZTypeGet(MlirContext ctx) {
@@ -198,7 +198,7 @@ MlirTypeID mlirFloat8E3M4TypeGetTypeID() {
198198
}
199199

200200
bool mlirTypeIsAFloat8E3M4(MlirType type) {
201-
return unwrap(type).isFloat8E3M4();
201+
return llvm::isa<Float8E3M4Type>(unwrap(type));
202202
}
203203

204204
MlirType mlirFloat8E3M4TypeGet(MlirContext ctx) {
@@ -210,7 +210,7 @@ MlirTypeID mlirFloat8E8M0FNUTypeGetTypeID() {
210210
}
211211

212212
bool mlirTypeIsAFloat8E8M0FNU(MlirType type) {
213-
return unwrap(type).isFloat8E8M0FNU();
213+
return llvm::isa<Float8E8M0FNUType>(unwrap(type));
214214
}
215215

216216
MlirType mlirFloat8E8M0FNUTypeGet(MlirContext ctx) {
@@ -221,15 +221,19 @@ MlirTypeID mlirBFloat16TypeGetTypeID() {
221221
return wrap(BFloat16Type::getTypeID());
222222
}
223223

224-
bool mlirTypeIsABF16(MlirType type) { return unwrap(type).isBF16(); }
224+
bool mlirTypeIsABF16(MlirType type) {
225+
return llvm::isa<BFloat16Type>(unwrap(type));
226+
}
225227

226228
MlirType mlirBF16TypeGet(MlirContext ctx) {
227229
return wrap(BFloat16Type::get(unwrap(ctx)));
228230
}
229231

230232
MlirTypeID mlirFloat16TypeGetTypeID() { return wrap(Float16Type::getTypeID()); }
231233

232-
bool mlirTypeIsAF16(MlirType type) { return unwrap(type).isF16(); }
234+
bool mlirTypeIsAF16(MlirType type) {
235+
return llvm::isa<Float16Type>(unwrap(type));
236+
}
233237

234238
MlirType mlirF16TypeGet(MlirContext ctx) {
235239
return wrap(Float16Type::get(unwrap(ctx)));
@@ -239,23 +243,29 @@ MlirTypeID mlirFloatTF32TypeGetTypeID() {
239243
return wrap(FloatTF32Type::getTypeID());
240244
}
241245

242-
bool mlirTypeIsATF32(MlirType type) { return unwrap(type).isTF32(); }
246+
bool mlirTypeIsATF32(MlirType type) {
247+
return llvm::isa<FloatTF32Type>(unwrap(type));
248+
}
243249

244250
MlirType mlirTF32TypeGet(MlirContext ctx) {
245251
return wrap(FloatTF32Type::get(unwrap(ctx)));
246252
}
247253

248254
MlirTypeID mlirFloat32TypeGetTypeID() { return wrap(Float32Type::getTypeID()); }
249255

250-
bool mlirTypeIsAF32(MlirType type) { return unwrap(type).isF32(); }
256+
bool mlirTypeIsAF32(MlirType type) {
257+
return llvm::isa<Float32Type>(unwrap(type));
258+
}
251259

252260
MlirType mlirF32TypeGet(MlirContext ctx) {
253261
return wrap(Float32Type::get(unwrap(ctx)));
254262
}
255263

256264
MlirTypeID mlirFloat64TypeGetTypeID() { return wrap(Float64Type::getTypeID()); }
257265

258-
bool mlirTypeIsAF64(MlirType type) { return unwrap(type).isF64(); }
266+
bool mlirTypeIsAF64(MlirType type) {
267+
return llvm::isa<Float64Type>(unwrap(type));
268+
}
259269

260270
MlirType mlirF64TypeGet(MlirContext ctx) {
261271
return wrap(Float64Type::get(unwrap(ctx)));

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -564,38 +564,40 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
564564
return ROCDL::mfma_f64_4x4x4f64::getOperationName();
565565
}
566566

567-
if (sourceElem.isFloat8E5M2FNUZ() && destElem.isF32() && chipset >= kGfx940) {
567+
if (isa<Float8E5M2FNUZType>(sourceElem) && destElem.isF32() &&
568+
chipset >= kGfx940) {
568569
// Known to be correct because there are no scalar f8 instructions and
569570
// because a length mismatch will have been caught by the verifier.
570571
Type sourceBElem =
571572
cast<VectorType>(mfma.getSourceB().getType()).getElementType();
572573
if (m == 16 && n == 16 && k == 32 && b == 1) {
573-
if (sourceBElem.isFloat8E5M2FNUZ())
574+
if (isa<Float8E5M2FNUZType>(sourceBElem))
574575
return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName();
575-
if (sourceBElem.isFloat8E4M3FNUZ())
576+
if (isa<Float8E4M3FNUZType>(sourceBElem))
576577
return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName();
577578
}
578579
if (m == 32 && n == 32 && k == 16 && b == 1) {
579-
if (sourceBElem.isFloat8E5M2FNUZ())
580+
if (isa<Float8E5M2FNUZType>(sourceBElem))
580581
return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName();
581-
if (sourceBElem.isFloat8E4M3FNUZ())
582+
if (isa<Float8E4M3FNUZType>(sourceBElem))
582583
return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName();
583584
}
584585
}
585586

586-
if (sourceElem.isFloat8E4M3FNUZ() && destElem.isF32() && chipset >= kGfx940) {
587+
if (isa<Float8E4M3FNUZType>(sourceElem) && destElem.isF32() &&
588+
chipset >= kGfx940) {
587589
Type sourceBElem =
588590
cast<VectorType>(mfma.getSourceB().getType()).getElementType();
589591
if (m == 16 && n == 16 && k == 32 && b == 1) {
590-
if (sourceBElem.isFloat8E5M2FNUZ())
592+
if (isa<Float8E5M2FNUZType>(sourceBElem))
591593
return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName();
592-
if (sourceBElem.isFloat8E4M3FNUZ())
594+
if (isa<Float8E4M3FNUZType>(sourceBElem))
593595
return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName();
594596
}
595597
if (m == 32 && n == 32 && k == 16 && b == 1) {
596-
if (sourceBElem.isFloat8E5M2FNUZ())
598+
if (isa<Float8E5M2FNUZType>(sourceBElem))
597599
return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName();
598-
if (sourceBElem.isFloat8E4M3FNUZ())
600+
if (isa<Float8E4M3FNUZType>(sourceBElem))
599601
return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName();
600602
}
601603
}
@@ -623,9 +625,9 @@ static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
623625
return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
624626
if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
625627
return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
626-
if (elemSourceType.isFloat8E4M3FN() && elemDestType.isF32())
628+
if (isa<Float8E4M3FNType>(elemSourceType) && elemDestType.isF32())
627629
return ROCDL::wmma_f32_16x16x16_fp8::getOperationName();
628-
if (elemSourceType.isFloat8E5M2() && elemDestType.isF32())
630+
if (isa<Float8E5M2Type>(elemSourceType) && elemDestType.isF32())
629631
return ROCDL::wmma_f32_16x16x16_bf8::getOperationName();
630632
return std::nullopt;
631633
}
@@ -803,10 +805,10 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
803805
}
804806
Value i32Source = rewriter.create<LLVM::BitcastOp>(loc, i32, source);
805807
Value wordSel = createI32Constant(rewriter, loc, op.getIndex());
806-
if (sourceElemType.isFloat8E5M2FNUZ()) {
808+
if (isa<Float8E5M2FNUZType>(sourceElemType)) {
807809
rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source,
808810
wordSel);
809-
} else if (sourceElemType.isFloat8E4M3FNUZ()) {
811+
} else if (isa<Float8E4M3FNUZType>(sourceElemType)) {
810812
rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source,
811813
wordSel);
812814
}
@@ -838,10 +840,10 @@ LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
838840
Value wordSel = createI1Constant(rewriter, loc, op.getWordIndex());
839841

840842
Value result;
841-
if (resultElemType.isFloat8E5M2FNUZ())
843+
if (isa<Float8E5M2FNUZType>(resultElemType))
842844
result = rewriter.create<ROCDL::CvtPkBf8F32Op>(loc, i32, sourceA, sourceB,
843845
existing, wordSel);
844-
else if (resultElemType.isFloat8E4M3FNUZ())
846+
else if (isa<Float8E4M3FNUZType>(resultElemType))
845847
result = rewriter.create<ROCDL::CvtPkFp8F32Op>(loc, i32, sourceA, sourceB,
846848
existing, wordSel);
847849

@@ -873,10 +875,10 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
873875
Value byteSel = createI32Constant(rewriter, loc, op.getStoreIndex());
874876

875877
Value result;
876-
if (resultElemType.isFloat8E5M2FNUZ())
878+
if (isa<Float8E5M2FNUZType>(resultElemType))
877879
result = rewriter.create<ROCDL::CvtSrBf8F32Op>(loc, i32, source, stoch,
878880
existing, byteSel);
879-
else if (resultElemType.isFloat8E4M3FNUZ())
881+
else if (isa<Float8E4M3FNUZType>(resultElemType))
880882
result = rewriter.create<ROCDL::CvtSrFp8F32Op>(loc, i32, source, stoch,
881883
existing, byteSel);
882884

mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ LogicalResult ExtFOnFloat8RewritePattern::match(arith::ExtFOp op) const {
8686
return failure();
8787
inType = inVecType.getElementType();
8888
}
89-
return success(inType.isFloat8E5M2FNUZ() || inType.isFloat8E4M3FNUZ());
89+
return success(isa<Float8E5M2FNUZType, Float8E4M3FNUZType>(inType));
9090
}
9191

9292
void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
@@ -216,7 +216,7 @@ LogicalResult TruncFToFloat8RewritePattern::match(arith::TruncFOp op) const {
216216
if (inType && inType.getWidth() <= 8 && saturateFP8)
217217
// Conversion between 8-bit floats is not supported with truncation enabled.
218218
return failure();
219-
return success(outType.isFloat8E5M2FNUZ() || outType.isFloat8E4M3FNUZ());
219+
return success(isa<Float8E5M2FNUZType, Float8E4M3FNUZType>(outType));
220220
}
221221

222222
void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,

mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -299,11 +299,10 @@ Type LLVMTypeConverter::convertFloatType(FloatType type) const {
299299
return type;
300300

301301
// F4, F6, F8 types are converted to integer types with the same bit width.
302-
if (type.isFloat8E5M2() || type.isFloat8E4M3() || type.isFloat8E4M3FN() ||
303-
type.isFloat8E5M2FNUZ() || type.isFloat8E4M3FNUZ() ||
304-
type.isFloat8E4M3B11FNUZ() || type.isFloat8E3M4() ||
305-
type.isFloat4E2M1FN() || type.isFloat6E2M3FN() || type.isFloat6E3M2FN() ||
306-
type.isFloat8E8M0FNU())
302+
if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType,
303+
Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type,
304+
Float4E2M1FNType, Float6E2M3FNType, Float6E3M2FNType,
305+
Float8E8M0FNUType>(type))
307306
return IntegerType::get(&getContext(), type.getWidth());
308307

309308
// Other floating-point types: A custom type conversion rule must be

mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1254,8 +1254,8 @@ struct NVGPUWarpgroupMmaOpLowering
12541254
wgmmaK = 8;
12551255
} else if (inputElemType.isF16() || inputElemType.isBF16()) {
12561256
wgmmaK = 16;
1257-
} else if (inputElemType.isFloat8E4M3FN() ||
1258-
inputElemType.isFloat8E5M2() || inputElemType.isInteger(16)) {
1257+
} else if (isa<Float8E4M3FNType, Float8E5M2Type>(inputElemType) ||
1258+
inputElemType.isInteger(16)) {
12591259
wgmmaK = 32;
12601260
} else if (inputElemType.isInteger(1)) {
12611261
wgmmaK = 256;
@@ -1276,9 +1276,9 @@ struct NVGPUWarpgroupMmaOpLowering
12761276
return NVVM::WGMMATypes::f16;
12771277
if (elemType.isBF16())
12781278
return NVVM::WGMMATypes::bf16;
1279-
if (elemType.isFloat8E4M3FN())
1279+
if (isa<Float8E4M3FNType>(elemType))
12801280
return NVVM::WGMMATypes::e4m3;
1281-
if (elemType.isFloat8E5M2())
1281+
if (isa<Float8E5M2Type>(elemType))
12821282
return NVVM::WGMMATypes::e5m2;
12831283
if (elemType.isInteger(1))
12841284
return NVVM::WGMMATypes::b1;

mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -272,14 +272,14 @@ LogicalResult MFMAOp::verify() {
272272
}
273273

274274
Type sourceBType = getSourceB().getType();
275-
if (sourceElem.isFloat8E5M2FNUZ() || sourceElem.isFloat8E4M3FNUZ()) {
275+
if (isa<Float8E5M2FNUZType, Float8E4M3FNUZType>(sourceElem)) {
276276
int64_t sourceBLen = 1;
277277
Type sourceBElem = sourceBType;
278278
if (auto sourceBVector = llvm::dyn_cast<VectorType>(sourceBType)) {
279279
sourceBLen = sourceBVector.getNumElements();
280280
sourceBElem = sourceBVector.getElementType();
281281
}
282-
if (!sourceBElem.isFloat8E5M2FNUZ() && !sourceBElem.isFloat8E4M3FNUZ())
282+
if (!isa<Float8E5M2FNUZType, Float8E4M3FNUZType>(sourceBElem))
283283
return emitOpError("expected both source operands to have f8 elements");
284284
if (sourceLen != sourceBLen)
285285
return emitOpError(

0 commit comments

Comments
 (0)