[llvm] [AArch64] Lower partial add reduction to udot or svdot (PR #101010)

David Green via llvm-commits llvm-commits at lists.llvm.org
Tue Aug 13 01:12:20 PDT 2024


================
@@ -21237,6 +21279,83 @@ static SDValue performIntrinsicCombine(SDNode *N,
   switch (IID) {
   default:
     break;
+  case Intrinsic::experimental_vector_partial_reduce_add: {
+    SDLoc DL(N);
+
+    bool IsValidDotProduct = true;
+
+    auto NarrowOp = N->getOperand(1);
+    auto MulOp = N->getOperand(2);
+    if (MulOp->getOpcode() != ISD::MUL)
+      IsValidDotProduct = false;
+
+    auto ExtA = MulOp->getOperand(0);
+    auto ExtB = MulOp->getOperand(1);
+    bool IsSExt = ExtA->getOpcode() == ISD::SIGN_EXTEND;
+    bool IsZExt = ExtA->getOpcode() == ISD::ZERO_EXTEND;
+    if (ExtA->getOpcode() != ExtB->getOpcode() || (!IsSExt && !IsZExt))
+      IsValidDotProduct = false;
+
+    unsigned DotIntrinsicId = Intrinsic::not_intrinsic;
+
+    if (IsSExt && IsValidDotProduct)
+      DotIntrinsicId = Intrinsic::aarch64_sve_sdot;
+    else if (IsZExt && IsValidDotProduct)
+      DotIntrinsicId = Intrinsic::aarch64_sve_udot;
+
+    assert((!IsValidDotProduct || DotIntrinsicId != Intrinsic::not_intrinsic) &&
+           "Unexpected dot product case encountered.");
+
+    if (IsValidDotProduct) {
+      auto A = ExtA->getOperand(0);
+      auto B = ExtB->getOperand(0);
+
+      auto IntrinsicId = DAG.getConstant(DotIntrinsicId, DL, MVT::i64);
+      return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, NarrowOp.getValueType(),
+                         {IntrinsicId, NarrowOp, A, B});
+    } else {
+      // If the node doesn't match a dot product, lower to a series of ADDs
+      // instead.
+      SDValue Op0 = N->getOperand(1);
+      SDValue Op1 = N->getOperand(2);
+      EVT Type0 = Op0->getValueType(0);
+      EVT Type1 = Op1->getValueType(0);
+
+      // Canonicalise so that Op1 has the larger type
+      if (Type1.getVectorNumElements() > Type0.getVectorNumElements()) {
+        std::swap(Op0, Op1);
+        std::swap(Type0, Type1);
+      }
+
+      auto Type0Elements = Type0.getVectorNumElements();
+      auto Type1Elements = Type1.getVectorNumElements();
+      auto Type0ElementSize =
+          Type0.getVectorElementType().getScalarSizeInBits();
+      auto Type1ElementSize =
+          Type1.getVectorElementType().getScalarSizeInBits();
+
+      // If the types are equal then a single ADD is fine
+      if (Type0 == Type1)
+        return DAG.getNode(ISD::ADD, DL, Type0, {Op0, Op1});
+
+      // Otherwise, we need to add each subvector together so that the output is
+      // the intrinsic's return type. For example, <4 x i32>
+      // partial.reduction(<4 x i32> a, <16 x i32> b) becomes a + b[0..3] +
+      // b[4..7] + b[8..11] + b[12..15]
+      SDValue Add = Op0;
+      for (unsigned i = 0; i < Type1Elements / Type0Elements; i++) {
+        SDValue Subvec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, Type0, Op1,
+                                     DAG.getConstant(i, DL, MVT::i64));
+
+        if (Type1ElementSize < Type0ElementSize)
+          Subvec = DAG.getNode(ISD::ANY_EXTEND, DL, Type0, Subvec);
+        else if (Type1ElementSize > Type0ElementSize)
+          Subvec = DAG.getNode(ISD::TRUNCATE, DL, Type0, Subvec);
----------------
davemgreen wrote:

I don't think this should be extending/truncating the type. From what I understand they should be the same size.

https://github.com/llvm/llvm-project/pull/101010


More information about the llvm-commits mailing list