[Mlir-commits] [mlir] [MLIR][Mem2Reg] Extract shared utilities for PromotableRegionOpInterface (PR #188514)

Berke Ates llvmlistbot at llvm.org
Wed Mar 25 08:39:59 PDT 2026


https://github.com/Berke-Ates created https://github.com/llvm/llvm-project/pull/188514

The `PromotableRegionOpInterface` implementations use two helpers that are likely useful for other dialects implementing this interface as well:
- `updateTerminator`: Appends the reaching definition as an operand to a block's terminator if it matches a given type.
- `replaceWithNewResults`: Clones an operation with additional result types while preserving its regions, then replaces the original.

This PR extracts them into a common utility header under mlir::memoryslot so that downstream dialects can reuse them directly.                                                                                          I'm open to discussion about the location of these utilities.

>From eea9219601277fd1a2184307172495011dd441bd Mon Sep 17 00:00:00 2001
From: Berke-Ates <berke at ates.ch>
Date: Wed, 25 Mar 2026 16:21:05 +0100
Subject: [PATCH] [MLIR][Mem2Reg] Extract shared utilities for
 PromotableRegionOpInterface

The SCF PromotableRegionOpInterface implementations use two helpers for
updating terminators with reaching definitions and for replacing ops
with additional result types while preserving regions. These are useful
for any dialect implementing this interface, but previously required
downstream dialects to reimplement them. This extracts them into a
common utility header.
---
 .../mlir/Interfaces/Utils/MemorySlotUtils.h   | 67 +++++++++++++++
 mlir/lib/Dialect/SCF/IR/MemorySlot.cpp        | 85 +++++--------------
 2 files changed, 89 insertions(+), 63 deletions(-)
 create mode 100644 mlir/include/mlir/Interfaces/Utils/MemorySlotUtils.h

diff --git a/mlir/include/mlir/Interfaces/Utils/MemorySlotUtils.h b/mlir/include/mlir/Interfaces/Utils/MemorySlotUtils.h
new file mode 100644
index 0000000000000..1d3322426fe05
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/Utils/MemorySlotUtils.h
@@ -0,0 +1,67 @@
+//===- MemorySlotUtils.h - Utilities for MemorySlot interfaces --*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares common utilities for implementing MemorySlot interfaces,
+// in particular PromotableRegionOpInterface.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_UTILS_MEMORYSLOTUTILS_H
+#define MLIR_INTERFACES_UTILS_MEMORYSLOTUTILS_H
+
+#include "mlir/IR/PatternMatch.h"
+
+namespace mlir {
+namespace memoryslot {
+
+/// Adds the corresponding reaching definition to the terminator of the block if
+/// the terminator is of the provided type.
+template <typename TermTy>
+void updateTerminator(
+    Block *block, Value defaultReachingDef,
+    const llvm::DenseMap<Block *, Value> &reachingAtBlockEnd) {
+  Operation *terminator = block->getTerminator();
+  if (!isa<TermTy>(terminator))
+    return;
+  Value blockReachingDef = reachingAtBlockEnd.lookup(block);
+  if (!blockReachingDef) {
+    // Block is dead code or the region is not using the slot, so we use the
+    // default provided reaching definition.
+    blockReachingDef = defaultReachingDef;
+  }
+  terminator->insertOperands(terminator->getNumOperands(), {blockReachingDef});
+}
+
+/// Creates a shallow copy of an operation with new result types, moving the
+/// regions out of the original operation and deleting the original operation.
+inline Operation *replaceWithNewResults(RewriterBase &rewriter, Operation *op,
+                                        TypeRange resultTypes) {
+  RewriterBase::InsertionGuard guard(rewriter);
+  rewriter.setInsertionPoint(op);
+  OperationState state(op->getLoc(), op->getName(), op->getOperands(),
+                       resultTypes, op->getAttrs());
+  for (size_t cnt = 0, e = op->getNumRegions(); cnt < e; ++cnt)
+    state.addRegion();
+  Operation *newOp = rewriter.create(state);
+  rewriter.startOpModification(newOp);
+  rewriter.startOpModification(op);
+  for (unsigned int i : llvm::seq(op->getNumRegions()))
+    newOp->getRegion(i).takeBody(op->getRegion(i));
+  rewriter.finalizeOpModification(op);
+  rewriter.finalizeOpModification(newOp);
+
+  SmallVector<Value> replacementValues(newOp->getResults().drop_back());
+  rewriter.replaceAllOpUsesWith(op, replacementValues);
+  rewriter.eraseOp(op);
+  return newOp;
+}
+
+} // namespace memoryslot
+} // namespace mlir
+
+#endif // MLIR_INTERFACES_UTILS_MEMORYSLOTUTILS_H
diff --git a/mlir/lib/Dialect/SCF/IR/MemorySlot.cpp b/mlir/lib/Dialect/SCF/IR/MemorySlot.cpp
index 3d61476df6014..19a781542eb38 100644
--- a/mlir/lib/Dialect/SCF/IR/MemorySlot.cpp
+++ b/mlir/lib/Dialect/SCF/IR/MemorySlot.cpp
@@ -7,54 +7,11 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/SCF/IR/SCF.h"
-#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/Interfaces/Utils/MemorySlotUtils.h"
 
 using namespace mlir;
 using namespace mlir::scf;
 
-//===----------------------------------------------------------------------===//
-// Helper functions
-//===----------------------------------------------------------------------===//
-
-/// Adds the corresponding reaching definition to the terminator of the block if
-/// the terminator is of the provided type.
-template <typename TermTy>
-static void
-updateTerminator(Block *block, Value defaultReachingDef,
-                 const llvm::DenseMap<Block *, Value> &reachingAtBlockEnd) {
-  Operation *terminator = block->getTerminator();
-  if (!isa<TermTy>(terminator))
-    return;
-  Value blockReachingDef = reachingAtBlockEnd.lookup(block);
-  if (!blockReachingDef) {
-    // Block is dead code or the region is not using the slot, so we use the
-    // default provided reaching definition.
-    blockReachingDef = defaultReachingDef;
-  }
-  terminator->insertOperands(terminator->getNumOperands(), {blockReachingDef});
-}
-
-/// Creates a shallow copy of an operation with new result types, moving the
-/// regions out of the original operation and deleting the original operation.
-static Operation *replaceWithNewResults(RewriterBase &rewriter, Operation *op,
-                                        TypeRange resultTypes) {
-  RewriterBase::InsertionGuard guard(rewriter);
-  rewriter.setInsertionPoint(op);
-  Operation *newOp =
-      mlir::cloneWithoutRegions(rewriter, op, resultTypes, op->getOperands());
-  rewriter.startOpModification(newOp);
-  rewriter.startOpModification(op);
-  for (unsigned int i : llvm::seq(op->getNumRegions()))
-    newOp->getRegion(i).takeBody(op->getRegion(i));
-  rewriter.finalizeOpModification(op);
-  rewriter.finalizeOpModification(newOp);
-
-  SmallVector<Value> replacementValues(newOp->getResults().drop_back());
-  rewriter.replaceAllOpUsesWith(op, replacementValues);
-  rewriter.eraseOp(op);
-  return newOp;
-}
-
 //===----------------------------------------------------------------------===//
 // ExecuteRegionOp
 //===----------------------------------------------------------------------===//
@@ -80,14 +37,15 @@ Value ExecuteRegionOp::finalizePromotion(
   // Update the yield terminators to return the newly defined reaching
   // definition.
   for (Block &block : getRegion().getBlocks())
-    updateTerminator<YieldOp>(&block, reachingDef, reachingAtBlockEnd);
+    memoryslot::updateTerminator<YieldOp>(&block, reachingDef,
+                                          reachingAtBlockEnd);
 
   SmallVector<Type> resultTypes(getResultTypes());
   resultTypes.push_back(slot.elemType);
 
   IRRewriter rewriter(builder);
   Operation *newOp =
-      replaceWithNewResults(rewriter, getOperation(), resultTypes);
+      memoryslot::replaceWithNewResults(rewriter, getOperation(), resultTypes);
   return newOp->getResults().back();
 }
 
@@ -123,14 +81,15 @@ Value ForOp::finalizePromotion(
 
   // Update the yield terminator to return the newly defined reaching
   // definition.
-  updateTerminator<YieldOp>(getBody(), reachingDef, reachingAtBlockEnd);
+  memoryslot::updateTerminator<YieldOp>(getBody(), reachingDef,
+                                        reachingAtBlockEnd);
 
   SmallVector<Type> resultTypes(getResultTypes());
   resultTypes.push_back(slot.elemType);
 
   IRRewriter rewriter(builder);
   Operation *newOp =
-      replaceWithNewResults(rewriter, getOperation(), resultTypes);
+      memoryslot::replaceWithNewResults(rewriter, getOperation(), resultTypes);
   return newOp->getResults().back();
 }
 
@@ -187,11 +146,11 @@ Value IfOp::finalizePromotion(
 
   // Update the yield terminators to return the newly defined reaching
   // definition.
-  updateTerminator<YieldOp>(&getThenRegion().back(), reachingDef,
-                            reachingAtBlockEnd);
+  memoryslot::updateTerminator<YieldOp>(&getThenRegion().back(), reachingDef,
+                                        reachingAtBlockEnd);
   if (getElseRegion().hasOneBlock()) {
-    updateTerminator<YieldOp>(&getElseRegion().back(), reachingDef,
-                              reachingAtBlockEnd);
+    memoryslot::updateTerminator<YieldOp>(&getElseRegion().back(), reachingDef,
+                                          reachingAtBlockEnd);
   } else {
     OpBuilder::InsertionGuard guard(rewriter);
     rewriter.createBlock(&getElseRegion());
@@ -202,7 +161,7 @@ Value IfOp::finalizePromotion(
   resultTypes.push_back(slot.elemType);
 
   Operation *newOp =
-      replaceWithNewResults(rewriter, getOperation(), resultTypes);
+      memoryslot::replaceWithNewResults(rewriter, getOperation(), resultTypes);
   return newOp->getResults().back();
 }
 
@@ -234,17 +193,17 @@ Value IndexSwitchOp::finalizePromotion(
 
   // Update the yield terminators to return the newly defined reaching
   // definition.
-  updateTerminator<YieldOp>(&getDefaultRegion().back(), reachingDef,
-                            reachingAtBlockEnd);
+  memoryslot::updateTerminator<YieldOp>(&getDefaultRegion().back(), reachingDef,
+                                        reachingAtBlockEnd);
   for (Region &caseRegion : getCaseRegions())
-    updateTerminator<YieldOp>(&caseRegion.back(), reachingDef,
-                              reachingAtBlockEnd);
+    memoryslot::updateTerminator<YieldOp>(&caseRegion.back(), reachingDef,
+                                          reachingAtBlockEnd);
 
   SmallVector<Type> resultTypes(getResultTypes());
   resultTypes.push_back(slot.elemType);
 
   Operation *newOp =
-      replaceWithNewResults(rewriter, getOperation(), resultTypes);
+      memoryslot::replaceWithNewResults(rewriter, getOperation(), resultTypes);
   return newOp->getResults().back();
 }
 
@@ -339,10 +298,10 @@ Value WhileOp::finalizePromotion(
 
   // Update the yield terminators to return the newly defined reaching
   // definition.
-  updateTerminator<ConditionOp>(&getBefore().back(),
-                                getBefore().getArguments().back(),
-                                reachingAtBlockEnd);
-  updateTerminator<YieldOp>(
+  memoryslot::updateTerminator<ConditionOp>(&getBefore().back(),
+                                            getBefore().getArguments().back(),
+                                            reachingAtBlockEnd);
+  memoryslot::updateTerminator<YieldOp>(
       &getAfter().back(), getAfter().getArguments().back(), reachingAtBlockEnd);
 
   SmallVector<Type> resultTypes(getResultTypes());
@@ -350,6 +309,6 @@ Value WhileOp::finalizePromotion(
 
   IRRewriter rewriter(builder);
   Operation *newOp =
-      replaceWithNewResults(rewriter, getOperation(), resultTypes);
+      memoryslot::replaceWithNewResults(rewriter, getOperation(), resultTypes);
   return newOp->getResults().back();
 }



More information about the Mlir-commits mailing list