[clang] [llvm] [SPIR-V] Add SPIR-V structurizer (PR #107408)

via cfe-commits cfe-commits at lists.llvm.org
Thu Sep 5 07:39:26 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-ir
@llvm/pr-subscribers-hlsl

@llvm/pr-subscribers-clang

Author: Nathan Gauër (Keenuts)

<details>
<summary>Changes</summary>

This commit adds an initial SPIR-V structurizer.
It leverages the previously merged passes, and the convergence region analysis to determine the correct merge and continue blocks for SPIR-V.

The first part does a branch cleanup (simplifying switches, and legalizing them), then merge instructions are added to cycles, convergent and later divergent blocks.
Then comes the important part: splitting critical edges, and making sure the divergent construct boundaries don't cross.

- we split blocks with multiple headers into 2 blocks.
- we split blocks that are a merge blocks for 2 or more constructs: SPIR-V spec disallow a merge block to be shared by 2 loop/switch/condition construct.
- we split merge & continue blocks: SPIR-V spec disallow a basic block to be both a continue block, and a merge block.
- we remove superfluous headers: when a header doesn't bring more info than the parent on the divergence state, it must be removed.

This PR leverages the merged SPIR-V simulator for testing, as long as spirv-val. For now, most DXC structurization tests are passing. The unsupported ones are either caused by unsupported features like switches on boolean types, or switches in region exits, because the MergeExit pass doesn't support those yet (there is a FIXME).

This PR is quite large, and the addition not trivial, so I tried to keep it simple. E.G: as soon as the CFG changes, I recompute the dominator trees and other structures instead of updating them.

---

Patch is 195.51 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/107408.diff


53 Files Affected:

- (added) clang/test/CodeGenHLSL/convergence/cf.for.plain.hlsl (+44) 
- (modified) llvm/include/llvm/IR/IntrinsicsSPIRV.td (+2) 
- (modified) llvm/lib/Target/SPIRV/Analysis/SPIRVConvergenceRegionAnalysis.cpp (+3-2) 
- (modified) llvm/lib/Target/SPIRV/Analysis/SPIRVConvergenceRegionAnalysis.h (+3) 
- (modified) llvm/lib/Target/SPIRV/CMakeLists.txt (+1) 
- (modified) llvm/lib/Target/SPIRV/SPIRV.h (+1) 
- (modified) llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp (+13) 
- (modified) llvm/lib/Target/SPIRV/SPIRVMergeRegionExitTargets.cpp (+12-8) 
- (modified) llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp (+115-57) 
- (added) llvm/lib/Target/SPIRV/SPIRVStructurizer.cpp (+1410) 
- (modified) llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp (+5-5) 
- (modified) llvm/test/CMakeLists.txt (+1) 
- (removed) llvm/test/CodeGen/SPIRV/scfg-add-pre-headers.ll (-87) 
- (added) llvm/test/CodeGen/SPIRV/structurizer/cf.cond-op.ll (+168) 
- (added) llvm/test/CodeGen/SPIRV/structurizer/cf.do.break.ll (+169) 
- (added) llvm/test/CodeGen/SPIRV/structurizer/cf.do.continue.ll (+167) 
- (added) llvm/test/CodeGen/SPIRV/structurizer/cf.do.nested.ll (+138) 
- (added) llvm/test/CodeGen/SPIRV/structurizer/cf.for.break.ll (+178) 
- (added) llvm/test/CodeGen/SPIRV/structurizer/cf.for.continue.hlsl (+47) 
- (added) llvm/test/CodeGen/SPIRV/structurizer/cf.for.nested.hlsl (+25) 
- (added) llvm/test/CodeGen/SPIRV/structurizer/cf.for.plain.ll (+105) 
- (added) llvm/test/CodeGen/SPIRV/structurizer/cf.for.short-circuited-cond.hlsl (+42) 
- (added) llvm/test/CodeGen/SPIRV/structurizer/cf.if.const-cond.hlsl (+33) 
- (added) llvm/test/CodeGen/SPIRV/structurizer/cf.if.for.hlsl (+46) 
- (added) llvm/test/CodeGen/SPIRV/structurizer/cf.if.nested.hlsl (+29) 
- (added) llvm/test/CodeGen/SPIRV/structurizer/cf.if.plain.hlsl (+39) 
- (added) llvm/test/CodeGen/SPIRV/structurizer/cf.logical-and.hlsl (+27) 
- (added) llvm/test/CodeGen/SPIRV/structurizer/cf.logical-or.hlsl (+24) 
- (added) llvm/test/CodeGen/SPIRV/structurizer/cf.return.early.hlsl (+58) 
- (added) llvm/test/CodeGen/SPIRV/structurizer/cf.return.early.simple.hlsl (+20) 
- (added) llvm/test/CodeGen/SPIRV/structurizer/cf.return.void.hlsl (+14) 
- (added) llvm/test/CodeGen/SPIRV/structurizer/cf.switch.ifstmt.hlsl (+122) 
- (added) llvm/test/CodeGen/SPIRV/structurizer/cf.switch.ifstmt.simple.hlsl (+25) 
- (added) llvm/test/CodeGen/SPIRV/structurizer/cf.switch.ifstmt.simple2.hlsl (+45) 
- (added) llvm/test/CodeGen/SPIRV/structurizer/cf.switch.opswitch.hlsl (+360) 
- (added) llvm/test/CodeGen/SPIRV/structurizer/cf.switch.opswitch.literal.hlsl (+36) 
- (added) llvm/test/CodeGen/SPIRV/structurizer/cf.switch.opswitch.simple.hlsl (+36) 
- (added) llvm/test/CodeGen/SPIRV/structurizer/cf.while.break.hlsl (+87) 
- (added) llvm/test/CodeGen/SPIRV/structurizer/cf.while.continue.hlsl (+89) 
- (added) llvm/test/CodeGen/SPIRV/structurizer/cf.while.nested.hlsl (+79) 
- (added) llvm/test/CodeGen/SPIRV/structurizer/cf.while.plain.hlsl (+101) 
- (added) llvm/test/CodeGen/SPIRV/structurizer/cf.while.short-circuited-cond.hlsl (+20) 
- (added) llvm/test/CodeGen/SPIRV/structurizer/condition-linear.ll (+128) 
- (added) llvm/test/CodeGen/SPIRV/structurizer/do-break.ll (+89) 
- (added) llvm/test/CodeGen/SPIRV/structurizer/do-continue.ll (+124) 
- (added) llvm/test/CodeGen/SPIRV/structurizer/do-nested.ll (+102) 
- (added) llvm/test/CodeGen/SPIRV/structurizer/do-plain.ll (+98) 
- (added) llvm/test/CodeGen/SPIRV/structurizer/logical-or.ll (+93) 
- (modified) llvm/test/CodeGen/SPIRV/structurizer/merge-exit-break.ll (+15-21) 
- (modified) llvm/test/CodeGen/SPIRV/structurizer/merge-exit-convergence-in-break.ll (+23-33) 
- (modified) llvm/test/CodeGen/SPIRV/structurizer/merge-exit-multiple-break.ll (+28-36) 
- (modified) llvm/test/CodeGen/SPIRV/structurizer/merge-exit-simple-while-identity.ll (+4-3) 
- (added) llvm/test/CodeGen/SPIRV/structurizer/return-early.ll (+131) 


``````````diff
diff --git a/clang/test/CodeGenHLSL/convergence/cf.for.plain.hlsl b/clang/test/CodeGenHLSL/convergence/cf.for.plain.hlsl
new file mode 100644
index 00000000000000..2f08854f84d955
--- /dev/null
+++ b/clang/test/CodeGenHLSL/convergence/cf.for.plain.hlsl
@@ -0,0 +1,44 @@
+// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple \
+// RUN:   spirv-pc-vulkan-library %s -emit-llvm -disable-llvm-passes -o - | FileCheck %s
+
+int process() {
+// CHECK: entry:
+// CHECK:   %[[#entry_token:]] = call token @llvm.experimental.convergence.entry()
+  int val = 0;
+
+// CHECK: for.cond:
+// CHECK-NEXT:   %[[#]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token %[[#entry_token]]) ]
+// CHECK: br i1 {{.*}}, label %for.body, label %for.end
+  for (int i = 0; i < 10; ++i) {
+
+// CHECK: for.body:
+// CHECK:   br label %for.inc
+    val = i;
+
+// CHECK: for.inc:
+// CHECK:   br label %for.cond
+  }
+
+// CHECK: for.end:
+// CHECK:   br label %for.cond1
+
+  // Infinite loop
+  for ( ; ; ) {
+// CHECK: for.cond1:
+// CHECK-NEXT:   %[[#]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token %[[#entry_token]]) ]
+// CHECK:   br label %for.cond1
+    val = 0;
+  }
+
+// CHECK-NEXT: }
+// This loop in unreachable. Not generated.
+  // Null body
+  for (int j = 0; j < 10; ++j)
+  ;
+  return val;
+}
+
+[numthreads(1, 1, 1)]
+void main() {
+  process();
+}
diff --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
index cbf6e04f2844d6..21b3d2ec8b9649 100644
--- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td
+++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
@@ -31,6 +31,8 @@ let TargetPrefix = "spv" in {
   def int_spv_bitcast : Intrinsic<[llvm_any_ty], [llvm_any_ty]>;
   def int_spv_ptrcast : Intrinsic<[llvm_any_ty], [llvm_any_ty, llvm_metadata_ty, llvm_i32_ty], [ImmArg<ArgIndex<2>>]>;
   def int_spv_switch : Intrinsic<[], [llvm_any_ty, llvm_vararg_ty]>;
+  def int_spv_loop_merge : Intrinsic<[], [llvm_vararg_ty]>;
+  def int_spv_selection_merge : Intrinsic<[], [llvm_vararg_ty]>;
   def int_spv_cmpxchg : Intrinsic<[llvm_i32_ty], [llvm_any_ty, llvm_vararg_ty]>;
   def int_spv_unreachable : Intrinsic<[], []>;
   def int_spv_alloca : Intrinsic<[llvm_any_ty], []>;
diff --git a/llvm/lib/Target/SPIRV/Analysis/SPIRVConvergenceRegionAnalysis.cpp b/llvm/lib/Target/SPIRV/Analysis/SPIRVConvergenceRegionAnalysis.cpp
index 25e285e35f9336..cc6daf7ef34426 100644
--- a/llvm/lib/Target/SPIRV/Analysis/SPIRVConvergenceRegionAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/Analysis/SPIRVConvergenceRegionAnalysis.cpp
@@ -203,7 +203,8 @@ class ConvergenceRegionAnalyzer {
 
 private:
   bool isBackEdge(const BasicBlock *From, const BasicBlock *To) const {
-    assert(From != To && "From == To. This is awkward.");
+    if (From == To)
+      return true;
 
     // We only handle loop in the simplified form. This means:
     // - a single back-edge, a single latch.
@@ -230,6 +231,7 @@ class ConvergenceRegionAnalyzer {
     auto *Terminator = From->getTerminator();
     for (unsigned i = 0; i < Terminator->getNumSuccessors(); ++i) {
       auto *To = Terminator->getSuccessor(i);
+      // Ignore back edges.
       if (isBackEdge(From, To))
         continue;
 
@@ -276,7 +278,6 @@ class ConvergenceRegionAnalyzer {
     while (ToProcess.size() != 0) {
       auto *L = ToProcess.front();
       ToProcess.pop();
-      assert(L->isLoopSimplifyForm());
 
       auto CT = getConvergenceToken(L->getHeader());
       SmallPtrSet<BasicBlock *, 8> RegionBlocks(L->block_begin(),
diff --git a/llvm/lib/Target/SPIRV/Analysis/SPIRVConvergenceRegionAnalysis.h b/llvm/lib/Target/SPIRV/Analysis/SPIRVConvergenceRegionAnalysis.h
index f9e30e4effa1d9..e435c88c919c9c 100644
--- a/llvm/lib/Target/SPIRV/Analysis/SPIRVConvergenceRegionAnalysis.h
+++ b/llvm/lib/Target/SPIRV/Analysis/SPIRVConvergenceRegionAnalysis.h
@@ -130,6 +130,9 @@ class ConvergenceRegionInfo {
   }
 
   const ConvergenceRegion *getTopLevelRegion() const { return TopLevelRegion; }
+  ConvergenceRegion *getWritableTopLevelRegion() const {
+    return TopLevelRegion;
+  }
 };
 
 } // namespace SPIRV
diff --git a/llvm/lib/Target/SPIRV/CMakeLists.txt b/llvm/lib/Target/SPIRV/CMakeLists.txt
index 5f8aea5fc8d84d..198483e03a46d7 100644
--- a/llvm/lib/Target/SPIRV/CMakeLists.txt
+++ b/llvm/lib/Target/SPIRV/CMakeLists.txt
@@ -31,6 +31,7 @@ add_llvm_target(SPIRVCodeGen
   SPIRVMCInstLower.cpp
   SPIRVMetadata.cpp
   SPIRVModuleAnalysis.cpp
+  SPIRVStructurizer.cpp
   SPIRVPreLegalizer.cpp
   SPIRVPostLegalizer.cpp
   SPIRVPrepareFunctions.cpp
diff --git a/llvm/lib/Target/SPIRV/SPIRV.h b/llvm/lib/Target/SPIRV/SPIRV.h
index 6c35a467f53bef..384133e7b4bd18 100644
--- a/llvm/lib/Target/SPIRV/SPIRV.h
+++ b/llvm/lib/Target/SPIRV/SPIRV.h
@@ -20,6 +20,7 @@ class InstructionSelector;
 class RegisterBankInfo;
 
 ModulePass *createSPIRVPrepareFunctionsPass(const SPIRVTargetMachine &TM);
+FunctionPass *createSPIRVStructurizerPass();
 FunctionPass *createSPIRVMergeRegionExitTargetsPass();
 FunctionPass *createSPIRVStripConvergenceIntrinsicsPass();
 FunctionPass *createSPIRVRegularizerPass();
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index fed82b904af4f7..44f2da11edb051 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -2296,6 +2296,19 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
     }
     return MIB.constrainAllUses(TII, TRI, RBI);
   }
+  case Intrinsic::spv_loop_merge:
+  case Intrinsic::spv_selection_merge: {
+    const auto Opcode = IID == Intrinsic::spv_selection_merge
+                            ? SPIRV::OpSelectionMerge
+                            : SPIRV::OpLoopMerge;
+    auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(Opcode));
+    for (unsigned i = 1; i < I.getNumExplicitOperands(); ++i) {
+      assert(I.getOperand(i).isMBB());
+      MIB.addMBB(I.getOperand(i).getMBB());
+    }
+    MIB.addImm(SPIRV::SelectionControl::None);
+    return MIB.constrainAllUses(TII, TRI, RBI);
+  }
   case Intrinsic::spv_cmpxchg:
     return selectAtomicCmpXchg(ResVReg, ResType, I);
   case Intrinsic::spv_unreachable:
diff --git a/llvm/lib/Target/SPIRV/SPIRVMergeRegionExitTargets.cpp b/llvm/lib/Target/SPIRV/SPIRVMergeRegionExitTargets.cpp
index 0747dd1bbaf40a..9930d067173df7 100644
--- a/llvm/lib/Target/SPIRV/SPIRVMergeRegionExitTargets.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVMergeRegionExitTargets.cpp
@@ -133,7 +133,7 @@ class SPIRVMergeRegionExitTargets : public FunctionPass {
   // Run the pass on the given convergence region, ignoring the sub-regions.
   // Returns true if the CFG changed, false otherwise.
   bool runOnConvergenceRegionNoRecurse(LoopInfo &LI,
-                                       const SPIRV::ConvergenceRegion *CR) {
+                                       SPIRV::ConvergenceRegion *CR) {
     // Gather all the exit targets for this region.
     SmallPtrSet<BasicBlock *, 4> ExitTargets;
     for (BasicBlock *Exit : CR->Exits) {
@@ -198,14 +198,19 @@ class SPIRVMergeRegionExitTargets : public FunctionPass {
     for (auto Exit : CR->Exits)
       replaceBranchTargets(Exit, ExitTargets, NewExitTarget);
 
+    CR = CR->Parent;
+    while (CR) {
+      CR->Blocks.insert(NewExitTarget);
+      CR = CR->Parent;
+    }
+
     return true;
   }
 
   /// Run the pass on the given convergence region and sub-regions (DFS).
   /// Returns true if a region/sub-region was modified, false otherwise.
   /// This returns as soon as one region/sub-region has been modified.
-  bool runOnConvergenceRegion(LoopInfo &LI,
-                              const SPIRV::ConvergenceRegion *CR) {
+  bool runOnConvergenceRegion(LoopInfo &LI, SPIRV::ConvergenceRegion *CR) {
     for (auto *Child : CR->Children)
       if (runOnConvergenceRegion(LI, Child))
         return true;
@@ -235,10 +240,10 @@ class SPIRVMergeRegionExitTargets : public FunctionPass {
 
   virtual bool runOnFunction(Function &F) override {
     LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
-    const auto *TopLevelRegion =
+    auto *TopLevelRegion =
         getAnalysis<SPIRVConvergenceRegionAnalysisWrapperPass>()
             .getRegionInfo()
-            .getTopLevelRegion();
+            .getWritableTopLevelRegion();
 
     // FIXME: very inefficient method: each time a region is modified, we bubble
     // back up, and recompute the whole convergence region tree. Once the
@@ -246,9 +251,6 @@ class SPIRVMergeRegionExitTargets : public FunctionPass {
     // to be efficient instead of simple.
     bool modified = false;
     while (runOnConvergenceRegion(LI, TopLevelRegion)) {
-      TopLevelRegion = getAnalysis<SPIRVConvergenceRegionAnalysisWrapperPass>()
-                           .getRegionInfo()
-                           .getTopLevelRegion();
       modified = true;
     }
 
@@ -262,6 +264,8 @@ class SPIRVMergeRegionExitTargets : public FunctionPass {
     AU.addRequired<DominatorTreeWrapperPass>();
     AU.addRequired<LoopInfoWrapperPass>();
     AU.addRequired<SPIRVConvergenceRegionAnalysisWrapperPass>();
+
+    AU.addPreserved<SPIRVConvergenceRegionAnalysisWrapperPass>();
     FunctionPass::getAnalysisUsage(AU);
   }
 };
diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
index df1b75bc1cb9eb..1784f00be600dd 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
@@ -744,79 +744,139 @@ static void insertSpirvDecorations(MachineFunction &MF, MachineIRBuilder MIB) {
     MI->eraseFromParent();
 }
 
-// Find basic blocks of the switch and replace registers in spv_switch() by its
-// MBB equivalent.
-static void processSwitches(MachineFunction &MF, SPIRVGlobalRegistry *GR,
-                            MachineIRBuilder MIB) {
-  DenseMap<const BasicBlock *, MachineBasicBlock *> BB2MBB;
-  SmallVector<std::pair<MachineInstr *, SmallVector<MachineInstr *, 8>>>
-      Switches;
+// LLVM allows the switches to use registers as cases, while SPIR-V required
+// those to be immediate values. This function replaces such operands with the
+// equivalent immediate constant.
+static void processSwitchesConstants(MachineFunction &MF,
+                                     SPIRVGlobalRegistry *GR,
+                                     MachineIRBuilder MIB) {
+  MachineRegisterInfo &MRI = MF.getRegInfo();
   for (MachineBasicBlock &MBB : MF) {
-    MachineRegisterInfo &MRI = MF.getRegInfo();
-    BB2MBB[MBB.getBasicBlock()] = &MBB;
     for (MachineInstr &MI : MBB) {
       if (!isSpvIntrinsic(MI, Intrinsic::spv_switch))
         continue;
-      // Calls to spv_switch intrinsics representing IR switches.
-      SmallVector<MachineInstr *, 8> NewOps;
-      for (unsigned i = 2; i < MI.getNumOperands(); ++i) {
+
+      SmallVector<MachineOperand, 8> NewOperands;
+      NewOperands.push_back(MI.getOperand(0)); // Opcode
+      NewOperands.push_back(MI.getOperand(1)); // Condition
+      NewOperands.push_back(MI.getOperand(2)); // Default
+      for (unsigned i = 3; i < MI.getNumOperands(); i += 2) {
         Register Reg = MI.getOperand(i).getReg();
-        if (i % 2 == 1) {
-          MachineInstr *ConstInstr = getDefInstrMaybeConstant(Reg, &MRI);
-          NewOps.push_back(ConstInstr);
-        } else {
-          MachineInstr *BuildMBB = MRI.getVRegDef(Reg);
-          assert(BuildMBB &&
-                 BuildMBB->getOpcode() == TargetOpcode::G_BLOCK_ADDR &&
-                 BuildMBB->getOperand(1).isBlockAddress() &&
-                 BuildMBB->getOperand(1).getBlockAddress());
-          NewOps.push_back(BuildMBB);
-        }
+        MachineInstr *ConstInstr = getDefInstrMaybeConstant(Reg, &MRI);
+        NewOperands.push_back(
+            MachineOperand::CreateCImm(ConstInstr->getOperand(1).getCImm()));
+
+        NewOperands.push_back(MI.getOperand(i + 1));
       }
-      Switches.push_back(std::make_pair(&MI, NewOps));
+
+      assert(MI.getNumOperands() == NewOperands.size());
+      while (MI.getNumOperands() > 0)
+        MI.removeOperand(0);
+      for (auto &MO : NewOperands)
+        MI.addOperand(MO);
     }
   }
+}
 
+// Some instructions are used during CodeGen but should never be emitted.
+// Cleaning up those.
+static void cleanupHelperInstructions(MachineFunction &MF) {
   SmallPtrSet<MachineInstr *, 8> ToEraseMI;
+  for (MachineBasicBlock &MBB : MF) {
+    for (MachineInstr &MI : MBB) {
+      if (isSpvIntrinsic(MI, Intrinsic::spv_track_constant) ||
+          MI.getOpcode() == TargetOpcode::G_BRINDIRECT)
+        ToEraseMI.insert(&MI);
+    }
+  }
+
+  for (MachineInstr *MI : ToEraseMI)
+    MI->eraseFromParent();
+}
+
+// Find all usages of G_BLOCK_ADDR in our intrinsics and replace those
+// operands/registers by the actual MBB it references.
+static void processBlockAddr(MachineFunction &MF, SPIRVGlobalRegistry *GR,
+                             MachineIRBuilder MIB) {
+  // Gather the reverse-mapping BB -> MBB.
+  DenseMap<const BasicBlock *, MachineBasicBlock *> BB2MBB;
+  for (MachineBasicBlock &MBB : MF)
+    BB2MBB[MBB.getBasicBlock()] = &MBB;
+
+  // Gather instructions requiring patching. For now, only those can use
+  // G_BLOCK_ADDR.
+  SmallVector<MachineInstr *, 8> InstructionsToPatch;
+  for (MachineBasicBlock &MBB : MF) {
+    for (MachineInstr &MI : MBB) {
+      if (isSpvIntrinsic(MI, Intrinsic::spv_switch) ||
+          isSpvIntrinsic(MI, Intrinsic::spv_loop_merge) ||
+          isSpvIntrinsic(MI, Intrinsic::spv_selection_merge))
+        InstructionsToPatch.push_back(&MI);
+    }
+  }
+
+  // For each instruction to fix, we replace all the G_BLOCK_ADDR operands by
+  // the actual MBB it references. Once those references updated, we can cleanup
+  // remaining G_BLOCK_ADDR references.
   SmallPtrSet<MachineBasicBlock *, 8> ClearAddressTaken;
-  for (auto &SwIt : Switches) {
-    MachineInstr &MI = *SwIt.first;
-    MachineBasicBlock *MBB = MI.getParent();
-    SmallVector<MachineInstr *, 8> &Ins = SwIt.second;
+  SmallPtrSet<MachineInstr *, 8> ToEraseMI;
+  MachineRegisterInfo &MRI = MF.getRegInfo();
+  for (MachineInstr *MI : InstructionsToPatch) {
     SmallVector<MachineOperand, 8> NewOps;
-    for (unsigned i = 0; i < Ins.size(); ++i) {
-      if (Ins[i]->getOpcode() == TargetOpcode::G_BLOCK_ADDR) {
-        BasicBlock *CaseBB =
-            Ins[i]->getOperand(1).getBlockAddress()->getBasicBlock();
-        auto It = BB2MBB.find(CaseBB);
-        if (It == BB2MBB.end())
-          report_fatal_error("cannot find a machine basic block by a basic "
-                             "block in a switch statement");
-        MachineBasicBlock *Succ = It->second;
-        ClearAddressTaken.insert(Succ);
-        NewOps.push_back(MachineOperand::CreateMBB(Succ));
-        if (!llvm::is_contained(MBB->successors(), Succ))
-          MBB->addSuccessor(Succ);
-        ToEraseMI.insert(Ins[i]);
-      } else {
-        NewOps.push_back(
-            MachineOperand::CreateCImm(Ins[i]->getOperand(1).getCImm()));
+    for (unsigned i = 0; i < MI->getNumOperands(); ++i) {
+      // The operand is not a register, keep as-is.
+      if (!MI->getOperand(i).isReg()) {
+        NewOps.push_back(MI->getOperand(i));
+        continue;
+      }
+
+      Register Reg = MI->getOperand(i).getReg();
+      MachineInstr *BuildMBB = MRI.getVRegDef(Reg);
+      // The register is not the result of G_BLOCK_ADDR, keep as-is.
+      if (!BuildMBB || BuildMBB->getOpcode() != TargetOpcode::G_BLOCK_ADDR) {
+        NewOps.push_back(MI->getOperand(i));
+        continue;
       }
+
+      assert(BuildMBB && BuildMBB->getOpcode() == TargetOpcode::G_BLOCK_ADDR &&
+             BuildMBB->getOperand(1).isBlockAddress() &&
+             BuildMBB->getOperand(1).getBlockAddress());
+      BasicBlock *BB =
+          BuildMBB->getOperand(1).getBlockAddress()->getBasicBlock();
+      auto It = BB2MBB.find(BB);
+      if (It == BB2MBB.end())
+        report_fatal_error("cannot find a machine basic block by a basic block "
+                           "in a switch statement");
+      MachineBasicBlock *ReferencedBlock = It->second;
+      NewOps.push_back(MachineOperand::CreateMBB(ReferencedBlock));
+
+      ClearAddressTaken.insert(ReferencedBlock);
+      ToEraseMI.insert(BuildMBB);
     }
-    for (unsigned i = MI.getNumOperands() - 1; i > 1; --i)
-      MI.removeOperand(i);
+
+    // Replace the operands.
+    assert(MI->getNumOperands() == NewOps.size());
+    while (MI->getNumOperands() > 0)
+      MI->removeOperand(0);
     for (auto &MO : NewOps)
-      MI.addOperand(MO);
-    if (MachineInstr *Next = MI.getNextNode()) {
+      MI->addOperand(MO);
+
+    if (MachineInstr *Next = MI->getNextNode()) {
       if (isSpvIntrinsic(*Next, Intrinsic::spv_track_constant)) {
         ToEraseMI.insert(Next);
-        Next = MI.getNextNode();
+        Next = MI->getNextNode();
       }
       if (Next && Next->getOpcode() == TargetOpcode::G_BRINDIRECT)
         ToEraseMI.insert(Next);
     }
   }
 
+  // BlockAddress operands were used to keep information between passes,
+  // let's undo the "address taken" status to reflect that Succ doesn't
+  // actually correspond to an IR-level basic block.
+  for (MachineBasicBlock *Succ : ClearAddressTaken)
+    Succ->setAddressTakenIRBlock(nullptr);
+
   // If we just delete G_BLOCK_ADDR instructions with BlockAddress operands,
   // this leaves their BasicBlock counterparts in a "address taken" status. This
   // would make AsmPrinter to generate a series of unneeded labels of a "Address
@@ -835,12 +895,6 @@ static void processSwitches(MachineFunction &MF, SPIRVGlobalRegistry *GR,
     }
     BlockAddrI->eraseFromParent();
   }
-
-  // BlockAddress operands were used to keep information between passes,
-  // let's undo the "address taken" status to reflect that Succ doesn't
-  // actually correspond to an IR-level basic block.
-  for (MachineBasicBlock *Succ : ClearAddressTaken)
-    Succ->setAddressTakenIRBlock(nullptr);
 }
 
 static bool isImplicitFallthrough(MachineBasicBlock &MBB) {
@@ -891,7 +945,11 @@ bool SPIRVPreLegalizer::runOnMachineFunction(MachineFunction &MF) {
   foldConstantsIntoIntrinsics(MF, TrackedConstRegs);
   insertBitcasts(MF, GR, MIB);
   generateAssignInstrs(MF, GR, MIB, TargetExtConstTypes);
-  processSwitches(MF, GR, MIB);
+
+  processSwitchesConstants(MF, GR, MIB);
+  processBlockAddr(MF, GR, MIB);
+  cleanupHelperInstructions(MF);
+
   processInstrsWithTypeFolding(MF, GR, MIB);
   removeImplicitFallthroughs(MF, MIB);
   insertSpirvDecorations(MF, MIB);
diff --git a/llvm/lib/Target/SPIRV/SPIRVStructurizer.cpp b/llvm/lib/Target/SPIRV/SPIRVStructurizer.cpp
new file mode 100644
index 00000000000000..f663b7f427e235
--- /dev/null
+++ b/llvm/lib/Target/SPIRV/SPIRVStructurizer.cpp
@@ -0,0 +1,1410 @@
+//===-- SPIRVStructurizer.cpp ----------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+//===----------------------------------------------------------------------===//
+
+#include "Analysis/SPIRVConvergenceRegionAnalysis.h"
+#include "SPIRV.h"
+#include "SPIRVSubtarget.h"
+#include "SPIRVTargetMachine.h"
+#include "SPIRVUtils.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/Analysis/LoopInfo.h"
+#include "llvm/CodeGen/IntrinsicLowering.h"
+#include "llvm/IR/CFG.h"
+#include "llvm/IR/Dominators.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/Intrinsics.h"
+#include "llvm/IR/IntrinsicsSPIRV.h"
+#include "llvm/InitializePasses.h"
+#include "llvm/Transforms/Utils/Cloning.h"
+#include "llvm/Transforms/Utils/LoopSimplify.h"
+#include "llvm/Transforms/Utils/LowerMemIntrinsics.h"
+#include <queue>
+#include <stack>
+
+using namespace llvm;
+using namespace SPIRV;
+
+namespace llvm {
+
+void initializeSPIRVStructurizerPass(PassRegistry &);
+
+namespace {
+
+using BlockSet = std::unordered_set<BasicBlock *>;
+using Edge = std::pair<BasicBlock *, BasicBlock *>;
+
+// This class implements a partial ordering visitor, which visits a cyclic graph
+// in natural topological-like ordering. Topological or...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/107408


More information about the cfe-commits mailing list