[llvm] [AArch64] Lower partial add reduction to udot or svdot (PR #101010)

Sam Tebbs via llvm-commits llvm-commits at lists.llvm.org
Fri Aug 30 02:02:48 PDT 2024


https://github.com/SamTebbs33 updated https://github.com/llvm/llvm-project/pull/101010

>From 561706ff1ed18f9b1924df417dbe1c2a4ff65432 Mon Sep 17 00:00:00 2001
From: Sam Tebbs <samuel.tebbs at arm.com>
Date: Mon, 29 Jul 2024 10:46:16 +0100
Subject: [PATCH 01/27] [AArch64] Lower add partial reduction to udot

This patch introduces lowering of the partial add reduction intrinsic to
a udot or svdot for AArch64.
---
 llvm/include/llvm/CodeGen/TargetLowering.h    |   6 +
 .../SelectionDAG/SelectionDAGBuilder.cpp      |   6 +
 .../Target/AArch64/AArch64ISelLowering.cpp    |  77 +++++++++++++
 llvm/lib/Target/AArch64/AArch64ISelLowering.h |   2 +
 .../AArch64/AArch64TargetTransformInfo.cpp    |  30 +++++
 .../AArch64/AArch64TargetTransformInfo.h      |   6 +
 .../AArch64/partial-reduce-dot-product.ll     | 109 ++++++++++++++++++
 7 files changed, 236 insertions(+)
 create mode 100644 llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll

diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index eda38cd8a564d6..883a2252f7ffee 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -453,6 +453,12 @@ class TargetLoweringBase {
     return true;
   }
 
+  /// Return true if the @llvm.experimental.vector.partial.reduce.* intrinsic
+  /// should be expanded using generic code in SelectionDAGBuilder.
+  virtual bool shouldExpandPartialReductionIntrinsic(const CallInst *I) const {
+    return true;
+  }
+
   /// Return true if the @llvm.get.active.lane.mask intrinsic should be expanded
   /// using generic code in SelectionDAGBuilder.
   virtual bool shouldExpandGetActiveLaneMask(EVT VT, EVT OpVT) const {
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 60dcb118542785..5ddbf9f414d218 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -8005,6 +8005,12 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
     return;
   }
   case Intrinsic::experimental_vector_partial_reduce_add: {
+
+    if (!TLI.shouldExpandPartialReductionIntrinsic(&I)) {
+      visitTargetIntrinsic(I, Intrinsic);
+      return;
+    }
+
     SDValue OpNode = getValue(I.getOperand(1));
     EVT ReducedTy = EVT::getEVT(I.getType());
     EVT FullTy = OpNode.getValueType();
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 215f30128e7038..987a7290274e73 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1988,6 +1988,57 @@ bool AArch64TargetLowering::shouldExpandGetActiveLaneMask(EVT ResVT,
   return false;
 }
 
+bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
+    const CallInst *CI) const {
+  const bool TargetLowers = false;
+  const bool GenericLowers = true;
+
+  auto *I = dyn_cast<IntrinsicInst>(CI);
+  if (!I)
+    return GenericLowers;
+
+  ScalableVectorType *RetTy = dyn_cast<ScalableVectorType>(I->getType());
+
+  if (!RetTy)
+    return GenericLowers;
+
+  ScalableVectorType *InputTy = nullptr;
+
+  auto RetScalarTy = RetTy->getScalarType();
+  if (RetScalarTy->isIntegerTy(64)) {
+    InputTy = ScalableVectorType::get(Type::getInt16Ty(I->getContext()), 8);
+  } else if (RetScalarTy->isIntegerTy(32)) {
+    InputTy = ScalableVectorType::get(Type::getInt8Ty(I->getContext()), 16);
+  }
+
+  if (!InputTy)
+    return GenericLowers;
+
+  Value *InputA;
+  Value *InputB;
+
+  auto Pattern = m_Intrinsic<Intrinsic::experimental_vector_partial_reduce_add>(
+      m_Value(), m_OneUse(m_Mul(m_OneUse(m_ZExtOrSExt(m_Value(InputA))),
+                                m_OneUse(m_ZExtOrSExt(m_Value(InputB))))));
+
+  if (!match(I, Pattern))
+    return GenericLowers;
+
+  auto Mul = cast<Instruction>(I->getOperand(1));
+
+  auto getOpcodeOfOperand = [&](unsigned Idx) {
+    return cast<Instruction>(Mul->getOperand(Idx))->getOpcode();
+  };
+
+  if (getOpcodeOfOperand(0) != getOpcodeOfOperand(1))
+    return GenericLowers;
+
+  if (InputA->getType() != InputTy || InputB->getType() != InputTy)
+    return GenericLowers;
+
+  return TargetLowers;
+}
+
 bool AArch64TargetLowering::shouldExpandCttzElements(EVT VT) const {
   if (!Subtarget->isSVEorStreamingSVEAvailable())
     return true;
@@ -21765,6 +21816,32 @@ static SDValue performIntrinsicCombine(SDNode *N,
   switch (IID) {
   default:
     break;
+  case Intrinsic::experimental_vector_partial_reduce_add: {
+    SDLoc DL(N);
+
+    auto NarrowOp = N->getOperand(1);
+    auto MulOp = N->getOperand(2);
+
+    auto ExtA = MulOp->getOperand(0);
+    auto ExtB = MulOp->getOperand(1);
+
+    unsigned DotIntrinsicId = Intrinsic::not_intrinsic;
+
+    if (ExtA->getOpcode() == ISD::SIGN_EXTEND)
+      DotIntrinsicId = Intrinsic::aarch64_sve_sdot;
+    else if (ExtA->getOpcode() == ISD::ZERO_EXTEND)
+      DotIntrinsicId = Intrinsic::aarch64_sve_udot;
+
+    assert(DotIntrinsicId != Intrinsic::not_intrinsic &&
+           "Unexpected dot product case encountered.");
+
+    auto A = ExtA->getOperand(0);
+    auto B = ExtB->getOperand(0);
+
+    auto IntrinsicId = DAG.getConstant(DotIntrinsicId, DL, MVT::i64);
+    return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, NarrowOp.getValueType(),
+                       {IntrinsicId, NarrowOp, A, B});
+  }
   case Intrinsic::aarch64_neon_vcvtfxs2fp:
   case Intrinsic::aarch64_neon_vcvtfxu2fp:
     return tryCombineFixedPointConvert(N, DCI, DAG);
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index 39d5df0de0eec7..9fe95ddaca32c8 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -998,6 +998,8 @@ class AArch64TargetLowering : public TargetLowering {
 
   bool shouldExpandGetActiveLaneMask(EVT VT, EVT OpVT) const override;
 
+  bool shouldExpandPartialReductionIntrinsic(const CallInst *I) const override;
+
   bool shouldExpandCttzElements(EVT VT) const override;
 
   /// If a change in streaming mode is required on entry to/return from a
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index dc748290f2e21e..5871134e60985b 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -3670,6 +3670,36 @@ AArch64TTIImpl::getCostOfKeepingLiveOverCall(ArrayRef<Type *> Tys) {
   return Cost;
 }
 
+bool AArch64TTIImpl::isPartialReductionSupported(
+    const Instruction *ReductionInstr, Type *InputType, unsigned ScaleFactor,
+    bool IsInputASignExtended, bool IsInputBSignExtended,
+    const Instruction *BinOp) const {
+  if (ReductionInstr->getOpcode() != Instruction::Add)
+    return false;
+
+  // Check that both extends are of the same type
+  if (IsInputASignExtended != IsInputBSignExtended)
+    return false;
+
+  if (!BinOp || BinOp->getOpcode() != Instruction::Mul)
+    return false;
+
+  // Dot product only supports a scale factor of 4
+  if (ScaleFactor != 4)
+    return false;
+
+  Type *ReductionType = ReductionInstr->getType();
+  if (ReductionType->isIntegerTy(32)) {
+    if (!InputType->isIntegerTy(8))
+      return false;
+  } else if (ReductionType->isIntegerTy(64)) {
+    if (!InputType->isIntegerTy(16))
+      return false;
+  }
+
+  return true;
+}
+
 unsigned AArch64TTIImpl::getMaxInterleaveFactor(ElementCount VF) {
   return ST->getMaxInterleaveFactor();
 }
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index 4a6457d7a7dbf5..af7e8e8e497dd8 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -155,6 +155,12 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
     return VF.getKnownMinValue() * ST->getVScaleForTuning();
   }
 
+  bool isPartialReductionSupported(const Instruction *ReductionInstr,
+                                   Type *InputType, unsigned ScaleFactor,
+                                   bool IsInputASignExtended,
+                                   bool IsInputBSignExtended,
+                                   const Instruction *BinOp = nullptr) const;
+
   unsigned getMaxInterleaveFactor(ElementCount VF);
 
   bool prefersVectorizedAddressing() const;
diff --git a/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll
new file mode 100644
index 00000000000000..23b39387fb7a0c
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll
@@ -0,0 +1,109 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc -mtriple=aarch64-unknwon-linux-gnu -mattr=+sve2 -O3 %s -o - | FileCheck %s
+
+target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128"
+target triple = "aarch64-none-unknown-elf"
+
+define <vscale x 4 x i32> @dotp(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b) #0 {
+; CHECK-LABEL: dotp:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    mov z2.s, #0 // =0x0
+; CHECK-NEXT:    udot z2.s, z0.b, z1.b
+; CHECK-NEXT:    mov z0.d, z2.d
+; CHECK-NEXT:    ret
+entry:
+  %a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+  %b.wide = zext <vscale x 16 x i8> %b to <vscale x 16 x i32>
+  %mult = mul nuw nsw <vscale x 16 x i32> %a.wide, %b.wide
+  %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> zeroinitializer, <vscale x 16 x i32> %mult)
+  ret <vscale x 4 x i32> %partial.reduce
+}
+
+define <vscale x 2 x i64> @dotp_wide(<vscale x 8 x i16> %a, <vscale x 8 x i16> %b) #0 {
+; CHECK-LABEL: dotp_wide:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    mov z2.d, #0 // =0x0
+; CHECK-NEXT:    udot z2.d, z0.h, z1.h
+; CHECK-NEXT:    mov z0.d, z2.d
+; CHECK-NEXT:    ret
+entry:
+  %a.wide = zext <vscale x 8 x i16> %a to <vscale x 8 x i64>
+  %b.wide = zext <vscale x 8 x i16> %b to <vscale x 8 x i64>
+  %mult = mul nuw nsw <vscale x 8 x i64> %a.wide, %b.wide
+  %partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv8i64(<vscale x 2 x i64> zeroinitializer, <vscale x 8 x i64> %mult)
+  ret <vscale x 2 x i64> %partial.reduce
+}
+
+define <vscale x 4 x i32> @dotp_sext(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b) #0 {
+; CHECK-LABEL: dotp_sext:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    mov z2.s, #0 // =0x0
+; CHECK-NEXT:    sdot z2.s, z0.b, z1.b
+; CHECK-NEXT:    mov z0.d, z2.d
+; CHECK-NEXT:    ret
+entry:
+  %a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+  %b.wide = sext <vscale x 16 x i8> %b to <vscale x 16 x i32>
+  %mult = mul nuw nsw <vscale x 16 x i32> %a.wide, %b.wide
+  %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> zeroinitializer, <vscale x 16 x i32> %mult)
+  ret <vscale x 4 x i32> %partial.reduce
+}
+
+define <vscale x 2 x i64> @dotp_wide_sext(<vscale x 8 x i16> %a, <vscale x 8 x i16> %b) #0 {
+; CHECK-LABEL: dotp_wide_sext:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    mov z2.d, #0 // =0x0
+; CHECK-NEXT:    sdot z2.d, z0.h, z1.h
+; CHECK-NEXT:    mov z0.d, z2.d
+; CHECK-NEXT:    ret
+entry:
+  %a.wide = sext <vscale x 8 x i16> %a to <vscale x 8 x i64>
+  %b.wide = sext <vscale x 8 x i16> %b to <vscale x 8 x i64>
+  %mult = mul nuw nsw <vscale x 8 x i64> %a.wide, %b.wide
+  %partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv8i64(<vscale x 2 x i64> zeroinitializer, <vscale x 8 x i64> %mult)
+  ret <vscale x 2 x i64> %partial.reduce
+}
+
+define <vscale x 4 x i32> @not_dotp(<vscale x 8 x i8> %a, <vscale x 8 x i8> %b) #0 {
+; CHECK-LABEL: not_dotp:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    and z0.h, z0.h, #0xff
+; CHECK-NEXT:    and z1.h, z1.h, #0xff
+; CHECK-NEXT:    ptrue p0.s
+; CHECK-NEXT:    uunpkhi z2.s, z0.h
+; CHECK-NEXT:    uunpkhi z3.s, z1.h
+; CHECK-NEXT:    uunpklo z0.s, z0.h
+; CHECK-NEXT:    uunpklo z1.s, z1.h
+; CHECK-NEXT:    mul z2.s, z2.s, z3.s
+; CHECK-NEXT:    mad z0.s, p0/m, z1.s, z2.s
+; CHECK-NEXT:    ret
+entry:
+  %a.wide = zext <vscale x 8 x i8> %a to <vscale x 8 x i32>
+  %b.wide = zext <vscale x 8 x i8> %b to <vscale x 8 x i32>
+  %mult = mul nuw nsw <vscale x 8 x i32> %a.wide, %b.wide
+  %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> zeroinitializer, <vscale x 8 x i32> %mult)
+  ret <vscale x 4 x i32> %partial.reduce
+}
+
+define <vscale x 2 x i64> @not_dotp_wide(<vscale x 4 x i16> %a, <vscale x 4 x i16> %b) #0 {
+; CHECK-LABEL: not_dotp_wide:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    and z0.s, z0.s, #0xffff
+; CHECK-NEXT:    and z1.s, z1.s, #0xffff
+; CHECK-NEXT:    ptrue p0.d
+; CHECK-NEXT:    uunpkhi z2.d, z0.s
+; CHECK-NEXT:    uunpkhi z3.d, z1.s
+; CHECK-NEXT:    uunpklo z0.d, z0.s
+; CHECK-NEXT:    uunpklo z1.d, z1.s
+; CHECK-NEXT:    mul z2.d, z2.d, z3.d
+; CHECK-NEXT:    mad z0.d, p0/m, z1.d, z2.d
+; CHECK-NEXT:    ret
+entry:
+  %a.wide = zext <vscale x 4 x i16> %a to <vscale x 4 x i64>
+  %b.wide = zext <vscale x 4 x i16> %b to <vscale x 4 x i64>
+  %mult = mul nuw nsw <vscale x 4 x i64> %a.wide, %b.wide
+  %partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv4i64(<vscale x 2 x i64> zeroinitializer, <vscale x 4 x i64> %mult)
+  ret <vscale x 2 x i64> %partial.reduce
+}
+
+attributes #0 = { "target-features"="+sve2" }

>From 563d025161d0b0e00253daf8943926b8e6ebfc75 Mon Sep 17 00:00:00 2001
From: Sam Tebbs <samuel.tebbs at arm.com>
Date: Mon, 29 Jul 2024 18:08:52 +0100
Subject: [PATCH 02/27] Remove TargetLowers and GenericLowers

---
 llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 16 +++++++---------
 1 file changed, 7 insertions(+), 9 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 987a7290274e73..f24e24bfe2d454 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1990,17 +1990,15 @@ bool AArch64TargetLowering::shouldExpandGetActiveLaneMask(EVT ResVT,
 
 bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
     const CallInst *CI) const {
-  const bool TargetLowers = false;
-  const bool GenericLowers = true;
 
   auto *I = dyn_cast<IntrinsicInst>(CI);
   if (!I)
-    return GenericLowers;
+    return true;
 
   ScalableVectorType *RetTy = dyn_cast<ScalableVectorType>(I->getType());
 
   if (!RetTy)
-    return GenericLowers;
+    return true;
 
   ScalableVectorType *InputTy = nullptr;
 
@@ -2012,7 +2010,7 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
   }
 
   if (!InputTy)
-    return GenericLowers;
+    return true;
 
   Value *InputA;
   Value *InputB;
@@ -2022,7 +2020,7 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
                                 m_OneUse(m_ZExtOrSExt(m_Value(InputB))))));
 
   if (!match(I, Pattern))
-    return GenericLowers;
+    return true;
 
   auto Mul = cast<Instruction>(I->getOperand(1));
 
@@ -2031,12 +2029,12 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
   };
 
   if (getOpcodeOfOperand(0) != getOpcodeOfOperand(1))
-    return GenericLowers;
+    return true;
 
   if (InputA->getType() != InputTy || InputB->getType() != InputTy)
-    return GenericLowers;
+    return true;
 
-  return TargetLowers;
+  return false;
 }
 
 bool AArch64TargetLowering::shouldExpandCttzElements(EVT VT) const {

>From e604e4570412adee32b2ce363f8ceea9c7438349 Mon Sep 17 00:00:00 2001
From: Sam Tebbs <samuel.tebbs at arm.com>
Date: Mon, 29 Jul 2024 18:16:16 +0100
Subject: [PATCH 03/27] Assert that shouldExpandPartialReductionIntrinsic sees
 an intrinsic

---
 llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index f24e24bfe2d454..8811f2ef94fb09 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1992,8 +1992,7 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
     const CallInst *CI) const {
 
   auto *I = dyn_cast<IntrinsicInst>(CI);
-  if (!I)
-    return true;
+  assert(I && "shouldExpandPartialReductionIntrinsic expects an intrinisc");
 
   ScalableVectorType *RetTy = dyn_cast<ScalableVectorType>(I->getType());
 

>From 9b23c96f5cd7a5ce10229682e61b1d8c5464b01e Mon Sep 17 00:00:00 2001
From: Sam Tebbs <samuel.tebbs at arm.com>
Date: Mon, 29 Jul 2024 18:31:24 +0100
Subject: [PATCH 04/27] Allow non-scalable vector types

---
 llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 9 ++++-----
 1 file changed, 4 insertions(+), 5 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 8811f2ef94fb09..ad38410ba3f275 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1994,18 +1994,17 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
   auto *I = dyn_cast<IntrinsicInst>(CI);
   assert(I && "shouldExpandPartialReductionIntrinsic expects an intrinisc");
 
-  ScalableVectorType *RetTy = dyn_cast<ScalableVectorType>(I->getType());
-
+  VectorType *RetTy = dyn_cast<VectorType>(I->getType());
   if (!RetTy)
     return true;
 
-  ScalableVectorType *InputTy = nullptr;
+  VectorType *InputTy = nullptr;
 
   auto RetScalarTy = RetTy->getScalarType();
   if (RetScalarTy->isIntegerTy(64)) {
-    InputTy = ScalableVectorType::get(Type::getInt16Ty(I->getContext()), 8);
+    InputTy = VectorType::get(Type::getInt16Ty(I->getContext()), 8, RetTy->isScalableTy());
   } else if (RetScalarTy->isIntegerTy(32)) {
-    InputTy = ScalableVectorType::get(Type::getInt8Ty(I->getContext()), 16);
+    InputTy = VectorType::get(Type::getInt8Ty(I->getContext()), 16, RetTy->isScalableTy());
   }
 
   if (!InputTy)

>From 45692dff2a6498d3ebc53a5862ef68239a772ab9 Mon Sep 17 00:00:00 2001
From: Sam Tebbs <samuel.tebbs at arm.com>
Date: Mon, 29 Jul 2024 19:10:08 +0100
Subject: [PATCH 05/27] Clean up type checking

---
 llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 9 +++------
 1 file changed, 3 insertions(+), 6 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index ad38410ba3f275..f96be2a9f55e62 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -2001,13 +2001,11 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
   VectorType *InputTy = nullptr;
 
   auto RetScalarTy = RetTy->getScalarType();
-  if (RetScalarTy->isIntegerTy(64)) {
+  if (RetScalarTy->isIntegerTy(64))
     InputTy = VectorType::get(Type::getInt16Ty(I->getContext()), 8, RetTy->isScalableTy());
-  } else if (RetScalarTy->isIntegerTy(32)) {
+  else if (RetScalarTy->isIntegerTy(32))
     InputTy = VectorType::get(Type::getInt8Ty(I->getContext()), 16, RetTy->isScalableTy());
-  }
-
-  if (!InputTy)
+  else
     return true;
 
   Value *InputA;
@@ -2021,7 +2019,6 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
     return true;
 
   auto Mul = cast<Instruction>(I->getOperand(1));
-
   auto getOpcodeOfOperand = [&](unsigned Idx) {
     return cast<Instruction>(Mul->getOperand(Idx))->getOpcode();
   };

>From d305452d3cf171b320cdd25720247a4d61fb1a31 Mon Sep 17 00:00:00 2001
From: Samuel Tebbs <samuel.tebbs at arm.com>
Date: Thu, 1 Aug 2024 11:04:37 +0100
Subject: [PATCH 06/27] Restrict to scalable vector types and clean up type
 checking

---
 llvm/lib/Target/AArch64/AArch64ISelLowering.cpp       |  2 +-
 .../lib/Target/AArch64/AArch64TargetTransformInfo.cpp | 11 +++--------
 2 files changed, 4 insertions(+), 9 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index f96be2a9f55e62..31c3e208356bf9 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1995,7 +1995,7 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
   assert(I && "shouldExpandPartialReductionIntrinsic expects an intrinisc");
 
   VectorType *RetTy = dyn_cast<VectorType>(I->getType());
-  if (!RetTy)
+  if (!RetTy || !RetTy->isScalableTy())
     return true;
 
   VectorType *InputTy = nullptr;
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 5871134e60985b..8cd2ba17b7d79d 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -3689,15 +3689,10 @@ bool AArch64TTIImpl::isPartialReductionSupported(
     return false;
 
   Type *ReductionType = ReductionInstr->getType();
-  if (ReductionType->isIntegerTy(32)) {
-    if (!InputType->isIntegerTy(8))
-      return false;
-  } else if (ReductionType->isIntegerTy(64)) {
-    if (!InputType->isIntegerTy(16))
-      return false;
-  }
 
-  return true;
+  return ((ReductionType->isIntegerTy(32) && InputType->isIntegerTy(8)) ||
+          (ReductionType->isIntegerTy(64) && InputType->isIntegerTy(16))) &&
+         ReductionType->isScalableTy();
 }
 
 unsigned AArch64TTIImpl::getMaxInterleaveFactor(ElementCount VF) {

>From 4738a204c93ba2386783afb375b95c806e111522 Mon Sep 17 00:00:00 2001
From: Samuel Tebbs <samuel.tebbs at arm.com>
Date: Thu, 1 Aug 2024 11:36:50 +0100
Subject: [PATCH 07/27] Simplify instruction matching in
 shouldExpandPartialReduction

---
 .../Target/AArch64/AArch64ISelLowering.cpp    | 56 +++++++++----------
 1 file changed, 27 insertions(+), 29 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 31c3e208356bf9..c4d00abfc190f0 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1998,38 +1998,36 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
   if (!RetTy || !RetTy->isScalableTy())
     return true;
 
-  VectorType *InputTy = nullptr;
-
-  auto RetScalarTy = RetTy->getScalarType();
-  if (RetScalarTy->isIntegerTy(64))
-    InputTy = VectorType::get(Type::getInt16Ty(I->getContext()), 8, RetTy->isScalableTy());
-  else if (RetScalarTy->isIntegerTy(32))
-    InputTy = VectorType::get(Type::getInt8Ty(I->getContext()), 16, RetTy->isScalableTy());
-  else
-    return true;
-
   Value *InputA;
   Value *InputB;
+  if (match(I, m_Intrinsic<Intrinsic::experimental_vector_partial_reduce_add>(
+                   m_Value(),
+                   m_OneUse(m_Mul(m_OneUse(m_ZExtOrSExt(m_Value(InputA))),
+                                  m_OneUse(m_ZExtOrSExt(m_Value(InputB)))))))) {
+    VectorType *InputAType = dyn_cast<VectorType>(InputA->getType());
+    VectorType *InputBType = dyn_cast<VectorType>(InputB->getType());
+    if (!InputAType || !InputBType)
+      return true;
+    ElementCount ExpectedCount8 = ElementCount::get(8, RetTy->isScalableTy());
+    ElementCount ExpectedCount16 = ElementCount::get(16, RetTy->isScalableTy());
+    if ((RetTy->getScalarType()->isIntegerTy(64) &&
+         InputAType->getElementType()->isIntegerTy(16) &&
+         InputAType->getElementCount() == ExpectedCount8 &&
+         InputAType == InputBType) ||
+
+        (RetTy->getScalarType()->isIntegerTy(32) &&
+         InputAType->getElementType()->isIntegerTy(8) &&
+         InputAType->getElementCount() == ExpectedCount16 &&
+         InputAType == InputBType)) {
+      auto *Mul = cast<Instruction>(I->getOperand(1));
+      auto *Mul0 = cast<Instruction>(Mul->getOperand(0));
+      auto *Mul1 = cast<Instruction>(Mul->getOperand(1));
+      if (Mul0->getOpcode() == Mul1->getOpcode())
+        return false;
+    }
+  }
 
-  auto Pattern = m_Intrinsic<Intrinsic::experimental_vector_partial_reduce_add>(
-      m_Value(), m_OneUse(m_Mul(m_OneUse(m_ZExtOrSExt(m_Value(InputA))),
-                                m_OneUse(m_ZExtOrSExt(m_Value(InputB))))));
-
-  if (!match(I, Pattern))
-    return true;
-
-  auto Mul = cast<Instruction>(I->getOperand(1));
-  auto getOpcodeOfOperand = [&](unsigned Idx) {
-    return cast<Instruction>(Mul->getOperand(Idx))->getOpcode();
-  };
-
-  if (getOpcodeOfOperand(0) != getOpcodeOfOperand(1))
-    return true;
-
-  if (InputA->getType() != InputTy || InputB->getType() != InputTy)
-    return true;
-
-  return false;
+  return true;
 }
 
 bool AArch64TargetLowering::shouldExpandCttzElements(EVT VT) const {

>From 4dbf99e959674b91c743340560816d926480e2a3 Mon Sep 17 00:00:00 2001
From: Sam Tebbs <samuel.tebbs at arm.com>
Date: Fri, 9 Aug 2024 16:38:22 +0100
Subject: [PATCH 08/27] Add fallback in case the nodes aren't as we expect at
 lowering time

---
 .../Target/AArch64/AArch64ISelLowering.cpp    | 67 ++++++++++++++++---
 1 file changed, 59 insertions(+), 8 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index c4d00abfc190f0..a5ea612fb38996 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -21810,28 +21810,79 @@ static SDValue performIntrinsicCombine(SDNode *N,
   case Intrinsic::experimental_vector_partial_reduce_add: {
     SDLoc DL(N);
 
+    bool IsValidDotProduct = false;
+
     auto NarrowOp = N->getOperand(1);
     auto MulOp = N->getOperand(2);
+    if (MulOp->getOpcode() == ISD::MUL)
+      IsValidDotProduct = true;
 
     auto ExtA = MulOp->getOperand(0);
     auto ExtB = MulOp->getOperand(1);
+    bool IsSExt = ExtA->getOpcode() == ISD::SIGN_EXTEND;
+    bool IsZExt = ExtA->getOpcode() == ISD::ZERO_EXTEND;
+    if (ExtA->getOpcode() == ExtB->getOpcode() && (IsSExt || IsZExt))
+      IsValidDotProduct = true;
 
     unsigned DotIntrinsicId = Intrinsic::not_intrinsic;
 
-    if (ExtA->getOpcode() == ISD::SIGN_EXTEND)
+    if (IsSExt && IsValidDotProduct)
       DotIntrinsicId = Intrinsic::aarch64_sve_sdot;
-    else if (ExtA->getOpcode() == ISD::ZERO_EXTEND)
+    else if (IsZExt && IsValidDotProduct)
       DotIntrinsicId = Intrinsic::aarch64_sve_udot;
 
-    assert(DotIntrinsicId != Intrinsic::not_intrinsic &&
+    assert((!IsValidDotProduct || DotIntrinsicId != Intrinsic::not_intrinsic) &&
            "Unexpected dot product case encountered.");
 
-    auto A = ExtA->getOperand(0);
-    auto B = ExtB->getOperand(0);
+    if (IsValidDotProduct) {
+      auto A = ExtA->getOperand(0);
+      auto B = ExtB->getOperand(0);
+
+      auto IntrinsicId = DAG.getConstant(DotIntrinsicId, DL, MVT::i64);
+      return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, NarrowOp.getValueType(),
+                         {IntrinsicId, NarrowOp, A, B});
+    } else {
+      // If the node doesn't match a dot product, lower to a series of ADDs
+      // instead.
+      SDValue Op0 = N->getOperand(0);
+      SDValue Op1 = N->getOperand(1);
+      EVT Type0 = Op0->getValueType(0);
+      EVT Type1 = Op1->getValueType(0);
+
+      // Canonicalise so that Op1 has the larger type
+      if (Type1.getVectorNumElements() > Type0.getVectorNumElements()) {
+        std::swap(Op0, Op1);
+        std::swap(Type0, Type1);
+      }
 
-    auto IntrinsicId = DAG.getConstant(DotIntrinsicId, DL, MVT::i64);
-    return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, NarrowOp.getValueType(),
-                       {IntrinsicId, NarrowOp, A, B});
+      auto Type0Elements = Type0.getVectorNumElements();
+      auto Type1Elements = Type1.getVectorNumElements();
+      auto Type0ElementSize =
+          Type0.getVectorElementType().getScalarSizeInBits();
+      auto Type1ElementSize =
+          Type1.getVectorElementType().getScalarSizeInBits();
+
+      // If the types are equal then a single ADD is fine
+      if (Type0 == Type1)
+        return DAG.getNode(ISD::ADD, DL, Type0, {Op0, Op1});
+
+      // Otherwise, we need to add each subvector together so that the output is
+      // the intrinsic's return type. For example, <4 x i32>
+      // partial.reduction(<4 x i32> a, <16 x i32> b) becomes a + b[0..3] +
+      // b[4..7] + b[8..11] + b[12..15]
+      SDValue Add = Op0;
+      for (unsigned i = 0; i < Type1Elements / Type0Elements; i++) {
+        SDValue Subvec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, Type0, Op1,
+                                     DAG.getConstant(i, DL, MVT::i64));
+
+        if (Type1ElementSize < Type0ElementSize)
+          Subvec = DAG.getNode(ISD::ANY_EXTEND, DL, Type0, Subvec);
+        else if (Type1ElementSize > Type0ElementSize)
+          Subvec = DAG.getNode(ISD::TRUNCATE, DL, Type0, Subvec);
+        Add = DAG.getNode(ISD::ADD, DL, Type0, {Add, Subvec});
+      }
+      return Add;
+    }
   }
   case Intrinsic::aarch64_neon_vcvtfxs2fp:
   case Intrinsic::aarch64_neon_vcvtfxu2fp:

>From c068775b5ddb75a9cd31b6443551c5a0e9cab496 Mon Sep 17 00:00:00 2001
From: Sam Tebbs <samuel.tebbs at arm.com>
Date: Mon, 12 Aug 2024 11:02:28 +0100
Subject: [PATCH 09/27] Fix logic error with fallback case

---
 llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 14 +++++++-------
 1 file changed, 7 insertions(+), 7 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index a5ea612fb38996..56327cc074a47b 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -21810,19 +21810,19 @@ static SDValue performIntrinsicCombine(SDNode *N,
   case Intrinsic::experimental_vector_partial_reduce_add: {
     SDLoc DL(N);
 
-    bool IsValidDotProduct = false;
+    bool IsValidDotProduct = true;
 
     auto NarrowOp = N->getOperand(1);
     auto MulOp = N->getOperand(2);
-    if (MulOp->getOpcode() == ISD::MUL)
-      IsValidDotProduct = true;
+    if (MulOp->getOpcode() != ISD::MUL)
+      IsValidDotProduct = false;
 
     auto ExtA = MulOp->getOperand(0);
     auto ExtB = MulOp->getOperand(1);
     bool IsSExt = ExtA->getOpcode() == ISD::SIGN_EXTEND;
     bool IsZExt = ExtA->getOpcode() == ISD::ZERO_EXTEND;
-    if (ExtA->getOpcode() == ExtB->getOpcode() && (IsSExt || IsZExt))
-      IsValidDotProduct = true;
+    if (ExtA->getOpcode() != ExtB->getOpcode() || (!IsSExt && !IsZExt))
+      IsValidDotProduct = false;
 
     unsigned DotIntrinsicId = Intrinsic::not_intrinsic;
 
@@ -21844,8 +21844,8 @@ static SDValue performIntrinsicCombine(SDNode *N,
     } else {
       // If the node doesn't match a dot product, lower to a series of ADDs
       // instead.
-      SDValue Op0 = N->getOperand(0);
-      SDValue Op1 = N->getOperand(1);
+      SDValue Op0 = N->getOperand(1);
+      SDValue Op1 = N->getOperand(2);
       EVT Type0 = Op0->getValueType(0);
       EVT Type1 = Op1->getValueType(0);
 

>From 636652d0ad53a8476270ffa01cee73483081ca4a Mon Sep 17 00:00:00 2001
From: Sam Tebbs <samuel.tebbs at arm.com>
Date: Tue, 13 Aug 2024 14:18:53 +0100
Subject: [PATCH 10/27] Pass IntrinsicInst to
 shouldExpandPartialReductionIntrinsic

---
 llvm/include/llvm/CodeGen/TargetLowering.h            | 3 ++-
 llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp | 2 +-
 llvm/lib/Target/AArch64/AArch64ISelLowering.cpp       | 5 +----
 llvm/lib/Target/AArch64/AArch64ISelLowering.h         | 3 ++-
 4 files changed, 6 insertions(+), 7 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 883a2252f7ffee..e17d68d2690c86 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -455,7 +455,8 @@ class TargetLoweringBase {
 
   /// Return true if the @llvm.experimental.vector.partial.reduce.* intrinsic
   /// should be expanded using generic code in SelectionDAGBuilder.
-  virtual bool shouldExpandPartialReductionIntrinsic(const CallInst *I) const {
+  virtual bool
+  shouldExpandPartialReductionIntrinsic(const IntrinsicInst *I) const {
     return true;
   }
 
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 5ddbf9f414d218..05cbe384cb5ed5 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -8006,7 +8006,7 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
   }
   case Intrinsic::experimental_vector_partial_reduce_add: {
 
-    if (!TLI.shouldExpandPartialReductionIntrinsic(&I)) {
+    if (!TLI.shouldExpandPartialReductionIntrinsic(cast<IntrinsicInst>(&I))) {
       visitTargetIntrinsic(I, Intrinsic);
       return;
     }
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 56327cc074a47b..916eccd52e9396 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1989,10 +1989,7 @@ bool AArch64TargetLowering::shouldExpandGetActiveLaneMask(EVT ResVT,
 }
 
 bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
-    const CallInst *CI) const {
-
-  auto *I = dyn_cast<IntrinsicInst>(CI);
-  assert(I && "shouldExpandPartialReductionIntrinsic expects an intrinisc");
+    const IntrinsicInst *I) const {
 
   VectorType *RetTy = dyn_cast<VectorType>(I->getType());
   if (!RetTy || !RetTy->isScalableTy())
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index 9fe95ddaca32c8..f9d45b02d30e30 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -998,7 +998,8 @@ class AArch64TargetLowering : public TargetLowering {
 
   bool shouldExpandGetActiveLaneMask(EVT VT, EVT OpVT) const override;
 
-  bool shouldExpandPartialReductionIntrinsic(const CallInst *I) const override;
+  bool
+  shouldExpandPartialReductionIntrinsic(const IntrinsicInst *I) const override;
 
   bool shouldExpandCttzElements(EVT VT) const override;
 

>From 83015b7e08e5c0ccfa55cf20423292a9e29d2a2a Mon Sep 17 00:00:00 2001
From: Sam Tebbs <samuel.tebbs at arm.com>
Date: Tue, 13 Aug 2024 14:31:22 +0100
Subject: [PATCH 11/27] Remove one-use restriction

---
 llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 916eccd52e9396..cd70d2e3cdfa7f 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1997,10 +1997,10 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
 
   Value *InputA;
   Value *InputB;
-  if (match(I, m_Intrinsic<Intrinsic::experimental_vector_partial_reduce_add>(
-                   m_Value(),
-                   m_OneUse(m_Mul(m_OneUse(m_ZExtOrSExt(m_Value(InputA))),
-                                  m_OneUse(m_ZExtOrSExt(m_Value(InputB)))))))) {
+  if (match(I,
+            m_Intrinsic<Intrinsic::experimental_vector_partial_reduce_add>(
+                m_Value(), m_OneUse(m_Mul(m_ZExtOrSExt(m_Value(InputA)),
+                                          m_ZExtOrSExt(m_Value(InputB))))))) {
     VectorType *InputAType = dyn_cast<VectorType>(InputA->getType());
     VectorType *InputBType = dyn_cast<VectorType>(InputB->getType());
     if (!InputAType || !InputBType)

>From ed6efd6e7e705cd8b932b66f9e818a0e6c204884 Mon Sep 17 00:00:00 2001
From: Sam Tebbs <samuel.tebbs at arm.com>
Date: Tue, 13 Aug 2024 14:32:10 +0100
Subject: [PATCH 12/27] Remove new line

---
 llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 1 -
 1 file changed, 1 deletion(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index cd70d2e3cdfa7f..856d9b0eeadd12 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -2011,7 +2011,6 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
          InputAType->getElementType()->isIntegerTy(16) &&
          InputAType->getElementCount() == ExpectedCount8 &&
          InputAType == InputBType) ||
-
         (RetTy->getScalarType()->isIntegerTy(32) &&
          InputAType->getElementType()->isIntegerTy(8) &&
          InputAType->getElementCount() == ExpectedCount16 &&

>From 63648378be615a28ea75fd53d27eabc0a3658c06 Mon Sep 17 00:00:00 2001
From: Sam Tebbs <samuel.tebbs at arm.com>
Date: Tue, 13 Aug 2024 20:21:43 +0100
Subject: [PATCH 13/27] Remove extending/truncating for fallback case

---
 llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 8 --------
 1 file changed, 8 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 856d9b0eeadd12..c150db8d8947d1 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -21853,10 +21853,6 @@ static SDValue performIntrinsicCombine(SDNode *N,
 
       auto Type0Elements = Type0.getVectorNumElements();
       auto Type1Elements = Type1.getVectorNumElements();
-      auto Type0ElementSize =
-          Type0.getVectorElementType().getScalarSizeInBits();
-      auto Type1ElementSize =
-          Type1.getVectorElementType().getScalarSizeInBits();
 
       // If the types are equal then a single ADD is fine
       if (Type0 == Type1)
@@ -21871,10 +21867,6 @@ static SDValue performIntrinsicCombine(SDNode *N,
         SDValue Subvec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, Type0, Op1,
                                      DAG.getConstant(i, DL, MVT::i64));
 
-        if (Type1ElementSize < Type0ElementSize)
-          Subvec = DAG.getNode(ISD::ANY_EXTEND, DL, Type0, Subvec);
-        else if (Type1ElementSize > Type0ElementSize)
-          Subvec = DAG.getNode(ISD::TRUNCATE, DL, Type0, Subvec);
         Add = DAG.getNode(ISD::ADD, DL, Type0, {Add, Subvec});
       }
       return Add;

>From 9da416b52c88340bbd9611ad016532c58f2e8417 Mon Sep 17 00:00:00 2001
From: Sam Tebbs <samuel.tebbs at arm.com>
Date: Tue, 13 Aug 2024 20:27:04 +0100
Subject: [PATCH 14/27] Clean up test target

---
 llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll | 5 +----
 1 file changed, 1 insertion(+), 4 deletions(-)

diff --git a/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll
index 23b39387fb7a0c..0facb2049135f6 100644
--- a/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll
@@ -1,8 +1,5 @@
 ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
-; RUN: llc -mtriple=aarch64-unknwon-linux-gnu -mattr=+sve2 -O3 %s -o - | FileCheck %s
-
-target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128"
-target triple = "aarch64-none-unknown-elf"
+; RUN: llc -mtriple=aarch64 -mattr=+sve2 %s -o - | FileCheck %s
 
 define <vscale x 4 x i32> @dotp(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b) #0 {
 ; CHECK-LABEL: dotp:

>From 0d231096e4f971f379a56c21e961caae7db3080e Mon Sep 17 00:00:00 2001
From: Samuel Tebbs <samuel.tebbs at arm.com>
Date: Wed, 14 Aug 2024 09:42:32 +0100
Subject: [PATCH 15/27] Remove #0 attribute from test

---
 .../CodeGen/AArch64/partial-reduce-dot-product.ll  | 14 ++++++--------
 1 file changed, 6 insertions(+), 8 deletions(-)

diff --git a/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll
index 0facb2049135f6..16ef219a93c9bf 100644
--- a/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll
@@ -1,7 +1,7 @@
 ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
 ; RUN: llc -mtriple=aarch64 -mattr=+sve2 %s -o - | FileCheck %s
 
-define <vscale x 4 x i32> @dotp(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b) #0 {
+define <vscale x 4 x i32> @dotp(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
 ; CHECK-LABEL: dotp:
 ; CHECK:       // %bb.0: // %entry
 ; CHECK-NEXT:    mov z2.s, #0 // =0x0
@@ -16,7 +16,7 @@ entry:
   ret <vscale x 4 x i32> %partial.reduce
 }
 
-define <vscale x 2 x i64> @dotp_wide(<vscale x 8 x i16> %a, <vscale x 8 x i16> %b) #0 {
+define <vscale x 2 x i64> @dotp_wide(<vscale x 8 x i16> %a, <vscale x 8 x i16> %b) {
 ; CHECK-LABEL: dotp_wide:
 ; CHECK:       // %bb.0: // %entry
 ; CHECK-NEXT:    mov z2.d, #0 // =0x0
@@ -31,7 +31,7 @@ entry:
   ret <vscale x 2 x i64> %partial.reduce
 }
 
-define <vscale x 4 x i32> @dotp_sext(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b) #0 {
+define <vscale x 4 x i32> @dotp_sext(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
 ; CHECK-LABEL: dotp_sext:
 ; CHECK:       // %bb.0: // %entry
 ; CHECK-NEXT:    mov z2.s, #0 // =0x0
@@ -46,7 +46,7 @@ entry:
   ret <vscale x 4 x i32> %partial.reduce
 }
 
-define <vscale x 2 x i64> @dotp_wide_sext(<vscale x 8 x i16> %a, <vscale x 8 x i16> %b) #0 {
+define <vscale x 2 x i64> @dotp_wide_sext(<vscale x 8 x i16> %a, <vscale x 8 x i16> %b) {
 ; CHECK-LABEL: dotp_wide_sext:
 ; CHECK:       // %bb.0: // %entry
 ; CHECK-NEXT:    mov z2.d, #0 // =0x0
@@ -61,7 +61,7 @@ entry:
   ret <vscale x 2 x i64> %partial.reduce
 }
 
-define <vscale x 4 x i32> @not_dotp(<vscale x 8 x i8> %a, <vscale x 8 x i8> %b) #0 {
+define <vscale x 4 x i32> @not_dotp(<vscale x 8 x i8> %a, <vscale x 8 x i8> %b) {
 ; CHECK-LABEL: not_dotp:
 ; CHECK:       // %bb.0: // %entry
 ; CHECK-NEXT:    and z0.h, z0.h, #0xff
@@ -82,7 +82,7 @@ entry:
   ret <vscale x 4 x i32> %partial.reduce
 }
 
-define <vscale x 2 x i64> @not_dotp_wide(<vscale x 4 x i16> %a, <vscale x 4 x i16> %b) #0 {
+define <vscale x 2 x i64> @not_dotp_wide(<vscale x 4 x i16> %a, <vscale x 4 x i16> %b) {
 ; CHECK-LABEL: not_dotp_wide:
 ; CHECK:       // %bb.0: // %entry
 ; CHECK-NEXT:    and z0.s, z0.s, #0xffff
@@ -102,5 +102,3 @@ entry:
   %partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv4i64(<vscale x 2 x i64> zeroinitializer, <vscale x 4 x i64> %mult)
   ret <vscale x 2 x i64> %partial.reduce
 }
-
-attributes #0 = { "target-features"="+sve2" }

>From bc86de6d93b9166c6ac4ca10cddae0199607fd15 Mon Sep 17 00:00:00 2001
From: Samuel Tebbs <samuel.tebbs at arm.com>
Date: Wed, 14 Aug 2024 10:55:12 +0100
Subject: [PATCH 16/27] Allow i8 to i64 dot products

---
 .../Target/AArch64/AArch64ISelLowering.cpp    | 34 ++++++++-
 .../AArch64/partial-reduce-dot-product.ll     | 72 +++++++++++++++++++
 2 files changed, 103 insertions(+), 3 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index c150db8d8947d1..13e664f5bf27f5 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -2007,11 +2007,15 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
       return true;
     ElementCount ExpectedCount8 = ElementCount::get(8, RetTy->isScalableTy());
     ElementCount ExpectedCount16 = ElementCount::get(16, RetTy->isScalableTy());
+    // Check that the input type is 4 times smaller than the output type. If the
+    // output type is 64 bit then we can accept 8 bit inputs if we do a 32 bit
+    // dot product and add a zext/sext.
     if ((RetTy->getScalarType()->isIntegerTy(64) &&
          InputAType->getElementType()->isIntegerTy(16) &&
          InputAType->getElementCount() == ExpectedCount8 &&
          InputAType == InputBType) ||
-        (RetTy->getScalarType()->isIntegerTy(32) &&
+        ((RetTy->getScalarType()->isIntegerTy(32) ||
+          RetTy->getScalarType()->isIntegerTy(64)) &&
          InputAType->getElementType()->isIntegerTy(8) &&
          InputAType->getElementCount() == ExpectedCount16 &&
          InputAType == InputBType)) {
@@ -21833,10 +21837,34 @@ static SDValue performIntrinsicCombine(SDNode *N,
     if (IsValidDotProduct) {
       auto A = ExtA->getOperand(0);
       auto B = ExtB->getOperand(0);
+      EVT Type = NarrowOp.getValueType();
+
+      // 8 bit input to 64 bit output can be done by doing a 32 bit dot product
+      // and extending the output
+      bool Extend = A->getValueType(0).getScalarSizeInBits() == 8 &&
+                    Type.getScalarSizeInBits() == 64;
+      SDValue Accumulator = NarrowOp;
+      if (Extend) {
+        Type = Type.changeVectorElementType(
+            EVT::getIntegerVT(*DAG.getContext(), 32));
+        // The accumulator is of the wider type so we insert a 0 accumulator and
+        // add the proper one after extending
+        Accumulator = DAG.getNode(ISD::SPLAT_VECTOR, DL, MVT::nxv4i32,
+                                  DAG.getConstant(0, DL, MVT::i32));
+      }
 
       auto IntrinsicId = DAG.getConstant(DotIntrinsicId, DL, MVT::i64);
-      return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, NarrowOp.getValueType(),
-                         {IntrinsicId, NarrowOp, A, B});
+      auto DotProduct = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, Type,
+                                    {IntrinsicId, Accumulator, A, B});
+      if (Extend) {
+        auto Extended =
+            DAG.getNode(IsZExt ? ISD::ZERO_EXTEND : ISD::SIGN_EXTEND, DL,
+                        NarrowOp.getValueType(), {DotProduct});
+        auto AccAdd = DAG.getNode(ISD::ADD, DL, NarrowOp.getValueType(),
+                                  {NarrowOp, Extended});
+        DotProduct = AccAdd;
+      }
+      return DotProduct;
     } else {
       // If the node doesn't match a dot product, lower to a series of ADDs
       // instead.
diff --git a/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll
index 16ef219a93c9bf..c1cf9026d693ce 100644
--- a/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll
@@ -61,6 +61,78 @@ entry:
   ret <vscale x 2 x i64> %partial.reduce
 }
 
+define <vscale x 4 x i64> @dotp_8to64(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
+; CHECK-LABEL: dotp_8to64:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    mov z2.s, #0 // =0x0
+; CHECK-NEXT:    udot z2.s, z0.b, z1.b
+; CHECK-NEXT:    uunpklo z0.d, z2.s
+; CHECK-NEXT:    uunpkhi z1.d, z2.s
+; CHECK-NEXT:    ret
+entry:
+  %a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i64>
+  %b.wide = zext <vscale x 16 x i8> %b to <vscale x 16 x i64>
+  %mult = mul nuw nsw <vscale x 16 x i64> %a.wide, %b.wide
+  %partial.reduce = tail call <vscale x 4 x i64> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(
+  <vscale x 4 x i64> zeroinitializer, <vscale x 16 x i64> %mult)
+  ret <vscale x 4 x i64> %partial.reduce
+}
+
+define <vscale x 4 x i64> @dotp_sext_8to64(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
+; CHECK-LABEL: dotp_sext_8to64:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    mov z2.s, #0 // =0x0
+; CHECK-NEXT:    sdot z2.s, z0.b, z1.b
+; CHECK-NEXT:    sunpklo z0.d, z2.s
+; CHECK-NEXT:    sunpkhi z1.d, z2.s
+; CHECK-NEXT:    ret
+entry:
+  %a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i64>
+  %b.wide = sext <vscale x 16 x i8> %b to <vscale x 16 x i64>
+  %mult = mul nuw nsw <vscale x 16 x i64> %a.wide, %b.wide
+  %partial.reduce = tail call <vscale x 4 x i64> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(
+  <vscale x 4 x i64> zeroinitializer, <vscale x 16 x i64> %mult)
+  ret <vscale x 4 x i64> %partial.reduce
+}
+
+define <vscale x 4 x i64> @dotp_8to64_accumulator(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b, <vscale x 4 x i64> %acc) {
+; CHECK-LABEL: dotp_8to64_accumulator:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    mov z4.s, #0 // =0x0
+; CHECK-NEXT:    udot z4.s, z0.b, z1.b
+; CHECK-NEXT:    uunpklo z0.d, z4.s
+; CHECK-NEXT:    uunpkhi z1.d, z4.s
+; CHECK-NEXT:    add z0.d, z2.d, z0.d
+; CHECK-NEXT:    add z1.d, z3.d, z1.d
+; CHECK-NEXT:    ret
+entry:
+  %a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i64>
+  %b.wide = zext <vscale x 16 x i8> %b to <vscale x 16 x i64>
+  %mult = mul nuw nsw <vscale x 16 x i64> %a.wide, %b.wide
+  %partial.reduce = tail call <vscale x 4 x i64> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(
+  <vscale x 4 x i64> %acc, <vscale x 16 x i64> %mult)
+  ret <vscale x 4 x i64> %partial.reduce
+}
+
+define <vscale x 4 x i64> @dotp_sext_8to64_accumulator(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b, <vscale x 4 x i64> %acc) {
+; CHECK-LABEL: dotp_sext_8to64_accumulator:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    mov z4.s, #0 // =0x0
+; CHECK-NEXT:    sdot z4.s, z0.b, z1.b
+; CHECK-NEXT:    sunpklo z0.d, z4.s
+; CHECK-NEXT:    sunpkhi z1.d, z4.s
+; CHECK-NEXT:    add z0.d, z2.d, z0.d
+; CHECK-NEXT:    add z1.d, z3.d, z1.d
+; CHECK-NEXT:    ret
+entry:
+  %a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i64>
+  %b.wide = sext <vscale x 16 x i8> %b to <vscale x 16 x i64>
+  %mult = mul nuw nsw <vscale x 16 x i64> %a.wide, %b.wide
+  %partial.reduce = tail call <vscale x 4 x i64> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(
+  <vscale x 4 x i64> %acc, <vscale x 16 x i64> %mult)
+  ret <vscale x 4 x i64> %partial.reduce
+}
+
 define <vscale x 4 x i32> @not_dotp(<vscale x 8 x i8> %a, <vscale x 8 x i8> %b) {
 ; CHECK-LABEL: not_dotp:
 ; CHECK:       // %bb.0: // %entry

>From aa7957faeb3adb8beda7443c4a1b06bddb9b01d4 Mon Sep 17 00:00:00 2001
From: Samuel Tebbs <samuel.tebbs at arm.com>
Date: Tue, 20 Aug 2024 13:53:11 +0100
Subject: [PATCH 17/27] Remove isPartialReductionSupported

---
 .../AArch64/AArch64TargetTransformInfo.cpp    | 25 -------------------
 .../AArch64/AArch64TargetTransformInfo.h      |  6 -----
 2 files changed, 31 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 8cd2ba17b7d79d..dc748290f2e21e 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -3670,31 +3670,6 @@ AArch64TTIImpl::getCostOfKeepingLiveOverCall(ArrayRef<Type *> Tys) {
   return Cost;
 }
 
-bool AArch64TTIImpl::isPartialReductionSupported(
-    const Instruction *ReductionInstr, Type *InputType, unsigned ScaleFactor,
-    bool IsInputASignExtended, bool IsInputBSignExtended,
-    const Instruction *BinOp) const {
-  if (ReductionInstr->getOpcode() != Instruction::Add)
-    return false;
-
-  // Check that both extends are of the same type
-  if (IsInputASignExtended != IsInputBSignExtended)
-    return false;
-
-  if (!BinOp || BinOp->getOpcode() != Instruction::Mul)
-    return false;
-
-  // Dot product only supports a scale factor of 4
-  if (ScaleFactor != 4)
-    return false;
-
-  Type *ReductionType = ReductionInstr->getType();
-
-  return ((ReductionType->isIntegerTy(32) && InputType->isIntegerTy(8)) ||
-          (ReductionType->isIntegerTy(64) && InputType->isIntegerTy(16))) &&
-         ReductionType->isScalableTy();
-}
-
 unsigned AArch64TTIImpl::getMaxInterleaveFactor(ElementCount VF) {
   return ST->getMaxInterleaveFactor();
 }
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index af7e8e8e497dd8..4a6457d7a7dbf5 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -155,12 +155,6 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
     return VF.getKnownMinValue() * ST->getVScaleForTuning();
   }
 
-  bool isPartialReductionSupported(const Instruction *ReductionInstr,
-                                   Type *InputType, unsigned ScaleFactor,
-                                   bool IsInputASignExtended,
-                                   bool IsInputBSignExtended,
-                                   const Instruction *BinOp = nullptr) const;
-
   unsigned getMaxInterleaveFactor(ElementCount VF);
 
   bool prefersVectorizedAddressing() const;

>From a58ac297afd726a536b6687748d9a196473a4618 Mon Sep 17 00:00:00 2001
From: Samuel Tebbs <samuel.tebbs at arm.com>
Date: Wed, 21 Aug 2024 15:02:17 +0100
Subject: [PATCH 18/27] Share expansion code in SelectionDAG

---
 llvm/include/llvm/CodeGen/SelectionDAG.h      |   4 +
 .../lib/CodeGen/SelectionDAG/SelectionDAG.cpp |  30 +++
 .../SelectionDAG/SelectionDAGBuilder.cpp      |  29 +--
 .../Target/AArch64/AArch64ISelLowering.cpp    | 217 ++++++++----------
 4 files changed, 130 insertions(+), 150 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h
index 1514d92b36b3c2..2235db5d93b5ad 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAG.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAG.h
@@ -1594,6 +1594,10 @@ class SelectionDAG {
   /// the target's desired shift amount type.
   SDValue getShiftAmountOperand(EVT LHSTy, SDValue Op);
 
+  /// Expand a partial reduction intrinsic call.
+  /// Op1 and Op2 are its operands and ReducedTY is the intrinsic's return type.
+  SDValue expandPartialReductionIntrinsic(EVT ReducedTy, SDValue Op1, SDValue Op2, SDLoc DL);
+
   /// Expand the specified \c ISD::VAARG node as the Legalize pass would.
   SDValue expandVAArg(SDNode *Node);
 
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 27675dce70c260..2510c1828c909f 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -74,6 +74,7 @@
 #include <cassert>
 #include <cstdint>
 #include <cstdlib>
+#include <deque>
 #include <limits>
 #include <optional>
 #include <set>
@@ -2426,6 +2427,35 @@ SDValue SelectionDAG::getShiftAmountOperand(EVT LHSTy, SDValue Op) {
   return getZExtOrTrunc(Op, SDLoc(Op), ShTy);
 }
 
+SDValue SelectionDAG::expandPartialReductionIntrinsic(EVT ReducedTy, SDValue Op1, SDValue Op2, SDLoc DL) {
+    EVT FullTy = Op2.getValueType();
+
+    unsigned Stride = ReducedTy.getVectorMinNumElements();
+    unsigned ScaleFactor = FullTy.getVectorMinNumElements() / Stride;
+
+    // Collect all of the subvectors
+    std::deque<SDValue> Subvectors = {Op1};
+    for (unsigned I = 0; I < ScaleFactor; I++) {
+      auto SourceIndex = getVectorIdxConstant(I * Stride, DL);
+      Subvectors.push_back(getNode(ISD::EXTRACT_SUBVECTOR, DL, ReducedTy,
+                                       {Op2, SourceIndex}));
+    }
+
+    // Flatten the subvector tree
+    while (Subvectors.size() > 1) {
+      Subvectors.push_back(getNode(ISD::ADD, DL, ReducedTy,
+                                       {Subvectors[0], Subvectors[1]}));
+      Subvectors.pop_front();
+      Subvectors.pop_front();
+    }
+
+    assert(Subvectors.size() == 1 &&
+           "There should only be one subvector after tree flattening");
+
+    return Subvectors[0];
+
+}
+
 SDValue SelectionDAG::expandVAArg(SDNode *Node) {
   SDLoc dl(Node);
   const TargetLowering &TLI = getTargetLoweringInfo();
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 05cbe384cb5ed5..33de8747fb7e56 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -8011,34 +8011,7 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
       return;
     }
 
-    SDValue OpNode = getValue(I.getOperand(1));
-    EVT ReducedTy = EVT::getEVT(I.getType());
-    EVT FullTy = OpNode.getValueType();
-
-    unsigned Stride = ReducedTy.getVectorMinNumElements();
-    unsigned ScaleFactor = FullTy.getVectorMinNumElements() / Stride;
-
-    // Collect all of the subvectors
-    std::deque<SDValue> Subvectors;
-    Subvectors.push_back(getValue(I.getOperand(0)));
-    for (unsigned i = 0; i < ScaleFactor; i++) {
-      auto SourceIndex = DAG.getVectorIdxConstant(i * Stride, sdl);
-      Subvectors.push_back(DAG.getNode(ISD::EXTRACT_SUBVECTOR, sdl, ReducedTy,
-                                       {OpNode, SourceIndex}));
-    }
-
-    // Flatten the subvector tree
-    while (Subvectors.size() > 1) {
-      Subvectors.push_back(DAG.getNode(ISD::ADD, sdl, ReducedTy,
-                                       {Subvectors[0], Subvectors[1]}));
-      Subvectors.pop_front();
-      Subvectors.pop_front();
-    }
-
-    assert(Subvectors.size() == 1 &&
-           "There should only be one subvector after tree flattening");
-
-    setValue(&I, Subvectors[0]);
+    setValue(&I, DAG.expandPartialReductionIntrinsic(EVT::getEVT(I.getType()), getValue(I.getOperand(0)), getValue(I.getOperand(1)), sdl));
     return;
   }
   case Intrinsic::experimental_cttz_elts: {
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 13e664f5bf27f5..df89806ca057e1 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1995,37 +1995,12 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
   if (!RetTy || !RetTy->isScalableTy())
     return true;
 
-  Value *InputA;
-  Value *InputB;
-  if (match(I,
-            m_Intrinsic<Intrinsic::experimental_vector_partial_reduce_add>(
-                m_Value(), m_OneUse(m_Mul(m_ZExtOrSExt(m_Value(InputA)),
-                                          m_ZExtOrSExt(m_Value(InputB))))))) {
-    VectorType *InputAType = dyn_cast<VectorType>(InputA->getType());
-    VectorType *InputBType = dyn_cast<VectorType>(InputB->getType());
-    if (!InputAType || !InputBType)
-      return true;
-    ElementCount ExpectedCount8 = ElementCount::get(8, RetTy->isScalableTy());
-    ElementCount ExpectedCount16 = ElementCount::get(16, RetTy->isScalableTy());
-    // Check that the input type is 4 times smaller than the output type. If the
-    // output type is 64 bit then we can accept 8 bit inputs if we do a 32 bit
-    // dot product and add a zext/sext.
-    if ((RetTy->getScalarType()->isIntegerTy(64) &&
-         InputAType->getElementType()->isIntegerTy(16) &&
-         InputAType->getElementCount() == ExpectedCount8 &&
-         InputAType == InputBType) ||
-        ((RetTy->getScalarType()->isIntegerTy(32) ||
-          RetTy->getScalarType()->isIntegerTy(64)) &&
-         InputAType->getElementType()->isIntegerTy(8) &&
-         InputAType->getElementCount() == ExpectedCount16 &&
-         InputAType == InputBType)) {
-      auto *Mul = cast<Instruction>(I->getOperand(1));
-      auto *Mul0 = cast<Instruction>(Mul->getOperand(0));
-      auto *Mul1 = cast<Instruction>(Mul->getOperand(1));
-      if (Mul0->getOpcode() == Mul1->getOpcode())
-        return false;
-    }
-  }
+  if (RetTy->getScalarType()->isIntegerTy(32) && RetTy->getElementCount() == ElementCount::get(4, RetTy->isScalableTy()))
+    return false;
+  if (RetTy->getScalarType()->isIntegerTy(64) && RetTy->getElementCount() == ElementCount::get(2, RetTy->isScalableTy()))
+    return false;
+  if (RetTy->getScalarType()->isIntegerTy(64) && RetTy->getElementCount() == ElementCount::get(4, RetTy->isScalableTy()))
+    return false;
 
   return true;
 }
@@ -21799,6 +21774,92 @@ static SDValue tryCombineWhileLo(SDNode *N,
   return SDValue(N, 0);
 }
 
+SDValue tryLowerPartialReductionToDot(SDNode *N, const AArch64Subtarget *Subtarget, SelectionDAG &DAG) {
+  SDLoc DL(N);
+
+  // The narrower of the two operands. Used as the accumulator
+  auto NarrowOp = N->getOperand(1);
+  auto MulOp = N->getOperand(2);
+  if (MulOp->getOpcode() != ISD::MUL)
+    return SDValue();
+
+  auto ExtA = MulOp->getOperand(0);
+  auto ExtB = MulOp->getOperand(1);
+  bool IsSExt = ExtA->getOpcode() == ISD::SIGN_EXTEND;
+  bool IsZExt = ExtA->getOpcode() == ISD::ZERO_EXTEND;
+  if (ExtA->getOpcode() != ExtB->getOpcode() || (!IsSExt && !IsZExt))
+    return SDValue();
+
+  auto A = ExtA->getOperand(0);
+  auto B = ExtB->getOperand(0);
+  if (A.getValueType() != B.getValueType())
+    return SDValue();
+
+  // The fully-reduced type. Should be a vector of i32 or i64
+  EVT FullType = N->getValueType(0);
+  // The type that is extended to the wide type. Should be an i8 or i16
+  EVT ExtendedType = A.getValueType();
+  // The wide type with four times as many elements as the reduced type. Should be a vector of i32 or i64, the same as the fully-reduced type
+  EVT WideType = MulOp.getValueType();
+  if (WideType.getScalarSizeInBits() != FullType.getScalarSizeInBits())
+    return SDValue();
+  // Dot products operate on chunks of four elements so there must be four times as many elements in the wide type
+  if (WideType.getVectorMinNumElements() / FullType.getVectorMinNumElements() != 4)
+    return SDValue();
+  switch (FullType.getScalarSizeInBits()) {
+    case 32:
+      if (ExtendedType.getScalarSizeInBits() != 8)
+        return SDValue();
+      break;
+    case 64:
+      // i8 to i64 can be done with an extended i32 dot product
+      if (ExtendedType.getScalarSizeInBits() != 8 && ExtendedType.getScalarSizeInBits() != 16)
+        return SDValue();
+      break;
+    default:
+      return SDValue();
+  }
+
+  unsigned DotIntrinsicId = Intrinsic::not_intrinsic;
+
+  if (IsSExt)
+    DotIntrinsicId = Intrinsic::aarch64_sve_sdot;
+  else if (IsZExt)
+    DotIntrinsicId = Intrinsic::aarch64_sve_udot;
+
+  assert(DotIntrinsicId != Intrinsic::not_intrinsic &&
+         "Unexpected dot product case encountered.");
+
+  EVT Type = NarrowOp.getValueType();
+
+  // 8 bit input to 64 bit output can be done by doing a 32 bit dot product
+  // and extending the output
+  bool Extend = A->getValueType(0).getScalarSizeInBits() == 8 &&
+                Type.getScalarSizeInBits() == 64;
+  SDValue Accumulator = NarrowOp;
+  if (Extend) {
+    Type = Type.changeVectorElementType(
+        EVT::getIntegerVT(*DAG.getContext(), 32));
+    // The accumulator is of the wider type so we insert a 0 accumulator and
+    // add the proper one after extending
+    Accumulator = DAG.getNode(ISD::SPLAT_VECTOR, DL, MVT::nxv4i32,
+                              DAG.getConstant(0, DL, MVT::i32));
+  }
+
+  auto IntrinsicId = DAG.getConstant(DotIntrinsicId, DL, MVT::i64);
+  auto DotProduct = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, Type,
+                                {IntrinsicId, Accumulator, A, B});
+  if (Extend) {
+    auto Extended =
+        DAG.getNode(IsZExt ? ISD::ZERO_EXTEND : ISD::SIGN_EXTEND, DL,
+                    NarrowOp.getValueType(), {DotProduct});
+    auto AccAdd = DAG.getNode(ISD::ADD, DL, NarrowOp.getValueType(),
+                              {NarrowOp, Extended});
+    DotProduct = AccAdd;
+  }
+  return DotProduct;
+}
+
 static SDValue performIntrinsicCombine(SDNode *N,
                                        TargetLowering::DAGCombinerInfo &DCI,
                                        const AArch64Subtarget *Subtarget) {
@@ -21808,97 +21869,9 @@ static SDValue performIntrinsicCombine(SDNode *N,
   default:
     break;
   case Intrinsic::experimental_vector_partial_reduce_add: {
-    SDLoc DL(N);
-
-    bool IsValidDotProduct = true;
-
-    auto NarrowOp = N->getOperand(1);
-    auto MulOp = N->getOperand(2);
-    if (MulOp->getOpcode() != ISD::MUL)
-      IsValidDotProduct = false;
-
-    auto ExtA = MulOp->getOperand(0);
-    auto ExtB = MulOp->getOperand(1);
-    bool IsSExt = ExtA->getOpcode() == ISD::SIGN_EXTEND;
-    bool IsZExt = ExtA->getOpcode() == ISD::ZERO_EXTEND;
-    if (ExtA->getOpcode() != ExtB->getOpcode() || (!IsSExt && !IsZExt))
-      IsValidDotProduct = false;
-
-    unsigned DotIntrinsicId = Intrinsic::not_intrinsic;
-
-    if (IsSExt && IsValidDotProduct)
-      DotIntrinsicId = Intrinsic::aarch64_sve_sdot;
-    else if (IsZExt && IsValidDotProduct)
-      DotIntrinsicId = Intrinsic::aarch64_sve_udot;
-
-    assert((!IsValidDotProduct || DotIntrinsicId != Intrinsic::not_intrinsic) &&
-           "Unexpected dot product case encountered.");
-
-    if (IsValidDotProduct) {
-      auto A = ExtA->getOperand(0);
-      auto B = ExtB->getOperand(0);
-      EVT Type = NarrowOp.getValueType();
-
-      // 8 bit input to 64 bit output can be done by doing a 32 bit dot product
-      // and extending the output
-      bool Extend = A->getValueType(0).getScalarSizeInBits() == 8 &&
-                    Type.getScalarSizeInBits() == 64;
-      SDValue Accumulator = NarrowOp;
-      if (Extend) {
-        Type = Type.changeVectorElementType(
-            EVT::getIntegerVT(*DAG.getContext(), 32));
-        // The accumulator is of the wider type so we insert a 0 accumulator and
-        // add the proper one after extending
-        Accumulator = DAG.getNode(ISD::SPLAT_VECTOR, DL, MVT::nxv4i32,
-                                  DAG.getConstant(0, DL, MVT::i32));
-      }
-
-      auto IntrinsicId = DAG.getConstant(DotIntrinsicId, DL, MVT::i64);
-      auto DotProduct = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, Type,
-                                    {IntrinsicId, Accumulator, A, B});
-      if (Extend) {
-        auto Extended =
-            DAG.getNode(IsZExt ? ISD::ZERO_EXTEND : ISD::SIGN_EXTEND, DL,
-                        NarrowOp.getValueType(), {DotProduct});
-        auto AccAdd = DAG.getNode(ISD::ADD, DL, NarrowOp.getValueType(),
-                                  {NarrowOp, Extended});
-        DotProduct = AccAdd;
-      }
-      return DotProduct;
-    } else {
-      // If the node doesn't match a dot product, lower to a series of ADDs
-      // instead.
-      SDValue Op0 = N->getOperand(1);
-      SDValue Op1 = N->getOperand(2);
-      EVT Type0 = Op0->getValueType(0);
-      EVT Type1 = Op1->getValueType(0);
-
-      // Canonicalise so that Op1 has the larger type
-      if (Type1.getVectorNumElements() > Type0.getVectorNumElements()) {
-        std::swap(Op0, Op1);
-        std::swap(Type0, Type1);
-      }
-
-      auto Type0Elements = Type0.getVectorNumElements();
-      auto Type1Elements = Type1.getVectorNumElements();
-
-      // If the types are equal then a single ADD is fine
-      if (Type0 == Type1)
-        return DAG.getNode(ISD::ADD, DL, Type0, {Op0, Op1});
-
-      // Otherwise, we need to add each subvector together so that the output is
-      // the intrinsic's return type. For example, <4 x i32>
-      // partial.reduction(<4 x i32> a, <16 x i32> b) becomes a + b[0..3] +
-      // b[4..7] + b[8..11] + b[12..15]
-      SDValue Add = Op0;
-      for (unsigned i = 0; i < Type1Elements / Type0Elements; i++) {
-        SDValue Subvec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, Type0, Op1,
-                                     DAG.getConstant(i, DL, MVT::i64));
-
-        Add = DAG.getNode(ISD::ADD, DL, Type0, {Add, Subvec});
-      }
-      return Add;
-    }
+    if (auto Dot = tryLowerPartialReductionToDot(N, Subtarget, DAG))
+        return Dot;
+    return DAG.expandPartialReductionIntrinsic(N->getValueType(0), N->getOperand(1), N->getOperand(2), SDLoc(N));
   }
   case Intrinsic::aarch64_neon_vcvtfxs2fp:
   case Intrinsic::aarch64_neon_vcvtfxu2fp:

>From 5f31079756eb54679f0dd64da0d425084069a929 Mon Sep 17 00:00:00 2001
From: Samuel Tebbs <samuel.tebbs at arm.com>
Date: Wed, 21 Aug 2024 16:04:52 +0100
Subject: [PATCH 19/27] Check for NEON or SVE

---
 llvm/include/llvm/CodeGen/SelectionDAG.h      |  3 +-
 .../lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 45 +++++++-------
 .../SelectionDAG/SelectionDAGBuilder.cpp      |  4 +-
 .../Target/AArch64/AArch64ISelLowering.cpp    | 61 +++++++++++--------
 4 files changed, 65 insertions(+), 48 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h
index 2235db5d93b5ad..2c1d4bf259699e 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAG.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAG.h
@@ -1596,7 +1596,8 @@ class SelectionDAG {
 
   /// Expand a partial reduction intrinsic call.
   /// Op1 and Op2 are its operands and ReducedTY is the intrinsic's return type.
-  SDValue expandPartialReductionIntrinsic(EVT ReducedTy, SDValue Op1, SDValue Op2, SDLoc DL);
+  SDValue expandPartialReductionIntrinsic(EVT ReducedTy, SDValue Op1,
+                                          SDValue Op2, SDLoc DL);
 
   /// Expand the specified \c ISD::VAARG node as the Legalize pass would.
   SDValue expandVAArg(SDNode *Node);
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 2510c1828c909f..cd09760b1f24e9 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -2427,33 +2427,34 @@ SDValue SelectionDAG::getShiftAmountOperand(EVT LHSTy, SDValue Op) {
   return getZExtOrTrunc(Op, SDLoc(Op), ShTy);
 }
 
-SDValue SelectionDAG::expandPartialReductionIntrinsic(EVT ReducedTy, SDValue Op1, SDValue Op2, SDLoc DL) {
-    EVT FullTy = Op2.getValueType();
+SDValue SelectionDAG::expandPartialReductionIntrinsic(EVT ReducedTy,
+                                                      SDValue Op1, SDValue Op2,
+                                                      SDLoc DL) {
+  EVT FullTy = Op2.getValueType();
 
-    unsigned Stride = ReducedTy.getVectorMinNumElements();
-    unsigned ScaleFactor = FullTy.getVectorMinNumElements() / Stride;
+  unsigned Stride = ReducedTy.getVectorMinNumElements();
+  unsigned ScaleFactor = FullTy.getVectorMinNumElements() / Stride;
 
-    // Collect all of the subvectors
-    std::deque<SDValue> Subvectors = {Op1};
-    for (unsigned I = 0; I < ScaleFactor; I++) {
-      auto SourceIndex = getVectorIdxConstant(I * Stride, DL);
-      Subvectors.push_back(getNode(ISD::EXTRACT_SUBVECTOR, DL, ReducedTy,
-                                       {Op2, SourceIndex}));
-    }
-
-    // Flatten the subvector tree
-    while (Subvectors.size() > 1) {
-      Subvectors.push_back(getNode(ISD::ADD, DL, ReducedTy,
-                                       {Subvectors[0], Subvectors[1]}));
-      Subvectors.pop_front();
-      Subvectors.pop_front();
-    }
+  // Collect all of the subvectors
+  std::deque<SDValue> Subvectors = {Op1};
+  for (unsigned I = 0; I < ScaleFactor; I++) {
+    auto SourceIndex = getVectorIdxConstant(I * Stride, DL);
+    Subvectors.push_back(
+        getNode(ISD::EXTRACT_SUBVECTOR, DL, ReducedTy, {Op2, SourceIndex}));
+  }
 
-    assert(Subvectors.size() == 1 &&
-           "There should only be one subvector after tree flattening");
+  // Flatten the subvector tree
+  while (Subvectors.size() > 1) {
+    Subvectors.push_back(
+        getNode(ISD::ADD, DL, ReducedTy, {Subvectors[0], Subvectors[1]}));
+    Subvectors.pop_front();
+    Subvectors.pop_front();
+  }
 
-    return Subvectors[0];
+  assert(Subvectors.size() == 1 &&
+         "There should only be one subvector after tree flattening");
 
+  return Subvectors[0];
 }
 
 SDValue SelectionDAG::expandVAArg(SDNode *Node) {
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 33de8747fb7e56..ce5ef78eba15db 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -8011,7 +8011,9 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
       return;
     }
 
-    setValue(&I, DAG.expandPartialReductionIntrinsic(EVT::getEVT(I.getType()), getValue(I.getOperand(0)), getValue(I.getOperand(1)), sdl));
+    setValue(&I, DAG.expandPartialReductionIntrinsic(
+                     EVT::getEVT(I.getType()), getValue(I.getOperand(0)),
+                     getValue(I.getOperand(1)), sdl));
     return;
   }
   case Intrinsic::experimental_cttz_elts: {
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index df89806ca057e1..b849ddb2a86d67 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1995,11 +1995,14 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
   if (!RetTy || !RetTy->isScalableTy())
     return true;
 
-  if (RetTy->getScalarType()->isIntegerTy(32) && RetTy->getElementCount() == ElementCount::get(4, RetTy->isScalableTy()))
+  if (RetTy->getScalarType()->isIntegerTy(32) &&
+      RetTy->getElementCount() == ElementCount::get(4, RetTy->isScalableTy()))
     return false;
-  if (RetTy->getScalarType()->isIntegerTy(64) && RetTy->getElementCount() == ElementCount::get(2, RetTy->isScalableTy()))
+  if (RetTy->getScalarType()->isIntegerTy(64) &&
+      RetTy->getElementCount() == ElementCount::get(2, RetTy->isScalableTy()))
     return false;
-  if (RetTy->getScalarType()->isIntegerTy(64) && RetTy->getElementCount() == ElementCount::get(4, RetTy->isScalableTy()))
+  if (RetTy->getScalarType()->isIntegerTy(64) &&
+      RetTy->getElementCount() == ElementCount::get(4, RetTy->isScalableTy()))
     return false;
 
   return true;
@@ -21774,7 +21777,13 @@ static SDValue tryCombineWhileLo(SDNode *N,
   return SDValue(N, 0);
 }
 
-SDValue tryLowerPartialReductionToDot(SDNode *N, const AArch64Subtarget *Subtarget, SelectionDAG &DAG) {
+SDValue tryLowerPartialReductionToDot(SDNode *N,
+                                      const AArch64Subtarget *Subtarget,
+                                      SelectionDAG &DAG) {
+
+  if (!Subtarget->isSVEAvailable() && !Subtarget->isNeonAvailable())
+    return SDValue();
+
   SDLoc DL(N);
 
   // The narrower of the two operands. Used as the accumulator
@@ -21799,25 +21808,29 @@ SDValue tryLowerPartialReductionToDot(SDNode *N, const AArch64Subtarget *Subtarg
   EVT FullType = N->getValueType(0);
   // The type that is extended to the wide type. Should be an i8 or i16
   EVT ExtendedType = A.getValueType();
-  // The wide type with four times as many elements as the reduced type. Should be a vector of i32 or i64, the same as the fully-reduced type
+  // The wide type with four times as many elements as the reduced type. Should
+  // be a vector of i32 or i64, the same as the fully-reduced type
   EVT WideType = MulOp.getValueType();
   if (WideType.getScalarSizeInBits() != FullType.getScalarSizeInBits())
     return SDValue();
-  // Dot products operate on chunks of four elements so there must be four times as many elements in the wide type
-  if (WideType.getVectorMinNumElements() / FullType.getVectorMinNumElements() != 4)
+  // Dot products operate on chunks of four elements so there must be four times
+  // as many elements in the wide type
+  if (WideType.getVectorMinNumElements() / FullType.getVectorMinNumElements() !=
+      4)
     return SDValue();
   switch (FullType.getScalarSizeInBits()) {
-    case 32:
-      if (ExtendedType.getScalarSizeInBits() != 8)
-        return SDValue();
-      break;
-    case 64:
-      // i8 to i64 can be done with an extended i32 dot product
-      if (ExtendedType.getScalarSizeInBits() != 8 && ExtendedType.getScalarSizeInBits() != 16)
-        return SDValue();
-      break;
-    default:
+  case 32:
+    if (ExtendedType.getScalarSizeInBits() != 8)
+      return SDValue();
+    break;
+  case 64:
+    // i8 to i64 can be done with an extended i32 dot product
+    if (ExtendedType.getScalarSizeInBits() != 8 &&
+        ExtendedType.getScalarSizeInBits() != 16)
       return SDValue();
+    break;
+  default:
+    return SDValue();
   }
 
   unsigned DotIntrinsicId = Intrinsic::not_intrinsic;
@@ -21838,8 +21851,8 @@ SDValue tryLowerPartialReductionToDot(SDNode *N, const AArch64Subtarget *Subtarg
                 Type.getScalarSizeInBits() == 64;
   SDValue Accumulator = NarrowOp;
   if (Extend) {
-    Type = Type.changeVectorElementType(
-        EVT::getIntegerVT(*DAG.getContext(), 32));
+    Type =
+        Type.changeVectorElementType(EVT::getIntegerVT(*DAG.getContext(), 32));
     // The accumulator is of the wider type so we insert a 0 accumulator and
     // add the proper one after extending
     Accumulator = DAG.getNode(ISD::SPLAT_VECTOR, DL, MVT::nxv4i32,
@@ -21850,9 +21863,8 @@ SDValue tryLowerPartialReductionToDot(SDNode *N, const AArch64Subtarget *Subtarg
   auto DotProduct = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, Type,
                                 {IntrinsicId, Accumulator, A, B});
   if (Extend) {
-    auto Extended =
-        DAG.getNode(IsZExt ? ISD::ZERO_EXTEND : ISD::SIGN_EXTEND, DL,
-                    NarrowOp.getValueType(), {DotProduct});
+    auto Extended = DAG.getNode(IsZExt ? ISD::ZERO_EXTEND : ISD::SIGN_EXTEND,
+                                DL, NarrowOp.getValueType(), {DotProduct});
     auto AccAdd = DAG.getNode(ISD::ADD, DL, NarrowOp.getValueType(),
                               {NarrowOp, Extended});
     DotProduct = AccAdd;
@@ -21870,8 +21882,9 @@ static SDValue performIntrinsicCombine(SDNode *N,
     break;
   case Intrinsic::experimental_vector_partial_reduce_add: {
     if (auto Dot = tryLowerPartialReductionToDot(N, Subtarget, DAG))
-        return Dot;
-    return DAG.expandPartialReductionIntrinsic(N->getValueType(0), N->getOperand(1), N->getOperand(2), SDLoc(N));
+      return Dot;
+    return DAG.expandPartialReductionIntrinsic(
+        N->getValueType(0), N->getOperand(1), N->getOperand(2), SDLoc(N));
   }
   case Intrinsic::aarch64_neon_vcvtfxs2fp:
   case Intrinsic::aarch64_neon_vcvtfxu2fp:

>From 2f3a0dc8d581efd0687d0eedf3c346ffc6a60716 Mon Sep 17 00:00:00 2001
From: Samuel Tebbs <samuel.tebbs at arm.com>
Date: Wed, 28 Aug 2024 09:27:15 +0100
Subject: [PATCH 20/27] Rename expansion function

---
 llvm/include/llvm/CodeGen/SelectionDAG.h              | 4 ++--
 llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp        | 5 ++---
 llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp | 6 +++---
 llvm/lib/Target/AArch64/AArch64ISelLowering.cpp       | 4 ++--
 4 files changed, 9 insertions(+), 10 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h
index 2c1d4bf259699e..227616c37e004d 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAG.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAG.h
@@ -1596,8 +1596,8 @@ class SelectionDAG {
 
   /// Expand a partial reduction intrinsic call.
   /// Op1 and Op2 are its operands and ReducedTY is the intrinsic's return type.
-  SDValue expandPartialReductionIntrinsic(EVT ReducedTy, SDValue Op1,
-                                          SDValue Op2, SDLoc DL);
+  SDValue getPartialReduceAdd(SDLoc DL, EVT ReducedTy, SDValue Op1,
+                              SDValue Op2);
 
   /// Expand the specified \c ISD::VAARG node as the Legalize pass would.
   SDValue expandVAArg(SDNode *Node);
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index cd09760b1f24e9..d5e61183d0e25d 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -2427,9 +2427,8 @@ SDValue SelectionDAG::getShiftAmountOperand(EVT LHSTy, SDValue Op) {
   return getZExtOrTrunc(Op, SDLoc(Op), ShTy);
 }
 
-SDValue SelectionDAG::expandPartialReductionIntrinsic(EVT ReducedTy,
-                                                      SDValue Op1, SDValue Op2,
-                                                      SDLoc DL) {
+SDValue SelectionDAG::getPartialReduceAdd(SDLoc DL, EVT ReducedTy, SDValue Op1,
+                                          SDValue Op2) {
   EVT FullTy = Op2.getValueType();
 
   unsigned Stride = ReducedTy.getVectorMinNumElements();
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index ce5ef78eba15db..98c2e703c39ed1 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -8011,9 +8011,9 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
       return;
     }
 
-    setValue(&I, DAG.expandPartialReductionIntrinsic(
-                     EVT::getEVT(I.getType()), getValue(I.getOperand(0)),
-                     getValue(I.getOperand(1)), sdl));
+    setValue(&I, DAG.getPartialReduceAdd(sdl, EVT::getEVT(I.getType()),
+                                         getValue(I.getOperand(0)),
+                                         getValue(I.getOperand(1))));
     return;
   }
   case Intrinsic::experimental_cttz_elts: {
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index b849ddb2a86d67..a25c09ade370e2 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -21883,8 +21883,8 @@ static SDValue performIntrinsicCombine(SDNode *N,
   case Intrinsic::experimental_vector_partial_reduce_add: {
     if (auto Dot = tryLowerPartialReductionToDot(N, Subtarget, DAG))
       return Dot;
-    return DAG.expandPartialReductionIntrinsic(
-        N->getValueType(0), N->getOperand(1), N->getOperand(2), SDLoc(N));
+    return DAG.getPartialReduceAdd(SDLoc(N), N->getValueType(0),
+                                   N->getOperand(1), N->getOperand(2));
   }
   case Intrinsic::aarch64_neon_vcvtfxs2fp:
   case Intrinsic::aarch64_neon_vcvtfxu2fp:

>From 00a1be219912ca53f24dda3ee410229fa6286736 Mon Sep 17 00:00:00 2001
From: Samuel Tebbs <samuel.tebbs at arm.com>
Date: Wed, 28 Aug 2024 10:22:45 +0100
Subject: [PATCH 21/27] Simplify shouldExpandPartialReductionIntrinsic

---
 llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 15 +++------------
 1 file changed, 3 insertions(+), 12 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index a25c09ade370e2..9f5d5a61397d10 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1991,21 +1991,12 @@ bool AArch64TargetLowering::shouldExpandGetActiveLaneMask(EVT ResVT,
 bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
     const IntrinsicInst *I) const {
 
-  VectorType *RetTy = dyn_cast<VectorType>(I->getType());
-  if (!RetTy || !RetTy->isScalableTy())
+  if (I->getIntrinsicID() != Intrinsic::experimental_vector_partial_reduce_add)
     return true;
 
-  if (RetTy->getScalarType()->isIntegerTy(32) &&
-      RetTy->getElementCount() == ElementCount::get(4, RetTy->isScalableTy()))
-    return false;
-  if (RetTy->getScalarType()->isIntegerTy(64) &&
-      RetTy->getElementCount() == ElementCount::get(2, RetTy->isScalableTy()))
-    return false;
-  if (RetTy->getScalarType()->isIntegerTy(64) &&
-      RetTy->getElementCount() == ElementCount::get(4, RetTy->isScalableTy()))
-    return false;
+  EVT VT = EVT::getEVT(I->getType());
 
-  return true;
+  return VT != MVT::nxv4i32 && VT != MVT::nxv2i64 && VT != MVT::nxv4i64;
 }
 
 bool AArch64TargetLowering::shouldExpandCttzElements(EVT VT) const {

>From f6c58393ddc03ae7328f7a1fe8236bce55924999 Mon Sep 17 00:00:00 2001
From: Samuel Tebbs <samuel.tebbs at arm.com>
Date: Thu, 29 Aug 2024 09:36:24 +0100
Subject: [PATCH 22/27] Remove nxv4i64 case

---
 .../Target/AArch64/AArch64ISelLowering.cpp    | 77 ++++++-------------
 .../AArch64/partial-reduce-dot-product.ll     | 72 -----------------
 2 files changed, 23 insertions(+), 126 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 9f5d5a61397d10..7193303c7084ab 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -21795,35 +21795,6 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
   if (A.getValueType() != B.getValueType())
     return SDValue();
 
-  // The fully-reduced type. Should be a vector of i32 or i64
-  EVT FullType = N->getValueType(0);
-  // The type that is extended to the wide type. Should be an i8 or i16
-  EVT ExtendedType = A.getValueType();
-  // The wide type with four times as many elements as the reduced type. Should
-  // be a vector of i32 or i64, the same as the fully-reduced type
-  EVT WideType = MulOp.getValueType();
-  if (WideType.getScalarSizeInBits() != FullType.getScalarSizeInBits())
-    return SDValue();
-  // Dot products operate on chunks of four elements so there must be four times
-  // as many elements in the wide type
-  if (WideType.getVectorMinNumElements() / FullType.getVectorMinNumElements() !=
-      4)
-    return SDValue();
-  switch (FullType.getScalarSizeInBits()) {
-  case 32:
-    if (ExtendedType.getScalarSizeInBits() != 8)
-      return SDValue();
-    break;
-  case 64:
-    // i8 to i64 can be done with an extended i32 dot product
-    if (ExtendedType.getScalarSizeInBits() != 8 &&
-        ExtendedType.getScalarSizeInBits() != 16)
-      return SDValue();
-    break;
-  default:
-    return SDValue();
-  }
-
   unsigned DotIntrinsicId = Intrinsic::not_intrinsic;
 
   if (IsSExt)
@@ -21834,33 +21805,31 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
   assert(DotIntrinsicId != Intrinsic::not_intrinsic &&
          "Unexpected dot product case encountered.");
 
-  EVT Type = NarrowOp.getValueType();
+  auto IntrinsicId = DAG.getConstant(DotIntrinsicId, DL, MVT::i64);
 
-  // 8 bit input to 64 bit output can be done by doing a 32 bit dot product
-  // and extending the output
-  bool Extend = A->getValueType(0).getScalarSizeInBits() == 8 &&
-                Type.getScalarSizeInBits() == 64;
-  SDValue Accumulator = NarrowOp;
-  if (Extend) {
-    Type =
-        Type.changeVectorElementType(EVT::getIntegerVT(*DAG.getContext(), 32));
-    // The accumulator is of the wider type so we insert a 0 accumulator and
-    // add the proper one after extending
-    Accumulator = DAG.getNode(ISD::SPLAT_VECTOR, DL, MVT::nxv4i32,
-                              DAG.getConstant(0, DL, MVT::i32));
-  }
+  // The fully-reduced type. Should be a vector of i32 or i64
+  EVT ReducedType = N->getValueType(0);
+  // The type that is extended to the wide type. Should be an i8 or i16
+  EVT ExtendedType = A.getValueType();
+  // The wide type with four times as many elements as the reduced type. Should
+  // be a vector of i32 or i64, the same as the fully-reduced type
+  EVT WideType = MulOp.getValueType();
+  if (WideType.getScalarSizeInBits() != ReducedType.getScalarSizeInBits())
+    return SDValue();
 
-  auto IntrinsicId = DAG.getConstant(DotIntrinsicId, DL, MVT::i64);
-  auto DotProduct = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, Type,
-                                {IntrinsicId, Accumulator, A, B});
-  if (Extend) {
-    auto Extended = DAG.getNode(IsZExt ? ISD::ZERO_EXTEND : ISD::SIGN_EXTEND,
-                                DL, NarrowOp.getValueType(), {DotProduct});
-    auto AccAdd = DAG.getNode(ISD::ADD, DL, NarrowOp.getValueType(),
-                              {NarrowOp, Extended});
-    DotProduct = AccAdd;
-  }
-  return DotProduct;
+  // Dot products operate on chunks of four elements so there must be four times
+  // as many elements in the wide type
+  if (WideType == MVT::nxv16i32 && ReducedType == MVT::nxv4i32 &&
+      ExtendedType == MVT::nxv16i8)
+    return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, MVT::nxv4i32,
+                       {IntrinsicId, NarrowOp, A, B});
+
+  if (WideType == MVT::nxv8i64 && ReducedType == MVT::nxv2i64 &&
+      ExtendedType == MVT::nxv8i16)
+    return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, MVT::nxv2i64,
+                       {IntrinsicId, NarrowOp, A, B});
+
+  return SDValue();
 }
 
 static SDValue performIntrinsicCombine(SDNode *N,
diff --git a/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll
index c1cf9026d693ce..16ef219a93c9bf 100644
--- a/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll
@@ -61,78 +61,6 @@ entry:
   ret <vscale x 2 x i64> %partial.reduce
 }
 
-define <vscale x 4 x i64> @dotp_8to64(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
-; CHECK-LABEL: dotp_8to64:
-; CHECK:       // %bb.0: // %entry
-; CHECK-NEXT:    mov z2.s, #0 // =0x0
-; CHECK-NEXT:    udot z2.s, z0.b, z1.b
-; CHECK-NEXT:    uunpklo z0.d, z2.s
-; CHECK-NEXT:    uunpkhi z1.d, z2.s
-; CHECK-NEXT:    ret
-entry:
-  %a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i64>
-  %b.wide = zext <vscale x 16 x i8> %b to <vscale x 16 x i64>
-  %mult = mul nuw nsw <vscale x 16 x i64> %a.wide, %b.wide
-  %partial.reduce = tail call <vscale x 4 x i64> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(
-  <vscale x 4 x i64> zeroinitializer, <vscale x 16 x i64> %mult)
-  ret <vscale x 4 x i64> %partial.reduce
-}
-
-define <vscale x 4 x i64> @dotp_sext_8to64(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
-; CHECK-LABEL: dotp_sext_8to64:
-; CHECK:       // %bb.0: // %entry
-; CHECK-NEXT:    mov z2.s, #0 // =0x0
-; CHECK-NEXT:    sdot z2.s, z0.b, z1.b
-; CHECK-NEXT:    sunpklo z0.d, z2.s
-; CHECK-NEXT:    sunpkhi z1.d, z2.s
-; CHECK-NEXT:    ret
-entry:
-  %a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i64>
-  %b.wide = sext <vscale x 16 x i8> %b to <vscale x 16 x i64>
-  %mult = mul nuw nsw <vscale x 16 x i64> %a.wide, %b.wide
-  %partial.reduce = tail call <vscale x 4 x i64> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(
-  <vscale x 4 x i64> zeroinitializer, <vscale x 16 x i64> %mult)
-  ret <vscale x 4 x i64> %partial.reduce
-}
-
-define <vscale x 4 x i64> @dotp_8to64_accumulator(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b, <vscale x 4 x i64> %acc) {
-; CHECK-LABEL: dotp_8to64_accumulator:
-; CHECK:       // %bb.0: // %entry
-; CHECK-NEXT:    mov z4.s, #0 // =0x0
-; CHECK-NEXT:    udot z4.s, z0.b, z1.b
-; CHECK-NEXT:    uunpklo z0.d, z4.s
-; CHECK-NEXT:    uunpkhi z1.d, z4.s
-; CHECK-NEXT:    add z0.d, z2.d, z0.d
-; CHECK-NEXT:    add z1.d, z3.d, z1.d
-; CHECK-NEXT:    ret
-entry:
-  %a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i64>
-  %b.wide = zext <vscale x 16 x i8> %b to <vscale x 16 x i64>
-  %mult = mul nuw nsw <vscale x 16 x i64> %a.wide, %b.wide
-  %partial.reduce = tail call <vscale x 4 x i64> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(
-  <vscale x 4 x i64> %acc, <vscale x 16 x i64> %mult)
-  ret <vscale x 4 x i64> %partial.reduce
-}
-
-define <vscale x 4 x i64> @dotp_sext_8to64_accumulator(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b, <vscale x 4 x i64> %acc) {
-; CHECK-LABEL: dotp_sext_8to64_accumulator:
-; CHECK:       // %bb.0: // %entry
-; CHECK-NEXT:    mov z4.s, #0 // =0x0
-; CHECK-NEXT:    sdot z4.s, z0.b, z1.b
-; CHECK-NEXT:    sunpklo z0.d, z4.s
-; CHECK-NEXT:    sunpkhi z1.d, z4.s
-; CHECK-NEXT:    add z0.d, z2.d, z0.d
-; CHECK-NEXT:    add z1.d, z3.d, z1.d
-; CHECK-NEXT:    ret
-entry:
-  %a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i64>
-  %b.wide = sext <vscale x 16 x i8> %b to <vscale x 16 x i64>
-  %mult = mul nuw nsw <vscale x 16 x i64> %a.wide, %b.wide
-  %partial.reduce = tail call <vscale x 4 x i64> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(
-  <vscale x 4 x i64> %acc, <vscale x 16 x i64> %mult)
-  ret <vscale x 4 x i64> %partial.reduce
-}
-
 define <vscale x 4 x i32> @not_dotp(<vscale x 8 x i8> %a, <vscale x 8 x i8> %b) {
 ; CHECK-LABEL: not_dotp:
 ; CHECK:       // %bb.0: // %entry

>From da20b2a03ef35f0f1172f0b4826ac3d792671da5 Mon Sep 17 00:00:00 2001
From: Samuel Tebbs <samuel.tebbs at arm.com>
Date: Thu, 29 Aug 2024 09:47:18 +0100
Subject: [PATCH 23/27] Add assertion

---
 llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 5 +++++
 1 file changed, 5 insertions(+)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 7193303c7084ab..a100d033f50d00 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -21772,6 +21772,11 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
                                       const AArch64Subtarget *Subtarget,
                                       SelectionDAG &DAG) {
 
+  assert(N->getOpcode() == ISD::INTRINSIC_WO_CHAIN &&
+         getIntrinsicID(N) ==
+             Intrinsic::experimental_vector_partial_reduce_add &&
+         "Expected a partial reduction node");
+
   if (!Subtarget->isSVEAvailable() && !Subtarget->isNeonAvailable())
     return SDValue();
 

>From 4697fc13d67421bd364ca907659a4896f629d0e3 Mon Sep 17 00:00:00 2001
From: Samuel Tebbs <samuel.tebbs at arm.com>
Date: Thu, 29 Aug 2024 09:48:31 +0100
Subject: [PATCH 24/27] Fix subtarget check

---
 llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 4 +---
 1 file changed, 1 insertion(+), 3 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index a100d033f50d00..e303f6b693da18 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -21777,7 +21777,7 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
              Intrinsic::experimental_vector_partial_reduce_add &&
          "Expected a partial reduction node");
 
-  if (!Subtarget->isSVEAvailable() && !Subtarget->isNeonAvailable())
+  if (!Subtarget->isSVEorStreamingSVEAvailable())
     return SDValue();
 
   SDLoc DL(N);
@@ -21819,8 +21819,6 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
   // The wide type with four times as many elements as the reduced type. Should
   // be a vector of i32 or i64, the same as the fully-reduced type
   EVT WideType = MulOp.getValueType();
-  if (WideType.getScalarSizeInBits() != ReducedType.getScalarSizeInBits())
-    return SDValue();
 
   // Dot products operate on chunks of four elements so there must be four times
   // as many elements in the wide type

>From 31b75674f414b740c3770377cb8b498aaa607225 Mon Sep 17 00:00:00 2001
From: Samuel Tebbs <samuel.tebbs at arm.com>
Date: Thu, 29 Aug 2024 09:50:33 +0100
Subject: [PATCH 25/27] Emit a node instead of an intrinsic

---
 .../Target/AArch64/AArch64ISelLowering.cpp    | 19 ++++++++-----------
 1 file changed, 8 insertions(+), 11 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index e303f6b693da18..ded99e5ec8c440 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -21800,17 +21800,14 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
   if (A.getValueType() != B.getValueType())
     return SDValue();
 
-  unsigned DotIntrinsicId = Intrinsic::not_intrinsic;
+  unsigned Opcode = 0;
 
   if (IsSExt)
-    DotIntrinsicId = Intrinsic::aarch64_sve_sdot;
+    Opcode = AArch64ISD::SDOT;
   else if (IsZExt)
-    DotIntrinsicId = Intrinsic::aarch64_sve_udot;
-
-  assert(DotIntrinsicId != Intrinsic::not_intrinsic &&
-         "Unexpected dot product case encountered.");
+    Opcode = AArch64ISD::UDOT;
 
-  auto IntrinsicId = DAG.getConstant(DotIntrinsicId, DL, MVT::i64);
+  assert(Opcode != 0 && "Unexpected dot product case encountered.");
 
   // The fully-reduced type. Should be a vector of i32 or i64
   EVT ReducedType = N->getValueType(0);
@@ -21824,13 +21821,13 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
   // as many elements in the wide type
   if (WideType == MVT::nxv16i32 && ReducedType == MVT::nxv4i32 &&
       ExtendedType == MVT::nxv16i8)
-    return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, MVT::nxv4i32,
-                       {IntrinsicId, NarrowOp, A, B});
+    return DAG.getNode(Opcode, DL, MVT::nxv4i32,
+                              NarrowOp, A, B);
 
   if (WideType == MVT::nxv8i64 && ReducedType == MVT::nxv2i64 &&
       ExtendedType == MVT::nxv8i16)
-    return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, MVT::nxv2i64,
-                       {IntrinsicId, NarrowOp, A, B});
+    return DAG.getNode(Opcode, DL, MVT::nxv2i64,
+                              NarrowOp, A, B);
 
   return SDValue();
 }

>From 76296792b2cd94c78daaa03e8416e5bc86346ccb Mon Sep 17 00:00:00 2001
From: Samuel Tebbs <samuel.tebbs at arm.com>
Date: Thu, 29 Aug 2024 10:11:18 +0100
Subject: [PATCH 26/27] Pass accumulator from function in tests

---
 .../AArch64/partial-reduce-dot-product.ll     | 68 ++++++++-----------
 1 file changed, 30 insertions(+), 38 deletions(-)

diff --git a/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll
index 16ef219a93c9bf..b1354ab210f727 100644
--- a/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll
@@ -1,104 +1,96 @@
 ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
 ; RUN: llc -mtriple=aarch64 -mattr=+sve2 %s -o - | FileCheck %s
 
-define <vscale x 4 x i32> @dotp(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
+define <vscale x 4 x i32> @dotp(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
 ; CHECK-LABEL: dotp:
 ; CHECK:       // %bb.0: // %entry
-; CHECK-NEXT:    mov z2.s, #0 // =0x0
-; CHECK-NEXT:    udot z2.s, z0.b, z1.b
-; CHECK-NEXT:    mov z0.d, z2.d
+; CHECK-NEXT:    udot z0.s, z1.b, z2.b
 ; CHECK-NEXT:    ret
 entry:
   %a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
   %b.wide = zext <vscale x 16 x i8> %b to <vscale x 16 x i32>
   %mult = mul nuw nsw <vscale x 16 x i32> %a.wide, %b.wide
-  %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> zeroinitializer, <vscale x 16 x i32> %mult)
+  %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %mult)
   ret <vscale x 4 x i32> %partial.reduce
 }
 
-define <vscale x 2 x i64> @dotp_wide(<vscale x 8 x i16> %a, <vscale x 8 x i16> %b) {
+define <vscale x 2 x i64> @dotp_wide(<vscale x 2 x i64> %acc, <vscale x 8 x i16> %a, <vscale x 8 x i16> %b) {
 ; CHECK-LABEL: dotp_wide:
 ; CHECK:       // %bb.0: // %entry
-; CHECK-NEXT:    mov z2.d, #0 // =0x0
-; CHECK-NEXT:    udot z2.d, z0.h, z1.h
-; CHECK-NEXT:    mov z0.d, z2.d
+; CHECK-NEXT:    udot z0.d, z1.h, z2.h
 ; CHECK-NEXT:    ret
 entry:
   %a.wide = zext <vscale x 8 x i16> %a to <vscale x 8 x i64>
   %b.wide = zext <vscale x 8 x i16> %b to <vscale x 8 x i64>
   %mult = mul nuw nsw <vscale x 8 x i64> %a.wide, %b.wide
-  %partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv8i64(<vscale x 2 x i64> zeroinitializer, <vscale x 8 x i64> %mult)
+  %partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv8i64(<vscale x 2 x i64> %acc, <vscale x 8 x i64> %mult)
   ret <vscale x 2 x i64> %partial.reduce
 }
 
-define <vscale x 4 x i32> @dotp_sext(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
+define <vscale x 4 x i32> @dotp_sext(<vscale x 4 x i32> %accc, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
 ; CHECK-LABEL: dotp_sext:
 ; CHECK:       // %bb.0: // %entry
-; CHECK-NEXT:    mov z2.s, #0 // =0x0
-; CHECK-NEXT:    sdot z2.s, z0.b, z1.b
-; CHECK-NEXT:    mov z0.d, z2.d
+; CHECK-NEXT:    sdot z0.s, z1.b, z2.b
 ; CHECK-NEXT:    ret
 entry:
   %a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
   %b.wide = sext <vscale x 16 x i8> %b to <vscale x 16 x i32>
   %mult = mul nuw nsw <vscale x 16 x i32> %a.wide, %b.wide
-  %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> zeroinitializer, <vscale x 16 x i32> %mult)
+  %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %accc, <vscale x 16 x i32> %mult)
   ret <vscale x 4 x i32> %partial.reduce
 }
 
-define <vscale x 2 x i64> @dotp_wide_sext(<vscale x 8 x i16> %a, <vscale x 8 x i16> %b) {
+define <vscale x 2 x i64> @dotp_wide_sext(<vscale x 2 x i64> %acc, <vscale x 8 x i16> %a, <vscale x 8 x i16> %b) {
 ; CHECK-LABEL: dotp_wide_sext:
 ; CHECK:       // %bb.0: // %entry
-; CHECK-NEXT:    mov z2.d, #0 // =0x0
-; CHECK-NEXT:    sdot z2.d, z0.h, z1.h
-; CHECK-NEXT:    mov z0.d, z2.d
+; CHECK-NEXT:    sdot z0.d, z1.h, z2.h
 ; CHECK-NEXT:    ret
 entry:
   %a.wide = sext <vscale x 8 x i16> %a to <vscale x 8 x i64>
   %b.wide = sext <vscale x 8 x i16> %b to <vscale x 8 x i64>
   %mult = mul nuw nsw <vscale x 8 x i64> %a.wide, %b.wide
-  %partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv8i64(<vscale x 2 x i64> zeroinitializer, <vscale x 8 x i64> %mult)
+  %partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv8i64(<vscale x 2 x i64> %acc, <vscale x 8 x i64> %mult)
   ret <vscale x 2 x i64> %partial.reduce
 }
 
-define <vscale x 4 x i32> @not_dotp(<vscale x 8 x i8> %a, <vscale x 8 x i8> %b) {
+define <vscale x 4 x i32> @not_dotp(<vscale x 4 x i32> %acc, <vscale x 8 x i8> %a, <vscale x 8 x i8> %b) {
 ; CHECK-LABEL: not_dotp:
 ; CHECK:       // %bb.0: // %entry
-; CHECK-NEXT:    and z0.h, z0.h, #0xff
 ; CHECK-NEXT:    and z1.h, z1.h, #0xff
+; CHECK-NEXT:    and z2.h, z2.h, #0xff
 ; CHECK-NEXT:    ptrue p0.s
-; CHECK-NEXT:    uunpkhi z2.s, z0.h
-; CHECK-NEXT:    uunpkhi z3.s, z1.h
-; CHECK-NEXT:    uunpklo z0.s, z0.h
-; CHECK-NEXT:    uunpklo z1.s, z1.h
-; CHECK-NEXT:    mul z2.s, z2.s, z3.s
-; CHECK-NEXT:    mad z0.s, p0/m, z1.s, z2.s
+; CHECK-NEXT:    uunpklo z3.s, z1.h
+; CHECK-NEXT:    uunpklo z4.s, z2.h
+; CHECK-NEXT:    uunpkhi z1.s, z1.h
+; CHECK-NEXT:    uunpkhi z2.s, z2.h
+; CHECK-NEXT:    mla z0.s, p0/m, z3.s, z4.s
+; CHECK-NEXT:    mla z0.s, p0/m, z1.s, z2.s
 ; CHECK-NEXT:    ret
 entry:
   %a.wide = zext <vscale x 8 x i8> %a to <vscale x 8 x i32>
   %b.wide = zext <vscale x 8 x i8> %b to <vscale x 8 x i32>
   %mult = mul nuw nsw <vscale x 8 x i32> %a.wide, %b.wide
-  %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> zeroinitializer, <vscale x 8 x i32> %mult)
+  %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %acc, <vscale x 8 x i32> %mult)
   ret <vscale x 4 x i32> %partial.reduce
 }
 
-define <vscale x 2 x i64> @not_dotp_wide(<vscale x 4 x i16> %a, <vscale x 4 x i16> %b) {
+define <vscale x 2 x i64> @not_dotp_wide(<vscale x 2 x i64> %acc, <vscale x 4 x i16> %a, <vscale x 4 x i16> %b) {
 ; CHECK-LABEL: not_dotp_wide:
 ; CHECK:       // %bb.0: // %entry
-; CHECK-NEXT:    and z0.s, z0.s, #0xffff
 ; CHECK-NEXT:    and z1.s, z1.s, #0xffff
+; CHECK-NEXT:    and z2.s, z2.s, #0xffff
 ; CHECK-NEXT:    ptrue p0.d
-; CHECK-NEXT:    uunpkhi z2.d, z0.s
-; CHECK-NEXT:    uunpkhi z3.d, z1.s
-; CHECK-NEXT:    uunpklo z0.d, z0.s
-; CHECK-NEXT:    uunpklo z1.d, z1.s
-; CHECK-NEXT:    mul z2.d, z2.d, z3.d
-; CHECK-NEXT:    mad z0.d, p0/m, z1.d, z2.d
+; CHECK-NEXT:    uunpklo z3.d, z1.s
+; CHECK-NEXT:    uunpklo z4.d, z2.s
+; CHECK-NEXT:    uunpkhi z1.d, z1.s
+; CHECK-NEXT:    uunpkhi z2.d, z2.s
+; CHECK-NEXT:    mla z0.d, p0/m, z3.d, z4.d
+; CHECK-NEXT:    mla z0.d, p0/m, z1.d, z2.d
 ; CHECK-NEXT:    ret
 entry:
   %a.wide = zext <vscale x 4 x i16> %a to <vscale x 4 x i64>
   %b.wide = zext <vscale x 4 x i16> %b to <vscale x 4 x i64>
   %mult = mul nuw nsw <vscale x 4 x i64> %a.wide, %b.wide
-  %partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv4i64(<vscale x 2 x i64> zeroinitializer, <vscale x 4 x i64> %mult)
+  %partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv4i64(<vscale x 2 x i64> %acc, <vscale x 4 x i64> %mult)
   ret <vscale x 2 x i64> %partial.reduce
 }

>From 830df7624894cca7c00db2263129433cf0e1a9d7 Mon Sep 17 00:00:00 2001
From: Sam Tebbs <samuel.tebbs at arm.com>
Date: Fri, 30 Aug 2024 10:00:37 +0100
Subject: [PATCH 27/27] Remove nxv4i64 case from
 shouldExpandPartialReductionIntrinsic

---
 llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index ded99e5ec8c440..9ec25a4074a0a3 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1996,7 +1996,7 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
 
   EVT VT = EVT::getEVT(I->getType());
 
-  return VT != MVT::nxv4i32 && VT != MVT::nxv2i64 && VT != MVT::nxv4i64;
+  return VT != MVT::nxv4i32 && VT != MVT::nxv2i64;
 }
 
 bool AArch64TargetLowering::shouldExpandCttzElements(EVT VT) const {



More information about the llvm-commits mailing list