[llvm] [InstCombine] Combine or-disjoint (and->mul), (and->mul) to and->mul (PR #136013)

Jeffrey Byrnes via llvm-commits llvm-commits at lists.llvm.org
Mon Apr 28 10:17:30 PDT 2025


https://github.com/jrbyrnes updated https://github.com/llvm/llvm-project/pull/136013

>From 514554965bd91246a963a8d7da311226c8854490 Mon Sep 17 00:00:00 2001
From: Jeffrey Byrnes <Jeffrey.Byrnes at amd.com>
Date: Mon, 28 Apr 2025 10:10:59 -0700
Subject: [PATCH 1/2] [InstCombine] Combine and->cmp->sel->or-disjoint into
 and->mul

Change-Id: I08cfc8c494bce343bf09e3110186e1a8553e2473
---
 .../InstCombine/InstCombineAndOrXor.cpp       |  48 +++
 .../test/Transforms/InstCombine/or-bitmask.ll | 328 ++++++++++++++++++
 2 files changed, 376 insertions(+)
 create mode 100644 llvm/test/Transforms/InstCombine/or-bitmask.ll

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index b74bce391aa56..18d3f551adef4 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -3641,6 +3641,54 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
             foldAddLikeCommutative(I.getOperand(1), I.getOperand(0),
                                    /*NSW=*/true, /*NUW=*/true))
       return R;
+
+    Value *Cond0 = nullptr, *Cond1 = nullptr;
+    const APInt *Op0Eq = nullptr, *Op0Ne = nullptr;
+    const APInt *Op1Eq = nullptr, *Op1Ne = nullptr;
+
+    //  (!(A & N) ? 0 : N * C) + (!(A & M) ? 0 : M * C) -> A & (N + M) * C
+    if (match(I.getOperand(0),
+              m_Select(m_Value(Cond0), m_APInt(Op0Eq), m_APInt(Op0Ne))) &&
+        match(I.getOperand(1),
+              m_Select(m_Value(Cond1), m_APInt(Op1Eq), m_APInt(Op1Ne)))) {
+      CmpPredicate Pred0, Pred1;
+
+      auto LHSDecompose =
+          decomposeBitTest(Cond0, /*LookThruTrunc=*/true,
+                           /*AllowNonZeroC=*/false, /*DecomposeAnd=*/true);
+      auto RHSDecompose =
+          decomposeBitTest(Cond1, /*LookThruTrunc=*/true,
+                           /*AllowNonZeroC=*/false, /*DecomposeAnd=*/true);
+
+      if (LHSDecompose && RHSDecompose && LHSDecompose->X == RHSDecompose->X &&
+          (ICmpInst::isEquality(LHSDecompose->Pred)) &&
+          !RHSDecompose->Mask.isNegative() &&
+          !LHSDecompose->Mask.isNegative() && RHSDecompose->Mask.isPowerOf2() &&
+          LHSDecompose->Mask.isPowerOf2() &&
+          LHSDecompose->Mask != RHSDecompose->Mask &&
+          LHSDecompose->C.isZero() && RHSDecompose->C.isZero()) {
+        if (LHSDecompose->Pred == ICmpInst::ICMP_NE)
+          std::swap(Op0Eq, Op0Ne);
+        if (RHSDecompose->Pred == ICmpInst::ICMP_NE)
+          std::swap(Op1Eq, Op1Ne);
+
+        if (Op0Ne->isStrictlyPositive() && Op1Ne->isStrictlyPositive() &&
+            Op0Eq->isZero() && Op1Eq->isZero() &&
+            Op0Ne->urem(LHSDecompose->Mask).isZero() &&
+            Op1Ne->urem(RHSDecompose->Mask).isZero() &&
+            Op0Ne->udiv(LHSDecompose->Mask) ==
+                Op1Ne->udiv(RHSDecompose->Mask)) {
+          auto NewAnd = Builder.CreateAnd(
+              LHSDecompose->X,
+              ConstantInt::get(LHSDecompose->X->getType(),
+                               (LHSDecompose->Mask + RHSDecompose->Mask)));
+
+          return BinaryOperator::CreateMul(
+              NewAnd, ConstantInt::get(NewAnd->getType(),
+                                       Op0Ne->udiv(LHSDecompose->Mask)));
+        }
+      }
+    }
   }
 
   Value *X, *Y;
diff --git a/llvm/test/Transforms/InstCombine/or-bitmask.ll b/llvm/test/Transforms/InstCombine/or-bitmask.ll
new file mode 100644
index 0000000000000..18acc42f2dc40
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/or-bitmask.ll
@@ -0,0 +1,328 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt < %s -passes=instcombine -S | FileCheck %s --check-prefixes=CHECK,CONSTVEC
+; RUN: opt < %s -passes=instcombine -S -use-constant-int-for-fixed-length-splat | FileCheck %s --check-prefixes=CHECK,CONSTSPLAT
+
+define i32 @add_select_cmp_and1(i32 %in) {
+; CHECK-LABEL: @add_select_cmp_and1(
+; CHECK-NEXT:    [[TMP1:%.*]] = and i32 [[IN:%.*]], 3
+; CHECK-NEXT:    [[OUT:%.*]] = mul nuw nsw i32 [[TMP1]], 72
+; CHECK-NEXT:    ret i32 [[OUT]]
+;
+  %bitop0 = and i32 %in, 1
+  %cmp0 = icmp eq i32 %bitop0, 0
+  %bitop1 = and i32 %in, 2
+  %cmp1 = icmp eq i32 %bitop1, 0
+  %sel0 = select i1 %cmp0, i32 0, i32 72
+  %sel1 = select i1 %cmp1, i32 0, i32 144
+  %out = or disjoint i32 %sel0, %sel1
+  ret i32 %out
+}
+
+define i32 @add_select_cmp_and2(i32 %in) {
+; CHECK-LABEL: @add_select_cmp_and2(
+; CHECK-NEXT:    [[TMP1:%.*]] = and i32 [[IN:%.*]], 5
+; CHECK-NEXT:    [[OUT:%.*]] = mul nuw nsw i32 [[TMP1]], 72
+; CHECK-NEXT:    ret i32 [[OUT]]
+;
+  %bitop0 = and i32 %in, 1
+  %cmp0 = icmp eq i32 %bitop0, 0
+  %bitop1 = and i32 %in, 4
+  %cmp1 = icmp eq i32 %bitop1, 0
+  %sel0 = select i1 %cmp0, i32 0, i32 72
+  %sel1 = select i1 %cmp1, i32 0, i32 288
+  %out = or disjoint i32 %sel0, %sel1
+  ret i32 %out
+}
+
+define i32 @add_select_cmp_and3(i32 %in) {
+; CHECK-LABEL: @add_select_cmp_and3(
+; CHECK-NEXT:    [[TMP1:%.*]] = and i32 [[IN:%.*]], 3
+; CHECK-NEXT:    [[TEMP:%.*]] = mul nuw nsw i32 [[TMP1]], 72
+; CHECK-NEXT:    [[BITOP2:%.*]] = and i32 [[IN]], 4
+; CHECK-NEXT:    [[CMP2:%.*]] = icmp eq i32 [[BITOP2]], 0
+; CHECK-NEXT:    [[SEL2:%.*]] = select i1 [[CMP2]], i32 0, i32 288
+; CHECK-NEXT:    [[OUT:%.*]] = or disjoint i32 [[TEMP]], [[SEL2]]
+; CHECK-NEXT:    ret i32 [[OUT]]
+;
+  %bitop0 = and i32 %in, 1
+  %cmp0 = icmp eq i32 %bitop0, 0
+  %bitop1 = and i32 %in, 2
+  %cmp1 = icmp eq i32 %bitop1, 0
+  %sel0 = select i1 %cmp0, i32 0, i32 72
+  %sel1 = select i1 %cmp1, i32 0, i32 144
+  %temp = or disjoint i32 %sel0, %sel1
+  %bitop2 = and i32 %in, 4
+  %cmp2 = icmp eq i32 %bitop2, 0
+  %sel2 = select i1 %cmp2, i32 0, i32 288
+  %out = or disjoint i32 %temp, %sel2
+  ret i32 %out
+}
+
+define i32 @add_select_cmp_and4(i32 %in) {
+; CHECK-LABEL: @add_select_cmp_and4(
+; CHECK-NEXT:    [[TMP1:%.*]] = and i32 [[IN:%.*]], 3
+; CHECK-NEXT:    [[OUT:%.*]] = mul nuw nsw i32 [[TMP1]], 72
+; CHECK-NEXT:    [[TMP2:%.*]] = and i32 [[IN]], 12
+; CHECK-NEXT:    [[TEMP3:%.*]] = mul nuw nsw i32 [[TMP2]], 72
+; CHECK-NEXT:    [[OUT1:%.*]] = or disjoint i32 [[OUT]], [[TEMP3]]
+; CHECK-NEXT:    ret i32 [[OUT1]]
+;
+  %bitop0 = and i32 %in, 1
+  %cmp0 = icmp eq i32 %bitop0, 0
+  %bitop1 = and i32 %in, 2
+  %cmp1 = icmp eq i32 %bitop1, 0
+  %sel0 = select i1 %cmp0, i32 0, i32 72
+  %sel1 = select i1 %cmp1, i32 0, i32 144
+  %temp = or disjoint i32 %sel0, %sel1
+  %bitop2 = and i32 %in, 4
+  %cmp2 = icmp eq i32 %bitop2, 0
+  %bitop3 = and i32 %in, 8
+  %cmp3 = icmp eq i32 %bitop3, 0
+  %sel2 = select i1 %cmp2, i32 0, i32 288
+  %sel3 = select i1 %cmp3, i32 0, i32 576
+  %temp2 = or disjoint i32 %sel2, %sel3
+  %out = or disjoint i32 %temp, %temp2
+  ret i32 %out
+}
+
+define i32 @add_select_cmp_and_pred1(i32 %in) {
+; CHECK-LABEL: @add_select_cmp_and_pred1(
+; CHECK-NEXT:    [[TMP1:%.*]] = and i32 [[IN:%.*]], 3
+; CHECK-NEXT:    [[OUT:%.*]] = mul nuw nsw i32 [[TMP1]], 72
+; CHECK-NEXT:    ret i32 [[OUT]]
+;
+  %bitop0 = and i32 %in, 1
+  %cmp0 = icmp ne i32 %bitop0, 0
+  %bitop1 = and i32 %in, 2
+  %cmp1 = icmp eq i32 %bitop1, 0
+  %sel0 = select i1 %cmp0, i32 72, i32 0
+  %sel1 = select i1 %cmp1, i32 0, i32 144
+  %out = or disjoint i32 %sel0, %sel1
+  ret i32 %out
+}
+
+define i32 @add_select_cmp_and_pred2(i32 %in) {
+; CHECK-LABEL: @add_select_cmp_and_pred2(
+; CHECK-NEXT:    [[TMP1:%.*]] = and i32 [[IN:%.*]], 3
+; CHECK-NEXT:    [[OUT:%.*]] = mul nuw nsw i32 [[TMP1]], 72
+; CHECK-NEXT:    ret i32 [[OUT]]
+;
+  %bitop0 = and i32 %in, 1
+  %cmp0 = icmp eq i32 %bitop0, 0
+  %bitop1 = and i32 %in, 2
+  %cmp1 = icmp ne i32 %bitop1, 0
+  %sel0 = select i1 %cmp0, i32 0, i32 72
+  %sel1 = select i1 %cmp1, i32 144, i32 0
+  %out = or disjoint i32 %sel0, %sel1
+  ret i32 %out
+}
+
+define i32 @add_select_cmp_and_pred3(i32 %in) {
+; CHECK-LABEL: @add_select_cmp_and_pred3(
+; CHECK-NEXT:    [[TMP1:%.*]] = and i32 [[IN:%.*]], 3
+; CHECK-NEXT:    [[OUT:%.*]] = mul nuw nsw i32 [[TMP1]], 72
+; CHECK-NEXT:    ret i32 [[OUT]]
+;
+  %bitop0 = and i32 %in, 1
+  %cmp0 = icmp ne i32 %bitop0, 0
+  %bitop1 = and i32 %in, 2
+  %cmp1 = icmp ne i32 %bitop1, 0
+  %sel0 = select i1 %cmp0, i32 72, i32 0
+  %sel1 = select i1 %cmp1, i32 144, i32 0
+  %out = or disjoint i32 %sel0, %sel1
+  ret i32 %out
+}
+
+define i32 @add_select_cmp_trunc(i32 %in) {
+; CHECK-LABEL: @add_select_cmp_trunc(
+; CHECK-NEXT:    [[TMP1:%.*]] = and i32 [[IN:%.*]], 3
+; CHECK-NEXT:    [[OUT:%.*]] = mul nuw nsw i32 [[TMP1]], 72
+; CHECK-NEXT:    ret i32 [[OUT]]
+;
+  %cmp0 = trunc i32 %in to i1
+  %bitop1 = and i32 %in, 2
+  %cmp1 = icmp eq i32 %bitop1, 0
+  %sel0 = select i1 %cmp0, i32 72, i32 0
+  %sel1 = select i1 %cmp1, i32 0, i32 144
+  %out = or disjoint i32 %sel0, %sel1
+  ret i32 %out
+}
+
+define i32 @add_select_cmp_trunc1(i32 %in) {
+; CHECK-LABEL: @add_select_cmp_trunc1(
+; CHECK-NEXT:    [[TMP1:%.*]] = and i32 [[IN:%.*]], 3
+; CHECK-NEXT:    [[OUT:%.*]] = mul nuw nsw i32 [[TMP1]], 72
+; CHECK-NEXT:    ret i32 [[OUT]]
+;
+  %cmp0 = trunc i32 %in to i1
+  %bitop1 = and i32 %in, 2
+  %cmp1 = icmp ne i32 %bitop1, 0
+  %sel0 = select i1 %cmp0, i32 72, i32 0
+  %sel1 = select i1 %cmp1, i32 144, i32 0
+  %out = or disjoint i32 %sel0, %sel1
+  ret i32 %out
+}
+
+
+define i32 @add_select_cmp_and_const_mismatch(i32 %in) {
+; CHECK-LABEL: @add_select_cmp_and_const_mismatch(
+; CHECK-NEXT:    [[BITOP0:%.*]] = and i32 [[IN:%.*]], 1
+; CHECK-NEXT:    [[CMP0:%.*]] = icmp eq i32 [[BITOP0]], 0
+; CHECK-NEXT:    [[BITOP1:%.*]] = and i32 [[IN]], 2
+; CHECK-NEXT:    [[CMP1:%.*]] = icmp eq i32 [[BITOP1]], 0
+; CHECK-NEXT:    [[SEL0:%.*]] = select i1 [[CMP0]], i32 0, i32 72
+; CHECK-NEXT:    [[SEL1:%.*]] = select i1 [[CMP1]], i32 0, i32 288
+; CHECK-NEXT:    [[OUT:%.*]] = or disjoint i32 [[SEL0]], [[SEL1]]
+; CHECK-NEXT:    ret i32 [[OUT]]
+;
+  %bitop0 = and i32 %in, 1
+  %cmp0 = icmp eq i32 %bitop0, 0
+  %bitop1 = and i32 %in, 2
+  %cmp1 = icmp eq i32 %bitop1, 0
+  %sel0 = select i1 %cmp0, i32 0, i32 72
+  %sel1 = select i1 %cmp1, i32 0, i32 288
+  %out = or disjoint i32 %sel0, %sel1
+  ret i32 %out
+}
+
+define i32 @add_select_cmp_and_value_mismatch(i32 %in, i32 %in1) {
+; CHECK-LABEL: @add_select_cmp_and_value_mismatch(
+; CHECK-NEXT:    [[BITOP0:%.*]] = and i32 [[IN:%.*]], 1
+; CHECK-NEXT:    [[CMP0:%.*]] = icmp eq i32 [[BITOP0]], 0
+; CHECK-NEXT:    [[BITOP1:%.*]] = and i32 [[IN1:%.*]], 2
+; CHECK-NEXT:    [[CMP1:%.*]] = icmp eq i32 [[BITOP1]], 0
+; CHECK-NEXT:    [[SEL0:%.*]] = select i1 [[CMP0]], i32 0, i32 72
+; CHECK-NEXT:    [[SEL1:%.*]] = select i1 [[CMP1]], i32 0, i32 144
+; CHECK-NEXT:    [[OUT:%.*]] = or disjoint i32 [[SEL0]], [[SEL1]]
+; CHECK-NEXT:    ret i32 [[OUT]]
+;
+  %bitop0 = and i32 %in, 1
+  %cmp0 = icmp eq i32 %bitop0, 0
+  %bitop1 = and i32 %in1, 2
+  %cmp1 = icmp eq i32 %bitop1, 0
+  %sel0 = select i1 %cmp0, i32 0, i32 72
+  %sel1 = select i1 %cmp1, i32 0, i32 144
+  %out = or disjoint i32 %sel0, %sel1
+  ret i32 %out
+}
+
+define i32 @add_select_cmp_and_negative(i32 %in) {
+; CHECK-LABEL: @add_select_cmp_and_negative(
+; CHECK-NEXT:    [[BITOP0:%.*]] = and i32 [[IN:%.*]], 1
+; CHECK-NEXT:    [[CMP0:%.*]] = icmp eq i32 [[BITOP0]], 0
+; CHECK-NEXT:    [[CMP1:%.*]] = icmp ult i32 [[IN]], 2
+; CHECK-NEXT:    [[SEL0:%.*]] = select i1 [[CMP0]], i32 0, i32 72
+; CHECK-NEXT:    [[SEL1:%.*]] = select i1 [[CMP1]], i32 0, i32 -144
+; CHECK-NEXT:    [[OUT:%.*]] = or disjoint i32 [[SEL0]], [[SEL1]]
+; CHECK-NEXT:    ret i32 [[OUT]]
+;
+  %bitop0 = and i32 %in, 1
+  %cmp0 = icmp eq i32 %bitop0, 0
+  %bitop1 = and i32 %in, -2
+  %cmp1 = icmp eq i32 %bitop1, 0
+  %sel0 = select i1 %cmp0, i32 0, i32 72
+  %sel1 = select i1 %cmp1, i32 0, i32 -144
+  %out = or disjoint i32 %sel0, %sel1
+  ret i32 %out
+}
+
+define i32 @add_select_cmp_and_bitsel_overlap(i32 %in) {
+; CHECK-LABEL: @add_select_cmp_and_bitsel_overlap(
+; CHECK-NEXT:    [[BITOP0:%.*]] = and i32 [[IN:%.*]], 2
+; CHECK-NEXT:    [[CMP0:%.*]] = icmp eq i32 [[BITOP0]], 0
+; CHECK-NEXT:    [[SEL0:%.*]] = select i1 [[CMP0]], i32 0, i32 144
+; CHECK-NEXT:    ret i32 [[SEL0]]
+;
+  %bitop0 = and i32 %in, 2
+  %cmp0 = icmp eq i32 %bitop0, 0
+  %bitop1 = and i32 %in, 2
+  %cmp1 = icmp eq i32 %bitop1, 0
+  %sel0 = select i1 %cmp0, i32 0, i32 144
+  %sel1 = select i1 %cmp1, i32 0, i32 144
+  %out = or disjoint i32 %sel0, %sel1
+  ret i32 %out
+}
+
+; We cannot combine into and-mul, as %bitop1 may not be exactly 6
+
+define i32 @add_select_cmp_and_multbit_mask(i32 %in) {
+; CHECK-LABEL: @add_select_cmp_and_multbit_mask(
+; CHECK-NEXT:    [[BITOP0:%.*]] = and i32 [[IN:%.*]], 1
+; CHECK-NEXT:    [[CMP0:%.*]] = icmp eq i32 [[BITOP0]], 0
+; CHECK-NEXT:    [[BITOP1:%.*]] = and i32 [[IN]], 6
+; CHECK-NEXT:    [[CMP1:%.*]] = icmp eq i32 [[BITOP1]], 0
+; CHECK-NEXT:    [[SEL0:%.*]] = select i1 [[CMP0]], i32 0, i32 72
+; CHECK-NEXT:    [[SEL1:%.*]] = select i1 [[CMP1]], i32 0, i32 432
+; CHECK-NEXT:    [[OUT:%.*]] = or disjoint i32 [[SEL0]], [[SEL1]]
+; CHECK-NEXT:    ret i32 [[OUT]]
+;
+  %bitop0 = and i32 %in, 1
+  %cmp0 = icmp eq i32 %bitop0, 0
+  %bitop1 = and i32 %in, 6
+  %cmp1 = icmp eq i32 %bitop1, 0
+  %sel0 = select i1 %cmp0, i32 0, i32 72
+  %sel1 = select i1 %cmp1, i32 0, i32 432
+  %out = or disjoint i32 %sel0, %sel1
+  ret i32 %out
+}
+
+
+define <2 x i32> @add_select_cmp_vec(<2 x i32> %in) {
+; CHECK-LABEL: @add_select_cmp_vec(
+; CHECK-NEXT:    [[TMP1:%.*]] = and <2 x i32> [[IN:%.*]], splat (i32 3)
+; CHECK-NEXT:    [[OUT:%.*]] = mul nuw nsw <2 x i32> [[TMP1]], splat (i32 72)
+; CHECK-NEXT:    ret <2 x i32> [[OUT]]
+;
+  %bitop0 = and <2 x i32> %in, <i32 1, i32 1>
+  %cmp0 = icmp eq <2 x i32> %bitop0, <i32 0, i32 0>
+  %bitop1 = and <2 x i32> %in, <i32 2, i32 2>
+  %cmp1 = icmp eq <2 x i32> %bitop1, <i32 0, i32 0>
+  %sel0 = select <2 x i1> %cmp0, <2 x i32> <i32 0, i32 0>, <2 x i32> <i32 72, i32 72>
+  %sel1 = select <2 x i1> %cmp1, <2 x i32> <i32 0, i32 0>, <2 x i32> <i32 144, i32 144>
+  %out = or disjoint <2 x i32> %sel0, %sel1
+  ret <2 x i32> %out
+}
+
+define <2 x i32> @add_select_cmp_vec_poison(<2 x i32> %in) {
+; CHECK-LABEL: @add_select_cmp_vec_poison(
+; CHECK-NEXT:    [[BITOP0:%.*]] = and <2 x i32> [[IN:%.*]], splat (i32 1)
+; CHECK-NEXT:    [[CMP0:%.*]] = icmp eq <2 x i32> [[BITOP0]], zeroinitializer
+; CHECK-NEXT:    [[BITOP1:%.*]] = and <2 x i32> [[IN]], splat (i32 2)
+; CHECK-NEXT:    [[CMP1:%.*]] = icmp eq <2 x i32> [[BITOP1]], zeroinitializer
+; CHECK-NEXT:    [[SEL1:%.*]] = select <2 x i1> [[CMP1]], <2 x i32> zeroinitializer, <2 x i32> <i32 poison, i32 144>
+; CHECK-NEXT:    [[OUT:%.*]] = select <2 x i1> [[CMP0]], <2 x i32> [[SEL1]], <2 x i32> <i32 72, i32 poison>
+; CHECK-NEXT:    ret <2 x i32> [[OUT]]
+;
+  %bitop0 = and <2 x i32> %in, <i32 1, i32 1>
+  %cmp0 = icmp eq <2 x i32> %bitop0, <i32 0, i32 0>
+  %bitop1 = and <2 x i32> %in, <i32 2, i32 2>
+  %cmp1 = icmp eq <2 x i32> %bitop1, <i32 0, i32 0>
+  %sel0 = select <2 x i1> %cmp0, <2 x i32> <i32 0, i32 0>, <2 x i32> <i32 72, i32 poison>
+  %sel1 = select <2 x i1> %cmp1, <2 x i32> <i32 0, i32 0>, <2 x i32> <i32 poison, i32 144>
+  %out = or disjoint <2 x i32> %sel0, %sel1
+  ret <2 x i32> %out
+}
+
+define <2 x i32> @add_select_cmp_vec_nonunique(<2 x i32> %in) {
+; CHECK-LABEL: @add_select_cmp_vec_nonunique(
+; CHECK-NEXT:    [[BITOP0:%.*]] = and <2 x i32> [[IN:%.*]], <i32 1, i32 2>
+; CHECK-NEXT:    [[CMP0:%.*]] = icmp eq <2 x i32> [[BITOP0]], zeroinitializer
+; CHECK-NEXT:    [[BITOP1:%.*]] = and <2 x i32> [[IN]], <i32 4, i32 8>
+; CHECK-NEXT:    [[CMP1:%.*]] = icmp eq <2 x i32> [[BITOP1]], zeroinitializer
+; CHECK-NEXT:    [[SEL0:%.*]] = select <2 x i1> [[CMP0]], <2 x i32> zeroinitializer, <2 x i32> <i32 72, i32 144>
+; CHECK-NEXT:    [[SEL1:%.*]] = select <2 x i1> [[CMP1]], <2 x i32> zeroinitializer, <2 x i32> <i32 288, i32 576>
+; CHECK-NEXT:    [[OUT:%.*]] = or disjoint <2 x i32> [[SEL0]], [[SEL1]]
+; CHECK-NEXT:    ret <2 x i32> [[OUT]]
+;
+  %bitop0 = and <2 x i32> %in, <i32 1, i32 2>
+  %cmp0 = icmp eq <2 x i32> %bitop0, <i32 0, i32 0>
+  %bitop1 = and <2 x i32> %in, <i32 4, i32 8>
+  %cmp1 = icmp eq <2 x i32> %bitop1, <i32 0, i32 0>
+  %sel0 = select <2 x i1> %cmp0, <2 x i32> <i32 0, i32 0>, <2 x i32> <i32 72, i32 144>
+  %sel1 = select <2 x i1> %cmp1, <2 x i32> <i32 0, i32 0>, <2 x i32> <i32 288, i32 576>
+  %out = or disjoint <2 x i32> %sel0, %sel1
+  ret <2 x i32> %out
+}
+;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
+; CONSTSPLAT: {{.*}}
+; CONSTVEC: {{.*}}
\ No newline at end of file

>From e2f011e4de86fa8ef84f6f304f70cfa85de90f26 Mon Sep 17 00:00:00 2001
From: Jeffrey Byrnes <Jeffrey.Byrnes at amd.com>
Date: Thu, 17 Apr 2025 10:11:18 -0700
Subject: [PATCH 2/2] [InstCombine] Extend bitmask->select combine to match
 and->mul

Change-Id: I1cc2acd3804dde50636518f3ef2c9581848ae9f6
---
 .../InstCombine/InstCombineAndOrXor.cpp       | 124 +++++++++++-------
 .../test/Transforms/InstCombine/or-bitmask.ll |  96 ++++++++++++--
 2 files changed, 164 insertions(+), 56 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index 18d3f551adef4..b3ec2bb2d5f3f 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -3560,6 +3560,72 @@ static Value *foldOrOfInversions(BinaryOperator &I,
   return nullptr;
 }
 
+struct DecomposedBitMaskMul {
+  Value *X;
+  APInt Factor;
+  APInt Mask;
+};
+
+static std::optional<DecomposedBitMaskMul> matchBitmaskMul(Value *V) {
+  Instruction *Op = dyn_cast<Instruction>(V);
+  if (!Op)
+    return std::nullopt;
+
+  Value *MulOp = nullptr;
+  const APInt *MulConst = nullptr;
+  if (match(Op, m_Mul(m_Value(MulOp), m_APInt(MulConst)))) {
+    Value *Original = nullptr;
+    const APInt *Mask = nullptr;
+    if (!MulConst->isStrictlyPositive())
+      return std::nullopt;
+
+    if (match(MulOp, m_And(m_Value(Original), m_APInt(Mask)))) {
+      if (!Mask->isStrictlyPositive())
+        return std::nullopt;
+      DecomposedBitMaskMul Ret;
+      Ret.X = Original;
+      Ret.Mask = *Mask;
+      Ret.Factor = *MulConst;
+      return Ret;
+    }
+    return std::nullopt;
+  }
+
+  Value *Cond = nullptr;
+  const APInt *EqZero = nullptr, *NeZero = nullptr;
+
+  //  (!(A & N) ? 0 : N * C) + (!(A & M) ? 0 : M * C) -> A & (N + M) * C
+  if (match(Op, m_Select(m_Value(Cond), m_APInt(EqZero), m_APInt(NeZero)))) {
+    auto ICmpDecompose =
+        decomposeBitTest(Cond, /*LookThruTrunc=*/true,
+                         /*AllowNonZeroC=*/false, /*DecomposeBitMask=*/true);
+    if (!ICmpDecompose.has_value())
+      return std::nullopt;
+
+    if (ICmpDecompose->Pred == ICmpInst::ICMP_NE)
+      std::swap(EqZero, NeZero);
+
+    if (!EqZero->isZero() || !NeZero->isStrictlyPositive())
+      return std::nullopt;
+
+    if (!ICmpInst::isEquality(ICmpDecompose->Pred) ||
+        !ICmpDecompose->C.isZero() || !ICmpDecompose->Mask.isPowerOf2() ||
+        ICmpDecompose->Mask.isNegative())
+      return std::nullopt;
+
+    if (!NeZero->urem(ICmpDecompose->Mask).isZero())
+      return std::nullopt;
+
+    DecomposedBitMaskMul Ret;
+    Ret.X = ICmpDecompose->X;
+    Ret.Mask = ICmpDecompose->Mask;
+    Ret.Factor = NeZero->udiv(ICmpDecompose->Mask);
+    return Ret;
+  }
+
+  return std::nullopt;
+}
+
 // FIXME: We use commutative matchers (m_c_*) for some, but not all, matches
 // here. We should standardize that construct where it is needed or choose some
 // other way to ensure that commutated variants of patterns are not missed.
@@ -3642,51 +3708,19 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
                                    /*NSW=*/true, /*NUW=*/true))
       return R;
 
-    Value *Cond0 = nullptr, *Cond1 = nullptr;
-    const APInt *Op0Eq = nullptr, *Op0Ne = nullptr;
-    const APInt *Op1Eq = nullptr, *Op1Ne = nullptr;
-
-    //  (!(A & N) ? 0 : N * C) + (!(A & M) ? 0 : M * C) -> A & (N + M) * C
-    if (match(I.getOperand(0),
-              m_Select(m_Value(Cond0), m_APInt(Op0Eq), m_APInt(Op0Ne))) &&
-        match(I.getOperand(1),
-              m_Select(m_Value(Cond1), m_APInt(Op1Eq), m_APInt(Op1Ne)))) {
-      CmpPredicate Pred0, Pred1;
-
-      auto LHSDecompose =
-          decomposeBitTest(Cond0, /*LookThruTrunc=*/true,
-                           /*AllowNonZeroC=*/false, /*DecomposeAnd=*/true);
-      auto RHSDecompose =
-          decomposeBitTest(Cond1, /*LookThruTrunc=*/true,
-                           /*AllowNonZeroC=*/false, /*DecomposeAnd=*/true);
-
-      if (LHSDecompose && RHSDecompose && LHSDecompose->X == RHSDecompose->X &&
-          (ICmpInst::isEquality(LHSDecompose->Pred)) &&
-          !RHSDecompose->Mask.isNegative() &&
-          !LHSDecompose->Mask.isNegative() && RHSDecompose->Mask.isPowerOf2() &&
-          LHSDecompose->Mask.isPowerOf2() &&
-          LHSDecompose->Mask != RHSDecompose->Mask &&
-          LHSDecompose->C.isZero() && RHSDecompose->C.isZero()) {
-        if (LHSDecompose->Pred == ICmpInst::ICMP_NE)
-          std::swap(Op0Eq, Op0Ne);
-        if (RHSDecompose->Pred == ICmpInst::ICMP_NE)
-          std::swap(Op1Eq, Op1Ne);
-
-        if (Op0Ne->isStrictlyPositive() && Op1Ne->isStrictlyPositive() &&
-            Op0Eq->isZero() && Op1Eq->isZero() &&
-            Op0Ne->urem(LHSDecompose->Mask).isZero() &&
-            Op1Ne->urem(RHSDecompose->Mask).isZero() &&
-            Op0Ne->udiv(LHSDecompose->Mask) ==
-                Op1Ne->udiv(RHSDecompose->Mask)) {
-          auto NewAnd = Builder.CreateAnd(
-              LHSDecompose->X,
-              ConstantInt::get(LHSDecompose->X->getType(),
-                               (LHSDecompose->Mask + RHSDecompose->Mask)));
-
-          return BinaryOperator::CreateMul(
-              NewAnd, ConstantInt::get(NewAnd->getType(),
-                                       Op0Ne->udiv(LHSDecompose->Mask)));
-        }
+    auto Decomp0 = matchBitmaskMul(I.getOperand(0));
+    auto Decomp1 = matchBitmaskMul(I.getOperand(1));
+
+    if (Decomp0 && Decomp1) {
+      if (Decomp0->X == Decomp1->X &&
+          (Decomp0->Mask & Decomp1->Mask).isZero() &&
+          Decomp0->Factor == Decomp1->Factor) {
+        auto NewAnd = Builder.CreateAnd(
+            Decomp0->X, ConstantInt::get(Decomp0->X->getType(),
+                                         (Decomp0->Mask + Decomp1->Mask)));
+
+        return BinaryOperator::CreateMul(
+            NewAnd, ConstantInt::get(NewAnd->getType(), Decomp1->Factor));
       }
     }
   }
diff --git a/llvm/test/Transforms/InstCombine/or-bitmask.ll b/llvm/test/Transforms/InstCombine/or-bitmask.ll
index 18acc42f2dc40..4b989733afbb9 100644
--- a/llvm/test/Transforms/InstCombine/or-bitmask.ll
+++ b/llvm/test/Transforms/InstCombine/or-bitmask.ll
@@ -36,13 +36,9 @@ define i32 @add_select_cmp_and2(i32 %in) {
 
 define i32 @add_select_cmp_and3(i32 %in) {
 ; CHECK-LABEL: @add_select_cmp_and3(
-; CHECK-NEXT:    [[TMP1:%.*]] = and i32 [[IN:%.*]], 3
+; CHECK-NEXT:    [[TMP1:%.*]] = and i32 [[IN:%.*]], 7
 ; CHECK-NEXT:    [[TEMP:%.*]] = mul nuw nsw i32 [[TMP1]], 72
-; CHECK-NEXT:    [[BITOP2:%.*]] = and i32 [[IN]], 4
-; CHECK-NEXT:    [[CMP2:%.*]] = icmp eq i32 [[BITOP2]], 0
-; CHECK-NEXT:    [[SEL2:%.*]] = select i1 [[CMP2]], i32 0, i32 288
-; CHECK-NEXT:    [[OUT:%.*]] = or disjoint i32 [[TEMP]], [[SEL2]]
-; CHECK-NEXT:    ret i32 [[OUT]]
+; CHECK-NEXT:    ret i32 [[TEMP]]
 ;
   %bitop0 = and i32 %in, 1
   %cmp0 = icmp eq i32 %bitop0, 0
@@ -60,12 +56,9 @@ define i32 @add_select_cmp_and3(i32 %in) {
 
 define i32 @add_select_cmp_and4(i32 %in) {
 ; CHECK-LABEL: @add_select_cmp_and4(
-; CHECK-NEXT:    [[TMP1:%.*]] = and i32 [[IN:%.*]], 3
-; CHECK-NEXT:    [[OUT:%.*]] = mul nuw nsw i32 [[TMP1]], 72
-; CHECK-NEXT:    [[TMP2:%.*]] = and i32 [[IN]], 12
+; CHECK-NEXT:    [[TMP2:%.*]] = and i32 [[IN:%.*]], 15
 ; CHECK-NEXT:    [[TEMP3:%.*]] = mul nuw nsw i32 [[TMP2]], 72
-; CHECK-NEXT:    [[OUT1:%.*]] = or disjoint i32 [[OUT]], [[TEMP3]]
-; CHECK-NEXT:    ret i32 [[OUT1]]
+; CHECK-NEXT:    ret i32 [[TEMP3]]
 ;
   %bitop0 = and i32 %in, 1
   %cmp0 = icmp eq i32 %bitop0, 0
@@ -323,6 +316,87 @@ define <2 x i32> @add_select_cmp_vec_nonunique(<2 x i32> %in) {
   %out = or disjoint <2 x i32> %sel0, %sel1
   ret <2 x i32> %out
 }
+
+define i32 @add_select_cmp_mixed1(i32 %in) {
+; CHECK-LABEL: @add_select_cmp_mixed1(
+; CHECK-NEXT:    [[TMP1:%.*]] = and i32 [[IN:%.*]], 3
+; CHECK-NEXT:    [[OUT:%.*]] = mul nuw nsw i32 [[TMP1]], 72
+; CHECK-NEXT:    ret i32 [[OUT]]
+;
+  %mask = and i32 %in, 1
+  %sel0 = mul i32 %mask, 72
+  %bitop1 = and i32 %in, 2
+  %cmp1 = icmp eq i32 %bitop1, 0
+  %sel1 = select i1 %cmp1, i32 0, i32 144
+  %out = or disjoint i32 %sel0, %sel1
+  ret i32 %out
+}
+
+define i32 @add_select_cmp_mixed2(i32 %in) {
+; CHECK-LABEL: @add_select_cmp_mixed2(
+; CHECK-NEXT:    [[TMP1:%.*]] = and i32 [[IN:%.*]], 3
+; CHECK-NEXT:    [[OUT:%.*]] = mul nuw nsw i32 [[TMP1]], 72
+; CHECK-NEXT:    ret i32 [[OUT]]
+;
+  %bitop0 = and i32 %in, 1
+  %cmp0 = icmp eq i32 %bitop0, 0
+  %mask = and i32 %in, 2
+  %sel0 = select i1 %cmp0, i32 0, i32 72
+  %sel1 = mul i32 %mask, 72
+  %out = or disjoint i32 %sel0, %sel1
+  ret i32 %out
+}
+
+define i32 @add_select_cmp_and_mul(i32 %in) {
+; CHECK-LABEL: @add_select_cmp_and_mul(
+; CHECK-NEXT:    [[TMP1:%.*]] = and i32 [[IN:%.*]], 3
+; CHECK-NEXT:    [[OUT:%.*]] = mul nuw nsw i32 [[TMP1]], 72
+; CHECK-NEXT:    ret i32 [[OUT]]
+;
+  %mask0 = and i32 %in, 1
+  %sel0 = mul i32 %mask0, 72
+  %mask1 = and i32 %in, 2
+  %sel1 = mul i32 %mask1, 72
+  %out = or disjoint i32 %sel0, %sel1
+  ret i32 %out
+}
+
+define i32 @add_select_cmp_mixed2_mismatch(i32 %in) {
+; CHECK-LABEL: @add_select_cmp_mixed2_mismatch(
+; CHECK-NEXT:    [[BITOP0:%.*]] = and i32 [[IN:%.*]], 1
+; CHECK-NEXT:    [[CMP0:%.*]] = icmp eq i32 [[BITOP0]], 0
+; CHECK-NEXT:    [[MASK:%.*]] = and i32 [[IN]], 2
+; CHECK-NEXT:    [[SEL0:%.*]] = select i1 [[CMP0]], i32 0, i32 73
+; CHECK-NEXT:    [[SEL1:%.*]] = mul nuw nsw i32 [[MASK]], 72
+; CHECK-NEXT:    [[OUT:%.*]] = or disjoint i32 [[SEL0]], [[SEL1]]
+; CHECK-NEXT:    ret i32 [[OUT]]
+;
+  %bitop0 = and i32 %in, 1
+  %cmp0 = icmp eq i32 %bitop0, 0
+  %mask = and i32 %in, 2
+  %sel0 = select i1 %cmp0, i32 0, i32 73
+  %sel1 = mul i32 %mask, 72
+  %out = or disjoint i32 %sel0, %sel1
+  ret i32 %out
+}
+
+define i32 @add_select_cmp_and_mul_mismatch(i32 %in) {
+; CHECK-LABEL: @add_select_cmp_and_mul_mismatch(
+; CHECK-NEXT:    [[TMP1:%.*]] = trunc i32 [[IN:%.*]] to i1
+; CHECK-NEXT:    [[SEL0:%.*]] = select i1 [[TMP1]], i32 73, i32 0
+; CHECK-NEXT:    [[MASK1:%.*]] = and i32 [[IN]], 2
+; CHECK-NEXT:    [[SEL1:%.*]] = mul nuw nsw i32 [[MASK1]], 72
+; CHECK-NEXT:    [[OUT:%.*]] = or disjoint i32 [[SEL0]], [[SEL1]]
+; CHECK-NEXT:    ret i32 [[OUT]]
+;
+  %mask0 = and i32 %in, 1
+  %sel0 = mul i32 %mask0, 73
+  %mask1 = and i32 %in, 2
+  %sel1 = mul i32 %mask1, 72
+  %out = or disjoint i32 %sel0, %sel1
+  ret i32 %out
+}
+
 ;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
 ; CONSTSPLAT: {{.*}}
 ; CONSTVEC: {{.*}}
\ No newline at end of file



More information about the llvm-commits mailing list