[Mlir-commits] [mlir] [mlir][emitc] Add 'emitc.switch' op to the dialect (PR #102331)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Aug 7 09:49:12 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-emitc

Author: Andrey Timonin (EtoAndruwa)

<details>
<summary>Changes</summary>

This PR is continuation of the [previous one](https://github.com/llvm/llvm-project/pull/101478).  As a result, the `emitc::SwitchOp` op was developed inspired by `scf::IndexSwitchOp`.

Main points of PR:

- Added the `emitc::SwitchOp` op  to the EmitC dialect + CppEmitter
- Corresponding tests were added
- Conversion from the SCF dialect to the EmitC dialect will be supported after the approvement this/adjusted version of the `emitc::SwitchOp` op

---

Patch is 23.63 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/102331.diff


6 Files Affected:

- (modified) mlir/include/mlir/Dialect/EmitC/IR/EmitC.h (+3) 
- (modified) mlir/include/mlir/Dialect/EmitC/IR/EmitC.td (+88-1) 
- (modified) mlir/lib/Dialect/EmitC/IR/EmitC.cpp (+208) 
- (modified) mlir/lib/Target/Cpp/TranslateToCpp.cpp (+46-3) 
- (modified) mlir/test/Target/Cpp/invalid.mlir (+129) 
- (added) mlir/test/Target/Cpp/switch.mlir (+129) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
index 87a4078f280f65..212271ea813080 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
@@ -47,6 +47,9 @@ bool isSupportedFloatType(mlir::Type type);
 /// Determines whether \p type is a emitc.size_t/ssize_t type.
 bool isPointerWideType(mlir::Type type);
 
+/// Determines whether \p type is a valid integer type for SwitchOp.
+bool isIntegerSwitchOperandType(Type type);
+
 } // namespace emitc
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index 16fa58ce47013d..4d0cedb91a3915 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -54,6 +54,8 @@ def CExpression : NativeOpTrait<"emitc::CExpression">;
 def IntegerIndexOrOpaqueType : Type<CPred<"emitc::isIntegerIndexOrOpaqueType($_self)">,
 "integer, index or opaque type supported by EmitC">;
 def FloatIntegerIndexOrOpaqueType : AnyTypeOf<[EmitCFloatType, IntegerIndexOrOpaqueType]>;
+def IntegerSwitchOperandType : Type<CPred<"emitc::isIntegerSwitchOperandType($_self)">,
+"integer type for switch operation">;
 
 def EmitC_AddOp : EmitC_BinaryOp<"add", [CExpression]> {
   let summary = "Addition operation";
@@ -1188,7 +1190,7 @@ def EmitC_AssignOp : EmitC_Op<"assign", []> {
 }
 
 def EmitC_YieldOp : EmitC_Op<"yield",
-      [Pure, Terminator, ParentOneOf<["ExpressionOp", "IfOp", "ForOp"]>]> {
+      [Pure, Terminator, ParentOneOf<["ExpressionOp", "IfOp", "ForOp", "SwitchOp"]>]> {
   let summary = "Block termination operation";
   let description = [{
     The `emitc.yield` terminates its parent EmitC op's region, optionally yielding
@@ -1302,5 +1304,90 @@ def EmitC_SubscriptOp : EmitC_Op<"subscript", []> {
   let assemblyFormat = "$value `[` $indices `]` attr-dict `:` functional-type(operands, results)";
 }
 
+def EmitC_SwitchOp : EmitC_Op<"switch", [RecursiveMemoryEffects,
+    SingleBlockImplicitTerminator<"emitc::YieldOp">,
+    DeclareOpInterfaceMethods<RegionBranchOpInterface,
+                              ["getRegionInvocationBounds",
+                               "getEntrySuccessorRegions"]>]> {
+  let summary = "Switch operation";
+  let description = [{
+    The `emitc.switch` is a control-flow operation that branches to one of
+    the given regions based on the values of the argument and the cases. The
+    argument is always of type integer (singed or unsigned), excluding i8 and i1.
+    If the type is not specified, then i32 will be used by default.
+
+    The operation always has a "default" region and any number of case regions
+    denoted by integer constants. Control-flow transfers to the case region
+    whose constant value equals the value of the argument. If the argument does
+    not equal any of the case values, control-flow transfer to the "default"
+    region.
+
+    The operation does not return any value. Moreover, case regions and 
+    default region must be terminated using the `emitc.yield` operation.
+
+    Example:
+
+    ```mlir
+    // Unspecified type for the cases (default i32 set).
+    %0 = "emitc.variable"(){value = 42 : i32} : () -> i32
+    emitc.switch %0
+    case 2: {
+      %1 = emitc.call_opaque "func_b" () : () -> i32
+      emitc.yield
+    }
+    case 5: {
+      %2 = emitc.call_opaque "func_a" () : () -> i32
+      emitc.yield
+    }
+    default : {
+      %3 = "emitc.variable"(){value = 42.0 : f32} : () -> f32
+      %4 = "emitc.variable"(){value = 42.0 : f32} : () -> f32
+
+      emitc.call_opaque "func2" (%3) : (f32) -> ()
+      emitc.call_opaque "func3" (%3, %4) { args = [1 : index, 0 : index] } : (f32, f32) -> ()
+      emitc.yield
+    }
+    ...
+    // Specified type for the cases (i16).
+    %0 = "emitc.variable"(){value = 42 : i16} : () -> i16
+    emitc.switch %0 : i16
+    case 2: {
+      %1 = emitc.call_opaque "func_b" () : () -> i32
+      emitc.yield
+    }
+    case 5: {
+      %2 = emitc.call_opaque "func_a" () : () -> i32
+      emitc.yield
+    }
+    default : {
+      %3 = "emitc.variable"(){value = 42.0 : f32} : () -> f32
+      %4 = "emitc.variable"(){value = 42.0 : f32} : () -> f32
+
+      emitc.call_opaque "func2" (%3) : (f32) -> ()
+      emitc.call_opaque "func3" (%3, %4) { args = [1 : index, 0 : index] } : (f32, f32) -> ()
+      emitc.yield
+    }
+    ```
+  }];
+
+  let arguments = (ins IntegerSwitchOperandType:$arg, DenseI64ArrayAttr:$cases);
+  let results = (outs);
+  let regions = (region SizedRegion<1>:$defaultRegion,
+                        VariadicRegion<SizedRegion<1>>:$caseRegions);
+
+  let extraClassDeclaration = [{
+    /// Get the number of cases.
+    unsigned getNumCases();
+
+    /// Get the default region body.
+    Block &getDefaultBlock();
+
+    /// Get the body of a case region.
+    Block &getCaseBlock(unsigned idx);
+  }];
+
+  let hasCustomAssemblyFormat = 1;
+  let hasVerifier = 1;
+}
 
 #endif // MLIR_DIALECT_EMITC_IR_EMITC
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index 70e3e728e01950..216911b2870ebc 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -131,6 +131,12 @@ bool mlir::emitc::isPointerWideType(Type type) {
       type);
 }
 
+bool mlir::emitc::isIntegerSwitchOperandType(Type type) {
+  auto intType = llvm::dyn_cast<IntegerType>(type);
+  return isSupportedIntegerType(type) && intType.getWidth() != 1 &&
+         intType.getWidth() != 8;
+}
+
 /// Check that the type of the initial value is compatible with the operations
 /// result type.
 static LogicalResult verifyInitializationAttribute(Operation *op,
@@ -1096,6 +1102,208 @@ GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// SwitchOp
+//===----------------------------------------------------------------------===//
+
+/// Parse the case regions and values.
+static ParseResult
+parseSwitchCases(OpAsmParser &parser, DenseI64ArrayAttr &cases,
+                 SmallVectorImpl<std::unique_ptr<Region>> &caseRegions) {
+  SmallVector<int64_t> caseValues;
+  while (succeeded(parser.parseOptionalKeyword("case"))) {
+    int64_t value;
+    Region &region = *caseRegions.emplace_back(std::make_unique<Region>());
+
+    if (parser.parseInteger(value) || parser.parseColon() ||
+        parser.parseRegion(region, /*arguments=*/{}))
+      return failure();
+    caseValues.push_back(value);
+  }
+  cases = parser.getBuilder().getDenseI64ArrayAttr(caseValues);
+  return success();
+}
+
+/// Print the case regions and values.
+static void printSwitchCases(OpAsmPrinter &parser, Operation *op,
+                             DenseI64ArrayAttr cases, RegionRange caseRegions) {
+  for (auto [value, region] : llvm::zip(cases.asArrayRef(), caseRegions)) {
+    parser.printNewline();
+    parser << "case " << value << ": ";
+    parser.printRegion(*region, /*printEntryBlockArgs=*/false);
+  }
+  return;
+}
+
+ParseResult SwitchOp::parse(OpAsmParser &parser, OperationState &result) {
+  OpAsmParser::UnresolvedOperand arg;
+  DenseI64ArrayAttr casesAttr;
+  SmallVector<std::unique_ptr<Region>, 2> caseRegionsRegions;
+  std::unique_ptr<Region> defaultRegionRegion = std::make_unique<Region>();
+
+  if (parser.parseOperand(arg))
+    return failure();
+
+  Type argType = parser.getBuilder().getI32Type();
+  // Parse optional type, else assume i32.
+  if (!parser.parseOptionalColon() && parser.parseType(argType))
+    return failure();
+
+  {
+    auto loc = parser.getCurrentLocation();
+    if (parser.parseOptionalAttrDict(result.attributes))
+      return failure();
+
+    if (failed(verifyInherentAttrs(result.name, result.attributes, [&]() {
+          return parser.emitError(loc)
+                 << "'" << result.name.getStringRef() << "' op ";
+        })))
+      return failure();
+  }
+  {
+    auto odsResult = parseSwitchCases(parser, casesAttr, caseRegionsRegions);
+
+    if (odsResult)
+      return failure();
+
+    result.getOrAddProperties<SwitchOp::Properties>().cases = casesAttr;
+  }
+  if (parser.parseKeyword("default") || parser.parseColon())
+    return failure();
+
+  if (parser.parseRegion(*defaultRegionRegion))
+    return failure();
+
+  result.addRegion(std::move(defaultRegionRegion));
+  result.addRegions(caseRegionsRegions);
+
+  if (parser.resolveOperand(arg, argType, result.operands))
+    return failure();
+
+  return success();
+}
+
+void SwitchOp::print(OpAsmPrinter &parser) {
+  parser << ' ';
+  parser << getArg();
+  SmallVector<StringRef, 2> elidedAttrs;
+  elidedAttrs.push_back("cases");
+  parser.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
+  parser << ' ';
+  printSwitchCases(parser, *this, getCasesAttr(), getCaseRegions());
+  parser.printNewline();
+  parser << "default";
+  parser << ' ';
+  parser.printRegion(getDefaultRegion(), /*printEntryBlockArgs=*/true,
+                     /*printBlockTerminators=*/true);
+
+  return;
+}
+
+static LogicalResult verifyRegion(emitc::SwitchOp op, Region &region,
+                                  const Twine &name) {
+  auto yield = dyn_cast<emitc::YieldOp>(region.front().back());
+  if (!yield)
+    return op.emitOpError("expected region to end with emitc.yield, but got ")
+           << region.front().back().getName();
+
+  if (yield.getNumOperands() != 0) {
+    return (op.emitOpError("expected each region to return ")
+            << "0 values, but " << name << " returns "
+            << yield.getNumOperands())
+               .attachNote(yield.getLoc())
+           << "see yield operation here";
+  }
+  return success();
+}
+
+LogicalResult emitc::SwitchOp::verify() {
+  if (!isIntegerSwitchOperandType(getArg().getType()))
+    return emitOpError("unsupported type ") << getArg().getType();
+
+  if (getCases().size() != getCaseRegions().size()) {
+    return emitOpError("has ")
+           << getCaseRegions().size() << " case regions but "
+           << getCases().size() << " case values";
+  }
+
+  DenseSet<int64_t> valueSet;
+  for (int64_t value : getCases())
+    if (!valueSet.insert(value).second)
+      return emitOpError("has duplicate case value: ") << value;
+
+  if (failed(verifyRegion(*this, getDefaultRegion(), "default region")))
+    return failure();
+
+  for (auto [idx, caseRegion] : llvm::enumerate(getCaseRegions()))
+    if (failed(verifyRegion(*this, caseRegion, "case region #" + Twine(idx))))
+      return failure();
+
+  return success();
+}
+
+unsigned emitc::SwitchOp::getNumCases() { return getCases().size(); }
+
+Block &emitc::SwitchOp::getDefaultBlock() { return getDefaultRegion().front(); }
+
+Block &emitc::SwitchOp::getCaseBlock(unsigned idx) {
+  assert(idx < getNumCases() && "case index out-of-bounds");
+  return getCaseRegions()[idx].front();
+}
+
+void SwitchOp::getSuccessorRegions(
+    RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &successors) {
+  llvm::copy(getRegions(), std::back_inserter(successors));
+  return;
+}
+
+void SwitchOp::getEntrySuccessorRegions(
+    ArrayRef<Attribute> operands,
+    SmallVectorImpl<RegionSuccessor> &successors) {
+  FoldAdaptor adaptor(operands, *this);
+
+  // If a constant was not provided, all regions are possible successors.
+  auto arg = dyn_cast_or_null<IntegerAttr>(adaptor.getArg());
+  if (!arg) {
+    llvm::copy(getRegions(), std::back_inserter(successors));
+    return;
+  }
+
+  // Otherwise, try to find a case with a matching value. If not, the
+  // default region is the only successor.
+  for (auto [caseValue, caseRegion] : llvm::zip(getCases(), getCaseRegions())) {
+    if (caseValue == arg.getInt()) {
+      successors.emplace_back(&caseRegion);
+      return;
+    }
+  }
+  successors.emplace_back(&getDefaultRegion());
+  return;
+}
+
+void SwitchOp::getRegionInvocationBounds(
+    ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
+  auto operandValue = llvm::dyn_cast_or_null<IntegerAttr>(operands.front());
+  if (!operandValue) {
+    // All regions are invoked at most once.
+    bounds.append(getNumRegions(), InvocationBounds(/*lb=*/0, /*ub=*/1));
+    return;
+  }
+
+  unsigned liveIndex = getNumRegions() - 1;
+  const auto *iteratorToInt = llvm::find(getCases(), operandValue.getInt());
+
+  liveIndex = iteratorToInt != getCases().end()
+                  ? std::distance(getCases().begin(), iteratorToInt)
+                  : liveIndex;
+
+  for (unsigned regIndex = 0, regNum = getNumRegions(); regIndex < regNum;
+       ++regIndex)
+    bounds.emplace_back(/*lb=*/0, /*ub=*/regIndex == liveIndex);
+
+  return;
+}
+
 //===----------------------------------------------------------------------===//
 // TableGen'd op method definitions
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index 1dadb9dd691e70..f315e91015ef5d 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -449,6 +449,49 @@ static LogicalResult printOperation(CppEmitter &emitter, emitc::SubOp subOp) {
   return printBinaryOperation(emitter, operation, "-");
 }
 
+static LogicalResult emitAllExceptLast(CppEmitter &emitter,
+                                       raw_indented_ostream &os,
+                                       Region &region) {
+  for (Region::OpIterator iteratorOp = region.op_begin(), end = region.op_end();
+       std::next(iteratorOp) != end; ++iteratorOp) {
+    if (failed(emitter.emitOperation(*iteratorOp, /*trailingSemicolon=*/true)))
+      return failure();
+  }
+  os << "break;\n";
+  return success();
+}
+
+static LogicalResult printOperation(CppEmitter &emitter,
+                                    emitc::SwitchOp switchOp) {
+  raw_indented_ostream &os = emitter.ostream();
+  auto iteratorCaseValues = switchOp.getCases().begin();
+  size_t caseIndex = 0;
+
+  os << "\nswitch(" << emitter.getOrCreateName(switchOp.getArg()) << ") {";
+
+  while (caseIndex != switchOp.getNumCases()) {
+    os << "\ncase "
+       << "(" << *(iteratorCaseValues++) << ")"
+       << ": {\n";
+    os.indent();
+
+    if (failed(emitAllExceptLast(
+            emitter, os, *(switchOp.getCaseBlock(caseIndex++)).getParent())))
+      return failure();
+
+    os.unindent() << "}";
+  }
+
+  os << "\ndefault: {\n";
+  os.indent();
+
+  if (failed(emitAllExceptLast(emitter, os, switchOp.getDefaultRegion())))
+    return failure();
+
+  os.unindent() << "}";
+  return success();
+}
+
 static LogicalResult printOperation(CppEmitter &emitter, emitc::CmpOp cmpOp) {
   Operation *operation = cmpOp.getOperation();
 
@@ -998,7 +1041,7 @@ static LogicalResult printFunctionBody(CppEmitter &emitter,
       // trailing semicolon is handled within the printOperation function.
       bool trailingSemicolon =
           !isa<cf::CondBranchOp, emitc::DeclareFuncOp, emitc::ForOp,
-               emitc::IfOp, emitc::VerbatimOp>(op);
+               emitc::IfOp, emitc::SwitchOp, emitc::VerbatimOp>(op);
 
       if (failed(emitter.emitOperation(
               op, /*trailingSemicolon=*/trailingSemicolon)))
@@ -1509,8 +1552,8 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
                 emitc::GlobalOp, emitc::IfOp, emitc::IncludeOp,
                 emitc::LogicalAndOp, emitc::LogicalNotOp, emitc::LogicalOrOp,
                 emitc::MulOp, emitc::RemOp, emitc::ReturnOp, emitc::SubOp,
-                emitc::UnaryMinusOp, emitc::UnaryPlusOp, emitc::VariableOp,
-                emitc::VerbatimOp>(
+                emitc::SwitchOp, emitc::UnaryMinusOp, emitc::UnaryPlusOp,
+                emitc::VariableOp, emitc::VerbatimOp>(
               [&](auto op) { return printOperation(*this, op); })
           // Func ops.
           .Case<func::CallOp, func::FuncOp, func::ReturnOp>(
diff --git a/mlir/test/Target/Cpp/invalid.mlir b/mlir/test/Target/Cpp/invalid.mlir
index 513371a09cde1d..1b40ac9a5f3c9b 100644
--- a/mlir/test/Target/Cpp/invalid.mlir
+++ b/mlir/test/Target/Cpp/invalid.mlir
@@ -80,8 +80,137 @@ func.func @array_as_result(%arg: !emitc.array<4xi8>) -> (!emitc.array<4xi8>) {
 }
 
 // -----
+
 func.func @ptr_to_array() {
   // expected-error at +1 {{cannot emit pointer to array type '!emitc.ptr<!emitc.array<9xi16>>'}}
   %v = "emitc.variable"(){value = #emitc.opaque<"NULL">} : () -> !emitc.ptr<!emitc.array<9xi16>>
   return
 }
+
+// -----
+
+func.func @emitc_switch() {
+  %0 = "emitc.variable"(){value = 1 : ui16} : () -> ui16
+
+  // expected-error at +1 {{'emitc.switch' op expected region to end with emitc.yield, but got emitc.call_opaque}}
+  emitc.switch %0 : ui16
+  case 2: {
+    %1 = emitc.call_opaque "func_b" () : () -> i32
+  }
+  case 5: {
+    %2 = emitc.call_opaque "func_a" () : () -> i32
+    emitc.yield
+  }
+  default : {
+    %3 = "emitc.variable"(){value = 42.0 : f32} : () -> f32
+    %4 = "emitc.variable"(){value = 42.0 : f32} : () -> f32
+
+    emitc.call_opaque "func2" (%3) : (f32) -> ()
+    emitc.call_opaque "func3" (%3, %4) { args = [1 : index, 0 : index] } : (f32, f32) -> ()
+    emitc.yield
+  }
+  return
+}
+
+// -----
+
+func.func @emitc_switch() {
+  %0 = "emitc.variable"(){value = 1 : ui16} : () -> ui16
+
+  // expected-error at +1 {{'emitc.switch' op expected region to end with emitc.yield, but got emitc.call_opaque}}
+  emitc.switch %0 : ui16
+  case 2: {
+    %1 = emitc.call_opaque "func_b" () : () -> i32
+    emitc.yield
+  }
+  case 5: {
+    %2 = emitc.call_opaque "func_a" () : () -> i32
+    emitc.yield
+  }
+  default : {
+    %3 = "emitc.variable"(){value = 42.0 : f32} : () -> f32
+    %4 = "emitc.variable"(){value = 42.0 : f32} : () -> f32
+
+    emitc.call_opaque "func2" (%3) : (f32) -> ()
+    emitc.call_opaque "func3" (%3, %4) { args = [1 : index, 0 : index] } : (f32, f32) -> ()
+  }
+  return
+}
+
+// -----
+
+func.func @emitc_switch() {
+  %0 = "emitc.variable"(){value = 1 : i8} : () -> i8
+
+  // expected-error at +1 {{'emitc.switch' op operand #0 must be integer type for switch operation, but got 'i8'}}
+  emitc.switch %0 : i8
+  case 2: {
+    %1 = emitc.call_opaque "func_b" () : () -> i32
+    emitc.yield
+  }
+  case 5: {
+    %2 = emitc.call_opaque "func_a" () : () -> i32
+    emitc.yield
+  }
+  default : {
+    %3 = "emitc.variable"(){value = 42.0 : f32} : () -> f32
+    %4 = "emitc.variable"(){value = 42.0 : f32} : () -> f32
+
+    emitc.call_opaque "func2" (%3) : (f32) -> ()
+    emitc.call_opaque "func3" (%3, %4) { args = [1 : index, 0 : index] } : (f32, f32) -> ()
+    emitc.yield
+  }
+  return
+}
+
+// -----
+
+func.func @emitc_switch() {
+  %0 = "emitc.variable"(){value = 1 : i1} : () -> i1
+
+  // expected-error at +1 {{'emitc.switch' op operand #0 must be integer type for switch operation, but got 'i1'}}
+  emitc.switch %0 : i1
+  case 2: {
+    %1 = emitc.call_opaque "func_b" () : () -> i32
+    emitc.yield
+  }
+  case 5: {
+    %2 = emitc.call_opaque "func_a" () : () -> i32
+    emitc.yield
+  }
+  default : {
+    %3 = "emitc.variable"(){value = 42.0 : f32} : () -> f32
+    %4 = "emitc.variable"(){value = 42.0 : f32} : () -> f32
+
+    emitc.call_opaque "func2" (%3) : (f32) -> ()
+    emitc.call_opaque "func3" (%3, %4) { args = [1 : index, 0 : index] } : (f32, f32) -> ()
+    emitc.yield
+  }
+  return
+}
+
+// -----
+
+func.func @emitc_switch() {
+  %0 = "emitc.variable"(){value = 1 : i16} : () -> i16
+
+  emitc.switch %0 : i16
+  case 2: {
+    %1 = emitc.call_opaque "func_b" () : () -> i32
+    emitc.yield
+  }
+  // expected-error at +1 {{custom op 'emitc.switch' expected integer value}}
+  case : {
+    %2 = emitc.call_opaque "func_a" () : () -> i32
+    emitc.yield
+  }
+  default : {
+    %3 = "emitc.variable"(){value = 42.0 : f32} : () -> f32
+    %4 = "emitc.variable"(){value = 42.0 : f32} : () -> f32
+
+    emitc.call_opaque "func2" (%3) : (f32) -> ()
+    emitc.call_opaque "func3" (%3, %4) { args = [1 : index, 0 : index] } : (f32, f32) -> ()
+    emitc.yield
+  }
+  return
+}
diff --git a/mlir/test/Target/Cpp/switch.mlir b/mlir/test/Target/Cpp/switch.mlir
new file mode 100644
index 00000000000000..251c13bc188f8d
--- /dev/null
+++ b/mlir/test/Target/Cpp/switch.mlir
@@ -0,0 +1,129 @@
+// RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s
+
+// CHECK-LABEL:   int32_t v1 = 1...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/102331


More information about the Mlir-commits mailing list