[llvm] [AArch64] Lower partial add reduction to udot or svdot (PR #101010)

Sam Tebbs via llvm-commits llvm-commits at lists.llvm.org
Thu Aug 1 06:03:05 PDT 2024


https://github.com/SamTebbs33 updated https://github.com/llvm/llvm-project/pull/101010

>From 0b9ce21c0019fea07188ffd142bd9cf580d09f35 Mon Sep 17 00:00:00 2001
From: Sam Tebbs <samuel.tebbs at arm.com>
Date: Mon, 29 Jul 2024 10:46:16 +0100
Subject: [PATCH 1/7] [AArch64] Lower add partial reduction to udot

This patch introduces lowering of the partial add reduction intrinsic to
a udot or svdot for AArch64.
---
 llvm/include/llvm/CodeGen/TargetLowering.h    |   6 +
 .../SelectionDAG/SelectionDAGBuilder.cpp      |   6 +
 .../Target/AArch64/AArch64ISelLowering.cpp    |  77 +++++++++++++
 llvm/lib/Target/AArch64/AArch64ISelLowering.h |   2 +
 .../AArch64/AArch64TargetTransformInfo.cpp    |  30 +++++
 .../AArch64/AArch64TargetTransformInfo.h      |   6 +
 .../AArch64/partial-reduce-dot-product.ll     | 109 ++++++++++++++++++
 7 files changed, 236 insertions(+)
 create mode 100644 llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll

diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 9d9886f4920a2..07d99aec47122 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -453,6 +453,12 @@ class TargetLoweringBase {
     return true;
   }
 
+  /// Return true if the @llvm.experimental.vector.partial.reduce.* intrinsic
+  /// should be expanded using generic code in SelectionDAGBuilder.
+  virtual bool shouldExpandPartialReductionIntrinsic(const CallInst *I) const {
+    return true;
+  }
+
   /// Return true if the @llvm.get.active.lane.mask intrinsic should be expanded
   /// using generic code in SelectionDAGBuilder.
   virtual bool shouldExpandGetActiveLaneMask(EVT VT, EVT OpVT) const {
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 1791f1b503379..c70ab253c1aab 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -7985,6 +7985,12 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
     return;
   }
   case Intrinsic::experimental_vector_partial_reduce_add: {
+
+    if (!TLI.shouldExpandPartialReductionIntrinsic(&I)) {
+      visitTargetIntrinsic(I, Intrinsic);
+      return;
+    }
+
     SDValue OpNode = getValue(I.getOperand(1));
     EVT ReducedTy = EVT::getEVT(I.getType());
     EVT FullTy = OpNode.getValueType();
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index d86e52d49000a..d1ee58668ecbd 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1971,6 +1971,57 @@ bool AArch64TargetLowering::shouldExpandGetActiveLaneMask(EVT ResVT,
   return false;
 }
 
+bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
+    const CallInst *CI) const {
+  const bool TargetLowers = false;
+  const bool GenericLowers = true;
+
+  auto *I = dyn_cast<IntrinsicInst>(CI);
+  if (!I)
+    return GenericLowers;
+
+  ScalableVectorType *RetTy = dyn_cast<ScalableVectorType>(I->getType());
+
+  if (!RetTy)
+    return GenericLowers;
+
+  ScalableVectorType *InputTy = nullptr;
+
+  auto RetScalarTy = RetTy->getScalarType();
+  if (RetScalarTy->isIntegerTy(64)) {
+    InputTy = ScalableVectorType::get(Type::getInt16Ty(I->getContext()), 8);
+  } else if (RetScalarTy->isIntegerTy(32)) {
+    InputTy = ScalableVectorType::get(Type::getInt8Ty(I->getContext()), 16);
+  }
+
+  if (!InputTy)
+    return GenericLowers;
+
+  Value *InputA;
+  Value *InputB;
+
+  auto Pattern = m_Intrinsic<Intrinsic::experimental_vector_partial_reduce_add>(
+      m_Value(), m_OneUse(m_Mul(m_OneUse(m_ZExtOrSExt(m_Value(InputA))),
+                                m_OneUse(m_ZExtOrSExt(m_Value(InputB))))));
+
+  if (!match(I, Pattern))
+    return GenericLowers;
+
+  auto Mul = cast<Instruction>(I->getOperand(1));
+
+  auto getOpcodeOfOperand = [&](unsigned Idx) {
+    return cast<Instruction>(Mul->getOperand(Idx))->getOpcode();
+  };
+
+  if (getOpcodeOfOperand(0) != getOpcodeOfOperand(1))
+    return GenericLowers;
+
+  if (InputA->getType() != InputTy || InputB->getType() != InputTy)
+    return GenericLowers;
+
+  return TargetLowers;
+}
+
 bool AArch64TargetLowering::shouldExpandCttzElements(EVT VT) const {
   if (!Subtarget->isSVEorStreamingSVEAvailable())
     return true;
@@ -21237,6 +21288,32 @@ static SDValue performIntrinsicCombine(SDNode *N,
   switch (IID) {
   default:
     break;
+  case Intrinsic::experimental_vector_partial_reduce_add: {
+    SDLoc DL(N);
+
+    auto NarrowOp = N->getOperand(1);
+    auto MulOp = N->getOperand(2);
+
+    auto ExtA = MulOp->getOperand(0);
+    auto ExtB = MulOp->getOperand(1);
+
+    unsigned DotIntrinsicId = Intrinsic::not_intrinsic;
+
+    if (ExtA->getOpcode() == ISD::SIGN_EXTEND)
+      DotIntrinsicId = Intrinsic::aarch64_sve_sdot;
+    else if (ExtA->getOpcode() == ISD::ZERO_EXTEND)
+      DotIntrinsicId = Intrinsic::aarch64_sve_udot;
+
+    assert(DotIntrinsicId != Intrinsic::not_intrinsic &&
+           "Unexpected dot product case encountered.");
+
+    auto A = ExtA->getOperand(0);
+    auto B = ExtB->getOperand(0);
+
+    auto IntrinsicId = DAG.getConstant(DotIntrinsicId, DL, MVT::i64);
+    return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, NarrowOp.getValueType(),
+                       {IntrinsicId, NarrowOp, A, B});
+  }
   case Intrinsic::aarch64_neon_vcvtfxs2fp:
   case Intrinsic::aarch64_neon_vcvtfxu2fp:
     return tryCombineFixedPointConvert(N, DCI, DAG);
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index 81e15185f985d..fc79d9766719b 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -991,6 +991,8 @@ class AArch64TargetLowering : public TargetLowering {
 
   bool shouldExpandGetActiveLaneMask(EVT VT, EVT OpVT) const override;
 
+  bool shouldExpandPartialReductionIntrinsic(const CallInst *I) const override;
+
   bool shouldExpandCttzElements(EVT VT) const override;
 
   /// If a change in streaming mode is required on entry to/return from a
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 45148449dfb82..792bd54601919 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -3533,6 +3533,36 @@ AArch64TTIImpl::getCostOfKeepingLiveOverCall(ArrayRef<Type *> Tys) {
   return Cost;
 }
 
+bool AArch64TTIImpl::isPartialReductionSupported(
+    const Instruction *ReductionInstr, Type *InputType, unsigned ScaleFactor,
+    bool IsInputASignExtended, bool IsInputBSignExtended,
+    const Instruction *BinOp) const {
+  if (ReductionInstr->getOpcode() != Instruction::Add)
+    return false;
+
+  // Check that both extends are of the same type
+  if (IsInputASignExtended != IsInputBSignExtended)
+    return false;
+
+  if (!BinOp || BinOp->getOpcode() != Instruction::Mul)
+    return false;
+
+  // Dot product only supports a scale factor of 4
+  if (ScaleFactor != 4)
+    return false;
+
+  Type *ReductionType = ReductionInstr->getType();
+  if (ReductionType->isIntegerTy(32)) {
+    if (!InputType->isIntegerTy(8))
+      return false;
+  } else if (ReductionType->isIntegerTy(64)) {
+    if (!InputType->isIntegerTy(16))
+      return false;
+  }
+
+  return true;
+}
+
 unsigned AArch64TTIImpl::getMaxInterleaveFactor(ElementCount VF) {
   return ST->getMaxInterleaveFactor();
 }
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index a9189fd53f40b..592b452134e77 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -155,6 +155,12 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
     return VF.getKnownMinValue() * ST->getVScaleForTuning();
   }
 
+  bool isPartialReductionSupported(const Instruction *ReductionInstr,
+                                   Type *InputType, unsigned ScaleFactor,
+                                   bool IsInputASignExtended,
+                                   bool IsInputBSignExtended,
+                                   const Instruction *BinOp = nullptr) const;
+
   unsigned getMaxInterleaveFactor(ElementCount VF);
 
   bool prefersVectorizedAddressing() const;
diff --git a/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll
new file mode 100644
index 0000000000000..23b39387fb7a0
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll
@@ -0,0 +1,109 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc -mtriple=aarch64-unknwon-linux-gnu -mattr=+sve2 -O3 %s -o - | FileCheck %s
+
+target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128"
+target triple = "aarch64-none-unknown-elf"
+
+define <vscale x 4 x i32> @dotp(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b) #0 {
+; CHECK-LABEL: dotp:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    mov z2.s, #0 // =0x0
+; CHECK-NEXT:    udot z2.s, z0.b, z1.b
+; CHECK-NEXT:    mov z0.d, z2.d
+; CHECK-NEXT:    ret
+entry:
+  %a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+  %b.wide = zext <vscale x 16 x i8> %b to <vscale x 16 x i32>
+  %mult = mul nuw nsw <vscale x 16 x i32> %a.wide, %b.wide
+  %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> zeroinitializer, <vscale x 16 x i32> %mult)
+  ret <vscale x 4 x i32> %partial.reduce
+}
+
+define <vscale x 2 x i64> @dotp_wide(<vscale x 8 x i16> %a, <vscale x 8 x i16> %b) #0 {
+; CHECK-LABEL: dotp_wide:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    mov z2.d, #0 // =0x0
+; CHECK-NEXT:    udot z2.d, z0.h, z1.h
+; CHECK-NEXT:    mov z0.d, z2.d
+; CHECK-NEXT:    ret
+entry:
+  %a.wide = zext <vscale x 8 x i16> %a to <vscale x 8 x i64>
+  %b.wide = zext <vscale x 8 x i16> %b to <vscale x 8 x i64>
+  %mult = mul nuw nsw <vscale x 8 x i64> %a.wide, %b.wide
+  %partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv8i64(<vscale x 2 x i64> zeroinitializer, <vscale x 8 x i64> %mult)
+  ret <vscale x 2 x i64> %partial.reduce
+}
+
+define <vscale x 4 x i32> @dotp_sext(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b) #0 {
+; CHECK-LABEL: dotp_sext:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    mov z2.s, #0 // =0x0
+; CHECK-NEXT:    sdot z2.s, z0.b, z1.b
+; CHECK-NEXT:    mov z0.d, z2.d
+; CHECK-NEXT:    ret
+entry:
+  %a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+  %b.wide = sext <vscale x 16 x i8> %b to <vscale x 16 x i32>
+  %mult = mul nuw nsw <vscale x 16 x i32> %a.wide, %b.wide
+  %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> zeroinitializer, <vscale x 16 x i32> %mult)
+  ret <vscale x 4 x i32> %partial.reduce
+}
+
+define <vscale x 2 x i64> @dotp_wide_sext(<vscale x 8 x i16> %a, <vscale x 8 x i16> %b) #0 {
+; CHECK-LABEL: dotp_wide_sext:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    mov z2.d, #0 // =0x0
+; CHECK-NEXT:    sdot z2.d, z0.h, z1.h
+; CHECK-NEXT:    mov z0.d, z2.d
+; CHECK-NEXT:    ret
+entry:
+  %a.wide = sext <vscale x 8 x i16> %a to <vscale x 8 x i64>
+  %b.wide = sext <vscale x 8 x i16> %b to <vscale x 8 x i64>
+  %mult = mul nuw nsw <vscale x 8 x i64> %a.wide, %b.wide
+  %partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv8i64(<vscale x 2 x i64> zeroinitializer, <vscale x 8 x i64> %mult)
+  ret <vscale x 2 x i64> %partial.reduce
+}
+
+define <vscale x 4 x i32> @not_dotp(<vscale x 8 x i8> %a, <vscale x 8 x i8> %b) #0 {
+; CHECK-LABEL: not_dotp:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    and z0.h, z0.h, #0xff
+; CHECK-NEXT:    and z1.h, z1.h, #0xff
+; CHECK-NEXT:    ptrue p0.s
+; CHECK-NEXT:    uunpkhi z2.s, z0.h
+; CHECK-NEXT:    uunpkhi z3.s, z1.h
+; CHECK-NEXT:    uunpklo z0.s, z0.h
+; CHECK-NEXT:    uunpklo z1.s, z1.h
+; CHECK-NEXT:    mul z2.s, z2.s, z3.s
+; CHECK-NEXT:    mad z0.s, p0/m, z1.s, z2.s
+; CHECK-NEXT:    ret
+entry:
+  %a.wide = zext <vscale x 8 x i8> %a to <vscale x 8 x i32>
+  %b.wide = zext <vscale x 8 x i8> %b to <vscale x 8 x i32>
+  %mult = mul nuw nsw <vscale x 8 x i32> %a.wide, %b.wide
+  %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> zeroinitializer, <vscale x 8 x i32> %mult)
+  ret <vscale x 4 x i32> %partial.reduce
+}
+
+define <vscale x 2 x i64> @not_dotp_wide(<vscale x 4 x i16> %a, <vscale x 4 x i16> %b) #0 {
+; CHECK-LABEL: not_dotp_wide:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    and z0.s, z0.s, #0xffff
+; CHECK-NEXT:    and z1.s, z1.s, #0xffff
+; CHECK-NEXT:    ptrue p0.d
+; CHECK-NEXT:    uunpkhi z2.d, z0.s
+; CHECK-NEXT:    uunpkhi z3.d, z1.s
+; CHECK-NEXT:    uunpklo z0.d, z0.s
+; CHECK-NEXT:    uunpklo z1.d, z1.s
+; CHECK-NEXT:    mul z2.d, z2.d, z3.d
+; CHECK-NEXT:    mad z0.d, p0/m, z1.d, z2.d
+; CHECK-NEXT:    ret
+entry:
+  %a.wide = zext <vscale x 4 x i16> %a to <vscale x 4 x i64>
+  %b.wide = zext <vscale x 4 x i16> %b to <vscale x 4 x i64>
+  %mult = mul nuw nsw <vscale x 4 x i64> %a.wide, %b.wide
+  %partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv4i64(<vscale x 2 x i64> zeroinitializer, <vscale x 4 x i64> %mult)
+  ret <vscale x 2 x i64> %partial.reduce
+}
+
+attributes #0 = { "target-features"="+sve2" }

>From 7a5661702155198e6e4f9eed4d83bbc2cd0cee6e Mon Sep 17 00:00:00 2001
From: Sam Tebbs <samuel.tebbs at arm.com>
Date: Mon, 29 Jul 2024 18:08:52 +0100
Subject: [PATCH 2/7] Remove TargetLowers and GenericLowers

---
 llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 16 +++++++---------
 1 file changed, 7 insertions(+), 9 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index d1ee58668ecbd..b936e4cb4dccf 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1973,17 +1973,15 @@ bool AArch64TargetLowering::shouldExpandGetActiveLaneMask(EVT ResVT,
 
 bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
     const CallInst *CI) const {
-  const bool TargetLowers = false;
-  const bool GenericLowers = true;
 
   auto *I = dyn_cast<IntrinsicInst>(CI);
   if (!I)
-    return GenericLowers;
+    return true;
 
   ScalableVectorType *RetTy = dyn_cast<ScalableVectorType>(I->getType());
 
   if (!RetTy)
-    return GenericLowers;
+    return true;
 
   ScalableVectorType *InputTy = nullptr;
 
@@ -1995,7 +1993,7 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
   }
 
   if (!InputTy)
-    return GenericLowers;
+    return true;
 
   Value *InputA;
   Value *InputB;
@@ -2005,7 +2003,7 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
                                 m_OneUse(m_ZExtOrSExt(m_Value(InputB))))));
 
   if (!match(I, Pattern))
-    return GenericLowers;
+    return true;
 
   auto Mul = cast<Instruction>(I->getOperand(1));
 
@@ -2014,12 +2012,12 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
   };
 
   if (getOpcodeOfOperand(0) != getOpcodeOfOperand(1))
-    return GenericLowers;
+    return true;
 
   if (InputA->getType() != InputTy || InputB->getType() != InputTy)
-    return GenericLowers;
+    return true;
 
-  return TargetLowers;
+  return false;
 }
 
 bool AArch64TargetLowering::shouldExpandCttzElements(EVT VT) const {

>From 7bddd3b71c858d21ea626eb7f453126ad1a1b65b Mon Sep 17 00:00:00 2001
From: Sam Tebbs <samuel.tebbs at arm.com>
Date: Mon, 29 Jul 2024 18:16:16 +0100
Subject: [PATCH 3/7] Assert that shouldExpandPartialReductionIntrinsic sees an
 intrinsic

---
 llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index b936e4cb4dccf..fdcce7a9d124b 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1975,8 +1975,7 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
     const CallInst *CI) const {
 
   auto *I = dyn_cast<IntrinsicInst>(CI);
-  if (!I)
-    return true;
+  assert(I && "shouldExpandPartialReductionIntrinsic expects an intrinisc");
 
   ScalableVectorType *RetTy = dyn_cast<ScalableVectorType>(I->getType());
 

>From 48798c1e4ec770f6a47c69e841c048a83bb9bff6 Mon Sep 17 00:00:00 2001
From: Sam Tebbs <samuel.tebbs at arm.com>
Date: Mon, 29 Jul 2024 18:31:24 +0100
Subject: [PATCH 4/7] Allow non-scalable vector types

---
 llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 9 ++++-----
 1 file changed, 4 insertions(+), 5 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index fdcce7a9d124b..28e13e2e2841c 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1977,18 +1977,17 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
   auto *I = dyn_cast<IntrinsicInst>(CI);
   assert(I && "shouldExpandPartialReductionIntrinsic expects an intrinisc");
 
-  ScalableVectorType *RetTy = dyn_cast<ScalableVectorType>(I->getType());
-
+  VectorType *RetTy = dyn_cast<VectorType>(I->getType());
   if (!RetTy)
     return true;
 
-  ScalableVectorType *InputTy = nullptr;
+  VectorType *InputTy = nullptr;
 
   auto RetScalarTy = RetTy->getScalarType();
   if (RetScalarTy->isIntegerTy(64)) {
-    InputTy = ScalableVectorType::get(Type::getInt16Ty(I->getContext()), 8);
+    InputTy = VectorType::get(Type::getInt16Ty(I->getContext()), 8, RetTy->isScalableTy());
   } else if (RetScalarTy->isIntegerTy(32)) {
-    InputTy = ScalableVectorType::get(Type::getInt8Ty(I->getContext()), 16);
+    InputTy = VectorType::get(Type::getInt8Ty(I->getContext()), 16, RetTy->isScalableTy());
   }
 
   if (!InputTy)

>From 955c84e7aa1f811f2a78585d9dbf985672d3e21e Mon Sep 17 00:00:00 2001
From: Sam Tebbs <samuel.tebbs at arm.com>
Date: Mon, 29 Jul 2024 19:10:08 +0100
Subject: [PATCH 5/7] Clean up type checking

---
 llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 9 +++------
 1 file changed, 3 insertions(+), 6 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 28e13e2e2841c..8cf997cf0a3f2 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1984,13 +1984,11 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
   VectorType *InputTy = nullptr;
 
   auto RetScalarTy = RetTy->getScalarType();
-  if (RetScalarTy->isIntegerTy(64)) {
+  if (RetScalarTy->isIntegerTy(64))
     InputTy = VectorType::get(Type::getInt16Ty(I->getContext()), 8, RetTy->isScalableTy());
-  } else if (RetScalarTy->isIntegerTy(32)) {
+  else if (RetScalarTy->isIntegerTy(32))
     InputTy = VectorType::get(Type::getInt8Ty(I->getContext()), 16, RetTy->isScalableTy());
-  }
-
-  if (!InputTy)
+  else
     return true;
 
   Value *InputA;
@@ -2004,7 +2002,6 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
     return true;
 
   auto Mul = cast<Instruction>(I->getOperand(1));
-
   auto getOpcodeOfOperand = [&](unsigned Idx) {
     return cast<Instruction>(Mul->getOperand(Idx))->getOpcode();
   };

>From 7acadbc84967f045e803ce0d9e9008ab1554cdf8 Mon Sep 17 00:00:00 2001
From: Samuel Tebbs <samuel.tebbs at arm.com>
Date: Thu, 1 Aug 2024 11:04:37 +0100
Subject: [PATCH 6/7] Restrict to scalable vector types and clean up type
 checking

---
 llvm/lib/Target/AArch64/AArch64ISelLowering.cpp       |  2 +-
 .../lib/Target/AArch64/AArch64TargetTransformInfo.cpp | 11 +++--------
 2 files changed, 4 insertions(+), 9 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 8cf997cf0a3f2..510a92dc8d734 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1978,7 +1978,7 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
   assert(I && "shouldExpandPartialReductionIntrinsic expects an intrinisc");
 
   VectorType *RetTy = dyn_cast<VectorType>(I->getType());
-  if (!RetTy)
+  if (!RetTy || !RetTy->isScalableTy())
     return true;
 
   VectorType *InputTy = nullptr;
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 792bd54601919..afa9acf5c8de3 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -3552,15 +3552,10 @@ bool AArch64TTIImpl::isPartialReductionSupported(
     return false;
 
   Type *ReductionType = ReductionInstr->getType();
-  if (ReductionType->isIntegerTy(32)) {
-    if (!InputType->isIntegerTy(8))
-      return false;
-  } else if (ReductionType->isIntegerTy(64)) {
-    if (!InputType->isIntegerTy(16))
-      return false;
-  }
 
-  return true;
+  return ((ReductionType->isIntegerTy(32) && InputType->isIntegerTy(8)) ||
+          (ReductionType->isIntegerTy(64) && InputType->isIntegerTy(16))) &&
+         ReductionType->isScalableTy();
 }
 
 unsigned AArch64TTIImpl::getMaxInterleaveFactor(ElementCount VF) {

>From 0213f5d0be6f27888b150e2f3b8af8d6ae64ccee Mon Sep 17 00:00:00 2001
From: Samuel Tebbs <samuel.tebbs at arm.com>
Date: Thu, 1 Aug 2024 11:36:50 +0100
Subject: [PATCH 7/7] Simplify instruction matching in
 shouldExpandPartialReduction

---
 .../Target/AArch64/AArch64ISelLowering.cpp    | 56 +++++++++----------
 1 file changed, 27 insertions(+), 29 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 510a92dc8d734..4fbe323117071 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1981,38 +1981,36 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
   if (!RetTy || !RetTy->isScalableTy())
     return true;
 
-  VectorType *InputTy = nullptr;
-
-  auto RetScalarTy = RetTy->getScalarType();
-  if (RetScalarTy->isIntegerTy(64))
-    InputTy = VectorType::get(Type::getInt16Ty(I->getContext()), 8, RetTy->isScalableTy());
-  else if (RetScalarTy->isIntegerTy(32))
-    InputTy = VectorType::get(Type::getInt8Ty(I->getContext()), 16, RetTy->isScalableTy());
-  else
-    return true;
-
   Value *InputA;
   Value *InputB;
+  if (match(I, m_Intrinsic<Intrinsic::experimental_vector_partial_reduce_add>(
+                   m_Value(),
+                   m_OneUse(m_Mul(m_OneUse(m_ZExtOrSExt(m_Value(InputA))),
+                                  m_OneUse(m_ZExtOrSExt(m_Value(InputB)))))))) {
+    VectorType *InputAType = dyn_cast<VectorType>(InputA->getType());
+    VectorType *InputBType = dyn_cast<VectorType>(InputB->getType());
+    if (!InputAType || !InputBType)
+      return true;
+    ElementCount ExpectedCount8 = ElementCount::get(8, RetTy->isScalableTy());
+    ElementCount ExpectedCount16 = ElementCount::get(16, RetTy->isScalableTy());
+    if ((RetTy->getScalarType()->isIntegerTy(64) &&
+         InputAType->getElementType()->isIntegerTy(16) &&
+         InputAType->getElementCount() == ExpectedCount8 &&
+         InputAType == InputBType) ||
+
+        (RetTy->getScalarType()->isIntegerTy(32) &&
+         InputAType->getElementType()->isIntegerTy(8) &&
+         InputAType->getElementCount() == ExpectedCount16 &&
+         InputAType == InputBType)) {
+      auto *Mul = cast<Instruction>(I->getOperand(1));
+      auto *Mul0 = cast<Instruction>(Mul->getOperand(0));
+      auto *Mul1 = cast<Instruction>(Mul->getOperand(1));
+      if (Mul0->getOpcode() == Mul1->getOpcode())
+        return false;
+    }
+  }
 
-  auto Pattern = m_Intrinsic<Intrinsic::experimental_vector_partial_reduce_add>(
-      m_Value(), m_OneUse(m_Mul(m_OneUse(m_ZExtOrSExt(m_Value(InputA))),
-                                m_OneUse(m_ZExtOrSExt(m_Value(InputB))))));
-
-  if (!match(I, Pattern))
-    return true;
-
-  auto Mul = cast<Instruction>(I->getOperand(1));
-  auto getOpcodeOfOperand = [&](unsigned Idx) {
-    return cast<Instruction>(Mul->getOperand(Idx))->getOpcode();
-  };
-
-  if (getOpcodeOfOperand(0) != getOpcodeOfOperand(1))
-    return true;
-
-  if (InputA->getType() != InputTy || InputB->getType() != InputTy)
-    return true;
-
-  return false;
+  return true;
 }
 
 bool AArch64TargetLowering::shouldExpandCttzElements(EVT VT) const {



More information about the llvm-commits mailing list