[llvm] [AArch64] Lower partial add reduction to udot or svdot (PR #101010)

Paul Walker via llvm-commits llvm-commits at lists.llvm.org
Mon Sep 2 02:27:46 PDT 2024


================
@@ -21757,6 +21768,70 @@ static SDValue tryCombineWhileLo(SDNode *N,
   return SDValue(N, 0);
 }
 
+SDValue tryLowerPartialReductionToDot(SDNode *N,
+                                      const AArch64Subtarget *Subtarget,
+                                      SelectionDAG &DAG) {
+
+  assert(N->getOpcode() == ISD::INTRINSIC_WO_CHAIN &&
+         getIntrinsicID(N) ==
+             Intrinsic::experimental_vector_partial_reduce_add &&
+         "Expected a partial reduction node");
+
+  if (!Subtarget->isSVEorStreamingSVEAvailable())
+    return SDValue();
+
+  SDLoc DL(N);
+
+  // The narrower of the two operands. Used as the accumulator
+  auto NarrowOp = N->getOperand(1);
+  auto MulOp = N->getOperand(2);
+  if (MulOp->getOpcode() != ISD::MUL)
+    return SDValue();
+
+  auto ExtA = MulOp->getOperand(0);
+  auto ExtB = MulOp->getOperand(1);
+  bool IsSExt = ExtA->getOpcode() == ISD::SIGN_EXTEND;
+  bool IsZExt = ExtA->getOpcode() == ISD::ZERO_EXTEND;
+  if (ExtA->getOpcode() != ExtB->getOpcode() || (!IsSExt && !IsZExt))
+    return SDValue();
+
+  auto A = ExtA->getOperand(0);
+  auto B = ExtB->getOperand(0);
+  if (A.getValueType() != B.getValueType())
+    return SDValue();
+
+  unsigned Opcode = 0;
+
+  if (IsSExt)
+    Opcode = AArch64ISD::SDOT;
+  else if (IsZExt)
+    Opcode = AArch64ISD::UDOT;
+
+  assert(Opcode != 0 && "Unexpected dot product case encountered.");
+
+  // The fully-reduced type. Should be a vector of i32 or i64
+  EVT ReducedType = N->getValueType(0);
+  // The type that is extended to the wide type. Should be an i8 or i16
+  EVT ExtendedType = A.getValueType();
----------------
paulwalker-arm wrote:

I see, perhaps `PreExtendedType` then?

https://github.com/llvm/llvm-project/pull/101010


More information about the llvm-commits mailing list