[llvm] [InstCombine] Extend bitmask mul combine to handle independent operands (PR #142503)

Jeffrey Byrnes via llvm-commits llvm-commits at lists.llvm.org
Mon Jun 2 16:29:35 PDT 2025


https://github.com/jrbyrnes created https://github.com/llvm/llvm-project/pull/142503

This extends https://github.com/llvm/llvm-project/pull/136013 to capture cases where the combineable bitmask muls are nested under multiple or-disjoints.

This PR is meant for commits starting at 8c403c912046505ffc10378560c2fc48f214af6a

op1 = or-disjoint mul(and (X, C1), D) , reg1
op2 = or-disjoint mul(and (X, C2), D) , reg2
out = or-disjoint op1, op2

->

temp1 = or-disjoint reg1, reg2
out = or-disjoint mul(and (X, (C1 + C2)), D), temp1


Case1: https://alive2.llvm.org/ce/z/dHApyV
Case2: https://alive2.llvm.org/ce/z/Jz-Nag
Case3: https://alive2.llvm.org/ce/z/3xBnEV

>From 0019711079e7d929b1853748d0f84c22adb04a62 Mon Sep 17 00:00:00 2001
From: Jeffrey Byrnes <Jeffrey.Byrnes at amd.com>
Date: Thu, 17 Apr 2025 10:11:18 -0700
Subject: [PATCH 1/4] [InstCombine] Extend bitmask->select combine to match
 and->mul

Change-Id: I1cc2acd3804dde50636518f3ef2c9581848ae9f6
---
 .../InstCombine/InstCombineAndOrXor.cpp       | 122 ++++++++++++------
 .../test/Transforms/InstCombine/or-bitmask.ll |  95 ++++++++++++--
 2 files changed, 163 insertions(+), 54 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index 59b46ebdb72e2..ea166717d5c05 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -3593,6 +3593,72 @@ static Value *foldOrOfInversions(BinaryOperator &I,
   return nullptr;
 }
 
+struct DecomposedBitMaskMul {
+  Value *X;
+  APInt Factor;
+  APInt Mask;
+};
+
+static std::optional<DecomposedBitMaskMul> matchBitmaskMul(Value *V) {
+  Instruction *Op = dyn_cast<Instruction>(V);
+  if (!Op)
+    return std::nullopt;
+
+  Value *MulOp = nullptr;
+  const APInt *MulConst = nullptr;
+  if (match(Op, m_Mul(m_Value(MulOp), m_APInt(MulConst)))) {
+    Value *Original = nullptr;
+    const APInt *Mask = nullptr;
+    if (!MulConst->isStrictlyPositive())
+      return std::nullopt;
+
+    if (match(MulOp, m_And(m_Value(Original), m_APInt(Mask)))) {
+      if (!Mask->isStrictlyPositive())
+        return std::nullopt;
+      DecomposedBitMaskMul Ret;
+      Ret.X = Original;
+      Ret.Mask = *Mask;
+      Ret.Factor = *MulConst;
+      return Ret;
+    }
+    return std::nullopt;
+  }
+
+  Value *Cond = nullptr;
+  const APInt *EqZero = nullptr, *NeZero = nullptr;
+
+  //  (!(A & N) ? 0 : N * C) + (!(A & M) ? 0 : M * C) -> A & (N + M) * C
+  if (match(Op, m_Select(m_Value(Cond), m_APInt(EqZero), m_APInt(NeZero)))) {
+    auto ICmpDecompose =
+        decomposeBitTest(Cond, /*LookThruTrunc=*/true,
+                         /*AllowNonZeroC=*/false, /*DecomposeBitMask=*/true);
+    if (!ICmpDecompose.has_value())
+      return std::nullopt;
+
+    if (ICmpDecompose->Pred == ICmpInst::ICMP_NE)
+      std::swap(EqZero, NeZero);
+
+    if (!EqZero->isZero() || !NeZero->isStrictlyPositive())
+      return std::nullopt;
+
+    if (!ICmpInst::isEquality(ICmpDecompose->Pred) ||
+        !ICmpDecompose->C.isZero() || !ICmpDecompose->Mask.isPowerOf2() ||
+        ICmpDecompose->Mask.isNegative())
+      return std::nullopt;
+
+    if (!NeZero->urem(ICmpDecompose->Mask).isZero())
+      return std::nullopt;
+
+    DecomposedBitMaskMul Ret;
+    Ret.X = ICmpDecompose->X;
+    Ret.Mask = ICmpDecompose->Mask;
+    Ret.Factor = NeZero->udiv(ICmpDecompose->Mask);
+    return Ret;
+  }
+
+  return std::nullopt;
+}
+
 // FIXME: We use commutative matchers (m_c_*) for some, but not all, matches
 // here. We should standardize that construct where it is needed or choose some
 // other way to ensure that commutated variants of patterns are not missed.
@@ -3675,49 +3741,19 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
                                    /*NSW=*/true, /*NUW=*/true))
       return R;
 
-    Value *Cond0 = nullptr, *Cond1 = nullptr;
-    const APInt *Op0Eq = nullptr, *Op0Ne = nullptr;
-    const APInt *Op1Eq = nullptr, *Op1Ne = nullptr;
-
-    //  (!(A & N) ? 0 : N * C) + (!(A & M) ? 0 : M * C) -> A & (N + M) * C
-    if (match(I.getOperand(0),
-              m_Select(m_Value(Cond0), m_APInt(Op0Eq), m_APInt(Op0Ne))) &&
-        match(I.getOperand(1),
-              m_Select(m_Value(Cond1), m_APInt(Op1Eq), m_APInt(Op1Ne)))) {
-
-      auto LHSDecompose =
-          decomposeBitTest(Cond0, /*LookThruTrunc=*/true,
-                           /*AllowNonZeroC=*/false, /*DecomposeAnd=*/true);
-      auto RHSDecompose =
-          decomposeBitTest(Cond1, /*LookThruTrunc=*/true,
-                           /*AllowNonZeroC=*/false, /*DecomposeAnd=*/true);
-
-      if (LHSDecompose && RHSDecompose && LHSDecompose->X == RHSDecompose->X &&
-          RHSDecompose->Mask.isPowerOf2() && LHSDecompose->Mask.isPowerOf2() &&
-          LHSDecompose->Mask != RHSDecompose->Mask &&
-          LHSDecompose->Mask.getBitWidth() == Op0Ne->getBitWidth() &&
-          RHSDecompose->Mask.getBitWidth() == Op1Ne->getBitWidth()) {
-        assert(Op0Ne->getBitWidth() == Op1Ne->getBitWidth());
-        assert(ICmpInst::isEquality(LHSDecompose->Pred));
-        if (LHSDecompose->Pred == ICmpInst::ICMP_NE)
-          std::swap(Op0Eq, Op0Ne);
-        if (RHSDecompose->Pred == ICmpInst::ICMP_NE)
-          std::swap(Op1Eq, Op1Ne);
-
-        if (!Op0Ne->isZero() && !Op1Ne->isZero() && Op0Eq->isZero() &&
-            Op1Eq->isZero() && Op0Ne->urem(LHSDecompose->Mask).isZero() &&
-            Op1Ne->urem(RHSDecompose->Mask).isZero() &&
-            Op0Ne->udiv(LHSDecompose->Mask) ==
-                Op1Ne->udiv(RHSDecompose->Mask)) {
-          auto NewAnd = Builder.CreateAnd(
-              LHSDecompose->X,
-              ConstantInt::get(LHSDecompose->X->getType(),
-                               (LHSDecompose->Mask + RHSDecompose->Mask)));
-
-          return BinaryOperator::CreateMul(
-              NewAnd, ConstantInt::get(NewAnd->getType(),
-                                       Op0Ne->udiv(LHSDecompose->Mask)));
-        }
+    auto Decomp0 = matchBitmaskMul(I.getOperand(0));
+    auto Decomp1 = matchBitmaskMul(I.getOperand(1));
+
+    if (Decomp0 && Decomp1) {
+      if (Decomp0->X == Decomp1->X &&
+          (Decomp0->Mask & Decomp1->Mask).isZero() &&
+          Decomp0->Factor == Decomp1->Factor) {
+        auto NewAnd = Builder.CreateAnd(
+            Decomp0->X, ConstantInt::get(Decomp0->X->getType(),
+                                         (Decomp0->Mask + Decomp1->Mask)));
+
+        return BinaryOperator::CreateMul(
+            NewAnd, ConstantInt::get(NewAnd->getType(), Decomp1->Factor));
       }
     }
   }
diff --git a/llvm/test/Transforms/InstCombine/or-bitmask.ll b/llvm/test/Transforms/InstCombine/or-bitmask.ll
index 3b482dc1794db..87f0bbf4d37ab 100644
--- a/llvm/test/Transforms/InstCombine/or-bitmask.ll
+++ b/llvm/test/Transforms/InstCombine/or-bitmask.ll
@@ -36,13 +36,9 @@ define i32 @add_select_cmp_and2(i32 %in) {
 
 define i32 @add_select_cmp_and3(i32 %in) {
 ; CHECK-LABEL: @add_select_cmp_and3(
-; CHECK-NEXT:    [[TMP1:%.*]] = and i32 [[IN:%.*]], 3
+; CHECK-NEXT:    [[TMP1:%.*]] = and i32 [[IN:%.*]], 7
 ; CHECK-NEXT:    [[TEMP:%.*]] = mul nuw nsw i32 [[TMP1]], 72
-; CHECK-NEXT:    [[BITOP2:%.*]] = and i32 [[IN]], 4
-; CHECK-NEXT:    [[CMP2:%.*]] = icmp eq i32 [[BITOP2]], 0
-; CHECK-NEXT:    [[SEL2:%.*]] = select i1 [[CMP2]], i32 0, i32 288
-; CHECK-NEXT:    [[OUT:%.*]] = or disjoint i32 [[TEMP]], [[SEL2]]
-; CHECK-NEXT:    ret i32 [[OUT]]
+; CHECK-NEXT:    ret i32 [[TEMP]]
 ;
   %bitop0 = and i32 %in, 1
   %cmp0 = icmp eq i32 %bitop0, 0
@@ -60,12 +56,9 @@ define i32 @add_select_cmp_and3(i32 %in) {
 
 define i32 @add_select_cmp_and4(i32 %in) {
 ; CHECK-LABEL: @add_select_cmp_and4(
-; CHECK-NEXT:    [[TMP1:%.*]] = and i32 [[IN:%.*]], 3
-; CHECK-NEXT:    [[OUT:%.*]] = mul nuw nsw i32 [[TMP1]], 72
-; CHECK-NEXT:    [[TMP2:%.*]] = and i32 [[IN]], 12
+; CHECK-NEXT:    [[TMP2:%.*]] = and i32 [[IN:%.*]], 15
 ; CHECK-NEXT:    [[TEMP3:%.*]] = mul nuw nsw i32 [[TMP2]], 72
-; CHECK-NEXT:    [[OUT1:%.*]] = or disjoint i32 [[OUT]], [[TEMP3]]
-; CHECK-NEXT:    ret i32 [[OUT1]]
+; CHECK-NEXT:    ret i32 [[TEMP3]]
 ;
   %bitop0 = and i32 %in, 1
   %cmp0 = icmp eq i32 %bitop0, 0
@@ -361,6 +354,86 @@ define i64 @mask_select_types_1(i64 %in) {
   ret i64 %out
 }
 
+define i32 @add_select_cmp_mixed1(i32 %in) {
+; CHECK-LABEL: @add_select_cmp_mixed1(
+; CHECK-NEXT:    [[TMP1:%.*]] = and i32 [[IN:%.*]], 3
+; CHECK-NEXT:    [[OUT:%.*]] = mul nuw nsw i32 [[TMP1]], 72
+; CHECK-NEXT:    ret i32 [[OUT]]
+;
+  %mask = and i32 %in, 1
+  %sel0 = mul i32 %mask, 72
+  %bitop1 = and i32 %in, 2
+  %cmp1 = icmp eq i32 %bitop1, 0
+  %sel1 = select i1 %cmp1, i32 0, i32 144
+  %out = or disjoint i32 %sel0, %sel1
+  ret i32 %out
+}
+
+define i32 @add_select_cmp_mixed2(i32 %in) {
+; CHECK-LABEL: @add_select_cmp_mixed2(
+; CHECK-NEXT:    [[TMP1:%.*]] = and i32 [[IN:%.*]], 3
+; CHECK-NEXT:    [[OUT:%.*]] = mul nuw nsw i32 [[TMP1]], 72
+; CHECK-NEXT:    ret i32 [[OUT]]
+;
+  %bitop0 = and i32 %in, 1
+  %cmp0 = icmp eq i32 %bitop0, 0
+  %mask = and i32 %in, 2
+  %sel0 = select i1 %cmp0, i32 0, i32 72
+  %sel1 = mul i32 %mask, 72
+  %out = or disjoint i32 %sel0, %sel1
+  ret i32 %out
+}
+
+define i32 @add_select_cmp_and_mul(i32 %in) {
+; CHECK-LABEL: @add_select_cmp_and_mul(
+; CHECK-NEXT:    [[TMP1:%.*]] = and i32 [[IN:%.*]], 3
+; CHECK-NEXT:    [[OUT:%.*]] = mul nuw nsw i32 [[TMP1]], 72
+; CHECK-NEXT:    ret i32 [[OUT]]
+;
+  %mask0 = and i32 %in, 1
+  %sel0 = mul i32 %mask0, 72
+  %mask1 = and i32 %in, 2
+  %sel1 = mul i32 %mask1, 72
+  %out = or disjoint i32 %sel0, %sel1
+  ret i32 %out
+}
+
+define i32 @add_select_cmp_mixed2_mismatch(i32 %in) {
+; CHECK-LABEL: @add_select_cmp_mixed2_mismatch(
+; CHECK-NEXT:    [[BITOP0:%.*]] = and i32 [[IN:%.*]], 1
+; CHECK-NEXT:    [[CMP0:%.*]] = icmp eq i32 [[BITOP0]], 0
+; CHECK-NEXT:    [[MASK:%.*]] = and i32 [[IN]], 2
+; CHECK-NEXT:    [[SEL0:%.*]] = select i1 [[CMP0]], i32 0, i32 73
+; CHECK-NEXT:    [[SEL1:%.*]] = mul nuw nsw i32 [[MASK]], 72
+; CHECK-NEXT:    [[OUT:%.*]] = or disjoint i32 [[SEL0]], [[SEL1]]
+; CHECK-NEXT:    ret i32 [[OUT]]
+;
+  %bitop0 = and i32 %in, 1
+  %cmp0 = icmp eq i32 %bitop0, 0
+  %mask = and i32 %in, 2
+  %sel0 = select i1 %cmp0, i32 0, i32 73
+  %sel1 = mul i32 %mask, 72
+  %out = or disjoint i32 %sel0, %sel1
+  ret i32 %out
+}
+
+define i32 @add_select_cmp_and_mul_mismatch(i32 %in) {
+; CHECK-LABEL: @add_select_cmp_and_mul_mismatch(
+; CHECK-NEXT:    [[TMP1:%.*]] = trunc i32 [[IN:%.*]] to i1
+; CHECK-NEXT:    [[SEL0:%.*]] = select i1 [[TMP1]], i32 73, i32 0
+; CHECK-NEXT:    [[MASK1:%.*]] = and i32 [[IN]], 2
+; CHECK-NEXT:    [[SEL1:%.*]] = mul nuw nsw i32 [[MASK1]], 72
+; CHECK-NEXT:    [[OUT:%.*]] = or disjoint i32 [[SEL0]], [[SEL1]]
+; CHECK-NEXT:    ret i32 [[OUT]]
+;
+  %mask0 = and i32 %in, 1
+  %sel0 = mul i32 %mask0, 73
+  %mask1 = and i32 %in, 2
+  %sel1 = mul i32 %mask1, 72
+  %out = or disjoint i32 %sel0, %sel1
+  ret i32 %out
+}
+
 ;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
 ; CONSTSPLAT: {{.*}}
 ; CONSTVEC: {{.*}}

>From 7b63d9b172597da44200f8718a2e3816e436e686 Mon Sep 17 00:00:00 2001
From: Jeffrey Byrnes <Jeffrey.Byrnes at amd.com>
Date: Thu, 22 May 2025 11:06:24 -0700
Subject: [PATCH 2/4] Review comments + fix some conditions

Change-Id: I4b71adfd8bffdda4d2b0d1cba85a3fd73a105a28
---
 .../InstCombine/InstCombineAndOrXor.cpp       | 52 ++++++++++++-------
 .../test/Transforms/InstCombine/or-bitmask.ll |  8 +--
 2 files changed, 36 insertions(+), 24 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index ea166717d5c05..62ff45fb24379 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -3593,10 +3593,16 @@ static Value *foldOrOfInversions(BinaryOperator &I,
   return nullptr;
 }
 
+// A decomposition of ((A & N) ? 0 : N * C) . Where X = A, Factor = C, Mask = N.
+// The NUW / NSW bools
+// Note that we can decompose equivalent forms of this expression (e.g. ((A & N)
+// * C))
 struct DecomposedBitMaskMul {
   Value *X;
   APInt Factor;
   APInt Mask;
+  bool NUW;
+  bool NSW;
 };
 
 static std::optional<DecomposedBitMaskMul> matchBitmaskMul(Value *V) {
@@ -3606,20 +3612,21 @@ static std::optional<DecomposedBitMaskMul> matchBitmaskMul(Value *V) {
 
   Value *MulOp = nullptr;
   const APInt *MulConst = nullptr;
+
+  // Decompose (A & N) * C) into BitMaskMul
   if (match(Op, m_Mul(m_Value(MulOp), m_APInt(MulConst)))) {
     Value *Original = nullptr;
     const APInt *Mask = nullptr;
-    if (!MulConst->isStrictlyPositive())
+    if (MulConst->isZero())
       return std::nullopt;
 
     if (match(MulOp, m_And(m_Value(Original), m_APInt(Mask)))) {
-      if (!Mask->isStrictlyPositive())
+      if (Mask->isZero())
         return std::nullopt;
-      DecomposedBitMaskMul Ret;
-      Ret.X = Original;
-      Ret.Mask = *Mask;
-      Ret.Factor = *MulConst;
-      return Ret;
+      return std::optional<DecomposedBitMaskMul>(
+          {Original, *MulConst, *Mask,
+           cast<BinaryOperator>(Op)->hasNoUnsignedWrap(),
+           cast<BinaryOperator>(Op)->hasNoSignedWrap()});
     }
     return std::nullopt;
   }
@@ -3627,7 +3634,7 @@ static std::optional<DecomposedBitMaskMul> matchBitmaskMul(Value *V) {
   Value *Cond = nullptr;
   const APInt *EqZero = nullptr, *NeZero = nullptr;
 
-  //  (!(A & N) ? 0 : N * C) + (!(A & M) ? 0 : M * C) -> A & (N + M) * C
+  // Decompose ((A & N) ? 0 : N * C) into BitMaskMul
   if (match(Op, m_Select(m_Value(Cond), m_APInt(EqZero), m_APInt(NeZero)))) {
     auto ICmpDecompose =
         decomposeBitTest(Cond, /*LookThruTrunc=*/true,
@@ -3638,22 +3645,20 @@ static std::optional<DecomposedBitMaskMul> matchBitmaskMul(Value *V) {
     if (ICmpDecompose->Pred == ICmpInst::ICMP_NE)
       std::swap(EqZero, NeZero);
 
-    if (!EqZero->isZero() || !NeZero->isStrictlyPositive())
+    if (!EqZero->isZero() || NeZero->isZero())
       return std::nullopt;
 
     if (!ICmpInst::isEquality(ICmpDecompose->Pred) ||
         !ICmpDecompose->C.isZero() || !ICmpDecompose->Mask.isPowerOf2() ||
-        ICmpDecompose->Mask.isNegative())
+        ICmpDecompose->Mask.isZero())
       return std::nullopt;
 
     if (!NeZero->urem(ICmpDecompose->Mask).isZero())
       return std::nullopt;
 
-    DecomposedBitMaskMul Ret;
-    Ret.X = ICmpDecompose->X;
-    Ret.Mask = ICmpDecompose->Mask;
-    Ret.Factor = NeZero->udiv(ICmpDecompose->Mask);
-    return Ret;
+    return std::optional<DecomposedBitMaskMul>(
+        {ICmpDecompose->X, NeZero->udiv(ICmpDecompose->Mask),
+         ICmpDecompose->Mask, /*NUW=*/false, /*NSW=*/false});
   }
 
   return std::nullopt;
@@ -3741,19 +3746,26 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
                                    /*NSW=*/true, /*NUW=*/true))
       return R;
 
-    auto Decomp0 = matchBitmaskMul(I.getOperand(0));
+    // (!(A & N) ? 0 : N * C) + (!(A & M) ? 0 : M * C) -> A & (N + M) * C
+    // This also accepts the equivalent mul form of (A & N) ? 0 : N * C)
+    // expressions i.e. (A & N) * C
     auto Decomp1 = matchBitmaskMul(I.getOperand(1));
-
-    if (Decomp0 && Decomp1) {
-      if (Decomp0->X == Decomp1->X &&
+    if (Decomp1) {
+      auto Decomp0 = matchBitmaskMul(I.getOperand(0));
+      if (Decomp0 && Decomp0->X == Decomp1->X &&
           (Decomp0->Mask & Decomp1->Mask).isZero() &&
           Decomp0->Factor == Decomp1->Factor) {
+
         auto NewAnd = Builder.CreateAnd(
             Decomp0->X, ConstantInt::get(Decomp0->X->getType(),
                                          (Decomp0->Mask + Decomp1->Mask)));
 
-        return BinaryOperator::CreateMul(
+        auto Combined = BinaryOperator::CreateMul(
             NewAnd, ConstantInt::get(NewAnd->getType(), Decomp1->Factor));
+
+        Combined->setHasNoUnsignedWrap(Decomp0->NUW && Decomp1->NUW);
+        Combined->setHasNoSignedWrap(Decomp0->NSW && Decomp1->NSW);
+        return Combined;
       }
     }
   }
diff --git a/llvm/test/Transforms/InstCombine/or-bitmask.ll b/llvm/test/Transforms/InstCombine/or-bitmask.ll
index 87f0bbf4d37ab..dcfbe171dd08f 100644
--- a/llvm/test/Transforms/InstCombine/or-bitmask.ll
+++ b/llvm/test/Transforms/InstCombine/or-bitmask.ll
@@ -37,8 +37,8 @@ define i32 @add_select_cmp_and2(i32 %in) {
 define i32 @add_select_cmp_and3(i32 %in) {
 ; CHECK-LABEL: @add_select_cmp_and3(
 ; CHECK-NEXT:    [[TMP1:%.*]] = and i32 [[IN:%.*]], 7
-; CHECK-NEXT:    [[TEMP:%.*]] = mul nuw nsw i32 [[TMP1]], 72
-; CHECK-NEXT:    ret i32 [[TEMP]]
+; CHECK-NEXT:    [[TEMP1:%.*]] = mul nuw nsw i32 [[TMP1]], 72
+; CHECK-NEXT:    ret i32 [[TEMP1]]
 ;
   %bitop0 = and i32 %in, 1
   %cmp0 = icmp eq i32 %bitop0, 0
@@ -57,8 +57,8 @@ define i32 @add_select_cmp_and3(i32 %in) {
 define i32 @add_select_cmp_and4(i32 %in) {
 ; CHECK-LABEL: @add_select_cmp_and4(
 ; CHECK-NEXT:    [[TMP2:%.*]] = and i32 [[IN:%.*]], 15
-; CHECK-NEXT:    [[TEMP3:%.*]] = mul nuw nsw i32 [[TMP2]], 72
-; CHECK-NEXT:    ret i32 [[TEMP3]]
+; CHECK-NEXT:    [[TEMP2:%.*]] = mul nuw nsw i32 [[TMP2]], 72
+; CHECK-NEXT:    ret i32 [[TEMP2]]
 ;
   %bitop0 = and i32 %in, 1
   %cmp0 = icmp eq i32 %bitop0, 0

>From 5fa229ba2432d00512a7d58c3ffa7ec610ee4aa6 Mon Sep 17 00:00:00 2001
From: Jeffrey Byrnes <Jeffrey.Byrnes at amd.com>
Date: Tue, 27 May 2025 11:03:46 -0700
Subject: [PATCH 3/4] Fix crash due to mismatch APInt bitwidth

Change-Id: I12f77aedbf1a2edfe63e4d03cd1e5c1c601365a7
---
 llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index 62ff45fb24379..e357e3d296cc1 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -3650,7 +3650,8 @@ static std::optional<DecomposedBitMaskMul> matchBitmaskMul(Value *V) {
 
     if (!ICmpInst::isEquality(ICmpDecompose->Pred) ||
         !ICmpDecompose->C.isZero() || !ICmpDecompose->Mask.isPowerOf2() ||
-        ICmpDecompose->Mask.isZero())
+        ICmpDecompose->Mask.isZero() ||
+        NeZero->getBitWidth() != ICmpDecompose->Mask.getBitWidth())
       return std::nullopt;
 
     if (!NeZero->urem(ICmpDecompose->Mask).isZero())

>From 8c403c912046505ffc10378560c2fc48f214af6a Mon Sep 17 00:00:00 2001
From: Jeffrey Byrnes <Jeffrey.Byrnes at amd.com>
Date: Mon, 2 Jun 2025 12:29:39 -0700
Subject: [PATCH 4/4] [InstCombine] Extend bitmask mul combine to handle
 independent operands

Change-Id: Ife1a010d2ae6df40549a6c73f7b893948befa3be
---
 .../InstCombine/InstCombineAndOrXor.cpp       | 80 ++++++++++++++++---
 .../test/Transforms/InstCombine/or-bitmask.ll | 50 ++++++++++++
 2 files changed, 117 insertions(+), 13 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index e357e3d296cc1..231b980506744 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -3603,6 +3603,11 @@ struct DecomposedBitMaskMul {
   APInt Mask;
   bool NUW;
   bool NSW;
+
+  bool isCombineableWith(DecomposedBitMaskMul Other) {
+    return X == Other.X && (Mask & Other.Mask).isZero() &&
+           Factor == Other.Factor;
+  }
 };
 
 static std::optional<DecomposedBitMaskMul> matchBitmaskMul(Value *V) {
@@ -3665,6 +3670,34 @@ static std::optional<DecomposedBitMaskMul> matchBitmaskMul(Value *V) {
   return std::nullopt;
 }
 
+using CombinedBitmaskMul =
+    std::pair<std::optional<DecomposedBitMaskMul>, Value *>;
+
+static CombinedBitmaskMul matchCombinedBitmaskMul(Value *V) {
+  auto DecompBitMaskMul = matchBitmaskMul(V);
+  if (DecompBitMaskMul)
+    return {DecompBitMaskMul, nullptr};
+
+  // Otherwise, check the operands of V for bitmaskmul pattern
+  auto BOp = dyn_cast<BinaryOperator>(V);
+  if (!BOp)
+    return {std::nullopt, nullptr};
+
+  auto Disj = dyn_cast<PossiblyDisjointInst>(BOp);
+  if (!Disj || !Disj->isDisjoint())
+    return {std::nullopt, nullptr};
+
+  auto DecompBitMaskMul0 = matchBitmaskMul(BOp->getOperand(0));
+  if (DecompBitMaskMul0)
+    return {DecompBitMaskMul0, BOp->getOperand(1)};
+
+  auto DecompBitMaskMul1 = matchBitmaskMul(BOp->getOperand(1));
+  if (DecompBitMaskMul1)
+    return {DecompBitMaskMul1, BOp->getOperand(0)};
+
+  return {std::nullopt, nullptr};
+}
+
 // FIXME: We use commutative matchers (m_c_*) for some, but not all, matches
 // here. We should standardize that construct where it is needed or choose some
 // other way to ensure that commutated variants of patterns are not missed.
@@ -3750,22 +3783,43 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
     // (!(A & N) ? 0 : N * C) + (!(A & M) ? 0 : M * C) -> A & (N + M) * C
     // This also accepts the equivalent mul form of (A & N) ? 0 : N * C)
     // expressions i.e. (A & N) * C
-    auto Decomp1 = matchBitmaskMul(I.getOperand(1));
-    if (Decomp1) {
-      auto Decomp0 = matchBitmaskMul(I.getOperand(0));
-      if (Decomp0 && Decomp0->X == Decomp1->X &&
-          (Decomp0->Mask & Decomp1->Mask).isZero() &&
-          Decomp0->Factor == Decomp1->Factor) {
+    CombinedBitmaskMul Decomp1 = matchCombinedBitmaskMul(I.getOperand(1));
+    auto BMDecomp1 = Decomp1.first;
 
-        auto NewAnd = Builder.CreateAnd(
-            Decomp0->X, ConstantInt::get(Decomp0->X->getType(),
-                                         (Decomp0->Mask + Decomp1->Mask)));
+    if (BMDecomp1) {
+      CombinedBitmaskMul Decomp0 = matchCombinedBitmaskMul(I.getOperand(0));
+      auto BMDecomp0 = Decomp0.first;
 
-        auto Combined = BinaryOperator::CreateMul(
-            NewAnd, ConstantInt::get(NewAnd->getType(), Decomp1->Factor));
+      if (BMDecomp0 && BMDecomp0->isCombineableWith(*BMDecomp1)) {
+        auto NewAnd = Builder.CreateAnd(
+            BMDecomp0->X,
+            ConstantInt::get(BMDecomp0->X->getType(),
+                             (BMDecomp0->Mask + BMDecomp1->Mask)));
+
+        BinaryOperator *Combined = cast<BinaryOperator>(Builder.CreateMul(
+            NewAnd, ConstantInt::get(NewAnd->getType(), BMDecomp1->Factor)));
+
+        Combined->setHasNoUnsignedWrap(BMDecomp0->NUW && BMDecomp1->NUW);
+        Combined->setHasNoSignedWrap(BMDecomp0->NSW && BMDecomp1->NSW);
+
+        // If our tree has indepdent or-disjoint operands, bring them in.
+        auto OtherOp0 = Decomp0.second;
+        auto OtherOp1 = Decomp1.second;
+
+        if (OtherOp0 || OtherOp1) {
+          Value *OtherOp;
+          if (OtherOp0 && OtherOp1) {
+            OtherOp = Builder.CreateOr(OtherOp0, OtherOp1);
+            cast<PossiblyDisjointInst>(OtherOp)->setIsDisjoint(true);
+          } else {
+            OtherOp = OtherOp0 ? OtherOp0 : OtherOp1;
+          }
+          Combined = cast<BinaryOperator>(Builder.CreateOr(Combined, OtherOp));
+          cast<PossiblyDisjointInst>(Combined)->setIsDisjoint(true);
+        }
 
-        Combined->setHasNoUnsignedWrap(Decomp0->NUW && Decomp1->NUW);
-        Combined->setHasNoSignedWrap(Decomp0->NSW && Decomp1->NSW);
+        // Caller expects detached instruction
+        Combined->removeFromParent();
         return Combined;
       }
     }
diff --git a/llvm/test/Transforms/InstCombine/or-bitmask.ll b/llvm/test/Transforms/InstCombine/or-bitmask.ll
index dcfbe171dd08f..8a2e328fa95cb 100644
--- a/llvm/test/Transforms/InstCombine/or-bitmask.ll
+++ b/llvm/test/Transforms/InstCombine/or-bitmask.ll
@@ -434,6 +434,56 @@ define i32 @add_select_cmp_and_mul_mismatch(i32 %in) {
   ret i32 %out
 }
 
+define i32 @unrelated_ops(i32 %in, i32 %in2) {
+; CHECK-LABEL: @unrelated_ops(
+; CHECK-NEXT:    [[TMP1:%.*]] = and i32 [[IN:%.*]], 15
+; CHECK-NEXT:    [[TMP2:%.*]] = mul nuw nsw i32 [[TMP1]], 72
+; CHECK-NEXT:    [[OUT:%.*]] = or disjoint i32 [[TMP2]], [[IN2:%.*]]
+; CHECK-NEXT:    ret i32 [[OUT]]
+;
+  %1 = and i32 %in, 3
+  %temp = mul nuw nsw i32 %1, 72
+  %2 = and i32 %in, 12
+  %temp2 = mul nuw nsw i32 %2, 72
+  %temp3 = or disjoint i32 %in2, %temp2
+  %out = or disjoint i32 %temp, %temp3
+  ret i32 %out
+}
+
+define i32 @unrelated_ops1(i32 %in, i32 %in2) {
+; CHECK-LABEL: @unrelated_ops1(
+; CHECK-NEXT:    [[TMP1:%.*]] = and i32 [[IN:%.*]], 15
+; CHECK-NEXT:    [[TMP2:%.*]] = mul nuw nsw i32 [[TMP1]], 72
+; CHECK-NEXT:    [[OUT:%.*]] = or disjoint i32 [[TMP2]], [[IN2:%.*]]
+; CHECK-NEXT:    ret i32 [[OUT]]
+;
+  %1 = and i32 %in, 3
+  %temp = mul nuw nsw i32 %1, 72
+  %2 = and i32 %in, 12
+  %temp2 = mul nuw nsw i32 %2, 72
+  %temp3 = or disjoint i32 %in2, %temp
+  %out = or disjoint i32 %temp3, %temp2
+  ret i32 %out
+}
+
+define i32 @unrelated_ops2(i32 %in, i32 %in2, i32 %in3) {
+; CHECK-LABEL: @unrelated_ops2(
+; CHECK-NEXT:    [[TMP1:%.*]] = and i32 [[IN:%.*]], 15
+; CHECK-NEXT:    [[TMP2:%.*]] = mul nuw nsw i32 [[TMP1]], 72
+; CHECK-NEXT:    [[TMP3:%.*]] = or disjoint i32 [[IN3:%.*]], [[IN2:%.*]]
+; CHECK-NEXT:    [[OUT:%.*]] = or disjoint i32 [[TMP2]], [[TMP3]]
+; CHECK-NEXT:    ret i32 [[OUT]]
+;
+  %1 = and i32 %in, 3
+  %temp = mul nuw nsw i32 %1, 72
+  %temp3 = or disjoint i32 %temp, %in3
+  %2 = and i32 %in, 12
+  %temp2 = mul nuw nsw i32 %2, 72
+  %temp4 = or disjoint i32 %in2, %temp2
+  %out = or disjoint i32 %temp3, %temp4
+  ret i32 %out
+}
+
 ;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
 ; CONSTSPLAT: {{.*}}
 ; CONSTVEC: {{.*}}



More information about the llvm-commits mailing list