[Mlir-commits] [mlir] [mlir][Interfaces] Add generic pattern for region inlining (PR #176641)

Matthias Springer llvmlistbot at llvm.org
Wed Jan 21 00:50:32 PST 2026


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/176641

>From 39d0f71ea91c87a6567024b5fa3bc7821fb15249 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Wed, 21 Jan 2026 08:27:20 +0000
Subject: [PATCH 1/2] [mlir][SCF] Improve `ForOp::getSuccessorRegions`

---
 mlir/lib/Dialect/SCF/IR/SCF.cpp               | 21 ++++++++++
 .../DataFlow/test-dead-code-analysis.mlir     | 38 +++++++++++++++++++
 .../Dialect/Arith/int-range-narrowing.mlir    |  2 +-
 3 files changed, 60 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 5b6e9304de505..86e66dbaf6171 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -704,6 +704,27 @@ OperandRange ForOp::getEntrySuccessorOperands(RegionSuccessor successor) {
 
 void ForOp::getSuccessorRegions(RegionBranchPoint point,
                                 SmallVectorImpl<RegionSuccessor> &regions) {
+  std::optional<APInt> tripCount = getStaticTripCount();
+  if (tripCount.has_value()) {
+    // The loop has a known static trip count.
+    if (point.isParent()) {
+      if (*tripCount == 0) {
+        // The loop has zero iterations. It branches directly back to the
+        // parent.
+        regions.push_back(RegionSuccessor::parent());
+      } else {
+        // The loop has at least one iteration. It branches into the body.
+        regions.push_back(RegionSuccessor(&getRegion()));
+      }
+      return;
+    } else if (*tripCount == 1) {
+      // The loop has exactly 1 iteration. Therefore, it branches from the
+      // region to the parent. (No further iteration.)
+      regions.push_back(RegionSuccessor::parent());
+      return;
+    }
+  }
+
   // Both the operation itself and the region may be branching into the body or
   // back into the operation itself. It is possible for loop not to enter the
   // body.
diff --git a/mlir/test/Analysis/DataFlow/test-dead-code-analysis.mlir b/mlir/test/Analysis/DataFlow/test-dead-code-analysis.mlir
index 7ce5c0f9e3d5a..4d3a61601a85c 100644
--- a/mlir/test/Analysis/DataFlow/test-dead-code-analysis.mlir
+++ b/mlir/test/Analysis/DataFlow/test-dead-code-analysis.mlir
@@ -283,3 +283,41 @@ func.func @test_forall_op_control_flow(%num_threads: index) {
   } {tag = "test_forall_op_control_flow"}
   return
 }
+
+func.func @test_for_op_control_flow() {
+  %c1 = arith.constant 1 : index
+  %c5 = arith.constant 5 : index
+  %c6 = arith.constant 6 : index
+  %c7 = arith.constant 7 : index
+
+  // Test case 1: Zero loop iterations.
+  // CHECK: test_for_op_control_flow_zero:
+  // CHECK:  region #0
+  // CHECK:   ^bb0 = dead
+  // CHECK: op_preds: (all) predecessors:
+  // CHECK:   scf.for %{{.*}} = %{{.*}} to %{{.*}} step %c1 {...} {tag = "test_for_op_control_flow_zero"}
+  scf.for %iv = %c5 to %c5 step %c1 {} {tag = "test_for_op_control_flow_zero"}
+
+  // Test case 2: One loop iteration.
+  // CHECK: test_for_op_control_flow_one:
+  // CHECK:  region #0
+  // CHECK:   ^bb0 = live
+  // CHECK: region_preds: (all) predecessors:
+  // CHECK:   scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {...} {tag = "test_for_op_control_flow_one"}
+  // CHECK: op_preds: (all) predecessors:
+  // CHECK:   scf.yield
+  scf.for %iv = %c5 to %c6 step %c1 {} {tag = "test_for_op_control_flow_one"}
+
+  // Test case 3: More than one loop iteration.
+  // CHECK: test_for_op_control_flow_multi:
+  // CHECK:  region #0
+  // CHECK:   ^bb0 = live
+  // CHECK: region_preds: (all) predecessors:
+  // CHECK:   scf.for %arg0 = %{{.*}} to %{{.*}} step %{{.*}} {...} {tag = "test_for_op_control_flow_multi"}
+  // CHECK:   scf.yield
+  // CHECK: op_preds: (all) predecessors:
+  // CHECK:   scf.yield
+  scf.for %iv = %c5 to %c7 step %c1 {} {tag = "test_for_op_control_flow_multi"}
+
+  return
+}
diff --git a/mlir/test/Dialect/Arith/int-range-narrowing.mlir b/mlir/test/Dialect/Arith/int-range-narrowing.mlir
index 9107bf649b561..e2cd9b50f6736 100644
--- a/mlir/test/Dialect/Arith/int-range-narrowing.mlir
+++ b/mlir/test/Dialect/Arith/int-range-narrowing.mlir
@@ -361,7 +361,7 @@ func.func private @use_i64(i64)
 // CHECK-LABEL: func.func @loop_with_iter_arg
 func.func @loop_with_iter_arg() {
   %c0 = arith.constant 0 : index
-  %c1 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
   %c16 = arith.constant 16 : index
 
   %cst = arith.constant dense<0.000000e+00> : vector<4xf32>

>From a3244e34fe024c0fa66d34ddcea91c3ae77bdc30 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Sun, 18 Jan 2026 11:07:39 +0000
Subject: [PATCH 2/2] [mlir][Interfaces] Add generic pattern for region
 inlining

---
 .../mlir/Interfaces/ControlFlowInterfaces.h   |  39 +++
 mlir/lib/Dialect/SCF/IR/SCF.cpp               | 171 ++-----------
 mlir/lib/Interfaces/ControlFlowInterfaces.cpp | 233 ++++++++++++++++++
 .../Dialect/Arith/int-range-interface.mlir    |   6 +-
 mlir/test/Dialect/SCF/canonicalize.mlir       |  20 ++
 mlir/test/Dialect/SCF/one-shot-bufferize.mlir |   6 +-
 6 files changed, 326 insertions(+), 149 deletions(-)

diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
index d764089f5ccc8..a76dce6f2ffc5 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
@@ -320,6 +320,45 @@ Region *getEnclosingRepetitiveRegion(Value value);
 void populateRegionBranchOpInterfaceCanonicalizationPatterns(
     RewritePatternSet &patterns, StringRef opName, PatternBenefit benefit = 1);
 
+/// Helper function for the region branch op inlining pattern that builds
+/// replacement values for non-successor-input values.
+using NonSuccessorInputReplacementBuilderFn =
+    std::function<Value(OpBuilder &, Location, Value)>;
+/// Helper function for the region branch op inlining pattern that checks if the
+/// pattern is applicable to the given operation.
+using PatternMatcherFn = std::function<LogicalResult(Operation *)>;
+
+namespace detail {
+/// Default implementation of the non-successor-input replacement builder
+/// function. This default implemention assumes that all block arguments and
+/// op results are successor inputs.
+static inline Value defaultReplBuilderFn(OpBuilder &builder, Location loc,
+                                         Value value) {
+  llvm_unreachable("defaultReplBuilderFn not implemented");
+}
+
+/// Default implementation of the pattern matcher function.
+static inline LogicalResult defaultMatcherFn(Operation *op) {
+  return success();
+}
+} // namespace detail
+
+/// Populate a pattern that inlines the body of region branch ops when there is
+/// a single acyclic path through the region branch op, starting from "parent"
+/// and ending at "parent". For details, refer to the documentation of the
+/// pattern.
+///
+/// `replBuilderFn` is a function that builds replacement values for
+/// non-successor-input values of the region branch op. `matcherFn` is a
+/// function that checks if the pattern is applicable to the given operation.
+/// Both functions are optional.
+void populateRegionBranchOpInterfaceInliningPattern(
+    RewritePatternSet &patterns, StringRef opName,
+    NonSuccessorInputReplacementBuilderFn replBuilderFn =
+        detail::defaultReplBuilderFn,
+    PatternMatcherFn matcherFn = detail::defaultMatcherFn,
+    PatternBenefit benefit = 1);
+
 //===----------------------------------------------------------------------===//
 // ControlFlow Traits
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 86e66dbaf6171..2ebece4bdedb7 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -132,19 +132,6 @@ std::optional<llvm::APSInt> mlir::scf::computeUbMinusLb(Value lb, Value ub,
 // ExecuteRegionOp
 //===----------------------------------------------------------------------===//
 
-/// Replaces the given op with the contents of the given single-block region,
-/// using the operands of the block terminator to replace operation results.
-static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op,
-                                Region &region, ValueRange blockArgs = {}) {
-  assert(region.hasOneBlock() && "expected single-block region");
-  Block *block = &region.front();
-  Operation *terminator = block->getTerminator();
-  ValueRange results = terminator->getOperands();
-  rewriter.inlineBlockBefore(block, op, blockArgs);
-  rewriter.replaceOp(op, results);
-  rewriter.eraseOp(terminator);
-}
-
 ///
 /// (ssa-id `=`)? `execute_region` `->` function-result-type `{`
 ///    block+
@@ -192,32 +179,6 @@ LogicalResult ExecuteRegionOp::verify() {
   return success();
 }
 
-// Inline an ExecuteRegionOp if it only contains one block.
-//     "test.foo"() : () -> ()
-//      %v = scf.execute_region -> i64 {
-//        %x = "test.val"() : () -> i64
-//        scf.yield %x : i64
-//      }
-//      "test.bar"(%v) : (i64) -> ()
-//
-//  becomes
-//
-//     "test.foo"() : () -> ()
-//     %x = "test.val"() : () -> i64
-//     "test.bar"(%x) : (i64) -> ()
-//
-struct SingleBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
-  using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(ExecuteRegionOp op,
-                                PatternRewriter &rewriter) const override {
-    if (!op.getRegion().hasOneBlock() || op.getNoInline())
-      return failure();
-    replaceOpWithRegion(rewriter, op, op.getRegion());
-    return success();
-  }
-};
-
 // Inline an ExecuteRegionOp if its parent can contain multiple blocks.
 // TODO generalize the conditions for operations which can be inlined into.
 // func @func_execute_region_elim() {
@@ -293,9 +254,15 @@ struct MultiBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
 
 void ExecuteRegionOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                   MLIRContext *context) {
-  results.add<SingleBlockExecuteInliner, MultiBlockExecuteInliner>(context);
+  results.add<MultiBlockExecuteInliner>(context);
   populateRegionBranchOpInterfaceCanonicalizationPatterns(
       results, ExecuteRegionOp::getOperationName());
+  // Inline ops with a single block that are not marked as "no_inline".
+  populateRegionBranchOpInterfaceInliningPattern(
+      results, ExecuteRegionOp::getOperationName(),
+      mlir::detail::defaultReplBuilderFn, [](Operation *op) {
+        return failure(cast<ExecuteRegionOp>(op).getNoInline());
+      });
 }
 
 void ExecuteRegionOp::getSuccessorRegions(
@@ -962,54 +929,6 @@ mlir::scf::replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp,
 }
 
 namespace {
-/// Rewriting pattern that erases loops that are known not to iterate, replaces
-/// single-iteration loops with their bodies, and removes empty loops that
-/// iterate at least once and only return values defined outside of the loop.
-struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
-  using OpRewritePattern<ForOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(ForOp op,
-                                PatternRewriter &rewriter) const override {
-    std::optional<APInt> tripCount = op.getStaticTripCount();
-    if (!tripCount.has_value())
-      return rewriter.notifyMatchFailure(op,
-                                         "can't compute constant trip count");
-
-    if (tripCount->isZero()) {
-      LDBG() << "SimplifyTrivialLoops tripCount is 0 for loop "
-             << OpWithFlags(op, OpPrintingFlags().skipRegions());
-      rewriter.replaceOp(op, op.getInitArgs());
-      return success();
-    }
-
-    if (tripCount->getSExtValue() == 1) {
-      LDBG() << "SimplifyTrivialLoops tripCount is 1 for loop "
-             << OpWithFlags(op, OpPrintingFlags().skipRegions());
-      SmallVector<Value, 4> blockArgs;
-      blockArgs.reserve(op.getInitArgs().size() + 1);
-      blockArgs.push_back(op.getLowerBound());
-      llvm::append_range(blockArgs, op.getInitArgs());
-      replaceOpWithRegion(rewriter, op, op.getRegion(), blockArgs);
-      return success();
-    }
-
-    // Now we are left with loops that have more than 1 iterations.
-    Block &block = op.getRegion().front();
-    if (!llvm::hasSingleElement(block))
-      return failure();
-    // The loop is empty and iterates at least once, if it only returns values
-    // defined outside of the loop, remove it and replace it with yield values.
-    if (llvm::any_of(op.getYieldedValues(),
-                     [&](Value v) { return !op.isDefinedOutsideOfLoop(v); }))
-      return failure();
-    LDBG() << "SimplifyTrivialLoops empty body loop allows replacement with "
-              "yield operands for loop "
-           << OpWithFlags(op, OpPrintingFlags().skipRegions());
-    rewriter.replaceOp(op, op.getYieldedValues());
-    return success();
-  }
-};
-
 /// Fold scf.for iter_arg/result pairs that go through incoming/ougoing
 /// a tensor.cast op pair so as to pull the tensor.cast inside the scf.for:
 ///
@@ -1072,9 +991,20 @@ struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
 
 void ForOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                         MLIRContext *context) {
-  results.add<SimplifyTrivialLoops, ForOpTensorCastFolder>(context);
+  results.add<ForOpTensorCastFolder>(context);
   populateRegionBranchOpInterfaceCanonicalizationPatterns(
       results, ForOp::getOperationName());
+  populateRegionBranchOpInterfaceInliningPattern(
+      results, ForOp::getOperationName(),
+      /*replBuilderFn=*/[](OpBuilder &builder, Location loc, Value value) {
+        // scf.for has only one non-successor input value: the loop induction
+        // variable. In case of a single acyclic path through the op, the IV can
+        // be safely replaced with the lower bound.
+        auto blockArg = cast<BlockArgument>(value);
+        assert(blockArg.getArgNumber() == 0 && "expected induction variable");
+        auto forOp = cast<ForOp>(blockArg.getOwner()->getParentOp());
+        return forOp.getLowerBound();
+      });
 }
 
 std::optional<APInt> ForOp::getConstantStep() {
@@ -2218,26 +2148,6 @@ void IfOp::getRegionInvocationBounds(
 }
 
 namespace {
-struct RemoveStaticCondition : public OpRewritePattern<IfOp> {
-  using OpRewritePattern<IfOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(IfOp op,
-                                PatternRewriter &rewriter) const override {
-    BoolAttr condition;
-    if (!matchPattern(op.getCondition(), m_Constant(&condition)))
-      return failure();
-
-    if (condition.getValue())
-      replaceOpWithRegion(rewriter, op, op.getThenRegion());
-    else if (!op.getElseRegion().empty())
-      replaceOpWithRegion(rewriter, op, op.getElseRegion());
-    else
-      rewriter.eraseOp(op);
-
-    return success();
-  }
-};
-
 /// Hoist any yielded results whose operands are defined outside
 /// the if, to a select instruction.
 struct ConvertTrivialIfToSelect : public OpRewritePattern<IfOp> {
@@ -2788,10 +2698,11 @@ void IfOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                        MLIRContext *context) {
   results.add<CombineIfs, CombineNestedIfs, ConditionPropagation,
               ConvertTrivialIfToSelect, RemoveEmptyElseBranch,
-              RemoveStaticCondition, ReplaceIfYieldWithConditionOrValue>(
-      context);
+              ReplaceIfYieldWithConditionOrValue>(context);
   populateRegionBranchOpInterfaceCanonicalizationPatterns(
       results, IfOp::getOperationName());
+  populateRegionBranchOpInterfaceInliningPattern(results,
+                                                 IfOp::getOperationName());
 }
 
 Block *IfOp::thenBlock() { return &getThenRegion().back(); }
@@ -3796,6 +3707,8 @@ void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
               WhileMoveIfDown>(context);
   populateRegionBranchOpInterfaceCanonicalizationPatterns(
       results, WhileOp::getOperationName());
+  populateRegionBranchOpInterfaceInliningPattern(results,
+                                                 WhileOp::getOperationName());
 }
 
 //===----------------------------------------------------------------------===//
@@ -3942,44 +3855,12 @@ void IndexSwitchOp::getRegionInvocationBounds(
     bounds.emplace_back(/*lb=*/0, /*ub=*/i == liveIndex);
 }
 
-struct FoldConstantCase : OpRewritePattern<scf::IndexSwitchOp> {
-  using OpRewritePattern<scf::IndexSwitchOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(scf::IndexSwitchOp op,
-                                PatternRewriter &rewriter) const override {
-    // If `op.getArg()` is a constant, select the region that matches with
-    // the constant value. Use the default region if no matche is found.
-    std::optional<int64_t> maybeCst = getConstantIntValue(op.getArg());
-    if (!maybeCst.has_value())
-      return failure();
-    int64_t cst = *maybeCst;
-    int64_t caseIdx, e = op.getNumCases();
-    for (caseIdx = 0; caseIdx < e; ++caseIdx) {
-      if (cst == op.getCases()[caseIdx])
-        break;
-    }
-
-    Region &r = (caseIdx < op.getNumCases()) ? op.getCaseRegions()[caseIdx]
-                                             : op.getDefaultRegion();
-    Block &source = r.front();
-    Operation *terminator = source.getTerminator();
-    SmallVector<Value> results = terminator->getOperands();
-
-    rewriter.inlineBlockBefore(&source, op);
-    rewriter.eraseOp(terminator);
-    // Replace the operation with a potentially empty list of results.
-    // Fold mechanism doesn't support the case where the result list is empty.
-    rewriter.replaceOp(op, results);
-
-    return success();
-  }
-};
-
 void IndexSwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                 MLIRContext *context) {
-  results.add<FoldConstantCase>(context);
   populateRegionBranchOpInterfaceCanonicalizationPatterns(
       results, IndexSwitchOp::getOperationName());
+  populateRegionBranchOpInterfaceInliningPattern(
+      results, IndexSwitchOp::getOperationName());
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
index ebf78d8bd60ce..8ed32ddf39a53 100644
--- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
+++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
@@ -9,6 +9,7 @@
 #include <utility>
 
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Matchers.h"
 #include "mlir/IR/Operation.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
@@ -1025,6 +1026,230 @@ struct RemoveDuplicateSuccessorInputUses : public RewritePattern {
     return success(changed);
   }
 };
+
+/// Given a range of values, return a vector of attributes of the same size,
+/// where the i-th attribute is the constant value of the i-th value. If a
+/// value is not constant, the corresponding attribute is null.
+static SmallVector<Attribute> extractConstants(ValueRange values) {
+  return llvm::map_to_vector(values, [](Value value) {
+    Attribute attr;
+    matchPattern(value, m_Constant(&attr));
+    return attr;
+  });
+}
+
+/// Return all successor regions when branching from the given region branch
+/// point. This helper functions extracts all constant operand values and
+/// passes them to the `RegionBranchOpInterface`.
+static SmallVector<RegionSuccessor>
+getSuccessorRegionsWithAttrs(RegionBranchOpInterface op,
+                             RegionBranchPoint point) {
+  SmallVector<RegionSuccessor> successors;
+  if (point.isParent()) {
+    op.getEntrySuccessorRegions(extractConstants(op->getOperands()),
+                                successors);
+    return successors;
+  }
+  RegionBranchTerminatorOpInterface terminator =
+      point.getTerminatorPredecessorOrNull();
+  terminator.getSuccessorRegions(extractConstants(terminator->getOperands()),
+                                 successors);
+  return successors;
+}
+
+/// Find the single acyclic path through the given region branch op. Return an
+/// empty vector if no such path or multiple such paths exist.
+///
+/// Example: "scf.if %true" has a single path: parent => then_region => parent
+///
+/// Example: "scf.if ???" has multiple paths:
+///          (1) parent => then_region => parent
+///          (2) parent => else_region => parent
+///
+/// Example: "scf.while with scf.condition(%false)" has a single path:
+///          parent => before_region => parent
+///
+/// Example: "scf.for with 0 iterations" has a single path: parent => parent
+///
+/// Note: Each path starts and ends with "parent". The "parent" at the beginning
+/// of the path is omitted from the result.
+///
+/// Note: This function also returns an "empty" path when a region with multiple
+/// blocks was found.
+static SmallVector<RegionSuccessor>
+computeSingleAcyclicRegionBranchPath(RegionBranchOpInterface op) {
+  llvm::SmallDenseSet<Region *> visited;
+  SmallVector<RegionSuccessor> path;
+
+  // Path starts with "parent".
+  RegionBranchPoint next = RegionBranchPoint::parent();
+  do {
+    SmallVector<RegionSuccessor> successors =
+        getSuccessorRegionsWithAttrs(op, next);
+    if (successors.size() != 1) {
+      // There are multiple region successors. I.e., there are multiple paths
+      // through the region branch op.
+      return {};
+    }
+    path.push_back(successors.front());
+    if (successors.front().isParent()) {
+      // Found path that ends with "parent".
+      return path;
+    }
+    Region *region = successors.front().getSuccessor();
+    if (!region->hasOneBlock()) {
+      // Entering a region with multiple blocks. Such regions are not supported
+      // at the moment.
+      return {};
+    }
+    if (!visited.insert(region).second) {
+      // We have already visited this region. I.e., we have found a cycle.
+      return {};
+    }
+    auto terminator =
+        dyn_cast<RegionBranchTerminatorOpInterface>(&region->front().back());
+    if (!terminator) {
+      // Region has no RegionBranchTerminatorOpInterface terminator. E.g., the
+      // terminator could be a "ub.unreachable" op. Such IR is not supported.
+      return {};
+    }
+    next = RegionBranchPoint(terminator);
+  } while (true);
+  llvm_unreachable("expected to return from loop");
+}
+
+/// Inline the body of the matched region branch op into the enclosing block if
+/// there is exactly one acyclic path through the region branch op, starting
+/// from "parent", and if that path ends with "parent".
+///
+/// Example: This pattern can inline "scf.for" operations that are guaranteed to
+/// have a single iteration, as indicated by the region branch path "parent =>
+/// region => parent". "scf.for" operations have a non-successor-input: the loop
+/// induction variable. Non-successor-input values have op-specific semantics
+/// and cannot be reasoned about through the `RegionBranchOpInterface`. A
+/// replacement value for non-successor-inputs is injected by the user-specified
+/// lambda: in the case of the loop induction variable of an "scf.for", the
+/// lower bound of the loop is used as a replacement value.
+///
+/// Before pattern application:
+/// %r = scf.for %iv = %c5 to %c6 step %c1 iter_args(%arg0 = %0) {
+///   %1 = "producer"(%arg0, %iv)
+///   scf.yield %1
+/// }
+/// "user"(%r)
+///
+/// After pattern application:
+/// %1 = "producer"(%0, %c5)
+/// "user"(%1)
+///
+/// This pattern is limited to the following cases:
+/// - Only regions with a single block are supported. This could be generalized.
+/// - Region branch ops with side effects are not supported. (Recursive side
+///   effects are fine.)
+///
+/// Note: This pattern queries the region dataflow from the
+/// `RegionBranchOpInterface`. Replacement values are for block arguments / op
+/// results are determined based on region dataflow. In case of
+/// non-successor-inputs (whose values are not modeled by the
+/// `RegionBranchOpInterface`), a user-specified lambda is queried.
+struct InlineRegionBranchOp : public RewritePattern {
+  InlineRegionBranchOp(MLIRContext *context, StringRef name,
+                       NonSuccessorInputReplacementBuilderFn replBuilderFn,
+                       PatternMatcherFn matcherFn, PatternBenefit benefit = 1)
+      : RewritePattern(name, benefit, context), replBuilderFn(replBuilderFn),
+        matcherFn(matcherFn) {}
+
+  LogicalResult matchAndRewrite(Operation *op,
+                                PatternRewriter &rewriter) const override {
+    // Check if the pattern is applicable to the given operation.
+    if (failed(matcherFn(op)))
+      return rewriter.notifyMatchFailure(op, "pattern not applicable");
+
+    // Patterns without recursive memory effects could have side effects, so
+    // it is not safe to fold such ops away.
+    if (!op->hasTrait<OpTrait::HasRecursiveMemoryEffects>())
+      return rewriter.notifyMatchFailure(
+          op, "pattern not applicable to ops without recursive memory effects");
+
+    // Find the single acyclic path through the region branch op.
+    auto regionBranchOp = cast<RegionBranchOpInterface>(op);
+    SmallVector<RegionSuccessor> path =
+        computeSingleAcyclicRegionBranchPath(regionBranchOp);
+    if (path.empty())
+      return rewriter.notifyMatchFailure(
+          op, "failed to find acyclic region branch path");
+
+    // Inline all regions on the path into the enclosing block.
+    rewriter.setInsertionPoint(op);
+    ArrayRef remainingPath = path;
+    OperandRange successorOperands =
+        regionBranchOp.getEntrySuccessorOperands(remainingPath.front());
+    while (!remainingPath.empty()) {
+      RegionSuccessor nextSuccessor = remainingPath.consume_front();
+      ValueRange successorInputs =
+          regionBranchOp.getSuccessorInputs(nextSuccessor);
+      assert(successorInputs.size() == successorOperands.size() &&
+             "size mismatch");
+      // Find the index of the first block argument / op result that is a
+      // succesor input.
+      unsigned firstSuccessorInputIdx = 0;
+      if (!successorInputs.empty())
+        firstSuccessorInputIdx =
+            nextSuccessor.isParent()
+                ? cast<OpResult>(successorInputs.front()).getResultNumber()
+                : cast<BlockArgument>(successorInputs.front()).getArgNumber();
+      // Query the total number of block arguments / op results.
+      unsigned numValues =
+          nextSuccessor.isParent()
+              ? op->getNumResults()
+              : nextSuccessor.getSuccessor()->getNumArguments();
+      // Compute replacement values for all block arguments / op results.
+      SmallVector<Value> replacements;
+      // Helper function to get the i-th block argument / op result.
+      auto getValue = [&](unsigned idx) {
+        return nextSuccessor.isParent()
+                   ? Value(op->getResult(idx))
+                   : Value(nextSuccessor.getSuccessor()->getArgument(idx));
+      };
+      // Compute replacement values for all non-successor-input values that
+      // precede the first successor input.
+      for (unsigned i = 0; i < firstSuccessorInputIdx; ++i)
+        replacements.push_back(
+            replBuilderFn(rewriter, op->getLoc(), getValue(i)));
+      // Use the successor operands of the predecessor as replacement values for
+      // the successor inputs.
+      llvm::append_range(replacements, successorOperands);
+      // Compute replacement values for all block arguments / op results that
+      // succeed the first successor input.
+      for (unsigned i = replacements.size(); i < numValues; ++i)
+        replacements.push_back(
+            replBuilderFn(rewriter, op->getLoc(), getValue(i)));
+      if (nextSuccessor.isParent()) {
+        // The path ends with "parent". Replace the region branch op with the
+        // computed replacement values.
+        assert(remainingPath.empty() && "expected that the path ended");
+        rewriter.replaceOp(op, replacements);
+        return success();
+      }
+      // We are inside of a region: query the successor operands from the
+      // terminator, inline the region into the enclosing block, and erase the
+      // terminator.
+      auto terminator = cast<RegionBranchTerminatorOpInterface>(
+          &nextSuccessor.getSuccessor()->front().back());
+      rewriter.inlineBlockBefore(&nextSuccessor.getSuccessor()->front(),
+                                 op->getBlock(), op->getIterator(),
+                                 replacements);
+      successorOperands =
+          terminator.getSuccessorOperands(remainingPath.front());
+      rewriter.eraseOp(terminator);
+    }
+
+    llvm_unreachable("expected that paths ends with parent");
+  }
+
+  NonSuccessorInputReplacementBuilderFn replBuilderFn;
+  PatternMatcherFn matcherFn;
+};
 } // namespace
 
 void mlir::populateRegionBranchOpInterfaceCanonicalizationPatterns(
@@ -1034,3 +1259,11 @@ void mlir::populateRegionBranchOpInterfaceCanonicalizationPatterns(
                RemoveDeadRegionBranchOpSuccessorInputs>(patterns.getContext(),
                                                         opName, benefit);
 }
+
+void mlir::populateRegionBranchOpInterfaceInliningPattern(
+    RewritePatternSet &patterns, StringRef opName,
+    NonSuccessorInputReplacementBuilderFn replBuilderFn,
+    PatternMatcherFn matcherFn, PatternBenefit benefit) {
+  patterns.add<InlineRegionBranchOp>(patterns.getContext(), opName,
+                                     replBuilderFn, matcherFn, benefit);
+}
diff --git a/mlir/test/Dialect/Arith/int-range-interface.mlir b/mlir/test/Dialect/Arith/int-range-interface.mlir
index 30b7128dab42c..dd8240299ef7e 100644
--- a/mlir/test/Dialect/Arith/int-range-interface.mlir
+++ b/mlir/test/Dialect/Arith/int-range-interface.mlir
@@ -118,8 +118,10 @@ func.func @for_loop_with_constant_result() -> i1 {
 
 // Test to catch a bug present in some versions of the data flow analysis
 // CHECK-LABEL: func @while_false
-// CHECK: %[[false:.*]] = arith.constant false
-// CHECK: scf.condition(%[[false]])
+// CHECK: %[[divui:.*]] = arith.divui
+// CHECK-NOT: scf.while
+// CHECK-NOT: scf.condition
+// CHECK: return %[[divui]]
 func.func @while_false(%arg0 : index) -> index {
     %c0 = arith.constant 0 : index
     %c1 = arith.constant 1 : index
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index e770f595bd262..56ef00dcd6f8b 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -661,6 +661,26 @@ func.func @remove_zero_iteration_loop() {
 
 // -----
 
+// CHECK-LABEL: @remove_while_zero_iteration_loop
+//  CHECK-NEXT:   %[[init:.*]] = "test.init"
+//  CHECK-NEXT:   %[[inner1:.*]] = "test.before"(%[[init]])
+//  CHECK-NEXT:   return %[[inner1]]
+func.func @remove_while_zero_iteration_loop() -> i64 {
+  %init = "test.init"() : () -> i32
+  %false = arith.constant false
+  %0 = scf.while (%arg0 = %init) : (i32) -> (i64) {
+    %inner1 = "test.before"(%arg0) : (i32) -> i64
+    scf.condition(%false) %inner1 : i64
+  } do {
+  ^bb0(%arg1: i64):
+    %inner2 = "test.before"(%arg1) : (i64) -> i32
+    scf.yield %inner2 : i32
+  }
+  return %0 : i64
+}
+
+// -----
+
 // CHECK-LABEL: @remove_zero_iteration_loop_vals
 func.func @remove_zero_iteration_loop_vals(%arg0: index) {
   %c2 = arith.constant 2 : index
diff --git a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
index af09dc865e2de..b431a9e75c669 100644
--- a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
@@ -870,15 +870,17 @@ func.func @non_block_argument_yield() {
 // This is a regression test. Make sure that bufferization succeeds.
 
 // CHECK-LABEL: func @regression_cast_in_loop(
+//  CHECK-NEXT:   %[[alloc1:.*]] = memref.alloc()
+//  CHECK-NEXT:   %[[alloc2:.*]] = memref.alloc()
+//  CHECK-NEXT:   memref.copy %[[alloc1]], %[[alloc2]]
+//  CHECK-NEXT:   return %[[alloc2]]
 func.func @regression_cast_in_loop() -> tensor<2xindex> {
   %false = arith.constant false
   %c0 = arith.constant 0 : index
   %0 = bufferization.alloc_tensor() : tensor<2xindex>
-  // CHECK: scf.while (%{{.*}} = %{{.*}}) : (memref<2xindex>) -> memref<2xindex>
   %1 = scf.while (%arg0 = %0) : (tensor<2xindex>) -> tensor<2xindex> {
     scf.condition(%false) %arg0 : tensor<2xindex>
   } do {
-  // CHECK: ^bb0(%{{.*}}: memref<2xindex>):
   ^bb0(%arg0: tensor<2xindex>):
     %cast = tensor.cast %0 : tensor<2xindex> to tensor<?xindex>
     %inserted = tensor.insert %c0 into %cast[%c0] : tensor<?xindex>



More information about the Mlir-commits mailing list