[llvm] [DAGCombiner] Add generic DAG combine for ISD::PARTIAL_REDUCE_MLA (PR #127083)
James Chesterman via llvm-commits
llvm-commits at lists.llvm.org
Fri Feb 28 01:01:17 PST 2025
================
@@ -12497,6 +12501,54 @@ SDValue DAGCombiner::visitMHISTOGRAM(SDNode *N) {
return SDValue();
}
+SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
+ // Makes PARTIAL_REDUCE_*MLA(Acc, MUL(ZEXT(MulOpLHS), ZEXT(MulOpRHS)),
+ // Splat(1)) into
+ // PARTIAL_REDUCE_UMLA(Acc, MulOpLHS, MulOpRHS).
+ // Makes PARTIAL_REDUCE_*MLA(Acc, MUL(SEXT(MulOpLHS), SEXT(MulOpRHS)),
+ // Splat(1)) into
+ // PARTIAL_REDUCE_SMLA(Acc, MulOpLHS, MulOpRHS).
+ SDLoc DL(N);
+
+ SDValue Op0 = N->getOperand(0);
+ SDValue Op1 = N->getOperand(1);
+
+ if (Op1->getOpcode() != ISD::MUL)
+ return SDValue();
+
+ APInt ConstantOne;
+ if (!ISD::isConstantSplatVector(N->getOperand(2).getNode(), ConstantOne) ||
+ !ConstantOne.isOne())
+ return SDValue();
+
+ SDValue ExtMulOpLHS = Op1->getOperand(0);
+ SDValue ExtMulOpRHS = Op1->getOperand(1);
+ unsigned ExtMulOpLHSOpcode = ExtMulOpLHS->getOpcode();
+ unsigned ExtMulOpRHSOpcode = ExtMulOpRHS->getOpcode();
+ if (!ISD::isExtOpcode(ExtMulOpLHSOpcode) ||
+ !ISD::isExtOpcode(ExtMulOpRHSOpcode))
+ return SDValue();
+
+ SDValue MulOpLHS = ExtMulOpLHS->getOperand(0);
+ SDValue MulOpRHS = ExtMulOpRHS->getOperand(0);
+ EVT MulOpLHSVT = MulOpLHS.getValueType();
+ if (MulOpLHSVT != MulOpRHS.getValueType())
+ return SDValue();
+
+ // FIXME: Add a check to only perform the DAG combine if there is lowering
+ // provided by the target
----------------
JamesChesterman wrote:
I think this is better done in a follow up patch, just because if I set a type combination to legal then it will try and go to the lowering, which there is nothing for. So I'd need to set the action for all types to just be an 'expand', but this will make it so the DAG combine does not happen and so there would be no test changes in this patch reflecting that the DAG combine is working successfully. This will be addressed in a future patch however.
https://github.com/llvm/llvm-project/pull/127083
More information about the llvm-commits
mailing list