[Mlir-commits] [mlir] 14f2415 - [mlir][LLVMIR] Add 'llvm.switch' op
Brian Gesiak
llvmlistbot at llvm.org
Thu Dec 17 11:16:59 PST 2020
Author: Brian Gesiak
Date: 2020-12-17T14:11:21-05:00
New Revision: 14f24155a5915a295bd965bb6062bfeab217b9c8
URL: https://github.com/llvm/llvm-project/commit/14f24155a5915a295bd965bb6062bfeab217b9c8
DIFF: https://github.com/llvm/llvm-project/commit/14f24155a5915a295bd965bb6062bfeab217b9c8.diff
LOG: [mlir][LLVMIR] Add 'llvm.switch' op
The LLVM IR 'switch' instruction allows control flow to be transferred
to one of any number of branches depending on an integer control value,
or a default value if the control does not match any branch values. This patch
adds `llvm.switch` to the MLIR LLVMIR dialect, as well as translation routines
for lowering it to LLVM IR.
To store a variable number of operands for a variable number of branch
destinations, the new op makes use of the `AttrSizedOperandSegments`
trait. It stores its default branch operands as one segment, and all
remaining case branches' operands as another. It also stores pairs of
begin and end offset values to delineate the sub-range of each case branch's
operands. There's probably a better way to implement this, since the
offset computation complicates several parts of the op definition. This is the
approach I settled on because in doing so I was able to delegate to the default
op builder member functions. However, it may be preferable to instead specify
`skipDefaultBuilders` in the op's ODS, or use a completely separate
approach; feedback is welcome!
Another contentious part of this patch may be the custom printer and
parser functions for the op. Ideally I would have liked the MLIR to be
printed in this way:
```
llvm.switch %0, ^bb1(%1 : !llvm.i32) [
1: ^bb2,
2: ^bb3(%2, %3 : !llvm.i32, !llvm.i32)
]
```
The above would resemble how LLVM IR is formatted for the 'switch'
instruction. But I found it difficult to print and parse something like
this, whether I used the declarative assembly format or custom functions.
I also was not sure a multi-line format would be welcome -- it seems
like most MLIR ops do not use newlines. Again, I'd be happy to hear any
feedback here as well, or on any other aspect of the patch.
Differential Revision: https://reviews.llvm.org/D93005
Added:
Modified:
mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
mlir/test/Dialect/LLVMIR/invalid.mlir
mlir/test/Dialect/LLVMIR/roundtrip.mlir
mlir/test/Target/llvmir.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 0088fe38246b..9608e15bb81a 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -635,6 +635,50 @@ def LLVM_UnreachableOp : LLVM_TerminatorOp<"unreachable", []> {
let printer = [{ p << getOperationName(); }];
}
+def LLVM_SwitchOp : LLVM_TerminatorOp<"switch",
+ [AttrSizedOperandSegments, DeclareOpInterfaceMethods<BranchOpInterface>,
+ NoSideEffect]> {
+ let arguments = (ins LLVM_i32:$value,
+ Variadic<AnyType>:$defaultOperands,
+ Variadic<AnyType>:$caseOperands,
+ OptionalAttr<ElementsAttr>:$case_values,
+ OptionalAttr<ElementsAttr>:$case_operand_offsets,
+ OptionalAttr<ElementsAttr>:$branch_weights);
+ let successors = (successor
+ AnySuccessor:$defaultDestination,
+ VariadicSuccessor<AnySuccessor>:$caseDestinations);
+
+ let verifier = [{ return ::verify(*this); }];
+ let assemblyFormat = [{
+ $value `,`
+ $defaultDestination (`(` $defaultOperands^ `:` type($defaultOperands) `)`)?
+ `[` `\n` custom<SwitchOpCases>($case_values, $caseDestinations,
+ $caseOperands, type($caseOperands),
+ $case_operand_offsets) `]`
+ attr-dict
+ }];
+
+ let builders = [
+ OpBuilderDAG<(ins "Value":$value,
+ "Block *":$defaultDestination,
+ "ValueRange":$defaultOperands,
+ CArg<"ArrayRef<int32_t>", "{}">:$caseValues,
+ CArg<"BlockRange", "{}">:$caseDestinations,
+ CArg<"ArrayRef<ValueRange>", "{}">:$caseOperands,
+ CArg<"ArrayRef<int32_t>", "{}">:$branchWeights)>,
+ LLVM_TerminatorPassthroughOpBuilder
+ ];
+
+ let extraClassDeclaration = [{
+ /// Return the operands for the case destination block at the given index.
+ OperandRange getCaseOperands(unsigned index);
+
+ /// Return a mutable range of operands for the case destination block at the
+ /// given index.
+ MutableOperandRange getCaseOperandsMutable(unsigned index);
+ }];
+}
+
////////////////////////////////////////////////////////////////////////////////
// Auxiliary operations (do not appear in LLVM IR but necessary for the dialect
// to work correctly).
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index d70a327824b7..9b2c88c30a86 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -189,6 +189,150 @@ CondBrOp::getMutableSuccessorOperands(unsigned index) {
return index == 0 ? trueDestOperandsMutable() : falseDestOperandsMutable();
}
+//===----------------------------------------------------------------------===//
+// LLVM::SwitchOp
+//===----------------------------------------------------------------------===//
+
+void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
+ Block *defaultDestination, ValueRange defaultOperands,
+ ArrayRef<int32_t> caseValues, BlockRange caseDestinations,
+ ArrayRef<ValueRange> caseOperands,
+ ArrayRef<int32_t> branchWeights) {
+ SmallVector<Value> flattenedCaseOperands;
+ SmallVector<int32_t> caseOperandOffsets;
+ int32_t offset = 0;
+ for (ValueRange operands : caseOperands) {
+ flattenedCaseOperands.append(operands.begin(), operands.end());
+ caseOperandOffsets.push_back(offset);
+ offset += operands.size();
+ }
+ ElementsAttr caseValuesAttr;
+ if (!caseValues.empty())
+ caseValuesAttr = builder.getI32VectorAttr(caseValues);
+ ElementsAttr caseOperandOffsetsAttr;
+ if (!caseOperandOffsets.empty())
+ caseOperandOffsetsAttr = builder.getI32VectorAttr(caseOperandOffsets);
+
+ ElementsAttr weightsAttr;
+ if (!branchWeights.empty())
+ weightsAttr = builder.getI32VectorAttr(llvm::to_vector<4>(branchWeights));
+
+ build(builder, result, value, defaultOperands, flattenedCaseOperands,
+ caseValuesAttr, caseOperandOffsetsAttr, weightsAttr, defaultDestination,
+ caseDestinations);
+}
+
+/// <cases> ::= integer `:` bb-id (`(` ssa-use-and-type-list `)`)?
+/// ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )?
+static ParseResult
+parseSwitchOpCases(OpAsmParser &parser, ElementsAttr &caseValues,
+ SmallVectorImpl<Block *> &caseDestinations,
+ SmallVectorImpl<OpAsmParser::OperandType> &caseOperands,
+ SmallVectorImpl<Type> &caseOperandTypes,
+ ElementsAttr &caseOperandOffsets) {
+ SmallVector<int32_t> values;
+ SmallVector<int32_t> offsets;
+ int32_t value, offset = 0;
+ do {
+ OptionalParseResult integerParseResult = parser.parseOptionalInteger(value);
+ if (values.empty() && !integerParseResult.hasValue())
+ return success();
+
+ if (!integerParseResult.hasValue() || integerParseResult.getValue())
+ return failure();
+ values.push_back(value);
+
+ Block *destination;
+ SmallVector<OpAsmParser::OperandType> operands;
+ if (parser.parseColon() || parser.parseSuccessor(destination))
+ return failure();
+ if (!parser.parseOptionalLParen()) {
+ if (parser.parseRegionArgumentList(operands) ||
+ parser.parseColonTypeList(caseOperandTypes) || parser.parseRParen())
+ return failure();
+ }
+ caseDestinations.push_back(destination);
+ caseOperands.append(operands.begin(), operands.end());
+ offsets.push_back(offset);
+ offset += operands.size();
+ } while (!parser.parseOptionalComma());
+
+ Builder &builder = parser.getBuilder();
+ caseValues = builder.getI32VectorAttr(values);
+ caseOperandOffsets = builder.getI32VectorAttr(offsets);
+
+ return success();
+}
+
+static void printSwitchOpCases(OpAsmPrinter &p, SwitchOp op,
+ ElementsAttr caseValues,
+ SuccessorRange caseDestinations,
+ OperandRange caseOperands,
+ TypeRange caseOperandTypes,
+ ElementsAttr caseOperandOffsets) {
+ if (!caseValues)
+ return;
+
+ size_t index = 0;
+ llvm::interleave(
+ llvm::zip(caseValues.cast<DenseIntElementsAttr>(), caseDestinations),
+ [&](auto i) {
+ p << " ";
+ p << std::get<0>(i).getLimitedValue();
+ p << ": ";
+ p.printSuccessorAndUseList(std::get<1>(i), op.getCaseOperands(index++));
+ },
+ [&] {
+ p << ',';
+ p.printNewline();
+ });
+ p.printNewline();
+}
+
+static LogicalResult verify(SwitchOp op) {
+ if ((!op.case_values() && !op.caseDestinations().empty()) ||
+ (op.case_values() &&
+ op.case_values()->size() !=
+ static_cast<int64_t>(op.caseDestinations().size())))
+ return op.emitOpError("expects number of case values to match number of "
+ "case destinations");
+ if (op.branch_weights() &&
+ op.branch_weights()->size() != op.getNumSuccessors())
+ return op.emitError("expects number of branch weights to match number of "
+ "successors: ")
+ << op.branch_weights()->size() << " vs " << op.getNumSuccessors();
+ return success();
+}
+
+OperandRange SwitchOp::getCaseOperands(unsigned index) {
+ return getCaseOperandsMutable(index);
+}
+
+MutableOperandRange SwitchOp::getCaseOperandsMutable(unsigned index) {
+ MutableOperandRange caseOperands = caseOperandsMutable();
+ if (!case_operand_offsets()) {
+ assert(caseOperands.size() == 0 &&
+ "non-empty case operands must have offsets");
+ return caseOperands;
+ }
+
+ ElementsAttr offsets = case_operand_offsets().getValue();
+ assert(index < offsets.size() && "invalid case operand offset index");
+
+ int64_t begin = offsets.getValue(index).cast<IntegerAttr>().getInt();
+ int64_t end = index + 1 == offsets.size()
+ ? caseOperands.size()
+ : offsets.getValue(index + 1).cast<IntegerAttr>().getInt();
+ return caseOperandsMutable().slice(begin, end - begin);
+}
+
+Optional<MutableOperandRange>
+SwitchOp::getMutableSuccessorOperands(unsigned index) {
+ assert(index < getNumSuccessors() && "invalid successor index");
+ return index == 0 ? defaultOperandsMutable()
+ : getCaseOperandsMutable(index - 1);
+}
+
//===----------------------------------------------------------------------===//
// Builder, printer and parser for for LLVM::LoadOp.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 057f57409940..d9094f8763d9 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -323,19 +323,31 @@ static Value getPHISourceValue(Block *current, Block *pred,
if (isa<LLVM::BrOp>(terminator))
return terminator.getOperand(index);
- // For conditional branches, we need to check if the current block is reached
- // through the "true" or the "false" branch and take the relevant operands.
- auto condBranchOp = dyn_cast<LLVM::CondBrOp>(terminator);
- assert(condBranchOp &&
- "only branch operations can be terminators of a block that "
- "has successors");
- assert((condBranchOp.getSuccessor(0) != condBranchOp.getSuccessor(1)) &&
- "successors with arguments in LLVM conditional branches must be "
- "
diff erent blocks");
-
- return condBranchOp.getSuccessor(0) == current
- ? condBranchOp.trueDestOperands()[index]
- : condBranchOp.falseDestOperands()[index];
+ SuccessorRange successors = terminator.getSuccessors();
+ assert(std::adjacent_find(successors.begin(), successors.end()) ==
+ successors.end() &&
+ "successors with arguments in LLVM branches must be
diff erent blocks");
+
+ // For instructions that branch based on a condition value, we need to take
+ // the operands for the branch that was taken.
+ if (auto condBranchOp = dyn_cast<LLVM::CondBrOp>(terminator)) {
+ // For conditional branches, we take the operands from either the "true" or
+ // the "false" branch.
+ return condBranchOp.getSuccessor(0) == current
+ ? condBranchOp.trueDestOperands()[index]
+ : condBranchOp.falseDestOperands()[index];
+ } else if (auto switchOp = dyn_cast<LLVM::SwitchOp>(terminator)) {
+ // For switches, we take the operands from either the default case, or from
+ // the case branch that was taken.
+ if (switchOp.defaultDestination() == current)
+ return switchOp.defaultOperands()[index];
+ for (auto i : llvm::enumerate(switchOp.caseDestinations()))
+ if (i.value() == current)
+ return switchOp.getCaseOperands(i.index())[index];
+ }
+
+ llvm_unreachable("only branch or switch operations can be terminators of a "
+ "block that has successors");
}
/// Connect the PHI nodes to the results of preceding blocks.
@@ -717,6 +729,34 @@ LogicalResult ModuleTranslation::convertOperation(Operation &opInst,
branchMapping.try_emplace(&opInst, branch);
return success();
}
+ if (auto switchOp = dyn_cast<LLVM::SwitchOp>(opInst)) {
+ llvm::MDNode *branchWeights = nullptr;
+ if (auto weights = switchOp.branch_weights()) {
+ llvm::SmallVector<uint32_t> weightValues;
+ weightValues.reserve(weights->size());
+ for (llvm::APInt weight : weights->cast<DenseIntElementsAttr>())
+ weightValues.push_back(weight.getLimitedValue());
+ branchWeights = llvm::MDBuilder(llvmModule->getContext())
+ .createBranchWeights(weightValues);
+ }
+
+ llvm::SwitchInst *switchInst =
+ builder.CreateSwitch(valueMapping[switchOp.value()],
+ blockMapping[switchOp.defaultDestination()],
+ switchOp.caseDestinations().size(), branchWeights);
+
+ auto *ty = llvm::cast<llvm::IntegerType>(
+ convertType(switchOp.value().getType().cast<LLVMType>()));
+ for (auto i :
+ llvm::zip(switchOp.case_values()->cast<DenseIntElementsAttr>(),
+ switchOp.caseDestinations()))
+ switchInst->addCase(
+ llvm::ConstantInt::get(ty, std::get<0>(i).getLimitedValue()),
+ blockMapping[std::get<1>(i)]);
+
+ branchMapping.try_emplace(&opInst, switchInst);
+ return success();
+ }
// Emit addressof. We need to look up the global value referenced by the
// operation and store it in the MLIR-to-LLVM value mapping. This does not
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index 1686da2fdba7..9461ebbd9ede 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -652,3 +652,18 @@ func @invalid_ordering_in_fence() {
module attributes {llvm.data_layout = "#vjkr32"} {
func @invalid_data_layout()
}
+
+// -----
+
+func @switch_wrong_number_of_weights(%arg0 : !llvm.i32) {
+ // expected-error at +1 {{expects number of branch weights to match number of successors: 3 vs 2}}
+ llvm.switch %arg0, ^bb1 [
+ 42: ^bb2(%arg0, %arg0 : !llvm.i32, !llvm.i32)
+ ] {branch_weights = dense<[13, 17, 19]> : vector<3xi32>}
+
+^bb1: // pred: ^bb0
+ llvm.return
+
+^bb2(%1: !llvm.i32, %2: !llvm.i32): // pred: ^bb0
+ llvm.return
+}
diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
index 6f794b7d9fe4..fc9ff686d78f 100644
--- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir
+++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
@@ -71,8 +71,8 @@ func @ops(%arg0: !llvm.i32, %arg1: !llvm.float,
// CHECK: ^[[BB1]]
^bb1:
-// CHECK: llvm.cond_br %7, ^[[BB2:.*]], ^[[BB1]]
- llvm.cond_br %7, ^bb2, ^bb1
+// CHECK: llvm.cond_br %7, ^[[BB2:.*]], ^[[BB3:.*]]
+ llvm.cond_br %7, ^bb2, ^bb3
// CHECK: ^[[BB2]]
^bb2:
@@ -80,7 +80,41 @@ func @ops(%arg0: !llvm.i32, %arg1: !llvm.float,
// CHECK: %{{.*}} = llvm.mlir.constant(42 : i64) : !llvm.i47
%22 = llvm.mlir.undef : !llvm.struct<(i32, double, i32)>
%23 = llvm.mlir.constant(42) : !llvm.i47
+ // CHECK: llvm.switch %0, ^[[BB3]] [
+ // CHECK-NEXT: 1: ^[[BB4:.*]],
+ // CHECK-NEXT: 2: ^[[BB5:.*]],
+ // CHECK-NEXT: 3: ^[[BB6:.*]]
+ // CHECK-NEXT: ]
+ llvm.switch %0, ^bb3 [
+ 1: ^bb4,
+ 2: ^bb5,
+ 3: ^bb6
+ ]
+
+// CHECK: ^[[BB3]]
+^bb3:
+// CHECK: llvm.switch %0, ^[[BB7:.*]] [
+// CHECK-NEXT: ]
+ llvm.switch %0, ^bb7 [
+ ]
+
+// CHECK: ^[[BB4]]
+^bb4:
+ llvm.switch %0, ^bb7 [
+ ]
+
+// CHECK: ^[[BB5]]
+^bb5:
+ llvm.switch %0, ^bb7 [
+ ]
+
+// CHECK: ^[[BB6]]
+^bb6:
+ llvm.switch %0, ^bb7 [
+ ]
+// CHECK: ^[[BB7]]
+^bb7:
// Misc operations.
// CHECK: %{{.*}} = llvm.select %{{.*}}, %{{.*}}, %{{.*}} : !llvm.i1, !llvm.i32
%24 = llvm.select %7, %0, %1 : !llvm.i1, !llvm.i32
diff --git a/mlir/test/Target/llvmir.mlir b/mlir/test/Target/llvmir.mlir
index 9f8a37198558..099b8c96cb16 100644
--- a/mlir/test/Target/llvmir.mlir
+++ b/mlir/test/Target/llvmir.mlir
@@ -1358,3 +1358,60 @@ llvm.func @useInlineAsm(%arg0: !llvm.i32) {
llvm.return
}
+// -----
+
+// CHECK-LABEL: @switch_args
+llvm.func @switch_args(%arg0: !llvm.i32) {
+ %0 = llvm.mlir.constant(5 : i32) : !llvm.i32
+ %1 = llvm.mlir.constant(7 : i32) : !llvm.i32
+ %2 = llvm.mlir.constant(11 : i32) : !llvm.i32
+ // CHECK: switch i32 %[[SWITCH_arg0:[0-9]+]], label %[[SWITCHDEFAULT_bb1:[0-9]+]] [
+ // CHECK-NEXT: i32 -1, label %[[SWITCHCASE_bb2:[0-9]+]]
+ // CHECK-NEXT: i32 1, label %[[SWITCHCASE_bb3:[0-9]+]]
+ // CHECK-NEXT: ]
+ llvm.switch %arg0, ^bb1 [
+ -1: ^bb2(%0 : !llvm.i32),
+ 1: ^bb3(%1, %2 : !llvm.i32, !llvm.i32)
+ ]
+
+// CHECK: [[SWITCHDEFAULT_bb1]]:
+// CHECK-NEXT: ret i32 %[[SWITCH_arg0]]
+^bb1: // pred: ^bb0
+ llvm.return %arg0 : !llvm.i32
+
+// CHECK: [[SWITCHCASE_bb2]]:
+// CHECK-NEXT: phi i32 [ 5, %1 ]
+// CHECK-NEXT: ret i32
+^bb2(%3: !llvm.i32): // pred: ^bb0
+ llvm.return %1 : !llvm.i32
+
+// CHECK: [[SWITCHCASE_bb3]]:
+// CHECK-NEXT: phi i32 [ 7, %1 ]
+// CHECK-NEXT: phi i32 [ 11, %1 ]
+// CHECK-NEXT: ret i32
+^bb3(%4: !llvm.i32, %5: !llvm.i32): // pred: ^bb0
+ llvm.return %4 : !llvm.i32
+}
+
+// CHECK-LABEL: @switch_weights
+llvm.func @switch_weights(%arg0: !llvm.i32) {
+ %0 = llvm.mlir.constant(19 : i32) : !llvm.i32
+ %1 = llvm.mlir.constant(23 : i32) : !llvm.i32
+ %2 = llvm.mlir.constant(29 : i32) : !llvm.i32
+ // CHECK: !prof ![[SWITCH_WEIGHT_NODE:[0-9]+]]
+ llvm.switch %arg0, ^bb1(%0 : !llvm.i32) [
+ 9: ^bb2(%1, %2 : !llvm.i32, !llvm.i32),
+ 99: ^bb3
+ ] {branch_weights = dense<[13, 17, 19]> : vector<3xi32>}
+
+^bb1(%3: !llvm.i32): // pred: ^bb0
+ llvm.return %3 : !llvm.i32
+
+^bb2(%4: !llvm.i32, %5: !llvm.i32): // pred: ^bb0
+ llvm.return %5 : !llvm.i32
+
+^bb3: // pred: ^bb0
+ llvm.return %arg0 : !llvm.i32
+}
+
+// CHECK: ![[SWITCH_WEIGHT_NODE]] = !{!"branch_weights", i32 13, i32 17, i32 19}
More information about the Mlir-commits
mailing list