[llvm] [TLI] replace-with-veclib works with FRem Instruction. (PR #76166)

Maciej Gabka via llvm-commits llvm-commits at lists.llvm.org
Thu Jan 4 12:02:38 PST 2024


================
@@ -69,88 +69,97 @@ Function *getTLIFunction(Module *M, FunctionType *VectorFTy,
   return TLIFunc;
 }
 
-/// Replace the call to the vector intrinsic ( \p CalltoReplace ) with a call to
-/// the corresponding function from the vector library ( \p TLIVecFunc ).
-static void replaceWithTLIFunction(CallInst &CalltoReplace, VFInfo &Info,
+/// Replace the instruction \p I with a call to the corresponding function from
+/// the vector library ( \p TLIVecFunc ).
+static void replaceWithTLIFunction(Instruction &I, VFInfo &Info,
                                    Function *TLIVecFunc) {
-  IRBuilder<> IRBuilder(&CalltoReplace);
-  SmallVector<Value *> Args(CalltoReplace.args());
+  IRBuilder<> IRBuilder(&I);
+  auto *CI = dyn_cast<CallInst>(&I);
+  SmallVector<Value *> Args(CI ? CI->args() : I.operands());
   if (auto OptMaskpos = Info.getParamIndexForOptionalMask()) {
-    auto *MaskTy = VectorType::get(Type::getInt1Ty(CalltoReplace.getContext()),
-                                   Info.Shape.VF);
+    auto *MaskTy =
+        VectorType::get(Type::getInt1Ty(I.getContext()), Info.Shape.VF);
     Args.insert(Args.begin() + OptMaskpos.value(),
                 Constant::getAllOnesValue(MaskTy));
   }
 
-  // Preserve the operand bundles.
+  // If it is a call instruction, preserve the operand bundles.
   SmallVector<OperandBundleDef, 1> OpBundles;
-  CalltoReplace.getOperandBundlesAsDefs(OpBundles);
-  CallInst *Replacement = IRBuilder.CreateCall(TLIVecFunc, Args, OpBundles);
-  CalltoReplace.replaceAllUsesWith(Replacement);
+  if (CI)
+    CI->getOperandBundlesAsDefs(OpBundles);
+
+  auto *Replacement = IRBuilder.CreateCall(TLIVecFunc, Args, OpBundles);
+  I.replaceAllUsesWith(Replacement);
   // Preserve fast math flags for FP math.
   if (isa<FPMathOperator>(Replacement))
-    Replacement->copyFastMathFlags(&CalltoReplace);
+    Replacement->copyFastMathFlags(&I);
 }
 
-/// Returns true when successfully replaced \p CallToReplace with a suitable
-/// function taking vector arguments, based on available mappings in the \p TLI.
-/// Currently only works when \p CallToReplace is a call to vectorized
-/// intrinsic.
+/// Returns true when successfully replaced \p I with a suitable function taking
+/// vector arguments, based on available mappings in the \p TLI. Currently only
+/// works when \p I is a call to vectorized intrinsic or the frem instruction.
 static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
-                                    CallInst &CallToReplace) {
-  if (!CallToReplace.getCalledFunction())
-    return false;
-
-  auto IntrinsicID = CallToReplace.getCalledFunction()->getIntrinsicID();
-  // Replacement is only performed for intrinsic functions.
-  if (IntrinsicID == Intrinsic::not_intrinsic)
-    return false;
-
-  // Compute arguments types of the corresponding scalar call. Additionally
-  // checks if in the vector call, all vector operands have the same EC.
-  ElementCount VF = ElementCount::getFixed(0);
-  SmallVector<Type *> ScalarArgTypes;
-  for (auto Arg : enumerate(CallToReplace.args())) {
-    auto *ArgTy = Arg.value()->getType();
-    if (isVectorIntrinsicWithScalarOpAtArg(IntrinsicID, Arg.index())) {
-      ScalarArgTypes.push_back(ArgTy);
-    } else if (auto *VectorArgTy = dyn_cast<VectorType>(ArgTy)) {
-      ScalarArgTypes.push_back(ArgTy->getScalarType());
-      // Disallow vector arguments with different VFs. When processing the first
-      // vector argument, store it's VF, and for the rest ensure that they match
-      // it.
-      if (VF.isZero())
-        VF = VectorArgTy->getElementCount();
-      else if (VF != VectorArgTy->getElementCount())
+                                    Instruction &I) {
+  std::string ScalarName;
+  ElementCount EC = ElementCount::getFixed(0);
+  Function *FuncToReplace = nullptr;
+  SmallVector<Type *, 8> ScalarArgTypes;
+  // Compute the argument types of the corresponding scalar call, the scalar
+  // function name, and EC. For calls, it additionally checks if in the vector
+  // call, all vector operands have the same EC.
+  if (auto *CI = dyn_cast<CallInst>(&I)) {
+    Intrinsic::ID IID = CI->getCalledFunction()->getIntrinsicID();
+    assert(IID != Intrinsic::not_intrinsic && "Not an intrinsic");
+    FuncToReplace = CI->getCalledFunction();
+    for (auto Arg : enumerate(CI->args())) {
----------------
mgabka wrote:

it just occurred to me that this loop checks only input arguments, but in case of intrinsic we do not check the return type EC, maybe it is worth to initialise EC to it instead of doing:

"ElementCount EC = ElementCount::getFixed(0);"?
I believe we won't support intrinsics returning scalar value or void at the moment

https://github.com/llvm/llvm-project/pull/76166


More information about the llvm-commits mailing list