[llvm] [SPIR-V] Improve type inference: fix types of return values in call lowering (PR #116609)

Vyacheslav Levytskyy via llvm-commits llvm-commits at lists.llvm.org
Mon Nov 18 04:49:49 PST 2024


https://github.com/VyacheslavLevytskyy created https://github.com/llvm/llvm-project/pull/116609

The goal of the PR case is to ensure that correct types are applied to virtual registers which were used as return values in call lowering. A reproducer is attached as a new test case, before the PR it fails because spirv-val considers output invalid due to wrong result/operand types in OpPhi's.

>From d89bb33f0adb8d98fe668765353f4d79ec64dbd8 Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Mon, 18 Nov 2024 04:47:27 -0800
Subject: [PATCH] Improve type inference: return values in call lowering

---
 llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp       | 18 +-----
 llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp   | 25 +++++++++
 .../SPIRV/pointers/builtin-ret-reg-type.ll    | 55 +++++++++++++++++++
 .../SPIRV/transcoding/OpGenericCastToPtr.ll   |  2 -
 4 files changed, 83 insertions(+), 17 deletions(-)
 create mode 100644 llvm/test/CodeGen/SPIRV/pointers/builtin-ret-reg-type.ll

diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index 06a37f1f559d44..e34f6c3c282750 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -2517,23 +2517,11 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
                                  SPIRVGlobalRegistry *GR) {
   LLVM_DEBUG(dbgs() << "Lowering builtin call: " << DemangledCall << "\n");
 
-  // SPIR-V type and return register.
-  Register ReturnRegister = OrigRet;
-  SPIRVType *ReturnType = nullptr;
-  if (OrigRetTy && !OrigRetTy->isVoidTy()) {
-    ReturnType = GR->assignTypeToVReg(OrigRetTy, ReturnRegister, MIRBuilder);
-    if (!MIRBuilder.getMRI()->getRegClassOrNull(ReturnRegister))
-      MIRBuilder.getMRI()->setRegClass(ReturnRegister,
-                                       GR->getRegClass(ReturnType));
-  } else if (OrigRetTy && OrigRetTy->isVoidTy()) {
-    ReturnRegister = MIRBuilder.getMRI()->createVirtualRegister(&IDRegClass);
-    MIRBuilder.getMRI()->setType(ReturnRegister, LLT::scalar(64));
-    ReturnType = GR->assignTypeToVReg(OrigRetTy, ReturnRegister, MIRBuilder);
-  }
-
   // Lookup the builtin in the TableGen records.
+  SPIRVType *SpvType = GR->getSPIRVTypeForVReg(OrigRet);
+  assert(SpvType && "Inconsistent return register: expected valid type info");
   std::unique_ptr<const IncomingCall> Call =
-      lookupBuiltin(DemangledCall, Set, ReturnRegister, ReturnType, Args);
+      lookupBuiltin(DemangledCall, Set, OrigRet, SpvType, Args);
 
   if (!Call) {
     LLVM_DEBUG(dbgs() << "Builtin record was not found!\n");
diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
index 3c5397319aaf21..a7b6b0efa99551 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
@@ -539,6 +539,31 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
 
   if (isFunctionDecl && !DemangledName.empty() &&
       (canUseGLSL || canUseOpenCL)) {
+    if (ResVReg.isValid()) {
+      if (!GR->getSPIRVTypeForVReg(ResVReg)) {
+        const Type *RetTy = OrigRetTy;
+        if (auto *PtrRetTy = dyn_cast<PointerType>(OrigRetTy)) {
+          const Value *OrigValue = Info.OrigRet.OrigValue;
+          if (!OrigValue)
+            OrigValue = Info.CB;
+          if (OrigValue)
+            if (Type *ElemTy = GR->findDeducedElementType(OrigValue))
+              RetTy =
+                  TypedPointerType::get(ElemTy, PtrRetTy->getAddressSpace());
+        }
+        SPIRVType *SpvType = GR->getOrCreateSPIRVType(RetTy, MIRBuilder);
+        GR->assignSPIRVTypeToVReg(SpvType, ResVReg, MF);
+        if (!MRI->getRegClassOrNull(ResVReg)) {
+          MRI->setRegClass(ResVReg, GR->getRegClass(SpvType));
+          MRI->setType(ResVReg, GR->getRegType(SpvType));
+        }
+      }
+    } else {
+      SPIRVType *SpvType = GR->getOrCreateSPIRVType(OrigRetTy, MIRBuilder);
+      ResVReg = MRI->createVirtualRegister(GR->getRegClass(SpvType));
+      MRI->setType(ResVReg, GR->getRegType(SpvType));
+      GR->assignSPIRVTypeToVReg(SpvType, ResVReg, MF);
+    }
     SmallVector<Register, 8> ArgVRegs;
     for (auto Arg : Info.OrigArgs) {
       assert(Arg.Regs.size() == 1 && "Call arg has multiple VRegs");
diff --git a/llvm/test/CodeGen/SPIRV/pointers/builtin-ret-reg-type.ll b/llvm/test/CodeGen/SPIRV/pointers/builtin-ret-reg-type.ll
new file mode 100644
index 00000000000000..afa97ccfc0a69c
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/pointers/builtin-ret-reg-type.ll
@@ -0,0 +1,55 @@
+; The goal of the test case is to ensure that correct types are applied to virtual registers which were
+; used as return values in call lowering. Pass criterion is that spirv-val considers output valid.
+
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s --check-prefixes=CHECK-SPIRV
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-SPIRV: OpFunction
+; CHECK-SPIRV: %[[#]] = OpPhi %[[#]] %[[#Ptr:]] %[[#]] %[[#]] %[[#]]
+; CHECK-SPIRV: %[[#Ptr]] = OpPtrAccessChain %[[#]] %[[#]] %[[#]]
+
+
+%t_half = type { half }
+%t_i17 = type { [17 x i32] }
+%t_h17 = type { [17 x %t_half] }
+
+define internal spir_func void @foo(i64 %arrayinit.cur.add_4, half %r1, ptr addrspace(4) noundef align 8 dereferenceable_or_null(72) %this) {
+entry:
+  %r_3 = alloca %t_h17, align 8
+  %p_src = alloca %t_i17, align 4
+  %p_src4 = addrspacecast ptr %p_src to ptr addrspace(4)
+  %call_2 = call spir_func noundef ptr @_Z42__spirv_GenericCastToPtrExplicit_ToPrivatePvi(ptr addrspace(4) noundef %p_src4, i32 noundef 7)
+  br label %l_body
+
+l_body:                                           ; preds = %l_body, %entry
+  %l_done = icmp eq i64 %arrayinit.cur.add_4, 34
+  br i1 %l_done, label %exit, label %l_body
+
+exit:                                             ; preds = %l_body
+  %0 = addrspacecast ptr %call_2 to ptr addrspace(4)
+  %call_6 = call spir_func noundef ptr @_Z42__spirv_GenericCastToPtrExplicit_ToPrivatePvi(ptr addrspace(4) noundef %0, i32 noundef 7)
+  br label %for.cond_3
+
+for.cond_3:                                       ; preds = %for.body_3, %exit
+  %lsr.iv1 = phi ptr [ %scevgep2, %for.body_3 ], [ %call_6, %exit ]
+  %lsr.iv = phi ptr [ %scevgep, %for.body_3 ], [ %r_3, %exit ]
+  %i.0_3 = phi i64 [ 0, %exit ], [ %inc_3, %for.body_3 ]
+  %cmp_3 = icmp ult i64 %i.0_3, 17
+  br i1 %cmp_3, label %for.body_3, label %exit2
+
+for.body_3:                                       ; preds = %for.cond_3
+  %call2_5 = call spir_func noundef half @_Z17__spirv_ocl_frexpDF16_PU3AS0i(half noundef %r1, ptr noundef %lsr.iv1)
+  store half %call2_5, ptr %lsr.iv, align 2
+  %inc_3 = add nuw nsw i64 %i.0_3, 1
+  %scevgep = getelementptr i8, ptr %lsr.iv, i64 2
+  %scevgep2 = getelementptr i8, ptr %lsr.iv1, i64 4
+  br label %for.cond_3
+
+exit2:                                            ; preds = %for.cond_3
+  ret void
+}
+
+declare dso_local spir_func noundef ptr @_Z42__spirv_GenericCastToPtrExplicit_ToPrivatePvi(ptr addrspace(4) noundef, i32 noundef)
+declare dso_local spir_func noundef half @_Z17__spirv_ocl_frexpDF16_PU3AS0i(half noundef, ptr noundef)
+declare void @llvm.memcpy.p0.p0.i64(ptr noalias nocapture writeonly, ptr noalias nocapture readonly, i64, i1 immarg)
+declare void @llvm.memcpy.p1.p0.i64(ptr addrspace(1) noalias nocapture writeonly, ptr noalias nocapture readonly, i64, i1 immarg)
diff --git a/llvm/test/CodeGen/SPIRV/transcoding/OpGenericCastToPtr.ll b/llvm/test/CodeGen/SPIRV/transcoding/OpGenericCastToPtr.ll
index 54b2c786747768..2cba0f6ebd74be 100644
--- a/llvm/test/CodeGen/SPIRV/transcoding/OpGenericCastToPtr.ll
+++ b/llvm/test/CodeGen/SPIRV/transcoding/OpGenericCastToPtr.ll
@@ -2,9 +2,7 @@
 ; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
 
 ; CHECK-SPIRV-DAG: %[[#Char:]] = OpTypeInt 8 0
-; CHECK-SPIRV-DAG: %[[#GlobalCharPtr:]] = OpTypePointer CrossWorkgroup %[[#Char]]
 ; CHECK-SPIRV-DAG: %[[#LocalCharPtr:]] = OpTypePointer Workgroup %[[#Char]]
-; CHECK-SPIRV-DAG: %[[#PrivateCharPtr:]] = OpTypePointer Function %[[#Char]]
 ; CHECK-SPIRV-DAG: %[[#GenericCharPtr:]] = OpTypePointer Generic %[[#Char]]
 
 ; CHECK-SPIRV-DAG: %[[#Int:]] = OpTypeInt 32 0



More information about the llvm-commits mailing list