[llvm] [AArch64][SVE] Add partial reduction SDNodes (PR #117185)

James Chesterman via llvm-commits llvm-commits at lists.llvm.org
Thu Nov 28 04:13:04 PST 2024


https://github.com/JamesChesterman updated https://github.com/llvm/llvm-project/pull/117185

>From 12769a1581ad3c434696d838c606ee83778f6032 Mon Sep 17 00:00:00 2001
From: James Chesterman <james.chesterman at arm.com>
Date: Thu, 21 Nov 2024 16:19:19 +0000
Subject: [PATCH 1/3] [AArch64][SVE] Add partial reduction SDNodes

Add the opcode 'ISD::PARTIAL_REDUCE_ADD' and use it when making
SDNodes. When the inputs and outputs have types that can allow for
lowering to wide add or dot product instruction(s), then convert
the corresponding intrinsic to an SDNode. This will allow
legalisation, which will be added in a future patch, to be done
more easily.
---
 llvm/include/llvm/CodeGen/ISDOpcodes.h        |  5 ++
 llvm/include/llvm/CodeGen/SelectionDAG.h      |  5 ++
 llvm/include/llvm/CodeGen/SelectionDAGNodes.h | 16 ++++++
 .../lib/CodeGen/SelectionDAG/SelectionDAG.cpp |  6 ++
 .../SelectionDAG/SelectionDAGBuilder.cpp      | 12 +++-
 .../SelectionDAG/SelectionDAGBuilder.h        |  1 +
 .../SelectionDAG/SelectionDAGDumper.cpp       |  3 +
 .../Target/AArch64/AArch64ISelLowering.cpp    | 56 +++++++++----------
 8 files changed, 74 insertions(+), 30 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/ISDOpcodes.h b/llvm/include/llvm/CodeGen/ISDOpcodes.h
index 0b6d155b6d161e..7809a1b26dd7cd 100644
--- a/llvm/include/llvm/CodeGen/ISDOpcodes.h
+++ b/llvm/include/llvm/CodeGen/ISDOpcodes.h
@@ -1451,6 +1451,11 @@ enum NodeType {
   VECREDUCE_UMAX,
   VECREDUCE_UMIN,
 
+  // The `llvm.experimental.vector.partial.reduce.add` intrinsic
+  // Operands: Accumulator, Input
+  // Outputs: Output
+  PARTIAL_REDUCE_ADD,
+
   // The `llvm.experimental.stackmap` intrinsic.
   // Operands: input chain, glue, <id>, <numShadowBytes>, [live0[, live1...]]
   // Outputs: output chain, glue
diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h
index 2e3507386df309..5d0a976fbb66a9 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAG.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAG.h
@@ -1595,6 +1595,11 @@ class SelectionDAG {
   /// the target's desired shift amount type.
   SDValue getShiftAmountOperand(EVT LHSTy, SDValue Op);
 
+  /// Get a partial reduction SD node for the DAG. This is done when the input
+  /// and output types can be legalised for wide add(s) or dot product(s)
+  SDValue getPartialReduceAddSDNode(SDLoc DL, SDValue Chain, SDValue Acc,
+                                    SDValue Input);
+
   /// Create the DAG equivalent of vector_partial_reduce where Op1 and Op2 are
   /// its operands and ReducedTY is the intrinsic's return type.
   SDValue getPartialReduceAdd(SDLoc DL, EVT ReducedTy, SDValue Op1,
diff --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
index 677b59e0c8fbeb..6cdf87fd7895c7 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
@@ -3010,6 +3010,22 @@ class MaskedHistogramSDNode : public MaskedGatherScatterSDNode {
   }
 };
 
+class PartialReduceAddSDNode : public SDNode {
+public:
+  friend class SelectionDAG;
+
+  PartialReduceAddSDNode(const DebugLoc &dl, SDVTList VTs)
+      : SDNode(ISD::PARTIAL_REDUCE_ADD, 0, dl, VTs) {}
+
+  const SDValue &getChain() const { return getOperand(0); }
+  const SDValue &getAcc() const { return getOperand(1); }
+  const SDValue &getInput() const { return getOperand(2); }
+
+  static bool classof(const SDNode *N) {
+    return N->getOpcode() == ISD::PARTIAL_REDUCE_ADD;
+  }
+};
+
 class FPStateAccessSDNode : public MemSDNode {
 public:
   friend class SelectionDAG;
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 3a8ec3c6105bc0..a7a208b6af6f9a 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -2452,6 +2452,12 @@ SDValue SelectionDAG::getShiftAmountOperand(EVT LHSTy, SDValue Op) {
   return getZExtOrTrunc(Op, SDLoc(Op), ShTy);
 }
 
+SDValue SelectionDAG::getPartialReduceAddSDNode(SDLoc DL, SDValue Chain,
+                                                SDValue Acc, SDValue Input) {
+  return getNode(ISD::PARTIAL_REDUCE_ADD, DL, Acc.getValueType(), Chain, Acc,
+                 Input);
+}
+
 SDValue SelectionDAG::getPartialReduceAdd(SDLoc DL, EVT ReducedTy, SDValue Op1,
                                           SDValue Op2) {
   EVT FullTy = Op2.getValueType();
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 9d729d448502d8..0480d99767bb75 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -6415,6 +6415,16 @@ void SelectionDAGBuilder::visitVectorHistogram(const CallInst &I,
   DAG.setRoot(Histogram);
 }
 
+void SelectionDAGBuilder::visitPartialReduceAdd(const CallInst &I,
+                                                unsigned IntrinsicID) {
+  SDLoc dl = getCurSDLoc();
+  SDValue Acc = getValue(I.getOperand(0));
+  SDValue Input = getValue(I.getOperand(1));
+  SDValue Chain = getRoot();
+
+  setValue(&I, DAG.getPartialReduceAddSDNode(dl, Chain, Acc, Input));
+}
+
 void SelectionDAGBuilder::visitVectorExtractLastActive(const CallInst &I,
                                                        unsigned Intrinsic) {
   assert(Intrinsic == Intrinsic::experimental_vector_extract_last_active &&
@@ -8128,7 +8138,7 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
   case Intrinsic::experimental_vector_partial_reduce_add: {
 
     if (!TLI.shouldExpandPartialReductionIntrinsic(cast<IntrinsicInst>(&I))) {
-      visitTargetIntrinsic(I, Intrinsic);
+      visitPartialReduceAdd(I, Intrinsic);
       return;
     }
 
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h
index 3a8dc25e98700e..a9e0c8f1ea10c1 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h
@@ -629,6 +629,7 @@ class SelectionDAGBuilder {
   void visitConstrainedFPIntrinsic(const ConstrainedFPIntrinsic &FPI);
   void visitConvergenceControl(const CallInst &I, unsigned Intrinsic);
   void visitVectorHistogram(const CallInst &I, unsigned IntrinsicID);
+  void visitPartialReduceAdd(const CallInst &, unsigned IntrinsicID);
   void visitVectorExtractLastActive(const CallInst &I, unsigned Intrinsic);
   void visitVPLoad(const VPIntrinsic &VPIntrin, EVT VT,
                    const SmallVectorImpl<SDValue> &OpValues);
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
index 580ff19065557b..8ce03b14bda46c 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
@@ -567,6 +567,9 @@ std::string SDNode::getOperationName(const SelectionDAG *G) const {
   case ISD::EXPERIMENTAL_VECTOR_HISTOGRAM:
     return "histogram";
 
+  case ISD::PARTIAL_REDUCE_ADD:
+    return "partial_reduce_add";
+
     // Vector Predication
 #define BEGIN_REGISTER_VP_SDNODE(SDID, LEGALARG, NAME, ...)                    \
   case ISD::SDID:                                                              \
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 7ab3fc06715ec8..6a3fbf3a8b596f 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1124,6 +1124,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
   setTargetDAGCombine(
       {ISD::MGATHER, ISD::MSCATTER, ISD::EXPERIMENTAL_VECTOR_HISTOGRAM});
 
+  setTargetDAGCombine(ISD::PARTIAL_REDUCE_ADD);
+
   setTargetDAGCombine(ISD::FP_EXTEND);
 
   setTargetDAGCombine(ISD::GlobalAddress);
@@ -21718,26 +21720,21 @@ static SDValue tryCombineWhileLo(SDNode *N,
   return SDValue(N, 0);
 }
 
-SDValue tryLowerPartialReductionToDot(SDNode *N,
+SDValue tryLowerPartialReductionToDot(PartialReduceAddSDNode *PR,
                                       const AArch64Subtarget *Subtarget,
                                       SelectionDAG &DAG) {
 
-  assert(N->getOpcode() == ISD::INTRINSIC_WO_CHAIN &&
-         getIntrinsicID(N) ==
-             Intrinsic::experimental_vector_partial_reduce_add &&
-         "Expected a partial reduction node");
-
-  bool Scalable = N->getValueType(0).isScalableVector();
+  bool Scalable = PR->getValueType(0).isScalableVector();
   if (Scalable && !Subtarget->isSVEorStreamingSVEAvailable())
     return SDValue();
   if (!Scalable && (!Subtarget->isNeonAvailable() || !Subtarget->hasDotProd()))
     return SDValue();
 
-  SDLoc DL(N);
+  SDLoc DL(PR);
 
   // The narrower of the two operands. Used as the accumulator
-  auto NarrowOp = N->getOperand(1);
-  auto MulOp = N->getOperand(2);
+  auto NarrowOp = PR->getAcc();
+  auto MulOp = PR->getInput();
   if (MulOp->getOpcode() != ISD::MUL)
     return SDValue();
 
@@ -21755,7 +21752,7 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
   if (A.getValueType() != B.getValueType())
     return SDValue();
 
-  EVT ReducedType = N->getValueType(0);
+  EVT ReducedType = PR->getValueType(0);
   EVT MulSrcType = A.getValueType();
 
   // Dot products operate on chunks of four elements so there must be four times
@@ -21774,7 +21771,7 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
     if (!Subtarget->hasMatMulInt8())
       return SDValue();
 
-    bool Scalable = N->getValueType(0).isScalableVT();
+    bool Scalable = PR->getValueType(0).isScalableVT();
     // There's no nxv2i64 version of usdot
     if (Scalable && ReducedType != MVT::nxv4i32 && ReducedType != MVT::nxv4i64)
       return SDValue();
@@ -21805,22 +21802,17 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
   return DAG.getNode(Opcode, DL, ReducedType, NarrowOp, A, B);
 }
 
-SDValue tryLowerPartialReductionToWideAdd(SDNode *N,
+SDValue tryLowerPartialReductionToWideAdd(PartialReduceAddSDNode *PR,
                                           const AArch64Subtarget *Subtarget,
                                           SelectionDAG &DAG) {
 
-  assert(N->getOpcode() == ISD::INTRINSIC_WO_CHAIN &&
-         getIntrinsicID(N) ==
-             Intrinsic::experimental_vector_partial_reduce_add &&
-         "Expected a partial reduction node");
-
   if (!Subtarget->isSVEorStreamingSVEAvailable())
     return SDValue();
 
-  SDLoc DL(N);
+  SDLoc DL(PR);
 
-  auto Acc = N->getOperand(1);
-  auto ExtInput = N->getOperand(2);
+  auto Acc = PR->getAcc();
+  auto ExtInput = PR->getInput();
 
   EVT AccVT = Acc.getValueType();
   EVT AccElemVT = AccVT.getVectorElementType();
@@ -21847,6 +21839,18 @@ SDValue tryLowerPartialReductionToWideAdd(SDNode *N,
   return DAG.getNode(TopOpcode, DL, AccVT, BottomNode, Input);
 }
 
+static SDValue
+performPartialReduceAddCombine(SDNode *N, SelectionDAG &DAG,
+                               const AArch64Subtarget *Subtarget) {
+  auto *PR = cast<PartialReduceAddSDNode>(N);
+  if (auto Dot = tryLowerPartialReductionToDot(PR, Subtarget, DAG))
+    return Dot;
+  if (auto WideAdd = tryLowerPartialReductionToWideAdd(PR, Subtarget, DAG))
+    return WideAdd;
+  return DAG.getPartialReduceAdd(SDLoc(PR), PR->getValueType(0), PR->getAcc(),
+                                 PR->getInput());
+}
+
 static SDValue performIntrinsicCombine(SDNode *N,
                                        TargetLowering::DAGCombinerInfo &DCI,
                                        const AArch64Subtarget *Subtarget) {
@@ -21855,14 +21859,6 @@ static SDValue performIntrinsicCombine(SDNode *N,
   switch (IID) {
   default:
     break;
-  case Intrinsic::experimental_vector_partial_reduce_add: {
-    if (auto Dot = tryLowerPartialReductionToDot(N, Subtarget, DAG))
-      return Dot;
-    if (auto WideAdd = tryLowerPartialReductionToWideAdd(N, Subtarget, DAG))
-      return WideAdd;
-    return DAG.getPartialReduceAdd(SDLoc(N), N->getValueType(0),
-                                   N->getOperand(1), N->getOperand(2));
-  }
   case Intrinsic::aarch64_neon_vcvtfxs2fp:
   case Intrinsic::aarch64_neon_vcvtfxu2fp:
     return tryCombineFixedPointConvert(N, DCI, DAG);
@@ -26148,6 +26144,8 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
   case ISD::MSCATTER:
   case ISD::EXPERIMENTAL_VECTOR_HISTOGRAM:
     return performMaskedGatherScatterCombine(N, DCI, DAG);
+  case ISD::PARTIAL_REDUCE_ADD:
+    return performPartialReduceAddCombine(N, DAG, Subtarget);
   case ISD::FP_EXTEND:
     return performFPExtendCombine(N, DAG, DCI, Subtarget);
   case AArch64ISD::BRCOND:

>From 13ff10bf06ac2716d590516267870195ef344113 Mon Sep 17 00:00:00 2001
From: James Chesterman <james.chesterman at arm.com>
Date: Wed, 27 Nov 2024 09:48:06 +0000
Subject: [PATCH 2/3] Changes to previous patch. Involves removing
 PartialReduceAddSDNode as well as changing how the intrinsic is transformed
 into the SD node.

---
 llvm/include/llvm/CodeGen/SelectionDAG.h      | 13 +++-----
 llvm/include/llvm/CodeGen/SelectionDAGNodes.h | 16 ----------
 .../lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 10 ++----
 .../SelectionDAG/SelectionDAGBuilder.cpp      | 21 ++++---------
 .../SelectionDAG/SelectionDAGBuilder.h        |  1 -
 .../Target/AArch64/AArch64ISelLowering.cpp    | 31 +++++++++----------
 6 files changed, 27 insertions(+), 65 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h
index 5d0a976fbb66a9..cda6f6192ca743 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAG.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAG.h
@@ -1595,15 +1595,10 @@ class SelectionDAG {
   /// the target's desired shift amount type.
   SDValue getShiftAmountOperand(EVT LHSTy, SDValue Op);
 
-  /// Get a partial reduction SD node for the DAG. This is done when the input
-  /// and output types can be legalised for wide add(s) or dot product(s)
-  SDValue getPartialReduceAddSDNode(SDLoc DL, SDValue Chain, SDValue Acc,
-                                    SDValue Input);
-
-  /// Create the DAG equivalent of vector_partial_reduce where Op1 and Op2 are
-  /// its operands and ReducedTY is the intrinsic's return type.
-  SDValue getPartialReduceAdd(SDLoc DL, EVT ReducedTy, SDValue Op1,
-                              SDValue Op2);
+  /// Expands partial reduce node which can't be lowered to wide add or dot
+  /// product instruction(s)
+  SDValue expandPartialReduceAdd(SDLoc DL, EVT ReducedTy, SDValue Op1,
+                                 SDValue Op2);
 
   /// Expands a node with multiple results to an FP or vector libcall. The
   /// libcall is expected to take all the operands of the \p Node followed by
diff --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
index 6cdf87fd7895c7..677b59e0c8fbeb 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
@@ -3010,22 +3010,6 @@ class MaskedHistogramSDNode : public MaskedGatherScatterSDNode {
   }
 };
 
-class PartialReduceAddSDNode : public SDNode {
-public:
-  friend class SelectionDAG;
-
-  PartialReduceAddSDNode(const DebugLoc &dl, SDVTList VTs)
-      : SDNode(ISD::PARTIAL_REDUCE_ADD, 0, dl, VTs) {}
-
-  const SDValue &getChain() const { return getOperand(0); }
-  const SDValue &getAcc() const { return getOperand(1); }
-  const SDValue &getInput() const { return getOperand(2); }
-
-  static bool classof(const SDNode *N) {
-    return N->getOpcode() == ISD::PARTIAL_REDUCE_ADD;
-  }
-};
-
 class FPStateAccessSDNode : public MemSDNode {
 public:
   friend class SelectionDAG;
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index a7a208b6af6f9a..e308a630692768 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -2452,14 +2452,8 @@ SDValue SelectionDAG::getShiftAmountOperand(EVT LHSTy, SDValue Op) {
   return getZExtOrTrunc(Op, SDLoc(Op), ShTy);
 }
 
-SDValue SelectionDAG::getPartialReduceAddSDNode(SDLoc DL, SDValue Chain,
-                                                SDValue Acc, SDValue Input) {
-  return getNode(ISD::PARTIAL_REDUCE_ADD, DL, Acc.getValueType(), Chain, Acc,
-                 Input);
-}
-
-SDValue SelectionDAG::getPartialReduceAdd(SDLoc DL, EVT ReducedTy, SDValue Op1,
-                                          SDValue Op2) {
+SDValue SelectionDAG::expandPartialReduceAdd(SDLoc DL, EVT ReducedTy,
+                                             SDValue Op1, SDValue Op2) {
   EVT FullTy = Op2.getValueType();
 
   unsigned Stride = ReducedTy.getVectorMinNumElements();
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 0480d99767bb75..39e4425636eeb6 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -6415,16 +6415,6 @@ void SelectionDAGBuilder::visitVectorHistogram(const CallInst &I,
   DAG.setRoot(Histogram);
 }
 
-void SelectionDAGBuilder::visitPartialReduceAdd(const CallInst &I,
-                                                unsigned IntrinsicID) {
-  SDLoc dl = getCurSDLoc();
-  SDValue Acc = getValue(I.getOperand(0));
-  SDValue Input = getValue(I.getOperand(1));
-  SDValue Chain = getRoot();
-
-  setValue(&I, DAG.getPartialReduceAddSDNode(dl, Chain, Acc, Input));
-}
-
 void SelectionDAGBuilder::visitVectorExtractLastActive(const CallInst &I,
                                                        unsigned Intrinsic) {
   assert(Intrinsic == Intrinsic::experimental_vector_extract_last_active &&
@@ -8136,15 +8126,16 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
     return;
   }
   case Intrinsic::experimental_vector_partial_reduce_add: {
+    SDLoc dl = getCurSDLoc();
+    SDValue Acc = getValue(I.getOperand(0));
+    EVT AccVT = Acc.getValueType();
+    SDValue Input = getValue(I.getOperand(1));
 
     if (!TLI.shouldExpandPartialReductionIntrinsic(cast<IntrinsicInst>(&I))) {
-      visitPartialReduceAdd(I, Intrinsic);
+      setValue(&I, DAG.getNode(ISD::PARTIAL_REDUCE_ADD, dl, AccVT, Acc, Input));
       return;
     }
-
-    setValue(&I, DAG.getPartialReduceAdd(sdl, EVT::getEVT(I.getType()),
-                                         getValue(I.getOperand(0)),
-                                         getValue(I.getOperand(1))));
+    setValue(&I, DAG.expandPartialReduceAdd(dl, AccVT, Acc, Input));
     return;
   }
   case Intrinsic::experimental_cttz_elts: {
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h
index a9e0c8f1ea10c1..3a8dc25e98700e 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h
@@ -629,7 +629,6 @@ class SelectionDAGBuilder {
   void visitConstrainedFPIntrinsic(const ConstrainedFPIntrinsic &FPI);
   void visitConvergenceControl(const CallInst &I, unsigned Intrinsic);
   void visitVectorHistogram(const CallInst &I, unsigned IntrinsicID);
-  void visitPartialReduceAdd(const CallInst &, unsigned IntrinsicID);
   void visitVectorExtractLastActive(const CallInst &I, unsigned Intrinsic);
   void visitVPLoad(const VPIntrinsic &VPIntrin, EVT VT,
                    const SmallVectorImpl<SDValue> &OpValues);
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 6a3fbf3a8b596f..6ef5a711c62520 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -21720,21 +21720,21 @@ static SDValue tryCombineWhileLo(SDNode *N,
   return SDValue(N, 0);
 }
 
-SDValue tryLowerPartialReductionToDot(PartialReduceAddSDNode *PR,
+SDValue tryLowerPartialReductionToDot(SDNode *N,
                                       const AArch64Subtarget *Subtarget,
                                       SelectionDAG &DAG) {
 
-  bool Scalable = PR->getValueType(0).isScalableVector();
+  bool Scalable = N->getValueType(0).isScalableVector();
   if (Scalable && !Subtarget->isSVEorStreamingSVEAvailable())
     return SDValue();
   if (!Scalable && (!Subtarget->isNeonAvailable() || !Subtarget->hasDotProd()))
     return SDValue();
 
-  SDLoc DL(PR);
+  SDLoc DL(N);
 
   // The narrower of the two operands. Used as the accumulator
-  auto NarrowOp = PR->getAcc();
-  auto MulOp = PR->getInput();
+  auto NarrowOp = N->getOperand(0);
+  auto MulOp = N->getOperand(1);
   if (MulOp->getOpcode() != ISD::MUL)
     return SDValue();
 
@@ -21752,7 +21752,7 @@ SDValue tryLowerPartialReductionToDot(PartialReduceAddSDNode *PR,
   if (A.getValueType() != B.getValueType())
     return SDValue();
 
-  EVT ReducedType = PR->getValueType(0);
+  EVT ReducedType = N->getValueType(0);
   EVT MulSrcType = A.getValueType();
 
   // Dot products operate on chunks of four elements so there must be four times
@@ -21771,7 +21771,7 @@ SDValue tryLowerPartialReductionToDot(PartialReduceAddSDNode *PR,
     if (!Subtarget->hasMatMulInt8())
       return SDValue();
 
-    bool Scalable = PR->getValueType(0).isScalableVT();
+    bool Scalable = N->getValueType(0).isScalableVT();
     // There's no nxv2i64 version of usdot
     if (Scalable && ReducedType != MVT::nxv4i32 && ReducedType != MVT::nxv4i64)
       return SDValue();
@@ -21802,17 +21802,17 @@ SDValue tryLowerPartialReductionToDot(PartialReduceAddSDNode *PR,
   return DAG.getNode(Opcode, DL, ReducedType, NarrowOp, A, B);
 }
 
-SDValue tryLowerPartialReductionToWideAdd(PartialReduceAddSDNode *PR,
+SDValue tryLowerPartialReductionToWideAdd(SDNode *N,
                                           const AArch64Subtarget *Subtarget,
                                           SelectionDAG &DAG) {
 
   if (!Subtarget->isSVEorStreamingSVEAvailable())
     return SDValue();
 
-  SDLoc DL(PR);
+  SDLoc DL(N);
 
-  auto Acc = PR->getAcc();
-  auto ExtInput = PR->getInput();
+  auto Acc = N->getOperand(0);
+  auto ExtInput = N->getOperand(1);
 
   EVT AccVT = Acc.getValueType();
   EVT AccElemVT = AccVT.getVectorElementType();
@@ -21842,13 +21842,12 @@ SDValue tryLowerPartialReductionToWideAdd(PartialReduceAddSDNode *PR,
 static SDValue
 performPartialReduceAddCombine(SDNode *N, SelectionDAG &DAG,
                                const AArch64Subtarget *Subtarget) {
-  auto *PR = cast<PartialReduceAddSDNode>(N);
-  if (auto Dot = tryLowerPartialReductionToDot(PR, Subtarget, DAG))
+  if (auto Dot = tryLowerPartialReductionToDot(N, Subtarget, DAG))
     return Dot;
-  if (auto WideAdd = tryLowerPartialReductionToWideAdd(PR, Subtarget, DAG))
+  if (auto WideAdd = tryLowerPartialReductionToWideAdd(N, Subtarget, DAG))
     return WideAdd;
-  return DAG.getPartialReduceAdd(SDLoc(PR), PR->getValueType(0), PR->getAcc(),
-                                 PR->getInput());
+  return DAG.expandPartialReduceAdd(SDLoc(N), N->getValueType(0),
+                                    N->getOperand(0), N->getOperand(1));
 }
 
 static SDValue performIntrinsicCombine(SDNode *N,

>From ea225c5b9833e297f77b2052456ff5bc6eb417d5 Mon Sep 17 00:00:00 2001
From: James Chesterman <james.chesterman at arm.com>
Date: Thu, 28 Nov 2024 12:11:36 +0000
Subject: [PATCH 3/3] Remove unnecessary function parameter and update comments

---
 llvm/include/llvm/CodeGen/ISDOpcodes.h                | 3 ++-
 llvm/include/llvm/CodeGen/SelectionDAG.h              | 9 +++++----
 llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp        | 5 +++--
 llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp | 2 +-
 llvm/lib/Target/AArch64/AArch64ISelLowering.cpp       | 3 +--
 5 files changed, 12 insertions(+), 10 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/ISDOpcodes.h b/llvm/include/llvm/CodeGen/ISDOpcodes.h
index 7809a1b26dd7cd..63a9e322400753 100644
--- a/llvm/include/llvm/CodeGen/ISDOpcodes.h
+++ b/llvm/include/llvm/CodeGen/ISDOpcodes.h
@@ -1451,7 +1451,8 @@ enum NodeType {
   VECREDUCE_UMAX,
   VECREDUCE_UMIN,
 
-  // The `llvm.experimental.vector.partial.reduce.add` intrinsic
+  // This corresponds to the `llvm.experimental.vector.partial.reduce.add`
+  // intrinsic
   // Operands: Accumulator, Input
   // Outputs: Output
   PARTIAL_REDUCE_ADD,
diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h
index cda6f6192ca743..8d59097c6ada1a 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAG.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAG.h
@@ -1595,10 +1595,11 @@ class SelectionDAG {
   /// the target's desired shift amount type.
   SDValue getShiftAmountOperand(EVT LHSTy, SDValue Op);
 
-  /// Expands partial reduce node which can't be lowered to wide add or dot
-  /// product instruction(s)
-  SDValue expandPartialReduceAdd(SDLoc DL, EVT ReducedTy, SDValue Op1,
-                                 SDValue Op2);
+  /// Expands PARTIAL_REDUCE_ADD nodes which can't be lowered.
+  /// @param Op1 Accumulator for where the result is stored for the partial
+  /// reduction operation
+  /// @param Op2 Input for the partial reduction operation
+  SDValue expandPartialReduceAdd(SDLoc DL, SDValue Op1, SDValue Op2);
 
   /// Expands a node with multiple results to an FP or vector libcall. The
   /// libcall is expected to take all the operands of the \p Node followed by
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index e308a630692768..bcfc8665144da4 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -2452,8 +2452,9 @@ SDValue SelectionDAG::getShiftAmountOperand(EVT LHSTy, SDValue Op) {
   return getZExtOrTrunc(Op, SDLoc(Op), ShTy);
 }
 
-SDValue SelectionDAG::expandPartialReduceAdd(SDLoc DL, EVT ReducedTy,
-                                             SDValue Op1, SDValue Op2) {
+SDValue SelectionDAG::expandPartialReduceAdd(SDLoc DL, SDValue Op1,
+                                             SDValue Op2) {
+  EVT ReducedTy = Op1.getValueType();
   EVT FullTy = Op2.getValueType();
 
   unsigned Stride = ReducedTy.getVectorMinNumElements();
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 39e4425636eeb6..ef52fd7775a0db 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -8135,7 +8135,7 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
       setValue(&I, DAG.getNode(ISD::PARTIAL_REDUCE_ADD, dl, AccVT, Acc, Input));
       return;
     }
-    setValue(&I, DAG.expandPartialReduceAdd(dl, AccVT, Acc, Input));
+    setValue(&I, DAG.expandPartialReduceAdd(dl, Acc, Input));
     return;
   }
   case Intrinsic::experimental_cttz_elts: {
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 6ef5a711c62520..a34d7c6d8173e6 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -21846,8 +21846,7 @@ performPartialReduceAddCombine(SDNode *N, SelectionDAG &DAG,
     return Dot;
   if (auto WideAdd = tryLowerPartialReductionToWideAdd(N, Subtarget, DAG))
     return WideAdd;
-  return DAG.expandPartialReduceAdd(SDLoc(N), N->getValueType(0),
-                                    N->getOperand(0), N->getOperand(1));
+  return DAG.expandPartialReduceAdd(SDLoc(N), N->getOperand(0), N->getOperand(1));
 }
 
 static SDValue performIntrinsicCombine(SDNode *N,



More information about the llvm-commits mailing list