[llvm] [AArch64] Lower partial add reduction to udot or svdot (PR #101010)
David Green via llvm-commits
llvm-commits at lists.llvm.org
Tue Aug 13 01:12:20 PDT 2024
================
@@ -21237,6 +21279,83 @@ static SDValue performIntrinsicCombine(SDNode *N,
switch (IID) {
default:
break;
+ case Intrinsic::experimental_vector_partial_reduce_add: {
+ SDLoc DL(N);
+
+ bool IsValidDotProduct = true;
+
+ auto NarrowOp = N->getOperand(1);
+ auto MulOp = N->getOperand(2);
+ if (MulOp->getOpcode() != ISD::MUL)
+ IsValidDotProduct = false;
+
+ auto ExtA = MulOp->getOperand(0);
+ auto ExtB = MulOp->getOperand(1);
+ bool IsSExt = ExtA->getOpcode() == ISD::SIGN_EXTEND;
+ bool IsZExt = ExtA->getOpcode() == ISD::ZERO_EXTEND;
+ if (ExtA->getOpcode() != ExtB->getOpcode() || (!IsSExt && !IsZExt))
+ IsValidDotProduct = false;
+
+ unsigned DotIntrinsicId = Intrinsic::not_intrinsic;
+
+ if (IsSExt && IsValidDotProduct)
+ DotIntrinsicId = Intrinsic::aarch64_sve_sdot;
+ else if (IsZExt && IsValidDotProduct)
+ DotIntrinsicId = Intrinsic::aarch64_sve_udot;
+
+ assert((!IsValidDotProduct || DotIntrinsicId != Intrinsic::not_intrinsic) &&
+ "Unexpected dot product case encountered.");
+
+ if (IsValidDotProduct) {
+ auto A = ExtA->getOperand(0);
+ auto B = ExtB->getOperand(0);
+
+ auto IntrinsicId = DAG.getConstant(DotIntrinsicId, DL, MVT::i64);
+ return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, NarrowOp.getValueType(),
+ {IntrinsicId, NarrowOp, A, B});
+ } else {
+ // If the node doesn't match a dot product, lower to a series of ADDs
+ // instead.
+ SDValue Op0 = N->getOperand(1);
+ SDValue Op1 = N->getOperand(2);
+ EVT Type0 = Op0->getValueType(0);
+ EVT Type1 = Op1->getValueType(0);
+
+ // Canonicalise so that Op1 has the larger type
+ if (Type1.getVectorNumElements() > Type0.getVectorNumElements()) {
+ std::swap(Op0, Op1);
+ std::swap(Type0, Type1);
+ }
+
+ auto Type0Elements = Type0.getVectorNumElements();
+ auto Type1Elements = Type1.getVectorNumElements();
+ auto Type0ElementSize =
+ Type0.getVectorElementType().getScalarSizeInBits();
+ auto Type1ElementSize =
+ Type1.getVectorElementType().getScalarSizeInBits();
+
+ // If the types are equal then a single ADD is fine
+ if (Type0 == Type1)
+ return DAG.getNode(ISD::ADD, DL, Type0, {Op0, Op1});
+
+ // Otherwise, we need to add each subvector together so that the output is
+ // the intrinsic's return type. For example, <4 x i32>
+ // partial.reduction(<4 x i32> a, <16 x i32> b) becomes a + b[0..3] +
+ // b[4..7] + b[8..11] + b[12..15]
+ SDValue Add = Op0;
+ for (unsigned i = 0; i < Type1Elements / Type0Elements; i++) {
+ SDValue Subvec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, Type0, Op1,
+ DAG.getConstant(i, DL, MVT::i64));
+
+ if (Type1ElementSize < Type0ElementSize)
+ Subvec = DAG.getNode(ISD::ANY_EXTEND, DL, Type0, Subvec);
+ else if (Type1ElementSize > Type0ElementSize)
+ Subvec = DAG.getNode(ISD::TRUNCATE, DL, Type0, Subvec);
----------------
davemgreen wrote:
I don't think this should be extending/truncating the type. From what I understand they should be the same size.
https://github.com/llvm/llvm-project/pull/101010
More information about the llvm-commits
mailing list