[llvm] b7ac8fd - [SPIR-V] Improve type inference: deduce types of composite data structures (#86782)

via llvm-commits llvm-commits at lists.llvm.org
Thu Mar 28 00:08:10 PDT 2024


Author: Vyacheslav Levytskyy
Date: 2024-03-28T08:08:06+01:00
New Revision: b7ac8fddb54816256fab70696ebc176717a391c3

URL: https://github.com/llvm/llvm-project/commit/b7ac8fddb54816256fab70696ebc176717a391c3
DIFF: https://github.com/llvm/llvm-project/commit/b7ac8fddb54816256fab70696ebc176717a391c3.diff

LOG: [SPIR-V] Improve type inference: deduce types of composite data structures (#86782)

This PR improves type inference in general and deduces types of
composite data structures in particular. Also added a way to insert a
bitcast to make a fun call valid in case of arguments types mismatch due
to opaque pointers type inference.

The attached test `pointers/nested-struct-opaque-pointers.ll`
demonstrates new capabilities: the SPIRV code emitted for this test is
now (1) valid in a sense of data field types and (2) accepted by
`spirv-val`.

More strict LIT checks, support of more composite data structures and
improvement of fun calls from the perspective of type correctness are
main todo's at the moment.

Added: 
    llvm/test/CodeGen/SPIRV/pointers/nested-struct-opaque-pointers.ll

Modified: 
    llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
    llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
    llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
    llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
    llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
    llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
    llvm/lib/Target/SPIRV/SPIRVUtils.h
    llvm/test/CodeGen/SPIRV/pointers/struct-opaque-pointers.ll
    llvm/test/CodeGen/SPIRV/pointers/type-deduce-by-call-chain.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
index afdca01561b0bc..ad4e72a3128b1e 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
@@ -201,21 +201,30 @@ static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,
   if (!isPointerTy(OriginalArgType))
     return GR->getOrCreateSPIRVType(OriginalArgType, MIRBuilder, ArgAccessQual);
 
-  // In case OriginalArgType is of pointer type, there are three possibilities:
+  Argument *Arg = F.getArg(ArgIdx);
+  Type *ArgType = Arg->getType();
+  if (isTypedPointerTy(ArgType)) {
+    SPIRVType *ElementType = GR->getOrCreateSPIRVType(
+        cast<TypedPointerType>(ArgType)->getElementType(), MIRBuilder);
+    return GR->getOrCreateSPIRVPointerType(
+        ElementType, MIRBuilder,
+        addressSpaceToStorageClass(getPointerAddressSpace(ArgType), ST));
+  }
+
+  // In case OriginalArgType is of untyped pointer type, there are three
+  // possibilities:
   // 1) This is a pointer of an LLVM IR element type, passed byval/byref.
   // 2) This is an OpenCL/SPIR-V builtin type if there is spv_assign_type
-  // intrinsic assigning a TargetExtType.
+  //    intrinsic assigning a TargetExtType.
   // 3) This is a pointer, try to retrieve pointer element type from a
   // spv_assign_ptr_type intrinsic or otherwise use default pointer element
   // type.
-  Argument *Arg = F.getArg(ArgIdx);
-  if (HasPointeeTypeAttr(Arg)) {
-    Type *ByValRefType = Arg->hasByValAttr() ? Arg->getParamByValType()
-                                             : Arg->getParamByRefType();
-    SPIRVType *ElementType = GR->getOrCreateSPIRVType(ByValRefType, MIRBuilder);
+  if (hasPointeeTypeAttr(Arg)) {
+    SPIRVType *ElementType =
+        GR->getOrCreateSPIRVType(getPointeeTypeByAttr(Arg), MIRBuilder);
     return GR->getOrCreateSPIRVPointerType(
         ElementType, MIRBuilder,
-        addressSpaceToStorageClass(getPointerAddressSpace(Arg->getType()), ST));
+        addressSpaceToStorageClass(getPointerAddressSpace(ArgType), ST));
   }
 
   for (auto User : Arg->users()) {

diff  --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index 5828db6669ff18..7c5a38fa48d009 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -14,6 +14,7 @@
 #include "SPIRV.h"
 #include "SPIRVBuiltins.h"
 #include "SPIRVMetadata.h"
+#include "SPIRVSubtarget.h"
 #include "SPIRVTargetMachine.h"
 #include "SPIRVUtils.h"
 #include "llvm/IR/IRBuilder.h"
@@ -53,14 +54,22 @@ class SPIRVEmitIntrinsics
     : public FunctionPass,
       public InstVisitor<SPIRVEmitIntrinsics, Instruction *> {
   SPIRVTargetMachine *TM = nullptr;
+  SPIRVGlobalRegistry *GR = nullptr;
   Function *F = nullptr;
   bool TrackConstants = true;
   DenseMap<Instruction *, Constant *> AggrConsts;
+  DenseMap<Instruction *, Type *> AggrConstTypes;
   DenseSet<Instruction *> AggrStores;
 
-  // deduce values type
-  DenseMap<Value *, Type *> DeducedElTys;
+  // deduce element type of untyped pointers
   Type *deduceElementType(Value *I);
+  Type *deduceElementTypeHelper(Value *I);
+  Type *deduceElementTypeHelper(Value *I, std::unordered_set<Value *> &Visited);
+
+  // deduce nested types of composites
+  Type *deduceNestedTypeHelper(User *U);
+  Type *deduceNestedTypeHelper(User *U, Type *Ty,
+                               std::unordered_set<Value *> &Visited);
 
   void preprocessCompositeConstants(IRBuilder<> &B);
   void preprocessUndefs(IRBuilder<> &B);
@@ -92,9 +101,9 @@ class SPIRVEmitIntrinsics
   void insertPtrCastOrAssignTypeInstr(Instruction *I, IRBuilder<> &B);
   void processGlobalValue(GlobalVariable &GV, IRBuilder<> &B);
   void processParamTypes(Function *F, IRBuilder<> &B);
-  Type *deduceFunParamType(Function *F, unsigned OpIdx);
-  Type *deduceFunParamType(Function *F, unsigned OpIdx,
-                           std::unordered_set<Function *> &FVisited);
+  Type *deduceFunParamElementType(Function *F, unsigned OpIdx);
+  Type *deduceFunParamElementType(Function *F, unsigned OpIdx,
+                                  std::unordered_set<Function *> &FVisited);
 
 public:
   static char ID;
@@ -169,17 +178,20 @@ static inline void reportFatalOnTokenType(const Instruction *I) {
 
 // Deduce and return a successfully deduced Type of the Instruction,
 // or nullptr otherwise.
-static Type *deduceElementTypeHelper(Value *I,
-                                     std::unordered_set<Value *> &Visited,
-                                     DenseMap<Value *, Type *> &DeducedElTys) {
+Type *SPIRVEmitIntrinsics::deduceElementTypeHelper(Value *I) {
+  std::unordered_set<Value *> Visited;
+  return deduceElementTypeHelper(I, Visited);
+}
+
+Type *SPIRVEmitIntrinsics::deduceElementTypeHelper(
+    Value *I, std::unordered_set<Value *> &Visited) {
   // allow to pass nullptr as an argument
   if (!I)
     return nullptr;
 
   // maybe already known
-  auto It = DeducedElTys.find(I);
-  if (It != DeducedElTys.end())
-    return It->second;
+  if (Type *KnownTy = GR->findDeducedElementType(I))
+    return KnownTy;
 
   // maybe a cycle
   if (Visited.find(I) != Visited.end())
@@ -195,25 +207,99 @@ static Type *deduceElementTypeHelper(Value *I,
     Ty = Ref->getResultElementType();
   } else if (auto *Ref = dyn_cast<GlobalValue>(I)) {
     Ty = Ref->getValueType();
+    if (Value *Op = Ref->getNumOperands() > 0 ? Ref->getOperand(0) : nullptr) {
+      if (auto *PtrTy = dyn_cast<PointerType>(Ty)) {
+        if (Type *NestedTy = deduceElementTypeHelper(Op, Visited))
+          Ty = TypedPointerType::get(NestedTy, PtrTy->getAddressSpace());
+      } else {
+        Ty = deduceNestedTypeHelper(dyn_cast<User>(Op), Ty, Visited);
+      }
+    }
   } else if (auto *Ref = dyn_cast<AddrSpaceCastInst>(I)) {
-    Ty = deduceElementTypeHelper(Ref->getPointerOperand(), Visited,
-                                 DeducedElTys);
+    Ty = deduceElementTypeHelper(Ref->getPointerOperand(), Visited);
   } else if (auto *Ref = dyn_cast<BitCastInst>(I)) {
     if (Type *Src = Ref->getSrcTy(), *Dest = Ref->getDestTy();
         isPointerTy(Src) && isPointerTy(Dest))
-      Ty = deduceElementTypeHelper(Ref->getOperand(0), Visited, DeducedElTys);
+      Ty = deduceElementTypeHelper(Ref->getOperand(0), Visited);
   }
 
   // remember the found relationship
-  if (Ty)
-    DeducedElTys[I] = Ty;
+  if (Ty) {
+    // specify nested types if needed, otherwise return unchanged
+    GR->addDeducedElementType(I, Ty);
+  }
 
   return Ty;
 }
 
-Type *SPIRVEmitIntrinsics::deduceElementType(Value *I) {
+// Re-create a type of the value if it has untyped pointer fields, also nested.
+// Return the original value type if no corrections of untyped pointer
+// information is found or needed.
+Type *SPIRVEmitIntrinsics::deduceNestedTypeHelper(User *U) {
   std::unordered_set<Value *> Visited;
-  if (Type *Ty = deduceElementTypeHelper(I, Visited, DeducedElTys))
+  return deduceNestedTypeHelper(U, U->getType(), Visited);
+}
+
+Type *SPIRVEmitIntrinsics::deduceNestedTypeHelper(
+    User *U, Type *OrigTy, std::unordered_set<Value *> &Visited) {
+  if (!U)
+    return OrigTy;
+
+  // maybe already known
+  if (Type *KnownTy = GR->findDeducedCompositeType(U))
+    return KnownTy;
+
+  // maybe a cycle
+  if (Visited.find(U) != Visited.end())
+    return OrigTy;
+  Visited.insert(U);
+
+  if (dyn_cast<StructType>(OrigTy)) {
+    SmallVector<Type *> Tys;
+    bool Change = false;
+    for (unsigned i = 0; i < U->getNumOperands(); ++i) {
+      Value *Op = U->getOperand(i);
+      Type *OpTy = Op->getType();
+      Type *Ty = OpTy;
+      if (Op) {
+        if (auto *PtrTy = dyn_cast<PointerType>(OpTy)) {
+          if (Type *NestedTy = deduceElementTypeHelper(Op, Visited))
+            Ty = TypedPointerType::get(NestedTy, PtrTy->getAddressSpace());
+        } else {
+          Ty = deduceNestedTypeHelper(dyn_cast<User>(Op), OpTy, Visited);
+        }
+      }
+      Tys.push_back(Ty);
+      Change |= Ty != OpTy;
+    }
+    if (Change) {
+      Type *NewTy = StructType::create(Tys);
+      GR->addDeducedCompositeType(U, NewTy);
+      return NewTy;
+    }
+  } else if (auto *ArrTy = dyn_cast<ArrayType>(OrigTy)) {
+    if (Value *Op = U->getNumOperands() > 0 ? U->getOperand(0) : nullptr) {
+      Type *OpTy = ArrTy->getElementType();
+      Type *Ty = OpTy;
+      if (auto *PtrTy = dyn_cast<PointerType>(OpTy)) {
+        if (Type *NestedTy = deduceElementTypeHelper(Op, Visited))
+          Ty = TypedPointerType::get(NestedTy, PtrTy->getAddressSpace());
+      } else {
+        Ty = deduceNestedTypeHelper(dyn_cast<User>(Op), OpTy, Visited);
+      }
+      if (Ty != OpTy) {
+        Type *NewTy = ArrayType::get(Ty, ArrTy->getNumElements());
+        GR->addDeducedCompositeType(U, NewTy);
+        return NewTy;
+      }
+    }
+  }
+
+  return OrigTy;
+}
+
+Type *SPIRVEmitIntrinsics::deduceElementType(Value *I) {
+  if (Type *Ty = deduceElementTypeHelper(I))
     return Ty;
   return IntegerType::getInt8Ty(I->getContext());
 }
@@ -257,6 +343,7 @@ void SPIRVEmitIntrinsics::preprocessUndefs(IRBuilder<> &B) {
       Worklist.push(IntrUndef);
       I->replaceUsesOfWith(Op, IntrUndef);
       AggrConsts[IntrUndef] = AggrUndef;
+      AggrConstTypes[IntrUndef] = AggrUndef->getType();
     }
   }
 }
@@ -282,6 +369,7 @@ void SPIRVEmitIntrinsics::preprocessCompositeConstants(IRBuilder<> &B) {
             I->replaceUsesOfWith(Op, CCI);
             KeepInst = true;
             SEI.AggrConsts[CCI] = AggrC;
+            SEI.AggrConstTypes[CCI] = SEI.deduceNestedTypeHelper(AggrC);
           };
 
       if (auto *AggrC = dyn_cast<ConstantAggregate>(Op)) {
@@ -396,8 +484,7 @@ void SPIRVEmitIntrinsics::replacePointerOperandWithPtrCast(
     Pointer = BC->getOperand(0);
 
   // Do not emit spv_ptrcast if Pointer's element type is ExpectedElementType
-  std::unordered_set<Value *> Visited;
-  Type *PointerElemTy = deduceElementTypeHelper(Pointer, Visited, DeducedElTys);
+  Type *PointerElemTy = deduceElementTypeHelper(Pointer);
   if (PointerElemTy == ExpectedElementType)
     return;
 
@@ -456,8 +543,8 @@ void SPIRVEmitIntrinsics::replacePointerOperandWithPtrCast(
     CallInst *CI = buildIntrWithMD(
         Intrinsic::spv_assign_ptr_type, {Pointer->getType()},
         ExpectedElementTypeConst, Pointer, {B.getInt32(AddressSpace)}, B);
-    DeducedElTys[CI] = ExpectedElementType;
-    DeducedElTys[Pointer] = ExpectedElementType;
+    GR->addDeducedElementType(CI, ExpectedElementType);
+    GR->addDeducedElementType(Pointer, ExpectedElementType);
     return;
   }
 
@@ -498,25 +585,29 @@ void SPIRVEmitIntrinsics::insertPtrCastOrAssignTypeInstr(Instruction *I,
   Function *CalledF = CI->getCalledFunction();
   SmallVector<Type *, 4> CalledArgTys;
   bool HaveTypes = false;
-  for (auto &CalledArg : CalledF->args()) {
-    if (!isPointerTy(CalledArg.getType())) {
+  for (unsigned OpIdx = 0; OpIdx < CalledF->arg_size(); ++OpIdx) {
+    Argument *CalledArg = CalledF->getArg(OpIdx);
+    Type *ArgType = CalledArg->getType();
+    if (!isPointerTy(ArgType)) {
       CalledArgTys.push_back(nullptr);
-      continue;
-    }
-    auto It = DeducedElTys.find(&CalledArg);
-    Type *ParamTy = It != DeducedElTys.end() ? It->second : nullptr;
-    if (!ParamTy) {
-      for (User *U : CalledArg.users()) {
-        if (Instruction *Inst = dyn_cast<Instruction>(U)) {
-          std::unordered_set<Value *> Visited;
-          ParamTy = deduceElementTypeHelper(Inst, Visited, DeducedElTys);
-          if (ParamTy)
-            break;
+    } else if (isTypedPointerTy(ArgType)) {
+      CalledArgTys.push_back(cast<TypedPointerType>(ArgType)->getElementType());
+      HaveTypes = true;
+    } else {
+      Type *ElemTy = GR->findDeducedElementType(CalledArg);
+      if (!ElemTy && hasPointeeTypeAttr(CalledArg))
+        ElemTy = getPointeeTypeByAttr(CalledArg);
+      if (!ElemTy) {
+        for (User *U : CalledArg->users()) {
+          if (Instruction *Inst = dyn_cast<Instruction>(U)) {
+            if ((ElemTy = deduceElementTypeHelper(Inst)) != nullptr)
+              break;
+          }
         }
       }
+      HaveTypes |= ElemTy != nullptr;
+      CalledArgTys.push_back(ElemTy);
     }
-    HaveTypes |= ParamTy != nullptr;
-    CalledArgTys.push_back(ParamTy);
   }
 
   std::string DemangledName =
@@ -706,6 +797,10 @@ void SPIRVEmitIntrinsics::processGlobalValue(GlobalVariable &GV,
   if (GV.getName() == "llvm.global.annotations")
     return;
   if (GV.hasInitializer() && !isa<UndefValue>(GV.getInitializer())) {
+    // Deduce element type and store results in Global Registry.
+    // Result is ignored, because TypedPointerType is not supported
+    // by llvm IR general logic.
+    deduceElementTypeHelper(&GV);
     Constant *Init = GV.getInitializer();
     Type *Ty = isAggrToReplace(Init) ? B.getInt32Ty() : Init->getType();
     Constant *Const = isAggrToReplace(Init) ? B.getInt32(1) : Init;
@@ -732,7 +827,7 @@ void SPIRVEmitIntrinsics::insertAssignPtrTypeIntrs(Instruction *I,
   unsigned AddressSpace = getPointerAddressSpace(I->getType());
   CallInst *CI = buildIntrWithMD(Intrinsic::spv_assign_ptr_type, {I->getType()},
                                  EltTyConst, I, {B.getInt32(AddressSpace)}, B);
-  DeducedElTys[CI] = ElemTy;
+  GR->addDeducedElementType(CI, ElemTy);
 }
 
 void SPIRVEmitIntrinsics::insertAssignTypeIntrs(Instruction *I,
@@ -745,9 +840,10 @@ void SPIRVEmitIntrinsics::insertAssignTypeIntrs(Instruction *I,
     if (auto *II = dyn_cast<IntrinsicInst>(I)) {
       if (II->getIntrinsicID() == Intrinsic::spv_const_composite ||
           II->getIntrinsicID() == Intrinsic::spv_undef) {
-        auto t = AggrConsts.find(II);
-        assert(t != AggrConsts.end());
-        TypeToAssign = t->second->getType();
+        auto It = AggrConstTypes.find(II);
+        if (It == AggrConstTypes.end())
+          report_fatal_error("Unknown composite intrinsic type");
+        TypeToAssign = It->second;
       }
     }
     Constant *Const = UndefValue::get(TypeToAssign);
@@ -807,12 +903,13 @@ void SPIRVEmitIntrinsics::processInstrAfterVisit(Instruction *I,
   }
 }
 
-Type *SPIRVEmitIntrinsics::deduceFunParamType(Function *F, unsigned OpIdx) {
+Type *SPIRVEmitIntrinsics::deduceFunParamElementType(Function *F,
+                                                     unsigned OpIdx) {
   std::unordered_set<Function *> FVisited;
-  return deduceFunParamType(F, OpIdx, FVisited);
+  return deduceFunParamElementType(F, OpIdx, FVisited);
 }
 
-Type *SPIRVEmitIntrinsics::deduceFunParamType(
+Type *SPIRVEmitIntrinsics::deduceFunParamElementType(
     Function *F, unsigned OpIdx, std::unordered_set<Function *> &FVisited) {
   // maybe a cycle
   if (FVisited.find(F) != FVisited.end())
@@ -830,15 +927,15 @@ Type *SPIRVEmitIntrinsics::deduceFunParamType(
     if (!isPointerTy(OpArg->getType()))
       continue;
     // maybe we already know operand's element type
-    if (auto It = DeducedElTys.find(OpArg); It != DeducedElTys.end())
-      return It->second;
+    if (Type *KnownTy = GR->findDeducedElementType(OpArg))
+      return KnownTy;
     // search in actual parameter's users
     for (User *OpU : OpArg->users()) {
       Instruction *Inst = dyn_cast<Instruction>(OpU);
       if (!Inst || Inst == CI)
         continue;
       Visited.clear();
-      if (Type *Ty = deduceElementTypeHelper(Inst, Visited, DeducedElTys))
+      if (Type *Ty = deduceElementTypeHelper(Inst, Visited))
         return Ty;
     }
     // check if it's a formal parameter of the outer function
@@ -857,7 +954,7 @@ Type *SPIRVEmitIntrinsics::deduceFunParamType(
 
   // search in function parameters
   for (auto &Pair : Lookup) {
-    if (Type *Ty = deduceFunParamType(Pair.first, Pair.second, FVisited))
+    if (Type *Ty = deduceFunParamElementType(Pair.first, Pair.second, FVisited))
       return Ty;
   }
 
@@ -866,19 +963,23 @@ Type *SPIRVEmitIntrinsics::deduceFunParamType(
 
 void SPIRVEmitIntrinsics::processParamTypes(Function *F, IRBuilder<> &B) {
   B.SetInsertPointPastAllocas(F);
-  DenseMap<Argument *, Type *> Args;
   for (unsigned OpIdx = 0; OpIdx < F->arg_size(); ++OpIdx) {
     Argument *Arg = F->getArg(OpIdx);
-    if (isUntypedPointerTy(Arg->getType()) &&
-        DeducedElTys.find(Arg) == DeducedElTys.end() &&
-        !HasPointeeTypeAttr(Arg)) {
-      if (Type *ElemTy = deduceFunParamType(F, OpIdx)) {
+    if (!isUntypedPointerTy(Arg->getType()))
+      continue;
+
+    Type *ElemTy = GR->findDeducedElementType(Arg);
+    if (!ElemTy) {
+      if (hasPointeeTypeAttr(Arg) &&
+          (ElemTy = getPointeeTypeByAttr(Arg)) != nullptr) {
+        GR->addDeducedElementType(Arg, ElemTy);
+      } else if ((ElemTy = deduceFunParamElementType(F, OpIdx)) != nullptr) {
         CallInst *AssignPtrTyCI = buildIntrWithMD(
             Intrinsic::spv_assign_ptr_type, {Arg->getType()},
             Constant::getNullValue(ElemTy), Arg,
             {B.getInt32(getPointerAddressSpace(Arg->getType()))}, B);
-        DeducedElTys[AssignPtrTyCI] = ElemTy;
-        DeducedElTys[Arg] = ElemTy;
+        GR->addDeducedElementType(AssignPtrTyCI, ElemTy);
+        GR->addDeducedElementType(Arg, ElemTy);
       }
     }
   }
@@ -887,9 +988,14 @@ void SPIRVEmitIntrinsics::processParamTypes(Function *F, IRBuilder<> &B) {
 bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
   if (Func.isDeclaration())
     return false;
+
+  const SPIRVSubtarget &ST = TM->getSubtarget<SPIRVSubtarget>(Func);
+  GR = ST.getSPIRVGlobalRegistry();
+
   F = &Func;
   IRBuilder<> B(Func.getContext());
   AggrConsts.clear();
+  AggrConstTypes.clear();
   AggrStores.clear();
 
   // StoreInst's operand type can be changed during the next transformations,

diff  --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index ed0f90ff89ce6e..e0099e52944725 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -41,9 +41,13 @@ class SPIRVGlobalRegistry {
 
   // map a Function to its definition (as a machine instruction operand)
   DenseMap<const Function *, const MachineOperand *> FunctionToInstr;
+  DenseMap<const MachineInstr *, const Function *> FunctionToInstrRev;
   // map function pointer (as a machine instruction operand) to the used
   // Function
   DenseMap<const MachineOperand *, const Function *> InstrToFunction;
+  // Maps Functions to their calls (in a form of the machine instruction,
+  // OpFunctionCall) that happened before the definition is available
+  DenseMap<const Function *, SmallVector<MachineInstr *>> ForwardCalls;
 
   // Look for an equivalent of the newType in the map. Return the equivalent
   // if it's found, otherwise insert newType to the map and return the type.
@@ -59,6 +63,13 @@ class SPIRVGlobalRegistry {
   // Holds the maximum ID we have in the module.
   unsigned Bound;
 
+  // Maps values associated with untyped pointers into deduced element types of
+  // untyped pointers.
+  DenseMap<Value *, Type *> DeducedElTys;
+  // Maps composite values to deduced types where untyped pointers are replaced
+  // with typed ones
+  DenseMap<Value *, Type *> DeducedNestedTys;
+
   // Add a new OpTypeXXX instruction without checking for duplicates.
   SPIRVType *createSPIRVType(const Type *Type, MachineIRBuilder &MIRBuilder,
                              SPIRV::AccessQualifier::AccessQualifier AQ =
@@ -122,6 +133,37 @@ class SPIRVGlobalRegistry {
   void setBound(unsigned V) { Bound = V; }
   unsigned getBound() { return Bound; }
 
+  // Deduced element types of untyped pointers and composites:
+  // - Add a record to the map of deduced element types.
+  void addDeducedElementType(Value *Val, Type *Ty) { DeducedElTys[Val] = Ty; }
+  // - Find a record in the map of deduced element types.
+  Type *findDeducedElementType(const Value *Val) {
+    auto It = DeducedElTys.find(Val);
+    return It == DeducedElTys.end() ? nullptr : It->second;
+  }
+  // - Add a record to the map of deduced composite types.
+  void addDeducedCompositeType(Value *Val, Type *Ty) {
+    DeducedNestedTys[Val] = Ty;
+  }
+  // - Find a record in the map of deduced composite types.
+  Type *findDeducedCompositeType(const Value *Val) {
+    auto It = DeducedNestedTys.find(Val);
+    return It == DeducedNestedTys.end() ? nullptr : It->second;
+  }
+  // - Find a type of the given Global value
+  Type *getDeducedGlobalValueType(const GlobalValue *Global) {
+    // we may know element type if it was deduced earlier
+    Type *ElementTy = findDeducedElementType(Global);
+    if (!ElementTy) {
+      // or we may know element type if it's associated with a composite
+      // value
+      if (Value *GlobalElem =
+              Global->getNumOperands() > 0 ? Global->getOperand(0) : nullptr)
+        ElementTy = findDeducedCompositeType(GlobalElem);
+    }
+    return ElementTy ? ElementTy : Global->getValueType();
+  }
+
   // Map a machine operand that represents a use of a function via function
   // pointer to a machine operand that represents the function definition.
   // Return either the register or invalid value, because we have no context for
@@ -133,18 +175,56 @@ class SPIRVGlobalRegistry {
     auto ResReg = FunctionToInstr.find(ResF->second);
     return ResReg == FunctionToInstr.end() ? nullptr : ResReg->second;
   }
+
+  // Map a Function to a machine instruction that represents the function
+  // definition.
+  const MachineInstr *getFunctionDefinition(const Function *F) {
+    if (!F)
+      return nullptr;
+    auto MOIt = FunctionToInstr.find(F);
+    return MOIt == FunctionToInstr.end() ? nullptr : MOIt->second->getParent();
+  }
+
+  // Map a Function to a machine instruction that represents the function
+  // definition.
+  const Function *getFunctionByDefinition(const MachineInstr *MI) {
+    if (!MI)
+      return nullptr;
+    auto FIt = FunctionToInstrRev.find(MI);
+    return FIt == FunctionToInstrRev.end() ? nullptr : FIt->second;
+  }
+
   // map function pointer (as a machine instruction operand) to the used
   // Function
   void recordFunctionPointer(const MachineOperand *MO, const Function *F) {
     InstrToFunction[MO] = F;
   }
+
   // map a Function to its definition (as a machine instruction)
   void recordFunctionDefinition(const Function *F, const MachineOperand *MO) {
     FunctionToInstr[F] = MO;
+    FunctionToInstrRev[MO->getParent()] = F;
   }
+
   // Return true if any OpConstantFunctionPointerINTEL were generated
   bool hasConstFunPtr() { return !InstrToFunction.empty(); }
 
+  // Add a record about forward function call.
+  void addForwardCall(const Function *F, MachineInstr *MI) {
+    auto It = ForwardCalls.find(F);
+    if (It == ForwardCalls.end())
+      ForwardCalls[F] = {MI};
+    else
+      It->second.push_back(MI);
+  }
+
+  // Map a Function to the vector of machine instructions that represents
+  // forward function calls or to nullptr if not found.
+  SmallVector<MachineInstr *> *getForwardCalls(const Function *F) {
+    auto It = ForwardCalls.find(F);
+    return It == ForwardCalls.end() ? nullptr : &It->second;
+  }
+
   // Get or create a SPIR-V type corresponding the given LLVM IR type,
   // and map it to the given VReg by creating an ASSIGN_TYPE instruction.
   SPIRVType *assignTypeToVReg(const Type *Type, Register VReg,

diff  --git a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
index 55b4c47c197dab..4f5c1dc4f90b0d 100644
--- a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
@@ -86,8 +86,8 @@ bool SPIRVTargetLowering::getTgtMemIntrinsic(IntrinsicInfo &Info,
 // when there is a type mismatch between results and operand types.
 static void validatePtrTypes(const SPIRVSubtarget &STI,
                              MachineRegisterInfo *MRI, SPIRVGlobalRegistry &GR,
-                             MachineInstr &I, SPIRVType *ResType,
-                             unsigned OpIdx) {
+                             MachineInstr &I, unsigned OpIdx,
+                             SPIRVType *ResType, const Type *ResTy = nullptr) {
   Register OpReg = I.getOperand(OpIdx).getReg();
   SPIRVType *TypeInst = MRI->getVRegDef(OpReg);
   SPIRVType *OpType = GR.getSPIRVTypeForVReg(
@@ -97,7 +97,13 @@ static void validatePtrTypes(const SPIRVSubtarget &STI,
   if (!ResType || !OpType || OpType->getOpcode() != SPIRV::OpTypePointer)
     return;
   SPIRVType *ElemType = GR.getSPIRVTypeForVReg(OpType->getOperand(2).getReg());
-  if (!ElemType || ElemType == ResType)
+  if (!ElemType)
+    return;
+  bool IsSameMF =
+      ElemType->getParent()->getParent() == ResType->getParent()->getParent();
+  bool IsEqualTypes = IsSameMF ? ElemType == ResType
+                               : GR.getTypeForSPIRVType(ElemType) == ResTy;
+  if (IsEqualTypes)
     return;
   // There is a type mismatch between results and operand types
   // and we insert a bitcast before the instruction to keep SPIR-V code valid
@@ -105,7 +111,11 @@ static void validatePtrTypes(const SPIRVSubtarget &STI,
       static_cast<SPIRV::StorageClass::StorageClass>(
           OpType->getOperand(1).getImm());
   MachineIRBuilder MIB(I);
-  SPIRVType *NewPtrType = GR.getOrCreateSPIRVPointerType(ResType, MIB, SC);
+  SPIRVType *NewBaseType =
+      IsSameMF ? ResType
+               : GR.getOrCreateSPIRVType(
+                     ResTy, MIB, SPIRV::AccessQualifier::ReadWrite, false);
+  SPIRVType *NewPtrType = GR.getOrCreateSPIRVPointerType(NewBaseType, MIB, SC);
   if (!GR.isBitcastCompatible(NewPtrType, OpType))
     report_fatal_error(
         "insert validation bitcast: incompatible result and operand types");
@@ -123,6 +133,74 @@ static void validatePtrTypes(const SPIRVSubtarget &STI,
   I.getOperand(OpIdx).setReg(NewReg);
 }
 
+// Insert a bitcast before the function call instruction to keep SPIR-V code
+// valid when there is a type mismatch between actual and expected types of an
+// argument:
+// %formal = OpFunctionParameter %formal_type
+// ...
+// %res = OpFunctionCall %ty %fun %actual ...
+// implies that %actual is of %formal_type, and in case of opaque pointers.
+// We may need to insert a bitcast to ensure this.
+void validateFunCallMachineDef(const SPIRVSubtarget &STI,
+                               MachineRegisterInfo *DefMRI,
+                               MachineRegisterInfo *CallMRI,
+                               SPIRVGlobalRegistry &GR, MachineInstr &FunCall,
+                               MachineInstr *FunDef) {
+  if (FunDef->getOpcode() != SPIRV::OpFunction)
+    return;
+  unsigned OpIdx = 3;
+  for (FunDef = FunDef->getNextNode();
+       FunDef && FunDef->getOpcode() == SPIRV::OpFunctionParameter &&
+       OpIdx < FunCall.getNumOperands();
+       FunDef = FunDef->getNextNode(), OpIdx++) {
+    SPIRVType *DefPtrType = DefMRI->getVRegDef(FunDef->getOperand(1).getReg());
+    SPIRVType *DefElemType =
+        DefPtrType && DefPtrType->getOpcode() == SPIRV::OpTypePointer
+            ? GR.getSPIRVTypeForVReg(DefPtrType->getOperand(2).getReg())
+            : nullptr;
+    if (DefElemType) {
+      const Type *DefElemTy = GR.getTypeForSPIRVType(DefElemType);
+      // Switch GR context to the call site instead of the (default) definition
+      // side
+      GR.setCurrentFunc(*FunCall.getParent()->getParent());
+      validatePtrTypes(STI, CallMRI, GR, FunCall, OpIdx, DefElemType,
+                       DefElemTy);
+      GR.setCurrentFunc(*FunDef->getParent()->getParent());
+    }
+  }
+}
+
+// Ensure there is no mismatch between actual and expected arg types: calls
+// with a processed definition. Return Function pointer if it's a forward
+// call (ahead of definition), and nullptr otherwise.
+const Function *validateFunCall(const SPIRVSubtarget &STI,
+                                MachineRegisterInfo *MRI,
+                                SPIRVGlobalRegistry &GR,
+                                MachineInstr &FunCall) {
+  const GlobalValue *GV = FunCall.getOperand(2).getGlobal();
+  const Function *F = dyn_cast<Function>(GV);
+  MachineInstr *FunDef =
+      const_cast<MachineInstr *>(GR.getFunctionDefinition(F));
+  if (!FunDef)
+    return F;
+  validateFunCallMachineDef(STI, MRI, MRI, GR, FunCall, FunDef);
+  return nullptr;
+}
+
+// Ensure there is no mismatch between actual and expected arg types: calls
+// ahead of a processed definition.
+void validateForwardCalls(const SPIRVSubtarget &STI,
+                          MachineRegisterInfo *DefMRI, SPIRVGlobalRegistry &GR,
+                          MachineInstr &FunDef) {
+  const Function *F = GR.getFunctionByDefinition(&FunDef);
+  if (SmallVector<MachineInstr *> *FwdCalls = GR.getForwardCalls(F))
+    for (MachineInstr *FunCall : *FwdCalls) {
+      MachineRegisterInfo *CallMRI =
+          &FunCall->getParent()->getParent()->getRegInfo();
+      validateFunCallMachineDef(STI, DefMRI, CallMRI, GR, *FunCall, &FunDef);
+    }
+}
+
 // TODO: the logic of inserting additional bitcast's is to be moved
 // to pre-IRTranslation passes eventually
 void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const {
@@ -137,14 +215,28 @@ void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const {
       switch (MI.getOpcode()) {
       case SPIRV::OpLoad:
         // OpLoad <ResType>, ptr %Op implies that %Op is a pointer to <ResType>
-        validatePtrTypes(STI, MRI, GR, MI,
-                         GR.getSPIRVTypeForVReg(MI.getOperand(0).getReg()), 2);
+        validatePtrTypes(STI, MRI, GR, MI, 2,
+                         GR.getSPIRVTypeForVReg(MI.getOperand(0).getReg()));
         break;
       case SPIRV::OpStore:
         // OpStore ptr %Op, <Obj> implies that %Op points to the <Obj>'s type
-        validatePtrTypes(STI, MRI, GR, MI,
-                         GR.getSPIRVTypeForVReg(MI.getOperand(1).getReg()), 0);
+        validatePtrTypes(STI, MRI, GR, MI, 0,
+                         GR.getSPIRVTypeForVReg(MI.getOperand(1).getReg()));
         break;
+
+      case SPIRV::OpFunctionCall:
+        // ensure there is no mismatch between actual and expected arg types:
+        // calls with a processed definition
+        if (MI.getNumOperands() > 3)
+          if (const Function *F = validateFunCall(STI, MRI, GR, MI))
+            GR.addForwardCall(F, &MI);
+        break;
+      case SPIRV::OpFunction:
+        // ensure there is no mismatch between actual and expected arg types:
+        // calls ahead of a processed definition
+        validateForwardCalls(STI, MRI, GR, MI);
+        break;
+
       // ensure that LLVM IR bitwise instructions result in logical SPIR-V
       // instructions when applied to bool type
       case SPIRV::OpBitwiseOrS:

diff  --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 505b19a4d66edb..f4525e713c987f 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -1897,7 +1897,7 @@ bool SPIRVInstructionSelector::selectGlobalValue(
   // FIXME: don't use MachineIRBuilder here, replace it with BuildMI.
   MachineIRBuilder MIRBuilder(I);
   const GlobalValue *GV = I.getOperand(1).getGlobal();
-  Type *GVType = GV->getValueType();
+  Type *GVType = GR.getDeducedGlobalValueType(GV);
   SPIRVType *PointerBaseType;
   if (GVType->isArrayTy()) {
     SPIRVType *ArrayElementType =

diff  --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
index 41807da6afcbc7..b133f0ae85de20 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
@@ -186,8 +186,9 @@ static SPIRVType *propagateSPIRVType(MachineInstr *MI, SPIRVGlobalRegistry *GR,
       }
       case TargetOpcode::G_GLOBAL_VALUE: {
         MIB.setInsertPt(*MI->getParent(), MI);
-        const auto *Global = MI->getOperand(1).getGlobal();
-        auto *Ty = TypedPointerType::get(Global->getValueType(),
+        const GlobalValue *Global = MI->getOperand(1).getGlobal();
+        Type *ElementTy = GR->getDeducedGlobalValueType(Global);
+        auto *Ty = TypedPointerType::get(ElementTy,
                                          Global->getType()->getAddressSpace());
         SpirvTy = GR->getOrCreateSPIRVType(Ty, MIB);
         break;

diff  --git a/llvm/lib/Target/SPIRV/SPIRVUtils.h b/llvm/lib/Target/SPIRV/SPIRVUtils.h
index eb87349f0941c5..c2c3475e1a936f 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.h
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.h
@@ -127,8 +127,26 @@ inline unsigned getPointerAddressSpace(const Type *T) {
 }
 
 // Return true if the Argument is decorated with a pointee type
-inline bool HasPointeeTypeAttr(Argument *Arg) {
-  return Arg->hasByValAttr() || Arg->hasByRefAttr();
+inline bool hasPointeeTypeAttr(Argument *Arg) {
+  return Arg->hasByValAttr() || Arg->hasByRefAttr() || Arg->hasStructRetAttr();
+}
+
+// Return the pointee type of the argument or nullptr otherwise
+inline Type *getPointeeTypeByAttr(Argument *Arg) {
+  if (Arg->hasByValAttr())
+    return Arg->getParamByValType();
+  if (Arg->hasStructRetAttr())
+    return Arg->getParamStructRetType();
+  if (Arg->hasByRefAttr())
+    return Arg->getParamByRefType();
+  return nullptr;
+}
+
+inline Type *reconstructFunctionType(Function *F) {
+  SmallVector<Type *> ArgTys;
+  for (unsigned i = 0; i < F->arg_size(); ++i)
+    ArgTys.push_back(F->getArg(i)->getType());
+  return FunctionType::get(F->getReturnType(), ArgTys, F->isVarArg());
 }
 
 } // namespace llvm

diff  --git a/llvm/test/CodeGen/SPIRV/pointers/nested-struct-opaque-pointers.ll b/llvm/test/CodeGen/SPIRV/pointers/nested-struct-opaque-pointers.ll
new file mode 100644
index 00000000000000..77b895c7762fba
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/pointers/nested-struct-opaque-pointers.ll
@@ -0,0 +1,20 @@
+; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-NOT: OpTypeInt 8 0
+
+ at GI = addrspace(1) constant i64 42
+
+ at GS = addrspace(1) global {ptr addrspace(1), ptr addrspace(1)} { ptr addrspace(1) @GI, ptr addrspace(1) @GI }
+ at GS2 = addrspace(1) global {ptr addrspace(1), ptr addrspace(1)} { ptr addrspace(1) @GS, ptr addrspace(1) @GS }
+ at GS3 = addrspace(1) global {ptr addrspace(1), ptr addrspace(1)} { ptr addrspace(1) @GS2, ptr addrspace(1) @GS2 }
+
+ at GPS = addrspace(1) global ptr addrspace(1) @GS3
+
+ at GPI1 = addrspace(1) global ptr addrspace(1) @GI
+ at GPI2 = addrspace(1) global ptr addrspace(1) @GPI1
+ at GPI3 = addrspace(1) global ptr addrspace(1) @GPI2
+
+define spir_kernel void @foo() {
+  ret void
+}

diff  --git a/llvm/test/CodeGen/SPIRV/pointers/struct-opaque-pointers.ll b/llvm/test/CodeGen/SPIRV/pointers/struct-opaque-pointers.ll
index ce3ab8895a5948..6d4913f802c289 100644
--- a/llvm/test/CodeGen/SPIRV/pointers/struct-opaque-pointers.ll
+++ b/llvm/test/CodeGen/SPIRV/pointers/struct-opaque-pointers.ll
@@ -1,14 +1,14 @@
 ; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s
 ; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}
 
-; CHECK: %[[TyInt8:.*]] = OpTypeInt 8 0
-; CHECK: %[[TyInt8Ptr:.*]] = OpTypePointer {{[a-zA-Z]+}} %[[TyInt8]]
-; CHECK: %[[TyStruct:.*]] = OpTypeStruct %[[TyInt8Ptr]] %[[TyInt8Ptr]]
+; CHECK: %[[TyInt64:.*]] = OpTypeInt 64 0
+; CHECK: %[[TyInt64Ptr:.*]] = OpTypePointer {{[a-zA-Z]+}} %[[TyInt64]]
+; CHECK: %[[TyStruct:.*]] = OpTypeStruct %[[TyInt64Ptr]] %[[TyInt64Ptr]]
 ; CHECK: %[[ConstStruct:.*]] = OpConstantComposite %[[TyStruct]] %[[ConstField:.*]] %[[ConstField]]
 ; CHECK: %[[TyStructPtr:.*]] = OpTypePointer {{[a-zA-Z]+}} %[[TyStruct]]
 ; CHECK: OpVariable %[[TyStructPtr]] {{[a-zA-Z]+}} %[[ConstStruct]]
 
- at a = addrspace(1) constant i32 123
+ at a = addrspace(1) constant i64 42
 @struct = addrspace(1) global {ptr addrspace(1), ptr addrspace(1)} { ptr addrspace(1) @a, ptr addrspace(1) @a }
 
 define spir_kernel void @foo() {

diff  --git a/llvm/test/CodeGen/SPIRV/pointers/type-deduce-by-call-chain.ll b/llvm/test/CodeGen/SPIRV/pointers/type-deduce-by-call-chain.ll
index 703f1e22a0321a..1071d3443056cb 100644
--- a/llvm/test/CodeGen/SPIRV/pointers/type-deduce-by-call-chain.ll
+++ b/llvm/test/CodeGen/SPIRV/pointers/type-deduce-by-call-chain.ll
@@ -34,6 +34,12 @@ entry:
   %addr = addrspacecast ptr addrspace(1) %lptr to ptr addrspace(4)
   %object = bitcast ptr addrspace(4) %addr to ptr addrspace(4)
   call spir_func void @foo(ptr addrspace(4) %object, i32 3)
+  %halfptr = getelementptr inbounds half, ptr addrspace(1) %_arg_cum, i64 1
+  %halfaddr = addrspacecast ptr addrspace(1) %halfptr to ptr addrspace(4)
+  call spir_func void @foo(ptr addrspace(4) %halfaddr, i32 3)
+  %dblptr = getelementptr inbounds double, ptr addrspace(1) %_arg_cum, i64 1
+  %dbladdr = addrspacecast ptr addrspace(1) %dblptr to ptr addrspace(4)
+  call spir_func void @foo(ptr addrspace(4) %dbladdr, i32 3)
   ret void
 }
 
@@ -49,4 +55,3 @@ define void @foo(ptr addrspace(4) noundef %foo_object, i32 noundef %mem_order) {
   tail call void @foo_stub(ptr addrspace(4) noundef %foo_object, i32 noundef %mem_order)
   ret void
 }
-


        


More information about the llvm-commits mailing list