Skip to content

Commit 83ba00b

Browse files
committed
[flang] Handle special case for SHIFTA intrinsic
This patch update the lowering of the shifta intrinsic to match the behvior of gfortran. When the SHIFT value is equal to the integer bitwidth then we handle it differently. This is due to the operation used in lowering (`mlir::arith::ShRSIOp`) that lowers to `ashr`. Before this patch we have the following results: ``` SHIFTA( -1, 8) = 0 SHIFTA( -2, 8) = 0 SHIFTA( -30, 8) = 0 SHIFTA( -31, 8) = 0 SHIFTA( -32, 8) = 0 SHIFTA( -33, 8) = 0 SHIFTA(-126, 8) = 0 SHIFTA(-127, 8) = 0 SHIFTA(-128, 8) = 0 ``` While gfortran is giving this: ``` SHIFTA( -1, 8) = -1 SHIFTA( -2, 8) = -1 SHIFTA( -30, 8) = -1 SHIFTA( -31, 8) = -1 SHIFTA( -32, 8) = -1 SHIFTA( -33, 8) = -1 SHIFTA(-126, 8) = -1 SHIFTA(-127, 8) = -1 SHIFTA(-128, 8) = -1 ``` With this patch flang and gfortran have the same behavior. Reviewed By: jeanPerier Differential Revision: https://reviews.llvm.org/D133104
1 parent 4f046bc commit 83ba00b

File tree

2 files changed

+63
-32
lines changed

2 files changed

+63
-32
lines changed

flang/lib/Lower/IntrinsicCall.cpp

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -557,6 +557,7 @@ struct IntrinsicLibrary {
557557
llvm::ArrayRef<mlir::Value> args);
558558
template <typename Shift>
559559
mlir::Value genShift(mlir::Type resultType, llvm::ArrayRef<mlir::Value>);
560+
mlir::Value genShiftA(mlir::Type resultType, llvm::ArrayRef<mlir::Value>);
560561
mlir::Value genSign(mlir::Type, llvm::ArrayRef<mlir::Value>);
561562
fir::ExtendedValue genSize(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
562563
mlir::Value genSpacing(mlir::Type resultType,
@@ -958,7 +959,7 @@ static constexpr IntrinsicHandler handlers[]{
958959
{"radix", asAddr, handleDynamicOptional}}},
959960
/*isElemental=*/false},
960961
{"set_exponent", &I::genSetExponent},
961-
{"shifta", &I::genShift<mlir::arith::ShRSIOp>},
962+
{"shifta", &I::genShiftA},
962963
{"shiftl", &I::genShift<mlir::arith::ShLIOp>},
963964
{"shiftr", &I::genShift<mlir::arith::ShRUIOp>},
964965
{"sign", &I::genSign},
@@ -4015,7 +4016,7 @@ mlir::Value IntrinsicLibrary::genSetExponent(mlir::Type resultType,
40154016
fir::getBase(args[1])));
40164017
}
40174018

4018-
// SHIFTA, SHIFTL, SHIFTR
4019+
// SHIFTL, SHIFTR
40194020
template <typename Shift>
40204021
mlir::Value IntrinsicLibrary::genShift(mlir::Type resultType,
40214022
llvm::ArrayRef<mlir::Value> args) {
@@ -4041,6 +4042,31 @@ mlir::Value IntrinsicLibrary::genShift(mlir::Type resultType,
40414042
return builder.create<mlir::arith::SelectOp>(loc, outOfBounds, zero, shifted);
40424043
}
40434044

4045+
// SHIFTA
4046+
mlir::Value IntrinsicLibrary::genShiftA(mlir::Type resultType,
4047+
llvm::ArrayRef<mlir::Value> args) {
4048+
unsigned bits = resultType.getIntOrFloatBitWidth();
4049+
mlir::Value bitSize = builder.createIntegerConstant(loc, resultType, bits);
4050+
mlir::Value shift = builder.createConvert(loc, resultType, args[1]);
4051+
mlir::Value shiftEqBitSize = builder.create<mlir::arith::CmpIOp>(
4052+
loc, mlir::arith::CmpIPredicate::eq, shift, bitSize);
4053+
4054+
// Lowering of mlir::arith::ShRSIOp is using `ashr`. `ashr` is undefined when
4055+
// the shift amount is equal to the element size.
4056+
// So if SHIFT is equal to the bit width then it is handled as a special case.
4057+
mlir::Value zero = builder.createIntegerConstant(loc, resultType, 0);
4058+
mlir::Value minusOne = builder.createIntegerConstant(loc, resultType, -1);
4059+
mlir::Value valueIsNeg = builder.create<mlir::arith::CmpIOp>(
4060+
loc, mlir::arith::CmpIPredicate::slt, args[0], zero);
4061+
mlir::Value specialRes =
4062+
builder.create<mlir::arith::SelectOp>(loc, valueIsNeg, minusOne, zero);
4063+
4064+
mlir::Value shifted =
4065+
builder.create<mlir::arith::ShRSIOp>(loc, args[0], shift);
4066+
return builder.create<mlir::arith::SelectOp>(loc, shiftEqBitSize, specialRes,
4067+
shifted);
4068+
}
4069+
40444070
// SIGN
40454071
mlir::Value IntrinsicLibrary::genSign(mlir::Type resultType,
40464072
llvm::ArrayRef<mlir::Value> args) {

flang/test/Lower/Intrinsics/shifta.f90

Lines changed: 35 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,14 @@ subroutine shifta1_test(a, b, c)
1212
! CHECK: %[[B_VAL:.*]] = fir.load %[[B]] : !fir.ref<i32>
1313
c = shifta(a, b)
1414
! CHECK: %[[C_BITS:.*]] = arith.constant 8 : i8
15-
! CHECK: %[[C_0:.*]] = arith.constant 0 : i8
1615
! CHECK: %[[B_CONV:.*]] = fir.convert %[[B_VAL]] : (i32) -> i8
17-
! CHECK: %[[UNDER:.*]] = arith.cmpi slt, %[[B_CONV]], %[[C_0]] : i8
18-
! CHECK: %[[OVER:.*]] = arith.cmpi sge, %[[B_CONV]], %[[C_BITS]] : i8
19-
! CHECK: %[[INVALID:.*]] = arith.ori %[[UNDER]], %[[OVER]] : i1
20-
! CHECK: %[[SHIFT:.*]] = arith.shrsi %[[A_VAL]], %[[B_CONV]] : i8
21-
! CHECK: %[[RES:.*]] = arith.select %[[INVALID]], %[[C_0]], %[[SHIFT]] : i8
16+
! CHECK: %[[SHIFT_IS_BITWIDTH:.*]] = arith.cmpi eq, %[[B_CONV]], %[[C_BITS]] : i8
17+
! CHECK: %[[C0:.*]] = arith.constant 0 : i8
18+
! CHECK: %[[CM1:.*]] = arith.constant -1 : i8
19+
! CHECK: %[[IS_NEG:.*]] = arith.cmpi slt, %[[A_VAL]], %[[C0]] : i8
20+
! CHECK: %[[RES:.*]] = arith.select %[[IS_NEG]], %[[CM1]], %[[C0]] : i8
21+
! CHECK: %[[SHIFTED:.*]] = arith.shrsi %[[A_VAL]], %[[B_CONV]] : i8
22+
! CHECK: %{{.*}} = arith.select %[[SHIFT_IS_BITWIDTH]], %[[RES]], %[[SHIFTED]] : i8
2223
end subroutine shifta1_test
2324

2425
! CHECK-LABEL: shifta2_test
@@ -32,13 +33,14 @@ subroutine shifta2_test(a, b, c)
3233
! CHECK: %[[B_VAL:.*]] = fir.load %[[B]] : !fir.ref<i32>
3334
c = shifta(a, b)
3435
! CHECK: %[[C_BITS:.*]] = arith.constant 16 : i16
35-
! CHECK: %[[C_0:.*]] = arith.constant 0 : i16
3636
! CHECK: %[[B_CONV:.*]] = fir.convert %[[B_VAL]] : (i32) -> i16
37-
! CHECK: %[[UNDER:.*]] = arith.cmpi slt, %[[B_CONV]], %[[C_0]] : i16
38-
! CHECK: %[[OVER:.*]] = arith.cmpi sge, %[[B_CONV]], %[[C_BITS]] : i16
39-
! CHECK: %[[INVALID:.*]] = arith.ori %[[UNDER]], %[[OVER]] : i1
40-
! CHECK: %[[SHIFT:.*]] = arith.shrsi %[[A_VAL]], %[[B_CONV]] : i16
41-
! CHECK: %[[RES:.*]] = arith.select %[[INVALID]], %[[C_0]], %[[SHIFT]] : i16
37+
! CHECK: %[[SHIFT_IS_BITWIDTH:.*]] = arith.cmpi eq, %[[B_CONV]], %[[C_BITS]] : i16
38+
! CHECK: %[[C0:.*]] = arith.constant 0 : i16
39+
! CHECK: %[[CM1:.*]] = arith.constant -1 : i16
40+
! CHECK: %[[IS_NEG:.*]] = arith.cmpi slt, %[[A_VAL]], %[[C0]] : i16
41+
! CHECK: %[[RES:.*]] = arith.select %[[IS_NEG]], %[[CM1]], %[[C0]] : i16
42+
! CHECK: %[[SHIFTED:.*]] = arith.shrsi %[[A_VAL]], %[[B_CONV]] : i16
43+
! CHECK: %{{.*}} = arith.select %[[SHIFT_IS_BITWIDTH]], %[[RES]], %[[SHIFTED]] : i16
4244
end subroutine shifta2_test
4345

4446
! CHECK-LABEL: shifta4_test
@@ -52,12 +54,13 @@ subroutine shifta4_test(a, b, c)
5254
! CHECK: %[[B_VAL:.*]] = fir.load %[[B]] : !fir.ref<i32>
5355
c = shifta(a, b)
5456
! CHECK: %[[C_BITS:.*]] = arith.constant 32 : i32
55-
! CHECK: %[[C_0:.*]] = arith.constant 0 : i32
56-
! CHECK: %[[UNDER:.*]] = arith.cmpi slt, %[[B_VAL]], %[[C_0]] : i32
57-
! CHECK: %[[OVER:.*]] = arith.cmpi sge, %[[B_VAL]], %[[C_BITS]] : i32
58-
! CHECK: %[[INVALID:.*]] = arith.ori %[[UNDER]], %[[OVER]] : i1
59-
! CHECK: %[[SHIFT:.*]] = arith.shrsi %[[A_VAL]], %[[B_VAL]] : i32
60-
! CHECK: %[[RES:.*]] = arith.select %[[INVALID]], %[[C_0]], %[[SHIFT]] : i32
57+
! CHECK: %[[SHIFT_IS_BITWIDTH:.*]] = arith.cmpi eq, %[[B_VAL]], %[[C_BITS]] : i32
58+
! CHECK: %[[C0:.*]] = arith.constant 0 : i32
59+
! CHECK: %[[CM1:.*]] = arith.constant -1 : i32
60+
! CHECK: %[[IS_NEG:.*]] = arith.cmpi slt, %[[A_VAL]], %[[C0]] : i32
61+
! CHECK: %[[RES:.*]] = arith.select %[[IS_NEG]], %[[CM1]], %[[C0]] : i32
62+
! CHECK: %[[SHIFTED:.*]] = arith.shrsi %[[A_VAL]], %[[B_VAL]] : i32
63+
! CHECK: %{{.*}} = arith.select %[[SHIFT_IS_BITWIDTH]], %[[RES]], %[[SHIFTED]] : i32
6164
end subroutine shifta4_test
6265

6366
! CHECK-LABEL: shifta8_test
@@ -71,13 +74,14 @@ subroutine shifta8_test(a, b, c)
7174
! CHECK: %[[B_VAL:.*]] = fir.load %[[B]] : !fir.ref<i32>
7275
c = shifta(a, b)
7376
! CHECK: %[[C_BITS:.*]] = arith.constant 64 : i64
74-
! CHECK: %[[C_0:.*]] = arith.constant 0 : i64
7577
! CHECK: %[[B_CONV:.*]] = fir.convert %[[B_VAL]] : (i32) -> i64
76-
! CHECK: %[[UNDER:.*]] = arith.cmpi slt, %[[B_CONV]], %[[C_0]] : i64
77-
! CHECK: %[[OVER:.*]] = arith.cmpi sge, %[[B_CONV]], %[[C_BITS]] : i64
78-
! CHECK: %[[INVALID:.*]] = arith.ori %[[UNDER]], %[[OVER]] : i1
79-
! CHECK: %[[SHIFT:.*]] = arith.shrsi %[[A_VAL]], %[[B_CONV]] : i64
80-
! CHECK: %[[RES:.*]] = arith.select %[[INVALID]], %[[C_0]], %[[SHIFT]] : i64
78+
! CHECK: %[[SHIFT_IS_BITWIDTH:.*]] = arith.cmpi eq, %[[B_CONV]], %[[C_BITS]] : i64
79+
! CHECK: %[[C0:.*]] = arith.constant 0 : i64
80+
! CHECK: %[[CM1:.*]] = arith.constant -1 : i64
81+
! CHECK: %[[IS_NEG:.*]] = arith.cmpi slt, %[[A_VAL]], %[[C0]] : i64
82+
! CHECK: %[[RES:.*]] = arith.select %[[IS_NEG]], %[[CM1]], %[[C0]] : i64
83+
! CHECK: %[[SHIFTED:.*]] = arith.shrsi %[[A_VAL]], %[[B_CONV]] : i64
84+
! CHECK: %{{.*}} = arith.select %[[SHIFT_IS_BITWIDTH]], %[[RES]], %[[SHIFTED]] : i64
8185
end subroutine shifta8_test
8286

8387
! CHECK-LABEL: shifta16_test
@@ -91,11 +95,12 @@ subroutine shifta16_test(a, b, c)
9195
! CHECK: %[[B_VAL:.*]] = fir.load %[[B]] : !fir.ref<i32>
9296
c = shifta(a, b)
9397
! CHECK: %[[C_BITS:.*]] = arith.constant 128 : i128
94-
! CHECK: %[[C_0:.*]] = arith.constant 0 : i128
9598
! CHECK: %[[B_CONV:.*]] = fir.convert %[[B_VAL]] : (i32) -> i128
96-
! CHECK: %[[UNDER:.*]] = arith.cmpi slt, %[[B_CONV]], %[[C_0]] : i128
97-
! CHECK: %[[OVER:.*]] = arith.cmpi sge, %[[B_CONV]], %[[C_BITS]] : i128
98-
! CHECK: %[[INVALID:.*]] = arith.ori %[[UNDER]], %[[OVER]] : i1
99-
! CHECK: %[[SHIFT:.*]] = arith.shrsi %[[A_VAL]], %[[B_CONV]] : i128
100-
! CHECK: %[[RES:.*]] = arith.select %[[INVALID]], %[[C_0]], %[[SHIFT]] : i128
99+
! CHECK: %[[SHIFT_IS_BITWIDTH:.*]] = arith.cmpi eq, %[[B_CONV]], %[[C_BITS]] : i128
100+
! CHECK: %[[C0:.*]] = arith.constant 0 : i128
101+
! CHECK: %[[CM1:.*]] = arith.constant {{.*}} : i128
102+
! CHECK: %[[IS_NEG:.*]] = arith.cmpi slt, %[[A_VAL]], %[[C0]] : i128
103+
! CHECK: %[[RES:.*]] = arith.select %[[IS_NEG]], %[[CM1]], %[[C0]] : i128
104+
! CHECK: %[[SHIFTED:.*]] = arith.shrsi %[[A_VAL]], %[[B_CONV]] : i128
105+
! CHECK: %{{.*}} = arith.select %[[SHIFT_IS_BITWIDTH]], %[[RES]], %[[SHIFTED]] : i128
101106
end subroutine shifta16_test

0 commit comments

Comments
 (0)