[Mlir-commits] [mlir] [mlir][spirv] Support (de)serialization of block operands in `spirv.Switch` (PR #168899)

Igor Wodiany llvmlistbot at llvm.org
Thu Nov 20 08:23:26 PST 2025


https://github.com/IgWod-IMG created https://github.com/llvm/llvm-project/pull/168899

None

>From 1f6bd906d7e2faf45f4f8f1631a90b40055e891e Mon Sep 17 00:00:00 2001
From: Igor Wodiany <igor.wodiany at imgtec.com>
Date: Thu, 20 Nov 2025 15:52:35 +0000
Subject: [PATCH] [mlir][spirv] Support (de)serialization of block operands in
 `spirv.Switch`

---
 .../SPIRV/Deserialization/Deserializer.cpp    | 17 ++++++
 .../Target/SPIRV/Serialization/Serializer.cpp | 15 ++++-
 mlir/test/Target/SPIRV/selection.mlir         | 58 +++++++++++++++++++
 3 files changed, 89 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 252be796488c5..c91e7fc0e748c 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -2831,6 +2831,23 @@ LogicalResult spirv::Deserializer::wireUpBlockArgument() {
             branchCondOp.getFalseBlock());
 
       branchCondOp.erase();
+    } else if (auto switchOp = dyn_cast<spirv::SwitchOp>(op)) {
+      if (target == switchOp.getDefaultTarget()) {
+        SmallVector<ValueRange> targetOperands(switchOp.getTargetOperands());
+        DenseIntElementsAttr literals =
+            switchOp.getLiterals().value_or(DenseIntElementsAttr());
+        spirv::SwitchOp::create(
+            opBuilder, switchOp.getLoc(), switchOp.getSelector(),
+            switchOp.getDefaultTarget(), blockArgs, literals,
+            switchOp.getTargets(), targetOperands);
+        switchOp.erase();
+      } else {
+        SuccessorRange targets = switchOp.getTargets();
+        auto it = llvm::find(targets, target);
+        assert(it != targets.end());
+        size_t index = std::distance(targets.begin(), it);
+        switchOp.getTargetOperandsMutable(index).assign(blockArgs);
+      }
     } else {
       return emitError(unknownLoc, "unimplemented terminator for Phi creation");
     }
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 4e03a809bd0bc..153e9c770e464 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -1443,7 +1443,20 @@ LogicalResult Serializer::emitPhiForBlockArguments(Block *block) {
         assert(branchCondOp.getFalseTarget() == block);
         blockOperands = branchCondOp.getFalseTargetOperands();
       }
-
+      assert(!blockOperands->empty() &&
+             "expected non-empty block operand range");
+      predecessors.emplace_back(spirvPredecessor, *blockOperands);
+    } else if (auto switchOp = dyn_cast<spirv::SwitchOp>(terminator)) {
+      std::optional<OperandRange> blockOperands;
+      if (block == switchOp.getDefaultTarget()) {
+        blockOperands = switchOp.getDefaultOperands();
+      } else {
+        SuccessorRange targets = switchOp.getTargets();
+        auto it = llvm::find(targets, block);
+        assert(it != targets.end());
+        size_t index = std::distance(targets.begin(), it);
+        blockOperands = switchOp.getTargetOperands(index);
+      }
       assert(!blockOperands->empty() &&
              "expected non-empty block operand range");
       predecessors.emplace_back(spirvPredecessor, *blockOperands);
diff --git a/mlir/test/Target/SPIRV/selection.mlir b/mlir/test/Target/SPIRV/selection.mlir
index 3f762920015aa..d0ad118b01c8a 100644
--- a/mlir/test/Target/SPIRV/selection.mlir
+++ b/mlir/test/Target/SPIRV/selection.mlir
@@ -288,3 +288,61 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
   spirv.EntryPoint "GLCompute" @main
   spirv.ExecutionMode @main "LocalSize", 1, 1, 1
 }
+
+// -----
+
+// Selection with switch and block operands
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.5, [Shader], []> {
+// CHECK-LABEL: @selection_switch_operands
+  spirv.func @selection_switch_operands(%selector : si32) "None" {
+    %cst1 = spirv.Constant 1.000000e+00 : f32
+    %vec0 = spirv.Undef : vector<3xf32>
+// CHECK: {{%.*}} = spirv.CompositeInsert {{%.*}}, {{%.*}}[0 : i32] : f32 into vector<3xf32>
+    %vec1 = spirv.CompositeInsert %cst1, %vec0[0 : i32] : f32 into vector<3xf32>
+    spirv.Branch ^bb1
+  ^bb1:
+// CHECK: {{%.*}} = spirv.mlir.selection -> vector<3xf32> {
+    %vec4 = spirv.mlir.selection -> vector<3xf32> {
+// CHECK-NEXT: spirv.Switch {{%.*}} : si32, [
+// CHECK-NEXT: default: ^[[DEFAULT:.+]]({{%.*}} : vector<3xf32>),
+// CHECK-NEXT: 0: ^[[CASE0:.+]]({{%.*}} : vector<3xf32>),
+// CHECK-NEXT: 1: ^[[CASE1:.+]]({{%.*}} : vector<3xf32>)
+      spirv.Switch %selector : si32, [
+        default: ^bb3(%vec1 : vector<3xf32>),
+        0: ^bb1(%vec1 : vector<3xf32>),
+        1: ^bb2(%vec1 : vector<3xf32>)
+      ]
+// CHECK: ^[[CASE0]]({{%.*}}: vector<3xf32>)
+    ^bb1(%vecbb1: vector<3xf32>):
+      %cst3 = spirv.Constant 3.000000e+00 : f32
+// CHECK: {{%.*}} = spirv.CompositeInsert {{%.*}}, {{%.*}}[1 : i32] : f32 into vector<3xf32>
+      %vec2 = spirv.CompositeInsert %cst3, %vecbb1[1 : i32] : f32 into vector<3xf32>
+// CHECK-NEXT: spirv.Branch ^[[DEFAULT]]({{%.*}} : vector<3xf32>)
+      spirv.Branch ^bb3(%vec2 : vector<3xf32>)
+// CHECK-NEXT: ^[[CASE1]]({{%.*}}: vector<3xf32>)
+    ^bb2(%vecbb2: vector<3xf32>):
+      %cst4 = spirv.Constant 4.000000e+00 : f32
+// CHECK: {{%.*}} = spirv.CompositeInsert {{%.*}}, {{%.*}}[1 : i32] : f32 into vector<3xf32>
+      %vec3 = spirv.CompositeInsert %cst4, %vecbb2[1 : i32] : f32 into vector<3xf32>
+// CHECK-NEXT: spirv.Branch ^[[DEFAULT]]({{%.*}} : vector<3xf32>)
+      spirv.Branch ^bb3(%vec3 : vector<3xf32>)
+// CHECK-NEXT: ^[[DEFAULT]]({{%.*}}: vector<3xf32>)
+    ^bb3(%vecbb3: vector<3xf32>):
+// CHECK-NEXT: spirv.mlir.merge {{%.*}} : vector<3xf32>
+      spirv.mlir.merge %vecbb3 : vector<3xf32>
+// CHECK-NEXT: }
+    }
+    %cst2 = spirv.Constant 2.000000e+00 : f32
+// CHECK: {{%.*}} = spirv.CompositeInsert {{%.*}}, {{%.*}}[2 : i32] : f32 into vector<3xf32>
+    %vec5 = spirv.CompositeInsert %cst2, %vec4[2 : i32] : f32 into vector<3xf32>
+    spirv.Return
+  }
+
+  spirv.func @main() -> () "None" {
+    spirv.Return
+  }
+
+  spirv.EntryPoint "GLCompute" @main
+  spirv.ExecutionMode @main "LocalSize", 1, 1, 1
+}



More information about the Mlir-commits mailing list