[Mlir-commits] [mlir] [MLIR][Mem2Reg] Add support for region control flow and SCF (PR #185036)
Tobias Gysi
llvmlistbot at llvm.org
Mon Mar 9 07:38:37 PDT 2026
================
@@ -0,0 +1,347 @@
+//===- MemorySlot.cpp - Memory Slot interface implementations for SCF -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+
+using namespace mlir;
+using namespace mlir::scf;
+
+//===----------------------------------------------------------------------===//
+// Helper functions
+//===----------------------------------------------------------------------===//
+
+/// Adds the corresponding reaching definition to the terminator of the block if
+/// the terminator is of the provided type.
+template <typename TermTy>
+static void
+updateTerminator(Block *block, Value defaultReachingDef,
+ llvm::DenseMap<Block *, Value> &reachingAtBlockEnd) {
+ Operation *terminator = block->getTerminator();
+ if (!isa<TermTy>(terminator))
+ return;
+ Value blockReachingDef = reachingAtBlockEnd[block];
+ if (!blockReachingDef) {
+ // Block is dead code or the region is not using the slot, so we use the
+ // default provided reaching definition.
+ blockReachingDef = defaultReachingDef;
+ }
+ terminator->insertOperands(terminator->getNumOperands(), {blockReachingDef});
+}
+
+/// Creates a shallow copy of an operation with new result types, moving the
+/// regions out of the original operation and deleting the original operation.
+static Operation *replaceWithNewResults(RewriterBase &rewriter, Operation *op,
+ TypeRange resultTypes) {
+ RewriterBase::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPoint(op);
+ Operation *newOp =
+ mlir::cloneWithoutRegions(rewriter, op, resultTypes, op->getOperands());
+ rewriter.startOpModification(newOp);
+ rewriter.startOpModification(op);
+ for (unsigned int i : llvm::seq(op->getNumRegions()))
+ newOp->getRegion(i).takeBody(op->getRegion(i));
+ rewriter.finalizeOpModification(op);
+ rewriter.finalizeOpModification(newOp);
+
+ SmallVector<Value> replacementValues(newOp->getResults().drop_back());
+ rewriter.replaceAllOpUsesWith(op, replacementValues);
+ rewriter.eraseOp(op);
+ return newOp;
+}
+
+//===----------------------------------------------------------------------===//
+// ExecuteRegionOp
+//===----------------------------------------------------------------------===//
+
+bool ExecuteRegionOp::isRegionPromotable(const MemorySlot &slot, Region *region,
+ bool hasValueStores) {
+ return true;
+}
+
+void ExecuteRegionOp::setupPromotion(
+ const MemorySlot &slot, Value reachingDef, bool hasValueStores,
+ llvm::SmallMapVector<Region *, Value, 2> ®ionsToProcess) {
+ regionsToProcess.insert({&getRegion(), reachingDef});
+}
+
+Value ExecuteRegionOp::finalizePromotion(
+ const MemorySlot &slot, Value reachingDef, bool hasValueStores,
+ llvm::DenseMap<Block *, Value> &reachingAtBlockEnd, OpBuilder &builder) {
+ if (!hasValueStores)
+ return reachingDef;
+
+ // Update the yield terminators to return the newly defined reaching
+ // definition.
+ for (Block &block : getRegion().getBlocks())
+ updateTerminator<YieldOp>(&block, reachingDef, reachingAtBlockEnd);
+
+ SmallVector<Type> resultTypes(getResultTypes());
+ resultTypes.push_back(slot.elemType);
+
+ IRRewriter rewriter(builder);
+ Operation *newOp =
+ replaceWithNewResults(rewriter, getOperation(), resultTypes);
+ return newOp->getResults().back();
+}
+
+//===----------------------------------------------------------------------===//
+// ForOp
+//===----------------------------------------------------------------------===//
+
+bool ForOp::isRegionPromotable(const MemorySlot &slot, Region *region,
+ bool hasValueStores) {
+ return true;
+}
+
+void ForOp::setupPromotion(
+ const MemorySlot &slot, Value reachingDef, bool hasValueStores,
+ llvm::SmallMapVector<Region *, Value, 2> ®ionsToProcess) {
+ Region &bodyRegion = getBodyRegion();
+ if (!hasValueStores) {
+ regionsToProcess.insert({&bodyRegion, reachingDef});
+ return;
+ }
+
+ getInitArgsMutable().append(reachingDef);
+ bodyRegion.addArgument(slot.elemType, slot.ptr.getLoc());
+ regionsToProcess.insert({&bodyRegion, bodyRegion.getArguments().back()});
+}
+
+Value ForOp::finalizePromotion(
+ const MemorySlot &slot, Value reachingDef, bool hasValueStores,
+ llvm::DenseMap<Block *, Value> &reachingAtBlockEnd, OpBuilder &builder) {
+ if (!hasValueStores)
+ return reachingDef;
+
+ // Update the yield terminator to return the newly defined reaching
+ // definition.
+ updateTerminator<YieldOp>(getBody(), reachingDef, reachingAtBlockEnd);
+
+ SmallVector<Type> resultTypes(getResultTypes());
+ resultTypes.push_back(slot.elemType);
+
+ IRRewriter rewriter(builder);
+ Operation *newOp =
+ replaceWithNewResults(rewriter, getOperation(), resultTypes);
+ return newOp->getResults().back();
+}
+
+//===----------------------------------------------------------------------===//
+// ForallOp
+//===----------------------------------------------------------------------===//
+
+bool ForallOp::isRegionPromotable(const MemorySlot &slot, Region *region,
+ bool hasValueStores) {
+ // The ForallOp body can be ran in parallel, thus does not support sequenced
+ // value passing. Therefore only loads can be handled.
+ return !hasValueStores;
+}
+
+void ForallOp::setupPromotion(
+ const MemorySlot &slot, Value reachingDef, bool hasValueStores,
+ llvm::SmallMapVector<Region *, Value, 2> ®ionsToProcess) {
+ assert(!hasValueStores && "ForallOp does not support stores");
+ regionsToProcess.insert({&getBodyRegion(), reachingDef});
+}
+
+Value ForallOp::finalizePromotion(
+ const MemorySlot &slot, Value reachingDef, bool hasValueStores,
+ llvm::DenseMap<Block *, Value> &reachingAtBlockEnd, OpBuilder &builder) {
+ assert(!hasValueStores && "ForallOp does not support stores");
+ return reachingDef;
+}
+
+//===----------------------------------------------------------------------===//
+// IfOp
+//===----------------------------------------------------------------------===//
+
+bool IfOp::isRegionPromotable(const MemorySlot &slot, Region *region,
+ bool hasValueStores) {
+ return true;
+}
+
+void IfOp::setupPromotion(
+ const MemorySlot &slot, Value reachingDef, bool hasValueStores,
+ llvm::SmallMapVector<Region *, Value, 2> ®ionsToProcess) {
+ regionsToProcess.insert({&getThenRegion(), reachingDef});
+ regionsToProcess.insert({&getElseRegion(), reachingDef});
+}
+
+Value IfOp::finalizePromotion(
+ const MemorySlot &slot, Value reachingDef, bool hasValueStores,
+ llvm::DenseMap<Block *, Value> &reachingAtBlockEnd, OpBuilder &builder) {
+ if (!hasValueStores)
+ return reachingDef;
+
+ IRRewriter rewriter(builder);
+
+ // Update the yield terminators to return the newly defined reaching
+ // definition.
+ updateTerminator<YieldOp>(&getThenRegion().back(), reachingDef,
+ reachingAtBlockEnd);
+ if (getElseRegion().hasOneBlock()) {
+ updateTerminator<YieldOp>(&getElseRegion().back(), reachingDef,
+ reachingAtBlockEnd);
+ } else {
+ OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.createBlock(&getElseRegion());
+ YieldOp::create(rewriter, getOperation()->getLoc(), reachingDef);
+ }
+
+ SmallVector<Type> resultTypes(getResultTypes());
+ resultTypes.push_back(slot.elemType);
+
+ Operation *newOp =
+ replaceWithNewResults(rewriter, getOperation(), resultTypes);
+ return newOp->getResults().back();
+}
+
+//===----------------------------------------------------------------------===//
+// IndexSwitchOp
+//===----------------------------------------------------------------------===//
+
+bool IndexSwitchOp::isRegionPromotable(const MemorySlot &slot, Region *region,
+ bool hasValueStores) {
+ return true;
+}
+
+void IndexSwitchOp::setupPromotion(
+ const MemorySlot &slot, Value reachingDef, bool hasValueStores,
+ llvm::SmallMapVector<Region *, Value, 2> ®ionsToProcess) {
+ regionsToProcess.insert({&getDefaultRegion(), reachingDef});
+ for (Region &caseRegion : getCaseRegions())
+ regionsToProcess.insert({&caseRegion, reachingDef});
+}
+
+Value IndexSwitchOp::finalizePromotion(
+ const MemorySlot &slot, Value reachingDef, bool hasValueStores,
+ llvm::DenseMap<Block *, Value> &reachingAtBlockEnd, OpBuilder &builder) {
+ if (!hasValueStores)
+ return reachingDef;
+
+ IRRewriter rewriter(builder);
+
+ // Update the yield terminators to return the newly defined reaching
+ // definition.
+ updateTerminator<YieldOp>(&getDefaultRegion().back(), reachingDef,
+ reachingAtBlockEnd);
+ for (Region &caseRegion : getCaseRegions())
+ updateTerminator<YieldOp>(&caseRegion.back(), reachingDef,
+ reachingAtBlockEnd);
+
+ SmallVector<Type> resultTypes(getResultTypes());
+ resultTypes.push_back(slot.elemType);
+
+ Operation *newOp =
+ replaceWithNewResults(rewriter, getOperation(), resultTypes);
+ return newOp->getResults().back();
+}
+
+//===----------------------------------------------------------------------===//
+// ParallelOp
+//===----------------------------------------------------------------------===//
+
+bool ParallelOp::isRegionPromotable(const MemorySlot &slot, Region *region,
+ bool hasValueStores) {
+ // The ParallelOp body can be ran in parallel, thus does not support sequenced
+ // value passing. Therefore only loads can be handled.
+ return !hasValueStores;
+}
+
+void ParallelOp::setupPromotion(
+ const MemorySlot &slot, Value reachingDef, bool hasValueStores,
+ llvm::SmallMapVector<Region *, Value, 2> ®ionsToProcess) {
+ assert(!hasValueStores && "ParallelOp does not support stores");
+ regionsToProcess.insert({&getBodyRegion(), reachingDef});
+}
+
+Value ParallelOp::finalizePromotion(
+ const MemorySlot &slot, Value reachingDef, bool hasValueStores,
+ llvm::DenseMap<Block *, Value> &reachingAtBlockEnd, OpBuilder &builder) {
+ assert(!hasValueStores && "ParallelOp does not support stores");
+ return reachingDef;
+}
----------------
gysit wrote:
No strong opinion from my side just a suggestion.
https://github.com/llvm/llvm-project/pull/185036
More information about the Mlir-commits
mailing list