[Mlir-commits] [mlir] 07d8fe9 - [mlir][scf] Add an IndexSwitchOp
Jeff Niu
llvmlistbot at llvm.org
Fri Oct 21 09:21:17 PDT 2022
Author: Jeff Niu
Date: 2022-10-21T09:21:10-07:00
New Revision: 07d8fe9391a1bda7bb5fdfd17a5b897df7a003f5
URL: https://github.com/llvm/llvm-project/commit/07d8fe9391a1bda7bb5fdfd17a5b897df7a003f5
DIFF: https://github.com/llvm/llvm-project/commit/07d8fe9391a1bda7bb5fdfd17a5b897df7a003f5.diff
LOG: [mlir][scf] Add an IndexSwitchOp
The `scf.index_switch` is a control-flow operation that branches to one of the
given regions based on the values of the argument and the cases. The
argument is always of type `index`.
Example:
```mlir
%0 = scf.index_switch %arg0 -> i32
case 2 {
%1 = arith.constant 10 : i32
scf.yield %1 : i32
}
case 5 {
%2 = arith.constant 20 : i32
scf.yield %2 : i32
}
default {
%3 = arith.constant 30 : i32
scf.yield %3 : i32
}
```
Reviewed By: jpienaar
Differential Revision: https://reviews.llvm.org/D136003
Added:
Modified:
mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
mlir/lib/Dialect/SCF/IR/SCF.cpp
mlir/test/Dialect/SCF/invalid.mlir
mlir/test/Dialect/SCF/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index 41b70992ac850..38dd0acc6a69e 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -235,7 +235,7 @@ def ForOp : SCF_Op<"for",
}
/// Return the `index`-th region iteration argument.
BlockArgument getRegionIterArg(unsigned index) {
- assert(index < getNumRegionIterArgs() &&
+ assert(index < getNumRegionIterArgs() &&
"expected an index less than the number of region iter args");
return getBody()->getArguments().drop_front(getNumInductionVars())[index];
}
@@ -434,7 +434,7 @@ def ForeachThreadOp : SCF_Op<"foreach_thread", [
```
Example with thread_dim_mapping attribute:
-
+
```mlir
//
// Sequential context.
@@ -456,7 +456,7 @@ def ForeachThreadOp : SCF_Op<"foreach_thread", [
```
Example with privatized tensors:
-
+
```mlir
%t0 = ...
%t1 = ...
@@ -527,8 +527,8 @@ def ForeachThreadOp : SCF_Op<"foreach_thread", [
return getBody()->getArguments().drop_front(getRank());
}
- /// Return the thread indices in the order specified by the
- /// thread_dim_mapping attribute. Return failure is
+ /// Return the thread indices in the order specified by the
+ /// thread_dim_mapping attribute. Return failure is
/// thread_dim_mapping is not a valid permutation.
FailureOr<SmallVector<Value>> getPermutedThreadIndices();
@@ -988,13 +988,77 @@ def WhileOp : SCF_Op<"while",
let hasVerifier = 1;
}
+//===----------------------------------------------------------------------===//
+// IndexSwitchOp
+//===----------------------------------------------------------------------===//
+
+def IndexSwitchOp : SCF_Op<"index_switch", [RecursiveMemoryEffects,
+ SingleBlockImplicitTerminator<"scf::YieldOp">,
+ DeclareOpInterfaceMethods<RegionBranchOpInterface,
+ ["getRegionInvocationBounds"]>]> {
+ let summary = "switch-case operation on an index argument";
+ let description = [{
+ The `scf.index_switch` is a control-flow operation that branches to one of
+ the given regions based on the values of the argument and the cases. The
+ argument is always of type `index`.
+
+ The operation always has a "default" region and any number of case regions
+ denoted by integer constants. Control-flow transfers to the case region
+ whose constant value equals the value of the argument. If the argument does
+ not equal any of the case values, control-flow transfer to the "default"
+ region.
+
+ Example:
+
+ ```mlir
+ %0 = scf.index_switch %arg0 : index -> i32
+ case 2 {
+ %1 = arith.constant 10 : i32
+ scf.yield %1 : i32
+ }
+ case 5 {
+ %2 = arith.constant 20 : i32
+ scf.yield %2 : i32
+ }
+ default {
+ %3 = arith.constant 30 : i32
+ scf.yield %3 : i32
+ }
+ ```
+ }];
+
+ let arguments = (ins Index:$arg, DenseI64ArrayAttr:$cases);
+ let results = (outs Variadic<AnyType>:$results);
+ let regions = (region SizedRegion<1>:$defaultRegion,
+ VariadicRegion<SizedRegion<1>>:$caseRegions);
+
+ let assemblyFormat = [{
+ $arg attr-dict (`->` type($results)^)?
+ custom<SwitchCases>($cases, $caseRegions) `\n`
+ `` `default` $defaultRegion
+ }];
+
+ let extraClassDeclaration = [{
+ /// Get the number of cases.
+ unsigned getNumCases();
+
+ /// Get the default region body.
+ Block &getDefaultBlock();
+
+ /// Get the body of a case region.
+ Block &getCaseBlock(unsigned idx);
+ }];
+
+ let hasVerifier = 1;
+}
+
//===----------------------------------------------------------------------===//
// YieldOp
//===----------------------------------------------------------------------===//
def YieldOp : SCF_Op<"yield", [Pure, ReturnLike, Terminator,
- ParentOneOf<["ExecuteRegionOp, ForOp",
- "IfOp, ParallelOp, WhileOp"]>]> {
+ ParentOneOf<["ExecuteRegionOp, ForOp", "IfOp", "IndexSwitchOp",
+ "ParallelOp", "WhileOp"]>]> {
let summary = "loop yield and termination operation";
let description = [{
"scf.yield" yields an SSA value from the SCF dialect op region and
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index b3e84f0d6271a..cd1d3829cab9a 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -3386,6 +3386,137 @@ void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
WhileCmpCond, WhileUnusedResult>(context);
}
+//===----------------------------------------------------------------------===//
+// IndexSwitchOp
+//===----------------------------------------------------------------------===//
+
+/// Parse the case regions and values.
+static ParseResult
+parseSwitchCases(OpAsmParser &p, DenseI64ArrayAttr &cases,
+ SmallVectorImpl<std::unique_ptr<Region>> &caseRegions) {
+ SmallVector<int64_t> caseValues;
+ while (succeeded(p.parseOptionalKeyword("case"))) {
+ int64_t value;
+ Region ®ion =
+ *caseRegions.emplace_back(std::make_unique<Region>()).get();
+ if (p.parseInteger(value) || p.parseRegion(region, /*arguments=*/{}))
+ return failure();
+ caseValues.push_back(value);
+ }
+ cases = p.getBuilder().getDenseI64ArrayAttr(caseValues);
+ return success();
+}
+
+/// Print the case regions and values.
+static void printSwitchCases(OpAsmPrinter &p, Operation *op,
+ DenseI64ArrayAttr cases, RegionRange caseRegions) {
+ for (auto [value, region] : llvm::zip(cases.asArrayRef(), caseRegions)) {
+ p.printNewline();
+ p << "case " << value << ' ';
+ p.printRegion(*region, /*printEntryBlockArgs=*/false);
+ }
+}
+
+LogicalResult scf::IndexSwitchOp::verify() {
+ if (getCases().size() != getCaseRegions().size()) {
+ return emitOpError("has ")
+ << getCaseRegions().size() << " case regions but "
+ << getCases().size() << " case values";
+ }
+
+ DenseSet<int64_t> valueSet;
+ for (int64_t value : getCases())
+ if (!valueSet.insert(value).second)
+ return emitOpError("has duplicate case value: ") << value;
+
+ auto verifyRegion = [&](Region ®ion, const Twine &name) -> LogicalResult {
+ auto yield = cast<YieldOp>(region.front().getTerminator());
+ if (yield.getNumOperands() != getNumResults()) {
+ return (emitOpError("expected each region to return ")
+ << getNumResults() << " values, but " << name << " returns "
+ << yield.getNumOperands())
+ .attachNote(yield.getLoc())
+ << "see yield operation here";
+ }
+ for (auto [idx, result, operand] :
+ llvm::zip(llvm::seq<unsigned>(0, getNumResults()), getResultTypes(),
+ yield.getOperandTypes())) {
+ if (result == operand)
+ continue;
+ return (emitOpError("expected result #")
+ << idx << " of each region to be " << result)
+ .attachNote(yield.getLoc())
+ << name << " returns " << operand << " here";
+ }
+ return success();
+ };
+
+ if (failed(verifyRegion(getDefaultRegion(), "default region")))
+ return failure();
+ for (auto &[idx, caseRegion] : llvm::enumerate(getCaseRegions()))
+ if (failed(verifyRegion(caseRegion, "case region #" + Twine(idx))))
+ return failure();
+
+ return success();
+}
+
+unsigned scf::IndexSwitchOp::getNumCases() { return getCases().size(); }
+
+Block &scf::IndexSwitchOp::getDefaultBlock() {
+ return getDefaultRegion().front();
+}
+
+Block &scf::IndexSwitchOp::getCaseBlock(unsigned idx) {
+ assert(idx < getNumCases() && "case index out-of-bounds");
+ return getCaseRegions()[idx].front();
+}
+
+void IndexSwitchOp::getSuccessorRegions(
+ Optional<unsigned> index, ArrayRef<Attribute> operands,
+ SmallVectorImpl<RegionSuccessor> &successors) {
+ // All regions branch back to the parent op.
+ if (index) {
+ successors.emplace_back(getResults());
+ return;
+ }
+
+ // If a constant was not provided, all regions are possible successors.
+ auto operandValue = operands.front().dyn_cast_or_null<IntegerAttr>();
+ if (!operandValue) {
+ for (Region &caseRegion : getCaseRegions())
+ successors.emplace_back(&caseRegion);
+ successors.emplace_back(&getDefaultRegion());
+ return;
+ }
+
+ // Otherwise, try to find a case with a matching value. If not, the default
+ // region is the only successor.
+ for (auto [caseValue, caseRegion] : llvm::zip(getCases(), getCaseRegions())) {
+ if (caseValue == operandValue.getInt()) {
+ successors.emplace_back(&caseRegion);
+ return;
+ }
+ }
+ successors.emplace_back(&getDefaultRegion());
+}
+
+void IndexSwitchOp::getRegionInvocationBounds(
+ ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
+ auto operandValue = operands.front().dyn_cast_or_null<IntegerAttr>();
+ if (!operandValue) {
+ // All regions are invoked at most once.
+ bounds.append(getNumRegions(), InvocationBounds(/*lb=*/0, /*ub=*/1));
+ return;
+ }
+
+ unsigned liveIndex = getNumRegions() - 1;
+ auto it = llvm::find(getCases(), operandValue.getInt());
+ if (it != getCases().end())
+ liveIndex = std::distance(getCases().begin(), it);
+ for (unsigned i = 0, e = getNumRegions(); i < e; ++i)
+ bounds.emplace_back(/*lb=*/0, /*ub=*/i == liveIndex);
+}
+
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SCF/invalid.mlir b/mlir/test/Dialect/SCF/invalid.mlir
index b79ecb48d7d7f..fa91ba03a0aac 100644
--- a/mlir/test/Dialect/SCF/invalid.mlir
+++ b/mlir/test/Dialect/SCF/invalid.mlir
@@ -428,7 +428,7 @@ func.func @parallel_invalid_yield(
func.func @yield_invalid_parent_op() {
"my.op"() ({
- // expected-error at +1 {{'scf.yield' op expects parent op to be one of 'scf.execute_region, scf.for, scf.if, scf.parallel, scf.while'}}
+ // expected-error at +1 {{'scf.yield' op expects parent op to be one of 'scf.execute_region, scf.for, scf.if, scf.index_switch, scf.parallel, scf.while'}}
scf.yield
}) : () -> ()
return
@@ -572,3 +572,57 @@ func.func @wrong_terminator_op(%in: tensor<100xf32>, %out: tensor<100xf32>) {
}
return
}
+
+// -----
+
+func.func @switch_wrong_case_count(%arg0: index) {
+ // expected-error @below {{'scf.index_switch' op has 0 case regions but 1 case values}}
+ "scf.index_switch"(%arg0) ({
+ scf.yield
+ }) {cases = array<i64: 1>} : (index) -> ()
+ return
+}
+
+// -----
+
+func.func @switch_duplicate_case(%arg0: index) {
+ // expected-error @below {{'scf.index_switch' op has duplicate case value: 0}}
+ scf.index_switch %arg0
+ case 0 {
+ scf.yield
+ }
+ case 0 {
+ scf.yield
+ }
+ default {
+ scf.yield
+ }
+ return
+}
+
+// -----
+
+func.func @switch_wrong_types(%arg0: index) {
+ // expected-error @below {{'scf.index_switch' op expected each region to return 0 values, but default region returns 1}}
+ scf.index_switch %arg0
+ default {
+ // expected-note @below {{see yield operation here}}
+ scf.yield %arg0 : index
+ }
+ return
+}
+
+// -----
+
+func.func @switch_wrong_types(%arg0: index, %arg1: i32) {
+ // expected-error @below {{'scf.index_switch' op expected result #0 of each region to be 'index'}}
+ scf.index_switch %arg0 -> index
+ case 0 {
+ // expected-note @below {{case region #0 returns 'i32' here}}
+ scf.yield %arg1 : i32
+ }
+ default {
+ scf.yield %arg0 : index
+ }
+ return
+}
diff --git a/mlir/test/Dialect/SCF/ops.mlir b/mlir/test/Dialect/SCF/ops.mlir
index c1fa4e65f6f5e..e5638388e566a 100644
--- a/mlir/test/Dialect/SCF/ops.mlir
+++ b/mlir/test/Dialect/SCF/ops.mlir
@@ -346,3 +346,36 @@ func.func @elide_terminator() -> () {
} {thread_dim_mapping = [42]}
return
}
+
+// CHECK-LABEL: @switch
+func.func @switch(%arg0: index) -> i32 {
+ // CHECK: %{{.*}} = scf.index_switch %arg0 -> i32
+ %0 = scf.index_switch %arg0 -> i32
+ // CHECK-NEXT: case 2 {
+ case 2 {
+ // CHECK-NEXT: arith.constant
+ %c10_i32 = arith.constant 10 : i32
+ // CHECK-NEXT: scf.yield %{{.*}} : i32
+ scf.yield %c10_i32 : i32
+ // CHECK-NEXT: }
+ }
+ // CHECK-NEXT: case 5 {
+ case 5 {
+ %c20_i32 = arith.constant 20 : i32
+ scf.yield %c20_i32 : i32
+ }
+ // CHECK: default {
+ default {
+ %c30_i32 = arith.constant 30 : i32
+ scf.yield %c30_i32 : i32
+ }
+
+ // CHECK: scf.index_switch %arg0
+ scf.index_switch %arg0
+ // CHECK-NEXT: default {
+ default {
+ scf.yield
+ }
+
+ return %0 : i32
+}
More information about the Mlir-commits
mailing list