[llvm] [SPIR-V] Improve type inference: deduce types of composite data structures (PR #86782)
Vyacheslav Levytskyy via llvm-commits
llvm-commits at lists.llvm.org
Wed Mar 27 02:13:04 PDT 2024
https://github.com/VyacheslavLevytskyy created https://github.com/llvm/llvm-project/pull/86782
This PR improves type inference in general and deduces types of composite data structures in particular.
>From b7ec2847aee7a3a263affb04a27a634bc8baea8a Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Wed, 27 Mar 2024 02:11:26 -0700
Subject: [PATCH] deduce types of composite data structures; improve and fix
types deduction
---
llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp | 25 ++-
llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp | 212 +++++++++++++-----
llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h | 40 ++++
.../Target/SPIRV/SPIRVInstructionSelector.cpp | 2 +-
llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp | 5 +-
llvm/lib/Target/SPIRV/SPIRVUtils.h | 23 +-
.../pointers/nested-struct-opaque-pointers.ll | 29 +++
.../SPIRV/pointers/struct-opaque-pointers.ll | 8 +-
8 files changed, 273 insertions(+), 71 deletions(-)
create mode 100644 llvm/test/CodeGen/SPIRV/pointers/nested-struct-opaque-pointers.ll
diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
index afdca01561b0bc..ad4e72a3128b1e 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
@@ -201,21 +201,30 @@ static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,
if (!isPointerTy(OriginalArgType))
return GR->getOrCreateSPIRVType(OriginalArgType, MIRBuilder, ArgAccessQual);
- // In case OriginalArgType is of pointer type, there are three possibilities:
+ Argument *Arg = F.getArg(ArgIdx);
+ Type *ArgType = Arg->getType();
+ if (isTypedPointerTy(ArgType)) {
+ SPIRVType *ElementType = GR->getOrCreateSPIRVType(
+ cast<TypedPointerType>(ArgType)->getElementType(), MIRBuilder);
+ return GR->getOrCreateSPIRVPointerType(
+ ElementType, MIRBuilder,
+ addressSpaceToStorageClass(getPointerAddressSpace(ArgType), ST));
+ }
+
+ // In case OriginalArgType is of untyped pointer type, there are three
+ // possibilities:
// 1) This is a pointer of an LLVM IR element type, passed byval/byref.
// 2) This is an OpenCL/SPIR-V builtin type if there is spv_assign_type
- // intrinsic assigning a TargetExtType.
+ // intrinsic assigning a TargetExtType.
// 3) This is a pointer, try to retrieve pointer element type from a
// spv_assign_ptr_type intrinsic or otherwise use default pointer element
// type.
- Argument *Arg = F.getArg(ArgIdx);
- if (HasPointeeTypeAttr(Arg)) {
- Type *ByValRefType = Arg->hasByValAttr() ? Arg->getParamByValType()
- : Arg->getParamByRefType();
- SPIRVType *ElementType = GR->getOrCreateSPIRVType(ByValRefType, MIRBuilder);
+ if (hasPointeeTypeAttr(Arg)) {
+ SPIRVType *ElementType =
+ GR->getOrCreateSPIRVType(getPointeeTypeByAttr(Arg), MIRBuilder);
return GR->getOrCreateSPIRVPointerType(
ElementType, MIRBuilder,
- addressSpaceToStorageClass(getPointerAddressSpace(Arg->getType()), ST));
+ addressSpaceToStorageClass(getPointerAddressSpace(ArgType), ST));
}
for (auto User : Arg->users()) {
diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index 5828db6669ff18..b4e71dd9b8800e 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -14,6 +14,7 @@
#include "SPIRV.h"
#include "SPIRVBuiltins.h"
#include "SPIRVMetadata.h"
+#include "SPIRVSubtarget.h"
#include "SPIRVTargetMachine.h"
#include "SPIRVUtils.h"
#include "llvm/IR/IRBuilder.h"
@@ -53,14 +54,22 @@ class SPIRVEmitIntrinsics
: public FunctionPass,
public InstVisitor<SPIRVEmitIntrinsics, Instruction *> {
SPIRVTargetMachine *TM = nullptr;
+ SPIRVGlobalRegistry *GR = nullptr;
Function *F = nullptr;
bool TrackConstants = true;
DenseMap<Instruction *, Constant *> AggrConsts;
+ DenseMap<Instruction *, Type *> AggrConstTypes;
DenseSet<Instruction *> AggrStores;
- // deduce values type
- DenseMap<Value *, Type *> DeducedElTys;
+ // deduce element type of untyped pointers
Type *deduceElementType(Value *I);
+ Type *deduceElementTypeHelper(Value *I);
+ Type *deduceElementTypeHelper(Value *I, std::unordered_set<Value *> &Visited);
+
+ // deduce nested types of composites
+ Type *deduceNestedTypeHelper(User *U);
+ Type *deduceNestedTypeHelper(User *U, Type *Ty,
+ std::unordered_set<Value *> &Visited);
void preprocessCompositeConstants(IRBuilder<> &B);
void preprocessUndefs(IRBuilder<> &B);
@@ -92,9 +101,9 @@ class SPIRVEmitIntrinsics
void insertPtrCastOrAssignTypeInstr(Instruction *I, IRBuilder<> &B);
void processGlobalValue(GlobalVariable &GV, IRBuilder<> &B);
void processParamTypes(Function *F, IRBuilder<> &B);
- Type *deduceFunParamType(Function *F, unsigned OpIdx);
- Type *deduceFunParamType(Function *F, unsigned OpIdx,
- std::unordered_set<Function *> &FVisited);
+ Type *deduceFunParamElementType(Function *F, unsigned OpIdx);
+ Type *deduceFunParamElementType(Function *F, unsigned OpIdx,
+ std::unordered_set<Function *> &FVisited);
public:
static char ID;
@@ -169,17 +178,20 @@ static inline void reportFatalOnTokenType(const Instruction *I) {
// Deduce and return a successfully deduced Type of the Instruction,
// or nullptr otherwise.
-static Type *deduceElementTypeHelper(Value *I,
- std::unordered_set<Value *> &Visited,
- DenseMap<Value *, Type *> &DeducedElTys) {
+Type *SPIRVEmitIntrinsics::deduceElementTypeHelper(Value *I) {
+ std::unordered_set<Value *> Visited;
+ return deduceElementTypeHelper(I, Visited);
+}
+
+Type *SPIRVEmitIntrinsics::deduceElementTypeHelper(
+ Value *I, std::unordered_set<Value *> &Visited) {
// allow to pass nullptr as an argument
if (!I)
return nullptr;
// maybe already known
- auto It = DeducedElTys.find(I);
- if (It != DeducedElTys.end())
- return It->second;
+ if (Type *KnownTy = GR->findDeducedElementType(I))
+ return KnownTy;
// maybe a cycle
if (Visited.find(I) != Visited.end())
@@ -195,25 +207,99 @@ static Type *deduceElementTypeHelper(Value *I,
Ty = Ref->getResultElementType();
} else if (auto *Ref = dyn_cast<GlobalValue>(I)) {
Ty = Ref->getValueType();
+ if (Value *Op = Ref->getNumOperands() > 0 ? Ref->getOperand(0) : nullptr) {
+ if (auto *PtrTy = dyn_cast<PointerType>(Ty)) {
+ if (Type *NestedTy = deduceElementTypeHelper(Op, Visited))
+ Ty = TypedPointerType::get(NestedTy, PtrTy->getAddressSpace());
+ } else {
+ Ty = deduceNestedTypeHelper(dyn_cast<User>(Op), Ty, Visited);
+ }
+ }
} else if (auto *Ref = dyn_cast<AddrSpaceCastInst>(I)) {
- Ty = deduceElementTypeHelper(Ref->getPointerOperand(), Visited,
- DeducedElTys);
+ Ty = deduceElementTypeHelper(Ref->getPointerOperand(), Visited);
} else if (auto *Ref = dyn_cast<BitCastInst>(I)) {
if (Type *Src = Ref->getSrcTy(), *Dest = Ref->getDestTy();
isPointerTy(Src) && isPointerTy(Dest))
- Ty = deduceElementTypeHelper(Ref->getOperand(0), Visited, DeducedElTys);
+ Ty = deduceElementTypeHelper(Ref->getOperand(0), Visited);
}
// remember the found relationship
- if (Ty)
- DeducedElTys[I] = Ty;
+ if (Ty) {
+ // specify nested types if needed, otherwise return unchanged
+ GR->addDeducedElementType(I, Ty);
+ }
return Ty;
}
-Type *SPIRVEmitIntrinsics::deduceElementType(Value *I) {
+// Re-create a type of the value if it has untyped pointer fields, also nested.
+// Return the original value type if no corrections of untyped pointer
+// information is found or needed.
+Type *SPIRVEmitIntrinsics::deduceNestedTypeHelper(User *U) {
std::unordered_set<Value *> Visited;
- if (Type *Ty = deduceElementTypeHelper(I, Visited, DeducedElTys))
+ return deduceNestedTypeHelper(U, U->getType(), Visited);
+}
+
+Type *SPIRVEmitIntrinsics::deduceNestedTypeHelper(
+ User *U, Type *OrigTy, std::unordered_set<Value *> &Visited) {
+ if (!U)
+ return OrigTy;
+
+ // maybe already known
+ if (Type *KnownTy = GR->findDeducedCompositeType(U))
+ return KnownTy;
+
+ // maybe a cycle
+ if (Visited.find(U) != Visited.end())
+ return OrigTy;
+ Visited.insert(U);
+
+ if (dyn_cast<StructType>(OrigTy)) {
+ SmallVector<Type *> Tys;
+ bool Change = false;
+ for (unsigned i = 0; i < U->getNumOperands(); ++i) {
+ Value *Op = U->getOperand(i);
+ Type *OpTy = Op->getType();
+ Type *Ty = OpTy;
+ if (Op) {
+ if (auto *PtrTy = dyn_cast<PointerType>(OpTy)) {
+ if (Type *NestedTy = deduceElementTypeHelper(Op, Visited))
+ Ty = TypedPointerType::get(NestedTy, PtrTy->getAddressSpace());
+ } else {
+ Ty = deduceNestedTypeHelper(dyn_cast<User>(Op), OpTy, Visited);
+ }
+ }
+ Tys.push_back(Ty);
+ Change |= Ty != OpTy;
+ }
+ if (Change) {
+ Type *NewTy = StructType::create(Tys);
+ GR->addDeducedCompositeType(U, NewTy);
+ return NewTy;
+ }
+ } else if (auto *ArrTy = dyn_cast<ArrayType>(OrigTy)) {
+ if (Value *Op = U->getNumOperands() > 0 ? U->getOperand(0) : nullptr) {
+ Type *OpTy = ArrTy->getElementType();
+ Type *Ty = OpTy;
+ if (auto *PtrTy = dyn_cast<PointerType>(OpTy)) {
+ if (Type *NestedTy = deduceElementTypeHelper(Op, Visited))
+ Ty = TypedPointerType::get(NestedTy, PtrTy->getAddressSpace());
+ } else {
+ Ty = deduceNestedTypeHelper(dyn_cast<User>(Op), OpTy, Visited);
+ }
+ if (Ty != OpTy) {
+ Type *NewTy = ArrayType::get(Ty, ArrTy->getNumElements());
+ GR->addDeducedCompositeType(U, NewTy);
+ return NewTy;
+ }
+ }
+ }
+
+ return OrigTy;
+}
+
+Type *SPIRVEmitIntrinsics::deduceElementType(Value *I) {
+ if (Type *Ty = deduceElementTypeHelper(I))
return Ty;
return IntegerType::getInt8Ty(I->getContext());
}
@@ -257,6 +343,7 @@ void SPIRVEmitIntrinsics::preprocessUndefs(IRBuilder<> &B) {
Worklist.push(IntrUndef);
I->replaceUsesOfWith(Op, IntrUndef);
AggrConsts[IntrUndef] = AggrUndef;
+ AggrConstTypes[IntrUndef] = AggrUndef->getType();
}
}
}
@@ -282,6 +369,7 @@ void SPIRVEmitIntrinsics::preprocessCompositeConstants(IRBuilder<> &B) {
I->replaceUsesOfWith(Op, CCI);
KeepInst = true;
SEI.AggrConsts[CCI] = AggrC;
+ SEI.AggrConstTypes[CCI] = SEI.deduceNestedTypeHelper(AggrC);
};
if (auto *AggrC = dyn_cast<ConstantAggregate>(Op)) {
@@ -396,8 +484,7 @@ void SPIRVEmitIntrinsics::replacePointerOperandWithPtrCast(
Pointer = BC->getOperand(0);
// Do not emit spv_ptrcast if Pointer's element type is ExpectedElementType
- std::unordered_set<Value *> Visited;
- Type *PointerElemTy = deduceElementTypeHelper(Pointer, Visited, DeducedElTys);
+ Type *PointerElemTy = deduceElementTypeHelper(Pointer);
if (PointerElemTy == ExpectedElementType)
return;
@@ -456,8 +543,8 @@ void SPIRVEmitIntrinsics::replacePointerOperandWithPtrCast(
CallInst *CI = buildIntrWithMD(
Intrinsic::spv_assign_ptr_type, {Pointer->getType()},
ExpectedElementTypeConst, Pointer, {B.getInt32(AddressSpace)}, B);
- DeducedElTys[CI] = ExpectedElementType;
- DeducedElTys[Pointer] = ExpectedElementType;
+ GR->addDeducedElementType(CI, ExpectedElementType);
+ GR->addDeducedElementType(Pointer, ExpectedElementType);
return;
}
@@ -498,25 +585,29 @@ void SPIRVEmitIntrinsics::insertPtrCastOrAssignTypeInstr(Instruction *I,
Function *CalledF = CI->getCalledFunction();
SmallVector<Type *, 4> CalledArgTys;
bool HaveTypes = false;
- for (auto &CalledArg : CalledF->args()) {
- if (!isPointerTy(CalledArg.getType())) {
+ for (unsigned OpIdx = 0; OpIdx < CalledF->arg_size(); ++OpIdx) {
+ Argument *CalledArg = CalledF->getArg(OpIdx);
+ Type *ArgType = CalledArg->getType();
+ if (!isPointerTy(ArgType)) {
CalledArgTys.push_back(nullptr);
- continue;
- }
- auto It = DeducedElTys.find(&CalledArg);
- Type *ParamTy = It != DeducedElTys.end() ? It->second : nullptr;
- if (!ParamTy) {
- for (User *U : CalledArg.users()) {
- if (Instruction *Inst = dyn_cast<Instruction>(U)) {
- std::unordered_set<Value *> Visited;
- ParamTy = deduceElementTypeHelper(Inst, Visited, DeducedElTys);
- if (ParamTy)
- break;
+ } else if (isTypedPointerTy(ArgType)) {
+ CalledArgTys.push_back(cast<TypedPointerType>(ArgType)->getElementType());
+ HaveTypes = true;
+ } else {
+ Type *ElemTy = GR->findDeducedElementType(CalledArg);
+ if (!ElemTy && hasPointeeTypeAttr(CalledArg))
+ ElemTy = getPointeeTypeByAttr(CalledArg);
+ if (!ElemTy) {
+ for (User *U : CalledArg->users()) {
+ if (Instruction *Inst = dyn_cast<Instruction>(U)) {
+ if ((ElemTy = deduceElementTypeHelper(Inst)) != nullptr)
+ break;
+ }
}
}
+ HaveTypes |= ElemTy != nullptr;
+ CalledArgTys.push_back(ElemTy);
}
- HaveTypes |= ParamTy != nullptr;
- CalledArgTys.push_back(ParamTy);
}
std::string DemangledName =
@@ -706,6 +797,10 @@ void SPIRVEmitIntrinsics::processGlobalValue(GlobalVariable &GV,
if (GV.getName() == "llvm.global.annotations")
return;
if (GV.hasInitializer() && !isa<UndefValue>(GV.getInitializer())) {
+ // Deduce element type and store results in Global Registry.
+ // Result is ignored, because TypedPointerType is not supported
+ // by llvm IR general logic.
+ deduceElementTypeHelper(&GV);
Constant *Init = GV.getInitializer();
Type *Ty = isAggrToReplace(Init) ? B.getInt32Ty() : Init->getType();
Constant *Const = isAggrToReplace(Init) ? B.getInt32(1) : Init;
@@ -732,7 +827,7 @@ void SPIRVEmitIntrinsics::insertAssignPtrTypeIntrs(Instruction *I,
unsigned AddressSpace = getPointerAddressSpace(I->getType());
CallInst *CI = buildIntrWithMD(Intrinsic::spv_assign_ptr_type, {I->getType()},
EltTyConst, I, {B.getInt32(AddressSpace)}, B);
- DeducedElTys[CI] = ElemTy;
+ GR->addDeducedElementType(CI, ElemTy);
}
void SPIRVEmitIntrinsics::insertAssignTypeIntrs(Instruction *I,
@@ -745,9 +840,10 @@ void SPIRVEmitIntrinsics::insertAssignTypeIntrs(Instruction *I,
if (auto *II = dyn_cast<IntrinsicInst>(I)) {
if (II->getIntrinsicID() == Intrinsic::spv_const_composite ||
II->getIntrinsicID() == Intrinsic::spv_undef) {
- auto t = AggrConsts.find(II);
- assert(t != AggrConsts.end());
- TypeToAssign = t->second->getType();
+ auto It = AggrConstTypes.find(II);
+ if (It == AggrConstTypes.end())
+ report_fatal_error("Unknown composite intrinsic type");
+ TypeToAssign = It->second;
}
}
Constant *Const = UndefValue::get(TypeToAssign);
@@ -807,12 +903,13 @@ void SPIRVEmitIntrinsics::processInstrAfterVisit(Instruction *I,
}
}
-Type *SPIRVEmitIntrinsics::deduceFunParamType(Function *F, unsigned OpIdx) {
+Type *SPIRVEmitIntrinsics::deduceFunParamElementType(Function *F,
+ unsigned OpIdx) {
std::unordered_set<Function *> FVisited;
- return deduceFunParamType(F, OpIdx, FVisited);
+ return deduceFunParamElementType(F, OpIdx, FVisited);
}
-Type *SPIRVEmitIntrinsics::deduceFunParamType(
+Type *SPIRVEmitIntrinsics::deduceFunParamElementType(
Function *F, unsigned OpIdx, std::unordered_set<Function *> &FVisited) {
// maybe a cycle
if (FVisited.find(F) != FVisited.end())
@@ -830,15 +927,15 @@ Type *SPIRVEmitIntrinsics::deduceFunParamType(
if (!isPointerTy(OpArg->getType()))
continue;
// maybe we already know operand's element type
- if (auto It = DeducedElTys.find(OpArg); It != DeducedElTys.end())
- return It->second;
+ if (Type *KnownTy = GR->findDeducedElementType(OpArg))
+ return KnownTy;
// search in actual parameter's users
for (User *OpU : OpArg->users()) {
Instruction *Inst = dyn_cast<Instruction>(OpU);
if (!Inst || Inst == CI)
continue;
Visited.clear();
- if (Type *Ty = deduceElementTypeHelper(Inst, Visited, DeducedElTys))
+ if (Type *Ty = deduceElementTypeHelper(Inst, Visited))
return Ty;
}
// check if it's a formal parameter of the outer function
@@ -857,7 +954,7 @@ Type *SPIRVEmitIntrinsics::deduceFunParamType(
// search in function parameters
for (auto &Pair : Lookup) {
- if (Type *Ty = deduceFunParamType(Pair.first, Pair.second, FVisited))
+ if (Type *Ty = deduceFunParamElementType(Pair.first, Pair.second, FVisited))
return Ty;
}
@@ -866,19 +963,21 @@ Type *SPIRVEmitIntrinsics::deduceFunParamType(
void SPIRVEmitIntrinsics::processParamTypes(Function *F, IRBuilder<> &B) {
B.SetInsertPointPastAllocas(F);
- DenseMap<Argument *, Type *> Args;
for (unsigned OpIdx = 0; OpIdx < F->arg_size(); ++OpIdx) {
Argument *Arg = F->getArg(OpIdx);
if (isUntypedPointerTy(Arg->getType()) &&
- DeducedElTys.find(Arg) == DeducedElTys.end() &&
- !HasPointeeTypeAttr(Arg)) {
- if (Type *ElemTy = deduceFunParamType(F, OpIdx)) {
+ !GR->findDeducedElementType(Arg)) {
+ Type *ElemTy = nullptr;
+ if (hasPointeeTypeAttr(Arg) &&
+ (ElemTy = getPointeeTypeByAttr(Arg)) != nullptr) {
+ GR->addDeducedElementType(Arg, ElemTy);
+ } else if ((ElemTy = deduceFunParamElementType(F, OpIdx)) != nullptr) {
CallInst *AssignPtrTyCI = buildIntrWithMD(
Intrinsic::spv_assign_ptr_type, {Arg->getType()},
Constant::getNullValue(ElemTy), Arg,
{B.getInt32(getPointerAddressSpace(Arg->getType()))}, B);
- DeducedElTys[AssignPtrTyCI] = ElemTy;
- DeducedElTys[Arg] = ElemTy;
+ GR->addDeducedElementType(AssignPtrTyCI, ElemTy);
+ GR->addDeducedElementType(Arg, ElemTy);
}
}
}
@@ -887,9 +986,14 @@ void SPIRVEmitIntrinsics::processParamTypes(Function *F, IRBuilder<> &B) {
bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
if (Func.isDeclaration())
return false;
+
+ const SPIRVSubtarget &ST = TM->getSubtarget<SPIRVSubtarget>(Func);
+ GR = ST.getSPIRVGlobalRegistry();
+
F = &Func;
IRBuilder<> B(Func.getContext());
AggrConsts.clear();
+ AggrConstTypes.clear();
AggrStores.clear();
// StoreInst's operand type can be changed during the next transformations,
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index ed0f90ff89ce6e..acaf1bd5327ab6 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -59,6 +59,13 @@ class SPIRVGlobalRegistry {
// Holds the maximum ID we have in the module.
unsigned Bound;
+ // Maps values associated with untyped pointers into deduced element types of
+ // untyped pointers.
+ DenseMap<Value *, Type *> DeducedElTys;
+ // Maps composite values to deduced types where untyped pointers are replaced
+ // with typed ones
+ DenseMap<Value *, Type *> DeducedNestedTys;
+
// Add a new OpTypeXXX instruction without checking for duplicates.
SPIRVType *createSPIRVType(const Type *Type, MachineIRBuilder &MIRBuilder,
SPIRV::AccessQualifier::AccessQualifier AQ =
@@ -122,6 +129,39 @@ class SPIRVGlobalRegistry {
void setBound(unsigned V) { Bound = V; }
unsigned getBound() { return Bound; }
+ // Deduced element types of untyped pointers and composites:
+ // - Add a record to the map of deduced element types.
+ void addDeducedElementType(Value *Val, Type *Ty) {
+ DeducedElTys[Val] = Ty;
+ }
+ // - Find a record in the map of deduced element types.
+ Type *findDeducedElementType(const Value *Val) {
+ auto It = DeducedElTys.find(Val);
+ return It == DeducedElTys.end() ? nullptr : It->second;
+ }
+ // - Add a record to the map of deduced composite types.
+ void addDeducedCompositeType(Value *Val, Type *Ty) {
+ DeducedNestedTys[Val] = Ty;
+ }
+ // - Find a record in the map of deduced composite types.
+ Type *findDeducedCompositeType(const Value *Val) {
+ auto It = DeducedNestedTys.find(Val);
+ return It == DeducedNestedTys.end() ? nullptr : It->second;
+ }
+ // - Find a type of the given Global value
+ Type *getDeducedGlobalValueType(const GlobalValue *Global) {
+ // we may know element type if it was deduced earlier
+ Type *ElementTy = findDeducedElementType(Global);
+ if (!ElementTy) {
+ // or we may know element type if it's associated with a composite
+ // value
+ if (Value *GlobalElem =
+ Global->getNumOperands() > 0 ? Global->getOperand(0) : nullptr)
+ ElementTy = findDeducedCompositeType(GlobalElem);
+ }
+ return ElementTy ? ElementTy : Global->getValueType();
+ }
+
// Map a machine operand that represents a use of a function via function
// pointer to a machine operand that represents the function definition.
// Return either the register or invalid value, because we have no context for
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 505b19a4d66edb..f4525e713c987f 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -1897,7 +1897,7 @@ bool SPIRVInstructionSelector::selectGlobalValue(
// FIXME: don't use MachineIRBuilder here, replace it with BuildMI.
MachineIRBuilder MIRBuilder(I);
const GlobalValue *GV = I.getOperand(1).getGlobal();
- Type *GVType = GV->getValueType();
+ Type *GVType = GR.getDeducedGlobalValueType(GV);
SPIRVType *PointerBaseType;
if (GVType->isArrayTy()) {
SPIRVType *ArrayElementType =
diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
index 41807da6afcbc7..b133f0ae85de20 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
@@ -186,8 +186,9 @@ static SPIRVType *propagateSPIRVType(MachineInstr *MI, SPIRVGlobalRegistry *GR,
}
case TargetOpcode::G_GLOBAL_VALUE: {
MIB.setInsertPt(*MI->getParent(), MI);
- const auto *Global = MI->getOperand(1).getGlobal();
- auto *Ty = TypedPointerType::get(Global->getValueType(),
+ const GlobalValue *Global = MI->getOperand(1).getGlobal();
+ Type *ElementTy = GR->getDeducedGlobalValueType(Global);
+ auto *Ty = TypedPointerType::get(ElementTy,
Global->getType()->getAddressSpace());
SpirvTy = GR->getOrCreateSPIRVType(Ty, MIB);
break;
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.h b/llvm/lib/Target/SPIRV/SPIRVUtils.h
index eb87349f0941c5..9fdc1d0dc9a559 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.h
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.h
@@ -127,8 +127,27 @@ inline unsigned getPointerAddressSpace(const Type *T) {
}
// Return true if the Argument is decorated with a pointee type
-inline bool HasPointeeTypeAttr(Argument *Arg) {
- return Arg->hasByValAttr() || Arg->hasByRefAttr();
+inline bool hasPointeeTypeAttr(Argument *Arg) {
+ return Arg->hasByValAttr() || Arg->hasByRefAttr() || Arg->hasStructRetAttr();
+}
+
+// Return the pointee type of the argument or nullptr otherwise
+inline Type *getPointeeTypeByAttr(Argument *Arg) {
+ if (Arg->hasByValAttr())
+ return Arg->getParamByValType();
+ if (Arg->hasStructRetAttr())
+ return Arg->getParamStructRetType();
+ if (Arg->hasByRefAttr())
+ return Arg->getParamByRefType();
+ return nullptr;
+}
+
+inline
+Type *reconstructFunctionType(Function *F) {
+ SmallVector<Type *> ArgTys;
+ for (unsigned i = 0; i < F->arg_size(); ++i)
+ ArgTys.push_back(F->getArg(i)->getType());
+ return FunctionType::get(F->getReturnType(), ArgTys, F->isVarArg());
}
} // namespace llvm
diff --git a/llvm/test/CodeGen/SPIRV/pointers/nested-struct-opaque-pointers.ll b/llvm/test/CodeGen/SPIRV/pointers/nested-struct-opaque-pointers.ll
new file mode 100644
index 00000000000000..0cf8b513e7f4dc
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/pointers/nested-struct-opaque-pointers.ll
@@ -0,0 +1,29 @@
+; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-NOT: OpTypeInt 8 0
+
+; TODO: %[[TyInt64:.*]] = OpTypeInt 64 0
+; TODO: %[[TyInt64Ptr:.*]] = OpTypePointer {{[a-zA-Z]+}} %[[TyInt64]]
+; TODO: %[[TyStruct:.*]] = OpTypeStruct %[[TyInt64Ptr]] %[[TyInt64Ptr]]
+; TODO: %[[ConstStruct:.*]] = OpConstantComposite %[[TyStruct]] %[[ConstField:.*]] %[[ConstField]]
+; TODO: %[[TyStructPtr:.*]] = OpTypePointer {{[a-zA-Z]+}} %[[TyStruct]]
+; TODO: OpVariable %[[TyStructPtr]] {{[a-zA-Z]+}} %[[ConstStruct]]
+
+ at GI = addrspace(1) constant i64 42
+
+ at GS = addrspace(1) global {ptr addrspace(1), ptr addrspace(1)} { ptr addrspace(1) @GI, ptr addrspace(1) @GI }
+ at GS2 = addrspace(1) global {ptr addrspace(1), ptr addrspace(1)} { ptr addrspace(1) @GS, ptr addrspace(1) @GS }
+ at GS3 = addrspace(1) global {ptr addrspace(1), ptr addrspace(1)} { ptr addrspace(1) @GS2, ptr addrspace(1) @GS2 }
+
+ at GPS = addrspace(1) global ptr addrspace(1) @GS3
+
+;@GPI0 = external addrspace(1) global ptr addrspace(1)
+
+ at GPI1 = addrspace(1) global ptr addrspace(1) @GI
+ at GPI2 = addrspace(1) global ptr addrspace(1) @GPI1
+ at GPI3 = addrspace(1) global ptr addrspace(1) @GPI2
+
+define spir_kernel void @foo() {
+ ret void
+}
diff --git a/llvm/test/CodeGen/SPIRV/pointers/struct-opaque-pointers.ll b/llvm/test/CodeGen/SPIRV/pointers/struct-opaque-pointers.ll
index ce3ab8895a5948..6d4913f802c289 100644
--- a/llvm/test/CodeGen/SPIRV/pointers/struct-opaque-pointers.ll
+++ b/llvm/test/CodeGen/SPIRV/pointers/struct-opaque-pointers.ll
@@ -1,14 +1,14 @@
; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}
-; CHECK: %[[TyInt8:.*]] = OpTypeInt 8 0
-; CHECK: %[[TyInt8Ptr:.*]] = OpTypePointer {{[a-zA-Z]+}} %[[TyInt8]]
-; CHECK: %[[TyStruct:.*]] = OpTypeStruct %[[TyInt8Ptr]] %[[TyInt8Ptr]]
+; CHECK: %[[TyInt64:.*]] = OpTypeInt 64 0
+; CHECK: %[[TyInt64Ptr:.*]] = OpTypePointer {{[a-zA-Z]+}} %[[TyInt64]]
+; CHECK: %[[TyStruct:.*]] = OpTypeStruct %[[TyInt64Ptr]] %[[TyInt64Ptr]]
; CHECK: %[[ConstStruct:.*]] = OpConstantComposite %[[TyStruct]] %[[ConstField:.*]] %[[ConstField]]
; CHECK: %[[TyStructPtr:.*]] = OpTypePointer {{[a-zA-Z]+}} %[[TyStruct]]
; CHECK: OpVariable %[[TyStructPtr]] {{[a-zA-Z]+}} %[[ConstStruct]]
- at a = addrspace(1) constant i32 123
+ at a = addrspace(1) constant i64 42
@struct = addrspace(1) global {ptr addrspace(1), ptr addrspace(1)} { ptr addrspace(1) @a, ptr addrspace(1) @a }
define spir_kernel void @foo() {
More information about the llvm-commits
mailing list