[llvm] [AArch64][SVE] Add dot product lowering for PARTIAL_REDUCE_MLA node (PR #130933)
Sander de Smalen via llvm-commits
llvm-commits at lists.llvm.org
Thu Apr 17 07:41:46 PDT 2025
================
@@ -12528,8 +12528,10 @@ SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
if (LHSExtOpVT != RHSExtOp.getValueType() || LHSOpcode != RHSOpcode)
return SDValue();
- // FIXME: Add a check to only perform the DAG combine if there is lowering
- // provided by the target
+ // Only perform the DAG combine if there is custom lowering provided by the
+ // target
+ if (!TLI.isPartialReduceMLALegalOrCustom(N->getValueType(0), LHSExtOpVT))
----------------
sdesmalen-arm wrote:
I'd suggest moving some of the 'anticipation on what the type is after legalisation' to this place here, such that it only passes in legal types to `isPartialReduceMLALegalOrCustom`. If the types are not legal, you can query what the legal type will be after type conversion using `TLI.getTypeToTransformTo`.
My understanding is that the operands can be 'legalized' individually, for example if the acc/res type are promoted, that doesn't require the mul-input types to promoted too, or vice-versa. I think the same holds for the splitting.
That means you can just do something like:
```
if (!TLI.isPartialReduceMLALegalOrCustom(
TLI.getTypeToTransformTo(N->getValueType(0)),
TLI.getTypeToTransformTo(LHSExtOpVT)))
return SDValue();
```
Perhaps if a type needs splitting more than once, you can create a small utility function to call `getTypeToTransformToo` multiple times in order to get the legal type. That might be useful anyway, because a case like this:
```
nxv2i32 partial.reduce.mla nxv2i32 acc, nxv16i32 ext(nxv16i8), nxv16i32 ext(nxv16i8))
```
is currently legalised to:
```
nxv2i64 partial.reduce.mla nxv2i64 acc, nxv16i64 ext(nxv16i8), nxv16i64 ext(nxv16i8))
```
Even though nxv16i8 -> nxv2i64 is currently not supported by udot/sdot. If it would widen the result/accumulator instead of promoting it, we would be able to use udot/sdot. We could represent that information in the utility function.
But for this PR, just calling `getTypeToTransformTo` is sufficient.
https://github.com/llvm/llvm-project/pull/130933
More information about the llvm-commits
mailing list