Skip to content

Commit f941908

Browse files
author
Ivy Zhang
authored
Revert "[MLIR][Arith] add fastMathAttr on arith::extf and arith::truncf" (#95344)
Reverts #93443
1 parent 90a23d3 commit f941908

File tree

7 files changed

+40
-213
lines changed

7 files changed

+40
-213
lines changed

mlir/include/mlir/Dialect/Arith/IR/ArithOps.td

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1199,7 +1199,7 @@ def Arith_ExtSIOp : Arith_IToICastOp<"extsi"> {
11991199
// ExtFOp
12001200
//===----------------------------------------------------------------------===//
12011201

1202-
def Arith_ExtFOp : Arith_FToFCastOp<"extf", [DeclareOpInterfaceMethods<ArithFastMathInterface>]> {
1202+
def Arith_ExtFOp : Arith_FToFCastOp<"extf"> {
12031203
let summary = "cast from floating-point to wider floating-point";
12041204
let description = [{
12051205
Cast a floating-point value to a larger floating-point-typed value.
@@ -1208,13 +1208,6 @@ def Arith_ExtFOp : Arith_FToFCastOp<"extf", [DeclareOpInterfaceMethods<ArithFast
12081208
}];
12091209
let hasVerifier = 1;
12101210
let hasFolder = 1;
1211-
1212-
let arguments = (ins FloatLike:$in, DefaultValuedAttr<
1213-
Arith_FastMathAttr, "::mlir::arith::FastMathFlags::none">:$fastmath);
1214-
let results = (outs FloatLike:$out);
1215-
1216-
let assemblyFormat = [{ $in (`fastmath` `` $fastmath^)?
1217-
attr-dict `:` type($in) `to` type($out) }];
12181211
}
12191212

12201213
//===----------------------------------------------------------------------===//
@@ -1253,11 +1246,8 @@ def Arith_TruncFOp :
12531246
Arith_Op<"truncf",
12541247
[Pure, SameOperandsAndResultShape, SameInputOutputTensorDims,
12551248
DeclareOpInterfaceMethods<ArithRoundingModeInterface>,
1256-
DeclareOpInterfaceMethods<ArithFastMathInterface>,
12571249
DeclareOpInterfaceMethods<CastOpInterface>]>,
12581250
Arguments<(ins FloatLike:$in,
1259-
DefaultValuedAttr<
1260-
Arith_FastMathAttr, "::mlir::arith::FastMathFlags::none">:$fastmath,
12611251
OptionalAttr<Arith_RoundingModeAttr>:$roundingmode)>,
12621252
Results<(outs FloatLike:$out)> {
12631253
let summary = "cast from floating-point to narrower floating-point";
@@ -1277,9 +1267,7 @@ def Arith_TruncFOp :
12771267

12781268
let hasFolder = 1;
12791269
let hasVerifier = 1;
1280-
let assemblyFormat = [{ $in ($roundingmode^)?
1281-
(`fastmath` `` $fastmath^)?
1282-
attr-dict `:` type($in) `to` type($out) }];
1270+
let assemblyFormat = "$in ($roundingmode^)? attr-dict `:` type($in) `to` type($out)";
12831271
}
12841272

12851273
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Arith/IR/ArithOps.cpp

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1390,20 +1390,6 @@ LogicalResult arith::ExtSIOp::verify() {
13901390
/// Fold extension of float constants when there is no information loss due the
13911391
/// difference in fp semantics.
13921392
OpFoldResult arith::ExtFOp::fold(FoldAdaptor adaptor) {
1393-
if (auto truncFOp = getOperand().getDefiningOp<TruncFOp>()) {
1394-
if (truncFOp.getOperand().getType() == getType()) {
1395-
arith::FastMathFlags truncFMF = truncFOp.getFastmath();
1396-
bool isTruncContract =
1397-
bitEnumContainsAll(truncFMF, arith::FastMathFlags::contract);
1398-
arith::FastMathFlags extFMF = getFastmath();
1399-
bool isExtContract =
1400-
bitEnumContainsAll(extFMF, arith::FastMathFlags::contract);
1401-
if (isTruncContract && isExtContract) {
1402-
return truncFOp.getOperand();
1403-
}
1404-
}
1405-
}
1406-
14071393
auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
14081394
const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
14091395
return constFoldCastOp<FloatAttr, FloatAttr>(

mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -94,11 +94,8 @@ void EmulateFloatPattern::rewrite(Operation *op, ArrayRef<Value> operands,
9494
SmallVector<Value> newResults(expandedOp->getResults());
9595
for (auto [res, oldType, newType] : llvm::zip_equal(
9696
MutableArrayRef{newResults}, op->getResultTypes(), resultTypes)) {
97-
if (oldType != newType) {
98-
auto truncFOp = rewriter.create<arith::TruncFOp>(loc, oldType, res);
99-
truncFOp.setFastmath(arith::FastMathFlags::contract);
100-
res = truncFOp.getResult();
101-
}
97+
if (oldType != newType)
98+
res = rewriter.create<arith::TruncFOp>(loc, oldType, res);
10299
}
103100
rewriter.replaceOp(op, newResults);
104101
}
@@ -117,9 +114,7 @@ void mlir::arith::populateEmulateUnsupportedFloatsConversions(
117114
});
118115
converter.addTargetMaterialization(
119116
[](OpBuilder &b, Type target, ValueRange input, Location loc) {
120-
auto extFOp = b.create<arith::ExtFOp>(loc, target, input);
121-
extFOp.setFastmath(arith::FastMathFlags::contract);
122-
return extFOp;
117+
return b.create<arith::ExtFOp>(loc, target, input);
123118
});
124119
}
125120

mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,7 @@ void mlir::math::populateLegalizeToF32TypeConverter(
5757
});
5858
typeConverter.addTargetMaterialization(
5959
[](OpBuilder &b, Type target, ValueRange input, Location loc) {
60-
auto extFOp = b.create<arith::ExtFOp>(loc, target, input);
61-
extFOp.setFastmath(arith::FastMathFlags::contract);
62-
return extFOp;
60+
return b.create<arith::ExtFOp>(loc, target, input);
6361
});
6462
}
6563

@@ -86,11 +84,8 @@ LogicalResult LegalizeToF32RewritePattern::matchAndRewrite(
8684
SmallVector<Value> results = (*legalized)->getResults();
8785
for (auto [result, newType, origType] : llvm::zip_equal(
8886
results, (*legalized)->getResultTypes(), op->getResultTypes())) {
89-
if (newType != origType) {
90-
auto truncFOp = rewriter.create<arith::TruncFOp>(loc, origType, result);
91-
truncFOp.setFastmath(arith::FastMathFlags::contract);
92-
result = truncFOp.getResult();
93-
}
87+
if (newType != origType)
88+
result = rewriter.create<arith::TruncFOp>(loc, origType, result);
9489
}
9590
rewriter.replaceOp(op, results);
9691
return success();

mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -162,23 +162,23 @@ func.func @uitofp(%arg0 : i32, %arg1 : i64) {
162162
// Checking conversion of integer types to floating point.
163163
// CHECK-LABEL: @fpext
164164
func.func @fpext(%arg0 : f16, %arg1 : f32) {
165-
// CHECK-NEXT: = llvm.fpext {{.*}} {fastmath = #arith.fastmath<none>} : f16 to f32
165+
// CHECK-NEXT: = llvm.fpext {{.*}} : f16 to f32
166166
%0 = arith.extf %arg0: f16 to f32
167-
// CHECK-NEXT: = llvm.fpext {{.*}} {fastmath = #arith.fastmath<none>} : f16 to f64
167+
// CHECK-NEXT: = llvm.fpext {{.*}} : f16 to f64
168168
%1 = arith.extf %arg0: f16 to f64
169-
// CHECK-NEXT: = llvm.fpext {{.*}} {fastmath = #arith.fastmath<none>} : f32 to f64
169+
// CHECK-NEXT: = llvm.fpext {{.*}} : f32 to f64
170170
%2 = arith.extf %arg1: f32 to f64
171171
return
172172
}
173173

174174
// Checking conversion of integer types to floating point.
175175
// CHECK-LABEL: @fpext
176176
func.func @fpext_vector(%arg0 : vector<2xf16>, %arg1 : vector<2xf32>) {
177-
// CHECK-NEXT: = llvm.fpext {{.*}} {fastmath = #arith.fastmath<none>} : vector<2xf16> to vector<2xf32>
177+
// CHECK-NEXT: = llvm.fpext {{.*}} : vector<2xf16> to vector<2xf32>
178178
%0 = arith.extf %arg0: vector<2xf16> to vector<2xf32>
179-
// CHECK-NEXT: = llvm.fpext {{.*}} {fastmath = #arith.fastmath<none>} : vector<2xf16> to vector<2xf64>
179+
// CHECK-NEXT: = llvm.fpext {{.*}} : vector<2xf16> to vector<2xf64>
180180
%1 = arith.extf %arg0: vector<2xf16> to vector<2xf64>
181-
// CHECK-NEXT: = llvm.fpext {{.*}} {fastmath = #arith.fastmath<none>} : vector<2xf32> to vector<2xf64>
181+
// CHECK-NEXT: = llvm.fpext {{.*}} : vector<2xf32> to vector<2xf64>
182182
%2 = arith.extf %arg1: vector<2xf32> to vector<2xf64>
183183
return
184184
}
@@ -268,38 +268,38 @@ func.func @uitofp_vector(%arg0 : vector<2xi16>, %arg1 : vector<2xi32>, %arg2 : v
268268
// Checking conversion of integer types to floating point.
269269
// CHECK-LABEL: @fptrunc
270270
func.func @fptrunc(%arg0 : f32, %arg1 : f64) {
271-
// CHECK-NEXT: = llvm.fptrunc {{.*}} {fastmath = #arith.fastmath<none>} : f32 to f16
271+
// CHECK-NEXT: = llvm.fptrunc {{.*}} : f32 to f16
272272
%0 = arith.truncf %arg0: f32 to f16
273-
// CHECK-NEXT: = llvm.fptrunc {{.*}} {fastmath = #arith.fastmath<none>} : f64 to f16
273+
// CHECK-NEXT: = llvm.fptrunc {{.*}} : f64 to f16
274274
%1 = arith.truncf %arg1: f64 to f16
275-
// CHECK-NEXT: = llvm.fptrunc {{.*}} {fastmath = #arith.fastmath<none>} : f64 to f32
275+
// CHECK-NEXT: = llvm.fptrunc {{.*}} : f64 to f32
276276
%2 = arith.truncf %arg1: f64 to f32
277277
return
278278
}
279279

280280
// Checking conversion of integer types to floating point.
281281
// CHECK-LABEL: @fptrunc
282282
func.func @fptrunc_vector(%arg0 : vector<2xf32>, %arg1 : vector<2xf64>) {
283-
// CHECK-NEXT: = llvm.fptrunc {{.*}} {fastmath = #arith.fastmath<none>} : vector<2xf32> to vector<2xf16>
283+
// CHECK-NEXT: = llvm.fptrunc {{.*}} : vector<2xf32> to vector<2xf16>
284284
%0 = arith.truncf %arg0: vector<2xf32> to vector<2xf16>
285-
// CHECK-NEXT: = llvm.fptrunc {{.*}} {fastmath = #arith.fastmath<none>} : vector<2xf64> to vector<2xf16>
285+
// CHECK-NEXT: = llvm.fptrunc {{.*}} : vector<2xf64> to vector<2xf16>
286286
%1 = arith.truncf %arg1: vector<2xf64> to vector<2xf16>
287-
// CHECK-NEXT: = llvm.fptrunc {{.*}} {fastmath = #arith.fastmath<none>} : vector<2xf64> to vector<2xf32>
287+
// CHECK-NEXT: = llvm.fptrunc {{.*}} : vector<2xf64> to vector<2xf32>
288288
%2 = arith.truncf %arg1: vector<2xf64> to vector<2xf32>
289289
return
290290
}
291291

292292
// CHECK-LABEL: experimental_constrained_fptrunc
293293
func.func @experimental_constrained_fptrunc(%arg0 : f64) {
294-
// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} tonearest ignore {fastmath = #arith.fastmath<none>} : f64 to f32
294+
// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} tonearest ignore : f64 to f32
295295
%0 = arith.truncf %arg0 to_nearest_even : f64 to f32
296-
// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} downward ignore {fastmath = #arith.fastmath<none>} : f64 to f32
296+
// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} downward ignore : f64 to f32
297297
%1 = arith.truncf %arg0 downward : f64 to f32
298-
// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} upward ignore {fastmath = #arith.fastmath<none>} : f64 to f32
298+
// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} upward ignore : f64 to f32
299299
%2 = arith.truncf %arg0 upward : f64 to f32
300-
// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} towardzero ignore {fastmath = #arith.fastmath<none>} : f64 to f32
300+
// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} towardzero ignore : f64 to f32
301301
%3 = arith.truncf %arg0 toward_zero : f64 to f32
302-
// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} tonearestaway ignore {fastmath = #arith.fastmath<none>} : f64 to f32
302+
// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} tonearestaway ignore : f64 to f32
303303
%4 = arith.truncf %arg0 to_nearest_away : f64 to f32
304304
return
305305
}

mlir/test/Dialect/Arith/canonicalize.mlir

Lines changed: 0 additions & 137 deletions
Original file line numberDiff line numberDiff line change
@@ -3031,143 +3031,6 @@ func.func @mulsi_extended_i0() -> (i0, i0) {
30313031
return %mulsi_extended#0, %mulsi_extended#1 : i0, i0
30323032
}
30333033

3034-
// CHECK-LABEL: @sequences_fastmath_contract
3035-
// CHECK-SAME: ([[ARG0:%.+]]: bf16)
3036-
// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
3037-
// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
3038-
// CHECK: [[SIN:%.+]] = math.sin [[ABSF]]
3039-
// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
3040-
// CHECK: return [[TRUNCF]] : bf16
3041-
func.func @sequences_fastmath_contract(%arg0: bf16) -> bf16 {
3042-
%0 = arith.extf %arg0 fastmath<contract> : bf16 to f32
3043-
%1 = math.absf %0 : f32
3044-
%2 = arith.truncf %1 fastmath<contract> : f32 to bf16
3045-
%3 = arith.extf %2 fastmath<contract> : bf16 to f32
3046-
%4 = math.sin %3 : f32
3047-
%5 = arith.truncf %4 fastmath<contract> : f32 to bf16
3048-
return %5 : bf16
3049-
}
3050-
3051-
// CHECK-LABEL: @sequences_no_fastmath
3052-
// CHECK-SAME: ([[ARG0:%.+]]: bf16)
3053-
// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
3054-
// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
3055-
// CHECK: [[TRUNCF1:%.+]] = arith.truncf [[ABSF]]
3056-
// CHECK: [[EXTF1:%.+]] = arith.extf [[TRUNCF1]]
3057-
// CHECK: [[SIN:%.+]] = math.sin [[EXTF1]]
3058-
// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
3059-
// CHECK: return [[TRUNCF]] : bf16
3060-
func.func @sequences_no_fastmath(%arg0: bf16) -> bf16 {
3061-
%0 = arith.extf %arg0 : bf16 to f32
3062-
%1 = math.absf %0 : f32
3063-
%2 = arith.truncf %1 : f32 to bf16
3064-
%3 = arith.extf %2 : bf16 to f32
3065-
%4 = math.sin %3 : f32
3066-
%5 = arith.truncf %4 : f32 to bf16
3067-
return %5 : bf16
3068-
}
3069-
3070-
// CHECK-LABEL: @eliminate_cast_to_f16
3071-
// CHECK: return [[arg0:%.+]] : f32
3072-
func.func @eliminate_cast_to_f16(%arg0: f32) -> f32 {
3073-
%0 = arith.truncf %arg0 fastmath<contract> : f32 to f16
3074-
%1 = arith.extf %0 fastmath<contract> : f16 to f32
3075-
return %1 : f32
3076-
}
3077-
3078-
// CHECK-LABEL: @eliminate_cast_to_bf16
3079-
// CHECK: return [[arg0:%.+]] : f32
3080-
func.func @eliminate_cast_to_bf16(%arg0: f32) -> f32 {
3081-
%0 = arith.truncf %arg0 fastmath<contract> : f32 to bf16
3082-
%1 = arith.extf %0 fastmath<contract> : bf16 to f32
3083-
return %1 : f32
3084-
}
3085-
3086-
// CHECK-LABEL: @bf16_sin_vector
3087-
// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xbf16>)
3088-
// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
3089-
// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
3090-
// CHECK: [[SIN:%.+]] = math.sin [[ABSF]]
3091-
// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
3092-
// CHECK: return [[TRUNCF]] : vector<32x32x32xbf16>
3093-
func.func @bf16_sin_vector(%arg0: vector<32x32x32xbf16>) -> vector<32x32x32xbf16> {
3094-
%0 = arith.extf %arg0 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
3095-
%1 = math.absf %0 : vector<32x32x32xf32>
3096-
%2 = arith.truncf %1 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xbf16>
3097-
%3 = arith.extf %2 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
3098-
%4 = math.sin %3 : vector<32x32x32xf32>
3099-
%5 = arith.truncf %4 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xbf16>
3100-
return %5 : vector<32x32x32xbf16>
3101-
}
3102-
3103-
// CHECK-LABEL: @f16_sin_vector
3104-
// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xf16>)
3105-
// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
3106-
// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
3107-
// CHECK: [[SIN:%.+]] = math.sin [[ABSF]]
3108-
// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
3109-
// CHECK: return [[TRUNCF]] : vector<32x32x32xf16>
3110-
func.func @f16_sin_vector(%arg0: vector<32x32x32xf16>) -> vector<32x32x32xf16> {
3111-
%0 = arith.extf %arg0 fastmath<contract> : vector<32x32x32xf16> to vector<32x32x32xf32>
3112-
%1 = math.absf %0 : vector<32x32x32xf32>
3113-
%2 = arith.truncf %1 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xf16>
3114-
%3 = arith.extf %2 fastmath<contract> : vector<32x32x32xf16> to vector<32x32x32xf32>
3115-
%4 = math.sin %3 : vector<32x32x32xf32>
3116-
%5 = arith.truncf %4 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xf16>
3117-
return %5 : vector<32x32x32xf16>
3118-
}
3119-
3120-
// CHECK-LABEL: @bf16_branch_vector
3121-
// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xbf16>)
3122-
// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
3123-
// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
3124-
// CHECK-DAG: [[SIN:%.+]] = math.sin [[ABSF]]
3125-
// CHECK-DAG: [[COS:%.+]] = math.cos [[ABSF]]
3126-
// CHECK: [[ADDF:%.+]] = arith.addf [[SIN]], [[COS]]
3127-
// CHECK: [[TRUNCF:%.+]] = arith.truncf [[ADDF]]
3128-
// CHECK: return [[TRUNCF]] : vector<32x32x32xbf16>
3129-
func.func @bf16_branch_vector(%arg0: vector<32x32x32xbf16>) -> vector<32x32x32xbf16> {
3130-
%0 = arith.extf %arg0 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
3131-
%1 = math.absf %0 : vector<32x32x32xf32>
3132-
%2 = arith.truncf %1 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xbf16>
3133-
%3 = arith.extf %2 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
3134-
%4 = math.sin %3 : vector<32x32x32xf32>
3135-
%5 = arith.truncf %4 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xbf16>
3136-
%6 = arith.extf %5 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
3137-
%7 = math.cos %3 : vector<32x32x32xf32>
3138-
%8 = arith.truncf %7 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xbf16>
3139-
%9 = arith.extf %8 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
3140-
%10 = arith.addf %6, %9 : vector<32x32x32xf32>
3141-
%11 = arith.truncf %10 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xbf16>
3142-
return %11 : vector<32x32x32xbf16>
3143-
}
3144-
3145-
// CHECK-LABEL: @bf16_fma
3146-
// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xbf16>, [[ARG1:%.+]]: vector<32x32x32xbf16>, [[ARG2:%.+]]: vector<32x32x32xbf16>)
3147-
// CHECK: [[EXTF0:%.+]] = arith.extf [[ARG0]]
3148-
// CHECK: [[ABSF:%.+]] = math.absf [[EXTF0]]
3149-
// CHECK-DAG: [[SIN:%.+]] = math.sin [[ABSF]]
3150-
// CHECK: [[TRUNCF0:%.+]] = arith.truncf [[SIN]]
3151-
// CHECK-DAG: [[FMA:%.+]] = math.fma [[TRUNCF0]], [[ARG1]], [[ARG2]]
3152-
// CHECK: [[EXTF1:%.+]] = arith.extf [[FMA]]
3153-
// CHECK: [[ADDF:%.+]] = arith.addf [[EXTF1]], [[SIN]]
3154-
// CHECK: [[TRUNCF1:%.+]] = arith.truncf [[ADDF]]
3155-
// CHECK: return [[TRUNCF1]] : vector<32x32x32xbf16>
3156-
func.func @bf16_fma(%arg0: vector<32x32x32xbf16>, %arg1: vector<32x32x32xbf16>, %arg2: vector<32x32x32xbf16>) -> vector<32x32x32xbf16> {
3157-
%0 = arith.extf %arg0 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
3158-
%1 = math.absf %0 : vector<32x32x32xf32>
3159-
%2 = arith.truncf %1 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xbf16>
3160-
%3 = arith.extf %2 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
3161-
%4 = math.sin %3 : vector<32x32x32xf32>
3162-
%5 = arith.truncf %4 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xbf16>
3163-
%6 = arith.extf %5 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
3164-
%7 = math.fma %5, %arg1, %arg2 : vector<32x32x32xbf16>
3165-
%8 = arith.extf %7 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
3166-
%9 = arith.addf %8, %6 : vector<32x32x32xf32>
3167-
%10 = arith.truncf %9 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xbf16>
3168-
return %10 : vector<32x32x32xbf16>
3169-
}
3170-
31713034
{-#
31723035
dialect_resources: {
31733036
builtin: {

0 commit comments

Comments
 (0)