[Mlir-commits] [mlir] [MLIR][Mem2Reg] Extract shared utilities for PromotableRegionOpInterface (PR #188514)

Berke Ates llvmlistbot at llvm.org
Thu Mar 26 00:01:31 PDT 2026


https://github.com/Berke-Ates updated https://github.com/llvm/llvm-project/pull/188514

>From 68bdae3b2dfc6534f7774794e64f2fb66a5a57d3 Mon Sep 17 00:00:00 2001
From: Berke-Ates <berke at ates.ch>
Date: Wed, 25 Mar 2026 16:21:05 +0100
Subject: [PATCH] [MLIR][Mem2Reg] Extract shared utilities for
 PromotableRegionOpInterface

The SCF PromotableRegionOpInterface implementations use two helpers for
updating terminators with reaching definitions and for replacing ops
with additional result types while preserving regions. These are useful
for any dialect implementing this interface, but previously required
downstream dialects to reimplement them. This extracts them into a
common utility header.
---
 .../mlir/Interfaces/Utils/MemorySlotUtils.h   | 48 +++++++++++
 mlir/lib/Dialect/SCF/IR/CMakeLists.txt        |  1 +
 mlir/lib/Dialect/SCF/IR/MemorySlot.cpp        | 85 +++++--------------
 mlir/lib/Interfaces/Utils/CMakeLists.txt      | 15 ++++
 mlir/lib/Interfaces/Utils/MemorySlotUtils.cpp | 39 +++++++++
 5 files changed, 125 insertions(+), 63 deletions(-)
 create mode 100644 mlir/include/mlir/Interfaces/Utils/MemorySlotUtils.h
 create mode 100644 mlir/lib/Interfaces/Utils/MemorySlotUtils.cpp

diff --git a/mlir/include/mlir/Interfaces/Utils/MemorySlotUtils.h b/mlir/include/mlir/Interfaces/Utils/MemorySlotUtils.h
new file mode 100644
index 0000000000000..0de1eafa99fee
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/Utils/MemorySlotUtils.h
@@ -0,0 +1,48 @@
+//===- MemorySlotUtils.h - Utilities for MemorySlot interfaces --*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares common utilities for implementing MemorySlot interfaces,
+// in particular PromotableRegionOpInterface.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_UTILS_MEMORYSLOTUTILS_H
+#define MLIR_INTERFACES_UTILS_MEMORYSLOTUTILS_H
+
+#include "mlir/IR/PatternMatch.h"
+
+namespace mlir {
+namespace memoryslot {
+
+/// Adds the corresponding reaching definition to the terminator of the block if
+/// the terminator is of the provided type.
+template <typename TermTy>
+void updateTerminator(
+    Block *block, Value defaultReachingDef,
+    const llvm::DenseMap<Block *, Value> &reachingAtBlockEnd) {
+  Operation *terminator = block->getTerminator();
+  if (!isa<TermTy>(terminator))
+    return;
+  Value blockReachingDef = reachingAtBlockEnd.lookup(block);
+  if (!blockReachingDef) {
+    // Block is dead code or the region is not using the slot, so we use the
+    // default provided reaching definition.
+    blockReachingDef = defaultReachingDef;
+  }
+  terminator->insertOperands(terminator->getNumOperands(), {blockReachingDef});
+}
+
+/// Creates a shallow copy of an operation with new result types, moving the
+/// regions out of the original operation and deleting the original operation.
+Operation *replaceWithNewResults(RewriterBase &rewriter, Operation *op,
+                                 TypeRange resultTypes);
+
+} // namespace memoryslot
+} // namespace mlir
+
+#endif // MLIR_INTERFACES_UTILS_MEMORYSLOTUTILS_H
diff --git a/mlir/lib/Dialect/SCF/IR/CMakeLists.txt b/mlir/lib/Dialect/SCF/IR/CMakeLists.txt
index fca28c5209e2d..5e7d10e4a9d6a 100644
--- a/mlir/lib/Dialect/SCF/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/SCF/IR/CMakeLists.txt
@@ -17,6 +17,7 @@ add_mlir_dialect_library(MLIRSCFDialect
   MLIRFunctionInterfaces
   MLIRIR
   MLIRLoopLikeInterface
+  MLIRMemorySlotUtils
   MLIRSideEffectInterfaces
   MLIRTensorDialect
   MLIRValueBoundsOpInterface
diff --git a/mlir/lib/Dialect/SCF/IR/MemorySlot.cpp b/mlir/lib/Dialect/SCF/IR/MemorySlot.cpp
index 3d61476df6014..19a781542eb38 100644
--- a/mlir/lib/Dialect/SCF/IR/MemorySlot.cpp
+++ b/mlir/lib/Dialect/SCF/IR/MemorySlot.cpp
@@ -7,54 +7,11 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/SCF/IR/SCF.h"
-#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/Interfaces/Utils/MemorySlotUtils.h"
 
 using namespace mlir;
 using namespace mlir::scf;
 
-//===----------------------------------------------------------------------===//
-// Helper functions
-//===----------------------------------------------------------------------===//
-
-/// Adds the corresponding reaching definition to the terminator of the block if
-/// the terminator is of the provided type.
-template <typename TermTy>
-static void
-updateTerminator(Block *block, Value defaultReachingDef,
-                 const llvm::DenseMap<Block *, Value> &reachingAtBlockEnd) {
-  Operation *terminator = block->getTerminator();
-  if (!isa<TermTy>(terminator))
-    return;
-  Value blockReachingDef = reachingAtBlockEnd.lookup(block);
-  if (!blockReachingDef) {
-    // Block is dead code or the region is not using the slot, so we use the
-    // default provided reaching definition.
-    blockReachingDef = defaultReachingDef;
-  }
-  terminator->insertOperands(terminator->getNumOperands(), {blockReachingDef});
-}
-
-/// Creates a shallow copy of an operation with new result types, moving the
-/// regions out of the original operation and deleting the original operation.
-static Operation *replaceWithNewResults(RewriterBase &rewriter, Operation *op,
-                                        TypeRange resultTypes) {
-  RewriterBase::InsertionGuard guard(rewriter);
-  rewriter.setInsertionPoint(op);
-  Operation *newOp =
-      mlir::cloneWithoutRegions(rewriter, op, resultTypes, op->getOperands());
-  rewriter.startOpModification(newOp);
-  rewriter.startOpModification(op);
-  for (unsigned int i : llvm::seq(op->getNumRegions()))
-    newOp->getRegion(i).takeBody(op->getRegion(i));
-  rewriter.finalizeOpModification(op);
-  rewriter.finalizeOpModification(newOp);
-
-  SmallVector<Value> replacementValues(newOp->getResults().drop_back());
-  rewriter.replaceAllOpUsesWith(op, replacementValues);
-  rewriter.eraseOp(op);
-  return newOp;
-}
-
 //===----------------------------------------------------------------------===//
 // ExecuteRegionOp
 //===----------------------------------------------------------------------===//
@@ -80,14 +37,15 @@ Value ExecuteRegionOp::finalizePromotion(
   // Update the yield terminators to return the newly defined reaching
   // definition.
   for (Block &block : getRegion().getBlocks())
-    updateTerminator<YieldOp>(&block, reachingDef, reachingAtBlockEnd);
+    memoryslot::updateTerminator<YieldOp>(&block, reachingDef,
+                                          reachingAtBlockEnd);
 
   SmallVector<Type> resultTypes(getResultTypes());
   resultTypes.push_back(slot.elemType);
 
   IRRewriter rewriter(builder);
   Operation *newOp =
-      replaceWithNewResults(rewriter, getOperation(), resultTypes);
+      memoryslot::replaceWithNewResults(rewriter, getOperation(), resultTypes);
   return newOp->getResults().back();
 }
 
@@ -123,14 +81,15 @@ Value ForOp::finalizePromotion(
 
   // Update the yield terminator to return the newly defined reaching
   // definition.
-  updateTerminator<YieldOp>(getBody(), reachingDef, reachingAtBlockEnd);
+  memoryslot::updateTerminator<YieldOp>(getBody(), reachingDef,
+                                        reachingAtBlockEnd);
 
   SmallVector<Type> resultTypes(getResultTypes());
   resultTypes.push_back(slot.elemType);
 
   IRRewriter rewriter(builder);
   Operation *newOp =
-      replaceWithNewResults(rewriter, getOperation(), resultTypes);
+      memoryslot::replaceWithNewResults(rewriter, getOperation(), resultTypes);
   return newOp->getResults().back();
 }
 
@@ -187,11 +146,11 @@ Value IfOp::finalizePromotion(
 
   // Update the yield terminators to return the newly defined reaching
   // definition.
-  updateTerminator<YieldOp>(&getThenRegion().back(), reachingDef,
-                            reachingAtBlockEnd);
+  memoryslot::updateTerminator<YieldOp>(&getThenRegion().back(), reachingDef,
+                                        reachingAtBlockEnd);
   if (getElseRegion().hasOneBlock()) {
-    updateTerminator<YieldOp>(&getElseRegion().back(), reachingDef,
-                              reachingAtBlockEnd);
+    memoryslot::updateTerminator<YieldOp>(&getElseRegion().back(), reachingDef,
+                                          reachingAtBlockEnd);
   } else {
     OpBuilder::InsertionGuard guard(rewriter);
     rewriter.createBlock(&getElseRegion());
@@ -202,7 +161,7 @@ Value IfOp::finalizePromotion(
   resultTypes.push_back(slot.elemType);
 
   Operation *newOp =
-      replaceWithNewResults(rewriter, getOperation(), resultTypes);
+      memoryslot::replaceWithNewResults(rewriter, getOperation(), resultTypes);
   return newOp->getResults().back();
 }
 
@@ -234,17 +193,17 @@ Value IndexSwitchOp::finalizePromotion(
 
   // Update the yield terminators to return the newly defined reaching
   // definition.
-  updateTerminator<YieldOp>(&getDefaultRegion().back(), reachingDef,
-                            reachingAtBlockEnd);
+  memoryslot::updateTerminator<YieldOp>(&getDefaultRegion().back(), reachingDef,
+                                        reachingAtBlockEnd);
   for (Region &caseRegion : getCaseRegions())
-    updateTerminator<YieldOp>(&caseRegion.back(), reachingDef,
-                              reachingAtBlockEnd);
+    memoryslot::updateTerminator<YieldOp>(&caseRegion.back(), reachingDef,
+                                          reachingAtBlockEnd);
 
   SmallVector<Type> resultTypes(getResultTypes());
   resultTypes.push_back(slot.elemType);
 
   Operation *newOp =
-      replaceWithNewResults(rewriter, getOperation(), resultTypes);
+      memoryslot::replaceWithNewResults(rewriter, getOperation(), resultTypes);
   return newOp->getResults().back();
 }
 
@@ -339,10 +298,10 @@ Value WhileOp::finalizePromotion(
 
   // Update the yield terminators to return the newly defined reaching
   // definition.
-  updateTerminator<ConditionOp>(&getBefore().back(),
-                                getBefore().getArguments().back(),
-                                reachingAtBlockEnd);
-  updateTerminator<YieldOp>(
+  memoryslot::updateTerminator<ConditionOp>(&getBefore().back(),
+                                            getBefore().getArguments().back(),
+                                            reachingAtBlockEnd);
+  memoryslot::updateTerminator<YieldOp>(
       &getAfter().back(), getAfter().getArguments().back(), reachingAtBlockEnd);
 
   SmallVector<Type> resultTypes(getResultTypes());
@@ -350,6 +309,6 @@ Value WhileOp::finalizePromotion(
 
   IRRewriter rewriter(builder);
   Operation *newOp =
-      replaceWithNewResults(rewriter, getOperation(), resultTypes);
+      memoryslot::replaceWithNewResults(rewriter, getOperation(), resultTypes);
   return newOp->getResults().back();
 }
diff --git a/mlir/lib/Interfaces/Utils/CMakeLists.txt b/mlir/lib/Interfaces/Utils/CMakeLists.txt
index 8c45f66997427..c722fbf4ece09 100644
--- a/mlir/lib/Interfaces/Utils/CMakeLists.txt
+++ b/mlir/lib/Interfaces/Utils/CMakeLists.txt
@@ -1,3 +1,8 @@
+set(LLVM_OPTIONAL_SOURCES
+  InferIntRangeCommon.cpp
+  MemorySlotUtils.cpp
+  )
+
 add_mlir_library(MLIRInferIntRangeCommon
     InferIntRangeCommon.cpp
 
@@ -12,3 +17,13 @@ add_mlir_library(MLIRInferIntRangeCommon
     MLIRInferIntRangeInterface
     MLIRIR
 )
+
+add_mlir_library(MLIRMemorySlotUtils
+    MemorySlotUtils.cpp
+
+    ADDITIONAL_HEADER_DIRS
+    ${MLIR_MAIN_INCLUDE_DIR}/mlir/Interfaces/Utils
+
+    LINK_LIBS PUBLIC
+    MLIRIR
+)
diff --git a/mlir/lib/Interfaces/Utils/MemorySlotUtils.cpp b/mlir/lib/Interfaces/Utils/MemorySlotUtils.cpp
new file mode 100644
index 0000000000000..a891897b5870b
--- /dev/null
+++ b/mlir/lib/Interfaces/Utils/MemorySlotUtils.cpp
@@ -0,0 +1,39 @@
+//===- MemorySlotUtils.cpp - Utilities for MemorySlot interfaces ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements common utilities for implementing MemorySlot interfaces,
+// in particular PromotableRegionOpInterface.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Interfaces/Utils/MemorySlotUtils.h"
+
+using namespace mlir;
+
+Operation *mlir::memoryslot::replaceWithNewResults(RewriterBase &rewriter,
+                                                   Operation *op,
+                                                   TypeRange resultTypes) {
+  RewriterBase::InsertionGuard guard(rewriter);
+  rewriter.setInsertionPoint(op);
+  OperationState state(op->getLoc(), op->getName(), op->getOperands(),
+                       resultTypes, op->getAttrs());
+  for (size_t cnt = 0, e = op->getNumRegions(); cnt < e; ++cnt)
+    state.addRegion();
+  Operation *newOp = rewriter.create(state);
+  rewriter.startOpModification(newOp);
+  rewriter.startOpModification(op);
+  for (unsigned int i : llvm::seq(op->getNumRegions()))
+    newOp->getRegion(i).takeBody(op->getRegion(i));
+  rewriter.finalizeOpModification(op);
+  rewriter.finalizeOpModification(newOp);
+
+  SmallVector<Value> replacementValues(newOp->getResults().drop_back());
+  rewriter.replaceAllOpUsesWith(op, replacementValues);
+  rewriter.eraseOp(op);
+  return newOp;
+}



More information about the Mlir-commits mailing list