[llvm] b60cf84 - [InstCombine] Add more cases for simplifying `(icmp (and/or x, Mask), y)`
Noah Goldstein via llvm-commits
llvm-commits at lists.llvm.org
Tue Mar 19 15:17:49 PDT 2024
Author: Noah Goldstein
Date: 2024-03-19T17:17:35-05:00
New Revision: b60cf84e0965ac12b83494f803ea0dd6dec0db77
URL: https://github.com/llvm/llvm-project/commit/b60cf84e0965ac12b83494f803ea0dd6dec0db77
DIFF: https://github.com/llvm/llvm-project/commit/b60cf84e0965ac12b83494f803ea0dd6dec0db77.diff
LOG: [InstCombine] Add more cases for simplifying `(icmp (and/or x, Mask), y)`
This cleans up basically all the regressions assosiated from #84688
Proof of all new cases: https://alive2.llvm.org/ce/z/5yYWLb
Closes #85445
Added:
Modified:
llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
llvm/test/Transforms/InstCombine/icmp-and-lowbit-mask.ll
Removed:
################################################################################
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 0dce0077bf1588..db302d7e526844 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -4177,7 +4177,9 @@ static bool isMaskOrZero(const Value *V, bool Not, const SimplifyQuery &Q,
/// a check for a lossy truncation.
/// Folds:
/// icmp SrcPred (x & Mask), x to icmp DstPred x, Mask
+/// icmp SrcPred (x & ~Mask), ~Mask to icmp DstPred x, ~Mask
/// icmp eq/ne (x & ~Mask), 0 to icmp DstPred x, Mask
+/// icmp eq/ne (~x | Mask), -1 to icmp DstPred x, Mask
/// Where Mask is some pattern that produces all-ones in low bits:
/// (-1 >> y)
/// ((-1 << y) >> y) <- non-canonical, has extra uses
@@ -4189,82 +4191,126 @@ static bool isMaskOrZero(const Value *V, bool Not, const SimplifyQuery &Q,
static Value *foldICmpWithLowBitMaskedVal(ICmpInst::Predicate Pred, Value *Op0,
Value *Op1, const SimplifyQuery &Q,
InstCombiner &IC) {
- Value *X, *M;
- bool NeedsNot = false;
-
- auto CheckMask = [&](Value *V, bool Not) {
- if (ICmpInst::isSigned(Pred) && !match(V, m_ImmConstant()))
- return false;
- return isMaskOrZero(V, Not, Q);
- };
-
- if (match(Op0, m_c_And(m_Specific(Op1), m_Value(M))) &&
- CheckMask(M, /*Not*/ false)) {
- X = Op1;
- } else if (match(Op1, m_Zero()) && ICmpInst::isEquality(Pred) &&
- match(Op0, m_OneUse(m_And(m_Value(X), m_Value(M))))) {
- NeedsNot = true;
- if (IC.isFreeToInvert(X, X->hasOneUse()) && CheckMask(X, /*Not*/ true))
- std::swap(X, M);
- else if (!IC.isFreeToInvert(M, M->hasOneUse()) ||
- !CheckMask(M, /*Not*/ true))
- return nullptr;
- } else {
- return nullptr;
- }
ICmpInst::Predicate DstPred;
switch (Pred) {
case ICmpInst::Predicate::ICMP_EQ:
- // x & (-1 >> y) == x -> x u<= (-1 >> y)
+ // x & Mask == x
+ // x & ~Mask == 0
+ // ~x | Mask == -1
+ // -> x u<= Mask
+ // x & ~Mask == ~Mask
+ // -> ~Mask u<= x
DstPred = ICmpInst::Predicate::ICMP_ULE;
break;
case ICmpInst::Predicate::ICMP_NE:
- // x & (-1 >> y) != x -> x u> (-1 >> y)
+ // x & Mask != x
+ // x & ~Mask != 0
+ // ~x | Mask != -1
+ // -> x u> Mask
+ // x & ~Mask != ~Mask
+ // -> ~Mask u> x
DstPred = ICmpInst::Predicate::ICMP_UGT;
break;
case ICmpInst::Predicate::ICMP_ULT:
- // x & (-1 >> y) u< x -> x u> (-1 >> y)
- // x u> x & (-1 >> y) -> x u> (-1 >> y)
+ // x & Mask u< x
+ // -> x u> Mask
+ // x & ~Mask u< ~Mask
+ // -> ~Mask u> x
DstPred = ICmpInst::Predicate::ICMP_UGT;
break;
case ICmpInst::Predicate::ICMP_UGE:
- // x & (-1 >> y) u>= x -> x u<= (-1 >> y)
- // x u<= x & (-1 >> y) -> x u<= (-1 >> y)
+ // x & Mask u>= x
+ // -> x u<= Mask
+ // x & ~Mask u>= ~Mask
+ // -> ~Mask u<= x
DstPred = ICmpInst::Predicate::ICMP_ULE;
break;
case ICmpInst::Predicate::ICMP_SLT:
- // x & (-1 >> y) s< x -> x s> (-1 >> y)
- // x s> x & (-1 >> y) -> x s> (-1 >> y)
- if (!match(M, m_Constant())) // Can not do this fold with non-constant.
- return nullptr;
- if (!match(M, m_NonNegative())) // Must not have any -1 vector elements.
- return nullptr;
+ // x & Mask s< x [iff Mask s>= 0]
+ // -> x s> Mask
+ // x & ~Mask s< ~Mask [iff ~Mask != 0]
+ // -> ~Mask s> x
DstPred = ICmpInst::Predicate::ICMP_SGT;
break;
case ICmpInst::Predicate::ICMP_SGE:
- // x & (-1 >> y) s>= x -> x s<= (-1 >> y)
- // x s<= x & (-1 >> y) -> x s<= (-1 >> y)
- if (!match(M, m_Constant())) // Can not do this fold with non-constant.
- return nullptr;
- if (!match(M, m_NonNegative())) // Must not have any -1 vector elements.
- return nullptr;
+ // x & Mask s>= x [iff Mask s>= 0]
+ // -> x s<= Mask
+ // x & ~Mask s>= ~Mask [iff ~Mask != 0]
+ // -> ~Mask s<= x
DstPred = ICmpInst::Predicate::ICMP_SLE;
break;
- case ICmpInst::Predicate::ICMP_SGT:
- case ICmpInst::Predicate::ICMP_SLE:
- return nullptr;
- case ICmpInst::Predicate::ICMP_UGT:
- case ICmpInst::Predicate::ICMP_ULE:
- llvm_unreachable("Instsimplify took care of commut. variant");
- break;
default:
- llvm_unreachable("All possible folds are handled.");
+ // We don't support sgt,sle
+ // ult/ugt are simplified to true/false respectively.
+ return nullptr;
}
- // The mask value may be a vector constant that has undefined elements. But it
- // may not be safe to propagate those undefs into the new compare, so replace
- // those elements by copying an existing, defined, and safe scalar constant.
+ Value *X, *M;
+ // Put search code in lambda for early positive returns.
+ auto IsLowBitMask = [&]() {
+ if (match(Op0, m_c_And(m_Specific(Op1), m_Value(M)))) {
+ X = Op1;
+ // Look for: x & Mask pred x
+ if (isMaskOrZero(M, /*Not=*/false, Q)) {
+ return !ICmpInst::isSigned(Pred) ||
+ (match(M, m_NonNegative()) || isKnownNonNegative(M, Q));
+ }
+
+ // Look for: x & ~Mask pred ~Mask
+ if (isMaskOrZero(X, /*Not=*/true, Q)) {
+ return !ICmpInst::isSigned(Pred) ||
+ isKnownNonZero(X, Q.DL, /*Depth=*/0, Q.AC, Q.CxtI, Q.DT);
+ }
+ return false;
+ }
+ if (ICmpInst::isEquality(Pred) && match(Op1, m_AllOnes()) &&
+ match(Op0, m_OneUse(m_Or(m_Value(X), m_Value(M))))) {
+
+ auto Check = [&]() {
+ // Look for: ~x | Mask == -1
+ if (isMaskOrZero(M, /*Not=*/false, Q)) {
+ if (Value *NotX =
+ IC.getFreelyInverted(X, X->hasOneUse(), &IC.Builder)) {
+ X = NotX;
+ return true;
+ }
+ }
+ return false;
+ };
+ if (Check())
+ return true;
+ std::swap(X, M);
+ return Check();
+ }
+ if (ICmpInst::isEquality(Pred) && match(Op1, m_Zero()) &&
+ match(Op0, m_OneUse(m_And(m_Value(X), m_Value(M))))) {
+ auto Check = [&]() {
+ // Look for: x & ~Mask == 0
+ if (isMaskOrZero(M, /*Not=*/true, Q)) {
+ if (Value *NotM =
+ IC.getFreelyInverted(M, M->hasOneUse(), &IC.Builder)) {
+ M = NotM;
+ return true;
+ }
+ }
+ return false;
+ };
+ if (Check())
+ return true;
+ std::swap(X, M);
+ return Check();
+ }
+ return false;
+ };
+
+ if (!IsLowBitMask())
+ return nullptr;
+
+ // The mask value may be a vector constant that has undefined elements. But
+ // it may not be safe to propagate those undefs into the new compare, so
+ // replace those elements by copying an existing, defined, and safe scalar
+ // constant.
Type *OpTy = M->getType();
auto *VecC = dyn_cast<Constant>(M);
auto *OpVTy = dyn_cast<FixedVectorType>(OpTy);
@@ -4280,8 +4326,6 @@ static Value *foldICmpWithLowBitMaskedVal(ICmpInst::Predicate Pred, Value *Op0,
M = Constant::replaceUndefsWith(VecC, SafeReplacementConstant);
}
- if (NeedsNot)
- M = IC.Builder.CreateNot(M);
return IC.Builder.CreateICmp(DstPred, X, M);
}
diff --git a/llvm/test/Transforms/InstCombine/icmp-and-lowbit-mask.ll b/llvm/test/Transforms/InstCombine/icmp-and-lowbit-mask.ll
index c74b84e9f4518f..5de3e89d7027ab 100644
--- a/llvm/test/Transforms/InstCombine/icmp-and-lowbit-mask.ll
+++ b/llvm/test/Transforms/InstCombine/icmp-and-lowbit-mask.ll
@@ -680,8 +680,7 @@ define i1 @src_x_and_mask_slt(i8 %x, i8 %y, i1 %cond) {
; CHECK-NEXT: [[MASK:%.*]] = select i1 [[COND:%.*]], i8 [[MASK0]], i8 0
; CHECK-NEXT: [[MASK_POS:%.*]] = icmp sgt i8 [[MASK]], -1
; CHECK-NEXT: call void @llvm.assume(i1 [[MASK_POS]])
-; CHECK-NEXT: [[AND:%.*]] = and i8 [[MASK]], [[X:%.*]]
-; CHECK-NEXT: [[R:%.*]] = icmp slt i8 [[AND]], [[X]]
+; CHECK-NEXT: [[R:%.*]] = icmp slt i8 [[MASK]], [[X:%.*]]
; CHECK-NEXT: ret i1 [[R]]
;
%mask0 = lshr i8 -1, %y
@@ -699,8 +698,7 @@ define i1 @src_x_and_mask_sge(i8 %x, i8 %y, i1 %cond) {
; CHECK-NEXT: [[MASK:%.*]] = select i1 [[COND:%.*]], i8 [[MASK0]], i8 0
; CHECK-NEXT: [[MASK_POS:%.*]] = icmp sgt i8 [[MASK]], -1
; CHECK-NEXT: call void @llvm.assume(i1 [[MASK_POS]])
-; CHECK-NEXT: [[AND:%.*]] = and i8 [[MASK]], [[X:%.*]]
-; CHECK-NEXT: [[R:%.*]] = icmp sge i8 [[AND]], [[X]]
+; CHECK-NEXT: [[R:%.*]] = icmp sge i8 [[MASK]], [[X:%.*]]
; CHECK-NEXT: ret i1 [[R]]
;
%mask0 = lshr i8 -1, %y
@@ -745,9 +743,9 @@ define i1 @src_x_and_mask_sge_fail_maybe_neg(i8 %x, i8 %y, i1 %cond) {
define i1 @src_x_and_nmask_eq(i8 %x, i8 %y, i1 %cond) {
; CHECK-LABEL: @src_x_and_nmask_eq(
; CHECK-NEXT: [[NOT_MASK0:%.*]] = shl nsw i8 -1, [[Y:%.*]]
-; CHECK-NEXT: [[NOT_MASK:%.*]] = select i1 [[COND:%.*]], i8 [[NOT_MASK0]], i8 0
-; CHECK-NEXT: [[AND:%.*]] = and i8 [[NOT_MASK]], [[X:%.*]]
-; CHECK-NEXT: [[R:%.*]] = icmp eq i8 [[NOT_MASK]], [[AND]]
+; CHECK-NEXT: [[R1:%.*]] = icmp ule i8 [[NOT_MASK0]], [[X:%.*]]
+; CHECK-NEXT: [[NOT_COND:%.*]] = xor i1 [[COND:%.*]], true
+; CHECK-NEXT: [[R:%.*]] = select i1 [[NOT_COND]], i1 true, i1 [[R1]]
; CHECK-NEXT: ret i1 [[R]]
;
%not_mask0 = shl i8 -1, %y
@@ -760,9 +758,8 @@ define i1 @src_x_and_nmask_eq(i8 %x, i8 %y, i1 %cond) {
define i1 @src_x_and_nmask_ne(i8 %x, i8 %y, i1 %cond) {
; CHECK-LABEL: @src_x_and_nmask_ne(
; CHECK-NEXT: [[NOT_MASK0:%.*]] = shl nsw i8 -1, [[Y:%.*]]
-; CHECK-NEXT: [[NOT_MASK:%.*]] = select i1 [[COND:%.*]], i8 [[NOT_MASK0]], i8 0
-; CHECK-NEXT: [[AND:%.*]] = and i8 [[NOT_MASK]], [[X:%.*]]
-; CHECK-NEXT: [[R:%.*]] = icmp ne i8 [[AND]], [[NOT_MASK]]
+; CHECK-NEXT: [[R1:%.*]] = icmp ugt i8 [[NOT_MASK0]], [[X:%.*]]
+; CHECK-NEXT: [[R:%.*]] = select i1 [[COND:%.*]], i1 [[R1]], i1 false
; CHECK-NEXT: ret i1 [[R]]
;
%not_mask0 = shl i8 -1, %y
@@ -775,9 +772,8 @@ define i1 @src_x_and_nmask_ne(i8 %x, i8 %y, i1 %cond) {
define i1 @src_x_and_nmask_ult(i8 %x, i8 %y, i1 %cond) {
; CHECK-LABEL: @src_x_and_nmask_ult(
; CHECK-NEXT: [[NOT_MASK0:%.*]] = shl nsw i8 -1, [[Y:%.*]]
-; CHECK-NEXT: [[NOT_MASK:%.*]] = select i1 [[COND:%.*]], i8 [[NOT_MASK0]], i8 0
-; CHECK-NEXT: [[AND:%.*]] = and i8 [[NOT_MASK]], [[X:%.*]]
-; CHECK-NEXT: [[R:%.*]] = icmp ne i8 [[AND]], [[NOT_MASK]]
+; CHECK-NEXT: [[R1:%.*]] = icmp ugt i8 [[NOT_MASK0]], [[X:%.*]]
+; CHECK-NEXT: [[R:%.*]] = select i1 [[COND:%.*]], i1 [[R1]], i1 false
; CHECK-NEXT: ret i1 [[R]]
;
%not_mask0 = shl i8 -1, %y
@@ -790,9 +786,9 @@ define i1 @src_x_and_nmask_ult(i8 %x, i8 %y, i1 %cond) {
define i1 @src_x_and_nmask_uge(i8 %x, i8 %y, i1 %cond) {
; CHECK-LABEL: @src_x_and_nmask_uge(
; CHECK-NEXT: [[NOT_MASK0:%.*]] = shl nsw i8 -1, [[Y:%.*]]
-; CHECK-NEXT: [[NOT_MASK:%.*]] = select i1 [[COND:%.*]], i8 [[NOT_MASK0]], i8 0
-; CHECK-NEXT: [[AND:%.*]] = and i8 [[NOT_MASK]], [[X:%.*]]
-; CHECK-NEXT: [[R:%.*]] = icmp eq i8 [[AND]], [[NOT_MASK]]
+; CHECK-NEXT: [[R1:%.*]] = icmp ule i8 [[NOT_MASK0]], [[X:%.*]]
+; CHECK-NEXT: [[NOT_COND:%.*]] = xor i1 [[COND:%.*]], true
+; CHECK-NEXT: [[R:%.*]] = select i1 [[NOT_COND]], i1 true, i1 [[R1]]
; CHECK-NEXT: ret i1 [[R]]
;
%not_mask0 = shl i8 -1, %y
@@ -805,8 +801,7 @@ define i1 @src_x_and_nmask_uge(i8 %x, i8 %y, i1 %cond) {
define i1 @src_x_and_nmask_slt(i8 %x, i8 %y) {
; CHECK-LABEL: @src_x_and_nmask_slt(
; CHECK-NEXT: [[NOT_MASK:%.*]] = shl nsw i8 -1, [[Y:%.*]]
-; CHECK-NEXT: [[AND:%.*]] = and i8 [[NOT_MASK]], [[X:%.*]]
-; CHECK-NEXT: [[R:%.*]] = icmp slt i8 [[AND]], [[NOT_MASK]]
+; CHECK-NEXT: [[R:%.*]] = icmp sgt i8 [[NOT_MASK]], [[X:%.*]]
; CHECK-NEXT: ret i1 [[R]]
;
%not_mask = shl i8 -1, %y
@@ -818,8 +813,7 @@ define i1 @src_x_and_nmask_slt(i8 %x, i8 %y) {
define i1 @src_x_and_nmask_sge(i8 %x, i8 %y) {
; CHECK-LABEL: @src_x_and_nmask_sge(
; CHECK-NEXT: [[NOT_MASK:%.*]] = shl nsw i8 -1, [[Y:%.*]]
-; CHECK-NEXT: [[AND:%.*]] = and i8 [[NOT_MASK]], [[X:%.*]]
-; CHECK-NEXT: [[R:%.*]] = icmp sge i8 [[AND]], [[NOT_MASK]]
+; CHECK-NEXT: [[R:%.*]] = icmp sle i8 [[NOT_MASK]], [[X:%.*]]
; CHECK-NEXT: ret i1 [[R]]
;
%not_mask = shl i8 -1, %y
@@ -865,9 +859,8 @@ define i1 @src_x_or_mask_eq(i8 %x, i8 %y, i8 %z, i1 %c2, i1 %cond) {
; CHECK-NEXT: [[TMP1:%.*]] = xor i8 [[X:%.*]], -124
; CHECK-NEXT: [[TMP2:%.*]] = select i1 [[C2:%.*]], i8 [[TMP1]], i8 -46
; CHECK-NEXT: [[TMP3:%.*]] = call i8 @llvm.umax.i8(i8 [[Z:%.*]], i8 [[TMP2]])
-; CHECK-NEXT: [[NX_CCC:%.*]] = sub i8 11, [[TMP3]]
-; CHECK-NEXT: [[OR:%.*]] = or i8 [[NX_CCC]], [[MASK]]
-; CHECK-NEXT: [[R:%.*]] = icmp eq i8 [[OR]], -1
+; CHECK-NEXT: [[TMP4:%.*]] = add i8 [[TMP3]], -12
+; CHECK-NEXT: [[R:%.*]] = icmp ule i8 [[TMP4]], [[MASK]]
; CHECK-NEXT: ret i1 [[R]]
;
%mask0 = lshr i8 -1, %y
@@ -886,9 +879,7 @@ define i1 @src_x_or_mask_ne(i8 %x, i8 %y, i1 %cond) {
; CHECK-LABEL: @src_x_or_mask_ne(
; CHECK-NEXT: [[MASK0:%.*]] = lshr i8 -1, [[Y:%.*]]
; CHECK-NEXT: [[MASK:%.*]] = select i1 [[COND:%.*]], i8 [[MASK0]], i8 0
-; CHECK-NEXT: [[NX:%.*]] = xor i8 [[X:%.*]], -1
-; CHECK-NEXT: [[OR:%.*]] = or i8 [[MASK]], [[NX]]
-; CHECK-NEXT: [[R:%.*]] = icmp ne i8 [[OR]], -1
+; CHECK-NEXT: [[R:%.*]] = icmp ult i8 [[MASK]], [[X:%.*]]
; CHECK-NEXT: ret i1 [[R]]
;
%mask0 = lshr i8 -1, %y
More information about the llvm-commits
mailing list