[llvm] [DAGCombine] Simplify partial_reduce_*mla with constant. (PR #138289)

Benjamin Maxwell via llvm-commits llvm-commits at lists.llvm.org
Fri May 2 09:02:13 PDT 2025


================
@@ -12612,47 +12612,63 @@ SDValue DAGCombiner::visitMHISTOGRAM(SDNode *N) {
   return SDValue();
 }
 
-// Makes PARTIAL_REDUCE_*MLA(Acc, MUL(ZEXT(LHSExtOp), ZEXT(RHSExtOp)),
-// Splat(1)) into
-// PARTIAL_REDUCE_UMLA(Acc, LHSExtOp, RHSExtOp).
-// Makes PARTIAL_REDUCE_*MLA(Acc, MUL(SEXT(LHSExtOp), SEXT(RHSExtOp)),
-// Splat(1)) into
-// PARTIAL_REDUCE_SMLA(Acc, LHSExtOp, RHSExtOp).
+// partial_reduce_*mla(acc, mul(zext(a), zext(b)))
+// -> partial_reduce_umla(acc, a, b)
+//
+// partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1))
+// -> partial_reduce_*mla(acc, x, C)
 SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
   SDLoc DL(N);
-
+  auto *Context = DAG.getContext();
   SDValue Acc = N->getOperand(0);
   SDValue Op1 = N->getOperand(1);
   SDValue Op2 = N->getOperand(2);
 
-  APInt ConstantOne;
+  APInt C;
   if (Op1->getOpcode() != ISD::MUL ||
-      !ISD::isConstantSplatVector(Op2.getNode(), ConstantOne) ||
-      !ConstantOne.isOne())
+      !ISD::isConstantSplatVector(Op2.getNode(), C) || !C.isOne())
     return SDValue();
 
   SDValue LHS = Op1->getOperand(0);
   SDValue RHS = Op1->getOperand(1);
   unsigned LHSOpcode = LHS->getOpcode();
-  unsigned RHSOpcode = RHS->getOpcode();
-  if (!ISD::isExtOpcode(LHSOpcode) || !ISD::isExtOpcode(RHSOpcode))
+  if (!ISD::isExtOpcode(LHSOpcode))
     return SDValue();
 
   SDValue LHSExtOp = LHS->getOperand(0);
-  SDValue RHSExtOp = RHS->getOperand(0);
   EVT LHSExtOpVT = LHSExtOp.getValueType();
-  if (LHSExtOpVT != RHSExtOp.getValueType() || LHSOpcode != RHSOpcode)
-    return SDValue();
 
-  // Only perform the DAG combine if there is custom lowering provided by the
-  // target
-  auto *Context = DAG.getContext();
+  // Only perform these combines if the target supports folding
+  // the extends into the operation.
   if (!TLI.isPartialReduceMLALegalOrCustom(
           TLI.getTypeToTransformTo(*Context, N->getValueType(0)),
           TLI.getTypeToTransformTo(*Context, LHSExtOpVT)))
     return SDValue();
 
   bool ExtIsSigned = LHSOpcode == ISD::SIGN_EXTEND;
+  unsigned NewOpcode =
+      ExtIsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
+
+  // partial_reduce_*mla(acc, mul(zext(x), splat(C)), splat(1))
----------------
MacDue wrote:

```suggestion
  // partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1))
```

https://github.com/llvm/llvm-project/pull/138289


More information about the llvm-commits mailing list