[llvm] [TLI] replace-with-veclib works with FRem Instruction. (PR #76166)
Maciej Gabka via llvm-commits
llvm-commits at lists.llvm.org
Thu Jan 4 12:02:37 PST 2024
================
@@ -69,88 +69,97 @@ Function *getTLIFunction(Module *M, FunctionType *VectorFTy,
return TLIFunc;
}
-/// Replace the call to the vector intrinsic ( \p CalltoReplace ) with a call to
-/// the corresponding function from the vector library ( \p TLIVecFunc ).
-static void replaceWithTLIFunction(CallInst &CalltoReplace, VFInfo &Info,
+/// Replace the instruction \p I with a call to the corresponding function from
+/// the vector library ( \p TLIVecFunc ).
+static void replaceWithTLIFunction(Instruction &I, VFInfo &Info,
Function *TLIVecFunc) {
- IRBuilder<> IRBuilder(&CalltoReplace);
- SmallVector<Value *> Args(CalltoReplace.args());
+ IRBuilder<> IRBuilder(&I);
+ auto *CI = dyn_cast<CallInst>(&I);
+ SmallVector<Value *> Args(CI ? CI->args() : I.operands());
if (auto OptMaskpos = Info.getParamIndexForOptionalMask()) {
- auto *MaskTy = VectorType::get(Type::getInt1Ty(CalltoReplace.getContext()),
- Info.Shape.VF);
+ auto *MaskTy =
+ VectorType::get(Type::getInt1Ty(I.getContext()), Info.Shape.VF);
Args.insert(Args.begin() + OptMaskpos.value(),
Constant::getAllOnesValue(MaskTy));
}
- // Preserve the operand bundles.
+ // If it is a call instruction, preserve the operand bundles.
SmallVector<OperandBundleDef, 1> OpBundles;
- CalltoReplace.getOperandBundlesAsDefs(OpBundles);
- CallInst *Replacement = IRBuilder.CreateCall(TLIVecFunc, Args, OpBundles);
- CalltoReplace.replaceAllUsesWith(Replacement);
+ if (CI)
+ CI->getOperandBundlesAsDefs(OpBundles);
+
+ auto *Replacement = IRBuilder.CreateCall(TLIVecFunc, Args, OpBundles);
+ I.replaceAllUsesWith(Replacement);
// Preserve fast math flags for FP math.
if (isa<FPMathOperator>(Replacement))
- Replacement->copyFastMathFlags(&CalltoReplace);
+ Replacement->copyFastMathFlags(&I);
}
-/// Returns true when successfully replaced \p CallToReplace with a suitable
-/// function taking vector arguments, based on available mappings in the \p TLI.
-/// Currently only works when \p CallToReplace is a call to vectorized
-/// intrinsic.
+/// Returns true when successfully replaced \p I with a suitable function taking
+/// vector arguments, based on available mappings in the \p TLI. Currently only
+/// works when \p I is a call to vectorized intrinsic or the frem instruction.
static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
- CallInst &CallToReplace) {
- if (!CallToReplace.getCalledFunction())
- return false;
-
- auto IntrinsicID = CallToReplace.getCalledFunction()->getIntrinsicID();
- // Replacement is only performed for intrinsic functions.
- if (IntrinsicID == Intrinsic::not_intrinsic)
- return false;
-
- // Compute arguments types of the corresponding scalar call. Additionally
- // checks if in the vector call, all vector operands have the same EC.
- ElementCount VF = ElementCount::getFixed(0);
- SmallVector<Type *> ScalarArgTypes;
- for (auto Arg : enumerate(CallToReplace.args())) {
- auto *ArgTy = Arg.value()->getType();
- if (isVectorIntrinsicWithScalarOpAtArg(IntrinsicID, Arg.index())) {
- ScalarArgTypes.push_back(ArgTy);
- } else if (auto *VectorArgTy = dyn_cast<VectorType>(ArgTy)) {
- ScalarArgTypes.push_back(ArgTy->getScalarType());
- // Disallow vector arguments with different VFs. When processing the first
- // vector argument, store it's VF, and for the rest ensure that they match
- // it.
- if (VF.isZero())
- VF = VectorArgTy->getElementCount();
- else if (VF != VectorArgTy->getElementCount())
+ Instruction &I) {
+ std::string ScalarName;
+ ElementCount EC = ElementCount::getFixed(0);
+ Function *FuncToReplace = nullptr;
+ SmallVector<Type *, 8> ScalarArgTypes;
+ // Compute the argument types of the corresponding scalar call, the scalar
+ // function name, and EC. For calls, it additionally checks if in the vector
+ // call, all vector operands have the same EC.
+ if (auto *CI = dyn_cast<CallInst>(&I)) {
+ Intrinsic::ID IID = CI->getCalledFunction()->getIntrinsicID();
+ assert(IID != Intrinsic::not_intrinsic && "Not an intrinsic");
+ FuncToReplace = CI->getCalledFunction();
+ for (auto Arg : enumerate(CI->args())) {
+ auto *ArgTy = Arg.value()->getType();
+ if (isVectorIntrinsicWithScalarOpAtArg(IID, Arg.index())) {
+ ScalarArgTypes.push_back(ArgTy);
+ } else if (auto *VectorArgTy = dyn_cast<VectorType>(ArgTy)) {
+ ScalarArgTypes.push_back(VectorArgTy->getElementType());
+ // Disallow vector arguments with different VFs. When processing the
+ // first vector argument, store it's VF, and for the rest ensure that
+ // they match it.
+ if (EC.isZero())
+ EC = VectorArgTy->getElementCount();
+ else if (EC != VectorArgTy->getElementCount())
+ return false;
+ } else
+ // Exit when it is supposed to be a vector argument but it isn't.
return false;
- } else
- // Exit when it is supposed to be a vector argument but it isn't.
+ }
+ // Try to reconstruct the name for the scalar version of the instruction,
+ // using scalar argument types.
+ ScalarName = Intrinsic::isOverloaded(IID)
+ ? Intrinsic::getName(IID, ScalarArgTypes, I.getModule())
+ : Intrinsic::getName(IID).str();
+ } else {
+ assert(I.getType()->isVectorTy() && "Instruction must use vectors");
+ LibFunc Func;
+ auto *ScalarTy = I.getType()->getScalarType();
+ if (!TLI.getLibFunc(I.getOpcode(), ScalarTy, Func))
return false;
+ ScalarName = TLI.getName(Func);
+ ScalarArgTypes = {ScalarTy, ScalarTy};
+ if (auto *VTy = dyn_cast<VectorType>(I.getType()))
----------------
mgabka wrote:
to me it feels like this could be done before the assert, i.e:
auto *VTy = dyn_cast<VectorType>(I.getType()
assert(VTy && "Instruction must use vectors");
and then you can use everywhere VTy instead calling several times getType.
and it will also remove the confusing " if (auto *VTy = dyn_cast<VectorType>(I.getType()))" without else, which in theory should never happen as the isSupportedInstruction ensures that we only get here with vector version of frem.
https://github.com/llvm/llvm-project/pull/76166
More information about the llvm-commits
mailing list