[llvm] [SelectionDAG] Improve type legalisation for PARTIAL_REDUCE_MLA (PR #130935)

via llvm-commits llvm-commits at lists.llvm.org
Wed Apr 23 06:17:35 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-aarch64

Author: Nicholas Guy (NickGuy-Arm)

<details>
<summary>Changes</summary>

Implement proper splitting functions for PARTIAL_REDUCE_MLA ISD nodes.
This makes the udot_8to64 and sdot_8to64 tests generate dot product instructions for when the new ISD nodes are used.

@<!-- -->JamesChesterman is the original author

---
Full diff: https://github.com/llvm/llvm-project/pull/130935.diff


4 Files Affected:

- (modified) llvm/include/llvm/CodeGen/TargetLowering.h (+6) 
- (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp (+23-3) 
- (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+28) 
- (modified) llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll (+4) 


``````````diff
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index abe261728a3e6..7b0e15f951681 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -1668,6 +1668,12 @@ class TargetLoweringBase {
     return Action == Legal || Action == Custom;
   }
 
+  /// Return true if a PARTIAL_REDUCE_U/SMLA node with the specified types is
+  /// legal for this target.
+  bool isPartialReduceMLALegal(EVT AccVT, EVT InputVT) const {
+    return getPartialReduceMLAAction(AccVT, InputVT) == Legal;
+  }
+
   /// If the action for this operation is to promote, this method returns the
   /// ValueType to promote to.
   MVT getTypeToPromoteTo(unsigned Op, MVT VT) const {
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
index a01e1cff74564..d0ae436a8758f 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
@@ -3220,8 +3220,26 @@ void DAGTypeLegalizer::SplitVecRes_VP_REVERSE(SDNode *N, SDValue &Lo,
 void DAGTypeLegalizer::SplitVecRes_PARTIAL_REDUCE_MLA(SDNode *N, SDValue &Lo,
                                                       SDValue &Hi) {
   SDLoc DL(N);
-  SDValue Expanded = TLI.expandPartialReduceMLA(N, DAG);
-  std::tie(Lo, Hi) = DAG.SplitVector(Expanded, DL);
+  SDValue Acc = N->getOperand(0);
+  SDValue Input1 = N->getOperand(1);
+
+  // If the node has not gone through the DAG combine, then do not attempt to
+  // legalise, just expand.
+  if (!TLI.isPartialReduceMLALegal(Acc.getValueType(), Input1.getValueType())) {
+    SDValue Expanded = TLI.expandPartialReduceMLA(N, DAG);
+    std::tie(Lo, Hi) = DAG.SplitVector(Expanded, DL);
+    return;
+  }
+
+  SDValue AccLo, AccHi, Input1Lo, Input1Hi, Input2Lo, Input2Hi;
+  std::tie(AccLo, AccHi) = DAG.SplitVector(Acc, DL);
+  std::tie(Input1Lo, Input1Hi) = DAG.SplitVector(Input1, DL);
+  std::tie(Input2Lo, Input2Hi) = DAG.SplitVector(N->getOperand(2), DL);
+  unsigned Opcode = N->getOpcode();
+  EVT ResultVT = AccLo.getValueType();
+
+  Lo = DAG.getNode(Opcode, DL, ResultVT, AccLo, Input1Lo, Input2Lo);
+  Hi = DAG.getNode(Opcode, DL, ResultVT, AccHi, Input1Hi, Input2Hi);
 }
 
 void DAGTypeLegalizer::SplitVecRes_VECTOR_DEINTERLEAVE(SDNode *N) {
@@ -4501,7 +4519,9 @@ SDValue DAGTypeLegalizer::SplitVecOp_VECTOR_HISTOGRAM(SDNode *N) {
 }
 
 SDValue DAGTypeLegalizer::SplitVecOp_PARTIAL_REDUCE_MLA(SDNode *N) {
-  return TLI.expandPartialReduceMLA(N, DAG);
+  SDValue Lo, Hi;
+  SplitVecRes_PARTIAL_REDUCE_MLA(N, Lo, Hi);
+  return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), N->getValueType(0), Lo, Hi);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 447794cc2b744..810d42635e7b2 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1604,6 +1604,26 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
       setOperationAction(ISD::MSTORE, VT, Custom);
     }
 
+    if (EnablePartialReduceNodes) {
+      for (MVT VT : MVT::integer_scalable_vector_valuetypes()) {
+        for (MVT InnerVT : MVT::integer_scalable_vector_valuetypes()) {
+          // 1. Set all combinations where a type is illegal to "Legal"
+          // - These will be legalized to a legal type pair
+          // - Avoid expanding them too early (or preventing folds)
+          if (!isTypeLegal(VT) || !isTypeLegal(InnerVT)) {
+            setPartialReduceMLAAction(VT, InnerVT, Legal);
+            continue;
+          }
+          //  2. Set all legal combinations to "Expand"
+          // - Not all of these can be lowered (via a Legal or Custom lowering).
+          setPartialReduceMLAAction(VT, InnerVT, Expand);
+        }
+      }
+      // 3. Mark known legal pairs as 'Legal' (these will expand to USDOT).
+      setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv8i16, Legal);
+      setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv16i8, Legal);
+    }
+
     // Firstly, exclude all scalable vector extending loads/truncating stores,
     // include both integer and floating scalable vector.
     for (MVT VT : MVT::scalable_vector_valuetypes()) {
@@ -1856,6 +1876,14 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
     // Other pairs will default to 'Expand'.
     setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv8i16, Legal);
     setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv16i8, Legal);
+
+    setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv8i64, Custom);
+    setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv16i32, Custom);
+
+    setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv4i64, Custom);
+    setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv8i32, Custom);
+    setPartialReduceMLAAction(MVT::nxv8i16, MVT::nxv16i16, Custom);
+    setPartialReduceMLAAction(MVT::nxv16i8, MVT::nxv32i8, Custom);
   }
 
   // Handle operations that are only available in non-streaming SVE mode.
diff --git a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
index ed27f40aba774..71936b686be15 100644
--- a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
@@ -259,6 +259,8 @@ define <vscale x 4 x i64> @udot_8to64(<vscale x 4 x i64> %acc, <vscale x 16 x i8
 ; CHECK-NEWLOWERING-NEXT:    add z1.d, z3.d, z1.d
 ; CHECK-NEWLOWERING-NEXT:    addvl sp, sp, #2
 ; CHECK-NEWLOWERING-NEXT:    ldr x29, [sp], #16 // 8-byte Folded Reload
+; CHECK-NEWLOWERING-NEXT:    udot z0.d, z5.h, z4.h
+; CHECK-NEWLOWERING-NEXT:    udot z1.d, z2.h, z3.h
 ; CHECK-NEWLOWERING-NEXT:    ret
 entry:
   %a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i64>
@@ -293,6 +295,8 @@ define <vscale x 4 x i64> @sdot_8to64(<vscale x 4 x i64> %acc, <vscale x 16 x i8
 ; CHECK-NEWLOWERING-NEXT:    sunpklo z4.h, z2.b
 ; CHECK-NEWLOWERING-NEXT:    sunpklo z5.h, z3.b
 ; CHECK-NEWLOWERING-NEXT:    sunpkhi z2.h, z2.b
+; CHECK-NEWLOWERING-NEXT:    sdot z0.d, z5.h, z4.h
+; CHECK-NEWLOWERING-NEXT:    sdot z1.d, z2.h, z3.h
 ; CHECK-NEWLOWERING-NEXT:    sunpkhi z3.h, z3.b
 ; CHECK-NEWLOWERING-NEXT:    ptrue p0.d
 ; CHECK-NEWLOWERING-NEXT:    sunpklo z6.s, z4.h

``````````

</details>


https://github.com/llvm/llvm-project/pull/130935


More information about the llvm-commits mailing list