[Mlir-commits] [mlir] ae33eef - [MLIR] Add a switch operation to the standard dialect

Geoffrey Martin-Noble llvmlistbot at llvm.org
Mon Apr 12 18:46:11 PDT 2021


Author: Geoffrey Martin-Noble
Date: 2021-04-12T18:46:02-07:00
New Revision: ae33eef5055ef1f55df0df3be0b8851aaf9f4efd

URL: https://github.com/llvm/llvm-project/commit/ae33eef5055ef1f55df0df3be0b8851aaf9f4efd
DIFF: https://github.com/llvm/llvm-project/commit/ae33eef5055ef1f55df0df3be0b8851aaf9f4efd.diff

LOG: [MLIR] Add a switch operation to the standard dialect

This is similar to the definition of llvm.switch, providing
unstructured branch-based control flow. It differs from the LLVM
operation in that it accepts any signless integer (not only an i32),
takes no branch weights (the same as the Branch and CondBranch ops),
and has a slightly different syntax for the default case that includes
it in the list of cases with an explicit `default` keyword.

Also included are several canonicalizers.

See https://llvm.discourse.group/t/rfc-add-std-switch-and-scf-switch/3090

Reviewed By: rriddle, bondhugula

Differential Revision: https://reviews.llvm.org/D99925

Added: 
    mlir/test/Dialect/Standard/parser.mlir

Modified: 
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
    mlir/include/mlir/IR/OpBase.td
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/test/Dialect/Standard/canonicalize-cf.mlir
    mlir/test/Dialect/Standard/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index 6d058f4261416..060062e30fb67 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -2030,6 +2030,89 @@ def SubTensorInsertOp : BaseOpWithOffsetSizesAndStrides<
   let hasFolder = 1;
 }
 
+
+//===----------------------------------------------------------------------===//
+// SwitchOp
+//===----------------------------------------------------------------------===//
+
+def SwitchOp : Std_Op<"switch",
+    [AttrSizedOperandSegments,
+     DeclareOpInterfaceMethods<BranchOpInterface, ["getSuccessorForOperands"]>,
+     NoSideEffect, Terminator]> {
+  let summary = "switch operation";
+  let description = [{
+    The `switch` terminator operation represents a switch on a signless integer
+    value. If the flag matches one of the specified cases, then the
+    corresponding destination is jumped to. If the flag does not match any of
+    the cases, the default destination is jumped to. The count and types of
+    operands must align with the arguments in the corresponding target blocks.
+
+    Example:
+
+    ```mlir
+    switch %flag : i32, [
+      default: ^bb1(%a : i32),
+      42: ^bb1(%b : i32),
+      43: ^bb3(%c : i32)
+    ]
+    ```
+  }];
+
+  let arguments = (ins AnyInteger:$flag,
+                       Variadic<AnyType>:$defaultOperands,
+                       Variadic<AnyType>:$caseOperands,
+                       OptionalAttr<AnyIntElementsAttr>:$case_values,
+                       OptionalAttr<I32ElementsAttr>:$case_operand_offsets);
+  let successors = (successor
+                        AnySuccessor:$defaultDestination,
+                        VariadicSuccessor<AnySuccessor>:$caseDestinations);
+  let builders = [
+    OpBuilder<(ins "Value":$flag,
+      "Block *":$defaultDestination,
+      "ValueRange":$defaultOperands,
+      CArg<"ArrayRef<APInt>", "{}">:$caseValues,
+      CArg<"BlockRange", "{}">:$caseDestinations,
+      CArg<"ArrayRef<ValueRange>", "{}">:$caseOperands)>,
+    OpBuilder<(ins "Value":$flag,
+      "Block *":$defaultDestination,
+      "ValueRange":$defaultOperands,
+      CArg<"ArrayRef<int32_t>", "{}">:$caseValues,
+      CArg<"BlockRange", "{}">:$caseDestinations,
+      CArg<"ArrayRef<ValueRange>", "{}">:$caseOperands)>,
+    OpBuilder<(ins "Value":$flag,
+      "Block *":$defaultDestination,
+      "ValueRange":$defaultOperands,
+      CArg<"DenseIntElementsAttr", "{}">:$caseValues,
+      CArg<"BlockRange", "{}">:$caseDestinations,
+      CArg<"ArrayRef<ValueRange>", "{}">:$caseOperands)>
+  ];
+
+  let assemblyFormat = [{
+    $flag `:` type($flag) `,` `[` `\n`
+      custom<SwitchOpCases>(ref(type($flag)),$defaultDestination,
+                            $defaultOperands,
+                            type($defaultOperands),
+                            $case_values,
+                            $caseDestinations,
+                            $caseOperands,
+                            type($caseOperands),
+                            $case_operand_offsets)
+   `]`
+    attr-dict
+  }];
+
+  let extraClassDeclaration = [{
+    /// Return the operands for the case destination block at the given index.
+    OperandRange getCaseOperands(unsigned index);
+
+    /// Return a mutable range of operands for the case destination block at the
+    /// given index.
+    MutableOperandRange getCaseOperandsMutable(unsigned index);
+  }];
+
+  let hasCanonicalizer = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // TruncateIOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index a2469dc5bee32..f0b741ee4dee2 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -1333,13 +1333,15 @@ def IndexElementsAttr
                                       .isIndex()}]>,
                           "index elements attribute">;
 
-class AnyIntElementsAttr<int width> : IntElementsAttrBase<
+def AnyIntElementsAttr : IntElementsAttrBase<CPred<"true">, "integer elements attribute">;
+
+class IntElementsAttrOf<int width> : IntElementsAttrBase<
   CPred<"$_self.cast<::mlir::DenseIntElementsAttr>().getType()."
         "getElementType().isInteger(" # width # ")">,
   width # "-bit integer elements attribute">;
 
-def AnyI32ElementsAttr : AnyIntElementsAttr<32>;
-def AnyI64ElementsAttr : AnyIntElementsAttr<64>;
+def AnyI32ElementsAttr : IntElementsAttrOf<32>;
+def AnyI64ElementsAttr : IntElementsAttrOf<64>;
 
 class SignlessIntElementsAttr<int width> : IntElementsAttrBase<
   CPred<"$_self.cast<::mlir::DenseIntElementsAttr>().getType()."

diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 3dad958887b52..b538cbabd80fb 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -441,8 +441,9 @@ static LogicalResult verify(AtomicYieldOp op) {
 /// Given a successor, try to collapse it to a new destination if it only
 /// contains a passthrough unconditional branch. If the successor is
 /// collapsable, `successor` and `successorOperands` are updated to reference
-/// the new destination and values. `argStorage` is an optional storage to use
-/// if operands to the collapsed successor need to be remapped.
+/// the new destination and values. `argStorage` is used as storage if operands
+/// to the collapsed successor need to be remapped. It must outlive uses of
+/// successorOperands.
 static LogicalResult collapseBranch(Block *&successor,
                                     ValueRange &successorOperands,
                                     SmallVectorImpl<Value> &argStorage) {
@@ -2160,6 +2161,490 @@ void SubTensorInsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
               SubTensorInsertOpCastFolder>(context);
 }
 
+//===----------------------------------------------------------------------===//
+// SwitchOp
+//===----------------------------------------------------------------------===//
+
+void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
+                     Block *defaultDestination, ValueRange defaultOperands,
+                     DenseIntElementsAttr caseValues,
+                     BlockRange caseDestinations,
+                     ArrayRef<ValueRange> caseOperands) {
+  SmallVector<Value> flattenedCaseOperands;
+  SmallVector<int32_t> caseOperandOffsets;
+  int32_t offset = 0;
+  for (ValueRange operands : caseOperands) {
+    flattenedCaseOperands.append(operands.begin(), operands.end());
+    caseOperandOffsets.push_back(offset);
+    offset += operands.size();
+  }
+  DenseIntElementsAttr caseOperandOffsetsAttr;
+  if (!caseOperandOffsets.empty())
+    caseOperandOffsetsAttr = builder.getI32VectorAttr(caseOperandOffsets);
+
+  build(builder, result, value, defaultOperands, flattenedCaseOperands,
+        caseValues, caseOperandOffsetsAttr, defaultDestination,
+        caseDestinations);
+}
+
+void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
+                     Block *defaultDestination, ValueRange defaultOperands,
+                     ArrayRef<APInt> caseValues, BlockRange caseDestinations,
+                     ArrayRef<ValueRange> caseOperands) {
+  DenseIntElementsAttr caseValuesAttr;
+  if (!caseValues.empty()) {
+    ShapedType caseValueType = VectorType::get(
+        static_cast<int64_t>(caseValues.size()), value.getType());
+    caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues);
+  }
+  build(builder, result, value, defaultDestination, defaultOperands,
+        caseValuesAttr, caseDestinations, caseOperands);
+}
+
+/// <cases> ::= `default` `:` bb-id (`(` ssa-use-and-type-list `)`)?
+///             ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )*
+static ParseResult
+parseSwitchOpCases(OpAsmParser &parser, Type &flagType,
+                   Block *&defaultDestination,
+                   SmallVectorImpl<OpAsmParser::OperandType> &defaultOperands,
+                   SmallVectorImpl<Type> &defaultOperandTypes,
+                   DenseIntElementsAttr &caseValues,
+                   SmallVectorImpl<Block *> &caseDestinations,
+                   SmallVectorImpl<OpAsmParser::OperandType> &caseOperands,
+                   SmallVectorImpl<Type> &caseOperandTypes,
+                   DenseIntElementsAttr &caseOperandOffsets) {
+
+  if (failed(parser.parseKeyword("default")) || failed(parser.parseColon()) ||
+      failed(parser.parseSuccessor(defaultDestination)))
+    return failure();
+  if (succeeded(parser.parseOptionalLParen())) {
+    if (failed(parser.parseRegionArgumentList(defaultOperands)) ||
+        failed(parser.parseColonTypeList(defaultOperandTypes)) ||
+        failed(parser.parseRParen()))
+      return failure();
+  }
+
+  SmallVector<APInt> values;
+  SmallVector<int32_t> offsets;
+  unsigned bitWidth = flagType.getIntOrFloatBitWidth();
+  int64_t offset = 0;
+  while (succeeded(parser.parseOptionalComma())) {
+    int64_t value = 0;
+    if (failed(parser.parseInteger(value)))
+      return failure();
+    values.push_back(APInt(bitWidth, value));
+
+    Block *destination;
+    SmallVector<OpAsmParser::OperandType> operands;
+    if (failed(parser.parseColon()) ||
+        failed(parser.parseSuccessor(destination)))
+      return failure();
+    if (succeeded(parser.parseOptionalLParen())) {
+      if (failed(parser.parseRegionArgumentList(operands)) ||
+          failed(parser.parseColonTypeList(caseOperandTypes)) ||
+          failed(parser.parseRParen()))
+        return failure();
+    }
+    caseDestinations.push_back(destination);
+    caseOperands.append(operands.begin(), operands.end());
+    offsets.push_back(offset);
+    offset += operands.size();
+  }
+
+  if (values.empty())
+    return success();
+
+  Builder &builder = parser.getBuilder();
+  ShapedType caseValueType =
+      VectorType::get(static_cast<int64_t>(values.size()), flagType);
+  caseValues = DenseIntElementsAttr::get(caseValueType, values);
+  caseOperandOffsets = builder.getI32VectorAttr(offsets);
+
+  return success();
+}
+
+static void printSwitchOpCases(
+    OpAsmPrinter &p, SwitchOp op, Type flagType, Block *defaultDestination,
+    OperandRange defaultOperands, TypeRange defaultOperandTypes,
+    DenseIntElementsAttr caseValues, SuccessorRange caseDestinations,
+    OperandRange caseOperands, TypeRange caseOperandTypes,
+    ElementsAttr caseOperandOffsets) {
+  p << "  default: ";
+  p.printSuccessorAndUseList(defaultDestination, defaultOperands);
+
+  if (!caseValues)
+    return;
+
+  for (int64_t i = 0, size = caseValues.size(); i < size; ++i) {
+    p << ',';
+    p.printNewline();
+    p << "  ";
+    p << caseValues.getValue<APInt>(i).getLimitedValue();
+    p << ": ";
+    p.printSuccessorAndUseList(caseDestinations[i], op.getCaseOperands(i));
+  }
+  p.printNewline();
+}
+
+static LogicalResult verify(SwitchOp op) {
+  auto caseValues = op.case_values();
+  auto caseDestinations = op.caseDestinations();
+
+  if (!caseValues && caseDestinations.empty())
+    return success();
+
+  Type flagType = op.flag().getType();
+  Type caseValueType = caseValues->getType().getElementType();
+  if (caseValueType != flagType)
+    return op.emitOpError()
+           << "'flag' type (" << flagType << ") should match case value type ("
+           << caseValueType << ")";
+
+  if (caseValues &&
+      caseValues->size() != static_cast<int64_t>(caseDestinations.size()))
+    return op.emitOpError() << "number of case values (" << caseValues->size()
+                            << ") should match number of "
+                               "case destinations ("
+                            << caseDestinations.size() << ")";
+  return success();
+}
+
+OperandRange SwitchOp::getCaseOperands(unsigned index) {
+  return getCaseOperandsMutable(index);
+}
+
+MutableOperandRange SwitchOp::getCaseOperandsMutable(unsigned index) {
+  MutableOperandRange caseOperands = caseOperandsMutable();
+  if (!case_operand_offsets()) {
+    assert(caseOperands.size() == 0 &&
+           "non-empty case operands must have offsets");
+    return caseOperands;
+  }
+
+  ElementsAttr offsets = case_operand_offsets().getValue();
+  assert(index < offsets.size() && "invalid case operand offset index");
+
+  int64_t begin = offsets.getValue(index).cast<IntegerAttr>().getInt();
+  int64_t end = index + 1 == offsets.size()
+                    ? caseOperands.size()
+                    : offsets.getValue(index + 1).cast<IntegerAttr>().getInt();
+  return caseOperandsMutable().slice(begin, end - begin);
+}
+
+Optional<MutableOperandRange>
+SwitchOp::getMutableSuccessorOperands(unsigned index) {
+  assert(index < getNumSuccessors() && "invalid successor index");
+  return index == 0 ? defaultOperandsMutable()
+                    : getCaseOperandsMutable(index - 1);
+}
+
+Block *SwitchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
+  Optional<DenseIntElementsAttr> caseValues = case_values();
+
+  if (!caseValues)
+    return defaultDestination();
+
+  SuccessorRange caseDests = caseDestinations();
+  if (auto value = operands.front().dyn_cast_or_null<IntegerAttr>()) {
+    for (int64_t i = 0, size = case_values()->size(); i < size; ++i)
+      if (value == caseValues->getValue<IntegerAttr>(i))
+        return caseDests[i];
+    return defaultDestination();
+  }
+  return nullptr;
+}
+
+/// switch %flag : i32, [
+///   default:  ^bb1
+/// ]
+///  -> br ^bb1
+static LogicalResult simplifySwitchWithOnlyDefault(SwitchOp op,
+                                                   PatternRewriter &rewriter) {
+  if (!op.caseDestinations().empty())
+    return failure();
+
+  rewriter.replaceOpWithNewOp<BranchOp>(op, op.defaultDestination(),
+                                        op.defaultOperands());
+  return success();
+}
+
+/// switch %flag : i32, [
+///   default: ^bb1,
+///   42: ^bb1,
+///   43: ^bb2
+/// ]
+/// ->
+/// switch %flag : i32, [
+///   default: ^bb1,
+///   43: ^bb2
+/// ]
+static LogicalResult
+dropSwitchCasesThatMatchDefault(SwitchOp op, PatternRewriter &rewriter) {
+  SmallVector<Block *> newCaseDestinations;
+  SmallVector<ValueRange> newCaseOperands;
+  SmallVector<APInt> newCaseValues;
+  bool requiresChange = false;
+  auto caseValues = op.case_values();
+  auto caseDests = op.caseDestinations();
+
+  for (int64_t i = 0, size = caseValues->size(); i < size; ++i) {
+    if (caseDests[i] == op.defaultDestination() &&
+        op.getCaseOperands(i) == op.defaultOperands()) {
+      requiresChange = true;
+      continue;
+    }
+    newCaseDestinations.push_back(caseDests[i]);
+    newCaseOperands.push_back(op.getCaseOperands(i));
+    newCaseValues.push_back(caseValues->getValue<APInt>(i));
+  }
+
+  if (!requiresChange)
+    return failure();
+
+  rewriter.replaceOpWithNewOp<SwitchOp>(op, op.flag(), op.defaultDestination(),
+                                        op.defaultOperands(), newCaseValues,
+                                        newCaseDestinations, newCaseOperands);
+  return success();
+}
+
+/// Helper for folding a switch with a constant value.
+/// switch %c_42 : i32, [
+///   default: ^bb1 ,
+///   42: ^bb2,
+///   43: ^bb3
+/// ]
+/// -> br ^bb2
+static void foldSwitch(SwitchOp op, PatternRewriter &rewriter,
+                       APInt caseValue) {
+  auto caseValues = op.case_values();
+  for (int64_t i = 0, size = caseValues->size(); i < size; ++i) {
+    if (caseValues->getValue<APInt>(i) == caseValue) {
+      rewriter.replaceOpWithNewOp<BranchOp>(op, op.caseDestinations()[i],
+                                            op.getCaseOperands(i));
+      return;
+    }
+  }
+  rewriter.replaceOpWithNewOp<BranchOp>(op, op.defaultDestination(),
+                                        op.defaultOperands());
+}
+
+/// switch %c_42 : i32, [
+///   default: ^bb1,
+///   42: ^bb2,
+///   43: ^bb3
+/// ]
+/// -> br ^bb2
+static LogicalResult simplifyConstSwitchValue(SwitchOp op,
+                                              PatternRewriter &rewriter) {
+  APInt caseValue;
+  if (!matchPattern(op.flag(), m_ConstantInt(&caseValue)))
+    return failure();
+
+  foldSwitch(op, rewriter, caseValue);
+  return success();
+}
+
+/// switch %c_42 : i32, [
+///   default: ^bb1,
+///   42: ^bb2,
+/// ]
+/// ^bb2:
+///   br ^bb3
+/// ->
+/// switch %c_42 : i32, [
+///   default: ^bb1,
+///   42: ^bb3,
+/// ]
+static LogicalResult simplifyPassThroughSwitch(SwitchOp op,
+                                               PatternRewriter &rewriter) {
+
+  SmallVector<Block *> newCaseDests;
+  SmallVector<ValueRange> newCaseOperands;
+  SmallVector<SmallVector<Value>> argStorage;
+  auto caseValues = op.case_values();
+  auto caseDests = op.caseDestinations();
+  bool requiresChange = false;
+  for (int64_t i = 0, size = caseValues->size(); i < size; ++i) {
+    Block *caseDest = caseDests[i];
+    ValueRange caseOperands = op.getCaseOperands(i);
+    argStorage.emplace_back();
+    if (succeeded(collapseBranch(caseDest, caseOperands, argStorage.back())))
+      requiresChange = true;
+
+    newCaseDests.push_back(caseDest);
+    newCaseOperands.push_back(caseOperands);
+  }
+
+  Block *defaultDest = op.defaultDestination();
+  ValueRange defaultOperands = op.defaultOperands();
+  argStorage.emplace_back();
+
+  if (succeeded(
+          collapseBranch(defaultDest, defaultOperands, argStorage.back())))
+    requiresChange = true;
+
+  if (!requiresChange)
+    return failure();
+
+  rewriter.replaceOpWithNewOp<SwitchOp>(op, op.flag(), defaultDest,
+                                        defaultOperands, caseValues.getValue(),
+                                        newCaseDests, newCaseOperands);
+  return success();
+}
+
+/// switch %flag : i32, [
+///   default: ^bb1,
+///   42: ^bb2,
+/// ]
+/// ^bb2:
+///   switch %flag : i32, [
+///     default: ^bb3,
+///     42: ^bb4
+///   ]
+/// ->
+/// switch %flag : i32, [
+///   default: ^bb1,
+///   42: ^bb2,
+/// ]
+/// ^bb2:
+///   br ^bb4
+///
+///  and
+///
+/// switch %flag : i32, [
+///   default: ^bb1,
+///   42: ^bb2,
+/// ]
+/// ^bb2:
+///   switch %flag : i32, [
+///     default: ^bb3,
+///     43: ^bb4
+///   ]
+/// ->
+/// switch %flag : i32, [
+///   default: ^bb1,
+///   42: ^bb2,
+/// ]
+/// ^bb2:
+///   br ^bb3
+static LogicalResult
+simplifySwitchFromSwitchOnSameCondition(SwitchOp op,
+                                        PatternRewriter &rewriter) {
+  // Check that we have a single distinct predecessor.
+  Block *currentBlock = op->getBlock();
+  Block *predecessor = currentBlock->getSinglePredecessor();
+  if (!predecessor)
+    return failure();
+
+  // Check that the predecessor terminates with a switch branch to this block
+  // and that it branches on the same condition and that this branch isn't the
+  // default destination.
+  auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator());
+  if (!predSwitch || op.flag() != predSwitch.flag() ||
+      predSwitch.defaultDestination() == currentBlock)
+    return failure();
+
+  // Fold this switch to an unconditional branch.
+  APInt caseValue;
+  bool isDefault = true;
+  SuccessorRange predDests = predSwitch.caseDestinations();
+  Optional<DenseIntElementsAttr> predCaseValues = predSwitch.case_values();
+  for (int64_t i = 0, size = predCaseValues->size(); i < size; ++i) {
+    if (currentBlock == predDests[i]) {
+      caseValue = predCaseValues->getValue<APInt>(i);
+      isDefault = false;
+      break;
+    }
+  }
+  if (isDefault)
+    rewriter.replaceOpWithNewOp<BranchOp>(op, op.defaultDestination(),
+                                          op.defaultOperands());
+  else
+    foldSwitch(op, rewriter, caseValue);
+  return success();
+}
+
+/// switch %flag : i32, [
+///   default: ^bb1,
+///   42: ^bb2
+/// ]
+/// ^bb1:
+///   switch %flag : i32, [
+///     default: ^bb3,
+///     42: ^bb4,
+///     43: ^bb5
+///   ]
+/// ->
+/// switch %flag : i32, [
+///   default: ^bb1,
+///   42: ^bb2,
+/// ]
+/// ^bb1:
+///   switch %flag : i32, [
+///     default: ^bb3,
+///     43: ^bb5
+///   ]
+static LogicalResult
+simplifySwitchFromDefaultSwitchOnSameCondition(SwitchOp op,
+                                               PatternRewriter &rewriter) {
+  // Check that we have a single distinct predecessor.
+  Block *currentBlock = op->getBlock();
+  Block *predecessor = currentBlock->getSinglePredecessor();
+  if (!predecessor)
+    return failure();
+
+  // Check that the predecessor terminates with a switch branch to this block
+  // and that it branches on the same condition and that this branch is the
+  // default destination.
+  auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator());
+  if (!predSwitch || op.flag() != predSwitch.flag() ||
+      predSwitch.defaultDestination() != currentBlock)
+    return failure();
+
+  // Delete case values that are not possible here.
+  DenseSet<APInt> caseValuesToRemove;
+  auto predDests = predSwitch.caseDestinations();
+  auto predCaseValues = predSwitch.case_values();
+  for (int64_t i = 0, size = predCaseValues->size(); i < size; ++i)
+    if (currentBlock != predDests[i])
+      caseValuesToRemove.insert(predCaseValues->getValue<APInt>(i));
+
+  SmallVector<Block *> newCaseDestinations;
+  SmallVector<ValueRange> newCaseOperands;
+  SmallVector<APInt> newCaseValues;
+  bool requiresChange = false;
+
+  auto caseValues = op.case_values();
+  auto caseDests = op.caseDestinations();
+  for (int64_t i = 0, size = caseValues->size(); i < size; ++i) {
+    if (caseValuesToRemove.contains(caseValues->getValue<APInt>(i))) {
+      requiresChange = true;
+      continue;
+    }
+    newCaseDestinations.push_back(caseDests[i]);
+    newCaseOperands.push_back(op.getCaseOperands(i));
+    newCaseValues.push_back(caseValues->getValue<APInt>(i));
+  }
+
+  if (!requiresChange)
+    return failure();
+
+  rewriter.replaceOpWithNewOp<SwitchOp>(op, op.flag(), op.defaultDestination(),
+                                        op.defaultOperands(), newCaseValues,
+                                        newCaseDestinations, newCaseOperands);
+  return success();
+}
+
+void SwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                           MLIRContext *context) {
+  results.add(&simplifySwitchWithOnlyDefault)
+      .add(&dropSwitchCasesThatMatchDefault)
+      .add(&simplifyConstSwitchValue)
+      .add(&simplifyPassThroughSwitch)
+      .add(&simplifySwitchFromSwitchOnSameCondition)
+      .add(&simplifySwitchFromDefaultSwitchOnSameCondition);
+}
+
 //===----------------------------------------------------------------------===//
 // TruncateIOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Standard/canonicalize-cf.mlir b/mlir/test/Dialect/Standard/canonicalize-cf.mlir
index 5f18562b7ad5f..d7d0a6a4fdd4d 100644
--- a/mlir/test/Dialect/Standard/canonicalize-cf.mlir
+++ b/mlir/test/Dialect/Standard/canonicalize-cf.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -allow-unregistered-dialect -pass-pipeline='func(canonicalize)' -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -allow-unregistered-dialect -pass-pipeline='func(canonicalize)' -split-input-file | FileCheck --dump-input-context 20 %s
 
 /// Test the folding of BranchOp.
 
@@ -139,6 +139,268 @@ func @cond_br_pass_through_fail(%cond : i1) {
   return
 }
 
+
+/// Test the folding of SwitchOp
+
+// CHECK-LABEL: func @switch_only_default(
+// CHECK-SAME: %[[FLAG:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[CASE_OPERAND_0:[a-zA-Z0-9_]+]]
+func @switch_only_default(%flag : i32, %caseOperand0 : f32) {
+  // add predecessors for all blocks to avoid other canonicalizations.
+  "foo.pred"() [^bb1, ^bb2] : () -> ()
+  ^bb1:
+    // CHECK-NOT: switch
+    // CHECK: br ^[[BB2:[a-zA-Z0-9_]+]](%[[CASE_OPERAND_0]]
+    switch %flag : i32, [
+      default: ^bb2(%caseOperand0 : f32)
+    ]
+  // CHECK: ^[[BB2]]({{.*}}):
+  ^bb2(%bb2Arg : f32):
+    // CHECK-NEXT: "foo.bb2Terminator"
+    "foo.bb2Terminator"(%bb2Arg) : (f32) -> ()
+}
+
+
+// CHECK-LABEL: func @switch_case_matching_default(
+// CHECK-SAME: %[[FLAG:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[CASE_OPERAND_0:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[CASE_OPERAND_1:[a-zA-Z0-9_]+]]
+func @switch_case_matching_default(%flag : i32, %caseOperand0 : f32, %caseOperand1 : f32) {
+  // add predecessors for all blocks to avoid other canonicalizations.
+  "foo.pred"() [^bb1, ^bb2, ^bb3] : () -> ()
+  ^bb1:
+    // CHECK: switch %[[FLAG]]
+    // CHECK-NEXT:   default: ^[[BB1:.+]](%[[CASE_OPERAND_0]] : f32)
+    // CHECK-NEXT:   10: ^[[BB2:.+]](%[[CASE_OPERAND_1]] : f32)
+    // CHECK-NEXT: ]
+    switch %flag : i32, [
+      default: ^bb2(%caseOperand0 : f32),
+      42: ^bb2(%caseOperand0 : f32),
+      10: ^bb3(%caseOperand1 : f32),
+      17: ^bb2(%caseOperand0 : f32)
+    ]
+  ^bb2(%bb2Arg : f32):
+    "foo.bb2Terminator"(%bb2Arg) : (f32) -> ()
+  ^bb3(%bb3Arg : f32):
+    "foo.bb3Terminator"(%bb3Arg) : (f32) -> ()
+}
+
+
+// CHECK-LABEL: func @switch_on_const_no_match(
+// CHECK-SAME: %[[CASE_OPERAND_0:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[CASE_OPERAND_1:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[CASE_OPERAND_2:[a-zA-Z0-9_]+]]
+func @switch_on_const_no_match(%caseOperand0 : f32, %caseOperand1 : f32, %caseOperand2 : f32) {
+  // add predecessors for all blocks to avoid other canonicalizations.
+  "foo.pred"() [^bb1, ^bb2, ^bb3, ^bb4] : () -> ()
+  ^bb1:
+    // CHECK-NOT: switch
+    // CHECK: br ^[[BB2:[a-zA-Z0-9_]+]](%[[CASE_OPERAND_0]]
+    %c0_i32 = constant 0 : i32
+    switch %c0_i32 : i32, [
+      default: ^bb2(%caseOperand0 : f32),
+      -1: ^bb3(%caseOperand1 : f32),
+      1: ^bb4(%caseOperand2 : f32)
+    ]
+  // CHECK: ^[[BB2]]({{.*}}):
+  // CHECK-NEXT: "foo.bb2Terminator"
+  ^bb2(%bb2Arg : f32):
+    "foo.bb2Terminator"(%bb2Arg) : (f32) -> ()
+  ^bb3(%bb3Arg : f32):
+    "foo.bb3Terminator"(%bb3Arg) : (f32) -> ()
+  ^bb4(%bb4Arg : f32):
+    "foo.bb4Terminator"(%bb4Arg) : (f32) -> ()
+}
+
+// CHECK-LABEL: func @switch_on_const_with_match(
+// CHECK-SAME: %[[CASE_OPERAND_0:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[CASE_OPERAND_1:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[CASE_OPERAND_2:[a-zA-Z0-9_]+]]
+func @switch_on_const_with_match(%caseOperand0 : f32, %caseOperand1 : f32, %caseOperand2 : f32) {
+  // add predecessors for all blocks to avoid other canonicalizations.
+  "foo.pred"() [^bb1, ^bb2, ^bb3, ^bb4] : () -> ()
+  ^bb1:
+    // CHECK-NOT: switch
+    // CHECK: br ^[[BB4:[a-zA-Z0-9_]+]](%[[CASE_OPERAND_2]]
+    %c0_i32 = constant 1 : i32
+    switch %c0_i32 : i32, [
+      default: ^bb2(%caseOperand0 : f32),
+      -1: ^bb3(%caseOperand1 : f32),
+      1: ^bb4(%caseOperand2 : f32)
+    ]
+  ^bb2(%bb2Arg : f32):
+    "foo.bb2Terminator"(%bb2Arg) : (f32) -> ()
+  ^bb3(%bb3Arg : f32):
+    "foo.bb3Terminator"(%bb3Arg) : (f32) -> ()
+  // CHECK: ^[[BB4]]({{.*}}):
+  // CHECK-NEXT: "foo.bb4Terminator"
+  ^bb4(%bb4Arg : f32):
+    "foo.bb4Terminator"(%bb4Arg) : (f32) -> ()
+}
+
+// CHECK-LABEL: func @switch_passthrough(
+// CHECK-SAME: %[[FLAG:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[CASE_OPERAND_0:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[CASE_OPERAND_1:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[CASE_OPERAND_2:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[CASE_OPERAND_3:[a-zA-Z0-9_]+]]
+func @switch_passthrough(%flag : i32,
+                         %caseOperand0 : f32,
+                         %caseOperand1 : f32,
+                         %caseOperand2 : f32,
+                         %caseOperand3 : f32) {
+  // add predecessors for all blocks to avoid other canonicalizations.
+  "foo.pred"() [^bb1, ^bb2, ^bb3, ^bb4, ^bb5, ^bb6] : () -> ()
+
+  ^bb1:
+  //      CHECK: switch %[[FLAG]]
+  // CHECK-NEXT:   default: ^[[BB5:[a-zA-Z0-9_]+]](%[[CASE_OPERAND_0]]
+  // CHECK-NEXT:   43: ^[[BB6:[a-zA-Z0-9_]+]](%[[CASE_OPERAND_1]]
+  // CHECK-NEXT:   44: ^[[BB4:[a-zA-Z0-9_]+]](%[[CASE_OPERAND_2]]
+  // CHECK-NEXT: ]
+    switch %flag : i32, [
+      default: ^bb2(%caseOperand0 : f32),
+      43: ^bb3(%caseOperand1 : f32),
+      44: ^bb4(%caseOperand2 : f32)
+    ]
+  ^bb2(%bb2Arg : f32):
+    br ^bb5(%bb2Arg : f32)
+  ^bb3(%bb3Arg : f32):
+    br ^bb6(%bb3Arg : f32)
+  ^bb4(%bb4Arg : f32):
+    "foo.bb4Terminator"(%bb4Arg) : (f32) -> ()
+
+  // CHECK: ^[[BB5]]({{.*}}):
+  // CHECK-NEXT: "foo.bb5Terminator"
+  ^bb5(%bb5Arg : f32):
+    "foo.bb5Terminator"(%bb5Arg) : (f32) -> ()
+
+  // CHECK: ^[[BB6]]({{.*}}):
+  // CHECK-NEXT: "foo.bb6Terminator"
+  ^bb6(%bb6Arg : f32):
+    "foo.bb6Terminator"(%bb6Arg) : (f32) -> ()
+}
+
+// CHECK-LABEL: func @switch_from_switch_with_same_value_with_match(
+// CHECK-SAME: %[[FLAG:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[CASE_OPERAND_0:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[CASE_OPERAND_1:[a-zA-Z0-9_]+]]
+func @switch_from_switch_with_same_value_with_match(%flag : i32, %caseOperand0 : f32, %caseOperand1 : f32) {
+  // add predecessors for all blocks except ^bb3 to avoid other canonicalizations.
+  "foo.pred"() [^bb1, ^bb2, ^bb4, ^bb5] : () -> ()
+
+  ^bb1:
+    // CHECK: switch %[[FLAG]]
+    switch %flag : i32, [
+      default: ^bb2,
+      42: ^bb3
+    ]
+
+  ^bb2:
+    "foo.bb2Terminator"() : () -> ()
+  ^bb3:
+    // prevent this block from being simplified away
+    "foo.op"() : () -> ()
+    // CHECK-NOT: switch %[[FLAG]]
+    // CHECK: br ^[[BB5:[a-zA-Z0-9_]+]](%[[CASE_OPERAND_1]]
+    switch %flag : i32, [
+      default: ^bb4(%caseOperand0 : f32),
+      42: ^bb5(%caseOperand1 : f32)
+    ]
+
+  ^bb4(%bb4Arg : f32):
+    "foo.bb4Terminator"(%bb4Arg) : (f32) -> ()
+
+  // CHECK: ^[[BB5]]({{.*}}):
+  // CHECK-NEXT: "foo.bb5Terminator"
+  ^bb5(%bb5Arg : f32):
+    "foo.bb5Terminator"(%bb5Arg) : (f32) -> ()
+}
+
+// CHECK-LABEL: func @switch_from_switch_with_same_value_no_match(
+// CHECK-SAME: %[[FLAG:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[CASE_OPERAND_0:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[CASE_OPERAND_1:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[CASE_OPERAND_2:[a-zA-Z0-9_]+]]
+func @switch_from_switch_with_same_value_no_match(%flag : i32, %caseOperand0 : f32, %caseOperand1 : f32, %caseOperand2 : f32) {
+  // add predecessors for all blocks except ^bb3 to avoid other canonicalizations.
+  "foo.pred"() [^bb1, ^bb2, ^bb4, ^bb5, ^bb6] : () -> ()
+
+  ^bb1:
+    // CHECK: switch %[[FLAG]]
+    switch %flag : i32, [
+      default: ^bb2,
+      42: ^bb3
+    ]
+
+  ^bb2:
+    "foo.bb2Terminator"() : () -> ()
+  ^bb3:
+    "foo.op"() : () -> ()
+    // CHECK-NOT: switch %[[FLAG]]
+    // CHECK: br ^[[BB4:[a-zA-Z0-9_]+]](%[[CASE_OPERAND_0]]
+    switch %flag : i32, [
+      default: ^bb4(%caseOperand0 : f32),
+      0: ^bb5(%caseOperand1 : f32),
+      43: ^bb6(%caseOperand2 : f32)
+    ]
+
+  // CHECK: ^[[BB4]]({{.*}})
+  // CHECK-NEXT: "foo.bb4Terminator"
+  ^bb4(%bb4Arg : f32):
+    "foo.bb4Terminator"(%bb4Arg) : (f32) -> ()
+
+  ^bb5(%bb5Arg : f32):
+    "foo.bb5Terminator"(%bb5Arg) : (f32) -> ()
+
+  ^bb6(%bb6Arg : f32):
+    "foo.bb6Terminator"(%bb6Arg) : (f32) -> ()
+}
+
+// CHECK-LABEL: func @switch_from_switch_default_with_same_value(
+// CHECK-SAME: %[[FLAG:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[CASE_OPERAND_0:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[CASE_OPERAND_1:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[CASE_OPERAND_2:[a-zA-Z0-9_]+]]
+func @switch_from_switch_default_with_same_value(%flag : i32, %caseOperand0 : f32, %caseOperand1 : f32, %caseOperand2 : f32) {
+  // add predecessors for all blocks except ^bb3 to avoid other canonicalizations.
+  "foo.pred"() [^bb1, ^bb2, ^bb4, ^bb5, ^bb6] : () -> ()
+
+  ^bb1:
+    // CHECK: switch %[[FLAG]]
+    switch %flag : i32, [
+      default: ^bb3,
+      42: ^bb2
+    ]
+
+  ^bb2:
+    "foo.bb2Terminator"() : () -> ()
+  ^bb3:
+    "foo.op"() : () -> ()
+    // CHECK: switch %[[FLAG]]
+    // CHECK-NEXT: default: ^[[BB4:[a-zA-Z0-9_]+]](%[[CASE_OPERAND_0]]
+    // CHECK-NEXT: 43: ^[[BB6:[a-zA-Z0-9_]+]](%[[CASE_OPERAND_2]]
+    // CHECK-NOT: 42
+    switch %flag : i32, [
+      default: ^bb4(%caseOperand0 : f32),
+      42: ^bb5(%caseOperand1 : f32),
+      43: ^bb6(%caseOperand2 : f32)
+    ]
+
+  // CHECK: ^[[BB4]]({{.*}}):
+  // CHECK-NEXT: "foo.bb4Terminator"
+  ^bb4(%bb4Arg : f32):
+    "foo.bb4Terminator"(%bb4Arg) : (f32) -> ()
+
+  ^bb5(%bb5Arg : f32):
+    "foo.bb5Terminator"(%bb5Arg) : (f32) -> ()
+
+  // CHECK: ^[[BB6]]({{.*}}):
+  // CHECK-NEXT: "foo.bb6Terminator"
+  ^bb6(%bb6Arg : f32):
+    "foo.bb6Terminator"(%bb6Arg) : (f32) -> ()
+}
+
 /// Test folding conditional branches that are successors of conditional
 /// branches with the same condition.
 

diff  --git a/mlir/test/Dialect/Standard/ops.mlir b/mlir/test/Dialect/Standard/ops.mlir
index 02ec47f96ac6e..53a4ad5550f9f 100644
--- a/mlir/test/Dialect/Standard/ops.mlir
+++ b/mlir/test/Dialect/Standard/ops.mlir
@@ -96,3 +96,35 @@ func @read_global_memref() {
   %1 = memref.tensor_load %0 : memref<2xf32>
   return
 }
+
+// CHECK-LABEL: func @switch(
+func @switch(%flag : i32, %caseOperand : i32) {
+  switch %flag : i32, [
+    default: ^bb1(%caseOperand : i32),
+    42: ^bb2(%caseOperand : i32),
+    43: ^bb3(%caseOperand : i32)
+  ]
+
+  ^bb1(%bb1arg : i32):
+    return
+  ^bb2(%bb2arg : i32):
+    return
+  ^bb3(%bb3arg : i32):
+    return
+}
+
+// CHECK-LABEL: func @switch_i64(
+func @switch_i64(%flag : i64, %caseOperand : i32) {
+  switch %flag : i64, [
+    default: ^bb1(%caseOperand : i32),
+    42: ^bb2(%caseOperand : i32),
+    43: ^bb3(%caseOperand : i32)
+  ]
+
+  ^bb1(%bb1arg : i32):
+    return
+  ^bb2(%bb2arg : i32):
+    return
+  ^bb3(%bb3arg : i32):
+    return
+}

diff  --git a/mlir/test/Dialect/Standard/parser.mlir b/mlir/test/Dialect/Standard/parser.mlir
new file mode 100644
index 0000000000000..9fcf9529a4a78
--- /dev/null
+++ b/mlir/test/Dialect/Standard/parser.mlir
@@ -0,0 +1,69 @@
+// RUN: mlir-opt -verify-diagnostics -split-input-file %s
+
+func @switch_missing_case_value(%flag : i32, %caseOperand : i32) {
+  switch %flag : i32, [
+    default: ^bb1(%caseOperand : i32),
+    45: ^bb2(%caseOperand : i32),
+    // expected-error at +1 {{expected integer value}}
+    : ^bb3(%caseOperand : i32)
+  ]
+
+  ^bb1(%bb1arg : i32):
+    return
+  ^bb2(%bb2arg : i32):
+    return
+  ^bb3(%bb3arg : i32):
+    return
+}
+
+// -----
+
+func @switch_wrong_type_case_value(%flag : i32, %caseOperand : i32) {
+  switch %flag : i32, [
+    default: ^bb1(%caseOperand : i32),
+    // expected-error at +1 {{expected integer value}}
+    "hello": ^bb2(%caseOperand : i32)
+  ]
+
+  ^bb1(%bb1arg : i32):
+    return
+  ^bb2(%bb2arg : i32):
+    return
+  ^bb3(%bb3arg : i32):
+    return
+}
+
+// -----
+
+func @switch_missing_comma(%flag : i32, %caseOperand : i32) {
+  switch %flag : i32, [
+    default: ^bb1(%caseOperand : i32),
+    45: ^bb2(%caseOperand : i32)
+    // expected-error at +1 {{expected ']'}}
+    43: ^bb3(%caseOperand : i32)
+  ]
+
+  ^bb1(%bb1arg : i32):
+    return
+  ^bb2(%bb2arg : i32):
+    return
+  ^bb3(%bb3arg : i32):
+    return
+}
+
+// -----
+
+func @switch_missing_default(%flag : i32, %caseOperand : i32) {
+  switch %flag : i32, [
+    // expected-error at +1 {{expected 'default'}}
+    45: ^bb2(%caseOperand : i32)
+    43: ^bb3(%caseOperand : i32)
+  ]
+
+  ^bb1(%bb1arg : i32):
+    return
+  ^bb2(%bb2arg : i32):
+    return
+  ^bb3(%bb3arg : i32):
+    return
+}


        


More information about the Mlir-commits mailing list