[llvm] [SelectionDAG] Improve type legalisation for PARTIAL_REDUCE_MLA (PR #130935)
Nicholas Guy via llvm-commits
llvm-commits at lists.llvm.org
Wed Apr 23 06:00:15 PDT 2025
https://github.com/NickGuy-Arm updated https://github.com/llvm/llvm-project/pull/130935
>From a27a811233d6248a95d830c4ea6b6370c1305d7b Mon Sep 17 00:00:00 2001
From: James Chesterman <james.chesterman at arm.com>
Date: Fri, 28 Feb 2025 17:10:50 +0000
Subject: [PATCH 1/2] [SelectionDAG] Improve type legalisation for
PARTIAL_REDUCE_MLA
Implement proper splitting functions for PARTIAL_REDUCE_MLA ISD
nodes.
This makes the udot_8to64 and sdot_8to64 tests generate dot product
instructions for when the new ISD nodes are used.
---
llvm/include/llvm/CodeGen/TargetLowering.h | 6 +++++
.../SelectionDAG/LegalizeVectorTypes.cpp | 26 ++++++++++++++++---
.../AArch64/sve-partial-reduce-dot-product.ll | 4 +++
3 files changed, 33 insertions(+), 3 deletions(-)
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index abe261728a3e6..7b0e15f951681 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -1668,6 +1668,12 @@ class TargetLoweringBase {
return Action == Legal || Action == Custom;
}
+ /// Return true if a PARTIAL_REDUCE_U/SMLA node with the specified types is
+ /// legal for this target.
+ bool isPartialReduceMLALegal(EVT AccVT, EVT InputVT) const {
+ return getPartialReduceMLAAction(AccVT, InputVT) == Legal;
+ }
+
/// If the action for this operation is to promote, this method returns the
/// ValueType to promote to.
MVT getTypeToPromoteTo(unsigned Op, MVT VT) const {
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
index a01e1cff74564..d0ae436a8758f 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
@@ -3220,8 +3220,26 @@ void DAGTypeLegalizer::SplitVecRes_VP_REVERSE(SDNode *N, SDValue &Lo,
void DAGTypeLegalizer::SplitVecRes_PARTIAL_REDUCE_MLA(SDNode *N, SDValue &Lo,
SDValue &Hi) {
SDLoc DL(N);
- SDValue Expanded = TLI.expandPartialReduceMLA(N, DAG);
- std::tie(Lo, Hi) = DAG.SplitVector(Expanded, DL);
+ SDValue Acc = N->getOperand(0);
+ SDValue Input1 = N->getOperand(1);
+
+ // If the node has not gone through the DAG combine, then do not attempt to
+ // legalise, just expand.
+ if (!TLI.isPartialReduceMLALegal(Acc.getValueType(), Input1.getValueType())) {
+ SDValue Expanded = TLI.expandPartialReduceMLA(N, DAG);
+ std::tie(Lo, Hi) = DAG.SplitVector(Expanded, DL);
+ return;
+ }
+
+ SDValue AccLo, AccHi, Input1Lo, Input1Hi, Input2Lo, Input2Hi;
+ std::tie(AccLo, AccHi) = DAG.SplitVector(Acc, DL);
+ std::tie(Input1Lo, Input1Hi) = DAG.SplitVector(Input1, DL);
+ std::tie(Input2Lo, Input2Hi) = DAG.SplitVector(N->getOperand(2), DL);
+ unsigned Opcode = N->getOpcode();
+ EVT ResultVT = AccLo.getValueType();
+
+ Lo = DAG.getNode(Opcode, DL, ResultVT, AccLo, Input1Lo, Input2Lo);
+ Hi = DAG.getNode(Opcode, DL, ResultVT, AccHi, Input1Hi, Input2Hi);
}
void DAGTypeLegalizer::SplitVecRes_VECTOR_DEINTERLEAVE(SDNode *N) {
@@ -4501,7 +4519,9 @@ SDValue DAGTypeLegalizer::SplitVecOp_VECTOR_HISTOGRAM(SDNode *N) {
}
SDValue DAGTypeLegalizer::SplitVecOp_PARTIAL_REDUCE_MLA(SDNode *N) {
- return TLI.expandPartialReduceMLA(N, DAG);
+ SDValue Lo, Hi;
+ SplitVecRes_PARTIAL_REDUCE_MLA(N, Lo, Hi);
+ return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), N->getValueType(0), Lo, Hi);
}
//===----------------------------------------------------------------------===//
diff --git a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
index ed27f40aba774..71936b686be15 100644
--- a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
@@ -259,6 +259,8 @@ define <vscale x 4 x i64> @udot_8to64(<vscale x 4 x i64> %acc, <vscale x 16 x i8
; CHECK-NEWLOWERING-NEXT: add z1.d, z3.d, z1.d
; CHECK-NEWLOWERING-NEXT: addvl sp, sp, #2
; CHECK-NEWLOWERING-NEXT: ldr x29, [sp], #16 // 8-byte Folded Reload
+; CHECK-NEWLOWERING-NEXT: udot z0.d, z5.h, z4.h
+; CHECK-NEWLOWERING-NEXT: udot z1.d, z2.h, z3.h
; CHECK-NEWLOWERING-NEXT: ret
entry:
%a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i64>
@@ -293,6 +295,8 @@ define <vscale x 4 x i64> @sdot_8to64(<vscale x 4 x i64> %acc, <vscale x 16 x i8
; CHECK-NEWLOWERING-NEXT: sunpklo z4.h, z2.b
; CHECK-NEWLOWERING-NEXT: sunpklo z5.h, z3.b
; CHECK-NEWLOWERING-NEXT: sunpkhi z2.h, z2.b
+; CHECK-NEWLOWERING-NEXT: sdot z0.d, z5.h, z4.h
+; CHECK-NEWLOWERING-NEXT: sdot z1.d, z2.h, z3.h
; CHECK-NEWLOWERING-NEXT: sunpkhi z3.h, z3.b
; CHECK-NEWLOWERING-NEXT: ptrue p0.d
; CHECK-NEWLOWERING-NEXT: sunpklo z6.s, z4.h
>From 4ee499082d5f1058dc192eb5eabee7662ad7a866 Mon Sep 17 00:00:00 2001
From: Nick Guy <nicholas.guy at arm.com>
Date: Tue, 15 Apr 2025 13:12:49 +0100
Subject: [PATCH 2/2] Explicitly set PartialReduceMLAActions
---
.../Target/AArch64/AArch64ISelLowering.cpp | 28 +++++++++++++++++++
1 file changed, 28 insertions(+)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 447794cc2b744..810d42635e7b2 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1604,6 +1604,26 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::MSTORE, VT, Custom);
}
+ if (EnablePartialReduceNodes) {
+ for (MVT VT : MVT::integer_scalable_vector_valuetypes()) {
+ for (MVT InnerVT : MVT::integer_scalable_vector_valuetypes()) {
+ // 1. Set all combinations where a type is illegal to "Legal"
+ // - These will be legalized to a legal type pair
+ // - Avoid expanding them too early (or preventing folds)
+ if (!isTypeLegal(VT) || !isTypeLegal(InnerVT)) {
+ setPartialReduceMLAAction(VT, InnerVT, Legal);
+ continue;
+ }
+ // 2. Set all legal combinations to "Expand"
+ // - Not all of these can be lowered (via a Legal or Custom lowering).
+ setPartialReduceMLAAction(VT, InnerVT, Expand);
+ }
+ }
+ // 3. Mark known legal pairs as 'Legal' (these will expand to USDOT).
+ setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv8i16, Legal);
+ setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv16i8, Legal);
+ }
+
// Firstly, exclude all scalable vector extending loads/truncating stores,
// include both integer and floating scalable vector.
for (MVT VT : MVT::scalable_vector_valuetypes()) {
@@ -1856,6 +1876,14 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
// Other pairs will default to 'Expand'.
setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv8i16, Legal);
setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv16i8, Legal);
+
+ setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv8i64, Custom);
+ setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv16i32, Custom);
+
+ setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv4i64, Custom);
+ setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv8i32, Custom);
+ setPartialReduceMLAAction(MVT::nxv8i16, MVT::nxv16i16, Custom);
+ setPartialReduceMLAAction(MVT::nxv16i8, MVT::nxv32i8, Custom);
}
// Handle operations that are only available in non-streaming SVE mode.
More information about the llvm-commits
mailing list