[llvm] [TLI] Pass replace-with-veclib works with Scalable Vectors. (PR #73642)

Maciej Gabka via llvm-commits llvm-commits at lists.llvm.org
Wed Dec 13 09:06:48 PST 2023


================
@@ -38,138 +42,166 @@ STATISTIC(NumTLIFuncDeclAdded,
 STATISTIC(NumFuncUsedAdded,
           "Number of functions added to `llvm.compiler.used`");
 
-static bool replaceWithTLIFunction(CallInst &CI, const StringRef TLIName) {
-  Module *M = CI.getModule();
-
-  Function *OldFunc = CI.getCalledFunction();
-
-  // Check if the vector library function is already declared in this module,
-  // otherwise insert it.
+/// Returns a vector Function that it adds to the Module \p M. When an \p
+/// OptOldFunc is given, it copies its attributes to the newly created Function.
+Function *getTLIFunction(Module *M, FunctionType *VectorFTy,
+                         std::optional<Function *> OptOldFunc,
+                         const StringRef TLIName) {
   Function *TLIFunc = M->getFunction(TLIName);
   if (!TLIFunc) {
-    TLIFunc = Function::Create(OldFunc->getFunctionType(),
-                               Function::ExternalLinkage, TLIName, *M);
-    TLIFunc->copyAttributesFrom(OldFunc);
+    TLIFunc =
+        Function::Create(VectorFTy, Function::ExternalLinkage, TLIName, *M);
+    if (OptOldFunc)
+      TLIFunc->copyAttributesFrom(*OptOldFunc);
 
     LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Added vector library function `"
                       << TLIName << "` of type `" << *(TLIFunc->getType())
                       << "` to module.\n");
 
     ++NumTLIFuncDeclAdded;
-
-    // Add the freshly created function to llvm.compiler.used,
-    // similar to as it is done in InjectTLIMappings
+    // Add the freshly created function to llvm.compiler.used, similar to as it
+    // is done in InjectTLIMappings
     appendToCompilerUsed(*M, {TLIFunc});
-
     LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Adding `" << TLIName
                       << "` to `@llvm.compiler.used`.\n");
     ++NumFuncUsedAdded;
   }
+  return TLIFunc;
+}
 
-  // Replace the call to the vector intrinsic with a call
-  // to the corresponding function from the vector library.
+/// Replace the call to the vector intrinsic ( \p OldFunc ) with a call to the
+/// corresponding function from the vector library ( \p TLIFunc ).
+static bool replaceWithTLIFunction(const Module *M, CallInst &CI,
+                                   const ElementCount &VecVF, Function *OldFunc,
+                                   Function *TLIFunc, FunctionType *VecFTy,
+                                   bool IsMasked) {
   IRBuilder<> IRBuilder(&CI);
   SmallVector<Value *> Args(CI.args());
+  if (IsMasked) {
+    if (Args.size() == VecFTy->getNumParams())
+      static_assert(true && "mask was already in place");
+
+    auto *MaskTy = VectorType::get(Type::getInt1Ty(M->getContext()), VecVF);
+    Args.push_back(Constant::getAllOnesValue(MaskTy));
+  }
+
   // Preserve the operand bundles.
   SmallVector<OperandBundleDef, 1> OpBundles;
   CI.getOperandBundlesAsDefs(OpBundles);
   CallInst *Replacement = IRBuilder.CreateCall(TLIFunc, Args, OpBundles);
-  assert(OldFunc->getFunctionType() == TLIFunc->getFunctionType() &&
+  assert(VecFTy == TLIFunc->getFunctionType() &&
          "Expecting function types to be identical");
   CI.replaceAllUsesWith(Replacement);
-  if (isa<FPMathOperator>(Replacement)) {
-    // Preserve fast math flags for FP math.
+  // Preserve fast math flags for FP math.
+  if (isa<FPMathOperator>(Replacement))
     Replacement->copyFastMathFlags(&CI);
-  }
 
   LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Replaced call to `"
-                    << OldFunc->getName() << "` with call to `" << TLIName
-                    << "`.\n");
+                    << OldFunc->getName() << "` with call to `"
+                    << TLIFunc->getName() << "`.\n");
   ++NumCallsReplaced;
   return true;
 }
 
+/// Utility method to get the VecDesc, depending on whether there is a TLI
+/// mapping, either with or without a mask.
+static std::optional<const VecDesc *> getVecDesc(const TargetLibraryInfo &TLI,
+                                                 const StringRef &ScalarName,
+                                                 const ElementCount &VF) {
+  const VecDesc *VDMasked = TLI.getVectorMappingInfo(ScalarName, VF, true);
+  const VecDesc *VDNoMask = TLI.getVectorMappingInfo(ScalarName, VF, false);
+  // Invalid when there are both variants (ie masked and unmasked), or none
+  if ((VDMasked == nullptr) == (VDNoMask == nullptr))
+    return std::nullopt;
+
+  return {VDMasked != nullptr ? VDMasked : VDNoMask};
+}
+
+/// Returns whether it is able to replace a call to the intrinsic \p CI with a
+/// TLI mapped call.
 static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
                                     CallInst &CI) {
-  if (!CI.getCalledFunction()) {
+  if (!CI.getCalledFunction())
     return false;
-  }
 
   auto IntrinsicID = CI.getCalledFunction()->getIntrinsicID();
-  if (IntrinsicID == Intrinsic::not_intrinsic) {
-    // Replacement is only performed for intrinsic functions
+  // Replacement is only performed for intrinsic functions
+  if (IntrinsicID == Intrinsic::not_intrinsic)
     return false;
-  }
 
-  // Convert vector arguments to scalar type and check that
-  // all vector operands have identical vector width.
+  // Convert vector arguments to scalar type and check that all vector operands
+  // have identical vector width.
   ElementCount VF = ElementCount::getFixed(0);
   SmallVector<Type *> ScalarTypes;
   for (auto Arg : enumerate(CI.args())) {
-    auto *ArgType = Arg.value()->getType();
-    // Vector calls to intrinsics can still have
-    // scalar operands for specific arguments.
+    auto *ArgTy = Arg.value()->getType();
     if (isVectorIntrinsicWithScalarOpAtArg(IntrinsicID, Arg.index())) {
-      ScalarTypes.push_back(ArgType);
-    } else {
-      // The argument in this place should be a vector if
-      // this is a call to a vector intrinsic.
-      auto *VectorArgTy = dyn_cast<VectorType>(ArgType);
-      if (!VectorArgTy) {
-        // The argument is not a vector, do not perform
-        // the replacement.
-        return false;
-      }
-      ElementCount NumElements = VectorArgTy->getElementCount();
-      if (NumElements.isScalable()) {
-        // The current implementation does not support
-        // scalable vectors.
+      ScalarTypes.push_back(ArgTy);
+    } else if (auto *VectorArgTy = dyn_cast<VectorType>(ArgTy)) {
+      ScalarTypes.push_back(ArgTy->getScalarType());
+      // Disallow vector arguments with different VFs. When processing the first
+      // vector argument, store it's VF, and for the rest ensure that they match
+      // it.
+      if (VF.isZero())
+        VF = VectorArgTy->getElementCount();
+      else if (VF != VectorArgTy->getElementCount())
         return false;
-      }
-      if (VF.isNonZero() && VF != NumElements) {
-        // The different arguments differ in vector size.
-        return false;
-      } else {
-        VF = NumElements;
-      }
-      ScalarTypes.push_back(VectorArgTy->getElementType());
+    } else {
+      // enters when it is supposed to be a vector argument but it isn't.
+      return false;
     }
   }
 
-  // Try to reconstruct the name for the scalar version of this
-  // intrinsic using the intrinsic ID and the argument types
-  // converted to scalar above.
-  std::string ScalarName;
-  if (Intrinsic::isOverloaded(IntrinsicID)) {
-    ScalarName = Intrinsic::getName(IntrinsicID, ScalarTypes, CI.getModule());
-  } else {
-    ScalarName = Intrinsic::getName(IntrinsicID).str();
-  }
+  // Try to reconstruct the name for the scalar version of this intrinsic using
+  // the intrinsic ID and the argument types converted to scalar above.
+  std::string ScalarName =
+      (Intrinsic::isOverloaded(IntrinsicID)
+           ? Intrinsic::getName(IntrinsicID, ScalarTypes, CI.getModule())
+           : Intrinsic::getName(IntrinsicID).str());
 
-  if (!TLI.isFunctionVectorizable(ScalarName)) {
-    // The TargetLibraryInfo does not contain a vectorized version of
-    // the scalar function.
+  // The TargetLibraryInfo does not contain a vectorized version of the scalar
+  // function.
+  if (!TLI.isFunctionVectorizable(ScalarName))
     return false;
-  }
 
-  // Try to find the mapping for the scalar version of this intrinsic
-  // and the exact vector width of the call operands in the
-  // TargetLibraryInfo.
-  StringRef TLIName = TLI.getVectorizedFunction(ScalarName, VF);
+  auto OptVD = getVecDesc(TLI, ScalarName, VF);
+  if (!OptVD)
+    return false;
 
+  const VecDesc *VD = *OptVD;
+  // Try to find the mapping for the scalar version of this intrinsic and the
+  // exact vector width of the call operands in the TargetLibraryInfo.
+  StringRef TLIName = TLI.getVectorizedFunction(ScalarName, VF, VD->isMasked());
   LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Looking up TLI mapping for `"
                     << ScalarName << "` and vector width " << VF << ".\n");
 
-  if (!TLIName.empty()) {
-    // Found the correct mapping in the TargetLibraryInfo,
-    // replace the call to the intrinsic with a call to
-    // the vector library function.
-    LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Found TLI function `" << TLIName
-                      << "`.\n");
-    return replaceWithTLIFunction(CI, TLIName);
-  }
+  // TLI failed to find a correct mapping.
+  if (TLIName.empty())
+    return false;
 
-  return false;
+  // Find the vector Function and replace the call to the intrinsic with a call
+  // to the vector library function.
+  LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Found TLI function `" << TLIName
+                    << "`.\n");
+
+  Type *ScalarRetTy = CI.getType()->getScalarType();
+  FunctionType *ScalarFTy = FunctionType::get(ScalarRetTy, ScalarTypes, false);
+  const std::string MangledName = VD->getVectorFunctionABIVariantString();
+  auto OptInfo = VFABI::tryDemangleForVFABI(MangledName, ScalarFTy);
+  if (!OptInfo)
+    return false;
+
+  // get the vector FunctionType
+  Module *M = CI.getModule();
+  auto OptFTy = VFABI::createFunctionType(*OptInfo, ScalarFTy);
+  if (!OptFTy)
+    return false;
+
+  Function *OldFunc = CI.getCalledFunction();
+  FunctionType *VectorFTy = *OptFTy;
+  Function *TLIFunc = getTLIFunction(M, VectorFTy, OldFunc, TLIName);
+  return replaceWithTLIFunction(M, CI, OptInfo->Shape.VF, OldFunc, TLIFunc,
+                                VectorFTy, VD->isMasked());
----------------
mgabka wrote:

I think we need to use OptInfo->isMasked(), as it is the OptInfo, which holds the demangled information.

It must be consistent with what you previously received from TLI.getVectorMappingInfo, but it feels more logical to use OptInfo here, as we will just have one source of the vec function features.


It also feels like you could pass here entire OptInfo, as it has all you want VF, Vector Name, and Mask.

https://github.com/llvm/llvm-project/pull/73642


More information about the llvm-commits mailing list