[llvm] [AArch64] Lower partial add reduction to udot or svdot (PR #101010)
Sander de Smalen via llvm-commits
llvm-commits at lists.llvm.org
Mon Jul 29 07:46:12 PDT 2024
================
@@ -1971,6 +1971,57 @@ bool AArch64TargetLowering::shouldExpandGetActiveLaneMask(EVT ResVT,
return false;
}
+bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
+ const CallInst *CI) const {
+ const bool TargetLowers = false;
+ const bool GenericLowers = true;
+
+ auto *I = dyn_cast<IntrinsicInst>(CI);
+ if (!I)
+ return GenericLowers;
+
+ ScalableVectorType *RetTy = dyn_cast<ScalableVectorType>(I->getType());
+
+ if (!RetTy)
+ return GenericLowers;
+
+ ScalableVectorType *InputTy = nullptr;
+
+ auto RetScalarTy = RetTy->getScalarType();
+ if (RetScalarTy->isIntegerTy(64)) {
+ InputTy = ScalableVectorType::get(Type::getInt16Ty(I->getContext()), 8);
+ } else if (RetScalarTy->isIntegerTy(32)) {
+ InputTy = ScalableVectorType::get(Type::getInt8Ty(I->getContext()), 16);
+ }
+
+ if (!InputTy)
+ return GenericLowers;
+
+ Value *InputA;
+ Value *InputB;
+
+ auto Pattern = m_Intrinsic<Intrinsic::experimental_vector_partial_reduce_add>(
+ m_Value(), m_OneUse(m_Mul(m_OneUse(m_ZExtOrSExt(m_Value(InputA))),
+ m_OneUse(m_ZExtOrSExt(m_Value(InputB))))));
----------------
sdesmalen-arm wrote:
Is it worth making this check earlier (before checking the types) and then do:
```
if (match(I, m_Intrinsic(...)) {
if ((I->getType()->isIntegerType(64) && InputA->getType()->isIntegerType(16)) ||
(I->getType()->isIntegerType(32) && InputA->getType()->isIntegerType(8))) {
auto *Mul = cast<Instruction>(I->getOperand(1);
if (Mul->getOperand(0)->getOpcode() == Mul->getOperand(1)->getOpcode())
return false;
}
}
return true;
```
That way you don't need to construct any explicit Types, to then match later.
https://github.com/llvm/llvm-project/pull/101010
More information about the llvm-commits
mailing list