[llvm] [DAGCombiner] Add generic DAG combine for ISD::PARTIAL_REDUCE_MLA (PR #127083)
via llvm-commits
llvm-commits at lists.llvm.org
Wed Feb 19 07:17:11 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-selectiondag
Author: James Chesterman (JamesChesterman)
<details>
<summary>Changes</summary>
Add generic DAG combine for ISD::PARTIAL_REDUCE_U/SMLA nodes. Transforms the DAG from:
PARTIAL_REDUCE_MLA(Acc, MUL(EXT(MulOpLHS), EXT(MulOpRHS)), Splat(1)) to
PARTIAL_REDUCE_MLA(Acc, MulOpLHS, MulOpRHS).
---
Full diff: https://github.com/llvm/llvm-project/pull/127083.diff
4 Files Affected:
- (modified) llvm/include/llvm/CodeGen/TargetLowering.h (+35)
- (modified) llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (+69)
- (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp (+5-2)
- (modified) llvm/lib/CodeGen/TargetLoweringBase.cpp (+2-3)
``````````diff
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index a4c3d042fe3a4..52e57365dceab 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -1639,6 +1639,25 @@ class TargetLoweringBase {
getCondCodeAction(CC, VT) == Custom;
}
+ /// Return how a PARTIAL_REDUCE_U/SMLA node with Acc type AccVT and Input type
+ /// InputVT should be treated. Either it's legal, needs to be promoted to a
+ /// larger size, needs to be expanded to some other code sequence, or the
+ /// target has a custom expander for it.
+ LegalizeAction getPartialReduceMLAAction(EVT AccVT, EVT InputVT) const {
+ unsigned AccI = (unsigned)AccVT.getSimpleVT().SimpleTy;
+ unsigned InputI = (unsigned)InputVT.getSimpleVT().SimpleTy;
+ assert(AccI < MVT::VALUETYPE_SIZE && InputI < MVT::VALUETYPE_SIZE &&
+ "Table isn't big enough!");
+ return PartialReduceMLAActions[AccI][InputI];
+ }
+
+ /// Return true if a PARTIAL_REDUCE_U/SMLA node with the specified types is
+ /// legal or custom for this target.
+ bool isPartialReduceMLALegalOrCustom(EVT AccVT, EVT InputVT) const {
+ return getPartialReduceMLAAction(AccVT, InputVT) == Legal ||
+ getPartialReduceMLAAction(AccVT, InputVT) == Custom;
+ }
+
/// If the action for this operation is to promote, this method returns the
/// ValueType to promote to.
MVT getTypeToPromoteTo(unsigned Op, MVT VT) const {
@@ -2704,6 +2723,16 @@ class TargetLoweringBase {
setCondCodeAction(CCs, VT, Action);
}
+ /// Indicate how a PARTIAL_REDUCE_U/SMLA node with Acc type AccVT and Input
+ /// type InputVT should be treated by the target. Either it's legal, needs to
+ /// be promoted to a larger size, needs to be expanded to some other code
+ /// sequence, or the target has a custom expander for it.
+ void setPartialReduceMLAAction(MVT AccVT, MVT InputVT,
+ LegalizeAction Action) {
+ assert(AccVT.isValid() && InputVT.isValid() && "Table isn't big enough!");
+ PartialReduceMLAActions[AccVT.SimpleTy][InputVT.SimpleTy] = Action;
+ }
+
/// If Opc/OrigVT is specified as being promoted, the promotion code defaults
/// to trying a larger integer/fp until it can find one that works. If that
/// default is insufficient, this method can be used by the target to override
@@ -3650,6 +3679,12 @@ class TargetLoweringBase {
/// up the MVT::VALUETYPE_SIZE value to the next multiple of 8.
uint32_t CondCodeActions[ISD::SETCC_INVALID][(MVT::VALUETYPE_SIZE + 7) / 8];
+ /// For each result type and input type for the ISD::PARTIAL_REDUCE_U/SMLA
+ /// nodes, keep a LegalizeAction which indicates how instruction selection
+ /// should deal with this operation.
+ LegalizeAction PartialReduceMLAActions[MVT::VALUETYPE_SIZE]
+ [MVT::VALUETYPE_SIZE];
+
ValueTypeActionImpl ValueTypeActions;
private:
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index bc7cdf38dbc2a..223260c43a38e 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -545,6 +545,7 @@ namespace {
SDValue visitMGATHER(SDNode *N);
SDValue visitMSCATTER(SDNode *N);
SDValue visitMHISTOGRAM(SDNode *N);
+ SDValue visitPARTIAL_REDUCE_MLA(SDNode *N);
SDValue visitVPGATHER(SDNode *N);
SDValue visitVPSCATTER(SDNode *N);
SDValue visitVP_STRIDED_LOAD(SDNode *N);
@@ -621,6 +622,8 @@ namespace {
SDValue CombineConsecutiveLoads(SDNode *N, EVT VT);
SDValue foldBitcastedFPLogic(SDNode *N, SelectionDAG &DAG,
const TargetLowering &TLI);
+ SDValue foldMulPARTIAL_REDUCE_MLA(SDNode *N);
+ SDValue foldExtendPARTIAL_REDUCE_MLA(SDNode *N);
SDValue CombineExtLoad(SDNode *N);
SDValue CombineZExtLogicopShiftLoad(SDNode *N);
@@ -1972,6 +1975,9 @@ SDValue DAGCombiner::visit(SDNode *N) {
case ISD::MSCATTER: return visitMSCATTER(N);
case ISD::MSTORE: return visitMSTORE(N);
case ISD::EXPERIMENTAL_VECTOR_HISTOGRAM: return visitMHISTOGRAM(N);
+ case ISD::PARTIAL_REDUCE_SMLA:
+ case ISD::PARTIAL_REDUCE_UMLA:
+ return visitPARTIAL_REDUCE_MLA(N);
case ISD::VECTOR_COMPRESS: return visitVECTOR_COMPRESS(N);
case ISD::LIFETIME_END: return visitLIFETIME_END(N);
case ISD::FP_TO_FP16: return visitFP_TO_FP16(N);
@@ -12497,6 +12503,69 @@ SDValue DAGCombiner::visitMHISTOGRAM(SDNode *N) {
return SDValue();
}
+SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
+ // Only perform the DAG combine if there is custom lowering provided by the
+ // target.
+ if (!TLI.isPartialReduceMLALegalOrCustom(N->getValueType(0),
+ N->getOperand(1).getValueType()))
+ return SDValue();
+
+ if (SDValue Res = foldMulPARTIAL_REDUCE_MLA(N))
+ return Res;
+ if (SDValue Res = foldExtendPARTIAL_REDUCE_MLA(N))
+ return Res;
+ return SDValue();
+}
+
+SDValue DAGCombiner::foldMulPARTIAL_REDUCE_MLA(SDNode *N) {
+ // Makes PARTIAL_REDUCE_*MLA(Acc, MUL(MulOpLHS, MulOpRHS), Splat(1)) into
+ // PARTIAL_REDUCE_*MLA(Acc, MulOpLHS, MulOpRHS)
+ SDLoc DL(N);
+
+ SDValue Op1 = N->getOperand(1);
+ if (Op1->getOpcode() != ISD::MUL)
+ return SDValue();
+
+ APInt ConstantOne;
+ if (!ISD::isConstantSplatVector(N->getOperand(2).getNode(), ConstantOne) ||
+ !ConstantOne.isOne())
+ return SDValue();
+
+ return DAG.getNode(N->getOpcode(), DL, N->getValueType(0), N->getOperand(0),
+ Op1->getOperand(0), Op1->getOperand(1));
+}
+
+SDValue DAGCombiner::foldExtendPARTIAL_REDUCE_MLA(SDNode *N) {
+ // Makes PARTIAL_REDUCE_*MLA(Acc, ZEXT(MulOpLHS), ZEXT(MulOpRHS)) into
+ // PARTIAL_REDUCE_UMLA(Acc, MulOpLHS, MulOpRHS) and
+ // PARTIAL_REDUCE_*MLA(Acc, SEXT(MulOpLHS), SEXT(MulOpRHS)) into
+ // PARTIAL_REDUCE_SMLA(Acc, MulOpLHS, MulOpRHS)
+ SDLoc DL(N);
+ SDValue ExtMulOpLHS = N->getOperand(1);
+ SDValue ExtMulOpRHS = N->getOperand(2);
+ unsigned ExtMulOpLHSOpcode = ExtMulOpLHS->getOpcode();
+ unsigned ExtMulOpRHSOpcode = ExtMulOpRHS->getOpcode();
+ if (!ISD::isExtOpcode(ExtMulOpLHSOpcode) ||
+ !ISD::isExtOpcode(ExtMulOpRHSOpcode))
+ return SDValue();
+
+ SDValue MulOpLHS = ExtMulOpLHS->getOperand(0);
+ SDValue MulOpRHS = ExtMulOpRHS->getOperand(0);
+ EVT MulOpLHSVT = MulOpLHS.getValueType();
+ if (MulOpLHSVT != MulOpRHS.getValueType())
+ return SDValue();
+
+ bool LHSIsSigned = ExtMulOpLHSOpcode == ISD::SIGN_EXTEND;
+ bool RHSIsSigned = ExtMulOpRHSOpcode == ISD::SIGN_EXTEND;
+ if (LHSIsSigned != RHSIsSigned)
+ return SDValue();
+
+ unsigned NewOpcode =
+ LHSIsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
+ return DAG.getNode(NewOpcode, DL, N->getValueType(0), N->getOperand(0),
+ MulOpLHS, MulOpRHS);
+}
+
SDValue DAGCombiner::visitVP_STRIDED_LOAD(SDNode *N) {
auto *SLD = cast<VPStridedLoadSDNode>(N);
EVT EltVT = SLD->getValueType(0).getVectorElementType();
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
index de4447fb0cf1a..e43b14a47e565 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
@@ -469,8 +469,6 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
case ISD::VECTOR_COMPRESS:
case ISD::SCMP:
case ISD::UCMP:
- case ISD::PARTIAL_REDUCE_UMLA:
- case ISD::PARTIAL_REDUCE_SMLA:
Action = TLI.getOperationAction(Node->getOpcode(), Node->getValueType(0));
break;
case ISD::SMULFIX:
@@ -524,6 +522,11 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
Action = TLI.getOperationAction(Node->getOpcode(), OpVT);
break;
}
+ case ISD::PARTIAL_REDUCE_UMLA:
+ case ISD::PARTIAL_REDUCE_SMLA:
+ Action = TLI.getPartialReduceMLAAction(Node->getValueType(0),
+ Node->getOperand(1).getValueType());
+ break;
#define BEGIN_REGISTER_VP_SDNODE(VPID, LEGALPOS, ...) \
case ISD::VPID: { \
diff --git a/llvm/lib/CodeGen/TargetLoweringBase.cpp b/llvm/lib/CodeGen/TargetLoweringBase.cpp
index f5ea3c0b47d6a..af97ce20fdb10 100644
--- a/llvm/lib/CodeGen/TargetLoweringBase.cpp
+++ b/llvm/lib/CodeGen/TargetLoweringBase.cpp
@@ -836,9 +836,8 @@ void TargetLoweringBase::initActions() {
setOperationAction(ISD::SET_FPENV, VT, Expand);
setOperationAction(ISD::RESET_FPENV, VT, Expand);
- // PartialReduceMLA operations default to expand.
- setOperationAction({ISD::PARTIAL_REDUCE_UMLA, ISD::PARTIAL_REDUCE_SMLA}, VT,
- Expand);
+ for (MVT InputVT : MVT::all_valuetypes())
+ setPartialReduceMLAAction(VT, InputVT, Expand);
}
// Most targets ignore the @llvm.prefetch intrinsic.
``````````
</details>
https://github.com/llvm/llvm-project/pull/127083
More information about the llvm-commits
mailing list