[llvm] [AArch64] Convert UADDV(add(zext, zext)) into UADDLV(concat). (PR #78301)

Rin Dobrescu via llvm-commits llvm-commits at lists.llvm.org
Wed Jan 17 04:17:44 PST 2024


https://github.com/Rin18 updated https://github.com/llvm/llvm-project/pull/78301

>From 714df50db8950ee8f9727893538671e5db68cd57 Mon Sep 17 00:00:00 2001
From: Rin Dobrescu <rin.dobrescu at arm.com>
Date: Tue, 16 Jan 2024 15:42:32 +0000
Subject: [PATCH 1/2] [AArch64] Convert UADDV(add(zext(64-bit source),
 zext(64-bit source))) into UADDLV(concat).

---
 .../Target/AArch64/AArch64ISelLowering.cpp    | 47 +++++++++--
 .../AArch64/aarch64-combine-add-zext.ll       | 20 +++++
 llvm/test/CodeGen/AArch64/avoid-pre-trunc.ll  | 12 +--
 llvm/test/CodeGen/AArch64/neon-dotreduce.ll   | 53 +++++++-----
 llvm/test/CodeGen/AArch64/vecreduce-add.ll    | 82 ++++++++++++-------
 5 files changed, 153 insertions(+), 61 deletions(-)
 create mode 100644 llvm/test/CodeGen/AArch64/aarch64-combine-add-zext.ll

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index dffe69bdb900db..0f7047c84d48c9 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -16490,11 +16490,14 @@ static SDValue performVecReduceAddCombine(SDNode *N, SelectionDAG &DAG,
                      VecReudceAdd8);
 }
 
-// Given an (integer) vecreduce, we know the order of the inputs does not
-// matter. We can convert UADDV(add(zext(extract_lo(x)), zext(extract_hi(x))))
-// into UADDV(UADDLP(x)). This can also happen through an extra add, where we
-// transform UADDV(add(y, add(zext(extract_lo(x)), zext(extract_hi(x))))).
+// Turn UADDV(add(zext(extract_lo(x)), zext(extract_hi(x)))) into
+// UADDV(UADDLP(x)). If that fails, then convert UADDV(add(zext(64-bit source),
+// zext(64-bit source))) into UADDLV(concat).
 static SDValue performUADDVAddCombine(SDValue A, SelectionDAG &DAG) {
+  // Given an (integer) vecreduce, we know the order of the inputs does not
+  // matter. We can convert UADDV(add(zext(extract_lo(x)), zext(extract_hi(x))))
+  // into UADDV(UADDLP(x)). This can also happen through an extra add, where we
+  // transform UADDV(add(y, add(zext(extract_lo(x)), zext(extract_hi(x))))).
   auto DetectAddExtract = [&](SDValue A) {
     // Look for add(zext(extract_lo(x)), zext(extract_hi(x))), returning
     // UADDLP(x) if found.
@@ -16528,6 +16531,34 @@ static SDValue performUADDVAddCombine(SDValue A, SelectionDAG &DAG) {
     return DAG.getNode(Opcode, SDLoc(A), VT, Ext0.getOperand(0));
   };
 
+  // We can convert a UADDV(add(zext(64-bit source), zext(64-bit source))) into
+  // UADDLV(concat), where the concat represents the 64-bit zext sources.
+  auto DetectZextConcat = [&](SDValue A, SelectionDAG &DAG) {
+    // Look for add(zext(64-bit source), zext(64-bit source)), returning
+    // UADDLV(concat(zext, zext)) if found.
+    if (A.getOpcode() != ISD::ADD)
+      return SDValue();
+    EVT VT = A.getValueType();
+    if (VT != MVT::v4i32)
+      return SDValue();
+    SDValue Op0 = A.getOperand(0);
+    SDValue Op1 = A.getOperand(1);
+    if (Op0.getOpcode() != ISD::ZERO_EXTEND)
+      return SDValue();
+    SDValue Ext0 = Op0.getOperand(0);
+    SDValue Ext1 = Op1.getOperand(0);
+    EVT ExtVT0 = Ext0.getValueType();
+    EVT ExtVT1 = Ext1.getValueType();
+    // Check zext VTs are the same and 64-bit length.
+    if (ExtVT0 != ExtVT1 || !(ExtVT0 == MVT::v8i8 || ExtVT0 == MVT::v4i16))
+      return SDValue();
+    // Get VT for concat of zext sources.
+    EVT PairVT = ExtVT0.getDoubleNumVectorElementsVT(*DAG.getContext());
+    SDValue Concat =
+        DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(A), PairVT, Ext0, Ext1);
+    return DAG.getNode(AArch64ISD::UADDLV, SDLoc(A), MVT::v4i32, Concat);
+  };
+
   if (SDValue R = DetectAddExtract(A))
     return R;
 
@@ -16539,6 +16570,10 @@ static SDValue performUADDVAddCombine(SDValue A, SelectionDAG &DAG) {
     if (SDValue R = performUADDVAddCombine(A.getOperand(1), DAG))
       return DAG.getNode(ISD::ADD, SDLoc(A), A.getValueType(), R,
                          A.getOperand(0));
+
+  if (SDValue R = DetectZextConcat(A, DAG))
+    return R;
+
   return SDValue();
 }
 
@@ -16546,7 +16581,9 @@ static SDValue performUADDVCombine(SDNode *N, SelectionDAG &DAG) {
   SDValue A = N->getOperand(0);
   if (A.getOpcode() == ISD::ADD)
     if (SDValue R = performUADDVAddCombine(A, DAG))
-      return DAG.getNode(N->getOpcode(), SDLoc(N), N->getValueType(0), R);
+      return R.getOpcode() == AArch64ISD::UADDLV
+                 ? R
+                 : DAG.getNode(N->getOpcode(), SDLoc(N), N->getValueType(0), R);
   return SDValue();
 }
 
diff --git a/llvm/test/CodeGen/AArch64/aarch64-combine-add-zext.ll b/llvm/test/CodeGen/AArch64/aarch64-combine-add-zext.ll
new file mode 100644
index 00000000000000..1cb0206c3fdbc1
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/aarch64-combine-add-zext.ll
@@ -0,0 +1,20 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc < %s -mtriple=aarch64-none-linux-gnu | FileCheck %s
+
+define i32 @test_add_zext(<4 x i16> %a, <4 x i16> %b) local_unnamed_addr #0 {
+; CHECK-LABEL: test_add_zext:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    // kill: def $d0 killed $d0 def $q0
+; CHECK-NEXT:    // kill: def $d1 killed $d1 def $q1
+; CHECK-NEXT:    mov v0.d[1], v1.d[0]
+; CHECK-NEXT:    uaddlv s0, v0.8h
+; CHECK-NEXT:    fmov w0, s0
+; CHECK-NEXT:    ret
+  %z1 = zext <4 x i16> %a to <4 x i32>
+  %z2 = zext <4 x i16> %b to <4 x i32>
+  %z = add <4 x i32> %z1, %z2
+  %r = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %z)
+  ret i32 %r
+}
+
+declare i32 @llvm.vector.reduce.add.v4i32(<4 x i32>)
diff --git a/llvm/test/CodeGen/AArch64/avoid-pre-trunc.ll b/llvm/test/CodeGen/AArch64/avoid-pre-trunc.ll
index 1fc177f034975d..24cce9a2b26b58 100644
--- a/llvm/test/CodeGen/AArch64/avoid-pre-trunc.ll
+++ b/llvm/test/CodeGen/AArch64/avoid-pre-trunc.ll
@@ -18,14 +18,14 @@ define i32 @lower_lshr(<4 x i32> %a, <4 x i32> %b, <4 x i32> %c, <4 x i32> %d, <
 ; CHECK-NEXT:    mov v4.s[2], v6.s[0]
 ; CHECK-NEXT:    mov v0.s[3], v1.s[0]
 ; CHECK-NEXT:    mov v4.s[3], v3.s[0]
-; CHECK-NEXT:    xtn v2.4h, v0.4s
+; CHECK-NEXT:    xtn v1.4h, v0.4s
 ; CHECK-NEXT:    shrn v0.4h, v0.4s, #16
-; CHECK-NEXT:    xtn v1.4h, v4.4s
+; CHECK-NEXT:    xtn v2.4h, v4.4s
 ; CHECK-NEXT:    shrn v3.4h, v4.4s, #16
-; CHECK-NEXT:    uhadd v0.4h, v2.4h, v0.4h
-; CHECK-NEXT:    uhadd v1.4h, v1.4h, v3.4h
-; CHECK-NEXT:    uaddl v0.4s, v0.4h, v1.4h
-; CHECK-NEXT:    addv s0, v0.4s
+; CHECK-NEXT:    uhadd v0.4h, v1.4h, v0.4h
+; CHECK-NEXT:    uhadd v1.4h, v2.4h, v3.4h
+; CHECK-NEXT:    mov v0.d[1], v1.d[0]
+; CHECK-NEXT:    uaddlv s0, v0.8h
 ; CHECK-NEXT:    fmov w0, s0
 ; CHECK-NEXT:    ret
   %l87  = tail call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %a)
diff --git a/llvm/test/CodeGen/AArch64/neon-dotreduce.ll b/llvm/test/CodeGen/AArch64/neon-dotreduce.ll
index 706aa4ad1b4665..e4767594851eae 100644
--- a/llvm/test/CodeGen/AArch64/neon-dotreduce.ll
+++ b/llvm/test/CodeGen/AArch64/neon-dotreduce.ll
@@ -1039,17 +1039,21 @@ define i32 @test_udot_v25i8_nomla(ptr nocapture readonly %a1) {
 ; CHECK:       // %bb.0: // %entry
 ; CHECK-NEXT:    ldp q2, q1, [x0]
 ; CHECK-NEXT:    movi v0.2d, #0000000000000000
-; CHECK-NEXT:    ushll v3.8h, v1.8b, #0
-; CHECK-NEXT:    ushll v4.8h, v2.8b, #0
+; CHECK-NEXT:    ushll v3.8h, v2.8b, #0
+; CHECK-NEXT:    ushll v4.8h, v1.8b, #0
 ; CHECK-NEXT:    ushll2 v1.8h, v1.16b, #0
-; CHECK-NEXT:    ushll2 v2.8h, v2.16b, #0
-; CHECK-NEXT:    uaddl2 v5.4s, v4.8h, v3.8h
+; CHECK-NEXT:    ext v5.16b, v3.16b, v3.16b, #8
+; CHECK-NEXT:    ext v6.16b, v4.16b, v4.16b, #8
 ; CHECK-NEXT:    ushll v1.4s, v1.4h, #0
-; CHECK-NEXT:    uaddl v3.4s, v4.4h, v3.4h
+; CHECK-NEXT:    mov v3.d[1], v4.d[0]
 ; CHECK-NEXT:    mov v0.s[0], v1.s[0]
-; CHECK-NEXT:    uaddw2 v1.4s, v5.4s, v2.8h
-; CHECK-NEXT:    uaddw v0.4s, v0.4s, v2.4h
-; CHECK-NEXT:    add v1.4s, v3.4s, v1.4s
+; CHECK-NEXT:    ushll2 v1.8h, v2.16b, #0
+; CHECK-NEXT:    mov v5.d[1], v6.d[0]
+; CHECK-NEXT:    uaddlv s2, v3.8h
+; CHECK-NEXT:    uaddw v0.4s, v0.4s, v1.4h
+; CHECK-NEXT:    uaddlv s3, v5.8h
+; CHECK-NEXT:    add v0.4s, v2.4s, v0.4s
+; CHECK-NEXT:    uaddw2 v1.4s, v3.4s, v1.8h
 ; CHECK-NEXT:    add v0.4s, v1.4s, v0.4s
 ; CHECK-NEXT:    addv s0, v0.4s
 ; CHECK-NEXT:    fmov w0, s0
@@ -1631,23 +1635,30 @@ define i32 @test_udot_v33i8_nomla(ptr nocapture readonly %a1) {
 ; CHECK-LABEL: test_udot_v33i8_nomla:
 ; CHECK:       // %bb.0: // %entry
 ; CHECK-NEXT:    ldr b1, [x0, #32]
-; CHECK-NEXT:    ldp q3, q2, [x0]
+; CHECK-NEXT:    ldp q2, q3, [x0]
 ; CHECK-NEXT:    movi v0.2d, #0000000000000000
 ; CHECK-NEXT:    ushll v1.8h, v1.8b, #0
-; CHECK-NEXT:    ushll v4.8h, v2.8b, #0
-; CHECK-NEXT:    ushll v5.8h, v3.8b, #0
-; CHECK-NEXT:    ushll2 v2.8h, v2.16b, #0
-; CHECK-NEXT:    ushll2 v3.8h, v3.16b, #0
+; CHECK-NEXT:    ushll2 v4.8h, v2.16b, #0
+; CHECK-NEXT:    ushll2 v5.8h, v3.16b, #0
+; CHECK-NEXT:    ushll v3.8h, v3.8b, #0
+; CHECK-NEXT:    ushll v2.8h, v2.8b, #0
 ; CHECK-NEXT:    ushll v1.4s, v1.4h, #0
-; CHECK-NEXT:    uaddl2 v6.4s, v3.8h, v2.8h
-; CHECK-NEXT:    uaddl v2.4s, v3.4h, v2.4h
+; CHECK-NEXT:    ext v6.16b, v4.16b, v4.16b, #8
+; CHECK-NEXT:    ext v7.16b, v5.16b, v5.16b, #8
+; CHECK-NEXT:    mov v4.d[1], v5.d[0]
+; CHECK-NEXT:    ext v16.16b, v2.16b, v2.16b, #8
 ; CHECK-NEXT:    mov v0.s[0], v1.s[0]
-; CHECK-NEXT:    uaddl2 v1.4s, v5.8h, v4.8h
-; CHECK-NEXT:    add v1.4s, v1.4s, v6.4s
-; CHECK-NEXT:    uaddw v0.4s, v0.4s, v5.4h
-; CHECK-NEXT:    uaddw v0.4s, v0.4s, v4.4h
-; CHECK-NEXT:    add v1.4s, v2.4s, v1.4s
-; CHECK-NEXT:    add v0.4s, v0.4s, v1.4s
+; CHECK-NEXT:    ext v1.16b, v3.16b, v3.16b, #8
+; CHECK-NEXT:    mov v6.d[1], v7.d[0]
+; CHECK-NEXT:    uaddlv s4, v4.8h
+; CHECK-NEXT:    mov v16.d[1], v1.d[0]
+; CHECK-NEXT:    uaddw v0.4s, v0.4s, v2.4h
+; CHECK-NEXT:    uaddlv s1, v6.8h
+; CHECK-NEXT:    uaddlv s2, v16.8h
+; CHECK-NEXT:    uaddw v0.4s, v0.4s, v3.4h
+; CHECK-NEXT:    add v1.4s, v1.4s, v2.4s
+; CHECK-NEXT:    add v0.4s, v4.4s, v0.4s
+; CHECK-NEXT:    add v0.4s, v1.4s, v0.4s
 ; CHECK-NEXT:    addv s0, v0.4s
 ; CHECK-NEXT:    fmov w0, s0
 ; CHECK-NEXT:    ret
diff --git a/llvm/test/CodeGen/AArch64/vecreduce-add.ll b/llvm/test/CodeGen/AArch64/vecreduce-add.ll
index b24967d5a130ed..ad9cf2e1d99365 100644
--- a/llvm/test/CodeGen/AArch64/vecreduce-add.ll
+++ b/llvm/test/CodeGen/AArch64/vecreduce-add.ll
@@ -513,11 +513,15 @@ entry:
 define i32 @add_v16i8_v16i32_zext(<16 x i8> %x) {
 ; CHECK-SD-BASE-LABEL: add_v16i8_v16i32_zext:
 ; CHECK-SD-BASE:       // %bb.0: // %entry
-; CHECK-SD-BASE-NEXT:    ushll2 v1.8h, v0.16b, #0
-; CHECK-SD-BASE-NEXT:    ushll v0.8h, v0.8b, #0
-; CHECK-SD-BASE-NEXT:    uaddl2 v2.4s, v0.8h, v1.8h
-; CHECK-SD-BASE-NEXT:    uaddl v0.4s, v0.4h, v1.4h
-; CHECK-SD-BASE-NEXT:    add v0.4s, v0.4s, v2.4s
+; CHECK-SD-BASE-NEXT:    ushll v1.8h, v0.8b, #0
+; CHECK-SD-BASE-NEXT:    ushll2 v0.8h, v0.16b, #0
+; CHECK-SD-BASE-NEXT:    ext v2.16b, v1.16b, v1.16b, #8
+; CHECK-SD-BASE-NEXT:    ext v3.16b, v0.16b, v0.16b, #8
+; CHECK-SD-BASE-NEXT:    mov v1.d[1], v0.d[0]
+; CHECK-SD-BASE-NEXT:    mov v2.d[1], v3.d[0]
+; CHECK-SD-BASE-NEXT:    uaddlv s0, v1.8h
+; CHECK-SD-BASE-NEXT:    uaddlv s1, v2.8h
+; CHECK-SD-BASE-NEXT:    add v0.4s, v1.4s, v0.4s
 ; CHECK-SD-BASE-NEXT:    addv s0, v0.4s
 ; CHECK-SD-BASE-NEXT:    fmov w0, s0
 ; CHECK-SD-BASE-NEXT:    ret
@@ -1910,11 +1914,15 @@ entry:
 define i32 @add_v16i8_v16i32_acc_zext(<16 x i8> %x, i32 %a) {
 ; CHECK-SD-BASE-LABEL: add_v16i8_v16i32_acc_zext:
 ; CHECK-SD-BASE:       // %bb.0: // %entry
-; CHECK-SD-BASE-NEXT:    ushll2 v1.8h, v0.16b, #0
-; CHECK-SD-BASE-NEXT:    ushll v0.8h, v0.8b, #0
-; CHECK-SD-BASE-NEXT:    uaddl2 v2.4s, v0.8h, v1.8h
-; CHECK-SD-BASE-NEXT:    uaddl v0.4s, v0.4h, v1.4h
-; CHECK-SD-BASE-NEXT:    add v0.4s, v0.4s, v2.4s
+; CHECK-SD-BASE-NEXT:    ushll v1.8h, v0.8b, #0
+; CHECK-SD-BASE-NEXT:    ushll2 v0.8h, v0.16b, #0
+; CHECK-SD-BASE-NEXT:    ext v2.16b, v1.16b, v1.16b, #8
+; CHECK-SD-BASE-NEXT:    ext v3.16b, v0.16b, v0.16b, #8
+; CHECK-SD-BASE-NEXT:    mov v1.d[1], v0.d[0]
+; CHECK-SD-BASE-NEXT:    mov v2.d[1], v3.d[0]
+; CHECK-SD-BASE-NEXT:    uaddlv s0, v1.8h
+; CHECK-SD-BASE-NEXT:    uaddlv s1, v2.8h
+; CHECK-SD-BASE-NEXT:    add v0.4s, v1.4s, v0.4s
 ; CHECK-SD-BASE-NEXT:    addv s0, v0.4s
 ; CHECK-SD-BASE-NEXT:    fmov w8, s0
 ; CHECK-SD-BASE-NEXT:    add w0, w8, w0
@@ -3200,15 +3208,19 @@ entry:
 define i32 @add_pair_v4i16_v4i32_zext(<4 x i16> %x, <4 x i16> %y) {
 ; CHECK-SD-BASE-LABEL: add_pair_v4i16_v4i32_zext:
 ; CHECK-SD-BASE:       // %bb.0: // %entry
-; CHECK-SD-BASE-NEXT:    uaddl v0.4s, v0.4h, v1.4h
-; CHECK-SD-BASE-NEXT:    addv s0, v0.4s
+; CHECK-SD-BASE-NEXT:    // kill: def $d0 killed $d0 def $q0
+; CHECK-SD-BASE-NEXT:    // kill: def $d1 killed $d1 def $q1
+; CHECK-SD-BASE-NEXT:    mov v0.d[1], v1.d[0]
+; CHECK-SD-BASE-NEXT:    uaddlv s0, v0.8h
 ; CHECK-SD-BASE-NEXT:    fmov w0, s0
 ; CHECK-SD-BASE-NEXT:    ret
 ;
 ; CHECK-SD-DOT-LABEL: add_pair_v4i16_v4i32_zext:
 ; CHECK-SD-DOT:       // %bb.0: // %entry
-; CHECK-SD-DOT-NEXT:    uaddl v0.4s, v0.4h, v1.4h
-; CHECK-SD-DOT-NEXT:    addv s0, v0.4s
+; CHECK-SD-DOT-NEXT:    // kill: def $d0 killed $d0 def $q0
+; CHECK-SD-DOT-NEXT:    // kill: def $d1 killed $d1 def $q1
+; CHECK-SD-DOT-NEXT:    mov v0.d[1], v1.d[0]
+; CHECK-SD-DOT-NEXT:    uaddlv s0, v0.8h
 ; CHECK-SD-DOT-NEXT:    fmov w0, s0
 ; CHECK-SD-DOT-NEXT:    ret
 ;
@@ -4781,17 +4793,25 @@ entry:
 define i32 @add_pair_v16i8_v16i32_zext(<16 x i8> %x, <16 x i8> %y) {
 ; CHECK-SD-BASE-LABEL: add_pair_v16i8_v16i32_zext:
 ; CHECK-SD-BASE:       // %bb.0: // %entry
-; CHECK-SD-BASE-NEXT:    ushll2 v2.8h, v0.16b, #0
-; CHECK-SD-BASE-NEXT:    ushll v0.8h, v0.8b, #0
-; CHECK-SD-BASE-NEXT:    ushll2 v3.8h, v1.16b, #0
-; CHECK-SD-BASE-NEXT:    ushll v1.8h, v1.8b, #0
-; CHECK-SD-BASE-NEXT:    uaddl2 v4.4s, v0.8h, v2.8h
-; CHECK-SD-BASE-NEXT:    uaddl v0.4s, v0.4h, v2.4h
-; CHECK-SD-BASE-NEXT:    uaddl2 v2.4s, v1.8h, v3.8h
-; CHECK-SD-BASE-NEXT:    uaddl v1.4s, v1.4h, v3.4h
-; CHECK-SD-BASE-NEXT:    add v0.4s, v0.4s, v4.4s
-; CHECK-SD-BASE-NEXT:    add v1.4s, v1.4s, v2.4s
-; CHECK-SD-BASE-NEXT:    add v0.4s, v0.4s, v1.4s
+; CHECK-SD-BASE-NEXT:    ushll v2.8h, v0.8b, #0
+; CHECK-SD-BASE-NEXT:    ushll2 v0.8h, v0.16b, #0
+; CHECK-SD-BASE-NEXT:    ushll v3.8h, v1.8b, #0
+; CHECK-SD-BASE-NEXT:    ushll2 v1.8h, v1.16b, #0
+; CHECK-SD-BASE-NEXT:    ext v4.16b, v2.16b, v2.16b, #8
+; CHECK-SD-BASE-NEXT:    ext v5.16b, v0.16b, v0.16b, #8
+; CHECK-SD-BASE-NEXT:    ext v6.16b, v3.16b, v3.16b, #8
+; CHECK-SD-BASE-NEXT:    ext v7.16b, v1.16b, v1.16b, #8
+; CHECK-SD-BASE-NEXT:    mov v2.d[1], v0.d[0]
+; CHECK-SD-BASE-NEXT:    mov v3.d[1], v1.d[0]
+; CHECK-SD-BASE-NEXT:    mov v4.d[1], v5.d[0]
+; CHECK-SD-BASE-NEXT:    mov v6.d[1], v7.d[0]
+; CHECK-SD-BASE-NEXT:    uaddlv s0, v2.8h
+; CHECK-SD-BASE-NEXT:    uaddlv s2, v3.8h
+; CHECK-SD-BASE-NEXT:    uaddlv s1, v4.8h
+; CHECK-SD-BASE-NEXT:    uaddlv s3, v6.8h
+; CHECK-SD-BASE-NEXT:    add v0.4s, v1.4s, v0.4s
+; CHECK-SD-BASE-NEXT:    add v1.4s, v3.4s, v2.4s
+; CHECK-SD-BASE-NEXT:    add v0.4s, v1.4s, v0.4s
 ; CHECK-SD-BASE-NEXT:    addv s0, v0.4s
 ; CHECK-SD-BASE-NEXT:    fmov w0, s0
 ; CHECK-SD-BASE-NEXT:    ret
@@ -5043,19 +5063,23 @@ entry:
 define i32 @add_pair_v4i8_v4i32_zext(<4 x i8> %x, <4 x i8> %y) {
 ; CHECK-SD-BASE-LABEL: add_pair_v4i8_v4i32_zext:
 ; CHECK-SD-BASE:       // %bb.0: // %entry
+; CHECK-SD-BASE-NEXT:    // kill: def $d1 killed $d1 def $q1
+; CHECK-SD-BASE-NEXT:    // kill: def $d0 killed $d0 def $q0
 ; CHECK-SD-BASE-NEXT:    bic v0.4h, #255, lsl #8
 ; CHECK-SD-BASE-NEXT:    bic v1.4h, #255, lsl #8
-; CHECK-SD-BASE-NEXT:    uaddl v0.4s, v0.4h, v1.4h
-; CHECK-SD-BASE-NEXT:    addv s0, v0.4s
+; CHECK-SD-BASE-NEXT:    mov v0.d[1], v1.d[0]
+; CHECK-SD-BASE-NEXT:    uaddlv s0, v0.8h
 ; CHECK-SD-BASE-NEXT:    fmov w0, s0
 ; CHECK-SD-BASE-NEXT:    ret
 ;
 ; CHECK-SD-DOT-LABEL: add_pair_v4i8_v4i32_zext:
 ; CHECK-SD-DOT:       // %bb.0: // %entry
+; CHECK-SD-DOT-NEXT:    // kill: def $d1 killed $d1 def $q1
+; CHECK-SD-DOT-NEXT:    // kill: def $d0 killed $d0 def $q0
 ; CHECK-SD-DOT-NEXT:    bic v0.4h, #255, lsl #8
 ; CHECK-SD-DOT-NEXT:    bic v1.4h, #255, lsl #8
-; CHECK-SD-DOT-NEXT:    uaddl v0.4s, v0.4h, v1.4h
-; CHECK-SD-DOT-NEXT:    addv s0, v0.4s
+; CHECK-SD-DOT-NEXT:    mov v0.d[1], v1.d[0]
+; CHECK-SD-DOT-NEXT:    uaddlv s0, v0.8h
 ; CHECK-SD-DOT-NEXT:    fmov w0, s0
 ; CHECK-SD-DOT-NEXT:    ret
 ;

>From ad9d5cf78d4c1e4311e7b1a881e0c407b2201262 Mon Sep 17 00:00:00 2001
From: Rin Dobrescu <rin.dobrescu at arm.com>
Date: Wed, 17 Jan 2024 12:14:53 +0000
Subject: [PATCH 2/2] Move combine function and check for same Opcode

---
 .../Target/AArch64/AArch64ISelLowering.cpp    | 80 +++++++++----------
 llvm/test/CodeGen/AArch64/neon-dotreduce.ll   | 53 +++++-------
 llvm/test/CodeGen/AArch64/vecreduce-add.ll    | 58 +++++---------
 3 files changed, 79 insertions(+), 112 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 0f7047c84d48c9..0687dfae671da8 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -16490,14 +16490,11 @@ static SDValue performVecReduceAddCombine(SDNode *N, SelectionDAG &DAG,
                      VecReudceAdd8);
 }
 
-// Turn UADDV(add(zext(extract_lo(x)), zext(extract_hi(x)))) into
-// UADDV(UADDLP(x)). If that fails, then convert UADDV(add(zext(64-bit source),
-// zext(64-bit source))) into UADDLV(concat).
+// Given an (integer) vecreduce, we know the order of the inputs does not
+// matter. We can convert UADDV(add(zext(extract_lo(x)), zext(extract_hi(x))))
+// into UADDV(UADDLP(x)). This can also happen through an extra add, where we
+// transform UADDV(add(y, add(zext(extract_lo(x)), zext(extract_hi(x))))).
 static SDValue performUADDVAddCombine(SDValue A, SelectionDAG &DAG) {
-  // Given an (integer) vecreduce, we know the order of the inputs does not
-  // matter. We can convert UADDV(add(zext(extract_lo(x)), zext(extract_hi(x))))
-  // into UADDV(UADDLP(x)). This can also happen through an extra add, where we
-  // transform UADDV(add(y, add(zext(extract_lo(x)), zext(extract_hi(x))))).
   auto DetectAddExtract = [&](SDValue A) {
     // Look for add(zext(extract_lo(x)), zext(extract_hi(x))), returning
     // UADDLP(x) if found.
@@ -16531,34 +16528,6 @@ static SDValue performUADDVAddCombine(SDValue A, SelectionDAG &DAG) {
     return DAG.getNode(Opcode, SDLoc(A), VT, Ext0.getOperand(0));
   };
 
-  // We can convert a UADDV(add(zext(64-bit source), zext(64-bit source))) into
-  // UADDLV(concat), where the concat represents the 64-bit zext sources.
-  auto DetectZextConcat = [&](SDValue A, SelectionDAG &DAG) {
-    // Look for add(zext(64-bit source), zext(64-bit source)), returning
-    // UADDLV(concat(zext, zext)) if found.
-    if (A.getOpcode() != ISD::ADD)
-      return SDValue();
-    EVT VT = A.getValueType();
-    if (VT != MVT::v4i32)
-      return SDValue();
-    SDValue Op0 = A.getOperand(0);
-    SDValue Op1 = A.getOperand(1);
-    if (Op0.getOpcode() != ISD::ZERO_EXTEND)
-      return SDValue();
-    SDValue Ext0 = Op0.getOperand(0);
-    SDValue Ext1 = Op1.getOperand(0);
-    EVT ExtVT0 = Ext0.getValueType();
-    EVT ExtVT1 = Ext1.getValueType();
-    // Check zext VTs are the same and 64-bit length.
-    if (ExtVT0 != ExtVT1 || !(ExtVT0 == MVT::v8i8 || ExtVT0 == MVT::v4i16))
-      return SDValue();
-    // Get VT for concat of zext sources.
-    EVT PairVT = ExtVT0.getDoubleNumVectorElementsVT(*DAG.getContext());
-    SDValue Concat =
-        DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(A), PairVT, Ext0, Ext1);
-    return DAG.getNode(AArch64ISD::UADDLV, SDLoc(A), MVT::v4i32, Concat);
-  };
-
   if (SDValue R = DetectAddExtract(A))
     return R;
 
@@ -16570,20 +16539,45 @@ static SDValue performUADDVAddCombine(SDValue A, SelectionDAG &DAG) {
     if (SDValue R = performUADDVAddCombine(A.getOperand(1), DAG))
       return DAG.getNode(ISD::ADD, SDLoc(A), A.getValueType(), R,
                          A.getOperand(0));
-
-  if (SDValue R = DetectZextConcat(A, DAG))
-    return R;
-
   return SDValue();
 }
 
+// We can convert a UADDV(add(zext(64-bit source), zext(64-bit source))) into
+// UADDLV(concat), where the concat represents the 64-bit zext sources.
+static SDValue performUADDVZextCombine(SDValue A, SelectionDAG &DAG) {
+  // Look for add(zext(64-bit source), zext(64-bit source)), returning
+  // UADDLV(concat(zext, zext)) if found.
+  if (A.getOpcode() != ISD::ADD)
+    return SDValue();
+  EVT VT = A.getValueType();
+  if (VT != MVT::v4i32)
+    return SDValue();
+  SDValue Op0 = A.getOperand(0);
+  SDValue Op1 = A.getOperand(1);
+  if (Op0.getOpcode() != ISD::ZERO_EXTEND || Op0.getOpcode() != Op1.getOpcode())
+    return SDValue();
+  SDValue Ext0 = Op0.getOperand(0);
+  SDValue Ext1 = Op1.getOperand(0);
+  EVT ExtVT0 = Ext0.getValueType();
+  EVT ExtVT1 = Ext1.getValueType();
+  // Check zext VTs are the same and 64-bit length.
+  if (ExtVT0 != ExtVT1 || !(ExtVT0 == MVT::v8i8 || ExtVT0 == MVT::v4i16))
+    return SDValue();
+  // Get VT for concat of zext sources.
+  EVT PairVT = ExtVT0.getDoubleNumVectorElementsVT(*DAG.getContext());
+  SDValue Concat =
+      DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(A), PairVT, Ext0, Ext1);
+  return DAG.getNode(AArch64ISD::UADDLV, SDLoc(A), MVT::v4i32, Concat);
+}
+
 static SDValue performUADDVCombine(SDNode *N, SelectionDAG &DAG) {
   SDValue A = N->getOperand(0);
-  if (A.getOpcode() == ISD::ADD)
+  if (A.getOpcode() == ISD::ADD) {
     if (SDValue R = performUADDVAddCombine(A, DAG))
-      return R.getOpcode() == AArch64ISD::UADDLV
-                 ? R
-                 : DAG.getNode(N->getOpcode(), SDLoc(N), N->getValueType(0), R);
+      return DAG.getNode(N->getOpcode(), SDLoc(N), N->getValueType(0), R);
+    else if (SDValue R = performUADDVZextCombine(A, DAG))
+      return R;
+  }
   return SDValue();
 }
 
diff --git a/llvm/test/CodeGen/AArch64/neon-dotreduce.ll b/llvm/test/CodeGen/AArch64/neon-dotreduce.ll
index e4767594851eae..706aa4ad1b4665 100644
--- a/llvm/test/CodeGen/AArch64/neon-dotreduce.ll
+++ b/llvm/test/CodeGen/AArch64/neon-dotreduce.ll
@@ -1039,21 +1039,17 @@ define i32 @test_udot_v25i8_nomla(ptr nocapture readonly %a1) {
 ; CHECK:       // %bb.0: // %entry
 ; CHECK-NEXT:    ldp q2, q1, [x0]
 ; CHECK-NEXT:    movi v0.2d, #0000000000000000
-; CHECK-NEXT:    ushll v3.8h, v2.8b, #0
-; CHECK-NEXT:    ushll v4.8h, v1.8b, #0
+; CHECK-NEXT:    ushll v3.8h, v1.8b, #0
+; CHECK-NEXT:    ushll v4.8h, v2.8b, #0
 ; CHECK-NEXT:    ushll2 v1.8h, v1.16b, #0
-; CHECK-NEXT:    ext v5.16b, v3.16b, v3.16b, #8
-; CHECK-NEXT:    ext v6.16b, v4.16b, v4.16b, #8
+; CHECK-NEXT:    ushll2 v2.8h, v2.16b, #0
+; CHECK-NEXT:    uaddl2 v5.4s, v4.8h, v3.8h
 ; CHECK-NEXT:    ushll v1.4s, v1.4h, #0
-; CHECK-NEXT:    mov v3.d[1], v4.d[0]
+; CHECK-NEXT:    uaddl v3.4s, v4.4h, v3.4h
 ; CHECK-NEXT:    mov v0.s[0], v1.s[0]
-; CHECK-NEXT:    ushll2 v1.8h, v2.16b, #0
-; CHECK-NEXT:    mov v5.d[1], v6.d[0]
-; CHECK-NEXT:    uaddlv s2, v3.8h
-; CHECK-NEXT:    uaddw v0.4s, v0.4s, v1.4h
-; CHECK-NEXT:    uaddlv s3, v5.8h
-; CHECK-NEXT:    add v0.4s, v2.4s, v0.4s
-; CHECK-NEXT:    uaddw2 v1.4s, v3.4s, v1.8h
+; CHECK-NEXT:    uaddw2 v1.4s, v5.4s, v2.8h
+; CHECK-NEXT:    uaddw v0.4s, v0.4s, v2.4h
+; CHECK-NEXT:    add v1.4s, v3.4s, v1.4s
 ; CHECK-NEXT:    add v0.4s, v1.4s, v0.4s
 ; CHECK-NEXT:    addv s0, v0.4s
 ; CHECK-NEXT:    fmov w0, s0
@@ -1635,30 +1631,23 @@ define i32 @test_udot_v33i8_nomla(ptr nocapture readonly %a1) {
 ; CHECK-LABEL: test_udot_v33i8_nomla:
 ; CHECK:       // %bb.0: // %entry
 ; CHECK-NEXT:    ldr b1, [x0, #32]
-; CHECK-NEXT:    ldp q2, q3, [x0]
+; CHECK-NEXT:    ldp q3, q2, [x0]
 ; CHECK-NEXT:    movi v0.2d, #0000000000000000
 ; CHECK-NEXT:    ushll v1.8h, v1.8b, #0
-; CHECK-NEXT:    ushll2 v4.8h, v2.16b, #0
-; CHECK-NEXT:    ushll2 v5.8h, v3.16b, #0
-; CHECK-NEXT:    ushll v3.8h, v3.8b, #0
-; CHECK-NEXT:    ushll v2.8h, v2.8b, #0
+; CHECK-NEXT:    ushll v4.8h, v2.8b, #0
+; CHECK-NEXT:    ushll v5.8h, v3.8b, #0
+; CHECK-NEXT:    ushll2 v2.8h, v2.16b, #0
+; CHECK-NEXT:    ushll2 v3.8h, v3.16b, #0
 ; CHECK-NEXT:    ushll v1.4s, v1.4h, #0
-; CHECK-NEXT:    ext v6.16b, v4.16b, v4.16b, #8
-; CHECK-NEXT:    ext v7.16b, v5.16b, v5.16b, #8
-; CHECK-NEXT:    mov v4.d[1], v5.d[0]
-; CHECK-NEXT:    ext v16.16b, v2.16b, v2.16b, #8
+; CHECK-NEXT:    uaddl2 v6.4s, v3.8h, v2.8h
+; CHECK-NEXT:    uaddl v2.4s, v3.4h, v2.4h
 ; CHECK-NEXT:    mov v0.s[0], v1.s[0]
-; CHECK-NEXT:    ext v1.16b, v3.16b, v3.16b, #8
-; CHECK-NEXT:    mov v6.d[1], v7.d[0]
-; CHECK-NEXT:    uaddlv s4, v4.8h
-; CHECK-NEXT:    mov v16.d[1], v1.d[0]
-; CHECK-NEXT:    uaddw v0.4s, v0.4s, v2.4h
-; CHECK-NEXT:    uaddlv s1, v6.8h
-; CHECK-NEXT:    uaddlv s2, v16.8h
-; CHECK-NEXT:    uaddw v0.4s, v0.4s, v3.4h
-; CHECK-NEXT:    add v1.4s, v1.4s, v2.4s
-; CHECK-NEXT:    add v0.4s, v4.4s, v0.4s
-; CHECK-NEXT:    add v0.4s, v1.4s, v0.4s
+; CHECK-NEXT:    uaddl2 v1.4s, v5.8h, v4.8h
+; CHECK-NEXT:    add v1.4s, v1.4s, v6.4s
+; CHECK-NEXT:    uaddw v0.4s, v0.4s, v5.4h
+; CHECK-NEXT:    uaddw v0.4s, v0.4s, v4.4h
+; CHECK-NEXT:    add v1.4s, v2.4s, v1.4s
+; CHECK-NEXT:    add v0.4s, v0.4s, v1.4s
 ; CHECK-NEXT:    addv s0, v0.4s
 ; CHECK-NEXT:    fmov w0, s0
 ; CHECK-NEXT:    ret
diff --git a/llvm/test/CodeGen/AArch64/vecreduce-add.ll b/llvm/test/CodeGen/AArch64/vecreduce-add.ll
index ad9cf2e1d99365..e431f007faf397 100644
--- a/llvm/test/CodeGen/AArch64/vecreduce-add.ll
+++ b/llvm/test/CodeGen/AArch64/vecreduce-add.ll
@@ -513,15 +513,11 @@ entry:
 define i32 @add_v16i8_v16i32_zext(<16 x i8> %x) {
 ; CHECK-SD-BASE-LABEL: add_v16i8_v16i32_zext:
 ; CHECK-SD-BASE:       // %bb.0: // %entry
-; CHECK-SD-BASE-NEXT:    ushll v1.8h, v0.8b, #0
-; CHECK-SD-BASE-NEXT:    ushll2 v0.8h, v0.16b, #0
-; CHECK-SD-BASE-NEXT:    ext v2.16b, v1.16b, v1.16b, #8
-; CHECK-SD-BASE-NEXT:    ext v3.16b, v0.16b, v0.16b, #8
-; CHECK-SD-BASE-NEXT:    mov v1.d[1], v0.d[0]
-; CHECK-SD-BASE-NEXT:    mov v2.d[1], v3.d[0]
-; CHECK-SD-BASE-NEXT:    uaddlv s0, v1.8h
-; CHECK-SD-BASE-NEXT:    uaddlv s1, v2.8h
-; CHECK-SD-BASE-NEXT:    add v0.4s, v1.4s, v0.4s
+; CHECK-SD-BASE-NEXT:    ushll2 v1.8h, v0.16b, #0
+; CHECK-SD-BASE-NEXT:    ushll v0.8h, v0.8b, #0
+; CHECK-SD-BASE-NEXT:    uaddl2 v2.4s, v0.8h, v1.8h
+; CHECK-SD-BASE-NEXT:    uaddl v0.4s, v0.4h, v1.4h
+; CHECK-SD-BASE-NEXT:    add v0.4s, v0.4s, v2.4s
 ; CHECK-SD-BASE-NEXT:    addv s0, v0.4s
 ; CHECK-SD-BASE-NEXT:    fmov w0, s0
 ; CHECK-SD-BASE-NEXT:    ret
@@ -1914,15 +1910,11 @@ entry:
 define i32 @add_v16i8_v16i32_acc_zext(<16 x i8> %x, i32 %a) {
 ; CHECK-SD-BASE-LABEL: add_v16i8_v16i32_acc_zext:
 ; CHECK-SD-BASE:       // %bb.0: // %entry
-; CHECK-SD-BASE-NEXT:    ushll v1.8h, v0.8b, #0
-; CHECK-SD-BASE-NEXT:    ushll2 v0.8h, v0.16b, #0
-; CHECK-SD-BASE-NEXT:    ext v2.16b, v1.16b, v1.16b, #8
-; CHECK-SD-BASE-NEXT:    ext v3.16b, v0.16b, v0.16b, #8
-; CHECK-SD-BASE-NEXT:    mov v1.d[1], v0.d[0]
-; CHECK-SD-BASE-NEXT:    mov v2.d[1], v3.d[0]
-; CHECK-SD-BASE-NEXT:    uaddlv s0, v1.8h
-; CHECK-SD-BASE-NEXT:    uaddlv s1, v2.8h
-; CHECK-SD-BASE-NEXT:    add v0.4s, v1.4s, v0.4s
+; CHECK-SD-BASE-NEXT:    ushll2 v1.8h, v0.16b, #0
+; CHECK-SD-BASE-NEXT:    ushll v0.8h, v0.8b, #0
+; CHECK-SD-BASE-NEXT:    uaddl2 v2.4s, v0.8h, v1.8h
+; CHECK-SD-BASE-NEXT:    uaddl v0.4s, v0.4h, v1.4h
+; CHECK-SD-BASE-NEXT:    add v0.4s, v0.4s, v2.4s
 ; CHECK-SD-BASE-NEXT:    addv s0, v0.4s
 ; CHECK-SD-BASE-NEXT:    fmov w8, s0
 ; CHECK-SD-BASE-NEXT:    add w0, w8, w0
@@ -4793,25 +4785,17 @@ entry:
 define i32 @add_pair_v16i8_v16i32_zext(<16 x i8> %x, <16 x i8> %y) {
 ; CHECK-SD-BASE-LABEL: add_pair_v16i8_v16i32_zext:
 ; CHECK-SD-BASE:       // %bb.0: // %entry
-; CHECK-SD-BASE-NEXT:    ushll v2.8h, v0.8b, #0
-; CHECK-SD-BASE-NEXT:    ushll2 v0.8h, v0.16b, #0
-; CHECK-SD-BASE-NEXT:    ushll v3.8h, v1.8b, #0
-; CHECK-SD-BASE-NEXT:    ushll2 v1.8h, v1.16b, #0
-; CHECK-SD-BASE-NEXT:    ext v4.16b, v2.16b, v2.16b, #8
-; CHECK-SD-BASE-NEXT:    ext v5.16b, v0.16b, v0.16b, #8
-; CHECK-SD-BASE-NEXT:    ext v6.16b, v3.16b, v3.16b, #8
-; CHECK-SD-BASE-NEXT:    ext v7.16b, v1.16b, v1.16b, #8
-; CHECK-SD-BASE-NEXT:    mov v2.d[1], v0.d[0]
-; CHECK-SD-BASE-NEXT:    mov v3.d[1], v1.d[0]
-; CHECK-SD-BASE-NEXT:    mov v4.d[1], v5.d[0]
-; CHECK-SD-BASE-NEXT:    mov v6.d[1], v7.d[0]
-; CHECK-SD-BASE-NEXT:    uaddlv s0, v2.8h
-; CHECK-SD-BASE-NEXT:    uaddlv s2, v3.8h
-; CHECK-SD-BASE-NEXT:    uaddlv s1, v4.8h
-; CHECK-SD-BASE-NEXT:    uaddlv s3, v6.8h
-; CHECK-SD-BASE-NEXT:    add v0.4s, v1.4s, v0.4s
-; CHECK-SD-BASE-NEXT:    add v1.4s, v3.4s, v2.4s
-; CHECK-SD-BASE-NEXT:    add v0.4s, v1.4s, v0.4s
+; CHECK-SD-BASE-NEXT:    ushll2 v2.8h, v0.16b, #0
+; CHECK-SD-BASE-NEXT:    ushll v0.8h, v0.8b, #0
+; CHECK-SD-BASE-NEXT:    ushll2 v3.8h, v1.16b, #0
+; CHECK-SD-BASE-NEXT:    ushll v1.8h, v1.8b, #0
+; CHECK-SD-BASE-NEXT:    uaddl2 v4.4s, v0.8h, v2.8h
+; CHECK-SD-BASE-NEXT:    uaddl v0.4s, v0.4h, v2.4h
+; CHECK-SD-BASE-NEXT:    uaddl2 v2.4s, v1.8h, v3.8h
+; CHECK-SD-BASE-NEXT:    uaddl v1.4s, v1.4h, v3.4h
+; CHECK-SD-BASE-NEXT:    add v0.4s, v0.4s, v4.4s
+; CHECK-SD-BASE-NEXT:    add v1.4s, v1.4s, v2.4s
+; CHECK-SD-BASE-NEXT:    add v0.4s, v0.4s, v1.4s
 ; CHECK-SD-BASE-NEXT:    addv s0, v0.4s
 ; CHECK-SD-BASE-NEXT:    fmov w0, s0
 ; CHECK-SD-BASE-NEXT:    ret



More information about the llvm-commits mailing list