[Mlir-commits] [mlir] [MLIR][Mem2Reg] Extract shared utilities for PromotableRegionOpInterface (PR #188514)
Berke Ates
llvmlistbot at llvm.org
Thu Mar 26 00:10:21 PDT 2026
https://github.com/Berke-Ates updated https://github.com/llvm/llvm-project/pull/188514
>From 2c7fb33c3711d23beacb7857c0db7c9128fb3b2c Mon Sep 17 00:00:00 2001
From: Berke-Ates <berke at ates.ch>
Date: Wed, 25 Mar 2026 16:21:05 +0100
Subject: [PATCH] [MLIR][Mem2Reg] Extract shared utilities for
PromotableRegionOpInterface
The SCF PromotableRegionOpInterface implementations use two helpers for
updating terminators with reaching definitions and for replacing ops
with additional result types while preserving regions. These are useful
for any dialect implementing this interface, but previously required
downstream dialects to reimplement them. This extracts them into a
common utility header.
---
.../mlir/Interfaces/Utils/MemorySlotUtils.h | 36 ++++++++
mlir/lib/Dialect/SCF/IR/CMakeLists.txt | 1 +
mlir/lib/Dialect/SCF/IR/MemorySlot.cpp | 84 +++++--------------
mlir/lib/Interfaces/Utils/CMakeLists.txt | 15 ++++
mlir/lib/Interfaces/Utils/MemorySlotUtils.cpp | 49 +++++++++++
5 files changed, 122 insertions(+), 63 deletions(-)
create mode 100644 mlir/include/mlir/Interfaces/Utils/MemorySlotUtils.h
create mode 100644 mlir/lib/Interfaces/Utils/MemorySlotUtils.cpp
diff --git a/mlir/include/mlir/Interfaces/Utils/MemorySlotUtils.h b/mlir/include/mlir/Interfaces/Utils/MemorySlotUtils.h
new file mode 100644
index 0000000000000..f057e419a209f
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/Utils/MemorySlotUtils.h
@@ -0,0 +1,36 @@
+//===- MemorySlotUtils.h - Utilities for MemorySlot interfaces --*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares common utilities for implementing MemorySlot interfaces,
+// in particular PromotableRegionOpInterface.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_UTILS_MEMORYSLOTUTILS_H
+#define MLIR_INTERFACES_UTILS_MEMORYSLOTUTILS_H
+
+#include "mlir/IR/PatternMatch.h"
+
+namespace mlir {
+namespace memoryslot {
+
+/// Appends the reaching definition for the given block as an operand to its
+/// terminator. If the block has no entry in \p reachingAtBlockEnd (e.g. dead
+/// code or the region does not use the slot), \p defaultReachingDef is used.
+void updateTerminator(Block *block, Value defaultReachingDef,
+ const DenseMap<Block *, Value> &reachingAtBlockEnd);
+
+/// Creates a shallow copy of an operation with new result types, moving the
+/// regions out of the original operation and deleting the original operation.
+Operation *replaceWithNewResults(RewriterBase &rewriter, Operation *op,
+ TypeRange resultTypes);
+
+} // namespace memoryslot
+} // namespace mlir
+
+#endif // MLIR_INTERFACES_UTILS_MEMORYSLOTUTILS_H
diff --git a/mlir/lib/Dialect/SCF/IR/CMakeLists.txt b/mlir/lib/Dialect/SCF/IR/CMakeLists.txt
index fca28c5209e2d..5e7d10e4a9d6a 100644
--- a/mlir/lib/Dialect/SCF/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/SCF/IR/CMakeLists.txt
@@ -17,6 +17,7 @@ add_mlir_dialect_library(MLIRSCFDialect
MLIRFunctionInterfaces
MLIRIR
MLIRLoopLikeInterface
+ MLIRMemorySlotUtils
MLIRSideEffectInterfaces
MLIRTensorDialect
MLIRValueBoundsOpInterface
diff --git a/mlir/lib/Dialect/SCF/IR/MemorySlot.cpp b/mlir/lib/Dialect/SCF/IR/MemorySlot.cpp
index 3d61476df6014..92fc3452d4629 100644
--- a/mlir/lib/Dialect/SCF/IR/MemorySlot.cpp
+++ b/mlir/lib/Dialect/SCF/IR/MemorySlot.cpp
@@ -7,54 +7,11 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/SCF/IR/SCF.h"
-#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/Interfaces/Utils/MemorySlotUtils.h"
using namespace mlir;
using namespace mlir::scf;
-//===----------------------------------------------------------------------===//
-// Helper functions
-//===----------------------------------------------------------------------===//
-
-/// Adds the corresponding reaching definition to the terminator of the block if
-/// the terminator is of the provided type.
-template <typename TermTy>
-static void
-updateTerminator(Block *block, Value defaultReachingDef,
- const llvm::DenseMap<Block *, Value> &reachingAtBlockEnd) {
- Operation *terminator = block->getTerminator();
- if (!isa<TermTy>(terminator))
- return;
- Value blockReachingDef = reachingAtBlockEnd.lookup(block);
- if (!blockReachingDef) {
- // Block is dead code or the region is not using the slot, so we use the
- // default provided reaching definition.
- blockReachingDef = defaultReachingDef;
- }
- terminator->insertOperands(terminator->getNumOperands(), {blockReachingDef});
-}
-
-/// Creates a shallow copy of an operation with new result types, moving the
-/// regions out of the original operation and deleting the original operation.
-static Operation *replaceWithNewResults(RewriterBase &rewriter, Operation *op,
- TypeRange resultTypes) {
- RewriterBase::InsertionGuard guard(rewriter);
- rewriter.setInsertionPoint(op);
- Operation *newOp =
- mlir::cloneWithoutRegions(rewriter, op, resultTypes, op->getOperands());
- rewriter.startOpModification(newOp);
- rewriter.startOpModification(op);
- for (unsigned int i : llvm::seq(op->getNumRegions()))
- newOp->getRegion(i).takeBody(op->getRegion(i));
- rewriter.finalizeOpModification(op);
- rewriter.finalizeOpModification(newOp);
-
- SmallVector<Value> replacementValues(newOp->getResults().drop_back());
- rewriter.replaceAllOpUsesWith(op, replacementValues);
- rewriter.eraseOp(op);
- return newOp;
-}
-
//===----------------------------------------------------------------------===//
// ExecuteRegionOp
//===----------------------------------------------------------------------===//
@@ -80,14 +37,15 @@ Value ExecuteRegionOp::finalizePromotion(
// Update the yield terminators to return the newly defined reaching
// definition.
for (Block &block : getRegion().getBlocks())
- updateTerminator<YieldOp>(&block, reachingDef, reachingAtBlockEnd);
+ if (isa<YieldOp>(block.getTerminator()))
+ memoryslot::updateTerminator(&block, reachingDef, reachingAtBlockEnd);
SmallVector<Type> resultTypes(getResultTypes());
resultTypes.push_back(slot.elemType);
IRRewriter rewriter(builder);
Operation *newOp =
- replaceWithNewResults(rewriter, getOperation(), resultTypes);
+ memoryslot::replaceWithNewResults(rewriter, getOperation(), resultTypes);
return newOp->getResults().back();
}
@@ -123,14 +81,14 @@ Value ForOp::finalizePromotion(
// Update the yield terminator to return the newly defined reaching
// definition.
- updateTerminator<YieldOp>(getBody(), reachingDef, reachingAtBlockEnd);
+ memoryslot::updateTerminator(getBody(), reachingDef, reachingAtBlockEnd);
SmallVector<Type> resultTypes(getResultTypes());
resultTypes.push_back(slot.elemType);
IRRewriter rewriter(builder);
Operation *newOp =
- replaceWithNewResults(rewriter, getOperation(), resultTypes);
+ memoryslot::replaceWithNewResults(rewriter, getOperation(), resultTypes);
return newOp->getResults().back();
}
@@ -187,11 +145,11 @@ Value IfOp::finalizePromotion(
// Update the yield terminators to return the newly defined reaching
// definition.
- updateTerminator<YieldOp>(&getThenRegion().back(), reachingDef,
- reachingAtBlockEnd);
+ memoryslot::updateTerminator(&getThenRegion().back(), reachingDef,
+ reachingAtBlockEnd);
if (getElseRegion().hasOneBlock()) {
- updateTerminator<YieldOp>(&getElseRegion().back(), reachingDef,
- reachingAtBlockEnd);
+ memoryslot::updateTerminator(&getElseRegion().back(), reachingDef,
+ reachingAtBlockEnd);
} else {
OpBuilder::InsertionGuard guard(rewriter);
rewriter.createBlock(&getElseRegion());
@@ -202,7 +160,7 @@ Value IfOp::finalizePromotion(
resultTypes.push_back(slot.elemType);
Operation *newOp =
- replaceWithNewResults(rewriter, getOperation(), resultTypes);
+ memoryslot::replaceWithNewResults(rewriter, getOperation(), resultTypes);
return newOp->getResults().back();
}
@@ -234,17 +192,17 @@ Value IndexSwitchOp::finalizePromotion(
// Update the yield terminators to return the newly defined reaching
// definition.
- updateTerminator<YieldOp>(&getDefaultRegion().back(), reachingDef,
- reachingAtBlockEnd);
+ memoryslot::updateTerminator(&getDefaultRegion().back(), reachingDef,
+ reachingAtBlockEnd);
for (Region &caseRegion : getCaseRegions())
- updateTerminator<YieldOp>(&caseRegion.back(), reachingDef,
- reachingAtBlockEnd);
+ memoryslot::updateTerminator(&caseRegion.back(), reachingDef,
+ reachingAtBlockEnd);
SmallVector<Type> resultTypes(getResultTypes());
resultTypes.push_back(slot.elemType);
Operation *newOp =
- replaceWithNewResults(rewriter, getOperation(), resultTypes);
+ memoryslot::replaceWithNewResults(rewriter, getOperation(), resultTypes);
return newOp->getResults().back();
}
@@ -339,10 +297,10 @@ Value WhileOp::finalizePromotion(
// Update the yield terminators to return the newly defined reaching
// definition.
- updateTerminator<ConditionOp>(&getBefore().back(),
- getBefore().getArguments().back(),
- reachingAtBlockEnd);
- updateTerminator<YieldOp>(
+ memoryslot::updateTerminator(&getBefore().back(),
+ getBefore().getArguments().back(),
+ reachingAtBlockEnd);
+ memoryslot::updateTerminator(
&getAfter().back(), getAfter().getArguments().back(), reachingAtBlockEnd);
SmallVector<Type> resultTypes(getResultTypes());
@@ -350,6 +308,6 @@ Value WhileOp::finalizePromotion(
IRRewriter rewriter(builder);
Operation *newOp =
- replaceWithNewResults(rewriter, getOperation(), resultTypes);
+ memoryslot::replaceWithNewResults(rewriter, getOperation(), resultTypes);
return newOp->getResults().back();
}
diff --git a/mlir/lib/Interfaces/Utils/CMakeLists.txt b/mlir/lib/Interfaces/Utils/CMakeLists.txt
index 8c45f66997427..c722fbf4ece09 100644
--- a/mlir/lib/Interfaces/Utils/CMakeLists.txt
+++ b/mlir/lib/Interfaces/Utils/CMakeLists.txt
@@ -1,3 +1,8 @@
+set(LLVM_OPTIONAL_SOURCES
+ InferIntRangeCommon.cpp
+ MemorySlotUtils.cpp
+ )
+
add_mlir_library(MLIRInferIntRangeCommon
InferIntRangeCommon.cpp
@@ -12,3 +17,13 @@ add_mlir_library(MLIRInferIntRangeCommon
MLIRInferIntRangeInterface
MLIRIR
)
+
+add_mlir_library(MLIRMemorySlotUtils
+ MemorySlotUtils.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Interfaces/Utils
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+)
diff --git a/mlir/lib/Interfaces/Utils/MemorySlotUtils.cpp b/mlir/lib/Interfaces/Utils/MemorySlotUtils.cpp
new file mode 100644
index 0000000000000..1ae389df219fe
--- /dev/null
+++ b/mlir/lib/Interfaces/Utils/MemorySlotUtils.cpp
@@ -0,0 +1,49 @@
+//===- MemorySlotUtils.cpp - Utilities for MemorySlot interfaces ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements common utilities for implementing MemorySlot interfaces,
+// in particular PromotableRegionOpInterface.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Interfaces/Utils/MemorySlotUtils.h"
+
+using namespace mlir;
+
+void mlir::memoryslot::updateTerminator(
+ Block *block, Value defaultReachingDef,
+ const DenseMap<Block *, Value> &reachingAtBlockEnd) {
+ Value blockReachingDef = reachingAtBlockEnd.lookup(block);
+ if (!blockReachingDef)
+ blockReachingDef = defaultReachingDef;
+ Operation *terminator = block->getTerminator();
+ terminator->insertOperands(terminator->getNumOperands(), {blockReachingDef});
+}
+
+Operation *mlir::memoryslot::replaceWithNewResults(RewriterBase &rewriter,
+ Operation *op,
+ TypeRange resultTypes) {
+ RewriterBase::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPoint(op);
+ OperationState state(op->getLoc(), op->getName(), op->getOperands(),
+ resultTypes, op->getAttrs());
+ for (size_t cnt = 0, e = op->getNumRegions(); cnt < e; ++cnt)
+ state.addRegion();
+ Operation *newOp = rewriter.create(state);
+ rewriter.startOpModification(newOp);
+ rewriter.startOpModification(op);
+ for (unsigned int i : llvm::seq(op->getNumRegions()))
+ newOp->getRegion(i).takeBody(op->getRegion(i));
+ rewriter.finalizeOpModification(op);
+ rewriter.finalizeOpModification(newOp);
+
+ SmallVector<Value> replacementValues(newOp->getResults().drop_back());
+ rewriter.replaceAllOpUsesWith(op, replacementValues);
+ rewriter.eraseOp(op);
+ return newOp;
+}
More information about the Mlir-commits
mailing list