[llvm] [CostModel][AArch64] Make extractelement, with fmul user, free whenev… (PR #111479)

Sushant Gokhale via llvm-commits llvm-commits at lists.llvm.org
Sun Nov 10 23:18:00 PST 2024


================
@@ -3226,6 +3227,130 @@ InstructionCost AArch64TTIImpl::getVectorInstrCostHelper(const Instruction *I,
     // compile-time considerations.
   }
 
+  // In case of Neon, if there exists extractelement from lane != 0 such that
+  // 1. extractelement does not necessitate a move from vector_reg -> GPR.
+  // 2. extractelement result feeds into fmul.
+  // 3. Other operand of fmul is an extractelement from lane 0 or lane
+  // equivalent to 0.
+  // then the extractelement can be merged with fmul in the backend and it
+  // incurs no cost.
+  // e.g.
+  // define double @foo(<2 x double> %a) {
+  //   %1 = extractelement <2 x double> %a, i32 0
+  //   %2 = extractelement <2 x double> %a, i32 1
+  //   %res = fmul double %1, %2
+  //   ret double %res
+  // }
+  // %2 and %res can be merged in the backend to generate fmul d0, d0, v1.d[1]
+  auto ExtractCanFuseWithFmul = [&]() {
+    // We bail out if the extract is from lane 0.
+    if (Index == 0)
+      return false;
+
+    // Check if the scalar element type of the vector operand of ExtractElement
+    // instruction is one of the allowed types.
+    auto IsAllowedScalarTy = [&](const Type *T) {
+      return T->isFloatTy() || T->isDoubleTy() ||
+             (T->isHalfTy() && ST->hasFullFP16());
+    };
+
+    // Check if the extractelement user is scalar fmul.
+    auto IsUserFMulScalarTy = [](const Value *EEUser) {
+      // Check if the user is scalar fmul.
+      const auto *BO = dyn_cast_if_present<BinaryOperator>(EEUser);
+      return BO && BO->getOpcode() == BinaryOperator::FMul &&
+             !BO->getType()->isVectorTy();
+    };
+
+    // Check if the type constraints on input vector type and result scalar type
+    // of extractelement instruction are satisfied.
+    auto TypeConstraintsOnEESatisfied =
+        [&IsAllowedScalarTy](const Type *VectorTy, const Type *ScalarTy) {
+          return isa<FixedVectorType>(VectorTy) && IsAllowedScalarTy(ScalarTy);
+        };
+
+    // Check if the extract index is from lane 0 or lane equivalent to 0 for a
+    // certain scalar type and a certain vector register width.
+    auto IsExtractLaneEquivalentToZero = [&](const unsigned &Idx,
+                                             const unsigned &EltSz) {
+      auto RegWidth =
+          getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector)
+              .getFixedValue();
+      return (Idx == 0 || (Idx * EltSz) % RegWidth == 0);
+    };
+
+    if (Opcode.has_value()) {
+      if (!TypeConstraintsOnEESatisfied(Val, Val->getScalarType()))
+        return false;
+
+      DenseMap<User *, unsigned> UserToExtractIdx;
+      for (auto *U : Scalar->users()) {
+        if (!IsUserFMulScalarTy(U))
+          return false;
+        // Recording entry for the user is important. Index value is not
+        // important.
+        UserToExtractIdx[U];
+      }
+      for (auto &[S, U, L] : ScalarUserAndIdx) {
+        for (auto *U : S->users()) {
+          if (UserToExtractIdx.find(U) != UserToExtractIdx.end()) {
+            auto *FMul = cast<BinaryOperator>(U);
+            auto *Op0 = FMul->getOperand(0);
+            auto *Op1 = FMul->getOperand(1);
+            if ((Op0 == S && Op1 == S) || (Op0 != S) || (Op1 != S)) {
+              UserToExtractIdx[U] = L;
+              break;
+            }
+          }
+        }
+      }
+      for (auto &[U, L] : UserToExtractIdx) {
+        if (!IsExtractLaneEquivalentToZero(Index, Val->getScalarSizeInBits()) &&
+            !IsExtractLaneEquivalentToZero(L, Val->getScalarSizeInBits()))
+          return false;
+      }
+    } else {
+      const auto *EE = cast<ExtractElementInst>(I);
+
+      const auto *IdxOp = dyn_cast<ConstantInt>(EE->getIndexOperand());
+      if (!IdxOp)
+        return false;
+
+      if (!TypeConstraintsOnEESatisfied(EE->getVectorOperand()->getType(),
+                                        EE->getType()))
+        return false;
+
+      return !EE->users().empty() && all_of(EE->users(), [&](const User *U) {
+        if (!IsUserFMulScalarTy(U))
+          return false;
+
+        // Check if the other operand of extractelement is also extractelement
+        // from lane equivalent to 0.
+        const auto *BO = cast<BinaryOperator>(U);
+        const auto *OtherEE = dyn_cast<ExtractElementInst>(
+            BO->getOperand(0) == EE ? BO->getOperand(1) : BO->getOperand(0));
+        if (OtherEE) {
+          const auto *IdxOp = dyn_cast<ConstantInt>(OtherEE->getIndexOperand());
+          if (!IdxOp)
+            return false;
+          return IsExtractLaneEquivalentToZero(
+              cast<ConstantInt>(OtherEE->getIndexOperand())
+                  ->getValue()
+                  .getZExtValue(),
+              OtherEE->getType()->getScalarSizeInBits());
+        }
+        return true;
+      });
+    }
+    return true;
----------------
sushgokh wrote:

its already false, if requirements are not met, from all other places.

Hence, last thing is true.

https://github.com/llvm/llvm-project/pull/111479


More information about the llvm-commits mailing list