[llvm] 44cfbef - [AArch64] Lower partial add reduction to udot or svdot (#101010)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Sep 2 06:06:17 PDT 2024
Author: Sam Tebbs
Date: 2024-09-02T14:06:14+01:00
New Revision: 44cfbef1b3cb0dd33886cc27441930008a245963
URL: https://github.com/llvm/llvm-project/commit/44cfbef1b3cb0dd33886cc27441930008a245963
DIFF: https://github.com/llvm/llvm-project/commit/44cfbef1b3cb0dd33886cc27441930008a245963.diff
LOG: [AArch64] Lower partial add reduction to udot or svdot (#101010)
This patch introduces lowering of the partial add reduction intrinsic to
a udot or svdot for AArch64. This also involves adding a
`shouldExpandPartialReductionIntrinsic` target hook, which AArch64 will
return false from in the cases that it can be lowered.
Added:
llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll
Modified:
llvm/include/llvm/CodeGen/SelectionDAG.h
llvm/include/llvm/CodeGen/TargetLowering.h
llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
llvm/lib/Target/AArch64/AArch64ISelLowering.h
Removed:
################################################################################
diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h
index 1514d92b36b3c2..7ee8ca18c2c1de 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAG.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAG.h
@@ -1594,6 +1594,11 @@ class SelectionDAG {
/// the target's desired shift amount type.
SDValue getShiftAmountOperand(EVT LHSTy, SDValue Op);
+ /// Create the DAG equivalent of vector_partial_reduce where Op1 and Op2 are
+ /// its operands and ReducedTY is the intrinsic's return type.
+ SDValue getPartialReduceAdd(SDLoc DL, EVT ReducedTy, SDValue Op1,
+ SDValue Op2);
+
/// Expand the specified \c ISD::VAARG node as the Legalize pass would.
SDValue expandVAArg(SDNode *Node);
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index eda38cd8a564d6..e17d68d2690c86 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -453,6 +453,13 @@ class TargetLoweringBase {
return true;
}
+ /// Return true if the @llvm.experimental.vector.partial.reduce.* intrinsic
+ /// should be expanded using generic code in SelectionDAGBuilder.
+ virtual bool
+ shouldExpandPartialReductionIntrinsic(const IntrinsicInst *I) const {
+ return true;
+ }
+
/// Return true if the @llvm.get.active.lane.mask intrinsic should be expanded
/// using generic code in SelectionDAGBuilder.
virtual bool shouldExpandGetActiveLaneMask(EVT VT, EVT OpVT) const {
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 7f57b6db40ef49..aa468fa9ebb4c3 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -74,6 +74,7 @@
#include <cassert>
#include <cstdint>
#include <cstdlib>
+#include <deque>
#include <limits>
#include <optional>
#include <set>
@@ -2439,6 +2440,35 @@ SDValue SelectionDAG::getShiftAmountOperand(EVT LHSTy, SDValue Op) {
return getZExtOrTrunc(Op, SDLoc(Op), ShTy);
}
+SDValue SelectionDAG::getPartialReduceAdd(SDLoc DL, EVT ReducedTy, SDValue Op1,
+ SDValue Op2) {
+ EVT FullTy = Op2.getValueType();
+
+ unsigned Stride = ReducedTy.getVectorMinNumElements();
+ unsigned ScaleFactor = FullTy.getVectorMinNumElements() / Stride;
+
+ // Collect all of the subvectors
+ std::deque<SDValue> Subvectors = {Op1};
+ for (unsigned I = 0; I < ScaleFactor; I++) {
+ auto SourceIndex = getVectorIdxConstant(I * Stride, DL);
+ Subvectors.push_back(
+ getNode(ISD::EXTRACT_SUBVECTOR, DL, ReducedTy, {Op2, SourceIndex}));
+ }
+
+ // Flatten the subvector tree
+ while (Subvectors.size() > 1) {
+ Subvectors.push_back(
+ getNode(ISD::ADD, DL, ReducedTy, {Subvectors[0], Subvectors[1]}));
+ Subvectors.pop_front();
+ Subvectors.pop_front();
+ }
+
+ assert(Subvectors.size() == 1 &&
+ "There should only be one subvector after tree flattening");
+
+ return Subvectors[0];
+}
+
SDValue SelectionDAG::expandVAArg(SDNode *Node) {
SDLoc dl(Node);
const TargetLowering &TLI = getTargetLoweringInfo();
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 4b326ba76f97f2..382a555aa656f2 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -8038,34 +8038,15 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
return;
}
case Intrinsic::experimental_vector_partial_reduce_add: {
- SDValue OpNode = getValue(I.getOperand(1));
- EVT ReducedTy = EVT::getEVT(I.getType());
- EVT FullTy = OpNode.getValueType();
- unsigned Stride = ReducedTy.getVectorMinNumElements();
- unsigned ScaleFactor = FullTy.getVectorMinNumElements() / Stride;
-
- // Collect all of the subvectors
- std::deque<SDValue> Subvectors;
- Subvectors.push_back(getValue(I.getOperand(0)));
- for (unsigned i = 0; i < ScaleFactor; i++) {
- auto SourceIndex = DAG.getVectorIdxConstant(i * Stride, sdl);
- Subvectors.push_back(DAG.getNode(ISD::EXTRACT_SUBVECTOR, sdl, ReducedTy,
- {OpNode, SourceIndex}));
- }
-
- // Flatten the subvector tree
- while (Subvectors.size() > 1) {
- Subvectors.push_back(DAG.getNode(ISD::ADD, sdl, ReducedTy,
- {Subvectors[0], Subvectors[1]}));
- Subvectors.pop_front();
- Subvectors.pop_front();
+ if (!TLI.shouldExpandPartialReductionIntrinsic(cast<IntrinsicInst>(&I))) {
+ visitTargetIntrinsic(I, Intrinsic);
+ return;
}
- assert(Subvectors.size() == 1 &&
- "There should only be one subvector after tree flattening");
-
- setValue(&I, Subvectors[0]);
+ setValue(&I, DAG.getPartialReduceAdd(sdl, EVT::getEVT(I.getType()),
+ getValue(I.getOperand(0)),
+ getValue(I.getOperand(1))));
return;
}
case Intrinsic::experimental_cttz_elts: {
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 11aca69db0a148..1735ff5cd69748 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1988,6 +1988,15 @@ bool AArch64TargetLowering::shouldExpandGetActiveLaneMask(EVT ResVT,
return false;
}
+bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
+ const IntrinsicInst *I) const {
+ if (I->getIntrinsicID() != Intrinsic::experimental_vector_partial_reduce_add)
+ return true;
+
+ EVT VT = EVT::getEVT(I->getType());
+ return VT != MVT::nxv4i32 && VT != MVT::nxv2i64;
+}
+
bool AArch64TargetLowering::shouldExpandCttzElements(EVT VT) const {
if (!Subtarget->isSVEorStreamingSVEAvailable())
return true;
@@ -21763,6 +21772,61 @@ static SDValue tryCombineWhileLo(SDNode *N,
return SDValue(N, 0);
}
+SDValue tryLowerPartialReductionToDot(SDNode *N,
+ const AArch64Subtarget *Subtarget,
+ SelectionDAG &DAG) {
+
+ assert(N->getOpcode() == ISD::INTRINSIC_WO_CHAIN &&
+ getIntrinsicID(N) ==
+ Intrinsic::experimental_vector_partial_reduce_add &&
+ "Expected a partial reduction node");
+
+ if (!Subtarget->isSVEorStreamingSVEAvailable())
+ return SDValue();
+
+ SDLoc DL(N);
+
+ // The narrower of the two operands. Used as the accumulator
+ auto NarrowOp = N->getOperand(1);
+ auto MulOp = N->getOperand(2);
+ if (MulOp->getOpcode() != ISD::MUL)
+ return SDValue();
+
+ auto ExtA = MulOp->getOperand(0);
+ auto ExtB = MulOp->getOperand(1);
+ bool IsSExt = ExtA->getOpcode() == ISD::SIGN_EXTEND;
+ bool IsZExt = ExtA->getOpcode() == ISD::ZERO_EXTEND;
+ if (ExtA->getOpcode() != ExtB->getOpcode() || (!IsSExt && !IsZExt))
+ return SDValue();
+
+ auto A = ExtA->getOperand(0);
+ auto B = ExtB->getOperand(0);
+ if (A.getValueType() != B.getValueType())
+ return SDValue();
+
+ unsigned Opcode = 0;
+
+ if (IsSExt)
+ Opcode = AArch64ISD::SDOT;
+ else if (IsZExt)
+ Opcode = AArch64ISD::UDOT;
+
+ assert(Opcode != 0 && "Unexpected dot product case encountered.");
+
+ EVT ReducedType = N->getValueType(0);
+ EVT MulSrcType = A.getValueType();
+
+ // Dot products operate on chunks of four elements so there must be four times
+ // as many elements in the wide type
+ if (ReducedType == MVT::nxv4i32 && MulSrcType == MVT::nxv16i8)
+ return DAG.getNode(Opcode, DL, MVT::nxv4i32, NarrowOp, A, B);
+
+ if (ReducedType == MVT::nxv2i64 && MulSrcType == MVT::nxv8i16)
+ return DAG.getNode(Opcode, DL, MVT::nxv2i64, NarrowOp, A, B);
+
+ return SDValue();
+}
+
static SDValue performIntrinsicCombine(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI,
const AArch64Subtarget *Subtarget) {
@@ -21771,6 +21835,12 @@ static SDValue performIntrinsicCombine(SDNode *N,
switch (IID) {
default:
break;
+ case Intrinsic::experimental_vector_partial_reduce_add: {
+ if (auto Dot = tryLowerPartialReductionToDot(N, Subtarget, DAG))
+ return Dot;
+ return DAG.getPartialReduceAdd(SDLoc(N), N->getValueType(0),
+ N->getOperand(1), N->getOperand(2));
+ }
case Intrinsic::aarch64_neon_vcvtfxs2fp:
case Intrinsic::aarch64_neon_vcvtfxu2fp:
return tryCombineFixedPointConvert(N, DCI, DAG);
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index 39d5df0de0eec7..f9d45b02d30e30 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -998,6 +998,9 @@ class AArch64TargetLowering : public TargetLowering {
bool shouldExpandGetActiveLaneMask(EVT VT, EVT OpVT) const override;
+ bool
+ shouldExpandPartialReductionIntrinsic(const IntrinsicInst *I) const override;
+
bool shouldExpandCttzElements(EVT VT) const override;
/// If a change in streaming mode is required on entry to/return from a
diff --git a/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll
new file mode 100644
index 00000000000000..b1354ab210f727
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll
@@ -0,0 +1,96 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc -mtriple=aarch64 -mattr=+sve2 %s -o - | FileCheck %s
+
+define <vscale x 4 x i32> @dotp(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
+; CHECK-LABEL: dotp:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: udot z0.s, z1.b, z2.b
+; CHECK-NEXT: ret
+entry:
+ %a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+ %b.wide = zext <vscale x 16 x i8> %b to <vscale x 16 x i32>
+ %mult = mul nuw nsw <vscale x 16 x i32> %a.wide, %b.wide
+ %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %mult)
+ ret <vscale x 4 x i32> %partial.reduce
+}
+
+define <vscale x 2 x i64> @dotp_wide(<vscale x 2 x i64> %acc, <vscale x 8 x i16> %a, <vscale x 8 x i16> %b) {
+; CHECK-LABEL: dotp_wide:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: udot z0.d, z1.h, z2.h
+; CHECK-NEXT: ret
+entry:
+ %a.wide = zext <vscale x 8 x i16> %a to <vscale x 8 x i64>
+ %b.wide = zext <vscale x 8 x i16> %b to <vscale x 8 x i64>
+ %mult = mul nuw nsw <vscale x 8 x i64> %a.wide, %b.wide
+ %partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv8i64(<vscale x 2 x i64> %acc, <vscale x 8 x i64> %mult)
+ ret <vscale x 2 x i64> %partial.reduce
+}
+
+define <vscale x 4 x i32> @dotp_sext(<vscale x 4 x i32> %accc, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
+; CHECK-LABEL: dotp_sext:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: sdot z0.s, z1.b, z2.b
+; CHECK-NEXT: ret
+entry:
+ %a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+ %b.wide = sext <vscale x 16 x i8> %b to <vscale x 16 x i32>
+ %mult = mul nuw nsw <vscale x 16 x i32> %a.wide, %b.wide
+ %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %accc, <vscale x 16 x i32> %mult)
+ ret <vscale x 4 x i32> %partial.reduce
+}
+
+define <vscale x 2 x i64> @dotp_wide_sext(<vscale x 2 x i64> %acc, <vscale x 8 x i16> %a, <vscale x 8 x i16> %b) {
+; CHECK-LABEL: dotp_wide_sext:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: sdot z0.d, z1.h, z2.h
+; CHECK-NEXT: ret
+entry:
+ %a.wide = sext <vscale x 8 x i16> %a to <vscale x 8 x i64>
+ %b.wide = sext <vscale x 8 x i16> %b to <vscale x 8 x i64>
+ %mult = mul nuw nsw <vscale x 8 x i64> %a.wide, %b.wide
+ %partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv8i64(<vscale x 2 x i64> %acc, <vscale x 8 x i64> %mult)
+ ret <vscale x 2 x i64> %partial.reduce
+}
+
+define <vscale x 4 x i32> @not_dotp(<vscale x 4 x i32> %acc, <vscale x 8 x i8> %a, <vscale x 8 x i8> %b) {
+; CHECK-LABEL: not_dotp:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: and z1.h, z1.h, #0xff
+; CHECK-NEXT: and z2.h, z2.h, #0xff
+; CHECK-NEXT: ptrue p0.s
+; CHECK-NEXT: uunpklo z3.s, z1.h
+; CHECK-NEXT: uunpklo z4.s, z2.h
+; CHECK-NEXT: uunpkhi z1.s, z1.h
+; CHECK-NEXT: uunpkhi z2.s, z2.h
+; CHECK-NEXT: mla z0.s, p0/m, z3.s, z4.s
+; CHECK-NEXT: mla z0.s, p0/m, z1.s, z2.s
+; CHECK-NEXT: ret
+entry:
+ %a.wide = zext <vscale x 8 x i8> %a to <vscale x 8 x i32>
+ %b.wide = zext <vscale x 8 x i8> %b to <vscale x 8 x i32>
+ %mult = mul nuw nsw <vscale x 8 x i32> %a.wide, %b.wide
+ %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %acc, <vscale x 8 x i32> %mult)
+ ret <vscale x 4 x i32> %partial.reduce
+}
+
+define <vscale x 2 x i64> @not_dotp_wide(<vscale x 2 x i64> %acc, <vscale x 4 x i16> %a, <vscale x 4 x i16> %b) {
+; CHECK-LABEL: not_dotp_wide:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: and z1.s, z1.s, #0xffff
+; CHECK-NEXT: and z2.s, z2.s, #0xffff
+; CHECK-NEXT: ptrue p0.d
+; CHECK-NEXT: uunpklo z3.d, z1.s
+; CHECK-NEXT: uunpklo z4.d, z2.s
+; CHECK-NEXT: uunpkhi z1.d, z1.s
+; CHECK-NEXT: uunpkhi z2.d, z2.s
+; CHECK-NEXT: mla z0.d, p0/m, z3.d, z4.d
+; CHECK-NEXT: mla z0.d, p0/m, z1.d, z2.d
+; CHECK-NEXT: ret
+entry:
+ %a.wide = zext <vscale x 4 x i16> %a to <vscale x 4 x i64>
+ %b.wide = zext <vscale x 4 x i16> %b to <vscale x 4 x i64>
+ %mult = mul nuw nsw <vscale x 4 x i64> %a.wide, %b.wide
+ %partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv4i64(<vscale x 2 x i64> %acc, <vscale x 4 x i64> %mult)
+ ret <vscale x 2 x i64> %partial.reduce
+}
More information about the llvm-commits
mailing list