[llvm] [SPIR-V] Do not reassign kernel arg SPIRVType based on later calls/uses (PR #75514)

Ilia Diachkov via llvm-commits llvm-commits at lists.llvm.org
Thu Jan 4 10:53:35 PST 2024


================
@@ -125,12 +125,32 @@ static void insertBitcasts(MachineFunction &MF, SPIRVGlobalRegistry *GR,
   SmallVector<MachineInstr *, 10> ToErase;
   for (MachineBasicBlock &MBB : MF) {
     for (MachineInstr &MI : MBB) {
-      if (!isSpvIntrinsic(MI, Intrinsic::spv_bitcast))
-        continue;
-      assert(MI.getOperand(2).isReg());
-      MIB.setInsertPt(*MI.getParent(), MI);
-      MIB.buildBitcast(MI.getOperand(0).getReg(), MI.getOperand(2).getReg());
-      ToErase.push_back(&MI);
+      if (isSpvIntrinsic(MI, Intrinsic::spv_bitcast)) {
+        assert(MI.getOperand(2).isReg());
+        MIB.setInsertPt(*MI.getParent(), MI);
+        MIB.buildBitcast(MI.getOperand(0).getReg(), MI.getOperand(2).getReg());
+        ToErase.push_back(&MI);
+      } else if (isSpvIntrinsic(MI, Intrinsic::spv_ptrcast)) {
+        assert(MI.getOperand(2).isReg());
+        MIB.setInsertPt(*MI.getParent(), MI);
+        Register Def = MI.getOperand(0).getReg();
+        Register Source = MI.getOperand(2).getReg();
+        SPIRVType *BaseTy = GR->getOrCreateSPIRVType(
+            getMDOperandAsType(MI.getOperand(3).getMetadata(), 0), MIB);
+        SPIRVType *AssignedPtrType = GR->getOrCreateSPIRVPointerType(
+            BaseTy, MI, *MF.getSubtarget<SPIRVSubtarget>().getInstrInfo(),
+            addressSpaceToStorageClass(MI.getOperand(4).getImm()));
+
+        // If the bitcast would be redundant, replace all uses with the source
+        // register.
+        if (GR->getSPIRVTypeForVReg(Source) == AssignedPtrType) {
+          MIB.getMRI()->replaceRegWith(Def, Source);
+        } else {
+          GR->assignSPIRVTypeToVReg(AssignedPtrType, Def, MF);
+          MIB.buildBitcast(Def, Source);
+        }
+        ToErase.push_back(&MI);
+      }
----------------
iliya-diyachkov wrote:

Maybe reduce indentation and code duplication:
```
if (!isSpvIntrinsic(MI, Intrinsic::spv_bitcast) && !isSpvIntrinsic(MI, Intrinsic::spv_ptrcast))
  continue;
assert(MI.getOperand(2).isReg());
MIB.setInsertPt(*MI.getParent(), MI);
ToErase.push_back(&MI);
if (isSpvIntrinsic(MI, Intrinsic::spv_bitcast)) {
  MIB.buildBitcast(MI.getOperand(0).getReg(), MI.getOperand(2).getReg());
  continue;
}
Register Def = MI.getOperand(0).getReg();
...
```

https://github.com/llvm/llvm-project/pull/75514


More information about the llvm-commits mailing list