[Mlir-commits] [mlir] [MLIR][Mem2Reg] Add support for region control flow and SCF (PR #185036)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Mar 9 06:56:51 PDT 2026
================
@@ -0,0 +1,347 @@
+//===- MemorySlot.cpp - Memory Slot interface implementations for SCF -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+
+using namespace mlir;
+using namespace mlir::scf;
+
+//===----------------------------------------------------------------------===//
+// Helper functions
+//===----------------------------------------------------------------------===//
+
+/// Adds the corresponding reaching definition to the terminator of the block if
+/// the terminator is of the provided type.
+template <typename TermTy>
+static void
+updateTerminator(Block *block, Value defaultReachingDef,
+ llvm::DenseMap<Block *, Value> &reachingAtBlockEnd) {
+ Operation *terminator = block->getTerminator();
+ if (!isa<TermTy>(terminator))
+ return;
+ Value blockReachingDef = reachingAtBlockEnd[block];
+ if (!blockReachingDef) {
+ // Block is dead code or the region is not using the slot, so we use the
+ // default provided reaching definition.
+ blockReachingDef = defaultReachingDef;
+ }
+ terminator->insertOperands(terminator->getNumOperands(), {blockReachingDef});
+}
+
+/// Creates a shallow copy of an operation with new result types, moving the
+/// regions out of the original operation and deleting the original operation.
+static Operation *replaceWithNewResults(RewriterBase &rewriter, Operation *op,
+ TypeRange resultTypes) {
+ RewriterBase::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPoint(op);
+ Operation *newOp =
+ mlir::cloneWithoutRegions(rewriter, op, resultTypes, op->getOperands());
+ rewriter.startOpModification(newOp);
+ rewriter.startOpModification(op);
+ for (unsigned int i : llvm::seq(op->getNumRegions()))
+ newOp->getRegion(i).takeBody(op->getRegion(i));
+ rewriter.finalizeOpModification(op);
+ rewriter.finalizeOpModification(newOp);
+
+ SmallVector<Value> replacementValues(newOp->getResults().drop_back());
+ rewriter.replaceAllOpUsesWith(op, replacementValues);
+ rewriter.eraseOp(op);
+ return newOp;
+}
+
+//===----------------------------------------------------------------------===//
+// ExecuteRegionOp
+//===----------------------------------------------------------------------===//
+
+bool ExecuteRegionOp::isRegionPromotable(const MemorySlot &slot, Region *region,
+ bool hasValueStores) {
+ return true;
+}
+
+void ExecuteRegionOp::setupPromotion(
+ const MemorySlot &slot, Value reachingDef, bool hasValueStores,
+ llvm::SmallMapVector<Region *, Value, 2> ®ionsToProcess) {
+ regionsToProcess.insert({&getRegion(), reachingDef});
+}
+
+Value ExecuteRegionOp::finalizePromotion(
+ const MemorySlot &slot, Value reachingDef, bool hasValueStores,
+ llvm::DenseMap<Block *, Value> &reachingAtBlockEnd, OpBuilder &builder) {
+ if (!hasValueStores)
+ return reachingDef;
+
+ // Update the yield terminators to return the newly defined reaching
+ // definition.
+ for (Block &block : getRegion().getBlocks())
+ updateTerminator<YieldOp>(&block, reachingDef, reachingAtBlockEnd);
+
+ SmallVector<Type> resultTypes(getResultTypes());
+ resultTypes.push_back(slot.elemType);
+
+ IRRewriter rewriter(builder);
+ Operation *newOp =
+ replaceWithNewResults(rewriter, getOperation(), resultTypes);
+ return newOp->getResults().back();
+}
+
+//===----------------------------------------------------------------------===//
+// ForOp
+//===----------------------------------------------------------------------===//
+
+bool ForOp::isRegionPromotable(const MemorySlot &slot, Region *region,
+ bool hasValueStores) {
+ return true;
+}
+
+void ForOp::setupPromotion(
+ const MemorySlot &slot, Value reachingDef, bool hasValueStores,
+ llvm::SmallMapVector<Region *, Value, 2> ®ionsToProcess) {
+ Region &bodyRegion = getBodyRegion();
+ if (!hasValueStores) {
+ regionsToProcess.insert({&bodyRegion, reachingDef});
+ return;
+ }
+
+ getInitArgsMutable().append(reachingDef);
+ bodyRegion.addArgument(slot.elemType, slot.ptr.getLoc());
+ regionsToProcess.insert({&bodyRegion, bodyRegion.getArguments().back()});
+}
+
+Value ForOp::finalizePromotion(
+ const MemorySlot &slot, Value reachingDef, bool hasValueStores,
+ llvm::DenseMap<Block *, Value> &reachingAtBlockEnd, OpBuilder &builder) {
+ if (!hasValueStores)
+ return reachingDef;
+
+ // Update the yield terminator to return the newly defined reaching
+ // definition.
+ updateTerminator<YieldOp>(getBody(), reachingDef, reachingAtBlockEnd);
+
+ SmallVector<Type> resultTypes(getResultTypes());
+ resultTypes.push_back(slot.elemType);
+
+ IRRewriter rewriter(builder);
+ Operation *newOp =
+ replaceWithNewResults(rewriter, getOperation(), resultTypes);
+ return newOp->getResults().back();
+}
+
+//===----------------------------------------------------------------------===//
+// ForallOp
+//===----------------------------------------------------------------------===//
+
+bool ForallOp::isRegionPromotable(const MemorySlot &slot, Region *region,
+ bool hasValueStores) {
+ // The ForallOp body can be ran in parallel, thus does not support sequenced
+ // value passing. Therefore only loads can be handled.
+ return !hasValueStores;
+}
+
+void ForallOp::setupPromotion(
+ const MemorySlot &slot, Value reachingDef, bool hasValueStores,
+ llvm::SmallMapVector<Region *, Value, 2> ®ionsToProcess) {
+ assert(!hasValueStores && "ForallOp does not support stores");
+ regionsToProcess.insert({&getBodyRegion(), reachingDef});
+}
+
+Value ForallOp::finalizePromotion(
+ const MemorySlot &slot, Value reachingDef, bool hasValueStores,
+ llvm::DenseMap<Block *, Value> &reachingAtBlockEnd, OpBuilder &builder) {
+ assert(!hasValueStores && "ForallOp does not support stores");
+ return reachingDef;
+}
+
+//===----------------------------------------------------------------------===//
+// IfOp
+//===----------------------------------------------------------------------===//
+
+bool IfOp::isRegionPromotable(const MemorySlot &slot, Region *region,
+ bool hasValueStores) {
+ return true;
+}
+
+void IfOp::setupPromotion(
+ const MemorySlot &slot, Value reachingDef, bool hasValueStores,
+ llvm::SmallMapVector<Region *, Value, 2> ®ionsToProcess) {
+ regionsToProcess.insert({&getThenRegion(), reachingDef});
+ regionsToProcess.insert({&getElseRegion(), reachingDef});
+}
+
+Value IfOp::finalizePromotion(
+ const MemorySlot &slot, Value reachingDef, bool hasValueStores,
+ llvm::DenseMap<Block *, Value> &reachingAtBlockEnd, OpBuilder &builder) {
+ if (!hasValueStores)
+ return reachingDef;
+
+ IRRewriter rewriter(builder);
+
+ // Update the yield terminators to return the newly defined reaching
+ // definition.
+ updateTerminator<YieldOp>(&getThenRegion().back(), reachingDef,
+ reachingAtBlockEnd);
+ if (getElseRegion().hasOneBlock()) {
+ updateTerminator<YieldOp>(&getElseRegion().back(), reachingDef,
+ reachingAtBlockEnd);
+ } else {
+ OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.createBlock(&getElseRegion());
+ YieldOp::create(rewriter, getOperation()->getLoc(), reachingDef);
+ }
+
+ SmallVector<Type> resultTypes(getResultTypes());
+ resultTypes.push_back(slot.elemType);
+
+ Operation *newOp =
+ replaceWithNewResults(rewriter, getOperation(), resultTypes);
+ return newOp->getResults().back();
+}
+
+//===----------------------------------------------------------------------===//
+// IndexSwitchOp
+//===----------------------------------------------------------------------===//
+
+bool IndexSwitchOp::isRegionPromotable(const MemorySlot &slot, Region *region,
+ bool hasValueStores) {
+ return true;
+}
+
+void IndexSwitchOp::setupPromotion(
+ const MemorySlot &slot, Value reachingDef, bool hasValueStores,
+ llvm::SmallMapVector<Region *, Value, 2> ®ionsToProcess) {
+ regionsToProcess.insert({&getDefaultRegion(), reachingDef});
+ for (Region &caseRegion : getCaseRegions())
+ regionsToProcess.insert({&caseRegion, reachingDef});
+}
+
+Value IndexSwitchOp::finalizePromotion(
+ const MemorySlot &slot, Value reachingDef, bool hasValueStores,
+ llvm::DenseMap<Block *, Value> &reachingAtBlockEnd, OpBuilder &builder) {
+ if (!hasValueStores)
+ return reachingDef;
+
+ IRRewriter rewriter(builder);
+
+ // Update the yield terminators to return the newly defined reaching
+ // definition.
+ updateTerminator<YieldOp>(&getDefaultRegion().back(), reachingDef,
+ reachingAtBlockEnd);
+ for (Region &caseRegion : getCaseRegions())
+ updateTerminator<YieldOp>(&caseRegion.back(), reachingDef,
+ reachingAtBlockEnd);
+
+ SmallVector<Type> resultTypes(getResultTypes());
+ resultTypes.push_back(slot.elemType);
+
+ Operation *newOp =
+ replaceWithNewResults(rewriter, getOperation(), resultTypes);
+ return newOp->getResults().back();
+}
+
+//===----------------------------------------------------------------------===//
+// ParallelOp
+//===----------------------------------------------------------------------===//
+
+bool ParallelOp::isRegionPromotable(const MemorySlot &slot, Region *region,
+ bool hasValueStores) {
+ // The ParallelOp body can be ran in parallel, thus does not support sequenced
+ // value passing. Therefore only loads can be handled.
+ return !hasValueStores;
+}
+
+void ParallelOp::setupPromotion(
+ const MemorySlot &slot, Value reachingDef, bool hasValueStores,
+ llvm::SmallMapVector<Region *, Value, 2> ®ionsToProcess) {
+ assert(!hasValueStores && "ParallelOp does not support stores");
+ regionsToProcess.insert({&getBodyRegion(), reachingDef});
+}
+
+Value ParallelOp::finalizePromotion(
+ const MemorySlot &slot, Value reachingDef, bool hasValueStores,
+ llvm::DenseMap<Block *, Value> &reachingAtBlockEnd, OpBuilder &builder) {
+ assert(!hasValueStores && "ParallelOp does not support stores");
+ return reachingDef;
+}
----------------
tdegioanni-nvidia wrote:
You know I like it explicit so people don't forget about it and get surprising behavior, but I'm not dead set on the question. This implementation is wrong in most cases, so that's why I'm a bit reluctant to bless it as "default".
https://github.com/llvm/llvm-project/pull/185036
More information about the Mlir-commits
mailing list