[Mlir-commits] [mlir] [mlir][spirv] Support (de)serialization of block operands in `spirv.Switch` (PR #168899)
Igor Wodiany
llvmlistbot at llvm.org
Thu Nov 20 08:23:26 PST 2025
https://github.com/IgWod-IMG created https://github.com/llvm/llvm-project/pull/168899
None
>From 1f6bd906d7e2faf45f4f8f1631a90b40055e891e Mon Sep 17 00:00:00 2001
From: Igor Wodiany <igor.wodiany at imgtec.com>
Date: Thu, 20 Nov 2025 15:52:35 +0000
Subject: [PATCH] [mlir][spirv] Support (de)serialization of block operands in
`spirv.Switch`
---
.../SPIRV/Deserialization/Deserializer.cpp | 17 ++++++
.../Target/SPIRV/Serialization/Serializer.cpp | 15 ++++-
mlir/test/Target/SPIRV/selection.mlir | 58 +++++++++++++++++++
3 files changed, 89 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 252be796488c5..c91e7fc0e748c 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -2831,6 +2831,23 @@ LogicalResult spirv::Deserializer::wireUpBlockArgument() {
branchCondOp.getFalseBlock());
branchCondOp.erase();
+ } else if (auto switchOp = dyn_cast<spirv::SwitchOp>(op)) {
+ if (target == switchOp.getDefaultTarget()) {
+ SmallVector<ValueRange> targetOperands(switchOp.getTargetOperands());
+ DenseIntElementsAttr literals =
+ switchOp.getLiterals().value_or(DenseIntElementsAttr());
+ spirv::SwitchOp::create(
+ opBuilder, switchOp.getLoc(), switchOp.getSelector(),
+ switchOp.getDefaultTarget(), blockArgs, literals,
+ switchOp.getTargets(), targetOperands);
+ switchOp.erase();
+ } else {
+ SuccessorRange targets = switchOp.getTargets();
+ auto it = llvm::find(targets, target);
+ assert(it != targets.end());
+ size_t index = std::distance(targets.begin(), it);
+ switchOp.getTargetOperandsMutable(index).assign(blockArgs);
+ }
} else {
return emitError(unknownLoc, "unimplemented terminator for Phi creation");
}
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 4e03a809bd0bc..153e9c770e464 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -1443,7 +1443,20 @@ LogicalResult Serializer::emitPhiForBlockArguments(Block *block) {
assert(branchCondOp.getFalseTarget() == block);
blockOperands = branchCondOp.getFalseTargetOperands();
}
-
+ assert(!blockOperands->empty() &&
+ "expected non-empty block operand range");
+ predecessors.emplace_back(spirvPredecessor, *blockOperands);
+ } else if (auto switchOp = dyn_cast<spirv::SwitchOp>(terminator)) {
+ std::optional<OperandRange> blockOperands;
+ if (block == switchOp.getDefaultTarget()) {
+ blockOperands = switchOp.getDefaultOperands();
+ } else {
+ SuccessorRange targets = switchOp.getTargets();
+ auto it = llvm::find(targets, block);
+ assert(it != targets.end());
+ size_t index = std::distance(targets.begin(), it);
+ blockOperands = switchOp.getTargetOperands(index);
+ }
assert(!blockOperands->empty() &&
"expected non-empty block operand range");
predecessors.emplace_back(spirvPredecessor, *blockOperands);
diff --git a/mlir/test/Target/SPIRV/selection.mlir b/mlir/test/Target/SPIRV/selection.mlir
index 3f762920015aa..d0ad118b01c8a 100644
--- a/mlir/test/Target/SPIRV/selection.mlir
+++ b/mlir/test/Target/SPIRV/selection.mlir
@@ -288,3 +288,61 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
spirv.EntryPoint "GLCompute" @main
spirv.ExecutionMode @main "LocalSize", 1, 1, 1
}
+
+// -----
+
+// Selection with switch and block operands
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.5, [Shader], []> {
+// CHECK-LABEL: @selection_switch_operands
+ spirv.func @selection_switch_operands(%selector : si32) "None" {
+ %cst1 = spirv.Constant 1.000000e+00 : f32
+ %vec0 = spirv.Undef : vector<3xf32>
+// CHECK: {{%.*}} = spirv.CompositeInsert {{%.*}}, {{%.*}}[0 : i32] : f32 into vector<3xf32>
+ %vec1 = spirv.CompositeInsert %cst1, %vec0[0 : i32] : f32 into vector<3xf32>
+ spirv.Branch ^bb1
+ ^bb1:
+// CHECK: {{%.*}} = spirv.mlir.selection -> vector<3xf32> {
+ %vec4 = spirv.mlir.selection -> vector<3xf32> {
+// CHECK-NEXT: spirv.Switch {{%.*}} : si32, [
+// CHECK-NEXT: default: ^[[DEFAULT:.+]]({{%.*}} : vector<3xf32>),
+// CHECK-NEXT: 0: ^[[CASE0:.+]]({{%.*}} : vector<3xf32>),
+// CHECK-NEXT: 1: ^[[CASE1:.+]]({{%.*}} : vector<3xf32>)
+ spirv.Switch %selector : si32, [
+ default: ^bb3(%vec1 : vector<3xf32>),
+ 0: ^bb1(%vec1 : vector<3xf32>),
+ 1: ^bb2(%vec1 : vector<3xf32>)
+ ]
+// CHECK: ^[[CASE0]]({{%.*}}: vector<3xf32>)
+ ^bb1(%vecbb1: vector<3xf32>):
+ %cst3 = spirv.Constant 3.000000e+00 : f32
+// CHECK: {{%.*}} = spirv.CompositeInsert {{%.*}}, {{%.*}}[1 : i32] : f32 into vector<3xf32>
+ %vec2 = spirv.CompositeInsert %cst3, %vecbb1[1 : i32] : f32 into vector<3xf32>
+// CHECK-NEXT: spirv.Branch ^[[DEFAULT]]({{%.*}} : vector<3xf32>)
+ spirv.Branch ^bb3(%vec2 : vector<3xf32>)
+// CHECK-NEXT: ^[[CASE1]]({{%.*}}: vector<3xf32>)
+ ^bb2(%vecbb2: vector<3xf32>):
+ %cst4 = spirv.Constant 4.000000e+00 : f32
+// CHECK: {{%.*}} = spirv.CompositeInsert {{%.*}}, {{%.*}}[1 : i32] : f32 into vector<3xf32>
+ %vec3 = spirv.CompositeInsert %cst4, %vecbb2[1 : i32] : f32 into vector<3xf32>
+// CHECK-NEXT: spirv.Branch ^[[DEFAULT]]({{%.*}} : vector<3xf32>)
+ spirv.Branch ^bb3(%vec3 : vector<3xf32>)
+// CHECK-NEXT: ^[[DEFAULT]]({{%.*}}: vector<3xf32>)
+ ^bb3(%vecbb3: vector<3xf32>):
+// CHECK-NEXT: spirv.mlir.merge {{%.*}} : vector<3xf32>
+ spirv.mlir.merge %vecbb3 : vector<3xf32>
+// CHECK-NEXT: }
+ }
+ %cst2 = spirv.Constant 2.000000e+00 : f32
+// CHECK: {{%.*}} = spirv.CompositeInsert {{%.*}}, {{%.*}}[2 : i32] : f32 into vector<3xf32>
+ %vec5 = spirv.CompositeInsert %cst2, %vec4[2 : i32] : f32 into vector<3xf32>
+ spirv.Return
+ }
+
+ spirv.func @main() -> () "None" {
+ spirv.Return
+ }
+
+ spirv.EntryPoint "GLCompute" @main
+ spirv.ExecutionMode @main "LocalSize", 1, 1, 1
+}
More information about the Mlir-commits
mailing list