[Mlir-commits] [mlir] [mlir][spirv] Split conditional basic blocks during deserialization (PR #127639)
Igor Wodiany
llvmlistbot at llvm.org
Wed Feb 19 09:25:29 PST 2025
https://github.com/IgWod-IMG updated https://github.com/llvm/llvm-project/pull/127639
>From 76498185986d2f468cb7147ea474a92548fab61a Mon Sep 17 00:00:00 2001
From: Igor Wodiany <igor.wodiany at imgtec.com>
Date: Mon, 25 Nov 2024 16:41:03 +0000
Subject: [PATCH] [mlir][spirv] Split conditional basic blocks during
deserialization
With the current design some of the values are sank into a selection region,
despite them being also used outside that region. This is because the current
deserializer logic sinks the entire basic block containing a conditional branch
forming a header of a selection construct, without accounting for some values
being used outside. This manifests as (for example):
```
<unknown>:0: error: 'spirv.Variable' op failed control flow structurization: it has uses outside of the enclosing selection/loop construct
<unknown>:0: note: see current operation: %0 = "spirv.Variable"()<{storage_class = #spirv.storage_class<Function>}> : () -> !spirv.ptr<vector<4xf32>, Function>
```
The proposed solution to this problem is to split the conditional basic block
into two, one block containing just the conditional branch, and other the rest
of instructions. By doing this, the logic that structures selection regions,
only sinks the comparison, keeping the rest of instructions outside the
selection region.
A SPIR-V test is required, as the problem can happen only during
deserialization and cannot be tested with `--test-spirv-roundtrip`. An MLIR
test exhibiting the problematic behaviour would be an incorrect MLIR in the
first place.
This solution is proposed as an alternative to an unfinished PR #123371, that
is unlikely to be merged in the foreseeable future, as the author "stepped away
from this for a time being". There is also a Discourse thread:
https://discourse.llvm.org/t/spir-v-uses-outside-the-selection-region/84494
that tried to solicit some feedback on the topic.
---
.../SPIRV/Deserialization/Deserializer.cpp | 37 ++++++++++++
.../SPIRV/Deserialization/Deserializer.h | 4 ++
mlir/test/Target/SPIRV/selection.mlir | 2 +
mlir/test/Target/SPIRV/selection.spv | 60 +++++++++++++++++++
mlir/test/lit.cfg.py | 1 +
5 files changed, 104 insertions(+)
create mode 100644 mlir/test/Target/SPIRV/selection.spv
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 04469f1933819..a2a6691775bc3 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -2158,6 +2158,34 @@ LogicalResult spirv::Deserializer::wireUpBlockArgument() {
return success();
}
+LogicalResult spirv::Deserializer::splitConditionalBlocks() {
+ for (auto it = blockMergeInfo.begin(); it != blockMergeInfo.end(); it++) {
+ // Skip processing loop regions. For loop regions continueBlock is non-null.
+ if (it->second.continueBlock)
+ return success();
+
+ Block *block = it->first;
+
+ if (!block->mightHaveTerminator())
+ return success();
+
+ Operation *terminator = block->getTerminator();
+ assert(terminator != nullptr);
+
+ if (isa<spirv::BranchConditionalOp>(terminator) &&
+ block->begin() != block->end() &&
+ std::next(block->begin()) != block->end()) {
+ Block *newBlock = block->splitBlock(terminator);
+ OpBuilder builder(block, block->end());
+ builder.create<spirv::BranchOp>(block->getParent()->getLoc(), newBlock);
+
+ it->first = newBlock;
+ }
+ }
+
+ return success();
+}
+
LogicalResult spirv::Deserializer::structurizeControlFlow() {
LLVM_DEBUG({
logger.startLine()
@@ -2165,6 +2193,15 @@ LogicalResult spirv::Deserializer::structurizeControlFlow() {
logger.indent();
});
+ LLVM_DEBUG({
+ logger.startLine() << "[cf] split conditional blocks\n";
+ logger.startLine() << "\n";
+ });
+
+ if (failed(splitConditionalBlocks())) {
+ return failure();
+ }
+
while (!blockMergeInfo.empty()) {
Block *headerBlock = blockMergeInfo.begin()->first;
BlockMergeInfo mergeInfo = blockMergeInfo.begin()->second;
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
index 264d580c40f09..8dd35aa876726 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
@@ -246,6 +246,10 @@ class Deserializer {
return opBuilder.getStringAttr(attrName);
}
+ // Move a conditional branch into a separate basic block to avoid sinking
+ // defs that are required outside a selection region.
+ LogicalResult splitConditionalBlocks();
+
//===--------------------------------------------------------------------===//
// Type
//===--------------------------------------------------------------------===//
diff --git a/mlir/test/Target/SPIRV/selection.mlir b/mlir/test/Target/SPIRV/selection.mlir
index f1d35d74dba15..24abb12998d06 100644
--- a/mlir/test/Target/SPIRV/selection.mlir
+++ b/mlir/test/Target/SPIRV/selection.mlir
@@ -105,6 +105,8 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
%var = spirv.Variable : !spirv.ptr<i1, Function>
// CHECK-NEXT: spirv.Branch ^[[BB:.+]]
// CHECK-NEXT: ^[[BB]]:
+// CHECK: spirv.Branch ^[[BB:.+]]
+// CHECK-NEXT: ^[[BB]]:
// CHECK-NEXT: spirv.mlir.selection {
spirv.mlir.selection {
diff --git a/mlir/test/Target/SPIRV/selection.spv b/mlir/test/Target/SPIRV/selection.spv
new file mode 100644
index 0000000000000..9642d0a44fb59
--- /dev/null
+++ b/mlir/test/Target/SPIRV/selection.spv
@@ -0,0 +1,60 @@
+; RUN: %if spirv-tools %{ spirv-as --target-env spv1.0 %s -o - | mlir-translate --deserialize-spirv - -o - | FileCheck %s %}
+
+; COM: The purpose of this test is to check that a variable (in this case %color) that
+; COM: is defined before a selection region and used both in the selection region and
+; COM: after it, is not sunk into that selection region by the deserializer. If the
+; COM: variable is sunk, then it cannot be accessed outside the region and causes
+; COM: control-flow structurization to fail.
+
+; CHECK: spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
+; CHECK: spirv.func @main() "None" {
+; CHECK: spirv.Variable : !spirv.ptr<vector<4xf32>, Function>
+; CHECK: spirv.mlir.selection {
+; CHECK-NEXT: spirv.BranchConditional {{.*}}, ^[[bb:.+]], ^[[bb:.+]]
+; CHECK-NEXT: ^[[bb:.+]]
+; CHECK: spirv.Branch ^[[bb:.+]]
+; CHECK-NEXT: ^[[bb:.+]]:
+; CHECK-NEXT: spirv.mlir.merge
+; CHECK-NEXT: }
+; CHECK: spirv.Return
+; CHECK-NEXT: }
+; CHECK: }
+
+ OpCapability Shader
+ %2 = OpExtInstImport "GLSL.std.450"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %main "main" %colorOut
+ OpExecutionMode %main OriginUpperLeft
+ OpDecorate %colorOut Location 0
+ %void = OpTypeVoid
+ %4 = OpTypeFunction %void
+ %float = OpTypeFloat 32
+ %v4float = OpTypeVector %float 4
+%fun_v4float = OpTypePointer Function %v4float
+ %float_1 = OpConstant %float 1
+ %float_0 = OpConstant %float 0
+ %13 = OpConstantComposite %v4float %float_1 %float_0 %float_0 %float_1
+%out_v4float = OpTypePointer Output %v4float
+ %colorOut = OpVariable %out_v4float Output
+ %uint = OpTypeInt 32 0
+ %uint_0 = OpConstant %uint 0
+ %out_float = OpTypePointer Output %float
+ %bool = OpTypeBool
+ %25 = OpConstantComposite %v4float %float_1 %float_1 %float_0 %float_1
+ %main = OpFunction %void None %4
+ %6 = OpLabel
+ %color = OpVariable %fun_v4float Function
+ OpStore %color %13
+ %19 = OpAccessChain %out_float %colorOut %uint_0
+ %20 = OpLoad %float %19
+ %22 = OpFOrdEqual %bool %20 %float_1
+ OpSelectionMerge %24 None
+ OpBranchConditional %22 %23 %24
+ %23 = OpLabel
+ OpStore %color %25
+ OpBranch %24
+ %24 = OpLabel
+ %26 = OpLoad %v4float %color
+ OpStore %colorOut %26
+ OpReturn
+ OpFunctionEnd
diff --git a/mlir/test/lit.cfg.py b/mlir/test/lit.cfg.py
index 32b2f8b53d5fa..8578c76969e74 100644
--- a/mlir/test/lit.cfg.py
+++ b/mlir/test/lit.cfg.py
@@ -43,6 +43,7 @@
".test",
".pdll",
".c",
+ ".spv",
]
# test_source_root: The root path where tests are located.
More information about the Mlir-commits
mailing list