[Mlir-commits] [mlir] [mlir][spirv] Split conditional basic blocks during deserialization (PR #127639)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Feb 18 06:09:39 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Igor Wodiany (IgWod-IMG)
<details>
<summary>Changes</summary>
With the current design some of the values are sank into a selection region, despite them being also used outside that region. This is because the current deserializer logic sinks the entire basic block containing a conditional branch forming a header of a selection construct, without accounting for some values being used outside. This manifests as (for example):
```
<unknown>:0: error: 'spirv.Variable' op failed control flow structurization: it has uses outside of the enclosing selection/loop construct
<unknown>:0: note: see current operation: %0 = "spirv.Variable"()<{storage_class = #spirv.storage_class<Function>}> : () -> !spirv.ptr<vector<4xf32>, Function>
```
The proposed solution to this problem is to split the conditional basic block into two, one block containing just the conditional branch, and other the rest of instructions. By doing this, the logic that structures selection regions, only sinks the comparison, keeping the rest of instructions outside the selection region.
A SPIR-V test is required, as the problem can happen only during deserialization and cannot be tested with `--test-spirv-roundtrip`. An MLIR test exhibiting the problematic behaviour would be an incorrect MLIR in the first place.
This solution is proposed as an alternative to an unfinished PR #<!-- -->123371, that is unlikely to be merged in the foreseeable future, as the author "stepped away from this for a time being". There is also a Discourse thread: https://discourse.llvm.org/t/spir-v-uses-outside-the-selection-region/84494 that tried to solicit some feedback on the topic.
cc @<!-- -->mishaobu
---
Full diff: https://github.com/llvm/llvm-project/pull/127639.diff
6 Files Affected:
- (modified) mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp (+42)
- (modified) mlir/lib/Target/SPIRV/Deserialization/Deserializer.h (+4)
- (modified) mlir/test/Target/SPIRV/loop.mlir (+2)
- (modified) mlir/test/Target/SPIRV/selection.mlir (+2)
- (added) mlir/test/Target/SPIRV/selection.spv (+40)
- (modified) mlir/test/lit.cfg.py (+1)
``````````diff
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 04469f1933819..ebf2ecee3207a 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -2158,6 +2158,39 @@ LogicalResult spirv::Deserializer::wireUpBlockArgument() {
return success();
}
+LogicalResult spirv::Deserializer::splitConditionalBlocks() {
+ auto splitBlock = [&](Block *block) {
+ // Do not split loop headers
+ if (auto it = blockMergeInfo.find(block); it != blockMergeInfo.end()) {
+ if (it->second.continueBlock) {
+ return;
+ }
+ }
+
+ if (!block->mightHaveTerminator())
+ return;
+
+ auto terminator = block->getTerminator();
+ assert(terminator != nullptr);
+
+ if (isa<spirv::BranchConditionalOp>(terminator) &&
+ std::distance(block->begin(), block->end()) > 1) {
+ auto newBlock = block->splitBlock(terminator);
+ OpBuilder builder(block, block->end());
+ builder.create<spirv::BranchOp>(block->getParent()->getLoc(), newBlock);
+
+ if (auto it = blockMergeInfo.find(block); it != blockMergeInfo.end()) {
+ auto value = std::move(it->second);
+ blockMergeInfo.erase(it);
+ blockMergeInfo.try_emplace(newBlock, std::move(value));
+ }
+ }
+ };
+ curFunction->walk(splitBlock);
+
+ return success();
+}
+
LogicalResult spirv::Deserializer::structurizeControlFlow() {
LLVM_DEBUG({
logger.startLine()
@@ -2165,6 +2198,15 @@ LogicalResult spirv::Deserializer::structurizeControlFlow() {
logger.indent();
});
+ LLVM_DEBUG({
+ logger.startLine() << "[cf] split conditional blocks\n";
+ logger.startLine() << "\n";
+ });
+
+ if (failed(splitConditionalBlocks())) {
+ return failure();
+ }
+
while (!blockMergeInfo.empty()) {
Block *headerBlock = blockMergeInfo.begin()->first;
BlockMergeInfo mergeInfo = blockMergeInfo.begin()->second;
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
index 264d580c40f09..8dd35aa876726 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
@@ -246,6 +246,10 @@ class Deserializer {
return opBuilder.getStringAttr(attrName);
}
+ // Move a conditional branch into a separate basic block to avoid sinking
+ // defs that are required outside a selection region.
+ LogicalResult splitConditionalBlocks();
+
//===--------------------------------------------------------------------===//
// Type
//===--------------------------------------------------------------------===//
diff --git a/mlir/test/Target/SPIRV/loop.mlir b/mlir/test/Target/SPIRV/loop.mlir
index d89600558f56d..dd0d3e1af19dc 100644
--- a/mlir/test/Target/SPIRV/loop.mlir
+++ b/mlir/test/Target/SPIRV/loop.mlir
@@ -267,6 +267,8 @@ spirv.module Physical64 OpenCL requires #spirv.vce<v1.0, [Kernel, Linkage, Addre
}
// CHECK-NEXT: %[[LOAD:.+]] = spirv.Load "Function" %[[VAR]] : i1
%load = spirv.Load "Function" %var : i1
+// CHECK-NEXT: spirv.Branch ^[[BB:.+]]
+// CHECK-NEXT: ^[[BB]]
// CHECK-NEXT: spirv.BranchConditional %[[LOAD]], ^[[CONTINUE:.+]](%[[ARG1]] : i64), ^[[LOOP_MERGE:.+]]
spirv.BranchConditional %load, ^continue(%arg1 : i64), ^loop_merge
// CHECK-NEXT: ^[[CONTINUE]](%[[ARG2:.+]]: i64):
diff --git a/mlir/test/Target/SPIRV/selection.mlir b/mlir/test/Target/SPIRV/selection.mlir
index f1d35d74dba15..24abb12998d06 100644
--- a/mlir/test/Target/SPIRV/selection.mlir
+++ b/mlir/test/Target/SPIRV/selection.mlir
@@ -105,6 +105,8 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
%var = spirv.Variable : !spirv.ptr<i1, Function>
// CHECK-NEXT: spirv.Branch ^[[BB:.+]]
// CHECK-NEXT: ^[[BB]]:
+// CHECK: spirv.Branch ^[[BB:.+]]
+// CHECK-NEXT: ^[[BB]]:
// CHECK-NEXT: spirv.mlir.selection {
spirv.mlir.selection {
diff --git a/mlir/test/Target/SPIRV/selection.spv b/mlir/test/Target/SPIRV/selection.spv
new file mode 100644
index 0000000000000..b96e839f5c805
--- /dev/null
+++ b/mlir/test/Target/SPIRV/selection.spv
@@ -0,0 +1,40 @@
+; RUN: %if spirv-tools %{ spirv-as --target-env spv1.0 %s -o - | mlir-translate --deserialize-spirv - -o - | FileCheck %s %}
+; CHECK: spirv.module
+ OpCapability Shader
+ %2 = OpExtInstImport "GLSL.std.450"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %main "main" %colorOut
+ OpExecutionMode %main OriginUpperLeft
+ OpDecorate %colorOut Location 0
+ %void = OpTypeVoid
+ %4 = OpTypeFunction %void
+ %float = OpTypeFloat 32
+ %v4float = OpTypeVector %float 4
+%fun_v4float = OpTypePointer Function %v4float
+ %float_1 = OpConstant %float 1
+ %float_0 = OpConstant %float 0
+ %13 = OpConstantComposite %v4float %float_1 %float_0 %float_0 %float_1
+%out_v4float = OpTypePointer Output %v4float
+ %colorOut = OpVariable %out_v4float Output
+ %uint = OpTypeInt 32 0
+ %uint_0 = OpConstant %uint 0
+ %out_float = OpTypePointer Output %float
+ %bool = OpTypeBool
+ %25 = OpConstantComposite %v4float %float_1 %float_1 %float_0 %float_1
+ %main = OpFunction %void None %4
+ %6 = OpLabel
+ %color = OpVariable %fun_v4float Function
+ OpStore %color %13
+ %19 = OpAccessChain %out_float %colorOut %uint_0
+ %20 = OpLoad %float %19
+ %22 = OpFOrdEqual %bool %20 %float_1
+ OpSelectionMerge %24 None
+ OpBranchConditional %22 %23 %24
+ %23 = OpLabel
+ OpStore %color %25
+ OpBranch %24
+ %24 = OpLabel
+ %26 = OpLoad %v4float %color
+ OpStore %colorOut %26
+ OpReturn
+ OpFunctionEnd
diff --git a/mlir/test/lit.cfg.py b/mlir/test/lit.cfg.py
index 32b2f8b53d5fa..c447a047eea89 100644
--- a/mlir/test/lit.cfg.py
+++ b/mlir/test/lit.cfg.py
@@ -43,6 +43,7 @@
".test",
".pdll",
".c",
+ ".spv"
]
# test_source_root: The root path where tests are located.
``````````
</details>
https://github.com/llvm/llvm-project/pull/127639
More information about the Mlir-commits
mailing list