[llvm] [SPIR-V] Ensure correct pointee types of some OpenCL Extended Instructions' pointer arguments (PR #114846)

Vyacheslav Levytskyy via llvm-commits llvm-commits at lists.llvm.org
Wed Nov 6 03:05:52 PST 2024


https://github.com/VyacheslavLevytskyy updated https://github.com/llvm/llvm-project/pull/114846

>From 371b58d05e2dddeebbc8d4e2d4f6d0625a860370 Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Mon, 4 Nov 2024 10:21:09 -0800
Subject: [PATCH 1/2] ensure correct pointee types of some OpenCL Extended
 Instructions' pointer arguments

---
 llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp   | 55 +++++++++++++++++--
 .../OpExtInst-OpenCL_std-ptr-types.ll         | 34 ++++++++++++
 2 files changed, 83 insertions(+), 6 deletions(-)
 create mode 100644 llvm/test/CodeGen/SPIRV/pointers/OpExtInst-OpenCL_std-ptr-types.ll

diff --git a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
index 682fca7cc7747c..a0b7a27a109ba0 100644
--- a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
@@ -223,10 +223,10 @@ static void validateLifetimeStart(const SPIRVSubtarget &STI,
   doInsertBitcast(STI, MRI, GR, I, PtrReg, 0, NewPtrType);
 }
 
-static void validateGroupAsyncCopyPtr(const SPIRVSubtarget &STI,
-                                      MachineRegisterInfo *MRI,
-                                      SPIRVGlobalRegistry &GR, MachineInstr &I,
-                                      unsigned OpIdx) {
+static void validatePtrUnwrapStructField(const SPIRVSubtarget &STI,
+                                         MachineRegisterInfo *MRI,
+                                         SPIRVGlobalRegistry &GR,
+                                         MachineInstr &I, unsigned OpIdx) {
   MachineFunction *MF = I.getParent()->getParent();
   Register OpReg = I.getOperand(OpIdx).getReg();
   Register OpTypeReg = getTypeReg(MRI, OpReg);
@@ -440,8 +440,8 @@ void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const {
           validateLifetimeStart(STI, MRI, GR, MI);
         break;
       case SPIRV::OpGroupAsyncCopy:
-        validateGroupAsyncCopyPtr(STI, MRI, GR, MI, 3);
-        validateGroupAsyncCopyPtr(STI, MRI, GR, MI, 4);
+        validatePtrUnwrapStructField(STI, MRI, GR, MI, 3);
+        validatePtrUnwrapStructField(STI, MRI, GR, MI, 4);
         break;
       case SPIRV::OpGroupWaitEvents:
         // OpGroupWaitEvents ..., ..., <pointer to OpTypeEvent>
@@ -467,6 +467,49 @@ void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const {
         if (Type->getParent() == Curr && !Curr->pred_empty())
           ToMove.insert(const_cast<MachineInstr *>(Type));
       } break;
+      case SPIRV::OpExtInst: {
+        // prefetch
+        if (!MI.getOperand(2).isImm() || !MI.getOperand(3).isImm() ||
+            MI.getOperand(2).getImm() != SPIRV::InstructionSet::OpenCL_std)
+          continue;
+        switch (MI.getOperand(3).getImm()) {
+        case SPIRV::OpenCLExtInst::remquo: {
+          // The last operand must be of a pointer to the return type.
+          MachineIRBuilder MIB(MI);
+          SPIRVType *Int32Type = GR.getOrCreateSPIRVIntegerType(32, MIB);
+          SPIRVType *RetType = MRI->getVRegDef(MI.getOperand(1).getReg());
+          assert(RetType && "Expect return type");
+          validatePtrTypes(
+              STI, MRI, GR, MI, MI.getNumOperands() - 1,
+              RetType->getOpcode() != SPIRV::OpTypeVector
+                  ? Int32Type
+                  : GR.getOrCreateSPIRVVectorType(
+                        Int32Type, RetType->getOperand(2).getImm(), MIB));
+        } break;
+        case SPIRV::OpenCLExtInst::fract:
+        case SPIRV::OpenCLExtInst::frexp:
+        case SPIRV::OpenCLExtInst::lgamma_r:
+        case SPIRV::OpenCLExtInst::modf:
+        case SPIRV::OpenCLExtInst::sincos:
+          // The last operand must be of a pointer to the base type represented
+          // by the previous operand.
+          assert(MI.getOperand(MI.getNumOperands() - 2).isReg() &&
+                 "Expect v-reg");
+          validatePtrTypes(
+              STI, MRI, GR, MI, MI.getNumOperands() - 1,
+              GR.getSPIRVTypeForVReg(
+                  MI.getOperand(MI.getNumOperands() - 2).getReg()));
+          break;
+        case SPIRV::OpenCLExtInst::prefetch:
+          // Expected `ptr` type is a pointer to float, integer or vector, but
+          // the pontee value can be wrapped into a struct.
+          assert(MI.getOperand(MI.getNumOperands() - 2).isReg() &&
+                 "Expect v-reg");
+          validatePtrUnwrapStructField(STI, MRI, GR, MI,
+                                       MI.getNumOperands() - 2);
+          break;
+        }
+      } break;
       }
     }
     for (MachineInstr *MI : ToMove) {
diff --git a/llvm/test/CodeGen/SPIRV/pointers/OpExtInst-OpenCL_std-ptr-types.ll b/llvm/test/CodeGen/SPIRV/pointers/OpExtInst-OpenCL_std-ptr-types.ll
new file mode 100644
index 00000000000000..8e29876d61d339
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/pointers/OpExtInst-OpenCL_std-ptr-types.ll
@@ -0,0 +1,34 @@
+; The goal of the test is to ensure that the output SPIR-V is valid from the perspective of the spirv-val tool.
+; RUN: %if spirv-tools %{ llc -verify-machineinstrs -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+%clsid = type { %arr }
+%arr = type { [1 x i64] }
+%struct_half = type { half }
+
+define weak_odr dso_local spir_kernel void @foo(ptr addrspace(1) noundef readonly align 2 %_acc, ptr noundef byval(%clsid) align 8 %_acc_id, ptr addrspace(3) noundef align 2 %_arg_loc) {
+entry:
+  %r1 = load i64, ptr %_acc_id, align 8
+  %add.ptr.i41 = getelementptr inbounds %struct_half, ptr addrspace(1) %_acc, i64 %r1
+  %idx = addrspacecast ptr addrspace(1) %add.ptr.i41 to ptr addrspace(4)
+  %call.i.i290 = call spir_func noundef ptr addrspace(1) @_Z41__spirv_GenericCastToPtrExplicit_ToGlobalPKvi(ptr addrspace(4) noundef %idx, i32 noundef 5)
+  call spir_func void @_Z20__spirv_ocl_prefetchPU3AS1Kcm(ptr addrspace(1) noundef %call.i.i290, i64 noundef 0)
+
+  %locidx = addrspacecast ptr addrspace(3) %_arg_loc to ptr addrspace(4)
+  %ptr1 = tail call spir_func noundef ptr addrspace(3) @_Z40__spirv_GenericCastToPtrExplicit_ToLocalPvi(ptr addrspace(4) noundef %locidx, i32 noundef 4)
+  %sincos_r = tail call spir_func noundef half @_Z18__spirv_ocl_sincosDF16_PU3AS3DF16_(half noundef 0xH3145, ptr addrspace(3) noundef %ptr1)
+
+  %p1 = addrspacecast ptr addrspace(1) %_acc to ptr addrspace(4)
+  %ptr2 = tail call spir_func noundef ptr addrspace(1) @_Z41__spirv_GenericCastToPtrExplicit_ToGlobalPvi(ptr addrspace(4) noundef %p1, i32 noundef 5)
+  %remquo_r = tail call spir_func noundef half @_Z18__spirv_ocl_remquoDF16_DF16_PU3AS1i(half noundef 0xH3A37, half noundef 0xH32F4, ptr addrspace(1) noundef %ptr2)
+
+  ret void
+}
+
+declare dso_local spir_func void @_Z20__spirv_ocl_prefetchPU3AS1Kcm(ptr addrspace(1) noundef, i64 noundef)
+declare dso_local spir_func noundef ptr addrspace(1) @_Z41__spirv_GenericCastToPtrExplicit_ToGlobalPKvi(ptr addrspace(4) noundef, i32 noundef)
+
+declare dso_local spir_func noundef half @_Z18__spirv_ocl_sincosDF16_PU3AS3DF16_(half noundef, ptr addrspace(3) noundef)
+declare dso_local spir_func noundef ptr addrspace(3) @_Z40__spirv_GenericCastToPtrExplicit_ToLocalPvi(ptr addrspace(4) noundef, i32 noundef)
+
+declare dso_local spir_func noundef ptr addrspace(1) @_Z41__spirv_GenericCastToPtrExplicit_ToGlobalPvi(ptr addrspace(4) noundef, i32 noundef)
+declare dso_local spir_func noundef half @_Z18__spirv_ocl_remquoDF16_DF16_PU3AS1i(half noundef, half noundef, ptr addrspace(1) noundef)

>From a7c3101562a50f46b1fdd4200da3eee365755dfc Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Wed, 6 Nov 2024 03:05:35 -0800
Subject: [PATCH 2/2] apply reviewer comments

---
 llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
index a0b7a27a109ba0..ecbceb5b472fa1 100644
--- a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
@@ -478,7 +478,7 @@ void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const {
           MachineIRBuilder MIB(MI);
           SPIRVType *Int32Type = GR.getOrCreateSPIRVIntegerType(32, MIB);
           SPIRVType *RetType = MRI->getVRegDef(MI.getOperand(1).getReg());
-          assert(RetType && "Expect return type");
+          assert(RetType && "Expected return type");
           validatePtrTypes(
               STI, MRI, GR, MI, MI.getNumOperands() - 1,
               RetType->getOpcode() != SPIRV::OpTypeVector
@@ -494,7 +494,7 @@ void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const {
           // The last operand must be of a pointer to the base type represented
           // by the previous operand.
           assert(MI.getOperand(MI.getNumOperands() - 2).isReg() &&
-                 "Expect v-reg");
+                 "Expected v-reg");
           validatePtrTypes(
               STI, MRI, GR, MI, MI.getNumOperands() - 1,
               GR.getSPIRVTypeForVReg(
@@ -504,7 +504,7 @@ void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const {
           // Expected `ptr` type is a pointer to float, integer or vector, but
           // the pontee value can be wrapped into a struct.
           assert(MI.getOperand(MI.getNumOperands() - 2).isReg() &&
-                 "Expect v-reg");
+                 "Expected v-reg");
           validatePtrUnwrapStructField(STI, MRI, GR, MI,
                                        MI.getNumOperands() - 2);
           break;



More information about the llvm-commits mailing list