[llvm] [SelectionDAG] Improve type legalisation for PARTIAL_REDUCE_MLA (PR #130935)
via llvm-commits
llvm-commits at lists.llvm.org
Wed Apr 23 06:17:35 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-aarch64
Author: Nicholas Guy (NickGuy-Arm)
<details>
<summary>Changes</summary>
Implement proper splitting functions for PARTIAL_REDUCE_MLA ISD nodes.
This makes the udot_8to64 and sdot_8to64 tests generate dot product instructions for when the new ISD nodes are used.
@<!-- -->JamesChesterman is the original author
---
Full diff: https://github.com/llvm/llvm-project/pull/130935.diff
4 Files Affected:
- (modified) llvm/include/llvm/CodeGen/TargetLowering.h (+6)
- (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp (+23-3)
- (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+28)
- (modified) llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll (+4)
``````````diff
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index abe261728a3e6..7b0e15f951681 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -1668,6 +1668,12 @@ class TargetLoweringBase {
return Action == Legal || Action == Custom;
}
+ /// Return true if a PARTIAL_REDUCE_U/SMLA node with the specified types is
+ /// legal for this target.
+ bool isPartialReduceMLALegal(EVT AccVT, EVT InputVT) const {
+ return getPartialReduceMLAAction(AccVT, InputVT) == Legal;
+ }
+
/// If the action for this operation is to promote, this method returns the
/// ValueType to promote to.
MVT getTypeToPromoteTo(unsigned Op, MVT VT) const {
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
index a01e1cff74564..d0ae436a8758f 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
@@ -3220,8 +3220,26 @@ void DAGTypeLegalizer::SplitVecRes_VP_REVERSE(SDNode *N, SDValue &Lo,
void DAGTypeLegalizer::SplitVecRes_PARTIAL_REDUCE_MLA(SDNode *N, SDValue &Lo,
SDValue &Hi) {
SDLoc DL(N);
- SDValue Expanded = TLI.expandPartialReduceMLA(N, DAG);
- std::tie(Lo, Hi) = DAG.SplitVector(Expanded, DL);
+ SDValue Acc = N->getOperand(0);
+ SDValue Input1 = N->getOperand(1);
+
+ // If the node has not gone through the DAG combine, then do not attempt to
+ // legalise, just expand.
+ if (!TLI.isPartialReduceMLALegal(Acc.getValueType(), Input1.getValueType())) {
+ SDValue Expanded = TLI.expandPartialReduceMLA(N, DAG);
+ std::tie(Lo, Hi) = DAG.SplitVector(Expanded, DL);
+ return;
+ }
+
+ SDValue AccLo, AccHi, Input1Lo, Input1Hi, Input2Lo, Input2Hi;
+ std::tie(AccLo, AccHi) = DAG.SplitVector(Acc, DL);
+ std::tie(Input1Lo, Input1Hi) = DAG.SplitVector(Input1, DL);
+ std::tie(Input2Lo, Input2Hi) = DAG.SplitVector(N->getOperand(2), DL);
+ unsigned Opcode = N->getOpcode();
+ EVT ResultVT = AccLo.getValueType();
+
+ Lo = DAG.getNode(Opcode, DL, ResultVT, AccLo, Input1Lo, Input2Lo);
+ Hi = DAG.getNode(Opcode, DL, ResultVT, AccHi, Input1Hi, Input2Hi);
}
void DAGTypeLegalizer::SplitVecRes_VECTOR_DEINTERLEAVE(SDNode *N) {
@@ -4501,7 +4519,9 @@ SDValue DAGTypeLegalizer::SplitVecOp_VECTOR_HISTOGRAM(SDNode *N) {
}
SDValue DAGTypeLegalizer::SplitVecOp_PARTIAL_REDUCE_MLA(SDNode *N) {
- return TLI.expandPartialReduceMLA(N, DAG);
+ SDValue Lo, Hi;
+ SplitVecRes_PARTIAL_REDUCE_MLA(N, Lo, Hi);
+ return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), N->getValueType(0), Lo, Hi);
}
//===----------------------------------------------------------------------===//
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 447794cc2b744..810d42635e7b2 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1604,6 +1604,26 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::MSTORE, VT, Custom);
}
+ if (EnablePartialReduceNodes) {
+ for (MVT VT : MVT::integer_scalable_vector_valuetypes()) {
+ for (MVT InnerVT : MVT::integer_scalable_vector_valuetypes()) {
+ // 1. Set all combinations where a type is illegal to "Legal"
+ // - These will be legalized to a legal type pair
+ // - Avoid expanding them too early (or preventing folds)
+ if (!isTypeLegal(VT) || !isTypeLegal(InnerVT)) {
+ setPartialReduceMLAAction(VT, InnerVT, Legal);
+ continue;
+ }
+ // 2. Set all legal combinations to "Expand"
+ // - Not all of these can be lowered (via a Legal or Custom lowering).
+ setPartialReduceMLAAction(VT, InnerVT, Expand);
+ }
+ }
+ // 3. Mark known legal pairs as 'Legal' (these will expand to USDOT).
+ setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv8i16, Legal);
+ setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv16i8, Legal);
+ }
+
// Firstly, exclude all scalable vector extending loads/truncating stores,
// include both integer and floating scalable vector.
for (MVT VT : MVT::scalable_vector_valuetypes()) {
@@ -1856,6 +1876,14 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
// Other pairs will default to 'Expand'.
setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv8i16, Legal);
setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv16i8, Legal);
+
+ setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv8i64, Custom);
+ setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv16i32, Custom);
+
+ setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv4i64, Custom);
+ setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv8i32, Custom);
+ setPartialReduceMLAAction(MVT::nxv8i16, MVT::nxv16i16, Custom);
+ setPartialReduceMLAAction(MVT::nxv16i8, MVT::nxv32i8, Custom);
}
// Handle operations that are only available in non-streaming SVE mode.
diff --git a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
index ed27f40aba774..71936b686be15 100644
--- a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
@@ -259,6 +259,8 @@ define <vscale x 4 x i64> @udot_8to64(<vscale x 4 x i64> %acc, <vscale x 16 x i8
; CHECK-NEWLOWERING-NEXT: add z1.d, z3.d, z1.d
; CHECK-NEWLOWERING-NEXT: addvl sp, sp, #2
; CHECK-NEWLOWERING-NEXT: ldr x29, [sp], #16 // 8-byte Folded Reload
+; CHECK-NEWLOWERING-NEXT: udot z0.d, z5.h, z4.h
+; CHECK-NEWLOWERING-NEXT: udot z1.d, z2.h, z3.h
; CHECK-NEWLOWERING-NEXT: ret
entry:
%a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i64>
@@ -293,6 +295,8 @@ define <vscale x 4 x i64> @sdot_8to64(<vscale x 4 x i64> %acc, <vscale x 16 x i8
; CHECK-NEWLOWERING-NEXT: sunpklo z4.h, z2.b
; CHECK-NEWLOWERING-NEXT: sunpklo z5.h, z3.b
; CHECK-NEWLOWERING-NEXT: sunpkhi z2.h, z2.b
+; CHECK-NEWLOWERING-NEXT: sdot z0.d, z5.h, z4.h
+; CHECK-NEWLOWERING-NEXT: sdot z1.d, z2.h, z3.h
; CHECK-NEWLOWERING-NEXT: sunpkhi z3.h, z3.b
; CHECK-NEWLOWERING-NEXT: ptrue p0.d
; CHECK-NEWLOWERING-NEXT: sunpklo z6.s, z4.h
``````````
</details>
https://github.com/llvm/llvm-project/pull/130935
More information about the llvm-commits
mailing list