[llvm] [SDAG] Add partial_reduce_sumla node (PR #141267)

Nicholas Guy via llvm-commits llvm-commits at lists.llvm.org
Mon Jun 2 22:12:42 PDT 2025


================
@@ -12702,26 +12704,46 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
     return SDValue();
 
   SDValue RHSExtOp = RHS->getOperand(0);
-  if (LHSExtOpVT != RHSExtOp.getValueType() || LHSOpcode != RHSOpcode)
+  if (LHSExtOpVT != RHSExtOp.getValueType())
+    return SDValue();
+
+  unsigned NewOpc;
+  if (LHSOpcode == ISD::SIGN_EXTEND && RHSOpcode == ISD::SIGN_EXTEND)
+    NewOpc = ISD::PARTIAL_REDUCE_SMLA;
+  else if (LHSOpcode == ISD::ZERO_EXTEND && RHSOpcode == ISD::ZERO_EXTEND)
+    NewOpc = ISD::PARTIAL_REDUCE_UMLA;
+  else if (LHSOpcode == ISD::SIGN_EXTEND && RHSOpcode == ISD::ZERO_EXTEND)
+    NewOpc = ISD::PARTIAL_REDUCE_SUMLA;
+  else if (LHSOpcode == ISD::ZERO_EXTEND && RHSOpcode == ISD::SIGN_EXTEND) {
+    NewOpc = ISD::PARTIAL_REDUCE_SUMLA;
+    std::swap(LHSExtOp, RHSExtOp);
+  } else
     return SDValue();
-
-  // For a 2-stage extend the signedness of both of the extends must be the
-  // same. This is so the node can be folded into only a signed or unsigned
-  // node.
-  bool NodeIsSigned = N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA;
+  // For a 2-stage extend the signedness of both of the extends must match
+  // If the mul has the same type, there is no outer extend, and thus we
+  // can simply use the inner extends to pick the result node.
+  // TODO: extend to handle nonneg zext as sext
----------------
NickGuy-Arm wrote:

Can you elaborate on what you mean by this TODO? I'm not sure I follow why we'd want to handle a `zext` as a `sext` in this case.

https://github.com/llvm/llvm-project/pull/141267


More information about the llvm-commits mailing list