[llvm] [TLI] Pass replace-with-veclib works with Scalable Vectors. (PR #73642)

Maciej Gabka via llvm-commits llvm-commits at lists.llvm.org
Wed Dec 20 06:34:28 PST 2023


================
@@ -38,138 +42,135 @@ STATISTIC(NumTLIFuncDeclAdded,
 STATISTIC(NumFuncUsedAdded,
           "Number of functions added to `llvm.compiler.used`");
 
-static bool replaceWithTLIFunction(CallInst &CI, const StringRef TLIName) {
-  Module *M = CI.getModule();
-
-  Function *OldFunc = CI.getCalledFunction();
-
-  // Check if the vector library function is already declared in this module,
-  // otherwise insert it.
+/// Returns a vector Function that it adds to the Module \p M. When an \p
+/// ScalarFunc is not null, it copies its attributes to the newly created
+/// Function.
+Function *getTLIFunction(Module *M, FunctionType *VectorFTy,
+                         Function *ScalarFunc, const StringRef TLIName) {
   Function *TLIFunc = M->getFunction(TLIName);
   if (!TLIFunc) {
-    TLIFunc = Function::Create(OldFunc->getFunctionType(),
-                               Function::ExternalLinkage, TLIName, *M);
-    TLIFunc->copyAttributesFrom(OldFunc);
+    TLIFunc =
+        Function::Create(VectorFTy, Function::ExternalLinkage, TLIName, *M);
+    if (ScalarFunc)
+      TLIFunc->copyAttributesFrom(ScalarFunc);
 
     LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Added vector library function `"
                       << TLIName << "` of type `" << *(TLIFunc->getType())
                       << "` to module.\n");
 
     ++NumTLIFuncDeclAdded;
-
-    // Add the freshly created function to llvm.compiler.used,
-    // similar to as it is done in InjectTLIMappings
+    // Add the freshly created function to llvm.compiler.used, similar to as it
+    // is done in InjectTLIMappings.
     appendToCompilerUsed(*M, {TLIFunc});
-
     LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Adding `" << TLIName
                       << "` to `@llvm.compiler.used`.\n");
     ++NumFuncUsedAdded;
   }
+  return TLIFunc;
+}
 
-  // Replace the call to the vector intrinsic with a call
-  // to the corresponding function from the vector library.
-  IRBuilder<> IRBuilder(&CI);
-  SmallVector<Value *> Args(CI.args());
-  // Preserve the operand bundles.
-  SmallVector<OperandBundleDef, 1> OpBundles;
-  CI.getOperandBundlesAsDefs(OpBundles);
-  CallInst *Replacement = IRBuilder.CreateCall(TLIFunc, Args, OpBundles);
-  assert(OldFunc->getFunctionType() == TLIFunc->getFunctionType() &&
-         "Expecting function types to be identical");
-  CI.replaceAllUsesWith(Replacement);
-  if (isa<FPMathOperator>(Replacement)) {
-    // Preserve fast math flags for FP math.
-    Replacement->copyFastMathFlags(&CI);
+/// Replace the call to the vector intrinsic ( \p CalltoReplace ) with a call to
+/// the corresponding function from the vector library ( \p TLIVecFunc ).
+static void replaceWithTLIFunction(CallInst &CalltoReplace, VFInfo &Info,
+                                   Function *TLIVecFunc) {
+  IRBuilder<> IRBuilder(&CalltoReplace);
+  SmallVector<Value *> Args(CalltoReplace.args());
+  if (auto OptMaskpos = Info.getParamIndexForOptionalMask()) {
+    auto *MaskTy = VectorType::get(Type::getInt1Ty(CalltoReplace.getContext()),
+                                   Info.Shape.VF);
+    Args.insert(Args.begin() + OptMaskpos.value(),
+                Constant::getAllOnesValue(MaskTy));
   }
 
-  LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Replaced call to `"
-                    << OldFunc->getName() << "` with call to `" << TLIName
-                    << "`.\n");
-  ++NumCallsReplaced;
-  return true;
+  // Preserve the operand bundles.
+  SmallVector<OperandBundleDef, 1> OpBundles;
+  CalltoReplace.getOperandBundlesAsDefs(OpBundles);
+  CallInst *Replacement = IRBuilder.CreateCall(TLIVecFunc, Args, OpBundles);
+  CalltoReplace.replaceAllUsesWith(Replacement);
+  // Preserve fast math flags for FP math.
+  if (isa<FPMathOperator>(Replacement))
+    Replacement->copyFastMathFlags(&CalltoReplace);
 }
 
+/// Returns true when successfully replaced \p CallToReplace with a suitable
+/// function taking vector arguments, based on available mappings in the \p TLI.
+/// Currently only works when \p CallToReplace is a call to vectorized
+/// intrinsic.
 static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
-                                    CallInst &CI) {
-  if (!CI.getCalledFunction()) {
+                                    CallInst &CallToReplace) {
+  if (!CallToReplace.getCalledFunction())
     return false;
-  }
 
-  auto IntrinsicID = CI.getCalledFunction()->getIntrinsicID();
-  if (IntrinsicID == Intrinsic::not_intrinsic) {
-    // Replacement is only performed for intrinsic functions
+  auto IntrinsicID = CallToReplace.getCalledFunction()->getIntrinsicID();
+  // Replacement is only performed for intrinsic functions.
+  if (IntrinsicID == Intrinsic::not_intrinsic)
     return false;
-  }
 
-  // Convert vector arguments to scalar type and check that
-  // all vector operands have identical vector width.
+  // Compute arguments types of the corresponding scalar call. Additionally
+  // checks if in the vector call, all vector operands have the same EC.
   ElementCount VF = ElementCount::getFixed(0);
-  SmallVector<Type *> ScalarTypes;
-  for (auto Arg : enumerate(CI.args())) {
-    auto *ArgType = Arg.value()->getType();
-    // Vector calls to intrinsics can still have
-    // scalar operands for specific arguments.
+  SmallVector<Type *> ScalarArgTypes;
+  for (auto Arg : enumerate(CallToReplace.args())) {
+    auto *ArgTy = Arg.value()->getType();
     if (isVectorIntrinsicWithScalarOpAtArg(IntrinsicID, Arg.index())) {
-      ScalarTypes.push_back(ArgType);
-    } else {
-      // The argument in this place should be a vector if
-      // this is a call to a vector intrinsic.
-      auto *VectorArgTy = dyn_cast<VectorType>(ArgType);
-      if (!VectorArgTy) {
-        // The argument is not a vector, do not perform
-        // the replacement.
-        return false;
-      }
-      ElementCount NumElements = VectorArgTy->getElementCount();
-      if (NumElements.isScalable()) {
-        // The current implementation does not support
-        // scalable vectors.
+      ScalarArgTypes.push_back(ArgTy);
+    } else if (auto *VectorArgTy = dyn_cast<VectorType>(ArgTy)) {
+      ScalarArgTypes.push_back(ArgTy->getScalarType());
+      // Disallow vector arguments with different VFs. When processing the first
+      // vector argument, store it's VF, and for the rest ensure that they match
+      // it.
+      if (VF.isZero())
+        VF = VectorArgTy->getElementCount();
+      else if (VF != VectorArgTy->getElementCount())
         return false;
-      }
-      if (VF.isNonZero() && VF != NumElements) {
-        // The different arguments differ in vector size.
-        return false;
-      } else {
-        VF = NumElements;
-      }
-      ScalarTypes.push_back(VectorArgTy->getElementType());
-    }
+    } else
+      // Exit when it is supposed to be a vector argument but it isn't.
+      return false;
   }
 
-  // Try to reconstruct the name for the scalar version of this
-  // intrinsic using the intrinsic ID and the argument types
-  // converted to scalar above.
-  std::string ScalarName;
-  if (Intrinsic::isOverloaded(IntrinsicID)) {
-    ScalarName = Intrinsic::getName(IntrinsicID, ScalarTypes, CI.getModule());
-  } else {
-    ScalarName = Intrinsic::getName(IntrinsicID).str();
-  }
+  // Try to reconstruct the name for the scalar version of this intrinsic using
+  // the intrinsic ID and the argument types converted to scalar above.
+  std::string ScalarName =
+      (Intrinsic::isOverloaded(IntrinsicID)
+           ? Intrinsic::getName(IntrinsicID, ScalarArgTypes,
+                                CallToReplace.getModule())
+           : Intrinsic::getName(IntrinsicID).str());
+
+  // Try to find the mapping for the scalar version of this intrinsic and the
+  // exact vector width of the call operands in the TargetLibraryInfo. First,
+  // check with a non-masked variant, and if that fails try with a masked one.
+  const VecDesc *VD = TLI.getVectorMappingInfo(ScalarName, VF, false);
+  if (!VD && !(VD = TLI.getVectorMappingInfo(ScalarName, VF, true)))
----------------
mgabka wrote:

nit: /*Masked*/ true

https://github.com/llvm/llvm-project/pull/73642


More information about the llvm-commits mailing list