[Mlir-commits] [mlir] [mlir][Interfaces] Add generic pattern for region inlining (PR #176641)

lonely eagle llvmlistbot at llvm.org
Thu Jan 22 07:28:28 PST 2026


================
@@ -1025,6 +1026,230 @@ struct RemoveDuplicateSuccessorInputUses : public RewritePattern {
     return success(changed);
   }
 };
+
+/// Given a range of values, return a vector of attributes of the same size,
+/// where the i-th attribute is the constant value of the i-th value. If a
+/// value is not constant, the corresponding attribute is null.
+static SmallVector<Attribute> extractConstants(ValueRange values) {
+  return llvm::map_to_vector(values, [](Value value) {
+    Attribute attr;
+    matchPattern(value, m_Constant(&attr));
+    return attr;
+  });
+}
+
+/// Return all successor regions when branching from the given region branch
+/// point. This helper functions extracts all constant operand values and
+/// passes them to the `RegionBranchOpInterface`.
+static SmallVector<RegionSuccessor>
+getSuccessorRegionsWithAttrs(RegionBranchOpInterface op,
+                             RegionBranchPoint point) {
+  SmallVector<RegionSuccessor> successors;
+  if (point.isParent()) {
+    op.getEntrySuccessorRegions(extractConstants(op->getOperands()),
+                                successors);
+    return successors;
+  }
+  RegionBranchTerminatorOpInterface terminator =
+      point.getTerminatorPredecessorOrNull();
+  terminator.getSuccessorRegions(extractConstants(terminator->getOperands()),
+                                 successors);
+  return successors;
+}
+
+/// Find the single acyclic path through the given region branch op. Return an
+/// empty vector if no such path or multiple such paths exist.
+///
+/// Example: "scf.if %true" has a single path: parent => then_region => parent
+///
+/// Example: "scf.if ???" has multiple paths:
+///          (1) parent => then_region => parent
+///          (2) parent => else_region => parent
+///
+/// Example: "scf.while with scf.condition(%false)" has a single path:
+///          parent => before_region => parent
+///
+/// Example: "scf.for with 0 iterations" has a single path: parent => parent
+///
+/// Note: Each path starts and ends with "parent". The "parent" at the beginning
+/// of the path is omitted from the result.
+///
+/// Note: This function also returns an "empty" path when a region with multiple
+/// blocks was found.
+static SmallVector<RegionSuccessor>
+computeSingleAcyclicRegionBranchPath(RegionBranchOpInterface op) {
+  llvm::SmallDenseSet<Region *> visited;
+  SmallVector<RegionSuccessor> path;
+
+  // Path starts with "parent".
+  RegionBranchPoint next = RegionBranchPoint::parent();
+  do {
+    SmallVector<RegionSuccessor> successors =
+        getSuccessorRegionsWithAttrs(op, next);
+    if (successors.size() != 1) {
+      // There are multiple region successors. I.e., there are multiple paths
+      // through the region branch op.
+      return {};
+    }
+    path.push_back(successors.front());
+    if (successors.front().isParent()) {
+      // Found path that ends with "parent".
+      return path;
+    }
+    Region *region = successors.front().getSuccessor();
+    if (!region->hasOneBlock()) {
+      // Entering a region with multiple blocks. Such regions are not supported
+      // at the moment.
+      return {};
+    }
+    if (!visited.insert(region).second) {
+      // We have already visited this region. I.e., we have found a cycle.
+      return {};
+    }
+    auto terminator =
+        dyn_cast<RegionBranchTerminatorOpInterface>(&region->front().back());
+    if (!terminator) {
+      // Region has no RegionBranchTerminatorOpInterface terminator. E.g., the
+      // terminator could be a "ub.unreachable" op. Such IR is not supported.
+      return {};
+    }
+    next = RegionBranchPoint(terminator);
+  } while (true);
+  llvm_unreachable("expected to return from loop");
+}
+
+/// Inline the body of the matched region branch op into the enclosing block if
+/// there is exactly one acyclic path through the region branch op, starting
+/// from "parent", and if that path ends with "parent".
+///
+/// Example: This pattern can inline "scf.for" operations that are guaranteed to
+/// have a single iteration, as indicated by the region branch path "parent =>
+/// region => parent". "scf.for" operations have a non-successor-input: the loop
+/// induction variable. Non-successor-input values have op-specific semantics
+/// and cannot be reasoned about through the `RegionBranchOpInterface`. A
+/// replacement value for non-successor-inputs is injected by the user-specified
+/// lambda: in the case of the loop induction variable of an "scf.for", the
+/// lower bound of the loop is used as a replacement value.
+///
+/// Before pattern application:
+/// %r = scf.for %iv = %c5 to %c6 step %c1 iter_args(%arg0 = %0) {
+///   %1 = "producer"(%arg0, %iv)
+///   scf.yield %1
+/// }
+/// "user"(%r)
+///
+/// After pattern application:
+/// %1 = "producer"(%0, %c5)
+/// "user"(%1)
+///
+/// This pattern is limited to the following cases:
+/// - Only regions with a single block are supported. This could be generalized.
+/// - Region branch ops with side effects are not supported. (Recursive side
+///   effects are fine.)
+///
+/// Note: This pattern queries the region dataflow from the
+/// `RegionBranchOpInterface`. Replacement values are for block arguments / op
+/// results are determined based on region dataflow. In case of
+/// non-successor-inputs (whose values are not modeled by the
+/// `RegionBranchOpInterface`), a user-specified lambda is queried.
+struct InlineRegionBranchOp : public RewritePattern {
+  InlineRegionBranchOp(MLIRContext *context, StringRef name,
+                       NonSuccessorInputReplacementBuilderFn replBuilderFn,
+                       PatternMatcherFn matcherFn, PatternBenefit benefit = 1)
+      : RewritePattern(name, benefit, context), replBuilderFn(replBuilderFn),
+        matcherFn(matcherFn) {}
+
+  LogicalResult matchAndRewrite(Operation *op,
+                                PatternRewriter &rewriter) const override {
+    // Check if the pattern is applicable to the given operation.
+    if (failed(matcherFn(op)))
+      return rewriter.notifyMatchFailure(op, "pattern not applicable");
+
+    // Patterns without recursive memory effects could have side effects, so
+    // it is not safe to fold such ops away.
+    if (!op->hasTrait<OpTrait::HasRecursiveMemoryEffects>())
+      return rewriter.notifyMatchFailure(
+          op, "pattern not applicable to ops without recursive memory effects");
+
+    // Find the single acyclic path through the region branch op.
+    auto regionBranchOp = cast<RegionBranchOpInterface>(op);
+    SmallVector<RegionSuccessor> path =
+        computeSingleAcyclicRegionBranchPath(regionBranchOp);
+    if (path.empty())
+      return rewriter.notifyMatchFailure(
+          op, "failed to find acyclic region branch path");
+
+    // Inline all regions on the path into the enclosing block.
+    rewriter.setInsertionPoint(op);
+    ArrayRef remainingPath = path;
+    OperandRange successorOperands =
+        regionBranchOp.getEntrySuccessorOperands(remainingPath.front());
+    while (!remainingPath.empty()) {
+      RegionSuccessor nextSuccessor = remainingPath.consume_front();
----------------
linuxlonelyeagle wrote:

> Yes, region -> region -> parent is possible. E.g., for `scf.while` loops: before_region -> after_region -> parent.

I think it would be best to add this as a test case.



https://github.com/llvm/llvm-project/pull/176641


More information about the Mlir-commits mailing list