[llvm] [SPIR-V] Improve type inference in SPIR-V Backend for opaque pointers (PR #86283)
Vyacheslav Levytskyy via llvm-commits
llvm-commits at lists.llvm.org
Fri Mar 22 06:30:58 PDT 2024
https://github.com/VyacheslavLevytskyy created https://github.com/llvm/llvm-project/pull/86283
This PR improves type inference in SPIR-V Backend for opaque pointers, accounting or a case when there is a chain of function calls that allows to deduce formal parameter types from actual arguments. The attached test demonstrates the case.
>From 8c40f8867eb1075846c5770d8af7991c6caa2518 Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Fri, 22 Mar 2024 06:28:20 -0700
Subject: [PATCH] improve type inference: case of chain of calls of arbitrary
order
---
llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp | 123 +++++++++++-------
llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp | 1 +
.../Target/SPIRV/SPIRVInstructionSelector.cpp | 12 +-
.../pointers/type-deduce-by-call-chain.ll | 52 ++++++++
4 files changed, 138 insertions(+), 50 deletions(-)
create mode 100644 llvm/test/CodeGen/SPIRV/pointers/type-deduce-by-call-chain.ll
diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index 458af9229ed7b1..5828db6669ff18 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -92,6 +92,9 @@ class SPIRVEmitIntrinsics
void insertPtrCastOrAssignTypeInstr(Instruction *I, IRBuilder<> &B);
void processGlobalValue(GlobalVariable &GV, IRBuilder<> &B);
void processParamTypes(Function *F, IRBuilder<> &B);
+ Type *deduceFunParamType(Function *F, unsigned OpIdx);
+ Type *deduceFunParamType(Function *F, unsigned OpIdx,
+ std::unordered_set<Function *> &FVisited);
public:
static char ID;
@@ -169,6 +172,10 @@ static inline void reportFatalOnTokenType(const Instruction *I) {
static Type *deduceElementTypeHelper(Value *I,
std::unordered_set<Value *> &Visited,
DenseMap<Value *, Type *> &DeducedElTys) {
+ // allow to pass nullptr as an argument
+ if (!I)
+ return nullptr;
+
// maybe already known
auto It = DeducedElTys.find(I);
if (It != DeducedElTys.end())
@@ -182,15 +189,20 @@ static Type *deduceElementTypeHelper(Value *I,
// fallback value in case when we fail to deduce a type
Type *Ty = nullptr;
// look for known basic patterns of type inference
- if (auto *Ref = dyn_cast<AllocaInst>(I))
+ if (auto *Ref = dyn_cast<AllocaInst>(I)) {
Ty = Ref->getAllocatedType();
- else if (auto *Ref = dyn_cast<GetElementPtrInst>(I))
+ } else if (auto *Ref = dyn_cast<GetElementPtrInst>(I)) {
Ty = Ref->getResultElementType();
- else if (auto *Ref = dyn_cast<GlobalValue>(I))
+ } else if (auto *Ref = dyn_cast<GlobalValue>(I)) {
Ty = Ref->getValueType();
- else if (auto *Ref = dyn_cast<AddrSpaceCastInst>(I))
+ } else if (auto *Ref = dyn_cast<AddrSpaceCastInst>(I)) {
Ty = deduceElementTypeHelper(Ref->getPointerOperand(), Visited,
DeducedElTys);
+ } else if (auto *Ref = dyn_cast<BitCastInst>(I)) {
+ if (Type *Src = Ref->getSrcTy(), *Dest = Ref->getDestTy();
+ isPointerTy(Src) && isPointerTy(Dest))
+ Ty = deduceElementTypeHelper(Ref->getOperand(0), Visited, DeducedElTys);
+ }
// remember the found relationship
if (Ty)
@@ -795,61 +807,80 @@ void SPIRVEmitIntrinsics::processInstrAfterVisit(Instruction *I,
}
}
-void SPIRVEmitIntrinsics::processParamTypes(Function *F, IRBuilder<> &B) {
- DenseMap<unsigned, Argument *> Args;
- unsigned i = 0;
- for (Argument &Arg : F->args()) {
- if (isUntypedPointerTy(Arg.getType()) &&
- DeducedElTys.find(&Arg) == DeducedElTys.end() &&
- !HasPointeeTypeAttr(&Arg))
- Args[i] = &Arg;
- i++;
- }
- if (Args.size() == 0)
- return;
+Type *SPIRVEmitIntrinsics::deduceFunParamType(Function *F, unsigned OpIdx) {
+ std::unordered_set<Function *> FVisited;
+ return deduceFunParamType(F, OpIdx, FVisited);
+}
+
+Type *SPIRVEmitIntrinsics::deduceFunParamType(
+ Function *F, unsigned OpIdx, std::unordered_set<Function *> &FVisited) {
+ // maybe a cycle
+ if (FVisited.find(F) != FVisited.end())
+ return nullptr;
+ FVisited.insert(F);
- // Args contains opaque pointers without element type definition
- B.SetInsertPointPastAllocas(F);
std::unordered_set<Value *> Visited;
+ SmallVector<std::pair<Function *, unsigned>> Lookup;
+ // search in function's call sites
for (User *U : F->users()) {
CallInst *CI = dyn_cast<CallInst>(U);
- if (!CI)
+ if (!CI || OpIdx >= CI->arg_size())
continue;
- for (unsigned OpIdx = 0; OpIdx < CI->arg_size() && Args.size() > 0;
- OpIdx++) {
- auto It = Args.find(OpIdx);
- Argument *Arg = It == Args.end() ? nullptr : It->second;
- if (!Arg)
- continue;
- Value *OpArg = CI->getArgOperand(OpIdx);
- if (!isPointerTy(OpArg->getType()))
+ Value *OpArg = CI->getArgOperand(OpIdx);
+ if (!isPointerTy(OpArg->getType()))
+ continue;
+ // maybe we already know operand's element type
+ if (auto It = DeducedElTys.find(OpArg); It != DeducedElTys.end())
+ return It->second;
+ // search in actual parameter's users
+ for (User *OpU : OpArg->users()) {
+ Instruction *Inst = dyn_cast<Instruction>(OpU);
+ if (!Inst || Inst == CI)
continue;
- // maybe we already know the operand's element type
- auto DeducedIt = DeducedElTys.find(OpArg);
- Type *ElemTy =
- DeducedIt == DeducedElTys.end() ? nullptr : DeducedIt->second;
- if (!ElemTy) {
- for (User *OpU : OpArg->users()) {
- if (Instruction *Inst = dyn_cast<Instruction>(OpU)) {
- Visited.clear();
- ElemTy = deduceElementTypeHelper(Inst, Visited, DeducedElTys);
- if (ElemTy)
- break;
- }
- }
+ Visited.clear();
+ if (Type *Ty = deduceElementTypeHelper(Inst, Visited, DeducedElTys))
+ return Ty;
+ }
+ // check if it's a formal parameter of the outer function
+ if (!CI->getParent() || !CI->getParent()->getParent())
+ continue;
+ Function *OuterF = CI->getParent()->getParent();
+ if (FVisited.find(OuterF) != FVisited.end())
+ continue;
+ for (unsigned i = 0; i < OuterF->arg_size(); ++i) {
+ if (OuterF->getArg(i) == OpArg) {
+ Lookup.push_back(std::make_pair(OuterF, i));
+ break;
}
- if (ElemTy) {
- unsigned AddressSpace = getPointerAddressSpace(Arg->getType());
+ }
+ }
+
+ // search in function parameters
+ for (auto &Pair : Lookup) {
+ if (Type *Ty = deduceFunParamType(Pair.first, Pair.second, FVisited))
+ return Ty;
+ }
+
+ return nullptr;
+}
+
+void SPIRVEmitIntrinsics::processParamTypes(Function *F, IRBuilder<> &B) {
+ B.SetInsertPointPastAllocas(F);
+ DenseMap<Argument *, Type *> Args;
+ for (unsigned OpIdx = 0; OpIdx < F->arg_size(); ++OpIdx) {
+ Argument *Arg = F->getArg(OpIdx);
+ if (isUntypedPointerTy(Arg->getType()) &&
+ DeducedElTys.find(Arg) == DeducedElTys.end() &&
+ !HasPointeeTypeAttr(Arg)) {
+ if (Type *ElemTy = deduceFunParamType(F, OpIdx)) {
CallInst *AssignPtrTyCI = buildIntrWithMD(
Intrinsic::spv_assign_ptr_type, {Arg->getType()},
- Constant::getNullValue(ElemTy), Arg, {B.getInt32(AddressSpace)}, B);
+ Constant::getNullValue(ElemTy), Arg,
+ {B.getInt32(getPointerAddressSpace(Arg->getType()))}, B);
DeducedElTys[AssignPtrTyCI] = ElemTy;
DeducedElTys[Arg] = ElemTy;
- Args.erase(It);
}
}
- if (Args.size() == 0)
- break;
}
}
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 42f8397a3023b1..ee52163a5d127f 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -479,6 +479,7 @@ Register SPIRVGlobalRegistry::buildGlobalVariable(
GVar = M->getGlobalVariable(Name);
if (GVar == nullptr) {
const Type *Ty = getTypeForSPIRVType(BaseType); // TODO: check type.
+ // Module takes ownership of the global var.
GVar = new GlobalVariable(*M, const_cast<Type *>(Ty), false,
GlobalValue::ExternalLinkage, nullptr,
Twine(Name));
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 5bb8f6084f9671..39228e2196b3af 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -499,6 +499,7 @@ bool SPIRVInstructionSelector::spvSelect(Register ResVReg,
assert(I.getOperand(1).isReg() && I.getOperand(2).isReg());
Register GV = I.getOperand(1).getReg();
MachineRegisterInfo::def_instr_iterator II = MRI->def_instr_begin(GV);
+ (void)II;
assert(((*II).getOpcode() == TargetOpcode::G_GLOBAL_VALUE ||
(*II).getOpcode() == TargetOpcode::COPY ||
(*II).getOpcode() == SPIRV::OpVariable) &&
@@ -771,10 +772,13 @@ bool SPIRVInstructionSelector::selectMemOperation(Register ResVReg,
SPIRVType *VarTy = GR.getOrCreateSPIRVPointerType(
ArrTy, I, TII, SPIRV::StorageClass::UniformConstant);
// TODO: check if we have such GV, add init, use buildGlobalVariable.
- Type *LLVMArrTy = ArrayType::get(
- IntegerType::get(GR.CurMF->getFunction().getContext(), 8), Num);
- GlobalVariable *GV =
- new GlobalVariable(LLVMArrTy, true, GlobalValue::InternalLinkage);
+ Function &CurFunction = GR.CurMF->getFunction();
+ Type *LLVMArrTy =
+ ArrayType::get(IntegerType::get(CurFunction.getContext(), 8), Num);
+ // Module takes ownership of the global var.
+ GlobalVariable *GV = new GlobalVariable(*CurFunction.getParent(), LLVMArrTy,
+ true, GlobalValue::InternalLinkage,
+ Constant::getNullValue(LLVMArrTy));
Register VarReg = MRI->createGenericVirtualRegister(LLT::scalar(32));
GR.add(GV, GR.CurMF, VarReg);
diff --git a/llvm/test/CodeGen/SPIRV/pointers/type-deduce-by-call-chain.ll b/llvm/test/CodeGen/SPIRV/pointers/type-deduce-by-call-chain.ll
new file mode 100644
index 00000000000000..703f1e22a0321a
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/pointers/type-deduce-by-call-chain.ll
@@ -0,0 +1,52 @@
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-SPIRV-DAG: OpName %[[ArgCum:.*]] "_arg_cum"
+; CHECK-SPIRV-DAG: OpName %[[FunTest:.*]] "test"
+; CHECK-SPIRV-DAG: OpName %[[Addr:.*]] "addr"
+; CHECK-SPIRV-DAG: OpName %[[StubObj:.*]] "stub_object"
+; CHECK-SPIRV-DAG: OpName %[[MemOrder:.*]] "mem_order"
+; CHECK-SPIRV-DAG: OpName %[[FooStub:.*]] "foo_stub"
+; CHECK-SPIRV-DAG: OpName %[[FooObj:.*]] "foo_object"
+; CHECK-SPIRV-DAG: OpName %[[FooMemOrder:.*]] "mem_order"
+; CHECK-SPIRV-DAG: OpName %[[FooFunc:.*]] "foo"
+; CHECK-SPIRV-DAG: %[[TyLong:.*]] = OpTypeInt 32 0
+; CHECK-SPIRV-DAG: %[[TyVoid:.*]] = OpTypeVoid
+; CHECK-SPIRV-DAG: %[[TyPtrLong:.*]] = OpTypePointer CrossWorkgroup %[[TyLong]]
+; CHECK-SPIRV-DAG: %[[TyFunPtrLong:.*]] = OpTypeFunction %[[TyVoid]] %[[TyPtrLong]]
+; CHECK-SPIRV-DAG: %[[TyGenPtrLong:.*]] = OpTypePointer Generic %[[TyLong]]
+; CHECK-SPIRV-DAG: %[[TyFunGenPtrLongLong:.*]] = OpTypeFunction %[[TyVoid]] %[[TyGenPtrLong]] %[[TyLong]]
+; CHECK-SPIRV-DAG: %[[Const3:.*]] = OpConstant %[[TyLong]] 3
+; CHECK-SPIRV: %[[FunTest]] = OpFunction %[[TyVoid]] None %[[TyFunPtrLong]]
+; CHECK-SPIRV: %[[ArgCum]] = OpFunctionParameter %[[TyPtrLong]]
+; CHECK-SPIRV: OpFunctionCall %[[TyVoid]] %[[FooFunc]] %[[Addr]] %[[Const3]]
+; CHECK-SPIRV: %[[FooStub]] = OpFunction %[[TyVoid]] None %[[TyFunGenPtrLongLong]]
+; CHECK-SPIRV: %[[StubObj]] = OpFunctionParameter %[[TyGenPtrLong]]
+; CHECK-SPIRV: %[[MemOrder]] = OpFunctionParameter %[[TyLong]]
+; CHECK-SPIRV: %[[FooFunc]] = OpFunction %[[TyVoid]] None %[[TyFunGenPtrLongLong]]
+; CHECK-SPIRV: %[[FooObj]] = OpFunctionParameter %[[TyGenPtrLong]]
+; CHECK-SPIRV: %[[FooMemOrder]] = OpFunctionParameter %[[TyLong]]
+; CHECK-SPIRV: OpFunctionCall %[[TyVoid]] %[[FooStub]] %[[FooObj]] %[[FooMemOrder]]
+
+define spir_kernel void @test(ptr addrspace(1) noundef align 4 %_arg_cum) {
+entry:
+ %lptr = getelementptr inbounds i32, ptr addrspace(1) %_arg_cum, i64 1
+ %addr = addrspacecast ptr addrspace(1) %lptr to ptr addrspace(4)
+ %object = bitcast ptr addrspace(4) %addr to ptr addrspace(4)
+ call spir_func void @foo(ptr addrspace(4) %object, i32 3)
+ ret void
+}
+
+define void @foo_stub(ptr addrspace(4) noundef %stub_object, i32 noundef %mem_order) {
+entry:
+ %object.addr = alloca ptr addrspace(4)
+ %object.addr.ascast = addrspacecast ptr %object.addr to ptr addrspace(4)
+ store ptr addrspace(4) %stub_object, ptr addrspace(4) %object.addr.ascast
+ ret void
+}
+
+define void @foo(ptr addrspace(4) noundef %foo_object, i32 noundef %mem_order) {
+ tail call void @foo_stub(ptr addrspace(4) noundef %foo_object, i32 noundef %mem_order)
+ ret void
+}
+
More information about the llvm-commits
mailing list