[Mlir-commits] [mlir] [MLIR][SCF] Actually use conversion interface in scf-to-cf conversion (PR #154075)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Aug 18 01:35:22 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: William Moses (wsmoses)

<details>
<summary>Changes</summary>



---
Full diff: https://github.com/llvm/llvm-project/pull/154075.diff


1 Files Affected:

- (modified) mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp (+76-52) 


``````````diff
diff --git a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
index 37cfc9f2c23e6..d9ec932244770 100644
--- a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
+++ b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
@@ -100,11 +100,13 @@ struct SCFToControlFlowPass
 //      |   <%init visible by dominance> |
 //      +--------------------------------+
 //
-struct ForLowering : public OpRewritePattern<ForOp> {
-  using OpRewritePattern<ForOp>::OpRewritePattern;
+struct ForLowering : public OpConversionPattern<ForOp> {
+  using OpConversionPattern<ForOp>::OpConversionPattern;
 
-  LogicalResult matchAndRewrite(ForOp forOp,
-                                PatternRewriter &rewriter) const override;
+  LogicalResult
+  matchAndRewrite(ForOp forOp,
+                  typename OpConversionPattern<ForOp>::OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override;
 };
 
 // Create a CFG subgraph for the scf.if operation (including its "then" and
@@ -193,25 +195,31 @@ struct ForLowering : public OpRewritePattern<ForOp> {
 //      | <code after the IfOp>          |
 //      +--------------------------------+
 //
-struct IfLowering : public OpRewritePattern<IfOp> {
-  using OpRewritePattern<IfOp>::OpRewritePattern;
+struct IfLowering : public OpConversionPattern<IfOp> {
+  using OpConversionPattern<IfOp>::OpConversionPattern;
 
-  LogicalResult matchAndRewrite(IfOp ifOp,
-                                PatternRewriter &rewriter) const override;
+  LogicalResult
+  matchAndRewrite(IfOp ifOp,
+                  typename OpConversionPattern<IfOp>::OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override;
 };
 
-struct ExecuteRegionLowering : public OpRewritePattern<ExecuteRegionOp> {
-  using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern;
+struct ExecuteRegionLowering : public OpConversionPattern<ExecuteRegionOp> {
+  using OpConversionPattern<ExecuteRegionOp>::OpConversionPattern;
 
-  LogicalResult matchAndRewrite(ExecuteRegionOp op,
-                                PatternRewriter &rewriter) const override;
+  LogicalResult matchAndRewrite(
+      ExecuteRegionOp op,
+      typename OpConversionPattern<ExecuteRegionOp>::OpAdaptor adaptor,
+      ConversionPatternRewriter &rewriter) const override;
 };
 
-struct ParallelLowering : public OpRewritePattern<mlir::scf::ParallelOp> {
-  using OpRewritePattern<mlir::scf::ParallelOp>::OpRewritePattern;
+struct ParallelLowering : public OpConversionPattern<mlir::scf::ParallelOp> {
+  using OpConversionPattern<mlir::scf::ParallelOp>::OpConversionPattern;
 
-  LogicalResult matchAndRewrite(mlir::scf::ParallelOp parallelOp,
-                                PatternRewriter &rewriter) const override;
+  LogicalResult matchAndRewrite(
+      mlir::scf::ParallelOp parallelOp,
+      typename OpConversionPattern<mlir::scf::ParallelOp>::OpAdaptor adaptor,
+      ConversionPatternRewriter &rewriter) const override;
 };
 
 /// Create a CFG subgraph for this loop construct. The regions of the loop need
@@ -273,41 +281,49 @@ struct ParallelLowering : public OpRewritePattern<mlir::scf::ParallelOp> {
 /// the results of the WhileOp are defined in the 'before' region, which is
 /// required to have a single existing block, and are therefore accessible in
 /// the continuation block due to dominance.
-struct WhileLowering : public OpRewritePattern<WhileOp> {
-  using OpRewritePattern<WhileOp>::OpRewritePattern;
+struct WhileLowering : public OpConversionPattern<WhileOp> {
+  using OpConversionPattern<WhileOp>::OpConversionPattern;
 
-  LogicalResult matchAndRewrite(WhileOp whileOp,
-                                PatternRewriter &rewriter) const override;
+  LogicalResult
+  matchAndRewrite(WhileOp whileOp,
+                  typename OpConversionPattern<WhileOp>::OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override;
 };
 
 /// Optimized version of the above for the case of the "after" region merely
 /// forwarding its arguments back to the "before" region (i.e., a "do-while"
 /// loop). This avoid inlining the "after" region completely and branches back
 /// to the "before" entry instead.
-struct DoWhileLowering : public OpRewritePattern<WhileOp> {
-  using OpRewritePattern<WhileOp>::OpRewritePattern;
+struct DoWhileLowering : public OpConversionPattern<WhileOp> {
+  using OpConversionPattern<WhileOp>::OpConversionPattern;
 
-  LogicalResult matchAndRewrite(WhileOp whileOp,
-                                PatternRewriter &rewriter) const override;
+  LogicalResult
+  matchAndRewrite(WhileOp whileOp,
+                  typename OpConversionPattern<WhileOp>::OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override;
 };
 
 /// Lower an `scf.index_switch` operation to a `cf.switch` operation.
-struct IndexSwitchLowering : public OpRewritePattern<IndexSwitchOp> {
-  using OpRewritePattern::OpRewritePattern;
+struct IndexSwitchLowering : public OpConversionPattern<IndexSwitchOp> {
+  using OpConversionPattern::OpConversionPattern;
 
-  LogicalResult matchAndRewrite(IndexSwitchOp op,
-                                PatternRewriter &rewriter) const override;
+  LogicalResult matchAndRewrite(
+      IndexSwitchOp op,
+      typename OpConversionPattern<IndexSwitchOp>::OpAdaptor adaptor,
+      ConversionPatternRewriter &rewriter) const override;
 };
 
 /// Lower an `scf.forall` operation to an `scf.parallel` op, assuming that it
 /// has no shared outputs. Ops with shared outputs should be bufferized first.
 /// Specialized lowerings for `scf.forall` (e.g., for GPUs) exist in other
 /// dialects/passes.
-struct ForallLowering : public OpRewritePattern<mlir::scf::ForallOp> {
-  using OpRewritePattern<mlir::scf::ForallOp>::OpRewritePattern;
+struct ForallLowering : public OpConversionPattern<mlir::scf::ForallOp> {
+  using OpConversionPattern<mlir::scf::ForallOp>::OpConversionPattern;
 
-  LogicalResult matchAndRewrite(mlir::scf::ForallOp forallOp,
-                                PatternRewriter &rewriter) const override;
+  LogicalResult matchAndRewrite(
+      mlir::scf::ForallOp forallOp,
+      typename OpConversionPattern<mlir::scf::ForallOp>::OpAdaptor adaptor,
+      ConversionPatternRewriter &rewriter) const override;
 };
 
 } // namespace
@@ -325,8 +341,9 @@ static void propagateLoopAttrs(Operation *scfOp, Operation *brOp) {
   brOp->setDiscardableAttrs(llvmAttrs);
 }
 
-LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
-                                           PatternRewriter &rewriter) const {
+LogicalResult ForLowering::matchAndRewrite(
+    ForOp forOp, typename OpConversionPattern<ForOp>::OpAdaptor adaptor,
+    ConversionPatternRewriter &rewriter) const {
   Location loc = forOp.getLoc();
 
   // Start by splitting the block containing the 'scf.for' into two parts.
@@ -397,8 +414,9 @@ LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
   return success();
 }
 
-LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
-                                          PatternRewriter &rewriter) const {
+LogicalResult IfLowering::matchAndRewrite(
+    IfOp ifOp, typename OpConversionPattern<IfOp>::OpAdaptor adaptor,
+    ConversionPatternRewriter &rewriter) const {
   auto loc = ifOp.getLoc();
 
   // Start by splitting the block containing the 'scf.if' into two parts.
@@ -453,9 +471,10 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
   return success();
 }
 
-LogicalResult
-ExecuteRegionLowering::matchAndRewrite(ExecuteRegionOp op,
-                                       PatternRewriter &rewriter) const {
+LogicalResult ExecuteRegionLowering::matchAndRewrite(
+    ExecuteRegionOp op,
+    typename OpConversionPattern<ExecuteRegionOp>::OpAdaptor adaptor,
+    ConversionPatternRewriter &rewriter) const {
   auto loc = op.getLoc();
 
   auto *condBlock = rewriter.getInsertionBlock();
@@ -487,9 +506,10 @@ ExecuteRegionLowering::matchAndRewrite(ExecuteRegionOp op,
   return success();
 }
 
-LogicalResult
-ParallelLowering::matchAndRewrite(ParallelOp parallelOp,
-                                  PatternRewriter &rewriter) const {
+LogicalResult ParallelLowering::matchAndRewrite(
+    ParallelOp parallelOp,
+    typename OpConversionPattern<ParallelOp>::OpAdaptor adaptor,
+    ConversionPatternRewriter &rewriter) const {
   Location loc = parallelOp.getLoc();
   auto reductionOp = dyn_cast<ReduceOp>(parallelOp.getBody()->getTerminator());
   if (!reductionOp) {
@@ -563,8 +583,9 @@ ParallelLowering::matchAndRewrite(ParallelOp parallelOp,
   return success();
 }
 
-LogicalResult WhileLowering::matchAndRewrite(WhileOp whileOp,
-                                             PatternRewriter &rewriter) const {
+LogicalResult WhileLowering::matchAndRewrite(
+    WhileOp whileOp, typename OpConversionPattern<WhileOp>::OpAdaptor adaptor,
+    ConversionPatternRewriter &rewriter) const {
   OpBuilder::InsertionGuard guard(rewriter);
   Location loc = whileOp.getLoc();
 
@@ -606,9 +627,9 @@ LogicalResult WhileLowering::matchAndRewrite(WhileOp whileOp,
   return success();
 }
 
-LogicalResult
-DoWhileLowering::matchAndRewrite(WhileOp whileOp,
-                                 PatternRewriter &rewriter) const {
+LogicalResult DoWhileLowering::matchAndRewrite(
+    WhileOp whileOp, typename OpConversionPattern<WhileOp>::OpAdaptor adaptor,
+    ConversionPatternRewriter &rewriter) const {
   Block &afterBlock = *whileOp.getAfterBody();
   if (!llvm::hasSingleElement(afterBlock))
     return rewriter.notifyMatchFailure(whileOp,
@@ -652,9 +673,10 @@ DoWhileLowering::matchAndRewrite(WhileOp whileOp,
   return success();
 }
 
-LogicalResult
-IndexSwitchLowering::matchAndRewrite(IndexSwitchOp op,
-                                     PatternRewriter &rewriter) const {
+LogicalResult IndexSwitchLowering::matchAndRewrite(
+    IndexSwitchOp op,
+    typename OpConversionPattern<IndexSwitchOp>::OpAdaptor adaptor,
+    ConversionPatternRewriter &rewriter) const {
   // Split the block at the op.
   Block *condBlock = rewriter.getInsertionBlock();
   Block *continueBlock = rewriter.splitBlock(condBlock, Block::iterator(op));
@@ -714,8 +736,10 @@ IndexSwitchLowering::matchAndRewrite(IndexSwitchOp op,
   return success();
 }
 
-LogicalResult ForallLowering::matchAndRewrite(ForallOp forallOp,
-                                              PatternRewriter &rewriter) const {
+LogicalResult ForallLowering::matchAndRewrite(
+    ForallOp forallOp,
+    typename OpConversionPattern<ForallOp>::OpAdaptor adaptor,
+    ConversionPatternRewriter &rewriter) const {
   return scf::forallToParallelLoop(rewriter, forallOp);
 }
 

``````````

</details>


https://github.com/llvm/llvm-project/pull/154075


More information about the Mlir-commits mailing list