[Mlir-commits] [mlir] [mlir][Interfaces] Add generic pattern for region inlining (PR #176641)
lonely eagle
llvmlistbot at llvm.org
Thu Jan 22 07:26:19 PST 2026
================
@@ -1025,6 +1026,230 @@ struct RemoveDuplicateSuccessorInputUses : public RewritePattern {
return success(changed);
}
};
+
+/// Given a range of values, return a vector of attributes of the same size,
+/// where the i-th attribute is the constant value of the i-th value. If a
+/// value is not constant, the corresponding attribute is null.
+static SmallVector<Attribute> extractConstants(ValueRange values) {
+ return llvm::map_to_vector(values, [](Value value) {
+ Attribute attr;
+ matchPattern(value, m_Constant(&attr));
+ return attr;
+ });
+}
+
+/// Return all successor regions when branching from the given region branch
+/// point. This helper functions extracts all constant operand values and
+/// passes them to the `RegionBranchOpInterface`.
+static SmallVector<RegionSuccessor>
+getSuccessorRegionsWithAttrs(RegionBranchOpInterface op,
+ RegionBranchPoint point) {
+ SmallVector<RegionSuccessor> successors;
+ if (point.isParent()) {
+ op.getEntrySuccessorRegions(extractConstants(op->getOperands()),
+ successors);
+ return successors;
+ }
+ RegionBranchTerminatorOpInterface terminator =
+ point.getTerminatorPredecessorOrNull();
+ terminator.getSuccessorRegions(extractConstants(terminator->getOperands()),
+ successors);
+ return successors;
+}
+
+/// Find the single acyclic path through the given region branch op. Return an
+/// empty vector if no such path or multiple such paths exist.
+///
+/// Example: "scf.if %true" has a single path: parent => then_region => parent
+///
+/// Example: "scf.if ???" has multiple paths:
+/// (1) parent => then_region => parent
+/// (2) parent => else_region => parent
+///
+/// Example: "scf.while with scf.condition(%false)" has a single path:
+/// parent => before_region => parent
+///
+/// Example: "scf.for with 0 iterations" has a single path: parent => parent
+///
+/// Note: Each path starts and ends with "parent". The "parent" at the beginning
+/// of the path is omitted from the result.
+///
+/// Note: This function also returns an "empty" path when a region with multiple
+/// blocks was found.
+static SmallVector<RegionSuccessor>
+computeSingleAcyclicRegionBranchPath(RegionBranchOpInterface op) {
+ llvm::SmallDenseSet<Region *> visited;
+ SmallVector<RegionSuccessor> path;
+
+ // Path starts with "parent".
+ RegionBranchPoint next = RegionBranchPoint::parent();
+ do {
+ SmallVector<RegionSuccessor> successors =
+ getSuccessorRegionsWithAttrs(op, next);
+ if (successors.size() != 1) {
+ // There are multiple region successors. I.e., there are multiple paths
+ // through the region branch op.
+ return {};
+ }
+ path.push_back(successors.front());
+ if (successors.front().isParent()) {
+ // Found path that ends with "parent".
+ return path;
+ }
+ Region *region = successors.front().getSuccessor();
+ if (!region->hasOneBlock()) {
+ // Entering a region with multiple blocks. Such regions are not supported
+ // at the moment.
+ return {};
+ }
+ if (!visited.insert(region).second) {
+ // We have already visited this region. I.e., we have found a cycle.
+ return {};
+ }
+ auto terminator =
+ dyn_cast<RegionBranchTerminatorOpInterface>(®ion->front().back());
+ if (!terminator) {
+ // Region has no RegionBranchTerminatorOpInterface terminator. E.g., the
+ // terminator could be a "ub.unreachable" op. Such IR is not supported.
+ return {};
+ }
+ next = RegionBranchPoint(terminator);
+ } while (true);
+ llvm_unreachable("expected to return from loop");
+}
+
+/// Inline the body of the matched region branch op into the enclosing block if
+/// there is exactly one acyclic path through the region branch op, starting
+/// from "parent", and if that path ends with "parent".
+///
+/// Example: This pattern can inline "scf.for" operations that are guaranteed to
+/// have a single iteration, as indicated by the region branch path "parent =>
+/// region => parent". "scf.for" operations have a non-successor-input: the loop
+/// induction variable. Non-successor-input values have op-specific semantics
+/// and cannot be reasoned about through the `RegionBranchOpInterface`. A
+/// replacement value for non-successor-inputs is injected by the user-specified
+/// lambda: in the case of the loop induction variable of an "scf.for", the
+/// lower bound of the loop is used as a replacement value.
+///
+/// Before pattern application:
+/// %r = scf.for %iv = %c5 to %c6 step %c1 iter_args(%arg0 = %0) {
+/// %1 = "producer"(%arg0, %iv)
+/// scf.yield %1
+/// }
+/// "user"(%r)
+///
+/// After pattern application:
+/// %1 = "producer"(%0, %c5)
+/// "user"(%1)
+///
+/// This pattern is limited to the following cases:
+/// - Only regions with a single block are supported. This could be generalized.
+/// - Region branch ops with side effects are not supported. (Recursive side
+/// effects are fine.)
+///
+/// Note: This pattern queries the region dataflow from the
+/// `RegionBranchOpInterface`. Replacement values are for block arguments / op
+/// results are determined based on region dataflow. In case of
+/// non-successor-inputs (whose values are not modeled by the
+/// `RegionBranchOpInterface`), a user-specified lambda is queried.
+struct InlineRegionBranchOp : public RewritePattern {
+ InlineRegionBranchOp(MLIRContext *context, StringRef name,
+ NonSuccessorInputReplacementBuilderFn replBuilderFn,
+ PatternMatcherFn matcherFn, PatternBenefit benefit = 1)
+ : RewritePattern(name, benefit, context), replBuilderFn(replBuilderFn),
+ matcherFn(matcherFn) {}
+
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override {
+ // Check if the pattern is applicable to the given operation.
+ if (failed(matcherFn(op)))
+ return rewriter.notifyMatchFailure(op, "pattern not applicable");
+
+ // Patterns without recursive memory effects could have side effects, so
+ // it is not safe to fold such ops away.
+ if (!op->hasTrait<OpTrait::HasRecursiveMemoryEffects>())
+ return rewriter.notifyMatchFailure(
+ op, "pattern not applicable to ops without recursive memory effects");
+
+ // Find the single acyclic path through the region branch op.
+ auto regionBranchOp = cast<RegionBranchOpInterface>(op);
+ SmallVector<RegionSuccessor> path =
+ computeSingleAcyclicRegionBranchPath(regionBranchOp);
+ if (path.empty())
+ return rewriter.notifyMatchFailure(
+ op, "failed to find acyclic region branch path");
+
+ // Inline all regions on the path into the enclosing block.
+ rewriter.setInsertionPoint(op);
+ ArrayRef remainingPath = path;
+ OperandRange successorOperands =
+ regionBranchOp.getEntrySuccessorOperands(remainingPath.front());
+ while (!remainingPath.empty()) {
+ RegionSuccessor nextSuccessor = remainingPath.consume_front();
----------------
linuxlonelyeagle wrote:
My understanding is that since the path stores the successors' inputs, it is named 'next'.
https://github.com/llvm/llvm-project/pull/176641
More information about the Mlir-commits
mailing list