[llvm] AMDGPU: Verify function type matches when matching libcalls (PR #119043)
via llvm-commits
llvm-commits at lists.llvm.org
Fri Dec 6 15:54:40 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-amdgpu
Author: Matt Arsenault (arsenm)
<details>
<summary>Changes</summary>
Previously this would recognize a call to a mangled ldexp(float, float)
as a candidate to replace with the intrinsic. We need to verify the second
parameter is in fact an integer.
Fixes: SWDEV-501389
---
Full diff: https://github.com/llvm/llvm-project/pull/119043.diff
4 Files Affected:
- (modified) llvm/lib/Target/AMDGPU/AMDGPULibCalls.cpp (+1-1)
- (modified) llvm/lib/Target/AMDGPU/AMDGPULibFunc.cpp (+36-6)
- (modified) llvm/lib/Target/AMDGPU/AMDGPULibFunc.h (+21-5)
- (modified) llvm/test/CodeGen/AMDGPU/amdgpu-simplify-libcall-ldexp.ll (+41)
``````````diff
diff --git a/llvm/lib/Target/AMDGPU/AMDGPULibCalls.cpp b/llvm/lib/Target/AMDGPU/AMDGPULibCalls.cpp
index f2c0be76b771b5..cf8b416d23e50d 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPULibCalls.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPULibCalls.cpp
@@ -654,7 +654,7 @@ bool AMDGPULibCalls::fold(CallInst *CI) {
// Further check the number of arguments to see if they match.
// TODO: Check calling convention matches too
- if (!FInfo.isCompatibleSignature(CI->getFunctionType()))
+ if (!FInfo.isCompatibleSignature(*Callee->getParent(), CI->getFunctionType()))
return false;
LLVM_DEBUG(dbgs() << "AMDIC: try folding " << *CI << '\n');
diff --git a/llvm/lib/Target/AMDGPU/AMDGPULibFunc.cpp b/llvm/lib/Target/AMDGPU/AMDGPULibFunc.cpp
index 4c596e37476c4e..c23472b147bcef 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPULibFunc.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPULibFunc.cpp
@@ -969,7 +969,7 @@ static Type* getIntrinsicParamType(
return T;
}
-FunctionType *AMDGPUMangledLibFunc::getFunctionType(Module &M) const {
+FunctionType *AMDGPUMangledLibFunc::getFunctionType(const Module &M) const {
LLVMContext& C = M.getContext();
std::vector<Type*> Args;
ParamIterator I(Leads, manglingRules[FuncId]);
@@ -997,9 +997,39 @@ std::string AMDGPUMangledLibFunc::getName() const {
return std::string(OS.str());
}
-bool AMDGPULibFunc::isCompatibleSignature(const FunctionType *FuncTy) const {
- // TODO: Validate types make sense
- return !FuncTy->isVarArg() && FuncTy->getNumParams() == getNumArgs();
+bool AMDGPULibFunc::isCompatibleSignature(const Module &M,
+ const FunctionType *CallTy) const {
+ const FunctionType *FuncTy = getFunctionType(M);
+
+ // FIXME: UnmangledFuncInfo does not have any type information other than the
+ // number of arguments.
+ if (!FuncTy)
+ return getNumArgs() == CallTy->getNumParams();
+
+ // Normally the types should exactly match.
+ if (FuncTy == CallTy)
+ return true;
+
+ const unsigned NumParams = FuncTy->getNumParams();
+ if (NumParams != CallTy->getNumParams())
+ return false;
+
+ for (unsigned I = 0; I != NumParams; ++I) {
+ Type *FuncArgTy = FuncTy->getParamType(I);
+ Type *CallArgTy = CallTy->getParamType(I);
+ if (FuncArgTy == CallArgTy)
+ continue;
+
+ // Some cases permit implicit splatting a scalar value to a vector argument.
+ auto *FuncVecTy = dyn_cast<VectorType>(FuncArgTy);
+ if (FuncVecTy && FuncVecTy->getElementType() == CallArgTy &&
+ allowsImplicitVectorSplat(I))
+ continue;
+
+ return false;
+ }
+
+ return true;
}
Function *AMDGPULibFunc::getFunction(Module *M, const AMDGPULibFunc &fInfo) {
@@ -1012,7 +1042,7 @@ Function *AMDGPULibFunc::getFunction(Module *M, const AMDGPULibFunc &fInfo) {
if (F->hasFnAttribute(Attribute::NoBuiltin))
return nullptr;
- if (!fInfo.isCompatibleSignature(F->getFunctionType()))
+ if (!fInfo.isCompatibleSignature(*M, F->getFunctionType()))
return nullptr;
return F;
@@ -1028,7 +1058,7 @@ FunctionCallee AMDGPULibFunc::getOrInsertFunction(Module *M,
if (F->hasFnAttribute(Attribute::NoBuiltin))
return nullptr;
if (!F->isDeclaration() &&
- fInfo.isCompatibleSignature(F->getFunctionType()))
+ fInfo.isCompatibleSignature(*M, F->getFunctionType()))
return F;
}
diff --git a/llvm/lib/Target/AMDGPU/AMDGPULibFunc.h b/llvm/lib/Target/AMDGPU/AMDGPULibFunc.h
index 10551bee3fa8d4..580ef51b559d80 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPULibFunc.h
+++ b/llvm/lib/Target/AMDGPU/AMDGPULibFunc.h
@@ -352,7 +352,7 @@ class AMDGPULibFuncImpl : public AMDGPULibFuncBase {
void setName(StringRef N) { Name = std::string(N); }
void setPrefix(ENamePrefix pfx) { FKind = pfx; }
- virtual FunctionType *getFunctionType(Module &M) const = 0;
+ virtual FunctionType *getFunctionType(const Module &M) const = 0;
protected:
EFuncId FuncId;
@@ -391,8 +391,22 @@ class AMDGPULibFunc : public AMDGPULibFuncBase {
return Impl->parseFuncName(MangledName);
}
+ /// Return true if it's legal to splat a scalar value passed in parameter \p
+ /// ArgIdx to a vector argument.
+ bool allowsImplicitVectorSplat(int ArgIdx) const {
+ switch (getId()) {
+ case EI_LDEXP:
+ return ArgIdx == 1;
+ case EI_FMIN:
+ case EI_FMAX:
+ return true;
+ default:
+ return false;
+ }
+ }
+
// Validate the call type matches the expected libfunc type.
- bool isCompatibleSignature(const FunctionType *FuncTy) const;
+ bool isCompatibleSignature(const Module &M, const FunctionType *FuncTy) const;
/// \return The mangled function name for mangled library functions
/// and unmangled function name for unmangled library functions.
@@ -401,7 +415,7 @@ class AMDGPULibFunc : public AMDGPULibFuncBase {
void setName(StringRef N) { Impl->setName(N); }
void setPrefix(ENamePrefix PFX) { Impl->setPrefix(PFX); }
- FunctionType *getFunctionType(Module &M) const {
+ FunctionType *getFunctionType(const Module &M) const {
return Impl->getFunctionType(M);
}
static Function *getFunction(llvm::Module *M, const AMDGPULibFunc &fInfo);
@@ -428,7 +442,7 @@ class AMDGPUMangledLibFunc : public AMDGPULibFuncImpl {
std::string getName() const override;
unsigned getNumArgs() const override;
- FunctionType *getFunctionType(Module &M) const override;
+ FunctionType *getFunctionType(const Module &M) const override;
static StringRef getUnmangledName(StringRef MangledName);
bool parseFuncName(StringRef &mangledName) override;
@@ -458,7 +472,9 @@ class AMDGPUUnmangledLibFunc : public AMDGPULibFuncImpl {
}
std::string getName() const override { return Name; }
unsigned getNumArgs() const override;
- FunctionType *getFunctionType(Module &M) const override { return FuncTy; }
+ FunctionType *getFunctionType(const Module &M) const override {
+ return FuncTy;
+ }
bool parseFuncName(StringRef &Name) override;
diff --git a/llvm/test/CodeGen/AMDGPU/amdgpu-simplify-libcall-ldexp.ll b/llvm/test/CodeGen/AMDGPU/amdgpu-simplify-libcall-ldexp.ll
index 24082b8c666111..dc275b33b012da 100644
--- a/llvm/test/CodeGen/AMDGPU/amdgpu-simplify-libcall-ldexp.ll
+++ b/llvm/test/CodeGen/AMDGPU/amdgpu-simplify-libcall-ldexp.ll
@@ -242,6 +242,47 @@ define float @test_ldexp_f32_strictfp(float %x, i32 %y) #4 {
ret float %ldexp
}
+;---------------------------------------------------------------------
+; Invalid signatures
+;---------------------------------------------------------------------
+
+; Declared with wrong type, second argument is float
+declare float @_Z5ldexpff(float noundef, float noundef)
+
+define float @call_wrong_typed_ldexp_f32_second_arg(float %x, float %wrongtype) {
+; CHECK-LABEL: define float @call_wrong_typed_ldexp_f32_second_arg
+; CHECK-SAME: (float [[X:%.*]], float [[WRONGTYPE:%.*]]) {
+; CHECK-NEXT: [[CALL:%.*]] = call float @_Z5ldexpff(float [[X]], float [[WRONGTYPE]])
+; CHECK-NEXT: ret float [[CALL]]
+;
+ %call = call float @_Z5ldexpff(float %x, float %wrongtype)
+ ret float %call
+}
+
+declare <2 x float> @_Z5ldexpDv2_fS_(<2 x float>, <2 x float>)
+
+define <2 x float> @call_wrong_typed_ldexp_v2f32_second_arg(<2 x float> %x, <2 x float> %wrongtype) {
+; CHECK-LABEL: define <2 x float> @call_wrong_typed_ldexp_v2f32_second_arg
+; CHECK-SAME: (<2 x float> [[X:%.*]], <2 x float> [[WRONGTYPE:%.*]]) {
+; CHECK-NEXT: [[CALL:%.*]] = call <2 x float> @_Z5ldexpDv2_fS_(<2 x float> [[X]], <2 x float> [[WRONGTYPE]])
+; CHECK-NEXT: ret <2 x float> [[CALL]]
+;
+ %call = call <2 x float> @_Z5ldexpDv2_fS_(<2 x float> %x, <2 x float> %wrongtype)
+ ret <2 x float> %call
+}
+
+declare <2 x float> @_Z5ldexpDv2_ff(<2 x float>, float)
+
+define <2 x float> @call_wrong_typed_ldexp_v2f32_f32(<2 x float> %x, float %wrongtype) {
+; CHECK-LABEL: define <2 x float> @call_wrong_typed_ldexp_v2f32_f32
+; CHECK-SAME: (<2 x float> [[X:%.*]], float [[WRONGTYPE:%.*]]) {
+; CHECK-NEXT: [[CALL:%.*]] = call <2 x float> @_Z5ldexpDv2_ff(<2 x float> [[X]], float [[WRONGTYPE]])
+; CHECK-NEXT: ret <2 x float> [[CALL]]
+;
+ %call = call <2 x float> @_Z5ldexpDv2_ff(<2 x float> %x, float %wrongtype)
+ ret <2 x float> %call
+}
+
attributes #0 = { nobuiltin }
attributes #1 = { "no-builtins" }
attributes #2 = { nounwind memory(none) }
``````````
</details>
https://github.com/llvm/llvm-project/pull/119043
More information about the llvm-commits
mailing list