[llvm] [InstCombine] Combine and->cmp->sel->or-disjoint into and->mul (PR #135274)
Jeffrey Byrnes via llvm-commits
llvm-commits at lists.llvm.org
Tue Apr 15 12:05:37 PDT 2025
https://github.com/jrbyrnes updated https://github.com/llvm/llvm-project/pull/135274
>From 79ebe2788159e73c252cdb8d04569bb332b1139f Mon Sep 17 00:00:00 2001
From: Jeffrey Byrnes <Jeffrey.Byrnes at amd.com>
Date: Wed, 9 Apr 2025 14:44:11 -0700
Subject: [PATCH 1/3] [InstCombine] Combine and->cmp->sel->or-disjoint into
and->mul
Change-Id: Id45315f1e5f71077800d3a8141b85bb3b5d8f38a
---
.../InstCombine/InstCombineAndOrXor.cpp | 42 +++++++
llvm/test/Transforms/InstCombine/or.ll | 114 +++++++++++++++++-
2 files changed, 152 insertions(+), 4 deletions(-)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index 6cc241781d112..6dc4b97686f97 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -3643,6 +3643,48 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
foldAddLikeCommutative(I.getOperand(1), I.getOperand(0),
/*NSW=*/true, /*NUW=*/true))
return R;
+
+ Value *Cond0 = nullptr, *Cond1 = nullptr;
+ ConstantInt *Op0True = nullptr, *Op0False = nullptr;
+ ConstantInt *Op1True = nullptr, *Op1False = nullptr;
+
+ // (!(A & N) ? 0 : N * C) + (!(A & M) ? 0 : M * C) -> A & (N + M) * C
+ if (match(I.getOperand(0), m_Select(m_Value(Cond0), m_ConstantInt(Op0True),
+ m_ConstantInt(Op0False))) &&
+ match(I.getOperand(1), m_Select(m_Value(Cond1), m_ConstantInt(Op1True),
+ m_ConstantInt(Op1False))) &&
+ Op0True->isZero() && Op1True->isZero() &&
+ Op0False->getValue().tryZExtValue() &&
+ Op1False->getValue().tryZExtValue()) {
+ CmpPredicate Pred0, Pred1;
+ Value *CmpOp0 = nullptr, *CmpOp1 = nullptr;
+ ConstantInt *Op0Cond = nullptr, *Op1Cond = nullptr;
+ if (match(Cond0,
+ m_c_ICmp(Pred0, m_Value(CmpOp0), m_ConstantInt(Op0Cond))) &&
+ match(Cond1,
+ m_c_ICmp(Pred1, m_Value(CmpOp1), m_ConstantInt(Op1Cond))) &&
+ Pred0 == ICmpInst::ICMP_EQ && Pred1 == ICmpInst::ICMP_EQ &&
+ Op0Cond->isZero() && Op1Cond->isZero()) {
+ Value *AndSrc0 = nullptr, *AndSrc1 = nullptr;
+ ConstantInt *BitSel0 = nullptr, *BitSel1 = nullptr;
+ if (match(CmpOp0, m_And(m_Value(AndSrc0), m_ConstantInt(BitSel0))) &&
+ match(CmpOp1, m_And(m_Value(AndSrc1), m_ConstantInt(BitSel1))) &&
+ AndSrc0 == AndSrc1 && BitSel0->getValue().tryZExtValue() &&
+ BitSel1->getValue().tryZExtValue()) {
+ unsigned Out0 = Op0False->getValue().getZExtValue();
+ unsigned Out1 = Op1False->getValue().getZExtValue();
+ unsigned Sel0 = BitSel0->getValue().getZExtValue();
+ unsigned Sel1 = BitSel1->getValue().getZExtValue();
+ if (!(Out0 % Sel0) && !(Out1 % Sel1) &&
+ ((Out0 / Sel0) == (Out1 / Sel1))) {
+ auto NewAnd = Builder.CreateAnd(
+ AndSrc0, ConstantInt::get(AndSrc0->getType(), Sel0 + Sel1));
+ return BinaryOperator::CreateMul(
+ NewAnd, ConstantInt::get(NewAnd->getType(), (Out1 / Sel1)));
+ }
+ }
+ }
+ }
}
Value *X, *Y;
diff --git a/llvm/test/Transforms/InstCombine/or.ll b/llvm/test/Transforms/InstCombine/or.ll
index 95f89e4ce11cd..f2b21ca966592 100644
--- a/llvm/test/Transforms/InstCombine/or.ll
+++ b/llvm/test/Transforms/InstCombine/or.ll
@@ -1281,10 +1281,10 @@ define <16 x i1> @test51(<16 x i1> %arg, <16 x i1> %arg1) {
; CHECK-NEXT: [[TMP3:%.*]] = shufflevector <16 x i1> [[ARG:%.*]], <16 x i1> [[ARG1:%.*]], <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 20, i32 5, i32 6, i32 23, i32 24, i32 9, i32 10, i32 27, i32 28, i32 29, i32 30, i32 31>
; CHECK-NEXT: ret <16 x i1> [[TMP3]]
;
- %tmp = and <16 x i1> %arg, <i1 true, i1 true, i1 true, i1 true, i1 false, i1 true, i1 true, i1 false, i1 false, i1 true, i1 true, i1 false, i1 false, i1 false, i1 false, i1 false>
- %tmp2 = and <16 x i1> %arg1, <i1 false, i1 false, i1 false, i1 false, i1 true, i1 false, i1 false, i1 true, i1 true, i1 false, i1 false, i1 true, i1 true, i1 true, i1 true, i1 true>
- %tmp3 = or <16 x i1> %tmp, %tmp2
- ret <16 x i1> %tmp3
+ %temp = and <16 x i1> %arg, <i1 true, i1 true, i1 true, i1 true, i1 false, i1 true, i1 true, i1 false, i1 false, i1 true, i1 true, i1 false, i1 false, i1 false, i1 false, i1 false>
+ %temp2 = and <16 x i1> %arg1, <i1 false, i1 false, i1 false, i1 false, i1 true, i1 false, i1 false, i1 true, i1 true, i1 false, i1 false, i1 true, i1 true, i1 true, i1 true, i1 true>
+ %temp3 = or <16 x i1> %temp, %temp2
+ ret <16 x i1> %temp3
}
; This would infinite loop because it reaches a transform
@@ -2035,3 +2035,109 @@ define i32 @or_xor_and_commuted3(i32 %x, i32 %y, i32 %z) {
%or1 = or i32 %xor, %yy
ret i32 %or1
}
+
+define i32 @add_select_cmp_and1(i32 %in) {
+; CHECK-LABEL: @add_select_cmp_and1(
+; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[IN:%.*]], 3
+; CHECK-NEXT: [[OUT:%.*]] = mul nuw nsw i32 [[TMP1]], 72
+; CHECK-NEXT: ret i32 [[OUT]]
+;
+ %bitop0 = and i32 %in, 1
+ %cmp0 = icmp eq i32 %bitop0, 0
+ %bitop1 = and i32 %in, 2
+ %cmp1 = icmp eq i32 %bitop1, 0
+ %sel0 = select i1 %cmp0, i32 0, i32 72
+ %sel1 = select i1 %cmp1, i32 0, i32 144
+ %out = or disjoint i32 %sel0, %sel1
+ ret i32 %out
+}
+
+define i32 @add_select_cmp_and2(i32 %in) {
+; CHECK-LABEL: @add_select_cmp_and2(
+; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[IN:%.*]], 5
+; CHECK-NEXT: [[OUT:%.*]] = mul nuw nsw i32 [[TMP1]], 72
+; CHECK-NEXT: ret i32 [[OUT]]
+;
+ %bitop0 = and i32 %in, 1
+ %cmp0 = icmp eq i32 %bitop0, 0
+ %bitop1 = and i32 %in, 4
+ %cmp1 = icmp eq i32 %bitop1, 0
+ %sel0 = select i1 %cmp0, i32 0, i32 72
+ %sel1 = select i1 %cmp1, i32 0, i32 288
+ %out = or disjoint i32 %sel0, %sel1
+ ret i32 %out
+}
+
+define i32 @add_select_cmp_and3(i32 %in) {
+; CHECK-LABEL: @add_select_cmp_and3(
+; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[IN:%.*]], 3
+; CHECK-NEXT: [[TEMP:%.*]] = mul nuw nsw i32 [[TMP1]], 72
+; CHECK-NEXT: [[BITOP2:%.*]] = and i32 [[IN]], 4
+; CHECK-NEXT: [[CMP2:%.*]] = icmp eq i32 [[BITOP2]], 0
+; CHECK-NEXT: [[SEL2:%.*]] = select i1 [[CMP2]], i32 0, i32 288
+; CHECK-NEXT: [[OUT:%.*]] = or disjoint i32 [[TEMP]], [[SEL2]]
+; CHECK-NEXT: ret i32 [[OUT]]
+;
+ %bitop0 = and i32 %in, 1
+ %cmp0 = icmp eq i32 %bitop0, 0
+ %bitop1 = and i32 %in, 2
+ %cmp1 = icmp eq i32 %bitop1, 0
+ %sel0 = select i1 %cmp0, i32 0, i32 72
+ %sel1 = select i1 %cmp1, i32 0, i32 144
+ %temp = or disjoint i32 %sel0, %sel1
+ %bitop2 = and i32 %in, 4
+ %cmp2 = icmp eq i32 %bitop2, 0
+ %sel2 = select i1 %cmp2, i32 0, i32 288
+ %out = or disjoint i32 %temp, %sel2
+ ret i32 %out
+}
+
+define i32 @add_select_cmp_and4(i32 %in) {
+; CHECK-LABEL: @add_select_cmp_and4(
+; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[IN:%.*]], 3
+; CHECK-NEXT: [[TEMP:%.*]] = mul nuw nsw i32 [[TMP1]], 72
+; CHECK-NEXT: [[TMP2:%.*]] = and i32 [[IN]], 12
+; CHECK-NEXT: [[TEMP2:%.*]] = mul nuw nsw i32 [[TMP2]], 72
+; CHECK-NEXT: [[OUT:%.*]] = or disjoint i32 [[TEMP]], [[TEMP2]]
+; CHECK-NEXT: ret i32 [[OUT]]
+;
+ %bitop0 = and i32 %in, 1
+ %cmp0 = icmp eq i32 %bitop0, 0
+ %bitop1 = and i32 %in, 2
+ %cmp1 = icmp eq i32 %bitop1, 0
+ %sel0 = select i1 %cmp0, i32 0, i32 72
+ %sel1 = select i1 %cmp1, i32 0, i32 144
+ %temp = or disjoint i32 %sel0, %sel1
+ %bitop2 = and i32 %in, 4
+ %cmp2 = icmp eq i32 %bitop2, 0
+ %bitop3 = and i32 %in, 8
+ %cmp3 = icmp eq i32 %bitop3, 0
+ %sel2 = select i1 %cmp2, i32 0, i32 288
+ %sel3 = select i1 %cmp3, i32 0, i32 576
+ %temp2 = or disjoint i32 %sel2, %sel3
+ %out = or disjoint i32 %temp, %temp2
+ ret i32 %out
+}
+
+
+
+define i32 @add_select_cmp_and_mismatch(i32 %in) {
+; CHECK-LABEL: @add_select_cmp_and_mismatch(
+; CHECK-NEXT: [[BITOP0:%.*]] = and i32 [[IN:%.*]], 1
+; CHECK-NEXT: [[CMP0:%.*]] = icmp eq i32 [[BITOP0]], 0
+; CHECK-NEXT: [[BITOP1:%.*]] = and i32 [[IN]], 3
+; CHECK-NEXT: [[CMP1:%.*]] = icmp eq i32 [[BITOP1]], 0
+; CHECK-NEXT: [[SEL0:%.*]] = select i1 [[CMP0]], i32 0, i32 72
+; CHECK-NEXT: [[SEL1:%.*]] = select i1 [[CMP1]], i32 0, i32 288
+; CHECK-NEXT: [[OUT:%.*]] = or disjoint i32 [[SEL0]], [[SEL1]]
+; CHECK-NEXT: ret i32 [[OUT]]
+;
+ %bitop0 = and i32 %in, 1
+ %cmp0 = icmp eq i32 %bitop0, 0
+ %bitop1 = and i32 %in, 3
+ %cmp1 = icmp eq i32 %bitop1, 0
+ %sel0 = select i1 %cmp0, i32 0, i32 72
+ %sel1 = select i1 %cmp1, i32 0, i32 288
+ %out = or disjoint i32 %sel0, %sel1
+ ret i32 %out
+}
>From eacb0d8fcef4565e7d9462e0b2feca472227c95c Mon Sep 17 00:00:00 2001
From: Jeffrey Byrnes <Jeffrey.Byrnes at amd.com>
Date: Mon, 14 Apr 2025 12:23:06 -0700
Subject: [PATCH 2/3] Address review comments
Change-Id: I630d506375b0eb4b16dad1437bff2da357be2059
---
llvm/include/llvm/Analysis/CmpInstAnalysis.h | 4 +-
llvm/lib/Analysis/CmpInstAnalysis.cpp | 28 +++-
.../InstCombine/InstCombineAndOrXor.cpp | 81 +++++++-----
llvm/test/Transforms/InstCombine/or.ll | 124 +++++++++++++++++-
4 files changed, 195 insertions(+), 42 deletions(-)
diff --git a/llvm/include/llvm/Analysis/CmpInstAnalysis.h b/llvm/include/llvm/Analysis/CmpInstAnalysis.h
index aeda58ac7535d..e8a9060b8e882 100644
--- a/llvm/include/llvm/Analysis/CmpInstAnalysis.h
+++ b/llvm/include/llvm/Analysis/CmpInstAnalysis.h
@@ -105,8 +105,8 @@ namespace llvm {
/// Unless \p AllowNonZeroC is true, C will always be 0.
std::optional<DecomposedBitTest>
decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
- bool LookThroughTrunc = true,
- bool AllowNonZeroC = false);
+ bool LookThroughTrunc = true, bool AllowNonZeroC = false,
+ bool LookThruBitSel = false);
/// Decompose an icmp into the form ((X & Mask) pred C) if
/// possible. Unless \p AllowNonZeroC is true, C will always be 0.
diff --git a/llvm/lib/Analysis/CmpInstAnalysis.cpp b/llvm/lib/Analysis/CmpInstAnalysis.cpp
index 5c0d1dd1c74b0..ebbc71ee20ec9 100644
--- a/llvm/lib/Analysis/CmpInstAnalysis.cpp
+++ b/llvm/lib/Analysis/CmpInstAnalysis.cpp
@@ -75,11 +75,12 @@ Constant *llvm::getPredForFCmpCode(unsigned Code, Type *OpTy,
std::optional<DecomposedBitTest>
llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
- bool LookThruTrunc, bool AllowNonZeroC) {
+ bool LookThruTrunc, bool AllowNonZeroC,
+ bool LookThruBitSel) {
using namespace PatternMatch;
const APInt *OrigC;
- if (!ICmpInst::isRelational(Pred) || !match(RHS, m_APIntAllowPoison(OrigC)))
+ if (!match(RHS, m_APIntAllowPoison(OrigC)))
return std::nullopt;
bool Inverted = false;
@@ -96,10 +97,27 @@ llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
Pred = ICmpInst::getStrictPredicate(Pred);
}
+ auto decomposeBitMask =
+ [LHS,
+ LookThruBitSel](CmpInst::Predicate Pred,
+ const APInt *OrigC) -> std::optional<DecomposedBitTest> {
+ if (!LookThruBitSel)
+ return std::nullopt;
+
+ const APInt *AndC;
+ Value *AndVal;
+ std::optional<DecomposedBitTest> Result = std::nullopt;
+ if (match(LHS, m_And(m_Value(AndVal), m_APInt(AndC))))
+ Result = {AndVal /*X*/, Pred /*Pred*/, *AndC /*Mask*/, *OrigC /*C*/};
+
+ return Result;
+ };
+
DecomposedBitTest Result;
+
switch (Pred) {
default:
- llvm_unreachable("Unexpected predicate");
+ return decomposeBitMask(Pred, OrigC);
case ICmpInst::ICMP_SLT: {
// X < 0 is equivalent to (X & SignMask) != 0.
if (C.isZero()) {
@@ -126,7 +144,7 @@ llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
break;
}
- return std::nullopt;
+ return decomposeBitMask(Pred, OrigC);
}
case ICmpInst::ICMP_ULT:
// X <u 2^n is equivalent to (X & ~(2^n-1)) == 0.
@@ -145,7 +163,7 @@ llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
break;
}
- return std::nullopt;
+ return decomposeBitMask(Pred, OrigC);
}
if (!AllowNonZeroC && !Result.C.isZero())
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index 6dc4b97686f97..4fc0df749a36d 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -3645,42 +3645,61 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
return R;
Value *Cond0 = nullptr, *Cond1 = nullptr;
- ConstantInt *Op0True = nullptr, *Op0False = nullptr;
- ConstantInt *Op1True = nullptr, *Op1False = nullptr;
+ const APInt *Op0True = nullptr, *Op0False = nullptr;
+ const APInt *Op1True = nullptr, *Op1False = nullptr;
// (!(A & N) ? 0 : N * C) + (!(A & M) ? 0 : M * C) -> A & (N + M) * C
- if (match(I.getOperand(0), m_Select(m_Value(Cond0), m_ConstantInt(Op0True),
- m_ConstantInt(Op0False))) &&
- match(I.getOperand(1), m_Select(m_Value(Cond1), m_ConstantInt(Op1True),
- m_ConstantInt(Op1False))) &&
- Op0True->isZero() && Op1True->isZero() &&
- Op0False->getValue().tryZExtValue() &&
- Op1False->getValue().tryZExtValue()) {
+ if (match(I.getOperand(0),
+ m_Select(m_Value(Cond0), m_APInt(Op0True), m_APInt(Op0False))) &&
+ match(I.getOperand(1),
+ m_Select(m_Value(Cond1), m_APInt(Op1True), m_APInt(Op1False))) &&
+ Op0True->isZero() && Op1True->isZero()) {
CmpPredicate Pred0, Pred1;
- Value *CmpOp0 = nullptr, *CmpOp1 = nullptr;
- ConstantInt *Op0Cond = nullptr, *Op1Cond = nullptr;
- if (match(Cond0,
- m_c_ICmp(Pred0, m_Value(CmpOp0), m_ConstantInt(Op0Cond))) &&
- match(Cond1,
- m_c_ICmp(Pred1, m_Value(CmpOp1), m_ConstantInt(Op1Cond))) &&
- Pred0 == ICmpInst::ICMP_EQ && Pred1 == ICmpInst::ICMP_EQ &&
- Op0Cond->isZero() && Op1Cond->isZero()) {
- Value *AndSrc0 = nullptr, *AndSrc1 = nullptr;
- ConstantInt *BitSel0 = nullptr, *BitSel1 = nullptr;
- if (match(CmpOp0, m_And(m_Value(AndSrc0), m_ConstantInt(BitSel0))) &&
- match(CmpOp1, m_And(m_Value(AndSrc1), m_ConstantInt(BitSel1))) &&
- AndSrc0 == AndSrc1 && BitSel0->getValue().tryZExtValue() &&
- BitSel1->getValue().tryZExtValue()) {
- unsigned Out0 = Op0False->getValue().getZExtValue();
- unsigned Out1 = Op1False->getValue().getZExtValue();
- unsigned Sel0 = BitSel0->getValue().getZExtValue();
- unsigned Sel1 = BitSel1->getValue().getZExtValue();
- if (!(Out0 % Sel0) && !(Out1 % Sel1) &&
- ((Out0 / Sel0) == (Out1 / Sel1))) {
+
+ if (ICmpInst *ICL = dyn_cast<ICmpInst>(Cond0);
+ ICmpInst *ICR = dyn_cast<ICmpInst>(Cond1)) {
+ auto LHSDecompose =
+ decomposeBitTestICmp(ICL->getOperand(0), ICL->getOperand(1),
+ ICL->getPredicate(), true, true, true);
+ auto RHSDecompose =
+ decomposeBitTestICmp(ICR->getOperand(0), ICR->getOperand(1),
+ ICR->getPredicate(), true, true, true);
+ if (LHSDecompose && RHSDecompose &&
+ LHSDecompose->Pred == RHSDecompose->Pred &&
+ (LHSDecompose->Pred == ICmpInst::ICMP_EQ ||
+ LHSDecompose->Pred == ICmpInst::ICMP_NE) &&
+ ((LHSDecompose->Mask & RHSDecompose->Mask) ==
+ APInt::getZero(LHSDecompose->Mask.getBitWidth())) &&
+ LHSDecompose->C.isZero() && RHSDecompose->C.isZero() &&
+ !RHSDecompose->Mask.isNegative() &&
+ !LHSDecompose->Mask.isNegative() &&
+ RHSDecompose->Mask.isPowerOf2() &&
+ LHSDecompose->Mask.isPowerOf2()) {
+ std::pair<const APInt *, const APInt *> LHSInts;
+ std::pair<const APInt *, const APInt *> RHSInts;
+
+ if (LHSDecompose->Pred == ICmpInst::ICMP_EQ) {
+ LHSInts = {Op0False, Op0True};
+ RHSInts = {Op1False, Op1True};
+ } else {
+ LHSInts = {Op0True, Op0False};
+ RHSInts = {Op1True, Op1False};
+ }
+
+ if (!LHSInts.first->isNegative() && !RHSInts.first->isNegative() &&
+ LHSInts.second->isZero() && RHSInts.second->isZero() &&
+ LHSInts.first->urem(LHSDecompose->Mask).isZero() &&
+ RHSInts.first->urem(RHSDecompose->Mask).isZero() &&
+ LHSInts.first->udiv(LHSDecompose->Mask) ==
+ RHSInts.first->udiv(RHSDecompose->Mask)) {
auto NewAnd = Builder.CreateAnd(
- AndSrc0, ConstantInt::get(AndSrc0->getType(), Sel0 + Sel1));
+ LHSDecompose->X,
+ ConstantInt::get(LHSDecompose->X->getType(),
+ (LHSDecompose->Mask + RHSDecompose->Mask)));
return BinaryOperator::CreateMul(
- NewAnd, ConstantInt::get(NewAnd->getType(), (Out1 / Sel1)));
+ NewAnd,
+ ConstantInt::get(NewAnd->getType(),
+ LHSInts.first->udiv(LHSDecompose->Mask)));
}
}
}
diff --git a/llvm/test/Transforms/InstCombine/or.ll b/llvm/test/Transforms/InstCombine/or.ll
index f2b21ca966592..2d6a905ea520b 100644
--- a/llvm/test/Transforms/InstCombine/or.ll
+++ b/llvm/test/Transforms/InstCombine/or.ll
@@ -2119,13 +2119,11 @@ define i32 @add_select_cmp_and4(i32 %in) {
ret i32 %out
}
-
-
define i32 @add_select_cmp_and_mismatch(i32 %in) {
; CHECK-LABEL: @add_select_cmp_and_mismatch(
; CHECK-NEXT: [[BITOP0:%.*]] = and i32 [[IN:%.*]], 1
; CHECK-NEXT: [[CMP0:%.*]] = icmp eq i32 [[BITOP0]], 0
-; CHECK-NEXT: [[BITOP1:%.*]] = and i32 [[IN]], 3
+; CHECK-NEXT: [[BITOP1:%.*]] = and i32 [[IN]], 2
; CHECK-NEXT: [[CMP1:%.*]] = icmp eq i32 [[BITOP1]], 0
; CHECK-NEXT: [[SEL0:%.*]] = select i1 [[CMP0]], i32 0, i32 72
; CHECK-NEXT: [[SEL1:%.*]] = select i1 [[CMP1]], i32 0, i32 288
@@ -2134,10 +2132,128 @@ define i32 @add_select_cmp_and_mismatch(i32 %in) {
;
%bitop0 = and i32 %in, 1
%cmp0 = icmp eq i32 %bitop0, 0
- %bitop1 = and i32 %in, 3
+ %bitop1 = and i32 %in, 2
%cmp1 = icmp eq i32 %bitop1, 0
%sel0 = select i1 %cmp0, i32 0, i32 72
%sel1 = select i1 %cmp1, i32 0, i32 288
%out = or disjoint i32 %sel0, %sel1
ret i32 %out
}
+
+define i32 @add_select_cmp_and_negative(i32 %in) {
+; CHECK-LABEL: @add_select_cmp_and_negative(
+; CHECK-NEXT: [[BITOP0:%.*]] = and i32 [[IN:%.*]], 1
+; CHECK-NEXT: [[CMP0:%.*]] = icmp eq i32 [[BITOP0]], 0
+; CHECK-NEXT: [[CMP1:%.*]] = icmp ult i32 [[IN]], 2
+; CHECK-NEXT: [[SEL0:%.*]] = select i1 [[CMP0]], i32 0, i32 72
+; CHECK-NEXT: [[SEL1:%.*]] = select i1 [[CMP1]], i32 0, i32 -144
+; CHECK-NEXT: [[OUT:%.*]] = or disjoint i32 [[SEL0]], [[SEL1]]
+; CHECK-NEXT: ret i32 [[OUT]]
+;
+ %bitop0 = and i32 %in, 1
+ %cmp0 = icmp eq i32 %bitop0, 0
+ %bitop1 = and i32 %in, -2
+ %cmp1 = icmp eq i32 %bitop1, 0
+ %sel0 = select i1 %cmp0, i32 0, i32 72
+ %sel1 = select i1 %cmp1, i32 0, i32 -144
+ %out = or disjoint i32 %sel0, %sel1
+ ret i32 %out
+}
+
+define i32 @add_select_cmp_and_bitsel_overlap(i32 %in) {
+; CHECK-LABEL: @add_select_cmp_and_bitsel_overlap(
+; CHECK-NEXT: [[BITOP0:%.*]] = and i32 [[IN:%.*]], 2
+; CHECK-NEXT: [[CMP0:%.*]] = icmp eq i32 [[BITOP0]], 0
+; CHECK-NEXT: [[SEL0:%.*]] = select i1 [[CMP0]], i32 0, i32 144
+; CHECK-NEXT: ret i32 [[SEL0]]
+;
+ %bitop0 = and i32 %in, 2
+ %cmp0 = icmp eq i32 %bitop0, 0
+ %bitop1 = and i32 %in, 2
+ %cmp1 = icmp eq i32 %bitop1, 0
+ %sel0 = select i1 %cmp0, i32 0, i32 144
+ %sel1 = select i1 %cmp1, i32 0, i32 144
+ %out = or disjoint i32 %sel0, %sel1
+ ret i32 %out
+}
+
+; We cannot combine into and-mul, as %bitop1 may not be exactly 6
+
+define i32 @add_select_cmp_and_multbit_mask(i32 %in) {
+; CHECK-LABEL: @add_select_cmp_and_multbit_mask(
+; CHECK-NEXT: [[BITOP0:%.*]] = and i32 [[IN:%.*]], 1
+; CHECK-NEXT: [[CMP0:%.*]] = icmp eq i32 [[BITOP0]], 0
+; CHECK-NEXT: [[BITOP1:%.*]] = and i32 [[IN]], 6
+; CHECK-NEXT: [[CMP1:%.*]] = icmp eq i32 [[BITOP1]], 0
+; CHECK-NEXT: [[SEL0:%.*]] = select i1 [[CMP0]], i32 0, i32 72
+; CHECK-NEXT: [[SEL1:%.*]] = select i1 [[CMP1]], i32 0, i32 432
+; CHECK-NEXT: [[OUT:%.*]] = or disjoint i32 [[SEL0]], [[SEL1]]
+; CHECK-NEXT: ret i32 [[OUT]]
+;
+ %bitop0 = and i32 %in, 1
+ %cmp0 = icmp eq i32 %bitop0, 0
+ %bitop1 = and i32 %in, 6
+ %cmp1 = icmp eq i32 %bitop1, 0
+ %sel0 = select i1 %cmp0, i32 0, i32 72
+ %sel1 = select i1 %cmp1, i32 0, i32 432
+ %out = or disjoint i32 %sel0, %sel1
+ ret i32 %out
+}
+
+
+define <2 x i32> @add_select_cmp_vec(<2 x i32> %in) {
+; CHECK-LABEL: @add_select_cmp_vec(
+; CHECK-NEXT: [[TMP1:%.*]] = and <2 x i32> [[IN:%.*]], splat (i32 3)
+; CHECK-NEXT: [[OUT:%.*]] = mul nuw nsw <2 x i32> [[TMP1]], splat (i32 72)
+; CHECK-NEXT: ret <2 x i32> [[OUT]]
+;
+ %bitop0 = and <2 x i32> %in, <i32 1, i32 1>
+ %cmp0 = icmp eq <2 x i32> %bitop0, <i32 0, i32 0>
+ %bitop1 = and <2 x i32> %in, <i32 2, i32 2>
+ %cmp1 = icmp eq <2 x i32> %bitop1, <i32 0, i32 0>
+ %sel0 = select <2 x i1> %cmp0, <2 x i32> <i32 0, i32 0>, <2 x i32> <i32 72, i32 72>
+ %sel1 = select <2 x i1> %cmp1, <2 x i32> <i32 0, i32 0>, <2 x i32> <i32 144, i32 144>
+ %out = or disjoint <2 x i32> %sel0, %sel1
+ ret <2 x i32> %out
+}
+
+define <2 x i32> @add_select_cmp_vec_poison(<2 x i32> %in) {
+; CHECK-LABEL: @add_select_cmp_vec_poison(
+; CHECK-NEXT: [[BITOP0:%.*]] = and <2 x i32> [[IN:%.*]], splat (i32 1)
+; CHECK-NEXT: [[CMP0:%.*]] = icmp eq <2 x i32> [[BITOP0]], zeroinitializer
+; CHECK-NEXT: [[BITOP1:%.*]] = and <2 x i32> [[IN]], splat (i32 2)
+; CHECK-NEXT: [[CMP1:%.*]] = icmp eq <2 x i32> [[BITOP1]], zeroinitializer
+; CHECK-NEXT: [[SEL1:%.*]] = select <2 x i1> [[CMP1]], <2 x i32> zeroinitializer, <2 x i32> <i32 poison, i32 144>
+; CHECK-NEXT: [[OUT:%.*]] = select <2 x i1> [[CMP0]], <2 x i32> [[SEL1]], <2 x i32> <i32 72, i32 poison>
+; CHECK-NEXT: ret <2 x i32> [[OUT]]
+;
+ %bitop0 = and <2 x i32> %in, <i32 1, i32 1>
+ %cmp0 = icmp eq <2 x i32> %bitop0, <i32 0, i32 0>
+ %bitop1 = and <2 x i32> %in, <i32 2, i32 2>
+ %cmp1 = icmp eq <2 x i32> %bitop1, <i32 0, i32 0>
+ %sel0 = select <2 x i1> %cmp0, <2 x i32> <i32 0, i32 0>, <2 x i32> <i32 72, i32 poison>
+ %sel1 = select <2 x i1> %cmp1, <2 x i32> <i32 0, i32 0>, <2 x i32> <i32 poison, i32 144>
+ %out = or disjoint <2 x i32> %sel0, %sel1
+ ret <2 x i32> %out
+}
+
+define <2 x i32> @add_select_cmp_vec_nonunique(<2 x i32> %in) {
+; CHECK-LABEL: @add_select_cmp_vec_nonunique(
+; CHECK-NEXT: [[BITOP0:%.*]] = and <2 x i32> [[IN:%.*]], <i32 1, i32 2>
+; CHECK-NEXT: [[CMP0:%.*]] = icmp eq <2 x i32> [[BITOP0]], zeroinitializer
+; CHECK-NEXT: [[BITOP1:%.*]] = and <2 x i32> [[IN]], <i32 4, i32 8>
+; CHECK-NEXT: [[CMP1:%.*]] = icmp eq <2 x i32> [[BITOP1]], zeroinitializer
+; CHECK-NEXT: [[SEL0:%.*]] = select <2 x i1> [[CMP0]], <2 x i32> zeroinitializer, <2 x i32> <i32 72, i32 144>
+; CHECK-NEXT: [[SEL1:%.*]] = select <2 x i1> [[CMP1]], <2 x i32> zeroinitializer, <2 x i32> <i32 288, i32 576>
+; CHECK-NEXT: [[OUT:%.*]] = or disjoint <2 x i32> [[SEL0]], [[SEL1]]
+; CHECK-NEXT: ret <2 x i32> [[OUT]]
+;
+ %bitop0 = and <2 x i32> %in, <i32 1, i32 2>
+ %cmp0 = icmp eq <2 x i32> %bitop0, <i32 0, i32 0>
+ %bitop1 = and <2 x i32> %in, <i32 4, i32 8>
+ %cmp1 = icmp eq <2 x i32> %bitop1, <i32 0, i32 0>
+ %sel0 = select <2 x i1> %cmp0, <2 x i32> <i32 0, i32 0>, <2 x i32> <i32 72, i32 144>
+ %sel1 = select <2 x i1> %cmp1, <2 x i32> <i32 0, i32 0>, <2 x i32> <i32 288, i32 576>
+ %out = or disjoint <2 x i32> %sel0, %sel1
+ ret <2 x i32> %out
+}
>From 107ec00350ab3f4010f21e4afda31eacc3a773a0 Mon Sep 17 00:00:00 2001
From: Jeffrey Byrnes <Jeffrey.Byrnes at amd.com>
Date: Tue, 15 Apr 2025 11:41:05 -0700
Subject: [PATCH 3/3] Review comments 2
Change-Id: I24786ee6dc53a33fc7afbd80d226cda4e4a4df03
---
.../InstCombine/InstCombineAndOrXor.cpp | 49 ++++++++-----------
llvm/test/Transforms/InstCombine/or.ll | 25 +++++++++-
2 files changed, 43 insertions(+), 31 deletions(-)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index 4fc0df749a36d..4bd32fba1b702 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -3645,61 +3645,52 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
return R;
Value *Cond0 = nullptr, *Cond1 = nullptr;
- const APInt *Op0True = nullptr, *Op0False = nullptr;
- const APInt *Op1True = nullptr, *Op1False = nullptr;
+ const APInt *Op0Eq = nullptr, *Op0Ne = nullptr;
+ const APInt *Op1Eq = nullptr, *Op1Ne = nullptr;
// (!(A & N) ? 0 : N * C) + (!(A & M) ? 0 : M * C) -> A & (N + M) * C
if (match(I.getOperand(0),
- m_Select(m_Value(Cond0), m_APInt(Op0True), m_APInt(Op0False))) &&
+ m_Select(m_Value(Cond0), m_APInt(Op0Eq), m_APInt(Op0Ne))) &&
match(I.getOperand(1),
- m_Select(m_Value(Cond1), m_APInt(Op1True), m_APInt(Op1False))) &&
- Op0True->isZero() && Op1True->isZero()) {
+ m_Select(m_Value(Cond1), m_APInt(Op1Eq), m_APInt(Op1Ne)))) {
CmpPredicate Pred0, Pred1;
if (ICmpInst *ICL = dyn_cast<ICmpInst>(Cond0);
ICmpInst *ICR = dyn_cast<ICmpInst>(Cond1)) {
auto LHSDecompose =
decomposeBitTestICmp(ICL->getOperand(0), ICL->getOperand(1),
- ICL->getPredicate(), true, true, true);
+ ICL->getPredicate(), true, false, true);
auto RHSDecompose =
decomposeBitTestICmp(ICR->getOperand(0), ICR->getOperand(1),
- ICR->getPredicate(), true, true, true);
+ ICR->getPredicate(), true, false, true);
if (LHSDecompose && RHSDecompose &&
+ LHSDecompose->X == RHSDecompose->X &&
LHSDecompose->Pred == RHSDecompose->Pred &&
- (LHSDecompose->Pred == ICmpInst::ICMP_EQ ||
- LHSDecompose->Pred == ICmpInst::ICMP_NE) &&
- ((LHSDecompose->Mask & RHSDecompose->Mask) ==
- APInt::getZero(LHSDecompose->Mask.getBitWidth())) &&
- LHSDecompose->C.isZero() && RHSDecompose->C.isZero() &&
+ (ICmpInst::isEquality(LHSDecompose->Pred)) &&
!RHSDecompose->Mask.isNegative() &&
!LHSDecompose->Mask.isNegative() &&
RHSDecompose->Mask.isPowerOf2() &&
- LHSDecompose->Mask.isPowerOf2()) {
+ LHSDecompose->Mask.isPowerOf2() &&
+ LHSDecompose->Mask != RHSDecompose->Mask) {
std::pair<const APInt *, const APInt *> LHSInts;
std::pair<const APInt *, const APInt *> RHSInts;
-
- if (LHSDecompose->Pred == ICmpInst::ICMP_EQ) {
- LHSInts = {Op0False, Op0True};
- RHSInts = {Op1False, Op1True};
- } else {
- LHSInts = {Op0True, Op0False};
- RHSInts = {Op1True, Op1False};
+ if (LHSDecompose->Pred == ICmpInst::ICMP_NE) {
+ std::swap(Op0Eq, Op0Ne);
+ std::swap(Op1Eq, Op1Ne);
}
- if (!LHSInts.first->isNegative() && !RHSInts.first->isNegative() &&
- LHSInts.second->isZero() && RHSInts.second->isZero() &&
- LHSInts.first->urem(LHSDecompose->Mask).isZero() &&
- RHSInts.first->urem(RHSDecompose->Mask).isZero() &&
- LHSInts.first->udiv(LHSDecompose->Mask) ==
- RHSInts.first->udiv(RHSDecompose->Mask)) {
+ if (!Op0Ne->isNegative() && !Op1Ne->isNegative() && Op0Eq->isZero() &&
+ Op1Eq->isZero() && Op0Ne->urem(LHSDecompose->Mask).isZero() &&
+ Op1Ne->urem(RHSDecompose->Mask).isZero() &&
+ Op0Ne->udiv(LHSDecompose->Mask) ==
+ Op1Ne->udiv(RHSDecompose->Mask)) {
auto NewAnd = Builder.CreateAnd(
LHSDecompose->X,
ConstantInt::get(LHSDecompose->X->getType(),
(LHSDecompose->Mask + RHSDecompose->Mask)));
return BinaryOperator::CreateMul(
- NewAnd,
- ConstantInt::get(NewAnd->getType(),
- LHSInts.first->udiv(LHSDecompose->Mask)));
+ NewAnd, ConstantInt::get(NewAnd->getType(),
+ Op0Ne->udiv(LHSDecompose->Mask)));
}
}
}
diff --git a/llvm/test/Transforms/InstCombine/or.ll b/llvm/test/Transforms/InstCombine/or.ll
index 2d6a905ea520b..02548ac9e7883 100644
--- a/llvm/test/Transforms/InstCombine/or.ll
+++ b/llvm/test/Transforms/InstCombine/or.ll
@@ -2119,8 +2119,8 @@ define i32 @add_select_cmp_and4(i32 %in) {
ret i32 %out
}
-define i32 @add_select_cmp_and_mismatch(i32 %in) {
-; CHECK-LABEL: @add_select_cmp_and_mismatch(
+define i32 @add_select_cmp_and_const_mismatch(i32 %in) {
+; CHECK-LABEL: @add_select_cmp_and_const_mismatch(
; CHECK-NEXT: [[BITOP0:%.*]] = and i32 [[IN:%.*]], 1
; CHECK-NEXT: [[CMP0:%.*]] = icmp eq i32 [[BITOP0]], 0
; CHECK-NEXT: [[BITOP1:%.*]] = and i32 [[IN]], 2
@@ -2140,6 +2140,27 @@ define i32 @add_select_cmp_and_mismatch(i32 %in) {
ret i32 %out
}
+define i32 @add_select_cmp_and_value_mismatch(i32 %in, i32 %in1) {
+; CHECK-LABEL: @add_select_cmp_and_value_mismatch(
+; CHECK-NEXT: [[BITOP0:%.*]] = and i32 [[IN:%.*]], 1
+; CHECK-NEXT: [[CMP0:%.*]] = icmp eq i32 [[BITOP0]], 0
+; CHECK-NEXT: [[BITOP1:%.*]] = and i32 [[IN1:%.*]], 2
+; CHECK-NEXT: [[CMP1:%.*]] = icmp eq i32 [[BITOP1]], 0
+; CHECK-NEXT: [[SEL0:%.*]] = select i1 [[CMP0]], i32 0, i32 72
+; CHECK-NEXT: [[SEL1:%.*]] = select i1 [[CMP1]], i32 0, i32 144
+; CHECK-NEXT: [[OUT:%.*]] = or disjoint i32 [[SEL0]], [[SEL1]]
+; CHECK-NEXT: ret i32 [[OUT]]
+;
+ %bitop0 = and i32 %in, 1
+ %cmp0 = icmp eq i32 %bitop0, 0
+ %bitop1 = and i32 %in1, 2
+ %cmp1 = icmp eq i32 %bitop1, 0
+ %sel0 = select i1 %cmp0, i32 0, i32 72
+ %sel1 = select i1 %cmp1, i32 0, i32 144
+ %out = or disjoint i32 %sel0, %sel1
+ ret i32 %out
+}
+
define i32 @add_select_cmp_and_negative(i32 %in) {
; CHECK-LABEL: @add_select_cmp_and_negative(
; CHECK-NEXT: [[BITOP0:%.*]] = and i32 [[IN:%.*]], 1
More information about the llvm-commits
mailing list