[llvm] 06499f3 - [InstCombine] Prepare foldLogOpOfMaskedICmps to handle trunc to i1. (NFC) (#122179)
via llvm-commits
llvm-commits at lists.llvm.org
Wed Jan 15 09:08:57 PST 2025
Author: Andreas Jonson
Date: 2025-01-15T18:08:53+01:00
New Revision: 06499f3672afc371b653bf54422c2e80e1e27c90
URL: https://github.com/llvm/llvm-project/commit/06499f3672afc371b653bf54422c2e80e1e27c90
DIFF: https://github.com/llvm/llvm-project/commit/06499f3672afc371b653bf54422c2e80e1e27c90.diff
LOG: [InstCombine] Prepare foldLogOpOfMaskedICmps to handle trunc to i1. (NFC) (#122179)
Added:
Modified:
llvm/include/llvm/Analysis/CmpInstAnalysis.h
llvm/lib/Analysis/CmpInstAnalysis.cpp
llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
Removed:
################################################################################
diff --git a/llvm/include/llvm/Analysis/CmpInstAnalysis.h b/llvm/include/llvm/Analysis/CmpInstAnalysis.h
index c7862a6d39d07e..aeda58ac7535d5 100644
--- a/llvm/include/llvm/Analysis/CmpInstAnalysis.h
+++ b/llvm/include/llvm/Analysis/CmpInstAnalysis.h
@@ -108,6 +108,12 @@ namespace llvm {
bool LookThroughTrunc = true,
bool AllowNonZeroC = false);
+ /// Decompose an icmp into the form ((X & Mask) pred C) if
+ /// possible. Unless \p AllowNonZeroC is true, C will always be 0.
+ std::optional<DecomposedBitTest>
+ decomposeBitTest(Value *Cond, bool LookThroughTrunc = true,
+ bool AllowNonZeroC = false);
+
} // end namespace llvm
#endif
diff --git a/llvm/lib/Analysis/CmpInstAnalysis.cpp b/llvm/lib/Analysis/CmpInstAnalysis.cpp
index 2580ea7e972488..3599428c5ff416 100644
--- a/llvm/lib/Analysis/CmpInstAnalysis.cpp
+++ b/llvm/lib/Analysis/CmpInstAnalysis.cpp
@@ -165,3 +165,17 @@ llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
return Result;
}
+
+std::optional<DecomposedBitTest>
+llvm::decomposeBitTest(Value *Cond, bool LookThruTrunc, bool AllowNonZeroC) {
+ if (auto *ICmp = dyn_cast<ICmpInst>(Cond)) {
+ // Don't allow pointers. Splat vectors are fine.
+ if (!ICmp->getOperand(0)->getType()->isIntOrIntVectorTy())
+ return std::nullopt;
+ return decomposeBitTestICmp(ICmp->getOperand(0), ICmp->getOperand(1),
+ ICmp->getPredicate(), LookThruTrunc,
+ AllowNonZeroC);
+ }
+
+ return std::nullopt;
+}
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index f82a557e5760c8..f7d17b1aa3865c 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -179,10 +179,10 @@ static unsigned conjugateICmpMask(unsigned Mask) {
}
// Adapts the external decomposeBitTestICmp for local use.
-static bool decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate &Pred,
+static bool decomposeBitTestICmp(Value *Cond, CmpInst::Predicate &Pred,
Value *&X, Value *&Y, Value *&Z) {
- auto Res = llvm::decomposeBitTestICmp(
- LHS, RHS, Pred, /*LookThroughTrunc=*/true, /*AllowNonZeroC=*/true);
+ auto Res = llvm::decomposeBitTest(Cond, /*LookThroughTrunc=*/true,
+ /*AllowNonZeroC=*/true);
if (!Res)
return false;
@@ -198,13 +198,10 @@ static bool decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate &Pre
/// the right hand side as a pair.
/// LHS and RHS are the left hand side and the right hand side ICmps and PredL
/// and PredR are their predicates, respectively.
-static std::optional<std::pair<unsigned, unsigned>> getMaskedTypeForICmpPair(
- Value *&A, Value *&B, Value *&C, Value *&D, Value *&E, ICmpInst *LHS,
- ICmpInst *RHS, ICmpInst::Predicate &PredL, ICmpInst::Predicate &PredR) {
- // Don't allow pointers. Splat vectors are fine.
- if (!LHS->getOperand(0)->getType()->isIntOrIntVectorTy() ||
- !RHS->getOperand(0)->getType()->isIntOrIntVectorTy())
- return std::nullopt;
+static std::optional<std::pair<unsigned, unsigned>>
+getMaskedTypeForICmpPair(Value *&A, Value *&B, Value *&C, Value *&D, Value *&E,
+ Value *LHS, Value *RHS, ICmpInst::Predicate &PredL,
+ ICmpInst::Predicate &PredR) {
// Here comes the tricky part:
// LHS might be of the form L11 & L12 == X, X == L21 & L22,
@@ -212,13 +209,23 @@ static std::optional<std::pair<unsigned, unsigned>> getMaskedTypeForICmpPair(
// Now we must find those components L** and R**, that are equal, so
// that we can extract the parameters A, B, C, D, and E for the canonical
// above.
- Value *L1 = LHS->getOperand(0);
- Value *L2 = LHS->getOperand(1);
- Value *L11, *L12, *L21, *L22;
+
// Check whether the icmp can be decomposed into a bit test.
- if (decomposeBitTestICmp(L1, L2, PredL, L11, L12, L2)) {
+ Value *L1, *L11, *L12, *L2, *L21, *L22;
+ if (decomposeBitTestICmp(LHS, PredL, L11, L12, L2)) {
L21 = L22 = L1 = nullptr;
} else {
+ auto *LHSCMP = dyn_cast<ICmpInst>(LHS);
+ if (!LHSCMP)
+ return std::nullopt;
+
+ // Don't allow pointers. Splat vectors are fine.
+ if (!LHSCMP->getOperand(0)->getType()->isIntOrIntVectorTy())
+ return std::nullopt;
+
+ PredL = LHSCMP->getPredicate();
+ L1 = LHSCMP->getOperand(0);
+ L2 = LHSCMP->getOperand(1);
// Look for ANDs in the LHS icmp.
if (!match(L1, m_And(m_Value(L11), m_Value(L12)))) {
// Any icmp can be viewed as being trivially masked; if it allows us to
@@ -237,11 +244,8 @@ static std::optional<std::pair<unsigned, unsigned>> getMaskedTypeForICmpPair(
if (!ICmpInst::isEquality(PredL))
return std::nullopt;
- Value *R1 = RHS->getOperand(0);
- Value *R2 = RHS->getOperand(1);
- Value *R11, *R12;
- bool Ok = false;
- if (decomposeBitTestICmp(R1, R2, PredR, R11, R12, R2)) {
+ Value *R11, *R12, *R2;
+ if (decomposeBitTestICmp(RHS, PredR, R11, R12, R2)) {
if (R11 == L11 || R11 == L12 || R11 == L21 || R11 == L22) {
A = R11;
D = R12;
@@ -252,9 +256,19 @@ static std::optional<std::pair<unsigned, unsigned>> getMaskedTypeForICmpPair(
return std::nullopt;
}
E = R2;
- R1 = nullptr;
- Ok = true;
} else {
+ auto *RHSCMP = dyn_cast<ICmpInst>(RHS);
+ if (!RHSCMP)
+ return std::nullopt;
+ // Don't allow pointers. Splat vectors are fine.
+ if (!RHSCMP->getOperand(0)->getType()->isIntOrIntVectorTy())
+ return std::nullopt;
+
+ PredR = RHSCMP->getPredicate();
+
+ Value *R1 = RHSCMP->getOperand(0);
+ R2 = RHSCMP->getOperand(1);
+ bool Ok = false;
if (!match(R1, m_And(m_Value(R11), m_Value(R12)))) {
// As before, model no mask as a trivial mask if it'll let us do an
// optimization.
@@ -277,36 +291,32 @@ static std::optional<std::pair<unsigned, unsigned>> getMaskedTypeForICmpPair(
// Avoid matching against the -1 value we created for unmasked operand.
if (Ok && match(A, m_AllOnes()))
Ok = false;
+
+ // Look for ANDs on the right side of the RHS icmp.
+ if (!Ok) {
+ if (!match(R2, m_And(m_Value(R11), m_Value(R12)))) {
+ R11 = R2;
+ R12 = Constant::getAllOnesValue(R2->getType());
+ }
+
+ if (R11 == L11 || R11 == L12 || R11 == L21 || R11 == L22) {
+ A = R11;
+ D = R12;
+ E = R1;
+ } else if (R12 == L11 || R12 == L12 || R12 == L21 || R12 == L22) {
+ A = R12;
+ D = R11;
+ E = R1;
+ } else {
+ return std::nullopt;
+ }
+ }
}
// Bail if RHS was a icmp that can't be decomposed into an equality.
if (!ICmpInst::isEquality(PredR))
return std::nullopt;
- // Look for ANDs on the right side of the RHS icmp.
- if (!Ok) {
- if (!match(R2, m_And(m_Value(R11), m_Value(R12)))) {
- R11 = R2;
- R12 = Constant::getAllOnesValue(R2->getType());
- }
-
- if (R11 == L11 || R11 == L12 || R11 == L21 || R11 == L22) {
- A = R11;
- D = R12;
- E = R1;
- Ok = true;
- } else if (R12 == L11 || R12 == L12 || R12 == L21 || R12 == L22) {
- A = R12;
- D = R11;
- E = R1;
- Ok = true;
- } else {
- return std::nullopt;
- }
-
- assert(Ok && "Failed to find AND on the right side of the RHS icmp.");
- }
-
if (L11 == A) {
B = L12;
C = L2;
@@ -333,8 +343,8 @@ static std::optional<std::pair<unsigned, unsigned>> getMaskedTypeForICmpPair(
/// (icmp (A & 12) != 0) & (icmp (A & 15) == 8) -> (icmp (A & 15) == 8).
/// Also used for logical and/or, must be poison safe.
static Value *foldLogOpOfMaskedICmps_NotAllZeros_BMask_Mixed(
- ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, Value *A, Value *B, Value *D,
- Value *E, ICmpInst::Predicate PredL, ICmpInst::Predicate PredR,
+ Value *LHS, Value *RHS, bool IsAnd, Value *A, Value *B, Value *D, Value *E,
+ ICmpInst::Predicate PredL, ICmpInst::Predicate PredR,
InstCombiner::BuilderTy &Builder) {
// We are given the canonical form:
// (icmp ne (A & B), 0) & (icmp eq (A & D), E).
@@ -457,7 +467,8 @@ static Value *foldLogOpOfMaskedICmps_NotAllZeros_BMask_Mixed(
// (icmp ne (A & 15), 0) & (icmp eq (A & 15), 8) -> (icmp eq (A & 15), 8).
if (IsSuperSetOrEqual(BCst, DCst)) {
// We can't guarantee that samesign hold after this fold.
- RHS->setSameSign(false);
+ if (auto *ICmp = dyn_cast<ICmpInst>(RHS))
+ ICmp->setSameSign(false);
return RHS;
}
// Otherwise, B is a subset of D. If B and E have a common bit set,
@@ -466,7 +477,8 @@ static Value *foldLogOpOfMaskedICmps_NotAllZeros_BMask_Mixed(
assert(IsSubSetOrEqual(BCst, DCst) && "Precondition due to above code");
if ((*BCst & ECst) != 0) {
// We can't guarantee that samesign hold after this fold.
- RHS->setSameSign(false);
+ if (auto *ICmp = dyn_cast<ICmpInst>(RHS))
+ ICmp->setSameSign(false);
return RHS;
}
// Otherwise, LHS and RHS contradict and the whole expression becomes false
@@ -481,8 +493,8 @@ static Value *foldLogOpOfMaskedICmps_NotAllZeros_BMask_Mixed(
/// aren't of the common mask pattern type.
/// Also used for logical and/or, must be poison safe.
static Value *foldLogOpOfMaskedICmpsAsymmetric(
- ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, Value *A, Value *B, Value *C,
- Value *D, Value *E, ICmpInst::Predicate PredL, ICmpInst::Predicate PredR,
+ Value *LHS, Value *RHS, bool IsAnd, Value *A, Value *B, Value *C, Value *D,
+ Value *E, ICmpInst::Predicate PredL, ICmpInst::Predicate PredR,
unsigned LHSMask, unsigned RHSMask, InstCombiner::BuilderTy &Builder) {
assert(ICmpInst::isEquality(PredL) && ICmpInst::isEquality(PredR) &&
"Expected equality predicates for masked type of icmps.");
@@ -511,12 +523,12 @@ static Value *foldLogOpOfMaskedICmpsAsymmetric(
/// Try to fold (icmp(A & B) ==/!= C) &/| (icmp(A & D) ==/!= E)
/// into a single (icmp(A & X) ==/!= Y).
-static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd,
+static Value *foldLogOpOfMaskedICmps(Value *LHS, Value *RHS, bool IsAnd,
bool IsLogical,
InstCombiner::BuilderTy &Builder,
const SimplifyQuery &Q) {
Value *A = nullptr, *B = nullptr, *C = nullptr, *D = nullptr, *E = nullptr;
- ICmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate();
+ ICmpInst::Predicate PredL, PredR;
std::optional<std::pair<unsigned, unsigned>> MaskPair =
getMaskedTypeForICmpPair(A, B, C, D, E, LHS, RHS, PredL, PredR);
if (!MaskPair)
@@ -1066,8 +1078,7 @@ static Value *foldPowerOf2AndShiftedMask(ICmpInst *Cmp0, ICmpInst *Cmp1,
if (!JoinedByAnd)
return nullptr;
Value *A = nullptr, *B = nullptr, *C = nullptr, *D = nullptr, *E = nullptr;
- ICmpInst::Predicate CmpPred0 = Cmp0->getPredicate(),
- CmpPred1 = Cmp1->getPredicate();
+ ICmpInst::Predicate CmpPred0, CmpPred1;
// Assuming P is a 2^n, getMaskedTypeForICmpPair will normalize (icmp X u<
// 2^n) into (icmp (X & ~(2^n-1)) == 0) and (icmp X s> -1) into (icmp (X &
// SignMask) == 0).
More information about the llvm-commits
mailing list