[llvm] 2a2c35a - [InstCombine] Fold `icmp spred (mul nsw X, Z), (mul nsw Y, Z)` into `icmp spred X, Y` (#110630)
via llvm-commits
llvm-commits at lists.llvm.org
Tue Oct 1 07:16:08 PDT 2024
Author: Yingwei Zheng
Date: 2024-10-01T22:16:05+08:00
New Revision: 2a2c35a9a652ba8562884ec76008979c761df207
URL: https://github.com/llvm/llvm-project/commit/2a2c35a9a652ba8562884ec76008979c761df207
DIFF: https://github.com/llvm/llvm-project/commit/2a2c35a9a652ba8562884ec76008979c761df207.diff
LOG: [InstCombine] Fold `icmp spred (mul nsw X, Z), (mul nsw Y, Z)` into `icmp spred X, Y` (#110630)
```
icmp spred (mul nsw X, Z), (mul nsw Y, Z) -> icmp spred X, Y iff Z > 0
icmp spred (mul nsw X, Z), (mul nsw Y, Z) -> icmp spred Y, X iff Z < 0
```
Alive2: https://alive2.llvm.org/ce/z/9fXFfn
Added:
Modified:
llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
llvm/test/Transforms/InstCombine/icmp-mul.ll
Removed:
################################################################################
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index e3f4925024e65c..d4d45384ec90e3 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -5273,37 +5273,46 @@ Instruction *InstCombinerImpl::foldICmpBinOp(ICmpInst &I,
{
// Try to remove shared multiplier from comparison:
- // X * Z u{lt/le/gt/ge}/eq/ne Y * Z
+ // X * Z pred Y * Z
Value *X, *Y, *Z;
- if (Pred == ICmpInst::getUnsignedPredicate(Pred) &&
- ((match(Op0, m_Mul(m_Value(X), m_Value(Z))) &&
- match(Op1, m_c_Mul(m_Specific(Z), m_Value(Y)))) ||
- (match(Op0, m_Mul(m_Value(Z), m_Value(X))) &&
- match(Op1, m_c_Mul(m_Specific(Z), m_Value(Y)))))) {
- bool NonZero;
- if (ICmpInst::isEquality(Pred)) {
- // If X != Y, fold (X *nw Z) eq/ne (Y *nw Z) -> Z eq/ne 0
- if (((Op0HasNSW && Op1HasNSW) || (Op0HasNUW && Op1HasNUW)) &&
- isKnownNonEqual(X, Y, DL, &AC, &I, &DT))
- return new ICmpInst(Pred, Z, Constant::getNullValue(Z->getType()));
-
- KnownBits ZKnown = computeKnownBits(Z, 0, &I);
- // if Z % 2 != 0
- // X * Z eq/ne Y * Z -> X eq/ne Y
- if (ZKnown.countMaxTrailingZeros() == 0)
- return new ICmpInst(Pred, X, Y);
- NonZero = !ZKnown.One.isZero() || isKnownNonZero(Z, Q);
- // if Z != 0 and nsw(X * Z) and nsw(Y * Z)
- // X * Z eq/ne Y * Z -> X eq/ne Y
- if (NonZero && BO0 && BO1 && Op0HasNSW && Op1HasNSW)
+ if ((match(Op0, m_Mul(m_Value(X), m_Value(Z))) &&
+ match(Op1, m_c_Mul(m_Specific(Z), m_Value(Y)))) ||
+ (match(Op0, m_Mul(m_Value(Z), m_Value(X))) &&
+ match(Op1, m_c_Mul(m_Specific(Z), m_Value(Y))))) {
+ if (ICmpInst::isSigned(Pred)) {
+ if (Op0HasNSW && Op1HasNSW) {
+ KnownBits ZKnown = computeKnownBits(Z, 0, &I);
+ if (ZKnown.isStrictlyPositive())
+ return new ICmpInst(Pred, X, Y);
+ if (ZKnown.isNegative())
+ return new ICmpInst(ICmpInst::getSwappedPredicate(Pred), X, Y);
+ }
+ } else {
+ bool NonZero;
+ if (ICmpInst::isEquality(Pred)) {
+ // If X != Y, fold (X *nw Z) eq/ne (Y *nw Z) -> Z eq/ne 0
+ if (((Op0HasNSW && Op1HasNSW) || (Op0HasNUW && Op1HasNUW)) &&
+ isKnownNonEqual(X, Y, DL, &AC, &I, &DT))
+ return new ICmpInst(Pred, Z, Constant::getNullValue(Z->getType()));
+
+ KnownBits ZKnown = computeKnownBits(Z, 0, &I);
+ // if Z % 2 != 0
+ // X * Z eq/ne Y * Z -> X eq/ne Y
+ if (ZKnown.countMaxTrailingZeros() == 0)
+ return new ICmpInst(Pred, X, Y);
+ NonZero = !ZKnown.One.isZero() || isKnownNonZero(Z, Q);
+ // if Z != 0 and nsw(X * Z) and nsw(Y * Z)
+ // X * Z eq/ne Y * Z -> X eq/ne Y
+ if (NonZero && BO0 && BO1 && Op0HasNSW && Op1HasNSW)
+ return new ICmpInst(Pred, X, Y);
+ } else
+ NonZero = isKnownNonZero(Z, Q);
+
+ // If Z != 0 and nuw(X * Z) and nuw(Y * Z)
+ // X * Z u{lt/le/gt/ge}/eq/ne Y * Z -> X u{lt/le/gt/ge}/eq/ne Y
+ if (NonZero && BO0 && BO1 && Op0HasNUW && Op1HasNUW)
return new ICmpInst(Pred, X, Y);
- } else
- NonZero = isKnownNonZero(Z, Q);
-
- // If Z != 0 and nuw(X * Z) and nuw(Y * Z)
- // X * Z u{lt/le/gt/ge}/eq/ne Y * Z -> X u{lt/le/gt/ge}/eq/ne Y
- if (NonZero && BO0 && BO1 && Op0HasNUW && Op1HasNUW)
- return new ICmpInst(Pred, X, Y);
+ }
}
}
diff --git a/llvm/test/Transforms/InstCombine/icmp-mul.ll b/llvm/test/Transforms/InstCombine/icmp-mul.ll
index 7ce43908c62cd0..a14f342ae2482b 100644
--- a/llvm/test/Transforms/InstCombine/icmp-mul.ll
+++ b/llvm/test/Transforms/InstCombine/icmp-mul.ll
@@ -1330,3 +1330,137 @@ entry:
%cmp = icmp ult i8 %mul1, %mul2
ret i1 %cmp
}
+
+define i1 @icmp_mul_nsw_slt(i8 %x, i8 %y) {
+; CHECK-LABEL: @icmp_mul_nsw_slt(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: [[CMP:%.*]] = icmp slt i8 [[X:%.*]], [[Y:%.*]]
+; CHECK-NEXT: ret i1 [[CMP]]
+;
+entry:
+ %mul1 = mul nsw i8 %x, 7
+ %mul2 = mul nsw i8 %y, 7
+ %cmp = icmp slt i8 %mul1, %mul2
+ ret i1 %cmp
+}
+
+define i1 @icmp_mul_nsw_sle(i8 %x, i8 %y) {
+; CHECK-LABEL: @icmp_mul_nsw_sle(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: [[CMP:%.*]] = icmp sle i8 [[X:%.*]], [[Y:%.*]]
+; CHECK-NEXT: ret i1 [[CMP]]
+;
+entry:
+ %mul1 = mul nsw i8 %x, 7
+ %mul2 = mul nsw i8 %y, 7
+ %cmp = icmp sle i8 %mul1, %mul2
+ ret i1 %cmp
+}
+
+define i1 @icmp_mul_nsw_sgt(i8 %x, i8 %y) {
+; CHECK-LABEL: @icmp_mul_nsw_sgt(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i8 [[X:%.*]], [[Y:%.*]]
+; CHECK-NEXT: ret i1 [[CMP]]
+;
+entry:
+ %mul1 = mul nsw i8 %x, 7
+ %mul2 = mul nsw i8 %y, 7
+ %cmp = icmp sgt i8 %mul1, %mul2
+ ret i1 %cmp
+}
+
+define i1 @icmp_mul_nsw_sge(i8 %x, i8 %y) {
+; CHECK-LABEL: @icmp_mul_nsw_sge(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: [[CMP:%.*]] = icmp sge i8 [[X:%.*]], [[Y:%.*]]
+; CHECK-NEXT: ret i1 [[CMP]]
+;
+entry:
+ %mul1 = mul nsw i8 %x, 7
+ %mul2 = mul nsw i8 %y, 7
+ %cmp = icmp sge i8 %mul1, %mul2
+ ret i1 %cmp
+}
+
+define i1 @icmp_mul_nsw_slt_neg(i8 %x, i8 %y) {
+; CHECK-LABEL: @icmp_mul_nsw_slt_neg(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i8 [[X:%.*]], [[Y:%.*]]
+; CHECK-NEXT: ret i1 [[CMP]]
+;
+entry:
+ %mul1 = mul nsw i8 %x, -7
+ %mul2 = mul nsw i8 %y, -7
+ %cmp = icmp slt i8 %mul1, %mul2
+ ret i1 %cmp
+}
+
+define i1 @icmp_mul_nsw_slt_neg_var(i8 %x, i8 %y, i8 %z) {
+; CHECK-LABEL: @icmp_mul_nsw_slt_neg_var(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: [[COND:%.*]] = icmp slt i8 [[Z:%.*]], 0
+; CHECK-NEXT: call void @llvm.assume(i1 [[COND]])
+; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i8 [[X:%.*]], [[Y:%.*]]
+; CHECK-NEXT: ret i1 [[CMP]]
+;
+entry:
+ %cond = icmp slt i8 %z, 0
+ call void @llvm.assume(i1 %cond)
+ %mul1 = mul nsw i8 %x, %z
+ %mul2 = mul nsw i8 %y, %z
+ %cmp = icmp slt i8 %mul1, %mul2
+ ret i1 %cmp
+}
+
+; Negative tests
+
+define i1 @icmp_mul_nonsw_slt(i8 %x, i8 %y) {
+; CHECK-LABEL: @icmp_mul_nonsw_slt(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: [[MUL1:%.*]] = mul i8 [[X:%.*]], 7
+; CHECK-NEXT: [[MUL2:%.*]] = mul nsw i8 [[Y:%.*]], 7
+; CHECK-NEXT: [[CMP:%.*]] = icmp slt i8 [[MUL1]], [[MUL2]]
+; CHECK-NEXT: ret i1 [[CMP]]
+;
+entry:
+ %mul1 = mul i8 %x, 7
+ %mul2 = mul nsw i8 %y, 7
+ %cmp = icmp slt i8 %mul1, %mul2
+ ret i1 %cmp
+}
+
+define i1 @icmp_mul_nsw_slt_unknown_sign(i8 %x, i8 %y, i8 %z) {
+; CHECK-LABEL: @icmp_mul_nsw_slt_unknown_sign(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: [[MUL1:%.*]] = mul nsw i8 [[X:%.*]], [[Z:%.*]]
+; CHECK-NEXT: [[MUL2:%.*]] = mul nsw i8 [[Y:%.*]], [[Z]]
+; CHECK-NEXT: [[CMP:%.*]] = icmp slt i8 [[MUL1]], [[MUL2]]
+; CHECK-NEXT: ret i1 [[CMP]]
+;
+entry:
+ %mul1 = mul nsw i8 %x, %z
+ %mul2 = mul nsw i8 %y, %z
+ %cmp = icmp slt i8 %mul1, %mul2
+ ret i1 %cmp
+}
+
+define i1 @icmp_mul_nsw_slt_may_be_zero(i8 %x, i8 %y, i8 %z) {
+; CHECK-LABEL: @icmp_mul_nsw_slt_may_be_zero(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: [[COND:%.*]] = icmp sgt i8 [[Z:%.*]], -1
+; CHECK-NEXT: call void @llvm.assume(i1 [[COND]])
+; CHECK-NEXT: [[MUL1:%.*]] = mul nsw i8 [[X:%.*]], [[Z]]
+; CHECK-NEXT: [[MUL2:%.*]] = mul nsw i8 [[Y:%.*]], [[Z]]
+; CHECK-NEXT: [[CMP:%.*]] = icmp slt i8 [[MUL1]], [[MUL2]]
+; CHECK-NEXT: ret i1 [[CMP]]
+;
+entry:
+ %cond = icmp sgt i8 %z, -1
+ call void @llvm.assume(i1 %cond)
+
+ %mul1 = mul nsw i8 %x, %z
+ %mul2 = mul nsw i8 %y, %z
+ %cmp = icmp slt i8 %mul1, %mul2
+ ret i1 %cmp
+}
More information about the llvm-commits
mailing list