[llvm] [SPIRV] Added support for 2 kernel query builtins (PR #142280)
via llvm-commits
llvm-commits at lists.llvm.org
Sat May 31 10:41:09 PDT 2025
https://github.com/EbinJose2002 created https://github.com/llvm/llvm-project/pull/142280
Added support for 2 kernel query builtins - OpGetKernelNDrangeMaxSubGroupSize and OpGetKernelNDrangeSubGroupCount
>From 1c31d738d9fc2e5c12df7341823fe150a14d142d Mon Sep 17 00:00:00 2001
From: EbinJose2002 <ebin.jose at multicorewareinc.com>
Date: Wed, 16 Apr 2025 10:39:42 +0530
Subject: [PATCH] Added support for 2 kernel query builtins
---
llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp | 116 +++++++++++++++---
llvm/lib/Target/SPIRV/SPIRVBuiltins.td | 3 +
llvm/lib/Target/SPIRV/SPIRVInstrInfo.td | 4 +
llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp | 5 +
llvm/lib/Target/SPIRV/SPIRVUtils.cpp | 10 ++
llvm/lib/Target/SPIRV/SPIRVUtils.h | 2 +-
.../CodeGen/SPIRV/transcoding/kernel_query.ll | 95 ++++++++++++++
7 files changed, 215 insertions(+), 20 deletions(-)
create mode 100644 llvm/test/CodeGen/SPIRV/transcoding/kernel_query.ll
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index f73a39c6ee9da..99910aeb20c43 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -372,18 +372,15 @@ static MachineInstr *getBlockStructInstr(Register ParamReg,
// We expect the following sequence of instructions:
// %0:_(pN) = G_INTRINSIC_W_SIDE_EFFECTS intrinsic(@llvm.spv.alloca)
// or = G_GLOBAL_VALUE @block_literal_global
- // %1:_(pN) = G_INTRINSIC_W_SIDE_EFFECTS intrinsic(@llvm.spv.bitcast), %0
- // %2:_(p4) = G_ADDRSPACE_CAST %1:_(pN)
+ // %1:_(p4) = G_ADDRSPACE_CAST %0:_(pN)
MachineInstr *MI = MRI->getUniqueVRegDef(ParamReg);
assert(MI->getOpcode() == TargetOpcode::G_ADDRSPACE_CAST &&
MI->getOperand(1).isReg());
- Register BitcastReg = MI->getOperand(1).getReg();
- MachineInstr *BitcastMI = MRI->getUniqueVRegDef(BitcastReg);
- assert(isSpvIntrinsic(*BitcastMI, Intrinsic::spv_bitcast) &&
- BitcastMI->getOperand(2).isReg());
- Register ValueReg = BitcastMI->getOperand(2).getReg();
- MachineInstr *ValueMI = MRI->getUniqueVRegDef(ValueReg);
- return ValueMI;
+ Register PtrReg = MI->getOperand(1).getReg();
+ MachineInstr *PtrMI = MRI->getUniqueVRegDef(PtrReg);
+ assert(PtrMI->getOpcode() == TargetOpcode::G_GLOBAL_VALUE ||
+ isSpvIntrinsic(*PtrMI, Intrinsic::spv_alloca));
+ return PtrMI;
}
// Return an integer constant corresponding to the given register and
@@ -2490,25 +2487,103 @@ static bool buildEnqueueKernel(const SPIRV::IncomingCall *Call,
MachineInstr *BlockMI = getBlockStructInstr(Call->Arguments[BlockFIdx], MRI);
assert(BlockMI->getOpcode() == TargetOpcode::G_GLOBAL_VALUE);
// Invoke: Pointer to invoke function.
- MIB.addGlobalAddress(BlockMI->getOperand(1).getGlobal());
+ Register BlockFReg = BlockMI->getOperand(0).getReg();
+ MIB.addUse(BlockFReg);
+ MRI->setRegClass(BlockFReg, &SPIRV::pIDRegClass);
Register BlockLiteralReg = Call->Arguments[BlockFIdx + 1];
// Param: Pointer to block literal.
MIB.addUse(BlockLiteralReg);
-
- Type *PType = const_cast<Type *>(getBlockStructType(BlockLiteralReg, MRI));
- // TODO: these numbers should be obtained from block literal structure.
- // Param Size: Size of block literal structure.
- MIB.addUse(buildConstantIntReg32(DL.getTypeStoreSize(PType), MIRBuilder, GR));
- // Param Aligment: Aligment of block literal structure.
- MIB.addUse(buildConstantIntReg32(DL.getPrefTypeAlign(PType).value(),
- MIRBuilder, GR));
-
+ BlockMI = MRI->getUniqueVRegDef(BlockLiteralReg);
+ Register BlockMIReg =
+ stripAddrspaceCast(BlockMI->getOperand(1).getReg(), *MRI);
+ BlockMI = MRI->getUniqueVRegDef(BlockMIReg);
+
+ if (BlockMI->getOpcode() == TargetOpcode::G_GLOBAL_VALUE) {
+ auto BlockLiteralMI = MRI->getUniqueVRegDef(BlockMIReg);
+ const GlobalValue *GV = BlockLiteralMI->getOperand(1).getGlobal();
+ const GlobalVariable *BlockGV = dyn_cast<GlobalVariable>(GV);
+ assert(BlockGV && BlockGV->hasInitializer());
+
+ const Constant *Init = BlockGV->getInitializer();
+ const ConstantStruct *CS = dyn_cast<ConstantStruct>(Init);
+ assert(CS && "Expected constant struct for block literal");
+
+ // Extract fields
+ const ConstantInt *SizeConst = dyn_cast<ConstantInt>(CS->getOperand(0));
+ const ConstantInt *AlignConst = dyn_cast<ConstantInt>(CS->getOperand(1));
+ uint64_t BlockSize = SizeConst->getZExtValue();
+ uint64_t BlockAlign = AlignConst->getZExtValue();
+ MIB.addUse(buildConstantIntReg32(BlockSize, MIRBuilder, GR));
+ MIB.addUse(buildConstantIntReg32(BlockAlign, MIRBuilder, GR));
+ }
+ else {
+ Type *PType = const_cast<Type *>(getBlockStructType(BlockLiteralReg, MRI));
+ // TODO: these numbers should be obtained from block literal structure.
+ // Param Size: Size of block literal structure.
+ MIB.addUse(buildConstantIntReg32(DL.getTypeStoreSize(PType), MIRBuilder, GR));
+ // Param Aligment: Aligment of block literal structure.
+ MIB.addUse(buildConstantIntReg32(DL.getPrefTypeAlign(PType).value(),
+ MIRBuilder, GR));
+ }
for (unsigned i = 0; i < LocalSizes.size(); i++)
MIB.addUse(LocalSizes[i]);
return true;
}
+static bool buildNDRangeSubGroup(const SPIRV::IncomingCall *Call,
+ unsigned Opcode, MachineIRBuilder &MIRBuilder,
+ SPIRVGlobalRegistry *GR) {
+ MachineRegisterInfo *MRI = MIRBuilder.getMRI();
+ const DataLayout &DL = MIRBuilder.getDataLayout();
+
+ auto MIB = MIRBuilder.buildInstr(Opcode)
+ .addDef(Call->ReturnRegister)
+ .addUse(GR->getSPIRVTypeID(Call->ReturnType))
+ .addUse(Call->Arguments[0]);
+ unsigned int BlockFIdx = 1;
+ MachineInstr *BlockMI = getBlockStructInstr(Call->Arguments[BlockFIdx], MRI);
+ assert(BlockMI->getOpcode() == TargetOpcode::G_GLOBAL_VALUE);
+ // Invoke: Pointer to invoke function.
+ Register BlockFReg = BlockMI->getOperand(0).getReg();
+ MIB.addUse(BlockFReg);
+ MRI->setRegClass(BlockFReg, &SPIRV::pIDRegClass);
+
+ Register BlockLiteralReg = Call->Arguments[BlockFIdx + 1];
+ // Param: Pointer to block literal.
+ MIB.addUse(BlockLiteralReg);
+ BlockMI = MRI->getUniqueVRegDef(BlockLiteralReg);
+ Register BlockMIReg =
+ stripAddrspaceCast(BlockMI->getOperand(1).getReg(), *MRI);
+ BlockMI = MRI->getUniqueVRegDef(BlockMIReg);
+
+ if (BlockMI->getOpcode() == TargetOpcode::G_GLOBAL_VALUE) {
+ // Size and align are given explicitly here.
+ const GlobalValue *GV = BlockMI->getOperand(1).getGlobal();
+
+ const GlobalVariable *BlockGV = dyn_cast<GlobalVariable>(GV);
+ assert(BlockGV->hasInitializer() &&
+ "Block literal should have an initializer");
+ const Constant *Init = BlockGV->getInitializer();
+ const ConstantStruct *CS = dyn_cast<ConstantStruct>(Init);
+ // Extract fields
+ const ConstantInt *SizeConst = dyn_cast<ConstantInt>(CS->getOperand(0));
+ const ConstantInt *AlignConst = dyn_cast<ConstantInt>(CS->getOperand(1));
+ uint64_t BlockSize = SizeConst->getZExtValue();
+ uint64_t BlockAlign = AlignConst->getZExtValue();
+ MIB.addUse(buildConstantIntReg32(BlockSize, MIRBuilder, GR));
+ MIB.addUse(buildConstantIntReg32(BlockAlign, MIRBuilder, GR));
+ } else {
+ Type *PType = const_cast<Type *>(getBlockStructType(BlockLiteralReg, MRI));
+ // Fallback to default if not found
+ MIB.addUse(
+ buildConstantIntReg32(DL.getTypeStoreSize(PType), MIRBuilder, GR));
+ MIB.addUse(buildConstantIntReg32(DL.getPrefTypeAlign(PType).value(),
+ MIRBuilder, GR));
+ }
+ return true;
+}
+
static bool generateEnqueueInst(const SPIRV::IncomingCall *Call,
MachineIRBuilder &MIRBuilder,
SPIRVGlobalRegistry *GR) {
@@ -2544,6 +2619,9 @@ static bool generateEnqueueInst(const SPIRV::IncomingCall *Call,
return buildNDRange(Call, MIRBuilder, GR);
case SPIRV::OpEnqueueKernel:
return buildEnqueueKernel(Call, MIRBuilder, GR);
+ case SPIRV::OpGetKernelNDrangeSubGroupCount:
+ case SPIRV::OpGetKernelNDrangeMaxSubGroupSize:
+ return buildNDRangeSubGroup(Call, Opcode, MIRBuilder, GR);
default:
return false;
}
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.td b/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
index 6842e5ff067cf..95ca3295b1733 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
@@ -671,6 +671,9 @@ defm : DemangledNativeBuiltin<"__spirv_GetDefaultQueue", OpenCL_std, Enqueue, 0,
defm : DemangledNativeBuiltin<"ndrange_1D", OpenCL_std, Enqueue, 1, 3, OpBuildNDRange>;
defm : DemangledNativeBuiltin<"ndrange_2D", OpenCL_std, Enqueue, 1, 3, OpBuildNDRange>;
defm : DemangledNativeBuiltin<"ndrange_3D", OpenCL_std, Enqueue, 1, 3, OpBuildNDRange>;
+defm : DemangledNativeBuiltin<"__get_kernel_sub_group_count_for_ndrange_impl", OpenCL_std, Enqueue, 2, 2, OpGetKernelNDrangeSubGroupCount>;
+defm : DemangledNativeBuiltin<"__get_kernel_max_sub_group_size_for_ndrange_impl", OpenCL_std, Enqueue, 2, 2, OpGetKernelNDrangeMaxSubGroupSize>;
+
// Spec constant builtin records:
defm : DemangledNativeBuiltin<"__spirv_SpecConstant", OpenCL_std, SpecConstant, 2, 2, OpSpecConstant>;
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
index 338f6809a3e46..9df0f79142bb9 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
+++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
@@ -759,6 +759,10 @@ def OpGetDefaultQueue: Op<303, (outs ID:$res), (ins TYPE:$type),
"$res = OpGetDefaultQueue $type">;
def OpBuildNDRange: Op<304, (outs ID:$res), (ins TYPE:$type, ID:$GWS, ID:$LWS, ID:$GWO),
"$res = OpBuildNDRange $type $GWS $LWS $GWO">;
+def OpGetKernelNDrangeSubGroupCount: Op<304, (outs ID:$res), (ins TYPE:$type, ID:$NDR, ID:$Invoke, ID:$Param, ID:$ParamSize, ID:$ParamAlign),
+ "$res = OpGetKernelNDrangeSubGroupCount $type $NDR $Invoke $Param $ParamSize $ParamAlign">;
+def OpGetKernelNDrangeMaxSubGroupSize: Op<304, (outs ID:$res), (ins TYPE:$type, ID:$NDR, ID:$Invoke, ID:$Param, ID:$ParamSize, ID:$ParamAlign),
+ "$res = OpGetKernelNDrangeMaxSubGroupSize $type $NDR $Invoke $Param $ParamSize $ParamAlign">;
// TODO: 3.42.23. Pipe Instructions
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index 2fdd54fdfc390..18d6f1e4886e7 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1848,6 +1848,11 @@ void addInstrRequirements(const MachineInstr &MI,
Reqs.addCapability(SPIRV::Capability::TernaryBitwiseFunctionINTEL);
break;
}
+ case SPIRV::OpGetKernelNDrangeMaxSubGroupSize:
+ case SPIRV::OpGetKernelNDrangeSubGroupCount: {
+ Reqs.addCapability(SPIRV::Capability::DeviceEnqueue);
+ break;
+ }
default:
break;
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
index 725a7979d3e5b..f1c94c4486b35 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
@@ -488,6 +488,16 @@ bool isEntryPoint(const Function &F) {
return false;
}
+Register stripAddrspaceCast(Register Reg, const MachineRegisterInfo &MRI) {
+ while (true) {
+ MachineInstr *Def = MRI.getVRegDef(Reg);
+ if (!Def || Def->getOpcode() != TargetOpcode::G_ADDRSPACE_CAST)
+ break;
+ Reg = Def->getOperand(1).getReg(); // Unwrap the cast
+ }
+ return Reg;
+}
+
Type *parseBasicTypeName(StringRef &TypeName, LLVMContext &Ctx) {
TypeName.consume_front("atomic_");
if (TypeName.consume_front("void"))
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.h b/llvm/lib/Target/SPIRV/SPIRVUtils.h
index f14a7d356ea58..a2699cbc57145 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.h
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.h
@@ -257,7 +257,7 @@ bool isSpecialOpaqueType(const Type *Ty);
// Check if the function is an SPIR-V entry point
bool isEntryPoint(const Function &F);
-
+Register stripAddrspaceCast(Register Reg, const MachineRegisterInfo &MRI);
// Parse basic scalar type name, substring TypeName, and return LLVM type.
Type *parseBasicTypeName(StringRef &TypeName, LLVMContext &Ctx);
diff --git a/llvm/test/CodeGen/SPIRV/transcoding/kernel_query.ll b/llvm/test/CodeGen/SPIRV/transcoding/kernel_query.ll
new file mode 100644
index 0000000000000..4d1917ca92abe
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/transcoding/kernel_query.ll
@@ -0,0 +1,95 @@
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
+
+target datalayout = "e-p:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024"
+target triple = "spir-unknown-unknown"
+
+%struct.ndrange_t = type { i32 }
+%1 = type <{ i32, i32 }>
+
+ at __block_literal_global = internal addrspace(1) constant { i32, i32 } { i32 8, i32 4 }, align 4
+ at __block_literal_global.1 = internal addrspace(1) constant { i32, i32 } { i32 8, i32 4 }, align 4
+
+; CHECK-DAG: %[[#Int32Ty:]] = OpTypeInt 32 0
+; CHECK-DAG: %[[#C4:]] = OpConstant %[[#Int32Ty]] 4
+; CHECK-DAG: %[[#C8:]] = OpConstant %[[#Int32Ty]] 8
+; CHECK-DAG: %[[#NDRangeTy:]] = OpTypeStruct %[[#Int32Ty]]
+; CHECK-DAG: %[[#NDRangePtrTy:]] = OpTypePointer Function %[[#NDRangeTy]]
+
+; Function Attrs: convergent noinline nounwind optnone
+define spir_kernel void @device_side_enqueue() #0 !kernel_arg_addr_space !2 !kernel_arg_access_qual !2 !kernel_arg_type !2 !kernel_arg_base_type !2 !kernel_arg_type_qual !2 {
+entry:
+
+; CHECK: %[[#NDRange:]] = OpVariable %[[#NDRangePtrTy]]
+
+ %ndrange = alloca %struct.ndrange_t, align 4
+
+; CHECK: %[[#BlockLit1:]] = OpPtrCastToGeneric %[[#]] %[[#]]
+; CHECK: %[[#]] = OpGetKernelNDrangeMaxSubGroupSize %[[#Int32Ty]] %[[#NDRange]] %[[#]] %[[#BlockLit1]] %[[#C8]] %[[#C4]]
+
+ %0 = call i32 @__get_kernel_max_sub_group_size_for_ndrange_impl(ptr %ndrange, ptr addrspace(4) addrspacecast (ptr @__device_side_enqueue_block_invoke_kernel to ptr addrspace(4)), ptr addrspace(4) addrspacecast (ptr addrspace(1) @__block_literal_global to ptr addrspace(4)))
+
+; CHECK: %[[#BlockLit2:]] = OpPtrCastToGeneric %[[#]] %[[#]]
+; CHECK: %[[#]] = OpGetKernelNDrangeSubGroupCount %[[#Int32Ty]] %[[#NDRange]] %[[#]] %[[#BlockLit2]] %[[#C8]] %[[#C4]]
+
+ %1 = call i32 @__get_kernel_sub_group_count_for_ndrange_impl(ptr %ndrange, ptr addrspace(4) addrspacecast (ptr @__device_side_enqueue_block_invoke_1_kernel to ptr addrspace(4)), ptr addrspace(4) addrspacecast (ptr addrspace(1) @__block_literal_global.1 to ptr addrspace(4)))
+ ret void
+}
+
+declare i32 @__get_kernel_preferred_work_group_size_multiple_impl(ptr addrspace(4), ptr addrspace(4))
+
+; Function Attrs: convergent noinline nounwind optnone
+define internal spir_func void @__device_side_enqueue_block_invoke(ptr addrspace(4) %.block_descriptor) #1 {
+entry:
+ %.block_descriptor.addr = alloca ptr addrspace(4), align 4
+ %block.addr = alloca ptr addrspace(4), align 4
+ store ptr addrspace(4) %.block_descriptor, ptr %.block_descriptor.addr, align 4
+ store ptr addrspace(4) %.block_descriptor, ptr %block.addr, align 4
+ ret void
+}
+
+; Function Attrs: nounwind
+define internal spir_kernel void @__device_side_enqueue_block_invoke_kernel(ptr addrspace(4)) #2 {
+entry:
+ call void @__device_side_enqueue_block_invoke(ptr addrspace(4) %0)
+ ret void
+}
+
+declare i32 @__get_kernel_max_sub_group_size_for_ndrange_impl(ptr, ptr addrspace(4), ptr addrspace(4))
+
+; Function Attrs: convergent noinline nounwind optnone
+define internal spir_func void @__device_side_enqueue_block_invoke_1(ptr addrspace(4) %.block_descriptor) #1 {
+entry:
+ %.block_descriptor.addr = alloca ptr addrspace(4), align 4
+ %block.addr = alloca ptr addrspace(4), align 4
+ store ptr addrspace(4) %.block_descriptor, ptr %.block_descriptor.addr, align 4
+ store ptr addrspace(4) %.block_descriptor, ptr %block.addr, align 4
+ ret void
+}
+
+; Function Attrs: nounwind
+define internal spir_kernel void @__device_side_enqueue_block_invoke_1_kernel(ptr addrspace(4)) #2 {
+entry:
+ call void @__device_side_enqueue_block_invoke_1(ptr addrspace(4) %0)
+ ret void
+}
+
+declare i32 @__get_kernel_sub_group_count_for_ndrange_impl(ptr, ptr addrspace(4), ptr addrspace(4))
+
+attributes #0 = { convergent noinline nounwind optnone "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "less-precise-fpmad"="false" "no-frame-pointer-elim"="false" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="false" "stack-protector-buffer-size"="8" "uniform-work-group-size"="false" "unsafe-fp-math"="false" "use-soft-float"="false" }
+attributes #1 = { convergent noinline nounwind optnone "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "less-precise-fpmad"="false" "no-frame-pointer-elim"="false" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="false" "stack-protector-buffer-size"="8" "unsafe-fp-math"="false" "use-soft-float"="false" }
+attributes #2 = { nounwind }
+attributes #3 = { argmemonly nounwind }
+
+!llvm.module.flags = !{!0}
+!opencl.enable.FP_CONTRACT = !{}
+!opencl.ocl.version = !{!1}
+!opencl.spir.version = !{!1}
+!opencl.used.extensions = !{!2}
+!opencl.used.optional.core.features = !{!2}
+!opencl.compiler.options = !{!2}
+!llvm.ident = !{!3}
+
+!0 = !{i32 1, !"wchar_size", i32 4}
+!1 = !{i32 2, i32 0}
+!2 = !{}
+!3 = !{!"clang version 7.0.0"}
More information about the llvm-commits
mailing list