[llvm] [SPIR-V] Improve type inference: fix types of return values in call lowering (PR #116609)
Vyacheslav Levytskyy via llvm-commits
llvm-commits at lists.llvm.org
Mon Nov 18 04:49:49 PST 2024
https://github.com/VyacheslavLevytskyy created https://github.com/llvm/llvm-project/pull/116609
The goal of the PR case is to ensure that correct types are applied to virtual registers which were used as return values in call lowering. A reproducer is attached as a new test case, before the PR it fails because spirv-val considers output invalid due to wrong result/operand types in OpPhi's.
>From d89bb33f0adb8d98fe668765353f4d79ec64dbd8 Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Mon, 18 Nov 2024 04:47:27 -0800
Subject: [PATCH] Improve type inference: return values in call lowering
---
llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp | 18 +-----
llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp | 25 +++++++++
.../SPIRV/pointers/builtin-ret-reg-type.ll | 55 +++++++++++++++++++
.../SPIRV/transcoding/OpGenericCastToPtr.ll | 2 -
4 files changed, 83 insertions(+), 17 deletions(-)
create mode 100644 llvm/test/CodeGen/SPIRV/pointers/builtin-ret-reg-type.ll
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index 06a37f1f559d44..e34f6c3c282750 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -2517,23 +2517,11 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
SPIRVGlobalRegistry *GR) {
LLVM_DEBUG(dbgs() << "Lowering builtin call: " << DemangledCall << "\n");
- // SPIR-V type and return register.
- Register ReturnRegister = OrigRet;
- SPIRVType *ReturnType = nullptr;
- if (OrigRetTy && !OrigRetTy->isVoidTy()) {
- ReturnType = GR->assignTypeToVReg(OrigRetTy, ReturnRegister, MIRBuilder);
- if (!MIRBuilder.getMRI()->getRegClassOrNull(ReturnRegister))
- MIRBuilder.getMRI()->setRegClass(ReturnRegister,
- GR->getRegClass(ReturnType));
- } else if (OrigRetTy && OrigRetTy->isVoidTy()) {
- ReturnRegister = MIRBuilder.getMRI()->createVirtualRegister(&IDRegClass);
- MIRBuilder.getMRI()->setType(ReturnRegister, LLT::scalar(64));
- ReturnType = GR->assignTypeToVReg(OrigRetTy, ReturnRegister, MIRBuilder);
- }
-
// Lookup the builtin in the TableGen records.
+ SPIRVType *SpvType = GR->getSPIRVTypeForVReg(OrigRet);
+ assert(SpvType && "Inconsistent return register: expected valid type info");
std::unique_ptr<const IncomingCall> Call =
- lookupBuiltin(DemangledCall, Set, ReturnRegister, ReturnType, Args);
+ lookupBuiltin(DemangledCall, Set, OrigRet, SpvType, Args);
if (!Call) {
LLVM_DEBUG(dbgs() << "Builtin record was not found!\n");
diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
index 3c5397319aaf21..a7b6b0efa99551 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
@@ -539,6 +539,31 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
if (isFunctionDecl && !DemangledName.empty() &&
(canUseGLSL || canUseOpenCL)) {
+ if (ResVReg.isValid()) {
+ if (!GR->getSPIRVTypeForVReg(ResVReg)) {
+ const Type *RetTy = OrigRetTy;
+ if (auto *PtrRetTy = dyn_cast<PointerType>(OrigRetTy)) {
+ const Value *OrigValue = Info.OrigRet.OrigValue;
+ if (!OrigValue)
+ OrigValue = Info.CB;
+ if (OrigValue)
+ if (Type *ElemTy = GR->findDeducedElementType(OrigValue))
+ RetTy =
+ TypedPointerType::get(ElemTy, PtrRetTy->getAddressSpace());
+ }
+ SPIRVType *SpvType = GR->getOrCreateSPIRVType(RetTy, MIRBuilder);
+ GR->assignSPIRVTypeToVReg(SpvType, ResVReg, MF);
+ if (!MRI->getRegClassOrNull(ResVReg)) {
+ MRI->setRegClass(ResVReg, GR->getRegClass(SpvType));
+ MRI->setType(ResVReg, GR->getRegType(SpvType));
+ }
+ }
+ } else {
+ SPIRVType *SpvType = GR->getOrCreateSPIRVType(OrigRetTy, MIRBuilder);
+ ResVReg = MRI->createVirtualRegister(GR->getRegClass(SpvType));
+ MRI->setType(ResVReg, GR->getRegType(SpvType));
+ GR->assignSPIRVTypeToVReg(SpvType, ResVReg, MF);
+ }
SmallVector<Register, 8> ArgVRegs;
for (auto Arg : Info.OrigArgs) {
assert(Arg.Regs.size() == 1 && "Call arg has multiple VRegs");
diff --git a/llvm/test/CodeGen/SPIRV/pointers/builtin-ret-reg-type.ll b/llvm/test/CodeGen/SPIRV/pointers/builtin-ret-reg-type.ll
new file mode 100644
index 00000000000000..afa97ccfc0a69c
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/pointers/builtin-ret-reg-type.ll
@@ -0,0 +1,55 @@
+; The goal of the test case is to ensure that correct types are applied to virtual registers which were
+; used as return values in call lowering. Pass criterion is that spirv-val considers output valid.
+
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s --check-prefixes=CHECK-SPIRV
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-SPIRV: OpFunction
+; CHECK-SPIRV: %[[#]] = OpPhi %[[#]] %[[#Ptr:]] %[[#]] %[[#]] %[[#]]
+; CHECK-SPIRV: %[[#Ptr]] = OpPtrAccessChain %[[#]] %[[#]] %[[#]]
+
+
+%t_half = type { half }
+%t_i17 = type { [17 x i32] }
+%t_h17 = type { [17 x %t_half] }
+
+define internal spir_func void @foo(i64 %arrayinit.cur.add_4, half %r1, ptr addrspace(4) noundef align 8 dereferenceable_or_null(72) %this) {
+entry:
+ %r_3 = alloca %t_h17, align 8
+ %p_src = alloca %t_i17, align 4
+ %p_src4 = addrspacecast ptr %p_src to ptr addrspace(4)
+ %call_2 = call spir_func noundef ptr @_Z42__spirv_GenericCastToPtrExplicit_ToPrivatePvi(ptr addrspace(4) noundef %p_src4, i32 noundef 7)
+ br label %l_body
+
+l_body: ; preds = %l_body, %entry
+ %l_done = icmp eq i64 %arrayinit.cur.add_4, 34
+ br i1 %l_done, label %exit, label %l_body
+
+exit: ; preds = %l_body
+ %0 = addrspacecast ptr %call_2 to ptr addrspace(4)
+ %call_6 = call spir_func noundef ptr @_Z42__spirv_GenericCastToPtrExplicit_ToPrivatePvi(ptr addrspace(4) noundef %0, i32 noundef 7)
+ br label %for.cond_3
+
+for.cond_3: ; preds = %for.body_3, %exit
+ %lsr.iv1 = phi ptr [ %scevgep2, %for.body_3 ], [ %call_6, %exit ]
+ %lsr.iv = phi ptr [ %scevgep, %for.body_3 ], [ %r_3, %exit ]
+ %i.0_3 = phi i64 [ 0, %exit ], [ %inc_3, %for.body_3 ]
+ %cmp_3 = icmp ult i64 %i.0_3, 17
+ br i1 %cmp_3, label %for.body_3, label %exit2
+
+for.body_3: ; preds = %for.cond_3
+ %call2_5 = call spir_func noundef half @_Z17__spirv_ocl_frexpDF16_PU3AS0i(half noundef %r1, ptr noundef %lsr.iv1)
+ store half %call2_5, ptr %lsr.iv, align 2
+ %inc_3 = add nuw nsw i64 %i.0_3, 1
+ %scevgep = getelementptr i8, ptr %lsr.iv, i64 2
+ %scevgep2 = getelementptr i8, ptr %lsr.iv1, i64 4
+ br label %for.cond_3
+
+exit2: ; preds = %for.cond_3
+ ret void
+}
+
+declare dso_local spir_func noundef ptr @_Z42__spirv_GenericCastToPtrExplicit_ToPrivatePvi(ptr addrspace(4) noundef, i32 noundef)
+declare dso_local spir_func noundef half @_Z17__spirv_ocl_frexpDF16_PU3AS0i(half noundef, ptr noundef)
+declare void @llvm.memcpy.p0.p0.i64(ptr noalias nocapture writeonly, ptr noalias nocapture readonly, i64, i1 immarg)
+declare void @llvm.memcpy.p1.p0.i64(ptr addrspace(1) noalias nocapture writeonly, ptr noalias nocapture readonly, i64, i1 immarg)
diff --git a/llvm/test/CodeGen/SPIRV/transcoding/OpGenericCastToPtr.ll b/llvm/test/CodeGen/SPIRV/transcoding/OpGenericCastToPtr.ll
index 54b2c786747768..2cba0f6ebd74be 100644
--- a/llvm/test/CodeGen/SPIRV/transcoding/OpGenericCastToPtr.ll
+++ b/llvm/test/CodeGen/SPIRV/transcoding/OpGenericCastToPtr.ll
@@ -2,9 +2,7 @@
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
; CHECK-SPIRV-DAG: %[[#Char:]] = OpTypeInt 8 0
-; CHECK-SPIRV-DAG: %[[#GlobalCharPtr:]] = OpTypePointer CrossWorkgroup %[[#Char]]
; CHECK-SPIRV-DAG: %[[#LocalCharPtr:]] = OpTypePointer Workgroup %[[#Char]]
-; CHECK-SPIRV-DAG: %[[#PrivateCharPtr:]] = OpTypePointer Function %[[#Char]]
; CHECK-SPIRV-DAG: %[[#GenericCharPtr:]] = OpTypePointer Generic %[[#Char]]
; CHECK-SPIRV-DAG: %[[#Int:]] = OpTypeInt 32 0
More information about the llvm-commits
mailing list