[llvm] [SPIR-V] Fix block sorting with irreducible CFG (PR #116996)
Nathan Gauër via llvm-commits
llvm-commits at lists.llvm.org
Wed Nov 20 08:09:30 PST 2024
https://github.com/Keenuts created https://github.com/llvm/llvm-project/pull/116996
Block sorting was assuming reducible CFG. Meaning we always had a best node to continue with. Irreducible CFG makes breaks this assumption, so the algorithm looped indefinitely because no node was a valid candidate.
Fixes #116692
>From 00ce1ff77dd6c8cdf47294dbc5c3b20c8ee2f9b2 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Nathan=20Gau=C3=ABr?= <brioche at google.com>
Date: Wed, 20 Nov 2024 16:54:43 +0100
Subject: [PATCH] [SPIR-V] Fix block sorting with irreducible CFG
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Block sorting was assuming reducible CFG. Meaning we always
had a best node to continue with. Irreducible CFG makes breaks
this assumption, so the algorithm looped indefinitely because
no node was a valid candidate.
Fixes #116692
Signed-off-by: Nathan Gauër <brioche at google.com>
---
llvm/lib/Target/SPIRV/SPIRVUtils.cpp | 20 +-
.../SPIRV/structurizer/cf.if.nested.ll | 36 +--
llvm/unittests/Target/SPIRV/CMakeLists.txt | 4 +-
.../Target/SPIRV/SPIRVSortBlocksTests.cpp | 262 ++++++++++++++++++
4 files changed, 301 insertions(+), 21 deletions(-)
create mode 100644 llvm/unittests/Target/SPIRV/SPIRVSortBlocksTests.cpp
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
index aeb2c29f7b8618..902dcf2ca28649 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
@@ -525,6 +525,12 @@ size_t PartialOrderingVisitor::GetNodeRank(BasicBlock *BB) const {
continue;
auto Iterator = BlockToOrder.end();
+ // This block hasn't been ranked yet. Ignoring.
+ // This doesn't happen often, but when dealing with irreducible CFG, we have
+ // to rank nodes without knowing the rank of all their predecessors.
+ if (Iterator == BlockToOrder.end())
+ continue;
+
Loop *L = LI.getLoopFor(P);
BasicBlock *Latch = L ? L->getLoopLatch() : nullptr;
@@ -550,15 +556,27 @@ size_t PartialOrderingVisitor::visit(BasicBlock *BB, size_t Unused) {
ToVisit.push(BB);
Queued.insert(BB);
+ // When the graph is irreducible, we can end up in a case where each
+ // node has a predecessor we haven't ranked yet.
+ // When such case arise, we have to pick a node to continue.
+ // This index is used to determine when we looped through all candidates.
+ // Each time a candidate is processed, this counter is reset.
+ // If the index is larger than the queue size, it means we looped.
+ size_t QueueIndex = 0;
+
while (ToVisit.size() != 0) {
BasicBlock *BB = ToVisit.front();
ToVisit.pop();
- if (!CanBeVisited(BB)) {
+ // Either the node is a candidate, or we looped already, and this is
+ // the first node we tried.
+ if (!CanBeVisited(BB) && QueueIndex <= ToVisit.size()) {
ToVisit.push(BB);
+ QueueIndex++;
continue;
}
+ QueueIndex = 0;
size_t Rank = GetNodeRank(BB);
OrderInfo Info = {Rank, BlockToOrder.size()};
BlockToOrder.emplace(BB, Info);
diff --git a/llvm/test/CodeGen/SPIRV/structurizer/cf.if.nested.ll b/llvm/test/CodeGen/SPIRV/structurizer/cf.if.nested.ll
index a69475a59db6f4..a44eec94db687e 100644
--- a/llvm/test/CodeGen/SPIRV/structurizer/cf.if.nested.ll
+++ b/llvm/test/CodeGen/SPIRV/structurizer/cf.if.nested.ll
@@ -34,28 +34,28 @@
; CHECK: %[[#bb30:]] = OpLabel
; CHECK: OpSelectionMerge %[[#bb31:]] None
; CHECK: OpBranchConditional %[[#]] %[[#bb32:]] %[[#bb33:]]
-; CHECK: %[[#bb32:]] = OpLabel
+; CHECK: %[[#bb32]] = OpLabel
; CHECK: OpSelectionMerge %[[#bb34:]] None
-; CHECK: OpBranchConditional %[[#]] %[[#bb35:]] %[[#bb34:]]
-; CHECK: %[[#bb33:]] = OpLabel
+; CHECK: OpBranchConditional %[[#]] %[[#bb35:]] %[[#bb34]]
+; CHECK: %[[#bb33]] = OpLabel
; CHECK: OpSelectionMerge %[[#bb36:]] None
; CHECK: OpBranchConditional %[[#]] %[[#bb37:]] %[[#bb38:]]
-; CHECK: %[[#bb35:]] = OpLabel
-; CHECK: OpBranch %[[#bb34:]]
-; CHECK: %[[#bb37:]] = OpLabel
-; CHECK: OpBranch %[[#bb36:]]
-; CHECK: %[[#bb38:]] = OpLabel
+; CHECK: %[[#bb35]] = OpLabel
+; CHECK: OpBranch %[[#bb34]]
+; CHECK: %[[#bb34]] = OpLabel
+; CHECK: OpBranch %[[#bb31]]
+; CHECK: %[[#bb37]] = OpLabel
+; CHECK: OpBranch %[[#bb36]]
+; CHECK: %[[#bb38]] = OpLabel
; CHECK: OpSelectionMerge %[[#bb39:]] None
-; CHECK: OpBranchConditional %[[#]] %[[#bb40:]] %[[#bb39:]]
-; CHECK: %[[#bb34:]] = OpLabel
-; CHECK: OpBranch %[[#bb31:]]
-; CHECK: %[[#bb40:]] = OpLabel
-; CHECK: OpBranch %[[#bb39:]]
-; CHECK: %[[#bb39:]] = OpLabel
-; CHECK: OpBranch %[[#bb36:]]
-; CHECK: %[[#bb36:]] = OpLabel
-; CHECK: OpBranch %[[#bb31:]]
-; CHECK: %[[#bb31:]] = OpLabel
+; CHECK: OpBranchConditional %[[#]] %[[#bb40:]] %[[#bb39]]
+; CHECK: %[[#bb40]] = OpLabel
+; CHECK: OpBranch %[[#bb39]]
+; CHECK: %[[#bb39]] = OpLabel
+; CHECK: OpBranch %[[#bb36]]
+; CHECK: %[[#bb36]] = OpLabel
+; CHECK: OpBranch %[[#bb31]]
+; CHECK: %[[#bb31]] = OpLabel
; CHECK: OpReturnValue %[[#]]
; CHECK: OpFunctionEnd
; CHECK: %[[#func_26:]] = OpFunction %[[#void:]] DontInline %[[#]]
diff --git a/llvm/unittests/Target/SPIRV/CMakeLists.txt b/llvm/unittests/Target/SPIRV/CMakeLists.txt
index e9fe4883e5b024..2af36225c5f200 100644
--- a/llvm/unittests/Target/SPIRV/CMakeLists.txt
+++ b/llvm/unittests/Target/SPIRV/CMakeLists.txt
@@ -15,6 +15,6 @@ set(LLVM_LINK_COMPONENTS
add_llvm_target_unittest(SPIRVTests
SPIRVConvergenceRegionAnalysisTests.cpp
+ SPIRVSortBlocksTests.cpp
SPIRVAPITest.cpp
- )
-
+)
diff --git a/llvm/unittests/Target/SPIRV/SPIRVSortBlocksTests.cpp b/llvm/unittests/Target/SPIRV/SPIRVSortBlocksTests.cpp
new file mode 100644
index 00000000000000..487f116aa3ace6
--- /dev/null
+++ b/llvm/unittests/Target/SPIRV/SPIRVSortBlocksTests.cpp
@@ -0,0 +1,262 @@
+//===- SPIRVSortBlocksTests.cpp ----------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "SPIRVUtils.h"
+#include "llvm/Analysis/DominanceFrontier.h"
+#include "llvm/Analysis/PostDominators.h"
+#include "llvm/AsmParser/Parser.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/LegacyPassManager.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/PassInstrumentation.h"
+#include "llvm/IR/Type.h"
+#include "llvm/IR/TypedPointerType.h"
+#include "llvm/Support/SourceMgr.h"
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include <queue>
+
+using namespace llvm;
+using namespace llvm::SPIRV;
+
+class SPIRVSortBlocksTest : public testing::Test {
+protected:
+ void TearDown() override { M.reset(); }
+
+ bool run(StringRef Assembly) {
+ assert(M == nullptr &&
+ "Calling runAnalysis multiple times is unsafe. See getAnalysis().");
+
+ SMDiagnostic Error;
+ M = parseAssemblyString(Assembly, Error, Context);
+ assert(M && "Bad assembly. Bad test?");
+ llvm::Function *F = M->getFunction("main");
+ return sortBlocks(*F);
+ }
+
+ void checkBasicBlockOrder(std::vector<const char *> &&Expected) {
+ llvm::Function *F = M->getFunction("main");
+ auto It = F->begin();
+ for (const auto *Name : Expected) {
+ ASSERT_TRUE(It != F->end())
+ << "Expected block \"" << Name
+ << "\" but reached the end of the function instead.";
+ ASSERT_TRUE(It->getName() == Name)
+ << "Error: expected block \"" << Name << "\" got \"" << It->getName()
+ << "\"";
+ It++;
+ }
+ EXPECT_TRUE(It == F->end());
+ ASSERT_TRUE(It == F->end())
+ << "No more blocks were expected, but function has more.";
+ }
+
+protected:
+ LLVMContext Context;
+ std::unique_ptr<Module> M;
+};
+
+TEST_F(SPIRVSortBlocksTest, DefaultRegion) {
+ StringRef Assembly = R"(
+ define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
+ ret void
+ }
+ )";
+
+ EXPECT_FALSE(run(Assembly));
+}
+
+TEST_F(SPIRVSortBlocksTest, BasicBlockSwap) {
+ StringRef Assembly = R"(
+ define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
+ entry:
+ br label %middle
+ exit:
+ ret void
+ middle:
+ br label %exit
+ }
+ )";
+
+ EXPECT_TRUE(run(Assembly));
+ checkBasicBlockOrder({"entry", "middle", "exit"});
+}
+
+// Simple loop:
+// entry -> header <-----------------+
+// | `-> body -> continue -+
+// `-> end
+TEST_F(SPIRVSortBlocksTest, LoopOrdering) {
+ StringRef Assembly = R"(
+ define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
+ entry:
+ %1 = icmp ne i32 0, 0
+ br label %header
+ end:
+ ret void
+ body:
+ br label %continue
+ continue:
+ br label %header
+ header:
+ br i1 %1, label %body, label %end
+ }
+ )";
+
+ EXPECT_TRUE(run(Assembly));
+ checkBasicBlockOrder({"entry", "header", "body", "continue", "end"});
+}
+
+// Diamond condition:
+// +-> A -+
+// entry -+ +-> C
+// +-> B -+
+//
+// A and B order can be flipped with no effect, but it must be remain
+// deterministic/stable.
+TEST_F(SPIRVSortBlocksTest, DiamondCondition) {
+ StringRef Assembly = R"(
+ define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
+ entry:
+ %1 = icmp ne i32 0, 0
+ br i1 %1, label %a, label %b
+ c:
+ ret void
+ b:
+ br label %c
+ a:
+ br label %c
+ }
+ )";
+
+ EXPECT_TRUE(run(Assembly));
+ checkBasicBlockOrder({"entry", "a", "b", "c"});
+}
+
+// Skip condition:
+// +-> A -+
+// entry -+ +-> C
+// +------+
+TEST_F(SPIRVSortBlocksTest, SkipCondition) {
+ StringRef Assembly = R"(
+ define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
+ entry:
+ %1 = icmp ne i32 0, 0
+ br i1 %1, label %a, label %c
+ c:
+ ret void
+ a:
+ br label %c
+ }
+ )";
+
+ EXPECT_TRUE(run(Assembly));
+ checkBasicBlockOrder({"entry", "a", "c"});
+}
+
+// Crossing conditions:
+// +------+ +-> C -+
+// +-> A -+ | | |
+// entry -+ +--_|_-+ +-> E
+// +-> B -+ | |
+// +------+----> D -+
+//
+// A & B have the same rank.
+// C & D have the same rank, but are after A & B.
+// E if the last block.
+TEST_F(SPIRVSortBlocksTest, CrossingCondition) {
+ StringRef Assembly = R"(
+ define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
+ entry:
+ %1 = icmp ne i32 0, 0
+ br i1 %1, label %a, label %b
+ e:
+ ret void
+ c:
+ br label %e
+ b:
+ br i1 %1, label %d, label %c
+ d:
+ br label %e
+ a:
+ br i1 %1, label %c, label %d
+ }
+ )";
+
+ EXPECT_TRUE(run(Assembly));
+ checkBasicBlockOrder({"entry", "a", "b", "c", "d", "e"});
+}
+
+// Irreducible CFG
+// digraph {
+// entry -> A;
+//
+// A -> B;
+// A -> C;
+//
+// B -> A;
+// B -> C;
+//
+// C -> B;
+// }
+//
+// Order starts with Entry and A. Order of B and C can change, but must remain
+// stable.
+TEST_F(SPIRVSortBlocksTest, IrreducibleOrdering) {
+ StringRef Assembly = R"(
+ define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
+ entry:
+ %1 = icmp ne i32 0, 0
+ br label %a
+
+ b:
+ br i1 %1, label %a, label %c
+
+ c:
+ br label %b
+
+ a:
+ br i1 %1, label %b, label %c
+
+ }
+ )";
+
+ EXPECT_TRUE(run(Assembly));
+ checkBasicBlockOrder({"entry", "a", "b", "c"});
+}
+
+TEST_F(SPIRVSortBlocksTest, IrreducibleOrderingBeforeReduction) {
+ StringRef Assembly = R"(
+ define void @main() convergent "hlsl.numthreads"="4,8,16" "hlsl.shader"="compute" {
+ entry:
+ %1 = icmp ne i32 0, 0
+ br label %a
+
+ c:
+ br i1 %1, label %d, label %e
+
+ e:
+ ret void
+
+ b:
+ br i1 %1, label %c, label %d
+
+ a:
+ br label %b
+
+ d:
+ br i1 %1, label %b, label %c
+
+ }
+ )";
+
+ EXPECT_TRUE(run(Assembly));
+ checkBasicBlockOrder({"entry", "a", "b", "c", "d", "e"});
+}
More information about the llvm-commits
mailing list