[llvm] [InstCombine] Extend bitmask mul combine to handle independent operands (PR #142503)
Jeffrey Byrnes via llvm-commits
llvm-commits at lists.llvm.org
Thu Jun 12 08:50:32 PDT 2025
https://github.com/jrbyrnes updated https://github.com/llvm/llvm-project/pull/142503
>From e9996d1d98980da1b1dea67f3fa5d5aef760570a Mon Sep 17 00:00:00 2001
From: Jeffrey Byrnes <Jeffrey.Byrnes at amd.com>
Date: Mon, 2 Jun 2025 12:29:39 -0700
Subject: [PATCH 1/2] [InstCombine] Extend bitmask mul combine to handle
independent operands
Change-Id: Ife1a010d2ae6df40549a6c73f7b893948befa3be
---
.../InstCombine/InstCombineAndOrXor.cpp | 92 +++++++++++++++----
.../test/Transforms/InstCombine/or-bitmask.ll | 50 ++++++++++
2 files changed, 123 insertions(+), 19 deletions(-)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index dce695a036006..099359021a394 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -3602,6 +3602,11 @@ struct DecomposedBitMaskMul {
APInt Mask;
bool NUW;
bool NSW;
+
+ bool isCombineableWith(DecomposedBitMaskMul Other) {
+ return X == Other.X && (Mask & Other.Mask).isZero() &&
+ Factor == Other.Factor;
+ }
};
static std::optional<DecomposedBitMaskMul> matchBitmaskMul(Value *V) {
@@ -3659,6 +3664,34 @@ static std::optional<DecomposedBitMaskMul> matchBitmaskMul(Value *V) {
return std::nullopt;
}
+using CombinedBitmaskMul =
+ std::pair<std::optional<DecomposedBitMaskMul>, Value *>;
+
+static CombinedBitmaskMul matchCombinedBitmaskMul(Value *V) {
+ auto DecompBitMaskMul = matchBitmaskMul(V);
+ if (DecompBitMaskMul)
+ return {DecompBitMaskMul, nullptr};
+
+ // Otherwise, check the operands of V for bitmaskmul pattern
+ auto BOp = dyn_cast<BinaryOperator>(V);
+ if (!BOp)
+ return {std::nullopt, nullptr};
+
+ auto Disj = dyn_cast<PossiblyDisjointInst>(BOp);
+ if (!Disj || !Disj->isDisjoint())
+ return {std::nullopt, nullptr};
+
+ auto DecompBitMaskMul0 = matchBitmaskMul(BOp->getOperand(0));
+ if (DecompBitMaskMul0)
+ return {DecompBitMaskMul0, BOp->getOperand(1)};
+
+ auto DecompBitMaskMul1 = matchBitmaskMul(BOp->getOperand(1));
+ if (DecompBitMaskMul1)
+ return {DecompBitMaskMul1, BOp->getOperand(0)};
+
+ return {std::nullopt, nullptr};
+}
+
// FIXME: We use commutative matchers (m_c_*) for some, but not all, matches
// here. We should standardize that construct where it is needed or choose some
// other way to ensure that commutated variants of patterns are not missed.
@@ -3741,25 +3774,46 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
/*NSW=*/true, /*NUW=*/true))
return R;
- // (A & N) * C + (A & M) * C -> (A & (N + M)) & C
- // This also accepts the equivalent select form of (A & N) * C
- // expressions i.e. !(A & N) ? 0 : N * C)
- auto Decomp1 = matchBitmaskMul(I.getOperand(1));
- if (Decomp1) {
- auto Decomp0 = matchBitmaskMul(I.getOperand(0));
- if (Decomp0 && Decomp0->X == Decomp1->X &&
- (Decomp0->Mask & Decomp1->Mask).isZero() &&
- Decomp0->Factor == Decomp1->Factor) {
-
- Value *NewAnd = Builder.CreateAnd(
- Decomp0->X, ConstantInt::get(Decomp0->X->getType(),
- (Decomp0->Mask + Decomp1->Mask)));
-
- auto *Combined = BinaryOperator::CreateMul(
- NewAnd, ConstantInt::get(NewAnd->getType(), Decomp1->Factor));
-
- Combined->setHasNoUnsignedWrap(Decomp0->NUW && Decomp1->NUW);
- Combined->setHasNoSignedWrap(Decomp0->NSW && Decomp1->NSW);
+ // (!(A & N) ? 0 : N * C) + (!(A & M) ? 0 : M * C) -> A & (N + M) * C
+ // This also accepts the equivalent mul form of (A & N) ? 0 : N * C)
+ // expressions i.e. (A & N) * C
+ CombinedBitmaskMul Decomp1 = matchCombinedBitmaskMul(I.getOperand(1));
+ auto BMDecomp1 = Decomp1.first;
+
+ if (BMDecomp1) {
+ CombinedBitmaskMul Decomp0 = matchCombinedBitmaskMul(I.getOperand(0));
+ auto BMDecomp0 = Decomp0.first;
+
+ if (BMDecomp0 && BMDecomp0->isCombineableWith(*BMDecomp1)) {
+ auto NewAnd = Builder.CreateAnd(
+ BMDecomp0->X,
+ ConstantInt::get(BMDecomp0->X->getType(),
+ (BMDecomp0->Mask + BMDecomp1->Mask)));
+
+ BinaryOperator *Combined = cast<BinaryOperator>(Builder.CreateMul(
+ NewAnd, ConstantInt::get(NewAnd->getType(), BMDecomp1->Factor)));
+
+ Combined->setHasNoUnsignedWrap(BMDecomp0->NUW && BMDecomp1->NUW);
+ Combined->setHasNoSignedWrap(BMDecomp0->NSW && BMDecomp1->NSW);
+
+ // If our tree has indepdent or-disjoint operands, bring them in.
+ auto OtherOp0 = Decomp0.second;
+ auto OtherOp1 = Decomp1.second;
+
+ if (OtherOp0 || OtherOp1) {
+ Value *OtherOp;
+ if (OtherOp0 && OtherOp1) {
+ OtherOp = Builder.CreateOr(OtherOp0, OtherOp1);
+ cast<PossiblyDisjointInst>(OtherOp)->setIsDisjoint(true);
+ } else {
+ OtherOp = OtherOp0 ? OtherOp0 : OtherOp1;
+ }
+ Combined = cast<BinaryOperator>(Builder.CreateOr(Combined, OtherOp));
+ cast<PossiblyDisjointInst>(Combined)->setIsDisjoint(true);
+ }
+
+ // Caller expects detached instruction
+ Combined->removeFromParent();
return Combined;
}
}
diff --git a/llvm/test/Transforms/InstCombine/or-bitmask.ll b/llvm/test/Transforms/InstCombine/or-bitmask.ll
index 3c992dfea569a..0976b76542f49 100644
--- a/llvm/test/Transforms/InstCombine/or-bitmask.ll
+++ b/llvm/test/Transforms/InstCombine/or-bitmask.ll
@@ -451,6 +451,56 @@ define i32 @and_mul_non_disjoint(i32 %in) {
ret i32 %out
}
+define i32 @unrelated_ops(i32 %in, i32 %in2) {
+; CHECK-LABEL: @unrelated_ops(
+; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[IN:%.*]], 15
+; CHECK-NEXT: [[TMP2:%.*]] = mul nuw nsw i32 [[TMP1]], 72
+; CHECK-NEXT: [[OUT:%.*]] = or disjoint i32 [[TMP2]], [[IN2:%.*]]
+; CHECK-NEXT: ret i32 [[OUT]]
+;
+ %1 = and i32 %in, 3
+ %temp = mul nuw nsw i32 %1, 72
+ %2 = and i32 %in, 12
+ %temp2 = mul nuw nsw i32 %2, 72
+ %temp3 = or disjoint i32 %in2, %temp2
+ %out = or disjoint i32 %temp, %temp3
+ ret i32 %out
+}
+
+define i32 @unrelated_ops1(i32 %in, i32 %in2) {
+; CHECK-LABEL: @unrelated_ops1(
+; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[IN:%.*]], 15
+; CHECK-NEXT: [[TMP2:%.*]] = mul nuw nsw i32 [[TMP1]], 72
+; CHECK-NEXT: [[OUT:%.*]] = or disjoint i32 [[TMP2]], [[IN2:%.*]]
+; CHECK-NEXT: ret i32 [[OUT]]
+;
+ %1 = and i32 %in, 3
+ %temp = mul nuw nsw i32 %1, 72
+ %2 = and i32 %in, 12
+ %temp2 = mul nuw nsw i32 %2, 72
+ %temp3 = or disjoint i32 %in2, %temp
+ %out = or disjoint i32 %temp3, %temp2
+ ret i32 %out
+}
+
+define i32 @unrelated_ops2(i32 %in, i32 %in2, i32 %in3) {
+; CHECK-LABEL: @unrelated_ops2(
+; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[IN:%.*]], 15
+; CHECK-NEXT: [[TMP2:%.*]] = mul nuw nsw i32 [[TMP1]], 72
+; CHECK-NEXT: [[TMP3:%.*]] = or disjoint i32 [[IN3:%.*]], [[IN2:%.*]]
+; CHECK-NEXT: [[OUT:%.*]] = or disjoint i32 [[TMP2]], [[TMP3]]
+; CHECK-NEXT: ret i32 [[OUT]]
+;
+ %1 = and i32 %in, 3
+ %temp = mul nuw nsw i32 %1, 72
+ %temp3 = or disjoint i32 %temp, %in3
+ %2 = and i32 %in, 12
+ %temp2 = mul nuw nsw i32 %2, 72
+ %temp4 = or disjoint i32 %in2, %temp2
+ %out = or disjoint i32 %temp3, %temp4
+ ret i32 %out
+}
+
;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
; CONSTSPLAT: {{.*}}
; CONSTVEC: {{.*}}
>From 171ca6bf1544359c8e55056e2e0f21d71bb7b6ea Mon Sep 17 00:00:00 2001
From: Jeffrey Byrnes <Jeffrey.Byrnes at amd.com>
Date: Thu, 12 Jun 2025 08:50:01 -0700
Subject: [PATCH 2/2] Fix comment from bad merge
Change-Id: I879acdf0b17a7110286c6c375410300611c468eb
---
llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index 099359021a394..c6c0a85b06bdd 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -3774,9 +3774,9 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
/*NSW=*/true, /*NUW=*/true))
return R;
- // (!(A & N) ? 0 : N * C) + (!(A & M) ? 0 : M * C) -> A & (N + M) * C
- // This also accepts the equivalent mul form of (A & N) ? 0 : N * C)
- // expressions i.e. (A & N) * C
+ // (A & N) * C + (A & M) * C -> (A & (N + M)) & C
+ // This also accepts the equivalent select form of (A & N) * C
+ // expressions i.e. !(A & N) ? 0 : N * C)
CombinedBitmaskMul Decomp1 = matchCombinedBitmaskMul(I.getOperand(1));
auto BMDecomp1 = Decomp1.first;
More information about the llvm-commits
mailing list