[Mlir-commits] [mlir] [mlir][Interfaces] Add generic pattern for region inlining (PR #176641)
Matthias Springer
llvmlistbot at llvm.org
Sun Jan 18 06:06:32 PST 2026
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/176641
>From e11fe6939dbacf858fd94ab9e9360d849c15e405 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Sun, 18 Jan 2026 11:07:39 +0000
Subject: [PATCH] [mlir][Interfaces] Add generic pattern for region inlining
---
.../mlir/Interfaces/ControlFlowInterfaces.h | 21 ++
mlir/lib/Dialect/SCF/IR/SCF.cpp | 60 +++--
mlir/lib/Interfaces/ControlFlowInterfaces.cpp | 209 ++++++++++++++++++
3 files changed, 267 insertions(+), 23 deletions(-)
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
index a3382e15fb76d..5acb33be89e23 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
@@ -345,6 +345,27 @@ Region *getEnclosingRepetitiveRegion(Value value);
void populateRegionBranchOpInterfaceCanonicalizationPatterns(
RewritePatternSet &patterns, StringRef opName, PatternBenefit benefit = 1);
+using NonSuccessorInputReplacementBuilder =
+ std::function<Value(OpBuilder &, Location, Value)>;
+using PatternMatcherFn = std::function<LogicalResult(Operation *)>;
+namespace detail {
+static inline Value defaultReplBuilderFn(OpBuilder &builder, Location loc,
+ Value value) {
+ llvm_unreachable("defaultReplBuilderFn not implemented");
+}
+
+static inline LogicalResult defaultMatcherFn(Operation *op) {
+ return success();
+}
+} // namespace detail
+
+void populateRegionBranchOpInterfaceInliningPattern(
+ RewritePatternSet &patterns, StringRef opName,
+ PatternMatcherFn matcherFn = detail::defaultMatcherFn,
+ NonSuccessorInputReplacementBuilder replBuilderFn =
+ detail::defaultReplBuilderFn,
+ PatternBenefit benefit = 1);
+
//===----------------------------------------------------------------------===//
// ControlFlow Traits
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 5b6e9304de505..61fd8e3fd258d 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -704,6 +704,22 @@ OperandRange ForOp::getEntrySuccessorOperands(RegionSuccessor successor) {
void ForOp::getSuccessorRegions(RegionBranchPoint point,
SmallVectorImpl<RegionSuccessor> ®ions) {
+ std::optional<APInt> tripCount = getStaticTripCount();
+ if (point.isParent()) {
+ if (tripCount.has_value() && *tripCount == 0) {
+ regions.push_back(RegionSuccessor::parent());
+ return;
+ } else if (tripCount.has_value()) {
+ regions.push_back(RegionSuccessor(&getRegion()));
+ return;
+ }
+ } else {
+ if (tripCount.has_value() && *tripCount == 1) {
+ regions.push_back(RegionSuccessor::parent());
+ return;
+ }
+ }
+
// Both the operation itself and the region may be branching into the body or
// back into the operation itself. It is possible for loop not to enter the
// body.
@@ -1049,11 +1065,26 @@ struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
};
} // namespace
+static Value forOpReplacementBuilder(OpBuilder &builder, Location loc,
+ Value value) {
+ auto blockArg = cast<BlockArgument>(value);
+ assert(blockArg.getArgNumber() == 0);
+ auto forOp = cast<ForOp>(blockArg.getOwner()->getParentOp());
+ return forOp.getLowerBound();
+}
+
+static LogicalResult forOpReplacementMatcher(Operation *op) {
+ return success();
+}
+
void ForOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<SimplifyTrivialLoops, ForOpTensorCastFolder>(context);
+ results.add</*SimplifyTrivialLoops,*/ ForOpTensorCastFolder>(context);
populateRegionBranchOpInterfaceCanonicalizationPatterns(
results, ForOp::getOperationName());
+ populateRegionBranchOpInterfaceInliningPattern(
+ results, ForOp::getOperationName(), forOpReplacementMatcher,
+ forOpReplacementBuilder);
}
std::optional<APInt> ForOp::getConstantStep() {
@@ -2197,26 +2228,6 @@ void IfOp::getRegionInvocationBounds(
}
namespace {
-struct RemoveStaticCondition : public OpRewritePattern<IfOp> {
- using OpRewritePattern<IfOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(IfOp op,
- PatternRewriter &rewriter) const override {
- BoolAttr condition;
- if (!matchPattern(op.getCondition(), m_Constant(&condition)))
- return failure();
-
- if (condition.getValue())
- replaceOpWithRegion(rewriter, op, op.getThenRegion());
- else if (!op.getElseRegion().empty())
- replaceOpWithRegion(rewriter, op, op.getElseRegion());
- else
- rewriter.eraseOp(op);
-
- return success();
- }
-};
-
/// Hoist any yielded results whose operands are defined outside
/// the if, to a select instruction.
struct ConvertTrivialIfToSelect : public OpRewritePattern<IfOp> {
@@ -2767,10 +2778,11 @@ void IfOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<CombineIfs, CombineNestedIfs, ConditionPropagation,
ConvertTrivialIfToSelect, RemoveEmptyElseBranch,
- RemoveStaticCondition, ReplaceIfYieldWithConditionOrValue>(
- context);
+ ReplaceIfYieldWithConditionOrValue>(context);
populateRegionBranchOpInterfaceCanonicalizationPatterns(
results, IfOp::getOperationName());
+ populateRegionBranchOpInterfaceInliningPattern(results,
+ IfOp::getOperationName());
}
Block *IfOp::thenBlock() { return &getThenRegion().back(); }
@@ -3775,6 +3787,8 @@ void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
WhileMoveIfDown>(context);
populateRegionBranchOpInterfaceCanonicalizationPatterns(
results, WhileOp::getOperationName());
+ populateRegionBranchOpInterfaceInliningPattern(results,
+ WhileOp::getOperationName());
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
index ebd4b63145f92..75facf1fed773 100644
--- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
+++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
@@ -9,6 +9,7 @@
#include <utility>
#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Matchers.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
@@ -1027,6 +1028,207 @@ struct RemoveDuplicateSuccessorInputUses : public RewritePattern {
return success(changed);
}
};
+
+static SmallVector<Attribute> extractAttributes(ValueRange values) {
+ SmallVector<Attribute> attributes;
+ attributes.reserve(values.size());
+ for (Value value : values) {
+ Attribute attr;
+ matchPattern(value, m_Constant(&attr));
+ attributes.push_back(attr);
+ }
+ return attributes;
+}
+
+static SmallVector<RegionSuccessor>
+getSuccessorRegionsWithAttrs(RegionBranchOpInterface op,
+ RegionBranchPoint point) {
+ SmallVector<RegionSuccessor> successors;
+ SmallVector<Attribute> attributes;
+ if (point.isParent()) {
+ attributes = extractAttributes(op->getOperands());
+ op.getEntrySuccessorRegions(attributes, successors);
+ return successors;
+ }
+ attributes =
+ extractAttributes(point.getTerminatorPredecessorOrNull()->getOperands());
+ cast<RegionBranchTerminatorOpInterface>(
+ point.getTerminatorPredecessorOrNull())
+ .getSuccessorRegions(attributes, successors);
+ return successors;
+}
+
+static SmallVector<RegionSuccessor>
+computeAcyclicRegionBranchPath(RegionBranchOpInterface op) {
+ llvm::SmallDenseSet<Region *> visited;
+ SmallVector<RegionSuccessor> path;
+
+ // Path starts with "parent".
+ RegionBranchPoint next = RegionBranchPoint::parent();
+ do {
+ SmallVector<RegionSuccessor> successors =
+ getSuccessorRegionsWithAttrs(op, next);
+ if (successors.size() != 1) {
+ // There are multiple region successors. I.e., there are multiple paths
+ // through the region branch op.
+ return {};
+ }
+ path.push_back(successors.front());
+ if (successors.front().isParent()) {
+ // Found path that ends with "parent".
+ return path;
+ }
+ Region *region = successors.front().getSuccessor();
+ if (!region->hasOneBlock()) {
+ // Entering a region with multiple blocks. Such regions are not supported
+ // at the moment.
+ return {};
+ }
+ if (!visited.insert(region).second) {
+ // We have already visited this region. I.e., we have found a cycle.
+ return {};
+ }
+ auto terminator =
+ dyn_cast<RegionBranchTerminatorOpInterface>(®ion->front().back());
+ if (!terminator) {
+ // Region has no RegionBranchTerminatorOpInterface terminator. E.g., the
+ // terminator could be a "ub.unreachable" op. Such IR is not supported.
+ return {};
+ }
+ next = RegionBranchPoint(terminator);
+ } while (true);
+ llvm_unreachable("expected to return from loop");
+}
+
+/// Inline the body of the matched region branch op into the enclosing block if
+/// there is exactly one acyclic path through the region branch op, starting
+/// from "parent", and if that path ends with "parent".
+///
+/// Example: This pattern can inline "scf.for" operations that are guaranteed to
+/// have a single iteration, as indicated by the region branch path "parent =>
+/// region => parent". "scf.for" operations have a non-successor-input: the loop
+/// induction variable. Non-successor-input values have op-specific semantics
+/// and cannot be reasoned about through the `RegionBranchOpInterface`. A
+/// replacement value for non-successor-inputs is injected by the user-specified
+/// lambda: in the case of the loop induction variable of an "scf.for", the
+/// lower bound of the loop is used as a replacement value.
+///
+/// Before pattern application:
+/// %r = scf.for %iv = %c5 to %c6 step %c1 iter_args(%arg0 = %0) {
+/// %1 = "producer"(%arg0, %iv)
+/// scf.yield %1
+/// }
+/// "user"(%r)
+///
+/// After pattern application:
+/// %1 = "producer"(%0, %c5)
+/// "user"(%1)
+///
+/// This pattern is limited to the following cases:
+/// - Only regions with a single block are supported. This could be generalized.
+/// - Region branch ops with side effects are not supported. (Recursive side
+/// effects are fine.)
+///
+/// Note: This pattern queries the region dataflow from the
+/// `RegionBranchOpInterface`. Replacement values are for block arguments / op
+/// results are determined based on region dataflow. In case of
+/// non-successor-inputs (whose values are not modeled by the
+/// `RegionBranchOpInterface`), a user-specified lambda is queried.
+struct InlineRegionBranchOp : public RewritePattern {
+ InlineRegionBranchOp(MLIRContext *context, StringRef name,
+ PatternMatcherFn matcherFn,
+ NonSuccessorInputReplacementBuilder replBuilderFn,
+ PatternBenefit benefit = 1)
+ : RewritePattern(name, benefit, context), matcherFn(matcherFn),
+ replBuilderFn(replBuilderFn) {}
+
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override {
+ // Patterns without recursive memory effects could have side effects, so
+ // it is not safe to fold such ops away.
+ if (!op->hasTrait<OpTrait::HasRecursiveMemoryEffects>())
+ return rewriter.notifyMatchFailure(
+ op, "pattern not applicable to ops without recursive memory effects");
+
+ // Find the single acyclic path through the region branch op.
+ auto regionBranchOp = cast<RegionBranchOpInterface>(op);
+ SmallVector<RegionSuccessor> path =
+ computeAcyclicRegionBranchPath(regionBranchOp);
+ if (path.empty())
+ return rewriter.notifyMatchFailure(
+ op, "failed to find acyclic region branch path");
+
+ // Inline all regions on the path into the enclosing block.
+ rewriter.setInsertionPoint(op);
+ ArrayRef remainingPath = path;
+ OperandRange successorOperands =
+ regionBranchOp.getEntrySuccessorOperands(remainingPath.front());
+ while (!remainingPath.empty()) {
+ RegionSuccessor nextSuccessor = remainingPath.consume_front();
+ ValueRange successorInputs =
+ regionBranchOp.getSuccessorInputs(nextSuccessor);
+ assert(successorInputs.size() == successorOperands.size() &&
+ "size mismatch");
+ // Find the index of the first block argument / op result that is a
+ // succesor input.
+ unsigned firstSuccessorInputIdx = 0;
+ if (!successorInputs.empty())
+ firstSuccessorInputIdx =
+ nextSuccessor.isParent()
+ ? cast<OpResult>(successorInputs.front()).getResultNumber()
+ : cast<BlockArgument>(successorInputs.front()).getArgNumber();
+ // Query the total number of block arguments / op results.
+ unsigned numValues =
+ nextSuccessor.isParent()
+ ? op->getNumResults()
+ : nextSuccessor.getSuccessor()->getNumArguments();
+ // Compute replacement values for all block arguments / op results.
+ SmallVector<Value> replacements;
+ // Helper function to get the i-th block argument / op result.
+ auto getValue = [&](unsigned idx) {
+ return nextSuccessor.isParent()
+ ? Value(op->getResult(idx))
+ : Value(nextSuccessor.getSuccessor()->getArgument(idx));
+ };
+ // Compute replacement values for all non-successor-input values that
+ // precede the first successor input.
+ for (unsigned i = 0; i < firstSuccessorInputIdx; ++i)
+ replacements.push_back(
+ replBuilderFn(rewriter, op->getLoc(), getValue(i)));
+ // Use the successor operands of the predecessor as replacement values for
+ // the successor inputs.
+ llvm::append_range(replacements, successorOperands);
+ // Compute replacement values for all block arguments / op results that
+ // succeed the first successor input.
+ for (unsigned i = replacements.size(); i < numValues; ++i)
+ replacements.push_back(
+ replBuilderFn(rewriter, op->getLoc(), getValue(i)));
+ if (nextSuccessor.isParent()) {
+ // The path ends with "parent". Replace the region branch op with the
+ // computed replacement values.
+ assert(remainingPath.empty() && "expected that the path ended");
+ rewriter.replaceOp(op, replacements);
+ return success();
+ }
+ // We are inside of a region: query the successor operands from the
+ // terminator, inline the region into the enclosing block, and erase the
+ // terminator.
+ auto terminator = cast<RegionBranchTerminatorOpInterface>(
+ &nextSuccessor.getSuccessor()->front().back());
+ rewriter.inlineBlockBefore(&nextSuccessor.getSuccessor()->front(),
+ op->getBlock(), op->getIterator(),
+ replacements);
+ successorOperands =
+ terminator.getSuccessorOperands(remainingPath.front());
+ rewriter.eraseOp(terminator);
+ }
+
+ llvm_unreachable("expected that paths ends with parent");
+ }
+
+ PatternMatcherFn matcherFn;
+ NonSuccessorInputReplacementBuilder replBuilderFn;
+};
} // namespace
void mlir::populateRegionBranchOpInterfaceCanonicalizationPatterns(
@@ -1036,3 +1238,10 @@ void mlir::populateRegionBranchOpInterfaceCanonicalizationPatterns(
RemoveDeadRegionBranchOpSuccessorInputs>(patterns.getContext(),
opName, benefit);
}
+
+void mlir::populateRegionBranchOpInterfaceInliningPattern(
+ RewritePatternSet &patterns, StringRef opName, PatternMatcherFn matcherFn,
+ NonSuccessorInputReplacementBuilder replBuilderFn, PatternBenefit benefit) {
+ patterns.add<InlineRegionBranchOp>(patterns.getContext(), opName, matcherFn,
+ replBuilderFn, benefit);
+}
More information about the Mlir-commits
mailing list