[llvm] 44cfbef - [AArch64] Lower partial add reduction to udot or svdot (#101010)

via llvm-commits llvm-commits at lists.llvm.org
Mon Sep 2 06:06:17 PDT 2024


Author: Sam Tebbs
Date: 2024-09-02T14:06:14+01:00
New Revision: 44cfbef1b3cb0dd33886cc27441930008a245963

URL: https://github.com/llvm/llvm-project/commit/44cfbef1b3cb0dd33886cc27441930008a245963
DIFF: https://github.com/llvm/llvm-project/commit/44cfbef1b3cb0dd33886cc27441930008a245963.diff

LOG: [AArch64] Lower partial add reduction to udot or svdot (#101010)

This patch introduces lowering of the partial add reduction intrinsic to
a udot or svdot for AArch64. This also involves adding a
`shouldExpandPartialReductionIntrinsic` target hook, which AArch64 will
return false from in the cases that it can be lowered.

Added: 
    llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll

Modified: 
    llvm/include/llvm/CodeGen/SelectionDAG.h
    llvm/include/llvm/CodeGen/TargetLowering.h
    llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
    llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
    llvm/lib/Target/AArch64/AArch64ISelLowering.h

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h
index 1514d92b36b3c2..7ee8ca18c2c1de 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAG.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAG.h
@@ -1594,6 +1594,11 @@ class SelectionDAG {
   /// the target's desired shift amount type.
   SDValue getShiftAmountOperand(EVT LHSTy, SDValue Op);
 
+  /// Create the DAG equivalent of vector_partial_reduce where Op1 and Op2 are
+  /// its operands and ReducedTY is the intrinsic's return type.
+  SDValue getPartialReduceAdd(SDLoc DL, EVT ReducedTy, SDValue Op1,
+                              SDValue Op2);
+
   /// Expand the specified \c ISD::VAARG node as the Legalize pass would.
   SDValue expandVAArg(SDNode *Node);
 

diff  --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index eda38cd8a564d6..e17d68d2690c86 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -453,6 +453,13 @@ class TargetLoweringBase {
     return true;
   }
 
+  /// Return true if the @llvm.experimental.vector.partial.reduce.* intrinsic
+  /// should be expanded using generic code in SelectionDAGBuilder.
+  virtual bool
+  shouldExpandPartialReductionIntrinsic(const IntrinsicInst *I) const {
+    return true;
+  }
+
   /// Return true if the @llvm.get.active.lane.mask intrinsic should be expanded
   /// using generic code in SelectionDAGBuilder.
   virtual bool shouldExpandGetActiveLaneMask(EVT VT, EVT OpVT) const {

diff  --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 7f57b6db40ef49..aa468fa9ebb4c3 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -74,6 +74,7 @@
 #include <cassert>
 #include <cstdint>
 #include <cstdlib>
+#include <deque>
 #include <limits>
 #include <optional>
 #include <set>
@@ -2439,6 +2440,35 @@ SDValue SelectionDAG::getShiftAmountOperand(EVT LHSTy, SDValue Op) {
   return getZExtOrTrunc(Op, SDLoc(Op), ShTy);
 }
 
+SDValue SelectionDAG::getPartialReduceAdd(SDLoc DL, EVT ReducedTy, SDValue Op1,
+                                          SDValue Op2) {
+  EVT FullTy = Op2.getValueType();
+
+  unsigned Stride = ReducedTy.getVectorMinNumElements();
+  unsigned ScaleFactor = FullTy.getVectorMinNumElements() / Stride;
+
+  // Collect all of the subvectors
+  std::deque<SDValue> Subvectors = {Op1};
+  for (unsigned I = 0; I < ScaleFactor; I++) {
+    auto SourceIndex = getVectorIdxConstant(I * Stride, DL);
+    Subvectors.push_back(
+        getNode(ISD::EXTRACT_SUBVECTOR, DL, ReducedTy, {Op2, SourceIndex}));
+  }
+
+  // Flatten the subvector tree
+  while (Subvectors.size() > 1) {
+    Subvectors.push_back(
+        getNode(ISD::ADD, DL, ReducedTy, {Subvectors[0], Subvectors[1]}));
+    Subvectors.pop_front();
+    Subvectors.pop_front();
+  }
+
+  assert(Subvectors.size() == 1 &&
+         "There should only be one subvector after tree flattening");
+
+  return Subvectors[0];
+}
+
 SDValue SelectionDAG::expandVAArg(SDNode *Node) {
   SDLoc dl(Node);
   const TargetLowering &TLI = getTargetLoweringInfo();

diff  --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 4b326ba76f97f2..382a555aa656f2 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -8038,34 +8038,15 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
     return;
   }
   case Intrinsic::experimental_vector_partial_reduce_add: {
-    SDValue OpNode = getValue(I.getOperand(1));
-    EVT ReducedTy = EVT::getEVT(I.getType());
-    EVT FullTy = OpNode.getValueType();
 
-    unsigned Stride = ReducedTy.getVectorMinNumElements();
-    unsigned ScaleFactor = FullTy.getVectorMinNumElements() / Stride;
-
-    // Collect all of the subvectors
-    std::deque<SDValue> Subvectors;
-    Subvectors.push_back(getValue(I.getOperand(0)));
-    for (unsigned i = 0; i < ScaleFactor; i++) {
-      auto SourceIndex = DAG.getVectorIdxConstant(i * Stride, sdl);
-      Subvectors.push_back(DAG.getNode(ISD::EXTRACT_SUBVECTOR, sdl, ReducedTy,
-                                       {OpNode, SourceIndex}));
-    }
-
-    // Flatten the subvector tree
-    while (Subvectors.size() > 1) {
-      Subvectors.push_back(DAG.getNode(ISD::ADD, sdl, ReducedTy,
-                                       {Subvectors[0], Subvectors[1]}));
-      Subvectors.pop_front();
-      Subvectors.pop_front();
+    if (!TLI.shouldExpandPartialReductionIntrinsic(cast<IntrinsicInst>(&I))) {
+      visitTargetIntrinsic(I, Intrinsic);
+      return;
     }
 
-    assert(Subvectors.size() == 1 &&
-           "There should only be one subvector after tree flattening");
-
-    setValue(&I, Subvectors[0]);
+    setValue(&I, DAG.getPartialReduceAdd(sdl, EVT::getEVT(I.getType()),
+                                         getValue(I.getOperand(0)),
+                                         getValue(I.getOperand(1))));
     return;
   }
   case Intrinsic::experimental_cttz_elts: {

diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 11aca69db0a148..1735ff5cd69748 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1988,6 +1988,15 @@ bool AArch64TargetLowering::shouldExpandGetActiveLaneMask(EVT ResVT,
   return false;
 }
 
+bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
+    const IntrinsicInst *I) const {
+  if (I->getIntrinsicID() != Intrinsic::experimental_vector_partial_reduce_add)
+    return true;
+
+  EVT VT = EVT::getEVT(I->getType());
+  return VT != MVT::nxv4i32 && VT != MVT::nxv2i64;
+}
+
 bool AArch64TargetLowering::shouldExpandCttzElements(EVT VT) const {
   if (!Subtarget->isSVEorStreamingSVEAvailable())
     return true;
@@ -21763,6 +21772,61 @@ static SDValue tryCombineWhileLo(SDNode *N,
   return SDValue(N, 0);
 }
 
+SDValue tryLowerPartialReductionToDot(SDNode *N,
+                                      const AArch64Subtarget *Subtarget,
+                                      SelectionDAG &DAG) {
+
+  assert(N->getOpcode() == ISD::INTRINSIC_WO_CHAIN &&
+         getIntrinsicID(N) ==
+             Intrinsic::experimental_vector_partial_reduce_add &&
+         "Expected a partial reduction node");
+
+  if (!Subtarget->isSVEorStreamingSVEAvailable())
+    return SDValue();
+
+  SDLoc DL(N);
+
+  // The narrower of the two operands. Used as the accumulator
+  auto NarrowOp = N->getOperand(1);
+  auto MulOp = N->getOperand(2);
+  if (MulOp->getOpcode() != ISD::MUL)
+    return SDValue();
+
+  auto ExtA = MulOp->getOperand(0);
+  auto ExtB = MulOp->getOperand(1);
+  bool IsSExt = ExtA->getOpcode() == ISD::SIGN_EXTEND;
+  bool IsZExt = ExtA->getOpcode() == ISD::ZERO_EXTEND;
+  if (ExtA->getOpcode() != ExtB->getOpcode() || (!IsSExt && !IsZExt))
+    return SDValue();
+
+  auto A = ExtA->getOperand(0);
+  auto B = ExtB->getOperand(0);
+  if (A.getValueType() != B.getValueType())
+    return SDValue();
+
+  unsigned Opcode = 0;
+
+  if (IsSExt)
+    Opcode = AArch64ISD::SDOT;
+  else if (IsZExt)
+    Opcode = AArch64ISD::UDOT;
+
+  assert(Opcode != 0 && "Unexpected dot product case encountered.");
+
+  EVT ReducedType = N->getValueType(0);
+  EVT MulSrcType = A.getValueType();
+
+  // Dot products operate on chunks of four elements so there must be four times
+  // as many elements in the wide type
+  if (ReducedType == MVT::nxv4i32 && MulSrcType == MVT::nxv16i8)
+    return DAG.getNode(Opcode, DL, MVT::nxv4i32, NarrowOp, A, B);
+
+  if (ReducedType == MVT::nxv2i64 && MulSrcType == MVT::nxv8i16)
+    return DAG.getNode(Opcode, DL, MVT::nxv2i64, NarrowOp, A, B);
+
+  return SDValue();
+}
+
 static SDValue performIntrinsicCombine(SDNode *N,
                                        TargetLowering::DAGCombinerInfo &DCI,
                                        const AArch64Subtarget *Subtarget) {
@@ -21771,6 +21835,12 @@ static SDValue performIntrinsicCombine(SDNode *N,
   switch (IID) {
   default:
     break;
+  case Intrinsic::experimental_vector_partial_reduce_add: {
+    if (auto Dot = tryLowerPartialReductionToDot(N, Subtarget, DAG))
+      return Dot;
+    return DAG.getPartialReduceAdd(SDLoc(N), N->getValueType(0),
+                                   N->getOperand(1), N->getOperand(2));
+  }
   case Intrinsic::aarch64_neon_vcvtfxs2fp:
   case Intrinsic::aarch64_neon_vcvtfxu2fp:
     return tryCombineFixedPointConvert(N, DCI, DAG);

diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index 39d5df0de0eec7..f9d45b02d30e30 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -998,6 +998,9 @@ class AArch64TargetLowering : public TargetLowering {
 
   bool shouldExpandGetActiveLaneMask(EVT VT, EVT OpVT) const override;
 
+  bool
+  shouldExpandPartialReductionIntrinsic(const IntrinsicInst *I) const override;
+
   bool shouldExpandCttzElements(EVT VT) const override;
 
   /// If a change in streaming mode is required on entry to/return from a

diff  --git a/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll
new file mode 100644
index 00000000000000..b1354ab210f727
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll
@@ -0,0 +1,96 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc -mtriple=aarch64 -mattr=+sve2 %s -o - | FileCheck %s
+
+define <vscale x 4 x i32> @dotp(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
+; CHECK-LABEL: dotp:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    udot z0.s, z1.b, z2.b
+; CHECK-NEXT:    ret
+entry:
+  %a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+  %b.wide = zext <vscale x 16 x i8> %b to <vscale x 16 x i32>
+  %mult = mul nuw nsw <vscale x 16 x i32> %a.wide, %b.wide
+  %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %mult)
+  ret <vscale x 4 x i32> %partial.reduce
+}
+
+define <vscale x 2 x i64> @dotp_wide(<vscale x 2 x i64> %acc, <vscale x 8 x i16> %a, <vscale x 8 x i16> %b) {
+; CHECK-LABEL: dotp_wide:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    udot z0.d, z1.h, z2.h
+; CHECK-NEXT:    ret
+entry:
+  %a.wide = zext <vscale x 8 x i16> %a to <vscale x 8 x i64>
+  %b.wide = zext <vscale x 8 x i16> %b to <vscale x 8 x i64>
+  %mult = mul nuw nsw <vscale x 8 x i64> %a.wide, %b.wide
+  %partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv8i64(<vscale x 2 x i64> %acc, <vscale x 8 x i64> %mult)
+  ret <vscale x 2 x i64> %partial.reduce
+}
+
+define <vscale x 4 x i32> @dotp_sext(<vscale x 4 x i32> %accc, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
+; CHECK-LABEL: dotp_sext:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    sdot z0.s, z1.b, z2.b
+; CHECK-NEXT:    ret
+entry:
+  %a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+  %b.wide = sext <vscale x 16 x i8> %b to <vscale x 16 x i32>
+  %mult = mul nuw nsw <vscale x 16 x i32> %a.wide, %b.wide
+  %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %accc, <vscale x 16 x i32> %mult)
+  ret <vscale x 4 x i32> %partial.reduce
+}
+
+define <vscale x 2 x i64> @dotp_wide_sext(<vscale x 2 x i64> %acc, <vscale x 8 x i16> %a, <vscale x 8 x i16> %b) {
+; CHECK-LABEL: dotp_wide_sext:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    sdot z0.d, z1.h, z2.h
+; CHECK-NEXT:    ret
+entry:
+  %a.wide = sext <vscale x 8 x i16> %a to <vscale x 8 x i64>
+  %b.wide = sext <vscale x 8 x i16> %b to <vscale x 8 x i64>
+  %mult = mul nuw nsw <vscale x 8 x i64> %a.wide, %b.wide
+  %partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv8i64(<vscale x 2 x i64> %acc, <vscale x 8 x i64> %mult)
+  ret <vscale x 2 x i64> %partial.reduce
+}
+
+define <vscale x 4 x i32> @not_dotp(<vscale x 4 x i32> %acc, <vscale x 8 x i8> %a, <vscale x 8 x i8> %b) {
+; CHECK-LABEL: not_dotp:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    and z1.h, z1.h, #0xff
+; CHECK-NEXT:    and z2.h, z2.h, #0xff
+; CHECK-NEXT:    ptrue p0.s
+; CHECK-NEXT:    uunpklo z3.s, z1.h
+; CHECK-NEXT:    uunpklo z4.s, z2.h
+; CHECK-NEXT:    uunpkhi z1.s, z1.h
+; CHECK-NEXT:    uunpkhi z2.s, z2.h
+; CHECK-NEXT:    mla z0.s, p0/m, z3.s, z4.s
+; CHECK-NEXT:    mla z0.s, p0/m, z1.s, z2.s
+; CHECK-NEXT:    ret
+entry:
+  %a.wide = zext <vscale x 8 x i8> %a to <vscale x 8 x i32>
+  %b.wide = zext <vscale x 8 x i8> %b to <vscale x 8 x i32>
+  %mult = mul nuw nsw <vscale x 8 x i32> %a.wide, %b.wide
+  %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %acc, <vscale x 8 x i32> %mult)
+  ret <vscale x 4 x i32> %partial.reduce
+}
+
+define <vscale x 2 x i64> @not_dotp_wide(<vscale x 2 x i64> %acc, <vscale x 4 x i16> %a, <vscale x 4 x i16> %b) {
+; CHECK-LABEL: not_dotp_wide:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    and z1.s, z1.s, #0xffff
+; CHECK-NEXT:    and z2.s, z2.s, #0xffff
+; CHECK-NEXT:    ptrue p0.d
+; CHECK-NEXT:    uunpklo z3.d, z1.s
+; CHECK-NEXT:    uunpklo z4.d, z2.s
+; CHECK-NEXT:    uunpkhi z1.d, z1.s
+; CHECK-NEXT:    uunpkhi z2.d, z2.s
+; CHECK-NEXT:    mla z0.d, p0/m, z3.d, z4.d
+; CHECK-NEXT:    mla z0.d, p0/m, z1.d, z2.d
+; CHECK-NEXT:    ret
+entry:
+  %a.wide = zext <vscale x 4 x i16> %a to <vscale x 4 x i64>
+  %b.wide = zext <vscale x 4 x i16> %b to <vscale x 4 x i64>
+  %mult = mul nuw nsw <vscale x 4 x i64> %a.wide, %b.wide
+  %partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv4i64(<vscale x 2 x i64> %acc, <vscale x 4 x i64> %mult)
+  ret <vscale x 2 x i64> %partial.reduce
+}


        


More information about the llvm-commits mailing list