[llvm] [AArch64] Lower partial add reduction to udot or svdot (PR #101010)

Sam Tebbs via llvm-commits llvm-commits at lists.llvm.org
Wed Aug 14 03:05:02 PDT 2024


================
@@ -1971,6 +1971,48 @@ bool AArch64TargetLowering::shouldExpandGetActiveLaneMask(EVT ResVT,
   return false;
 }
 
+bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
+    const CallInst *CI) const {
+
+  auto *I = dyn_cast<IntrinsicInst>(CI);
+  assert(I && "shouldExpandPartialReductionIntrinsic expects an intrinisc");
+
+  VectorType *RetTy = dyn_cast<VectorType>(I->getType());
+  if (!RetTy || !RetTy->isScalableTy())
+    return true;
+
+  Value *InputA;
+  Value *InputB;
+  if (match(I, m_Intrinsic<Intrinsic::experimental_vector_partial_reduce_add>(
+                   m_Value(),
+                   m_OneUse(m_Mul(m_OneUse(m_ZExtOrSExt(m_Value(InputA))),
+                                  m_OneUse(m_ZExtOrSExt(m_Value(InputB)))))))) {
+    VectorType *InputAType = dyn_cast<VectorType>(InputA->getType());
+    VectorType *InputBType = dyn_cast<VectorType>(InputB->getType());
+    if (!InputAType || !InputBType)
+      return true;
+    ElementCount ExpectedCount8 = ElementCount::get(8, RetTy->isScalableTy());
+    ElementCount ExpectedCount16 = ElementCount::get(16, RetTy->isScalableTy());
+    if ((RetTy->getScalarType()->isIntegerTy(64) &&
+         InputAType->getElementType()->isIntegerTy(16) &&
+         InputAType->getElementCount() == ExpectedCount8 &&
+         InputAType == InputBType) ||
+
----------------
SamTebbs33 wrote:

Done.

https://github.com/llvm/llvm-project/pull/101010


More information about the llvm-commits mailing list