[llvm] [InstCombine] Combine or-disjoint (and->mul), (and->mul) to and->mul (PR #136013)
Jeffrey Byrnes via llvm-commits
llvm-commits at lists.llvm.org
Thu Jun 5 09:51:41 PDT 2025
https://github.com/jrbyrnes updated https://github.com/llvm/llvm-project/pull/136013
>From 0019711079e7d929b1853748d0f84c22adb04a62 Mon Sep 17 00:00:00 2001
From: Jeffrey Byrnes <Jeffrey.Byrnes at amd.com>
Date: Thu, 17 Apr 2025 10:11:18 -0700
Subject: [PATCH 1/4] [InstCombine] Extend bitmask->select combine to match
and->mul
Change-Id: I1cc2acd3804dde50636518f3ef2c9581848ae9f6
---
.../InstCombine/InstCombineAndOrXor.cpp | 122 ++++++++++++------
.../test/Transforms/InstCombine/or-bitmask.ll | 95 ++++++++++++--
2 files changed, 163 insertions(+), 54 deletions(-)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index 59b46ebdb72e2..ea166717d5c05 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -3593,6 +3593,72 @@ static Value *foldOrOfInversions(BinaryOperator &I,
return nullptr;
}
+struct DecomposedBitMaskMul {
+ Value *X;
+ APInt Factor;
+ APInt Mask;
+};
+
+static std::optional<DecomposedBitMaskMul> matchBitmaskMul(Value *V) {
+ Instruction *Op = dyn_cast<Instruction>(V);
+ if (!Op)
+ return std::nullopt;
+
+ Value *MulOp = nullptr;
+ const APInt *MulConst = nullptr;
+ if (match(Op, m_Mul(m_Value(MulOp), m_APInt(MulConst)))) {
+ Value *Original = nullptr;
+ const APInt *Mask = nullptr;
+ if (!MulConst->isStrictlyPositive())
+ return std::nullopt;
+
+ if (match(MulOp, m_And(m_Value(Original), m_APInt(Mask)))) {
+ if (!Mask->isStrictlyPositive())
+ return std::nullopt;
+ DecomposedBitMaskMul Ret;
+ Ret.X = Original;
+ Ret.Mask = *Mask;
+ Ret.Factor = *MulConst;
+ return Ret;
+ }
+ return std::nullopt;
+ }
+
+ Value *Cond = nullptr;
+ const APInt *EqZero = nullptr, *NeZero = nullptr;
+
+ // (!(A & N) ? 0 : N * C) + (!(A & M) ? 0 : M * C) -> A & (N + M) * C
+ if (match(Op, m_Select(m_Value(Cond), m_APInt(EqZero), m_APInt(NeZero)))) {
+ auto ICmpDecompose =
+ decomposeBitTest(Cond, /*LookThruTrunc=*/true,
+ /*AllowNonZeroC=*/false, /*DecomposeBitMask=*/true);
+ if (!ICmpDecompose.has_value())
+ return std::nullopt;
+
+ if (ICmpDecompose->Pred == ICmpInst::ICMP_NE)
+ std::swap(EqZero, NeZero);
+
+ if (!EqZero->isZero() || !NeZero->isStrictlyPositive())
+ return std::nullopt;
+
+ if (!ICmpInst::isEquality(ICmpDecompose->Pred) ||
+ !ICmpDecompose->C.isZero() || !ICmpDecompose->Mask.isPowerOf2() ||
+ ICmpDecompose->Mask.isNegative())
+ return std::nullopt;
+
+ if (!NeZero->urem(ICmpDecompose->Mask).isZero())
+ return std::nullopt;
+
+ DecomposedBitMaskMul Ret;
+ Ret.X = ICmpDecompose->X;
+ Ret.Mask = ICmpDecompose->Mask;
+ Ret.Factor = NeZero->udiv(ICmpDecompose->Mask);
+ return Ret;
+ }
+
+ return std::nullopt;
+}
+
// FIXME: We use commutative matchers (m_c_*) for some, but not all, matches
// here. We should standardize that construct where it is needed or choose some
// other way to ensure that commutated variants of patterns are not missed.
@@ -3675,49 +3741,19 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
/*NSW=*/true, /*NUW=*/true))
return R;
- Value *Cond0 = nullptr, *Cond1 = nullptr;
- const APInt *Op0Eq = nullptr, *Op0Ne = nullptr;
- const APInt *Op1Eq = nullptr, *Op1Ne = nullptr;
-
- // (!(A & N) ? 0 : N * C) + (!(A & M) ? 0 : M * C) -> A & (N + M) * C
- if (match(I.getOperand(0),
- m_Select(m_Value(Cond0), m_APInt(Op0Eq), m_APInt(Op0Ne))) &&
- match(I.getOperand(1),
- m_Select(m_Value(Cond1), m_APInt(Op1Eq), m_APInt(Op1Ne)))) {
-
- auto LHSDecompose =
- decomposeBitTest(Cond0, /*LookThruTrunc=*/true,
- /*AllowNonZeroC=*/false, /*DecomposeAnd=*/true);
- auto RHSDecompose =
- decomposeBitTest(Cond1, /*LookThruTrunc=*/true,
- /*AllowNonZeroC=*/false, /*DecomposeAnd=*/true);
-
- if (LHSDecompose && RHSDecompose && LHSDecompose->X == RHSDecompose->X &&
- RHSDecompose->Mask.isPowerOf2() && LHSDecompose->Mask.isPowerOf2() &&
- LHSDecompose->Mask != RHSDecompose->Mask &&
- LHSDecompose->Mask.getBitWidth() == Op0Ne->getBitWidth() &&
- RHSDecompose->Mask.getBitWidth() == Op1Ne->getBitWidth()) {
- assert(Op0Ne->getBitWidth() == Op1Ne->getBitWidth());
- assert(ICmpInst::isEquality(LHSDecompose->Pred));
- if (LHSDecompose->Pred == ICmpInst::ICMP_NE)
- std::swap(Op0Eq, Op0Ne);
- if (RHSDecompose->Pred == ICmpInst::ICMP_NE)
- std::swap(Op1Eq, Op1Ne);
-
- if (!Op0Ne->isZero() && !Op1Ne->isZero() && Op0Eq->isZero() &&
- Op1Eq->isZero() && Op0Ne->urem(LHSDecompose->Mask).isZero() &&
- Op1Ne->urem(RHSDecompose->Mask).isZero() &&
- Op0Ne->udiv(LHSDecompose->Mask) ==
- Op1Ne->udiv(RHSDecompose->Mask)) {
- auto NewAnd = Builder.CreateAnd(
- LHSDecompose->X,
- ConstantInt::get(LHSDecompose->X->getType(),
- (LHSDecompose->Mask + RHSDecompose->Mask)));
-
- return BinaryOperator::CreateMul(
- NewAnd, ConstantInt::get(NewAnd->getType(),
- Op0Ne->udiv(LHSDecompose->Mask)));
- }
+ auto Decomp0 = matchBitmaskMul(I.getOperand(0));
+ auto Decomp1 = matchBitmaskMul(I.getOperand(1));
+
+ if (Decomp0 && Decomp1) {
+ if (Decomp0->X == Decomp1->X &&
+ (Decomp0->Mask & Decomp1->Mask).isZero() &&
+ Decomp0->Factor == Decomp1->Factor) {
+ auto NewAnd = Builder.CreateAnd(
+ Decomp0->X, ConstantInt::get(Decomp0->X->getType(),
+ (Decomp0->Mask + Decomp1->Mask)));
+
+ return BinaryOperator::CreateMul(
+ NewAnd, ConstantInt::get(NewAnd->getType(), Decomp1->Factor));
}
}
}
diff --git a/llvm/test/Transforms/InstCombine/or-bitmask.ll b/llvm/test/Transforms/InstCombine/or-bitmask.ll
index 3b482dc1794db..87f0bbf4d37ab 100644
--- a/llvm/test/Transforms/InstCombine/or-bitmask.ll
+++ b/llvm/test/Transforms/InstCombine/or-bitmask.ll
@@ -36,13 +36,9 @@ define i32 @add_select_cmp_and2(i32 %in) {
define i32 @add_select_cmp_and3(i32 %in) {
; CHECK-LABEL: @add_select_cmp_and3(
-; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[IN:%.*]], 3
+; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[IN:%.*]], 7
; CHECK-NEXT: [[TEMP:%.*]] = mul nuw nsw i32 [[TMP1]], 72
-; CHECK-NEXT: [[BITOP2:%.*]] = and i32 [[IN]], 4
-; CHECK-NEXT: [[CMP2:%.*]] = icmp eq i32 [[BITOP2]], 0
-; CHECK-NEXT: [[SEL2:%.*]] = select i1 [[CMP2]], i32 0, i32 288
-; CHECK-NEXT: [[OUT:%.*]] = or disjoint i32 [[TEMP]], [[SEL2]]
-; CHECK-NEXT: ret i32 [[OUT]]
+; CHECK-NEXT: ret i32 [[TEMP]]
;
%bitop0 = and i32 %in, 1
%cmp0 = icmp eq i32 %bitop0, 0
@@ -60,12 +56,9 @@ define i32 @add_select_cmp_and3(i32 %in) {
define i32 @add_select_cmp_and4(i32 %in) {
; CHECK-LABEL: @add_select_cmp_and4(
-; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[IN:%.*]], 3
-; CHECK-NEXT: [[OUT:%.*]] = mul nuw nsw i32 [[TMP1]], 72
-; CHECK-NEXT: [[TMP2:%.*]] = and i32 [[IN]], 12
+; CHECK-NEXT: [[TMP2:%.*]] = and i32 [[IN:%.*]], 15
; CHECK-NEXT: [[TEMP3:%.*]] = mul nuw nsw i32 [[TMP2]], 72
-; CHECK-NEXT: [[OUT1:%.*]] = or disjoint i32 [[OUT]], [[TEMP3]]
-; CHECK-NEXT: ret i32 [[OUT1]]
+; CHECK-NEXT: ret i32 [[TEMP3]]
;
%bitop0 = and i32 %in, 1
%cmp0 = icmp eq i32 %bitop0, 0
@@ -361,6 +354,86 @@ define i64 @mask_select_types_1(i64 %in) {
ret i64 %out
}
+define i32 @add_select_cmp_mixed1(i32 %in) {
+; CHECK-LABEL: @add_select_cmp_mixed1(
+; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[IN:%.*]], 3
+; CHECK-NEXT: [[OUT:%.*]] = mul nuw nsw i32 [[TMP1]], 72
+; CHECK-NEXT: ret i32 [[OUT]]
+;
+ %mask = and i32 %in, 1
+ %sel0 = mul i32 %mask, 72
+ %bitop1 = and i32 %in, 2
+ %cmp1 = icmp eq i32 %bitop1, 0
+ %sel1 = select i1 %cmp1, i32 0, i32 144
+ %out = or disjoint i32 %sel0, %sel1
+ ret i32 %out
+}
+
+define i32 @add_select_cmp_mixed2(i32 %in) {
+; CHECK-LABEL: @add_select_cmp_mixed2(
+; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[IN:%.*]], 3
+; CHECK-NEXT: [[OUT:%.*]] = mul nuw nsw i32 [[TMP1]], 72
+; CHECK-NEXT: ret i32 [[OUT]]
+;
+ %bitop0 = and i32 %in, 1
+ %cmp0 = icmp eq i32 %bitop0, 0
+ %mask = and i32 %in, 2
+ %sel0 = select i1 %cmp0, i32 0, i32 72
+ %sel1 = mul i32 %mask, 72
+ %out = or disjoint i32 %sel0, %sel1
+ ret i32 %out
+}
+
+define i32 @add_select_cmp_and_mul(i32 %in) {
+; CHECK-LABEL: @add_select_cmp_and_mul(
+; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[IN:%.*]], 3
+; CHECK-NEXT: [[OUT:%.*]] = mul nuw nsw i32 [[TMP1]], 72
+; CHECK-NEXT: ret i32 [[OUT]]
+;
+ %mask0 = and i32 %in, 1
+ %sel0 = mul i32 %mask0, 72
+ %mask1 = and i32 %in, 2
+ %sel1 = mul i32 %mask1, 72
+ %out = or disjoint i32 %sel0, %sel1
+ ret i32 %out
+}
+
+define i32 @add_select_cmp_mixed2_mismatch(i32 %in) {
+; CHECK-LABEL: @add_select_cmp_mixed2_mismatch(
+; CHECK-NEXT: [[BITOP0:%.*]] = and i32 [[IN:%.*]], 1
+; CHECK-NEXT: [[CMP0:%.*]] = icmp eq i32 [[BITOP0]], 0
+; CHECK-NEXT: [[MASK:%.*]] = and i32 [[IN]], 2
+; CHECK-NEXT: [[SEL0:%.*]] = select i1 [[CMP0]], i32 0, i32 73
+; CHECK-NEXT: [[SEL1:%.*]] = mul nuw nsw i32 [[MASK]], 72
+; CHECK-NEXT: [[OUT:%.*]] = or disjoint i32 [[SEL0]], [[SEL1]]
+; CHECK-NEXT: ret i32 [[OUT]]
+;
+ %bitop0 = and i32 %in, 1
+ %cmp0 = icmp eq i32 %bitop0, 0
+ %mask = and i32 %in, 2
+ %sel0 = select i1 %cmp0, i32 0, i32 73
+ %sel1 = mul i32 %mask, 72
+ %out = or disjoint i32 %sel0, %sel1
+ ret i32 %out
+}
+
+define i32 @add_select_cmp_and_mul_mismatch(i32 %in) {
+; CHECK-LABEL: @add_select_cmp_and_mul_mismatch(
+; CHECK-NEXT: [[TMP1:%.*]] = trunc i32 [[IN:%.*]] to i1
+; CHECK-NEXT: [[SEL0:%.*]] = select i1 [[TMP1]], i32 73, i32 0
+; CHECK-NEXT: [[MASK1:%.*]] = and i32 [[IN]], 2
+; CHECK-NEXT: [[SEL1:%.*]] = mul nuw nsw i32 [[MASK1]], 72
+; CHECK-NEXT: [[OUT:%.*]] = or disjoint i32 [[SEL0]], [[SEL1]]
+; CHECK-NEXT: ret i32 [[OUT]]
+;
+ %mask0 = and i32 %in, 1
+ %sel0 = mul i32 %mask0, 73
+ %mask1 = and i32 %in, 2
+ %sel1 = mul i32 %mask1, 72
+ %out = or disjoint i32 %sel0, %sel1
+ ret i32 %out
+}
+
;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
; CONSTSPLAT: {{.*}}
; CONSTVEC: {{.*}}
>From 7b63d9b172597da44200f8718a2e3816e436e686 Mon Sep 17 00:00:00 2001
From: Jeffrey Byrnes <Jeffrey.Byrnes at amd.com>
Date: Thu, 22 May 2025 11:06:24 -0700
Subject: [PATCH 2/4] Review comments + fix some conditions
Change-Id: I4b71adfd8bffdda4d2b0d1cba85a3fd73a105a28
---
.../InstCombine/InstCombineAndOrXor.cpp | 52 ++++++++++++-------
.../test/Transforms/InstCombine/or-bitmask.ll | 8 +--
2 files changed, 36 insertions(+), 24 deletions(-)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index ea166717d5c05..62ff45fb24379 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -3593,10 +3593,16 @@ static Value *foldOrOfInversions(BinaryOperator &I,
return nullptr;
}
+// A decomposition of ((A & N) ? 0 : N * C) . Where X = A, Factor = C, Mask = N.
+// The NUW / NSW bools
+// Note that we can decompose equivalent forms of this expression (e.g. ((A & N)
+// * C))
struct DecomposedBitMaskMul {
Value *X;
APInt Factor;
APInt Mask;
+ bool NUW;
+ bool NSW;
};
static std::optional<DecomposedBitMaskMul> matchBitmaskMul(Value *V) {
@@ -3606,20 +3612,21 @@ static std::optional<DecomposedBitMaskMul> matchBitmaskMul(Value *V) {
Value *MulOp = nullptr;
const APInt *MulConst = nullptr;
+
+ // Decompose (A & N) * C) into BitMaskMul
if (match(Op, m_Mul(m_Value(MulOp), m_APInt(MulConst)))) {
Value *Original = nullptr;
const APInt *Mask = nullptr;
- if (!MulConst->isStrictlyPositive())
+ if (MulConst->isZero())
return std::nullopt;
if (match(MulOp, m_And(m_Value(Original), m_APInt(Mask)))) {
- if (!Mask->isStrictlyPositive())
+ if (Mask->isZero())
return std::nullopt;
- DecomposedBitMaskMul Ret;
- Ret.X = Original;
- Ret.Mask = *Mask;
- Ret.Factor = *MulConst;
- return Ret;
+ return std::optional<DecomposedBitMaskMul>(
+ {Original, *MulConst, *Mask,
+ cast<BinaryOperator>(Op)->hasNoUnsignedWrap(),
+ cast<BinaryOperator>(Op)->hasNoSignedWrap()});
}
return std::nullopt;
}
@@ -3627,7 +3634,7 @@ static std::optional<DecomposedBitMaskMul> matchBitmaskMul(Value *V) {
Value *Cond = nullptr;
const APInt *EqZero = nullptr, *NeZero = nullptr;
- // (!(A & N) ? 0 : N * C) + (!(A & M) ? 0 : M * C) -> A & (N + M) * C
+ // Decompose ((A & N) ? 0 : N * C) into BitMaskMul
if (match(Op, m_Select(m_Value(Cond), m_APInt(EqZero), m_APInt(NeZero)))) {
auto ICmpDecompose =
decomposeBitTest(Cond, /*LookThruTrunc=*/true,
@@ -3638,22 +3645,20 @@ static std::optional<DecomposedBitMaskMul> matchBitmaskMul(Value *V) {
if (ICmpDecompose->Pred == ICmpInst::ICMP_NE)
std::swap(EqZero, NeZero);
- if (!EqZero->isZero() || !NeZero->isStrictlyPositive())
+ if (!EqZero->isZero() || NeZero->isZero())
return std::nullopt;
if (!ICmpInst::isEquality(ICmpDecompose->Pred) ||
!ICmpDecompose->C.isZero() || !ICmpDecompose->Mask.isPowerOf2() ||
- ICmpDecompose->Mask.isNegative())
+ ICmpDecompose->Mask.isZero())
return std::nullopt;
if (!NeZero->urem(ICmpDecompose->Mask).isZero())
return std::nullopt;
- DecomposedBitMaskMul Ret;
- Ret.X = ICmpDecompose->X;
- Ret.Mask = ICmpDecompose->Mask;
- Ret.Factor = NeZero->udiv(ICmpDecompose->Mask);
- return Ret;
+ return std::optional<DecomposedBitMaskMul>(
+ {ICmpDecompose->X, NeZero->udiv(ICmpDecompose->Mask),
+ ICmpDecompose->Mask, /*NUW=*/false, /*NSW=*/false});
}
return std::nullopt;
@@ -3741,19 +3746,26 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
/*NSW=*/true, /*NUW=*/true))
return R;
- auto Decomp0 = matchBitmaskMul(I.getOperand(0));
+ // (!(A & N) ? 0 : N * C) + (!(A & M) ? 0 : M * C) -> A & (N + M) * C
+ // This also accepts the equivalent mul form of (A & N) ? 0 : N * C)
+ // expressions i.e. (A & N) * C
auto Decomp1 = matchBitmaskMul(I.getOperand(1));
-
- if (Decomp0 && Decomp1) {
- if (Decomp0->X == Decomp1->X &&
+ if (Decomp1) {
+ auto Decomp0 = matchBitmaskMul(I.getOperand(0));
+ if (Decomp0 && Decomp0->X == Decomp1->X &&
(Decomp0->Mask & Decomp1->Mask).isZero() &&
Decomp0->Factor == Decomp1->Factor) {
+
auto NewAnd = Builder.CreateAnd(
Decomp0->X, ConstantInt::get(Decomp0->X->getType(),
(Decomp0->Mask + Decomp1->Mask)));
- return BinaryOperator::CreateMul(
+ auto Combined = BinaryOperator::CreateMul(
NewAnd, ConstantInt::get(NewAnd->getType(), Decomp1->Factor));
+
+ Combined->setHasNoUnsignedWrap(Decomp0->NUW && Decomp1->NUW);
+ Combined->setHasNoSignedWrap(Decomp0->NSW && Decomp1->NSW);
+ return Combined;
}
}
}
diff --git a/llvm/test/Transforms/InstCombine/or-bitmask.ll b/llvm/test/Transforms/InstCombine/or-bitmask.ll
index 87f0bbf4d37ab..dcfbe171dd08f 100644
--- a/llvm/test/Transforms/InstCombine/or-bitmask.ll
+++ b/llvm/test/Transforms/InstCombine/or-bitmask.ll
@@ -37,8 +37,8 @@ define i32 @add_select_cmp_and2(i32 %in) {
define i32 @add_select_cmp_and3(i32 %in) {
; CHECK-LABEL: @add_select_cmp_and3(
; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[IN:%.*]], 7
-; CHECK-NEXT: [[TEMP:%.*]] = mul nuw nsw i32 [[TMP1]], 72
-; CHECK-NEXT: ret i32 [[TEMP]]
+; CHECK-NEXT: [[TEMP1:%.*]] = mul nuw nsw i32 [[TMP1]], 72
+; CHECK-NEXT: ret i32 [[TEMP1]]
;
%bitop0 = and i32 %in, 1
%cmp0 = icmp eq i32 %bitop0, 0
@@ -57,8 +57,8 @@ define i32 @add_select_cmp_and3(i32 %in) {
define i32 @add_select_cmp_and4(i32 %in) {
; CHECK-LABEL: @add_select_cmp_and4(
; CHECK-NEXT: [[TMP2:%.*]] = and i32 [[IN:%.*]], 15
-; CHECK-NEXT: [[TEMP3:%.*]] = mul nuw nsw i32 [[TMP2]], 72
-; CHECK-NEXT: ret i32 [[TEMP3]]
+; CHECK-NEXT: [[TEMP2:%.*]] = mul nuw nsw i32 [[TMP2]], 72
+; CHECK-NEXT: ret i32 [[TEMP2]]
;
%bitop0 = and i32 %in, 1
%cmp0 = icmp eq i32 %bitop0, 0
>From 5fa229ba2432d00512a7d58c3ffa7ec610ee4aa6 Mon Sep 17 00:00:00 2001
From: Jeffrey Byrnes <Jeffrey.Byrnes at amd.com>
Date: Tue, 27 May 2025 11:03:46 -0700
Subject: [PATCH 3/4] Fix crash due to mismatch APInt bitwidth
Change-Id: I12f77aedbf1a2edfe63e4d03cd1e5c1c601365a7
---
llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index 62ff45fb24379..e357e3d296cc1 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -3650,7 +3650,8 @@ static std::optional<DecomposedBitMaskMul> matchBitmaskMul(Value *V) {
if (!ICmpInst::isEquality(ICmpDecompose->Pred) ||
!ICmpDecompose->C.isZero() || !ICmpDecompose->Mask.isPowerOf2() ||
- ICmpDecompose->Mask.isZero())
+ ICmpDecompose->Mask.isZero() ||
+ NeZero->getBitWidth() != ICmpDecompose->Mask.getBitWidth())
return std::nullopt;
if (!NeZero->urem(ICmpDecompose->Mask).isZero())
>From 9ccf1fa021df9068cd071493942b7c718dd8ad29 Mon Sep 17 00:00:00 2001
From: Jeffrey Byrnes <Jeffrey.Byrnes at amd.com>
Date: Thu, 5 Jun 2025 09:40:16 -0700
Subject: [PATCH 4/4] Review comments
Change-Id: I56a280990a9bae36e59f784a7f48bdbc9f7ca539
---
.../InstCombine/InstCombineAndOrXor.cpp | 37 ++++++++-----------
.../test/Transforms/InstCombine/or-bitmask.ll | 17 +++++++++
2 files changed, 33 insertions(+), 21 deletions(-)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index e357e3d296cc1..de029be1d28ce 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -3593,10 +3593,9 @@ static Value *foldOrOfInversions(BinaryOperator &I,
return nullptr;
}
-// A decomposition of ((A & N) ? 0 : N * C) . Where X = A, Factor = C, Mask = N.
-// The NUW / NSW bools
-// Note that we can decompose equivalent forms of this expression (e.g. ((A & N)
-// * C))
+// A decomposition of ((X & Mask) ? 0 : Mask * Factor) . The NUW / NSW bools
+// track these properities for preservation. Note that we can decompose
+// equivalent forms of this expression (e.g. ((X & Mask) * Factor))
struct DecomposedBitMaskMul {
Value *X;
APInt Factor;
@@ -3610,25 +3609,20 @@ static std::optional<DecomposedBitMaskMul> matchBitmaskMul(Value *V) {
if (!Op)
return std::nullopt;
- Value *MulOp = nullptr;
const APInt *MulConst = nullptr;
// Decompose (A & N) * C) into BitMaskMul
- if (match(Op, m_Mul(m_Value(MulOp), m_APInt(MulConst)))) {
- Value *Original = nullptr;
- const APInt *Mask = nullptr;
- if (MulConst->isZero())
+ Value *Original = nullptr;
+ const APInt *Mask = nullptr;
+ if (match(Op, m_Mul(m_And(m_Value(Original), m_APInt(Mask)),
+ m_APInt(MulConst)))) {
+ if (MulConst->isZero() || Mask->isZero())
return std::nullopt;
- if (match(MulOp, m_And(m_Value(Original), m_APInt(Mask)))) {
- if (Mask->isZero())
- return std::nullopt;
- return std::optional<DecomposedBitMaskMul>(
- {Original, *MulConst, *Mask,
- cast<BinaryOperator>(Op)->hasNoUnsignedWrap(),
- cast<BinaryOperator>(Op)->hasNoSignedWrap()});
- }
- return std::nullopt;
+ return std::optional<DecomposedBitMaskMul>(
+ {Original, *MulConst, *Mask,
+ cast<BinaryOperator>(Op)->hasNoUnsignedWrap(),
+ cast<BinaryOperator>(Op)->hasNoSignedWrap()});
}
Value *Cond = nullptr;
@@ -3642,15 +3636,16 @@ static std::optional<DecomposedBitMaskMul> matchBitmaskMul(Value *V) {
if (!ICmpDecompose.has_value())
return std::nullopt;
+ assert(ICmpInst::isEquality(ICmpDecompose->Pred) &&
+ ICmpDecompose->C.isZero());
+
if (ICmpDecompose->Pred == ICmpInst::ICMP_NE)
std::swap(EqZero, NeZero);
if (!EqZero->isZero() || NeZero->isZero())
return std::nullopt;
- if (!ICmpInst::isEquality(ICmpDecompose->Pred) ||
- !ICmpDecompose->C.isZero() || !ICmpDecompose->Mask.isPowerOf2() ||
- ICmpDecompose->Mask.isZero() ||
+ if (!ICmpDecompose->Mask.isPowerOf2() || ICmpDecompose->Mask.isZero() ||
NeZero->getBitWidth() != ICmpDecompose->Mask.getBitWidth())
return std::nullopt;
diff --git a/llvm/test/Transforms/InstCombine/or-bitmask.ll b/llvm/test/Transforms/InstCombine/or-bitmask.ll
index dcfbe171dd08f..3c992dfea569a 100644
--- a/llvm/test/Transforms/InstCombine/or-bitmask.ll
+++ b/llvm/test/Transforms/InstCombine/or-bitmask.ll
@@ -434,6 +434,23 @@ define i32 @add_select_cmp_and_mul_mismatch(i32 %in) {
ret i32 %out
}
+define i32 @and_mul_non_disjoint(i32 %in) {
+; CHECK-LABEL: @and_mul_non_disjoint(
+; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[IN:%.*]], 2
+; CHECK-NEXT: [[OUT:%.*]] = mul nuw nsw i32 [[TMP1]], 72
+; CHECK-NEXT: [[MASK1:%.*]] = and i32 [[IN]], 4
+; CHECK-NEXT: [[SEL1:%.*]] = mul nuw nsw i32 [[MASK1]], 72
+; CHECK-NEXT: [[OUT1:%.*]] = or i32 [[OUT]], [[SEL1]]
+; CHECK-NEXT: ret i32 [[OUT1]]
+;
+ %mask0 = and i32 %in, 2
+ %sel0 = mul i32 %mask0, 72
+ %mask1 = and i32 %in, 4
+ %sel1 = mul i32 %mask1, 72
+ %out = or i32 %sel0, %sel1
+ ret i32 %out
+}
+
;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
; CONSTSPLAT: {{.*}}
; CONSTVEC: {{.*}}
More information about the llvm-commits
mailing list