[llvm] [SPIR-V] Fix block sorting with irreducible CFG (PR #116996)

Nathan Gauër via llvm-commits llvm-commits at lists.llvm.org
Wed Nov 20 08:09:30 PST 2024


https://github.com/Keenuts created https://github.com/llvm/llvm-project/pull/116996

Block sorting was assuming reducible CFG. Meaning we always had a best node to continue with. Irreducible CFG makes breaks this assumption, so the algorithm looped indefinitely because no node was a valid candidate.

Fixes #116692

>From 00ce1ff77dd6c8cdf47294dbc5c3b20c8ee2f9b2 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Nathan=20Gau=C3=ABr?= <brioche at google.com>
Date: Wed, 20 Nov 2024 16:54:43 +0100
Subject: [PATCH] [SPIR-V] Fix block sorting with irreducible CFG
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Block sorting was assuming reducible CFG. Meaning we always
had a best node to continue with. Irreducible CFG makes breaks
this assumption, so the algorithm looped indefinitely because
no node was a valid candidate.

Fixes #116692

Signed-off-by: Nathan Gauër <brioche at google.com>
---
 llvm/lib/Target/SPIRV/SPIRVUtils.cpp          |  20 +-
 .../SPIRV/structurizer/cf.if.nested.ll        |  36 +--
 llvm/unittests/Target/SPIRV/CMakeLists.txt    |   4 +-
 .../Target/SPIRV/SPIRVSortBlocksTests.cpp     | 262 ++++++++++++++++++
 4 files changed, 301 insertions(+), 21 deletions(-)
 create mode 100644 llvm/unittests/Target/SPIRV/SPIRVSortBlocksTests.cpp

diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
index aeb2c29f7b8618..902dcf2ca28649 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
@@ -525,6 +525,12 @@ size_t PartialOrderingVisitor::GetNodeRank(BasicBlock *BB) const {
       continue;
 
     auto Iterator = BlockToOrder.end();
+    // This block hasn't been ranked yet. Ignoring.
+    // This doesn't happen often, but when dealing with irreducible CFG, we have
+    // to rank nodes without knowing the rank of all their predecessors.
+    if (Iterator == BlockToOrder.end())
+      continue;
+
     Loop *L = LI.getLoopFor(P);
     BasicBlock *Latch = L ? L->getLoopLatch() : nullptr;
 
@@ -550,15 +556,27 @@ size_t PartialOrderingVisitor::visit(BasicBlock *BB, size_t Unused) {
   ToVisit.push(BB);
   Queued.insert(BB);
 
+  // When the graph is irreducible, we can end up in a case where each
+  // node has a predecessor we haven't ranked yet.
+  // When such case arise, we have to pick a node to continue.
+  // This index is used to determine when we looped through all candidates.
+  // Each time a candidate is processed, this counter is reset.
+  // If the index is larger than the queue size, it means we looped.
+  size_t QueueIndex = 0;
+
   while (ToVisit.size() != 0) {
     BasicBlock *BB = ToVisit.front();
     ToVisit.pop();
 
-    if (!CanBeVisited(BB)) {
+    // Either the node is a candidate, or we looped already, and this is
+    // the first node we tried.
+    if (!CanBeVisited(BB) && QueueIndex <= ToVisit.size()) {
       ToVisit.push(BB);
+      QueueIndex++;
       continue;
     }
 
+    QueueIndex = 0;
     size_t Rank = GetNodeRank(BB);
     OrderInfo Info = {Rank, BlockToOrder.size()};
     BlockToOrder.emplace(BB, Info);
diff --git a/llvm/test/CodeGen/SPIRV/structurizer/cf.if.nested.ll b/llvm/test/CodeGen/SPIRV/structurizer/cf.if.nested.ll
index a69475a59db6f4..a44eec94db687e 100644
--- a/llvm/test/CodeGen/SPIRV/structurizer/cf.if.nested.ll
+++ b/llvm/test/CodeGen/SPIRV/structurizer/cf.if.nested.ll
@@ -34,28 +34,28 @@
 ; CHECK:    %[[#bb30:]] = OpLabel
 ; CHECK:                  OpSelectionMerge %[[#bb31:]] None
 ; CHECK:                  OpBranchConditional %[[#]] %[[#bb32:]] %[[#bb33:]]
-; CHECK:    %[[#bb32:]] = OpLabel
+; CHECK:     %[[#bb32]] = OpLabel
 ; CHECK:                  OpSelectionMerge %[[#bb34:]] None
-; CHECK:                  OpBranchConditional %[[#]] %[[#bb35:]] %[[#bb34:]]
-; CHECK:    %[[#bb33:]] = OpLabel
+; CHECK:                  OpBranchConditional %[[#]] %[[#bb35:]] %[[#bb34]]
+; CHECK:     %[[#bb33]] = OpLabel
 ; CHECK:                  OpSelectionMerge %[[#bb36:]] None
 ; CHECK:                  OpBranchConditional %[[#]] %[[#bb37:]] %[[#bb38:]]
-; CHECK:    %[[#bb35:]] = OpLabel
-; CHECK:                  OpBranch %[[#bb34:]]
-; CHECK:    %[[#bb37:]] = OpLabel
-; CHECK:                  OpBranch %[[#bb36:]]
-; CHECK:    %[[#bb38:]] = OpLabel
+; CHECK:     %[[#bb35]] = OpLabel
+; CHECK:                  OpBranch %[[#bb34]]
+; CHECK:     %[[#bb34]] = OpLabel
+; CHECK:                  OpBranch %[[#bb31]]
+; CHECK:     %[[#bb37]] = OpLabel
+; CHECK:                  OpBranch %[[#bb36]]
+; CHECK:     %[[#bb38]] = OpLabel
 ; CHECK:                  OpSelectionMerge %[[#bb39:]] None
-; CHECK:                  OpBranchConditional %[[#]] %[[#bb40:]] %[[#bb39:]]
-; CHECK:    %[[#bb34:]] = OpLabel
-; CHECK:                  OpBranch %[[#bb31:]]
-; CHECK:    %[[#bb40:]] = OpLabel
-; CHECK:                  OpBranch %[[#bb39:]]
-; CHECK:    %[[#bb39:]] = OpLabel
-; CHECK:                  OpBranch %[[#bb36:]]
-; CHECK:    %[[#bb36:]] = OpLabel
-; CHECK:                  OpBranch %[[#bb31:]]
-; CHECK:    %[[#bb31:]] = OpLabel
+; CHECK:                  OpBranchConditional %[[#]] %[[#bb40:]] %[[#bb39]]
+; CHECK:     %[[#bb40]] = OpLabel
+; CHECK:                  OpBranch %[[#bb39]]
+; CHECK:     %[[#bb39]] = OpLabel
+; CHECK:                  OpBranch %[[#bb36]]
+; CHECK:     %[[#bb36]] = OpLabel
+; CHECK:                  OpBranch %[[#bb31]]
+; CHECK:     %[[#bb31]] = OpLabel
 ; CHECK:                  OpReturnValue %[[#]]
 ; CHECK:                  OpFunctionEnd
 ; CHECK: %[[#func_26:]] = OpFunction %[[#void:]] DontInline %[[#]]
diff --git a/llvm/unittests/Target/SPIRV/CMakeLists.txt b/llvm/unittests/Target/SPIRV/CMakeLists.txt
index e9fe4883e5b024..2af36225c5f200 100644
--- a/llvm/unittests/Target/SPIRV/CMakeLists.txt
+++ b/llvm/unittests/Target/SPIRV/CMakeLists.txt
@@ -15,6 +15,6 @@ set(LLVM_LINK_COMPONENTS
 
 add_llvm_target_unittest(SPIRVTests
   SPIRVConvergenceRegionAnalysisTests.cpp
+  SPIRVSortBlocksTests.cpp
   SPIRVAPITest.cpp
-  )
-
+)
diff --git a/llvm/unittests/Target/SPIRV/SPIRVSortBlocksTests.cpp b/llvm/unittests/Target/SPIRV/SPIRVSortBlocksTests.cpp
new file mode 100644
index 00000000000000..487f116aa3ace6
--- /dev/null
+++ b/llvm/unittests/Target/SPIRV/SPIRVSortBlocksTests.cpp
@@ -0,0 +1,262 @@
+//===- SPIRVSortBlocksTests.cpp ----------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "SPIRVUtils.h"
+#include "llvm/Analysis/DominanceFrontier.h"
+#include "llvm/Analysis/PostDominators.h"
+#include "llvm/AsmParser/Parser.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/LegacyPassManager.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/PassInstrumentation.h"
+#include "llvm/IR/Type.h"
+#include "llvm/IR/TypedPointerType.h"
+#include "llvm/Support/SourceMgr.h"
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include <queue>
+
+using namespace llvm;
+using namespace llvm::SPIRV;
+
+class SPIRVSortBlocksTest : public testing::Test {
+protected:
+  void TearDown() override { M.reset(); }
+
+  bool run(StringRef Assembly) {
+    assert(M == nullptr &&
+           "Calling runAnalysis multiple times is unsafe. See getAnalysis().");
+
+    SMDiagnostic Error;
+    M = parseAssemblyString(Assembly, Error, Context);
+    assert(M && "Bad assembly. Bad test?");
+    llvm::Function *F = M->getFunction("main");
+    return sortBlocks(*F);
+  }
+
+  void checkBasicBlockOrder(std::vector<const char *> &&Expected) {
+    llvm::Function *F = M->getFunction("main");
+    auto It = F->begin();
+    for (const auto *Name : Expected) {
+      ASSERT_TRUE(It != F->end())
+          << "Expected block \"" << Name
+          << "\" but reached the end of the function instead.";
+      ASSERT_TRUE(It->getName() == Name)
+          << "Error: expected block \"" << Name << "\" got \"" << It->getName()
+          << "\"";
+      It++;
+    }
+    EXPECT_TRUE(It == F->end());
+    ASSERT_TRUE(It == F->end())
+        << "No more blocks were expected, but function has more.";
+  }
+
+protected:
+  LLVMContext Context;
+  std::unique_ptr<Module> M;
+};
+
+TEST_F(SPIRVSortBlocksTest, DefaultRegion) {
+  StringRef Assembly = R"(
+    define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
+      ret void
+    }
+  )";
+
+  EXPECT_FALSE(run(Assembly));
+}
+
+TEST_F(SPIRVSortBlocksTest, BasicBlockSwap) {
+  StringRef Assembly = R"(
+    define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
+    entry:
+      br label %middle
+    exit:
+      ret void
+    middle:
+      br label %exit
+    }
+  )";
+
+  EXPECT_TRUE(run(Assembly));
+  checkBasicBlockOrder({"entry", "middle", "exit"});
+}
+
+// Simple loop:
+// entry -> header <-----------------+
+//           | `-> body -> continue -+
+//           `-> end
+TEST_F(SPIRVSortBlocksTest, LoopOrdering) {
+  StringRef Assembly = R"(
+    define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
+    entry:
+      %1 = icmp ne i32 0, 0
+      br label %header
+    end:
+      ret void
+    body:
+      br label %continue
+    continue:
+      br label %header
+    header:
+      br i1 %1, label %body, label %end
+    }
+  )";
+
+  EXPECT_TRUE(run(Assembly));
+  checkBasicBlockOrder({"entry", "header", "body", "continue", "end"});
+}
+
+// Diamond condition:
+//         +-> A -+
+//  entry -+      +-> C
+//         +-> B -+
+//
+// A and B order can be flipped with no effect, but it must be remain
+// deterministic/stable.
+TEST_F(SPIRVSortBlocksTest, DiamondCondition) {
+  StringRef Assembly = R"(
+    define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
+    entry:
+      %1 = icmp ne i32 0, 0
+      br i1 %1, label %a, label %b
+    c:
+      ret void
+    b:
+      br label %c
+    a:
+      br label %c
+    }
+  )";
+
+  EXPECT_TRUE(run(Assembly));
+  checkBasicBlockOrder({"entry", "a", "b", "c"});
+}
+
+// Skip condition:
+//         +-> A -+
+//  entry -+      +-> C
+//         +------+
+TEST_F(SPIRVSortBlocksTest, SkipCondition) {
+  StringRef Assembly = R"(
+    define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
+    entry:
+      %1 = icmp ne i32 0, 0
+      br i1 %1, label %a, label %c
+    c:
+      ret void
+    a:
+      br label %c
+    }
+  )";
+
+  EXPECT_TRUE(run(Assembly));
+  checkBasicBlockOrder({"entry", "a", "c"});
+}
+
+// Crossing conditions:
+//             +------+  +-> C -+
+//         +-> A -+   |  |      |
+//  entry -+      +--_|_-+      +-> E
+//         +-> B -+   |         |
+//             +------+----> D -+
+//
+// A & B have the same rank.
+// C & D have the same rank, but are after A & B.
+// E if the last block.
+TEST_F(SPIRVSortBlocksTest, CrossingCondition) {
+  StringRef Assembly = R"(
+    define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
+    entry:
+      %1 = icmp ne i32 0, 0
+      br i1 %1, label %a, label %b
+    e:
+      ret void
+    c:
+      br label %e
+    b:
+      br i1 %1, label %d, label %c
+    d:
+      br label %e
+    a:
+      br i1 %1, label %c, label %d
+    }
+  )";
+
+  EXPECT_TRUE(run(Assembly));
+  checkBasicBlockOrder({"entry", "a", "b", "c", "d", "e"});
+}
+
+// Irreducible CFG
+// digraph {
+//    entry -> A;
+//
+//    A -> B;
+//    A -> C;
+//
+//    B -> A;
+//    B -> C;
+//
+//    C -> B;
+// }
+//
+// Order starts with Entry and A. Order of B and C can change, but must remain
+// stable.
+TEST_F(SPIRVSortBlocksTest, IrreducibleOrdering) {
+  StringRef Assembly = R"(
+    define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
+    entry:
+      %1 = icmp ne i32 0, 0
+      br label %a
+
+    b:
+      br i1 %1, label %a, label %c
+
+    c:
+      br label %b
+
+    a:
+      br i1 %1, label %b, label %c
+
+    }
+  )";
+
+  EXPECT_TRUE(run(Assembly));
+  checkBasicBlockOrder({"entry", "a", "b", "c"});
+}
+
+TEST_F(SPIRVSortBlocksTest, IrreducibleOrderingBeforeReduction) {
+  StringRef Assembly = R"(
+    define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
+    entry:
+      %1 = icmp ne i32 0, 0
+      br label %a
+
+    c:
+      br i1 %1, label %d, label %e
+
+    e:
+      ret void
+
+    b:
+      br i1 %1, label %c, label %d
+
+    a:
+      br label %b
+
+    d:
+      br i1 %1, label %b, label %c
+
+    }
+  )";
+
+  EXPECT_TRUE(run(Assembly));
+  checkBasicBlockOrder({"entry", "a", "b", "c", "d", "e"});
+}



More information about the llvm-commits mailing list