[llvm] [InstCombine] Reuse common code between foldSelectICmpAndBinOp and foldSelectICmpAnd. (PR #131902)
Andreas Jonson via llvm-commits
llvm-commits at lists.llvm.org
Tue Mar 18 12:48:21 PDT 2025
https://github.com/andjo403 created https://github.com/llvm/llvm-project/pull/131902
The commit that was removed from https://github.com/llvm/llvm-project/pull/127905 due to the conflict with https://github.com/llvm/llvm-project/pull/128741.
The use of common code results in that the foldSelectICmpAndBinOp also use knownbits in the same way as was added for foldSelectICmpAnd in https://github.com/llvm/llvm-project/pull/128741.
proof for the use of knowbits in foldSelectICmpAndBinOp: https://alive2.llvm.org/ce/z/RYXr_k
>From 6904d6913af5e15116b150f7de1bf0cf65032e33 Mon Sep 17 00:00:00 2001
From: Andreas Jonson <andjo403 at hotmail.com>
Date: Thu, 20 Feb 2025 21:14:07 +0100
Subject: [PATCH] [InstCombine] Reuse common matches between
foldSelectICmpAndBinOp and foldSelectICmpAnd.
---
.../InstCombine/InstCombineSelect.cpp | 184 ++++++++----------
.../InstCombine/select-with-bitwise-ops.ll | 6 +-
2 files changed, 84 insertions(+), 106 deletions(-)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index c3163f70b847e..4bba2f406b4c1 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -119,63 +119,15 @@ static Instruction *foldSelectBinOpIdentity(SelectInst &Sel,
/// (shl (and (X, C1)), (log2(TC-FC) - log2(C1))) + FC
/// With some variations depending if FC is larger than TC, or the shift
/// isn't needed, or the bit widths don't match.
-static Value *foldSelectICmpAnd(SelectInst &Sel, Value *CondVal,
- InstCombiner::BuilderTy &Builder,
- const SimplifyQuery &SQ) {
+static Value *foldSelectICmpAnd(SelectInst &Sel, Value *CondVal, Value *TrueVal,
+ Value *FalseVal, Value *V, const APInt &AndMask,
+ bool CreateAnd,
+ InstCombiner::BuilderTy &Builder) {
const APInt *SelTC, *SelFC;
- if (!match(Sel.getTrueValue(), m_APInt(SelTC)) ||
- !match(Sel.getFalseValue(), m_APInt(SelFC)))
+ if (!match(TrueVal, m_APInt(SelTC)) || !match(FalseVal, m_APInt(SelFC)))
return nullptr;
- // If this is a vector select, we need a vector compare.
Type *SelType = Sel.getType();
- if (SelType->isVectorTy() != CondVal->getType()->isVectorTy())
- return nullptr;
-
- Value *V;
- APInt AndMask;
- bool CreateAnd = false;
- CmpPredicate Pred;
- Value *CmpLHS, *CmpRHS;
-
- if (match(CondVal, m_ICmp(Pred, m_Value(CmpLHS), m_Value(CmpRHS)))) {
- if (ICmpInst::isEquality(Pred)) {
- if (!match(CmpRHS, m_Zero()))
- return nullptr;
-
- V = CmpLHS;
- const APInt *AndRHS;
- if (!match(V, m_And(m_Value(), m_Power2(AndRHS))))
- return nullptr;
-
- AndMask = *AndRHS;
- } else if (auto Res = decomposeBitTestICmp(CmpLHS, CmpRHS, Pred)) {
- assert(ICmpInst::isEquality(Res->Pred) && "Not equality test?");
- AndMask = Res->Mask;
- V = Res->X;
- KnownBits Known =
- computeKnownBits(V, /*Depth=*/0, SQ.getWithInstruction(&Sel));
- AndMask &= Known.getMaxValue();
- if (!AndMask.isPowerOf2())
- return nullptr;
-
- Pred = Res->Pred;
- CreateAnd = true;
- } else {
- return nullptr;
- }
-
- } else if (auto *Trunc = dyn_cast<TruncInst>(CondVal)) {
- V = Trunc->getOperand(0);
- AndMask = APInt(V->getType()->getScalarSizeInBits(), 1);
- Pred = ICmpInst::ICMP_NE;
- CreateAnd = !Trunc->hasNoUnsignedWrap();
- } else {
- return nullptr;
- }
- if (Pred == ICmpInst::ICMP_NE)
- std::swap(SelTC, SelFC);
-
// In general, when both constants are non-zero, we would need an offset to
// replace the select. This would require more instructions than we started
// with. But there's one special-case that we handle here because it can
@@ -762,60 +714,26 @@ static Value *foldSelectICmpLshrAshr(const ICmpInst *IC, Value *TrueVal,
/// 2. The select operands are reversed
/// 3. The magnitude of C2 and C1 are flipped
static Value *foldSelectICmpAndBinOp(Value *CondVal, Value *TrueVal,
- Value *FalseVal,
+ Value *FalseVal, Value *V,
+ const APInt &AndMask, bool CreateAnd,
InstCombiner::BuilderTy &Builder) {
- // Only handle integer compares. Also, if this is a vector select, we need a
- // vector compare.
- if (!TrueVal->getType()->isIntOrIntVectorTy() ||
- TrueVal->getType()->isVectorTy() != CondVal->getType()->isVectorTy())
- return nullptr;
-
- unsigned C1Log;
- bool NeedAnd = false;
- CmpPredicate Pred;
- Value *CmpLHS, *CmpRHS;
-
- if (match(CondVal, m_ICmp(Pred, m_Value(CmpLHS), m_Value(CmpRHS)))) {
- if (ICmpInst::isEquality(Pred)) {
- if (!match(CmpRHS, m_Zero()))
- return nullptr;
-
- const APInt *C1;
- if (!match(CmpLHS, m_And(m_Value(), m_Power2(C1))))
- return nullptr;
-
- C1Log = C1->logBase2();
- } else {
- auto Res = decomposeBitTestICmp(CmpLHS, CmpRHS, Pred);
- if (!Res || !Res->Mask.isPowerOf2())
- return nullptr;
-
- CmpLHS = Res->X;
- Pred = Res->Pred;
- C1Log = Res->Mask.logBase2();
- NeedAnd = true;
- }
- } else if (auto *Trunc = dyn_cast<TruncInst>(CondVal)) {
- CmpLHS = Trunc->getOperand(0);
- C1Log = 0;
- Pred = ICmpInst::ICMP_NE;
- NeedAnd = !Trunc->hasNoUnsignedWrap();
- } else {
+ // Only handle integer compares.
+ if (!TrueVal->getType()->isIntOrIntVectorTy())
return nullptr;
- }
- Value *Y, *V = CmpLHS;
+ unsigned C1Log = AndMask.logBase2();
+ Value *Y;
BinaryOperator *BinOp;
const APInt *C2;
bool NeedXor;
if (match(FalseVal, m_BinOp(m_Specific(TrueVal), m_Power2(C2)))) {
Y = TrueVal;
BinOp = cast<BinaryOperator>(FalseVal);
- NeedXor = Pred == ICmpInst::ICMP_NE;
+ NeedXor = false;
} else if (match(TrueVal, m_BinOp(m_Specific(FalseVal), m_Power2(C2)))) {
Y = FalseVal;
BinOp = cast<BinaryOperator>(TrueVal);
- NeedXor = Pred == ICmpInst::ICMP_EQ;
+ NeedXor = true;
} else {
return nullptr;
}
@@ -834,14 +752,13 @@ static Value *foldSelectICmpAndBinOp(Value *CondVal, Value *TrueVal,
V->getType()->getScalarSizeInBits();
// Make sure we don't create more instructions than we save.
- if ((NeedShift + NeedXor + NeedZExtTrunc + NeedAnd) >
+ if ((NeedShift + NeedXor + NeedZExtTrunc + CreateAnd) >
(CondVal->hasOneUse() + BinOp->hasOneUse()))
return nullptr;
- if (NeedAnd) {
+ if (CreateAnd) {
// Insert the AND instruction on the input to the truncate.
- APInt C1 = APInt::getOneBitSet(V->getType()->getScalarSizeInBits(), C1Log);
- V = Builder.CreateAnd(V, ConstantInt::get(V->getType(), C1));
+ V = Builder.CreateAnd(V, ConstantInt::get(V->getType(), AndMask));
}
if (C2Log > C1Log) {
@@ -3797,6 +3714,70 @@ static Value *foldSelectIntoAddConstant(SelectInst &SI,
return nullptr;
}
+static Value *foldSelectBitTest(SelectInst &Sel, Value *CondVal, Value *TrueVal,
+ Value *FalseVal,
+ InstCombiner::BuilderTy &Builder,
+ const SimplifyQuery &SQ) {
+ // If this is a vector select, we need a vector compare.
+ Type *SelType = Sel.getType();
+ if (SelType->isVectorTy() != CondVal->getType()->isVectorTy())
+ return nullptr;
+
+ Value *V;
+ APInt AndMask;
+ bool CreateAnd = false;
+ CmpPredicate Pred;
+ Value *CmpLHS, *CmpRHS;
+
+ if (match(CondVal, m_ICmp(Pred, m_Value(CmpLHS), m_Value(CmpRHS)))) {
+ if (ICmpInst::isEquality(Pred)) {
+ if (!match(CmpRHS, m_Zero()))
+ return nullptr;
+
+ V = CmpLHS;
+ const APInt *AndRHS;
+ if (!match(CmpLHS, m_And(m_Value(), m_Power2(AndRHS))))
+ return nullptr;
+
+ AndMask = *AndRHS;
+ } else if (auto Res = decomposeBitTestICmp(CmpLHS, CmpRHS, Pred)) {
+ assert(ICmpInst::isEquality(Res->Pred) && "Not equality test?");
+ AndMask = Res->Mask;
+ V = Res->X;
+ KnownBits Known =
+ computeKnownBits(V, /*Depth=*/0, SQ.getWithInstruction(&Sel));
+ AndMask &= Known.getMaxValue();
+ if (!AndMask.isPowerOf2())
+ return nullptr;
+
+ Pred = Res->Pred;
+ CreateAnd = true;
+ } else {
+ return nullptr;
+ }
+ } else if (auto *Trunc = dyn_cast<TruncInst>(CondVal)) {
+ V = Trunc->getOperand(0);
+ AndMask = APInt(V->getType()->getScalarSizeInBits(), 1);
+ Pred = ICmpInst::ICMP_NE;
+ CreateAnd = !Trunc->hasNoUnsignedWrap();
+ } else {
+ return nullptr;
+ }
+
+ if (Pred == ICmpInst::ICMP_NE)
+ std::swap(TrueVal, FalseVal);
+
+ if (Value *X = foldSelectICmpAnd(Sel, CondVal, TrueVal, FalseVal, V, AndMask,
+ CreateAnd, Builder))
+ return X;
+
+ if (Value *X = foldSelectICmpAndBinOp(CondVal, TrueVal, FalseVal, V, AndMask,
+ CreateAnd, Builder))
+ return X;
+
+ return nullptr;
+}
+
Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
Value *CondVal = SI.getCondition();
Value *TrueVal = SI.getTrueValue();
@@ -3969,10 +3950,7 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
if (Instruction *Result = foldSelectInstWithICmp(SI, ICI))
return Result;
- if (Value *V = foldSelectICmpAnd(SI, CondVal, Builder, SQ))
- return replaceInstUsesWith(SI, V);
-
- if (Value *V = foldSelectICmpAndBinOp(CondVal, TrueVal, FalseVal, Builder))
+ if (Value *V = foldSelectBitTest(SI, CondVal, TrueVal, FalseVal, Builder, SQ))
return replaceInstUsesWith(SI, V);
if (Instruction *Add = foldAddSubSelect(SI, Builder))
diff --git a/llvm/test/Transforms/InstCombine/select-with-bitwise-ops.ll b/llvm/test/Transforms/InstCombine/select-with-bitwise-ops.ll
index a424247b676e4..771fad66e961e 100644
--- a/llvm/test/Transforms/InstCombine/select-with-bitwise-ops.ll
+++ b/llvm/test/Transforms/InstCombine/select-with-bitwise-ops.ll
@@ -1832,9 +1832,9 @@ define i8 @neg_select_trunc_or_2(i8 %x, i8 %y) {
define i8 @select_icmp_bittest_range(i8 range(i8 0, 64) %a, i8 %y) {
; CHECK-LABEL: @select_icmp_bittest_range(
-; CHECK-NEXT: [[CMP:%.*]] = icmp samesign ult i8 [[A:%.*]], 32
-; CHECK-NEXT: [[OR:%.*]] = or i8 [[Y:%.*]], 2
-; CHECK-NEXT: [[RES:%.*]] = select i1 [[CMP]], i8 [[Y]], i8 [[OR]]
+; CHECK-NEXT: [[TMP1:%.*]] = lshr i8 [[A:%.*]], 4
+; CHECK-NEXT: [[TMP2:%.*]] = and i8 [[TMP1]], 2
+; CHECK-NEXT: [[RES:%.*]] = or i8 [[Y:%.*]], [[TMP2]]
; CHECK-NEXT: ret i8 [[RES]]
;
%cmp = icmp ult i8 %a, 32
More information about the llvm-commits
mailing list