[Mlir-commits] [mlir] [MLIR] Add single definition multiple regions for mem2reg (PR #89107)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Apr 17 10:31:59 PDT 2024


https://github.com/fanfuqiang created https://github.com/llvm/llvm-project/pull/89107

The mem2reg restricts all cases that can be processed to stay at the same mlir region now, this is because the iterated dominant frontier calculate can not support multiple regions currently.

This commit adds simple case support for the existing mem2reg pass, that when we have a single definition and multiple uses, the definition and uses stay at multiple regions, and the definition dominates all the uses. so we can forward the value to all the uses directly.

>From 8b7da9792179ce080544f958b0a0987845bc7059 Mon Sep 17 00:00:00 2001
From: fanfuqiang <fuqiang.fan at mthreads.com>
Date: Thu, 18 Apr 2024 00:57:56 +0800
Subject: [PATCH] [MLIR] Add single definition multiple regions for mem2reg

The mem2reg restricts all cases that can be processed to stay at the
same mlir region now, this is because the iterated dominant frontier
calculate can not support multiple regions currently.

This commit adds simple case support for the existing mem2reg pass,
that when we have a single definition and multiple uses, the definition
and uses stay at multiple regions, and the definition dominates all
the uses. so we can forward the value to all the uses directly.

Signed-off-by: fanfuqiang <fuqiang.fan at mthreads.com>
---
 mlir/lib/Transforms/Mem2Reg.cpp               | 120 +++++++++++++++---
 .../Transform/mem2reg-hybird-dialects.mlir    |  80 ++++++++++++
 2 files changed, 180 insertions(+), 20 deletions(-)
 create mode 100644 mlir/test/Dialect/Transform/mem2reg-hybird-dialects.mlir

diff --git a/mlir/lib/Transforms/Mem2Reg.cpp b/mlir/lib/Transforms/Mem2Reg.cpp
index abe565ea862f8f..3b8c3688a739cc 100644
--- a/mlir/lib/Transforms/Mem2Reg.cpp
+++ b/mlir/lib/Transforms/Mem2Reg.cpp
@@ -126,6 +126,9 @@ class MemorySlotPromotionAnalyzer {
   /// returns nothing otherwise.
   std::optional<MemorySlotPromotionInfo> computeInfo();
 
+  /// The slot has single definition or not.
+  auto getSingleDefining() { return singleDefining; };
+
 private:
   /// Computes the transitive uses of the slot that block promotion. This finds
   /// uses that would block the promotion, checks that the operation has a
@@ -156,6 +159,11 @@ class MemorySlotPromotionAnalyzer {
   MemorySlot slot;
   DominanceInfo &dominance;
   const DataLayout &dataLayout;
+
+  /// If there is only one defining operation for the current slot, will save
+  /// this pointer part, otherwise empty. And when meeting the define and uses
+  /// in the same region the `bool` part will been set true, otherwise false.
+  llvm::PointerIntPair<Operation *, 1, bool> singleDefining{};
 };
 
 /// The MemorySlotPromoter handles the state of promoting a memory slot. It
@@ -166,7 +174,7 @@ class MemorySlotPromoter {
   MemorySlotPromoter(MemorySlot slot, PromotableAllocationOpInterface allocator,
                      RewriterBase &rewriter, DominanceInfo &dominance,
                      MemorySlotPromotionInfo info,
-                     const Mem2RegStatistics &statistics);
+                     const Mem2RegStatistics &statistics, Operation *def);
 
   /// Actually promotes the slot by mutating IR. Promoting a slot DOES
   /// invalidate the MemorySlotPromotionInfo of other slots. Preparation of
@@ -186,6 +194,9 @@ class MemorySlotPromoter {
   /// This method must only be called at most once per region.
   void computeReachingDefInRegion(Region *region, Value reachingDef);
 
+  ///
+  void computeReachingDefOfSingleDefiningUses();
+
   /// Removes the blocking uses of the slot, in topological order.
   void removeBlockingUses();
 
@@ -206,6 +217,7 @@ class MemorySlotPromoter {
   DominanceInfo &dominance;
   MemorySlotPromotionInfo info;
   const Mem2RegStatistics &statistics;
+  Operation *singleDefining;
 };
 
 } // namespace
@@ -213,9 +225,11 @@ class MemorySlotPromoter {
 MemorySlotPromoter::MemorySlotPromoter(
     MemorySlot slot, PromotableAllocationOpInterface allocator,
     RewriterBase &rewriter, DominanceInfo &dominance,
-    MemorySlotPromotionInfo info, const Mem2RegStatistics &statistics)
+    MemorySlotPromotionInfo info, const Mem2RegStatistics &statistics,
+    Operation *singleDefining)
     : slot(slot), allocator(allocator), rewriter(rewriter),
-      dominance(dominance), info(std::move(info)), statistics(statistics) {
+      dominance(dominance), info(std::move(info)), statistics(statistics),
+      singleDefining(singleDefining) {
 #ifndef NDEBUG
   auto isResultOrNewBlockArgument = [&]() {
     if (BlockArgument arg = dyn_cast<BlockArgument>(slot.ptr))
@@ -253,6 +267,7 @@ LogicalResult MemorySlotPromotionAnalyzer::computeBlockingUses(
     blockingUses.insert(&use);
   }
 
+  size_t totalStores{};
   // Then, propagate the requirements for the removal of uses. The
   // topologically-sorted forward slice allows for all blocking uses of an
   // operation to have been computed before it is reached. Operations are
@@ -275,8 +290,11 @@ LogicalResult MemorySlotPromotionAnalyzer::computeBlockingUses(
                                        dataLayout))
         return failure();
     } else if (auto promotable = dyn_cast<PromotableMemOpInterface>(user)) {
-      if (!promotable.canUsesBeRemoved(slot, blockingUses, newBlockingUses,
-                                       dataLayout))
+      if (promotable.canUsesBeRemoved(slot, blockingUses, newBlockingUses,
+                                      dataLayout))
+        promotable.storesTo(slot) ? singleDefining.setPointer(user),
+            totalStores++         : 0;
+      else
         return failure();
     } else {
       // An operation that has blocking uses must be promoted. If it is not
@@ -294,14 +312,52 @@ LogicalResult MemorySlotPromotionAnalyzer::computeBlockingUses(
     }
   }
 
-  // Because this pass currently only supports analysing the parent region of
-  // the slot pointer, if a promotable memory op that needs promotion is outside
-  // of this region, promotion must fail because it will be impossible to
-  // provide a valid `reachingDef` for it.
-  for (auto &[toPromote, _] : userToBlockingUses)
+  // The define uses web only have one definition, It is the potential case.
+  totalStores != 1 ? singleDefining.setPointer(nullptr) : (void)0;
+
+  // The single definition can not dominate all of the uses, there are some
+  // uses need default value and block arguments. But current dominate
+  // frontier algorithms only support single region, so failed.
+  auto leagelSingleDefiningMultiRegions = [this](BlockingUsesMap &userMap) {
+    if (singleDefining.getPointer() == nullptr)
+      return failure();
+
+    // The single definition dominate all the uses, we can ignore whether
+    // all of the uses and definotion in a same region.
+    for (auto &[user, _] : userMap)
+      if (!dominance.dominates(singleDefining.getPointer(), user))
+        return failure();
+    // The same region case is fail in here we can konwn, clear the bool value
+    // of `singleDefing`, otherwise we may confuse weather same region or single
+    // define multiple region. Place this above return success() because when
+    // return failure() there is not any afterwards, will clear everything
+    // prepare the next slot.
+    singleDefining.setInt(false);
+    return success();
+  };
+
+  // Because we will first check weather def-uses at same region, if this
+  // success, we will missing the opportunity clear the pointer value
+  // of`singleDefinig`, so we need to known weather this is the same region case
+  // in the code afterward. If same region check fails, this value will been
+  // clear at the single defining multiple regions case.
+  singleDefining.setInt(true);
+  // Because this pass currently only supports has only one defining operation
+  // for the , slot or analysing the parent region of the slot pointer, if a
+  // promotable memory op that needs promotion is outside of this region,
+  // promotion must fail because it will be impossible to provide a valid
+  // `reachingDef` for it.
+  for (auto &[toPromote, _] : userToBlockingUses) {
     if (isa<PromotableMemOpInterface>(toPromote) &&
-        toPromote->getParentRegion() != slot.ptr.getParentRegion())
+        toPromote->getParentRegion() != slot.ptr.getParentRegion()) {
+      // Same region fails, maybe we can save this. Checkes the single defining
+      // operation and its uses in multiple regions can been proceed the stored
+      // value forwards.
+      if (succeeded(leagelSingleDefiningMultiRegions(userToBlockingUses)))
+        return success();
       return failure();
+    }
+  }
 
   return success();
 }
@@ -366,7 +422,7 @@ SmallPtrSet<Block *, 16> MemorySlotPromotionAnalyzer::computeSlotLiveIn(
 using IDFCalculator = llvm::IDFCalculatorBase<Block, false>;
 void MemorySlotPromotionAnalyzer::computeMergePoints(
     SmallPtrSetImpl<Block *> &mergePoints) {
-  if (slot.ptr.getParentRegion()->hasOneBlock())
+  if (!singleDefining.getInt() || slot.ptr.getParentRegion()->hasOneBlock())
     return;
 
   IDFCalculator idfCalculator(dominance.getDomTree(slot.ptr.getParentRegion()));
@@ -423,12 +479,23 @@ MemorySlotPromotionAnalyzer::computeInfo() {
   return info;
 }
 
+void MemorySlotPromoter::computeReachingDefOfSingleDefiningUses() {
+  auto definingOp = cast<PromotableMemOpInterface>(singleDefining);
+  assert(definingOp.storesTo(slot));
+  rewriter.setInsertionPointAfter(definingOp);
+  Value stored = definingOp.getStored(slot, rewriter);
+  assert(stored && "a memory operation storing to a slot must provide a "
+                   "new definition of the slot");
+  replacedValuesMap[definingOp] = stored;
+
+  for (auto &[op, _] : info.userToBlockingUses)
+    if (auto memOp = dyn_cast<PromotableMemOpInterface>(op))
+      reachingDefs.insert({memOp, stored});
+}
+
 Value MemorySlotPromoter::computeReachingDefInBlock(Block *block,
                                                     Value reachingDef) {
-  SmallVector<Operation *> blockOps;
-  for (Operation &op : block->getOperations())
-    blockOps.push_back(&op);
-  for (Operation *op : blockOps) {
+  for (Operation &op : block->getOperations()) {
     if (auto memOp = dyn_cast<PromotableMemOpInterface>(op)) {
       if (info.userToBlockingUses.contains(memOp))
         reachingDefs.insert({memOp, reachingDef});
@@ -550,8 +617,10 @@ void MemorySlotPromoter::removeBlockingUses() {
   llvm::SmallVector<Operation *> usersToRemoveUses(
       llvm::make_first_range(info.userToBlockingUses));
 
-  // Sort according to dominance.
-  dominanceSort(usersToRemoveUses, *slot.ptr.getParentBlock()->getParent());
+  // Sort according to dominance, but not at the single definition multiple
+  // regions case.
+  if (!singleDefining)
+    dominanceSort(usersToRemoveUses, *slot.ptr.getParentBlock()->getParent());
 
   llvm::SmallVector<Operation *> toErase;
   // List of all replaced values in the slot.
@@ -561,6 +630,13 @@ void MemorySlotPromoter::removeBlockingUses() {
   for (Operation *toPromote : llvm::reverse(usersToRemoveUses)) {
     if (auto toPromoteMemOp = dyn_cast<PromotableMemOpInterface>(toPromote)) {
       Value reachingDef = reachingDefs.lookup(toPromoteMemOp);
+
+#ifndef NDEBUG
+      if (singleDefining)
+        assert(reachingDef && "must have a reaching definition in single "
+                              "definition multiple regions case");
+#endif // NDEBUG
+
       // If no reaching definition is known, this use is outside the reach of
       // the slot. The default value should thus be used.
       if (!reachingDef)
@@ -598,7 +674,8 @@ void MemorySlotPromoter::removeBlockingUses() {
 }
 
 void MemorySlotPromoter::promoteSlot() {
-  computeReachingDefInRegion(slot.ptr.getParentRegion(), {});
+  singleDefining ? computeReachingDefOfSingleDefiningUses()
+                 : computeReachingDefInRegion(slot.ptr.getParentRegion(), {});
 
   // Now that reaching definitions are known, remove all users.
   removeBlockingUses();
@@ -643,7 +720,10 @@ LogicalResult mlir::tryToPromoteMemorySlots(
       std::optional<MemorySlotPromotionInfo> info = analyzer.computeInfo();
       if (info) {
         MemorySlotPromoter(slot, allocator, rewriter, dominance,
-                           std::move(*info), statistics)
+                           std::move(*info), statistics,
+                           analyzer.getSingleDefining().getInt()
+                               ? nullptr
+                               : analyzer.getSingleDefining().getPointer())
             .promoteSlot();
         promotedAny = true;
       }
diff --git a/mlir/test/Dialect/Transform/mem2reg-hybird-dialects.mlir b/mlir/test/Dialect/Transform/mem2reg-hybird-dialects.mlir
new file mode 100644
index 00000000000000..c739eec9507734
--- /dev/null
+++ b/mlir/test/Dialect/Transform/mem2reg-hybird-dialects.mlir
@@ -0,0 +1,80 @@
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(func.func(mem2reg{region-simplify=false}))" --split-input-file | FileCheck %s
+
+// -----
+// CHECK-LABEL: @single_define_multiple_regions_with_for
+// CHECK-NOT: llvm.alloca
+func.func @single_define_multiple_regions_with_for(%arg0 : index, %arg1 : index, %arg2 : index) {
+  %0 = llvm.mlir.constant(4 : i32) : i32
+  %1 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
+  scf.for %i0 = %arg0 to %arg1 step %arg2 {
+    scf.for %i1 = %arg0 to %arg1 step %arg2 {
+      llvm.store %0, %1 {alignment = 8 : i64} : i32, !llvm.ptr
+      %min_cmp = arith.cmpi slt, %i0, %i1 : index
+      %min = arith.select %min_cmp, %i0, %i1 : index
+      %max_cmp = arith.cmpi sge, %i0, %i1 : index
+      %max = arith.select %max_cmp, %i0, %i1 : index
+      %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32
+      %3 = arith.index_cast %2 : i32 to index
+      scf.for %i2 = %min to %max step %i1 {
+        %val = arith.addi %3, %3 : index
+      }
+    }
+  }
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @single_define_multiple_regions_with_if
+// CHECK-NOT: llvm.alloca
+func.func @single_define_multiple_regions_with_if(%arg0 : i1, %arg1 : i32) {
+  %0 = llvm.mlir.constant(4 : i32) : i32
+  %1 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
+  llvm.store %0, %1 {alignment = 8 : i64} : i32, !llvm.ptr
+  scf.if %arg0 {
+    %3 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32
+    %4 = arith.addi %arg1, %3 : i32
+  } else {
+    %5 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32
+    %6 = arith.subi %arg1, %5 : i32
+  }
+
+  return
+}
+
+// -----
+// The definition doesn't dominate all uses, mem2reg fails.
+// CHECK-LABEL: @single_define_multiple_regions_with_if
+// CHECK: llvm.alloca
+func.func @single_define_multiple_regions_with_if_fail(%arg0 : i1, %arg1 : i32) {
+  %0 = llvm.mlir.constant(4 : i32) : i32
+  %1 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
+  scf.if %arg0 {
+    llvm.store %0, %1 {alignment = 8 : i64} : i32, !llvm.ptr
+    %3 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32
+    %4 = arith.addi %arg1, %3 : i32
+  } else {
+    %5 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32
+    %6 = arith.subi %arg1, %5 : i32
+  }
+
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @single_define_multiple_regions_with_while
+// CHECK-NOT: llvm.alloca
+func.func @single_define_multiple_regions_with_while(%arg0 : i32) {
+  %0 = llvm.mlir.constant(4 : i32) : i32
+  %1 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
+  llvm.store %0, %1 {alignment = 8 : i64} : i32, !llvm.ptr
+  scf.while : () -> () {
+    %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32
+    %3 = arith.cmpi sge, %2, %arg0 : i32
+    scf.condition(%3)
+  } do {
+    scf.yield
+  }
+  return
+}



More information about the Mlir-commits mailing list