[llvm] [SPIRV] Improve type inference of operand presented by opaque pointers and aggregate types (PR #98035)

via llvm-commits llvm-commits at lists.llvm.org
Tue Jul 9 14:45:30 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-spir-v

Author: Vyacheslav Levytskyy (VyacheslavLevytskyy)

<details>
<summary>Changes</summary>

This PR improves type inference of operand presented by opaque pointers and aggregate types:
* tries to restore original function return type for aggregate types so that it's possible to deduce a correct type during emit-intrinsics step (see llvm/test/CodeGen/SPIRV/SpecConstants/restore-spec-type.ll for the reproducer of the previously existed issue when spirv-val found a mismatch between object and ptr types in OpStore due to the incorrect aggregate types tracing),
* explores untyped pointer operands of store to deduce correct pointee types,
* creates an extension type to track pointee types from emit-intrinsics step and further instead of direct and naive usage of TypePointerType that led previously to crashes due to ban of creation of Value of TypePointerType type,
* tracks instructions with uncomplete type information and tries to improve their type info after pass calculated types for all machine functions (it doesn't traverse a code but rather checks only those instructions which were tracked as uncompleted),
* address more cases of removing unnecessary bitcasts (see, for example, changes in test/CodeGen/SPIRV/transcoding/OpGenericCastToPtr.ll where `CHECK-SPIRV-NEXT` in LIT checks show absence of unneeded bitcasts and unmangled/mangled versions have proper typing now with equivalent type info),
* address more cases of well known types or relations between types within instructions (see, for example, atomic*.ll test cases and Event-related test cases for improved SPIR-V code generated by the Backend).

---

Patch is 82.33 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/98035.diff


19 Files Affected:

- (modified) llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp (+26-19) 
- (modified) llvm/lib/Target/SPIRV/SPIRVBuiltins.h (+2) 
- (modified) llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp (+318-61) 
- (modified) llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h (+14) 
- (modified) llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp (+1-1) 
- (modified) llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp (+3-2) 
- (modified) llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp (+2-2) 
- (modified) llvm/lib/Target/SPIRV/SPIRVPrepareFunctions.cpp (+5) 
- (modified) llvm/lib/Target/SPIRV/SPIRVUtils.h (+55-1) 
- (added) llvm/test/CodeGen/SPIRV/SpecConstants/restore-spec-type.ll (+46) 
- (added) llvm/test/CodeGen/SPIRV/instructions/atomic-ptr.ll (+38) 
- (modified) llvm/test/CodeGen/SPIRV/instructions/atomic.ll (+38-43) 
- (modified) llvm/test/CodeGen/SPIRV/instructions/atomic_acqrel.ll (+28-36) 
- (modified) llvm/test/CodeGen/SPIRV/instructions/atomic_seq.ll (+28-36) 
- (modified) llvm/test/CodeGen/SPIRV/pointers/type-deduce-by-call-chain.ll (+4-8) 
- (added) llvm/test/CodeGen/SPIRV/pointers/type-deduce-sycl-stub.ll (+127) 
- (modified) llvm/test/CodeGen/SPIRV/transcoding/OpGenericCastToPtr.ll (+40-40) 
- (modified) llvm/test/CodeGen/SPIRV/transcoding/OpGroupAsyncCopy-strided.ll (+3-5) 
- (modified) llvm/test/CodeGen/SPIRV/transcoding/spirv-event-null.ll (+1-1) 


``````````diff
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index 286bdb9a7ebac..1609576c038d0 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -169,21 +169,9 @@ using namespace InstructionSet;
 // TableGen records
 //===----------------------------------------------------------------------===//
 
-/// Looks up the demangled builtin call in the SPIRVBuiltins.td records using
-/// the provided \p DemangledCall and specified \p Set.
-///
-/// The lookup follows the following algorithm, returning the first successful
-/// match:
-/// 1. Search with the plain demangled name (expecting a 1:1 match).
-/// 2. Search with the prefix before or suffix after the demangled name
-/// signyfying the type of the first argument.
-///
-/// \returns Wrapper around the demangled call and found builtin definition.
-static std::unique_ptr<const SPIRV::IncomingCall>
-lookupBuiltin(StringRef DemangledCall,
-              SPIRV::InstructionSet::InstructionSet Set,
-              Register ReturnRegister, const SPIRVType *ReturnType,
-              const SmallVectorImpl<Register> &Arguments) {
+namespace SPIRV {
+/// Parses the name part of the demangled builtin call.
+std::string lookupBuiltinNameHelper(StringRef DemangledCall) {
   const static std::string PassPrefix = "(anonymous namespace)::";
   std::string BuiltinName;
   // Itanium Demangler result may have "(anonymous namespace)::" prefix
@@ -215,6 +203,27 @@ lookupBuiltin(StringRef DemangledCall,
     BuiltinName = BuiltinName.substr(0, BuiltinName.find("_R"));
   }
 
+  return BuiltinName;
+}
+} // namespace SPIRV
+
+/// Looks up the demangled builtin call in the SPIRVBuiltins.td records using
+/// the provided \p DemangledCall and specified \p Set.
+///
+/// The lookup follows the following algorithm, returning the first successful
+/// match:
+/// 1. Search with the plain demangled name (expecting a 1:1 match).
+/// 2. Search with the prefix before or suffix after the demangled name
+/// signyfying the type of the first argument.
+///
+/// \returns Wrapper around the demangled call and found builtin definition.
+static std::unique_ptr<const SPIRV::IncomingCall>
+lookupBuiltin(StringRef DemangledCall,
+              SPIRV::InstructionSet::InstructionSet Set,
+              Register ReturnRegister, const SPIRVType *ReturnType,
+              const SmallVectorImpl<Register> &Arguments) {
+  std::string BuiltinName = SPIRV::lookupBuiltinNameHelper(DemangledCall);
+
   SmallVector<StringRef, 10> BuiltinArgumentTypes;
   StringRef BuiltinArgs =
       DemangledCall.slice(DemangledCall.find('(') + 1, DemangledCall.find(')'));
@@ -2610,9 +2619,6 @@ Type *parseBuiltinCallArgumentBaseType(const StringRef DemangledCall,
     // Unable to recognize SPIRV type name.
     return nullptr;
 
-  if (BaseType->isVoidTy())
-    BaseType = Type::getInt8Ty(Ctx);
-
   // Handle "typeN*" or "type vector[N]*".
   TypeStr.consume_back("*");
 
@@ -2621,7 +2627,8 @@ Type *parseBuiltinCallArgumentBaseType(const StringRef DemangledCall,
 
   TypeStr.getAsInteger(10, VecElts);
   if (VecElts > 0)
-    BaseType = VectorType::get(BaseType, VecElts, false);
+    BaseType = VectorType::get(
+        BaseType->isVoidTy() ? Type::getInt8Ty(Ctx) : BaseType, VecElts, false);
 
   return BaseType;
 }
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.h b/llvm/lib/Target/SPIRV/SPIRVBuiltins.h
index 68bff602d1d10..d07fc7c6ca874 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.h
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.h
@@ -19,6 +19,8 @@
 
 namespace llvm {
 namespace SPIRV {
+/// Parses the name part of the demangled builtin call.
+std::string lookupBuiltinNameHelper(StringRef DemangledCall);
 /// Lowers a builtin function call using the provided \p DemangledCall skeleton
 /// and external instruction \p Set.
 ///
diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index 566eafd41e9bd..d9864ab50ecfe 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -46,6 +46,10 @@
 using namespace llvm;
 
 namespace llvm {
+namespace SPIRV {
+#define GET_BuiltinGroup_DECL
+#include "SPIRVGenTables.inc"
+} // namespace SPIRV
 void initializeSPIRVEmitIntrinsicsPass(PassRegistry &);
 } // namespace llvm
 
@@ -69,22 +73,38 @@ class SPIRVEmitIntrinsics
   DenseSet<Instruction *> AggrStores;
   SPIRV::InstructionSet::InstructionSet InstrSet;
 
+  // a register of Instructions that don't have a complete type definition
+  SmallPtrSet<Value *, 8> UncompleteTypeInfo;
+  SmallVector<Instruction *> PostprocessWorklist;
+
+  // well known result types of builtins
+  enum WellKnownTypes { Event };
+
   // deduce element type of untyped pointers
   Type *deduceElementType(Value *I, bool UnknownElemTypeI8);
-  Type *deduceElementTypeHelper(Value *I);
-  Type *deduceElementTypeHelper(Value *I, std::unordered_set<Value *> &Visited);
+  Type *deduceElementTypeHelper(Value *I, bool UnknownElemTypeI8);
+  Type *deduceElementTypeHelper(Value *I, std::unordered_set<Value *> &Visited,
+                                bool UnknownElemTypeI8);
   Type *deduceElementTypeByValueDeep(Type *ValueTy, Value *Operand,
-                                     std::unordered_set<Value *> &Visited);
+                                     bool UnknownElemTypeI8);
+  Type *deduceElementTypeByValueDeep(Type *ValueTy, Value *Operand,
+                                     std::unordered_set<Value *> &Visited,
+                                     bool UnknownElemTypeI8);
   Type *deduceElementTypeByUsersDeep(Value *Op,
-                                     std::unordered_set<Value *> &Visited);
+                                     std::unordered_set<Value *> &Visited,
+                                     bool UnknownElemTypeI8);
+  void maybeAssignPtrType(Type *&Ty, Value *I, Type *RefTy,
+                          bool UnknownElemTypeI8);
 
   // deduce nested types of composites
-  Type *deduceNestedTypeHelper(User *U);
+  Type *deduceNestedTypeHelper(User *U, bool UnknownElemTypeI8);
   Type *deduceNestedTypeHelper(User *U, Type *Ty,
-                               std::unordered_set<Value *> &Visited);
+                               std::unordered_set<Value *> &Visited,
+                               bool UnknownElemTypeI8);
 
   // deduce Types of operands of the Instruction if possible
-  void deduceOperandElementType(Instruction *I);
+  void deduceOperandElementType(Instruction *I, Instruction *AskOp = 0,
+                                Type *AskTy = 0, CallInst *AssignCI = 0);
 
   void preprocessCompositeConstants(IRBuilder<> &B);
   void preprocessUndefs(IRBuilder<> &B);
@@ -151,6 +171,7 @@ class SPIRVEmitIntrinsics
 
   bool runOnModule(Module &M) override;
   bool runOnFunction(Function &F);
+  bool postprocessTypes();
 
   void getAnalysisUsage(AnalysisUsage &AU) const override {
     ModulePass::getAnalysisUsage(AU);
@@ -223,6 +244,41 @@ static inline void reportFatalOnTokenType(const Instruction *I) {
                        false);
 }
 
+static bool IsKernelArgInt8(Function *F, StoreInst *SI) {
+  return SI && F->getCallingConv() == CallingConv::SPIR_KERNEL &&
+         isPointerTy(SI->getValueOperand()->getType()) &&
+         isa<Argument>(SI->getValueOperand());
+}
+
+// Maybe restore original function return type.
+static inline Type *restoreMutatedType(SPIRVGlobalRegistry *GR, Instruction *I,
+                                       Type *Ty) {
+  CallInst *CI = dyn_cast<CallInst>(I);
+  if (!CI || CI->isIndirectCall() || CI->isInlineAsm() ||
+      !CI->getCalledFunction() || CI->getCalledFunction()->isIntrinsic())
+    return Ty;
+  if (Type *OriginalTy = GR->findMutated(CI->getCalledFunction()))
+    return OriginalTy;
+  return Ty;
+}
+
+// Reconstruct type with nested element types according to deduced type info.
+// Return nullptr if no detailed type info is available.
+static inline Type *reconstructType(SPIRVGlobalRegistry *GR, Value *Op) {
+  Type *Ty = Op->getType();
+  if (!isUntypedPointerTy(Ty))
+    return Ty;
+  // try to find the pointee type
+  if (Type *NestedTy = GR->findDeducedElementType(Op))
+    return getTypedPointerWrapper(NestedTy, getPointerAddressSpace(Ty));
+  // not a pointer according to the type info (e.g., Event object)
+  CallInst *CI = GR->findAssignPtrTypeInstr(Op);
+  if (!CI)
+    return nullptr;
+  MetadataAsValue *MD = cast<MetadataAsValue>(CI->getArgOperand(1));
+  return cast<ConstantAsMetadata>(MD->getMetadata())->getType();
+}
+
 void SPIRVEmitIntrinsics::buildAssignType(IRBuilder<> &B, Type *Ty,
                                           Value *Arg) {
   Value *OfType = PoisonValue::get(Ty);
@@ -263,15 +319,26 @@ void SPIRVEmitIntrinsics::updateAssignType(CallInst *AssignCI, Value *Arg,
 
 // Set element pointer type to the given value of ValueTy and tries to
 // specify this type further (recursively) by Operand value, if needed.
+Type *
+SPIRVEmitIntrinsics::deduceElementTypeByValueDeep(Type *ValueTy, Value *Operand,
+                                                  bool UnknownElemTypeI8) {
+  std::unordered_set<Value *> Visited;
+  return deduceElementTypeByValueDeep(ValueTy, Operand, Visited,
+                                      UnknownElemTypeI8);
+}
+
 Type *SPIRVEmitIntrinsics::deduceElementTypeByValueDeep(
-    Type *ValueTy, Value *Operand, std::unordered_set<Value *> &Visited) {
+    Type *ValueTy, Value *Operand, std::unordered_set<Value *> &Visited,
+    bool UnknownElemTypeI8) {
   Type *Ty = ValueTy;
   if (Operand) {
     if (auto *PtrTy = dyn_cast<PointerType>(Ty)) {
-      if (Type *NestedTy = deduceElementTypeHelper(Operand, Visited))
-        Ty = TypedPointerType::get(NestedTy, PtrTy->getAddressSpace());
+      if (Type *NestedTy =
+              deduceElementTypeHelper(Operand, Visited, UnknownElemTypeI8))
+        Ty = getTypedPointerWrapper(NestedTy, PtrTy->getAddressSpace());
     } else {
-      Ty = deduceNestedTypeHelper(dyn_cast<User>(Operand), Ty, Visited);
+      Ty = deduceNestedTypeHelper(dyn_cast<User>(Operand), Ty, Visited,
+                                  UnknownElemTypeI8);
     }
   }
   return Ty;
@@ -279,12 +346,12 @@ Type *SPIRVEmitIntrinsics::deduceElementTypeByValueDeep(
 
 // Traverse User instructions to deduce an element pointer type of the operand.
 Type *SPIRVEmitIntrinsics::deduceElementTypeByUsersDeep(
-    Value *Op, std::unordered_set<Value *> &Visited) {
+    Value *Op, std::unordered_set<Value *> &Visited, bool UnknownElemTypeI8) {
   if (!Op || !isPointerTy(Op->getType()))
     return nullptr;
 
-  if (auto PType = dyn_cast<TypedPointerType>(Op->getType()))
-    return PType->getElementType();
+  if (auto ElemTy = getPointeeType(Op->getType()))
+    return ElemTy;
 
   // maybe we already know operand's element type
   if (Type *KnownTy = GR->findDeducedElementType(Op))
@@ -292,7 +359,7 @@ Type *SPIRVEmitIntrinsics::deduceElementTypeByUsersDeep(
 
   for (User *OpU : Op->users()) {
     if (Instruction *Inst = dyn_cast<Instruction>(OpU)) {
-      if (Type *Ty = deduceElementTypeHelper(Inst, Visited))
+      if (Type *Ty = deduceElementTypeHelper(Inst, Visited, UnknownElemTypeI8))
         return Ty;
     }
   }
@@ -314,13 +381,27 @@ static Type *getPointeeTypeByCallInst(StringRef DemangledName,
 
 // Deduce and return a successfully deduced Type of the Instruction,
 // or nullptr otherwise.
-Type *SPIRVEmitIntrinsics::deduceElementTypeHelper(Value *I) {
+Type *SPIRVEmitIntrinsics::deduceElementTypeHelper(Value *I,
+                                                   bool UnknownElemTypeI8) {
   std::unordered_set<Value *> Visited;
-  return deduceElementTypeHelper(I, Visited);
+  return deduceElementTypeHelper(I, Visited, UnknownElemTypeI8);
+}
+
+void SPIRVEmitIntrinsics::maybeAssignPtrType(Type *&Ty, Value *Op, Type *RefTy,
+                                             bool UnknownElemTypeI8) {
+  if (isUntypedPointerTy(RefTy)) {
+    if (!UnknownElemTypeI8)
+      return;
+    if (auto *I = dyn_cast<Instruction>(Op)) {
+      UncompleteTypeInfo.insert(I);
+      PostprocessWorklist.push_back(I);
+    }
+  }
+  Ty = RefTy;
 }
 
 Type *SPIRVEmitIntrinsics::deduceElementTypeHelper(
-    Value *I, std::unordered_set<Value *> &Visited) {
+    Value *I, std::unordered_set<Value *> &Visited, bool UnknownElemTypeI8) {
   // allow to pass nullptr as an argument
   if (!I)
     return nullptr;
@@ -338,34 +419,41 @@ Type *SPIRVEmitIntrinsics::deduceElementTypeHelper(
   Type *Ty = nullptr;
   // look for known basic patterns of type inference
   if (auto *Ref = dyn_cast<AllocaInst>(I)) {
-    Ty = Ref->getAllocatedType();
+    maybeAssignPtrType(Ty, I, Ref->getAllocatedType(), UnknownElemTypeI8);
   } else if (auto *Ref = dyn_cast<GetElementPtrInst>(I)) {
     Ty = Ref->getResultElementType();
   } else if (auto *Ref = dyn_cast<GlobalValue>(I)) {
     Ty = deduceElementTypeByValueDeep(
         Ref->getValueType(),
-        Ref->getNumOperands() > 0 ? Ref->getOperand(0) : nullptr, Visited);
+        Ref->getNumOperands() > 0 ? Ref->getOperand(0) : nullptr, Visited,
+        UnknownElemTypeI8);
   } else if (auto *Ref = dyn_cast<AddrSpaceCastInst>(I)) {
-    Ty = deduceElementTypeHelper(Ref->getPointerOperand(), Visited);
+    Type *RefTy = deduceElementTypeHelper(Ref->getPointerOperand(), Visited,
+                                          UnknownElemTypeI8);
+    maybeAssignPtrType(Ty, I, RefTy, UnknownElemTypeI8);
   } else if (auto *Ref = dyn_cast<BitCastInst>(I)) {
     if (Type *Src = Ref->getSrcTy(), *Dest = Ref->getDestTy();
         isPointerTy(Src) && isPointerTy(Dest))
-      Ty = deduceElementTypeHelper(Ref->getOperand(0), Visited);
+      Ty = deduceElementTypeHelper(Ref->getOperand(0), Visited,
+                                   UnknownElemTypeI8);
   } else if (auto *Ref = dyn_cast<AtomicCmpXchgInst>(I)) {
     Value *Op = Ref->getNewValOperand();
-    Ty = deduceElementTypeByValueDeep(Op->getType(), Op, Visited);
+    if (isPointerTy(Op->getType()))
+      Ty = deduceElementTypeHelper(Op, Visited, UnknownElemTypeI8);
   } else if (auto *Ref = dyn_cast<AtomicRMWInst>(I)) {
     Value *Op = Ref->getValOperand();
-    Ty = deduceElementTypeByValueDeep(Op->getType(), Op, Visited);
+    if (isPointerTy(Op->getType()))
+      Ty = deduceElementTypeHelper(Op, Visited, UnknownElemTypeI8);
   } else if (auto *Ref = dyn_cast<PHINode>(I)) {
     for (unsigned i = 0; i < Ref->getNumIncomingValues(); i++) {
-      Ty = deduceElementTypeByUsersDeep(Ref->getIncomingValue(i), Visited);
+      Ty = deduceElementTypeByUsersDeep(Ref->getIncomingValue(i), Visited,
+                                        UnknownElemTypeI8);
       if (Ty)
         break;
     }
   } else if (auto *Ref = dyn_cast<SelectInst>(I)) {
     for (Value *Op : {Ref->getTrueValue(), Ref->getFalseValue()}) {
-      Ty = deduceElementTypeByUsersDeep(Op, Visited);
+      Ty = deduceElementTypeByUsersDeep(Op, Visited, UnknownElemTypeI8);
       if (Ty)
         break;
     }
@@ -384,10 +472,12 @@ Type *SPIRVEmitIntrinsics::deduceElementTypeHelper(
     if (Function *CalledF = CI->getCalledFunction()) {
       std::string DemangledName =
           getOclOrSpirvBuiltinDemangledName(CalledF->getName());
+      if (DemangledName.length() > 0)
+        DemangledName = SPIRV::lookupBuiltinNameHelper(DemangledName);
       auto AsArgIt = ResTypeByArg.find(DemangledName);
       if (AsArgIt != ResTypeByArg.end()) {
         Ty = deduceElementTypeHelper(CI->getArgOperand(AsArgIt->second),
-                                     Visited);
+                                     Visited, UnknownElemTypeI8);
       }
     }
   }
@@ -404,13 +494,15 @@ Type *SPIRVEmitIntrinsics::deduceElementTypeHelper(
 // Re-create a type of the value if it has untyped pointer fields, also nested.
 // Return the original value type if no corrections of untyped pointer
 // information is found or needed.
-Type *SPIRVEmitIntrinsics::deduceNestedTypeHelper(User *U) {
+Type *SPIRVEmitIntrinsics::deduceNestedTypeHelper(User *U,
+                                                  bool UnknownElemTypeI8) {
   std::unordered_set<Value *> Visited;
-  return deduceNestedTypeHelper(U, U->getType(), Visited);
+  return deduceNestedTypeHelper(U, U->getType(), Visited, UnknownElemTypeI8);
 }
 
 Type *SPIRVEmitIntrinsics::deduceNestedTypeHelper(
-    User *U, Type *OrigTy, std::unordered_set<Value *> &Visited) {
+    User *U, Type *OrigTy, std::unordered_set<Value *> &Visited,
+    bool UnknownElemTypeI8) {
   if (!U)
     return OrigTy;
 
@@ -432,10 +524,12 @@ Type *SPIRVEmitIntrinsics::deduceNestedTypeHelper(
       Type *Ty = OpTy;
       if (Op) {
         if (auto *PtrTy = dyn_cast<PointerType>(OpTy)) {
-          if (Type *NestedTy = deduceElementTypeHelper(Op, Visited))
+          if (Type *NestedTy =
+                  deduceElementTypeHelper(Op, Visited, UnknownElemTypeI8))
             Ty = TypedPointerType::get(NestedTy, PtrTy->getAddressSpace());
         } else {
-          Ty = deduceNestedTypeHelper(dyn_cast<User>(Op), OpTy, Visited);
+          Ty = deduceNestedTypeHelper(dyn_cast<User>(Op), OpTy, Visited,
+                                      UnknownElemTypeI8);
         }
       }
       Tys.push_back(Ty);
@@ -451,10 +545,12 @@ Type *SPIRVEmitIntrinsics::deduceNestedTypeHelper(
       Type *OpTy = ArrTy->getElementType();
       Type *Ty = OpTy;
       if (auto *PtrTy = dyn_cast<PointerType>(OpTy)) {
-        if (Type *NestedTy = deduceElementTypeHelper(Op, Visited))
+        if (Type *NestedTy =
+                deduceElementTypeHelper(Op, Visited, UnknownElemTypeI8))
           Ty = TypedPointerType::get(NestedTy, PtrTy->getAddressSpace());
       } else {
-        Ty = deduceNestedTypeHelper(dyn_cast<User>(Op), OpTy, Visited);
+        Ty = deduceNestedTypeHelper(dyn_cast<User>(Op), OpTy, Visited,
+                                    UnknownElemTypeI8);
       }
       if (Ty != OpTy) {
         Type *NewTy = ArrayType::get(Ty, ArrTy->getNumElements());
@@ -467,10 +563,12 @@ Type *SPIRVEmitIntrinsics::deduceNestedTypeHelper(
       Type *OpTy = VecTy->getElementType();
       Type *Ty = OpTy;
       if (auto *PtrTy = dyn_cast<PointerType>(OpTy)) {
-        if (Type *NestedTy = deduceElementTypeHelper(Op, Visited))
-          Ty = TypedPointerType::get(NestedTy, PtrTy->getAddressSpace());
+        if (Type *NestedTy =
+                deduceElementTypeHelper(Op, Visited, UnknownElemTypeI8))
+          Ty = getTypedPointerWrapper(NestedTy, PtrTy->getAddressSpace());
       } else {
-        Ty = deduceNestedTypeHelper(dyn_cast<User>(Op), OpTy, Visited);
+        Ty = deduceNestedTypeHelper(dyn_cast<User>(Op), OpTy, Visited,
+                                    UnknownElemTypeI8);
       }
       if (Ty != OpTy) {
         Type *NewTy = VectorType::get(Ty, VecTy->getElementCount());
@@ -484,16 +582,38 @@ Type *SPIRVEmitIntrinsics::deduceNestedTypeHelper(
 }
 
 Type *SPIRVEmitIntrinsics::deduceElementType(Value *I, bool UnknownElemTypeI8) {
-  if (Type *Ty = deduceElementTypeHelper(I))
+  if (Type *Ty = deduceElementTypeHelper(I, UnknownElemTypeI8))
     return Ty;
-  return UnknownElemTypeI8 ? IntegerType::getInt8Ty(I->getContext()) : nullptr;
+  if (!UnknownElemTypeI8)
+    return nullptr;
+  if (auto *Instr = dyn_cast<Instruction>(I)) {
+    UncompleteTypeInfo.insert(Instr);
+    PostprocessWorklist.push_back(Instr);
+  }
+  return IntegerType::getInt8Ty(I->getContext());
+}
+
+static inline Type *getAtomicElemTy(SPIRVGlobalRegistry *GR, Instruction *I,
+                                    Value *PointerOperand) {
+  Type *PointeeTy = GR->findDeducedElementType(PointerOperand);
+  if (PointeeTy && !isUntypedPointerTy(PointeeTy))
+    return nullptr;
+  auto *PtrTy = dyn_cast<PointerType>(I->getType());
+  if (!PtrTy)
+    return I->getType();
+  if (Type *NestedTy = GR->findDeducedElementType(I))
+    return getTypedPointerWrapper(NestedTy, PtrTy->getAddressSpace());
+  return nullptr;
 }
 
 // If the Instruction has Pointer operands with unresolved types, this function
 // tries to deduce them. If the Instruction has Pointer operands with known
 // types which differ from expected, this function tries to insert a bitcast to
 // resolve the issue.
-void SPIRVEmitIntrinsics::deduceOperandElementType(Instruction *I) {
+void SPIRVEmitIntrinsics::deduceOperandElementType(Instruction *I,
+  ...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/98035


More information about the llvm-commits mailing list