[llvm] [SPIR-V] Fix BB ordering & register lifetime (PR #111026)

via llvm-commits llvm-commits at lists.llvm.org
Mon Oct 21 06:16:19 PDT 2024


Nathan =?utf-8?q?Gauër?= <brioche at google.com>,
Nathan =?utf-8?q?Gauër?= <brioche at google.com>,
Nathan =?utf-8?q?Gauër?= <brioche at google.com>
Message-ID:
In-Reply-To: <llvm.org/llvm/llvm-project/pull/111026 at github.com>


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-spir-v

Author: Nathan Gauër (Keenuts)

<details>
<summary>Changes</summary>

The "topological" sorting was behaving incorrectly in some cases: 
the exit of a loop could have a lower rank than a node in the loop.
This causes issues when structurizing some patterns, and also codegen
issues as we could generate BBs in the incorrect order in regard to the
SPIR-V spec.

Fixing this ordering alone broke other parts of the structurizer, which
by luck worked. Had to fix those.

Added more test cases, especially to test basic patterns.

I also needed to tweak/disable some tests for 2 reasons:
 - SPIR-V now required reg2mem/mem2reg to run. Meaning dead stores
   are optimized away. Some tests require tweaks to avoid having the
   whole function removed.
 - Mem2Reg will generate variable & load/stores. This generates
   G_BITCAST in several cases. And there is currently something wrong
   we do with G_BITCAST which causes MIR verifier to complain.
   Until this is resolved, I disabled -verify-machineinstrs flag on
   those tests.

---

Patch is 141.47 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/111026.diff


97 Files Affected:

- (modified) llvm/lib/Target/SPIRV/SPIRVMergeRegionExitTargets.cpp (+15-7) 
- (modified) llvm/lib/Target/SPIRV/SPIRVStructurizer.cpp (+51-50) 
- (modified) llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp (+13-2) 
- (modified) llvm/lib/Target/SPIRV/SPIRVUtils.cpp (+77-32) 
- (modified) llvm/lib/Target/SPIRV/SPIRVUtils.h (+7-1) 
- (modified) llvm/test/CodeGen/SPIRV/HlslBufferLoad.ll (+5-5) 
- (modified) llvm/test/CodeGen/SPIRV/OpVariable_order.ll (+2-2) 
- (modified) llvm/test/CodeGen/SPIRV/ShaderBufferImage.ll (+1-1) 
- (modified) llvm/test/CodeGen/SPIRV/ShaderImage.ll (+1-1) 
- (modified) llvm/test/CodeGen/SPIRV/basic_float_types.ll (-1) 
- (modified) llvm/test/CodeGen/SPIRV/basic_int_types.ll (-3) 
- (modified) llvm/test/CodeGen/SPIRV/basic_int_types_spirvdis.ll (-1) 
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/SV_DispatchThreadID.ll (+2-1) 
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveGetLaneIndex.ll (+2-1) 
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveIsFirstLane.ll (+1) 
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/abs.ll (+2-1) 
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/acos.ll (+2-1) 
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/all.ll (+3-2) 
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/any.ll (+3-2) 
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/asin.ll (+2-1) 
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/atan.ll (+2-1) 
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/atan2.ll (+2-1) 
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/ceil.ll (+2-1) 
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/cos.ll (+2-1) 
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/cosh.ll (+2-1) 
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/countbits.ll (+2-1) 
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/degrees.ll (+1-1) 
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/exp.ll (+2-1) 
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/exp2.ll (+2-1) 
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/fdot.ll (+1) 
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/floor.ll (+2-1) 
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/fmad.ll (+2-1) 
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/fmax.ll (+2-1) 
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/fmin.ll (+2-1) 
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/frac.ll (+2-1) 
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/idot.ll (+1) 
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/imad.ll (+2-1) 
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/length.ll (+1) 
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/lerp.ll (+2-1) 
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/log.ll (+2-1) 
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/log10.ll (+10-22) 
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/log2.ll (+2-1) 
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/normalize.ll (+1) 
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/pow.ll (+2-1) 
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/radians.ll (+1-1) 
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/rcp.ll (+2-1) 
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/reversebits.ll (+2-1) 
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/round.ll (+2-1) 
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/rsqrt.ll (+2-1) 
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/saturate.ll (+1) 
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/sign.ll (+1) 
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/sin.ll (+2-1) 
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/sinh.ll (+2-1) 
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/smax.ll (+2-1) 
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/smin.ll (+2-1) 
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/sqrt.ll (+2-1) 
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/step.ll (+1) 
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/tan.ll (+2-1) 
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/tanh.ll (+2-1) 
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/trunc.ll (+2-1) 
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/umax.ll (+2-1) 
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/umin.ll (+2-1) 
- (modified) llvm/test/CodeGen/SPIRV/literals.ll (-3) 
- (added) llvm/test/CodeGen/SPIRV/structurizer/basic-if.ll (+53) 
- (added) llvm/test/CodeGen/SPIRV/structurizer/basic-imbalanced-if.ll (+47) 
- (added) llvm/test/CodeGen/SPIRV/structurizer/basic-loop.ll (+59) 
- (added) llvm/test/CodeGen/SPIRV/structurizer/basic-phi.ll (+58) 
- (modified) llvm/test/CodeGen/SPIRV/structurizer/cf.cond-op.ll (+86-99) 
- (modified) llvm/test/CodeGen/SPIRV/structurizer/cf.do.break.ll (-1) 
- (modified) llvm/test/CodeGen/SPIRV/structurizer/cf.do.continue.ll (-1) 
- (modified) llvm/test/CodeGen/SPIRV/structurizer/cf.do.nested.ll (-1) 
- (modified) llvm/test/CodeGen/SPIRV/structurizer/cf.for.break.ll (-1) 
- (modified) llvm/test/CodeGen/SPIRV/structurizer/cf.for.continue.ll (-1) 
- (modified) llvm/test/CodeGen/SPIRV/structurizer/cf.for.nested.ll (-1) 
- (modified) llvm/test/CodeGen/SPIRV/structurizer/cf.for.short-circuited-cond.ll (+2-3) 
- (modified) llvm/test/CodeGen/SPIRV/structurizer/cf.if.const-cond.ll (-1) 
- (modified) llvm/test/CodeGen/SPIRV/structurizer/cf.if.for.ll (-1) 
- (modified) llvm/test/CodeGen/SPIRV/structurizer/cf.if.nested.ll (-1) 
- (modified) llvm/test/CodeGen/SPIRV/structurizer/cf.if.plain.ll (-1) 
- (modified) llvm/test/CodeGen/SPIRV/structurizer/cf.logical-and.ll (-1) 
- (modified) llvm/test/CodeGen/SPIRV/structurizer/cf.logical-or.ll (-1) 
- (modified) llvm/test/CodeGen/SPIRV/structurizer/cf.return.early.ll (-1) 
- (modified) llvm/test/CodeGen/SPIRV/structurizer/cf.switch.ifstmt.ll (-1) 
- (modified) llvm/test/CodeGen/SPIRV/structurizer/cf.switch.ifstmt.simple.ll (-1) 
- (modified) llvm/test/CodeGen/SPIRV/structurizer/cf.switch.ifstmt.simple2.ll (-1) 
- (modified) llvm/test/CodeGen/SPIRV/structurizer/cf.while.break.ll (+27-29) 
- (modified) llvm/test/CodeGen/SPIRV/structurizer/condition-linear.ll (+72-70) 
- (modified) llvm/test/CodeGen/SPIRV/structurizer/do-continue.ll (+77-79) 
- (modified) llvm/test/CodeGen/SPIRV/structurizer/do-nested.ll (+55-67) 
- (modified) llvm/test/CodeGen/SPIRV/structurizer/do-plain.ll (+64-60) 
- (modified) llvm/test/CodeGen/SPIRV/structurizer/logical-or.ll (+54-61) 
- (added) llvm/test/CodeGen/SPIRV/structurizer/loop-continue-split.ll (+104) 
- (modified) llvm/test/CodeGen/SPIRV/structurizer/merge-exit-break.ll (+19-19) 
- (modified) llvm/test/CodeGen/SPIRV/structurizer/merge-exit-convergence-in-break.ll (+9-9) 
- (modified) llvm/test/CodeGen/SPIRV/structurizer/merge-exit-multiple-break.ll (+20-15) 
- (added) llvm/test/CodeGen/SPIRV/structurizer/phi-exit.ll (+45) 
- (modified) llvm/test/CodeGen/SPIRV/structurizer/return-early.ll (+14-6) 


``````````diff
diff --git a/llvm/lib/Target/SPIRV/SPIRVMergeRegionExitTargets.cpp b/llvm/lib/Target/SPIRV/SPIRVMergeRegionExitTargets.cpp
index 9930d067173df7..c22492ec43b095 100644
--- a/llvm/lib/Target/SPIRV/SPIRVMergeRegionExitTargets.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVMergeRegionExitTargets.cpp
@@ -130,6 +130,13 @@ class SPIRVMergeRegionExitTargets : public FunctionPass {
     assert(false && "Unhandled terminator type.");
   }
 
+  AllocaInst *CreateVariable(Function &F, Type *Type,
+                             BasicBlock::iterator Position) {
+    const DataLayout &DL = F.getDataLayout();
+    return new AllocaInst(Type, DL.getAllocaAddrSpace(), nullptr, "reg",
+                          Position);
+  }
+
   // Run the pass on the given convergence region, ignoring the sub-regions.
   // Returns true if the CFG changed, false otherwise.
   bool runOnConvergenceRegionNoRecurse(LoopInfo &LI,
@@ -152,6 +159,9 @@ class SPIRVMergeRegionExitTargets : public FunctionPass {
     auto NewExitTarget = BasicBlock::Create(F->getContext(), "new.exit", F);
     IRBuilder<> Builder(NewExitTarget);
 
+    AllocaInst *Variable = CreateVariable(*F, Builder.getInt32Ty(),
+                                          F->begin()->getFirstInsertionPt());
+
     // CodeGen output needs to be stable. Using the set as-is would order
     // the targets differently depending on the allocation pattern.
     // Sorting per basic-block ordering in the function.
@@ -176,18 +186,16 @@ class SPIRVMergeRegionExitTargets : public FunctionPass {
     std::vector<std::pair<BasicBlock *, Value *>> ExitToVariable;
     for (auto Exit : SortedExits) {
       llvm::Value *Value = createExitVariable(Exit, TargetToValue);
+      IRBuilder<> B2(Exit);
+      B2.SetInsertPoint(Exit->getFirstInsertionPt());
+      B2.CreateStore(Value, Variable);
       ExitToVariable.emplace_back(std::make_pair(Exit, Value));
     }
 
-    // Gather the correct value depending on the exit we came from.
-    llvm::PHINode *node =
-        Builder.CreatePHI(Builder.getInt32Ty(), ExitToVariable.size());
-    for (auto [BB, Value] : ExitToVariable) {
-      node->addIncoming(Value, BB);
-    }
+    llvm::Value *Load = Builder.CreateLoad(Builder.getInt32Ty(), Variable);
 
     // Creating the switch to jump to the correct exit target.
-    llvm::SwitchInst *Sw = Builder.CreateSwitch(node, SortedExitTargets[0],
+    llvm::SwitchInst *Sw = Builder.CreateSwitch(Load, SortedExitTargets[0],
                                                 SortedExitTargets.size() - 1);
     for (size_t i = 1; i < SortedExitTargets.size(); i++) {
       BasicBlock *BB = SortedExitTargets[i];
diff --git a/llvm/lib/Target/SPIRV/SPIRVStructurizer.cpp b/llvm/lib/Target/SPIRV/SPIRVStructurizer.cpp
index 211a060ee103bc..5b6d31782c2093 100644
--- a/llvm/lib/Target/SPIRV/SPIRVStructurizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVStructurizer.cpp
@@ -87,6 +87,8 @@ BasicBlock *getExitFor(const ConvergenceRegion *CR) {
 // Returns the merge block designated by I if I is a merge instruction, nullptr
 // otherwise.
 BasicBlock *getDesignatedMergeBlock(Instruction *I) {
+  if (I == nullptr)
+    return nullptr;
   IntrinsicInst *II = dyn_cast<IntrinsicInst>(I);
   if (II == nullptr)
     return nullptr;
@@ -102,6 +104,8 @@ BasicBlock *getDesignatedMergeBlock(Instruction *I) {
 // Returns the continue block designated by I if I is an OpLoopMerge, nullptr
 // otherwise.
 BasicBlock *getDesignatedContinueBlock(Instruction *I) {
+  if (I == nullptr)
+    return nullptr;
   IntrinsicInst *II = dyn_cast<IntrinsicInst>(I);
   if (II == nullptr)
     return nullptr;
@@ -284,18 +288,6 @@ void replaceBranchTargets(BasicBlock *BB, BasicBlock *OldTarget,
   assert(false && "Unhandled terminator type.");
 }
 
-// Replaces basic bloc operands |OldSrc| or OpPhi instructions in |BB| by
-// |NewSrc|. This function does not simplify the OpPhi instruction once
-// transformed.
-void replacePhiTargets(BasicBlock *BB, BasicBlock *OldSrc, BasicBlock *NewSrc) {
-  for (PHINode &Phi : BB->phis()) {
-    int index = Phi.getBasicBlockIndex(OldSrc);
-    if (index == -1)
-      continue;
-    Phi.setIncomingBlock(index, NewSrc);
-  }
-}
-
 } // anonymous namespace
 
 // Given a reducible CFG, produces a structurized CFG in the SPIR-V sense,
@@ -447,48 +439,40 @@ class SPIRVStructurizer : public FunctionPass {
     // clang-format on
     std::vector<Edge>
     createAliasBlocksForComplexEdges(std::vector<Edge> Edges) {
-      std::unordered_map<BasicBlock *, BasicBlock *> Seen;
+      std::unordered_set<BasicBlock *> Seen;
       std::vector<Edge> Output;
       Output.reserve(Edges.size());
 
       for (auto &[Src, Dst] : Edges) {
-        auto [iterator, inserted] = Seen.insert({Src, Dst});
-        if (inserted) {
-          Output.emplace_back(Src, Dst);
-          continue;
+        auto [iterator, inserted] = Seen.insert(Src);
+        if (!inserted) {
+          // Src already a source node. Cannot have 2 edges from A to B.
+          // Creating alias source block.
+          BasicBlock *NewSrc =
+              BasicBlock::Create(F.getContext(), "new.src", &F);
+          replaceBranchTargets(Src, Dst, NewSrc);
+          IRBuilder<> Builder(NewSrc);
+          Builder.CreateBr(Dst);
+          Src = NewSrc;
         }
 
-        // The exact same edge was already seen. Ignoring.
-        if (iterator->second == Dst)
-          continue;
-
-        // The same Src block branches to 2 distinct blocks. This will be an
-        // issue for the generated OpPhi. Creating alias block.
-        BasicBlock *NewSrc =
-            BasicBlock::Create(F.getContext(), "new.exit.src", &F);
-        replaceBranchTargets(Src, Dst, NewSrc);
-        replacePhiTargets(Dst, Src, NewSrc);
-
-        IRBuilder<> Builder(NewSrc);
-        Builder.CreateBr(Dst);
-
-        Seen.emplace(NewSrc, Dst);
-        Output.emplace_back(NewSrc, Dst);
+        Output.emplace_back(Src, Dst);
       }
 
       return Output;
     }
 
+    AllocaInst *CreateVariable(Function &F, Type *Type,
+                               BasicBlock::iterator Position) {
+      const DataLayout &DL = F.getDataLayout();
+      return new AllocaInst(Type, DL.getAllocaAddrSpace(), nullptr, "reg",
+                            Position);
+    }
+
     // Given a construct defined by |Header|, and a list of exiting edges
     // |Edges|, creates a new single exit node, fixing up those edges.
     BasicBlock *createSingleExitNode(BasicBlock *Header,
                                      std::vector<Edge> &Edges) {
-      auto NewExit = BasicBlock::Create(F.getContext(), "new.exit", &F);
-      IRBuilder<> ExitBuilder(NewExit);
-
-      std::vector<BasicBlock *> Dsts;
-      std::unordered_map<BasicBlock *, ConstantInt *> DstToIndex;
-
       // Given 2 edges: Src1 -> Dst, Src2 -> Dst:
       // If Dst has an PHI node, and Src1 and Src2 are both operands, both Src1
       // and Src2 cannot be hidden by NewExit. Create 2 new nodes: Alias1,
@@ -496,6 +480,10 @@ class SPIRVStructurizer : public FunctionPass {
       // Dst PHI node to look for Alias1 and Alias2.
       std::vector<Edge> FixedEdges = createAliasBlocksForComplexEdges(Edges);
 
+      std::vector<BasicBlock *> Dsts;
+      std::unordered_map<BasicBlock *, ConstantInt *> DstToIndex;
+      auto NewExit = BasicBlock::Create(F.getContext(), "new.exit", &F);
+      IRBuilder<> ExitBuilder(NewExit);
       for (auto &[Src, Dst] : FixedEdges) {
         if (DstToIndex.count(Dst) != 0)
           continue;
@@ -506,33 +494,34 @@ class SPIRVStructurizer : public FunctionPass {
       if (Dsts.size() == 1) {
         for (auto &[Src, Dst] : FixedEdges) {
           replaceBranchTargets(Src, Dst, NewExit);
-          replacePhiTargets(Dst, Src, NewExit);
         }
         ExitBuilder.CreateBr(Dsts[0]);
         return NewExit;
       }
 
-      PHINode *PhiNode =
-          ExitBuilder.CreatePHI(ExitBuilder.getInt32Ty(), FixedEdges.size());
-
+      AllocaInst *Variable = CreateVariable(F, ExitBuilder.getInt32Ty(),
+                                            F.begin()->getFirstInsertionPt());
       for (auto &[Src, Dst] : FixedEdges) {
-        PhiNode->addIncoming(DstToIndex[Dst], Src);
+        IRBuilder<> B2(Src);
+        B2.SetInsertPoint(Src->getFirstInsertionPt());
+        B2.CreateStore(DstToIndex[Dst], Variable);
         replaceBranchTargets(Src, Dst, NewExit);
-        replacePhiTargets(Dst, Src, NewExit);
       }
 
+      llvm::Value *Load =
+          ExitBuilder.CreateLoad(ExitBuilder.getInt32Ty(), Variable);
+
       // If we can avoid an OpSwitch, generate an OpBranch. Reason is some
       // OpBranch are allowed to exist without a new OpSelectionMerge if one of
       // the branch is the parent's merge node, while OpSwitches are not.
       if (Dsts.size() == 2) {
-        Value *Condition = ExitBuilder.CreateCmp(CmpInst::ICMP_EQ,
-                                                 DstToIndex[Dsts[0]], PhiNode);
+        Value *Condition =
+            ExitBuilder.CreateCmp(CmpInst::ICMP_EQ, DstToIndex[Dsts[0]], Load);
         ExitBuilder.CreateCondBr(Condition, Dsts[0], Dsts[1]);
         return NewExit;
       }
 
-      SwitchInst *Sw =
-          ExitBuilder.CreateSwitch(PhiNode, Dsts[0], Dsts.size() - 1);
+      SwitchInst *Sw = ExitBuilder.CreateSwitch(Load, Dsts[0], Dsts.size() - 1);
       for (auto It = Dsts.begin() + 1; It != Dsts.end(); ++It) {
         Sw->addCase(DstToIndex[*It], *It);
       }
@@ -576,7 +565,7 @@ class SPIRVStructurizer : public FunctionPass {
 
   // Creates a new basic block in F with a single OpUnreachable instruction.
   BasicBlock *CreateUnreachable(Function &F) {
-    BasicBlock *BB = BasicBlock::Create(F.getContext(), "new.exit", &F);
+    BasicBlock *BB = BasicBlock::Create(F.getContext(), "unreachable", &F);
     IRBuilder<> Builder(BB);
     Builder.CreateUnreachable();
     return BB;
@@ -1127,6 +1116,18 @@ class SPIRVStructurizer : public FunctionPass {
         continue;
 
       Modified = true;
+
+      if (Merge == nullptr) {
+        Merge = *successors(Header).begin();
+        IRBuilder<> Builder(Header);
+        Builder.SetInsertPoint(Header->getTerminator());
+
+        auto MergeAddress = BlockAddress::get(Merge->getParent(), Merge);
+        SmallVector<Value *, 1> Args = {MergeAddress};
+        Builder.CreateIntrinsic(Intrinsic::spv_selection_merge, {}, {Args});
+        continue;
+      }
+
       Instruction *SplitInstruction = Merge->getTerminator();
       if (isMergeInstruction(SplitInstruction->getPrevNode()))
         SplitInstruction = SplitInstruction->getPrevNode();
diff --git a/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp b/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp
index e5384b2eb2c2c1..133a98375d840f 100644
--- a/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp
@@ -29,6 +29,7 @@
 #include "llvm/MC/TargetRegistry.h"
 #include "llvm/Pass.h"
 #include "llvm/Target/TargetOptions.h"
+#include "llvm/Transforms/Scalar/Reg2Mem.h"
 #include "llvm/Transforms/Utils.h"
 #include <optional>
 
@@ -162,6 +163,8 @@ void SPIRVPassConfig::addIRPasses() {
   TargetPassConfig::addIRPasses();
 
   if (TM.getSubtargetImpl()->isVulkanEnv()) {
+    addPass(createRegToMemWrapperPass());
+
     // 1.  Simplify loop for subsequent transformations. After this steps, loops
     // have the following properties:
     //  - loops have a single entry edge (pre-header to loop header).
@@ -169,13 +172,21 @@ void SPIRVPassConfig::addIRPasses() {
     //  - loops have a single back-edge.
     addPass(createLoopSimplifyPass());
 
-    // 2. Merge the convergence region exit nodes into one. After this step,
+    // 2. Removes registers whose lifetime spans across basic blocks. Also
+    // removes phi nodes. This will greatly simplify the next steps.
+    addPass(createRegToMemWrapperPass());
+
+    // 3. Merge the convergence region exit nodes into one. After this step,
     // regions are single-entry, single-exit. This will help determine the
     // correct merge block.
     addPass(createSPIRVMergeRegionExitTargetsPass());
 
-    // 3. Structurize.
+    // 4. Structurize.
     addPass(createSPIRVStructurizerPass());
+
+    // 5. Reduce the amount of variables required by pushing some operations
+    // back to virtual registers.
+    addPass(createPromoteMemoryToRegisterPass());
   }
 
   addPass(createSPIRVRegularizerPass());
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
index dff33b16b9cfcf..f9b361e163c909 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
@@ -460,53 +460,98 @@ PartialOrderingVisitor::getReachableFrom(BasicBlock *Start) {
   return Output;
 }
 
-size_t PartialOrderingVisitor::visit(BasicBlock *BB, size_t Rank) {
-  if (Visited.count(BB) != 0)
-    return Rank;
+bool PartialOrderingVisitor::CanBeVisited(BasicBlock *BB) const {
+  for (BasicBlock *P : predecessors(BB)) {
+    // Ignore back-edges.
+    if (DT.dominates(BB, P))
+      continue;
 
-  Loop *L = LI.getLoopFor(BB);
-  const bool isLoopHeader = LI.isLoopHeader(BB);
+    // One of the predecessor hasn't been visited. Not ready yet.
+    if (BlockToOrder.count(P) == 0)
+      return false;
 
-  if (BlockToOrder.count(BB) == 0) {
-    OrderInfo Info = {Rank, Visited.size()};
-    BlockToOrder.emplace(BB, Info);
-  } else {
-    BlockToOrder[BB].Rank = std::max(BlockToOrder[BB].Rank, Rank);
+    // If the block is a loop exit, the loop must be finished before
+    // we can continue.
+    Loop *L = LI.getLoopFor(P);
+    if (L == nullptr || L->contains(BB))
+      continue;
+
+    // SPIR-V requires a single back-edge. And the backend first
+    // step transforms loops into the simplified format. If we have
+    // more than 1 back-edge, something is wrong.
+    assert(L->getNumBackEdges() <= 1);
+
+    // If the loop has no latch, loop's rank won't matter, so we can
+    // proceed.
+    BasicBlock *Latch = L->getLoopLatch();
+    assert(Latch);
+    if (Latch == nullptr)
+      continue;
+
+    // The latch is not ready yet, let's wait.
+    if (BlockToOrder.count(Latch) == 0)
+      return false;
   }
 
-  for (BasicBlock *Predecessor : predecessors(BB)) {
-    if (isLoopHeader && L->contains(Predecessor)) {
+  return true;
+}
+
+size_t PartialOrderingVisitor::GetNodeRank(BasicBlock *BB) const {
+  size_t result = 0;
+
+  for (BasicBlock *P : predecessors(BB)) {
+    // Ignore back-edges.
+    if (DT.dominates(BB, P))
       continue;
-    }
 
-    if (BlockToOrder.count(Predecessor) == 0) {
-      return Rank;
+    auto Iterator = BlockToOrder.end();
+    Loop *L = LI.getLoopFor(P);
+    BasicBlock *Latch = L ? L->getLoopLatch() : nullptr;
+
+    // If the predecessor is either outside a loop, or part of
+    // the same loop, simply take its rank + 1.
+    if (L == nullptr || L->contains(BB) || Latch == nullptr) {
+      Iterator = BlockToOrder.find(P);
+    } else {
+      // Otherwise, take the loop's rank (highest rank in the loop) as base.
+      // Since loops have a single latch, highest rank is easy to find.
+      // If the loop has no latch, then it doesn't matter.
+      Iterator = BlockToOrder.find(Latch);
     }
+
+    assert(Iterator != BlockToOrder.end());
+    result = std::max(result, Iterator->second.Rank + 1);
   }
 
-  Visited.insert(BB);
+  return result;
+}
+
+size_t PartialOrderingVisitor::visit(BasicBlock *BB, size_t Unused) {
+  ToVisit.push(BB);
+  Queued.insert(BB);
 
-  SmallVector<BasicBlock *, 2> OtherSuccessors;
-  SmallVector<BasicBlock *, 2> LoopSuccessors;
+  while (ToVisit.size() != 0) {
+    BasicBlock *BB = ToVisit.front();
+    ToVisit.pop();
 
-  for (BasicBlock *Successor : successors(BB)) {
-    // Ignoring back-edges.
-    if (DT.dominates(Successor, BB))
+    if (!CanBeVisited(BB)) {
+      ToVisit.push(BB);
       continue;
+    }
 
-    if (isLoopHeader && L->contains(Successor)) {
-      LoopSuccessors.push_back(Successor);
-    } else
-      OtherSuccessors.push_back(Successor);
-  }
+    size_t Rank = GetNodeRank(BB);
+    OrderInfo Info = {Rank, BlockToOrder.size()};
+    BlockToOrder.emplace(BB, Info);
 
-  for (BasicBlock *BB : LoopSuccessors)
-    Rank = std::max(Rank, visit(BB, Rank + 1));
+    for (BasicBlock *S : successors(BB)) {
+      if (Queued.count(S) != 0)
+        continue;
+      ToVisit.push(S);
+      Queued.insert(S);
+    }
+  }
 
-  size_t OutputRank = Rank;
-  for (BasicBlock *Item : OtherSuccessors)
-    OutputRank = std::max(OutputRank, visit(Item, Rank + 1));
-  return OutputRank;
+  return 0;
 }
 
 PartialOrderingVisitor::PartialOrderingVisitor(Function &F) {
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.h b/llvm/lib/Target/SPIRV/SPIRVUtils.h
index 83e717e6ea58fd..11fd3a5c61dcae 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.h
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.h
@@ -18,6 +18,7 @@
 #include "llvm/IR/Dominators.h"
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/TypedPointerType.h"
+#include <queue>
 #include <string>
 #include <unordered_set>
 
@@ -62,7 +63,9 @@ class SPIRVSubtarget;
 class PartialOrderingVisitor {
   DomTreeBuilder::BBDomTree DT;
   LoopInfo LI;
-  std::unordered_set<BasicBlock *> Visited = {};
+
+  std::unordered_set<BasicBlock *> Queued = {};
+  std::queue<BasicBlock *> ToVisit = {};
 
   struct OrderInfo {
     size_t Rank;
@@ -80,6 +83,9 @@ class PartialOrderingVisitor {
   // Visits |BB| with the current rank being |Rank|.
   size_t visit(BasicBlock *BB, size_t Rank);
 
+  size_t GetNodeRank(BasicBlock *BB) const;
+  bool CanBeVisited(BasicBlock *BB) const;
+
 public:
   // Build the visitor to operate on the function F.
   PartialOrderingVisitor(Function &F);
diff --git a/llvm/test/CodeGen/SPIRV/HlslBufferLoad.ll b/llvm/test/CodeGen/SPIRV/HlslBufferLoad.ll
index fe960f0d6f2f9a..2c9ad1a657a1a5 100644
--- a/llvm/test/CodeGen/SPIRV/HlslBufferLoad.ll
+++ b/llvm/test/CodeGen/SPIRV/HlslBufferLoad.ll
@@ -1,4 +1,4 @@
-; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv-vulkan-library %s -o - | FileCheck %s
+; RUN: llc -O0 -mtriple=spirv-vulkan-library %s -o - | FileCheck %s
 ; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-vulkan-library %s -o - -filetype=obj | spirv-val %}
 
 ; CHECK-DAG: OpDecorate [[IntBufferVar:%[0-9]+]] DescriptorSet 16
@@ -18,13 +18,13 @@
 ; CHECK: {{%[0-9]+}} = OpFunction {{%[0-9]+}} DontInline {{%[0-9]+}}
 ; CHECK-NEXT: OpLabel
 define void @RWBufferLoad() #0 {
-; CHECK-NEXT: [[buffer:%[0-9]+]] = OpLoad [[RWBufferTypeInt]] [[IntBufferVar]]
+; CHECK: [[buffer:%[0-9]+]] = OpLoad [[RWBufferTypeInt]] [[IntBufferVar]]
   %buffer0 = call target("spirv.Image", i32, 5, 2, 0, 0, 2, 24)
       @llvm.spv.handle.fromBinding.tspirv.Image_f32_5_2_0_0_2_24(
           i32 16, i32 7, i32 1, i32 0, i1 false)
 
 ; Make sure we use the same variable with multiple loads.
-; CHECK-NEXT: [[buffer:%[0-9]+]] = OpLoad [[RWBufferTypeInt]] [[IntBufferVar]]
+; CHECK: [[buffer:%[0-9]+]] = OpLoad [[RWBufferTypeInt]] [[IntBufferVar]]
   %buffer1 = call target("spirv.Image", i32, 5, 2, 0, 0, 2, 24)
       @llvm.spv.handle.fromBinding.tspirv.Image_f32_5_2_0_0_2_24(
           i32 16, i32 7, i32 1, i32 0, i1 false)
@@ -36,7 +36,7 @@ define void @RWBufferLoad() #0 {
 define void @UseDifferentGlobalVar() #0 {
 ; Make sure we use a different variable from the first function. They have
 ; different types.
-; CHECK-NEXT: [[buffer:%[0-9]+]] = OpLoad [[RWBufferTypeFloat]] [[FloatBufferVar]]
+; CHECK: [[buffer:%[0-9]+]] = OpLoad [[RWBufferTypeFloat]] [[FloatBufferVar]]
   %buffer0 = call target("spirv.Image", float, 5, 2, 0, 0, 2, 3)
       @llvm.spv.handle.fromBinding.tspirv.Image_f32_5_2_0_0_2_3(
           i32 16, i32 7, i32 1, i32 0, i1 false)
@@ -48,7 +48,7 @@ define void @UseDifferentGlobalVar() #0 {
 define void @ReuseGlobalVarFromFirstFunction() #0 {
 ; Make sure we use the same variable as the first function. They should be the
 ; same in case one function calls the other.
-; CHECK-NEXT: [[buffer:%[0-9]+]] = OpLoad [[RWBufferTypeInt]] [[IntBufferVar]]
+; CHECK: [[buffer:%[0-9]+]] = OpLoad [[RWBufferTypeInt]] [[IntBufferVar]]
   %buffer1 = call target("spirv.Image", i32, 5, 2, 0, 0, 2, 24)
       @llvm.spv.handle.fromBinding.tspirv.Image_f32_5_2_0_0_2_24(
           i32 16, i32 7, i32 1, i32 0, i1 false)
diff --git a/llvm/test/CodeGen/SPIRV/OpVariable_order.ll b/llvm/test/CodeGen/SPIRV/OpVariable_order.ll
index 6057bf38d4c4c4..c68250697c4a7b 100644
--- a/llvm/test/Code...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/111026


More information about the llvm-commits mailing list