[llvm] 513334f - [NFC][SPIRV] Fix function type recovery (#165934)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Nov 3 09:30:08 PST 2025
Author: Alex Voicu
Date: 2025-11-03T19:30:04+02:00
New Revision: 513334faec2594bbeb3ac00f0092bed20b23abd3
URL: https://github.com/llvm/llvm-project/commit/513334faec2594bbeb3ac00f0092bed20b23abd3
DIFF: https://github.com/llvm/llvm-project/commit/513334faec2594bbeb3ac00f0092bed20b23abd3.diff
LOG: [NFC][SPIRV] Fix function type recovery (#165934)
Due to limitations in GISel / IRTranslator, the SPIR-V BE replaces aggregate function args with `i32` placeholders, which are subsequently used to retrieve the original type after IR translation, from metadata. Due to what appears to be an oversight, the current implementation only handles a single mutation, as it does not traverse the metadata, but rather only takes the first operand. This patch addresses that limitation by correctly iterating the metadata.
Added:
Modified:
llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
llvm/test/CodeGen/SPIRV/pointers/composite-fun-fix-ptr-arg.ll
Removed:
################################################################################
diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
index 9e11c3a281a1b..dd57b74d79a5e 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
@@ -149,23 +149,23 @@ static FunctionType *getOriginalFunctionType(const Function &F) {
return isa<MDString>(N->getOperand(0)) &&
cast<MDString>(N->getOperand(0))->getString() == F.getName();
});
- // TODO: probably one function can have numerous type mutations,
- // so we should support this.
if (ThisFuncMDIt != NamedMD->op_end()) {
auto *ThisFuncMD = *ThisFuncMDIt;
- MDNode *MD = dyn_cast<MDNode>(ThisFuncMD->getOperand(1));
- assert(MD && "MDNode operand is expected");
- ConstantInt *Const = getConstInt(MD, 0);
- if (Const) {
- auto *CMeta = dyn_cast<ConstantAsMetadata>(MD->getOperand(1));
- assert(CMeta && "ConstantAsMetadata operand is expected");
- assert(Const->getSExtValue() >= -1);
- // Currently -1 indicates return value, greater values mean
- // argument numbers.
- if (Const->getSExtValue() == -1)
- RetTy = CMeta->getType();
- else
- ArgTypes[Const->getSExtValue()] = CMeta->getType();
+ for (unsigned I = 1; I != ThisFuncMD->getNumOperands(); ++I) {
+ MDNode *MD = dyn_cast<MDNode>(ThisFuncMD->getOperand(I));
+ assert(MD && "MDNode operand is expected");
+ ConstantInt *Const = getConstInt(MD, 0);
+ if (Const) {
+ auto *CMeta = dyn_cast<ConstantAsMetadata>(MD->getOperand(1));
+ assert(CMeta && "ConstantAsMetadata operand is expected");
+ assert(Const->getSExtValue() >= -1);
+ // Currently -1 indicates return value, greater values mean
+ // argument numbers.
+ if (Const->getSExtValue() == -1)
+ RetTy = CMeta->getType();
+ else
+ ArgTypes[Const->getSExtValue()] = CMeta->getType();
+ }
}
}
diff --git a/llvm/test/CodeGen/SPIRV/pointers/composite-fun-fix-ptr-arg.ll b/llvm/test/CodeGen/SPIRV/pointers/composite-fun-fix-ptr-arg.ll
index 73c46b18bfa78..c9b2968a4aed7 100644
--- a/llvm/test/CodeGen/SPIRV/pointers/composite-fun-fix-ptr-arg.ll
+++ b/llvm/test/CodeGen/SPIRV/pointers/composite-fun-fix-ptr-arg.ll
@@ -10,6 +10,7 @@
; CHECK-DAG: %[[#Int8:]] = OpTypeInt 8 0
; CHECK-DAG: %[[#Half:]] = OpTypeFloat 16
+; CHECK-DAG: %[[#Float:]] = OpTypeFloat 32
; CHECK-DAG: %[[#Struct:]] = OpTypeStruct %[[#Half]]
; CHECK-DAG: %[[#Void:]] = OpTypeVoid
; CHECK-DAG: %[[#PtrInt8:]] = OpTypePointer CrossWorkgroup %[[#Int8:]]
@@ -17,12 +18,20 @@
; CHECK-DAG: %[[#Int64:]] = OpTypeInt 64 0
; CHECK-DAG: %[[#PtrInt64:]] = OpTypePointer CrossWorkgroup %[[#Int64]]
; CHECK-DAG: %[[#BarType:]] = OpTypeFunction %[[#Void]] %[[#PtrInt64]] %[[#Struct]]
+; CHECK-DAG: %[[#BazType:]] = OpTypeFunction %[[#Void]] %[[#PtrInt8]] %[[#Struct]] %[[#Int8]] %[[#Struct]] %[[#Float]] %[[#Struct]]
; CHECK: OpFunction %[[#Void]] None %[[#FooType]]
; CHECK: OpFunctionParameter %[[#PtrInt8]]
; CHECK: OpFunctionParameter %[[#Struct]]
; CHECK: OpFunction %[[#Void]] None %[[#BarType]]
; CHECK: OpFunctionParameter %[[#PtrInt64]]
; CHECK: OpFunctionParameter %[[#Struct]]
+; CHECK: OpFunction %[[#Void]] None %[[#BazType]]
+; CHECK: OpFunctionParameter %[[#PtrInt8]]
+; CHECK: OpFunctionParameter %[[#Struct]]
+; CHECK: OpFunctionParameter %[[#Int8]]
+; CHECK: OpFunctionParameter %[[#Struct]]
+; CHECK: OpFunctionParameter %[[#Float]]
+; CHECK: OpFunctionParameter %[[#Struct]]
%t_half = type { half }
@@ -38,4 +47,9 @@ entry:
ret void
}
+define spir_kernel void @baz(ptr addrspace(1) %a, %t_half %b, i8 %c, %t_half %d, float %e, %t_half %f) {
+entry:
+ ret void
+}
+
declare spir_func %t_half @_Z29__spirv_SpecConstantComposite(half)
More information about the llvm-commits
mailing list