[llvm] [SPIR-V] Emit proper pointer type for OpenCL kernel arguments (PR #67726)

Michal Paszkowski via llvm-commits llvm-commits at lists.llvm.org
Thu Sep 28 12:25:38 PDT 2023


https://github.com/michalpaszkowski created https://github.com/llvm/llvm-project/pull/67726

None

>From fab425063b71a76131cb3610e3ae8c56f0851383 Mon Sep 17 00:00:00 2001
From: Michal Paszkowski <michal.paszkowski at outlook.com>
Date: Thu, 28 Sep 2023 12:10:34 -0700
Subject: [PATCH] [SPIR-V] Emit proper pointer type for OpenCL kernel arguments

---
 llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp   | 47 ++++++++++++-------
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp | 11 +++++
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h   |  5 ++
 3 files changed, 46 insertions(+), 17 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
index cae7c0e9ac5b8ac..f2b4beb9696641f 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
@@ -194,23 +194,39 @@ getKernelArgTypeQual(const Function &KernelFunction, unsigned ArgIdx) {
   return {};
 }
 
-static Type *getArgType(const Function &F, unsigned ArgIdx) {
+static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,
+                                  SPIRVGlobalRegistry *GR,
+                                  MachineIRBuilder &MIRBuilder) {
+  // Read argument's access qualifier from metadata or default
+  SPIRV::AccessQualifier::AccessQualifier ArgAccessQual =
+      getArgAccessQual(F, ArgIdx);
+
   Type *OriginalArgType = getOriginalFunctionType(F)->getParamType(ArgIdx);
-  if (F.getCallingConv() != CallingConv::SPIR_KERNEL ||
-      isSpecialOpaqueType(OriginalArgType))
-    return OriginalArgType;
+
+  // In case of non-kernel SPIR-V function, use the original IR type.
+  if (F.getCallingConv() != CallingConv::SPIR_KERNEL)
+    return GR->getOrCreateSPIRVType(OriginalArgType, MIRBuilder, ArgAccessQual);
+
+  // Use original type if it is already TargetExtType of a builtin type.
+  if (isSpecialOpaqueType(OriginalArgType))
+    return GR->getOrCreateSPIRVType(OriginalArgType, MIRBuilder, ArgAccessQual);
 
   MDString *MDKernelArgType =
       getKernelArgAttribute(F, ArgIdx, "kernel_arg_type");
-  if (!MDKernelArgType || !MDKernelArgType->getString().endswith("_t"))
-    return OriginalArgType;
-
-  std::string KernelArgTypeStr = "opencl." + MDKernelArgType->getString().str();
-  Type *ExistingOpaqueType =
-      StructType::getTypeByName(F.getContext(), KernelArgTypeStr);
-  return ExistingOpaqueType
-             ? ExistingOpaqueType
-             : StructType::create(F.getContext(), KernelArgTypeStr);
+  if (!MDKernelArgType || (MDKernelArgType->getString().ends_with("*") &&
+                           MDKernelArgType->getString().ends_with("_t")))
+    return GR->getOrCreateSPIRVType(OriginalArgType, MIRBuilder, ArgAccessQual);
+
+  if (MDKernelArgType->getString().ends_with("*"))
+    return GR->getOrCreateSPIRVPointerTypeByName(
+        MDKernelArgType->getString(), MIRBuilder,
+        SPIRV::StorageClass::CrossWorkgroup);
+
+  std::string OpaqueBuiltinTypeStr =
+      "opencl." + MDKernelArgType->getString().str();
+  return GR->getOrCreateSPIRVType(
+      TargetExtType::get(F.getContext(), OpaqueBuiltinTypeStr), MIRBuilder,
+      ArgAccessQual);
 }
 
 static bool isEntryPoint(const Function &F) {
@@ -262,10 +278,7 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
       // TODO: handle the case of multiple registers.
       if (VRegs[i].size() > 1)
         return false;
-      SPIRV::AccessQualifier::AccessQualifier ArgAccessQual =
-          getArgAccessQual(F, i);
-      auto *SpirvTy = GR->assignTypeToVReg(getArgType(F, i), VRegs[i][0],
-                                           MIRBuilder, ArgAccessQual);
+      auto *SpirvTy = getArgSPIRVType(F, i, GR, MIRBuilder);
       ArgTypeVRegs.push_back(SpirvTy);
 
       if (Arg.hasName())
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index d68454f26a80282..4166264ddcee080 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -952,6 +952,17 @@ SPIRVGlobalRegistry::checkSpecialInstr(const SPIRV::SpecialTypeDescriptor &TD,
   return nullptr;
 }
 
+SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerTypeByName(
+    StringRef TypeStr, MachineIRBuilder &MIRBuilder,
+    SPIRV::StorageClass::StorageClass SC) {
+  if (!TypeStr.ends_with("*"))
+    llvm_unreachable("Unable to recognize SPIRV pointer type name.");
+
+  SPIRVType *BaseType = getOrCreateSPIRVTypeByName(
+      TypeStr.substr(0, TypeStr.find("*")), MIRBuilder);
+  return getOrCreateSPIRVPointerType(BaseType, MIRBuilder, SC);
+}
+
 // TODO: maybe use tablegen to implement this.
 SPIRVType *
 SPIRVGlobalRegistry::getOrCreateSPIRVTypeByName(StringRef TypeStr,
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index 88769f84b3e504b..5a6938a1922e206 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -141,6 +141,11 @@ class SPIRVGlobalRegistry {
   SPIRVType *getOrCreateSPIRVTypeByName(StringRef TypeStr,
                                         MachineIRBuilder &MIRBuilder);
 
+  SPIRVType *
+  getOrCreateSPIRVPointerTypeByName(StringRef TypeStr,
+                                    MachineIRBuilder &MIRBuilder,
+                                    SPIRV::StorageClass::StorageClass SC);
+
   // Return the SPIR-V type instruction corresponding to the given VReg, or
   // nullptr if no such type instruction exists.
   SPIRVType *getSPIRVTypeForVReg(Register VReg) const;



More information about the llvm-commits mailing list