[llvm] [SPIR-V] Improve type inference for a known instruction's builtin: OpGroupAsyncCopy (PR #96895)
via llvm-commits
llvm-commits at lists.llvm.org
Tue Jul 2 04:57:06 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-spir-v
Author: Vyacheslav Levytskyy (VyacheslavLevytskyy)
<details>
<summary>Changes</summary>
This PR improves type inference for a known instruction's builtin: OpGroupAsyncCopy:
* deduce a type of one source/destination pointer when it's possible to deduce a type of another argument, and
* validate src and dest types and tries to unfold a parameter if it's a structure wrapper around a scalar/vector type.
---
Full diff: https://github.com/llvm/llvm-project/pull/96895.diff
5 Files Affected:
- (modified) llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp (+73-2)
- (modified) llvm/lib/Target/SPIRV/SPIRVBuiltins.h (+6)
- (modified) llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp (+27-1)
- (modified) llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp (+37)
- (modified) llvm/test/CodeGen/SPIRV/transcoding/spirv-event-null.ll (+51-3)
``````````diff
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index 0b93a4d85eedf..5e5fa26a0f5f7 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -2377,9 +2377,80 @@ static bool generateLoadStoreInst(const SPIRV::IncomingCall *Call,
return true;
}
-/// Lowers a builtin funtion call using the provided \p DemangledCall skeleton
-/// and external instruction \p Set.
namespace SPIRV {
+// Try to find a builtin funtion attributes by a demangled function name and
+// return a tuple <builtin group, op code, ext instruction number>, or a special
+// tuple value <-1, 0, 0> if the builtin funtion is not found.
+// Not all builtin funtions are supported, only those with a ready-to-use op
+// code or instruction number defined in TableGen.
+// TODO: consider a major rework of mapping demangled calls into a builtin
+// functions to unify search and decrease number of individual cases.
+std::tuple<int, unsigned, unsigned>
+mapBuiltinToOpcode(const StringRef DemangledCall,
+ SPIRV::InstructionSet::InstructionSet Set) {
+ Register Reg;
+ SmallVector<Register> Args;
+ std::unique_ptr<const IncomingCall> Call =
+ lookupBuiltin(DemangledCall, Set, Reg, nullptr, Args);
+ if (!Call)
+ return std::make_tuple(-1, 0, 0);
+
+ switch (Call->Builtin->Group) {
+ case SPIRV::Relational:
+ case SPIRV::Atomic:
+ case SPIRV::Barrier:
+ case SPIRV::CastToPtr:
+ case SPIRV::ImageMiscQuery:
+ case SPIRV::SpecConstant:
+ case SPIRV::Enqueue:
+ case SPIRV::AsyncCopy:
+ case SPIRV::LoadStore:
+ case SPIRV::CoopMatr:
+ if (const auto *R =
+ SPIRV::lookupNativeBuiltin(Call->Builtin->Name, Call->Builtin->Set))
+ return std::make_tuple(Call->Builtin->Group, R->Opcode, 0);
+ break;
+ case SPIRV::Extended:
+ if (const auto *R = SPIRV::lookupExtendedBuiltin(Call->Builtin->Name,
+ Call->Builtin->Set))
+ return std::make_tuple(Call->Builtin->Group, 0, R->Number);
+ break;
+ case SPIRV::VectorLoadStore:
+ if (const auto *R = SPIRV::lookupVectorLoadStoreBuiltin(Call->Builtin->Name,
+ Call->Builtin->Set))
+ return std::make_tuple(SPIRV::Extended, 0, R->Number);
+ break;
+ case SPIRV::Group:
+ if (const auto *R = SPIRV::lookupGroupBuiltin(Call->Builtin->Name))
+ return std::make_tuple(Call->Builtin->Group, R->Opcode, 0);
+ break;
+ case SPIRV::AtomicFloating:
+ if (const auto *R = SPIRV::lookupAtomicFloatingBuiltin(Call->Builtin->Name))
+ return std::make_tuple(Call->Builtin->Group, R->Opcode, 0);
+ break;
+ case SPIRV::IntelSubgroups:
+ if (const auto *R = SPIRV::lookupIntelSubgroupsBuiltin(Call->Builtin->Name))
+ return std::make_tuple(Call->Builtin->Group, R->Opcode, 0);
+ break;
+ case SPIRV::GroupUniform:
+ if (const auto *R = SPIRV::lookupGroupUniformBuiltin(Call->Builtin->Name))
+ return std::make_tuple(Call->Builtin->Group, R->Opcode, 0);
+ break;
+ case SPIRV::WriteImage:
+ return std::make_tuple(Call->Builtin->Group, SPIRV::OpImageWrite, 0);
+ case SPIRV::Select:
+ return std::make_tuple(Call->Builtin->Group, TargetOpcode::G_SELECT, 0);
+ case SPIRV::Construct:
+ return std::make_tuple(Call->Builtin->Group, SPIRV::OpCompositeConstruct,
+ 0);
+ case SPIRV::KernelClock:
+ return std::make_tuple(Call->Builtin->Group, SPIRV::OpReadClockKHR, 0);
+ default:
+ return std::make_tuple(-1, 0, 0);
+ }
+ return std::make_tuple(-1, 0, 0);
+}
+
std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
SPIRV::InstructionSet::InstructionSet Set,
MachineIRBuilder &MIRBuilder,
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.h b/llvm/lib/Target/SPIRV/SPIRVBuiltins.h
index 649f5bfd1d7c2..482b2ed853f7d 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.h
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.h
@@ -38,6 +38,12 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
const SmallVectorImpl<Register> &Args,
SPIRVGlobalRegistry *GR);
+/// Helper external function for finding a builtin funtion attributes
+/// by a demangled function name. Defined in SPIRVBuiltins.cpp.
+std::tuple<int, unsigned, unsigned>
+mapBuiltinToOpcode(const StringRef DemangledCall,
+ SPIRV::InstructionSet::InstructionSet Set);
+
/// Parses the provided \p ArgIdx argument base type in the \p DemangledCall
/// skeleton. A base type is either a basic type (e.g. i32 for int), pointer
/// element type (e.g. i8 for char*), or builtin type (TargetExtType).
diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index dd5884096b85d..04e29b24d8fd6 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -67,6 +67,7 @@ class SPIRVEmitIntrinsics
DenseMap<Instruction *, Constant *> AggrConsts;
DenseMap<Instruction *, Type *> AggrConstTypes;
DenseSet<Instruction *> AggrStores;
+ SPIRV::InstructionSet::InstructionSet InstrSet;
// deduce element type of untyped pointers
Type *deduceElementType(Value *I, bool UnknownElemTypeI8);
@@ -384,9 +385,10 @@ Type *SPIRVEmitIntrinsics::deduceElementTypeHelper(
std::string DemangledName =
getOclOrSpirvBuiltinDemangledName(CalledF->getName());
auto AsArgIt = ResTypeByArg.find(DemangledName);
- if (AsArgIt != ResTypeByArg.end())
+ if (AsArgIt != ResTypeByArg.end()) {
Ty = deduceElementTypeHelper(CI->getArgOperand(AsArgIt->second),
Visited);
+ }
}
}
@@ -544,6 +546,28 @@ void SPIRVEmitIntrinsics::deduceOperandElementType(Instruction *I) {
KnownElemTy = ElemTy1;
Ops.push_back(std::make_pair(Op0, 0));
}
+ } else if (auto *CI = dyn_cast<CallInst>(I)) {
+ if (Function *CalledF = CI->getCalledFunction()) {
+ std::string DemangledName =
+ getOclOrSpirvBuiltinDemangledName(CalledF->getName());
+ if (DemangledName.length() > 0 &&
+ !StringRef(DemangledName).starts_with("llvm.")) {
+ auto [Grp, Opcode, ExtNo] =
+ SPIRV::mapBuiltinToOpcode(DemangledName, InstrSet);
+ if (Opcode == SPIRV::OpGroupAsyncCopy) {
+ for (unsigned i = 0, PtrCnt = 0; i < CI->arg_size() && PtrCnt < 2;
+ ++i) {
+ Value *Op = CI->getArgOperand(i);
+ if (!isPointerTy(Op->getType()))
+ continue;
+ ++PtrCnt;
+ if (Type *ElemTy = GR->findDeducedElementType(Op))
+ KnownElemTy = ElemTy; // src will rewrite dest if both are defined
+ Ops.push_back(std::make_pair(Op, i));
+ }
+ }
+ }
+ }
}
// There is no enough info to deduce types or all is valid.
@@ -1385,6 +1409,8 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
const SPIRVSubtarget &ST = TM->getSubtarget<SPIRVSubtarget>(Func);
GR = ST.getSPIRVGlobalRegistry();
+ InstrSet = ST.isOpenCLEnv() ? SPIRV::InstructionSet::OpenCL_std
+ : SPIRV::InstructionSet::GLSL_std_450;
F = &Func;
IRBuilder<> B(Func.getContext());
diff --git a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
index 4383d1c5c0e25..2344ec529e16d 100644
--- a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
@@ -203,6 +203,39 @@ static void validateGroupWaitEventsPtr(const SPIRVSubtarget &STI,
doInsertBitcast(STI, MRI, GR, I, OpReg, OpIdx, NewPtrType);
}
+static void validateGroupAsyncCopyPtr(const SPIRVSubtarget &STI,
+ MachineRegisterInfo *MRI,
+ SPIRVGlobalRegistry &GR, MachineInstr &I,
+ unsigned OpIdx) {
+ MachineFunction *MF = I.getParent()->getParent();
+ Register OpReg = I.getOperand(OpIdx).getReg();
+ Register OpTypeReg = getTypeReg(MRI, OpReg);
+ SPIRVType *OpType = GR.getSPIRVTypeForVReg(OpTypeReg, MF);
+ if (!OpType || OpType->getOpcode() != SPIRV::OpTypePointer)
+ return;
+ SPIRVType *ElemType = GR.getSPIRVTypeForVReg(OpType->getOperand(2).getReg());
+ if (!ElemType || ElemType->getOpcode() != SPIRV::OpTypeStruct ||
+ ElemType->getNumOperands() != 2)
+ return;
+ // It's a structure-wrapper around another type with a single member field.
+ SPIRVType *MemberType =
+ GR.getSPIRVTypeForVReg(ElemType->getOperand(1).getReg());
+ if (!MemberType)
+ return;
+ unsigned MemberTypeOp = MemberType->getOpcode();
+ if (MemberTypeOp != SPIRV::OpTypeVector && MemberTypeOp != SPIRV::OpTypeInt &&
+ MemberTypeOp != SPIRV::OpTypeFloat && MemberTypeOp != SPIRV::OpTypeBool)
+ return;
+ // It's a structure-wrapper around a valid type. Insert a bitcast before the
+ // instruction to keep SPIR-V code valid.
+ SPIRV::StorageClass::StorageClass SC =
+ static_cast<SPIRV::StorageClass::StorageClass>(
+ OpType->getOperand(1).getImm());
+ MachineIRBuilder MIB(I);
+ SPIRVType *NewPtrType = GR.getOrCreateSPIRVPointerType(MemberType, MIB, SC);
+ doInsertBitcast(STI, MRI, GR, I, OpReg, OpIdx, NewPtrType);
+}
+
// Insert a bitcast before the function call instruction to keep SPIR-V code
// valid when there is a type mismatch between actual and expected types of an
// argument:
@@ -380,6 +413,10 @@ void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const {
SPIRV::OpTypeBool))
MI.setDesc(STI.getInstrInfo()->get(SPIRV::OpLogicalNotEqual));
break;
+ case SPIRV::OpGroupAsyncCopy:
+ validateGroupAsyncCopyPtr(STI, MRI, GR, MI, 3);
+ validateGroupAsyncCopyPtr(STI, MRI, GR, MI, 4);
+ break;
case SPIRV::OpGroupWaitEvents:
// OpGroupWaitEvents ..., ..., <pointer to OpTypeEvent>
validateGroupWaitEventsPtr(STI, MRI, GR, MI);
diff --git a/llvm/test/CodeGen/SPIRV/transcoding/spirv-event-null.ll b/llvm/test/CodeGen/SPIRV/transcoding/spirv-event-null.ll
index fe0d96f2773ec..df11565ca8180 100644
--- a/llvm/test/CodeGen/SPIRV/transcoding/spirv-event-null.ll
+++ b/llvm/test/CodeGen/SPIRV/transcoding/spirv-event-null.ll
@@ -8,7 +8,18 @@
; CHECK-DAG: %[[#TyStruct:]] = OpTypeStruct %[[#TyEvent]]
; CHECK-DAG: %[[#ConstEvent:]] = OpConstantNull %[[#TyEvent]]
; CHECK-DAG: %[[#TyEventPtr:]] = OpTypePointer Function %[[#TyEvent]]
+; CHECK-DAG: %[[#TyEventPtrGen:]] = OpTypePointer Generic %[[#TyEvent]]
; CHECK-DAG: %[[#TyStructPtr:]] = OpTypePointer Function %[[#TyStruct]]
+; CHECK-DAG: %[[#TyChar:]] = OpTypeInt 8 0
+; CHECK-DAG: %[[#TyV4:]] = OpTypeVector %[[#TyChar]] 4
+; CHECK-DAG: %[[#TyStructV4:]] = OpTypeStruct %[[#TyV4]]
+; CHECK-DAG: %[[#TyPtrSV4_W:]] = OpTypePointer Workgroup %[[#TyStructV4]]
+; CHECK-DAG: %[[#TyPtrSV4_CW:]] = OpTypePointer CrossWorkgroup %[[#TyStructV4]]
+; CHECK-DAG: %[[#TyPtrV4_W:]] = OpTypePointer Workgroup %[[#TyV4]]
+; CHECK-DAG: %[[#TyPtrV4_CW:]] = OpTypePointer CrossWorkgroup %[[#TyV4]]
+
+; Check correct translation of __spirv_GroupAsyncCopy and target("spirv.Event") zeroinitializer
+
; CHECK: OpFunction
; CHECK: OpFunctionParameter
; CHECK: %[[#Src:]] = OpFunctionParameter
@@ -17,12 +28,13 @@
; CHECK: %[[#Dest:]] = OpInBoundsPtrAccessChain
; CHECK: %[[#CopyRes:]] = OpGroupAsyncCopy %[[#TyEvent]] %[[#]] %[[#Dest]] %[[#Src]] %[[#]] %[[#]] %[[#ConstEvent]]
; CHECK: OpStore %[[#EventVar]] %[[#CopyRes]]
+; CHECK: OpFunctionEnd
-%"class.sycl::_V1::device_event" = type { target("spirv.Event") }
+%StructEvent = type { target("spirv.Event") }
-define spir_kernel void @foo(ptr addrspace(1) %_arg_out_ptr, ptr addrspace(3) noundef %_arg_local_acc) {
+define spir_kernel void @foo(ptr addrspace(1) %_arg_out_ptr, ptr addrspace(3) %_arg_local_acc) {
entry:
- %var = alloca %"class.sycl::_V1::device_event"
+ %var = alloca %StructEvent
%dev_event.i.sroa.0 = alloca target("spirv.Event")
%add.ptr.i26 = getelementptr inbounds i32, ptr addrspace(1) %_arg_out_ptr, i64 0
%call3.i = tail call spir_func target("spirv.Event") @_Z22__spirv_GroupAsyncCopyjPU3AS1iPU3AS3Kimm9ocl_event(i32 2, ptr addrspace(1) %add.ptr.i26, ptr addrspace(3) %_arg_local_acc, i64 16, i64 10, target("spirv.Event") zeroinitializer)
@@ -31,3 +43,39 @@ entry:
}
declare dso_local spir_func target("spirv.Event") @_Z22__spirv_GroupAsyncCopyjPU3AS1iPU3AS3Kimm9ocl_event(i32, ptr addrspace(1), ptr addrspace(3), i64, i64, target("spirv.Event"))
+
+; Check correct type inference when calling __spirv_GroupAsyncCopy:
+; we expect that the Backend is able to deduce a type of the %_arg_Local
+; given facts that it's possible to deduce a type of the %_arg
+; and %_arg_Local and %_arg are source/destination arguments in OpGroupAsyncCopy
+
+; CHECK: OpFunction
+; CHECK: %[[#BarArg1:]] = OpFunctionParameter %[[#TyPtrSV4_W]]
+; CHECK: %[[#BarArg2:]] = OpFunctionParameter %[[#TyPtrSV4_CW]]
+; CHECK: %[[#EventVarBar:]] = OpVariable %[[#TyStructPtr]] Function
+; CHECK: %[[#SrcBar:]] = OpInBoundsPtrAccessChain %[[#TyPtrSV4_CW]] %[[#BarArg2]] %[[#]]
+; CHECK-DAG: %[[#BarArg1Casted:]] = OpBitcast %[[#TyPtrV4_W]] %[[#BarArg1]]
+; CHECK-DAG: %[[#SrcBarCasted:]] = OpBitcast %[[#TyPtrV4_CW]] %[[#SrcBar]]
+; CHECK: %[[#ResBar:]] = OpGroupAsyncCopy %[[#TyEvent]] %[[#]] %[[#BarArg1Casted]] %[[#SrcBarCasted]] %[[#]] %[[#]] %[[#ConstEvent]]
+; CHECK: %[[#EventVarBarCasted:]] = OpBitcast %[[#TyEventPtr]] %[[#EventVarBar]]
+; CHECK: OpStore %[[#EventVarBarCasted]] %[[#ResBar]]
+; CHECK: %[[#EventVarBarCasted2:]] = OpBitcast %[[#TyEventPtr]] %[[#EventVarBar]]
+; CHECK: %[[#EventVarBarGen:]] = OpPtrCastToGeneric %[[#TyEventPtrGen]] %[[#EventVarBarCasted2]]
+; CHECK: OpGroupWaitEvents %[[#]] %[[#]] %[[#EventVarBarGen]]
+; CHECK: OpFunctionEnd
+
+%Vec4 = type { <4 x i8> }
+
+define spir_kernel void @bar(ptr addrspace(3) %_arg_Local, ptr addrspace(1) readonly %_arg) {
+entry:
+ %E1 = alloca %StructEvent
+ %srcptr = getelementptr inbounds %Vec4, ptr addrspace(1) %_arg, i64 0
+ %r1 = tail call spir_func target("spirv.Event") @_Z22__spirv_GroupAsyncCopyjPU3AS3Dv4_aPU3AS1KS_mm9ocl_event(i32 2, ptr addrspace(3) %_arg_Local, ptr addrspace(1) %srcptr, i64 16, i64 10, target("spirv.Event") zeroinitializer)
+ store target("spirv.Event") %r1, ptr %E1
+ %E.ascast.i = addrspacecast ptr %E1 to ptr addrspace(4)
+ call spir_func void @_Z23__spirv_GroupWaitEventsjiP9ocl_event(i32 2, i32 1, ptr addrspace(4) %E.ascast.i)
+ ret void
+}
+
+declare dso_local spir_func target("spirv.Event") @_Z22__spirv_GroupAsyncCopyjPU3AS3Dv4_aPU3AS1KS_mm9ocl_event(i32, ptr addrspace(3), ptr addrspace(1), i64, i64, target("spirv.Event"))
+declare dso_local spir_func void @_Z23__spirv_GroupWaitEventsjiP9ocl_event(i32, i32, ptr addrspace(4))
``````````
</details>
https://github.com/llvm/llvm-project/pull/96895
More information about the llvm-commits
mailing list