[Mlir-commits] [mlir] [mlir][spirv] Split conditional basic blocks during deserialization (PR #127639)
Igor Wodiany
llvmlistbot at llvm.org
Tue Feb 18 06:09:00 PST 2025
https://github.com/IgWod-IMG created https://github.com/llvm/llvm-project/pull/127639
With the current design some of the values are sank into a selection region, despite them being also used outside that region. This is because the current deserializer logic sinks the entire basic block containing a conditional branch forming a header of a selection construct, without accounting for some values being used outside. This manifests as (for example):
```
<unknown>:0: error: 'spirv.Variable' op failed control flow structurization: it has uses outside of the enclosing selection/loop construct
<unknown>:0: note: see current operation: %0 = "spirv.Variable"()<{storage_class = #spirv.storage_class<Function>}> : () -> !spirv.ptr<vector<4xf32>, Function>
```
The proposed solution to this problem is to split the conditional basic block into two, one block containing just the conditional branch, and other the rest of instructions. By doing this, the logic that structures selection regions, only sinks the comparison, keeping the rest of instructions outside the selection region.
A SPIR-V test is required, as the problem can happen only during deserialization and cannot be tested with `--test-spirv-roundtrip`. An MLIR test exhibiting the problematic behaviour would be an incorrect MLIR in the first place.
This solution is proposed as an alternative to an unfinished PR #123371, that is unlikely to be merged in the foreseeable future, as the author "stepped away from this for a time being". There is also a Discourse thread: https://discourse.llvm.org/t/spir-v-uses-outside-the-selection-region/84494 that tried to solicit some feedback on the topic.
cc @mishaobu
>From 10751804da0b7e224970c9f5e7643e3bb1d1f33a Mon Sep 17 00:00:00 2001
From: Igor Wodiany <igor.wodiany at imgtec.com>
Date: Mon, 25 Nov 2024 16:41:03 +0000
Subject: [PATCH] [mlir][spirv] Split conditional basic block during
deserialization
With the current design some of the values are sank into a selection region,
despite them being also used outside that region. This is because the current
deserializer logic sinks the entire basic block containing a conditional branch
forming a header of a selection construct, without accounting for some values
being used outside. This manifests as (for example):
```
<unknown>:0: error: 'spirv.Variable' op failed control flow structurization: it has uses outside of the enclosing selection/loop construct
<unknown>:0: note: see current operation: %0 = "spirv.Variable"()<{storage_class = #spirv.storage_class<Function>}> : () -> !spirv.ptr<vector<4xf32>, Function>
```
The proposed solution to this problem is to split the conditional basic block
into two, one block containing just the conditional branch, and other the rest
of instructions. By doing this, the logic that structures selection regions,
only sinks the comparison, keeping the rest of instructions outside the
selection region.
A SPIR-V test is required, as the problem can happen only during
deserialization and cannot be tested with `--test-spirv-roundtrip`. An MLIR
test exhibiting the problematic behaviour would be an incorrect MLIR in the
first place.
This solution is proposed as an alternative to an unfinished PR #123371, that
is unlikely to be merged in the foreseeable future, as the author "stepped away
from this for a time being". There is also a Discourse thread:
https://discourse.llvm.org/t/spir-v-uses-outside-the-selection-region/84494
that tried, unfortunately unsuccessful, solicit some feedback on the topic.
---
.../SPIRV/Deserialization/Deserializer.cpp | 42 +++++++++++++++++++
.../SPIRV/Deserialization/Deserializer.h | 4 ++
mlir/test/Target/SPIRV/loop.mlir | 2 +
mlir/test/Target/SPIRV/selection.mlir | 2 +
mlir/test/Target/SPIRV/selection.spv | 40 ++++++++++++++++++
mlir/test/lit.cfg.py | 1 +
6 files changed, 91 insertions(+)
create mode 100644 mlir/test/Target/SPIRV/selection.spv
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 04469f1933819..ebf2ecee3207a 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -2158,6 +2158,39 @@ LogicalResult spirv::Deserializer::wireUpBlockArgument() {
return success();
}
+LogicalResult spirv::Deserializer::splitConditionalBlocks() {
+ auto splitBlock = [&](Block *block) {
+ // Do not split loop headers
+ if (auto it = blockMergeInfo.find(block); it != blockMergeInfo.end()) {
+ if (it->second.continueBlock) {
+ return;
+ }
+ }
+
+ if (!block->mightHaveTerminator())
+ return;
+
+ auto terminator = block->getTerminator();
+ assert(terminator != nullptr);
+
+ if (isa<spirv::BranchConditionalOp>(terminator) &&
+ std::distance(block->begin(), block->end()) > 1) {
+ auto newBlock = block->splitBlock(terminator);
+ OpBuilder builder(block, block->end());
+ builder.create<spirv::BranchOp>(block->getParent()->getLoc(), newBlock);
+
+ if (auto it = blockMergeInfo.find(block); it != blockMergeInfo.end()) {
+ auto value = std::move(it->second);
+ blockMergeInfo.erase(it);
+ blockMergeInfo.try_emplace(newBlock, std::move(value));
+ }
+ }
+ };
+ curFunction->walk(splitBlock);
+
+ return success();
+}
+
LogicalResult spirv::Deserializer::structurizeControlFlow() {
LLVM_DEBUG({
logger.startLine()
@@ -2165,6 +2198,15 @@ LogicalResult spirv::Deserializer::structurizeControlFlow() {
logger.indent();
});
+ LLVM_DEBUG({
+ logger.startLine() << "[cf] split conditional blocks\n";
+ logger.startLine() << "\n";
+ });
+
+ if (failed(splitConditionalBlocks())) {
+ return failure();
+ }
+
while (!blockMergeInfo.empty()) {
Block *headerBlock = blockMergeInfo.begin()->first;
BlockMergeInfo mergeInfo = blockMergeInfo.begin()->second;
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
index 264d580c40f09..8dd35aa876726 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
@@ -246,6 +246,10 @@ class Deserializer {
return opBuilder.getStringAttr(attrName);
}
+ // Move a conditional branch into a separate basic block to avoid sinking
+ // defs that are required outside a selection region.
+ LogicalResult splitConditionalBlocks();
+
//===--------------------------------------------------------------------===//
// Type
//===--------------------------------------------------------------------===//
diff --git a/mlir/test/Target/SPIRV/loop.mlir b/mlir/test/Target/SPIRV/loop.mlir
index d89600558f56d..dd0d3e1af19dc 100644
--- a/mlir/test/Target/SPIRV/loop.mlir
+++ b/mlir/test/Target/SPIRV/loop.mlir
@@ -267,6 +267,8 @@ spirv.module Physical64 OpenCL requires #spirv.vce<v1.0, [Kernel, Linkage, Addre
}
// CHECK-NEXT: %[[LOAD:.+]] = spirv.Load "Function" %[[VAR]] : i1
%load = spirv.Load "Function" %var : i1
+// CHECK-NEXT: spirv.Branch ^[[BB:.+]]
+// CHECK-NEXT: ^[[BB]]
// CHECK-NEXT: spirv.BranchConditional %[[LOAD]], ^[[CONTINUE:.+]](%[[ARG1]] : i64), ^[[LOOP_MERGE:.+]]
spirv.BranchConditional %load, ^continue(%arg1 : i64), ^loop_merge
// CHECK-NEXT: ^[[CONTINUE]](%[[ARG2:.+]]: i64):
diff --git a/mlir/test/Target/SPIRV/selection.mlir b/mlir/test/Target/SPIRV/selection.mlir
index f1d35d74dba15..24abb12998d06 100644
--- a/mlir/test/Target/SPIRV/selection.mlir
+++ b/mlir/test/Target/SPIRV/selection.mlir
@@ -105,6 +105,8 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
%var = spirv.Variable : !spirv.ptr<i1, Function>
// CHECK-NEXT: spirv.Branch ^[[BB:.+]]
// CHECK-NEXT: ^[[BB]]:
+// CHECK: spirv.Branch ^[[BB:.+]]
+// CHECK-NEXT: ^[[BB]]:
// CHECK-NEXT: spirv.mlir.selection {
spirv.mlir.selection {
diff --git a/mlir/test/Target/SPIRV/selection.spv b/mlir/test/Target/SPIRV/selection.spv
new file mode 100644
index 0000000000000..b96e839f5c805
--- /dev/null
+++ b/mlir/test/Target/SPIRV/selection.spv
@@ -0,0 +1,40 @@
+; RUN: %if spirv-tools %{ spirv-as --target-env spv1.0 %s -o - | mlir-translate --deserialize-spirv - -o - | FileCheck %s %}
+; CHECK: spirv.module
+ OpCapability Shader
+ %2 = OpExtInstImport "GLSL.std.450"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %main "main" %colorOut
+ OpExecutionMode %main OriginUpperLeft
+ OpDecorate %colorOut Location 0
+ %void = OpTypeVoid
+ %4 = OpTypeFunction %void
+ %float = OpTypeFloat 32
+ %v4float = OpTypeVector %float 4
+%fun_v4float = OpTypePointer Function %v4float
+ %float_1 = OpConstant %float 1
+ %float_0 = OpConstant %float 0
+ %13 = OpConstantComposite %v4float %float_1 %float_0 %float_0 %float_1
+%out_v4float = OpTypePointer Output %v4float
+ %colorOut = OpVariable %out_v4float Output
+ %uint = OpTypeInt 32 0
+ %uint_0 = OpConstant %uint 0
+ %out_float = OpTypePointer Output %float
+ %bool = OpTypeBool
+ %25 = OpConstantComposite %v4float %float_1 %float_1 %float_0 %float_1
+ %main = OpFunction %void None %4
+ %6 = OpLabel
+ %color = OpVariable %fun_v4float Function
+ OpStore %color %13
+ %19 = OpAccessChain %out_float %colorOut %uint_0
+ %20 = OpLoad %float %19
+ %22 = OpFOrdEqual %bool %20 %float_1
+ OpSelectionMerge %24 None
+ OpBranchConditional %22 %23 %24
+ %23 = OpLabel
+ OpStore %color %25
+ OpBranch %24
+ %24 = OpLabel
+ %26 = OpLoad %v4float %color
+ OpStore %colorOut %26
+ OpReturn
+ OpFunctionEnd
diff --git a/mlir/test/lit.cfg.py b/mlir/test/lit.cfg.py
index 32b2f8b53d5fa..c447a047eea89 100644
--- a/mlir/test/lit.cfg.py
+++ b/mlir/test/lit.cfg.py
@@ -43,6 +43,7 @@
".test",
".pdll",
".c",
+ ".spv"
]
# test_source_root: The root path where tests are located.
More information about the Mlir-commits
mailing list