[llvm] [InstCombine] Extend bitmask mul combine to handle independent operands (PR #142503)

Jeffrey Byrnes via llvm-commits llvm-commits at lists.llvm.org
Thu Jun 12 08:50:32 PDT 2025


https://github.com/jrbyrnes updated https://github.com/llvm/llvm-project/pull/142503

>From e9996d1d98980da1b1dea67f3fa5d5aef760570a Mon Sep 17 00:00:00 2001
From: Jeffrey Byrnes <Jeffrey.Byrnes at amd.com>
Date: Mon, 2 Jun 2025 12:29:39 -0700
Subject: [PATCH 1/2] [InstCombine] Extend bitmask mul combine to handle
 independent operands

Change-Id: Ife1a010d2ae6df40549a6c73f7b893948befa3be
---
 .../InstCombine/InstCombineAndOrXor.cpp       | 92 +++++++++++++++----
 .../test/Transforms/InstCombine/or-bitmask.ll | 50 ++++++++++
 2 files changed, 123 insertions(+), 19 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index dce695a036006..099359021a394 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -3602,6 +3602,11 @@ struct DecomposedBitMaskMul {
   APInt Mask;
   bool NUW;
   bool NSW;
+
+  bool isCombineableWith(DecomposedBitMaskMul Other) {
+    return X == Other.X && (Mask & Other.Mask).isZero() &&
+           Factor == Other.Factor;
+  }
 };
 
 static std::optional<DecomposedBitMaskMul> matchBitmaskMul(Value *V) {
@@ -3659,6 +3664,34 @@ static std::optional<DecomposedBitMaskMul> matchBitmaskMul(Value *V) {
   return std::nullopt;
 }
 
+using CombinedBitmaskMul =
+    std::pair<std::optional<DecomposedBitMaskMul>, Value *>;
+
+static CombinedBitmaskMul matchCombinedBitmaskMul(Value *V) {
+  auto DecompBitMaskMul = matchBitmaskMul(V);
+  if (DecompBitMaskMul)
+    return {DecompBitMaskMul, nullptr};
+
+  // Otherwise, check the operands of V for bitmaskmul pattern
+  auto BOp = dyn_cast<BinaryOperator>(V);
+  if (!BOp)
+    return {std::nullopt, nullptr};
+
+  auto Disj = dyn_cast<PossiblyDisjointInst>(BOp);
+  if (!Disj || !Disj->isDisjoint())
+    return {std::nullopt, nullptr};
+
+  auto DecompBitMaskMul0 = matchBitmaskMul(BOp->getOperand(0));
+  if (DecompBitMaskMul0)
+    return {DecompBitMaskMul0, BOp->getOperand(1)};
+
+  auto DecompBitMaskMul1 = matchBitmaskMul(BOp->getOperand(1));
+  if (DecompBitMaskMul1)
+    return {DecompBitMaskMul1, BOp->getOperand(0)};
+
+  return {std::nullopt, nullptr};
+}
+
 // FIXME: We use commutative matchers (m_c_*) for some, but not all, matches
 // here. We should standardize that construct where it is needed or choose some
 // other way to ensure that commutated variants of patterns are not missed.
@@ -3741,25 +3774,46 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
                                    /*NSW=*/true, /*NUW=*/true))
       return R;
 
-    // (A & N) * C + (A & M) * C -> (A & (N + M)) & C
-    // This also accepts the equivalent select form of (A & N) * C
-    // expressions i.e. !(A & N) ? 0 : N * C)
-    auto Decomp1 = matchBitmaskMul(I.getOperand(1));
-    if (Decomp1) {
-      auto Decomp0 = matchBitmaskMul(I.getOperand(0));
-      if (Decomp0 && Decomp0->X == Decomp1->X &&
-          (Decomp0->Mask & Decomp1->Mask).isZero() &&
-          Decomp0->Factor == Decomp1->Factor) {
-
-        Value *NewAnd = Builder.CreateAnd(
-            Decomp0->X, ConstantInt::get(Decomp0->X->getType(),
-                                         (Decomp0->Mask + Decomp1->Mask)));
-
-        auto *Combined = BinaryOperator::CreateMul(
-            NewAnd, ConstantInt::get(NewAnd->getType(), Decomp1->Factor));
-
-        Combined->setHasNoUnsignedWrap(Decomp0->NUW && Decomp1->NUW);
-        Combined->setHasNoSignedWrap(Decomp0->NSW && Decomp1->NSW);
+    // (!(A & N) ? 0 : N * C) + (!(A & M) ? 0 : M * C) -> A & (N + M) * C
+    // This also accepts the equivalent mul form of (A & N) ? 0 : N * C)
+    // expressions i.e. (A & N) * C
+    CombinedBitmaskMul Decomp1 = matchCombinedBitmaskMul(I.getOperand(1));
+    auto BMDecomp1 = Decomp1.first;
+
+    if (BMDecomp1) {
+      CombinedBitmaskMul Decomp0 = matchCombinedBitmaskMul(I.getOperand(0));
+      auto BMDecomp0 = Decomp0.first;
+
+      if (BMDecomp0 && BMDecomp0->isCombineableWith(*BMDecomp1)) {
+        auto NewAnd = Builder.CreateAnd(
+            BMDecomp0->X,
+            ConstantInt::get(BMDecomp0->X->getType(),
+                             (BMDecomp0->Mask + BMDecomp1->Mask)));
+
+        BinaryOperator *Combined = cast<BinaryOperator>(Builder.CreateMul(
+            NewAnd, ConstantInt::get(NewAnd->getType(), BMDecomp1->Factor)));
+
+        Combined->setHasNoUnsignedWrap(BMDecomp0->NUW && BMDecomp1->NUW);
+        Combined->setHasNoSignedWrap(BMDecomp0->NSW && BMDecomp1->NSW);
+
+        // If our tree has indepdent or-disjoint operands, bring them in.
+        auto OtherOp0 = Decomp0.second;
+        auto OtherOp1 = Decomp1.second;
+
+        if (OtherOp0 || OtherOp1) {
+          Value *OtherOp;
+          if (OtherOp0 && OtherOp1) {
+            OtherOp = Builder.CreateOr(OtherOp0, OtherOp1);
+            cast<PossiblyDisjointInst>(OtherOp)->setIsDisjoint(true);
+          } else {
+            OtherOp = OtherOp0 ? OtherOp0 : OtherOp1;
+          }
+          Combined = cast<BinaryOperator>(Builder.CreateOr(Combined, OtherOp));
+          cast<PossiblyDisjointInst>(Combined)->setIsDisjoint(true);
+        }
+
+        // Caller expects detached instruction
+        Combined->removeFromParent();
         return Combined;
       }
     }
diff --git a/llvm/test/Transforms/InstCombine/or-bitmask.ll b/llvm/test/Transforms/InstCombine/or-bitmask.ll
index 3c992dfea569a..0976b76542f49 100644
--- a/llvm/test/Transforms/InstCombine/or-bitmask.ll
+++ b/llvm/test/Transforms/InstCombine/or-bitmask.ll
@@ -451,6 +451,56 @@ define i32 @and_mul_non_disjoint(i32 %in) {
   ret i32 %out
 }
 
+define i32 @unrelated_ops(i32 %in, i32 %in2) {
+; CHECK-LABEL: @unrelated_ops(
+; CHECK-NEXT:    [[TMP1:%.*]] = and i32 [[IN:%.*]], 15
+; CHECK-NEXT:    [[TMP2:%.*]] = mul nuw nsw i32 [[TMP1]], 72
+; CHECK-NEXT:    [[OUT:%.*]] = or disjoint i32 [[TMP2]], [[IN2:%.*]]
+; CHECK-NEXT:    ret i32 [[OUT]]
+;
+  %1 = and i32 %in, 3
+  %temp = mul nuw nsw i32 %1, 72
+  %2 = and i32 %in, 12
+  %temp2 = mul nuw nsw i32 %2, 72
+  %temp3 = or disjoint i32 %in2, %temp2
+  %out = or disjoint i32 %temp, %temp3
+  ret i32 %out
+}
+
+define i32 @unrelated_ops1(i32 %in, i32 %in2) {
+; CHECK-LABEL: @unrelated_ops1(
+; CHECK-NEXT:    [[TMP1:%.*]] = and i32 [[IN:%.*]], 15
+; CHECK-NEXT:    [[TMP2:%.*]] = mul nuw nsw i32 [[TMP1]], 72
+; CHECK-NEXT:    [[OUT:%.*]] = or disjoint i32 [[TMP2]], [[IN2:%.*]]
+; CHECK-NEXT:    ret i32 [[OUT]]
+;
+  %1 = and i32 %in, 3
+  %temp = mul nuw nsw i32 %1, 72
+  %2 = and i32 %in, 12
+  %temp2 = mul nuw nsw i32 %2, 72
+  %temp3 = or disjoint i32 %in2, %temp
+  %out = or disjoint i32 %temp3, %temp2
+  ret i32 %out
+}
+
+define i32 @unrelated_ops2(i32 %in, i32 %in2, i32 %in3) {
+; CHECK-LABEL: @unrelated_ops2(
+; CHECK-NEXT:    [[TMP1:%.*]] = and i32 [[IN:%.*]], 15
+; CHECK-NEXT:    [[TMP2:%.*]] = mul nuw nsw i32 [[TMP1]], 72
+; CHECK-NEXT:    [[TMP3:%.*]] = or disjoint i32 [[IN3:%.*]], [[IN2:%.*]]
+; CHECK-NEXT:    [[OUT:%.*]] = or disjoint i32 [[TMP2]], [[TMP3]]
+; CHECK-NEXT:    ret i32 [[OUT]]
+;
+  %1 = and i32 %in, 3
+  %temp = mul nuw nsw i32 %1, 72
+  %temp3 = or disjoint i32 %temp, %in3
+  %2 = and i32 %in, 12
+  %temp2 = mul nuw nsw i32 %2, 72
+  %temp4 = or disjoint i32 %in2, %temp2
+  %out = or disjoint i32 %temp3, %temp4
+  ret i32 %out
+}
+
 ;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
 ; CONSTSPLAT: {{.*}}
 ; CONSTVEC: {{.*}}

>From 171ca6bf1544359c8e55056e2e0f21d71bb7b6ea Mon Sep 17 00:00:00 2001
From: Jeffrey Byrnes <Jeffrey.Byrnes at amd.com>
Date: Thu, 12 Jun 2025 08:50:01 -0700
Subject: [PATCH 2/2] Fix comment from bad merge

Change-Id: I879acdf0b17a7110286c6c375410300611c468eb
---
 llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index 099359021a394..c6c0a85b06bdd 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -3774,9 +3774,9 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
                                    /*NSW=*/true, /*NUW=*/true))
       return R;
 
-    // (!(A & N) ? 0 : N * C) + (!(A & M) ? 0 : M * C) -> A & (N + M) * C
-    // This also accepts the equivalent mul form of (A & N) ? 0 : N * C)
-    // expressions i.e. (A & N) * C
+    // (A & N) * C + (A & M) * C -> (A & (N + M)) & C
+    // This also accepts the equivalent select form of (A & N) * C
+    // expressions i.e. !(A & N) ? 0 : N * C)
     CombinedBitmaskMul Decomp1 = matchCombinedBitmaskMul(I.getOperand(1));
     auto BMDecomp1 = Decomp1.first;
 



More information about the llvm-commits mailing list