[Mlir-commits] [mlir] 594919c - [mlir][spirv] Split conditional basic blocks during deserialization (#127639)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Feb 24 14:22:01 PST 2025
Author: Igor Wodiany
Date: 2025-02-24T17:21:58-05:00
New Revision: 594919c263122e1d0468dfecee6eb5962e892b44
URL: https://github.com/llvm/llvm-project/commit/594919c263122e1d0468dfecee6eb5962e892b44
DIFF: https://github.com/llvm/llvm-project/commit/594919c263122e1d0468dfecee6eb5962e892b44.diff
LOG: [mlir][spirv] Split conditional basic blocks during deserialization (#127639)
With the current design some of the values are sank into a selection
region, despite them being also used outside that region. This is
because the current deserializer logic sinks the entire basic block
containing a conditional branch forming a header of a selection
construct, without accounting for some values being used outside. This
manifests as (for example):
```
<unknown>:0: error: 'spirv.Variable' op failed control flow structurization: it has uses outside of the enclosing selection/loop construct
<unknown>:0: note: see current operation: %0 = "spirv.Variable"()<{storage_class = #spirv.storage_class<Function>}> : () -> !spirv.ptr<vector<4xf32>, Function>
```
The proposed solution to this problem is to split the conditional basic
block into two, one block containing just the conditional branch, and
other the rest of instructions. By doing this, the logic that structures
selection regions, only sinks the comparison, keeping the rest of
instructions outside the selection region.
A SPIR-V test is required, as the problem can happen only during
deserialization and cannot be tested with `--test-spirv-roundtrip`. An
MLIR test exhibiting the problematic behaviour would be an incorrect
MLIR in the first place.
This solution is proposed as an alternative to an unfinished PR #123371,
that is unlikely to be merged in the foreseeable future, as the author
"stepped away from this for a time being". There is also a Discourse
thread:
https://discourse.llvm.org/t/spir-v-uses-outside-the-selection-region/84494
that tried to solicit some feedback on the topic.
Added:
mlir/test/Target/SPIRV/selection.spv
Modified:
mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
mlir/test/Target/SPIRV/selection.mlir
mlir/test/lit.cfg.py
Removed:
################################################################################
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 04469f1933819..8ebe8d54b041c 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -2158,6 +2158,53 @@ LogicalResult spirv::Deserializer::wireUpBlockArgument() {
return success();
}
+LogicalResult spirv::Deserializer::splitConditionalBlocks() {
+ // Create a copy, so we can modify keys in the original.
+ BlockMergeInfoMap blockMergeInfoCopy = blockMergeInfo;
+ for (auto it = blockMergeInfoCopy.begin(), e = blockMergeInfoCopy.end();
+ it != e; ++it) {
+ auto &[block, mergeInfo] = *it;
+
+ // Skip processing loop regions. For loop regions continueBlock is non-null.
+ if (mergeInfo.continueBlock)
+ continue;
+
+ if (!block->mightHaveTerminator())
+ continue;
+
+ Operation *terminator = block->getTerminator();
+ assert(terminator);
+
+ if (!isa<spirv::BranchConditionalOp>(terminator))
+ continue;
+
+ // Do not split blocks that only contain a conditional branch, i.e., block
+ // size is <= 1.
+ if (block->begin() != block->end() &&
+ std::next(block->begin()) != block->end()) {
+ Block *newBlock = block->splitBlock(terminator);
+ OpBuilder builder(block, block->end());
+ builder.create<spirv::BranchOp>(block->getParent()->getLoc(), newBlock);
+
+ // If the split block was a merge block of another region we need to
+ // update the map.
+ for (auto it = blockMergeInfo.begin(); it != blockMergeInfo.end(); ++it) {
+ auto &[ignore, mergeInfo] = *it;
+ if (mergeInfo.mergeBlock == block) {
+ mergeInfo.mergeBlock = newBlock;
+ }
+ }
+
+ // After splitting we need to update the map to use the new block as a
+ // header.
+ blockMergeInfo.erase(block);
+ blockMergeInfo.try_emplace(newBlock, mergeInfo);
+ }
+ }
+
+ return success();
+}
+
LogicalResult spirv::Deserializer::structurizeControlFlow() {
LLVM_DEBUG({
logger.startLine()
@@ -2165,6 +2212,18 @@ LogicalResult spirv::Deserializer::structurizeControlFlow() {
logger.indent();
});
+ LLVM_DEBUG({
+ logger.startLine() << "[cf] split conditional blocks\n";
+ logger.startLine() << "\n";
+ });
+
+ if (failed(splitConditionalBlocks())) {
+ return failure();
+ }
+
+ // TODO: This loop is non-deterministic. Iteration order may vary between runs
+ // for the same shader as the key to the map is a pointer. See:
+ // https://github.com/llvm/llvm-project/issues/128547
while (!blockMergeInfo.empty()) {
Block *headerBlock = blockMergeInfo.begin()->first;
BlockMergeInfo mergeInfo = blockMergeInfo.begin()->second;
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
index 264d580c40f09..8dd35aa876726 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
@@ -246,6 +246,10 @@ class Deserializer {
return opBuilder.getStringAttr(attrName);
}
+ // Move a conditional branch into a separate basic block to avoid sinking
+ // defs that are required outside a selection region.
+ LogicalResult splitConditionalBlocks();
+
//===--------------------------------------------------------------------===//
// Type
//===--------------------------------------------------------------------===//
diff --git a/mlir/test/Target/SPIRV/selection.mlir b/mlir/test/Target/SPIRV/selection.mlir
index f1d35d74dba15..24abb12998d06 100644
--- a/mlir/test/Target/SPIRV/selection.mlir
+++ b/mlir/test/Target/SPIRV/selection.mlir
@@ -105,6 +105,8 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
%var = spirv.Variable : !spirv.ptr<i1, Function>
// CHECK-NEXT: spirv.Branch ^[[BB:.+]]
// CHECK-NEXT: ^[[BB]]:
+// CHECK: spirv.Branch ^[[BB:.+]]
+// CHECK-NEXT: ^[[BB]]:
// CHECK-NEXT: spirv.mlir.selection {
spirv.mlir.selection {
diff --git a/mlir/test/Target/SPIRV/selection.spv b/mlir/test/Target/SPIRV/selection.spv
new file mode 100644
index 0000000000000..9642d0a44fb59
--- /dev/null
+++ b/mlir/test/Target/SPIRV/selection.spv
@@ -0,0 +1,60 @@
+; RUN: %if spirv-tools %{ spirv-as --target-env spv1.0 %s -o - | mlir-translate --deserialize-spirv - -o - | FileCheck %s %}
+
+; COM: The purpose of this test is to check that a variable (in this case %color) that
+; COM: is defined before a selection region and used both in the selection region and
+; COM: after it, is not sunk into that selection region by the deserializer. If the
+; COM: variable is sunk, then it cannot be accessed outside the region and causes
+; COM: control-flow structurization to fail.
+
+; CHECK: spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
+; CHECK: spirv.func @main() "None" {
+; CHECK: spirv.Variable : !spirv.ptr<vector<4xf32>, Function>
+; CHECK: spirv.mlir.selection {
+; CHECK-NEXT: spirv.BranchConditional {{.*}}, ^[[bb:.+]], ^[[bb:.+]]
+; CHECK-NEXT: ^[[bb:.+]]
+; CHECK: spirv.Branch ^[[bb:.+]]
+; CHECK-NEXT: ^[[bb:.+]]:
+; CHECK-NEXT: spirv.mlir.merge
+; CHECK-NEXT: }
+; CHECK: spirv.Return
+; CHECK-NEXT: }
+; CHECK: }
+
+ OpCapability Shader
+ %2 = OpExtInstImport "GLSL.std.450"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %main "main" %colorOut
+ OpExecutionMode %main OriginUpperLeft
+ OpDecorate %colorOut Location 0
+ %void = OpTypeVoid
+ %4 = OpTypeFunction %void
+ %float = OpTypeFloat 32
+ %v4float = OpTypeVector %float 4
+%fun_v4float = OpTypePointer Function %v4float
+ %float_1 = OpConstant %float 1
+ %float_0 = OpConstant %float 0
+ %13 = OpConstantComposite %v4float %float_1 %float_0 %float_0 %float_1
+%out_v4float = OpTypePointer Output %v4float
+ %colorOut = OpVariable %out_v4float Output
+ %uint = OpTypeInt 32 0
+ %uint_0 = OpConstant %uint 0
+ %out_float = OpTypePointer Output %float
+ %bool = OpTypeBool
+ %25 = OpConstantComposite %v4float %float_1 %float_1 %float_0 %float_1
+ %main = OpFunction %void None %4
+ %6 = OpLabel
+ %color = OpVariable %fun_v4float Function
+ OpStore %color %13
+ %19 = OpAccessChain %out_float %colorOut %uint_0
+ %20 = OpLoad %float %19
+ %22 = OpFOrdEqual %bool %20 %float_1
+ OpSelectionMerge %24 None
+ OpBranchConditional %22 %23 %24
+ %23 = OpLabel
+ OpStore %color %25
+ OpBranch %24
+ %24 = OpLabel
+ %26 = OpLoad %v4float %color
+ OpStore %colorOut %26
+ OpReturn
+ OpFunctionEnd
diff --git a/mlir/test/lit.cfg.py b/mlir/test/lit.cfg.py
index 32b2f8b53d5fa..8578c76969e74 100644
--- a/mlir/test/lit.cfg.py
+++ b/mlir/test/lit.cfg.py
@@ -43,6 +43,7 @@
".test",
".pdll",
".c",
+ ".spv",
]
# test_source_root: The root path where tests are located.
More information about the Mlir-commits
mailing list