[llvm] 2a2c35a - [InstCombine] Fold `icmp spred (mul nsw X, Z), (mul nsw Y, Z)` into `icmp spred X, Y` (#110630)

via llvm-commits llvm-commits at lists.llvm.org
Tue Oct 1 07:16:08 PDT 2024


Author: Yingwei Zheng
Date: 2024-10-01T22:16:05+08:00
New Revision: 2a2c35a9a652ba8562884ec76008979c761df207

URL: https://github.com/llvm/llvm-project/commit/2a2c35a9a652ba8562884ec76008979c761df207
DIFF: https://github.com/llvm/llvm-project/commit/2a2c35a9a652ba8562884ec76008979c761df207.diff

LOG: [InstCombine] Fold `icmp spred (mul nsw X, Z), (mul nsw Y, Z)` into `icmp spred X, Y` (#110630)

```
icmp spred (mul nsw X, Z), (mul nsw Y, Z) -> icmp spred X, Y iff Z > 0
icmp spred (mul nsw X, Z), (mul nsw Y, Z) -> icmp spred Y, X iff Z < 0
```
Alive2: https://alive2.llvm.org/ce/z/9fXFfn

Added: 
    

Modified: 
    llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
    llvm/test/Transforms/InstCombine/icmp-mul.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index e3f4925024e65c..d4d45384ec90e3 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -5273,37 +5273,46 @@ Instruction *InstCombinerImpl::foldICmpBinOp(ICmpInst &I,
 
   {
     // Try to remove shared multiplier from comparison:
-    // X * Z u{lt/le/gt/ge}/eq/ne Y * Z
+    // X * Z pred Y * Z
     Value *X, *Y, *Z;
-    if (Pred == ICmpInst::getUnsignedPredicate(Pred) &&
-        ((match(Op0, m_Mul(m_Value(X), m_Value(Z))) &&
-          match(Op1, m_c_Mul(m_Specific(Z), m_Value(Y)))) ||
-         (match(Op0, m_Mul(m_Value(Z), m_Value(X))) &&
-          match(Op1, m_c_Mul(m_Specific(Z), m_Value(Y)))))) {
-      bool NonZero;
-      if (ICmpInst::isEquality(Pred)) {
-        // If X != Y, fold (X *nw Z) eq/ne (Y *nw Z) -> Z eq/ne 0
-        if (((Op0HasNSW && Op1HasNSW) || (Op0HasNUW && Op1HasNUW)) &&
-            isKnownNonEqual(X, Y, DL, &AC, &I, &DT))
-          return new ICmpInst(Pred, Z, Constant::getNullValue(Z->getType()));
-
-        KnownBits ZKnown = computeKnownBits(Z, 0, &I);
-        // if Z % 2 != 0
-        //    X * Z eq/ne Y * Z -> X eq/ne Y
-        if (ZKnown.countMaxTrailingZeros() == 0)
-          return new ICmpInst(Pred, X, Y);
-        NonZero = !ZKnown.One.isZero() || isKnownNonZero(Z, Q);
-        // if Z != 0 and nsw(X * Z) and nsw(Y * Z)
-        //    X * Z eq/ne Y * Z -> X eq/ne Y
-        if (NonZero && BO0 && BO1 && Op0HasNSW && Op1HasNSW)
+    if ((match(Op0, m_Mul(m_Value(X), m_Value(Z))) &&
+         match(Op1, m_c_Mul(m_Specific(Z), m_Value(Y)))) ||
+        (match(Op0, m_Mul(m_Value(Z), m_Value(X))) &&
+         match(Op1, m_c_Mul(m_Specific(Z), m_Value(Y))))) {
+      if (ICmpInst::isSigned(Pred)) {
+        if (Op0HasNSW && Op1HasNSW) {
+          KnownBits ZKnown = computeKnownBits(Z, 0, &I);
+          if (ZKnown.isStrictlyPositive())
+            return new ICmpInst(Pred, X, Y);
+          if (ZKnown.isNegative())
+            return new ICmpInst(ICmpInst::getSwappedPredicate(Pred), X, Y);
+        }
+      } else {
+        bool NonZero;
+        if (ICmpInst::isEquality(Pred)) {
+          // If X != Y, fold (X *nw Z) eq/ne (Y *nw Z) -> Z eq/ne 0
+          if (((Op0HasNSW && Op1HasNSW) || (Op0HasNUW && Op1HasNUW)) &&
+              isKnownNonEqual(X, Y, DL, &AC, &I, &DT))
+            return new ICmpInst(Pred, Z, Constant::getNullValue(Z->getType()));
+
+          KnownBits ZKnown = computeKnownBits(Z, 0, &I);
+          // if Z % 2 != 0
+          //    X * Z eq/ne Y * Z -> X eq/ne Y
+          if (ZKnown.countMaxTrailingZeros() == 0)
+            return new ICmpInst(Pred, X, Y);
+          NonZero = !ZKnown.One.isZero() || isKnownNonZero(Z, Q);
+          // if Z != 0 and nsw(X * Z) and nsw(Y * Z)
+          //    X * Z eq/ne Y * Z -> X eq/ne Y
+          if (NonZero && BO0 && BO1 && Op0HasNSW && Op1HasNSW)
+            return new ICmpInst(Pred, X, Y);
+        } else
+          NonZero = isKnownNonZero(Z, Q);
+
+        // If Z != 0 and nuw(X * Z) and nuw(Y * Z)
+        //    X * Z u{lt/le/gt/ge}/eq/ne Y * Z -> X u{lt/le/gt/ge}/eq/ne Y
+        if (NonZero && BO0 && BO1 && Op0HasNUW && Op1HasNUW)
           return new ICmpInst(Pred, X, Y);
-      } else
-        NonZero = isKnownNonZero(Z, Q);
-
-      // If Z != 0 and nuw(X * Z) and nuw(Y * Z)
-      //    X * Z u{lt/le/gt/ge}/eq/ne Y * Z -> X u{lt/le/gt/ge}/eq/ne Y
-      if (NonZero && BO0 && BO1 && Op0HasNUW && Op1HasNUW)
-        return new ICmpInst(Pred, X, Y);
+      }
     }
   }
 

diff  --git a/llvm/test/Transforms/InstCombine/icmp-mul.ll b/llvm/test/Transforms/InstCombine/icmp-mul.ll
index 7ce43908c62cd0..a14f342ae2482b 100644
--- a/llvm/test/Transforms/InstCombine/icmp-mul.ll
+++ b/llvm/test/Transforms/InstCombine/icmp-mul.ll
@@ -1330,3 +1330,137 @@ entry:
   %cmp = icmp ult i8 %mul1, %mul2
   ret i1 %cmp
 }
+
+define i1 @icmp_mul_nsw_slt(i8 %x, i8 %y) {
+; CHECK-LABEL: @icmp_mul_nsw_slt(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[CMP:%.*]] = icmp slt i8 [[X:%.*]], [[Y:%.*]]
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+entry:
+  %mul1 = mul nsw i8 %x, 7
+  %mul2 = mul nsw i8 %y, 7
+  %cmp = icmp slt i8 %mul1, %mul2
+  ret i1 %cmp
+}
+
+define i1 @icmp_mul_nsw_sle(i8 %x, i8 %y) {
+; CHECK-LABEL: @icmp_mul_nsw_sle(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[CMP:%.*]] = icmp sle i8 [[X:%.*]], [[Y:%.*]]
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+entry:
+  %mul1 = mul nsw i8 %x, 7
+  %mul2 = mul nsw i8 %y, 7
+  %cmp = icmp sle i8 %mul1, %mul2
+  ret i1 %cmp
+}
+
+define i1 @icmp_mul_nsw_sgt(i8 %x, i8 %y) {
+; CHECK-LABEL: @icmp_mul_nsw_sgt(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[CMP:%.*]] = icmp sgt i8 [[X:%.*]], [[Y:%.*]]
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+entry:
+  %mul1 = mul nsw i8 %x, 7
+  %mul2 = mul nsw i8 %y, 7
+  %cmp = icmp sgt i8 %mul1, %mul2
+  ret i1 %cmp
+}
+
+define i1 @icmp_mul_nsw_sge(i8 %x, i8 %y) {
+; CHECK-LABEL: @icmp_mul_nsw_sge(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[CMP:%.*]] = icmp sge i8 [[X:%.*]], [[Y:%.*]]
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+entry:
+  %mul1 = mul nsw i8 %x, 7
+  %mul2 = mul nsw i8 %y, 7
+  %cmp = icmp sge i8 %mul1, %mul2
+  ret i1 %cmp
+}
+
+define i1 @icmp_mul_nsw_slt_neg(i8 %x, i8 %y) {
+; CHECK-LABEL: @icmp_mul_nsw_slt_neg(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[CMP:%.*]] = icmp sgt i8 [[X:%.*]], [[Y:%.*]]
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+entry:
+  %mul1 = mul nsw i8 %x, -7
+  %mul2 = mul nsw i8 %y, -7
+  %cmp = icmp slt i8 %mul1, %mul2
+  ret i1 %cmp
+}
+
+define i1 @icmp_mul_nsw_slt_neg_var(i8 %x, i8 %y, i8 %z) {
+; CHECK-LABEL: @icmp_mul_nsw_slt_neg_var(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[COND:%.*]] = icmp slt i8 [[Z:%.*]], 0
+; CHECK-NEXT:    call void @llvm.assume(i1 [[COND]])
+; CHECK-NEXT:    [[CMP:%.*]] = icmp sgt i8 [[X:%.*]], [[Y:%.*]]
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+entry:
+  %cond = icmp slt i8 %z, 0
+  call void @llvm.assume(i1 %cond)
+  %mul1 = mul nsw i8 %x, %z
+  %mul2 = mul nsw i8 %y, %z
+  %cmp = icmp slt i8 %mul1, %mul2
+  ret i1 %cmp
+}
+
+; Negative tests
+
+define i1 @icmp_mul_nonsw_slt(i8 %x, i8 %y) {
+; CHECK-LABEL: @icmp_mul_nonsw_slt(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[MUL1:%.*]] = mul i8 [[X:%.*]], 7
+; CHECK-NEXT:    [[MUL2:%.*]] = mul nsw i8 [[Y:%.*]], 7
+; CHECK-NEXT:    [[CMP:%.*]] = icmp slt i8 [[MUL1]], [[MUL2]]
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+entry:
+  %mul1 = mul i8 %x, 7
+  %mul2 = mul nsw i8 %y, 7
+  %cmp = icmp slt i8 %mul1, %mul2
+  ret i1 %cmp
+}
+
+define i1 @icmp_mul_nsw_slt_unknown_sign(i8 %x, i8 %y, i8 %z) {
+; CHECK-LABEL: @icmp_mul_nsw_slt_unknown_sign(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[MUL1:%.*]] = mul nsw i8 [[X:%.*]], [[Z:%.*]]
+; CHECK-NEXT:    [[MUL2:%.*]] = mul nsw i8 [[Y:%.*]], [[Z]]
+; CHECK-NEXT:    [[CMP:%.*]] = icmp slt i8 [[MUL1]], [[MUL2]]
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+entry:
+  %mul1 = mul nsw i8 %x, %z
+  %mul2 = mul nsw i8 %y, %z
+  %cmp = icmp slt i8 %mul1, %mul2
+  ret i1 %cmp
+}
+
+define i1 @icmp_mul_nsw_slt_may_be_zero(i8 %x, i8 %y, i8 %z) {
+; CHECK-LABEL: @icmp_mul_nsw_slt_may_be_zero(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[COND:%.*]] = icmp sgt i8 [[Z:%.*]], -1
+; CHECK-NEXT:    call void @llvm.assume(i1 [[COND]])
+; CHECK-NEXT:    [[MUL1:%.*]] = mul nsw i8 [[X:%.*]], [[Z]]
+; CHECK-NEXT:    [[MUL2:%.*]] = mul nsw i8 [[Y:%.*]], [[Z]]
+; CHECK-NEXT:    [[CMP:%.*]] = icmp slt i8 [[MUL1]], [[MUL2]]
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+entry:
+  %cond = icmp sgt i8 %z, -1
+  call void @llvm.assume(i1 %cond)
+
+  %mul1 = mul nsw i8 %x, %z
+  %mul2 = mul nsw i8 %y, %z
+  %cmp = icmp slt i8 %mul1, %mul2
+  ret i1 %cmp
+}


        


More information about the llvm-commits mailing list