[Mlir-commits] [mlir] 91effec - [mlir][scf] Add scf-to-cf lowering for `scf.index_switch`
Jeff Niu
llvmlistbot at llvm.org
Mon Oct 31 12:01:29 PDT 2022
Author: Jeff Niu
Date: 2022-10-31T12:01:22-07:00
New Revision: 91effec852879ed3588fa6b54015ac128642b8fd
URL: https://github.com/llvm/llvm-project/commit/91effec852879ed3588fa6b54015ac128642b8fd
DIFF: https://github.com/llvm/llvm-project/commit/91effec852879ed3588fa6b54015ac128642b8fd.diff
LOG: [mlir][scf] Add scf-to-cf lowering for `scf.index_switch`
This patch adds lowering from `scf.index_switch` to `cf.switch.
Reviewed By: rriddle
Differential Revision: https://reviews.llvm.org/D136883
Added:
Modified:
mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
index 72f483c38f1e2..c15832a8efad6 100644
--- a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
+++ b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
@@ -290,6 +290,14 @@ struct DoWhileLowering : public OpRewritePattern<WhileOp> {
LogicalResult matchAndRewrite(WhileOp whileOp,
PatternRewriter &rewriter) const override;
};
+
+/// Lower an `scf.index_switch` operation to a `cf.switch` operation.
+struct IndexSwitchLowering : public OpRewritePattern<IndexSwitchOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(IndexSwitchOp op,
+ PatternRewriter &rewriter) const override;
+};
} // namespace
LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
@@ -615,10 +623,68 @@ DoWhileLowering::matchAndRewrite(WhileOp whileOp,
return success();
}
+LogicalResult
+IndexSwitchLowering::matchAndRewrite(IndexSwitchOp op,
+ PatternRewriter &rewriter) const {
+ // Split the block at the op.
+ Block *condBlock = rewriter.getInsertionBlock();
+ Block *continueBlock = rewriter.splitBlock(condBlock, Block::iterator(op));
+
+ // Create the arguments on the continue block with which to replace the
+ // results of the op.
+ SmallVector<Value> results;
+ results.reserve(op.getNumResults());
+ for (Type resultType : op.getResultTypes())
+ results.push_back(continueBlock->addArgument(resultType, op.getLoc()));
+
+ // Handle the regions.
+ auto convertRegion = [&](Region ®ion) -> FailureOr<Block *> {
+ Block *block = ®ion.front();
+
+ // Convert the yield terminator to a branch to the continue block.
+ auto yield = cast<scf::YieldOp>(block->getTerminator());
+ rewriter.setInsertionPoint(yield);
+ rewriter.replaceOpWithNewOp<cf::BranchOp>(yield, continueBlock,
+ yield.getOperands());
+
+ // Inline the region.
+ rewriter.inlineRegionBefore(region, continueBlock);
+ return block;
+ };
+
+ // Convert the case regions.
+ SmallVector<Block *> caseSuccessors;
+ SmallVector<int32_t> caseValues;
+ caseSuccessors.reserve(op.getCases().size());
+ caseValues.reserve(op.getCases().size());
+ for (auto [region, value] : llvm::zip(op.getCaseRegions(), op.getCases())) {
+ FailureOr<Block *> block = convertRegion(region);
+ if (failed(block))
+ return failure();
+ caseSuccessors.push_back(*block);
+ caseValues.push_back(value);
+ }
+
+ // Convert the default region.
+ FailureOr<Block *> defaultBlock = convertRegion(op.getDefaultRegion());
+ if (failed(defaultBlock))
+ return failure();
+
+ // Create the switch.
+ rewriter.setInsertionPointToEnd(condBlock);
+ SmallVector<ValueRange> caseOperands(caseSuccessors.size(), {});
+ rewriter.create<cf::SwitchOp>(
+ op.getLoc(), op.getArg(), *defaultBlock, ValueRange(),
+ rewriter.getDenseI32ArrayAttr(caseValues), caseSuccessors, caseOperands);
+ rewriter.replaceOp(op, continueBlock->getArguments());
+ return success();
+}
+
void mlir::populateSCFToControlFlowConversionPatterns(
RewritePatternSet &patterns) {
patterns.add<ForLowering, IfLowering, ParallelLowering, WhileLowering,
- ExecuteRegionLowering>(patterns.getContext());
+ ExecuteRegionLowering, IndexSwitchLowering>(
+ patterns.getContext());
patterns.add<DoWhileLowering>(patterns.getContext(), /*benefit=*/2);
}
diff --git a/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir b/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir
index df5d60cbbb929..94bacd258470f 100644
--- a/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir
+++ b/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir
@@ -473,7 +473,7 @@ func.func @while_values(%arg0: i32, %arg1: f32) {
scf.condition(%0) %2, %3 : i64, f64
} do {
// CHECK: ^[[AFTER]](%[[ARG4:.*]]: i64, %[[ARG5:.*]]: f64):
- ^bb0(%arg2: i64, %arg3: f64):
+ ^bb0(%arg2: i64, %arg3: f64):
// CHECK: cf.br ^[[BEFORE]](%{{.*}}, %{{.*}} : i32, f32)
scf.yield %c0_i32, %cst : i32, f32
}
@@ -620,3 +620,30 @@ func.func @func_execute_region_elim_multi_yield() {
// CHECK: ^[[bb3]](%[[z:.+]]: i64):
// CHECK: "test.bar"(%[[z]])
// CHECK: return
+
+// SWITCH-LABEL: @index_switch
+func.func @index_switch(%i: index, %a: i32, %b: i32, %c: i32) -> i32 {
+ // SWITCH: cf.switch %arg0 : index
+ // SWITCH-NEXT: default: ^bb3
+ // SWITCH-NEXT: 0: ^bb1
+ // SWITCH-NEXT: 1: ^bb2
+ %0 = scf.index_switch %i -> i32
+ // SWITCH: ^bb1:
+ case 0 {
+ // SWITCH-NEXT: llvm.br ^bb4(%arg1
+ scf.yield %a : i32
+ }
+ // SWITCH: ^bb2:
+ case 1 {
+ // SWITCH-NEXT: llvm.br ^bb4(%arg2
+ scf.yield %b : i32
+ }
+ // SWITCH: ^bb3:
+ default {
+ // SWITCH-NEXT: llvm.br ^bb4(%arg3
+ scf.yield %c : i32
+ }
+ // SWITCH: ^bb4(%[[V:.*]]: i32
+ // SWITCH-NEXT: return %[[V]]
+ return %0 : i32
+}
More information about the Mlir-commits
mailing list