[llvm] b7ac8fd - [SPIR-V] Improve type inference: deduce types of composite data structures (#86782)
via llvm-commits
llvm-commits at lists.llvm.org
Thu Mar 28 00:08:10 PDT 2024
Author: Vyacheslav Levytskyy
Date: 2024-03-28T08:08:06+01:00
New Revision: b7ac8fddb54816256fab70696ebc176717a391c3
URL: https://github.com/llvm/llvm-project/commit/b7ac8fddb54816256fab70696ebc176717a391c3
DIFF: https://github.com/llvm/llvm-project/commit/b7ac8fddb54816256fab70696ebc176717a391c3.diff
LOG: [SPIR-V] Improve type inference: deduce types of composite data structures (#86782)
This PR improves type inference in general and deduces types of
composite data structures in particular. Also added a way to insert a
bitcast to make a fun call valid in case of arguments types mismatch due
to opaque pointers type inference.
The attached test `pointers/nested-struct-opaque-pointers.ll`
demonstrates new capabilities: the SPIRV code emitted for this test is
now (1) valid in a sense of data field types and (2) accepted by
`spirv-val`.
More strict LIT checks, support of more composite data structures and
improvement of fun calls from the perspective of type correctness are
main todo's at the moment.
Added:
llvm/test/CodeGen/SPIRV/pointers/nested-struct-opaque-pointers.ll
Modified:
llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
llvm/lib/Target/SPIRV/SPIRVUtils.h
llvm/test/CodeGen/SPIRV/pointers/struct-opaque-pointers.ll
llvm/test/CodeGen/SPIRV/pointers/type-deduce-by-call-chain.ll
Removed:
################################################################################
diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
index afdca01561b0bc..ad4e72a3128b1e 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
@@ -201,21 +201,30 @@ static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,
if (!isPointerTy(OriginalArgType))
return GR->getOrCreateSPIRVType(OriginalArgType, MIRBuilder, ArgAccessQual);
- // In case OriginalArgType is of pointer type, there are three possibilities:
+ Argument *Arg = F.getArg(ArgIdx);
+ Type *ArgType = Arg->getType();
+ if (isTypedPointerTy(ArgType)) {
+ SPIRVType *ElementType = GR->getOrCreateSPIRVType(
+ cast<TypedPointerType>(ArgType)->getElementType(), MIRBuilder);
+ return GR->getOrCreateSPIRVPointerType(
+ ElementType, MIRBuilder,
+ addressSpaceToStorageClass(getPointerAddressSpace(ArgType), ST));
+ }
+
+ // In case OriginalArgType is of untyped pointer type, there are three
+ // possibilities:
// 1) This is a pointer of an LLVM IR element type, passed byval/byref.
// 2) This is an OpenCL/SPIR-V builtin type if there is spv_assign_type
- // intrinsic assigning a TargetExtType.
+ // intrinsic assigning a TargetExtType.
// 3) This is a pointer, try to retrieve pointer element type from a
// spv_assign_ptr_type intrinsic or otherwise use default pointer element
// type.
- Argument *Arg = F.getArg(ArgIdx);
- if (HasPointeeTypeAttr(Arg)) {
- Type *ByValRefType = Arg->hasByValAttr() ? Arg->getParamByValType()
- : Arg->getParamByRefType();
- SPIRVType *ElementType = GR->getOrCreateSPIRVType(ByValRefType, MIRBuilder);
+ if (hasPointeeTypeAttr(Arg)) {
+ SPIRVType *ElementType =
+ GR->getOrCreateSPIRVType(getPointeeTypeByAttr(Arg), MIRBuilder);
return GR->getOrCreateSPIRVPointerType(
ElementType, MIRBuilder,
- addressSpaceToStorageClass(getPointerAddressSpace(Arg->getType()), ST));
+ addressSpaceToStorageClass(getPointerAddressSpace(ArgType), ST));
}
for (auto User : Arg->users()) {
diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index 5828db6669ff18..7c5a38fa48d009 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -14,6 +14,7 @@
#include "SPIRV.h"
#include "SPIRVBuiltins.h"
#include "SPIRVMetadata.h"
+#include "SPIRVSubtarget.h"
#include "SPIRVTargetMachine.h"
#include "SPIRVUtils.h"
#include "llvm/IR/IRBuilder.h"
@@ -53,14 +54,22 @@ class SPIRVEmitIntrinsics
: public FunctionPass,
public InstVisitor<SPIRVEmitIntrinsics, Instruction *> {
SPIRVTargetMachine *TM = nullptr;
+ SPIRVGlobalRegistry *GR = nullptr;
Function *F = nullptr;
bool TrackConstants = true;
DenseMap<Instruction *, Constant *> AggrConsts;
+ DenseMap<Instruction *, Type *> AggrConstTypes;
DenseSet<Instruction *> AggrStores;
- // deduce values type
- DenseMap<Value *, Type *> DeducedElTys;
+ // deduce element type of untyped pointers
Type *deduceElementType(Value *I);
+ Type *deduceElementTypeHelper(Value *I);
+ Type *deduceElementTypeHelper(Value *I, std::unordered_set<Value *> &Visited);
+
+ // deduce nested types of composites
+ Type *deduceNestedTypeHelper(User *U);
+ Type *deduceNestedTypeHelper(User *U, Type *Ty,
+ std::unordered_set<Value *> &Visited);
void preprocessCompositeConstants(IRBuilder<> &B);
void preprocessUndefs(IRBuilder<> &B);
@@ -92,9 +101,9 @@ class SPIRVEmitIntrinsics
void insertPtrCastOrAssignTypeInstr(Instruction *I, IRBuilder<> &B);
void processGlobalValue(GlobalVariable &GV, IRBuilder<> &B);
void processParamTypes(Function *F, IRBuilder<> &B);
- Type *deduceFunParamType(Function *F, unsigned OpIdx);
- Type *deduceFunParamType(Function *F, unsigned OpIdx,
- std::unordered_set<Function *> &FVisited);
+ Type *deduceFunParamElementType(Function *F, unsigned OpIdx);
+ Type *deduceFunParamElementType(Function *F, unsigned OpIdx,
+ std::unordered_set<Function *> &FVisited);
public:
static char ID;
@@ -169,17 +178,20 @@ static inline void reportFatalOnTokenType(const Instruction *I) {
// Deduce and return a successfully deduced Type of the Instruction,
// or nullptr otherwise.
-static Type *deduceElementTypeHelper(Value *I,
- std::unordered_set<Value *> &Visited,
- DenseMap<Value *, Type *> &DeducedElTys) {
+Type *SPIRVEmitIntrinsics::deduceElementTypeHelper(Value *I) {
+ std::unordered_set<Value *> Visited;
+ return deduceElementTypeHelper(I, Visited);
+}
+
+Type *SPIRVEmitIntrinsics::deduceElementTypeHelper(
+ Value *I, std::unordered_set<Value *> &Visited) {
// allow to pass nullptr as an argument
if (!I)
return nullptr;
// maybe already known
- auto It = DeducedElTys.find(I);
- if (It != DeducedElTys.end())
- return It->second;
+ if (Type *KnownTy = GR->findDeducedElementType(I))
+ return KnownTy;
// maybe a cycle
if (Visited.find(I) != Visited.end())
@@ -195,25 +207,99 @@ static Type *deduceElementTypeHelper(Value *I,
Ty = Ref->getResultElementType();
} else if (auto *Ref = dyn_cast<GlobalValue>(I)) {
Ty = Ref->getValueType();
+ if (Value *Op = Ref->getNumOperands() > 0 ? Ref->getOperand(0) : nullptr) {
+ if (auto *PtrTy = dyn_cast<PointerType>(Ty)) {
+ if (Type *NestedTy = deduceElementTypeHelper(Op, Visited))
+ Ty = TypedPointerType::get(NestedTy, PtrTy->getAddressSpace());
+ } else {
+ Ty = deduceNestedTypeHelper(dyn_cast<User>(Op), Ty, Visited);
+ }
+ }
} else if (auto *Ref = dyn_cast<AddrSpaceCastInst>(I)) {
- Ty = deduceElementTypeHelper(Ref->getPointerOperand(), Visited,
- DeducedElTys);
+ Ty = deduceElementTypeHelper(Ref->getPointerOperand(), Visited);
} else if (auto *Ref = dyn_cast<BitCastInst>(I)) {
if (Type *Src = Ref->getSrcTy(), *Dest = Ref->getDestTy();
isPointerTy(Src) && isPointerTy(Dest))
- Ty = deduceElementTypeHelper(Ref->getOperand(0), Visited, DeducedElTys);
+ Ty = deduceElementTypeHelper(Ref->getOperand(0), Visited);
}
// remember the found relationship
- if (Ty)
- DeducedElTys[I] = Ty;
+ if (Ty) {
+ // specify nested types if needed, otherwise return unchanged
+ GR->addDeducedElementType(I, Ty);
+ }
return Ty;
}
-Type *SPIRVEmitIntrinsics::deduceElementType(Value *I) {
+// Re-create a type of the value if it has untyped pointer fields, also nested.
+// Return the original value type if no corrections of untyped pointer
+// information is found or needed.
+Type *SPIRVEmitIntrinsics::deduceNestedTypeHelper(User *U) {
std::unordered_set<Value *> Visited;
- if (Type *Ty = deduceElementTypeHelper(I, Visited, DeducedElTys))
+ return deduceNestedTypeHelper(U, U->getType(), Visited);
+}
+
+Type *SPIRVEmitIntrinsics::deduceNestedTypeHelper(
+ User *U, Type *OrigTy, std::unordered_set<Value *> &Visited) {
+ if (!U)
+ return OrigTy;
+
+ // maybe already known
+ if (Type *KnownTy = GR->findDeducedCompositeType(U))
+ return KnownTy;
+
+ // maybe a cycle
+ if (Visited.find(U) != Visited.end())
+ return OrigTy;
+ Visited.insert(U);
+
+ if (dyn_cast<StructType>(OrigTy)) {
+ SmallVector<Type *> Tys;
+ bool Change = false;
+ for (unsigned i = 0; i < U->getNumOperands(); ++i) {
+ Value *Op = U->getOperand(i);
+ Type *OpTy = Op->getType();
+ Type *Ty = OpTy;
+ if (Op) {
+ if (auto *PtrTy = dyn_cast<PointerType>(OpTy)) {
+ if (Type *NestedTy = deduceElementTypeHelper(Op, Visited))
+ Ty = TypedPointerType::get(NestedTy, PtrTy->getAddressSpace());
+ } else {
+ Ty = deduceNestedTypeHelper(dyn_cast<User>(Op), OpTy, Visited);
+ }
+ }
+ Tys.push_back(Ty);
+ Change |= Ty != OpTy;
+ }
+ if (Change) {
+ Type *NewTy = StructType::create(Tys);
+ GR->addDeducedCompositeType(U, NewTy);
+ return NewTy;
+ }
+ } else if (auto *ArrTy = dyn_cast<ArrayType>(OrigTy)) {
+ if (Value *Op = U->getNumOperands() > 0 ? U->getOperand(0) : nullptr) {
+ Type *OpTy = ArrTy->getElementType();
+ Type *Ty = OpTy;
+ if (auto *PtrTy = dyn_cast<PointerType>(OpTy)) {
+ if (Type *NestedTy = deduceElementTypeHelper(Op, Visited))
+ Ty = TypedPointerType::get(NestedTy, PtrTy->getAddressSpace());
+ } else {
+ Ty = deduceNestedTypeHelper(dyn_cast<User>(Op), OpTy, Visited);
+ }
+ if (Ty != OpTy) {
+ Type *NewTy = ArrayType::get(Ty, ArrTy->getNumElements());
+ GR->addDeducedCompositeType(U, NewTy);
+ return NewTy;
+ }
+ }
+ }
+
+ return OrigTy;
+}
+
+Type *SPIRVEmitIntrinsics::deduceElementType(Value *I) {
+ if (Type *Ty = deduceElementTypeHelper(I))
return Ty;
return IntegerType::getInt8Ty(I->getContext());
}
@@ -257,6 +343,7 @@ void SPIRVEmitIntrinsics::preprocessUndefs(IRBuilder<> &B) {
Worklist.push(IntrUndef);
I->replaceUsesOfWith(Op, IntrUndef);
AggrConsts[IntrUndef] = AggrUndef;
+ AggrConstTypes[IntrUndef] = AggrUndef->getType();
}
}
}
@@ -282,6 +369,7 @@ void SPIRVEmitIntrinsics::preprocessCompositeConstants(IRBuilder<> &B) {
I->replaceUsesOfWith(Op, CCI);
KeepInst = true;
SEI.AggrConsts[CCI] = AggrC;
+ SEI.AggrConstTypes[CCI] = SEI.deduceNestedTypeHelper(AggrC);
};
if (auto *AggrC = dyn_cast<ConstantAggregate>(Op)) {
@@ -396,8 +484,7 @@ void SPIRVEmitIntrinsics::replacePointerOperandWithPtrCast(
Pointer = BC->getOperand(0);
// Do not emit spv_ptrcast if Pointer's element type is ExpectedElementType
- std::unordered_set<Value *> Visited;
- Type *PointerElemTy = deduceElementTypeHelper(Pointer, Visited, DeducedElTys);
+ Type *PointerElemTy = deduceElementTypeHelper(Pointer);
if (PointerElemTy == ExpectedElementType)
return;
@@ -456,8 +543,8 @@ void SPIRVEmitIntrinsics::replacePointerOperandWithPtrCast(
CallInst *CI = buildIntrWithMD(
Intrinsic::spv_assign_ptr_type, {Pointer->getType()},
ExpectedElementTypeConst, Pointer, {B.getInt32(AddressSpace)}, B);
- DeducedElTys[CI] = ExpectedElementType;
- DeducedElTys[Pointer] = ExpectedElementType;
+ GR->addDeducedElementType(CI, ExpectedElementType);
+ GR->addDeducedElementType(Pointer, ExpectedElementType);
return;
}
@@ -498,25 +585,29 @@ void SPIRVEmitIntrinsics::insertPtrCastOrAssignTypeInstr(Instruction *I,
Function *CalledF = CI->getCalledFunction();
SmallVector<Type *, 4> CalledArgTys;
bool HaveTypes = false;
- for (auto &CalledArg : CalledF->args()) {
- if (!isPointerTy(CalledArg.getType())) {
+ for (unsigned OpIdx = 0; OpIdx < CalledF->arg_size(); ++OpIdx) {
+ Argument *CalledArg = CalledF->getArg(OpIdx);
+ Type *ArgType = CalledArg->getType();
+ if (!isPointerTy(ArgType)) {
CalledArgTys.push_back(nullptr);
- continue;
- }
- auto It = DeducedElTys.find(&CalledArg);
- Type *ParamTy = It != DeducedElTys.end() ? It->second : nullptr;
- if (!ParamTy) {
- for (User *U : CalledArg.users()) {
- if (Instruction *Inst = dyn_cast<Instruction>(U)) {
- std::unordered_set<Value *> Visited;
- ParamTy = deduceElementTypeHelper(Inst, Visited, DeducedElTys);
- if (ParamTy)
- break;
+ } else if (isTypedPointerTy(ArgType)) {
+ CalledArgTys.push_back(cast<TypedPointerType>(ArgType)->getElementType());
+ HaveTypes = true;
+ } else {
+ Type *ElemTy = GR->findDeducedElementType(CalledArg);
+ if (!ElemTy && hasPointeeTypeAttr(CalledArg))
+ ElemTy = getPointeeTypeByAttr(CalledArg);
+ if (!ElemTy) {
+ for (User *U : CalledArg->users()) {
+ if (Instruction *Inst = dyn_cast<Instruction>(U)) {
+ if ((ElemTy = deduceElementTypeHelper(Inst)) != nullptr)
+ break;
+ }
}
}
+ HaveTypes |= ElemTy != nullptr;
+ CalledArgTys.push_back(ElemTy);
}
- HaveTypes |= ParamTy != nullptr;
- CalledArgTys.push_back(ParamTy);
}
std::string DemangledName =
@@ -706,6 +797,10 @@ void SPIRVEmitIntrinsics::processGlobalValue(GlobalVariable &GV,
if (GV.getName() == "llvm.global.annotations")
return;
if (GV.hasInitializer() && !isa<UndefValue>(GV.getInitializer())) {
+ // Deduce element type and store results in Global Registry.
+ // Result is ignored, because TypedPointerType is not supported
+ // by llvm IR general logic.
+ deduceElementTypeHelper(&GV);
Constant *Init = GV.getInitializer();
Type *Ty = isAggrToReplace(Init) ? B.getInt32Ty() : Init->getType();
Constant *Const = isAggrToReplace(Init) ? B.getInt32(1) : Init;
@@ -732,7 +827,7 @@ void SPIRVEmitIntrinsics::insertAssignPtrTypeIntrs(Instruction *I,
unsigned AddressSpace = getPointerAddressSpace(I->getType());
CallInst *CI = buildIntrWithMD(Intrinsic::spv_assign_ptr_type, {I->getType()},
EltTyConst, I, {B.getInt32(AddressSpace)}, B);
- DeducedElTys[CI] = ElemTy;
+ GR->addDeducedElementType(CI, ElemTy);
}
void SPIRVEmitIntrinsics::insertAssignTypeIntrs(Instruction *I,
@@ -745,9 +840,10 @@ void SPIRVEmitIntrinsics::insertAssignTypeIntrs(Instruction *I,
if (auto *II = dyn_cast<IntrinsicInst>(I)) {
if (II->getIntrinsicID() == Intrinsic::spv_const_composite ||
II->getIntrinsicID() == Intrinsic::spv_undef) {
- auto t = AggrConsts.find(II);
- assert(t != AggrConsts.end());
- TypeToAssign = t->second->getType();
+ auto It = AggrConstTypes.find(II);
+ if (It == AggrConstTypes.end())
+ report_fatal_error("Unknown composite intrinsic type");
+ TypeToAssign = It->second;
}
}
Constant *Const = UndefValue::get(TypeToAssign);
@@ -807,12 +903,13 @@ void SPIRVEmitIntrinsics::processInstrAfterVisit(Instruction *I,
}
}
-Type *SPIRVEmitIntrinsics::deduceFunParamType(Function *F, unsigned OpIdx) {
+Type *SPIRVEmitIntrinsics::deduceFunParamElementType(Function *F,
+ unsigned OpIdx) {
std::unordered_set<Function *> FVisited;
- return deduceFunParamType(F, OpIdx, FVisited);
+ return deduceFunParamElementType(F, OpIdx, FVisited);
}
-Type *SPIRVEmitIntrinsics::deduceFunParamType(
+Type *SPIRVEmitIntrinsics::deduceFunParamElementType(
Function *F, unsigned OpIdx, std::unordered_set<Function *> &FVisited) {
// maybe a cycle
if (FVisited.find(F) != FVisited.end())
@@ -830,15 +927,15 @@ Type *SPIRVEmitIntrinsics::deduceFunParamType(
if (!isPointerTy(OpArg->getType()))
continue;
// maybe we already know operand's element type
- if (auto It = DeducedElTys.find(OpArg); It != DeducedElTys.end())
- return It->second;
+ if (Type *KnownTy = GR->findDeducedElementType(OpArg))
+ return KnownTy;
// search in actual parameter's users
for (User *OpU : OpArg->users()) {
Instruction *Inst = dyn_cast<Instruction>(OpU);
if (!Inst || Inst == CI)
continue;
Visited.clear();
- if (Type *Ty = deduceElementTypeHelper(Inst, Visited, DeducedElTys))
+ if (Type *Ty = deduceElementTypeHelper(Inst, Visited))
return Ty;
}
// check if it's a formal parameter of the outer function
@@ -857,7 +954,7 @@ Type *SPIRVEmitIntrinsics::deduceFunParamType(
// search in function parameters
for (auto &Pair : Lookup) {
- if (Type *Ty = deduceFunParamType(Pair.first, Pair.second, FVisited))
+ if (Type *Ty = deduceFunParamElementType(Pair.first, Pair.second, FVisited))
return Ty;
}
@@ -866,19 +963,23 @@ Type *SPIRVEmitIntrinsics::deduceFunParamType(
void SPIRVEmitIntrinsics::processParamTypes(Function *F, IRBuilder<> &B) {
B.SetInsertPointPastAllocas(F);
- DenseMap<Argument *, Type *> Args;
for (unsigned OpIdx = 0; OpIdx < F->arg_size(); ++OpIdx) {
Argument *Arg = F->getArg(OpIdx);
- if (isUntypedPointerTy(Arg->getType()) &&
- DeducedElTys.find(Arg) == DeducedElTys.end() &&
- !HasPointeeTypeAttr(Arg)) {
- if (Type *ElemTy = deduceFunParamType(F, OpIdx)) {
+ if (!isUntypedPointerTy(Arg->getType()))
+ continue;
+
+ Type *ElemTy = GR->findDeducedElementType(Arg);
+ if (!ElemTy) {
+ if (hasPointeeTypeAttr(Arg) &&
+ (ElemTy = getPointeeTypeByAttr(Arg)) != nullptr) {
+ GR->addDeducedElementType(Arg, ElemTy);
+ } else if ((ElemTy = deduceFunParamElementType(F, OpIdx)) != nullptr) {
CallInst *AssignPtrTyCI = buildIntrWithMD(
Intrinsic::spv_assign_ptr_type, {Arg->getType()},
Constant::getNullValue(ElemTy), Arg,
{B.getInt32(getPointerAddressSpace(Arg->getType()))}, B);
- DeducedElTys[AssignPtrTyCI] = ElemTy;
- DeducedElTys[Arg] = ElemTy;
+ GR->addDeducedElementType(AssignPtrTyCI, ElemTy);
+ GR->addDeducedElementType(Arg, ElemTy);
}
}
}
@@ -887,9 +988,14 @@ void SPIRVEmitIntrinsics::processParamTypes(Function *F, IRBuilder<> &B) {
bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
if (Func.isDeclaration())
return false;
+
+ const SPIRVSubtarget &ST = TM->getSubtarget<SPIRVSubtarget>(Func);
+ GR = ST.getSPIRVGlobalRegistry();
+
F = &Func;
IRBuilder<> B(Func.getContext());
AggrConsts.clear();
+ AggrConstTypes.clear();
AggrStores.clear();
// StoreInst's operand type can be changed during the next transformations,
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index ed0f90ff89ce6e..e0099e52944725 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -41,9 +41,13 @@ class SPIRVGlobalRegistry {
// map a Function to its definition (as a machine instruction operand)
DenseMap<const Function *, const MachineOperand *> FunctionToInstr;
+ DenseMap<const MachineInstr *, const Function *> FunctionToInstrRev;
// map function pointer (as a machine instruction operand) to the used
// Function
DenseMap<const MachineOperand *, const Function *> InstrToFunction;
+ // Maps Functions to their calls (in a form of the machine instruction,
+ // OpFunctionCall) that happened before the definition is available
+ DenseMap<const Function *, SmallVector<MachineInstr *>> ForwardCalls;
// Look for an equivalent of the newType in the map. Return the equivalent
// if it's found, otherwise insert newType to the map and return the type.
@@ -59,6 +63,13 @@ class SPIRVGlobalRegistry {
// Holds the maximum ID we have in the module.
unsigned Bound;
+ // Maps values associated with untyped pointers into deduced element types of
+ // untyped pointers.
+ DenseMap<Value *, Type *> DeducedElTys;
+ // Maps composite values to deduced types where untyped pointers are replaced
+ // with typed ones
+ DenseMap<Value *, Type *> DeducedNestedTys;
+
// Add a new OpTypeXXX instruction without checking for duplicates.
SPIRVType *createSPIRVType(const Type *Type, MachineIRBuilder &MIRBuilder,
SPIRV::AccessQualifier::AccessQualifier AQ =
@@ -122,6 +133,37 @@ class SPIRVGlobalRegistry {
void setBound(unsigned V) { Bound = V; }
unsigned getBound() { return Bound; }
+ // Deduced element types of untyped pointers and composites:
+ // - Add a record to the map of deduced element types.
+ void addDeducedElementType(Value *Val, Type *Ty) { DeducedElTys[Val] = Ty; }
+ // - Find a record in the map of deduced element types.
+ Type *findDeducedElementType(const Value *Val) {
+ auto It = DeducedElTys.find(Val);
+ return It == DeducedElTys.end() ? nullptr : It->second;
+ }
+ // - Add a record to the map of deduced composite types.
+ void addDeducedCompositeType(Value *Val, Type *Ty) {
+ DeducedNestedTys[Val] = Ty;
+ }
+ // - Find a record in the map of deduced composite types.
+ Type *findDeducedCompositeType(const Value *Val) {
+ auto It = DeducedNestedTys.find(Val);
+ return It == DeducedNestedTys.end() ? nullptr : It->second;
+ }
+ // - Find a type of the given Global value
+ Type *getDeducedGlobalValueType(const GlobalValue *Global) {
+ // we may know element type if it was deduced earlier
+ Type *ElementTy = findDeducedElementType(Global);
+ if (!ElementTy) {
+ // or we may know element type if it's associated with a composite
+ // value
+ if (Value *GlobalElem =
+ Global->getNumOperands() > 0 ? Global->getOperand(0) : nullptr)
+ ElementTy = findDeducedCompositeType(GlobalElem);
+ }
+ return ElementTy ? ElementTy : Global->getValueType();
+ }
+
// Map a machine operand that represents a use of a function via function
// pointer to a machine operand that represents the function definition.
// Return either the register or invalid value, because we have no context for
@@ -133,18 +175,56 @@ class SPIRVGlobalRegistry {
auto ResReg = FunctionToInstr.find(ResF->second);
return ResReg == FunctionToInstr.end() ? nullptr : ResReg->second;
}
+
+ // Map a Function to a machine instruction that represents the function
+ // definition.
+ const MachineInstr *getFunctionDefinition(const Function *F) {
+ if (!F)
+ return nullptr;
+ auto MOIt = FunctionToInstr.find(F);
+ return MOIt == FunctionToInstr.end() ? nullptr : MOIt->second->getParent();
+ }
+
+ // Map a Function to a machine instruction that represents the function
+ // definition.
+ const Function *getFunctionByDefinition(const MachineInstr *MI) {
+ if (!MI)
+ return nullptr;
+ auto FIt = FunctionToInstrRev.find(MI);
+ return FIt == FunctionToInstrRev.end() ? nullptr : FIt->second;
+ }
+
// map function pointer (as a machine instruction operand) to the used
// Function
void recordFunctionPointer(const MachineOperand *MO, const Function *F) {
InstrToFunction[MO] = F;
}
+
// map a Function to its definition (as a machine instruction)
void recordFunctionDefinition(const Function *F, const MachineOperand *MO) {
FunctionToInstr[F] = MO;
+ FunctionToInstrRev[MO->getParent()] = F;
}
+
// Return true if any OpConstantFunctionPointerINTEL were generated
bool hasConstFunPtr() { return !InstrToFunction.empty(); }
+ // Add a record about forward function call.
+ void addForwardCall(const Function *F, MachineInstr *MI) {
+ auto It = ForwardCalls.find(F);
+ if (It == ForwardCalls.end())
+ ForwardCalls[F] = {MI};
+ else
+ It->second.push_back(MI);
+ }
+
+ // Map a Function to the vector of machine instructions that represents
+ // forward function calls or to nullptr if not found.
+ SmallVector<MachineInstr *> *getForwardCalls(const Function *F) {
+ auto It = ForwardCalls.find(F);
+ return It == ForwardCalls.end() ? nullptr : &It->second;
+ }
+
// Get or create a SPIR-V type corresponding the given LLVM IR type,
// and map it to the given VReg by creating an ASSIGN_TYPE instruction.
SPIRVType *assignTypeToVReg(const Type *Type, Register VReg,
diff --git a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
index 55b4c47c197dab..4f5c1dc4f90b0d 100644
--- a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
@@ -86,8 +86,8 @@ bool SPIRVTargetLowering::getTgtMemIntrinsic(IntrinsicInfo &Info,
// when there is a type mismatch between results and operand types.
static void validatePtrTypes(const SPIRVSubtarget &STI,
MachineRegisterInfo *MRI, SPIRVGlobalRegistry &GR,
- MachineInstr &I, SPIRVType *ResType,
- unsigned OpIdx) {
+ MachineInstr &I, unsigned OpIdx,
+ SPIRVType *ResType, const Type *ResTy = nullptr) {
Register OpReg = I.getOperand(OpIdx).getReg();
SPIRVType *TypeInst = MRI->getVRegDef(OpReg);
SPIRVType *OpType = GR.getSPIRVTypeForVReg(
@@ -97,7 +97,13 @@ static void validatePtrTypes(const SPIRVSubtarget &STI,
if (!ResType || !OpType || OpType->getOpcode() != SPIRV::OpTypePointer)
return;
SPIRVType *ElemType = GR.getSPIRVTypeForVReg(OpType->getOperand(2).getReg());
- if (!ElemType || ElemType == ResType)
+ if (!ElemType)
+ return;
+ bool IsSameMF =
+ ElemType->getParent()->getParent() == ResType->getParent()->getParent();
+ bool IsEqualTypes = IsSameMF ? ElemType == ResType
+ : GR.getTypeForSPIRVType(ElemType) == ResTy;
+ if (IsEqualTypes)
return;
// There is a type mismatch between results and operand types
// and we insert a bitcast before the instruction to keep SPIR-V code valid
@@ -105,7 +111,11 @@ static void validatePtrTypes(const SPIRVSubtarget &STI,
static_cast<SPIRV::StorageClass::StorageClass>(
OpType->getOperand(1).getImm());
MachineIRBuilder MIB(I);
- SPIRVType *NewPtrType = GR.getOrCreateSPIRVPointerType(ResType, MIB, SC);
+ SPIRVType *NewBaseType =
+ IsSameMF ? ResType
+ : GR.getOrCreateSPIRVType(
+ ResTy, MIB, SPIRV::AccessQualifier::ReadWrite, false);
+ SPIRVType *NewPtrType = GR.getOrCreateSPIRVPointerType(NewBaseType, MIB, SC);
if (!GR.isBitcastCompatible(NewPtrType, OpType))
report_fatal_error(
"insert validation bitcast: incompatible result and operand types");
@@ -123,6 +133,74 @@ static void validatePtrTypes(const SPIRVSubtarget &STI,
I.getOperand(OpIdx).setReg(NewReg);
}
+// Insert a bitcast before the function call instruction to keep SPIR-V code
+// valid when there is a type mismatch between actual and expected types of an
+// argument:
+// %formal = OpFunctionParameter %formal_type
+// ...
+// %res = OpFunctionCall %ty %fun %actual ...
+// implies that %actual is of %formal_type, and in case of opaque pointers.
+// We may need to insert a bitcast to ensure this.
+void validateFunCallMachineDef(const SPIRVSubtarget &STI,
+ MachineRegisterInfo *DefMRI,
+ MachineRegisterInfo *CallMRI,
+ SPIRVGlobalRegistry &GR, MachineInstr &FunCall,
+ MachineInstr *FunDef) {
+ if (FunDef->getOpcode() != SPIRV::OpFunction)
+ return;
+ unsigned OpIdx = 3;
+ for (FunDef = FunDef->getNextNode();
+ FunDef && FunDef->getOpcode() == SPIRV::OpFunctionParameter &&
+ OpIdx < FunCall.getNumOperands();
+ FunDef = FunDef->getNextNode(), OpIdx++) {
+ SPIRVType *DefPtrType = DefMRI->getVRegDef(FunDef->getOperand(1).getReg());
+ SPIRVType *DefElemType =
+ DefPtrType && DefPtrType->getOpcode() == SPIRV::OpTypePointer
+ ? GR.getSPIRVTypeForVReg(DefPtrType->getOperand(2).getReg())
+ : nullptr;
+ if (DefElemType) {
+ const Type *DefElemTy = GR.getTypeForSPIRVType(DefElemType);
+ // Switch GR context to the call site instead of the (default) definition
+ // side
+ GR.setCurrentFunc(*FunCall.getParent()->getParent());
+ validatePtrTypes(STI, CallMRI, GR, FunCall, OpIdx, DefElemType,
+ DefElemTy);
+ GR.setCurrentFunc(*FunDef->getParent()->getParent());
+ }
+ }
+}
+
+// Ensure there is no mismatch between actual and expected arg types: calls
+// with a processed definition. Return Function pointer if it's a forward
+// call (ahead of definition), and nullptr otherwise.
+const Function *validateFunCall(const SPIRVSubtarget &STI,
+ MachineRegisterInfo *MRI,
+ SPIRVGlobalRegistry &GR,
+ MachineInstr &FunCall) {
+ const GlobalValue *GV = FunCall.getOperand(2).getGlobal();
+ const Function *F = dyn_cast<Function>(GV);
+ MachineInstr *FunDef =
+ const_cast<MachineInstr *>(GR.getFunctionDefinition(F));
+ if (!FunDef)
+ return F;
+ validateFunCallMachineDef(STI, MRI, MRI, GR, FunCall, FunDef);
+ return nullptr;
+}
+
+// Ensure there is no mismatch between actual and expected arg types: calls
+// ahead of a processed definition.
+void validateForwardCalls(const SPIRVSubtarget &STI,
+ MachineRegisterInfo *DefMRI, SPIRVGlobalRegistry &GR,
+ MachineInstr &FunDef) {
+ const Function *F = GR.getFunctionByDefinition(&FunDef);
+ if (SmallVector<MachineInstr *> *FwdCalls = GR.getForwardCalls(F))
+ for (MachineInstr *FunCall : *FwdCalls) {
+ MachineRegisterInfo *CallMRI =
+ &FunCall->getParent()->getParent()->getRegInfo();
+ validateFunCallMachineDef(STI, DefMRI, CallMRI, GR, *FunCall, &FunDef);
+ }
+}
+
// TODO: the logic of inserting additional bitcast's is to be moved
// to pre-IRTranslation passes eventually
void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const {
@@ -137,14 +215,28 @@ void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const {
switch (MI.getOpcode()) {
case SPIRV::OpLoad:
// OpLoad <ResType>, ptr %Op implies that %Op is a pointer to <ResType>
- validatePtrTypes(STI, MRI, GR, MI,
- GR.getSPIRVTypeForVReg(MI.getOperand(0).getReg()), 2);
+ validatePtrTypes(STI, MRI, GR, MI, 2,
+ GR.getSPIRVTypeForVReg(MI.getOperand(0).getReg()));
break;
case SPIRV::OpStore:
// OpStore ptr %Op, <Obj> implies that %Op points to the <Obj>'s type
- validatePtrTypes(STI, MRI, GR, MI,
- GR.getSPIRVTypeForVReg(MI.getOperand(1).getReg()), 0);
+ validatePtrTypes(STI, MRI, GR, MI, 0,
+ GR.getSPIRVTypeForVReg(MI.getOperand(1).getReg()));
break;
+
+ case SPIRV::OpFunctionCall:
+ // ensure there is no mismatch between actual and expected arg types:
+ // calls with a processed definition
+ if (MI.getNumOperands() > 3)
+ if (const Function *F = validateFunCall(STI, MRI, GR, MI))
+ GR.addForwardCall(F, &MI);
+ break;
+ case SPIRV::OpFunction:
+ // ensure there is no mismatch between actual and expected arg types:
+ // calls ahead of a processed definition
+ validateForwardCalls(STI, MRI, GR, MI);
+ break;
+
// ensure that LLVM IR bitwise instructions result in logical SPIR-V
// instructions when applied to bool type
case SPIRV::OpBitwiseOrS:
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 505b19a4d66edb..f4525e713c987f 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -1897,7 +1897,7 @@ bool SPIRVInstructionSelector::selectGlobalValue(
// FIXME: don't use MachineIRBuilder here, replace it with BuildMI.
MachineIRBuilder MIRBuilder(I);
const GlobalValue *GV = I.getOperand(1).getGlobal();
- Type *GVType = GV->getValueType();
+ Type *GVType = GR.getDeducedGlobalValueType(GV);
SPIRVType *PointerBaseType;
if (GVType->isArrayTy()) {
SPIRVType *ArrayElementType =
diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
index 41807da6afcbc7..b133f0ae85de20 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
@@ -186,8 +186,9 @@ static SPIRVType *propagateSPIRVType(MachineInstr *MI, SPIRVGlobalRegistry *GR,
}
case TargetOpcode::G_GLOBAL_VALUE: {
MIB.setInsertPt(*MI->getParent(), MI);
- const auto *Global = MI->getOperand(1).getGlobal();
- auto *Ty = TypedPointerType::get(Global->getValueType(),
+ const GlobalValue *Global = MI->getOperand(1).getGlobal();
+ Type *ElementTy = GR->getDeducedGlobalValueType(Global);
+ auto *Ty = TypedPointerType::get(ElementTy,
Global->getType()->getAddressSpace());
SpirvTy = GR->getOrCreateSPIRVType(Ty, MIB);
break;
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.h b/llvm/lib/Target/SPIRV/SPIRVUtils.h
index eb87349f0941c5..c2c3475e1a936f 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.h
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.h
@@ -127,8 +127,26 @@ inline unsigned getPointerAddressSpace(const Type *T) {
}
// Return true if the Argument is decorated with a pointee type
-inline bool HasPointeeTypeAttr(Argument *Arg) {
- return Arg->hasByValAttr() || Arg->hasByRefAttr();
+inline bool hasPointeeTypeAttr(Argument *Arg) {
+ return Arg->hasByValAttr() || Arg->hasByRefAttr() || Arg->hasStructRetAttr();
+}
+
+// Return the pointee type of the argument or nullptr otherwise
+inline Type *getPointeeTypeByAttr(Argument *Arg) {
+ if (Arg->hasByValAttr())
+ return Arg->getParamByValType();
+ if (Arg->hasStructRetAttr())
+ return Arg->getParamStructRetType();
+ if (Arg->hasByRefAttr())
+ return Arg->getParamByRefType();
+ return nullptr;
+}
+
+inline Type *reconstructFunctionType(Function *F) {
+ SmallVector<Type *> ArgTys;
+ for (unsigned i = 0; i < F->arg_size(); ++i)
+ ArgTys.push_back(F->getArg(i)->getType());
+ return FunctionType::get(F->getReturnType(), ArgTys, F->isVarArg());
}
} // namespace llvm
diff --git a/llvm/test/CodeGen/SPIRV/pointers/nested-struct-opaque-pointers.ll b/llvm/test/CodeGen/SPIRV/pointers/nested-struct-opaque-pointers.ll
new file mode 100644
index 00000000000000..77b895c7762fba
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/pointers/nested-struct-opaque-pointers.ll
@@ -0,0 +1,20 @@
+; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-NOT: OpTypeInt 8 0
+
+ at GI = addrspace(1) constant i64 42
+
+ at GS = addrspace(1) global {ptr addrspace(1), ptr addrspace(1)} { ptr addrspace(1) @GI, ptr addrspace(1) @GI }
+ at GS2 = addrspace(1) global {ptr addrspace(1), ptr addrspace(1)} { ptr addrspace(1) @GS, ptr addrspace(1) @GS }
+ at GS3 = addrspace(1) global {ptr addrspace(1), ptr addrspace(1)} { ptr addrspace(1) @GS2, ptr addrspace(1) @GS2 }
+
+ at GPS = addrspace(1) global ptr addrspace(1) @GS3
+
+ at GPI1 = addrspace(1) global ptr addrspace(1) @GI
+ at GPI2 = addrspace(1) global ptr addrspace(1) @GPI1
+ at GPI3 = addrspace(1) global ptr addrspace(1) @GPI2
+
+define spir_kernel void @foo() {
+ ret void
+}
diff --git a/llvm/test/CodeGen/SPIRV/pointers/struct-opaque-pointers.ll b/llvm/test/CodeGen/SPIRV/pointers/struct-opaque-pointers.ll
index ce3ab8895a5948..6d4913f802c289 100644
--- a/llvm/test/CodeGen/SPIRV/pointers/struct-opaque-pointers.ll
+++ b/llvm/test/CodeGen/SPIRV/pointers/struct-opaque-pointers.ll
@@ -1,14 +1,14 @@
; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}
-; CHECK: %[[TyInt8:.*]] = OpTypeInt 8 0
-; CHECK: %[[TyInt8Ptr:.*]] = OpTypePointer {{[a-zA-Z]+}} %[[TyInt8]]
-; CHECK: %[[TyStruct:.*]] = OpTypeStruct %[[TyInt8Ptr]] %[[TyInt8Ptr]]
+; CHECK: %[[TyInt64:.*]] = OpTypeInt 64 0
+; CHECK: %[[TyInt64Ptr:.*]] = OpTypePointer {{[a-zA-Z]+}} %[[TyInt64]]
+; CHECK: %[[TyStruct:.*]] = OpTypeStruct %[[TyInt64Ptr]] %[[TyInt64Ptr]]
; CHECK: %[[ConstStruct:.*]] = OpConstantComposite %[[TyStruct]] %[[ConstField:.*]] %[[ConstField]]
; CHECK: %[[TyStructPtr:.*]] = OpTypePointer {{[a-zA-Z]+}} %[[TyStruct]]
; CHECK: OpVariable %[[TyStructPtr]] {{[a-zA-Z]+}} %[[ConstStruct]]
- at a = addrspace(1) constant i32 123
+ at a = addrspace(1) constant i64 42
@struct = addrspace(1) global {ptr addrspace(1), ptr addrspace(1)} { ptr addrspace(1) @a, ptr addrspace(1) @a }
define spir_kernel void @foo() {
diff --git a/llvm/test/CodeGen/SPIRV/pointers/type-deduce-by-call-chain.ll b/llvm/test/CodeGen/SPIRV/pointers/type-deduce-by-call-chain.ll
index 703f1e22a0321a..1071d3443056cb 100644
--- a/llvm/test/CodeGen/SPIRV/pointers/type-deduce-by-call-chain.ll
+++ b/llvm/test/CodeGen/SPIRV/pointers/type-deduce-by-call-chain.ll
@@ -34,6 +34,12 @@ entry:
%addr = addrspacecast ptr addrspace(1) %lptr to ptr addrspace(4)
%object = bitcast ptr addrspace(4) %addr to ptr addrspace(4)
call spir_func void @foo(ptr addrspace(4) %object, i32 3)
+ %halfptr = getelementptr inbounds half, ptr addrspace(1) %_arg_cum, i64 1
+ %halfaddr = addrspacecast ptr addrspace(1) %halfptr to ptr addrspace(4)
+ call spir_func void @foo(ptr addrspace(4) %halfaddr, i32 3)
+ %dblptr = getelementptr inbounds double, ptr addrspace(1) %_arg_cum, i64 1
+ %dbladdr = addrspacecast ptr addrspace(1) %dblptr to ptr addrspace(4)
+ call spir_func void @foo(ptr addrspace(4) %dbladdr, i32 3)
ret void
}
@@ -49,4 +55,3 @@ define void @foo(ptr addrspace(4) noundef %foo_object, i32 noundef %mem_order) {
tail call void @foo_stub(ptr addrspace(4) noundef %foo_object, i32 noundef %mem_order)
ret void
}
-
More information about the llvm-commits
mailing list