[llvm] [SDAG] Split the partial reduce legalize table by opcode [nfc] (PR #141970)

via llvm-commits llvm-commits at lists.llvm.org
Thu May 29 08:55:23 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-aarch64

Author: Philip Reames (preames)

<details>
<summary>Changes</summary>

On it's own, this change should be non-functional. This is a preparatory change for https://github.com/llvm/llvm-project/pull/141267 which adds a new form of PARTIAL_REDUCE_*MLA.  As noted in the discussion on that review, AArch64 needs a different set of legal and custom types for the PARTIAL_REDUCE_SUMLA variant than the currently existing PARTIAL_REDUCE_UMLA/SMLA.

---
Full diff: https://github.com/llvm/llvm-project/pull/141970.diff


5 Files Affected:

- (modified) llvm/include/llvm/CodeGen/TargetLowering.h (+22-13) 
- (modified) llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (+14-13) 
- (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp (+3-2) 
- (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+14-9) 
- (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+8-6) 


``````````diff
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index b818f4768c2c3..9c453f51e129d 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -1659,17 +1659,20 @@ class LLVM_ABI TargetLoweringBase {
   /// InputVT should be treated. Either it's legal, needs to be promoted to a
   /// larger size, needs to be expanded to some other code sequence, or the
   /// target has a custom expander for it.
-  LegalizeAction getPartialReduceMLAAction(EVT AccVT, EVT InputVT) const {
-    PartialReduceActionTypes TypePair = {AccVT.getSimpleVT().SimpleTy,
-                                         InputVT.getSimpleVT().SimpleTy};
-    auto It = PartialReduceMLAActions.find(TypePair);
+  LegalizeAction getPartialReduceMLAAction(unsigned Opc, EVT AccVT,
+                                           EVT InputVT) const {
+    assert(Opc == ISD::PARTIAL_REDUCE_SMLA || Opc == ISD::PARTIAL_REDUCE_UMLA);
+    PartialReduceActionTypes Key = {Opc, AccVT.getSimpleVT().SimpleTy,
+                                    InputVT.getSimpleVT().SimpleTy};
+    auto It = PartialReduceMLAActions.find(Key);
     return It != PartialReduceMLAActions.end() ? It->second : Expand;
   }
 
   /// Return true if a PARTIAL_REDUCE_U/SMLA node with the specified types is
   /// legal or custom for this target.
-  bool isPartialReduceMLALegalOrCustom(EVT AccVT, EVT InputVT) const {
-    LegalizeAction Action = getPartialReduceMLAAction(AccVT, InputVT);
+  bool isPartialReduceMLALegalOrCustom(unsigned Opc, EVT AccVT,
+                                       EVT InputVT) const {
+    LegalizeAction Action = getPartialReduceMLAAction(Opc, AccVT, InputVT);
     return Action == Legal || Action == Custom;
   }
 
@@ -2754,12 +2757,18 @@ class LLVM_ABI TargetLoweringBase {
   /// type InputVT should be treated by the target. Either it's legal, needs to
   /// be promoted to a larger size, needs to be expanded to some other code
   /// sequence, or the target has a custom expander for it.
-  void setPartialReduceMLAAction(MVT AccVT, MVT InputVT,
+  void setPartialReduceMLAAction(unsigned Opc, MVT AccVT, MVT InputVT,
                                  LegalizeAction Action) {
+    assert(Opc == ISD::PARTIAL_REDUCE_SMLA || Opc == ISD::PARTIAL_REDUCE_UMLA);
     assert(AccVT.isValid() && InputVT.isValid() &&
            "setPartialReduceMLAAction types aren't valid");
-    PartialReduceActionTypes TypePair = {AccVT.SimpleTy, InputVT.SimpleTy};
-    PartialReduceMLAActions[TypePair] = Action;
+    PartialReduceActionTypes Key = {Opc, AccVT.SimpleTy, InputVT.SimpleTy};
+    PartialReduceMLAActions[Key] = Action;
+  }
+  void setPartialReduceMLAAction(ArrayRef<unsigned> Opcodes, MVT AccVT,
+                                 MVT InputVT, LegalizeAction Action) {
+    for (unsigned Opc : Opcodes)
+      setPartialReduceMLAAction(Opc, AccVT, InputVT, Action);
   }
 
   /// If Opc/OrigVT is specified as being promoted, the promotion code defaults
@@ -3751,10 +3760,10 @@ class LLVM_ABI TargetLoweringBase {
   uint32_t CondCodeActions[ISD::SETCC_INVALID][(MVT::VALUETYPE_SIZE + 7) / 8];
 
   using PartialReduceActionTypes =
-      std::pair<MVT::SimpleValueType, MVT::SimpleValueType>;
-  /// For each result type and input type for the ISD::PARTIAL_REDUCE_U/SMLA
-  /// nodes, keep a LegalizeAction which indicates how instruction selection
-  /// should deal with this operation.
+      std::tuple<unsigned, MVT::SimpleValueType, MVT::SimpleValueType>;
+  /// For each partial reduce opcode, result type and input type combination,
+  /// keep a LegalizeAction which indicates how instruction selection should
+  /// deal with this operation.
   DenseMap<PartialReduceActionTypes, LegalizeAction> PartialReduceMLAActions;
 
   ValueTypeActionImpl ValueTypeActions;
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index e05f85ea3bd8e..be2209a2f8faf 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -12673,17 +12673,17 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
   SDValue LHSExtOp = LHS->getOperand(0);
   EVT LHSExtOpVT = LHSExtOp.getValueType();
 
+  bool ExtIsSigned = LHSOpcode == ISD::SIGN_EXTEND;
+  unsigned NewOpcode =
+      ExtIsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
+
   // Only perform these combines if the target supports folding
   // the extends into the operation.
   if (!TLI.isPartialReduceMLALegalOrCustom(
-          TLI.getTypeToTransformTo(*Context, N->getValueType(0)),
+          NewOpcode, TLI.getTypeToTransformTo(*Context, N->getValueType(0)),
           TLI.getTypeToTransformTo(*Context, LHSExtOpVT)))
     return SDValue();
 
-  bool ExtIsSigned = LHSOpcode == ISD::SIGN_EXTEND;
-  unsigned NewOpcode =
-      ExtIsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
-
   // partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1))
   // -> partial_reduce_*mla(acc, x, C)
   if (ISD::isConstantSplatVector(RHS.getNode(), C)) {
@@ -12737,14 +12737,6 @@ SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {
   if (!ISD::isExtOpcode(Op1Opcode))
     return SDValue();
 
-  SDValue UnextOp1 = Op1.getOperand(0);
-  EVT UnextOp1VT = UnextOp1.getValueType();
-  auto *Context = DAG.getContext();
-  if (!TLI.isPartialReduceMLALegalOrCustom(
-          TLI.getTypeToTransformTo(*Context, N->getValueType(0)),
-          TLI.getTypeToTransformTo(*Context, UnextOp1VT)))
-    return SDValue();
-
   bool Op1IsSigned = Op1Opcode == ISD::SIGN_EXTEND;
   bool NodeIsSigned = N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA;
   EVT AccElemVT = Acc.getValueType().getVectorElementType();
@@ -12754,6 +12746,15 @@ SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {
 
   unsigned NewOpcode =
       Op1IsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
+
+  SDValue UnextOp1 = Op1.getOperand(0);
+  EVT UnextOp1VT = UnextOp1.getValueType();
+  auto *Context = DAG.getContext();
+  if (!TLI.isPartialReduceMLALegalOrCustom(
+          NewOpcode, TLI.getTypeToTransformTo(*Context, N->getValueType(0)),
+          TLI.getTypeToTransformTo(*Context, UnextOp1VT)))
+    return SDValue();
+
   return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, UnextOp1,
                      DAG.getConstant(1, DL, UnextOp1VT));
 }
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
index affcd78ea61b0..910a40e5b5141 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
@@ -530,8 +530,9 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
   }
   case ISD::PARTIAL_REDUCE_UMLA:
   case ISD::PARTIAL_REDUCE_SMLA:
-    Action = TLI.getPartialReduceMLAAction(Node->getValueType(0),
-                                           Node->getOperand(1).getValueType());
+    Action =
+        TLI.getPartialReduceMLAAction(Op.getOpcode(), Node->getValueType(0),
+                                      Node->getOperand(1).getValueType());
     break;
 
 #define BEGIN_REGISTER_VP_SDNODE(VPID, LEGALPOS, ...)                          \
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index a07afea963e20..f18d325148742 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1458,9 +1458,12 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
       setOperationAction(ISD::FADD, VT, Custom);
 
     if (EnablePartialReduceNodes && Subtarget->hasDotProd()) {
-      setPartialReduceMLAAction(MVT::v4i32, MVT::v16i8, Legal);
-      setPartialReduceMLAAction(MVT::v2i32, MVT::v8i8, Legal);
-      setPartialReduceMLAAction(MVT::v2i64, MVT::v16i8, Custom);
+      static const unsigned MLAOps[] = {ISD::PARTIAL_REDUCE_SMLA,
+                                        ISD::PARTIAL_REDUCE_UMLA};
+
+      setPartialReduceMLAAction(MLAOps, MVT::v4i32, MVT::v16i8, Legal);
+      setPartialReduceMLAAction(MLAOps, MVT::v2i32, MVT::v8i8, Legal);
+      setPartialReduceMLAAction(MLAOps, MVT::v2i64, MVT::v16i8, Custom);
     }
 
   } else /* !isNeonAvailable */ {
@@ -1881,16 +1884,18 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
   if (EnablePartialReduceNodes && Subtarget->isSVEorStreamingSVEAvailable()) {
     // Mark known legal pairs as 'Legal' (these will expand to UDOT or SDOT).
     // Other pairs will default to 'Expand'.
-    setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv8i16, Legal);
-    setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv16i8, Legal);
+    static const unsigned MLAOps[] = {ISD::PARTIAL_REDUCE_SMLA,
+                                      ISD::PARTIAL_REDUCE_UMLA};
+    setPartialReduceMLAAction(MLAOps, MVT::nxv2i64, MVT::nxv8i16, Legal);
+    setPartialReduceMLAAction(MLAOps, MVT::nxv4i32, MVT::nxv16i8, Legal);
 
-    setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv16i8, Custom);
+    setPartialReduceMLAAction(MLAOps, MVT::nxv2i64, MVT::nxv16i8, Custom);
 
     // Wide add types
     if (Subtarget->hasSVE2() || Subtarget->hasSME()) {
-      setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv4i32, Legal);
-      setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv8i16, Legal);
-      setPartialReduceMLAAction(MVT::nxv8i16, MVT::nxv16i8, Legal);
+      setPartialReduceMLAAction(MLAOps, MVT::nxv2i64, MVT::nxv4i32, Legal);
+      setPartialReduceMLAAction(MLAOps, MVT::nxv4i32, MVT::nxv8i16, Legal);
+      setPartialReduceMLAAction(MLAOps, MVT::nxv8i16, MVT::nxv16i8, Legal);
     }
   }
 
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 43c81b97a0e05..567f4c5b47d30 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -1573,11 +1573,13 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
 
   // zve32x is broken for partial_reduce_umla, but let's not make it worse.
   if (Subtarget.hasStdExtZvqdotq() && Subtarget.getELen() >= 64) {
-    setPartialReduceMLAAction(MVT::nxv1i32, MVT::nxv4i8, Custom);
-    setPartialReduceMLAAction(MVT::nxv2i32, MVT::nxv8i8, Custom);
-    setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv16i8, Custom);
-    setPartialReduceMLAAction(MVT::nxv8i32, MVT::nxv32i8, Custom);
-    setPartialReduceMLAAction(MVT::nxv16i32, MVT::nxv64i8, Custom);
+    static const unsigned MLAOps[] = {ISD::PARTIAL_REDUCE_SMLA,
+                                      ISD::PARTIAL_REDUCE_UMLA};
+    setPartialReduceMLAAction(MLAOps, MVT::nxv1i32, MVT::nxv4i8, Custom);
+    setPartialReduceMLAAction(MLAOps, MVT::nxv2i32, MVT::nxv8i8, Custom);
+    setPartialReduceMLAAction(MLAOps, MVT::nxv4i32, MVT::nxv16i8, Custom);
+    setPartialReduceMLAAction(MLAOps, MVT::nxv8i32, MVT::nxv32i8, Custom);
+    setPartialReduceMLAAction(MLAOps, MVT::nxv16i32, MVT::nxv64i8, Custom);
 
     if (Subtarget.useRVVForFixedLengthVectors()) {
       for (MVT VT : MVT::integer_fixedlen_vector_valuetypes()) {
@@ -1586,7 +1588,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
           continue;
         ElementCount EC = VT.getVectorElementCount();
         MVT ArgVT = MVT::getVectorVT(MVT::i8, EC.multiplyCoefficientBy(4));
-        setPartialReduceMLAAction(VT, ArgVT, Custom);
+        setPartialReduceMLAAction(MLAOps, VT, ArgVT, Custom);
       }
     }
   }

``````````

</details>


https://github.com/llvm/llvm-project/pull/141970


More information about the llvm-commits mailing list