[llvm] 25ee36c - [SPIRV] read kernel arg attributes from fuction/module metadata

Ilia Diachkov via llvm-commits llvm-commits at lists.llvm.org
Wed Oct 5 17:59:53 PDT 2022


Author: Ilia Diachkov
Date: 2022-10-06T04:43:52+03:00
New Revision: 25ee36c6b19d31bbb5554969411cee59cc6b40c1

URL: https://github.com/llvm/llvm-project/commit/25ee36c6b19d31bbb5554969411cee59cc6b40c1
DIFF: https://github.com/llvm/llvm-project/commit/25ee36c6b19d31bbb5554969411cee59cc6b40c1.diff

LOG: [SPIRV] read kernel arg attributes from fuction/module metadata

The patch introduces reading the attributes of kernel arguments both from
function-attached and module-level metadata, during kernel arguments lowering.
Two tests are added to show the improvement.

Differential Revision: https://reviews.llvm.org/D135106

Co-authored-by: Aleksandr Bezzubikov <zuban32s at gmail.com>
Co-authored-by: Michal Paszkowski <michal.paszkowski at outlook.com>
Co-authored-by: Andrey Tretyakov <andrey.tretyakov at mail.com>
Co-authored-by: Konrad Trifunovic <konrad.trifunovic at intel.com>

Added: 
    llvm/test/CodeGen/SPIRV/opencl/kernel_arg_type_function_metadata.ll
    llvm/test/CodeGen/SPIRV/opencl/kernel_arg_type_module_metadata.ll

Modified: 
    llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
index 774941d1f17ea..18193bf2a9ad5 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
@@ -115,6 +115,102 @@ static FunctionType *getOriginalFunctionType(const Function &F) {
   return FunctionType::get(RetTy, ArgTypes, F.isVarArg());
 }
 
+static MDString *getKernelArgAttribute(const Function &KernelFunction,
+                                       unsigned ArgIdx,
+                                       const StringRef AttributeName) {
+  assert(KernelFunction.getCallingConv() == CallingConv::SPIR_KERNEL &&
+         "Kernel attributes are attached/belong only to kernel functions");
+
+  // Lookup the argument attribute in metadata attached to the kernel function.
+  MDNode *Node = KernelFunction.getMetadata(AttributeName);
+  if (Node && ArgIdx < Node->getNumOperands())
+    return cast<MDString>(Node->getOperand(ArgIdx));
+
+  // Sometimes metadata containing kernel attributes is not attached to the
+  // function, but can be found in the named module-level metadata instead.
+  // For example:
+  //   !opencl.kernels = !{!0}
+  //   !0 = !{void ()* @someKernelFunction, !1, ...}
+  //   !1 = !{!"kernel_arg_addr_space", ...}
+  // In this case the actual index of searched argument attribute is ArgIdx + 1,
+  // since the first metadata node operand is occupied by attribute name
+  // ("kernel_arg_addr_space" in the example above).
+  unsigned MDArgIdx = ArgIdx + 1;
+  NamedMDNode *OpenCLKernelsMD =
+      KernelFunction.getParent()->getNamedMetadata("opencl.kernels");
+  if (!OpenCLKernelsMD || OpenCLKernelsMD->getNumOperands() == 0)
+    return nullptr;
+
+  // KernelToMDNodeList contains kernel function declarations followed by
+  // corresponding MDNodes for each attribute. Search only MDNodes "belonging"
+  // to the currently lowered kernel function.
+  MDNode *KernelToMDNodeList = OpenCLKernelsMD->getOperand(0);
+  bool FoundLoweredKernelFunction = false;
+  for (const MDOperand &Operand : KernelToMDNodeList->operands()) {
+    ValueAsMetadata *MaybeValue = dyn_cast<ValueAsMetadata>(Operand);
+    if (MaybeValue && dyn_cast<Function>(MaybeValue->getValue())->getName() ==
+                          KernelFunction.getName()) {
+      FoundLoweredKernelFunction = true;
+      continue;
+    }
+    if (MaybeValue && FoundLoweredKernelFunction)
+      return nullptr;
+
+    MDNode *MaybeNode = dyn_cast<MDNode>(Operand);
+    if (FoundLoweredKernelFunction && MaybeNode &&
+        cast<MDString>(MaybeNode->getOperand(0))->getString() ==
+            AttributeName &&
+        MDArgIdx < MaybeNode->getNumOperands())
+      return cast<MDString>(MaybeNode->getOperand(MDArgIdx));
+  }
+  return nullptr;
+}
+
+static SPIRV::AccessQualifier::AccessQualifier
+getArgAccessQual(const Function &F, unsigned ArgIdx) {
+  if (F.getCallingConv() != CallingConv::SPIR_KERNEL)
+    return SPIRV::AccessQualifier::ReadWrite;
+
+  MDString *ArgAttribute =
+      getKernelArgAttribute(F, ArgIdx, "kernel_arg_access_qual");
+  if (!ArgAttribute)
+    return SPIRV::AccessQualifier::ReadWrite;
+
+  if (ArgAttribute->getString().compare("read_only") == 0)
+    return SPIRV::AccessQualifier::ReadOnly;
+  if (ArgAttribute->getString().compare("write_only") == 0)
+    return SPIRV::AccessQualifier::WriteOnly;
+  return SPIRV::AccessQualifier::ReadWrite;
+}
+
+static std::vector<SPIRV::Decoration::Decoration>
+getKernelArgTypeQual(const Function &KernelFunction, unsigned ArgIdx) {
+  MDString *ArgAttribute =
+      getKernelArgAttribute(KernelFunction, ArgIdx, "kernel_arg_type_qual");
+  if (ArgAttribute && ArgAttribute->getString().compare("volatile") == 0)
+    return {SPIRV::Decoration::Volatile};
+  return {};
+}
+
+static Type *getArgType(const Function &F, unsigned ArgIdx) {
+  Type *OriginalArgType = getOriginalFunctionType(F)->getParamType(ArgIdx);
+  if (F.getCallingConv() != CallingConv::SPIR_KERNEL ||
+      isSpecialOpaqueType(OriginalArgType))
+    return OriginalArgType;
+
+  MDString *MDKernelArgType =
+      getKernelArgAttribute(F, ArgIdx, "kernel_arg_type");
+  if (!MDKernelArgType || !MDKernelArgType->getString().endswith("_t"))
+    return OriginalArgType;
+
+  std::string KernelArgTypeStr = "opencl." + MDKernelArgType->getString().str();
+  Type *ExistingOpaqueType =
+      StructType::getTypeByName(F.getContext(), KernelArgTypeStr);
+  return ExistingOpaqueType
+             ? ExistingOpaqueType
+             : StructType::create(F.getContext(), KernelArgTypeStr);
+}
+
 bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
                                              const Function &F,
                                              ArrayRef<ArrayRef<Register>> VRegs,
@@ -132,18 +228,10 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
       // TODO: handle the case of multiple registers.
       if (VRegs[i].size() > 1)
         return false;
-      Type *ArgTy = FTy->getParamType(i);
-      SPIRV::AccessQualifier::AccessQualifier AQ =
-          SPIRV::AccessQualifier::ReadWrite;
-      MDNode *Node = F.getMetadata("kernel_arg_access_qual");
-      if (Node && i < Node->getNumOperands()) {
-        StringRef AQString = cast<MDString>(Node->getOperand(i))->getString();
-        if (AQString.compare("read_only") == 0)
-          AQ = SPIRV::AccessQualifier::ReadOnly;
-        else if (AQString.compare("write_only") == 0)
-          AQ = SPIRV::AccessQualifier::WriteOnly;
-      }
-      auto *SpirvTy = GR->assignTypeToVReg(ArgTy, VRegs[i][0], MIRBuilder, AQ);
+      SPIRV::AccessQualifier::AccessQualifier ArgAccessQual =
+          getArgAccessQual(F, i);
+      auto *SpirvTy = GR->assignTypeToVReg(getArgType(F, i), VRegs[i][0],
+                                           MIRBuilder, ArgAccessQual);
       ArgTypeVRegs.push_back(SpirvTy);
 
       if (Arg.hasName())
@@ -178,14 +266,15 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
         buildOpDecorate(VRegs[i][0], MIRBuilder,
                         SPIRV::Decoration::FuncParamAttr, {Attr});
       }
-      Node = F.getMetadata("kernel_arg_type_qual");
-      if (Node && i < Node->getNumOperands()) {
-        StringRef TypeQual = cast<MDString>(Node->getOperand(i))->getString();
-        if (TypeQual.compare("volatile") == 0)
-          buildOpDecorate(VRegs[i][0], MIRBuilder, SPIRV::Decoration::Volatile,
-                          {});
+
+      if (F.getCallingConv() == CallingConv::SPIR_KERNEL) {
+        std::vector<SPIRV::Decoration::Decoration> ArgTypeQualDecs =
+            getKernelArgTypeQual(F, i);
+        for (SPIRV::Decoration::Decoration Decoration : ArgTypeQualDecs)
+          buildOpDecorate(VRegs[i][0], MIRBuilder, Decoration, {});
       }
-      Node = F.getMetadata("spirv.ParameterDecorations");
+
+      MDNode *Node = F.getMetadata("spirv.ParameterDecorations");
       if (Node && i < Node->getNumOperands() &&
           isa<MDNode>(Node->getOperand(i))) {
         MDNode *MD = cast<MDNode>(Node->getOperand(i));

diff  --git a/llvm/test/CodeGen/SPIRV/opencl/kernel_arg_type_function_metadata.ll b/llvm/test/CodeGen/SPIRV/opencl/kernel_arg_type_function_metadata.ll
new file mode 100644
index 0000000000000..ce5910efc6ccd
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/opencl/kernel_arg_type_function_metadata.ll
@@ -0,0 +1,12 @@
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
+
+; CHECK: %[[#TypeSampler:]] = OpTypeSampler
+define spir_kernel void @foo(i64 %sampler) !kernel_arg_addr_space !7 !kernel_arg_access_qual !8 !kernel_arg_type !9 !kernel_arg_type_qual !10 !kernel_arg_base_type !9 {
+entry:
+  ret void
+}
+
+!7 = !{i32 0}
+!8 = !{!"none"}
+!9 = !{!"sampler_t"}
+!10 = !{!""}

diff  --git a/llvm/test/CodeGen/SPIRV/opencl/kernel_arg_type_module_metadata.ll b/llvm/test/CodeGen/SPIRV/opencl/kernel_arg_type_module_metadata.ll
new file mode 100644
index 0000000000000..b5bb8433321da
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/opencl/kernel_arg_type_module_metadata.ll
@@ -0,0 +1,16 @@
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
+
+; CHECK: %[[#TypeSampler:]] = OpTypeSampler
+define spir_kernel void @foo(i64 %sampler) {
+entry:
+  ret void
+}
+!opencl.kernels = !{!0}
+
+!0 = !{void (i64)* @foo, !1, !2, !3, !4, !5, !6}
+!1 = !{!"kernel_arg_addr_space", i32 0}
+!2 = !{!"kernel_arg_access_qual", !"none"}
+!3 = !{!"kernel_arg_type", !"sampler_t"}
+!4 = !{!"kernel_arg_type_qual", !""}
+!5 = !{!"kernel_arg_base_type", !"sampler_t"}
+!6 = !{!"kernel_arg_name", !"sampler"}


        


More information about the llvm-commits mailing list