[llvm] [InstCombine] Move foldLogOpOfMaskedICmps to make it possible to handle trunc to i1. (PR #122179)
via llvm-commits
llvm-commits at lists.llvm.org
Wed Jan 8 14:09:00 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-transforms
Author: Andreas Jonson (andjo403)
<details>
<summary>Changes</summary>
Making the move of foldLogOpOfMaskedICmps before adding the handling of `trunc to i1` in getMaskedTypeForICmpPair as there was some diffs in llvm-opt-benchmark due to the move.
---
Full diff: https://github.com/llvm/llvm-project/pull/122179.diff
1 Files Affected:
- (modified) llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp (+126-106)
``````````diff
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index 8bfa3d0f6c5ea1..0aeb025ea44840 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -199,113 +199,132 @@ static bool decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate &Pre
/// the right hand side as a pair.
/// LHS and RHS are the left hand side and the right hand side ICmps and PredL
/// and PredR are their predicates, respectively.
-static std::optional<std::pair<unsigned, unsigned>> getMaskedTypeForICmpPair(
- Value *&A, Value *&B, Value *&C, Value *&D, Value *&E, ICmpInst *LHS,
- ICmpInst *RHS, ICmpInst::Predicate &PredL, ICmpInst::Predicate &PredR) {
- // Don't allow pointers. Splat vectors are fine.
- if (!LHS->getOperand(0)->getType()->isIntOrIntVectorTy() ||
- !RHS->getOperand(0)->getType()->isIntOrIntVectorTy())
- return std::nullopt;
+static std::optional<std::pair<unsigned, unsigned>>
+getMaskedTypeForICmpPair(Value *&A, Value *&B, Value *&C, Value *&D, Value *&E,
+ Value *LHS, Value *RHS, ICmpInst::Predicate &PredL,
+ ICmpInst::Predicate &PredR) {
- // Here comes the tricky part:
- // LHS might be of the form L11 & L12 == X, X == L21 & L22,
- // and L11 & L12 == L21 & L22. The same goes for RHS.
- // Now we must find those components L** and R**, that are equal, so
- // that we can extract the parameters A, B, C, D, and E for the canonical
- // above.
- Value *L1 = LHS->getOperand(0);
- Value *L2 = LHS->getOperand(1);
- Value *L11, *L12, *L21, *L22;
- // Check whether the icmp can be decomposed into a bit test.
- if (decomposeBitTestICmp(L1, L2, PredL, L11, L12, L2)) {
- L21 = L22 = L1 = nullptr;
- } else {
- // Look for ANDs in the LHS icmp.
- if (!match(L1, m_And(m_Value(L11), m_Value(L12)))) {
- // Any icmp can be viewed as being trivially masked; if it allows us to
- // remove one, it's worth it.
- L11 = L1;
- L12 = Constant::getAllOnesValue(L1->getType());
- }
+ Value *L1, *L11, *L12, *L2, *L21, *L22;
+ if (auto *LHSCMP = dyn_cast<ICmpInst>(LHS)) {
+
+ // Don't allow pointers. Splat vectors are fine.
+ if (!LHSCMP->getOperand(0)->getType()->isIntOrIntVectorTy())
+ return std::nullopt;
+
+ PredL = LHSCMP->getPredicate();
+
+ // Here comes the tricky part:
+ // LHS might be of the form L11 & L12 == X, X == L21 & L22,
+ // and L11 & L12 == L21 & L22. The same goes for RHS.
+ // Now we must find those components L** and R**, that are equal, so
+ // that we can extract the parameters A, B, C, D, and E for the canonical
+ // above.
+ L1 = LHSCMP->getOperand(0);
+ L2 = LHSCMP->getOperand(1);
+ // Check whether the icmp can be decomposed into a bit test.
+ if (decomposeBitTestICmp(L1, L2, PredL, L11, L12, L2)) {
+ L21 = L22 = L1 = nullptr;
+ } else {
+ // Look for ANDs in the LHS icmp.
+ if (!match(L1, m_And(m_Value(L11), m_Value(L12)))) {
+ // Any icmp can be viewed as being trivially masked; if it allows us to
+ // remove one, it's worth it.
+ L11 = L1;
+ L12 = Constant::getAllOnesValue(L1->getType());
+ }
- if (!match(L2, m_And(m_Value(L21), m_Value(L22)))) {
- L21 = L2;
- L22 = Constant::getAllOnesValue(L2->getType());
+ if (!match(L2, m_And(m_Value(L21), m_Value(L22)))) {
+ L21 = L2;
+ L22 = Constant::getAllOnesValue(L2->getType());
+ }
}
- }
+ // Bail if LHS was a icmp that can't be decomposed into an equality.
+ if (!ICmpInst::isEquality(PredL))
+ return std::nullopt;
- // Bail if LHS was a icmp that can't be decomposed into an equality.
- if (!ICmpInst::isEquality(PredL))
+ } else {
return std::nullopt;
+ }
- Value *R1 = RHS->getOperand(0);
- Value *R2 = RHS->getOperand(1);
Value *R11, *R12;
- bool Ok = false;
- if (decomposeBitTestICmp(R1, R2, PredR, R11, R12, R2)) {
- if (R11 == L11 || R11 == L12 || R11 == L21 || R11 == L22) {
- A = R11;
- D = R12;
- } else if (R12 == L11 || R12 == L12 || R12 == L21 || R12 == L22) {
- A = R12;
- D = R11;
- } else {
+ if (auto *RHSCMP = dyn_cast<ICmpInst>(RHS)) {
+
+ // Don't allow pointers. Splat vectors are fine.
+ if (!RHSCMP->getOperand(0)->getType()->isIntOrIntVectorTy())
return std::nullopt;
- }
- E = R2;
- R1 = nullptr;
- Ok = true;
- } else {
- if (!match(R1, m_And(m_Value(R11), m_Value(R12)))) {
- // As before, model no mask as a trivial mask if it'll let us do an
- // optimization.
- R11 = R1;
- R12 = Constant::getAllOnesValue(R1->getType());
- }
- if (R11 == L11 || R11 == L12 || R11 == L21 || R11 == L22) {
- A = R11;
- D = R12;
- E = R2;
- Ok = true;
- } else if (R12 == L11 || R12 == L12 || R12 == L21 || R12 == L22) {
- A = R12;
- D = R11;
+ PredR = RHSCMP->getPredicate();
+
+ Value *R1 = RHSCMP->getOperand(0);
+ Value *R2 = RHSCMP->getOperand(1);
+ bool Ok = false;
+ if (decomposeBitTestICmp(R1, R2, PredR, R11, R12, R2)) {
+ if (R11 == L11 || R11 == L12 || R11 == L21 || R11 == L22) {
+ A = R11;
+ D = R12;
+ } else if (R12 == L11 || R12 == L12 || R12 == L21 || R12 == L22) {
+ A = R12;
+ D = R11;
+ } else {
+ return std::nullopt;
+ }
E = R2;
+ R1 = nullptr;
Ok = true;
- }
-
- // Avoid matching against the -1 value we created for unmasked operand.
- if (Ok && match(A, m_AllOnes()))
- Ok = false;
- }
+ } else {
+ if (!match(R1, m_And(m_Value(R11), m_Value(R12)))) {
+ // As before, model no mask as a trivial mask if it'll let us do an
+ // optimization.
+ R11 = R1;
+ R12 = Constant::getAllOnesValue(R1->getType());
+ }
- // Bail if RHS was a icmp that can't be decomposed into an equality.
- if (!ICmpInst::isEquality(PredR))
- return std::nullopt;
+ if (R11 == L11 || R11 == L12 || R11 == L21 || R11 == L22) {
+ A = R11;
+ D = R12;
+ E = R2;
+ Ok = true;
+ } else if (R12 == L11 || R12 == L12 || R12 == L21 || R12 == L22) {
+ A = R12;
+ D = R11;
+ E = R2;
+ Ok = true;
+ }
- // Look for ANDs on the right side of the RHS icmp.
- if (!Ok) {
- if (!match(R2, m_And(m_Value(R11), m_Value(R12)))) {
- R11 = R2;
- R12 = Constant::getAllOnesValue(R2->getType());
+ // Avoid matching against the -1 value we created for unmasked operand.
+ if (Ok && match(A, m_AllOnes()))
+ Ok = false;
}
- if (R11 == L11 || R11 == L12 || R11 == L21 || R11 == L22) {
- A = R11;
- D = R12;
- E = R1;
- Ok = true;
- } else if (R12 == L11 || R12 == L12 || R12 == L21 || R12 == L22) {
- A = R12;
- D = R11;
- E = R1;
- Ok = true;
- } else {
+ // Bail if RHS was a icmp that can't be decomposed into an equality.
+ if (!ICmpInst::isEquality(PredR))
return std::nullopt;
- }
- assert(Ok && "Failed to find AND on the right side of the RHS icmp.");
+ // Look for ANDs on the right side of the RHS icmp.
+ if (!Ok) {
+ if (!match(R2, m_And(m_Value(R11), m_Value(R12)))) {
+ R11 = R2;
+ R12 = Constant::getAllOnesValue(R2->getType());
+ }
+
+ if (R11 == L11 || R11 == L12 || R11 == L21 || R11 == L22) {
+ A = R11;
+ D = R12;
+ E = R1;
+ Ok = true;
+ } else if (R12 == L11 || R12 == L12 || R12 == L21 || R12 == L22) {
+ A = R12;
+ D = R11;
+ E = R1;
+ Ok = true;
+ } else {
+ return std::nullopt;
+ }
+
+ assert(Ok && "Failed to find AND on the right side of the RHS icmp.");
+ }
+ } else {
+ return std::nullopt;
}
if (L11 == A) {
@@ -334,8 +353,8 @@ static std::optional<std::pair<unsigned, unsigned>> getMaskedTypeForICmpPair(
/// (icmp (A & 12) != 0) & (icmp (A & 15) == 8) -> (icmp (A & 15) == 8).
/// Also used for logical and/or, must be poison safe.
static Value *foldLogOpOfMaskedICmps_NotAllZeros_BMask_Mixed(
- ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, Value *A, Value *B, Value *D,
- Value *E, ICmpInst::Predicate PredL, ICmpInst::Predicate PredR,
+ Value *LHS, Value *RHS, bool IsAnd, Value *A, Value *B, Value *D, Value *E,
+ ICmpInst::Predicate PredL, ICmpInst::Predicate PredR,
InstCombiner::BuilderTy &Builder) {
// We are given the canonical form:
// (icmp ne (A & B), 0) & (icmp eq (A & D), E).
@@ -458,7 +477,8 @@ static Value *foldLogOpOfMaskedICmps_NotAllZeros_BMask_Mixed(
// (icmp ne (A & 15), 0) & (icmp eq (A & 15), 8) -> (icmp eq (A & 15), 8).
if (IsSuperSetOrEqual(BCst, DCst)) {
// We can't guarantee that samesign hold after this fold.
- RHS->setSameSign(false);
+ if (auto *ICmp = dyn_cast<ICmpInst>(RHS))
+ ICmp->setSameSign(false);
return RHS;
}
// Otherwise, B is a subset of D. If B and E have a common bit set,
@@ -467,7 +487,8 @@ static Value *foldLogOpOfMaskedICmps_NotAllZeros_BMask_Mixed(
assert(IsSubSetOrEqual(BCst, DCst) && "Precondition due to above code");
if ((*BCst & ECst) != 0) {
// We can't guarantee that samesign hold after this fold.
- RHS->setSameSign(false);
+ if (auto *ICmp = dyn_cast<ICmpInst>(RHS))
+ ICmp->setSameSign(false);
return RHS;
}
// Otherwise, LHS and RHS contradict and the whole expression becomes false
@@ -482,8 +503,8 @@ static Value *foldLogOpOfMaskedICmps_NotAllZeros_BMask_Mixed(
/// aren't of the common mask pattern type.
/// Also used for logical and/or, must be poison safe.
static Value *foldLogOpOfMaskedICmpsAsymmetric(
- ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, Value *A, Value *B, Value *C,
- Value *D, Value *E, ICmpInst::Predicate PredL, ICmpInst::Predicate PredR,
+ Value *LHS, Value *RHS, bool IsAnd, Value *A, Value *B, Value *C, Value *D,
+ Value *E, ICmpInst::Predicate PredL, ICmpInst::Predicate PredR,
unsigned LHSMask, unsigned RHSMask, InstCombiner::BuilderTy &Builder) {
assert(ICmpInst::isEquality(PredL) && ICmpInst::isEquality(PredR) &&
"Expected equality predicates for masked type of icmps.");
@@ -512,12 +533,12 @@ static Value *foldLogOpOfMaskedICmpsAsymmetric(
/// Try to fold (icmp(A & B) ==/!= C) &/| (icmp(A & D) ==/!= E)
/// into a single (icmp(A & X) ==/!= Y).
-static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd,
+static Value *foldLogOpOfMaskedICmps(Value *LHS, Value *RHS, bool IsAnd,
bool IsLogical,
InstCombiner::BuilderTy &Builder,
const SimplifyQuery &Q) {
Value *A = nullptr, *B = nullptr, *C = nullptr, *D = nullptr, *E = nullptr;
- ICmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate();
+ ICmpInst::Predicate PredL, PredR;
std::optional<std::pair<unsigned, unsigned>> MaskPair =
getMaskedTypeForICmpPair(A, B, C, D, E, LHS, RHS, PredL, PredR);
if (!MaskPair)
@@ -1067,8 +1088,7 @@ static Value *foldPowerOf2AndShiftedMask(ICmpInst *Cmp0, ICmpInst *Cmp1,
if (!JoinedByAnd)
return nullptr;
Value *A = nullptr, *B = nullptr, *C = nullptr, *D = nullptr, *E = nullptr;
- ICmpInst::Predicate CmpPred0 = Cmp0->getPredicate(),
- CmpPred1 = Cmp1->getPredicate();
+ ICmpInst::Predicate CmpPred0, CmpPred1;
// Assuming P is a 2^n, getMaskedTypeForICmpPair will normalize (icmp X u<
// 2^n) into (icmp (X & ~(2^n-1)) == 0) and (icmp X s> -1) into (icmp (X &
// SignMask) == 0).
@@ -3325,12 +3345,6 @@ Value *InstCombinerImpl::foldAndOrOfICmps(ICmpInst *LHS, ICmpInst *RHS,
}
}
- // handle (roughly):
- // (icmp ne (A & B), C) | (icmp ne (A & D), E)
- // (icmp eq (A & B), C) & (icmp eq (A & D), E)
- if (Value *V = foldLogOpOfMaskedICmps(LHS, RHS, IsAnd, IsLogical, Builder, Q))
- return V;
-
if (Value *V =
foldAndOrOfICmpEqConstantAndICmp(LHS, RHS, IsAnd, IsLogical, Builder))
return V;
@@ -3510,6 +3524,12 @@ Value *InstCombinerImpl::foldBooleanAndOr(Value *LHS, Value *RHS,
if (Value *Res = foldAndOrOfICmps(LHSCmp, RHSCmp, I, IsAnd, IsLogical))
return Res;
+ /// Try to fold (icmp(A & B) ==/!= C) &/| (icmp(A & D) ==/!= E)
+ /// into a single (icmp(A & X) ==/!= Y).
+ if (Value *V = foldLogOpOfMaskedICmps(LHS, RHS, IsAnd, IsLogical, Builder,
+ SQ.getWithInstruction(&I)))
+ return V;
+
if (auto *LHSCmp = dyn_cast<FCmpInst>(LHS))
if (auto *RHSCmp = dyn_cast<FCmpInst>(RHS))
if (Value *Res = foldLogicOfFCmps(LHSCmp, RHSCmp, IsAnd, IsLogical))
``````````
</details>
https://github.com/llvm/llvm-project/pull/122179
More information about the llvm-commits
mailing list