[llvm] [TLI] replace-with-veclib works with FRem Instruction. (PR #76166)

Maciej Gabka via llvm-commits llvm-commits at lists.llvm.org
Thu Dec 28 04:50:58 PST 2023


================
@@ -69,73 +70,83 @@ Function *getTLIFunction(Module *M, FunctionType *VectorFTy,
   return TLIFunc;
 }
 
-/// Replace the call to the vector intrinsic ( \p CalltoReplace ) with a call to
-/// the corresponding function from the vector library ( \p TLIVecFunc ).
-static void replaceWithTLIFunction(CallInst &CalltoReplace, VFInfo &Info,
+/// Replace the Instruction \p I with a call to the corresponding function from
+/// the vector library ( \p TLIVecFunc ).
+static void replaceWithTLIFunction(Instruction &I, VFInfo &Info,
                                    Function *TLIVecFunc) {
-  IRBuilder<> IRBuilder(&CalltoReplace);
-  SmallVector<Value *> Args(CalltoReplace.args());
+  IRBuilder<> IRBuilder(&I);
+  auto *CI = dyn_cast<CallInst>(&I);
+  SmallVector<Value *> Args(CI ? CI->args() : I.operands());
   if (auto OptMaskpos = Info.getParamIndexForOptionalMask()) {
-    auto *MaskTy = VectorType::get(Type::getInt1Ty(CalltoReplace.getContext()),
-                                   Info.Shape.VF);
+    auto *MaskTy =
+        VectorType::get(Type::getInt1Ty(I.getContext()), Info.Shape.VF);
     Args.insert(Args.begin() + OptMaskpos.value(),
                 Constant::getAllOnesValue(MaskTy));
   }
 
-  // Preserve the operand bundles.
+  // Preserve the operand bundles for CallInsts.
   SmallVector<OperandBundleDef, 1> OpBundles;
-  CalltoReplace.getOperandBundlesAsDefs(OpBundles);
+  if (CI)
+    CI->getOperandBundlesAsDefs(OpBundles);
+
   CallInst *Replacement = IRBuilder.CreateCall(TLIVecFunc, Args, OpBundles);
-  CalltoReplace.replaceAllUsesWith(Replacement);
+  I.replaceAllUsesWith(Replacement);
   // Preserve fast math flags for FP math.
   if (isa<FPMathOperator>(Replacement))
-    Replacement->copyFastMathFlags(&CalltoReplace);
+    Replacement->copyFastMathFlags(&I);
 }
 
-/// Returns true when successfully replaced \p CallToReplace with a suitable
-/// function taking vector arguments, based on available mappings in the \p TLI.
-/// Currently only works when \p CallToReplace is a call to vectorized
-/// intrinsic.
+/// Returns true when successfully replaced \p I with a suitable function taking
+/// vector arguments, based on available mappings in the \p TLI. Currently only
+/// works when \p I is a call to vectorized intrinsic or the frem Instruction.
 static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
-                                    CallInst &CallToReplace) {
-  if (!CallToReplace.getCalledFunction())
-    return false;
-
-  auto IntrinsicID = CallToReplace.getCalledFunction()->getIntrinsicID();
-  // Replacement is only performed for intrinsic functions.
-  if (IntrinsicID == Intrinsic::not_intrinsic)
-    return false;
-
-  // Compute arguments types of the corresponding scalar call. Additionally
-  // checks if in the vector call, all vector operands have the same EC.
+                                    Instruction &I) {
+  std::string ScalarName;
   ElementCount VF = ElementCount::getFixed(0);
-  SmallVector<Type *> ScalarArgTypes;
-  for (auto Arg : enumerate(CallToReplace.args())) {
-    auto *ArgTy = Arg.value()->getType();
-    if (isVectorIntrinsicWithScalarOpAtArg(IntrinsicID, Arg.index())) {
-      ScalarArgTypes.push_back(ArgTy);
-    } else if (auto *VectorArgTy = dyn_cast<VectorType>(ArgTy)) {
-      ScalarArgTypes.push_back(ArgTy->getScalarType());
-      // Disallow vector arguments with different VFs. When processing the first
-      // vector argument, store it's VF, and for the rest ensure that they match
-      // it.
-      if (VF.isZero())
-        VF = VectorArgTy->getElementCount();
-      else if (VF != VectorArgTy->getElementCount())
-        return false;
-    } else
-      // Exit when it is supposed to be a vector argument but it isn't.
+  CallInst *CI = dyn_cast<CallInst>(&I);
+  SmallVector<Type *, 8> ScalarArgTypes;
+  if (CI) {
+    Intrinsic::ID IID = Intrinsic::not_intrinsic;
+    IID = CI->getCalledFunction()->getIntrinsicID();
+    // Compute arguments types of the corresponding scalar call. Additionally
+    // checks if in the vector call, all vector operands have the same EC.
+    for (auto Arg : enumerate(CI ? CI->args() : I.operands())) {
+      auto *ArgTy = Arg.value()->getType();
+      if (CI && isVectorIntrinsicWithScalarOpAtArg(IID, Arg.index())) {
+        ScalarArgTypes.push_back(ArgTy);
+      } else {
+        auto *VectorArgTy = dyn_cast<VectorType>(ArgTy);
+        // We are expecting only VectorTypes, as:
+        // - with a CallInst, scalar operands are handled earlier
+        // - with the frem Instruction, both operands must be vectors.
+        if (!VectorArgTy)
+          return false;
+        ScalarArgTypes.push_back(ArgTy->getScalarType());
----------------
mgabka wrote:

using getElementType is better as you will avoid extra check, as you know that this is a vector type

https://github.com/llvm/llvm-project/pull/76166


More information about the llvm-commits mailing list