[llvm] [SDAG] Add partial_reduce_sumla node (PR #141267)
Nicholas Guy via llvm-commits
llvm-commits at lists.llvm.org
Mon Jun 2 22:12:42 PDT 2025
================
@@ -12702,26 +12704,46 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
return SDValue();
SDValue RHSExtOp = RHS->getOperand(0);
- if (LHSExtOpVT != RHSExtOp.getValueType() || LHSOpcode != RHSOpcode)
+ if (LHSExtOpVT != RHSExtOp.getValueType())
+ return SDValue();
+
+ unsigned NewOpc;
+ if (LHSOpcode == ISD::SIGN_EXTEND && RHSOpcode == ISD::SIGN_EXTEND)
+ NewOpc = ISD::PARTIAL_REDUCE_SMLA;
+ else if (LHSOpcode == ISD::ZERO_EXTEND && RHSOpcode == ISD::ZERO_EXTEND)
+ NewOpc = ISD::PARTIAL_REDUCE_UMLA;
+ else if (LHSOpcode == ISD::SIGN_EXTEND && RHSOpcode == ISD::ZERO_EXTEND)
+ NewOpc = ISD::PARTIAL_REDUCE_SUMLA;
+ else if (LHSOpcode == ISD::ZERO_EXTEND && RHSOpcode == ISD::SIGN_EXTEND) {
+ NewOpc = ISD::PARTIAL_REDUCE_SUMLA;
+ std::swap(LHSExtOp, RHSExtOp);
+ } else
return SDValue();
-
- // For a 2-stage extend the signedness of both of the extends must be the
- // same. This is so the node can be folded into only a signed or unsigned
- // node.
- bool NodeIsSigned = N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA;
+ // For a 2-stage extend the signedness of both of the extends must match
+ // If the mul has the same type, there is no outer extend, and thus we
+ // can simply use the inner extends to pick the result node.
+ // TODO: extend to handle nonneg zext as sext
----------------
NickGuy-Arm wrote:
Can you elaborate on what you mean by this TODO? I'm not sure I follow why we'd want to handle a `zext` as a `sext` in this case.
https://github.com/llvm/llvm-project/pull/141267
More information about the llvm-commits
mailing list