[llvm] [AArch64][Codegen] Improve small shufflevector/concat lowering for SME (PR #116662)

Benjamin Maxwell via llvm-commits llvm-commits at lists.llvm.org
Tue Nov 19 08:34:02 PST 2024


https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/116662

>From 92459deafa5ce16c16de1217d3cd381a35e44bfe Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Mon, 18 Nov 2024 17:24:09 +0000
Subject: [PATCH 1/3] [AArch64][Codegen] Improve small shufflevector/concat
 lowering for SME

*  Avoid using TBL for small vectors (that can be lowered with a couple
   of ZIP1s)
*  Fold redundant ZIP1s
---
 .../Target/AArch64/AArch64ISelLowering.cpp    | 52 +++++++++++++++++++
 .../sve-streaming-mode-fixed-length-concat.ll | 38 ++++----------
 ...streaming-mode-fixed-length-permute-rev.ll | 11 ++--
 3 files changed, 67 insertions(+), 34 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 9d1c3d4eddc880..30d396f29329e4 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -24720,6 +24720,49 @@ static SDValue tryToWidenSetCCOperands(SDNode *Op, SelectionDAG &DAG) {
                      Op0ExtV, Op1ExtV, Op->getOperand(2));
 }
 
+static SDValue skipElementSizePreservingCast(SDValue Op, EVT VT) {
+  if (Op->getOpcode() == ISD::BITCAST)
+    Op = Op->getOperand(0);
+  EVT OpVT = Op.getValueType();
+  if (OpVT.isVector() && OpVT.getVectorElementType().getSizeInBits() ==
+                             VT.getVectorElementType().getSizeInBits())
+    return Op;
+  return SDValue();
+}
+
+static SDValue performZIP1Combine(SDNode *N, SelectionDAG &DAG) {
+  SDLoc DL(N);
+  EVT VT = N->getValueType(0);
+
+  // zip1(insert_vector_elt(undef, extract_vector_elt(vec, 0), 0),
+  //      insert_vector_elt(undef, extract_vector_elt(vec, 1), 0))
+  // -> vec
+  SDValue Op0 = skipElementSizePreservingCast(N->getOperand(0), VT);
+  SDValue Op1 = skipElementSizePreservingCast(N->getOperand(1), VT);
+  if (Op0 && Op1 && Op0->getOpcode() == ISD::INSERT_VECTOR_ELT &&
+      Op1->getOpcode() == ISD::INSERT_VECTOR_ELT) {
+    SDValue Op00 = Op0->getOperand(0);
+    SDValue Op10 = Op1->getOperand(0);
+    if (Op00.isUndef() && Op10.isUndef() &&
+        Op0->getConstantOperandVal(2) == 0 &&
+        Op1->getConstantOperandVal(2) == 0) {
+      SDValue Op01 = Op0->getOperand(1);
+      SDValue Op11 = Op1->getOperand(1);
+      if (Op01->getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
+          Op11->getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
+          Op01->getConstantOperandVal(1) == 0 &&
+          Op11->getConstantOperandVal(1) == 1) {
+        SDValue Op010 = skipElementSizePreservingCast(Op01->getOperand(0), VT);
+        SDValue Op110 = skipElementSizePreservingCast(Op11->getOperand(0), VT);
+        if (Op010 && Op010 == Op110)
+          return DAG.getBitcast(VT, Op010);
+      }
+    }
+  }
+
+  return SDValue();
+}
+
 static SDValue
 performVecReduceBitwiseCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
                                SelectionDAG &DAG) {
@@ -26161,6 +26204,8 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
 
     break;
   }
+  case AArch64ISD::ZIP1:
+    return performZIP1Combine(N, DAG);
   case ISD::XOR:
     return performXorCombine(N, DAG, DCI, Subtarget);
   case ISD::MUL:
@@ -29030,7 +29075,14 @@ static SDValue GenerateFixedLengthSVETBL(SDValue Op, SDValue Op1, SDValue Op2,
   if (!IsSingleOp && !Subtarget.hasSVE2())
     return SDValue();
 
+  // Small vectors (with few extracts) can be lowered more efficiently as a
+  // sequence of ZIPs.
   EVT VTOp1 = Op.getOperand(0).getValueType();
+  unsigned NumElts = VT.getVectorNumElements();
+  if (VT.isPow2VectorType() && VT.getFixedSizeInBits() <= 128 &&
+      (NumElts <= 2 || (NumElts <= 4 && !Op2.isUndef())))
+    return SDValue();
+
   unsigned BitsPerElt = VTOp1.getVectorElementType().getSizeInBits();
   unsigned IndexLen = MinSVESize / BitsPerElt;
   unsigned ElementsPerVectorReg = VTOp1.getVectorNumElements();
diff --git a/llvm/test/CodeGen/AArch64/sve-streaming-mode-fixed-length-concat.ll b/llvm/test/CodeGen/AArch64/sve-streaming-mode-fixed-length-concat.ll
index 6e2ecfca9e963e..619840fc6afb28 100644
--- a/llvm/test/CodeGen/AArch64/sve-streaming-mode-fixed-length-concat.ll
+++ b/llvm/test/CodeGen/AArch64/sve-streaming-mode-fixed-length-concat.ll
@@ -1,6 +1,6 @@
 ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
-; RUN: llc -mattr=+sve2 -force-streaming-compatible  < %s | FileCheck %s --check-prefixes=CHECK,SVE2
-; RUN: llc -mattr=+sme -force-streaming  < %s | FileCheck %s --check-prefixes=CHECK,SME
+; RUN: llc -mattr=+sve2 -force-streaming-compatible  < %s | FileCheck %s --check-prefixes=CHECK
+; RUN: llc -mattr=+sme -force-streaming  < %s | FileCheck %s --check-prefixes=CHECK
 ; RUN: llc -force-streaming-compatible < %s | FileCheck %s --check-prefix=NONEON-NOSVE
 
 target triple = "aarch64-unknown-linux-gnu"
@@ -406,33 +406,13 @@ define void @concat_v8i64(ptr %a, ptr %b, ptr %c) {
 ;
 
 define <4 x half> @concat_v4f16(<2 x half> %op1, <2 x half> %op2)  {
-; SVE2-LABEL: concat_v4f16:
-; SVE2:       // %bb.0:
-; SVE2-NEXT:    cnth x8
-; SVE2-NEXT:    adrp x9, .LCPI15_0
-; SVE2-NEXT:    adrp x10, .LCPI15_1
-; SVE2-NEXT:    mov z2.h, w8
-; SVE2-NEXT:    ldr q3, [x9, :lo12:.LCPI15_0]
-; SVE2-NEXT:    ldr q4, [x10, :lo12:.LCPI15_1]
-; SVE2-NEXT:    ptrue p0.h, vl8
-; SVE2-NEXT:    // kill: def $d1 killed $d1 killed $z0_z1 def $z0_z1
-; SVE2-NEXT:    // kill: def $d0 killed $d0 killed $z0_z1 def $z0_z1
-; SVE2-NEXT:    mad z2.h, p0/m, z3.h, z4.h
-; SVE2-NEXT:    tbl z0.h, { z0.h, z1.h }, z2.h
-; SVE2-NEXT:    // kill: def $d0 killed $d0 killed $z0
-; SVE2-NEXT:    ret
-;
-; SME-LABEL: concat_v4f16:
-; SME:       // %bb.0:
-; SME-NEXT:    // kill: def $d1 killed $d1 def $z1
-; SME-NEXT:    // kill: def $d0 killed $d0 def $z0
-; SME-NEXT:    mov z2.h, z1.h[1]
-; SME-NEXT:    mov z3.h, z0.h[1]
-; SME-NEXT:    zip1 z1.h, z1.h, z2.h
-; SME-NEXT:    zip1 z0.h, z0.h, z3.h
-; SME-NEXT:    zip1 z0.s, z0.s, z1.s
-; SME-NEXT:    // kill: def $d0 killed $d0 killed $z0
-; SME-NEXT:    ret
+; CHECK-LABEL: concat_v4f16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    // kill: def $d0 killed $d0 def $z0
+; CHECK-NEXT:    // kill: def $d1 killed $d1 def $z1
+; CHECK-NEXT:    zip1 z0.s, z0.s, z1.s
+; CHECK-NEXT:    // kill: def $d0 killed $d0 killed $z0
+; CHECK-NEXT:    ret
 ;
 ; NONEON-NOSVE-LABEL: concat_v4f16:
 ; NONEON-NOSVE:       // %bb.0:
diff --git a/llvm/test/CodeGen/AArch64/sve-streaming-mode-fixed-length-permute-rev.ll b/llvm/test/CodeGen/AArch64/sve-streaming-mode-fixed-length-permute-rev.ll
index a33e8537edf4ee..1b083d80ef3e68 100644
--- a/llvm/test/CodeGen/AArch64/sve-streaming-mode-fixed-length-permute-rev.ll
+++ b/llvm/test/CodeGen/AArch64/sve-streaming-mode-fixed-length-permute-rev.ll
@@ -643,11 +643,12 @@ define void @test_revhv32i16(ptr %a) {
 define void @test_rev_elts_fail(ptr %a) {
 ; CHECK-LABEL: test_rev_elts_fail:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    index z0.d, #1, #-1
-; CHECK-NEXT:    ldp q1, q2, [x0]
-; CHECK-NEXT:    tbl z1.d, { z1.d }, z0.d
-; CHECK-NEXT:    tbl z0.d, { z2.d }, z0.d
-; CHECK-NEXT:    stp q1, q0, [x0]
+; CHECK-NEXT:    ldp q0, q1, [x0]
+; CHECK-NEXT:    mov z2.d, z0.d[1]
+; CHECK-NEXT:    mov z3.d, z1.d[1]
+; CHECK-NEXT:    zip1 z0.d, z2.d, z0.d
+; CHECK-NEXT:    zip1 z1.d, z3.d, z1.d
+; CHECK-NEXT:    stp q0, q1, [x0]
 ; CHECK-NEXT:    ret
 ;
 ; NONEON-NOSVE-LABEL: test_rev_elts_fail:

>From 26fb1e16d29048a378b3498763106c941eca4312 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Tue, 19 Nov 2024 11:42:48 +0000
Subject: [PATCH 2/3] Generalize fold a little

---
 .../Target/AArch64/AArch64ISelLowering.cpp    |  51 ++++-
 .../sve-fixed-length-vector-shuffle-tbl.ll    | 190 ++++++++----------
 2 files changed, 125 insertions(+), 116 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 30d396f29329e4..c6b0f5876f4607 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -24733,10 +24733,7 @@ static SDValue skipElementSizePreservingCast(SDValue Op, EVT VT) {
 static SDValue performZIP1Combine(SDNode *N, SelectionDAG &DAG) {
   SDLoc DL(N);
   EVT VT = N->getValueType(0);
-
-  // zip1(insert_vector_elt(undef, extract_vector_elt(vec, 0), 0),
-  //      insert_vector_elt(undef, extract_vector_elt(vec, 1), 0))
-  // -> vec
+  EVT EltVT = VT.getVectorElementType();
   SDValue Op0 = skipElementSizePreservingCast(N->getOperand(0), VT);
   SDValue Op1 = skipElementSizePreservingCast(N->getOperand(1), VT);
   if (Op0 && Op1 && Op0->getOpcode() == ISD::INSERT_VECTOR_ELT &&
@@ -24749,17 +24746,51 @@ static SDValue performZIP1Combine(SDNode *N, SelectionDAG &DAG) {
       SDValue Op01 = Op0->getOperand(1);
       SDValue Op11 = Op1->getOperand(1);
       if (Op01->getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
-          Op11->getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
-          Op01->getConstantOperandVal(1) == 0 &&
-          Op11->getConstantOperandVal(1) == 1) {
+          Op11->getOpcode() == ISD::EXTRACT_VECTOR_ELT) {
         SDValue Op010 = skipElementSizePreservingCast(Op01->getOperand(0), VT);
         SDValue Op110 = skipElementSizePreservingCast(Op11->getOperand(0), VT);
-        if (Op010 && Op010 == Op110)
-          return DAG.getBitcast(VT, Op010);
+        unsigned StartExtractIdx = Op01->getConstantOperandVal(1);
+        if (Op010 && Op010 == Op110 &&
+            Op11->getConstantOperandVal(1) == StartExtractIdx + 1 &&
+            StartExtractIdx % 2 == 0) {
+          //       t0: nxv16i8 = ...
+          //     t1: i32 = extract_vector_elt t0, Constant:i64<n>
+          //     t2: i32 = extract_vector_elt t0, Constant:i64<n + 1>
+          //   t3: nxv16i8 = insert_vector_elt(undef, t1, 0)
+          //   t4: nxv16i8 = insert_vector_elt(undef, t2, 0)
+          // t5: nxv16i8 = zip1(t3, t4)
+          //
+          // ->
+          //         t0: nxv16i8 = ...
+          //       t1: nxv8i16 = bitcast t0
+          //     t2: i32 = extract_vector_elt t1, Constant:i64<n / 2>
+          //   t3: nxv8i16 = insert_vector_elt(undef, t2, 0)
+          // t4: nxv16i8 = bitcast t3
+          //
+          // Where n % 2 == 0
+          SDValue Result;
+          if (StartExtractIdx == 0)
+            Result = Op010;
+          else if (EltVT.getSizeInBits() < 64) {
+            unsigned LargeEltBits = EltVT.getSizeInBits() * 2;
+            EVT LargeEltVT = MVT::getVectorVT(
+                MVT::getIntegerVT(LargeEltBits),
+                VT.getVectorElementCount().divideCoefficientBy(2));
+            EVT ExtractVT = MVT::getIntegerVT(std::max(LargeEltBits, 32U));
+            SDValue Extract =
+                DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ExtractVT,
+                            DAG.getBitcast(LargeEltVT, Op010),
+                            DAG.getVectorIdxConstant(StartExtractIdx / 2, DL));
+            Result = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, LargeEltVT,
+                                 DAG.getUNDEF(LargeEltVT), Extract,
+                                 DAG.getVectorIdxConstant(0, DL));
+          }
+          if (Result)
+            return DAG.getBitcast(VT, Result);
+        }
       }
     }
   }
-
   return SDValue();
 }
 
diff --git a/llvm/test/CodeGen/AArch64/sve-fixed-length-vector-shuffle-tbl.ll b/llvm/test/CodeGen/AArch64/sve-fixed-length-vector-shuffle-tbl.ll
index 20659cde83ee00..45285f5f6b6938 100644
--- a/llvm/test/CodeGen/AArch64/sve-fixed-length-vector-shuffle-tbl.ll
+++ b/llvm/test/CodeGen/AArch64/sve-fixed-length-vector-shuffle-tbl.ll
@@ -140,64 +140,52 @@ define <8 x i8> @shuffle_index_indices_from_both_ops(ptr %a, ptr %b) {
 ;
 ; SVE2_128_NOMAX-LABEL: shuffle_index_indices_from_both_ops:
 ; SVE2_128_NOMAX:       // %bb.0:
-; SVE2_128_NOMAX-NEXT:    ldr d0, [x1]
-; SVE2_128_NOMAX-NEXT:    ldr d1, [x0]
-; SVE2_128_NOMAX-NEXT:    mov z2.b, z0.b[3]
-; SVE2_128_NOMAX-NEXT:    mov z3.b, z0.b[2]
-; SVE2_128_NOMAX-NEXT:    mov z4.b, z0.b[1]
-; SVE2_128_NOMAX-NEXT:    mov z1.b, z1.b[1]
-; SVE2_128_NOMAX-NEXT:    mov z5.b, z0.b[7]
-; SVE2_128_NOMAX-NEXT:    mov z6.b, z0.b[6]
-; SVE2_128_NOMAX-NEXT:    mov z0.b, z0.b[4]
-; SVE2_128_NOMAX-NEXT:    zip1 z2.b, z3.b, z2.b
-; SVE2_128_NOMAX-NEXT:    zip1 z1.b, z1.b, z4.b
-; SVE2_128_NOMAX-NEXT:    zip1 z3.b, z6.b, z5.b
-; SVE2_128_NOMAX-NEXT:    zip1 z0.b, z0.b, z0.b
-; SVE2_128_NOMAX-NEXT:    zip1 z1.h, z1.h, z2.h
-; SVE2_128_NOMAX-NEXT:    zip1 z0.h, z0.h, z3.h
-; SVE2_128_NOMAX-NEXT:    zip1 z0.s, z1.s, z0.s
+; SVE2_128_NOMAX-NEXT:    ldr d0, [x0]
+; SVE2_128_NOMAX-NEXT:    ldr d1, [x1]
+; SVE2_128_NOMAX-NEXT:    mov z2.b, z1.b[4]
+; SVE2_128_NOMAX-NEXT:    mov z3.b, z1.b[1]
+; SVE2_128_NOMAX-NEXT:    mov z0.b, z0.b[1]
+; SVE2_128_NOMAX-NEXT:    mov z4.h, z1.h[3]
+; SVE2_128_NOMAX-NEXT:    mov z1.h, z1.h[1]
+; SVE2_128_NOMAX-NEXT:    zip1 z2.b, z2.b, z2.b
+; SVE2_128_NOMAX-NEXT:    zip1 z0.b, z0.b, z3.b
+; SVE2_128_NOMAX-NEXT:    zip1 z2.h, z2.h, z4.h
+; SVE2_128_NOMAX-NEXT:    zip1 z0.h, z0.h, z1.h
+; SVE2_128_NOMAX-NEXT:    zip1 z0.s, z0.s, z2.s
 ; SVE2_128_NOMAX-NEXT:    // kill: def $d0 killed $d0 killed $z0
 ; SVE2_128_NOMAX-NEXT:    ret
 ;
 ; SVE2_NOMIN_NOMAX-LABEL: shuffle_index_indices_from_both_ops:
 ; SVE2_NOMIN_NOMAX:       // %bb.0:
-; SVE2_NOMIN_NOMAX-NEXT:    ldr d0, [x1]
-; SVE2_NOMIN_NOMAX-NEXT:    ldr d1, [x0]
-; SVE2_NOMIN_NOMAX-NEXT:    mov z2.b, z0.b[3]
-; SVE2_NOMIN_NOMAX-NEXT:    mov z3.b, z0.b[2]
-; SVE2_NOMIN_NOMAX-NEXT:    mov z4.b, z0.b[1]
-; SVE2_NOMIN_NOMAX-NEXT:    mov z1.b, z1.b[1]
-; SVE2_NOMIN_NOMAX-NEXT:    mov z5.b, z0.b[7]
-; SVE2_NOMIN_NOMAX-NEXT:    mov z6.b, z0.b[6]
-; SVE2_NOMIN_NOMAX-NEXT:    mov z0.b, z0.b[4]
-; SVE2_NOMIN_NOMAX-NEXT:    zip1 z2.b, z3.b, z2.b
-; SVE2_NOMIN_NOMAX-NEXT:    zip1 z1.b, z1.b, z4.b
-; SVE2_NOMIN_NOMAX-NEXT:    zip1 z3.b, z6.b, z5.b
-; SVE2_NOMIN_NOMAX-NEXT:    zip1 z0.b, z0.b, z0.b
-; SVE2_NOMIN_NOMAX-NEXT:    zip1 z1.h, z1.h, z2.h
-; SVE2_NOMIN_NOMAX-NEXT:    zip1 z0.h, z0.h, z3.h
-; SVE2_NOMIN_NOMAX-NEXT:    zip1 z0.s, z1.s, z0.s
+; SVE2_NOMIN_NOMAX-NEXT:    ldr d0, [x0]
+; SVE2_NOMIN_NOMAX-NEXT:    ldr d1, [x1]
+; SVE2_NOMIN_NOMAX-NEXT:    mov z2.b, z1.b[4]
+; SVE2_NOMIN_NOMAX-NEXT:    mov z3.b, z1.b[1]
+; SVE2_NOMIN_NOMAX-NEXT:    mov z0.b, z0.b[1]
+; SVE2_NOMIN_NOMAX-NEXT:    mov z4.h, z1.h[3]
+; SVE2_NOMIN_NOMAX-NEXT:    mov z1.h, z1.h[1]
+; SVE2_NOMIN_NOMAX-NEXT:    zip1 z2.b, z2.b, z2.b
+; SVE2_NOMIN_NOMAX-NEXT:    zip1 z0.b, z0.b, z3.b
+; SVE2_NOMIN_NOMAX-NEXT:    zip1 z2.h, z2.h, z4.h
+; SVE2_NOMIN_NOMAX-NEXT:    zip1 z0.h, z0.h, z1.h
+; SVE2_NOMIN_NOMAX-NEXT:    zip1 z0.s, z0.s, z2.s
 ; SVE2_NOMIN_NOMAX-NEXT:    // kill: def $d0 killed $d0 killed $z0
 ; SVE2_NOMIN_NOMAX-NEXT:    ret
 ;
 ; SVE2_MIN_256_NOMAX-LABEL: shuffle_index_indices_from_both_ops:
 ; SVE2_MIN_256_NOMAX:       // %bb.0:
-; SVE2_MIN_256_NOMAX-NEXT:    ldr d0, [x1]
-; SVE2_MIN_256_NOMAX-NEXT:    ldr d1, [x0]
-; SVE2_MIN_256_NOMAX-NEXT:    mov z2.b, z0.b[3]
-; SVE2_MIN_256_NOMAX-NEXT:    mov z3.b, z0.b[2]
-; SVE2_MIN_256_NOMAX-NEXT:    mov z4.b, z0.b[1]
-; SVE2_MIN_256_NOMAX-NEXT:    mov z1.b, z1.b[1]
-; SVE2_MIN_256_NOMAX-NEXT:    mov z5.b, z0.b[7]
-; SVE2_MIN_256_NOMAX-NEXT:    mov z6.b, z0.b[6]
-; SVE2_MIN_256_NOMAX-NEXT:    mov z0.b, z0.b[4]
-; SVE2_MIN_256_NOMAX-NEXT:    zip1 z2.b, z3.b, z2.b
-; SVE2_MIN_256_NOMAX-NEXT:    zip1 z1.b, z1.b, z4.b
-; SVE2_MIN_256_NOMAX-NEXT:    zip1 z3.b, z6.b, z5.b
-; SVE2_MIN_256_NOMAX-NEXT:    zip1 z0.b, z0.b, z0.b
-; SVE2_MIN_256_NOMAX-NEXT:    zip1 z1.h, z1.h, z2.h
-; SVE2_MIN_256_NOMAX-NEXT:    zip1 z0.h, z0.h, z3.h
-; SVE2_MIN_256_NOMAX-NEXT:    zip1 z0.s, z1.s, z0.s
+; SVE2_MIN_256_NOMAX-NEXT:    ldr d0, [x0]
+; SVE2_MIN_256_NOMAX-NEXT:    ldr d1, [x1]
+; SVE2_MIN_256_NOMAX-NEXT:    mov z2.b, z1.b[4]
+; SVE2_MIN_256_NOMAX-NEXT:    mov z3.b, z1.b[1]
+; SVE2_MIN_256_NOMAX-NEXT:    mov z0.b, z0.b[1]
+; SVE2_MIN_256_NOMAX-NEXT:    mov z4.h, z1.h[3]
+; SVE2_MIN_256_NOMAX-NEXT:    mov z1.h, z1.h[1]
+; SVE2_MIN_256_NOMAX-NEXT:    zip1 z2.b, z2.b, z2.b
+; SVE2_MIN_256_NOMAX-NEXT:    zip1 z0.b, z0.b, z3.b
+; SVE2_MIN_256_NOMAX-NEXT:    zip1 z2.h, z2.h, z4.h
+; SVE2_MIN_256_NOMAX-NEXT:    zip1 z0.h, z0.h, z1.h
+; SVE2_MIN_256_NOMAX-NEXT:    zip1 z0.s, z0.s, z2.s
 ; SVE2_MIN_256_NOMAX-NEXT:    // kill: def $d0 killed $d0 killed $z0
 ; SVE2_MIN_256_NOMAX-NEXT:    ret
   %op1 = load <8 x i8>, ptr %a
@@ -230,58 +218,52 @@ define <8 x i8> @shuffle_index_poison_value(ptr %a, ptr %b) {
 ;
 ; SVE2_128_NOMAX-LABEL: shuffle_index_poison_value:
 ; SVE2_128_NOMAX:       // %bb.0:
-; SVE2_128_NOMAX-NEXT:    ldr d0, [x1]
-; SVE2_128_NOMAX-NEXT:    ldr d1, [x0]
-; SVE2_128_NOMAX-NEXT:    mov z2.b, z0.b[3]
-; SVE2_128_NOMAX-NEXT:    mov z3.b, z0.b[2]
-; SVE2_128_NOMAX-NEXT:    mov z4.b, z0.b[1]
-; SVE2_128_NOMAX-NEXT:    mov z1.b, z1.b[1]
-; SVE2_128_NOMAX-NEXT:    mov z5.b, z0.b[4]
-; SVE2_128_NOMAX-NEXT:    mov z0.b, z0.b[6]
-; SVE2_128_NOMAX-NEXT:    zip1 z2.b, z3.b, z2.b
-; SVE2_128_NOMAX-NEXT:    zip1 z1.b, z1.b, z4.b
-; SVE2_128_NOMAX-NEXT:    zip1 z3.b, z5.b, z5.b
-; SVE2_128_NOMAX-NEXT:    zip1 z1.h, z1.h, z2.h
-; SVE2_128_NOMAX-NEXT:    zip1 z0.h, z3.h, z0.h
-; SVE2_128_NOMAX-NEXT:    zip1 z0.s, z1.s, z0.s
+; SVE2_128_NOMAX-NEXT:    ldr d0, [x0]
+; SVE2_128_NOMAX-NEXT:    ldr d1, [x1]
+; SVE2_128_NOMAX-NEXT:    mov z2.b, z1.b[4]
+; SVE2_128_NOMAX-NEXT:    mov z3.b, z1.b[1]
+; SVE2_128_NOMAX-NEXT:    mov z0.b, z0.b[1]
+; SVE2_128_NOMAX-NEXT:    mov z4.b, z1.b[6]
+; SVE2_128_NOMAX-NEXT:    mov z1.h, z1.h[1]
+; SVE2_128_NOMAX-NEXT:    zip1 z2.b, z2.b, z2.b
+; SVE2_128_NOMAX-NEXT:    zip1 z0.b, z0.b, z3.b
+; SVE2_128_NOMAX-NEXT:    zip1 z2.h, z2.h, z4.h
+; SVE2_128_NOMAX-NEXT:    zip1 z0.h, z0.h, z1.h
+; SVE2_128_NOMAX-NEXT:    zip1 z0.s, z0.s, z2.s
 ; SVE2_128_NOMAX-NEXT:    // kill: def $d0 killed $d0 killed $z0
 ; SVE2_128_NOMAX-NEXT:    ret
 ;
 ; SVE2_NOMIN_NOMAX-LABEL: shuffle_index_poison_value:
 ; SVE2_NOMIN_NOMAX:       // %bb.0:
-; SVE2_NOMIN_NOMAX-NEXT:    ldr d0, [x1]
-; SVE2_NOMIN_NOMAX-NEXT:    ldr d1, [x0]
-; SVE2_NOMIN_NOMAX-NEXT:    mov z2.b, z0.b[3]
-; SVE2_NOMIN_NOMAX-NEXT:    mov z3.b, z0.b[2]
-; SVE2_NOMIN_NOMAX-NEXT:    mov z4.b, z0.b[1]
-; SVE2_NOMIN_NOMAX-NEXT:    mov z1.b, z1.b[1]
-; SVE2_NOMIN_NOMAX-NEXT:    mov z5.b, z0.b[4]
-; SVE2_NOMIN_NOMAX-NEXT:    mov z0.b, z0.b[6]
-; SVE2_NOMIN_NOMAX-NEXT:    zip1 z2.b, z3.b, z2.b
-; SVE2_NOMIN_NOMAX-NEXT:    zip1 z1.b, z1.b, z4.b
-; SVE2_NOMIN_NOMAX-NEXT:    zip1 z3.b, z5.b, z5.b
-; SVE2_NOMIN_NOMAX-NEXT:    zip1 z1.h, z1.h, z2.h
-; SVE2_NOMIN_NOMAX-NEXT:    zip1 z0.h, z3.h, z0.h
-; SVE2_NOMIN_NOMAX-NEXT:    zip1 z0.s, z1.s, z0.s
+; SVE2_NOMIN_NOMAX-NEXT:    ldr d0, [x0]
+; SVE2_NOMIN_NOMAX-NEXT:    ldr d1, [x1]
+; SVE2_NOMIN_NOMAX-NEXT:    mov z2.b, z1.b[4]
+; SVE2_NOMIN_NOMAX-NEXT:    mov z3.b, z1.b[1]
+; SVE2_NOMIN_NOMAX-NEXT:    mov z0.b, z0.b[1]
+; SVE2_NOMIN_NOMAX-NEXT:    mov z4.b, z1.b[6]
+; SVE2_NOMIN_NOMAX-NEXT:    mov z1.h, z1.h[1]
+; SVE2_NOMIN_NOMAX-NEXT:    zip1 z2.b, z2.b, z2.b
+; SVE2_NOMIN_NOMAX-NEXT:    zip1 z0.b, z0.b, z3.b
+; SVE2_NOMIN_NOMAX-NEXT:    zip1 z2.h, z2.h, z4.h
+; SVE2_NOMIN_NOMAX-NEXT:    zip1 z0.h, z0.h, z1.h
+; SVE2_NOMIN_NOMAX-NEXT:    zip1 z0.s, z0.s, z2.s
 ; SVE2_NOMIN_NOMAX-NEXT:    // kill: def $d0 killed $d0 killed $z0
 ; SVE2_NOMIN_NOMAX-NEXT:    ret
 ;
 ; SVE2_MIN_256_NOMAX-LABEL: shuffle_index_poison_value:
 ; SVE2_MIN_256_NOMAX:       // %bb.0:
-; SVE2_MIN_256_NOMAX-NEXT:    ldr d0, [x1]
-; SVE2_MIN_256_NOMAX-NEXT:    ldr d1, [x0]
-; SVE2_MIN_256_NOMAX-NEXT:    mov z2.b, z0.b[3]
-; SVE2_MIN_256_NOMAX-NEXT:    mov z3.b, z0.b[2]
-; SVE2_MIN_256_NOMAX-NEXT:    mov z4.b, z0.b[1]
-; SVE2_MIN_256_NOMAX-NEXT:    mov z1.b, z1.b[1]
-; SVE2_MIN_256_NOMAX-NEXT:    mov z5.b, z0.b[4]
-; SVE2_MIN_256_NOMAX-NEXT:    mov z0.b, z0.b[6]
-; SVE2_MIN_256_NOMAX-NEXT:    zip1 z2.b, z3.b, z2.b
-; SVE2_MIN_256_NOMAX-NEXT:    zip1 z1.b, z1.b, z4.b
-; SVE2_MIN_256_NOMAX-NEXT:    zip1 z3.b, z5.b, z5.b
-; SVE2_MIN_256_NOMAX-NEXT:    zip1 z1.h, z1.h, z2.h
-; SVE2_MIN_256_NOMAX-NEXT:    zip1 z0.h, z3.h, z0.h
-; SVE2_MIN_256_NOMAX-NEXT:    zip1 z0.s, z1.s, z0.s
+; SVE2_MIN_256_NOMAX-NEXT:    ldr d0, [x0]
+; SVE2_MIN_256_NOMAX-NEXT:    ldr d1, [x1]
+; SVE2_MIN_256_NOMAX-NEXT:    mov z2.b, z1.b[4]
+; SVE2_MIN_256_NOMAX-NEXT:    mov z3.b, z1.b[1]
+; SVE2_MIN_256_NOMAX-NEXT:    mov z0.b, z0.b[1]
+; SVE2_MIN_256_NOMAX-NEXT:    mov z4.b, z1.b[6]
+; SVE2_MIN_256_NOMAX-NEXT:    mov z1.h, z1.h[1]
+; SVE2_MIN_256_NOMAX-NEXT:    zip1 z2.b, z2.b, z2.b
+; SVE2_MIN_256_NOMAX-NEXT:    zip1 z0.b, z0.b, z3.b
+; SVE2_MIN_256_NOMAX-NEXT:    zip1 z2.h, z2.h, z4.h
+; SVE2_MIN_256_NOMAX-NEXT:    zip1 z0.h, z0.h, z1.h
+; SVE2_MIN_256_NOMAX-NEXT:    zip1 z0.s, z0.s, z2.s
 ; SVE2_MIN_256_NOMAX-NEXT:    // kill: def $d0 killed $d0 killed $z0
 ; SVE2_MIN_256_NOMAX-NEXT:    ret
   %op1 = load <8 x i8>, ptr %a
@@ -338,22 +320,18 @@ define <8 x i8> @shuffle_op1_poison(ptr %a, ptr %b) {
 define <8 x i8> @negative_test_shuffle_index_size_op_both_maxhw(ptr %a, ptr %b) "target-features"="+sve2" vscale_range(16,16) {
 ; CHECK-LABEL: negative_test_shuffle_index_size_op_both_maxhw:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    ldr d0, [x1]
-; CHECK-NEXT:    ldr d1, [x0]
-; CHECK-NEXT:    mov z2.b, z0.b[3]
-; CHECK-NEXT:    mov z3.b, z0.b[2]
-; CHECK-NEXT:    mov z4.b, z0.b[1]
-; CHECK-NEXT:    mov z1.b, z1.b[1]
-; CHECK-NEXT:    mov z5.b, z0.b[7]
-; CHECK-NEXT:    mov z6.b, z0.b[6]
-; CHECK-NEXT:    mov z0.b, z0.b[4]
-; CHECK-NEXT:    zip1 z2.b, z3.b, z2.b
-; CHECK-NEXT:    zip1 z1.b, z1.b, z4.b
-; CHECK-NEXT:    zip1 z3.b, z6.b, z5.b
-; CHECK-NEXT:    zip1 z0.b, z0.b, z0.b
-; CHECK-NEXT:    zip1 z1.h, z1.h, z2.h
-; CHECK-NEXT:    zip1 z0.h, z0.h, z3.h
-; CHECK-NEXT:    zip1 z0.s, z1.s, z0.s
+; CHECK-NEXT:    ldr d0, [x0]
+; CHECK-NEXT:    ldr d1, [x1]
+; CHECK-NEXT:    mov z2.b, z1.b[4]
+; CHECK-NEXT:    mov z3.b, z1.b[1]
+; CHECK-NEXT:    mov z0.b, z0.b[1]
+; CHECK-NEXT:    mov z4.h, z1.h[3]
+; CHECK-NEXT:    mov z1.h, z1.h[1]
+; CHECK-NEXT:    zip1 z2.b, z2.b, z2.b
+; CHECK-NEXT:    zip1 z0.b, z0.b, z3.b
+; CHECK-NEXT:    zip1 z2.h, z2.h, z4.h
+; CHECK-NEXT:    zip1 z0.h, z0.h, z1.h
+; CHECK-NEXT:    zip1 z0.s, z0.s, z2.s
 ; CHECK-NEXT:    // kill: def $d0 killed $d0 killed $z0
 ; CHECK-NEXT:    ret
   %op1 = load <8 x i8>, ptr %a

>From 4a7c67fb30df893778d33b9b917a8b73b546f169 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Tue, 19 Nov 2024 16:32:59 +0000
Subject: [PATCH 3/3] Use eary exits

---
 .../Target/AArch64/AArch64ISelLowering.cpp    | 115 +++++++++---------
 1 file changed, 59 insertions(+), 56 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index c6b0f5876f4607..b06ebd9a24dbd7 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -24734,64 +24734,67 @@ static SDValue performZIP1Combine(SDNode *N, SelectionDAG &DAG) {
   SDLoc DL(N);
   EVT VT = N->getValueType(0);
   EVT EltVT = VT.getVectorElementType();
+
   SDValue Op0 = skipElementSizePreservingCast(N->getOperand(0), VT);
   SDValue Op1 = skipElementSizePreservingCast(N->getOperand(1), VT);
-  if (Op0 && Op1 && Op0->getOpcode() == ISD::INSERT_VECTOR_ELT &&
-      Op1->getOpcode() == ISD::INSERT_VECTOR_ELT) {
-    SDValue Op00 = Op0->getOperand(0);
-    SDValue Op10 = Op1->getOperand(0);
-    if (Op00.isUndef() && Op10.isUndef() &&
-        Op0->getConstantOperandVal(2) == 0 &&
-        Op1->getConstantOperandVal(2) == 0) {
-      SDValue Op01 = Op0->getOperand(1);
-      SDValue Op11 = Op1->getOperand(1);
-      if (Op01->getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
-          Op11->getOpcode() == ISD::EXTRACT_VECTOR_ELT) {
-        SDValue Op010 = skipElementSizePreservingCast(Op01->getOperand(0), VT);
-        SDValue Op110 = skipElementSizePreservingCast(Op11->getOperand(0), VT);
-        unsigned StartExtractIdx = Op01->getConstantOperandVal(1);
-        if (Op010 && Op010 == Op110 &&
-            Op11->getConstantOperandVal(1) == StartExtractIdx + 1 &&
-            StartExtractIdx % 2 == 0) {
-          //       t0: nxv16i8 = ...
-          //     t1: i32 = extract_vector_elt t0, Constant:i64<n>
-          //     t2: i32 = extract_vector_elt t0, Constant:i64<n + 1>
-          //   t3: nxv16i8 = insert_vector_elt(undef, t1, 0)
-          //   t4: nxv16i8 = insert_vector_elt(undef, t2, 0)
-          // t5: nxv16i8 = zip1(t3, t4)
-          //
-          // ->
-          //         t0: nxv16i8 = ...
-          //       t1: nxv8i16 = bitcast t0
-          //     t2: i32 = extract_vector_elt t1, Constant:i64<n / 2>
-          //   t3: nxv8i16 = insert_vector_elt(undef, t2, 0)
-          // t4: nxv16i8 = bitcast t3
-          //
-          // Where n % 2 == 0
-          SDValue Result;
-          if (StartExtractIdx == 0)
-            Result = Op010;
-          else if (EltVT.getSizeInBits() < 64) {
-            unsigned LargeEltBits = EltVT.getSizeInBits() * 2;
-            EVT LargeEltVT = MVT::getVectorVT(
-                MVT::getIntegerVT(LargeEltBits),
-                VT.getVectorElementCount().divideCoefficientBy(2));
-            EVT ExtractVT = MVT::getIntegerVT(std::max(LargeEltBits, 32U));
-            SDValue Extract =
-                DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ExtractVT,
-                            DAG.getBitcast(LargeEltVT, Op010),
-                            DAG.getVectorIdxConstant(StartExtractIdx / 2, DL));
-            Result = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, LargeEltVT,
-                                 DAG.getUNDEF(LargeEltVT), Extract,
-                                 DAG.getVectorIdxConstant(0, DL));
-          }
-          if (Result)
-            return DAG.getBitcast(VT, Result);
-        }
-      }
-    }
-  }
-  return SDValue();
+  if (!Op0 || !Op1 || Op0->getOpcode() != ISD::INSERT_VECTOR_ELT ||
+      Op1->getOpcode() != ISD::INSERT_VECTOR_ELT)
+    return SDValue();
+
+  SDValue Op00 = Op0->getOperand(0);
+  SDValue Op10 = Op1->getOperand(0);
+  if (!Op00.isUndef() || !Op10.isUndef() ||
+      Op0->getConstantOperandVal(2) != 0 || Op1->getConstantOperandVal(2) != 0)
+    return SDValue();
+
+  SDValue Op01 = Op0->getOperand(1);
+  SDValue Op11 = Op1->getOperand(1);
+  if (Op01->getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
+      Op11->getOpcode() != ISD::EXTRACT_VECTOR_ELT)
+    return SDValue();
+
+  SDValue Op010 = skipElementSizePreservingCast(Op01->getOperand(0), VT);
+  SDValue Op110 = skipElementSizePreservingCast(Op11->getOperand(0), VT);
+  unsigned StartExtractIdx = Op01->getConstantOperandVal(1);
+  if (!Op010 || Op010 != Op110 ||
+      Op11->getConstantOperandVal(1) != StartExtractIdx + 1 ||
+      StartExtractIdx % 2 != 0)
+    return SDValue();
+
+  //       t0: nxv16i8 = ...
+  //     t1: i32 = extract_vector_elt t0, Constant:i64<n>
+  //     t2: i32 = extract_vector_elt t0, Constant:i64<n + 1>
+  //   t3: nxv16i8 = insert_vector_elt(undef, t1, 0)
+  //   t4: nxv16i8 = insert_vector_elt(undef, t2, 0)
+  // t5: nxv16i8 = zip1(t3, t4)
+  //
+  // ->
+  //         t0: nxv16i8 = ...
+  //       t1: nxv8i16 = bitcast t0
+  //     t2: i32 = extract_vector_elt t1, Constant:i64<n / 2>
+  //   t3: nxv8i16 = insert_vector_elt(undef, t2, 0)
+  // t4: nxv16i8 = bitcast t3
+  //
+  // Where n % 2 == 0
+  SDValue Result;
+  if (StartExtractIdx == 0)
+    Result = Op010;
+  else if (EltVT.getSizeInBits() < 64) {
+    unsigned LargeEltBits = EltVT.getSizeInBits() * 2;
+    EVT LargeEltVT =
+        MVT::getVectorVT(MVT::getIntegerVT(LargeEltBits),
+                         VT.getVectorElementCount().divideCoefficientBy(2));
+    EVT ExtractVT = MVT::getIntegerVT(std::max(LargeEltBits, 32U));
+    SDValue Extract =
+        DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ExtractVT,
+                    DAG.getBitcast(LargeEltVT, Op010),
+                    DAG.getVectorIdxConstant(StartExtractIdx / 2, DL));
+    Result = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, LargeEltVT,
+                         DAG.getUNDEF(LargeEltVT), Extract,
+                         DAG.getVectorIdxConstant(0, DL));
+  }
+
+  return Result ? DAG.getBitcast(VT, Result) : SDValue();
 }
 
 static SDValue



More information about the llvm-commits mailing list