[llvm] [SPIR-V] Cast ptr kernel args to i8* when used as Store's value operand (PR #78603)

via llvm-commits llvm-commits at lists.llvm.org
Thu Jan 18 08:36:51 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-spir-v

Author: Michal Paszkowski (michalpaszkowski)

<details>
<summary>Changes</summary>

Handle a special case when StoreInst's value operand is a kernel argument of a pointer type. Since these arguments could have either a basic element type (e.g. float*) or OpenCL builtin type (sampler_t), bitcast the StoreInst's value operand to default pointer element type (i8).

---
Full diff: https://github.com/llvm/llvm-project/pull/78603.diff


6 Files Affected:

- (modified) llvm/lib/Target/SPIRV/CMakeLists.txt (+1) 
- (modified) llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp (+5-58) 
- (modified) llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp (+28-6) 
- (added) llvm/lib/Target/SPIRV/SPIRVMetadata.cpp (+92) 
- (added) llvm/lib/Target/SPIRV/SPIRVMetadata.h (+31) 
- (added) llvm/test/CodeGen/SPIRV/pointers/store-kernel-arg-ptr-as-value-operand.ll (+19) 


``````````diff
diff --git a/llvm/lib/Target/SPIRV/CMakeLists.txt b/llvm/lib/Target/SPIRV/CMakeLists.txt
index 7d17c307db13a04..d9e24375dcb243f 100644
--- a/llvm/lib/Target/SPIRV/CMakeLists.txt
+++ b/llvm/lib/Target/SPIRV/CMakeLists.txt
@@ -26,6 +26,7 @@ add_llvm_target(SPIRVCodeGen
   SPIRVISelLowering.cpp
   SPIRVLegalizerInfo.cpp
   SPIRVMCInstLower.cpp
+  SPIRVMetadata.cpp
   SPIRVModuleAnalysis.cpp
   SPIRVPreLegalizer.cpp
   SPIRVPrepareFunctions.cpp
diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
index 0a8b5499a1fc2ac..62c08bab46eee27 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
@@ -17,6 +17,7 @@
 #include "SPIRVBuiltins.h"
 #include "SPIRVGlobalRegistry.h"
 #include "SPIRVISelLowering.h"
+#include "SPIRVMetadata.h"
 #include "SPIRVRegisterInfo.h"
 #include "SPIRVSubtarget.h"
 #include "SPIRVUtils.h"
@@ -117,64 +118,12 @@ static FunctionType *getOriginalFunctionType(const Function &F) {
   return FunctionType::get(RetTy, ArgTypes, F.isVarArg());
 }
 
-static MDString *getKernelArgAttribute(const Function &KernelFunction,
-                                       unsigned ArgIdx,
-                                       const StringRef AttributeName) {
-  assert(KernelFunction.getCallingConv() == CallingConv::SPIR_KERNEL &&
-         "Kernel attributes are attached/belong only to kernel functions");
-
-  // Lookup the argument attribute in metadata attached to the kernel function.
-  MDNode *Node = KernelFunction.getMetadata(AttributeName);
-  if (Node && ArgIdx < Node->getNumOperands())
-    return cast<MDString>(Node->getOperand(ArgIdx));
-
-  // Sometimes metadata containing kernel attributes is not attached to the
-  // function, but can be found in the named module-level metadata instead.
-  // For example:
-  //   !opencl.kernels = !{!0}
-  //   !0 = !{void ()* @someKernelFunction, !1, ...}
-  //   !1 = !{!"kernel_arg_addr_space", ...}
-  // In this case the actual index of searched argument attribute is ArgIdx + 1,
-  // since the first metadata node operand is occupied by attribute name
-  // ("kernel_arg_addr_space" in the example above).
-  unsigned MDArgIdx = ArgIdx + 1;
-  NamedMDNode *OpenCLKernelsMD =
-      KernelFunction.getParent()->getNamedMetadata("opencl.kernels");
-  if (!OpenCLKernelsMD || OpenCLKernelsMD->getNumOperands() == 0)
-    return nullptr;
-
-  // KernelToMDNodeList contains kernel function declarations followed by
-  // corresponding MDNodes for each attribute. Search only MDNodes "belonging"
-  // to the currently lowered kernel function.
-  MDNode *KernelToMDNodeList = OpenCLKernelsMD->getOperand(0);
-  bool FoundLoweredKernelFunction = false;
-  for (const MDOperand &Operand : KernelToMDNodeList->operands()) {
-    ValueAsMetadata *MaybeValue = dyn_cast<ValueAsMetadata>(Operand);
-    if (MaybeValue && dyn_cast<Function>(MaybeValue->getValue())->getName() ==
-                          KernelFunction.getName()) {
-      FoundLoweredKernelFunction = true;
-      continue;
-    }
-    if (MaybeValue && FoundLoweredKernelFunction)
-      return nullptr;
-
-    MDNode *MaybeNode = dyn_cast<MDNode>(Operand);
-    if (FoundLoweredKernelFunction && MaybeNode &&
-        cast<MDString>(MaybeNode->getOperand(0))->getString() ==
-            AttributeName &&
-        MDArgIdx < MaybeNode->getNumOperands())
-      return cast<MDString>(MaybeNode->getOperand(MDArgIdx));
-  }
-  return nullptr;
-}
-
 static SPIRV::AccessQualifier::AccessQualifier
 getArgAccessQual(const Function &F, unsigned ArgIdx) {
   if (F.getCallingConv() != CallingConv::SPIR_KERNEL)
     return SPIRV::AccessQualifier::ReadWrite;
 
-  MDString *ArgAttribute =
-      getKernelArgAttribute(F, ArgIdx, "kernel_arg_access_qual");
+  MDString *ArgAttribute = getOCLKernelArgAccessQual(F, ArgIdx);
   if (!ArgAttribute)
     return SPIRV::AccessQualifier::ReadWrite;
 
@@ -186,9 +135,8 @@ getArgAccessQual(const Function &F, unsigned ArgIdx) {
 }
 
 static std::vector<SPIRV::Decoration::Decoration>
-getKernelArgTypeQual(const Function &KernelFunction, unsigned ArgIdx) {
-  MDString *ArgAttribute =
-      getKernelArgAttribute(KernelFunction, ArgIdx, "kernel_arg_type_qual");
+getKernelArgTypeQual(const Function &F, unsigned ArgIdx) {
+  MDString *ArgAttribute = getOCLKernelArgTypeQual(F, ArgIdx);
   if (ArgAttribute && ArgAttribute->getString().compare("volatile") == 0)
     return {SPIRV::Decoration::Volatile};
   return {};
@@ -209,8 +157,7 @@ static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,
       isSpecialOpaqueType(OriginalArgType))
     return GR->getOrCreateSPIRVType(OriginalArgType, MIRBuilder, ArgAccessQual);
 
-  MDString *MDKernelArgType =
-      getKernelArgAttribute(F, ArgIdx, "kernel_arg_type");
+  MDString *MDKernelArgType = getOCLKernelArgType(F, ArgIdx);
   if (!MDKernelArgType || (!MDKernelArgType->getString().ends_with("*") &&
                            !MDKernelArgType->getString().ends_with("_t")))
     return GR->getOrCreateSPIRVType(OriginalArgType, MIRBuilder, ArgAccessQual);
diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index 90ec98bb361d3c0..56384ba1c006478 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -12,6 +12,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "SPIRV.h"
+#include "SPIRVMetadata.h"
 #include "SPIRVTargetMachine.h"
 #include "SPIRVUtils.h"
 #include "llvm/IR/IRBuilder.h"
@@ -282,7 +283,26 @@ void SPIRVEmitIntrinsics::insertPtrCastInstr(Instruction *I) {
   Value *Pointer;
   Type *ExpectedElementType;
   unsigned OperandToReplace;
-  if (StoreInst *SI = dyn_cast<StoreInst>(I)) {
+  bool AllowCastingToChar = false;
+
+  StoreInst *SI = dyn_cast<StoreInst>(I);
+  if (SI && F->getCallingConv() == CallingConv::SPIR_KERNEL &&
+      SI->getValueOperand()->getType()->isPointerTy() &&
+      isa<Argument>(SI->getValueOperand())) {
+    Argument *Arg = dyn_cast<Argument>(SI->getValueOperand());
+    MDString *ArgType = getOCLKernelArgType(*Arg->getParent(), Arg->getArgNo());
+    if (!ArgType || ArgType->getString().starts_with("uchar*"))
+      return;
+
+    // Handle special case when StoreInst's value operand is a kernel argument
+    // of a pointer type. Since these arguments could have either a basic
+    // element type (e.g. float*) or OpenCL builtin type (sampler_t), bitcast
+    // the StoreInst's value operand to default pointer element type (i8).
+    Pointer = Arg;
+    ExpectedElementType = IntegerType::getInt8Ty(F->getContext());
+    OperandToReplace = 0;
+    AllowCastingToChar = true;
+  } else if (SI) {
     Pointer = SI->getPointerOperand();
     ExpectedElementType = SI->getValueOperand()->getType();
     OperandToReplace = 1;
@@ -364,13 +384,15 @@ void SPIRVEmitIntrinsics::insertPtrCastInstr(Instruction *I) {
 
   // Do not emit spv_ptrcast if it would cast to the default pointer element
   // type (i8) of the same address space.
-  if (ExpectedElementType->isIntegerTy(8))
+  if (ExpectedElementType->isIntegerTy(8) && !AllowCastingToChar)
     return;
 
-  // If this would be the first spv_ptrcast and there is no spv_assign_ptr_type
-  // for this pointer before, do not emit spv_ptrcast but emit
-  // spv_assign_ptr_type instead.
-  if (FirstPtrCastOrAssignPtrType && isa<Instruction>(Pointer)) {
+  // If this would be the first spv_ptrcast, the pointer's defining instruction
+  // requires spv_assign_ptr_type and does not already have one, do not emit
+  // spv_ptrcast and emit spv_assign_ptr_type instead.
+  Instruction *PointerDefInst = dyn_cast<Instruction>(Pointer);
+  if (FirstPtrCastOrAssignPtrType && PointerDefInst &&
+      requireAssignPtrType(PointerDefInst)) {
     buildIntrWithMD(Intrinsic::spv_assign_ptr_type, {Pointer->getType()},
                     ExpectedElementTypeConst, Pointer,
                     {IRB->getInt32(AddressSpace)});
diff --git a/llvm/lib/Target/SPIRV/SPIRVMetadata.cpp b/llvm/lib/Target/SPIRV/SPIRVMetadata.cpp
new file mode 100644
index 000000000000000..e8c707742f24437
--- /dev/null
+++ b/llvm/lib/Target/SPIRV/SPIRVMetadata.cpp
@@ -0,0 +1,92 @@
+//===--- SPIRVMetadata.cpp ---- IR Metadata Parsing Funcs -------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains functions needed for parsing LLVM IR metadata relevant
+// to the SPIR-V target.
+//
+//===----------------------------------------------------------------------===//
+
+#include "SPIRVMetadata.h"
+
+using namespace llvm;
+
+static MDString *getOCLKernelArgAttribute(const Function &F, unsigned ArgIdx,
+                                          const StringRef AttributeName) {
+  assert(
+      F.getCallingConv() == CallingConv::SPIR_KERNEL &&
+      "Kernel attributes are attached/belong only to OpenCL kernel functions");
+
+  // Lookup the argument attribute in metadata attached to the kernel function.
+  MDNode *Node = F.getMetadata(AttributeName);
+  if (Node && ArgIdx < Node->getNumOperands())
+    return cast<MDString>(Node->getOperand(ArgIdx));
+
+  // Sometimes metadata containing kernel attributes is not attached to the
+  // function, but can be found in the named module-level metadata instead.
+  // For example:
+  //   !opencl.kernels = !{!0}
+  //   !0 = !{void ()* @someKernelFunction, !1, ...}
+  //   !1 = !{!"kernel_arg_addr_space", ...}
+  // In this case the actual index of searched argument attribute is ArgIdx + 1,
+  // since the first metadata node operand is occupied by attribute name
+  // ("kernel_arg_addr_space" in the example above).
+  unsigned MDArgIdx = ArgIdx + 1;
+  NamedMDNode *OpenCLKernelsMD =
+      F.getParent()->getNamedMetadata("opencl.kernels");
+  if (!OpenCLKernelsMD || OpenCLKernelsMD->getNumOperands() == 0)
+    return nullptr;
+
+  // KernelToMDNodeList contains kernel function declarations followed by
+  // corresponding MDNodes for each attribute. Search only MDNodes "belonging"
+  // to the currently lowered kernel function.
+  MDNode *KernelToMDNodeList = OpenCLKernelsMD->getOperand(0);
+  bool FoundLoweredKernelFunction = false;
+  for (const MDOperand &Operand : KernelToMDNodeList->operands()) {
+    ValueAsMetadata *MaybeValue = dyn_cast<ValueAsMetadata>(Operand);
+    if (MaybeValue &&
+        dyn_cast<Function>(MaybeValue->getValue())->getName() == F.getName()) {
+      FoundLoweredKernelFunction = true;
+      continue;
+    }
+    if (MaybeValue && FoundLoweredKernelFunction)
+      return nullptr;
+
+    MDNode *MaybeNode = dyn_cast<MDNode>(Operand);
+    if (FoundLoweredKernelFunction && MaybeNode &&
+        cast<MDString>(MaybeNode->getOperand(0))->getString() ==
+            AttributeName &&
+        MDArgIdx < MaybeNode->getNumOperands())
+      return cast<MDString>(MaybeNode->getOperand(MDArgIdx));
+  }
+  return nullptr;
+}
+
+namespace llvm {
+
+MDString *getOCLKernelArgAccessQual(const Function &F, unsigned ArgIdx) {
+  assert(
+      F.getCallingConv() == CallingConv::SPIR_KERNEL &&
+      "Kernel attributes are attached/belong only to OpenCL kernel functions");
+  return getOCLKernelArgAttribute(F, ArgIdx, "kernel_arg_access_qual");
+}
+
+MDString *getOCLKernelArgTypeQual(const Function &F, unsigned ArgIdx) {
+  assert(
+      F.getCallingConv() == CallingConv::SPIR_KERNEL &&
+      "Kernel attributes are attached/belong only to OpenCL kernel functions");
+  return getOCLKernelArgAttribute(F, ArgIdx, "kernel_arg_type_qual");
+}
+
+MDString *getOCLKernelArgType(const Function &F, unsigned ArgIdx) {
+  assert(
+      F.getCallingConv() == CallingConv::SPIR_KERNEL &&
+      "Kernel attributes are attached/belong only to OpenCL kernel functions");
+  return getOCLKernelArgAttribute(F, ArgIdx, "kernel_arg_type");
+}
+
+} // namespace llvm
diff --git a/llvm/lib/Target/SPIRV/SPIRVMetadata.h b/llvm/lib/Target/SPIRV/SPIRVMetadata.h
new file mode 100644
index 000000000000000..50aee7234395927
--- /dev/null
+++ b/llvm/lib/Target/SPIRV/SPIRVMetadata.h
@@ -0,0 +1,31 @@
+//===--- SPIRVMetadata.h ---- IR Metadata Parsing Funcs ---------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains functions needed for parsing LLVM IR metadata relevant
+// to the SPIR-V target.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_LIB_TARGET_SPIRV_SPIRVMETADATA_H
+#define LLVM_LIB_TARGET_SPIRV_SPIRVMETADATA_H
+
+#include "llvm/IR/Metadata.h"
+#include "llvm/IR/Module.h"
+
+namespace llvm {
+
+//===----------------------------------------------------------------------===//
+// OpenCL Metadata
+//
+
+MDString *getOCLKernelArgAccessQual(const Function &F, unsigned ArgIdx);
+MDString *getOCLKernelArgTypeQual(const Function &F, unsigned ArgIdx);
+MDString *getOCLKernelArgType(const Function &F, unsigned ArgIdx);
+
+} // namespace llvm
+#endif // LLVM_LIB_TARGET_SPIRV_METADATA_H
diff --git a/llvm/test/CodeGen/SPIRV/pointers/store-kernel-arg-ptr-as-value-operand.ll b/llvm/test/CodeGen/SPIRV/pointers/store-kernel-arg-ptr-as-value-operand.ll
new file mode 100644
index 000000000000000..e7ce3ef621e83a0
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/pointers/store-kernel-arg-ptr-as-value-operand.ll
@@ -0,0 +1,19 @@
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+define spir_kernel void @foo(ptr addrspace(1) %arg) !kernel_arg_addr_space !1 !kernel_arg_access_qual !2 !kernel_arg_type !3 !kernel_arg_base_type !3 !kernel_arg_type_qual !4 {
+  %var = alloca ptr addrspace(1), align 8
+; CHECK: %[[#VAR:]] = OpVariable %[[#]] Function
+  store ptr addrspace(1) %arg, ptr %var, align 8
+; The test itends to verify that OpStore uses OpVariable result directly (without a bitcast).
+; Other type checking is done by spirv-val.
+; CHECK: OpStore %[[#VAR]] %[[#]] Aligned 8
+  %lod = load ptr addrspace(1), ptr %var, align 8
+  %idx = getelementptr inbounds i64, ptr addrspace(1) %lod, i64 0
+  ret void
+}
+
+!1 = !{i32 1}
+!2 = !{!"none"}
+!3 = !{!"ulong*"}
+!4 = !{!""}

``````````

</details>


https://github.com/llvm/llvm-project/pull/78603


More information about the llvm-commits mailing list