[llvm] 0fbaf03 - [SPIR-V] Cast ptr kernel args to i8* when used as Store's value operand (#78603)
via llvm-commits
llvm-commits at lists.llvm.org
Sun Jan 28 19:30:18 PST 2024
Author: Michal Paszkowski
Date: 2024-01-28T19:30:14-08:00
New Revision: 0fbaf03f703db2cb29d1bde23708b80db049164f
URL: https://github.com/llvm/llvm-project/commit/0fbaf03f703db2cb29d1bde23708b80db049164f
DIFF: https://github.com/llvm/llvm-project/commit/0fbaf03f703db2cb29d1bde23708b80db049164f.diff
LOG: [SPIR-V] Cast ptr kernel args to i8* when used as Store's value operand (#78603)
Handle a special case when StoreInst's value operand is a kernel
argument of a pointer type. Since these arguments could have either a
basic element type (e.g. float*) or OpenCL builtin type (sampler_t),
bitcast the StoreInst's value operand to default pointer element type
(i8).
This pull request addresses the issue
https://github.com/llvm/llvm-project/issues/72864
Added:
llvm/lib/Target/SPIRV/SPIRVMetadata.cpp
llvm/lib/Target/SPIRV/SPIRVMetadata.h
llvm/test/CodeGen/SPIRV/pointers/kernel-argument-ptr-i8-default-element-type.ll
llvm/test/CodeGen/SPIRV/pointers/kernel-argument-ptr-no-bitcast.ll
llvm/test/CodeGen/SPIRV/pointers/store-kernel-arg-i8-ptr-as-value-operand.ll
llvm/test/CodeGen/SPIRV/pointers/store-kernel-arg-ptr-as-value-operand.ll
Modified:
llvm/lib/Target/SPIRV/CMakeLists.txt
llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
Removed:
################################################################################
diff --git a/llvm/lib/Target/SPIRV/CMakeLists.txt b/llvm/lib/Target/SPIRV/CMakeLists.txt
index 7d17c307db13a04..d9e24375dcb243f 100644
--- a/llvm/lib/Target/SPIRV/CMakeLists.txt
+++ b/llvm/lib/Target/SPIRV/CMakeLists.txt
@@ -26,6 +26,7 @@ add_llvm_target(SPIRVCodeGen
SPIRVISelLowering.cpp
SPIRVLegalizerInfo.cpp
SPIRVMCInstLower.cpp
+ SPIRVMetadata.cpp
SPIRVModuleAnalysis.cpp
SPIRVPreLegalizer.cpp
SPIRVPrepareFunctions.cpp
diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
index 0a8b5499a1fc2ac..62c08bab46eee27 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
@@ -17,6 +17,7 @@
#include "SPIRVBuiltins.h"
#include "SPIRVGlobalRegistry.h"
#include "SPIRVISelLowering.h"
+#include "SPIRVMetadata.h"
#include "SPIRVRegisterInfo.h"
#include "SPIRVSubtarget.h"
#include "SPIRVUtils.h"
@@ -117,64 +118,12 @@ static FunctionType *getOriginalFunctionType(const Function &F) {
return FunctionType::get(RetTy, ArgTypes, F.isVarArg());
}
-static MDString *getKernelArgAttribute(const Function &KernelFunction,
- unsigned ArgIdx,
- const StringRef AttributeName) {
- assert(KernelFunction.getCallingConv() == CallingConv::SPIR_KERNEL &&
- "Kernel attributes are attached/belong only to kernel functions");
-
- // Lookup the argument attribute in metadata attached to the kernel function.
- MDNode *Node = KernelFunction.getMetadata(AttributeName);
- if (Node && ArgIdx < Node->getNumOperands())
- return cast<MDString>(Node->getOperand(ArgIdx));
-
- // Sometimes metadata containing kernel attributes is not attached to the
- // function, but can be found in the named module-level metadata instead.
- // For example:
- // !opencl.kernels = !{!0}
- // !0 = !{void ()* @someKernelFunction, !1, ...}
- // !1 = !{!"kernel_arg_addr_space", ...}
- // In this case the actual index of searched argument attribute is ArgIdx + 1,
- // since the first metadata node operand is occupied by attribute name
- // ("kernel_arg_addr_space" in the example above).
- unsigned MDArgIdx = ArgIdx + 1;
- NamedMDNode *OpenCLKernelsMD =
- KernelFunction.getParent()->getNamedMetadata("opencl.kernels");
- if (!OpenCLKernelsMD || OpenCLKernelsMD->getNumOperands() == 0)
- return nullptr;
-
- // KernelToMDNodeList contains kernel function declarations followed by
- // corresponding MDNodes for each attribute. Search only MDNodes "belonging"
- // to the currently lowered kernel function.
- MDNode *KernelToMDNodeList = OpenCLKernelsMD->getOperand(0);
- bool FoundLoweredKernelFunction = false;
- for (const MDOperand &Operand : KernelToMDNodeList->operands()) {
- ValueAsMetadata *MaybeValue = dyn_cast<ValueAsMetadata>(Operand);
- if (MaybeValue && dyn_cast<Function>(MaybeValue->getValue())->getName() ==
- KernelFunction.getName()) {
- FoundLoweredKernelFunction = true;
- continue;
- }
- if (MaybeValue && FoundLoweredKernelFunction)
- return nullptr;
-
- MDNode *MaybeNode = dyn_cast<MDNode>(Operand);
- if (FoundLoweredKernelFunction && MaybeNode &&
- cast<MDString>(MaybeNode->getOperand(0))->getString() ==
- AttributeName &&
- MDArgIdx < MaybeNode->getNumOperands())
- return cast<MDString>(MaybeNode->getOperand(MDArgIdx));
- }
- return nullptr;
-}
-
static SPIRV::AccessQualifier::AccessQualifier
getArgAccessQual(const Function &F, unsigned ArgIdx) {
if (F.getCallingConv() != CallingConv::SPIR_KERNEL)
return SPIRV::AccessQualifier::ReadWrite;
- MDString *ArgAttribute =
- getKernelArgAttribute(F, ArgIdx, "kernel_arg_access_qual");
+ MDString *ArgAttribute = getOCLKernelArgAccessQual(F, ArgIdx);
if (!ArgAttribute)
return SPIRV::AccessQualifier::ReadWrite;
@@ -186,9 +135,8 @@ getArgAccessQual(const Function &F, unsigned ArgIdx) {
}
static std::vector<SPIRV::Decoration::Decoration>
-getKernelArgTypeQual(const Function &KernelFunction, unsigned ArgIdx) {
- MDString *ArgAttribute =
- getKernelArgAttribute(KernelFunction, ArgIdx, "kernel_arg_type_qual");
+getKernelArgTypeQual(const Function &F, unsigned ArgIdx) {
+ MDString *ArgAttribute = getOCLKernelArgTypeQual(F, ArgIdx);
if (ArgAttribute && ArgAttribute->getString().compare("volatile") == 0)
return {SPIRV::Decoration::Volatile};
return {};
@@ -209,8 +157,7 @@ static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,
isSpecialOpaqueType(OriginalArgType))
return GR->getOrCreateSPIRVType(OriginalArgType, MIRBuilder, ArgAccessQual);
- MDString *MDKernelArgType =
- getKernelArgAttribute(F, ArgIdx, "kernel_arg_type");
+ MDString *MDKernelArgType = getOCLKernelArgType(F, ArgIdx);
if (!MDKernelArgType || (!MDKernelArgType->getString().ends_with("*") &&
!MDKernelArgType->getString().ends_with("_t")))
return GR->getOrCreateSPIRVType(OriginalArgType, MIRBuilder, ArgAccessQual);
diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index 90ec98bb361d3c0..4169f7c6233c997 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -12,6 +12,7 @@
//===----------------------------------------------------------------------===//
#include "SPIRV.h"
+#include "SPIRVMetadata.h"
#include "SPIRVTargetMachine.h"
#include "SPIRVUtils.h"
#include "llvm/IR/IRBuilder.h"
@@ -282,7 +283,26 @@ void SPIRVEmitIntrinsics::insertPtrCastInstr(Instruction *I) {
Value *Pointer;
Type *ExpectedElementType;
unsigned OperandToReplace;
- if (StoreInst *SI = dyn_cast<StoreInst>(I)) {
+ bool AllowCastingToChar = false;
+
+ StoreInst *SI = dyn_cast<StoreInst>(I);
+ if (SI && F->getCallingConv() == CallingConv::SPIR_KERNEL &&
+ SI->getValueOperand()->getType()->isPointerTy() &&
+ isa<Argument>(SI->getValueOperand())) {
+ Argument *Arg = cast<Argument>(SI->getValueOperand());
+ MDString *ArgType = getOCLKernelArgType(*Arg->getParent(), Arg->getArgNo());
+ if (!ArgType || ArgType->getString().starts_with("uchar*"))
+ return;
+
+ // Handle special case when StoreInst's value operand is a kernel argument
+ // of a pointer type. Since these arguments could have either a basic
+ // element type (e.g. float*) or OpenCL builtin type (sampler_t), bitcast
+ // the StoreInst's value operand to default pointer element type (i8).
+ Pointer = Arg;
+ ExpectedElementType = IntegerType::getInt8Ty(F->getContext());
+ OperandToReplace = 0;
+ AllowCastingToChar = true;
+ } else if (SI) {
Pointer = SI->getPointerOperand();
ExpectedElementType = SI->getValueOperand()->getType();
OperandToReplace = 1;
@@ -364,13 +384,15 @@ void SPIRVEmitIntrinsics::insertPtrCastInstr(Instruction *I) {
// Do not emit spv_ptrcast if it would cast to the default pointer element
// type (i8) of the same address space.
- if (ExpectedElementType->isIntegerTy(8))
+ if (ExpectedElementType->isIntegerTy(8) && !AllowCastingToChar)
return;
- // If this would be the first spv_ptrcast and there is no spv_assign_ptr_type
- // for this pointer before, do not emit spv_ptrcast but emit
- // spv_assign_ptr_type instead.
- if (FirstPtrCastOrAssignPtrType && isa<Instruction>(Pointer)) {
+ // If this would be the first spv_ptrcast, the pointer's defining instruction
+ // requires spv_assign_ptr_type and does not already have one, do not emit
+ // spv_ptrcast and emit spv_assign_ptr_type instead.
+ Instruction *PointerDefInst = dyn_cast<Instruction>(Pointer);
+ if (FirstPtrCastOrAssignPtrType && PointerDefInst &&
+ requireAssignPtrType(PointerDefInst)) {
buildIntrWithMD(Intrinsic::spv_assign_ptr_type, {Pointer->getType()},
ExpectedElementTypeConst, Pointer,
{IRB->getInt32(AddressSpace)});
diff --git a/llvm/lib/Target/SPIRV/SPIRVMetadata.cpp b/llvm/lib/Target/SPIRV/SPIRVMetadata.cpp
new file mode 100644
index 000000000000000..e8c707742f24437
--- /dev/null
+++ b/llvm/lib/Target/SPIRV/SPIRVMetadata.cpp
@@ -0,0 +1,92 @@
+//===--- SPIRVMetadata.cpp ---- IR Metadata Parsing Funcs -------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains functions needed for parsing LLVM IR metadata relevant
+// to the SPIR-V target.
+//
+//===----------------------------------------------------------------------===//
+
+#include "SPIRVMetadata.h"
+
+using namespace llvm;
+
+static MDString *getOCLKernelArgAttribute(const Function &F, unsigned ArgIdx,
+ const StringRef AttributeName) {
+ assert(
+ F.getCallingConv() == CallingConv::SPIR_KERNEL &&
+ "Kernel attributes are attached/belong only to OpenCL kernel functions");
+
+ // Lookup the argument attribute in metadata attached to the kernel function.
+ MDNode *Node = F.getMetadata(AttributeName);
+ if (Node && ArgIdx < Node->getNumOperands())
+ return cast<MDString>(Node->getOperand(ArgIdx));
+
+ // Sometimes metadata containing kernel attributes is not attached to the
+ // function, but can be found in the named module-level metadata instead.
+ // For example:
+ // !opencl.kernels = !{!0}
+ // !0 = !{void ()* @someKernelFunction, !1, ...}
+ // !1 = !{!"kernel_arg_addr_space", ...}
+ // In this case the actual index of searched argument attribute is ArgIdx + 1,
+ // since the first metadata node operand is occupied by attribute name
+ // ("kernel_arg_addr_space" in the example above).
+ unsigned MDArgIdx = ArgIdx + 1;
+ NamedMDNode *OpenCLKernelsMD =
+ F.getParent()->getNamedMetadata("opencl.kernels");
+ if (!OpenCLKernelsMD || OpenCLKernelsMD->getNumOperands() == 0)
+ return nullptr;
+
+ // KernelToMDNodeList contains kernel function declarations followed by
+ // corresponding MDNodes for each attribute. Search only MDNodes "belonging"
+ // to the currently lowered kernel function.
+ MDNode *KernelToMDNodeList = OpenCLKernelsMD->getOperand(0);
+ bool FoundLoweredKernelFunction = false;
+ for (const MDOperand &Operand : KernelToMDNodeList->operands()) {
+ ValueAsMetadata *MaybeValue = dyn_cast<ValueAsMetadata>(Operand);
+ if (MaybeValue &&
+ dyn_cast<Function>(MaybeValue->getValue())->getName() == F.getName()) {
+ FoundLoweredKernelFunction = true;
+ continue;
+ }
+ if (MaybeValue && FoundLoweredKernelFunction)
+ return nullptr;
+
+ MDNode *MaybeNode = dyn_cast<MDNode>(Operand);
+ if (FoundLoweredKernelFunction && MaybeNode &&
+ cast<MDString>(MaybeNode->getOperand(0))->getString() ==
+ AttributeName &&
+ MDArgIdx < MaybeNode->getNumOperands())
+ return cast<MDString>(MaybeNode->getOperand(MDArgIdx));
+ }
+ return nullptr;
+}
+
+namespace llvm {
+
+MDString *getOCLKernelArgAccessQual(const Function &F, unsigned ArgIdx) {
+ assert(
+ F.getCallingConv() == CallingConv::SPIR_KERNEL &&
+ "Kernel attributes are attached/belong only to OpenCL kernel functions");
+ return getOCLKernelArgAttribute(F, ArgIdx, "kernel_arg_access_qual");
+}
+
+MDString *getOCLKernelArgTypeQual(const Function &F, unsigned ArgIdx) {
+ assert(
+ F.getCallingConv() == CallingConv::SPIR_KERNEL &&
+ "Kernel attributes are attached/belong only to OpenCL kernel functions");
+ return getOCLKernelArgAttribute(F, ArgIdx, "kernel_arg_type_qual");
+}
+
+MDString *getOCLKernelArgType(const Function &F, unsigned ArgIdx) {
+ assert(
+ F.getCallingConv() == CallingConv::SPIR_KERNEL &&
+ "Kernel attributes are attached/belong only to OpenCL kernel functions");
+ return getOCLKernelArgAttribute(F, ArgIdx, "kernel_arg_type");
+}
+
+} // namespace llvm
diff --git a/llvm/lib/Target/SPIRV/SPIRVMetadata.h b/llvm/lib/Target/SPIRV/SPIRVMetadata.h
new file mode 100644
index 000000000000000..50aee7234395927
--- /dev/null
+++ b/llvm/lib/Target/SPIRV/SPIRVMetadata.h
@@ -0,0 +1,31 @@
+//===--- SPIRVMetadata.h ---- IR Metadata Parsing Funcs ---------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains functions needed for parsing LLVM IR metadata relevant
+// to the SPIR-V target.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_LIB_TARGET_SPIRV_SPIRVMETADATA_H
+#define LLVM_LIB_TARGET_SPIRV_SPIRVMETADATA_H
+
+#include "llvm/IR/Metadata.h"
+#include "llvm/IR/Module.h"
+
+namespace llvm {
+
+//===----------------------------------------------------------------------===//
+// OpenCL Metadata
+//
+
+MDString *getOCLKernelArgAccessQual(const Function &F, unsigned ArgIdx);
+MDString *getOCLKernelArgTypeQual(const Function &F, unsigned ArgIdx);
+MDString *getOCLKernelArgType(const Function &F, unsigned ArgIdx);
+
+} // namespace llvm
+#endif // LLVM_LIB_TARGET_SPIRV_METADATA_H
diff --git a/llvm/test/CodeGen/SPIRV/pointers/kernel-argument-ptr-i8-default-element-type.ll b/llvm/test/CodeGen/SPIRV/pointers/kernel-argument-ptr-i8-default-element-type.ll
new file mode 100644
index 000000000000000..55bddfdad699b22
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/pointers/kernel-argument-ptr-i8-default-element-type.ll
@@ -0,0 +1,11 @@
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-DAG: %[[#CHAR:]] = OpTypeInt 8
+; CHECK-DAG: %[[#GLOBAL_PTR_CHAR:]] = OpTypePointer CrossWorkgroup %[[#CHAR]]
+
+define spir_kernel void @foo(ptr addrspace(1) %arg) {
+ ret void
+}
+
+; CHECK: %[[#]] = OpFunctionParameter %[[#GLOBAL_PTR_CHAR]]
diff --git a/llvm/test/CodeGen/SPIRV/pointers/kernel-argument-ptr-no-bitcast.ll b/llvm/test/CodeGen/SPIRV/pointers/kernel-argument-ptr-no-bitcast.ll
new file mode 100644
index 000000000000000..0d2a832c496b1b6
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/pointers/kernel-argument-ptr-no-bitcast.ll
@@ -0,0 +1,14 @@
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-DAG: %[[#CHAR:]] = OpTypeInt 8
+; CHECK-DAG: %[[#GLOBAL_PTR_CHAR:]] = OpTypePointer CrossWorkgroup %[[#CHAR]]
+
+define spir_kernel void @foo(i8 %a, ptr addrspace(1) %p) {
+ store i8 %a, ptr addrspace(1) %p
+ ret void
+}
+
+; CHECK: %[[#A:]] = OpFunctionParameter %[[#CHAR]]
+; CHECK: %[[#P:]] = OpFunctionParameter %[[#GLOBAL_PTR_CHAR]]
+; CHECK: OpStore %[[#P]] %[[#A]]
diff --git a/llvm/test/CodeGen/SPIRV/pointers/store-kernel-arg-i8-ptr-as-value-operand.ll b/llvm/test/CodeGen/SPIRV/pointers/store-kernel-arg-i8-ptr-as-value-operand.ll
new file mode 100644
index 000000000000000..5adaf6f65688df1
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/pointers/store-kernel-arg-i8-ptr-as-value-operand.ll
@@ -0,0 +1,18 @@
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-DAG: %[[#CHAR:]] = OpTypeInt 8
+; CHECK-DAG: %[[#GLOBAL_PTR_CHAR:]] = OpTypePointer CrossWorkgroup %[[#CHAR]]
+
+define spir_kernel void @foo(ptr addrspace(1) %arg) !kernel_arg_addr_space !1 !kernel_arg_access_qual !2 !kernel_arg_type !3 !kernel_arg_base_type !3 !kernel_arg_type_qual !4 {
+ %var = alloca ptr addrspace(1), align 8
+; CHECK: %[[#]] = OpFunctionParameter %[[#GLOBAL_PTR_CHAR]]
+; CHECK-NOT: %[[#]] = OpBitcast %[[#]] %[[#]]
+ store ptr addrspace(1) %arg, ptr %var, align 8
+ ret void
+}
+
+!1 = !{i32 1}
+!2 = !{!"none"}
+!3 = !{!"char*"}
+!4 = !{!""}
diff --git a/llvm/test/CodeGen/SPIRV/pointers/store-kernel-arg-ptr-as-value-operand.ll b/llvm/test/CodeGen/SPIRV/pointers/store-kernel-arg-ptr-as-value-operand.ll
new file mode 100644
index 000000000000000..e7ce3ef621e83a0
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/pointers/store-kernel-arg-ptr-as-value-operand.ll
@@ -0,0 +1,19 @@
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+define spir_kernel void @foo(ptr addrspace(1) %arg) !kernel_arg_addr_space !1 !kernel_arg_access_qual !2 !kernel_arg_type !3 !kernel_arg_base_type !3 !kernel_arg_type_qual !4 {
+ %var = alloca ptr addrspace(1), align 8
+; CHECK: %[[#VAR:]] = OpVariable %[[#]] Function
+ store ptr addrspace(1) %arg, ptr %var, align 8
+; The test itends to verify that OpStore uses OpVariable result directly (without a bitcast).
+; Other type checking is done by spirv-val.
+; CHECK: OpStore %[[#VAR]] %[[#]] Aligned 8
+ %lod = load ptr addrspace(1), ptr %var, align 8
+ %idx = getelementptr inbounds i64, ptr addrspace(1) %lod, i64 0
+ ret void
+}
+
+!1 = !{i32 1}
+!2 = !{!"none"}
+!3 = !{!"ulong*"}
+!4 = !{!""}
More information about the llvm-commits
mailing list