[llvm] [SPIRV] OpEnqueueKernel Instruction generation correction (PR #136094)
via llvm-commits
llvm-commits at lists.llvm.org
Thu Apr 17 21:32:48 PDT 2025
https://github.com/EbinJose2002 updated https://github.com/llvm/llvm-project/pull/136094
>From 1662ed1c1ccfbcec4aede083e2e26c5bc79c049d Mon Sep 17 00:00:00 2001
From: EbinJose2002 <ebin.jose at multicorewareinc.com>
Date: Thu, 17 Apr 2025 11:17:07 +0530
Subject: [PATCH] - Enqueue Kernel Builtin correction - Handled cases where
size and align are explicitly given inside Global_Value - Created a
SpirvUtil function for stripping addresscast similar to in spirv-llvm
translator
---
llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp | 58 +++++++++++++++++--------
llvm/lib/Target/SPIRV/SPIRVUtils.cpp | 10 +++++
llvm/lib/Target/SPIRV/SPIRVUtils.h | 3 ++
3 files changed, 52 insertions(+), 19 deletions(-)
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index 16364ab30f280..2591c9dcedcb4 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -372,18 +372,15 @@ static MachineInstr *getBlockStructInstr(Register ParamReg,
// We expect the following sequence of instructions:
// %0:_(pN) = G_INTRINSIC_W_SIDE_EFFECTS intrinsic(@llvm.spv.alloca)
// or = G_GLOBAL_VALUE @block_literal_global
- // %1:_(pN) = G_INTRINSIC_W_SIDE_EFFECTS intrinsic(@llvm.spv.bitcast), %0
- // %2:_(p4) = G_ADDRSPACE_CAST %1:_(pN)
+ // %1:_(p4) = G_ADDRSPACE_CAST %0:_(pN)
MachineInstr *MI = MRI->getUniqueVRegDef(ParamReg);
assert(MI->getOpcode() == TargetOpcode::G_ADDRSPACE_CAST &&
MI->getOperand(1).isReg());
- Register BitcastReg = MI->getOperand(1).getReg();
- MachineInstr *BitcastMI = MRI->getUniqueVRegDef(BitcastReg);
- assert(isSpvIntrinsic(*BitcastMI, Intrinsic::spv_bitcast) &&
- BitcastMI->getOperand(2).isReg());
- Register ValueReg = BitcastMI->getOperand(2).getReg();
- MachineInstr *ValueMI = MRI->getUniqueVRegDef(ValueReg);
- return ValueMI;
+ Register PtrReg = MI->getOperand(1).getReg();
+ MachineInstr *PtrMI = MRI->getUniqueVRegDef(PtrReg);
+ assert(PtrMI->getOpcode() == TargetOpcode::G_GLOBAL_VALUE ||
+ isSpvIntrinsic(*PtrMI, Intrinsic::spv_alloca));
+ return PtrMI;
}
// Return an integer constant corresponding to the given register and
@@ -2436,20 +2433,43 @@ static bool buildEnqueueKernel(const SPIRV::IncomingCall *Call,
MachineInstr *BlockMI = getBlockStructInstr(Call->Arguments[BlockFIdx], MRI);
assert(BlockMI->getOpcode() == TargetOpcode::G_GLOBAL_VALUE);
// Invoke: Pointer to invoke function.
- MIB.addGlobalAddress(BlockMI->getOperand(1).getGlobal());
+ Register BlockFReg = BlockMI->getOperand(0).getReg();
+ MIB.addUse(BlockFReg);
+ MRI->setRegClass(BlockFReg, &SPIRV::pIDRegClass);
Register BlockLiteralReg = Call->Arguments[BlockFIdx + 1];
// Param: Pointer to block literal.
MIB.addUse(BlockLiteralReg);
-
- Type *PType = const_cast<Type *>(getBlockStructType(BlockLiteralReg, MRI));
- // TODO: these numbers should be obtained from block literal structure.
- // Param Size: Size of block literal structure.
- MIB.addUse(buildConstantIntReg32(DL.getTypeStoreSize(PType), MIRBuilder, GR));
- // Param Aligment: Aligment of block literal structure.
- MIB.addUse(buildConstantIntReg32(DL.getPrefTypeAlign(PType).value(),
- MIRBuilder, GR));
-
+ BlockMI = MRI->getUniqueVRegDef(BlockLiteralReg);
+ Register BlockMIReg =
+ stripAddrspaceCast(BlockMI->getOperand(1).getReg(), *MRI);
+ BlockMI = MRI->getUniqueVRegDef(BlockMIReg);
+
+ if (BlockMI->getOpcode() == TargetOpcode::G_GLOBAL_VALUE) {
+ // Size and align are given explicitly here.
+ const GlobalValue *GV = BlockMI->getOperand(1).getGlobal();
+ const GlobalVariable *BlockGV = dyn_cast<GlobalVariable>(GV);
+ assert(BlockGV && BlockGV->hasInitializer());
+ const Constant *Init = BlockGV->getInitializer();
+ const ConstantStruct *CS = dyn_cast<ConstantStruct>(Init);
+ // Extract fields
+ const ConstantInt *SizeConst = dyn_cast<ConstantInt>(CS->getOperand(0));
+ const ConstantInt *AlignConst = dyn_cast<ConstantInt>(CS->getOperand(1));
+ uint64_t BlockSize = SizeConst->getZExtValue();
+ uint64_t BlockAlign = AlignConst->getZExtValue();
+ MIB.addUse(buildConstantIntReg32(BlockSize, MIRBuilder, GR));
+ MIB.addUse(buildConstantIntReg32(BlockAlign, MIRBuilder, GR));
+ }
+ else {
+ Type *PType = const_cast<Type *>(getBlockStructType(BlockLiteralReg, MRI));
+ // TODO: these numbers should be obtained from block literal structure.
+ // Param Size: Size of block literal structure.
+ MIB.addUse(buildConstantIntReg32(DL.getTypeStoreSize(PType), MIRBuilder, GR));
+ // Param Aligment: Aligment of block literal structure.
+ MIB.addUse(buildConstantIntReg32(DL.getPrefTypeAlign(PType).value(),
+ MIRBuilder, GR));
+
+ }
for (unsigned i = 0; i < LocalSizes.size(); i++)
MIB.addUse(LocalSizes[i]);
return true;
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
index f38794afab436..df969ef76590a 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
@@ -486,6 +486,16 @@ bool isEntryPoint(const Function &F) {
return false;
}
+Register stripAddrspaceCast(Register Reg, const MachineRegisterInfo &MRI) {
+ while (true) {
+ MachineInstr *Def = MRI.getVRegDef(Reg);
+ if (!Def || Def->getOpcode() != TargetOpcode::G_ADDRSPACE_CAST)
+ break;
+ Reg = Def->getOperand(1).getReg(); // Unwrap the cast
+ }
+ return Reg;
+}
+
Type *parseBasicTypeName(StringRef &TypeName, LLVMContext &Ctx) {
TypeName.consume_front("atomic_");
if (TypeName.consume_front("void"))
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.h b/llvm/lib/Target/SPIRV/SPIRVUtils.h
index 0498c7beb073c..7765dca6a1df1 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.h
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.h
@@ -243,6 +243,9 @@ bool isSpecialOpaqueType(const Type *Ty);
// Check if the function is an SPIR-V entry point
bool isEntryPoint(const Function &F);
+// Strips all address space casts from the given register.
+Register stripAddrspaceCast(Register Reg, const MachineRegisterInfo &MRI);
+
// Parse basic scalar type name, substring TypeName, and return LLVM type.
Type *parseBasicTypeName(StringRef &TypeName, LLVMContext &Ctx);
More information about the llvm-commits
mailing list