[llvm] [SPIR-V] Do not reassign kernel arg SPIRVType based on later calls/uses (PR #75514)
Ilia Diachkov via llvm-commits
llvm-commits at lists.llvm.org
Thu Jan 4 10:53:35 PST 2024
================
@@ -125,12 +125,32 @@ static void insertBitcasts(MachineFunction &MF, SPIRVGlobalRegistry *GR,
SmallVector<MachineInstr *, 10> ToErase;
for (MachineBasicBlock &MBB : MF) {
for (MachineInstr &MI : MBB) {
- if (!isSpvIntrinsic(MI, Intrinsic::spv_bitcast))
- continue;
- assert(MI.getOperand(2).isReg());
- MIB.setInsertPt(*MI.getParent(), MI);
- MIB.buildBitcast(MI.getOperand(0).getReg(), MI.getOperand(2).getReg());
- ToErase.push_back(&MI);
+ if (isSpvIntrinsic(MI, Intrinsic::spv_bitcast)) {
+ assert(MI.getOperand(2).isReg());
+ MIB.setInsertPt(*MI.getParent(), MI);
+ MIB.buildBitcast(MI.getOperand(0).getReg(), MI.getOperand(2).getReg());
+ ToErase.push_back(&MI);
+ } else if (isSpvIntrinsic(MI, Intrinsic::spv_ptrcast)) {
+ assert(MI.getOperand(2).isReg());
+ MIB.setInsertPt(*MI.getParent(), MI);
+ Register Def = MI.getOperand(0).getReg();
+ Register Source = MI.getOperand(2).getReg();
+ SPIRVType *BaseTy = GR->getOrCreateSPIRVType(
+ getMDOperandAsType(MI.getOperand(3).getMetadata(), 0), MIB);
+ SPIRVType *AssignedPtrType = GR->getOrCreateSPIRVPointerType(
+ BaseTy, MI, *MF.getSubtarget<SPIRVSubtarget>().getInstrInfo(),
+ addressSpaceToStorageClass(MI.getOperand(4).getImm()));
+
+ // If the bitcast would be redundant, replace all uses with the source
+ // register.
+ if (GR->getSPIRVTypeForVReg(Source) == AssignedPtrType) {
+ MIB.getMRI()->replaceRegWith(Def, Source);
+ } else {
+ GR->assignSPIRVTypeToVReg(AssignedPtrType, Def, MF);
+ MIB.buildBitcast(Def, Source);
+ }
+ ToErase.push_back(&MI);
+ }
----------------
iliya-diyachkov wrote:
Maybe reduce indentation and code duplication:
```
if (!isSpvIntrinsic(MI, Intrinsic::spv_bitcast) && !isSpvIntrinsic(MI, Intrinsic::spv_ptrcast))
continue;
assert(MI.getOperand(2).isReg());
MIB.setInsertPt(*MI.getParent(), MI);
ToErase.push_back(&MI);
if (isSpvIntrinsic(MI, Intrinsic::spv_bitcast)) {
MIB.buildBitcast(MI.getOperand(0).getReg(), MI.getOperand(2).getReg());
continue;
}
Register Def = MI.getOperand(0).getReg();
...
```
https://github.com/llvm/llvm-project/pull/75514
More information about the llvm-commits
mailing list