[llvm] [TLI] Pass replace-with-veclib works with Scalable Vectors. (PR #73642)
Maciej Gabka via llvm-commits
llvm-commits at lists.llvm.org
Wed Dec 20 06:34:29 PST 2023
================
@@ -38,138 +42,135 @@ STATISTIC(NumTLIFuncDeclAdded,
STATISTIC(NumFuncUsedAdded,
"Number of functions added to `llvm.compiler.used`");
-static bool replaceWithTLIFunction(CallInst &CI, const StringRef TLIName) {
- Module *M = CI.getModule();
-
- Function *OldFunc = CI.getCalledFunction();
-
- // Check if the vector library function is already declared in this module,
- // otherwise insert it.
+/// Returns a vector Function that it adds to the Module \p M. When an \p
+/// ScalarFunc is not null, it copies its attributes to the newly created
+/// Function.
+Function *getTLIFunction(Module *M, FunctionType *VectorFTy,
+ Function *ScalarFunc, const StringRef TLIName) {
Function *TLIFunc = M->getFunction(TLIName);
if (!TLIFunc) {
- TLIFunc = Function::Create(OldFunc->getFunctionType(),
- Function::ExternalLinkage, TLIName, *M);
- TLIFunc->copyAttributesFrom(OldFunc);
+ TLIFunc =
+ Function::Create(VectorFTy, Function::ExternalLinkage, TLIName, *M);
+ if (ScalarFunc)
+ TLIFunc->copyAttributesFrom(ScalarFunc);
LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Added vector library function `"
<< TLIName << "` of type `" << *(TLIFunc->getType())
<< "` to module.\n");
++NumTLIFuncDeclAdded;
-
- // Add the freshly created function to llvm.compiler.used,
- // similar to as it is done in InjectTLIMappings
+ // Add the freshly created function to llvm.compiler.used, similar to as it
+ // is done in InjectTLIMappings.
appendToCompilerUsed(*M, {TLIFunc});
-
LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Adding `" << TLIName
<< "` to `@llvm.compiler.used`.\n");
++NumFuncUsedAdded;
}
+ return TLIFunc;
+}
- // Replace the call to the vector intrinsic with a call
- // to the corresponding function from the vector library.
- IRBuilder<> IRBuilder(&CI);
- SmallVector<Value *> Args(CI.args());
- // Preserve the operand bundles.
- SmallVector<OperandBundleDef, 1> OpBundles;
- CI.getOperandBundlesAsDefs(OpBundles);
- CallInst *Replacement = IRBuilder.CreateCall(TLIFunc, Args, OpBundles);
- assert(OldFunc->getFunctionType() == TLIFunc->getFunctionType() &&
- "Expecting function types to be identical");
- CI.replaceAllUsesWith(Replacement);
- if (isa<FPMathOperator>(Replacement)) {
- // Preserve fast math flags for FP math.
- Replacement->copyFastMathFlags(&CI);
+/// Replace the call to the vector intrinsic ( \p CalltoReplace ) with a call to
+/// the corresponding function from the vector library ( \p TLIVecFunc ).
+static void replaceWithTLIFunction(CallInst &CalltoReplace, VFInfo &Info,
+ Function *TLIVecFunc) {
+ IRBuilder<> IRBuilder(&CalltoReplace);
+ SmallVector<Value *> Args(CalltoReplace.args());
+ if (auto OptMaskpos = Info.getParamIndexForOptionalMask()) {
+ auto *MaskTy = VectorType::get(Type::getInt1Ty(CalltoReplace.getContext()),
+ Info.Shape.VF);
+ Args.insert(Args.begin() + OptMaskpos.value(),
+ Constant::getAllOnesValue(MaskTy));
}
- LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Replaced call to `"
- << OldFunc->getName() << "` with call to `" << TLIName
- << "`.\n");
- ++NumCallsReplaced;
- return true;
+ // Preserve the operand bundles.
+ SmallVector<OperandBundleDef, 1> OpBundles;
+ CalltoReplace.getOperandBundlesAsDefs(OpBundles);
+ CallInst *Replacement = IRBuilder.CreateCall(TLIVecFunc, Args, OpBundles);
+ CalltoReplace.replaceAllUsesWith(Replacement);
+ // Preserve fast math flags for FP math.
+ if (isa<FPMathOperator>(Replacement))
+ Replacement->copyFastMathFlags(&CalltoReplace);
}
+/// Returns true when successfully replaced \p CallToReplace with a suitable
+/// function taking vector arguments, based on available mappings in the \p TLI.
+/// Currently only works when \p CallToReplace is a call to vectorized
+/// intrinsic.
static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
- CallInst &CI) {
- if (!CI.getCalledFunction()) {
+ CallInst &CallToReplace) {
+ if (!CallToReplace.getCalledFunction())
return false;
- }
- auto IntrinsicID = CI.getCalledFunction()->getIntrinsicID();
- if (IntrinsicID == Intrinsic::not_intrinsic) {
- // Replacement is only performed for intrinsic functions
+ auto IntrinsicID = CallToReplace.getCalledFunction()->getIntrinsicID();
+ // Replacement is only performed for intrinsic functions.
+ if (IntrinsicID == Intrinsic::not_intrinsic)
return false;
- }
- // Convert vector arguments to scalar type and check that
- // all vector operands have identical vector width.
+ // Compute arguments types of the corresponding scalar call. Additionally
+ // checks if in the vector call, all vector operands have the same EC.
ElementCount VF = ElementCount::getFixed(0);
- SmallVector<Type *> ScalarTypes;
- for (auto Arg : enumerate(CI.args())) {
- auto *ArgType = Arg.value()->getType();
- // Vector calls to intrinsics can still have
- // scalar operands for specific arguments.
+ SmallVector<Type *> ScalarArgTypes;
+ for (auto Arg : enumerate(CallToReplace.args())) {
+ auto *ArgTy = Arg.value()->getType();
if (isVectorIntrinsicWithScalarOpAtArg(IntrinsicID, Arg.index())) {
- ScalarTypes.push_back(ArgType);
- } else {
- // The argument in this place should be a vector if
- // this is a call to a vector intrinsic.
- auto *VectorArgTy = dyn_cast<VectorType>(ArgType);
- if (!VectorArgTy) {
- // The argument is not a vector, do not perform
- // the replacement.
- return false;
- }
- ElementCount NumElements = VectorArgTy->getElementCount();
- if (NumElements.isScalable()) {
- // The current implementation does not support
- // scalable vectors.
+ ScalarArgTypes.push_back(ArgTy);
+ } else if (auto *VectorArgTy = dyn_cast<VectorType>(ArgTy)) {
+ ScalarArgTypes.push_back(ArgTy->getScalarType());
+ // Disallow vector arguments with different VFs. When processing the first
+ // vector argument, store it's VF, and for the rest ensure that they match
+ // it.
+ if (VF.isZero())
+ VF = VectorArgTy->getElementCount();
+ else if (VF != VectorArgTy->getElementCount())
return false;
- }
- if (VF.isNonZero() && VF != NumElements) {
- // The different arguments differ in vector size.
- return false;
- } else {
- VF = NumElements;
- }
- ScalarTypes.push_back(VectorArgTy->getElementType());
- }
+ } else
+ // Exit when it is supposed to be a vector argument but it isn't.
+ return false;
}
- // Try to reconstruct the name for the scalar version of this
- // intrinsic using the intrinsic ID and the argument types
- // converted to scalar above.
- std::string ScalarName;
- if (Intrinsic::isOverloaded(IntrinsicID)) {
- ScalarName = Intrinsic::getName(IntrinsicID, ScalarTypes, CI.getModule());
- } else {
- ScalarName = Intrinsic::getName(IntrinsicID).str();
- }
+ // Try to reconstruct the name for the scalar version of this intrinsic using
+ // the intrinsic ID and the argument types converted to scalar above.
+ std::string ScalarName =
+ (Intrinsic::isOverloaded(IntrinsicID)
+ ? Intrinsic::getName(IntrinsicID, ScalarArgTypes,
+ CallToReplace.getModule())
+ : Intrinsic::getName(IntrinsicID).str());
+
+ // Try to find the mapping for the scalar version of this intrinsic and the
+ // exact vector width of the call operands in the TargetLibraryInfo. First,
+ // check with a non-masked variant, and if that fails try with a masked one.
+ const VecDesc *VD = TLI.getVectorMappingInfo(ScalarName, VF, false);
----------------
mgabka wrote:
nit: /*Masked*/ false
https://github.com/llvm/llvm-project/pull/73642
More information about the llvm-commits
mailing list