[Mlir-commits] [mlir] [MLIR][Mem2Reg] Extract shared utilities for PromotableRegionOpInterface (PR #188514)

Berke Ates llvmlistbot at llvm.org
Thu Mar 26 02:04:54 PDT 2026


https://github.com/Berke-Ates updated https://github.com/llvm/llvm-project/pull/188514

>From f05e8eb4086c28dfca8252c80881598c176766ab Mon Sep 17 00:00:00 2001
From: Berke-Ates <berke at ates.ch>
Date: Wed, 25 Mar 2026 16:21:05 +0100
Subject: [PATCH] [MLIR][Mem2Reg] Extract shared utilities for
 PromotableRegionOpInterface

The PromotableRegionOpInterface implementations use two helpers that are
likely useful for other dialects implementing this interface as well.
This extracts them into a common utility header so that downstream
dialects can reuse them directly.
---
 .../mlir/Interfaces/Utils/MemorySlotUtils.h   | 36 ++++++++
 mlir/lib/Dialect/SCF/IR/CMakeLists.txt        |  1 +
 mlir/lib/Dialect/SCF/IR/MemorySlot.cpp        | 84 +++++--------------
 mlir/lib/Interfaces/Utils/CMakeLists.txt      | 15 ++++
 mlir/lib/Interfaces/Utils/MemorySlotUtils.cpp | 49 +++++++++++
 5 files changed, 122 insertions(+), 63 deletions(-)
 create mode 100644 mlir/include/mlir/Interfaces/Utils/MemorySlotUtils.h
 create mode 100644 mlir/lib/Interfaces/Utils/MemorySlotUtils.cpp

diff --git a/mlir/include/mlir/Interfaces/Utils/MemorySlotUtils.h b/mlir/include/mlir/Interfaces/Utils/MemorySlotUtils.h
new file mode 100644
index 0000000000000..66be79b5ad285
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/Utils/MemorySlotUtils.h
@@ -0,0 +1,36 @@
+//===- MemorySlotUtils.h - Utilities for MemorySlot interfaces --*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares common utilities for implementing MemorySlot interfaces,
+// in particular PromotableRegionOpInterface.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_UTILS_MEMORYSLOTUTILS_H
+#define MLIR_INTERFACES_UTILS_MEMORYSLOTUTILS_H
+
+#include "mlir/IR/PatternMatch.h"
+
+namespace mlir {
+namespace memoryslot {
+
+/// Appends the reaching definition for the given block as an operand to its
+/// terminator. If the block has no entry in `reachingAtBlockEnd` (e.g. dead
+/// code or the region does not use the slot), `defaultReachingDef` is used.
+void updateTerminator(Block *block, Value defaultReachingDef,
+                      const DenseMap<Block *, Value> &reachingAtBlockEnd);
+
+/// Creates a shallow copy of an operation with new result types, moving the
+/// regions out of the original operation and deleting the original operation.
+Operation *replaceWithNewResults(RewriterBase &rewriter, Operation *op,
+                                 TypeRange resultTypes);
+
+} // namespace memoryslot
+} // namespace mlir
+
+#endif // MLIR_INTERFACES_UTILS_MEMORYSLOTUTILS_H
diff --git a/mlir/lib/Dialect/SCF/IR/CMakeLists.txt b/mlir/lib/Dialect/SCF/IR/CMakeLists.txt
index fca28c5209e2d..5e7d10e4a9d6a 100644
--- a/mlir/lib/Dialect/SCF/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/SCF/IR/CMakeLists.txt
@@ -17,6 +17,7 @@ add_mlir_dialect_library(MLIRSCFDialect
   MLIRFunctionInterfaces
   MLIRIR
   MLIRLoopLikeInterface
+  MLIRMemorySlotUtils
   MLIRSideEffectInterfaces
   MLIRTensorDialect
   MLIRValueBoundsOpInterface
diff --git a/mlir/lib/Dialect/SCF/IR/MemorySlot.cpp b/mlir/lib/Dialect/SCF/IR/MemorySlot.cpp
index 3d61476df6014..92fc3452d4629 100644
--- a/mlir/lib/Dialect/SCF/IR/MemorySlot.cpp
+++ b/mlir/lib/Dialect/SCF/IR/MemorySlot.cpp
@@ -7,54 +7,11 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/SCF/IR/SCF.h"
-#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/Interfaces/Utils/MemorySlotUtils.h"
 
 using namespace mlir;
 using namespace mlir::scf;
 
-//===----------------------------------------------------------------------===//
-// Helper functions
-//===----------------------------------------------------------------------===//
-
-/// Adds the corresponding reaching definition to the terminator of the block if
-/// the terminator is of the provided type.
-template <typename TermTy>
-static void
-updateTerminator(Block *block, Value defaultReachingDef,
-                 const llvm::DenseMap<Block *, Value> &reachingAtBlockEnd) {
-  Operation *terminator = block->getTerminator();
-  if (!isa<TermTy>(terminator))
-    return;
-  Value blockReachingDef = reachingAtBlockEnd.lookup(block);
-  if (!blockReachingDef) {
-    // Block is dead code or the region is not using the slot, so we use the
-    // default provided reaching definition.
-    blockReachingDef = defaultReachingDef;
-  }
-  terminator->insertOperands(terminator->getNumOperands(), {blockReachingDef});
-}
-
-/// Creates a shallow copy of an operation with new result types, moving the
-/// regions out of the original operation and deleting the original operation.
-static Operation *replaceWithNewResults(RewriterBase &rewriter, Operation *op,
-                                        TypeRange resultTypes) {
-  RewriterBase::InsertionGuard guard(rewriter);
-  rewriter.setInsertionPoint(op);
-  Operation *newOp =
-      mlir::cloneWithoutRegions(rewriter, op, resultTypes, op->getOperands());
-  rewriter.startOpModification(newOp);
-  rewriter.startOpModification(op);
-  for (unsigned int i : llvm::seq(op->getNumRegions()))
-    newOp->getRegion(i).takeBody(op->getRegion(i));
-  rewriter.finalizeOpModification(op);
-  rewriter.finalizeOpModification(newOp);
-
-  SmallVector<Value> replacementValues(newOp->getResults().drop_back());
-  rewriter.replaceAllOpUsesWith(op, replacementValues);
-  rewriter.eraseOp(op);
-  return newOp;
-}
-
 //===----------------------------------------------------------------------===//
 // ExecuteRegionOp
 //===----------------------------------------------------------------------===//
@@ -80,14 +37,15 @@ Value ExecuteRegionOp::finalizePromotion(
   // Update the yield terminators to return the newly defined reaching
   // definition.
   for (Block &block : getRegion().getBlocks())
-    updateTerminator<YieldOp>(&block, reachingDef, reachingAtBlockEnd);
+    if (isa<YieldOp>(block.getTerminator()))
+      memoryslot::updateTerminator(&block, reachingDef, reachingAtBlockEnd);
 
   SmallVector<Type> resultTypes(getResultTypes());
   resultTypes.push_back(slot.elemType);
 
   IRRewriter rewriter(builder);
   Operation *newOp =
-      replaceWithNewResults(rewriter, getOperation(), resultTypes);
+      memoryslot::replaceWithNewResults(rewriter, getOperation(), resultTypes);
   return newOp->getResults().back();
 }
 
@@ -123,14 +81,14 @@ Value ForOp::finalizePromotion(
 
   // Update the yield terminator to return the newly defined reaching
   // definition.
-  updateTerminator<YieldOp>(getBody(), reachingDef, reachingAtBlockEnd);
+  memoryslot::updateTerminator(getBody(), reachingDef, reachingAtBlockEnd);
 
   SmallVector<Type> resultTypes(getResultTypes());
   resultTypes.push_back(slot.elemType);
 
   IRRewriter rewriter(builder);
   Operation *newOp =
-      replaceWithNewResults(rewriter, getOperation(), resultTypes);
+      memoryslot::replaceWithNewResults(rewriter, getOperation(), resultTypes);
   return newOp->getResults().back();
 }
 
@@ -187,11 +145,11 @@ Value IfOp::finalizePromotion(
 
   // Update the yield terminators to return the newly defined reaching
   // definition.
-  updateTerminator<YieldOp>(&getThenRegion().back(), reachingDef,
-                            reachingAtBlockEnd);
+  memoryslot::updateTerminator(&getThenRegion().back(), reachingDef,
+                               reachingAtBlockEnd);
   if (getElseRegion().hasOneBlock()) {
-    updateTerminator<YieldOp>(&getElseRegion().back(), reachingDef,
-                              reachingAtBlockEnd);
+    memoryslot::updateTerminator(&getElseRegion().back(), reachingDef,
+                                 reachingAtBlockEnd);
   } else {
     OpBuilder::InsertionGuard guard(rewriter);
     rewriter.createBlock(&getElseRegion());
@@ -202,7 +160,7 @@ Value IfOp::finalizePromotion(
   resultTypes.push_back(slot.elemType);
 
   Operation *newOp =
-      replaceWithNewResults(rewriter, getOperation(), resultTypes);
+      memoryslot::replaceWithNewResults(rewriter, getOperation(), resultTypes);
   return newOp->getResults().back();
 }
 
@@ -234,17 +192,17 @@ Value IndexSwitchOp::finalizePromotion(
 
   // Update the yield terminators to return the newly defined reaching
   // definition.
-  updateTerminator<YieldOp>(&getDefaultRegion().back(), reachingDef,
-                            reachingAtBlockEnd);
+  memoryslot::updateTerminator(&getDefaultRegion().back(), reachingDef,
+                               reachingAtBlockEnd);
   for (Region &caseRegion : getCaseRegions())
-    updateTerminator<YieldOp>(&caseRegion.back(), reachingDef,
-                              reachingAtBlockEnd);
+    memoryslot::updateTerminator(&caseRegion.back(), reachingDef,
+                                 reachingAtBlockEnd);
 
   SmallVector<Type> resultTypes(getResultTypes());
   resultTypes.push_back(slot.elemType);
 
   Operation *newOp =
-      replaceWithNewResults(rewriter, getOperation(), resultTypes);
+      memoryslot::replaceWithNewResults(rewriter, getOperation(), resultTypes);
   return newOp->getResults().back();
 }
 
@@ -339,10 +297,10 @@ Value WhileOp::finalizePromotion(
 
   // Update the yield terminators to return the newly defined reaching
   // definition.
-  updateTerminator<ConditionOp>(&getBefore().back(),
-                                getBefore().getArguments().back(),
-                                reachingAtBlockEnd);
-  updateTerminator<YieldOp>(
+  memoryslot::updateTerminator(&getBefore().back(),
+                               getBefore().getArguments().back(),
+                               reachingAtBlockEnd);
+  memoryslot::updateTerminator(
       &getAfter().back(), getAfter().getArguments().back(), reachingAtBlockEnd);
 
   SmallVector<Type> resultTypes(getResultTypes());
@@ -350,6 +308,6 @@ Value WhileOp::finalizePromotion(
 
   IRRewriter rewriter(builder);
   Operation *newOp =
-      replaceWithNewResults(rewriter, getOperation(), resultTypes);
+      memoryslot::replaceWithNewResults(rewriter, getOperation(), resultTypes);
   return newOp->getResults().back();
 }
diff --git a/mlir/lib/Interfaces/Utils/CMakeLists.txt b/mlir/lib/Interfaces/Utils/CMakeLists.txt
index 8c45f66997427..c722fbf4ece09 100644
--- a/mlir/lib/Interfaces/Utils/CMakeLists.txt
+++ b/mlir/lib/Interfaces/Utils/CMakeLists.txt
@@ -1,3 +1,8 @@
+set(LLVM_OPTIONAL_SOURCES
+  InferIntRangeCommon.cpp
+  MemorySlotUtils.cpp
+  )
+
 add_mlir_library(MLIRInferIntRangeCommon
     InferIntRangeCommon.cpp
 
@@ -12,3 +17,13 @@ add_mlir_library(MLIRInferIntRangeCommon
     MLIRInferIntRangeInterface
     MLIRIR
 )
+
+add_mlir_library(MLIRMemorySlotUtils
+    MemorySlotUtils.cpp
+
+    ADDITIONAL_HEADER_DIRS
+    ${MLIR_MAIN_INCLUDE_DIR}/mlir/Interfaces/Utils
+
+    LINK_LIBS PUBLIC
+    MLIRIR
+)
diff --git a/mlir/lib/Interfaces/Utils/MemorySlotUtils.cpp b/mlir/lib/Interfaces/Utils/MemorySlotUtils.cpp
new file mode 100644
index 0000000000000..32a8d898ae3c6
--- /dev/null
+++ b/mlir/lib/Interfaces/Utils/MemorySlotUtils.cpp
@@ -0,0 +1,49 @@
+//===- MemorySlotUtils.cpp - Utilities for MemorySlot interfaces ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements common utilities for implementing MemorySlot interfaces,
+// in particular PromotableRegionOpInterface.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Interfaces/Utils/MemorySlotUtils.h"
+
+using namespace mlir;
+
+void mlir::memoryslot::updateTerminator(
+    Block *block, Value defaultReachingDef,
+    const DenseMap<Block *, Value> &reachingAtBlockEnd) {
+  Value blockReachingDef = reachingAtBlockEnd.lookup(block);
+  if (!blockReachingDef)
+    blockReachingDef = defaultReachingDef;
+  Operation *terminator = block->getTerminator();
+  terminator->insertOperands(terminator->getNumOperands(), {blockReachingDef});
+}
+
+Operation *mlir::memoryslot::replaceWithNewResults(RewriterBase &rewriter,
+                                                   Operation *op,
+                                                   TypeRange resultTypes) {
+  RewriterBase::InsertionGuard guard(rewriter);
+  rewriter.setInsertionPoint(op);
+  OperationState state(op->getLoc(), op->getName(), op->getOperands(),
+                       resultTypes, op->getAttrs());
+  for (unsigned cnt = 0, e = op->getNumRegions(); cnt < e; ++cnt)
+    state.addRegion();
+  Operation *newOp = rewriter.create(state);
+  rewriter.startOpModification(newOp);
+  rewriter.startOpModification(op);
+  for (unsigned i : llvm::seq(op->getNumRegions()))
+    newOp->getRegion(i).takeBody(op->getRegion(i));
+  rewriter.finalizeOpModification(op);
+  rewriter.finalizeOpModification(newOp);
+
+  rewriter.replaceAllOpUsesWith(
+      op, newOp->getResults().take_front(op->getNumResults()));
+  rewriter.eraseOp(op);
+  return newOp;
+}



More information about the Mlir-commits mailing list