[Mlir-commits] [mlir] [MLIR][Mem2Reg] Add support for region control flow and SCF (PR #185036)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Mar 6 08:40:00 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: None (tdegioanni-nvidia)

<details>
<summary>Changes</summary>

This PR adds support for region control-flow. Region control-flow and CFG can be mixed together in the same program. See the accompanying RFC for some design considerations.

Beyond the considerations in the RFC, a few minor changes were introduced:

- Calling the visitor hook for defined values is now deferred to the end of promotion.
- The lazy creation of default values has been moved to the places where it happens to prepare for a future change where it is actually lazy. Documentation about it not working as intended for now was also added.

All SCF operations are supported, including `forall` and `parallel`, which is pretty cool I think.

I am sorry in advance for git diff displaying a really bad diff for Mem2Reg.cpp around where the liveness analysis used to be. Do consider simply reading this part of the code off the file.

As a disclaimer, I designed all the test cases myself, but I used a large amount of matrix multiplications to produce the corresponding IR and FileCheck tests. I have reviewed them carefully and they correspond to my intent.

---

Patch is 96.26 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/185036.diff


10 Files Affected:

- (modified) mlir/include/mlir/Dialect/SCF/IR/SCF.h (+1) 
- (modified) mlir/include/mlir/Dialect/SCF/IR/SCFOps.td (+9) 
- (modified) mlir/include/mlir/Interfaces/MemorySlotInterfaces.td (+90-3) 
- (modified) mlir/lib/Dialect/SCF/IR/CMakeLists.txt (+1) 
- (added) mlir/lib/Dialect/SCF/IR/MemorySlot.cpp (+347) 
- (modified) mlir/lib/Transforms/Mem2Reg.cpp (+349-192) 
- (modified) mlir/test/Dialect/LLVMIR/mem2reg.mlir (+3-7) 
- (added) mlir/test/Dialect/SCF/mem2reg-reject.mlir (+160) 
- (added) mlir/test/Dialect/SCF/mem2reg.mlir (+819) 
- (modified) mlir/test/Transforms/mem2reg.mlir (+74) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCF.h b/mlir/include/mlir/Dialect/SCF/IR/SCF.h
index e754a04b0903a..44cbb458d94fe 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCF.h
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCF.h
@@ -22,6 +22,7 @@
 #include "mlir/Interfaces/DestinationStyleOpInterface.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/LoopLikeInterface.h"
+#include "mlir/Interfaces/MemorySlotInterfaces.h"
 #include "mlir/Interfaces/ParallelCombiningOpInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Interfaces/ViewLikeInterface.h"
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index a08cf3c95e6ce..abc6f79bb09b2 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -19,6 +19,7 @@ include "mlir/IR/RegionKindInterface.td"
 include "mlir/Dialect/SCF/IR/DeviceMappingInterface.td"
 include "mlir/Interfaces/DestinationStyleOpInterface.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
+include "mlir/Interfaces/MemorySlotInterfaces.td"
 include "mlir/Interfaces/ParallelCombiningOpInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/Interfaces/ViewLikeInterface.td"
@@ -78,6 +79,7 @@ def ConditionOp : SCF_Op<"condition", [
 
 def ExecuteRegionOp : SCF_Op<"execute_region", [
     DeclareOpInterfaceMethods<RegionBranchOpInterface, ["getSuccessorInputs"]>,
+    DeclareOpInterfaceMethods<PromotableRegionOpInterface>,
     RecursiveMemoryEffects]> {
   let summary = "operation that executes its region exactly once";
   let description = [{
@@ -161,6 +163,7 @@ def ForOp : SCF_Op<"for",
        ConditionallySpeculatable,
        DeclareOpInterfaceMethods<RegionBranchOpInterface,
         ["getEntrySuccessorOperands", "getSuccessorInputs"]>,
+       DeclareOpInterfaceMethods<PromotableRegionOpInterface>,
        SingleBlockImplicitTerminator<"scf::YieldOp">,
        RecursiveMemoryEffects]> {
   let summary = "for operation";
@@ -329,6 +332,7 @@ def ForallOp : SCF_Op<"forall", [
        RecursiveMemoryEffects,
        SingleBlockImplicitTerminator<"scf::InParallelOp">,
        DeclareOpInterfaceMethods<RegionBranchOpInterface>,
+       DeclareOpInterfaceMethods<PromotableRegionOpInterface>,
        DestinationStyleOpInterface,
        HasParallelRegion
      ]> {
@@ -701,6 +705,7 @@ def InParallelOp : SCF_Op<"forall.in_parallel", [
 def IfOp : SCF_Op<"if", [DeclareOpInterfaceMethods<RegionBranchOpInterface, [
     "getNumRegionInvocations", "getRegionInvocationBounds",
     "getEntrySuccessorRegions", "getSuccessorInputs"]>,
+    DeclareOpInterfaceMethods<PromotableRegionOpInterface>,
     InferTypeOpAdaptor, SingleBlockImplicitTerminator<"scf::YieldOp">,
     RecursiveMemoryEffects, RecursivelySpeculatable, NoRegionArguments]> {
   let summary = "if-then-else operation";
@@ -806,6 +811,7 @@ def ParallelOp : SCF_Op<"parallel",
           "getLoopLowerBounds", "getLoopUpperBounds", "getLoopSteps"]>,
      RecursiveMemoryEffects,
      DeclareOpInterfaceMethods<RegionBranchOpInterface>,
+     DeclareOpInterfaceMethods<PromotableRegionOpInterface>,
      SingleBlockImplicitTerminator<"scf::ReduceOp">,
      HasParallelRegion]> {
   let summary = "parallel for operation";
@@ -904,6 +910,7 @@ def ParallelOp : SCF_Op<"parallel",
 
 def ReduceOp : SCF_Op<"reduce", [
     Terminator, HasParent<"ParallelOp">, RecursiveMemoryEffects,
+    DeclareOpInterfaceMethods<PromotableRegionOpInterface>,
     DeclareOpInterfaceMethods<RegionBranchTerminatorOpInterface>]> {
   let summary = "reduce operation for scf.parallel";
   let description = [{
@@ -986,6 +993,7 @@ def WhileOp : SCF_Op<"while",
         ["getEntrySuccessorOperands", "getSuccessorInputs"]>,
      DeclareOpInterfaceMethods<LoopLikeOpInterface,
         ["getRegionIterArgs", "getYieldedValuesMutable"]>,
+     DeclareOpInterfaceMethods<PromotableRegionOpInterface>,
      RecursiveMemoryEffects, SingleBlock]> {
   let summary = "a generic 'while' loop";
   let description = [{
@@ -1135,6 +1143,7 @@ def WhileOp : SCF_Op<"while",
 
 def IndexSwitchOp : SCF_Op<"index_switch", [RecursiveMemoryEffects,
     SingleBlockImplicitTerminator<"scf::YieldOp">,
+    DeclareOpInterfaceMethods<PromotableRegionOpInterface>,
     DeclareOpInterfaceMethods<RegionBranchOpInterface,
                               ["getRegionInvocationBounds",
                                "getEntrySuccessorRegions",
diff --git a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
index fbce2fa1d043d..5c2c0c9248317 100644
--- a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
+++ b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
@@ -53,6 +53,8 @@ def PromotableAllocationOpInterface
     InterfaceMethod<[{
         Hook triggered for every new block argument added to a block.
         This will only be called for slots declared by this operation.
+        This function is called after removal of blocking uses, meaning
+        only operations that will be deleted remain users of the slot.
 
         The builder is located at the beginning of the block on call. All IR
         mutations must happen through the builder.
@@ -240,7 +242,7 @@ def PromotableOpInterface : OpInterface<"PromotableOpInterface"> {
         operation will be called after the main mutation stage finishes
         (i.e., after all ops have been processed with `removeBlockingUses`).
 
-        Operations should only the replaced values if the intended
+        Operations should only visit the replaced values if the intended
         transformation applies to all the replaced values. Furthermore, replaced
         values must not be deleted.
       }], "bool", "requiresReplacedValues", (ins), [{}],
@@ -263,6 +265,91 @@ def PromotableOpInterface : OpInterface<"PromotableOpInterface"> {
   ];
 }
 
+def PromotableRegionOpInterface
+    : OpInterface<"PromotableRegionOpInterface"> {
+  let description = [{
+    Describes an operation for which memory slots can be promoted to SSA values
+    within the operation's regions.
+  }];
+  let cppNamespace = "::mlir";
+
+  let methods = [
+    InterfaceMethod<[{
+        Returns true when the provided region of the operation can be analyzed
+        for promotion. The provided region must be a child of the operation's
+        region.
+        The `hasValueStores` flag indicates whether the region contains
+        store-like operations that write to the memory slot.
+      }], "bool", "isRegionPromotable",
+      (ins
+        "const ::mlir::MemorySlot &":$slot,
+        "::mlir::Region *":$region,
+        "bool":$hasValueStores
+      )
+    >,
+    InterfaceMethod<[{
+        Called before processing the nested regions in the operation.
+
+        Based on the `reachingDef` value representing the value in the memory
+        slot at the entry into the operation, `setupPromotion` fills in the
+        `regionsToProcess` with the reaching definition at the entry of
+        all its promotable regions.
+
+        `setupPromotion` is allowed to mutate
+        the operation in place, including its nested regions, but cannot
+        delete existing operations or modify successor-bearing terminators.
+        Other mutations are not allowed.
+
+        The `hasValueStores` flag indicates whether the regions contain
+        `store`-like operations that write to the memory slot. This field can be
+        used to reduce the amount of book-keeping required to track the reaching
+        definitions.
+      }], "void", "setupPromotion",
+      (ins
+        "const ::mlir::MemorySlot &":$slot,
+        "::mlir::Value":$reachingDef,
+        "bool":$hasValueStores,
+        "::llvm::SmallMapVector<::mlir::Region *, ::mlir::Value, 2> &":$regionsToProcess
+      )
+    >,
+    InterfaceMethod<[{
+        Called after promotion has been completed in all the relevant regions.
+
+        Returns the new reaching definition at the exit of the operation. For
+        this purpose, it is allowed to mutate the operation using the provided
+        `builder`, along with its region contents. However, all blocks within
+        the existing regions must remain valid and no new blocks may be added.
+        As a result, the operation is allowed to be cloned and replaced only
+        if its region content is moved from the original operation and not
+        copied. Operations with an effect on the value of the slot must not
+        change said effect (for example, new control flow that could change
+        reaching definitions for a block is not allowed).
+
+        The `entryReachingDef` is the reaching definition at the entry of the
+        region operation.
+
+        The `reachingAtBlockEnd` map contains the reaching definitions after all
+        the terminators within the regions of the operation. If a block of the
+        region is not present in the map, it is either dead code or within a
+        region that does not interact with the value of the slot.
+
+        The `hasValueStores` flag indicates whether the regions contain
+        `store`-like operations that write to the memory slot. This field can be
+        used to reduce the amount of book-keeping required to track the reaching
+        definitions.
+      }],
+      "::mlir::Value", "finalizePromotion",
+      (ins
+        "const ::mlir::MemorySlot &":$slot,
+        "::mlir::Value":$entryReachingDef,
+        "bool":$hasValueStores,
+        "::llvm::DenseMap<::mlir::Block *, ::mlir::Value> &":$reachingAtBlockEnd,
+        "::mlir::OpBuilder &":$builder
+      )
+    >,
+  ];
+}
+
 def DestructurableAllocationOpInterface
   : OpInterface<"DestructurableAllocationOpInterface"> {
   let description = [{
@@ -304,7 +391,7 @@ def DestructurableAllocationOpInterface
     >,
     InterfaceMethod<[{
         Hook triggered once the destructuring of a slot is complete, meaning the
-        original slot is no longer being refered to and could be deleted.
+        original slot is no longer being referred to and could be deleted.
         This will only be called for slots declared by this operation.
 
         Must return a new destructurable allocation op if this hook creates
@@ -328,7 +415,7 @@ def SafeMemorySlotAccessOpInterface
   let methods = [
     InterfaceMethod<[{
         Returns whether all accesses in this operation to the provided slot are
-        done in a safe manner. To be safe, the access most only access the slot
+        done in a safe manner. To be safe, the access must only access the slot
         inside the bounds that its type implies.
 
         If the safety of the accesses depends on the safety of the accesses to
diff --git a/mlir/lib/Dialect/SCF/IR/CMakeLists.txt b/mlir/lib/Dialect/SCF/IR/CMakeLists.txt
index b111117410ba3..fca28c5209e2d 100644
--- a/mlir/lib/Dialect/SCF/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/SCF/IR/CMakeLists.txt
@@ -1,6 +1,7 @@
 add_mlir_dialect_library(MLIRSCFDialect
   SCF.cpp
   DeviceMappingInterface.cpp
+  MemorySlot.cpp
   ValueBoundsOpInterfaceImpl.cpp
 
   ADDITIONAL_HEADER_DIRS
diff --git a/mlir/lib/Dialect/SCF/IR/MemorySlot.cpp b/mlir/lib/Dialect/SCF/IR/MemorySlot.cpp
new file mode 100644
index 0000000000000..89f816ff8bf69
--- /dev/null
+++ b/mlir/lib/Dialect/SCF/IR/MemorySlot.cpp
@@ -0,0 +1,347 @@
+//===- MemorySlot.cpp - Memory Slot interface implementations for SCF -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+
+using namespace mlir;
+using namespace mlir::scf;
+
+//===----------------------------------------------------------------------===//
+// Helper functions
+//===----------------------------------------------------------------------===//
+
+/// Adds the corresponding reaching definition to the terminator of the block if
+/// the terminator is of the provided type.
+template <typename TermTy>
+static void
+updateTerminator(Block *block, Value defaultReachingDef,
+                 llvm::DenseMap<Block *, Value> &reachingAtBlockEnd) {
+  Operation *terminator = block->getTerminator();
+  if (!isa<TermTy>(terminator))
+    return;
+  Value blockReachingDef = reachingAtBlockEnd[block];
+  if (!blockReachingDef) {
+    // Block is dead code or the region is not using the slot, so we use the
+    // default provided reaching definition.
+    blockReachingDef = defaultReachingDef;
+  }
+  terminator->insertOperands(terminator->getNumOperands(), {blockReachingDef});
+}
+
+/// Creates a shallow copy of an operation with new result types, moving the
+/// regions out of the original operation and deleting the original operation.
+static Operation *replaceWithNewResults(RewriterBase &rewriter, Operation *op,
+                                        TypeRange resultTypes) {
+  RewriterBase::InsertionGuard guard(rewriter);
+  rewriter.setInsertionPoint(op);
+  Operation *newOp =
+      mlir::cloneWithoutRegions(rewriter, op, resultTypes, op->getOperands());
+  rewriter.startOpModification(newOp);
+  rewriter.startOpModification(op);
+  for (unsigned int i : llvm::seq(op->getNumRegions()))
+    newOp->getRegion(i).takeBody(op->getRegion(i));
+  rewriter.finalizeOpModification(op);
+  rewriter.finalizeOpModification(newOp);
+
+  SmallVector<Value> replacementValues(newOp->getResults().drop_back());
+  rewriter.replaceAllOpUsesWith(op, replacementValues);
+  rewriter.eraseOp(op);
+  return newOp;
+}
+
+//===----------------------------------------------------------------------===//
+// ExecuteRegionOp
+//===----------------------------------------------------------------------===//
+
+bool ExecuteRegionOp::isRegionPromotable(const MemorySlot &slot, Region *region,
+                                         bool hasValueStores) {
+  return true;
+}
+
+void ExecuteRegionOp::setupPromotion(
+    const MemorySlot &slot, Value reachingDef, bool hasValueStores,
+    llvm::SmallMapVector<Region *, Value, 2> &regionsToProcess) {
+  regionsToProcess.insert({&getRegion(), reachingDef});
+}
+
+Value ExecuteRegionOp::finalizePromotion(
+    const MemorySlot &slot, Value reachingDef, bool hasValueStores,
+    llvm::DenseMap<Block *, Value> &reachingAtBlockEnd, OpBuilder &builder) {
+  if (!hasValueStores)
+    return reachingDef;
+
+  // Update the yield terminators to return the newly defined reaching
+  // definition.
+  for (Block &block : getRegion().getBlocks())
+    updateTerminator<YieldOp>(&block, reachingDef, reachingAtBlockEnd);
+
+  SmallVector<Type> resultTypes(getResultTypes());
+  resultTypes.push_back(slot.elemType);
+
+  IRRewriter rewriter(builder);
+  Operation *newOp =
+      replaceWithNewResults(rewriter, getOperation(), resultTypes);
+  return newOp->getResults().back();
+}
+
+//===----------------------------------------------------------------------===//
+// ForOp
+//===----------------------------------------------------------------------===//
+
+bool ForOp::isRegionPromotable(const MemorySlot &slot, Region *region,
+                               bool hasValueStores) {
+  return true;
+}
+
+void ForOp::setupPromotion(
+    const MemorySlot &slot, Value reachingDef, bool hasValueStores,
+    llvm::SmallMapVector<Region *, Value, 2> &regionsToProcess) {
+  Region &bodyRegion = getBodyRegion();
+  if (!hasValueStores) {
+    regionsToProcess.insert({&bodyRegion, reachingDef});
+    return;
+  }
+
+  getInitArgsMutable().append(reachingDef);
+  bodyRegion.addArgument(slot.elemType, slot.ptr.getLoc());
+  regionsToProcess.insert({&bodyRegion, bodyRegion.getArguments().back()});
+}
+
+Value ForOp::finalizePromotion(
+    const MemorySlot &slot, Value reachingDef, bool hasValueStores,
+    llvm::DenseMap<Block *, Value> &reachingAtBlockEnd, OpBuilder &builder) {
+  if (!hasValueStores)
+    return reachingDef;
+
+  // Update the yield terminator to return the newly defined reaching
+  // definition.
+  updateTerminator<YieldOp>(getBody(), reachingDef, reachingAtBlockEnd);
+
+  SmallVector<Type> resultTypes(getResultTypes());
+  resultTypes.push_back(slot.elemType);
+
+  IRRewriter rewriter(builder);
+  Operation *newOp =
+      replaceWithNewResults(rewriter, getOperation(), resultTypes);
+  return newOp->getResults().back();
+}
+
+//===----------------------------------------------------------------------===//
+// ForallOp
+//===----------------------------------------------------------------------===//
+
+bool ForallOp::isRegionPromotable(const MemorySlot &slot, Region *region,
+                                  bool hasValueStores) {
+  // The ForallOp body can be ran in parallel, thus does not support sequenced
+  // value passing. Therefore only loads can be handled.
+  return !hasValueStores;
+}
+
+void ForallOp::setupPromotion(
+    const MemorySlot &slot, Value reachingDef, bool hasValueStores,
+    llvm::SmallMapVector<Region *, Value, 2> &regionsToProcess) {
+  assert(!hasValueStores && "ForallOp does not support stores");
+  regionsToProcess.insert({&getBodyRegion(), reachingDef});
+}
+
+Value ForallOp::finalizePromotion(
+    const MemorySlot &slot, Value reachingDef, bool hasValueStores,
+    llvm::DenseMap<Block *, Value> &reachingAtBlockEnd, OpBuilder &builder) {
+  assert(!hasValueStores && "ForallOp does not support stores");
+  return reachingDef;
+}
+
+//===----------------------------------------------------------------------===//
+// IfOp
+//===----------------------------------------------------------------------===//
+
+bool IfOp::isRegionPromotable(const MemorySlot &slot, Region *region,
+                              bool hasValueStores) {
+  return true;
+}
+
+void IfOp::setupPromotion(
+    const MemorySlot &slot, Value reachingDef, bool hasValueStores,
+    llvm::SmallMapVector<Region *, Value, 2> &regionsToProcess) {
+  regionsToProcess.insert({&getThenRegion(), reachingDef});
+  regionsToProcess.insert({&getElseRegion(), reachingDef});
+}
+
+Value IfOp::finalizePromotion(
+    const MemorySlot &slot, Value reachingDef, bool hasValueStores,
+    llvm::DenseMap<Block *, Value> &reachingAtBlockEnd, OpBuilder &builder) {
+  if (!hasValueStores)
+    return reachingDef;
+
+  IRRewriter rewriter(builder);
+
+  // Update the yield terminators to return the newly defined reaching
+  // definition.
+  updateTerminator<YieldOp>(&getThenRegion().back(), reachingDef,
+                            reachingAtBlockEnd);
+  if (getElseRegion().hasOneBlock()) {
+    updateTerminator<YieldOp>(&getElseRegion().back(), reachingDef,
+                              reachingAtBlockEnd);
+  } else {
+    OpBuilder::InsertionGuard guard(rewriter);
+    rewriter.createBlock(&getElseRegion());
+    YieldOp::create(rewriter, getOperation()->getLoc(), reachingDef);
+  }
+
+  SmallVector<Type> resultTypes(getResultTypes());
+  resultTypes.push_back(slot.elemType);
+
+  Operation *newOp =
+      replaceWithNewResults(rewriter, getOperation(), resultTypes);
+  return newOp->getResults().back();
+}
+
+//===----------------------------------------------------------------------===//
+// IndexSwitchOp
+//===----------------------------------------------------------------------===//
+
+bool IndexSwitchOp::isRegionPromotable(const MemorySlot &slot, Region *region,
+                                       bool hasValueStores) {
+  return true;
+}
+
+void IndexSwitchOp::setupPromotion(
+    const MemorySlot &slot, Value reachingDef, bool hasValueStores,
+    llvm::SmallMapVector<Region *, Value, 2> &regionsToProcess) {
+  regionsToProcess.insert({&getDefaultRegion(), reachingDef});
+  for (Region &caseRegion : getCaseRegions())
+    regionsToProcess.insert({&caseRegion, reachingDef});
+}
+
+Value IndexSwitchOp::finalizePromotion(
+    const MemorySlot &slot, Value reachingDef, bool hasValueStores,
+    llvm::DenseMap<Block *, Value> &reachingAtBlockEnd, OpBuilder &builder) {
+  if (!hasValueStores)
+    return reachingDef;
+
+  IRRewriter rewriter(builder);
+
+  // Update the yield terminators to return the newly defined reaching
+  // definition.
+  updateTerminator<YieldOp>(&getDefaultRegion().back(), reachingDef,
+                            reachingAtBlockEnd);
+  for (Regi...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/185036


More information about the Mlir-commits mailing list