[llvm] e461bdf - [SPIR-V] Fix switch lowering with common compare register

Michal Paszkowski via llvm-commits llvm-commits at lists.llvm.org
Fri Jan 13 13:59:48 PST 2023


Author: Michal Paszkowski
Date: 2023-01-13T22:56:22+01:00
New Revision: e461bdf65b66cb002e0b455b1ec11044e90b1e7b

URL: https://github.com/llvm/llvm-project/commit/e461bdf65b66cb002e0b455b1ec11044e90b1e7b
DIFF: https://github.com/llvm/llvm-project/commit/e461bdf65b66cb002e0b455b1ec11044e90b1e7b.diff

LOG: [SPIR-V] Fix switch lowering with common compare register

Differential Revision: https://reviews.llvm.org/D141203

Added: 
    llvm/test/CodeGen/SPIRV/transcoding/Two_OpSwitch_same_register.ll

Modified: 
    llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
    llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index 7b7455a53325d..f91b6ea5cb141 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -198,6 +198,7 @@ Instruction *SPIRVEmitIntrinsics::visitSwitchInst(SwitchInst &I) {
   for (auto &Op : I.operands())
     if (Op.get()->getType()->isSized())
       Args.push_back(Op);
+  IRB->SetInsertPoint(&I);
   IRB->CreateIntrinsic(Intrinsic::spv_switch, {I.getOperand(0)->getType()},
                        {Args});
   return &I;

diff  --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
index 7c24d9557711f..27d0e8a976f0d 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
@@ -389,14 +389,11 @@ static void processInstrsWithTypeFolding(MachineFunction &MF,
 
 static void processSwitches(MachineFunction &MF, SPIRVGlobalRegistry *GR,
                             MachineIRBuilder MIB) {
-  DenseMap<Register, SmallDenseMap<uint64_t, MachineBasicBlock *>>
-      SwitchRegToMBB;
-  DenseMap<Register, MachineBasicBlock *> DefaultMBBs;
-  DenseSet<Register> SwitchRegs;
-  MachineRegisterInfo &MRI = MF.getRegInfo();
-  // Before IRTranslator pass, spv_switch calls are inserted before each
-  // switch instruction. IRTranslator lowers switches to ICMP+CBr+Br triples.
-  // A switch with two cases may be translated to this MIR sequesnce:
+  // Before IRTranslator pass, calls to spv_switch intrinsic are inserted before
+  // each switch instruction. IRTranslator lowers switches to G_ICMP + G_BRCOND
+  // + G_BR triples. A switch with two cases may be transformed to this MIR
+  // sequence:
+  //
   //   intrinsic(@llvm.spv.switch), %CmpReg, %Const0, %Const1
   //   %Dst0 = G_ICMP intpred(eq), %CmpReg, %Const0
   //   G_BRCOND %Dst0, %bb.2
@@ -411,31 +408,48 @@ static void processSwitches(MachineFunction &MF, SPIRVGlobalRegistry *GR,
   //   ...
   // bb.4.sw.epilog:
   //   ...
-  // Walk MIs and collect information about destination MBBs to update
-  // spv_switch call. We assume that all spv_switch precede corresponding ICMPs.
+  //
+  // Sometimes (in case of range-compare switches), additional G_SUBs
+  // instructions are inserted before G_ICMPs. Those need to be additionally
+  // processed and require type assignment.
+  //
+  // This function modifies spv_switch call's operands to include destination
+  // MBBs (default and for each constant value).
+  // Note that this function does not remove G_ICMP + G_BRCOND + G_BR sequences,
+  // but they are marked by ModuleAnalysis as skipped and as a result AsmPrinter
+  // does not output them.
+
+  MachineRegisterInfo &MRI = MF.getRegInfo();
+
+  // Collect all MIs relevant to switches across all MBBs in MF.
+  std::vector<MachineInstr *> RelevantInsts;
+
+  // Temporary set of compare registers. G_SUBs and G_ICMPs relating to
+  // spv_switch use these registers.
+  DenseSet<Register> CompareRegs;
   for (MachineBasicBlock &MBB : MF) {
     for (MachineInstr &MI : MBB) {
+      // Calls to spv_switch intrinsics representing IR switches.
       if (isSpvIntrinsic(MI, Intrinsic::spv_switch)) {
         assert(MI.getOperand(1).isReg());
-        Register Reg = MI.getOperand(1).getReg();
-        SwitchRegs.insert(Reg);
-        // Set the first successor as default MBB to support empty switches.
-        DefaultMBBs[Reg] = *MBB.succ_begin();
+        CompareRegs.insert(MI.getOperand(1).getReg());
+        RelevantInsts.push_back(&MI);
       }
-      // Process G_SUB coming from switch range-compare lowering.
+
+      // G_SUBs coming from range-compare switch lowering. G_SUBs are found
+      // after spv_switch but before G_ICMP.
       if (MI.getOpcode() == TargetOpcode::G_SUB && MI.getOperand(1).isReg() &&
-          SwitchRegs.contains(MI.getOperand(1).getReg())) {
+          CompareRegs.contains(MI.getOperand(1).getReg())) {
         assert(MI.getOperand(0).isReg() && MI.getOperand(1).isReg());
         Register Dst = MI.getOperand(0).getReg();
-        SwitchRegs.insert(Dst);
+        CompareRegs.insert(Dst);
         SPIRVType *Ty = GR->getSPIRVTypeForVReg(MI.getOperand(1).getReg());
         insertAssignInstr(Dst, nullptr, Ty, GR, MIB, MRI);
       }
-      // Process only ICMPs that relate to spv_switches.
+
+      // G_ICMPs relating to switches.
       if (MI.getOpcode() == TargetOpcode::G_ICMP && MI.getOperand(2).isReg() &&
-          SwitchRegs.contains(MI.getOperand(2).getReg())) {
-        assert(MI.getOperand(0).isReg() && MI.getOperand(1).isPredicate() &&
-               MI.getOperand(3).isReg());
+          CompareRegs.contains(MI.getOperand(2).getReg())) {
         Register Dst = MI.getOperand(0).getReg();
         // Set type info for destination register of switch's ICMP instruction.
         if (GR->getSPIRVTypeForVReg(Dst) == nullptr) {
@@ -445,60 +459,85 @@ static void processSwitches(MachineFunction &MF, SPIRVGlobalRegistry *GR,
           MRI.setRegClass(Dst, &SPIRV::IDRegClass);
           GR->assignSPIRVTypeToVReg(SpirvTy, Dst, MIB.getMF());
         }
-        Register CmpReg = MI.getOperand(2).getReg();
-        MachineOperand &PredOp = MI.getOperand(1);
-        const auto CC = static_cast<CmpInst::Predicate>(PredOp.getPredicate());
-        assert((CC == CmpInst::ICMP_EQ || CC == CmpInst::ICMP_ULE) &&
-               MRI.hasOneUse(Dst) && MRI.hasOneDef(CmpReg));
-        uint64_t Val = getIConstVal(MI.getOperand(3).getReg(), &MRI);
-        MachineInstr *CBr = MRI.use_begin(Dst)->getParent();
-        assert(CBr->getOpcode() == SPIRV::G_BRCOND &&
-               CBr->getOperand(1).isMBB());
-        SwitchRegToMBB[CmpReg][Val] = CBr->getOperand(1).getMBB();
-        // The next MI is always BR to either the next case or the default.
-        MachineInstr *NextMI = CBr->getNextNode();
-        assert(NextMI->getOpcode() == SPIRV::G_BR &&
-               NextMI->getOperand(0).isMBB());
-        MachineBasicBlock *NextMBB = NextMI->getOperand(0).getMBB();
-        assert(NextMBB != nullptr);
-        // The default MBB is not started by ICMP with switch's cmp register.
-        if (NextMBB->front().getOpcode() != SPIRV::G_ICMP ||
-            (NextMBB->front().getOperand(2).isReg() &&
-             NextMBB->front().getOperand(2).getReg() != CmpReg))
-          DefaultMBBs[CmpReg] = NextMBB;
+        RelevantInsts.push_back(&MI);
       }
     }
   }
-  // Modify spv_switch's operands by collected values. For the example above,
-  // the result will be like this:
-  //   intrinsic(@llvm.spv.switch), %CmpReg, %bb.4, i32 0, %bb.2, i32 1, %bb.3
-  // Note that ICMP+CBr+Br sequences are not removed, but ModuleAnalysis marks
-  // them as skipped and AsmPrinter does not output them.
-  for (MachineBasicBlock &MBB : MF) {
-    for (MachineInstr &MI : MBB) {
-      if (!isSpvIntrinsic(MI, Intrinsic::spv_switch))
+
+  // Update each spv_switch with destination MBBs.
+  for (auto i = RelevantInsts.begin(); i != RelevantInsts.end(); i++) {
+    if (!isSpvIntrinsic(**i, Intrinsic::spv_switch))
+      continue;
+
+    // Currently considered spv_switch.
+    MachineInstr *Switch = *i;
+    // Set the first successor as default MBB to support empty switches.
+    MachineBasicBlock *DefaultMBB = *Switch->getParent()->succ_begin();
+    // Container for mapping values to MMBs.
+    SmallDenseMap<uint64_t, MachineBasicBlock *> ValuesToMBBs;
+
+    // Walk all G_ICMPs to collect ValuesToMBBs. Start at currently considered
+    // spv_switch (i) and break at any spv_switch with the same compare
+    // register (indicating we are back at the same scope).
+    Register CompareReg = Switch->getOperand(1).getReg();
+    for (auto j = i + 1; j != RelevantInsts.end(); j++) {
+      if (isSpvIntrinsic(**j, Intrinsic::spv_switch) &&
+          (*j)->getOperand(1).getReg() == CompareReg)
+        break;
+
+      if (!((*j)->getOpcode() == TargetOpcode::G_ICMP &&
+            (*j)->getOperand(2).getReg() == CompareReg))
         continue;
-      assert(MI.getOperand(1).isReg());
-      Register Reg = MI.getOperand(1).getReg();
-      unsigned NumOp = MI.getNumExplicitOperands();
-      SmallVector<const ConstantInt *, 3> Vals;
-      SmallVector<MachineBasicBlock *, 3> MBBs;
-      for (unsigned i = 2; i < NumOp; i++) {
-        Register CReg = MI.getOperand(i).getReg();
-        uint64_t Val = getIConstVal(CReg, &MRI);
-        MachineInstr *ConstInstr = getDefInstrMaybeConstant(CReg, &MRI);
-        if (!SwitchRegToMBB[Reg][Val])
-          continue;
-        Vals.push_back(ConstInstr->getOperand(1).getCImm());
-        MBBs.push_back(SwitchRegToMBB[Reg][Val]);
-      }
-      for (unsigned i = MI.getNumExplicitOperands() - 1; i > 1; i--)
-        MI.removeOperand(i);
-      MI.addOperand(MachineOperand::CreateMBB(DefaultMBBs[Reg]));
-      for (unsigned i = 0; i < Vals.size(); i++) {
-        MI.addOperand(MachineOperand::CreateCImm(Vals[i]));
-        MI.addOperand(MachineOperand::CreateMBB(MBBs[i]));
-      }
+
+      MachineInstr *ICMP = *j;
+      Register Dst = ICMP->getOperand(0).getReg();
+      MachineOperand &PredOp = ICMP->getOperand(1);
+      const auto CC = static_cast<CmpInst::Predicate>(PredOp.getPredicate());
+      assert((CC == CmpInst::ICMP_EQ || CC == CmpInst::ICMP_ULE) &&
+             MRI.hasOneUse(Dst) && MRI.hasOneDef(CompareReg));
+      uint64_t Value = getIConstVal(ICMP->getOperand(3).getReg(), &MRI);
+      MachineInstr *CBr = MRI.use_begin(Dst)->getParent();
+      assert(CBr->getOpcode() == SPIRV::G_BRCOND && CBr->getOperand(1).isMBB());
+      MachineBasicBlock *MBB = CBr->getOperand(1).getMBB();
+
+      // Map switch case Value to target MBB.
+      ValuesToMBBs[Value] = MBB;
+
+      // The next MI is always G_BR to either the next case or the default.
+      MachineInstr *NextMI = CBr->getNextNode();
+      assert(NextMI->getOpcode() == SPIRV::G_BR &&
+             NextMI->getOperand(0).isMBB());
+      MachineBasicBlock *NextMBB = NextMI->getOperand(0).getMBB();
+      // Default MBB does not begin with G_ICMP using spv_switch compare
+      // register.
+      if (NextMBB->front().getOpcode() != SPIRV::G_ICMP ||
+          (NextMBB->front().getOperand(2).isReg() &&
+           NextMBB->front().getOperand(2).getReg() != CompareReg))
+        DefaultMBB = NextMBB;
+    }
+
+    // Modify considered spv_switch operands using collected Values and
+    // MBBs.
+    SmallVector<const ConstantInt *, 3> Values;
+    SmallVector<MachineBasicBlock *, 3> MBBs;
+    for (unsigned k = 2; k < Switch->getNumExplicitOperands(); k++) {
+      Register CReg = Switch->getOperand(k).getReg();
+      uint64_t Val = getIConstVal(CReg, &MRI);
+      MachineInstr *ConstInstr = getDefInstrMaybeConstant(CReg, &MRI);
+      if (!ValuesToMBBs[Val])
+        continue;
+
+      Values.push_back(ConstInstr->getOperand(1).getCImm());
+      MBBs.push_back(ValuesToMBBs[Val]);
+    }
+
+    for (unsigned k = Switch->getNumExplicitOperands() - 1; k > 1; k--)
+      Switch->removeOperand(k);
+
+    Switch->addOperand(MachineOperand::CreateMBB(DefaultMBB));
+    for (unsigned k = 0; k < Values.size(); k++) {
+      Switch->addOperand(MachineOperand::CreateCImm(Values[k]));
+      Switch->addOperand(MachineOperand::CreateMBB(MBBs[k]));
     }
   }
 }

diff  --git a/llvm/test/CodeGen/SPIRV/transcoding/Two_OpSwitch_same_register.ll b/llvm/test/CodeGen/SPIRV/transcoding/Two_OpSwitch_same_register.ll
new file mode 100644
index 0000000000000..19c11ff64476b
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/transcoding/Two_OpSwitch_same_register.ll
@@ -0,0 +1,42 @@
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
+
+define spir_kernel void @test_two_switch_same_register(i32 %value) {
+; CHECK-SPIRV:      OpSwitch %[[#REGISTER:]] %[[#DEFAULT1:]] 1 %[[#CASE1:]] 0 %[[#CASE2:]]
+  switch i32 %value, label %default1 [
+    i32 1, label %case1
+    i32 0, label %case2
+  ]
+
+; CHECK-SPIRV:      %[[#CASE1]] = OpLabel
+case1:
+; CHECK-SPIRV-NEXT: OpBranch %[[#DEFAULT1]]
+  br label %default1
+
+; CHECK-SPIRV:      %[[#CASE2]] = OpLabel
+case2:
+; CHECK-SPIRV-NEXT: OpBranch %[[#DEFAULT1]]
+  br label %default1
+
+; CHECK-SPIRV:      %[[#DEFAULT1]] = OpLabel
+default1:
+; CHECK-SPIRV-NEXT:      OpSwitch %[[#REGISTER]] %[[#DEFAULT2:]] 0 %[[#CASE3:]] 1 %[[#CASE4:]]
+  switch i32 %value, label %default2 [
+    i32 0, label %case3
+    i32 1, label %case4
+  ]
+
+; CHECK-SPIRV:      %[[#CASE3]] = OpLabel
+case3:
+; CHECK-SPIRV-NEXT: OpBranch %[[#DEFAULT2]]
+  br label %default2
+
+; CHECK-SPIRV:      %[[#CASE4]] = OpLabel
+case4:
+; CHECK-SPIRV-NEXT: OpBranch %[[#DEFAULT2]]
+  br label %default2
+
+; CHECK-SPIRV:      %[[#DEFAULT2]] = OpLabel
+default2:
+; CHECK-SPIRV-NEXT: OpReturn
+  ret void
+}


        


More information about the llvm-commits mailing list