[llvm] [SPIRV] OpEnqueueKernel Instruction generation correction (PR #136094)

via llvm-commits llvm-commits at lists.llvm.org
Thu Apr 17 21:32:48 PDT 2025


https://github.com/EbinJose2002 updated https://github.com/llvm/llvm-project/pull/136094

>From 1662ed1c1ccfbcec4aede083e2e26c5bc79c049d Mon Sep 17 00:00:00 2001
From: EbinJose2002 <ebin.jose at multicorewareinc.com>
Date: Thu, 17 Apr 2025 11:17:07 +0530
Subject: [PATCH]  - Enqueue Kernel Builtin correction  - Handled cases where
 size and align are explicitly given inside Global_Value  - Created a
 SpirvUtil function for stripping addresscast similar to in spirv-llvm
 translator

---
 llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp | 58 +++++++++++++++++--------
 llvm/lib/Target/SPIRV/SPIRVUtils.cpp    | 10 +++++
 llvm/lib/Target/SPIRV/SPIRVUtils.h      |  3 ++
 3 files changed, 52 insertions(+), 19 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index 16364ab30f280..2591c9dcedcb4 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -372,18 +372,15 @@ static MachineInstr *getBlockStructInstr(Register ParamReg,
   // We expect the following sequence of instructions:
   //   %0:_(pN) = G_INTRINSIC_W_SIDE_EFFECTS intrinsic(@llvm.spv.alloca)
   //   or       = G_GLOBAL_VALUE @block_literal_global
-  //   %1:_(pN) = G_INTRINSIC_W_SIDE_EFFECTS intrinsic(@llvm.spv.bitcast), %0
-  //   %2:_(p4) = G_ADDRSPACE_CAST %1:_(pN)
+  //   %1:_(p4) = G_ADDRSPACE_CAST %0:_(pN)
   MachineInstr *MI = MRI->getUniqueVRegDef(ParamReg);
   assert(MI->getOpcode() == TargetOpcode::G_ADDRSPACE_CAST &&
          MI->getOperand(1).isReg());
-  Register BitcastReg = MI->getOperand(1).getReg();
-  MachineInstr *BitcastMI = MRI->getUniqueVRegDef(BitcastReg);
-  assert(isSpvIntrinsic(*BitcastMI, Intrinsic::spv_bitcast) &&
-         BitcastMI->getOperand(2).isReg());
-  Register ValueReg = BitcastMI->getOperand(2).getReg();
-  MachineInstr *ValueMI = MRI->getUniqueVRegDef(ValueReg);
-  return ValueMI;
+  Register PtrReg = MI->getOperand(1).getReg();
+  MachineInstr *PtrMI = MRI->getUniqueVRegDef(PtrReg);
+  assert(PtrMI->getOpcode() == TargetOpcode::G_GLOBAL_VALUE ||
+         isSpvIntrinsic(*PtrMI, Intrinsic::spv_alloca));
+  return PtrMI;
 }
 
 // Return an integer constant corresponding to the given register and
@@ -2436,20 +2433,43 @@ static bool buildEnqueueKernel(const SPIRV::IncomingCall *Call,
   MachineInstr *BlockMI = getBlockStructInstr(Call->Arguments[BlockFIdx], MRI);
   assert(BlockMI->getOpcode() == TargetOpcode::G_GLOBAL_VALUE);
   // Invoke: Pointer to invoke function.
-  MIB.addGlobalAddress(BlockMI->getOperand(1).getGlobal());
+  Register BlockFReg = BlockMI->getOperand(0).getReg();
+  MIB.addUse(BlockFReg);
+  MRI->setRegClass(BlockFReg, &SPIRV::pIDRegClass);
 
   Register BlockLiteralReg = Call->Arguments[BlockFIdx + 1];
   // Param: Pointer to block literal.
   MIB.addUse(BlockLiteralReg);
-
-  Type *PType = const_cast<Type *>(getBlockStructType(BlockLiteralReg, MRI));
-  // TODO: these numbers should be obtained from block literal structure.
-  // Param Size: Size of block literal structure.
-  MIB.addUse(buildConstantIntReg32(DL.getTypeStoreSize(PType), MIRBuilder, GR));
-  // Param Aligment: Aligment of block literal structure.
-  MIB.addUse(buildConstantIntReg32(DL.getPrefTypeAlign(PType).value(),
-                                   MIRBuilder, GR));
-
+  BlockMI = MRI->getUniqueVRegDef(BlockLiteralReg);
+  Register BlockMIReg =
+      stripAddrspaceCast(BlockMI->getOperand(1).getReg(), *MRI);
+  BlockMI = MRI->getUniqueVRegDef(BlockMIReg);
+
+  if (BlockMI->getOpcode() == TargetOpcode::G_GLOBAL_VALUE) {
+    // Size and align are given explicitly here.
+    const GlobalValue *GV = BlockMI->getOperand(1).getGlobal();
+    const GlobalVariable *BlockGV = dyn_cast<GlobalVariable>(GV);
+    assert(BlockGV && BlockGV->hasInitializer());
+    const Constant *Init = BlockGV->getInitializer();
+    const ConstantStruct *CS = dyn_cast<ConstantStruct>(Init);
+    // Extract fields
+    const ConstantInt *SizeConst = dyn_cast<ConstantInt>(CS->getOperand(0));
+    const ConstantInt *AlignConst = dyn_cast<ConstantInt>(CS->getOperand(1));
+    uint64_t BlockSize = SizeConst->getZExtValue();
+    uint64_t BlockAlign = AlignConst->getZExtValue();
+    MIB.addUse(buildConstantIntReg32(BlockSize, MIRBuilder, GR));
+    MIB.addUse(buildConstantIntReg32(BlockAlign, MIRBuilder, GR));
+  }
+  else {
+    Type *PType = const_cast<Type *>(getBlockStructType(BlockLiteralReg, MRI));
+    // TODO: these numbers should be obtained from block literal structure.
+    // Param Size: Size of block literal structure.
+    MIB.addUse(buildConstantIntReg32(DL.getTypeStoreSize(PType), MIRBuilder, GR));
+    // Param Aligment: Aligment of block literal structure.
+    MIB.addUse(buildConstantIntReg32(DL.getPrefTypeAlign(PType).value(),
+                                     MIRBuilder, GR));
+  
+  }
   for (unsigned i = 0; i < LocalSizes.size(); i++)
     MIB.addUse(LocalSizes[i]);
   return true;
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
index f38794afab436..df969ef76590a 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
@@ -486,6 +486,16 @@ bool isEntryPoint(const Function &F) {
   return false;
 }
 
+Register stripAddrspaceCast(Register Reg, const MachineRegisterInfo &MRI) {
+  while (true) {
+    MachineInstr *Def = MRI.getVRegDef(Reg);
+    if (!Def || Def->getOpcode() != TargetOpcode::G_ADDRSPACE_CAST)
+      break;
+    Reg = Def->getOperand(1).getReg(); // Unwrap the cast
+  }
+  return Reg;
+}
+
 Type *parseBasicTypeName(StringRef &TypeName, LLVMContext &Ctx) {
   TypeName.consume_front("atomic_");
   if (TypeName.consume_front("void"))
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.h b/llvm/lib/Target/SPIRV/SPIRVUtils.h
index 0498c7beb073c..7765dca6a1df1 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.h
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.h
@@ -243,6 +243,9 @@ bool isSpecialOpaqueType(const Type *Ty);
 // Check if the function is an SPIR-V entry point
 bool isEntryPoint(const Function &F);
 
+// Strips all address space casts from the given register. 
+Register stripAddrspaceCast(Register Reg, const MachineRegisterInfo &MRI);
+
 // Parse basic scalar type name, substring TypeName, and return LLVM type.
 Type *parseBasicTypeName(StringRef &TypeName, LLVMContext &Ctx);
 



More information about the llvm-commits mailing list