[llvm] [AArch64][SelectionDAG] Add type legalization for partial reduce wide adds (PR #141075)
Nicholas Guy via llvm-commits
llvm-commits at lists.llvm.org
Wed May 28 02:53:08 PDT 2025
https://github.com/NickGuy-Arm updated https://github.com/llvm/llvm-project/pull/141075
>From f5ce921cb4e8ed49c07f02b0aadea09b5816e08c Mon Sep 17 00:00:00 2001
From: Nick Guy <nicholas.guy at arm.com>
Date: Tue, 20 May 2025 14:25:37 +0100
Subject: [PATCH 1/4] [AArch64][SelectionDAG] Add type legalization for partial
reduce wide adds
---
llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 12 +-
.../Target/AArch64/AArch64ISelLowering.cpp | 36 ++
.../lib/Target/AArch64/AArch64SVEInstrInfo.td | 13 +
.../AArch64/sve-partial-reduce-dot-product.ll | 206 +++++++----
.../AArch64/sve-partial-reduce-wide-add.ll | 322 +++++++++++++-----
5 files changed, 427 insertions(+), 162 deletions(-)
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index d6e288a59b2ee..0ac8f6f3a8171 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -12644,6 +12644,8 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
TLI.getTypeToTransformTo(*Context, LHSExtOpVT)))
return SDValue();
+ EVT ResultVT = N->getValueType(0);
+
bool ExtIsSigned = LHSOpcode == ISD::SIGN_EXTEND;
unsigned NewOpcode =
ExtIsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
@@ -12657,7 +12659,7 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
(LHSOpcode != ISD::SIGN_EXTEND || CTrunc.sext(LHSBits) != C))
return SDValue();
- return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, LHSExtOp,
+ return DAG.getNode(NewOpcode, DL, ResultVT, Acc, LHSExtOp,
DAG.getConstant(CTrunc, DL, LHSExtOpVT));
}
@@ -12678,8 +12680,7 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
Op1.getValueType().getVectorElementType() != AccElemVT)
return SDValue();
- return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, LHSExtOp,
- RHSExtOp);
+ return DAG.getNode(NewOpcode, DL, ResultVT, Acc, LHSExtOp, RHSExtOp);
}
// partial.reduce.umla(acc, zext(op), splat(1))
@@ -12703,7 +12704,10 @@ SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {
SDValue UnextOp1 = Op1.getOperand(0);
EVT UnextOp1VT = UnextOp1.getValueType();
- if (!TLI.isPartialReduceMLALegalOrCustom(N->getValueType(0), UnextOp1VT))
+ auto *Context = DAG.getContext();
+ if (!TLI.isPartialReduceMLALegalOrCustom(
+ TLI.getTypeToTransformTo(*Context, N->getValueType(0)),
+ TLI.getTypeToTransformTo(*Context, UnextOp1VT)))
return SDValue();
bool Op1IsSigned = Op1Opcode == ISD::SIGN_EXTEND;
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 13fb6a32233fe..d602a62eaaf84 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1870,6 +1870,13 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv16i8, Legal);
setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv16i8, Custom);
+
+ // Wide add types
+ if (Subtarget->hasSVE2() || Subtarget->hasSME()) {
+ setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv4i32, Custom);
+ setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv8i16, Custom);
+ setPartialReduceMLAAction(MVT::nxv8i16, MVT::nxv16i8, Custom);
+ }
}
// Handle operations that are only available in non-streaming SVE mode.
@@ -29530,6 +29537,35 @@ AArch64TargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op,
SDValue LHS = Op.getOperand(1);
SDValue RHS = Op.getOperand(2);
EVT ResultVT = Op.getValueType();
+
+ // Recognise Op as a wide add, if it is then we leave it as-is
+ // Base: nxv2i64, Subdivision: nxv4i32
+ auto IsEVTSubdivision = [](EVT Base, EVT Subdivision) -> bool {
+ assert(Base.isVector() && Subdivision.isVector());
+ assert(Base.isScalableVector() == Subdivision.isScalableVector());
+
+ ElementCount BaseCount = Base.getVectorElementCount();
+ ElementCount SubCount = Subdivision.getVectorElementCount();
+ if (BaseCount * 2 != SubCount)
+ return false;
+
+ uint64_t BaseScalarSize = Base.getScalarSizeInBits();
+ uint64_t SubScalarSize = Subdivision.getScalarSizeInBits();
+ if (BaseScalarSize != SubScalarSize * 2)
+ return false;
+
+ return true;
+ };
+ if (IsEVTSubdivision(ResultVT, LHS.getValueType())) {
+ // If it looks like a real wide add, we can leave it as-is and treat it as
+ // Legal
+ APInt C;
+ if (ISD::isConstantSplatVector(RHS.getNode(), C) && C.isOne())
+ return Op;
+ // If it doesn't, then we need to expand it.
+ return SDValue();
+ }
+
assert(ResultVT == MVT::nxv2i64 && LHS.getValueType() == MVT::nxv16i8);
SDValue DotNode = DAG.getNode(Op.getOpcode(), DL, MVT::nxv4i32,
diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index d6bd59adef03b..b15caa25b604e 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -3787,6 +3787,19 @@ let Predicates = [HasSVE2_or_SME] in {
defm USUBWB_ZZZ : sve2_wide_int_arith_wide<0b110, "usubwb", int_aarch64_sve_usubwb>;
defm USUBWT_ZZZ : sve2_wide_int_arith_wide<0b111, "usubwt", int_aarch64_sve_usubwt>;
+ def : Pat<(nxv2i64 (partial_reduce_umla nxv2i64:$Acc, nxv4i32:$Input, (nxv4i32 (splat_vector (i32 1))))),
+ (UADDWT_ZZZ_D (UADDWB_ZZZ_D $Acc, $Input), $Input)>;
+ def : Pat<(nxv2i64 (partial_reduce_smla nxv2i64:$Acc, nxv4i32:$Input, (nxv4i32 (splat_vector (i32 1))))),
+ (SADDWT_ZZZ_D (SADDWB_ZZZ_D $Acc, $Input), $Input)>;
+ def : Pat<(nxv4i32 (partial_reduce_umla nxv4i32:$Acc, nxv8i16:$Input, (nxv8i16 (splat_vector (i32 1))))),
+ (UADDWT_ZZZ_D (UADDWB_ZZZ_D $Acc, $Input), $Input)>;
+ def : Pat<(nxv4i32 (partial_reduce_smla nxv4i32:$Acc, nxv8i16:$Input, (nxv8i16 (splat_vector (i32 1))))),
+ (SADDWT_ZZZ_D (SADDWB_ZZZ_D $Acc, $Input), $Input)>;
+ def : Pat<(nxv8i16 (partial_reduce_umla nxv8i16:$Acc, nxv16i8:$Input, (nxv16i8 (splat_vector (i32 1))))),
+ (UADDWT_ZZZ_D (UADDWB_ZZZ_D $Acc, $Input), $Input)>;
+ def : Pat<(nxv8i16 (partial_reduce_smla nxv8i16:$Acc, nxv16i8:$Input, (nxv16i8 (splat_vector (i32 1))))),
+ (SADDWT_ZZZ_D (SADDWB_ZZZ_D $Acc, $Input), $Input)>;
+
// SVE2 integer multiply long
defm SQDMULLB_ZZZ : sve2_wide_int_arith_long<0b11000, "sqdmullb", int_aarch64_sve_sqdmullb>;
defm SQDMULLT_ZZZ : sve2_wide_int_arith_long<0b11001, "sqdmullt", int_aarch64_sve_sqdmullt>;
diff --git a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
index 5bc9a101b1e44..baa63a4ca31a2 100644
--- a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
@@ -561,31 +561,34 @@ define <vscale x 4 x i64> @udot_no_bin_op_8to64(<vscale x 4 x i64> %acc, <vscale
; CHECK-NEXT: add z1.d, z1.d, z3.d
; CHECK-NEXT: ret
;
-; CHECK-NEWLOWERING-LABEL: udot_no_bin_op_8to64:
-; CHECK-NEWLOWERING: // %bb.0:
-; CHECK-NEWLOWERING-NEXT: uunpkhi z3.h, z2.b
-; CHECK-NEWLOWERING-NEXT: uunpklo z2.h, z2.b
-; CHECK-NEWLOWERING-NEXT: uunpkhi z4.s, z3.h
-; CHECK-NEWLOWERING-NEXT: uunpklo z3.s, z3.h
-; CHECK-NEWLOWERING-NEXT: uunpklo z5.s, z2.h
-; CHECK-NEWLOWERING-NEXT: uunpkhi z2.s, z2.h
-; CHECK-NEWLOWERING-NEXT: uunpklo z6.d, z4.s
-; CHECK-NEWLOWERING-NEXT: uunpkhi z7.d, z3.s
-; CHECK-NEWLOWERING-NEXT: uunpklo z24.d, z5.s
-; CHECK-NEWLOWERING-NEXT: uunpklo z3.d, z3.s
-; CHECK-NEWLOWERING-NEXT: uunpklo z25.d, z2.s
-; CHECK-NEWLOWERING-NEXT: uunpkhi z5.d, z5.s
-; CHECK-NEWLOWERING-NEXT: uunpkhi z4.d, z4.s
-; CHECK-NEWLOWERING-NEXT: uunpkhi z2.d, z2.s
-; CHECK-NEWLOWERING-NEXT: add z0.d, z0.d, z24.d
-; CHECK-NEWLOWERING-NEXT: add z5.d, z5.d, z25.d
-; CHECK-NEWLOWERING-NEXT: add z1.d, z1.d, z3.d
-; CHECK-NEWLOWERING-NEXT: add z3.d, z7.d, z6.d
-; CHECK-NEWLOWERING-NEXT: add z0.d, z0.d, z5.d
-; CHECK-NEWLOWERING-NEXT: add z1.d, z1.d, z3.d
-; CHECK-NEWLOWERING-NEXT: add z0.d, z0.d, z2.d
-; CHECK-NEWLOWERING-NEXT: add z1.d, z1.d, z4.d
-; CHECK-NEWLOWERING-NEXT: ret
+; CHECK-NEWLOWERING-SVE-LABEL: udot_no_bin_op_8to64:
+; CHECK-NEWLOWERING-SVE: // %bb.0:
+; CHECK-NEWLOWERING-SVE-NEXT: movi v3.2d, #0000000000000000
+; CHECK-NEWLOWERING-SVE-NEXT: mov z4.b, #1 // =0x1
+; CHECK-NEWLOWERING-SVE-NEXT: udot z3.s, z2.b, z4.b
+; CHECK-NEWLOWERING-SVE-NEXT: uunpkhi z2.d, z3.s
+; CHECK-NEWLOWERING-SVE-NEXT: uunpklo z3.d, z3.s
+; CHECK-NEWLOWERING-SVE-NEXT: add z2.d, z3.d, z2.d
+; CHECK-NEWLOWERING-SVE-NEXT: add z0.d, z0.d, z2.d
+; CHECK-NEWLOWERING-SVE-NEXT: ret
+;
+; CHECK-NEWLOWERING-SVE2-LABEL: udot_no_bin_op_8to64:
+; CHECK-NEWLOWERING-SVE2: // %bb.0:
+; CHECK-NEWLOWERING-SVE2-NEXT: movi v3.2d, #0000000000000000
+; CHECK-NEWLOWERING-SVE2-NEXT: mov z4.b, #1 // =0x1
+; CHECK-NEWLOWERING-SVE2-NEXT: udot z3.s, z2.b, z4.b
+; CHECK-NEWLOWERING-SVE2-NEXT: uaddwb z0.d, z0.d, z3.s
+; CHECK-NEWLOWERING-SVE2-NEXT: uaddwt z0.d, z0.d, z3.s
+; CHECK-NEWLOWERING-SVE2-NEXT: ret
+;
+; CHECK-NEWLOWERING-SME-LABEL: udot_no_bin_op_8to64:
+; CHECK-NEWLOWERING-SME: // %bb.0:
+; CHECK-NEWLOWERING-SME-NEXT: mov z3.b, #1 // =0x1
+; CHECK-NEWLOWERING-SME-NEXT: mov z4.s, #0 // =0x0
+; CHECK-NEWLOWERING-SME-NEXT: udot z4.s, z2.b, z3.b
+; CHECK-NEWLOWERING-SME-NEXT: uaddwb z0.d, z0.d, z4.s
+; CHECK-NEWLOWERING-SME-NEXT: uaddwt z0.d, z0.d, z4.s
+; CHECK-NEWLOWERING-SME-NEXT: ret
%a.ext = zext <vscale x 16 x i8> %a to <vscale x 16 x i64>
%partial.reduce = tail call <vscale x 4 x i64> @llvm.experimental.vector.partial.reduce.add.nxv4i64.nxv16i64(<vscale x 4 x i64> %acc, <vscale x 16 x i64> %a.ext)
ret <vscale x 4 x i64> %partial.reduce
@@ -603,31 +606,34 @@ define <vscale x 4 x i64> @sdot_no_bin_op_8to64(<vscale x 4 x i64> %acc, <vscale
; CHECK-NEXT: add z1.d, z1.d, z3.d
; CHECK-NEXT: ret
;
-; CHECK-NEWLOWERING-LABEL: sdot_no_bin_op_8to64:
-; CHECK-NEWLOWERING: // %bb.0:
-; CHECK-NEWLOWERING-NEXT: sunpkhi z3.h, z2.b
-; CHECK-NEWLOWERING-NEXT: sunpklo z2.h, z2.b
-; CHECK-NEWLOWERING-NEXT: sunpkhi z4.s, z3.h
-; CHECK-NEWLOWERING-NEXT: sunpklo z3.s, z3.h
-; CHECK-NEWLOWERING-NEXT: sunpklo z5.s, z2.h
-; CHECK-NEWLOWERING-NEXT: sunpkhi z2.s, z2.h
-; CHECK-NEWLOWERING-NEXT: sunpklo z6.d, z4.s
-; CHECK-NEWLOWERING-NEXT: sunpkhi z7.d, z3.s
-; CHECK-NEWLOWERING-NEXT: sunpklo z24.d, z5.s
-; CHECK-NEWLOWERING-NEXT: sunpklo z3.d, z3.s
-; CHECK-NEWLOWERING-NEXT: sunpklo z25.d, z2.s
-; CHECK-NEWLOWERING-NEXT: sunpkhi z5.d, z5.s
-; CHECK-NEWLOWERING-NEXT: sunpkhi z4.d, z4.s
-; CHECK-NEWLOWERING-NEXT: sunpkhi z2.d, z2.s
-; CHECK-NEWLOWERING-NEXT: add z0.d, z0.d, z24.d
-; CHECK-NEWLOWERING-NEXT: add z5.d, z5.d, z25.d
-; CHECK-NEWLOWERING-NEXT: add z1.d, z1.d, z3.d
-; CHECK-NEWLOWERING-NEXT: add z3.d, z7.d, z6.d
-; CHECK-NEWLOWERING-NEXT: add z0.d, z0.d, z5.d
-; CHECK-NEWLOWERING-NEXT: add z1.d, z1.d, z3.d
-; CHECK-NEWLOWERING-NEXT: add z0.d, z0.d, z2.d
-; CHECK-NEWLOWERING-NEXT: add z1.d, z1.d, z4.d
-; CHECK-NEWLOWERING-NEXT: ret
+; CHECK-NEWLOWERING-SVE-LABEL: sdot_no_bin_op_8to64:
+; CHECK-NEWLOWERING-SVE: // %bb.0:
+; CHECK-NEWLOWERING-SVE-NEXT: movi v3.2d, #0000000000000000
+; CHECK-NEWLOWERING-SVE-NEXT: mov z4.b, #1 // =0x1
+; CHECK-NEWLOWERING-SVE-NEXT: sdot z3.s, z2.b, z4.b
+; CHECK-NEWLOWERING-SVE-NEXT: sunpkhi z2.d, z3.s
+; CHECK-NEWLOWERING-SVE-NEXT: sunpklo z3.d, z3.s
+; CHECK-NEWLOWERING-SVE-NEXT: add z2.d, z3.d, z2.d
+; CHECK-NEWLOWERING-SVE-NEXT: add z0.d, z0.d, z2.d
+; CHECK-NEWLOWERING-SVE-NEXT: ret
+;
+; CHECK-NEWLOWERING-SVE2-LABEL: sdot_no_bin_op_8to64:
+; CHECK-NEWLOWERING-SVE2: // %bb.0:
+; CHECK-NEWLOWERING-SVE2-NEXT: movi v3.2d, #0000000000000000
+; CHECK-NEWLOWERING-SVE2-NEXT: mov z4.b, #1 // =0x1
+; CHECK-NEWLOWERING-SVE2-NEXT: sdot z3.s, z2.b, z4.b
+; CHECK-NEWLOWERING-SVE2-NEXT: saddwb z0.d, z0.d, z3.s
+; CHECK-NEWLOWERING-SVE2-NEXT: saddwt z0.d, z0.d, z3.s
+; CHECK-NEWLOWERING-SVE2-NEXT: ret
+;
+; CHECK-NEWLOWERING-SME-LABEL: sdot_no_bin_op_8to64:
+; CHECK-NEWLOWERING-SME: // %bb.0:
+; CHECK-NEWLOWERING-SME-NEXT: mov z3.b, #1 // =0x1
+; CHECK-NEWLOWERING-SME-NEXT: mov z4.s, #0 // =0x0
+; CHECK-NEWLOWERING-SME-NEXT: sdot z4.s, z2.b, z3.b
+; CHECK-NEWLOWERING-SME-NEXT: saddwb z0.d, z0.d, z4.s
+; CHECK-NEWLOWERING-SME-NEXT: saddwt z0.d, z0.d, z4.s
+; CHECK-NEWLOWERING-SME-NEXT: ret
%a.ext = sext <vscale x 16 x i8> %a to <vscale x 16 x i64>
%partial.reduce = tail call <vscale x 4 x i64> @llvm.experimental.vector.partial.reduce.add.nxv4i64.nxv16i64(<vscale x 4 x i64> %acc, <vscale x 16 x i64> %a.ext)
ret <vscale x 4 x i64> %partial.reduce
@@ -647,18 +653,44 @@ define <vscale x 4 x i32> @not_udot(<vscale x 4 x i32> %acc, <vscale x 8 x i8> %
; CHECK-NEXT: mla z0.s, p0/m, z1.s, z2.s
; CHECK-NEXT: ret
;
-; CHECK-NEWLOWERING-LABEL: not_udot:
-; CHECK-NEWLOWERING: // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT: and z1.h, z1.h, #0xff
-; CHECK-NEWLOWERING-NEXT: and z2.h, z2.h, #0xff
-; CHECK-NEWLOWERING-NEXT: ptrue p0.s
-; CHECK-NEWLOWERING-NEXT: uunpklo z3.s, z1.h
-; CHECK-NEWLOWERING-NEXT: uunpklo z4.s, z2.h
-; CHECK-NEWLOWERING-NEXT: uunpkhi z1.s, z1.h
-; CHECK-NEWLOWERING-NEXT: uunpkhi z2.s, z2.h
-; CHECK-NEWLOWERING-NEXT: mla z0.s, p0/m, z3.s, z4.s
-; CHECK-NEWLOWERING-NEXT: mla z0.s, p0/m, z1.s, z2.s
-; CHECK-NEWLOWERING-NEXT: ret
+; CHECK-NEWLOWERING-SVE-LABEL: not_udot:
+; CHECK-NEWLOWERING-SVE: // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE-NEXT: and z1.h, z1.h, #0xff
+; CHECK-NEWLOWERING-SVE-NEXT: and z2.h, z2.h, #0xff
+; CHECK-NEWLOWERING-SVE-NEXT: ptrue p0.s
+; CHECK-NEWLOWERING-SVE-NEXT: uunpklo z3.s, z1.h
+; CHECK-NEWLOWERING-SVE-NEXT: uunpklo z4.s, z2.h
+; CHECK-NEWLOWERING-SVE-NEXT: uunpkhi z1.s, z1.h
+; CHECK-NEWLOWERING-SVE-NEXT: uunpkhi z2.s, z2.h
+; CHECK-NEWLOWERING-SVE-NEXT: mla z0.s, p0/m, z3.s, z4.s
+; CHECK-NEWLOWERING-SVE-NEXT: mla z0.s, p0/m, z1.s, z2.s
+; CHECK-NEWLOWERING-SVE-NEXT: ret
+;
+; CHECK-NEWLOWERING-SVE2-LABEL: not_udot:
+; CHECK-NEWLOWERING-SVE2: // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE2-NEXT: and z1.h, z1.h, #0xff
+; CHECK-NEWLOWERING-SVE2-NEXT: and z2.h, z2.h, #0xff
+; CHECK-NEWLOWERING-SVE2-NEXT: ptrue p0.s
+; CHECK-NEWLOWERING-SVE2-NEXT: uunpklo z3.s, z2.h
+; CHECK-NEWLOWERING-SVE2-NEXT: uunpklo z4.s, z1.h
+; CHECK-NEWLOWERING-SVE2-NEXT: uunpkhi z2.s, z2.h
+; CHECK-NEWLOWERING-SVE2-NEXT: uunpkhi z1.s, z1.h
+; CHECK-NEWLOWERING-SVE2-NEXT: mla z0.s, p0/m, z4.s, z3.s
+; CHECK-NEWLOWERING-SVE2-NEXT: mla z0.s, p0/m, z1.s, z2.s
+; CHECK-NEWLOWERING-SVE2-NEXT: ret
+;
+; CHECK-NEWLOWERING-SME-LABEL: not_udot:
+; CHECK-NEWLOWERING-SME: // %bb.0: // %entry
+; CHECK-NEWLOWERING-SME-NEXT: and z1.h, z1.h, #0xff
+; CHECK-NEWLOWERING-SME-NEXT: and z2.h, z2.h, #0xff
+; CHECK-NEWLOWERING-SME-NEXT: ptrue p0.s
+; CHECK-NEWLOWERING-SME-NEXT: uunpklo z3.s, z2.h
+; CHECK-NEWLOWERING-SME-NEXT: uunpklo z4.s, z1.h
+; CHECK-NEWLOWERING-SME-NEXT: uunpkhi z2.s, z2.h
+; CHECK-NEWLOWERING-SME-NEXT: uunpkhi z1.s, z1.h
+; CHECK-NEWLOWERING-SME-NEXT: mla z0.s, p0/m, z4.s, z3.s
+; CHECK-NEWLOWERING-SME-NEXT: mla z0.s, p0/m, z1.s, z2.s
+; CHECK-NEWLOWERING-SME-NEXT: ret
entry:
%a.wide = zext <vscale x 8 x i8> %a to <vscale x 8 x i32>
%b.wide = zext <vscale x 8 x i8> %b to <vscale x 8 x i32>
@@ -681,18 +713,44 @@ define <vscale x 2 x i64> @not_udot_wide(<vscale x 2 x i64> %acc, <vscale x 4 x
; CHECK-NEXT: mla z0.d, p0/m, z1.d, z2.d
; CHECK-NEXT: ret
;
-; CHECK-NEWLOWERING-LABEL: not_udot_wide:
-; CHECK-NEWLOWERING: // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT: and z1.s, z1.s, #0xffff
-; CHECK-NEWLOWERING-NEXT: and z2.s, z2.s, #0xffff
-; CHECK-NEWLOWERING-NEXT: ptrue p0.d
-; CHECK-NEWLOWERING-NEXT: uunpklo z3.d, z1.s
-; CHECK-NEWLOWERING-NEXT: uunpklo z4.d, z2.s
-; CHECK-NEWLOWERING-NEXT: uunpkhi z1.d, z1.s
-; CHECK-NEWLOWERING-NEXT: uunpkhi z2.d, z2.s
-; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z3.d, z4.d
-; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z1.d, z2.d
-; CHECK-NEWLOWERING-NEXT: ret
+; CHECK-NEWLOWERING-SVE-LABEL: not_udot_wide:
+; CHECK-NEWLOWERING-SVE: // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE-NEXT: and z1.s, z1.s, #0xffff
+; CHECK-NEWLOWERING-SVE-NEXT: and z2.s, z2.s, #0xffff
+; CHECK-NEWLOWERING-SVE-NEXT: ptrue p0.d
+; CHECK-NEWLOWERING-SVE-NEXT: uunpklo z3.d, z1.s
+; CHECK-NEWLOWERING-SVE-NEXT: uunpklo z4.d, z2.s
+; CHECK-NEWLOWERING-SVE-NEXT: uunpkhi z1.d, z1.s
+; CHECK-NEWLOWERING-SVE-NEXT: uunpkhi z2.d, z2.s
+; CHECK-NEWLOWERING-SVE-NEXT: mla z0.d, p0/m, z3.d, z4.d
+; CHECK-NEWLOWERING-SVE-NEXT: mla z0.d, p0/m, z1.d, z2.d
+; CHECK-NEWLOWERING-SVE-NEXT: ret
+;
+; CHECK-NEWLOWERING-SVE2-LABEL: not_udot_wide:
+; CHECK-NEWLOWERING-SVE2: // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE2-NEXT: and z1.s, z1.s, #0xffff
+; CHECK-NEWLOWERING-SVE2-NEXT: and z2.s, z2.s, #0xffff
+; CHECK-NEWLOWERING-SVE2-NEXT: ptrue p0.d
+; CHECK-NEWLOWERING-SVE2-NEXT: uunpklo z3.d, z2.s
+; CHECK-NEWLOWERING-SVE2-NEXT: uunpklo z4.d, z1.s
+; CHECK-NEWLOWERING-SVE2-NEXT: uunpkhi z2.d, z2.s
+; CHECK-NEWLOWERING-SVE2-NEXT: uunpkhi z1.d, z1.s
+; CHECK-NEWLOWERING-SVE2-NEXT: mla z0.d, p0/m, z4.d, z3.d
+; CHECK-NEWLOWERING-SVE2-NEXT: mla z0.d, p0/m, z1.d, z2.d
+; CHECK-NEWLOWERING-SVE2-NEXT: ret
+;
+; CHECK-NEWLOWERING-SME-LABEL: not_udot_wide:
+; CHECK-NEWLOWERING-SME: // %bb.0: // %entry
+; CHECK-NEWLOWERING-SME-NEXT: and z1.s, z1.s, #0xffff
+; CHECK-NEWLOWERING-SME-NEXT: and z2.s, z2.s, #0xffff
+; CHECK-NEWLOWERING-SME-NEXT: ptrue p0.d
+; CHECK-NEWLOWERING-SME-NEXT: uunpklo z3.d, z2.s
+; CHECK-NEWLOWERING-SME-NEXT: uunpklo z4.d, z1.s
+; CHECK-NEWLOWERING-SME-NEXT: uunpkhi z2.d, z2.s
+; CHECK-NEWLOWERING-SME-NEXT: uunpkhi z1.d, z1.s
+; CHECK-NEWLOWERING-SME-NEXT: mla z0.d, p0/m, z4.d, z3.d
+; CHECK-NEWLOWERING-SME-NEXT: mla z0.d, p0/m, z1.d, z2.d
+; CHECK-NEWLOWERING-SME-NEXT: ret
entry:
%a.wide = zext <vscale x 4 x i16> %a to <vscale x 4 x i64>
%b.wide = zext <vscale x 4 x i16> %b to <vscale x 4 x i64>
diff --git a/llvm/test/CodeGen/AArch64/sve-partial-reduce-wide-add.ll b/llvm/test/CodeGen/AArch64/sve-partial-reduce-wide-add.ll
index 5148d3da6c737..8f9f26a5d5b23 100644
--- a/llvm/test/CodeGen/AArch64/sve-partial-reduce-wide-add.ll
+++ b/llvm/test/CodeGen/AArch64/sve-partial-reduce-wide-add.ll
@@ -1,7 +1,8 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
-; RUN: llc -mtriple=aarch64 -mattr=+sve2 %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-SVE2
-; RUN: llc -mtriple=aarch64 -mattr=+sve %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-SVE
-; RUN: llc -mtriple=aarch64 -mattr=+sve2 -aarch64-enable-partial-reduce-nodes %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-NEWLOWERING
+; RUN: llc -mtriple=aarch64 -mattr=+sve2 %s -o - | FileCheck %s --check-prefixes=CHECK-SVE2
+; RUN: llc -mtriple=aarch64 -mattr=+sve %s -o - | FileCheck %s --check-prefixes=CHECK-SVE
+; RUN: llc -mtriple=aarch64 -mattr=+sve -aarch64-enable-partial-reduce-nodes %s -o - | FileCheck %s --check-prefixes=CHECK-NEWLOWERING-SVE
+; RUN: llc -mtriple=aarch64 -mattr=+sve2 -aarch64-enable-partial-reduce-nodes %s -o - | FileCheck %s --check-prefixes=CHECK-NEWLOWERING-SVE2
define <vscale x 2 x i64> @signed_wide_add_nxv4i32(<vscale x 2 x i64> %acc, <vscale x 4 x i32> %input){
; CHECK-SVE2-LABEL: signed_wide_add_nxv4i32:
@@ -18,13 +19,19 @@ define <vscale x 2 x i64> @signed_wide_add_nxv4i32(<vscale x 2 x i64> %acc, <vsc
; CHECK-SVE-NEXT: add z0.d, z0.d, z1.d
; CHECK-SVE-NEXT: ret
;
-; CHECK-NEWLOWERING-LABEL: signed_wide_add_nxv4i32:
-; CHECK-NEWLOWERING: // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT: sunpklo z2.d, z1.s
-; CHECK-NEWLOWERING-NEXT: sunpkhi z1.d, z1.s
-; CHECK-NEWLOWERING-NEXT: add z0.d, z0.d, z2.d
-; CHECK-NEWLOWERING-NEXT: add z0.d, z0.d, z1.d
-; CHECK-NEWLOWERING-NEXT: ret
+; CHECK-NEWLOWERING-SVE-LABEL: signed_wide_add_nxv4i32:
+; CHECK-NEWLOWERING-SVE: // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE-NEXT: sunpklo z2.d, z1.s
+; CHECK-NEWLOWERING-SVE-NEXT: sunpkhi z1.d, z1.s
+; CHECK-NEWLOWERING-SVE-NEXT: add z0.d, z0.d, z2.d
+; CHECK-NEWLOWERING-SVE-NEXT: add z0.d, z0.d, z1.d
+; CHECK-NEWLOWERING-SVE-NEXT: ret
+;
+; CHECK-NEWLOWERING-SVE2-LABEL: signed_wide_add_nxv4i32:
+; CHECK-NEWLOWERING-SVE2: // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE2-NEXT: saddwb z0.d, z0.d, z1.s
+; CHECK-NEWLOWERING-SVE2-NEXT: saddwt z0.d, z0.d, z1.s
+; CHECK-NEWLOWERING-SVE2-NEXT: ret
entry:
%input.wide = sext <vscale x 4 x i32> %input to <vscale x 4 x i64>
%partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv4i64(<vscale x 2 x i64> %acc, <vscale x 4 x i64> %input.wide)
@@ -46,13 +53,19 @@ define <vscale x 2 x i64> @unsigned_wide_add_nxv4i32(<vscale x 2 x i64> %acc, <v
; CHECK-SVE-NEXT: add z0.d, z0.d, z1.d
; CHECK-SVE-NEXT: ret
;
-; CHECK-NEWLOWERING-LABEL: unsigned_wide_add_nxv4i32:
-; CHECK-NEWLOWERING: // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT: uunpklo z2.d, z1.s
-; CHECK-NEWLOWERING-NEXT: uunpkhi z1.d, z1.s
-; CHECK-NEWLOWERING-NEXT: add z0.d, z0.d, z2.d
-; CHECK-NEWLOWERING-NEXT: add z0.d, z0.d, z1.d
-; CHECK-NEWLOWERING-NEXT: ret
+; CHECK-NEWLOWERING-SVE-LABEL: unsigned_wide_add_nxv4i32:
+; CHECK-NEWLOWERING-SVE: // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE-NEXT: uunpklo z2.d, z1.s
+; CHECK-NEWLOWERING-SVE-NEXT: uunpkhi z1.d, z1.s
+; CHECK-NEWLOWERING-SVE-NEXT: add z0.d, z0.d, z2.d
+; CHECK-NEWLOWERING-SVE-NEXT: add z0.d, z0.d, z1.d
+; CHECK-NEWLOWERING-SVE-NEXT: ret
+;
+; CHECK-NEWLOWERING-SVE2-LABEL: unsigned_wide_add_nxv4i32:
+; CHECK-NEWLOWERING-SVE2: // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE2-NEXT: uaddwb z0.d, z0.d, z1.s
+; CHECK-NEWLOWERING-SVE2-NEXT: uaddwt z0.d, z0.d, z1.s
+; CHECK-NEWLOWERING-SVE2-NEXT: ret
entry:
%input.wide = zext <vscale x 4 x i32> %input to <vscale x 4 x i64>
%partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv4i64(<vscale x 2 x i64> %acc, <vscale x 4 x i64> %input.wide)
@@ -74,13 +87,19 @@ define <vscale x 4 x i32> @signed_wide_add_nxv8i16(<vscale x 4 x i32> %acc, <vsc
; CHECK-SVE-NEXT: add z0.s, z0.s, z1.s
; CHECK-SVE-NEXT: ret
;
-; CHECK-NEWLOWERING-LABEL: signed_wide_add_nxv8i16:
-; CHECK-NEWLOWERING: // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT: sunpklo z2.s, z1.h
-; CHECK-NEWLOWERING-NEXT: sunpkhi z1.s, z1.h
-; CHECK-NEWLOWERING-NEXT: add z0.s, z0.s, z2.s
-; CHECK-NEWLOWERING-NEXT: add z0.s, z0.s, z1.s
-; CHECK-NEWLOWERING-NEXT: ret
+; CHECK-NEWLOWERING-SVE-LABEL: signed_wide_add_nxv8i16:
+; CHECK-NEWLOWERING-SVE: // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE-NEXT: sunpklo z2.s, z1.h
+; CHECK-NEWLOWERING-SVE-NEXT: sunpkhi z1.s, z1.h
+; CHECK-NEWLOWERING-SVE-NEXT: add z0.s, z0.s, z2.s
+; CHECK-NEWLOWERING-SVE-NEXT: add z0.s, z0.s, z1.s
+; CHECK-NEWLOWERING-SVE-NEXT: ret
+;
+; CHECK-NEWLOWERING-SVE2-LABEL: signed_wide_add_nxv8i16:
+; CHECK-NEWLOWERING-SVE2: // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE2-NEXT: saddwb z0.d, z0.d, z1.s
+; CHECK-NEWLOWERING-SVE2-NEXT: saddwt z0.d, z0.d, z1.s
+; CHECK-NEWLOWERING-SVE2-NEXT: ret
entry:
%input.wide = sext <vscale x 8 x i16> %input to <vscale x 8 x i32>
%partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv8i32(<vscale x 4 x i32> %acc, <vscale x 8 x i32> %input.wide)
@@ -102,13 +121,19 @@ define <vscale x 4 x i32> @unsigned_wide_add_nxv8i16(<vscale x 4 x i32> %acc, <v
; CHECK-SVE-NEXT: add z0.s, z0.s, z1.s
; CHECK-SVE-NEXT: ret
;
-; CHECK-NEWLOWERING-LABEL: unsigned_wide_add_nxv8i16:
-; CHECK-NEWLOWERING: // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT: uunpklo z2.s, z1.h
-; CHECK-NEWLOWERING-NEXT: uunpkhi z1.s, z1.h
-; CHECK-NEWLOWERING-NEXT: add z0.s, z0.s, z2.s
-; CHECK-NEWLOWERING-NEXT: add z0.s, z0.s, z1.s
-; CHECK-NEWLOWERING-NEXT: ret
+; CHECK-NEWLOWERING-SVE-LABEL: unsigned_wide_add_nxv8i16:
+; CHECK-NEWLOWERING-SVE: // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE-NEXT: uunpklo z2.s, z1.h
+; CHECK-NEWLOWERING-SVE-NEXT: uunpkhi z1.s, z1.h
+; CHECK-NEWLOWERING-SVE-NEXT: add z0.s, z0.s, z2.s
+; CHECK-NEWLOWERING-SVE-NEXT: add z0.s, z0.s, z1.s
+; CHECK-NEWLOWERING-SVE-NEXT: ret
+;
+; CHECK-NEWLOWERING-SVE2-LABEL: unsigned_wide_add_nxv8i16:
+; CHECK-NEWLOWERING-SVE2: // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE2-NEXT: uaddwb z0.d, z0.d, z1.s
+; CHECK-NEWLOWERING-SVE2-NEXT: uaddwt z0.d, z0.d, z1.s
+; CHECK-NEWLOWERING-SVE2-NEXT: ret
entry:
%input.wide = zext <vscale x 8 x i16> %input to <vscale x 8 x i32>
%partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv8i32(<vscale x 4 x i32> %acc, <vscale x 8 x i32> %input.wide)
@@ -130,13 +155,19 @@ define <vscale x 8 x i16> @signed_wide_add_nxv16i8(<vscale x 8 x i16> %acc, <vsc
; CHECK-SVE-NEXT: add z0.h, z0.h, z1.h
; CHECK-SVE-NEXT: ret
;
-; CHECK-NEWLOWERING-LABEL: signed_wide_add_nxv16i8:
-; CHECK-NEWLOWERING: // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT: sunpklo z2.h, z1.b
-; CHECK-NEWLOWERING-NEXT: sunpkhi z1.h, z1.b
-; CHECK-NEWLOWERING-NEXT: add z0.h, z0.h, z2.h
-; CHECK-NEWLOWERING-NEXT: add z0.h, z0.h, z1.h
-; CHECK-NEWLOWERING-NEXT: ret
+; CHECK-NEWLOWERING-SVE-LABEL: signed_wide_add_nxv16i8:
+; CHECK-NEWLOWERING-SVE: // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE-NEXT: sunpklo z2.h, z1.b
+; CHECK-NEWLOWERING-SVE-NEXT: sunpkhi z1.h, z1.b
+; CHECK-NEWLOWERING-SVE-NEXT: add z0.h, z0.h, z2.h
+; CHECK-NEWLOWERING-SVE-NEXT: add z0.h, z0.h, z1.h
+; CHECK-NEWLOWERING-SVE-NEXT: ret
+;
+; CHECK-NEWLOWERING-SVE2-LABEL: signed_wide_add_nxv16i8:
+; CHECK-NEWLOWERING-SVE2: // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE2-NEXT: saddwb z0.d, z0.d, z1.s
+; CHECK-NEWLOWERING-SVE2-NEXT: saddwt z0.d, z0.d, z1.s
+; CHECK-NEWLOWERING-SVE2-NEXT: ret
entry:
%input.wide = sext <vscale x 16 x i8> %input to <vscale x 16 x i16>
%partial.reduce = tail call <vscale x 8 x i16> @llvm.experimental.vector.partial.reduce.add.nxv8i16.nxv16i16(<vscale x 8 x i16> %acc, <vscale x 16 x i16> %input.wide)
@@ -158,13 +189,19 @@ define <vscale x 8 x i16> @unsigned_wide_add_nxv16i8(<vscale x 8 x i16> %acc, <v
; CHECK-SVE-NEXT: add z0.h, z0.h, z1.h
; CHECK-SVE-NEXT: ret
;
-; CHECK-NEWLOWERING-LABEL: unsigned_wide_add_nxv16i8:
-; CHECK-NEWLOWERING: // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT: uunpklo z2.h, z1.b
-; CHECK-NEWLOWERING-NEXT: uunpkhi z1.h, z1.b
-; CHECK-NEWLOWERING-NEXT: add z0.h, z0.h, z2.h
-; CHECK-NEWLOWERING-NEXT: add z0.h, z0.h, z1.h
-; CHECK-NEWLOWERING-NEXT: ret
+; CHECK-NEWLOWERING-SVE-LABEL: unsigned_wide_add_nxv16i8:
+; CHECK-NEWLOWERING-SVE: // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE-NEXT: uunpklo z2.h, z1.b
+; CHECK-NEWLOWERING-SVE-NEXT: uunpkhi z1.h, z1.b
+; CHECK-NEWLOWERING-SVE-NEXT: add z0.h, z0.h, z2.h
+; CHECK-NEWLOWERING-SVE-NEXT: add z0.h, z0.h, z1.h
+; CHECK-NEWLOWERING-SVE-NEXT: ret
+;
+; CHECK-NEWLOWERING-SVE2-LABEL: unsigned_wide_add_nxv16i8:
+; CHECK-NEWLOWERING-SVE2: // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE2-NEXT: uaddwb z0.d, z0.d, z1.s
+; CHECK-NEWLOWERING-SVE2-NEXT: uaddwt z0.d, z0.d, z1.s
+; CHECK-NEWLOWERING-SVE2-NEXT: ret
entry:
%input.wide = zext <vscale x 16 x i8> %input to <vscale x 16 x i16>
%partial.reduce = tail call <vscale x 8 x i16> @llvm.experimental.vector.partial.reduce.add.nxv8i16.nxv16i16(<vscale x 8 x i16> %acc, <vscale x 16 x i16> %input.wide)
@@ -172,15 +209,43 @@ entry:
}
define <vscale x 2 x i32> @signed_wide_add_nxv4i16(<vscale x 2 x i32> %acc, <vscale x 4 x i16> %input){
-; CHECK-LABEL: signed_wide_add_nxv4i16:
-; CHECK: // %bb.0: // %entry
-; CHECK-NEXT: ptrue p0.s
-; CHECK-NEXT: sxth z1.s, p0/m, z1.s
-; CHECK-NEXT: uunpklo z2.d, z1.s
-; CHECK-NEXT: uunpkhi z1.d, z1.s
-; CHECK-NEXT: add z0.d, z0.d, z2.d
-; CHECK-NEXT: add z0.d, z1.d, z0.d
-; CHECK-NEXT: ret
+; CHECK-SVE2-LABEL: signed_wide_add_nxv4i16:
+; CHECK-SVE2: // %bb.0: // %entry
+; CHECK-SVE2-NEXT: ptrue p0.s
+; CHECK-SVE2-NEXT: sxth z1.s, p0/m, z1.s
+; CHECK-SVE2-NEXT: uunpklo z2.d, z1.s
+; CHECK-SVE2-NEXT: uunpkhi z1.d, z1.s
+; CHECK-SVE2-NEXT: add z0.d, z0.d, z2.d
+; CHECK-SVE2-NEXT: add z0.d, z1.d, z0.d
+; CHECK-SVE2-NEXT: ret
+;
+; CHECK-SVE-LABEL: signed_wide_add_nxv4i16:
+; CHECK-SVE: // %bb.0: // %entry
+; CHECK-SVE-NEXT: ptrue p0.s
+; CHECK-SVE-NEXT: sxth z1.s, p0/m, z1.s
+; CHECK-SVE-NEXT: uunpklo z2.d, z1.s
+; CHECK-SVE-NEXT: uunpkhi z1.d, z1.s
+; CHECK-SVE-NEXT: add z0.d, z0.d, z2.d
+; CHECK-SVE-NEXT: add z0.d, z1.d, z0.d
+; CHECK-SVE-NEXT: ret
+;
+; CHECK-NEWLOWERING-SVE-LABEL: signed_wide_add_nxv4i16:
+; CHECK-NEWLOWERING-SVE: // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE-NEXT: ptrue p0.s
+; CHECK-NEWLOWERING-SVE-NEXT: sxth z1.s, p0/m, z1.s
+; CHECK-NEWLOWERING-SVE-NEXT: uunpklo z2.d, z1.s
+; CHECK-NEWLOWERING-SVE-NEXT: uunpkhi z1.d, z1.s
+; CHECK-NEWLOWERING-SVE-NEXT: add z0.d, z0.d, z2.d
+; CHECK-NEWLOWERING-SVE-NEXT: add z0.d, z1.d, z0.d
+; CHECK-NEWLOWERING-SVE-NEXT: ret
+;
+; CHECK-NEWLOWERING-SVE2-LABEL: signed_wide_add_nxv4i16:
+; CHECK-NEWLOWERING-SVE2: // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE2-NEXT: ptrue p0.s
+; CHECK-NEWLOWERING-SVE2-NEXT: sxth z1.s, p0/m, z1.s
+; CHECK-NEWLOWERING-SVE2-NEXT: saddwb z0.d, z0.d, z1.s
+; CHECK-NEWLOWERING-SVE2-NEXT: saddwt z0.d, z0.d, z1.s
+; CHECK-NEWLOWERING-SVE2-NEXT: ret
entry:
%input.wide = sext <vscale x 4 x i16> %input to <vscale x 4 x i32>
%partial.reduce = tail call <vscale x 2 x i32> @llvm.experimental.vector.partial.reduce.add.nxv2i32.nxv4i32(<vscale x 2 x i32> %acc, <vscale x 4 x i32> %input.wide)
@@ -188,14 +253,39 @@ entry:
}
define <vscale x 2 x i32> @unsigned_wide_add_nxv4i16(<vscale x 2 x i32> %acc, <vscale x 4 x i16> %input){
-; CHECK-LABEL: unsigned_wide_add_nxv4i16:
-; CHECK: // %bb.0: // %entry
-; CHECK-NEXT: and z1.s, z1.s, #0xffff
-; CHECK-NEXT: uunpklo z2.d, z1.s
-; CHECK-NEXT: uunpkhi z1.d, z1.s
-; CHECK-NEXT: add z0.d, z0.d, z2.d
-; CHECK-NEXT: add z0.d, z1.d, z0.d
-; CHECK-NEXT: ret
+; CHECK-SVE2-LABEL: unsigned_wide_add_nxv4i16:
+; CHECK-SVE2: // %bb.0: // %entry
+; CHECK-SVE2-NEXT: and z1.s, z1.s, #0xffff
+; CHECK-SVE2-NEXT: uunpklo z2.d, z1.s
+; CHECK-SVE2-NEXT: uunpkhi z1.d, z1.s
+; CHECK-SVE2-NEXT: add z0.d, z0.d, z2.d
+; CHECK-SVE2-NEXT: add z0.d, z1.d, z0.d
+; CHECK-SVE2-NEXT: ret
+;
+; CHECK-SVE-LABEL: unsigned_wide_add_nxv4i16:
+; CHECK-SVE: // %bb.0: // %entry
+; CHECK-SVE-NEXT: and z1.s, z1.s, #0xffff
+; CHECK-SVE-NEXT: uunpklo z2.d, z1.s
+; CHECK-SVE-NEXT: uunpkhi z1.d, z1.s
+; CHECK-SVE-NEXT: add z0.d, z0.d, z2.d
+; CHECK-SVE-NEXT: add z0.d, z1.d, z0.d
+; CHECK-SVE-NEXT: ret
+;
+; CHECK-NEWLOWERING-SVE-LABEL: unsigned_wide_add_nxv4i16:
+; CHECK-NEWLOWERING-SVE: // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE-NEXT: and z1.s, z1.s, #0xffff
+; CHECK-NEWLOWERING-SVE-NEXT: uunpklo z2.d, z1.s
+; CHECK-NEWLOWERING-SVE-NEXT: uunpkhi z1.d, z1.s
+; CHECK-NEWLOWERING-SVE-NEXT: add z0.d, z0.d, z2.d
+; CHECK-NEWLOWERING-SVE-NEXT: add z0.d, z1.d, z0.d
+; CHECK-NEWLOWERING-SVE-NEXT: ret
+;
+; CHECK-NEWLOWERING-SVE2-LABEL: unsigned_wide_add_nxv4i16:
+; CHECK-NEWLOWERING-SVE2: // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE2-NEXT: and z1.s, z1.s, #0xffff
+; CHECK-NEWLOWERING-SVE2-NEXT: uaddwb z0.d, z0.d, z1.s
+; CHECK-NEWLOWERING-SVE2-NEXT: uaddwt z0.d, z0.d, z1.s
+; CHECK-NEWLOWERING-SVE2-NEXT: ret
entry:
%input.wide = zext <vscale x 4 x i16> %input to <vscale x 4 x i32>
%partial.reduce = tail call <vscale x 2 x i32> @llvm.experimental.vector.partial.reduce.add.nxv2i32.nxv4i32(<vscale x 2 x i32> %acc, <vscale x 4 x i32> %input.wide)
@@ -203,17 +293,49 @@ entry:
}
define <vscale x 4 x i64> @signed_wide_add_nxv8i32(<vscale x 4 x i64> %acc, <vscale x 8 x i32> %input){
-; CHECK-LABEL: signed_wide_add_nxv8i32:
-; CHECK: // %bb.0: // %entry
-; CHECK-NEXT: sunpklo z4.d, z3.s
-; CHECK-NEXT: sunpklo z5.d, z2.s
-; CHECK-NEXT: sunpkhi z3.d, z3.s
-; CHECK-NEXT: sunpkhi z2.d, z2.s
-; CHECK-NEXT: add z0.d, z0.d, z5.d
-; CHECK-NEXT: add z1.d, z1.d, z4.d
-; CHECK-NEXT: add z0.d, z0.d, z2.d
-; CHECK-NEXT: add z1.d, z1.d, z3.d
-; CHECK-NEXT: ret
+; CHECK-SVE2-LABEL: signed_wide_add_nxv8i32:
+; CHECK-SVE2: // %bb.0: // %entry
+; CHECK-SVE2-NEXT: sunpklo z4.d, z3.s
+; CHECK-SVE2-NEXT: sunpklo z5.d, z2.s
+; CHECK-SVE2-NEXT: sunpkhi z3.d, z3.s
+; CHECK-SVE2-NEXT: sunpkhi z2.d, z2.s
+; CHECK-SVE2-NEXT: add z0.d, z0.d, z5.d
+; CHECK-SVE2-NEXT: add z1.d, z1.d, z4.d
+; CHECK-SVE2-NEXT: add z0.d, z0.d, z2.d
+; CHECK-SVE2-NEXT: add z1.d, z1.d, z3.d
+; CHECK-SVE2-NEXT: ret
+;
+; CHECK-SVE-LABEL: signed_wide_add_nxv8i32:
+; CHECK-SVE: // %bb.0: // %entry
+; CHECK-SVE-NEXT: sunpklo z4.d, z3.s
+; CHECK-SVE-NEXT: sunpklo z5.d, z2.s
+; CHECK-SVE-NEXT: sunpkhi z3.d, z3.s
+; CHECK-SVE-NEXT: sunpkhi z2.d, z2.s
+; CHECK-SVE-NEXT: add z0.d, z0.d, z5.d
+; CHECK-SVE-NEXT: add z1.d, z1.d, z4.d
+; CHECK-SVE-NEXT: add z0.d, z0.d, z2.d
+; CHECK-SVE-NEXT: add z1.d, z1.d, z3.d
+; CHECK-SVE-NEXT: ret
+;
+; CHECK-NEWLOWERING-SVE-LABEL: signed_wide_add_nxv8i32:
+; CHECK-NEWLOWERING-SVE: // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE-NEXT: sunpklo z4.d, z3.s
+; CHECK-NEWLOWERING-SVE-NEXT: sunpklo z5.d, z2.s
+; CHECK-NEWLOWERING-SVE-NEXT: sunpkhi z3.d, z3.s
+; CHECK-NEWLOWERING-SVE-NEXT: sunpkhi z2.d, z2.s
+; CHECK-NEWLOWERING-SVE-NEXT: add z0.d, z0.d, z5.d
+; CHECK-NEWLOWERING-SVE-NEXT: add z1.d, z1.d, z4.d
+; CHECK-NEWLOWERING-SVE-NEXT: add z0.d, z0.d, z2.d
+; CHECK-NEWLOWERING-SVE-NEXT: add z1.d, z1.d, z3.d
+; CHECK-NEWLOWERING-SVE-NEXT: ret
+;
+; CHECK-NEWLOWERING-SVE2-LABEL: signed_wide_add_nxv8i32:
+; CHECK-NEWLOWERING-SVE2: // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE2-NEXT: saddwb z1.d, z1.d, z3.s
+; CHECK-NEWLOWERING-SVE2-NEXT: saddwb z0.d, z0.d, z2.s
+; CHECK-NEWLOWERING-SVE2-NEXT: saddwt z1.d, z1.d, z3.s
+; CHECK-NEWLOWERING-SVE2-NEXT: saddwt z0.d, z0.d, z2.s
+; CHECK-NEWLOWERING-SVE2-NEXT: ret
entry:
%input.wide = sext <vscale x 8 x i32> %input to <vscale x 8 x i64>
%partial.reduce = tail call <vscale x 4 x i64> @llvm.experimental.vector.partial.reduce.add.nxv4i64.nxv8i64(<vscale x 4 x i64> %acc, <vscale x 8 x i64> %input.wide)
@@ -221,17 +343,49 @@ entry:
}
define <vscale x 4 x i64> @unsigned_wide_add_nxv8i32(<vscale x 4 x i64> %acc, <vscale x 8 x i32> %input){
-; CHECK-LABEL: unsigned_wide_add_nxv8i32:
-; CHECK: // %bb.0: // %entry
-; CHECK-NEXT: uunpklo z4.d, z3.s
-; CHECK-NEXT: uunpklo z5.d, z2.s
-; CHECK-NEXT: uunpkhi z3.d, z3.s
-; CHECK-NEXT: uunpkhi z2.d, z2.s
-; CHECK-NEXT: add z0.d, z0.d, z5.d
-; CHECK-NEXT: add z1.d, z1.d, z4.d
-; CHECK-NEXT: add z0.d, z0.d, z2.d
-; CHECK-NEXT: add z1.d, z1.d, z3.d
-; CHECK-NEXT: ret
+; CHECK-SVE2-LABEL: unsigned_wide_add_nxv8i32:
+; CHECK-SVE2: // %bb.0: // %entry
+; CHECK-SVE2-NEXT: uunpklo z4.d, z3.s
+; CHECK-SVE2-NEXT: uunpklo z5.d, z2.s
+; CHECK-SVE2-NEXT: uunpkhi z3.d, z3.s
+; CHECK-SVE2-NEXT: uunpkhi z2.d, z2.s
+; CHECK-SVE2-NEXT: add z0.d, z0.d, z5.d
+; CHECK-SVE2-NEXT: add z1.d, z1.d, z4.d
+; CHECK-SVE2-NEXT: add z0.d, z0.d, z2.d
+; CHECK-SVE2-NEXT: add z1.d, z1.d, z3.d
+; CHECK-SVE2-NEXT: ret
+;
+; CHECK-SVE-LABEL: unsigned_wide_add_nxv8i32:
+; CHECK-SVE: // %bb.0: // %entry
+; CHECK-SVE-NEXT: uunpklo z4.d, z3.s
+; CHECK-SVE-NEXT: uunpklo z5.d, z2.s
+; CHECK-SVE-NEXT: uunpkhi z3.d, z3.s
+; CHECK-SVE-NEXT: uunpkhi z2.d, z2.s
+; CHECK-SVE-NEXT: add z0.d, z0.d, z5.d
+; CHECK-SVE-NEXT: add z1.d, z1.d, z4.d
+; CHECK-SVE-NEXT: add z0.d, z0.d, z2.d
+; CHECK-SVE-NEXT: add z1.d, z1.d, z3.d
+; CHECK-SVE-NEXT: ret
+;
+; CHECK-NEWLOWERING-SVE-LABEL: unsigned_wide_add_nxv8i32:
+; CHECK-NEWLOWERING-SVE: // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE-NEXT: uunpklo z4.d, z3.s
+; CHECK-NEWLOWERING-SVE-NEXT: uunpklo z5.d, z2.s
+; CHECK-NEWLOWERING-SVE-NEXT: uunpkhi z3.d, z3.s
+; CHECK-NEWLOWERING-SVE-NEXT: uunpkhi z2.d, z2.s
+; CHECK-NEWLOWERING-SVE-NEXT: add z0.d, z0.d, z5.d
+; CHECK-NEWLOWERING-SVE-NEXT: add z1.d, z1.d, z4.d
+; CHECK-NEWLOWERING-SVE-NEXT: add z0.d, z0.d, z2.d
+; CHECK-NEWLOWERING-SVE-NEXT: add z1.d, z1.d, z3.d
+; CHECK-NEWLOWERING-SVE-NEXT: ret
+;
+; CHECK-NEWLOWERING-SVE2-LABEL: unsigned_wide_add_nxv8i32:
+; CHECK-NEWLOWERING-SVE2: // %bb.0: // %entry
+; CHECK-NEWLOWERING-SVE2-NEXT: uaddwb z1.d, z1.d, z3.s
+; CHECK-NEWLOWERING-SVE2-NEXT: uaddwb z0.d, z0.d, z2.s
+; CHECK-NEWLOWERING-SVE2-NEXT: uaddwt z1.d, z1.d, z3.s
+; CHECK-NEWLOWERING-SVE2-NEXT: uaddwt z0.d, z0.d, z2.s
+; CHECK-NEWLOWERING-SVE2-NEXT: ret
entry:
%input.wide = zext <vscale x 8 x i32> %input to <vscale x 8 x i64>
%partial.reduce = tail call <vscale x 4 x i64> @llvm.experimental.vector.partial.reduce.add.nxv4i64.nxv8i64(<vscale x 4 x i64> %acc, <vscale x 8 x i64> %input.wide)
>From a31f871526cfdc829c3c6916c13ecf91d101776f Mon Sep 17 00:00:00 2001
From: Nick Guy <nicholas.guy at arm.com>
Date: Tue, 27 May 2025 17:29:00 +0100
Subject: [PATCH 2/4] Replace custom lowering with tablegen patterns.
---
llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 7 +--
.../Target/AArch64/AArch64ISelLowering.cpp | 63 ++++++++++---------
.../lib/Target/AArch64/AArch64SVEInstrInfo.td | 13 ++++
.../AArch64/sve-partial-reduce-dot-product.ll | 44 ++++---------
4 files changed, 60 insertions(+), 67 deletions(-)
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 0ac8f6f3a8171..c9644b0bf7909 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -12644,8 +12644,6 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
TLI.getTypeToTransformTo(*Context, LHSExtOpVT)))
return SDValue();
- EVT ResultVT = N->getValueType(0);
-
bool ExtIsSigned = LHSOpcode == ISD::SIGN_EXTEND;
unsigned NewOpcode =
ExtIsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
@@ -12659,7 +12657,7 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
(LHSOpcode != ISD::SIGN_EXTEND || CTrunc.sext(LHSBits) != C))
return SDValue();
- return DAG.getNode(NewOpcode, DL, ResultVT, Acc, LHSExtOp,
+ return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, LHSExtOp,
DAG.getConstant(CTrunc, DL, LHSExtOpVT));
}
@@ -12680,7 +12678,8 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
Op1.getValueType().getVectorElementType() != AccElemVT)
return SDValue();
- return DAG.getNode(NewOpcode, DL, ResultVT, Acc, LHSExtOp, RHSExtOp);
+ return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, LHSExtOp,
+ RHSExtOp);
}
// partial.reduce.umla(acc, zext(op), splat(1))
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index d602a62eaaf84..7bb7dede4d32b 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1873,9 +1873,9 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
// Wide add types
if (Subtarget->hasSVE2() || Subtarget->hasSME()) {
- setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv4i32, Custom);
- setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv8i16, Custom);
- setPartialReduceMLAAction(MVT::nxv8i16, MVT::nxv16i8, Custom);
+ setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv4i32, Legal);
+ setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv8i16, Legal);
+ setPartialReduceMLAAction(MVT::nxv8i16, MVT::nxv16i8, Legal);
}
}
@@ -29537,34 +29537,35 @@ AArch64TargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op,
SDValue LHS = Op.getOperand(1);
SDValue RHS = Op.getOperand(2);
EVT ResultVT = Op.getValueType();
-
- // Recognise Op as a wide add, if it is then we leave it as-is
- // Base: nxv2i64, Subdivision: nxv4i32
- auto IsEVTSubdivision = [](EVT Base, EVT Subdivision) -> bool {
- assert(Base.isVector() && Subdivision.isVector());
- assert(Base.isScalableVector() == Subdivision.isScalableVector());
-
- ElementCount BaseCount = Base.getVectorElementCount();
- ElementCount SubCount = Subdivision.getVectorElementCount();
- if (BaseCount * 2 != SubCount)
- return false;
-
- uint64_t BaseScalarSize = Base.getScalarSizeInBits();
- uint64_t SubScalarSize = Subdivision.getScalarSizeInBits();
- if (BaseScalarSize != SubScalarSize * 2)
- return false;
-
- return true;
- };
- if (IsEVTSubdivision(ResultVT, LHS.getValueType())) {
- // If it looks like a real wide add, we can leave it as-is and treat it as
- // Legal
- APInt C;
- if (ISD::isConstantSplatVector(RHS.getNode(), C) && C.isOne())
- return Op;
- // If it doesn't, then we need to expand it.
- return SDValue();
- }
+ //
+ // // Recognise Op as a wide add, if it is then we leave it as-is
+ // // Base: nxv2i64, Subdivision: nxv4i32
+ // auto IsEVTSubdivision = [](EVT Base, EVT Subdivision) -> bool {
+ // assert(Base.isVector() && Subdivision.isVector());
+ // assert(Base.isScalableVector() == Subdivision.isScalableVector());
+ //
+ // ElementCount BaseCount = Base.getVectorElementCount();
+ // ElementCount SubCount = Subdivision.getVectorElementCount();
+ // if (BaseCount * 2 != SubCount)
+ // return false;
+ //
+ // uint64_t BaseScalarSize = Base.getScalarSizeInBits();
+ // uint64_t SubScalarSize = Subdivision.getScalarSizeInBits();
+ // if (BaseScalarSize != SubScalarSize * 2)
+ // return false;
+ //
+ // return true;
+ // };
+ // if (IsEVTSubdivision(ResultVT, LHS.getValueType())) {
+ // // If it looks like a real wide add, we can leave it as-is and treat it
+ // as
+ // // Legal
+ // APInt C;
+ // if (ISD::isConstantSplatVector(RHS.getNode(), C) && C.isOne())
+ // return Op;
+ // // If it doesn't, then we need to expand it.
+ // return SDValue();
+ // }
assert(ResultVT == MVT::nxv2i64 && LHS.getValueType() == MVT::nxv16i8);
diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index b15caa25b604e..ecfca793d5862 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -3800,6 +3800,19 @@ let Predicates = [HasSVE2_or_SME] in {
def : Pat<(nxv8i16 (partial_reduce_smla nxv8i16:$Acc, nxv16i8:$Input, (nxv16i8 (splat_vector (i32 1))))),
(SADDWT_ZZZ_D (SADDWB_ZZZ_D $Acc, $Input), $Input)>;
+ def : Pat<(nxv8i16 (partial_reduce_umla nxv8i16:$Acc, nxv16i8:$LHS, nxv16i8:$RHS)),
+ (UMLALT_ZZZ_H (UMLALB_ZZZ_H $Acc, $LHS, $RHS), $LHS, $RHS)>;
+ def : Pat<(nxv8i16 (partial_reduce_smla nxv8i16:$Acc, nxv16i8:$LHS, nxv16i8:$RHS)),
+ (SMLALT_ZZZ_H (SMLALB_ZZZ_H $Acc, $LHS, $RHS), $LHS, $RHS)>;
+ def : Pat<(nxv4i32 (partial_reduce_umla nxv4i32:$Acc, nxv8i16:$LHS, nxv8i16:$RHS)),
+ (UMLALT_ZZZ_H (UMLALB_ZZZ_H $Acc, $LHS, $RHS), $LHS, $RHS)>;
+ def : Pat<(nxv4i32 (partial_reduce_smla nxv4i32:$Acc, nxv8i16:$LHS, nxv8i16:$RHS)),
+ (SMLALT_ZZZ_H (SMLALB_ZZZ_H $Acc, $LHS, $RHS), $LHS, $RHS)>;
+ def : Pat<(nxv2i64 (partial_reduce_umla nxv2i64:$Acc, nxv4i32:$LHS, nxv4i32:$RHS)),
+ (UMLALT_ZZZ_H (UMLALB_ZZZ_H $Acc, $LHS, $RHS), $LHS, $RHS)>;
+ def : Pat<(nxv2i64 (partial_reduce_smla nxv2i64:$Acc, nxv4i32:$LHS, nxv4i32:$RHS)),
+ (SMLALT_ZZZ_H (SMLALB_ZZZ_H $Acc, $LHS, $RHS), $LHS, $RHS)>;
+
// SVE2 integer multiply long
defm SQDMULLB_ZZZ : sve2_wide_int_arith_long<0b11000, "sqdmullb", int_aarch64_sve_sqdmullb>;
defm SQDMULLT_ZZZ : sve2_wide_int_arith_long<0b11001, "sqdmullt", int_aarch64_sve_sqdmullt>;
diff --git a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
index baa63a4ca31a2..889b2e73868a4 100644
--- a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
@@ -668,28 +668,18 @@ define <vscale x 4 x i32> @not_udot(<vscale x 4 x i32> %acc, <vscale x 8 x i8> %
;
; CHECK-NEWLOWERING-SVE2-LABEL: not_udot:
; CHECK-NEWLOWERING-SVE2: // %bb.0: // %entry
-; CHECK-NEWLOWERING-SVE2-NEXT: and z1.h, z1.h, #0xff
; CHECK-NEWLOWERING-SVE2-NEXT: and z2.h, z2.h, #0xff
-; CHECK-NEWLOWERING-SVE2-NEXT: ptrue p0.s
-; CHECK-NEWLOWERING-SVE2-NEXT: uunpklo z3.s, z2.h
-; CHECK-NEWLOWERING-SVE2-NEXT: uunpklo z4.s, z1.h
-; CHECK-NEWLOWERING-SVE2-NEXT: uunpkhi z2.s, z2.h
-; CHECK-NEWLOWERING-SVE2-NEXT: uunpkhi z1.s, z1.h
-; CHECK-NEWLOWERING-SVE2-NEXT: mla z0.s, p0/m, z4.s, z3.s
-; CHECK-NEWLOWERING-SVE2-NEXT: mla z0.s, p0/m, z1.s, z2.s
+; CHECK-NEWLOWERING-SVE2-NEXT: and z1.h, z1.h, #0xff
+; CHECK-NEWLOWERING-SVE2-NEXT: umlalb z0.h, z1.b, z2.b
+; CHECK-NEWLOWERING-SVE2-NEXT: umlalt z0.h, z1.b, z2.b
; CHECK-NEWLOWERING-SVE2-NEXT: ret
;
; CHECK-NEWLOWERING-SME-LABEL: not_udot:
; CHECK-NEWLOWERING-SME: // %bb.0: // %entry
-; CHECK-NEWLOWERING-SME-NEXT: and z1.h, z1.h, #0xff
; CHECK-NEWLOWERING-SME-NEXT: and z2.h, z2.h, #0xff
-; CHECK-NEWLOWERING-SME-NEXT: ptrue p0.s
-; CHECK-NEWLOWERING-SME-NEXT: uunpklo z3.s, z2.h
-; CHECK-NEWLOWERING-SME-NEXT: uunpklo z4.s, z1.h
-; CHECK-NEWLOWERING-SME-NEXT: uunpkhi z2.s, z2.h
-; CHECK-NEWLOWERING-SME-NEXT: uunpkhi z1.s, z1.h
-; CHECK-NEWLOWERING-SME-NEXT: mla z0.s, p0/m, z4.s, z3.s
-; CHECK-NEWLOWERING-SME-NEXT: mla z0.s, p0/m, z1.s, z2.s
+; CHECK-NEWLOWERING-SME-NEXT: and z1.h, z1.h, #0xff
+; CHECK-NEWLOWERING-SME-NEXT: umlalb z0.h, z1.b, z2.b
+; CHECK-NEWLOWERING-SME-NEXT: umlalt z0.h, z1.b, z2.b
; CHECK-NEWLOWERING-SME-NEXT: ret
entry:
%a.wide = zext <vscale x 8 x i8> %a to <vscale x 8 x i32>
@@ -728,28 +718,18 @@ define <vscale x 2 x i64> @not_udot_wide(<vscale x 2 x i64> %acc, <vscale x 4 x
;
; CHECK-NEWLOWERING-SVE2-LABEL: not_udot_wide:
; CHECK-NEWLOWERING-SVE2: // %bb.0: // %entry
-; CHECK-NEWLOWERING-SVE2-NEXT: and z1.s, z1.s, #0xffff
; CHECK-NEWLOWERING-SVE2-NEXT: and z2.s, z2.s, #0xffff
-; CHECK-NEWLOWERING-SVE2-NEXT: ptrue p0.d
-; CHECK-NEWLOWERING-SVE2-NEXT: uunpklo z3.d, z2.s
-; CHECK-NEWLOWERING-SVE2-NEXT: uunpklo z4.d, z1.s
-; CHECK-NEWLOWERING-SVE2-NEXT: uunpkhi z2.d, z2.s
-; CHECK-NEWLOWERING-SVE2-NEXT: uunpkhi z1.d, z1.s
-; CHECK-NEWLOWERING-SVE2-NEXT: mla z0.d, p0/m, z4.d, z3.d
-; CHECK-NEWLOWERING-SVE2-NEXT: mla z0.d, p0/m, z1.d, z2.d
+; CHECK-NEWLOWERING-SVE2-NEXT: and z1.s, z1.s, #0xffff
+; CHECK-NEWLOWERING-SVE2-NEXT: umlalb z0.h, z1.b, z2.b
+; CHECK-NEWLOWERING-SVE2-NEXT: umlalt z0.h, z1.b, z2.b
; CHECK-NEWLOWERING-SVE2-NEXT: ret
;
; CHECK-NEWLOWERING-SME-LABEL: not_udot_wide:
; CHECK-NEWLOWERING-SME: // %bb.0: // %entry
-; CHECK-NEWLOWERING-SME-NEXT: and z1.s, z1.s, #0xffff
; CHECK-NEWLOWERING-SME-NEXT: and z2.s, z2.s, #0xffff
-; CHECK-NEWLOWERING-SME-NEXT: ptrue p0.d
-; CHECK-NEWLOWERING-SME-NEXT: uunpklo z3.d, z2.s
-; CHECK-NEWLOWERING-SME-NEXT: uunpklo z4.d, z1.s
-; CHECK-NEWLOWERING-SME-NEXT: uunpkhi z2.d, z2.s
-; CHECK-NEWLOWERING-SME-NEXT: uunpkhi z1.d, z1.s
-; CHECK-NEWLOWERING-SME-NEXT: mla z0.d, p0/m, z4.d, z3.d
-; CHECK-NEWLOWERING-SME-NEXT: mla z0.d, p0/m, z1.d, z2.d
+; CHECK-NEWLOWERING-SME-NEXT: and z1.s, z1.s, #0xffff
+; CHECK-NEWLOWERING-SME-NEXT: umlalb z0.h, z1.b, z2.b
+; CHECK-NEWLOWERING-SME-NEXT: umlalt z0.h, z1.b, z2.b
; CHECK-NEWLOWERING-SME-NEXT: ret
entry:
%a.wide = zext <vscale x 4 x i16> %a to <vscale x 4 x i64>
>From e1274991074323c4fc974e76b45cf56918d77e9d Mon Sep 17 00:00:00 2001
From: Nick Guy <nicholas.guy at arm.com>
Date: Tue, 27 May 2025 18:08:52 +0100
Subject: [PATCH 3/4] Remove dead code
---
.../Target/AArch64/AArch64ISelLowering.cpp | 29 -------------------
1 file changed, 29 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 7bb7dede4d32b..789ddc0f289f7 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -29537,35 +29537,6 @@ AArch64TargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op,
SDValue LHS = Op.getOperand(1);
SDValue RHS = Op.getOperand(2);
EVT ResultVT = Op.getValueType();
- //
- // // Recognise Op as a wide add, if it is then we leave it as-is
- // // Base: nxv2i64, Subdivision: nxv4i32
- // auto IsEVTSubdivision = [](EVT Base, EVT Subdivision) -> bool {
- // assert(Base.isVector() && Subdivision.isVector());
- // assert(Base.isScalableVector() == Subdivision.isScalableVector());
- //
- // ElementCount BaseCount = Base.getVectorElementCount();
- // ElementCount SubCount = Subdivision.getVectorElementCount();
- // if (BaseCount * 2 != SubCount)
- // return false;
- //
- // uint64_t BaseScalarSize = Base.getScalarSizeInBits();
- // uint64_t SubScalarSize = Subdivision.getScalarSizeInBits();
- // if (BaseScalarSize != SubScalarSize * 2)
- // return false;
- //
- // return true;
- // };
- // if (IsEVTSubdivision(ResultVT, LHS.getValueType())) {
- // // If it looks like a real wide add, we can leave it as-is and treat it
- // as
- // // Legal
- // APInt C;
- // if (ISD::isConstantSplatVector(RHS.getNode(), C) && C.isOne())
- // return Op;
- // // If it doesn't, then we need to expand it.
- // return SDValue();
- // }
assert(ResultVT == MVT::nxv2i64 && LHS.getValueType() == MVT::nxv16i8);
>From 3010a2bbfcbaa46058b7232cbd4f48d630f42402 Mon Sep 17 00:00:00 2001
From: Nick Guy <nicholas.guy at arm.com>
Date: Wed, 28 May 2025 10:51:16 +0100
Subject: [PATCH 4/4] Use correct instructions for types
---
.../Target/AArch64/AArch64ISelLowering.cpp | 1 -
.../lib/Target/AArch64/AArch64SVEInstrInfo.td | 24 +++++++++----------
2 files changed, 12 insertions(+), 13 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 789ddc0f289f7..d71815f1521c7 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -29537,7 +29537,6 @@ AArch64TargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op,
SDValue LHS = Op.getOperand(1);
SDValue RHS = Op.getOperand(2);
EVT ResultVT = Op.getValueType();
-
assert(ResultVT == MVT::nxv2i64 && LHS.getValueType() == MVT::nxv16i8);
SDValue DotNode = DAG.getNode(Op.getOpcode(), DL, MVT::nxv4i32,
diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index ecfca793d5862..d04cf814b2b89 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -3792,25 +3792,25 @@ let Predicates = [HasSVE2_or_SME] in {
def : Pat<(nxv2i64 (partial_reduce_smla nxv2i64:$Acc, nxv4i32:$Input, (nxv4i32 (splat_vector (i32 1))))),
(SADDWT_ZZZ_D (SADDWB_ZZZ_D $Acc, $Input), $Input)>;
def : Pat<(nxv4i32 (partial_reduce_umla nxv4i32:$Acc, nxv8i16:$Input, (nxv8i16 (splat_vector (i32 1))))),
- (UADDWT_ZZZ_D (UADDWB_ZZZ_D $Acc, $Input), $Input)>;
+ (UADDWT_ZZZ_S (UADDWB_ZZZ_S $Acc, $Input), $Input)>;
def : Pat<(nxv4i32 (partial_reduce_smla nxv4i32:$Acc, nxv8i16:$Input, (nxv8i16 (splat_vector (i32 1))))),
- (SADDWT_ZZZ_D (SADDWB_ZZZ_D $Acc, $Input), $Input)>;
+ (SADDWT_ZZZ_S (SADDWB_ZZZ_S $Acc, $Input), $Input)>;
def : Pat<(nxv8i16 (partial_reduce_umla nxv8i16:$Acc, nxv16i8:$Input, (nxv16i8 (splat_vector (i32 1))))),
- (UADDWT_ZZZ_D (UADDWB_ZZZ_D $Acc, $Input), $Input)>;
+ (UADDWT_ZZZ_H (UADDWB_ZZZ_H $Acc, $Input), $Input)>;
def : Pat<(nxv8i16 (partial_reduce_smla nxv8i16:$Acc, nxv16i8:$Input, (nxv16i8 (splat_vector (i32 1))))),
- (SADDWT_ZZZ_D (SADDWB_ZZZ_D $Acc, $Input), $Input)>;
+ (SADDWT_ZZZ_H (SADDWB_ZZZ_H $Acc, $Input), $Input)>;
- def : Pat<(nxv8i16 (partial_reduce_umla nxv8i16:$Acc, nxv16i8:$LHS, nxv16i8:$RHS)),
- (UMLALT_ZZZ_H (UMLALB_ZZZ_H $Acc, $LHS, $RHS), $LHS, $RHS)>;
- def : Pat<(nxv8i16 (partial_reduce_smla nxv8i16:$Acc, nxv16i8:$LHS, nxv16i8:$RHS)),
- (SMLALT_ZZZ_H (SMLALB_ZZZ_H $Acc, $LHS, $RHS), $LHS, $RHS)>;
+ def : Pat<(nxv2i64 (partial_reduce_umla nxv2i64:$Acc, nxv4i32:$LHS, nxv4i32:$RHS)),
+ (UMLALT_ZZZ_D (UMLALB_ZZZ_H $Acc, $LHS, $RHS), $LHS, $RHS)>;
+ def : Pat<(nxv2i64 (partial_reduce_smla nxv2i64:$Acc, nxv4i32:$LHS, nxv4i32:$RHS)),
+ (SMLALT_ZZZ_D (SMLALB_ZZZ_H $Acc, $LHS, $RHS), $LHS, $RHS)>;
def : Pat<(nxv4i32 (partial_reduce_umla nxv4i32:$Acc, nxv8i16:$LHS, nxv8i16:$RHS)),
- (UMLALT_ZZZ_H (UMLALB_ZZZ_H $Acc, $LHS, $RHS), $LHS, $RHS)>;
+ (UMLALT_ZZZ_S (UMLALB_ZZZ_H $Acc, $LHS, $RHS), $LHS, $RHS)>;
def : Pat<(nxv4i32 (partial_reduce_smla nxv4i32:$Acc, nxv8i16:$LHS, nxv8i16:$RHS)),
- (SMLALT_ZZZ_H (SMLALB_ZZZ_H $Acc, $LHS, $RHS), $LHS, $RHS)>;
- def : Pat<(nxv2i64 (partial_reduce_umla nxv2i64:$Acc, nxv4i32:$LHS, nxv4i32:$RHS)),
+ (SMLALT_ZZZ_S (SMLALB_ZZZ_H $Acc, $LHS, $RHS), $LHS, $RHS)>;
+ def : Pat<(nxv8i16 (partial_reduce_umla nxv8i16:$Acc, nxv16i8:$LHS, nxv16i8:$RHS)),
(UMLALT_ZZZ_H (UMLALB_ZZZ_H $Acc, $LHS, $RHS), $LHS, $RHS)>;
- def : Pat<(nxv2i64 (partial_reduce_smla nxv2i64:$Acc, nxv4i32:$LHS, nxv4i32:$RHS)),
+ def : Pat<(nxv8i16 (partial_reduce_smla nxv8i16:$Acc, nxv16i8:$LHS, nxv16i8:$RHS)),
(SMLALT_ZZZ_H (SMLALB_ZZZ_H $Acc, $LHS, $RHS), $LHS, $RHS)>;
// SVE2 integer multiply long
More information about the llvm-commits
mailing list