[llvm] [SPIR-V] Add type analysis pass (PR #131348)

Nathan Gauër via llvm-commits llvm-commits at lists.llvm.org
Fri Mar 14 09:00:44 PDT 2025


https://github.com/Keenuts created https://github.com/llvm/llvm-project/pull/131348

The backend currenty generates 2 intrinsics to deduce value types:
 - spv.assign_type
 - spv.assign_ptr_type

This has 2 issues:
 - if adds many instructions in the IR which makes passes working on those types harder to implement: we want to make sure not to break deduces types when modifying the IR. And if we change a type, we need to reimplement the type propagation logic.
 - it's currently implemented in the EmitIntrinsics pass, which also lowers other LLVM-IR instructions into SPV intrinsics, meaning passes requiring type information also have trouble since they need to understand both LLVM-IR, and SPV intrinsics equivalents (ex spv.gep vs getelementptr).
 - Lastly, OpenCL can to pointercast, meaning the ptr type deduction can be 'best-effort', and fallback to *i8. For graphical SPIR-V, this is illegal, so we must make sure each pointer is strongly typed.

This commit adds an analysis which given a Value returns the non-opaque type (as it type which contains no opaque pointer).

For now, this analysis implements a subset of the LLVM-IR, hence this is still a draft.

>From 4c1d357b4d6d0712f4daa4d14bb2117333b300bd Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Nathan=20Gau=C3=ABr?= <brioche at google.com>
Date: Mon, 10 Mar 2025 16:23:27 +0100
Subject: [PATCH] [SPIR-V] Add type analysis pass

The backend currenty generates 2 intrinsics to deduce value
types:
 - spv.assign_type
 - spv.assign_ptr_type

This has 2 issues:
 - if adds many instructions in the IR which makes passes working on
   those types harder to implement: we want to make sure not to
   break deduces types when modifying the IR. And if we change a type,
   we need to reimplement the type propagation logic.
 - it's currently implemented in the EmitIntrinsics pass, which also
   lowers other LLVM-IR instructions into SPV intrinsics, meaning passes
   requiring type information also have trouble since they need to
   understand both LLVM-IR, and SPV intrinsics equivalents (ex spv.gep
   vs getelementptr).
 - Lastly, OpenCL can to pointercast, meaning the ptr type deduction can
   be 'best-effort', and fallback to *i8. For graphical SPIR-V, this is
   illegal, so we must make sure each pointer is strongly typed.

This commit adds an analysis which given a Value returns the non-opaque
type (as it type which contains no opaque pointer).

For now, this analysis implements a subset of the LLVM-IR, but most hard
cases should be handled:
 - cross-function type deduction
 - phi/select type deduction
 - inner element access through pointer load/store.
 - type conflict caused by those inner element loads.
---
 llvm/lib/Target/SPIRV/Analysis/CMakeLists.txt |   1 +
 .../SPIRV/Analysis/SPIRVTypeAnalysis.cpp      | 453 +++++++++++
 .../Target/SPIRV/Analysis/SPIRVTypeAnalysis.h |  95 +++
 llvm/lib/Target/SPIRV/SPIRV.h                 |   1 +
 llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp |  10 +-
 llvm/unittests/Target/SPIRV/CMakeLists.txt    |   1 +
 .../Target/SPIRV/SPIRVTypeAnalysisTests.cpp   | 722 ++++++++++++++++++
 7 files changed, 1281 insertions(+), 2 deletions(-)
 create mode 100644 llvm/lib/Target/SPIRV/Analysis/SPIRVTypeAnalysis.cpp
 create mode 100644 llvm/lib/Target/SPIRV/Analysis/SPIRVTypeAnalysis.h
 create mode 100644 llvm/unittests/Target/SPIRV/SPIRVTypeAnalysisTests.cpp

diff --git a/llvm/lib/Target/SPIRV/Analysis/CMakeLists.txt b/llvm/lib/Target/SPIRV/Analysis/CMakeLists.txt
index 4d4351132d3de..30661413aaeb1 100644
--- a/llvm/lib/Target/SPIRV/Analysis/CMakeLists.txt
+++ b/llvm/lib/Target/SPIRV/Analysis/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_llvm_component_library(LLVMSPIRVAnalysis
   SPIRVConvergenceRegionAnalysis.cpp
+  SPIRVTypeAnalysis.cpp
 
   LINK_COMPONENTS
   Analysis
diff --git a/llvm/lib/Target/SPIRV/Analysis/SPIRVTypeAnalysis.cpp b/llvm/lib/Target/SPIRV/Analysis/SPIRVTypeAnalysis.cpp
new file mode 100644
index 0000000000000..992994943f077
--- /dev/null
+++ b/llvm/lib/Target/SPIRV/Analysis/SPIRVTypeAnalysis.cpp
@@ -0,0 +1,453 @@
+//===- SPIRVTypeAnalysis.h -----------------------------*- C++ -*--===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This analysis links a type information to every register/pointer, allowing
+// us to legalize type mismatches when required (graphical SPIR-V pointers for
+// ex).
+//
+//===----------------------------------------------------------------------===//
+
+#include "SPIRVTypeAnalysis.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/TypedPointerType.h"
+#include "llvm/InitializePasses.h"
+
+#include <optional>
+#include <queue>
+
+#define DEBUG_TYPE "spirv-type-analysis"
+
+using namespace llvm;
+
+namespace llvm {
+void initializeSPIRVTypeAnalysisWrapperPassPass(PassRegistry &);
+} // namespace llvm
+
+INITIALIZE_PASS_BEGIN(SPIRVTypeAnalysisWrapperPass, "type-analysis",
+                      "SPIRV type analysis", true, true)
+INITIALIZE_PASS_END(SPIRVTypeAnalysisWrapperPass, "type-region",
+                    "SPIRV type analysis", true, true)
+
+namespace llvm {
+namespace SPIRV {
+namespace {} // anonymous namespace
+
+// Returns true this type contains no opaque pointers (recursively).
+bool TypeInfo::isOpaqueType(const Type *T) {
+  if (T->isPointerTy())
+    return true;
+
+  if (const ArrayType *AT = dyn_cast<ArrayType>(T))
+    return TypeInfo::isOpaqueType(AT->getElementType());
+  if (const VectorType *VT = dyn_cast<VectorType>(T))
+    return TypeInfo::isOpaqueType(VT->getElementType());
+
+  return false;
+}
+
+class TypeAnalyzer {
+
+public:
+  TypeAnalyzer(Module &M)
+      : M(M), TypeMap(new DenseMap<const Value *, Type *>()) {}
+
+  TypeInfo analyze() {
+    for (const Function &F : M) {
+      for (const BasicBlock &BB : F) {
+        for (const Value &V : BB) {
+          if (!deduceElementType(&V))
+            IncompleteTypeDefinition.insert(&V);
+        }
+      }
+    }
+
+    size_t IncompleteCount;
+    do {
+      IncompleteCount = IncompleteTypeDefinition.size();
+      for (const Value *Item : IncompleteTypeDefinition) {
+        if (deduceElementType(Item)) {
+          IncompleteTypeDefinition.erase(Item);
+          break;
+        }
+      }
+    } while (IncompleteTypeDefinition.size() < IncompleteCount);
+
+    return TypeInfo(TypeMap);
+  }
+
+private:
+  Type *getMappedType(const Value *V) {
+    auto It = TypeMap->find(V);
+    if (It == TypeMap->end())
+      return nullptr;
+    return It->second;
+  }
+
+  bool typeContainsType(Type *Wrapper, Type *Needle) {
+    if (Wrapper == Needle)
+      return true;
+
+    TypedPointerType *LP = dyn_cast<TypedPointerType>(Wrapper);
+    TypedPointerType *RP = dyn_cast<TypedPointerType>(Needle);
+    if (LP && RP)
+      return typeContainsType(LP->getElementType(), RP->getElementType());
+
+    if (StructType *ST = dyn_cast<StructType>(Wrapper))
+      return typeContainsType(ST->getElementType(0), Needle);
+    if (ArrayType *AT = dyn_cast<ArrayType>(Wrapper))
+      return typeContainsType(AT->getElementType(), Needle);
+    if (VectorType *VT = dyn_cast<VectorType>(Wrapper))
+      return typeContainsType(VT->getElementType(), Needle);
+
+    return false;
+  }
+
+  Type *resolveTypeConflict(TypedPointerType *A, TypedPointerType *B) {
+    if (typeContainsType(A->getElementType(), B->getElementType()))
+      return A;
+    if (typeContainsType(B->getElementType(), A->getElementType()))
+      return B;
+    return nullptr;
+  }
+
+  void propagateType(Type *DeducedType, const Value *V) {
+    assert(!TypeInfo::isOpaqueType(DeducedType));
+
+    auto It = TypeMap->find(V);
+    // The value type has already been deduced.
+    if (It != TypeMap->end()) {
+      // There is no conflict.
+      if (DeducedType == It->second)
+        return;
+
+      TypedPointerType *DeducedPtrType =
+          dyn_cast<TypedPointerType>(DeducedType);
+      TypedPointerType *KnownPtrType = dyn_cast<TypedPointerType>(It->second);
+      // Cannot resolve conflict on non-pointer types.
+      if (!DeducedPtrType || !KnownPtrType) {
+        assert(0); // FIXME: shall I ignore, fail, crash?
+        return;
+      }
+
+      DeducedType = resolveTypeConflict(DeducedPtrType, KnownPtrType);
+      if (!DeducedType)
+        return;
+      (*TypeMap)[V] = DeducedType;
+    }
+
+    if (const Constant *C = dyn_cast<Constant>(V))
+      propagateTypeDetails(DeducedType, C);
+    else if (const Argument *C = dyn_cast<Argument>(V))
+      propagateTypeDetails(DeducedType, C);
+    else if (const AllocaInst *C = dyn_cast<AllocaInst>(V))
+      propagateTypeDetails(DeducedType, C);
+    else if (const CallInst *C = dyn_cast<CallInst>(V))
+      propagateTypeDetails(DeducedType, C);
+    else if (const GetElementPtrInst *C = dyn_cast<GetElementPtrInst>(V))
+      propagateTypeDetails(DeducedType, C);
+    else if (const LoadInst *C = dyn_cast<LoadInst>(V))
+      propagateTypeDetails(DeducedType, C);
+    else if (const ReturnInst *C = dyn_cast<ReturnInst>(V))
+      propagateTypeDetails(DeducedType, C);
+    else if (const StoreInst *C = dyn_cast<StoreInst>(V))
+      propagateTypeDetails(DeducedType, C);
+    else
+      llvm_unreachable("FIXME: unsupported instruction");
+
+    // for (const User *U : V->users())
+    //   if (TypeMap->find(U) == TypeMap->end())
+    //     deduceElementType(U);
+
+    // X(CallInst, V);
+    // X(GetElementPtrInst, V);
+    // X(LoadInst, V);
+    // X(ReturnInst, V);
+    // X(StoreInst, V);
+    //  TODO:  GlobalValue
+    //  TODO: addrspacecast
+    //  TODO: bitcast
+    //  TODO: AtomicCmpXchgInst
+    //  TODO: AtomicRMWInst
+    //  TODO: PHINode
+    //  TODO: SelectInst
+    //  TODO: CallInst
+  }
+
+  void propagateTypeDetails(Type *DeducedType, const Constant *C) {
+    (*TypeMap)[C] = DeducedType;
+  }
+
+  void propagateTypeDetails(Type *DeducedType, const Argument *A) {
+    (*TypeMap)[A] = DeducedType;
+
+    unsigned ArgNo = A->getArgNo();
+    for (const User *U : A->getParent()->users()) {
+      const CallInst *CI = cast<CallInst>(U);
+      propagateType(DeducedType, CI->getOperand(ArgNo));
+    }
+  }
+
+  void propagateTypeDetails(Type *DeducedType, const LoadInst *I) {
+    TypeMap->try_emplace(I, DeducedType);
+    propagateType(TypedPointerType::get(DeducedType, 0),
+                  I->getPointerOperand());
+  }
+
+  void propagateTypeDetails(Type *DeducedType, const StoreInst *I) {
+    TypeMap->try_emplace(I, DeducedType);
+  }
+
+  void propagateTypeDetails(Type *DeducedType, const AllocaInst *I) {
+    Type *StoredType = cast<TypedPointerType>(DeducedType)->getElementType();
+    if (ArrayType *AT = dyn_cast<ArrayType>(I->getAllocatedType())) {
+      Type *NewType = ArrayType::get(StoredType, AT->getNumElements());
+      TypeMap->try_emplace(I, NewType);
+      return;
+    }
+
+    TypeMap->try_emplace(I, DeducedType);
+  }
+
+  void propagateTypeDetails(Type *DeducedType, const GetElementPtrInst *GEP) {
+    if (!TypeInfo::isOpaqueType(GEP->getSourceElementType())) {
+      // If the source is non-opaque, the result must be non-opaque (subset of
+      // source). If not, this means the GEP is using the wrong base-type. We
+      // don't support this.
+      assert(!TypeInfo::isOpaqueType(GEP->getResultElementType()));
+      // If the result is non-opaque, this means each use was non-opaque, and
+      // thus we shouldn't have a deduction mismatch. If this happens, something
+      // is wrong with this analysis.
+      TypedPointerType *DeducedPtr = cast<TypedPointerType>(DeducedType);
+      assert(GEP->getResultElementType() == DeducedPtr->getElementType());
+      propagateType(TypedPointerType::get(GEP->getSourceElementType(), 0),
+                    GEP->getPointerOperand());
+      return;
+    }
+
+    TypeMap->try_emplace(GEP, DeducedType);
+
+    // The source type is opaque. We might be able to deduce more info from the
+    // new result type.
+    Type *NewType = DeducedType;
+    std::vector<Type *> Types = {GEP->getSourceElementType()};
+    std::vector<uint64_t> Indices;
+    for (const Use &U : GEP->indices())
+      Indices.push_back(cast<ConstantInt>(&*U)->getZExtValue());
+
+    for (unsigned I = 1; I < Indices.size(); ++I)
+      Types.push_back(
+          GetElementPtrInst::getTypeAtIndex(Types[I - 1], Indices[I]));
+
+    for (unsigned I = 1; I < GEP->getNumIndices(); ++I) {
+      unsigned Index = GEP->getNumIndices() - 1 - I;
+      Type *T = Types[Index];
+      if (T->isPointerTy())
+        return;
+
+      if (ArrayType *AT = dyn_cast<ArrayType>(T))
+        NewType = ArrayType::get(NewType, AT->getNumElements());
+      else if (VectorType *VT = dyn_cast<VectorType>(T))
+        NewType = VectorType::get(NewType, VT->getElementCount());
+      else if (StructType *ST = dyn_cast<StructType>(T))
+        assert(0 && "Opaque struct types are not supported.");
+      else
+        llvm_unreachable("Unsupported aggregate type?");
+    }
+
+    // The first index of a GEP is indexing from the passed pointer. So we need
+    // to add one layer.
+    NewType = TypedPointerType::get(NewType, 0);
+
+    propagateType(NewType, GEP->getPointerOperand());
+  }
+
+  void propagateTypeDetails(Type *DeducedType, const CallInst *CI) {
+    if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI))
+      llvm_unreachable("Not implemented");
+    // return deduceType(II);
+
+    FunctionType *FT = CI->getFunctionType();
+    if (!TypeInfo::isOpaqueType(FT->getReturnType())) {
+      assert(FT->getReturnType() == DeducedType);
+      return;
+    }
+
+    TypeMap->try_emplace(CI, DeducedType);
+    propagateType(DeducedType, CI->getCalledFunction());
+
+    for (const BasicBlock &BB : *CI->getCalledFunction())
+      for (const Instruction &I : BB)
+        if (const ReturnInst *RI = dyn_cast<ReturnInst>(&I))
+          propagateType(DeducedType, RI);
+  }
+
+  void propagateTypeDetails(Type *DeducedType, const ReturnInst *RI) {
+    Value *RV = RI->getReturnValue();
+    assert(RV || DeducedType->isVoidTy());
+
+    TypeMap->try_emplace(RI, DeducedType);
+    propagateType(DeducedType, RV);
+  }
+
+  bool deduceElementType(const Value *V) {
+    assert(V != nullptr);
+
+    auto It = TypeMap->find(V);
+    if (It != TypeMap->end())
+      return true;
+
+#define X(Type, Value)                                                         \
+  if (auto *Casted = dyn_cast<Type>(Value))                                    \
+    return deduceType(Casted);
+
+    X(AllocaInst, V);
+    X(CallInst, V);
+    X(GetElementPtrInst, V);
+    X(LoadInst, V);
+    X(ReturnInst, V);
+    X(StoreInst, V);
+    // TODO:  GlobalValue
+    // TODO: addrspacecast
+    // TODO: bitcast
+    // TODO: AtomicCmpXchgInst
+    // TODO: AtomicRMWInst
+    // TODO: PHINode
+    // TODO: SelectInst
+    // TODO: CallInst
+#undef X
+
+    llvm_unreachable("FIXME: unsupported instruction");
+    return false;
+  }
+
+  bool deduceType(const AllocaInst *I) {
+    if (TypeInfo::isOpaqueType(I->getAllocatedType()))
+      return false;
+    TypeMap->try_emplace(I, TypedPointerType::get(I->getAllocatedType(), 0));
+    return true;
+  }
+
+  bool deduceType(const LoadInst *I) {
+    // First case: the loaded type is complete: we can assign the result type.
+    if (!TypeInfo::isOpaqueType(I->getType())) {
+      TypeMap->try_emplace(I, I->getType());
+      propagateType(TypedPointerType::get(I->getType(), 0),
+                    I->getPointerOperand());
+      return true;
+    }
+
+    // The pointer operand is non-opaque, we can deduce the loaded type.
+    if (Type *PointerOperandTy = getMappedType(I->getPointerOperand())) {
+      // FIXME: only supports pointer of pointers for now.
+      Type *ElementType =
+          cast<TypedPointerType>(PointerOperandTy)->getElementType();
+      assert(!ElementType->isPointerTy());
+      TypeMap->try_emplace(I, ElementType);
+      return true;
+    }
+
+    return false;
+  }
+
+  Type *getDeducedType(const Value *V) {
+    auto It = TypeMap->find(V);
+    return It == TypeMap->end() ? nullptr : It->second;
+  }
+
+  bool deduceType(const StoreInst *I) {
+    Type *ValueType = I->getValueOperand()->getType();
+    Type *SourceType = getDeducedType(I->getPointerOperand());
+    bool isOpaqueType = TypeInfo::isOpaqueType(ValueType);
+
+    if (isOpaqueType && !SourceType)
+      return false;
+
+    if (isOpaqueType)
+      ValueType = cast<TypedPointerType>(SourceType)->getElementType();
+    propagateType(ValueType, I->getValueOperand());
+    propagateType(TypedPointerType::get(ValueType, 0), I->getPointerOperand());
+    return true;
+  }
+
+  bool deduceType(const ReturnInst *I) {
+    Value *RV = I->getReturnValue();
+    if (nullptr == RV) {
+      TypeMap->try_emplace(I, Type::getVoidTy(I->getContext()));
+      return true;
+    }
+
+    Type *T = TypeInfo::isOpaqueType(RV->getType()) ? getDeducedType(RV)
+                                                    : RV->getType();
+    if (nullptr == T)
+      return false;
+
+    TypeMap->try_emplace(I, T);
+    propagateType(T, I->getFunction());
+    return true;
+  }
+
+  bool deduceType(const GetElementPtrInst *I) {
+    if (TypeInfo::isOpaqueType(I->getResultElementType()))
+      return false;
+
+    Type *T = TypedPointerType::get(I->getResultElementType(), 0);
+    // TypeMap->try_emplace(I, T);
+    propagateType(T, I);
+    return true;
+  }
+
+  bool deduceType(const CallInst *CI) {
+    if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI))
+      llvm_unreachable("Not implemented");
+    // return deduceType(II);
+
+    Type *ReturnType = CI->getFunctionType()->getReturnType();
+    ReturnType = TypeInfo::isOpaqueType(ReturnType)
+                     ? getDeducedType(CI->getCalledFunction())
+                     : ReturnType;
+    if (nullptr == ReturnType)
+      return false;
+
+    TypeMap->try_emplace(CI, ReturnType);
+    propagateType(ReturnType, CI->getCalledFunction());
+    return true;
+  }
+
+public:
+  Module &M;
+  DenseMap<const Value *, Type *> *TypeMap;
+  std::unordered_set<const Value *> IncompleteTypeDefinition;
+};
+
+TypeInfo getTypeInfo(Module &M) {
+  TypeAnalyzer Analyzer(M);
+  return Analyzer.analyze();
+}
+
+} // namespace SPIRV
+
+char SPIRVTypeAnalysisWrapperPass::ID = 0;
+
+SPIRVTypeAnalysisWrapperPass::SPIRVTypeAnalysisWrapperPass() : ModulePass(ID) {}
+
+bool SPIRVTypeAnalysisWrapperPass::runOnModule(Module &M) {
+  TI = SPIRV::getTypeInfo(M);
+  return false;
+}
+
+SPIRVTypeAnalysis::Result SPIRVTypeAnalysis::run(Module &M,
+                                                 ModuleAnalysisManager &MAM) {
+  return SPIRV::getTypeInfo(M);
+}
+
+AnalysisKey SPIRVTypeAnalysis::Key;
+
+} // namespace llvm
diff --git a/llvm/lib/Target/SPIRV/Analysis/SPIRVTypeAnalysis.h b/llvm/lib/Target/SPIRV/Analysis/SPIRVTypeAnalysis.h
new file mode 100644
index 0000000000000..ba11bcc796c01
--- /dev/null
+++ b/llvm/lib/Target/SPIRV/Analysis/SPIRVTypeAnalysis.h
@@ -0,0 +1,95 @@
+//===- SPIRVTypeAnalysis.h ------------------------*- C++ -*--===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This analysis links a type information to every register/pointer, allowing
+// us to legalize type mismatches when required (graphical SPIR-V pointers for
+// ex).
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_LIB_TARGET_SPIRV_SPIRVTYPEANALYSIS_H
+#define LLVM_LIB_TARGET_SPIRV_SPIRVTYPEANALYSIS_H
+
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/Analysis/CFG.h"
+#include "llvm/IR/Dominators.h"
+#include "llvm/IR/IRBuilder.h"
+#include <iostream>
+#include <optional>
+#include <unordered_set>
+
+namespace llvm {
+class SPIRVSubtarget;
+class MachineFunction;
+class MachineModuleInfo;
+
+namespace SPIRV {
+
+// Holds a ConvergenceRegion hierarchy.
+class TypeInfo {
+  DenseMap<const Value *, Type *> *TypeMap;
+
+public:
+  TypeInfo() : TypeMap(nullptr) {}
+  TypeInfo(DenseMap<const Value *, Type *> *TypeMap) : TypeMap(TypeMap) {}
+
+  Type *getType(const Value *V) {
+    auto It = TypeMap->find(V);
+    if (It != TypeMap->end())
+      return It->second;
+
+    // In some cases, type deduction is not possible from the IR. This should
+    // only happen when handling opaque pointers, otherwise it means the type
+    // deduction is broken.
+    assert(V->getType()->isPointerTy());
+    return V->getType();
+  }
+
+  // Returns true this type contains no opaque pointers (recursively).
+  static bool isOpaqueType(const Type *T);
+};
+
+} // namespace SPIRV
+
+// Wrapper around the function above to use it with the legacy pass manager.
+class SPIRVTypeAnalysisWrapperPass : public ModulePass {
+  SPIRV::TypeInfo TI;
+
+public:
+  static char ID;
+
+  SPIRVTypeAnalysisWrapperPass();
+
+  void getAnalysisUsage(AnalysisUsage &AU) const override {
+    AU.setPreservesAll();
+  };
+
+  bool runOnModule(Module &F) override;
+
+  SPIRV::TypeInfo &getTypeInfo() { return TI; }
+  const SPIRV::TypeInfo &getTypeInfo() const { return TI; }
+};
+
+// Wrapper around the function above to use it with the new pass manager.
+class SPIRVTypeAnalysis : public AnalysisInfoMixin<SPIRVTypeAnalysis> {
+  friend AnalysisInfoMixin<SPIRVTypeAnalysis>;
+  static AnalysisKey Key;
+
+public:
+  using Result = SPIRV::TypeInfo;
+
+  Result run(Module &F, ModuleAnalysisManager &AM);
+};
+
+namespace SPIRV {
+TypeInfo getTypeInfo(Module &F);
+} // namespace SPIRV
+
+} // namespace llvm
+#endif // LLVM_LIB_TARGET_SPIRV_SPIRVTYPEANALYSIS_H
diff --git a/llvm/lib/Target/SPIRV/SPIRV.h b/llvm/lib/Target/SPIRV/SPIRV.h
index d765dfe370be2..6cd9c492ae13e 100644
--- a/llvm/lib/Target/SPIRV/SPIRV.h
+++ b/llvm/lib/Target/SPIRV/SPIRV.h
@@ -37,6 +37,7 @@ createSPIRVInstructionSelector(const SPIRVTargetMachine &TM,
 
 void initializeSPIRVModuleAnalysisPass(PassRegistry &);
 void initializeSPIRVConvergenceRegionAnalysisWrapperPassPass(PassRegistry &);
+void initializeSPIRVTypeAnalysisWrapperPassPass(PassRegistry &);
 void initializeSPIRVPreLegalizerPass(PassRegistry &);
 void initializeSPIRVPreLegalizerCombinerPass(PassRegistry &);
 void initializeSPIRVPostLegalizerPass(PassRegistry &);
diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index 68651f4ee4d2f..6dddf310c547b 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -11,6 +11,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "Analysis/SPIRVTypeAnalysis.h"
 #include "SPIRV.h"
 #include "SPIRVBuiltins.h"
 #include "SPIRVMetadata.h"
@@ -225,6 +226,7 @@ class SPIRVEmitIntrinsics
   bool runOnModule(Module &M) override;
 
   void getAnalysisUsage(AnalysisUsage &AU) const override {
+    // AU.addRequired<SPIRVTypeAnalysisWrapperPass>();
     ModulePass::getAnalysisUsage(AU);
   }
 };
@@ -257,8 +259,11 @@ bool expectIgnoredInIRTranslation(const Instruction *I) {
 
 char SPIRVEmitIntrinsics::ID = 0;
 
-INITIALIZE_PASS(SPIRVEmitIntrinsics, "emit-intrinsics", "SPIRV emit intrinsics",
-                false, false)
+INITIALIZE_PASS_BEGIN(SPIRVEmitIntrinsics, "emit-intrinsics",
+                      "SPIRV emit intrinsics", false, false)
+// INITIALIZE_PASS_DEPENDENCY(SPIRVTypeAnalysisWrapperPass)
+INITIALIZE_PASS_END(SPIRVEmitIntrinsics, "emit-intrinsics",
+                    "SPIRV emit intrinsics", false, false)
 
 static inline bool isAssignTypeInstr(const Instruction *I) {
   return isa<IntrinsicInst>(I) &&
@@ -2551,6 +2556,7 @@ void SPIRVEmitIntrinsics::parseFunDeclarations(Module &M) {
 
 bool SPIRVEmitIntrinsics::runOnModule(Module &M) {
   bool Changed = false;
+  // auto TI = getAnalysis<SPIRVTypeAnalysisWrapperPass>().getTypeInfo();
 
   parseFunDeclarations(M);
 
diff --git a/llvm/unittests/Target/SPIRV/CMakeLists.txt b/llvm/unittests/Target/SPIRV/CMakeLists.txt
index d7f0290089c4c..ef2fcc612d7df 100644
--- a/llvm/unittests/Target/SPIRV/CMakeLists.txt
+++ b/llvm/unittests/Target/SPIRV/CMakeLists.txt
@@ -16,6 +16,7 @@ set(LLVM_LINK_COMPONENTS
 
 add_llvm_target_unittest(SPIRVTests
   SPIRVConvergenceRegionAnalysisTests.cpp
+  SPIRVTypeAnalysisTests.cpp
   SPIRVSortBlocksTests.cpp
   SPIRVPartialOrderingVisitorTests.cpp
   SPIRVAPITest.cpp
diff --git a/llvm/unittests/Target/SPIRV/SPIRVTypeAnalysisTests.cpp b/llvm/unittests/Target/SPIRV/SPIRVTypeAnalysisTests.cpp
new file mode 100644
index 0000000000000..3c2e7983838e8
--- /dev/null
+++ b/llvm/unittests/Target/SPIRV/SPIRVTypeAnalysisTests.cpp
@@ -0,0 +1,722 @@
+//===- SPIRVTypeAnalysisTests.cpp ----------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "Analysis/SPIRVTypeAnalysis.h"
+#include "llvm/Analysis/DominanceFrontier.h"
+#include "llvm/Analysis/PostDominators.h"
+#include "llvm/AsmParser/Parser.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/LegacyPassManager.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/PassInstrumentation.h"
+#include "llvm/IR/Type.h"
+#include "llvm/IR/TypedPointerType.h"
+#include "llvm/Support/SourceMgr.h"
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include <iostream>
+
+using ::testing::Contains;
+using ::testing::Pair;
+
+using namespace llvm;
+using namespace llvm::SPIRV;
+
+template <typename T> struct IsA {
+  friend bool operator==(const Value *V, const IsA &) { return isa<T>(V); }
+};
+
+class SPIRVTypeAnalysisTest : public testing::Test {
+protected:
+  void SetUp() override {
+    // Required for tests.
+    FAM.registerPass([&] { return PassInstrumentationAnalysis(); });
+    MAM.registerPass([&] { return PassInstrumentationAnalysis(); });
+
+    MAM.registerPass([&] { return SPIRVTypeAnalysis(); });
+  }
+
+  void TearDown() override { M.reset(); }
+
+  SPIRVTypeAnalysis::Result &runAnalysis(StringRef Assembly) {
+    assert(M == nullptr &&
+           "Calling runAnalysis multiple times is unsafe. See getAnalysis().");
+
+    SMDiagnostic Error;
+    M = parseAssemblyString(Assembly, Error, Context);
+    if (!M) {
+      std::cerr << Error.getMessage().str() << std::endl;
+      std::cerr << "> " << Error.getLineContents().str() << std::endl;
+    }
+    assert(M && "Bad assembly. Bad test?");
+
+    ModulePassManager MPM;
+    MPM.run(*M, MAM);
+
+    // Setup helper types.
+    IntTy = IntegerType::get(M->getContext(), 32);
+    FloatTy = Type::getFloatTy(M->getContext());
+
+    return MAM.getResult<SPIRVTypeAnalysis>(*M);
+  }
+
+  const Value *getValue(StringRef Name) {
+    assert(M != nullptr && "Has runAnalysis been called before?");
+
+    for (auto &F : *M) {
+      for (Argument &A : F.args())
+        if (A.getName() == Name)
+          return &A;
+
+      for (auto &BB : F)
+        for (auto &V : BB)
+          if (Name == V.getName())
+            return &V;
+    }
+    ADD_FAILURE() << "Error: Could not locate requested variable. Bad test?";
+    return nullptr;
+  }
+
+  StructType *getStructType(StringRef Name) {
+    for (StructType *ST : M->getIdentifiedStructTypes()) {
+      if (ST->getName() == Name)
+        return ST;
+    }
+
+    ADD_FAILURE() << "Error: Could not locate requested struct type. Bad test?";
+    return nullptr;
+  }
+
+protected:
+  LLVMContext Context;
+  FunctionAnalysisManager FAM;
+  ModuleAnalysisManager MAM;
+  std::unique_ptr<Module> M;
+
+  // Helper types for writting tests.
+  Type *IntTy = nullptr;
+  Type *FloatTy = nullptr;
+};
+
+TEST_F(SPIRVTypeAnalysisTest, ScalarAlloca) {
+  StringRef Assembly = R"(
+    define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
+      %test = alloca i32
+      ret void
+    }
+  )";
+
+  auto TI = runAnalysis(Assembly);
+  EXPECT_EQ(TI.getType(getValue("test")),
+            TypedPointerType::get(IntTy, /* AS= */ 0));
+}
+
+TEST_F(SPIRVTypeAnalysisTest, ScalarAllocaFloat) {
+  StringRef Assembly = R"(
+    define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
+      %test = alloca float
+      ret void
+    }
+  )";
+
+  auto TI = runAnalysis(Assembly);
+  EXPECT_EQ(TI.getType(getValue("test")),
+            TypedPointerType::get(FloatTy, /* AS= */ 0));
+}
+
+TEST_F(SPIRVTypeAnalysisTest, AllocaArray) {
+  StringRef Assembly = R"(
+    define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
+      %test = alloca [5 x i32]
+      ret void
+    }
+  )";
+
+  auto TI = runAnalysis(Assembly);
+  EXPECT_EQ(TI.getType(getValue("test")),
+            TypedPointerType::get(ArrayType::get(IntTy, 5), 0));
+}
+
+TEST_F(SPIRVTypeAnalysisTest, AllocaArrayArray) {
+  StringRef Assembly = R"(
+    define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
+      %test = alloca [5 x [10 x i32]]
+      ret void
+    }
+  )";
+
+  auto TI = runAnalysis(Assembly);
+  EXPECT_EQ(
+      TI.getType(getValue("test")),
+      TypedPointerType::get(ArrayType::get(ArrayType::get(IntTy, 10), 5), 0));
+}
+
+TEST_F(SPIRVTypeAnalysisTest, AllocaVector) {
+  StringRef Assembly = R"(
+    define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
+      %test = alloca <4 x i32>
+      ret void
+    }
+  )";
+
+  auto TI = runAnalysis(Assembly);
+  EXPECT_EQ(TI.getType(getValue("test")),
+            TypedPointerType::get(VectorType::get(IntTy, 4, false), 0));
+}
+
+TEST_F(SPIRVTypeAnalysisTest, AllocaArrayVector) {
+  StringRef Assembly = R"(
+    define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
+      %test = alloca [5 x <4 x i32>]
+      ret void
+    }
+  )";
+
+  auto TI = runAnalysis(Assembly);
+  EXPECT_EQ(TI.getType(getValue("test")),
+            TypedPointerType::get(
+                ArrayType::get(VectorType::get(IntTy, 4, false), 5), 0));
+}
+
+TEST_F(SPIRVTypeAnalysisTest, AllocaLoad) {
+  StringRef Assembly = R"(
+    define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
+      %test = alloca [5 x <4 x i32>]
+      %v = load <4 x i32>, ptr %test
+      ret void
+    }
+  )";
+
+  auto TI = runAnalysis(Assembly);
+  auto VT = VectorType::get(IntegerType::get(M->getContext(), 32), 4,
+                            /* scalable= */ false);
+  EXPECT_EQ(TI.getType(getValue("test")),
+            TypedPointerType::get(ArrayType::get(VT, 5), 0));
+  EXPECT_EQ(TI.getType(getValue("v")), VT);
+}
+
+TEST_F(SPIRVTypeAnalysisTest, AllocaPtrIndirectDeduction) {
+  StringRef Assembly = R"(
+    define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
+      %v = alloca ptr
+      %l = load i32, ptr %v
+      ret void
+    }
+  )";
+
+  auto TI = runAnalysis(Assembly);
+  EXPECT_EQ(TI.getType(getValue("v")),
+            TypedPointerType::get(IntTy, /* AS= */ 0));
+  EXPECT_EQ(TI.getType(getValue("l")), IntTy);
+}
+
+TEST_F(SPIRVTypeAnalysisTest, AllocaPtrNestedIndirectDeduction) {
+  StringRef Assembly = R"(
+    define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
+      %v = alloca ptr
+      %l1 = load ptr, ptr %v
+      %l2 = load ptr, ptr %l1
+      %l3 = load ptr, ptr %l2
+      %l4 = load i32, ptr %l3
+      ret void
+    }
+  )";
+
+  auto TI = runAnalysis(Assembly);
+
+  Type *LoadType = IntTy;
+  EXPECT_EQ(TI.getType(getValue("l4")), LoadType);
+  LoadType = TypedPointerType::get(LoadType, /* AS= */ 0);
+  EXPECT_EQ(TI.getType(getValue("l3")), LoadType);
+  LoadType = TypedPointerType::get(LoadType, /* AS= */ 0);
+  EXPECT_EQ(TI.getType(getValue("l2")), LoadType);
+  LoadType = TypedPointerType::get(LoadType, /* AS= */ 0);
+  EXPECT_EQ(TI.getType(getValue("l1")), LoadType);
+  LoadType = TypedPointerType::get(LoadType, /* AS= */ 0);
+  EXPECT_EQ(TI.getType(getValue("v")), LoadType);
+}
+
+TEST_F(SPIRVTypeAnalysisTest, AllocaLoadArrayPtr) {
+  StringRef Assembly = R"(
+    define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
+      ; %ptr_array = alloca [5 x *i32]
+      %ptr_array = alloca [5 x ptr]
+      %l1 = load ptr, ptr %ptr_array
+      %l2 = load i32, ptr %l1
+      ret void
+    }
+  )";
+
+  auto TI = runAnalysis(Assembly);
+
+  EXPECT_EQ(TI.getType(getValue("ptr_array")),
+            ArrayType::get(TypedPointerType::get(IntTy, /* AS= */ 0), 5));
+  EXPECT_EQ(TI.getType(getValue("l1")),
+            TypedPointerType::get(IntTy, /* AS= */ 0));
+  EXPECT_EQ(TI.getType(getValue("l2")), IntTy);
+}
+
+TEST_F(SPIRVTypeAnalysisTest, AllocaStruct) {
+  StringRef Assembly = R"(
+    %st = type { i32 }
+
+    define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
+      %var = alloca %st
+      %l = load %st, ptr %var
+      ret void
+    }
+  )";
+
+  auto TI = runAnalysis(Assembly);
+  auto ST = getStructType("st");
+  EXPECT_EQ(TI.getType(getValue("var")),
+            TypedPointerType::get(ST, /* AS= */ 0));
+  EXPECT_EQ(TI.getType(getValue("l")), ST);
+}
+
+TEST_F(SPIRVTypeAnalysisTest, AllocaStructDeducedFromLoad) {
+  StringRef Assembly = R"(
+    %st = type { i32 }
+
+    define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
+      %var = alloca ptr
+      %l1 = load ptr, ptr %var
+      %l2 = load %st, ptr %l1
+      ret void
+    }
+  )";
+
+  auto TI = runAnalysis(Assembly);
+
+  Type *LoadType = getStructType("st");
+  EXPECT_EQ(TI.getType(getValue("l2")), LoadType);
+  LoadType = TypedPointerType::get(LoadType, /* AS= */ 0);
+  EXPECT_EQ(TI.getType(getValue("l1")), LoadType);
+  LoadType = TypedPointerType::get(LoadType, /* AS= */ 0);
+  EXPECT_EQ(TI.getType(getValue("var")), LoadType);
+}
+
+TEST_F(SPIRVTypeAnalysisTest, StructMemberDirectLoad) {
+  StringRef Assembly = R"(
+    %st = type { i32 }
+
+    define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
+      %var = alloca %st
+      %l2 = load i32, ptr %var
+      ret void
+    }
+  )";
+
+  auto TI = runAnalysis(Assembly);
+
+  Type *ST = getStructType("st");
+  EXPECT_EQ(TI.getType(getValue("l2")), IntTy);
+  EXPECT_EQ(TI.getType(getValue("var")),
+            TypedPointerType::get(ST, /* AS= */ 0));
+}
+
+TEST_F(SPIRVTypeAnalysisTest, StructMemberConflictingDeductionFromLoadA) {
+  StringRef Assembly = R"(
+    %st = type { i32 }
+
+    define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
+      ; %var = alloca *%st
+      %var = alloca ptr
+
+      %l1 = load ptr, ptr %var
+
+      %l2 = load %st, ptr %l1
+      %l3 = load i32, ptr %l1
+
+      ret void
+    }
+  )";
+
+  auto TI = runAnalysis(Assembly);
+
+  Type *ST = getStructType("st");
+  EXPECT_EQ(TI.getType(getValue("l3")), IntTy);
+  EXPECT_EQ(TI.getType(getValue("l2")), ST);
+  EXPECT_EQ(TI.getType(getValue("l1")), TypedPointerType::get(ST, /* AS= */ 0));
+  EXPECT_EQ(TI.getType(getValue("var")),
+            TypedPointerType::get(TypedPointerType::get(ST, /* AS= */ 0), 0));
+}
+
+TEST_F(SPIRVTypeAnalysisTest, StructMemberConflictingDeductionFromLoadB) {
+  StringRef Assembly = R"(
+    %st = type { i32 }
+
+    define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
+      ; %var = alloca *%st
+      %var = alloca ptr
+
+      %l1 = load ptr, ptr %var
+
+      %l3 = load i32, ptr %l1
+      %l2 = load %st, ptr %l1
+
+      ret void
+    }
+  )";
+
+  auto TI = runAnalysis(Assembly);
+
+  Type *ST = getStructType("st");
+  EXPECT_EQ(TI.getType(getValue("l3")), IntTy);
+  EXPECT_EQ(TI.getType(getValue("l2")), ST);
+  EXPECT_EQ(TI.getType(getValue("l1")), TypedPointerType::get(ST, /* AS= */ 0));
+  EXPECT_EQ(TI.getType(getValue("var")),
+            TypedPointerType::get(TypedPointerType::get(ST, /* AS= */ 0), 0));
+}
+
+TEST_F(SPIRVTypeAnalysisTest, StructMemberConflictingTree) {
+  StringRef Assembly = R"(
+    %st = type { i32 }
+
+    define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
+      ; %var = alloca *%st
+      %var = alloca ptr
+
+      %l_0 = load ptr, ptr %var
+      %l_1 = load ptr, ptr %var
+
+      %l_00 = load %st, ptr %l_0
+      %l_01 = load %st, ptr %l_1
+
+      %l_10 = load i32, ptr %l_0
+      %l_11 = load i32, ptr %l_1
+
+      ret void
+    }
+  )";
+
+  auto TI = runAnalysis(Assembly);
+
+  Type *ST = getStructType("st");
+  EXPECT_EQ(TI.getType(getValue("l_10")), IntTy);
+  EXPECT_EQ(TI.getType(getValue("l_11")), IntTy);
+
+  EXPECT_EQ(TI.getType(getValue("l_00")), ST);
+  EXPECT_EQ(TI.getType(getValue("l_01")), ST);
+
+  EXPECT_EQ(TI.getType(getValue("l_0")),
+            TypedPointerType::get(ST, /* AS= */ 0));
+  EXPECT_EQ(TI.getType(getValue("l_1")),
+            TypedPointerType::get(ST, /* AS= */ 0));
+
+  EXPECT_EQ(TI.getType(getValue("var")),
+            TypedPointerType::get(TypedPointerType::get(ST, /* AS= */ 0), 0));
+}
+
+TEST_F(SPIRVTypeAnalysisTest, ArrayElementConflict) {
+  StringRef Assembly = R"(
+    %st = type { i32 }
+
+    define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
+      ; %var = alloca *[5 x %st]
+      %var = alloca ptr
+
+      %l1 = load ptr, ptr %var
+
+      %l2 = load %st, ptr %l1
+      %l3 = load i32, ptr %l1
+      %l4 = load [5 x %st], ptr %l1
+
+      ret void
+    }
+  )";
+
+  auto TI = runAnalysis(Assembly);
+
+  Type *ST = getStructType("st");
+  EXPECT_EQ(TI.getType(getValue("l4")), ArrayType::get(ST, 5));
+  EXPECT_EQ(TI.getType(getValue("l3")), IntTy);
+  EXPECT_EQ(TI.getType(getValue("l2")), ST);
+  EXPECT_EQ(TI.getType(getValue("l1")),
+            TypedPointerType::get(ArrayType::get(ST, 5), /* AS= */ 0));
+  EXPECT_EQ(TI.getType(getValue("var")),
+            TypedPointerType::get(
+                TypedPointerType::get(ArrayType::get(ST, 5), /* AS= */ 0), 0));
+}
+
+TEST_F(SPIRVTypeAnalysisTest, VectorElementConflict) {
+  StringRef Assembly = R"(
+    define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
+      ; %var = alloca *<4 x i32>
+      %var = alloca ptr
+
+      %l1 = load ptr, ptr %var
+
+      %l3 = load i32, ptr %l1
+      %l4 = load <4 x i32>, ptr %l1
+
+      ret void
+    }
+  )";
+
+  auto TI = runAnalysis(Assembly);
+
+  auto *VT = VectorType::get(IntTy, 4, /* scalable= */ false);
+
+  EXPECT_EQ(TI.getType(getValue("l4")), VT);
+  EXPECT_EQ(TI.getType(getValue("l3")), IntTy);
+  EXPECT_EQ(TI.getType(getValue("l1")), TypedPointerType::get(VT, /* AS= */ 0));
+  EXPECT_EQ(TI.getType(getValue("var")),
+            TypedPointerType::get(TypedPointerType::get(VT, /* AS= */ 0), 0));
+}
+
+TEST_F(SPIRVTypeAnalysisTest, MissingInformationOnLoad) {
+  StringRef Assembly = R"(
+    define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
+      %var = alloca ptr
+      %l1 = load ptr, ptr %var
+      ret void
+    }
+  )";
+
+  auto TI = runAnalysis(Assembly);
+
+  EXPECT_EQ(TI.getType(getValue("l1")),
+            PointerType::get(M->getContext(), /* AS= */ 0));
+  EXPECT_EQ(TI.getType(getValue("var")),
+            PointerType::get(M->getContext(), /* AS= */ 0));
+}
+
+TEST_F(SPIRVTypeAnalysisTest, MissingInformationAlloca) {
+  StringRef Assembly = R"(
+    define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
+      %var = alloca ptr
+      ret void
+    }
+  )";
+
+  auto TI = runAnalysis(Assembly);
+
+  EXPECT_EQ(TI.getType(getValue("var")),
+            PointerType::get(M->getContext(), /* AS= */ 0));
+}
+
+TEST_F(SPIRVTypeAnalysisTest, DeduceFromGep) {
+  StringRef Assembly = R"(
+    define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
+      %var = alloca ptr
+      %ptr = getelementptr [5 x i32], ptr %var, i64 0, i64 0
+      %val = load i32, ptr %ptr
+      ret void
+    }
+  )";
+
+  auto TI = runAnalysis(Assembly);
+  Type *AT = ArrayType::get(IntTy, 5);
+  EXPECT_EQ(TI.getType(getValue("var")),
+            TypedPointerType::get(AT, /* AS= */ 0));
+}
+
+TEST_F(SPIRVTypeAnalysisTest, DeduceFromGepOpaqueBaseType) {
+  StringRef Assembly = R"(
+    define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
+      ; *[5 x *[2 x i32] ]
+      ; *[5 x ptr ]
+      ;   ptr
+      %var = alloca ptr
+      %ptr1 = getelementptr ptr, ptr %var, i64 0
+      %ptr2 = getelementptr ptr, ptr %ptr1, i64 0
+      %val = load i32, ptr %ptr2
+      ret void
+    }
+  )";
+
+  auto TI = runAnalysis(Assembly);
+  Type *T = IntTy;
+  EXPECT_EQ(TI.getType(getValue("val")), T);
+  T = TypedPointerType::get(T, /* AS= */ 0);
+  EXPECT_EQ(TI.getType(getValue("ptr2")), T);
+  T = TypedPointerType::get(T, /* AS= */ 0);
+  EXPECT_EQ(TI.getType(getValue("ptr1")), T);
+  T = TypedPointerType::get(T, /* AS= */ 0);
+  EXPECT_EQ(TI.getType(getValue("var")), T);
+}
+
+TEST_F(SPIRVTypeAnalysisTest, DeduceFromGepPartialOpaqueBaseType) {
+  StringRef Assembly = R"(
+    define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
+      ; *[5 x *[2 x i32] ]
+      ; *[5 x ptr ]
+      ;   ptr
+      %var = alloca ptr
+      %ptr1 = getelementptr [5 x ptr], ptr %var, i64 0, i64 0
+      %ptr2 = getelementptr ptr, ptr %ptr1, i64 0
+      %val = load i32, ptr %ptr2
+      ret void
+    }
+  )";
+
+  auto TI = runAnalysis(Assembly);
+  Type *T = IntTy;
+  EXPECT_EQ(TI.getType(getValue("val")), T);
+  T = TypedPointerType::get(T, /* AS= */ 0);
+  EXPECT_EQ(TI.getType(getValue("ptr2")), T);
+  T = TypedPointerType::get(T, /* AS= */ 0);
+  EXPECT_EQ(TI.getType(getValue("ptr1")), T);
+  T = TypedPointerType::get(ArrayType::get(T, 5), 0);
+  EXPECT_EQ(TI.getType(getValue("var")), T);
+}
+
+TEST_F(SPIRVTypeAnalysisTest, DeduceParamFromLoad) {
+  StringRef Assembly = R"(
+    define void @foo(ptr %input) {
+      %a = load i32, ptr %input
+      ret void
+    }
+  )";
+
+  auto TI = runAnalysis(Assembly);
+  EXPECT_EQ(TI.getType(getValue("input")), TypedPointerType::get(IntTy, 0));
+}
+
+TEST_F(SPIRVTypeAnalysisTest, DeduceFromStore) {
+  StringRef Assembly = R"(
+    define void @foo(ptr %input) {
+      store i32 0, ptr %input
+      ret void
+    }
+  )";
+
+  auto TI = runAnalysis(Assembly);
+  EXPECT_EQ(TI.getType(getValue("input")), TypedPointerType::get(IntTy, 0));
+}
+
+TEST_F(SPIRVTypeAnalysisTest, DeduceFromStoreStructConflict) {
+  StringRef Assembly = R"(
+    %st = type { i32 }
+
+    define void @foo(ptr %a, ptr %b) {
+      %s = load %st, ptr %a
+      store i32 0, ptr %b
+      store %st %s, ptr %b
+      ret void
+    }
+  )";
+
+  auto TI = runAnalysis(Assembly);
+  auto ST = getStructType("st");
+  EXPECT_EQ(TI.getType(getValue("a")), TypedPointerType::get(ST, 0));
+  EXPECT_EQ(TI.getType(getValue("b")), TypedPointerType::get(ST, 0));
+  EXPECT_EQ(TI.getType(getValue("s")), ST);
+}
+
+TEST_F(SPIRVTypeAnalysisTest, DeduceFromStoreStructInline) {
+  StringRef Assembly = R"(
+    %st = type { i32 }
+
+    define void @foo(ptr %a) {
+      store i32 0, ptr %a
+      store %st { i32 0 }, ptr %a
+      ret void
+    }
+  )";
+
+  auto TI = runAnalysis(Assembly);
+  auto ST = getStructType("st");
+  EXPECT_EQ(TI.getType(getValue("a")), TypedPointerType::get(ST, 0));
+}
+
+TEST_F(SPIRVTypeAnalysisTest, DeduceFromStoreArray) {
+  StringRef Assembly = R"(
+    define void @foo(ptr %a) {
+      store i32 0, ptr %a
+      store [2 x i32] [ i32 0, i32 1 ], ptr %a
+      ret void
+    }
+  )";
+
+  auto TI = runAnalysis(Assembly);
+  EXPECT_EQ(TI.getType(getValue("a")),
+            TypedPointerType::get(ArrayType::get(IntTy, 2), 0));
+}
+
+TEST_F(SPIRVTypeAnalysisTest, DeduceFromCall) {
+  StringRef Assembly = R"(
+    define ptr @foo(ptr %par) {
+      ret ptr %par
+    }
+
+    define void @bar() {
+      %var = alloca i32
+      %res = call ptr @foo(ptr %var)
+      store i32 0, ptr %res
+      ret void
+    }
+  )";
+
+  auto TI = runAnalysis(Assembly);
+  EXPECT_EQ(TI.getType(getValue("var")), TypedPointerType::get(IntTy, 0));
+  EXPECT_EQ(TI.getType(getValue("res")), TypedPointerType::get(IntTy, 0));
+  EXPECT_EQ(TI.getType(getValue("par")), TypedPointerType::get(IntTy, 0));
+}
+
+TEST_F(SPIRVTypeAnalysisTest, DeduceFromCallLackInfo) {
+  StringRef Assembly = R"(
+    define ptr @foo(ptr %par) {
+      ret ptr %par
+    }
+
+    define void @bar() {
+      %var = alloca ptr
+      %res = call ptr @foo(ptr %var)
+      ret void
+    }
+  )";
+
+  auto TI = runAnalysis(Assembly);
+  EXPECT_EQ(TI.getType(getValue("var")),
+            PointerType::get(M->getContext(), /* AS= */ 0));
+  EXPECT_EQ(TI.getType(getValue("res")),
+            PointerType::get(M->getContext(), /* AS= */ 0));
+  EXPECT_EQ(TI.getType(getValue("par")),
+            PointerType::get(M->getContext(), /* AS= */ 0));
+}
+
+TEST_F(SPIRVTypeAnalysisTest, DeduceFromCallPartial) {
+  StringRef Assembly = R"(
+    define ptr @foo(ptr %fpar1, ptr %fpar2) {
+      ret ptr %fpar2
+    }
+
+    define void @bar(ptr %bpar1, ptr %bpar2) {
+      %res = call ptr @foo(ptr %bpar1, ptr %bpar2)
+      store i32 0, ptr %res
+      ret void
+    }
+  )";
+
+  auto TI = runAnalysis(Assembly);
+  EXPECT_EQ(TI.getType(getValue("fpar1")),
+            PointerType::get(M->getContext(), /* AS= */ 0));
+  EXPECT_EQ(TI.getType(getValue("bpar1")),
+            PointerType::get(M->getContext(), /* AS= */ 0));
+  EXPECT_EQ(TI.getType(getValue("res")), TypedPointerType::get(IntTy, 0));
+  EXPECT_EQ(TI.getType(getValue("fpar2")), TypedPointerType::get(IntTy, 0));
+  EXPECT_EQ(TI.getType(getValue("bpar2")), TypedPointerType::get(IntTy, 0));
+}
+
+TEST_F(SPIRVTypeAnalysisTest, DeduceRecursive) {
+  StringRef Assembly = R"(
+    define ptr @foo(ptr %par) {
+      store i32 0, ptr %par
+      %tmp = call ptr @foo(ptr %par)
+      ret ptr %par
+    }
+  )";
+
+  auto TI = runAnalysis(Assembly);
+  EXPECT_EQ(TI.getType(getValue("tmp")), TypedPointerType::get(IntTy, 0));
+  EXPECT_EQ(TI.getType(getValue("par")), TypedPointerType::get(IntTy, 0));
+}



More information about the llvm-commits mailing list