[llvm] [DAGCombiner] Add DAG combine for PARTIAL_REDUCE_MLA when no mul op (PR #131326)

Sander de Smalen via llvm-commits llvm-commits at lists.llvm.org
Fri May 2 08:06:50 PDT 2025


================
@@ -12669,6 +12679,48 @@ SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
                      RHSExtOp);
 }
 
+// Makes partial.reduce.umla(acc, zext(op1), splat(1)) into
+// partial.reduce.umla(acc, op, splat(trunc(1)))
+// Makes partial.reduce.smla(acc, sext(op1), splat(1)) into
+// partial.reduce.smla(acc, op, splat(trunc(1)))
+SDValue DAGCombiner::foldPartialReduceMLANoMulOp(SDNode *N) {
+  SDLoc DL(N);
+  SDValue Acc = N->getOperand(0);
+  SDValue Op1 = N->getOperand(1);
+  SDValue Op2 = N->getOperand(2);
+
+  APInt ConstantOne;
+  if (!ISD::isConstantSplatVector(Op2.getNode(), ConstantOne) ||
+      !ConstantOne.isOne())
+    return SDValue();
+
+  unsigned Op1Opcode = Op1.getOpcode();
+  if (!ISD::isExtOpcode(Op1Opcode))
+    return SDValue();
+
+  SDValue UnextOp1 = Op1.getOperand(0);
+  EVT UnextOp1VT = UnextOp1.getValueType();
+
+  if (!TLI.isPartialReduceMLALegalOrCustom(N->getValueType(0), UnextOp1VT))
+    return SDValue();
+
+  SDValue TruncOp2 = DAG.getNode(ISD::TRUNCATE, DL, UnextOp1VT, Op2);
+
+  bool Op1IsSigned = Op1Opcode == ISD::SIGN_EXTEND;
+
+  bool NodeIsSigned = N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA;
+  EVT AccElemVT = Acc.getValueType().getVectorElementType();
+  if (Op1IsSigned != NodeIsSigned &&
+      (Op1.getValueType().getVectorElementType() != AccElemVT ||
+       Op2.getValueType().getVectorElementType() != AccElemVT))
----------------
sdesmalen-arm wrote:

I don't think you need to test the `Op2` case here, because the type of `Op1` must match that of `Op2`.

https://github.com/llvm/llvm-project/pull/131326


More information about the llvm-commits mailing list