[llvm] [CodeGen] Implement widening for partial.reduce.add (PR #161834)
Sam Tebbs via llvm-commits
llvm-commits at lists.llvm.org
Wed Oct 8 08:48:04 PDT 2025
================
@@ -7008,10 +7008,34 @@ SDValue DAGTypeLegalizer::WidenVecRes_PARTIAL_REDUCE_MLA(SDNode *N) {
EVT VT = N->getValueType(0);
EVT WideAccVT = TLI.getTypeToTransformTo(*DAG.getContext(),
N->getOperand(0).getValueType());
- SDValue Zero = DAG.getConstant(0, DL, WideAccVT);
+ ElementCount WideAccEC = WideAccVT.getVectorElementCount();
+
+ // Widen mul-operands if needed, otherwise we'll end up with a
+ // node that isn't legal because the accumulator vector will not
+ // be a known multiple of the input vector.
SDValue MulOp1 = N->getOperand(1);
SDValue MulOp2 = N->getOperand(2);
- SDValue Acc = DAG.getInsertSubvector(DL, Zero, N->getOperand(0), 0);
+ EVT MulOpVT = MulOp1.getValueType();
+ ElementCount MulOpEC = MulOpVT.getVectorElementCount();
+ if (getTypeAction(MulOpVT) == TargetLowering::TypeWidenVector) {
+ EVT WideMulVT = GetWidenedVector(MulOp1).getValueType();
+ assert(WideMulVT.getVectorElementCount().isKnownMultipleOf(WideAccEC) &&
+ "Widening to a vector with less elements than accumulator?");
+ SDValue Zero = DAG.getConstant(0, DL, WideMulVT);
+ MulOp1 = DAG.getInsertSubvector(DL, Zero, MulOp1, 0);
+ MulOp2 = DAG.getInsertSubvector(DL, Zero, MulOp2, 0);
+ } else if (!MulOpEC.isKnownMultipleOf(WideAccEC)) {
+ assert(getTypeAction(MulOpVT) != TargetLowering::TypeLegal &&
+ "Expected Mul operands to need legalisation");
+ EVT WideMulVT = EVT::getVectorVT(*DAG.getContext(),
+ MulOpVT.getVectorElementType(), WideAccEC);
+ SDValue Zero = DAG.getConstant(0, DL, WideMulVT);
+ MulOp1 = DAG.getInsertSubvector(DL, Zero, MulOp1, 0);
+ MulOp2 = DAG.getInsertSubvector(DL, Zero, MulOp2, 0);
+ }
----------------
SamTebbs33 wrote:
It looks like the only difference between these two blocks is their assertions and how they assign `WideMulVT``, so they could share the zero vector construction and subvector insertion as:
```
bool NeedsWidening = getTypeAction(MulOpVT) == TargetLowering::TypeWidenVector;
bool NarrowMultipleOfWide = MulOpEC.isKnownMultipleOf(WideAccEC);
if (NeedsWidening || !NarrowMultipleOfWide) {
EVT WideMulVT;
if (NeedsWidening) {
assert(...)
...
} else {
assert(...)
...
}
SDValue Zero = ...
MulOp1 = ...
MulOp2 = ...
}
```
https://github.com/llvm/llvm-project/pull/161834
More information about the llvm-commits
mailing list