[Mlir-commits] [mlir] Reland: [MLIR][Transforms] Fix Mem2Reg removal order to respect dominance (PR #68877)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Oct 12 04:25:51 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-openacc

Author: Christian Ulmann (Dinistro)

<details>
<summary>Changes</summary>

This commit fixes a bug in the Mem2Reg operation erasure order. Replacing the use-def based topological order with a dominance-based weak order ensures that no operation is removed before all its uses have been replaced. The order relation uses the topological order of blocks and block internal ordering to determine a deterministic operation order.

Additionally, the reliance on the `DenseMap` key order was eliminated by switching to a `MapVector`, that gives a deterministic iteration order.

Example:

```
%ptr = alloca ...
...
%val0 = %load %ptr ... // LOAD0
store %val0 %ptr ...
%val1 = load %ptr ... // LOAD1
````

When promoting the slot backing %ptr, it can happen that the LOAD0 was cleaned before LOAD1. This results in all uses of LOAD0 being replaced by its reaching definition, before LOAD1's result is replaced by LOAD0's result. The subsequent erasure of LOAD0 can thus not succeed, as it has remaining usages.

---
Full diff: https://github.com/llvm/llvm-project/pull/68877.diff


11 Files Affected:

- (modified) mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h (-3) 
- (modified) mlir/include/mlir/Transforms/RegionUtils.h (+3) 
- (modified) mlir/lib/Target/LLVMIR/CMakeLists.txt (+1) 
- (modified) mlir/lib/Target/LLVMIR/Dialect/OpenACC/CMakeLists.txt (+1) 
- (modified) mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp (+2-2) 
- (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/CMakeLists.txt (+1) 
- (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+1-2) 
- (modified) mlir/lib/Target/LLVMIR/ModuleTranslation.cpp (+2-19) 
- (modified) mlir/lib/Transforms/Mem2Reg.cpp (+38-12) 
- (modified) mlir/lib/Transforms/Utils/RegionUtils.cpp (+16) 
- (modified) mlir/test/Dialect/LLVMIR/mem2reg.mlir (+13) 


``````````diff
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
index 68b522405535a69..f9026f84935be52 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
@@ -380,9 +380,6 @@ namespace detail {
 /// to the results of preceding blocks.
 void connectPHINodes(Region &region, const ModuleTranslation &state);
 
-/// Get a topologically sorted list of blocks of the given region.
-SetVector<Block *> getTopologicallySortedBlocks(Region &region);
-
 /// Create an LLVM IR constant of `llvmType` from the MLIR attribute `attr`.
 /// This currently supports integer, floating point, splat and dense element
 /// attributes and combinations thereof. Also, an array attribute with two
diff --git a/mlir/include/mlir/Transforms/RegionUtils.h b/mlir/include/mlir/Transforms/RegionUtils.h
index 06eebff201d1b4d..192ff7138405973 100644
--- a/mlir/include/mlir/Transforms/RegionUtils.h
+++ b/mlir/include/mlir/Transforms/RegionUtils.h
@@ -87,6 +87,9 @@ LogicalResult eraseUnreachableBlocks(RewriterBase &rewriter,
 LogicalResult runRegionDCE(RewriterBase &rewriter,
                            MutableArrayRef<Region> regions);
 
+/// Get a topologically sorted list of blocks of the given region.
+SetVector<Block *> getTopologicallySortedBlocks(Region &region);
+
 } // namespace mlir
 
 #endif // MLIR_TRANSFORMS_REGIONUTILS_H_
diff --git a/mlir/lib/Target/LLVMIR/CMakeLists.txt b/mlir/lib/Target/LLVMIR/CMakeLists.txt
index 868ccbbb10620d9..5db0885d70d6e7a 100644
--- a/mlir/lib/Target/LLVMIR/CMakeLists.txt
+++ b/mlir/lib/Target/LLVMIR/CMakeLists.txt
@@ -39,6 +39,7 @@ add_mlir_translation_library(MLIRTargetLLVMIRExport
   MLIRLLVMDialect
   MLIRLLVMIRTransforms
   MLIRTranslateLib
+  MLIRTransformUtils
   )
 
 add_mlir_translation_library(MLIRToLLVMIRTranslationRegistration
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenACC/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/OpenACC/CMakeLists.txt
index 5e79e4be58f2965..e43581e28c77afa 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenACC/CMakeLists.txt
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenACC/CMakeLists.txt
@@ -10,4 +10,5 @@ add_mlir_translation_library(MLIROpenACCToLLVMIRTranslation
   MLIROpenACCDialect
   MLIRSupport
   MLIRTargetLLVMIRExport
+  MLIRTransformUtils
   )
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp
index 392d34cd6f91353..37fec190d6f401d 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp
@@ -19,6 +19,7 @@
 #include "mlir/Support/LLVM.h"
 #include "mlir/Target/LLVMIR/Dialect/OpenMPCommon.h"
 #include "mlir/Target/LLVMIR/ModuleTranslation.h"
+#include "mlir/Transforms/RegionUtils.h"
 
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Frontend/OpenMP/OMPConstants.h"
@@ -395,8 +396,7 @@ static LogicalResult convertDataOp(acc::DataOp &op,
   llvm::BasicBlock *endDataBlock = llvm::BasicBlock::Create(
       ctx, "acc.end_data", builder.GetInsertBlock()->getParent());
 
-  SetVector<Block *> blocks =
-      LLVM::detail::getTopologicallySortedBlocks(op.getRegion());
+  SetVector<Block *> blocks = getTopologicallySortedBlocks(op.getRegion());
   for (Block *bb : blocks) {
     llvm::BasicBlock *llvmBB = moduleTranslation.lookupBlock(bb);
     if (bb->isEntryBlock()) {
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/CMakeLists.txt
index 41227f80cdfc8dd..0a5d7c6e22058d4 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/CMakeLists.txt
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/CMakeLists.txt
@@ -10,4 +10,5 @@ add_mlir_translation_library(MLIROpenMPToLLVMIRTranslation
   MLIROpenMPDialect
   MLIRSupport
   MLIRTargetLLVMIRExport
+  MLIRTransformUtils
   )
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 1ec3bb8e7562a9e..208c3d690e5532c 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -194,8 +194,7 @@ static llvm::BasicBlock *convertOmpOpRegions(
 
   // Convert blocks one by one in topological order to ensure
   // defs are converted before uses.
-  SetVector<Block *> blocks =
-      LLVM::detail::getTopologicallySortedBlocks(region);
+  SetVector<Block *> blocks = getTopologicallySortedBlocks(region);
   for (Block *bb : blocks) {
     llvm::BasicBlock *llvmBB = moduleTranslation.lookupBlock(bb);
     // Retarget the branch of the entry block to the entry block of the
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index ee73b04e020fd26..7312388bc9b4dd2 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -31,6 +31,7 @@
 #include "mlir/Support/LogicalResult.h"
 #include "mlir/Target/LLVMIR/LLVMTranslationInterface.h"
 #include "mlir/Target/LLVMIR/TypeToLLVM.h"
+#include "mlir/Transforms/RegionUtils.h"
 
 #include "llvm/ADT/PostOrderIterator.h"
 #include "llvm/ADT/SetVector.h"
@@ -571,24 +572,6 @@ void mlir::LLVM::detail::connectPHINodes(Region &region,
   }
 }
 
-/// Sort function blocks topologically.
-SetVector<Block *>
-mlir::LLVM::detail::getTopologicallySortedBlocks(Region &region) {
-  // For each block that has not been visited yet (i.e. that has no
-  // predecessors), add it to the list as well as its successors.
-  SetVector<Block *> blocks;
-  for (Block &b : region) {
-    if (blocks.count(&b) == 0) {
-      llvm::ReversePostOrderTraversal<Block *> traversal(&b);
-      blocks.insert(traversal.begin(), traversal.end());
-    }
-  }
-  assert(blocks.size() == region.getBlocks().size() &&
-         "some blocks are not sorted");
-
-  return blocks;
-}
-
 llvm::CallInst *mlir::LLVM::detail::createIntrinsicCall(
     llvm::IRBuilderBase &builder, llvm::Intrinsic::ID intrinsic,
     ArrayRef<llvm::Value *> args, ArrayRef<llvm::Type *> tys) {
@@ -922,7 +905,7 @@ LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) {
 
   // Then, convert blocks one by one in topological order to ensure defs are
   // converted before uses.
-  auto blocks = detail::getTopologicallySortedBlocks(func.getBody());
+  auto blocks = getTopologicallySortedBlocks(func.getBody());
   for (Block *bb : blocks) {
     llvm::IRBuilder<> builder(llvmContext);
     if (failed(convertBlock(*bb, bb->isEntryBlock(), builder)))
diff --git a/mlir/lib/Transforms/Mem2Reg.cpp b/mlir/lib/Transforms/Mem2Reg.cpp
index 65de25dd2f32663..1794d12ae2768db 100644
--- a/mlir/lib/Transforms/Mem2Reg.cpp
+++ b/mlir/lib/Transforms/Mem2Reg.cpp
@@ -16,6 +16,8 @@
 #include "mlir/Interfaces/MemorySlotInterfaces.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "mlir/Transforms/Passes.h"
+#include "mlir/Transforms/RegionUtils.h"
+#include "llvm/ADT/PostOrderIterator.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/Support/Casting.h"
 #include "llvm/Support/GenericIteratedDominanceFrontier.h"
@@ -96,6 +98,9 @@ using namespace mlir;
 
 namespace {
 
+using BlockingUsesMap =
+    llvm::MapVector<Operation *, SmallPtrSet<OpOperand *, 4>>;
+
 /// Information computed during promotion analysis used to perform actual
 /// promotion.
 struct MemorySlotPromotionInfo {
@@ -106,7 +111,7 @@ struct MemorySlotPromotionInfo {
   /// its uses, it is because the defining ops of the blocking uses requested
   /// it. The defining ops therefore must also have blocking uses or be the
   /// starting point of the bloccking uses.
-  DenseMap<Operation *, SmallPtrSet<OpOperand *, 4>> userToBlockingUses;
+  BlockingUsesMap userToBlockingUses;
 };
 
 /// Computes information for basic slot promotion. This will check that direct
@@ -129,8 +134,7 @@ class MemorySlotPromotionAnalyzer {
   /// uses (typically, removing its users because it will delete itself to
   /// resolve its own blocking uses). This will fail if one of the transitive
   /// users cannot remove a requested use, and should prevent promotion.
-  LogicalResult computeBlockingUses(
-      DenseMap<Operation *, SmallPtrSet<OpOperand *, 4>> &userToBlockingUses);
+  LogicalResult computeBlockingUses(BlockingUsesMap &userToBlockingUses);
 
   /// Computes in which blocks the value stored in the slot is actually used,
   /// meaning blocks leading to a load. This method uses `definingBlocks`, the
@@ -233,7 +237,7 @@ Value MemorySlotPromoter::getLazyDefaultValue() {
 }
 
 LogicalResult MemorySlotPromotionAnalyzer::computeBlockingUses(
-    DenseMap<Operation *, SmallPtrSet<OpOperand *, 4>> &userToBlockingUses) {
+    BlockingUsesMap &userToBlockingUses) {
   // The promotion of an operation may require the promotion of further
   // operations (typically, removing operations that use an operation that must
   // delete itself). We thus need to start from the use of the slot pointer and
@@ -243,7 +247,7 @@ LogicalResult MemorySlotPromotionAnalyzer::computeBlockingUses(
   // use it.
   for (OpOperand &use : slot.ptr.getUses()) {
     SmallPtrSet<OpOperand *, 4> &blockingUses =
-        userToBlockingUses.getOrInsertDefault(use.getOwner());
+        userToBlockingUses[use.getOwner()];
     blockingUses.insert(&use);
   }
 
@@ -281,7 +285,7 @@ LogicalResult MemorySlotPromotionAnalyzer::computeBlockingUses(
       assert(llvm::is_contained(user->getResults(), blockingUse->get()));
 
       SmallPtrSetImpl<OpOperand *> &newUserBlockingUseSet =
-          userToBlockingUses.getOrInsertDefault(blockingUse->getOwner());
+          userToBlockingUses[blockingUse->getOwner()];
       newUserBlockingUseSet.insert(blockingUse);
     }
   }
@@ -515,15 +519,37 @@ void MemorySlotPromoter::computeReachingDefInRegion(Region *region,
   }
 }
 
+/// Sorts `ops` according to dominance. Relies on the topological order of basic
+/// blocks to get a deterministic ordering.
+static void dominanceSort(SmallVector<Operation *> &ops, Region &region) {
+  // Produce a topological block order and construct a map to lookup the indices
+  // of blocks.
+  DenseMap<Block *, size_t> topoBlockIndices;
+  SetVector<Block *> topologicalOrder = getTopologicallySortedBlocks(region);
+  for (auto [index, block] : llvm::enumerate(topologicalOrder))
+    topoBlockIndices[block] = index;
+
+  // Combining the topological order of the basic blocks together with block
+  // internal operation order guarentees a deterministic, dominance respecting
+  // order.
+  llvm::sort(ops, [&](Operation *lhs, Operation *rhs) {
+    size_t lhsBlockIndex = topoBlockIndices.at(lhs->getBlock());
+    size_t rhsBlockIndex = topoBlockIndices.at(rhs->getBlock());
+    if (lhsBlockIndex == rhsBlockIndex)
+      return lhs->isBeforeInBlock(rhs);
+    return lhsBlockIndex < rhsBlockIndex;
+  });
+}
+
 void MemorySlotPromoter::removeBlockingUses() {
-  llvm::SetVector<Operation *> usersToRemoveUses;
-  for (auto &user : llvm::make_first_range(info.userToBlockingUses))
-    usersToRemoveUses.insert(user);
-  SetVector<Operation *> sortedUsersToRemoveUses =
-      mlir::topologicalSort(usersToRemoveUses);
+  llvm::SmallVector<Operation *> usersToRemoveUses(
+      llvm::make_first_range(info.userToBlockingUses));
+
+  // Sort according to dominance.
+  dominanceSort(usersToRemoveUses, *slot.ptr.getParentBlock()->getParent());
 
   llvm::SmallVector<Operation *> toErase;
-  for (Operation *toPromote : llvm::reverse(sortedUsersToRemoveUses)) {
+  for (Operation *toPromote : llvm::reverse(usersToRemoveUses)) {
     if (auto toPromoteMemOp = dyn_cast<PromotableMemOpInterface>(toPromote)) {
       Value reachingDef = reachingDefs.lookup(toPromoteMemOp);
       // If no reaching definition is known, this use is outside the reach of
diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp
index b95af9ca0299649..1f2344677e6515c 100644
--- a/mlir/lib/Transforms/Utils/RegionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp
@@ -836,3 +836,19 @@ LogicalResult mlir::simplifyRegions(RewriterBase &rewriter,
   return success(eliminatedBlocks || eliminatedOpsOrArgs ||
                  mergedIdenticalBlocks);
 }
+
+SetVector<Block *> mlir::getTopologicallySortedBlocks(Region &region) {
+  // For each block that has not been visited yet (i.e. that has no
+  // predecessors), add it to the list as well as its successors.
+  SetVector<Block *> blocks;
+  for (Block &b : region) {
+    if (blocks.count(&b) == 0) {
+      llvm::ReversePostOrderTraversal<Block *> traversal(&b);
+      blocks.insert(traversal.begin(), traversal.end());
+    }
+  }
+  assert(blocks.size() == region.getBlocks().size() &&
+         "some blocks are not sorted");
+
+  return blocks;
+}
diff --git a/mlir/test/Dialect/LLVMIR/mem2reg.mlir b/mlir/test/Dialect/LLVMIR/mem2reg.mlir
index 30ba459d07a49f3..32e3fed7e5485df 100644
--- a/mlir/test/Dialect/LLVMIR/mem2reg.mlir
+++ b/mlir/test/Dialect/LLVMIR/mem2reg.mlir
@@ -683,3 +683,16 @@ llvm.func @no_inner_alloca_promotion(%arg: i64) -> i64 {
   // CHECK: llvm.return %[[RES]] : i64
   llvm.return %2 : i64
 }
+
+// -----
+
+// CHECK-LABEL: @transitive_reaching_def
+llvm.func @transitive_reaching_def() -> !llvm.ptr {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  // CHECK-NOT: alloca
+  %1 = llvm.alloca %0 x !llvm.ptr {alignment = 8 : i64} : (i32) -> !llvm.ptr
+  %2 = llvm.load %1 {alignment = 8 : i64} : !llvm.ptr -> !llvm.ptr
+  llvm.store %2, %1 {alignment = 8 : i64} : !llvm.ptr, !llvm.ptr
+  %3 = llvm.load %1 {alignment = 8 : i64} : !llvm.ptr -> !llvm.ptr
+  llvm.return %3 : !llvm.ptr
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/68877


More information about the Mlir-commits mailing list