[llvm] [InstCombine] Fold `icmp spred (mul nsw X, Z), (mul nsw Y, Z)` into `icmp spred X, Y` (PR #110630)

Yingwei Zheng via llvm-commits llvm-commits at lists.llvm.org
Mon Sep 30 23:06:16 PDT 2024


https://github.com/dtcxzyw created https://github.com/llvm/llvm-project/pull/110630

```
icmp spred (mul nsw X, Z), (mul nsw Y, Z) -> icmp spred X, Y iff Z > 0
icmp spred (mul nsw X, Z), (mul nsw Y, Z) -> icmp spred Y, X iff Z < 0
```
Alive2: https://alive2.llvm.org/ce/z/9fXFfn


>From 2c014e9380095e25fda19ff91f0f7f530105795b Mon Sep 17 00:00:00 2001
From: Yingwei Zheng <dtcxzyw2333 at gmail.com>
Date: Tue, 1 Oct 2024 13:50:18 +0800
Subject: [PATCH 1/2] [InstCombine] Add pre-commit tests. NFC.

---
 llvm/test/Transforms/InstCombine/icmp-mul.ll | 146 +++++++++++++++++++
 1 file changed, 146 insertions(+)

diff --git a/llvm/test/Transforms/InstCombine/icmp-mul.ll b/llvm/test/Transforms/InstCombine/icmp-mul.ll
index 7ce43908c62cd0..88a5cf8ba31d38 100644
--- a/llvm/test/Transforms/InstCombine/icmp-mul.ll
+++ b/llvm/test/Transforms/InstCombine/icmp-mul.ll
@@ -1330,3 +1330,149 @@ entry:
   %cmp = icmp ult i8 %mul1, %mul2
   ret i1 %cmp
 }
+
+define i1 @icmp_mul_nsw_slt(i8 %x, i8 %y) {
+; CHECK-LABEL: @icmp_mul_nsw_slt(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[MUL1:%.*]] = mul nsw i8 [[X:%.*]], 7
+; CHECK-NEXT:    [[MUL2:%.*]] = mul nsw i8 [[Y:%.*]], 7
+; CHECK-NEXT:    [[CMP:%.*]] = icmp slt i8 [[MUL1]], [[MUL2]]
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+entry:
+  %mul1 = mul nsw i8 %x, 7
+  %mul2 = mul nsw i8 %y, 7
+  %cmp = icmp slt i8 %mul1, %mul2
+  ret i1 %cmp
+}
+
+define i1 @icmp_mul_nsw_sle(i8 %x, i8 %y) {
+; CHECK-LABEL: @icmp_mul_nsw_sle(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[MUL1:%.*]] = mul nsw i8 [[X:%.*]], 7
+; CHECK-NEXT:    [[MUL2:%.*]] = mul nsw i8 [[Y:%.*]], 7
+; CHECK-NEXT:    [[CMP:%.*]] = icmp sle i8 [[MUL1]], [[MUL2]]
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+entry:
+  %mul1 = mul nsw i8 %x, 7
+  %mul2 = mul nsw i8 %y, 7
+  %cmp = icmp sle i8 %mul1, %mul2
+  ret i1 %cmp
+}
+
+define i1 @icmp_mul_nsw_sgt(i8 %x, i8 %y) {
+; CHECK-LABEL: @icmp_mul_nsw_sgt(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[MUL1:%.*]] = mul nsw i8 [[X:%.*]], 7
+; CHECK-NEXT:    [[MUL2:%.*]] = mul nsw i8 [[Y:%.*]], 7
+; CHECK-NEXT:    [[CMP:%.*]] = icmp sgt i8 [[MUL1]], [[MUL2]]
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+entry:
+  %mul1 = mul nsw i8 %x, 7
+  %mul2 = mul nsw i8 %y, 7
+  %cmp = icmp sgt i8 %mul1, %mul2
+  ret i1 %cmp
+}
+
+define i1 @icmp_mul_nsw_sge(i8 %x, i8 %y) {
+; CHECK-LABEL: @icmp_mul_nsw_sge(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[MUL1:%.*]] = mul nsw i8 [[X:%.*]], 7
+; CHECK-NEXT:    [[MUL2:%.*]] = mul nsw i8 [[Y:%.*]], 7
+; CHECK-NEXT:    [[CMP:%.*]] = icmp sge i8 [[MUL1]], [[MUL2]]
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+entry:
+  %mul1 = mul nsw i8 %x, 7
+  %mul2 = mul nsw i8 %y, 7
+  %cmp = icmp sge i8 %mul1, %mul2
+  ret i1 %cmp
+}
+
+define i1 @icmp_mul_nsw_slt_neg(i8 %x, i8 %y) {
+; CHECK-LABEL: @icmp_mul_nsw_slt_neg(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[MUL1:%.*]] = mul nsw i8 [[X:%.*]], -7
+; CHECK-NEXT:    [[MUL2:%.*]] = mul nsw i8 [[Y:%.*]], -7
+; CHECK-NEXT:    [[CMP:%.*]] = icmp slt i8 [[MUL1]], [[MUL2]]
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+entry:
+  %mul1 = mul nsw i8 %x, -7
+  %mul2 = mul nsw i8 %y, -7
+  %cmp = icmp slt i8 %mul1, %mul2
+  ret i1 %cmp
+}
+
+define i1 @icmp_mul_nsw_slt_neg_var(i8 %x, i8 %y, i8 %z) {
+; CHECK-LABEL: @icmp_mul_nsw_slt_neg_var(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[COND:%.*]] = icmp slt i8 [[Z:%.*]], 0
+; CHECK-NEXT:    call void @llvm.assume(i1 [[COND]])
+; CHECK-NEXT:    [[MUL1:%.*]] = mul nsw i8 [[X:%.*]], [[Z]]
+; CHECK-NEXT:    [[MUL2:%.*]] = mul nsw i8 [[Y:%.*]], [[Z]]
+; CHECK-NEXT:    [[CMP:%.*]] = icmp slt i8 [[MUL1]], [[MUL2]]
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+entry:
+  %cond = icmp slt i8 %z, 0
+  call void @llvm.assume(i1 %cond)
+  %mul1 = mul nsw i8 %x, %z
+  %mul2 = mul nsw i8 %y, %z
+  %cmp = icmp slt i8 %mul1, %mul2
+  ret i1 %cmp
+}
+
+; Negative tests
+
+define i1 @icmp_mul_nonsw_slt(i8 %x, i8 %y) {
+; CHECK-LABEL: @icmp_mul_nonsw_slt(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[MUL1:%.*]] = mul i8 [[X:%.*]], 7
+; CHECK-NEXT:    [[MUL2:%.*]] = mul nsw i8 [[Y:%.*]], 7
+; CHECK-NEXT:    [[CMP:%.*]] = icmp slt i8 [[MUL1]], [[MUL2]]
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+entry:
+  %mul1 = mul i8 %x, 7
+  %mul2 = mul nsw i8 %y, 7
+  %cmp = icmp slt i8 %mul1, %mul2
+  ret i1 %cmp
+}
+
+define i1 @icmp_mul_nsw_slt_unknown_sign(i8 %x, i8 %y, i8 %z) {
+; CHECK-LABEL: @icmp_mul_nsw_slt_unknown_sign(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[MUL1:%.*]] = mul nsw i8 [[X:%.*]], [[Z:%.*]]
+; CHECK-NEXT:    [[MUL2:%.*]] = mul nsw i8 [[Y:%.*]], [[Z]]
+; CHECK-NEXT:    [[CMP:%.*]] = icmp slt i8 [[MUL1]], [[MUL2]]
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+entry:
+  %mul1 = mul nsw i8 %x, %z
+  %mul2 = mul nsw i8 %y, %z
+  %cmp = icmp slt i8 %mul1, %mul2
+  ret i1 %cmp
+}
+
+define i1 @icmp_mul_nsw_slt_may_be_zero(i8 %x, i8 %y, i8 %z) {
+; CHECK-LABEL: @icmp_mul_nsw_slt_may_be_zero(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[COND:%.*]] = icmp sgt i8 [[Z:%.*]], -1
+; CHECK-NEXT:    call void @llvm.assume(i1 [[COND]])
+; CHECK-NEXT:    [[MUL1:%.*]] = mul nsw i8 [[X:%.*]], [[Z]]
+; CHECK-NEXT:    [[MUL2:%.*]] = mul nsw i8 [[Y:%.*]], [[Z]]
+; CHECK-NEXT:    [[CMP:%.*]] = icmp slt i8 [[MUL1]], [[MUL2]]
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+entry:
+  %cond = icmp sgt i8 %z, -1
+  call void @llvm.assume(i1 %cond)
+
+  %mul1 = mul nsw i8 %x, %z
+  %mul2 = mul nsw i8 %y, %z
+  %cmp = icmp slt i8 %mul1, %mul2
+  ret i1 %cmp
+}

>From df173c3c724d8c6281e3eca8f9ae85be8652c142 Mon Sep 17 00:00:00 2001
From: Yingwei Zheng <dtcxzyw2333 at gmail.com>
Date: Tue, 1 Oct 2024 14:02:35 +0800
Subject: [PATCH 2/2] [InstCombine] Fold `icmp spred (mul nsw X, Z), (mul nsw
 Y, Z)` into `icmp spred X, Y`

---
 .../InstCombine/InstCombineCompares.cpp       | 67 +++++++++++--------
 llvm/test/Transforms/InstCombine/icmp-mul.ll  | 24 ++-----
 2 files changed, 44 insertions(+), 47 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index e3f4925024e65c..d4d45384ec90e3 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -5273,37 +5273,46 @@ Instruction *InstCombinerImpl::foldICmpBinOp(ICmpInst &I,
 
   {
     // Try to remove shared multiplier from comparison:
-    // X * Z u{lt/le/gt/ge}/eq/ne Y * Z
+    // X * Z pred Y * Z
     Value *X, *Y, *Z;
-    if (Pred == ICmpInst::getUnsignedPredicate(Pred) &&
-        ((match(Op0, m_Mul(m_Value(X), m_Value(Z))) &&
-          match(Op1, m_c_Mul(m_Specific(Z), m_Value(Y)))) ||
-         (match(Op0, m_Mul(m_Value(Z), m_Value(X))) &&
-          match(Op1, m_c_Mul(m_Specific(Z), m_Value(Y)))))) {
-      bool NonZero;
-      if (ICmpInst::isEquality(Pred)) {
-        // If X != Y, fold (X *nw Z) eq/ne (Y *nw Z) -> Z eq/ne 0
-        if (((Op0HasNSW && Op1HasNSW) || (Op0HasNUW && Op1HasNUW)) &&
-            isKnownNonEqual(X, Y, DL, &AC, &I, &DT))
-          return new ICmpInst(Pred, Z, Constant::getNullValue(Z->getType()));
-
-        KnownBits ZKnown = computeKnownBits(Z, 0, &I);
-        // if Z % 2 != 0
-        //    X * Z eq/ne Y * Z -> X eq/ne Y
-        if (ZKnown.countMaxTrailingZeros() == 0)
-          return new ICmpInst(Pred, X, Y);
-        NonZero = !ZKnown.One.isZero() || isKnownNonZero(Z, Q);
-        // if Z != 0 and nsw(X * Z) and nsw(Y * Z)
-        //    X * Z eq/ne Y * Z -> X eq/ne Y
-        if (NonZero && BO0 && BO1 && Op0HasNSW && Op1HasNSW)
+    if ((match(Op0, m_Mul(m_Value(X), m_Value(Z))) &&
+         match(Op1, m_c_Mul(m_Specific(Z), m_Value(Y)))) ||
+        (match(Op0, m_Mul(m_Value(Z), m_Value(X))) &&
+         match(Op1, m_c_Mul(m_Specific(Z), m_Value(Y))))) {
+      if (ICmpInst::isSigned(Pred)) {
+        if (Op0HasNSW && Op1HasNSW) {
+          KnownBits ZKnown = computeKnownBits(Z, 0, &I);
+          if (ZKnown.isStrictlyPositive())
+            return new ICmpInst(Pred, X, Y);
+          if (ZKnown.isNegative())
+            return new ICmpInst(ICmpInst::getSwappedPredicate(Pred), X, Y);
+        }
+      } else {
+        bool NonZero;
+        if (ICmpInst::isEquality(Pred)) {
+          // If X != Y, fold (X *nw Z) eq/ne (Y *nw Z) -> Z eq/ne 0
+          if (((Op0HasNSW && Op1HasNSW) || (Op0HasNUW && Op1HasNUW)) &&
+              isKnownNonEqual(X, Y, DL, &AC, &I, &DT))
+            return new ICmpInst(Pred, Z, Constant::getNullValue(Z->getType()));
+
+          KnownBits ZKnown = computeKnownBits(Z, 0, &I);
+          // if Z % 2 != 0
+          //    X * Z eq/ne Y * Z -> X eq/ne Y
+          if (ZKnown.countMaxTrailingZeros() == 0)
+            return new ICmpInst(Pred, X, Y);
+          NonZero = !ZKnown.One.isZero() || isKnownNonZero(Z, Q);
+          // if Z != 0 and nsw(X * Z) and nsw(Y * Z)
+          //    X * Z eq/ne Y * Z -> X eq/ne Y
+          if (NonZero && BO0 && BO1 && Op0HasNSW && Op1HasNSW)
+            return new ICmpInst(Pred, X, Y);
+        } else
+          NonZero = isKnownNonZero(Z, Q);
+
+        // If Z != 0 and nuw(X * Z) and nuw(Y * Z)
+        //    X * Z u{lt/le/gt/ge}/eq/ne Y * Z -> X u{lt/le/gt/ge}/eq/ne Y
+        if (NonZero && BO0 && BO1 && Op0HasNUW && Op1HasNUW)
           return new ICmpInst(Pred, X, Y);
-      } else
-        NonZero = isKnownNonZero(Z, Q);
-
-      // If Z != 0 and nuw(X * Z) and nuw(Y * Z)
-      //    X * Z u{lt/le/gt/ge}/eq/ne Y * Z -> X u{lt/le/gt/ge}/eq/ne Y
-      if (NonZero && BO0 && BO1 && Op0HasNUW && Op1HasNUW)
-        return new ICmpInst(Pred, X, Y);
+      }
     }
   }
 
diff --git a/llvm/test/Transforms/InstCombine/icmp-mul.ll b/llvm/test/Transforms/InstCombine/icmp-mul.ll
index 88a5cf8ba31d38..a14f342ae2482b 100644
--- a/llvm/test/Transforms/InstCombine/icmp-mul.ll
+++ b/llvm/test/Transforms/InstCombine/icmp-mul.ll
@@ -1334,9 +1334,7 @@ entry:
 define i1 @icmp_mul_nsw_slt(i8 %x, i8 %y) {
 ; CHECK-LABEL: @icmp_mul_nsw_slt(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[MUL1:%.*]] = mul nsw i8 [[X:%.*]], 7
-; CHECK-NEXT:    [[MUL2:%.*]] = mul nsw i8 [[Y:%.*]], 7
-; CHECK-NEXT:    [[CMP:%.*]] = icmp slt i8 [[MUL1]], [[MUL2]]
+; CHECK-NEXT:    [[CMP:%.*]] = icmp slt i8 [[X:%.*]], [[Y:%.*]]
 ; CHECK-NEXT:    ret i1 [[CMP]]
 ;
 entry:
@@ -1349,9 +1347,7 @@ entry:
 define i1 @icmp_mul_nsw_sle(i8 %x, i8 %y) {
 ; CHECK-LABEL: @icmp_mul_nsw_sle(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[MUL1:%.*]] = mul nsw i8 [[X:%.*]], 7
-; CHECK-NEXT:    [[MUL2:%.*]] = mul nsw i8 [[Y:%.*]], 7
-; CHECK-NEXT:    [[CMP:%.*]] = icmp sle i8 [[MUL1]], [[MUL2]]
+; CHECK-NEXT:    [[CMP:%.*]] = icmp sle i8 [[X:%.*]], [[Y:%.*]]
 ; CHECK-NEXT:    ret i1 [[CMP]]
 ;
 entry:
@@ -1364,9 +1360,7 @@ entry:
 define i1 @icmp_mul_nsw_sgt(i8 %x, i8 %y) {
 ; CHECK-LABEL: @icmp_mul_nsw_sgt(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[MUL1:%.*]] = mul nsw i8 [[X:%.*]], 7
-; CHECK-NEXT:    [[MUL2:%.*]] = mul nsw i8 [[Y:%.*]], 7
-; CHECK-NEXT:    [[CMP:%.*]] = icmp sgt i8 [[MUL1]], [[MUL2]]
+; CHECK-NEXT:    [[CMP:%.*]] = icmp sgt i8 [[X:%.*]], [[Y:%.*]]
 ; CHECK-NEXT:    ret i1 [[CMP]]
 ;
 entry:
@@ -1379,9 +1373,7 @@ entry:
 define i1 @icmp_mul_nsw_sge(i8 %x, i8 %y) {
 ; CHECK-LABEL: @icmp_mul_nsw_sge(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[MUL1:%.*]] = mul nsw i8 [[X:%.*]], 7
-; CHECK-NEXT:    [[MUL2:%.*]] = mul nsw i8 [[Y:%.*]], 7
-; CHECK-NEXT:    [[CMP:%.*]] = icmp sge i8 [[MUL1]], [[MUL2]]
+; CHECK-NEXT:    [[CMP:%.*]] = icmp sge i8 [[X:%.*]], [[Y:%.*]]
 ; CHECK-NEXT:    ret i1 [[CMP]]
 ;
 entry:
@@ -1394,9 +1386,7 @@ entry:
 define i1 @icmp_mul_nsw_slt_neg(i8 %x, i8 %y) {
 ; CHECK-LABEL: @icmp_mul_nsw_slt_neg(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[MUL1:%.*]] = mul nsw i8 [[X:%.*]], -7
-; CHECK-NEXT:    [[MUL2:%.*]] = mul nsw i8 [[Y:%.*]], -7
-; CHECK-NEXT:    [[CMP:%.*]] = icmp slt i8 [[MUL1]], [[MUL2]]
+; CHECK-NEXT:    [[CMP:%.*]] = icmp sgt i8 [[X:%.*]], [[Y:%.*]]
 ; CHECK-NEXT:    ret i1 [[CMP]]
 ;
 entry:
@@ -1411,9 +1401,7 @@ define i1 @icmp_mul_nsw_slt_neg_var(i8 %x, i8 %y, i8 %z) {
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[COND:%.*]] = icmp slt i8 [[Z:%.*]], 0
 ; CHECK-NEXT:    call void @llvm.assume(i1 [[COND]])
-; CHECK-NEXT:    [[MUL1:%.*]] = mul nsw i8 [[X:%.*]], [[Z]]
-; CHECK-NEXT:    [[MUL2:%.*]] = mul nsw i8 [[Y:%.*]], [[Z]]
-; CHECK-NEXT:    [[CMP:%.*]] = icmp slt i8 [[MUL1]], [[MUL2]]
+; CHECK-NEXT:    [[CMP:%.*]] = icmp sgt i8 [[X:%.*]], [[Y:%.*]]
 ; CHECK-NEXT:    ret i1 [[CMP]]
 ;
 entry:



More information about the llvm-commits mailing list