[llvm] [AArch64][SVE] Add dot product lowering for PARTIAL_REDUCE_MLA node (PR #130933)

Sander de Smalen via llvm-commits llvm-commits at lists.llvm.org
Thu Apr 17 07:41:46 PDT 2025


================
@@ -12528,8 +12528,10 @@ SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
   if (LHSExtOpVT != RHSExtOp.getValueType() || LHSOpcode != RHSOpcode)
     return SDValue();
 
-  // FIXME: Add a check to only perform the DAG combine if there is lowering
-  // provided by the target
+  // Only perform the DAG combine if there is custom lowering provided by the
+  // target
+  if (!TLI.isPartialReduceMLALegalOrCustom(N->getValueType(0), LHSExtOpVT))
----------------
sdesmalen-arm wrote:

I'd suggest moving some of the 'anticipation on what the type is after legalisation' to this place here, such that it only passes in legal types to `isPartialReduceMLALegalOrCustom`.  If the types are not legal, you can query what the legal type will be after type conversion using `TLI.getTypeToTransformTo`.

My understanding is that the operands can be 'legalized' individually, for example if the acc/res type are promoted, that doesn't require the mul-input types to promoted too, or vice-versa. I think the same holds for the splitting. 

That means you can just do something like:

```
if (!TLI.isPartialReduceMLALegalOrCustom(
        TLI.getTypeToTransformTo(N->getValueType(0)),
        TLI.getTypeToTransformTo(LHSExtOpVT)))
  return SDValue();
```

Perhaps if a type needs splitting more than once, you can create a small utility function to call `getTypeToTransformToo` multiple times in order to get the legal type. That might be useful anyway, because a case like this:
```
nxv2i32 partial.reduce.mla nxv2i32 acc, nxv16i32 ext(nxv16i8), nxv16i32 ext(nxv16i8))
```
is currently legalised to:
```
nxv2i64 partial.reduce.mla nxv2i64 acc, nxv16i64 ext(nxv16i8), nxv16i64 ext(nxv16i8))
```
Even though nxv16i8 -> nxv2i64 is currently not supported by udot/sdot. If it would widen the result/accumulator instead of promoting it, we would be able to use udot/sdot. We could represent that information in the utility function.

But for this PR, just calling `getTypeToTransformTo` is sufficient.

https://github.com/llvm/llvm-project/pull/130933


More information about the llvm-commits mailing list