[Mlir-commits] [mlir] ab6a66d - Reland: [MLIR][Transforms] Fix Mem2Reg removal order to respect dominance (#68877)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Oct 12 07:47:10 PDT 2023


Author: Christian Ulmann
Date: 2023-10-12T16:47:06+02:00
New Revision: ab6a66dbec61654d0962f6abf6d6c5b776937584

URL: https://github.com/llvm/llvm-project/commit/ab6a66dbec61654d0962f6abf6d6c5b776937584
DIFF: https://github.com/llvm/llvm-project/commit/ab6a66dbec61654d0962f6abf6d6c5b776937584.diff

LOG: Reland: [MLIR][Transforms] Fix Mem2Reg removal order to respect dominance (#68877)

This commit fixes a bug in the Mem2Reg operation erasure order.
Replacing the use-def based topological order with a dominance-based
weak order ensures that no operation is removed before all its uses have
been replaced. The order relation uses the topological order of blocks
and block internal ordering to determine a deterministic operation
order.

Additionally, the reliance on the `DenseMap` key order was eliminated by
switching to a `MapVector`, that gives a deterministic iteration order.

Example:

```
%ptr = alloca ...
...
%val0 = %load %ptr ... // LOAD0
store %val0 %ptr ...
%val1 = load %ptr ... // LOAD1
````

When promoting the slot backing %ptr, it can happen that the LOAD0 was
cleaned before LOAD1. This results in all uses of LOAD0 being replaced
by its reaching definition, before LOAD1's result is replaced by LOAD0's
result. The subsequent erasure of LOAD0 can thus not succeed, as it has
remaining usages.

Added: 
    

Modified: 
    mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
    mlir/include/mlir/Transforms/RegionUtils.h
    mlir/lib/Target/LLVMIR/CMakeLists.txt
    mlir/lib/Target/LLVMIR/Dialect/OpenACC/CMakeLists.txt
    mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp
    mlir/lib/Target/LLVMIR/Dialect/OpenMP/CMakeLists.txt
    mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
    mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
    mlir/lib/Transforms/Mem2Reg.cpp
    mlir/lib/Transforms/Utils/RegionUtils.cpp
    mlir/test/Dialect/LLVMIR/mem2reg.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
index 68b522405535a69..f9026f84935be52 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
@@ -380,9 +380,6 @@ namespace detail {
 /// to the results of preceding blocks.
 void connectPHINodes(Region &region, const ModuleTranslation &state);
 
-/// Get a topologically sorted list of blocks of the given region.
-SetVector<Block *> getTopologicallySortedBlocks(Region &region);
-
 /// Create an LLVM IR constant of `llvmType` from the MLIR attribute `attr`.
 /// This currently supports integer, floating point, splat and dense element
 /// attributes and combinations thereof. Also, an array attribute with two

diff  --git a/mlir/include/mlir/Transforms/RegionUtils.h b/mlir/include/mlir/Transforms/RegionUtils.h
index 06eebff201d1b4d..192ff7138405973 100644
--- a/mlir/include/mlir/Transforms/RegionUtils.h
+++ b/mlir/include/mlir/Transforms/RegionUtils.h
@@ -87,6 +87,9 @@ LogicalResult eraseUnreachableBlocks(RewriterBase &rewriter,
 LogicalResult runRegionDCE(RewriterBase &rewriter,
                            MutableArrayRef<Region> regions);
 
+/// Get a topologically sorted list of blocks of the given region.
+SetVector<Block *> getTopologicallySortedBlocks(Region &region);
+
 } // namespace mlir
 
 #endif // MLIR_TRANSFORMS_REGIONUTILS_H_

diff  --git a/mlir/lib/Target/LLVMIR/CMakeLists.txt b/mlir/lib/Target/LLVMIR/CMakeLists.txt
index 868ccbbb10620d9..5db0885d70d6e7a 100644
--- a/mlir/lib/Target/LLVMIR/CMakeLists.txt
+++ b/mlir/lib/Target/LLVMIR/CMakeLists.txt
@@ -39,6 +39,7 @@ add_mlir_translation_library(MLIRTargetLLVMIRExport
   MLIRLLVMDialect
   MLIRLLVMIRTransforms
   MLIRTranslateLib
+  MLIRTransformUtils
   )
 
 add_mlir_translation_library(MLIRToLLVMIRTranslationRegistration

diff  --git a/mlir/lib/Target/LLVMIR/Dialect/OpenACC/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/OpenACC/CMakeLists.txt
index 5e79e4be58f2965..e43581e28c77afa 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenACC/CMakeLists.txt
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenACC/CMakeLists.txt
@@ -10,4 +10,5 @@ add_mlir_translation_library(MLIROpenACCToLLVMIRTranslation
   MLIROpenACCDialect
   MLIRSupport
   MLIRTargetLLVMIRExport
+  MLIRTransformUtils
   )

diff  --git a/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp
index 392d34cd6f91353..37fec190d6f401d 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp
@@ -19,6 +19,7 @@
 #include "mlir/Support/LLVM.h"
 #include "mlir/Target/LLVMIR/Dialect/OpenMPCommon.h"
 #include "mlir/Target/LLVMIR/ModuleTranslation.h"
+#include "mlir/Transforms/RegionUtils.h"
 
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Frontend/OpenMP/OMPConstants.h"
@@ -395,8 +396,7 @@ static LogicalResult convertDataOp(acc::DataOp &op,
   llvm::BasicBlock *endDataBlock = llvm::BasicBlock::Create(
       ctx, "acc.end_data", builder.GetInsertBlock()->getParent());
 
-  SetVector<Block *> blocks =
-      LLVM::detail::getTopologicallySortedBlocks(op.getRegion());
+  SetVector<Block *> blocks = getTopologicallySortedBlocks(op.getRegion());
   for (Block *bb : blocks) {
     llvm::BasicBlock *llvmBB = moduleTranslation.lookupBlock(bb);
     if (bb->isEntryBlock()) {

diff  --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/CMakeLists.txt
index 41227f80cdfc8dd..0a5d7c6e22058d4 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/CMakeLists.txt
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/CMakeLists.txt
@@ -10,4 +10,5 @@ add_mlir_translation_library(MLIROpenMPToLLVMIRTranslation
   MLIROpenMPDialect
   MLIRSupport
   MLIRTargetLLVMIRExport
+  MLIRTransformUtils
   )

diff  --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 1ec3bb8e7562a9e..208c3d690e5532c 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -194,8 +194,7 @@ static llvm::BasicBlock *convertOmpOpRegions(
 
   // Convert blocks one by one in topological order to ensure
   // defs are converted before uses.
-  SetVector<Block *> blocks =
-      LLVM::detail::getTopologicallySortedBlocks(region);
+  SetVector<Block *> blocks = getTopologicallySortedBlocks(region);
   for (Block *bb : blocks) {
     llvm::BasicBlock *llvmBB = moduleTranslation.lookupBlock(bb);
     // Retarget the branch of the entry block to the entry block of the

diff  --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index ee73b04e020fd26..7312388bc9b4dd2 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -31,6 +31,7 @@
 #include "mlir/Support/LogicalResult.h"
 #include "mlir/Target/LLVMIR/LLVMTranslationInterface.h"
 #include "mlir/Target/LLVMIR/TypeToLLVM.h"
+#include "mlir/Transforms/RegionUtils.h"
 
 #include "llvm/ADT/PostOrderIterator.h"
 #include "llvm/ADT/SetVector.h"
@@ -571,24 +572,6 @@ void mlir::LLVM::detail::connectPHINodes(Region &region,
   }
 }
 
-/// Sort function blocks topologically.
-SetVector<Block *>
-mlir::LLVM::detail::getTopologicallySortedBlocks(Region &region) {
-  // For each block that has not been visited yet (i.e. that has no
-  // predecessors), add it to the list as well as its successors.
-  SetVector<Block *> blocks;
-  for (Block &b : region) {
-    if (blocks.count(&b) == 0) {
-      llvm::ReversePostOrderTraversal<Block *> traversal(&b);
-      blocks.insert(traversal.begin(), traversal.end());
-    }
-  }
-  assert(blocks.size() == region.getBlocks().size() &&
-         "some blocks are not sorted");
-
-  return blocks;
-}
-
 llvm::CallInst *mlir::LLVM::detail::createIntrinsicCall(
     llvm::IRBuilderBase &builder, llvm::Intrinsic::ID intrinsic,
     ArrayRef<llvm::Value *> args, ArrayRef<llvm::Type *> tys) {
@@ -922,7 +905,7 @@ LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) {
 
   // Then, convert blocks one by one in topological order to ensure defs are
   // converted before uses.
-  auto blocks = detail::getTopologicallySortedBlocks(func.getBody());
+  auto blocks = getTopologicallySortedBlocks(func.getBody());
   for (Block *bb : blocks) {
     llvm::IRBuilder<> builder(llvmContext);
     if (failed(convertBlock(*bb, bb->isEntryBlock(), builder)))

diff  --git a/mlir/lib/Transforms/Mem2Reg.cpp b/mlir/lib/Transforms/Mem2Reg.cpp
index 65de25dd2f32663..26a7ea5d5e219ef 100644
--- a/mlir/lib/Transforms/Mem2Reg.cpp
+++ b/mlir/lib/Transforms/Mem2Reg.cpp
@@ -16,6 +16,8 @@
 #include "mlir/Interfaces/MemorySlotInterfaces.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "mlir/Transforms/Passes.h"
+#include "mlir/Transforms/RegionUtils.h"
+#include "llvm/ADT/PostOrderIterator.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/Support/Casting.h"
 #include "llvm/Support/GenericIteratedDominanceFrontier.h"
@@ -96,6 +98,9 @@ using namespace mlir;
 
 namespace {
 
+using BlockingUsesMap =
+    llvm::MapVector<Operation *, SmallPtrSet<OpOperand *, 4>>;
+
 /// Information computed during promotion analysis used to perform actual
 /// promotion.
 struct MemorySlotPromotionInfo {
@@ -106,7 +111,7 @@ struct MemorySlotPromotionInfo {
   /// its uses, it is because the defining ops of the blocking uses requested
   /// it. The defining ops therefore must also have blocking uses or be the
   /// starting point of the bloccking uses.
-  DenseMap<Operation *, SmallPtrSet<OpOperand *, 4>> userToBlockingUses;
+  BlockingUsesMap userToBlockingUses;
 };
 
 /// Computes information for basic slot promotion. This will check that direct
@@ -129,8 +134,7 @@ class MemorySlotPromotionAnalyzer {
   /// uses (typically, removing its users because it will delete itself to
   /// resolve its own blocking uses). This will fail if one of the transitive
   /// users cannot remove a requested use, and should prevent promotion.
-  LogicalResult computeBlockingUses(
-      DenseMap<Operation *, SmallPtrSet<OpOperand *, 4>> &userToBlockingUses);
+  LogicalResult computeBlockingUses(BlockingUsesMap &userToBlockingUses);
 
   /// Computes in which blocks the value stored in the slot is actually used,
   /// meaning blocks leading to a load. This method uses `definingBlocks`, the
@@ -233,7 +237,7 @@ Value MemorySlotPromoter::getLazyDefaultValue() {
 }
 
 LogicalResult MemorySlotPromotionAnalyzer::computeBlockingUses(
-    DenseMap<Operation *, SmallPtrSet<OpOperand *, 4>> &userToBlockingUses) {
+    BlockingUsesMap &userToBlockingUses) {
   // The promotion of an operation may require the promotion of further
   // operations (typically, removing operations that use an operation that must
   // delete itself). We thus need to start from the use of the slot pointer and
@@ -243,7 +247,7 @@ LogicalResult MemorySlotPromotionAnalyzer::computeBlockingUses(
   // use it.
   for (OpOperand &use : slot.ptr.getUses()) {
     SmallPtrSet<OpOperand *, 4> &blockingUses =
-        userToBlockingUses.getOrInsertDefault(use.getOwner());
+        userToBlockingUses[use.getOwner()];
     blockingUses.insert(&use);
   }
 
@@ -281,7 +285,7 @@ LogicalResult MemorySlotPromotionAnalyzer::computeBlockingUses(
       assert(llvm::is_contained(user->getResults(), blockingUse->get()));
 
       SmallPtrSetImpl<OpOperand *> &newUserBlockingUseSet =
-          userToBlockingUses.getOrInsertDefault(blockingUse->getOwner());
+          userToBlockingUses[blockingUse->getOwner()];
       newUserBlockingUseSet.insert(blockingUse);
     }
   }
@@ -515,15 +519,37 @@ void MemorySlotPromoter::computeReachingDefInRegion(Region *region,
   }
 }
 
+/// Sorts `ops` according to dominance. Relies on the topological order of basic
+/// blocks to get a deterministic ordering.
+static void dominanceSort(SmallVector<Operation *> &ops, Region &region) {
+  // Produce a topological block order and construct a map to lookup the indices
+  // of blocks.
+  DenseMap<Block *, size_t> topoBlockIndices;
+  SetVector<Block *> topologicalOrder = getTopologicallySortedBlocks(region);
+  for (auto [index, block] : llvm::enumerate(topologicalOrder))
+    topoBlockIndices[block] = index;
+
+  // Combining the topological order of the basic blocks together with block
+  // internal operation order guarantees a deterministic, dominance respecting
+  // order.
+  llvm::sort(ops, [&](Operation *lhs, Operation *rhs) {
+    size_t lhsBlockIndex = topoBlockIndices.at(lhs->getBlock());
+    size_t rhsBlockIndex = topoBlockIndices.at(rhs->getBlock());
+    if (lhsBlockIndex == rhsBlockIndex)
+      return lhs->isBeforeInBlock(rhs);
+    return lhsBlockIndex < rhsBlockIndex;
+  });
+}
+
 void MemorySlotPromoter::removeBlockingUses() {
-  llvm::SetVector<Operation *> usersToRemoveUses;
-  for (auto &user : llvm::make_first_range(info.userToBlockingUses))
-    usersToRemoveUses.insert(user);
-  SetVector<Operation *> sortedUsersToRemoveUses =
-      mlir::topologicalSort(usersToRemoveUses);
+  llvm::SmallVector<Operation *> usersToRemoveUses(
+      llvm::make_first_range(info.userToBlockingUses));
+
+  // Sort according to dominance.
+  dominanceSort(usersToRemoveUses, *slot.ptr.getParentBlock()->getParent());
 
   llvm::SmallVector<Operation *> toErase;
-  for (Operation *toPromote : llvm::reverse(sortedUsersToRemoveUses)) {
+  for (Operation *toPromote : llvm::reverse(usersToRemoveUses)) {
     if (auto toPromoteMemOp = dyn_cast<PromotableMemOpInterface>(toPromote)) {
       Value reachingDef = reachingDefs.lookup(toPromoteMemOp);
       // If no reaching definition is known, this use is outside the reach of

diff  --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp
index b95af9ca0299649..1f2344677e6515c 100644
--- a/mlir/lib/Transforms/Utils/RegionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp
@@ -836,3 +836,19 @@ LogicalResult mlir::simplifyRegions(RewriterBase &rewriter,
   return success(eliminatedBlocks || eliminatedOpsOrArgs ||
                  mergedIdenticalBlocks);
 }
+
+SetVector<Block *> mlir::getTopologicallySortedBlocks(Region &region) {
+  // For each block that has not been visited yet (i.e. that has no
+  // predecessors), add it to the list as well as its successors.
+  SetVector<Block *> blocks;
+  for (Block &b : region) {
+    if (blocks.count(&b) == 0) {
+      llvm::ReversePostOrderTraversal<Block *> traversal(&b);
+      blocks.insert(traversal.begin(), traversal.end());
+    }
+  }
+  assert(blocks.size() == region.getBlocks().size() &&
+         "some blocks are not sorted");
+
+  return blocks;
+}

diff  --git a/mlir/test/Dialect/LLVMIR/mem2reg.mlir b/mlir/test/Dialect/LLVMIR/mem2reg.mlir
index 30ba459d07a49f3..32e3fed7e5485df 100644
--- a/mlir/test/Dialect/LLVMIR/mem2reg.mlir
+++ b/mlir/test/Dialect/LLVMIR/mem2reg.mlir
@@ -683,3 +683,16 @@ llvm.func @no_inner_alloca_promotion(%arg: i64) -> i64 {
   // CHECK: llvm.return %[[RES]] : i64
   llvm.return %2 : i64
 }
+
+// -----
+
+// CHECK-LABEL: @transitive_reaching_def
+llvm.func @transitive_reaching_def() -> !llvm.ptr {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  // CHECK-NOT: alloca
+  %1 = llvm.alloca %0 x !llvm.ptr {alignment = 8 : i64} : (i32) -> !llvm.ptr
+  %2 = llvm.load %1 {alignment = 8 : i64} : !llvm.ptr -> !llvm.ptr
+  llvm.store %2, %1 {alignment = 8 : i64} : !llvm.ptr, !llvm.ptr
+  %3 = llvm.load %1 {alignment = 8 : i64} : !llvm.ptr -> !llvm.ptr
+  llvm.return %3 : !llvm.ptr
+}


        


More information about the Mlir-commits mailing list