[llvm] 513334f - [NFC][SPIRV] Fix function type recovery (#165934)

via llvm-commits llvm-commits at lists.llvm.org
Mon Nov 3 09:30:08 PST 2025


Author: Alex Voicu
Date: 2025-11-03T19:30:04+02:00
New Revision: 513334faec2594bbeb3ac00f0092bed20b23abd3

URL: https://github.com/llvm/llvm-project/commit/513334faec2594bbeb3ac00f0092bed20b23abd3
DIFF: https://github.com/llvm/llvm-project/commit/513334faec2594bbeb3ac00f0092bed20b23abd3.diff

LOG: [NFC][SPIRV] Fix function type recovery (#165934)

Due to limitations in GISel / IRTranslator, the SPIR-V BE replaces aggregate function args with `i32` placeholders, which are subsequently used to retrieve the original type after IR translation, from metadata. Due to what appears to be an oversight, the current implementation only handles a single mutation, as it does not traverse the metadata, but rather only takes the first operand. This patch addresses that limitation by correctly iterating the metadata.

Added: 
    

Modified: 
    llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
    llvm/test/CodeGen/SPIRV/pointers/composite-fun-fix-ptr-arg.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
index 9e11c3a281a1b..dd57b74d79a5e 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
@@ -149,23 +149,23 @@ static FunctionType *getOriginalFunctionType(const Function &F) {
         return isa<MDString>(N->getOperand(0)) &&
                cast<MDString>(N->getOperand(0))->getString() == F.getName();
       });
-  // TODO: probably one function can have numerous type mutations,
-  // so we should support this.
   if (ThisFuncMDIt != NamedMD->op_end()) {
     auto *ThisFuncMD = *ThisFuncMDIt;
-    MDNode *MD = dyn_cast<MDNode>(ThisFuncMD->getOperand(1));
-    assert(MD && "MDNode operand is expected");
-    ConstantInt *Const = getConstInt(MD, 0);
-    if (Const) {
-      auto *CMeta = dyn_cast<ConstantAsMetadata>(MD->getOperand(1));
-      assert(CMeta && "ConstantAsMetadata operand is expected");
-      assert(Const->getSExtValue() >= -1);
-      // Currently -1 indicates return value, greater values mean
-      // argument numbers.
-      if (Const->getSExtValue() == -1)
-        RetTy = CMeta->getType();
-      else
-        ArgTypes[Const->getSExtValue()] = CMeta->getType();
+    for (unsigned I = 1; I != ThisFuncMD->getNumOperands(); ++I) {
+      MDNode *MD = dyn_cast<MDNode>(ThisFuncMD->getOperand(I));
+      assert(MD && "MDNode operand is expected");
+      ConstantInt *Const = getConstInt(MD, 0);
+      if (Const) {
+        auto *CMeta = dyn_cast<ConstantAsMetadata>(MD->getOperand(1));
+        assert(CMeta && "ConstantAsMetadata operand is expected");
+        assert(Const->getSExtValue() >= -1);
+        // Currently -1 indicates return value, greater values mean
+        // argument numbers.
+        if (Const->getSExtValue() == -1)
+          RetTy = CMeta->getType();
+        else
+          ArgTypes[Const->getSExtValue()] = CMeta->getType();
+      }
     }
   }
 

diff  --git a/llvm/test/CodeGen/SPIRV/pointers/composite-fun-fix-ptr-arg.ll b/llvm/test/CodeGen/SPIRV/pointers/composite-fun-fix-ptr-arg.ll
index 73c46b18bfa78..c9b2968a4aed7 100644
--- a/llvm/test/CodeGen/SPIRV/pointers/composite-fun-fix-ptr-arg.ll
+++ b/llvm/test/CodeGen/SPIRV/pointers/composite-fun-fix-ptr-arg.ll
@@ -10,6 +10,7 @@
 
 ; CHECK-DAG: %[[#Int8:]] = OpTypeInt 8 0
 ; CHECK-DAG: %[[#Half:]] = OpTypeFloat 16
+; CHECK-DAG: %[[#Float:]] = OpTypeFloat 32
 ; CHECK-DAG: %[[#Struct:]] = OpTypeStruct %[[#Half]]
 ; CHECK-DAG: %[[#Void:]] = OpTypeVoid
 ; CHECK-DAG: %[[#PtrInt8:]] = OpTypePointer CrossWorkgroup %[[#Int8:]]
@@ -17,12 +18,20 @@
 ; CHECK-DAG: %[[#Int64:]] = OpTypeInt 64 0
 ; CHECK-DAG: %[[#PtrInt64:]] = OpTypePointer CrossWorkgroup %[[#Int64]]
 ; CHECK-DAG: %[[#BarType:]] = OpTypeFunction %[[#Void]] %[[#PtrInt64]] %[[#Struct]]
+; CHECK-DAG: %[[#BazType:]] = OpTypeFunction %[[#Void]] %[[#PtrInt8]] %[[#Struct]] %[[#Int8]] %[[#Struct]] %[[#Float]] %[[#Struct]]
 ; CHECK: OpFunction %[[#Void]] None %[[#FooType]]
 ; CHECK: OpFunctionParameter %[[#PtrInt8]]
 ; CHECK: OpFunctionParameter %[[#Struct]]
 ; CHECK: OpFunction %[[#Void]] None %[[#BarType]]
 ; CHECK: OpFunctionParameter %[[#PtrInt64]]
 ; CHECK: OpFunctionParameter %[[#Struct]]
+; CHECK: OpFunction %[[#Void]] None %[[#BazType]]
+; CHECK: OpFunctionParameter %[[#PtrInt8]]
+; CHECK: OpFunctionParameter %[[#Struct]]
+; CHECK: OpFunctionParameter %[[#Int8]]
+; CHECK: OpFunctionParameter %[[#Struct]]
+; CHECK: OpFunctionParameter %[[#Float]]
+; CHECK: OpFunctionParameter %[[#Struct]]
 
 %t_half = type { half }
 
@@ -38,4 +47,9 @@ entry:
   ret void
 }
 
+define spir_kernel void @baz(ptr addrspace(1) %a, %t_half %b, i8 %c, %t_half %d, float %e, %t_half %f) {
+entry:
+  ret void
+}
+
 declare spir_func %t_half @_Z29__spirv_SpecConstantComposite(half)


        


More information about the llvm-commits mailing list