[llvm] [AArch64][Codegen] Improve small shufflevector/concat lowering for SME (PR #116662)

Benjamin Maxwell via llvm-commits llvm-commits at lists.llvm.org
Mon Nov 18 09:28:42 PST 2024


https://github.com/MacDue created https://github.com/llvm/llvm-project/pull/116662

*  Avoid using TBL for small vectors (that can be lowered with a couple of ZIP1s)
*  Fold redundant ZIP1s

>From 92459deafa5ce16c16de1217d3cd381a35e44bfe Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Mon, 18 Nov 2024 17:24:09 +0000
Subject: [PATCH] [AArch64][Codegen] Improve small shufflevector/concat
 lowering for SME

*  Avoid using TBL for small vectors (that can be lowered with a couple
   of ZIP1s)
*  Fold redundant ZIP1s
---
 .../Target/AArch64/AArch64ISelLowering.cpp    | 52 +++++++++++++++++++
 .../sve-streaming-mode-fixed-length-concat.ll | 38 ++++----------
 ...streaming-mode-fixed-length-permute-rev.ll | 11 ++--
 3 files changed, 67 insertions(+), 34 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 9d1c3d4eddc880..30d396f29329e4 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -24720,6 +24720,49 @@ static SDValue tryToWidenSetCCOperands(SDNode *Op, SelectionDAG &DAG) {
                      Op0ExtV, Op1ExtV, Op->getOperand(2));
 }
 
+static SDValue skipElementSizePreservingCast(SDValue Op, EVT VT) {
+  if (Op->getOpcode() == ISD::BITCAST)
+    Op = Op->getOperand(0);
+  EVT OpVT = Op.getValueType();
+  if (OpVT.isVector() && OpVT.getVectorElementType().getSizeInBits() ==
+                             VT.getVectorElementType().getSizeInBits())
+    return Op;
+  return SDValue();
+}
+
+static SDValue performZIP1Combine(SDNode *N, SelectionDAG &DAG) {
+  SDLoc DL(N);
+  EVT VT = N->getValueType(0);
+
+  // zip1(insert_vector_elt(undef, extract_vector_elt(vec, 0), 0),
+  //      insert_vector_elt(undef, extract_vector_elt(vec, 1), 0))
+  // -> vec
+  SDValue Op0 = skipElementSizePreservingCast(N->getOperand(0), VT);
+  SDValue Op1 = skipElementSizePreservingCast(N->getOperand(1), VT);
+  if (Op0 && Op1 && Op0->getOpcode() == ISD::INSERT_VECTOR_ELT &&
+      Op1->getOpcode() == ISD::INSERT_VECTOR_ELT) {
+    SDValue Op00 = Op0->getOperand(0);
+    SDValue Op10 = Op1->getOperand(0);
+    if (Op00.isUndef() && Op10.isUndef() &&
+        Op0->getConstantOperandVal(2) == 0 &&
+        Op1->getConstantOperandVal(2) == 0) {
+      SDValue Op01 = Op0->getOperand(1);
+      SDValue Op11 = Op1->getOperand(1);
+      if (Op01->getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
+          Op11->getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
+          Op01->getConstantOperandVal(1) == 0 &&
+          Op11->getConstantOperandVal(1) == 1) {
+        SDValue Op010 = skipElementSizePreservingCast(Op01->getOperand(0), VT);
+        SDValue Op110 = skipElementSizePreservingCast(Op11->getOperand(0), VT);
+        if (Op010 && Op010 == Op110)
+          return DAG.getBitcast(VT, Op010);
+      }
+    }
+  }
+
+  return SDValue();
+}
+
 static SDValue
 performVecReduceBitwiseCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
                                SelectionDAG &DAG) {
@@ -26161,6 +26204,8 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
 
     break;
   }
+  case AArch64ISD::ZIP1:
+    return performZIP1Combine(N, DAG);
   case ISD::XOR:
     return performXorCombine(N, DAG, DCI, Subtarget);
   case ISD::MUL:
@@ -29030,7 +29075,14 @@ static SDValue GenerateFixedLengthSVETBL(SDValue Op, SDValue Op1, SDValue Op2,
   if (!IsSingleOp && !Subtarget.hasSVE2())
     return SDValue();
 
+  // Small vectors (with few extracts) can be lowered more efficiently as a
+  // sequence of ZIPs.
   EVT VTOp1 = Op.getOperand(0).getValueType();
+  unsigned NumElts = VT.getVectorNumElements();
+  if (VT.isPow2VectorType() && VT.getFixedSizeInBits() <= 128 &&
+      (NumElts <= 2 || (NumElts <= 4 && !Op2.isUndef())))
+    return SDValue();
+
   unsigned BitsPerElt = VTOp1.getVectorElementType().getSizeInBits();
   unsigned IndexLen = MinSVESize / BitsPerElt;
   unsigned ElementsPerVectorReg = VTOp1.getVectorNumElements();
diff --git a/llvm/test/CodeGen/AArch64/sve-streaming-mode-fixed-length-concat.ll b/llvm/test/CodeGen/AArch64/sve-streaming-mode-fixed-length-concat.ll
index 6e2ecfca9e963e..619840fc6afb28 100644
--- a/llvm/test/CodeGen/AArch64/sve-streaming-mode-fixed-length-concat.ll
+++ b/llvm/test/CodeGen/AArch64/sve-streaming-mode-fixed-length-concat.ll
@@ -1,6 +1,6 @@
 ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
-; RUN: llc -mattr=+sve2 -force-streaming-compatible  < %s | FileCheck %s --check-prefixes=CHECK,SVE2
-; RUN: llc -mattr=+sme -force-streaming  < %s | FileCheck %s --check-prefixes=CHECK,SME
+; RUN: llc -mattr=+sve2 -force-streaming-compatible  < %s | FileCheck %s --check-prefixes=CHECK
+; RUN: llc -mattr=+sme -force-streaming  < %s | FileCheck %s --check-prefixes=CHECK
 ; RUN: llc -force-streaming-compatible < %s | FileCheck %s --check-prefix=NONEON-NOSVE
 
 target triple = "aarch64-unknown-linux-gnu"
@@ -406,33 +406,13 @@ define void @concat_v8i64(ptr %a, ptr %b, ptr %c) {
 ;
 
 define <4 x half> @concat_v4f16(<2 x half> %op1, <2 x half> %op2)  {
-; SVE2-LABEL: concat_v4f16:
-; SVE2:       // %bb.0:
-; SVE2-NEXT:    cnth x8
-; SVE2-NEXT:    adrp x9, .LCPI15_0
-; SVE2-NEXT:    adrp x10, .LCPI15_1
-; SVE2-NEXT:    mov z2.h, w8
-; SVE2-NEXT:    ldr q3, [x9, :lo12:.LCPI15_0]
-; SVE2-NEXT:    ldr q4, [x10, :lo12:.LCPI15_1]
-; SVE2-NEXT:    ptrue p0.h, vl8
-; SVE2-NEXT:    // kill: def $d1 killed $d1 killed $z0_z1 def $z0_z1
-; SVE2-NEXT:    // kill: def $d0 killed $d0 killed $z0_z1 def $z0_z1
-; SVE2-NEXT:    mad z2.h, p0/m, z3.h, z4.h
-; SVE2-NEXT:    tbl z0.h, { z0.h, z1.h }, z2.h
-; SVE2-NEXT:    // kill: def $d0 killed $d0 killed $z0
-; SVE2-NEXT:    ret
-;
-; SME-LABEL: concat_v4f16:
-; SME:       // %bb.0:
-; SME-NEXT:    // kill: def $d1 killed $d1 def $z1
-; SME-NEXT:    // kill: def $d0 killed $d0 def $z0
-; SME-NEXT:    mov z2.h, z1.h[1]
-; SME-NEXT:    mov z3.h, z0.h[1]
-; SME-NEXT:    zip1 z1.h, z1.h, z2.h
-; SME-NEXT:    zip1 z0.h, z0.h, z3.h
-; SME-NEXT:    zip1 z0.s, z0.s, z1.s
-; SME-NEXT:    // kill: def $d0 killed $d0 killed $z0
-; SME-NEXT:    ret
+; CHECK-LABEL: concat_v4f16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    // kill: def $d0 killed $d0 def $z0
+; CHECK-NEXT:    // kill: def $d1 killed $d1 def $z1
+; CHECK-NEXT:    zip1 z0.s, z0.s, z1.s
+; CHECK-NEXT:    // kill: def $d0 killed $d0 killed $z0
+; CHECK-NEXT:    ret
 ;
 ; NONEON-NOSVE-LABEL: concat_v4f16:
 ; NONEON-NOSVE:       // %bb.0:
diff --git a/llvm/test/CodeGen/AArch64/sve-streaming-mode-fixed-length-permute-rev.ll b/llvm/test/CodeGen/AArch64/sve-streaming-mode-fixed-length-permute-rev.ll
index a33e8537edf4ee..1b083d80ef3e68 100644
--- a/llvm/test/CodeGen/AArch64/sve-streaming-mode-fixed-length-permute-rev.ll
+++ b/llvm/test/CodeGen/AArch64/sve-streaming-mode-fixed-length-permute-rev.ll
@@ -643,11 +643,12 @@ define void @test_revhv32i16(ptr %a) {
 define void @test_rev_elts_fail(ptr %a) {
 ; CHECK-LABEL: test_rev_elts_fail:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    index z0.d, #1, #-1
-; CHECK-NEXT:    ldp q1, q2, [x0]
-; CHECK-NEXT:    tbl z1.d, { z1.d }, z0.d
-; CHECK-NEXT:    tbl z0.d, { z2.d }, z0.d
-; CHECK-NEXT:    stp q1, q0, [x0]
+; CHECK-NEXT:    ldp q0, q1, [x0]
+; CHECK-NEXT:    mov z2.d, z0.d[1]
+; CHECK-NEXT:    mov z3.d, z1.d[1]
+; CHECK-NEXT:    zip1 z0.d, z2.d, z0.d
+; CHECK-NEXT:    zip1 z1.d, z3.d, z1.d
+; CHECK-NEXT:    stp q0, q1, [x0]
 ; CHECK-NEXT:    ret
 ;
 ; NONEON-NOSVE-LABEL: test_rev_elts_fail:



More information about the llvm-commits mailing list