[llvm] [SPIR-V] Add type analysis pass (PR #131348)
Nathan Gauër via llvm-commits
llvm-commits at lists.llvm.org
Fri Mar 14 09:00:44 PDT 2025
https://github.com/Keenuts created https://github.com/llvm/llvm-project/pull/131348
The backend currenty generates 2 intrinsics to deduce value types:
- spv.assign_type
- spv.assign_ptr_type
This has 2 issues:
- if adds many instructions in the IR which makes passes working on those types harder to implement: we want to make sure not to break deduces types when modifying the IR. And if we change a type, we need to reimplement the type propagation logic.
- it's currently implemented in the EmitIntrinsics pass, which also lowers other LLVM-IR instructions into SPV intrinsics, meaning passes requiring type information also have trouble since they need to understand both LLVM-IR, and SPV intrinsics equivalents (ex spv.gep vs getelementptr).
- Lastly, OpenCL can to pointercast, meaning the ptr type deduction can be 'best-effort', and fallback to *i8. For graphical SPIR-V, this is illegal, so we must make sure each pointer is strongly typed.
This commit adds an analysis which given a Value returns the non-opaque type (as it type which contains no opaque pointer).
For now, this analysis implements a subset of the LLVM-IR, hence this is still a draft.
>From 4c1d357b4d6d0712f4daa4d14bb2117333b300bd Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Nathan=20Gau=C3=ABr?= <brioche at google.com>
Date: Mon, 10 Mar 2025 16:23:27 +0100
Subject: [PATCH] [SPIR-V] Add type analysis pass
The backend currenty generates 2 intrinsics to deduce value
types:
- spv.assign_type
- spv.assign_ptr_type
This has 2 issues:
- if adds many instructions in the IR which makes passes working on
those types harder to implement: we want to make sure not to
break deduces types when modifying the IR. And if we change a type,
we need to reimplement the type propagation logic.
- it's currently implemented in the EmitIntrinsics pass, which also
lowers other LLVM-IR instructions into SPV intrinsics, meaning passes
requiring type information also have trouble since they need to
understand both LLVM-IR, and SPV intrinsics equivalents (ex spv.gep
vs getelementptr).
- Lastly, OpenCL can to pointercast, meaning the ptr type deduction can
be 'best-effort', and fallback to *i8. For graphical SPIR-V, this is
illegal, so we must make sure each pointer is strongly typed.
This commit adds an analysis which given a Value returns the non-opaque
type (as it type which contains no opaque pointer).
For now, this analysis implements a subset of the LLVM-IR, but most hard
cases should be handled:
- cross-function type deduction
- phi/select type deduction
- inner element access through pointer load/store.
- type conflict caused by those inner element loads.
---
llvm/lib/Target/SPIRV/Analysis/CMakeLists.txt | 1 +
.../SPIRV/Analysis/SPIRVTypeAnalysis.cpp | 453 +++++++++++
.../Target/SPIRV/Analysis/SPIRVTypeAnalysis.h | 95 +++
llvm/lib/Target/SPIRV/SPIRV.h | 1 +
llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp | 10 +-
llvm/unittests/Target/SPIRV/CMakeLists.txt | 1 +
.../Target/SPIRV/SPIRVTypeAnalysisTests.cpp | 722 ++++++++++++++++++
7 files changed, 1281 insertions(+), 2 deletions(-)
create mode 100644 llvm/lib/Target/SPIRV/Analysis/SPIRVTypeAnalysis.cpp
create mode 100644 llvm/lib/Target/SPIRV/Analysis/SPIRVTypeAnalysis.h
create mode 100644 llvm/unittests/Target/SPIRV/SPIRVTypeAnalysisTests.cpp
diff --git a/llvm/lib/Target/SPIRV/Analysis/CMakeLists.txt b/llvm/lib/Target/SPIRV/Analysis/CMakeLists.txt
index 4d4351132d3de..30661413aaeb1 100644
--- a/llvm/lib/Target/SPIRV/Analysis/CMakeLists.txt
+++ b/llvm/lib/Target/SPIRV/Analysis/CMakeLists.txt
@@ -1,5 +1,6 @@
add_llvm_component_library(LLVMSPIRVAnalysis
SPIRVConvergenceRegionAnalysis.cpp
+ SPIRVTypeAnalysis.cpp
LINK_COMPONENTS
Analysis
diff --git a/llvm/lib/Target/SPIRV/Analysis/SPIRVTypeAnalysis.cpp b/llvm/lib/Target/SPIRV/Analysis/SPIRVTypeAnalysis.cpp
new file mode 100644
index 0000000000000..992994943f077
--- /dev/null
+++ b/llvm/lib/Target/SPIRV/Analysis/SPIRVTypeAnalysis.cpp
@@ -0,0 +1,453 @@
+//===- SPIRVTypeAnalysis.h -----------------------------*- C++ -*--===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This analysis links a type information to every register/pointer, allowing
+// us to legalize type mismatches when required (graphical SPIR-V pointers for
+// ex).
+//
+//===----------------------------------------------------------------------===//
+
+#include "SPIRVTypeAnalysis.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/TypedPointerType.h"
+#include "llvm/InitializePasses.h"
+
+#include <optional>
+#include <queue>
+
+#define DEBUG_TYPE "spirv-type-analysis"
+
+using namespace llvm;
+
+namespace llvm {
+void initializeSPIRVTypeAnalysisWrapperPassPass(PassRegistry &);
+} // namespace llvm
+
+INITIALIZE_PASS_BEGIN(SPIRVTypeAnalysisWrapperPass, "type-analysis",
+ "SPIRV type analysis", true, true)
+INITIALIZE_PASS_END(SPIRVTypeAnalysisWrapperPass, "type-region",
+ "SPIRV type analysis", true, true)
+
+namespace llvm {
+namespace SPIRV {
+namespace {} // anonymous namespace
+
+// Returns true this type contains no opaque pointers (recursively).
+bool TypeInfo::isOpaqueType(const Type *T) {
+ if (T->isPointerTy())
+ return true;
+
+ if (const ArrayType *AT = dyn_cast<ArrayType>(T))
+ return TypeInfo::isOpaqueType(AT->getElementType());
+ if (const VectorType *VT = dyn_cast<VectorType>(T))
+ return TypeInfo::isOpaqueType(VT->getElementType());
+
+ return false;
+}
+
+class TypeAnalyzer {
+
+public:
+ TypeAnalyzer(Module &M)
+ : M(M), TypeMap(new DenseMap<const Value *, Type *>()) {}
+
+ TypeInfo analyze() {
+ for (const Function &F : M) {
+ for (const BasicBlock &BB : F) {
+ for (const Value &V : BB) {
+ if (!deduceElementType(&V))
+ IncompleteTypeDefinition.insert(&V);
+ }
+ }
+ }
+
+ size_t IncompleteCount;
+ do {
+ IncompleteCount = IncompleteTypeDefinition.size();
+ for (const Value *Item : IncompleteTypeDefinition) {
+ if (deduceElementType(Item)) {
+ IncompleteTypeDefinition.erase(Item);
+ break;
+ }
+ }
+ } while (IncompleteTypeDefinition.size() < IncompleteCount);
+
+ return TypeInfo(TypeMap);
+ }
+
+private:
+ Type *getMappedType(const Value *V) {
+ auto It = TypeMap->find(V);
+ if (It == TypeMap->end())
+ return nullptr;
+ return It->second;
+ }
+
+ bool typeContainsType(Type *Wrapper, Type *Needle) {
+ if (Wrapper == Needle)
+ return true;
+
+ TypedPointerType *LP = dyn_cast<TypedPointerType>(Wrapper);
+ TypedPointerType *RP = dyn_cast<TypedPointerType>(Needle);
+ if (LP && RP)
+ return typeContainsType(LP->getElementType(), RP->getElementType());
+
+ if (StructType *ST = dyn_cast<StructType>(Wrapper))
+ return typeContainsType(ST->getElementType(0), Needle);
+ if (ArrayType *AT = dyn_cast<ArrayType>(Wrapper))
+ return typeContainsType(AT->getElementType(), Needle);
+ if (VectorType *VT = dyn_cast<VectorType>(Wrapper))
+ return typeContainsType(VT->getElementType(), Needle);
+
+ return false;
+ }
+
+ Type *resolveTypeConflict(TypedPointerType *A, TypedPointerType *B) {
+ if (typeContainsType(A->getElementType(), B->getElementType()))
+ return A;
+ if (typeContainsType(B->getElementType(), A->getElementType()))
+ return B;
+ return nullptr;
+ }
+
+ void propagateType(Type *DeducedType, const Value *V) {
+ assert(!TypeInfo::isOpaqueType(DeducedType));
+
+ auto It = TypeMap->find(V);
+ // The value type has already been deduced.
+ if (It != TypeMap->end()) {
+ // There is no conflict.
+ if (DeducedType == It->second)
+ return;
+
+ TypedPointerType *DeducedPtrType =
+ dyn_cast<TypedPointerType>(DeducedType);
+ TypedPointerType *KnownPtrType = dyn_cast<TypedPointerType>(It->second);
+ // Cannot resolve conflict on non-pointer types.
+ if (!DeducedPtrType || !KnownPtrType) {
+ assert(0); // FIXME: shall I ignore, fail, crash?
+ return;
+ }
+
+ DeducedType = resolveTypeConflict(DeducedPtrType, KnownPtrType);
+ if (!DeducedType)
+ return;
+ (*TypeMap)[V] = DeducedType;
+ }
+
+ if (const Constant *C = dyn_cast<Constant>(V))
+ propagateTypeDetails(DeducedType, C);
+ else if (const Argument *C = dyn_cast<Argument>(V))
+ propagateTypeDetails(DeducedType, C);
+ else if (const AllocaInst *C = dyn_cast<AllocaInst>(V))
+ propagateTypeDetails(DeducedType, C);
+ else if (const CallInst *C = dyn_cast<CallInst>(V))
+ propagateTypeDetails(DeducedType, C);
+ else if (const GetElementPtrInst *C = dyn_cast<GetElementPtrInst>(V))
+ propagateTypeDetails(DeducedType, C);
+ else if (const LoadInst *C = dyn_cast<LoadInst>(V))
+ propagateTypeDetails(DeducedType, C);
+ else if (const ReturnInst *C = dyn_cast<ReturnInst>(V))
+ propagateTypeDetails(DeducedType, C);
+ else if (const StoreInst *C = dyn_cast<StoreInst>(V))
+ propagateTypeDetails(DeducedType, C);
+ else
+ llvm_unreachable("FIXME: unsupported instruction");
+
+ // for (const User *U : V->users())
+ // if (TypeMap->find(U) == TypeMap->end())
+ // deduceElementType(U);
+
+ // X(CallInst, V);
+ // X(GetElementPtrInst, V);
+ // X(LoadInst, V);
+ // X(ReturnInst, V);
+ // X(StoreInst, V);
+ // TODO: GlobalValue
+ // TODO: addrspacecast
+ // TODO: bitcast
+ // TODO: AtomicCmpXchgInst
+ // TODO: AtomicRMWInst
+ // TODO: PHINode
+ // TODO: SelectInst
+ // TODO: CallInst
+ }
+
+ void propagateTypeDetails(Type *DeducedType, const Constant *C) {
+ (*TypeMap)[C] = DeducedType;
+ }
+
+ void propagateTypeDetails(Type *DeducedType, const Argument *A) {
+ (*TypeMap)[A] = DeducedType;
+
+ unsigned ArgNo = A->getArgNo();
+ for (const User *U : A->getParent()->users()) {
+ const CallInst *CI = cast<CallInst>(U);
+ propagateType(DeducedType, CI->getOperand(ArgNo));
+ }
+ }
+
+ void propagateTypeDetails(Type *DeducedType, const LoadInst *I) {
+ TypeMap->try_emplace(I, DeducedType);
+ propagateType(TypedPointerType::get(DeducedType, 0),
+ I->getPointerOperand());
+ }
+
+ void propagateTypeDetails(Type *DeducedType, const StoreInst *I) {
+ TypeMap->try_emplace(I, DeducedType);
+ }
+
+ void propagateTypeDetails(Type *DeducedType, const AllocaInst *I) {
+ Type *StoredType = cast<TypedPointerType>(DeducedType)->getElementType();
+ if (ArrayType *AT = dyn_cast<ArrayType>(I->getAllocatedType())) {
+ Type *NewType = ArrayType::get(StoredType, AT->getNumElements());
+ TypeMap->try_emplace(I, NewType);
+ return;
+ }
+
+ TypeMap->try_emplace(I, DeducedType);
+ }
+
+ void propagateTypeDetails(Type *DeducedType, const GetElementPtrInst *GEP) {
+ if (!TypeInfo::isOpaqueType(GEP->getSourceElementType())) {
+ // If the source is non-opaque, the result must be non-opaque (subset of
+ // source). If not, this means the GEP is using the wrong base-type. We
+ // don't support this.
+ assert(!TypeInfo::isOpaqueType(GEP->getResultElementType()));
+ // If the result is non-opaque, this means each use was non-opaque, and
+ // thus we shouldn't have a deduction mismatch. If this happens, something
+ // is wrong with this analysis.
+ TypedPointerType *DeducedPtr = cast<TypedPointerType>(DeducedType);
+ assert(GEP->getResultElementType() == DeducedPtr->getElementType());
+ propagateType(TypedPointerType::get(GEP->getSourceElementType(), 0),
+ GEP->getPointerOperand());
+ return;
+ }
+
+ TypeMap->try_emplace(GEP, DeducedType);
+
+ // The source type is opaque. We might be able to deduce more info from the
+ // new result type.
+ Type *NewType = DeducedType;
+ std::vector<Type *> Types = {GEP->getSourceElementType()};
+ std::vector<uint64_t> Indices;
+ for (const Use &U : GEP->indices())
+ Indices.push_back(cast<ConstantInt>(&*U)->getZExtValue());
+
+ for (unsigned I = 1; I < Indices.size(); ++I)
+ Types.push_back(
+ GetElementPtrInst::getTypeAtIndex(Types[I - 1], Indices[I]));
+
+ for (unsigned I = 1; I < GEP->getNumIndices(); ++I) {
+ unsigned Index = GEP->getNumIndices() - 1 - I;
+ Type *T = Types[Index];
+ if (T->isPointerTy())
+ return;
+
+ if (ArrayType *AT = dyn_cast<ArrayType>(T))
+ NewType = ArrayType::get(NewType, AT->getNumElements());
+ else if (VectorType *VT = dyn_cast<VectorType>(T))
+ NewType = VectorType::get(NewType, VT->getElementCount());
+ else if (StructType *ST = dyn_cast<StructType>(T))
+ assert(0 && "Opaque struct types are not supported.");
+ else
+ llvm_unreachable("Unsupported aggregate type?");
+ }
+
+ // The first index of a GEP is indexing from the passed pointer. So we need
+ // to add one layer.
+ NewType = TypedPointerType::get(NewType, 0);
+
+ propagateType(NewType, GEP->getPointerOperand());
+ }
+
+ void propagateTypeDetails(Type *DeducedType, const CallInst *CI) {
+ if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI))
+ llvm_unreachable("Not implemented");
+ // return deduceType(II);
+
+ FunctionType *FT = CI->getFunctionType();
+ if (!TypeInfo::isOpaqueType(FT->getReturnType())) {
+ assert(FT->getReturnType() == DeducedType);
+ return;
+ }
+
+ TypeMap->try_emplace(CI, DeducedType);
+ propagateType(DeducedType, CI->getCalledFunction());
+
+ for (const BasicBlock &BB : *CI->getCalledFunction())
+ for (const Instruction &I : BB)
+ if (const ReturnInst *RI = dyn_cast<ReturnInst>(&I))
+ propagateType(DeducedType, RI);
+ }
+
+ void propagateTypeDetails(Type *DeducedType, const ReturnInst *RI) {
+ Value *RV = RI->getReturnValue();
+ assert(RV || DeducedType->isVoidTy());
+
+ TypeMap->try_emplace(RI, DeducedType);
+ propagateType(DeducedType, RV);
+ }
+
+ bool deduceElementType(const Value *V) {
+ assert(V != nullptr);
+
+ auto It = TypeMap->find(V);
+ if (It != TypeMap->end())
+ return true;
+
+#define X(Type, Value) \
+ if (auto *Casted = dyn_cast<Type>(Value)) \
+ return deduceType(Casted);
+
+ X(AllocaInst, V);
+ X(CallInst, V);
+ X(GetElementPtrInst, V);
+ X(LoadInst, V);
+ X(ReturnInst, V);
+ X(StoreInst, V);
+ // TODO: GlobalValue
+ // TODO: addrspacecast
+ // TODO: bitcast
+ // TODO: AtomicCmpXchgInst
+ // TODO: AtomicRMWInst
+ // TODO: PHINode
+ // TODO: SelectInst
+ // TODO: CallInst
+#undef X
+
+ llvm_unreachable("FIXME: unsupported instruction");
+ return false;
+ }
+
+ bool deduceType(const AllocaInst *I) {
+ if (TypeInfo::isOpaqueType(I->getAllocatedType()))
+ return false;
+ TypeMap->try_emplace(I, TypedPointerType::get(I->getAllocatedType(), 0));
+ return true;
+ }
+
+ bool deduceType(const LoadInst *I) {
+ // First case: the loaded type is complete: we can assign the result type.
+ if (!TypeInfo::isOpaqueType(I->getType())) {
+ TypeMap->try_emplace(I, I->getType());
+ propagateType(TypedPointerType::get(I->getType(), 0),
+ I->getPointerOperand());
+ return true;
+ }
+
+ // The pointer operand is non-opaque, we can deduce the loaded type.
+ if (Type *PointerOperandTy = getMappedType(I->getPointerOperand())) {
+ // FIXME: only supports pointer of pointers for now.
+ Type *ElementType =
+ cast<TypedPointerType>(PointerOperandTy)->getElementType();
+ assert(!ElementType->isPointerTy());
+ TypeMap->try_emplace(I, ElementType);
+ return true;
+ }
+
+ return false;
+ }
+
+ Type *getDeducedType(const Value *V) {
+ auto It = TypeMap->find(V);
+ return It == TypeMap->end() ? nullptr : It->second;
+ }
+
+ bool deduceType(const StoreInst *I) {
+ Type *ValueType = I->getValueOperand()->getType();
+ Type *SourceType = getDeducedType(I->getPointerOperand());
+ bool isOpaqueType = TypeInfo::isOpaqueType(ValueType);
+
+ if (isOpaqueType && !SourceType)
+ return false;
+
+ if (isOpaqueType)
+ ValueType = cast<TypedPointerType>(SourceType)->getElementType();
+ propagateType(ValueType, I->getValueOperand());
+ propagateType(TypedPointerType::get(ValueType, 0), I->getPointerOperand());
+ return true;
+ }
+
+ bool deduceType(const ReturnInst *I) {
+ Value *RV = I->getReturnValue();
+ if (nullptr == RV) {
+ TypeMap->try_emplace(I, Type::getVoidTy(I->getContext()));
+ return true;
+ }
+
+ Type *T = TypeInfo::isOpaqueType(RV->getType()) ? getDeducedType(RV)
+ : RV->getType();
+ if (nullptr == T)
+ return false;
+
+ TypeMap->try_emplace(I, T);
+ propagateType(T, I->getFunction());
+ return true;
+ }
+
+ bool deduceType(const GetElementPtrInst *I) {
+ if (TypeInfo::isOpaqueType(I->getResultElementType()))
+ return false;
+
+ Type *T = TypedPointerType::get(I->getResultElementType(), 0);
+ // TypeMap->try_emplace(I, T);
+ propagateType(T, I);
+ return true;
+ }
+
+ bool deduceType(const CallInst *CI) {
+ if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI))
+ llvm_unreachable("Not implemented");
+ // return deduceType(II);
+
+ Type *ReturnType = CI->getFunctionType()->getReturnType();
+ ReturnType = TypeInfo::isOpaqueType(ReturnType)
+ ? getDeducedType(CI->getCalledFunction())
+ : ReturnType;
+ if (nullptr == ReturnType)
+ return false;
+
+ TypeMap->try_emplace(CI, ReturnType);
+ propagateType(ReturnType, CI->getCalledFunction());
+ return true;
+ }
+
+public:
+ Module &M;
+ DenseMap<const Value *, Type *> *TypeMap;
+ std::unordered_set<const Value *> IncompleteTypeDefinition;
+};
+
+TypeInfo getTypeInfo(Module &M) {
+ TypeAnalyzer Analyzer(M);
+ return Analyzer.analyze();
+}
+
+} // namespace SPIRV
+
+char SPIRVTypeAnalysisWrapperPass::ID = 0;
+
+SPIRVTypeAnalysisWrapperPass::SPIRVTypeAnalysisWrapperPass() : ModulePass(ID) {}
+
+bool SPIRVTypeAnalysisWrapperPass::runOnModule(Module &M) {
+ TI = SPIRV::getTypeInfo(M);
+ return false;
+}
+
+SPIRVTypeAnalysis::Result SPIRVTypeAnalysis::run(Module &M,
+ ModuleAnalysisManager &MAM) {
+ return SPIRV::getTypeInfo(M);
+}
+
+AnalysisKey SPIRVTypeAnalysis::Key;
+
+} // namespace llvm
diff --git a/llvm/lib/Target/SPIRV/Analysis/SPIRVTypeAnalysis.h b/llvm/lib/Target/SPIRV/Analysis/SPIRVTypeAnalysis.h
new file mode 100644
index 0000000000000..ba11bcc796c01
--- /dev/null
+++ b/llvm/lib/Target/SPIRV/Analysis/SPIRVTypeAnalysis.h
@@ -0,0 +1,95 @@
+//===- SPIRVTypeAnalysis.h ------------------------*- C++ -*--===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This analysis links a type information to every register/pointer, allowing
+// us to legalize type mismatches when required (graphical SPIR-V pointers for
+// ex).
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_LIB_TARGET_SPIRV_SPIRVTYPEANALYSIS_H
+#define LLVM_LIB_TARGET_SPIRV_SPIRVTYPEANALYSIS_H
+
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/Analysis/CFG.h"
+#include "llvm/IR/Dominators.h"
+#include "llvm/IR/IRBuilder.h"
+#include <iostream>
+#include <optional>
+#include <unordered_set>
+
+namespace llvm {
+class SPIRVSubtarget;
+class MachineFunction;
+class MachineModuleInfo;
+
+namespace SPIRV {
+
+// Holds a ConvergenceRegion hierarchy.
+class TypeInfo {
+ DenseMap<const Value *, Type *> *TypeMap;
+
+public:
+ TypeInfo() : TypeMap(nullptr) {}
+ TypeInfo(DenseMap<const Value *, Type *> *TypeMap) : TypeMap(TypeMap) {}
+
+ Type *getType(const Value *V) {
+ auto It = TypeMap->find(V);
+ if (It != TypeMap->end())
+ return It->second;
+
+ // In some cases, type deduction is not possible from the IR. This should
+ // only happen when handling opaque pointers, otherwise it means the type
+ // deduction is broken.
+ assert(V->getType()->isPointerTy());
+ return V->getType();
+ }
+
+ // Returns true this type contains no opaque pointers (recursively).
+ static bool isOpaqueType(const Type *T);
+};
+
+} // namespace SPIRV
+
+// Wrapper around the function above to use it with the legacy pass manager.
+class SPIRVTypeAnalysisWrapperPass : public ModulePass {
+ SPIRV::TypeInfo TI;
+
+public:
+ static char ID;
+
+ SPIRVTypeAnalysisWrapperPass();
+
+ void getAnalysisUsage(AnalysisUsage &AU) const override {
+ AU.setPreservesAll();
+ };
+
+ bool runOnModule(Module &F) override;
+
+ SPIRV::TypeInfo &getTypeInfo() { return TI; }
+ const SPIRV::TypeInfo &getTypeInfo() const { return TI; }
+};
+
+// Wrapper around the function above to use it with the new pass manager.
+class SPIRVTypeAnalysis : public AnalysisInfoMixin<SPIRVTypeAnalysis> {
+ friend AnalysisInfoMixin<SPIRVTypeAnalysis>;
+ static AnalysisKey Key;
+
+public:
+ using Result = SPIRV::TypeInfo;
+
+ Result run(Module &F, ModuleAnalysisManager &AM);
+};
+
+namespace SPIRV {
+TypeInfo getTypeInfo(Module &F);
+} // namespace SPIRV
+
+} // namespace llvm
+#endif // LLVM_LIB_TARGET_SPIRV_SPIRVTYPEANALYSIS_H
diff --git a/llvm/lib/Target/SPIRV/SPIRV.h b/llvm/lib/Target/SPIRV/SPIRV.h
index d765dfe370be2..6cd9c492ae13e 100644
--- a/llvm/lib/Target/SPIRV/SPIRV.h
+++ b/llvm/lib/Target/SPIRV/SPIRV.h
@@ -37,6 +37,7 @@ createSPIRVInstructionSelector(const SPIRVTargetMachine &TM,
void initializeSPIRVModuleAnalysisPass(PassRegistry &);
void initializeSPIRVConvergenceRegionAnalysisWrapperPassPass(PassRegistry &);
+void initializeSPIRVTypeAnalysisWrapperPassPass(PassRegistry &);
void initializeSPIRVPreLegalizerPass(PassRegistry &);
void initializeSPIRVPreLegalizerCombinerPass(PassRegistry &);
void initializeSPIRVPostLegalizerPass(PassRegistry &);
diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index 68651f4ee4d2f..6dddf310c547b 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -11,6 +11,7 @@
//
//===----------------------------------------------------------------------===//
+#include "Analysis/SPIRVTypeAnalysis.h"
#include "SPIRV.h"
#include "SPIRVBuiltins.h"
#include "SPIRVMetadata.h"
@@ -225,6 +226,7 @@ class SPIRVEmitIntrinsics
bool runOnModule(Module &M) override;
void getAnalysisUsage(AnalysisUsage &AU) const override {
+ // AU.addRequired<SPIRVTypeAnalysisWrapperPass>();
ModulePass::getAnalysisUsage(AU);
}
};
@@ -257,8 +259,11 @@ bool expectIgnoredInIRTranslation(const Instruction *I) {
char SPIRVEmitIntrinsics::ID = 0;
-INITIALIZE_PASS(SPIRVEmitIntrinsics, "emit-intrinsics", "SPIRV emit intrinsics",
- false, false)
+INITIALIZE_PASS_BEGIN(SPIRVEmitIntrinsics, "emit-intrinsics",
+ "SPIRV emit intrinsics", false, false)
+// INITIALIZE_PASS_DEPENDENCY(SPIRVTypeAnalysisWrapperPass)
+INITIALIZE_PASS_END(SPIRVEmitIntrinsics, "emit-intrinsics",
+ "SPIRV emit intrinsics", false, false)
static inline bool isAssignTypeInstr(const Instruction *I) {
return isa<IntrinsicInst>(I) &&
@@ -2551,6 +2556,7 @@ void SPIRVEmitIntrinsics::parseFunDeclarations(Module &M) {
bool SPIRVEmitIntrinsics::runOnModule(Module &M) {
bool Changed = false;
+ // auto TI = getAnalysis<SPIRVTypeAnalysisWrapperPass>().getTypeInfo();
parseFunDeclarations(M);
diff --git a/llvm/unittests/Target/SPIRV/CMakeLists.txt b/llvm/unittests/Target/SPIRV/CMakeLists.txt
index d7f0290089c4c..ef2fcc612d7df 100644
--- a/llvm/unittests/Target/SPIRV/CMakeLists.txt
+++ b/llvm/unittests/Target/SPIRV/CMakeLists.txt
@@ -16,6 +16,7 @@ set(LLVM_LINK_COMPONENTS
add_llvm_target_unittest(SPIRVTests
SPIRVConvergenceRegionAnalysisTests.cpp
+ SPIRVTypeAnalysisTests.cpp
SPIRVSortBlocksTests.cpp
SPIRVPartialOrderingVisitorTests.cpp
SPIRVAPITest.cpp
diff --git a/llvm/unittests/Target/SPIRV/SPIRVTypeAnalysisTests.cpp b/llvm/unittests/Target/SPIRV/SPIRVTypeAnalysisTests.cpp
new file mode 100644
index 0000000000000..3c2e7983838e8
--- /dev/null
+++ b/llvm/unittests/Target/SPIRV/SPIRVTypeAnalysisTests.cpp
@@ -0,0 +1,722 @@
+//===- SPIRVTypeAnalysisTests.cpp ----------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "Analysis/SPIRVTypeAnalysis.h"
+#include "llvm/Analysis/DominanceFrontier.h"
+#include "llvm/Analysis/PostDominators.h"
+#include "llvm/AsmParser/Parser.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/LegacyPassManager.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/PassInstrumentation.h"
+#include "llvm/IR/Type.h"
+#include "llvm/IR/TypedPointerType.h"
+#include "llvm/Support/SourceMgr.h"
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include <iostream>
+
+using ::testing::Contains;
+using ::testing::Pair;
+
+using namespace llvm;
+using namespace llvm::SPIRV;
+
+template <typename T> struct IsA {
+ friend bool operator==(const Value *V, const IsA &) { return isa<T>(V); }
+};
+
+class SPIRVTypeAnalysisTest : public testing::Test {
+protected:
+ void SetUp() override {
+ // Required for tests.
+ FAM.registerPass([&] { return PassInstrumentationAnalysis(); });
+ MAM.registerPass([&] { return PassInstrumentationAnalysis(); });
+
+ MAM.registerPass([&] { return SPIRVTypeAnalysis(); });
+ }
+
+ void TearDown() override { M.reset(); }
+
+ SPIRVTypeAnalysis::Result &runAnalysis(StringRef Assembly) {
+ assert(M == nullptr &&
+ "Calling runAnalysis multiple times is unsafe. See getAnalysis().");
+
+ SMDiagnostic Error;
+ M = parseAssemblyString(Assembly, Error, Context);
+ if (!M) {
+ std::cerr << Error.getMessage().str() << std::endl;
+ std::cerr << "> " << Error.getLineContents().str() << std::endl;
+ }
+ assert(M && "Bad assembly. Bad test?");
+
+ ModulePassManager MPM;
+ MPM.run(*M, MAM);
+
+ // Setup helper types.
+ IntTy = IntegerType::get(M->getContext(), 32);
+ FloatTy = Type::getFloatTy(M->getContext());
+
+ return MAM.getResult<SPIRVTypeAnalysis>(*M);
+ }
+
+ const Value *getValue(StringRef Name) {
+ assert(M != nullptr && "Has runAnalysis been called before?");
+
+ for (auto &F : *M) {
+ for (Argument &A : F.args())
+ if (A.getName() == Name)
+ return &A;
+
+ for (auto &BB : F)
+ for (auto &V : BB)
+ if (Name == V.getName())
+ return &V;
+ }
+ ADD_FAILURE() << "Error: Could not locate requested variable. Bad test?";
+ return nullptr;
+ }
+
+ StructType *getStructType(StringRef Name) {
+ for (StructType *ST : M->getIdentifiedStructTypes()) {
+ if (ST->getName() == Name)
+ return ST;
+ }
+
+ ADD_FAILURE() << "Error: Could not locate requested struct type. Bad test?";
+ return nullptr;
+ }
+
+protected:
+ LLVMContext Context;
+ FunctionAnalysisManager FAM;
+ ModuleAnalysisManager MAM;
+ std::unique_ptr<Module> M;
+
+ // Helper types for writting tests.
+ Type *IntTy = nullptr;
+ Type *FloatTy = nullptr;
+};
+
+TEST_F(SPIRVTypeAnalysisTest, ScalarAlloca) {
+ StringRef Assembly = R"(
+ define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
+ %test = alloca i32
+ ret void
+ }
+ )";
+
+ auto TI = runAnalysis(Assembly);
+ EXPECT_EQ(TI.getType(getValue("test")),
+ TypedPointerType::get(IntTy, /* AS= */ 0));
+}
+
+TEST_F(SPIRVTypeAnalysisTest, ScalarAllocaFloat) {
+ StringRef Assembly = R"(
+ define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
+ %test = alloca float
+ ret void
+ }
+ )";
+
+ auto TI = runAnalysis(Assembly);
+ EXPECT_EQ(TI.getType(getValue("test")),
+ TypedPointerType::get(FloatTy, /* AS= */ 0));
+}
+
+TEST_F(SPIRVTypeAnalysisTest, AllocaArray) {
+ StringRef Assembly = R"(
+ define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
+ %test = alloca [5 x i32]
+ ret void
+ }
+ )";
+
+ auto TI = runAnalysis(Assembly);
+ EXPECT_EQ(TI.getType(getValue("test")),
+ TypedPointerType::get(ArrayType::get(IntTy, 5), 0));
+}
+
+TEST_F(SPIRVTypeAnalysisTest, AllocaArrayArray) {
+ StringRef Assembly = R"(
+ define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
+ %test = alloca [5 x [10 x i32]]
+ ret void
+ }
+ )";
+
+ auto TI = runAnalysis(Assembly);
+ EXPECT_EQ(
+ TI.getType(getValue("test")),
+ TypedPointerType::get(ArrayType::get(ArrayType::get(IntTy, 10), 5), 0));
+}
+
+TEST_F(SPIRVTypeAnalysisTest, AllocaVector) {
+ StringRef Assembly = R"(
+ define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
+ %test = alloca <4 x i32>
+ ret void
+ }
+ )";
+
+ auto TI = runAnalysis(Assembly);
+ EXPECT_EQ(TI.getType(getValue("test")),
+ TypedPointerType::get(VectorType::get(IntTy, 4, false), 0));
+}
+
+TEST_F(SPIRVTypeAnalysisTest, AllocaArrayVector) {
+ StringRef Assembly = R"(
+ define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
+ %test = alloca [5 x <4 x i32>]
+ ret void
+ }
+ )";
+
+ auto TI = runAnalysis(Assembly);
+ EXPECT_EQ(TI.getType(getValue("test")),
+ TypedPointerType::get(
+ ArrayType::get(VectorType::get(IntTy, 4, false), 5), 0));
+}
+
+TEST_F(SPIRVTypeAnalysisTest, AllocaLoad) {
+ StringRef Assembly = R"(
+ define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
+ %test = alloca [5 x <4 x i32>]
+ %v = load <4 x i32>, ptr %test
+ ret void
+ }
+ )";
+
+ auto TI = runAnalysis(Assembly);
+ auto VT = VectorType::get(IntegerType::get(M->getContext(), 32), 4,
+ /* scalable= */ false);
+ EXPECT_EQ(TI.getType(getValue("test")),
+ TypedPointerType::get(ArrayType::get(VT, 5), 0));
+ EXPECT_EQ(TI.getType(getValue("v")), VT);
+}
+
+TEST_F(SPIRVTypeAnalysisTest, AllocaPtrIndirectDeduction) {
+ StringRef Assembly = R"(
+ define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
+ %v = alloca ptr
+ %l = load i32, ptr %v
+ ret void
+ }
+ )";
+
+ auto TI = runAnalysis(Assembly);
+ EXPECT_EQ(TI.getType(getValue("v")),
+ TypedPointerType::get(IntTy, /* AS= */ 0));
+ EXPECT_EQ(TI.getType(getValue("l")), IntTy);
+}
+
+TEST_F(SPIRVTypeAnalysisTest, AllocaPtrNestedIndirectDeduction) {
+ StringRef Assembly = R"(
+ define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
+ %v = alloca ptr
+ %l1 = load ptr, ptr %v
+ %l2 = load ptr, ptr %l1
+ %l3 = load ptr, ptr %l2
+ %l4 = load i32, ptr %l3
+ ret void
+ }
+ )";
+
+ auto TI = runAnalysis(Assembly);
+
+ Type *LoadType = IntTy;
+ EXPECT_EQ(TI.getType(getValue("l4")), LoadType);
+ LoadType = TypedPointerType::get(LoadType, /* AS= */ 0);
+ EXPECT_EQ(TI.getType(getValue("l3")), LoadType);
+ LoadType = TypedPointerType::get(LoadType, /* AS= */ 0);
+ EXPECT_EQ(TI.getType(getValue("l2")), LoadType);
+ LoadType = TypedPointerType::get(LoadType, /* AS= */ 0);
+ EXPECT_EQ(TI.getType(getValue("l1")), LoadType);
+ LoadType = TypedPointerType::get(LoadType, /* AS= */ 0);
+ EXPECT_EQ(TI.getType(getValue("v")), LoadType);
+}
+
+TEST_F(SPIRVTypeAnalysisTest, AllocaLoadArrayPtr) {
+ StringRef Assembly = R"(
+ define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
+ ; %ptr_array = alloca [5 x *i32]
+ %ptr_array = alloca [5 x ptr]
+ %l1 = load ptr, ptr %ptr_array
+ %l2 = load i32, ptr %l1
+ ret void
+ }
+ )";
+
+ auto TI = runAnalysis(Assembly);
+
+ EXPECT_EQ(TI.getType(getValue("ptr_array")),
+ ArrayType::get(TypedPointerType::get(IntTy, /* AS= */ 0), 5));
+ EXPECT_EQ(TI.getType(getValue("l1")),
+ TypedPointerType::get(IntTy, /* AS= */ 0));
+ EXPECT_EQ(TI.getType(getValue("l2")), IntTy);
+}
+
+TEST_F(SPIRVTypeAnalysisTest, AllocaStruct) {
+ StringRef Assembly = R"(
+ %st = type { i32 }
+
+ define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
+ %var = alloca %st
+ %l = load %st, ptr %var
+ ret void
+ }
+ )";
+
+ auto TI = runAnalysis(Assembly);
+ auto ST = getStructType("st");
+ EXPECT_EQ(TI.getType(getValue("var")),
+ TypedPointerType::get(ST, /* AS= */ 0));
+ EXPECT_EQ(TI.getType(getValue("l")), ST);
+}
+
+TEST_F(SPIRVTypeAnalysisTest, AllocaStructDeducedFromLoad) {
+ StringRef Assembly = R"(
+ %st = type { i32 }
+
+ define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
+ %var = alloca ptr
+ %l1 = load ptr, ptr %var
+ %l2 = load %st, ptr %l1
+ ret void
+ }
+ )";
+
+ auto TI = runAnalysis(Assembly);
+
+ Type *LoadType = getStructType("st");
+ EXPECT_EQ(TI.getType(getValue("l2")), LoadType);
+ LoadType = TypedPointerType::get(LoadType, /* AS= */ 0);
+ EXPECT_EQ(TI.getType(getValue("l1")), LoadType);
+ LoadType = TypedPointerType::get(LoadType, /* AS= */ 0);
+ EXPECT_EQ(TI.getType(getValue("var")), LoadType);
+}
+
+TEST_F(SPIRVTypeAnalysisTest, StructMemberDirectLoad) {
+ StringRef Assembly = R"(
+ %st = type { i32 }
+
+ define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
+ %var = alloca %st
+ %l2 = load i32, ptr %var
+ ret void
+ }
+ )";
+
+ auto TI = runAnalysis(Assembly);
+
+ Type *ST = getStructType("st");
+ EXPECT_EQ(TI.getType(getValue("l2")), IntTy);
+ EXPECT_EQ(TI.getType(getValue("var")),
+ TypedPointerType::get(ST, /* AS= */ 0));
+}
+
+TEST_F(SPIRVTypeAnalysisTest, StructMemberConflictingDeductionFromLoadA) {
+ StringRef Assembly = R"(
+ %st = type { i32 }
+
+ define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
+ ; %var = alloca *%st
+ %var = alloca ptr
+
+ %l1 = load ptr, ptr %var
+
+ %l2 = load %st, ptr %l1
+ %l3 = load i32, ptr %l1
+
+ ret void
+ }
+ )";
+
+ auto TI = runAnalysis(Assembly);
+
+ Type *ST = getStructType("st");
+ EXPECT_EQ(TI.getType(getValue("l3")), IntTy);
+ EXPECT_EQ(TI.getType(getValue("l2")), ST);
+ EXPECT_EQ(TI.getType(getValue("l1")), TypedPointerType::get(ST, /* AS= */ 0));
+ EXPECT_EQ(TI.getType(getValue("var")),
+ TypedPointerType::get(TypedPointerType::get(ST, /* AS= */ 0), 0));
+}
+
+TEST_F(SPIRVTypeAnalysisTest, StructMemberConflictingDeductionFromLoadB) {
+ StringRef Assembly = R"(
+ %st = type { i32 }
+
+ define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
+ ; %var = alloca *%st
+ %var = alloca ptr
+
+ %l1 = load ptr, ptr %var
+
+ %l3 = load i32, ptr %l1
+ %l2 = load %st, ptr %l1
+
+ ret void
+ }
+ )";
+
+ auto TI = runAnalysis(Assembly);
+
+ Type *ST = getStructType("st");
+ EXPECT_EQ(TI.getType(getValue("l3")), IntTy);
+ EXPECT_EQ(TI.getType(getValue("l2")), ST);
+ EXPECT_EQ(TI.getType(getValue("l1")), TypedPointerType::get(ST, /* AS= */ 0));
+ EXPECT_EQ(TI.getType(getValue("var")),
+ TypedPointerType::get(TypedPointerType::get(ST, /* AS= */ 0), 0));
+}
+
+TEST_F(SPIRVTypeAnalysisTest, StructMemberConflictingTree) {
+ StringRef Assembly = R"(
+ %st = type { i32 }
+
+ define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
+ ; %var = alloca *%st
+ %var = alloca ptr
+
+ %l_0 = load ptr, ptr %var
+ %l_1 = load ptr, ptr %var
+
+ %l_00 = load %st, ptr %l_0
+ %l_01 = load %st, ptr %l_1
+
+ %l_10 = load i32, ptr %l_0
+ %l_11 = load i32, ptr %l_1
+
+ ret void
+ }
+ )";
+
+ auto TI = runAnalysis(Assembly);
+
+ Type *ST = getStructType("st");
+ EXPECT_EQ(TI.getType(getValue("l_10")), IntTy);
+ EXPECT_EQ(TI.getType(getValue("l_11")), IntTy);
+
+ EXPECT_EQ(TI.getType(getValue("l_00")), ST);
+ EXPECT_EQ(TI.getType(getValue("l_01")), ST);
+
+ EXPECT_EQ(TI.getType(getValue("l_0")),
+ TypedPointerType::get(ST, /* AS= */ 0));
+ EXPECT_EQ(TI.getType(getValue("l_1")),
+ TypedPointerType::get(ST, /* AS= */ 0));
+
+ EXPECT_EQ(TI.getType(getValue("var")),
+ TypedPointerType::get(TypedPointerType::get(ST, /* AS= */ 0), 0));
+}
+
+TEST_F(SPIRVTypeAnalysisTest, ArrayElementConflict) {
+ StringRef Assembly = R"(
+ %st = type { i32 }
+
+ define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
+ ; %var = alloca *[5 x %st]
+ %var = alloca ptr
+
+ %l1 = load ptr, ptr %var
+
+ %l2 = load %st, ptr %l1
+ %l3 = load i32, ptr %l1
+ %l4 = load [5 x %st], ptr %l1
+
+ ret void
+ }
+ )";
+
+ auto TI = runAnalysis(Assembly);
+
+ Type *ST = getStructType("st");
+ EXPECT_EQ(TI.getType(getValue("l4")), ArrayType::get(ST, 5));
+ EXPECT_EQ(TI.getType(getValue("l3")), IntTy);
+ EXPECT_EQ(TI.getType(getValue("l2")), ST);
+ EXPECT_EQ(TI.getType(getValue("l1")),
+ TypedPointerType::get(ArrayType::get(ST, 5), /* AS= */ 0));
+ EXPECT_EQ(TI.getType(getValue("var")),
+ TypedPointerType::get(
+ TypedPointerType::get(ArrayType::get(ST, 5), /* AS= */ 0), 0));
+}
+
+TEST_F(SPIRVTypeAnalysisTest, VectorElementConflict) {
+ StringRef Assembly = R"(
+ define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
+ ; %var = alloca *<4 x i32>
+ %var = alloca ptr
+
+ %l1 = load ptr, ptr %var
+
+ %l3 = load i32, ptr %l1
+ %l4 = load <4 x i32>, ptr %l1
+
+ ret void
+ }
+ )";
+
+ auto TI = runAnalysis(Assembly);
+
+ auto *VT = VectorType::get(IntTy, 4, /* scalable= */ false);
+
+ EXPECT_EQ(TI.getType(getValue("l4")), VT);
+ EXPECT_EQ(TI.getType(getValue("l3")), IntTy);
+ EXPECT_EQ(TI.getType(getValue("l1")), TypedPointerType::get(VT, /* AS= */ 0));
+ EXPECT_EQ(TI.getType(getValue("var")),
+ TypedPointerType::get(TypedPointerType::get(VT, /* AS= */ 0), 0));
+}
+
+TEST_F(SPIRVTypeAnalysisTest, MissingInformationOnLoad) {
+ StringRef Assembly = R"(
+ define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
+ %var = alloca ptr
+ %l1 = load ptr, ptr %var
+ ret void
+ }
+ )";
+
+ auto TI = runAnalysis(Assembly);
+
+ EXPECT_EQ(TI.getType(getValue("l1")),
+ PointerType::get(M->getContext(), /* AS= */ 0));
+ EXPECT_EQ(TI.getType(getValue("var")),
+ PointerType::get(M->getContext(), /* AS= */ 0));
+}
+
+TEST_F(SPIRVTypeAnalysisTest, MissingInformationAlloca) {
+ StringRef Assembly = R"(
+ define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
+ %var = alloca ptr
+ ret void
+ }
+ )";
+
+ auto TI = runAnalysis(Assembly);
+
+ EXPECT_EQ(TI.getType(getValue("var")),
+ PointerType::get(M->getContext(), /* AS= */ 0));
+}
+
+TEST_F(SPIRVTypeAnalysisTest, DeduceFromGep) {
+ StringRef Assembly = R"(
+ define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
+ %var = alloca ptr
+ %ptr = getelementptr [5 x i32], ptr %var, i64 0, i64 0
+ %val = load i32, ptr %ptr
+ ret void
+ }
+ )";
+
+ auto TI = runAnalysis(Assembly);
+ Type *AT = ArrayType::get(IntTy, 5);
+ EXPECT_EQ(TI.getType(getValue("var")),
+ TypedPointerType::get(AT, /* AS= */ 0));
+}
+
+TEST_F(SPIRVTypeAnalysisTest, DeduceFromGepOpaqueBaseType) {
+ StringRef Assembly = R"(
+ define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
+ ; *[5 x *[2 x i32] ]
+ ; *[5 x ptr ]
+ ; ptr
+ %var = alloca ptr
+ %ptr1 = getelementptr ptr, ptr %var, i64 0
+ %ptr2 = getelementptr ptr, ptr %ptr1, i64 0
+ %val = load i32, ptr %ptr2
+ ret void
+ }
+ )";
+
+ auto TI = runAnalysis(Assembly);
+ Type *T = IntTy;
+ EXPECT_EQ(TI.getType(getValue("val")), T);
+ T = TypedPointerType::get(T, /* AS= */ 0);
+ EXPECT_EQ(TI.getType(getValue("ptr2")), T);
+ T = TypedPointerType::get(T, /* AS= */ 0);
+ EXPECT_EQ(TI.getType(getValue("ptr1")), T);
+ T = TypedPointerType::get(T, /* AS= */ 0);
+ EXPECT_EQ(TI.getType(getValue("var")), T);
+}
+
+TEST_F(SPIRVTypeAnalysisTest, DeduceFromGepPartialOpaqueBaseType) {
+ StringRef Assembly = R"(
+ define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
+ ; *[5 x *[2 x i32] ]
+ ; *[5 x ptr ]
+ ; ptr
+ %var = alloca ptr
+ %ptr1 = getelementptr [5 x ptr], ptr %var, i64 0, i64 0
+ %ptr2 = getelementptr ptr, ptr %ptr1, i64 0
+ %val = load i32, ptr %ptr2
+ ret void
+ }
+ )";
+
+ auto TI = runAnalysis(Assembly);
+ Type *T = IntTy;
+ EXPECT_EQ(TI.getType(getValue("val")), T);
+ T = TypedPointerType::get(T, /* AS= */ 0);
+ EXPECT_EQ(TI.getType(getValue("ptr2")), T);
+ T = TypedPointerType::get(T, /* AS= */ 0);
+ EXPECT_EQ(TI.getType(getValue("ptr1")), T);
+ T = TypedPointerType::get(ArrayType::get(T, 5), 0);
+ EXPECT_EQ(TI.getType(getValue("var")), T);
+}
+
+TEST_F(SPIRVTypeAnalysisTest, DeduceParamFromLoad) {
+ StringRef Assembly = R"(
+ define void @foo(ptr %input) {
+ %a = load i32, ptr %input
+ ret void
+ }
+ )";
+
+ auto TI = runAnalysis(Assembly);
+ EXPECT_EQ(TI.getType(getValue("input")), TypedPointerType::get(IntTy, 0));
+}
+
+TEST_F(SPIRVTypeAnalysisTest, DeduceFromStore) {
+ StringRef Assembly = R"(
+ define void @foo(ptr %input) {
+ store i32 0, ptr %input
+ ret void
+ }
+ )";
+
+ auto TI = runAnalysis(Assembly);
+ EXPECT_EQ(TI.getType(getValue("input")), TypedPointerType::get(IntTy, 0));
+}
+
+TEST_F(SPIRVTypeAnalysisTest, DeduceFromStoreStructConflict) {
+ StringRef Assembly = R"(
+ %st = type { i32 }
+
+ define void @foo(ptr %a, ptr %b) {
+ %s = load %st, ptr %a
+ store i32 0, ptr %b
+ store %st %s, ptr %b
+ ret void
+ }
+ )";
+
+ auto TI = runAnalysis(Assembly);
+ auto ST = getStructType("st");
+ EXPECT_EQ(TI.getType(getValue("a")), TypedPointerType::get(ST, 0));
+ EXPECT_EQ(TI.getType(getValue("b")), TypedPointerType::get(ST, 0));
+ EXPECT_EQ(TI.getType(getValue("s")), ST);
+}
+
+TEST_F(SPIRVTypeAnalysisTest, DeduceFromStoreStructInline) {
+ StringRef Assembly = R"(
+ %st = type { i32 }
+
+ define void @foo(ptr %a) {
+ store i32 0, ptr %a
+ store %st { i32 0 }, ptr %a
+ ret void
+ }
+ )";
+
+ auto TI = runAnalysis(Assembly);
+ auto ST = getStructType("st");
+ EXPECT_EQ(TI.getType(getValue("a")), TypedPointerType::get(ST, 0));
+}
+
+TEST_F(SPIRVTypeAnalysisTest, DeduceFromStoreArray) {
+ StringRef Assembly = R"(
+ define void @foo(ptr %a) {
+ store i32 0, ptr %a
+ store [2 x i32] [ i32 0, i32 1 ], ptr %a
+ ret void
+ }
+ )";
+
+ auto TI = runAnalysis(Assembly);
+ EXPECT_EQ(TI.getType(getValue("a")),
+ TypedPointerType::get(ArrayType::get(IntTy, 2), 0));
+}
+
+TEST_F(SPIRVTypeAnalysisTest, DeduceFromCall) {
+ StringRef Assembly = R"(
+ define ptr @foo(ptr %par) {
+ ret ptr %par
+ }
+
+ define void @bar() {
+ %var = alloca i32
+ %res = call ptr @foo(ptr %var)
+ store i32 0, ptr %res
+ ret void
+ }
+ )";
+
+ auto TI = runAnalysis(Assembly);
+ EXPECT_EQ(TI.getType(getValue("var")), TypedPointerType::get(IntTy, 0));
+ EXPECT_EQ(TI.getType(getValue("res")), TypedPointerType::get(IntTy, 0));
+ EXPECT_EQ(TI.getType(getValue("par")), TypedPointerType::get(IntTy, 0));
+}
+
+TEST_F(SPIRVTypeAnalysisTest, DeduceFromCallLackInfo) {
+ StringRef Assembly = R"(
+ define ptr @foo(ptr %par) {
+ ret ptr %par
+ }
+
+ define void @bar() {
+ %var = alloca ptr
+ %res = call ptr @foo(ptr %var)
+ ret void
+ }
+ )";
+
+ auto TI = runAnalysis(Assembly);
+ EXPECT_EQ(TI.getType(getValue("var")),
+ PointerType::get(M->getContext(), /* AS= */ 0));
+ EXPECT_EQ(TI.getType(getValue("res")),
+ PointerType::get(M->getContext(), /* AS= */ 0));
+ EXPECT_EQ(TI.getType(getValue("par")),
+ PointerType::get(M->getContext(), /* AS= */ 0));
+}
+
+TEST_F(SPIRVTypeAnalysisTest, DeduceFromCallPartial) {
+ StringRef Assembly = R"(
+ define ptr @foo(ptr %fpar1, ptr %fpar2) {
+ ret ptr %fpar2
+ }
+
+ define void @bar(ptr %bpar1, ptr %bpar2) {
+ %res = call ptr @foo(ptr %bpar1, ptr %bpar2)
+ store i32 0, ptr %res
+ ret void
+ }
+ )";
+
+ auto TI = runAnalysis(Assembly);
+ EXPECT_EQ(TI.getType(getValue("fpar1")),
+ PointerType::get(M->getContext(), /* AS= */ 0));
+ EXPECT_EQ(TI.getType(getValue("bpar1")),
+ PointerType::get(M->getContext(), /* AS= */ 0));
+ EXPECT_EQ(TI.getType(getValue("res")), TypedPointerType::get(IntTy, 0));
+ EXPECT_EQ(TI.getType(getValue("fpar2")), TypedPointerType::get(IntTy, 0));
+ EXPECT_EQ(TI.getType(getValue("bpar2")), TypedPointerType::get(IntTy, 0));
+}
+
+TEST_F(SPIRVTypeAnalysisTest, DeduceRecursive) {
+ StringRef Assembly = R"(
+ define ptr @foo(ptr %par) {
+ store i32 0, ptr %par
+ %tmp = call ptr @foo(ptr %par)
+ ret ptr %par
+ }
+ )";
+
+ auto TI = runAnalysis(Assembly);
+ EXPECT_EQ(TI.getType(getValue("tmp")), TypedPointerType::get(IntTy, 0));
+ EXPECT_EQ(TI.getType(getValue("par")), TypedPointerType::get(IntTy, 0));
+}
More information about the llvm-commits
mailing list