[llvm] [DAGCombiner] Add support for scalarising extracts of a vector setcc (PR #116031)

David Sherwood via llvm-commits llvm-commits at lists.llvm.org
Thu Nov 21 06:02:41 PST 2024


https://github.com/david-arm updated https://github.com/llvm/llvm-project/pull/116031

>From b2063b8ed643056a4410dde10946d7a16e3a5df8 Mon Sep 17 00:00:00 2001
From: David Sherwood <david.sherwood at arm.com>
Date: Wed, 13 Nov 2024 11:42:14 +0000
Subject: [PATCH 1/4] Add tests

---
 .../CodeGen/AArch64/extract-vector-cmp.ll     | 247 ++++++++++++++++++
 1 file changed, 247 insertions(+)
 create mode 100644 llvm/test/CodeGen/AArch64/extract-vector-cmp.ll

diff --git a/llvm/test/CodeGen/AArch64/extract-vector-cmp.ll b/llvm/test/CodeGen/AArch64/extract-vector-cmp.ll
new file mode 100644
index 00000000000000..6143d99c8380be
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/extract-vector-cmp.ll
@@ -0,0 +1,247 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc -mattr=+sve < %s | FileCheck %s
+
+target triple = "aarch64-unknown-linux-gnu"
+
+
+define i1 @extract_icmp_v4i32_const_splat_rhs(<4 x i32> %a) {
+; CHECK-LABEL: extract_icmp_v4i32_const_splat_rhs:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    movi v1.4s, #5
+; CHECK-NEXT:    cmhi v0.4s, v1.4s, v0.4s
+; CHECK-NEXT:    xtn v0.4h, v0.4s
+; CHECK-NEXT:    umov w8, v0.h[1]
+; CHECK-NEXT:    and w0, w8, #0x1
+; CHECK-NEXT:    ret
+  %icmp = icmp ult <4 x i32> %a, splat (i32 5)
+  %ext = extractelement <4 x i1> %icmp, i32 1
+  ret i1 %ext
+}
+
+define i1 @extract_icmp_v4i32_const_splat_lhs(<4 x i32> %a) {
+; CHECK-LABEL: extract_icmp_v4i32_const_splat_lhs:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    movi v1.4s, #7
+; CHECK-NEXT:    cmhi v0.4s, v0.4s, v1.4s
+; CHECK-NEXT:    xtn v0.4h, v0.4s
+; CHECK-NEXT:    umov w8, v0.h[1]
+; CHECK-NEXT:    and w0, w8, #0x1
+; CHECK-NEXT:    ret
+  %icmp = icmp ult <4 x i32> splat(i32 7), %a
+  %ext = extractelement <4 x i1> %icmp, i32 1
+  ret i1 %ext
+}
+
+define i1 @extract_icmp_v4i32_const_vec_rhs(<4 x i32> %a) {
+; CHECK-LABEL: extract_icmp_v4i32_const_vec_rhs:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    adrp x8, .LCPI2_0
+; CHECK-NEXT:    ldr q1, [x8, :lo12:.LCPI2_0]
+; CHECK-NEXT:    cmhi v0.4s, v1.4s, v0.4s
+; CHECK-NEXT:    xtn v0.4h, v0.4s
+; CHECK-NEXT:    umov w8, v0.h[1]
+; CHECK-NEXT:    and w0, w8, #0x1
+; CHECK-NEXT:    ret
+  %icmp = icmp ult <4 x i32> %a, <i32 5, i32 234, i32 -1, i32 7>
+  %ext = extractelement <4 x i1> %icmp, i32 1
+  ret i1 %ext
+}
+
+define i1 @extract_fcmp_v4f32_const_splat_rhs(<4 x float> %a) {
+; CHECK-LABEL: extract_fcmp_v4f32_const_splat_rhs:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    fmov v1.4s, #4.00000000
+; CHECK-NEXT:    fcmge v0.4s, v0.4s, v1.4s
+; CHECK-NEXT:    mvn v0.16b, v0.16b
+; CHECK-NEXT:    xtn v0.4h, v0.4s
+; CHECK-NEXT:    umov w8, v0.h[1]
+; CHECK-NEXT:    and w0, w8, #0x1
+; CHECK-NEXT:    ret
+  %fcmp = fcmp ult <4 x float> %a, splat(float 4.0e+0)
+  %ext = extractelement <4 x i1> %fcmp, i32 1
+  ret i1 %ext
+}
+
+define void @vector_loop_with_icmp(ptr nocapture noundef writeonly %dest) {
+; CHECK-LABEL: vector_loop_with_icmp:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    index z0.d, #0, #1
+; CHECK-NEXT:    mov w8, #15 // =0xf
+; CHECK-NEXT:    mov w9, #4 // =0x4
+; CHECK-NEXT:    dup v2.2d, x8
+; CHECK-NEXT:    dup v3.2d, x9
+; CHECK-NEXT:    add x9, x0, #8
+; CHECK-NEXT:    mov w10, #16 // =0x10
+; CHECK-NEXT:    mov w11, #1 // =0x1
+; CHECK-NEXT:    mov z1.d, z0.d
+; CHECK-NEXT:    add z1.d, z1.d, #2 // =0x2
+; CHECK-NEXT:    b .LBB4_2
+; CHECK-NEXT:  .LBB4_1: // %pred.store.continue18
+; CHECK-NEXT:    // in Loop: Header=BB4_2 Depth=1
+; CHECK-NEXT:    add v1.2d, v1.2d, v3.2d
+; CHECK-NEXT:    add v0.2d, v0.2d, v3.2d
+; CHECK-NEXT:    subs x10, x10, #4
+; CHECK-NEXT:    add x9, x9, #16
+; CHECK-NEXT:    b.eq .LBB4_10
+; CHECK-NEXT:  .LBB4_2: // %vector.body
+; CHECK-NEXT:    // =>This Inner Loop Header: Depth=1
+; CHECK-NEXT:    cmhi v4.2d, v2.2d, v0.2d
+; CHECK-NEXT:    xtn v4.2s, v4.2d
+; CHECK-NEXT:    uzp1 v4.4h, v4.4h, v0.4h
+; CHECK-NEXT:    umov w12, v4.h[0]
+; CHECK-NEXT:    tbz w12, #0, .LBB4_4
+; CHECK-NEXT:  // %bb.3: // %pred.store.if
+; CHECK-NEXT:    // in Loop: Header=BB4_2 Depth=1
+; CHECK-NEXT:    stur w11, [x9, #-8]
+; CHECK-NEXT:  .LBB4_4: // %pred.store.continue
+; CHECK-NEXT:    // in Loop: Header=BB4_2 Depth=1
+; CHECK-NEXT:    dup v4.2d, x8
+; CHECK-NEXT:    cmhi v4.2d, v4.2d, v0.2d
+; CHECK-NEXT:    xtn v4.2s, v4.2d
+; CHECK-NEXT:    uzp1 v4.4h, v4.4h, v0.4h
+; CHECK-NEXT:    umov w12, v4.h[1]
+; CHECK-NEXT:    tbz w12, #0, .LBB4_6
+; CHECK-NEXT:  // %bb.5: // %pred.store.if5
+; CHECK-NEXT:    // in Loop: Header=BB4_2 Depth=1
+; CHECK-NEXT:    stur w11, [x9, #-4]
+; CHECK-NEXT:  .LBB4_6: // %pred.store.continue6
+; CHECK-NEXT:    // in Loop: Header=BB4_2 Depth=1
+; CHECK-NEXT:    dup v4.2d, x8
+; CHECK-NEXT:    cmhi v4.2d, v4.2d, v1.2d
+; CHECK-NEXT:    xtn v4.2s, v4.2d
+; CHECK-NEXT:    uzp1 v4.4h, v0.4h, v4.4h
+; CHECK-NEXT:    umov w12, v4.h[2]
+; CHECK-NEXT:    tbz w12, #0, .LBB4_8
+; CHECK-NEXT:  // %bb.7: // %pred.store.if7
+; CHECK-NEXT:    // in Loop: Header=BB4_2 Depth=1
+; CHECK-NEXT:    str w11, [x9]
+; CHECK-NEXT:  .LBB4_8: // %pred.store.continue8
+; CHECK-NEXT:    // in Loop: Header=BB4_2 Depth=1
+; CHECK-NEXT:    dup v4.2d, x8
+; CHECK-NEXT:    cmhi v4.2d, v4.2d, v1.2d
+; CHECK-NEXT:    xtn v4.2s, v4.2d
+; CHECK-NEXT:    uzp1 v4.4h, v0.4h, v4.4h
+; CHECK-NEXT:    umov w12, v4.h[3]
+; CHECK-NEXT:    tbz w12, #0, .LBB4_1
+; CHECK-NEXT:  // %bb.9: // %pred.store.if9
+; CHECK-NEXT:    // in Loop: Header=BB4_2 Depth=1
+; CHECK-NEXT:    str w11, [x9, #4]
+; CHECK-NEXT:    b .LBB4_1
+; CHECK-NEXT:  .LBB4_10: // %for.cond.cleanup
+; CHECK-NEXT:    ret
+entry:
+  br label %vector.body
+
+vector.body:
+  %index = phi i64 [ 0, %entry ], [ %index.next, %pred.store.continue18 ]
+  %vec.ind = phi <4 x i64> [ <i64 0, i64 1, i64 2, i64 3>, %entry ], [ %vec.ind.next, %pred.store.continue18 ]
+  %0 = icmp ult <4 x i64> %vec.ind, <i64 15, i64 15, i64 15, i64 15>
+  %1 = extractelement <4 x i1> %0, i64 0
+  br i1 %1, label %pred.store.if, label %pred.store.continue
+
+pred.store.if:
+  %2 = getelementptr inbounds i32, ptr %dest, i64 %index
+  store i32 1, ptr %2, align 4
+  br label %pred.store.continue
+
+pred.store.continue:
+  %3 = extractelement <4 x i1> %0, i64 1
+  br i1 %3, label %pred.store.if5, label %pred.store.continue6
+
+pred.store.if5:
+  %4 = or disjoint i64 %index, 1
+  %5 = getelementptr inbounds i32, ptr %dest, i64 %4
+  store i32 1, ptr %5, align 4
+  br label %pred.store.continue6
+
+pred.store.continue6:
+  %6 = extractelement <4 x i1> %0, i64 2
+  br i1 %6, label %pred.store.if7, label %pred.store.continue8
+
+pred.store.if7:
+  %7 = or disjoint i64 %index, 2
+  %8 = getelementptr inbounds i32, ptr %dest, i64 %7
+  store i32 1, ptr %8, align 4
+  br label %pred.store.continue8
+
+pred.store.continue8:
+  %9 = extractelement <4 x i1> %0, i64 3
+  br i1 %9, label %pred.store.if9, label %pred.store.continue18
+
+pred.store.if9:
+  %10 = or disjoint i64 %index, 3
+  %11 = getelementptr inbounds i32, ptr %dest, i64 %10
+  store i32 1, ptr %11, align 4
+  br label %pred.store.continue18
+
+pred.store.continue18:
+  %index.next = add i64 %index, 4
+  %vec.ind.next = add <4 x i64> %vec.ind, <i64 4, i64 4, i64 4, i64 4>
+  %24 = icmp eq i64 %index.next, 16
+  br i1 %24, label %for.cond.cleanup, label %vector.body
+
+for.cond.cleanup:
+  ret void
+}
+
+
+; Negative tests
+
+define i1 @extract_icmp_v4i32_splat_rhs(<4 x i32> %a, i32 %b) {
+; CHECK-LABEL: extract_icmp_v4i32_splat_rhs:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    dup v1.4s, w0
+; CHECK-NEXT:    cmhi v0.4s, v1.4s, v0.4s
+; CHECK-NEXT:    xtn v0.4h, v0.4s
+; CHECK-NEXT:    umov w8, v0.h[1]
+; CHECK-NEXT:    and w0, w8, #0x1
+; CHECK-NEXT:    ret
+  %ins = insertelement <4 x i32> poison, i32 %b, i32 0
+  %splat = shufflevector <4 x i32> %ins, <4 x i32> poison, <4 x i32> zeroinitializer
+  %icmp = icmp ult <4 x i32> %a, %splat
+  %ext = extractelement <4 x i1> %icmp, i32 1
+  ret i1 %ext
+}
+
+define i1 @extract_icmp_v4i32_splat_rhs_mul_use(<4 x i32> %a, ptr %p) {
+; CHECK-LABEL: extract_icmp_v4i32_splat_rhs_mul_use:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    movi v1.4s, #235
+; CHECK-NEXT:    adrp x9, .LCPI6_0
+; CHECK-NEXT:    mov x8, x0
+; CHECK-NEXT:    ldr q2, [x9, :lo12:.LCPI6_0]
+; CHECK-NEXT:    cmhi v0.4s, v1.4s, v0.4s
+; CHECK-NEXT:    xtn v1.4h, v0.4s
+; CHECK-NEXT:    and v0.16b, v0.16b, v2.16b
+; CHECK-NEXT:    addv s0, v0.4s
+; CHECK-NEXT:    umov w9, v1.h[1]
+; CHECK-NEXT:    fmov w10, s0
+; CHECK-NEXT:    and w0, w9, #0x1
+; CHECK-NEXT:    strb w10, [x8]
+; CHECK-NEXT:    ret
+  %icmp = icmp ult <4 x i32> %a, splat(i32 235)
+  %ext = extractelement <4 x i1> %icmp, i32 1
+  store <4 x i1> %icmp, ptr %p, align 4
+  ret i1 %ext
+}
+
+define i1 @extract_icmp_v4i32_splat_rhs_unknown_idx(<4 x i32> %a, i32 %c) {
+; CHECK-LABEL: extract_icmp_v4i32_splat_rhs_unknown_idx:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    sub sp, sp, #16
+; CHECK-NEXT:    .cfi_def_cfa_offset 16
+; CHECK-NEXT:    movi v1.4s, #127
+; CHECK-NEXT:    add x8, sp, #8
+; CHECK-NEXT:    // kill: def $w0 killed $w0 def $x0
+; CHECK-NEXT:    bfi x8, x0, #1, #2
+; CHECK-NEXT:    cmhi v0.4s, v1.4s, v0.4s
+; CHECK-NEXT:    xtn v0.4h, v0.4s
+; CHECK-NEXT:    str d0, [sp, #8]
+; CHECK-NEXT:    ldrh w8, [x8]
+; CHECK-NEXT:    and w0, w8, #0x1
+; CHECK-NEXT:    add sp, sp, #16
+; CHECK-NEXT:    ret
+  %icmp = icmp ult <4 x i32> %a, splat(i32 127)
+  %ext = extractelement <4 x i1> %icmp, i32 %c
+  ret i1 %ext
+}

>From d8f9f074f101a0d69d5829696243e12111731095 Mon Sep 17 00:00:00 2001
From: David Sherwood <david.sherwood at arm.com>
Date: Wed, 13 Nov 2024 11:42:28 +0000
Subject: [PATCH 2/4] [DAGCombiner] Add support for scalarising extracts of a
 vector setcc

For IR like this:

  %icmp = icmp ult <4 x i32> %a, splat (i32 5)
  %res = extractelement <4 x i1> %icmp, i32 1

where there is only one use of %icmp we can take a similar approach
to what we already do for binary ops such add, sub, etc. and convert
this into

  %ext = extractelement <4 x i32> %a, i32 1
  %res = icmp ult i32 %ext, 5

For AArch64 targets at least the scalar boolean result will almost
certainly need to be in a GPR anyway, since it will probably be
used by branches for control flow. I've tried to reuse existing code
in scalarizeExtractedBinop to also work for setcc.

NOTE: The optimisations don't apply for tests such as
extract_icmp_v4i32_splat_rhs in the file

CodeGen/AArch64/extract-vector-cmp.ll

because scalarizeExtractedBinOp only works if one of the input
operands is a constant.
---
 llvm/include/llvm/CodeGen/TargetLowering.h    |  4 +
 llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 66 +++++++++----
 llvm/lib/Target/AArch64/AArch64ISelLowering.h |  2 +
 .../AArch64/dag-combine-concat-vectors.ll     | 49 ----------
 .../CodeGen/AArch64/extract-vector-cmp.ll     | 96 +++++++------------
 5 files changed, 93 insertions(+), 124 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 6a41094ff933b0..ddedb08b44c583 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -3342,6 +3342,10 @@ class TargetLoweringBase {
     return false;
   }
 
+  /// Try to convert an extract element of a vector setcc operation into an
+  /// extract element followed by a scalar operation.
+  virtual bool shouldScalarizeSetCC(SDValue VecOp) const { return false; }
+
   /// Return true if extraction of a scalar element from the given vector type
   /// at the given index is cheap. For example, if scalar operations occur on
   /// the same register file as vector operations, then an extract element may
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 521829675ae7c3..2aa54566a0345b 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -22746,19 +22746,15 @@ SDValue DAGCombiner::scalarizeExtractedVectorLoad(SDNode *EVE, EVT InVecVT,
 
 /// Transform a vector binary operation into a scalar binary operation by moving
 /// the math/logic after an extract element of a vector.
-static SDValue scalarizeExtractedBinop(SDNode *ExtElt, SelectionDAG &DAG,
-                                       const SDLoc &DL, bool LegalOperations) {
-  const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+static bool scalarizeExtractedBinOpCommon(SDNode *ExtElt, SelectionDAG &DAG,
+                                          const SDLoc &DL, bool IsSetCC,
+                                          SDValue &ScalarOp1,
+                                          SDValue &ScalarOp2) {
   SDValue Vec = ExtElt->getOperand(0);
   SDValue Index = ExtElt->getOperand(1);
   auto *IndexC = dyn_cast<ConstantSDNode>(Index);
-  if (!IndexC || !TLI.isBinOp(Vec.getOpcode()) || !Vec.hasOneUse() ||
-      Vec->getNumValues() != 1)
-    return SDValue();
-
-  // Targets may want to avoid this to prevent an expensive register transfer.
-  if (!TLI.shouldScalarizeBinop(Vec))
-    return SDValue();
+  if (!IndexC || !Vec.hasOneUse() || Vec->getNumValues() != 1)
+    return false;
 
   // Extracting an element of a vector constant is constant-folded, so this
   // transform is just replacing a vector op with a scalar op while moving the
@@ -22772,13 +22768,46 @@ static SDValue scalarizeExtractedBinop(SDNode *ExtElt, SelectionDAG &DAG,
       ISD::isConstantSplatVector(Op1.getNode(), SplatVal)) {
     // extractelt (binop X, C), IndexC --> binop (extractelt X, IndexC), C'
     // extractelt (binop C, X), IndexC --> binop C', (extractelt X, IndexC)
-    EVT VT = ExtElt->getValueType(0);
-    SDValue Ext0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Op0, Index);
-    SDValue Ext1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Op1, Index);
-    return DAG.getNode(Vec.getOpcode(), DL, VT, Ext0, Ext1);
+    // extractelt (setcc X, C, op), IndexC -> setcc (extractelt X, IndexC)), C
+    // extractelt (setcc C, X, op), IndexC -> setcc (extractelt IndexC, X)), C
+    EVT VT = Op0->getValueType(0).getVectorElementType();
+    ScalarOp1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Op0, Index);
+    ScalarOp2 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Op1, Index);
+    return true;
   }
 
-  return SDValue();
+  return false;
+}
+
+static SDValue scalarizeExtractedBinOp(SDNode *ExtElt, SelectionDAG &DAG,
+                                       const SDLoc &DL) {
+  SDValue Op1, Op2;
+  SDValue Vec = ExtElt->getOperand(0);
+  const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+  if (!TLI.isBinOp(Vec.getOpcode()) || !TLI.shouldScalarizeBinop(Vec))
+    return SDValue();
+
+  if (!scalarizeExtractedBinOpCommon(ExtElt, DAG, DL, false, Op1, Op2))
+    return SDValue();
+
+  EVT VT = ExtElt->getValueType(0);
+  return DAG.getNode(Vec.getOpcode(), DL, VT, Op1, Op2);
+}
+
+static SDValue scalarizeExtractedSetCC(SDNode *ExtElt, SelectionDAG &DAG,
+                                       const SDLoc &DL) {
+  SDValue Op1, Op2;
+  SDValue Vec = ExtElt->getOperand(0);
+  const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+  if (Vec.getOpcode() != ISD::SETCC || !TLI.shouldScalarizeSetCC(Vec))
+    return SDValue();
+
+  if (!scalarizeExtractedBinOpCommon(ExtElt, DAG, DL, true, Op1, Op2))
+    return SDValue();
+
+  EVT VT = ExtElt->getValueType(0);
+  return DAG.getSetCC(DL, VT, Op1, Op2,
+                      cast<CondCodeSDNode>(Vec->getOperand(2))->get());
 }
 
 // Given a ISD::EXTRACT_VECTOR_ELT, which is a glorified bit sequence extract,
@@ -23011,9 +23040,14 @@ SDValue DAGCombiner::visitEXTRACT_VECTOR_ELT(SDNode *N) {
     }
   }
 
-  if (SDValue BO = scalarizeExtractedBinop(N, DAG, DL, LegalOperations))
+  if (SDValue BO = scalarizeExtractedBinOp(N, DAG, DL))
     return BO;
 
+  // extract (setcc x, splat(y)), i -> setcc (extract x, i)), y
+  if (ScalarVT == VecVT.getVectorElementType())
+    if (SDValue SetCC = scalarizeExtractedSetCC(N, DAG, DL))
+      return SetCC;
+
   if (VecVT.isScalableVector())
     return SDValue();
 
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index cb0b9e965277aa..a1fff8a62a28f4 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -1348,6 +1348,8 @@ class AArch64TargetLowering : public TargetLowering {
   unsigned getMinimumJumpTableEntries() const override;
 
   bool softPromoteHalfType() const override { return true; }
+
+  bool shouldScalarizeSetCC(SDValue VecOp) const override { return true; }
 };
 
 namespace AArch64 {
diff --git a/llvm/test/CodeGen/AArch64/dag-combine-concat-vectors.ll b/llvm/test/CodeGen/AArch64/dag-combine-concat-vectors.ll
index 5a5dee0b53d439..1f1164698826f7 100644
--- a/llvm/test/CodeGen/AArch64/dag-combine-concat-vectors.ll
+++ b/llvm/test/CodeGen/AArch64/dag-combine-concat-vectors.ll
@@ -8,56 +8,7 @@ declare void @llvm.masked.scatter.nxv16i8.nxv16p0(<vscale x 16 x i8>, <vscale x
 define fastcc i8 @allocno_reload_assign() {
 ; CHECK-LABEL: allocno_reload_assign:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    fmov d0, xzr
-; CHECK-NEXT:    ptrue p0.d
-; CHECK-NEXT:    mov z16.d, #0 // =0x0
-; CHECK-NEXT:    cmpeq p0.d, p0/z, z0.d, #0
-; CHECK-NEXT:    uzp1 p0.s, p0.s, p0.s
-; CHECK-NEXT:    uzp1 p0.h, p0.h, p0.h
-; CHECK-NEXT:    uzp1 p0.b, p0.b, p0.b
-; CHECK-NEXT:    mov z0.b, p0/z, #1 // =0x1
-; CHECK-NEXT:    fmov w8, s0
-; CHECK-NEXT:    mov z0.b, #0 // =0x0
-; CHECK-NEXT:    uunpklo z1.h, z0.b
-; CHECK-NEXT:    uunpkhi z0.h, z0.b
-; CHECK-NEXT:    mvn w8, w8
-; CHECK-NEXT:    sbfx x8, x8, #0, #1
-; CHECK-NEXT:    whilelo p0.b, xzr, x8
-; CHECK-NEXT:    uunpklo z2.s, z1.h
-; CHECK-NEXT:    uunpkhi z3.s, z1.h
-; CHECK-NEXT:    uunpklo z5.s, z0.h
-; CHECK-NEXT:    uunpkhi z7.s, z0.h
-; CHECK-NEXT:    punpklo p1.h, p0.b
-; CHECK-NEXT:    punpkhi p0.h, p0.b
-; CHECK-NEXT:    punpklo p2.h, p1.b
-; CHECK-NEXT:    punpkhi p3.h, p1.b
-; CHECK-NEXT:    uunpklo z0.d, z2.s
-; CHECK-NEXT:    uunpkhi z1.d, z2.s
-; CHECK-NEXT:    punpklo p5.h, p0.b
-; CHECK-NEXT:    uunpklo z2.d, z3.s
-; CHECK-NEXT:    uunpkhi z3.d, z3.s
-; CHECK-NEXT:    punpkhi p7.h, p0.b
-; CHECK-NEXT:    uunpklo z4.d, z5.s
-; CHECK-NEXT:    uunpkhi z5.d, z5.s
-; CHECK-NEXT:    uunpklo z6.d, z7.s
-; CHECK-NEXT:    uunpkhi z7.d, z7.s
-; CHECK-NEXT:    punpklo p0.h, p2.b
-; CHECK-NEXT:    punpkhi p1.h, p2.b
-; CHECK-NEXT:    punpklo p2.h, p3.b
-; CHECK-NEXT:    punpkhi p3.h, p3.b
-; CHECK-NEXT:    punpklo p4.h, p5.b
-; CHECK-NEXT:    punpkhi p5.h, p5.b
-; CHECK-NEXT:    punpklo p6.h, p7.b
-; CHECK-NEXT:    punpkhi p7.h, p7.b
 ; CHECK-NEXT:  .LBB0_1: // =>This Inner Loop Header: Depth=1
-; CHECK-NEXT:    st1b { z0.d }, p0, [z16.d]
-; CHECK-NEXT:    st1b { z1.d }, p1, [z16.d]
-; CHECK-NEXT:    st1b { z2.d }, p2, [z16.d]
-; CHECK-NEXT:    st1b { z3.d }, p3, [z16.d]
-; CHECK-NEXT:    st1b { z4.d }, p4, [z16.d]
-; CHECK-NEXT:    st1b { z5.d }, p5, [z16.d]
-; CHECK-NEXT:    st1b { z6.d }, p6, [z16.d]
-; CHECK-NEXT:    st1b { z7.d }, p7, [z16.d]
 ; CHECK-NEXT:    b .LBB0_1
   br label %1
 
diff --git a/llvm/test/CodeGen/AArch64/extract-vector-cmp.ll b/llvm/test/CodeGen/AArch64/extract-vector-cmp.ll
index 6143d99c8380be..2388a0f206a51c 100644
--- a/llvm/test/CodeGen/AArch64/extract-vector-cmp.ll
+++ b/llvm/test/CodeGen/AArch64/extract-vector-cmp.ll
@@ -7,11 +7,9 @@ target triple = "aarch64-unknown-linux-gnu"
 define i1 @extract_icmp_v4i32_const_splat_rhs(<4 x i32> %a) {
 ; CHECK-LABEL: extract_icmp_v4i32_const_splat_rhs:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    movi v1.4s, #5
-; CHECK-NEXT:    cmhi v0.4s, v1.4s, v0.4s
-; CHECK-NEXT:    xtn v0.4h, v0.4s
-; CHECK-NEXT:    umov w8, v0.h[1]
-; CHECK-NEXT:    and w0, w8, #0x1
+; CHECK-NEXT:    mov w8, v0.s[1]
+; CHECK-NEXT:    cmp w8, #5
+; CHECK-NEXT:    cset w0, lo
 ; CHECK-NEXT:    ret
   %icmp = icmp ult <4 x i32> %a, splat (i32 5)
   %ext = extractelement <4 x i1> %icmp, i32 1
@@ -21,11 +19,9 @@ define i1 @extract_icmp_v4i32_const_splat_rhs(<4 x i32> %a) {
 define i1 @extract_icmp_v4i32_const_splat_lhs(<4 x i32> %a) {
 ; CHECK-LABEL: extract_icmp_v4i32_const_splat_lhs:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    movi v1.4s, #7
-; CHECK-NEXT:    cmhi v0.4s, v0.4s, v1.4s
-; CHECK-NEXT:    xtn v0.4h, v0.4s
-; CHECK-NEXT:    umov w8, v0.h[1]
-; CHECK-NEXT:    and w0, w8, #0x1
+; CHECK-NEXT:    mov w8, v0.s[1]
+; CHECK-NEXT:    cmp w8, #7
+; CHECK-NEXT:    cset w0, hi
 ; CHECK-NEXT:    ret
   %icmp = icmp ult <4 x i32> splat(i32 7), %a
   %ext = extractelement <4 x i1> %icmp, i32 1
@@ -35,12 +31,9 @@ define i1 @extract_icmp_v4i32_const_splat_lhs(<4 x i32> %a) {
 define i1 @extract_icmp_v4i32_const_vec_rhs(<4 x i32> %a) {
 ; CHECK-LABEL: extract_icmp_v4i32_const_vec_rhs:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    adrp x8, .LCPI2_0
-; CHECK-NEXT:    ldr q1, [x8, :lo12:.LCPI2_0]
-; CHECK-NEXT:    cmhi v0.4s, v1.4s, v0.4s
-; CHECK-NEXT:    xtn v0.4h, v0.4s
-; CHECK-NEXT:    umov w8, v0.h[1]
-; CHECK-NEXT:    and w0, w8, #0x1
+; CHECK-NEXT:    mov w8, v0.s[1]
+; CHECK-NEXT:    cmp w8, #234
+; CHECK-NEXT:    cset w0, lo
 ; CHECK-NEXT:    ret
   %icmp = icmp ult <4 x i32> %a, <i32 5, i32 234, i32 -1, i32 7>
   %ext = extractelement <4 x i1> %icmp, i32 1
@@ -50,12 +43,10 @@ define i1 @extract_icmp_v4i32_const_vec_rhs(<4 x i32> %a) {
 define i1 @extract_fcmp_v4f32_const_splat_rhs(<4 x float> %a) {
 ; CHECK-LABEL: extract_fcmp_v4f32_const_splat_rhs:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    fmov v1.4s, #4.00000000
-; CHECK-NEXT:    fcmge v0.4s, v0.4s, v1.4s
-; CHECK-NEXT:    mvn v0.16b, v0.16b
-; CHECK-NEXT:    xtn v0.4h, v0.4s
-; CHECK-NEXT:    umov w8, v0.h[1]
-; CHECK-NEXT:    and w0, w8, #0x1
+; CHECK-NEXT:    mov s0, v0.s[1]
+; CHECK-NEXT:    fmov s1, #4.00000000
+; CHECK-NEXT:    fcmp s0, s1
+; CHECK-NEXT:    cset w0, lt
 ; CHECK-NEXT:    ret
   %fcmp = fcmp ult <4 x float> %a, splat(float 4.0e+0)
   %ext = extractelement <4 x i1> %fcmp, i32 1
@@ -66,66 +57,53 @@ define void @vector_loop_with_icmp(ptr nocapture noundef writeonly %dest) {
 ; CHECK-LABEL: vector_loop_with_icmp:
 ; CHECK:       // %bb.0: // %entry
 ; CHECK-NEXT:    index z0.d, #0, #1
-; CHECK-NEXT:    mov w8, #15 // =0xf
-; CHECK-NEXT:    mov w9, #4 // =0x4
+; CHECK-NEXT:    mov w8, #4 // =0x4
+; CHECK-NEXT:    mov w9, #16 // =0x10
 ; CHECK-NEXT:    dup v2.2d, x8
-; CHECK-NEXT:    dup v3.2d, x9
-; CHECK-NEXT:    add x9, x0, #8
-; CHECK-NEXT:    mov w10, #16 // =0x10
-; CHECK-NEXT:    mov w11, #1 // =0x1
+; CHECK-NEXT:    add x8, x0, #8
+; CHECK-NEXT:    mov w10, #1 // =0x1
 ; CHECK-NEXT:    mov z1.d, z0.d
 ; CHECK-NEXT:    add z1.d, z1.d, #2 // =0x2
 ; CHECK-NEXT:    b .LBB4_2
 ; CHECK-NEXT:  .LBB4_1: // %pred.store.continue18
 ; CHECK-NEXT:    // in Loop: Header=BB4_2 Depth=1
-; CHECK-NEXT:    add v1.2d, v1.2d, v3.2d
-; CHECK-NEXT:    add v0.2d, v0.2d, v3.2d
-; CHECK-NEXT:    subs x10, x10, #4
-; CHECK-NEXT:    add x9, x9, #16
+; CHECK-NEXT:    add v1.2d, v1.2d, v2.2d
+; CHECK-NEXT:    add v0.2d, v0.2d, v2.2d
+; CHECK-NEXT:    subs x9, x9, #4
+; CHECK-NEXT:    add x8, x8, #16
 ; CHECK-NEXT:    b.eq .LBB4_10
 ; CHECK-NEXT:  .LBB4_2: // %vector.body
 ; CHECK-NEXT:    // =>This Inner Loop Header: Depth=1
-; CHECK-NEXT:    cmhi v4.2d, v2.2d, v0.2d
-; CHECK-NEXT:    xtn v4.2s, v4.2d
-; CHECK-NEXT:    uzp1 v4.4h, v4.4h, v0.4h
-; CHECK-NEXT:    umov w12, v4.h[0]
-; CHECK-NEXT:    tbz w12, #0, .LBB4_4
+; CHECK-NEXT:    fmov x11, d0
+; CHECK-NEXT:    cmp x11, #14
+; CHECK-NEXT:    b.hi .LBB4_4
 ; CHECK-NEXT:  // %bb.3: // %pred.store.if
 ; CHECK-NEXT:    // in Loop: Header=BB4_2 Depth=1
-; CHECK-NEXT:    stur w11, [x9, #-8]
+; CHECK-NEXT:    stur w10, [x8, #-8]
 ; CHECK-NEXT:  .LBB4_4: // %pred.store.continue
 ; CHECK-NEXT:    // in Loop: Header=BB4_2 Depth=1
-; CHECK-NEXT:    dup v4.2d, x8
-; CHECK-NEXT:    cmhi v4.2d, v4.2d, v0.2d
-; CHECK-NEXT:    xtn v4.2s, v4.2d
-; CHECK-NEXT:    uzp1 v4.4h, v4.4h, v0.4h
-; CHECK-NEXT:    umov w12, v4.h[1]
-; CHECK-NEXT:    tbz w12, #0, .LBB4_6
+; CHECK-NEXT:    mov x11, v0.d[1]
+; CHECK-NEXT:    cmp x11, #14
+; CHECK-NEXT:    b.hi .LBB4_6
 ; CHECK-NEXT:  // %bb.5: // %pred.store.if5
 ; CHECK-NEXT:    // in Loop: Header=BB4_2 Depth=1
-; CHECK-NEXT:    stur w11, [x9, #-4]
+; CHECK-NEXT:    stur w10, [x8, #-4]
 ; CHECK-NEXT:  .LBB4_6: // %pred.store.continue6
 ; CHECK-NEXT:    // in Loop: Header=BB4_2 Depth=1
-; CHECK-NEXT:    dup v4.2d, x8
-; CHECK-NEXT:    cmhi v4.2d, v4.2d, v1.2d
-; CHECK-NEXT:    xtn v4.2s, v4.2d
-; CHECK-NEXT:    uzp1 v4.4h, v0.4h, v4.4h
-; CHECK-NEXT:    umov w12, v4.h[2]
-; CHECK-NEXT:    tbz w12, #0, .LBB4_8
+; CHECK-NEXT:    fmov x11, d1
+; CHECK-NEXT:    cmp x11, #14
+; CHECK-NEXT:    b.hi .LBB4_8
 ; CHECK-NEXT:  // %bb.7: // %pred.store.if7
 ; CHECK-NEXT:    // in Loop: Header=BB4_2 Depth=1
-; CHECK-NEXT:    str w11, [x9]
+; CHECK-NEXT:    str w10, [x8]
 ; CHECK-NEXT:  .LBB4_8: // %pred.store.continue8
 ; CHECK-NEXT:    // in Loop: Header=BB4_2 Depth=1
-; CHECK-NEXT:    dup v4.2d, x8
-; CHECK-NEXT:    cmhi v4.2d, v4.2d, v1.2d
-; CHECK-NEXT:    xtn v4.2s, v4.2d
-; CHECK-NEXT:    uzp1 v4.4h, v0.4h, v4.4h
-; CHECK-NEXT:    umov w12, v4.h[3]
-; CHECK-NEXT:    tbz w12, #0, .LBB4_1
+; CHECK-NEXT:    mov x11, v1.d[1]
+; CHECK-NEXT:    cmp x11, #14
+; CHECK-NEXT:    b.hi .LBB4_1
 ; CHECK-NEXT:  // %bb.9: // %pred.store.if9
 ; CHECK-NEXT:    // in Loop: Header=BB4_2 Depth=1
-; CHECK-NEXT:    str w11, [x9, #4]
+; CHECK-NEXT:    str w10, [x8, #4]
 ; CHECK-NEXT:    b .LBB4_1
 ; CHECK-NEXT:  .LBB4_10: // %for.cond.cleanup
 ; CHECK-NEXT:    ret

>From 5ec0ccaf6bfb935ee34389a55c82ba610ce2f9d9 Mon Sep 17 00:00:00 2001
From: David Sherwood <david.sherwood at arm.com>
Date: Thu, 21 Nov 2024 12:10:35 +0000
Subject: [PATCH 3/4] Address review comments

---
 llvm/include/llvm/CodeGen/TargetLowering.h    |  4 -
 llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 84 +++++++------------
 llvm/lib/Target/AArch64/AArch64ISelLowering.h |  4 +-
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp   |  2 +-
 .../WebAssembly/WebAssemblyISelLowering.cpp   |  2 +-
 llvm/lib/Target/X86/X86ISelLowering.cpp       |  2 +-
 .../AArch64/dag-combine-concat-vectors.ll     | 53 +++++++++++-
 7 files changed, 88 insertions(+), 63 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index ddedb08b44c583..6a41094ff933b0 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -3342,10 +3342,6 @@ class TargetLoweringBase {
     return false;
   }
 
-  /// Try to convert an extract element of a vector setcc operation into an
-  /// extract element followed by a scalar operation.
-  virtual bool shouldScalarizeSetCC(SDValue VecOp) const { return false; }
-
   /// Return true if extraction of a scalar element from the given vector type
   /// at the given index is cheap. For example, if scalar operations occur on
   /// the same register file as vector operations, then an extract element may
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 2aa54566a0345b..6281d11eb42a1e 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -22746,15 +22746,25 @@ SDValue DAGCombiner::scalarizeExtractedVectorLoad(SDNode *EVE, EVT InVecVT,
 
 /// Transform a vector binary operation into a scalar binary operation by moving
 /// the math/logic after an extract element of a vector.
-static bool scalarizeExtractedBinOpCommon(SDNode *ExtElt, SelectionDAG &DAG,
-                                          const SDLoc &DL, bool IsSetCC,
-                                          SDValue &ScalarOp1,
-                                          SDValue &ScalarOp2) {
+static SDValue scalarizeExtractedBinOp(SDNode *ExtElt, SelectionDAG &DAG,
+                                       const SDLoc &DL) {
+  const TargetLowering &TLI = DAG.getTargetLoweringInfo();
   SDValue Vec = ExtElt->getOperand(0);
   SDValue Index = ExtElt->getOperand(1);
   auto *IndexC = dyn_cast<ConstantSDNode>(Index);
-  if (!IndexC || !Vec.hasOneUse() || Vec->getNumValues() != 1)
-    return false;
+  if (!IndexC ||
+      (!TLI.isBinOp(Vec.getOpcode()) && Vec.getOpcode() != ISD::SETCC) ||
+      !Vec.hasOneUse() || Vec->getNumValues() != 1)
+    return SDValue();
+
+  EVT ResVT = ExtElt->getValueType(0);
+  if (Vec.getOpcode() == ISD::SETCC &&
+      ResVT != Vec.getValueType().getVectorElementType())
+    return SDValue();
+
+  // Targets may want to avoid this to prevent an expensive register transfer.
+  if (!TLI.shouldScalarizeBinop(Vec))
+    return SDValue();
 
   // Extracting an element of a vector constant is constant-folded, so this
   // transform is just replacing a vector op with a scalar op while moving the
@@ -22762,52 +22772,23 @@ static bool scalarizeExtractedBinOpCommon(SDNode *ExtElt, SelectionDAG &DAG,
   SDValue Op0 = Vec.getOperand(0);
   SDValue Op1 = Vec.getOperand(1);
   APInt SplatVal;
-  if (isAnyConstantBuildVector(Op0, true) ||
-      ISD::isConstantSplatVector(Op0.getNode(), SplatVal) ||
-      isAnyConstantBuildVector(Op1, true) ||
-      ISD::isConstantSplatVector(Op1.getNode(), SplatVal)) {
-    // extractelt (binop X, C), IndexC --> binop (extractelt X, IndexC), C'
-    // extractelt (binop C, X), IndexC --> binop C', (extractelt X, IndexC)
-    // extractelt (setcc X, C, op), IndexC -> setcc (extractelt X, IndexC)), C
-    // extractelt (setcc C, X, op), IndexC -> setcc (extractelt IndexC, X)), C
-    EVT VT = Op0->getValueType(0).getVectorElementType();
-    ScalarOp1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Op0, Index);
-    ScalarOp2 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Op1, Index);
-    return true;
-  }
-
-  return false;
-}
-
-static SDValue scalarizeExtractedBinOp(SDNode *ExtElt, SelectionDAG &DAG,
-                                       const SDLoc &DL) {
-  SDValue Op1, Op2;
-  SDValue Vec = ExtElt->getOperand(0);
-  const TargetLowering &TLI = DAG.getTargetLoweringInfo();
-  if (!TLI.isBinOp(Vec.getOpcode()) || !TLI.shouldScalarizeBinop(Vec))
-    return SDValue();
-
-  if (!scalarizeExtractedBinOpCommon(ExtElt, DAG, DL, false, Op1, Op2))
+  if (!isAnyConstantBuildVector(Op0, true) &&
+      !ISD::isConstantSplatVector(Op0.getNode(), SplatVal) &&
+      !isAnyConstantBuildVector(Op1, true) &&
+      !ISD::isConstantSplatVector(Op1.getNode(), SplatVal))
     return SDValue();
 
-  EVT VT = ExtElt->getValueType(0);
-  return DAG.getNode(Vec.getOpcode(), DL, VT, Op1, Op2);
-}
+  // extractelt (op X, C), IndexC --> op (extractelt X, IndexC), C'
+  // extractelt (op C, X), IndexC --> op C', (extractelt X, IndexC)
+  EVT OpVT = Op0->getValueType(0).getVectorElementType();
+  Op0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, OpVT, Op0, Index);
+  Op1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, OpVT, Op1, Index);
 
-static SDValue scalarizeExtractedSetCC(SDNode *ExtElt, SelectionDAG &DAG,
-                                       const SDLoc &DL) {
-  SDValue Op1, Op2;
-  SDValue Vec = ExtElt->getOperand(0);
-  const TargetLowering &TLI = DAG.getTargetLoweringInfo();
-  if (Vec.getOpcode() != ISD::SETCC || !TLI.shouldScalarizeSetCC(Vec))
-    return SDValue();
-
-  if (!scalarizeExtractedBinOpCommon(ExtElt, DAG, DL, true, Op1, Op2))
-    return SDValue();
-
-  EVT VT = ExtElt->getValueType(0);
-  return DAG.getSetCC(DL, VT, Op1, Op2,
-                      cast<CondCodeSDNode>(Vec->getOperand(2))->get());
+  if (Vec.getOpcode() == ISD::SETCC)
+    return DAG.getSetCC(DL, ResVT, Op0, Op1,
+                        cast<CondCodeSDNode>(Vec->getOperand(2))->get());
+  else
+    return DAG.getNode(Vec.getOpcode(), DL, ResVT, Op0, Op1);
 }
 
 // Given a ISD::EXTRACT_VECTOR_ELT, which is a glorified bit sequence extract,
@@ -23043,11 +23024,6 @@ SDValue DAGCombiner::visitEXTRACT_VECTOR_ELT(SDNode *N) {
   if (SDValue BO = scalarizeExtractedBinOp(N, DAG, DL))
     return BO;
 
-  // extract (setcc x, splat(y)), i -> setcc (extract x, i)), y
-  if (ScalarVT == VecVT.getVectorElementType())
-    if (SDValue SetCC = scalarizeExtractedSetCC(N, DAG, DL))
-      return SetCC;
-
   if (VecVT.isScalableVector())
     return SDValue();
 
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index a1fff8a62a28f4..d51b36f7e49946 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -1349,7 +1349,9 @@ class AArch64TargetLowering : public TargetLowering {
 
   bool softPromoteHalfType() const override { return true; }
 
-  bool shouldScalarizeSetCC(SDValue VecOp) const override { return true; }
+  bool shouldScalarizeBinop(SDValue VecOp) const override {
+    return VecOp.getOpcode() == ISD::SETCC;
+  }
 };
 
 namespace AArch64 {
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 976b2478b433e5..0453b7f2e691aa 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -2093,7 +2093,7 @@ bool RISCVTargetLowering::shouldScalarizeBinop(SDValue VecOp) const {
 
   // Assume target opcodes can't be scalarized.
   // TODO - do we have any exceptions?
-  if (Opc >= ISD::BUILTIN_OP_END)
+  if (Opc >= ISD::BUILTIN_OP_END || !isBinOp(Opc))
     return false;
 
   // If the vector op is not supported, try to convert to scalar.
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
index 2d00889407ff48..a52af6832d583f 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
@@ -429,7 +429,7 @@ bool WebAssemblyTargetLowering::shouldScalarizeBinop(SDValue VecOp) const {
 
   // Assume target opcodes can't be scalarized.
   // TODO - do we have any exceptions?
-  if (Opc >= ISD::BUILTIN_OP_END)
+  if (Opc >= ISD::BUILTIN_OP_END || !isBinOp(Opc))
     return false;
 
   // If the vector op is not supported, try to convert to scalar.
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 934654a09c1724..e3ef57550ec181 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -3300,7 +3300,7 @@ bool X86TargetLowering::shouldScalarizeBinop(SDValue VecOp) const {
 
   // Assume target opcodes can't be scalarized.
   // TODO - do we have any exceptions?
-  if (Opc >= ISD::BUILTIN_OP_END)
+  if (Opc >= ISD::BUILTIN_OP_END || !isBinOp(Opc))
     return false;
 
   // If the vector op is not supported, try to convert to scalar.
diff --git a/llvm/test/CodeGen/AArch64/dag-combine-concat-vectors.ll b/llvm/test/CodeGen/AArch64/dag-combine-concat-vectors.ll
index 1f1164698826f7..4cb1d5b2fb345d 100644
--- a/llvm/test/CodeGen/AArch64/dag-combine-concat-vectors.ll
+++ b/llvm/test/CodeGen/AArch64/dag-combine-concat-vectors.ll
@@ -5,10 +5,60 @@
 
 declare void @llvm.masked.scatter.nxv16i8.nxv16p0(<vscale x 16 x i8>, <vscale x 16 x ptr>, i32 immarg, <vscale x 16 x i1>)
 
-define fastcc i8 @allocno_reload_assign() {
+define fastcc i8 @allocno_reload_assign(ptr %p) {
 ; CHECK-LABEL: allocno_reload_assign:
 ; CHECK:       // %bb.0:
+; CHECK-NEXT:    fmov d0, xzr
+; CHECK-NEXT:    ptrue p0.d
+; CHECK-NEXT:    mov z16.d, #0 // =0x0
+; CHECK-NEXT:    cmpeq p0.d, p0/z, z0.d, #0
+; CHECK-NEXT:    uzp1 p0.s, p0.s, p0.s
+; CHECK-NEXT:    uzp1 p0.h, p0.h, p0.h
+; CHECK-NEXT:    uzp1 p8.b, p0.b, p0.b
+; CHECK-NEXT:    mov z0.b, p8/z, #1 // =0x1
+; CHECK-NEXT:    fmov w8, s0
+; CHECK-NEXT:    mov z0.b, #0 // =0x0
+; CHECK-NEXT:    uunpklo z1.h, z0.b
+; CHECK-NEXT:    uunpkhi z0.h, z0.b
+; CHECK-NEXT:    mvn w8, w8
+; CHECK-NEXT:    sbfx x8, x8, #0, #1
+; CHECK-NEXT:    whilelo p0.b, xzr, x8
+; CHECK-NEXT:    uunpklo z2.s, z1.h
+; CHECK-NEXT:    uunpkhi z3.s, z1.h
+; CHECK-NEXT:    uunpklo z5.s, z0.h
+; CHECK-NEXT:    uunpkhi z7.s, z0.h
+; CHECK-NEXT:    punpklo p1.h, p0.b
+; CHECK-NEXT:    punpkhi p0.h, p0.b
+; CHECK-NEXT:    punpklo p2.h, p1.b
+; CHECK-NEXT:    punpkhi p4.h, p1.b
+; CHECK-NEXT:    uunpklo z0.d, z2.s
+; CHECK-NEXT:    uunpkhi z1.d, z2.s
+; CHECK-NEXT:    punpklo p6.h, p0.b
+; CHECK-NEXT:    uunpklo z2.d, z3.s
+; CHECK-NEXT:    uunpkhi z3.d, z3.s
+; CHECK-NEXT:    punpkhi p0.h, p0.b
+; CHECK-NEXT:    uunpklo z4.d, z5.s
+; CHECK-NEXT:    uunpkhi z5.d, z5.s
+; CHECK-NEXT:    uunpklo z6.d, z7.s
+; CHECK-NEXT:    uunpkhi z7.d, z7.s
+; CHECK-NEXT:    punpklo p1.h, p2.b
+; CHECK-NEXT:    punpkhi p2.h, p2.b
+; CHECK-NEXT:    punpklo p3.h, p4.b
+; CHECK-NEXT:    punpkhi p4.h, p4.b
+; CHECK-NEXT:    punpklo p5.h, p6.b
+; CHECK-NEXT:    punpkhi p6.h, p6.b
+; CHECK-NEXT:    punpklo p7.h, p0.b
+; CHECK-NEXT:    punpkhi p0.h, p0.b
 ; CHECK-NEXT:  .LBB0_1: // =>This Inner Loop Header: Depth=1
+; CHECK-NEXT:    st1b { z0.d }, p1, [z16.d]
+; CHECK-NEXT:    st1b { z1.d }, p2, [z16.d]
+; CHECK-NEXT:    st1b { z2.d }, p3, [z16.d]
+; CHECK-NEXT:    st1b { z3.d }, p4, [z16.d]
+; CHECK-NEXT:    st1b { z4.d }, p5, [z16.d]
+; CHECK-NEXT:    st1b { z5.d }, p6, [z16.d]
+; CHECK-NEXT:    st1b { z6.d }, p7, [z16.d]
+; CHECK-NEXT:    st1b { z7.d }, p0, [z16.d]
+; CHECK-NEXT:    str p8, [x0]
 ; CHECK-NEXT:    b .LBB0_1
   br label %1
 
@@ -17,6 +67,7 @@ define fastcc i8 @allocno_reload_assign() {
   %constexpr1 = shufflevector <vscale x 16 x i1> %constexpr, <vscale x 16 x i1> poison, <vscale x 16 x i32> zeroinitializer
   %constexpr2 = xor <vscale x 16 x i1> %constexpr1, shufflevector (<vscale x 16 x i1> insertelement (<vscale x 16 x i1> poison, i1 true, i64 0), <vscale x 16 x i1> poison, <vscale x 16 x i32> zeroinitializer)
   call void @llvm.masked.scatter.nxv16i8.nxv16p0(<vscale x 16 x i8> zeroinitializer, <vscale x 16 x ptr> zeroinitializer, i32 0, <vscale x 16 x i1> %constexpr2)
+  store <vscale x 16 x i1> %constexpr, ptr %p, align 16
   br label %1
 }
 

>From 65b02f336722df35f8ab8631e1730b3831fc746f Mon Sep 17 00:00:00 2001
From: David Sherwood <david.sherwood at arm.com>
Date: Thu, 21 Nov 2024 14:01:51 +0000
Subject: [PATCH 4/4] Address review comment

---
 llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 13 ++++++-------
 1 file changed, 6 insertions(+), 7 deletions(-)

diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 6281d11eb42a1e..57cce9f7852636 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -22752,14 +22752,13 @@ static SDValue scalarizeExtractedBinOp(SDNode *ExtElt, SelectionDAG &DAG,
   SDValue Vec = ExtElt->getOperand(0);
   SDValue Index = ExtElt->getOperand(1);
   auto *IndexC = dyn_cast<ConstantSDNode>(Index);
-  if (!IndexC ||
-      (!TLI.isBinOp(Vec.getOpcode()) && Vec.getOpcode() != ISD::SETCC) ||
-      !Vec.hasOneUse() || Vec->getNumValues() != 1)
+  unsigned Opc = Vec.getOpcode();
+  if (!IndexC || (!TLI.isBinOp(Opc) && Opc != ISD::SETCC) || !Vec.hasOneUse() ||
+      Vec->getNumValues() != 1)
     return SDValue();
 
   EVT ResVT = ExtElt->getValueType(0);
-  if (Vec.getOpcode() == ISD::SETCC &&
-      ResVT != Vec.getValueType().getVectorElementType())
+  if (Opc == ISD::SETCC && ResVT != Vec.getValueType().getVectorElementType())
     return SDValue();
 
   // Targets may want to avoid this to prevent an expensive register transfer.
@@ -22784,11 +22783,11 @@ static SDValue scalarizeExtractedBinOp(SDNode *ExtElt, SelectionDAG &DAG,
   Op0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, OpVT, Op0, Index);
   Op1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, OpVT, Op1, Index);
 
-  if (Vec.getOpcode() == ISD::SETCC)
+  if (Opc == ISD::SETCC)
     return DAG.getSetCC(DL, ResVT, Op0, Op1,
                         cast<CondCodeSDNode>(Vec->getOperand(2))->get());
   else
-    return DAG.getNode(Vec.getOpcode(), DL, ResVT, Op0, Op1);
+    return DAG.getNode(Opc, DL, ResVT, Op0, Op1);
 }
 
 // Given a ISD::EXTRACT_VECTOR_ELT, which is a glorified bit sequence extract,



More information about the llvm-commits mailing list