[Mlir-commits] [mlir] [mlir][Interfaces] Add generic pattern for region inlining (PR #176641)

Matthias Springer llvmlistbot at llvm.org
Sun Jan 18 03:08:19 PST 2026


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/176641

None

>From 104c30bcfef95b92d96a4e33c2a4b479e2faef79 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Sun, 18 Jan 2026 11:07:39 +0000
Subject: [PATCH] [mlir][Interfaces] Add generic pattern for region inlining

---
 .../mlir/Interfaces/ControlFlowInterfaces.h   |  21 +++
 mlir/lib/Dialect/SCF/IR/SCF.cpp               |  73 +++++---
 mlir/lib/Interfaces/ControlFlowInterfaces.cpp | 160 ++++++++++++++++++
 3 files changed, 231 insertions(+), 23 deletions(-)

diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
index a3382e15fb76d..5acb33be89e23 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
@@ -345,6 +345,27 @@ Region *getEnclosingRepetitiveRegion(Value value);
 void populateRegionBranchOpInterfaceCanonicalizationPatterns(
     RewritePatternSet &patterns, StringRef opName, PatternBenefit benefit = 1);
 
+using NonSuccessorInputReplacementBuilder =
+    std::function<Value(OpBuilder &, Location, Value)>;
+using PatternMatcherFn = std::function<LogicalResult(Operation *)>;
+namespace detail {
+static inline Value defaultReplBuilderFn(OpBuilder &builder, Location loc,
+                                         Value value) {
+  llvm_unreachable("defaultReplBuilderFn not implemented");
+}
+
+static inline LogicalResult defaultMatcherFn(Operation *op) {
+  return success();
+}
+} // namespace detail
+
+void populateRegionBranchOpInterfaceInliningPattern(
+    RewritePatternSet &patterns, StringRef opName,
+    PatternMatcherFn matcherFn = detail::defaultMatcherFn,
+    NonSuccessorInputReplacementBuilder replBuilderFn =
+        detail::defaultReplBuilderFn,
+    PatternBenefit benefit = 1);
+
 //===----------------------------------------------------------------------===//
 // ControlFlow Traits
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 5b6e9304de505..8bdf7014e531d 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -704,6 +704,22 @@ OperandRange ForOp::getEntrySuccessorOperands(RegionSuccessor successor) {
 
 void ForOp::getSuccessorRegions(RegionBranchPoint point,
                                 SmallVectorImpl<RegionSuccessor> &regions) {
+  std::optional<APInt> tripCount = getStaticTripCount();
+  if (point.isParent()) {
+    if (tripCount.has_value() && *tripCount == 0) {
+      regions.push_back(RegionSuccessor::parent());
+      return;
+    } else if (tripCount.has_value()) {
+      regions.push_back(RegionSuccessor(&getRegion()));
+      return;
+    }
+  } else {
+    if (tripCount.has_value() && *tripCount == 1) {
+      regions.push_back(RegionSuccessor::parent());
+      return;
+    }
+  }
+
   // Both the operation itself and the region may be branching into the body or
   // back into the operation itself. It is possible for loop not to enter the
   // body.
@@ -1049,11 +1065,26 @@ struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
 };
 } // namespace
 
+static Value forOpReplacementBuilder(OpBuilder &builder, Location loc,
+                                     Value value) {
+  auto blockArg = cast<BlockArgument>(value);
+  assert(blockArg.getArgNumber() == 0);
+  auto forOp = cast<ForOp>(blockArg.getOwner()->getParentOp());
+  return forOp.getLowerBound();
+}
+
+static LogicalResult forOpReplacementMatcher(Operation *op) {
+  return success();
+}
+
 void ForOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                         MLIRContext *context) {
-  results.add<SimplifyTrivialLoops, ForOpTensorCastFolder>(context);
+  results.add</*SimplifyTrivialLoops,*/ ForOpTensorCastFolder>(context);
   populateRegionBranchOpInterfaceCanonicalizationPatterns(
       results, ForOp::getOperationName());
+  populateRegionBranchOpInterfaceInliningPattern(
+      results, ForOp::getOperationName(), forOpReplacementMatcher,
+      forOpReplacementBuilder);
 }
 
 std::optional<APInt> ForOp::getConstantStep() {
@@ -2127,6 +2158,21 @@ void IfOp::getSuccessorRegions(RegionBranchPoint point,
     return;
   }
 
+  bool hasElseRegion = !getElseRegion().empty();
+  BoolAttr staticCondition;
+  if (matchPattern(getCondition(), m_Constant(&staticCondition))) {
+    if (staticCondition.getValue()) {
+      regions.push_back(RegionSuccessor(&getThenRegion()));
+      return;
+    }
+    if (hasElseRegion) {
+      regions.push_back(RegionSuccessor(&getElseRegion()));
+      return;
+    }
+    regions.push_back(RegionSuccessor::parent());
+    return;
+  }
+
   regions.push_back(RegionSuccessor(&getThenRegion()));
 
   // Don't consider the else region if it is empty.
@@ -2197,26 +2243,6 @@ void IfOp::getRegionInvocationBounds(
 }
 
 namespace {
-struct RemoveStaticCondition : public OpRewritePattern<IfOp> {
-  using OpRewritePattern<IfOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(IfOp op,
-                                PatternRewriter &rewriter) const override {
-    BoolAttr condition;
-    if (!matchPattern(op.getCondition(), m_Constant(&condition)))
-      return failure();
-
-    if (condition.getValue())
-      replaceOpWithRegion(rewriter, op, op.getThenRegion());
-    else if (!op.getElseRegion().empty())
-      replaceOpWithRegion(rewriter, op, op.getElseRegion());
-    else
-      rewriter.eraseOp(op);
-
-    return success();
-  }
-};
-
 /// Hoist any yielded results whose operands are defined outside
 /// the if, to a select instruction.
 struct ConvertTrivialIfToSelect : public OpRewritePattern<IfOp> {
@@ -2767,10 +2793,11 @@ void IfOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                        MLIRContext *context) {
   results.add<CombineIfs, CombineNestedIfs, ConditionPropagation,
               ConvertTrivialIfToSelect, RemoveEmptyElseBranch,
-              RemoveStaticCondition, ReplaceIfYieldWithConditionOrValue>(
-      context);
+              ReplaceIfYieldWithConditionOrValue>(context);
   populateRegionBranchOpInterfaceCanonicalizationPatterns(
       results, IfOp::getOperationName());
+  populateRegionBranchOpInterfaceInliningPattern(results,
+                                                 IfOp::getOperationName());
 }
 
 Block *IfOp::thenBlock() { return &getThenRegion().back(); }
diff --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
index ebd4b63145f92..ef44b54cb9547 100644
--- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
+++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
@@ -9,6 +9,7 @@
 #include <utility>
 
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Matchers.h"
 #include "mlir/IR/Operation.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
@@ -1027,6 +1028,158 @@ struct RemoveDuplicateSuccessorInputUses : public RewritePattern {
     return success(changed);
   }
 };
+
+// No parent at the beginning, but parent at the end.
+static SmallVector<RegionSuccessor>
+computeAcyclicRegionBranchPath(RegionBranchOpInterface op) {
+  llvm::SmallDenseSet<Region *> visited;
+  SmallVector<RegionSuccessor> path;
+  RegionBranchPoint next = RegionBranchPoint::parent();
+  do {
+    SmallVector<RegionSuccessor> successors;
+    op.getSuccessorRegions(next, successors);
+    if (successors.size() != 1) {
+      return {};
+    }
+    path.push_back(successors.front());
+    if (successors.front().isParent()) {
+      next = RegionBranchPoint::parent();
+      continue;
+    }
+    if (!visited.insert(successors.front().getSuccessor()).second) {
+      return {};
+    }
+    SmallVector<RegionBranchPoint> branchPoints;
+    for (Block &b : *successors.front().getSuccessor()) {
+      if (auto terminator =
+              dyn_cast<RegionBranchTerminatorOpInterface>(&b.back())) {
+        branchPoints.push_back(RegionBranchPoint(terminator));
+      }
+    }
+    if (branchPoints.size() != 1) {
+      return {};
+    }
+    next = branchPoints.front();
+  } while (!next.isParent());
+  return path;
+}
+
+/// Inline the body of region branch ops into the enclosing block if the op has
+/// trivial region control flow: parent => region => parent.
+///
+/// The pattern queries the region dataflow from the `RegionBranchOpInterface`.
+/// Replacement values are for block arguments / op results are determined based
+/// on region dataflow. In case of non-successor-inputs (who's values cannot are
+/// not modeled by the `RegionBranchOpInterface`), a user-specified lambda is
+/// queried.
+///
+/// Example: This pattern can inline "scf.for" operations that are guaranteed to
+/// have a single iteration. "scf.for" has a non-successor-input: the loop
+/// induction variable. A replacement value for this block argument is
+/// produced by the user-specified lambda: the lower bounds of the loop.
+///
+/// Before pattern application:
+/// %r = scf.for %iv = %c5 to %c6 step %c1 iter_args(%arg0 = %0) {
+///   %1 = "producer"(%arg0)
+///   scf.yield %1
+/// }
+/// "user"(%r)
+///
+/// After pattern application:
+/// %1 = "producer"(%0)
+/// "user"(%1)
+///
+/// This pattern is limited to the following cases:
+/// - Only ops with a single region and a single block are supported. (This
+///   could be generalized.)
+/// - Only ops with side effects (on the region branch op or its terminator)
+///   are not supported.
+/// - Ops with complex control flow (more than one region) are not supported.
+///   This could be extended in the future, so that ops like "scf.while" can
+///   also be canonicalized by this pattern.
+struct InlineRegionBranchOp : public RewritePattern {
+  InlineRegionBranchOp(MLIRContext *context, StringRef name,
+                       PatternMatcherFn matcherFn,
+                       NonSuccessorInputReplacementBuilder replBuilderFn,
+                       PatternBenefit benefit = 1)
+      : RewritePattern(name, benefit, context), matcherFn(matcherFn),
+        replBuilderFn(replBuilderFn) {}
+
+  LogicalResult matchAndRewrite(Operation *op,
+                                PatternRewriter &rewriter) const override {
+    // Patterns without recursive memory effects could have side effects, so
+    // it is not safe to fold such ops away.
+    if (!op->hasTrait<OpTrait::HasRecursiveMemoryEffects>())
+      return rewriter.notifyMatchFailure(
+          op, "pattern not applicable to ops without recursive memory effects");
+
+    // The enclosing op may not support unstructured control flow.
+    // if (op->getNumRegions() != 1 || !op->getRegion(0).hasOneBlock())
+    //  return rewriter.notifyMatchFailure(op, "pattern not applicable to ops
+    //  with multiple regions/blocks");
+
+    // Match only ops that always branch parent => region.
+    auto regionBranchOp = cast<RegionBranchOpInterface>(op);
+    SmallVector<RegionSuccessor> path =
+        computeAcyclicRegionBranchPath(regionBranchOp);
+    if (path.empty())
+      return rewriter.notifyMatchFailure(
+          op, "failed to find acyclic region path branch path");
+    ArrayRef remainingPath = path;
+    // TODO: Ensure that all regions on the path have one block.
+
+    rewriter.setInsertionPoint(op);
+    OperandRange successorOperands =
+        regionBranchOp.getEntrySuccessorOperands(remainingPath.front());
+    while (!remainingPath.empty()) {
+      RegionSuccessor nextSuccessor = remainingPath.consume_front();
+      ValueRange successorInputs =
+          regionBranchOp.getSuccessorInputs(nextSuccessor);
+      assert(successorInputs.size() == successorOperands.size() &&
+             "size mismatch");
+      unsigned firstSuccessorInputIdx = 0;
+      if (!successorInputs.empty())
+        firstSuccessorInputIdx =
+            nextSuccessor.isParent()
+                ? cast<OpResult>(successorInputs.front()).getResultNumber()
+                : cast<BlockArgument>(successorInputs.front()).getArgNumber();
+      unsigned numValues =
+          nextSuccessor.isParent()
+              ? op->getNumResults()
+              : nextSuccessor.getSuccessor()->getNumArguments();
+      SmallVector<Value> replacements;
+      auto getValue = [&](unsigned idx) {
+        return nextSuccessor.isParent()
+                   ? Value(op->getResult(idx))
+                   : Value(nextSuccessor.getSuccessor()->getArgument(idx));
+      };
+      for (unsigned i = 0; i < firstSuccessorInputIdx; ++i)
+        replacements.push_back(
+            replBuilderFn(rewriter, op->getLoc(), getValue(i)));
+      llvm::append_range(replacements, successorOperands);
+      for (unsigned i = replacements.size(); i < numValues; ++i)
+        replacements.push_back(
+            replBuilderFn(rewriter, op->getLoc(), getValue(i)));
+      if (nextSuccessor.isParent()) {
+        rewriter.replaceOp(op, replacements);
+        return success();
+      }
+      auto terminator = cast<RegionBranchTerminatorOpInterface>(
+          &nextSuccessor.getSuccessor()->front().back());
+      rewriter.inlineBlockBefore(&nextSuccessor.getSuccessor()->front(),
+                                 op->getBlock(), op->getIterator(),
+                                 replacements);
+      successorOperands =
+          terminator.getSuccessorOperands(remainingPath.front());
+      rewriter.eraseOp(terminator);
+    }
+
+    llvm_unreachable("expected that paths ends with parent");
+  }
+
+  PatternMatcherFn matcherFn;
+  NonSuccessorInputReplacementBuilder replBuilderFn;
+};
 } // namespace
 
 void mlir::populateRegionBranchOpInterfaceCanonicalizationPatterns(
@@ -1036,3 +1189,10 @@ void mlir::populateRegionBranchOpInterfaceCanonicalizationPatterns(
                RemoveDeadRegionBranchOpSuccessorInputs>(patterns.getContext(),
                                                         opName, benefit);
 }
+
+void mlir::populateRegionBranchOpInterfaceInliningPattern(
+    RewritePatternSet &patterns, StringRef opName, PatternMatcherFn matcherFn,
+    NonSuccessorInputReplacementBuilder replBuilderFn, PatternBenefit benefit) {
+  patterns.add<InlineRegionBranchOp>(patterns.getContext(), opName, matcherFn,
+                                     replBuilderFn, benefit);
+}



More information about the Mlir-commits mailing list