[llvm] [AArch64] Lower partial add reduction to udot or svdot (PR #101010)

Paul Walker via llvm-commits llvm-commits at lists.llvm.org
Tue Aug 27 10:41:27 PDT 2024


================
@@ -21229,6 +21249,101 @@ static SDValue tryCombineWhileLo(SDNode *N,
   return SDValue(N, 0);
 }
 
+SDValue tryLowerPartialReductionToDot(SDNode *N,
+                                      const AArch64Subtarget *Subtarget,
+                                      SelectionDAG &DAG) {
+
+  if (!Subtarget->isSVEAvailable() && !Subtarget->isNeonAvailable())
+    return SDValue();
+
+  SDLoc DL(N);
+
+  // The narrower of the two operands. Used as the accumulator
+  auto NarrowOp = N->getOperand(1);
+  auto MulOp = N->getOperand(2);
+  if (MulOp->getOpcode() != ISD::MUL)
+    return SDValue();
+
+  auto ExtA = MulOp->getOperand(0);
+  auto ExtB = MulOp->getOperand(1);
+  bool IsSExt = ExtA->getOpcode() == ISD::SIGN_EXTEND;
+  bool IsZExt = ExtA->getOpcode() == ISD::ZERO_EXTEND;
+  if (ExtA->getOpcode() != ExtB->getOpcode() || (!IsSExt && !IsZExt))
+    return SDValue();
+
+  auto A = ExtA->getOperand(0);
+  auto B = ExtB->getOperand(0);
+  if (A.getValueType() != B.getValueType())
+    return SDValue();
+
+  // The fully-reduced type. Should be a vector of i32 or i64
+  EVT FullType = N->getValueType(0);
+  // The type that is extended to the wide type. Should be an i8 or i16
+  EVT ExtendedType = A.getValueType();
+  // The wide type with four times as many elements as the reduced type. Should
+  // be a vector of i32 or i64, the same as the fully-reduced type
+  EVT WideType = MulOp.getValueType();
+  if (WideType.getScalarSizeInBits() != FullType.getScalarSizeInBits())
+    return SDValue();
+  // Dot products operate on chunks of four elements so there must be four times
+  // as many elements in the wide type
+  if (WideType.getVectorMinNumElements() / FullType.getVectorMinNumElements() !=
+      4)
+    return SDValue();
+  switch (FullType.getScalarSizeInBits()) {
+  case 32:
+    if (ExtendedType.getScalarSizeInBits() != 8)
+      return SDValue();
+    break;
+  case 64:
+    // i8 to i64 can be done with an extended i32 dot product
+    if (ExtendedType.getScalarSizeInBits() != 8 &&
+        ExtendedType.getScalarSizeInBits() != 16)
+      return SDValue();
+    break;
+  default:
+    return SDValue();
+  }
+
+  unsigned DotIntrinsicId = Intrinsic::not_intrinsic;
+
+  if (IsSExt)
+    DotIntrinsicId = Intrinsic::aarch64_sve_sdot;
+  else if (IsZExt)
+    DotIntrinsicId = Intrinsic::aarch64_sve_udot;
+
+  assert(DotIntrinsicId != Intrinsic::not_intrinsic &&
+         "Unexpected dot product case encountered.");
+
+  EVT Type = NarrowOp.getValueType();
+
+  // 8 bit input to 64 bit output can be done by doing a 32 bit dot product
+  // and extending the output
+  bool Extend = A->getValueType(0).getScalarSizeInBits() == 8 &&
+                Type.getScalarSizeInBits() == 64;
+  SDValue Accumulator = NarrowOp;
+  if (Extend) {
+    Type =
+        Type.changeVectorElementType(EVT::getIntegerVT(*DAG.getContext(), 32));
+    // The accumulator is of the wider type so we insert a 0 accumulator and
+    // add the proper one after extending
+    Accumulator = DAG.getNode(ISD::SPLAT_VECTOR, DL, MVT::nxv4i32,
+                              DAG.getConstant(0, DL, MVT::i32));
----------------
paulwalker-arm wrote:

You should be able to use `DAG.getConstant(0, DL, MVT::nxv4i32);` here.

https://github.com/llvm/llvm-project/pull/101010


More information about the llvm-commits mailing list