[llvm-branch-commits] [mlir] [mlir][Interfaces] Add generic pattern for region inlining (PR #176641)
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Wed Jan 21 00:53:19 PST 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-scf
@llvm/pr-subscribers-mlir-arith
Author: Matthias Springer (matthias-springer)
<details>
<summary>Changes</summary>
Add a new canonicalization pattern that inlines the body of acyclic `RegionBranchOpInterface` ops. This pattern is a generalization and replacement for the following existing patterns:
* `SingleBlockExecuteInliner`: inlines `scf.execute_region` ops with a single block.
* `SimplifyTrivialLoops`: inlines / folds away `scf.for` ops with 0 or 1 iterations.
* `RemoveStaticCondition`: inlines `scf.if` ops with a static condition.
* `FoldConstantCase`: inlines `scf.index_switch` ops with a constant operand.
Additionally, this new pattern is also enabled for `scf.while` ops. Loops with `scf.condition(%false)` are now also inlined. (New test case added.)
The new pattern looks for region branch ops with a single acyclic path through the operation (starting from and ending at "parent"). All regions on that path can be inlined into the enclosing block.
Depends on #<!-- -->177116.
---
Patch is 27.15 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/176641.diff
6 Files Affected:
- (modified) mlir/include/mlir/Interfaces/ControlFlowInterfaces.h (+39)
- (modified) mlir/lib/Dialect/SCF/IR/SCF.cpp (+26-145)
- (modified) mlir/lib/Interfaces/ControlFlowInterfaces.cpp (+233)
- (modified) mlir/test/Dialect/Arith/int-range-interface.mlir (+4-2)
- (modified) mlir/test/Dialect/SCF/canonicalize.mlir (+20)
- (modified) mlir/test/Dialect/SCF/one-shot-bufferize.mlir (+4-2)
``````````diff
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
index d764089f5ccc8..a76dce6f2ffc5 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
@@ -320,6 +320,45 @@ Region *getEnclosingRepetitiveRegion(Value value);
void populateRegionBranchOpInterfaceCanonicalizationPatterns(
RewritePatternSet &patterns, StringRef opName, PatternBenefit benefit = 1);
+/// Helper function for the region branch op inlining pattern that builds
+/// replacement values for non-successor-input values.
+using NonSuccessorInputReplacementBuilderFn =
+ std::function<Value(OpBuilder &, Location, Value)>;
+/// Helper function for the region branch op inlining pattern that checks if the
+/// pattern is applicable to the given operation.
+using PatternMatcherFn = std::function<LogicalResult(Operation *)>;
+
+namespace detail {
+/// Default implementation of the non-successor-input replacement builder
+/// function. This default implemention assumes that all block arguments and
+/// op results are successor inputs.
+static inline Value defaultReplBuilderFn(OpBuilder &builder, Location loc,
+ Value value) {
+ llvm_unreachable("defaultReplBuilderFn not implemented");
+}
+
+/// Default implementation of the pattern matcher function.
+static inline LogicalResult defaultMatcherFn(Operation *op) {
+ return success();
+}
+} // namespace detail
+
+/// Populate a pattern that inlines the body of region branch ops when there is
+/// a single acyclic path through the region branch op, starting from "parent"
+/// and ending at "parent". For details, refer to the documentation of the
+/// pattern.
+///
+/// `replBuilderFn` is a function that builds replacement values for
+/// non-successor-input values of the region branch op. `matcherFn` is a
+/// function that checks if the pattern is applicable to the given operation.
+/// Both functions are optional.
+void populateRegionBranchOpInterfaceInliningPattern(
+ RewritePatternSet &patterns, StringRef opName,
+ NonSuccessorInputReplacementBuilderFn replBuilderFn =
+ detail::defaultReplBuilderFn,
+ PatternMatcherFn matcherFn = detail::defaultMatcherFn,
+ PatternBenefit benefit = 1);
+
//===----------------------------------------------------------------------===//
// ControlFlow Traits
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 86e66dbaf6171..2ebece4bdedb7 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -132,19 +132,6 @@ std::optional<llvm::APSInt> mlir::scf::computeUbMinusLb(Value lb, Value ub,
// ExecuteRegionOp
//===----------------------------------------------------------------------===//
-/// Replaces the given op with the contents of the given single-block region,
-/// using the operands of the block terminator to replace operation results.
-static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op,
- Region ®ion, ValueRange blockArgs = {}) {
- assert(region.hasOneBlock() && "expected single-block region");
- Block *block = ®ion.front();
- Operation *terminator = block->getTerminator();
- ValueRange results = terminator->getOperands();
- rewriter.inlineBlockBefore(block, op, blockArgs);
- rewriter.replaceOp(op, results);
- rewriter.eraseOp(terminator);
-}
-
///
/// (ssa-id `=`)? `execute_region` `->` function-result-type `{`
/// block+
@@ -192,32 +179,6 @@ LogicalResult ExecuteRegionOp::verify() {
return success();
}
-// Inline an ExecuteRegionOp if it only contains one block.
-// "test.foo"() : () -> ()
-// %v = scf.execute_region -> i64 {
-// %x = "test.val"() : () -> i64
-// scf.yield %x : i64
-// }
-// "test.bar"(%v) : (i64) -> ()
-//
-// becomes
-//
-// "test.foo"() : () -> ()
-// %x = "test.val"() : () -> i64
-// "test.bar"(%x) : (i64) -> ()
-//
-struct SingleBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
- using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(ExecuteRegionOp op,
- PatternRewriter &rewriter) const override {
- if (!op.getRegion().hasOneBlock() || op.getNoInline())
- return failure();
- replaceOpWithRegion(rewriter, op, op.getRegion());
- return success();
- }
-};
-
// Inline an ExecuteRegionOp if its parent can contain multiple blocks.
// TODO generalize the conditions for operations which can be inlined into.
// func @func_execute_region_elim() {
@@ -293,9 +254,15 @@ struct MultiBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
void ExecuteRegionOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<SingleBlockExecuteInliner, MultiBlockExecuteInliner>(context);
+ results.add<MultiBlockExecuteInliner>(context);
populateRegionBranchOpInterfaceCanonicalizationPatterns(
results, ExecuteRegionOp::getOperationName());
+ // Inline ops with a single block that are not marked as "no_inline".
+ populateRegionBranchOpInterfaceInliningPattern(
+ results, ExecuteRegionOp::getOperationName(),
+ mlir::detail::defaultReplBuilderFn, [](Operation *op) {
+ return failure(cast<ExecuteRegionOp>(op).getNoInline());
+ });
}
void ExecuteRegionOp::getSuccessorRegions(
@@ -962,54 +929,6 @@ mlir::scf::replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp,
}
namespace {
-/// Rewriting pattern that erases loops that are known not to iterate, replaces
-/// single-iteration loops with their bodies, and removes empty loops that
-/// iterate at least once and only return values defined outside of the loop.
-struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
- using OpRewritePattern<ForOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(ForOp op,
- PatternRewriter &rewriter) const override {
- std::optional<APInt> tripCount = op.getStaticTripCount();
- if (!tripCount.has_value())
- return rewriter.notifyMatchFailure(op,
- "can't compute constant trip count");
-
- if (tripCount->isZero()) {
- LDBG() << "SimplifyTrivialLoops tripCount is 0 for loop "
- << OpWithFlags(op, OpPrintingFlags().skipRegions());
- rewriter.replaceOp(op, op.getInitArgs());
- return success();
- }
-
- if (tripCount->getSExtValue() == 1) {
- LDBG() << "SimplifyTrivialLoops tripCount is 1 for loop "
- << OpWithFlags(op, OpPrintingFlags().skipRegions());
- SmallVector<Value, 4> blockArgs;
- blockArgs.reserve(op.getInitArgs().size() + 1);
- blockArgs.push_back(op.getLowerBound());
- llvm::append_range(blockArgs, op.getInitArgs());
- replaceOpWithRegion(rewriter, op, op.getRegion(), blockArgs);
- return success();
- }
-
- // Now we are left with loops that have more than 1 iterations.
- Block &block = op.getRegion().front();
- if (!llvm::hasSingleElement(block))
- return failure();
- // The loop is empty and iterates at least once, if it only returns values
- // defined outside of the loop, remove it and replace it with yield values.
- if (llvm::any_of(op.getYieldedValues(),
- [&](Value v) { return !op.isDefinedOutsideOfLoop(v); }))
- return failure();
- LDBG() << "SimplifyTrivialLoops empty body loop allows replacement with "
- "yield operands for loop "
- << OpWithFlags(op, OpPrintingFlags().skipRegions());
- rewriter.replaceOp(op, op.getYieldedValues());
- return success();
- }
-};
-
/// Fold scf.for iter_arg/result pairs that go through incoming/ougoing
/// a tensor.cast op pair so as to pull the tensor.cast inside the scf.for:
///
@@ -1072,9 +991,20 @@ struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
void ForOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<SimplifyTrivialLoops, ForOpTensorCastFolder>(context);
+ results.add<ForOpTensorCastFolder>(context);
populateRegionBranchOpInterfaceCanonicalizationPatterns(
results, ForOp::getOperationName());
+ populateRegionBranchOpInterfaceInliningPattern(
+ results, ForOp::getOperationName(),
+ /*replBuilderFn=*/[](OpBuilder &builder, Location loc, Value value) {
+ // scf.for has only one non-successor input value: the loop induction
+ // variable. In case of a single acyclic path through the op, the IV can
+ // be safely replaced with the lower bound.
+ auto blockArg = cast<BlockArgument>(value);
+ assert(blockArg.getArgNumber() == 0 && "expected induction variable");
+ auto forOp = cast<ForOp>(blockArg.getOwner()->getParentOp());
+ return forOp.getLowerBound();
+ });
}
std::optional<APInt> ForOp::getConstantStep() {
@@ -2218,26 +2148,6 @@ void IfOp::getRegionInvocationBounds(
}
namespace {
-struct RemoveStaticCondition : public OpRewritePattern<IfOp> {
- using OpRewritePattern<IfOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(IfOp op,
- PatternRewriter &rewriter) const override {
- BoolAttr condition;
- if (!matchPattern(op.getCondition(), m_Constant(&condition)))
- return failure();
-
- if (condition.getValue())
- replaceOpWithRegion(rewriter, op, op.getThenRegion());
- else if (!op.getElseRegion().empty())
- replaceOpWithRegion(rewriter, op, op.getElseRegion());
- else
- rewriter.eraseOp(op);
-
- return success();
- }
-};
-
/// Hoist any yielded results whose operands are defined outside
/// the if, to a select instruction.
struct ConvertTrivialIfToSelect : public OpRewritePattern<IfOp> {
@@ -2788,10 +2698,11 @@ void IfOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<CombineIfs, CombineNestedIfs, ConditionPropagation,
ConvertTrivialIfToSelect, RemoveEmptyElseBranch,
- RemoveStaticCondition, ReplaceIfYieldWithConditionOrValue>(
- context);
+ ReplaceIfYieldWithConditionOrValue>(context);
populateRegionBranchOpInterfaceCanonicalizationPatterns(
results, IfOp::getOperationName());
+ populateRegionBranchOpInterfaceInliningPattern(results,
+ IfOp::getOperationName());
}
Block *IfOp::thenBlock() { return &getThenRegion().back(); }
@@ -3796,6 +3707,8 @@ void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
WhileMoveIfDown>(context);
populateRegionBranchOpInterfaceCanonicalizationPatterns(
results, WhileOp::getOperationName());
+ populateRegionBranchOpInterfaceInliningPattern(results,
+ WhileOp::getOperationName());
}
//===----------------------------------------------------------------------===//
@@ -3942,44 +3855,12 @@ void IndexSwitchOp::getRegionInvocationBounds(
bounds.emplace_back(/*lb=*/0, /*ub=*/i == liveIndex);
}
-struct FoldConstantCase : OpRewritePattern<scf::IndexSwitchOp> {
- using OpRewritePattern<scf::IndexSwitchOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(scf::IndexSwitchOp op,
- PatternRewriter &rewriter) const override {
- // If `op.getArg()` is a constant, select the region that matches with
- // the constant value. Use the default region if no matche is found.
- std::optional<int64_t> maybeCst = getConstantIntValue(op.getArg());
- if (!maybeCst.has_value())
- return failure();
- int64_t cst = *maybeCst;
- int64_t caseIdx, e = op.getNumCases();
- for (caseIdx = 0; caseIdx < e; ++caseIdx) {
- if (cst == op.getCases()[caseIdx])
- break;
- }
-
- Region &r = (caseIdx < op.getNumCases()) ? op.getCaseRegions()[caseIdx]
- : op.getDefaultRegion();
- Block &source = r.front();
- Operation *terminator = source.getTerminator();
- SmallVector<Value> results = terminator->getOperands();
-
- rewriter.inlineBlockBefore(&source, op);
- rewriter.eraseOp(terminator);
- // Replace the operation with a potentially empty list of results.
- // Fold mechanism doesn't support the case where the result list is empty.
- rewriter.replaceOp(op, results);
-
- return success();
- }
-};
-
void IndexSwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<FoldConstantCase>(context);
populateRegionBranchOpInterfaceCanonicalizationPatterns(
results, IndexSwitchOp::getOperationName());
+ populateRegionBranchOpInterfaceInliningPattern(
+ results, IndexSwitchOp::getOperationName());
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
index ebf78d8bd60ce..8ed32ddf39a53 100644
--- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
+++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
@@ -9,6 +9,7 @@
#include <utility>
#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Matchers.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
@@ -1025,6 +1026,230 @@ struct RemoveDuplicateSuccessorInputUses : public RewritePattern {
return success(changed);
}
};
+
+/// Given a range of values, return a vector of attributes of the same size,
+/// where the i-th attribute is the constant value of the i-th value. If a
+/// value is not constant, the corresponding attribute is null.
+static SmallVector<Attribute> extractConstants(ValueRange values) {
+ return llvm::map_to_vector(values, [](Value value) {
+ Attribute attr;
+ matchPattern(value, m_Constant(&attr));
+ return attr;
+ });
+}
+
+/// Return all successor regions when branching from the given region branch
+/// point. This helper functions extracts all constant operand values and
+/// passes them to the `RegionBranchOpInterface`.
+static SmallVector<RegionSuccessor>
+getSuccessorRegionsWithAttrs(RegionBranchOpInterface op,
+ RegionBranchPoint point) {
+ SmallVector<RegionSuccessor> successors;
+ if (point.isParent()) {
+ op.getEntrySuccessorRegions(extractConstants(op->getOperands()),
+ successors);
+ return successors;
+ }
+ RegionBranchTerminatorOpInterface terminator =
+ point.getTerminatorPredecessorOrNull();
+ terminator.getSuccessorRegions(extractConstants(terminator->getOperands()),
+ successors);
+ return successors;
+}
+
+/// Find the single acyclic path through the given region branch op. Return an
+/// empty vector if no such path or multiple such paths exist.
+///
+/// Example: "scf.if %true" has a single path: parent => then_region => parent
+///
+/// Example: "scf.if ???" has multiple paths:
+/// (1) parent => then_region => parent
+/// (2) parent => else_region => parent
+///
+/// Example: "scf.while with scf.condition(%false)" has a single path:
+/// parent => before_region => parent
+///
+/// Example: "scf.for with 0 iterations" has a single path: parent => parent
+///
+/// Note: Each path starts and ends with "parent". The "parent" at the beginning
+/// of the path is omitted from the result.
+///
+/// Note: This function also returns an "empty" path when a region with multiple
+/// blocks was found.
+static SmallVector<RegionSuccessor>
+computeSingleAcyclicRegionBranchPath(RegionBranchOpInterface op) {
+ llvm::SmallDenseSet<Region *> visited;
+ SmallVector<RegionSuccessor> path;
+
+ // Path starts with "parent".
+ RegionBranchPoint next = RegionBranchPoint::parent();
+ do {
+ SmallVector<RegionSuccessor> successors =
+ getSuccessorRegionsWithAttrs(op, next);
+ if (successors.size() != 1) {
+ // There are multiple region successors. I.e., there are multiple paths
+ // through the region branch op.
+ return {};
+ }
+ path.push_back(successors.front());
+ if (successors.front().isParent()) {
+ // Found path that ends with "parent".
+ return path;
+ }
+ Region *region = successors.front().getSuccessor();
+ if (!region->hasOneBlock()) {
+ // Entering a region with multiple blocks. Such regions are not supported
+ // at the moment.
+ return {};
+ }
+ if (!visited.insert(region).second) {
+ // We have already visited this region. I.e., we have found a cycle.
+ return {};
+ }
+ auto terminator =
+ dyn_cast<RegionBranchTerminatorOpInterface>(®ion->front().back());
+ if (!terminator) {
+ // Region has no RegionBranchTerminatorOpInterface terminator. E.g., the
+ // terminator could be a "ub.unreachable" op. Such IR is not supported.
+ return {};
+ }
+ next = RegionBranchPoint(terminator);
+ } while (true);
+ llvm_unreachable("expected to return from loop");
+}
+
+/// Inline the body of the matched region branch op into the enclosing block if
+/// there is exactly one acyclic path through the region branch op, starting
+/// from "parent", and if that path ends with "parent".
+///
+/// Example: This pattern can inline "scf.for" operations that are guaranteed to
+/// have a single iteration, as indicated by the region branch path "parent =>
+/// region => parent". "scf.for" operations have a non-successor-input: the loop
+/// induction variable. Non-successor-input values have op-specific semantics
+/// and cannot be reasoned about through the `RegionBranchOpInterface`. A
+/// replacement value for non-successor-inputs is injected by the user-specified
+/// lambda: in the case of the loop induction variable of an "scf.for", the
+/// lower bound of the loop is used as a replacement value.
+///
+/// Before pattern application:
+/// %r = scf.for %iv = %c5 to %c6 step %c1 iter_args(%arg0 = %0) {
+/// %1 = "producer"(%arg0, %iv)
+/// scf.yield %1
+/// }
+/// "user"(%r)
+///
+/// After pattern application:
+/// %1 = "producer"(%0, %c5)
+/// "user"(%1)
+///
+/// This pattern is limited to the following cases:
+/// - Only regions with a single block are supported. This could be generalized.
+/// - Region branch ops with side effects are not supported. (Recursive side
+/// effects are fine.)
+///
+/// Note: This pattern queries the region dataflow from the
+/// `RegionBranchOpInterface`. Replacement values are for block arguments / op
+/// results are determined based on region dataflow. In case of
+/// non-successor-inputs (whose values are not modeled by the
+/// `RegionBranchOpInterface`), a user-specified lambda is queried.
+struct InlineRegionBranchOp : public RewritePattern {
+ InlineRegionBranchOp(MLIRContext *context, StringRef name,
+ NonSuccessorInputReplacementBuilderFn replBuilderFn,
+ PatternMatcherFn matcherFn, PatternBenefit benefit = 1)
+ : RewritePattern(name, benefit, context), replBuilderFn(replBuilderFn),
+ matcherFn(matcherFn) {}
+
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override {
+ // Check if the pattern is applicable to the given operation.
+ if (failed(matcherFn(op)))
+ return rewriter.notifyMatchFailure(op, "pattern not applicable");
+
+ // Patterns without recursive memory effects could have side effects, so
+ // it is not safe to fold such ops away.
+ if (!op->hasTrait<OpTrait::HasRecursiveMemoryEffects>())
+ return rewriter.notifyMatchFailure(
+ op, "pattern not applicable to ops without recursive memory effects");
+
+ // Find the single acyclic path through the region branch op.
+ auto regionBranchOp = cast<RegionBranchOpInterface>(op);
+ SmallVector<RegionSuccessor> path =
+ computeSingleAcyclicRegionBranchPath(regionBranchOp);
+ if (path.empty())
+ return rewriter.notifyMatchFailure(
...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/176641
More information about the llvm-branch-commits
mailing list