[llvm] [SPIR-V] Emit proper pointer type for OpenCL kernel arguments (PR #67726)
Michal Paszkowski via llvm-commits
llvm-commits at lists.llvm.org
Thu Sep 28 12:25:38 PDT 2023
https://github.com/michalpaszkowski created https://github.com/llvm/llvm-project/pull/67726
None
>From fab425063b71a76131cb3610e3ae8c56f0851383 Mon Sep 17 00:00:00 2001
From: Michal Paszkowski <michal.paszkowski at outlook.com>
Date: Thu, 28 Sep 2023 12:10:34 -0700
Subject: [PATCH] [SPIR-V] Emit proper pointer type for OpenCL kernel arguments
---
llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp | 47 ++++++++++++-------
llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp | 11 +++++
llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h | 5 ++
3 files changed, 46 insertions(+), 17 deletions(-)
diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
index cae7c0e9ac5b8ac..f2b4beb9696641f 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
@@ -194,23 +194,39 @@ getKernelArgTypeQual(const Function &KernelFunction, unsigned ArgIdx) {
return {};
}
-static Type *getArgType(const Function &F, unsigned ArgIdx) {
+static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,
+ SPIRVGlobalRegistry *GR,
+ MachineIRBuilder &MIRBuilder) {
+ // Read argument's access qualifier from metadata or default
+ SPIRV::AccessQualifier::AccessQualifier ArgAccessQual =
+ getArgAccessQual(F, ArgIdx);
+
Type *OriginalArgType = getOriginalFunctionType(F)->getParamType(ArgIdx);
- if (F.getCallingConv() != CallingConv::SPIR_KERNEL ||
- isSpecialOpaqueType(OriginalArgType))
- return OriginalArgType;
+
+ // In case of non-kernel SPIR-V function, use the original IR type.
+ if (F.getCallingConv() != CallingConv::SPIR_KERNEL)
+ return GR->getOrCreateSPIRVType(OriginalArgType, MIRBuilder, ArgAccessQual);
+
+ // Use original type if it is already TargetExtType of a builtin type.
+ if (isSpecialOpaqueType(OriginalArgType))
+ return GR->getOrCreateSPIRVType(OriginalArgType, MIRBuilder, ArgAccessQual);
MDString *MDKernelArgType =
getKernelArgAttribute(F, ArgIdx, "kernel_arg_type");
- if (!MDKernelArgType || !MDKernelArgType->getString().endswith("_t"))
- return OriginalArgType;
-
- std::string KernelArgTypeStr = "opencl." + MDKernelArgType->getString().str();
- Type *ExistingOpaqueType =
- StructType::getTypeByName(F.getContext(), KernelArgTypeStr);
- return ExistingOpaqueType
- ? ExistingOpaqueType
- : StructType::create(F.getContext(), KernelArgTypeStr);
+ if (!MDKernelArgType || (MDKernelArgType->getString().ends_with("*") &&
+ MDKernelArgType->getString().ends_with("_t")))
+ return GR->getOrCreateSPIRVType(OriginalArgType, MIRBuilder, ArgAccessQual);
+
+ if (MDKernelArgType->getString().ends_with("*"))
+ return GR->getOrCreateSPIRVPointerTypeByName(
+ MDKernelArgType->getString(), MIRBuilder,
+ SPIRV::StorageClass::CrossWorkgroup);
+
+ std::string OpaqueBuiltinTypeStr =
+ "opencl." + MDKernelArgType->getString().str();
+ return GR->getOrCreateSPIRVType(
+ TargetExtType::get(F.getContext(), OpaqueBuiltinTypeStr), MIRBuilder,
+ ArgAccessQual);
}
static bool isEntryPoint(const Function &F) {
@@ -262,10 +278,7 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
// TODO: handle the case of multiple registers.
if (VRegs[i].size() > 1)
return false;
- SPIRV::AccessQualifier::AccessQualifier ArgAccessQual =
- getArgAccessQual(F, i);
- auto *SpirvTy = GR->assignTypeToVReg(getArgType(F, i), VRegs[i][0],
- MIRBuilder, ArgAccessQual);
+ auto *SpirvTy = getArgSPIRVType(F, i, GR, MIRBuilder);
ArgTypeVRegs.push_back(SpirvTy);
if (Arg.hasName())
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index d68454f26a80282..4166264ddcee080 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -952,6 +952,17 @@ SPIRVGlobalRegistry::checkSpecialInstr(const SPIRV::SpecialTypeDescriptor &TD,
return nullptr;
}
+SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerTypeByName(
+ StringRef TypeStr, MachineIRBuilder &MIRBuilder,
+ SPIRV::StorageClass::StorageClass SC) {
+ if (!TypeStr.ends_with("*"))
+ llvm_unreachable("Unable to recognize SPIRV pointer type name.");
+
+ SPIRVType *BaseType = getOrCreateSPIRVTypeByName(
+ TypeStr.substr(0, TypeStr.find("*")), MIRBuilder);
+ return getOrCreateSPIRVPointerType(BaseType, MIRBuilder, SC);
+}
+
// TODO: maybe use tablegen to implement this.
SPIRVType *
SPIRVGlobalRegistry::getOrCreateSPIRVTypeByName(StringRef TypeStr,
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index 88769f84b3e504b..5a6938a1922e206 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -141,6 +141,11 @@ class SPIRVGlobalRegistry {
SPIRVType *getOrCreateSPIRVTypeByName(StringRef TypeStr,
MachineIRBuilder &MIRBuilder);
+ SPIRVType *
+ getOrCreateSPIRVPointerTypeByName(StringRef TypeStr,
+ MachineIRBuilder &MIRBuilder,
+ SPIRV::StorageClass::StorageClass SC);
+
// Return the SPIR-V type instruction corresponding to the given VReg, or
// nullptr if no such type instruction exists.
SPIRVType *getSPIRVTypeForVReg(Register VReg) const;
More information about the llvm-commits
mailing list