[llvm] [CostModel][AArch64] Make extractelement, with fmul user, free whenev… (PR #111479)
Sushant Gokhale via llvm-commits
llvm-commits at lists.llvm.org
Sun Nov 10 23:18:00 PST 2024
================
@@ -3226,6 +3227,130 @@ InstructionCost AArch64TTIImpl::getVectorInstrCostHelper(const Instruction *I,
// compile-time considerations.
}
+ // In case of Neon, if there exists extractelement from lane != 0 such that
+ // 1. extractelement does not necessitate a move from vector_reg -> GPR.
+ // 2. extractelement result feeds into fmul.
+ // 3. Other operand of fmul is an extractelement from lane 0 or lane
+ // equivalent to 0.
+ // then the extractelement can be merged with fmul in the backend and it
+ // incurs no cost.
+ // e.g.
+ // define double @foo(<2 x double> %a) {
+ // %1 = extractelement <2 x double> %a, i32 0
+ // %2 = extractelement <2 x double> %a, i32 1
+ // %res = fmul double %1, %2
+ // ret double %res
+ // }
+ // %2 and %res can be merged in the backend to generate fmul d0, d0, v1.d[1]
+ auto ExtractCanFuseWithFmul = [&]() {
+ // We bail out if the extract is from lane 0.
+ if (Index == 0)
+ return false;
+
+ // Check if the scalar element type of the vector operand of ExtractElement
+ // instruction is one of the allowed types.
+ auto IsAllowedScalarTy = [&](const Type *T) {
+ return T->isFloatTy() || T->isDoubleTy() ||
+ (T->isHalfTy() && ST->hasFullFP16());
+ };
+
+ // Check if the extractelement user is scalar fmul.
+ auto IsUserFMulScalarTy = [](const Value *EEUser) {
+ // Check if the user is scalar fmul.
+ const auto *BO = dyn_cast_if_present<BinaryOperator>(EEUser);
+ return BO && BO->getOpcode() == BinaryOperator::FMul &&
+ !BO->getType()->isVectorTy();
+ };
+
+ // Check if the type constraints on input vector type and result scalar type
+ // of extractelement instruction are satisfied.
+ auto TypeConstraintsOnEESatisfied =
+ [&IsAllowedScalarTy](const Type *VectorTy, const Type *ScalarTy) {
+ return isa<FixedVectorType>(VectorTy) && IsAllowedScalarTy(ScalarTy);
+ };
+
+ // Check if the extract index is from lane 0 or lane equivalent to 0 for a
+ // certain scalar type and a certain vector register width.
+ auto IsExtractLaneEquivalentToZero = [&](const unsigned &Idx,
+ const unsigned &EltSz) {
+ auto RegWidth =
+ getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector)
+ .getFixedValue();
+ return (Idx == 0 || (Idx * EltSz) % RegWidth == 0);
+ };
+
+ if (Opcode.has_value()) {
+ if (!TypeConstraintsOnEESatisfied(Val, Val->getScalarType()))
+ return false;
+
+ DenseMap<User *, unsigned> UserToExtractIdx;
+ for (auto *U : Scalar->users()) {
+ if (!IsUserFMulScalarTy(U))
+ return false;
+ // Recording entry for the user is important. Index value is not
+ // important.
+ UserToExtractIdx[U];
+ }
+ for (auto &[S, U, L] : ScalarUserAndIdx) {
+ for (auto *U : S->users()) {
+ if (UserToExtractIdx.find(U) != UserToExtractIdx.end()) {
+ auto *FMul = cast<BinaryOperator>(U);
+ auto *Op0 = FMul->getOperand(0);
+ auto *Op1 = FMul->getOperand(1);
+ if ((Op0 == S && Op1 == S) || (Op0 != S) || (Op1 != S)) {
+ UserToExtractIdx[U] = L;
+ break;
+ }
+ }
+ }
+ }
+ for (auto &[U, L] : UserToExtractIdx) {
+ if (!IsExtractLaneEquivalentToZero(Index, Val->getScalarSizeInBits()) &&
+ !IsExtractLaneEquivalentToZero(L, Val->getScalarSizeInBits()))
+ return false;
+ }
+ } else {
+ const auto *EE = cast<ExtractElementInst>(I);
+
+ const auto *IdxOp = dyn_cast<ConstantInt>(EE->getIndexOperand());
+ if (!IdxOp)
+ return false;
+
+ if (!TypeConstraintsOnEESatisfied(EE->getVectorOperand()->getType(),
+ EE->getType()))
+ return false;
+
+ return !EE->users().empty() && all_of(EE->users(), [&](const User *U) {
+ if (!IsUserFMulScalarTy(U))
+ return false;
+
+ // Check if the other operand of extractelement is also extractelement
+ // from lane equivalent to 0.
+ const auto *BO = cast<BinaryOperator>(U);
+ const auto *OtherEE = dyn_cast<ExtractElementInst>(
+ BO->getOperand(0) == EE ? BO->getOperand(1) : BO->getOperand(0));
+ if (OtherEE) {
+ const auto *IdxOp = dyn_cast<ConstantInt>(OtherEE->getIndexOperand());
+ if (!IdxOp)
+ return false;
+ return IsExtractLaneEquivalentToZero(
+ cast<ConstantInt>(OtherEE->getIndexOperand())
+ ->getValue()
+ .getZExtValue(),
+ OtherEE->getType()->getScalarSizeInBits());
+ }
+ return true;
+ });
+ }
+ return true;
----------------
sushgokh wrote:
its already false, if requirements are not met, from all other places.
Hence, last thing is true.
https://github.com/llvm/llvm-project/pull/111479
More information about the llvm-commits
mailing list