[llvm] [CmpInstAnalysis] Decompose icmp eq (and x, C) C2 (PR #136367)
Jeffrey Byrnes via llvm-commits
llvm-commits at lists.llvm.org
Fri Apr 18 13:36:05 PDT 2025
https://github.com/jrbyrnes created https://github.com/llvm/llvm-project/pull/136367
This type of decomposition is used in multiple places already. Adding it to `CmpInstAnalysis` reduces code duplication.
>From 740a37e616a1a4954efcfa2fa3a739e09c592238 Mon Sep 17 00:00:00 2001
From: Jeffrey Byrnes <Jeffrey.Byrnes at amd.com>
Date: Fri, 18 Apr 2025 10:38:54 -0700
Subject: [PATCH] [CmpInstAnalysis] Decompose icmp eq (and x, C) C2
Change-Id: I1dd786a4652ccd2e486db6903e16e58ffa1a7959
---
llvm/include/llvm/Analysis/CmpInstAnalysis.h | 14 +++++---
llvm/lib/Analysis/CmpInstAnalysis.cpp | 34 ++++++++++++++-----
.../InstCombine/InstCombineAndOrXor.cpp | 14 +++-----
.../InstCombine/InstCombineSelect.cpp | 32 ++++++++---------
.../Transforms/Scalar/LoopIdiomRecognize.cpp | 14 +++-----
5 files changed, 59 insertions(+), 49 deletions(-)
diff --git a/llvm/include/llvm/Analysis/CmpInstAnalysis.h b/llvm/include/llvm/Analysis/CmpInstAnalysis.h
index aeda58ac7535d..a796a38feff88 100644
--- a/llvm/include/llvm/Analysis/CmpInstAnalysis.h
+++ b/llvm/include/llvm/Analysis/CmpInstAnalysis.h
@@ -95,24 +95,28 @@ namespace llvm {
/// Represents the operation icmp (X & Mask) pred C, where pred can only be
/// eq or ne.
struct DecomposedBitTest {
- Value *X;
+ Value *X = nullptr;
CmpInst::Predicate Pred;
APInt Mask;
APInt C;
};
/// Decompose an icmp into the form ((X & Mask) pred C) if possible.
- /// Unless \p AllowNonZeroC is true, C will always be 0.
+ /// Unless \p AllowNonZeroC is true, C will always be 0. If \p
+ /// DecomposeBitMask is specified, then, for equality predicates, this will
+ /// decompose bitmasking (e.g. implemented via `and`).
std::optional<DecomposedBitTest>
decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
- bool LookThroughTrunc = true,
- bool AllowNonZeroC = false);
+ bool LookThroughTrunc = true, bool AllowNonZeroC = false,
+ bool DecomposeBitMask = false);
/// Decompose an icmp into the form ((X & Mask) pred C) if
/// possible. Unless \p AllowNonZeroC is true, C will always be 0.
+ /// If \p DecomposeBitMask is specified, then, for equality predicates, this
+ /// will decompose bitmasking (e.g. implemented via `and`).
std::optional<DecomposedBitTest>
decomposeBitTest(Value *Cond, bool LookThroughTrunc = true,
- bool AllowNonZeroC = false);
+ bool AllowNonZeroC = false, bool DecomposeBitMask = false);
} // end namespace llvm
diff --git a/llvm/lib/Analysis/CmpInstAnalysis.cpp b/llvm/lib/Analysis/CmpInstAnalysis.cpp
index 5c0d1dd1c74b0..6714088097127 100644
--- a/llvm/lib/Analysis/CmpInstAnalysis.cpp
+++ b/llvm/lib/Analysis/CmpInstAnalysis.cpp
@@ -75,11 +75,13 @@ Constant *llvm::getPredForFCmpCode(unsigned Code, Type *OpTy,
std::optional<DecomposedBitTest>
llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
- bool LookThruTrunc, bool AllowNonZeroC) {
+ bool LookThruTrunc, bool AllowNonZeroC,
+ bool DecomposeBitMask) {
using namespace PatternMatch;
const APInt *OrigC;
- if (!ICmpInst::isRelational(Pred) || !match(RHS, m_APIntAllowPoison(OrigC)))
+ if ((ICmpInst::isEquality(Pred) && !DecomposeBitMask) ||
+ !match(RHS, m_APIntAllowPoison(OrigC)))
return std::nullopt;
bool Inverted = false;
@@ -97,9 +99,10 @@ llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
}
DecomposedBitTest Result;
+
switch (Pred) {
default:
- llvm_unreachable("Unexpected predicate");
+ return std::nullopt;
case ICmpInst::ICMP_SLT: {
// X < 0 is equivalent to (X & SignMask) != 0.
if (C.isZero()) {
@@ -128,7 +131,7 @@ llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
return std::nullopt;
}
- case ICmpInst::ICMP_ULT:
+ case ICmpInst::ICMP_ULT: {
// X <u 2^n is equivalent to (X & ~(2^n-1)) == 0.
if (C.isPowerOf2()) {
Result.Mask = -C;
@@ -147,6 +150,19 @@ llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
return std::nullopt;
}
+ case ICmpInst::ICMP_EQ:
+ case ICmpInst::ICMP_NE: {
+ assert(DecomposeBitMask);
+ const APInt *AndC;
+ Value *AndVal;
+ if (match(LHS, m_And(m_Value(AndVal), m_APIntAllowPoison(AndC)))) {
+ Result = {AndVal /*X*/, Pred /*Pred*/, *AndC /*Mask*/, *OrigC /*C*/};
+ break;
+ }
+
+ return std::nullopt;
+ }
+ }
if (!AllowNonZeroC && !Result.C.isZero())
return std::nullopt;
@@ -159,15 +175,17 @@ llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
Result.X = X;
Result.Mask = Result.Mask.zext(X->getType()->getScalarSizeInBits());
Result.C = Result.C.zext(X->getType()->getScalarSizeInBits());
- } else {
+ } else if (!Result.X) {
Result.X = LHS;
}
return Result;
}
-std::optional<DecomposedBitTest>
-llvm::decomposeBitTest(Value *Cond, bool LookThruTrunc, bool AllowNonZeroC) {
+std::optional<DecomposedBitTest> llvm::decomposeBitTest(Value *Cond,
+ bool LookThruTrunc,
+ bool AllowNonZeroC,
+ bool DecomposeBitMask) {
using namespace PatternMatch;
if (auto *ICmp = dyn_cast<ICmpInst>(Cond)) {
// Don't allow pointers. Splat vectors are fine.
@@ -175,7 +193,7 @@ llvm::decomposeBitTest(Value *Cond, bool LookThruTrunc, bool AllowNonZeroC) {
return std::nullopt;
return decomposeBitTestICmp(ICmp->getOperand(0), ICmp->getOperand(1),
ICmp->getPredicate(), LookThruTrunc,
- AllowNonZeroC);
+ AllowNonZeroC, DecomposeBitMask);
}
Value *X;
if (Cond->getType()->isIntOrIntVectorTy(1) &&
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index 19bf81137aab7..60461a6e8b338 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -875,22 +875,16 @@ static Value *foldSignedTruncationCheck(ICmpInst *ICmp0, ICmpInst *ICmp1,
APInt &UnsetBitsMask) -> bool {
CmpPredicate Pred = ICmp->getPredicate();
// Can it be decomposed into icmp eq (X & Mask), 0 ?
- auto Res =
- llvm::decomposeBitTestICmp(ICmp->getOperand(0), ICmp->getOperand(1),
- Pred, /*LookThroughTrunc=*/false);
+ auto Res = llvm::decomposeBitTestICmp(
+ ICmp->getOperand(0), ICmp->getOperand(1), Pred,
+ /*LookThroughTrunc=*/true, /*AllowNonZeroC=*/false,
+ /*DecomposeBitMask=*/true);
if (Res && Res->Pred == ICmpInst::ICMP_EQ) {
X = Res->X;
UnsetBitsMask = Res->Mask;
return true;
}
- // Is it icmp eq (X & Mask), 0 already?
- const APInt *Mask;
- if (match(ICmp, m_ICmp(Pred, m_And(m_Value(X), m_APInt(Mask)), m_Zero())) &&
- Pred == ICmpInst::ICMP_EQ) {
- UnsetBitsMask = *Mask;
- return true;
- }
return false;
};
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 4bba2f406b4c1..8cdb00fb44a48 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -3730,31 +3730,29 @@ static Value *foldSelectBitTest(SelectInst &Sel, Value *CondVal, Value *TrueVal,
Value *CmpLHS, *CmpRHS;
if (match(CondVal, m_ICmp(Pred, m_Value(CmpLHS), m_Value(CmpRHS)))) {
- if (ICmpInst::isEquality(Pred)) {
- if (!match(CmpRHS, m_Zero()))
- return nullptr;
+ auto Res = decomposeBitTestICmp(
+ CmpLHS, CmpRHS, Pred, /*LookThroughTrunc=*/true,
+ /*AllowNonZeroC=*/false, /*DecomposeBitMask=*/true);
- V = CmpLHS;
- const APInt *AndRHS;
- if (!match(CmpLHS, m_And(m_Value(), m_Power2(AndRHS))))
- return nullptr;
+ if (!Res)
+ return nullptr;
- AndMask = *AndRHS;
- } else if (auto Res = decomposeBitTestICmp(CmpLHS, CmpRHS, Pred)) {
- assert(ICmpInst::isEquality(Res->Pred) && "Not equality test?");
- AndMask = Res->Mask;
+ V = CmpLHS;
+ AndMask = Res->Mask;
+
+ if (!ICmpInst::isEquality(Pred)) {
V = Res->X;
KnownBits Known =
computeKnownBits(V, /*Depth=*/0, SQ.getWithInstruction(&Sel));
AndMask &= Known.getMaxValue();
- if (!AndMask.isPowerOf2())
- return nullptr;
-
- Pred = Res->Pred;
CreateAnd = true;
- } else {
- return nullptr;
}
+
+ Pred = Res->Pred;
+
+ if (!AndMask.isPowerOf2())
+ return nullptr;
+
} else if (auto *Trunc = dyn_cast<TruncInst>(CondVal)) {
V = Trunc->getOperand(0);
AndMask = APInt(V->getType()->getScalarSizeInBits(), 1);
diff --git a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
index eacc67cd5a475..ced1557ec3060 100644
--- a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
@@ -2761,14 +2761,11 @@ static bool detectShiftUntilBitTestIdiom(Loop *CurLoop, Value *&BaseX,
m_LoopInvariant(m_Shl(m_One(), m_Value(BitPos)),
CurLoop))));
};
- auto MatchConstantBitMask = [&]() {
- return ICmpInst::isEquality(Pred) && match(CmpRHS, m_Zero()) &&
- match(CmpLHS, m_And(m_Value(CurrX),
- m_CombineAnd(m_Value(BitMask), m_Power2()))) &&
- (BitPos = ConstantExpr::getExactLogBase2(cast<Constant>(BitMask)));
- };
+
auto MatchDecomposableConstantBitMask = [&]() {
- auto Res = llvm::decomposeBitTestICmp(CmpLHS, CmpRHS, Pred);
+ auto Res = llvm::decomposeBitTestICmp(
+ CmpLHS, CmpRHS, Pred, /*LookThroughTrunc=*/true,
+ /*AllowNonZeroC=*/false, /*DecomposeBitMask=*/true);
if (Res && Res->Mask.isPowerOf2()) {
assert(ICmpInst::isEquality(Res->Pred));
Pred = Res->Pred;
@@ -2780,8 +2777,7 @@ static bool detectShiftUntilBitTestIdiom(Loop *CurLoop, Value *&BaseX,
return false;
};
- if (!MatchVariableBitMask() && !MatchConstantBitMask() &&
- !MatchDecomposableConstantBitMask()) {
+ if (!MatchVariableBitMask() && !MatchDecomposableConstantBitMask()) {
LLVM_DEBUG(dbgs() << DEBUG_TYPE " Bad backedge comparison.\n");
return false;
}
More information about the llvm-commits
mailing list