[llvm] [DAGCombiner] Add generic DAG combine for ISD::PARTIAL_REDUCE_MLA (PR #127083)

via llvm-commits llvm-commits at lists.llvm.org
Wed Feb 19 07:17:11 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-selectiondag

Author: James Chesterman (JamesChesterman)

<details>
<summary>Changes</summary>

Add generic DAG combine for ISD::PARTIAL_REDUCE_U/SMLA nodes. Transforms the DAG from:
PARTIAL_REDUCE_MLA(Acc, MUL(EXT(MulOpLHS), EXT(MulOpRHS)), Splat(1)) to
PARTIAL_REDUCE_MLA(Acc, MulOpLHS, MulOpRHS).

---
Full diff: https://github.com/llvm/llvm-project/pull/127083.diff


4 Files Affected:

- (modified) llvm/include/llvm/CodeGen/TargetLowering.h (+35) 
- (modified) llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (+69) 
- (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp (+5-2) 
- (modified) llvm/lib/CodeGen/TargetLoweringBase.cpp (+2-3) 


``````````diff
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index a4c3d042fe3a4..52e57365dceab 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -1639,6 +1639,25 @@ class TargetLoweringBase {
            getCondCodeAction(CC, VT) == Custom;
   }
 
+  /// Return how a PARTIAL_REDUCE_U/SMLA node with Acc type AccVT and Input type
+  /// InputVT should be treated. Either it's legal, needs to be promoted to a
+  /// larger size, needs to be expanded to some other code sequence, or the
+  /// target has a custom expander for it.
+  LegalizeAction getPartialReduceMLAAction(EVT AccVT, EVT InputVT) const {
+    unsigned AccI = (unsigned)AccVT.getSimpleVT().SimpleTy;
+    unsigned InputI = (unsigned)InputVT.getSimpleVT().SimpleTy;
+    assert(AccI < MVT::VALUETYPE_SIZE && InputI < MVT::VALUETYPE_SIZE &&
+           "Table isn't big enough!");
+    return PartialReduceMLAActions[AccI][InputI];
+  }
+
+  /// Return true if a PARTIAL_REDUCE_U/SMLA node with the specified types is
+  /// legal or custom for this target.
+  bool isPartialReduceMLALegalOrCustom(EVT AccVT, EVT InputVT) const {
+    return getPartialReduceMLAAction(AccVT, InputVT) == Legal ||
+           getPartialReduceMLAAction(AccVT, InputVT) == Custom;
+  }
+
   /// If the action for this operation is to promote, this method returns the
   /// ValueType to promote to.
   MVT getTypeToPromoteTo(unsigned Op, MVT VT) const {
@@ -2704,6 +2723,16 @@ class TargetLoweringBase {
       setCondCodeAction(CCs, VT, Action);
   }
 
+  /// Indicate how a PARTIAL_REDUCE_U/SMLA node with Acc type AccVT and Input
+  /// type InputVT should be treated by the target. Either it's legal, needs to
+  /// be promoted to a larger size, needs to be expanded to some other code
+  /// sequence, or the target has a custom expander for it.
+  void setPartialReduceMLAAction(MVT AccVT, MVT InputVT,
+                                 LegalizeAction Action) {
+    assert(AccVT.isValid() && InputVT.isValid() && "Table isn't big enough!");
+    PartialReduceMLAActions[AccVT.SimpleTy][InputVT.SimpleTy] = Action;
+  }
+
   /// If Opc/OrigVT is specified as being promoted, the promotion code defaults
   /// to trying a larger integer/fp until it can find one that works. If that
   /// default is insufficient, this method can be used by the target to override
@@ -3650,6 +3679,12 @@ class TargetLoweringBase {
   /// up the MVT::VALUETYPE_SIZE value to the next multiple of 8.
   uint32_t CondCodeActions[ISD::SETCC_INVALID][(MVT::VALUETYPE_SIZE + 7) / 8];
 
+  /// For each result type and input type for the ISD::PARTIAL_REDUCE_U/SMLA
+  /// nodes, keep a LegalizeAction which indicates how instruction selection
+  /// should deal with this operation.
+  LegalizeAction PartialReduceMLAActions[MVT::VALUETYPE_SIZE]
+                                        [MVT::VALUETYPE_SIZE];
+
   ValueTypeActionImpl ValueTypeActions;
 
 private:
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index bc7cdf38dbc2a..223260c43a38e 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -545,6 +545,7 @@ namespace {
     SDValue visitMGATHER(SDNode *N);
     SDValue visitMSCATTER(SDNode *N);
     SDValue visitMHISTOGRAM(SDNode *N);
+    SDValue visitPARTIAL_REDUCE_MLA(SDNode *N);
     SDValue visitVPGATHER(SDNode *N);
     SDValue visitVPSCATTER(SDNode *N);
     SDValue visitVP_STRIDED_LOAD(SDNode *N);
@@ -621,6 +622,8 @@ namespace {
     SDValue CombineConsecutiveLoads(SDNode *N, EVT VT);
     SDValue foldBitcastedFPLogic(SDNode *N, SelectionDAG &DAG,
                                  const TargetLowering &TLI);
+    SDValue foldMulPARTIAL_REDUCE_MLA(SDNode *N);
+    SDValue foldExtendPARTIAL_REDUCE_MLA(SDNode *N);
 
     SDValue CombineExtLoad(SDNode *N);
     SDValue CombineZExtLogicopShiftLoad(SDNode *N);
@@ -1972,6 +1975,9 @@ SDValue DAGCombiner::visit(SDNode *N) {
   case ISD::MSCATTER:           return visitMSCATTER(N);
   case ISD::MSTORE:             return visitMSTORE(N);
   case ISD::EXPERIMENTAL_VECTOR_HISTOGRAM: return visitMHISTOGRAM(N);
+  case ISD::PARTIAL_REDUCE_SMLA:
+  case ISD::PARTIAL_REDUCE_UMLA:
+                                return visitPARTIAL_REDUCE_MLA(N);
   case ISD::VECTOR_COMPRESS:    return visitVECTOR_COMPRESS(N);
   case ISD::LIFETIME_END:       return visitLIFETIME_END(N);
   case ISD::FP_TO_FP16:         return visitFP_TO_FP16(N);
@@ -12497,6 +12503,69 @@ SDValue DAGCombiner::visitMHISTOGRAM(SDNode *N) {
   return SDValue();
 }
 
+SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
+  // Only perform the DAG combine if there is custom lowering provided by the
+  // target.
+  if (!TLI.isPartialReduceMLALegalOrCustom(N->getValueType(0),
+                                           N->getOperand(1).getValueType()))
+    return SDValue();
+
+  if (SDValue Res = foldMulPARTIAL_REDUCE_MLA(N))
+    return Res;
+  if (SDValue Res = foldExtendPARTIAL_REDUCE_MLA(N))
+    return Res;
+  return SDValue();
+}
+
+SDValue DAGCombiner::foldMulPARTIAL_REDUCE_MLA(SDNode *N) {
+  // Makes PARTIAL_REDUCE_*MLA(Acc, MUL(MulOpLHS, MulOpRHS), Splat(1)) into
+  // PARTIAL_REDUCE_*MLA(Acc, MulOpLHS, MulOpRHS)
+  SDLoc DL(N);
+
+  SDValue Op1 = N->getOperand(1);
+  if (Op1->getOpcode() != ISD::MUL)
+    return SDValue();
+
+  APInt ConstantOne;
+  if (!ISD::isConstantSplatVector(N->getOperand(2).getNode(), ConstantOne) ||
+      !ConstantOne.isOne())
+    return SDValue();
+
+  return DAG.getNode(N->getOpcode(), DL, N->getValueType(0), N->getOperand(0),
+                     Op1->getOperand(0), Op1->getOperand(1));
+}
+
+SDValue DAGCombiner::foldExtendPARTIAL_REDUCE_MLA(SDNode *N) {
+  // Makes PARTIAL_REDUCE_*MLA(Acc, ZEXT(MulOpLHS), ZEXT(MulOpRHS)) into
+  // PARTIAL_REDUCE_UMLA(Acc, MulOpLHS, MulOpRHS) and
+  // PARTIAL_REDUCE_*MLA(Acc, SEXT(MulOpLHS), SEXT(MulOpRHS)) into
+  // PARTIAL_REDUCE_SMLA(Acc, MulOpLHS, MulOpRHS)
+  SDLoc DL(N);
+  SDValue ExtMulOpLHS = N->getOperand(1);
+  SDValue ExtMulOpRHS = N->getOperand(2);
+  unsigned ExtMulOpLHSOpcode = ExtMulOpLHS->getOpcode();
+  unsigned ExtMulOpRHSOpcode = ExtMulOpRHS->getOpcode();
+  if (!ISD::isExtOpcode(ExtMulOpLHSOpcode) ||
+      !ISD::isExtOpcode(ExtMulOpRHSOpcode))
+    return SDValue();
+
+  SDValue MulOpLHS = ExtMulOpLHS->getOperand(0);
+  SDValue MulOpRHS = ExtMulOpRHS->getOperand(0);
+  EVT MulOpLHSVT = MulOpLHS.getValueType();
+  if (MulOpLHSVT != MulOpRHS.getValueType())
+    return SDValue();
+
+  bool LHSIsSigned = ExtMulOpLHSOpcode == ISD::SIGN_EXTEND;
+  bool RHSIsSigned = ExtMulOpRHSOpcode == ISD::SIGN_EXTEND;
+  if (LHSIsSigned != RHSIsSigned)
+    return SDValue();
+
+  unsigned NewOpcode =
+      LHSIsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
+  return DAG.getNode(NewOpcode, DL, N->getValueType(0), N->getOperand(0),
+                     MulOpLHS, MulOpRHS);
+}
+
 SDValue DAGCombiner::visitVP_STRIDED_LOAD(SDNode *N) {
   auto *SLD = cast<VPStridedLoadSDNode>(N);
   EVT EltVT = SLD->getValueType(0).getVectorElementType();
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
index de4447fb0cf1a..e43b14a47e565 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
@@ -469,8 +469,6 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
   case ISD::VECTOR_COMPRESS:
   case ISD::SCMP:
   case ISD::UCMP:
-  case ISD::PARTIAL_REDUCE_UMLA:
-  case ISD::PARTIAL_REDUCE_SMLA:
     Action = TLI.getOperationAction(Node->getOpcode(), Node->getValueType(0));
     break;
   case ISD::SMULFIX:
@@ -524,6 +522,11 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
       Action = TLI.getOperationAction(Node->getOpcode(), OpVT);
     break;
   }
+  case ISD::PARTIAL_REDUCE_UMLA:
+  case ISD::PARTIAL_REDUCE_SMLA:
+    Action = TLI.getPartialReduceMLAAction(Node->getValueType(0),
+                                           Node->getOperand(1).getValueType());
+    break;
 
 #define BEGIN_REGISTER_VP_SDNODE(VPID, LEGALPOS, ...)                          \
   case ISD::VPID: {                                                            \
diff --git a/llvm/lib/CodeGen/TargetLoweringBase.cpp b/llvm/lib/CodeGen/TargetLoweringBase.cpp
index f5ea3c0b47d6a..af97ce20fdb10 100644
--- a/llvm/lib/CodeGen/TargetLoweringBase.cpp
+++ b/llvm/lib/CodeGen/TargetLoweringBase.cpp
@@ -836,9 +836,8 @@ void TargetLoweringBase::initActions() {
     setOperationAction(ISD::SET_FPENV, VT, Expand);
     setOperationAction(ISD::RESET_FPENV, VT, Expand);
 
-    // PartialReduceMLA operations default to expand.
-    setOperationAction({ISD::PARTIAL_REDUCE_UMLA, ISD::PARTIAL_REDUCE_SMLA}, VT,
-                       Expand);
+    for (MVT InputVT : MVT::all_valuetypes())
+      setPartialReduceMLAAction(VT, InputVT, Expand);
   }
 
   // Most targets ignore the @llvm.prefetch intrinsic.

``````````

</details>


https://github.com/llvm/llvm-project/pull/127083


More information about the llvm-commits mailing list