[Mlir-commits] [mlir] 14f2415 - [mlir][LLVMIR] Add 'llvm.switch' op

Brian Gesiak llvmlistbot at llvm.org
Thu Dec 17 11:16:59 PST 2020


Author: Brian Gesiak
Date: 2020-12-17T14:11:21-05:00
New Revision: 14f24155a5915a295bd965bb6062bfeab217b9c8

URL: https://github.com/llvm/llvm-project/commit/14f24155a5915a295bd965bb6062bfeab217b9c8
DIFF: https://github.com/llvm/llvm-project/commit/14f24155a5915a295bd965bb6062bfeab217b9c8.diff

LOG: [mlir][LLVMIR] Add 'llvm.switch' op

The LLVM IR 'switch' instruction allows control flow to be transferred
to one of any number of branches depending on an integer control value,
or a default value if the control does not match any branch values. This patch
adds `llvm.switch` to the MLIR LLVMIR dialect, as well as translation routines
for lowering it to LLVM IR.

To store a variable number of operands for a variable number of branch
destinations, the new op makes use of the `AttrSizedOperandSegments`
trait. It stores its default branch operands as one segment, and all
remaining case branches' operands as another. It also stores pairs of
begin and end offset values to delineate the sub-range of each case branch's
operands. There's probably a better way to implement this, since the
offset computation complicates several parts of the op definition. This is the
approach I settled on because in doing so I was able to delegate to the default
op builder member functions. However, it may be preferable to instead specify
`skipDefaultBuilders` in the op's ODS, or use a completely separate
approach; feedback is welcome!

Another contentious part of this patch may be the custom printer and
parser functions for the op. Ideally I would have liked the MLIR to be
printed in this way:

```
llvm.switch %0, ^bb1(%1 : !llvm.i32) [
  1: ^bb2,
  2: ^bb3(%2, %3 : !llvm.i32, !llvm.i32)
]
```

The above would resemble how LLVM IR is formatted for the 'switch'
instruction. But I found it difficult to print and parse something like
this, whether I used the declarative assembly format or custom functions.
I also was not sure a multi-line format would be welcome -- it seems
like most MLIR ops do not use newlines. Again, I'd be happy to hear any
feedback here as well, or on any other aspect of the patch.

Differential Revision: https://reviews.llvm.org/D93005

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
    mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
    mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
    mlir/test/Dialect/LLVMIR/invalid.mlir
    mlir/test/Dialect/LLVMIR/roundtrip.mlir
    mlir/test/Target/llvmir.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 0088fe38246b..9608e15bb81a 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -635,6 +635,50 @@ def LLVM_UnreachableOp : LLVM_TerminatorOp<"unreachable", []> {
   let printer = [{ p << getOperationName(); }];
 }
 
+def LLVM_SwitchOp : LLVM_TerminatorOp<"switch",
+    [AttrSizedOperandSegments, DeclareOpInterfaceMethods<BranchOpInterface>,
+     NoSideEffect]> {
+  let arguments = (ins LLVM_i32:$value,
+                   Variadic<AnyType>:$defaultOperands,
+                   Variadic<AnyType>:$caseOperands,
+                   OptionalAttr<ElementsAttr>:$case_values,
+                   OptionalAttr<ElementsAttr>:$case_operand_offsets,
+                   OptionalAttr<ElementsAttr>:$branch_weights);
+  let successors = (successor
+                    AnySuccessor:$defaultDestination,
+                    VariadicSuccessor<AnySuccessor>:$caseDestinations);
+
+  let verifier = [{ return ::verify(*this); }];
+  let assemblyFormat = [{
+    $value `,`
+    $defaultDestination (`(` $defaultOperands^ `:` type($defaultOperands) `)`)?
+    `[` `\n` custom<SwitchOpCases>($case_values, $caseDestinations,
+                                   $caseOperands, type($caseOperands),
+                                   $case_operand_offsets) `]`
+    attr-dict
+  }];
+
+  let builders = [
+    OpBuilderDAG<(ins "Value":$value,
+      "Block *":$defaultDestination,
+      "ValueRange":$defaultOperands,
+      CArg<"ArrayRef<int32_t>", "{}">:$caseValues,
+      CArg<"BlockRange", "{}">:$caseDestinations,
+      CArg<"ArrayRef<ValueRange>", "{}">:$caseOperands,
+      CArg<"ArrayRef<int32_t>", "{}">:$branchWeights)>,
+    LLVM_TerminatorPassthroughOpBuilder
+  ];
+
+  let extraClassDeclaration = [{
+    /// Return the operands for the case destination block at the given index.
+    OperandRange getCaseOperands(unsigned index);
+
+    /// Return a mutable range of operands for the case destination block at the
+    /// given index.
+    MutableOperandRange getCaseOperandsMutable(unsigned index);
+  }];
+}
+
 ////////////////////////////////////////////////////////////////////////////////
 // Auxiliary operations (do not appear in LLVM IR but necessary for the dialect
 // to work correctly).

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index d70a327824b7..9b2c88c30a86 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -189,6 +189,150 @@ CondBrOp::getMutableSuccessorOperands(unsigned index) {
   return index == 0 ? trueDestOperandsMutable() : falseDestOperandsMutable();
 }
 
+//===----------------------------------------------------------------------===//
+// LLVM::SwitchOp
+//===----------------------------------------------------------------------===//
+
+void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
+                     Block *defaultDestination, ValueRange defaultOperands,
+                     ArrayRef<int32_t> caseValues, BlockRange caseDestinations,
+                     ArrayRef<ValueRange> caseOperands,
+                     ArrayRef<int32_t> branchWeights) {
+  SmallVector<Value> flattenedCaseOperands;
+  SmallVector<int32_t> caseOperandOffsets;
+  int32_t offset = 0;
+  for (ValueRange operands : caseOperands) {
+    flattenedCaseOperands.append(operands.begin(), operands.end());
+    caseOperandOffsets.push_back(offset);
+    offset += operands.size();
+  }
+  ElementsAttr caseValuesAttr;
+  if (!caseValues.empty())
+    caseValuesAttr = builder.getI32VectorAttr(caseValues);
+  ElementsAttr caseOperandOffsetsAttr;
+  if (!caseOperandOffsets.empty())
+    caseOperandOffsetsAttr = builder.getI32VectorAttr(caseOperandOffsets);
+
+  ElementsAttr weightsAttr;
+  if (!branchWeights.empty())
+    weightsAttr = builder.getI32VectorAttr(llvm::to_vector<4>(branchWeights));
+
+  build(builder, result, value, defaultOperands, flattenedCaseOperands,
+        caseValuesAttr, caseOperandOffsetsAttr, weightsAttr, defaultDestination,
+        caseDestinations);
+}
+
+/// <cases> ::= integer `:` bb-id (`(` ssa-use-and-type-list `)`)?
+///             ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )?
+static ParseResult
+parseSwitchOpCases(OpAsmParser &parser, ElementsAttr &caseValues,
+                   SmallVectorImpl<Block *> &caseDestinations,
+                   SmallVectorImpl<OpAsmParser::OperandType> &caseOperands,
+                   SmallVectorImpl<Type> &caseOperandTypes,
+                   ElementsAttr &caseOperandOffsets) {
+  SmallVector<int32_t> values;
+  SmallVector<int32_t> offsets;
+  int32_t value, offset = 0;
+  do {
+    OptionalParseResult integerParseResult = parser.parseOptionalInteger(value);
+    if (values.empty() && !integerParseResult.hasValue())
+      return success();
+
+    if (!integerParseResult.hasValue() || integerParseResult.getValue())
+      return failure();
+    values.push_back(value);
+
+    Block *destination;
+    SmallVector<OpAsmParser::OperandType> operands;
+    if (parser.parseColon() || parser.parseSuccessor(destination))
+      return failure();
+    if (!parser.parseOptionalLParen()) {
+      if (parser.parseRegionArgumentList(operands) ||
+          parser.parseColonTypeList(caseOperandTypes) || parser.parseRParen())
+        return failure();
+    }
+    caseDestinations.push_back(destination);
+    caseOperands.append(operands.begin(), operands.end());
+    offsets.push_back(offset);
+    offset += operands.size();
+  } while (!parser.parseOptionalComma());
+
+  Builder &builder = parser.getBuilder();
+  caseValues = builder.getI32VectorAttr(values);
+  caseOperandOffsets = builder.getI32VectorAttr(offsets);
+
+  return success();
+}
+
+static void printSwitchOpCases(OpAsmPrinter &p, SwitchOp op,
+                               ElementsAttr caseValues,
+                               SuccessorRange caseDestinations,
+                               OperandRange caseOperands,
+                               TypeRange caseOperandTypes,
+                               ElementsAttr caseOperandOffsets) {
+  if (!caseValues)
+    return;
+
+  size_t index = 0;
+  llvm::interleave(
+      llvm::zip(caseValues.cast<DenseIntElementsAttr>(), caseDestinations),
+      [&](auto i) {
+        p << "  ";
+        p << std::get<0>(i).getLimitedValue();
+        p << ": ";
+        p.printSuccessorAndUseList(std::get<1>(i), op.getCaseOperands(index++));
+      },
+      [&] {
+        p << ',';
+        p.printNewline();
+      });
+  p.printNewline();
+}
+
+static LogicalResult verify(SwitchOp op) {
+  if ((!op.case_values() && !op.caseDestinations().empty()) ||
+      (op.case_values() &&
+       op.case_values()->size() !=
+           static_cast<int64_t>(op.caseDestinations().size())))
+    return op.emitOpError("expects number of case values to match number of "
+                          "case destinations");
+  if (op.branch_weights() &&
+      op.branch_weights()->size() != op.getNumSuccessors())
+    return op.emitError("expects number of branch weights to match number of "
+                        "successors: ")
+           << op.branch_weights()->size() << " vs " << op.getNumSuccessors();
+  return success();
+}
+
+OperandRange SwitchOp::getCaseOperands(unsigned index) {
+  return getCaseOperandsMutable(index);
+}
+
+MutableOperandRange SwitchOp::getCaseOperandsMutable(unsigned index) {
+  MutableOperandRange caseOperands = caseOperandsMutable();
+  if (!case_operand_offsets()) {
+    assert(caseOperands.size() == 0 &&
+           "non-empty case operands must have offsets");
+    return caseOperands;
+  }
+
+  ElementsAttr offsets = case_operand_offsets().getValue();
+  assert(index < offsets.size() && "invalid case operand offset index");
+
+  int64_t begin = offsets.getValue(index).cast<IntegerAttr>().getInt();
+  int64_t end = index + 1 == offsets.size()
+                    ? caseOperands.size()
+                    : offsets.getValue(index + 1).cast<IntegerAttr>().getInt();
+  return caseOperandsMutable().slice(begin, end - begin);
+}
+
+Optional<MutableOperandRange>
+SwitchOp::getMutableSuccessorOperands(unsigned index) {
+  assert(index < getNumSuccessors() && "invalid successor index");
+  return index == 0 ? defaultOperandsMutable()
+                    : getCaseOperandsMutable(index - 1);
+}
+
 //===----------------------------------------------------------------------===//
 // Builder, printer and parser for for LLVM::LoadOp.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 057f57409940..d9094f8763d9 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -323,19 +323,31 @@ static Value getPHISourceValue(Block *current, Block *pred,
   if (isa<LLVM::BrOp>(terminator))
     return terminator.getOperand(index);
 
-  // For conditional branches, we need to check if the current block is reached
-  // through the "true" or the "false" branch and take the relevant operands.
-  auto condBranchOp = dyn_cast<LLVM::CondBrOp>(terminator);
-  assert(condBranchOp &&
-         "only branch operations can be terminators of a block that "
-         "has successors");
-  assert((condBranchOp.getSuccessor(0) != condBranchOp.getSuccessor(1)) &&
-         "successors with arguments in LLVM conditional branches must be "
-         "
diff erent blocks");
-
-  return condBranchOp.getSuccessor(0) == current
-             ? condBranchOp.trueDestOperands()[index]
-             : condBranchOp.falseDestOperands()[index];
+  SuccessorRange successors = terminator.getSuccessors();
+  assert(std::adjacent_find(successors.begin(), successors.end()) ==
+             successors.end() &&
+         "successors with arguments in LLVM branches must be 
diff erent blocks");
+
+  // For instructions that branch based on a condition value, we need to take
+  // the operands for the branch that was taken.
+  if (auto condBranchOp = dyn_cast<LLVM::CondBrOp>(terminator)) {
+    // For conditional branches, we take the operands from either the "true" or
+    // the "false" branch.
+    return condBranchOp.getSuccessor(0) == current
+               ? condBranchOp.trueDestOperands()[index]
+               : condBranchOp.falseDestOperands()[index];
+  } else if (auto switchOp = dyn_cast<LLVM::SwitchOp>(terminator)) {
+    // For switches, we take the operands from either the default case, or from
+    // the case branch that was taken.
+    if (switchOp.defaultDestination() == current)
+      return switchOp.defaultOperands()[index];
+    for (auto i : llvm::enumerate(switchOp.caseDestinations()))
+      if (i.value() == current)
+        return switchOp.getCaseOperands(i.index())[index];
+  }
+
+  llvm_unreachable("only branch or switch operations can be terminators of a "
+                   "block that has successors");
 }
 
 /// Connect the PHI nodes to the results of preceding blocks.
@@ -717,6 +729,34 @@ LogicalResult ModuleTranslation::convertOperation(Operation &opInst,
     branchMapping.try_emplace(&opInst, branch);
     return success();
   }
+  if (auto switchOp = dyn_cast<LLVM::SwitchOp>(opInst)) {
+    llvm::MDNode *branchWeights = nullptr;
+    if (auto weights = switchOp.branch_weights()) {
+      llvm::SmallVector<uint32_t> weightValues;
+      weightValues.reserve(weights->size());
+      for (llvm::APInt weight : weights->cast<DenseIntElementsAttr>())
+        weightValues.push_back(weight.getLimitedValue());
+      branchWeights = llvm::MDBuilder(llvmModule->getContext())
+                          .createBranchWeights(weightValues);
+    }
+
+    llvm::SwitchInst *switchInst =
+        builder.CreateSwitch(valueMapping[switchOp.value()],
+                             blockMapping[switchOp.defaultDestination()],
+                             switchOp.caseDestinations().size(), branchWeights);
+
+    auto *ty = llvm::cast<llvm::IntegerType>(
+        convertType(switchOp.value().getType().cast<LLVMType>()));
+    for (auto i :
+         llvm::zip(switchOp.case_values()->cast<DenseIntElementsAttr>(),
+                   switchOp.caseDestinations()))
+      switchInst->addCase(
+          llvm::ConstantInt::get(ty, std::get<0>(i).getLimitedValue()),
+          blockMapping[std::get<1>(i)]);
+
+    branchMapping.try_emplace(&opInst, switchInst);
+    return success();
+  }
 
   // Emit addressof.  We need to look up the global value referenced by the
   // operation and store it in the MLIR-to-LLVM value mapping.  This does not

diff  --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index 1686da2fdba7..9461ebbd9ede 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -652,3 +652,18 @@ func @invalid_ordering_in_fence() {
 module attributes {llvm.data_layout = "#vjkr32"} {
   func @invalid_data_layout()
 }
+
+// -----
+
+func @switch_wrong_number_of_weights(%arg0 : !llvm.i32) {
+  // expected-error at +1 {{expects number of branch weights to match number of successors: 3 vs 2}}
+  llvm.switch %arg0, ^bb1 [
+    42: ^bb2(%arg0, %arg0 : !llvm.i32, !llvm.i32)
+  ] {branch_weights = dense<[13, 17, 19]> : vector<3xi32>}
+
+^bb1: // pred: ^bb0
+  llvm.return
+
+^bb2(%1: !llvm.i32, %2: !llvm.i32): // pred: ^bb0
+  llvm.return
+}

diff  --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
index 6f794b7d9fe4..fc9ff686d78f 100644
--- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir
+++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
@@ -71,8 +71,8 @@ func @ops(%arg0: !llvm.i32, %arg1: !llvm.float,
 
 // CHECK: ^[[BB1]]
 ^bb1:
-// CHECK: llvm.cond_br %7, ^[[BB2:.*]], ^[[BB1]]
-  llvm.cond_br %7, ^bb2, ^bb1
+// CHECK: llvm.cond_br %7, ^[[BB2:.*]], ^[[BB3:.*]]
+  llvm.cond_br %7, ^bb2, ^bb3
 
 // CHECK: ^[[BB2]]
 ^bb2:
@@ -80,7 +80,41 @@ func @ops(%arg0: !llvm.i32, %arg1: !llvm.float,
 // CHECK: %{{.*}} = llvm.mlir.constant(42 : i64) : !llvm.i47
   %22 = llvm.mlir.undef : !llvm.struct<(i32, double, i32)>
   %23 = llvm.mlir.constant(42) : !llvm.i47
+  // CHECK:      llvm.switch %0, ^[[BB3]] [
+  // CHECK-NEXT:   1: ^[[BB4:.*]],
+  // CHECK-NEXT:   2: ^[[BB5:.*]],
+  // CHECK-NEXT:   3: ^[[BB6:.*]]
+  // CHECK-NEXT: ]
+  llvm.switch %0, ^bb3 [
+    1: ^bb4,
+    2: ^bb5,
+    3: ^bb6
+  ]
+
+// CHECK: ^[[BB3]]
+^bb3:
+// CHECK:      llvm.switch %0, ^[[BB7:.*]] [
+// CHECK-NEXT: ]
+  llvm.switch %0, ^bb7 [
+  ]
+
+// CHECK: ^[[BB4]]
+^bb4:
+  llvm.switch %0, ^bb7 [
+  ]
+
+// CHECK: ^[[BB5]]
+^bb5:
+  llvm.switch %0, ^bb7 [
+  ]
+
+// CHECK: ^[[BB6]]
+^bb6:
+  llvm.switch %0, ^bb7 [
+  ]
 
+// CHECK: ^[[BB7]]
+^bb7:
 // Misc operations.
 // CHECK: %{{.*}} = llvm.select %{{.*}}, %{{.*}}, %{{.*}} : !llvm.i1, !llvm.i32
   %24 = llvm.select %7, %0, %1 : !llvm.i1, !llvm.i32

diff  --git a/mlir/test/Target/llvmir.mlir b/mlir/test/Target/llvmir.mlir
index 9f8a37198558..099b8c96cb16 100644
--- a/mlir/test/Target/llvmir.mlir
+++ b/mlir/test/Target/llvmir.mlir
@@ -1358,3 +1358,60 @@ llvm.func @useInlineAsm(%arg0: !llvm.i32) {
   llvm.return
 }
 
+// -----
+
+// CHECK-LABEL: @switch_args
+llvm.func @switch_args(%arg0: !llvm.i32) {
+  %0 = llvm.mlir.constant(5 : i32) : !llvm.i32
+  %1 = llvm.mlir.constant(7 : i32) : !llvm.i32
+  %2 = llvm.mlir.constant(11 : i32) : !llvm.i32
+  // CHECK:      switch i32 %[[SWITCH_arg0:[0-9]+]], label %[[SWITCHDEFAULT_bb1:[0-9]+]] [
+  // CHECK-NEXT:   i32 -1, label %[[SWITCHCASE_bb2:[0-9]+]]
+  // CHECK-NEXT:   i32 1, label %[[SWITCHCASE_bb3:[0-9]+]]
+  // CHECK-NEXT: ]
+  llvm.switch %arg0, ^bb1 [
+    -1: ^bb2(%0 : !llvm.i32),
+    1: ^bb3(%1, %2 : !llvm.i32, !llvm.i32)
+  ]
+
+// CHECK:      [[SWITCHDEFAULT_bb1]]:
+// CHECK-NEXT:   ret i32 %[[SWITCH_arg0]]
+^bb1:  // pred: ^bb0
+  llvm.return %arg0 : !llvm.i32
+
+// CHECK:      [[SWITCHCASE_bb2]]:
+// CHECK-NEXT:   phi i32 [ 5, %1 ]
+// CHECK-NEXT:   ret i32
+^bb2(%3: !llvm.i32): // pred: ^bb0
+  llvm.return %1 : !llvm.i32
+
+// CHECK:      [[SWITCHCASE_bb3]]:
+// CHECK-NEXT:   phi i32 [ 7, %1 ]
+// CHECK-NEXT:   phi i32 [ 11, %1 ]
+// CHECK-NEXT:   ret i32
+^bb3(%4: !llvm.i32, %5: !llvm.i32): // pred: ^bb0
+  llvm.return %4 : !llvm.i32
+}
+
+// CHECK-LABEL: @switch_weights
+llvm.func @switch_weights(%arg0: !llvm.i32) {
+  %0 = llvm.mlir.constant(19 : i32) : !llvm.i32
+  %1 = llvm.mlir.constant(23 : i32) : !llvm.i32
+  %2 = llvm.mlir.constant(29 : i32) : !llvm.i32
+  // CHECK: !prof ![[SWITCH_WEIGHT_NODE:[0-9]+]]
+  llvm.switch %arg0, ^bb1(%0 : !llvm.i32) [
+    9: ^bb2(%1, %2 : !llvm.i32, !llvm.i32),
+    99: ^bb3
+  ] {branch_weights = dense<[13, 17, 19]> : vector<3xi32>}
+
+^bb1(%3: !llvm.i32):  // pred: ^bb0
+  llvm.return %3 : !llvm.i32
+
+^bb2(%4: !llvm.i32, %5: !llvm.i32): // pred: ^bb0
+  llvm.return %5 : !llvm.i32
+
+^bb3: // pred: ^bb0
+  llvm.return %arg0 : !llvm.i32
+}
+
+// CHECK: ![[SWITCH_WEIGHT_NODE]] = !{!"branch_weights", i32 13, i32 17, i32 19}


        


More information about the Mlir-commits mailing list