[llvm] [AArch64] Lower partial add reduction to udot or svdot (PR #101010)
Sam Tebbs via llvm-commits
llvm-commits at lists.llvm.org
Tue Aug 20 06:31:49 PDT 2024
https://github.com/SamTebbs33 updated https://github.com/llvm/llvm-project/pull/101010
>From 0b9ce21c0019fea07188ffd142bd9cf580d09f35 Mon Sep 17 00:00:00 2001
From: Sam Tebbs <samuel.tebbs at arm.com>
Date: Mon, 29 Jul 2024 10:46:16 +0100
Subject: [PATCH 01/17] [AArch64] Lower add partial reduction to udot
This patch introduces lowering of the partial add reduction intrinsic to
a udot or svdot for AArch64.
---
llvm/include/llvm/CodeGen/TargetLowering.h | 6 +
.../SelectionDAG/SelectionDAGBuilder.cpp | 6 +
.../Target/AArch64/AArch64ISelLowering.cpp | 77 +++++++++++++
llvm/lib/Target/AArch64/AArch64ISelLowering.h | 2 +
.../AArch64/AArch64TargetTransformInfo.cpp | 30 +++++
.../AArch64/AArch64TargetTransformInfo.h | 6 +
.../AArch64/partial-reduce-dot-product.ll | 109 ++++++++++++++++++
7 files changed, 236 insertions(+)
create mode 100644 llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 9d9886f4920a29..07d99aec47122a 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -453,6 +453,12 @@ class TargetLoweringBase {
return true;
}
+ /// Return true if the @llvm.experimental.vector.partial.reduce.* intrinsic
+ /// should be expanded using generic code in SelectionDAGBuilder.
+ virtual bool shouldExpandPartialReductionIntrinsic(const CallInst *I) const {
+ return true;
+ }
+
/// Return true if the @llvm.get.active.lane.mask intrinsic should be expanded
/// using generic code in SelectionDAGBuilder.
virtual bool shouldExpandGetActiveLaneMask(EVT VT, EVT OpVT) const {
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 1791f1b503379e..c70ab253c1aabc 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -7985,6 +7985,12 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
return;
}
case Intrinsic::experimental_vector_partial_reduce_add: {
+
+ if (!TLI.shouldExpandPartialReductionIntrinsic(&I)) {
+ visitTargetIntrinsic(I, Intrinsic);
+ return;
+ }
+
SDValue OpNode = getValue(I.getOperand(1));
EVT ReducedTy = EVT::getEVT(I.getType());
EVT FullTy = OpNode.getValueType();
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index d86e52d49000ae..d1ee58668ecbd7 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1971,6 +1971,57 @@ bool AArch64TargetLowering::shouldExpandGetActiveLaneMask(EVT ResVT,
return false;
}
+bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
+ const CallInst *CI) const {
+ const bool TargetLowers = false;
+ const bool GenericLowers = true;
+
+ auto *I = dyn_cast<IntrinsicInst>(CI);
+ if (!I)
+ return GenericLowers;
+
+ ScalableVectorType *RetTy = dyn_cast<ScalableVectorType>(I->getType());
+
+ if (!RetTy)
+ return GenericLowers;
+
+ ScalableVectorType *InputTy = nullptr;
+
+ auto RetScalarTy = RetTy->getScalarType();
+ if (RetScalarTy->isIntegerTy(64)) {
+ InputTy = ScalableVectorType::get(Type::getInt16Ty(I->getContext()), 8);
+ } else if (RetScalarTy->isIntegerTy(32)) {
+ InputTy = ScalableVectorType::get(Type::getInt8Ty(I->getContext()), 16);
+ }
+
+ if (!InputTy)
+ return GenericLowers;
+
+ Value *InputA;
+ Value *InputB;
+
+ auto Pattern = m_Intrinsic<Intrinsic::experimental_vector_partial_reduce_add>(
+ m_Value(), m_OneUse(m_Mul(m_OneUse(m_ZExtOrSExt(m_Value(InputA))),
+ m_OneUse(m_ZExtOrSExt(m_Value(InputB))))));
+
+ if (!match(I, Pattern))
+ return GenericLowers;
+
+ auto Mul = cast<Instruction>(I->getOperand(1));
+
+ auto getOpcodeOfOperand = [&](unsigned Idx) {
+ return cast<Instruction>(Mul->getOperand(Idx))->getOpcode();
+ };
+
+ if (getOpcodeOfOperand(0) != getOpcodeOfOperand(1))
+ return GenericLowers;
+
+ if (InputA->getType() != InputTy || InputB->getType() != InputTy)
+ return GenericLowers;
+
+ return TargetLowers;
+}
+
bool AArch64TargetLowering::shouldExpandCttzElements(EVT VT) const {
if (!Subtarget->isSVEorStreamingSVEAvailable())
return true;
@@ -21237,6 +21288,32 @@ static SDValue performIntrinsicCombine(SDNode *N,
switch (IID) {
default:
break;
+ case Intrinsic::experimental_vector_partial_reduce_add: {
+ SDLoc DL(N);
+
+ auto NarrowOp = N->getOperand(1);
+ auto MulOp = N->getOperand(2);
+
+ auto ExtA = MulOp->getOperand(0);
+ auto ExtB = MulOp->getOperand(1);
+
+ unsigned DotIntrinsicId = Intrinsic::not_intrinsic;
+
+ if (ExtA->getOpcode() == ISD::SIGN_EXTEND)
+ DotIntrinsicId = Intrinsic::aarch64_sve_sdot;
+ else if (ExtA->getOpcode() == ISD::ZERO_EXTEND)
+ DotIntrinsicId = Intrinsic::aarch64_sve_udot;
+
+ assert(DotIntrinsicId != Intrinsic::not_intrinsic &&
+ "Unexpected dot product case encountered.");
+
+ auto A = ExtA->getOperand(0);
+ auto B = ExtB->getOperand(0);
+
+ auto IntrinsicId = DAG.getConstant(DotIntrinsicId, DL, MVT::i64);
+ return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, NarrowOp.getValueType(),
+ {IntrinsicId, NarrowOp, A, B});
+ }
case Intrinsic::aarch64_neon_vcvtfxs2fp:
case Intrinsic::aarch64_neon_vcvtfxu2fp:
return tryCombineFixedPointConvert(N, DCI, DAG);
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index 81e15185f985d5..fc79d9766719bc 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -991,6 +991,8 @@ class AArch64TargetLowering : public TargetLowering {
bool shouldExpandGetActiveLaneMask(EVT VT, EVT OpVT) const override;
+ bool shouldExpandPartialReductionIntrinsic(const CallInst *I) const override;
+
bool shouldExpandCttzElements(EVT VT) const override;
/// If a change in streaming mode is required on entry to/return from a
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 45148449dfb821..792bd546019192 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -3533,6 +3533,36 @@ AArch64TTIImpl::getCostOfKeepingLiveOverCall(ArrayRef<Type *> Tys) {
return Cost;
}
+bool AArch64TTIImpl::isPartialReductionSupported(
+ const Instruction *ReductionInstr, Type *InputType, unsigned ScaleFactor,
+ bool IsInputASignExtended, bool IsInputBSignExtended,
+ const Instruction *BinOp) const {
+ if (ReductionInstr->getOpcode() != Instruction::Add)
+ return false;
+
+ // Check that both extends are of the same type
+ if (IsInputASignExtended != IsInputBSignExtended)
+ return false;
+
+ if (!BinOp || BinOp->getOpcode() != Instruction::Mul)
+ return false;
+
+ // Dot product only supports a scale factor of 4
+ if (ScaleFactor != 4)
+ return false;
+
+ Type *ReductionType = ReductionInstr->getType();
+ if (ReductionType->isIntegerTy(32)) {
+ if (!InputType->isIntegerTy(8))
+ return false;
+ } else if (ReductionType->isIntegerTy(64)) {
+ if (!InputType->isIntegerTy(16))
+ return false;
+ }
+
+ return true;
+}
+
unsigned AArch64TTIImpl::getMaxInterleaveFactor(ElementCount VF) {
return ST->getMaxInterleaveFactor();
}
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index a9189fd53f40bb..592b452134e778 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -155,6 +155,12 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
return VF.getKnownMinValue() * ST->getVScaleForTuning();
}
+ bool isPartialReductionSupported(const Instruction *ReductionInstr,
+ Type *InputType, unsigned ScaleFactor,
+ bool IsInputASignExtended,
+ bool IsInputBSignExtended,
+ const Instruction *BinOp = nullptr) const;
+
unsigned getMaxInterleaveFactor(ElementCount VF);
bool prefersVectorizedAddressing() const;
diff --git a/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll
new file mode 100644
index 00000000000000..23b39387fb7a0c
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll
@@ -0,0 +1,109 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc -mtriple=aarch64-unknwon-linux-gnu -mattr=+sve2 -O3 %s -o - | FileCheck %s
+
+target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128"
+target triple = "aarch64-none-unknown-elf"
+
+define <vscale x 4 x i32> @dotp(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b) #0 {
+; CHECK-LABEL: dotp:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: mov z2.s, #0 // =0x0
+; CHECK-NEXT: udot z2.s, z0.b, z1.b
+; CHECK-NEXT: mov z0.d, z2.d
+; CHECK-NEXT: ret
+entry:
+ %a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+ %b.wide = zext <vscale x 16 x i8> %b to <vscale x 16 x i32>
+ %mult = mul nuw nsw <vscale x 16 x i32> %a.wide, %b.wide
+ %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> zeroinitializer, <vscale x 16 x i32> %mult)
+ ret <vscale x 4 x i32> %partial.reduce
+}
+
+define <vscale x 2 x i64> @dotp_wide(<vscale x 8 x i16> %a, <vscale x 8 x i16> %b) #0 {
+; CHECK-LABEL: dotp_wide:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: mov z2.d, #0 // =0x0
+; CHECK-NEXT: udot z2.d, z0.h, z1.h
+; CHECK-NEXT: mov z0.d, z2.d
+; CHECK-NEXT: ret
+entry:
+ %a.wide = zext <vscale x 8 x i16> %a to <vscale x 8 x i64>
+ %b.wide = zext <vscale x 8 x i16> %b to <vscale x 8 x i64>
+ %mult = mul nuw nsw <vscale x 8 x i64> %a.wide, %b.wide
+ %partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv8i64(<vscale x 2 x i64> zeroinitializer, <vscale x 8 x i64> %mult)
+ ret <vscale x 2 x i64> %partial.reduce
+}
+
+define <vscale x 4 x i32> @dotp_sext(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b) #0 {
+; CHECK-LABEL: dotp_sext:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: mov z2.s, #0 // =0x0
+; CHECK-NEXT: sdot z2.s, z0.b, z1.b
+; CHECK-NEXT: mov z0.d, z2.d
+; CHECK-NEXT: ret
+entry:
+ %a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+ %b.wide = sext <vscale x 16 x i8> %b to <vscale x 16 x i32>
+ %mult = mul nuw nsw <vscale x 16 x i32> %a.wide, %b.wide
+ %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> zeroinitializer, <vscale x 16 x i32> %mult)
+ ret <vscale x 4 x i32> %partial.reduce
+}
+
+define <vscale x 2 x i64> @dotp_wide_sext(<vscale x 8 x i16> %a, <vscale x 8 x i16> %b) #0 {
+; CHECK-LABEL: dotp_wide_sext:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: mov z2.d, #0 // =0x0
+; CHECK-NEXT: sdot z2.d, z0.h, z1.h
+; CHECK-NEXT: mov z0.d, z2.d
+; CHECK-NEXT: ret
+entry:
+ %a.wide = sext <vscale x 8 x i16> %a to <vscale x 8 x i64>
+ %b.wide = sext <vscale x 8 x i16> %b to <vscale x 8 x i64>
+ %mult = mul nuw nsw <vscale x 8 x i64> %a.wide, %b.wide
+ %partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv8i64(<vscale x 2 x i64> zeroinitializer, <vscale x 8 x i64> %mult)
+ ret <vscale x 2 x i64> %partial.reduce
+}
+
+define <vscale x 4 x i32> @not_dotp(<vscale x 8 x i8> %a, <vscale x 8 x i8> %b) #0 {
+; CHECK-LABEL: not_dotp:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: and z0.h, z0.h, #0xff
+; CHECK-NEXT: and z1.h, z1.h, #0xff
+; CHECK-NEXT: ptrue p0.s
+; CHECK-NEXT: uunpkhi z2.s, z0.h
+; CHECK-NEXT: uunpkhi z3.s, z1.h
+; CHECK-NEXT: uunpklo z0.s, z0.h
+; CHECK-NEXT: uunpklo z1.s, z1.h
+; CHECK-NEXT: mul z2.s, z2.s, z3.s
+; CHECK-NEXT: mad z0.s, p0/m, z1.s, z2.s
+; CHECK-NEXT: ret
+entry:
+ %a.wide = zext <vscale x 8 x i8> %a to <vscale x 8 x i32>
+ %b.wide = zext <vscale x 8 x i8> %b to <vscale x 8 x i32>
+ %mult = mul nuw nsw <vscale x 8 x i32> %a.wide, %b.wide
+ %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> zeroinitializer, <vscale x 8 x i32> %mult)
+ ret <vscale x 4 x i32> %partial.reduce
+}
+
+define <vscale x 2 x i64> @not_dotp_wide(<vscale x 4 x i16> %a, <vscale x 4 x i16> %b) #0 {
+; CHECK-LABEL: not_dotp_wide:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: and z0.s, z0.s, #0xffff
+; CHECK-NEXT: and z1.s, z1.s, #0xffff
+; CHECK-NEXT: ptrue p0.d
+; CHECK-NEXT: uunpkhi z2.d, z0.s
+; CHECK-NEXT: uunpkhi z3.d, z1.s
+; CHECK-NEXT: uunpklo z0.d, z0.s
+; CHECK-NEXT: uunpklo z1.d, z1.s
+; CHECK-NEXT: mul z2.d, z2.d, z3.d
+; CHECK-NEXT: mad z0.d, p0/m, z1.d, z2.d
+; CHECK-NEXT: ret
+entry:
+ %a.wide = zext <vscale x 4 x i16> %a to <vscale x 4 x i64>
+ %b.wide = zext <vscale x 4 x i16> %b to <vscale x 4 x i64>
+ %mult = mul nuw nsw <vscale x 4 x i64> %a.wide, %b.wide
+ %partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv4i64(<vscale x 2 x i64> zeroinitializer, <vscale x 4 x i64> %mult)
+ ret <vscale x 2 x i64> %partial.reduce
+}
+
+attributes #0 = { "target-features"="+sve2" }
>From 7a5661702155198e6e4f9eed4d83bbc2cd0cee6e Mon Sep 17 00:00:00 2001
From: Sam Tebbs <samuel.tebbs at arm.com>
Date: Mon, 29 Jul 2024 18:08:52 +0100
Subject: [PATCH 02/17] Remove TargetLowers and GenericLowers
---
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 16 +++++++---------
1 file changed, 7 insertions(+), 9 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index d1ee58668ecbd7..b936e4cb4dccf3 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1973,17 +1973,15 @@ bool AArch64TargetLowering::shouldExpandGetActiveLaneMask(EVT ResVT,
bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
const CallInst *CI) const {
- const bool TargetLowers = false;
- const bool GenericLowers = true;
auto *I = dyn_cast<IntrinsicInst>(CI);
if (!I)
- return GenericLowers;
+ return true;
ScalableVectorType *RetTy = dyn_cast<ScalableVectorType>(I->getType());
if (!RetTy)
- return GenericLowers;
+ return true;
ScalableVectorType *InputTy = nullptr;
@@ -1995,7 +1993,7 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
}
if (!InputTy)
- return GenericLowers;
+ return true;
Value *InputA;
Value *InputB;
@@ -2005,7 +2003,7 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
m_OneUse(m_ZExtOrSExt(m_Value(InputB))))));
if (!match(I, Pattern))
- return GenericLowers;
+ return true;
auto Mul = cast<Instruction>(I->getOperand(1));
@@ -2014,12 +2012,12 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
};
if (getOpcodeOfOperand(0) != getOpcodeOfOperand(1))
- return GenericLowers;
+ return true;
if (InputA->getType() != InputTy || InputB->getType() != InputTy)
- return GenericLowers;
+ return true;
- return TargetLowers;
+ return false;
}
bool AArch64TargetLowering::shouldExpandCttzElements(EVT VT) const {
>From 7bddd3b71c858d21ea626eb7f453126ad1a1b65b Mon Sep 17 00:00:00 2001
From: Sam Tebbs <samuel.tebbs at arm.com>
Date: Mon, 29 Jul 2024 18:16:16 +0100
Subject: [PATCH 03/17] Assert that shouldExpandPartialReductionIntrinsic sees
an intrinsic
---
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 3 +--
1 file changed, 1 insertion(+), 2 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index b936e4cb4dccf3..fdcce7a9d124b1 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1975,8 +1975,7 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
const CallInst *CI) const {
auto *I = dyn_cast<IntrinsicInst>(CI);
- if (!I)
- return true;
+ assert(I && "shouldExpandPartialReductionIntrinsic expects an intrinisc");
ScalableVectorType *RetTy = dyn_cast<ScalableVectorType>(I->getType());
>From 48798c1e4ec770f6a47c69e841c048a83bb9bff6 Mon Sep 17 00:00:00 2001
From: Sam Tebbs <samuel.tebbs at arm.com>
Date: Mon, 29 Jul 2024 18:31:24 +0100
Subject: [PATCH 04/17] Allow non-scalable vector types
---
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 9 ++++-----
1 file changed, 4 insertions(+), 5 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index fdcce7a9d124b1..28e13e2e2841cf 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1977,18 +1977,17 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
auto *I = dyn_cast<IntrinsicInst>(CI);
assert(I && "shouldExpandPartialReductionIntrinsic expects an intrinisc");
- ScalableVectorType *RetTy = dyn_cast<ScalableVectorType>(I->getType());
-
+ VectorType *RetTy = dyn_cast<VectorType>(I->getType());
if (!RetTy)
return true;
- ScalableVectorType *InputTy = nullptr;
+ VectorType *InputTy = nullptr;
auto RetScalarTy = RetTy->getScalarType();
if (RetScalarTy->isIntegerTy(64)) {
- InputTy = ScalableVectorType::get(Type::getInt16Ty(I->getContext()), 8);
+ InputTy = VectorType::get(Type::getInt16Ty(I->getContext()), 8, RetTy->isScalableTy());
} else if (RetScalarTy->isIntegerTy(32)) {
- InputTy = ScalableVectorType::get(Type::getInt8Ty(I->getContext()), 16);
+ InputTy = VectorType::get(Type::getInt8Ty(I->getContext()), 16, RetTy->isScalableTy());
}
if (!InputTy)
>From 955c84e7aa1f811f2a78585d9dbf985672d3e21e Mon Sep 17 00:00:00 2001
From: Sam Tebbs <samuel.tebbs at arm.com>
Date: Mon, 29 Jul 2024 19:10:08 +0100
Subject: [PATCH 05/17] Clean up type checking
---
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 9 +++------
1 file changed, 3 insertions(+), 6 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 28e13e2e2841cf..8cf997cf0a3f29 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1984,13 +1984,11 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
VectorType *InputTy = nullptr;
auto RetScalarTy = RetTy->getScalarType();
- if (RetScalarTy->isIntegerTy(64)) {
+ if (RetScalarTy->isIntegerTy(64))
InputTy = VectorType::get(Type::getInt16Ty(I->getContext()), 8, RetTy->isScalableTy());
- } else if (RetScalarTy->isIntegerTy(32)) {
+ else if (RetScalarTy->isIntegerTy(32))
InputTy = VectorType::get(Type::getInt8Ty(I->getContext()), 16, RetTy->isScalableTy());
- }
-
- if (!InputTy)
+ else
return true;
Value *InputA;
@@ -2004,7 +2002,6 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
return true;
auto Mul = cast<Instruction>(I->getOperand(1));
-
auto getOpcodeOfOperand = [&](unsigned Idx) {
return cast<Instruction>(Mul->getOperand(Idx))->getOpcode();
};
>From 7acadbc84967f045e803ce0d9e9008ab1554cdf8 Mon Sep 17 00:00:00 2001
From: Samuel Tebbs <samuel.tebbs at arm.com>
Date: Thu, 1 Aug 2024 11:04:37 +0100
Subject: [PATCH 06/17] Restrict to scalable vector types and clean up type
checking
---
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 2 +-
.../lib/Target/AArch64/AArch64TargetTransformInfo.cpp | 11 +++--------
2 files changed, 4 insertions(+), 9 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 8cf997cf0a3f29..510a92dc8d7345 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1978,7 +1978,7 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
assert(I && "shouldExpandPartialReductionIntrinsic expects an intrinisc");
VectorType *RetTy = dyn_cast<VectorType>(I->getType());
- if (!RetTy)
+ if (!RetTy || !RetTy->isScalableTy())
return true;
VectorType *InputTy = nullptr;
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 792bd546019192..afa9acf5c8de3f 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -3552,15 +3552,10 @@ bool AArch64TTIImpl::isPartialReductionSupported(
return false;
Type *ReductionType = ReductionInstr->getType();
- if (ReductionType->isIntegerTy(32)) {
- if (!InputType->isIntegerTy(8))
- return false;
- } else if (ReductionType->isIntegerTy(64)) {
- if (!InputType->isIntegerTy(16))
- return false;
- }
- return true;
+ return ((ReductionType->isIntegerTy(32) && InputType->isIntegerTy(8)) ||
+ (ReductionType->isIntegerTy(64) && InputType->isIntegerTy(16))) &&
+ ReductionType->isScalableTy();
}
unsigned AArch64TTIImpl::getMaxInterleaveFactor(ElementCount VF) {
>From 0213f5d0be6f27888b150e2f3b8af8d6ae64ccee Mon Sep 17 00:00:00 2001
From: Samuel Tebbs <samuel.tebbs at arm.com>
Date: Thu, 1 Aug 2024 11:36:50 +0100
Subject: [PATCH 07/17] Simplify instruction matching in
shouldExpandPartialReduction
---
.../Target/AArch64/AArch64ISelLowering.cpp | 56 +++++++++----------
1 file changed, 27 insertions(+), 29 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 510a92dc8d7345..4fbe3231170716 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1981,38 +1981,36 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
if (!RetTy || !RetTy->isScalableTy())
return true;
- VectorType *InputTy = nullptr;
-
- auto RetScalarTy = RetTy->getScalarType();
- if (RetScalarTy->isIntegerTy(64))
- InputTy = VectorType::get(Type::getInt16Ty(I->getContext()), 8, RetTy->isScalableTy());
- else if (RetScalarTy->isIntegerTy(32))
- InputTy = VectorType::get(Type::getInt8Ty(I->getContext()), 16, RetTy->isScalableTy());
- else
- return true;
-
Value *InputA;
Value *InputB;
+ if (match(I, m_Intrinsic<Intrinsic::experimental_vector_partial_reduce_add>(
+ m_Value(),
+ m_OneUse(m_Mul(m_OneUse(m_ZExtOrSExt(m_Value(InputA))),
+ m_OneUse(m_ZExtOrSExt(m_Value(InputB)))))))) {
+ VectorType *InputAType = dyn_cast<VectorType>(InputA->getType());
+ VectorType *InputBType = dyn_cast<VectorType>(InputB->getType());
+ if (!InputAType || !InputBType)
+ return true;
+ ElementCount ExpectedCount8 = ElementCount::get(8, RetTy->isScalableTy());
+ ElementCount ExpectedCount16 = ElementCount::get(16, RetTy->isScalableTy());
+ if ((RetTy->getScalarType()->isIntegerTy(64) &&
+ InputAType->getElementType()->isIntegerTy(16) &&
+ InputAType->getElementCount() == ExpectedCount8 &&
+ InputAType == InputBType) ||
+
+ (RetTy->getScalarType()->isIntegerTy(32) &&
+ InputAType->getElementType()->isIntegerTy(8) &&
+ InputAType->getElementCount() == ExpectedCount16 &&
+ InputAType == InputBType)) {
+ auto *Mul = cast<Instruction>(I->getOperand(1));
+ auto *Mul0 = cast<Instruction>(Mul->getOperand(0));
+ auto *Mul1 = cast<Instruction>(Mul->getOperand(1));
+ if (Mul0->getOpcode() == Mul1->getOpcode())
+ return false;
+ }
+ }
- auto Pattern = m_Intrinsic<Intrinsic::experimental_vector_partial_reduce_add>(
- m_Value(), m_OneUse(m_Mul(m_OneUse(m_ZExtOrSExt(m_Value(InputA))),
- m_OneUse(m_ZExtOrSExt(m_Value(InputB))))));
-
- if (!match(I, Pattern))
- return true;
-
- auto Mul = cast<Instruction>(I->getOperand(1));
- auto getOpcodeOfOperand = [&](unsigned Idx) {
- return cast<Instruction>(Mul->getOperand(Idx))->getOpcode();
- };
-
- if (getOpcodeOfOperand(0) != getOpcodeOfOperand(1))
- return true;
-
- if (InputA->getType() != InputTy || InputB->getType() != InputTy)
- return true;
-
- return false;
+ return true;
}
bool AArch64TargetLowering::shouldExpandCttzElements(EVT VT) const {
>From 82769ae7bbacda12114c5a389bab9ddd13900e46 Mon Sep 17 00:00:00 2001
From: Sam Tebbs <samuel.tebbs at arm.com>
Date: Fri, 9 Aug 2024 16:38:22 +0100
Subject: [PATCH 08/17] Add fallback in case the nodes aren't as we expect at
lowering time
---
.../Target/AArch64/AArch64ISelLowering.cpp | 67 ++++++++++++++++---
1 file changed, 59 insertions(+), 8 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 4fbe3231170716..2ddcb5686042da 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -21282,28 +21282,79 @@ static SDValue performIntrinsicCombine(SDNode *N,
case Intrinsic::experimental_vector_partial_reduce_add: {
SDLoc DL(N);
+ bool IsValidDotProduct = false;
+
auto NarrowOp = N->getOperand(1);
auto MulOp = N->getOperand(2);
+ if (MulOp->getOpcode() == ISD::MUL)
+ IsValidDotProduct = true;
auto ExtA = MulOp->getOperand(0);
auto ExtB = MulOp->getOperand(1);
+ bool IsSExt = ExtA->getOpcode() == ISD::SIGN_EXTEND;
+ bool IsZExt = ExtA->getOpcode() == ISD::ZERO_EXTEND;
+ if (ExtA->getOpcode() == ExtB->getOpcode() && (IsSExt || IsZExt))
+ IsValidDotProduct = true;
unsigned DotIntrinsicId = Intrinsic::not_intrinsic;
- if (ExtA->getOpcode() == ISD::SIGN_EXTEND)
+ if (IsSExt && IsValidDotProduct)
DotIntrinsicId = Intrinsic::aarch64_sve_sdot;
- else if (ExtA->getOpcode() == ISD::ZERO_EXTEND)
+ else if (IsZExt && IsValidDotProduct)
DotIntrinsicId = Intrinsic::aarch64_sve_udot;
- assert(DotIntrinsicId != Intrinsic::not_intrinsic &&
+ assert((!IsValidDotProduct || DotIntrinsicId != Intrinsic::not_intrinsic) &&
"Unexpected dot product case encountered.");
- auto A = ExtA->getOperand(0);
- auto B = ExtB->getOperand(0);
+ if (IsValidDotProduct) {
+ auto A = ExtA->getOperand(0);
+ auto B = ExtB->getOperand(0);
+
+ auto IntrinsicId = DAG.getConstant(DotIntrinsicId, DL, MVT::i64);
+ return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, NarrowOp.getValueType(),
+ {IntrinsicId, NarrowOp, A, B});
+ } else {
+ // If the node doesn't match a dot product, lower to a series of ADDs
+ // instead.
+ SDValue Op0 = N->getOperand(0);
+ SDValue Op1 = N->getOperand(1);
+ EVT Type0 = Op0->getValueType(0);
+ EVT Type1 = Op1->getValueType(0);
+
+ // Canonicalise so that Op1 has the larger type
+ if (Type1.getVectorNumElements() > Type0.getVectorNumElements()) {
+ std::swap(Op0, Op1);
+ std::swap(Type0, Type1);
+ }
- auto IntrinsicId = DAG.getConstant(DotIntrinsicId, DL, MVT::i64);
- return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, NarrowOp.getValueType(),
- {IntrinsicId, NarrowOp, A, B});
+ auto Type0Elements = Type0.getVectorNumElements();
+ auto Type1Elements = Type1.getVectorNumElements();
+ auto Type0ElementSize =
+ Type0.getVectorElementType().getScalarSizeInBits();
+ auto Type1ElementSize =
+ Type1.getVectorElementType().getScalarSizeInBits();
+
+ // If the types are equal then a single ADD is fine
+ if (Type0 == Type1)
+ return DAG.getNode(ISD::ADD, DL, Type0, {Op0, Op1});
+
+ // Otherwise, we need to add each subvector together so that the output is
+ // the intrinsic's return type. For example, <4 x i32>
+ // partial.reduction(<4 x i32> a, <16 x i32> b) becomes a + b[0..3] +
+ // b[4..7] + b[8..11] + b[12..15]
+ SDValue Add = Op0;
+ for (unsigned i = 0; i < Type1Elements / Type0Elements; i++) {
+ SDValue Subvec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, Type0, Op1,
+ DAG.getConstant(i, DL, MVT::i64));
+
+ if (Type1ElementSize < Type0ElementSize)
+ Subvec = DAG.getNode(ISD::ANY_EXTEND, DL, Type0, Subvec);
+ else if (Type1ElementSize > Type0ElementSize)
+ Subvec = DAG.getNode(ISD::TRUNCATE, DL, Type0, Subvec);
+ Add = DAG.getNode(ISD::ADD, DL, Type0, {Add, Subvec});
+ }
+ return Add;
+ }
}
case Intrinsic::aarch64_neon_vcvtfxs2fp:
case Intrinsic::aarch64_neon_vcvtfxu2fp:
>From f8ee528af87e5fa2bd88538961addc1d3ea92afb Mon Sep 17 00:00:00 2001
From: Sam Tebbs <samuel.tebbs at arm.com>
Date: Mon, 12 Aug 2024 11:02:28 +0100
Subject: [PATCH 09/17] Fix logic error with fallback case
---
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 14 +++++++-------
1 file changed, 7 insertions(+), 7 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 2ddcb5686042da..d39e0a512f9d6c 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -21282,19 +21282,19 @@ static SDValue performIntrinsicCombine(SDNode *N,
case Intrinsic::experimental_vector_partial_reduce_add: {
SDLoc DL(N);
- bool IsValidDotProduct = false;
+ bool IsValidDotProduct = true;
auto NarrowOp = N->getOperand(1);
auto MulOp = N->getOperand(2);
- if (MulOp->getOpcode() == ISD::MUL)
- IsValidDotProduct = true;
+ if (MulOp->getOpcode() != ISD::MUL)
+ IsValidDotProduct = false;
auto ExtA = MulOp->getOperand(0);
auto ExtB = MulOp->getOperand(1);
bool IsSExt = ExtA->getOpcode() == ISD::SIGN_EXTEND;
bool IsZExt = ExtA->getOpcode() == ISD::ZERO_EXTEND;
- if (ExtA->getOpcode() == ExtB->getOpcode() && (IsSExt || IsZExt))
- IsValidDotProduct = true;
+ if (ExtA->getOpcode() != ExtB->getOpcode() || (!IsSExt && !IsZExt))
+ IsValidDotProduct = false;
unsigned DotIntrinsicId = Intrinsic::not_intrinsic;
@@ -21316,8 +21316,8 @@ static SDValue performIntrinsicCombine(SDNode *N,
} else {
// If the node doesn't match a dot product, lower to a series of ADDs
// instead.
- SDValue Op0 = N->getOperand(0);
- SDValue Op1 = N->getOperand(1);
+ SDValue Op0 = N->getOperand(1);
+ SDValue Op1 = N->getOperand(2);
EVT Type0 = Op0->getValueType(0);
EVT Type1 = Op1->getValueType(0);
>From a3df2e98a02b6fa4e568288f1ac8607eec86aab7 Mon Sep 17 00:00:00 2001
From: Sam Tebbs <samuel.tebbs at arm.com>
Date: Tue, 13 Aug 2024 14:18:53 +0100
Subject: [PATCH 10/17] Pass IntrinsicInst to
shouldExpandPartialReductionIntrinsic
---
llvm/include/llvm/CodeGen/TargetLowering.h | 3 ++-
llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp | 2 +-
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 5 +----
llvm/lib/Target/AArch64/AArch64ISelLowering.h | 3 ++-
4 files changed, 6 insertions(+), 7 deletions(-)
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 07d99aec47122a..b20c6be16f9e8a 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -455,7 +455,8 @@ class TargetLoweringBase {
/// Return true if the @llvm.experimental.vector.partial.reduce.* intrinsic
/// should be expanded using generic code in SelectionDAGBuilder.
- virtual bool shouldExpandPartialReductionIntrinsic(const CallInst *I) const {
+ virtual bool
+ shouldExpandPartialReductionIntrinsic(const IntrinsicInst *I) const {
return true;
}
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index c70ab253c1aabc..7211c00240bb27 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -7986,7 +7986,7 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
}
case Intrinsic::experimental_vector_partial_reduce_add: {
- if (!TLI.shouldExpandPartialReductionIntrinsic(&I)) {
+ if (!TLI.shouldExpandPartialReductionIntrinsic(cast<IntrinsicInst>(&I))) {
visitTargetIntrinsic(I, Intrinsic);
return;
}
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index d39e0a512f9d6c..7957451173c012 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1972,10 +1972,7 @@ bool AArch64TargetLowering::shouldExpandGetActiveLaneMask(EVT ResVT,
}
bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
- const CallInst *CI) const {
-
- auto *I = dyn_cast<IntrinsicInst>(CI);
- assert(I && "shouldExpandPartialReductionIntrinsic expects an intrinisc");
+ const IntrinsicInst *I) const {
VectorType *RetTy = dyn_cast<VectorType>(I->getType());
if (!RetTy || !RetTy->isScalableTy())
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index fc79d9766719bc..a870fb5f551209 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -991,7 +991,8 @@ class AArch64TargetLowering : public TargetLowering {
bool shouldExpandGetActiveLaneMask(EVT VT, EVT OpVT) const override;
- bool shouldExpandPartialReductionIntrinsic(const CallInst *I) const override;
+ bool
+ shouldExpandPartialReductionIntrinsic(const IntrinsicInst *I) const override;
bool shouldExpandCttzElements(EVT VT) const override;
>From e47670e6897218917519b7648a31930787949b6e Mon Sep 17 00:00:00 2001
From: Sam Tebbs <samuel.tebbs at arm.com>
Date: Tue, 13 Aug 2024 14:31:22 +0100
Subject: [PATCH 11/17] Remove one-use restriction
---
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 8 ++++----
1 file changed, 4 insertions(+), 4 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 7957451173c012..362faad4d925cb 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1980,10 +1980,10 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
Value *InputA;
Value *InputB;
- if (match(I, m_Intrinsic<Intrinsic::experimental_vector_partial_reduce_add>(
- m_Value(),
- m_OneUse(m_Mul(m_OneUse(m_ZExtOrSExt(m_Value(InputA))),
- m_OneUse(m_ZExtOrSExt(m_Value(InputB)))))))) {
+ if (match(I,
+ m_Intrinsic<Intrinsic::experimental_vector_partial_reduce_add>(
+ m_Value(), m_OneUse(m_Mul(m_ZExtOrSExt(m_Value(InputA)),
+ m_ZExtOrSExt(m_Value(InputB))))))) {
VectorType *InputAType = dyn_cast<VectorType>(InputA->getType());
VectorType *InputBType = dyn_cast<VectorType>(InputB->getType());
if (!InputAType || !InputBType)
>From ff5f96a8588cf9c9f359bcf092df2791cebc70cb Mon Sep 17 00:00:00 2001
From: Sam Tebbs <samuel.tebbs at arm.com>
Date: Tue, 13 Aug 2024 14:32:10 +0100
Subject: [PATCH 12/17] Remove new line
---
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 1 -
1 file changed, 1 deletion(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 362faad4d925cb..20417aab02c8e4 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1994,7 +1994,6 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
InputAType->getElementType()->isIntegerTy(16) &&
InputAType->getElementCount() == ExpectedCount8 &&
InputAType == InputBType) ||
-
(RetTy->getScalarType()->isIntegerTy(32) &&
InputAType->getElementType()->isIntegerTy(8) &&
InputAType->getElementCount() == ExpectedCount16 &&
>From 3ff122c183431b472a85105500287e422c7861f3 Mon Sep 17 00:00:00 2001
From: Sam Tebbs <samuel.tebbs at arm.com>
Date: Tue, 13 Aug 2024 20:21:43 +0100
Subject: [PATCH 13/17] Remove extending/truncating for fallback case
---
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 8 --------
1 file changed, 8 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 20417aab02c8e4..4b7a9cab8a57e5 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -21325,10 +21325,6 @@ static SDValue performIntrinsicCombine(SDNode *N,
auto Type0Elements = Type0.getVectorNumElements();
auto Type1Elements = Type1.getVectorNumElements();
- auto Type0ElementSize =
- Type0.getVectorElementType().getScalarSizeInBits();
- auto Type1ElementSize =
- Type1.getVectorElementType().getScalarSizeInBits();
// If the types are equal then a single ADD is fine
if (Type0 == Type1)
@@ -21343,10 +21339,6 @@ static SDValue performIntrinsicCombine(SDNode *N,
SDValue Subvec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, Type0, Op1,
DAG.getConstant(i, DL, MVT::i64));
- if (Type1ElementSize < Type0ElementSize)
- Subvec = DAG.getNode(ISD::ANY_EXTEND, DL, Type0, Subvec);
- else if (Type1ElementSize > Type0ElementSize)
- Subvec = DAG.getNode(ISD::TRUNCATE, DL, Type0, Subvec);
Add = DAG.getNode(ISD::ADD, DL, Type0, {Add, Subvec});
}
return Add;
>From 81d5b0c7036af75e6a452e14901eca1adc66608a Mon Sep 17 00:00:00 2001
From: Sam Tebbs <samuel.tebbs at arm.com>
Date: Tue, 13 Aug 2024 20:27:04 +0100
Subject: [PATCH 14/17] Clean up test target
---
llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll | 5 +----
1 file changed, 1 insertion(+), 4 deletions(-)
diff --git a/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll
index 23b39387fb7a0c..0facb2049135f6 100644
--- a/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll
@@ -1,8 +1,5 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
-; RUN: llc -mtriple=aarch64-unknwon-linux-gnu -mattr=+sve2 -O3 %s -o - | FileCheck %s
-
-target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128"
-target triple = "aarch64-none-unknown-elf"
+; RUN: llc -mtriple=aarch64 -mattr=+sve2 %s -o - | FileCheck %s
define <vscale x 4 x i32> @dotp(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b) #0 {
; CHECK-LABEL: dotp:
>From 127bfc4def806e5ed5260d77aac0cbd226965606 Mon Sep 17 00:00:00 2001
From: Samuel Tebbs <samuel.tebbs at arm.com>
Date: Wed, 14 Aug 2024 09:42:32 +0100
Subject: [PATCH 15/17] Remove #0 attribute from test
---
.../CodeGen/AArch64/partial-reduce-dot-product.ll | 14 ++++++--------
1 file changed, 6 insertions(+), 8 deletions(-)
diff --git a/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll
index 0facb2049135f6..16ef219a93c9bf 100644
--- a/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll
@@ -1,7 +1,7 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
; RUN: llc -mtriple=aarch64 -mattr=+sve2 %s -o - | FileCheck %s
-define <vscale x 4 x i32> @dotp(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b) #0 {
+define <vscale x 4 x i32> @dotp(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
; CHECK-LABEL: dotp:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: mov z2.s, #0 // =0x0
@@ -16,7 +16,7 @@ entry:
ret <vscale x 4 x i32> %partial.reduce
}
-define <vscale x 2 x i64> @dotp_wide(<vscale x 8 x i16> %a, <vscale x 8 x i16> %b) #0 {
+define <vscale x 2 x i64> @dotp_wide(<vscale x 8 x i16> %a, <vscale x 8 x i16> %b) {
; CHECK-LABEL: dotp_wide:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: mov z2.d, #0 // =0x0
@@ -31,7 +31,7 @@ entry:
ret <vscale x 2 x i64> %partial.reduce
}
-define <vscale x 4 x i32> @dotp_sext(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b) #0 {
+define <vscale x 4 x i32> @dotp_sext(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
; CHECK-LABEL: dotp_sext:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: mov z2.s, #0 // =0x0
@@ -46,7 +46,7 @@ entry:
ret <vscale x 4 x i32> %partial.reduce
}
-define <vscale x 2 x i64> @dotp_wide_sext(<vscale x 8 x i16> %a, <vscale x 8 x i16> %b) #0 {
+define <vscale x 2 x i64> @dotp_wide_sext(<vscale x 8 x i16> %a, <vscale x 8 x i16> %b) {
; CHECK-LABEL: dotp_wide_sext:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: mov z2.d, #0 // =0x0
@@ -61,7 +61,7 @@ entry:
ret <vscale x 2 x i64> %partial.reduce
}
-define <vscale x 4 x i32> @not_dotp(<vscale x 8 x i8> %a, <vscale x 8 x i8> %b) #0 {
+define <vscale x 4 x i32> @not_dotp(<vscale x 8 x i8> %a, <vscale x 8 x i8> %b) {
; CHECK-LABEL: not_dotp:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: and z0.h, z0.h, #0xff
@@ -82,7 +82,7 @@ entry:
ret <vscale x 4 x i32> %partial.reduce
}
-define <vscale x 2 x i64> @not_dotp_wide(<vscale x 4 x i16> %a, <vscale x 4 x i16> %b) #0 {
+define <vscale x 2 x i64> @not_dotp_wide(<vscale x 4 x i16> %a, <vscale x 4 x i16> %b) {
; CHECK-LABEL: not_dotp_wide:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: and z0.s, z0.s, #0xffff
@@ -102,5 +102,3 @@ entry:
%partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv4i64(<vscale x 2 x i64> zeroinitializer, <vscale x 4 x i64> %mult)
ret <vscale x 2 x i64> %partial.reduce
}
-
-attributes #0 = { "target-features"="+sve2" }
>From 9f791a1f4b79e2bc58c9efc3a277c9bf3e9292ad Mon Sep 17 00:00:00 2001
From: Samuel Tebbs <samuel.tebbs at arm.com>
Date: Wed, 14 Aug 2024 10:55:12 +0100
Subject: [PATCH 16/17] Allow i8 to i64 dot products
---
.../Target/AArch64/AArch64ISelLowering.cpp | 34 ++++++++-
.../AArch64/partial-reduce-dot-product.ll | 72 +++++++++++++++++++
2 files changed, 103 insertions(+), 3 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 4b7a9cab8a57e5..d084fc3f969f34 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1990,11 +1990,15 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
return true;
ElementCount ExpectedCount8 = ElementCount::get(8, RetTy->isScalableTy());
ElementCount ExpectedCount16 = ElementCount::get(16, RetTy->isScalableTy());
+ // Check that the input type is 4 times smaller than the output type. If the
+ // output type is 64 bit then we can accept 8 bit inputs if we do a 32 bit
+ // dot product and add a zext/sext.
if ((RetTy->getScalarType()->isIntegerTy(64) &&
InputAType->getElementType()->isIntegerTy(16) &&
InputAType->getElementCount() == ExpectedCount8 &&
InputAType == InputBType) ||
- (RetTy->getScalarType()->isIntegerTy(32) &&
+ ((RetTy->getScalarType()->isIntegerTy(32) ||
+ RetTy->getScalarType()->isIntegerTy(64)) &&
InputAType->getElementType()->isIntegerTy(8) &&
InputAType->getElementCount() == ExpectedCount16 &&
InputAType == InputBType)) {
@@ -21305,10 +21309,34 @@ static SDValue performIntrinsicCombine(SDNode *N,
if (IsValidDotProduct) {
auto A = ExtA->getOperand(0);
auto B = ExtB->getOperand(0);
+ EVT Type = NarrowOp.getValueType();
+
+ // 8 bit input to 64 bit output can be done by doing a 32 bit dot product
+ // and extending the output
+ bool Extend = A->getValueType(0).getScalarSizeInBits() == 8 &&
+ Type.getScalarSizeInBits() == 64;
+ SDValue Accumulator = NarrowOp;
+ if (Extend) {
+ Type = Type.changeVectorElementType(
+ EVT::getIntegerVT(*DAG.getContext(), 32));
+ // The accumulator is of the wider type so we insert a 0 accumulator and
+ // add the proper one after extending
+ Accumulator = DAG.getNode(ISD::SPLAT_VECTOR, DL, MVT::nxv4i32,
+ DAG.getConstant(0, DL, MVT::i32));
+ }
auto IntrinsicId = DAG.getConstant(DotIntrinsicId, DL, MVT::i64);
- return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, NarrowOp.getValueType(),
- {IntrinsicId, NarrowOp, A, B});
+ auto DotProduct = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, Type,
+ {IntrinsicId, Accumulator, A, B});
+ if (Extend) {
+ auto Extended =
+ DAG.getNode(IsZExt ? ISD::ZERO_EXTEND : ISD::SIGN_EXTEND, DL,
+ NarrowOp.getValueType(), {DotProduct});
+ auto AccAdd = DAG.getNode(ISD::ADD, DL, NarrowOp.getValueType(),
+ {NarrowOp, Extended});
+ DotProduct = AccAdd;
+ }
+ return DotProduct;
} else {
// If the node doesn't match a dot product, lower to a series of ADDs
// instead.
diff --git a/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll
index 16ef219a93c9bf..c1cf9026d693ce 100644
--- a/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll
@@ -61,6 +61,78 @@ entry:
ret <vscale x 2 x i64> %partial.reduce
}
+define <vscale x 4 x i64> @dotp_8to64(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
+; CHECK-LABEL: dotp_8to64:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: mov z2.s, #0 // =0x0
+; CHECK-NEXT: udot z2.s, z0.b, z1.b
+; CHECK-NEXT: uunpklo z0.d, z2.s
+; CHECK-NEXT: uunpkhi z1.d, z2.s
+; CHECK-NEXT: ret
+entry:
+ %a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i64>
+ %b.wide = zext <vscale x 16 x i8> %b to <vscale x 16 x i64>
+ %mult = mul nuw nsw <vscale x 16 x i64> %a.wide, %b.wide
+ %partial.reduce = tail call <vscale x 4 x i64> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(
+ <vscale x 4 x i64> zeroinitializer, <vscale x 16 x i64> %mult)
+ ret <vscale x 4 x i64> %partial.reduce
+}
+
+define <vscale x 4 x i64> @dotp_sext_8to64(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
+; CHECK-LABEL: dotp_sext_8to64:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: mov z2.s, #0 // =0x0
+; CHECK-NEXT: sdot z2.s, z0.b, z1.b
+; CHECK-NEXT: sunpklo z0.d, z2.s
+; CHECK-NEXT: sunpkhi z1.d, z2.s
+; CHECK-NEXT: ret
+entry:
+ %a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i64>
+ %b.wide = sext <vscale x 16 x i8> %b to <vscale x 16 x i64>
+ %mult = mul nuw nsw <vscale x 16 x i64> %a.wide, %b.wide
+ %partial.reduce = tail call <vscale x 4 x i64> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(
+ <vscale x 4 x i64> zeroinitializer, <vscale x 16 x i64> %mult)
+ ret <vscale x 4 x i64> %partial.reduce
+}
+
+define <vscale x 4 x i64> @dotp_8to64_accumulator(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b, <vscale x 4 x i64> %acc) {
+; CHECK-LABEL: dotp_8to64_accumulator:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: mov z4.s, #0 // =0x0
+; CHECK-NEXT: udot z4.s, z0.b, z1.b
+; CHECK-NEXT: uunpklo z0.d, z4.s
+; CHECK-NEXT: uunpkhi z1.d, z4.s
+; CHECK-NEXT: add z0.d, z2.d, z0.d
+; CHECK-NEXT: add z1.d, z3.d, z1.d
+; CHECK-NEXT: ret
+entry:
+ %a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i64>
+ %b.wide = zext <vscale x 16 x i8> %b to <vscale x 16 x i64>
+ %mult = mul nuw nsw <vscale x 16 x i64> %a.wide, %b.wide
+ %partial.reduce = tail call <vscale x 4 x i64> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(
+ <vscale x 4 x i64> %acc, <vscale x 16 x i64> %mult)
+ ret <vscale x 4 x i64> %partial.reduce
+}
+
+define <vscale x 4 x i64> @dotp_sext_8to64_accumulator(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b, <vscale x 4 x i64> %acc) {
+; CHECK-LABEL: dotp_sext_8to64_accumulator:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: mov z4.s, #0 // =0x0
+; CHECK-NEXT: sdot z4.s, z0.b, z1.b
+; CHECK-NEXT: sunpklo z0.d, z4.s
+; CHECK-NEXT: sunpkhi z1.d, z4.s
+; CHECK-NEXT: add z0.d, z2.d, z0.d
+; CHECK-NEXT: add z1.d, z3.d, z1.d
+; CHECK-NEXT: ret
+entry:
+ %a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i64>
+ %b.wide = sext <vscale x 16 x i8> %b to <vscale x 16 x i64>
+ %mult = mul nuw nsw <vscale x 16 x i64> %a.wide, %b.wide
+ %partial.reduce = tail call <vscale x 4 x i64> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(
+ <vscale x 4 x i64> %acc, <vscale x 16 x i64> %mult)
+ ret <vscale x 4 x i64> %partial.reduce
+}
+
define <vscale x 4 x i32> @not_dotp(<vscale x 8 x i8> %a, <vscale x 8 x i8> %b) {
; CHECK-LABEL: not_dotp:
; CHECK: // %bb.0: // %entry
>From 0f2dfedb57c64c49cdbbc71187e94338ee5b49d8 Mon Sep 17 00:00:00 2001
From: Samuel Tebbs <samuel.tebbs at arm.com>
Date: Tue, 20 Aug 2024 13:53:11 +0100
Subject: [PATCH 17/17] Remove isPartialReductionSupported
---
.../AArch64/AArch64TargetTransformInfo.cpp | 25 -------------------
.../AArch64/AArch64TargetTransformInfo.h | 6 -----
2 files changed, 31 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index afa9acf5c8de3f..45148449dfb821 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -3533,31 +3533,6 @@ AArch64TTIImpl::getCostOfKeepingLiveOverCall(ArrayRef<Type *> Tys) {
return Cost;
}
-bool AArch64TTIImpl::isPartialReductionSupported(
- const Instruction *ReductionInstr, Type *InputType, unsigned ScaleFactor,
- bool IsInputASignExtended, bool IsInputBSignExtended,
- const Instruction *BinOp) const {
- if (ReductionInstr->getOpcode() != Instruction::Add)
- return false;
-
- // Check that both extends are of the same type
- if (IsInputASignExtended != IsInputBSignExtended)
- return false;
-
- if (!BinOp || BinOp->getOpcode() != Instruction::Mul)
- return false;
-
- // Dot product only supports a scale factor of 4
- if (ScaleFactor != 4)
- return false;
-
- Type *ReductionType = ReductionInstr->getType();
-
- return ((ReductionType->isIntegerTy(32) && InputType->isIntegerTy(8)) ||
- (ReductionType->isIntegerTy(64) && InputType->isIntegerTy(16))) &&
- ReductionType->isScalableTy();
-}
-
unsigned AArch64TTIImpl::getMaxInterleaveFactor(ElementCount VF) {
return ST->getMaxInterleaveFactor();
}
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index 592b452134e778..a9189fd53f40bb 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -155,12 +155,6 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
return VF.getKnownMinValue() * ST->getVScaleForTuning();
}
- bool isPartialReductionSupported(const Instruction *ReductionInstr,
- Type *InputType, unsigned ScaleFactor,
- bool IsInputASignExtended,
- bool IsInputBSignExtended,
- const Instruction *BinOp = nullptr) const;
-
unsigned getMaxInterleaveFactor(ElementCount VF);
bool prefersVectorizedAddressing() const;
More information about the llvm-commits
mailing list