[llvm] [AArch64] Convert UADDV(add(zext, zext)) into UADDLV(concat). (PR #78301)
Rin Dobrescu via llvm-commits
llvm-commits at lists.llvm.org
Wed Jan 17 04:17:44 PST 2024
https://github.com/Rin18 updated https://github.com/llvm/llvm-project/pull/78301
>From 714df50db8950ee8f9727893538671e5db68cd57 Mon Sep 17 00:00:00 2001
From: Rin Dobrescu <rin.dobrescu at arm.com>
Date: Tue, 16 Jan 2024 15:42:32 +0000
Subject: [PATCH 1/2] [AArch64] Convert UADDV(add(zext(64-bit source),
zext(64-bit source))) into UADDLV(concat).
---
.../Target/AArch64/AArch64ISelLowering.cpp | 47 +++++++++--
.../AArch64/aarch64-combine-add-zext.ll | 20 +++++
llvm/test/CodeGen/AArch64/avoid-pre-trunc.ll | 12 +--
llvm/test/CodeGen/AArch64/neon-dotreduce.ll | 53 +++++++-----
llvm/test/CodeGen/AArch64/vecreduce-add.ll | 82 ++++++++++++-------
5 files changed, 153 insertions(+), 61 deletions(-)
create mode 100644 llvm/test/CodeGen/AArch64/aarch64-combine-add-zext.ll
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index dffe69bdb900db..0f7047c84d48c9 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -16490,11 +16490,14 @@ static SDValue performVecReduceAddCombine(SDNode *N, SelectionDAG &DAG,
VecReudceAdd8);
}
-// Given an (integer) vecreduce, we know the order of the inputs does not
-// matter. We can convert UADDV(add(zext(extract_lo(x)), zext(extract_hi(x))))
-// into UADDV(UADDLP(x)). This can also happen through an extra add, where we
-// transform UADDV(add(y, add(zext(extract_lo(x)), zext(extract_hi(x))))).
+// Turn UADDV(add(zext(extract_lo(x)), zext(extract_hi(x)))) into
+// UADDV(UADDLP(x)). If that fails, then convert UADDV(add(zext(64-bit source),
+// zext(64-bit source))) into UADDLV(concat).
static SDValue performUADDVAddCombine(SDValue A, SelectionDAG &DAG) {
+ // Given an (integer) vecreduce, we know the order of the inputs does not
+ // matter. We can convert UADDV(add(zext(extract_lo(x)), zext(extract_hi(x))))
+ // into UADDV(UADDLP(x)). This can also happen through an extra add, where we
+ // transform UADDV(add(y, add(zext(extract_lo(x)), zext(extract_hi(x))))).
auto DetectAddExtract = [&](SDValue A) {
// Look for add(zext(extract_lo(x)), zext(extract_hi(x))), returning
// UADDLP(x) if found.
@@ -16528,6 +16531,34 @@ static SDValue performUADDVAddCombine(SDValue A, SelectionDAG &DAG) {
return DAG.getNode(Opcode, SDLoc(A), VT, Ext0.getOperand(0));
};
+ // We can convert a UADDV(add(zext(64-bit source), zext(64-bit source))) into
+ // UADDLV(concat), where the concat represents the 64-bit zext sources.
+ auto DetectZextConcat = [&](SDValue A, SelectionDAG &DAG) {
+ // Look for add(zext(64-bit source), zext(64-bit source)), returning
+ // UADDLV(concat(zext, zext)) if found.
+ if (A.getOpcode() != ISD::ADD)
+ return SDValue();
+ EVT VT = A.getValueType();
+ if (VT != MVT::v4i32)
+ return SDValue();
+ SDValue Op0 = A.getOperand(0);
+ SDValue Op1 = A.getOperand(1);
+ if (Op0.getOpcode() != ISD::ZERO_EXTEND)
+ return SDValue();
+ SDValue Ext0 = Op0.getOperand(0);
+ SDValue Ext1 = Op1.getOperand(0);
+ EVT ExtVT0 = Ext0.getValueType();
+ EVT ExtVT1 = Ext1.getValueType();
+ // Check zext VTs are the same and 64-bit length.
+ if (ExtVT0 != ExtVT1 || !(ExtVT0 == MVT::v8i8 || ExtVT0 == MVT::v4i16))
+ return SDValue();
+ // Get VT for concat of zext sources.
+ EVT PairVT = ExtVT0.getDoubleNumVectorElementsVT(*DAG.getContext());
+ SDValue Concat =
+ DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(A), PairVT, Ext0, Ext1);
+ return DAG.getNode(AArch64ISD::UADDLV, SDLoc(A), MVT::v4i32, Concat);
+ };
+
if (SDValue R = DetectAddExtract(A))
return R;
@@ -16539,6 +16570,10 @@ static SDValue performUADDVAddCombine(SDValue A, SelectionDAG &DAG) {
if (SDValue R = performUADDVAddCombine(A.getOperand(1), DAG))
return DAG.getNode(ISD::ADD, SDLoc(A), A.getValueType(), R,
A.getOperand(0));
+
+ if (SDValue R = DetectZextConcat(A, DAG))
+ return R;
+
return SDValue();
}
@@ -16546,7 +16581,9 @@ static SDValue performUADDVCombine(SDNode *N, SelectionDAG &DAG) {
SDValue A = N->getOperand(0);
if (A.getOpcode() == ISD::ADD)
if (SDValue R = performUADDVAddCombine(A, DAG))
- return DAG.getNode(N->getOpcode(), SDLoc(N), N->getValueType(0), R);
+ return R.getOpcode() == AArch64ISD::UADDLV
+ ? R
+ : DAG.getNode(N->getOpcode(), SDLoc(N), N->getValueType(0), R);
return SDValue();
}
diff --git a/llvm/test/CodeGen/AArch64/aarch64-combine-add-zext.ll b/llvm/test/CodeGen/AArch64/aarch64-combine-add-zext.ll
new file mode 100644
index 00000000000000..1cb0206c3fdbc1
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/aarch64-combine-add-zext.ll
@@ -0,0 +1,20 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc < %s -mtriple=aarch64-none-linux-gnu | FileCheck %s
+
+define i32 @test_add_zext(<4 x i16> %a, <4 x i16> %b) local_unnamed_addr #0 {
+; CHECK-LABEL: test_add_zext:
+; CHECK: // %bb.0:
+; CHECK-NEXT: // kill: def $d0 killed $d0 def $q0
+; CHECK-NEXT: // kill: def $d1 killed $d1 def $q1
+; CHECK-NEXT: mov v0.d[1], v1.d[0]
+; CHECK-NEXT: uaddlv s0, v0.8h
+; CHECK-NEXT: fmov w0, s0
+; CHECK-NEXT: ret
+ %z1 = zext <4 x i16> %a to <4 x i32>
+ %z2 = zext <4 x i16> %b to <4 x i32>
+ %z = add <4 x i32> %z1, %z2
+ %r = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %z)
+ ret i32 %r
+}
+
+declare i32 @llvm.vector.reduce.add.v4i32(<4 x i32>)
diff --git a/llvm/test/CodeGen/AArch64/avoid-pre-trunc.ll b/llvm/test/CodeGen/AArch64/avoid-pre-trunc.ll
index 1fc177f034975d..24cce9a2b26b58 100644
--- a/llvm/test/CodeGen/AArch64/avoid-pre-trunc.ll
+++ b/llvm/test/CodeGen/AArch64/avoid-pre-trunc.ll
@@ -18,14 +18,14 @@ define i32 @lower_lshr(<4 x i32> %a, <4 x i32> %b, <4 x i32> %c, <4 x i32> %d, <
; CHECK-NEXT: mov v4.s[2], v6.s[0]
; CHECK-NEXT: mov v0.s[3], v1.s[0]
; CHECK-NEXT: mov v4.s[3], v3.s[0]
-; CHECK-NEXT: xtn v2.4h, v0.4s
+; CHECK-NEXT: xtn v1.4h, v0.4s
; CHECK-NEXT: shrn v0.4h, v0.4s, #16
-; CHECK-NEXT: xtn v1.4h, v4.4s
+; CHECK-NEXT: xtn v2.4h, v4.4s
; CHECK-NEXT: shrn v3.4h, v4.4s, #16
-; CHECK-NEXT: uhadd v0.4h, v2.4h, v0.4h
-; CHECK-NEXT: uhadd v1.4h, v1.4h, v3.4h
-; CHECK-NEXT: uaddl v0.4s, v0.4h, v1.4h
-; CHECK-NEXT: addv s0, v0.4s
+; CHECK-NEXT: uhadd v0.4h, v1.4h, v0.4h
+; CHECK-NEXT: uhadd v1.4h, v2.4h, v3.4h
+; CHECK-NEXT: mov v0.d[1], v1.d[0]
+; CHECK-NEXT: uaddlv s0, v0.8h
; CHECK-NEXT: fmov w0, s0
; CHECK-NEXT: ret
%l87 = tail call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %a)
diff --git a/llvm/test/CodeGen/AArch64/neon-dotreduce.ll b/llvm/test/CodeGen/AArch64/neon-dotreduce.ll
index 706aa4ad1b4665..e4767594851eae 100644
--- a/llvm/test/CodeGen/AArch64/neon-dotreduce.ll
+++ b/llvm/test/CodeGen/AArch64/neon-dotreduce.ll
@@ -1039,17 +1039,21 @@ define i32 @test_udot_v25i8_nomla(ptr nocapture readonly %a1) {
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: ldp q2, q1, [x0]
; CHECK-NEXT: movi v0.2d, #0000000000000000
-; CHECK-NEXT: ushll v3.8h, v1.8b, #0
-; CHECK-NEXT: ushll v4.8h, v2.8b, #0
+; CHECK-NEXT: ushll v3.8h, v2.8b, #0
+; CHECK-NEXT: ushll v4.8h, v1.8b, #0
; CHECK-NEXT: ushll2 v1.8h, v1.16b, #0
-; CHECK-NEXT: ushll2 v2.8h, v2.16b, #0
-; CHECK-NEXT: uaddl2 v5.4s, v4.8h, v3.8h
+; CHECK-NEXT: ext v5.16b, v3.16b, v3.16b, #8
+; CHECK-NEXT: ext v6.16b, v4.16b, v4.16b, #8
; CHECK-NEXT: ushll v1.4s, v1.4h, #0
-; CHECK-NEXT: uaddl v3.4s, v4.4h, v3.4h
+; CHECK-NEXT: mov v3.d[1], v4.d[0]
; CHECK-NEXT: mov v0.s[0], v1.s[0]
-; CHECK-NEXT: uaddw2 v1.4s, v5.4s, v2.8h
-; CHECK-NEXT: uaddw v0.4s, v0.4s, v2.4h
-; CHECK-NEXT: add v1.4s, v3.4s, v1.4s
+; CHECK-NEXT: ushll2 v1.8h, v2.16b, #0
+; CHECK-NEXT: mov v5.d[1], v6.d[0]
+; CHECK-NEXT: uaddlv s2, v3.8h
+; CHECK-NEXT: uaddw v0.4s, v0.4s, v1.4h
+; CHECK-NEXT: uaddlv s3, v5.8h
+; CHECK-NEXT: add v0.4s, v2.4s, v0.4s
+; CHECK-NEXT: uaddw2 v1.4s, v3.4s, v1.8h
; CHECK-NEXT: add v0.4s, v1.4s, v0.4s
; CHECK-NEXT: addv s0, v0.4s
; CHECK-NEXT: fmov w0, s0
@@ -1631,23 +1635,30 @@ define i32 @test_udot_v33i8_nomla(ptr nocapture readonly %a1) {
; CHECK-LABEL: test_udot_v33i8_nomla:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: ldr b1, [x0, #32]
-; CHECK-NEXT: ldp q3, q2, [x0]
+; CHECK-NEXT: ldp q2, q3, [x0]
; CHECK-NEXT: movi v0.2d, #0000000000000000
; CHECK-NEXT: ushll v1.8h, v1.8b, #0
-; CHECK-NEXT: ushll v4.8h, v2.8b, #0
-; CHECK-NEXT: ushll v5.8h, v3.8b, #0
-; CHECK-NEXT: ushll2 v2.8h, v2.16b, #0
-; CHECK-NEXT: ushll2 v3.8h, v3.16b, #0
+; CHECK-NEXT: ushll2 v4.8h, v2.16b, #0
+; CHECK-NEXT: ushll2 v5.8h, v3.16b, #0
+; CHECK-NEXT: ushll v3.8h, v3.8b, #0
+; CHECK-NEXT: ushll v2.8h, v2.8b, #0
; CHECK-NEXT: ushll v1.4s, v1.4h, #0
-; CHECK-NEXT: uaddl2 v6.4s, v3.8h, v2.8h
-; CHECK-NEXT: uaddl v2.4s, v3.4h, v2.4h
+; CHECK-NEXT: ext v6.16b, v4.16b, v4.16b, #8
+; CHECK-NEXT: ext v7.16b, v5.16b, v5.16b, #8
+; CHECK-NEXT: mov v4.d[1], v5.d[0]
+; CHECK-NEXT: ext v16.16b, v2.16b, v2.16b, #8
; CHECK-NEXT: mov v0.s[0], v1.s[0]
-; CHECK-NEXT: uaddl2 v1.4s, v5.8h, v4.8h
-; CHECK-NEXT: add v1.4s, v1.4s, v6.4s
-; CHECK-NEXT: uaddw v0.4s, v0.4s, v5.4h
-; CHECK-NEXT: uaddw v0.4s, v0.4s, v4.4h
-; CHECK-NEXT: add v1.4s, v2.4s, v1.4s
-; CHECK-NEXT: add v0.4s, v0.4s, v1.4s
+; CHECK-NEXT: ext v1.16b, v3.16b, v3.16b, #8
+; CHECK-NEXT: mov v6.d[1], v7.d[0]
+; CHECK-NEXT: uaddlv s4, v4.8h
+; CHECK-NEXT: mov v16.d[1], v1.d[0]
+; CHECK-NEXT: uaddw v0.4s, v0.4s, v2.4h
+; CHECK-NEXT: uaddlv s1, v6.8h
+; CHECK-NEXT: uaddlv s2, v16.8h
+; CHECK-NEXT: uaddw v0.4s, v0.4s, v3.4h
+; CHECK-NEXT: add v1.4s, v1.4s, v2.4s
+; CHECK-NEXT: add v0.4s, v4.4s, v0.4s
+; CHECK-NEXT: add v0.4s, v1.4s, v0.4s
; CHECK-NEXT: addv s0, v0.4s
; CHECK-NEXT: fmov w0, s0
; CHECK-NEXT: ret
diff --git a/llvm/test/CodeGen/AArch64/vecreduce-add.ll b/llvm/test/CodeGen/AArch64/vecreduce-add.ll
index b24967d5a130ed..ad9cf2e1d99365 100644
--- a/llvm/test/CodeGen/AArch64/vecreduce-add.ll
+++ b/llvm/test/CodeGen/AArch64/vecreduce-add.ll
@@ -513,11 +513,15 @@ entry:
define i32 @add_v16i8_v16i32_zext(<16 x i8> %x) {
; CHECK-SD-BASE-LABEL: add_v16i8_v16i32_zext:
; CHECK-SD-BASE: // %bb.0: // %entry
-; CHECK-SD-BASE-NEXT: ushll2 v1.8h, v0.16b, #0
-; CHECK-SD-BASE-NEXT: ushll v0.8h, v0.8b, #0
-; CHECK-SD-BASE-NEXT: uaddl2 v2.4s, v0.8h, v1.8h
-; CHECK-SD-BASE-NEXT: uaddl v0.4s, v0.4h, v1.4h
-; CHECK-SD-BASE-NEXT: add v0.4s, v0.4s, v2.4s
+; CHECK-SD-BASE-NEXT: ushll v1.8h, v0.8b, #0
+; CHECK-SD-BASE-NEXT: ushll2 v0.8h, v0.16b, #0
+; CHECK-SD-BASE-NEXT: ext v2.16b, v1.16b, v1.16b, #8
+; CHECK-SD-BASE-NEXT: ext v3.16b, v0.16b, v0.16b, #8
+; CHECK-SD-BASE-NEXT: mov v1.d[1], v0.d[0]
+; CHECK-SD-BASE-NEXT: mov v2.d[1], v3.d[0]
+; CHECK-SD-BASE-NEXT: uaddlv s0, v1.8h
+; CHECK-SD-BASE-NEXT: uaddlv s1, v2.8h
+; CHECK-SD-BASE-NEXT: add v0.4s, v1.4s, v0.4s
; CHECK-SD-BASE-NEXT: addv s0, v0.4s
; CHECK-SD-BASE-NEXT: fmov w0, s0
; CHECK-SD-BASE-NEXT: ret
@@ -1910,11 +1914,15 @@ entry:
define i32 @add_v16i8_v16i32_acc_zext(<16 x i8> %x, i32 %a) {
; CHECK-SD-BASE-LABEL: add_v16i8_v16i32_acc_zext:
; CHECK-SD-BASE: // %bb.0: // %entry
-; CHECK-SD-BASE-NEXT: ushll2 v1.8h, v0.16b, #0
-; CHECK-SD-BASE-NEXT: ushll v0.8h, v0.8b, #0
-; CHECK-SD-BASE-NEXT: uaddl2 v2.4s, v0.8h, v1.8h
-; CHECK-SD-BASE-NEXT: uaddl v0.4s, v0.4h, v1.4h
-; CHECK-SD-BASE-NEXT: add v0.4s, v0.4s, v2.4s
+; CHECK-SD-BASE-NEXT: ushll v1.8h, v0.8b, #0
+; CHECK-SD-BASE-NEXT: ushll2 v0.8h, v0.16b, #0
+; CHECK-SD-BASE-NEXT: ext v2.16b, v1.16b, v1.16b, #8
+; CHECK-SD-BASE-NEXT: ext v3.16b, v0.16b, v0.16b, #8
+; CHECK-SD-BASE-NEXT: mov v1.d[1], v0.d[0]
+; CHECK-SD-BASE-NEXT: mov v2.d[1], v3.d[0]
+; CHECK-SD-BASE-NEXT: uaddlv s0, v1.8h
+; CHECK-SD-BASE-NEXT: uaddlv s1, v2.8h
+; CHECK-SD-BASE-NEXT: add v0.4s, v1.4s, v0.4s
; CHECK-SD-BASE-NEXT: addv s0, v0.4s
; CHECK-SD-BASE-NEXT: fmov w8, s0
; CHECK-SD-BASE-NEXT: add w0, w8, w0
@@ -3200,15 +3208,19 @@ entry:
define i32 @add_pair_v4i16_v4i32_zext(<4 x i16> %x, <4 x i16> %y) {
; CHECK-SD-BASE-LABEL: add_pair_v4i16_v4i32_zext:
; CHECK-SD-BASE: // %bb.0: // %entry
-; CHECK-SD-BASE-NEXT: uaddl v0.4s, v0.4h, v1.4h
-; CHECK-SD-BASE-NEXT: addv s0, v0.4s
+; CHECK-SD-BASE-NEXT: // kill: def $d0 killed $d0 def $q0
+; CHECK-SD-BASE-NEXT: // kill: def $d1 killed $d1 def $q1
+; CHECK-SD-BASE-NEXT: mov v0.d[1], v1.d[0]
+; CHECK-SD-BASE-NEXT: uaddlv s0, v0.8h
; CHECK-SD-BASE-NEXT: fmov w0, s0
; CHECK-SD-BASE-NEXT: ret
;
; CHECK-SD-DOT-LABEL: add_pair_v4i16_v4i32_zext:
; CHECK-SD-DOT: // %bb.0: // %entry
-; CHECK-SD-DOT-NEXT: uaddl v0.4s, v0.4h, v1.4h
-; CHECK-SD-DOT-NEXT: addv s0, v0.4s
+; CHECK-SD-DOT-NEXT: // kill: def $d0 killed $d0 def $q0
+; CHECK-SD-DOT-NEXT: // kill: def $d1 killed $d1 def $q1
+; CHECK-SD-DOT-NEXT: mov v0.d[1], v1.d[0]
+; CHECK-SD-DOT-NEXT: uaddlv s0, v0.8h
; CHECK-SD-DOT-NEXT: fmov w0, s0
; CHECK-SD-DOT-NEXT: ret
;
@@ -4781,17 +4793,25 @@ entry:
define i32 @add_pair_v16i8_v16i32_zext(<16 x i8> %x, <16 x i8> %y) {
; CHECK-SD-BASE-LABEL: add_pair_v16i8_v16i32_zext:
; CHECK-SD-BASE: // %bb.0: // %entry
-; CHECK-SD-BASE-NEXT: ushll2 v2.8h, v0.16b, #0
-; CHECK-SD-BASE-NEXT: ushll v0.8h, v0.8b, #0
-; CHECK-SD-BASE-NEXT: ushll2 v3.8h, v1.16b, #0
-; CHECK-SD-BASE-NEXT: ushll v1.8h, v1.8b, #0
-; CHECK-SD-BASE-NEXT: uaddl2 v4.4s, v0.8h, v2.8h
-; CHECK-SD-BASE-NEXT: uaddl v0.4s, v0.4h, v2.4h
-; CHECK-SD-BASE-NEXT: uaddl2 v2.4s, v1.8h, v3.8h
-; CHECK-SD-BASE-NEXT: uaddl v1.4s, v1.4h, v3.4h
-; CHECK-SD-BASE-NEXT: add v0.4s, v0.4s, v4.4s
-; CHECK-SD-BASE-NEXT: add v1.4s, v1.4s, v2.4s
-; CHECK-SD-BASE-NEXT: add v0.4s, v0.4s, v1.4s
+; CHECK-SD-BASE-NEXT: ushll v2.8h, v0.8b, #0
+; CHECK-SD-BASE-NEXT: ushll2 v0.8h, v0.16b, #0
+; CHECK-SD-BASE-NEXT: ushll v3.8h, v1.8b, #0
+; CHECK-SD-BASE-NEXT: ushll2 v1.8h, v1.16b, #0
+; CHECK-SD-BASE-NEXT: ext v4.16b, v2.16b, v2.16b, #8
+; CHECK-SD-BASE-NEXT: ext v5.16b, v0.16b, v0.16b, #8
+; CHECK-SD-BASE-NEXT: ext v6.16b, v3.16b, v3.16b, #8
+; CHECK-SD-BASE-NEXT: ext v7.16b, v1.16b, v1.16b, #8
+; CHECK-SD-BASE-NEXT: mov v2.d[1], v0.d[0]
+; CHECK-SD-BASE-NEXT: mov v3.d[1], v1.d[0]
+; CHECK-SD-BASE-NEXT: mov v4.d[1], v5.d[0]
+; CHECK-SD-BASE-NEXT: mov v6.d[1], v7.d[0]
+; CHECK-SD-BASE-NEXT: uaddlv s0, v2.8h
+; CHECK-SD-BASE-NEXT: uaddlv s2, v3.8h
+; CHECK-SD-BASE-NEXT: uaddlv s1, v4.8h
+; CHECK-SD-BASE-NEXT: uaddlv s3, v6.8h
+; CHECK-SD-BASE-NEXT: add v0.4s, v1.4s, v0.4s
+; CHECK-SD-BASE-NEXT: add v1.4s, v3.4s, v2.4s
+; CHECK-SD-BASE-NEXT: add v0.4s, v1.4s, v0.4s
; CHECK-SD-BASE-NEXT: addv s0, v0.4s
; CHECK-SD-BASE-NEXT: fmov w0, s0
; CHECK-SD-BASE-NEXT: ret
@@ -5043,19 +5063,23 @@ entry:
define i32 @add_pair_v4i8_v4i32_zext(<4 x i8> %x, <4 x i8> %y) {
; CHECK-SD-BASE-LABEL: add_pair_v4i8_v4i32_zext:
; CHECK-SD-BASE: // %bb.0: // %entry
+; CHECK-SD-BASE-NEXT: // kill: def $d1 killed $d1 def $q1
+; CHECK-SD-BASE-NEXT: // kill: def $d0 killed $d0 def $q0
; CHECK-SD-BASE-NEXT: bic v0.4h, #255, lsl #8
; CHECK-SD-BASE-NEXT: bic v1.4h, #255, lsl #8
-; CHECK-SD-BASE-NEXT: uaddl v0.4s, v0.4h, v1.4h
-; CHECK-SD-BASE-NEXT: addv s0, v0.4s
+; CHECK-SD-BASE-NEXT: mov v0.d[1], v1.d[0]
+; CHECK-SD-BASE-NEXT: uaddlv s0, v0.8h
; CHECK-SD-BASE-NEXT: fmov w0, s0
; CHECK-SD-BASE-NEXT: ret
;
; CHECK-SD-DOT-LABEL: add_pair_v4i8_v4i32_zext:
; CHECK-SD-DOT: // %bb.0: // %entry
+; CHECK-SD-DOT-NEXT: // kill: def $d1 killed $d1 def $q1
+; CHECK-SD-DOT-NEXT: // kill: def $d0 killed $d0 def $q0
; CHECK-SD-DOT-NEXT: bic v0.4h, #255, lsl #8
; CHECK-SD-DOT-NEXT: bic v1.4h, #255, lsl #8
-; CHECK-SD-DOT-NEXT: uaddl v0.4s, v0.4h, v1.4h
-; CHECK-SD-DOT-NEXT: addv s0, v0.4s
+; CHECK-SD-DOT-NEXT: mov v0.d[1], v1.d[0]
+; CHECK-SD-DOT-NEXT: uaddlv s0, v0.8h
; CHECK-SD-DOT-NEXT: fmov w0, s0
; CHECK-SD-DOT-NEXT: ret
;
>From ad9d5cf78d4c1e4311e7b1a881e0c407b2201262 Mon Sep 17 00:00:00 2001
From: Rin Dobrescu <rin.dobrescu at arm.com>
Date: Wed, 17 Jan 2024 12:14:53 +0000
Subject: [PATCH 2/2] Move combine function and check for same Opcode
---
.../Target/AArch64/AArch64ISelLowering.cpp | 80 +++++++++----------
llvm/test/CodeGen/AArch64/neon-dotreduce.ll | 53 +++++-------
llvm/test/CodeGen/AArch64/vecreduce-add.ll | 58 +++++---------
3 files changed, 79 insertions(+), 112 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 0f7047c84d48c9..0687dfae671da8 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -16490,14 +16490,11 @@ static SDValue performVecReduceAddCombine(SDNode *N, SelectionDAG &DAG,
VecReudceAdd8);
}
-// Turn UADDV(add(zext(extract_lo(x)), zext(extract_hi(x)))) into
-// UADDV(UADDLP(x)). If that fails, then convert UADDV(add(zext(64-bit source),
-// zext(64-bit source))) into UADDLV(concat).
+// Given an (integer) vecreduce, we know the order of the inputs does not
+// matter. We can convert UADDV(add(zext(extract_lo(x)), zext(extract_hi(x))))
+// into UADDV(UADDLP(x)). This can also happen through an extra add, where we
+// transform UADDV(add(y, add(zext(extract_lo(x)), zext(extract_hi(x))))).
static SDValue performUADDVAddCombine(SDValue A, SelectionDAG &DAG) {
- // Given an (integer) vecreduce, we know the order of the inputs does not
- // matter. We can convert UADDV(add(zext(extract_lo(x)), zext(extract_hi(x))))
- // into UADDV(UADDLP(x)). This can also happen through an extra add, where we
- // transform UADDV(add(y, add(zext(extract_lo(x)), zext(extract_hi(x))))).
auto DetectAddExtract = [&](SDValue A) {
// Look for add(zext(extract_lo(x)), zext(extract_hi(x))), returning
// UADDLP(x) if found.
@@ -16531,34 +16528,6 @@ static SDValue performUADDVAddCombine(SDValue A, SelectionDAG &DAG) {
return DAG.getNode(Opcode, SDLoc(A), VT, Ext0.getOperand(0));
};
- // We can convert a UADDV(add(zext(64-bit source), zext(64-bit source))) into
- // UADDLV(concat), where the concat represents the 64-bit zext sources.
- auto DetectZextConcat = [&](SDValue A, SelectionDAG &DAG) {
- // Look for add(zext(64-bit source), zext(64-bit source)), returning
- // UADDLV(concat(zext, zext)) if found.
- if (A.getOpcode() != ISD::ADD)
- return SDValue();
- EVT VT = A.getValueType();
- if (VT != MVT::v4i32)
- return SDValue();
- SDValue Op0 = A.getOperand(0);
- SDValue Op1 = A.getOperand(1);
- if (Op0.getOpcode() != ISD::ZERO_EXTEND)
- return SDValue();
- SDValue Ext0 = Op0.getOperand(0);
- SDValue Ext1 = Op1.getOperand(0);
- EVT ExtVT0 = Ext0.getValueType();
- EVT ExtVT1 = Ext1.getValueType();
- // Check zext VTs are the same and 64-bit length.
- if (ExtVT0 != ExtVT1 || !(ExtVT0 == MVT::v8i8 || ExtVT0 == MVT::v4i16))
- return SDValue();
- // Get VT for concat of zext sources.
- EVT PairVT = ExtVT0.getDoubleNumVectorElementsVT(*DAG.getContext());
- SDValue Concat =
- DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(A), PairVT, Ext0, Ext1);
- return DAG.getNode(AArch64ISD::UADDLV, SDLoc(A), MVT::v4i32, Concat);
- };
-
if (SDValue R = DetectAddExtract(A))
return R;
@@ -16570,20 +16539,45 @@ static SDValue performUADDVAddCombine(SDValue A, SelectionDAG &DAG) {
if (SDValue R = performUADDVAddCombine(A.getOperand(1), DAG))
return DAG.getNode(ISD::ADD, SDLoc(A), A.getValueType(), R,
A.getOperand(0));
-
- if (SDValue R = DetectZextConcat(A, DAG))
- return R;
-
return SDValue();
}
+// We can convert a UADDV(add(zext(64-bit source), zext(64-bit source))) into
+// UADDLV(concat), where the concat represents the 64-bit zext sources.
+static SDValue performUADDVZextCombine(SDValue A, SelectionDAG &DAG) {
+ // Look for add(zext(64-bit source), zext(64-bit source)), returning
+ // UADDLV(concat(zext, zext)) if found.
+ if (A.getOpcode() != ISD::ADD)
+ return SDValue();
+ EVT VT = A.getValueType();
+ if (VT != MVT::v4i32)
+ return SDValue();
+ SDValue Op0 = A.getOperand(0);
+ SDValue Op1 = A.getOperand(1);
+ if (Op0.getOpcode() != ISD::ZERO_EXTEND || Op0.getOpcode() != Op1.getOpcode())
+ return SDValue();
+ SDValue Ext0 = Op0.getOperand(0);
+ SDValue Ext1 = Op1.getOperand(0);
+ EVT ExtVT0 = Ext0.getValueType();
+ EVT ExtVT1 = Ext1.getValueType();
+ // Check zext VTs are the same and 64-bit length.
+ if (ExtVT0 != ExtVT1 || !(ExtVT0 == MVT::v8i8 || ExtVT0 == MVT::v4i16))
+ return SDValue();
+ // Get VT for concat of zext sources.
+ EVT PairVT = ExtVT0.getDoubleNumVectorElementsVT(*DAG.getContext());
+ SDValue Concat =
+ DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(A), PairVT, Ext0, Ext1);
+ return DAG.getNode(AArch64ISD::UADDLV, SDLoc(A), MVT::v4i32, Concat);
+}
+
static SDValue performUADDVCombine(SDNode *N, SelectionDAG &DAG) {
SDValue A = N->getOperand(0);
- if (A.getOpcode() == ISD::ADD)
+ if (A.getOpcode() == ISD::ADD) {
if (SDValue R = performUADDVAddCombine(A, DAG))
- return R.getOpcode() == AArch64ISD::UADDLV
- ? R
- : DAG.getNode(N->getOpcode(), SDLoc(N), N->getValueType(0), R);
+ return DAG.getNode(N->getOpcode(), SDLoc(N), N->getValueType(0), R);
+ else if (SDValue R = performUADDVZextCombine(A, DAG))
+ return R;
+ }
return SDValue();
}
diff --git a/llvm/test/CodeGen/AArch64/neon-dotreduce.ll b/llvm/test/CodeGen/AArch64/neon-dotreduce.ll
index e4767594851eae..706aa4ad1b4665 100644
--- a/llvm/test/CodeGen/AArch64/neon-dotreduce.ll
+++ b/llvm/test/CodeGen/AArch64/neon-dotreduce.ll
@@ -1039,21 +1039,17 @@ define i32 @test_udot_v25i8_nomla(ptr nocapture readonly %a1) {
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: ldp q2, q1, [x0]
; CHECK-NEXT: movi v0.2d, #0000000000000000
-; CHECK-NEXT: ushll v3.8h, v2.8b, #0
-; CHECK-NEXT: ushll v4.8h, v1.8b, #0
+; CHECK-NEXT: ushll v3.8h, v1.8b, #0
+; CHECK-NEXT: ushll v4.8h, v2.8b, #0
; CHECK-NEXT: ushll2 v1.8h, v1.16b, #0
-; CHECK-NEXT: ext v5.16b, v3.16b, v3.16b, #8
-; CHECK-NEXT: ext v6.16b, v4.16b, v4.16b, #8
+; CHECK-NEXT: ushll2 v2.8h, v2.16b, #0
+; CHECK-NEXT: uaddl2 v5.4s, v4.8h, v3.8h
; CHECK-NEXT: ushll v1.4s, v1.4h, #0
-; CHECK-NEXT: mov v3.d[1], v4.d[0]
+; CHECK-NEXT: uaddl v3.4s, v4.4h, v3.4h
; CHECK-NEXT: mov v0.s[0], v1.s[0]
-; CHECK-NEXT: ushll2 v1.8h, v2.16b, #0
-; CHECK-NEXT: mov v5.d[1], v6.d[0]
-; CHECK-NEXT: uaddlv s2, v3.8h
-; CHECK-NEXT: uaddw v0.4s, v0.4s, v1.4h
-; CHECK-NEXT: uaddlv s3, v5.8h
-; CHECK-NEXT: add v0.4s, v2.4s, v0.4s
-; CHECK-NEXT: uaddw2 v1.4s, v3.4s, v1.8h
+; CHECK-NEXT: uaddw2 v1.4s, v5.4s, v2.8h
+; CHECK-NEXT: uaddw v0.4s, v0.4s, v2.4h
+; CHECK-NEXT: add v1.4s, v3.4s, v1.4s
; CHECK-NEXT: add v0.4s, v1.4s, v0.4s
; CHECK-NEXT: addv s0, v0.4s
; CHECK-NEXT: fmov w0, s0
@@ -1635,30 +1631,23 @@ define i32 @test_udot_v33i8_nomla(ptr nocapture readonly %a1) {
; CHECK-LABEL: test_udot_v33i8_nomla:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: ldr b1, [x0, #32]
-; CHECK-NEXT: ldp q2, q3, [x0]
+; CHECK-NEXT: ldp q3, q2, [x0]
; CHECK-NEXT: movi v0.2d, #0000000000000000
; CHECK-NEXT: ushll v1.8h, v1.8b, #0
-; CHECK-NEXT: ushll2 v4.8h, v2.16b, #0
-; CHECK-NEXT: ushll2 v5.8h, v3.16b, #0
-; CHECK-NEXT: ushll v3.8h, v3.8b, #0
-; CHECK-NEXT: ushll v2.8h, v2.8b, #0
+; CHECK-NEXT: ushll v4.8h, v2.8b, #0
+; CHECK-NEXT: ushll v5.8h, v3.8b, #0
+; CHECK-NEXT: ushll2 v2.8h, v2.16b, #0
+; CHECK-NEXT: ushll2 v3.8h, v3.16b, #0
; CHECK-NEXT: ushll v1.4s, v1.4h, #0
-; CHECK-NEXT: ext v6.16b, v4.16b, v4.16b, #8
-; CHECK-NEXT: ext v7.16b, v5.16b, v5.16b, #8
-; CHECK-NEXT: mov v4.d[1], v5.d[0]
-; CHECK-NEXT: ext v16.16b, v2.16b, v2.16b, #8
+; CHECK-NEXT: uaddl2 v6.4s, v3.8h, v2.8h
+; CHECK-NEXT: uaddl v2.4s, v3.4h, v2.4h
; CHECK-NEXT: mov v0.s[0], v1.s[0]
-; CHECK-NEXT: ext v1.16b, v3.16b, v3.16b, #8
-; CHECK-NEXT: mov v6.d[1], v7.d[0]
-; CHECK-NEXT: uaddlv s4, v4.8h
-; CHECK-NEXT: mov v16.d[1], v1.d[0]
-; CHECK-NEXT: uaddw v0.4s, v0.4s, v2.4h
-; CHECK-NEXT: uaddlv s1, v6.8h
-; CHECK-NEXT: uaddlv s2, v16.8h
-; CHECK-NEXT: uaddw v0.4s, v0.4s, v3.4h
-; CHECK-NEXT: add v1.4s, v1.4s, v2.4s
-; CHECK-NEXT: add v0.4s, v4.4s, v0.4s
-; CHECK-NEXT: add v0.4s, v1.4s, v0.4s
+; CHECK-NEXT: uaddl2 v1.4s, v5.8h, v4.8h
+; CHECK-NEXT: add v1.4s, v1.4s, v6.4s
+; CHECK-NEXT: uaddw v0.4s, v0.4s, v5.4h
+; CHECK-NEXT: uaddw v0.4s, v0.4s, v4.4h
+; CHECK-NEXT: add v1.4s, v2.4s, v1.4s
+; CHECK-NEXT: add v0.4s, v0.4s, v1.4s
; CHECK-NEXT: addv s0, v0.4s
; CHECK-NEXT: fmov w0, s0
; CHECK-NEXT: ret
diff --git a/llvm/test/CodeGen/AArch64/vecreduce-add.ll b/llvm/test/CodeGen/AArch64/vecreduce-add.ll
index ad9cf2e1d99365..e431f007faf397 100644
--- a/llvm/test/CodeGen/AArch64/vecreduce-add.ll
+++ b/llvm/test/CodeGen/AArch64/vecreduce-add.ll
@@ -513,15 +513,11 @@ entry:
define i32 @add_v16i8_v16i32_zext(<16 x i8> %x) {
; CHECK-SD-BASE-LABEL: add_v16i8_v16i32_zext:
; CHECK-SD-BASE: // %bb.0: // %entry
-; CHECK-SD-BASE-NEXT: ushll v1.8h, v0.8b, #0
-; CHECK-SD-BASE-NEXT: ushll2 v0.8h, v0.16b, #0
-; CHECK-SD-BASE-NEXT: ext v2.16b, v1.16b, v1.16b, #8
-; CHECK-SD-BASE-NEXT: ext v3.16b, v0.16b, v0.16b, #8
-; CHECK-SD-BASE-NEXT: mov v1.d[1], v0.d[0]
-; CHECK-SD-BASE-NEXT: mov v2.d[1], v3.d[0]
-; CHECK-SD-BASE-NEXT: uaddlv s0, v1.8h
-; CHECK-SD-BASE-NEXT: uaddlv s1, v2.8h
-; CHECK-SD-BASE-NEXT: add v0.4s, v1.4s, v0.4s
+; CHECK-SD-BASE-NEXT: ushll2 v1.8h, v0.16b, #0
+; CHECK-SD-BASE-NEXT: ushll v0.8h, v0.8b, #0
+; CHECK-SD-BASE-NEXT: uaddl2 v2.4s, v0.8h, v1.8h
+; CHECK-SD-BASE-NEXT: uaddl v0.4s, v0.4h, v1.4h
+; CHECK-SD-BASE-NEXT: add v0.4s, v0.4s, v2.4s
; CHECK-SD-BASE-NEXT: addv s0, v0.4s
; CHECK-SD-BASE-NEXT: fmov w0, s0
; CHECK-SD-BASE-NEXT: ret
@@ -1914,15 +1910,11 @@ entry:
define i32 @add_v16i8_v16i32_acc_zext(<16 x i8> %x, i32 %a) {
; CHECK-SD-BASE-LABEL: add_v16i8_v16i32_acc_zext:
; CHECK-SD-BASE: // %bb.0: // %entry
-; CHECK-SD-BASE-NEXT: ushll v1.8h, v0.8b, #0
-; CHECK-SD-BASE-NEXT: ushll2 v0.8h, v0.16b, #0
-; CHECK-SD-BASE-NEXT: ext v2.16b, v1.16b, v1.16b, #8
-; CHECK-SD-BASE-NEXT: ext v3.16b, v0.16b, v0.16b, #8
-; CHECK-SD-BASE-NEXT: mov v1.d[1], v0.d[0]
-; CHECK-SD-BASE-NEXT: mov v2.d[1], v3.d[0]
-; CHECK-SD-BASE-NEXT: uaddlv s0, v1.8h
-; CHECK-SD-BASE-NEXT: uaddlv s1, v2.8h
-; CHECK-SD-BASE-NEXT: add v0.4s, v1.4s, v0.4s
+; CHECK-SD-BASE-NEXT: ushll2 v1.8h, v0.16b, #0
+; CHECK-SD-BASE-NEXT: ushll v0.8h, v0.8b, #0
+; CHECK-SD-BASE-NEXT: uaddl2 v2.4s, v0.8h, v1.8h
+; CHECK-SD-BASE-NEXT: uaddl v0.4s, v0.4h, v1.4h
+; CHECK-SD-BASE-NEXT: add v0.4s, v0.4s, v2.4s
; CHECK-SD-BASE-NEXT: addv s0, v0.4s
; CHECK-SD-BASE-NEXT: fmov w8, s0
; CHECK-SD-BASE-NEXT: add w0, w8, w0
@@ -4793,25 +4785,17 @@ entry:
define i32 @add_pair_v16i8_v16i32_zext(<16 x i8> %x, <16 x i8> %y) {
; CHECK-SD-BASE-LABEL: add_pair_v16i8_v16i32_zext:
; CHECK-SD-BASE: // %bb.0: // %entry
-; CHECK-SD-BASE-NEXT: ushll v2.8h, v0.8b, #0
-; CHECK-SD-BASE-NEXT: ushll2 v0.8h, v0.16b, #0
-; CHECK-SD-BASE-NEXT: ushll v3.8h, v1.8b, #0
-; CHECK-SD-BASE-NEXT: ushll2 v1.8h, v1.16b, #0
-; CHECK-SD-BASE-NEXT: ext v4.16b, v2.16b, v2.16b, #8
-; CHECK-SD-BASE-NEXT: ext v5.16b, v0.16b, v0.16b, #8
-; CHECK-SD-BASE-NEXT: ext v6.16b, v3.16b, v3.16b, #8
-; CHECK-SD-BASE-NEXT: ext v7.16b, v1.16b, v1.16b, #8
-; CHECK-SD-BASE-NEXT: mov v2.d[1], v0.d[0]
-; CHECK-SD-BASE-NEXT: mov v3.d[1], v1.d[0]
-; CHECK-SD-BASE-NEXT: mov v4.d[1], v5.d[0]
-; CHECK-SD-BASE-NEXT: mov v6.d[1], v7.d[0]
-; CHECK-SD-BASE-NEXT: uaddlv s0, v2.8h
-; CHECK-SD-BASE-NEXT: uaddlv s2, v3.8h
-; CHECK-SD-BASE-NEXT: uaddlv s1, v4.8h
-; CHECK-SD-BASE-NEXT: uaddlv s3, v6.8h
-; CHECK-SD-BASE-NEXT: add v0.4s, v1.4s, v0.4s
-; CHECK-SD-BASE-NEXT: add v1.4s, v3.4s, v2.4s
-; CHECK-SD-BASE-NEXT: add v0.4s, v1.4s, v0.4s
+; CHECK-SD-BASE-NEXT: ushll2 v2.8h, v0.16b, #0
+; CHECK-SD-BASE-NEXT: ushll v0.8h, v0.8b, #0
+; CHECK-SD-BASE-NEXT: ushll2 v3.8h, v1.16b, #0
+; CHECK-SD-BASE-NEXT: ushll v1.8h, v1.8b, #0
+; CHECK-SD-BASE-NEXT: uaddl2 v4.4s, v0.8h, v2.8h
+; CHECK-SD-BASE-NEXT: uaddl v0.4s, v0.4h, v2.4h
+; CHECK-SD-BASE-NEXT: uaddl2 v2.4s, v1.8h, v3.8h
+; CHECK-SD-BASE-NEXT: uaddl v1.4s, v1.4h, v3.4h
+; CHECK-SD-BASE-NEXT: add v0.4s, v0.4s, v4.4s
+; CHECK-SD-BASE-NEXT: add v1.4s, v1.4s, v2.4s
+; CHECK-SD-BASE-NEXT: add v0.4s, v0.4s, v1.4s
; CHECK-SD-BASE-NEXT: addv s0, v0.4s
; CHECK-SD-BASE-NEXT: fmov w0, s0
; CHECK-SD-BASE-NEXT: ret
More information about the llvm-commits
mailing list