[llvm] [SDAG] Split the partial reduce legalize table by opcode [nfc] (PR #141970)
via llvm-commits
llvm-commits at lists.llvm.org
Thu May 29 08:55:23 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-aarch64
Author: Philip Reames (preames)
<details>
<summary>Changes</summary>
On it's own, this change should be non-functional. This is a preparatory change for https://github.com/llvm/llvm-project/pull/141267 which adds a new form of PARTIAL_REDUCE_*MLA. As noted in the discussion on that review, AArch64 needs a different set of legal and custom types for the PARTIAL_REDUCE_SUMLA variant than the currently existing PARTIAL_REDUCE_UMLA/SMLA.
---
Full diff: https://github.com/llvm/llvm-project/pull/141970.diff
5 Files Affected:
- (modified) llvm/include/llvm/CodeGen/TargetLowering.h (+22-13)
- (modified) llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (+14-13)
- (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp (+3-2)
- (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+14-9)
- (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+8-6)
``````````diff
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index b818f4768c2c3..9c453f51e129d 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -1659,17 +1659,20 @@ class LLVM_ABI TargetLoweringBase {
/// InputVT should be treated. Either it's legal, needs to be promoted to a
/// larger size, needs to be expanded to some other code sequence, or the
/// target has a custom expander for it.
- LegalizeAction getPartialReduceMLAAction(EVT AccVT, EVT InputVT) const {
- PartialReduceActionTypes TypePair = {AccVT.getSimpleVT().SimpleTy,
- InputVT.getSimpleVT().SimpleTy};
- auto It = PartialReduceMLAActions.find(TypePair);
+ LegalizeAction getPartialReduceMLAAction(unsigned Opc, EVT AccVT,
+ EVT InputVT) const {
+ assert(Opc == ISD::PARTIAL_REDUCE_SMLA || Opc == ISD::PARTIAL_REDUCE_UMLA);
+ PartialReduceActionTypes Key = {Opc, AccVT.getSimpleVT().SimpleTy,
+ InputVT.getSimpleVT().SimpleTy};
+ auto It = PartialReduceMLAActions.find(Key);
return It != PartialReduceMLAActions.end() ? It->second : Expand;
}
/// Return true if a PARTIAL_REDUCE_U/SMLA node with the specified types is
/// legal or custom for this target.
- bool isPartialReduceMLALegalOrCustom(EVT AccVT, EVT InputVT) const {
- LegalizeAction Action = getPartialReduceMLAAction(AccVT, InputVT);
+ bool isPartialReduceMLALegalOrCustom(unsigned Opc, EVT AccVT,
+ EVT InputVT) const {
+ LegalizeAction Action = getPartialReduceMLAAction(Opc, AccVT, InputVT);
return Action == Legal || Action == Custom;
}
@@ -2754,12 +2757,18 @@ class LLVM_ABI TargetLoweringBase {
/// type InputVT should be treated by the target. Either it's legal, needs to
/// be promoted to a larger size, needs to be expanded to some other code
/// sequence, or the target has a custom expander for it.
- void setPartialReduceMLAAction(MVT AccVT, MVT InputVT,
+ void setPartialReduceMLAAction(unsigned Opc, MVT AccVT, MVT InputVT,
LegalizeAction Action) {
+ assert(Opc == ISD::PARTIAL_REDUCE_SMLA || Opc == ISD::PARTIAL_REDUCE_UMLA);
assert(AccVT.isValid() && InputVT.isValid() &&
"setPartialReduceMLAAction types aren't valid");
- PartialReduceActionTypes TypePair = {AccVT.SimpleTy, InputVT.SimpleTy};
- PartialReduceMLAActions[TypePair] = Action;
+ PartialReduceActionTypes Key = {Opc, AccVT.SimpleTy, InputVT.SimpleTy};
+ PartialReduceMLAActions[Key] = Action;
+ }
+ void setPartialReduceMLAAction(ArrayRef<unsigned> Opcodes, MVT AccVT,
+ MVT InputVT, LegalizeAction Action) {
+ for (unsigned Opc : Opcodes)
+ setPartialReduceMLAAction(Opc, AccVT, InputVT, Action);
}
/// If Opc/OrigVT is specified as being promoted, the promotion code defaults
@@ -3751,10 +3760,10 @@ class LLVM_ABI TargetLoweringBase {
uint32_t CondCodeActions[ISD::SETCC_INVALID][(MVT::VALUETYPE_SIZE + 7) / 8];
using PartialReduceActionTypes =
- std::pair<MVT::SimpleValueType, MVT::SimpleValueType>;
- /// For each result type and input type for the ISD::PARTIAL_REDUCE_U/SMLA
- /// nodes, keep a LegalizeAction which indicates how instruction selection
- /// should deal with this operation.
+ std::tuple<unsigned, MVT::SimpleValueType, MVT::SimpleValueType>;
+ /// For each partial reduce opcode, result type and input type combination,
+ /// keep a LegalizeAction which indicates how instruction selection should
+ /// deal with this operation.
DenseMap<PartialReduceActionTypes, LegalizeAction> PartialReduceMLAActions;
ValueTypeActionImpl ValueTypeActions;
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index e05f85ea3bd8e..be2209a2f8faf 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -12673,17 +12673,17 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
SDValue LHSExtOp = LHS->getOperand(0);
EVT LHSExtOpVT = LHSExtOp.getValueType();
+ bool ExtIsSigned = LHSOpcode == ISD::SIGN_EXTEND;
+ unsigned NewOpcode =
+ ExtIsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
+
// Only perform these combines if the target supports folding
// the extends into the operation.
if (!TLI.isPartialReduceMLALegalOrCustom(
- TLI.getTypeToTransformTo(*Context, N->getValueType(0)),
+ NewOpcode, TLI.getTypeToTransformTo(*Context, N->getValueType(0)),
TLI.getTypeToTransformTo(*Context, LHSExtOpVT)))
return SDValue();
- bool ExtIsSigned = LHSOpcode == ISD::SIGN_EXTEND;
- unsigned NewOpcode =
- ExtIsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
-
// partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1))
// -> partial_reduce_*mla(acc, x, C)
if (ISD::isConstantSplatVector(RHS.getNode(), C)) {
@@ -12737,14 +12737,6 @@ SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {
if (!ISD::isExtOpcode(Op1Opcode))
return SDValue();
- SDValue UnextOp1 = Op1.getOperand(0);
- EVT UnextOp1VT = UnextOp1.getValueType();
- auto *Context = DAG.getContext();
- if (!TLI.isPartialReduceMLALegalOrCustom(
- TLI.getTypeToTransformTo(*Context, N->getValueType(0)),
- TLI.getTypeToTransformTo(*Context, UnextOp1VT)))
- return SDValue();
-
bool Op1IsSigned = Op1Opcode == ISD::SIGN_EXTEND;
bool NodeIsSigned = N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA;
EVT AccElemVT = Acc.getValueType().getVectorElementType();
@@ -12754,6 +12746,15 @@ SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {
unsigned NewOpcode =
Op1IsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
+
+ SDValue UnextOp1 = Op1.getOperand(0);
+ EVT UnextOp1VT = UnextOp1.getValueType();
+ auto *Context = DAG.getContext();
+ if (!TLI.isPartialReduceMLALegalOrCustom(
+ NewOpcode, TLI.getTypeToTransformTo(*Context, N->getValueType(0)),
+ TLI.getTypeToTransformTo(*Context, UnextOp1VT)))
+ return SDValue();
+
return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, UnextOp1,
DAG.getConstant(1, DL, UnextOp1VT));
}
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
index affcd78ea61b0..910a40e5b5141 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
@@ -530,8 +530,9 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
}
case ISD::PARTIAL_REDUCE_UMLA:
case ISD::PARTIAL_REDUCE_SMLA:
- Action = TLI.getPartialReduceMLAAction(Node->getValueType(0),
- Node->getOperand(1).getValueType());
+ Action =
+ TLI.getPartialReduceMLAAction(Op.getOpcode(), Node->getValueType(0),
+ Node->getOperand(1).getValueType());
break;
#define BEGIN_REGISTER_VP_SDNODE(VPID, LEGALPOS, ...) \
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index a07afea963e20..f18d325148742 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1458,9 +1458,12 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::FADD, VT, Custom);
if (EnablePartialReduceNodes && Subtarget->hasDotProd()) {
- setPartialReduceMLAAction(MVT::v4i32, MVT::v16i8, Legal);
- setPartialReduceMLAAction(MVT::v2i32, MVT::v8i8, Legal);
- setPartialReduceMLAAction(MVT::v2i64, MVT::v16i8, Custom);
+ static const unsigned MLAOps[] = {ISD::PARTIAL_REDUCE_SMLA,
+ ISD::PARTIAL_REDUCE_UMLA};
+
+ setPartialReduceMLAAction(MLAOps, MVT::v4i32, MVT::v16i8, Legal);
+ setPartialReduceMLAAction(MLAOps, MVT::v2i32, MVT::v8i8, Legal);
+ setPartialReduceMLAAction(MLAOps, MVT::v2i64, MVT::v16i8, Custom);
}
} else /* !isNeonAvailable */ {
@@ -1881,16 +1884,18 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
if (EnablePartialReduceNodes && Subtarget->isSVEorStreamingSVEAvailable()) {
// Mark known legal pairs as 'Legal' (these will expand to UDOT or SDOT).
// Other pairs will default to 'Expand'.
- setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv8i16, Legal);
- setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv16i8, Legal);
+ static const unsigned MLAOps[] = {ISD::PARTIAL_REDUCE_SMLA,
+ ISD::PARTIAL_REDUCE_UMLA};
+ setPartialReduceMLAAction(MLAOps, MVT::nxv2i64, MVT::nxv8i16, Legal);
+ setPartialReduceMLAAction(MLAOps, MVT::nxv4i32, MVT::nxv16i8, Legal);
- setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv16i8, Custom);
+ setPartialReduceMLAAction(MLAOps, MVT::nxv2i64, MVT::nxv16i8, Custom);
// Wide add types
if (Subtarget->hasSVE2() || Subtarget->hasSME()) {
- setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv4i32, Legal);
- setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv8i16, Legal);
- setPartialReduceMLAAction(MVT::nxv8i16, MVT::nxv16i8, Legal);
+ setPartialReduceMLAAction(MLAOps, MVT::nxv2i64, MVT::nxv4i32, Legal);
+ setPartialReduceMLAAction(MLAOps, MVT::nxv4i32, MVT::nxv8i16, Legal);
+ setPartialReduceMLAAction(MLAOps, MVT::nxv8i16, MVT::nxv16i8, Legal);
}
}
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 43c81b97a0e05..567f4c5b47d30 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -1573,11 +1573,13 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
// zve32x is broken for partial_reduce_umla, but let's not make it worse.
if (Subtarget.hasStdExtZvqdotq() && Subtarget.getELen() >= 64) {
- setPartialReduceMLAAction(MVT::nxv1i32, MVT::nxv4i8, Custom);
- setPartialReduceMLAAction(MVT::nxv2i32, MVT::nxv8i8, Custom);
- setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv16i8, Custom);
- setPartialReduceMLAAction(MVT::nxv8i32, MVT::nxv32i8, Custom);
- setPartialReduceMLAAction(MVT::nxv16i32, MVT::nxv64i8, Custom);
+ static const unsigned MLAOps[] = {ISD::PARTIAL_REDUCE_SMLA,
+ ISD::PARTIAL_REDUCE_UMLA};
+ setPartialReduceMLAAction(MLAOps, MVT::nxv1i32, MVT::nxv4i8, Custom);
+ setPartialReduceMLAAction(MLAOps, MVT::nxv2i32, MVT::nxv8i8, Custom);
+ setPartialReduceMLAAction(MLAOps, MVT::nxv4i32, MVT::nxv16i8, Custom);
+ setPartialReduceMLAAction(MLAOps, MVT::nxv8i32, MVT::nxv32i8, Custom);
+ setPartialReduceMLAAction(MLAOps, MVT::nxv16i32, MVT::nxv64i8, Custom);
if (Subtarget.useRVVForFixedLengthVectors()) {
for (MVT VT : MVT::integer_fixedlen_vector_valuetypes()) {
@@ -1586,7 +1588,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
continue;
ElementCount EC = VT.getVectorElementCount();
MVT ArgVT = MVT::getVectorVT(MVT::i8, EC.multiplyCoefficientBy(4));
- setPartialReduceMLAAction(VT, ArgVT, Custom);
+ setPartialReduceMLAAction(MLAOps, VT, ArgVT, Custom);
}
}
}
``````````
</details>
https://github.com/llvm/llvm-project/pull/141970
More information about the llvm-commits
mailing list