[llvm] [AArch64][SVE] Add codegen support for partial reduction lowering to wide add instructions (PR #114406)
James Chesterman via llvm-commits
llvm-commits at lists.llvm.org
Wed Nov 6 01:57:41 PST 2024
https://github.com/JamesChesterman updated https://github.com/llvm/llvm-project/pull/114406
>From daa1cdb04b4904b2cb97fde0f608f467bdab557f Mon Sep 17 00:00:00 2001
From: James Chesterman <james.chesterman at arm.com>
Date: Thu, 31 Oct 2024 13:08:06 +0000
Subject: [PATCH 1/2] [AArch64][SVE] Add codegen support for partial reduction
lowering to wide add instructions
---
.../Target/AArch64/AArch64ISelLowering.cpp | 61 ++++++++++++++-
.../AArch64/sve-partial-reduce-wide-add.ll | 74 +++++++++++++++++++
2 files changed, 134 insertions(+), 1 deletion(-)
create mode 100644 llvm/test/CodeGen/AArch64/sve-partial-reduce-wide-add.ll
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 4c0cd1ac3d4512..8efc8244426ef3 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -2042,7 +2042,8 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
EVT VT = EVT::getEVT(I->getType());
return VT != MVT::nxv4i64 && VT != MVT::nxv4i32 && VT != MVT::nxv2i64 &&
- VT != MVT::v4i64 && VT != MVT::v4i32 && VT != MVT::v2i32;
+ VT != MVT::nxv8i16 && VT != MVT::v4i64 && VT != MVT::v4i32 &&
+ VT != MVT::v2i32 && VT != MVT::v8i16;
}
bool AArch64TargetLowering::shouldExpandCttzElements(EVT VT) const {
@@ -21783,6 +21784,62 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
return DAG.getNode(Opcode, DL, ReducedType, NarrowOp, A, B);
}
+SDValue tryLowerPartialReductionToWideAdd(SDNode *N,
+ const AArch64Subtarget *Subtarget,
+ SelectionDAG &DAG) {
+
+ assert(N->getOpcode() == ISD::INTRINSIC_WO_CHAIN &&
+ getIntrinsicID(N) ==
+ Intrinsic::experimental_vector_partial_reduce_add &&
+ "Expected a partial reduction node");
+
+ bool Scalable = N->getValueType(0).isScalableVector();
+ if (Scalable && !Subtarget->isSVEorStreamingSVEAvailable())
+ return SDValue();
+
+ SDLoc DL(N);
+
+ auto Accumulator = N->getOperand(1);
+ auto ExtInput = N->getOperand(2);
+
+ EVT AccumulatorType = Accumulator.getValueType();
+ EVT AccumulatorElementType = AccumulatorType.getVectorElementType();
+
+ if (ExtInput.getValueType().getVectorElementType() != AccumulatorElementType)
+ return SDValue();
+
+ unsigned ExtInputOpcode = ExtInput->getOpcode();
+ if (!ISD::isExtOpcode(ExtInputOpcode))
+ return SDValue();
+
+ auto Input = ExtInput->getOperand(0);
+ EVT InputType = Input.getValueType();
+
+ // To do this transformation, output element size needs to be double input
+ // element size, and output number of elements needs to be half the input
+ // number of elements
+ if (!(InputType.getVectorElementType().getSizeInBits() * 2 ==
+ AccumulatorElementType.getSizeInBits()) ||
+ !(AccumulatorType.getVectorElementCount() * 2 ==
+ InputType.getVectorElementCount()) ||
+ !(AccumulatorType.isScalableVector() == InputType.isScalableVector()))
+ return SDValue();
+
+ bool InputIsSigned = ExtInputOpcode == ISD::SIGN_EXTEND;
+ auto BottomIntrinsic = InputIsSigned ? Intrinsic::aarch64_sve_saddwb
+ : Intrinsic::aarch64_sve_uaddwb;
+ auto TopIntrinsic = InputIsSigned ? Intrinsic::aarch64_sve_saddwt
+ : Intrinsic::aarch64_sve_uaddwt;
+
+ auto BottomID =
+ DAG.getTargetConstant(BottomIntrinsic, DL, AccumulatorElementType);
+ auto BottomNode = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, AccumulatorType,
+ BottomID, Accumulator, Input);
+ auto TopID = DAG.getTargetConstant(TopIntrinsic, DL, AccumulatorElementType);
+ return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, AccumulatorType, TopID,
+ BottomNode, Input);
+}
+
static SDValue performIntrinsicCombine(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI,
const AArch64Subtarget *Subtarget) {
@@ -21794,6 +21851,8 @@ static SDValue performIntrinsicCombine(SDNode *N,
case Intrinsic::experimental_vector_partial_reduce_add: {
if (auto Dot = tryLowerPartialReductionToDot(N, Subtarget, DAG))
return Dot;
+ if (auto WideAdd = tryLowerPartialReductionToWideAdd(N, Subtarget, DAG))
+ return WideAdd;
return DAG.getPartialReduceAdd(SDLoc(N), N->getValueType(0),
N->getOperand(1), N->getOperand(2));
}
diff --git a/llvm/test/CodeGen/AArch64/sve-partial-reduce-wide-add.ll b/llvm/test/CodeGen/AArch64/sve-partial-reduce-wide-add.ll
new file mode 100644
index 00000000000000..6fe3da2a25c0cd
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sve-partial-reduce-wide-add.ll
@@ -0,0 +1,74 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc -mtriple=aarch64 -mattr=+sve2 %s -o - | FileCheck %s
+
+define <vscale x 2 x i64> @signed_wide_add_nxv4i32(<vscale x 2 x i64> %acc, <vscale x 4 x i32> %input){
+; CHECK-LABEL: signed_wide_add_nxv4i32:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: saddwb z0.d, z0.d, z1.s
+; CHECK-NEXT: saddwt z0.d, z0.d, z1.s
+; CHECK-NEXT: ret
+entry:
+ %input.wide = sext <vscale x 4 x i32> %input to <vscale x 4 x i64>
+ %partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv4i64(<vscale x 2 x i64> %acc, <vscale x 4 x i64> %input.wide)
+ ret <vscale x 2 x i64> %partial.reduce
+}
+
+define <vscale x 2 x i64> @unsigned_wide_add_nxv4i32(<vscale x 2 x i64> %acc, <vscale x 4 x i32> %input){
+; CHECK-LABEL: unsigned_wide_add_nxv4i32:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: uaddwb z0.d, z0.d, z1.s
+; CHECK-NEXT: uaddwt z0.d, z0.d, z1.s
+; CHECK-NEXT: ret
+entry:
+ %input.wide = zext <vscale x 4 x i32> %input to <vscale x 4 x i64>
+ %partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv4i64(<vscale x 2 x i64> %acc, <vscale x 4 x i64> %input.wide)
+ ret <vscale x 2 x i64> %partial.reduce
+}
+
+define <vscale x 4 x i32> @signed_wide_add_nxv8i16(<vscale x 4 x i32> %acc, <vscale x 8 x i16> %input){
+; CHECK-LABEL: signed_wide_add_nxv8i16:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: saddwb z0.s, z0.s, z1.h
+; CHECK-NEXT: saddwt z0.s, z0.s, z1.h
+; CHECK-NEXT: ret
+entry:
+ %input.wide = sext <vscale x 8 x i16> %input to <vscale x 8 x i32>
+ %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv8i32(<vscale x 4 x i32> %acc, <vscale x 8 x i32> %input.wide)
+ ret <vscale x 4 x i32> %partial.reduce
+}
+
+define <vscale x 4 x i32> @unsigned_wide_add_nxv8i16(<vscale x 4 x i32> %acc, <vscale x 8 x i16> %input){
+; CHECK-LABEL: unsigned_wide_add_nxv8i16:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: uaddwb z0.s, z0.s, z1.h
+; CHECK-NEXT: uaddwt z0.s, z0.s, z1.h
+; CHECK-NEXT: ret
+entry:
+ %input.wide = zext <vscale x 8 x i16> %input to <vscale x 8 x i32>
+ %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv8i32(<vscale x 4 x i32> %acc, <vscale x 8 x i32> %input.wide)
+ ret <vscale x 4 x i32> %partial.reduce
+}
+
+define <vscale x 8 x i16> @signed_wide_add_nxv16i8(<vscale x 8 x i16> %acc, <vscale x 16 x i8> %input){
+; CHECK-LABEL: signed_wide_add_nxv16i8:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: saddwb z0.h, z0.h, z1.b
+; CHECK-NEXT: saddwt z0.h, z0.h, z1.b
+; CHECK-NEXT: ret
+entry:
+ %input.wide = sext <vscale x 16 x i8> %input to <vscale x 16 x i16>
+ %partial.reduce = tail call <vscale x 8 x i16> @llvm.experimental.vector.partial.reduce.add.nxv8i16.nxv16i16(<vscale x 8 x i16> %acc, <vscale x 16 x i16> %input.wide)
+ ret <vscale x 8 x i16> %partial.reduce
+}
+
+define <vscale x 8 x i16> @unsigned_wide_add_nxv16i8(<vscale x 8 x i16> %acc, <vscale x 16 x i8> %input){
+; CHECK-LABEL: unsigned_wide_add_nxv16i8:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: uaddwb z0.h, z0.h, z1.b
+; CHECK-NEXT: uaddwt z0.h, z0.h, z1.b
+; CHECK-NEXT: ret
+entry:
+ %input.wide = zext <vscale x 16 x i8> %input to <vscale x 16 x i16>
+ %partial.reduce = tail call <vscale x 8 x i16> @llvm.experimental.vector.partial.reduce.add.nxv8i16.nxv16i16(<vscale x 8 x i16> %acc, <vscale x 16 x i16> %input.wide)
+ ret <vscale x 8 x i16> %partial.reduce
+}
>From 9cb54822ecb3196902238055f25b41fba8c3c891 Mon Sep 17 00:00:00 2001
From: James Chesterman <james.chesterman at arm.com>
Date: Mon, 4 Nov 2024 10:15:23 +0000
Subject: [PATCH 2/2] Minor changes to previous patch
Rename variables, eliminate a redundant condition in an if
statement and refactor code checking types
---
.../Target/AArch64/AArch64ISelLowering.cpp | 45 ++++++++++---------
1 file changed, 23 insertions(+), 22 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 8efc8244426ef3..9b9494f61049d9 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -2041,9 +2041,13 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
return true;
EVT VT = EVT::getEVT(I->getType());
- return VT != MVT::nxv4i64 && VT != MVT::nxv4i32 && VT != MVT::nxv2i64 &&
- VT != MVT::nxv8i16 && VT != MVT::v4i64 && VT != MVT::v4i32 &&
- VT != MVT::v2i32 && VT != MVT::v8i16;
+ auto Op1 = I->getOperand(1);
+ EVT Op1VT = EVT::getEVT(Op1->getType());
+ if (Op1VT.getVectorElementType() == VT.getVectorElementType() &&
+ (VT.getVectorElementCount() * 4 == Op1VT.getVectorElementCount() ||
+ VT.getVectorElementCount() * 2 == Op1VT.getVectorElementCount()))
+ return false;
+ return true;
}
bool AArch64TargetLowering::shouldExpandCttzElements(EVT VT) const {
@@ -21793,19 +21797,18 @@ SDValue tryLowerPartialReductionToWideAdd(SDNode *N,
Intrinsic::experimental_vector_partial_reduce_add &&
"Expected a partial reduction node");
- bool Scalable = N->getValueType(0).isScalableVector();
- if (Scalable && !Subtarget->isSVEorStreamingSVEAvailable())
+ if (!Subtarget->isSVEorStreamingSVEAvailable())
return SDValue();
SDLoc DL(N);
- auto Accumulator = N->getOperand(1);
+ auto Acc = N->getOperand(1);
auto ExtInput = N->getOperand(2);
- EVT AccumulatorType = Accumulator.getValueType();
- EVT AccumulatorElementType = AccumulatorType.getVectorElementType();
+ EVT AccVT = Acc.getValueType();
+ EVT AccElemVT = AccVT.getVectorElementType();
- if (ExtInput.getValueType().getVectorElementType() != AccumulatorElementType)
+ if (ExtInput.getValueType().getVectorElementType() != AccElemVT)
return SDValue();
unsigned ExtInputOpcode = ExtInput->getOpcode();
@@ -21813,16 +21816,15 @@ SDValue tryLowerPartialReductionToWideAdd(SDNode *N,
return SDValue();
auto Input = ExtInput->getOperand(0);
- EVT InputType = Input.getValueType();
+ EVT InputVT = Input.getValueType();
// To do this transformation, output element size needs to be double input
// element size, and output number of elements needs to be half the input
// number of elements
- if (!(InputType.getVectorElementType().getSizeInBits() * 2 ==
- AccumulatorElementType.getSizeInBits()) ||
- !(AccumulatorType.getVectorElementCount() * 2 ==
- InputType.getVectorElementCount()) ||
- !(AccumulatorType.isScalableVector() == InputType.isScalableVector()))
+ if (InputVT.getVectorElementType().getSizeInBits() * 2 !=
+ AccElemVT.getSizeInBits() ||
+ AccVT.getVectorElementCount() * 2 != InputVT.getVectorElementCount() ||
+ AccVT.isScalableVector() != InputVT.isScalableVector())
return SDValue();
bool InputIsSigned = ExtInputOpcode == ISD::SIGN_EXTEND;
@@ -21831,13 +21833,12 @@ SDValue tryLowerPartialReductionToWideAdd(SDNode *N,
auto TopIntrinsic = InputIsSigned ? Intrinsic::aarch64_sve_saddwt
: Intrinsic::aarch64_sve_uaddwt;
- auto BottomID =
- DAG.getTargetConstant(BottomIntrinsic, DL, AccumulatorElementType);
- auto BottomNode = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, AccumulatorType,
- BottomID, Accumulator, Input);
- auto TopID = DAG.getTargetConstant(TopIntrinsic, DL, AccumulatorElementType);
- return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, AccumulatorType, TopID,
- BottomNode, Input);
+ auto BottomID = DAG.getTargetConstant(BottomIntrinsic, DL, AccElemVT);
+ auto BottomNode =
+ DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, AccVT, BottomID, Acc, Input);
+ auto TopID = DAG.getTargetConstant(TopIntrinsic, DL, AccElemVT);
+ return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, AccVT, TopID, BottomNode,
+ Input);
}
static SDValue performIntrinsicCombine(SDNode *N,
More information about the llvm-commits
mailing list