[Mlir-commits] [mlir] [MLIR][Mem2Reg] Extract shared utilities for PromotableRegionOpInterface (PR #188514)
Berke Ates
llvmlistbot at llvm.org
Thu Mar 26 02:04:54 PDT 2026
https://github.com/Berke-Ates updated https://github.com/llvm/llvm-project/pull/188514
>From f05e8eb4086c28dfca8252c80881598c176766ab Mon Sep 17 00:00:00 2001
From: Berke-Ates <berke at ates.ch>
Date: Wed, 25 Mar 2026 16:21:05 +0100
Subject: [PATCH] [MLIR][Mem2Reg] Extract shared utilities for
PromotableRegionOpInterface
The PromotableRegionOpInterface implementations use two helpers that are
likely useful for other dialects implementing this interface as well.
This extracts them into a common utility header so that downstream
dialects can reuse them directly.
---
.../mlir/Interfaces/Utils/MemorySlotUtils.h | 36 ++++++++
mlir/lib/Dialect/SCF/IR/CMakeLists.txt | 1 +
mlir/lib/Dialect/SCF/IR/MemorySlot.cpp | 84 +++++--------------
mlir/lib/Interfaces/Utils/CMakeLists.txt | 15 ++++
mlir/lib/Interfaces/Utils/MemorySlotUtils.cpp | 49 +++++++++++
5 files changed, 122 insertions(+), 63 deletions(-)
create mode 100644 mlir/include/mlir/Interfaces/Utils/MemorySlotUtils.h
create mode 100644 mlir/lib/Interfaces/Utils/MemorySlotUtils.cpp
diff --git a/mlir/include/mlir/Interfaces/Utils/MemorySlotUtils.h b/mlir/include/mlir/Interfaces/Utils/MemorySlotUtils.h
new file mode 100644
index 0000000000000..66be79b5ad285
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/Utils/MemorySlotUtils.h
@@ -0,0 +1,36 @@
+//===- MemorySlotUtils.h - Utilities for MemorySlot interfaces --*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares common utilities for implementing MemorySlot interfaces,
+// in particular PromotableRegionOpInterface.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_UTILS_MEMORYSLOTUTILS_H
+#define MLIR_INTERFACES_UTILS_MEMORYSLOTUTILS_H
+
+#include "mlir/IR/PatternMatch.h"
+
+namespace mlir {
+namespace memoryslot {
+
+/// Appends the reaching definition for the given block as an operand to its
+/// terminator. If the block has no entry in `reachingAtBlockEnd` (e.g. dead
+/// code or the region does not use the slot), `defaultReachingDef` is used.
+void updateTerminator(Block *block, Value defaultReachingDef,
+ const DenseMap<Block *, Value> &reachingAtBlockEnd);
+
+/// Creates a shallow copy of an operation with new result types, moving the
+/// regions out of the original operation and deleting the original operation.
+Operation *replaceWithNewResults(RewriterBase &rewriter, Operation *op,
+ TypeRange resultTypes);
+
+} // namespace memoryslot
+} // namespace mlir
+
+#endif // MLIR_INTERFACES_UTILS_MEMORYSLOTUTILS_H
diff --git a/mlir/lib/Dialect/SCF/IR/CMakeLists.txt b/mlir/lib/Dialect/SCF/IR/CMakeLists.txt
index fca28c5209e2d..5e7d10e4a9d6a 100644
--- a/mlir/lib/Dialect/SCF/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/SCF/IR/CMakeLists.txt
@@ -17,6 +17,7 @@ add_mlir_dialect_library(MLIRSCFDialect
MLIRFunctionInterfaces
MLIRIR
MLIRLoopLikeInterface
+ MLIRMemorySlotUtils
MLIRSideEffectInterfaces
MLIRTensorDialect
MLIRValueBoundsOpInterface
diff --git a/mlir/lib/Dialect/SCF/IR/MemorySlot.cpp b/mlir/lib/Dialect/SCF/IR/MemorySlot.cpp
index 3d61476df6014..92fc3452d4629 100644
--- a/mlir/lib/Dialect/SCF/IR/MemorySlot.cpp
+++ b/mlir/lib/Dialect/SCF/IR/MemorySlot.cpp
@@ -7,54 +7,11 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/SCF/IR/SCF.h"
-#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/Interfaces/Utils/MemorySlotUtils.h"
using namespace mlir;
using namespace mlir::scf;
-//===----------------------------------------------------------------------===//
-// Helper functions
-//===----------------------------------------------------------------------===//
-
-/// Adds the corresponding reaching definition to the terminator of the block if
-/// the terminator is of the provided type.
-template <typename TermTy>
-static void
-updateTerminator(Block *block, Value defaultReachingDef,
- const llvm::DenseMap<Block *, Value> &reachingAtBlockEnd) {
- Operation *terminator = block->getTerminator();
- if (!isa<TermTy>(terminator))
- return;
- Value blockReachingDef = reachingAtBlockEnd.lookup(block);
- if (!blockReachingDef) {
- // Block is dead code or the region is not using the slot, so we use the
- // default provided reaching definition.
- blockReachingDef = defaultReachingDef;
- }
- terminator->insertOperands(terminator->getNumOperands(), {blockReachingDef});
-}
-
-/// Creates a shallow copy of an operation with new result types, moving the
-/// regions out of the original operation and deleting the original operation.
-static Operation *replaceWithNewResults(RewriterBase &rewriter, Operation *op,
- TypeRange resultTypes) {
- RewriterBase::InsertionGuard guard(rewriter);
- rewriter.setInsertionPoint(op);
- Operation *newOp =
- mlir::cloneWithoutRegions(rewriter, op, resultTypes, op->getOperands());
- rewriter.startOpModification(newOp);
- rewriter.startOpModification(op);
- for (unsigned int i : llvm::seq(op->getNumRegions()))
- newOp->getRegion(i).takeBody(op->getRegion(i));
- rewriter.finalizeOpModification(op);
- rewriter.finalizeOpModification(newOp);
-
- SmallVector<Value> replacementValues(newOp->getResults().drop_back());
- rewriter.replaceAllOpUsesWith(op, replacementValues);
- rewriter.eraseOp(op);
- return newOp;
-}
-
//===----------------------------------------------------------------------===//
// ExecuteRegionOp
//===----------------------------------------------------------------------===//
@@ -80,14 +37,15 @@ Value ExecuteRegionOp::finalizePromotion(
// Update the yield terminators to return the newly defined reaching
// definition.
for (Block &block : getRegion().getBlocks())
- updateTerminator<YieldOp>(&block, reachingDef, reachingAtBlockEnd);
+ if (isa<YieldOp>(block.getTerminator()))
+ memoryslot::updateTerminator(&block, reachingDef, reachingAtBlockEnd);
SmallVector<Type> resultTypes(getResultTypes());
resultTypes.push_back(slot.elemType);
IRRewriter rewriter(builder);
Operation *newOp =
- replaceWithNewResults(rewriter, getOperation(), resultTypes);
+ memoryslot::replaceWithNewResults(rewriter, getOperation(), resultTypes);
return newOp->getResults().back();
}
@@ -123,14 +81,14 @@ Value ForOp::finalizePromotion(
// Update the yield terminator to return the newly defined reaching
// definition.
- updateTerminator<YieldOp>(getBody(), reachingDef, reachingAtBlockEnd);
+ memoryslot::updateTerminator(getBody(), reachingDef, reachingAtBlockEnd);
SmallVector<Type> resultTypes(getResultTypes());
resultTypes.push_back(slot.elemType);
IRRewriter rewriter(builder);
Operation *newOp =
- replaceWithNewResults(rewriter, getOperation(), resultTypes);
+ memoryslot::replaceWithNewResults(rewriter, getOperation(), resultTypes);
return newOp->getResults().back();
}
@@ -187,11 +145,11 @@ Value IfOp::finalizePromotion(
// Update the yield terminators to return the newly defined reaching
// definition.
- updateTerminator<YieldOp>(&getThenRegion().back(), reachingDef,
- reachingAtBlockEnd);
+ memoryslot::updateTerminator(&getThenRegion().back(), reachingDef,
+ reachingAtBlockEnd);
if (getElseRegion().hasOneBlock()) {
- updateTerminator<YieldOp>(&getElseRegion().back(), reachingDef,
- reachingAtBlockEnd);
+ memoryslot::updateTerminator(&getElseRegion().back(), reachingDef,
+ reachingAtBlockEnd);
} else {
OpBuilder::InsertionGuard guard(rewriter);
rewriter.createBlock(&getElseRegion());
@@ -202,7 +160,7 @@ Value IfOp::finalizePromotion(
resultTypes.push_back(slot.elemType);
Operation *newOp =
- replaceWithNewResults(rewriter, getOperation(), resultTypes);
+ memoryslot::replaceWithNewResults(rewriter, getOperation(), resultTypes);
return newOp->getResults().back();
}
@@ -234,17 +192,17 @@ Value IndexSwitchOp::finalizePromotion(
// Update the yield terminators to return the newly defined reaching
// definition.
- updateTerminator<YieldOp>(&getDefaultRegion().back(), reachingDef,
- reachingAtBlockEnd);
+ memoryslot::updateTerminator(&getDefaultRegion().back(), reachingDef,
+ reachingAtBlockEnd);
for (Region &caseRegion : getCaseRegions())
- updateTerminator<YieldOp>(&caseRegion.back(), reachingDef,
- reachingAtBlockEnd);
+ memoryslot::updateTerminator(&caseRegion.back(), reachingDef,
+ reachingAtBlockEnd);
SmallVector<Type> resultTypes(getResultTypes());
resultTypes.push_back(slot.elemType);
Operation *newOp =
- replaceWithNewResults(rewriter, getOperation(), resultTypes);
+ memoryslot::replaceWithNewResults(rewriter, getOperation(), resultTypes);
return newOp->getResults().back();
}
@@ -339,10 +297,10 @@ Value WhileOp::finalizePromotion(
// Update the yield terminators to return the newly defined reaching
// definition.
- updateTerminator<ConditionOp>(&getBefore().back(),
- getBefore().getArguments().back(),
- reachingAtBlockEnd);
- updateTerminator<YieldOp>(
+ memoryslot::updateTerminator(&getBefore().back(),
+ getBefore().getArguments().back(),
+ reachingAtBlockEnd);
+ memoryslot::updateTerminator(
&getAfter().back(), getAfter().getArguments().back(), reachingAtBlockEnd);
SmallVector<Type> resultTypes(getResultTypes());
@@ -350,6 +308,6 @@ Value WhileOp::finalizePromotion(
IRRewriter rewriter(builder);
Operation *newOp =
- replaceWithNewResults(rewriter, getOperation(), resultTypes);
+ memoryslot::replaceWithNewResults(rewriter, getOperation(), resultTypes);
return newOp->getResults().back();
}
diff --git a/mlir/lib/Interfaces/Utils/CMakeLists.txt b/mlir/lib/Interfaces/Utils/CMakeLists.txt
index 8c45f66997427..c722fbf4ece09 100644
--- a/mlir/lib/Interfaces/Utils/CMakeLists.txt
+++ b/mlir/lib/Interfaces/Utils/CMakeLists.txt
@@ -1,3 +1,8 @@
+set(LLVM_OPTIONAL_SOURCES
+ InferIntRangeCommon.cpp
+ MemorySlotUtils.cpp
+ )
+
add_mlir_library(MLIRInferIntRangeCommon
InferIntRangeCommon.cpp
@@ -12,3 +17,13 @@ add_mlir_library(MLIRInferIntRangeCommon
MLIRInferIntRangeInterface
MLIRIR
)
+
+add_mlir_library(MLIRMemorySlotUtils
+ MemorySlotUtils.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Interfaces/Utils
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+)
diff --git a/mlir/lib/Interfaces/Utils/MemorySlotUtils.cpp b/mlir/lib/Interfaces/Utils/MemorySlotUtils.cpp
new file mode 100644
index 0000000000000..32a8d898ae3c6
--- /dev/null
+++ b/mlir/lib/Interfaces/Utils/MemorySlotUtils.cpp
@@ -0,0 +1,49 @@
+//===- MemorySlotUtils.cpp - Utilities for MemorySlot interfaces ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements common utilities for implementing MemorySlot interfaces,
+// in particular PromotableRegionOpInterface.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Interfaces/Utils/MemorySlotUtils.h"
+
+using namespace mlir;
+
+void mlir::memoryslot::updateTerminator(
+ Block *block, Value defaultReachingDef,
+ const DenseMap<Block *, Value> &reachingAtBlockEnd) {
+ Value blockReachingDef = reachingAtBlockEnd.lookup(block);
+ if (!blockReachingDef)
+ blockReachingDef = defaultReachingDef;
+ Operation *terminator = block->getTerminator();
+ terminator->insertOperands(terminator->getNumOperands(), {blockReachingDef});
+}
+
+Operation *mlir::memoryslot::replaceWithNewResults(RewriterBase &rewriter,
+ Operation *op,
+ TypeRange resultTypes) {
+ RewriterBase::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPoint(op);
+ OperationState state(op->getLoc(), op->getName(), op->getOperands(),
+ resultTypes, op->getAttrs());
+ for (unsigned cnt = 0, e = op->getNumRegions(); cnt < e; ++cnt)
+ state.addRegion();
+ Operation *newOp = rewriter.create(state);
+ rewriter.startOpModification(newOp);
+ rewriter.startOpModification(op);
+ for (unsigned i : llvm::seq(op->getNumRegions()))
+ newOp->getRegion(i).takeBody(op->getRegion(i));
+ rewriter.finalizeOpModification(op);
+ rewriter.finalizeOpModification(newOp);
+
+ rewriter.replaceAllOpUsesWith(
+ op, newOp->getResults().take_front(op->getNumResults()));
+ rewriter.eraseOp(op);
+ return newOp;
+}
More information about the Mlir-commits
mailing list