[llvm] b446c20 - AMDGPU: Verify function type matches when matching libcalls (#119043)

via llvm-commits llvm-commits at lists.llvm.org
Sun Dec 15 22:01:52 PST 2024


Author: Matt Arsenault
Date: 2024-12-16T15:01:48+09:00
New Revision: b446c208a5f0e2ad7193cc23e70642d207db4d13

URL: https://github.com/llvm/llvm-project/commit/b446c208a5f0e2ad7193cc23e70642d207db4d13
DIFF: https://github.com/llvm/llvm-project/commit/b446c208a5f0e2ad7193cc23e70642d207db4d13.diff

LOG: AMDGPU: Verify function type matches when matching libcalls (#119043)

Previously this would recognize a call to a mangled ldexp(float, float)
as a candidate to replace with the intrinsic. We need to verify the second
parameter is in fact an integer.

Fixes: SWDEV-501389

Added: 
    llvm/test/CodeGen/AMDGPU/amdgpu-simplify-libcall-image-function-signatures.ll

Modified: 
    llvm/lib/Target/AMDGPU/AMDGPULibCalls.cpp
    llvm/lib/Target/AMDGPU/AMDGPULibFunc.cpp
    llvm/lib/Target/AMDGPU/AMDGPULibFunc.h
    llvm/test/CodeGen/AMDGPU/amdgpu-simplify-libcall-ldexp.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AMDGPU/AMDGPULibCalls.cpp b/llvm/lib/Target/AMDGPU/AMDGPULibCalls.cpp
index f2c0be76b771b5..cf8b416d23e50d 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPULibCalls.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPULibCalls.cpp
@@ -654,7 +654,7 @@ bool AMDGPULibCalls::fold(CallInst *CI) {
 
   // Further check the number of arguments to see if they match.
   // TODO: Check calling convention matches too
-  if (!FInfo.isCompatibleSignature(CI->getFunctionType()))
+  if (!FInfo.isCompatibleSignature(*Callee->getParent(), CI->getFunctionType()))
     return false;
 
   LLVM_DEBUG(dbgs() << "AMDIC: try folding " << *CI << '\n');

diff  --git a/llvm/lib/Target/AMDGPU/AMDGPULibFunc.cpp b/llvm/lib/Target/AMDGPU/AMDGPULibFunc.cpp
index 4c596e37476c4e..64db58be032def 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPULibFunc.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPULibFunc.cpp
@@ -620,17 +620,17 @@ bool ItaniumParamParser::parseItaniumParam(StringRef& param,
   // parse type
   char const TC = param.front();
   if (isDigit(TC)) {
-    res.ArgType = StringSwitch<AMDGPULibFunc::EType>
-      (eatLengthPrefixedName(param))
-      .Case("ocl_image1darray" , AMDGPULibFunc::IMG1DA)
-      .Case("ocl_image1dbuffer", AMDGPULibFunc::IMG1DB)
-      .Case("ocl_image2darray" , AMDGPULibFunc::IMG2DA)
-      .Case("ocl_image1d"      , AMDGPULibFunc::IMG1D)
-      .Case("ocl_image2d"      , AMDGPULibFunc::IMG2D)
-      .Case("ocl_image3d"      , AMDGPULibFunc::IMG3D)
-      .Case("ocl_event"        , AMDGPULibFunc::DUMMY)
-      .Case("ocl_sampler"      , AMDGPULibFunc::DUMMY)
-      .Default(AMDGPULibFunc::DUMMY);
+    res.ArgType =
+        StringSwitch<AMDGPULibFunc::EType>(eatLengthPrefixedName(param))
+            .Case("ocl_image1darray", AMDGPULibFunc::IMG1DA)
+            .Case("ocl_image1dbuffer", AMDGPULibFunc::IMG1DB)
+            .Case("ocl_image2darray", AMDGPULibFunc::IMG2DA)
+            .StartsWith("ocl_image1d", AMDGPULibFunc::IMG1D)
+            .StartsWith("ocl_image2d", AMDGPULibFunc::IMG2D)
+            .StartsWith("ocl_image3d", AMDGPULibFunc::IMG3D)
+            .Case("ocl_event", AMDGPULibFunc::DUMMY)
+            .Case("ocl_sampler", AMDGPULibFunc::DUMMY)
+            .Default(AMDGPULibFunc::DUMMY);
   } else {
     drop_front(param);
     switch (TC) {
@@ -969,7 +969,7 @@ static Type* getIntrinsicParamType(
   return T;
 }
 
-FunctionType *AMDGPUMangledLibFunc::getFunctionType(Module &M) const {
+FunctionType *AMDGPUMangledLibFunc::getFunctionType(const Module &M) const {
   LLVMContext& C = M.getContext();
   std::vector<Type*> Args;
   ParamIterator I(Leads, manglingRules[FuncId]);
@@ -997,9 +997,39 @@ std::string AMDGPUMangledLibFunc::getName() const {
   return std::string(OS.str());
 }
 
-bool AMDGPULibFunc::isCompatibleSignature(const FunctionType *FuncTy) const {
-  // TODO: Validate types make sense
-  return !FuncTy->isVarArg() && FuncTy->getNumParams() == getNumArgs();
+bool AMDGPULibFunc::isCompatibleSignature(const Module &M,
+                                          const FunctionType *CallTy) const {
+  const FunctionType *FuncTy = getFunctionType(M);
+
+  // FIXME: UnmangledFuncInfo does not have any type information other than the
+  // number of arguments.
+  if (!FuncTy)
+    return getNumArgs() == CallTy->getNumParams();
+
+  // Normally the types should exactly match.
+  if (FuncTy == CallTy)
+    return true;
+
+  const unsigned NumParams = FuncTy->getNumParams();
+  if (NumParams != CallTy->getNumParams())
+    return false;
+
+  for (unsigned I = 0; I != NumParams; ++I) {
+    Type *FuncArgTy = FuncTy->getParamType(I);
+    Type *CallArgTy = CallTy->getParamType(I);
+    if (FuncArgTy == CallArgTy)
+      continue;
+
+    // Some cases permit implicit splatting a scalar value to a vector argument.
+    auto *FuncVecTy = dyn_cast<VectorType>(FuncArgTy);
+    if (FuncVecTy && FuncVecTy->getElementType() == CallArgTy &&
+        allowsImplicitVectorSplat(I))
+      continue;
+
+    return false;
+  }
+
+  return true;
 }
 
 Function *AMDGPULibFunc::getFunction(Module *M, const AMDGPULibFunc &fInfo) {
@@ -1012,7 +1042,7 @@ Function *AMDGPULibFunc::getFunction(Module *M, const AMDGPULibFunc &fInfo) {
   if (F->hasFnAttribute(Attribute::NoBuiltin))
     return nullptr;
 
-  if (!fInfo.isCompatibleSignature(F->getFunctionType()))
+  if (!fInfo.isCompatibleSignature(*M, F->getFunctionType()))
     return nullptr;
 
   return F;
@@ -1028,7 +1058,7 @@ FunctionCallee AMDGPULibFunc::getOrInsertFunction(Module *M,
     if (F->hasFnAttribute(Attribute::NoBuiltin))
       return nullptr;
     if (!F->isDeclaration() &&
-        fInfo.isCompatibleSignature(F->getFunctionType()))
+        fInfo.isCompatibleSignature(*M, F->getFunctionType()))
       return F;
   }
 

diff  --git a/llvm/lib/Target/AMDGPU/AMDGPULibFunc.h b/llvm/lib/Target/AMDGPU/AMDGPULibFunc.h
index 10551bee3fa8d4..580ef51b559d80 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPULibFunc.h
+++ b/llvm/lib/Target/AMDGPU/AMDGPULibFunc.h
@@ -352,7 +352,7 @@ class AMDGPULibFuncImpl : public AMDGPULibFuncBase {
   void setName(StringRef N) { Name = std::string(N); }
   void setPrefix(ENamePrefix pfx) { FKind = pfx; }
 
-  virtual FunctionType *getFunctionType(Module &M) const = 0;
+  virtual FunctionType *getFunctionType(const Module &M) const = 0;
 
 protected:
   EFuncId FuncId;
@@ -391,8 +391,22 @@ class AMDGPULibFunc : public AMDGPULibFuncBase {
     return Impl->parseFuncName(MangledName);
   }
 
+  /// Return true if it's legal to splat a scalar value passed in parameter \p
+  /// ArgIdx to a vector argument.
+  bool allowsImplicitVectorSplat(int ArgIdx) const {
+    switch (getId()) {
+    case EI_LDEXP:
+      return ArgIdx == 1;
+    case EI_FMIN:
+    case EI_FMAX:
+      return true;
+    default:
+      return false;
+    }
+  }
+
   // Validate the call type matches the expected libfunc type.
-  bool isCompatibleSignature(const FunctionType *FuncTy) const;
+  bool isCompatibleSignature(const Module &M, const FunctionType *FuncTy) const;
 
   /// \return The mangled function name for mangled library functions
   /// and unmangled function name for unmangled library functions.
@@ -401,7 +415,7 @@ class AMDGPULibFunc : public AMDGPULibFuncBase {
   void setName(StringRef N) { Impl->setName(N); }
   void setPrefix(ENamePrefix PFX) { Impl->setPrefix(PFX); }
 
-  FunctionType *getFunctionType(Module &M) const {
+  FunctionType *getFunctionType(const Module &M) const {
     return Impl->getFunctionType(M);
   }
   static Function *getFunction(llvm::Module *M, const AMDGPULibFunc &fInfo);
@@ -428,7 +442,7 @@ class AMDGPUMangledLibFunc : public AMDGPULibFuncImpl {
 
   std::string getName() const override;
   unsigned getNumArgs() const override;
-  FunctionType *getFunctionType(Module &M) const override;
+  FunctionType *getFunctionType(const Module &M) const override;
   static StringRef getUnmangledName(StringRef MangledName);
 
   bool parseFuncName(StringRef &mangledName) override;
@@ -458,7 +472,9 @@ class AMDGPUUnmangledLibFunc : public AMDGPULibFuncImpl {
   }
   std::string getName() const override { return Name; }
   unsigned getNumArgs() const override;
-  FunctionType *getFunctionType(Module &M) const override { return FuncTy; }
+  FunctionType *getFunctionType(const Module &M) const override {
+    return FuncTy;
+  }
 
   bool parseFuncName(StringRef &Name) override;
 

diff  --git a/llvm/test/CodeGen/AMDGPU/amdgpu-simplify-libcall-image-function-signatures.ll b/llvm/test/CodeGen/AMDGPU/amdgpu-simplify-libcall-image-function-signatures.ll
new file mode 100644
index 00000000000000..ab06292e949948
--- /dev/null
+++ b/llvm/test/CodeGen/AMDGPU/amdgpu-simplify-libcall-image-function-signatures.ll
@@ -0,0 +1,65 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt -S -mtriple=amdgcn-amd-amdhsa -passes=amdgpu-simplifylib %s | FileCheck %s
+
+; Make sure we can produce a valid FunctionType for the expected
+; signature of image functions.
+
+declare i32 @_Z16get_image_height20ocl_image2d_depth_rw(ptr addrspace(4))
+
+define i32 @call_ocl_image2d_depth(ptr addrspace(4) %img) {
+; CHECK-LABEL: define i32 @call_ocl_image2d_depth(
+; CHECK-SAME: ptr addrspace(4) [[IMG:%.*]]) {
+; CHECK-NEXT:    [[RESULT:%.*]] = call i32 @_Z16get_image_height20ocl_image2d_depth_rw(ptr addrspace(4) [[IMG]])
+; CHECK-NEXT:    ret i32 [[RESULT]]
+;
+  %result = call i32 @_Z16get_image_height20ocl_image2d_depth_rw(ptr addrspace(4) %img)
+  ret i32 %result
+}
+
+declare i32 @_Z15get_image_width14ocl_image3d_ro(ptr addrspace(4))
+
+define i32 @call_ocl_image3d_depth(ptr addrspace(4) %img) {
+; CHECK-LABEL: define i32 @call_ocl_image3d_depth(
+; CHECK-SAME: ptr addrspace(4) [[IMG:%.*]]) {
+; CHECK-NEXT:    [[RESULT:%.*]] = call i32 @_Z15get_image_width14ocl_image3d_ro(ptr addrspace(4) [[IMG]])
+; CHECK-NEXT:    ret i32 [[RESULT]]
+;
+  %result = call i32 @_Z15get_image_width14ocl_image3d_ro(ptr addrspace(4) %img)
+  ret i32 %result
+}
+
+declare i32 @_Z15get_image_width14ocl_image1d_ro(ptr addrspace(4))
+
+define i32 @call_get_image_width14ocl_image1d_ro(ptr addrspace(4) %img) {
+; CHECK-LABEL: define i32 @call_get_image_width14ocl_image1d_ro(
+; CHECK-SAME: ptr addrspace(4) [[IMG:%.*]]) {
+; CHECK-NEXT:    [[RESULT:%.*]] = call i32 @_Z15get_image_width14ocl_image1d_ro(ptr addrspace(4) [[IMG]])
+; CHECK-NEXT:    ret i32 [[RESULT]]
+;
+  %result = call i32 @_Z15get_image_width14ocl_image1d_ro(ptr addrspace(4) %img)
+  ret i32 %result
+}
+
+declare <2 x i32> @_Z13get_image_dim20ocl_image2d_array_ro(ptr addrspace(4))
+
+define <2 x i32> @call_Z13get_image_dim20ocl_image2d_array_ro(ptr addrspace(4) %img) {
+; CHECK-LABEL: define <2 x i32> @call_Z13get_image_dim20ocl_image2d_array_ro(
+; CHECK-SAME: ptr addrspace(4) [[IMG:%.*]]) {
+; CHECK-NEXT:    [[RESULT:%.*]] = call <2 x i32> @_Z13get_image_dim20ocl_image2d_array_ro(ptr addrspace(4) [[IMG]])
+; CHECK-NEXT:    ret <2 x i32> [[RESULT]]
+;
+  %result = call <2 x i32> @_Z13get_image_dim20ocl_image2d_array_ro(ptr addrspace(4) %img)
+  ret <2 x i32> %result
+}
+
+declare i32 @_Z15get_image_width20ocl_image1d_array_ro(ptr addrspace(4))
+
+define i32 @call_Z15get_image_width20ocl_image1d_array_ro(ptr addrspace(4) %img) {
+; CHECK-LABEL: define i32 @call_Z15get_image_width20ocl_image1d_array_ro(
+; CHECK-SAME: ptr addrspace(4) [[IMG:%.*]]) {
+; CHECK-NEXT:    [[RESULT:%.*]] = call i32 @_Z15get_image_width20ocl_image1d_array_ro(ptr addrspace(4) [[IMG]])
+; CHECK-NEXT:    ret i32 [[RESULT]]
+;
+  %result = call i32 @_Z15get_image_width20ocl_image1d_array_ro(ptr addrspace(4) %img)
+  ret i32 %result
+}

diff  --git a/llvm/test/CodeGen/AMDGPU/amdgpu-simplify-libcall-ldexp.ll b/llvm/test/CodeGen/AMDGPU/amdgpu-simplify-libcall-ldexp.ll
index 24082b8c666111..dc275b33b012da 100644
--- a/llvm/test/CodeGen/AMDGPU/amdgpu-simplify-libcall-ldexp.ll
+++ b/llvm/test/CodeGen/AMDGPU/amdgpu-simplify-libcall-ldexp.ll
@@ -242,6 +242,47 @@ define float @test_ldexp_f32_strictfp(float %x, i32 %y) #4 {
   ret float %ldexp
 }
 
+;---------------------------------------------------------------------
+; Invalid signatures
+;---------------------------------------------------------------------
+
+; Declared with wrong type, second argument is float
+declare float @_Z5ldexpff(float noundef, float noundef)
+
+define float @call_wrong_typed_ldexp_f32_second_arg(float %x, float %wrongtype) {
+; CHECK-LABEL: define float @call_wrong_typed_ldexp_f32_second_arg
+; CHECK-SAME: (float [[X:%.*]], float [[WRONGTYPE:%.*]]) {
+; CHECK-NEXT:    [[CALL:%.*]] = call float @_Z5ldexpff(float [[X]], float [[WRONGTYPE]])
+; CHECK-NEXT:    ret float [[CALL]]
+;
+  %call = call float @_Z5ldexpff(float %x, float %wrongtype)
+  ret float %call
+}
+
+declare <2 x float> @_Z5ldexpDv2_fS_(<2 x float>, <2 x float>)
+
+define <2 x float> @call_wrong_typed_ldexp_v2f32_second_arg(<2 x float> %x, <2 x float> %wrongtype) {
+; CHECK-LABEL: define <2 x float> @call_wrong_typed_ldexp_v2f32_second_arg
+; CHECK-SAME: (<2 x float> [[X:%.*]], <2 x float> [[WRONGTYPE:%.*]]) {
+; CHECK-NEXT:    [[CALL:%.*]] = call <2 x float> @_Z5ldexpDv2_fS_(<2 x float> [[X]], <2 x float> [[WRONGTYPE]])
+; CHECK-NEXT:    ret <2 x float> [[CALL]]
+;
+  %call = call <2 x float> @_Z5ldexpDv2_fS_(<2 x float> %x, <2 x float> %wrongtype)
+  ret <2 x float> %call
+}
+
+declare <2 x float> @_Z5ldexpDv2_ff(<2 x float>, float)
+
+define <2 x float> @call_wrong_typed_ldexp_v2f32_f32(<2 x float> %x, float %wrongtype) {
+; CHECK-LABEL: define <2 x float> @call_wrong_typed_ldexp_v2f32_f32
+; CHECK-SAME: (<2 x float> [[X:%.*]], float [[WRONGTYPE:%.*]]) {
+; CHECK-NEXT:    [[CALL:%.*]] = call <2 x float> @_Z5ldexpDv2_ff(<2 x float> [[X]], float [[WRONGTYPE]])
+; CHECK-NEXT:    ret <2 x float> [[CALL]]
+;
+  %call = call <2 x float> @_Z5ldexpDv2_ff(<2 x float> %x, float %wrongtype)
+  ret <2 x float> %call
+}
+
 attributes #0 = { nobuiltin }
 attributes #1 = { "no-builtins" }
 attributes #2 = { nounwind memory(none) }


        


More information about the llvm-commits mailing list