[Mlir-commits] [mlir] [MLIR][Mem2Reg] Add support for region control flow and SCF (PR #185036)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Mar 9 08:25:20 PDT 2026
================
@@ -0,0 +1,347 @@
+//===- MemorySlot.cpp - Memory Slot interface implementations for SCF -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+
+using namespace mlir;
+using namespace mlir::scf;
+
+//===----------------------------------------------------------------------===//
+// Helper functions
+//===----------------------------------------------------------------------===//
+
+/// Adds the corresponding reaching definition to the terminator of the block if
+/// the terminator is of the provided type.
+template <typename TermTy>
+static void
+updateTerminator(Block *block, Value defaultReachingDef,
+ llvm::DenseMap<Block *, Value> &reachingAtBlockEnd) {
+ Operation *terminator = block->getTerminator();
+ if (!isa<TermTy>(terminator))
+ return;
+ Value blockReachingDef = reachingAtBlockEnd[block];
+ if (!blockReachingDef) {
+ // Block is dead code or the region is not using the slot, so we use the
+ // default provided reaching definition.
+ blockReachingDef = defaultReachingDef;
+ }
+ terminator->insertOperands(terminator->getNumOperands(), {blockReachingDef});
+}
+
+/// Creates a shallow copy of an operation with new result types, moving the
+/// regions out of the original operation and deleting the original operation.
+static Operation *replaceWithNewResults(RewriterBase &rewriter, Operation *op,
+ TypeRange resultTypes) {
+ RewriterBase::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPoint(op);
+ Operation *newOp =
+ mlir::cloneWithoutRegions(rewriter, op, resultTypes, op->getOperands());
+ rewriter.startOpModification(newOp);
+ rewriter.startOpModification(op);
+ for (unsigned int i : llvm::seq(op->getNumRegions()))
+ newOp->getRegion(i).takeBody(op->getRegion(i));
+ rewriter.finalizeOpModification(op);
+ rewriter.finalizeOpModification(newOp);
+
+ SmallVector<Value> replacementValues(newOp->getResults().drop_back());
+ rewriter.replaceAllOpUsesWith(op, replacementValues);
+ rewriter.eraseOp(op);
+ return newOp;
+}
+
+//===----------------------------------------------------------------------===//
+// ExecuteRegionOp
+//===----------------------------------------------------------------------===//
+
+bool ExecuteRegionOp::isRegionPromotable(const MemorySlot &slot, Region *region,
+ bool hasValueStores) {
+ return true;
+}
+
+void ExecuteRegionOp::setupPromotion(
+ const MemorySlot &slot, Value reachingDef, bool hasValueStores,
+ llvm::SmallMapVector<Region *, Value, 2> ®ionsToProcess) {
+ regionsToProcess.insert({&getRegion(), reachingDef});
+}
+
+Value ExecuteRegionOp::finalizePromotion(
+ const MemorySlot &slot, Value reachingDef, bool hasValueStores,
+ llvm::DenseMap<Block *, Value> &reachingAtBlockEnd, OpBuilder &builder) {
+ if (!hasValueStores)
+ return reachingDef;
+
+ // Update the yield terminators to return the newly defined reaching
+ // definition.
+ for (Block &block : getRegion().getBlocks())
+ updateTerminator<YieldOp>(&block, reachingDef, reachingAtBlockEnd);
+
+ SmallVector<Type> resultTypes(getResultTypes());
+ resultTypes.push_back(slot.elemType);
+
+ IRRewriter rewriter(builder);
+ Operation *newOp =
+ replaceWithNewResults(rewriter, getOperation(), resultTypes);
+ return newOp->getResults().back();
+}
+
+//===----------------------------------------------------------------------===//
+// ForOp
+//===----------------------------------------------------------------------===//
+
+bool ForOp::isRegionPromotable(const MemorySlot &slot, Region *region,
+ bool hasValueStores) {
+ return true;
+}
+
+void ForOp::setupPromotion(
+ const MemorySlot &slot, Value reachingDef, bool hasValueStores,
+ llvm::SmallMapVector<Region *, Value, 2> ®ionsToProcess) {
+ Region &bodyRegion = getBodyRegion();
+ if (!hasValueStores) {
+ regionsToProcess.insert({&bodyRegion, reachingDef});
+ return;
+ }
+
+ getInitArgsMutable().append(reachingDef);
+ bodyRegion.addArgument(slot.elemType, slot.ptr.getLoc());
+ regionsToProcess.insert({&bodyRegion, bodyRegion.getArguments().back()});
+}
+
+Value ForOp::finalizePromotion(
+ const MemorySlot &slot, Value reachingDef, bool hasValueStores,
+ llvm::DenseMap<Block *, Value> &reachingAtBlockEnd, OpBuilder &builder) {
+ if (!hasValueStores)
+ return reachingDef;
----------------
tdegioanni-nvidia wrote:
I have a strong preference for hasValueStores because there is also the case of dead blocks which would complexify the in-place checking a lot.
https://github.com/llvm/llvm-project/pull/185036
More information about the Mlir-commits
mailing list