[llvm] [SPIRV] Added support for 2 kernel query builtins (PR #142280)

via llvm-commits llvm-commits at lists.llvm.org
Sat May 31 10:41:09 PDT 2025


https://github.com/EbinJose2002 created https://github.com/llvm/llvm-project/pull/142280

Added support for 2 kernel query builtins - OpGetKernelNDrangeMaxSubGroupSize and OpGetKernelNDrangeSubGroupCount

>From 1c31d738d9fc2e5c12df7341823fe150a14d142d Mon Sep 17 00:00:00 2001
From: EbinJose2002 <ebin.jose at multicorewareinc.com>
Date: Wed, 16 Apr 2025 10:39:42 +0530
Subject: [PATCH] Added support for 2 kernel query builtins

---
 llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp       | 116 +++++++++++++++---
 llvm/lib/Target/SPIRV/SPIRVBuiltins.td        |   3 +
 llvm/lib/Target/SPIRV/SPIRVInstrInfo.td       |   4 +
 llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp |   5 +
 llvm/lib/Target/SPIRV/SPIRVUtils.cpp          |  10 ++
 llvm/lib/Target/SPIRV/SPIRVUtils.h            |   2 +-
 .../CodeGen/SPIRV/transcoding/kernel_query.ll |  95 ++++++++++++++
 7 files changed, 215 insertions(+), 20 deletions(-)
 create mode 100644 llvm/test/CodeGen/SPIRV/transcoding/kernel_query.ll

diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index f73a39c6ee9da..99910aeb20c43 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -372,18 +372,15 @@ static MachineInstr *getBlockStructInstr(Register ParamReg,
   // We expect the following sequence of instructions:
   //   %0:_(pN) = G_INTRINSIC_W_SIDE_EFFECTS intrinsic(@llvm.spv.alloca)
   //   or       = G_GLOBAL_VALUE @block_literal_global
-  //   %1:_(pN) = G_INTRINSIC_W_SIDE_EFFECTS intrinsic(@llvm.spv.bitcast), %0
-  //   %2:_(p4) = G_ADDRSPACE_CAST %1:_(pN)
+  //   %1:_(p4) = G_ADDRSPACE_CAST %0:_(pN)
   MachineInstr *MI = MRI->getUniqueVRegDef(ParamReg);
   assert(MI->getOpcode() == TargetOpcode::G_ADDRSPACE_CAST &&
          MI->getOperand(1).isReg());
-  Register BitcastReg = MI->getOperand(1).getReg();
-  MachineInstr *BitcastMI = MRI->getUniqueVRegDef(BitcastReg);
-  assert(isSpvIntrinsic(*BitcastMI, Intrinsic::spv_bitcast) &&
-         BitcastMI->getOperand(2).isReg());
-  Register ValueReg = BitcastMI->getOperand(2).getReg();
-  MachineInstr *ValueMI = MRI->getUniqueVRegDef(ValueReg);
-  return ValueMI;
+  Register PtrReg = MI->getOperand(1).getReg();
+  MachineInstr *PtrMI = MRI->getUniqueVRegDef(PtrReg);
+  assert(PtrMI->getOpcode() == TargetOpcode::G_GLOBAL_VALUE ||
+         isSpvIntrinsic(*PtrMI, Intrinsic::spv_alloca));
+  return PtrMI;
 }
 
 // Return an integer constant corresponding to the given register and
@@ -2490,25 +2487,103 @@ static bool buildEnqueueKernel(const SPIRV::IncomingCall *Call,
   MachineInstr *BlockMI = getBlockStructInstr(Call->Arguments[BlockFIdx], MRI);
   assert(BlockMI->getOpcode() == TargetOpcode::G_GLOBAL_VALUE);
   // Invoke: Pointer to invoke function.
-  MIB.addGlobalAddress(BlockMI->getOperand(1).getGlobal());
+  Register BlockFReg = BlockMI->getOperand(0).getReg();
+  MIB.addUse(BlockFReg);
+  MRI->setRegClass(BlockFReg, &SPIRV::pIDRegClass);
 
   Register BlockLiteralReg = Call->Arguments[BlockFIdx + 1];
   // Param: Pointer to block literal.
   MIB.addUse(BlockLiteralReg);
-
-  Type *PType = const_cast<Type *>(getBlockStructType(BlockLiteralReg, MRI));
-  // TODO: these numbers should be obtained from block literal structure.
-  // Param Size: Size of block literal structure.
-  MIB.addUse(buildConstantIntReg32(DL.getTypeStoreSize(PType), MIRBuilder, GR));
-  // Param Aligment: Aligment of block literal structure.
-  MIB.addUse(buildConstantIntReg32(DL.getPrefTypeAlign(PType).value(),
-                                   MIRBuilder, GR));
-
+  BlockMI = MRI->getUniqueVRegDef(BlockLiteralReg);
+  Register BlockMIReg =
+      stripAddrspaceCast(BlockMI->getOperand(1).getReg(), *MRI);
+  BlockMI = MRI->getUniqueVRegDef(BlockMIReg);
+
+  if (BlockMI->getOpcode() == TargetOpcode::G_GLOBAL_VALUE) {
+    auto BlockLiteralMI = MRI->getUniqueVRegDef(BlockMIReg);
+    const GlobalValue *GV = BlockLiteralMI->getOperand(1).getGlobal();
+    const GlobalVariable *BlockGV = dyn_cast<GlobalVariable>(GV);
+    assert(BlockGV && BlockGV->hasInitializer());
+
+    const Constant *Init = BlockGV->getInitializer();
+    const ConstantStruct *CS = dyn_cast<ConstantStruct>(Init);
+    assert(CS && "Expected constant struct for block literal");
+
+    // Extract fields
+    const ConstantInt *SizeConst = dyn_cast<ConstantInt>(CS->getOperand(0));
+    const ConstantInt *AlignConst = dyn_cast<ConstantInt>(CS->getOperand(1));
+    uint64_t BlockSize = SizeConst->getZExtValue();
+    uint64_t BlockAlign = AlignConst->getZExtValue();
+    MIB.addUse(buildConstantIntReg32(BlockSize, MIRBuilder, GR));
+    MIB.addUse(buildConstantIntReg32(BlockAlign, MIRBuilder, GR));
+  }
+  else {
+    Type *PType = const_cast<Type *>(getBlockStructType(BlockLiteralReg, MRI));
+    // TODO: these numbers should be obtained from block literal structure.
+    // Param Size: Size of block literal structure.
+    MIB.addUse(buildConstantIntReg32(DL.getTypeStoreSize(PType), MIRBuilder, GR));
+    // Param Aligment: Aligment of block literal structure.
+    MIB.addUse(buildConstantIntReg32(DL.getPrefTypeAlign(PType).value(),
+                                    MIRBuilder, GR));
+  }
   for (unsigned i = 0; i < LocalSizes.size(); i++)
     MIB.addUse(LocalSizes[i]);
   return true;
 }
 
+static bool buildNDRangeSubGroup(const SPIRV::IncomingCall *Call,
+                                 unsigned Opcode, MachineIRBuilder &MIRBuilder,
+                                 SPIRVGlobalRegistry *GR) {
+  MachineRegisterInfo *MRI = MIRBuilder.getMRI();
+  const DataLayout &DL = MIRBuilder.getDataLayout();
+
+  auto MIB = MIRBuilder.buildInstr(Opcode)
+                 .addDef(Call->ReturnRegister)
+                 .addUse(GR->getSPIRVTypeID(Call->ReturnType))
+                 .addUse(Call->Arguments[0]);
+  unsigned int BlockFIdx = 1;
+  MachineInstr *BlockMI = getBlockStructInstr(Call->Arguments[BlockFIdx], MRI);
+  assert(BlockMI->getOpcode() == TargetOpcode::G_GLOBAL_VALUE);
+  // Invoke: Pointer to invoke function.
+  Register BlockFReg = BlockMI->getOperand(0).getReg();
+  MIB.addUse(BlockFReg);
+  MRI->setRegClass(BlockFReg, &SPIRV::pIDRegClass);
+
+  Register BlockLiteralReg = Call->Arguments[BlockFIdx + 1];
+  // Param: Pointer to block literal.
+  MIB.addUse(BlockLiteralReg);
+  BlockMI = MRI->getUniqueVRegDef(BlockLiteralReg);
+  Register BlockMIReg =
+      stripAddrspaceCast(BlockMI->getOperand(1).getReg(), *MRI);
+  BlockMI = MRI->getUniqueVRegDef(BlockMIReg);
+
+  if (BlockMI->getOpcode() == TargetOpcode::G_GLOBAL_VALUE) {
+    // Size and align are given explicitly here.
+    const GlobalValue *GV = BlockMI->getOperand(1).getGlobal();
+
+    const GlobalVariable *BlockGV = dyn_cast<GlobalVariable>(GV);
+    assert(BlockGV->hasInitializer() &&
+           "Block literal should have an initializer");
+    const Constant *Init = BlockGV->getInitializer();
+    const ConstantStruct *CS = dyn_cast<ConstantStruct>(Init);
+    // Extract fields
+    const ConstantInt *SizeConst = dyn_cast<ConstantInt>(CS->getOperand(0));
+    const ConstantInt *AlignConst = dyn_cast<ConstantInt>(CS->getOperand(1));
+    uint64_t BlockSize = SizeConst->getZExtValue();
+    uint64_t BlockAlign = AlignConst->getZExtValue();
+    MIB.addUse(buildConstantIntReg32(BlockSize, MIRBuilder, GR));
+    MIB.addUse(buildConstantIntReg32(BlockAlign, MIRBuilder, GR));
+  } else {
+    Type *PType = const_cast<Type *>(getBlockStructType(BlockLiteralReg, MRI));
+    // Fallback to default if not found
+    MIB.addUse(
+        buildConstantIntReg32(DL.getTypeStoreSize(PType), MIRBuilder, GR));
+    MIB.addUse(buildConstantIntReg32(DL.getPrefTypeAlign(PType).value(),
+                                     MIRBuilder, GR));
+  }
+  return true;
+}
+
 static bool generateEnqueueInst(const SPIRV::IncomingCall *Call,
                                 MachineIRBuilder &MIRBuilder,
                                 SPIRVGlobalRegistry *GR) {
@@ -2544,6 +2619,9 @@ static bool generateEnqueueInst(const SPIRV::IncomingCall *Call,
     return buildNDRange(Call, MIRBuilder, GR);
   case SPIRV::OpEnqueueKernel:
     return buildEnqueueKernel(Call, MIRBuilder, GR);
+  case SPIRV::OpGetKernelNDrangeSubGroupCount:
+  case SPIRV::OpGetKernelNDrangeMaxSubGroupSize:
+    return buildNDRangeSubGroup(Call, Opcode, MIRBuilder, GR);
   default:
     return false;
   }
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.td b/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
index 6842e5ff067cf..95ca3295b1733 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
@@ -671,6 +671,9 @@ defm : DemangledNativeBuiltin<"__spirv_GetDefaultQueue", OpenCL_std, Enqueue, 0,
 defm : DemangledNativeBuiltin<"ndrange_1D", OpenCL_std, Enqueue, 1, 3, OpBuildNDRange>;
 defm : DemangledNativeBuiltin<"ndrange_2D", OpenCL_std, Enqueue, 1, 3, OpBuildNDRange>;
 defm : DemangledNativeBuiltin<"ndrange_3D", OpenCL_std, Enqueue, 1, 3, OpBuildNDRange>;
+defm : DemangledNativeBuiltin<"__get_kernel_sub_group_count_for_ndrange_impl", OpenCL_std, Enqueue, 2, 2, OpGetKernelNDrangeSubGroupCount>;
+defm : DemangledNativeBuiltin<"__get_kernel_max_sub_group_size_for_ndrange_impl", OpenCL_std, Enqueue, 2, 2, OpGetKernelNDrangeMaxSubGroupSize>;
+
 
 // Spec constant builtin records:
 defm : DemangledNativeBuiltin<"__spirv_SpecConstant", OpenCL_std, SpecConstant, 2, 2, OpSpecConstant>;
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
index 338f6809a3e46..9df0f79142bb9 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
+++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
@@ -759,6 +759,10 @@ def OpGetDefaultQueue: Op<303, (outs ID:$res), (ins TYPE:$type),
                   "$res = OpGetDefaultQueue $type">;
 def OpBuildNDRange: Op<304, (outs ID:$res), (ins TYPE:$type, ID:$GWS, ID:$LWS, ID:$GWO),
                   "$res = OpBuildNDRange $type $GWS $LWS $GWO">;
+def OpGetKernelNDrangeSubGroupCount: Op<304, (outs ID:$res), (ins TYPE:$type, ID:$NDR, ID:$Invoke, ID:$Param, ID:$ParamSize, ID:$ParamAlign),
+                  "$res = OpGetKernelNDrangeSubGroupCount $type $NDR $Invoke $Param $ParamSize $ParamAlign">;
+def OpGetKernelNDrangeMaxSubGroupSize: Op<304, (outs ID:$res), (ins TYPE:$type, ID:$NDR, ID:$Invoke, ID:$Param, ID:$ParamSize, ID:$ParamAlign),
+                  "$res = OpGetKernelNDrangeMaxSubGroupSize $type $NDR $Invoke $Param $ParamSize $ParamAlign">;
 
 // TODO: 3.42.23. Pipe Instructions
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index 2fdd54fdfc390..18d6f1e4886e7 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1848,6 +1848,11 @@ void addInstrRequirements(const MachineInstr &MI,
     Reqs.addCapability(SPIRV::Capability::TernaryBitwiseFunctionINTEL);
     break;
   }
+  case SPIRV::OpGetKernelNDrangeMaxSubGroupSize:
+  case SPIRV::OpGetKernelNDrangeSubGroupCount: {
+    Reqs.addCapability(SPIRV::Capability::DeviceEnqueue);
+    break;
+  }
 
   default:
     break;
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
index 725a7979d3e5b..f1c94c4486b35 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
@@ -488,6 +488,16 @@ bool isEntryPoint(const Function &F) {
   return false;
 }
 
+Register stripAddrspaceCast(Register Reg, const MachineRegisterInfo &MRI) {
+  while (true) {
+    MachineInstr *Def = MRI.getVRegDef(Reg);
+    if (!Def || Def->getOpcode() != TargetOpcode::G_ADDRSPACE_CAST)
+      break;
+    Reg = Def->getOperand(1).getReg(); // Unwrap the cast
+  }
+  return Reg;
+}
+
 Type *parseBasicTypeName(StringRef &TypeName, LLVMContext &Ctx) {
   TypeName.consume_front("atomic_");
   if (TypeName.consume_front("void"))
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.h b/llvm/lib/Target/SPIRV/SPIRVUtils.h
index f14a7d356ea58..a2699cbc57145 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.h
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.h
@@ -257,7 +257,7 @@ bool isSpecialOpaqueType(const Type *Ty);
 
 // Check if the function is an SPIR-V entry point
 bool isEntryPoint(const Function &F);
-
+Register stripAddrspaceCast(Register Reg, const MachineRegisterInfo &MRI);
 // Parse basic scalar type name, substring TypeName, and return LLVM type.
 Type *parseBasicTypeName(StringRef &TypeName, LLVMContext &Ctx);
 
diff --git a/llvm/test/CodeGen/SPIRV/transcoding/kernel_query.ll b/llvm/test/CodeGen/SPIRV/transcoding/kernel_query.ll
new file mode 100644
index 0000000000000..4d1917ca92abe
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/transcoding/kernel_query.ll
@@ -0,0 +1,95 @@
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
+
+target datalayout = "e-p:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024"
+target triple = "spir-unknown-unknown"
+
+%struct.ndrange_t = type { i32 }
+%1 = type <{ i32, i32 }>
+
+ at __block_literal_global = internal addrspace(1) constant { i32, i32 } { i32 8, i32 4 }, align 4
+ at __block_literal_global.1 = internal addrspace(1) constant { i32, i32 } { i32 8, i32 4 }, align 4
+
+; CHECK-DAG: %[[#Int32Ty:]] = OpTypeInt 32 0
+; CHECK-DAG: %[[#C4:]] = OpConstant %[[#Int32Ty]] 4
+; CHECK-DAG: %[[#C8:]] = OpConstant %[[#Int32Ty]] 8
+; CHECK-DAG: %[[#NDRangeTy:]] = OpTypeStruct %[[#Int32Ty]]
+; CHECK-DAG: %[[#NDRangePtrTy:]] = OpTypePointer Function %[[#NDRangeTy]]
+
+; Function Attrs: convergent noinline nounwind optnone
+define spir_kernel void @device_side_enqueue() #0 !kernel_arg_addr_space !2 !kernel_arg_access_qual !2 !kernel_arg_type !2 !kernel_arg_base_type !2 !kernel_arg_type_qual !2 {
+entry:
+
+; CHECK: %[[#NDRange:]] = OpVariable %[[#NDRangePtrTy]]
+
+  %ndrange = alloca %struct.ndrange_t, align 4
+
+; CHECK: %[[#BlockLit1:]] = OpPtrCastToGeneric %[[#]] %[[#]]
+; CHECK: %[[#]] = OpGetKernelNDrangeMaxSubGroupSize %[[#Int32Ty]] %[[#NDRange]] %[[#]] %[[#BlockLit1]] %[[#C8]] %[[#C4]]
+
+  %0 = call i32 @__get_kernel_max_sub_group_size_for_ndrange_impl(ptr %ndrange, ptr addrspace(4) addrspacecast (ptr @__device_side_enqueue_block_invoke_kernel to ptr addrspace(4)), ptr addrspace(4) addrspacecast (ptr addrspace(1) @__block_literal_global to ptr addrspace(4)))
+
+; CHECK: %[[#BlockLit2:]] = OpPtrCastToGeneric %[[#]] %[[#]]
+; CHECK: %[[#]] = OpGetKernelNDrangeSubGroupCount %[[#Int32Ty]] %[[#NDRange]] %[[#]] %[[#BlockLit2]] %[[#C8]] %[[#C4]]
+
+  %1 = call i32 @__get_kernel_sub_group_count_for_ndrange_impl(ptr %ndrange, ptr addrspace(4) addrspacecast (ptr @__device_side_enqueue_block_invoke_1_kernel to ptr addrspace(4)), ptr addrspace(4) addrspacecast (ptr addrspace(1) @__block_literal_global.1 to ptr addrspace(4)))
+  ret void
+}
+
+declare i32 @__get_kernel_preferred_work_group_size_multiple_impl(ptr addrspace(4), ptr addrspace(4))
+
+; Function Attrs: convergent noinline nounwind optnone
+define internal spir_func void @__device_side_enqueue_block_invoke(ptr addrspace(4) %.block_descriptor) #1 {
+entry:
+  %.block_descriptor.addr = alloca ptr addrspace(4), align 4
+  %block.addr = alloca ptr addrspace(4), align 4
+  store ptr addrspace(4) %.block_descriptor, ptr %.block_descriptor.addr, align 4
+  store ptr addrspace(4) %.block_descriptor, ptr %block.addr, align 4
+  ret void
+}
+
+; Function Attrs: nounwind
+define internal spir_kernel void @__device_side_enqueue_block_invoke_kernel(ptr addrspace(4)) #2 {
+entry:
+  call void @__device_side_enqueue_block_invoke(ptr addrspace(4) %0)
+  ret void
+}
+
+declare i32 @__get_kernel_max_sub_group_size_for_ndrange_impl(ptr, ptr addrspace(4), ptr addrspace(4))
+
+; Function Attrs: convergent noinline nounwind optnone
+define internal spir_func void @__device_side_enqueue_block_invoke_1(ptr addrspace(4) %.block_descriptor) #1 {
+entry:
+  %.block_descriptor.addr = alloca ptr addrspace(4), align 4
+  %block.addr = alloca ptr addrspace(4), align 4
+  store ptr addrspace(4) %.block_descriptor, ptr %.block_descriptor.addr, align 4
+  store ptr addrspace(4) %.block_descriptor, ptr %block.addr, align 4
+  ret void
+}
+
+; Function Attrs: nounwind
+define internal spir_kernel void @__device_side_enqueue_block_invoke_1_kernel(ptr addrspace(4)) #2 {
+entry:
+  call void @__device_side_enqueue_block_invoke_1(ptr addrspace(4) %0)
+  ret void
+}
+
+declare i32 @__get_kernel_sub_group_count_for_ndrange_impl(ptr, ptr addrspace(4), ptr addrspace(4))
+
+attributes #0 = { convergent noinline nounwind optnone "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "less-precise-fpmad"="false" "no-frame-pointer-elim"="false" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="false" "stack-protector-buffer-size"="8" "uniform-work-group-size"="false" "unsafe-fp-math"="false" "use-soft-float"="false" }
+attributes #1 = { convergent noinline nounwind optnone "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "less-precise-fpmad"="false" "no-frame-pointer-elim"="false" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="false" "stack-protector-buffer-size"="8" "unsafe-fp-math"="false" "use-soft-float"="false" }
+attributes #2 = { nounwind }
+attributes #3 = { argmemonly nounwind }
+
+!llvm.module.flags = !{!0}
+!opencl.enable.FP_CONTRACT = !{}
+!opencl.ocl.version = !{!1}
+!opencl.spir.version = !{!1}
+!opencl.used.extensions = !{!2}
+!opencl.used.optional.core.features = !{!2}
+!opencl.compiler.options = !{!2}
+!llvm.ident = !{!3}
+
+!0 = !{i32 1, !"wchar_size", i32 4}
+!1 = !{i32 2, i32 0}
+!2 = !{}
+!3 = !{!"clang version 7.0.0"}



More information about the llvm-commits mailing list