[llvm] [AArch64] Lower partial add reduction to udot or svdot (PR #101010)
Sam Tebbs via llvm-commits
llvm-commits at lists.llvm.org
Fri Aug 30 02:02:48 PDT 2024
https://github.com/SamTebbs33 updated https://github.com/llvm/llvm-project/pull/101010
>From 561706ff1ed18f9b1924df417dbe1c2a4ff65432 Mon Sep 17 00:00:00 2001
From: Sam Tebbs <samuel.tebbs at arm.com>
Date: Mon, 29 Jul 2024 10:46:16 +0100
Subject: [PATCH 01/27] [AArch64] Lower add partial reduction to udot
This patch introduces lowering of the partial add reduction intrinsic to
a udot or svdot for AArch64.
---
llvm/include/llvm/CodeGen/TargetLowering.h | 6 +
.../SelectionDAG/SelectionDAGBuilder.cpp | 6 +
.../Target/AArch64/AArch64ISelLowering.cpp | 77 +++++++++++++
llvm/lib/Target/AArch64/AArch64ISelLowering.h | 2 +
.../AArch64/AArch64TargetTransformInfo.cpp | 30 +++++
.../AArch64/AArch64TargetTransformInfo.h | 6 +
.../AArch64/partial-reduce-dot-product.ll | 109 ++++++++++++++++++
7 files changed, 236 insertions(+)
create mode 100644 llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index eda38cd8a564d6..883a2252f7ffee 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -453,6 +453,12 @@ class TargetLoweringBase {
return true;
}
+ /// Return true if the @llvm.experimental.vector.partial.reduce.* intrinsic
+ /// should be expanded using generic code in SelectionDAGBuilder.
+ virtual bool shouldExpandPartialReductionIntrinsic(const CallInst *I) const {
+ return true;
+ }
+
/// Return true if the @llvm.get.active.lane.mask intrinsic should be expanded
/// using generic code in SelectionDAGBuilder.
virtual bool shouldExpandGetActiveLaneMask(EVT VT, EVT OpVT) const {
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 60dcb118542785..5ddbf9f414d218 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -8005,6 +8005,12 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
return;
}
case Intrinsic::experimental_vector_partial_reduce_add: {
+
+ if (!TLI.shouldExpandPartialReductionIntrinsic(&I)) {
+ visitTargetIntrinsic(I, Intrinsic);
+ return;
+ }
+
SDValue OpNode = getValue(I.getOperand(1));
EVT ReducedTy = EVT::getEVT(I.getType());
EVT FullTy = OpNode.getValueType();
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 215f30128e7038..987a7290274e73 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1988,6 +1988,57 @@ bool AArch64TargetLowering::shouldExpandGetActiveLaneMask(EVT ResVT,
return false;
}
+bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
+ const CallInst *CI) const {
+ const bool TargetLowers = false;
+ const bool GenericLowers = true;
+
+ auto *I = dyn_cast<IntrinsicInst>(CI);
+ if (!I)
+ return GenericLowers;
+
+ ScalableVectorType *RetTy = dyn_cast<ScalableVectorType>(I->getType());
+
+ if (!RetTy)
+ return GenericLowers;
+
+ ScalableVectorType *InputTy = nullptr;
+
+ auto RetScalarTy = RetTy->getScalarType();
+ if (RetScalarTy->isIntegerTy(64)) {
+ InputTy = ScalableVectorType::get(Type::getInt16Ty(I->getContext()), 8);
+ } else if (RetScalarTy->isIntegerTy(32)) {
+ InputTy = ScalableVectorType::get(Type::getInt8Ty(I->getContext()), 16);
+ }
+
+ if (!InputTy)
+ return GenericLowers;
+
+ Value *InputA;
+ Value *InputB;
+
+ auto Pattern = m_Intrinsic<Intrinsic::experimental_vector_partial_reduce_add>(
+ m_Value(), m_OneUse(m_Mul(m_OneUse(m_ZExtOrSExt(m_Value(InputA))),
+ m_OneUse(m_ZExtOrSExt(m_Value(InputB))))));
+
+ if (!match(I, Pattern))
+ return GenericLowers;
+
+ auto Mul = cast<Instruction>(I->getOperand(1));
+
+ auto getOpcodeOfOperand = [&](unsigned Idx) {
+ return cast<Instruction>(Mul->getOperand(Idx))->getOpcode();
+ };
+
+ if (getOpcodeOfOperand(0) != getOpcodeOfOperand(1))
+ return GenericLowers;
+
+ if (InputA->getType() != InputTy || InputB->getType() != InputTy)
+ return GenericLowers;
+
+ return TargetLowers;
+}
+
bool AArch64TargetLowering::shouldExpandCttzElements(EVT VT) const {
if (!Subtarget->isSVEorStreamingSVEAvailable())
return true;
@@ -21765,6 +21816,32 @@ static SDValue performIntrinsicCombine(SDNode *N,
switch (IID) {
default:
break;
+ case Intrinsic::experimental_vector_partial_reduce_add: {
+ SDLoc DL(N);
+
+ auto NarrowOp = N->getOperand(1);
+ auto MulOp = N->getOperand(2);
+
+ auto ExtA = MulOp->getOperand(0);
+ auto ExtB = MulOp->getOperand(1);
+
+ unsigned DotIntrinsicId = Intrinsic::not_intrinsic;
+
+ if (ExtA->getOpcode() == ISD::SIGN_EXTEND)
+ DotIntrinsicId = Intrinsic::aarch64_sve_sdot;
+ else if (ExtA->getOpcode() == ISD::ZERO_EXTEND)
+ DotIntrinsicId = Intrinsic::aarch64_sve_udot;
+
+ assert(DotIntrinsicId != Intrinsic::not_intrinsic &&
+ "Unexpected dot product case encountered.");
+
+ auto A = ExtA->getOperand(0);
+ auto B = ExtB->getOperand(0);
+
+ auto IntrinsicId = DAG.getConstant(DotIntrinsicId, DL, MVT::i64);
+ return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, NarrowOp.getValueType(),
+ {IntrinsicId, NarrowOp, A, B});
+ }
case Intrinsic::aarch64_neon_vcvtfxs2fp:
case Intrinsic::aarch64_neon_vcvtfxu2fp:
return tryCombineFixedPointConvert(N, DCI, DAG);
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index 39d5df0de0eec7..9fe95ddaca32c8 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -998,6 +998,8 @@ class AArch64TargetLowering : public TargetLowering {
bool shouldExpandGetActiveLaneMask(EVT VT, EVT OpVT) const override;
+ bool shouldExpandPartialReductionIntrinsic(const CallInst *I) const override;
+
bool shouldExpandCttzElements(EVT VT) const override;
/// If a change in streaming mode is required on entry to/return from a
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index dc748290f2e21e..5871134e60985b 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -3670,6 +3670,36 @@ AArch64TTIImpl::getCostOfKeepingLiveOverCall(ArrayRef<Type *> Tys) {
return Cost;
}
+bool AArch64TTIImpl::isPartialReductionSupported(
+ const Instruction *ReductionInstr, Type *InputType, unsigned ScaleFactor,
+ bool IsInputASignExtended, bool IsInputBSignExtended,
+ const Instruction *BinOp) const {
+ if (ReductionInstr->getOpcode() != Instruction::Add)
+ return false;
+
+ // Check that both extends are of the same type
+ if (IsInputASignExtended != IsInputBSignExtended)
+ return false;
+
+ if (!BinOp || BinOp->getOpcode() != Instruction::Mul)
+ return false;
+
+ // Dot product only supports a scale factor of 4
+ if (ScaleFactor != 4)
+ return false;
+
+ Type *ReductionType = ReductionInstr->getType();
+ if (ReductionType->isIntegerTy(32)) {
+ if (!InputType->isIntegerTy(8))
+ return false;
+ } else if (ReductionType->isIntegerTy(64)) {
+ if (!InputType->isIntegerTy(16))
+ return false;
+ }
+
+ return true;
+}
+
unsigned AArch64TTIImpl::getMaxInterleaveFactor(ElementCount VF) {
return ST->getMaxInterleaveFactor();
}
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index 4a6457d7a7dbf5..af7e8e8e497dd8 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -155,6 +155,12 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
return VF.getKnownMinValue() * ST->getVScaleForTuning();
}
+ bool isPartialReductionSupported(const Instruction *ReductionInstr,
+ Type *InputType, unsigned ScaleFactor,
+ bool IsInputASignExtended,
+ bool IsInputBSignExtended,
+ const Instruction *BinOp = nullptr) const;
+
unsigned getMaxInterleaveFactor(ElementCount VF);
bool prefersVectorizedAddressing() const;
diff --git a/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll
new file mode 100644
index 00000000000000..23b39387fb7a0c
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll
@@ -0,0 +1,109 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc -mtriple=aarch64-unknwon-linux-gnu -mattr=+sve2 -O3 %s -o - | FileCheck %s
+
+target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128"
+target triple = "aarch64-none-unknown-elf"
+
+define <vscale x 4 x i32> @dotp(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b) #0 {
+; CHECK-LABEL: dotp:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: mov z2.s, #0 // =0x0
+; CHECK-NEXT: udot z2.s, z0.b, z1.b
+; CHECK-NEXT: mov z0.d, z2.d
+; CHECK-NEXT: ret
+entry:
+ %a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+ %b.wide = zext <vscale x 16 x i8> %b to <vscale x 16 x i32>
+ %mult = mul nuw nsw <vscale x 16 x i32> %a.wide, %b.wide
+ %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> zeroinitializer, <vscale x 16 x i32> %mult)
+ ret <vscale x 4 x i32> %partial.reduce
+}
+
+define <vscale x 2 x i64> @dotp_wide(<vscale x 8 x i16> %a, <vscale x 8 x i16> %b) #0 {
+; CHECK-LABEL: dotp_wide:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: mov z2.d, #0 // =0x0
+; CHECK-NEXT: udot z2.d, z0.h, z1.h
+; CHECK-NEXT: mov z0.d, z2.d
+; CHECK-NEXT: ret
+entry:
+ %a.wide = zext <vscale x 8 x i16> %a to <vscale x 8 x i64>
+ %b.wide = zext <vscale x 8 x i16> %b to <vscale x 8 x i64>
+ %mult = mul nuw nsw <vscale x 8 x i64> %a.wide, %b.wide
+ %partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv8i64(<vscale x 2 x i64> zeroinitializer, <vscale x 8 x i64> %mult)
+ ret <vscale x 2 x i64> %partial.reduce
+}
+
+define <vscale x 4 x i32> @dotp_sext(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b) #0 {
+; CHECK-LABEL: dotp_sext:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: mov z2.s, #0 // =0x0
+; CHECK-NEXT: sdot z2.s, z0.b, z1.b
+; CHECK-NEXT: mov z0.d, z2.d
+; CHECK-NEXT: ret
+entry:
+ %a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+ %b.wide = sext <vscale x 16 x i8> %b to <vscale x 16 x i32>
+ %mult = mul nuw nsw <vscale x 16 x i32> %a.wide, %b.wide
+ %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> zeroinitializer, <vscale x 16 x i32> %mult)
+ ret <vscale x 4 x i32> %partial.reduce
+}
+
+define <vscale x 2 x i64> @dotp_wide_sext(<vscale x 8 x i16> %a, <vscale x 8 x i16> %b) #0 {
+; CHECK-LABEL: dotp_wide_sext:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: mov z2.d, #0 // =0x0
+; CHECK-NEXT: sdot z2.d, z0.h, z1.h
+; CHECK-NEXT: mov z0.d, z2.d
+; CHECK-NEXT: ret
+entry:
+ %a.wide = sext <vscale x 8 x i16> %a to <vscale x 8 x i64>
+ %b.wide = sext <vscale x 8 x i16> %b to <vscale x 8 x i64>
+ %mult = mul nuw nsw <vscale x 8 x i64> %a.wide, %b.wide
+ %partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv8i64(<vscale x 2 x i64> zeroinitializer, <vscale x 8 x i64> %mult)
+ ret <vscale x 2 x i64> %partial.reduce
+}
+
+define <vscale x 4 x i32> @not_dotp(<vscale x 8 x i8> %a, <vscale x 8 x i8> %b) #0 {
+; CHECK-LABEL: not_dotp:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: and z0.h, z0.h, #0xff
+; CHECK-NEXT: and z1.h, z1.h, #0xff
+; CHECK-NEXT: ptrue p0.s
+; CHECK-NEXT: uunpkhi z2.s, z0.h
+; CHECK-NEXT: uunpkhi z3.s, z1.h
+; CHECK-NEXT: uunpklo z0.s, z0.h
+; CHECK-NEXT: uunpklo z1.s, z1.h
+; CHECK-NEXT: mul z2.s, z2.s, z3.s
+; CHECK-NEXT: mad z0.s, p0/m, z1.s, z2.s
+; CHECK-NEXT: ret
+entry:
+ %a.wide = zext <vscale x 8 x i8> %a to <vscale x 8 x i32>
+ %b.wide = zext <vscale x 8 x i8> %b to <vscale x 8 x i32>
+ %mult = mul nuw nsw <vscale x 8 x i32> %a.wide, %b.wide
+ %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> zeroinitializer, <vscale x 8 x i32> %mult)
+ ret <vscale x 4 x i32> %partial.reduce
+}
+
+define <vscale x 2 x i64> @not_dotp_wide(<vscale x 4 x i16> %a, <vscale x 4 x i16> %b) #0 {
+; CHECK-LABEL: not_dotp_wide:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: and z0.s, z0.s, #0xffff
+; CHECK-NEXT: and z1.s, z1.s, #0xffff
+; CHECK-NEXT: ptrue p0.d
+; CHECK-NEXT: uunpkhi z2.d, z0.s
+; CHECK-NEXT: uunpkhi z3.d, z1.s
+; CHECK-NEXT: uunpklo z0.d, z0.s
+; CHECK-NEXT: uunpklo z1.d, z1.s
+; CHECK-NEXT: mul z2.d, z2.d, z3.d
+; CHECK-NEXT: mad z0.d, p0/m, z1.d, z2.d
+; CHECK-NEXT: ret
+entry:
+ %a.wide = zext <vscale x 4 x i16> %a to <vscale x 4 x i64>
+ %b.wide = zext <vscale x 4 x i16> %b to <vscale x 4 x i64>
+ %mult = mul nuw nsw <vscale x 4 x i64> %a.wide, %b.wide
+ %partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv4i64(<vscale x 2 x i64> zeroinitializer, <vscale x 4 x i64> %mult)
+ ret <vscale x 2 x i64> %partial.reduce
+}
+
+attributes #0 = { "target-features"="+sve2" }
>From 563d025161d0b0e00253daf8943926b8e6ebfc75 Mon Sep 17 00:00:00 2001
From: Sam Tebbs <samuel.tebbs at arm.com>
Date: Mon, 29 Jul 2024 18:08:52 +0100
Subject: [PATCH 02/27] Remove TargetLowers and GenericLowers
---
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 16 +++++++---------
1 file changed, 7 insertions(+), 9 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 987a7290274e73..f24e24bfe2d454 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1990,17 +1990,15 @@ bool AArch64TargetLowering::shouldExpandGetActiveLaneMask(EVT ResVT,
bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
const CallInst *CI) const {
- const bool TargetLowers = false;
- const bool GenericLowers = true;
auto *I = dyn_cast<IntrinsicInst>(CI);
if (!I)
- return GenericLowers;
+ return true;
ScalableVectorType *RetTy = dyn_cast<ScalableVectorType>(I->getType());
if (!RetTy)
- return GenericLowers;
+ return true;
ScalableVectorType *InputTy = nullptr;
@@ -2012,7 +2010,7 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
}
if (!InputTy)
- return GenericLowers;
+ return true;
Value *InputA;
Value *InputB;
@@ -2022,7 +2020,7 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
m_OneUse(m_ZExtOrSExt(m_Value(InputB))))));
if (!match(I, Pattern))
- return GenericLowers;
+ return true;
auto Mul = cast<Instruction>(I->getOperand(1));
@@ -2031,12 +2029,12 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
};
if (getOpcodeOfOperand(0) != getOpcodeOfOperand(1))
- return GenericLowers;
+ return true;
if (InputA->getType() != InputTy || InputB->getType() != InputTy)
- return GenericLowers;
+ return true;
- return TargetLowers;
+ return false;
}
bool AArch64TargetLowering::shouldExpandCttzElements(EVT VT) const {
>From e604e4570412adee32b2ce363f8ceea9c7438349 Mon Sep 17 00:00:00 2001
From: Sam Tebbs <samuel.tebbs at arm.com>
Date: Mon, 29 Jul 2024 18:16:16 +0100
Subject: [PATCH 03/27] Assert that shouldExpandPartialReductionIntrinsic sees
an intrinsic
---
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 3 +--
1 file changed, 1 insertion(+), 2 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index f24e24bfe2d454..8811f2ef94fb09 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1992,8 +1992,7 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
const CallInst *CI) const {
auto *I = dyn_cast<IntrinsicInst>(CI);
- if (!I)
- return true;
+ assert(I && "shouldExpandPartialReductionIntrinsic expects an intrinisc");
ScalableVectorType *RetTy = dyn_cast<ScalableVectorType>(I->getType());
>From 9b23c96f5cd7a5ce10229682e61b1d8c5464b01e Mon Sep 17 00:00:00 2001
From: Sam Tebbs <samuel.tebbs at arm.com>
Date: Mon, 29 Jul 2024 18:31:24 +0100
Subject: [PATCH 04/27] Allow non-scalable vector types
---
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 9 ++++-----
1 file changed, 4 insertions(+), 5 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 8811f2ef94fb09..ad38410ba3f275 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1994,18 +1994,17 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
auto *I = dyn_cast<IntrinsicInst>(CI);
assert(I && "shouldExpandPartialReductionIntrinsic expects an intrinisc");
- ScalableVectorType *RetTy = dyn_cast<ScalableVectorType>(I->getType());
-
+ VectorType *RetTy = dyn_cast<VectorType>(I->getType());
if (!RetTy)
return true;
- ScalableVectorType *InputTy = nullptr;
+ VectorType *InputTy = nullptr;
auto RetScalarTy = RetTy->getScalarType();
if (RetScalarTy->isIntegerTy(64)) {
- InputTy = ScalableVectorType::get(Type::getInt16Ty(I->getContext()), 8);
+ InputTy = VectorType::get(Type::getInt16Ty(I->getContext()), 8, RetTy->isScalableTy());
} else if (RetScalarTy->isIntegerTy(32)) {
- InputTy = ScalableVectorType::get(Type::getInt8Ty(I->getContext()), 16);
+ InputTy = VectorType::get(Type::getInt8Ty(I->getContext()), 16, RetTy->isScalableTy());
}
if (!InputTy)
>From 45692dff2a6498d3ebc53a5862ef68239a772ab9 Mon Sep 17 00:00:00 2001
From: Sam Tebbs <samuel.tebbs at arm.com>
Date: Mon, 29 Jul 2024 19:10:08 +0100
Subject: [PATCH 05/27] Clean up type checking
---
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 9 +++------
1 file changed, 3 insertions(+), 6 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index ad38410ba3f275..f96be2a9f55e62 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -2001,13 +2001,11 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
VectorType *InputTy = nullptr;
auto RetScalarTy = RetTy->getScalarType();
- if (RetScalarTy->isIntegerTy(64)) {
+ if (RetScalarTy->isIntegerTy(64))
InputTy = VectorType::get(Type::getInt16Ty(I->getContext()), 8, RetTy->isScalableTy());
- } else if (RetScalarTy->isIntegerTy(32)) {
+ else if (RetScalarTy->isIntegerTy(32))
InputTy = VectorType::get(Type::getInt8Ty(I->getContext()), 16, RetTy->isScalableTy());
- }
-
- if (!InputTy)
+ else
return true;
Value *InputA;
@@ -2021,7 +2019,6 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
return true;
auto Mul = cast<Instruction>(I->getOperand(1));
-
auto getOpcodeOfOperand = [&](unsigned Idx) {
return cast<Instruction>(Mul->getOperand(Idx))->getOpcode();
};
>From d305452d3cf171b320cdd25720247a4d61fb1a31 Mon Sep 17 00:00:00 2001
From: Samuel Tebbs <samuel.tebbs at arm.com>
Date: Thu, 1 Aug 2024 11:04:37 +0100
Subject: [PATCH 06/27] Restrict to scalable vector types and clean up type
checking
---
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 2 +-
.../lib/Target/AArch64/AArch64TargetTransformInfo.cpp | 11 +++--------
2 files changed, 4 insertions(+), 9 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index f96be2a9f55e62..31c3e208356bf9 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1995,7 +1995,7 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
assert(I && "shouldExpandPartialReductionIntrinsic expects an intrinisc");
VectorType *RetTy = dyn_cast<VectorType>(I->getType());
- if (!RetTy)
+ if (!RetTy || !RetTy->isScalableTy())
return true;
VectorType *InputTy = nullptr;
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 5871134e60985b..8cd2ba17b7d79d 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -3689,15 +3689,10 @@ bool AArch64TTIImpl::isPartialReductionSupported(
return false;
Type *ReductionType = ReductionInstr->getType();
- if (ReductionType->isIntegerTy(32)) {
- if (!InputType->isIntegerTy(8))
- return false;
- } else if (ReductionType->isIntegerTy(64)) {
- if (!InputType->isIntegerTy(16))
- return false;
- }
- return true;
+ return ((ReductionType->isIntegerTy(32) && InputType->isIntegerTy(8)) ||
+ (ReductionType->isIntegerTy(64) && InputType->isIntegerTy(16))) &&
+ ReductionType->isScalableTy();
}
unsigned AArch64TTIImpl::getMaxInterleaveFactor(ElementCount VF) {
>From 4738a204c93ba2386783afb375b95c806e111522 Mon Sep 17 00:00:00 2001
From: Samuel Tebbs <samuel.tebbs at arm.com>
Date: Thu, 1 Aug 2024 11:36:50 +0100
Subject: [PATCH 07/27] Simplify instruction matching in
shouldExpandPartialReduction
---
.../Target/AArch64/AArch64ISelLowering.cpp | 56 +++++++++----------
1 file changed, 27 insertions(+), 29 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 31c3e208356bf9..c4d00abfc190f0 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1998,38 +1998,36 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
if (!RetTy || !RetTy->isScalableTy())
return true;
- VectorType *InputTy = nullptr;
-
- auto RetScalarTy = RetTy->getScalarType();
- if (RetScalarTy->isIntegerTy(64))
- InputTy = VectorType::get(Type::getInt16Ty(I->getContext()), 8, RetTy->isScalableTy());
- else if (RetScalarTy->isIntegerTy(32))
- InputTy = VectorType::get(Type::getInt8Ty(I->getContext()), 16, RetTy->isScalableTy());
- else
- return true;
-
Value *InputA;
Value *InputB;
+ if (match(I, m_Intrinsic<Intrinsic::experimental_vector_partial_reduce_add>(
+ m_Value(),
+ m_OneUse(m_Mul(m_OneUse(m_ZExtOrSExt(m_Value(InputA))),
+ m_OneUse(m_ZExtOrSExt(m_Value(InputB)))))))) {
+ VectorType *InputAType = dyn_cast<VectorType>(InputA->getType());
+ VectorType *InputBType = dyn_cast<VectorType>(InputB->getType());
+ if (!InputAType || !InputBType)
+ return true;
+ ElementCount ExpectedCount8 = ElementCount::get(8, RetTy->isScalableTy());
+ ElementCount ExpectedCount16 = ElementCount::get(16, RetTy->isScalableTy());
+ if ((RetTy->getScalarType()->isIntegerTy(64) &&
+ InputAType->getElementType()->isIntegerTy(16) &&
+ InputAType->getElementCount() == ExpectedCount8 &&
+ InputAType == InputBType) ||
+
+ (RetTy->getScalarType()->isIntegerTy(32) &&
+ InputAType->getElementType()->isIntegerTy(8) &&
+ InputAType->getElementCount() == ExpectedCount16 &&
+ InputAType == InputBType)) {
+ auto *Mul = cast<Instruction>(I->getOperand(1));
+ auto *Mul0 = cast<Instruction>(Mul->getOperand(0));
+ auto *Mul1 = cast<Instruction>(Mul->getOperand(1));
+ if (Mul0->getOpcode() == Mul1->getOpcode())
+ return false;
+ }
+ }
- auto Pattern = m_Intrinsic<Intrinsic::experimental_vector_partial_reduce_add>(
- m_Value(), m_OneUse(m_Mul(m_OneUse(m_ZExtOrSExt(m_Value(InputA))),
- m_OneUse(m_ZExtOrSExt(m_Value(InputB))))));
-
- if (!match(I, Pattern))
- return true;
-
- auto Mul = cast<Instruction>(I->getOperand(1));
- auto getOpcodeOfOperand = [&](unsigned Idx) {
- return cast<Instruction>(Mul->getOperand(Idx))->getOpcode();
- };
-
- if (getOpcodeOfOperand(0) != getOpcodeOfOperand(1))
- return true;
-
- if (InputA->getType() != InputTy || InputB->getType() != InputTy)
- return true;
-
- return false;
+ return true;
}
bool AArch64TargetLowering::shouldExpandCttzElements(EVT VT) const {
>From 4dbf99e959674b91c743340560816d926480e2a3 Mon Sep 17 00:00:00 2001
From: Sam Tebbs <samuel.tebbs at arm.com>
Date: Fri, 9 Aug 2024 16:38:22 +0100
Subject: [PATCH 08/27] Add fallback in case the nodes aren't as we expect at
lowering time
---
.../Target/AArch64/AArch64ISelLowering.cpp | 67 ++++++++++++++++---
1 file changed, 59 insertions(+), 8 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index c4d00abfc190f0..a5ea612fb38996 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -21810,28 +21810,79 @@ static SDValue performIntrinsicCombine(SDNode *N,
case Intrinsic::experimental_vector_partial_reduce_add: {
SDLoc DL(N);
+ bool IsValidDotProduct = false;
+
auto NarrowOp = N->getOperand(1);
auto MulOp = N->getOperand(2);
+ if (MulOp->getOpcode() == ISD::MUL)
+ IsValidDotProduct = true;
auto ExtA = MulOp->getOperand(0);
auto ExtB = MulOp->getOperand(1);
+ bool IsSExt = ExtA->getOpcode() == ISD::SIGN_EXTEND;
+ bool IsZExt = ExtA->getOpcode() == ISD::ZERO_EXTEND;
+ if (ExtA->getOpcode() == ExtB->getOpcode() && (IsSExt || IsZExt))
+ IsValidDotProduct = true;
unsigned DotIntrinsicId = Intrinsic::not_intrinsic;
- if (ExtA->getOpcode() == ISD::SIGN_EXTEND)
+ if (IsSExt && IsValidDotProduct)
DotIntrinsicId = Intrinsic::aarch64_sve_sdot;
- else if (ExtA->getOpcode() == ISD::ZERO_EXTEND)
+ else if (IsZExt && IsValidDotProduct)
DotIntrinsicId = Intrinsic::aarch64_sve_udot;
- assert(DotIntrinsicId != Intrinsic::not_intrinsic &&
+ assert((!IsValidDotProduct || DotIntrinsicId != Intrinsic::not_intrinsic) &&
"Unexpected dot product case encountered.");
- auto A = ExtA->getOperand(0);
- auto B = ExtB->getOperand(0);
+ if (IsValidDotProduct) {
+ auto A = ExtA->getOperand(0);
+ auto B = ExtB->getOperand(0);
+
+ auto IntrinsicId = DAG.getConstant(DotIntrinsicId, DL, MVT::i64);
+ return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, NarrowOp.getValueType(),
+ {IntrinsicId, NarrowOp, A, B});
+ } else {
+ // If the node doesn't match a dot product, lower to a series of ADDs
+ // instead.
+ SDValue Op0 = N->getOperand(0);
+ SDValue Op1 = N->getOperand(1);
+ EVT Type0 = Op0->getValueType(0);
+ EVT Type1 = Op1->getValueType(0);
+
+ // Canonicalise so that Op1 has the larger type
+ if (Type1.getVectorNumElements() > Type0.getVectorNumElements()) {
+ std::swap(Op0, Op1);
+ std::swap(Type0, Type1);
+ }
- auto IntrinsicId = DAG.getConstant(DotIntrinsicId, DL, MVT::i64);
- return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, NarrowOp.getValueType(),
- {IntrinsicId, NarrowOp, A, B});
+ auto Type0Elements = Type0.getVectorNumElements();
+ auto Type1Elements = Type1.getVectorNumElements();
+ auto Type0ElementSize =
+ Type0.getVectorElementType().getScalarSizeInBits();
+ auto Type1ElementSize =
+ Type1.getVectorElementType().getScalarSizeInBits();
+
+ // If the types are equal then a single ADD is fine
+ if (Type0 == Type1)
+ return DAG.getNode(ISD::ADD, DL, Type0, {Op0, Op1});
+
+ // Otherwise, we need to add each subvector together so that the output is
+ // the intrinsic's return type. For example, <4 x i32>
+ // partial.reduction(<4 x i32> a, <16 x i32> b) becomes a + b[0..3] +
+ // b[4..7] + b[8..11] + b[12..15]
+ SDValue Add = Op0;
+ for (unsigned i = 0; i < Type1Elements / Type0Elements; i++) {
+ SDValue Subvec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, Type0, Op1,
+ DAG.getConstant(i, DL, MVT::i64));
+
+ if (Type1ElementSize < Type0ElementSize)
+ Subvec = DAG.getNode(ISD::ANY_EXTEND, DL, Type0, Subvec);
+ else if (Type1ElementSize > Type0ElementSize)
+ Subvec = DAG.getNode(ISD::TRUNCATE, DL, Type0, Subvec);
+ Add = DAG.getNode(ISD::ADD, DL, Type0, {Add, Subvec});
+ }
+ return Add;
+ }
}
case Intrinsic::aarch64_neon_vcvtfxs2fp:
case Intrinsic::aarch64_neon_vcvtfxu2fp:
>From c068775b5ddb75a9cd31b6443551c5a0e9cab496 Mon Sep 17 00:00:00 2001
From: Sam Tebbs <samuel.tebbs at arm.com>
Date: Mon, 12 Aug 2024 11:02:28 +0100
Subject: [PATCH 09/27] Fix logic error with fallback case
---
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 14 +++++++-------
1 file changed, 7 insertions(+), 7 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index a5ea612fb38996..56327cc074a47b 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -21810,19 +21810,19 @@ static SDValue performIntrinsicCombine(SDNode *N,
case Intrinsic::experimental_vector_partial_reduce_add: {
SDLoc DL(N);
- bool IsValidDotProduct = false;
+ bool IsValidDotProduct = true;
auto NarrowOp = N->getOperand(1);
auto MulOp = N->getOperand(2);
- if (MulOp->getOpcode() == ISD::MUL)
- IsValidDotProduct = true;
+ if (MulOp->getOpcode() != ISD::MUL)
+ IsValidDotProduct = false;
auto ExtA = MulOp->getOperand(0);
auto ExtB = MulOp->getOperand(1);
bool IsSExt = ExtA->getOpcode() == ISD::SIGN_EXTEND;
bool IsZExt = ExtA->getOpcode() == ISD::ZERO_EXTEND;
- if (ExtA->getOpcode() == ExtB->getOpcode() && (IsSExt || IsZExt))
- IsValidDotProduct = true;
+ if (ExtA->getOpcode() != ExtB->getOpcode() || (!IsSExt && !IsZExt))
+ IsValidDotProduct = false;
unsigned DotIntrinsicId = Intrinsic::not_intrinsic;
@@ -21844,8 +21844,8 @@ static SDValue performIntrinsicCombine(SDNode *N,
} else {
// If the node doesn't match a dot product, lower to a series of ADDs
// instead.
- SDValue Op0 = N->getOperand(0);
- SDValue Op1 = N->getOperand(1);
+ SDValue Op0 = N->getOperand(1);
+ SDValue Op1 = N->getOperand(2);
EVT Type0 = Op0->getValueType(0);
EVT Type1 = Op1->getValueType(0);
>From 636652d0ad53a8476270ffa01cee73483081ca4a Mon Sep 17 00:00:00 2001
From: Sam Tebbs <samuel.tebbs at arm.com>
Date: Tue, 13 Aug 2024 14:18:53 +0100
Subject: [PATCH 10/27] Pass IntrinsicInst to
shouldExpandPartialReductionIntrinsic
---
llvm/include/llvm/CodeGen/TargetLowering.h | 3 ++-
llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp | 2 +-
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 5 +----
llvm/lib/Target/AArch64/AArch64ISelLowering.h | 3 ++-
4 files changed, 6 insertions(+), 7 deletions(-)
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 883a2252f7ffee..e17d68d2690c86 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -455,7 +455,8 @@ class TargetLoweringBase {
/// Return true if the @llvm.experimental.vector.partial.reduce.* intrinsic
/// should be expanded using generic code in SelectionDAGBuilder.
- virtual bool shouldExpandPartialReductionIntrinsic(const CallInst *I) const {
+ virtual bool
+ shouldExpandPartialReductionIntrinsic(const IntrinsicInst *I) const {
return true;
}
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 5ddbf9f414d218..05cbe384cb5ed5 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -8006,7 +8006,7 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
}
case Intrinsic::experimental_vector_partial_reduce_add: {
- if (!TLI.shouldExpandPartialReductionIntrinsic(&I)) {
+ if (!TLI.shouldExpandPartialReductionIntrinsic(cast<IntrinsicInst>(&I))) {
visitTargetIntrinsic(I, Intrinsic);
return;
}
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 56327cc074a47b..916eccd52e9396 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1989,10 +1989,7 @@ bool AArch64TargetLowering::shouldExpandGetActiveLaneMask(EVT ResVT,
}
bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
- const CallInst *CI) const {
-
- auto *I = dyn_cast<IntrinsicInst>(CI);
- assert(I && "shouldExpandPartialReductionIntrinsic expects an intrinisc");
+ const IntrinsicInst *I) const {
VectorType *RetTy = dyn_cast<VectorType>(I->getType());
if (!RetTy || !RetTy->isScalableTy())
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index 9fe95ddaca32c8..f9d45b02d30e30 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -998,7 +998,8 @@ class AArch64TargetLowering : public TargetLowering {
bool shouldExpandGetActiveLaneMask(EVT VT, EVT OpVT) const override;
- bool shouldExpandPartialReductionIntrinsic(const CallInst *I) const override;
+ bool
+ shouldExpandPartialReductionIntrinsic(const IntrinsicInst *I) const override;
bool shouldExpandCttzElements(EVT VT) const override;
>From 83015b7e08e5c0ccfa55cf20423292a9e29d2a2a Mon Sep 17 00:00:00 2001
From: Sam Tebbs <samuel.tebbs at arm.com>
Date: Tue, 13 Aug 2024 14:31:22 +0100
Subject: [PATCH 11/27] Remove one-use restriction
---
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 8 ++++----
1 file changed, 4 insertions(+), 4 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 916eccd52e9396..cd70d2e3cdfa7f 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1997,10 +1997,10 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
Value *InputA;
Value *InputB;
- if (match(I, m_Intrinsic<Intrinsic::experimental_vector_partial_reduce_add>(
- m_Value(),
- m_OneUse(m_Mul(m_OneUse(m_ZExtOrSExt(m_Value(InputA))),
- m_OneUse(m_ZExtOrSExt(m_Value(InputB)))))))) {
+ if (match(I,
+ m_Intrinsic<Intrinsic::experimental_vector_partial_reduce_add>(
+ m_Value(), m_OneUse(m_Mul(m_ZExtOrSExt(m_Value(InputA)),
+ m_ZExtOrSExt(m_Value(InputB))))))) {
VectorType *InputAType = dyn_cast<VectorType>(InputA->getType());
VectorType *InputBType = dyn_cast<VectorType>(InputB->getType());
if (!InputAType || !InputBType)
>From ed6efd6e7e705cd8b932b66f9e818a0e6c204884 Mon Sep 17 00:00:00 2001
From: Sam Tebbs <samuel.tebbs at arm.com>
Date: Tue, 13 Aug 2024 14:32:10 +0100
Subject: [PATCH 12/27] Remove new line
---
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 1 -
1 file changed, 1 deletion(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index cd70d2e3cdfa7f..856d9b0eeadd12 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -2011,7 +2011,6 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
InputAType->getElementType()->isIntegerTy(16) &&
InputAType->getElementCount() == ExpectedCount8 &&
InputAType == InputBType) ||
-
(RetTy->getScalarType()->isIntegerTy(32) &&
InputAType->getElementType()->isIntegerTy(8) &&
InputAType->getElementCount() == ExpectedCount16 &&
>From 63648378be615a28ea75fd53d27eabc0a3658c06 Mon Sep 17 00:00:00 2001
From: Sam Tebbs <samuel.tebbs at arm.com>
Date: Tue, 13 Aug 2024 20:21:43 +0100
Subject: [PATCH 13/27] Remove extending/truncating for fallback case
---
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 8 --------
1 file changed, 8 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 856d9b0eeadd12..c150db8d8947d1 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -21853,10 +21853,6 @@ static SDValue performIntrinsicCombine(SDNode *N,
auto Type0Elements = Type0.getVectorNumElements();
auto Type1Elements = Type1.getVectorNumElements();
- auto Type0ElementSize =
- Type0.getVectorElementType().getScalarSizeInBits();
- auto Type1ElementSize =
- Type1.getVectorElementType().getScalarSizeInBits();
// If the types are equal then a single ADD is fine
if (Type0 == Type1)
@@ -21871,10 +21867,6 @@ static SDValue performIntrinsicCombine(SDNode *N,
SDValue Subvec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, Type0, Op1,
DAG.getConstant(i, DL, MVT::i64));
- if (Type1ElementSize < Type0ElementSize)
- Subvec = DAG.getNode(ISD::ANY_EXTEND, DL, Type0, Subvec);
- else if (Type1ElementSize > Type0ElementSize)
- Subvec = DAG.getNode(ISD::TRUNCATE, DL, Type0, Subvec);
Add = DAG.getNode(ISD::ADD, DL, Type0, {Add, Subvec});
}
return Add;
>From 9da416b52c88340bbd9611ad016532c58f2e8417 Mon Sep 17 00:00:00 2001
From: Sam Tebbs <samuel.tebbs at arm.com>
Date: Tue, 13 Aug 2024 20:27:04 +0100
Subject: [PATCH 14/27] Clean up test target
---
llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll | 5 +----
1 file changed, 1 insertion(+), 4 deletions(-)
diff --git a/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll
index 23b39387fb7a0c..0facb2049135f6 100644
--- a/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll
@@ -1,8 +1,5 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
-; RUN: llc -mtriple=aarch64-unknwon-linux-gnu -mattr=+sve2 -O3 %s -o - | FileCheck %s
-
-target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128"
-target triple = "aarch64-none-unknown-elf"
+; RUN: llc -mtriple=aarch64 -mattr=+sve2 %s -o - | FileCheck %s
define <vscale x 4 x i32> @dotp(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b) #0 {
; CHECK-LABEL: dotp:
>From 0d231096e4f971f379a56c21e961caae7db3080e Mon Sep 17 00:00:00 2001
From: Samuel Tebbs <samuel.tebbs at arm.com>
Date: Wed, 14 Aug 2024 09:42:32 +0100
Subject: [PATCH 15/27] Remove #0 attribute from test
---
.../CodeGen/AArch64/partial-reduce-dot-product.ll | 14 ++++++--------
1 file changed, 6 insertions(+), 8 deletions(-)
diff --git a/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll
index 0facb2049135f6..16ef219a93c9bf 100644
--- a/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll
@@ -1,7 +1,7 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
; RUN: llc -mtriple=aarch64 -mattr=+sve2 %s -o - | FileCheck %s
-define <vscale x 4 x i32> @dotp(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b) #0 {
+define <vscale x 4 x i32> @dotp(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
; CHECK-LABEL: dotp:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: mov z2.s, #0 // =0x0
@@ -16,7 +16,7 @@ entry:
ret <vscale x 4 x i32> %partial.reduce
}
-define <vscale x 2 x i64> @dotp_wide(<vscale x 8 x i16> %a, <vscale x 8 x i16> %b) #0 {
+define <vscale x 2 x i64> @dotp_wide(<vscale x 8 x i16> %a, <vscale x 8 x i16> %b) {
; CHECK-LABEL: dotp_wide:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: mov z2.d, #0 // =0x0
@@ -31,7 +31,7 @@ entry:
ret <vscale x 2 x i64> %partial.reduce
}
-define <vscale x 4 x i32> @dotp_sext(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b) #0 {
+define <vscale x 4 x i32> @dotp_sext(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
; CHECK-LABEL: dotp_sext:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: mov z2.s, #0 // =0x0
@@ -46,7 +46,7 @@ entry:
ret <vscale x 4 x i32> %partial.reduce
}
-define <vscale x 2 x i64> @dotp_wide_sext(<vscale x 8 x i16> %a, <vscale x 8 x i16> %b) #0 {
+define <vscale x 2 x i64> @dotp_wide_sext(<vscale x 8 x i16> %a, <vscale x 8 x i16> %b) {
; CHECK-LABEL: dotp_wide_sext:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: mov z2.d, #0 // =0x0
@@ -61,7 +61,7 @@ entry:
ret <vscale x 2 x i64> %partial.reduce
}
-define <vscale x 4 x i32> @not_dotp(<vscale x 8 x i8> %a, <vscale x 8 x i8> %b) #0 {
+define <vscale x 4 x i32> @not_dotp(<vscale x 8 x i8> %a, <vscale x 8 x i8> %b) {
; CHECK-LABEL: not_dotp:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: and z0.h, z0.h, #0xff
@@ -82,7 +82,7 @@ entry:
ret <vscale x 4 x i32> %partial.reduce
}
-define <vscale x 2 x i64> @not_dotp_wide(<vscale x 4 x i16> %a, <vscale x 4 x i16> %b) #0 {
+define <vscale x 2 x i64> @not_dotp_wide(<vscale x 4 x i16> %a, <vscale x 4 x i16> %b) {
; CHECK-LABEL: not_dotp_wide:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: and z0.s, z0.s, #0xffff
@@ -102,5 +102,3 @@ entry:
%partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv4i64(<vscale x 2 x i64> zeroinitializer, <vscale x 4 x i64> %mult)
ret <vscale x 2 x i64> %partial.reduce
}
-
-attributes #0 = { "target-features"="+sve2" }
>From bc86de6d93b9166c6ac4ca10cddae0199607fd15 Mon Sep 17 00:00:00 2001
From: Samuel Tebbs <samuel.tebbs at arm.com>
Date: Wed, 14 Aug 2024 10:55:12 +0100
Subject: [PATCH 16/27] Allow i8 to i64 dot products
---
.../Target/AArch64/AArch64ISelLowering.cpp | 34 ++++++++-
.../AArch64/partial-reduce-dot-product.ll | 72 +++++++++++++++++++
2 files changed, 103 insertions(+), 3 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index c150db8d8947d1..13e664f5bf27f5 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -2007,11 +2007,15 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
return true;
ElementCount ExpectedCount8 = ElementCount::get(8, RetTy->isScalableTy());
ElementCount ExpectedCount16 = ElementCount::get(16, RetTy->isScalableTy());
+ // Check that the input type is 4 times smaller than the output type. If the
+ // output type is 64 bit then we can accept 8 bit inputs if we do a 32 bit
+ // dot product and add a zext/sext.
if ((RetTy->getScalarType()->isIntegerTy(64) &&
InputAType->getElementType()->isIntegerTy(16) &&
InputAType->getElementCount() == ExpectedCount8 &&
InputAType == InputBType) ||
- (RetTy->getScalarType()->isIntegerTy(32) &&
+ ((RetTy->getScalarType()->isIntegerTy(32) ||
+ RetTy->getScalarType()->isIntegerTy(64)) &&
InputAType->getElementType()->isIntegerTy(8) &&
InputAType->getElementCount() == ExpectedCount16 &&
InputAType == InputBType)) {
@@ -21833,10 +21837,34 @@ static SDValue performIntrinsicCombine(SDNode *N,
if (IsValidDotProduct) {
auto A = ExtA->getOperand(0);
auto B = ExtB->getOperand(0);
+ EVT Type = NarrowOp.getValueType();
+
+ // 8 bit input to 64 bit output can be done by doing a 32 bit dot product
+ // and extending the output
+ bool Extend = A->getValueType(0).getScalarSizeInBits() == 8 &&
+ Type.getScalarSizeInBits() == 64;
+ SDValue Accumulator = NarrowOp;
+ if (Extend) {
+ Type = Type.changeVectorElementType(
+ EVT::getIntegerVT(*DAG.getContext(), 32));
+ // The accumulator is of the wider type so we insert a 0 accumulator and
+ // add the proper one after extending
+ Accumulator = DAG.getNode(ISD::SPLAT_VECTOR, DL, MVT::nxv4i32,
+ DAG.getConstant(0, DL, MVT::i32));
+ }
auto IntrinsicId = DAG.getConstant(DotIntrinsicId, DL, MVT::i64);
- return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, NarrowOp.getValueType(),
- {IntrinsicId, NarrowOp, A, B});
+ auto DotProduct = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, Type,
+ {IntrinsicId, Accumulator, A, B});
+ if (Extend) {
+ auto Extended =
+ DAG.getNode(IsZExt ? ISD::ZERO_EXTEND : ISD::SIGN_EXTEND, DL,
+ NarrowOp.getValueType(), {DotProduct});
+ auto AccAdd = DAG.getNode(ISD::ADD, DL, NarrowOp.getValueType(),
+ {NarrowOp, Extended});
+ DotProduct = AccAdd;
+ }
+ return DotProduct;
} else {
// If the node doesn't match a dot product, lower to a series of ADDs
// instead.
diff --git a/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll
index 16ef219a93c9bf..c1cf9026d693ce 100644
--- a/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll
@@ -61,6 +61,78 @@ entry:
ret <vscale x 2 x i64> %partial.reduce
}
+define <vscale x 4 x i64> @dotp_8to64(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
+; CHECK-LABEL: dotp_8to64:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: mov z2.s, #0 // =0x0
+; CHECK-NEXT: udot z2.s, z0.b, z1.b
+; CHECK-NEXT: uunpklo z0.d, z2.s
+; CHECK-NEXT: uunpkhi z1.d, z2.s
+; CHECK-NEXT: ret
+entry:
+ %a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i64>
+ %b.wide = zext <vscale x 16 x i8> %b to <vscale x 16 x i64>
+ %mult = mul nuw nsw <vscale x 16 x i64> %a.wide, %b.wide
+ %partial.reduce = tail call <vscale x 4 x i64> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(
+ <vscale x 4 x i64> zeroinitializer, <vscale x 16 x i64> %mult)
+ ret <vscale x 4 x i64> %partial.reduce
+}
+
+define <vscale x 4 x i64> @dotp_sext_8to64(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
+; CHECK-LABEL: dotp_sext_8to64:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: mov z2.s, #0 // =0x0
+; CHECK-NEXT: sdot z2.s, z0.b, z1.b
+; CHECK-NEXT: sunpklo z0.d, z2.s
+; CHECK-NEXT: sunpkhi z1.d, z2.s
+; CHECK-NEXT: ret
+entry:
+ %a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i64>
+ %b.wide = sext <vscale x 16 x i8> %b to <vscale x 16 x i64>
+ %mult = mul nuw nsw <vscale x 16 x i64> %a.wide, %b.wide
+ %partial.reduce = tail call <vscale x 4 x i64> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(
+ <vscale x 4 x i64> zeroinitializer, <vscale x 16 x i64> %mult)
+ ret <vscale x 4 x i64> %partial.reduce
+}
+
+define <vscale x 4 x i64> @dotp_8to64_accumulator(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b, <vscale x 4 x i64> %acc) {
+; CHECK-LABEL: dotp_8to64_accumulator:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: mov z4.s, #0 // =0x0
+; CHECK-NEXT: udot z4.s, z0.b, z1.b
+; CHECK-NEXT: uunpklo z0.d, z4.s
+; CHECK-NEXT: uunpkhi z1.d, z4.s
+; CHECK-NEXT: add z0.d, z2.d, z0.d
+; CHECK-NEXT: add z1.d, z3.d, z1.d
+; CHECK-NEXT: ret
+entry:
+ %a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i64>
+ %b.wide = zext <vscale x 16 x i8> %b to <vscale x 16 x i64>
+ %mult = mul nuw nsw <vscale x 16 x i64> %a.wide, %b.wide
+ %partial.reduce = tail call <vscale x 4 x i64> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(
+ <vscale x 4 x i64> %acc, <vscale x 16 x i64> %mult)
+ ret <vscale x 4 x i64> %partial.reduce
+}
+
+define <vscale x 4 x i64> @dotp_sext_8to64_accumulator(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b, <vscale x 4 x i64> %acc) {
+; CHECK-LABEL: dotp_sext_8to64_accumulator:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: mov z4.s, #0 // =0x0
+; CHECK-NEXT: sdot z4.s, z0.b, z1.b
+; CHECK-NEXT: sunpklo z0.d, z4.s
+; CHECK-NEXT: sunpkhi z1.d, z4.s
+; CHECK-NEXT: add z0.d, z2.d, z0.d
+; CHECK-NEXT: add z1.d, z3.d, z1.d
+; CHECK-NEXT: ret
+entry:
+ %a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i64>
+ %b.wide = sext <vscale x 16 x i8> %b to <vscale x 16 x i64>
+ %mult = mul nuw nsw <vscale x 16 x i64> %a.wide, %b.wide
+ %partial.reduce = tail call <vscale x 4 x i64> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(
+ <vscale x 4 x i64> %acc, <vscale x 16 x i64> %mult)
+ ret <vscale x 4 x i64> %partial.reduce
+}
+
define <vscale x 4 x i32> @not_dotp(<vscale x 8 x i8> %a, <vscale x 8 x i8> %b) {
; CHECK-LABEL: not_dotp:
; CHECK: // %bb.0: // %entry
>From aa7957faeb3adb8beda7443c4a1b06bddb9b01d4 Mon Sep 17 00:00:00 2001
From: Samuel Tebbs <samuel.tebbs at arm.com>
Date: Tue, 20 Aug 2024 13:53:11 +0100
Subject: [PATCH 17/27] Remove isPartialReductionSupported
---
.../AArch64/AArch64TargetTransformInfo.cpp | 25 -------------------
.../AArch64/AArch64TargetTransformInfo.h | 6 -----
2 files changed, 31 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 8cd2ba17b7d79d..dc748290f2e21e 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -3670,31 +3670,6 @@ AArch64TTIImpl::getCostOfKeepingLiveOverCall(ArrayRef<Type *> Tys) {
return Cost;
}
-bool AArch64TTIImpl::isPartialReductionSupported(
- const Instruction *ReductionInstr, Type *InputType, unsigned ScaleFactor,
- bool IsInputASignExtended, bool IsInputBSignExtended,
- const Instruction *BinOp) const {
- if (ReductionInstr->getOpcode() != Instruction::Add)
- return false;
-
- // Check that both extends are of the same type
- if (IsInputASignExtended != IsInputBSignExtended)
- return false;
-
- if (!BinOp || BinOp->getOpcode() != Instruction::Mul)
- return false;
-
- // Dot product only supports a scale factor of 4
- if (ScaleFactor != 4)
- return false;
-
- Type *ReductionType = ReductionInstr->getType();
-
- return ((ReductionType->isIntegerTy(32) && InputType->isIntegerTy(8)) ||
- (ReductionType->isIntegerTy(64) && InputType->isIntegerTy(16))) &&
- ReductionType->isScalableTy();
-}
-
unsigned AArch64TTIImpl::getMaxInterleaveFactor(ElementCount VF) {
return ST->getMaxInterleaveFactor();
}
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index af7e8e8e497dd8..4a6457d7a7dbf5 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -155,12 +155,6 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
return VF.getKnownMinValue() * ST->getVScaleForTuning();
}
- bool isPartialReductionSupported(const Instruction *ReductionInstr,
- Type *InputType, unsigned ScaleFactor,
- bool IsInputASignExtended,
- bool IsInputBSignExtended,
- const Instruction *BinOp = nullptr) const;
-
unsigned getMaxInterleaveFactor(ElementCount VF);
bool prefersVectorizedAddressing() const;
>From a58ac297afd726a536b6687748d9a196473a4618 Mon Sep 17 00:00:00 2001
From: Samuel Tebbs <samuel.tebbs at arm.com>
Date: Wed, 21 Aug 2024 15:02:17 +0100
Subject: [PATCH 18/27] Share expansion code in SelectionDAG
---
llvm/include/llvm/CodeGen/SelectionDAG.h | 4 +
.../lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 30 +++
.../SelectionDAG/SelectionDAGBuilder.cpp | 29 +--
.../Target/AArch64/AArch64ISelLowering.cpp | 217 ++++++++----------
4 files changed, 130 insertions(+), 150 deletions(-)
diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h
index 1514d92b36b3c2..2235db5d93b5ad 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAG.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAG.h
@@ -1594,6 +1594,10 @@ class SelectionDAG {
/// the target's desired shift amount type.
SDValue getShiftAmountOperand(EVT LHSTy, SDValue Op);
+ /// Expand a partial reduction intrinsic call.
+ /// Op1 and Op2 are its operands and ReducedTY is the intrinsic's return type.
+ SDValue expandPartialReductionIntrinsic(EVT ReducedTy, SDValue Op1, SDValue Op2, SDLoc DL);
+
/// Expand the specified \c ISD::VAARG node as the Legalize pass would.
SDValue expandVAArg(SDNode *Node);
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 27675dce70c260..2510c1828c909f 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -74,6 +74,7 @@
#include <cassert>
#include <cstdint>
#include <cstdlib>
+#include <deque>
#include <limits>
#include <optional>
#include <set>
@@ -2426,6 +2427,35 @@ SDValue SelectionDAG::getShiftAmountOperand(EVT LHSTy, SDValue Op) {
return getZExtOrTrunc(Op, SDLoc(Op), ShTy);
}
+SDValue SelectionDAG::expandPartialReductionIntrinsic(EVT ReducedTy, SDValue Op1, SDValue Op2, SDLoc DL) {
+ EVT FullTy = Op2.getValueType();
+
+ unsigned Stride = ReducedTy.getVectorMinNumElements();
+ unsigned ScaleFactor = FullTy.getVectorMinNumElements() / Stride;
+
+ // Collect all of the subvectors
+ std::deque<SDValue> Subvectors = {Op1};
+ for (unsigned I = 0; I < ScaleFactor; I++) {
+ auto SourceIndex = getVectorIdxConstant(I * Stride, DL);
+ Subvectors.push_back(getNode(ISD::EXTRACT_SUBVECTOR, DL, ReducedTy,
+ {Op2, SourceIndex}));
+ }
+
+ // Flatten the subvector tree
+ while (Subvectors.size() > 1) {
+ Subvectors.push_back(getNode(ISD::ADD, DL, ReducedTy,
+ {Subvectors[0], Subvectors[1]}));
+ Subvectors.pop_front();
+ Subvectors.pop_front();
+ }
+
+ assert(Subvectors.size() == 1 &&
+ "There should only be one subvector after tree flattening");
+
+ return Subvectors[0];
+
+}
+
SDValue SelectionDAG::expandVAArg(SDNode *Node) {
SDLoc dl(Node);
const TargetLowering &TLI = getTargetLoweringInfo();
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 05cbe384cb5ed5..33de8747fb7e56 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -8011,34 +8011,7 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
return;
}
- SDValue OpNode = getValue(I.getOperand(1));
- EVT ReducedTy = EVT::getEVT(I.getType());
- EVT FullTy = OpNode.getValueType();
-
- unsigned Stride = ReducedTy.getVectorMinNumElements();
- unsigned ScaleFactor = FullTy.getVectorMinNumElements() / Stride;
-
- // Collect all of the subvectors
- std::deque<SDValue> Subvectors;
- Subvectors.push_back(getValue(I.getOperand(0)));
- for (unsigned i = 0; i < ScaleFactor; i++) {
- auto SourceIndex = DAG.getVectorIdxConstant(i * Stride, sdl);
- Subvectors.push_back(DAG.getNode(ISD::EXTRACT_SUBVECTOR, sdl, ReducedTy,
- {OpNode, SourceIndex}));
- }
-
- // Flatten the subvector tree
- while (Subvectors.size() > 1) {
- Subvectors.push_back(DAG.getNode(ISD::ADD, sdl, ReducedTy,
- {Subvectors[0], Subvectors[1]}));
- Subvectors.pop_front();
- Subvectors.pop_front();
- }
-
- assert(Subvectors.size() == 1 &&
- "There should only be one subvector after tree flattening");
-
- setValue(&I, Subvectors[0]);
+ setValue(&I, DAG.expandPartialReductionIntrinsic(EVT::getEVT(I.getType()), getValue(I.getOperand(0)), getValue(I.getOperand(1)), sdl));
return;
}
case Intrinsic::experimental_cttz_elts: {
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 13e664f5bf27f5..df89806ca057e1 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1995,37 +1995,12 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
if (!RetTy || !RetTy->isScalableTy())
return true;
- Value *InputA;
- Value *InputB;
- if (match(I,
- m_Intrinsic<Intrinsic::experimental_vector_partial_reduce_add>(
- m_Value(), m_OneUse(m_Mul(m_ZExtOrSExt(m_Value(InputA)),
- m_ZExtOrSExt(m_Value(InputB))))))) {
- VectorType *InputAType = dyn_cast<VectorType>(InputA->getType());
- VectorType *InputBType = dyn_cast<VectorType>(InputB->getType());
- if (!InputAType || !InputBType)
- return true;
- ElementCount ExpectedCount8 = ElementCount::get(8, RetTy->isScalableTy());
- ElementCount ExpectedCount16 = ElementCount::get(16, RetTy->isScalableTy());
- // Check that the input type is 4 times smaller than the output type. If the
- // output type is 64 bit then we can accept 8 bit inputs if we do a 32 bit
- // dot product and add a zext/sext.
- if ((RetTy->getScalarType()->isIntegerTy(64) &&
- InputAType->getElementType()->isIntegerTy(16) &&
- InputAType->getElementCount() == ExpectedCount8 &&
- InputAType == InputBType) ||
- ((RetTy->getScalarType()->isIntegerTy(32) ||
- RetTy->getScalarType()->isIntegerTy(64)) &&
- InputAType->getElementType()->isIntegerTy(8) &&
- InputAType->getElementCount() == ExpectedCount16 &&
- InputAType == InputBType)) {
- auto *Mul = cast<Instruction>(I->getOperand(1));
- auto *Mul0 = cast<Instruction>(Mul->getOperand(0));
- auto *Mul1 = cast<Instruction>(Mul->getOperand(1));
- if (Mul0->getOpcode() == Mul1->getOpcode())
- return false;
- }
- }
+ if (RetTy->getScalarType()->isIntegerTy(32) && RetTy->getElementCount() == ElementCount::get(4, RetTy->isScalableTy()))
+ return false;
+ if (RetTy->getScalarType()->isIntegerTy(64) && RetTy->getElementCount() == ElementCount::get(2, RetTy->isScalableTy()))
+ return false;
+ if (RetTy->getScalarType()->isIntegerTy(64) && RetTy->getElementCount() == ElementCount::get(4, RetTy->isScalableTy()))
+ return false;
return true;
}
@@ -21799,6 +21774,92 @@ static SDValue tryCombineWhileLo(SDNode *N,
return SDValue(N, 0);
}
+SDValue tryLowerPartialReductionToDot(SDNode *N, const AArch64Subtarget *Subtarget, SelectionDAG &DAG) {
+ SDLoc DL(N);
+
+ // The narrower of the two operands. Used as the accumulator
+ auto NarrowOp = N->getOperand(1);
+ auto MulOp = N->getOperand(2);
+ if (MulOp->getOpcode() != ISD::MUL)
+ return SDValue();
+
+ auto ExtA = MulOp->getOperand(0);
+ auto ExtB = MulOp->getOperand(1);
+ bool IsSExt = ExtA->getOpcode() == ISD::SIGN_EXTEND;
+ bool IsZExt = ExtA->getOpcode() == ISD::ZERO_EXTEND;
+ if (ExtA->getOpcode() != ExtB->getOpcode() || (!IsSExt && !IsZExt))
+ return SDValue();
+
+ auto A = ExtA->getOperand(0);
+ auto B = ExtB->getOperand(0);
+ if (A.getValueType() != B.getValueType())
+ return SDValue();
+
+ // The fully-reduced type. Should be a vector of i32 or i64
+ EVT FullType = N->getValueType(0);
+ // The type that is extended to the wide type. Should be an i8 or i16
+ EVT ExtendedType = A.getValueType();
+ // The wide type with four times as many elements as the reduced type. Should be a vector of i32 or i64, the same as the fully-reduced type
+ EVT WideType = MulOp.getValueType();
+ if (WideType.getScalarSizeInBits() != FullType.getScalarSizeInBits())
+ return SDValue();
+ // Dot products operate on chunks of four elements so there must be four times as many elements in the wide type
+ if (WideType.getVectorMinNumElements() / FullType.getVectorMinNumElements() != 4)
+ return SDValue();
+ switch (FullType.getScalarSizeInBits()) {
+ case 32:
+ if (ExtendedType.getScalarSizeInBits() != 8)
+ return SDValue();
+ break;
+ case 64:
+ // i8 to i64 can be done with an extended i32 dot product
+ if (ExtendedType.getScalarSizeInBits() != 8 && ExtendedType.getScalarSizeInBits() != 16)
+ return SDValue();
+ break;
+ default:
+ return SDValue();
+ }
+
+ unsigned DotIntrinsicId = Intrinsic::not_intrinsic;
+
+ if (IsSExt)
+ DotIntrinsicId = Intrinsic::aarch64_sve_sdot;
+ else if (IsZExt)
+ DotIntrinsicId = Intrinsic::aarch64_sve_udot;
+
+ assert(DotIntrinsicId != Intrinsic::not_intrinsic &&
+ "Unexpected dot product case encountered.");
+
+ EVT Type = NarrowOp.getValueType();
+
+ // 8 bit input to 64 bit output can be done by doing a 32 bit dot product
+ // and extending the output
+ bool Extend = A->getValueType(0).getScalarSizeInBits() == 8 &&
+ Type.getScalarSizeInBits() == 64;
+ SDValue Accumulator = NarrowOp;
+ if (Extend) {
+ Type = Type.changeVectorElementType(
+ EVT::getIntegerVT(*DAG.getContext(), 32));
+ // The accumulator is of the wider type so we insert a 0 accumulator and
+ // add the proper one after extending
+ Accumulator = DAG.getNode(ISD::SPLAT_VECTOR, DL, MVT::nxv4i32,
+ DAG.getConstant(0, DL, MVT::i32));
+ }
+
+ auto IntrinsicId = DAG.getConstant(DotIntrinsicId, DL, MVT::i64);
+ auto DotProduct = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, Type,
+ {IntrinsicId, Accumulator, A, B});
+ if (Extend) {
+ auto Extended =
+ DAG.getNode(IsZExt ? ISD::ZERO_EXTEND : ISD::SIGN_EXTEND, DL,
+ NarrowOp.getValueType(), {DotProduct});
+ auto AccAdd = DAG.getNode(ISD::ADD, DL, NarrowOp.getValueType(),
+ {NarrowOp, Extended});
+ DotProduct = AccAdd;
+ }
+ return DotProduct;
+}
+
static SDValue performIntrinsicCombine(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI,
const AArch64Subtarget *Subtarget) {
@@ -21808,97 +21869,9 @@ static SDValue performIntrinsicCombine(SDNode *N,
default:
break;
case Intrinsic::experimental_vector_partial_reduce_add: {
- SDLoc DL(N);
-
- bool IsValidDotProduct = true;
-
- auto NarrowOp = N->getOperand(1);
- auto MulOp = N->getOperand(2);
- if (MulOp->getOpcode() != ISD::MUL)
- IsValidDotProduct = false;
-
- auto ExtA = MulOp->getOperand(0);
- auto ExtB = MulOp->getOperand(1);
- bool IsSExt = ExtA->getOpcode() == ISD::SIGN_EXTEND;
- bool IsZExt = ExtA->getOpcode() == ISD::ZERO_EXTEND;
- if (ExtA->getOpcode() != ExtB->getOpcode() || (!IsSExt && !IsZExt))
- IsValidDotProduct = false;
-
- unsigned DotIntrinsicId = Intrinsic::not_intrinsic;
-
- if (IsSExt && IsValidDotProduct)
- DotIntrinsicId = Intrinsic::aarch64_sve_sdot;
- else if (IsZExt && IsValidDotProduct)
- DotIntrinsicId = Intrinsic::aarch64_sve_udot;
-
- assert((!IsValidDotProduct || DotIntrinsicId != Intrinsic::not_intrinsic) &&
- "Unexpected dot product case encountered.");
-
- if (IsValidDotProduct) {
- auto A = ExtA->getOperand(0);
- auto B = ExtB->getOperand(0);
- EVT Type = NarrowOp.getValueType();
-
- // 8 bit input to 64 bit output can be done by doing a 32 bit dot product
- // and extending the output
- bool Extend = A->getValueType(0).getScalarSizeInBits() == 8 &&
- Type.getScalarSizeInBits() == 64;
- SDValue Accumulator = NarrowOp;
- if (Extend) {
- Type = Type.changeVectorElementType(
- EVT::getIntegerVT(*DAG.getContext(), 32));
- // The accumulator is of the wider type so we insert a 0 accumulator and
- // add the proper one after extending
- Accumulator = DAG.getNode(ISD::SPLAT_VECTOR, DL, MVT::nxv4i32,
- DAG.getConstant(0, DL, MVT::i32));
- }
-
- auto IntrinsicId = DAG.getConstant(DotIntrinsicId, DL, MVT::i64);
- auto DotProduct = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, Type,
- {IntrinsicId, Accumulator, A, B});
- if (Extend) {
- auto Extended =
- DAG.getNode(IsZExt ? ISD::ZERO_EXTEND : ISD::SIGN_EXTEND, DL,
- NarrowOp.getValueType(), {DotProduct});
- auto AccAdd = DAG.getNode(ISD::ADD, DL, NarrowOp.getValueType(),
- {NarrowOp, Extended});
- DotProduct = AccAdd;
- }
- return DotProduct;
- } else {
- // If the node doesn't match a dot product, lower to a series of ADDs
- // instead.
- SDValue Op0 = N->getOperand(1);
- SDValue Op1 = N->getOperand(2);
- EVT Type0 = Op0->getValueType(0);
- EVT Type1 = Op1->getValueType(0);
-
- // Canonicalise so that Op1 has the larger type
- if (Type1.getVectorNumElements() > Type0.getVectorNumElements()) {
- std::swap(Op0, Op1);
- std::swap(Type0, Type1);
- }
-
- auto Type0Elements = Type0.getVectorNumElements();
- auto Type1Elements = Type1.getVectorNumElements();
-
- // If the types are equal then a single ADD is fine
- if (Type0 == Type1)
- return DAG.getNode(ISD::ADD, DL, Type0, {Op0, Op1});
-
- // Otherwise, we need to add each subvector together so that the output is
- // the intrinsic's return type. For example, <4 x i32>
- // partial.reduction(<4 x i32> a, <16 x i32> b) becomes a + b[0..3] +
- // b[4..7] + b[8..11] + b[12..15]
- SDValue Add = Op0;
- for (unsigned i = 0; i < Type1Elements / Type0Elements; i++) {
- SDValue Subvec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, Type0, Op1,
- DAG.getConstant(i, DL, MVT::i64));
-
- Add = DAG.getNode(ISD::ADD, DL, Type0, {Add, Subvec});
- }
- return Add;
- }
+ if (auto Dot = tryLowerPartialReductionToDot(N, Subtarget, DAG))
+ return Dot;
+ return DAG.expandPartialReductionIntrinsic(N->getValueType(0), N->getOperand(1), N->getOperand(2), SDLoc(N));
}
case Intrinsic::aarch64_neon_vcvtfxs2fp:
case Intrinsic::aarch64_neon_vcvtfxu2fp:
>From 5f31079756eb54679f0dd64da0d425084069a929 Mon Sep 17 00:00:00 2001
From: Samuel Tebbs <samuel.tebbs at arm.com>
Date: Wed, 21 Aug 2024 16:04:52 +0100
Subject: [PATCH 19/27] Check for NEON or SVE
---
llvm/include/llvm/CodeGen/SelectionDAG.h | 3 +-
.../lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 45 +++++++-------
.../SelectionDAG/SelectionDAGBuilder.cpp | 4 +-
.../Target/AArch64/AArch64ISelLowering.cpp | 61 +++++++++++--------
4 files changed, 65 insertions(+), 48 deletions(-)
diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h
index 2235db5d93b5ad..2c1d4bf259699e 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAG.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAG.h
@@ -1596,7 +1596,8 @@ class SelectionDAG {
/// Expand a partial reduction intrinsic call.
/// Op1 and Op2 are its operands and ReducedTY is the intrinsic's return type.
- SDValue expandPartialReductionIntrinsic(EVT ReducedTy, SDValue Op1, SDValue Op2, SDLoc DL);
+ SDValue expandPartialReductionIntrinsic(EVT ReducedTy, SDValue Op1,
+ SDValue Op2, SDLoc DL);
/// Expand the specified \c ISD::VAARG node as the Legalize pass would.
SDValue expandVAArg(SDNode *Node);
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 2510c1828c909f..cd09760b1f24e9 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -2427,33 +2427,34 @@ SDValue SelectionDAG::getShiftAmountOperand(EVT LHSTy, SDValue Op) {
return getZExtOrTrunc(Op, SDLoc(Op), ShTy);
}
-SDValue SelectionDAG::expandPartialReductionIntrinsic(EVT ReducedTy, SDValue Op1, SDValue Op2, SDLoc DL) {
- EVT FullTy = Op2.getValueType();
+SDValue SelectionDAG::expandPartialReductionIntrinsic(EVT ReducedTy,
+ SDValue Op1, SDValue Op2,
+ SDLoc DL) {
+ EVT FullTy = Op2.getValueType();
- unsigned Stride = ReducedTy.getVectorMinNumElements();
- unsigned ScaleFactor = FullTy.getVectorMinNumElements() / Stride;
+ unsigned Stride = ReducedTy.getVectorMinNumElements();
+ unsigned ScaleFactor = FullTy.getVectorMinNumElements() / Stride;
- // Collect all of the subvectors
- std::deque<SDValue> Subvectors = {Op1};
- for (unsigned I = 0; I < ScaleFactor; I++) {
- auto SourceIndex = getVectorIdxConstant(I * Stride, DL);
- Subvectors.push_back(getNode(ISD::EXTRACT_SUBVECTOR, DL, ReducedTy,
- {Op2, SourceIndex}));
- }
-
- // Flatten the subvector tree
- while (Subvectors.size() > 1) {
- Subvectors.push_back(getNode(ISD::ADD, DL, ReducedTy,
- {Subvectors[0], Subvectors[1]}));
- Subvectors.pop_front();
- Subvectors.pop_front();
- }
+ // Collect all of the subvectors
+ std::deque<SDValue> Subvectors = {Op1};
+ for (unsigned I = 0; I < ScaleFactor; I++) {
+ auto SourceIndex = getVectorIdxConstant(I * Stride, DL);
+ Subvectors.push_back(
+ getNode(ISD::EXTRACT_SUBVECTOR, DL, ReducedTy, {Op2, SourceIndex}));
+ }
- assert(Subvectors.size() == 1 &&
- "There should only be one subvector after tree flattening");
+ // Flatten the subvector tree
+ while (Subvectors.size() > 1) {
+ Subvectors.push_back(
+ getNode(ISD::ADD, DL, ReducedTy, {Subvectors[0], Subvectors[1]}));
+ Subvectors.pop_front();
+ Subvectors.pop_front();
+ }
- return Subvectors[0];
+ assert(Subvectors.size() == 1 &&
+ "There should only be one subvector after tree flattening");
+ return Subvectors[0];
}
SDValue SelectionDAG::expandVAArg(SDNode *Node) {
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 33de8747fb7e56..ce5ef78eba15db 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -8011,7 +8011,9 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
return;
}
- setValue(&I, DAG.expandPartialReductionIntrinsic(EVT::getEVT(I.getType()), getValue(I.getOperand(0)), getValue(I.getOperand(1)), sdl));
+ setValue(&I, DAG.expandPartialReductionIntrinsic(
+ EVT::getEVT(I.getType()), getValue(I.getOperand(0)),
+ getValue(I.getOperand(1)), sdl));
return;
}
case Intrinsic::experimental_cttz_elts: {
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index df89806ca057e1..b849ddb2a86d67 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1995,11 +1995,14 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
if (!RetTy || !RetTy->isScalableTy())
return true;
- if (RetTy->getScalarType()->isIntegerTy(32) && RetTy->getElementCount() == ElementCount::get(4, RetTy->isScalableTy()))
+ if (RetTy->getScalarType()->isIntegerTy(32) &&
+ RetTy->getElementCount() == ElementCount::get(4, RetTy->isScalableTy()))
return false;
- if (RetTy->getScalarType()->isIntegerTy(64) && RetTy->getElementCount() == ElementCount::get(2, RetTy->isScalableTy()))
+ if (RetTy->getScalarType()->isIntegerTy(64) &&
+ RetTy->getElementCount() == ElementCount::get(2, RetTy->isScalableTy()))
return false;
- if (RetTy->getScalarType()->isIntegerTy(64) && RetTy->getElementCount() == ElementCount::get(4, RetTy->isScalableTy()))
+ if (RetTy->getScalarType()->isIntegerTy(64) &&
+ RetTy->getElementCount() == ElementCount::get(4, RetTy->isScalableTy()))
return false;
return true;
@@ -21774,7 +21777,13 @@ static SDValue tryCombineWhileLo(SDNode *N,
return SDValue(N, 0);
}
-SDValue tryLowerPartialReductionToDot(SDNode *N, const AArch64Subtarget *Subtarget, SelectionDAG &DAG) {
+SDValue tryLowerPartialReductionToDot(SDNode *N,
+ const AArch64Subtarget *Subtarget,
+ SelectionDAG &DAG) {
+
+ if (!Subtarget->isSVEAvailable() && !Subtarget->isNeonAvailable())
+ return SDValue();
+
SDLoc DL(N);
// The narrower of the two operands. Used as the accumulator
@@ -21799,25 +21808,29 @@ SDValue tryLowerPartialReductionToDot(SDNode *N, const AArch64Subtarget *Subtarg
EVT FullType = N->getValueType(0);
// The type that is extended to the wide type. Should be an i8 or i16
EVT ExtendedType = A.getValueType();
- // The wide type with four times as many elements as the reduced type. Should be a vector of i32 or i64, the same as the fully-reduced type
+ // The wide type with four times as many elements as the reduced type. Should
+ // be a vector of i32 or i64, the same as the fully-reduced type
EVT WideType = MulOp.getValueType();
if (WideType.getScalarSizeInBits() != FullType.getScalarSizeInBits())
return SDValue();
- // Dot products operate on chunks of four elements so there must be four times as many elements in the wide type
- if (WideType.getVectorMinNumElements() / FullType.getVectorMinNumElements() != 4)
+ // Dot products operate on chunks of four elements so there must be four times
+ // as many elements in the wide type
+ if (WideType.getVectorMinNumElements() / FullType.getVectorMinNumElements() !=
+ 4)
return SDValue();
switch (FullType.getScalarSizeInBits()) {
- case 32:
- if (ExtendedType.getScalarSizeInBits() != 8)
- return SDValue();
- break;
- case 64:
- // i8 to i64 can be done with an extended i32 dot product
- if (ExtendedType.getScalarSizeInBits() != 8 && ExtendedType.getScalarSizeInBits() != 16)
- return SDValue();
- break;
- default:
+ case 32:
+ if (ExtendedType.getScalarSizeInBits() != 8)
+ return SDValue();
+ break;
+ case 64:
+ // i8 to i64 can be done with an extended i32 dot product
+ if (ExtendedType.getScalarSizeInBits() != 8 &&
+ ExtendedType.getScalarSizeInBits() != 16)
return SDValue();
+ break;
+ default:
+ return SDValue();
}
unsigned DotIntrinsicId = Intrinsic::not_intrinsic;
@@ -21838,8 +21851,8 @@ SDValue tryLowerPartialReductionToDot(SDNode *N, const AArch64Subtarget *Subtarg
Type.getScalarSizeInBits() == 64;
SDValue Accumulator = NarrowOp;
if (Extend) {
- Type = Type.changeVectorElementType(
- EVT::getIntegerVT(*DAG.getContext(), 32));
+ Type =
+ Type.changeVectorElementType(EVT::getIntegerVT(*DAG.getContext(), 32));
// The accumulator is of the wider type so we insert a 0 accumulator and
// add the proper one after extending
Accumulator = DAG.getNode(ISD::SPLAT_VECTOR, DL, MVT::nxv4i32,
@@ -21850,9 +21863,8 @@ SDValue tryLowerPartialReductionToDot(SDNode *N, const AArch64Subtarget *Subtarg
auto DotProduct = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, Type,
{IntrinsicId, Accumulator, A, B});
if (Extend) {
- auto Extended =
- DAG.getNode(IsZExt ? ISD::ZERO_EXTEND : ISD::SIGN_EXTEND, DL,
- NarrowOp.getValueType(), {DotProduct});
+ auto Extended = DAG.getNode(IsZExt ? ISD::ZERO_EXTEND : ISD::SIGN_EXTEND,
+ DL, NarrowOp.getValueType(), {DotProduct});
auto AccAdd = DAG.getNode(ISD::ADD, DL, NarrowOp.getValueType(),
{NarrowOp, Extended});
DotProduct = AccAdd;
@@ -21870,8 +21882,9 @@ static SDValue performIntrinsicCombine(SDNode *N,
break;
case Intrinsic::experimental_vector_partial_reduce_add: {
if (auto Dot = tryLowerPartialReductionToDot(N, Subtarget, DAG))
- return Dot;
- return DAG.expandPartialReductionIntrinsic(N->getValueType(0), N->getOperand(1), N->getOperand(2), SDLoc(N));
+ return Dot;
+ return DAG.expandPartialReductionIntrinsic(
+ N->getValueType(0), N->getOperand(1), N->getOperand(2), SDLoc(N));
}
case Intrinsic::aarch64_neon_vcvtfxs2fp:
case Intrinsic::aarch64_neon_vcvtfxu2fp:
>From 2f3a0dc8d581efd0687d0eedf3c346ffc6a60716 Mon Sep 17 00:00:00 2001
From: Samuel Tebbs <samuel.tebbs at arm.com>
Date: Wed, 28 Aug 2024 09:27:15 +0100
Subject: [PATCH 20/27] Rename expansion function
---
llvm/include/llvm/CodeGen/SelectionDAG.h | 4 ++--
llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 5 ++---
llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp | 6 +++---
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 4 ++--
4 files changed, 9 insertions(+), 10 deletions(-)
diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h
index 2c1d4bf259699e..227616c37e004d 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAG.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAG.h
@@ -1596,8 +1596,8 @@ class SelectionDAG {
/// Expand a partial reduction intrinsic call.
/// Op1 and Op2 are its operands and ReducedTY is the intrinsic's return type.
- SDValue expandPartialReductionIntrinsic(EVT ReducedTy, SDValue Op1,
- SDValue Op2, SDLoc DL);
+ SDValue getPartialReduceAdd(SDLoc DL, EVT ReducedTy, SDValue Op1,
+ SDValue Op2);
/// Expand the specified \c ISD::VAARG node as the Legalize pass would.
SDValue expandVAArg(SDNode *Node);
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index cd09760b1f24e9..d5e61183d0e25d 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -2427,9 +2427,8 @@ SDValue SelectionDAG::getShiftAmountOperand(EVT LHSTy, SDValue Op) {
return getZExtOrTrunc(Op, SDLoc(Op), ShTy);
}
-SDValue SelectionDAG::expandPartialReductionIntrinsic(EVT ReducedTy,
- SDValue Op1, SDValue Op2,
- SDLoc DL) {
+SDValue SelectionDAG::getPartialReduceAdd(SDLoc DL, EVT ReducedTy, SDValue Op1,
+ SDValue Op2) {
EVT FullTy = Op2.getValueType();
unsigned Stride = ReducedTy.getVectorMinNumElements();
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index ce5ef78eba15db..98c2e703c39ed1 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -8011,9 +8011,9 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
return;
}
- setValue(&I, DAG.expandPartialReductionIntrinsic(
- EVT::getEVT(I.getType()), getValue(I.getOperand(0)),
- getValue(I.getOperand(1)), sdl));
+ setValue(&I, DAG.getPartialReduceAdd(sdl, EVT::getEVT(I.getType()),
+ getValue(I.getOperand(0)),
+ getValue(I.getOperand(1))));
return;
}
case Intrinsic::experimental_cttz_elts: {
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index b849ddb2a86d67..a25c09ade370e2 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -21883,8 +21883,8 @@ static SDValue performIntrinsicCombine(SDNode *N,
case Intrinsic::experimental_vector_partial_reduce_add: {
if (auto Dot = tryLowerPartialReductionToDot(N, Subtarget, DAG))
return Dot;
- return DAG.expandPartialReductionIntrinsic(
- N->getValueType(0), N->getOperand(1), N->getOperand(2), SDLoc(N));
+ return DAG.getPartialReduceAdd(SDLoc(N), N->getValueType(0),
+ N->getOperand(1), N->getOperand(2));
}
case Intrinsic::aarch64_neon_vcvtfxs2fp:
case Intrinsic::aarch64_neon_vcvtfxu2fp:
>From 00a1be219912ca53f24dda3ee410229fa6286736 Mon Sep 17 00:00:00 2001
From: Samuel Tebbs <samuel.tebbs at arm.com>
Date: Wed, 28 Aug 2024 10:22:45 +0100
Subject: [PATCH 21/27] Simplify shouldExpandPartialReductionIntrinsic
---
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 15 +++------------
1 file changed, 3 insertions(+), 12 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index a25c09ade370e2..9f5d5a61397d10 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1991,21 +1991,12 @@ bool AArch64TargetLowering::shouldExpandGetActiveLaneMask(EVT ResVT,
bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
const IntrinsicInst *I) const {
- VectorType *RetTy = dyn_cast<VectorType>(I->getType());
- if (!RetTy || !RetTy->isScalableTy())
+ if (I->getIntrinsicID() != Intrinsic::experimental_vector_partial_reduce_add)
return true;
- if (RetTy->getScalarType()->isIntegerTy(32) &&
- RetTy->getElementCount() == ElementCount::get(4, RetTy->isScalableTy()))
- return false;
- if (RetTy->getScalarType()->isIntegerTy(64) &&
- RetTy->getElementCount() == ElementCount::get(2, RetTy->isScalableTy()))
- return false;
- if (RetTy->getScalarType()->isIntegerTy(64) &&
- RetTy->getElementCount() == ElementCount::get(4, RetTy->isScalableTy()))
- return false;
+ EVT VT = EVT::getEVT(I->getType());
- return true;
+ return VT != MVT::nxv4i32 && VT != MVT::nxv2i64 && VT != MVT::nxv4i64;
}
bool AArch64TargetLowering::shouldExpandCttzElements(EVT VT) const {
>From f6c58393ddc03ae7328f7a1fe8236bce55924999 Mon Sep 17 00:00:00 2001
From: Samuel Tebbs <samuel.tebbs at arm.com>
Date: Thu, 29 Aug 2024 09:36:24 +0100
Subject: [PATCH 22/27] Remove nxv4i64 case
---
.../Target/AArch64/AArch64ISelLowering.cpp | 77 ++++++-------------
.../AArch64/partial-reduce-dot-product.ll | 72 -----------------
2 files changed, 23 insertions(+), 126 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 9f5d5a61397d10..7193303c7084ab 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -21795,35 +21795,6 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
if (A.getValueType() != B.getValueType())
return SDValue();
- // The fully-reduced type. Should be a vector of i32 or i64
- EVT FullType = N->getValueType(0);
- // The type that is extended to the wide type. Should be an i8 or i16
- EVT ExtendedType = A.getValueType();
- // The wide type with four times as many elements as the reduced type. Should
- // be a vector of i32 or i64, the same as the fully-reduced type
- EVT WideType = MulOp.getValueType();
- if (WideType.getScalarSizeInBits() != FullType.getScalarSizeInBits())
- return SDValue();
- // Dot products operate on chunks of four elements so there must be four times
- // as many elements in the wide type
- if (WideType.getVectorMinNumElements() / FullType.getVectorMinNumElements() !=
- 4)
- return SDValue();
- switch (FullType.getScalarSizeInBits()) {
- case 32:
- if (ExtendedType.getScalarSizeInBits() != 8)
- return SDValue();
- break;
- case 64:
- // i8 to i64 can be done with an extended i32 dot product
- if (ExtendedType.getScalarSizeInBits() != 8 &&
- ExtendedType.getScalarSizeInBits() != 16)
- return SDValue();
- break;
- default:
- return SDValue();
- }
-
unsigned DotIntrinsicId = Intrinsic::not_intrinsic;
if (IsSExt)
@@ -21834,33 +21805,31 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
assert(DotIntrinsicId != Intrinsic::not_intrinsic &&
"Unexpected dot product case encountered.");
- EVT Type = NarrowOp.getValueType();
+ auto IntrinsicId = DAG.getConstant(DotIntrinsicId, DL, MVT::i64);
- // 8 bit input to 64 bit output can be done by doing a 32 bit dot product
- // and extending the output
- bool Extend = A->getValueType(0).getScalarSizeInBits() == 8 &&
- Type.getScalarSizeInBits() == 64;
- SDValue Accumulator = NarrowOp;
- if (Extend) {
- Type =
- Type.changeVectorElementType(EVT::getIntegerVT(*DAG.getContext(), 32));
- // The accumulator is of the wider type so we insert a 0 accumulator and
- // add the proper one after extending
- Accumulator = DAG.getNode(ISD::SPLAT_VECTOR, DL, MVT::nxv4i32,
- DAG.getConstant(0, DL, MVT::i32));
- }
+ // The fully-reduced type. Should be a vector of i32 or i64
+ EVT ReducedType = N->getValueType(0);
+ // The type that is extended to the wide type. Should be an i8 or i16
+ EVT ExtendedType = A.getValueType();
+ // The wide type with four times as many elements as the reduced type. Should
+ // be a vector of i32 or i64, the same as the fully-reduced type
+ EVT WideType = MulOp.getValueType();
+ if (WideType.getScalarSizeInBits() != ReducedType.getScalarSizeInBits())
+ return SDValue();
- auto IntrinsicId = DAG.getConstant(DotIntrinsicId, DL, MVT::i64);
- auto DotProduct = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, Type,
- {IntrinsicId, Accumulator, A, B});
- if (Extend) {
- auto Extended = DAG.getNode(IsZExt ? ISD::ZERO_EXTEND : ISD::SIGN_EXTEND,
- DL, NarrowOp.getValueType(), {DotProduct});
- auto AccAdd = DAG.getNode(ISD::ADD, DL, NarrowOp.getValueType(),
- {NarrowOp, Extended});
- DotProduct = AccAdd;
- }
- return DotProduct;
+ // Dot products operate on chunks of four elements so there must be four times
+ // as many elements in the wide type
+ if (WideType == MVT::nxv16i32 && ReducedType == MVT::nxv4i32 &&
+ ExtendedType == MVT::nxv16i8)
+ return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, MVT::nxv4i32,
+ {IntrinsicId, NarrowOp, A, B});
+
+ if (WideType == MVT::nxv8i64 && ReducedType == MVT::nxv2i64 &&
+ ExtendedType == MVT::nxv8i16)
+ return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, MVT::nxv2i64,
+ {IntrinsicId, NarrowOp, A, B});
+
+ return SDValue();
}
static SDValue performIntrinsicCombine(SDNode *N,
diff --git a/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll
index c1cf9026d693ce..16ef219a93c9bf 100644
--- a/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll
@@ -61,78 +61,6 @@ entry:
ret <vscale x 2 x i64> %partial.reduce
}
-define <vscale x 4 x i64> @dotp_8to64(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
-; CHECK-LABEL: dotp_8to64:
-; CHECK: // %bb.0: // %entry
-; CHECK-NEXT: mov z2.s, #0 // =0x0
-; CHECK-NEXT: udot z2.s, z0.b, z1.b
-; CHECK-NEXT: uunpklo z0.d, z2.s
-; CHECK-NEXT: uunpkhi z1.d, z2.s
-; CHECK-NEXT: ret
-entry:
- %a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i64>
- %b.wide = zext <vscale x 16 x i8> %b to <vscale x 16 x i64>
- %mult = mul nuw nsw <vscale x 16 x i64> %a.wide, %b.wide
- %partial.reduce = tail call <vscale x 4 x i64> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(
- <vscale x 4 x i64> zeroinitializer, <vscale x 16 x i64> %mult)
- ret <vscale x 4 x i64> %partial.reduce
-}
-
-define <vscale x 4 x i64> @dotp_sext_8to64(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
-; CHECK-LABEL: dotp_sext_8to64:
-; CHECK: // %bb.0: // %entry
-; CHECK-NEXT: mov z2.s, #0 // =0x0
-; CHECK-NEXT: sdot z2.s, z0.b, z1.b
-; CHECK-NEXT: sunpklo z0.d, z2.s
-; CHECK-NEXT: sunpkhi z1.d, z2.s
-; CHECK-NEXT: ret
-entry:
- %a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i64>
- %b.wide = sext <vscale x 16 x i8> %b to <vscale x 16 x i64>
- %mult = mul nuw nsw <vscale x 16 x i64> %a.wide, %b.wide
- %partial.reduce = tail call <vscale x 4 x i64> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(
- <vscale x 4 x i64> zeroinitializer, <vscale x 16 x i64> %mult)
- ret <vscale x 4 x i64> %partial.reduce
-}
-
-define <vscale x 4 x i64> @dotp_8to64_accumulator(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b, <vscale x 4 x i64> %acc) {
-; CHECK-LABEL: dotp_8to64_accumulator:
-; CHECK: // %bb.0: // %entry
-; CHECK-NEXT: mov z4.s, #0 // =0x0
-; CHECK-NEXT: udot z4.s, z0.b, z1.b
-; CHECK-NEXT: uunpklo z0.d, z4.s
-; CHECK-NEXT: uunpkhi z1.d, z4.s
-; CHECK-NEXT: add z0.d, z2.d, z0.d
-; CHECK-NEXT: add z1.d, z3.d, z1.d
-; CHECK-NEXT: ret
-entry:
- %a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i64>
- %b.wide = zext <vscale x 16 x i8> %b to <vscale x 16 x i64>
- %mult = mul nuw nsw <vscale x 16 x i64> %a.wide, %b.wide
- %partial.reduce = tail call <vscale x 4 x i64> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(
- <vscale x 4 x i64> %acc, <vscale x 16 x i64> %mult)
- ret <vscale x 4 x i64> %partial.reduce
-}
-
-define <vscale x 4 x i64> @dotp_sext_8to64_accumulator(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b, <vscale x 4 x i64> %acc) {
-; CHECK-LABEL: dotp_sext_8to64_accumulator:
-; CHECK: // %bb.0: // %entry
-; CHECK-NEXT: mov z4.s, #0 // =0x0
-; CHECK-NEXT: sdot z4.s, z0.b, z1.b
-; CHECK-NEXT: sunpklo z0.d, z4.s
-; CHECK-NEXT: sunpkhi z1.d, z4.s
-; CHECK-NEXT: add z0.d, z2.d, z0.d
-; CHECK-NEXT: add z1.d, z3.d, z1.d
-; CHECK-NEXT: ret
-entry:
- %a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i64>
- %b.wide = sext <vscale x 16 x i8> %b to <vscale x 16 x i64>
- %mult = mul nuw nsw <vscale x 16 x i64> %a.wide, %b.wide
- %partial.reduce = tail call <vscale x 4 x i64> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(
- <vscale x 4 x i64> %acc, <vscale x 16 x i64> %mult)
- ret <vscale x 4 x i64> %partial.reduce
-}
-
define <vscale x 4 x i32> @not_dotp(<vscale x 8 x i8> %a, <vscale x 8 x i8> %b) {
; CHECK-LABEL: not_dotp:
; CHECK: // %bb.0: // %entry
>From da20b2a03ef35f0f1172f0b4826ac3d792671da5 Mon Sep 17 00:00:00 2001
From: Samuel Tebbs <samuel.tebbs at arm.com>
Date: Thu, 29 Aug 2024 09:47:18 +0100
Subject: [PATCH 23/27] Add assertion
---
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 5 +++++
1 file changed, 5 insertions(+)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 7193303c7084ab..a100d033f50d00 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -21772,6 +21772,11 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
const AArch64Subtarget *Subtarget,
SelectionDAG &DAG) {
+ assert(N->getOpcode() == ISD::INTRINSIC_WO_CHAIN &&
+ getIntrinsicID(N) ==
+ Intrinsic::experimental_vector_partial_reduce_add &&
+ "Expected a partial reduction node");
+
if (!Subtarget->isSVEAvailable() && !Subtarget->isNeonAvailable())
return SDValue();
>From 4697fc13d67421bd364ca907659a4896f629d0e3 Mon Sep 17 00:00:00 2001
From: Samuel Tebbs <samuel.tebbs at arm.com>
Date: Thu, 29 Aug 2024 09:48:31 +0100
Subject: [PATCH 24/27] Fix subtarget check
---
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 4 +---
1 file changed, 1 insertion(+), 3 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index a100d033f50d00..e303f6b693da18 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -21777,7 +21777,7 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
Intrinsic::experimental_vector_partial_reduce_add &&
"Expected a partial reduction node");
- if (!Subtarget->isSVEAvailable() && !Subtarget->isNeonAvailable())
+ if (!Subtarget->isSVEorStreamingSVEAvailable())
return SDValue();
SDLoc DL(N);
@@ -21819,8 +21819,6 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
// The wide type with four times as many elements as the reduced type. Should
// be a vector of i32 or i64, the same as the fully-reduced type
EVT WideType = MulOp.getValueType();
- if (WideType.getScalarSizeInBits() != ReducedType.getScalarSizeInBits())
- return SDValue();
// Dot products operate on chunks of four elements so there must be four times
// as many elements in the wide type
>From 31b75674f414b740c3770377cb8b498aaa607225 Mon Sep 17 00:00:00 2001
From: Samuel Tebbs <samuel.tebbs at arm.com>
Date: Thu, 29 Aug 2024 09:50:33 +0100
Subject: [PATCH 25/27] Emit a node instead of an intrinsic
---
.../Target/AArch64/AArch64ISelLowering.cpp | 19 ++++++++-----------
1 file changed, 8 insertions(+), 11 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index e303f6b693da18..ded99e5ec8c440 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -21800,17 +21800,14 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
if (A.getValueType() != B.getValueType())
return SDValue();
- unsigned DotIntrinsicId = Intrinsic::not_intrinsic;
+ unsigned Opcode = 0;
if (IsSExt)
- DotIntrinsicId = Intrinsic::aarch64_sve_sdot;
+ Opcode = AArch64ISD::SDOT;
else if (IsZExt)
- DotIntrinsicId = Intrinsic::aarch64_sve_udot;
-
- assert(DotIntrinsicId != Intrinsic::not_intrinsic &&
- "Unexpected dot product case encountered.");
+ Opcode = AArch64ISD::UDOT;
- auto IntrinsicId = DAG.getConstant(DotIntrinsicId, DL, MVT::i64);
+ assert(Opcode != 0 && "Unexpected dot product case encountered.");
// The fully-reduced type. Should be a vector of i32 or i64
EVT ReducedType = N->getValueType(0);
@@ -21824,13 +21821,13 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
// as many elements in the wide type
if (WideType == MVT::nxv16i32 && ReducedType == MVT::nxv4i32 &&
ExtendedType == MVT::nxv16i8)
- return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, MVT::nxv4i32,
- {IntrinsicId, NarrowOp, A, B});
+ return DAG.getNode(Opcode, DL, MVT::nxv4i32,
+ NarrowOp, A, B);
if (WideType == MVT::nxv8i64 && ReducedType == MVT::nxv2i64 &&
ExtendedType == MVT::nxv8i16)
- return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, MVT::nxv2i64,
- {IntrinsicId, NarrowOp, A, B});
+ return DAG.getNode(Opcode, DL, MVT::nxv2i64,
+ NarrowOp, A, B);
return SDValue();
}
>From 76296792b2cd94c78daaa03e8416e5bc86346ccb Mon Sep 17 00:00:00 2001
From: Samuel Tebbs <samuel.tebbs at arm.com>
Date: Thu, 29 Aug 2024 10:11:18 +0100
Subject: [PATCH 26/27] Pass accumulator from function in tests
---
.../AArch64/partial-reduce-dot-product.ll | 68 ++++++++-----------
1 file changed, 30 insertions(+), 38 deletions(-)
diff --git a/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll
index 16ef219a93c9bf..b1354ab210f727 100644
--- a/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll
@@ -1,104 +1,96 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
; RUN: llc -mtriple=aarch64 -mattr=+sve2 %s -o - | FileCheck %s
-define <vscale x 4 x i32> @dotp(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
+define <vscale x 4 x i32> @dotp(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
; CHECK-LABEL: dotp:
; CHECK: // %bb.0: // %entry
-; CHECK-NEXT: mov z2.s, #0 // =0x0
-; CHECK-NEXT: udot z2.s, z0.b, z1.b
-; CHECK-NEXT: mov z0.d, z2.d
+; CHECK-NEXT: udot z0.s, z1.b, z2.b
; CHECK-NEXT: ret
entry:
%a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
%b.wide = zext <vscale x 16 x i8> %b to <vscale x 16 x i32>
%mult = mul nuw nsw <vscale x 16 x i32> %a.wide, %b.wide
- %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> zeroinitializer, <vscale x 16 x i32> %mult)
+ %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %mult)
ret <vscale x 4 x i32> %partial.reduce
}
-define <vscale x 2 x i64> @dotp_wide(<vscale x 8 x i16> %a, <vscale x 8 x i16> %b) {
+define <vscale x 2 x i64> @dotp_wide(<vscale x 2 x i64> %acc, <vscale x 8 x i16> %a, <vscale x 8 x i16> %b) {
; CHECK-LABEL: dotp_wide:
; CHECK: // %bb.0: // %entry
-; CHECK-NEXT: mov z2.d, #0 // =0x0
-; CHECK-NEXT: udot z2.d, z0.h, z1.h
-; CHECK-NEXT: mov z0.d, z2.d
+; CHECK-NEXT: udot z0.d, z1.h, z2.h
; CHECK-NEXT: ret
entry:
%a.wide = zext <vscale x 8 x i16> %a to <vscale x 8 x i64>
%b.wide = zext <vscale x 8 x i16> %b to <vscale x 8 x i64>
%mult = mul nuw nsw <vscale x 8 x i64> %a.wide, %b.wide
- %partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv8i64(<vscale x 2 x i64> zeroinitializer, <vscale x 8 x i64> %mult)
+ %partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv8i64(<vscale x 2 x i64> %acc, <vscale x 8 x i64> %mult)
ret <vscale x 2 x i64> %partial.reduce
}
-define <vscale x 4 x i32> @dotp_sext(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
+define <vscale x 4 x i32> @dotp_sext(<vscale x 4 x i32> %accc, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
; CHECK-LABEL: dotp_sext:
; CHECK: // %bb.0: // %entry
-; CHECK-NEXT: mov z2.s, #0 // =0x0
-; CHECK-NEXT: sdot z2.s, z0.b, z1.b
-; CHECK-NEXT: mov z0.d, z2.d
+; CHECK-NEXT: sdot z0.s, z1.b, z2.b
; CHECK-NEXT: ret
entry:
%a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
%b.wide = sext <vscale x 16 x i8> %b to <vscale x 16 x i32>
%mult = mul nuw nsw <vscale x 16 x i32> %a.wide, %b.wide
- %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> zeroinitializer, <vscale x 16 x i32> %mult)
+ %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %accc, <vscale x 16 x i32> %mult)
ret <vscale x 4 x i32> %partial.reduce
}
-define <vscale x 2 x i64> @dotp_wide_sext(<vscale x 8 x i16> %a, <vscale x 8 x i16> %b) {
+define <vscale x 2 x i64> @dotp_wide_sext(<vscale x 2 x i64> %acc, <vscale x 8 x i16> %a, <vscale x 8 x i16> %b) {
; CHECK-LABEL: dotp_wide_sext:
; CHECK: // %bb.0: // %entry
-; CHECK-NEXT: mov z2.d, #0 // =0x0
-; CHECK-NEXT: sdot z2.d, z0.h, z1.h
-; CHECK-NEXT: mov z0.d, z2.d
+; CHECK-NEXT: sdot z0.d, z1.h, z2.h
; CHECK-NEXT: ret
entry:
%a.wide = sext <vscale x 8 x i16> %a to <vscale x 8 x i64>
%b.wide = sext <vscale x 8 x i16> %b to <vscale x 8 x i64>
%mult = mul nuw nsw <vscale x 8 x i64> %a.wide, %b.wide
- %partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv8i64(<vscale x 2 x i64> zeroinitializer, <vscale x 8 x i64> %mult)
+ %partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv8i64(<vscale x 2 x i64> %acc, <vscale x 8 x i64> %mult)
ret <vscale x 2 x i64> %partial.reduce
}
-define <vscale x 4 x i32> @not_dotp(<vscale x 8 x i8> %a, <vscale x 8 x i8> %b) {
+define <vscale x 4 x i32> @not_dotp(<vscale x 4 x i32> %acc, <vscale x 8 x i8> %a, <vscale x 8 x i8> %b) {
; CHECK-LABEL: not_dotp:
; CHECK: // %bb.0: // %entry
-; CHECK-NEXT: and z0.h, z0.h, #0xff
; CHECK-NEXT: and z1.h, z1.h, #0xff
+; CHECK-NEXT: and z2.h, z2.h, #0xff
; CHECK-NEXT: ptrue p0.s
-; CHECK-NEXT: uunpkhi z2.s, z0.h
-; CHECK-NEXT: uunpkhi z3.s, z1.h
-; CHECK-NEXT: uunpklo z0.s, z0.h
-; CHECK-NEXT: uunpklo z1.s, z1.h
-; CHECK-NEXT: mul z2.s, z2.s, z3.s
-; CHECK-NEXT: mad z0.s, p0/m, z1.s, z2.s
+; CHECK-NEXT: uunpklo z3.s, z1.h
+; CHECK-NEXT: uunpklo z4.s, z2.h
+; CHECK-NEXT: uunpkhi z1.s, z1.h
+; CHECK-NEXT: uunpkhi z2.s, z2.h
+; CHECK-NEXT: mla z0.s, p0/m, z3.s, z4.s
+; CHECK-NEXT: mla z0.s, p0/m, z1.s, z2.s
; CHECK-NEXT: ret
entry:
%a.wide = zext <vscale x 8 x i8> %a to <vscale x 8 x i32>
%b.wide = zext <vscale x 8 x i8> %b to <vscale x 8 x i32>
%mult = mul nuw nsw <vscale x 8 x i32> %a.wide, %b.wide
- %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> zeroinitializer, <vscale x 8 x i32> %mult)
+ %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %acc, <vscale x 8 x i32> %mult)
ret <vscale x 4 x i32> %partial.reduce
}
-define <vscale x 2 x i64> @not_dotp_wide(<vscale x 4 x i16> %a, <vscale x 4 x i16> %b) {
+define <vscale x 2 x i64> @not_dotp_wide(<vscale x 2 x i64> %acc, <vscale x 4 x i16> %a, <vscale x 4 x i16> %b) {
; CHECK-LABEL: not_dotp_wide:
; CHECK: // %bb.0: // %entry
-; CHECK-NEXT: and z0.s, z0.s, #0xffff
; CHECK-NEXT: and z1.s, z1.s, #0xffff
+; CHECK-NEXT: and z2.s, z2.s, #0xffff
; CHECK-NEXT: ptrue p0.d
-; CHECK-NEXT: uunpkhi z2.d, z0.s
-; CHECK-NEXT: uunpkhi z3.d, z1.s
-; CHECK-NEXT: uunpklo z0.d, z0.s
-; CHECK-NEXT: uunpklo z1.d, z1.s
-; CHECK-NEXT: mul z2.d, z2.d, z3.d
-; CHECK-NEXT: mad z0.d, p0/m, z1.d, z2.d
+; CHECK-NEXT: uunpklo z3.d, z1.s
+; CHECK-NEXT: uunpklo z4.d, z2.s
+; CHECK-NEXT: uunpkhi z1.d, z1.s
+; CHECK-NEXT: uunpkhi z2.d, z2.s
+; CHECK-NEXT: mla z0.d, p0/m, z3.d, z4.d
+; CHECK-NEXT: mla z0.d, p0/m, z1.d, z2.d
; CHECK-NEXT: ret
entry:
%a.wide = zext <vscale x 4 x i16> %a to <vscale x 4 x i64>
%b.wide = zext <vscale x 4 x i16> %b to <vscale x 4 x i64>
%mult = mul nuw nsw <vscale x 4 x i64> %a.wide, %b.wide
- %partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv4i64(<vscale x 2 x i64> zeroinitializer, <vscale x 4 x i64> %mult)
+ %partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv4i64(<vscale x 2 x i64> %acc, <vscale x 4 x i64> %mult)
ret <vscale x 2 x i64> %partial.reduce
}
>From 830df7624894cca7c00db2263129433cf0e1a9d7 Mon Sep 17 00:00:00 2001
From: Sam Tebbs <samuel.tebbs at arm.com>
Date: Fri, 30 Aug 2024 10:00:37 +0100
Subject: [PATCH 27/27] Remove nxv4i64 case from
shouldExpandPartialReductionIntrinsic
---
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index ded99e5ec8c440..9ec25a4074a0a3 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1996,7 +1996,7 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
EVT VT = EVT::getEVT(I->getType());
- return VT != MVT::nxv4i32 && VT != MVT::nxv2i64 && VT != MVT::nxv4i64;
+ return VT != MVT::nxv4i32 && VT != MVT::nxv2i64;
}
bool AArch64TargetLowering::shouldExpandCttzElements(EVT VT) const {
More information about the llvm-commits
mailing list