[Mlir-commits] [mlir] 730f4a1 - [mlir][spirv] Split header and merge block in `mlir.selection`s (#134875)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Apr 30 05:38:40 PDT 2025


Author: Igor Wodiany
Date: 2025-04-30T13:38:36+01:00
New Revision: 730f4a1ab3b642aa0ce1c88a9a2f80efbbb33539

URL: https://github.com/llvm/llvm-project/commit/730f4a1ab3b642aa0ce1c88a9a2f80efbbb33539
DIFF: https://github.com/llvm/llvm-project/commit/730f4a1ab3b642aa0ce1c88a9a2f80efbbb33539.diff

LOG: [mlir][spirv] Split header and merge block in `mlir.selection`s (#134875)

In the example below with the current code the first selection construct
(`if`/`else` in GLSL for simplicity) share its merge block with a header
block of the second construct.

```
bool _115;
if (_107)
{
    // ...
    _115 = _200 < _174;
}
else
{
    _115 = _107;
}
bool _123;
if (_115)
{
    // ...
    _123 = _213 < _174;
}
else
{
    _123 = _115;
}
```

This results in a malformed nesting of `mlir.selection` instructions
where one selection ends up inside a header block of another selection
construct. For example:

```
%61 = spirv.mlir.selection -> i1 {
  %80 = spirv.mlir.selection -> i1 {
    spirv.BranchConditional %60, ^bb1, ^bb2(%60 : i1)
  ^bb1:  // pred: ^bb0
    // ...
    spirv.Branch ^bb2(%101 : i1)
  ^bb2(%102: i1):  // 2 preds: ^bb0, ^bb1
    spirv.mlir.merge %102 : i1
  }
  spirv.BranchConditional %80, ^bb1, ^bb2(%80 : i1)
^bb1:  // pred: ^bb0
  // ...
  spirv.Branch ^bb2(%90 : i1)
^bb2(%91: i1):  // 2 preds: ^bb0, ^bb1
  spirv.mlir.merge %91 : i1
}
```

This change ensures that the merge block of one selection is not a
header block of another, splitting blocks if necessary. The existing
block splitting mechanism is updated to handle this case.

Added: 
    mlir/test/Target/SPIRV/consecutive-selection.spv

Modified: 
    mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
    mlir/lib/Target/SPIRV/Deserialization/Deserializer.h

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 46607c3e6a98f..b0220bc16e15e 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -2300,23 +2300,22 @@ LogicalResult spirv::Deserializer::splitConditionalBlocks() {
     if (!isa<spirv::BranchConditionalOp>(terminator))
       continue;
 
-    // Do not split blocks that only contain a conditional branch, i.e., block
-    // size is <= 1.
-    if (block->begin() != block->end() &&
-        std::next(block->begin()) != block->end()) {
+    // Check if the current header block is a merge block of another construct.
+    bool splitHeaderMergeBlock = false;
+    for (const auto &[_, mergeInfo] : blockMergeInfo) {
+      if (mergeInfo.mergeBlock == block)
+        splitHeaderMergeBlock = true;
+    }
+
+    // Do not split a block that only contains a conditional branch, unless it
+    // is also a merge block of another construct - in that case we want to
+    // split the block. We do not want two constructs to share header / merge
+    // block.
+    if (!llvm::hasSingleElement(*block) || splitHeaderMergeBlock) {
       Block *newBlock = block->splitBlock(terminator);
       OpBuilder builder(block, block->end());
       builder.create<spirv::BranchOp>(block->getParent()->getLoc(), newBlock);
 
-      // If the split block was a merge block of another region we need to
-      // update the map.
-      for (auto it = blockMergeInfo.begin(); it != blockMergeInfo.end(); ++it) {
-        auto &[ignore, mergeInfo] = *it;
-        if (mergeInfo.mergeBlock == block) {
-          mergeInfo.mergeBlock = newBlock;
-        }
-      }
-
       // After splitting we need to update the map to use the new block as a
       // header.
       blockMergeInfo.erase(block);

diff  --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
index 8dd35aa876726..bcc78e3e6508d 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
@@ -246,8 +246,10 @@ class Deserializer {
     return opBuilder.getStringAttr(attrName);
   }
 
-  // Move a conditional branch into a separate basic block to avoid sinking
-  // defs that are required outside a selection region.
+  /// Move a conditional branch into a separate basic block to avoid unnecessary
+  /// sinking of defs that may be required outside a selection region. This
+  /// function also ensures that a single block cannot be a header block of one
+  /// selection construct and the merge block of another.
   LogicalResult splitConditionalBlocks();
 
   //===--------------------------------------------------------------------===//

diff  --git a/mlir/test/Target/SPIRV/consecutive-selection.spv b/mlir/test/Target/SPIRV/consecutive-selection.spv
new file mode 100644
index 0000000000000..37520586d041b
--- /dev/null
+++ b/mlir/test/Target/SPIRV/consecutive-selection.spv
@@ -0,0 +1,71 @@
+; RUN: %if spirv-tools %{ spirv-as --target-env spv1.0 %s -o - | mlir-translate --deserialize-spirv - -o - | FileCheck %s %}
+
+; COM: The purpose of this test is to check that in the case where two selections
+; COM: regions share a header / merge block, this block is split and the selection
+; COM: regions are not incorrectly nested.
+
+; CHECK:      spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
+; CHECK:        spirv.func @main() "None" {
+; CHECK:          spirv.mlir.selection {
+; CHECK-NEXT:       spirv.BranchConditional {{.*}}, ^[[bb:.+]], ^[[bb:.+]]
+; CHECK-NEXT:     ^[[bb:.+]]
+; CHECK:            spirv.Branch ^[[bb:.+]]
+; CHECK-NEXT:     ^[[bb:.+]]:
+; CHECK-NEXT:       spirv.mlir.merge
+; CHECK-NEXT:     }
+; CHECK:          spirv.mlir.selection {
+; CHECK-NEXT:       spirv.BranchConditional {{.*}}, ^[[bb:.+]], ^[[bb:.+]]
+; CHECK-NEXT:     ^[[bb:.+]]
+; CHECK:            spirv.Branch ^[[bb:.+]]
+; CHECK-NEXT:     ^[[bb:.+]]:
+; CHECK-NEXT:       spirv.mlir.merge
+; CHECK-NEXT:     }
+; CHECK:          spirv.Return
+; CHECK-NEXT:   }
+; CHECK:      }
+
+               OpCapability Shader
+          %2 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint Fragment %main "main" %colorOut
+               OpExecutionMode %main OriginUpperLeft
+               OpDecorate %colorOut Location 0
+       %void = OpTypeVoid
+          %4 = OpTypeFunction %void
+      %float = OpTypeFloat 32
+    %v4float = OpTypeVector %float 4
+%fun_v4float = OpTypePointer Function %v4float
+    %float_1 = OpConstant %float 1
+    %float_0 = OpConstant %float 0
+         %13 = OpConstantComposite %v4float %float_1 %float_0 %float_0 %float_1
+%out_v4float = OpTypePointer Output %v4float
+   %colorOut = OpVariable %out_v4float Output   
+       %uint = OpTypeInt 32 0
+     %uint_0 = OpConstant %uint 0
+  %out_float = OpTypePointer Output %float
+       %bool = OpTypeBool
+         %25 = OpConstantComposite %v4float %float_1 %float_1 %float_0 %float_1
+       %main = OpFunction %void None %4
+          %6 = OpLabel
+      %color = OpVariable %fun_v4float Function
+               OpStore %color %13
+         %19 = OpAccessChain %out_float %colorOut %uint_0
+         %20 = OpLoad %float %19
+         %22 = OpFOrdEqual %bool %20 %float_1
+               OpSelectionMerge %24 None
+               OpBranchConditional %22 %23 %24
+         %23 = OpLabel
+               OpStore %color %25
+               OpBranch %24
+         %24 = OpLabel
+         %30 = OpFOrdEqual %bool %20 %float_1
+               OpSelectionMerge %32 None
+               OpBranchConditional %30 %31 %32
+         %31 = OpLabel
+               OpStore %color %25
+               OpBranch %32
+         %32 = OpLabel
+         %26 = OpLoad %v4float %color
+               OpStore %colorOut %26
+               OpReturn
+               OpFunctionEnd


        


More information about the Mlir-commits mailing list