[Mlir-commits] [llvm] [mlir] [MLIR][Analysis] Consolidate topological sort utilities (PR #92563)

Christian Ulmann llvmlistbot at llvm.org
Tue May 21 02:17:28 PDT 2024


https://github.com/Dinistro updated https://github.com/llvm/llvm-project/pull/92563

>From c64c0b7708a363960eacfe66e631fbef7ef5bb00 Mon Sep 17 00:00:00 2001
From: Christian Ulmann <christian.ulmann at nextsilicon.com>
Date: Fri, 17 May 2024 11:16:56 +0000
Subject: [PATCH 1/5] first impl of the new topo sort

---
 mlir/include/mlir/Analysis/SliceAnalysis.h    |   1 +
 mlir/lib/Analysis/SliceAnalysis.cpp           | 126 +++++++++-----
 mlir/test/Analysis/test-topoligical-sort.mlir |  53 ++++--
 mlir/test/Dialect/Affine/slicing-utils.mlir   | 160 +++++++++---------
 mlir/test/lib/Analysis/TestSlice.cpp          |  28 ++-
 5 files changed, 211 insertions(+), 157 deletions(-)

diff --git a/mlir/include/mlir/Analysis/SliceAnalysis.h b/mlir/include/mlir/Analysis/SliceAnalysis.h
index d5cdf72c3889f..19571fc1946be 100644
--- a/mlir/include/mlir/Analysis/SliceAnalysis.h
+++ b/mlir/include/mlir/Analysis/SliceAnalysis.h
@@ -226,6 +226,7 @@ getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions = {},
 /// Multi-root DAG topological sort.
 /// Performs a topological sort of the Operation in the `toSort` SetVector.
 /// Returns a topologically sorted SetVector.
+/// Does not support multi-sets.
 SetVector<Operation *> topologicalSort(const SetVector<Operation *> &toSort);
 
 /// Utility to match a generic reduction given a list of iteration-carried
diff --git a/mlir/lib/Analysis/SliceAnalysis.cpp b/mlir/lib/Analysis/SliceAnalysis.cpp
index 26fe8e3dc0819..f93183749dfd2 100644
--- a/mlir/lib/Analysis/SliceAnalysis.cpp
+++ b/mlir/lib/Analysis/SliceAnalysis.cpp
@@ -11,10 +11,13 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/IR/Block.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/Operation.h"
+#include "mlir/IR/RegionGraphTraits.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Support/LLVM.h"
+#include "llvm/ADT/PostOrderIterator.h"
 #include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/SmallPtrSet.h"
 
@@ -164,60 +167,95 @@ mlir::getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions,
   return topologicalSort(slice);
 }
 
-namespace {
-/// DFS post-order implementation that maintains a global count to work across
-/// multiple invocations, to help implement topological sort on multi-root DAGs.
-/// We traverse all operations but only record the ones that appear in
-/// `toSort` for the final result.
-struct DFSState {
-  DFSState(const SetVector<Operation *> &set) : toSort(set), seen() {}
-  const SetVector<Operation *> &toSort;
-  SmallVector<Operation *, 16> topologicalCounts;
-  DenseSet<Operation *> seen;
-};
-} // namespace
-
-static void dfsPostorder(Operation *root, DFSState *state) {
-  SmallVector<Operation *> queue(1, root);
-  std::vector<Operation *> ops;
-  while (!queue.empty()) {
-    Operation *current = queue.pop_back_val();
-    ops.push_back(current);
-    for (Operation *op : current->getUsers())
-      queue.push_back(op);
-    for (Region &region : current->getRegions()) {
-      for (Operation &op : region.getOps())
-        queue.push_back(&op);
+/// TODO: deduplicate
+static SetVector<Block *> getTopologicallySortedBlocks(Region &region) {
+  // For each block that has not been visited yet (i.e. that has no
+  // predecessors), add it to the list as well as its successors.
+  SetVector<Block *> blocks;
+  for (Block &b : region) {
+    if (blocks.count(&b) == 0) {
+      llvm::ReversePostOrderTraversal<Block *> traversal(&b);
+      blocks.insert(traversal.begin(), traversal.end());
     }
   }
+  assert(blocks.size() == region.getBlocks().size() &&
+         "some blocks are not sorted");
 
-  for (Operation *op : llvm::reverse(ops)) {
-    if (state->seen.insert(op).second && state->toSort.count(op) > 0)
-      state->topologicalCounts.push_back(op);
+  return blocks;
+}
+
+/// Computes the common ancestor region of all operations in `ops`. Remembers
+/// all the traversed regions in `traversedRegions`.
+static Region *findCommonParentRegion(const SetVector<Operation *> &ops,
+                                      DenseSet<Region *> &traversedRegions) {
+  // Map to count the number of times a region was encountered.
+  llvm::DenseMap<Region *, size_t> regionCounts;
+  size_t expectedCount = ops.size();
+
+  // Walk the region tree for each operation towards the root and add to the
+  // region count.
+  Region *res = nullptr;
+  for (Operation *op : ops) {
+    Region *current = op->getParentRegion();
+    while (current) {
+      // Insert or get the count.
+      auto it = regionCounts.try_emplace(current, 0).first;
+      size_t count = ++it->getSecond();
+      if (count == expectedCount) {
+        res = current;
+        break;
+      }
+      current = current->getParentRegion();
+    }
+  }
+  auto firstRange = llvm::make_first_range(regionCounts);
+  traversedRegions.insert(firstRange.begin(), firstRange.end());
+  return res;
+}
+
+/// Topologically traverses `region` and insers all encountered operations in
+/// `toSort` into the result. Recursively traverses regions when they are
+/// present in `relevantRegions`.
+static void topoSortRegion(Region &region,
+                           const DenseSet<Region *> &relevantRegions,
+                           const SetVector<Operation *> &toSort,
+                           SetVector<Operation *> &result) {
+  SetVector<Block *> sortedBlocks = getTopologicallySortedBlocks(region);
+  for (Block *block : sortedBlocks) {
+    for (Operation &op : *block) {
+      if (toSort.contains(&op))
+        result.insert(&op);
+      for (Region &subRegion : op.getRegions()) {
+        // Skip regions that do not contain operations from `toSort`.
+        if (!relevantRegions.contains(&region))
+          continue;
+        topoSortRegion(subRegion, relevantRegions, toSort, result);
+      }
+    }
   }
 }
 
 SetVector<Operation *>
 mlir::topologicalSort(const SetVector<Operation *> &toSort) {
-  if (toSort.empty()) {
+  if (toSort.size() <= 1)
     return toSort;
-  }
 
-  // Run from each root with global count and `seen` set.
-  DFSState state(toSort);
-  for (auto *s : toSort) {
-    assert(toSort.count(s) == 1 && "NYI: multi-sets not supported");
-    dfsPostorder(s, &state);
-  }
-
-  // Reorder and return.
-  SetVector<Operation *> res;
-  for (auto it = state.topologicalCounts.rbegin(),
-            eit = state.topologicalCounts.rend();
-       it != eit; ++it) {
-    res.insert(*it);
-  }
-  return res;
+  assert(llvm::all_of(toSort,
+                      [&](Operation *op) { return toSort.count(op) == 1; }) &&
+         "expected only unique set entries");
+
+  // First, find the root region to start the recursive traversal through the
+  // IR.
+  DenseSet<Region *> relevantRegions;
+  Region *rootRegion = findCommonParentRegion(toSort, relevantRegions);
+  assert(rootRegion && "expected all ops to have a common ancestor");
+
+  // Sort all element in `toSort` by recursively traversing the IR.
+  SetVector<Operation *> result;
+  topoSortRegion(*rootRegion, relevantRegions, toSort, result);
+  assert(result.size() == toSort.size() &&
+         "expected all operations to be present in the result");
+  return result;
 }
 
 /// Returns true if `value` (transitively) depends on iteration-carried values
diff --git a/mlir/test/Analysis/test-topoligical-sort.mlir b/mlir/test/Analysis/test-topoligical-sort.mlir
index 8608586402055..150aff854fc8f 100644
--- a/mlir/test/Analysis/test-topoligical-sort.mlir
+++ b/mlir/test/Analysis/test-topoligical-sort.mlir
@@ -1,21 +1,38 @@
-// RUN: mlir-opt %s -pass-pipeline="builtin.module(func.func(test-print-topological-sort))" 2>&1 | FileCheck %s
+// RUN: mlir-opt %s -pass-pipeline="builtin.module(func.func(test-print-topological-sort))" --split-input-file | FileCheck %s
 
-// CHECK-LABEL: Testing : region
-//       CHECK: arith.addi {{.*}} : index
-//  CHECK-NEXT: scf.for
-//       CHECK: } {__test_sort_original_idx__ = 2 : i64}
-//  CHECK-NEXT: arith.addi {{.*}} : i32
-//  CHECK-NEXT: arith.subi {{.*}} : i32
-func.func @region(
-  %arg0 : index, %arg1 : index, %arg2 : index, %arg3 : index,
-  %arg4 : i32, %arg5 : i32, %arg6 : i32,
-  %buffer : memref<i32>) {
-  %0 = arith.addi %arg4, %arg5 {__test_sort_original_idx__ = 0} : i32
-  %idx = arith.addi %arg0, %arg1 {__test_sort_original_idx__ = 3} : index
-  scf.for %arg7 = %idx to %arg2 step %arg3  {
-    %2 = arith.addi %0, %arg5 : i32
-    %3 = arith.subi %2, %arg6 {__test_sort_original_idx__ = 1} : i32
-    memref.store %3, %buffer[] : memref<i32>
-  } {__test_sort_original_idx__ = 2}
+// CHECK-LABEL: single_element
+func.func @single_element() {
+  // CHECK: test_sort_index = 0
+  return {test_to_sort}
+}
+
+// -----
+
+// CHECK-LABEL: @simple_region
+func.func @simple_region(%cond: i1) {
+  // CHECK: test_sort_index = 0
+  %0 = arith.constant {test_to_sort} 42 : i32
+  scf.if %cond {
+    %1 = arith.addi %0, %0 : i32
+    // CHECK: test_sort_index = 2
+    %2 = arith.subi %0, %1 {test_to_sort} : i32
+  // CHECK: test_sort_index = 1
+  } {test_to_sort}
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @multi_region
+func.func @multi_region(%cond: i1) {
+  scf.if %cond {
+    // CHECK: test_sort_index = 0
+    %0 = arith.constant {test_to_sort} 42 : i32
+  }
+
+  scf.if %cond {
+    // CHECK: test_sort_index = 1
+    %0 = arith.constant {test_to_sort} 24 : i32
+  }
   return
 }
diff --git a/mlir/test/Dialect/Affine/slicing-utils.mlir b/mlir/test/Dialect/Affine/slicing-utils.mlir
index 74379978fdf8c..0848a924b9d96 100644
--- a/mlir/test/Dialect/Affine/slicing-utils.mlir
+++ b/mlir/test/Dialect/Affine/slicing-utils.mlir
@@ -28,15 +28,15 @@ func.func @slicing_test() {
   // BWD: matched: %[[v1:.*]] {{.*}} backward static slice:
   //
   // FWDBWD: matched: %[[v1:.*]] {{.*}} static slice:
-  // FWDBWD-DAG: %[[v4:.*]] = "slicing-test-op"() : () -> i4
-  // FWDBWD-DAG: %[[v3:.*]] = "slicing-test-op"() : () -> i3
-  // FWDBWD-NEXT: %[[v6:.*]] = "slicing-test-op"(%[[v3]], %[[v4]]) : (i3, i4) -> i6
-  // FWDBWD-DAG: %[[v2:.*]] = "slicing-test-op"() : () -> i2
-  // FWDBWD-DAG: %[[v1:.*]] = "slicing-test-op"() : () -> i1
-  // FWDBWD-NEXT: %[[v5:.*]] = "slicing-test-op"(%[[v1]], %[[v2]]) : (i1, i2) -> i5
-  // FWDBWD-DAG: %[[v8:.*]] = "slicing-test-op"(%[[v5]], %[[v6]]) : (i5, i6) -> i8
-  // FWDBWD-DAG: %[[v7:.*]] = "slicing-test-op"(%[[v1]], %[[v5]]) : (i1, i5) -> i7
-  // FWDBWD-NEXT: %[[v9:.*]] = "slicing-test-op"(%[[v7]], %[[v8]]) : (i7, i8) -> i9
+  // FWDBWD: %[[v1:.*]] = "slicing-test-op"() : () -> i1
+  // FWDBWD: %[[v2:.*]] = "slicing-test-op"() : () -> i2
+  // FWDBWD: %[[v3:.*]] = "slicing-test-op"() : () -> i3
+  // FWDBWD: %[[v4:.*]] = "slicing-test-op"() : () -> i4
+  // FWDBWD: %[[v5:.*]] = "slicing-test-op"(%[[v1]], %[[v2]]) : (i1, i2) -> i5
+  // FWDBWD: %[[v6:.*]] = "slicing-test-op"(%[[v3]], %[[v4]]) : (i3, i4) -> i6
+  // FWDBWD: %[[v7:.*]] = "slicing-test-op"(%[[v1]], %[[v5]]) : (i1, i5) -> i7
+  // FWDBWD: %[[v8:.*]] = "slicing-test-op"(%[[v5]], %[[v6]]) : (i5, i6) -> i8
+  // FWDBWD: %[[v9:.*]] = "slicing-test-op"(%[[v7]], %[[v8]]) : (i7, i8) -> i9
 
   %1 = "slicing-test-op" () : () -> i1
 
@@ -49,15 +49,15 @@ func.func @slicing_test() {
   // BWD: matched: %[[v2:.*]] {{.*}} backward static slice:
   //
   // FWDBWD-NEXT: matched: %[[v2:.*]] {{.*}} static slice:
-  // FWDBWD-DAG: %[[v4:.*]] = "slicing-test-op"() : () -> i4
-  // FWDBWD-DAG: %[[v3:.*]] = "slicing-test-op"() : () -> i3
-  // FWDBWD-NEXT: %[[v6:.*]] = "slicing-test-op"(%[[v3]], %[[v4]]) : (i3, i4) -> i6
-  // FWDBWD-DAG: %[[v2:.*]] = "slicing-test-op"() : () -> i2
-  // FWDBWD-DAG: %[[v1:.*]] = "slicing-test-op"() : () -> i1
-  // FWDBWD-NEXT: %[[v5:.*]] = "slicing-test-op"(%[[v1]], %[[v2]]) : (i1, i2) -> i5
-  // FWDBWD-DAG: %[[v8:.*]] = "slicing-test-op"(%[[v5]], %[[v6]]) : (i5, i6) -> i8
-  // FWDBWD-DAG: %[[v7:.*]] = "slicing-test-op"(%[[v1]], %[[v5]]) : (i1, i5) -> i7
-  // FWDBWD-NEXT: %[[v9:.*]] = "slicing-test-op"(%[[v7]], %[[v8]]) : (i7, i8) -> i9
+  // FWDBWD: %[[v1:.*]] = "slicing-test-op"() : () -> i1
+  // FWDBWD: %[[v2:.*]] = "slicing-test-op"() : () -> i2
+  // FWDBWD: %[[v3:.*]] = "slicing-test-op"() : () -> i3
+  // FWDBWD: %[[v4:.*]] = "slicing-test-op"() : () -> i4
+  // FWDBWD: %[[v5:.*]] = "slicing-test-op"(%[[v1]], %[[v2]]) : (i1, i2) -> i5
+  // FWDBWD: %[[v6:.*]] = "slicing-test-op"(%[[v3]], %[[v4]]) : (i3, i4) -> i6
+  // FWDBWD: %[[v7:.*]] = "slicing-test-op"(%[[v1]], %[[v5]]) : (i1, i5) -> i7
+  // FWDBWD: %[[v8:.*]] = "slicing-test-op"(%[[v5]], %[[v6]]) : (i5, i6) -> i8
+  // FWDBWD: %[[v9:.*]] = "slicing-test-op"(%[[v7]], %[[v8]]) : (i7, i8) -> i9
 
   %2 = "slicing-test-op" () : () -> i2
 
@@ -69,15 +69,15 @@ func.func @slicing_test() {
   // BWD: matched: %[[v3:.*]] {{.*}} backward static slice:
   //
   // FWDBWD-NEXT: matched: %[[v3:.*]] {{.*}} static slice:
-  // FWDBWD-DAG: %[[v2:.*]] = "slicing-test-op"() : () -> i2
-  // FWDBWD-DAG: %[[v1:.*]] = "slicing-test-op"() : () -> i1
-  // FWDBWD-NEXT: %[[v5:.*]] = "slicing-test-op"(%[[v1]], %[[v2]]) : (i1, i2) -> i5
-  // FWDBWD-NEXT: %[[v7:.*]] = "slicing-test-op"(%[[v1]], %[[v5]]) : (i1, i5) -> i7
-  // FWDBWD-DAG: %[[v4:.*]] = "slicing-test-op"() : () -> i4
-  // FWDBWD-DAG: %[[v3:.*]] = "slicing-test-op"() : () -> i3
-  // FWDBWD-NEXT: %[[v6:.*]] = "slicing-test-op"(%[[v3]], %[[v4]]) : (i3, i4) -> i6
-  // FWDBWD-NEXT: %[[v8:.*]] = "slicing-test-op"(%[[v5]], %[[v6]]) : (i5, i6) -> i8
-  // FWDBWD-NEXT: %[[v9:.*]] = "slicing-test-op"(%[[v7]], %[[v8]]) : (i7, i8) -> i9
+  // FWDBWD: %[[v1:.*]] = "slicing-test-op"() : () -> i1
+  // FWDBWD: %[[v2:.*]] = "slicing-test-op"() : () -> i2
+  // FWDBWD: %[[v3:.*]] = "slicing-test-op"() : () -> i3
+  // FWDBWD: %[[v4:.*]] = "slicing-test-op"() : () -> i4
+  // FWDBWD: %[[v5:.*]] = "slicing-test-op"(%[[v1]], %[[v2]]) : (i1, i2) -> i5
+  // FWDBWD: %[[v6:.*]] = "slicing-test-op"(%[[v3]], %[[v4]]) : (i3, i4) -> i6
+  // FWDBWD: %[[v7:.*]] = "slicing-test-op"(%[[v1]], %[[v5]]) : (i1, i5) -> i7
+  // FWDBWD: %[[v8:.*]] = "slicing-test-op"(%[[v5]], %[[v6]]) : (i5, i6) -> i8
+  // FWDBWD: %[[v9:.*]] = "slicing-test-op"(%[[v7]], %[[v8]]) : (i7, i8) -> i9
 
   %3 = "slicing-test-op" () : () -> i3
 
@@ -89,15 +89,15 @@ func.func @slicing_test() {
   // BWD: matched: %[[v4:.*]] {{.*}} backward static slice:
   //
   // FWDBWD-NEXT: matched: %[[v4:.*]] {{.*}} static slice:
-  // FWDBWD-DAG: %[[v2:.*]] = "slicing-test-op"() : () -> i2
-  // FWDBWD-DAG: %[[v1:.*]] = "slicing-test-op"() : () -> i1
-  // FWDBWD-NEXT: %[[v5:.*]] = "slicing-test-op"(%[[v1]], %[[v2]]) : (i1, i2) -> i5
-  // FWDBWD-NEXT: %[[v7:.*]] = "slicing-test-op"(%[[v1]], %[[v5]]) : (i1, i5) -> i7
-  // FWDBWD-DAG: %[[v4:.*]] = "slicing-test-op"() : () -> i4
-  // FWDBWD-DAG: %[[v3:.*]] = "slicing-test-op"() : () -> i3
-  // FWDBWD-NEXT: %[[v6:.*]] = "slicing-test-op"(%[[v3]], %[[v4]]) : (i3, i4) -> i6
-  // FWDBWD-NEXT: %[[v8:.*]] = "slicing-test-op"(%[[v5]], %[[v6]]) : (i5, i6) -> i8
-  // FWDBWD-NEXT: %[[v9:.*]] = "slicing-test-op"(%[[v7]], %[[v8]]) : (i7, i8) -> i9
+  // FWDBWD: %[[v1:.*]] = "slicing-test-op"() : () -> i1
+  // FWDBWD: %[[v2:.*]] = "slicing-test-op"() : () -> i2
+  // FWDBWD: %[[v3:.*]] = "slicing-test-op"() : () -> i3
+  // FWDBWD: %[[v4:.*]] = "slicing-test-op"() : () -> i4
+  // FWDBWD: %[[v5:.*]] = "slicing-test-op"(%[[v1]], %[[v2]]) : (i1, i2) -> i5
+  // FWDBWD: %[[v6:.*]] = "slicing-test-op"(%[[v3]], %[[v4]]) : (i3, i4) -> i6
+  // FWDBWD: %[[v7:.*]] = "slicing-test-op"(%[[v1]], %[[v5]]) : (i1, i5) -> i7
+  // FWDBWD: %[[v8:.*]] = "slicing-test-op"(%[[v5]], %[[v6]]) : (i5, i6) -> i8
+  // FWDBWD: %[[v9:.*]] = "slicing-test-op"(%[[v7]], %[[v8]]) : (i7, i8) -> i9
 
   %4 = "slicing-test-op" () : () -> i4
 
@@ -111,15 +111,15 @@ func.func @slicing_test() {
   // BWD-DAG: %[[v2:.*]] = "slicing-test-op"() : () -> i2
   //
   // FWDBWD-NEXT: matched: %[[v5:.*]] {{.*}} static slice:
-  // FWDBWD-DAG: %[[v4:.*]] = "slicing-test-op"() : () -> i4
-  // FWDBWD-DAG: %[[v3:.*]] = "slicing-test-op"() : () -> i3
-  // FWDBWD-NEXT: %[[v6:.*]] = "slicing-test-op"(%[[v3]], %[[v4]]) : (i3, i4) -> i6
-  // FWDBWD-DAG: %[[v2:.*]] = "slicing-test-op"() : () -> i2
-  // FWDBWD-DAG: %[[v1:.*]] = "slicing-test-op"() : () -> i1
-  // FWDBWD-NEXT: %[[v5:.*]] = "slicing-test-op"(%[[v1]], %[[v2]]) : (i1, i2) -> i5
-  // FWDBWD-DAG: %[[v8:.*]] = "slicing-test-op"(%[[v5]], %[[v6]]) : (i5, i6) -> i8
-  // FWDBWD-DAG: %[[v7:.*]] = "slicing-test-op"(%[[v1]], %[[v5]]) : (i1, i5) -> i7
-  // FWDBWD-NEXT: %[[v9:.*]] = "slicing-test-op"(%[[v7]], %[[v8]]) : (i7, i8) -> i9
+  // FWDBWD: %[[v1:.*]] = "slicing-test-op"() : () -> i1
+  // FWDBWD: %[[v2:.*]] = "slicing-test-op"() : () -> i2
+  // FWDBWD: %[[v3:.*]] = "slicing-test-op"() : () -> i3
+  // FWDBWD: %[[v4:.*]] = "slicing-test-op"() : () -> i4
+  // FWDBWD: %[[v5:.*]] = "slicing-test-op"(%[[v1]], %[[v2]]) : (i1, i2) -> i5
+  // FWDBWD: %[[v6:.*]] = "slicing-test-op"(%[[v3]], %[[v4]]) : (i3, i4) -> i6
+  // FWDBWD: %[[v7:.*]] = "slicing-test-op"(%[[v1]], %[[v5]]) : (i1, i5) -> i7
+  // FWDBWD: %[[v8:.*]] = "slicing-test-op"(%[[v5]], %[[v6]]) : (i5, i6) -> i8
+  // FWDBWD: %[[v9:.*]] = "slicing-test-op"(%[[v7]], %[[v8]]) : (i7, i8) -> i9
 
   %5 = "slicing-test-op" (%1, %2) : (i1, i2) -> i5
 
@@ -132,15 +132,15 @@ func.func @slicing_test() {
   // BWD-DAG: %[[v4:.*]] = "slicing-test-op"() : () -> i4
   //
   // FWDBWD-NEXT: matched: %[[v6:.*]] {{.*}} static slice:
-  // FWDBWD-DAG: %[[v2:.*]] = "slicing-test-op"() : () -> i2
-  // FWDBWD-DAG: %[[v1:.*]] = "slicing-test-op"() : () -> i1
-  // FWDBWD-NEXT: %[[v5:.*]] = "slicing-test-op"(%[[v1]], %[[v2]]) : (i1, i2) -> i5
-  // FWDBWD-NEXT: %[[v7:.*]] = "slicing-test-op"(%[[v1]], %[[v5]]) : (i1, i5) -> i7
-  // FWDBWD-DAG: %[[v4:.*]] = "slicing-test-op"() : () -> i4
-  // FWDBWD-DAG: %[[v3:.*]] = "slicing-test-op"() : () -> i3
-  // FWDBWD-NEXT: %[[v6:.*]] = "slicing-test-op"(%[[v3]], %[[v4]]) : (i3, i4) -> i6
-  // FWDBWD-NEXT: %[[v8:.*]] = "slicing-test-op"(%[[v5]], %[[v6]]) : (i5, i6) -> i8
-  // FWDBWD-NEXT: %[[v9:.*]] = "slicing-test-op"(%[[v7]], %[[v8]]) : (i7, i8) -> i9
+  // FWDBWD: %[[v1:.*]] = "slicing-test-op"() : () -> i1
+  // FWDBWD: %[[v2:.*]] = "slicing-test-op"() : () -> i2
+  // FWDBWD: %[[v3:.*]] = "slicing-test-op"() : () -> i3
+  // FWDBWD: %[[v4:.*]] = "slicing-test-op"() : () -> i4
+  // FWDBWD: %[[v5:.*]] = "slicing-test-op"(%[[v1]], %[[v2]]) : (i1, i2) -> i5
+  // FWDBWD: %[[v6:.*]] = "slicing-test-op"(%[[v3]], %[[v4]]) : (i3, i4) -> i6
+  // FWDBWD: %[[v7:.*]] = "slicing-test-op"(%[[v1]], %[[v5]]) : (i1, i5) -> i7
+  // FWDBWD: %[[v8:.*]] = "slicing-test-op"(%[[v5]], %[[v6]]) : (i5, i6) -> i8
+  // FWDBWD: %[[v9:.*]] = "slicing-test-op"(%[[v7]], %[[v8]]) : (i7, i8) -> i9
 
   %6 = "slicing-test-op" (%3, %4) : (i3, i4) -> i6
 
@@ -153,15 +153,15 @@ func.func @slicing_test() {
   // BWD-NEXT: %[[v5:.*]] = "slicing-test-op"(%[[v1]], %[[v2]]) : (i1, i2) -> i5
   //
   // FWDBWD-NEXT: matched: %[[v7:.*]] {{.*}} static slice:
-  // FWDBWD-DAG: %[[v4:.*]] = "slicing-test-op"() : () -> i4
-  // FWDBWD-DAG: %[[v3:.*]] = "slicing-test-op"() : () -> i3
-  // FWDBWD-NEXT: %[[v6:.*]] = "slicing-test-op"(%[[v3]], %[[v4]]) : (i3, i4) -> i6
-  // FWDBWD-DAG: %[[v2:.*]] = "slicing-test-op"() : () -> i2
-  // FWDBWD-DAG: %[[v1:.*]] = "slicing-test-op"() : () -> i1
+  // FWDBWD: %[[v1:.*]] = "slicing-test-op"() : () -> i1
+  // FWDBWD: %[[v2:.*]] = "slicing-test-op"() : () -> i2
+  // FWDBWD: %[[v3:.*]] = "slicing-test-op"() : () -> i3
+  // FWDBWD: %[[v4:.*]] = "slicing-test-op"() : () -> i4
   // FWDBWD: %[[v5:.*]] = "slicing-test-op"(%[[v1]], %[[v2]]) : (i1, i2) -> i5
-  // FWDBWD-DAG: %[[v8:.*]] = "slicing-test-op"(%[[v5]], %[[v6]]) : (i5, i6) -> i8
-  // FWDBWD-DAG: %[[v7:.*]] = "slicing-test-op"(%[[v1]], %[[v5]]) : (i1, i5) -> i7
-  // FWDBWD-NEXT: %[[v9:.*]] = "slicing-test-op"(%[[v7]], %[[v8]]) : (i7, i8) -> i9
+  // FWDBWD: %[[v6:.*]] = "slicing-test-op"(%[[v3]], %[[v4]]) : (i3, i4) -> i6
+  // FWDBWD: %[[v7:.*]] = "slicing-test-op"(%[[v1]], %[[v5]]) : (i1, i5) -> i7
+  // FWDBWD: %[[v8:.*]] = "slicing-test-op"(%[[v5]], %[[v6]]) : (i5, i6) -> i8
+  // FWDBWD: %[[v9:.*]] = "slicing-test-op"(%[[v7]], %[[v8]]) : (i7, i8) -> i9
 
   %7 = "slicing-test-op" (%1, %5) : (i1, i5) -> i7
 
@@ -177,15 +177,15 @@ func.func @slicing_test() {
   // BWD-NEXT: %[[v6:.*]] = "slicing-test-op"(%[[v3]], %[[v4]]) : (i3, i4) -> i6
   //
   // FWDBWD-NEXT: matched: %[[v8:.*]] {{.*}} static slice:
-  // FWDBWD-DAG: %[[v4:.*]] = "slicing-test-op"() : () -> i4
-  // FWDBWD-DAG: %[[v3:.*]] = "slicing-test-op"() : () -> i3
-  // FWDBWD-NEXT: %[[v6:.*]] = "slicing-test-op"(%[[v3]], %[[v4]]) : (i3, i4) -> i6
-  // FWDBWD-DAG: %[[v2:.*]] = "slicing-test-op"() : () -> i2
-  // FWDBWD-DAG: %[[v1:.*]] = "slicing-test-op"() : () -> i1
-  // FWDBWD-NEXT: %[[v5:.*]] = "slicing-test-op"(%[[v1]], %[[v2]]) : (i1, i2) -> i5
-  // FWDBWD-DAG: %[[v8:.*]] = "slicing-test-op"(%[[v5]], %[[v6]]) : (i5, i6) -> i8
-  // FWDBWD-DAG: %[[v7:.*]] = "slicing-test-op"(%[[v1]], %[[v5]]) : (i1, i5) -> i7
-  // FWDBWD-NEXT: %[[v9:.*]] = "slicing-test-op"(%[[v7]], %[[v8]]) : (i7, i8) -> i9
+  // FWDBWD: %[[v1:.*]] = "slicing-test-op"() : () -> i1
+  // FWDBWD: %[[v2:.*]] = "slicing-test-op"() : () -> i2
+  // FWDBWD: %[[v3:.*]] = "slicing-test-op"() : () -> i3
+  // FWDBWD: %[[v4:.*]] = "slicing-test-op"() : () -> i4
+  // FWDBWD: %[[v5:.*]] = "slicing-test-op"(%[[v1]], %[[v2]]) : (i1, i2) -> i5
+  // FWDBWD: %[[v6:.*]] = "slicing-test-op"(%[[v3]], %[[v4]]) : (i3, i4) -> i6
+  // FWDBWD: %[[v7:.*]] = "slicing-test-op"(%[[v1]], %[[v5]]) : (i1, i5) -> i7
+  // FWDBWD: %[[v8:.*]] = "slicing-test-op"(%[[v5]], %[[v6]]) : (i5, i6) -> i8
+  // FWDBWD: %[[v9:.*]] = "slicing-test-op"(%[[v7]], %[[v8]]) : (i7, i8) -> i9
 
   %8 = "slicing-test-op" (%5, %6) : (i5, i6) -> i8
 
@@ -202,15 +202,15 @@ func.func @slicing_test() {
   // BWD-NEXT: %[[v8:.*]] = "slicing-test-op"(%[[v5]], %[[v6]]) : (i5, i6) -> i8
   //
   // FWDBWD-NEXT: matched: %[[v9:.*]] {{.*}} static slice:
-  // FWDBWD-DAG: %[[v4:.*]] = "slicing-test-op"() : () -> i4
-  // FWDBWD-DAG: %[[v3:.*]] = "slicing-test-op"() : () -> i3
-  // FWDBWD-NEXT: %[[v6:.*]] = "slicing-test-op"(%[[v3]], %[[v4]]) : (i3, i4) -> i6
-  // FWDBWD-DAG: %[[v2:.*]] = "slicing-test-op"() : () -> i2
-  // FWDBWD-DAG: %[[v1:.*]] = "slicing-test-op"() : () -> i1
-  // FWDBWD-NEXT: %[[v5:.*]] = "slicing-test-op"(%[[v1]], %[[v2]]) : (i1, i2) -> i5
-  // FWDBWD-DAG: %[[v8:.*]] = "slicing-test-op"(%[[v5]], %[[v6]]) : (i5, i6) -> i8
-  // FWDBWD-DAG: %[[v7:.*]] = "slicing-test-op"(%[[v1]], %[[v5]]) : (i1, i5) -> i7
-  // FWDBWD-NEXT: %[[v9:.*]] = "slicing-test-op"(%[[v7]], %[[v8]]) : (i7, i8) -> i9
+  // FWDBWD: %[[v1:.*]] = "slicing-test-op"() : () -> i1
+  // FWDBWD: %[[v2:.*]] = "slicing-test-op"() : () -> i2
+  // FWDBWD: %[[v3:.*]] = "slicing-test-op"() : () -> i3
+  // FWDBWD: %[[v4:.*]] = "slicing-test-op"() : () -> i4
+  // FWDBWD: %[[v5:.*]] = "slicing-test-op"(%[[v1]], %[[v2]]) : (i1, i2) -> i5
+  // FWDBWD: %[[v6:.*]] = "slicing-test-op"(%[[v3]], %[[v4]]) : (i3, i4) -> i6
+  // FWDBWD: %[[v7:.*]] = "slicing-test-op"(%[[v1]], %[[v5]]) : (i1, i5) -> i7
+  // FWDBWD: %[[v8:.*]] = "slicing-test-op"(%[[v5]], %[[v6]]) : (i5, i6) -> i8
+  // FWDBWD: %[[v9:.*]] = "slicing-test-op"(%[[v7]], %[[v8]]) : (i7, i8) -> i9
 
   %9 = "slicing-test-op" (%7, %8) : (i7, i8) -> i9
 
diff --git a/mlir/test/lib/Analysis/TestSlice.cpp b/mlir/test/lib/Analysis/TestSlice.cpp
index b445febde5971..06c41d8c4a110 100644
--- a/mlir/test/lib/Analysis/TestSlice.cpp
+++ b/mlir/test/lib/Analysis/TestSlice.cpp
@@ -1,4 +1,4 @@
-//===------------- TestSlice.cpp - Test slice related analisis ------------===//
+//===- TestSlice.cpp - Test slice related analisis ------------------------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -7,12 +7,14 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/SymbolTable.h"
 #include "mlir/Pass/Pass.h"
 
 using namespace mlir;
 
-static const StringLiteral kOrderMarker = "__test_sort_original_idx__";
+static const StringLiteral kToSortMark = "test_to_sort";
+static const StringLiteral kOrderIndex = "test_sort_index";
 
 namespace {
 
@@ -26,20 +28,16 @@ struct TestTopologicalSortPass
     return "Print operations in topological order";
   }
   void runOnOperation() override {
-    std::map<int, Operation *> ops;
-    getOperation().walk([&ops](Operation *op) {
-      if (auto originalOrderAttr = op->getAttrOfType<IntegerAttr>(kOrderMarker))
-        ops[originalOrderAttr.getInt()] = op;
+    SetVector<Operation *> toSort;
+    getOperation().walk([&](Operation *op) {
+      if (op->hasAttrOfType<UnitAttr>(kToSortMark))
+        toSort.insert(op);
     });
-    SetVector<Operation *> sortedOp;
-    for (auto op : ops)
-      sortedOp.insert(op.second);
-    sortedOp = topologicalSort(sortedOp);
-    llvm::errs() << "Testing : " << getOperation().getName() << "\n";
-    for (Operation *op : sortedOp) {
-      op->print(llvm::errs());
-      llvm::errs() << "\n";
-    }
+
+    auto i32Type = IntegerType::get(&getContext(), 32);
+    SetVector<Operation *> sortedOps = topologicalSort(toSort);
+    for (auto [index, op] : llvm::enumerate(sortedOps))
+      op->setAttr(kOrderIndex, IntegerAttr::get(i32Type, index));
   }
 };
 

>From 5e181bd00cbe1cab023e496f00d3340fbf31cdbe Mon Sep 17 00:00:00 2001
From: Christian Ulmann <christian.ulmann at nextsilicon.com>
Date: Fri, 17 May 2024 11:42:08 +0000
Subject: [PATCH 2/5] move topo utils to analysis

---
 .../mlir/{Transforms => Analysis}/TopologicalSortUtils.h    | 6 +++---
 mlir/lib/Analysis/CMakeLists.txt                            | 2 ++
 .../{Transforms/Utils => Analysis}/TopologicalSortUtils.cpp | 4 ++--
 mlir/lib/Transforms/TopologicalSort.cpp                     | 2 +-
 mlir/lib/Transforms/Utils/CMakeLists.txt                    | 1 -
 mlir/lib/Transforms/Utils/RegionUtils.cpp                   | 2 +-
 mlir/lib/Transforms/ViewOpGraph.cpp                         | 2 +-
 mlir/test/{Transforms => Analysis}/test-toposort.mlir       | 0
 mlir/test/lib/Analysis/CMakeLists.txt                       | 1 +
 .../lib/{Transforms => Analysis}/TestTopologicalSort.cpp    | 2 +-
 mlir/test/lib/Transforms/CMakeLists.txt                     | 1 -
 utils/bazel/llvm-project-overlay/mlir/BUILD.bazel           | 1 -
 12 files changed, 12 insertions(+), 12 deletions(-)
 rename mlir/include/mlir/{Transforms => Analysis}/TopologicalSortUtils.h (95%)
 rename mlir/lib/{Transforms/Utils => Analysis}/TopologicalSortUtils.cpp (97%)
 rename mlir/test/{Transforms => Analysis}/test-toposort.mlir (100%)
 rename mlir/test/lib/{Transforms => Analysis}/TestTopologicalSort.cpp (98%)

diff --git a/mlir/include/mlir/Transforms/TopologicalSortUtils.h b/mlir/include/mlir/Analysis/TopologicalSortUtils.h
similarity index 95%
rename from mlir/include/mlir/Transforms/TopologicalSortUtils.h
rename to mlir/include/mlir/Analysis/TopologicalSortUtils.h
index 74e44b1dc485d..fb9441db119fd 100644
--- a/mlir/include/mlir/Transforms/TopologicalSortUtils.h
+++ b/mlir/include/mlir/Analysis/TopologicalSortUtils.h
@@ -6,8 +6,8 @@
 //
 //===----------------------------------------------------------------------===//
 
-#ifndef MLIR_TRANSFORMS_TOPOLOGICALSORTUTILS_H
-#define MLIR_TRANSFORMS_TOPOLOGICALSORTUTILS_H
+#ifndef MLIR_ANALYSIS_TOPOLOGICALSORTUTILS_H
+#define MLIR_ANALYSIS_TOPOLOGICALSORTUTILS_H
 
 #include "mlir/IR/Block.h"
 
@@ -106,4 +106,4 @@ bool computeTopologicalSorting(
 
 } // end namespace mlir
 
-#endif // MLIR_TRANSFORMS_TOPOLOGICALSORTUTILS_H
+#endif // MLIR_ANALYSIS_TOPOLOGICALSORTUTILS_H
diff --git a/mlir/lib/Analysis/CMakeLists.txt b/mlir/lib/Analysis/CMakeLists.txt
index 005814ddbec79..38d8415d81c72 100644
--- a/mlir/lib/Analysis/CMakeLists.txt
+++ b/mlir/lib/Analysis/CMakeLists.txt
@@ -6,6 +6,7 @@ set(LLVM_OPTIONAL_SOURCES
   Liveness.cpp
   CFGLoopInfo.cpp
   SliceAnalysis.cpp
+  TopologicalSortUtils.cpp
 
   AliasAnalysis/LocalAliasAnalysis.cpp
 
@@ -28,6 +29,7 @@ add_mlir_library(MLIRAnalysis
   Liveness.cpp
   CFGLoopInfo.cpp
   SliceAnalysis.cpp
+  TopologicalSortUtils.cpp
 
   AliasAnalysis/LocalAliasAnalysis.cpp
 
diff --git a/mlir/lib/Transforms/Utils/TopologicalSortUtils.cpp b/mlir/lib/Analysis/TopologicalSortUtils.cpp
similarity index 97%
rename from mlir/lib/Transforms/Utils/TopologicalSortUtils.cpp
rename to mlir/lib/Analysis/TopologicalSortUtils.cpp
index f3a9d217f2c98..4281beacee89e 100644
--- a/mlir/lib/Transforms/Utils/TopologicalSortUtils.cpp
+++ b/mlir/lib/Analysis/TopologicalSortUtils.cpp
@@ -1,4 +1,4 @@
-//===- TopologicalSortUtils.h - Topological sort utilities ------*- C++ -*-===//
+//===- TopologicalSortUtils.cpp - Topological sort utilities --------------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -6,7 +6,7 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Transforms/TopologicalSortUtils.h"
+#include "mlir/Analysis/TopologicalSortUtils.h"
 #include "mlir/IR/OpDefinition.h"
 
 using namespace mlir;
diff --git a/mlir/lib/Transforms/TopologicalSort.cpp b/mlir/lib/Transforms/TopologicalSort.cpp
index 1219968fb3692..528f6ef676020 100644
--- a/mlir/lib/Transforms/TopologicalSort.cpp
+++ b/mlir/lib/Transforms/TopologicalSort.cpp
@@ -8,8 +8,8 @@
 
 #include "mlir/Transforms/Passes.h"
 
+#include "mlir/Analysis/TopologicalSortUtils.h"
 #include "mlir/IR/RegionKindInterface.h"
-#include "mlir/Transforms/TopologicalSortUtils.h"
 
 namespace mlir {
 #define GEN_PASS_DEF_TOPOLOGICALSORT
diff --git a/mlir/lib/Transforms/Utils/CMakeLists.txt b/mlir/lib/Transforms/Utils/CMakeLists.txt
index d6aac0e2da4f5..b5788c679edc4 100644
--- a/mlir/lib/Transforms/Utils/CMakeLists.txt
+++ b/mlir/lib/Transforms/Utils/CMakeLists.txt
@@ -10,7 +10,6 @@ add_mlir_library(MLIRTransformUtils
   LoopInvariantCodeMotionUtils.cpp
   OneToNTypeConversion.cpp
   RegionUtils.cpp
-  TopologicalSortUtils.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Transforms
diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp
index 192f59b353295..b6a6dea5fe9a0 100644
--- a/mlir/lib/Transforms/Utils/RegionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Transforms/RegionUtils.h"
+#include "mlir/Analysis/TopologicalSortUtils.h"
 #include "mlir/IR/Block.h"
 #include "mlir/IR/IRMapping.h"
 #include "mlir/IR/Operation.h"
@@ -15,7 +16,6 @@
 #include "mlir/IR/Value.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
-#include "mlir/Transforms/TopologicalSortUtils.h"
 
 #include "llvm/ADT/DepthFirstIterator.h"
 #include "llvm/ADT/PostOrderIterator.h"
diff --git a/mlir/lib/Transforms/ViewOpGraph.cpp b/mlir/lib/Transforms/ViewOpGraph.cpp
index c2eb2b893cea4..b3c0a06c96fea 100644
--- a/mlir/lib/Transforms/ViewOpGraph.cpp
+++ b/mlir/lib/Transforms/ViewOpGraph.cpp
@@ -8,12 +8,12 @@
 
 #include "mlir/Transforms/ViewOpGraph.h"
 
+#include "mlir/Analysis/TopologicalSortUtils.h"
 #include "mlir/IR/Block.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Operation.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Support/IndentedOstream.h"
-#include "mlir/Transforms/TopologicalSortUtils.h"
 #include "llvm/Support/Format.h"
 #include "llvm/Support/GraphWriter.h"
 #include <map>
diff --git a/mlir/test/Transforms/test-toposort.mlir b/mlir/test/Analysis/test-toposort.mlir
similarity index 100%
rename from mlir/test/Transforms/test-toposort.mlir
rename to mlir/test/Analysis/test-toposort.mlir
diff --git a/mlir/test/lib/Analysis/CMakeLists.txt b/mlir/test/lib/Analysis/CMakeLists.txt
index d168888c1e71e..7c6b31ae8b73e 100644
--- a/mlir/test/lib/Analysis/CMakeLists.txt
+++ b/mlir/test/lib/Analysis/CMakeLists.txt
@@ -10,6 +10,7 @@ add_mlir_library(MLIRTestAnalysis
   TestMemRefDependenceCheck.cpp
   TestMemRefStrideCalculation.cpp
   TestSlice.cpp
+  TestTopologicalSort.cpp
 
   DataFlow/TestDeadCodeAnalysis.cpp
   DataFlow/TestDenseBackwardDataFlowAnalysis.cpp
diff --git a/mlir/test/lib/Transforms/TestTopologicalSort.cpp b/mlir/test/lib/Analysis/TestTopologicalSort.cpp
similarity index 98%
rename from mlir/test/lib/Transforms/TestTopologicalSort.cpp
rename to mlir/test/lib/Analysis/TestTopologicalSort.cpp
index 3b110c7126200..c7e0206b2a4d7 100644
--- a/mlir/test/lib/Transforms/TestTopologicalSort.cpp
+++ b/mlir/test/lib/Analysis/TestTopologicalSort.cpp
@@ -6,10 +6,10 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Analysis/TopologicalSortUtils.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/Pass/Pass.h"
-#include "mlir/Transforms/TopologicalSortUtils.h"
 
 using namespace mlir;
 
diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt
index a849b7ebd29e2..975a41ac3d5fe 100644
--- a/mlir/test/lib/Transforms/CMakeLists.txt
+++ b/mlir/test/lib/Transforms/CMakeLists.txt
@@ -26,7 +26,6 @@ add_mlir_library(MLIRTestTransforms
   TestInlining.cpp
   TestIntRangeInference.cpp
   TestMakeIsolatedFromAbove.cpp
-  TestTopologicalSort.cpp
   ${MLIRTestTransformsPDLSrc}
 
   EXCLUDE_FROM_LIBMLIR
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index fc449e9010ae4..971c851a5f89f 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -7597,7 +7597,6 @@ cc_library(
         "include/mlir/Transforms/LoopInvariantCodeMotionUtils.h",
         "include/mlir/Transforms/OneToNTypeConversion.h",
         "include/mlir/Transforms/RegionUtils.h",
-        "include/mlir/Transforms/TopologicalSortUtils.h",
     ],
     includes = ["include"],
     deps = [

>From 277ed0ac9f4e2b8963defea5c54e2f0868103a37 Mon Sep 17 00:00:00 2001
From: Christian Ulmann <christian.ulmann at nextsilicon.com>
Date: Fri, 17 May 2024 12:02:50 +0000
Subject: [PATCH 3/5] migrate block topo sort

---
 .../mlir/Analysis/TopologicalSortUtils.h      |  4 ++++
 mlir/include/mlir/Transforms/RegionUtils.h    |  4 ----
 mlir/lib/Analysis/SliceAnalysis.cpp           | 23 ++-----------------
 mlir/lib/Analysis/TopologicalSortUtils.cpp    | 21 +++++++++++++++++
 .../ArmSME/Transforms/TileAllocation.cpp      |  1 +
 .../OpenACC/OpenACCToLLVMIRTranslation.cpp    |  2 +-
 .../OpenMP/OpenMPToLLVMIRTranslation.cpp      |  1 +
 mlir/lib/Target/LLVMIR/ModuleTranslation.cpp  |  2 +-
 mlir/lib/Transforms/Mem2Reg.cpp               |  2 +-
 mlir/lib/Transforms/Utils/RegionUtils.cpp     | 17 --------------
 10 files changed, 32 insertions(+), 45 deletions(-)

diff --git a/mlir/include/mlir/Analysis/TopologicalSortUtils.h b/mlir/include/mlir/Analysis/TopologicalSortUtils.h
index fb9441db119fd..c2bc15ad3143f 100644
--- a/mlir/include/mlir/Analysis/TopologicalSortUtils.h
+++ b/mlir/include/mlir/Analysis/TopologicalSortUtils.h
@@ -104,6 +104,10 @@ bool computeTopologicalSorting(
     MutableArrayRef<Operation *> ops,
     function_ref<bool(Value, Operation *)> isOperandReady = nullptr);
 
+/// Get a list of blocks that is sorted according to dominance. This sort is
+/// stable.
+SetVector<Block *> getBlocksSortedByDominance(Region &region);
+
 } // end namespace mlir
 
 #endif // MLIR_ANALYSIS_TOPOLOGICALSORTUTILS_H
diff --git a/mlir/include/mlir/Transforms/RegionUtils.h b/mlir/include/mlir/Transforms/RegionUtils.h
index f65d0d44eef42..06eebff201d1b 100644
--- a/mlir/include/mlir/Transforms/RegionUtils.h
+++ b/mlir/include/mlir/Transforms/RegionUtils.h
@@ -87,10 +87,6 @@ LogicalResult eraseUnreachableBlocks(RewriterBase &rewriter,
 LogicalResult runRegionDCE(RewriterBase &rewriter,
                            MutableArrayRef<Region> regions);
 
-/// Get a list of blocks that is sorted according to dominance. This sort is
-/// stable.
-SetVector<Block *> getBlocksSortedByDominance(Region &region);
-
 } // namespace mlir
 
 #endif // MLIR_TRANSFORMS_REGIONUTILS_H_
diff --git a/mlir/lib/Analysis/SliceAnalysis.cpp b/mlir/lib/Analysis/SliceAnalysis.cpp
index f93183749dfd2..37dae769dbc6d 100644
--- a/mlir/lib/Analysis/SliceAnalysis.cpp
+++ b/mlir/lib/Analysis/SliceAnalysis.cpp
@@ -11,13 +11,11 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Analysis/TopologicalSortUtils.h"
 #include "mlir/IR/Block.h"
-#include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/Operation.h"
-#include "mlir/IR/RegionGraphTraits.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Support/LLVM.h"
-#include "llvm/ADT/PostOrderIterator.h"
 #include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/SmallPtrSet.h"
 
@@ -167,23 +165,6 @@ mlir::getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions,
   return topologicalSort(slice);
 }
 
-/// TODO: deduplicate
-static SetVector<Block *> getTopologicallySortedBlocks(Region &region) {
-  // For each block that has not been visited yet (i.e. that has no
-  // predecessors), add it to the list as well as its successors.
-  SetVector<Block *> blocks;
-  for (Block &b : region) {
-    if (blocks.count(&b) == 0) {
-      llvm::ReversePostOrderTraversal<Block *> traversal(&b);
-      blocks.insert(traversal.begin(), traversal.end());
-    }
-  }
-  assert(blocks.size() == region.getBlocks().size() &&
-         "some blocks are not sorted");
-
-  return blocks;
-}
-
 /// Computes the common ancestor region of all operations in `ops`. Remembers
 /// all the traversed regions in `traversedRegions`.
 static Region *findCommonParentRegion(const SetVector<Operation *> &ops,
@@ -220,7 +201,7 @@ static void topoSortRegion(Region &region,
                            const DenseSet<Region *> &relevantRegions,
                            const SetVector<Operation *> &toSort,
                            SetVector<Operation *> &result) {
-  SetVector<Block *> sortedBlocks = getTopologicallySortedBlocks(region);
+  SetVector<Block *> sortedBlocks = getBlocksSortedByDominance(region);
   for (Block *block : sortedBlocks) {
     for (Operation &op : *block) {
       if (toSort.contains(&op))
diff --git a/mlir/lib/Analysis/TopologicalSortUtils.cpp b/mlir/lib/Analysis/TopologicalSortUtils.cpp
index 4281beacee89e..f5e16e9a91fe5 100644
--- a/mlir/lib/Analysis/TopologicalSortUtils.cpp
+++ b/mlir/lib/Analysis/TopologicalSortUtils.cpp
@@ -7,7 +7,12 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Analysis/TopologicalSortUtils.h"
+#include "mlir/IR/Block.h"
 #include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/RegionGraphTraits.h"
+
+#include "llvm/ADT/PostOrderIterator.h"
+#include "llvm/ADT/SetVector.h"
 
 using namespace mlir;
 
@@ -146,3 +151,19 @@ bool mlir::computeTopologicalSorting(
 
   return allOpsScheduled;
 }
+
+SetVector<Block *> mlir::getBlocksSortedByDominance(Region &region) {
+  // For each block that has not been visited yet (i.e. that has no
+  // predecessors), add it to the list as well as its successors.
+  SetVector<Block *> blocks;
+  for (Block &b : region) {
+    if (blocks.count(&b) == 0) {
+      llvm::ReversePostOrderTraversal<Block *> traversal(&b);
+      blocks.insert(traversal.begin(), traversal.end());
+    }
+  }
+  assert(blocks.size() == region.getBlocks().size() &&
+         "some blocks are not sorted");
+
+  return blocks;
+}
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
index acbbbe9932e19..733e758b43907 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
@@ -46,6 +46,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Analysis/Liveness.h"
+#include "mlir/Analysis/TopologicalSortUtils.h"
 #include "mlir/Dialect/ArmSME/IR/ArmSME.h"
 #include "mlir/Dialect/ArmSME/Transforms/Passes.h"
 #include "mlir/Dialect/ArmSME/Transforms/Transforms.h"
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp
index eeda245ce969f..d9cf85e4aecab 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp
@@ -12,6 +12,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.h"
+#include "mlir/Analysis/TopologicalSortUtils.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/OpenACC/OpenACC.h"
 #include "mlir/IR/BuiltinOps.h"
@@ -19,7 +20,6 @@
 #include "mlir/Support/LLVM.h"
 #include "mlir/Target/LLVMIR/Dialect/OpenMPCommon.h"
 #include "mlir/Target/LLVMIR/ModuleTranslation.h"
-#include "mlir/Transforms/RegionUtils.h"
 
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Frontend/OpenMP/OMPConstants.h"
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 34b6903f8da07..9d125b7f11809 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -11,6 +11,7 @@
 //
 //===----------------------------------------------------------------------===//
 #include "mlir/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.h"
+#include "mlir/Analysis/TopologicalSortUtils.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
 #include "mlir/Dialect/OpenMP/OpenMPInterfaces.h"
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index cf3257c8b9b87..1ec0736ec08bf 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -16,6 +16,7 @@
 #include "AttrKindDetail.h"
 #include "DebugTranslation.h"
 #include "LoopAnnotationTranslation.h"
+#include "mlir/Analysis/TopologicalSortUtils.h"
 #include "mlir/Dialect/DLTI/DLTI.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/LLVMIR/LLVMInterfaces.h"
@@ -33,7 +34,6 @@
 #include "mlir/Support/LogicalResult.h"
 #include "mlir/Target/LLVMIR/LLVMTranslationInterface.h"
 #include "mlir/Target/LLVMIR/TypeToLLVM.h"
-#include "mlir/Transforms/RegionUtils.h"
 
 #include "llvm/ADT/PostOrderIterator.h"
 #include "llvm/ADT/SetVector.h"
diff --git a/mlir/lib/Transforms/Mem2Reg.cpp b/mlir/lib/Transforms/Mem2Reg.cpp
index e2e240ad865ce..a452cc3fae8ac 100644
--- a/mlir/lib/Transforms/Mem2Reg.cpp
+++ b/mlir/lib/Transforms/Mem2Reg.cpp
@@ -9,6 +9,7 @@
 #include "mlir/Transforms/Mem2Reg.h"
 #include "mlir/Analysis/DataLayoutAnalysis.h"
 #include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Analysis/TopologicalSortUtils.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/Dominance.h"
 #include "mlir/IR/PatternMatch.h"
@@ -16,7 +17,6 @@
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/MemorySlotInterfaces.h"
 #include "mlir/Transforms/Passes.h"
-#include "mlir/Transforms/RegionUtils.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/Support/GenericIteratedDominanceFrontier.h"
 
diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp
index b6a6dea5fe9a0..b5e641d39fc0a 100644
--- a/mlir/lib/Transforms/Utils/RegionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp
@@ -19,7 +19,6 @@
 
 #include "llvm/ADT/DepthFirstIterator.h"
 #include "llvm/ADT/PostOrderIterator.h"
-#include "llvm/ADT/SmallSet.h"
 
 #include <deque>
 
@@ -836,19 +835,3 @@ LogicalResult mlir::simplifyRegions(RewriterBase &rewriter,
   return success(eliminatedBlocks || eliminatedOpsOrArgs ||
                  mergedIdenticalBlocks);
 }
-
-SetVector<Block *> mlir::getBlocksSortedByDominance(Region &region) {
-  // For each block that has not been visited yet (i.e. that has no
-  // predecessors), add it to the list as well as its successors.
-  SetVector<Block *> blocks;
-  for (Block &b : region) {
-    if (blocks.count(&b) == 0) {
-      llvm::ReversePostOrderTraversal<Block *> traversal(&b);
-      blocks.insert(traversal.begin(), traversal.end());
-    }
-  }
-  assert(blocks.size() == region.getBlocks().size() &&
-         "some blocks are not sorted");
-
-  return blocks;
-}

>From a269b086c64d009d2f3af2891b02c5daca7457a4 Mon Sep 17 00:00:00 2001
From: Christian Ulmann <christian.ulmann at nextsilicon.com>
Date: Fri, 17 May 2024 13:56:09 +0000
Subject: [PATCH 4/5] move all topo sorts into the utils file

---
 mlir/include/mlir/Analysis/SliceAnalysis.h    |  6 --
 .../mlir/Analysis/TopologicalSortUtils.h      |  4 +
 mlir/lib/Analysis/SliceAnalysis.cpp           | 74 -------------------
 mlir/lib/Analysis/TopologicalSortUtils.cpp    | 74 +++++++++++++++++++
 .../Conversion/VectorToGPU/VectorToGPU.cpp    |  1 +
 .../Dialect/Affine/Utils/LoopFusionUtils.cpp  |  1 +
 mlir/lib/Transforms/SROA.cpp                  |  1 +
 mlir/test/lib/Analysis/TestSlice.cpp          |  2 +-
 8 files changed, 82 insertions(+), 81 deletions(-)

diff --git a/mlir/include/mlir/Analysis/SliceAnalysis.h b/mlir/include/mlir/Analysis/SliceAnalysis.h
index 19571fc1946be..99279fdfe427c 100644
--- a/mlir/include/mlir/Analysis/SliceAnalysis.h
+++ b/mlir/include/mlir/Analysis/SliceAnalysis.h
@@ -223,12 +223,6 @@ SetVector<Operation *>
 getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions = {},
          const ForwardSliceOptions &forwardSliceOptions = {});
 
-/// Multi-root DAG topological sort.
-/// Performs a topological sort of the Operation in the `toSort` SetVector.
-/// Returns a topologically sorted SetVector.
-/// Does not support multi-sets.
-SetVector<Operation *> topologicalSort(const SetVector<Operation *> &toSort);
-
 /// Utility to match a generic reduction given a list of iteration-carried
 /// arguments, `iterCarriedArgs` and the position of the potential reduction
 /// argument within the list, `redPos`. If a reduction is matched, returns the
diff --git a/mlir/include/mlir/Analysis/TopologicalSortUtils.h b/mlir/include/mlir/Analysis/TopologicalSortUtils.h
index c2bc15ad3143f..7aabc5ee457c0 100644
--- a/mlir/include/mlir/Analysis/TopologicalSortUtils.h
+++ b/mlir/include/mlir/Analysis/TopologicalSortUtils.h
@@ -108,6 +108,10 @@ bool computeTopologicalSorting(
 /// stable.
 SetVector<Block *> getBlocksSortedByDominance(Region &region);
 
+/// Sorts all operation in `toSort` topologically while also region semantics.
+/// Does not support multi-sets.
+SetVector<Operation *> topologicalSort(const SetVector<Operation *> &toSort);
+
 } // end namespace mlir
 
 #endif // MLIR_ANALYSIS_TOPOLOGICALSORTUTILS_H
diff --git a/mlir/lib/Analysis/SliceAnalysis.cpp b/mlir/lib/Analysis/SliceAnalysis.cpp
index 37dae769dbc6d..2b1cf411ceeee 100644
--- a/mlir/lib/Analysis/SliceAnalysis.cpp
+++ b/mlir/lib/Analysis/SliceAnalysis.cpp
@@ -165,80 +165,6 @@ mlir::getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions,
   return topologicalSort(slice);
 }
 
-/// Computes the common ancestor region of all operations in `ops`. Remembers
-/// all the traversed regions in `traversedRegions`.
-static Region *findCommonParentRegion(const SetVector<Operation *> &ops,
-                                      DenseSet<Region *> &traversedRegions) {
-  // Map to count the number of times a region was encountered.
-  llvm::DenseMap<Region *, size_t> regionCounts;
-  size_t expectedCount = ops.size();
-
-  // Walk the region tree for each operation towards the root and add to the
-  // region count.
-  Region *res = nullptr;
-  for (Operation *op : ops) {
-    Region *current = op->getParentRegion();
-    while (current) {
-      // Insert or get the count.
-      auto it = regionCounts.try_emplace(current, 0).first;
-      size_t count = ++it->getSecond();
-      if (count == expectedCount) {
-        res = current;
-        break;
-      }
-      current = current->getParentRegion();
-    }
-  }
-  auto firstRange = llvm::make_first_range(regionCounts);
-  traversedRegions.insert(firstRange.begin(), firstRange.end());
-  return res;
-}
-
-/// Topologically traverses `region` and insers all encountered operations in
-/// `toSort` into the result. Recursively traverses regions when they are
-/// present in `relevantRegions`.
-static void topoSortRegion(Region &region,
-                           const DenseSet<Region *> &relevantRegions,
-                           const SetVector<Operation *> &toSort,
-                           SetVector<Operation *> &result) {
-  SetVector<Block *> sortedBlocks = getBlocksSortedByDominance(region);
-  for (Block *block : sortedBlocks) {
-    for (Operation &op : *block) {
-      if (toSort.contains(&op))
-        result.insert(&op);
-      for (Region &subRegion : op.getRegions()) {
-        // Skip regions that do not contain operations from `toSort`.
-        if (!relevantRegions.contains(&region))
-          continue;
-        topoSortRegion(subRegion, relevantRegions, toSort, result);
-      }
-    }
-  }
-}
-
-SetVector<Operation *>
-mlir::topologicalSort(const SetVector<Operation *> &toSort) {
-  if (toSort.size() <= 1)
-    return toSort;
-
-  assert(llvm::all_of(toSort,
-                      [&](Operation *op) { return toSort.count(op) == 1; }) &&
-         "expected only unique set entries");
-
-  // First, find the root region to start the recursive traversal through the
-  // IR.
-  DenseSet<Region *> relevantRegions;
-  Region *rootRegion = findCommonParentRegion(toSort, relevantRegions);
-  assert(rootRegion && "expected all ops to have a common ancestor");
-
-  // Sort all element in `toSort` by recursively traversing the IR.
-  SetVector<Operation *> result;
-  topoSortRegion(*rootRegion, relevantRegions, toSort, result);
-  assert(result.size() == toSort.size() &&
-         "expected all operations to be present in the result");
-  return result;
-}
-
 /// Returns true if `value` (transitively) depends on iteration-carried values
 /// of the given `ancestorOp`.
 static bool dependsOnCarriedVals(Value value,
diff --git a/mlir/lib/Analysis/TopologicalSortUtils.cpp b/mlir/lib/Analysis/TopologicalSortUtils.cpp
index f5e16e9a91fe5..2fc1bd582ef47 100644
--- a/mlir/lib/Analysis/TopologicalSortUtils.cpp
+++ b/mlir/lib/Analysis/TopologicalSortUtils.cpp
@@ -167,3 +167,77 @@ SetVector<Block *> mlir::getBlocksSortedByDominance(Region &region) {
 
   return blocks;
 }
+
+/// Computes the common ancestor region of all operations in `ops`. Remembers
+/// all the traversed regions in `traversedRegions`.
+static Region *findCommonParentRegion(const SetVector<Operation *> &ops,
+                                      DenseSet<Region *> &traversedRegions) {
+  // Map to count the number of times a region was encountered.
+  llvm::DenseMap<Region *, size_t> regionCounts;
+  size_t expectedCount = ops.size();
+
+  // Walk the region tree for each operation towards the root and add to the
+  // region count.
+  Region *res = nullptr;
+  for (Operation *op : ops) {
+    Region *current = op->getParentRegion();
+    while (current) {
+      // Insert or get the count.
+      auto it = regionCounts.try_emplace(current, 0).first;
+      size_t count = ++it->getSecond();
+      if (count == expectedCount) {
+        res = current;
+        break;
+      }
+      current = current->getParentRegion();
+    }
+  }
+  auto firstRange = llvm::make_first_range(regionCounts);
+  traversedRegions.insert(firstRange.begin(), firstRange.end());
+  return res;
+}
+
+/// Topologically traverses `region` and insers all encountered operations in
+/// `toSort` into the result. Recursively traverses regions when they are
+/// present in `relevantRegions`.
+static void topoSortRegion(Region &region,
+                           const DenseSet<Region *> &relevantRegions,
+                           const SetVector<Operation *> &toSort,
+                           SetVector<Operation *> &result) {
+  SetVector<Block *> sortedBlocks = getBlocksSortedByDominance(region);
+  for (Block *block : sortedBlocks) {
+    for (Operation &op : *block) {
+      if (toSort.contains(&op))
+        result.insert(&op);
+      for (Region &subRegion : op.getRegions()) {
+        // Skip regions that do not contain operations from `toSort`.
+        if (!relevantRegions.contains(&region))
+          continue;
+        topoSortRegion(subRegion, relevantRegions, toSort, result);
+      }
+    }
+  }
+}
+
+SetVector<Operation *>
+mlir::topologicalSort(const SetVector<Operation *> &toSort) {
+  if (toSort.size() <= 1)
+    return toSort;
+
+  assert(llvm::all_of(toSort,
+                      [&](Operation *op) { return toSort.count(op) == 1; }) &&
+         "expected only unique set entries");
+
+  // First, find the root region to start the recursive traversal through the
+  // IR.
+  DenseSet<Region *> relevantRegions;
+  Region *rootRegion = findCommonParentRegion(toSort, relevantRegions);
+  assert(rootRegion && "expected all ops to have a common ancestor");
+
+  // Sort all element in `toSort` by recursively traversing the IR.
+  SetVector<Operation *> result;
+  topoSortRegion(*rootRegion, relevantRegions, toSort, result);
+  assert(result.size() == toSort.size() &&
+         "expected all operations to be present in the result");
+  return result;
+}
diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index 332f0a2eecfcf..4496c2bc5fe8b 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -15,6 +15,7 @@
 #include <type_traits>
 
 #include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Analysis/TopologicalSortUtils.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
diff --git a/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp
index 84ae4b52dcf4e..7f3e43d0b4cd3 100644
--- a/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp
@@ -12,6 +12,7 @@
 
 #include "mlir/Dialect/Affine/LoopFusionUtils.h"
 #include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Analysis/TopologicalSortUtils.h"
 #include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
 #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
 #include "mlir/Dialect/Affine/Analysis/Utils.h"
diff --git a/mlir/lib/Transforms/SROA.cpp b/mlir/lib/Transforms/SROA.cpp
index 67cbade07bc94..39f7256fb789d 100644
--- a/mlir/lib/Transforms/SROA.cpp
+++ b/mlir/lib/Transforms/SROA.cpp
@@ -9,6 +9,7 @@
 #include "mlir/Transforms/SROA.h"
 #include "mlir/Analysis/DataLayoutAnalysis.h"
 #include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Analysis/TopologicalSortUtils.h"
 #include "mlir/Interfaces/MemorySlotInterfaces.h"
 #include "mlir/Transforms/Passes.h"
 
diff --git a/mlir/test/lib/Analysis/TestSlice.cpp b/mlir/test/lib/Analysis/TestSlice.cpp
index 06c41d8c4a110..fc367c07ad863 100644
--- a/mlir/test/lib/Analysis/TestSlice.cpp
+++ b/mlir/test/lib/Analysis/TestSlice.cpp
@@ -6,7 +6,7 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Analysis/TopologicalSortUtils.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/SymbolTable.h"
 #include "mlir/Pass/Pass.h"

>From 9ea5fb929db8c4566abca69f5940a30b13de148d Mon Sep 17 00:00:00 2001
From: Christian Ulmann <christian.ulmann at nextsilicon.com>
Date: Tue, 21 May 2024 09:01:26 +0000
Subject: [PATCH 5/5] address nit comments

---
 .../mlir/Analysis/TopologicalSortUtils.h      |  6 ++--
 mlir/lib/Analysis/TopologicalSortUtils.cpp    | 34 ++++++++-----------
 mlir/lib/Transforms/SROA.cpp                  |  9 +++--
 mlir/test/lib/Analysis/TestSlice.cpp          |  3 +-
 4 files changed, 26 insertions(+), 26 deletions(-)

diff --git a/mlir/include/mlir/Analysis/TopologicalSortUtils.h b/mlir/include/mlir/Analysis/TopologicalSortUtils.h
index 7aabc5ee457c0..ee98cd8cb380e 100644
--- a/mlir/include/mlir/Analysis/TopologicalSortUtils.h
+++ b/mlir/include/mlir/Analysis/TopologicalSortUtils.h
@@ -104,12 +104,12 @@ bool computeTopologicalSorting(
     MutableArrayRef<Operation *> ops,
     function_ref<bool(Value, Operation *)> isOperandReady = nullptr);
 
-/// Get a list of blocks that is sorted according to dominance. This sort is
+/// Gets a list of blocks that is sorted according to dominance. This sort is
 /// stable.
 SetVector<Block *> getBlocksSortedByDominance(Region &region);
 
-/// Sorts all operation in `toSort` topologically while also region semantics.
-/// Does not support multi-sets.
+/// Sorts all operations in `toSort` topologically while also considering region
+/// semantics. Does not support multi-sets.
 SetVector<Operation *> topologicalSort(const SetVector<Operation *> &toSort);
 
 } // end namespace mlir
diff --git a/mlir/lib/Analysis/TopologicalSortUtils.cpp b/mlir/lib/Analysis/TopologicalSortUtils.cpp
index 2fc1bd582ef47..94c403c07385b 100644
--- a/mlir/lib/Analysis/TopologicalSortUtils.cpp
+++ b/mlir/lib/Analysis/TopologicalSortUtils.cpp
@@ -168,12 +168,12 @@ SetVector<Block *> mlir::getBlocksSortedByDominance(Region &region) {
   return blocks;
 }
 
-/// Computes the common ancestor region of all operations in `ops`. Remembers
-/// all the traversed regions in `traversedRegions`.
-static Region *findCommonParentRegion(const SetVector<Operation *> &ops,
-                                      DenseSet<Region *> &traversedRegions) {
+/// Computes the closest common ancestor region of all operations in `ops`.
+/// Remembers all the traversed regions in `traversedRegions`.
+static Region *findCommonAncestorRegion(const SetVector<Operation *> &ops,
+                                        DenseSet<Region *> &traversedRegions) {
   // Map to count the number of times a region was encountered.
-  llvm::DenseMap<Region *, size_t> regionCounts;
+  DenseMap<Region *, size_t> regionCounts;
   size_t expectedCount = ops.size();
 
   // Walk the region tree for each operation towards the root and add to the
@@ -182,10 +182,8 @@ static Region *findCommonParentRegion(const SetVector<Operation *> &ops,
   for (Operation *op : ops) {
     Region *current = op->getParentRegion();
     while (current) {
-      // Insert or get the count.
-      auto it = regionCounts.try_emplace(current, 0).first;
-      size_t count = ++it->getSecond();
-      if (count == expectedCount) {
+      // Insert or update the count and compare it.
+      if (++regionCounts[current] == expectedCount) {
         res = current;
         break;
       }
@@ -197,11 +195,11 @@ static Region *findCommonParentRegion(const SetVector<Operation *> &ops,
   return res;
 }
 
-/// Topologically traverses `region` and insers all encountered operations in
+/// Topologically traverses `region` and inserts all encountered operations in
 /// `toSort` into the result. Recursively traverses regions when they are
 /// present in `relevantRegions`.
 static void topoSortRegion(Region &region,
-                           const DenseSet<Region *> &relevantRegions,
+                           const DenseSet<Region *> &ancestorRegions,
                            const SetVector<Operation *> &toSort,
                            SetVector<Operation *> &result) {
   SetVector<Block *> sortedBlocks = getBlocksSortedByDominance(region);
@@ -211,9 +209,9 @@ static void topoSortRegion(Region &region,
         result.insert(&op);
       for (Region &subRegion : op.getRegions()) {
         // Skip regions that do not contain operations from `toSort`.
-        if (!relevantRegions.contains(&region))
+        if (!ancestorRegions.contains(&region))
           continue;
-        topoSortRegion(subRegion, relevantRegions, toSort, result);
+        topoSortRegion(subRegion, ancestorRegions, toSort, result);
       }
     }
   }
@@ -224,19 +222,15 @@ mlir::topologicalSort(const SetVector<Operation *> &toSort) {
   if (toSort.size() <= 1)
     return toSort;
 
-  assert(llvm::all_of(toSort,
-                      [&](Operation *op) { return toSort.count(op) == 1; }) &&
-         "expected only unique set entries");
-
   // First, find the root region to start the recursive traversal through the
   // IR.
-  DenseSet<Region *> relevantRegions;
-  Region *rootRegion = findCommonParentRegion(toSort, relevantRegions);
+  DenseSet<Region *> ancestorRegions;
+  Region *rootRegion = findCommonAncestorRegion(toSort, ancestorRegions);
   assert(rootRegion && "expected all ops to have a common ancestor");
 
   // Sort all element in `toSort` by recursively traversing the IR.
   SetVector<Operation *> result;
-  topoSortRegion(*rootRegion, relevantRegions, toSort, result);
+  topoSortRegion(*rootRegion, ancestorRegions, toSort, result);
   assert(result.size() == toSort.size() &&
          "expected all operations to be present in the result");
   return result;
diff --git a/mlir/lib/Transforms/SROA.cpp b/mlir/lib/Transforms/SROA.cpp
index 39f7256fb789d..9a7f4db2afe00 100644
--- a/mlir/lib/Transforms/SROA.cpp
+++ b/mlir/lib/Transforms/SROA.cpp
@@ -108,14 +108,19 @@ computeDestructuringInfo(DestructurableMemorySlot &slot,
 
     // An operation that has blocking uses must be promoted. If it is not
     // promotable, destructuring must fail.
-    if (!promotable)
+    if (!promotable) {
+      // user->emitError() << "not promotable";
       return {};
+    }
 
     SmallVector<OpOperand *> newBlockingUses;
     // If the operation decides it cannot deal with removing the blocking uses,
     // destructuring must fail.
-    if (!promotable.canUsesBeRemoved(blockingUses, newBlockingUses, dataLayout))
+    if (!promotable.canUsesBeRemoved(blockingUses, newBlockingUses,
+                                     dataLayout)) {
+      // promotable->emitError() << "not removable";
       return {};
+    }
 
     // Then, register any new blocking uses for coming operations.
     for (OpOperand *blockingUse : newBlockingUses) {
diff --git a/mlir/test/lib/Analysis/TestSlice.cpp b/mlir/test/lib/Analysis/TestSlice.cpp
index fc367c07ad863..7e8320dbf3ec3 100644
--- a/mlir/test/lib/Analysis/TestSlice.cpp
+++ b/mlir/test/lib/Analysis/TestSlice.cpp
@@ -25,7 +25,8 @@ struct TestTopologicalSortPass
 
   StringRef getArgument() const final { return "test-print-topological-sort"; }
   StringRef getDescription() const final {
-    return "Print operations in topological order";
+    return "Sorts operations topologically and attaches attributes with their "
+           "corresponding index in the ordering to them";
   }
   void runOnOperation() override {
     SetVector<Operation *> toSort;



More information about the Mlir-commits mailing list