[llvm] [SPIR-V] Do not rely on type metadata for ptr element type resolution (PR #82678)

via llvm-commits llvm-commits at lists.llvm.org
Wed Feb 28 04:50:21 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-spir-v

Author: Michal Paszkowski (michalpaszkowski)

<details>
<summary>Changes</summary>

This pull request aims to remove any dependency on OpenCL/SPIR-V type information in LLVM IR metadata. While, using metadata might simplify and prettify the resulting SPIR-V output (and restore some of the information missed in the transformation to opaque pointers), the overall methodology for resolving kernel parameter types is highly inefficient.

This pull request is work in progress, but the high-level strategy is to assign kernel parameter types in this order:

1. Resolving the types using builtin function calls as mangled names must contain type information or by looking up builtin definition in SPIRVBuiltins.td. Then:

  - Assigning the type temporarily using an intrinsic and later setting the right SPIR-V type in SPIRVGlobalRegistry after IRTranslation
  - Inserting a bitcast
 2. Defaulting to LLVM IR types (in case of pointers the generic i8* type)

In case of type incompatibility (e.g. parameter defined initially as sampler_t and later used as image_t) the error will be found early on before IRTranslation (most likely in the SPIRVEmitIntrinsics pass).

The code repetition in parseBuiltinCallArgumentBaseType(...) will be removed in an amended commit.



---

Patch is 62.00 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/82678.diff


35 Files Affected:

- (modified) llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp (+65-10) 
- (modified) llvm/lib/Target/SPIRV/SPIRVBuiltins.h (+13-3) 
- (modified) llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp (+35-19) 
- (modified) llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp (+108-71) 
- (modified) llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp (+5-39) 
- (modified) llvm/lib/Target/SPIRV/SPIRVMetadata.cpp (-7) 
- (modified) llvm/lib/Target/SPIRV/SPIRVMetadata.h (-1) 
- (modified) llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp (+4-2) 
- (modified) llvm/lib/Target/SPIRV/SPIRVUtils.cpp (+29-13) 
- (modified) llvm/lib/Target/SPIRV/SPIRVUtils.h (+3-4) 
- (modified) llvm/test/CodeGen/SPIRV/function/alloca-load-store.ll (+4-7) 
- (modified) llvm/test/CodeGen/SPIRV/half_no_extension.ll (-3) 
- (modified) llvm/test/CodeGen/SPIRV/instructions/undef-nested-composite-store.ll (+4-6) 
- (modified) llvm/test/CodeGen/SPIRV/instructions/undef-simple-composite-store.ll (+4-6) 
- (modified) llvm/test/CodeGen/SPIRV/opaque_pointers.ll (+5-8) 
- (modified) llvm/test/CodeGen/SPIRV/opencl/basic/get_global_offset.ll (+10-14) 
- (removed) llvm/test/CodeGen/SPIRV/opencl/metadata/kernel_arg_type_function_metadata.ll (-12) 
- (removed) llvm/test/CodeGen/SPIRV/opencl/metadata/kernel_arg_type_module_metadata.ll (-16) 
- (modified) llvm/test/CodeGen/SPIRV/opencl/vload2.ll (+14-6) 
- (added) llvm/test/CodeGen/SPIRV/opencl/vstore2.ll (+23) 
- (added) llvm/test/CodeGen/SPIRV/passes/SPIRVEmitIntrinsics-TargetExtType-arg-no-spv_assign_type.ll (+12) 
- (added) llvm/test/CodeGen/SPIRV/passes/SPIRVEmitIntrinsics-no-divergent-spv_assign_ptr_type.ll (+12) 
- (added) llvm/test/CodeGen/SPIRV/passes/SPIRVEmitIntrinsics-no-duplicate-spv_assign_type.ll (+14) 
- (modified) llvm/test/CodeGen/SPIRV/pointers/getelementptr-kernel-arg-char.ll (+6-14) 
- (added) llvm/test/CodeGen/SPIRV/pointers/kernel-argument-builtin-vload-type-discrapency.ll (+35) 
- (added) llvm/test/CodeGen/SPIRV/pointers/kernel-argument-pointer-type-deduction-mismatch.ll (+12) 
- (added) llvm/test/CodeGen/SPIRV/pointers/kernel-argument-pointer-type-deduction-no-metadata.ll (+13) 
- (added) llvm/test/CodeGen/SPIRV/pointers/store-operand-ptr-to-struct.ll (+19) 
- (renamed) llvm/test/CodeGen/SPIRV/pointers/two-bitcast-or-param-users.ll (+3-7) 
- (modified) llvm/test/CodeGen/SPIRV/pointers/two-subsequent-bitcasts.ll (+3-4) 
- (modified) llvm/test/CodeGen/SPIRV/sitofp-with-bool.ll (+2-3) 
- (modified) llvm/test/CodeGen/SPIRV/transcoding/OpenCL/atomic_cmpxchg.ll (+2-3) 
- (modified) llvm/test/CodeGen/SPIRV/transcoding/OpenCL/atomic_legacy.ll (+2-3) 
- (modified) llvm/test/CodeGen/SPIRV/transcoding/spirv-private-array-initialization.ll (-1) 
- (modified) llvm/test/CodeGen/SPIRV/uitofp-with-bool.ll (+2-3) 


``````````diff
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index c1bb27322443ff..119a15b5f1bfb9 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -1775,7 +1775,7 @@ static const Type *getMachineInstrType(MachineInstr *MI) {
     return nullptr;
   Type *Ty = getMDOperandAsType(NextMI->getOperand(2).getMetadata(), 0);
   assert(Ty && "Type is expected");
-  return getTypedPtrEltType(Ty);
+  return Ty;
 }
 
 static const Type *getBlockStructType(Register ParamReg,
@@ -1787,7 +1787,7 @@ static const Type *getBlockStructType(Register ParamReg,
   // section 6.12.5 should guarantee that we can do this.
   MachineInstr *MI = getBlockStructInstr(ParamReg, MRI);
   if (MI->getOpcode() == TargetOpcode::G_GLOBAL_VALUE)
-    return getTypedPtrEltType(MI->getOperand(1).getGlobal()->getType());
+    return MI->getOperand(1).getGlobal()->getType();
   assert(isSpvIntrinsic(*MI, Intrinsic::spv_alloca) &&
          "Blocks in OpenCL C must be traceable to allocation site");
   return getMachineInstrType(MI);
@@ -2043,7 +2043,8 @@ static bool generateVectorLoadStoreInst(const SPIRV::IncomingCall *Call,
           .addImm(Builtin->Number);
   for (auto Argument : Call->Arguments)
     MIB.addUse(Argument);
-  MIB.addImm(Builtin->ElementCount);
+  if (Builtin->Name.contains("load") && Builtin->ElementCount > 1)
+    MIB.addImm(Builtin->ElementCount);
 
   // Rounding mode should be passed as a last argument in the MI for builtins
   // like "vstorea_halfn_r".
@@ -2179,6 +2180,61 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
   return false;
 }
 
+Type *parseBuiltinCallArgumentBaseType(const StringRef DemangledCall,
+                                       unsigned ArgIdx, LLVMContext &Ctx) {
+  SmallVector<StringRef, 10> BuiltinArgsTypeStrs;
+  StringRef BuiltinArgs =
+      DemangledCall.slice(DemangledCall.find('(') + 1, DemangledCall.find(')'));
+  BuiltinArgs.split(BuiltinArgsTypeStrs, ',', -1, false);
+  if (ArgIdx >= BuiltinArgsTypeStrs.size())
+    return nullptr;
+  StringRef TypeStr = BuiltinArgsTypeStrs[ArgIdx].trim();
+
+  // Parse strings representing OpenCL builtin types.
+  if (hasBuiltinTypePrefix(TypeStr)) {
+    // OpenCL builtin types in demangled call strings have the following format:
+    // e.g. ocl_image2d_ro
+    bool IsOCLBuiltinType = TypeStr.consume_front("ocl_");
+    assert(IsOCLBuiltinType && "Invalid OpenCL builtin prefix");
+
+    // Check if this is pointer to a builtin type and not just pointer
+    // representing a builtin type. In case it is a pointer to builtin type,
+    // this will require additional handling in the method calling
+    // parseBuiltinCallArgumentBaseType(...) as this function only retrieves the
+    // base types.
+    if (TypeStr.ends_with("*"))
+      TypeStr = TypeStr.slice(0, TypeStr.find_first_of(" "));
+
+    return parseBuiltinTypeNameToTargetExtType("opencl." + TypeStr.str() + "_t",
+                                               Ctx);
+  }
+
+  // Parse type name in either "typeN" or "type vector[N]" format, where
+  // N is the number of elements of the vector.
+  Type *BaseType;
+  unsigned VecElts = 0;
+
+  BaseType = parseBasicTypeName(TypeStr, Ctx);
+  if (!BaseType)
+    // Unable to recognize SPIRV type name.
+    return nullptr;
+
+  if (BaseType->isVoidTy())
+    BaseType = Type::getInt8Ty(Ctx);
+
+  // Handle "typeN*" or "type vector[N]*".
+  TypeStr.consume_back("*");
+
+  if (TypeStr.consume_front(" vector["))
+    TypeStr = TypeStr.substr(0, TypeStr.find(']'));
+
+  TypeStr.getAsInteger(10, VecElts);
+  if (VecElts > 0)
+    BaseType = VectorType::get(BaseType, VecElts, false);
+
+  return BaseType;
+}
+
 struct BuiltinType {
   StringRef Name;
   uint32_t Opcode;
@@ -2277,9 +2333,8 @@ static SPIRVType *getSampledImageType(const TargetExtType *OpaqueType,
 }
 
 namespace SPIRV {
-const TargetExtType *
-parseBuiltinTypeNameToTargetExtType(std::string TypeName,
-                                    MachineIRBuilder &MIRBuilder) {
+TargetExtType *parseBuiltinTypeNameToTargetExtType(std::string TypeName,
+                                                   LLVMContext &Context) {
   StringRef NameWithParameters = TypeName;
 
   // Pointers-to-opaque-structs representing OpenCL types are first translated
@@ -2303,7 +2358,7 @@ parseBuiltinTypeNameToTargetExtType(std::string TypeName,
   // Parameterized SPIR-V builtins names follow this format:
   // e.g. %spirv.Image._void_1_0_0_0_0_0_0, %spirv.Pipe._0
   if (!NameWithParameters.contains('_'))
-    return TargetExtType::get(MIRBuilder.getContext(), NameWithParameters);
+    return TargetExtType::get(Context, NameWithParameters);
 
   SmallVector<StringRef> Parameters;
   unsigned BaseNameLength = NameWithParameters.find('_') - 1;
@@ -2313,7 +2368,7 @@ parseBuiltinTypeNameToTargetExtType(std::string TypeName,
   bool HasTypeParameter = !isDigit(Parameters[0][0]);
   if (HasTypeParameter)
     TypeParameters.push_back(parseTypeString(
-        Parameters[0], MIRBuilder.getMF().getFunction().getContext()));
+        Parameters[0], Context));
   SmallVector<unsigned> IntParameters;
   for (unsigned i = HasTypeParameter ? 1 : 0; i < Parameters.size(); i++) {
     unsigned IntParameter = 0;
@@ -2323,7 +2378,7 @@ parseBuiltinTypeNameToTargetExtType(std::string TypeName,
            "Invalid format of SPIR-V builtin parameter literal!");
     IntParameters.push_back(IntParameter);
   }
-  return TargetExtType::get(MIRBuilder.getContext(),
+  return TargetExtType::get(Context,
                             NameWithParameters.substr(0, BaseNameLength),
                             TypeParameters, IntParameters);
 }
@@ -2343,7 +2398,7 @@ SPIRVType *lowerBuiltinType(const Type *OpaqueType,
   const TargetExtType *BuiltinType = dyn_cast<TargetExtType>(OpaqueType);
   if (!BuiltinType)
     BuiltinType = parseBuiltinTypeNameToTargetExtType(
-        OpaqueType->getStructName().str(), MIRBuilder);
+        OpaqueType->getStructName().str(), MIRBuilder.getContext());
 
   unsigned NumStartingVRegs = MIRBuilder.getMRI()->getNumVirtRegs();
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.h b/llvm/lib/Target/SPIRV/SPIRVBuiltins.h
index 6f957295464812..649f5bfd1d7c26 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.h
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.h
@@ -38,6 +38,17 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
                                  const SmallVectorImpl<Register> &Args,
                                  SPIRVGlobalRegistry *GR);
 
+/// Parses the provided \p ArgIdx argument base type in the \p DemangledCall
+/// skeleton. A base type is either a basic type (e.g. i32 for int), pointer
+/// element type (e.g. i8 for char*), or builtin type (TargetExtType).
+///
+/// \return LLVM Type or nullptr if unrecognized
+///
+/// \p DemangledCall is the skeleton of the lowered builtin function call.
+/// \p ArgIdx is the index of the argument to parse.
+Type *parseBuiltinCallArgumentBaseType(const StringRef DemangledCall,
+                                       unsigned ArgIdx, LLVMContext &Ctx);
+
 /// Translates a string representing a SPIR-V or OpenCL builtin type to a
 /// TargetExtType that can be further lowered with lowerBuiltinType().
 ///
@@ -45,9 +56,8 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
 ///
 /// \p TypeName is the full string representation of the SPIR-V or OpenCL
 /// builtin type.
-const TargetExtType *
-parseBuiltinTypeNameToTargetExtType(std::string TypeName,
-                                    MachineIRBuilder &MIRBuilder);
+TargetExtType *parseBuiltinTypeNameToTargetExtType(std::string TypeName,
+                                                   LLVMContext &Context);
 
 /// Handles the translation of the provided special opaque/builtin type \p Type
 /// to SPIR-V type. Generates the corresponding machine instructions for the
diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
index cc438b2bb8d4d7..f9197b805f0637 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
@@ -22,6 +22,8 @@
 #include "SPIRVSubtarget.h"
 #include "SPIRVUtils.h"
 #include "llvm/CodeGen/FunctionLoweringInfo.h"
+#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/IntrinsicsSPIRV.h"
 #include "llvm/Support/ModRef.h"
 
 using namespace llvm;
@@ -157,28 +159,42 @@ static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,
 
   Type *OriginalArgType = getOriginalFunctionType(F)->getParamType(ArgIdx);
 
-  // In case of non-kernel SPIR-V function or already TargetExtType, use the
-  // original IR type.
-  if (F.getCallingConv() != CallingConv::SPIR_KERNEL ||
-      isSpecialOpaqueType(OriginalArgType))
+  // If OriginalArgType is non-pointer, use the OriginalArgType (the type cannot
+  // be legally reassigned later).
+  if (!OriginalArgType->isPointerTy())
     return GR->getOrCreateSPIRVType(OriginalArgType, MIRBuilder, ArgAccessQual);
 
-  SPIRVType *ResArgType = nullptr;
-  if (MDString *MDKernelArgType = getOCLKernelArgType(F, ArgIdx)) {
-    StringRef MDTypeStr = MDKernelArgType->getString();
-    if (MDTypeStr.ends_with("*"))
-      ResArgType = GR->getOrCreateSPIRVTypeByName(
-          MDTypeStr, MIRBuilder,
-          addressSpaceToStorageClass(
-              OriginalArgType->getPointerAddressSpace()));
-    else if (MDTypeStr.ends_with("_t"))
-      ResArgType = GR->getOrCreateSPIRVTypeByName(
-          "opencl." + MDTypeStr.str(), MIRBuilder,
-          SPIRV::StorageClass::Function, ArgAccessQual);
+  // In case OriginalArgType is of pointer type, there are two possibilities:
+  // 1) This is an OpenCL/SPIR-V builtin type if there is spv_assign_type
+  // intrinsic assigning a TargetExtType.
+  // 2) This is a pointer, try to retrieve pointer element type from a
+  // spv_assign_ptr_type intrinsic or otherwise use default pointer element
+  // type.
+  for (auto User : F.getArg(ArgIdx)->users()) {
+    auto *II = dyn_cast<IntrinsicInst>(User);
+    // Check if this is spv_assign_type assigning OpenCL/SPIR-V builtin type.
+    if (II && II->getIntrinsicID() == Intrinsic::spv_assign_type) {
+      MetadataAsValue *VMD = cast<MetadataAsValue>(II->getOperand(1));
+      Type *BuiltinType =
+          cast<ConstantAsMetadata>(VMD->getMetadata())->getType();
+      assert(BuiltinType->isTargetExtTy() && "Expected TargetExtType");
+      return GR->getOrCreateSPIRVType(BuiltinType, MIRBuilder, ArgAccessQual);
+    }
+
+    // Check if this is spv_assign_ptr_type assigning pointer element type.
+    if (!II || II->getIntrinsicID() != Intrinsic::spv_assign_ptr_type)
+      continue;
+
+    MetadataAsValue *VMD = cast<MetadataAsValue>(II->getOperand(1));
+    SPIRVType *ElementType = GR->getOrCreateSPIRVType(
+        cast<ConstantAsMetadata>(VMD->getMetadata())->getType(), MIRBuilder);
+    return GR->getOrCreateSPIRVPointerType(
+        ElementType, MIRBuilder,
+        addressSpaceToStorageClass(
+            cast<ConstantInt>(II->getOperand(2))->getZExtValue()));
   }
-  return ResArgType ? ResArgType
-                    : GR->getOrCreateSPIRVType(OriginalArgType, MIRBuilder,
-                                               ArgAccessQual);
+
+  return GR->getOrCreateSPIRVType(OriginalArgType, MIRBuilder, ArgAccessQual);
 }
 
 static SPIRV::ExecutionModel::ExecutionModel
diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index e32cd50be56e38..c627427bd9c7a0 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -12,6 +12,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "SPIRV.h"
+#include "SPIRVBuiltins.h"
 #include "SPIRVMetadata.h"
 #include "SPIRVTargetMachine.h"
 #include "SPIRVUtils.h"
@@ -75,7 +76,11 @@ class SPIRVEmitIntrinsics
   void processInstrAfterVisit(Instruction *I);
   void insertAssignPtrTypeIntrs(Instruction *I);
   void insertAssignTypeIntrs(Instruction *I);
-  void insertPtrCastInstr(Instruction *I);
+  void insertAssignTypeInstrForTargetExtTypes(TargetExtType* AssignedType, Value *V);
+  void replacePointerOperandWithPtrCast(Instruction *I, Value *Pointer,
+                                        Type *ExpectedElementType,
+                                        unsigned OperandToReplace);
+  void insertPtrCastOrAssignTypeInstr(Instruction *I);
   void processGlobalValue(GlobalVariable &GV);
 
 public:
@@ -130,13 +135,6 @@ static void setInsertPointSkippingPhis(IRBuilder<> &B, Instruction *I) {
     B.SetInsertPoint(I);
 }
 
-static bool requireAssignPtrType(Instruction *I) {
-  if (isa<AllocaInst>(I) || isa<GetElementPtrInst>(I))
-    return true;
-
-  return false;
-}
-
 static bool requireAssignType(Instruction *I) {
   IntrinsicInst *Intr = dyn_cast<IntrinsicInst>(I);
   if (Intr) {
@@ -269,7 +267,7 @@ Instruction *SPIRVEmitIntrinsics::visitBitCastInst(BitCastInst &I) {
   // SPIR-V, contrary to LLVM 17+ IR, supports bitcasts between pointers of
   // varying element types. In case of IR coming from older versions of LLVM
   // such bitcasts do not provide sufficient information, should be just skipped
-  // here, and handled in insertPtrCastInstr.
+  // here, and handled in insertPtrCastOrAssignTypeInstr.
   if (I.getType()->isPointerTy()) {
     I.replaceAllUsesWith(Source);
     I.eraseFromParent();
@@ -286,34 +284,37 @@ Instruction *SPIRVEmitIntrinsics::visitBitCastInst(BitCastInst &I) {
   return NewI;
 }
 
-void SPIRVEmitIntrinsics::insertPtrCastInstr(Instruction *I) {
-  Value *Pointer;
-  Type *ExpectedElementType;
-  unsigned OperandToReplace;
+void SPIRVEmitIntrinsics::insertAssignTypeInstrForTargetExtTypes(
+    TargetExtType *AssignedType, Value *V) {
+  // Do not emit spv_assign_type if the V is of the AssignedType already.
+  if (V->getType() == AssignedType)
+    return;
 
-  StoreInst *SI = dyn_cast<StoreInst>(I);
-  if (SI && F->getCallingConv() == CallingConv::SPIR_KERNEL &&
-      SI->getValueOperand()->getType()->isPointerTy() &&
-      isa<Argument>(SI->getValueOperand())) {
-    Pointer = SI->getValueOperand();
-    ExpectedElementType = IntegerType::getInt8Ty(F->getContext());
-    OperandToReplace = 0;
-  } else if (SI) {
-    Pointer = SI->getPointerOperand();
-    ExpectedElementType = SI->getValueOperand()->getType();
-    OperandToReplace = 1;
-  } else if (LoadInst *LI = dyn_cast<LoadInst>(I)) {
-    Pointer = LI->getPointerOperand();
-    ExpectedElementType = LI->getType();
-    OperandToReplace = 0;
-  } else if (GetElementPtrInst *GEPI = dyn_cast<GetElementPtrInst>(I)) {
-    Pointer = GEPI->getPointerOperand();
-    ExpectedElementType = GEPI->getSourceElementType();
-    OperandToReplace = 0;
-  } else {
+  // Do not emit spv_assign_type if there is one already targetting V. If the
+  // found spv_assign_type assigns a type different than AssignedType, report an
+  // error. Builtin types cannot be redeclared or casted.
+  for (auto User : V->users()) {
+    auto *II = dyn_cast<IntrinsicInst>(User);
+    if (!II || II->getIntrinsicID() != Intrinsic::spv_assign_type)
+      continue;
+
+    MetadataAsValue *VMD = cast<MetadataAsValue>(II->getOperand(1));
+    Type *BuiltinType = dyn_cast<ConstantAsMetadata>(VMD->getMetadata())->getType();
+    if (BuiltinType != AssignedType)
+      report_fatal_error("Type mismatch " + BuiltinType->getTargetExtName() +
+                             "/" + AssignedType->getTargetExtName() +
+                             " for value " + V->getName(),
+                         false);
     return;
   }
 
+  Constant *Const = UndefValue::get(AssignedType);
+  buildIntrWithMD(Intrinsic::spv_assign_type, {V->getType()}, Const, V, {});
+}
+
+void SPIRVEmitIntrinsics::replacePointerOperandWithPtrCast(
+    Instruction *I, Value *Pointer, Type *ExpectedElementType,
+    unsigned OperandToReplace) {
   // If Pointer is the result of nop BitCastInst (ptr -> ptr), use the source
   // pointer instead. The BitCastInst should be later removed when visited.
   while (BitCastInst *BC = dyn_cast<BitCastInst>(Pointer))
@@ -378,38 +379,76 @@ void SPIRVEmitIntrinsics::insertPtrCastInstr(Instruction *I) {
     return;
   }
 
-  // Do not emit spv_ptrcast if it would cast to the default pointer element
-  // type (i8) of the same address space. In case of OpenCL kernels, make sure
-  // i8 is the pointer element type defined for the given kernel argument.
-  if (ExpectedElementType->isIntegerTy(8) &&
-      F->getCallingConv() != CallingConv::SPIR_KERNEL)
-    return;
+  // // Do not emit spv_ptrcast if it would cast to the default pointer element
+  // // type (i8) of the same address space.
+  // if (ExpectedElementType->isIntegerTy(8))
+  //   return;
 
-  Argument *Arg = dyn_cast<Argument>(Pointer);
-  if (ExpectedElementType->isIntegerTy(8) &&
-      F->getCallingConv() == CallingConv::SPIR_KERNEL && Arg) {
-    MDString *ArgType = getOCLKernelArgType(*Arg->getParent(), Arg->getArgNo());
-    if (ArgType && ArgType->getString().starts_with("uchar*"))
-      return;
-  }
-
-  // If this would be the first spv_ptrcast, the pointer's defining instruction
-  // requires spv_assign_ptr_type and does not already have one, do not emit
-  // spv_ptrcast and emit spv_assign_ptr_type instead.
-  Instruction *PointerDefInst = dyn_cast<Instruction>(Pointer);
-  if (FirstPtrCastOrAssignPtrType && PointerDefInst &&
-      requireAssignPtrType(PointerDefInst)) {
+  // If this would be the first spv_ptrcast, do not emit spv_ptrcast and emit
+  // spv_assign_ptr_type instead.
+  if (FirstPtrCastOrAssignPtrType &&
+      (isa<Instruction>(Pointer) || isa<Argument>(Pointer))) {
     buildIntrWithMD(Intrinsic::spv_assign_ptr_type, {Pointer->getType()},
                     ExpectedElementTypeConst, Pointer,
                     {IRB->getInt32(AddressSpace)});
     return;
-  } else {
-    SmallVector<Type *, 2> Types = {Pointer->getType(), Pointer->getType()};
-    SmallVector<Value *, 2> Args = {Pointer, VMD, IRB->getInt32(AddressSpace)};
-    auto *PtrCastI =
-        IRB->CreateIntrinsic(Intrinsic::spv_ptrcast, {Types}, Args);
-    I->setOperand(OperandToReplace, PtrCastI);
+  }
+
+  // Emit spv_ptrcast
+  SmallVector<Type *, 2> Types = {Pointer->getType(), Pointer->getType()};
+  SmallVector<Value *, 2> Args = {Pointer, VMD, IRB->getInt32(AddressSpace)};
+  auto *PtrCastI = IRB->CreateIntrinsic(Intrinsic::spv_ptrcast, {Types}, Args);
+  I->setOperand(OperandToReplace, PtrCastI);
+}
+
+void SPIRVEmitIntrinsics::insertPtrCastOrAssignTypeInstr(Instruction *I) {
+  // Handle basic instructions:
+  StoreInst *SI = dyn_cast<StoreInst>(I);
+  if (SI && F->getCallingConv() == CallingConv::SPIR_KERNEL &&
+      SI->getValueOperand()->getType()->isPointerTy() &&
+      isa<Argument>(SI->getValueOperand())) {
+    return replacePointerOperandWithPtrCast(
+        I, SI->getValueOperand(), IntegerType::getInt8Ty(F->getContext()), 0);
+  } else if (SI) {
+    return replacePointerOperandWithPtrCast(
+        I, SI->getPointerOperand(), SI->getValueOperand()->getType(), 1);
+  } else if (LoadInst *LI = dyn_cast<LoadInst>(I)) {
+    return replacePointerOperandWithPtrCast(I, LI->getPointerOperand(),
+                                            LI->getType(), 0);
+  } else if (GetElementPtrInst *GEPI = dyn_cast<GetElementPtrInst>(I)) {
+    return replacePointerOperandWithPtrCast(I, GEPI->getPointerOperand(),
+                                            GEPI->getSourceElementType(), 0);
+  }
+
+  // Handle calls to builtins (non-intrinsics):
+  CallInst *CI = dyn_cast<CallInst>(I);
+  if (!CI || CI->isIndirectCall() || CI->getCalledFunction()->isIntrinsic())
+    return;
+
+  std::string DemangledName =
+      getOclOrSpirvBuiltinDemangledName(CI->getCalledFunction()->getName());
+  if (DemangledName.empty())
     return;
+
+  for (unsigned OpIdx = 0; OpIdx < CI->arg_size(); OpIdx++) {
+    Value *ArgOperand = CI->getArgOperand(OpIdx);
+    if (!isa<PointerType>(ArgOperand->getType()))
+      continue;
+
+    // Constants (nulls/undefs) are handled in insertAssignPtrTypeIntrs()
+    if (!isa<Instruction>(ArgOperand) && !isa<Argument>(ArgOperand))
+      continue;
+
+    Type *ExpectedType = SPIRV::parseBuiltinCallArgumentBaseType(
+        DemangledName, OpIdx, I->getContext());
+    if (!ExpectedType)
+      cont...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/82678


More information about the llvm-commits mailing list