[llvm] [SPIR-V] Improve type inference for a known instruction's builtin: OpGroupAsyncCopy (PR #96895)
Vyacheslav Levytskyy via llvm-commits
llvm-commits at lists.llvm.org
Thu Jun 27 11:13:31 PDT 2024
https://github.com/VyacheslavLevytskyy updated https://github.com/llvm/llvm-project/pull/96895
>From d54f3a10269c82c43ac518502e4e03f3744ab91b Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Thu, 27 Jun 2024 03:43:34 -0700
Subject: [PATCH 1/3] Improve type inference fo a known instruction's builtin
---
llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp | 71 ++++++++++++++++++-
llvm/lib/Target/SPIRV/SPIRVBuiltins.h | 6 ++
llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp | 25 ++++++-
.../SPIRV/transcoding/spirv-event-null.ll | 23 +++++-
4 files changed, 119 insertions(+), 6 deletions(-)
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index 0b93a4d85eedf..bdada81cef89f 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -2377,9 +2377,76 @@ static bool generateLoadStoreInst(const SPIRV::IncomingCall *Call,
return true;
}
-/// Lowers a builtin funtion call using the provided \p DemangledCall skeleton
-/// and external instruction \p Set.
namespace SPIRV {
+// Try to find a builtin funtion attributes by a demangled function name and
+// return a tuple <builtin group, op code, ext instruction number>, or a special
+// tuple value <-1, 0, 0> if the builtin funtion is not found.
+// Not all builtin funtions are supported, only those with a ready-to-use op
+// code or instruction number defined in TableGen.
+// TODO: consider a major rework of mapping demangled calls into a builtin
+// functions to unify search and decrease number of individual cases.
+std::tuple<int, unsigned, unsigned>
+mapBuiltinToOpcode(const StringRef DemangledCall,
+ SPIRV::InstructionSet::InstructionSet Set) {
+ const SPIRV::DemangledBuiltin *Builtin =
+ SPIRV::lookupBuiltin(DemangledCall, Set);
+ if (!Builtin)
+ return std::make_tuple(-1, 0, 0);
+
+ switch (Builtin->Group) {
+ case SPIRV::Relational:
+ case SPIRV::Atomic:
+ case SPIRV::Barrier:
+ case SPIRV::CastToPtr:
+ case SPIRV::ImageMiscQuery:
+ case SPIRV::SpecConstant:
+ case SPIRV::Enqueue:
+ case SPIRV::AsyncCopy:
+ case SPIRV::LoadStore:
+ case SPIRV::CoopMatr:
+ if (const auto *R = SPIRV::lookupNativeBuiltin(Builtin->Name, Builtin->Set))
+ return std::make_tuple(Builtin->Group, R->Opcode, 0);
+ break;
+ case SPIRV::Extended:
+ if (const auto *R =
+ SPIRV::lookupExtendedBuiltin(Builtin->Name, Builtin->Set))
+ return std::make_tuple(Builtin->Group, 0, R->Number);
+ break;
+ case SPIRV::VectorLoadStore:
+ if (const auto *R =
+ SPIRV::lookupVectorLoadStoreBuiltin(Builtin->Name, Builtin->Set))
+ return std::make_tuple(SPIRV::Extended, 0, R->Number);
+ break;
+ case SPIRV::Group:
+ if (const auto *R = SPIRV::lookupGroupBuiltin(Builtin->Name))
+ return std::make_tuple(Builtin->Group, R->Opcode, 0);
+ break;
+ case SPIRV::AtomicFloating:
+ if (const auto *R = SPIRV::lookupAtomicFloatingBuiltin(Builtin->Name))
+ return std::make_tuple(Builtin->Group, R->Opcode, 0);
+ break;
+ case SPIRV::IntelSubgroups:
+ if (const auto *R = SPIRV::lookupIntelSubgroupsBuiltin(Builtin->Name))
+ return std::make_tuple(Builtin->Group, R->Opcode, 0);
+ break;
+ case SPIRV::GroupUniform:
+ if (const auto *R = SPIRV::lookupGroupUniformBuiltin(Builtin->Name))
+ return std::make_tuple(Builtin->Group, R->Opcode, 0);
+ break;
+ case SPIRV::WriteImage:
+ return std::make_tuple(Builtin->Group, SPIRV::OpImageWrite, 0);
+ case SPIRV::Select:
+ return std::make_tuple(Builtin->Group, TargetOpcode::G_SELECT, 0);
+ case SPIRV::Construct:
+ return std::make_tuple(Builtin->Group, SPIRV::OpCompositeConstruct, 0);
+ case SPIRV::KernelClock:
+ return std::make_tuple(Builtin->Group, SPIRV::OpReadClockKHR, 0);
+ default:
+ return std::make_tuple(-1, 0, 0);
+ }
+ return std::make_tuple(-1, 0, 0);
+}
+
std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
SPIRV::InstructionSet::InstructionSet Set,
MachineIRBuilder &MIRBuilder,
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.h b/llvm/lib/Target/SPIRV/SPIRVBuiltins.h
index 649f5bfd1d7c2..482b2ed853f7d 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.h
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.h
@@ -38,6 +38,12 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
const SmallVectorImpl<Register> &Args,
SPIRVGlobalRegistry *GR);
+/// Helper external function for finding a builtin funtion attributes
+/// by a demangled function name. Defined in SPIRVBuiltins.cpp.
+std::tuple<int, unsigned, unsigned>
+mapBuiltinToOpcode(const StringRef DemangledCall,
+ SPIRV::InstructionSet::InstructionSet Set);
+
/// Parses the provided \p ArgIdx argument base type in the \p DemangledCall
/// skeleton. A base type is either a basic type (e.g. i32 for int), pointer
/// element type (e.g. i8 for char*), or builtin type (TargetExtType).
diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index dd5884096b85d..53f00eb2d1f10 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -67,6 +67,7 @@ class SPIRVEmitIntrinsics
DenseMap<Instruction *, Constant *> AggrConsts;
DenseMap<Instruction *, Type *> AggrConstTypes;
DenseSet<Instruction *> AggrStores;
+ SPIRV::InstructionSet::InstructionSet InstrSet;
// deduce element type of untyped pointers
Type *deduceElementType(Value *I, bool UnknownElemTypeI8);
@@ -384,9 +385,10 @@ Type *SPIRVEmitIntrinsics::deduceElementTypeHelper(
std::string DemangledName =
getOclOrSpirvBuiltinDemangledName(CalledF->getName());
auto AsArgIt = ResTypeByArg.find(DemangledName);
- if (AsArgIt != ResTypeByArg.end())
+ if (AsArgIt != ResTypeByArg.end()) {
Ty = deduceElementTypeHelper(CI->getArgOperand(AsArgIt->second),
Visited);
+ }
}
}
@@ -544,6 +546,25 @@ void SPIRVEmitIntrinsics::deduceOperandElementType(Instruction *I) {
KnownElemTy = ElemTy1;
Ops.push_back(std::make_pair(Op0, 0));
}
+ } else if (auto *CI = dyn_cast<CallInst>(I)) {
+ if (Function *CalledF = CI->getCalledFunction()) {
+ std::string DemangledName =
+ getOclOrSpirvBuiltinDemangledName(CalledF->getName());
+ auto [Grp, Opcode, ExtNo] =
+ SPIRV::mapBuiltinToOpcode(DemangledName, InstrSet);
+ if (Opcode == SPIRV::OpGroupAsyncCopy) {
+ for (unsigned i = 0, PtrCnt = 0; i < CI->arg_size() && PtrCnt < 2;
+ ++i) {
+ Value *Op = CI->getArgOperand(i);
+ if (!isPointerTy(Op->getType()))
+ continue;
+ ++PtrCnt;
+ if (Type *ElemTy = GR->findDeducedElementType(Op))
+ KnownElemTy = ElemTy; // src will rewrite dest if both are defined
+ Ops.push_back(std::make_pair(Op, i));
+ }
+ }
+ }
}
// There is no enough info to deduce types or all is valid.
@@ -1385,6 +1406,8 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
const SPIRVSubtarget &ST = TM->getSubtarget<SPIRVSubtarget>(Func);
GR = ST.getSPIRVGlobalRegistry();
+ InstrSet = ST.isOpenCLEnv() ? SPIRV::InstructionSet::OpenCL_std
+ : SPIRV::InstructionSet::GLSL_std_450;
F = &Func;
IRBuilder<> B(Func.getContext());
diff --git a/llvm/test/CodeGen/SPIRV/transcoding/spirv-event-null.ll b/llvm/test/CodeGen/SPIRV/transcoding/spirv-event-null.ll
index fe0d96f2773ec..82eff6f14469d 100644
--- a/llvm/test/CodeGen/SPIRV/transcoding/spirv-event-null.ll
+++ b/llvm/test/CodeGen/SPIRV/transcoding/spirv-event-null.ll
@@ -9,6 +9,7 @@
; CHECK-DAG: %[[#ConstEvent:]] = OpConstantNull %[[#TyEvent]]
; CHECK-DAG: %[[#TyEventPtr:]] = OpTypePointer Function %[[#TyEvent]]
; CHECK-DAG: %[[#TyStructPtr:]] = OpTypePointer Function %[[#TyStruct]]
+
; CHECK: OpFunction
; CHECK: OpFunctionParameter
; CHECK: %[[#Src:]] = OpFunctionParameter
@@ -18,11 +19,11 @@
; CHECK: %[[#CopyRes:]] = OpGroupAsyncCopy %[[#TyEvent]] %[[#]] %[[#Dest]] %[[#Src]] %[[#]] %[[#]] %[[#ConstEvent]]
; CHECK: OpStore %[[#EventVar]] %[[#CopyRes]]
-%"class.sycl::_V1::device_event" = type { target("spirv.Event") }
+%StructEvent = type { target("spirv.Event") }
-define spir_kernel void @foo(ptr addrspace(1) %_arg_out_ptr, ptr addrspace(3) noundef %_arg_local_acc) {
+define spir_kernel void @foo(ptr addrspace(1) %_arg_out_ptr, ptr addrspace(3) %_arg_local_acc) {
entry:
- %var = alloca %"class.sycl::_V1::device_event"
+ %var = alloca %StructEvent
%dev_event.i.sroa.0 = alloca target("spirv.Event")
%add.ptr.i26 = getelementptr inbounds i32, ptr addrspace(1) %_arg_out_ptr, i64 0
%call3.i = tail call spir_func target("spirv.Event") @_Z22__spirv_GroupAsyncCopyjPU3AS1iPU3AS3Kimm9ocl_event(i32 2, ptr addrspace(1) %add.ptr.i26, ptr addrspace(3) %_arg_local_acc, i64 16, i64 10, target("spirv.Event") zeroinitializer)
@@ -31,3 +32,19 @@ entry:
}
declare dso_local spir_func target("spirv.Event") @_Z22__spirv_GroupAsyncCopyjPU3AS1iPU3AS3Kimm9ocl_event(i32, ptr addrspace(1), ptr addrspace(3), i64, i64, target("spirv.Event"))
+
+%Vec4 = type { <4 x i8> }
+
+define spir_kernel void @bar(ptr addrspace(3) %_arg_Local, ptr addrspace(1) readonly %_arg_In) {
+entry:
+ %E.i = alloca %StructEvent
+ %add.ptr.i42 = getelementptr inbounds %Vec4, ptr addrspace(1) %_arg_In, i64 0
+ %call3.i.i = tail call spir_func target("spirv.Event") @_Z22__spirv_GroupAsyncCopyjPU3AS3Dv4_aPU3AS1KS_mm9ocl_event(i32 2, ptr addrspace(3) %_arg_Local, ptr addrspace(1) %add.ptr.i42, i64 16, i64 10, target("spirv.Event") zeroinitializer)
+ store target("spirv.Event") %call3.i.i, ptr %E.i
+ %E.ascast.i = addrspacecast ptr %E.i to ptr addrspace(4)
+ call spir_func void @_Z23__spirv_GroupWaitEventsjiP9ocl_event(i32 2, i32 1, ptr addrspace(4) %E.ascast.i)
+ ret void
+}
+
+declare dso_local spir_func target("spirv.Event") @_Z22__spirv_GroupAsyncCopyjPU3AS3Dv4_aPU3AS1KS_mm9ocl_event(i32, ptr addrspace(3), ptr addrspace(1), i64, i64, target("spirv.Event"))
+declare dso_local spir_func void @_Z23__spirv_GroupWaitEventsjiP9ocl_event(i32, i32, ptr addrspace(4))
>From ebf6d801693fd1ca0e7d86f5ff75efd667911e90 Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Thu, 27 Jun 2024 10:41:46 -0700
Subject: [PATCH 2/3] deduce type of arguments of OpGroupAsyncCopy
---
llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp | 44 ++++++++++---------
llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp | 26 ++++++-----
2 files changed, 37 insertions(+), 33 deletions(-)
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index bdada81cef89f..c99165e786216 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -2388,12 +2388,14 @@ namespace SPIRV {
std::tuple<int, unsigned, unsigned>
mapBuiltinToOpcode(const StringRef DemangledCall,
SPIRV::InstructionSet::InstructionSet Set) {
- const SPIRV::DemangledBuiltin *Builtin =
- SPIRV::lookupBuiltin(DemangledCall, Set);
- if (!Builtin)
+ Register Reg;
+ SmallVector<Register> Args;
+ std::unique_ptr<const IncomingCall> Call =
+ lookupBuiltin(DemangledCall, Set, Reg, nullptr, Args);
+ if (!Call)
return std::make_tuple(-1, 0, 0);
- switch (Builtin->Group) {
+ switch (Call->Builtin->Group) {
case SPIRV::Relational:
case SPIRV::Atomic:
case SPIRV::Barrier:
@@ -2404,43 +2406,43 @@ mapBuiltinToOpcode(const StringRef DemangledCall,
case SPIRV::AsyncCopy:
case SPIRV::LoadStore:
case SPIRV::CoopMatr:
- if (const auto *R = SPIRV::lookupNativeBuiltin(Builtin->Name, Builtin->Set))
- return std::make_tuple(Builtin->Group, R->Opcode, 0);
+ if (const auto *R = SPIRV::lookupNativeBuiltin(Call->Builtin->Name, Call->Builtin->Set))
+ return std::make_tuple(Call->Builtin->Group, R->Opcode, 0);
break;
case SPIRV::Extended:
if (const auto *R =
- SPIRV::lookupExtendedBuiltin(Builtin->Name, Builtin->Set))
- return std::make_tuple(Builtin->Group, 0, R->Number);
+ SPIRV::lookupExtendedBuiltin(Call->Builtin->Name, Call->Builtin->Set))
+ return std::make_tuple(Call->Builtin->Group, 0, R->Number);
break;
case SPIRV::VectorLoadStore:
if (const auto *R =
- SPIRV::lookupVectorLoadStoreBuiltin(Builtin->Name, Builtin->Set))
+ SPIRV::lookupVectorLoadStoreBuiltin(Call->Builtin->Name, Call->Builtin->Set))
return std::make_tuple(SPIRV::Extended, 0, R->Number);
break;
case SPIRV::Group:
- if (const auto *R = SPIRV::lookupGroupBuiltin(Builtin->Name))
- return std::make_tuple(Builtin->Group, R->Opcode, 0);
+ if (const auto *R = SPIRV::lookupGroupBuiltin(Call->Builtin->Name))
+ return std::make_tuple(Call->Builtin->Group, R->Opcode, 0);
break;
case SPIRV::AtomicFloating:
- if (const auto *R = SPIRV::lookupAtomicFloatingBuiltin(Builtin->Name))
- return std::make_tuple(Builtin->Group, R->Opcode, 0);
+ if (const auto *R = SPIRV::lookupAtomicFloatingBuiltin(Call->Builtin->Name))
+ return std::make_tuple(Call->Builtin->Group, R->Opcode, 0);
break;
case SPIRV::IntelSubgroups:
- if (const auto *R = SPIRV::lookupIntelSubgroupsBuiltin(Builtin->Name))
- return std::make_tuple(Builtin->Group, R->Opcode, 0);
+ if (const auto *R = SPIRV::lookupIntelSubgroupsBuiltin(Call->Builtin->Name))
+ return std::make_tuple(Call->Builtin->Group, R->Opcode, 0);
break;
case SPIRV::GroupUniform:
- if (const auto *R = SPIRV::lookupGroupUniformBuiltin(Builtin->Name))
- return std::make_tuple(Builtin->Group, R->Opcode, 0);
+ if (const auto *R = SPIRV::lookupGroupUniformBuiltin(Call->Builtin->Name))
+ return std::make_tuple(Call->Builtin->Group, R->Opcode, 0);
break;
case SPIRV::WriteImage:
- return std::make_tuple(Builtin->Group, SPIRV::OpImageWrite, 0);
+ return std::make_tuple(Call->Builtin->Group, SPIRV::OpImageWrite, 0);
case SPIRV::Select:
- return std::make_tuple(Builtin->Group, TargetOpcode::G_SELECT, 0);
+ return std::make_tuple(Call->Builtin->Group, TargetOpcode::G_SELECT, 0);
case SPIRV::Construct:
- return std::make_tuple(Builtin->Group, SPIRV::OpCompositeConstruct, 0);
+ return std::make_tuple(Call->Builtin->Group, SPIRV::OpCompositeConstruct, 0);
case SPIRV::KernelClock:
- return std::make_tuple(Builtin->Group, SPIRV::OpReadClockKHR, 0);
+ return std::make_tuple(Call->Builtin->Group, SPIRV::OpReadClockKHR, 0);
default:
return std::make_tuple(-1, 0, 0);
}
diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index 53f00eb2d1f10..5e7891496ce93 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -550,18 +550,20 @@ void SPIRVEmitIntrinsics::deduceOperandElementType(Instruction *I) {
if (Function *CalledF = CI->getCalledFunction()) {
std::string DemangledName =
getOclOrSpirvBuiltinDemangledName(CalledF->getName());
- auto [Grp, Opcode, ExtNo] =
- SPIRV::mapBuiltinToOpcode(DemangledName, InstrSet);
- if (Opcode == SPIRV::OpGroupAsyncCopy) {
- for (unsigned i = 0, PtrCnt = 0; i < CI->arg_size() && PtrCnt < 2;
- ++i) {
- Value *Op = CI->getArgOperand(i);
- if (!isPointerTy(Op->getType()))
- continue;
- ++PtrCnt;
- if (Type *ElemTy = GR->findDeducedElementType(Op))
- KnownElemTy = ElemTy; // src will rewrite dest if both are defined
- Ops.push_back(std::make_pair(Op, i));
+ if (!StringRef(DemangledName).starts_with("llvm.")) {
+ auto [Grp, Opcode, ExtNo] =
+ SPIRV::mapBuiltinToOpcode(DemangledName, InstrSet);
+ if (Opcode == SPIRV::OpGroupAsyncCopy) {
+ for (unsigned i = 0, PtrCnt = 0; i < CI->arg_size() && PtrCnt < 2;
+ ++i) {
+ Value *Op = CI->getArgOperand(i);
+ if (!isPointerTy(Op->getType()))
+ continue;
+ ++PtrCnt;
+ if (Type *ElemTy = GR->findDeducedElementType(Op))
+ KnownElemTy = ElemTy; // src will rewrite dest if both are defined
+ Ops.push_back(std::make_pair(Op, i));
+ }
}
}
}
>From b448cb76dc4a5005f8137a70e288490a76357081 Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Thu, 27 Jun 2024 11:13:20 -0700
Subject: [PATCH 3/3] fix functions name condition
---
llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index 5e7891496ce93..04e29b24d8fd6 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -550,7 +550,8 @@ void SPIRVEmitIntrinsics::deduceOperandElementType(Instruction *I) {
if (Function *CalledF = CI->getCalledFunction()) {
std::string DemangledName =
getOclOrSpirvBuiltinDemangledName(CalledF->getName());
- if (!StringRef(DemangledName).starts_with("llvm.")) {
+ if (DemangledName.length() > 0 &&
+ !StringRef(DemangledName).starts_with("llvm.")) {
auto [Grp, Opcode, ExtNo] =
SPIRV::mapBuiltinToOpcode(DemangledName, InstrSet);
if (Opcode == SPIRV::OpGroupAsyncCopy) {
More information about the llvm-commits
mailing list