[llvm] [SPIRV] Add type inference of function parameters by call instances (PR #85077)
via llvm-commits
llvm-commits at lists.llvm.org
Wed Mar 13 06:46:07 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-spir-v
Author: Vyacheslav Levytskyy (VyacheslavLevytskyy)
<details>
<summary>Changes</summary>
This PR adds type inference of function parameters by call instances. Two use cases that demonstrate the problem are added.
---
Full diff: https://github.com/llvm/llvm-project/pull/85077.diff
5 Files Affected:
- (modified) llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp (+1-1)
- (modified) llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp (+62)
- (modified) llvm/lib/Target/SPIRV/SPIRVUtils.h (+5)
- (added) llvm/test/CodeGen/SPIRV/pointers/type-deduce-args-rev.ll (+28)
- (added) llvm/test/CodeGen/SPIRV/pointers/type-deduce-args.ll (+28)
``````````diff
diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
index f1fbe2ba1bc416..77319f58ff4d97 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
@@ -209,7 +209,7 @@ static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,
// spv_assign_ptr_type intrinsic or otherwise use default pointer element
// type.
Argument *Arg = F.getArg(ArgIdx);
- if (Arg->hasByValAttr() || Arg->hasByRefAttr()) {
+ if (HasPointeeTypeAttr(Arg)) {
Type *ByValRefType = Arg->hasByValAttr() ? Arg->getParamByValType()
: Arg->getParamByRefType();
SPIRVType *ElementType = GR->getOrCreateSPIRVType(ByValRefType, MIRBuilder);
diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index c5b901235402c1..e9099fac1c1a3c 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -91,6 +91,7 @@ class SPIRVEmitIntrinsics
IRBuilder<> &B);
void insertPtrCastOrAssignTypeInstr(Instruction *I, IRBuilder<> &B);
void processGlobalValue(GlobalVariable &GV, IRBuilder<> &B);
+ void processParamTypes(Function *F, IRBuilder<> &B);
public:
static char ID;
@@ -794,6 +795,62 @@ void SPIRVEmitIntrinsics::processInstrAfterVisit(Instruction *I,
}
}
+void SPIRVEmitIntrinsics::processParamTypes(Function *F, IRBuilder<> &B) {
+ DenseMap<unsigned, Argument *> Args;
+ unsigned i = 0;
+ for (Argument &Arg : F->args()) {
+ if (isUntypedPointerTy(Arg.getType()) &&
+ DeducedElTys.find(&Arg) == DeducedElTys.end() &&
+ !HasPointeeTypeAttr(&Arg))
+ Args[i++] = &Arg;
+ }
+ if (i == 0)
+ return;
+
+ // Args contains opaque pointers without element type definition
+ B.SetInsertPointPastAllocas(F);
+ std::unordered_set<Value *> Visited;
+ for (User *U : F->users()) {
+ CallInst *CI = dyn_cast<CallInst>(U);
+ if (!CI)
+ continue;
+ for (unsigned OpIdx = 0; OpIdx < CI->arg_size() && Args.size() > 0;
+ OpIdx++) {
+ auto It = Args.find(OpIdx);
+ Argument *Arg = It == Args.end() ? nullptr : It->second;
+ if (!Arg)
+ continue;
+ Value *OpArg = CI->getArgOperand(OpIdx);
+ if (!isPointerTy(OpArg->getType()))
+ continue;
+ // maybe we already know the operand's element type
+ auto DeducedIt = DeducedElTys.find(OpArg);
+ Type *ElemTy = DeducedIt == DeducedElTys.end() ? nullptr : DeducedIt->second;
+ if (!ElemTy) {
+ for (User *OpU : OpArg->users()) {
+ if (Instruction *Inst = dyn_cast<Instruction>(OpU)) {
+ Visited.clear();
+ ElemTy = deduceElementTypeHelper(Inst, Visited, DeducedElTys);
+ if (ElemTy)
+ break;
+ }
+ }
+ }
+ if (ElemTy) {
+ unsigned AddressSpace = getPointerAddressSpace(Arg->getType());
+ CallInst *AssignPtrTyCI = buildIntrWithMD(
+ Intrinsic::spv_assign_ptr_type, {Arg->getType()},
+ Constant::getNullValue(ElemTy), Arg, {B.getInt32(AddressSpace)}, B);
+ DeducedElTys[AssignPtrTyCI] = ElemTy;
+ DeducedElTys[Arg] = ElemTy;
+ Args.erase(It);
+ }
+ }
+ if (Args.size() == 0)
+ break;
+ }
+}
+
bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
if (Func.isDeclaration())
return false;
@@ -839,6 +896,11 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
continue;
processInstrAfterVisit(I, B);
}
+
+ // check if function parameter types are set
+ if (!F->isIntrinsic())
+ processParamTypes(F, B);
+
return true;
}
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.h b/llvm/lib/Target/SPIRV/SPIRVUtils.h
index d5ed501def9986..eb87349f0941c5 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.h
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.h
@@ -126,5 +126,10 @@ inline unsigned getPointerAddressSpace(const Type *T) {
: cast<TypedPointerType>(SubT)->getAddressSpace();
}
+// Return true if the Argument is decorated with a pointee type
+inline bool HasPointeeTypeAttr(Argument *Arg) {
+ return Arg->hasByValAttr() || Arg->hasByRefAttr();
+}
+
} // namespace llvm
#endif // LLVM_LIB_TARGET_SPIRV_SPIRVUTILS_H
diff --git a/llvm/test/CodeGen/SPIRV/pointers/type-deduce-args-rev.ll b/llvm/test/CodeGen/SPIRV/pointers/type-deduce-args-rev.ll
new file mode 100644
index 00000000000000..3f8edfef78b03c
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/pointers/type-deduce-args-rev.ll
@@ -0,0 +1,28 @@
+; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-SPIRV-DAG: OpName %[[FooArg:.*]] "known_type_ptr"
+; CHECK-SPIRV-DAG: OpName %[[Foo:.*]] "foo"
+; CHECK-SPIRV-DAG: OpName %[[ArgToDeduce:.*]] "unknown_type_ptr"
+; CHECK-SPIRV-DAG: OpName %[[Bar:.*]] "bar"
+; CHECK-SPIRV-DAG: %[[Long:.*]] = OpTypeInt 32 0
+; CHECK-SPIRV-DAG: %[[Void:.*]] = OpTypeVoid
+; CHECK-SPIRV-DAG: %[[LongPtr:.*]] = OpTypePointer CrossWorkgroup %[[Long]]
+; CHECK-SPIRV-DAG: %[[Fun:.*]] = OpTypeFunction %[[Void]] %[[LongPtr]]
+; CHECK-SPIRV: %[[Bar]] = OpFunction %[[Void]] None %[[Fun]]
+; CHECK-SPIRV: %[[ArgToDeduce]] = OpFunctionParameter %[[LongPtr]]
+; CHECK-SPIRV: OpFunctionCall %[[Void]] %[[Foo]] %[[ArgToDeduce]]
+; CHECK-SPIRV: %[[Foo]] = OpFunction %[[Void]] None %[[Fun]]
+; CHECK-SPIRV: %[[FooArg]] = OpFunctionParameter %[[LongPtr]]
+
+define spir_kernel void @bar(ptr addrspace(1) %unknown_type_ptr) {
+entry:
+ %elem = getelementptr inbounds i32, ptr addrspace(1) %unknown_type_ptr, i64 0
+ call spir_func void @foo(ptr addrspace(1) %unknown_type_ptr)
+ ret void
+}
+
+define void @foo(ptr addrspace(1) %known_type_ptr) {
+entry:
+ ret void
+}
diff --git a/llvm/test/CodeGen/SPIRV/pointers/type-deduce-args.ll b/llvm/test/CodeGen/SPIRV/pointers/type-deduce-args.ll
new file mode 100644
index 00000000000000..be8582f9226d5c
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/pointers/type-deduce-args.ll
@@ -0,0 +1,28 @@
+; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-SPIRV-DAG: OpName %[[FooArg:.*]] "known_type_ptr"
+; CHECK-SPIRV-DAG: OpName %[[Foo:.*]] "foo"
+; CHECK-SPIRV-DAG: OpName %[[ArgToDeduce:.*]] "unknown_type_ptr"
+; CHECK-SPIRV-DAG: OpName %[[Bar:.*]] "bar"
+; CHECK-SPIRV-DAG: %[[Long:.*]] = OpTypeInt 32 0
+; CHECK-SPIRV-DAG: %[[Void:.*]] = OpTypeVoid
+; CHECK-SPIRV-DAG: %[[LongPtr:.*]] = OpTypePointer CrossWorkgroup %[[Long]]
+; CHECK-SPIRV-DAG: %[[Fun:.*]] = OpTypeFunction %[[Void]] %[[LongPtr]]
+; CHECK-SPIRV: %[[Foo]] = OpFunction %[[Void]] None %[[Fun]]
+; CHECK-SPIRV: %[[FooArg]] = OpFunctionParameter %[[LongPtr]]
+; CHECK-SPIRV: %[[Bar]] = OpFunction %[[Void]] None %[[Fun]]
+; CHECK-SPIRV: %[[ArgToDeduce]] = OpFunctionParameter %[[LongPtr]]
+; CHECK-SPIRV: OpFunctionCall %[[Void]] %[[Foo]] %[[ArgToDeduce]]
+
+define void @foo(ptr addrspace(1) %known_type_ptr) {
+entry:
+ ret void
+}
+
+define spir_kernel void @bar(ptr addrspace(1) %unknown_type_ptr) {
+entry:
+ %elem = getelementptr inbounds i32, ptr addrspace(1) %unknown_type_ptr, i64 0
+ call spir_func void @foo(ptr addrspace(1) %unknown_type_ptr)
+ ret void
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/85077
More information about the llvm-commits
mailing list