[llvm] [DAGCombine] Simplify partial_reduce_*mla with constant. (PR #138289)
Benjamin Maxwell via llvm-commits
llvm-commits at lists.llvm.org
Fri May 2 09:02:13 PDT 2025
================
@@ -12612,47 +12612,63 @@ SDValue DAGCombiner::visitMHISTOGRAM(SDNode *N) {
return SDValue();
}
-// Makes PARTIAL_REDUCE_*MLA(Acc, MUL(ZEXT(LHSExtOp), ZEXT(RHSExtOp)),
-// Splat(1)) into
-// PARTIAL_REDUCE_UMLA(Acc, LHSExtOp, RHSExtOp).
-// Makes PARTIAL_REDUCE_*MLA(Acc, MUL(SEXT(LHSExtOp), SEXT(RHSExtOp)),
-// Splat(1)) into
-// PARTIAL_REDUCE_SMLA(Acc, LHSExtOp, RHSExtOp).
+// partial_reduce_*mla(acc, mul(zext(a), zext(b)))
+// -> partial_reduce_umla(acc, a, b)
+//
+// partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1))
+// -> partial_reduce_*mla(acc, x, C)
SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
SDLoc DL(N);
-
+ auto *Context = DAG.getContext();
SDValue Acc = N->getOperand(0);
SDValue Op1 = N->getOperand(1);
SDValue Op2 = N->getOperand(2);
- APInt ConstantOne;
+ APInt C;
if (Op1->getOpcode() != ISD::MUL ||
- !ISD::isConstantSplatVector(Op2.getNode(), ConstantOne) ||
- !ConstantOne.isOne())
+ !ISD::isConstantSplatVector(Op2.getNode(), C) || !C.isOne())
return SDValue();
SDValue LHS = Op1->getOperand(0);
SDValue RHS = Op1->getOperand(1);
unsigned LHSOpcode = LHS->getOpcode();
- unsigned RHSOpcode = RHS->getOpcode();
- if (!ISD::isExtOpcode(LHSOpcode) || !ISD::isExtOpcode(RHSOpcode))
+ if (!ISD::isExtOpcode(LHSOpcode))
return SDValue();
SDValue LHSExtOp = LHS->getOperand(0);
- SDValue RHSExtOp = RHS->getOperand(0);
EVT LHSExtOpVT = LHSExtOp.getValueType();
- if (LHSExtOpVT != RHSExtOp.getValueType() || LHSOpcode != RHSOpcode)
- return SDValue();
- // Only perform the DAG combine if there is custom lowering provided by the
- // target
- auto *Context = DAG.getContext();
+ // Only perform these combines if the target supports folding
+ // the extends into the operation.
if (!TLI.isPartialReduceMLALegalOrCustom(
TLI.getTypeToTransformTo(*Context, N->getValueType(0)),
TLI.getTypeToTransformTo(*Context, LHSExtOpVT)))
return SDValue();
bool ExtIsSigned = LHSOpcode == ISD::SIGN_EXTEND;
+ unsigned NewOpcode =
+ ExtIsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
+
+ // partial_reduce_*mla(acc, mul(zext(x), splat(C)), splat(1))
----------------
MacDue wrote:
```suggestion
// partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1))
```
https://github.com/llvm/llvm-project/pull/138289
More information about the llvm-commits
mailing list