[llvm] [AArch64][SVE] Add partial reduction SDNodes (PR #117185)

James Chesterman via llvm-commits llvm-commits at lists.llvm.org
Thu Nov 21 08:30:53 PST 2024


https://github.com/JamesChesterman created https://github.com/llvm/llvm-project/pull/117185

Add the opcode 'ISD::PARTIAL_REDUCE_ADD' and use it when making SDNodes. When the inputs and outputs have types that can allow for lowering to wide add or dot product instruction(s), then convert the corresponding intrinsic to an SDNode. This will allow legalisation, which will be added in a future patch, to be done more easily.

>From 12769a1581ad3c434696d838c606ee83778f6032 Mon Sep 17 00:00:00 2001
From: James Chesterman <james.chesterman at arm.com>
Date: Thu, 21 Nov 2024 16:19:19 +0000
Subject: [PATCH] [AArch64][SVE] Add partial reduction SDNodes

Add the opcode 'ISD::PARTIAL_REDUCE_ADD' and use it when making
SDNodes. When the inputs and outputs have types that can allow for
lowering to wide add or dot product instruction(s), then convert
the corresponding intrinsic to an SDNode. This will allow
legalisation, which will be added in a future patch, to be done
more easily.
---
 llvm/include/llvm/CodeGen/ISDOpcodes.h        |  5 ++
 llvm/include/llvm/CodeGen/SelectionDAG.h      |  5 ++
 llvm/include/llvm/CodeGen/SelectionDAGNodes.h | 16 ++++++
 .../lib/CodeGen/SelectionDAG/SelectionDAG.cpp |  6 ++
 .../SelectionDAG/SelectionDAGBuilder.cpp      | 12 +++-
 .../SelectionDAG/SelectionDAGBuilder.h        |  1 +
 .../SelectionDAG/SelectionDAGDumper.cpp       |  3 +
 .../Target/AArch64/AArch64ISelLowering.cpp    | 56 +++++++++----------
 8 files changed, 74 insertions(+), 30 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/ISDOpcodes.h b/llvm/include/llvm/CodeGen/ISDOpcodes.h
index 0b6d155b6d161e..7809a1b26dd7cd 100644
--- a/llvm/include/llvm/CodeGen/ISDOpcodes.h
+++ b/llvm/include/llvm/CodeGen/ISDOpcodes.h
@@ -1451,6 +1451,11 @@ enum NodeType {
   VECREDUCE_UMAX,
   VECREDUCE_UMIN,
 
+  // The `llvm.experimental.vector.partial.reduce.add` intrinsic
+  // Operands: Accumulator, Input
+  // Outputs: Output
+  PARTIAL_REDUCE_ADD,
+
   // The `llvm.experimental.stackmap` intrinsic.
   // Operands: input chain, glue, <id>, <numShadowBytes>, [live0[, live1...]]
   // Outputs: output chain, glue
diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h
index 2e3507386df309..5d0a976fbb66a9 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAG.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAG.h
@@ -1595,6 +1595,11 @@ class SelectionDAG {
   /// the target's desired shift amount type.
   SDValue getShiftAmountOperand(EVT LHSTy, SDValue Op);
 
+  /// Get a partial reduction SD node for the DAG. This is done when the input
+  /// and output types can be legalised for wide add(s) or dot product(s)
+  SDValue getPartialReduceAddSDNode(SDLoc DL, SDValue Chain, SDValue Acc,
+                                    SDValue Input);
+
   /// Create the DAG equivalent of vector_partial_reduce where Op1 and Op2 are
   /// its operands and ReducedTY is the intrinsic's return type.
   SDValue getPartialReduceAdd(SDLoc DL, EVT ReducedTy, SDValue Op1,
diff --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
index 677b59e0c8fbeb..6cdf87fd7895c7 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
@@ -3010,6 +3010,22 @@ class MaskedHistogramSDNode : public MaskedGatherScatterSDNode {
   }
 };
 
+class PartialReduceAddSDNode : public SDNode {
+public:
+  friend class SelectionDAG;
+
+  PartialReduceAddSDNode(const DebugLoc &dl, SDVTList VTs)
+      : SDNode(ISD::PARTIAL_REDUCE_ADD, 0, dl, VTs) {}
+
+  const SDValue &getChain() const { return getOperand(0); }
+  const SDValue &getAcc() const { return getOperand(1); }
+  const SDValue &getInput() const { return getOperand(2); }
+
+  static bool classof(const SDNode *N) {
+    return N->getOpcode() == ISD::PARTIAL_REDUCE_ADD;
+  }
+};
+
 class FPStateAccessSDNode : public MemSDNode {
 public:
   friend class SelectionDAG;
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 3a8ec3c6105bc0..a7a208b6af6f9a 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -2452,6 +2452,12 @@ SDValue SelectionDAG::getShiftAmountOperand(EVT LHSTy, SDValue Op) {
   return getZExtOrTrunc(Op, SDLoc(Op), ShTy);
 }
 
+SDValue SelectionDAG::getPartialReduceAddSDNode(SDLoc DL, SDValue Chain,
+                                                SDValue Acc, SDValue Input) {
+  return getNode(ISD::PARTIAL_REDUCE_ADD, DL, Acc.getValueType(), Chain, Acc,
+                 Input);
+}
+
 SDValue SelectionDAG::getPartialReduceAdd(SDLoc DL, EVT ReducedTy, SDValue Op1,
                                           SDValue Op2) {
   EVT FullTy = Op2.getValueType();
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 9d729d448502d8..0480d99767bb75 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -6415,6 +6415,16 @@ void SelectionDAGBuilder::visitVectorHistogram(const CallInst &I,
   DAG.setRoot(Histogram);
 }
 
+void SelectionDAGBuilder::visitPartialReduceAdd(const CallInst &I,
+                                                unsigned IntrinsicID) {
+  SDLoc dl = getCurSDLoc();
+  SDValue Acc = getValue(I.getOperand(0));
+  SDValue Input = getValue(I.getOperand(1));
+  SDValue Chain = getRoot();
+
+  setValue(&I, DAG.getPartialReduceAddSDNode(dl, Chain, Acc, Input));
+}
+
 void SelectionDAGBuilder::visitVectorExtractLastActive(const CallInst &I,
                                                        unsigned Intrinsic) {
   assert(Intrinsic == Intrinsic::experimental_vector_extract_last_active &&
@@ -8128,7 +8138,7 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
   case Intrinsic::experimental_vector_partial_reduce_add: {
 
     if (!TLI.shouldExpandPartialReductionIntrinsic(cast<IntrinsicInst>(&I))) {
-      visitTargetIntrinsic(I, Intrinsic);
+      visitPartialReduceAdd(I, Intrinsic);
       return;
     }
 
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h
index 3a8dc25e98700e..a9e0c8f1ea10c1 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h
@@ -629,6 +629,7 @@ class SelectionDAGBuilder {
   void visitConstrainedFPIntrinsic(const ConstrainedFPIntrinsic &FPI);
   void visitConvergenceControl(const CallInst &I, unsigned Intrinsic);
   void visitVectorHistogram(const CallInst &I, unsigned IntrinsicID);
+  void visitPartialReduceAdd(const CallInst &, unsigned IntrinsicID);
   void visitVectorExtractLastActive(const CallInst &I, unsigned Intrinsic);
   void visitVPLoad(const VPIntrinsic &VPIntrin, EVT VT,
                    const SmallVectorImpl<SDValue> &OpValues);
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
index 580ff19065557b..8ce03b14bda46c 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
@@ -567,6 +567,9 @@ std::string SDNode::getOperationName(const SelectionDAG *G) const {
   case ISD::EXPERIMENTAL_VECTOR_HISTOGRAM:
     return "histogram";
 
+  case ISD::PARTIAL_REDUCE_ADD:
+    return "partial_reduce_add";
+
     // Vector Predication
 #define BEGIN_REGISTER_VP_SDNODE(SDID, LEGALARG, NAME, ...)                    \
   case ISD::SDID:                                                              \
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 7ab3fc06715ec8..6a3fbf3a8b596f 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1124,6 +1124,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
   setTargetDAGCombine(
       {ISD::MGATHER, ISD::MSCATTER, ISD::EXPERIMENTAL_VECTOR_HISTOGRAM});
 
+  setTargetDAGCombine(ISD::PARTIAL_REDUCE_ADD);
+
   setTargetDAGCombine(ISD::FP_EXTEND);
 
   setTargetDAGCombine(ISD::GlobalAddress);
@@ -21718,26 +21720,21 @@ static SDValue tryCombineWhileLo(SDNode *N,
   return SDValue(N, 0);
 }
 
-SDValue tryLowerPartialReductionToDot(SDNode *N,
+SDValue tryLowerPartialReductionToDot(PartialReduceAddSDNode *PR,
                                       const AArch64Subtarget *Subtarget,
                                       SelectionDAG &DAG) {
 
-  assert(N->getOpcode() == ISD::INTRINSIC_WO_CHAIN &&
-         getIntrinsicID(N) ==
-             Intrinsic::experimental_vector_partial_reduce_add &&
-         "Expected a partial reduction node");
-
-  bool Scalable = N->getValueType(0).isScalableVector();
+  bool Scalable = PR->getValueType(0).isScalableVector();
   if (Scalable && !Subtarget->isSVEorStreamingSVEAvailable())
     return SDValue();
   if (!Scalable && (!Subtarget->isNeonAvailable() || !Subtarget->hasDotProd()))
     return SDValue();
 
-  SDLoc DL(N);
+  SDLoc DL(PR);
 
   // The narrower of the two operands. Used as the accumulator
-  auto NarrowOp = N->getOperand(1);
-  auto MulOp = N->getOperand(2);
+  auto NarrowOp = PR->getAcc();
+  auto MulOp = PR->getInput();
   if (MulOp->getOpcode() != ISD::MUL)
     return SDValue();
 
@@ -21755,7 +21752,7 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
   if (A.getValueType() != B.getValueType())
     return SDValue();
 
-  EVT ReducedType = N->getValueType(0);
+  EVT ReducedType = PR->getValueType(0);
   EVT MulSrcType = A.getValueType();
 
   // Dot products operate on chunks of four elements so there must be four times
@@ -21774,7 +21771,7 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
     if (!Subtarget->hasMatMulInt8())
       return SDValue();
 
-    bool Scalable = N->getValueType(0).isScalableVT();
+    bool Scalable = PR->getValueType(0).isScalableVT();
     // There's no nxv2i64 version of usdot
     if (Scalable && ReducedType != MVT::nxv4i32 && ReducedType != MVT::nxv4i64)
       return SDValue();
@@ -21805,22 +21802,17 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
   return DAG.getNode(Opcode, DL, ReducedType, NarrowOp, A, B);
 }
 
-SDValue tryLowerPartialReductionToWideAdd(SDNode *N,
+SDValue tryLowerPartialReductionToWideAdd(PartialReduceAddSDNode *PR,
                                           const AArch64Subtarget *Subtarget,
                                           SelectionDAG &DAG) {
 
-  assert(N->getOpcode() == ISD::INTRINSIC_WO_CHAIN &&
-         getIntrinsicID(N) ==
-             Intrinsic::experimental_vector_partial_reduce_add &&
-         "Expected a partial reduction node");
-
   if (!Subtarget->isSVEorStreamingSVEAvailable())
     return SDValue();
 
-  SDLoc DL(N);
+  SDLoc DL(PR);
 
-  auto Acc = N->getOperand(1);
-  auto ExtInput = N->getOperand(2);
+  auto Acc = PR->getAcc();
+  auto ExtInput = PR->getInput();
 
   EVT AccVT = Acc.getValueType();
   EVT AccElemVT = AccVT.getVectorElementType();
@@ -21847,6 +21839,18 @@ SDValue tryLowerPartialReductionToWideAdd(SDNode *N,
   return DAG.getNode(TopOpcode, DL, AccVT, BottomNode, Input);
 }
 
+static SDValue
+performPartialReduceAddCombine(SDNode *N, SelectionDAG &DAG,
+                               const AArch64Subtarget *Subtarget) {
+  auto *PR = cast<PartialReduceAddSDNode>(N);
+  if (auto Dot = tryLowerPartialReductionToDot(PR, Subtarget, DAG))
+    return Dot;
+  if (auto WideAdd = tryLowerPartialReductionToWideAdd(PR, Subtarget, DAG))
+    return WideAdd;
+  return DAG.getPartialReduceAdd(SDLoc(PR), PR->getValueType(0), PR->getAcc(),
+                                 PR->getInput());
+}
+
 static SDValue performIntrinsicCombine(SDNode *N,
                                        TargetLowering::DAGCombinerInfo &DCI,
                                        const AArch64Subtarget *Subtarget) {
@@ -21855,14 +21859,6 @@ static SDValue performIntrinsicCombine(SDNode *N,
   switch (IID) {
   default:
     break;
-  case Intrinsic::experimental_vector_partial_reduce_add: {
-    if (auto Dot = tryLowerPartialReductionToDot(N, Subtarget, DAG))
-      return Dot;
-    if (auto WideAdd = tryLowerPartialReductionToWideAdd(N, Subtarget, DAG))
-      return WideAdd;
-    return DAG.getPartialReduceAdd(SDLoc(N), N->getValueType(0),
-                                   N->getOperand(1), N->getOperand(2));
-  }
   case Intrinsic::aarch64_neon_vcvtfxs2fp:
   case Intrinsic::aarch64_neon_vcvtfxu2fp:
     return tryCombineFixedPointConvert(N, DCI, DAG);
@@ -26148,6 +26144,8 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
   case ISD::MSCATTER:
   case ISD::EXPERIMENTAL_VECTOR_HISTOGRAM:
     return performMaskedGatherScatterCombine(N, DCI, DAG);
+  case ISD::PARTIAL_REDUCE_ADD:
+    return performPartialReduceAddCombine(N, DAG, Subtarget);
   case ISD::FP_EXTEND:
     return performFPExtendCombine(N, DAG, DCI, Subtarget);
   case AArch64ISD::BRCOND:



More information about the llvm-commits mailing list