[flang-commits] [flang] 83ba00b - [flang] Handle special case for SHIFTA intrinsic

Valentin Clement via flang-commits flang-commits at lists.llvm.org
Thu Sep 1 07:28:18 PDT 2022


Author: Valentin Clement
Date: 2022-09-01T16:28:08+02:00
New Revision: 83ba00bd2eb7ac13fe18edecb306c6c9e14d4d67

URL: https://github.com/llvm/llvm-project/commit/83ba00bd2eb7ac13fe18edecb306c6c9e14d4d67
DIFF: https://github.com/llvm/llvm-project/commit/83ba00bd2eb7ac13fe18edecb306c6c9e14d4d67.diff

LOG: [flang] Handle special case for SHIFTA intrinsic

This patch update the lowering of the shifta intrinsic to match
the behvior of gfortran. When the SHIFT value is equal to the
integer bitwidth then we handle it differently.
This is due to the operation used in lowering (`mlir::arith::ShRSIOp`)
that lowers to `ashr`.

Before this patch we have the following results:

```
SHIFTA(  -1, 8) =  0
SHIFTA(  -2, 8) =  0
SHIFTA( -30, 8) =  0
SHIFTA( -31, 8) =  0
SHIFTA( -32, 8) =  0
SHIFTA( -33, 8) =  0
SHIFTA(-126, 8) =  0
SHIFTA(-127, 8) =  0
SHIFTA(-128, 8) =  0
```

While gfortran is giving this:

```
SHIFTA(  -1, 8) = -1
SHIFTA(  -2, 8) = -1
SHIFTA( -30, 8) = -1
SHIFTA( -31, 8) = -1
SHIFTA( -32, 8) = -1
SHIFTA( -33, 8) = -1
SHIFTA(-126, 8) = -1
SHIFTA(-127, 8) = -1
SHIFTA(-128, 8) = -1
```

With this patch flang and gfortran have the same behavior.

Reviewed By: jeanPerier

Differential Revision: https://reviews.llvm.org/D133104

Added: 
    

Modified: 
    flang/lib/Lower/IntrinsicCall.cpp
    flang/test/Lower/Intrinsics/shifta.f90

Removed: 
    


################################################################################
diff  --git a/flang/lib/Lower/IntrinsicCall.cpp b/flang/lib/Lower/IntrinsicCall.cpp
index c5a1121985257..545c2d96a28f2 100644
--- a/flang/lib/Lower/IntrinsicCall.cpp
+++ b/flang/lib/Lower/IntrinsicCall.cpp
@@ -557,6 +557,7 @@ struct IntrinsicLibrary {
                              llvm::ArrayRef<mlir::Value> args);
   template <typename Shift>
   mlir::Value genShift(mlir::Type resultType, llvm::ArrayRef<mlir::Value>);
+  mlir::Value genShiftA(mlir::Type resultType, llvm::ArrayRef<mlir::Value>);
   mlir::Value genSign(mlir::Type, llvm::ArrayRef<mlir::Value>);
   fir::ExtendedValue genSize(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
   mlir::Value genSpacing(mlir::Type resultType,
@@ -958,7 +959,7 @@ static constexpr IntrinsicHandler handlers[]{
        {"radix", asAddr, handleDynamicOptional}}},
      /*isElemental=*/false},
     {"set_exponent", &I::genSetExponent},
-    {"shifta", &I::genShift<mlir::arith::ShRSIOp>},
+    {"shifta", &I::genShiftA},
     {"shiftl", &I::genShift<mlir::arith::ShLIOp>},
     {"shiftr", &I::genShift<mlir::arith::ShRUIOp>},
     {"sign", &I::genSign},
@@ -4015,7 +4016,7 @@ mlir::Value IntrinsicLibrary::genSetExponent(mlir::Type resultType,
                                    fir::getBase(args[1])));
 }
 
-// SHIFTA, SHIFTL, SHIFTR
+// SHIFTL, SHIFTR
 template <typename Shift>
 mlir::Value IntrinsicLibrary::genShift(mlir::Type resultType,
                                        llvm::ArrayRef<mlir::Value> args) {
@@ -4041,6 +4042,31 @@ mlir::Value IntrinsicLibrary::genShift(mlir::Type resultType,
   return builder.create<mlir::arith::SelectOp>(loc, outOfBounds, zero, shifted);
 }
 
+// SHIFTA
+mlir::Value IntrinsicLibrary::genShiftA(mlir::Type resultType,
+                                        llvm::ArrayRef<mlir::Value> args) {
+  unsigned bits = resultType.getIntOrFloatBitWidth();
+  mlir::Value bitSize = builder.createIntegerConstant(loc, resultType, bits);
+  mlir::Value shift = builder.createConvert(loc, resultType, args[1]);
+  mlir::Value shiftEqBitSize = builder.create<mlir::arith::CmpIOp>(
+      loc, mlir::arith::CmpIPredicate::eq, shift, bitSize);
+
+  // Lowering of mlir::arith::ShRSIOp is using `ashr`. `ashr` is undefined when
+  // the shift amount is equal to the element size.
+  // So if SHIFT is equal to the bit width then it is handled as a special case.
+  mlir::Value zero = builder.createIntegerConstant(loc, resultType, 0);
+  mlir::Value minusOne = builder.createIntegerConstant(loc, resultType, -1);
+  mlir::Value valueIsNeg = builder.create<mlir::arith::CmpIOp>(
+      loc, mlir::arith::CmpIPredicate::slt, args[0], zero);
+  mlir::Value specialRes =
+      builder.create<mlir::arith::SelectOp>(loc, valueIsNeg, minusOne, zero);
+
+  mlir::Value shifted =
+      builder.create<mlir::arith::ShRSIOp>(loc, args[0], shift);
+  return builder.create<mlir::arith::SelectOp>(loc, shiftEqBitSize, specialRes,
+                                               shifted);
+}
+
 // SIGN
 mlir::Value IntrinsicLibrary::genSign(mlir::Type resultType,
                                       llvm::ArrayRef<mlir::Value> args) {

diff  --git a/flang/test/Lower/Intrinsics/shifta.f90 b/flang/test/Lower/Intrinsics/shifta.f90
index 311206b213d22..92fda39e84bf8 100644
--- a/flang/test/Lower/Intrinsics/shifta.f90
+++ b/flang/test/Lower/Intrinsics/shifta.f90
@@ -12,13 +12,14 @@ subroutine shifta1_test(a, b, c)
   ! CHECK: %[[B_VAL:.*]] = fir.load %[[B]] : !fir.ref<i32>
   c = shifta(a, b)
   ! CHECK: %[[C_BITS:.*]] = arith.constant 8 : i8
-  ! CHECK: %[[C_0:.*]] = arith.constant 0 : i8
   ! CHECK: %[[B_CONV:.*]] = fir.convert %[[B_VAL]] : (i32) -> i8
-  ! CHECK: %[[UNDER:.*]] = arith.cmpi slt, %[[B_CONV]], %[[C_0]] : i8
-  ! CHECK: %[[OVER:.*]] = arith.cmpi sge, %[[B_CONV]], %[[C_BITS]] : i8
-  ! CHECK: %[[INVALID:.*]] = arith.ori %[[UNDER]], %[[OVER]] : i1
-  ! CHECK: %[[SHIFT:.*]] = arith.shrsi %[[A_VAL]], %[[B_CONV]] : i8
-  ! CHECK: %[[RES:.*]] = arith.select %[[INVALID]], %[[C_0]], %[[SHIFT]] : i8
+  ! CHECK: %[[SHIFT_IS_BITWIDTH:.*]] = arith.cmpi eq, %[[B_CONV]], %[[C_BITS]] : i8
+  ! CHECK: %[[C0:.*]] = arith.constant 0 : i8
+  ! CHECK: %[[CM1:.*]] = arith.constant -1 : i8
+  ! CHECK: %[[IS_NEG:.*]] = arith.cmpi slt, %[[A_VAL]], %[[C0]] : i8
+  ! CHECK: %[[RES:.*]] = arith.select %[[IS_NEG]], %[[CM1]], %[[C0]] : i8
+  ! CHECK: %[[SHIFTED:.*]] = arith.shrsi %[[A_VAL]], %[[B_CONV]] : i8
+  ! CHECK: %{{.*}} = arith.select %[[SHIFT_IS_BITWIDTH]], %[[RES]], %[[SHIFTED]] : i8
 end subroutine shifta1_test
 
 ! CHECK-LABEL: shifta2_test
@@ -32,13 +33,14 @@ subroutine shifta2_test(a, b, c)
   ! CHECK: %[[B_VAL:.*]] = fir.load %[[B]] : !fir.ref<i32>
   c = shifta(a, b)
   ! CHECK: %[[C_BITS:.*]] = arith.constant 16 : i16
-  ! CHECK: %[[C_0:.*]] = arith.constant 0 : i16
   ! CHECK: %[[B_CONV:.*]] = fir.convert %[[B_VAL]] : (i32) -> i16
-  ! CHECK: %[[UNDER:.*]] = arith.cmpi slt, %[[B_CONV]], %[[C_0]] : i16
-  ! CHECK: %[[OVER:.*]] = arith.cmpi sge, %[[B_CONV]], %[[C_BITS]] : i16
-  ! CHECK: %[[INVALID:.*]] = arith.ori %[[UNDER]], %[[OVER]] : i1
-  ! CHECK: %[[SHIFT:.*]] = arith.shrsi %[[A_VAL]], %[[B_CONV]] : i16
-  ! CHECK: %[[RES:.*]] = arith.select %[[INVALID]], %[[C_0]], %[[SHIFT]] : i16
+  ! CHECK: %[[SHIFT_IS_BITWIDTH:.*]] = arith.cmpi eq, %[[B_CONV]], %[[C_BITS]] : i16
+  ! CHECK: %[[C0:.*]] = arith.constant 0 : i16
+  ! CHECK: %[[CM1:.*]] = arith.constant -1 : i16
+  ! CHECK: %[[IS_NEG:.*]] = arith.cmpi slt, %[[A_VAL]], %[[C0]] : i16
+  ! CHECK: %[[RES:.*]] = arith.select %[[IS_NEG]], %[[CM1]], %[[C0]] : i16
+  ! CHECK: %[[SHIFTED:.*]] = arith.shrsi %[[A_VAL]], %[[B_CONV]] : i16
+  ! CHECK: %{{.*}} = arith.select %[[SHIFT_IS_BITWIDTH]], %[[RES]], %[[SHIFTED]] : i16
 end subroutine shifta2_test
 
 ! CHECK-LABEL: shifta4_test
@@ -52,12 +54,13 @@ subroutine shifta4_test(a, b, c)
   ! CHECK: %[[B_VAL:.*]] = fir.load %[[B]] : !fir.ref<i32>
   c = shifta(a, b)
   ! CHECK: %[[C_BITS:.*]] = arith.constant 32 : i32
-  ! CHECK: %[[C_0:.*]] = arith.constant 0 : i32
-  ! CHECK: %[[UNDER:.*]] = arith.cmpi slt, %[[B_VAL]], %[[C_0]] : i32
-  ! CHECK: %[[OVER:.*]] = arith.cmpi sge, %[[B_VAL]], %[[C_BITS]] : i32
-  ! CHECK: %[[INVALID:.*]] = arith.ori %[[UNDER]], %[[OVER]] : i1
-  ! CHECK: %[[SHIFT:.*]] = arith.shrsi %[[A_VAL]], %[[B_VAL]] : i32
-  ! CHECK: %[[RES:.*]] = arith.select %[[INVALID]], %[[C_0]], %[[SHIFT]] : i32
+  ! CHECK: %[[SHIFT_IS_BITWIDTH:.*]] = arith.cmpi eq, %[[B_VAL]], %[[C_BITS]] : i32
+  ! CHECK: %[[C0:.*]] = arith.constant 0 : i32
+  ! CHECK: %[[CM1:.*]] = arith.constant -1 : i32
+  ! CHECK: %[[IS_NEG:.*]] = arith.cmpi slt, %[[A_VAL]], %[[C0]] : i32
+  ! CHECK: %[[RES:.*]] = arith.select %[[IS_NEG]], %[[CM1]], %[[C0]] : i32
+  ! CHECK: %[[SHIFTED:.*]] = arith.shrsi %[[A_VAL]], %[[B_VAL]] : i32
+  ! CHECK: %{{.*}} = arith.select %[[SHIFT_IS_BITWIDTH]], %[[RES]], %[[SHIFTED]] : i32
 end subroutine shifta4_test
 
 ! CHECK-LABEL: shifta8_test
@@ -71,13 +74,14 @@ subroutine shifta8_test(a, b, c)
   ! CHECK: %[[B_VAL:.*]] = fir.load %[[B]] : !fir.ref<i32>
   c = shifta(a, b)
   ! CHECK: %[[C_BITS:.*]] = arith.constant 64 : i64
-  ! CHECK: %[[C_0:.*]] = arith.constant 0 : i64
   ! CHECK: %[[B_CONV:.*]] = fir.convert %[[B_VAL]] : (i32) -> i64
-  ! CHECK: %[[UNDER:.*]] = arith.cmpi slt, %[[B_CONV]], %[[C_0]] : i64
-  ! CHECK: %[[OVER:.*]] = arith.cmpi sge, %[[B_CONV]], %[[C_BITS]] : i64
-  ! CHECK: %[[INVALID:.*]] = arith.ori %[[UNDER]], %[[OVER]] : i1
-  ! CHECK: %[[SHIFT:.*]] = arith.shrsi %[[A_VAL]], %[[B_CONV]] : i64
-  ! CHECK: %[[RES:.*]] = arith.select %[[INVALID]], %[[C_0]], %[[SHIFT]] : i64
+  ! CHECK: %[[SHIFT_IS_BITWIDTH:.*]] = arith.cmpi eq, %[[B_CONV]], %[[C_BITS]] : i64
+  ! CHECK: %[[C0:.*]] = arith.constant 0 : i64
+  ! CHECK: %[[CM1:.*]] = arith.constant -1 : i64
+  ! CHECK: %[[IS_NEG:.*]] = arith.cmpi slt, %[[A_VAL]], %[[C0]] : i64
+  ! CHECK: %[[RES:.*]] = arith.select %[[IS_NEG]], %[[CM1]], %[[C0]] : i64
+  ! CHECK: %[[SHIFTED:.*]] = arith.shrsi %[[A_VAL]], %[[B_CONV]] : i64
+  ! CHECK: %{{.*}} = arith.select %[[SHIFT_IS_BITWIDTH]], %[[RES]], %[[SHIFTED]] : i64
 end subroutine shifta8_test
 
 ! CHECK-LABEL: shifta16_test
@@ -91,11 +95,12 @@ subroutine shifta16_test(a, b, c)
   ! CHECK: %[[B_VAL:.*]] = fir.load %[[B]] : !fir.ref<i32>
   c = shifta(a, b)
   ! CHECK: %[[C_BITS:.*]] = arith.constant 128 : i128
-  ! CHECK: %[[C_0:.*]] = arith.constant 0 : i128
   ! CHECK: %[[B_CONV:.*]] = fir.convert %[[B_VAL]] : (i32) -> i128
-  ! CHECK: %[[UNDER:.*]] = arith.cmpi slt, %[[B_CONV]], %[[C_0]] : i128
-  ! CHECK: %[[OVER:.*]] = arith.cmpi sge, %[[B_CONV]], %[[C_BITS]] : i128
-  ! CHECK: %[[INVALID:.*]] = arith.ori %[[UNDER]], %[[OVER]] : i1
-  ! CHECK: %[[SHIFT:.*]] = arith.shrsi %[[A_VAL]], %[[B_CONV]] : i128
-  ! CHECK: %[[RES:.*]] = arith.select %[[INVALID]], %[[C_0]], %[[SHIFT]] : i128
+  ! CHECK: %[[SHIFT_IS_BITWIDTH:.*]] = arith.cmpi eq, %[[B_CONV]], %[[C_BITS]] : i128
+  ! CHECK: %[[C0:.*]] = arith.constant 0 : i128
+  ! CHECK: %[[CM1:.*]] = arith.constant {{.*}} : i128
+  ! CHECK: %[[IS_NEG:.*]] = arith.cmpi slt, %[[A_VAL]], %[[C0]] : i128
+  ! CHECK: %[[RES:.*]] = arith.select %[[IS_NEG]], %[[CM1]], %[[C0]] : i128
+  ! CHECK: %[[SHIFTED:.*]] = arith.shrsi %[[A_VAL]], %[[B_CONV]] : i128
+  ! CHECK: %{{.*}} = arith.select %[[SHIFT_IS_BITWIDTH]], %[[RES]], %[[SHIFTED]] : i128
 end subroutine shifta16_test


        


More information about the flang-commits mailing list