[Mlir-commits] [mlir] 891b3cf - [mlir][spirv] Add support for SwitchOp (#168713)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Nov 20 07:19:15 PST 2025


Author: Igor Wodiany
Date: 2025-11-20T15:19:10Z
New Revision: 891b3cf63e160c83309b90728034ab832184c964

URL: https://github.com/llvm/llvm-project/commit/891b3cf63e160c83309b90728034ab832184c964
DIFF: https://github.com/llvm/llvm-project/commit/891b3cf63e160c83309b90728034ab832184c964.diff

LOG: [mlir][spirv] Add support for SwitchOp (#168713)

The dialect implementation mostly copies the one of `cf.switch`, but
aligns naming to the SPIR-V spec.

Added: 
    

Modified: 
    mlir/docs/Dialects/SPIR-V.md
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td
    mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
    mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp
    mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
    mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
    mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
    mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
    mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
    mlir/lib/Target/SPIRV/Serialization/Serializer.h
    mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir
    mlir/test/Target/SPIRV/selection.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/docs/Dialects/SPIR-V.md b/mlir/docs/Dialects/SPIR-V.md
index 716dd7773aefa..dd68e6ec8b7b8 100644
--- a/mlir/docs/Dialects/SPIR-V.md
+++ b/mlir/docs/Dialects/SPIR-V.md
@@ -566,7 +566,7 @@ merge block.
 For example, for the given function
 
 ```c++
-void loop(bool cond) {
+void if(bool cond) {
   int x = 0;
   if (cond) {
     x = 1;
@@ -605,6 +605,62 @@ func.func @selection(%cond: i1) -> () {
 }
 ```
 
+Similarly, for the give function with a `switch` statement
+
+```c++
+void switch(int selector) {
+  int x = 0;
+  switch (selector) {
+  case 0:
+    x = 2;
+    break;
+  case 1:
+    x = 3;
+    break;
+  default:
+    x = 1;
+    break;
+  }
+  // ...
+}
+```
+
+It will be represented as
+
+```mlir
+func.func @selection(%selector: i32) -> () {
+  %zero = spirv.Constant 0: i32
+  %one = spirv.Constant 1: i32
+  %two = spirv.Constant 2: i32
+  %three = spirv.Constant 3: i32
+  %var = spirv.Variable init(%zero) : !spirv.ptr<i32, Function>
+
+  spirv.mlir.selection {
+    spirv.Switch %selector : i32, [
+      default: ^default,
+      0: ^case0,
+      1: ^case1
+    ]
+  ^default:
+    spirv.Store "Function" %var, %one : i32
+    spirv.Branch ^merge
+
+  ^case0:
+    spirv.Store "Function" %var, %two : i32
+    spirv.Branch ^merge
+
+  ^case1:
+    spirv.Store "Function" %var, %three : i32
+    spirv.Branch ^merge
+
+  ^merge:
+    spirv.mlir.merge
+  }
+
+  // ...
+}
+```
+
 The selection can return values by yielding them with `spirv.mlir.merge`. This
 mechanism allows values defined within the selection region to be used outside of it.
 Without this, values that were sunk into the selection region, but used outside, would

diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index b628f1a3f7b20..7b363fac6e627 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4531,6 +4531,7 @@ def SPIRV_OC_OpSelectionMerge                 : I32EnumAttrCase<"OpSelectionMerg
 def SPIRV_OC_OpLabel                          : I32EnumAttrCase<"OpLabel", 248>;
 def SPIRV_OC_OpBranch                         : I32EnumAttrCase<"OpBranch", 249>;
 def SPIRV_OC_OpBranchConditional              : I32EnumAttrCase<"OpBranchConditional", 250>;
+def SPIRV_OC_OpSwitch                         : I32EnumAttrCase<"OpSwitch", 251>;
 def SPIRV_OC_OpKill                           : I32EnumAttrCase<"OpKill", 252>;
 def SPIRV_OC_OpReturn                         : I32EnumAttrCase<"OpReturn", 253>;
 def SPIRV_OC_OpReturnValue                    : I32EnumAttrCase<"OpReturnValue", 254>;
@@ -4681,7 +4682,7 @@ def SPIRV_OpcodeAttr :
       SPIRV_OC_OpAtomicAnd, SPIRV_OC_OpAtomicOr, SPIRV_OC_OpAtomicXor,
       SPIRV_OC_OpPhi, SPIRV_OC_OpLoopMerge, SPIRV_OC_OpSelectionMerge,
       SPIRV_OC_OpLabel, SPIRV_OC_OpBranch, SPIRV_OC_OpBranchConditional,
-      SPIRV_OC_OpKill, SPIRV_OC_OpReturn, SPIRV_OC_OpReturnValue,
+      SPIRV_OC_OpSwitch, SPIRV_OC_OpKill, SPIRV_OC_OpReturn, SPIRV_OC_OpReturnValue,
       SPIRV_OC_OpUnreachable, SPIRV_OC_OpGroupBroadcast, SPIRV_OC_OpGroupIAdd,
       SPIRV_OC_OpGroupFAdd, SPIRV_OC_OpGroupFMin, SPIRV_OC_OpGroupUMin,
       SPIRV_OC_OpGroupSMin, SPIRV_OC_OpGroupFMax, SPIRV_OC_OpGroupUMax,

diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td
index acb6467132be9..27c9add7d43af 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td
@@ -242,6 +242,112 @@ def SPIRV_FunctionCallOp : SPIRV_Op<"FunctionCall", [
   }];
 }
 
+// -----
+
+def SPIRV_SwitchOp : SPIRV_Op<"Switch",
+    [AttrSizedOperandSegments, InFunctionScope,
+     DeclareOpInterfaceMethods<BranchOpInterface, ["getSuccessorForOperands"]>,
+     Pure, Terminator]> {
+  let summary = [{
+    Multi-way branch to one of the operand label <id>.
+  }];
+
+  let description = [{
+    Selector must have a type of OpTypeInt. Selector is compared for equality to
+    the Target literals.
+
+    Default must be the <id> of a label. If Selector does not equal any of the
+    Target literals, control flow branches to the Default label <id>.
+
+    Target must be alternating scalar integer literals and the <id> of a label.
+    If Selector equals a literal, control flow branches to the following label
+    <id>. It is invalid for any two literal to be equal to each other. If Selector
+    does not equal any literal, control flow branches to the Default label <id>.
+    Each literal is interpreted with the type of Selector: The bit width of
+    Selector’s type is the width of each literal’s type. If this width is not a
+    multiple of 32-bits and the OpTypeInt Signedness is set to 1, the literal values
+    are interpreted as being sign extended.
+
+    If Selector is an OpUndef, behavior is undefined.
+
+    This instruction must be the last instruction in a block.
+
+    #### Example:
+
+    ```mlir
+    spirv.Switch %selector : si32, [
+      default: ^bb1(%a : i32),
+      0: ^bb1(%b : i32),
+      1: ^bb3(%c : i32)
+    ]
+    ```
+  }];
+
+  let arguments = (ins
+    SPIRV_Integer:$selector,
+    Variadic<AnyType>:$defaultOperands,
+    VariadicOfVariadic<AnyType, "case_operand_segments">:$targetOperands,
+    OptionalAttr<AnyIntElementsAttr>:$literals,
+    DenseI32ArrayAttr:$case_operand_segments
+  );
+
+  let results = (outs);
+
+  let successors = (successor AnySuccessor:$defaultTarget,
+                              VariadicSuccessor<AnySuccessor>:$targets);
+
+  let builders = [
+    OpBuilder<(ins "Value":$selector,
+      "Block *":$defaultTarget,
+      "ValueRange":$defaultOperands,
+      CArg<"ArrayRef<APInt>", "{}">:$literals,
+      CArg<"BlockRange", "{}">:$targets,
+      CArg<"ArrayRef<ValueRange>", "{}">:$targetOperands)>,
+    OpBuilder<(ins "Value":$selector,
+      "Block *":$defaultTarget,
+      "ValueRange":$defaultOperands,
+      CArg<"ArrayRef<int32_t>", "{}">:$literals,
+      CArg<"BlockRange", "{}">:$targets,
+      CArg<"ArrayRef<ValueRange>", "{}">:$targetOperands)>,
+    OpBuilder<(ins "Value":$selector,
+      "Block *":$defaultTarget,
+      "ValueRange":$defaultOperands,
+      CArg<"DenseIntElementsAttr", "{}">:$literals,
+      CArg<"BlockRange", "{}">:$targets,
+      CArg<"ArrayRef<ValueRange>", "{}">:$targetOperands)>
+  ];
+
+  let assemblyFormat = [{
+    $selector `:` type($selector) `,` `[` `\n`
+      custom<SwitchOpCases>(ref(type($selector)),$defaultTarget,
+                            $defaultOperands,
+                            type($defaultOperands),
+                            $literals,
+                            $targets,
+                            $targetOperands,
+                            type($targetOperands))
+   `]`
+    attr-dict
+  }];
+
+  let extraClassDeclaration = [{
+    /// Return the operands for the target block at the given index.
+    OperandRange getTargetOperands(unsigned index) {
+      return getTargetOperands()[index];
+    }
+
+    /// Return a mutable range of operands for the target block at the
+    /// given index.
+    MutableOperandRange getTargetOperandsMutable(unsigned index) {
+      return getTargetOperandsMutable()[index];
+    }
+  }];
+
+  let autogenSerialization = 0;
+  let hasVerifier = 1;
+}
+
+
 // -----
 
 def SPIRV_KillOp : SPIRV_Op<"Kill", [Terminator]> {

diff  --git a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
index f0b46e61965f4..a846d7e60024c 100644
--- a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
@@ -219,6 +219,89 @@ MutableOperandRange FunctionCallOp::getArgOperandsMutable() {
   return getArgumentsMutable();
 }
 
+//===----------------------------------------------------------------------===//
+// spirv.Switch
+//===----------------------------------------------------------------------===//
+
+void SwitchOp::build(OpBuilder &builder, OperationState &result, Value selector,
+                     Block *defaultTarget, ValueRange defaultOperands,
+                     DenseIntElementsAttr literals, BlockRange targets,
+                     ArrayRef<ValueRange> targetOperands) {
+  build(builder, result, selector, defaultOperands, targetOperands, literals,
+        defaultTarget, targets);
+}
+
+void SwitchOp::build(OpBuilder &builder, OperationState &result, Value selector,
+                     Block *defaultTarget, ValueRange defaultOperands,
+                     ArrayRef<APInt> literals, BlockRange targets,
+                     ArrayRef<ValueRange> targetOperands) {
+  DenseIntElementsAttr literalsAttr;
+  if (!literals.empty()) {
+    ShapedType literalType = VectorType::get(
+        static_cast<int64_t>(literals.size()), selector.getType());
+    literalsAttr = DenseIntElementsAttr::get(literalType, literals);
+  }
+  build(builder, result, selector, defaultTarget, defaultOperands, literalsAttr,
+        targets, targetOperands);
+}
+
+void SwitchOp::build(OpBuilder &builder, OperationState &result, Value selector,
+                     Block *defaultTarget, ValueRange defaultOperands,
+                     ArrayRef<int32_t> literals, BlockRange targets,
+                     ArrayRef<ValueRange> targetOperands) {
+  DenseIntElementsAttr literalsAttr;
+  if (!literals.empty()) {
+    ShapedType literalType = VectorType::get(
+        static_cast<int64_t>(literals.size()), selector.getType());
+    literalsAttr = DenseIntElementsAttr::get(literalType, literals);
+  }
+  build(builder, result, selector, defaultTarget, defaultOperands, literalsAttr,
+        targets, targetOperands);
+}
+
+LogicalResult SwitchOp::verify() {
+  std::optional<DenseIntElementsAttr> literals = getLiterals();
+  BlockRange targets = getTargets();
+
+  if (!literals && targets.empty())
+    return success();
+
+  Type selectorType = getSelector().getType();
+  Type literalType = literals->getType().getElementType();
+  if (literalType != selectorType)
+    return emitOpError() << "'selector' type (" << selectorType
+                         << ") should match literals type (" << literalType
+                         << ")";
+
+  if (literals && literals->size() != static_cast<int64_t>(targets.size()))
+    return emitOpError() << "number of literals (" << literals->size()
+                         << ") should match number of targets ("
+                         << targets.size() << ")";
+  return success();
+}
+
+SuccessorOperands SwitchOp::getSuccessorOperands(unsigned index) {
+  assert(index < getNumSuccessors() && "invalid successor index");
+  return SuccessorOperands(index == 0 ? getDefaultOperandsMutable()
+                                      : getTargetOperandsMutable(index - 1));
+}
+
+Block *SwitchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
+  std::optional<DenseIntElementsAttr> literals = getLiterals();
+
+  if (!literals)
+    return getDefaultTarget();
+
+  SuccessorRange targets = getTargets();
+  if (auto value = dyn_cast_or_null<IntegerAttr>(operands.front())) {
+    for (auto [index, literal] : llvm::enumerate(literals->getValues<APInt>()))
+      if (literal == value.getValue())
+        return targets[index];
+    return getDefaultTarget();
+  }
+  return nullptr;
+}
+
 //===----------------------------------------------------------------------===//
 // spirv.mlir.loop
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp
index 2f3a28ff16173..8575487ff52cc 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp
@@ -81,6 +81,83 @@ static void printImageOperands(OpAsmPrinter &printer, Operation *imageOp,
   }
 }
 
+/// Adapted from the cf.switch implementation.
+/// <cases> ::= `default` `:` bb-id (`(` ssa-use-and-type-list `)`)?
+///             ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )*
+static ParseResult parseSwitchOpCases(
+    OpAsmParser &parser, Type &selectorType, Block *&defaultTarget,
+    SmallVectorImpl<OpAsmParser::UnresolvedOperand> &defaultOperands,
+    SmallVectorImpl<Type> &defaultOperandTypes, DenseIntElementsAttr &literals,
+    SmallVectorImpl<Block *> &targets,
+    SmallVectorImpl<SmallVector<OpAsmParser::UnresolvedOperand>>
+        &targetOperands,
+    SmallVectorImpl<SmallVector<Type>> &targetOperandTypes) {
+  if (parser.parseKeyword("default") || parser.parseColon() ||
+      parser.parseSuccessor(defaultTarget))
+    return failure();
+  if (succeeded(parser.parseOptionalLParen())) {
+    if (parser.parseOperandList(defaultOperands, OpAsmParser::Delimiter::None,
+                                /*allowResultNumber=*/false) ||
+        parser.parseColonTypeList(defaultOperandTypes) || parser.parseRParen())
+      return failure();
+  }
+
+  SmallVector<APInt> values;
+  unsigned bitWidth = selectorType.getIntOrFloatBitWidth();
+  while (succeeded(parser.parseOptionalComma())) {
+    int64_t value = 0;
+    if (failed(parser.parseInteger(value)))
+      return failure();
+    values.push_back(APInt(bitWidth, value, /*isSigned=*/true));
+
+    Block *target;
+    SmallVector<OpAsmParser::UnresolvedOperand> operands;
+    SmallVector<Type> operandTypes;
+    if (failed(parser.parseColon()) || failed(parser.parseSuccessor(target)))
+      return failure();
+    if (succeeded(parser.parseOptionalLParen())) {
+      if (failed(parser.parseOperandList(operands,
+                                         OpAsmParser::Delimiter::None)) ||
+          failed(parser.parseColonTypeList(operandTypes)) ||
+          failed(parser.parseRParen()))
+        return failure();
+    }
+    targets.push_back(target);
+    targetOperands.emplace_back(operands);
+    targetOperandTypes.emplace_back(operandTypes);
+  }
+
+  if (!values.empty()) {
+    ShapedType literalType =
+        VectorType::get(static_cast<int64_t>(values.size()), selectorType);
+    literals = DenseIntElementsAttr::get(literalType, values);
+  }
+  return success();
+}
+
+static void
+printSwitchOpCases(OpAsmPrinter &p, SwitchOp op, Type selectorType,
+                   Block *defaultTarget, OperandRange defaultOperands,
+                   TypeRange defaultOperandTypes, DenseIntElementsAttr literals,
+                   SuccessorRange targets, OperandRangeRange targetOperands,
+                   const TypeRangeRange &targetOperandTypes) {
+  p << "  default: ";
+  p.printSuccessorAndUseList(defaultTarget, defaultOperands);
+
+  if (!literals)
+    return;
+
+  for (auto [index, literal] : llvm::enumerate(literals.getValues<APInt>())) {
+    p << ',';
+    p.printNewline();
+    p << "  ";
+    p << literal.getLimitedValue();
+    p << ": ";
+    p.printSuccessorAndUseList(targets[index], targetOperands[index]);
+  }
+  p.printNewline();
+}
+
 } // namespace mlir::spirv
 
 // TablenGen'erated operation definitions.

diff  --git a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
index c27f9aa91332c..5b04a14a78036 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
@@ -248,6 +248,8 @@ LogicalResult spirv::Deserializer::processInstruction(
     return processLoopMerge(operands);
   case spirv::Opcode::OpPhi:
     return processPhi(operands);
+  case spirv::Opcode::OpSwitch:
+    return processSwitch(operands);
   case spirv::Opcode::OpUndef:
     return processUndef(operands);
   default:

diff  --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 6492708694cc5..252be796488c5 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -2292,6 +2292,38 @@ LogicalResult spirv::Deserializer::processPhi(ArrayRef<uint32_t> operands) {
   return success();
 }
 
+LogicalResult spirv::Deserializer::processSwitch(ArrayRef<uint32_t> operands) {
+  if (!curBlock)
+    return emitError(unknownLoc, "OpSwitch must appear in a block");
+
+  if (operands.size() < 2)
+    return emitError(unknownLoc, "OpSwitch must at least specify selector and "
+                                 "a default target");
+
+  if (operands.size() % 2)
+    return emitError(unknownLoc,
+                     "OpSwitch must at have an even number of operands: "
+                     "selector, default target and any number of literal and "
+                     "label <id> pairs");
+
+  Value selector = getValue(operands[0]);
+  Block *defaultBlock = getOrCreateBlock(operands[1]);
+  Location loc = createFileLineColLoc(opBuilder);
+
+  SmallVector<int32_t> literals;
+  SmallVector<Block *> blocks;
+  for (unsigned i = 2, e = operands.size(); i < e; i += 2) {
+    literals.push_back(operands[i]);
+    blocks.push_back(getOrCreateBlock(operands[i + 1]));
+  }
+
+  SmallVector<ValueRange> targetOperands(blocks.size(), {});
+  spirv::SwitchOp::create(opBuilder, loc, selector, defaultBlock,
+                          ArrayRef<Value>(), literals, blocks, targetOperands);
+
+  return success();
+}
+
 namespace {
 /// A class for putting all blocks in a structured selection/loop in a
 /// spirv.mlir.selection/spirv.mlir.loop op.

diff  --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
index 6027f1ac94c23..243e6fd70ae43 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
@@ -472,6 +472,9 @@ class Deserializer {
   /// Processes a SPIR-V OpPhi instruction with the given `operands`.
   LogicalResult processPhi(ArrayRef<uint32_t> operands);
 
+  /// Processes a SPIR-V OpSwitch instruction with the given `operands`.
+  LogicalResult processSwitch(ArrayRef<uint32_t> operands);
+
   /// Creates block arguments on predecessors previously recorded when handling
   /// OpPhi instructions.
   LogicalResult wireUpBlockArgument();

diff  --git a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
index 85e92c7ced394..6397d2c005c16 100644
--- a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
@@ -775,6 +775,27 @@ LogicalResult Serializer::processBranchOp(spirv::BranchOp branchOp) {
   return success();
 }
 
+LogicalResult Serializer::processSwitchOp(spirv::SwitchOp switchOp) {
+  uint32_t selectorID = getValueID(switchOp.getSelector());
+  uint32_t defaultLabelID = getOrCreateBlockID(switchOp.getDefaultTarget());
+  SmallVector<uint32_t> arguments{selectorID, defaultLabelID};
+
+  std::optional<mlir::DenseIntElementsAttr> literals = switchOp.getLiterals();
+  BlockRange targets = switchOp.getTargets();
+  if (literals) {
+    for (auto [literal, target] : llvm::zip_equal(*literals, targets)) {
+      arguments.push_back(literal.getLimitedValue());
+      uint32_t targetLabelID = getOrCreateBlockID(target);
+      arguments.push_back(targetLabelID);
+    }
+  }
+
+  if (failed(emitDebugLine(functionBody, switchOp.getLoc())))
+    return failure();
+  encodeInstructionInto(functionBody, spirv::Opcode::OpSwitch, arguments);
+  return success();
+}
+
 LogicalResult Serializer::processAddressOfOp(spirv::AddressOfOp addressOfOp) {
   auto varName = addressOfOp.getVariable();
   auto variableID = getVariableID(varName);

diff  --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 29ed5a4fc139e..4e03a809bd0bc 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -1579,6 +1579,7 @@ LogicalResult Serializer::processOperation(Operation *opInst) {
       .Case([&](spirv::SpecConstantOperationOp op) {
         return processSpecConstantOperationOp(op);
       })
+      .Case([&](spirv::SwitchOp op) { return processSwitchOp(op); })
       .Case([&](spirv::UndefOp op) { return processUndefOp(op); })
       .Case([&](spirv::VariableOp op) { return processVariableOp(op); })
 

diff  --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.h b/mlir/lib/Target/SPIRV/Serialization/Serializer.h
index add372b19b5af..6e79c133eb6af 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.h
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.h
@@ -304,6 +304,8 @@ class Serializer {
 
   LogicalResult processBranchOp(spirv::BranchOp branchOp);
 
+  LogicalResult processSwitchOp(spirv::SwitchOp switchOp);
+
   //===--------------------------------------------------------------------===//
   // Operations
   //===--------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir b/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir
index 8e29ff6679068..b70bb40dae97f 100644
--- a/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir
@@ -795,6 +795,53 @@ func.func @selection(%cond: i1) -> () {
 
 // -----
 
+func.func @selection_switch(%selector: i32) -> () {
+  %zero = spirv.Constant 0: i32
+  %one = spirv.Constant 1: i32
+  %two = spirv.Constant 2: i32
+  %three = spirv.Constant 3: i32
+  %var = spirv.Variable init(%zero) : !spirv.ptr<i32, Function>
+
+  // CHECK: spirv.mlir.selection {
+  spirv.mlir.selection {
+    // CHECK-NEXT: spirv.Switch {{%.*}} : i32, [
+    // CHECK-NEXT: default: ^bb1,
+    // CHECK-NEXT: 0: ^bb2,
+    // CHECK-NEXT: 1: ^bb3
+    spirv.Switch %selector : i32, [
+      default: ^default,
+      0: ^case0,
+      1: ^case1
+    ]
+  // CHECK: ^bb1
+  ^default:
+    spirv.Store "Function" %var, %one : i32
+    // CHECK: spirv.Branch ^bb4
+    spirv.Branch ^merge
+
+  // CHECK: ^bb2
+  ^case0:
+    spirv.Store "Function" %var, %two : i32
+    // CHECK: spirv.Branch ^bb4
+    spirv.Branch ^merge
+
+  // CHECK: ^bb3
+  ^case1:
+    spirv.Store "Function" %var, %three : i32
+    // CHECK: spirv.Branch ^bb4
+    spirv.Branch ^merge
+
+  // CHECK: ^bb4
+  ^merge:
+    // CHECK-NEXT: spirv.mlir.merge
+    spirv.mlir.merge
+  }
+
+  spirv.Return
+}
+
+// -----
+
 // CHECK-LABEL: @empty_region
 func.func @empty_region() -> () {
   // CHECK: spirv.mlir.selection
@@ -918,3 +965,171 @@ func.func @kill() {
   // CHECK: spirv.Kill
   spirv.Kill
 }
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.Switch
+//===----------------------------------------------------------------------===//
+
+func.func @switch(%selector: i32) -> () {
+  // CHECK: spirv.Switch {{%.*}} : i32, [
+  // CHECK-NEXT: default: ^bb1,
+  // CHECK-NEXT: 0: ^bb2,
+  // CHECK-NEXT: 1: ^bb3,
+  // CHECK-NEXT: 2: ^bb4
+  spirv.Switch %selector : i32, [
+    default: ^default,
+    0: ^case0,
+    1: ^case1,
+    2: ^case2
+  ]
+^default:
+  spirv.Branch ^merge
+
+^case0:
+  spirv.Branch ^merge
+
+^case1:
+  spirv.Branch ^merge
+
+^case2:
+  spirv.Branch ^merge
+
+^merge:
+  spirv.Return
+}
+
+func.func @switch_only_default(%selector: i32) -> () {
+  // CHECK: spirv.Switch {{%.*}} : i32, [
+  // CHECK-NEXT: default: ^bb1
+  spirv.Switch %selector : i32, [
+    default: ^default
+  ]
+^default:
+  spirv.Branch ^merge
+
+^merge:
+  spirv.Return
+}
+
+func.func @switch_operands(%selector : i32, %operand : i32) {
+  // CHECK: spirv.Switch {{%.*}} : i32, [
+  // CHECK-NEXT: default: ^bb1({{%.*}} : i32),
+  // CHECK-NEXT: 0: ^bb2({{%.*}} : i32),
+  // CHECK-NEXT: 1: ^bb3({{%.*}} : i32)
+  spirv.Switch %selector : i32, [
+    default: ^default(%operand : i32),
+    0: ^case0(%operand : i32),
+    1: ^case1(%operand : i32)
+  ]
+^default(%argd : i32):
+  spirv.Branch ^merge
+
+^case0(%arg0 : i32):
+  spirv.Branch ^merge
+
+^case1(%arg1 : i32):
+  spirv.Branch ^merge
+
+^merge:
+  spirv.Return
+}
+
+// -----
+
+func.func @switch_float_selector(%selector: f32) -> () {
+  // expected-error at +1 {{expected builtin.integer, but found 'f32'}}
+  spirv.Switch %selector : f32, [
+    default: ^default
+  ]
+^default:
+  spirv.Branch ^merge
+
+^merge:
+  spirv.Return
+}
+
+// -----
+
+func.func @switch_float_selector(%selector: i32) -> () {
+  // expected-error at +3 {{expected integer value}}
+  spirv.Switch %selector : i32, [
+    default: ^default,
+    0.0: ^case0
+  ]
+^default:
+  spirv.Branch ^merge
+
+^case 0:
+  spirv.Branch ^merge
+
+^merge:
+  spirv.Return
+}
+
+// -----
+
+func.func @switch_missing_default(%selector: i32) -> () {
+  // expected-error at +2 {{expected 'default'}}
+  spirv.Switch %selector : i32, [
+    0: ^case0
+  ]
+^case 0:
+  spirv.Branch ^merge
+
+^merge:
+  spirv.Return
+}
+
+// -----
+
+func.func @switch_default_no_target(%selector: i32) -> () {
+  // expected-error at +2 {{expected block name}}
+  spirv.Switch %selector : i32, [
+    default:
+  ]
+^default:
+  spirv.Branch ^merge
+
+^merge:
+  spirv.Return
+}
+
+// -----
+
+func.func @switch_case_no_target(%selector: i32) -> () {
+  // expected-error at +3 {{expected block name}}
+  spirv.Switch %selector : i32, [
+    default: ^default,
+    0:
+  ]
+^default:
+  spirv.Branch ^merge
+
+^case 0:
+  spirv.Branch ^merge
+
+^merge:
+  spirv.Return
+}
+
+// -----
+
+func.func @switch_missing_operand_type(%selector: i32) -> () {
+  %0 = spirv.Constant 0 : i32
+  // expected-error at +2 {{expected ':'}}
+  spirv.Switch %selector : i32, [
+    default: ^default (%0),
+    0.0: ^case0
+  ]
+^default(%argd : i32):
+  spirv.Branch ^merge
+
+^case 0:
+  spirv.Branch ^merge
+
+^merge:
+  spirv.Return
+}
+

diff  --git a/mlir/test/Target/SPIRV/selection.mlir b/mlir/test/Target/SPIRV/selection.mlir
index 12daf68538d0a..3f762920015aa 100644
--- a/mlir/test/Target/SPIRV/selection.mlir
+++ b/mlir/test/Target/SPIRV/selection.mlir
@@ -220,3 +220,71 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
   spirv.EntryPoint "GLCompute" @main
   spirv.ExecutionMode @main "LocalSize", 1, 1, 1
 }
+
+// -----
+
+// Selection with switch
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
+// CHECK-LABEL: @selection_switch
+  spirv.func @selection_switch(%selector: i32) -> () "None" {
+    %zero = spirv.Constant 0: i32
+    %one = spirv.Constant 1: i32
+    %two = spirv.Constant 2: i32
+    %three = spirv.Constant 3: i32
+    %four = spirv.Constant 4: i32
+// CHECK: {{%.*}} = spirv.Variable init({{%.*}}) : !spirv.ptr<i32, Function>
+    %var = spirv.Variable init(%zero) : !spirv.ptr<i32, Function>
+// CHECK: spirv.mlir.selection {
+    spirv.mlir.selection {
+// CHECK-NEXT: spirv.Switch {{%.*}} : i32, [
+// CHECK-NEXT: default: ^[[DEFAULT:.+]],
+// CHECK-NEXT: 0: ^[[CASE0:.+]],
+// CHECK-NEXT: 1: ^[[CASE1:.+]],
+// CHECK-NEXT: 2: ^[[CASE2:.+]]
+      spirv.Switch %selector : i32, [
+        default: ^default,
+        0: ^case0,
+        1: ^case1,
+        2: ^case2
+      ]
+// CHECK: ^[[DEFAULT]]
+    ^default:
+// CHECK: spirv.Store "Function" {{%.*}}, {{%.*}} : i32
+      spirv.Store "Function" %var, %one : i32
+// CHECK-NEXT: spirv.Branch ^[[MERGE:.+]]
+      spirv.Branch ^merge
+// CHECK-NEXT: ^[[CASE0]]
+    ^case0:
+// CHECK: spirv.Store "Function" {{%.*}}, {{%.*}} : i32
+      spirv.Store "Function" %var, %two : i32
+// CHECK-NEXT: spirv.Branch ^[[MERGE:.+]]
+      spirv.Branch ^merge
+// CHECK-NEXT: ^[[CASE1]]
+    ^case1:
+// CHECK: spirv.Store "Function" {{%.*}}, {{%.*}} : i32
+      spirv.Store "Function" %var, %three : i32
+// CHECK-NEXT: spirv.Branch ^[[MERGE:.+]]
+      spirv.Branch ^merge
+// CHECK-NEXT: ^[[CASE2]]
+    ^case2:
+// CHECK: spirv.Store "Function" {{%.*}}, {{%.*}} : i32
+      spirv.Store "Function" %var, %four : i32
+// CHECK-NEXT: spirv.Branch ^[[MERGE:.+]]
+      spirv.Branch ^merge
+// CHECK-NEXT: ^[[MERGE]]
+    ^merge:
+// CHECK-NEXT: spirv.mlir.merge
+      spirv.mlir.merge
+// CHECK-NEXT: }
+    }
+// CHECK-NEXT: spirv.Return
+    spirv.Return
+  }
+
+  spirv.func @main() -> () "None" {
+    spirv.Return
+  }
+  spirv.EntryPoint "GLCompute" @main
+  spirv.ExecutionMode @main "LocalSize", 1, 1, 1
+}


        


More information about the Mlir-commits mailing list