[llvm] [SPIR-V] Do not rely on type metadata for ptr element type resolution (PR #82678)

Michal Paszkowski via llvm-commits llvm-commits at lists.llvm.org
Thu Feb 22 11:15:16 PST 2024


https://github.com/michalpaszkowski created https://github.com/llvm/llvm-project/pull/82678

This pull request aims to remove any dependency on OpenCL/SPIR-V type information in LLVM IR metadata. While, using metadata might simplify and prettify the resulting SPIR-V output (and restore some of the information missed in the transformation to opaque pointers), the overall methodology for resolving kernel parameter types is highly inefficient.

This pull request is work in progress, but the high-level strategy is to assign kernel parameter types in this order:

1. Resolving the types using builtin function calls as mangled names must contain type information or by looking up builtin definition in SPIRVBuiltins.td. Then:

  - Assigning the type temporarily using an intrinsic and later setting the right SPIR-V type in SPIRVGlobalRegistry after IRTranslation
  - Inserting a bitcast
 2. Defaulting to LLVM IR types (in case of pointers the generic i8* type)

In case of type incompatibility (e.g. parameter defined initially as sampler_t and later used as image_t) the error will be found early on before IRTranslation (most likely in the SPIRVEmitIntrinsics pass).

The code repetition in parseBuiltinCallArgumentBaseType(...) will be removed in an amended commit.



>From a4a0d28d4f42fc28bb591028a136efa44dc7af04 Mon Sep 17 00:00:00 2001
From: Michal Paszkowski <michal at paszkowski.org>
Date: Thu, 22 Feb 2024 10:59:39 -0800
Subject: [PATCH] [WIP] [SPIR-V] Do not rely on type metadata for ptr element
 type resolution

---
 llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp       | 71 +++++++++++++++++-
 llvm/lib/Target/SPIRV/SPIRVBuiltins.h         |  3 +
 llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp | 74 ++++++++++++-------
 llvm/test/CodeGen/SPIRV/half_no_extension.ll  |  3 -
 llvm/test/CodeGen/SPIRV/opencl/vstore2.ll     | 23 ++++++
 ...argument-builtin-vload-type-discrapency.ll | 34 +++++++++
 6 files changed, 176 insertions(+), 32 deletions(-)
 create mode 100644 llvm/test/CodeGen/SPIRV/opencl/vstore2.ll
 create mode 100644 llvm/test/CodeGen/SPIRV/pointers/kernel-argument-builtin-vload-type-discrapency.ll

diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index c1bb27322443ff..714a6265866fe9 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -2043,7 +2043,8 @@ static bool generateVectorLoadStoreInst(const SPIRV::IncomingCall *Call,
           .addImm(Builtin->Number);
   for (auto Argument : Call->Arguments)
     MIB.addUse(Argument);
-  MIB.addImm(Builtin->ElementCount);
+  if (Builtin->Name.contains("load") && Builtin->ElementCount > 1)
+    MIB.addImm(Builtin->ElementCount);
 
   // Rounding mode should be passed as a last argument in the MI for builtins
   // like "vstorea_halfn_r".
@@ -2179,6 +2180,74 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
   return false;
 }
 
+Type *parseBuiltinCallArgumentBaseType(const StringRef DemangledCall,
+                                   unsigned ArgIdx, LLVMContext &Ctx) {
+  SmallVector<StringRef, 10> BuiltinArgsTypeStrs;
+  StringRef BuiltinArgs =
+      DemangledCall.slice(DemangledCall.find('(') + 1, DemangledCall.find(')'));
+  BuiltinArgs.split(BuiltinArgsTypeStrs, ',', -1, false);
+  StringRef TypeStr = BuiltinArgsTypeStrs[ArgIdx].trim();
+  assert(ArgIdx < BuiltinArgsTypeStrs.size() && "Out of bounds argument index");
+  bool IsBaseType = TypeStr.ends_with("*") || TypeStr.ends_with("_t") ||
+                    TypeStr.starts_with("ocl_");
+  assert(IsBaseType && "Parsing only ptr element type/builtin base type");
+
+  // Parse type name in either "typeN" or "type vector[N]" format, where
+  // N is the number of elements of the vector.
+  Type *BaseType;
+  unsigned VecElts = 0;
+
+  TypeStr.consume_front("atomic_");
+
+  if (TypeStr.starts_with("void")) {
+    BaseType = Type::getVoidTy(Ctx);
+    TypeStr = TypeStr.substr(strlen("void"));
+  } else if (TypeStr.starts_with("bool")) {
+    BaseType = Type::getIntNTy(Ctx, 1);
+    TypeStr = TypeStr.substr(strlen("bool"));
+  } else if (TypeStr.starts_with("char") || TypeStr.starts_with("uchar")) {
+    BaseType = Type::getInt8Ty(Ctx);
+    TypeStr = TypeStr.starts_with("char") ? TypeStr.substr(strlen("char"))
+                                          : TypeStr.substr(strlen("uchar"));
+  } else if (TypeStr.starts_with("short") || TypeStr.starts_with("ushort")) {
+    BaseType = Type::getInt16Ty(Ctx);
+    TypeStr = TypeStr.starts_with("short") ? TypeStr.substr(strlen("short"))
+                                           : TypeStr.substr(strlen("ushort"));
+  } else if (TypeStr.starts_with("int") || TypeStr.starts_with("uint")) {
+    BaseType = Type::getInt32Ty(Ctx);
+    TypeStr = TypeStr.starts_with("int") ? TypeStr.substr(strlen("int"))
+                                         : TypeStr.substr(strlen("uint"));
+  } else if (TypeStr.starts_with("long") || TypeStr.starts_with("ulong")) {
+    BaseType = Type::getInt64Ty(Ctx);
+    TypeStr = TypeStr.starts_with("long") ? TypeStr.substr(strlen("long"))
+                                          : TypeStr.substr(strlen("ulong"));
+  } else if (TypeStr.starts_with("half")) {
+    BaseType = Type::getHalfTy(Ctx);
+    TypeStr = TypeStr.substr(strlen("half"));
+  } else if (TypeStr.starts_with("float")) {
+    BaseType = Type::getFloatTy(Ctx);
+    TypeStr = TypeStr.substr(strlen("float"));
+  } else if (TypeStr.starts_with("double")) {
+    BaseType = Type::getDoubleTy(Ctx);
+    TypeStr = TypeStr.substr(strlen("double"));
+  } else {
+    // Unable to recognize SPIRV type name
+    return nullptr;
+  }
+
+  // Handle "typeN*" or "type vector[N]*".
+  bool IsPtrToVec = TypeStr.consume_back("*");
+
+  if (TypeStr.consume_front(" vector["))
+    TypeStr = TypeStr.substr(0, TypeStr.find(']'));
+
+  TypeStr.getAsInteger(10, VecElts);
+  if (VecElts > 0)
+    BaseType = VectorType::get(BaseType, VecElts, false);
+
+  return BaseType;
+}
+
 struct BuiltinType {
   StringRef Name;
   uint32_t Opcode;
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.h b/llvm/lib/Target/SPIRV/SPIRVBuiltins.h
index 6f957295464812..bc0c8da3223706 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.h
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.h
@@ -38,6 +38,9 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
                                  const SmallVectorImpl<Register> &Args,
                                  SPIRVGlobalRegistry *GR);
 
+Type *parseBuiltinCallArgumentBaseType(const StringRef DemangledCall,
+                                       unsigned ArgIdx, LLVMContext &Ctx);
+
 /// Translates a string representing a SPIR-V or OpenCL builtin type to a
 /// TargetExtType that can be further lowered with lowerBuiltinType().
 ///
diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index e32cd50be56e38..e2f22daf4f9630 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -12,6 +12,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "SPIRV.h"
+#include "SPIRVBuiltins.h"
 #include "SPIRVMetadata.h"
 #include "SPIRVTargetMachine.h"
 #include "SPIRVUtils.h"
@@ -75,6 +76,9 @@ class SPIRVEmitIntrinsics
   void processInstrAfterVisit(Instruction *I);
   void insertAssignPtrTypeIntrs(Instruction *I);
   void insertAssignTypeIntrs(Instruction *I);
+  void replacePointerOperandWithPtrCast(Instruction *I, Value *Pointer,
+                                        Type *ExpectedElementType,
+                                        unsigned OperandToReplace);
   void insertPtrCastInstr(Instruction *I);
   void processGlobalValue(GlobalVariable &GV);
 
@@ -286,34 +290,9 @@ Instruction *SPIRVEmitIntrinsics::visitBitCastInst(BitCastInst &I) {
   return NewI;
 }
 
-void SPIRVEmitIntrinsics::insertPtrCastInstr(Instruction *I) {
-  Value *Pointer;
-  Type *ExpectedElementType;
-  unsigned OperandToReplace;
-
-  StoreInst *SI = dyn_cast<StoreInst>(I);
-  if (SI && F->getCallingConv() == CallingConv::SPIR_KERNEL &&
-      SI->getValueOperand()->getType()->isPointerTy() &&
-      isa<Argument>(SI->getValueOperand())) {
-    Pointer = SI->getValueOperand();
-    ExpectedElementType = IntegerType::getInt8Ty(F->getContext());
-    OperandToReplace = 0;
-  } else if (SI) {
-    Pointer = SI->getPointerOperand();
-    ExpectedElementType = SI->getValueOperand()->getType();
-    OperandToReplace = 1;
-  } else if (LoadInst *LI = dyn_cast<LoadInst>(I)) {
-    Pointer = LI->getPointerOperand();
-    ExpectedElementType = LI->getType();
-    OperandToReplace = 0;
-  } else if (GetElementPtrInst *GEPI = dyn_cast<GetElementPtrInst>(I)) {
-    Pointer = GEPI->getPointerOperand();
-    ExpectedElementType = GEPI->getSourceElementType();
-    OperandToReplace = 0;
-  } else {
-    return;
-  }
-
+void SPIRVEmitIntrinsics::replacePointerOperandWithPtrCast(
+    Instruction *I, Value *Pointer, Type *ExpectedElementType,
+    unsigned OperandToReplace) {
   // If Pointer is the result of nop BitCastInst (ptr -> ptr), use the source
   // pointer instead. The BitCastInst should be later removed when visited.
   while (BitCastInst *BC = dyn_cast<BitCastInst>(Pointer))
@@ -413,6 +392,45 @@ void SPIRVEmitIntrinsics::insertPtrCastInstr(Instruction *I) {
   }
 }
 
+void SPIRVEmitIntrinsics::insertPtrCastInstr(Instruction *I) {
+  // Handle basic instructions:
+  StoreInst *SI = dyn_cast<StoreInst>(I);
+  if (SI && F->getCallingConv() == CallingConv::SPIR_KERNEL &&
+      SI->getValueOperand()->getType()->isPointerTy() &&
+      isa<Argument>(SI->getValueOperand())) {
+    return replacePointerOperandWithPtrCast(
+        I, SI->getValueOperand(), IntegerType::getInt8Ty(F->getContext()), 0);
+  } else if (SI) {
+    return replacePointerOperandWithPtrCast(
+        I, SI->getPointerOperand(), SI->getValueOperand()->getType(), 1);
+  } else if (LoadInst *LI = dyn_cast<LoadInst>(I)) {
+    return replacePointerOperandWithPtrCast(I, LI->getPointerOperand(),
+                                            LI->getType(), 0);
+  } else if (GetElementPtrInst *GEPI = dyn_cast<GetElementPtrInst>(I)) {
+    return replacePointerOperandWithPtrCast(I, GEPI->getPointerOperand(),
+                                            GEPI->getSourceElementType(), 0);
+  }
+
+  // Handle calls to builtins (non-intrinsics):
+  CallInst *CI = dyn_cast<CallInst>(I);
+  if (!CI || CI->isIndirectCall() || CI->getCalledFunction()->isIntrinsic())
+    return;
+
+  std::string DemangledName =
+      getOclOrSpirvBuiltinDemangledName(CI->getCalledFunction()->getName());
+  if (DemangledName.empty())
+    return;
+
+  for (unsigned OpIdx = 0; OpIdx < CI->arg_size(); OpIdx++) {
+    if (!isa<PointerType>(CI->getArgOperand(OpIdx)->getType()))
+      continue;
+    Type *ExpectedType = SPIRV::parseBuiltinCallArgumentBaseType(
+        DemangledName, OpIdx, I->getContext());
+    replacePointerOperandWithPtrCast(CI, CI->getArgOperand(OpIdx), ExpectedType,
+                                     OpIdx);
+  }
+}
+
 Instruction *SPIRVEmitIntrinsics::visitInsertElementInst(InsertElementInst &I) {
   SmallVector<Type *, 4> Types = {I.getType(), I.getOperand(0)->getType(),
                                   I.getOperand(1)->getType(),
diff --git a/llvm/test/CodeGen/SPIRV/half_no_extension.ll b/llvm/test/CodeGen/SPIRV/half_no_extension.ll
index 6414b62874bc71..a5b0ec9c92d236 100644
--- a/llvm/test/CodeGen/SPIRV/half_no_extension.ll
+++ b/llvm/test/CodeGen/SPIRV/half_no_extension.ll
@@ -7,9 +7,6 @@
 
 ; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
 
-; TODO(#60133): Requires updates following opaque pointer migration.
-; XFAIL: *
-
 ; CHECK-SPIRV:     OpCapability Float16Buffer
 ; CHECK-SPIRV-NOT: OpCapability Float16
 
diff --git a/llvm/test/CodeGen/SPIRV/opencl/vstore2.ll b/llvm/test/CodeGen/SPIRV/opencl/vstore2.ll
new file mode 100644
index 00000000000000..0ea14c7f25e877
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/opencl/vstore2.ll
@@ -0,0 +1,23 @@
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
+; This test only intends to check the vstoren builtin name resolution.
+; The calls to the OpenCL builtins are not valid and will not pass SPIR-V validation.
+
+; CHECK-DAG: %[[#IMPORT:]] = OpExtInstImport "OpenCL.std"
+
+; CHECK-DAG: %[[#VOID:]] = OpTypeVoid
+; CHECK-DAG: %[[#INT8:]] = OpTypeInt 8 0
+; CHECK-DAG: %[[#INT64:]] = OpTypeInt 64 0
+; CHECK-DAG: %[[#VINT8:]] = OpTypeVector %[[#INT8]] 2
+; CHECK-DAG: %[[#PTRINT8:]] = OpTypePointer CrossWorkgroup %[[#INT8]]
+
+; CHECK: %[[#DATA:]] = OpFunctionParameter %[[#VINT8]]
+; CHECK: %[[#OFFSET:]] = OpFunctionParameter %[[#INT64]]
+; CHECK: %[[#ADDRESS:]] = OpFunctionParameter %[[#PTRINT8]]
+
+define spir_kernel void @test_fn(<2 x i8> %data, i64 %offset, ptr addrspace(1) %address) {
+; CHECK: %[[#]] = OpExtInst %[[#VOID]] %[[#IMPORT]] vstoren %[[#DATA]] %[[#OFFSET]] %[[#ADDRESS]]
+  call spir_func void @_Z7vstore2Dv2_cmPU3AS1c(<2 x i8> %data, i64 %offset, ptr addrspace(1) %address)
+  ret void
+}
+
+declare spir_func void @_Z7vstore2Dv2_cmPU3AS1c(<2 x i8>, i64, ptr addrspace(1))
diff --git a/llvm/test/CodeGen/SPIRV/pointers/kernel-argument-builtin-vload-type-discrapency.ll b/llvm/test/CodeGen/SPIRV/pointers/kernel-argument-builtin-vload-type-discrapency.ll
new file mode 100644
index 00000000000000..2216eda027af33
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/pointers/kernel-argument-builtin-vload-type-discrapency.ll
@@ -0,0 +1,34 @@
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
+
+; CHECK-DAG: %[[#INT8:]] = OpTypeInt 8 0
+; CHECK-DAG: %[[#PTRINT8:]] = OpTypePointer CrossWorkgroup %[[#INT8]]
+
+define spir_kernel void @test_fn(ptr addrspace(1) %src, ptr addrspace(1) %dummy) !kernel_arg_addr_space !1 !kernel_arg_access_qual !2 !kernel_arg_type !3 !kernel_arg_type_qual !4 !kernel_arg_base_type !3 {
+entry:
+  %g1 = call spir_func i64 @_Z13get_global_idj(i32 0)
+  %i1 = insertelement <3 x i64> undef, i64 %g1, i32 0
+  %g2 = call spir_func i64 @_Z13get_global_idj(i32 1)
+  %i2 = insertelement <3 x i64> %i1, i64 %g2, i32 1
+  %g3 = call spir_func i64 @_Z13get_global_idj(i32 2)
+  %i3 = insertelement <3 x i64> %i2, i64 %g3, i32 2
+  %e = extractelement <3 x i64> %i3, i32 0
+  %c1 = trunc i64 %e to i32
+  %c2 = sext i32 %c1 to i64
+  %b = bitcast ptr addrspace(1) %src to ptr addrspace(1)
+
+; Make sure that builtin call directly uses either a OpBitcast or OpFunctionParameter of i8* type
+; CHECK: %[[#BITCASTorPARAMETER:]] = {{.*}} %[[#PTRINT8]] {{.*}}
+; CHECK: %[[#]] = OpExtInst %[[#]] %[[#]] vloadn %[[#]] %[[#BITCASTorPARAMETER]] 3
+  %call = call spir_func <3 x i8> @_Z6vload3mPU3AS1Kc(i64 %c2, ptr addrspace(1) %b)
+  
+  ret void
+}
+
+declare spir_func i64 @_Z13get_global_idj(i32)
+
+declare spir_func <3 x i8> @_Z6vload3mPU3AS1Kc(i64, ptr addrspace(1))
+
+!1 = !{i32 1, i32 1}
+!2 = !{!"none", !"none"}
+!3 = !{!"char3*", !"char*"}
+!4 = !{!"", !""}



More information about the llvm-commits mailing list