[Mlir-commits] [mlir] 730f4a1 - [mlir][spirv] Split header and merge block in `mlir.selection`s (#134875)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Apr 30 05:38:40 PDT 2025
Author: Igor Wodiany
Date: 2025-04-30T13:38:36+01:00
New Revision: 730f4a1ab3b642aa0ce1c88a9a2f80efbbb33539
URL: https://github.com/llvm/llvm-project/commit/730f4a1ab3b642aa0ce1c88a9a2f80efbbb33539
DIFF: https://github.com/llvm/llvm-project/commit/730f4a1ab3b642aa0ce1c88a9a2f80efbbb33539.diff
LOG: [mlir][spirv] Split header and merge block in `mlir.selection`s (#134875)
In the example below with the current code the first selection construct
(`if`/`else` in GLSL for simplicity) share its merge block with a header
block of the second construct.
```
bool _115;
if (_107)
{
// ...
_115 = _200 < _174;
}
else
{
_115 = _107;
}
bool _123;
if (_115)
{
// ...
_123 = _213 < _174;
}
else
{
_123 = _115;
}
```
This results in a malformed nesting of `mlir.selection` instructions
where one selection ends up inside a header block of another selection
construct. For example:
```
%61 = spirv.mlir.selection -> i1 {
%80 = spirv.mlir.selection -> i1 {
spirv.BranchConditional %60, ^bb1, ^bb2(%60 : i1)
^bb1: // pred: ^bb0
// ...
spirv.Branch ^bb2(%101 : i1)
^bb2(%102: i1): // 2 preds: ^bb0, ^bb1
spirv.mlir.merge %102 : i1
}
spirv.BranchConditional %80, ^bb1, ^bb2(%80 : i1)
^bb1: // pred: ^bb0
// ...
spirv.Branch ^bb2(%90 : i1)
^bb2(%91: i1): // 2 preds: ^bb0, ^bb1
spirv.mlir.merge %91 : i1
}
```
This change ensures that the merge block of one selection is not a
header block of another, splitting blocks if necessary. The existing
block splitting mechanism is updated to handle this case.
Added:
mlir/test/Target/SPIRV/consecutive-selection.spv
Modified:
mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
Removed:
################################################################################
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 46607c3e6a98f..b0220bc16e15e 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -2300,23 +2300,22 @@ LogicalResult spirv::Deserializer::splitConditionalBlocks() {
if (!isa<spirv::BranchConditionalOp>(terminator))
continue;
- // Do not split blocks that only contain a conditional branch, i.e., block
- // size is <= 1.
- if (block->begin() != block->end() &&
- std::next(block->begin()) != block->end()) {
+ // Check if the current header block is a merge block of another construct.
+ bool splitHeaderMergeBlock = false;
+ for (const auto &[_, mergeInfo] : blockMergeInfo) {
+ if (mergeInfo.mergeBlock == block)
+ splitHeaderMergeBlock = true;
+ }
+
+ // Do not split a block that only contains a conditional branch, unless it
+ // is also a merge block of another construct - in that case we want to
+ // split the block. We do not want two constructs to share header / merge
+ // block.
+ if (!llvm::hasSingleElement(*block) || splitHeaderMergeBlock) {
Block *newBlock = block->splitBlock(terminator);
OpBuilder builder(block, block->end());
builder.create<spirv::BranchOp>(block->getParent()->getLoc(), newBlock);
- // If the split block was a merge block of another region we need to
- // update the map.
- for (auto it = blockMergeInfo.begin(); it != blockMergeInfo.end(); ++it) {
- auto &[ignore, mergeInfo] = *it;
- if (mergeInfo.mergeBlock == block) {
- mergeInfo.mergeBlock = newBlock;
- }
- }
-
// After splitting we need to update the map to use the new block as a
// header.
blockMergeInfo.erase(block);
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
index 8dd35aa876726..bcc78e3e6508d 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
@@ -246,8 +246,10 @@ class Deserializer {
return opBuilder.getStringAttr(attrName);
}
- // Move a conditional branch into a separate basic block to avoid sinking
- // defs that are required outside a selection region.
+ /// Move a conditional branch into a separate basic block to avoid unnecessary
+ /// sinking of defs that may be required outside a selection region. This
+ /// function also ensures that a single block cannot be a header block of one
+ /// selection construct and the merge block of another.
LogicalResult splitConditionalBlocks();
//===--------------------------------------------------------------------===//
diff --git a/mlir/test/Target/SPIRV/consecutive-selection.spv b/mlir/test/Target/SPIRV/consecutive-selection.spv
new file mode 100644
index 0000000000000..37520586d041b
--- /dev/null
+++ b/mlir/test/Target/SPIRV/consecutive-selection.spv
@@ -0,0 +1,71 @@
+; RUN: %if spirv-tools %{ spirv-as --target-env spv1.0 %s -o - | mlir-translate --deserialize-spirv - -o - | FileCheck %s %}
+
+; COM: The purpose of this test is to check that in the case where two selections
+; COM: regions share a header / merge block, this block is split and the selection
+; COM: regions are not incorrectly nested.
+
+; CHECK: spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
+; CHECK: spirv.func @main() "None" {
+; CHECK: spirv.mlir.selection {
+; CHECK-NEXT: spirv.BranchConditional {{.*}}, ^[[bb:.+]], ^[[bb:.+]]
+; CHECK-NEXT: ^[[bb:.+]]
+; CHECK: spirv.Branch ^[[bb:.+]]
+; CHECK-NEXT: ^[[bb:.+]]:
+; CHECK-NEXT: spirv.mlir.merge
+; CHECK-NEXT: }
+; CHECK: spirv.mlir.selection {
+; CHECK-NEXT: spirv.BranchConditional {{.*}}, ^[[bb:.+]], ^[[bb:.+]]
+; CHECK-NEXT: ^[[bb:.+]]
+; CHECK: spirv.Branch ^[[bb:.+]]
+; CHECK-NEXT: ^[[bb:.+]]:
+; CHECK-NEXT: spirv.mlir.merge
+; CHECK-NEXT: }
+; CHECK: spirv.Return
+; CHECK-NEXT: }
+; CHECK: }
+
+ OpCapability Shader
+ %2 = OpExtInstImport "GLSL.std.450"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint Fragment %main "main" %colorOut
+ OpExecutionMode %main OriginUpperLeft
+ OpDecorate %colorOut Location 0
+ %void = OpTypeVoid
+ %4 = OpTypeFunction %void
+ %float = OpTypeFloat 32
+ %v4float = OpTypeVector %float 4
+%fun_v4float = OpTypePointer Function %v4float
+ %float_1 = OpConstant %float 1
+ %float_0 = OpConstant %float 0
+ %13 = OpConstantComposite %v4float %float_1 %float_0 %float_0 %float_1
+%out_v4float = OpTypePointer Output %v4float
+ %colorOut = OpVariable %out_v4float Output
+ %uint = OpTypeInt 32 0
+ %uint_0 = OpConstant %uint 0
+ %out_float = OpTypePointer Output %float
+ %bool = OpTypeBool
+ %25 = OpConstantComposite %v4float %float_1 %float_1 %float_0 %float_1
+ %main = OpFunction %void None %4
+ %6 = OpLabel
+ %color = OpVariable %fun_v4float Function
+ OpStore %color %13
+ %19 = OpAccessChain %out_float %colorOut %uint_0
+ %20 = OpLoad %float %19
+ %22 = OpFOrdEqual %bool %20 %float_1
+ OpSelectionMerge %24 None
+ OpBranchConditional %22 %23 %24
+ %23 = OpLabel
+ OpStore %color %25
+ OpBranch %24
+ %24 = OpLabel
+ %30 = OpFOrdEqual %bool %20 %float_1
+ OpSelectionMerge %32 None
+ OpBranchConditional %30 %31 %32
+ %31 = OpLabel
+ OpStore %color %25
+ OpBranch %32
+ %32 = OpLabel
+ %26 = OpLoad %v4float %color
+ OpStore %colorOut %26
+ OpReturn
+ OpFunctionEnd
More information about the Mlir-commits
mailing list