[llvm] [DAGCombiner] Add generic DAG combine for ISD::PARTIAL_REDUCE_MLA (PR #127083)

Sander de Smalen via llvm-commits llvm-commits at lists.llvm.org
Mon Feb 24 13:34:37 PST 2025


================
@@ -12497,6 +12501,54 @@ SDValue DAGCombiner::visitMHISTOGRAM(SDNode *N) {
   return SDValue();
 }
 
+SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
+  // Makes PARTIAL_REDUCE_*MLA(Acc, MUL(ZEXT(MulOpLHS), ZEXT(MulOpRHS)),
+  // Splat(1)) into
+  // PARTIAL_REDUCE_UMLA(Acc, MulOpLHS, MulOpRHS).
+  // Makes PARTIAL_REDUCE_*MLA(Acc, MUL(SEXT(MulOpLHS), SEXT(MulOpRHS)),
+  // Splat(1)) into
+  // PARTIAL_REDUCE_SMLA(Acc, MulOpLHS, MulOpRHS).
+  SDLoc DL(N);
+
+  SDValue Op0 = N->getOperand(0);
+  SDValue Op1 = N->getOperand(1);
+
+  if (Op1->getOpcode() != ISD::MUL)
+    return SDValue();
+
+  APInt ConstantOne;
+  if (!ISD::isConstantSplatVector(N->getOperand(2).getNode(), ConstantOne) ||
+      !ConstantOne.isOne())
+    return SDValue();
+
+  SDValue ExtMulOpLHS = Op1->getOperand(0);
+  SDValue ExtMulOpRHS = Op1->getOperand(1);
+  unsigned ExtMulOpLHSOpcode = ExtMulOpLHS->getOpcode();
+  unsigned ExtMulOpRHSOpcode = ExtMulOpRHS->getOpcode();
+  if (!ISD::isExtOpcode(ExtMulOpLHSOpcode) ||
+      !ISD::isExtOpcode(ExtMulOpRHSOpcode))
+    return SDValue();
+
+  SDValue MulOpLHS = ExtMulOpLHS->getOperand(0);
+  SDValue MulOpRHS = ExtMulOpRHS->getOperand(0);
+  EVT MulOpLHSVT = MulOpLHS.getValueType();
+  if (MulOpLHSVT != MulOpRHS.getValueType())
+    return SDValue();
+
+  // FIXME: Add a check to only perform the DAG combine if there is lowering
+  // provided by the target
+
+  bool LHSIsSigned = ExtMulOpLHSOpcode == ISD::SIGN_EXTEND;
+  bool RHSIsSigned = ExtMulOpRHSOpcode == ISD::SIGN_EXTEND;
+  if (LHSIsSigned != RHSIsSigned)
----------------
sdesmalen-arm wrote:

This needs a check to ensure that we **don't** fold for the following case:

```
partial_reduce_umla v2i64 %acc, mul(sext(v2i16 %lhs) to v2i32), sext(v2i16 %rhs) to v2i32)
->
partial_reduce_smla v2i64 %acc, v2i16 %lhs, v2i16 %rhs)
```

Because this would have both sign and zero-extension and changing that to only sign-extending would result in different semantics. There isn't currently a way to write an LLVM IR test-case for this though, but I want to prevent this combine from doing an incorrect combine when the node has been created from elsewhere in selectiondag.

https://github.com/llvm/llvm-project/pull/127083


More information about the llvm-commits mailing list