[llvm] [DAGCombiner] Add generic DAG combine for ISD::PARTIAL_REDUCE_MLA (PR #127083)
James Chesterman via llvm-commits
llvm-commits at lists.llvm.org
Thu Feb 13 07:50:39 PST 2025
https://github.com/JamesChesterman created https://github.com/llvm/llvm-project/pull/127083
Add generic DAG combine for ISD::PARTIAL_REDUCE_U/SMLA nodes. Transforms the DAG from:
PARTIAL_REDUCE_MLA(Acc, MUL(EXT(MulOpLHS), EXT(MulOpRHS)), Splat(1)) to
PARTIAL_REDUCE_MLA(Acc, MulOpLHS, MulOpRHS).
>From db5999d03bbd305cebf9edb633de5114a8a518f7 Mon Sep 17 00:00:00 2001
From: James Chesterman <james.chesterman at arm.com>
Date: Fri, 31 Jan 2025 11:44:55 +0000
Subject: [PATCH 01/12] [AArch64] Add PARTIAL_REDUCE_U/SMLA ISD Nodes
Add signed and unsigned PARTIAL_REDUCE_MLA ISD nodes.
Add command line argument (new-partial-reduce-lowering) that
indicates whether the intrinsic experimental_vector_partial_
reduce_add will be transformed into the new ISD node.
Lowering with the new ISD nodes will, for now, always be done as
an expand.
---
llvm/include/llvm/CodeGen/ISDOpcodes.h | 14 +
llvm/include/llvm/CodeGen/SelectionDAG.h | 7 +
.../SelectionDAG/LegalizeIntegerTypes.cpp | 21 +
llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h | 4 +
.../SelectionDAG/LegalizeVectorTypes.cpp | 17 +
.../lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 17 +
.../SelectionDAG/SelectionDAGBuilder.cpp | 27 +
.../SelectionDAG/SelectionDAGDumper.cpp | 5 +
.../AArch64/sve-partial-reduce-dot-product.ll | 709 ++++++++++++++++--
.../AArch64/sve-partial-reduce-wide-add.ll | 49 ++
10 files changed, 796 insertions(+), 74 deletions(-)
diff --git a/llvm/include/llvm/CodeGen/ISDOpcodes.h b/llvm/include/llvm/CodeGen/ISDOpcodes.h
index fd8784a4c1003..3f235ee358e0e 100644
--- a/llvm/include/llvm/CodeGen/ISDOpcodes.h
+++ b/llvm/include/llvm/CodeGen/ISDOpcodes.h
@@ -1451,6 +1451,20 @@ enum NodeType {
VECREDUCE_UMAX,
VECREDUCE_UMIN,
+ // PARTIAL_REDUCE_*MLA (Accumulator, Input1, Input2)
+ // Partial reduction nodes. Input1 and Input2 are multiplied together before
+ // being reduced, by addition to the number of elements that Accumulator's
+ // type has.
+ // Input1 and Input2 must be the same type. Accumulator and the output must be
+ // the same type.
+ // The number of elements in Input1 and Input2 must be a positive integer
+ // multiple of the number of elements in the Accumulator / output type.
+ // All operands, as well as the output, must have the same element type.
+ // Operands: Accumulator, Input1, Input2
+ // Outputs: Output
+ PARTIAL_REDUCE_SMLA,
+ PARTIAL_REDUCE_UMLA,
+
// The `llvm.experimental.stackmap` intrinsic.
// Operands: input chain, glue, <id>, <numShadowBytes>, [live0[, live1...]]
// Outputs: output chain, glue
diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h
index 461c0c1ead16d..0fc6f6ccf85bd 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAG.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAG.h
@@ -1607,6 +1607,13 @@ class SelectionDAG {
/// the target's desired shift amount type.
SDValue getShiftAmountOperand(EVT LHSTy, SDValue Op);
+ // Expands PARTIAL_REDUCE_S/UMLA nodes
+ // \p Acc Accumulator for where the result is stored for the partial reduction
+ // operation.
+ // \p Input1 First input for the partial reduction operation
+ // \p Input2 Second input for the partial reduction operation
+ SDValue expandPartialReduceMLA(SDNode *N);
+
/// Create the DAG equivalent of vector_partial_reduce where Op1 and Op2 are
/// its operands and ReducedTY is the intrinsic's return type.
SDValue getPartialReduceAdd(SDLoc DL, EVT ReducedTy, SDValue Op1,
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
index f1a91a782bbf9..be22143a8e9d9 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
@@ -159,6 +159,11 @@ void DAGTypeLegalizer::PromoteIntegerResult(SDNode *N, unsigned ResNo) {
Res = PromoteIntRes_VECTOR_FIND_LAST_ACTIVE(N);
break;
+ case ISD::PARTIAL_REDUCE_UMLA:
+ case ISD::PARTIAL_REDUCE_SMLA:
+ Res = PromoteIntRes_PARTIAL_REDUCE_MLA(N);
+ break;
+
case ISD::SIGN_EXTEND:
case ISD::VP_SIGN_EXTEND:
case ISD::ZERO_EXTEND:
@@ -2099,6 +2104,10 @@ bool DAGTypeLegalizer::PromoteIntegerOperand(SDNode *N, unsigned OpNo) {
case ISD::VECTOR_FIND_LAST_ACTIVE:
Res = PromoteIntOp_VECTOR_FIND_LAST_ACTIVE(N, OpNo);
break;
+ case ISD::PARTIAL_REDUCE_UMLA:
+ case ISD::PARTIAL_REDUCE_SMLA:
+ Res = PromoteIntOp_PARTIAL_REDUCE_MLA(N);
+ break;
}
// If the result is null, the sub-method took care of registering results etc.
@@ -2881,6 +2890,12 @@ SDValue DAGTypeLegalizer::PromoteIntOp_VECTOR_FIND_LAST_ACTIVE(SDNode *N,
return SDValue(DAG.UpdateNodeOperands(N, NewOps), 0);
}
+SDValue DAGTypeLegalizer::PromoteIntOp_PARTIAL_REDUCE_MLA(SDNode *N) {
+ SDValue Res = DAG.expandPartialReduceMLA(N);
+ ReplaceValueWith(SDValue(N, 0), Res);
+ return SDValue();
+}
+
//===----------------------------------------------------------------------===//
// Integer Result Expansion
//===----------------------------------------------------------------------===//
@@ -6196,6 +6211,12 @@ SDValue DAGTypeLegalizer::PromoteIntRes_VECTOR_FIND_LAST_ACTIVE(SDNode *N) {
return DAG.getNode(ISD::VECTOR_FIND_LAST_ACTIVE, SDLoc(N), NVT, N->ops());
}
+SDValue DAGTypeLegalizer::PromoteIntRes_PARTIAL_REDUCE_MLA(SDNode *N) {
+ SDValue Res = DAG.expandPartialReduceMLA(N);
+ ReplaceValueWith(SDValue(N, 0), Res);
+ return SDValue();
+}
+
SDValue DAGTypeLegalizer::PromoteIntRes_INSERT_VECTOR_ELT(SDNode *N) {
EVT OutVT = N->getValueType(0);
EVT NOutVT = TLI.getTypeToTransformTo(*DAG.getContext(), OutVT);
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
index f13f70e66cfaa..cb9c1b239c0fa 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
@@ -379,6 +379,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
SDValue PromoteIntRes_IS_FPCLASS(SDNode *N);
SDValue PromoteIntRes_PATCHPOINT(SDNode *N);
SDValue PromoteIntRes_VECTOR_FIND_LAST_ACTIVE(SDNode *N);
+ SDValue PromoteIntRes_PARTIAL_REDUCE_MLA(SDNode *N);
// Integer Operand Promotion.
bool PromoteIntegerOperand(SDNode *N, unsigned OpNo);
@@ -430,6 +431,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
SDValue PromoteIntOp_VP_SPLICE(SDNode *N, unsigned OpNo);
SDValue PromoteIntOp_VECTOR_HISTOGRAM(SDNode *N, unsigned OpNo);
SDValue PromoteIntOp_VECTOR_FIND_LAST_ACTIVE(SDNode *N, unsigned OpNo);
+ SDValue PromoteIntOp_PARTIAL_REDUCE_MLA(SDNode *N);
void SExtOrZExtPromotedOperands(SDValue &LHS, SDValue &RHS);
void PromoteSetCCOperands(SDValue &LHS,SDValue &RHS, ISD::CondCode Code);
@@ -968,6 +970,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
void SplitVecRes_VAARG(SDNode *N, SDValue &Lo, SDValue &Hi);
void SplitVecRes_FP_TO_XINT_SAT(SDNode *N, SDValue &Lo, SDValue &Hi);
void SplitVecRes_VP_REVERSE(SDNode *N, SDValue &Lo, SDValue &Hi);
+ void SplitVecRes_PARTIAL_REDUCE_MLA(SDNode *N);
// Vector Operand Splitting: <128 x ty> -> 2 x <64 x ty>.
bool SplitVectorOperand(SDNode *N, unsigned OpNo);
@@ -999,6 +1002,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
SDValue SplitVecOp_FP_TO_XINT_SAT(SDNode *N);
SDValue SplitVecOp_VP_CttzElements(SDNode *N);
SDValue SplitVecOp_VECTOR_HISTOGRAM(SDNode *N);
+ SDValue SplitVecOp_PARTIAL_REDUCE_MLA(SDNode *N);
//===--------------------------------------------------------------------===//
// Vector Widening Support: LegalizeVectorTypes.cpp
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
index 1000235ab4061..b01470028981e 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
@@ -1373,6 +1373,9 @@ void DAGTypeLegalizer::SplitVectorResult(SDNode *N, unsigned ResNo) {
case ISD::EXPERIMENTAL_VP_REVERSE:
SplitVecRes_VP_REVERSE(N, Lo, Hi);
break;
+ case ISD::PARTIAL_REDUCE_UMLA:
+ case ISD::PARTIAL_REDUCE_SMLA:
+ SplitVecRes_PARTIAL_REDUCE_MLA(N);
}
// If Lo/Hi is null, the sub-method took care of registering results etc.
@@ -3182,6 +3185,11 @@ void DAGTypeLegalizer::SplitVecRes_VP_REVERSE(SDNode *N, SDValue &Lo,
std::tie(Lo, Hi) = DAG.SplitVector(Load, DL);
}
+void DAGTypeLegalizer::SplitVecRes_PARTIAL_REDUCE_MLA(SDNode *N) {
+ SDValue Res = DAG.expandPartialReduceMLA(N);
+ ReplaceValueWith(SDValue(N, 0), Res);
+}
+
void DAGTypeLegalizer::SplitVecRes_VECTOR_DEINTERLEAVE(SDNode *N) {
SDValue Op0Lo, Op0Hi, Op1Lo, Op1Hi;
@@ -3381,6 +3389,9 @@ bool DAGTypeLegalizer::SplitVectorOperand(SDNode *N, unsigned OpNo) {
case ISD::EXPERIMENTAL_VECTOR_HISTOGRAM:
Res = SplitVecOp_VECTOR_HISTOGRAM(N);
break;
+ case ISD::PARTIAL_REDUCE_UMLA:
+ case ISD::PARTIAL_REDUCE_SMLA:
+ Res = SplitVecOp_PARTIAL_REDUCE_MLA(N);
}
// If the result is null, the sub-method took care of registering results etc.
@@ -4435,6 +4446,12 @@ SDValue DAGTypeLegalizer::SplitVecOp_VECTOR_HISTOGRAM(SDNode *N) {
MMO, IndexType);
}
+SDValue DAGTypeLegalizer::SplitVecOp_PARTIAL_REDUCE_MLA(SDNode *N) {
+ SDValue Res = DAG.expandPartialReduceMLA(N);
+ ReplaceValueWith(SDValue(N, 0), Res);
+ return SDValue();
+}
+
//===----------------------------------------------------------------------===//
// Result Vector Widening
//===----------------------------------------------------------------------===//
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 16c3b295426c6..32fca9028c7af 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -2474,6 +2474,23 @@ SDValue SelectionDAG::getShiftAmountOperand(EVT LHSTy, SDValue Op) {
return getZExtOrTrunc(Op, SDLoc(Op), ShTy);
}
+SDValue SelectionDAG::expandPartialReduceMLA(SDNode *N) {
+ SDLoc DL(N);
+ SDValue Acc = N->getOperand(0);
+ SDValue Input1 = N->getOperand(1);
+ SDValue Input2 = N->getOperand(2);
+
+ EVT FullTy = Input1.getValueType();
+
+ SDValue Input = Input1;
+ APInt ConstantOne;
+ if (!ISD::isConstantSplatVector(Input2.getNode(), ConstantOne) ||
+ !ConstantOne.isOne())
+ Input = getNode(ISD::MUL, DL, FullTy, Input1, Input2);
+
+ return getPartialReduceAdd(DL, Acc.getValueType(), Acc, Input);
+}
+
SDValue SelectionDAG::getPartialReduceAdd(SDLoc DL, EVT ReducedTy, SDValue Op1,
SDValue Op2) {
EVT FullTy = Op2.getValueType();
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 428e7a316d247..144439f136ff1 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -135,6 +135,10 @@ static cl::opt<unsigned> SwitchPeelThreshold(
"switch statement. A value greater than 100 will void this "
"optimization"));
+static cl::opt<bool> NewPartialReduceLowering(
+ "new-partial-reduce-lowering", cl::init(false), cl::ReallyHidden,
+ cl::desc("Use the new method of lowering partial reductions."));
+
// Limit the width of DAG chains. This is important in general to prevent
// DAG-based analysis from blowing up. For example, alias analysis and
// load clustering may not complete in reasonable time. It is difficult to
@@ -8118,6 +8122,29 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
return;
}
case Intrinsic::experimental_vector_partial_reduce_add: {
+ if (NewPartialReduceLowering) {
+ SDValue Acc = getValue(I.getOperand(0));
+ EVT AccVT = Acc.getValueType();
+ SDValue Input = getValue(I.getOperand(1));
+ EVT InputVT = Input.getValueType();
+
+ assert(AccVT.getVectorElementType() == InputVT.getVectorElementType() &&
+ "Expected operands to have the same vector element type!");
+ assert(
+ InputVT.getVectorElementCount().getKnownMinValue() %
+ AccVT.getVectorElementCount().getKnownMinValue() ==
+ 0 &&
+ "Expected the element count of the Input operand to be a positive "
+ "integer multiple of the element count of the Accumulator operand!");
+
+ // ISD::PARTIAL_REDUCE_UMLA is chosen arbitrarily and would function the
+ // same if ISD::PARTIAL_REDUCE_SMLA was chosen instead. It should be
+ // changed to its correct signedness when combining or expanding,
+ // according to extends being performed on Input.
+ setValue(&I, DAG.getNode(ISD::PARTIAL_REDUCE_UMLA, sdl, AccVT, Acc, Input,
+ DAG.getConstant(1, sdl, InputVT)));
+ return;
+ }
if (!TLI.shouldExpandPartialReductionIntrinsic(cast<IntrinsicInst>(&I))) {
visitTargetIntrinsic(I, Intrinsic);
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
index f63c8dd3df1c8..a387c10679261 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
@@ -570,6 +570,11 @@ std::string SDNode::getOperationName(const SelectionDAG *G) const {
case ISD::VECTOR_FIND_LAST_ACTIVE:
return "find_last_active";
+ case ISD::PARTIAL_REDUCE_UMLA:
+ return "partial_reduce_umla";
+ case ISD::PARTIAL_REDUCE_SMLA:
+ return "partial_reduce_smla";
+
// Vector Predication
#define BEGIN_REGISTER_VP_SDNODE(SDID, LEGALARG, NAME, ...) \
case ISD::SDID: \
diff --git a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
index 66f83c658ff4f..16c0001dbdb83 100644
--- a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
@@ -1,12 +1,41 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
; RUN: llc -mtriple=aarch64 -mattr=+sve2,+i8mm %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-I8MM
; RUN: llc -mtriple=aarch64 -mattr=+sve2 %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-NOI8MM
+; RUN: llc -mtriple=aarch64 -mattr=+sve2,+i8mm -new-partial-reduce-lowering %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-NEWLOWERING
define <vscale x 4 x i32> @udot(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
-; CHECK-LABEL: udot:
-; CHECK: // %bb.0: // %entry
-; CHECK-NEXT: udot z0.s, z1.b, z2.b
-; CHECK-NEXT: ret
+; CHECK-I8MM-LABEL: udot:
+; CHECK-I8MM: // %bb.0: // %entry
+; CHECK-I8MM-NEXT: udot z0.s, z1.b, z2.b
+; CHECK-I8MM-NEXT: ret
+;
+; CHECK-NOI8MM-LABEL: udot:
+; CHECK-NOI8MM: // %bb.0: // %entry
+; CHECK-NOI8MM-NEXT: udot z0.s, z1.b, z2.b
+; CHECK-NOI8MM-NEXT: ret
+;
+; CHECK-NEWLOWERING-LABEL: udot:
+; CHECK-NEWLOWERING: // %bb.0: // %entry
+; CHECK-NEWLOWERING-NEXT: uunpklo z3.h, z1.b
+; CHECK-NEWLOWERING-NEXT: uunpklo z4.h, z2.b
+; CHECK-NEWLOWERING-NEXT: uunpkhi z1.h, z1.b
+; CHECK-NEWLOWERING-NEXT: uunpkhi z2.h, z2.b
+; CHECK-NEWLOWERING-NEXT: ptrue p0.s
+; CHECK-NEWLOWERING-NEXT: uunpklo z5.s, z3.h
+; CHECK-NEWLOWERING-NEXT: uunpkhi z3.s, z3.h
+; CHECK-NEWLOWERING-NEXT: uunpklo z6.s, z4.h
+; CHECK-NEWLOWERING-NEXT: uunpkhi z4.s, z4.h
+; CHECK-NEWLOWERING-NEXT: uunpklo z7.s, z1.h
+; CHECK-NEWLOWERING-NEXT: uunpkhi z1.s, z1.h
+; CHECK-NEWLOWERING-NEXT: uunpklo z24.s, z2.h
+; CHECK-NEWLOWERING-NEXT: uunpkhi z2.s, z2.h
+; CHECK-NEWLOWERING-NEXT: mla z0.s, p0/m, z5.s, z6.s
+; CHECK-NEWLOWERING-NEXT: mul z3.s, z3.s, z4.s
+; CHECK-NEWLOWERING-NEXT: mla z0.s, p0/m, z1.s, z2.s
+; CHECK-NEWLOWERING-NEXT: movprfx z1, z3
+; CHECK-NEWLOWERING-NEXT: mla z1.s, p0/m, z7.s, z24.s
+; CHECK-NEWLOWERING-NEXT: add z0.s, z1.s, z0.s
+; CHECK-NEWLOWERING-NEXT: ret
entry:
%a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
%b.wide = zext <vscale x 16 x i8> %b to <vscale x 16 x i32>
@@ -16,10 +45,38 @@ entry:
}
define <vscale x 2 x i64> @udot_wide(<vscale x 2 x i64> %acc, <vscale x 8 x i16> %a, <vscale x 8 x i16> %b) {
-; CHECK-LABEL: udot_wide:
-; CHECK: // %bb.0: // %entry
-; CHECK-NEXT: udot z0.d, z1.h, z2.h
-; CHECK-NEXT: ret
+; CHECK-I8MM-LABEL: udot_wide:
+; CHECK-I8MM: // %bb.0: // %entry
+; CHECK-I8MM-NEXT: udot z0.d, z1.h, z2.h
+; CHECK-I8MM-NEXT: ret
+;
+; CHECK-NOI8MM-LABEL: udot_wide:
+; CHECK-NOI8MM: // %bb.0: // %entry
+; CHECK-NOI8MM-NEXT: udot z0.d, z1.h, z2.h
+; CHECK-NOI8MM-NEXT: ret
+;
+; CHECK-NEWLOWERING-LABEL: udot_wide:
+; CHECK-NEWLOWERING: // %bb.0: // %entry
+; CHECK-NEWLOWERING-NEXT: uunpklo z3.s, z1.h
+; CHECK-NEWLOWERING-NEXT: uunpklo z4.s, z2.h
+; CHECK-NEWLOWERING-NEXT: uunpkhi z1.s, z1.h
+; CHECK-NEWLOWERING-NEXT: uunpkhi z2.s, z2.h
+; CHECK-NEWLOWERING-NEXT: ptrue p0.d
+; CHECK-NEWLOWERING-NEXT: uunpklo z5.d, z3.s
+; CHECK-NEWLOWERING-NEXT: uunpkhi z3.d, z3.s
+; CHECK-NEWLOWERING-NEXT: uunpklo z6.d, z4.s
+; CHECK-NEWLOWERING-NEXT: uunpkhi z4.d, z4.s
+; CHECK-NEWLOWERING-NEXT: uunpklo z7.d, z1.s
+; CHECK-NEWLOWERING-NEXT: uunpkhi z1.d, z1.s
+; CHECK-NEWLOWERING-NEXT: uunpklo z24.d, z2.s
+; CHECK-NEWLOWERING-NEXT: uunpkhi z2.d, z2.s
+; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z5.d, z6.d
+; CHECK-NEWLOWERING-NEXT: mul z3.d, z3.d, z4.d
+; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z1.d, z2.d
+; CHECK-NEWLOWERING-NEXT: movprfx z1, z3
+; CHECK-NEWLOWERING-NEXT: mla z1.d, p0/m, z7.d, z24.d
+; CHECK-NEWLOWERING-NEXT: add z0.d, z1.d, z0.d
+; CHECK-NEWLOWERING-NEXT: ret
entry:
%a.wide = zext <vscale x 8 x i16> %a to <vscale x 8 x i64>
%b.wide = zext <vscale x 8 x i16> %b to <vscale x 8 x i64>
@@ -29,10 +86,38 @@ entry:
}
define <vscale x 4 x i32> @sdot(<vscale x 4 x i32> %accc, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
-; CHECK-LABEL: sdot:
-; CHECK: // %bb.0: // %entry
-; CHECK-NEXT: sdot z0.s, z1.b, z2.b
-; CHECK-NEXT: ret
+; CHECK-I8MM-LABEL: sdot:
+; CHECK-I8MM: // %bb.0: // %entry
+; CHECK-I8MM-NEXT: sdot z0.s, z1.b, z2.b
+; CHECK-I8MM-NEXT: ret
+;
+; CHECK-NOI8MM-LABEL: sdot:
+; CHECK-NOI8MM: // %bb.0: // %entry
+; CHECK-NOI8MM-NEXT: sdot z0.s, z1.b, z2.b
+; CHECK-NOI8MM-NEXT: ret
+;
+; CHECK-NEWLOWERING-LABEL: sdot:
+; CHECK-NEWLOWERING: // %bb.0: // %entry
+; CHECK-NEWLOWERING-NEXT: sunpklo z3.h, z1.b
+; CHECK-NEWLOWERING-NEXT: sunpklo z4.h, z2.b
+; CHECK-NEWLOWERING-NEXT: sunpkhi z1.h, z1.b
+; CHECK-NEWLOWERING-NEXT: sunpkhi z2.h, z2.b
+; CHECK-NEWLOWERING-NEXT: ptrue p0.s
+; CHECK-NEWLOWERING-NEXT: sunpklo z5.s, z3.h
+; CHECK-NEWLOWERING-NEXT: sunpkhi z3.s, z3.h
+; CHECK-NEWLOWERING-NEXT: sunpklo z6.s, z4.h
+; CHECK-NEWLOWERING-NEXT: sunpkhi z4.s, z4.h
+; CHECK-NEWLOWERING-NEXT: sunpklo z7.s, z1.h
+; CHECK-NEWLOWERING-NEXT: sunpkhi z1.s, z1.h
+; CHECK-NEWLOWERING-NEXT: sunpklo z24.s, z2.h
+; CHECK-NEWLOWERING-NEXT: sunpkhi z2.s, z2.h
+; CHECK-NEWLOWERING-NEXT: mla z0.s, p0/m, z5.s, z6.s
+; CHECK-NEWLOWERING-NEXT: mul z3.s, z3.s, z4.s
+; CHECK-NEWLOWERING-NEXT: mla z0.s, p0/m, z1.s, z2.s
+; CHECK-NEWLOWERING-NEXT: movprfx z1, z3
+; CHECK-NEWLOWERING-NEXT: mla z1.s, p0/m, z7.s, z24.s
+; CHECK-NEWLOWERING-NEXT: add z0.s, z1.s, z0.s
+; CHECK-NEWLOWERING-NEXT: ret
entry:
%a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
%b.wide = sext <vscale x 16 x i8> %b to <vscale x 16 x i32>
@@ -42,10 +127,38 @@ entry:
}
define <vscale x 2 x i64> @sdot_wide(<vscale x 2 x i64> %acc, <vscale x 8 x i16> %a, <vscale x 8 x i16> %b) {
-; CHECK-LABEL: sdot_wide:
-; CHECK: // %bb.0: // %entry
-; CHECK-NEXT: sdot z0.d, z1.h, z2.h
-; CHECK-NEXT: ret
+; CHECK-I8MM-LABEL: sdot_wide:
+; CHECK-I8MM: // %bb.0: // %entry
+; CHECK-I8MM-NEXT: sdot z0.d, z1.h, z2.h
+; CHECK-I8MM-NEXT: ret
+;
+; CHECK-NOI8MM-LABEL: sdot_wide:
+; CHECK-NOI8MM: // %bb.0: // %entry
+; CHECK-NOI8MM-NEXT: sdot z0.d, z1.h, z2.h
+; CHECK-NOI8MM-NEXT: ret
+;
+; CHECK-NEWLOWERING-LABEL: sdot_wide:
+; CHECK-NEWLOWERING: // %bb.0: // %entry
+; CHECK-NEWLOWERING-NEXT: sunpklo z3.s, z1.h
+; CHECK-NEWLOWERING-NEXT: sunpklo z4.s, z2.h
+; CHECK-NEWLOWERING-NEXT: sunpkhi z1.s, z1.h
+; CHECK-NEWLOWERING-NEXT: sunpkhi z2.s, z2.h
+; CHECK-NEWLOWERING-NEXT: ptrue p0.d
+; CHECK-NEWLOWERING-NEXT: sunpklo z5.d, z3.s
+; CHECK-NEWLOWERING-NEXT: sunpkhi z3.d, z3.s
+; CHECK-NEWLOWERING-NEXT: sunpklo z6.d, z4.s
+; CHECK-NEWLOWERING-NEXT: sunpkhi z4.d, z4.s
+; CHECK-NEWLOWERING-NEXT: sunpklo z7.d, z1.s
+; CHECK-NEWLOWERING-NEXT: sunpkhi z1.d, z1.s
+; CHECK-NEWLOWERING-NEXT: sunpklo z24.d, z2.s
+; CHECK-NEWLOWERING-NEXT: sunpkhi z2.d, z2.s
+; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z5.d, z6.d
+; CHECK-NEWLOWERING-NEXT: mul z3.d, z3.d, z4.d
+; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z1.d, z2.d
+; CHECK-NEWLOWERING-NEXT: movprfx z1, z3
+; CHECK-NEWLOWERING-NEXT: mla z1.d, p0/m, z7.d, z24.d
+; CHECK-NEWLOWERING-NEXT: add z0.d, z1.d, z0.d
+; CHECK-NEWLOWERING-NEXT: ret
entry:
%a.wide = sext <vscale x 8 x i16> %a to <vscale x 8 x i64>
%b.wide = sext <vscale x 8 x i16> %b to <vscale x 8 x i64>
@@ -82,6 +195,29 @@ define <vscale x 4 x i32> @usdot(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a,
; CHECK-NOI8MM-NEXT: mla z1.s, p0/m, z7.s, z24.s
; CHECK-NOI8MM-NEXT: add z0.s, z1.s, z0.s
; CHECK-NOI8MM-NEXT: ret
+;
+; CHECK-NEWLOWERING-LABEL: usdot:
+; CHECK-NEWLOWERING: // %bb.0: // %entry
+; CHECK-NEWLOWERING-NEXT: uunpklo z3.h, z1.b
+; CHECK-NEWLOWERING-NEXT: sunpklo z4.h, z2.b
+; CHECK-NEWLOWERING-NEXT: uunpkhi z1.h, z1.b
+; CHECK-NEWLOWERING-NEXT: sunpkhi z2.h, z2.b
+; CHECK-NEWLOWERING-NEXT: ptrue p0.s
+; CHECK-NEWLOWERING-NEXT: uunpklo z5.s, z3.h
+; CHECK-NEWLOWERING-NEXT: uunpkhi z3.s, z3.h
+; CHECK-NEWLOWERING-NEXT: sunpklo z6.s, z4.h
+; CHECK-NEWLOWERING-NEXT: sunpkhi z4.s, z4.h
+; CHECK-NEWLOWERING-NEXT: uunpklo z7.s, z1.h
+; CHECK-NEWLOWERING-NEXT: uunpkhi z1.s, z1.h
+; CHECK-NEWLOWERING-NEXT: sunpklo z24.s, z2.h
+; CHECK-NEWLOWERING-NEXT: sunpkhi z2.s, z2.h
+; CHECK-NEWLOWERING-NEXT: mla z0.s, p0/m, z5.s, z6.s
+; CHECK-NEWLOWERING-NEXT: mul z3.s, z3.s, z4.s
+; CHECK-NEWLOWERING-NEXT: mla z0.s, p0/m, z1.s, z2.s
+; CHECK-NEWLOWERING-NEXT: movprfx z1, z3
+; CHECK-NEWLOWERING-NEXT: mla z1.s, p0/m, z7.s, z24.s
+; CHECK-NEWLOWERING-NEXT: add z0.s, z1.s, z0.s
+; CHECK-NEWLOWERING-NEXT: ret
entry:
%a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
%b.wide = sext <vscale x 16 x i8> %b to <vscale x 16 x i32>
@@ -118,6 +254,29 @@ define <vscale x 4 x i32> @sudot(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a,
; CHECK-NOI8MM-NEXT: mla z1.s, p0/m, z7.s, z24.s
; CHECK-NOI8MM-NEXT: add z0.s, z1.s, z0.s
; CHECK-NOI8MM-NEXT: ret
+;
+; CHECK-NEWLOWERING-LABEL: sudot:
+; CHECK-NEWLOWERING: // %bb.0: // %entry
+; CHECK-NEWLOWERING-NEXT: sunpklo z3.h, z1.b
+; CHECK-NEWLOWERING-NEXT: uunpklo z4.h, z2.b
+; CHECK-NEWLOWERING-NEXT: sunpkhi z1.h, z1.b
+; CHECK-NEWLOWERING-NEXT: uunpkhi z2.h, z2.b
+; CHECK-NEWLOWERING-NEXT: ptrue p0.s
+; CHECK-NEWLOWERING-NEXT: sunpklo z5.s, z3.h
+; CHECK-NEWLOWERING-NEXT: sunpkhi z3.s, z3.h
+; CHECK-NEWLOWERING-NEXT: uunpklo z6.s, z4.h
+; CHECK-NEWLOWERING-NEXT: uunpkhi z4.s, z4.h
+; CHECK-NEWLOWERING-NEXT: sunpklo z7.s, z1.h
+; CHECK-NEWLOWERING-NEXT: sunpkhi z1.s, z1.h
+; CHECK-NEWLOWERING-NEXT: uunpklo z24.s, z2.h
+; CHECK-NEWLOWERING-NEXT: uunpkhi z2.s, z2.h
+; CHECK-NEWLOWERING-NEXT: mla z0.s, p0/m, z5.s, z6.s
+; CHECK-NEWLOWERING-NEXT: mul z3.s, z3.s, z4.s
+; CHECK-NEWLOWERING-NEXT: mla z0.s, p0/m, z1.s, z2.s
+; CHECK-NEWLOWERING-NEXT: movprfx z1, z3
+; CHECK-NEWLOWERING-NEXT: mla z1.s, p0/m, z7.s, z24.s
+; CHECK-NEWLOWERING-NEXT: add z0.s, z1.s, z0.s
+; CHECK-NEWLOWERING-NEXT: ret
entry:
%a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
%b.wide = zext <vscale x 16 x i8> %b to <vscale x 16 x i32>
@@ -127,15 +286,82 @@ entry:
}
define <vscale x 4 x i64> @udot_8to64(<vscale x 4 x i64> %acc, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
-; CHECK-LABEL: udot_8to64:
-; CHECK: // %bb.0: // %entry
-; CHECK-NEXT: mov z4.s, #0 // =0x0
-; CHECK-NEXT: udot z4.s, z2.b, z3.b
-; CHECK-NEXT: sunpklo z2.d, z4.s
-; CHECK-NEXT: sunpkhi z3.d, z4.s
-; CHECK-NEXT: add z0.d, z0.d, z2.d
-; CHECK-NEXT: add z1.d, z1.d, z3.d
-; CHECK-NEXT: ret
+; CHECK-I8MM-LABEL: udot_8to64:
+; CHECK-I8MM: // %bb.0: // %entry
+; CHECK-I8MM-NEXT: mov z4.s, #0 // =0x0
+; CHECK-I8MM-NEXT: udot z4.s, z2.b, z3.b
+; CHECK-I8MM-NEXT: sunpklo z2.d, z4.s
+; CHECK-I8MM-NEXT: sunpkhi z3.d, z4.s
+; CHECK-I8MM-NEXT: add z0.d, z0.d, z2.d
+; CHECK-I8MM-NEXT: add z1.d, z1.d, z3.d
+; CHECK-I8MM-NEXT: ret
+;
+; CHECK-NOI8MM-LABEL: udot_8to64:
+; CHECK-NOI8MM: // %bb.0: // %entry
+; CHECK-NOI8MM-NEXT: mov z4.s, #0 // =0x0
+; CHECK-NOI8MM-NEXT: udot z4.s, z2.b, z3.b
+; CHECK-NOI8MM-NEXT: sunpklo z2.d, z4.s
+; CHECK-NOI8MM-NEXT: sunpkhi z3.d, z4.s
+; CHECK-NOI8MM-NEXT: add z0.d, z0.d, z2.d
+; CHECK-NOI8MM-NEXT: add z1.d, z1.d, z3.d
+; CHECK-NOI8MM-NEXT: ret
+;
+; CHECK-NEWLOWERING-LABEL: udot_8to64:
+; CHECK-NEWLOWERING: // %bb.0: // %entry
+; CHECK-NEWLOWERING-NEXT: str x29, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-NEWLOWERING-NEXT: addvl sp, sp, #-2
+; CHECK-NEWLOWERING-NEXT: str z9, [sp] // 16-byte Folded Spill
+; CHECK-NEWLOWERING-NEXT: str z8, [sp, #1, mul vl] // 16-byte Folded Spill
+; CHECK-NEWLOWERING-NEXT: .cfi_escape 0x0f, 0x0c, 0x8f, 0x00, 0x11, 0x10, 0x22, 0x11, 0x10, 0x92, 0x2e, 0x00, 0x1e, 0x22 // sp + 16 + 16 * VG
+; CHECK-NEWLOWERING-NEXT: .cfi_offset w29, -16
+; CHECK-NEWLOWERING-NEXT: .cfi_escape 0x10, 0x48, 0x0a, 0x11, 0x70, 0x22, 0x11, 0x78, 0x92, 0x2e, 0x00, 0x1e, 0x22 // $d8 @ cfa - 16 - 8 * VG
+; CHECK-NEWLOWERING-NEXT: .cfi_escape 0x10, 0x49, 0x0a, 0x11, 0x70, 0x22, 0x11, 0x70, 0x92, 0x2e, 0x00, 0x1e, 0x22 // $d9 @ cfa - 16 - 16 * VG
+; CHECK-NEWLOWERING-NEXT: uunpklo z4.h, z2.b
+; CHECK-NEWLOWERING-NEXT: uunpklo z5.h, z3.b
+; CHECK-NEWLOWERING-NEXT: uunpkhi z2.h, z2.b
+; CHECK-NEWLOWERING-NEXT: uunpkhi z3.h, z3.b
+; CHECK-NEWLOWERING-NEXT: ptrue p0.d
+; CHECK-NEWLOWERING-NEXT: uunpklo z6.s, z4.h
+; CHECK-NEWLOWERING-NEXT: uunpkhi z4.s, z4.h
+; CHECK-NEWLOWERING-NEXT: uunpklo z7.s, z5.h
+; CHECK-NEWLOWERING-NEXT: uunpkhi z5.s, z5.h
+; CHECK-NEWLOWERING-NEXT: uunpklo z24.s, z2.h
+; CHECK-NEWLOWERING-NEXT: uunpkhi z2.s, z2.h
+; CHECK-NEWLOWERING-NEXT: uunpklo z25.s, z3.h
+; CHECK-NEWLOWERING-NEXT: uunpkhi z3.s, z3.h
+; CHECK-NEWLOWERING-NEXT: uunpkhi z26.d, z6.s
+; CHECK-NEWLOWERING-NEXT: uunpklo z6.d, z6.s
+; CHECK-NEWLOWERING-NEXT: uunpklo z27.d, z4.s
+; CHECK-NEWLOWERING-NEXT: uunpklo z28.d, z7.s
+; CHECK-NEWLOWERING-NEXT: uunpklo z29.d, z5.s
+; CHECK-NEWLOWERING-NEXT: uunpkhi z4.d, z4.s
+; CHECK-NEWLOWERING-NEXT: uunpkhi z7.d, z7.s
+; CHECK-NEWLOWERING-NEXT: uunpkhi z5.d, z5.s
+; CHECK-NEWLOWERING-NEXT: uunpkhi z30.d, z24.s
+; CHECK-NEWLOWERING-NEXT: uunpkhi z31.d, z2.s
+; CHECK-NEWLOWERING-NEXT: uunpklo z24.d, z24.s
+; CHECK-NEWLOWERING-NEXT: uunpklo z2.d, z2.s
+; CHECK-NEWLOWERING-NEXT: uunpkhi z8.d, z25.s
+; CHECK-NEWLOWERING-NEXT: uunpklo z25.d, z25.s
+; CHECK-NEWLOWERING-NEXT: uunpklo z9.d, z3.s
+; CHECK-NEWLOWERING-NEXT: mul z27.d, z27.d, z29.d
+; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z6.d, z28.d
+; CHECK-NEWLOWERING-NEXT: uunpkhi z3.d, z3.s
+; CHECK-NEWLOWERING-NEXT: mul z4.d, z4.d, z5.d
+; CHECK-NEWLOWERING-NEXT: mla z1.d, p0/m, z26.d, z7.d
+; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z2.d, z9.d
+; CHECK-NEWLOWERING-NEXT: movprfx z2, z27
+; CHECK-NEWLOWERING-NEXT: mla z2.d, p0/m, z24.d, z25.d
+; CHECK-NEWLOWERING-NEXT: ldr z9, [sp] // 16-byte Folded Reload
+; CHECK-NEWLOWERING-NEXT: mla z1.d, p0/m, z31.d, z3.d
+; CHECK-NEWLOWERING-NEXT: movprfx z3, z4
+; CHECK-NEWLOWERING-NEXT: mla z3.d, p0/m, z30.d, z8.d
+; CHECK-NEWLOWERING-NEXT: ldr z8, [sp, #1, mul vl] // 16-byte Folded Reload
+; CHECK-NEWLOWERING-NEXT: add z0.d, z2.d, z0.d
+; CHECK-NEWLOWERING-NEXT: add z1.d, z3.d, z1.d
+; CHECK-NEWLOWERING-NEXT: addvl sp, sp, #2
+; CHECK-NEWLOWERING-NEXT: ldr x29, [sp], #16 // 8-byte Folded Reload
+; CHECK-NEWLOWERING-NEXT: ret
entry:
%a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i64>
%b.wide = zext <vscale x 16 x i8> %b to <vscale x 16 x i64>
@@ -146,15 +372,82 @@ entry:
}
define <vscale x 4 x i64> @sdot_8to64(<vscale x 4 x i64> %acc, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b){
-; CHECK-LABEL: sdot_8to64:
-; CHECK: // %bb.0: // %entry
-; CHECK-NEXT: mov z4.s, #0 // =0x0
-; CHECK-NEXT: sdot z4.s, z2.b, z3.b
-; CHECK-NEXT: sunpklo z2.d, z4.s
-; CHECK-NEXT: sunpkhi z3.d, z4.s
-; CHECK-NEXT: add z0.d, z0.d, z2.d
-; CHECK-NEXT: add z1.d, z1.d, z3.d
-; CHECK-NEXT: ret
+; CHECK-I8MM-LABEL: sdot_8to64:
+; CHECK-I8MM: // %bb.0: // %entry
+; CHECK-I8MM-NEXT: mov z4.s, #0 // =0x0
+; CHECK-I8MM-NEXT: sdot z4.s, z2.b, z3.b
+; CHECK-I8MM-NEXT: sunpklo z2.d, z4.s
+; CHECK-I8MM-NEXT: sunpkhi z3.d, z4.s
+; CHECK-I8MM-NEXT: add z0.d, z0.d, z2.d
+; CHECK-I8MM-NEXT: add z1.d, z1.d, z3.d
+; CHECK-I8MM-NEXT: ret
+;
+; CHECK-NOI8MM-LABEL: sdot_8to64:
+; CHECK-NOI8MM: // %bb.0: // %entry
+; CHECK-NOI8MM-NEXT: mov z4.s, #0 // =0x0
+; CHECK-NOI8MM-NEXT: sdot z4.s, z2.b, z3.b
+; CHECK-NOI8MM-NEXT: sunpklo z2.d, z4.s
+; CHECK-NOI8MM-NEXT: sunpkhi z3.d, z4.s
+; CHECK-NOI8MM-NEXT: add z0.d, z0.d, z2.d
+; CHECK-NOI8MM-NEXT: add z1.d, z1.d, z3.d
+; CHECK-NOI8MM-NEXT: ret
+;
+; CHECK-NEWLOWERING-LABEL: sdot_8to64:
+; CHECK-NEWLOWERING: // %bb.0: // %entry
+; CHECK-NEWLOWERING-NEXT: str x29, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-NEWLOWERING-NEXT: addvl sp, sp, #-2
+; CHECK-NEWLOWERING-NEXT: str z9, [sp] // 16-byte Folded Spill
+; CHECK-NEWLOWERING-NEXT: str z8, [sp, #1, mul vl] // 16-byte Folded Spill
+; CHECK-NEWLOWERING-NEXT: .cfi_escape 0x0f, 0x0c, 0x8f, 0x00, 0x11, 0x10, 0x22, 0x11, 0x10, 0x92, 0x2e, 0x00, 0x1e, 0x22 // sp + 16 + 16 * VG
+; CHECK-NEWLOWERING-NEXT: .cfi_offset w29, -16
+; CHECK-NEWLOWERING-NEXT: .cfi_escape 0x10, 0x48, 0x0a, 0x11, 0x70, 0x22, 0x11, 0x78, 0x92, 0x2e, 0x00, 0x1e, 0x22 // $d8 @ cfa - 16 - 8 * VG
+; CHECK-NEWLOWERING-NEXT: .cfi_escape 0x10, 0x49, 0x0a, 0x11, 0x70, 0x22, 0x11, 0x70, 0x92, 0x2e, 0x00, 0x1e, 0x22 // $d9 @ cfa - 16 - 16 * VG
+; CHECK-NEWLOWERING-NEXT: sunpklo z4.h, z2.b
+; CHECK-NEWLOWERING-NEXT: sunpklo z5.h, z3.b
+; CHECK-NEWLOWERING-NEXT: sunpkhi z2.h, z2.b
+; CHECK-NEWLOWERING-NEXT: sunpkhi z3.h, z3.b
+; CHECK-NEWLOWERING-NEXT: ptrue p0.d
+; CHECK-NEWLOWERING-NEXT: sunpklo z6.s, z4.h
+; CHECK-NEWLOWERING-NEXT: sunpkhi z4.s, z4.h
+; CHECK-NEWLOWERING-NEXT: sunpklo z7.s, z5.h
+; CHECK-NEWLOWERING-NEXT: sunpkhi z5.s, z5.h
+; CHECK-NEWLOWERING-NEXT: sunpklo z24.s, z2.h
+; CHECK-NEWLOWERING-NEXT: sunpkhi z2.s, z2.h
+; CHECK-NEWLOWERING-NEXT: sunpklo z25.s, z3.h
+; CHECK-NEWLOWERING-NEXT: sunpkhi z3.s, z3.h
+; CHECK-NEWLOWERING-NEXT: sunpkhi z26.d, z6.s
+; CHECK-NEWLOWERING-NEXT: sunpklo z6.d, z6.s
+; CHECK-NEWLOWERING-NEXT: sunpklo z27.d, z4.s
+; CHECK-NEWLOWERING-NEXT: sunpklo z28.d, z7.s
+; CHECK-NEWLOWERING-NEXT: sunpklo z29.d, z5.s
+; CHECK-NEWLOWERING-NEXT: sunpkhi z4.d, z4.s
+; CHECK-NEWLOWERING-NEXT: sunpkhi z7.d, z7.s
+; CHECK-NEWLOWERING-NEXT: sunpkhi z5.d, z5.s
+; CHECK-NEWLOWERING-NEXT: sunpkhi z30.d, z24.s
+; CHECK-NEWLOWERING-NEXT: sunpkhi z31.d, z2.s
+; CHECK-NEWLOWERING-NEXT: sunpklo z24.d, z24.s
+; CHECK-NEWLOWERING-NEXT: sunpklo z2.d, z2.s
+; CHECK-NEWLOWERING-NEXT: sunpkhi z8.d, z25.s
+; CHECK-NEWLOWERING-NEXT: sunpklo z25.d, z25.s
+; CHECK-NEWLOWERING-NEXT: sunpklo z9.d, z3.s
+; CHECK-NEWLOWERING-NEXT: mul z27.d, z27.d, z29.d
+; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z6.d, z28.d
+; CHECK-NEWLOWERING-NEXT: sunpkhi z3.d, z3.s
+; CHECK-NEWLOWERING-NEXT: mul z4.d, z4.d, z5.d
+; CHECK-NEWLOWERING-NEXT: mla z1.d, p0/m, z26.d, z7.d
+; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z2.d, z9.d
+; CHECK-NEWLOWERING-NEXT: movprfx z2, z27
+; CHECK-NEWLOWERING-NEXT: mla z2.d, p0/m, z24.d, z25.d
+; CHECK-NEWLOWERING-NEXT: ldr z9, [sp] // 16-byte Folded Reload
+; CHECK-NEWLOWERING-NEXT: mla z1.d, p0/m, z31.d, z3.d
+; CHECK-NEWLOWERING-NEXT: movprfx z3, z4
+; CHECK-NEWLOWERING-NEXT: mla z3.d, p0/m, z30.d, z8.d
+; CHECK-NEWLOWERING-NEXT: ldr z8, [sp, #1, mul vl] // 16-byte Folded Reload
+; CHECK-NEWLOWERING-NEXT: add z0.d, z2.d, z0.d
+; CHECK-NEWLOWERING-NEXT: add z1.d, z3.d, z1.d
+; CHECK-NEWLOWERING-NEXT: addvl sp, sp, #2
+; CHECK-NEWLOWERING-NEXT: ldr x29, [sp], #16 // 8-byte Folded Reload
+; CHECK-NEWLOWERING-NEXT: ret
entry:
%a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i64>
%b.wide = sext <vscale x 16 x i8> %b to <vscale x 16 x i64>
@@ -231,6 +524,63 @@ define <vscale x 4 x i64> @usdot_8to64(<vscale x 4 x i64> %acc, <vscale x 16 x i
; CHECK-NOI8MM-NEXT: addvl sp, sp, #2
; CHECK-NOI8MM-NEXT: ldr x29, [sp], #16 // 8-byte Folded Reload
; CHECK-NOI8MM-NEXT: ret
+;
+; CHECK-NEWLOWERING-LABEL: usdot_8to64:
+; CHECK-NEWLOWERING: // %bb.0: // %entry
+; CHECK-NEWLOWERING-NEXT: str x29, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-NEWLOWERING-NEXT: addvl sp, sp, #-2
+; CHECK-NEWLOWERING-NEXT: str z9, [sp] // 16-byte Folded Spill
+; CHECK-NEWLOWERING-NEXT: str z8, [sp, #1, mul vl] // 16-byte Folded Spill
+; CHECK-NEWLOWERING-NEXT: .cfi_escape 0x0f, 0x0c, 0x8f, 0x00, 0x11, 0x10, 0x22, 0x11, 0x10, 0x92, 0x2e, 0x00, 0x1e, 0x22 // sp + 16 + 16 * VG
+; CHECK-NEWLOWERING-NEXT: .cfi_offset w29, -16
+; CHECK-NEWLOWERING-NEXT: .cfi_escape 0x10, 0x48, 0x0a, 0x11, 0x70, 0x22, 0x11, 0x78, 0x92, 0x2e, 0x00, 0x1e, 0x22 // $d8 @ cfa - 16 - 8 * VG
+; CHECK-NEWLOWERING-NEXT: .cfi_escape 0x10, 0x49, 0x0a, 0x11, 0x70, 0x22, 0x11, 0x70, 0x92, 0x2e, 0x00, 0x1e, 0x22 // $d9 @ cfa - 16 - 16 * VG
+; CHECK-NEWLOWERING-NEXT: uunpklo z4.h, z2.b
+; CHECK-NEWLOWERING-NEXT: sunpklo z5.h, z3.b
+; CHECK-NEWLOWERING-NEXT: uunpkhi z2.h, z2.b
+; CHECK-NEWLOWERING-NEXT: sunpkhi z3.h, z3.b
+; CHECK-NEWLOWERING-NEXT: ptrue p0.d
+; CHECK-NEWLOWERING-NEXT: uunpklo z6.s, z4.h
+; CHECK-NEWLOWERING-NEXT: uunpkhi z4.s, z4.h
+; CHECK-NEWLOWERING-NEXT: sunpklo z7.s, z5.h
+; CHECK-NEWLOWERING-NEXT: sunpkhi z5.s, z5.h
+; CHECK-NEWLOWERING-NEXT: uunpklo z24.s, z2.h
+; CHECK-NEWLOWERING-NEXT: uunpkhi z2.s, z2.h
+; CHECK-NEWLOWERING-NEXT: sunpklo z25.s, z3.h
+; CHECK-NEWLOWERING-NEXT: sunpkhi z3.s, z3.h
+; CHECK-NEWLOWERING-NEXT: uunpkhi z26.d, z6.s
+; CHECK-NEWLOWERING-NEXT: uunpklo z6.d, z6.s
+; CHECK-NEWLOWERING-NEXT: uunpklo z27.d, z4.s
+; CHECK-NEWLOWERING-NEXT: sunpklo z28.d, z7.s
+; CHECK-NEWLOWERING-NEXT: sunpklo z29.d, z5.s
+; CHECK-NEWLOWERING-NEXT: uunpkhi z4.d, z4.s
+; CHECK-NEWLOWERING-NEXT: sunpkhi z7.d, z7.s
+; CHECK-NEWLOWERING-NEXT: sunpkhi z5.d, z5.s
+; CHECK-NEWLOWERING-NEXT: uunpkhi z30.d, z24.s
+; CHECK-NEWLOWERING-NEXT: uunpkhi z31.d, z2.s
+; CHECK-NEWLOWERING-NEXT: uunpklo z24.d, z24.s
+; CHECK-NEWLOWERING-NEXT: uunpklo z2.d, z2.s
+; CHECK-NEWLOWERING-NEXT: sunpkhi z8.d, z25.s
+; CHECK-NEWLOWERING-NEXT: sunpklo z25.d, z25.s
+; CHECK-NEWLOWERING-NEXT: sunpklo z9.d, z3.s
+; CHECK-NEWLOWERING-NEXT: mul z27.d, z27.d, z29.d
+; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z6.d, z28.d
+; CHECK-NEWLOWERING-NEXT: sunpkhi z3.d, z3.s
+; CHECK-NEWLOWERING-NEXT: mul z4.d, z4.d, z5.d
+; CHECK-NEWLOWERING-NEXT: mla z1.d, p0/m, z26.d, z7.d
+; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z2.d, z9.d
+; CHECK-NEWLOWERING-NEXT: movprfx z2, z27
+; CHECK-NEWLOWERING-NEXT: mla z2.d, p0/m, z24.d, z25.d
+; CHECK-NEWLOWERING-NEXT: ldr z9, [sp] // 16-byte Folded Reload
+; CHECK-NEWLOWERING-NEXT: mla z1.d, p0/m, z31.d, z3.d
+; CHECK-NEWLOWERING-NEXT: movprfx z3, z4
+; CHECK-NEWLOWERING-NEXT: mla z3.d, p0/m, z30.d, z8.d
+; CHECK-NEWLOWERING-NEXT: ldr z8, [sp, #1, mul vl] // 16-byte Folded Reload
+; CHECK-NEWLOWERING-NEXT: add z0.d, z2.d, z0.d
+; CHECK-NEWLOWERING-NEXT: add z1.d, z3.d, z1.d
+; CHECK-NEWLOWERING-NEXT: addvl sp, sp, #2
+; CHECK-NEWLOWERING-NEXT: ldr x29, [sp], #16 // 8-byte Folded Reload
+; CHECK-NEWLOWERING-NEXT: ret
entry:
%a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i64>
%b.wide = sext <vscale x 16 x i8> %b to <vscale x 16 x i64>
@@ -307,6 +657,63 @@ define <vscale x 4 x i64> @sudot_8to64(<vscale x 4 x i64> %acc, <vscale x 16 x i
; CHECK-NOI8MM-NEXT: addvl sp, sp, #2
; CHECK-NOI8MM-NEXT: ldr x29, [sp], #16 // 8-byte Folded Reload
; CHECK-NOI8MM-NEXT: ret
+;
+; CHECK-NEWLOWERING-LABEL: sudot_8to64:
+; CHECK-NEWLOWERING: // %bb.0: // %entry
+; CHECK-NEWLOWERING-NEXT: str x29, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-NEWLOWERING-NEXT: addvl sp, sp, #-2
+; CHECK-NEWLOWERING-NEXT: str z9, [sp] // 16-byte Folded Spill
+; CHECK-NEWLOWERING-NEXT: str z8, [sp, #1, mul vl] // 16-byte Folded Spill
+; CHECK-NEWLOWERING-NEXT: .cfi_escape 0x0f, 0x0c, 0x8f, 0x00, 0x11, 0x10, 0x22, 0x11, 0x10, 0x92, 0x2e, 0x00, 0x1e, 0x22 // sp + 16 + 16 * VG
+; CHECK-NEWLOWERING-NEXT: .cfi_offset w29, -16
+; CHECK-NEWLOWERING-NEXT: .cfi_escape 0x10, 0x48, 0x0a, 0x11, 0x70, 0x22, 0x11, 0x78, 0x92, 0x2e, 0x00, 0x1e, 0x22 // $d8 @ cfa - 16 - 8 * VG
+; CHECK-NEWLOWERING-NEXT: .cfi_escape 0x10, 0x49, 0x0a, 0x11, 0x70, 0x22, 0x11, 0x70, 0x92, 0x2e, 0x00, 0x1e, 0x22 // $d9 @ cfa - 16 - 16 * VG
+; CHECK-NEWLOWERING-NEXT: sunpklo z4.h, z2.b
+; CHECK-NEWLOWERING-NEXT: uunpklo z5.h, z3.b
+; CHECK-NEWLOWERING-NEXT: sunpkhi z2.h, z2.b
+; CHECK-NEWLOWERING-NEXT: uunpkhi z3.h, z3.b
+; CHECK-NEWLOWERING-NEXT: ptrue p0.d
+; CHECK-NEWLOWERING-NEXT: sunpklo z6.s, z4.h
+; CHECK-NEWLOWERING-NEXT: sunpkhi z4.s, z4.h
+; CHECK-NEWLOWERING-NEXT: uunpklo z7.s, z5.h
+; CHECK-NEWLOWERING-NEXT: uunpkhi z5.s, z5.h
+; CHECK-NEWLOWERING-NEXT: sunpklo z24.s, z2.h
+; CHECK-NEWLOWERING-NEXT: sunpkhi z2.s, z2.h
+; CHECK-NEWLOWERING-NEXT: uunpklo z25.s, z3.h
+; CHECK-NEWLOWERING-NEXT: uunpkhi z3.s, z3.h
+; CHECK-NEWLOWERING-NEXT: sunpkhi z26.d, z6.s
+; CHECK-NEWLOWERING-NEXT: sunpklo z6.d, z6.s
+; CHECK-NEWLOWERING-NEXT: sunpklo z27.d, z4.s
+; CHECK-NEWLOWERING-NEXT: uunpklo z28.d, z7.s
+; CHECK-NEWLOWERING-NEXT: uunpklo z29.d, z5.s
+; CHECK-NEWLOWERING-NEXT: sunpkhi z4.d, z4.s
+; CHECK-NEWLOWERING-NEXT: uunpkhi z7.d, z7.s
+; CHECK-NEWLOWERING-NEXT: uunpkhi z5.d, z5.s
+; CHECK-NEWLOWERING-NEXT: sunpkhi z30.d, z24.s
+; CHECK-NEWLOWERING-NEXT: sunpkhi z31.d, z2.s
+; CHECK-NEWLOWERING-NEXT: sunpklo z24.d, z24.s
+; CHECK-NEWLOWERING-NEXT: sunpklo z2.d, z2.s
+; CHECK-NEWLOWERING-NEXT: uunpkhi z8.d, z25.s
+; CHECK-NEWLOWERING-NEXT: uunpklo z25.d, z25.s
+; CHECK-NEWLOWERING-NEXT: uunpklo z9.d, z3.s
+; CHECK-NEWLOWERING-NEXT: mul z27.d, z27.d, z29.d
+; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z6.d, z28.d
+; CHECK-NEWLOWERING-NEXT: uunpkhi z3.d, z3.s
+; CHECK-NEWLOWERING-NEXT: mul z4.d, z4.d, z5.d
+; CHECK-NEWLOWERING-NEXT: mla z1.d, p0/m, z26.d, z7.d
+; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z2.d, z9.d
+; CHECK-NEWLOWERING-NEXT: movprfx z2, z27
+; CHECK-NEWLOWERING-NEXT: mla z2.d, p0/m, z24.d, z25.d
+; CHECK-NEWLOWERING-NEXT: ldr z9, [sp] // 16-byte Folded Reload
+; CHECK-NEWLOWERING-NEXT: mla z1.d, p0/m, z31.d, z3.d
+; CHECK-NEWLOWERING-NEXT: movprfx z3, z4
+; CHECK-NEWLOWERING-NEXT: mla z3.d, p0/m, z30.d, z8.d
+; CHECK-NEWLOWERING-NEXT: ldr z8, [sp, #1, mul vl] // 16-byte Folded Reload
+; CHECK-NEWLOWERING-NEXT: add z0.d, z2.d, z0.d
+; CHECK-NEWLOWERING-NEXT: add z1.d, z3.d, z1.d
+; CHECK-NEWLOWERING-NEXT: addvl sp, sp, #2
+; CHECK-NEWLOWERING-NEXT: ldr x29, [sp], #16 // 8-byte Folded Reload
+; CHECK-NEWLOWERING-NEXT: ret
entry:
%a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i64>
%b.wide = zext <vscale x 16 x i8> %b to <vscale x 16 x i64>
@@ -317,33 +724,93 @@ entry:
}
define <vscale x 4 x i32> @udot_no_bin_op(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a){
-; CHECK-LABEL: udot_no_bin_op:
-; CHECK: // %bb.0:
-; CHECK-NEXT: mov z2.b, #1 // =0x1
-; CHECK-NEXT: udot z0.s, z1.b, z2.b
-; CHECK-NEXT: ret
+; CHECK-I8MM-LABEL: udot_no_bin_op:
+; CHECK-I8MM: // %bb.0:
+; CHECK-I8MM-NEXT: mov z2.b, #1 // =0x1
+; CHECK-I8MM-NEXT: udot z0.s, z1.b, z2.b
+; CHECK-I8MM-NEXT: ret
+;
+; CHECK-NOI8MM-LABEL: udot_no_bin_op:
+; CHECK-NOI8MM: // %bb.0:
+; CHECK-NOI8MM-NEXT: mov z2.b, #1 // =0x1
+; CHECK-NOI8MM-NEXT: udot z0.s, z1.b, z2.b
+; CHECK-NOI8MM-NEXT: ret
+;
+; CHECK-NEWLOWERING-LABEL: udot_no_bin_op:
+; CHECK-NEWLOWERING: // %bb.0:
+; CHECK-NEWLOWERING-NEXT: uunpklo z2.h, z1.b
+; CHECK-NEWLOWERING-NEXT: uunpkhi z1.h, z1.b
+; CHECK-NEWLOWERING-NEXT: uunpklo z3.s, z2.h
+; CHECK-NEWLOWERING-NEXT: uunpkhi z4.s, z1.h
+; CHECK-NEWLOWERING-NEXT: uunpklo z1.s, z1.h
+; CHECK-NEWLOWERING-NEXT: uunpkhi z2.s, z2.h
+; CHECK-NEWLOWERING-NEXT: add z0.s, z0.s, z3.s
+; CHECK-NEWLOWERING-NEXT: add z1.s, z2.s, z1.s
+; CHECK-NEWLOWERING-NEXT: add z0.s, z4.s, z0.s
+; CHECK-NEWLOWERING-NEXT: add z0.s, z1.s, z0.s
+; CHECK-NEWLOWERING-NEXT: ret
%a.ext = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
%partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %a.ext)
ret <vscale x 4 x i32> %partial.reduce
}
define <vscale x 4 x i32> @sdot_no_bin_op(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a){
-; CHECK-LABEL: sdot_no_bin_op:
-; CHECK: // %bb.0:
-; CHECK-NEXT: mov z2.b, #1 // =0x1
-; CHECK-NEXT: sdot z0.s, z1.b, z2.b
-; CHECK-NEXT: ret
+; CHECK-I8MM-LABEL: sdot_no_bin_op:
+; CHECK-I8MM: // %bb.0:
+; CHECK-I8MM-NEXT: mov z2.b, #1 // =0x1
+; CHECK-I8MM-NEXT: sdot z0.s, z1.b, z2.b
+; CHECK-I8MM-NEXT: ret
+;
+; CHECK-NOI8MM-LABEL: sdot_no_bin_op:
+; CHECK-NOI8MM: // %bb.0:
+; CHECK-NOI8MM-NEXT: mov z2.b, #1 // =0x1
+; CHECK-NOI8MM-NEXT: sdot z0.s, z1.b, z2.b
+; CHECK-NOI8MM-NEXT: ret
+;
+; CHECK-NEWLOWERING-LABEL: sdot_no_bin_op:
+; CHECK-NEWLOWERING: // %bb.0:
+; CHECK-NEWLOWERING-NEXT: sunpklo z2.h, z1.b
+; CHECK-NEWLOWERING-NEXT: sunpkhi z1.h, z1.b
+; CHECK-NEWLOWERING-NEXT: sunpklo z3.s, z2.h
+; CHECK-NEWLOWERING-NEXT: sunpkhi z4.s, z1.h
+; CHECK-NEWLOWERING-NEXT: sunpklo z1.s, z1.h
+; CHECK-NEWLOWERING-NEXT: sunpkhi z2.s, z2.h
+; CHECK-NEWLOWERING-NEXT: add z0.s, z0.s, z3.s
+; CHECK-NEWLOWERING-NEXT: add z1.s, z2.s, z1.s
+; CHECK-NEWLOWERING-NEXT: add z0.s, z4.s, z0.s
+; CHECK-NEWLOWERING-NEXT: add z0.s, z1.s, z0.s
+; CHECK-NEWLOWERING-NEXT: ret
%a.ext = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
%partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %a.ext)
ret <vscale x 4 x i32> %partial.reduce
}
define <vscale x 2 x i64> @udot_no_bin_op_wide(<vscale x 2 x i64> %acc, <vscale x 8 x i16> %a, <vscale x 8 x i16> %b){
-; CHECK-LABEL: udot_no_bin_op_wide:
-; CHECK: // %bb.0: // %entry
-; CHECK-NEXT: mov z2.h, #1 // =0x1
-; CHECK-NEXT: udot z0.d, z1.h, z2.h
-; CHECK-NEXT: ret
+; CHECK-I8MM-LABEL: udot_no_bin_op_wide:
+; CHECK-I8MM: // %bb.0: // %entry
+; CHECK-I8MM-NEXT: mov z2.h, #1 // =0x1
+; CHECK-I8MM-NEXT: udot z0.d, z1.h, z2.h
+; CHECK-I8MM-NEXT: ret
+;
+; CHECK-NOI8MM-LABEL: udot_no_bin_op_wide:
+; CHECK-NOI8MM: // %bb.0: // %entry
+; CHECK-NOI8MM-NEXT: mov z2.h, #1 // =0x1
+; CHECK-NOI8MM-NEXT: udot z0.d, z1.h, z2.h
+; CHECK-NOI8MM-NEXT: ret
+;
+; CHECK-NEWLOWERING-LABEL: udot_no_bin_op_wide:
+; CHECK-NEWLOWERING: // %bb.0: // %entry
+; CHECK-NEWLOWERING-NEXT: uunpklo z2.s, z1.h
+; CHECK-NEWLOWERING-NEXT: uunpkhi z1.s, z1.h
+; CHECK-NEWLOWERING-NEXT: uunpklo z3.d, z2.s
+; CHECK-NEWLOWERING-NEXT: uunpkhi z4.d, z1.s
+; CHECK-NEWLOWERING-NEXT: uunpklo z1.d, z1.s
+; CHECK-NEWLOWERING-NEXT: uunpkhi z2.d, z2.s
+; CHECK-NEWLOWERING-NEXT: add z0.d, z0.d, z3.d
+; CHECK-NEWLOWERING-NEXT: add z1.d, z2.d, z1.d
+; CHECK-NEWLOWERING-NEXT: add z0.d, z4.d, z0.d
+; CHECK-NEWLOWERING-NEXT: add z0.d, z1.d, z0.d
+; CHECK-NEWLOWERING-NEXT: ret
entry:
%a.wide = zext <vscale x 8 x i16> %a to <vscale x 8 x i64>
%partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv8i64(<vscale x 2 x i64> %acc, <vscale x 8 x i64> %a.wide)
@@ -351,11 +818,31 @@ entry:
}
define <vscale x 2 x i64> @sdot_no_bin_op_wide(<vscale x 2 x i64> %acc, <vscale x 8 x i16> %a, <vscale x 8 x i16> %b){
-; CHECK-LABEL: sdot_no_bin_op_wide:
-; CHECK: // %bb.0: // %entry
-; CHECK-NEXT: mov z2.h, #1 // =0x1
-; CHECK-NEXT: sdot z0.d, z1.h, z2.h
-; CHECK-NEXT: ret
+; CHECK-I8MM-LABEL: sdot_no_bin_op_wide:
+; CHECK-I8MM: // %bb.0: // %entry
+; CHECK-I8MM-NEXT: mov z2.h, #1 // =0x1
+; CHECK-I8MM-NEXT: sdot z0.d, z1.h, z2.h
+; CHECK-I8MM-NEXT: ret
+;
+; CHECK-NOI8MM-LABEL: sdot_no_bin_op_wide:
+; CHECK-NOI8MM: // %bb.0: // %entry
+; CHECK-NOI8MM-NEXT: mov z2.h, #1 // =0x1
+; CHECK-NOI8MM-NEXT: sdot z0.d, z1.h, z2.h
+; CHECK-NOI8MM-NEXT: ret
+;
+; CHECK-NEWLOWERING-LABEL: sdot_no_bin_op_wide:
+; CHECK-NEWLOWERING: // %bb.0: // %entry
+; CHECK-NEWLOWERING-NEXT: sunpklo z2.s, z1.h
+; CHECK-NEWLOWERING-NEXT: sunpkhi z1.s, z1.h
+; CHECK-NEWLOWERING-NEXT: sunpklo z3.d, z2.s
+; CHECK-NEWLOWERING-NEXT: sunpkhi z4.d, z1.s
+; CHECK-NEWLOWERING-NEXT: sunpklo z1.d, z1.s
+; CHECK-NEWLOWERING-NEXT: sunpkhi z2.d, z2.s
+; CHECK-NEWLOWERING-NEXT: add z0.d, z0.d, z3.d
+; CHECK-NEWLOWERING-NEXT: add z1.d, z2.d, z1.d
+; CHECK-NEWLOWERING-NEXT: add z0.d, z4.d, z0.d
+; CHECK-NEWLOWERING-NEXT: add z0.d, z1.d, z0.d
+; CHECK-NEWLOWERING-NEXT: ret
entry:
%a.wide = sext <vscale x 8 x i16> %a to <vscale x 8 x i64>
%partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv8i64(<vscale x 2 x i64> %acc, <vscale x 8 x i64> %a.wide)
@@ -363,32 +850,106 @@ entry:
}
define <vscale x 4 x i64> @udot_no_bin_op_8to64(<vscale x 4 x i64> %acc, <vscale x 16 x i8> %a){
-; CHECK-LABEL: udot_no_bin_op_8to64:
-; CHECK: // %bb.0:
-; CHECK-NEXT: mov z3.b, #1 // =0x1
-; CHECK-NEXT: mov z4.s, #0 // =0x0
-; CHECK-NEXT: udot z4.s, z2.b, z3.b
-; CHECK-NEXT: sunpklo z2.d, z4.s
-; CHECK-NEXT: sunpkhi z3.d, z4.s
-; CHECK-NEXT: add z0.d, z0.d, z2.d
-; CHECK-NEXT: add z1.d, z1.d, z3.d
-; CHECK-NEXT: ret
+; CHECK-I8MM-LABEL: udot_no_bin_op_8to64:
+; CHECK-I8MM: // %bb.0:
+; CHECK-I8MM-NEXT: mov z3.b, #1 // =0x1
+; CHECK-I8MM-NEXT: mov z4.s, #0 // =0x0
+; CHECK-I8MM-NEXT: udot z4.s, z2.b, z3.b
+; CHECK-I8MM-NEXT: sunpklo z2.d, z4.s
+; CHECK-I8MM-NEXT: sunpkhi z3.d, z4.s
+; CHECK-I8MM-NEXT: add z0.d, z0.d, z2.d
+; CHECK-I8MM-NEXT: add z1.d, z1.d, z3.d
+; CHECK-I8MM-NEXT: ret
+;
+; CHECK-NOI8MM-LABEL: udot_no_bin_op_8to64:
+; CHECK-NOI8MM: // %bb.0:
+; CHECK-NOI8MM-NEXT: mov z3.b, #1 // =0x1
+; CHECK-NOI8MM-NEXT: mov z4.s, #0 // =0x0
+; CHECK-NOI8MM-NEXT: udot z4.s, z2.b, z3.b
+; CHECK-NOI8MM-NEXT: sunpklo z2.d, z4.s
+; CHECK-NOI8MM-NEXT: sunpkhi z3.d, z4.s
+; CHECK-NOI8MM-NEXT: add z0.d, z0.d, z2.d
+; CHECK-NOI8MM-NEXT: add z1.d, z1.d, z3.d
+; CHECK-NOI8MM-NEXT: ret
+;
+; CHECK-NEWLOWERING-LABEL: udot_no_bin_op_8to64:
+; CHECK-NEWLOWERING: // %bb.0:
+; CHECK-NEWLOWERING-NEXT: uunpklo z3.h, z2.b
+; CHECK-NEWLOWERING-NEXT: uunpkhi z2.h, z2.b
+; CHECK-NEWLOWERING-NEXT: uunpklo z4.s, z3.h
+; CHECK-NEWLOWERING-NEXT: uunpkhi z5.s, z2.h
+; CHECK-NEWLOWERING-NEXT: uunpklo z2.s, z2.h
+; CHECK-NEWLOWERING-NEXT: uunpkhi z3.s, z3.h
+; CHECK-NEWLOWERING-NEXT: uunpkhi z6.d, z4.s
+; CHECK-NEWLOWERING-NEXT: uunpklo z4.d, z4.s
+; CHECK-NEWLOWERING-NEXT: uunpkhi z7.d, z5.s
+; CHECK-NEWLOWERING-NEXT: uunpklo z24.d, z2.s
+; CHECK-NEWLOWERING-NEXT: uunpklo z25.d, z3.s
+; CHECK-NEWLOWERING-NEXT: uunpkhi z2.d, z2.s
+; CHECK-NEWLOWERING-NEXT: uunpkhi z3.d, z3.s
+; CHECK-NEWLOWERING-NEXT: uunpklo z5.d, z5.s
+; CHECK-NEWLOWERING-NEXT: add z0.d, z0.d, z4.d
+; CHECK-NEWLOWERING-NEXT: add z1.d, z1.d, z6.d
+; CHECK-NEWLOWERING-NEXT: add z4.d, z25.d, z24.d
+; CHECK-NEWLOWERING-NEXT: add z2.d, z3.d, z2.d
+; CHECK-NEWLOWERING-NEXT: add z0.d, z5.d, z0.d
+; CHECK-NEWLOWERING-NEXT: add z1.d, z7.d, z1.d
+; CHECK-NEWLOWERING-NEXT: add z0.d, z4.d, z0.d
+; CHECK-NEWLOWERING-NEXT: add z1.d, z2.d, z1.d
+; CHECK-NEWLOWERING-NEXT: ret
%a.ext = zext <vscale x 16 x i8> %a to <vscale x 16 x i64>
%partial.reduce = tail call <vscale x 4 x i64> @llvm.experimental.vector.partial.reduce.add.nxv4i64.nxv16i64(<vscale x 4 x i64> %acc, <vscale x 16 x i64> %a.ext)
ret <vscale x 4 x i64> %partial.reduce
}
define <vscale x 4 x i64> @sdot_no_bin_op_8to64(<vscale x 4 x i64> %acc, <vscale x 16 x i8> %a){
-; CHECK-LABEL: sdot_no_bin_op_8to64:
-; CHECK: // %bb.0:
-; CHECK-NEXT: mov z3.b, #1 // =0x1
-; CHECK-NEXT: mov z4.s, #0 // =0x0
-; CHECK-NEXT: sdot z4.s, z2.b, z3.b
-; CHECK-NEXT: sunpklo z2.d, z4.s
-; CHECK-NEXT: sunpkhi z3.d, z4.s
-; CHECK-NEXT: add z0.d, z0.d, z2.d
-; CHECK-NEXT: add z1.d, z1.d, z3.d
-; CHECK-NEXT: ret
+; CHECK-I8MM-LABEL: sdot_no_bin_op_8to64:
+; CHECK-I8MM: // %bb.0:
+; CHECK-I8MM-NEXT: mov z3.b, #1 // =0x1
+; CHECK-I8MM-NEXT: mov z4.s, #0 // =0x0
+; CHECK-I8MM-NEXT: sdot z4.s, z2.b, z3.b
+; CHECK-I8MM-NEXT: sunpklo z2.d, z4.s
+; CHECK-I8MM-NEXT: sunpkhi z3.d, z4.s
+; CHECK-I8MM-NEXT: add z0.d, z0.d, z2.d
+; CHECK-I8MM-NEXT: add z1.d, z1.d, z3.d
+; CHECK-I8MM-NEXT: ret
+;
+; CHECK-NOI8MM-LABEL: sdot_no_bin_op_8to64:
+; CHECK-NOI8MM: // %bb.0:
+; CHECK-NOI8MM-NEXT: mov z3.b, #1 // =0x1
+; CHECK-NOI8MM-NEXT: mov z4.s, #0 // =0x0
+; CHECK-NOI8MM-NEXT: sdot z4.s, z2.b, z3.b
+; CHECK-NOI8MM-NEXT: sunpklo z2.d, z4.s
+; CHECK-NOI8MM-NEXT: sunpkhi z3.d, z4.s
+; CHECK-NOI8MM-NEXT: add z0.d, z0.d, z2.d
+; CHECK-NOI8MM-NEXT: add z1.d, z1.d, z3.d
+; CHECK-NOI8MM-NEXT: ret
+;
+; CHECK-NEWLOWERING-LABEL: sdot_no_bin_op_8to64:
+; CHECK-NEWLOWERING: // %bb.0:
+; CHECK-NEWLOWERING-NEXT: sunpklo z3.h, z2.b
+; CHECK-NEWLOWERING-NEXT: sunpkhi z2.h, z2.b
+; CHECK-NEWLOWERING-NEXT: sunpklo z4.s, z3.h
+; CHECK-NEWLOWERING-NEXT: sunpkhi z5.s, z2.h
+; CHECK-NEWLOWERING-NEXT: sunpklo z2.s, z2.h
+; CHECK-NEWLOWERING-NEXT: sunpkhi z3.s, z3.h
+; CHECK-NEWLOWERING-NEXT: sunpkhi z6.d, z4.s
+; CHECK-NEWLOWERING-NEXT: sunpklo z4.d, z4.s
+; CHECK-NEWLOWERING-NEXT: sunpkhi z7.d, z5.s
+; CHECK-NEWLOWERING-NEXT: sunpklo z24.d, z2.s
+; CHECK-NEWLOWERING-NEXT: sunpklo z25.d, z3.s
+; CHECK-NEWLOWERING-NEXT: sunpkhi z2.d, z2.s
+; CHECK-NEWLOWERING-NEXT: sunpkhi z3.d, z3.s
+; CHECK-NEWLOWERING-NEXT: sunpklo z5.d, z5.s
+; CHECK-NEWLOWERING-NEXT: add z0.d, z0.d, z4.d
+; CHECK-NEWLOWERING-NEXT: add z1.d, z1.d, z6.d
+; CHECK-NEWLOWERING-NEXT: add z4.d, z25.d, z24.d
+; CHECK-NEWLOWERING-NEXT: add z2.d, z3.d, z2.d
+; CHECK-NEWLOWERING-NEXT: add z0.d, z5.d, z0.d
+; CHECK-NEWLOWERING-NEXT: add z1.d, z7.d, z1.d
+; CHECK-NEWLOWERING-NEXT: add z0.d, z4.d, z0.d
+; CHECK-NEWLOWERING-NEXT: add z1.d, z2.d, z1.d
+; CHECK-NEWLOWERING-NEXT: ret
%a.ext = sext <vscale x 16 x i8> %a to <vscale x 16 x i64>
%partial.reduce = tail call <vscale x 4 x i64> @llvm.experimental.vector.partial.reduce.add.nxv4i64.nxv16i64(<vscale x 4 x i64> %acc, <vscale x 16 x i64> %a.ext)
ret <vscale x 4 x i64> %partial.reduce
diff --git a/llvm/test/CodeGen/AArch64/sve-partial-reduce-wide-add.ll b/llvm/test/CodeGen/AArch64/sve-partial-reduce-wide-add.ll
index b4b946c68566e..62b5039259392 100644
--- a/llvm/test/CodeGen/AArch64/sve-partial-reduce-wide-add.ll
+++ b/llvm/test/CodeGen/AArch64/sve-partial-reduce-wide-add.ll
@@ -1,6 +1,7 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
; RUN: llc -mtriple=aarch64 -mattr=+sve2 %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-SVE2
; RUN: llc -mtriple=aarch64 -mattr=+sve %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-SVE
+; RUN: llc -mtriple=aarch64 -mattr=+sve2 -new-partial-reduce-lowering %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-NEWLOWERING
define <vscale x 2 x i64> @signed_wide_add_nxv4i32(<vscale x 2 x i64> %acc, <vscale x 4 x i32> %input){
; CHECK-SVE2-LABEL: signed_wide_add_nxv4i32:
@@ -16,6 +17,14 @@ define <vscale x 2 x i64> @signed_wide_add_nxv4i32(<vscale x 2 x i64> %acc, <vsc
; CHECK-SVE-NEXT: add z0.d, z0.d, z2.d
; CHECK-SVE-NEXT: add z0.d, z1.d, z0.d
; CHECK-SVE-NEXT: ret
+;
+; CHECK-NEWLOWERING-LABEL: signed_wide_add_nxv4i32:
+; CHECK-NEWLOWERING: // %bb.0: // %entry
+; CHECK-NEWLOWERING-NEXT: sunpklo z2.d, z1.s
+; CHECK-NEWLOWERING-NEXT: sunpkhi z1.d, z1.s
+; CHECK-NEWLOWERING-NEXT: add z0.d, z0.d, z2.d
+; CHECK-NEWLOWERING-NEXT: add z0.d, z1.d, z0.d
+; CHECK-NEWLOWERING-NEXT: ret
entry:
%input.wide = sext <vscale x 4 x i32> %input to <vscale x 4 x i64>
%partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv4i64(<vscale x 2 x i64> %acc, <vscale x 4 x i64> %input.wide)
@@ -36,6 +45,14 @@ define <vscale x 2 x i64> @unsigned_wide_add_nxv4i32(<vscale x 2 x i64> %acc, <v
; CHECK-SVE-NEXT: add z0.d, z0.d, z2.d
; CHECK-SVE-NEXT: add z0.d, z1.d, z0.d
; CHECK-SVE-NEXT: ret
+;
+; CHECK-NEWLOWERING-LABEL: unsigned_wide_add_nxv4i32:
+; CHECK-NEWLOWERING: // %bb.0: // %entry
+; CHECK-NEWLOWERING-NEXT: uunpklo z2.d, z1.s
+; CHECK-NEWLOWERING-NEXT: uunpkhi z1.d, z1.s
+; CHECK-NEWLOWERING-NEXT: add z0.d, z0.d, z2.d
+; CHECK-NEWLOWERING-NEXT: add z0.d, z1.d, z0.d
+; CHECK-NEWLOWERING-NEXT: ret
entry:
%input.wide = zext <vscale x 4 x i32> %input to <vscale x 4 x i64>
%partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv4i64(<vscale x 2 x i64> %acc, <vscale x 4 x i64> %input.wide)
@@ -56,6 +73,14 @@ define <vscale x 4 x i32> @signed_wide_add_nxv8i16(<vscale x 4 x i32> %acc, <vsc
; CHECK-SVE-NEXT: add z0.s, z0.s, z2.s
; CHECK-SVE-NEXT: add z0.s, z1.s, z0.s
; CHECK-SVE-NEXT: ret
+;
+; CHECK-NEWLOWERING-LABEL: signed_wide_add_nxv8i16:
+; CHECK-NEWLOWERING: // %bb.0: // %entry
+; CHECK-NEWLOWERING-NEXT: sunpklo z2.s, z1.h
+; CHECK-NEWLOWERING-NEXT: sunpkhi z1.s, z1.h
+; CHECK-NEWLOWERING-NEXT: add z0.s, z0.s, z2.s
+; CHECK-NEWLOWERING-NEXT: add z0.s, z1.s, z0.s
+; CHECK-NEWLOWERING-NEXT: ret
entry:
%input.wide = sext <vscale x 8 x i16> %input to <vscale x 8 x i32>
%partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv8i32(<vscale x 4 x i32> %acc, <vscale x 8 x i32> %input.wide)
@@ -76,6 +101,14 @@ define <vscale x 4 x i32> @unsigned_wide_add_nxv8i16(<vscale x 4 x i32> %acc, <v
; CHECK-SVE-NEXT: add z0.s, z0.s, z2.s
; CHECK-SVE-NEXT: add z0.s, z1.s, z0.s
; CHECK-SVE-NEXT: ret
+;
+; CHECK-NEWLOWERING-LABEL: unsigned_wide_add_nxv8i16:
+; CHECK-NEWLOWERING: // %bb.0: // %entry
+; CHECK-NEWLOWERING-NEXT: uunpklo z2.s, z1.h
+; CHECK-NEWLOWERING-NEXT: uunpkhi z1.s, z1.h
+; CHECK-NEWLOWERING-NEXT: add z0.s, z0.s, z2.s
+; CHECK-NEWLOWERING-NEXT: add z0.s, z1.s, z0.s
+; CHECK-NEWLOWERING-NEXT: ret
entry:
%input.wide = zext <vscale x 8 x i16> %input to <vscale x 8 x i32>
%partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv8i32(<vscale x 4 x i32> %acc, <vscale x 8 x i32> %input.wide)
@@ -96,6 +129,14 @@ define <vscale x 8 x i16> @signed_wide_add_nxv16i8(<vscale x 8 x i16> %acc, <vsc
; CHECK-SVE-NEXT: add z0.h, z0.h, z2.h
; CHECK-SVE-NEXT: add z0.h, z1.h, z0.h
; CHECK-SVE-NEXT: ret
+;
+; CHECK-NEWLOWERING-LABEL: signed_wide_add_nxv16i8:
+; CHECK-NEWLOWERING: // %bb.0: // %entry
+; CHECK-NEWLOWERING-NEXT: sunpklo z2.h, z1.b
+; CHECK-NEWLOWERING-NEXT: sunpkhi z1.h, z1.b
+; CHECK-NEWLOWERING-NEXT: add z0.h, z0.h, z2.h
+; CHECK-NEWLOWERING-NEXT: add z0.h, z1.h, z0.h
+; CHECK-NEWLOWERING-NEXT: ret
entry:
%input.wide = sext <vscale x 16 x i8> %input to <vscale x 16 x i16>
%partial.reduce = tail call <vscale x 8 x i16> @llvm.experimental.vector.partial.reduce.add.nxv8i16.nxv16i16(<vscale x 8 x i16> %acc, <vscale x 16 x i16> %input.wide)
@@ -116,6 +157,14 @@ define <vscale x 8 x i16> @unsigned_wide_add_nxv16i8(<vscale x 8 x i16> %acc, <v
; CHECK-SVE-NEXT: add z0.h, z0.h, z2.h
; CHECK-SVE-NEXT: add z0.h, z1.h, z0.h
; CHECK-SVE-NEXT: ret
+;
+; CHECK-NEWLOWERING-LABEL: unsigned_wide_add_nxv16i8:
+; CHECK-NEWLOWERING: // %bb.0: // %entry
+; CHECK-NEWLOWERING-NEXT: uunpklo z2.h, z1.b
+; CHECK-NEWLOWERING-NEXT: uunpkhi z1.h, z1.b
+; CHECK-NEWLOWERING-NEXT: add z0.h, z0.h, z2.h
+; CHECK-NEWLOWERING-NEXT: add z0.h, z1.h, z0.h
+; CHECK-NEWLOWERING-NEXT: ret
entry:
%input.wide = zext <vscale x 16 x i8> %input to <vscale x 16 x i16>
%partial.reduce = tail call <vscale x 8 x i16> @llvm.experimental.vector.partial.reduce.add.nxv8i16.nxv16i16(<vscale x 8 x i16> %acc, <vscale x 16 x i16> %input.wide)
>From c036082b13c502c04b10577554880b31291f4d3b Mon Sep 17 00:00:00 2001
From: James Chesterman <james.chesterman at arm.com>
Date: Mon, 3 Feb 2025 15:17:31 +0000
Subject: [PATCH 02/12] Address comments on PR.
---
llvm/include/llvm/CodeGen/ISDOpcodes.h | 11 +++---
llvm/include/llvm/CodeGen/SelectionDAG.h | 7 ----
llvm/include/llvm/CodeGen/TargetLowering.h | 4 ++
llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp | 5 +++
.../SelectionDAG/LegalizeIntegerTypes.cpp | 21 ++++++----
llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h | 2 +-
.../SelectionDAG/LegalizeVectorOps.cpp | 6 +++
.../SelectionDAG/LegalizeVectorTypes.cpp | 4 +-
.../lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 17 ---------
.../SelectionDAG/SelectionDAGBuilder.cpp | 3 ++
.../CodeGen/SelectionDAG/TargetLowering.cpp | 38 +++++++++++++++++++
.../Target/AArch64/AArch64ISelLowering.cpp | 5 +++
12 files changed, 82 insertions(+), 41 deletions(-)
diff --git a/llvm/include/llvm/CodeGen/ISDOpcodes.h b/llvm/include/llvm/CodeGen/ISDOpcodes.h
index 3f235ee358e0e..422a70bb6641b 100644
--- a/llvm/include/llvm/CodeGen/ISDOpcodes.h
+++ b/llvm/include/llvm/CodeGen/ISDOpcodes.h
@@ -1451,17 +1451,16 @@ enum NodeType {
VECREDUCE_UMAX,
VECREDUCE_UMIN,
- // PARTIAL_REDUCE_*MLA (Accumulator, Input1, Input2)
- // Partial reduction nodes. Input1 and Input2 are multiplied together before
- // being reduced, by addition to the number of elements that Accumulator's
- // type has.
+ // PARTIAL_REDUCE_[U|S]MLA(Accumulator, Input1, Input2)
+ // The partial reduction nodes sign-or zero extend Input1 and Input2 to
+ // the element type of Accumulator before multiplying their results.
+ // The multiplied result is then reduced using addition to the result
+ // type of Accumulator. The result is added to Accumulator and returned.
// Input1 and Input2 must be the same type. Accumulator and the output must be
// the same type.
// The number of elements in Input1 and Input2 must be a positive integer
// multiple of the number of elements in the Accumulator / output type.
// All operands, as well as the output, must have the same element type.
- // Operands: Accumulator, Input1, Input2
- // Outputs: Output
PARTIAL_REDUCE_SMLA,
PARTIAL_REDUCE_UMLA,
diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h
index 0fc6f6ccf85bd..461c0c1ead16d 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAG.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAG.h
@@ -1607,13 +1607,6 @@ class SelectionDAG {
/// the target's desired shift amount type.
SDValue getShiftAmountOperand(EVT LHSTy, SDValue Op);
- // Expands PARTIAL_REDUCE_S/UMLA nodes
- // \p Acc Accumulator for where the result is stored for the partial reduction
- // operation.
- // \p Input1 First input for the partial reduction operation
- // \p Input2 Second input for the partial reduction operation
- SDValue expandPartialReduceMLA(SDNode *N);
-
/// Create the DAG equivalent of vector_partial_reduce where Op1 and Op2 are
/// its operands and ReducedTY is the intrinsic's return type.
SDValue getPartialReduceAdd(SDLoc DL, EVT ReducedTy, SDValue Op1,
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 04ee24c0916e5..91c13cd507542 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -5634,6 +5634,10 @@ class TargetLowering : public TargetLoweringBase {
LoadSDNode *OriginalLoad,
SelectionDAG &DAG) const;
+ // Expands PARTIAL_REDUCE_S/UMLA nodes to a series of simpler operations,
+ // consisting of zext/sext, extract_subvector, mul and add operations.
+ SDValue expandPartialReduceMLA(SDNode *N, SelectionDAG &DAG) const;
+
private:
SDValue foldSetCCWithAnd(EVT VT, SDValue N0, SDValue N1, ISD::CondCode Cond,
const SDLoc &DL, DAGCombinerInfo &DCI) const;
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
index 6c9c96ceaa4ba..53e8019a2e0ec 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
@@ -1245,6 +1245,11 @@ void SelectionDAGLegalize::LegalizeOp(SDNode *Node) {
Node->getOpcode(),
cast<MaskedHistogramSDNode>(Node)->getIndex().getValueType());
break;
+ case ISD::PARTIAL_REDUCE_UMLA:
+ case ISD::PARTIAL_REDUCE_SMLA:
+ Action = TLI.getOperationAction(Node->getOpcode(),
+ Node->getOperand(0).getValueType());
+ break;
default:
if (Node->getOpcode() >= ISD::BUILTIN_OP_END) {
Action = TLI.getCustomOperationAction(*Node);
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
index be22143a8e9d9..cbb061c9e9403 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
@@ -2106,7 +2106,7 @@ bool DAGTypeLegalizer::PromoteIntegerOperand(SDNode *N, unsigned OpNo) {
break;
case ISD::PARTIAL_REDUCE_UMLA:
case ISD::PARTIAL_REDUCE_SMLA:
- Res = PromoteIntOp_PARTIAL_REDUCE_MLA(N);
+ Res = PromoteIntOp_PARTIAL_REDUCE_MLA(N, OpNo);
break;
}
@@ -2890,10 +2890,11 @@ SDValue DAGTypeLegalizer::PromoteIntOp_VECTOR_FIND_LAST_ACTIVE(SDNode *N,
return SDValue(DAG.UpdateNodeOperands(N, NewOps), 0);
}
-SDValue DAGTypeLegalizer::PromoteIntOp_PARTIAL_REDUCE_MLA(SDNode *N) {
- SDValue Res = DAG.expandPartialReduceMLA(N);
- ReplaceValueWith(SDValue(N, 0), Res);
- return SDValue();
+SDValue DAGTypeLegalizer::PromoteIntOp_PARTIAL_REDUCE_MLA(SDNode *N,
+ unsigned OpNo) {
+ SmallVector<SDValue, 1> NewOps(N->ops());
+ NewOps[OpNo] = GetPromotedInteger(N->getOperand(OpNo));
+ return SDValue(DAG.UpdateNodeOperands(N, NewOps), 0);
}
//===----------------------------------------------------------------------===//
@@ -6212,9 +6213,13 @@ SDValue DAGTypeLegalizer::PromoteIntRes_VECTOR_FIND_LAST_ACTIVE(SDNode *N) {
}
SDValue DAGTypeLegalizer::PromoteIntRes_PARTIAL_REDUCE_MLA(SDNode *N) {
- SDValue Res = DAG.expandPartialReduceMLA(N);
- ReplaceValueWith(SDValue(N, 0), Res);
- return SDValue();
+ SDLoc DL(N);
+ EVT VT = N->getValueType(0);
+ EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), VT);
+ SDValue ExtAcc = GetPromotedInteger(N->getOperand(0));
+ SDValue V = DAG.getNode(N->getOpcode(), DL, NVT, ExtAcc, N->getOperand(1),
+ N->getOperand(2));
+ return V;
}
SDValue DAGTypeLegalizer::PromoteIntRes_INSERT_VECTOR_ELT(SDNode *N) {
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
index cb9c1b239c0fa..e48aa35773780 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
@@ -431,7 +431,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
SDValue PromoteIntOp_VP_SPLICE(SDNode *N, unsigned OpNo);
SDValue PromoteIntOp_VECTOR_HISTOGRAM(SDNode *N, unsigned OpNo);
SDValue PromoteIntOp_VECTOR_FIND_LAST_ACTIVE(SDNode *N, unsigned OpNo);
- SDValue PromoteIntOp_PARTIAL_REDUCE_MLA(SDNode *N);
+ SDValue PromoteIntOp_PARTIAL_REDUCE_MLA(SDNode *N, unsigned OpNo);
void SExtOrZExtPromotedOperands(SDValue &LHS, SDValue &RHS);
void PromoteSetCCOperands(SDValue &LHS,SDValue &RHS, ISD::CondCode Code);
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
index 6ad08bce44b0a..ab74e8e1bac00 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
@@ -503,6 +503,8 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
case ISD::VECREDUCE_FMIN:
case ISD::VECREDUCE_FMAXIMUM:
case ISD::VECREDUCE_FMINIMUM:
+ case ISD::PARTIAL_REDUCE_UMLA:
+ case ISD::PARTIAL_REDUCE_SMLA:
case ISD::VECTOR_FIND_LAST_ACTIVE:
Action = TLI.getOperationAction(Node->getOpcode(),
Node->getOperand(0).getValueType());
@@ -1195,6 +1197,10 @@ void VectorLegalizer::Expand(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
case ISD::VECREDUCE_FMINIMUM:
Results.push_back(TLI.expandVecReduce(Node, DAG));
return;
+ case ISD::PARTIAL_REDUCE_UMLA:
+ case ISD::PARTIAL_REDUCE_SMLA:
+ Results.push_back(TLI.expandPartialReduceMLA(Node, DAG));
+ return;
case ISD::VECREDUCE_SEQ_FADD:
case ISD::VECREDUCE_SEQ_FMUL:
Results.push_back(TLI.expandVecReduceSeq(Node, DAG));
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
index b01470028981e..97396627888f4 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
@@ -3186,7 +3186,7 @@ void DAGTypeLegalizer::SplitVecRes_VP_REVERSE(SDNode *N, SDValue &Lo,
}
void DAGTypeLegalizer::SplitVecRes_PARTIAL_REDUCE_MLA(SDNode *N) {
- SDValue Res = DAG.expandPartialReduceMLA(N);
+ SDValue Res = TLI.expandPartialReduceMLA(N, DAG);
ReplaceValueWith(SDValue(N, 0), Res);
}
@@ -4447,7 +4447,7 @@ SDValue DAGTypeLegalizer::SplitVecOp_VECTOR_HISTOGRAM(SDNode *N) {
}
SDValue DAGTypeLegalizer::SplitVecOp_PARTIAL_REDUCE_MLA(SDNode *N) {
- SDValue Res = DAG.expandPartialReduceMLA(N);
+ SDValue Res = TLI.expandPartialReduceMLA(N, DAG);
ReplaceValueWith(SDValue(N, 0), Res);
return SDValue();
}
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 32fca9028c7af..16c3b295426c6 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -2474,23 +2474,6 @@ SDValue SelectionDAG::getShiftAmountOperand(EVT LHSTy, SDValue Op) {
return getZExtOrTrunc(Op, SDLoc(Op), ShTy);
}
-SDValue SelectionDAG::expandPartialReduceMLA(SDNode *N) {
- SDLoc DL(N);
- SDValue Acc = N->getOperand(0);
- SDValue Input1 = N->getOperand(1);
- SDValue Input2 = N->getOperand(2);
-
- EVT FullTy = Input1.getValueType();
-
- SDValue Input = Input1;
- APInt ConstantOne;
- if (!ISD::isConstantSplatVector(Input2.getNode(), ConstantOne) ||
- !ConstantOne.isOne())
- Input = getNode(ISD::MUL, DL, FullTy, Input1, Input2);
-
- return getPartialReduceAdd(DL, Acc.getValueType(), Acc, Input);
-}
-
SDValue SelectionDAG::getPartialReduceAdd(SDLoc DL, EVT ReducedTy, SDValue Op1,
SDValue Op2) {
EVT FullTy = Op2.getValueType();
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 144439f136ff1..530e1ff3d0af0 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -135,6 +135,9 @@ static cl::opt<unsigned> SwitchPeelThreshold(
"switch statement. A value greater than 100 will void this "
"optimization"));
+// FIXME : This is a temporary flag, and is used to help transition to
+// performing lowering the proper way using the new PARTIAL_REDUCE_MLA ISD
+// nodes.
static cl::opt<bool> NewPartialReduceLowering(
"new-partial-reduce-lowering", cl::init(false), cl::ReallyHidden,
cl::desc("Use the new method of lowering partial reductions."));
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index adfb96041c5c0..32115e98e15b0 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -12188,3 +12188,41 @@ SDValue TargetLowering::scalarizeExtractedVectorLoad(EVT ResultVT,
return Load;
}
+
+SDValue TargetLowering::expandPartialReduceMLA(SDNode *N,
+ SelectionDAG &DAG) const {
+ SDLoc DL(N);
+ SDValue Acc = N->getOperand(0);
+ SDValue Input1 = N->getOperand(1);
+ SDValue Input2 = N->getOperand(2);
+
+ EVT ReducedTy = Acc.getValueType();
+ EVT FullTy = Input1.getValueType();
+
+ auto ExtendToAccEltVT = [&](SDValue V) {
+ unsigned ExtOpc = N->getOpcode() == ISD::PARTIAL_REDUCE_UMLA
+ ? ISD::ZERO_EXTEND
+ : ISD::SIGN_EXTEND;
+ EVT ExtVT = V.getValueType().changeVectorElementType(
+ Acc.getValueType().getVectorElementType());
+ if (ExtVT != FullTy)
+ return DAG.getNode(ExtOpc, DL, ExtVT, V);
+ return V;
+ };
+
+ SDValue Input;
+ APInt ConstantOne;
+ if (!ISD::isConstantSplatVector(Input2.getNode(), ConstantOne) ||
+ !ConstantOne.isOne()) {
+ EVT NewVT =
+ EVT::getVectorVT(*DAG.getContext(), ReducedTy.getVectorElementType(),
+ FullTy.getVectorElementCount());
+ Input1 = ExtendToAccEltVT(Input1);
+ Input2 = ExtendToAccEltVT(Input2);
+ Input = DAG.getNode(ISD::MUL, DL, NewVT, Input1, Input2);
+ } else {
+ Input = ExtendToAccEltVT(Input1);
+ }
+
+ return DAG.getPartialReduceAdd(DL, ReducedTy, Acc, Input);
+}
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 84f6d421b70f9..8ea5995f69108 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1571,6 +1571,11 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::MSTORE, VT, Custom);
}
+ for (MVT VT : MVT::scalable_vector_valuetypes()) {
+ setOperationAction(ISD::PARTIAL_REDUCE_UMLA, VT, Expand);
+ setOperationAction(ISD::PARTIAL_REDUCE_SMLA, VT, Expand);
+ }
+
// Firstly, exclude all scalable vector extending loads/truncating stores,
// include both integer and floating scalable vector.
for (MVT VT : MVT::scalable_vector_valuetypes()) {
>From 4cd617a292628c67eadc7ada0f399d8c0b97766d Mon Sep 17 00:00:00 2001
From: James Chesterman <james.chesterman at arm.com>
Date: Tue, 4 Feb 2025 13:55:24 +0000
Subject: [PATCH 03/12] Address comments
Move the getPartialReduceAdd function around.
Make the new codepath work for fixed length NEON vectors too.
---
llvm/include/llvm/CodeGen/SelectionDAG.h | 5 -
llvm/include/llvm/CodeGen/TargetLowering.h | 9 +-
.../SelectionDAG/LegalizeVectorOps.cpp | 8 +-
.../SelectionDAG/LegalizeVectorTypes.cpp | 8 +-
.../lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 29 --
.../SelectionDAG/SelectionDAGBuilder.cpp | 31 +-
.../CodeGen/SelectionDAG/TargetLowering.cpp | 45 ++-
.../Target/AArch64/AArch64ISelLowering.cpp | 9 +-
.../neon-partial-reduce-dot-product.ll | 289 ++++++++++++++++++
9 files changed, 365 insertions(+), 68 deletions(-)
diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h
index 461c0c1ead16d..cf8e4a3d2513b 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAG.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAG.h
@@ -1607,11 +1607,6 @@ class SelectionDAG {
/// the target's desired shift amount type.
SDValue getShiftAmountOperand(EVT LHSTy, SDValue Op);
- /// Create the DAG equivalent of vector_partial_reduce where Op1 and Op2 are
- /// its operands and ReducedTY is the intrinsic's return type.
- SDValue getPartialReduceAdd(SDLoc DL, EVT ReducedTy, SDValue Op1,
- SDValue Op2);
-
/// Expands a node with multiple results to an FP or vector libcall. The
/// libcall is expected to take all the operands of the \p Node followed by
/// output pointers for each of the results. \p CallRetResNo can be optionally
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 91c13cd507542..cefd12639c5a8 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -5636,7 +5636,14 @@ class TargetLowering : public TargetLoweringBase {
// Expands PARTIAL_REDUCE_S/UMLA nodes to a series of simpler operations,
// consisting of zext/sext, extract_subvector, mul and add operations.
- SDValue expandPartialReduceMLA(SDNode *N, SelectionDAG &DAG) const;
+ SDValue expandPartialReduceMLA(SDLoc DL, SDValue Acc, SDValue Input1,
+ SDValue Input2, SelectionDAG &DAG) const;
+
+ // Create the DAG equivalent of vector_partial_reduce where Op1 and Op2 are
+ // its operands and ReducedTY is the return type.
+ static SDValue getPartialReduceAdd(SDLoc DL, EVT ReducedTy, EVT FullTy,
+ SDValue Op1, SDValue Op2,
+ SelectionDAG &DAG);
private:
SDValue foldSetCCWithAnd(EVT VT, SDValue N0, SDValue N1, ISD::CondCode Cond,
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
index ab74e8e1bac00..da86996e88ed8 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
@@ -1198,9 +1198,13 @@ void VectorLegalizer::Expand(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
Results.push_back(TLI.expandVecReduce(Node, DAG));
return;
case ISD::PARTIAL_REDUCE_UMLA:
- case ISD::PARTIAL_REDUCE_SMLA:
- Results.push_back(TLI.expandPartialReduceMLA(Node, DAG));
+ case ISD::PARTIAL_REDUCE_SMLA: {
+ SDLoc DL(Node);
+ Results.push_back(TLI.expandPartialReduceMLA(DL, Node->getOperand(0),
+ Node->getOperand(1),
+ Node->getOperand(2), DAG));
return;
+ }
case ISD::VECREDUCE_SEQ_FADD:
case ISD::VECREDUCE_SEQ_FMUL:
Results.push_back(TLI.expandVecReduceSeq(Node, DAG));
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
index 97396627888f4..ed6e1b57542af 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
@@ -3186,7 +3186,9 @@ void DAGTypeLegalizer::SplitVecRes_VP_REVERSE(SDNode *N, SDValue &Lo,
}
void DAGTypeLegalizer::SplitVecRes_PARTIAL_REDUCE_MLA(SDNode *N) {
- SDValue Res = TLI.expandPartialReduceMLA(N, DAG);
+ SDLoc DL(N);
+ SDValue Res = TLI.expandPartialReduceMLA(
+ DL, N->getOperand(0), N->getOperand(1), N->getOperand(2), DAG);
ReplaceValueWith(SDValue(N, 0), Res);
}
@@ -4447,7 +4449,9 @@ SDValue DAGTypeLegalizer::SplitVecOp_VECTOR_HISTOGRAM(SDNode *N) {
}
SDValue DAGTypeLegalizer::SplitVecOp_PARTIAL_REDUCE_MLA(SDNode *N) {
- SDValue Res = TLI.expandPartialReduceMLA(N, DAG);
+ SDLoc DL(N);
+ SDValue Res = TLI.expandPartialReduceMLA(
+ DL, N->getOperand(0), N->getOperand(1), N->getOperand(2), DAG);
ReplaceValueWith(SDValue(N, 0), Res);
return SDValue();
}
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 16c3b295426c6..b3ab68f14df5c 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -2474,35 +2474,6 @@ SDValue SelectionDAG::getShiftAmountOperand(EVT LHSTy, SDValue Op) {
return getZExtOrTrunc(Op, SDLoc(Op), ShTy);
}
-SDValue SelectionDAG::getPartialReduceAdd(SDLoc DL, EVT ReducedTy, SDValue Op1,
- SDValue Op2) {
- EVT FullTy = Op2.getValueType();
-
- unsigned Stride = ReducedTy.getVectorMinNumElements();
- unsigned ScaleFactor = FullTy.getVectorMinNumElements() / Stride;
-
- // Collect all of the subvectors
- std::deque<SDValue> Subvectors = {Op1};
- for (unsigned I = 0; I < ScaleFactor; I++) {
- auto SourceIndex = getVectorIdxConstant(I * Stride, DL);
- Subvectors.push_back(
- getNode(ISD::EXTRACT_SUBVECTOR, DL, ReducedTy, {Op2, SourceIndex}));
- }
-
- // Flatten the subvector tree
- while (Subvectors.size() > 1) {
- Subvectors.push_back(
- getNode(ISD::ADD, DL, ReducedTy, {Subvectors[0], Subvectors[1]}));
- Subvectors.pop_front();
- Subvectors.pop_front();
- }
-
- assert(Subvectors.size() == 1 &&
- "There should only be one subvector after tree flattening");
-
- return Subvectors[0];
-}
-
/// Given a store node \p StoreNode, return true if it is safe to fold that node
/// into \p FPNode, which expands to a library call with output pointers.
static bool canFoldStoreIntoLibCallOutputPointers(StoreSDNode *StoreNode,
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 530e1ff3d0af0..bc08b6c5580f7 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -8125,21 +8125,19 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
return;
}
case Intrinsic::experimental_vector_partial_reduce_add: {
+ SDValue Acc = getValue(I.getOperand(0));
+ EVT AccVT = Acc.getValueType();
+ SDValue Input = getValue(I.getOperand(1));
+ EVT InputVT = Input.getValueType();
+
+ assert(AccVT.getVectorElementType() == InputVT.getVectorElementType() &&
+ "Expected operands to have the same vector element type!");
+ assert(InputVT.getVectorElementCount().getKnownMinValue() %
+ AccVT.getVectorElementCount().getKnownMinValue() ==
+ 0 &&
+ "Expected the element count of the Input operand to be a positive "
+ "integer multiple of the element count of the Accumulator operand!");
if (NewPartialReduceLowering) {
- SDValue Acc = getValue(I.getOperand(0));
- EVT AccVT = Acc.getValueType();
- SDValue Input = getValue(I.getOperand(1));
- EVT InputVT = Input.getValueType();
-
- assert(AccVT.getVectorElementType() == InputVT.getVectorElementType() &&
- "Expected operands to have the same vector element type!");
- assert(
- InputVT.getVectorElementCount().getKnownMinValue() %
- AccVT.getVectorElementCount().getKnownMinValue() ==
- 0 &&
- "Expected the element count of the Input operand to be a positive "
- "integer multiple of the element count of the Accumulator operand!");
-
// ISD::PARTIAL_REDUCE_UMLA is chosen arbitrarily and would function the
// same if ISD::PARTIAL_REDUCE_SMLA was chosen instead. It should be
// changed to its correct signedness when combining or expanding,
@@ -8154,9 +8152,8 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
return;
}
- setValue(&I, DAG.getPartialReduceAdd(sdl, EVT::getEVT(I.getType()),
- getValue(I.getOperand(0)),
- getValue(I.getOperand(1))));
+ setValue(&I, TLI.expandPartialReduceMLA(
+ sdl, Acc, Input, DAG.getConstant(1, sdl, InputVT), DAG));
return;
}
case Intrinsic::experimental_cttz_elts: {
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index 32115e98e15b0..8f2164c12958c 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -34,6 +34,7 @@
#include "llvm/Support/MathExtras.h"
#include "llvm/Target/TargetMachine.h"
#include <cctype>
+#include <deque>
using namespace llvm;
/// NOTE: The TargetMachine owns TLOF.
@@ -12189,20 +12190,15 @@ SDValue TargetLowering::scalarizeExtractedVectorLoad(EVT ResultVT,
return Load;
}
-SDValue TargetLowering::expandPartialReduceMLA(SDNode *N,
+SDValue TargetLowering::expandPartialReduceMLA(SDLoc DL, SDValue Acc,
+ SDValue Input1, SDValue Input2,
SelectionDAG &DAG) const {
- SDLoc DL(N);
- SDValue Acc = N->getOperand(0);
- SDValue Input1 = N->getOperand(1);
- SDValue Input2 = N->getOperand(2);
-
EVT ReducedTy = Acc.getValueType();
EVT FullTy = Input1.getValueType();
auto ExtendToAccEltVT = [&](SDValue V) {
- unsigned ExtOpc = N->getOpcode() == ISD::PARTIAL_REDUCE_UMLA
- ? ISD::ZERO_EXTEND
- : ISD::SIGN_EXTEND;
+ unsigned ExtOpc = V->getOpcode() == ISD::SIGN_EXTEND ? ISD::SIGN_EXTEND
+ : ISD::ZERO_EXTEND;
EVT ExtVT = V.getValueType().changeVectorElementType(
Acc.getValueType().getVectorElementType());
if (ExtVT != FullTy)
@@ -12224,5 +12220,34 @@ SDValue TargetLowering::expandPartialReduceMLA(SDNode *N,
Input = ExtendToAccEltVT(Input1);
}
- return DAG.getPartialReduceAdd(DL, ReducedTy, Acc, Input);
+ return TargetLowering::getPartialReduceAdd(DL, ReducedTy, FullTy, Acc, Input,
+ DAG);
+}
+
+SDValue TargetLowering::getPartialReduceAdd(SDLoc DL, EVT ReducedTy, EVT FullTy,
+ SDValue Op1, SDValue Op2,
+ SelectionDAG &DAG) {
+ unsigned Stride = ReducedTy.getVectorMinNumElements();
+ unsigned ScaleFactor = FullTy.getVectorMinNumElements() / Stride;
+
+ // Collect all of the subvectors
+ std::deque<SDValue> Subvectors = {Op1};
+ for (unsigned I = 0; I < ScaleFactor; I++) {
+ auto SourceIndex = DAG.getVectorIdxConstant(I * Stride, DL);
+ Subvectors.push_back(
+ DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ReducedTy, {Op2, SourceIndex}));
+ }
+
+ // Flatten the subvector tree
+ while (Subvectors.size() > 1) {
+ Subvectors.push_back(
+ DAG.getNode(ISD::ADD, DL, ReducedTy, {Subvectors[0], Subvectors[1]}));
+ Subvectors.pop_front();
+ Subvectors.pop_front();
+ }
+
+ assert(Subvectors.size() == 1 &&
+ "There should only be one subvector after tree flattening");
+
+ return Subvectors[0];
}
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 8ea5995f69108..e2bfb6df8cbb0 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1358,6 +1358,9 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::BSWAP, VT, Expand);
setOperationAction(ISD::CTTZ, VT, Expand);
+ setOperationAction(ISD::PARTIAL_REDUCE_UMLA, VT, Expand);
+ setOperationAction(ISD::PARTIAL_REDUCE_SMLA, VT, Expand);
+
for (MVT InnerVT : MVT::fixedlen_vector_valuetypes()) {
setTruncStoreAction(VT, InnerVT, Expand);
setLoadExtAction(ISD::SEXTLOAD, VT, InnerVT, Expand);
@@ -22015,8 +22018,10 @@ static SDValue performIntrinsicCombine(SDNode *N,
return Dot;
if (SDValue WideAdd = tryLowerPartialReductionToWideAdd(N, Subtarget, DAG))
return WideAdd;
- return DAG.getPartialReduceAdd(SDLoc(N), N->getValueType(0),
- N->getOperand(1), N->getOperand(2));
+ SDValue Input = N->getOperand(2);
+ return TargetLowering::getPartialReduceAdd(SDLoc(N), N->getValueType(0),
+ Input.getValueType(),
+ N->getOperand(1), Input, DAG);
}
case Intrinsic::aarch64_neon_vcvtfxs2fp:
case Intrinsic::aarch64_neon_vcvtfxu2fp:
diff --git a/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
index 9ece9edb84343..4fa10d0e9c0f9 100644
--- a/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
@@ -2,6 +2,7 @@
; RUN: llc -mtriple aarch64 -mattr=+neon,+dotprod < %s | FileCheck %s --check-prefixes=CHECK,CHECK-DOT,CHECK-NOI8MM
; RUN: llc -mtriple aarch64 -mattr=+neon < %s | FileCheck %s --check-prefixes=CHECK,CHECK-NOI8MM,CHECK-NODOT
; RUN: llc -mtriple aarch64 -mattr=+neon,+dotprod,+i8mm < %s | FileCheck %s --check-prefixes=CHECK,CHECK-DOT,CHECK-I8MM
+; RUN: llc -mtriple aarch64 -mattr=+neon,+dotprod,+i8mm -new-partial-reduce-lowering < %s | FileCheck %s --check-prefixes=CHECK,CHECK-NEWLOWERING
define <4 x i32> @udot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) {
; CHECK-DOT-LABEL: udot:
@@ -19,6 +20,17 @@ define <4 x i32> @udot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) {
; CHECK-NODOT-NEXT: uaddw2 v0.4s, v0.4s, v1.8h
; CHECK-NODOT-NEXT: add v0.4s, v2.4s, v0.4s
; CHECK-NODOT-NEXT: ret
+;
+; CHECK-NEWLOWERING-LABEL: udot:
+; CHECK-NEWLOWERING: // %bb.0:
+; CHECK-NEWLOWERING-NEXT: umull v3.8h, v2.8b, v1.8b
+; CHECK-NEWLOWERING-NEXT: umull2 v1.8h, v2.16b, v1.16b
+; CHECK-NEWLOWERING-NEXT: ushll v2.4s, v1.4h, #0
+; CHECK-NEWLOWERING-NEXT: uaddw v0.4s, v0.4s, v3.4h
+; CHECK-NEWLOWERING-NEXT: uaddw2 v2.4s, v2.4s, v3.8h
+; CHECK-NEWLOWERING-NEXT: uaddw2 v0.4s, v0.4s, v1.8h
+; CHECK-NEWLOWERING-NEXT: add v0.4s, v2.4s, v0.4s
+; CHECK-NEWLOWERING-NEXT: ret
%u.wide = zext <16 x i8> %u to <16 x i32>
%s.wide = zext <16 x i8> %s to <16 x i32>
%mult = mul nuw nsw <16 x i32> %s.wide, %u.wide
@@ -46,6 +58,21 @@ define <2 x i32> @udot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) {
; CHECK-NODOT-NEXT: uaddw v1.4s, v2.4s, v4.4h
; CHECK-NODOT-NEXT: add v0.2s, v1.2s, v0.2s
; CHECK-NODOT-NEXT: ret
+;
+; CHECK-NEWLOWERING-LABEL: udot_narrow:
+; CHECK-NEWLOWERING: // %bb.0:
+; CHECK-NEWLOWERING-NEXT: umull v1.8h, v2.8b, v1.8b
+; CHECK-NEWLOWERING-NEXT: // kill: def $d0 killed $d0 def $q0
+; CHECK-NEWLOWERING-NEXT: ushll v2.4s, v1.4h, #0
+; CHECK-NEWLOWERING-NEXT: ushll2 v3.4s, v1.8h, #0
+; CHECK-NEWLOWERING-NEXT: ext v4.16b, v1.16b, v1.16b, #8
+; CHECK-NEWLOWERING-NEXT: uaddw v0.4s, v0.4s, v1.4h
+; CHECK-NEWLOWERING-NEXT: ext v3.16b, v3.16b, v3.16b, #8
+; CHECK-NEWLOWERING-NEXT: ext v2.16b, v2.16b, v2.16b, #8
+; CHECK-NEWLOWERING-NEXT: add v0.2s, v3.2s, v0.2s
+; CHECK-NEWLOWERING-NEXT: uaddw v1.4s, v2.4s, v4.4h
+; CHECK-NEWLOWERING-NEXT: add v0.2s, v1.2s, v0.2s
+; CHECK-NEWLOWERING-NEXT: ret
%u.wide = zext <8 x i8> %u to <8 x i32>
%s.wide = zext <8 x i8> %s to <8 x i32>
%mult = mul nuw nsw <8 x i32> %s.wide, %u.wide
@@ -69,6 +96,17 @@ define <4 x i32> @sdot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) {
; CHECK-NODOT-NEXT: saddw2 v0.4s, v0.4s, v1.8h
; CHECK-NODOT-NEXT: add v0.4s, v2.4s, v0.4s
; CHECK-NODOT-NEXT: ret
+;
+; CHECK-NEWLOWERING-LABEL: sdot:
+; CHECK-NEWLOWERING: // %bb.0:
+; CHECK-NEWLOWERING-NEXT: smull v3.8h, v2.8b, v1.8b
+; CHECK-NEWLOWERING-NEXT: smull2 v1.8h, v2.16b, v1.16b
+; CHECK-NEWLOWERING-NEXT: sshll v2.4s, v1.4h, #0
+; CHECK-NEWLOWERING-NEXT: saddw v0.4s, v0.4s, v3.4h
+; CHECK-NEWLOWERING-NEXT: saddw2 v2.4s, v2.4s, v3.8h
+; CHECK-NEWLOWERING-NEXT: saddw2 v0.4s, v0.4s, v1.8h
+; CHECK-NEWLOWERING-NEXT: add v0.4s, v2.4s, v0.4s
+; CHECK-NEWLOWERING-NEXT: ret
%u.wide = sext <16 x i8> %u to <16 x i32>
%s.wide = sext <16 x i8> %s to <16 x i32>
%mult = mul nuw nsw <16 x i32> %s.wide, %u.wide
@@ -96,6 +134,21 @@ define <2 x i32> @sdot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) {
; CHECK-NODOT-NEXT: saddw v1.4s, v2.4s, v4.4h
; CHECK-NODOT-NEXT: add v0.2s, v1.2s, v0.2s
; CHECK-NODOT-NEXT: ret
+;
+; CHECK-NEWLOWERING-LABEL: sdot_narrow:
+; CHECK-NEWLOWERING: // %bb.0:
+; CHECK-NEWLOWERING-NEXT: smull v1.8h, v2.8b, v1.8b
+; CHECK-NEWLOWERING-NEXT: // kill: def $d0 killed $d0 def $q0
+; CHECK-NEWLOWERING-NEXT: sshll v2.4s, v1.4h, #0
+; CHECK-NEWLOWERING-NEXT: sshll2 v3.4s, v1.8h, #0
+; CHECK-NEWLOWERING-NEXT: ext v4.16b, v1.16b, v1.16b, #8
+; CHECK-NEWLOWERING-NEXT: saddw v0.4s, v0.4s, v1.4h
+; CHECK-NEWLOWERING-NEXT: ext v3.16b, v3.16b, v3.16b, #8
+; CHECK-NEWLOWERING-NEXT: ext v2.16b, v2.16b, v2.16b, #8
+; CHECK-NEWLOWERING-NEXT: add v0.2s, v3.2s, v0.2s
+; CHECK-NEWLOWERING-NEXT: saddw v1.4s, v2.4s, v4.4h
+; CHECK-NEWLOWERING-NEXT: add v0.2s, v1.2s, v0.2s
+; CHECK-NEWLOWERING-NEXT: ret
%u.wide = sext <8 x i8> %u to <8 x i32>
%s.wide = sext <8 x i8> %s to <8 x i32>
%mult = mul nuw nsw <8 x i32> %s.wide, %u.wide
@@ -121,6 +174,19 @@ define <4 x i32> @usdot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) {
; CHECK-I8MM: // %bb.0:
; CHECK-I8MM-NEXT: usdot v0.4s, v1.16b, v2.16b
; CHECK-I8MM-NEXT: ret
+;
+; CHECK-NEWLOWERING-LABEL: usdot:
+; CHECK-NEWLOWERING: // %bb.0:
+; CHECK-NEWLOWERING-NEXT: ushll v3.8h, v1.8b, #0
+; CHECK-NEWLOWERING-NEXT: ushll2 v1.8h, v1.16b, #0
+; CHECK-NEWLOWERING-NEXT: sshll v4.8h, v2.8b, #0
+; CHECK-NEWLOWERING-NEXT: sshll2 v2.8h, v2.16b, #0
+; CHECK-NEWLOWERING-NEXT: smlal v0.4s, v4.4h, v3.4h
+; CHECK-NEWLOWERING-NEXT: smull v5.4s, v2.4h, v1.4h
+; CHECK-NEWLOWERING-NEXT: smlal2 v0.4s, v2.8h, v1.8h
+; CHECK-NEWLOWERING-NEXT: smlal2 v5.4s, v4.8h, v3.8h
+; CHECK-NEWLOWERING-NEXT: add v0.4s, v5.4s, v0.4s
+; CHECK-NEWLOWERING-NEXT: ret
%u.wide = zext <16 x i8> %u to <16 x i32>
%s.wide = sext <16 x i8> %s to <16 x i32>
%mult = mul nuw nsw <16 x i32> %s.wide, %u.wide
@@ -150,6 +216,23 @@ define <2 x i32> @usdot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) #0{
; CHECK-I8MM: // %bb.0:
; CHECK-I8MM-NEXT: usdot v0.2s, v1.8b, v2.8b
; CHECK-I8MM-NEXT: ret
+;
+; CHECK-NEWLOWERING-LABEL: usdot_narrow:
+; CHECK-NEWLOWERING: // %bb.0:
+; CHECK-NEWLOWERING-NEXT: ushll v1.8h, v1.8b, #0
+; CHECK-NEWLOWERING-NEXT: sshll v2.8h, v2.8b, #0
+; CHECK-NEWLOWERING-NEXT: // kill: def $d0 killed $d0 def $q0
+; CHECK-NEWLOWERING-NEXT: smull v3.4s, v2.4h, v1.4h
+; CHECK-NEWLOWERING-NEXT: smull2 v4.4s, v2.8h, v1.8h
+; CHECK-NEWLOWERING-NEXT: ext v5.16b, v1.16b, v1.16b, #8
+; CHECK-NEWLOWERING-NEXT: ext v6.16b, v2.16b, v2.16b, #8
+; CHECK-NEWLOWERING-NEXT: smlal v0.4s, v2.4h, v1.4h
+; CHECK-NEWLOWERING-NEXT: ext v3.16b, v3.16b, v3.16b, #8
+; CHECK-NEWLOWERING-NEXT: ext v1.16b, v4.16b, v4.16b, #8
+; CHECK-NEWLOWERING-NEXT: smlal v3.4s, v6.4h, v5.4h
+; CHECK-NEWLOWERING-NEXT: add v0.2s, v1.2s, v0.2s
+; CHECK-NEWLOWERING-NEXT: add v0.2s, v3.2s, v0.2s
+; CHECK-NEWLOWERING-NEXT: ret
%u.wide = zext <8 x i8> %u to <8 x i32>
%s.wide = sext <8 x i8> %s to <8 x i32>
%mult = mul nuw nsw <8 x i32> %s.wide, %u.wide
@@ -175,6 +258,19 @@ define <4 x i32> @sudot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) #0{
; CHECK-I8MM: // %bb.0:
; CHECK-I8MM-NEXT: usdot v0.4s, v2.16b, v1.16b
; CHECK-I8MM-NEXT: ret
+;
+; CHECK-NEWLOWERING-LABEL: sudot:
+; CHECK-NEWLOWERING: // %bb.0:
+; CHECK-NEWLOWERING-NEXT: sshll v3.8h, v1.8b, #0
+; CHECK-NEWLOWERING-NEXT: sshll2 v1.8h, v1.16b, #0
+; CHECK-NEWLOWERING-NEXT: ushll v4.8h, v2.8b, #0
+; CHECK-NEWLOWERING-NEXT: ushll2 v2.8h, v2.16b, #0
+; CHECK-NEWLOWERING-NEXT: smlal v0.4s, v4.4h, v3.4h
+; CHECK-NEWLOWERING-NEXT: smull v5.4s, v2.4h, v1.4h
+; CHECK-NEWLOWERING-NEXT: smlal2 v0.4s, v2.8h, v1.8h
+; CHECK-NEWLOWERING-NEXT: smlal2 v5.4s, v4.8h, v3.8h
+; CHECK-NEWLOWERING-NEXT: add v0.4s, v5.4s, v0.4s
+; CHECK-NEWLOWERING-NEXT: ret
%u.wide = sext <16 x i8> %u to <16 x i32>
%s.wide = zext <16 x i8> %s to <16 x i32>
%mult = mul nuw nsw <16 x i32> %s.wide, %u.wide
@@ -204,6 +300,23 @@ define <2 x i32> @sudot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) #0{
; CHECK-I8MM: // %bb.0:
; CHECK-I8MM-NEXT: usdot v0.2s, v2.8b, v1.8b
; CHECK-I8MM-NEXT: ret
+;
+; CHECK-NEWLOWERING-LABEL: sudot_narrow:
+; CHECK-NEWLOWERING: // %bb.0:
+; CHECK-NEWLOWERING-NEXT: sshll v1.8h, v1.8b, #0
+; CHECK-NEWLOWERING-NEXT: ushll v2.8h, v2.8b, #0
+; CHECK-NEWLOWERING-NEXT: // kill: def $d0 killed $d0 def $q0
+; CHECK-NEWLOWERING-NEXT: smull v3.4s, v2.4h, v1.4h
+; CHECK-NEWLOWERING-NEXT: smull2 v4.4s, v2.8h, v1.8h
+; CHECK-NEWLOWERING-NEXT: ext v5.16b, v1.16b, v1.16b, #8
+; CHECK-NEWLOWERING-NEXT: ext v6.16b, v2.16b, v2.16b, #8
+; CHECK-NEWLOWERING-NEXT: smlal v0.4s, v2.4h, v1.4h
+; CHECK-NEWLOWERING-NEXT: ext v3.16b, v3.16b, v3.16b, #8
+; CHECK-NEWLOWERING-NEXT: ext v1.16b, v4.16b, v4.16b, #8
+; CHECK-NEWLOWERING-NEXT: smlal v3.4s, v6.4h, v5.4h
+; CHECK-NEWLOWERING-NEXT: add v0.2s, v1.2s, v0.2s
+; CHECK-NEWLOWERING-NEXT: add v0.2s, v3.2s, v0.2s
+; CHECK-NEWLOWERING-NEXT: ret
%u.wide = sext <8 x i8> %u to <8 x i32>
%s.wide = zext <8 x i8> %s to <8 x i32>
%mult = mul nuw nsw <8 x i32> %s.wide, %u.wide
@@ -237,6 +350,24 @@ define <4 x i64> @udot_8to64(<4 x i64> %acc, <16 x i8> %a, <16 x i8> %b) {
; CHECK-NODOT-NEXT: add v1.2d, v3.2d, v1.2d
; CHECK-NODOT-NEXT: add v0.2d, v4.2d, v0.2d
; CHECK-NODOT-NEXT: ret
+;
+; CHECK-NEWLOWERING-LABEL: udot_8to64:
+; CHECK-NEWLOWERING: // %bb.0: // %entry
+; CHECK-NEWLOWERING-NEXT: umull v4.8h, v2.8b, v3.8b
+; CHECK-NEWLOWERING-NEXT: umull2 v2.8h, v2.16b, v3.16b
+; CHECK-NEWLOWERING-NEXT: ushll v3.4s, v4.4h, #0
+; CHECK-NEWLOWERING-NEXT: ushll v5.4s, v2.4h, #0
+; CHECK-NEWLOWERING-NEXT: ushll2 v4.4s, v4.8h, #0
+; CHECK-NEWLOWERING-NEXT: ushll2 v2.4s, v2.8h, #0
+; CHECK-NEWLOWERING-NEXT: uaddw2 v1.2d, v1.2d, v3.4s
+; CHECK-NEWLOWERING-NEXT: uaddw v0.2d, v0.2d, v3.2s
+; CHECK-NEWLOWERING-NEXT: uaddl2 v3.2d, v4.4s, v5.4s
+; CHECK-NEWLOWERING-NEXT: uaddl v4.2d, v4.2s, v5.2s
+; CHECK-NEWLOWERING-NEXT: uaddw2 v1.2d, v1.2d, v2.4s
+; CHECK-NEWLOWERING-NEXT: uaddw v0.2d, v0.2d, v2.2s
+; CHECK-NEWLOWERING-NEXT: add v1.2d, v3.2d, v1.2d
+; CHECK-NEWLOWERING-NEXT: add v0.2d, v4.2d, v0.2d
+; CHECK-NEWLOWERING-NEXT: ret
entry:
%a.wide = zext <16 x i8> %a to <16 x i64>
%b.wide = zext <16 x i8> %b to <16 x i64>
@@ -272,6 +403,24 @@ define <4 x i64> @sdot_8to64(<4 x i64> %acc, <16 x i8> %a, <16 x i8> %b){
; CHECK-NODOT-NEXT: add v1.2d, v3.2d, v1.2d
; CHECK-NODOT-NEXT: add v0.2d, v4.2d, v0.2d
; CHECK-NODOT-NEXT: ret
+;
+; CHECK-NEWLOWERING-LABEL: sdot_8to64:
+; CHECK-NEWLOWERING: // %bb.0: // %entry
+; CHECK-NEWLOWERING-NEXT: smull v4.8h, v2.8b, v3.8b
+; CHECK-NEWLOWERING-NEXT: smull2 v2.8h, v2.16b, v3.16b
+; CHECK-NEWLOWERING-NEXT: sshll v3.4s, v4.4h, #0
+; CHECK-NEWLOWERING-NEXT: sshll v5.4s, v2.4h, #0
+; CHECK-NEWLOWERING-NEXT: sshll2 v4.4s, v4.8h, #0
+; CHECK-NEWLOWERING-NEXT: sshll2 v2.4s, v2.8h, #0
+; CHECK-NEWLOWERING-NEXT: saddw2 v1.2d, v1.2d, v3.4s
+; CHECK-NEWLOWERING-NEXT: saddw v0.2d, v0.2d, v3.2s
+; CHECK-NEWLOWERING-NEXT: saddl2 v3.2d, v4.4s, v5.4s
+; CHECK-NEWLOWERING-NEXT: saddl v4.2d, v4.2s, v5.2s
+; CHECK-NEWLOWERING-NEXT: saddw2 v1.2d, v1.2d, v2.4s
+; CHECK-NEWLOWERING-NEXT: saddw v0.2d, v0.2d, v2.2s
+; CHECK-NEWLOWERING-NEXT: add v1.2d, v3.2d, v1.2d
+; CHECK-NEWLOWERING-NEXT: add v0.2d, v4.2d, v0.2d
+; CHECK-NEWLOWERING-NEXT: ret
entry:
%a.wide = sext <16 x i8> %a to <16 x i64>
%b.wide = sext <16 x i8> %b to <16 x i64>
@@ -315,6 +464,32 @@ define <4 x i64> @usdot_8to64(<4 x i64> %acc, <16 x i8> %a, <16 x i8> %b){
; CHECK-I8MM-NEXT: saddw2 v1.2d, v1.2d, v4.4s
; CHECK-I8MM-NEXT: saddw v0.2d, v0.2d, v4.2s
; CHECK-I8MM-NEXT: ret
+;
+; CHECK-NEWLOWERING-LABEL: usdot_8to64:
+; CHECK-NEWLOWERING: // %bb.0: // %entry
+; CHECK-NEWLOWERING-NEXT: ushll v4.8h, v2.8b, #0
+; CHECK-NEWLOWERING-NEXT: sshll v5.8h, v3.8b, #0
+; CHECK-NEWLOWERING-NEXT: ushll2 v2.8h, v2.16b, #0
+; CHECK-NEWLOWERING-NEXT: sshll2 v3.8h, v3.16b, #0
+; CHECK-NEWLOWERING-NEXT: ushll v6.4s, v4.4h, #0
+; CHECK-NEWLOWERING-NEXT: sshll v7.4s, v5.4h, #0
+; CHECK-NEWLOWERING-NEXT: ushll2 v4.4s, v4.8h, #0
+; CHECK-NEWLOWERING-NEXT: sshll2 v5.4s, v5.8h, #0
+; CHECK-NEWLOWERING-NEXT: ushll2 v16.4s, v2.8h, #0
+; CHECK-NEWLOWERING-NEXT: sshll2 v17.4s, v3.8h, #0
+; CHECK-NEWLOWERING-NEXT: ushll v2.4s, v2.4h, #0
+; CHECK-NEWLOWERING-NEXT: sshll v3.4s, v3.4h, #0
+; CHECK-NEWLOWERING-NEXT: smlal2 v1.2d, v6.4s, v7.4s
+; CHECK-NEWLOWERING-NEXT: smlal v0.2d, v6.2s, v7.2s
+; CHECK-NEWLOWERING-NEXT: smull v18.2d, v4.2s, v5.2s
+; CHECK-NEWLOWERING-NEXT: smull2 v4.2d, v4.4s, v5.4s
+; CHECK-NEWLOWERING-NEXT: smlal2 v1.2d, v16.4s, v17.4s
+; CHECK-NEWLOWERING-NEXT: smlal v0.2d, v16.2s, v17.2s
+; CHECK-NEWLOWERING-NEXT: smlal2 v4.2d, v2.4s, v3.4s
+; CHECK-NEWLOWERING-NEXT: smlal v18.2d, v2.2s, v3.2s
+; CHECK-NEWLOWERING-NEXT: add v1.2d, v4.2d, v1.2d
+; CHECK-NEWLOWERING-NEXT: add v0.2d, v18.2d, v0.2d
+; CHECK-NEWLOWERING-NEXT: ret
entry:
%a.wide = zext <16 x i8> %a to <16 x i64>
%b.wide = sext <16 x i8> %b to <16 x i64>
@@ -358,6 +533,32 @@ define <4 x i64> @sudot_8to64(<4 x i64> %acc, <16 x i8> %a, <16 x i8> %b) {
; CHECK-I8MM-NEXT: saddw2 v1.2d, v1.2d, v4.4s
; CHECK-I8MM-NEXT: saddw v0.2d, v0.2d, v4.2s
; CHECK-I8MM-NEXT: ret
+;
+; CHECK-NEWLOWERING-LABEL: sudot_8to64:
+; CHECK-NEWLOWERING: // %bb.0: // %entry
+; CHECK-NEWLOWERING-NEXT: sshll v4.8h, v2.8b, #0
+; CHECK-NEWLOWERING-NEXT: ushll v5.8h, v3.8b, #0
+; CHECK-NEWLOWERING-NEXT: sshll2 v2.8h, v2.16b, #0
+; CHECK-NEWLOWERING-NEXT: ushll2 v3.8h, v3.16b, #0
+; CHECK-NEWLOWERING-NEXT: sshll v6.4s, v4.4h, #0
+; CHECK-NEWLOWERING-NEXT: ushll v7.4s, v5.4h, #0
+; CHECK-NEWLOWERING-NEXT: sshll2 v4.4s, v4.8h, #0
+; CHECK-NEWLOWERING-NEXT: ushll2 v5.4s, v5.8h, #0
+; CHECK-NEWLOWERING-NEXT: sshll2 v16.4s, v2.8h, #0
+; CHECK-NEWLOWERING-NEXT: ushll2 v17.4s, v3.8h, #0
+; CHECK-NEWLOWERING-NEXT: sshll v2.4s, v2.4h, #0
+; CHECK-NEWLOWERING-NEXT: ushll v3.4s, v3.4h, #0
+; CHECK-NEWLOWERING-NEXT: smlal2 v1.2d, v6.4s, v7.4s
+; CHECK-NEWLOWERING-NEXT: smlal v0.2d, v6.2s, v7.2s
+; CHECK-NEWLOWERING-NEXT: smull v18.2d, v4.2s, v5.2s
+; CHECK-NEWLOWERING-NEXT: smull2 v4.2d, v4.4s, v5.4s
+; CHECK-NEWLOWERING-NEXT: smlal2 v1.2d, v16.4s, v17.4s
+; CHECK-NEWLOWERING-NEXT: smlal v0.2d, v16.2s, v17.2s
+; CHECK-NEWLOWERING-NEXT: smlal2 v4.2d, v2.4s, v3.4s
+; CHECK-NEWLOWERING-NEXT: smlal v18.2d, v2.2s, v3.2s
+; CHECK-NEWLOWERING-NEXT: add v1.2d, v4.2d, v1.2d
+; CHECK-NEWLOWERING-NEXT: add v0.2d, v18.2d, v0.2d
+; CHECK-NEWLOWERING-NEXT: ret
entry:
%a.wide = sext <16 x i8> %a to <16 x i64>
%b.wide = zext <16 x i8> %b to <16 x i64>
@@ -384,6 +585,17 @@ define <4 x i32> @udot_no_bin_op(<4 x i32> %acc, <16 x i8> %a){
; CHECK-NODOT-NEXT: uaddw2 v0.4s, v0.4s, v1.8h
; CHECK-NODOT-NEXT: add v0.4s, v2.4s, v0.4s
; CHECK-NODOT-NEXT: ret
+;
+; CHECK-NEWLOWERING-LABEL: udot_no_bin_op:
+; CHECK-NEWLOWERING: // %bb.0:
+; CHECK-NEWLOWERING-NEXT: ushll v2.8h, v1.8b, #0
+; CHECK-NEWLOWERING-NEXT: ushll2 v1.8h, v1.16b, #0
+; CHECK-NEWLOWERING-NEXT: ushll v3.4s, v1.4h, #0
+; CHECK-NEWLOWERING-NEXT: uaddw v0.4s, v0.4s, v2.4h
+; CHECK-NEWLOWERING-NEXT: uaddw2 v2.4s, v3.4s, v2.8h
+; CHECK-NEWLOWERING-NEXT: uaddw2 v0.4s, v0.4s, v1.8h
+; CHECK-NEWLOWERING-NEXT: add v0.4s, v2.4s, v0.4s
+; CHECK-NEWLOWERING-NEXT: ret
%a.wide = zext <16 x i8> %a to <16 x i32>
%partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> %acc, <16 x i32> %a.wide)
ret <4 x i32> %partial.reduce
@@ -406,6 +618,17 @@ define <4 x i32> @sdot_no_bin_op(<4 x i32> %acc, <16 x i8> %a){
; CHECK-NODOT-NEXT: saddw2 v0.4s, v0.4s, v1.8h
; CHECK-NODOT-NEXT: add v0.4s, v2.4s, v0.4s
; CHECK-NODOT-NEXT: ret
+;
+; CHECK-NEWLOWERING-LABEL: sdot_no_bin_op:
+; CHECK-NEWLOWERING: // %bb.0:
+; CHECK-NEWLOWERING-NEXT: sshll v2.8h, v1.8b, #0
+; CHECK-NEWLOWERING-NEXT: sshll2 v1.8h, v1.16b, #0
+; CHECK-NEWLOWERING-NEXT: sshll v3.4s, v1.4h, #0
+; CHECK-NEWLOWERING-NEXT: saddw v0.4s, v0.4s, v2.4h
+; CHECK-NEWLOWERING-NEXT: saddw2 v2.4s, v3.4s, v2.8h
+; CHECK-NEWLOWERING-NEXT: saddw2 v0.4s, v0.4s, v1.8h
+; CHECK-NEWLOWERING-NEXT: add v0.4s, v2.4s, v0.4s
+; CHECK-NEWLOWERING-NEXT: ret
%a.wide = sext <16 x i8> %a to <16 x i32>
%partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> %acc, <16 x i32> %a.wide)
ret <4 x i32> %partial.reduce
@@ -432,6 +655,21 @@ define <2 x i32> @udot_no_bin_op_narrow(<2 x i32> %acc, <8 x i8> %a){
; CHECK-NODOT-NEXT: uaddw v1.4s, v2.4s, v4.4h
; CHECK-NODOT-NEXT: add v0.2s, v1.2s, v0.2s
; CHECK-NODOT-NEXT: ret
+;
+; CHECK-NEWLOWERING-LABEL: udot_no_bin_op_narrow:
+; CHECK-NEWLOWERING: // %bb.0:
+; CHECK-NEWLOWERING-NEXT: ushll v1.8h, v1.8b, #0
+; CHECK-NEWLOWERING-NEXT: // kill: def $d0 killed $d0 def $q0
+; CHECK-NEWLOWERING-NEXT: ushll v2.4s, v1.4h, #0
+; CHECK-NEWLOWERING-NEXT: ushll2 v3.4s, v1.8h, #0
+; CHECK-NEWLOWERING-NEXT: ext v4.16b, v1.16b, v1.16b, #8
+; CHECK-NEWLOWERING-NEXT: uaddw v0.4s, v0.4s, v1.4h
+; CHECK-NEWLOWERING-NEXT: ext v3.16b, v3.16b, v3.16b, #8
+; CHECK-NEWLOWERING-NEXT: ext v2.16b, v2.16b, v2.16b, #8
+; CHECK-NEWLOWERING-NEXT: add v0.2s, v3.2s, v0.2s
+; CHECK-NEWLOWERING-NEXT: uaddw v1.4s, v2.4s, v4.4h
+; CHECK-NEWLOWERING-NEXT: add v0.2s, v1.2s, v0.2s
+; CHECK-NEWLOWERING-NEXT: ret
%a.wide = zext <8 x i8> %a to <8 x i32>
%partial.reduce = tail call <2 x i32> @llvm.experimental.vector.partial.reduce.add.v2i32.v8i32(<2 x i32> %acc, <8 x i32> %a.wide)
ret <2 x i32> %partial.reduce
@@ -458,6 +696,21 @@ define <2 x i32> @sdot_no_bin_op_narrow(<2 x i32> %acc, <8 x i8> %a){
; CHECK-NODOT-NEXT: saddw v1.4s, v2.4s, v4.4h
; CHECK-NODOT-NEXT: add v0.2s, v1.2s, v0.2s
; CHECK-NODOT-NEXT: ret
+;
+; CHECK-NEWLOWERING-LABEL: sdot_no_bin_op_narrow:
+; CHECK-NEWLOWERING: // %bb.0:
+; CHECK-NEWLOWERING-NEXT: sshll v1.8h, v1.8b, #0
+; CHECK-NEWLOWERING-NEXT: // kill: def $d0 killed $d0 def $q0
+; CHECK-NEWLOWERING-NEXT: sshll v2.4s, v1.4h, #0
+; CHECK-NEWLOWERING-NEXT: sshll2 v3.4s, v1.8h, #0
+; CHECK-NEWLOWERING-NEXT: ext v4.16b, v1.16b, v1.16b, #8
+; CHECK-NEWLOWERING-NEXT: saddw v0.4s, v0.4s, v1.4h
+; CHECK-NEWLOWERING-NEXT: ext v3.16b, v3.16b, v3.16b, #8
+; CHECK-NEWLOWERING-NEXT: ext v2.16b, v2.16b, v2.16b, #8
+; CHECK-NEWLOWERING-NEXT: add v0.2s, v3.2s, v0.2s
+; CHECK-NEWLOWERING-NEXT: saddw v1.4s, v2.4s, v4.4h
+; CHECK-NEWLOWERING-NEXT: add v0.2s, v1.2s, v0.2s
+; CHECK-NEWLOWERING-NEXT: ret
%a.wide = sext <8 x i8> %a to <8 x i32>
%partial.reduce = tail call <2 x i32> @llvm.experimental.vector.partial.reduce.add.v2i32.v8i32(<2 x i32> %acc, <8 x i32> %a.wide)
ret <2 x i32> %partial.reduce
@@ -490,6 +743,24 @@ define <4 x i64> @udot_no_bin_op_8to64(<4 x i64> %acc, <16 x i8> %a){
; CHECK-NODOT-NEXT: add v1.2d, v4.2d, v1.2d
; CHECK-NODOT-NEXT: add v0.2d, v3.2d, v0.2d
; CHECK-NODOT-NEXT: ret
+;
+; CHECK-NEWLOWERING-LABEL: udot_no_bin_op_8to64:
+; CHECK-NEWLOWERING: // %bb.0:
+; CHECK-NEWLOWERING-NEXT: ushll v3.8h, v2.8b, #0
+; CHECK-NEWLOWERING-NEXT: ushll2 v2.8h, v2.16b, #0
+; CHECK-NEWLOWERING-NEXT: ushll v4.4s, v3.4h, #0
+; CHECK-NEWLOWERING-NEXT: ushll v5.4s, v2.4h, #0
+; CHECK-NEWLOWERING-NEXT: ushll2 v3.4s, v3.8h, #0
+; CHECK-NEWLOWERING-NEXT: ushll2 v2.4s, v2.8h, #0
+; CHECK-NEWLOWERING-NEXT: uaddw2 v1.2d, v1.2d, v4.4s
+; CHECK-NEWLOWERING-NEXT: uaddw v0.2d, v0.2d, v4.2s
+; CHECK-NEWLOWERING-NEXT: uaddl2 v4.2d, v3.4s, v5.4s
+; CHECK-NEWLOWERING-NEXT: uaddl v3.2d, v3.2s, v5.2s
+; CHECK-NEWLOWERING-NEXT: uaddw2 v1.2d, v1.2d, v2.4s
+; CHECK-NEWLOWERING-NEXT: uaddw v0.2d, v0.2d, v2.2s
+; CHECK-NEWLOWERING-NEXT: add v1.2d, v4.2d, v1.2d
+; CHECK-NEWLOWERING-NEXT: add v0.2d, v3.2d, v0.2d
+; CHECK-NEWLOWERING-NEXT: ret
%a.wide = zext <16 x i8> %a to <16 x i64>
%partial.reduce = tail call <4 x i64> @llvm.experimental.vector.partial.reduce.add.v4i64.v16i64(<4 x i64> %acc, <16 x i64> %a.wide)
ret <4 x i64> %partial.reduce
@@ -522,6 +793,24 @@ define <4 x i64> @sdot_no_bin_op_8to64(<4 x i64> %acc, <16 x i8> %a){
; CHECK-NODOT-NEXT: add v1.2d, v4.2d, v1.2d
; CHECK-NODOT-NEXT: add v0.2d, v3.2d, v0.2d
; CHECK-NODOT-NEXT: ret
+;
+; CHECK-NEWLOWERING-LABEL: sdot_no_bin_op_8to64:
+; CHECK-NEWLOWERING: // %bb.0:
+; CHECK-NEWLOWERING-NEXT: sshll v3.8h, v2.8b, #0
+; CHECK-NEWLOWERING-NEXT: sshll2 v2.8h, v2.16b, #0
+; CHECK-NEWLOWERING-NEXT: sshll v4.4s, v3.4h, #0
+; CHECK-NEWLOWERING-NEXT: sshll v5.4s, v2.4h, #0
+; CHECK-NEWLOWERING-NEXT: sshll2 v3.4s, v3.8h, #0
+; CHECK-NEWLOWERING-NEXT: sshll2 v2.4s, v2.8h, #0
+; CHECK-NEWLOWERING-NEXT: saddw2 v1.2d, v1.2d, v4.4s
+; CHECK-NEWLOWERING-NEXT: saddw v0.2d, v0.2d, v4.2s
+; CHECK-NEWLOWERING-NEXT: saddl2 v4.2d, v3.4s, v5.4s
+; CHECK-NEWLOWERING-NEXT: saddl v3.2d, v3.2s, v5.2s
+; CHECK-NEWLOWERING-NEXT: saddw2 v1.2d, v1.2d, v2.4s
+; CHECK-NEWLOWERING-NEXT: saddw v0.2d, v0.2d, v2.2s
+; CHECK-NEWLOWERING-NEXT: add v1.2d, v4.2d, v1.2d
+; CHECK-NEWLOWERING-NEXT: add v0.2d, v3.2d, v0.2d
+; CHECK-NEWLOWERING-NEXT: ret
%a.wide = sext <16 x i8> %a to <16 x i64>
%partial.reduce = tail call <4 x i64> @llvm.experimental.vector.partial.reduce.add.v4i64.v16i64(<4 x i64> %acc, <16 x i64> %a.wide)
ret <4 x i64> %partial.reduce
>From 3161c0e24237f403e1b2f116394ce2b60bc959fb Mon Sep 17 00:00:00 2001
From: James Chesterman <james.chesterman at arm.com>
Date: Tue, 4 Feb 2025 14:54:40 +0000
Subject: [PATCH 04/12] Combine two functions together
---
llvm/include/llvm/CodeGen/TargetLowering.h | 6 ----
.../CodeGen/SelectionDAG/TargetLowering.cpp | 32 ++++++-------------
.../Target/AArch64/AArch64ISelLowering.cpp | 8 +++--
3 files changed, 14 insertions(+), 32 deletions(-)
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index cefd12639c5a8..74380f41a5cec 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -5639,12 +5639,6 @@ class TargetLowering : public TargetLoweringBase {
SDValue expandPartialReduceMLA(SDLoc DL, SDValue Acc, SDValue Input1,
SDValue Input2, SelectionDAG &DAG) const;
- // Create the DAG equivalent of vector_partial_reduce where Op1 and Op2 are
- // its operands and ReducedTY is the return type.
- static SDValue getPartialReduceAdd(SDLoc DL, EVT ReducedTy, EVT FullTy,
- SDValue Op1, SDValue Op2,
- SelectionDAG &DAG);
-
private:
SDValue foldSetCCWithAnd(EVT VT, SDValue N0, SDValue N1, ISD::CondCode Cond,
const SDLoc &DL, DAGCombinerInfo &DCI) const;
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index 8f2164c12958c..5b330487a9285 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -12206,36 +12206,22 @@ SDValue TargetLowering::expandPartialReduceMLA(SDLoc DL, SDValue Acc,
return V;
};
- SDValue Input;
- APInt ConstantOne;
- if (!ISD::isConstantSplatVector(Input2.getNode(), ConstantOne) ||
- !ConstantOne.isOne()) {
- EVT NewVT =
- EVT::getVectorVT(*DAG.getContext(), ReducedTy.getVectorElementType(),
- FullTy.getVectorElementCount());
- Input1 = ExtendToAccEltVT(Input1);
- Input2 = ExtendToAccEltVT(Input2);
- Input = DAG.getNode(ISD::MUL, DL, NewVT, Input1, Input2);
- } else {
- Input = ExtendToAccEltVT(Input1);
- }
+ EVT NewVT =
+ EVT::getVectorVT(*DAG.getContext(), ReducedTy.getVectorElementType(),
+ FullTy.getVectorElementCount());
+ Input1 = ExtendToAccEltVT(Input1);
+ Input2 = ExtendToAccEltVT(Input2);
+ SDValue Input = DAG.getNode(ISD::MUL, DL, NewVT, Input1, Input2);
- return TargetLowering::getPartialReduceAdd(DL, ReducedTy, FullTy, Acc, Input,
- DAG);
-}
-
-SDValue TargetLowering::getPartialReduceAdd(SDLoc DL, EVT ReducedTy, EVT FullTy,
- SDValue Op1, SDValue Op2,
- SelectionDAG &DAG) {
unsigned Stride = ReducedTy.getVectorMinNumElements();
unsigned ScaleFactor = FullTy.getVectorMinNumElements() / Stride;
// Collect all of the subvectors
- std::deque<SDValue> Subvectors = {Op1};
+ std::deque<SDValue> Subvectors = {Acc};
for (unsigned I = 0; I < ScaleFactor; I++) {
auto SourceIndex = DAG.getVectorIdxConstant(I * Stride, DL);
- Subvectors.push_back(
- DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ReducedTy, {Op2, SourceIndex}));
+ Subvectors.push_back(DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ReducedTy,
+ {Input, SourceIndex}));
}
// Flatten the subvector tree
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index e2bfb6df8cbb0..4c66690b11c70 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -22018,10 +22018,12 @@ static SDValue performIntrinsicCombine(SDNode *N,
return Dot;
if (SDValue WideAdd = tryLowerPartialReductionToWideAdd(N, Subtarget, DAG))
return WideAdd;
+ const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+ SDLoc DL(N);
SDValue Input = N->getOperand(2);
- return TargetLowering::getPartialReduceAdd(SDLoc(N), N->getValueType(0),
- Input.getValueType(),
- N->getOperand(1), Input, DAG);
+ return TLI.expandPartialReduceMLA(
+ DL, N->getOperand(1), Input,
+ DAG.getConstant(1, DL, Input.getValueType()), DAG);
}
case Intrinsic::aarch64_neon_vcvtfxs2fp:
case Intrinsic::aarch64_neon_vcvtfxu2fp:
>From a9b13c7b5db3f30d2ec942b28c4458afc8a92fd9 Mon Sep 17 00:00:00 2001
From: James Chesterman <james.chesterman at arm.com>
Date: Wed, 5 Feb 2025 13:06:56 +0000
Subject: [PATCH 05/12] Address comments on PR. Involves changing arguments on
the function.
---
llvm/include/llvm/CodeGen/ISDOpcodes.h | 2 +-
llvm/include/llvm/CodeGen/TargetLowering.h | 9 +-
.../SelectionDAG/LegalizeIntegerTypes.cpp | 3 +-
.../SelectionDAG/LegalizeVectorOps.cpp | 4 +-
.../SelectionDAG/LegalizeVectorTypes.cpp | 6 +-
.../SelectionDAG/SelectionDAGBuilder.cpp | 10 +-
.../CodeGen/SelectionDAG/TargetLowering.cpp | 100 ++---
.../Target/AArch64/AArch64ISelLowering.cpp | 7 +-
.../neon-partial-reduce-dot-product.ll | 290 +------------
.../AArch64/sve-partial-reduce-dot-product.ll | 406 +++++++++++-------
10 files changed, 316 insertions(+), 521 deletions(-)
diff --git a/llvm/include/llvm/CodeGen/ISDOpcodes.h b/llvm/include/llvm/CodeGen/ISDOpcodes.h
index 422a70bb6641b..efeb63df4f143 100644
--- a/llvm/include/llvm/CodeGen/ISDOpcodes.h
+++ b/llvm/include/llvm/CodeGen/ISDOpcodes.h
@@ -1452,7 +1452,7 @@ enum NodeType {
VECREDUCE_UMIN,
// PARTIAL_REDUCE_[U|S]MLA(Accumulator, Input1, Input2)
- // The partial reduction nodes sign-or zero extend Input1 and Input2 to
+ // The partial reduction nodes sign or zero extend Input1 and Input2 to
// the element type of Accumulator before multiplying their results.
// The multiplied result is then reduced using addition to the result
// type of Accumulator. The result is added to Accumulator and returned.
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 74380f41a5cec..b6358776d9891 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -5539,6 +5539,10 @@ class TargetLowering : public TargetLoweringBase {
/// temporarily, advance store position, before re-loading the final vector.
SDValue expandVECTOR_COMPRESS(SDNode *Node, SelectionDAG &DAG) const;
+ /// Expands PARTIAL_REDUCE_S/UMLA nodes to a series of simpler operations,
+ /// consisting of zext/sext, extract_subvector, mul and add operations.
+ SDValue expandPartialReduceMLA(SDNode *Node, SelectionDAG &DAG) const;
+
/// Legalize a SETCC or VP_SETCC with given LHS and RHS and condition code CC
/// on the current target. A VP_SETCC will additionally be given a Mask
/// and/or EVL not equal to SDValue().
@@ -5634,11 +5638,6 @@ class TargetLowering : public TargetLoweringBase {
LoadSDNode *OriginalLoad,
SelectionDAG &DAG) const;
- // Expands PARTIAL_REDUCE_S/UMLA nodes to a series of simpler operations,
- // consisting of zext/sext, extract_subvector, mul and add operations.
- SDValue expandPartialReduceMLA(SDLoc DL, SDValue Acc, SDValue Input1,
- SDValue Input2, SelectionDAG &DAG) const;
-
private:
SDValue foldSetCCWithAnd(EVT VT, SDValue N0, SDValue N1, ISD::CondCode Cond,
const SDLoc &DL, DAGCombinerInfo &DCI) const;
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
index cbb061c9e9403..8e4f90ff9c6c4 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
@@ -6217,9 +6217,8 @@ SDValue DAGTypeLegalizer::PromoteIntRes_PARTIAL_REDUCE_MLA(SDNode *N) {
EVT VT = N->getValueType(0);
EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), VT);
SDValue ExtAcc = GetPromotedInteger(N->getOperand(0));
- SDValue V = DAG.getNode(N->getOpcode(), DL, NVT, ExtAcc, N->getOperand(1),
+ return DAG.getNode(N->getOpcode(), DL, NVT, ExtAcc, N->getOperand(1),
N->getOperand(2));
- return V;
}
SDValue DAGTypeLegalizer::PromoteIntRes_INSERT_VECTOR_ELT(SDNode *N) {
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
index da86996e88ed8..85524f1d83896 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
@@ -1200,9 +1200,7 @@ void VectorLegalizer::Expand(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
case ISD::PARTIAL_REDUCE_UMLA:
case ISD::PARTIAL_REDUCE_SMLA: {
SDLoc DL(Node);
- Results.push_back(TLI.expandPartialReduceMLA(DL, Node->getOperand(0),
- Node->getOperand(1),
- Node->getOperand(2), DAG));
+ Results.push_back(TLI.expandPartialReduceMLA(Node, DAG));
return;
}
case ISD::VECREDUCE_SEQ_FADD:
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
index ed6e1b57542af..b1b2957efc4e3 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
@@ -3187,8 +3187,7 @@ void DAGTypeLegalizer::SplitVecRes_VP_REVERSE(SDNode *N, SDValue &Lo,
void DAGTypeLegalizer::SplitVecRes_PARTIAL_REDUCE_MLA(SDNode *N) {
SDLoc DL(N);
- SDValue Res = TLI.expandPartialReduceMLA(
- DL, N->getOperand(0), N->getOperand(1), N->getOperand(2), DAG);
+ SDValue Res = TLI.expandPartialReduceMLA(N, DAG);
ReplaceValueWith(SDValue(N, 0), Res);
}
@@ -4450,8 +4449,7 @@ SDValue DAGTypeLegalizer::SplitVecOp_VECTOR_HISTOGRAM(SDNode *N) {
SDValue DAGTypeLegalizer::SplitVecOp_PARTIAL_REDUCE_MLA(SDNode *N) {
SDLoc DL(N);
- SDValue Res = TLI.expandPartialReduceMLA(
- DL, N->getOperand(0), N->getOperand(1), N->getOperand(2), DAG);
+ SDValue Res = TLI.expandPartialReduceMLA(N, DAG);
ReplaceValueWith(SDValue(N, 0), Res);
return SDValue();
}
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index bc08b6c5580f7..5188b85adade1 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -8137,13 +8137,15 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
0 &&
"Expected the element count of the Input operand to be a positive "
"integer multiple of the element count of the Accumulator operand!");
+ SDValue PRVal = DAG.getNode(ISD::PARTIAL_REDUCE_UMLA, sdl, AccVT, Acc,
+ Input, DAG.getConstant(1, sdl, InputVT));
+
if (NewPartialReduceLowering) {
// ISD::PARTIAL_REDUCE_UMLA is chosen arbitrarily and would function the
// same if ISD::PARTIAL_REDUCE_SMLA was chosen instead. It should be
// changed to its correct signedness when combining or expanding,
// according to extends being performed on Input.
- setValue(&I, DAG.getNode(ISD::PARTIAL_REDUCE_UMLA, sdl, AccVT, Acc, Input,
- DAG.getConstant(1, sdl, InputVT)));
+ setValue(&I, PRVal);
return;
}
@@ -8151,9 +8153,7 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
visitTargetIntrinsic(I, Intrinsic);
return;
}
-
- setValue(&I, TLI.expandPartialReduceMLA(
- sdl, Acc, Input, DAG.getConstant(1, sdl, InputVT), DAG));
+ setValue(&I, TLI.expandPartialReduceMLA(PRVal.getNode(), DAG));
return;
}
case Intrinsic::experimental_cttz_elts: {
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index 5b330487a9285..c2fb6a07cbd18 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -11891,6 +11891,58 @@ SDValue TargetLowering::expandVECTOR_COMPRESS(SDNode *Node,
return DAG.getLoad(VecVT, DL, Chain, StackPtr, PtrInfo);
}
+SDValue TargetLowering::expandPartialReduceMLA(SDNode *N, SelectionDAG &DAG) const {
+ SDLoc DL(N);
+ SDValue Acc = N->getOperand(0);
+ SDValue Input1 = N->getOperand(1);
+ SDValue Input2 = N->getOperand(2);
+
+
+ EVT ReducedTy = Acc.getValueType();
+ EVT FullTy = Input1.getValueType();
+
+ auto ExtendToAccEltVT = [&](SDValue V) {
+ unsigned ExtOpc = V->getOpcode() == ISD::SIGN_EXTEND ? ISD::SIGN_EXTEND
+ : ISD::ZERO_EXTEND;
+ EVT ExtVT = V.getValueType().changeVectorElementType(
+ Acc.getValueType().getVectorElementType());
+ if (ExtVT != FullTy)
+ return DAG.getNode(ExtOpc, DL, ExtVT, V);
+ return V;
+ };
+
+ EVT NewVT =
+ EVT::getVectorVT(*DAG.getContext(), ReducedTy.getVectorElementType(),
+ FullTy.getVectorElementCount());
+ Input1 = ExtendToAccEltVT(Input1);
+ Input2 = ExtendToAccEltVT(Input2);
+ SDValue Input = DAG.getNode(ISD::MUL, DL, NewVT, Input1, Input2);
+
+ unsigned Stride = ReducedTy.getVectorMinNumElements();
+ unsigned ScaleFactor = FullTy.getVectorMinNumElements() / Stride;
+
+ // Collect all of the subvectors
+ std::deque<SDValue> Subvectors = {Acc};
+ for (unsigned I = 0; I < ScaleFactor; I++) {
+ auto SourceIndex = DAG.getVectorIdxConstant(I * Stride, DL);
+ Subvectors.push_back(DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ReducedTy,
+ {Input, SourceIndex}));
+ }
+
+ // Flatten the subvector tree
+ while (Subvectors.size() > 1) {
+ Subvectors.push_back(
+ DAG.getNode(ISD::ADD, DL, ReducedTy, {Subvectors[0], Subvectors[1]}));
+ Subvectors.pop_front();
+ Subvectors.pop_front();
+ }
+
+ assert(Subvectors.size() == 1 &&
+ "There should only be one subvector after tree flattening");
+
+ return Subvectors[0];
+}
+
bool TargetLowering::LegalizeSetCCCondCode(SelectionDAG &DAG, EVT VT,
SDValue &LHS, SDValue &RHS,
SDValue &CC, SDValue Mask,
@@ -12189,51 +12241,3 @@ SDValue TargetLowering::scalarizeExtractedVectorLoad(EVT ResultVT,
return Load;
}
-
-SDValue TargetLowering::expandPartialReduceMLA(SDLoc DL, SDValue Acc,
- SDValue Input1, SDValue Input2,
- SelectionDAG &DAG) const {
- EVT ReducedTy = Acc.getValueType();
- EVT FullTy = Input1.getValueType();
-
- auto ExtendToAccEltVT = [&](SDValue V) {
- unsigned ExtOpc = V->getOpcode() == ISD::SIGN_EXTEND ? ISD::SIGN_EXTEND
- : ISD::ZERO_EXTEND;
- EVT ExtVT = V.getValueType().changeVectorElementType(
- Acc.getValueType().getVectorElementType());
- if (ExtVT != FullTy)
- return DAG.getNode(ExtOpc, DL, ExtVT, V);
- return V;
- };
-
- EVT NewVT =
- EVT::getVectorVT(*DAG.getContext(), ReducedTy.getVectorElementType(),
- FullTy.getVectorElementCount());
- Input1 = ExtendToAccEltVT(Input1);
- Input2 = ExtendToAccEltVT(Input2);
- SDValue Input = DAG.getNode(ISD::MUL, DL, NewVT, Input1, Input2);
-
- unsigned Stride = ReducedTy.getVectorMinNumElements();
- unsigned ScaleFactor = FullTy.getVectorMinNumElements() / Stride;
-
- // Collect all of the subvectors
- std::deque<SDValue> Subvectors = {Acc};
- for (unsigned I = 0; I < ScaleFactor; I++) {
- auto SourceIndex = DAG.getVectorIdxConstant(I * Stride, DL);
- Subvectors.push_back(DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ReducedTy,
- {Input, SourceIndex}));
- }
-
- // Flatten the subvector tree
- while (Subvectors.size() > 1) {
- Subvectors.push_back(
- DAG.getNode(ISD::ADD, DL, ReducedTy, {Subvectors[0], Subvectors[1]}));
- Subvectors.pop_front();
- Subvectors.pop_front();
- }
-
- assert(Subvectors.size() == 1 &&
- "There should only be one subvector after tree flattening");
-
- return Subvectors[0];
-}
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 4c66690b11c70..0baa873295ced 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -22021,9 +22021,10 @@ static SDValue performIntrinsicCombine(SDNode *N,
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
SDLoc DL(N);
SDValue Input = N->getOperand(2);
- return TLI.expandPartialReduceMLA(
- DL, N->getOperand(1), Input,
- DAG.getConstant(1, DL, Input.getValueType()), DAG);
+ SDValue PRVal = DAG.getNode(ISD::PARTIAL_REDUCE_UMLA, DL,
+ N->getValueType(0), N->getOperand(1), Input,
+ DAG.getConstant(1, DL, Input.getValueType()));
+ return TLI.expandPartialReduceMLA(PRVal.getNode(), DAG);
}
case Intrinsic::aarch64_neon_vcvtfxs2fp:
case Intrinsic::aarch64_neon_vcvtfxu2fp:
diff --git a/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
index 4fa10d0e9c0f9..60dd825b2d093 100644
--- a/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
@@ -2,7 +2,7 @@
; RUN: llc -mtriple aarch64 -mattr=+neon,+dotprod < %s | FileCheck %s --check-prefixes=CHECK,CHECK-DOT,CHECK-NOI8MM
; RUN: llc -mtriple aarch64 -mattr=+neon < %s | FileCheck %s --check-prefixes=CHECK,CHECK-NOI8MM,CHECK-NODOT
; RUN: llc -mtriple aarch64 -mattr=+neon,+dotprod,+i8mm < %s | FileCheck %s --check-prefixes=CHECK,CHECK-DOT,CHECK-I8MM
-; RUN: llc -mtriple aarch64 -mattr=+neon,+dotprod,+i8mm -new-partial-reduce-lowering < %s | FileCheck %s --check-prefixes=CHECK,CHECK-NEWLOWERING
+; RUN: llc -mtriple aarch64 -mattr=+neon,+dotprod,+i8mm -new-partial-reduce-lowering < %s | FileCheck %s --check-prefixes=CHECK,CHECK-NOI8MM,CHECK-NODOT
define <4 x i32> @udot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) {
; CHECK-DOT-LABEL: udot:
@@ -20,17 +20,6 @@ define <4 x i32> @udot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) {
; CHECK-NODOT-NEXT: uaddw2 v0.4s, v0.4s, v1.8h
; CHECK-NODOT-NEXT: add v0.4s, v2.4s, v0.4s
; CHECK-NODOT-NEXT: ret
-;
-; CHECK-NEWLOWERING-LABEL: udot:
-; CHECK-NEWLOWERING: // %bb.0:
-; CHECK-NEWLOWERING-NEXT: umull v3.8h, v2.8b, v1.8b
-; CHECK-NEWLOWERING-NEXT: umull2 v1.8h, v2.16b, v1.16b
-; CHECK-NEWLOWERING-NEXT: ushll v2.4s, v1.4h, #0
-; CHECK-NEWLOWERING-NEXT: uaddw v0.4s, v0.4s, v3.4h
-; CHECK-NEWLOWERING-NEXT: uaddw2 v2.4s, v2.4s, v3.8h
-; CHECK-NEWLOWERING-NEXT: uaddw2 v0.4s, v0.4s, v1.8h
-; CHECK-NEWLOWERING-NEXT: add v0.4s, v2.4s, v0.4s
-; CHECK-NEWLOWERING-NEXT: ret
%u.wide = zext <16 x i8> %u to <16 x i32>
%s.wide = zext <16 x i8> %s to <16 x i32>
%mult = mul nuw nsw <16 x i32> %s.wide, %u.wide
@@ -58,21 +47,6 @@ define <2 x i32> @udot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) {
; CHECK-NODOT-NEXT: uaddw v1.4s, v2.4s, v4.4h
; CHECK-NODOT-NEXT: add v0.2s, v1.2s, v0.2s
; CHECK-NODOT-NEXT: ret
-;
-; CHECK-NEWLOWERING-LABEL: udot_narrow:
-; CHECK-NEWLOWERING: // %bb.0:
-; CHECK-NEWLOWERING-NEXT: umull v1.8h, v2.8b, v1.8b
-; CHECK-NEWLOWERING-NEXT: // kill: def $d0 killed $d0 def $q0
-; CHECK-NEWLOWERING-NEXT: ushll v2.4s, v1.4h, #0
-; CHECK-NEWLOWERING-NEXT: ushll2 v3.4s, v1.8h, #0
-; CHECK-NEWLOWERING-NEXT: ext v4.16b, v1.16b, v1.16b, #8
-; CHECK-NEWLOWERING-NEXT: uaddw v0.4s, v0.4s, v1.4h
-; CHECK-NEWLOWERING-NEXT: ext v3.16b, v3.16b, v3.16b, #8
-; CHECK-NEWLOWERING-NEXT: ext v2.16b, v2.16b, v2.16b, #8
-; CHECK-NEWLOWERING-NEXT: add v0.2s, v3.2s, v0.2s
-; CHECK-NEWLOWERING-NEXT: uaddw v1.4s, v2.4s, v4.4h
-; CHECK-NEWLOWERING-NEXT: add v0.2s, v1.2s, v0.2s
-; CHECK-NEWLOWERING-NEXT: ret
%u.wide = zext <8 x i8> %u to <8 x i32>
%s.wide = zext <8 x i8> %s to <8 x i32>
%mult = mul nuw nsw <8 x i32> %s.wide, %u.wide
@@ -96,17 +70,6 @@ define <4 x i32> @sdot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) {
; CHECK-NODOT-NEXT: saddw2 v0.4s, v0.4s, v1.8h
; CHECK-NODOT-NEXT: add v0.4s, v2.4s, v0.4s
; CHECK-NODOT-NEXT: ret
-;
-; CHECK-NEWLOWERING-LABEL: sdot:
-; CHECK-NEWLOWERING: // %bb.0:
-; CHECK-NEWLOWERING-NEXT: smull v3.8h, v2.8b, v1.8b
-; CHECK-NEWLOWERING-NEXT: smull2 v1.8h, v2.16b, v1.16b
-; CHECK-NEWLOWERING-NEXT: sshll v2.4s, v1.4h, #0
-; CHECK-NEWLOWERING-NEXT: saddw v0.4s, v0.4s, v3.4h
-; CHECK-NEWLOWERING-NEXT: saddw2 v2.4s, v2.4s, v3.8h
-; CHECK-NEWLOWERING-NEXT: saddw2 v0.4s, v0.4s, v1.8h
-; CHECK-NEWLOWERING-NEXT: add v0.4s, v2.4s, v0.4s
-; CHECK-NEWLOWERING-NEXT: ret
%u.wide = sext <16 x i8> %u to <16 x i32>
%s.wide = sext <16 x i8> %s to <16 x i32>
%mult = mul nuw nsw <16 x i32> %s.wide, %u.wide
@@ -134,21 +97,6 @@ define <2 x i32> @sdot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) {
; CHECK-NODOT-NEXT: saddw v1.4s, v2.4s, v4.4h
; CHECK-NODOT-NEXT: add v0.2s, v1.2s, v0.2s
; CHECK-NODOT-NEXT: ret
-;
-; CHECK-NEWLOWERING-LABEL: sdot_narrow:
-; CHECK-NEWLOWERING: // %bb.0:
-; CHECK-NEWLOWERING-NEXT: smull v1.8h, v2.8b, v1.8b
-; CHECK-NEWLOWERING-NEXT: // kill: def $d0 killed $d0 def $q0
-; CHECK-NEWLOWERING-NEXT: sshll v2.4s, v1.4h, #0
-; CHECK-NEWLOWERING-NEXT: sshll2 v3.4s, v1.8h, #0
-; CHECK-NEWLOWERING-NEXT: ext v4.16b, v1.16b, v1.16b, #8
-; CHECK-NEWLOWERING-NEXT: saddw v0.4s, v0.4s, v1.4h
-; CHECK-NEWLOWERING-NEXT: ext v3.16b, v3.16b, v3.16b, #8
-; CHECK-NEWLOWERING-NEXT: ext v2.16b, v2.16b, v2.16b, #8
-; CHECK-NEWLOWERING-NEXT: add v0.2s, v3.2s, v0.2s
-; CHECK-NEWLOWERING-NEXT: saddw v1.4s, v2.4s, v4.4h
-; CHECK-NEWLOWERING-NEXT: add v0.2s, v1.2s, v0.2s
-; CHECK-NEWLOWERING-NEXT: ret
%u.wide = sext <8 x i8> %u to <8 x i32>
%s.wide = sext <8 x i8> %s to <8 x i32>
%mult = mul nuw nsw <8 x i32> %s.wide, %u.wide
@@ -174,19 +122,6 @@ define <4 x i32> @usdot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) {
; CHECK-I8MM: // %bb.0:
; CHECK-I8MM-NEXT: usdot v0.4s, v1.16b, v2.16b
; CHECK-I8MM-NEXT: ret
-;
-; CHECK-NEWLOWERING-LABEL: usdot:
-; CHECK-NEWLOWERING: // %bb.0:
-; CHECK-NEWLOWERING-NEXT: ushll v3.8h, v1.8b, #0
-; CHECK-NEWLOWERING-NEXT: ushll2 v1.8h, v1.16b, #0
-; CHECK-NEWLOWERING-NEXT: sshll v4.8h, v2.8b, #0
-; CHECK-NEWLOWERING-NEXT: sshll2 v2.8h, v2.16b, #0
-; CHECK-NEWLOWERING-NEXT: smlal v0.4s, v4.4h, v3.4h
-; CHECK-NEWLOWERING-NEXT: smull v5.4s, v2.4h, v1.4h
-; CHECK-NEWLOWERING-NEXT: smlal2 v0.4s, v2.8h, v1.8h
-; CHECK-NEWLOWERING-NEXT: smlal2 v5.4s, v4.8h, v3.8h
-; CHECK-NEWLOWERING-NEXT: add v0.4s, v5.4s, v0.4s
-; CHECK-NEWLOWERING-NEXT: ret
%u.wide = zext <16 x i8> %u to <16 x i32>
%s.wide = sext <16 x i8> %s to <16 x i32>
%mult = mul nuw nsw <16 x i32> %s.wide, %u.wide
@@ -216,23 +151,6 @@ define <2 x i32> @usdot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) #0{
; CHECK-I8MM: // %bb.0:
; CHECK-I8MM-NEXT: usdot v0.2s, v1.8b, v2.8b
; CHECK-I8MM-NEXT: ret
-;
-; CHECK-NEWLOWERING-LABEL: usdot_narrow:
-; CHECK-NEWLOWERING: // %bb.0:
-; CHECK-NEWLOWERING-NEXT: ushll v1.8h, v1.8b, #0
-; CHECK-NEWLOWERING-NEXT: sshll v2.8h, v2.8b, #0
-; CHECK-NEWLOWERING-NEXT: // kill: def $d0 killed $d0 def $q0
-; CHECK-NEWLOWERING-NEXT: smull v3.4s, v2.4h, v1.4h
-; CHECK-NEWLOWERING-NEXT: smull2 v4.4s, v2.8h, v1.8h
-; CHECK-NEWLOWERING-NEXT: ext v5.16b, v1.16b, v1.16b, #8
-; CHECK-NEWLOWERING-NEXT: ext v6.16b, v2.16b, v2.16b, #8
-; CHECK-NEWLOWERING-NEXT: smlal v0.4s, v2.4h, v1.4h
-; CHECK-NEWLOWERING-NEXT: ext v3.16b, v3.16b, v3.16b, #8
-; CHECK-NEWLOWERING-NEXT: ext v1.16b, v4.16b, v4.16b, #8
-; CHECK-NEWLOWERING-NEXT: smlal v3.4s, v6.4h, v5.4h
-; CHECK-NEWLOWERING-NEXT: add v0.2s, v1.2s, v0.2s
-; CHECK-NEWLOWERING-NEXT: add v0.2s, v3.2s, v0.2s
-; CHECK-NEWLOWERING-NEXT: ret
%u.wide = zext <8 x i8> %u to <8 x i32>
%s.wide = sext <8 x i8> %s to <8 x i32>
%mult = mul nuw nsw <8 x i32> %s.wide, %u.wide
@@ -258,19 +176,6 @@ define <4 x i32> @sudot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) #0{
; CHECK-I8MM: // %bb.0:
; CHECK-I8MM-NEXT: usdot v0.4s, v2.16b, v1.16b
; CHECK-I8MM-NEXT: ret
-;
-; CHECK-NEWLOWERING-LABEL: sudot:
-; CHECK-NEWLOWERING: // %bb.0:
-; CHECK-NEWLOWERING-NEXT: sshll v3.8h, v1.8b, #0
-; CHECK-NEWLOWERING-NEXT: sshll2 v1.8h, v1.16b, #0
-; CHECK-NEWLOWERING-NEXT: ushll v4.8h, v2.8b, #0
-; CHECK-NEWLOWERING-NEXT: ushll2 v2.8h, v2.16b, #0
-; CHECK-NEWLOWERING-NEXT: smlal v0.4s, v4.4h, v3.4h
-; CHECK-NEWLOWERING-NEXT: smull v5.4s, v2.4h, v1.4h
-; CHECK-NEWLOWERING-NEXT: smlal2 v0.4s, v2.8h, v1.8h
-; CHECK-NEWLOWERING-NEXT: smlal2 v5.4s, v4.8h, v3.8h
-; CHECK-NEWLOWERING-NEXT: add v0.4s, v5.4s, v0.4s
-; CHECK-NEWLOWERING-NEXT: ret
%u.wide = sext <16 x i8> %u to <16 x i32>
%s.wide = zext <16 x i8> %s to <16 x i32>
%mult = mul nuw nsw <16 x i32> %s.wide, %u.wide
@@ -300,23 +205,6 @@ define <2 x i32> @sudot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) #0{
; CHECK-I8MM: // %bb.0:
; CHECK-I8MM-NEXT: usdot v0.2s, v2.8b, v1.8b
; CHECK-I8MM-NEXT: ret
-;
-; CHECK-NEWLOWERING-LABEL: sudot_narrow:
-; CHECK-NEWLOWERING: // %bb.0:
-; CHECK-NEWLOWERING-NEXT: sshll v1.8h, v1.8b, #0
-; CHECK-NEWLOWERING-NEXT: ushll v2.8h, v2.8b, #0
-; CHECK-NEWLOWERING-NEXT: // kill: def $d0 killed $d0 def $q0
-; CHECK-NEWLOWERING-NEXT: smull v3.4s, v2.4h, v1.4h
-; CHECK-NEWLOWERING-NEXT: smull2 v4.4s, v2.8h, v1.8h
-; CHECK-NEWLOWERING-NEXT: ext v5.16b, v1.16b, v1.16b, #8
-; CHECK-NEWLOWERING-NEXT: ext v6.16b, v2.16b, v2.16b, #8
-; CHECK-NEWLOWERING-NEXT: smlal v0.4s, v2.4h, v1.4h
-; CHECK-NEWLOWERING-NEXT: ext v3.16b, v3.16b, v3.16b, #8
-; CHECK-NEWLOWERING-NEXT: ext v1.16b, v4.16b, v4.16b, #8
-; CHECK-NEWLOWERING-NEXT: smlal v3.4s, v6.4h, v5.4h
-; CHECK-NEWLOWERING-NEXT: add v0.2s, v1.2s, v0.2s
-; CHECK-NEWLOWERING-NEXT: add v0.2s, v3.2s, v0.2s
-; CHECK-NEWLOWERING-NEXT: ret
%u.wide = sext <8 x i8> %u to <8 x i32>
%s.wide = zext <8 x i8> %s to <8 x i32>
%mult = mul nuw nsw <8 x i32> %s.wide, %u.wide
@@ -350,24 +238,6 @@ define <4 x i64> @udot_8to64(<4 x i64> %acc, <16 x i8> %a, <16 x i8> %b) {
; CHECK-NODOT-NEXT: add v1.2d, v3.2d, v1.2d
; CHECK-NODOT-NEXT: add v0.2d, v4.2d, v0.2d
; CHECK-NODOT-NEXT: ret
-;
-; CHECK-NEWLOWERING-LABEL: udot_8to64:
-; CHECK-NEWLOWERING: // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT: umull v4.8h, v2.8b, v3.8b
-; CHECK-NEWLOWERING-NEXT: umull2 v2.8h, v2.16b, v3.16b
-; CHECK-NEWLOWERING-NEXT: ushll v3.4s, v4.4h, #0
-; CHECK-NEWLOWERING-NEXT: ushll v5.4s, v2.4h, #0
-; CHECK-NEWLOWERING-NEXT: ushll2 v4.4s, v4.8h, #0
-; CHECK-NEWLOWERING-NEXT: ushll2 v2.4s, v2.8h, #0
-; CHECK-NEWLOWERING-NEXT: uaddw2 v1.2d, v1.2d, v3.4s
-; CHECK-NEWLOWERING-NEXT: uaddw v0.2d, v0.2d, v3.2s
-; CHECK-NEWLOWERING-NEXT: uaddl2 v3.2d, v4.4s, v5.4s
-; CHECK-NEWLOWERING-NEXT: uaddl v4.2d, v4.2s, v5.2s
-; CHECK-NEWLOWERING-NEXT: uaddw2 v1.2d, v1.2d, v2.4s
-; CHECK-NEWLOWERING-NEXT: uaddw v0.2d, v0.2d, v2.2s
-; CHECK-NEWLOWERING-NEXT: add v1.2d, v3.2d, v1.2d
-; CHECK-NEWLOWERING-NEXT: add v0.2d, v4.2d, v0.2d
-; CHECK-NEWLOWERING-NEXT: ret
entry:
%a.wide = zext <16 x i8> %a to <16 x i64>
%b.wide = zext <16 x i8> %b to <16 x i64>
@@ -403,24 +273,6 @@ define <4 x i64> @sdot_8to64(<4 x i64> %acc, <16 x i8> %a, <16 x i8> %b){
; CHECK-NODOT-NEXT: add v1.2d, v3.2d, v1.2d
; CHECK-NODOT-NEXT: add v0.2d, v4.2d, v0.2d
; CHECK-NODOT-NEXT: ret
-;
-; CHECK-NEWLOWERING-LABEL: sdot_8to64:
-; CHECK-NEWLOWERING: // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT: smull v4.8h, v2.8b, v3.8b
-; CHECK-NEWLOWERING-NEXT: smull2 v2.8h, v2.16b, v3.16b
-; CHECK-NEWLOWERING-NEXT: sshll v3.4s, v4.4h, #0
-; CHECK-NEWLOWERING-NEXT: sshll v5.4s, v2.4h, #0
-; CHECK-NEWLOWERING-NEXT: sshll2 v4.4s, v4.8h, #0
-; CHECK-NEWLOWERING-NEXT: sshll2 v2.4s, v2.8h, #0
-; CHECK-NEWLOWERING-NEXT: saddw2 v1.2d, v1.2d, v3.4s
-; CHECK-NEWLOWERING-NEXT: saddw v0.2d, v0.2d, v3.2s
-; CHECK-NEWLOWERING-NEXT: saddl2 v3.2d, v4.4s, v5.4s
-; CHECK-NEWLOWERING-NEXT: saddl v4.2d, v4.2s, v5.2s
-; CHECK-NEWLOWERING-NEXT: saddw2 v1.2d, v1.2d, v2.4s
-; CHECK-NEWLOWERING-NEXT: saddw v0.2d, v0.2d, v2.2s
-; CHECK-NEWLOWERING-NEXT: add v1.2d, v3.2d, v1.2d
-; CHECK-NEWLOWERING-NEXT: add v0.2d, v4.2d, v0.2d
-; CHECK-NEWLOWERING-NEXT: ret
entry:
%a.wide = sext <16 x i8> %a to <16 x i64>
%b.wide = sext <16 x i8> %b to <16 x i64>
@@ -464,32 +316,6 @@ define <4 x i64> @usdot_8to64(<4 x i64> %acc, <16 x i8> %a, <16 x i8> %b){
; CHECK-I8MM-NEXT: saddw2 v1.2d, v1.2d, v4.4s
; CHECK-I8MM-NEXT: saddw v0.2d, v0.2d, v4.2s
; CHECK-I8MM-NEXT: ret
-;
-; CHECK-NEWLOWERING-LABEL: usdot_8to64:
-; CHECK-NEWLOWERING: // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT: ushll v4.8h, v2.8b, #0
-; CHECK-NEWLOWERING-NEXT: sshll v5.8h, v3.8b, #0
-; CHECK-NEWLOWERING-NEXT: ushll2 v2.8h, v2.16b, #0
-; CHECK-NEWLOWERING-NEXT: sshll2 v3.8h, v3.16b, #0
-; CHECK-NEWLOWERING-NEXT: ushll v6.4s, v4.4h, #0
-; CHECK-NEWLOWERING-NEXT: sshll v7.4s, v5.4h, #0
-; CHECK-NEWLOWERING-NEXT: ushll2 v4.4s, v4.8h, #0
-; CHECK-NEWLOWERING-NEXT: sshll2 v5.4s, v5.8h, #0
-; CHECK-NEWLOWERING-NEXT: ushll2 v16.4s, v2.8h, #0
-; CHECK-NEWLOWERING-NEXT: sshll2 v17.4s, v3.8h, #0
-; CHECK-NEWLOWERING-NEXT: ushll v2.4s, v2.4h, #0
-; CHECK-NEWLOWERING-NEXT: sshll v3.4s, v3.4h, #0
-; CHECK-NEWLOWERING-NEXT: smlal2 v1.2d, v6.4s, v7.4s
-; CHECK-NEWLOWERING-NEXT: smlal v0.2d, v6.2s, v7.2s
-; CHECK-NEWLOWERING-NEXT: smull v18.2d, v4.2s, v5.2s
-; CHECK-NEWLOWERING-NEXT: smull2 v4.2d, v4.4s, v5.4s
-; CHECK-NEWLOWERING-NEXT: smlal2 v1.2d, v16.4s, v17.4s
-; CHECK-NEWLOWERING-NEXT: smlal v0.2d, v16.2s, v17.2s
-; CHECK-NEWLOWERING-NEXT: smlal2 v4.2d, v2.4s, v3.4s
-; CHECK-NEWLOWERING-NEXT: smlal v18.2d, v2.2s, v3.2s
-; CHECK-NEWLOWERING-NEXT: add v1.2d, v4.2d, v1.2d
-; CHECK-NEWLOWERING-NEXT: add v0.2d, v18.2d, v0.2d
-; CHECK-NEWLOWERING-NEXT: ret
entry:
%a.wide = zext <16 x i8> %a to <16 x i64>
%b.wide = sext <16 x i8> %b to <16 x i64>
@@ -533,32 +359,6 @@ define <4 x i64> @sudot_8to64(<4 x i64> %acc, <16 x i8> %a, <16 x i8> %b) {
; CHECK-I8MM-NEXT: saddw2 v1.2d, v1.2d, v4.4s
; CHECK-I8MM-NEXT: saddw v0.2d, v0.2d, v4.2s
; CHECK-I8MM-NEXT: ret
-;
-; CHECK-NEWLOWERING-LABEL: sudot_8to64:
-; CHECK-NEWLOWERING: // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT: sshll v4.8h, v2.8b, #0
-; CHECK-NEWLOWERING-NEXT: ushll v5.8h, v3.8b, #0
-; CHECK-NEWLOWERING-NEXT: sshll2 v2.8h, v2.16b, #0
-; CHECK-NEWLOWERING-NEXT: ushll2 v3.8h, v3.16b, #0
-; CHECK-NEWLOWERING-NEXT: sshll v6.4s, v4.4h, #0
-; CHECK-NEWLOWERING-NEXT: ushll v7.4s, v5.4h, #0
-; CHECK-NEWLOWERING-NEXT: sshll2 v4.4s, v4.8h, #0
-; CHECK-NEWLOWERING-NEXT: ushll2 v5.4s, v5.8h, #0
-; CHECK-NEWLOWERING-NEXT: sshll2 v16.4s, v2.8h, #0
-; CHECK-NEWLOWERING-NEXT: ushll2 v17.4s, v3.8h, #0
-; CHECK-NEWLOWERING-NEXT: sshll v2.4s, v2.4h, #0
-; CHECK-NEWLOWERING-NEXT: ushll v3.4s, v3.4h, #0
-; CHECK-NEWLOWERING-NEXT: smlal2 v1.2d, v6.4s, v7.4s
-; CHECK-NEWLOWERING-NEXT: smlal v0.2d, v6.2s, v7.2s
-; CHECK-NEWLOWERING-NEXT: smull v18.2d, v4.2s, v5.2s
-; CHECK-NEWLOWERING-NEXT: smull2 v4.2d, v4.4s, v5.4s
-; CHECK-NEWLOWERING-NEXT: smlal2 v1.2d, v16.4s, v17.4s
-; CHECK-NEWLOWERING-NEXT: smlal v0.2d, v16.2s, v17.2s
-; CHECK-NEWLOWERING-NEXT: smlal2 v4.2d, v2.4s, v3.4s
-; CHECK-NEWLOWERING-NEXT: smlal v18.2d, v2.2s, v3.2s
-; CHECK-NEWLOWERING-NEXT: add v1.2d, v4.2d, v1.2d
-; CHECK-NEWLOWERING-NEXT: add v0.2d, v18.2d, v0.2d
-; CHECK-NEWLOWERING-NEXT: ret
entry:
%a.wide = sext <16 x i8> %a to <16 x i64>
%b.wide = zext <16 x i8> %b to <16 x i64>
@@ -585,17 +385,6 @@ define <4 x i32> @udot_no_bin_op(<4 x i32> %acc, <16 x i8> %a){
; CHECK-NODOT-NEXT: uaddw2 v0.4s, v0.4s, v1.8h
; CHECK-NODOT-NEXT: add v0.4s, v2.4s, v0.4s
; CHECK-NODOT-NEXT: ret
-;
-; CHECK-NEWLOWERING-LABEL: udot_no_bin_op:
-; CHECK-NEWLOWERING: // %bb.0:
-; CHECK-NEWLOWERING-NEXT: ushll v2.8h, v1.8b, #0
-; CHECK-NEWLOWERING-NEXT: ushll2 v1.8h, v1.16b, #0
-; CHECK-NEWLOWERING-NEXT: ushll v3.4s, v1.4h, #0
-; CHECK-NEWLOWERING-NEXT: uaddw v0.4s, v0.4s, v2.4h
-; CHECK-NEWLOWERING-NEXT: uaddw2 v2.4s, v3.4s, v2.8h
-; CHECK-NEWLOWERING-NEXT: uaddw2 v0.4s, v0.4s, v1.8h
-; CHECK-NEWLOWERING-NEXT: add v0.4s, v2.4s, v0.4s
-; CHECK-NEWLOWERING-NEXT: ret
%a.wide = zext <16 x i8> %a to <16 x i32>
%partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> %acc, <16 x i32> %a.wide)
ret <4 x i32> %partial.reduce
@@ -618,17 +407,6 @@ define <4 x i32> @sdot_no_bin_op(<4 x i32> %acc, <16 x i8> %a){
; CHECK-NODOT-NEXT: saddw2 v0.4s, v0.4s, v1.8h
; CHECK-NODOT-NEXT: add v0.4s, v2.4s, v0.4s
; CHECK-NODOT-NEXT: ret
-;
-; CHECK-NEWLOWERING-LABEL: sdot_no_bin_op:
-; CHECK-NEWLOWERING: // %bb.0:
-; CHECK-NEWLOWERING-NEXT: sshll v2.8h, v1.8b, #0
-; CHECK-NEWLOWERING-NEXT: sshll2 v1.8h, v1.16b, #0
-; CHECK-NEWLOWERING-NEXT: sshll v3.4s, v1.4h, #0
-; CHECK-NEWLOWERING-NEXT: saddw v0.4s, v0.4s, v2.4h
-; CHECK-NEWLOWERING-NEXT: saddw2 v2.4s, v3.4s, v2.8h
-; CHECK-NEWLOWERING-NEXT: saddw2 v0.4s, v0.4s, v1.8h
-; CHECK-NEWLOWERING-NEXT: add v0.4s, v2.4s, v0.4s
-; CHECK-NEWLOWERING-NEXT: ret
%a.wide = sext <16 x i8> %a to <16 x i32>
%partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> %acc, <16 x i32> %a.wide)
ret <4 x i32> %partial.reduce
@@ -655,21 +433,6 @@ define <2 x i32> @udot_no_bin_op_narrow(<2 x i32> %acc, <8 x i8> %a){
; CHECK-NODOT-NEXT: uaddw v1.4s, v2.4s, v4.4h
; CHECK-NODOT-NEXT: add v0.2s, v1.2s, v0.2s
; CHECK-NODOT-NEXT: ret
-;
-; CHECK-NEWLOWERING-LABEL: udot_no_bin_op_narrow:
-; CHECK-NEWLOWERING: // %bb.0:
-; CHECK-NEWLOWERING-NEXT: ushll v1.8h, v1.8b, #0
-; CHECK-NEWLOWERING-NEXT: // kill: def $d0 killed $d0 def $q0
-; CHECK-NEWLOWERING-NEXT: ushll v2.4s, v1.4h, #0
-; CHECK-NEWLOWERING-NEXT: ushll2 v3.4s, v1.8h, #0
-; CHECK-NEWLOWERING-NEXT: ext v4.16b, v1.16b, v1.16b, #8
-; CHECK-NEWLOWERING-NEXT: uaddw v0.4s, v0.4s, v1.4h
-; CHECK-NEWLOWERING-NEXT: ext v3.16b, v3.16b, v3.16b, #8
-; CHECK-NEWLOWERING-NEXT: ext v2.16b, v2.16b, v2.16b, #8
-; CHECK-NEWLOWERING-NEXT: add v0.2s, v3.2s, v0.2s
-; CHECK-NEWLOWERING-NEXT: uaddw v1.4s, v2.4s, v4.4h
-; CHECK-NEWLOWERING-NEXT: add v0.2s, v1.2s, v0.2s
-; CHECK-NEWLOWERING-NEXT: ret
%a.wide = zext <8 x i8> %a to <8 x i32>
%partial.reduce = tail call <2 x i32> @llvm.experimental.vector.partial.reduce.add.v2i32.v8i32(<2 x i32> %acc, <8 x i32> %a.wide)
ret <2 x i32> %partial.reduce
@@ -696,21 +459,6 @@ define <2 x i32> @sdot_no_bin_op_narrow(<2 x i32> %acc, <8 x i8> %a){
; CHECK-NODOT-NEXT: saddw v1.4s, v2.4s, v4.4h
; CHECK-NODOT-NEXT: add v0.2s, v1.2s, v0.2s
; CHECK-NODOT-NEXT: ret
-;
-; CHECK-NEWLOWERING-LABEL: sdot_no_bin_op_narrow:
-; CHECK-NEWLOWERING: // %bb.0:
-; CHECK-NEWLOWERING-NEXT: sshll v1.8h, v1.8b, #0
-; CHECK-NEWLOWERING-NEXT: // kill: def $d0 killed $d0 def $q0
-; CHECK-NEWLOWERING-NEXT: sshll v2.4s, v1.4h, #0
-; CHECK-NEWLOWERING-NEXT: sshll2 v3.4s, v1.8h, #0
-; CHECK-NEWLOWERING-NEXT: ext v4.16b, v1.16b, v1.16b, #8
-; CHECK-NEWLOWERING-NEXT: saddw v0.4s, v0.4s, v1.4h
-; CHECK-NEWLOWERING-NEXT: ext v3.16b, v3.16b, v3.16b, #8
-; CHECK-NEWLOWERING-NEXT: ext v2.16b, v2.16b, v2.16b, #8
-; CHECK-NEWLOWERING-NEXT: add v0.2s, v3.2s, v0.2s
-; CHECK-NEWLOWERING-NEXT: saddw v1.4s, v2.4s, v4.4h
-; CHECK-NEWLOWERING-NEXT: add v0.2s, v1.2s, v0.2s
-; CHECK-NEWLOWERING-NEXT: ret
%a.wide = sext <8 x i8> %a to <8 x i32>
%partial.reduce = tail call <2 x i32> @llvm.experimental.vector.partial.reduce.add.v2i32.v8i32(<2 x i32> %acc, <8 x i32> %a.wide)
ret <2 x i32> %partial.reduce
@@ -743,24 +491,6 @@ define <4 x i64> @udot_no_bin_op_8to64(<4 x i64> %acc, <16 x i8> %a){
; CHECK-NODOT-NEXT: add v1.2d, v4.2d, v1.2d
; CHECK-NODOT-NEXT: add v0.2d, v3.2d, v0.2d
; CHECK-NODOT-NEXT: ret
-;
-; CHECK-NEWLOWERING-LABEL: udot_no_bin_op_8to64:
-; CHECK-NEWLOWERING: // %bb.0:
-; CHECK-NEWLOWERING-NEXT: ushll v3.8h, v2.8b, #0
-; CHECK-NEWLOWERING-NEXT: ushll2 v2.8h, v2.16b, #0
-; CHECK-NEWLOWERING-NEXT: ushll v4.4s, v3.4h, #0
-; CHECK-NEWLOWERING-NEXT: ushll v5.4s, v2.4h, #0
-; CHECK-NEWLOWERING-NEXT: ushll2 v3.4s, v3.8h, #0
-; CHECK-NEWLOWERING-NEXT: ushll2 v2.4s, v2.8h, #0
-; CHECK-NEWLOWERING-NEXT: uaddw2 v1.2d, v1.2d, v4.4s
-; CHECK-NEWLOWERING-NEXT: uaddw v0.2d, v0.2d, v4.2s
-; CHECK-NEWLOWERING-NEXT: uaddl2 v4.2d, v3.4s, v5.4s
-; CHECK-NEWLOWERING-NEXT: uaddl v3.2d, v3.2s, v5.2s
-; CHECK-NEWLOWERING-NEXT: uaddw2 v1.2d, v1.2d, v2.4s
-; CHECK-NEWLOWERING-NEXT: uaddw v0.2d, v0.2d, v2.2s
-; CHECK-NEWLOWERING-NEXT: add v1.2d, v4.2d, v1.2d
-; CHECK-NEWLOWERING-NEXT: add v0.2d, v3.2d, v0.2d
-; CHECK-NEWLOWERING-NEXT: ret
%a.wide = zext <16 x i8> %a to <16 x i64>
%partial.reduce = tail call <4 x i64> @llvm.experimental.vector.partial.reduce.add.v4i64.v16i64(<4 x i64> %acc, <16 x i64> %a.wide)
ret <4 x i64> %partial.reduce
@@ -793,24 +523,6 @@ define <4 x i64> @sdot_no_bin_op_8to64(<4 x i64> %acc, <16 x i8> %a){
; CHECK-NODOT-NEXT: add v1.2d, v4.2d, v1.2d
; CHECK-NODOT-NEXT: add v0.2d, v3.2d, v0.2d
; CHECK-NODOT-NEXT: ret
-;
-; CHECK-NEWLOWERING-LABEL: sdot_no_bin_op_8to64:
-; CHECK-NEWLOWERING: // %bb.0:
-; CHECK-NEWLOWERING-NEXT: sshll v3.8h, v2.8b, #0
-; CHECK-NEWLOWERING-NEXT: sshll2 v2.8h, v2.16b, #0
-; CHECK-NEWLOWERING-NEXT: sshll v4.4s, v3.4h, #0
-; CHECK-NEWLOWERING-NEXT: sshll v5.4s, v2.4h, #0
-; CHECK-NEWLOWERING-NEXT: sshll2 v3.4s, v3.8h, #0
-; CHECK-NEWLOWERING-NEXT: sshll2 v2.4s, v2.8h, #0
-; CHECK-NEWLOWERING-NEXT: saddw2 v1.2d, v1.2d, v4.4s
-; CHECK-NEWLOWERING-NEXT: saddw v0.2d, v0.2d, v4.2s
-; CHECK-NEWLOWERING-NEXT: saddl2 v4.2d, v3.4s, v5.4s
-; CHECK-NEWLOWERING-NEXT: saddl v3.2d, v3.2s, v5.2s
-; CHECK-NEWLOWERING-NEXT: saddw2 v1.2d, v1.2d, v2.4s
-; CHECK-NEWLOWERING-NEXT: saddw v0.2d, v0.2d, v2.2s
-; CHECK-NEWLOWERING-NEXT: add v1.2d, v4.2d, v1.2d
-; CHECK-NEWLOWERING-NEXT: add v0.2d, v3.2d, v0.2d
-; CHECK-NEWLOWERING-NEXT: ret
%a.wide = sext <16 x i8> %a to <16 x i64>
%partial.reduce = tail call <4 x i64> @llvm.experimental.vector.partial.reduce.add.v4i64.v16i64(<4 x i64> %acc, <16 x i64> %a.wide)
ret <4 x i64> %partial.reduce
diff --git a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
index 16c0001dbdb83..b66b952b224ab 100644
--- a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
@@ -1,18 +1,13 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
; RUN: llc -mtriple=aarch64 -mattr=+sve2,+i8mm %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-I8MM
; RUN: llc -mtriple=aarch64 -mattr=+sve2 %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-NOI8MM
-; RUN: llc -mtriple=aarch64 -mattr=+sve2,+i8mm -new-partial-reduce-lowering %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-NEWLOWERING
+; RUN: llc -mtriple=aarch64 -mattr=+sve2,+i8mm -new-partial-reduce-lowering %s -o - | FileCheck %s --check-prefixes=CHECK-NEWLOWERING
define <vscale x 4 x i32> @udot(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
-; CHECK-I8MM-LABEL: udot:
-; CHECK-I8MM: // %bb.0: // %entry
-; CHECK-I8MM-NEXT: udot z0.s, z1.b, z2.b
-; CHECK-I8MM-NEXT: ret
-;
-; CHECK-NOI8MM-LABEL: udot:
-; CHECK-NOI8MM: // %bb.0: // %entry
-; CHECK-NOI8MM-NEXT: udot z0.s, z1.b, z2.b
-; CHECK-NOI8MM-NEXT: ret
+; CHECK-LABEL: udot:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: udot z0.s, z1.b, z2.b
+; CHECK-NEXT: ret
;
; CHECK-NEWLOWERING-LABEL: udot:
; CHECK-NEWLOWERING: // %bb.0: // %entry
@@ -45,15 +40,10 @@ entry:
}
define <vscale x 2 x i64> @udot_wide(<vscale x 2 x i64> %acc, <vscale x 8 x i16> %a, <vscale x 8 x i16> %b) {
-; CHECK-I8MM-LABEL: udot_wide:
-; CHECK-I8MM: // %bb.0: // %entry
-; CHECK-I8MM-NEXT: udot z0.d, z1.h, z2.h
-; CHECK-I8MM-NEXT: ret
-;
-; CHECK-NOI8MM-LABEL: udot_wide:
-; CHECK-NOI8MM: // %bb.0: // %entry
-; CHECK-NOI8MM-NEXT: udot z0.d, z1.h, z2.h
-; CHECK-NOI8MM-NEXT: ret
+; CHECK-LABEL: udot_wide:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: udot z0.d, z1.h, z2.h
+; CHECK-NEXT: ret
;
; CHECK-NEWLOWERING-LABEL: udot_wide:
; CHECK-NEWLOWERING: // %bb.0: // %entry
@@ -86,15 +76,10 @@ entry:
}
define <vscale x 4 x i32> @sdot(<vscale x 4 x i32> %accc, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
-; CHECK-I8MM-LABEL: sdot:
-; CHECK-I8MM: // %bb.0: // %entry
-; CHECK-I8MM-NEXT: sdot z0.s, z1.b, z2.b
-; CHECK-I8MM-NEXT: ret
-;
-; CHECK-NOI8MM-LABEL: sdot:
-; CHECK-NOI8MM: // %bb.0: // %entry
-; CHECK-NOI8MM-NEXT: sdot z0.s, z1.b, z2.b
-; CHECK-NOI8MM-NEXT: ret
+; CHECK-LABEL: sdot:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: sdot z0.s, z1.b, z2.b
+; CHECK-NEXT: ret
;
; CHECK-NEWLOWERING-LABEL: sdot:
; CHECK-NEWLOWERING: // %bb.0: // %entry
@@ -127,15 +112,10 @@ entry:
}
define <vscale x 2 x i64> @sdot_wide(<vscale x 2 x i64> %acc, <vscale x 8 x i16> %a, <vscale x 8 x i16> %b) {
-; CHECK-I8MM-LABEL: sdot_wide:
-; CHECK-I8MM: // %bb.0: // %entry
-; CHECK-I8MM-NEXT: sdot z0.d, z1.h, z2.h
-; CHECK-I8MM-NEXT: ret
-;
-; CHECK-NOI8MM-LABEL: sdot_wide:
-; CHECK-NOI8MM: // %bb.0: // %entry
-; CHECK-NOI8MM-NEXT: sdot z0.d, z1.h, z2.h
-; CHECK-NOI8MM-NEXT: ret
+; CHECK-LABEL: sdot_wide:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: sdot z0.d, z1.h, z2.h
+; CHECK-NEXT: ret
;
; CHECK-NEWLOWERING-LABEL: sdot_wide:
; CHECK-NEWLOWERING: // %bb.0: // %entry
@@ -286,25 +266,15 @@ entry:
}
define <vscale x 4 x i64> @udot_8to64(<vscale x 4 x i64> %acc, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
-; CHECK-I8MM-LABEL: udot_8to64:
-; CHECK-I8MM: // %bb.0: // %entry
-; CHECK-I8MM-NEXT: mov z4.s, #0 // =0x0
-; CHECK-I8MM-NEXT: udot z4.s, z2.b, z3.b
-; CHECK-I8MM-NEXT: sunpklo z2.d, z4.s
-; CHECK-I8MM-NEXT: sunpkhi z3.d, z4.s
-; CHECK-I8MM-NEXT: add z0.d, z0.d, z2.d
-; CHECK-I8MM-NEXT: add z1.d, z1.d, z3.d
-; CHECK-I8MM-NEXT: ret
-;
-; CHECK-NOI8MM-LABEL: udot_8to64:
-; CHECK-NOI8MM: // %bb.0: // %entry
-; CHECK-NOI8MM-NEXT: mov z4.s, #0 // =0x0
-; CHECK-NOI8MM-NEXT: udot z4.s, z2.b, z3.b
-; CHECK-NOI8MM-NEXT: sunpklo z2.d, z4.s
-; CHECK-NOI8MM-NEXT: sunpkhi z3.d, z4.s
-; CHECK-NOI8MM-NEXT: add z0.d, z0.d, z2.d
-; CHECK-NOI8MM-NEXT: add z1.d, z1.d, z3.d
-; CHECK-NOI8MM-NEXT: ret
+; CHECK-LABEL: udot_8to64:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: mov z4.s, #0 // =0x0
+; CHECK-NEXT: udot z4.s, z2.b, z3.b
+; CHECK-NEXT: sunpklo z2.d, z4.s
+; CHECK-NEXT: sunpkhi z3.d, z4.s
+; CHECK-NEXT: add z0.d, z0.d, z2.d
+; CHECK-NEXT: add z1.d, z1.d, z3.d
+; CHECK-NEXT: ret
;
; CHECK-NEWLOWERING-LABEL: udot_8to64:
; CHECK-NEWLOWERING: // %bb.0: // %entry
@@ -372,25 +342,15 @@ entry:
}
define <vscale x 4 x i64> @sdot_8to64(<vscale x 4 x i64> %acc, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b){
-; CHECK-I8MM-LABEL: sdot_8to64:
-; CHECK-I8MM: // %bb.0: // %entry
-; CHECK-I8MM-NEXT: mov z4.s, #0 // =0x0
-; CHECK-I8MM-NEXT: sdot z4.s, z2.b, z3.b
-; CHECK-I8MM-NEXT: sunpklo z2.d, z4.s
-; CHECK-I8MM-NEXT: sunpkhi z3.d, z4.s
-; CHECK-I8MM-NEXT: add z0.d, z0.d, z2.d
-; CHECK-I8MM-NEXT: add z1.d, z1.d, z3.d
-; CHECK-I8MM-NEXT: ret
-;
-; CHECK-NOI8MM-LABEL: sdot_8to64:
-; CHECK-NOI8MM: // %bb.0: // %entry
-; CHECK-NOI8MM-NEXT: mov z4.s, #0 // =0x0
-; CHECK-NOI8MM-NEXT: sdot z4.s, z2.b, z3.b
-; CHECK-NOI8MM-NEXT: sunpklo z2.d, z4.s
-; CHECK-NOI8MM-NEXT: sunpkhi z3.d, z4.s
-; CHECK-NOI8MM-NEXT: add z0.d, z0.d, z2.d
-; CHECK-NOI8MM-NEXT: add z1.d, z1.d, z3.d
-; CHECK-NOI8MM-NEXT: ret
+; CHECK-LABEL: sdot_8to64:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: mov z4.s, #0 // =0x0
+; CHECK-NEXT: sdot z4.s, z2.b, z3.b
+; CHECK-NEXT: sunpklo z2.d, z4.s
+; CHECK-NEXT: sunpkhi z3.d, z4.s
+; CHECK-NEXT: add z0.d, z0.d, z2.d
+; CHECK-NEXT: add z1.d, z1.d, z3.d
+; CHECK-NEXT: ret
;
; CHECK-NEWLOWERING-LABEL: sdot_8to64:
; CHECK-NEWLOWERING: // %bb.0: // %entry
@@ -724,17 +684,11 @@ entry:
}
define <vscale x 4 x i32> @udot_no_bin_op(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a){
-; CHECK-I8MM-LABEL: udot_no_bin_op:
-; CHECK-I8MM: // %bb.0:
-; CHECK-I8MM-NEXT: mov z2.b, #1 // =0x1
-; CHECK-I8MM-NEXT: udot z0.s, z1.b, z2.b
-; CHECK-I8MM-NEXT: ret
-;
-; CHECK-NOI8MM-LABEL: udot_no_bin_op:
-; CHECK-NOI8MM: // %bb.0:
-; CHECK-NOI8MM-NEXT: mov z2.b, #1 // =0x1
-; CHECK-NOI8MM-NEXT: udot z0.s, z1.b, z2.b
-; CHECK-NOI8MM-NEXT: ret
+; CHECK-LABEL: udot_no_bin_op:
+; CHECK: // %bb.0:
+; CHECK-NEXT: mov z2.b, #1 // =0x1
+; CHECK-NEXT: udot z0.s, z1.b, z2.b
+; CHECK-NEXT: ret
;
; CHECK-NEWLOWERING-LABEL: udot_no_bin_op:
; CHECK-NEWLOWERING: // %bb.0:
@@ -755,17 +709,11 @@ define <vscale x 4 x i32> @udot_no_bin_op(<vscale x 4 x i32> %acc, <vscale x 16
}
define <vscale x 4 x i32> @sdot_no_bin_op(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a){
-; CHECK-I8MM-LABEL: sdot_no_bin_op:
-; CHECK-I8MM: // %bb.0:
-; CHECK-I8MM-NEXT: mov z2.b, #1 // =0x1
-; CHECK-I8MM-NEXT: sdot z0.s, z1.b, z2.b
-; CHECK-I8MM-NEXT: ret
-;
-; CHECK-NOI8MM-LABEL: sdot_no_bin_op:
-; CHECK-NOI8MM: // %bb.0:
-; CHECK-NOI8MM-NEXT: mov z2.b, #1 // =0x1
-; CHECK-NOI8MM-NEXT: sdot z0.s, z1.b, z2.b
-; CHECK-NOI8MM-NEXT: ret
+; CHECK-LABEL: sdot_no_bin_op:
+; CHECK: // %bb.0:
+; CHECK-NEXT: mov z2.b, #1 // =0x1
+; CHECK-NEXT: sdot z0.s, z1.b, z2.b
+; CHECK-NEXT: ret
;
; CHECK-NEWLOWERING-LABEL: sdot_no_bin_op:
; CHECK-NEWLOWERING: // %bb.0:
@@ -786,17 +734,11 @@ define <vscale x 4 x i32> @sdot_no_bin_op(<vscale x 4 x i32> %acc, <vscale x 16
}
define <vscale x 2 x i64> @udot_no_bin_op_wide(<vscale x 2 x i64> %acc, <vscale x 8 x i16> %a, <vscale x 8 x i16> %b){
-; CHECK-I8MM-LABEL: udot_no_bin_op_wide:
-; CHECK-I8MM: // %bb.0: // %entry
-; CHECK-I8MM-NEXT: mov z2.h, #1 // =0x1
-; CHECK-I8MM-NEXT: udot z0.d, z1.h, z2.h
-; CHECK-I8MM-NEXT: ret
-;
-; CHECK-NOI8MM-LABEL: udot_no_bin_op_wide:
-; CHECK-NOI8MM: // %bb.0: // %entry
-; CHECK-NOI8MM-NEXT: mov z2.h, #1 // =0x1
-; CHECK-NOI8MM-NEXT: udot z0.d, z1.h, z2.h
-; CHECK-NOI8MM-NEXT: ret
+; CHECK-LABEL: udot_no_bin_op_wide:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: mov z2.h, #1 // =0x1
+; CHECK-NEXT: udot z0.d, z1.h, z2.h
+; CHECK-NEXT: ret
;
; CHECK-NEWLOWERING-LABEL: udot_no_bin_op_wide:
; CHECK-NEWLOWERING: // %bb.0: // %entry
@@ -818,17 +760,11 @@ entry:
}
define <vscale x 2 x i64> @sdot_no_bin_op_wide(<vscale x 2 x i64> %acc, <vscale x 8 x i16> %a, <vscale x 8 x i16> %b){
-; CHECK-I8MM-LABEL: sdot_no_bin_op_wide:
-; CHECK-I8MM: // %bb.0: // %entry
-; CHECK-I8MM-NEXT: mov z2.h, #1 // =0x1
-; CHECK-I8MM-NEXT: sdot z0.d, z1.h, z2.h
-; CHECK-I8MM-NEXT: ret
-;
-; CHECK-NOI8MM-LABEL: sdot_no_bin_op_wide:
-; CHECK-NOI8MM: // %bb.0: // %entry
-; CHECK-NOI8MM-NEXT: mov z2.h, #1 // =0x1
-; CHECK-NOI8MM-NEXT: sdot z0.d, z1.h, z2.h
-; CHECK-NOI8MM-NEXT: ret
+; CHECK-LABEL: sdot_no_bin_op_wide:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: mov z2.h, #1 // =0x1
+; CHECK-NEXT: sdot z0.d, z1.h, z2.h
+; CHECK-NEXT: ret
;
; CHECK-NEWLOWERING-LABEL: sdot_no_bin_op_wide:
; CHECK-NEWLOWERING: // %bb.0: // %entry
@@ -850,27 +786,16 @@ entry:
}
define <vscale x 4 x i64> @udot_no_bin_op_8to64(<vscale x 4 x i64> %acc, <vscale x 16 x i8> %a){
-; CHECK-I8MM-LABEL: udot_no_bin_op_8to64:
-; CHECK-I8MM: // %bb.0:
-; CHECK-I8MM-NEXT: mov z3.b, #1 // =0x1
-; CHECK-I8MM-NEXT: mov z4.s, #0 // =0x0
-; CHECK-I8MM-NEXT: udot z4.s, z2.b, z3.b
-; CHECK-I8MM-NEXT: sunpklo z2.d, z4.s
-; CHECK-I8MM-NEXT: sunpkhi z3.d, z4.s
-; CHECK-I8MM-NEXT: add z0.d, z0.d, z2.d
-; CHECK-I8MM-NEXT: add z1.d, z1.d, z3.d
-; CHECK-I8MM-NEXT: ret
-;
-; CHECK-NOI8MM-LABEL: udot_no_bin_op_8to64:
-; CHECK-NOI8MM: // %bb.0:
-; CHECK-NOI8MM-NEXT: mov z3.b, #1 // =0x1
-; CHECK-NOI8MM-NEXT: mov z4.s, #0 // =0x0
-; CHECK-NOI8MM-NEXT: udot z4.s, z2.b, z3.b
-; CHECK-NOI8MM-NEXT: sunpklo z2.d, z4.s
-; CHECK-NOI8MM-NEXT: sunpkhi z3.d, z4.s
-; CHECK-NOI8MM-NEXT: add z0.d, z0.d, z2.d
-; CHECK-NOI8MM-NEXT: add z1.d, z1.d, z3.d
-; CHECK-NOI8MM-NEXT: ret
+; CHECK-LABEL: udot_no_bin_op_8to64:
+; CHECK: // %bb.0:
+; CHECK-NEXT: mov z3.b, #1 // =0x1
+; CHECK-NEXT: mov z4.s, #0 // =0x0
+; CHECK-NEXT: udot z4.s, z2.b, z3.b
+; CHECK-NEXT: sunpklo z2.d, z4.s
+; CHECK-NEXT: sunpkhi z3.d, z4.s
+; CHECK-NEXT: add z0.d, z0.d, z2.d
+; CHECK-NEXT: add z1.d, z1.d, z3.d
+; CHECK-NEXT: ret
;
; CHECK-NEWLOWERING-LABEL: udot_no_bin_op_8to64:
; CHECK-NEWLOWERING: // %bb.0:
@@ -903,27 +828,16 @@ define <vscale x 4 x i64> @udot_no_bin_op_8to64(<vscale x 4 x i64> %acc, <vscale
}
define <vscale x 4 x i64> @sdot_no_bin_op_8to64(<vscale x 4 x i64> %acc, <vscale x 16 x i8> %a){
-; CHECK-I8MM-LABEL: sdot_no_bin_op_8to64:
-; CHECK-I8MM: // %bb.0:
-; CHECK-I8MM-NEXT: mov z3.b, #1 // =0x1
-; CHECK-I8MM-NEXT: mov z4.s, #0 // =0x0
-; CHECK-I8MM-NEXT: sdot z4.s, z2.b, z3.b
-; CHECK-I8MM-NEXT: sunpklo z2.d, z4.s
-; CHECK-I8MM-NEXT: sunpkhi z3.d, z4.s
-; CHECK-I8MM-NEXT: add z0.d, z0.d, z2.d
-; CHECK-I8MM-NEXT: add z1.d, z1.d, z3.d
-; CHECK-I8MM-NEXT: ret
-;
-; CHECK-NOI8MM-LABEL: sdot_no_bin_op_8to64:
-; CHECK-NOI8MM: // %bb.0:
-; CHECK-NOI8MM-NEXT: mov z3.b, #1 // =0x1
-; CHECK-NOI8MM-NEXT: mov z4.s, #0 // =0x0
-; CHECK-NOI8MM-NEXT: sdot z4.s, z2.b, z3.b
-; CHECK-NOI8MM-NEXT: sunpklo z2.d, z4.s
-; CHECK-NOI8MM-NEXT: sunpkhi z3.d, z4.s
-; CHECK-NOI8MM-NEXT: add z0.d, z0.d, z2.d
-; CHECK-NOI8MM-NEXT: add z1.d, z1.d, z3.d
-; CHECK-NOI8MM-NEXT: ret
+; CHECK-LABEL: sdot_no_bin_op_8to64:
+; CHECK: // %bb.0:
+; CHECK-NEXT: mov z3.b, #1 // =0x1
+; CHECK-NEXT: mov z4.s, #0 // =0x0
+; CHECK-NEXT: sdot z4.s, z2.b, z3.b
+; CHECK-NEXT: sunpklo z2.d, z4.s
+; CHECK-NEXT: sunpkhi z3.d, z4.s
+; CHECK-NEXT: add z0.d, z0.d, z2.d
+; CHECK-NEXT: add z1.d, z1.d, z3.d
+; CHECK-NEXT: ret
;
; CHECK-NEWLOWERING-LABEL: sdot_no_bin_op_8to64:
; CHECK-NEWLOWERING: // %bb.0:
@@ -968,6 +882,19 @@ define <vscale x 4 x i32> @not_udot(<vscale x 4 x i32> %acc, <vscale x 8 x i8> %
; CHECK-NEXT: mla z0.s, p0/m, z3.s, z4.s
; CHECK-NEXT: mla z0.s, p0/m, z1.s, z2.s
; CHECK-NEXT: ret
+;
+; CHECK-NEWLOWERING-LABEL: not_udot:
+; CHECK-NEWLOWERING: // %bb.0: // %entry
+; CHECK-NEWLOWERING-NEXT: and z1.h, z1.h, #0xff
+; CHECK-NEWLOWERING-NEXT: and z2.h, z2.h, #0xff
+; CHECK-NEWLOWERING-NEXT: ptrue p0.s
+; CHECK-NEWLOWERING-NEXT: uunpklo z3.s, z1.h
+; CHECK-NEWLOWERING-NEXT: uunpklo z4.s, z2.h
+; CHECK-NEWLOWERING-NEXT: uunpkhi z1.s, z1.h
+; CHECK-NEWLOWERING-NEXT: uunpkhi z2.s, z2.h
+; CHECK-NEWLOWERING-NEXT: mla z0.s, p0/m, z3.s, z4.s
+; CHECK-NEWLOWERING-NEXT: mla z0.s, p0/m, z1.s, z2.s
+; CHECK-NEWLOWERING-NEXT: ret
entry:
%a.wide = zext <vscale x 8 x i8> %a to <vscale x 8 x i32>
%b.wide = zext <vscale x 8 x i8> %b to <vscale x 8 x i32>
@@ -989,6 +916,19 @@ define <vscale x 2 x i64> @not_udot_wide(<vscale x 2 x i64> %acc, <vscale x 4 x
; CHECK-NEXT: mla z0.d, p0/m, z3.d, z4.d
; CHECK-NEXT: mla z0.d, p0/m, z1.d, z2.d
; CHECK-NEXT: ret
+;
+; CHECK-NEWLOWERING-LABEL: not_udot_wide:
+; CHECK-NEWLOWERING: // %bb.0: // %entry
+; CHECK-NEWLOWERING-NEXT: and z1.s, z1.s, #0xffff
+; CHECK-NEWLOWERING-NEXT: and z2.s, z2.s, #0xffff
+; CHECK-NEWLOWERING-NEXT: ptrue p0.d
+; CHECK-NEWLOWERING-NEXT: uunpklo z3.d, z1.s
+; CHECK-NEWLOWERING-NEXT: uunpklo z4.d, z2.s
+; CHECK-NEWLOWERING-NEXT: uunpkhi z1.d, z1.s
+; CHECK-NEWLOWERING-NEXT: uunpkhi z2.d, z2.s
+; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z3.d, z4.d
+; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z1.d, z2.d
+; CHECK-NEWLOWERING-NEXT: ret
entry:
%a.wide = zext <vscale x 4 x i16> %a to <vscale x 4 x i64>
%b.wide = zext <vscale x 4 x i16> %b to <vscale x 4 x i64>
@@ -1020,6 +960,29 @@ define <vscale x 2 x i64> @not_usdot(<vscale x 2 x i64> %acc, <vscale x 8 x i16>
; CHECK-NEXT: mla z1.d, p0/m, z7.d, z24.d
; CHECK-NEXT: add z0.d, z1.d, z0.d
; CHECK-NEXT: ret
+;
+; CHECK-NEWLOWERING-LABEL: not_usdot:
+; CHECK-NEWLOWERING: // %bb.0: // %entry
+; CHECK-NEWLOWERING-NEXT: uunpklo z3.s, z1.h
+; CHECK-NEWLOWERING-NEXT: sunpklo z4.s, z2.h
+; CHECK-NEWLOWERING-NEXT: uunpkhi z1.s, z1.h
+; CHECK-NEWLOWERING-NEXT: sunpkhi z2.s, z2.h
+; CHECK-NEWLOWERING-NEXT: ptrue p0.d
+; CHECK-NEWLOWERING-NEXT: uunpklo z5.d, z3.s
+; CHECK-NEWLOWERING-NEXT: uunpkhi z3.d, z3.s
+; CHECK-NEWLOWERING-NEXT: sunpklo z6.d, z4.s
+; CHECK-NEWLOWERING-NEXT: sunpkhi z4.d, z4.s
+; CHECK-NEWLOWERING-NEXT: uunpklo z7.d, z1.s
+; CHECK-NEWLOWERING-NEXT: uunpkhi z1.d, z1.s
+; CHECK-NEWLOWERING-NEXT: sunpklo z24.d, z2.s
+; CHECK-NEWLOWERING-NEXT: sunpkhi z2.d, z2.s
+; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z5.d, z6.d
+; CHECK-NEWLOWERING-NEXT: mul z3.d, z3.d, z4.d
+; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z1.d, z2.d
+; CHECK-NEWLOWERING-NEXT: movprfx z1, z3
+; CHECK-NEWLOWERING-NEXT: mla z1.d, p0/m, z7.d, z24.d
+; CHECK-NEWLOWERING-NEXT: add z0.d, z1.d, z0.d
+; CHECK-NEWLOWERING-NEXT: ret
entry:
%a.wide = zext <vscale x 8 x i16> %a to <vscale x 8 x i64>
%b.wide = sext <vscale x 8 x i16> %b to <vscale x 8 x i64>
@@ -1051,6 +1014,29 @@ define <vscale x 2 x i64> @not_sudot(<vscale x 2 x i64> %acc, <vscale x 8 x i16>
; CHECK-NEXT: mla z1.d, p0/m, z7.d, z24.d
; CHECK-NEXT: add z0.d, z1.d, z0.d
; CHECK-NEXT: ret
+;
+; CHECK-NEWLOWERING-LABEL: not_sudot:
+; CHECK-NEWLOWERING: // %bb.0: // %entry
+; CHECK-NEWLOWERING-NEXT: sunpklo z3.s, z1.h
+; CHECK-NEWLOWERING-NEXT: uunpklo z4.s, z2.h
+; CHECK-NEWLOWERING-NEXT: sunpkhi z1.s, z1.h
+; CHECK-NEWLOWERING-NEXT: uunpkhi z2.s, z2.h
+; CHECK-NEWLOWERING-NEXT: ptrue p0.d
+; CHECK-NEWLOWERING-NEXT: sunpklo z5.d, z3.s
+; CHECK-NEWLOWERING-NEXT: sunpkhi z3.d, z3.s
+; CHECK-NEWLOWERING-NEXT: uunpklo z6.d, z4.s
+; CHECK-NEWLOWERING-NEXT: uunpkhi z4.d, z4.s
+; CHECK-NEWLOWERING-NEXT: sunpklo z7.d, z1.s
+; CHECK-NEWLOWERING-NEXT: sunpkhi z1.d, z1.s
+; CHECK-NEWLOWERING-NEXT: uunpklo z24.d, z2.s
+; CHECK-NEWLOWERING-NEXT: uunpkhi z2.d, z2.s
+; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z5.d, z6.d
+; CHECK-NEWLOWERING-NEXT: mul z3.d, z3.d, z4.d
+; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z1.d, z2.d
+; CHECK-NEWLOWERING-NEXT: movprfx z1, z3
+; CHECK-NEWLOWERING-NEXT: mla z1.d, p0/m, z7.d, z24.d
+; CHECK-NEWLOWERING-NEXT: add z0.d, z1.d, z0.d
+; CHECK-NEWLOWERING-NEXT: ret
entry:
%a.wide = sext <vscale x 8 x i16> %a to <vscale x 8 x i64>
%b.wide = zext <vscale x 8 x i16> %b to <vscale x 8 x i64>
@@ -1083,6 +1069,30 @@ define <vscale x 2 x i64> @udot_different_types(<vscale x 2 x i64> %acc, <vscale
; CHECK-NEXT: mla z1.d, p0/m, z7.d, z24.d
; CHECK-NEXT: add z0.d, z1.d, z0.d
; CHECK-NEXT: ret
+;
+; CHECK-NEWLOWERING-LABEL: udot_different_types:
+; CHECK-NEWLOWERING: // %bb.0: // %entry
+; CHECK-NEWLOWERING-NEXT: and z2.h, z2.h, #0xff
+; CHECK-NEWLOWERING-NEXT: uunpklo z3.s, z1.h
+; CHECK-NEWLOWERING-NEXT: uunpkhi z1.s, z1.h
+; CHECK-NEWLOWERING-NEXT: ptrue p0.d
+; CHECK-NEWLOWERING-NEXT: uunpklo z4.s, z2.h
+; CHECK-NEWLOWERING-NEXT: uunpkhi z2.s, z2.h
+; CHECK-NEWLOWERING-NEXT: uunpklo z5.d, z3.s
+; CHECK-NEWLOWERING-NEXT: uunpkhi z3.d, z3.s
+; CHECK-NEWLOWERING-NEXT: uunpklo z7.d, z1.s
+; CHECK-NEWLOWERING-NEXT: uunpkhi z1.d, z1.s
+; CHECK-NEWLOWERING-NEXT: uunpklo z6.d, z4.s
+; CHECK-NEWLOWERING-NEXT: uunpkhi z4.d, z4.s
+; CHECK-NEWLOWERING-NEXT: uunpklo z24.d, z2.s
+; CHECK-NEWLOWERING-NEXT: uunpkhi z2.d, z2.s
+; CHECK-NEWLOWERING-NEXT: mul z3.d, z3.d, z4.d
+; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z5.d, z6.d
+; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z1.d, z2.d
+; CHECK-NEWLOWERING-NEXT: movprfx z1, z3
+; CHECK-NEWLOWERING-NEXT: mla z1.d, p0/m, z7.d, z24.d
+; CHECK-NEWLOWERING-NEXT: add z0.d, z1.d, z0.d
+; CHECK-NEWLOWERING-NEXT: ret
entry:
%a.wide = zext <vscale x 8 x i16> %a to <vscale x 8 x i64>
%b.wide = zext <vscale x 8 x i8> %b to <vscale x 8 x i64>
@@ -1116,6 +1126,31 @@ define <vscale x 2 x i64> @sdot_different_types(<vscale x 2 x i64> %acc, <vscale
; CHECK-NEXT: mla z1.d, p0/m, z7.d, z24.d
; CHECK-NEXT: add z0.d, z1.d, z0.d
; CHECK-NEXT: ret
+;
+; CHECK-NEWLOWERING-LABEL: sdot_different_types:
+; CHECK-NEWLOWERING: // %bb.0: // %entry
+; CHECK-NEWLOWERING-NEXT: ptrue p0.h
+; CHECK-NEWLOWERING-NEXT: sunpklo z3.s, z1.h
+; CHECK-NEWLOWERING-NEXT: sunpkhi z1.s, z1.h
+; CHECK-NEWLOWERING-NEXT: sxtb z2.h, p0/m, z2.h
+; CHECK-NEWLOWERING-NEXT: ptrue p0.d
+; CHECK-NEWLOWERING-NEXT: sunpklo z5.d, z3.s
+; CHECK-NEWLOWERING-NEXT: sunpkhi z3.d, z3.s
+; CHECK-NEWLOWERING-NEXT: sunpklo z7.d, z1.s
+; CHECK-NEWLOWERING-NEXT: sunpklo z4.s, z2.h
+; CHECK-NEWLOWERING-NEXT: sunpkhi z2.s, z2.h
+; CHECK-NEWLOWERING-NEXT: sunpkhi z1.d, z1.s
+; CHECK-NEWLOWERING-NEXT: sunpklo z6.d, z4.s
+; CHECK-NEWLOWERING-NEXT: sunpkhi z4.d, z4.s
+; CHECK-NEWLOWERING-NEXT: sunpklo z24.d, z2.s
+; CHECK-NEWLOWERING-NEXT: sunpkhi z2.d, z2.s
+; CHECK-NEWLOWERING-NEXT: mul z3.d, z3.d, z4.d
+; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z5.d, z6.d
+; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z1.d, z2.d
+; CHECK-NEWLOWERING-NEXT: movprfx z1, z3
+; CHECK-NEWLOWERING-NEXT: mla z1.d, p0/m, z7.d, z24.d
+; CHECK-NEWLOWERING-NEXT: add z0.d, z1.d, z0.d
+; CHECK-NEWLOWERING-NEXT: ret
entry:
%a.wide = sext <vscale x 8 x i16> %a to <vscale x 8 x i64>
%b.wide = sext <vscale x 8 x i8> %b to <vscale x 8 x i64>
@@ -1149,6 +1184,31 @@ define <vscale x 2 x i64> @usdot_different_types(<vscale x 2 x i64> %acc, <vscal
; CHECK-NEXT: mla z1.d, p0/m, z7.d, z24.d
; CHECK-NEXT: add z0.d, z1.d, z0.d
; CHECK-NEXT: ret
+;
+; CHECK-NEWLOWERING-LABEL: usdot_different_types:
+; CHECK-NEWLOWERING: // %bb.0: // %entry
+; CHECK-NEWLOWERING-NEXT: ptrue p0.h
+; CHECK-NEWLOWERING-NEXT: uunpklo z3.s, z1.h
+; CHECK-NEWLOWERING-NEXT: uunpkhi z1.s, z1.h
+; CHECK-NEWLOWERING-NEXT: sxtb z2.h, p0/m, z2.h
+; CHECK-NEWLOWERING-NEXT: ptrue p0.d
+; CHECK-NEWLOWERING-NEXT: uunpklo z5.d, z3.s
+; CHECK-NEWLOWERING-NEXT: uunpkhi z3.d, z3.s
+; CHECK-NEWLOWERING-NEXT: uunpklo z7.d, z1.s
+; CHECK-NEWLOWERING-NEXT: sunpklo z4.s, z2.h
+; CHECK-NEWLOWERING-NEXT: sunpkhi z2.s, z2.h
+; CHECK-NEWLOWERING-NEXT: uunpkhi z1.d, z1.s
+; CHECK-NEWLOWERING-NEXT: sunpklo z6.d, z4.s
+; CHECK-NEWLOWERING-NEXT: sunpkhi z4.d, z4.s
+; CHECK-NEWLOWERING-NEXT: sunpklo z24.d, z2.s
+; CHECK-NEWLOWERING-NEXT: sunpkhi z2.d, z2.s
+; CHECK-NEWLOWERING-NEXT: mul z3.d, z3.d, z4.d
+; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z5.d, z6.d
+; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z1.d, z2.d
+; CHECK-NEWLOWERING-NEXT: movprfx z1, z3
+; CHECK-NEWLOWERING-NEXT: mla z1.d, p0/m, z7.d, z24.d
+; CHECK-NEWLOWERING-NEXT: add z0.d, z1.d, z0.d
+; CHECK-NEWLOWERING-NEXT: ret
entry:
%a.wide = zext <vscale x 8 x i16> %a to <vscale x 8 x i64>
%b.wide = sext <vscale x 8 x i8> %b to <vscale x 8 x i64>
@@ -1181,6 +1241,30 @@ define <vscale x 2 x i64> @sudot_different_types(<vscale x 2 x i64> %acc, <vscal
; CHECK-NEXT: mla z1.d, p0/m, z7.d, z24.d
; CHECK-NEXT: add z0.d, z1.d, z0.d
; CHECK-NEXT: ret
+;
+; CHECK-NEWLOWERING-LABEL: sudot_different_types:
+; CHECK-NEWLOWERING: // %bb.0: // %entry
+; CHECK-NEWLOWERING-NEXT: and z2.h, z2.h, #0xff
+; CHECK-NEWLOWERING-NEXT: sunpklo z3.s, z1.h
+; CHECK-NEWLOWERING-NEXT: sunpkhi z1.s, z1.h
+; CHECK-NEWLOWERING-NEXT: ptrue p0.d
+; CHECK-NEWLOWERING-NEXT: uunpklo z4.s, z2.h
+; CHECK-NEWLOWERING-NEXT: uunpkhi z2.s, z2.h
+; CHECK-NEWLOWERING-NEXT: sunpklo z5.d, z3.s
+; CHECK-NEWLOWERING-NEXT: sunpkhi z3.d, z3.s
+; CHECK-NEWLOWERING-NEXT: sunpklo z7.d, z1.s
+; CHECK-NEWLOWERING-NEXT: sunpkhi z1.d, z1.s
+; CHECK-NEWLOWERING-NEXT: uunpklo z6.d, z4.s
+; CHECK-NEWLOWERING-NEXT: uunpkhi z4.d, z4.s
+; CHECK-NEWLOWERING-NEXT: uunpklo z24.d, z2.s
+; CHECK-NEWLOWERING-NEXT: uunpkhi z2.d, z2.s
+; CHECK-NEWLOWERING-NEXT: mul z3.d, z3.d, z4.d
+; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z5.d, z6.d
+; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z1.d, z2.d
+; CHECK-NEWLOWERING-NEXT: movprfx z1, z3
+; CHECK-NEWLOWERING-NEXT: mla z1.d, p0/m, z7.d, z24.d
+; CHECK-NEWLOWERING-NEXT: add z0.d, z1.d, z0.d
+; CHECK-NEWLOWERING-NEXT: ret
entry:
%a.wide = sext <vscale x 8 x i16> %a to <vscale x 8 x i64>
%b.wide = zext <vscale x 8 x i8> %b to <vscale x 8 x i64>
>From 263c40e4226b83dfc4ae7d0ff0fbd51cafad2d6c Mon Sep 17 00:00:00 2001
From: James Chesterman <james.chesterman at arm.com>
Date: Wed, 5 Feb 2025 14:34:41 +0000
Subject: [PATCH 06/12] Rename variables and ensure conformity to coding
standards.
---
.../SelectionDAG/LegalizeIntegerTypes.cpp | 2 +-
.../lib/CodeGen/SelectionDAG/TargetLowering.cpp | 17 ++++++++---------
2 files changed, 9 insertions(+), 10 deletions(-)
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
index 8e4f90ff9c6c4..8a2b58a9610ae 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
@@ -6218,7 +6218,7 @@ SDValue DAGTypeLegalizer::PromoteIntRes_PARTIAL_REDUCE_MLA(SDNode *N) {
EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), VT);
SDValue ExtAcc = GetPromotedInteger(N->getOperand(0));
return DAG.getNode(N->getOpcode(), DL, NVT, ExtAcc, N->getOperand(1),
- N->getOperand(2));
+ N->getOperand(2));
}
SDValue DAGTypeLegalizer::PromoteIntRes_INSERT_VECTOR_ELT(SDNode *N) {
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index c2fb6a07cbd18..cc57a17878d12 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -11891,15 +11891,14 @@ SDValue TargetLowering::expandVECTOR_COMPRESS(SDNode *Node,
return DAG.getLoad(VecVT, DL, Chain, StackPtr, PtrInfo);
}
-SDValue TargetLowering::expandPartialReduceMLA(SDNode *N, SelectionDAG &DAG) const {
+SDValue TargetLowering::expandPartialReduceMLA(SDNode *N,
+ SelectionDAG &DAG) const {
SDLoc DL(N);
SDValue Acc = N->getOperand(0);
- SDValue Input1 = N->getOperand(1);
- SDValue Input2 = N->getOperand(2);
-
-
+ SDValue MulLHS = N->getOperand(1);
+ SDValue MulRHS = N->getOperand(2);
EVT ReducedTy = Acc.getValueType();
- EVT FullTy = Input1.getValueType();
+ EVT FullTy = MulLHS.getValueType();
auto ExtendToAccEltVT = [&](SDValue V) {
unsigned ExtOpc = V->getOpcode() == ISD::SIGN_EXTEND ? ISD::SIGN_EXTEND
@@ -11914,9 +11913,9 @@ SDValue TargetLowering::expandPartialReduceMLA(SDNode *N, SelectionDAG &DAG) con
EVT NewVT =
EVT::getVectorVT(*DAG.getContext(), ReducedTy.getVectorElementType(),
FullTy.getVectorElementCount());
- Input1 = ExtendToAccEltVT(Input1);
- Input2 = ExtendToAccEltVT(Input2);
- SDValue Input = DAG.getNode(ISD::MUL, DL, NewVT, Input1, Input2);
+ MulLHS = ExtendToAccEltVT(MulLHS);
+ MulRHS = ExtendToAccEltVT(MulRHS);
+ SDValue Input = DAG.getNode(ISD::MUL, DL, NewVT, MulLHS, MulRHS);
unsigned Stride = ReducedTy.getVectorMinNumElements();
unsigned ScaleFactor = FullTy.getVectorMinNumElements() / Stride;
>From 9368985ec1268c98b9f82e88dbc2215273b72cd7 Mon Sep 17 00:00:00 2001
From: James Chesterman <james.chesterman at arm.com>
Date: Thu, 6 Feb 2025 13:21:19 +0000
Subject: [PATCH 07/12] Address comments. Includes making cli option
target-specific
Make comment describing the node more broad.
Promote both inputs at the same time.
Move assert statements to the getNode() function.
Make the command line option target-specific.
---
llvm/include/llvm/CodeGen/ISDOpcodes.h | 10 +++----
llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp | 5 ----
.../SelectionDAG/LegalizeIntegerTypes.cpp | 8 ++---
llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h | 2 +-
.../SelectionDAG/LegalizeVectorOps.cpp | 4 +--
.../lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 24 +++++++++++++++
.../SelectionDAG/SelectionDAGBuilder.cpp | 29 ++-----------------
.../CodeGen/SelectionDAG/TargetLowering.cpp | 6 +++-
.../Target/AArch64/AArch64ISelLowering.cpp | 9 ++++++
.../neon-partial-reduce-dot-product.ll | 2 +-
.../AArch64/sve-partial-reduce-dot-product.ll | 2 +-
.../AArch64/sve-partial-reduce-wide-add.ll | 2 +-
12 files changed, 55 insertions(+), 48 deletions(-)
diff --git a/llvm/include/llvm/CodeGen/ISDOpcodes.h b/llvm/include/llvm/CodeGen/ISDOpcodes.h
index efeb63df4f143..9fe66c24cd96f 100644
--- a/llvm/include/llvm/CodeGen/ISDOpcodes.h
+++ b/llvm/include/llvm/CodeGen/ISDOpcodes.h
@@ -1452,15 +1452,15 @@ enum NodeType {
VECREDUCE_UMIN,
// PARTIAL_REDUCE_[U|S]MLA(Accumulator, Input1, Input2)
- // The partial reduction nodes sign or zero extend Input1 and Input2 to
- // the element type of Accumulator before multiplying their results.
- // The multiplied result is then reduced using addition to the result
- // type of Accumulator. The result is added to Accumulator and returned.
+ // Input1 and Input2 are multiplied together. This result is concatenated to
+ // the accumulator, and this is then reduced, using addition, to the result
+ // type.
// Input1 and Input2 must be the same type. Accumulator and the output must be
// the same type.
// The number of elements in Input1 and Input2 must be a positive integer
// multiple of the number of elements in the Accumulator / output type.
- // All operands, as well as the output, must have the same element type.
+ // Input1 and Input2 must have an element type which is the same as or smaller
+ // than the element type of the Accumulator and output.
PARTIAL_REDUCE_SMLA,
PARTIAL_REDUCE_UMLA,
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
index 53e8019a2e0ec..6c9c96ceaa4ba 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
@@ -1245,11 +1245,6 @@ void SelectionDAGLegalize::LegalizeOp(SDNode *Node) {
Node->getOpcode(),
cast<MaskedHistogramSDNode>(Node)->getIndex().getValueType());
break;
- case ISD::PARTIAL_REDUCE_UMLA:
- case ISD::PARTIAL_REDUCE_SMLA:
- Action = TLI.getOperationAction(Node->getOpcode(),
- Node->getOperand(0).getValueType());
- break;
default:
if (Node->getOpcode() >= ISD::BUILTIN_OP_END) {
Action = TLI.getCustomOperationAction(*Node);
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
index 8a2b58a9610ae..2ded0600a7559 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
@@ -2106,7 +2106,7 @@ bool DAGTypeLegalizer::PromoteIntegerOperand(SDNode *N, unsigned OpNo) {
break;
case ISD::PARTIAL_REDUCE_UMLA:
case ISD::PARTIAL_REDUCE_SMLA:
- Res = PromoteIntOp_PARTIAL_REDUCE_MLA(N, OpNo);
+ Res = PromoteIntOp_PARTIAL_REDUCE_MLA(N);
break;
}
@@ -2890,10 +2890,10 @@ SDValue DAGTypeLegalizer::PromoteIntOp_VECTOR_FIND_LAST_ACTIVE(SDNode *N,
return SDValue(DAG.UpdateNodeOperands(N, NewOps), 0);
}
-SDValue DAGTypeLegalizer::PromoteIntOp_PARTIAL_REDUCE_MLA(SDNode *N,
- unsigned OpNo) {
+SDValue DAGTypeLegalizer::PromoteIntOp_PARTIAL_REDUCE_MLA(SDNode *N) {
SmallVector<SDValue, 1> NewOps(N->ops());
- NewOps[OpNo] = GetPromotedInteger(N->getOperand(OpNo));
+ NewOps[1] = GetPromotedInteger(N->getOperand(1));
+ NewOps[2] = GetPromotedInteger(N->getOperand(2));
return SDValue(DAG.UpdateNodeOperands(N, NewOps), 0);
}
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
index e48aa35773780..cb9c1b239c0fa 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
@@ -431,7 +431,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
SDValue PromoteIntOp_VP_SPLICE(SDNode *N, unsigned OpNo);
SDValue PromoteIntOp_VECTOR_HISTOGRAM(SDNode *N, unsigned OpNo);
SDValue PromoteIntOp_VECTOR_FIND_LAST_ACTIVE(SDNode *N, unsigned OpNo);
- SDValue PromoteIntOp_PARTIAL_REDUCE_MLA(SDNode *N, unsigned OpNo);
+ SDValue PromoteIntOp_PARTIAL_REDUCE_MLA(SDNode *N);
void SExtOrZExtPromotedOperands(SDValue &LHS, SDValue &RHS);
void PromoteSetCCOperands(SDValue &LHS,SDValue &RHS, ISD::CondCode Code);
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
index 85524f1d83896..4d8cc914857a7 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
@@ -467,6 +467,8 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
case ISD::VECTOR_COMPRESS:
case ISD::SCMP:
case ISD::UCMP:
+ case ISD::PARTIAL_REDUCE_UMLA:
+ case ISD::PARTIAL_REDUCE_SMLA:
Action = TLI.getOperationAction(Node->getOpcode(), Node->getValueType(0));
break;
case ISD::SMULFIX:
@@ -503,8 +505,6 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
case ISD::VECREDUCE_FMIN:
case ISD::VECREDUCE_FMAXIMUM:
case ISD::VECREDUCE_FMINIMUM:
- case ISD::PARTIAL_REDUCE_UMLA:
- case ISD::PARTIAL_REDUCE_SMLA:
case ISD::VECTOR_FIND_LAST_ACTIVE:
Action = TLI.getOperationAction(Node->getOpcode(),
Node->getOperand(0).getValueType());
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index b3ab68f14df5c..cb75f8fc543f2 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -7854,6 +7854,30 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
break;
}
+ case ISD::PARTIAL_REDUCE_UMLA:
+ case ISD::PARTIAL_REDUCE_SMLA: {
+ EVT AccVT = N1.getValueType();
+ EVT Input1VT = N2.getValueType();
+ EVT Input2VT = N3.getValueType();
+ assert(Input1VT == Input2VT &&
+ "Expected the second and third operands of the PARTIAL_REDUCE_MLA "
+ "node to have the same type!");
+ assert(VT == AccVT &&
+ "Expected the first operand of the PARTIAL_REDUCE_MLA node to have "
+ "the same type as its result!");
+ assert(Input1VT.getVectorElementCount().getKnownMinValue() %
+ AccVT.getVectorElementCount().getKnownMinValue() ==
+ 0 &&
+ "Expected the element count of the second and third operands of the "
+ "PARTIAL_REDUCE_MLA node to be a positive integer multiple of the "
+ "element count of the first operand and result!");
+ assert(Input1VT.getVectorElementType().getSizeInBits() <=
+ AccVT.getVectorElementType().getSizeInBits() &&
+ "Expected the second and third operands of the PARTIAL_REDUCE_MLA "
+ "node to have an element type which is the same as or smaller than "
+ "the element type of the first operand and result!");
+ break;
+ }
}
// Memoize node if it doesn't produce a glue result.
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 5188b85adade1..9c158f7b054d1 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -135,13 +135,6 @@ static cl::opt<unsigned> SwitchPeelThreshold(
"switch statement. A value greater than 100 will void this "
"optimization"));
-// FIXME : This is a temporary flag, and is used to help transition to
-// performing lowering the proper way using the new PARTIAL_REDUCE_MLA ISD
-// nodes.
-static cl::opt<bool> NewPartialReduceLowering(
- "new-partial-reduce-lowering", cl::init(false), cl::ReallyHidden,
- cl::desc("Use the new method of lowering partial reductions."));
-
// Limit the width of DAG chains. This is important in general to prevent
// DAG-based analysis from blowing up. For example, alias analysis and
// load clustering may not complete in reasonable time. It is difficult to
@@ -8130,30 +8123,12 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
SDValue Input = getValue(I.getOperand(1));
EVT InputVT = Input.getValueType();
- assert(AccVT.getVectorElementType() == InputVT.getVectorElementType() &&
- "Expected operands to have the same vector element type!");
- assert(InputVT.getVectorElementCount().getKnownMinValue() %
- AccVT.getVectorElementCount().getKnownMinValue() ==
- 0 &&
- "Expected the element count of the Input operand to be a positive "
- "integer multiple of the element count of the Accumulator operand!");
- SDValue PRVal = DAG.getNode(ISD::PARTIAL_REDUCE_UMLA, sdl, AccVT, Acc,
- Input, DAG.getConstant(1, sdl, InputVT));
-
- if (NewPartialReduceLowering) {
- // ISD::PARTIAL_REDUCE_UMLA is chosen arbitrarily and would function the
- // same if ISD::PARTIAL_REDUCE_SMLA was chosen instead. It should be
- // changed to its correct signedness when combining or expanding,
- // according to extends being performed on Input.
- setValue(&I, PRVal);
- return;
- }
-
if (!TLI.shouldExpandPartialReductionIntrinsic(cast<IntrinsicInst>(&I))) {
visitTargetIntrinsic(I, Intrinsic);
return;
}
- setValue(&I, TLI.expandPartialReduceMLA(PRVal.getNode(), DAG));
+ setValue(&I, DAG.getNode(ISD::PARTIAL_REDUCE_UMLA, sdl, AccVT, Acc, Input,
+ DAG.getConstant(1, sdl, InputVT)));
return;
}
case Intrinsic::experimental_cttz_elts: {
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index cc57a17878d12..a313ae7254f13 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -11915,7 +11915,11 @@ SDValue TargetLowering::expandPartialReduceMLA(SDNode *N,
FullTy.getVectorElementCount());
MulLHS = ExtendToAccEltVT(MulLHS);
MulRHS = ExtendToAccEltVT(MulRHS);
- SDValue Input = DAG.getNode(ISD::MUL, DL, NewVT, MulLHS, MulRHS);
+ SDValue Input = MulLHS;
+ APInt ConstantOne;
+ if (!ISD::isConstantSplatVector(MulRHS.getNode(), ConstantOne) ||
+ !ConstantOne.isOne())
+ Input = DAG.getNode(ISD::MUL, DL, NewVT, MulLHS, MulRHS);
unsigned Stride = ReducedTy.getVectorMinNumElements();
unsigned ScaleFactor = FullTy.getVectorMinNumElements() / Stride;
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 0baa873295ced..896007d0d4270 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -154,6 +154,13 @@ cl::opt<bool> EnableSVEGISel(
cl::desc("Enable / disable SVE scalable vectors in Global ISel"),
cl::init(false));
+// FIXME : This is a temporary flag, and is used to help transition to
+// performing lowering the proper way using the new PARTIAL_REDUCE_MLA ISD
+// nodes.
+static cl::opt<bool> EnablePartialReduceNodes(
+ "aarch64-enable-partial-reduce-nodes", cl::init(false), cl::ReallyHidden,
+ cl::desc("Use the new method of lowering partial reductions."));
+
/// Value type used for condition codes.
static const MVT MVT_CC = MVT::i32;
@@ -2055,6 +2062,8 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
const IntrinsicInst *I) const {
if (I->getIntrinsicID() != Intrinsic::experimental_vector_partial_reduce_add)
return true;
+ if (EnablePartialReduceNodes)
+ return true;
EVT VT = EVT::getEVT(I->getType());
auto Op1 = I->getOperand(1);
diff --git a/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
index 60dd825b2d093..40daf8ffb63ea 100644
--- a/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
@@ -2,7 +2,7 @@
; RUN: llc -mtriple aarch64 -mattr=+neon,+dotprod < %s | FileCheck %s --check-prefixes=CHECK,CHECK-DOT,CHECK-NOI8MM
; RUN: llc -mtriple aarch64 -mattr=+neon < %s | FileCheck %s --check-prefixes=CHECK,CHECK-NOI8MM,CHECK-NODOT
; RUN: llc -mtriple aarch64 -mattr=+neon,+dotprod,+i8mm < %s | FileCheck %s --check-prefixes=CHECK,CHECK-DOT,CHECK-I8MM
-; RUN: llc -mtriple aarch64 -mattr=+neon,+dotprod,+i8mm -new-partial-reduce-lowering < %s | FileCheck %s --check-prefixes=CHECK,CHECK-NOI8MM,CHECK-NODOT
+; RUN: llc -mtriple aarch64 -mattr=+neon,+dotprod,+i8mm -aarch64-enable-partial-reduce-nodes < %s | FileCheck %s --check-prefixes=CHECK,CHECK-NOI8MM,CHECK-NODOT
define <4 x i32> @udot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) {
; CHECK-DOT-LABEL: udot:
diff --git a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
index b66b952b224ab..a46c97d8605cd 100644
--- a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
@@ -1,7 +1,7 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
; RUN: llc -mtriple=aarch64 -mattr=+sve2,+i8mm %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-I8MM
; RUN: llc -mtriple=aarch64 -mattr=+sve2 %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-NOI8MM
-; RUN: llc -mtriple=aarch64 -mattr=+sve2,+i8mm -new-partial-reduce-lowering %s -o - | FileCheck %s --check-prefixes=CHECK-NEWLOWERING
+; RUN: llc -mtriple=aarch64 -mattr=+sve2,+i8mm -aarch64-enable-partial-reduce-nodes %s -o - | FileCheck %s --check-prefixes=CHECK-NEWLOWERING
define <vscale x 4 x i32> @udot(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
; CHECK-LABEL: udot:
diff --git a/llvm/test/CodeGen/AArch64/sve-partial-reduce-wide-add.ll b/llvm/test/CodeGen/AArch64/sve-partial-reduce-wide-add.ll
index 62b5039259392..11fb60ead4fb2 100644
--- a/llvm/test/CodeGen/AArch64/sve-partial-reduce-wide-add.ll
+++ b/llvm/test/CodeGen/AArch64/sve-partial-reduce-wide-add.ll
@@ -1,7 +1,7 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
; RUN: llc -mtriple=aarch64 -mattr=+sve2 %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-SVE2
; RUN: llc -mtriple=aarch64 -mattr=+sve %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-SVE
-; RUN: llc -mtriple=aarch64 -mattr=+sve2 -new-partial-reduce-lowering %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-NEWLOWERING
+; RUN: llc -mtriple=aarch64 -mattr=+sve2 -aarch64-enable-partial-reduce-nodes %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-NEWLOWERING
define <vscale x 2 x i64> @signed_wide_add_nxv4i32(<vscale x 2 x i64> %acc, <vscale x 4 x i32> %input){
; CHECK-SVE2-LABEL: signed_wide_add_nxv4i32:
>From ccc6f48820f56ad14697a6f1e4bc8ab25232c64d Mon Sep 17 00:00:00 2001
From: James Chesterman <james.chesterman at arm.com>
Date: Fri, 7 Feb 2025 11:10:58 +0000
Subject: [PATCH 08/12] Address comments. Includes improving assert statemetns.
Improve assert statements.
Remove unnecessary variable creation.
Change how operation actions are set for the nodes.
---
llvm/include/llvm/CodeGen/ISDOpcodes.h | 7 ++++---
.../SelectionDAG/LegalizeVectorOps.cpp | 4 +---
.../SelectionDAG/LegalizeVectorTypes.cpp | 4 ++--
.../lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 20 +++++++++----------
.../SelectionDAG/SelectionDAGBuilder.cpp | 12 +++++------
llvm/lib/CodeGen/TargetLoweringBase.cpp | 4 ++++
.../Target/AArch64/AArch64ISelLowering.cpp | 16 +++------------
7 files changed, 28 insertions(+), 39 deletions(-)
diff --git a/llvm/include/llvm/CodeGen/ISDOpcodes.h b/llvm/include/llvm/CodeGen/ISDOpcodes.h
index 9fe66c24cd96f..ea8b2d580f195 100644
--- a/llvm/include/llvm/CodeGen/ISDOpcodes.h
+++ b/llvm/include/llvm/CodeGen/ISDOpcodes.h
@@ -1452,9 +1452,10 @@ enum NodeType {
VECREDUCE_UMIN,
// PARTIAL_REDUCE_[U|S]MLA(Accumulator, Input1, Input2)
- // Input1 and Input2 are multiplied together. This result is concatenated to
- // the accumulator, and this is then reduced, using addition, to the result
- // type.
+ // The partial reduction nodes sign or zero extend Input1 and Input2 to the
+ // element type of Accumulator before multiplying their results.
+ // This result is concatenated to the Accumulator, and this is then reduced,
+ // using addition, to the result type.
// Input1 and Input2 must be the same type. Accumulator and the output must be
// the same type.
// The number of elements in Input1 and Input2 must be a positive integer
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
index 4d8cc914857a7..915e0e4ec8a53 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
@@ -1198,11 +1198,9 @@ void VectorLegalizer::Expand(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
Results.push_back(TLI.expandVecReduce(Node, DAG));
return;
case ISD::PARTIAL_REDUCE_UMLA:
- case ISD::PARTIAL_REDUCE_SMLA: {
- SDLoc DL(Node);
+ case ISD::PARTIAL_REDUCE_SMLA:
Results.push_back(TLI.expandPartialReduceMLA(Node, DAG));
return;
- }
case ISD::VECREDUCE_SEQ_FADD:
case ISD::VECREDUCE_SEQ_FMUL:
Results.push_back(TLI.expandVecReduceSeq(Node, DAG));
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
index b1b2957efc4e3..7f0cbc51d6aa6 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
@@ -1376,6 +1376,7 @@ void DAGTypeLegalizer::SplitVectorResult(SDNode *N, unsigned ResNo) {
case ISD::PARTIAL_REDUCE_UMLA:
case ISD::PARTIAL_REDUCE_SMLA:
SplitVecRes_PARTIAL_REDUCE_MLA(N);
+ break;
}
// If Lo/Hi is null, the sub-method took care of registering results etc.
@@ -3186,7 +3187,6 @@ void DAGTypeLegalizer::SplitVecRes_VP_REVERSE(SDNode *N, SDValue &Lo,
}
void DAGTypeLegalizer::SplitVecRes_PARTIAL_REDUCE_MLA(SDNode *N) {
- SDLoc DL(N);
SDValue Res = TLI.expandPartialReduceMLA(N, DAG);
ReplaceValueWith(SDValue(N, 0), Res);
}
@@ -3393,6 +3393,7 @@ bool DAGTypeLegalizer::SplitVectorOperand(SDNode *N, unsigned OpNo) {
case ISD::PARTIAL_REDUCE_UMLA:
case ISD::PARTIAL_REDUCE_SMLA:
Res = SplitVecOp_PARTIAL_REDUCE_MLA(N);
+ break;
}
// If the result is null, the sub-method took care of registering results etc.
@@ -4448,7 +4449,6 @@ SDValue DAGTypeLegalizer::SplitVecOp_VECTOR_HISTOGRAM(SDNode *N) {
}
SDValue DAGTypeLegalizer::SplitVecOp_PARTIAL_REDUCE_MLA(SDNode *N) {
- SDLoc DL(N);
SDValue Res = TLI.expandPartialReduceMLA(N, DAG);
ReplaceValueWith(SDValue(N, 0), Res);
return SDValue();
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index cb75f8fc543f2..71d6a66775863 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -7856,23 +7856,21 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
}
case ISD::PARTIAL_REDUCE_UMLA:
case ISD::PARTIAL_REDUCE_SMLA: {
- EVT AccVT = N1.getValueType();
- EVT Input1VT = N2.getValueType();
- EVT Input2VT = N3.getValueType();
- assert(Input1VT == Input2VT &&
+ [[maybe_unused]] EVT AccVT = N1.getValueType();
+ [[maybe_unused]] EVT Input1VT = N2.getValueType();
+ [[maybe_unused]] EVT Input2VT = N3.getValueType();
+ assert(Input1VT.isVector() && Input1VT == Input2VT &&
"Expected the second and third operands of the PARTIAL_REDUCE_MLA "
"node to have the same type!");
- assert(VT == AccVT &&
+ assert(VT.isVector() && VT == AccVT &&
"Expected the first operand of the PARTIAL_REDUCE_MLA node to have "
"the same type as its result!");
- assert(Input1VT.getVectorElementCount().getKnownMinValue() %
- AccVT.getVectorElementCount().getKnownMinValue() ==
- 0 &&
+ assert(Input1VT.getVectorElementCount().hasKnownScalarFactor(
+ AccVT.getVectorElementCount()) &&
"Expected the element count of the second and third operands of the "
"PARTIAL_REDUCE_MLA node to be a positive integer multiple of the "
- "element count of the first operand and result!");
- assert(Input1VT.getVectorElementType().getSizeInBits() <=
- AccVT.getVectorElementType().getSizeInBits() &&
+ "element count of the first operand and the result!");
+ assert(N2.getScalarValueSizeInBits() <= N1.getScalarValueSizeInBits() &&
"Expected the second and third operands of the PARTIAL_REDUCE_MLA "
"node to have an element type which is the same as or smaller than "
"the element type of the first operand and result!");
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 9c158f7b054d1..552222fdf581c 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -8118,17 +8118,15 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
return;
}
case Intrinsic::experimental_vector_partial_reduce_add: {
- SDValue Acc = getValue(I.getOperand(0));
- EVT AccVT = Acc.getValueType();
- SDValue Input = getValue(I.getOperand(1));
- EVT InputVT = Input.getValueType();
-
if (!TLI.shouldExpandPartialReductionIntrinsic(cast<IntrinsicInst>(&I))) {
visitTargetIntrinsic(I, Intrinsic);
return;
}
- setValue(&I, DAG.getNode(ISD::PARTIAL_REDUCE_UMLA, sdl, AccVT, Acc, Input,
- DAG.getConstant(1, sdl, InputVT)));
+ SDValue Acc = getValue(I.getOperand(0));
+ SDValue Input = getValue(I.getOperand(1));
+ setValue(&I,
+ DAG.getNode(ISD::PARTIAL_REDUCE_UMLA, sdl, Acc.getValueType(), Acc,
+ Input, DAG.getConstant(1, sdl, Input.getValueType())));
return;
}
case Intrinsic::experimental_cttz_elts: {
diff --git a/llvm/lib/CodeGen/TargetLoweringBase.cpp b/llvm/lib/CodeGen/TargetLoweringBase.cpp
index 9c56912aa6ba0..194aae19f45b0 100644
--- a/llvm/lib/CodeGen/TargetLoweringBase.cpp
+++ b/llvm/lib/CodeGen/TargetLoweringBase.cpp
@@ -825,6 +825,10 @@ void TargetLoweringBase::initActions() {
setOperationAction(ISD::GET_FPENV, VT, Expand);
setOperationAction(ISD::SET_FPENV, VT, Expand);
setOperationAction(ISD::RESET_FPENV, VT, Expand);
+
+ // PartialReduceMLA operations default to expand.
+ setOperationAction({ISD::PARTIAL_REDUCE_UMLA, ISD::PARTIAL_REDUCE_SMLA}, VT,
+ Expand);
}
// Most targets ignore the @llvm.prefetch intrinsic.
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 896007d0d4270..52f05ef952896 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1365,9 +1365,6 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::BSWAP, VT, Expand);
setOperationAction(ISD::CTTZ, VT, Expand);
- setOperationAction(ISD::PARTIAL_REDUCE_UMLA, VT, Expand);
- setOperationAction(ISD::PARTIAL_REDUCE_SMLA, VT, Expand);
-
for (MVT InnerVT : MVT::fixedlen_vector_valuetypes()) {
setTruncStoreAction(VT, InnerVT, Expand);
setLoadExtAction(ISD::SEXTLOAD, VT, InnerVT, Expand);
@@ -1581,11 +1578,6 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::MSTORE, VT, Custom);
}
- for (MVT VT : MVT::scalable_vector_valuetypes()) {
- setOperationAction(ISD::PARTIAL_REDUCE_UMLA, VT, Expand);
- setOperationAction(ISD::PARTIAL_REDUCE_SMLA, VT, Expand);
- }
-
// Firstly, exclude all scalable vector extending loads/truncating stores,
// include both integer and floating scalable vector.
for (MVT VT : MVT::scalable_vector_valuetypes()) {
@@ -22027,13 +22019,11 @@ static SDValue performIntrinsicCombine(SDNode *N,
return Dot;
if (SDValue WideAdd = tryLowerPartialReductionToWideAdd(N, Subtarget, DAG))
return WideAdd;
- const TargetLowering &TLI = DAG.getTargetLoweringInfo();
SDLoc DL(N);
SDValue Input = N->getOperand(2);
- SDValue PRVal = DAG.getNode(ISD::PARTIAL_REDUCE_UMLA, DL,
- N->getValueType(0), N->getOperand(1), Input,
- DAG.getConstant(1, DL, Input.getValueType()));
- return TLI.expandPartialReduceMLA(PRVal.getNode(), DAG);
+ return DAG.getNode(ISD::PARTIAL_REDUCE_UMLA, DL, N->getValueType(0),
+ N->getOperand(1), Input,
+ DAG.getConstant(1, DL, Input.getValueType()));
}
case Intrinsic::aarch64_neon_vcvtfxs2fp:
case Intrinsic::aarch64_neon_vcvtfxu2fp:
>From 5e29cdf8d06c7701de433027006409ab427fd4c7 Mon Sep 17 00:00:00 2001
From: James Chesterman <james.chesterman at arm.com>
Date: Fri, 7 Feb 2025 16:44:47 +0000
Subject: [PATCH 09/12] Address comments, includes fix for using wrong
signedness
Fix wrong signedness in promotion function.
Make expand code depend on node signedness not on operand extend.
---
llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp | 9 +++++++--
llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp | 5 +++--
2 files changed, 10 insertions(+), 4 deletions(-)
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
index 2ded0600a7559..d811176052e3e 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
@@ -2892,8 +2892,13 @@ SDValue DAGTypeLegalizer::PromoteIntOp_VECTOR_FIND_LAST_ACTIVE(SDNode *N,
SDValue DAGTypeLegalizer::PromoteIntOp_PARTIAL_REDUCE_MLA(SDNode *N) {
SmallVector<SDValue, 1> NewOps(N->ops());
- NewOps[1] = GetPromotedInteger(N->getOperand(1));
- NewOps[2] = GetPromotedInteger(N->getOperand(2));
+ if (N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA) {
+ NewOps[1] = SExtPromotedInteger(N->getOperand(1));
+ NewOps[2] = SExtPromotedInteger(N->getOperand(2));
+ } else {
+ NewOps[1] = ZExtPromotedInteger(N->getOperand(1));
+ NewOps[2] = ZExtPromotedInteger(N->getOperand(2));
+ }
return SDValue(DAG.UpdateNodeOperands(N, NewOps), 0);
}
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index a313ae7254f13..8f4d1838ec92a 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -11901,8 +11901,9 @@ SDValue TargetLowering::expandPartialReduceMLA(SDNode *N,
EVT FullTy = MulLHS.getValueType();
auto ExtendToAccEltVT = [&](SDValue V) {
- unsigned ExtOpc = V->getOpcode() == ISD::SIGN_EXTEND ? ISD::SIGN_EXTEND
- : ISD::ZERO_EXTEND;
+ unsigned ExtOpc = V->getOpcode() == ISD::PARTIAL_REDUCE_SMLA
+ ? ISD::SIGN_EXTEND
+ : ISD::ZERO_EXTEND;
EVT ExtVT = V.getValueType().changeVectorElementType(
Acc.getValueType().getVectorElementType());
if (ExtVT != FullTy)
>From d5719f99377b7467f34b94a3bae4c8b6c4e35e71 Mon Sep 17 00:00:00 2001
From: James Chesterman <james.chesterman at arm.com>
Date: Fri, 7 Feb 2025 17:04:12 +0000
Subject: [PATCH 10/12] Fix issue with expand function.
Was seeing if the operand opcodes were PARTIAL_REDUCE_MLA nodes.
Now looking at their parent.
---
.../CodeGen/SelectionDAG/TargetLowering.cpp | 26 +++++++++----------
1 file changed, 13 insertions(+), 13 deletions(-)
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index 8f4d1838ec92a..da7d3dbc9ce52 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -11900,22 +11900,22 @@ SDValue TargetLowering::expandPartialReduceMLA(SDNode *N,
EVT ReducedTy = Acc.getValueType();
EVT FullTy = MulLHS.getValueType();
- auto ExtendToAccEltVT = [&](SDValue V) {
- unsigned ExtOpc = V->getOpcode() == ISD::PARTIAL_REDUCE_SMLA
- ? ISD::SIGN_EXTEND
- : ISD::ZERO_EXTEND;
- EVT ExtVT = V.getValueType().changeVectorElementType(
- Acc.getValueType().getVectorElementType());
- if (ExtVT != FullTy)
- return DAG.getNode(ExtOpc, DL, ExtVT, V);
- return V;
- };
-
EVT NewVT =
EVT::getVectorVT(*DAG.getContext(), ReducedTy.getVectorElementType(),
FullTy.getVectorElementCount());
- MulLHS = ExtendToAccEltVT(MulLHS);
- MulRHS = ExtendToAccEltVT(MulRHS);
+ unsigned ExtOpc = N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA
+ ? ISD::SIGN_EXTEND
+ : ISD::ZERO_EXTEND;
+ EVT MulLHSVT = MulLHS.getValueType();
+ assert(MulLHSVT == MulRHS.getValueType() &&
+ "The second and third operands of a PARTIAL_REDUCE_MLA node must have "
+ "the same value type!");
+ EVT ExtVT = MulLHSVT.changeVectorElementType(
+ Acc.getValueType().getVectorElementType());
+ if (ExtVT != FullTy) {
+ MulLHS = DAG.getNode(ExtOpc, DL, ExtVT, MulLHS);
+ MulRHS = DAG.getNode(ExtOpc, DL, ExtVT, MulRHS);
+ }
SDValue Input = MulLHS;
APInt ConstantOne;
if (!ISD::isConstantSplatVector(MulRHS.getNode(), ConstantOne) ||
>From 638c0ca4e12e9cf76bfcc6658c616042bd4d7e71 Mon Sep 17 00:00:00 2001
From: James Chesterman <james.chesterman at arm.com>
Date: Wed, 12 Feb 2025 11:59:46 +0000
Subject: [PATCH 11/12] Change splitting functions.
Adjust ISDOpcode description.
Rename variables in expand function.
Remove unnecessary assert statement.
---
llvm/include/llvm/CodeGen/ISDOpcodes.h | 3 +
llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h | 2 +-
.../SelectionDAG/LegalizeVectorTypes.cpp | 14 +--
.../CodeGen/SelectionDAG/TargetLowering.cpp | 35 ++++----
.../AArch64/sve-partial-reduce-dot-product.ll | 86 +++++++++++++++++++
5 files changed, 112 insertions(+), 28 deletions(-)
diff --git a/llvm/include/llvm/CodeGen/ISDOpcodes.h b/llvm/include/llvm/CodeGen/ISDOpcodes.h
index ea8b2d580f195..6e26e8354473f 100644
--- a/llvm/include/llvm/CodeGen/ISDOpcodes.h
+++ b/llvm/include/llvm/CodeGen/ISDOpcodes.h
@@ -1456,6 +1456,9 @@ enum NodeType {
// element type of Accumulator before multiplying their results.
// This result is concatenated to the Accumulator, and this is then reduced,
// using addition, to the result type.
+ // The output is only expected to either be given to another partial reduction
+ // operation or an equivalent vector reduce operation, so the order in which
+ // the elements are reduced is deliberately not specified.
// Input1 and Input2 must be the same type. Accumulator and the output must be
// the same type.
// The number of elements in Input1 and Input2 must be a positive integer
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
index cb9c1b239c0fa..376f1e73719cc 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
@@ -970,7 +970,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
void SplitVecRes_VAARG(SDNode *N, SDValue &Lo, SDValue &Hi);
void SplitVecRes_FP_TO_XINT_SAT(SDNode *N, SDValue &Lo, SDValue &Hi);
void SplitVecRes_VP_REVERSE(SDNode *N, SDValue &Lo, SDValue &Hi);
- void SplitVecRes_PARTIAL_REDUCE_MLA(SDNode *N);
+ void SplitVecRes_PARTIAL_REDUCE_MLA(SDNode *N, SDValue &Lo, SDValue &Hi);
// Vector Operand Splitting: <128 x ty> -> 2 x <64 x ty>.
bool SplitVectorOperand(SDNode *N, unsigned OpNo);
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
index 7f0cbc51d6aa6..3c39d3eee0816 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
@@ -1375,7 +1375,7 @@ void DAGTypeLegalizer::SplitVectorResult(SDNode *N, unsigned ResNo) {
break;
case ISD::PARTIAL_REDUCE_UMLA:
case ISD::PARTIAL_REDUCE_SMLA:
- SplitVecRes_PARTIAL_REDUCE_MLA(N);
+ SplitVecRes_PARTIAL_REDUCE_MLA(N, Lo, Hi);
break;
}
@@ -3186,9 +3186,11 @@ void DAGTypeLegalizer::SplitVecRes_VP_REVERSE(SDNode *N, SDValue &Lo,
std::tie(Lo, Hi) = DAG.SplitVector(Load, DL);
}
-void DAGTypeLegalizer::SplitVecRes_PARTIAL_REDUCE_MLA(SDNode *N) {
- SDValue Res = TLI.expandPartialReduceMLA(N, DAG);
- ReplaceValueWith(SDValue(N, 0), Res);
+void DAGTypeLegalizer::SplitVecRes_PARTIAL_REDUCE_MLA(SDNode *N, SDValue &Lo,
+ SDValue &Hi) {
+ SDLoc DL(N);
+ SDValue Expanded = TLI.expandPartialReduceMLA(N, DAG);
+ std::tie(Lo, Hi) = DAG.SplitVector(Expanded, DL);
}
void DAGTypeLegalizer::SplitVecRes_VECTOR_DEINTERLEAVE(SDNode *N) {
@@ -4449,9 +4451,7 @@ SDValue DAGTypeLegalizer::SplitVecOp_VECTOR_HISTOGRAM(SDNode *N) {
}
SDValue DAGTypeLegalizer::SplitVecOp_PARTIAL_REDUCE_MLA(SDNode *N) {
- SDValue Res = TLI.expandPartialReduceMLA(N, DAG);
- ReplaceValueWith(SDValue(N, 0), Res);
- return SDValue();
+ return TLI.expandPartialReduceMLA(N, DAG);
}
//===----------------------------------------------------------------------===//
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index da7d3dbc9ce52..7771958f5adc9 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -11897,46 +11897,41 @@ SDValue TargetLowering::expandPartialReduceMLA(SDNode *N,
SDValue Acc = N->getOperand(0);
SDValue MulLHS = N->getOperand(1);
SDValue MulRHS = N->getOperand(2);
- EVT ReducedTy = Acc.getValueType();
- EVT FullTy = MulLHS.getValueType();
+ EVT AccVT = Acc.getValueType();
+ EVT MulOpVT = MulLHS.getValueType();
- EVT NewVT =
- EVT::getVectorVT(*DAG.getContext(), ReducedTy.getVectorElementType(),
- FullTy.getVectorElementCount());
+ EVT ExtMulOpVT =
+ EVT::getVectorVT(*DAG.getContext(), AccVT.getVectorElementType(),
+ MulOpVT.getVectorElementCount());
unsigned ExtOpc = N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA
? ISD::SIGN_EXTEND
: ISD::ZERO_EXTEND;
- EVT MulLHSVT = MulLHS.getValueType();
- assert(MulLHSVT == MulRHS.getValueType() &&
- "The second and third operands of a PARTIAL_REDUCE_MLA node must have "
- "the same value type!");
- EVT ExtVT = MulLHSVT.changeVectorElementType(
- Acc.getValueType().getVectorElementType());
- if (ExtVT != FullTy) {
- MulLHS = DAG.getNode(ExtOpc, DL, ExtVT, MulLHS);
- MulRHS = DAG.getNode(ExtOpc, DL, ExtVT, MulRHS);
+
+ if (ExtMulOpVT != MulOpVT) {
+ MulLHS = DAG.getNode(ExtOpc, DL, ExtMulOpVT, MulLHS);
+ MulRHS = DAG.getNode(ExtOpc, DL, ExtMulOpVT, MulRHS);
}
SDValue Input = MulLHS;
APInt ConstantOne;
if (!ISD::isConstantSplatVector(MulRHS.getNode(), ConstantOne) ||
!ConstantOne.isOne())
- Input = DAG.getNode(ISD::MUL, DL, NewVT, MulLHS, MulRHS);
+ Input = DAG.getNode(ISD::MUL, DL, ExtMulOpVT, MulLHS, MulRHS);
- unsigned Stride = ReducedTy.getVectorMinNumElements();
- unsigned ScaleFactor = FullTy.getVectorMinNumElements() / Stride;
+ unsigned Stride = AccVT.getVectorMinNumElements();
+ unsigned ScaleFactor = MulOpVT.getVectorMinNumElements() / Stride;
// Collect all of the subvectors
std::deque<SDValue> Subvectors = {Acc};
for (unsigned I = 0; I < ScaleFactor; I++) {
auto SourceIndex = DAG.getVectorIdxConstant(I * Stride, DL);
- Subvectors.push_back(DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ReducedTy,
- {Input, SourceIndex}));
+ Subvectors.push_back(
+ DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, AccVT, {Input, SourceIndex}));
}
// Flatten the subvector tree
while (Subvectors.size() > 1) {
Subvectors.push_back(
- DAG.getNode(ISD::ADD, DL, ReducedTy, {Subvectors[0], Subvectors[1]}));
+ DAG.getNode(ISD::ADD, DL, AccVT, {Subvectors[0], Subvectors[1]}));
Subvectors.pop_front();
Subvectors.pop_front();
}
diff --git a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
index a46c97d8605cd..455231dd37be6 100644
--- a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
@@ -1272,3 +1272,89 @@ entry:
%partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv8i64(<vscale x 2 x i64> %acc, <vscale x 8 x i64> %mult)
ret <vscale x 2 x i64> %partial.reduce
}
+
+define <vscale x 2 x i16> @udot_nxv8i8_promote (<vscale x 2 x i16> %acc, <vscale x 8 x i8> %a, <vscale x 8 x i8> %b){
+; CHECK-LABEL: udot_nxv8i8_promote:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: and z1.h, z1.h, #0xff
+; CHECK-NEXT: and z2.h, z2.h, #0xff
+; CHECK-NEXT: mul z1.h, z1.h, z2.h
+; CHECK-NEXT: uunpklo z2.s, z1.h
+; CHECK-NEXT: uunpkhi z1.s, z1.h
+; CHECK-NEXT: uunpklo z3.d, z2.s
+; CHECK-NEXT: uunpklo z4.d, z1.s
+; CHECK-NEXT: uunpkhi z2.d, z2.s
+; CHECK-NEXT: uunpkhi z1.d, z1.s
+; CHECK-NEXT: add z0.d, z0.d, z3.d
+; CHECK-NEXT: add z2.d, z2.d, z4.d
+; CHECK-NEXT: add z0.d, z1.d, z0.d
+; CHECK-NEXT: add z0.d, z2.d, z0.d
+; CHECK-NEXT: ret
+;
+; CHECK-NEWLOWERING-LABEL: udot_nxv8i8_promote:
+; CHECK-NEWLOWERING: // %bb.0: // %entry
+; CHECK-NEWLOWERING-NEXT: and z1.h, z1.h, #0xff
+; CHECK-NEWLOWERING-NEXT: and z2.h, z2.h, #0xff
+; CHECK-NEWLOWERING-NEXT: mul z1.h, z1.h, z2.h
+; CHECK-NEWLOWERING-NEXT: uunpklo z2.s, z1.h
+; CHECK-NEWLOWERING-NEXT: uunpkhi z1.s, z1.h
+; CHECK-NEWLOWERING-NEXT: uunpklo z3.d, z2.s
+; CHECK-NEWLOWERING-NEXT: uunpklo z4.d, z1.s
+; CHECK-NEWLOWERING-NEXT: uunpkhi z2.d, z2.s
+; CHECK-NEWLOWERING-NEXT: uunpkhi z1.d, z1.s
+; CHECK-NEWLOWERING-NEXT: add z0.d, z0.d, z3.d
+; CHECK-NEWLOWERING-NEXT: add z2.d, z2.d, z4.d
+; CHECK-NEWLOWERING-NEXT: add z0.d, z1.d, z0.d
+; CHECK-NEWLOWERING-NEXT: add z0.d, z2.d, z0.d
+; CHECK-NEWLOWERING-NEXT: ret
+entry:
+ %a.wide = zext <vscale x 8 x i8> %a to <vscale x 8 x i16>
+ %b.wide = zext <vscale x 8 x i8> %b to <vscale x 8 x i16>
+ %mult = mul nuw nsw <vscale x 8 x i16> %a.wide, %b.wide
+ %partial.reduce = tail call <vscale x 2 x i16> @llvm.experimental.vector.partial.reduce.add.nxv2i16.nxv8i16(<vscale x 2 x i16> %acc, <vscale x 8 x i16> %mult)
+ ret <vscale x 2 x i16> %partial.reduce
+}
+
+define <vscale x 2 x i16> @sdot_nxv8i8_promote (<vscale x 2 x i16> %acc, <vscale x 8 x i8> %a, <vscale x 8 x i8> %b){
+; CHECK-LABEL: sdot_nxv8i8_promote:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: ptrue p0.h
+; CHECK-NEXT: sxtb z1.h, p0/m, z1.h
+; CHECK-NEXT: sxtb z2.h, p0/m, z2.h
+; CHECK-NEXT: mul z1.h, z1.h, z2.h
+; CHECK-NEXT: uunpklo z2.s, z1.h
+; CHECK-NEXT: uunpkhi z1.s, z1.h
+; CHECK-NEXT: uunpklo z3.d, z2.s
+; CHECK-NEXT: uunpklo z4.d, z1.s
+; CHECK-NEXT: uunpkhi z2.d, z2.s
+; CHECK-NEXT: uunpkhi z1.d, z1.s
+; CHECK-NEXT: add z0.d, z0.d, z3.d
+; CHECK-NEXT: add z2.d, z2.d, z4.d
+; CHECK-NEXT: add z0.d, z1.d, z0.d
+; CHECK-NEXT: add z0.d, z2.d, z0.d
+; CHECK-NEXT: ret
+;
+; CHECK-NEWLOWERING-LABEL: sdot_nxv8i8_promote:
+; CHECK-NEWLOWERING: // %bb.0: // %entry
+; CHECK-NEWLOWERING-NEXT: ptrue p0.h
+; CHECK-NEWLOWERING-NEXT: sxtb z1.h, p0/m, z1.h
+; CHECK-NEWLOWERING-NEXT: sxtb z2.h, p0/m, z2.h
+; CHECK-NEWLOWERING-NEXT: mul z1.h, z1.h, z2.h
+; CHECK-NEWLOWERING-NEXT: uunpklo z2.s, z1.h
+; CHECK-NEWLOWERING-NEXT: uunpkhi z1.s, z1.h
+; CHECK-NEWLOWERING-NEXT: uunpklo z3.d, z2.s
+; CHECK-NEWLOWERING-NEXT: uunpklo z4.d, z1.s
+; CHECK-NEWLOWERING-NEXT: uunpkhi z2.d, z2.s
+; CHECK-NEWLOWERING-NEXT: uunpkhi z1.d, z1.s
+; CHECK-NEWLOWERING-NEXT: add z0.d, z0.d, z3.d
+; CHECK-NEWLOWERING-NEXT: add z2.d, z2.d, z4.d
+; CHECK-NEWLOWERING-NEXT: add z0.d, z1.d, z0.d
+; CHECK-NEWLOWERING-NEXT: add z0.d, z2.d, z0.d
+; CHECK-NEWLOWERING-NEXT: ret
+entry:
+ %a.wide = sext <vscale x 8 x i8> %a to <vscale x 8 x i16>
+ %b.wide = sext <vscale x 8 x i8> %b to <vscale x 8 x i16>
+ %mult = mul nuw nsw <vscale x 8 x i16> %a.wide, %b.wide
+ %partial.reduce = tail call <vscale x 2 x i16> @llvm.experimental.vector.partial.reduce.add.nxv2i16.nxv8i16(<vscale x 2 x i16> %acc, <vscale x 8 x i16> %mult)
+ ret <vscale x 2 x i16> %partial.reduce
+}
>From 17fe9bddbf8ae416fc84e6f7e2746d4a0f5d9dc1 Mon Sep 17 00:00:00 2001
From: James Chesterman <james.chesterman at arm.com>
Date: Thu, 13 Feb 2025 15:35:55 +0000
Subject: [PATCH 12/12] [DAGCombiner] Add generic DAG combine for
ISD::PARTIAL_REDUCE_MLA
Add generic DAG combine for ISD::PARTIAL_REDUCE_U/SMLA nodes.
Transforms the DAG from:
PARTIAL_REDUCE_MLA(Acc, MUL(EXT(MulOpLHS), EXT(MulOpRHS)), Splat(1))
to
PARTIAL_REDUCE_MLA(Acc, MulOpLHS, MulOpRHS).
---
llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 48 +++++++++
.../neon-partial-reduce-dot-product.ll | 75 +++++++------
.../AArch64/sve-partial-reduce-dot-product.ll | 100 +++++++++---------
3 files changed, 138 insertions(+), 85 deletions(-)
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 8858c2012c706..f6d5f08762151 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -545,6 +545,7 @@ namespace {
SDValue visitMGATHER(SDNode *N);
SDValue visitMSCATTER(SDNode *N);
SDValue visitMHISTOGRAM(SDNode *N);
+ SDValue visitPARTIAL_REDUCE_MLA(SDNode *N);
SDValue visitVPGATHER(SDNode *N);
SDValue visitVPSCATTER(SDNode *N);
SDValue visitVP_STRIDED_LOAD(SDNode *N);
@@ -1972,6 +1973,9 @@ SDValue DAGCombiner::visit(SDNode *N) {
case ISD::MSCATTER: return visitMSCATTER(N);
case ISD::MSTORE: return visitMSTORE(N);
case ISD::EXPERIMENTAL_VECTOR_HISTOGRAM: return visitMHISTOGRAM(N);
+ case ISD::PARTIAL_REDUCE_SMLA:
+ case ISD::PARTIAL_REDUCE_UMLA:
+ return visitPARTIAL_REDUCE_MLA(N);
case ISD::VECTOR_COMPRESS: return visitVECTOR_COMPRESS(N);
case ISD::LIFETIME_END: return visitLIFETIME_END(N);
case ISD::FP_TO_FP16: return visitFP_TO_FP16(N);
@@ -12497,6 +12501,50 @@ SDValue DAGCombiner::visitMHISTOGRAM(SDNode *N) {
return SDValue();
}
+SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
+ // Makes PARTIAL_REDUCE_MLA(Acc, MUL(EXT(MulOpLHS), EXT(MulOpRHS)), Splat(1))
+ // into PARTIAL_REDUCE_MLA(Acc, MulOpLHS, MulOpRHS)
+ SDLoc DL(N);
+ SDValue Op0 = N->getOperand(0);
+ SDValue Op1 = N->getOperand(1);
+ SDValue Op2 = N->getOperand(2);
+
+ if (Op1->getOpcode() != ISD::MUL)
+ return SDValue();
+
+ SDValue ExtMulOpLHS = Op1->getOperand(0);
+ SDValue ExtMulOpRHS = Op1->getOperand(1);
+ unsigned ExtMulOpLHSOpcode = ExtMulOpLHS->getOpcode();
+ unsigned ExtMulOpRHSOpcode = ExtMulOpRHS->getOpcode();
+ if (!ISD::isExtOpcode(ExtMulOpLHSOpcode) ||
+ !ISD::isExtOpcode(ExtMulOpRHSOpcode))
+ return SDValue();
+
+ SDValue MulOpLHS = ExtMulOpLHS->getOperand(0);
+ SDValue MulOpRHS = ExtMulOpRHS->getOperand(0);
+ EVT MulOpLHSVT = MulOpLHS.getValueType();
+ if (MulOpLHSVT != MulOpRHS.getValueType())
+ return SDValue();
+
+ if (!TLI.isTypeLegal(MulOpLHSVT) || !TLI.isTypeLegal(N->getValueType(0)))
+ return SDValue();
+
+ APInt ConstantOne;
+ if (!ISD::isConstantSplatVector(Op2.getNode(), ConstantOne) ||
+ !ConstantOne.isOne())
+ return SDValue();
+
+ bool LHSIsSigned = ExtMulOpLHSOpcode == ISD::SIGN_EXTEND;
+ bool RHSIsSigned = ExtMulOpRHSOpcode == ISD::SIGN_EXTEND;
+ if (LHSIsSigned != RHSIsSigned)
+ return SDValue();
+
+ unsigned NewOpcode =
+ LHSIsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
+ return DAG.getNode(NewOpcode, DL, Op0->getValueType(0), Op0, MulOpLHS,
+ MulOpRHS);
+}
+
SDValue DAGCombiner::visitVP_STRIDED_LOAD(SDNode *N) {
auto *SLD = cast<VPStridedLoadSDNode>(N);
EVT EltVT = SLD->getValueType(0).getVectorElementType();
diff --git a/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
index 40daf8ffb63ea..7ec166aa8ed36 100644
--- a/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
@@ -12,13 +12,15 @@ define <4 x i32> @udot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) {
;
; CHECK-NODOT-LABEL: udot:
; CHECK-NODOT: // %bb.0:
-; CHECK-NODOT-NEXT: umull v3.8h, v2.8b, v1.8b
-; CHECK-NODOT-NEXT: umull2 v1.8h, v2.16b, v1.16b
-; CHECK-NODOT-NEXT: ushll v2.4s, v1.4h, #0
-; CHECK-NODOT-NEXT: uaddw v0.4s, v0.4s, v3.4h
-; CHECK-NODOT-NEXT: uaddw2 v2.4s, v2.4s, v3.8h
-; CHECK-NODOT-NEXT: uaddw2 v0.4s, v0.4s, v1.8h
-; CHECK-NODOT-NEXT: add v0.4s, v2.4s, v0.4s
+; CHECK-NODOT-NEXT: ushll v3.8h, v1.8b, #0
+; CHECK-NODOT-NEXT: ushll v4.8h, v2.8b, #0
+; CHECK-NODOT-NEXT: ushll2 v1.8h, v1.16b, #0
+; CHECK-NODOT-NEXT: ushll2 v2.8h, v2.16b, #0
+; CHECK-NODOT-NEXT: umlal v0.4s, v4.4h, v3.4h
+; CHECK-NODOT-NEXT: umull v5.4s, v2.4h, v1.4h
+; CHECK-NODOT-NEXT: umlal2 v0.4s, v2.8h, v1.8h
+; CHECK-NODOT-NEXT: umlal2 v5.4s, v4.8h, v3.8h
+; CHECK-NODOT-NEXT: add v0.4s, v5.4s, v0.4s
; CHECK-NODOT-NEXT: ret
%u.wide = zext <16 x i8> %u to <16 x i32>
%s.wide = zext <16 x i8> %s to <16 x i32>
@@ -35,17 +37,19 @@ define <2 x i32> @udot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) {
;
; CHECK-NODOT-LABEL: udot_narrow:
; CHECK-NODOT: // %bb.0:
-; CHECK-NODOT-NEXT: umull v1.8h, v2.8b, v1.8b
+; CHECK-NODOT-NEXT: ushll v1.8h, v1.8b, #0
+; CHECK-NODOT-NEXT: ushll v2.8h, v2.8b, #0
; CHECK-NODOT-NEXT: // kill: def $d0 killed $d0 def $q0
-; CHECK-NODOT-NEXT: ushll v2.4s, v1.4h, #0
-; CHECK-NODOT-NEXT: ushll2 v3.4s, v1.8h, #0
-; CHECK-NODOT-NEXT: ext v4.16b, v1.16b, v1.16b, #8
-; CHECK-NODOT-NEXT: uaddw v0.4s, v0.4s, v1.4h
+; CHECK-NODOT-NEXT: umull v3.4s, v2.4h, v1.4h
+; CHECK-NODOT-NEXT: umull2 v4.4s, v2.8h, v1.8h
+; CHECK-NODOT-NEXT: ext v5.16b, v1.16b, v1.16b, #8
+; CHECK-NODOT-NEXT: ext v6.16b, v2.16b, v2.16b, #8
+; CHECK-NODOT-NEXT: umlal v0.4s, v2.4h, v1.4h
; CHECK-NODOT-NEXT: ext v3.16b, v3.16b, v3.16b, #8
-; CHECK-NODOT-NEXT: ext v2.16b, v2.16b, v2.16b, #8
-; CHECK-NODOT-NEXT: add v0.2s, v3.2s, v0.2s
-; CHECK-NODOT-NEXT: uaddw v1.4s, v2.4s, v4.4h
+; CHECK-NODOT-NEXT: ext v1.16b, v4.16b, v4.16b, #8
+; CHECK-NODOT-NEXT: umlal v3.4s, v6.4h, v5.4h
; CHECK-NODOT-NEXT: add v0.2s, v1.2s, v0.2s
+; CHECK-NODOT-NEXT: add v0.2s, v3.2s, v0.2s
; CHECK-NODOT-NEXT: ret
%u.wide = zext <8 x i8> %u to <8 x i32>
%s.wide = zext <8 x i8> %s to <8 x i32>
@@ -62,13 +66,15 @@ define <4 x i32> @sdot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) {
;
; CHECK-NODOT-LABEL: sdot:
; CHECK-NODOT: // %bb.0:
-; CHECK-NODOT-NEXT: smull v3.8h, v2.8b, v1.8b
-; CHECK-NODOT-NEXT: smull2 v1.8h, v2.16b, v1.16b
-; CHECK-NODOT-NEXT: sshll v2.4s, v1.4h, #0
-; CHECK-NODOT-NEXT: saddw v0.4s, v0.4s, v3.4h
-; CHECK-NODOT-NEXT: saddw2 v2.4s, v2.4s, v3.8h
-; CHECK-NODOT-NEXT: saddw2 v0.4s, v0.4s, v1.8h
-; CHECK-NODOT-NEXT: add v0.4s, v2.4s, v0.4s
+; CHECK-NODOT-NEXT: sshll v3.8h, v1.8b, #0
+; CHECK-NODOT-NEXT: sshll v4.8h, v2.8b, #0
+; CHECK-NODOT-NEXT: sshll2 v1.8h, v1.16b, #0
+; CHECK-NODOT-NEXT: sshll2 v2.8h, v2.16b, #0
+; CHECK-NODOT-NEXT: smlal v0.4s, v4.4h, v3.4h
+; CHECK-NODOT-NEXT: smull v5.4s, v2.4h, v1.4h
+; CHECK-NODOT-NEXT: smlal2 v0.4s, v2.8h, v1.8h
+; CHECK-NODOT-NEXT: smlal2 v5.4s, v4.8h, v3.8h
+; CHECK-NODOT-NEXT: add v0.4s, v5.4s, v0.4s
; CHECK-NODOT-NEXT: ret
%u.wide = sext <16 x i8> %u to <16 x i32>
%s.wide = sext <16 x i8> %s to <16 x i32>
@@ -85,17 +91,19 @@ define <2 x i32> @sdot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) {
;
; CHECK-NODOT-LABEL: sdot_narrow:
; CHECK-NODOT: // %bb.0:
-; CHECK-NODOT-NEXT: smull v1.8h, v2.8b, v1.8b
+; CHECK-NODOT-NEXT: sshll v1.8h, v1.8b, #0
+; CHECK-NODOT-NEXT: sshll v2.8h, v2.8b, #0
; CHECK-NODOT-NEXT: // kill: def $d0 killed $d0 def $q0
-; CHECK-NODOT-NEXT: sshll v2.4s, v1.4h, #0
-; CHECK-NODOT-NEXT: sshll2 v3.4s, v1.8h, #0
-; CHECK-NODOT-NEXT: ext v4.16b, v1.16b, v1.16b, #8
-; CHECK-NODOT-NEXT: saddw v0.4s, v0.4s, v1.4h
+; CHECK-NODOT-NEXT: smull v3.4s, v2.4h, v1.4h
+; CHECK-NODOT-NEXT: smull2 v4.4s, v2.8h, v1.8h
+; CHECK-NODOT-NEXT: ext v5.16b, v1.16b, v1.16b, #8
+; CHECK-NODOT-NEXT: ext v6.16b, v2.16b, v2.16b, #8
+; CHECK-NODOT-NEXT: smlal v0.4s, v2.4h, v1.4h
; CHECK-NODOT-NEXT: ext v3.16b, v3.16b, v3.16b, #8
-; CHECK-NODOT-NEXT: ext v2.16b, v2.16b, v2.16b, #8
-; CHECK-NODOT-NEXT: add v0.2s, v3.2s, v0.2s
-; CHECK-NODOT-NEXT: saddw v1.4s, v2.4s, v4.4h
+; CHECK-NODOT-NEXT: ext v1.16b, v4.16b, v4.16b, #8
+; CHECK-NODOT-NEXT: smlal v3.4s, v6.4h, v5.4h
; CHECK-NODOT-NEXT: add v0.2s, v1.2s, v0.2s
+; CHECK-NODOT-NEXT: add v0.2s, v3.2s, v0.2s
; CHECK-NODOT-NEXT: ret
%u.wide = sext <8 x i8> %u to <8 x i32>
%s.wide = sext <8 x i8> %s to <8 x i32>
@@ -531,9 +539,10 @@ define <4 x i64> @sdot_no_bin_op_8to64(<4 x i64> %acc, <16 x i8> %a){
define <4 x i32> @not_udot(<4 x i32> %acc, <8 x i8> %u, <8 x i8> %s) #0{
; CHECK-LABEL: not_udot:
; CHECK: // %bb.0:
-; CHECK-NEXT: umull v1.8h, v2.8b, v1.8b
-; CHECK-NEXT: uaddw v0.4s, v0.4s, v1.4h
-; CHECK-NEXT: uaddw2 v0.4s, v0.4s, v1.8h
+; CHECK-NEXT: ushll v1.8h, v1.8b, #0
+; CHECK-NEXT: ushll v2.8h, v2.8b, #0
+; CHECK-NEXT: umlal v0.4s, v2.4h, v1.4h
+; CHECK-NEXT: umlal2 v0.4s, v2.8h, v1.8h
; CHECK-NEXT: ret
%u.wide = zext <8 x i8> %u to <8 x i32>
%s.wide = zext <8 x i8> %s to <8 x i32>
diff --git a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
index 455231dd37be6..c6dc0ed5651ec 100644
--- a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
@@ -11,24 +11,23 @@ define <vscale x 4 x i32> @udot(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a,
;
; CHECK-NEWLOWERING-LABEL: udot:
; CHECK-NEWLOWERING: // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT: uunpklo z3.h, z1.b
-; CHECK-NEWLOWERING-NEXT: uunpklo z4.h, z2.b
-; CHECK-NEWLOWERING-NEXT: uunpkhi z1.h, z1.b
+; CHECK-NEWLOWERING-NEXT: uunpklo z3.h, z2.b
+; CHECK-NEWLOWERING-NEXT: uunpklo z4.h, z1.b
; CHECK-NEWLOWERING-NEXT: uunpkhi z2.h, z2.b
+; CHECK-NEWLOWERING-NEXT: uunpkhi z1.h, z1.b
; CHECK-NEWLOWERING-NEXT: ptrue p0.s
; CHECK-NEWLOWERING-NEXT: uunpklo z5.s, z3.h
-; CHECK-NEWLOWERING-NEXT: uunpkhi z3.s, z3.h
; CHECK-NEWLOWERING-NEXT: uunpklo z6.s, z4.h
+; CHECK-NEWLOWERING-NEXT: uunpkhi z3.s, z3.h
; CHECK-NEWLOWERING-NEXT: uunpkhi z4.s, z4.h
-; CHECK-NEWLOWERING-NEXT: uunpklo z7.s, z1.h
-; CHECK-NEWLOWERING-NEXT: uunpkhi z1.s, z1.h
-; CHECK-NEWLOWERING-NEXT: uunpklo z24.s, z2.h
-; CHECK-NEWLOWERING-NEXT: uunpkhi z2.s, z2.h
-; CHECK-NEWLOWERING-NEXT: mla z0.s, p0/m, z5.s, z6.s
-; CHECK-NEWLOWERING-NEXT: mul z3.s, z3.s, z4.s
-; CHECK-NEWLOWERING-NEXT: mla z0.s, p0/m, z1.s, z2.s
-; CHECK-NEWLOWERING-NEXT: movprfx z1, z3
-; CHECK-NEWLOWERING-NEXT: mla z1.s, p0/m, z7.s, z24.s
+; CHECK-NEWLOWERING-NEXT: mla z0.s, p0/m, z6.s, z5.s
+; CHECK-NEWLOWERING-NEXT: uunpkhi z5.s, z2.h
+; CHECK-NEWLOWERING-NEXT: uunpkhi z6.s, z1.h
+; CHECK-NEWLOWERING-NEXT: mul z3.s, z4.s, z3.s
+; CHECK-NEWLOWERING-NEXT: uunpklo z2.s, z2.h
+; CHECK-NEWLOWERING-NEXT: uunpklo z1.s, z1.h
+; CHECK-NEWLOWERING-NEXT: mla z0.s, p0/m, z6.s, z5.s
+; CHECK-NEWLOWERING-NEXT: mad z1.s, p0/m, z2.s, z3.s
; CHECK-NEWLOWERING-NEXT: add z0.s, z1.s, z0.s
; CHECK-NEWLOWERING-NEXT: ret
entry:
@@ -47,24 +46,23 @@ define <vscale x 2 x i64> @udot_wide(<vscale x 2 x i64> %acc, <vscale x 8 x i16>
;
; CHECK-NEWLOWERING-LABEL: udot_wide:
; CHECK-NEWLOWERING: // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT: uunpklo z3.s, z1.h
-; CHECK-NEWLOWERING-NEXT: uunpklo z4.s, z2.h
-; CHECK-NEWLOWERING-NEXT: uunpkhi z1.s, z1.h
+; CHECK-NEWLOWERING-NEXT: uunpklo z3.s, z2.h
+; CHECK-NEWLOWERING-NEXT: uunpklo z4.s, z1.h
; CHECK-NEWLOWERING-NEXT: uunpkhi z2.s, z2.h
+; CHECK-NEWLOWERING-NEXT: uunpkhi z1.s, z1.h
; CHECK-NEWLOWERING-NEXT: ptrue p0.d
; CHECK-NEWLOWERING-NEXT: uunpklo z5.d, z3.s
-; CHECK-NEWLOWERING-NEXT: uunpkhi z3.d, z3.s
; CHECK-NEWLOWERING-NEXT: uunpklo z6.d, z4.s
+; CHECK-NEWLOWERING-NEXT: uunpkhi z3.d, z3.s
; CHECK-NEWLOWERING-NEXT: uunpkhi z4.d, z4.s
-; CHECK-NEWLOWERING-NEXT: uunpklo z7.d, z1.s
-; CHECK-NEWLOWERING-NEXT: uunpkhi z1.d, z1.s
-; CHECK-NEWLOWERING-NEXT: uunpklo z24.d, z2.s
-; CHECK-NEWLOWERING-NEXT: uunpkhi z2.d, z2.s
-; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z5.d, z6.d
-; CHECK-NEWLOWERING-NEXT: mul z3.d, z3.d, z4.d
-; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z1.d, z2.d
-; CHECK-NEWLOWERING-NEXT: movprfx z1, z3
-; CHECK-NEWLOWERING-NEXT: mla z1.d, p0/m, z7.d, z24.d
+; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z6.d, z5.d
+; CHECK-NEWLOWERING-NEXT: uunpkhi z5.d, z2.s
+; CHECK-NEWLOWERING-NEXT: uunpkhi z6.d, z1.s
+; CHECK-NEWLOWERING-NEXT: mul z3.d, z4.d, z3.d
+; CHECK-NEWLOWERING-NEXT: uunpklo z2.d, z2.s
+; CHECK-NEWLOWERING-NEXT: uunpklo z1.d, z1.s
+; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z6.d, z5.d
+; CHECK-NEWLOWERING-NEXT: mad z1.d, p0/m, z2.d, z3.d
; CHECK-NEWLOWERING-NEXT: add z0.d, z1.d, z0.d
; CHECK-NEWLOWERING-NEXT: ret
entry:
@@ -83,24 +81,23 @@ define <vscale x 4 x i32> @sdot(<vscale x 4 x i32> %accc, <vscale x 16 x i8> %a,
;
; CHECK-NEWLOWERING-LABEL: sdot:
; CHECK-NEWLOWERING: // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT: sunpklo z3.h, z1.b
-; CHECK-NEWLOWERING-NEXT: sunpklo z4.h, z2.b
-; CHECK-NEWLOWERING-NEXT: sunpkhi z1.h, z1.b
+; CHECK-NEWLOWERING-NEXT: sunpklo z3.h, z2.b
+; CHECK-NEWLOWERING-NEXT: sunpklo z4.h, z1.b
; CHECK-NEWLOWERING-NEXT: sunpkhi z2.h, z2.b
+; CHECK-NEWLOWERING-NEXT: sunpkhi z1.h, z1.b
; CHECK-NEWLOWERING-NEXT: ptrue p0.s
; CHECK-NEWLOWERING-NEXT: sunpklo z5.s, z3.h
-; CHECK-NEWLOWERING-NEXT: sunpkhi z3.s, z3.h
; CHECK-NEWLOWERING-NEXT: sunpklo z6.s, z4.h
+; CHECK-NEWLOWERING-NEXT: sunpkhi z3.s, z3.h
; CHECK-NEWLOWERING-NEXT: sunpkhi z4.s, z4.h
-; CHECK-NEWLOWERING-NEXT: sunpklo z7.s, z1.h
-; CHECK-NEWLOWERING-NEXT: sunpkhi z1.s, z1.h
-; CHECK-NEWLOWERING-NEXT: sunpklo z24.s, z2.h
-; CHECK-NEWLOWERING-NEXT: sunpkhi z2.s, z2.h
-; CHECK-NEWLOWERING-NEXT: mla z0.s, p0/m, z5.s, z6.s
-; CHECK-NEWLOWERING-NEXT: mul z3.s, z3.s, z4.s
-; CHECK-NEWLOWERING-NEXT: mla z0.s, p0/m, z1.s, z2.s
-; CHECK-NEWLOWERING-NEXT: movprfx z1, z3
-; CHECK-NEWLOWERING-NEXT: mla z1.s, p0/m, z7.s, z24.s
+; CHECK-NEWLOWERING-NEXT: mla z0.s, p0/m, z6.s, z5.s
+; CHECK-NEWLOWERING-NEXT: sunpkhi z5.s, z2.h
+; CHECK-NEWLOWERING-NEXT: sunpkhi z6.s, z1.h
+; CHECK-NEWLOWERING-NEXT: mul z3.s, z4.s, z3.s
+; CHECK-NEWLOWERING-NEXT: sunpklo z2.s, z2.h
+; CHECK-NEWLOWERING-NEXT: sunpklo z1.s, z1.h
+; CHECK-NEWLOWERING-NEXT: mla z0.s, p0/m, z6.s, z5.s
+; CHECK-NEWLOWERING-NEXT: mad z1.s, p0/m, z2.s, z3.s
; CHECK-NEWLOWERING-NEXT: add z0.s, z1.s, z0.s
; CHECK-NEWLOWERING-NEXT: ret
entry:
@@ -119,24 +116,23 @@ define <vscale x 2 x i64> @sdot_wide(<vscale x 2 x i64> %acc, <vscale x 8 x i16>
;
; CHECK-NEWLOWERING-LABEL: sdot_wide:
; CHECK-NEWLOWERING: // %bb.0: // %entry
-; CHECK-NEWLOWERING-NEXT: sunpklo z3.s, z1.h
-; CHECK-NEWLOWERING-NEXT: sunpklo z4.s, z2.h
-; CHECK-NEWLOWERING-NEXT: sunpkhi z1.s, z1.h
+; CHECK-NEWLOWERING-NEXT: sunpklo z3.s, z2.h
+; CHECK-NEWLOWERING-NEXT: sunpklo z4.s, z1.h
; CHECK-NEWLOWERING-NEXT: sunpkhi z2.s, z2.h
+; CHECK-NEWLOWERING-NEXT: sunpkhi z1.s, z1.h
; CHECK-NEWLOWERING-NEXT: ptrue p0.d
; CHECK-NEWLOWERING-NEXT: sunpklo z5.d, z3.s
-; CHECK-NEWLOWERING-NEXT: sunpkhi z3.d, z3.s
; CHECK-NEWLOWERING-NEXT: sunpklo z6.d, z4.s
+; CHECK-NEWLOWERING-NEXT: sunpkhi z3.d, z3.s
; CHECK-NEWLOWERING-NEXT: sunpkhi z4.d, z4.s
-; CHECK-NEWLOWERING-NEXT: sunpklo z7.d, z1.s
-; CHECK-NEWLOWERING-NEXT: sunpkhi z1.d, z1.s
-; CHECK-NEWLOWERING-NEXT: sunpklo z24.d, z2.s
-; CHECK-NEWLOWERING-NEXT: sunpkhi z2.d, z2.s
-; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z5.d, z6.d
-; CHECK-NEWLOWERING-NEXT: mul z3.d, z3.d, z4.d
-; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z1.d, z2.d
-; CHECK-NEWLOWERING-NEXT: movprfx z1, z3
-; CHECK-NEWLOWERING-NEXT: mla z1.d, p0/m, z7.d, z24.d
+; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z6.d, z5.d
+; CHECK-NEWLOWERING-NEXT: sunpkhi z5.d, z2.s
+; CHECK-NEWLOWERING-NEXT: sunpkhi z6.d, z1.s
+; CHECK-NEWLOWERING-NEXT: mul z3.d, z4.d, z3.d
+; CHECK-NEWLOWERING-NEXT: sunpklo z2.d, z2.s
+; CHECK-NEWLOWERING-NEXT: sunpklo z1.d, z1.s
+; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z6.d, z5.d
+; CHECK-NEWLOWERING-NEXT: mad z1.d, p0/m, z2.d, z3.d
; CHECK-NEWLOWERING-NEXT: add z0.d, z1.d, z0.d
; CHECK-NEWLOWERING-NEXT: ret
entry:
More information about the llvm-commits
mailing list