[llvm] 0399473 - [InstCombine] add fold for (ShiftC >> X) <u C

Sanjay Patel via llvm-commits llvm-commits at lists.llvm.org
Sun Jun 19 08:04:58 PDT 2022


Author: Sanjay Patel
Date: 2022-06-19T11:03:28-04:00
New Revision: 0399473de886595d8ce3346f2cc99c94267496e5

URL: https://github.com/llvm/llvm-project/commit/0399473de886595d8ce3346f2cc99c94267496e5
DIFF: https://github.com/llvm/llvm-project/commit/0399473de886595d8ce3346f2cc99c94267496e5.diff

LOG: [InstCombine] add fold for (ShiftC >> X) <u C

https://alive2.llvm.org/ce/z/RcdzM-

This fixes a regression noted in issue #56046.

Added: 
    

Modified: 
    llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
    llvm/test/Transforms/InstCombine/icmp-shr.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 49f4464a996d..c4bfdcf29525 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -2215,22 +2215,35 @@ Instruction *InstCombinerImpl::foldICmpShrConstant(ICmpInst &Cmp,
   if (Cmp.isEquality() && Shr->isExact() && C.isZero())
     return new ICmpInst(Pred, X, Cmp.getOperand(1));
 
-  const APInt *ShiftVal;
-  if (Cmp.isEquality() && match(Shr->getOperand(0), m_APInt(ShiftVal)))
-    return foldICmpShrConstConst(Cmp, Shr->getOperand(1), C, *ShiftVal);
+  bool IsAShr = Shr->getOpcode() == Instruction::AShr;
+  const APInt *ShiftValC;
+  if (match(Shr->getOperand(0), m_APInt(ShiftValC))) {
+    if (Cmp.isEquality())
+      return foldICmpShrConstConst(Cmp, Shr->getOperand(1), C, *ShiftValC);
 
-  const APInt *ShiftAmt;
-  if (!match(Shr->getOperand(1), m_APInt(ShiftAmt)))
+    // If the shifted constant is a power-of-2, test the shift amount directly:
+    // (ShiftValC >> X) >u C --> X <u (LZ(C) - LZ(ShiftValC))
+    // TODO: Handle ult.
+    if (!IsAShr && Pred == CmpInst::ICMP_UGT && ShiftValC->isPowerOf2()) {
+      assert(ShiftValC->ugt(C) && "Expected simplify of compare");
+      unsigned CmpLZ = C.countLeadingZeros();
+      unsigned ShiftLZ = ShiftValC->countLeadingZeros();
+      Constant *NewC = ConstantInt::get(Shr->getType(), CmpLZ - ShiftLZ);
+      return new ICmpInst(ICmpInst::ICMP_ULT, Shr->User::getOperand(1), NewC);
+    }
+  }
+
+  const APInt *ShiftAmtC;
+  if (!match(Shr->getOperand(1), m_APInt(ShiftAmtC)))
     return nullptr;
 
   // Check that the shift amount is in range. If not, don't perform undefined
   // shifts. When the shift is visited it will be simplified.
   unsigned TypeBits = C.getBitWidth();
-  unsigned ShAmtVal = ShiftAmt->getLimitedValue(TypeBits);
+  unsigned ShAmtVal = ShiftAmtC->getLimitedValue(TypeBits);
   if (ShAmtVal >= TypeBits || ShAmtVal == 0)
     return nullptr;
 
-  bool IsAShr = Shr->getOpcode() == Instruction::AShr;
   bool IsExact = Shr->isExact();
   Type *ShrTy = Shr->getType();
   // TODO: If we could guarantee that InstSimplify would handle all of the

diff  --git a/llvm/test/Transforms/InstCombine/icmp-shr.ll b/llvm/test/Transforms/InstCombine/icmp-shr.ll
index 2d78916356ab..f013e6bdcbe2 100644
--- a/llvm/test/Transforms/InstCombine/icmp-shr.ll
+++ b/llvm/test/Transforms/InstCombine/icmp-shr.ll
@@ -1008,8 +1008,7 @@ define i1 @ashr_exact_ne_0_multiuse(i8 %x) {
 
 define i1 @lshr_pow2_ugt(i8 %x) {
 ; CHECK-LABEL: @lshr_pow2_ugt(
-; CHECK-NEXT:    [[S:%.*]] = lshr i8 2, [[X:%.*]]
-; CHECK-NEXT:    [[R:%.*]] = icmp ugt i8 [[S]], 1
+; CHECK-NEXT:    [[R:%.*]] = icmp eq i8 [[X:%.*]], 0
 ; CHECK-NEXT:    ret i1 [[R]]
 ;
   %s = lshr i8 2, %x
@@ -1021,7 +1020,7 @@ define i1 @lshr_pow2_ugt_use(i8 %x) {
 ; CHECK-LABEL: @lshr_pow2_ugt_use(
 ; CHECK-NEXT:    [[S:%.*]] = lshr i8 -128, [[X:%.*]]
 ; CHECK-NEXT:    call void @use(i8 [[S]])
-; CHECK-NEXT:    [[R:%.*]] = icmp ugt i8 [[S]], 5
+; CHECK-NEXT:    [[R:%.*]] = icmp ult i8 [[X]], 5
 ; CHECK-NEXT:    ret i1 [[R]]
 ;
   %s = lshr i8 128, %x
@@ -1032,8 +1031,7 @@ define i1 @lshr_pow2_ugt_use(i8 %x) {
 
 define <2 x i1> @lshr_pow2_ugt_vec(<2 x i8> %x) {
 ; CHECK-LABEL: @lshr_pow2_ugt_vec(
-; CHECK-NEXT:    [[S:%.*]] = lshr <2 x i8> <i8 8, i8 8>, [[X:%.*]]
-; CHECK-NEXT:    [[R:%.*]] = icmp ugt <2 x i8> [[S]], <i8 6, i8 6>
+; CHECK-NEXT:    [[R:%.*]] = icmp eq <2 x i8> [[X:%.*]], zeroinitializer
 ; CHECK-NEXT:    ret <2 x i1> [[R]]
 ;
   %s = lshr <2 x i8> <i8 8, i8 8>, %x
@@ -1041,6 +1039,8 @@ define <2 x i1> @lshr_pow2_ugt_vec(<2 x i8> %x) {
   ret <2 x i1> %r
 }
 
+; negative test - need power-of-2
+
 define i1 @lshr_not_pow2_ugt(i8 %x) {
 ; CHECK-LABEL: @lshr_not_pow2_ugt(
 ; CHECK-NEXT:    [[S:%.*]] = lshr i8 3, [[X:%.*]]
@@ -1054,8 +1054,7 @@ define i1 @lshr_not_pow2_ugt(i8 %x) {
 
 define i1 @lshr_pow2_ugt1(i8 %x) {
 ; CHECK-LABEL: @lshr_pow2_ugt1(
-; CHECK-NEXT:    [[S:%.*]] = lshr i8 -128, [[X:%.*]]
-; CHECK-NEXT:    [[R:%.*]] = icmp ugt i8 [[S]], 1
+; CHECK-NEXT:    [[R:%.*]] = icmp ult i8 [[X:%.*]], 7
 ; CHECK-NEXT:    ret i1 [[R]]
 ;
   %s = lshr i8 128, %x
@@ -1063,6 +1062,8 @@ define i1 @lshr_pow2_ugt1(i8 %x) {
   ret i1 %r
 }
 
+; negative test - need logical shift
+
 define i1 @ashr_pow2_ugt(i8 %x) {
 ; CHECK-LABEL: @ashr_pow2_ugt(
 ; CHECK-NEXT:    [[S:%.*]] = ashr i8 -128, [[X:%.*]]
@@ -1074,6 +1075,8 @@ define i1 @ashr_pow2_ugt(i8 %x) {
   ret i1 %r
 }
 
+; negative test - need unsigned pred
+
 define i1 @lshr_pow2_sgt(i8 %x) {
 ; CHECK-LABEL: @lshr_pow2_sgt(
 ; CHECK-NEXT:    [[S:%.*]] = lshr i8 -128, [[X:%.*]]


        


More information about the llvm-commits mailing list