[Mlir-commits] [mlir] [mlir][spirv] Enable block splitting for `spirv.Switch` (PR #170147)

Igor Wodiany llvmlistbot at llvm.org
Mon Dec 1 06:49:38 PST 2025


https://github.com/IgWod-IMG created https://github.com/llvm/llvm-project/pull/170147

This is not strictly necessary as now selection regions can yield values, however splitting the block simplifies the code as it avoids unnecessary values being sunk just to be later yielded.

>From 9a6fcb950cf848ad91161c747e463a6961917363 Mon Sep 17 00:00:00 2001
From: Igor Wodiany <igor.wodiany at imgtec.com>
Date: Thu, 20 Nov 2025 17:17:06 +0000
Subject: [PATCH] [mlir][spirv] Enable block splitting for `spirv.Switch`

This is not strictly necessary as now selection regions can yield
values, however splitting the block simplifies the code as it avoids
unnecessary values being sunk just to be later yielded.
---
 .../SPIRV/Deserialization/Deserializer.cpp    | 14 ++--
 .../SPIRV/Deserialization/Deserializer.h      | 10 +--
 mlir/test/Target/SPIRV/selection_switch.spv   | 69 +++++++++++++++++++
 3 files changed, 81 insertions(+), 12 deletions(-)
 create mode 100644 mlir/test/Target/SPIRV/selection_switch.spv

diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 00057a2269895..0dc3f27781bd3 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -2868,7 +2868,7 @@ LogicalResult spirv::Deserializer::wireUpBlockArgument() {
   return success();
 }
 
-LogicalResult spirv::Deserializer::splitConditionalBlocks() {
+LogicalResult spirv::Deserializer::splitSelectionHeader() {
   // Create a copy, so we can modify keys in the original.
   BlockMergeInfoMap blockMergeInfoCopy = blockMergeInfo;
   for (auto it = blockMergeInfoCopy.begin(), e = blockMergeInfoCopy.end();
@@ -2885,7 +2885,7 @@ LogicalResult spirv::Deserializer::splitConditionalBlocks() {
     Operation *terminator = block->getTerminator();
     assert(terminator);
 
-    if (!isa<spirv::BranchConditionalOp>(terminator))
+    if (!isa<spirv::BranchConditionalOp, spirv::SwitchOp>(terminator))
       continue;
 
     // Check if the current header block is a merge block of another construct.
@@ -2895,10 +2895,10 @@ LogicalResult spirv::Deserializer::splitConditionalBlocks() {
         splitHeaderMergeBlock = true;
     }
 
-    // Do not split a block that only contains a conditional branch, unless it
-    // is also a merge block of another construct - in that case we want to
-    // split the block. We do not want two constructs to share header / merge
-    // block.
+    // Do not split a block that only contains a conditional branch / switch,
+    // unless it is also a merge block of another construct - in that case we
+    // want to split the block. We do not want two constructs to share header /
+    // merge block.
     if (!llvm::hasSingleElement(*block) || splitHeaderMergeBlock) {
       Block *newBlock = block->splitBlock(terminator);
       OpBuilder builder(block, block->end());
@@ -2936,7 +2936,7 @@ LogicalResult spirv::Deserializer::structurizeControlFlow() {
     logger.startLine() << "\n";
   });
 
-  if (failed(splitConditionalBlocks())) {
+  if (failed(splitSelectionHeader())) {
     return failure();
   }
 
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
index 6d09d556c4d02..50c935036158c 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
@@ -280,11 +280,11 @@ class Deserializer {
     return opBuilder.getStringAttr(attrName);
   }
 
-  /// Move a conditional branch into a separate basic block to avoid unnecessary
-  /// sinking of defs that may be required outside a selection region. This
-  /// function also ensures that a single block cannot be a header block of one
-  /// selection construct and the merge block of another.
-  LogicalResult splitConditionalBlocks();
+  /// Move a conditional branch or a switch into a separate basic block to avoid
+  /// unnecessary sinking of defs that may be required outside a selection
+  /// region. This function also ensures that a single block cannot be a header
+  /// block of one selection construct and the merge block of another.
+  LogicalResult splitSelectionHeader();
 
   //===--------------------------------------------------------------------===//
   // Type
diff --git a/mlir/test/Target/SPIRV/selection_switch.spv b/mlir/test/Target/SPIRV/selection_switch.spv
new file mode 100644
index 0000000000000..81fecf307eb7a
--- /dev/null
+++ b/mlir/test/Target/SPIRV/selection_switch.spv
@@ -0,0 +1,69 @@
+; RUN: %if spirv-tools %{ spirv-as --target-env spv1.0 %s -o - | mlir-translate --deserialize-spirv - -o - | FileCheck %s %}
+
+; This test is analogous to selection.spv but tests switch op.
+
+; CHECK:      spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
+; CHECK-NEXT:   spirv.func @switch({{%.*}}: si32) "None" {
+; CHECK:          {{%.*}} = spirv.Constant 1.000000e+00 : f32
+; CHECK-NEXT:     {{%.*}} = spirv.Undef : vector<3xf32>
+; CHECK-NEXT:     {{%.*}} = spirv.CompositeInsert {{%.*}}, {{%.*}}[0 : i32] : f32 into vector<3xf32>
+; CHECK-NEXT:     spirv.Branch ^[[bb:.+]]
+; CHECK-NEXT:   ^[[bb:.+]]:
+; CHECK-NEXT:     {{%.*}} = spirv.mlir.selection -> vector<3xf32> {
+; CHECK-NEXT:     spirv.Switch {{%.*}} : si32, [
+; CHECK-NEXT:       default: ^[[bb:.+]]({{%.*}}: vector<3xf32>),
+; CHECK-NEXT:       0: ^[[bb:.+]]({{%.*}}: vector<3xf32>),
+; CHECK-NEXT:       1: ^[[bb:.+]]({{%.*}}: vector<3xf32>)
+; CHECK:          ^[[bb:.+]]({{%.*}}: vector<3xf32>):
+; CHECK:            spirv.Branch ^[[bb:.+]]({{%.*}}: vector<3xf32>)
+; CHECK-NEXT:     ^[[bb:.+]]({{%.*}}: vector<3xf32>):
+; CHECK:            spirv.Branch ^[[bb:.+]]({{%.*}}: vector<3xf32>)
+; CHECK-NEXT:     ^[[bb:.+]]({{%.*}}: vector<3xf32>):
+; CHECK-NEXT:       spirv.mlir.merge %8 : vector<3xf32>
+; CHECK-NEXT      }
+; CHECK:          spirv.Return
+; CHECK-NEXT:   }
+; CHECK:      }
+
+               OpCapability Shader
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint GLCompute %main "main"
+               OpExecutionMode %main LocalSize 1 1 1
+               OpName %switch "switch"
+               OpName %main "main"
+       %void = OpTypeVoid
+        %int = OpTypeInt 32 1
+          %1 = OpTypeFunction %void %int
+      %float = OpTypeFloat 32
+    %float_1 = OpConstant %float 1
+    %v3float = OpTypeVector %float 3
+          %9 = OpUndef %v3float
+    %float_3 = OpConstant %float 3
+    %float_4 = OpConstant %float 4
+    %float_2 = OpConstant %float 2
+         %25 = OpTypeFunction %void
+     %switch = OpFunction %void None %1
+          %5 = OpFunctionParameter %int
+          %6 = OpLabel
+               OpBranch %12
+         %12 = OpLabel
+         %11 = OpCompositeInsert %v3float %float_1 %9 0
+               OpSelectionMerge %15 None
+               OpSwitch %5 %15 0 %13 1 %14
+         %13 = OpLabel
+         %16 = OpPhi %v3float %11 %12
+         %18 = OpCompositeInsert %v3float %float_3 %16 1
+               OpBranch %15
+         %14 = OpLabel
+         %19 = OpPhi %v3float %11 %12
+         %21 = OpCompositeInsert %v3float %float_4 %19 1
+               OpBranch %15
+         %15 = OpLabel
+         %22 = OpPhi %v3float %21 %14 %18 %13 %11 %12
+         %24 = OpCompositeInsert %v3float %float_2 %22 2
+               OpReturn
+               OpFunctionEnd
+       %main = OpFunction %void None %25
+         %27 = OpLabel
+               OpReturn
+               OpFunctionEnd



More information about the Mlir-commits mailing list