[Mlir-commits] [mlir] [MLIR][Mem2Reg] Add support for region control flow and SCF (PR #185036)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Mar 20 19:31:15 PDT 2026
https://github.com/tdegioanni-nvidia updated https://github.com/llvm/llvm-project/pull/185036
>From 2292a90ba74ccc1d25518833dbb5973c1d32266f Mon Sep 17 00:00:00 2001
From: Theo Degioanni <tdegioanni at nvidia.com>
Date: Mon, 2 Mar 2026 19:10:05 +0100
Subject: [PATCH 01/15] mem2reg for region control flow
---
.../mlir/Interfaces/MemorySlotInterfaces.td | 98 ++++
mlir/lib/Transforms/Mem2Reg.cpp | 487 ++++++++++++++----
2 files changed, 473 insertions(+), 112 deletions(-)
diff --git a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
index fbce2fa1d043d..b09d88bd171fd 100644
--- a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
+++ b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
@@ -263,6 +263,104 @@ def PromotableOpInterface : OpInterface<"PromotableOpInterface"> {
];
}
+def PromotableRegionOpInterface
+ : OpInterface<"PromotableRegionOpInterface"> {
+ let description = [{
+ Describes an operation for which memory slots can be promoted to SSA values
+ within the operation's regions.
+ }];
+ let cppNamespace = "::mlir";
+
+ let methods = [
+ InterfaceMethod<[{
+ Returns true when the provided region of the operation can be analyzed
+ for promotion. The provided region must be a child of the operation's
+ region.
+ The `hasValueStores` flag indicates whether the region contains
+ store-like operations that write to the memory slot.
+ }], "bool", "isRegionPromotable",
+ (ins
+ "const ::mlir::MemorySlot &":$slot,
+ "::mlir::Region *":$region,
+ "bool":$hasValueStores
+ )
+ >,
+ InterfaceMethod<[{
+ Specifies which operations should be marked as live-in with respect to
+ the value of the provided slot assuming the provided region is live-in.
+
+ In other words, specifies which operation immediately passes control to
+ the region without storing to the value of the slot.
+ }], "void", "propagateLiveIn",
+ (ins
+ "const ::mlir::MemorySlot &":$slot,
+ "::mlir::Region *":$regionLiveIn,
+ "::mlir::SmallPtrSetImpl<::mlir::Operation *> &":$operationsLiveIn
+ )
+ >,
+ InterfaceMethod<[{
+ Called before processing the nested regions in the operation.
+
+ Based on the `reachingDef` value representing the value in the memory
+ slot at the entry into the operation, `setupPromotion` fills in the
+ `regionsToProcess` with the the reaching definition at the entry of
+ all its promotable regions.
+
+ `setupPromotion` is allowed to mutate
+ the operation in place, including its nested regions, but cannot
+ delete existing operations or modify successor-bearing terminators.
+ Other mutations are not allowed.
+
+ The `hasValueStores` flag indicates whether the regions contain
+ `store`-like operations that write to the memory slot. This field can be
+ used to reduce the amount of book-keeping required to track the reaching
+ definitions, but is correct to consider it always true.
+ }], "void", "setupPromotion",
+ (ins
+ "const ::mlir::MemorySlot &":$slot,
+ "::mlir::Value":$reachingDef,
+ "bool":$hasValueStores,
+ "::llvm::SmallMapVector<::mlir::Region *, ::mlir::Value, 2> &":$regionsToProcess
+ )
+ >,
+ InterfaceMethod<[{
+ Called after promotion has been completed in all the relevant regions.
+
+ Returns the new reaching definition at the exit of the operation. For
+ this purpose, it is allowed to mutate the operation using the provided
+ `builder`, along with its region contents. However, all blocks within
+ the existing regions must remain valid and no new blocks may be added.
+ As a result, the operation is allowed to be cloned and replaced only
+ if its region content is moved from the original operation and not
+ copied. Operations with an effect on the value of the slot must not
+ change said effect (for example, new control flow that could change
+ reaching definitions for a block is not allowed).
+
+ The `entryReachingDef` is the reaching definition at the entry of the
+ region operation.
+
+ The `reachingAtBlockEnd` map contains the reaching definitions after all
+ the terminators within the regions of the operation. If a block of the
+ region is not present in the map, it is either dead code or within a
+ region that does not interact with the value of the slot.
+
+ The `hasValueStores` flag indicates whether the regions contain
+ `store`-like operations that write to the memory slot. This field can be
+ used to reduce the amount of book-keeping required to track the reaching
+ definitions, but is correct to consider it always true.
+ }],
+ "::mlir::Value", "finalizePromotion",
+ (ins
+ "const ::mlir::MemorySlot &":$slot,
+ "::mlir::Value":$entryReachingDef,
+ "bool":$hasValueStores,
+ "::llvm::DenseMap<::mlir::Block *, ::mlir::Value> &":$reachingAtBlockEnd,
+ "::mlir::OpBuilder &":$builder
+ )
+ >,
+ ];
+}
+
def DestructurableAllocationOpInterface
: OpInterface<"DestructurableAllocationOpInterface"> {
let description = [{
diff --git a/mlir/lib/Transforms/Mem2Reg.cpp b/mlir/lib/Transforms/Mem2Reg.cpp
index b3057129fb9fd..26ab65c68b259 100644
--- a/mlir/lib/Transforms/Mem2Reg.cpp
+++ b/mlir/lib/Transforms/Mem2Reg.cpp
@@ -52,8 +52,19 @@ using namespace mlir;
/// this, the value stored can be well defined at block boundaries, allowing
/// the propagation of replacement through blocks.
///
-/// This pass computes this transformation in four main steps. The two first
-/// steps are performed during an analysis phase that does not mutate IR.
+/// The way regions are handled in the trnasformation is by offering an
+/// interface to express the behavior of the allocation value at the edges of
+/// the regions: from a particular definition reaching the region operation, the
+/// operation will specify what the reaching definition at the entry of its
+/// regions are (potentially mutating itself, for example to add region
+/// arguments). Likewise, provided a reaching definition at the end of the
+/// blocks in the regions, the region operation will provide the reaching
+/// definition right after itself.
+///
+/// This pass computes this transformation in two main phases: an analysis
+/// phase that does not mutate IR, and a transformation phase where mutation
+/// happens. Each phase is handled by the `MemorySlotPromotionAnalyzer` and
+/// `MemorySlotPromoter` classes respectively.
///
/// The two steps of the analysis phase are the following:
/// - A first step computes the list of operations that transitively use the
@@ -62,36 +73,43 @@ using namespace mlir;
/// the user or deleting it. Naturally, direct uses of the slot must be removed.
/// Sometimes additional uses must also be removed: this is notably the case
/// when a direct user of the slot cannot rewire its use and must delete itself,
-/// and thus must make its users no longer use it. If any of those uses cannot
-/// be removed by their users in any way, promotion cannot continue: this is
-/// decided at this step.
+/// and thus must make its users no longer use it. If the allocation is used in
+/// nested regions, it is also ensured the region operations provide the right
+/// interface to analyze the values of the allocation at the edges of its
+/// regions. If any of those constraints cannot be satisfied, promotion cannot
+/// continue: this is decided at this step.
/// - A second step computes the list of blocks where a block argument will be
/// needed ("merge points") without mutating the IR. These blocks are the blocks
/// leading to a definition clash between two predecessors. Such blocks happen
/// to be the Iterated Dominance Frontier (IDF) of the set of blocks containing
-/// a store, as they represent the point where a clear defining dominator stops
+/// a store, as they represent the points where a clear defining dominator stops
/// existing. Computing this information in advance allows making sure the
/// terminators that will forward values are capable of doing so (inability to
/// do so aborts promotion at this step).
///
-/// At this point, promotion is guaranteed to happen, and the mutation phase can
-/// begin with the following steps:
-/// - A third step computes the reaching definition of the memory slot at each
-/// blocking user. This is the core of the mem2reg algorithm, also known as
-/// load-store forwarding. This analyses loads and stores and propagates which
-/// value must be stored in the slot at each blocking user. This is achieved by
-/// doing a depth-first walk of the dominator tree of the function. This is
-/// sufficient because the reaching definition at the beginning of a block is
-/// either its new block argument if it is a merge block, or the definition
-/// reaching the end of its immediate dominator (parent in the dominator tree).
-/// We can therefore propagate this information down the dominator tree to
-/// proceed with renaming within blocks.
-/// - The final fourth step uses the reaching definition to remove blocking uses
-/// in topological order.
+/// At this point, promotion is guaranteed to happen, and the transformation
+/// phase can begin. For each region of the program, a two step procvess is
+/// carried out.
+/// - The first step of the per-region process computes the reaching definition
+/// of the memory slot at each blocking user. This is the core of the mem2reg
+/// algorithm, also known as load-store forwarding. This analyses loads and
+/// stores and propagates which value must be stored in the slot at each
+/// blocking user. This is achieved by doing a depth-first walk of the dominator
+/// tree of the function. This is sufficient because the reaching definition at
+/// the beginning of a block is either its new block argument if it is a merge
+/// block, or the definition reaching the end of its immediate dominator (parent
+/// in the dominator tree). We can therefore propagate this information down the
+/// dominator tree to proceed with renaming within blocks. If at any point a
+/// region operation that contains a use of the allocation is encountered, the
+/// transformation process is triggered on the child regions of the encountered
+/// operation, to obtain the reaching definition at its end and carry on with
+/// the value forwarding.
+/// - The second step of the per-region process uses the reaching definition to
+/// remove blocking uses in topological order.
///
/// For further reading, chapter three of SSA-based Compiler Design [1]
-/// showcases SSA construction, where mem2reg is an adaptation of the same
-/// process.
+/// showcases SSA construction for control-flow graphs, where mem2reg is an
+/// adaptation of the same process.
///
/// [1]: Rastello F. & Bouchez Tichadou F., SSA-based Compiler Design (2022),
/// Springer.
@@ -100,18 +118,34 @@ namespace {
using BlockingUsesMap =
llvm::MapVector<Operation *, SmallPtrSet<OpOperand *, 4>>;
+using RegionBlockingUsesMap =
+ llvm::SmallMapVector<Region *, BlockingUsesMap, 2>;
+
+using RegionSet = SmallPtrSet<Region *, 32>;
+
+/// Information about regions that will be traversed for promotion, computed
+/// during promotion analysis.
+struct RegionPromotionInfo {
+ /// True if an operation storing to the slot is present in the region.
+ bool hasValueStores;
+};
/// Information computed during promotion analysis used to perform actual
/// promotion.
struct MemorySlotPromotionInfo {
/// Blocks for which at least two definitions of the slot values clash.
SmallPtrSet<Block *, 8> mergePoints;
- /// Contains, for each operation, which uses must be eliminated by promotion.
- /// This is a DAG structure because if an operation must eliminate some of
- /// its uses, it is because the defining ops of the blocking uses requested
- /// it. The defining ops therefore must also have blocking uses or be the
- /// starting point of the blocking uses.
- BlockingUsesMap userToBlockingUses;
+ /// Contains, for each each region, the blocking uses for its operations. The
+ /// blocking uses are the uses that must be eliminated by promotion. For each
+ /// region, this is a DAG structure because if an operation must eliminate
+ /// some of its uses, it is because the defining ops of the blocking uses
+ /// requested it. The defining ops therefore must also have blocking uses or
+ /// be the starting point of the blocking uses.
+ RegionBlockingUsesMap userToBlockingUses;
+ /// Regions of which the edges must be analyzed for promotion. All regions
+ /// are guaranteed to be held by a PromotableRegionOpInterface, and to be
+ /// nested within the parent region of the slot pointer.
+ DenseMap<Region *, RegionPromotionInfo> regionsToPromote;
};
/// Computes information for basic slot promotion. This will check that direct
@@ -135,18 +169,32 @@ class MemorySlotPromotionAnalyzer {
/// uses (typically, removing its users because it will delete itself to
/// resolve its own blocking uses). This will fail if one of the transitive
/// users cannot remove a requested use, and should prevent promotion.
- LogicalResult computeBlockingUses(BlockingUsesMap &userToBlockingUses);
+ /// Resulting blocking uses are grouped by region.
+ /// This also ensures all the uses are within promotable regions, adding
+ /// information about regions to be promoted to the `regionsToPromote` map.
+ LogicalResult computeBlockingUses(
+ RegionBlockingUsesMap &userToBlockingUses,
+ DenseMap<Region *, RegionPromotionInfo> ®ionsToPromote);
/// Computes in which blocks the value stored in the slot is actually used,
/// meaning blocks leading to a load. This method uses `definingBlocks`, the
/// set of blocks containing a store to the slot (defining the value of the
/// slot).
- SmallPtrSet<Block *, 16>
- computeSlotLiveIn(SmallPtrSetImpl<Block *> &definingBlocks);
-
- /// Computes the points in which multiple re-definitions of the slot's value
- /// (stores) may conflict.
- void computeMergePoints(SmallPtrSetImpl<Block *> &mergePoints);
+ /// The analysis is aware of regions and uses region promotion information
+ /// to determine the effect of nested regions on slot value liveness.
+ SmallPtrSet<Block *, 16> computeSlotLiveIn(
+ DenseMap<Region *, SmallPtrSet<Block *, 16>> &definingBlocksByRegion,
+ DenseMap<Region *, RegionPromotionInfo> ®ionsToPromote);
+
+ /// Computes the points in the provided region where multiple re-definitions
+ /// of the slot's value (stores) may conflict.
+ /// `definingBlocks` is the set of blocks containing a store to the slot,
+ /// either directly or inherited from a nested region.
+ /// `slotLiveIn` is the set of blocks where the memory slot is live-in.
+ void computeMergePoints(Region *region,
+ SmallPtrSetImpl<Block *> &definingBlocks,
+ SmallPtrSetImpl<Block *> &slotLiveIn,
+ SmallPtrSetImpl<Block *> &mergePoints);
/// Ensures predecessors of merge points can properly provide their current
/// definition of the value stored in the slot to the merge point. This can
@@ -155,6 +203,7 @@ class MemorySlotPromotionAnalyzer {
bool areMergePointsUsable(SmallPtrSetImpl<Block *> &mergePoints);
MemorySlot slot;
+
DominanceInfo &dominance;
const DataLayout &dataLayout;
};
@@ -181,19 +230,31 @@ class MemorySlotPromoter {
private:
/// Computes the reaching definition for all the operations that require
- /// promotion. `reachingDef` is the value the slot should contain at the
- /// beginning of the block. This method returns the reached definition at the
- /// end of the block. This method must only be called at most once per block.
- Value computeReachingDefInBlock(Block *block, Value reachingDef);
+ /// promotion, including within nested regions needing promotion.
+ /// `reachingDef` is the value the slot contains at the beginning of the
+ /// block. This method returns the reached definition at the end of the block.
+ ///
+ /// The `reachingDef` may be a null value. In that case, a lazily-created
+ /// default value will be used.
+ ///
+ /// This method must only be called at most once per block.
+ Value promoteInBlock(Block *block, Value reachingDef);
/// Computes the reaching definition for all the operations that require
- /// promotion. `reachingDef` corresponds to the initial value the
- /// slot will contain before any write, typically a poison value.
+ /// promotion, including within nested regions needing promotion, and removes
+ /// the blocking uses of the slot within the region.
+ /// `reachingDef` is the value the slot contains at the beginning of the
+ /// region.
+ ///
+ /// The `reachingDef` may be a null value. In that case, a lazily-created
+ /// default value will be used.
+ ///
/// This method must only be called at most once per region.
- void computeReachingDefInRegion(Region *region, Value reachingDef);
+ void promoteInRegion(Region *region, Value reachingDef);
- /// Removes the blocking uses of the slot, in topological order.
- void removeBlockingUses();
+ /// Removes the blocking uses of the slot within the given region, in
+ /// topological order.
+ void removeBlockingUses(Region *region);
/// Lazily-constructed default value representing the content of the slot when
/// no store has been executed. This function may mutate IR.
@@ -209,6 +270,20 @@ class MemorySlotPromoter {
/// are only computed for promotable memory operations with blocking uses.
DenseMap<PromotableMemOpInterface, Value> reachingDefs;
DenseMap<PromotableMemOpInterface, Value> replacedValuesMap;
+
+ /// Contains the reaching definition at the end of the blocks visited so far.
+ DenseMap<Block *, Value> reachingAtBlockEnd;
+
+ /// Lists all the values that have been set by a memory operation as a
+ /// reaching definition at one point during the promotion. The accompanying
+ /// operation is the memory operation that originally stored the value.
+ llvm::SmallVector<std::pair<Operation *, Value>> replacedValues;
+ /// Operations to visit with the `visitReplacedValues` method at the end of
+ /// the promotion.
+ llvm::SmallVector<PromotableOpInterface> toVisitReplacedValues;
+ /// Operations to be erased at the end of the promotion.
+ llvm::SmallVector<Operation *> toErase;
+
DominanceInfo &dominance;
const DataLayout &dataLayout;
MemorySlotPromotionInfo info;
@@ -251,16 +326,14 @@ Value MemorySlotPromoter::getOrCreateDefaultValue() {
}
LogicalResult MemorySlotPromotionAnalyzer::computeBlockingUses(
- BlockingUsesMap &userToBlockingUses) {
+ RegionBlockingUsesMap &userToBlockingUses,
+ DenseMap<Region *, RegionPromotionInfo> ®ionsToPromote) {
// The promotion of an operation may require the promotion of further
// operations (typically, removing operations that use an operation that must
// delete itself). We thus need to start from the use of the slot pointer and
// propagate further requests through the forward slice.
- // Because this pass currently only supports analysing the parent region of
- // the slot pointer, if a promotable memory op that needs promotion is within
- // a graph region, the slot may only be used in a graph region and should
- // therefore be ignored.
+ // Graph regions are not supported.
Region *slotPtrRegion = slot.ptr.getParentRegion();
auto slotPtrRegionOp =
dyn_cast<RegionKindInterface>(slotPtrRegion->getParentOp());
@@ -273,10 +346,15 @@ LogicalResult MemorySlotPromotionAnalyzer::computeBlockingUses(
// use it.
for (OpOperand &use : slot.ptr.getUses()) {
SmallPtrSet<OpOperand *, 4> &blockingUses =
- userToBlockingUses[use.getOwner()];
+ userToBlockingUses[use.getOwner()->getParentRegion()][use.getOwner()];
blockingUses.insert(&use);
}
+ // Regions that immediately contain a slot memory use that is not a store.
+ RegionSet regionsWithDirectUse;
+ // Regions that immediately contain a slot memory use that is a store.
+ RegionSet regionsWithDirectStore;
+
// Then, propagate the requirements for the removal of uses. The
// topologically-sorted forward slice allows for all blocking uses of an
// operation to have been computed before it is reached. Operations are
@@ -286,8 +364,12 @@ LogicalResult MemorySlotPromotionAnalyzer::computeBlockingUses(
mlir::getForwardSlice(slot.ptr, &forwardSlice);
for (Operation *user : forwardSlice) {
// If the next operation has no blocking uses, everything is fine.
- auto *it = userToBlockingUses.find(user);
- if (it == userToBlockingUses.end())
+ auto *blockingUsesMapIt = userToBlockingUses.find(user->getParentRegion());
+ if (blockingUsesMapIt == userToBlockingUses.end())
+ continue;
+ BlockingUsesMap &blockingUsesMap = blockingUsesMapIt->second;
+ auto *it = blockingUsesMap.find(user);
+ if (it == blockingUsesMap.end())
continue;
SmallPtrSet<OpOperand *, 4> &blockingUses = it->second;
@@ -303,6 +385,14 @@ LogicalResult MemorySlotPromotionAnalyzer::computeBlockingUses(
if (!promotable.canUsesBeRemoved(slot, blockingUses, newBlockingUses,
dataLayout))
return failure();
+
+ // Operations that interact with the slot's memory will be promoted using
+ // a reaching definition. Therefore, the operation must be within a region
+ // where the reaching definition can be computed.
+ if (promotable.storesTo(slot))
+ regionsWithDirectStore.insert(user->getParentRegion());
+ else
+ regionsWithDirectUse.insert(user->getParentRegion());
} else {
// An operation that has blocking uses must be promoted. If it is not
// promotable, promotion must fail.
@@ -314,25 +404,84 @@ LogicalResult MemorySlotPromotionAnalyzer::computeBlockingUses(
assert(llvm::is_contained(user->getResults(), blockingUse->get()));
SmallPtrSetImpl<OpOperand *> &newUserBlockingUseSet =
- userToBlockingUses[blockingUse->getOwner()];
+ blockingUsesMap[blockingUse->getOwner()];
newUserBlockingUseSet.insert(blockingUse);
}
}
- // Because this pass currently only supports analysing the parent region of
- // the slot pointer, if a promotable memory op that needs promotion is outside
- // of this region, promotion must fail because it will be impossible to
- // provide a valid `reachingDef` for it.
- for (auto &[toPromote, _] : userToBlockingUses)
- if (isa<PromotableMemOpInterface>(toPromote) &&
- toPromote->getParentRegion() != slot.ptr.getParentRegion())
- return failure();
+ // Finally, check that all the regions needed are promotable, and propagate
+ // the constraint to their parent regions.
+ auto visitRegions = [&](SmallVector<Region *> ®ionsToPropagateFrom,
+ bool hasValueStores) {
+ while (!regionsToPropagateFrom.empty()) {
+ Region *region = regionsToPropagateFrom.pop_back_val();
+
+ if (region == slot.ptr.getParentRegion() ||
+ regionsToPromote.contains(region))
+ continue;
+
+ RegionPromotionInfo ®ionInfo = regionsToPromote[region];
+ regionInfo.hasValueStores = hasValueStores;
+
+ auto promotableParentOp =
+ dyn_cast<PromotableRegionOpInterface>(region->getParentOp());
+ if (!promotableParentOp)
+ return failure();
+
+ if (!promotableParentOp.isRegionPromotable(slot, region, hasValueStores))
+ return failure();
+
+ regionsToPropagateFrom.push_back(region->getParentRegion());
+ }
+
+ return success();
+ };
+
+ // Start with the regions that directly contain a store to give priority
+ // to stores in the propagation of `hasValueStores` information.
+ SmallVector<Region *> regionsToPropagateFrom(regionsWithDirectStore.begin(),
+ regionsWithDirectStore.end());
+ if (failed(visitRegions(regionsToPropagateFrom, true)))
+ return failure();
+
+ // Then, propagate from the regions that directly contain non-store uses.
+ regionsToPropagateFrom.clear();
+ regionsToPropagateFrom.append(regionsWithDirectUse.begin(),
+ regionsWithDirectUse.end());
+ if (failed(visitRegions(regionsToPropagateFrom, false)))
+ return failure();
return success();
}
+/// Returns true if the operation contains a store, whether itself or in a
+/// nested region.
+static bool
+isStoreLike(Operation *op, MemorySlot &slot,
+ DenseMap<Region *, RegionPromotionInfo> ®ionsToPromote) {
+ auto promotableMemOp = dyn_cast<PromotableMemOpInterface>(op);
+ if (promotableMemOp && promotableMemOp.storesTo(slot))
+ return true;
+
+ auto promotableRegionOp = dyn_cast<PromotableRegionOpInterface>(op);
+ if (!promotableRegionOp)
+ return false;
+
+ for (Region ®ion : op->getRegions()) {
+ auto regionInfoIt = regionsToPromote.find(®ion);
+ if (regionInfoIt == regionsToPromote.end())
+ continue;
+
+ if (regionInfoIt->second.hasValueStores)
+ return true;
+ }
+
+ return false;
+}
+
SmallPtrSet<Block *, 16> MemorySlotPromotionAnalyzer::computeSlotLiveIn(
- SmallPtrSetImpl<Block *> &definingBlocks) {
+ DenseMap<Region *, SmallPtrSet<Block *, 16>> &definingBlocksByRegion,
+ DenseMap<Region *, RegionPromotionInfo> ®ionsToPromote) {
SmallPtrSet<Block *, 16> liveIn;
// The worklist contains blocks in which it is known that the slot value is
@@ -340,6 +489,8 @@ SmallPtrSet<Block *, 16> MemorySlotPromotionAnalyzer::computeSlotLiveIn(
// from these.
SmallVector<Block *> liveInWorkList;
+ SmallPtrSet<Operation *, 4> regionPredecessorScratch;
+
// Blocks with a load before any other store to the slot are the starting
// points of the analysis. The slot value is definitely live-in in those
// blocks.
@@ -380,8 +531,33 @@ SmallPtrSet<Block *, 16> MemorySlotPromotionAnalyzer::computeSlotLiveIn(
// We can thus at this stage insert to the worklist only predecessors that
// are not defining blocks.
for (Block *pred : liveInBlock->getPredecessors())
- if (!definingBlocks.contains(pred))
+ if (!definingBlocksByRegion[pred->getParent()].contains(pred))
liveInWorkList.push_back(pred);
+
+ // The logic is a little more complicated for region predecessors as they
+ // could be in the middle of a block. We thus need to look for a store
+ // within the predecessor block specifically before the region predecessor
+ // operation.
+ if (liveInBlock->isEntryBlock() &&
+ liveInBlock->getParent() != slot.ptr.getParentRegion()) {
+ regionPredecessorScratch.clear();
+ auto parentOp =
+ cast<PromotableRegionOpInterface>(liveInBlock->getParentOp());
+ parentOp.propagateLiveIn(slot, liveInBlock->getParent(),
+ regionPredecessorScratch);
+ for (Operation *pred : regionPredecessorScratch) {
+ if (liveIn.contains(pred->getBlock()))
+ continue;
+
+ Operation *storeCandidate = pred;
+ while (storeCandidate &&
+ !isStoreLike(storeCandidate, slot, regionsToPromote))
+ storeCandidate = storeCandidate->getPrevNode();
+
+ if (!storeCandidate)
+ liveInWorkList.push_back(pred->getBlock());
+ }
+ }
}
return liveIn;
@@ -389,22 +565,16 @@ SmallPtrSet<Block *, 16> MemorySlotPromotionAnalyzer::computeSlotLiveIn(
using IDFCalculator = llvm::IDFCalculatorBase<Block, false>;
void MemorySlotPromotionAnalyzer::computeMergePoints(
+ Region *region, SmallPtrSetImpl<Block *> &definingBlocks,
+ SmallPtrSetImpl<Block *> &slotLiveIn,
SmallPtrSetImpl<Block *> &mergePoints) {
- if (slot.ptr.getParentRegion()->hasOneBlock())
+ if (region->hasOneBlock())
return;
- IDFCalculator idfCalculator(dominance.getDomTree(slot.ptr.getParentRegion()));
-
- SmallPtrSet<Block *, 16> definingBlocks;
- for (Operation *user : slot.ptr.getUsers())
- if (auto storeOp = dyn_cast<PromotableMemOpInterface>(user))
- if (storeOp.storesTo(slot))
- definingBlocks.insert(user->getBlock());
+ IDFCalculator idfCalculator(dominance.getDomTree(region));
idfCalculator.setDefiningBlocks(definingBlocks);
-
- SmallPtrSet<Block *, 16> liveIn = computeSlotLiveIn(definingBlocks);
- idfCalculator.setLiveInBlocks(liveIn);
+ idfCalculator.setLiveInBlocks(slotLiveIn);
SmallVector<Block *> mergePointsVec;
idfCalculator.calculate(mergePointsVec);
@@ -430,13 +600,36 @@ MemorySlotPromotionAnalyzer::computeInfo() {
// promotion to happen. These operations need to resolve some of their uses,
// either by rewiring them or simply deleting themselves. If any of them
// cannot find a way to resolve their blocking uses, we abort the promotion.
- if (failed(computeBlockingUses(info.userToBlockingUses)))
+ // We also compute at this stage the regions that will be analyzed for
+ // reaching definition information.
+ if (failed(
+ computeBlockingUses(info.userToBlockingUses, info.regionsToPromote)))
return {};
+ // Compute the blocks containing a store for each region, either directly or
+ // inherited from a nested region. As a side effect, `definingBlocks` contains
+ // all regions with at least one store.
+ DenseMap<Region *, SmallPtrSet<Block *, 16>> definingBlocks;
+ for (Operation *user : slot.ptr.getUsers())
+ if (auto storeOp = dyn_cast<PromotableMemOpInterface>(user))
+ if (storeOp.storesTo(slot))
+ definingBlocks[user->getParentRegion()].insert(user->getBlock());
+ for (auto &[region, regionInfo] : info.regionsToPromote)
+ if (regionInfo.hasValueStores)
+ definingBlocks[region->getParentRegion()].insert(
+ region->getParentOp()->getBlock());
+
+ // TODO: When all regions involved are single-block (fairly common in
+ // region-based control-flow), there cannot be any merge points, so we could
+ // skip this costly analysis and its dependencies.
+ SmallPtrSet<Block *, 16> slotLiveIn =
+ computeSlotLiveIn(definingBlocks, info.regionsToPromote);
+
// Then, compute blocks in which two or more definitions of the allocated
// variable may conflict. These blocks will need a new block argument to
// accommodate this.
- computeMergePoints(info.mergePoints);
+ for (auto &[region, defBlocks] : definingBlocks)
+ computeMergePoints(region, defBlocks, slotLiveIn, info.mergePoints);
// The slot can be promoted if the block arguments to be created can
// actually be populated with values, which may not be possible depending
@@ -447,18 +640,24 @@ MemorySlotPromotionAnalyzer::computeInfo() {
return info;
}
-Value MemorySlotPromoter::computeReachingDefInBlock(Block *block,
- Value reachingDef) {
+Value MemorySlotPromoter::promoteInBlock(Block *block, Value reachingDef) {
+ llvm::SmallMapVector<Region *, Value, 2> regionsToProcess;
SmallVector<Operation *> blockOps;
for (Operation &op : block->getOperations())
blockOps.push_back(&op);
for (Operation *op : blockOps) {
+ // Promote operations that interact with the slot's memory.
if (auto memOp = dyn_cast<PromotableMemOpInterface>(op)) {
- if (info.userToBlockingUses.contains(memOp))
+ if (info.userToBlockingUses[memOp->getParentRegion()].contains(memOp))
reachingDefs.insert({memOp, reachingDef});
if (memOp.storesTo(slot)) {
builder.setInsertionPointAfter(memOp);
+ // To not expose default value creation to the interfaces, if we have
+ // no reaching definition by now, we set it to the default value.
+ // This is slightly too eager as `getStored` may not need it.
+ if (!reachingDef)
+ reachingDef = getOrCreateDefaultValue();
Value stored = memOp.getStored(slot, builder, reachingDef, dataLayout);
assert(stored && "a memory operation storing to a slot must provide a "
"new definition of the slot");
@@ -466,16 +665,71 @@ Value MemorySlotPromoter::computeReachingDefInBlock(Block *block,
replacedValuesMap[memOp] = stored;
}
}
+
+ // Promote regions that contain operations that interact with the slot's
+ // memory.
+ if (auto promotableRegionOp = dyn_cast<PromotableRegionOpInterface>(op)) {
+ bool needsPromotion = false;
+ bool hasValueStores = false;
+ for (Region ®ion : op->getRegions()) {
+ auto regionInfoIt = info.regionsToPromote.find(®ion);
+ if (regionInfoIt == info.regionsToPromote.end())
+ continue;
+ needsPromotion = true;
+ if (!regionInfoIt->second.hasValueStores)
+ continue;
+
+ hasValueStores = true;
+ break;
+ }
+
+ if (needsPromotion) {
+ regionsToProcess.clear();
+
+ // To not expose default value creation to the interfaces, if we have
+ // no reaching definition by now, we set it to the default value.
+ // This is slightly too eager as `setupPromotion` may not need it.
+ if (!reachingDef)
+ reachingDef = getOrCreateDefaultValue();
+
+ promotableRegionOp.setupPromotion(slot, reachingDef, hasValueStores,
+ regionsToProcess);
+
+#ifndef NDEBUG
+ for (Region ®ion : op->getRegions())
+ if (info.regionsToPromote.contains(®ion))
+ assert(
+ regionsToProcess.contains(®ion) &&
+ "reaching definition must be provided for a required region");
+#endif // NDEBUG
+
+ for (auto &[region, reachingDef] : regionsToProcess) {
+#ifndef NDEBUG
+ Region *regionCapture = region;
+ assert(llvm::any_of(op->getRegions(),
+ [&](Region &r) { return &r == regionCapture; }) &&
+ "region must be part of the operation");
+#endif // NDEBUG
+ if (!info.regionsToPromote.contains(region))
+ continue;
+ promoteInRegion(region, reachingDef);
+ }
+
+ builder.setInsertionPointAfter(op);
+ reachingDef = promotableRegionOp.finalizePromotion(
+ slot, reachingDef, hasValueStores, reachingAtBlockEnd, builder);
+ }
+ }
}
+ reachingAtBlockEnd[block] = reachingDef;
return reachingDef;
}
-void MemorySlotPromoter::computeReachingDefInRegion(Region *region,
- Value reachingDef) {
- assert(reachingDef && "expected an initial reaching def to be provided");
+void MemorySlotPromoter::promoteInRegion(Region *region, Value reachingDef) {
if (region->hasOneBlock()) {
- computeReachingDefInBlock(®ion->front(), reachingDef);
+ promoteInBlock(®ion->front(), reachingDef);
+ removeBlockingUses(region);
return;
}
@@ -486,7 +740,7 @@ void MemorySlotPromoter::computeReachingDefInRegion(Region *region,
SmallVector<DfsJob> dfsStack;
- auto &domTree = dominance.getDomTree(slot.ptr.getParentRegion());
+ auto &domTree = dominance.getDomTree(region);
dfsStack.emplace_back<DfsJob>(
{domTree.getNode(®ion->front()), reachingDef});
@@ -506,12 +760,14 @@ void MemorySlotPromoter::computeReachingDefInRegion(Region *region,
(*statistics.newBlockArgumentAmount)++;
}
- job.reachingDef = computeReachingDefInBlock(block, job.reachingDef);
- assert(job.reachingDef);
+ job.reachingDef = promoteInBlock(block, job.reachingDef);
if (auto terminator = dyn_cast<BranchOpInterface>(block->getTerminator())) {
for (BlockOperand &blockOperand : terminator->getBlockOperands()) {
if (info.mergePoints.contains(blockOperand.get())) {
+ if (!job.reachingDef)
+ job.reachingDef = getOrCreateDefaultValue();
+
terminator.getSuccessorOperands(blockOperand.getOperandNumber())
.append(job.reachingDef);
}
@@ -521,6 +777,8 @@ void MemorySlotPromoter::computeReachingDefInRegion(Region *region,
for (auto *child : job.block->children())
dfsStack.emplace_back<DfsJob>({child, job.reachingDef});
}
+
+ removeBlockingUses(region);
}
/// Gets or creates a block index mapping for `region`.
@@ -559,49 +817,63 @@ static void dominanceSort(SmallVector<Operation *> &ops, Region ®ion,
});
}
-void MemorySlotPromoter::removeBlockingUses() {
+void MemorySlotPromoter::removeBlockingUses(Region *region) {
+ auto *blockingUsesMapIt = info.userToBlockingUses.find(region);
+ if (blockingUsesMapIt == info.userToBlockingUses.end())
+ return;
+ BlockingUsesMap &blockingUsesMap = blockingUsesMapIt->second;
+
llvm::SmallVector<Operation *> usersToRemoveUses(
- llvm::make_first_range(info.userToBlockingUses));
+ llvm::make_first_range(blockingUsesMap));
// Sort according to dominance.
- dominanceSort(usersToRemoveUses, *slot.ptr.getParentBlock()->getParent(),
- blockIndexCache);
+ dominanceSort(usersToRemoveUses, *region, blockIndexCache);
- llvm::SmallVector<Operation *> toErase;
- // List of all replaced values in the slot.
- llvm::SmallVector<std::pair<Operation *, Value>> replacedValuesList;
- // Ops to visit with the `visitReplacedValues` method.
- llvm::SmallVector<PromotableOpInterface> toVisit;
for (Operation *toPromote : llvm::reverse(usersToRemoveUses)) {
if (auto toPromoteMemOp = dyn_cast<PromotableMemOpInterface>(toPromote)) {
Value reachingDef = reachingDefs.lookup(toPromoteMemOp);
// If no reaching definition is known, this use is outside the reach of
// the slot. The default value should thus be used.
+ // FIXME: This is too eager, and will generate default values even for
+ // pure stores. This cannot be removed easily as partial stores may
+ // still require a default value to complete.
if (!reachingDef)
reachingDef = getOrCreateDefaultValue();
builder.setInsertionPointAfter(toPromote);
- if (toPromoteMemOp.removeBlockingUses(
- slot, info.userToBlockingUses[toPromote], builder, reachingDef,
- dataLayout) == DeletionKind::Delete)
+ if (toPromoteMemOp.removeBlockingUses(slot, blockingUsesMap[toPromote],
+ builder, reachingDef,
+ dataLayout) == DeletionKind::Delete)
toErase.push_back(toPromote);
if (toPromoteMemOp.storesTo(slot))
if (Value replacedValue = replacedValuesMap[toPromoteMemOp])
- replacedValuesList.push_back({toPromoteMemOp, replacedValue});
+ replacedValues.push_back({toPromoteMemOp, replacedValue});
continue;
}
auto toPromoteBasic = cast<PromotableOpInterface>(toPromote);
builder.setInsertionPointAfter(toPromote);
- if (toPromoteBasic.removeBlockingUses(info.userToBlockingUses[toPromote],
+ if (toPromoteBasic.removeBlockingUses(blockingUsesMap[toPromote],
builder) == DeletionKind::Delete)
toErase.push_back(toPromote);
if (toPromoteBasic.requiresReplacedValues())
- toVisit.push_back(toPromoteBasic);
+ toVisitReplacedValues.push_back(toPromoteBasic);
}
- for (PromotableOpInterface op : toVisit) {
+}
+
+std::optional<PromotableAllocationOpInterface>
+MemorySlotPromoter::promoteSlot() {
+ // Perform the promotion recursively through nested regions. The reaching
+ // definition starts with a null value that will be replaced by a
+ // lazily-created default value if the value must be passed to a promotion
+ // interface while no store has been encountered yet.
+ promoteInRegion(slot.ptr.getParentRegion(), nullptr);
+
+ // Notify operations that requested it of the reaching definitions set by
+ // storing memory operations.
+ for (PromotableOpInterface op : toVisitReplacedValues) {
builder.setInsertionPointAfter(op);
- op.visitReplacedValues(replacedValuesList, builder);
+ op.visitReplacedValues(replacedValues, builder);
}
for (Operation *toEraseOp : toErase)
@@ -609,15 +881,6 @@ void MemorySlotPromoter::removeBlockingUses() {
assert(slot.ptr.use_empty() &&
"after promotion, the slot pointer should not be used anymore");
-}
-
-std::optional<PromotableAllocationOpInterface>
-MemorySlotPromoter::promoteSlot() {
- computeReachingDefInRegion(slot.ptr.getParentRegion(),
- getOrCreateDefaultValue());
-
- // Now that reaching definitions are known, remove all users.
- removeBlockingUses();
// Update terminators in dead branches to forward default if they are
// succeeded by a merge points.
>From 6e8baa7e0ccf20b949f02ace3322be4b70f59946 Mon Sep 17 00:00:00 2001
From: Theo Degioanni <tdegioanni at nvidia.com>
Date: Mon, 2 Mar 2026 19:10:16 +0100
Subject: [PATCH 02/15] begin implementation for SCF
---
mlir/include/mlir/Dialect/SCF/IR/SCF.h | 1 +
mlir/include/mlir/Dialect/SCF/IR/SCFOps.td | 9 +++
mlir/lib/Dialect/SCF/IR/CMakeLists.txt | 1 +
mlir/lib/Dialect/SCF/IR/MemorySlot.cpp | 86 ++++++++++++++++++++++
4 files changed, 97 insertions(+)
create mode 100644 mlir/lib/Dialect/SCF/IR/MemorySlot.cpp
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCF.h b/mlir/include/mlir/Dialect/SCF/IR/SCF.h
index e754a04b0903a..44cbb458d94fe 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCF.h
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCF.h
@@ -22,6 +22,7 @@
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
+#include "mlir/Interfaces/MemorySlotInterfaces.h"
#include "mlir/Interfaces/ParallelCombiningOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index a08cf3c95e6ce..abc6f79bb09b2 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -19,6 +19,7 @@ include "mlir/IR/RegionKindInterface.td"
include "mlir/Dialect/SCF/IR/DeviceMappingInterface.td"
include "mlir/Interfaces/DestinationStyleOpInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
+include "mlir/Interfaces/MemorySlotInterfaces.td"
include "mlir/Interfaces/ParallelCombiningOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/ViewLikeInterface.td"
@@ -78,6 +79,7 @@ def ConditionOp : SCF_Op<"condition", [
def ExecuteRegionOp : SCF_Op<"execute_region", [
DeclareOpInterfaceMethods<RegionBranchOpInterface, ["getSuccessorInputs"]>,
+ DeclareOpInterfaceMethods<PromotableRegionOpInterface>,
RecursiveMemoryEffects]> {
let summary = "operation that executes its region exactly once";
let description = [{
@@ -161,6 +163,7 @@ def ForOp : SCF_Op<"for",
ConditionallySpeculatable,
DeclareOpInterfaceMethods<RegionBranchOpInterface,
["getEntrySuccessorOperands", "getSuccessorInputs"]>,
+ DeclareOpInterfaceMethods<PromotableRegionOpInterface>,
SingleBlockImplicitTerminator<"scf::YieldOp">,
RecursiveMemoryEffects]> {
let summary = "for operation";
@@ -329,6 +332,7 @@ def ForallOp : SCF_Op<"forall", [
RecursiveMemoryEffects,
SingleBlockImplicitTerminator<"scf::InParallelOp">,
DeclareOpInterfaceMethods<RegionBranchOpInterface>,
+ DeclareOpInterfaceMethods<PromotableRegionOpInterface>,
DestinationStyleOpInterface,
HasParallelRegion
]> {
@@ -701,6 +705,7 @@ def InParallelOp : SCF_Op<"forall.in_parallel", [
def IfOp : SCF_Op<"if", [DeclareOpInterfaceMethods<RegionBranchOpInterface, [
"getNumRegionInvocations", "getRegionInvocationBounds",
"getEntrySuccessorRegions", "getSuccessorInputs"]>,
+ DeclareOpInterfaceMethods<PromotableRegionOpInterface>,
InferTypeOpAdaptor, SingleBlockImplicitTerminator<"scf::YieldOp">,
RecursiveMemoryEffects, RecursivelySpeculatable, NoRegionArguments]> {
let summary = "if-then-else operation";
@@ -806,6 +811,7 @@ def ParallelOp : SCF_Op<"parallel",
"getLoopLowerBounds", "getLoopUpperBounds", "getLoopSteps"]>,
RecursiveMemoryEffects,
DeclareOpInterfaceMethods<RegionBranchOpInterface>,
+ DeclareOpInterfaceMethods<PromotableRegionOpInterface>,
SingleBlockImplicitTerminator<"scf::ReduceOp">,
HasParallelRegion]> {
let summary = "parallel for operation";
@@ -904,6 +910,7 @@ def ParallelOp : SCF_Op<"parallel",
def ReduceOp : SCF_Op<"reduce", [
Terminator, HasParent<"ParallelOp">, RecursiveMemoryEffects,
+ DeclareOpInterfaceMethods<PromotableRegionOpInterface>,
DeclareOpInterfaceMethods<RegionBranchTerminatorOpInterface>]> {
let summary = "reduce operation for scf.parallel";
let description = [{
@@ -986,6 +993,7 @@ def WhileOp : SCF_Op<"while",
["getEntrySuccessorOperands", "getSuccessorInputs"]>,
DeclareOpInterfaceMethods<LoopLikeOpInterface,
["getRegionIterArgs", "getYieldedValuesMutable"]>,
+ DeclareOpInterfaceMethods<PromotableRegionOpInterface>,
RecursiveMemoryEffects, SingleBlock]> {
let summary = "a generic 'while' loop";
let description = [{
@@ -1135,6 +1143,7 @@ def WhileOp : SCF_Op<"while",
def IndexSwitchOp : SCF_Op<"index_switch", [RecursiveMemoryEffects,
SingleBlockImplicitTerminator<"scf::YieldOp">,
+ DeclareOpInterfaceMethods<PromotableRegionOpInterface>,
DeclareOpInterfaceMethods<RegionBranchOpInterface,
["getRegionInvocationBounds",
"getEntrySuccessorRegions",
diff --git a/mlir/lib/Dialect/SCF/IR/CMakeLists.txt b/mlir/lib/Dialect/SCF/IR/CMakeLists.txt
index b111117410ba3..fca28c5209e2d 100644
--- a/mlir/lib/Dialect/SCF/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/SCF/IR/CMakeLists.txt
@@ -1,6 +1,7 @@
add_mlir_dialect_library(MLIRSCFDialect
SCF.cpp
DeviceMappingInterface.cpp
+ MemorySlot.cpp
ValueBoundsOpInterfaceImpl.cpp
ADDITIONAL_HEADER_DIRS
diff --git a/mlir/lib/Dialect/SCF/IR/MemorySlot.cpp b/mlir/lib/Dialect/SCF/IR/MemorySlot.cpp
new file mode 100644
index 0000000000000..ef3e6bf01dbc1
--- /dev/null
+++ b/mlir/lib/Dialect/SCF/IR/MemorySlot.cpp
@@ -0,0 +1,86 @@
+//===- MemorySlot.cpp - Memory Slot interface implementations for SCF -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SCF/IR/SCF.h"
+
+using namespace mlir;
+using namespace mlir::scf;
+
+//===----------------------------------------------------------------------===//
+// Helper functions
+//===----------------------------------------------------------------------===//
+
+/// Creates a shallow copy of an operation with new result types moving the
+/// regions out of the original operation, then deletes the original operation.
+template <typename OpTy>
+static OpTy replaceWithNewResults(OpBuilder &builder, Operation *op,
+ TypeRange resultTypes) {
+ OpBuilder::InsertionGuard guard(builder);
+ builder.setInsertionPoint(op);
+ builder.
+ auto newOp = OpTy::create(builder, op->getLoc(), resultTypes,
+ op->getOperands(), op->getProperties(),
+ op->getSuccessors(), op->getNumRegions());
+ builder.create()
+ op.erase();
+ return newOp;
+}
+
+//===----------------------------------------------------------------------===//
+// ExecuteRegionOp
+//===----------------------------------------------------------------------===//
+
+bool ExecuteRegionOp::isRegionPromotable(const MemorySlot &slot, Region *region,
+ bool hasValueStores) {
+ return true;
+}
+
+void ExecuteRegionOp::propagateLiveIn(
+ const MemorySlot &slot, Region *regionLiveIn,
+ SmallPtrSetImpl<Operation *> &operationsLiveIn) {
+ assert(regionLiveIn == &getRegion() &&
+ "regionLiveIn must be the region of the ExecuteRegionOp");
+ operationsLiveIn.insert(getOperation());
+}
+
+void ExecuteRegionOp::setupPromotion(
+ const MemorySlot &slot, Value reachingDef, bool hasValueStores,
+ llvm::SmallMapVector<Region *, Value, 2> ®ionsToProcess) {
+ regionsToProcess.insert({&getRegion(), reachingDef});
+}
+
+Value ExecuteRegionOp::finalizePromotion(
+ const MemorySlot &slot, Value reachingDef, bool hasValueStores,
+ llvm::DenseMap<Block *, Value> &reachingAtBlockEnd, OpBuilder &builder) {
+ if (!hasValueStores)
+ return reachingDef;
+
+ // Update the yield terminators to return the newly defined reaching
+ // definition.
+ for (Block &block : getRegion().getBlocks()) {
+ Operation *terminator = block.getTerminator();
+ if (!isa<YieldOp>(terminator))
+ continue;
+ Value blockReachingDef = reachingAtBlockEnd[block];
+ if (!blockReachingDef) {
+ // Block is dead code or the region is not using the slot, so the reaching
+ // definition is the entry reaching definition.
+ blockReachingDef = reachingDef;
+ }
+ terminator->insertOperands(terminator->getNumOperands(),
+ {blockReachingDef});
+ }
+
+ SmallVector<Type> resultTypes(getResultTypes());
+ resultTypes.push_back(slot.elemType);
+
+ auto newOp = replaceWithNewResults<ExecuteRegionOp>(builder, getOperation(),
+ resultTypes);
+
+ return reachingDef;
+}
>From b2e4b980f84d77bae33a955ddb49e32cb5bc02e6 Mon Sep 17 00:00:00 2001
From: Theo Degioanni <tdegioanni at nvidia.com>
Date: Tue, 3 Mar 2026 15:34:30 +0100
Subject: [PATCH 03/15] implement interfaces for all SCF ops
---
mlir/lib/Dialect/SCF/IR/MemorySlot.cpp | 401 +++++++++++++++++++++++--
1 file changed, 375 insertions(+), 26 deletions(-)
diff --git a/mlir/lib/Dialect/SCF/IR/MemorySlot.cpp b/mlir/lib/Dialect/SCF/IR/MemorySlot.cpp
index ef3e6bf01dbc1..c1b2e122c45a2 100644
--- a/mlir/lib/Dialect/SCF/IR/MemorySlot.cpp
+++ b/mlir/lib/Dialect/SCF/IR/MemorySlot.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
using namespace mlir;
using namespace mlir::scf;
@@ -15,19 +16,42 @@ using namespace mlir::scf;
// Helper functions
//===----------------------------------------------------------------------===//
+/// Adds the corresponding reaching definition to the terminator of the block if
+/// the terminator is of the provided type.
+template <typename TermTy>
+static void
+updateTerminator(Block *block, Value reachingDef,
+ llvm::DenseMap<Block *, Value> &reachingAtBlockEnd) {
+ Operation *terminator = block->getTerminator();
+ if (!isa<TermTy>(terminator))
+ return;
+ Value blockReachingDef = reachingAtBlockEnd[block];
+ if (!blockReachingDef) {
+ // Block is dead code or the region is not using the slot, so the reaching
+ // definition is the entry reaching definition.
+ blockReachingDef = reachingDef;
+ }
+ terminator->insertOperands(terminator->getNumOperands(), {blockReachingDef});
+}
+
/// Creates a shallow copy of an operation with new result types moving the
/// regions out of the original operation, then deletes the original operation.
-template <typename OpTy>
-static OpTy replaceWithNewResults(OpBuilder &builder, Operation *op,
- TypeRange resultTypes) {
- OpBuilder::InsertionGuard guard(builder);
- builder.setInsertionPoint(op);
- builder.
- auto newOp = OpTy::create(builder, op->getLoc(), resultTypes,
- op->getOperands(), op->getProperties(),
- op->getSuccessors(), op->getNumRegions());
- builder.create()
- op.erase();
+static Operation *replaceWithNewResults(RewriterBase &rewriter, Operation *op,
+ TypeRange resultTypes) {
+ RewriterBase::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPoint(op);
+ Operation *newOp =
+ mlir::cloneWithoutRegions(rewriter, op, resultTypes, op->getOperands());
+ rewriter.startOpModification(newOp);
+ rewriter.startOpModification(op);
+ for (unsigned int i : llvm::seq(op->getNumRegions()))
+ newOp->getRegion(i).takeBody(op->getRegion(i));
+ rewriter.finalizeOpModification(op);
+ rewriter.finalizeOpModification(newOp);
+
+ SmallVector<Value> replacementValues(newOp->getResults().drop_back());
+ rewriter.replaceAllOpUsesWith(op, replacementValues);
+ rewriter.eraseOp(op);
return newOp;
}
@@ -44,7 +68,7 @@ void ExecuteRegionOp::propagateLiveIn(
const MemorySlot &slot, Region *regionLiveIn,
SmallPtrSetImpl<Operation *> &operationsLiveIn) {
assert(regionLiveIn == &getRegion() &&
- "regionLiveIn must be the region of the ExecuteRegionOp");
+ "regionLiveIn can only be the region of the ExecuteRegionOp");
operationsLiveIn.insert(getOperation());
}
@@ -62,25 +86,350 @@ Value ExecuteRegionOp::finalizePromotion(
// Update the yield terminators to return the newly defined reaching
// definition.
- for (Block &block : getRegion().getBlocks()) {
- Operation *terminator = block.getTerminator();
- if (!isa<YieldOp>(terminator))
- continue;
- Value blockReachingDef = reachingAtBlockEnd[block];
- if (!blockReachingDef) {
- // Block is dead code or the region is not using the slot, so the reaching
- // definition is the entry reaching definition.
- blockReachingDef = reachingDef;
- }
- terminator->insertOperands(terminator->getNumOperands(),
- {blockReachingDef});
+ for (Block &block : getRegion().getBlocks())
+ updateTerminator<YieldOp>(&block, reachingDef, reachingAtBlockEnd);
+
+ SmallVector<Type> resultTypes(getResultTypes());
+ resultTypes.push_back(slot.elemType);
+
+ IRRewriter rewriter(builder);
+ Operation *newOp =
+ replaceWithNewResults(rewriter, getOperation(), resultTypes);
+ return newOp->getResults().back();
+}
+
+//===----------------------------------------------------------------------===//
+// ForOp
+//===----------------------------------------------------------------------===//
+
+bool ForOp::isRegionPromotable(const MemorySlot &slot, Region *region,
+ bool hasValueStores) {
+ return true;
+}
+
+void ForOp::propagateLiveIn(const MemorySlot &slot, Region *regionLiveIn,
+ SmallPtrSetImpl<Operation *> &operationsLiveIn) {
+ assert(regionLiveIn == &getBodyRegion() &&
+ "regionLiveIn can only be the region of the ForOp");
+ operationsLiveIn.insert(getOperation());
+ operationsLiveIn.insert(getBody()->getTerminator());
+}
+
+void ForOp::setupPromotion(
+ const MemorySlot &slot, Value reachingDef, bool hasValueStores,
+ llvm::SmallMapVector<Region *, Value, 2> ®ionsToProcess) {
+ Region &bodyRegion = getBodyRegion();
+ if (!hasValueStores)
+ regionsToProcess.insert({&bodyRegion, reachingDef});
+
+ bodyRegion.addArgument(slot.elemType, slot.ptr.getLoc());
+ regionsToProcess.insert({&bodyRegion, bodyRegion.getArguments().back()});
+}
+
+Value ForOp::finalizePromotion(
+ const MemorySlot &slot, Value reachingDef, bool hasValueStores,
+ llvm::DenseMap<Block *, Value> &reachingAtBlockEnd, OpBuilder &builder) {
+ if (!hasValueStores)
+ return reachingDef;
+
+ // Update the yield terminator to return the newly defined reaching
+ // definition.
+ updateTerminator<YieldOp>(getBody(), reachingDef, reachingAtBlockEnd);
+
+ SmallVector<Type> resultTypes(getResultTypes());
+ resultTypes.push_back(slot.elemType);
+
+ IRRewriter rewriter(builder);
+ Operation *newOp =
+ replaceWithNewResults(rewriter, getOperation(), resultTypes);
+ return newOp->getResults().back();
+}
+
+//===----------------------------------------------------------------------===//
+// ForallOp
+//===----------------------------------------------------------------------===//
+
+bool ForallOp::isRegionPromotable(const MemorySlot &slot, Region *region,
+ bool hasValueStores) {
+ // The ForallOp body can be ran in parallel, thus does not support sequenced
+ // value passing. Therefore only loads can be handled.
+ return !hasValueStores;
+}
+
+void ForallOp::propagateLiveIn(const MemorySlot &slot, Region *regionLiveIn,
+ SmallPtrSetImpl<Operation *> &operationsLiveIn) {
+ assert(regionLiveIn == &getBodyRegion() &&
+ "regionLiveIn can only be the region of the ForallOp");
+ // Due to the parallel semantics of ForallOp, there is no liveness dependency
+ // on the body region as liveness cannot be influenced by neighboring
+ // iterations.
+ operationsLiveIn.insert(getOperation());
+}
+
+void ForallOp::setupPromotion(
+ const MemorySlot &slot, Value reachingDef, bool hasValueStores,
+ llvm::SmallMapVector<Region *, Value, 2> ®ionsToProcess) {
+ assert(!hasValueStores && "ForallOp does not support stores");
+ regionsToProcess.insert({&getBodyRegion(), reachingDef});
+}
+
+Value ForallOp::finalizePromotion(
+ const MemorySlot &slot, Value reachingDef, bool hasValueStores,
+ llvm::DenseMap<Block *, Value> &reachingAtBlockEnd, OpBuilder &builder) {
+ assert(!hasValueStores && "ForallOp does not support stores");
+ return reachingDef;
+}
+
+//===----------------------------------------------------------------------===//
+// IfOp
+//===----------------------------------------------------------------------===//
+
+bool IfOp::isRegionPromotable(const MemorySlot &slot, Region *region,
+ bool hasValueStores) {
+ return true;
+}
+
+void IfOp::propagateLiveIn(const MemorySlot &slot, Region *regionLiveIn,
+ SmallPtrSetImpl<Operation *> &operationsLiveIn) {
+ assert(regionLiveIn == &getThenRegion() || regionLiveIn == &getElseRegion());
+ operationsLiveIn.insert(getOperation());
+}
+
+void IfOp::setupPromotion(
+ const MemorySlot &slot, Value reachingDef, bool hasValueStores,
+ llvm::SmallMapVector<Region *, Value, 2> ®ionsToProcess) {
+ regionsToProcess.insert({&getThenRegion(), reachingDef});
+ regionsToProcess.insert({&getElseRegion(), reachingDef});
+}
+
+Value IfOp::finalizePromotion(
+ const MemorySlot &slot, Value reachingDef, bool hasValueStores,
+ llvm::DenseMap<Block *, Value> &reachingAtBlockEnd, OpBuilder &builder) {
+ if (!hasValueStores)
+ return reachingDef;
+
+ IRRewriter rewriter(builder);
+
+ // Update the yield terminators to return the newly defined reaching
+ // definition.
+ updateTerminator<YieldOp>(&getThenRegion().back(), reachingDef,
+ reachingAtBlockEnd);
+ if (getElseRegion().hasOneBlock()) {
+ updateTerminator<YieldOp>(&getElseRegion().back(), reachingDef,
+ reachingAtBlockEnd);
+ } else {
+ OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.createBlock(&getElseRegion());
+ YieldOp::create(rewriter, getOperation()->getLoc(), reachingDef);
}
SmallVector<Type> resultTypes(getResultTypes());
resultTypes.push_back(slot.elemType);
- auto newOp = replaceWithNewResults<ExecuteRegionOp>(builder, getOperation(),
- resultTypes);
+ Operation *newOp =
+ replaceWithNewResults(rewriter, getOperation(), resultTypes);
+ return newOp->getResults().back();
+}
+
+//===----------------------------------------------------------------------===//
+// IndexSwitchOp
+//===----------------------------------------------------------------------===//
+
+bool IndexSwitchOp::isRegionPromotable(const MemorySlot &slot, Region *region,
+ bool hasValueStores) {
+ return true;
+}
+
+void IndexSwitchOp::propagateLiveIn(
+ const MemorySlot &slot, Region *regionLiveIn,
+ SmallPtrSetImpl<Operation *> &operationsLiveIn) {
+#ifndef NDEBUG
+ auto checkRegionValid = [&](IndexSwitchOp op) {
+ if (regionLiveIn == &op.getDefaultRegion())
+ return;
+
+ for (Region &caseRegion : op.getCaseRegions())
+ if (regionLiveIn == &caseRegion)
+ return;
+
+ assert(false && "regionLiveIn can only be the default region or a case "
+ "region of the IndexSwitchOp");
+ };
+
+ checkRegionValid(*this);
+#endif
+ operationsLiveIn.insert(getOperation());
+}
+
+void IndexSwitchOp::setupPromotion(
+ const MemorySlot &slot, Value reachingDef, bool hasValueStores,
+ llvm::SmallMapVector<Region *, Value, 2> ®ionsToProcess) {
+ regionsToProcess.insert({&getDefaultRegion(), reachingDef});
+ for (Region &caseRegion : getCaseRegions())
+ regionsToProcess.insert({&caseRegion, reachingDef});
+}
+
+Value IndexSwitchOp::finalizePromotion(
+ const MemorySlot &slot, Value reachingDef, bool hasValueStores,
+ llvm::DenseMap<Block *, Value> &reachingAtBlockEnd, OpBuilder &builder) {
+ if (!hasValueStores)
+ return reachingDef;
+
+ IRRewriter rewriter(builder);
+
+ // Update the yield terminators to return the newly defined reaching
+ // definition.
+ updateTerminator<YieldOp>(&getDefaultRegion().back(), reachingDef,
+ reachingAtBlockEnd);
+ for (Region &caseRegion : getCaseRegions())
+ updateTerminator<YieldOp>(&caseRegion.back(), reachingDef,
+ reachingAtBlockEnd);
+
+ SmallVector<Type> resultTypes(getResultTypes());
+ resultTypes.push_back(slot.elemType);
+
+ Operation *newOp =
+ replaceWithNewResults(rewriter, getOperation(), resultTypes);
+ return newOp->getResults().back();
+}
+
+//===----------------------------------------------------------------------===//
+// ParallelOp
+//===----------------------------------------------------------------------===//
+
+bool ParallelOp::isRegionPromotable(const MemorySlot &slot, Region *region,
+ bool hasValueStores) {
+ // The ParallelOp body can be ran in parallel, thus does not support sequenced
+ // value passing. Therefore only loads can be handled.
+ return !hasValueStores;
+}
+
+void ParallelOp::propagateLiveIn(
+ const MemorySlot &slot, Region *regionLiveIn,
+ SmallPtrSetImpl<Operation *> &operationsLiveIn) {
+ assert(regionLiveIn == &getBodyRegion() &&
+ "regionLiveIn can only be the region of the ParallelOp");
+ // Due to the parallel semantics of ParallelOp, there is no liveness
+ // dependency on the body region as liveness cannot be influenced by
+ // neighboring iterations.
+ operationsLiveIn.insert(getOperation());
+}
+
+void ParallelOp::setupPromotion(
+ const MemorySlot &slot, Value reachingDef, bool hasValueStores,
+ llvm::SmallMapVector<Region *, Value, 2> ®ionsToProcess) {
+ assert(!hasValueStores && "ParallelOp does not support stores");
+ regionsToProcess.insert({&getBodyRegion(), reachingDef});
+}
+
+Value ParallelOp::finalizePromotion(
+ const MemorySlot &slot, Value reachingDef, bool hasValueStores,
+ llvm::DenseMap<Block *, Value> &reachingAtBlockEnd, OpBuilder &builder) {
+ assert(!hasValueStores && "ParallelOp does not support stores");
return reachingDef;
}
+
+//===----------------------------------------------------------------------===//
+// ReduceOp
+//===----------------------------------------------------------------------===//
+
+bool ReduceOp::isRegionPromotable(const MemorySlot &slot, Region *region,
+ bool hasValueStores) {
+ // The ReduceOp body can be ran in parallel, thus does not support sequenced
+ // value passing. Therefore only loads can be handled.
+ return !hasValueStores;
+}
+
+void ReduceOp::propagateLiveIn(const MemorySlot &slot, Region *regionLiveIn,
+ SmallPtrSetImpl<Operation *> &operationsLiveIn) {
+#ifndef NDEBUG
+ auto checkRegionValid = [&](ReduceOp op) {
+ for (Region &reduction : op.getReductions())
+ if (regionLiveIn == &reduction)
+ return;
+
+ assert(false &&
+ "regionLiveIn can only be a reduction region of the ReduceOp");
+ };
+
+ checkRegionValid(*this);
+#endif
+
+ operationsLiveIn.insert(getOperation());
+}
+
+void ReduceOp::setupPromotion(
+ const MemorySlot &slot, Value reachingDef, bool hasValueStores,
+ llvm::SmallMapVector<Region *, Value, 2> ®ionsToProcess) {
+ assert(!hasValueStores && "ReduceOp does not support stores");
+ for (Region &reduction : getReductions())
+ regionsToProcess.insert({&reduction, reachingDef});
+}
+
+Value ReduceOp::finalizePromotion(
+ const MemorySlot &slot, Value reachingDef, bool hasValueStores,
+ llvm::DenseMap<Block *, Value> &reachingAtBlockEnd, OpBuilder &builder) {
+ assert(!hasValueStores && "ReduceOp does not support stores");
+ return reachingDef;
+}
+
+//===----------------------------------------------------------------------===//
+// WhileOp
+//===----------------------------------------------------------------------===//
+
+bool WhileOp::isRegionPromotable(const MemorySlot &slot, Region *region,
+ bool hasValueStores) {
+ return true;
+}
+
+void WhileOp::propagateLiveIn(const MemorySlot &slot, Region *regionLiveIn,
+ SmallPtrSetImpl<Operation *> &operationsLiveIn) {
+ if (regionLiveIn == &getBefore()) {
+ operationsLiveIn.insert(getOperation());
+ operationsLiveIn.insert(getAfterBody()->getTerminator());
+ }
+
+ assert(regionLiveIn == &getAfter());
+ operationsLiveIn.insert(getBeforeBody()->getTerminator());
+}
+
+void WhileOp::setupPromotion(
+ const MemorySlot &slot, Value reachingDef, bool hasValueStores,
+ llvm::SmallMapVector<Region *, Value, 2> ®ionsToProcess) {
+ Region &beforeRegion = getBefore();
+ Region &afterRegion = getAfter();
+ if (!hasValueStores) {
+ regionsToProcess.insert({&beforeRegion, reachingDef});
+ regionsToProcess.insert({&afterRegion, reachingDef});
+ return;
+ }
+
+ beforeRegion.addArgument(slot.elemType, slot.ptr.getLoc());
+ regionsToProcess.insert({&beforeRegion, beforeRegion.getArguments().back()});
+
+ afterRegion.addArgument(slot.elemType, slot.ptr.getLoc());
+ regionsToProcess.insert({&afterRegion, afterRegion.getArguments().back()});
+}
+
+Value WhileOp::finalizePromotion(
+ const MemorySlot &slot, Value reachingDef, bool hasValueStores,
+ llvm::DenseMap<Block *, Value> &reachingAtBlockEnd, OpBuilder &builder) {
+ if (!hasValueStores)
+ return reachingDef;
+
+ // Update the yield terminators to return the newly defined reaching
+ // definition.
+ updateTerminator<ConditionOp>(&getBefore().back(), reachingDef,
+ reachingAtBlockEnd);
+ updateTerminator<YieldOp>(&getAfter().back(), reachingDef,
+ reachingAtBlockEnd);
+
+ SmallVector<Type> resultTypes(getResultTypes());
+ resultTypes.push_back(slot.elemType);
+
+ IRRewriter rewriter(builder);
+ Operation *newOp =
+ replaceWithNewResults(rewriter, getOperation(), resultTypes);
+ return newOp->getResults().back();
+}
>From 3d2bf5119bbb8d3808c0144cb7f008588b399ed7 Mon Sep 17 00:00:00 2001
From: Theo Degioanni <tdegioanni at nvidia.com>
Date: Thu, 5 Mar 2026 14:49:43 +0100
Subject: [PATCH 04/15] remove liveness analysis from mem2reg
---
.../mlir/Interfaces/MemorySlotInterfaces.td | 15 +-
mlir/lib/Dialect/SCF/IR/MemorySlot.cpp | 94 -------
mlir/lib/Transforms/Mem2Reg.cpp | 218 ++++-----------
mlir/test/Dialect/LLVMIR/mem2reg.mlir | 10 +-
mlir/test/Dialect/SCF/mem2reg.mlir | 254 ++++++++++++++++++
mlir/test/Transforms/mem2reg.mlir | 51 ++++
6 files changed, 366 insertions(+), 276 deletions(-)
create mode 100644 mlir/test/Dialect/SCF/mem2reg.mlir
diff --git a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
index b09d88bd171fd..ab2cb39227525 100644
--- a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
+++ b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
@@ -53,6 +53,8 @@ def PromotableAllocationOpInterface
InterfaceMethod<[{
Hook triggered for every new block argument added to a block.
This will only be called for slots declared by this operation.
+ This function is called after removal of blocking uses, meaning
+ only operations that will be deleted remain users of the slot.
The builder is located at the beginning of the block on call. All IR
mutations must happen through the builder.
@@ -285,19 +287,6 @@ def PromotableRegionOpInterface
"bool":$hasValueStores
)
>,
- InterfaceMethod<[{
- Specifies which operations should be marked as live-in with respect to
- the value of the provided slot assuming the provided region is live-in.
-
- In other words, specifies which operation immediately passes control to
- the region without storing to the value of the slot.
- }], "void", "propagateLiveIn",
- (ins
- "const ::mlir::MemorySlot &":$slot,
- "::mlir::Region *":$regionLiveIn,
- "::mlir::SmallPtrSetImpl<::mlir::Operation *> &":$operationsLiveIn
- )
- >,
InterfaceMethod<[{
Called before processing the nested regions in the operation.
diff --git a/mlir/lib/Dialect/SCF/IR/MemorySlot.cpp b/mlir/lib/Dialect/SCF/IR/MemorySlot.cpp
index c1b2e122c45a2..0b1bbfb17dce0 100644
--- a/mlir/lib/Dialect/SCF/IR/MemorySlot.cpp
+++ b/mlir/lib/Dialect/SCF/IR/MemorySlot.cpp
@@ -64,14 +64,6 @@ bool ExecuteRegionOp::isRegionPromotable(const MemorySlot &slot, Region *region,
return true;
}
-void ExecuteRegionOp::propagateLiveIn(
- const MemorySlot &slot, Region *regionLiveIn,
- SmallPtrSetImpl<Operation *> &operationsLiveIn) {
- assert(regionLiveIn == &getRegion() &&
- "regionLiveIn can only be the region of the ExecuteRegionOp");
- operationsLiveIn.insert(getOperation());
-}
-
void ExecuteRegionOp::setupPromotion(
const MemorySlot &slot, Value reachingDef, bool hasValueStores,
llvm::SmallMapVector<Region *, Value, 2> ®ionsToProcess) {
@@ -107,14 +99,6 @@ bool ForOp::isRegionPromotable(const MemorySlot &slot, Region *region,
return true;
}
-void ForOp::propagateLiveIn(const MemorySlot &slot, Region *regionLiveIn,
- SmallPtrSetImpl<Operation *> &operationsLiveIn) {
- assert(regionLiveIn == &getBodyRegion() &&
- "regionLiveIn can only be the region of the ForOp");
- operationsLiveIn.insert(getOperation());
- operationsLiveIn.insert(getBody()->getTerminator());
-}
-
void ForOp::setupPromotion(
const MemorySlot &slot, Value reachingDef, bool hasValueStores,
llvm::SmallMapVector<Region *, Value, 2> ®ionsToProcess) {
@@ -156,16 +140,6 @@ bool ForallOp::isRegionPromotable(const MemorySlot &slot, Region *region,
return !hasValueStores;
}
-void ForallOp::propagateLiveIn(const MemorySlot &slot, Region *regionLiveIn,
- SmallPtrSetImpl<Operation *> &operationsLiveIn) {
- assert(regionLiveIn == &getBodyRegion() &&
- "regionLiveIn can only be the region of the ForallOp");
- // Due to the parallel semantics of ForallOp, there is no liveness dependency
- // on the body region as liveness cannot be influenced by neighboring
- // iterations.
- operationsLiveIn.insert(getOperation());
-}
-
void ForallOp::setupPromotion(
const MemorySlot &slot, Value reachingDef, bool hasValueStores,
llvm::SmallMapVector<Region *, Value, 2> ®ionsToProcess) {
@@ -189,12 +163,6 @@ bool IfOp::isRegionPromotable(const MemorySlot &slot, Region *region,
return true;
}
-void IfOp::propagateLiveIn(const MemorySlot &slot, Region *regionLiveIn,
- SmallPtrSetImpl<Operation *> &operationsLiveIn) {
- assert(regionLiveIn == &getThenRegion() || regionLiveIn == &getElseRegion());
- operationsLiveIn.insert(getOperation());
-}
-
void IfOp::setupPromotion(
const MemorySlot &slot, Value reachingDef, bool hasValueStores,
llvm::SmallMapVector<Region *, Value, 2> ®ionsToProcess) {
@@ -240,28 +208,6 @@ bool IndexSwitchOp::isRegionPromotable(const MemorySlot &slot, Region *region,
return true;
}
-void IndexSwitchOp::propagateLiveIn(
- const MemorySlot &slot, Region *regionLiveIn,
- SmallPtrSetImpl<Operation *> &operationsLiveIn) {
-#ifndef NDEBUG
- auto checkRegionValid = [&](IndexSwitchOp op) {
- if (regionLiveIn == &op.getDefaultRegion())
- return;
-
- for (Region &caseRegion : op.getCaseRegions())
- if (regionLiveIn == &caseRegion)
- return;
-
- assert(false && "regionLiveIn can only be the default region or a case "
- "region of the IndexSwitchOp");
- };
-
- checkRegionValid(*this);
-#endif
-
- operationsLiveIn.insert(getOperation());
-}
-
void IndexSwitchOp::setupPromotion(
const MemorySlot &slot, Value reachingDef, bool hasValueStores,
llvm::SmallMapVector<Region *, Value, 2> ®ionsToProcess) {
@@ -305,17 +251,6 @@ bool ParallelOp::isRegionPromotable(const MemorySlot &slot, Region *region,
return !hasValueStores;
}
-void ParallelOp::propagateLiveIn(
- const MemorySlot &slot, Region *regionLiveIn,
- SmallPtrSetImpl<Operation *> &operationsLiveIn) {
- assert(regionLiveIn == &getBodyRegion() &&
- "regionLiveIn can only be the region of the ParallelOp");
- // Due to the parallel semantics of ParallelOp, there is no liveness
- // dependency on the body region as liveness cannot be influenced by
- // neighboring iterations.
- operationsLiveIn.insert(getOperation());
-}
-
void ParallelOp::setupPromotion(
const MemorySlot &slot, Value reachingDef, bool hasValueStores,
llvm::SmallMapVector<Region *, Value, 2> ®ionsToProcess) {
@@ -341,24 +276,6 @@ bool ReduceOp::isRegionPromotable(const MemorySlot &slot, Region *region,
return !hasValueStores;
}
-void ReduceOp::propagateLiveIn(const MemorySlot &slot, Region *regionLiveIn,
- SmallPtrSetImpl<Operation *> &operationsLiveIn) {
-#ifndef NDEBUG
- auto checkRegionValid = [&](ReduceOp op) {
- for (Region &reduction : op.getReductions())
- if (regionLiveIn == &reduction)
- return;
-
- assert(false &&
- "regionLiveIn can only be a reduction region of the ReduceOp");
- };
-
- checkRegionValid(*this);
-#endif
-
- operationsLiveIn.insert(getOperation());
-}
-
void ReduceOp::setupPromotion(
const MemorySlot &slot, Value reachingDef, bool hasValueStores,
llvm::SmallMapVector<Region *, Value, 2> ®ionsToProcess) {
@@ -383,17 +300,6 @@ bool WhileOp::isRegionPromotable(const MemorySlot &slot, Region *region,
return true;
}
-void WhileOp::propagateLiveIn(const MemorySlot &slot, Region *regionLiveIn,
- SmallPtrSetImpl<Operation *> &operationsLiveIn) {
- if (regionLiveIn == &getBefore()) {
- operationsLiveIn.insert(getOperation());
- operationsLiveIn.insert(getAfterBody()->getTerminator());
- }
-
- assert(regionLiveIn == &getAfter());
- operationsLiveIn.insert(getBeforeBody()->getTerminator());
-}
-
void WhileOp::setupPromotion(
const MemorySlot &slot, Value reachingDef, bool hasValueStores,
llvm::SmallMapVector<Region *, Value, 2> ®ionsToProcess) {
diff --git a/mlir/lib/Transforms/Mem2Reg.cpp b/mlir/lib/Transforms/Mem2Reg.cpp
index 26ab65c68b259..6c1704331587e 100644
--- a/mlir/lib/Transforms/Mem2Reg.cpp
+++ b/mlir/lib/Transforms/Mem2Reg.cpp
@@ -176,24 +176,12 @@ class MemorySlotPromotionAnalyzer {
RegionBlockingUsesMap &userToBlockingUses,
DenseMap<Region *, RegionPromotionInfo> ®ionsToPromote);
- /// Computes in which blocks the value stored in the slot is actually used,
- /// meaning blocks leading to a load. This method uses `definingBlocks`, the
- /// set of blocks containing a store to the slot (defining the value of the
- /// slot).
- /// The analysis is aware of regions and uses region promotion information
- /// to determine the effect of nested regions on slot value liveness.
- SmallPtrSet<Block *, 16> computeSlotLiveIn(
- DenseMap<Region *, SmallPtrSet<Block *, 16>> &definingBlocksByRegion,
- DenseMap<Region *, RegionPromotionInfo> ®ionsToPromote);
-
/// Computes the points in the provided region where multiple re-definitions
/// of the slot's value (stores) may conflict.
/// `definingBlocks` is the set of blocks containing a store to the slot,
/// either directly or inherited from a nested region.
- /// `slotLiveIn` is the set of blocks where the memory slot is live-in.
void computeMergePoints(Region *region,
SmallPtrSetImpl<Block *> &definingBlocks,
- SmallPtrSetImpl<Block *> &slotLiveIn,
SmallPtrSetImpl<Block *> &mergePoints);
/// Ensures predecessors of merge points can properly provide their current
@@ -256,6 +244,10 @@ class MemorySlotPromoter {
/// topological order.
void removeBlockingUses(Region *region);
+ /// Links merge point block arguments to the terminators targeting the merge
+ /// point or remove the argument if it is not used.
+ void linkMergePoints();
+
/// Lazily-constructed default value representing the content of the slot when
/// no store has been executed. This function may mutate IR.
Value getOrCreateDefaultValue();
@@ -454,127 +446,15 @@ LogicalResult MemorySlotPromotionAnalyzer::computeBlockingUses(
return success();
}
-/// Returns true if the operation contains a store, whether itself or in a
-/// nested region.
-static bool
-isStoreLike(Operation *op, MemorySlot &slot,
- DenseMap<Region *, RegionPromotionInfo> ®ionsToPromote) {
- auto promotableMemOp = dyn_cast<PromotableMemOpInterface>(op);
- if (promotableMemOp && promotableMemOp.storesTo(slot))
- return true;
-
- auto promotableRegionOp = dyn_cast<PromotableRegionOpInterface>(op);
- if (!promotableRegionOp)
- return false;
-
- for (Region ®ion : op->getRegions()) {
- auto regionInfoIt = regionsToPromote.find(®ion);
- if (regionInfoIt == regionsToPromote.end())
- continue;
-
- if (regionInfoIt->second.hasValueStores)
- return true;
- }
-
- return false;
-}
-
-SmallPtrSet<Block *, 16> MemorySlotPromotionAnalyzer::computeSlotLiveIn(
- DenseMap<Region *, SmallPtrSet<Block *, 16>> &definingBlocksByRegion,
- DenseMap<Region *, RegionPromotionInfo> ®ionsToPromote) {
- SmallPtrSet<Block *, 16> liveIn;
-
- // The worklist contains blocks in which it is known that the slot value is
- // live-in. The further blocks where this value is live-in will be inferred
- // from these.
- SmallVector<Block *> liveInWorkList;
-
- SmallPtrSet<Operation *, 4> regionPredecessorScratch;
-
- // Blocks with a load before any other store to the slot are the starting
- // points of the analysis. The slot value is definitely live-in in those
- // blocks.
- SmallPtrSet<Block *, 16> visited;
- for (Operation *user : slot.ptr.getUsers()) {
- if (!visited.insert(user->getBlock()).second)
- continue;
-
- for (Operation &op : user->getBlock()->getOperations()) {
- if (auto memOp = dyn_cast<PromotableMemOpInterface>(op)) {
- // If this operation loads the slot, it is loading from it before
- // ever writing to it, so the value is live-in in this block.
- if (memOp.loadsFrom(slot)) {
- liveInWorkList.push_back(user->getBlock());
- break;
- }
-
- // If we store to the slot, further loads will see that value.
- // Because we did not meet any load before, the value is not live-in.
- if (memOp.storesTo(slot))
- break;
- }
- }
- }
-
- // The information is then propagated to the predecessors until a def site
- // (store) is found.
- while (!liveInWorkList.empty()) {
- Block *liveInBlock = liveInWorkList.pop_back_val();
-
- if (!liveIn.insert(liveInBlock).second)
- continue;
-
- // If a predecessor is a defining block, either:
- // - It has a load before its first store, in which case it is live-in but
- // has already been processed in the initialisation step.
- // - It has a store before any load, in which case it is not live-in.
- // We can thus at this stage insert to the worklist only predecessors that
- // are not defining blocks.
- for (Block *pred : liveInBlock->getPredecessors())
- if (!definingBlocksByRegion[pred->getParent()].contains(pred))
- liveInWorkList.push_back(pred);
-
- // The logic is a little more complicated for region predecessors as they
- // could be in the middle of a block. We thus need to look for a store
- // within the predecessor block specifically before the region predecessor
- // operation.
- if (liveInBlock->isEntryBlock() &&
- liveInBlock->getParent() != slot.ptr.getParentRegion()) {
- regionPredecessorScratch.clear();
- auto parentOp =
- cast<PromotableRegionOpInterface>(liveInBlock->getParentOp());
- parentOp.propagateLiveIn(slot, liveInBlock->getParent(),
- regionPredecessorScratch);
- for (Operation *pred : regionPredecessorScratch) {
- if (liveIn.contains(pred->getBlock()))
- continue;
-
- Operation *storeCandidate = pred;
- while (storeCandidate &&
- !isStoreLike(storeCandidate, slot, regionsToPromote))
- storeCandidate = storeCandidate->getPrevNode();
-
- if (!storeCandidate)
- liveInWorkList.push_back(pred->getBlock());
- }
- }
- }
-
- return liveIn;
-}
-
using IDFCalculator = llvm::IDFCalculatorBase<Block, false>;
void MemorySlotPromotionAnalyzer::computeMergePoints(
Region *region, SmallPtrSetImpl<Block *> &definingBlocks,
- SmallPtrSetImpl<Block *> &slotLiveIn,
SmallPtrSetImpl<Block *> &mergePoints) {
if (region->hasOneBlock())
return;
IDFCalculator idfCalculator(dominance.getDomTree(region));
-
idfCalculator.setDefiningBlocks(definingBlocks);
- idfCalculator.setLiveInBlocks(slotLiveIn);
SmallVector<Block *> mergePointsVec;
idfCalculator.calculate(mergePointsVec);
@@ -619,17 +499,11 @@ MemorySlotPromotionAnalyzer::computeInfo() {
definingBlocks[region->getParentRegion()].insert(
region->getParentOp()->getBlock());
- // TODO: When all regions involved are single-block (fairly common in
- // region-based control-flow), there cannot be any merge points, so we could
- // skip this costly analysis and its dependencies.
- SmallPtrSet<Block *, 16> slotLiveIn =
- computeSlotLiveIn(definingBlocks, info.regionsToPromote);
-
// Then, compute blocks in which two or more definitions of the allocated
// variable may conflict. These blocks will need a new block argument to
// accommodate this.
for (auto &[region, defBlocks] : definingBlocks)
- computeMergePoints(region, defBlocks, slotLiveIn, info.mergePoints);
+ computeMergePoints(region, defBlocks, info.mergePoints);
// The slot can be promoted if the block arguments to be created can
// actually be populated with values, which may not be possible depending
@@ -752,28 +626,11 @@ void MemorySlotPromoter::promoteInRegion(Region *region, Value reachingDef) {
if (info.mergePoints.contains(block)) {
BlockArgument blockArgument =
block->addArgument(slot.elemType, slot.ptr.getLoc());
- builder.setInsertionPointToStart(block);
- allocator.handleBlockArgument(slot, blockArgument, builder);
job.reachingDef = blockArgument;
-
- if (statistics.newBlockArgumentAmount)
- (*statistics.newBlockArgumentAmount)++;
}
job.reachingDef = promoteInBlock(block, job.reachingDef);
- if (auto terminator = dyn_cast<BranchOpInterface>(block->getTerminator())) {
- for (BlockOperand &blockOperand : terminator->getBlockOperands()) {
- if (info.mergePoints.contains(blockOperand.get())) {
- if (!job.reachingDef)
- job.reachingDef = getOrCreateDefaultValue();
-
- terminator.getSuccessorOperands(blockOperand.getOperandNumber())
- .append(job.reachingDef);
- }
- }
- }
-
for (auto *child : job.block->children())
dfsStack.emplace_back<DfsJob>({child, job.reachingDef});
}
@@ -861,6 +718,54 @@ void MemorySlotPromoter::removeBlockingUses(Region *region) {
}
}
+void MemorySlotPromoter::linkMergePoints() {
+ // We want to eliminate unused block arguments. In case connecting a block
+ // argument to its predecessor would trigger the use of the predecessor's
+ // unused block argument, we need to process merge points in an expanding
+ // worklist, mergePointsToProcess.
+
+ SmallPtrSet<BlockArgument, 8> mergePointArgsUnused;
+ SmallVector<BlockArgument> mergePointArgsToProcess;
+ for (Block *mergePoint : info.mergePoints) {
+ BlockArgument arg = mergePoint->getArguments().back();
+ if (arg.use_empty())
+ mergePointArgsUnused.insert(arg);
+ else
+ mergePointArgsToProcess.push_back(arg);
+ }
+
+ while (!mergePointArgsToProcess.empty()) {
+ BlockArgument arg = mergePointArgsToProcess.pop_back_val();
+ Block *mergePoint = arg.getOwner();
+
+ for (BlockOperand &use : mergePoint->getUses()) {
+ Value reachingDef = reachingAtBlockEnd[use.getOwner()->getBlock()];
+ if (!reachingDef)
+ reachingDef = getOrCreateDefaultValue();
+
+ // If the reaching definition is a block argument of an unused merge
+ // point, mark it as used and process it as such later.
+ auto reachingDefArgument = dyn_cast<BlockArgument>(reachingDef);
+ if (reachingDefArgument &&
+ mergePointArgsUnused.erase(reachingDefArgument))
+ mergePointArgsToProcess.push_back(reachingDefArgument);
+
+ BranchOpInterface user = cast<BranchOpInterface>(use.getOwner());
+ user.getSuccessorOperands(use.getOperandNumber()).append(reachingDef);
+ }
+
+ builder.setInsertionPointToStart(mergePoint);
+ allocator.handleBlockArgument(slot, arg, builder);
+ if (statistics.newBlockArgumentAmount)
+ (*statistics.newBlockArgumentAmount)++;
+ }
+
+ for (BlockArgument arg : mergePointArgsUnused) {
+ Block *mergePoint = arg.getOwner();
+ mergePoint->eraseArgument(mergePoint->getNumArguments() - 1);
+ }
+}
+
std::optional<PromotableAllocationOpInterface>
MemorySlotPromoter::promoteSlot() {
// Perform the promotion recursively through nested regions. The reaching
@@ -876,26 +781,15 @@ MemorySlotPromoter::promoteSlot() {
op.visitReplacedValues(replacedValues, builder);
}
+ // Finally, connect merge points to their predecessor's reaching definitions.
+ linkMergePoints();
+
for (Operation *toEraseOp : toErase)
toEraseOp->erase();
assert(slot.ptr.use_empty() &&
"after promotion, the slot pointer should not be used anymore");
- // Update terminators in dead branches to forward default if they are
- // succeeded by a merge points.
- for (Block *mergePoint : info.mergePoints) {
- for (BlockOperand &use : mergePoint->getUses()) {
- auto user = cast<BranchOpInterface>(use.getOwner());
- SuccessorOperands succOperands =
- user.getSuccessorOperands(use.getOperandNumber());
- assert(succOperands.size() == mergePoint->getNumArguments() ||
- succOperands.size() + 1 == mergePoint->getNumArguments());
- if (succOperands.size() + 1 == mergePoint->getNumArguments())
- succOperands.append(getOrCreateDefaultValue());
- }
- }
-
LDBG() << "Promoted memory slot: " << slot.ptr;
if (statistics.promotedAmount)
diff --git a/mlir/test/Dialect/LLVMIR/mem2reg.mlir b/mlir/test/Dialect/LLVMIR/mem2reg.mlir
index 716a5860a0c07..779d72000d543 100644
--- a/mlir/test/Dialect/LLVMIR/mem2reg.mlir
+++ b/mlir/test/Dialect/LLVMIR/mem2reg.mlir
@@ -613,24 +613,20 @@ llvm.func @use(i64)
// -----
-// This test should no longer be an issue once promotion within subregions
-// is supported.
// CHECK-LABEL: llvm.func @subregion_block_promotion
// CHECK-SAME: (%[[ARG0:.*]]: i64, %[[ARG1:.*]]: i64) -> i64
llvm.func @subregion_block_promotion(%arg0: i64, %arg1: i64) -> i64 {
%0 = llvm.mlir.constant(1 : i32) : i32
- // CHECK: %[[ALLOCA:.*]] = llvm.alloca
+ // CHECK-NOT: = llvm.alloca
%1 = llvm.alloca %0 x i64 {alignment = 8 : i64} : (i32) -> !llvm.ptr
- // CHECK: llvm.store %[[ARG1]], %[[ALLOCA]]
llvm.store %arg1, %1 {alignment = 4 : i64} : i64, !llvm.ptr
- // CHECK: scf.execute_region {
+ // CHECK: %[[RES:.*]] = scf.execute_region -> i64 {
scf.execute_region {
- // CHECK: llvm.store %[[ARG0]], %[[ALLOCA]]
llvm.store %arg0, %1 {alignment = 4 : i64} : i64, !llvm.ptr
scf.yield
}
+ // CHECK: scf.yield %[[ARG0]] : i64
// CHECK: }
- // CHECK: %[[RES:.*]] = llvm.load %[[ALLOCA]]
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i64
// CHECK: llvm.return %[[RES]] : i64
llvm.return %2 : i64
diff --git a/mlir/test/Dialect/SCF/mem2reg.mlir b/mlir/test/Dialect/SCF/mem2reg.mlir
new file mode 100644
index 0000000000000..c8be552a71f36
--- /dev/null
+++ b/mlir/test/Dialect/SCF/mem2reg.mlir
@@ -0,0 +1,254 @@
+// RUN: mlir-opt %s --mem2reg --split-input-file | FileCheck %s \
+// RUN: -implicit-check-not "memref.alloca" \
+// RUN: -implicit-check-not "memref.load" \
+// RUN: -implicit-check-not "memref.store"
+
+// Check regions within if are promoted.
+
+// CHECK-LABEL: func.func @if_load_only
+// CHECK-SAME: (%[[COND:.*]]: i1)
+// CHECK: %[[C5:.*]] = arith.constant 5 : i32
+// CHECK: %[[RES:.*]] = scf.if %[[COND]] -> (i32)
+// CHECK: scf.yield %[[C5]] : i32
+// CHECK: } else {
+// CHECK: scf.yield %[[C5]] : i32
+// CHECK: }
+// CHECK: return %[[RES]] : i32
+func.func @if_load_only(%cond: i1) -> i32 {
+ %c5 = arith.constant 5 : i32
+ %alloca = memref.alloca() : memref<i32>
+ memref.store %c5, %alloca[] : memref<i32>
+ %res = scf.if %cond -> i32 {
+ %load = memref.load %alloca[] : memref<i32>
+ scf.yield %load : i32
+ } else {
+ scf.yield %c5 : i32
+ }
+ return %res : i32
+}
+
+// -----
+
+// Check load promotion through an if with no else branch.
+
+func.func private @use(i32)
+
+// CHECK-LABEL: func.func @if_no_else_load
+// CHECK-SAME: (%[[COND:.*]]: i1)
+// CHECK: %[[C5:.*]] = arith.constant 5 : i32
+// CHECK: scf.if %[[COND]] {
+// CHECK: call @use(%[[C5]])
+// CHECK: }
+// CHECK: call @use(%[[C5]])
+func.func @if_no_else_load(%cond: i1) {
+ %c5 = arith.constant 5 : i32
+ %alloca = memref.alloca() : memref<i32>
+ memref.store %c5, %alloca[] : memref<i32>
+ scf.if %cond {
+ %load = memref.load %alloca[] : memref<i32>
+ func.call @use(%load) : (i32) -> ()
+ scf.yield
+ }
+ %load2 = memref.load %alloca[] : memref<i32>
+ func.call @use(%load2) : (i32) -> ()
+ return
+}
+
+// -----
+
+// Check store promotion through an if with no else branch.
+
+func.func private @use(i32)
+
+// CHECK-LABEL: func.func @if_no_else_store
+// CHECK-SAME: (%[[COND:.*]]: i1)
+// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : i32
+// CHECK-DAG: %[[C7:.*]] = arith.constant 7 : i32
+// CHECK: %[[IF:.*]] = scf.if %[[COND]] -> (i32)
+// CHECK: scf.yield %[[C7]] : i32
+// CHECK: } else {
+// CHECK: scf.yield %[[C5]] : i32
+// CHECK: }
+// CHECK: call @use(%[[IF]])
+func.func @if_no_else_store(%cond: i1) {
+ %c5 = arith.constant 5 : i32
+ %c7 = arith.constant 7 : i32
+ %alloca = memref.alloca() : memref<i32>
+ memref.store %c5, %alloca[] : memref<i32>
+ scf.if %cond {
+ memref.store %c7, %alloca[] : memref<i32>
+ scf.yield
+ }
+ %load = memref.load %alloca[] : memref<i32>
+ func.call @use(%load) : (i32) -> ()
+ return
+}
+
+// -----
+
+// Check store promotion through nested ifs with no else branches.
+
+func.func private @use(i32)
+
+// CHECK-LABEL: func.func @if_nested_store
+// CHECK-SAME: (%[[COND0:.*]]: i1, %[[COND1:.*]]: i1)
+// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : i32
+// CHECK-DAG: %[[C7:.*]] = arith.constant 7 : i32
+// CHECK: %[[OUTER:.*]] = scf.if %[[COND0]] -> (i32)
+// CHECK: %[[INNER:.*]] = scf.if %[[COND1]] -> (i32)
+// CHECK: scf.yield %[[C7]] : i32
+// CHECK: } else {
+// CHECK: scf.yield %[[C5]] : i32
+// CHECK: }
+// CHECK: scf.yield %[[INNER]] : i32
+// CHECK: } else {
+// CHECK: scf.yield %[[C5]] : i32
+// CHECK: }
+// CHECK: call @use(%[[OUTER]])
+func.func @if_nested_store(%cond0: i1, %cond1: i1) {
+ %c5 = arith.constant 5 : i32
+ %c7 = arith.constant 7 : i32
+ %alloca = memref.alloca() : memref<i32>
+ memref.store %c5, %alloca[] : memref<i32>
+ scf.if %cond0 {
+ scf.if %cond1 {
+ memref.store %c7, %alloca[] : memref<i32>
+ scf.yield
+ }
+ scf.yield
+ }
+ %load = memref.load %alloca[] : memref<i32>
+ func.call @use(%load) : (i32) -> ()
+ return
+}
+
+// -----
+
+// Check load promotion through execute_region.
+
+// CHECK-LABEL: func.func @execute_region_load
+// CHECK: %[[C5:.*]] = arith.constant 5 : i32
+// CHECK: %[[RES:.*]] = scf.execute_region -> i32 {
+// CHECK: scf.yield %[[C5]] : i32
+// CHECK: }
+// CHECK: return %[[RES]] : i32
+func.func @execute_region_load() -> i32 {
+ %c5 = arith.constant 5 : i32
+ %alloca = memref.alloca() : memref<i32>
+ memref.store %c5, %alloca[] : memref<i32>
+ %res = scf.execute_region -> i32 {
+ %load = memref.load %alloca[] : memref<i32>
+ scf.yield %load : i32
+ }
+ return %res : i32
+}
+
+// -----
+
+// Check store promotion through execute_region.
+
+// CHECK-LABEL: func.func @execute_region_store
+// CHECK: %[[C7:.*]] = arith.constant 7 : i32
+// CHECK: %[[RES:.*]] = scf.execute_region -> i32 {
+// CHECK: scf.yield %[[C7]] : i32
+// CHECK: }
+// CHECK: return %[[RES]] : i32
+func.func @execute_region_store() -> i32 {
+ %c5 = arith.constant 5 : i32
+ %c7 = arith.constant 7 : i32
+ %alloca = memref.alloca() : memref<i32>
+ memref.store %c5, %alloca[] : memref<i32>
+ scf.execute_region {
+ memref.store %c7, %alloca[] : memref<i32>
+ scf.yield
+ }
+ %load = memref.load %alloca[] : memref<i32>
+ return %load : i32
+}
+
+// -----
+
+// Check promotion through an execute_region with CFG control flow and a
+// nested if containing a load. This ensures a block argument is created
+// even in blocks with no direct slot use.
+
+func.func private @use(i32)
+
+// CHECK-LABEL: func.func @execute_region_cfg
+// CHECK-SAME: (%[[COND0:.*]]: i1, %[[COND1:.*]]: i1)
+// CHECK-DAG: %[[C7:.*]] = arith.constant 7 : i32
+// CHECK-DAG: %[[C9:.*]] = arith.constant 9 : i32
+// CHECK: %[[RES:.*]] = scf.execute_region -> i32 {
+// CHECK: cf.cond_br %[[COND0]], ^[[BB1:.*]], ^[[BB2:.*]]
+// CHECK: ^[[BB1]]:
+// CHECK: cf.br ^[[BB3:.*]](%[[C7]] : i32)
+// CHECK: ^[[BB2]]:
+// CHECK: cf.br ^[[BB3]](%[[C9]] : i32)
+// CHECK: ^[[BB3]](%[[VAL:.*]]: i32):
+// CHECK: scf.if %[[COND1]] {
+// CHECK: call @use(%[[VAL]])
+// CHECK: }
+// CHECK: scf.yield %[[VAL]] : i32
+// CHECK: }
+// CHECK: call @use(%[[RES]])
+func.func @execute_region_cfg(%cond0: i1, %cond1: i1) {
+ %c5 = arith.constant 5 : i32
+ %c7 = arith.constant 7 : i32
+ %c9 = arith.constant 9 : i32
+ %alloca = memref.alloca() : memref<i32>
+ memref.store %c5, %alloca[] : memref<i32>
+ scf.execute_region {
+ cf.cond_br %cond0, ^bb1, ^bb2
+ ^bb1:
+ memref.store %c7, %alloca[] : memref<i32>
+ cf.br ^bb3
+ ^bb2:
+ memref.store %c9, %alloca[] : memref<i32>
+ cf.br ^bb3
+ ^bb3:
+ scf.if %cond1 {
+ %load = memref.load %alloca[] : memref<i32>
+ func.call @use(%load) : (i32) -> ()
+ scf.yield
+ }
+ scf.yield
+ }
+ %load2 = memref.load %alloca[] : memref<i32>
+ func.call @use(%load2) : (i32) -> ()
+ return
+}
+
+// CHECK-LABEL: func.func @execute_region_cfg_no_use_at_all
+// CHECK-SAME: (%[[COND0:.*]]: i1, %[[COND1:.*]]: i1)
+// CHECK-DAG: %[[C7:.*]] = arith.constant 7 : i32
+// CHECK-DAG: %[[C9:.*]] = arith.constant 9 : i32
+// CHECK: %[[RES:.*]] = scf.execute_region -> i32 {
+// CHECK: cf.cond_br %[[COND0]], ^[[BB1:.*]], ^[[BB2:.*]]
+// CHECK: ^[[BB1]]:
+// CHECK: cf.br ^[[BB3:.*]](%[[C7]] : i32)
+// CHECK: ^[[BB2]]:
+// CHECK: cf.br ^[[BB3]](%[[C9]] : i32)
+// CHECK: ^[[BB3]](%[[VAL:.*]]: i32):
+// CHECK: scf.yield %[[VAL]] : i32
+// CHECK: }
+// CHECK: return %[[RES]] : i32
+func.func @execute_region_cfg_no_use_at_all(%cond0: i1, %cond1: i1) -> i32 {
+ %c5 = arith.constant 5 : i32
+ %c7 = arith.constant 7 : i32
+ %c9 = arith.constant 9 : i32
+ %alloca = memref.alloca() : memref<i32>
+ memref.store %c5, %alloca[] : memref<i32>
+ scf.execute_region {
+ cf.cond_br %cond0, ^bb1, ^bb2
+ ^bb1:
+ memref.store %c7, %alloca[] : memref<i32>
+ cf.br ^bb3
+ ^bb2:
+ memref.store %c9, %alloca[] : memref<i32>
+ cf.br ^bb3
+ ^bb3:
+ scf.yield
+ }
+ %load2 = memref.load %alloca[] : memref<i32>
+ return %load2 : i32
+}
diff --git a/mlir/test/Transforms/mem2reg.mlir b/mlir/test/Transforms/mem2reg.mlir
index 4b27f3305e89d..2128a5cd9ffd3 100644
--- a/mlir/test/Transforms/mem2reg.mlir
+++ b/mlir/test/Transforms/mem2reg.mlir
@@ -39,3 +39,54 @@ test.isolated_graph_region {
%a = memref.load %slot[] : memref<i32>
"test.foo"() : () -> ()
}
+
+// -----
+
+// Verifies that block arguments of merge points are not abusively treated as
+// the newly created block arguments. Here, ^merge has a pre-existing block
+// argument (%genuine) and mem2reg adds a second one for the promoted slot. The
+// slot arg then serves as the reaching definition for the follow-up merge point
+// ^final. If the unused merge point propagation logic identified merge points
+// by block rather than by specific block argument, it would confuse %genuine
+// for the slot argument to be removed and thus not eliminate the slot which
+// is unused. In other words, the genuine block argument, which is used, would
+// mask that the actual slot argument is unused.
+
+// CHECK-LABEL: func.func @merge_point_arg_not_confused
+// CHECK-SAME: (%[[COND:.*]]: i1, %[[A:.*]]: i32, %[[B:.*]]: i32) -> i32
+// CHECK: cf.cond_br %[[COND]], ^[[BB1:.*]], ^[[BB2:.*]]
+// CHECK: ^[[BB1]]:
+// CHECK: cf.br ^[[MERGE:.*]](%[[A]] : i32)
+// CHECK: ^[[BB2]]:
+// CHECK: cf.br ^[[MERGE]](%[[B]] : i32)
+// CHECK: ^[[MERGE]](%[[GENUINE:.*]]: i32):
+// CHECK: cf.cond_br %[[COND]], ^[[BB3:.*]], ^[[BB4:.*]]
+// CHECK: ^[[BB3]]:
+// CHECK: cf.br ^[[FINAL:.*]](%[[GENUINE]] : i32)
+// CHECK: ^[[BB4]]:
+// CHECK: %[[DUMMY:.*]] = arith.constant 0 : i32
+// CHECK: cf.br ^[[FINAL]](%[[DUMMY]] : i32)
+// CHECK: ^[[FINAL]](%[[FINAL_SLOT:.*]]: i32):
+// CHECK: return %[[FINAL_SLOT]] : i32
+func.func @merge_point_arg_not_confused(%cond: i1, %a: i32, %b: i32) -> i32 {
+ %alloca = memref.alloca() : memref<i32>
+ memref.store %a, %alloca[] : memref<i32>
+ cf.cond_br %cond, ^bb1, ^bb2
+^bb1:
+ memref.store %b, %alloca[] : memref<i32>
+ cf.br ^merge(%a : i32)
+^bb2:
+ cf.br ^merge(%b : i32)
+^merge(%genuine: i32):
+ cf.cond_br %cond, ^bb3, ^bb4
+^bb3:
+ memref.store %genuine, %alloca[] : memref<i32>
+ cf.br ^final
+^bb4:
+ %dummy = arith.constant 0 : i32
+ memref.store %dummy, %alloca[] : memref<i32>
+ cf.br ^final
+^final:
+ %load = memref.load %alloca[] : memref<i32>
+ return %load : i32
+}
>From 47993aa7cfc294d6960e09037c46f10e940da062 Mon Sep 17 00:00:00 2001
From: Theo Degioanni <tdegioanni at nvidia.com>
Date: Thu, 5 Mar 2026 17:26:59 +0100
Subject: [PATCH 05/15] add more tests
---
mlir/lib/Dialect/SCF/IR/MemorySlot.cpp | 26 +-
mlir/test/Dialect/SCF/mem2reg-reject.mlir | 160 ++++++
mlir/test/Dialect/SCF/mem2reg.mlir | 579 +++++++++++++++++++++-
mlir/test/Transforms/mem2reg.mlir | 23 +
4 files changed, 771 insertions(+), 17 deletions(-)
create mode 100644 mlir/test/Dialect/SCF/mem2reg-reject.mlir
diff --git a/mlir/lib/Dialect/SCF/IR/MemorySlot.cpp b/mlir/lib/Dialect/SCF/IR/MemorySlot.cpp
index 0b1bbfb17dce0..89f816ff8bf69 100644
--- a/mlir/lib/Dialect/SCF/IR/MemorySlot.cpp
+++ b/mlir/lib/Dialect/SCF/IR/MemorySlot.cpp
@@ -20,22 +20,22 @@ using namespace mlir::scf;
/// the terminator is of the provided type.
template <typename TermTy>
static void
-updateTerminator(Block *block, Value reachingDef,
+updateTerminator(Block *block, Value defaultReachingDef,
llvm::DenseMap<Block *, Value> &reachingAtBlockEnd) {
Operation *terminator = block->getTerminator();
if (!isa<TermTy>(terminator))
return;
Value blockReachingDef = reachingAtBlockEnd[block];
if (!blockReachingDef) {
- // Block is dead code or the region is not using the slot, so the reaching
- // definition is the entry reaching definition.
- blockReachingDef = reachingDef;
+ // Block is dead code or the region is not using the slot, so we use the
+ // default provided reaching definition.
+ blockReachingDef = defaultReachingDef;
}
terminator->insertOperands(terminator->getNumOperands(), {blockReachingDef});
}
-/// Creates a shallow copy of an operation with new result types moving the
-/// regions out of the original operation, then deletes the original operation.
+/// Creates a shallow copy of an operation with new result types, moving the
+/// regions out of the original operation and deleting the original operation.
static Operation *replaceWithNewResults(RewriterBase &rewriter, Operation *op,
TypeRange resultTypes) {
RewriterBase::InsertionGuard guard(rewriter);
@@ -103,9 +103,12 @@ void ForOp::setupPromotion(
const MemorySlot &slot, Value reachingDef, bool hasValueStores,
llvm::SmallMapVector<Region *, Value, 2> ®ionsToProcess) {
Region &bodyRegion = getBodyRegion();
- if (!hasValueStores)
+ if (!hasValueStores) {
regionsToProcess.insert({&bodyRegion, reachingDef});
+ return;
+ }
+ getInitArgsMutable().append(reachingDef);
bodyRegion.addArgument(slot.elemType, slot.ptr.getLoc());
regionsToProcess.insert({&bodyRegion, bodyRegion.getArguments().back()});
}
@@ -311,6 +314,8 @@ void WhileOp::setupPromotion(
return;
}
+ getInitsMutable().append(reachingDef);
+
beforeRegion.addArgument(slot.elemType, slot.ptr.getLoc());
regionsToProcess.insert({&beforeRegion, beforeRegion.getArguments().back()});
@@ -326,10 +331,11 @@ Value WhileOp::finalizePromotion(
// Update the yield terminators to return the newly defined reaching
// definition.
- updateTerminator<ConditionOp>(&getBefore().back(), reachingDef,
+ updateTerminator<ConditionOp>(&getBefore().back(),
+ getBefore().getArguments().back(),
reachingAtBlockEnd);
- updateTerminator<YieldOp>(&getAfter().back(), reachingDef,
- reachingAtBlockEnd);
+ updateTerminator<YieldOp>(
+ &getAfter().back(), getAfter().getArguments().back(), reachingAtBlockEnd);
SmallVector<Type> resultTypes(getResultTypes());
resultTypes.push_back(slot.elemType);
diff --git a/mlir/test/Dialect/SCF/mem2reg-reject.mlir b/mlir/test/Dialect/SCF/mem2reg-reject.mlir
new file mode 100644
index 0000000000000..9497bf47ff09b
--- /dev/null
+++ b/mlir/test/Dialect/SCF/mem2reg-reject.mlir
@@ -0,0 +1,160 @@
+// RUN: mlir-opt %s --mem2reg --split-input-file | FileCheck %s
+
+// Check that a store inside a forall prevents promotion.
+
+// CHECK-LABEL: func.func @forall_store
+// CHECK-SAME: (%[[UB:.*]]: index)
+// CHECK: %[[C5:.*]] = arith.constant 5 : i32
+// CHECK: %[[C7:.*]] = arith.constant 7 : i32
+// CHECK: %[[ALLOCA:.*]] = memref.alloca() : memref<i32>
+// CHECK: memref.store %[[C5]], %[[ALLOCA]][]
+// CHECK: scf.forall (%{{.*}}) in (%[[UB]]) {
+// CHECK: memref.store %[[C7]], %[[ALLOCA]][]
+// CHECK: }
+// CHECK: %[[LOAD:.*]] = memref.load %[[ALLOCA]][]
+// CHECK: return %[[LOAD]] : i32
+func.func @forall_store(%ub: index) -> i32 {
+ %c5 = arith.constant 5 : i32
+ %c7 = arith.constant 7 : i32
+ %alloca = memref.alloca() : memref<i32>
+ memref.store %c5, %alloca[] : memref<i32>
+ scf.forall (%i) in (%ub) {
+ memref.store %c7, %alloca[] : memref<i32>
+ }
+ %load = memref.load %alloca[] : memref<i32>
+ return %load : i32
+}
+
+// -----
+
+// Check that a store inside an if inside a forall prevents promotion.
+
+// CHECK-LABEL: func.func @forall_if_store
+// CHECK-SAME: (%[[UB:.*]]: index, %[[COND:.*]]: i1)
+// CHECK: %[[C5:.*]] = arith.constant 5 : i32
+// CHECK: %[[C7:.*]] = arith.constant 7 : i32
+// CHECK: %[[ALLOCA:.*]] = memref.alloca() : memref<i32>
+// CHECK: memref.store %[[C5]], %[[ALLOCA]][]
+// CHECK: scf.forall (%{{.*}}) in (%[[UB]]) {
+// CHECK: scf.if %[[COND]] {
+// CHECK: memref.store %[[C7]], %[[ALLOCA]][]
+// CHECK: }
+// CHECK: }
+// CHECK: %[[LOAD:.*]] = memref.load %[[ALLOCA]][]
+// CHECK: return %[[LOAD]] : i32
+func.func @forall_if_store(%ub: index, %cond: i1) -> i32 {
+ %c5 = arith.constant 5 : i32
+ %c7 = arith.constant 7 : i32
+ %alloca = memref.alloca() : memref<i32>
+ memref.store %c5, %alloca[] : memref<i32>
+ scf.forall (%i) in (%ub) {
+ scf.if %cond {
+ memref.store %c7, %alloca[] : memref<i32>
+ scf.yield
+ }
+ }
+ %load = memref.load %alloca[] : memref<i32>
+ return %load : i32
+}
+
+// -----
+
+// Check that a store inside a parallel prevents promotion.
+
+// CHECK-LABEL: func.func @parallel_store
+// CHECK-SAME: (%[[LB:.*]]: index, %[[UB:.*]]: index, %[[STEP:.*]]: index)
+// CHECK: %[[C5:.*]] = arith.constant 5 : i32
+// CHECK: %[[C7:.*]] = arith.constant 7 : i32
+// CHECK: %[[ALLOCA:.*]] = memref.alloca() : memref<i32>
+// CHECK: memref.store %[[C5]], %[[ALLOCA]][]
+// CHECK: scf.parallel (%{{.*}}) = (%[[LB]]) to (%[[UB]]) step (%[[STEP]]) {
+// CHECK: memref.store %[[C7]], %[[ALLOCA]][]
+// CHECK: scf.reduce
+// CHECK: }
+// CHECK: %[[LOAD:.*]] = memref.load %[[ALLOCA]][]
+// CHECK: return %[[LOAD]] : i32
+func.func @parallel_store(%lb: index, %ub: index, %step: index) -> i32 {
+ %c5 = arith.constant 5 : i32
+ %c7 = arith.constant 7 : i32
+ %alloca = memref.alloca() : memref<i32>
+ memref.store %c5, %alloca[] : memref<i32>
+ scf.parallel (%i) = (%lb) to (%ub) step (%step) {
+ memref.store %c7, %alloca[] : memref<i32>
+ scf.reduce
+ }
+ %load = memref.load %alloca[] : memref<i32>
+ return %load : i32
+}
+
+// -----
+
+// Check that a store inside an if inside a parallel prevents promotion.
+
+// CHECK-LABEL: func.func @parallel_if_store
+// CHECK-SAME: (%[[LB:.*]]: index, %[[UB:.*]]: index, %[[STEP:.*]]: index, %[[COND:.*]]: i1)
+// CHECK: %[[C5:.*]] = arith.constant 5 : i32
+// CHECK: %[[C7:.*]] = arith.constant 7 : i32
+// CHECK: %[[ALLOCA:.*]] = memref.alloca() : memref<i32>
+// CHECK: memref.store %[[C5]], %[[ALLOCA]][]
+// CHECK: scf.parallel (%{{.*}}) = (%[[LB]]) to (%[[UB]]) step (%[[STEP]]) {
+// CHECK: scf.if %[[COND]] {
+// CHECK: memref.store %[[C7]], %[[ALLOCA]][]
+// CHECK: }
+// CHECK: scf.reduce
+// CHECK: }
+// CHECK: %[[LOAD:.*]] = memref.load %[[ALLOCA]][]
+// CHECK: return %[[LOAD]] : i32
+func.func @parallel_if_store(%lb: index, %ub: index, %step: index, %cond: i1) -> i32 {
+ %c5 = arith.constant 5 : i32
+ %c7 = arith.constant 7 : i32
+ %alloca = memref.alloca() : memref<i32>
+ memref.store %c5, %alloca[] : memref<i32>
+ scf.parallel (%i) = (%lb) to (%ub) step (%step) {
+ scf.if %cond {
+ memref.store %c7, %alloca[] : memref<i32>
+ scf.yield
+ }
+ scf.reduce
+ }
+ %load = memref.load %alloca[] : memref<i32>
+ return %load : i32
+}
+
+// -----
+
+// Check that a store inside a reduce region prevents promotion.
+
+// CHECK-LABEL: func.func @parallel_reduce_store
+// CHECK-SAME: (%[[LB:.*]]: index, %[[UB:.*]]: index, %[[STEP:.*]]: index)
+// CHECK: %[[C5:.*]] = arith.constant 5 : i32
+// CHECK: %[[C0:.*]] = arith.constant 0 : i32
+// CHECK: %[[ALLOCA:.*]] = memref.alloca() : memref<i32>
+// CHECK: memref.store %[[C5]], %[[ALLOCA]][]
+// CHECK: %[[RES:.*]] = scf.parallel (%{{.*}}) = (%[[LB]]) to (%[[UB]]) step (%[[STEP]]) init (%[[C0]]) -> i32 {
+// CHECK: %[[C1:.*]] = arith.constant 1 : i32
+// CHECK: scf.reduce(%[[C1]] : i32) {
+// CHECK: ^{{.*}}(%[[LHS:.*]]: i32, %[[RHS:.*]]: i32):
+// CHECK: %[[SUM:.*]] = arith.addi %[[LHS]], %[[RHS]] : i32
+// CHECK: memref.store %[[SUM]], %[[ALLOCA]][]
+// CHECK: scf.reduce.return %[[SUM]] : i32
+// CHECK: }
+// CHECK: }
+// CHECK: %[[LOAD:.*]] = memref.load %[[ALLOCA]][]
+// CHECK: return %[[LOAD]] : i32
+func.func @parallel_reduce_store(%lb: index, %ub: index, %step: index) -> i32 {
+ %c5 = arith.constant 5 : i32
+ %c0 = arith.constant 0 : i32
+ %alloca = memref.alloca() : memref<i32>
+ memref.store %c5, %alloca[] : memref<i32>
+ %res = scf.parallel (%i) = (%lb) to (%ub) step (%step) init (%c0) -> i32 {
+ %c1 = arith.constant 1 : i32
+ scf.reduce(%c1 : i32) {
+ ^bb0(%lhs: i32, %rhs: i32):
+ %sum = arith.addi %lhs, %rhs : i32
+ memref.store %sum, %alloca[] : memref<i32>
+ scf.reduce.return %sum : i32
+ }
+ }
+ %load = memref.load %alloca[] : memref<i32>
+ return %load : i32
+}
diff --git a/mlir/test/Dialect/SCF/mem2reg.mlir b/mlir/test/Dialect/SCF/mem2reg.mlir
index c8be552a71f36..2143bae9d4f23 100644
--- a/mlir/test/Dialect/SCF/mem2reg.mlir
+++ b/mlir/test/Dialect/SCF/mem2reg.mlir
@@ -126,21 +126,24 @@ func.func @if_nested_store(%cond0: i1, %cond1: i1) {
// Check load promotion through execute_region.
+func.func private @use(i32)
+
// CHECK-LABEL: func.func @execute_region_load
// CHECK: %[[C5:.*]] = arith.constant 5 : i32
-// CHECK: %[[RES:.*]] = scf.execute_region -> i32 {
-// CHECK: scf.yield %[[C5]] : i32
+// CHECK: scf.execute_region {
+// CHECK: call @use(%[[C5]])
+// CHECK: scf.yield
// CHECK: }
-// CHECK: return %[[RES]] : i32
-func.func @execute_region_load() -> i32 {
+func.func @execute_region_load() {
%c5 = arith.constant 5 : i32
%alloca = memref.alloca() : memref<i32>
memref.store %c5, %alloca[] : memref<i32>
- %res = scf.execute_region -> i32 {
+ scf.execute_region {
%load = memref.load %alloca[] : memref<i32>
- scf.yield %load : i32
+ func.call @use(%load) : (i32) -> ()
+ scf.yield
}
- return %res : i32
+ return
}
// -----
@@ -252,3 +255,565 @@ func.func @execute_region_cfg_no_use_at_all(%cond0: i1, %cond1: i1) -> i32 {
%load2 = memref.load %alloca[] : memref<i32>
return %load2 : i32
}
+
+// CHECK-LABEL: func.func @execute_region_cfg_with_store
+// CHECK-SAME: (%[[COND0:.*]]: i1, %[[COND1:.*]]: i1)
+// CHECK-DAG: %[[C7:.*]] = arith.constant 7 : i32
+// CHECK-DAG: %[[C9:.*]] = arith.constant 9 : i32
+// CHECK-DAG: %[[C11:.*]] = arith.constant 11 : i32
+// CHECK: %[[RES:.*]] = scf.execute_region -> i32 {
+// CHECK: cf.cond_br %[[COND0]], ^[[BB1:.*]], ^[[BB2:.*]]
+// CHECK: ^[[BB1]]:
+// CHECK: cf.br ^[[BB3:.*]]
+// CHECK: ^[[BB2]]:
+// CHECK: cf.br ^[[BB3]]
+// CHECK: ^[[BB3]]:
+// CHECK: scf.yield %[[C11]] : i32
+// CHECK: }
+// CHECK: return %[[RES]] : i32
+func.func @execute_region_cfg_with_store(%cond0: i1, %cond1: i1) -> i32 {
+ %c5 = arith.constant 5 : i32
+ %c7 = arith.constant 7 : i32
+ %c9 = arith.constant 9 : i32
+ %c11 = arith.constant 11 : i32
+ %alloca = memref.alloca() : memref<i32>
+ memref.store %c5, %alloca[] : memref<i32>
+ scf.execute_region {
+ cf.cond_br %cond0, ^bb1, ^bb2
+ ^bb1:
+ memref.store %c7, %alloca[] : memref<i32>
+ cf.br ^bb3
+ ^bb2:
+ memref.store %c9, %alloca[] : memref<i32>
+ cf.br ^bb3
+ ^bb3:
+ memref.store %c11, %alloca[] : memref<i32>
+ scf.yield
+ }
+ %load2 = memref.load %alloca[] : memref<i32>
+ return %load2 : i32
+}
+
+// -----
+
+// Check promotion through a for loop with a load and store in the body.
+
+// CHECK-LABEL: func.func @for_load_and_store
+// CHECK-SAME: (%[[LB:.*]]: index, %[[UB:.*]]: index, %[[STEP:.*]]: index)
+// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : i32
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : i32
+// CHECK: %[[RES:.*]] = scf.for %{{.*}} = %[[LB]] to %[[UB]] step %[[STEP]] iter_args(%[[ARG:.*]] = %[[C5]]) -> (i32) {
+// CHECK: %[[NEW:.*]] = arith.addi %[[ARG]], %[[C1]] : i32
+// CHECK: scf.yield %[[NEW]] : i32
+// CHECK: }
+// CHECK: return %[[RES]] : i32
+func.func @for_load_and_store(%lb: index, %ub: index, %step: index) -> i32 {
+ %c5 = arith.constant 5 : i32
+ %c1 = arith.constant 1 : i32
+ %alloca = memref.alloca() : memref<i32>
+ memref.store %c5, %alloca[] : memref<i32>
+ scf.for %i = %lb to %ub step %step {
+ %load = memref.load %alloca[] : memref<i32>
+ %new = arith.addi %load, %c1 : i32
+ memref.store %new, %alloca[] : memref<i32>
+ scf.yield
+ }
+ %load2 = memref.load %alloca[] : memref<i32>
+ return %load2 : i32
+}
+
+// -----
+
+// Check promotion adds a second iter_arg when one already exists.
+
+// CHECK-LABEL: func.func @for_existing_iter_arg
+// CHECK-SAME: (%[[LB:.*]]: index, %[[UB:.*]]: index, %[[STEP:.*]]: index, %[[INIT:.*]]: i32)
+// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : i32
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : i32
+// CHECK: %[[RES:.*]]:2 = scf.for %{{.*}} = %[[LB]] to %[[UB]] step %[[STEP]] iter_args(%[[MUL_ARG:.*]] = %[[INIT]], %[[SLOT_ARG:.*]] = %[[C5]]) -> (i32, i32) {
+// CHECK: %[[MUL:.*]] = arith.muli %[[MUL_ARG]], %[[MUL_ARG]] : i32
+// CHECK: %[[NEW:.*]] = arith.addi %[[SLOT_ARG]], %[[C1]] : i32
+// CHECK: scf.yield %[[MUL]], %[[NEW]] : i32, i32
+// CHECK: }
+// CHECK: return %[[RES]]#1 : i32
+func.func @for_existing_iter_arg(%lb: index, %ub: index, %step: index, %init: i32) -> i32 {
+ %c5 = arith.constant 5 : i32
+ %c1 = arith.constant 1 : i32
+ %alloca = memref.alloca() : memref<i32>
+ memref.store %c5, %alloca[] : memref<i32>
+ %mul_res = scf.for %i = %lb to %ub step %step iter_args(%mul_arg = %init) -> i32 {
+ %mul = arith.muli %mul_arg, %mul_arg : i32
+ %load = memref.load %alloca[] : memref<i32>
+ %new = arith.addi %load, %c1 : i32
+ memref.store %new, %alloca[] : memref<i32>
+ scf.yield %mul : i32
+ }
+ %load2 = memref.load %alloca[] : memref<i32>
+ return %load2 : i32
+}
+
+// -----
+
+// Check load-only promotion through a for loop generates no iter_arg.
+
+func.func private @use(i32)
+
+// CHECK-LABEL: func.func @for_load_only
+// CHECK-SAME: (%[[LB:.*]]: index, %[[UB:.*]]: index, %[[STEP:.*]]: index)
+// CHECK: %[[C5:.*]] = arith.constant 5 : i32
+// CHECK: scf.for %{{.*}} = %[[LB]] to %[[UB]] step %[[STEP]] {
+// CHECK: call @use(%[[C5]])
+// CHECK: }
+// CHECK: return %[[C5]] : i32
+func.func @for_load_only(%lb: index, %ub: index, %step: index) -> i32 {
+ %c5 = arith.constant 5 : i32
+ %alloca = memref.alloca() : memref<i32>
+ memref.store %c5, %alloca[] : memref<i32>
+ scf.for %i = %lb to %ub step %step {
+ %load = memref.load %alloca[] : memref<i32>
+ func.call @use(%load) : (i32) -> ()
+ scf.yield
+ }
+ %load2 = memref.load %alloca[] : memref<i32>
+ return %load2 : i32
+}
+
+// -----
+
+// Check promotion through a for loop with a store inside an if in the body.
+
+// CHECK-LABEL: func.func @for_if_store
+// CHECK-SAME: (%[[LB:.*]]: index, %[[UB:.*]]: index, %[[STEP:.*]]: index, %[[COND:.*]]: i1)
+// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : i32
+// CHECK-DAG: %[[C7:.*]] = arith.constant 7 : i32
+// CHECK: %[[RES:.*]] = scf.for %{{.*}} = %[[LB]] to %[[UB]] step %[[STEP]] iter_args(%[[ARG:.*]] = %[[C5]]) -> (i32) {
+// CHECK: %[[IF:.*]] = scf.if %[[COND]] -> (i32) {
+// CHECK: scf.yield %[[C7]] : i32
+// CHECK: } else {
+// CHECK: scf.yield %[[ARG]] : i32
+// CHECK: }
+// CHECK: scf.yield %[[IF]] : i32
+// CHECK: }
+// CHECK: return %[[RES]] : i32
+func.func @for_if_store(%lb: index, %ub: index, %step: index, %cond: i1) -> i32 {
+ %c5 = arith.constant 5 : i32
+ %c7 = arith.constant 7 : i32
+ %alloca = memref.alloca() : memref<i32>
+ memref.store %c5, %alloca[] : memref<i32>
+ scf.for %i = %lb to %ub step %step {
+ scf.if %cond {
+ memref.store %c7, %alloca[] : memref<i32>
+ scf.yield
+ }
+ scf.yield
+ }
+ %load = memref.load %alloca[] : memref<i32>
+ return %load : i32
+}
+
+// -----
+
+// Check load promotion through a forall.
+
+func.func private @use(i32)
+
+// CHECK-LABEL: func.func @forall_load
+// CHECK-SAME: (%[[UB:.*]]: index)
+// CHECK: %[[C5:.*]] = arith.constant 5 : i32
+// CHECK: scf.forall (%{{.*}}) in (%[[UB]]) {
+// CHECK: call @use(%[[C5]])
+// CHECK: }
+func.func @forall_load(%ub: index) {
+ %c5 = arith.constant 5 : i32
+ %alloca = memref.alloca() : memref<i32>
+ memref.store %c5, %alloca[] : memref<i32>
+ scf.forall (%i) in (%ub) {
+ %load = memref.load %alloca[] : memref<i32>
+ func.call @use(%load) : (i32) -> ()
+ }
+ return
+}
+
+// -----
+
+// Check promotion through a forall nested inside an if with a store.
+
+func.func private @use(i32)
+
+// CHECK-LABEL: func.func @forall_in_if
+// CHECK-SAME: (%[[UB:.*]]: index, %[[COND:.*]]: i1)
+// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : i32
+// CHECK-DAG: %[[C7:.*]] = arith.constant 7 : i32
+// CHECK: %[[RES:.*]] = scf.if %[[COND]] -> (i32) {
+// CHECK: scf.forall (%{{.*}}) in (%[[UB]]) {
+// CHECK: call @use(%[[C7]])
+// CHECK: }
+// CHECK: scf.yield %[[C7]] : i32
+// CHECK: } else {
+// CHECK: scf.yield %[[C5]] : i32
+// CHECK: }
+// CHECK: return %[[RES]] : i32
+func.func @forall_in_if(%ub: index, %cond: i1) -> i32 {
+ %c5 = arith.constant 5 : i32
+ %c7 = arith.constant 7 : i32
+ %alloca = memref.alloca() : memref<i32>
+ memref.store %c5, %alloca[] : memref<i32>
+ scf.if %cond {
+ memref.store %c7, %alloca[] : memref<i32>
+ scf.forall (%i) in (%ub) {
+ %load = memref.load %alloca[] : memref<i32>
+ func.call @use(%load) : (i32) -> ()
+ }
+ scf.yield
+ }
+ %load2 = memref.load %alloca[] : memref<i32>
+ return %load2 : i32
+}
+
+// -----
+
+// Check load promotion through a parallel.
+
+func.func private @use(i32)
+
+// CHECK-LABEL: func.func @parallel_load
+// CHECK-SAME: (%[[LB:.*]]: index, %[[UB:.*]]: index, %[[STEP:.*]]: index)
+// CHECK: %[[C5:.*]] = arith.constant 5 : i32
+// CHECK: scf.parallel (%{{.*}}) = (%[[LB]]) to (%[[UB]]) step (%[[STEP]]) {
+// CHECK: call @use(%[[C5]])
+// CHECK: scf.reduce
+// CHECK: }
+func.func @parallel_load(%lb: index, %ub: index, %step: index) {
+ %c5 = arith.constant 5 : i32
+ %alloca = memref.alloca() : memref<i32>
+ memref.store %c5, %alloca[] : memref<i32>
+ scf.parallel (%i) = (%lb) to (%ub) step (%step) {
+ %load = memref.load %alloca[] : memref<i32>
+ func.call @use(%load) : (i32) -> ()
+ scf.reduce
+ }
+ return
+}
+
+// -----
+
+// Check promotion through a parallel nested inside an if with a store.
+
+func.func private @use(i32)
+
+// CHECK-LABEL: func.func @parallel_in_if
+// CHECK-SAME: (%[[LB:.*]]: index, %[[UB:.*]]: index, %[[STEP:.*]]: index, %[[COND:.*]]: i1)
+// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : i32
+// CHECK-DAG: %[[C7:.*]] = arith.constant 7 : i32
+// CHECK: %[[RES:.*]] = scf.if %[[COND]] -> (i32) {
+// CHECK: scf.parallel (%{{.*}}) = (%[[LB]]) to (%[[UB]]) step (%[[STEP]]) {
+// CHECK: call @use(%[[C7]])
+// CHECK: scf.reduce
+// CHECK: }
+// CHECK: scf.yield %[[C7]] : i32
+// CHECK: } else {
+// CHECK: scf.yield %[[C5]] : i32
+// CHECK: }
+// CHECK: return %[[RES]] : i32
+func.func @parallel_in_if(%lb: index, %ub: index, %step: index, %cond: i1) -> i32 {
+ %c5 = arith.constant 5 : i32
+ %c7 = arith.constant 7 : i32
+ %alloca = memref.alloca() : memref<i32>
+ memref.store %c5, %alloca[] : memref<i32>
+ scf.if %cond {
+ memref.store %c7, %alloca[] : memref<i32>
+ scf.parallel (%i) = (%lb) to (%ub) step (%step) {
+ %load = memref.load %alloca[] : memref<i32>
+ func.call @use(%load) : (i32) -> ()
+ scf.reduce
+ }
+ scf.yield
+ }
+ %load2 = memref.load %alloca[] : memref<i32>
+ return %load2 : i32
+}
+
+// -----
+
+// Check load promotion inside a reduce region.
+
+// CHECK-LABEL: func.func @parallel_reduce_load
+// CHECK-SAME: (%[[LB:.*]]: index, %[[UB:.*]]: index, %[[STEP:.*]]: index)
+// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : i32
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : i32
+// CHECK: %[[RES:.*]] = scf.parallel (%{{.*}}) = (%[[LB]]) to (%[[UB]]) step (%[[STEP]]) init (%[[C0]]) -> i32 {
+// CHECK: %[[C1:.*]] = arith.constant 1 : i32
+// CHECK: scf.reduce(%[[C1]] : i32) {
+// CHECK: ^{{.*}}(%[[LHS:.*]]: i32, %[[RHS:.*]]: i32):
+// CHECK: %[[SUM:.*]] = arith.addi %[[LHS]], %[[RHS]] : i32
+// CHECK: %[[MUL:.*]] = arith.muli %[[SUM]], %[[C5]] : i32
+// CHECK: scf.reduce.return %[[MUL]] : i32
+// CHECK: }
+// CHECK: }
+// CHECK: return %[[RES]] : i32
+func.func @parallel_reduce_load(%lb: index, %ub: index, %step: index) -> i32 {
+ %c5 = arith.constant 5 : i32
+ %c0 = arith.constant 0 : i32
+ %alloca = memref.alloca() : memref<i32>
+ memref.store %c5, %alloca[] : memref<i32>
+ %res = scf.parallel (%i) = (%lb) to (%ub) step (%step) init (%c0) -> i32 {
+ %c1 = arith.constant 1 : i32
+ scf.reduce(%c1 : i32) {
+ ^bb0(%lhs: i32, %rhs: i32):
+ %sum = arith.addi %lhs, %rhs : i32
+ %load = memref.load %alloca[] : memref<i32>
+ %mul = arith.muli %sum, %load : i32
+ scf.reduce.return %mul : i32
+ }
+ }
+ return %res : i32
+}
+
+// -----
+
+// Check load promotion in the before region of a while.
+
+func.func private @use(i32)
+
+// CHECK-LABEL: func.func @while_load_before
+// CHECK-SAME: (%[[COND:.*]]: i1)
+// CHECK: %[[C5:.*]] = arith.constant 5 : i32
+// CHECK: scf.while : () -> () {
+// CHECK: call @use(%[[C5]])
+// CHECK: scf.condition(%[[COND]])
+// CHECK: } do {
+// CHECK: scf.yield
+// CHECK: }
+func.func @while_load_before(%cond: i1) {
+ %c5 = arith.constant 5 : i32
+ %alloca = memref.alloca() : memref<i32>
+ memref.store %c5, %alloca[] : memref<i32>
+ scf.while : () -> () {
+ %load = memref.load %alloca[] : memref<i32>
+ func.call @use(%load) : (i32) -> ()
+ scf.condition(%cond)
+ } do {
+ scf.yield
+ }
+ return
+}
+
+// -----
+
+// Check load promotion in the after region of a while.
+
+func.func private @use(i32)
+
+// CHECK-LABEL: func.func @while_load_after
+// CHECK-SAME: (%[[COND:.*]]: i1)
+// CHECK: %[[C5:.*]] = arith.constant 5 : i32
+// CHECK: scf.while : () -> () {
+// CHECK: scf.condition(%[[COND]])
+// CHECK: } do {
+// CHECK: call @use(%[[C5]])
+// CHECK: scf.yield
+// CHECK: }
+func.func @while_load_after(%cond: i1) {
+ %c5 = arith.constant 5 : i32
+ %alloca = memref.alloca() : memref<i32>
+ memref.store %c5, %alloca[] : memref<i32>
+ scf.while : () -> () {
+ scf.condition(%cond)
+ } do {
+ %load = memref.load %alloca[] : memref<i32>
+ func.call @use(%load) : (i32) -> ()
+ scf.yield
+ }
+ return
+}
+
+// -----
+
+// Check promotion with a store in the before region and a load in the after.
+
+func.func private @use(i32)
+
+// CHECK-LABEL: func.func @while_store_before_load_after
+// CHECK-SAME: (%[[COND:.*]]: i1)
+// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : i32
+// CHECK-DAG: %[[C7:.*]] = arith.constant 7 : i32
+// CHECK: scf.while (%[[BEFORE:.*]] = %[[C5]]) : (i32) -> i32 {
+// CHECK: scf.condition(%[[COND]]) %[[C7]] : i32
+// CHECK: } do {
+// CHECK: ^{{.*}}(%[[AFTER:.*]]: i32):
+// CHECK: call @use(%[[AFTER]])
+// CHECK: scf.yield %[[AFTER]] : i32
+// CHECK: }
+func.func @while_store_before_load_after(%cond: i1) {
+ %c5 = arith.constant 5 : i32
+ %c7 = arith.constant 7 : i32
+ %alloca = memref.alloca() : memref<i32>
+ memref.store %c5, %alloca[] : memref<i32>
+ scf.while : () -> () {
+ memref.store %c7, %alloca[] : memref<i32>
+ scf.condition(%cond)
+ } do {
+ %load = memref.load %alloca[] : memref<i32>
+ func.call @use(%load) : (i32) -> ()
+ scf.yield
+ }
+ return
+}
+
+// -----
+
+// Check promotion with a store in the before region and a load after the loop.
+
+// CHECK-LABEL: func.func @while_store_before_load_after_loop
+// CHECK-SAME: (%[[COND:.*]]: i1)
+// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : i32
+// CHECK-DAG: %[[C7:.*]] = arith.constant 7 : i32
+// CHECK: %[[RES:.*]] = scf.while (%[[BEFORE:.*]] = %[[C5]]) : (i32) -> i32 {
+// CHECK: scf.condition(%[[COND]]) %[[C7]] : i32
+// CHECK: } do {
+// CHECK: ^{{.*}}(%[[AFTER:.*]]: i32):
+// CHECK: scf.yield %[[AFTER]] : i32
+// CHECK: }
+// CHECK: return %[[RES]] : i32
+func.func @while_store_before_load_after_loop(%cond: i1) -> i32 {
+ %c5 = arith.constant 5 : i32
+ %c7 = arith.constant 7 : i32
+ %alloca = memref.alloca() : memref<i32>
+ memref.store %c5, %alloca[] : memref<i32>
+ scf.while : () -> () {
+ memref.store %c7, %alloca[] : memref<i32>
+ scf.condition(%cond)
+ } do {
+ scf.yield
+ }
+ %res = memref.load %alloca[] : memref<i32>
+ return %res : i32
+}
+
+// -----
+
+// Check store promotion through a while implementing a for loop from 0 to 10.
+
+// CHECK-LABEL: func.func @while_store
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : i32
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : i32
+// CHECK-DAG: %[[C10:.*]] = arith.constant 10 : i32
+// CHECK: %[[RES:.*]] = scf.while (%[[BEFORE:.*]] = %[[C0]]) : (i32) -> i32 {
+// CHECK: %[[COND:.*]] = arith.cmpi slt, %[[BEFORE]], %[[C10]] : i32
+// CHECK: scf.condition(%[[COND]]) %[[BEFORE]] : i32
+// CHECK: } do {
+// CHECK: ^{{.*}}(%[[AFTER:.*]]: i32):
+// CHECK: %[[NEW:.*]] = arith.addi %[[AFTER]], %[[C1]] : i32
+// CHECK: scf.yield %[[NEW]] : i32
+// CHECK: }
+// CHECK: return %[[RES]] : i32
+func.func @while_store() -> i32 {
+ %c0 = arith.constant 0 : i32
+ %c1 = arith.constant 1 : i32
+ %c10 = arith.constant 10 : i32
+ %alloca = memref.alloca() : memref<i32>
+ memref.store %c0, %alloca[] : memref<i32>
+ scf.while : () -> () {
+ %val = memref.load %alloca[] : memref<i32>
+ %cond = arith.cmpi slt, %val, %c10 : i32
+ scf.condition(%cond)
+ } do {
+ %val = memref.load %alloca[] : memref<i32>
+ %new = arith.addi %val, %c1 : i32
+ memref.store %new, %alloca[] : memref<i32>
+ scf.yield
+ }
+ %res = memref.load %alloca[] : memref<i32>
+ return %res : i32
+}
+
+// -----
+
+// Check load promotion through an index_switch default branch.
+
+func.func private @use(i32)
+
+// CHECK-LABEL: func.func @index_switch_load_default
+// CHECK-SAME: (%[[IDX:.*]]: index)
+// CHECK: %[[C5:.*]] = arith.constant 5 : i32
+// CHECK: scf.index_switch %[[IDX]]
+// CHECK: default {
+// CHECK: call @use(%[[C5]])
+// CHECK: }
+func.func @index_switch_load_default(%idx: index) {
+ %c5 = arith.constant 5 : i32
+ %alloca = memref.alloca() : memref<i32>
+ memref.store %c5, %alloca[] : memref<i32>
+ scf.index_switch %idx
+ default {
+ %load = memref.load %alloca[] : memref<i32>
+ func.call @use(%load) : (i32) -> ()
+ scf.yield
+ }
+ return
+}
+
+// -----
+
+// Check store promotion through an index_switch default branch.
+
+// CHECK-LABEL: func.func @index_switch_store_default
+// CHECK-SAME: (%[[IDX:.*]]: index)
+// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : i32
+// CHECK-DAG: %[[C7:.*]] = arith.constant 7 : i32
+// CHECK: %[[RES:.*]] = scf.index_switch %[[IDX]] -> i32
+// CHECK: default {
+// CHECK: scf.yield %[[C7]] : i32
+// CHECK: }
+// CHECK: return %[[RES]] : i32
+func.func @index_switch_store_default(%idx: index) -> i32 {
+ %c5 = arith.constant 5 : i32
+ %c7 = arith.constant 7 : i32
+ %alloca = memref.alloca() : memref<i32>
+ memref.store %c5, %alloca[] : memref<i32>
+ scf.index_switch %idx
+ default {
+ memref.store %c7, %alloca[] : memref<i32>
+ scf.yield
+ }
+ %load = memref.load %alloca[] : memref<i32>
+ return %load : i32
+}
+
+// -----
+
+// Check promotion with a store in a case and a load in the default branch.
+
+func.func private @use(i32)
+
+// CHECK-LABEL: func.func @index_switch_store_case_load_default
+// CHECK-SAME: (%[[IDX:.*]]: index)
+// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : i32
+// CHECK-DAG: %[[C7:.*]] = arith.constant 7 : i32
+// CHECK: %[[RES:.*]] = scf.index_switch %[[IDX]] -> i32
+// CHECK: case 0 {
+// CHECK: scf.yield %[[C7]] : i32
+// CHECK: }
+// CHECK: default {
+// CHECK: call @use(%[[C5]])
+// CHECK: scf.yield %[[C5]] : i32
+// CHECK: }
+// CHECK: return %[[RES]] : i32
+func.func @index_switch_store_case_load_default(%idx: index) -> i32 {
+ %c5 = arith.constant 5 : i32
+ %c7 = arith.constant 7 : i32
+ %alloca = memref.alloca() : memref<i32>
+ memref.store %c5, %alloca[] : memref<i32>
+ scf.index_switch %idx
+ case 0 {
+ memref.store %c7, %alloca[] : memref<i32>
+ scf.yield
+ }
+ default {
+ %load = memref.load %alloca[] : memref<i32>
+ func.call @use(%load) : (i32) -> ()
+ scf.yield
+ }
+ %load2 = memref.load %alloca[] : memref<i32>
+ return %load2 : i32
+}
diff --git a/mlir/test/Transforms/mem2reg.mlir b/mlir/test/Transforms/mem2reg.mlir
index 2128a5cd9ffd3..70fbddcb25b2a 100644
--- a/mlir/test/Transforms/mem2reg.mlir
+++ b/mlir/test/Transforms/mem2reg.mlir
@@ -90,3 +90,26 @@ func.func @merge_point_arg_not_confused(%cond: i1, %a: i32, %b: i32) -> i32 {
%load = memref.load %alloca[] : memref<i32>
return %load : i32
}
+
+// -----
+
+// Check that a load inside an unknown region-bearing op prevents promotion.
+
+// CHECK-LABEL: func.func @unknown_region_op_load
+// CHECK: %[[C5:.*]] = arith.constant 5 : i32
+// CHECK: %[[ALLOCA:.*]] = memref.alloca() : memref<i32>
+// CHECK: memref.store %[[C5]], %[[ALLOCA]][]
+// CHECK: "test.one_region_op"() ({
+// CHECK: %[[LOAD:.*]] = memref.load %[[ALLOCA]][]
+// CHECK: "test.finish"() : () -> ()
+// CHECK: }) : () -> ()
+func.func @unknown_region_op_load() {
+ %c5 = arith.constant 5 : i32
+ %alloca = memref.alloca() : memref<i32>
+ memref.store %c5, %alloca[] : memref<i32>
+ "test.one_region_op"() ({
+ %load = memref.load %alloca[] : memref<i32>
+ "test.finish"() : () -> ()
+ }) : () -> ()
+ return
+}
>From 24f7e0475f61ce96f0d6716d32fd48c963d6be9b Mon Sep 17 00:00:00 2001
From: Theo Degioanni <tdegioanni at nvidia.com>
Date: Fri, 6 Mar 2026 17:30:00 +0100
Subject: [PATCH 06/15] fix various typos
---
mlir/include/mlir/Interfaces/MemorySlotInterfaces.td | 12 ++++++------
mlir/lib/Transforms/Mem2Reg.cpp | 2 +-
2 files changed, 7 insertions(+), 7 deletions(-)
diff --git a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
index ab2cb39227525..5c2c0c9248317 100644
--- a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
+++ b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
@@ -242,7 +242,7 @@ def PromotableOpInterface : OpInterface<"PromotableOpInterface"> {
operation will be called after the main mutation stage finishes
(i.e., after all ops have been processed with `removeBlockingUses`).
- Operations should only the replaced values if the intended
+ Operations should only visit the replaced values if the intended
transformation applies to all the replaced values. Furthermore, replaced
values must not be deleted.
}], "bool", "requiresReplacedValues", (ins), [{}],
@@ -292,7 +292,7 @@ def PromotableRegionOpInterface
Based on the `reachingDef` value representing the value in the memory
slot at the entry into the operation, `setupPromotion` fills in the
- `regionsToProcess` with the the reaching definition at the entry of
+ `regionsToProcess` with the reaching definition at the entry of
all its promotable regions.
`setupPromotion` is allowed to mutate
@@ -303,7 +303,7 @@ def PromotableRegionOpInterface
The `hasValueStores` flag indicates whether the regions contain
`store`-like operations that write to the memory slot. This field can be
used to reduce the amount of book-keeping required to track the reaching
- definitions, but is correct to consider it always true.
+ definitions.
}], "void", "setupPromotion",
(ins
"const ::mlir::MemorySlot &":$slot,
@@ -336,7 +336,7 @@ def PromotableRegionOpInterface
The `hasValueStores` flag indicates whether the regions contain
`store`-like operations that write to the memory slot. This field can be
used to reduce the amount of book-keeping required to track the reaching
- definitions, but is correct to consider it always true.
+ definitions.
}],
"::mlir::Value", "finalizePromotion",
(ins
@@ -391,7 +391,7 @@ def DestructurableAllocationOpInterface
>,
InterfaceMethod<[{
Hook triggered once the destructuring of a slot is complete, meaning the
- original slot is no longer being refered to and could be deleted.
+ original slot is no longer being referred to and could be deleted.
This will only be called for slots declared by this operation.
Must return a new destructurable allocation op if this hook creates
@@ -415,7 +415,7 @@ def SafeMemorySlotAccessOpInterface
let methods = [
InterfaceMethod<[{
Returns whether all accesses in this operation to the provided slot are
- done in a safe manner. To be safe, the access most only access the slot
+ done in a safe manner. To be safe, the access must only access the slot
inside the bounds that its type implies.
If the safety of the accesses depends on the safety of the accesses to
diff --git a/mlir/lib/Transforms/Mem2Reg.cpp b/mlir/lib/Transforms/Mem2Reg.cpp
index 6c1704331587e..f17b186c03f1d 100644
--- a/mlir/lib/Transforms/Mem2Reg.cpp
+++ b/mlir/lib/Transforms/Mem2Reg.cpp
@@ -722,7 +722,7 @@ void MemorySlotPromoter::linkMergePoints() {
// We want to eliminate unused block arguments. In case connecting a block
// argument to its predecessor would trigger the use of the predecessor's
// unused block argument, we need to process merge points in an expanding
- // worklist, mergePointsToProcess.
+ // worklist, `mergePointArgsToProcess`.
SmallPtrSet<BlockArgument, 8> mergePointArgsUnused;
SmallVector<BlockArgument> mergePointArgsToProcess;
>From 00774f37e92964349d847c8d7c022b8d4b07ce67 Mon Sep 17 00:00:00 2001
From: Theo Degioanni <tdegioanni at nvidia.com>
Date: Mon, 9 Mar 2026 16:54:30 +0100
Subject: [PATCH 07/15] address various comments
---
.../mlir/Interfaces/MemorySlotInterfaces.td | 2 +-
mlir/lib/Dialect/SCF/IR/MemorySlot.cpp | 28 +++++---
mlir/lib/Transforms/Mem2Reg.cpp | 9 +--
mlir/test/Dialect/SCF/mem2reg.mlir | 69 +++++++++++++++++++
4 files changed, 90 insertions(+), 18 deletions(-)
diff --git a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
index 5c2c0c9248317..350adce2963ad 100644
--- a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
+++ b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
@@ -343,7 +343,7 @@ def PromotableRegionOpInterface
"const ::mlir::MemorySlot &":$slot,
"::mlir::Value":$entryReachingDef,
"bool":$hasValueStores,
- "::llvm::DenseMap<::mlir::Block *, ::mlir::Value> &":$reachingAtBlockEnd,
+ "const ::llvm::DenseMap<::mlir::Block *, ::mlir::Value> &":$reachingAtBlockEnd,
"::mlir::OpBuilder &":$builder
)
>,
diff --git a/mlir/lib/Dialect/SCF/IR/MemorySlot.cpp b/mlir/lib/Dialect/SCF/IR/MemorySlot.cpp
index 89f816ff8bf69..3d61476df6014 100644
--- a/mlir/lib/Dialect/SCF/IR/MemorySlot.cpp
+++ b/mlir/lib/Dialect/SCF/IR/MemorySlot.cpp
@@ -21,11 +21,11 @@ using namespace mlir::scf;
template <typename TermTy>
static void
updateTerminator(Block *block, Value defaultReachingDef,
- llvm::DenseMap<Block *, Value> &reachingAtBlockEnd) {
+ const llvm::DenseMap<Block *, Value> &reachingAtBlockEnd) {
Operation *terminator = block->getTerminator();
if (!isa<TermTy>(terminator))
return;
- Value blockReachingDef = reachingAtBlockEnd[block];
+ Value blockReachingDef = reachingAtBlockEnd.lookup(block);
if (!blockReachingDef) {
// Block is dead code or the region is not using the slot, so we use the
// default provided reaching definition.
@@ -72,7 +72,8 @@ void ExecuteRegionOp::setupPromotion(
Value ExecuteRegionOp::finalizePromotion(
const MemorySlot &slot, Value reachingDef, bool hasValueStores,
- llvm::DenseMap<Block *, Value> &reachingAtBlockEnd, OpBuilder &builder) {
+ const llvm::DenseMap<Block *, Value> &reachingAtBlockEnd,
+ OpBuilder &builder) {
if (!hasValueStores)
return reachingDef;
@@ -115,7 +116,8 @@ void ForOp::setupPromotion(
Value ForOp::finalizePromotion(
const MemorySlot &slot, Value reachingDef, bool hasValueStores,
- llvm::DenseMap<Block *, Value> &reachingAtBlockEnd, OpBuilder &builder) {
+ const llvm::DenseMap<Block *, Value> &reachingAtBlockEnd,
+ OpBuilder &builder) {
if (!hasValueStores)
return reachingDef;
@@ -152,7 +154,8 @@ void ForallOp::setupPromotion(
Value ForallOp::finalizePromotion(
const MemorySlot &slot, Value reachingDef, bool hasValueStores,
- llvm::DenseMap<Block *, Value> &reachingAtBlockEnd, OpBuilder &builder) {
+ const llvm::DenseMap<Block *, Value> &reachingAtBlockEnd,
+ OpBuilder &builder) {
assert(!hasValueStores && "ForallOp does not support stores");
return reachingDef;
}
@@ -175,7 +178,8 @@ void IfOp::setupPromotion(
Value IfOp::finalizePromotion(
const MemorySlot &slot, Value reachingDef, bool hasValueStores,
- llvm::DenseMap<Block *, Value> &reachingAtBlockEnd, OpBuilder &builder) {
+ const llvm::DenseMap<Block *, Value> &reachingAtBlockEnd,
+ OpBuilder &builder) {
if (!hasValueStores)
return reachingDef;
@@ -221,7 +225,8 @@ void IndexSwitchOp::setupPromotion(
Value IndexSwitchOp::finalizePromotion(
const MemorySlot &slot, Value reachingDef, bool hasValueStores,
- llvm::DenseMap<Block *, Value> &reachingAtBlockEnd, OpBuilder &builder) {
+ const llvm::DenseMap<Block *, Value> &reachingAtBlockEnd,
+ OpBuilder &builder) {
if (!hasValueStores)
return reachingDef;
@@ -263,7 +268,8 @@ void ParallelOp::setupPromotion(
Value ParallelOp::finalizePromotion(
const MemorySlot &slot, Value reachingDef, bool hasValueStores,
- llvm::DenseMap<Block *, Value> &reachingAtBlockEnd, OpBuilder &builder) {
+ const llvm::DenseMap<Block *, Value> &reachingAtBlockEnd,
+ OpBuilder &builder) {
assert(!hasValueStores && "ParallelOp does not support stores");
return reachingDef;
}
@@ -289,7 +295,8 @@ void ReduceOp::setupPromotion(
Value ReduceOp::finalizePromotion(
const MemorySlot &slot, Value reachingDef, bool hasValueStores,
- llvm::DenseMap<Block *, Value> &reachingAtBlockEnd, OpBuilder &builder) {
+ const llvm::DenseMap<Block *, Value> &reachingAtBlockEnd,
+ OpBuilder &builder) {
assert(!hasValueStores && "ReduceOp does not support stores");
return reachingDef;
}
@@ -325,7 +332,8 @@ void WhileOp::setupPromotion(
Value WhileOp::finalizePromotion(
const MemorySlot &slot, Value reachingDef, bool hasValueStores,
- llvm::DenseMap<Block *, Value> &reachingAtBlockEnd, OpBuilder &builder) {
+ const llvm::DenseMap<Block *, Value> &reachingAtBlockEnd,
+ OpBuilder &builder) {
if (!hasValueStores)
return reachingDef;
diff --git a/mlir/lib/Transforms/Mem2Reg.cpp b/mlir/lib/Transforms/Mem2Reg.cpp
index f17b186c03f1d..916b510f338a6 100644
--- a/mlir/lib/Transforms/Mem2Reg.cpp
+++ b/mlir/lib/Transforms/Mem2Reg.cpp
@@ -515,7 +515,6 @@ MemorySlotPromotionAnalyzer::computeInfo() {
}
Value MemorySlotPromoter::promoteInBlock(Block *block, Value reachingDef) {
- llvm::SmallMapVector<Region *, Value, 2> regionsToProcess;
SmallVector<Operation *> blockOps;
for (Operation &op : block->getOperations())
blockOps.push_back(&op);
@@ -558,7 +557,7 @@ Value MemorySlotPromoter::promoteInBlock(Block *block, Value reachingDef) {
}
if (needsPromotion) {
- regionsToProcess.clear();
+ llvm::SmallMapVector<Region *, Value, 2> regionsToProcess;
// To not expose default value creation to the interfaces, if we have
// no reaching definition by now, we set it to the default value.
@@ -578,12 +577,8 @@ Value MemorySlotPromoter::promoteInBlock(Block *block, Value reachingDef) {
#endif // NDEBUG
for (auto &[region, reachingDef] : regionsToProcess) {
-#ifndef NDEBUG
- Region *regionCapture = region;
- assert(llvm::any_of(op->getRegions(),
- [&](Region &r) { return &r == regionCapture; }) &&
+ assert(region->getParentOp() == op &&
"region must be part of the operation");
-#endif // NDEBUG
if (!info.regionsToPromote.contains(region))
continue;
promoteInRegion(region, reachingDef);
diff --git a/mlir/test/Dialect/SCF/mem2reg.mlir b/mlir/test/Dialect/SCF/mem2reg.mlir
index 2143bae9d4f23..389fe33e308b3 100644
--- a/mlir/test/Dialect/SCF/mem2reg.mlir
+++ b/mlir/test/Dialect/SCF/mem2reg.mlir
@@ -296,6 +296,75 @@ func.func @execute_region_cfg_with_store(%cond0: i1, %cond1: i1) -> i32 {
// -----
+// Check promotion through an execute_region with multiple yield terminators
+// having different reaching definitions.
+
+// CHECK-LABEL: func.func @execute_region_multiple_yields
+// CHECK-SAME: (%[[COND:.*]]: i1)
+// CHECK-DAG: %[[C7:.*]] = arith.constant 7 : i32
+// CHECK-DAG: %[[C9:.*]] = arith.constant 9 : i32
+// CHECK: %[[RES:.*]] = scf.execute_region -> i32 {
+// CHECK: cf.cond_br %[[COND]], ^[[BB1:.*]], ^[[BB2:.*]]
+// CHECK: ^[[BB1]]:
+// CHECK: scf.yield %[[C7]] : i32
+// CHECK: ^[[BB2]]:
+// CHECK: scf.yield %[[C9]] : i32
+// CHECK: }
+// CHECK: return %[[RES]] : i32
+func.func @execute_region_multiple_yields(%cond: i1) -> i32 {
+ %c5 = arith.constant 5 : i32
+ %c7 = arith.constant 7 : i32
+ %c9 = arith.constant 9 : i32
+ %alloca = memref.alloca() : memref<i32>
+ memref.store %c5, %alloca[] : memref<i32>
+ scf.execute_region {
+ cf.cond_br %cond, ^bb1, ^bb2
+ ^bb1:
+ memref.store %c7, %alloca[] : memref<i32>
+ scf.yield
+ ^bb2:
+ memref.store %c9, %alloca[] : memref<i32>
+ scf.yield
+ }
+ %load = memref.load %alloca[] : memref<i32>
+ return %load : i32
+}
+
+// -----
+
+// Check promotion when both yield terminators share the same reaching
+// definition from a store in the entry block.
+
+// CHECK-LABEL: func.func @execute_region_same_reaching_def
+// CHECK-SAME: (%[[COND:.*]]: i1)
+// CHECK: %[[C7:.*]] = arith.constant 7 : i32
+// CHECK: %[[RES:.*]] = scf.execute_region -> i32 {
+// CHECK: cf.cond_br %[[COND]], ^[[BB1:.*]], ^[[BB2:.*]]
+// CHECK: ^[[BB1]]:
+// CHECK: scf.yield %[[C7]] : i32
+// CHECK: ^[[BB2]]:
+// CHECK: scf.yield %[[C7]] : i32
+// CHECK: }
+// CHECK: return %[[RES]] : i32
+func.func @execute_region_same_reaching_def(%cond: i1) -> i32 {
+ %c5 = arith.constant 5 : i32
+ %c7 = arith.constant 7 : i32
+ %alloca = memref.alloca() : memref<i32>
+ memref.store %c5, %alloca[] : memref<i32>
+ scf.execute_region {
+ memref.store %c7, %alloca[] : memref<i32>
+ cf.cond_br %cond, ^bb1, ^bb2
+ ^bb1:
+ scf.yield
+ ^bb2:
+ scf.yield
+ }
+ %load = memref.load %alloca[] : memref<i32>
+ return %load : i32
+}
+
+// -----
+
// Check promotion through a for loop with a load and store in the body.
// CHECK-LABEL: func.func @for_load_and_store
>From 0af89bcd0e4357621053781301d4a507d168aa3c Mon Sep 17 00:00:00 2001
From: Theo Degioanni <tdegioanni at nvidia.com>
Date: Tue, 10 Mar 2026 15:32:59 +0100
Subject: [PATCH 08/15] fix a bug with load into store
---
.../mlir/Interfaces/MemorySlotInterfaces.td | 6 +-
mlir/lib/Transforms/Mem2Reg.cpp | 52 ++++-
mlir/test/Dialect/SCF/mem2reg.mlir | 197 ++++++++++++++++++
3 files changed, 246 insertions(+), 9 deletions(-)
diff --git a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
index 350adce2963ad..ac6c7e0459c3d 100644
--- a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
+++ b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
@@ -313,7 +313,8 @@ def PromotableRegionOpInterface
)
>,
InterfaceMethod<[{
- Called after promotion has been completed in all the relevant regions.
+ Called once the reaching definitions have been computed for all the
+ regions, but before the actual removal of the blocking uses.
Returns the new reaching definition at the exit of the operation. For
this purpose, it is allowed to mutate the operation using the provided
@@ -323,7 +324,8 @@ def PromotableRegionOpInterface
if its region content is moved from the original operation and not
copied. Operations with an effect on the value of the slot must not
change said effect (for example, new control flow that could change
- reaching definitions for a block is not allowed).
+ reaching definitions for a block is not allowed), and must remain
+ unmodified.
The `entryReachingDef` is the reaching definition at the entry of the
region operation.
diff --git a/mlir/lib/Transforms/Mem2Reg.cpp b/mlir/lib/Transforms/Mem2Reg.cpp
index 916b510f338a6..8e2f31f85ffdf 100644
--- a/mlir/lib/Transforms/Mem2Reg.cpp
+++ b/mlir/lib/Transforms/Mem2Reg.cpp
@@ -105,7 +105,18 @@ using namespace mlir;
/// operation, to obtain the reaching definition at its end and carry on with
/// the value forwarding.
/// - The second step of the per-region process uses the reaching definition to
-/// remove blocking uses in topological order.
+/// remove blocking uses in topological order. Some reaching definitions may
+/// be values that will be removed or modified during the blocking use removal
+/// step (typically, in the case of a store that stores the result of a load).
+/// To properly handle such values, this step traverses the operations to modify
+/// in reverse topological order. This way, if a value that will disappear is
+/// used in place of reaching definition, the logic to make it disappear will be
+/// executed after the value has been used to replace an operation. For regions
+/// within a PromotableRegionOpInterface, in order to correctly handle cases
+/// where the finalization logic would use a reaching definition that will be
+/// replaced, the finalization logic must be called before the blocking use
+/// removal step, so that any use of a value that will be removed gets properly
+/// replaced.
///
/// For further reading, chapter three of SSA-based Compiler Design [1]
/// showcases SSA construction for control-flow graphs, where mem2reg is an
@@ -220,12 +231,15 @@ class MemorySlotPromoter {
/// Computes the reaching definition for all the operations that require
/// promotion, including within nested regions needing promotion.
/// `reachingDef` is the value the slot contains at the beginning of the
- /// block. This method returns the reached definition at the end of the block.
+ /// block. This member function returns the reached definition at the end of
+ /// the block. If the block contains a region that needs promotion, the
+ /// blocking uses of that region will have been removed. This member function
+ /// will not remove the blocking uses contained directly in the block.
///
/// The `reachingDef` may be a null value. In that case, a lazily-created
/// default value will be used.
///
- /// This method must only be called at most once per block.
+ /// This member function must only be called at most once per block.
Value promoteInBlock(Block *block, Value reachingDef);
/// Computes the reaching definition for all the operations that require
@@ -237,7 +251,7 @@ class MemorySlotPromoter {
/// The `reachingDef` may be a null value. In that case, a lazily-created
/// default value will be used.
///
- /// This method must only be called at most once per region.
+ /// This member function must only be called at most once per region.
void promoteInRegion(Region *region, Value reachingDef);
/// Removes the blocking uses of the slot within the given region, in
@@ -587,6 +601,13 @@ Value MemorySlotPromoter::promoteInBlock(Block *block, Value reachingDef) {
builder.setInsertionPointAfter(op);
reachingDef = promotableRegionOp.finalizePromotion(
slot, reachingDef, hasValueStores, reachingAtBlockEnd, builder);
+
+ // Blocking uses can then be removed for the regions that were promoted.
+ // Even though `finalizePromotion` may have moved regions to a new operation,
+ // `removeBlockingUses` handles this case and will redirect processing to
+ // the correct region.
+ for (auto &[region, reachingDef] : regionsToProcess)
+ removeBlockingUses(region);
}
}
}
@@ -598,7 +619,6 @@ Value MemorySlotPromoter::promoteInBlock(Block *block, Value reachingDef) {
void MemorySlotPromoter::promoteInRegion(Region *region, Value reachingDef) {
if (region->hasOneBlock()) {
promoteInBlock(®ion->front(), reachingDef);
- removeBlockingUses(region);
return;
}
@@ -629,8 +649,6 @@ void MemorySlotPromoter::promoteInRegion(Region *region, Value reachingDef) {
for (auto *child : job.block->children())
dfsStack.emplace_back<DfsJob>({child, job.reachingDef});
}
-
- removeBlockingUses(region);
}
/// Gets or creates a block index mapping for `region`.
@@ -674,6 +692,18 @@ void MemorySlotPromoter::removeBlockingUses(Region *region) {
if (blockingUsesMapIt == info.userToBlockingUses.end())
return;
BlockingUsesMap &blockingUsesMap = blockingUsesMapIt->second;
+ if (blockingUsesMap.empty())
+ return;
+
+ // Operations may have been moved to a different region at this point.
+ // To cover this, we process the current region of an operation to remove
+ // instead of the provided region.
+ region = blockingUsesMap.front().first->getParentRegion();
+#ifndef NDEBUG
+ for (auto &[op, blockingUses] : blockingUsesMap)
+ assert(op->getParentRegion() == region &&
+ "all operations must still be in the same region");
+#endif // NDEBUG
llvm::SmallVector<Operation *> usersToRemoveUses(
llvm::make_first_range(blockingUsesMap));
@@ -681,6 +711,7 @@ void MemorySlotPromoter::removeBlockingUses(Region *region) {
// Sort according to dominance.
dominanceSort(usersToRemoveUses, *region, blockIndexCache);
+ // Iterate over the operations to rewrite in reverse dominance order.
for (Operation *toPromote : llvm::reverse(usersToRemoveUses)) {
if (auto toPromoteMemOp = dyn_cast<PromotableMemOpInterface>(toPromote)) {
Value reachingDef = reachingDefs.lookup(toPromoteMemOp);
@@ -767,8 +798,15 @@ MemorySlotPromoter::promoteSlot() {
// definition starts with a null value that will be replaced by a
// lazily-created default value if the value must be passed to a promotion
// interface while no store has been encountered yet.
+ // Innermost regions will see their blocking uses be removed, but not the
+ // outermost region which we have to remove manually afterwards. This is
+ // because PromotableRegionOpInterface::finalizePromotion must be called
+ // before removeBlockingUses.
promoteInRegion(slot.ptr.getParentRegion(), nullptr);
+ // Blocking uses can then be removed for the outermost region.
+ removeBlockingUses(slot.ptr.getParentRegion());
+
// Notify operations that requested it of the reaching definitions set by
// storing memory operations.
for (PromotableOpInterface op : toVisitReplacedValues) {
diff --git a/mlir/test/Dialect/SCF/mem2reg.mlir b/mlir/test/Dialect/SCF/mem2reg.mlir
index 389fe33e308b3..b6dbb8c015188 100644
--- a/mlir/test/Dialect/SCF/mem2reg.mlir
+++ b/mlir/test/Dialect/SCF/mem2reg.mlir
@@ -124,6 +124,32 @@ func.func @if_nested_store(%cond0: i1, %cond1: i1) {
// -----
+// Check that a store coming from a load of the same slot is correctly promoted.
+
+// CHECK-LABEL: func.func @if_load_into_store
+// CHECK-SAME: (%[[COND:.*]]: i1)
+// CHECK: %[[C5:.*]] = arith.constant 5 : i32
+// CHECK: %[[RES:.*]] = scf.if %[[COND]] -> (i32) {
+// CHECK: scf.yield %[[C5]] : i32
+// CHECK: } else {
+// CHECK: scf.yield %[[C5]] : i32
+// CHECK: }
+// CHECK: return %[[RES]] : i32
+func.func @if_load_into_store(%arg1 : i1) -> i32 {
+ %c5 = arith.constant 5 : i32
+ %alloca = memref.alloca() : memref<i32>
+ memref.store %c5, %alloca[] : memref<i32>
+ scf.if %arg1 {
+ %loaded = memref.load %alloca[] : memref<i32>
+ memref.store %loaded, %alloca[] : memref<i32>
+ scf.yield
+ }
+ %loaded2 = memref.load %alloca[] : memref<i32>
+ return %loaded2 : i32
+}
+
+// -----
+
// Check load promotion through execute_region.
func.func private @use(i32)
@@ -365,6 +391,61 @@ func.func @execute_region_same_reaching_def(%cond: i1) -> i32 {
// -----
+// Check that a load-then-store of the same slot in the same block is promoted.
+
+// CHECK-LABEL: func.func @execute_region_load_into_store_same_block
+// CHECK: %[[C5:.*]] = arith.constant 5 : i32
+// CHECK: %[[RES:.*]] = scf.execute_region -> i32 {
+// CHECK: scf.yield %[[C5]] : i32
+// CHECK: }
+// CHECK: return %[[RES]] : i32
+func.func @execute_region_load_into_store_same_block() -> i32 {
+ %c5 = arith.constant 5 : i32
+ %alloca = memref.alloca() : memref<i32>
+ memref.store %c5, %alloca[] : memref<i32>
+ scf.execute_region {
+ %loaded = memref.load %alloca[] : memref<i32>
+ memref.store %loaded, %alloca[] : memref<i32>
+ scf.yield
+ }
+ %load = memref.load %alloca[] : memref<i32>
+ return %load : i32
+}
+
+// -----
+
+// Check that a load-then-store of the same slot across blocks is promoted.
+
+// CHECK-LABEL: func.func @execute_region_load_into_store_diff_block
+// CHECK-SAME: (%[[COND:.*]]: i1)
+// CHECK: %[[C5:.*]] = arith.constant 5 : i32
+// CHECK: %[[RES:.*]] = scf.execute_region -> i32 {
+// CHECK: cf.cond_br %[[COND]], ^[[BB1:.*]], ^[[BB2:.*]]
+// CHECK: ^[[BB1]]:
+// CHECK: scf.yield %[[C5]] : i32
+// CHECK: ^[[BB2]]:
+// CHECK: scf.yield %[[C5]] : i32
+// CHECK: }
+// CHECK: return %[[RES]] : i32
+func.func @execute_region_load_into_store_diff_block(%cond: i1) -> i32 {
+ %c5 = arith.constant 5 : i32
+ %alloca = memref.alloca() : memref<i32>
+ memref.store %c5, %alloca[] : memref<i32>
+ scf.execute_region {
+ %loaded = memref.load %alloca[] : memref<i32>
+ cf.cond_br %cond, ^bb1, ^bb2
+ ^bb1:
+ memref.store %loaded, %alloca[] : memref<i32>
+ scf.yield
+ ^bb2:
+ scf.yield
+ }
+ %load = memref.load %alloca[] : memref<i32>
+ return %load : i32
+}
+
+// -----
+
// Check promotion through a for loop with a load and store in the body.
// CHECK-LABEL: func.func @for_load_and_store
@@ -482,6 +563,30 @@ func.func @for_if_store(%lb: index, %ub: index, %step: index, %cond: i1) -> i32
// -----
+// Check that a load-then-store of the same slot in a for loop is promoted.
+
+// CHECK-LABEL: func.func @for_load_into_store
+// CHECK-SAME: (%[[LB:.*]]: index, %[[UB:.*]]: index, %[[STEP:.*]]: index)
+// CHECK: %[[C5:.*]] = arith.constant 5 : i32
+// CHECK: %[[RES:.*]] = scf.for %{{.*}} = %[[LB]] to %[[UB]] step %[[STEP]] iter_args(%[[ARG:.*]] = %[[C5]]) -> (i32) {
+// CHECK: scf.yield %[[ARG]] : i32
+// CHECK: }
+// CHECK: return %[[RES]] : i32
+func.func @for_load_into_store(%lb: index, %ub: index, %step: index) -> i32 {
+ %c5 = arith.constant 5 : i32
+ %alloca = memref.alloca() : memref<i32>
+ memref.store %c5, %alloca[] : memref<i32>
+ scf.for %i = %lb to %ub step %step {
+ %loaded = memref.load %alloca[] : memref<i32>
+ memref.store %loaded, %alloca[] : memref<i32>
+ scf.yield
+ }
+ %load = memref.load %alloca[] : memref<i32>
+ return %load : i32
+}
+
+// -----
+
// Check load promotion through a forall.
func.func private @use(i32)
@@ -798,6 +903,64 @@ func.func @while_store() -> i32 {
// -----
+// Check that a load-then-store in the before region of a while is promoted.
+
+// CHECK-LABEL: func.func @while_load_into_store_before
+// CHECK-SAME: (%[[COND:.*]]: i1)
+// CHECK: %[[C5:.*]] = arith.constant 5 : i32
+// CHECK: %[[RES:.*]] = scf.while (%[[BEFORE:.*]] = %[[C5]]) : (i32) -> i32 {
+// CHECK: scf.condition(%[[COND]]) %[[BEFORE]] : i32
+// CHECK: } do {
+// CHECK: ^{{.*}}(%[[AFTER:.*]]: i32):
+// CHECK: scf.yield %[[AFTER]] : i32
+// CHECK: }
+// CHECK: return %[[RES]] : i32
+func.func @while_load_into_store_before(%cond: i1) -> i32 {
+ %c5 = arith.constant 5 : i32
+ %alloca = memref.alloca() : memref<i32>
+ memref.store %c5, %alloca[] : memref<i32>
+ scf.while : () -> () {
+ %loaded = memref.load %alloca[] : memref<i32>
+ memref.store %loaded, %alloca[] : memref<i32>
+ scf.condition(%cond)
+ } do {
+ scf.yield
+ }
+ %res = memref.load %alloca[] : memref<i32>
+ return %res : i32
+}
+
+// -----
+
+// Check that a load-then-store in the after region of a while is promoted.
+
+// CHECK-LABEL: func.func @while_load_into_store
+// CHECK-SAME: (%[[COND:.*]]: i1)
+// CHECK: %[[C5:.*]] = arith.constant 5 : i32
+// CHECK: %[[RES:.*]] = scf.while (%[[BEFORE:.*]] = %[[C5]]) : (i32) -> i32 {
+// CHECK: scf.condition(%[[COND]]) %[[BEFORE]] : i32
+// CHECK: } do {
+// CHECK: ^{{.*}}(%[[AFTER:.*]]: i32):
+// CHECK: scf.yield %[[AFTER]] : i32
+// CHECK: }
+// CHECK: return %[[RES]] : i32
+func.func @while_load_into_store(%cond: i1) -> i32 {
+ %c5 = arith.constant 5 : i32
+ %alloca = memref.alloca() : memref<i32>
+ memref.store %c5, %alloca[] : memref<i32>
+ scf.while : () -> () {
+ scf.condition(%cond)
+ } do {
+ %loaded = memref.load %alloca[] : memref<i32>
+ memref.store %loaded, %alloca[] : memref<i32>
+ scf.yield
+ }
+ %res = memref.load %alloca[] : memref<i32>
+ return %res : i32
+}
+
+// -----
+
// Check load promotion through an index_switch default branch.
func.func private @use(i32)
@@ -886,3 +1049,37 @@ func.func @index_switch_store_case_load_default(%idx: index) -> i32 {
%load2 = memref.load %alloca[] : memref<i32>
return %load2 : i32
}
+
+// -----
+
+// Check that load-then-store of the same slot in an index_switch is promoted.
+
+// CHECK-LABEL: func.func @index_switch_load_into_store
+// CHECK-SAME: (%[[IDX:.*]]: index)
+// CHECK: %[[C5:.*]] = arith.constant 5 : i32
+// CHECK: %[[RES:.*]] = scf.index_switch %[[IDX]] -> i32
+// CHECK: case 0 {
+// CHECK: scf.yield %[[C5]] : i32
+// CHECK: }
+// CHECK: default {
+// CHECK: scf.yield %[[C5]] : i32
+// CHECK: }
+// CHECK: return %[[RES]] : i32
+func.func @index_switch_load_into_store(%idx: index) -> i32 {
+ %c5 = arith.constant 5 : i32
+ %alloca = memref.alloca() : memref<i32>
+ memref.store %c5, %alloca[] : memref<i32>
+ scf.index_switch %idx
+ case 0 {
+ %loaded = memref.load %alloca[] : memref<i32>
+ memref.store %loaded, %alloca[] : memref<i32>
+ scf.yield
+ }
+ default {
+ %loaded = memref.load %alloca[] : memref<i32>
+ memref.store %loaded, %alloca[] : memref<i32>
+ scf.yield
+ }
+ %load = memref.load %alloca[] : memref<i32>
+ return %load : i32
+}
>From ecddee133d79015ce0349845746a452637f1c8ff Mon Sep 17 00:00:00 2001
From: Theo Degioanni <tdegioanni at nvidia.com>
Date: Tue, 10 Mar 2026 15:36:34 +0100
Subject: [PATCH 09/15] format
---
mlir/lib/Transforms/Mem2Reg.cpp | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Transforms/Mem2Reg.cpp b/mlir/lib/Transforms/Mem2Reg.cpp
index 8e2f31f85ffdf..29c5662c5331a 100644
--- a/mlir/lib/Transforms/Mem2Reg.cpp
+++ b/mlir/lib/Transforms/Mem2Reg.cpp
@@ -603,9 +603,9 @@ Value MemorySlotPromoter::promoteInBlock(Block *block, Value reachingDef) {
slot, reachingDef, hasValueStores, reachingAtBlockEnd, builder);
// Blocking uses can then be removed for the regions that were promoted.
- // Even though `finalizePromotion` may have moved regions to a new operation,
- // `removeBlockingUses` handles this case and will redirect processing to
- // the correct region.
+ // Even though `finalizePromotion` may have moved regions to a new
+ // operation, `removeBlockingUses` handles this case and will redirect
+ // processing to the correct region.
for (auto &[region, reachingDef] : regionsToProcess)
removeBlockingUses(region);
}
>From 3ea03349c6980980510e58cd18c16a03dd963a2d Mon Sep 17 00:00:00 2001
From: Theo Degioanni <tdegioanni at nvidia.com>
Date: Tue, 10 Mar 2026 15:44:06 +0100
Subject: [PATCH 10/15] improve doc of removeBlockingUses
---
mlir/lib/Transforms/Mem2Reg.cpp | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Transforms/Mem2Reg.cpp b/mlir/lib/Transforms/Mem2Reg.cpp
index 29c5662c5331a..1bdd1b5695efc 100644
--- a/mlir/lib/Transforms/Mem2Reg.cpp
+++ b/mlir/lib/Transforms/Mem2Reg.cpp
@@ -255,7 +255,8 @@ class MemorySlotPromoter {
void promoteInRegion(Region *region, Value reachingDef);
/// Removes the blocking uses of the slot within the given region, in
- /// topological order.
+ /// reverse topological order. If the content of the region was moved out
+ /// to a different region, the new region will be processed instead.
void removeBlockingUses(Region *region);
/// Links merge point block arguments to the terminators targeting the merge
>From 8618c3a369b836027452996b5791f043b0bb2aa1 Mon Sep 17 00:00:00 2001
From: Theo Degioanni <tdegioanni at nvidia.com>
Date: Tue, 10 Mar 2026 15:50:05 +0100
Subject: [PATCH 11/15] add another test
---
mlir/test/Dialect/SCF/mem2reg.mlir | 35 ++++++++++++++++++++++++++++++
1 file changed, 35 insertions(+)
diff --git a/mlir/test/Dialect/SCF/mem2reg.mlir b/mlir/test/Dialect/SCF/mem2reg.mlir
index b6dbb8c015188..82098b2d306b2 100644
--- a/mlir/test/Dialect/SCF/mem2reg.mlir
+++ b/mlir/test/Dialect/SCF/mem2reg.mlir
@@ -150,6 +150,41 @@ func.func @if_load_into_store(%arg1 : i1) -> i32 {
// -----
+// Check promotion of a load followed by a nested if containing a store of
+// the loaded value.
+
+// CHECK-LABEL: func.func @if_load_then_nested_if_store
+// CHECK-SAME: (%[[COND0:.*]]: i1, %[[COND1:.*]]: i1)
+// CHECK: %[[C5:.*]] = arith.constant 5 : i32
+// CHECK: %[[RES:.*]] = scf.if %[[COND0]] -> (i32) {
+// CHECK: %[[INNER:.*]] = scf.if %[[COND1]] -> (i32) {
+// CHECK: scf.yield %[[C5]] : i32
+// CHECK: } else {
+// CHECK: scf.yield %[[C5]] : i32
+// CHECK: }
+// CHECK: scf.yield %[[INNER]] : i32
+// CHECK: } else {
+// CHECK: scf.yield %[[C5]] : i32
+// CHECK: }
+// CHECK: return %[[RES]] : i32
+func.func @if_load_then_nested_if_store(%cond0: i1, %cond1: i1) -> i32 {
+ %c5 = arith.constant 5 : i32
+ %alloca = memref.alloca() : memref<i32>
+ memref.store %c5, %alloca[] : memref<i32>
+ scf.if %cond0 {
+ %loaded = memref.load %alloca[] : memref<i32>
+ scf.if %cond1 {
+ memref.store %loaded, %alloca[] : memref<i32>
+ scf.yield
+ }
+ scf.yield
+ }
+ %load = memref.load %alloca[] : memref<i32>
+ return %load : i32
+}
+
+// -----
+
// Check load promotion through execute_region.
func.func private @use(i32)
>From e01d9d6dfc8785d83793d9a18ea2b4e7c43d172a Mon Sep 17 00:00:00 2001
From: Theo Degioanni <tdegioanni at nvidia.com>
Date: Sat, 14 Mar 2026 01:00:04 +0100
Subject: [PATCH 12/15] fix cache invalidation
---
.../mlir/Interfaces/MemorySlotInterfaces.td | 22 ++++++++------
mlir/lib/Transforms/Mem2Reg.cpp | 29 +++++++++++++++----
2 files changed, 36 insertions(+), 15 deletions(-)
diff --git a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
index ac6c7e0459c3d..df7e6a7283705 100644
--- a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
+++ b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
@@ -317,15 +317,19 @@ def PromotableRegionOpInterface
regions, but before the actual removal of the blocking uses.
Returns the new reaching definition at the exit of the operation. For
- this purpose, it is allowed to mutate the operation using the provided
- `builder`, along with its region contents. However, all blocks within
- the existing regions must remain valid and no new blocks may be added.
- As a result, the operation is allowed to be cloned and replaced only
- if its region content is moved from the original operation and not
- copied. Operations with an effect on the value of the slot must not
- change said effect (for example, new control flow that could change
- reaching definitions for a block is not allowed), and must remain
- unmodified.
+ this purpose, mutation is allowed under the following constraints:
+ 1. If a region is deleted, all of its content must have been moved out
+ (not copied) to a new empty region that remains valid after the
+ deletion.
+ 2. Mutation must not change control flow within existing or moved
+ regions. This includes adding, removing or reordering blocks.
+ 3. Mutation must not modify or add operations that interact with the
+ value of the slot.
+
+ As an example, in order to add new results to the region operation, it
+ is allowed to clone the operation without regions, move (without
+ copying) the old region content into the new regions, and delete the
+ original operation.
The `entryReachingDef` is the reaching definition at the entry of the
region operation.
diff --git a/mlir/lib/Transforms/Mem2Reg.cpp b/mlir/lib/Transforms/Mem2Reg.cpp
index 1bdd1b5695efc..74b9c59f677fa 100644
--- a/mlir/lib/Transforms/Mem2Reg.cpp
+++ b/mlir/lib/Transforms/Mem2Reg.cpp
@@ -207,7 +207,12 @@ class MemorySlotPromotionAnalyzer {
const DataLayout &dataLayout;
};
-using BlockIndexCache = DenseMap<Region *, DenseMap<Block *, size_t>>;
+/// Maps a region to a map of blocks to their index in the region.
+/// The region is identified by its entry block pointer instead of its region
+/// pointer to not need to invalidate the cache when region content is moved to
+/// a new region. This only supports moves of all the blocks of a region to
+/// an empty region.
+using BlockIndexCache = DenseMap<Block *, DenseMap<Block *, size_t>>;
/// The MemorySlotPromoter handles the state of promoting a memory slot. It
/// wraps a slot and its associated allocator. This will perform the mutation of
@@ -297,6 +302,10 @@ class MemorySlotPromoter {
const Mem2RegStatistics &statistics;
/// Shared cache of block indices of specific regions.
+ /// Cache entries must be invalidated before any addition, removal or
+ /// reordering of blocks in the corresponding region.
+ /// Cache entries are *NOT* invalidated if all the blocks of the corresponding
+ /// region are moved to an empty region.
BlockIndexCache &blockIndexCache;
};
@@ -652,15 +661,18 @@ void MemorySlotPromoter::promoteInRegion(Region *region, Value reachingDef) {
}
}
-/// Gets or creates a block index mapping for `region`.
+/// Gets or creates a block index mapping for the region of which the entry
+/// block is `regionEntryBlock`.
static const DenseMap<Block *, size_t> &
-getOrCreateBlockIndices(BlockIndexCache &blockIndexCache, Region *region) {
- auto [it, inserted] = blockIndexCache.try_emplace(region);
+getOrCreateBlockIndices(BlockIndexCache &blockIndexCache,
+ Block *regionEntryBlock) {
+ auto [it, inserted] = blockIndexCache.try_emplace(regionEntryBlock);
if (!inserted)
return it->second;
DenseMap<Block *, size_t> &blockIndices = it->second;
- SetVector<Block *> topologicalOrder = getBlocksSortedByDominance(*region);
+ SetVector<Block *> topologicalOrder =
+ getBlocksSortedByDominance(*regionEntryBlock->getParent());
for (auto [index, block] : llvm::enumerate(topologicalOrder))
blockIndices[block] = index;
return blockIndices;
@@ -669,12 +681,17 @@ getOrCreateBlockIndices(BlockIndexCache &blockIndexCache, Region *region) {
/// Sorts `ops` according to dominance. Relies on the topological order of basic
/// blocks to get a deterministic ordering. Uses `blockIndexCache` to avoid the
/// potentially expensive recomputation of a block index map.
+/// This function assumes no blocks are ever deleted or entry block changed
+/// during the lifetime of the block index cache.
static void dominanceSort(SmallVector<Operation *> &ops, Region ®ion,
BlockIndexCache &blockIndexCache) {
+ if (region.empty())
+ return;
+
// Produce a topological block order and construct a map to lookup the indices
// of blocks.
const DenseMap<Block *, size_t> &topoBlockIndices =
- getOrCreateBlockIndices(blockIndexCache, ®ion);
+ getOrCreateBlockIndices(blockIndexCache, ®ion.front());
// Combining the topological order of the basic blocks together with block
// internal operation order guarantees a deterministic, dominance respecting
>From 769e2e46f9c57782956986a4516b3ca0b9d01c0a Mon Sep 17 00:00:00 2001
From: Theo Degioanni <tdegioanni at nvidia.com>
Date: Sat, 14 Mar 2026 01:25:30 +0100
Subject: [PATCH 13/15] fix default value check in memref alloca complete
---
.../include/mlir/Interfaces/MemorySlotInterfaces.td | 3 ++-
mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp | 2 +-
mlir/test/Dialect/MemRef/mem2reg.mlir | 13 +++++++++++++
3 files changed, 16 insertions(+), 2 deletions(-)
diff --git a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
index df7e6a7283705..6be0084ebbe0e 100644
--- a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
+++ b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
@@ -68,7 +68,8 @@ def PromotableAllocationOpInterface
>,
InterfaceMethod<[{
Hook triggered once the promotion of a slot is complete. This can
- also clean up the created default value if necessary.
+ also clean up the created default value if necessary. The default
+ value may be a null value if no default value was created.
This will only be called for slots declared by this operation.
Must return a new promotable allocation op if this operation produced
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
index 540423831937e..6748e2cf71804 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
@@ -82,7 +82,7 @@ std::optional<PromotableAllocationOpInterface>
memref::AllocaOp::handlePromotionComplete(const MemorySlot &slot,
Value defaultValue,
OpBuilder &builder) {
- if (defaultValue.use_empty())
+ if (defaultValue && defaultValue.use_empty())
defaultValue.getDefiningOp()->erase();
this->erase();
return std::nullopt;
diff --git a/mlir/test/Dialect/MemRef/mem2reg.mlir b/mlir/test/Dialect/MemRef/mem2reg.mlir
index dd68675cc4441..30a268521c69b 100644
--- a/mlir/test/Dialect/MemRef/mem2reg.mlir
+++ b/mlir/test/Dialect/MemRef/mem2reg.mlir
@@ -163,3 +163,16 @@ func.func @promotable_nonpromotable_intertwined() -> i32 {
}
func.func @use(%arg: memref<i32>) { return }
+
+// -----
+
+// CHECK-LABEL: func.func @unused_alloca_store_loop
+func.func @unused_alloca_store_loop() {
+ // CHECK-NOT: memref.alloca
+ %cst = arith.constant 1 : i32
+ %alloca = memref.alloca() : memref<i32>
+ cf.br ^bb1
+^bb1:
+ memref.store %cst, %alloca[] : memref<i32>
+ cf.br ^bb1
+}
>From fbb3874784c36004c557f19ced912cd44cc3a47a Mon Sep 17 00:00:00 2001
From: Theo Degioanni <tdegioanni at nvidia.com>
Date: Sat, 14 Mar 2026 01:49:19 +0100
Subject: [PATCH 14/15] clarify interface doc further
---
mlir/include/mlir/Interfaces/MemorySlotInterfaces.td | 7 ++++---
1 file changed, 4 insertions(+), 3 deletions(-)
diff --git a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
index 6be0084ebbe0e..801555fba4947 100644
--- a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
+++ b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
@@ -318,12 +318,13 @@ def PromotableRegionOpInterface
regions, but before the actual removal of the blocking uses.
Returns the new reaching definition at the exit of the operation. For
- this purpose, mutation is allowed under the following constraints:
+ this purpose, mutation of the operation is allowed under the following
+ constraints:
1. If a region is deleted, all of its content must have been moved out
(not copied) to a new empty region that remains valid after the
deletion.
- 2. Mutation must not change control flow within existing or moved
- regions. This includes adding, removing or reordering blocks.
+ 2. Mutation must not change control flow within or between existing or
+ moved regions. This includes adding, removing or reordering blocks.
3. Mutation must not modify or add operations that interact with the
value of the slot.
>From 6fd365eac592575dfbdd995a87bd33e7cde69dc4 Mon Sep 17 00:00:00 2001
From: tdegioanni-nvidia <tdegioanni at nvidia.com>
Date: Sat, 21 Mar 2026 03:30:59 +0100
Subject: [PATCH 15/15] Apply suggestion from @vzakhari
Co-authored-by: Slava Zakharin <szakharin at nvidia.com>
---
mlir/lib/Transforms/Mem2Reg.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Transforms/Mem2Reg.cpp b/mlir/lib/Transforms/Mem2Reg.cpp
index 74b9c59f677fa..53d38536cee00 100644
--- a/mlir/lib/Transforms/Mem2Reg.cpp
+++ b/mlir/lib/Transforms/Mem2Reg.cpp
@@ -88,7 +88,7 @@ using namespace mlir;
/// do so aborts promotion at this step).
///
/// At this point, promotion is guaranteed to happen, and the transformation
-/// phase can begin. For each region of the program, a two step procvess is
+/// phase can begin. For each region of the program, a two step process is
/// carried out.
/// - The first step of the per-region process computes the reaching definition
/// of the memory slot at each blocking user. This is the core of the mem2reg
More information about the Mlir-commits
mailing list