[llvm] [AArch64][SVE] Add lowering for PARTIAL_REDUCE_U/SMLA to USDOT (PR #131327)

Nicholas Guy via llvm-commits llvm-commits at lists.llvm.org
Thu May 8 07:16:56 PDT 2025


================
@@ -924,8 +924,19 @@ SDValue DAGTypeLegalizer::CreateStackStoreLoad(SDValue Op,
 /// illegal ResNo in that case.
 bool DAGTypeLegalizer::CustomLowerNode(SDNode *N, EVT VT, bool LegalizeResult) {
   // See if the target wants to custom lower this node.
-  if (TLI.getOperationAction(N->getOpcode(), VT) != TargetLowering::Custom)
-    return false;
+  unsigned Opcode = N->getOpcode();
+  bool IsPRMLAOpcode =
+      Opcode == ISD::PARTIAL_REDUCE_UMLA || Opcode == ISD::PARTIAL_REDUCE_SMLA;
+
+  if (IsPRMLAOpcode) {
+    if (TLI.getPartialReduceMLAAction(N->getValueType(0),
+                                      N->getOperand(1).getValueType()) !=
+        TargetLowering::Custom)
+      return false;
+  } else {
+    if (TLI.getOperationAction(Opcode, VT) != TargetLowering::Custom)
+      return false;
+  }
----------------
NickGuy-Arm wrote:

It isn't strictly to work around type legalization, but to avoid the extended type being split before we can account for it properly. The type coming in via `VT` is the type of the operand of the `partial_reduce_umla` node, which is an extend, so it effectively hides the actual operand type at this stage.
I don't think the pre-legalization DAG combine would work for the reasons you pointed out, but in trying to implement the separate node, I encountered the exact same issues.
I've added an operation action for `ISD::PARTIAL_REDUCE_UMLA` with `nxv16i32`, which is the post-extended type of `nxv16i8`, and we can have the existing validation within `LowerPARTIAL_REDUCE_MLAToUSDOT` decide whether it can actually be lowered to USDOT (falling back to unpacks and `mla`s if not).

https://github.com/llvm/llvm-project/pull/131327


More information about the llvm-commits mailing list