[Mlir-commits] [mlir] 891b3cf - [mlir][spirv] Add support for SwitchOp (#168713)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Nov 20 07:19:15 PST 2025
Author: Igor Wodiany
Date: 2025-11-20T15:19:10Z
New Revision: 891b3cf63e160c83309b90728034ab832184c964
URL: https://github.com/llvm/llvm-project/commit/891b3cf63e160c83309b90728034ab832184c964
DIFF: https://github.com/llvm/llvm-project/commit/891b3cf63e160c83309b90728034ab832184c964.diff
LOG: [mlir][spirv] Add support for SwitchOp (#168713)
The dialect implementation mostly copies the one of `cf.switch`, but
aligns naming to the SPIR-V spec.
Added:
Modified:
mlir/docs/Dialects/SPIR-V.md
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td
mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp
mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
mlir/lib/Target/SPIRV/Serialization/Serializer.h
mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir
mlir/test/Target/SPIRV/selection.mlir
Removed:
################################################################################
diff --git a/mlir/docs/Dialects/SPIR-V.md b/mlir/docs/Dialects/SPIR-V.md
index 716dd7773aefa..dd68e6ec8b7b8 100644
--- a/mlir/docs/Dialects/SPIR-V.md
+++ b/mlir/docs/Dialects/SPIR-V.md
@@ -566,7 +566,7 @@ merge block.
For example, for the given function
```c++
-void loop(bool cond) {
+void if(bool cond) {
int x = 0;
if (cond) {
x = 1;
@@ -605,6 +605,62 @@ func.func @selection(%cond: i1) -> () {
}
```
+Similarly, for the give function with a `switch` statement
+
+```c++
+void switch(int selector) {
+ int x = 0;
+ switch (selector) {
+ case 0:
+ x = 2;
+ break;
+ case 1:
+ x = 3;
+ break;
+ default:
+ x = 1;
+ break;
+ }
+ // ...
+}
+```
+
+It will be represented as
+
+```mlir
+func.func @selection(%selector: i32) -> () {
+ %zero = spirv.Constant 0: i32
+ %one = spirv.Constant 1: i32
+ %two = spirv.Constant 2: i32
+ %three = spirv.Constant 3: i32
+ %var = spirv.Variable init(%zero) : !spirv.ptr<i32, Function>
+
+ spirv.mlir.selection {
+ spirv.Switch %selector : i32, [
+ default: ^default,
+ 0: ^case0,
+ 1: ^case1
+ ]
+ ^default:
+ spirv.Store "Function" %var, %one : i32
+ spirv.Branch ^merge
+
+ ^case0:
+ spirv.Store "Function" %var, %two : i32
+ spirv.Branch ^merge
+
+ ^case1:
+ spirv.Store "Function" %var, %three : i32
+ spirv.Branch ^merge
+
+ ^merge:
+ spirv.mlir.merge
+ }
+
+ // ...
+}
+```
+
The selection can return values by yielding them with `spirv.mlir.merge`. This
mechanism allows values defined within the selection region to be used outside of it.
Without this, values that were sunk into the selection region, but used outside, would
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index b628f1a3f7b20..7b363fac6e627 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4531,6 +4531,7 @@ def SPIRV_OC_OpSelectionMerge : I32EnumAttrCase<"OpSelectionMerg
def SPIRV_OC_OpLabel : I32EnumAttrCase<"OpLabel", 248>;
def SPIRV_OC_OpBranch : I32EnumAttrCase<"OpBranch", 249>;
def SPIRV_OC_OpBranchConditional : I32EnumAttrCase<"OpBranchConditional", 250>;
+def SPIRV_OC_OpSwitch : I32EnumAttrCase<"OpSwitch", 251>;
def SPIRV_OC_OpKill : I32EnumAttrCase<"OpKill", 252>;
def SPIRV_OC_OpReturn : I32EnumAttrCase<"OpReturn", 253>;
def SPIRV_OC_OpReturnValue : I32EnumAttrCase<"OpReturnValue", 254>;
@@ -4681,7 +4682,7 @@ def SPIRV_OpcodeAttr :
SPIRV_OC_OpAtomicAnd, SPIRV_OC_OpAtomicOr, SPIRV_OC_OpAtomicXor,
SPIRV_OC_OpPhi, SPIRV_OC_OpLoopMerge, SPIRV_OC_OpSelectionMerge,
SPIRV_OC_OpLabel, SPIRV_OC_OpBranch, SPIRV_OC_OpBranchConditional,
- SPIRV_OC_OpKill, SPIRV_OC_OpReturn, SPIRV_OC_OpReturnValue,
+ SPIRV_OC_OpSwitch, SPIRV_OC_OpKill, SPIRV_OC_OpReturn, SPIRV_OC_OpReturnValue,
SPIRV_OC_OpUnreachable, SPIRV_OC_OpGroupBroadcast, SPIRV_OC_OpGroupIAdd,
SPIRV_OC_OpGroupFAdd, SPIRV_OC_OpGroupFMin, SPIRV_OC_OpGroupUMin,
SPIRV_OC_OpGroupSMin, SPIRV_OC_OpGroupFMax, SPIRV_OC_OpGroupUMax,
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td
index acb6467132be9..27c9add7d43af 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td
@@ -242,6 +242,112 @@ def SPIRV_FunctionCallOp : SPIRV_Op<"FunctionCall", [
}];
}
+// -----
+
+def SPIRV_SwitchOp : SPIRV_Op<"Switch",
+ [AttrSizedOperandSegments, InFunctionScope,
+ DeclareOpInterfaceMethods<BranchOpInterface, ["getSuccessorForOperands"]>,
+ Pure, Terminator]> {
+ let summary = [{
+ Multi-way branch to one of the operand label <id>.
+ }];
+
+ let description = [{
+ Selector must have a type of OpTypeInt. Selector is compared for equality to
+ the Target literals.
+
+ Default must be the <id> of a label. If Selector does not equal any of the
+ Target literals, control flow branches to the Default label <id>.
+
+ Target must be alternating scalar integer literals and the <id> of a label.
+ If Selector equals a literal, control flow branches to the following label
+ <id>. It is invalid for any two literal to be equal to each other. If Selector
+ does not equal any literal, control flow branches to the Default label <id>.
+ Each literal is interpreted with the type of Selector: The bit width of
+ Selector’s type is the width of each literal’s type. If this width is not a
+ multiple of 32-bits and the OpTypeInt Signedness is set to 1, the literal values
+ are interpreted as being sign extended.
+
+ If Selector is an OpUndef, behavior is undefined.
+
+ This instruction must be the last instruction in a block.
+
+ #### Example:
+
+ ```mlir
+ spirv.Switch %selector : si32, [
+ default: ^bb1(%a : i32),
+ 0: ^bb1(%b : i32),
+ 1: ^bb3(%c : i32)
+ ]
+ ```
+ }];
+
+ let arguments = (ins
+ SPIRV_Integer:$selector,
+ Variadic<AnyType>:$defaultOperands,
+ VariadicOfVariadic<AnyType, "case_operand_segments">:$targetOperands,
+ OptionalAttr<AnyIntElementsAttr>:$literals,
+ DenseI32ArrayAttr:$case_operand_segments
+ );
+
+ let results = (outs);
+
+ let successors = (successor AnySuccessor:$defaultTarget,
+ VariadicSuccessor<AnySuccessor>:$targets);
+
+ let builders = [
+ OpBuilder<(ins "Value":$selector,
+ "Block *":$defaultTarget,
+ "ValueRange":$defaultOperands,
+ CArg<"ArrayRef<APInt>", "{}">:$literals,
+ CArg<"BlockRange", "{}">:$targets,
+ CArg<"ArrayRef<ValueRange>", "{}">:$targetOperands)>,
+ OpBuilder<(ins "Value":$selector,
+ "Block *":$defaultTarget,
+ "ValueRange":$defaultOperands,
+ CArg<"ArrayRef<int32_t>", "{}">:$literals,
+ CArg<"BlockRange", "{}">:$targets,
+ CArg<"ArrayRef<ValueRange>", "{}">:$targetOperands)>,
+ OpBuilder<(ins "Value":$selector,
+ "Block *":$defaultTarget,
+ "ValueRange":$defaultOperands,
+ CArg<"DenseIntElementsAttr", "{}">:$literals,
+ CArg<"BlockRange", "{}">:$targets,
+ CArg<"ArrayRef<ValueRange>", "{}">:$targetOperands)>
+ ];
+
+ let assemblyFormat = [{
+ $selector `:` type($selector) `,` `[` `\n`
+ custom<SwitchOpCases>(ref(type($selector)),$defaultTarget,
+ $defaultOperands,
+ type($defaultOperands),
+ $literals,
+ $targets,
+ $targetOperands,
+ type($targetOperands))
+ `]`
+ attr-dict
+ }];
+
+ let extraClassDeclaration = [{
+ /// Return the operands for the target block at the given index.
+ OperandRange getTargetOperands(unsigned index) {
+ return getTargetOperands()[index];
+ }
+
+ /// Return a mutable range of operands for the target block at the
+ /// given index.
+ MutableOperandRange getTargetOperandsMutable(unsigned index) {
+ return getTargetOperandsMutable()[index];
+ }
+ }];
+
+ let autogenSerialization = 0;
+ let hasVerifier = 1;
+}
+
+
// -----
def SPIRV_KillOp : SPIRV_Op<"Kill", [Terminator]> {
diff --git a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
index f0b46e61965f4..a846d7e60024c 100644
--- a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
@@ -219,6 +219,89 @@ MutableOperandRange FunctionCallOp::getArgOperandsMutable() {
return getArgumentsMutable();
}
+//===----------------------------------------------------------------------===//
+// spirv.Switch
+//===----------------------------------------------------------------------===//
+
+void SwitchOp::build(OpBuilder &builder, OperationState &result, Value selector,
+ Block *defaultTarget, ValueRange defaultOperands,
+ DenseIntElementsAttr literals, BlockRange targets,
+ ArrayRef<ValueRange> targetOperands) {
+ build(builder, result, selector, defaultOperands, targetOperands, literals,
+ defaultTarget, targets);
+}
+
+void SwitchOp::build(OpBuilder &builder, OperationState &result, Value selector,
+ Block *defaultTarget, ValueRange defaultOperands,
+ ArrayRef<APInt> literals, BlockRange targets,
+ ArrayRef<ValueRange> targetOperands) {
+ DenseIntElementsAttr literalsAttr;
+ if (!literals.empty()) {
+ ShapedType literalType = VectorType::get(
+ static_cast<int64_t>(literals.size()), selector.getType());
+ literalsAttr = DenseIntElementsAttr::get(literalType, literals);
+ }
+ build(builder, result, selector, defaultTarget, defaultOperands, literalsAttr,
+ targets, targetOperands);
+}
+
+void SwitchOp::build(OpBuilder &builder, OperationState &result, Value selector,
+ Block *defaultTarget, ValueRange defaultOperands,
+ ArrayRef<int32_t> literals, BlockRange targets,
+ ArrayRef<ValueRange> targetOperands) {
+ DenseIntElementsAttr literalsAttr;
+ if (!literals.empty()) {
+ ShapedType literalType = VectorType::get(
+ static_cast<int64_t>(literals.size()), selector.getType());
+ literalsAttr = DenseIntElementsAttr::get(literalType, literals);
+ }
+ build(builder, result, selector, defaultTarget, defaultOperands, literalsAttr,
+ targets, targetOperands);
+}
+
+LogicalResult SwitchOp::verify() {
+ std::optional<DenseIntElementsAttr> literals = getLiterals();
+ BlockRange targets = getTargets();
+
+ if (!literals && targets.empty())
+ return success();
+
+ Type selectorType = getSelector().getType();
+ Type literalType = literals->getType().getElementType();
+ if (literalType != selectorType)
+ return emitOpError() << "'selector' type (" << selectorType
+ << ") should match literals type (" << literalType
+ << ")";
+
+ if (literals && literals->size() != static_cast<int64_t>(targets.size()))
+ return emitOpError() << "number of literals (" << literals->size()
+ << ") should match number of targets ("
+ << targets.size() << ")";
+ return success();
+}
+
+SuccessorOperands SwitchOp::getSuccessorOperands(unsigned index) {
+ assert(index < getNumSuccessors() && "invalid successor index");
+ return SuccessorOperands(index == 0 ? getDefaultOperandsMutable()
+ : getTargetOperandsMutable(index - 1));
+}
+
+Block *SwitchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
+ std::optional<DenseIntElementsAttr> literals = getLiterals();
+
+ if (!literals)
+ return getDefaultTarget();
+
+ SuccessorRange targets = getTargets();
+ if (auto value = dyn_cast_or_null<IntegerAttr>(operands.front())) {
+ for (auto [index, literal] : llvm::enumerate(literals->getValues<APInt>()))
+ if (literal == value.getValue())
+ return targets[index];
+ return getDefaultTarget();
+ }
+ return nullptr;
+}
+
//===----------------------------------------------------------------------===//
// spirv.mlir.loop
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp
index 2f3a28ff16173..8575487ff52cc 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp
@@ -81,6 +81,83 @@ static void printImageOperands(OpAsmPrinter &printer, Operation *imageOp,
}
}
+/// Adapted from the cf.switch implementation.
+/// <cases> ::= `default` `:` bb-id (`(` ssa-use-and-type-list `)`)?
+/// ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )*
+static ParseResult parseSwitchOpCases(
+ OpAsmParser &parser, Type &selectorType, Block *&defaultTarget,
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &defaultOperands,
+ SmallVectorImpl<Type> &defaultOperandTypes, DenseIntElementsAttr &literals,
+ SmallVectorImpl<Block *> &targets,
+ SmallVectorImpl<SmallVector<OpAsmParser::UnresolvedOperand>>
+ &targetOperands,
+ SmallVectorImpl<SmallVector<Type>> &targetOperandTypes) {
+ if (parser.parseKeyword("default") || parser.parseColon() ||
+ parser.parseSuccessor(defaultTarget))
+ return failure();
+ if (succeeded(parser.parseOptionalLParen())) {
+ if (parser.parseOperandList(defaultOperands, OpAsmParser::Delimiter::None,
+ /*allowResultNumber=*/false) ||
+ parser.parseColonTypeList(defaultOperandTypes) || parser.parseRParen())
+ return failure();
+ }
+
+ SmallVector<APInt> values;
+ unsigned bitWidth = selectorType.getIntOrFloatBitWidth();
+ while (succeeded(parser.parseOptionalComma())) {
+ int64_t value = 0;
+ if (failed(parser.parseInteger(value)))
+ return failure();
+ values.push_back(APInt(bitWidth, value, /*isSigned=*/true));
+
+ Block *target;
+ SmallVector<OpAsmParser::UnresolvedOperand> operands;
+ SmallVector<Type> operandTypes;
+ if (failed(parser.parseColon()) || failed(parser.parseSuccessor(target)))
+ return failure();
+ if (succeeded(parser.parseOptionalLParen())) {
+ if (failed(parser.parseOperandList(operands,
+ OpAsmParser::Delimiter::None)) ||
+ failed(parser.parseColonTypeList(operandTypes)) ||
+ failed(parser.parseRParen()))
+ return failure();
+ }
+ targets.push_back(target);
+ targetOperands.emplace_back(operands);
+ targetOperandTypes.emplace_back(operandTypes);
+ }
+
+ if (!values.empty()) {
+ ShapedType literalType =
+ VectorType::get(static_cast<int64_t>(values.size()), selectorType);
+ literals = DenseIntElementsAttr::get(literalType, values);
+ }
+ return success();
+}
+
+static void
+printSwitchOpCases(OpAsmPrinter &p, SwitchOp op, Type selectorType,
+ Block *defaultTarget, OperandRange defaultOperands,
+ TypeRange defaultOperandTypes, DenseIntElementsAttr literals,
+ SuccessorRange targets, OperandRangeRange targetOperands,
+ const TypeRangeRange &targetOperandTypes) {
+ p << " default: ";
+ p.printSuccessorAndUseList(defaultTarget, defaultOperands);
+
+ if (!literals)
+ return;
+
+ for (auto [index, literal] : llvm::enumerate(literals.getValues<APInt>())) {
+ p << ',';
+ p.printNewline();
+ p << " ";
+ p << literal.getLimitedValue();
+ p << ": ";
+ p.printSuccessorAndUseList(targets[index], targetOperands[index]);
+ }
+ p.printNewline();
+}
+
} // namespace mlir::spirv
// TablenGen'erated operation definitions.
diff --git a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
index c27f9aa91332c..5b04a14a78036 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
@@ -248,6 +248,8 @@ LogicalResult spirv::Deserializer::processInstruction(
return processLoopMerge(operands);
case spirv::Opcode::OpPhi:
return processPhi(operands);
+ case spirv::Opcode::OpSwitch:
+ return processSwitch(operands);
case spirv::Opcode::OpUndef:
return processUndef(operands);
default:
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 6492708694cc5..252be796488c5 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -2292,6 +2292,38 @@ LogicalResult spirv::Deserializer::processPhi(ArrayRef<uint32_t> operands) {
return success();
}
+LogicalResult spirv::Deserializer::processSwitch(ArrayRef<uint32_t> operands) {
+ if (!curBlock)
+ return emitError(unknownLoc, "OpSwitch must appear in a block");
+
+ if (operands.size() < 2)
+ return emitError(unknownLoc, "OpSwitch must at least specify selector and "
+ "a default target");
+
+ if (operands.size() % 2)
+ return emitError(unknownLoc,
+ "OpSwitch must at have an even number of operands: "
+ "selector, default target and any number of literal and "
+ "label <id> pairs");
+
+ Value selector = getValue(operands[0]);
+ Block *defaultBlock = getOrCreateBlock(operands[1]);
+ Location loc = createFileLineColLoc(opBuilder);
+
+ SmallVector<int32_t> literals;
+ SmallVector<Block *> blocks;
+ for (unsigned i = 2, e = operands.size(); i < e; i += 2) {
+ literals.push_back(operands[i]);
+ blocks.push_back(getOrCreateBlock(operands[i + 1]));
+ }
+
+ SmallVector<ValueRange> targetOperands(blocks.size(), {});
+ spirv::SwitchOp::create(opBuilder, loc, selector, defaultBlock,
+ ArrayRef<Value>(), literals, blocks, targetOperands);
+
+ return success();
+}
+
namespace {
/// A class for putting all blocks in a structured selection/loop in a
/// spirv.mlir.selection/spirv.mlir.loop op.
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
index 6027f1ac94c23..243e6fd70ae43 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
@@ -472,6 +472,9 @@ class Deserializer {
/// Processes a SPIR-V OpPhi instruction with the given `operands`.
LogicalResult processPhi(ArrayRef<uint32_t> operands);
+ /// Processes a SPIR-V OpSwitch instruction with the given `operands`.
+ LogicalResult processSwitch(ArrayRef<uint32_t> operands);
+
/// Creates block arguments on predecessors previously recorded when handling
/// OpPhi instructions.
LogicalResult wireUpBlockArgument();
diff --git a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
index 85e92c7ced394..6397d2c005c16 100644
--- a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
@@ -775,6 +775,27 @@ LogicalResult Serializer::processBranchOp(spirv::BranchOp branchOp) {
return success();
}
+LogicalResult Serializer::processSwitchOp(spirv::SwitchOp switchOp) {
+ uint32_t selectorID = getValueID(switchOp.getSelector());
+ uint32_t defaultLabelID = getOrCreateBlockID(switchOp.getDefaultTarget());
+ SmallVector<uint32_t> arguments{selectorID, defaultLabelID};
+
+ std::optional<mlir::DenseIntElementsAttr> literals = switchOp.getLiterals();
+ BlockRange targets = switchOp.getTargets();
+ if (literals) {
+ for (auto [literal, target] : llvm::zip_equal(*literals, targets)) {
+ arguments.push_back(literal.getLimitedValue());
+ uint32_t targetLabelID = getOrCreateBlockID(target);
+ arguments.push_back(targetLabelID);
+ }
+ }
+
+ if (failed(emitDebugLine(functionBody, switchOp.getLoc())))
+ return failure();
+ encodeInstructionInto(functionBody, spirv::Opcode::OpSwitch, arguments);
+ return success();
+}
+
LogicalResult Serializer::processAddressOfOp(spirv::AddressOfOp addressOfOp) {
auto varName = addressOfOp.getVariable();
auto variableID = getVariableID(varName);
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 29ed5a4fc139e..4e03a809bd0bc 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -1579,6 +1579,7 @@ LogicalResult Serializer::processOperation(Operation *opInst) {
.Case([&](spirv::SpecConstantOperationOp op) {
return processSpecConstantOperationOp(op);
})
+ .Case([&](spirv::SwitchOp op) { return processSwitchOp(op); })
.Case([&](spirv::UndefOp op) { return processUndefOp(op); })
.Case([&](spirv::VariableOp op) { return processVariableOp(op); })
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.h b/mlir/lib/Target/SPIRV/Serialization/Serializer.h
index add372b19b5af..6e79c133eb6af 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.h
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.h
@@ -304,6 +304,8 @@ class Serializer {
LogicalResult processBranchOp(spirv::BranchOp branchOp);
+ LogicalResult processSwitchOp(spirv::SwitchOp switchOp);
+
//===--------------------------------------------------------------------===//
// Operations
//===--------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir b/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir
index 8e29ff6679068..b70bb40dae97f 100644
--- a/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir
@@ -795,6 +795,53 @@ func.func @selection(%cond: i1) -> () {
// -----
+func.func @selection_switch(%selector: i32) -> () {
+ %zero = spirv.Constant 0: i32
+ %one = spirv.Constant 1: i32
+ %two = spirv.Constant 2: i32
+ %three = spirv.Constant 3: i32
+ %var = spirv.Variable init(%zero) : !spirv.ptr<i32, Function>
+
+ // CHECK: spirv.mlir.selection {
+ spirv.mlir.selection {
+ // CHECK-NEXT: spirv.Switch {{%.*}} : i32, [
+ // CHECK-NEXT: default: ^bb1,
+ // CHECK-NEXT: 0: ^bb2,
+ // CHECK-NEXT: 1: ^bb3
+ spirv.Switch %selector : i32, [
+ default: ^default,
+ 0: ^case0,
+ 1: ^case1
+ ]
+ // CHECK: ^bb1
+ ^default:
+ spirv.Store "Function" %var, %one : i32
+ // CHECK: spirv.Branch ^bb4
+ spirv.Branch ^merge
+
+ // CHECK: ^bb2
+ ^case0:
+ spirv.Store "Function" %var, %two : i32
+ // CHECK: spirv.Branch ^bb4
+ spirv.Branch ^merge
+
+ // CHECK: ^bb3
+ ^case1:
+ spirv.Store "Function" %var, %three : i32
+ // CHECK: spirv.Branch ^bb4
+ spirv.Branch ^merge
+
+ // CHECK: ^bb4
+ ^merge:
+ // CHECK-NEXT: spirv.mlir.merge
+ spirv.mlir.merge
+ }
+
+ spirv.Return
+}
+
+// -----
+
// CHECK-LABEL: @empty_region
func.func @empty_region() -> () {
// CHECK: spirv.mlir.selection
@@ -918,3 +965,171 @@ func.func @kill() {
// CHECK: spirv.Kill
spirv.Kill
}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.Switch
+//===----------------------------------------------------------------------===//
+
+func.func @switch(%selector: i32) -> () {
+ // CHECK: spirv.Switch {{%.*}} : i32, [
+ // CHECK-NEXT: default: ^bb1,
+ // CHECK-NEXT: 0: ^bb2,
+ // CHECK-NEXT: 1: ^bb3,
+ // CHECK-NEXT: 2: ^bb4
+ spirv.Switch %selector : i32, [
+ default: ^default,
+ 0: ^case0,
+ 1: ^case1,
+ 2: ^case2
+ ]
+^default:
+ spirv.Branch ^merge
+
+^case0:
+ spirv.Branch ^merge
+
+^case1:
+ spirv.Branch ^merge
+
+^case2:
+ spirv.Branch ^merge
+
+^merge:
+ spirv.Return
+}
+
+func.func @switch_only_default(%selector: i32) -> () {
+ // CHECK: spirv.Switch {{%.*}} : i32, [
+ // CHECK-NEXT: default: ^bb1
+ spirv.Switch %selector : i32, [
+ default: ^default
+ ]
+^default:
+ spirv.Branch ^merge
+
+^merge:
+ spirv.Return
+}
+
+func.func @switch_operands(%selector : i32, %operand : i32) {
+ // CHECK: spirv.Switch {{%.*}} : i32, [
+ // CHECK-NEXT: default: ^bb1({{%.*}} : i32),
+ // CHECK-NEXT: 0: ^bb2({{%.*}} : i32),
+ // CHECK-NEXT: 1: ^bb3({{%.*}} : i32)
+ spirv.Switch %selector : i32, [
+ default: ^default(%operand : i32),
+ 0: ^case0(%operand : i32),
+ 1: ^case1(%operand : i32)
+ ]
+^default(%argd : i32):
+ spirv.Branch ^merge
+
+^case0(%arg0 : i32):
+ spirv.Branch ^merge
+
+^case1(%arg1 : i32):
+ spirv.Branch ^merge
+
+^merge:
+ spirv.Return
+}
+
+// -----
+
+func.func @switch_float_selector(%selector: f32) -> () {
+ // expected-error at +1 {{expected builtin.integer, but found 'f32'}}
+ spirv.Switch %selector : f32, [
+ default: ^default
+ ]
+^default:
+ spirv.Branch ^merge
+
+^merge:
+ spirv.Return
+}
+
+// -----
+
+func.func @switch_float_selector(%selector: i32) -> () {
+ // expected-error at +3 {{expected integer value}}
+ spirv.Switch %selector : i32, [
+ default: ^default,
+ 0.0: ^case0
+ ]
+^default:
+ spirv.Branch ^merge
+
+^case 0:
+ spirv.Branch ^merge
+
+^merge:
+ spirv.Return
+}
+
+// -----
+
+func.func @switch_missing_default(%selector: i32) -> () {
+ // expected-error at +2 {{expected 'default'}}
+ spirv.Switch %selector : i32, [
+ 0: ^case0
+ ]
+^case 0:
+ spirv.Branch ^merge
+
+^merge:
+ spirv.Return
+}
+
+// -----
+
+func.func @switch_default_no_target(%selector: i32) -> () {
+ // expected-error at +2 {{expected block name}}
+ spirv.Switch %selector : i32, [
+ default:
+ ]
+^default:
+ spirv.Branch ^merge
+
+^merge:
+ spirv.Return
+}
+
+// -----
+
+func.func @switch_case_no_target(%selector: i32) -> () {
+ // expected-error at +3 {{expected block name}}
+ spirv.Switch %selector : i32, [
+ default: ^default,
+ 0:
+ ]
+^default:
+ spirv.Branch ^merge
+
+^case 0:
+ spirv.Branch ^merge
+
+^merge:
+ spirv.Return
+}
+
+// -----
+
+func.func @switch_missing_operand_type(%selector: i32) -> () {
+ %0 = spirv.Constant 0 : i32
+ // expected-error at +2 {{expected ':'}}
+ spirv.Switch %selector : i32, [
+ default: ^default (%0),
+ 0.0: ^case0
+ ]
+^default(%argd : i32):
+ spirv.Branch ^merge
+
+^case 0:
+ spirv.Branch ^merge
+
+^merge:
+ spirv.Return
+}
+
diff --git a/mlir/test/Target/SPIRV/selection.mlir b/mlir/test/Target/SPIRV/selection.mlir
index 12daf68538d0a..3f762920015aa 100644
--- a/mlir/test/Target/SPIRV/selection.mlir
+++ b/mlir/test/Target/SPIRV/selection.mlir
@@ -220,3 +220,71 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
spirv.EntryPoint "GLCompute" @main
spirv.ExecutionMode @main "LocalSize", 1, 1, 1
}
+
+// -----
+
+// Selection with switch
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
+// CHECK-LABEL: @selection_switch
+ spirv.func @selection_switch(%selector: i32) -> () "None" {
+ %zero = spirv.Constant 0: i32
+ %one = spirv.Constant 1: i32
+ %two = spirv.Constant 2: i32
+ %three = spirv.Constant 3: i32
+ %four = spirv.Constant 4: i32
+// CHECK: {{%.*}} = spirv.Variable init({{%.*}}) : !spirv.ptr<i32, Function>
+ %var = spirv.Variable init(%zero) : !spirv.ptr<i32, Function>
+// CHECK: spirv.mlir.selection {
+ spirv.mlir.selection {
+// CHECK-NEXT: spirv.Switch {{%.*}} : i32, [
+// CHECK-NEXT: default: ^[[DEFAULT:.+]],
+// CHECK-NEXT: 0: ^[[CASE0:.+]],
+// CHECK-NEXT: 1: ^[[CASE1:.+]],
+// CHECK-NEXT: 2: ^[[CASE2:.+]]
+ spirv.Switch %selector : i32, [
+ default: ^default,
+ 0: ^case0,
+ 1: ^case1,
+ 2: ^case2
+ ]
+// CHECK: ^[[DEFAULT]]
+ ^default:
+// CHECK: spirv.Store "Function" {{%.*}}, {{%.*}} : i32
+ spirv.Store "Function" %var, %one : i32
+// CHECK-NEXT: spirv.Branch ^[[MERGE:.+]]
+ spirv.Branch ^merge
+// CHECK-NEXT: ^[[CASE0]]
+ ^case0:
+// CHECK: spirv.Store "Function" {{%.*}}, {{%.*}} : i32
+ spirv.Store "Function" %var, %two : i32
+// CHECK-NEXT: spirv.Branch ^[[MERGE:.+]]
+ spirv.Branch ^merge
+// CHECK-NEXT: ^[[CASE1]]
+ ^case1:
+// CHECK: spirv.Store "Function" {{%.*}}, {{%.*}} : i32
+ spirv.Store "Function" %var, %three : i32
+// CHECK-NEXT: spirv.Branch ^[[MERGE:.+]]
+ spirv.Branch ^merge
+// CHECK-NEXT: ^[[CASE2]]
+ ^case2:
+// CHECK: spirv.Store "Function" {{%.*}}, {{%.*}} : i32
+ spirv.Store "Function" %var, %four : i32
+// CHECK-NEXT: spirv.Branch ^[[MERGE:.+]]
+ spirv.Branch ^merge
+// CHECK-NEXT: ^[[MERGE]]
+ ^merge:
+// CHECK-NEXT: spirv.mlir.merge
+ spirv.mlir.merge
+// CHECK-NEXT: }
+ }
+// CHECK-NEXT: spirv.Return
+ spirv.Return
+ }
+
+ spirv.func @main() -> () "None" {
+ spirv.Return
+ }
+ spirv.EntryPoint "GLCompute" @main
+ spirv.ExecutionMode @main "LocalSize", 1, 1, 1
+}
More information about the Mlir-commits
mailing list