[llvm] e461bdf - [SPIR-V] Fix switch lowering with common compare register
Michal Paszkowski via llvm-commits
llvm-commits at lists.llvm.org
Fri Jan 13 13:59:48 PST 2023
Author: Michal Paszkowski
Date: 2023-01-13T22:56:22+01:00
New Revision: e461bdf65b66cb002e0b455b1ec11044e90b1e7b
URL: https://github.com/llvm/llvm-project/commit/e461bdf65b66cb002e0b455b1ec11044e90b1e7b
DIFF: https://github.com/llvm/llvm-project/commit/e461bdf65b66cb002e0b455b1ec11044e90b1e7b.diff
LOG: [SPIR-V] Fix switch lowering with common compare register
Differential Revision: https://reviews.llvm.org/D141203
Added:
llvm/test/CodeGen/SPIRV/transcoding/Two_OpSwitch_same_register.ll
Modified:
llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
Removed:
################################################################################
diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index 7b7455a53325d..f91b6ea5cb141 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -198,6 +198,7 @@ Instruction *SPIRVEmitIntrinsics::visitSwitchInst(SwitchInst &I) {
for (auto &Op : I.operands())
if (Op.get()->getType()->isSized())
Args.push_back(Op);
+ IRB->SetInsertPoint(&I);
IRB->CreateIntrinsic(Intrinsic::spv_switch, {I.getOperand(0)->getType()},
{Args});
return &I;
diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
index 7c24d9557711f..27d0e8a976f0d 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
@@ -389,14 +389,11 @@ static void processInstrsWithTypeFolding(MachineFunction &MF,
static void processSwitches(MachineFunction &MF, SPIRVGlobalRegistry *GR,
MachineIRBuilder MIB) {
- DenseMap<Register, SmallDenseMap<uint64_t, MachineBasicBlock *>>
- SwitchRegToMBB;
- DenseMap<Register, MachineBasicBlock *> DefaultMBBs;
- DenseSet<Register> SwitchRegs;
- MachineRegisterInfo &MRI = MF.getRegInfo();
- // Before IRTranslator pass, spv_switch calls are inserted before each
- // switch instruction. IRTranslator lowers switches to ICMP+CBr+Br triples.
- // A switch with two cases may be translated to this MIR sequesnce:
+ // Before IRTranslator pass, calls to spv_switch intrinsic are inserted before
+ // each switch instruction. IRTranslator lowers switches to G_ICMP + G_BRCOND
+ // + G_BR triples. A switch with two cases may be transformed to this MIR
+ // sequence:
+ //
// intrinsic(@llvm.spv.switch), %CmpReg, %Const0, %Const1
// %Dst0 = G_ICMP intpred(eq), %CmpReg, %Const0
// G_BRCOND %Dst0, %bb.2
@@ -411,31 +408,48 @@ static void processSwitches(MachineFunction &MF, SPIRVGlobalRegistry *GR,
// ...
// bb.4.sw.epilog:
// ...
- // Walk MIs and collect information about destination MBBs to update
- // spv_switch call. We assume that all spv_switch precede corresponding ICMPs.
+ //
+ // Sometimes (in case of range-compare switches), additional G_SUBs
+ // instructions are inserted before G_ICMPs. Those need to be additionally
+ // processed and require type assignment.
+ //
+ // This function modifies spv_switch call's operands to include destination
+ // MBBs (default and for each constant value).
+ // Note that this function does not remove G_ICMP + G_BRCOND + G_BR sequences,
+ // but they are marked by ModuleAnalysis as skipped and as a result AsmPrinter
+ // does not output them.
+
+ MachineRegisterInfo &MRI = MF.getRegInfo();
+
+ // Collect all MIs relevant to switches across all MBBs in MF.
+ std::vector<MachineInstr *> RelevantInsts;
+
+ // Temporary set of compare registers. G_SUBs and G_ICMPs relating to
+ // spv_switch use these registers.
+ DenseSet<Register> CompareRegs;
for (MachineBasicBlock &MBB : MF) {
for (MachineInstr &MI : MBB) {
+ // Calls to spv_switch intrinsics representing IR switches.
if (isSpvIntrinsic(MI, Intrinsic::spv_switch)) {
assert(MI.getOperand(1).isReg());
- Register Reg = MI.getOperand(1).getReg();
- SwitchRegs.insert(Reg);
- // Set the first successor as default MBB to support empty switches.
- DefaultMBBs[Reg] = *MBB.succ_begin();
+ CompareRegs.insert(MI.getOperand(1).getReg());
+ RelevantInsts.push_back(&MI);
}
- // Process G_SUB coming from switch range-compare lowering.
+
+ // G_SUBs coming from range-compare switch lowering. G_SUBs are found
+ // after spv_switch but before G_ICMP.
if (MI.getOpcode() == TargetOpcode::G_SUB && MI.getOperand(1).isReg() &&
- SwitchRegs.contains(MI.getOperand(1).getReg())) {
+ CompareRegs.contains(MI.getOperand(1).getReg())) {
assert(MI.getOperand(0).isReg() && MI.getOperand(1).isReg());
Register Dst = MI.getOperand(0).getReg();
- SwitchRegs.insert(Dst);
+ CompareRegs.insert(Dst);
SPIRVType *Ty = GR->getSPIRVTypeForVReg(MI.getOperand(1).getReg());
insertAssignInstr(Dst, nullptr, Ty, GR, MIB, MRI);
}
- // Process only ICMPs that relate to spv_switches.
+
+ // G_ICMPs relating to switches.
if (MI.getOpcode() == TargetOpcode::G_ICMP && MI.getOperand(2).isReg() &&
- SwitchRegs.contains(MI.getOperand(2).getReg())) {
- assert(MI.getOperand(0).isReg() && MI.getOperand(1).isPredicate() &&
- MI.getOperand(3).isReg());
+ CompareRegs.contains(MI.getOperand(2).getReg())) {
Register Dst = MI.getOperand(0).getReg();
// Set type info for destination register of switch's ICMP instruction.
if (GR->getSPIRVTypeForVReg(Dst) == nullptr) {
@@ -445,60 +459,85 @@ static void processSwitches(MachineFunction &MF, SPIRVGlobalRegistry *GR,
MRI.setRegClass(Dst, &SPIRV::IDRegClass);
GR->assignSPIRVTypeToVReg(SpirvTy, Dst, MIB.getMF());
}
- Register CmpReg = MI.getOperand(2).getReg();
- MachineOperand &PredOp = MI.getOperand(1);
- const auto CC = static_cast<CmpInst::Predicate>(PredOp.getPredicate());
- assert((CC == CmpInst::ICMP_EQ || CC == CmpInst::ICMP_ULE) &&
- MRI.hasOneUse(Dst) && MRI.hasOneDef(CmpReg));
- uint64_t Val = getIConstVal(MI.getOperand(3).getReg(), &MRI);
- MachineInstr *CBr = MRI.use_begin(Dst)->getParent();
- assert(CBr->getOpcode() == SPIRV::G_BRCOND &&
- CBr->getOperand(1).isMBB());
- SwitchRegToMBB[CmpReg][Val] = CBr->getOperand(1).getMBB();
- // The next MI is always BR to either the next case or the default.
- MachineInstr *NextMI = CBr->getNextNode();
- assert(NextMI->getOpcode() == SPIRV::G_BR &&
- NextMI->getOperand(0).isMBB());
- MachineBasicBlock *NextMBB = NextMI->getOperand(0).getMBB();
- assert(NextMBB != nullptr);
- // The default MBB is not started by ICMP with switch's cmp register.
- if (NextMBB->front().getOpcode() != SPIRV::G_ICMP ||
- (NextMBB->front().getOperand(2).isReg() &&
- NextMBB->front().getOperand(2).getReg() != CmpReg))
- DefaultMBBs[CmpReg] = NextMBB;
+ RelevantInsts.push_back(&MI);
}
}
}
- // Modify spv_switch's operands by collected values. For the example above,
- // the result will be like this:
- // intrinsic(@llvm.spv.switch), %CmpReg, %bb.4, i32 0, %bb.2, i32 1, %bb.3
- // Note that ICMP+CBr+Br sequences are not removed, but ModuleAnalysis marks
- // them as skipped and AsmPrinter does not output them.
- for (MachineBasicBlock &MBB : MF) {
- for (MachineInstr &MI : MBB) {
- if (!isSpvIntrinsic(MI, Intrinsic::spv_switch))
+
+ // Update each spv_switch with destination MBBs.
+ for (auto i = RelevantInsts.begin(); i != RelevantInsts.end(); i++) {
+ if (!isSpvIntrinsic(**i, Intrinsic::spv_switch))
+ continue;
+
+ // Currently considered spv_switch.
+ MachineInstr *Switch = *i;
+ // Set the first successor as default MBB to support empty switches.
+ MachineBasicBlock *DefaultMBB = *Switch->getParent()->succ_begin();
+ // Container for mapping values to MMBs.
+ SmallDenseMap<uint64_t, MachineBasicBlock *> ValuesToMBBs;
+
+ // Walk all G_ICMPs to collect ValuesToMBBs. Start at currently considered
+ // spv_switch (i) and break at any spv_switch with the same compare
+ // register (indicating we are back at the same scope).
+ Register CompareReg = Switch->getOperand(1).getReg();
+ for (auto j = i + 1; j != RelevantInsts.end(); j++) {
+ if (isSpvIntrinsic(**j, Intrinsic::spv_switch) &&
+ (*j)->getOperand(1).getReg() == CompareReg)
+ break;
+
+ if (!((*j)->getOpcode() == TargetOpcode::G_ICMP &&
+ (*j)->getOperand(2).getReg() == CompareReg))
continue;
- assert(MI.getOperand(1).isReg());
- Register Reg = MI.getOperand(1).getReg();
- unsigned NumOp = MI.getNumExplicitOperands();
- SmallVector<const ConstantInt *, 3> Vals;
- SmallVector<MachineBasicBlock *, 3> MBBs;
- for (unsigned i = 2; i < NumOp; i++) {
- Register CReg = MI.getOperand(i).getReg();
- uint64_t Val = getIConstVal(CReg, &MRI);
- MachineInstr *ConstInstr = getDefInstrMaybeConstant(CReg, &MRI);
- if (!SwitchRegToMBB[Reg][Val])
- continue;
- Vals.push_back(ConstInstr->getOperand(1).getCImm());
- MBBs.push_back(SwitchRegToMBB[Reg][Val]);
- }
- for (unsigned i = MI.getNumExplicitOperands() - 1; i > 1; i--)
- MI.removeOperand(i);
- MI.addOperand(MachineOperand::CreateMBB(DefaultMBBs[Reg]));
- for (unsigned i = 0; i < Vals.size(); i++) {
- MI.addOperand(MachineOperand::CreateCImm(Vals[i]));
- MI.addOperand(MachineOperand::CreateMBB(MBBs[i]));
- }
+
+ MachineInstr *ICMP = *j;
+ Register Dst = ICMP->getOperand(0).getReg();
+ MachineOperand &PredOp = ICMP->getOperand(1);
+ const auto CC = static_cast<CmpInst::Predicate>(PredOp.getPredicate());
+ assert((CC == CmpInst::ICMP_EQ || CC == CmpInst::ICMP_ULE) &&
+ MRI.hasOneUse(Dst) && MRI.hasOneDef(CompareReg));
+ uint64_t Value = getIConstVal(ICMP->getOperand(3).getReg(), &MRI);
+ MachineInstr *CBr = MRI.use_begin(Dst)->getParent();
+ assert(CBr->getOpcode() == SPIRV::G_BRCOND && CBr->getOperand(1).isMBB());
+ MachineBasicBlock *MBB = CBr->getOperand(1).getMBB();
+
+ // Map switch case Value to target MBB.
+ ValuesToMBBs[Value] = MBB;
+
+ // The next MI is always G_BR to either the next case or the default.
+ MachineInstr *NextMI = CBr->getNextNode();
+ assert(NextMI->getOpcode() == SPIRV::G_BR &&
+ NextMI->getOperand(0).isMBB());
+ MachineBasicBlock *NextMBB = NextMI->getOperand(0).getMBB();
+ // Default MBB does not begin with G_ICMP using spv_switch compare
+ // register.
+ if (NextMBB->front().getOpcode() != SPIRV::G_ICMP ||
+ (NextMBB->front().getOperand(2).isReg() &&
+ NextMBB->front().getOperand(2).getReg() != CompareReg))
+ DefaultMBB = NextMBB;
+ }
+
+ // Modify considered spv_switch operands using collected Values and
+ // MBBs.
+ SmallVector<const ConstantInt *, 3> Values;
+ SmallVector<MachineBasicBlock *, 3> MBBs;
+ for (unsigned k = 2; k < Switch->getNumExplicitOperands(); k++) {
+ Register CReg = Switch->getOperand(k).getReg();
+ uint64_t Val = getIConstVal(CReg, &MRI);
+ MachineInstr *ConstInstr = getDefInstrMaybeConstant(CReg, &MRI);
+ if (!ValuesToMBBs[Val])
+ continue;
+
+ Values.push_back(ConstInstr->getOperand(1).getCImm());
+ MBBs.push_back(ValuesToMBBs[Val]);
+ }
+
+ for (unsigned k = Switch->getNumExplicitOperands() - 1; k > 1; k--)
+ Switch->removeOperand(k);
+
+ Switch->addOperand(MachineOperand::CreateMBB(DefaultMBB));
+ for (unsigned k = 0; k < Values.size(); k++) {
+ Switch->addOperand(MachineOperand::CreateCImm(Values[k]));
+ Switch->addOperand(MachineOperand::CreateMBB(MBBs[k]));
}
}
}
diff --git a/llvm/test/CodeGen/SPIRV/transcoding/Two_OpSwitch_same_register.ll b/llvm/test/CodeGen/SPIRV/transcoding/Two_OpSwitch_same_register.ll
new file mode 100644
index 0000000000000..19c11ff64476b
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/transcoding/Two_OpSwitch_same_register.ll
@@ -0,0 +1,42 @@
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
+
+define spir_kernel void @test_two_switch_same_register(i32 %value) {
+; CHECK-SPIRV: OpSwitch %[[#REGISTER:]] %[[#DEFAULT1:]] 1 %[[#CASE1:]] 0 %[[#CASE2:]]
+ switch i32 %value, label %default1 [
+ i32 1, label %case1
+ i32 0, label %case2
+ ]
+
+; CHECK-SPIRV: %[[#CASE1]] = OpLabel
+case1:
+; CHECK-SPIRV-NEXT: OpBranch %[[#DEFAULT1]]
+ br label %default1
+
+; CHECK-SPIRV: %[[#CASE2]] = OpLabel
+case2:
+; CHECK-SPIRV-NEXT: OpBranch %[[#DEFAULT1]]
+ br label %default1
+
+; CHECK-SPIRV: %[[#DEFAULT1]] = OpLabel
+default1:
+; CHECK-SPIRV-NEXT: OpSwitch %[[#REGISTER]] %[[#DEFAULT2:]] 0 %[[#CASE3:]] 1 %[[#CASE4:]]
+ switch i32 %value, label %default2 [
+ i32 0, label %case3
+ i32 1, label %case4
+ ]
+
+; CHECK-SPIRV: %[[#CASE3]] = OpLabel
+case3:
+; CHECK-SPIRV-NEXT: OpBranch %[[#DEFAULT2]]
+ br label %default2
+
+; CHECK-SPIRV: %[[#CASE4]] = OpLabel
+case4:
+; CHECK-SPIRV-NEXT: OpBranch %[[#DEFAULT2]]
+ br label %default2
+
+; CHECK-SPIRV: %[[#DEFAULT2]] = OpLabel
+default2:
+; CHECK-SPIRV-NEXT: OpReturn
+ ret void
+}
More information about the llvm-commits
mailing list