[Mlir-commits] [mlir] 91effec - [mlir][scf] Add scf-to-cf lowering for `scf.index_switch`

Jeff Niu llvmlistbot at llvm.org
Mon Oct 31 12:01:29 PDT 2022


Author: Jeff Niu
Date: 2022-10-31T12:01:22-07:00
New Revision: 91effec852879ed3588fa6b54015ac128642b8fd

URL: https://github.com/llvm/llvm-project/commit/91effec852879ed3588fa6b54015ac128642b8fd
DIFF: https://github.com/llvm/llvm-project/commit/91effec852879ed3588fa6b54015ac128642b8fd.diff

LOG: [mlir][scf] Add scf-to-cf lowering for `scf.index_switch`

This patch adds lowering from `scf.index_switch` to `cf.switch.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D136883

Added: 
    

Modified: 
    mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
    mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
index 72f483c38f1e2..c15832a8efad6 100644
--- a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
+++ b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
@@ -290,6 +290,14 @@ struct DoWhileLowering : public OpRewritePattern<WhileOp> {
   LogicalResult matchAndRewrite(WhileOp whileOp,
                                 PatternRewriter &rewriter) const override;
 };
+
+/// Lower an `scf.index_switch` operation to a `cf.switch` operation.
+struct IndexSwitchLowering : public OpRewritePattern<IndexSwitchOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(IndexSwitchOp op,
+                                PatternRewriter &rewriter) const override;
+};
 } // namespace
 
 LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
@@ -615,10 +623,68 @@ DoWhileLowering::matchAndRewrite(WhileOp whileOp,
   return success();
 }
 
+LogicalResult
+IndexSwitchLowering::matchAndRewrite(IndexSwitchOp op,
+                                     PatternRewriter &rewriter) const {
+  // Split the block at the op.
+  Block *condBlock = rewriter.getInsertionBlock();
+  Block *continueBlock = rewriter.splitBlock(condBlock, Block::iterator(op));
+
+  // Create the arguments on the continue block with which to replace the
+  // results of the op.
+  SmallVector<Value> results;
+  results.reserve(op.getNumResults());
+  for (Type resultType : op.getResultTypes())
+    results.push_back(continueBlock->addArgument(resultType, op.getLoc()));
+
+  // Handle the regions.
+  auto convertRegion = [&](Region &region) -> FailureOr<Block *> {
+    Block *block = &region.front();
+
+    // Convert the yield terminator to a branch to the continue block.
+    auto yield = cast<scf::YieldOp>(block->getTerminator());
+    rewriter.setInsertionPoint(yield);
+    rewriter.replaceOpWithNewOp<cf::BranchOp>(yield, continueBlock,
+                                              yield.getOperands());
+
+    // Inline the region.
+    rewriter.inlineRegionBefore(region, continueBlock);
+    return block;
+  };
+
+  // Convert the case regions.
+  SmallVector<Block *> caseSuccessors;
+  SmallVector<int32_t> caseValues;
+  caseSuccessors.reserve(op.getCases().size());
+  caseValues.reserve(op.getCases().size());
+  for (auto [region, value] : llvm::zip(op.getCaseRegions(), op.getCases())) {
+    FailureOr<Block *> block = convertRegion(region);
+    if (failed(block))
+      return failure();
+    caseSuccessors.push_back(*block);
+    caseValues.push_back(value);
+  }
+
+  // Convert the default region.
+  FailureOr<Block *> defaultBlock = convertRegion(op.getDefaultRegion());
+  if (failed(defaultBlock))
+    return failure();
+
+  // Create the switch.
+  rewriter.setInsertionPointToEnd(condBlock);
+  SmallVector<ValueRange> caseOperands(caseSuccessors.size(), {});
+  rewriter.create<cf::SwitchOp>(
+      op.getLoc(), op.getArg(), *defaultBlock, ValueRange(),
+      rewriter.getDenseI32ArrayAttr(caseValues), caseSuccessors, caseOperands);
+  rewriter.replaceOp(op, continueBlock->getArguments());
+  return success();
+}
+
 void mlir::populateSCFToControlFlowConversionPatterns(
     RewritePatternSet &patterns) {
   patterns.add<ForLowering, IfLowering, ParallelLowering, WhileLowering,
-               ExecuteRegionLowering>(patterns.getContext());
+               ExecuteRegionLowering, IndexSwitchLowering>(
+      patterns.getContext());
   patterns.add<DoWhileLowering>(patterns.getContext(), /*benefit=*/2);
 }
 

diff  --git a/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir b/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir
index df5d60cbbb929..94bacd258470f 100644
--- a/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir
+++ b/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir
@@ -473,7 +473,7 @@ func.func @while_values(%arg0: i32, %arg1: f32) {
     scf.condition(%0) %2, %3 : i64, f64
   } do {
   // CHECK:   ^[[AFTER]](%[[ARG4:.*]]: i64, %[[ARG5:.*]]: f64):
-  ^bb0(%arg2: i64, %arg3: f64):  
+  ^bb0(%arg2: i64, %arg3: f64):
     // CHECK:   cf.br ^[[BEFORE]](%{{.*}}, %{{.*}} : i32, f32)
     scf.yield %c0_i32, %cst : i32, f32
   }
@@ -620,3 +620,30 @@ func.func @func_execute_region_elim_multi_yield() {
 // CHECK:   ^[[bb3]](%[[z:.+]]: i64):
 // CHECK:     "test.bar"(%[[z]])
 // CHECK:     return
+
+// SWITCH-LABEL: @index_switch
+func.func @index_switch(%i: index, %a: i32, %b: i32, %c: i32) -> i32 {
+  // SWITCH: cf.switch %arg0 : index
+  // SWITCH-NEXT: default: ^bb3
+  // SWITCH-NEXT: 0: ^bb1
+  // SWITCH-NEXT: 1: ^bb2
+  %0 = scf.index_switch %i -> i32
+  // SWITCH: ^bb1:
+  case 0 {
+    // SWITCH-NEXT: llvm.br ^bb4(%arg1
+    scf.yield %a : i32
+  }
+  // SWITCH: ^bb2:
+  case 1 {
+    // SWITCH-NEXT: llvm.br ^bb4(%arg2
+    scf.yield %b : i32
+  }
+  // SWITCH: ^bb3:
+  default {
+    // SWITCH-NEXT: llvm.br ^bb4(%arg3
+    scf.yield %c : i32
+  }
+  // SWITCH: ^bb4(%[[V:.*]]: i32
+  // SWITCH-NEXT: return %[[V]]
+  return %0 : i32
+}


        


More information about the Mlir-commits mailing list