[Mlir-commits] [llvm] [mlir] [MLIR][Analysis] Consolidate topological sort utilities (PR #92563)
Christian Ulmann
llvmlistbot at llvm.org
Tue May 21 02:17:28 PDT 2024
https://github.com/Dinistro updated https://github.com/llvm/llvm-project/pull/92563
>From c64c0b7708a363960eacfe66e631fbef7ef5bb00 Mon Sep 17 00:00:00 2001
From: Christian Ulmann <christian.ulmann at nextsilicon.com>
Date: Fri, 17 May 2024 11:16:56 +0000
Subject: [PATCH 1/5] first impl of the new topo sort
---
mlir/include/mlir/Analysis/SliceAnalysis.h | 1 +
mlir/lib/Analysis/SliceAnalysis.cpp | 126 +++++++++-----
mlir/test/Analysis/test-topoligical-sort.mlir | 53 ++++--
mlir/test/Dialect/Affine/slicing-utils.mlir | 160 +++++++++---------
mlir/test/lib/Analysis/TestSlice.cpp | 28 ++-
5 files changed, 211 insertions(+), 157 deletions(-)
diff --git a/mlir/include/mlir/Analysis/SliceAnalysis.h b/mlir/include/mlir/Analysis/SliceAnalysis.h
index d5cdf72c3889f..19571fc1946be 100644
--- a/mlir/include/mlir/Analysis/SliceAnalysis.h
+++ b/mlir/include/mlir/Analysis/SliceAnalysis.h
@@ -226,6 +226,7 @@ getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions = {},
/// Multi-root DAG topological sort.
/// Performs a topological sort of the Operation in the `toSort` SetVector.
/// Returns a topologically sorted SetVector.
+/// Does not support multi-sets.
SetVector<Operation *> topologicalSort(const SetVector<Operation *> &toSort);
/// Utility to match a generic reduction given a list of iteration-carried
diff --git a/mlir/lib/Analysis/SliceAnalysis.cpp b/mlir/lib/Analysis/SliceAnalysis.cpp
index 26fe8e3dc0819..f93183749dfd2 100644
--- a/mlir/lib/Analysis/SliceAnalysis.cpp
+++ b/mlir/lib/Analysis/SliceAnalysis.cpp
@@ -11,10 +11,13 @@
//===----------------------------------------------------------------------===//
#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/IR/Block.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Operation.h"
+#include "mlir/IR/RegionGraphTraits.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/PostOrderIterator.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallPtrSet.h"
@@ -164,60 +167,95 @@ mlir::getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions,
return topologicalSort(slice);
}
-namespace {
-/// DFS post-order implementation that maintains a global count to work across
-/// multiple invocations, to help implement topological sort on multi-root DAGs.
-/// We traverse all operations but only record the ones that appear in
-/// `toSort` for the final result.
-struct DFSState {
- DFSState(const SetVector<Operation *> &set) : toSort(set), seen() {}
- const SetVector<Operation *> &toSort;
- SmallVector<Operation *, 16> topologicalCounts;
- DenseSet<Operation *> seen;
-};
-} // namespace
-
-static void dfsPostorder(Operation *root, DFSState *state) {
- SmallVector<Operation *> queue(1, root);
- std::vector<Operation *> ops;
- while (!queue.empty()) {
- Operation *current = queue.pop_back_val();
- ops.push_back(current);
- for (Operation *op : current->getUsers())
- queue.push_back(op);
- for (Region ®ion : current->getRegions()) {
- for (Operation &op : region.getOps())
- queue.push_back(&op);
+/// TODO: deduplicate
+static SetVector<Block *> getTopologicallySortedBlocks(Region ®ion) {
+ // For each block that has not been visited yet (i.e. that has no
+ // predecessors), add it to the list as well as its successors.
+ SetVector<Block *> blocks;
+ for (Block &b : region) {
+ if (blocks.count(&b) == 0) {
+ llvm::ReversePostOrderTraversal<Block *> traversal(&b);
+ blocks.insert(traversal.begin(), traversal.end());
}
}
+ assert(blocks.size() == region.getBlocks().size() &&
+ "some blocks are not sorted");
- for (Operation *op : llvm::reverse(ops)) {
- if (state->seen.insert(op).second && state->toSort.count(op) > 0)
- state->topologicalCounts.push_back(op);
+ return blocks;
+}
+
+/// Computes the common ancestor region of all operations in `ops`. Remembers
+/// all the traversed regions in `traversedRegions`.
+static Region *findCommonParentRegion(const SetVector<Operation *> &ops,
+ DenseSet<Region *> &traversedRegions) {
+ // Map to count the number of times a region was encountered.
+ llvm::DenseMap<Region *, size_t> regionCounts;
+ size_t expectedCount = ops.size();
+
+ // Walk the region tree for each operation towards the root and add to the
+ // region count.
+ Region *res = nullptr;
+ for (Operation *op : ops) {
+ Region *current = op->getParentRegion();
+ while (current) {
+ // Insert or get the count.
+ auto it = regionCounts.try_emplace(current, 0).first;
+ size_t count = ++it->getSecond();
+ if (count == expectedCount) {
+ res = current;
+ break;
+ }
+ current = current->getParentRegion();
+ }
+ }
+ auto firstRange = llvm::make_first_range(regionCounts);
+ traversedRegions.insert(firstRange.begin(), firstRange.end());
+ return res;
+}
+
+/// Topologically traverses `region` and insers all encountered operations in
+/// `toSort` into the result. Recursively traverses regions when they are
+/// present in `relevantRegions`.
+static void topoSortRegion(Region ®ion,
+ const DenseSet<Region *> &relevantRegions,
+ const SetVector<Operation *> &toSort,
+ SetVector<Operation *> &result) {
+ SetVector<Block *> sortedBlocks = getTopologicallySortedBlocks(region);
+ for (Block *block : sortedBlocks) {
+ for (Operation &op : *block) {
+ if (toSort.contains(&op))
+ result.insert(&op);
+ for (Region &subRegion : op.getRegions()) {
+ // Skip regions that do not contain operations from `toSort`.
+ if (!relevantRegions.contains(®ion))
+ continue;
+ topoSortRegion(subRegion, relevantRegions, toSort, result);
+ }
+ }
}
}
SetVector<Operation *>
mlir::topologicalSort(const SetVector<Operation *> &toSort) {
- if (toSort.empty()) {
+ if (toSort.size() <= 1)
return toSort;
- }
- // Run from each root with global count and `seen` set.
- DFSState state(toSort);
- for (auto *s : toSort) {
- assert(toSort.count(s) == 1 && "NYI: multi-sets not supported");
- dfsPostorder(s, &state);
- }
-
- // Reorder and return.
- SetVector<Operation *> res;
- for (auto it = state.topologicalCounts.rbegin(),
- eit = state.topologicalCounts.rend();
- it != eit; ++it) {
- res.insert(*it);
- }
- return res;
+ assert(llvm::all_of(toSort,
+ [&](Operation *op) { return toSort.count(op) == 1; }) &&
+ "expected only unique set entries");
+
+ // First, find the root region to start the recursive traversal through the
+ // IR.
+ DenseSet<Region *> relevantRegions;
+ Region *rootRegion = findCommonParentRegion(toSort, relevantRegions);
+ assert(rootRegion && "expected all ops to have a common ancestor");
+
+ // Sort all element in `toSort` by recursively traversing the IR.
+ SetVector<Operation *> result;
+ topoSortRegion(*rootRegion, relevantRegions, toSort, result);
+ assert(result.size() == toSort.size() &&
+ "expected all operations to be present in the result");
+ return result;
}
/// Returns true if `value` (transitively) depends on iteration-carried values
diff --git a/mlir/test/Analysis/test-topoligical-sort.mlir b/mlir/test/Analysis/test-topoligical-sort.mlir
index 8608586402055..150aff854fc8f 100644
--- a/mlir/test/Analysis/test-topoligical-sort.mlir
+++ b/mlir/test/Analysis/test-topoligical-sort.mlir
@@ -1,21 +1,38 @@
-// RUN: mlir-opt %s -pass-pipeline="builtin.module(func.func(test-print-topological-sort))" 2>&1 | FileCheck %s
+// RUN: mlir-opt %s -pass-pipeline="builtin.module(func.func(test-print-topological-sort))" --split-input-file | FileCheck %s
-// CHECK-LABEL: Testing : region
-// CHECK: arith.addi {{.*}} : index
-// CHECK-NEXT: scf.for
-// CHECK: } {__test_sort_original_idx__ = 2 : i64}
-// CHECK-NEXT: arith.addi {{.*}} : i32
-// CHECK-NEXT: arith.subi {{.*}} : i32
-func.func @region(
- %arg0 : index, %arg1 : index, %arg2 : index, %arg3 : index,
- %arg4 : i32, %arg5 : i32, %arg6 : i32,
- %buffer : memref<i32>) {
- %0 = arith.addi %arg4, %arg5 {__test_sort_original_idx__ = 0} : i32
- %idx = arith.addi %arg0, %arg1 {__test_sort_original_idx__ = 3} : index
- scf.for %arg7 = %idx to %arg2 step %arg3 {
- %2 = arith.addi %0, %arg5 : i32
- %3 = arith.subi %2, %arg6 {__test_sort_original_idx__ = 1} : i32
- memref.store %3, %buffer[] : memref<i32>
- } {__test_sort_original_idx__ = 2}
+// CHECK-LABEL: single_element
+func.func @single_element() {
+ // CHECK: test_sort_index = 0
+ return {test_to_sort}
+}
+
+// -----
+
+// CHECK-LABEL: @simple_region
+func.func @simple_region(%cond: i1) {
+ // CHECK: test_sort_index = 0
+ %0 = arith.constant {test_to_sort} 42 : i32
+ scf.if %cond {
+ %1 = arith.addi %0, %0 : i32
+ // CHECK: test_sort_index = 2
+ %2 = arith.subi %0, %1 {test_to_sort} : i32
+ // CHECK: test_sort_index = 1
+ } {test_to_sort}
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @multi_region
+func.func @multi_region(%cond: i1) {
+ scf.if %cond {
+ // CHECK: test_sort_index = 0
+ %0 = arith.constant {test_to_sort} 42 : i32
+ }
+
+ scf.if %cond {
+ // CHECK: test_sort_index = 1
+ %0 = arith.constant {test_to_sort} 24 : i32
+ }
return
}
diff --git a/mlir/test/Dialect/Affine/slicing-utils.mlir b/mlir/test/Dialect/Affine/slicing-utils.mlir
index 74379978fdf8c..0848a924b9d96 100644
--- a/mlir/test/Dialect/Affine/slicing-utils.mlir
+++ b/mlir/test/Dialect/Affine/slicing-utils.mlir
@@ -28,15 +28,15 @@ func.func @slicing_test() {
// BWD: matched: %[[v1:.*]] {{.*}} backward static slice:
//
// FWDBWD: matched: %[[v1:.*]] {{.*}} static slice:
- // FWDBWD-DAG: %[[v4:.*]] = "slicing-test-op"() : () -> i4
- // FWDBWD-DAG: %[[v3:.*]] = "slicing-test-op"() : () -> i3
- // FWDBWD-NEXT: %[[v6:.*]] = "slicing-test-op"(%[[v3]], %[[v4]]) : (i3, i4) -> i6
- // FWDBWD-DAG: %[[v2:.*]] = "slicing-test-op"() : () -> i2
- // FWDBWD-DAG: %[[v1:.*]] = "slicing-test-op"() : () -> i1
- // FWDBWD-NEXT: %[[v5:.*]] = "slicing-test-op"(%[[v1]], %[[v2]]) : (i1, i2) -> i5
- // FWDBWD-DAG: %[[v8:.*]] = "slicing-test-op"(%[[v5]], %[[v6]]) : (i5, i6) -> i8
- // FWDBWD-DAG: %[[v7:.*]] = "slicing-test-op"(%[[v1]], %[[v5]]) : (i1, i5) -> i7
- // FWDBWD-NEXT: %[[v9:.*]] = "slicing-test-op"(%[[v7]], %[[v8]]) : (i7, i8) -> i9
+ // FWDBWD: %[[v1:.*]] = "slicing-test-op"() : () -> i1
+ // FWDBWD: %[[v2:.*]] = "slicing-test-op"() : () -> i2
+ // FWDBWD: %[[v3:.*]] = "slicing-test-op"() : () -> i3
+ // FWDBWD: %[[v4:.*]] = "slicing-test-op"() : () -> i4
+ // FWDBWD: %[[v5:.*]] = "slicing-test-op"(%[[v1]], %[[v2]]) : (i1, i2) -> i5
+ // FWDBWD: %[[v6:.*]] = "slicing-test-op"(%[[v3]], %[[v4]]) : (i3, i4) -> i6
+ // FWDBWD: %[[v7:.*]] = "slicing-test-op"(%[[v1]], %[[v5]]) : (i1, i5) -> i7
+ // FWDBWD: %[[v8:.*]] = "slicing-test-op"(%[[v5]], %[[v6]]) : (i5, i6) -> i8
+ // FWDBWD: %[[v9:.*]] = "slicing-test-op"(%[[v7]], %[[v8]]) : (i7, i8) -> i9
%1 = "slicing-test-op" () : () -> i1
@@ -49,15 +49,15 @@ func.func @slicing_test() {
// BWD: matched: %[[v2:.*]] {{.*}} backward static slice:
//
// FWDBWD-NEXT: matched: %[[v2:.*]] {{.*}} static slice:
- // FWDBWD-DAG: %[[v4:.*]] = "slicing-test-op"() : () -> i4
- // FWDBWD-DAG: %[[v3:.*]] = "slicing-test-op"() : () -> i3
- // FWDBWD-NEXT: %[[v6:.*]] = "slicing-test-op"(%[[v3]], %[[v4]]) : (i3, i4) -> i6
- // FWDBWD-DAG: %[[v2:.*]] = "slicing-test-op"() : () -> i2
- // FWDBWD-DAG: %[[v1:.*]] = "slicing-test-op"() : () -> i1
- // FWDBWD-NEXT: %[[v5:.*]] = "slicing-test-op"(%[[v1]], %[[v2]]) : (i1, i2) -> i5
- // FWDBWD-DAG: %[[v8:.*]] = "slicing-test-op"(%[[v5]], %[[v6]]) : (i5, i6) -> i8
- // FWDBWD-DAG: %[[v7:.*]] = "slicing-test-op"(%[[v1]], %[[v5]]) : (i1, i5) -> i7
- // FWDBWD-NEXT: %[[v9:.*]] = "slicing-test-op"(%[[v7]], %[[v8]]) : (i7, i8) -> i9
+ // FWDBWD: %[[v1:.*]] = "slicing-test-op"() : () -> i1
+ // FWDBWD: %[[v2:.*]] = "slicing-test-op"() : () -> i2
+ // FWDBWD: %[[v3:.*]] = "slicing-test-op"() : () -> i3
+ // FWDBWD: %[[v4:.*]] = "slicing-test-op"() : () -> i4
+ // FWDBWD: %[[v5:.*]] = "slicing-test-op"(%[[v1]], %[[v2]]) : (i1, i2) -> i5
+ // FWDBWD: %[[v6:.*]] = "slicing-test-op"(%[[v3]], %[[v4]]) : (i3, i4) -> i6
+ // FWDBWD: %[[v7:.*]] = "slicing-test-op"(%[[v1]], %[[v5]]) : (i1, i5) -> i7
+ // FWDBWD: %[[v8:.*]] = "slicing-test-op"(%[[v5]], %[[v6]]) : (i5, i6) -> i8
+ // FWDBWD: %[[v9:.*]] = "slicing-test-op"(%[[v7]], %[[v8]]) : (i7, i8) -> i9
%2 = "slicing-test-op" () : () -> i2
@@ -69,15 +69,15 @@ func.func @slicing_test() {
// BWD: matched: %[[v3:.*]] {{.*}} backward static slice:
//
// FWDBWD-NEXT: matched: %[[v3:.*]] {{.*}} static slice:
- // FWDBWD-DAG: %[[v2:.*]] = "slicing-test-op"() : () -> i2
- // FWDBWD-DAG: %[[v1:.*]] = "slicing-test-op"() : () -> i1
- // FWDBWD-NEXT: %[[v5:.*]] = "slicing-test-op"(%[[v1]], %[[v2]]) : (i1, i2) -> i5
- // FWDBWD-NEXT: %[[v7:.*]] = "slicing-test-op"(%[[v1]], %[[v5]]) : (i1, i5) -> i7
- // FWDBWD-DAG: %[[v4:.*]] = "slicing-test-op"() : () -> i4
- // FWDBWD-DAG: %[[v3:.*]] = "slicing-test-op"() : () -> i3
- // FWDBWD-NEXT: %[[v6:.*]] = "slicing-test-op"(%[[v3]], %[[v4]]) : (i3, i4) -> i6
- // FWDBWD-NEXT: %[[v8:.*]] = "slicing-test-op"(%[[v5]], %[[v6]]) : (i5, i6) -> i8
- // FWDBWD-NEXT: %[[v9:.*]] = "slicing-test-op"(%[[v7]], %[[v8]]) : (i7, i8) -> i9
+ // FWDBWD: %[[v1:.*]] = "slicing-test-op"() : () -> i1
+ // FWDBWD: %[[v2:.*]] = "slicing-test-op"() : () -> i2
+ // FWDBWD: %[[v3:.*]] = "slicing-test-op"() : () -> i3
+ // FWDBWD: %[[v4:.*]] = "slicing-test-op"() : () -> i4
+ // FWDBWD: %[[v5:.*]] = "slicing-test-op"(%[[v1]], %[[v2]]) : (i1, i2) -> i5
+ // FWDBWD: %[[v6:.*]] = "slicing-test-op"(%[[v3]], %[[v4]]) : (i3, i4) -> i6
+ // FWDBWD: %[[v7:.*]] = "slicing-test-op"(%[[v1]], %[[v5]]) : (i1, i5) -> i7
+ // FWDBWD: %[[v8:.*]] = "slicing-test-op"(%[[v5]], %[[v6]]) : (i5, i6) -> i8
+ // FWDBWD: %[[v9:.*]] = "slicing-test-op"(%[[v7]], %[[v8]]) : (i7, i8) -> i9
%3 = "slicing-test-op" () : () -> i3
@@ -89,15 +89,15 @@ func.func @slicing_test() {
// BWD: matched: %[[v4:.*]] {{.*}} backward static slice:
//
// FWDBWD-NEXT: matched: %[[v4:.*]] {{.*}} static slice:
- // FWDBWD-DAG: %[[v2:.*]] = "slicing-test-op"() : () -> i2
- // FWDBWD-DAG: %[[v1:.*]] = "slicing-test-op"() : () -> i1
- // FWDBWD-NEXT: %[[v5:.*]] = "slicing-test-op"(%[[v1]], %[[v2]]) : (i1, i2) -> i5
- // FWDBWD-NEXT: %[[v7:.*]] = "slicing-test-op"(%[[v1]], %[[v5]]) : (i1, i5) -> i7
- // FWDBWD-DAG: %[[v4:.*]] = "slicing-test-op"() : () -> i4
- // FWDBWD-DAG: %[[v3:.*]] = "slicing-test-op"() : () -> i3
- // FWDBWD-NEXT: %[[v6:.*]] = "slicing-test-op"(%[[v3]], %[[v4]]) : (i3, i4) -> i6
- // FWDBWD-NEXT: %[[v8:.*]] = "slicing-test-op"(%[[v5]], %[[v6]]) : (i5, i6) -> i8
- // FWDBWD-NEXT: %[[v9:.*]] = "slicing-test-op"(%[[v7]], %[[v8]]) : (i7, i8) -> i9
+ // FWDBWD: %[[v1:.*]] = "slicing-test-op"() : () -> i1
+ // FWDBWD: %[[v2:.*]] = "slicing-test-op"() : () -> i2
+ // FWDBWD: %[[v3:.*]] = "slicing-test-op"() : () -> i3
+ // FWDBWD: %[[v4:.*]] = "slicing-test-op"() : () -> i4
+ // FWDBWD: %[[v5:.*]] = "slicing-test-op"(%[[v1]], %[[v2]]) : (i1, i2) -> i5
+ // FWDBWD: %[[v6:.*]] = "slicing-test-op"(%[[v3]], %[[v4]]) : (i3, i4) -> i6
+ // FWDBWD: %[[v7:.*]] = "slicing-test-op"(%[[v1]], %[[v5]]) : (i1, i5) -> i7
+ // FWDBWD: %[[v8:.*]] = "slicing-test-op"(%[[v5]], %[[v6]]) : (i5, i6) -> i8
+ // FWDBWD: %[[v9:.*]] = "slicing-test-op"(%[[v7]], %[[v8]]) : (i7, i8) -> i9
%4 = "slicing-test-op" () : () -> i4
@@ -111,15 +111,15 @@ func.func @slicing_test() {
// BWD-DAG: %[[v2:.*]] = "slicing-test-op"() : () -> i2
//
// FWDBWD-NEXT: matched: %[[v5:.*]] {{.*}} static slice:
- // FWDBWD-DAG: %[[v4:.*]] = "slicing-test-op"() : () -> i4
- // FWDBWD-DAG: %[[v3:.*]] = "slicing-test-op"() : () -> i3
- // FWDBWD-NEXT: %[[v6:.*]] = "slicing-test-op"(%[[v3]], %[[v4]]) : (i3, i4) -> i6
- // FWDBWD-DAG: %[[v2:.*]] = "slicing-test-op"() : () -> i2
- // FWDBWD-DAG: %[[v1:.*]] = "slicing-test-op"() : () -> i1
- // FWDBWD-NEXT: %[[v5:.*]] = "slicing-test-op"(%[[v1]], %[[v2]]) : (i1, i2) -> i5
- // FWDBWD-DAG: %[[v8:.*]] = "slicing-test-op"(%[[v5]], %[[v6]]) : (i5, i6) -> i8
- // FWDBWD-DAG: %[[v7:.*]] = "slicing-test-op"(%[[v1]], %[[v5]]) : (i1, i5) -> i7
- // FWDBWD-NEXT: %[[v9:.*]] = "slicing-test-op"(%[[v7]], %[[v8]]) : (i7, i8) -> i9
+ // FWDBWD: %[[v1:.*]] = "slicing-test-op"() : () -> i1
+ // FWDBWD: %[[v2:.*]] = "slicing-test-op"() : () -> i2
+ // FWDBWD: %[[v3:.*]] = "slicing-test-op"() : () -> i3
+ // FWDBWD: %[[v4:.*]] = "slicing-test-op"() : () -> i4
+ // FWDBWD: %[[v5:.*]] = "slicing-test-op"(%[[v1]], %[[v2]]) : (i1, i2) -> i5
+ // FWDBWD: %[[v6:.*]] = "slicing-test-op"(%[[v3]], %[[v4]]) : (i3, i4) -> i6
+ // FWDBWD: %[[v7:.*]] = "slicing-test-op"(%[[v1]], %[[v5]]) : (i1, i5) -> i7
+ // FWDBWD: %[[v8:.*]] = "slicing-test-op"(%[[v5]], %[[v6]]) : (i5, i6) -> i8
+ // FWDBWD: %[[v9:.*]] = "slicing-test-op"(%[[v7]], %[[v8]]) : (i7, i8) -> i9
%5 = "slicing-test-op" (%1, %2) : (i1, i2) -> i5
@@ -132,15 +132,15 @@ func.func @slicing_test() {
// BWD-DAG: %[[v4:.*]] = "slicing-test-op"() : () -> i4
//
// FWDBWD-NEXT: matched: %[[v6:.*]] {{.*}} static slice:
- // FWDBWD-DAG: %[[v2:.*]] = "slicing-test-op"() : () -> i2
- // FWDBWD-DAG: %[[v1:.*]] = "slicing-test-op"() : () -> i1
- // FWDBWD-NEXT: %[[v5:.*]] = "slicing-test-op"(%[[v1]], %[[v2]]) : (i1, i2) -> i5
- // FWDBWD-NEXT: %[[v7:.*]] = "slicing-test-op"(%[[v1]], %[[v5]]) : (i1, i5) -> i7
- // FWDBWD-DAG: %[[v4:.*]] = "slicing-test-op"() : () -> i4
- // FWDBWD-DAG: %[[v3:.*]] = "slicing-test-op"() : () -> i3
- // FWDBWD-NEXT: %[[v6:.*]] = "slicing-test-op"(%[[v3]], %[[v4]]) : (i3, i4) -> i6
- // FWDBWD-NEXT: %[[v8:.*]] = "slicing-test-op"(%[[v5]], %[[v6]]) : (i5, i6) -> i8
- // FWDBWD-NEXT: %[[v9:.*]] = "slicing-test-op"(%[[v7]], %[[v8]]) : (i7, i8) -> i9
+ // FWDBWD: %[[v1:.*]] = "slicing-test-op"() : () -> i1
+ // FWDBWD: %[[v2:.*]] = "slicing-test-op"() : () -> i2
+ // FWDBWD: %[[v3:.*]] = "slicing-test-op"() : () -> i3
+ // FWDBWD: %[[v4:.*]] = "slicing-test-op"() : () -> i4
+ // FWDBWD: %[[v5:.*]] = "slicing-test-op"(%[[v1]], %[[v2]]) : (i1, i2) -> i5
+ // FWDBWD: %[[v6:.*]] = "slicing-test-op"(%[[v3]], %[[v4]]) : (i3, i4) -> i6
+ // FWDBWD: %[[v7:.*]] = "slicing-test-op"(%[[v1]], %[[v5]]) : (i1, i5) -> i7
+ // FWDBWD: %[[v8:.*]] = "slicing-test-op"(%[[v5]], %[[v6]]) : (i5, i6) -> i8
+ // FWDBWD: %[[v9:.*]] = "slicing-test-op"(%[[v7]], %[[v8]]) : (i7, i8) -> i9
%6 = "slicing-test-op" (%3, %4) : (i3, i4) -> i6
@@ -153,15 +153,15 @@ func.func @slicing_test() {
// BWD-NEXT: %[[v5:.*]] = "slicing-test-op"(%[[v1]], %[[v2]]) : (i1, i2) -> i5
//
// FWDBWD-NEXT: matched: %[[v7:.*]] {{.*}} static slice:
- // FWDBWD-DAG: %[[v4:.*]] = "slicing-test-op"() : () -> i4
- // FWDBWD-DAG: %[[v3:.*]] = "slicing-test-op"() : () -> i3
- // FWDBWD-NEXT: %[[v6:.*]] = "slicing-test-op"(%[[v3]], %[[v4]]) : (i3, i4) -> i6
- // FWDBWD-DAG: %[[v2:.*]] = "slicing-test-op"() : () -> i2
- // FWDBWD-DAG: %[[v1:.*]] = "slicing-test-op"() : () -> i1
+ // FWDBWD: %[[v1:.*]] = "slicing-test-op"() : () -> i1
+ // FWDBWD: %[[v2:.*]] = "slicing-test-op"() : () -> i2
+ // FWDBWD: %[[v3:.*]] = "slicing-test-op"() : () -> i3
+ // FWDBWD: %[[v4:.*]] = "slicing-test-op"() : () -> i4
// FWDBWD: %[[v5:.*]] = "slicing-test-op"(%[[v1]], %[[v2]]) : (i1, i2) -> i5
- // FWDBWD-DAG: %[[v8:.*]] = "slicing-test-op"(%[[v5]], %[[v6]]) : (i5, i6) -> i8
- // FWDBWD-DAG: %[[v7:.*]] = "slicing-test-op"(%[[v1]], %[[v5]]) : (i1, i5) -> i7
- // FWDBWD-NEXT: %[[v9:.*]] = "slicing-test-op"(%[[v7]], %[[v8]]) : (i7, i8) -> i9
+ // FWDBWD: %[[v6:.*]] = "slicing-test-op"(%[[v3]], %[[v4]]) : (i3, i4) -> i6
+ // FWDBWD: %[[v7:.*]] = "slicing-test-op"(%[[v1]], %[[v5]]) : (i1, i5) -> i7
+ // FWDBWD: %[[v8:.*]] = "slicing-test-op"(%[[v5]], %[[v6]]) : (i5, i6) -> i8
+ // FWDBWD: %[[v9:.*]] = "slicing-test-op"(%[[v7]], %[[v8]]) : (i7, i8) -> i9
%7 = "slicing-test-op" (%1, %5) : (i1, i5) -> i7
@@ -177,15 +177,15 @@ func.func @slicing_test() {
// BWD-NEXT: %[[v6:.*]] = "slicing-test-op"(%[[v3]], %[[v4]]) : (i3, i4) -> i6
//
// FWDBWD-NEXT: matched: %[[v8:.*]] {{.*}} static slice:
- // FWDBWD-DAG: %[[v4:.*]] = "slicing-test-op"() : () -> i4
- // FWDBWD-DAG: %[[v3:.*]] = "slicing-test-op"() : () -> i3
- // FWDBWD-NEXT: %[[v6:.*]] = "slicing-test-op"(%[[v3]], %[[v4]]) : (i3, i4) -> i6
- // FWDBWD-DAG: %[[v2:.*]] = "slicing-test-op"() : () -> i2
- // FWDBWD-DAG: %[[v1:.*]] = "slicing-test-op"() : () -> i1
- // FWDBWD-NEXT: %[[v5:.*]] = "slicing-test-op"(%[[v1]], %[[v2]]) : (i1, i2) -> i5
- // FWDBWD-DAG: %[[v8:.*]] = "slicing-test-op"(%[[v5]], %[[v6]]) : (i5, i6) -> i8
- // FWDBWD-DAG: %[[v7:.*]] = "slicing-test-op"(%[[v1]], %[[v5]]) : (i1, i5) -> i7
- // FWDBWD-NEXT: %[[v9:.*]] = "slicing-test-op"(%[[v7]], %[[v8]]) : (i7, i8) -> i9
+ // FWDBWD: %[[v1:.*]] = "slicing-test-op"() : () -> i1
+ // FWDBWD: %[[v2:.*]] = "slicing-test-op"() : () -> i2
+ // FWDBWD: %[[v3:.*]] = "slicing-test-op"() : () -> i3
+ // FWDBWD: %[[v4:.*]] = "slicing-test-op"() : () -> i4
+ // FWDBWD: %[[v5:.*]] = "slicing-test-op"(%[[v1]], %[[v2]]) : (i1, i2) -> i5
+ // FWDBWD: %[[v6:.*]] = "slicing-test-op"(%[[v3]], %[[v4]]) : (i3, i4) -> i6
+ // FWDBWD: %[[v7:.*]] = "slicing-test-op"(%[[v1]], %[[v5]]) : (i1, i5) -> i7
+ // FWDBWD: %[[v8:.*]] = "slicing-test-op"(%[[v5]], %[[v6]]) : (i5, i6) -> i8
+ // FWDBWD: %[[v9:.*]] = "slicing-test-op"(%[[v7]], %[[v8]]) : (i7, i8) -> i9
%8 = "slicing-test-op" (%5, %6) : (i5, i6) -> i8
@@ -202,15 +202,15 @@ func.func @slicing_test() {
// BWD-NEXT: %[[v8:.*]] = "slicing-test-op"(%[[v5]], %[[v6]]) : (i5, i6) -> i8
//
// FWDBWD-NEXT: matched: %[[v9:.*]] {{.*}} static slice:
- // FWDBWD-DAG: %[[v4:.*]] = "slicing-test-op"() : () -> i4
- // FWDBWD-DAG: %[[v3:.*]] = "slicing-test-op"() : () -> i3
- // FWDBWD-NEXT: %[[v6:.*]] = "slicing-test-op"(%[[v3]], %[[v4]]) : (i3, i4) -> i6
- // FWDBWD-DAG: %[[v2:.*]] = "slicing-test-op"() : () -> i2
- // FWDBWD-DAG: %[[v1:.*]] = "slicing-test-op"() : () -> i1
- // FWDBWD-NEXT: %[[v5:.*]] = "slicing-test-op"(%[[v1]], %[[v2]]) : (i1, i2) -> i5
- // FWDBWD-DAG: %[[v8:.*]] = "slicing-test-op"(%[[v5]], %[[v6]]) : (i5, i6) -> i8
- // FWDBWD-DAG: %[[v7:.*]] = "slicing-test-op"(%[[v1]], %[[v5]]) : (i1, i5) -> i7
- // FWDBWD-NEXT: %[[v9:.*]] = "slicing-test-op"(%[[v7]], %[[v8]]) : (i7, i8) -> i9
+ // FWDBWD: %[[v1:.*]] = "slicing-test-op"() : () -> i1
+ // FWDBWD: %[[v2:.*]] = "slicing-test-op"() : () -> i2
+ // FWDBWD: %[[v3:.*]] = "slicing-test-op"() : () -> i3
+ // FWDBWD: %[[v4:.*]] = "slicing-test-op"() : () -> i4
+ // FWDBWD: %[[v5:.*]] = "slicing-test-op"(%[[v1]], %[[v2]]) : (i1, i2) -> i5
+ // FWDBWD: %[[v6:.*]] = "slicing-test-op"(%[[v3]], %[[v4]]) : (i3, i4) -> i6
+ // FWDBWD: %[[v7:.*]] = "slicing-test-op"(%[[v1]], %[[v5]]) : (i1, i5) -> i7
+ // FWDBWD: %[[v8:.*]] = "slicing-test-op"(%[[v5]], %[[v6]]) : (i5, i6) -> i8
+ // FWDBWD: %[[v9:.*]] = "slicing-test-op"(%[[v7]], %[[v8]]) : (i7, i8) -> i9
%9 = "slicing-test-op" (%7, %8) : (i7, i8) -> i9
diff --git a/mlir/test/lib/Analysis/TestSlice.cpp b/mlir/test/lib/Analysis/TestSlice.cpp
index b445febde5971..06c41d8c4a110 100644
--- a/mlir/test/lib/Analysis/TestSlice.cpp
+++ b/mlir/test/lib/Analysis/TestSlice.cpp
@@ -1,4 +1,4 @@
-//===------------- TestSlice.cpp - Test slice related analisis ------------===//
+//===- TestSlice.cpp - Test slice related analisis ------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -7,12 +7,14 @@
//===----------------------------------------------------------------------===//
#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Pass/Pass.h"
using namespace mlir;
-static const StringLiteral kOrderMarker = "__test_sort_original_idx__";
+static const StringLiteral kToSortMark = "test_to_sort";
+static const StringLiteral kOrderIndex = "test_sort_index";
namespace {
@@ -26,20 +28,16 @@ struct TestTopologicalSortPass
return "Print operations in topological order";
}
void runOnOperation() override {
- std::map<int, Operation *> ops;
- getOperation().walk([&ops](Operation *op) {
- if (auto originalOrderAttr = op->getAttrOfType<IntegerAttr>(kOrderMarker))
- ops[originalOrderAttr.getInt()] = op;
+ SetVector<Operation *> toSort;
+ getOperation().walk([&](Operation *op) {
+ if (op->hasAttrOfType<UnitAttr>(kToSortMark))
+ toSort.insert(op);
});
- SetVector<Operation *> sortedOp;
- for (auto op : ops)
- sortedOp.insert(op.second);
- sortedOp = topologicalSort(sortedOp);
- llvm::errs() << "Testing : " << getOperation().getName() << "\n";
- for (Operation *op : sortedOp) {
- op->print(llvm::errs());
- llvm::errs() << "\n";
- }
+
+ auto i32Type = IntegerType::get(&getContext(), 32);
+ SetVector<Operation *> sortedOps = topologicalSort(toSort);
+ for (auto [index, op] : llvm::enumerate(sortedOps))
+ op->setAttr(kOrderIndex, IntegerAttr::get(i32Type, index));
}
};
>From 5e181bd00cbe1cab023e496f00d3340fbf31cdbe Mon Sep 17 00:00:00 2001
From: Christian Ulmann <christian.ulmann at nextsilicon.com>
Date: Fri, 17 May 2024 11:42:08 +0000
Subject: [PATCH 2/5] move topo utils to analysis
---
.../mlir/{Transforms => Analysis}/TopologicalSortUtils.h | 6 +++---
mlir/lib/Analysis/CMakeLists.txt | 2 ++
.../{Transforms/Utils => Analysis}/TopologicalSortUtils.cpp | 4 ++--
mlir/lib/Transforms/TopologicalSort.cpp | 2 +-
mlir/lib/Transforms/Utils/CMakeLists.txt | 1 -
mlir/lib/Transforms/Utils/RegionUtils.cpp | 2 +-
mlir/lib/Transforms/ViewOpGraph.cpp | 2 +-
mlir/test/{Transforms => Analysis}/test-toposort.mlir | 0
mlir/test/lib/Analysis/CMakeLists.txt | 1 +
.../lib/{Transforms => Analysis}/TestTopologicalSort.cpp | 2 +-
mlir/test/lib/Transforms/CMakeLists.txt | 1 -
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel | 1 -
12 files changed, 12 insertions(+), 12 deletions(-)
rename mlir/include/mlir/{Transforms => Analysis}/TopologicalSortUtils.h (95%)
rename mlir/lib/{Transforms/Utils => Analysis}/TopologicalSortUtils.cpp (97%)
rename mlir/test/{Transforms => Analysis}/test-toposort.mlir (100%)
rename mlir/test/lib/{Transforms => Analysis}/TestTopologicalSort.cpp (98%)
diff --git a/mlir/include/mlir/Transforms/TopologicalSortUtils.h b/mlir/include/mlir/Analysis/TopologicalSortUtils.h
similarity index 95%
rename from mlir/include/mlir/Transforms/TopologicalSortUtils.h
rename to mlir/include/mlir/Analysis/TopologicalSortUtils.h
index 74e44b1dc485d..fb9441db119fd 100644
--- a/mlir/include/mlir/Transforms/TopologicalSortUtils.h
+++ b/mlir/include/mlir/Analysis/TopologicalSortUtils.h
@@ -6,8 +6,8 @@
//
//===----------------------------------------------------------------------===//
-#ifndef MLIR_TRANSFORMS_TOPOLOGICALSORTUTILS_H
-#define MLIR_TRANSFORMS_TOPOLOGICALSORTUTILS_H
+#ifndef MLIR_ANALYSIS_TOPOLOGICALSORTUTILS_H
+#define MLIR_ANALYSIS_TOPOLOGICALSORTUTILS_H
#include "mlir/IR/Block.h"
@@ -106,4 +106,4 @@ bool computeTopologicalSorting(
} // end namespace mlir
-#endif // MLIR_TRANSFORMS_TOPOLOGICALSORTUTILS_H
+#endif // MLIR_ANALYSIS_TOPOLOGICALSORTUTILS_H
diff --git a/mlir/lib/Analysis/CMakeLists.txt b/mlir/lib/Analysis/CMakeLists.txt
index 005814ddbec79..38d8415d81c72 100644
--- a/mlir/lib/Analysis/CMakeLists.txt
+++ b/mlir/lib/Analysis/CMakeLists.txt
@@ -6,6 +6,7 @@ set(LLVM_OPTIONAL_SOURCES
Liveness.cpp
CFGLoopInfo.cpp
SliceAnalysis.cpp
+ TopologicalSortUtils.cpp
AliasAnalysis/LocalAliasAnalysis.cpp
@@ -28,6 +29,7 @@ add_mlir_library(MLIRAnalysis
Liveness.cpp
CFGLoopInfo.cpp
SliceAnalysis.cpp
+ TopologicalSortUtils.cpp
AliasAnalysis/LocalAliasAnalysis.cpp
diff --git a/mlir/lib/Transforms/Utils/TopologicalSortUtils.cpp b/mlir/lib/Analysis/TopologicalSortUtils.cpp
similarity index 97%
rename from mlir/lib/Transforms/Utils/TopologicalSortUtils.cpp
rename to mlir/lib/Analysis/TopologicalSortUtils.cpp
index f3a9d217f2c98..4281beacee89e 100644
--- a/mlir/lib/Transforms/Utils/TopologicalSortUtils.cpp
+++ b/mlir/lib/Analysis/TopologicalSortUtils.cpp
@@ -1,4 +1,4 @@
-//===- TopologicalSortUtils.h - Topological sort utilities ------*- C++ -*-===//
+//===- TopologicalSortUtils.cpp - Topological sort utilities --------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Transforms/TopologicalSortUtils.h"
+#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/IR/OpDefinition.h"
using namespace mlir;
diff --git a/mlir/lib/Transforms/TopologicalSort.cpp b/mlir/lib/Transforms/TopologicalSort.cpp
index 1219968fb3692..528f6ef676020 100644
--- a/mlir/lib/Transforms/TopologicalSort.cpp
+++ b/mlir/lib/Transforms/TopologicalSort.cpp
@@ -8,8 +8,8 @@
#include "mlir/Transforms/Passes.h"
+#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/IR/RegionKindInterface.h"
-#include "mlir/Transforms/TopologicalSortUtils.h"
namespace mlir {
#define GEN_PASS_DEF_TOPOLOGICALSORT
diff --git a/mlir/lib/Transforms/Utils/CMakeLists.txt b/mlir/lib/Transforms/Utils/CMakeLists.txt
index d6aac0e2da4f5..b5788c679edc4 100644
--- a/mlir/lib/Transforms/Utils/CMakeLists.txt
+++ b/mlir/lib/Transforms/Utils/CMakeLists.txt
@@ -10,7 +10,6 @@ add_mlir_library(MLIRTransformUtils
LoopInvariantCodeMotionUtils.cpp
OneToNTypeConversion.cpp
RegionUtils.cpp
- TopologicalSortUtils.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Transforms
diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp
index 192f59b353295..b6a6dea5fe9a0 100644
--- a/mlir/lib/Transforms/Utils/RegionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Transforms/RegionUtils.h"
+#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Operation.h"
@@ -15,7 +16,6 @@
#include "mlir/IR/Value.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
-#include "mlir/Transforms/TopologicalSortUtils.h"
#include "llvm/ADT/DepthFirstIterator.h"
#include "llvm/ADT/PostOrderIterator.h"
diff --git a/mlir/lib/Transforms/ViewOpGraph.cpp b/mlir/lib/Transforms/ViewOpGraph.cpp
index c2eb2b893cea4..b3c0a06c96fea 100644
--- a/mlir/lib/Transforms/ViewOpGraph.cpp
+++ b/mlir/lib/Transforms/ViewOpGraph.cpp
@@ -8,12 +8,12 @@
#include "mlir/Transforms/ViewOpGraph.h"
+#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Operation.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/IndentedOstream.h"
-#include "mlir/Transforms/TopologicalSortUtils.h"
#include "llvm/Support/Format.h"
#include "llvm/Support/GraphWriter.h"
#include <map>
diff --git a/mlir/test/Transforms/test-toposort.mlir b/mlir/test/Analysis/test-toposort.mlir
similarity index 100%
rename from mlir/test/Transforms/test-toposort.mlir
rename to mlir/test/Analysis/test-toposort.mlir
diff --git a/mlir/test/lib/Analysis/CMakeLists.txt b/mlir/test/lib/Analysis/CMakeLists.txt
index d168888c1e71e..7c6b31ae8b73e 100644
--- a/mlir/test/lib/Analysis/CMakeLists.txt
+++ b/mlir/test/lib/Analysis/CMakeLists.txt
@@ -10,6 +10,7 @@ add_mlir_library(MLIRTestAnalysis
TestMemRefDependenceCheck.cpp
TestMemRefStrideCalculation.cpp
TestSlice.cpp
+ TestTopologicalSort.cpp
DataFlow/TestDeadCodeAnalysis.cpp
DataFlow/TestDenseBackwardDataFlowAnalysis.cpp
diff --git a/mlir/test/lib/Transforms/TestTopologicalSort.cpp b/mlir/test/lib/Analysis/TestTopologicalSort.cpp
similarity index 98%
rename from mlir/test/lib/Transforms/TestTopologicalSort.cpp
rename to mlir/test/lib/Analysis/TestTopologicalSort.cpp
index 3b110c7126200..c7e0206b2a4d7 100644
--- a/mlir/test/lib/Transforms/TestTopologicalSort.cpp
+++ b/mlir/test/lib/Analysis/TestTopologicalSort.cpp
@@ -6,10 +6,10 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h"
-#include "mlir/Transforms/TopologicalSortUtils.h"
using namespace mlir;
diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt
index a849b7ebd29e2..975a41ac3d5fe 100644
--- a/mlir/test/lib/Transforms/CMakeLists.txt
+++ b/mlir/test/lib/Transforms/CMakeLists.txt
@@ -26,7 +26,6 @@ add_mlir_library(MLIRTestTransforms
TestInlining.cpp
TestIntRangeInference.cpp
TestMakeIsolatedFromAbove.cpp
- TestTopologicalSort.cpp
${MLIRTestTransformsPDLSrc}
EXCLUDE_FROM_LIBMLIR
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index fc449e9010ae4..971c851a5f89f 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -7597,7 +7597,6 @@ cc_library(
"include/mlir/Transforms/LoopInvariantCodeMotionUtils.h",
"include/mlir/Transforms/OneToNTypeConversion.h",
"include/mlir/Transforms/RegionUtils.h",
- "include/mlir/Transforms/TopologicalSortUtils.h",
],
includes = ["include"],
deps = [
>From 277ed0ac9f4e2b8963defea5c54e2f0868103a37 Mon Sep 17 00:00:00 2001
From: Christian Ulmann <christian.ulmann at nextsilicon.com>
Date: Fri, 17 May 2024 12:02:50 +0000
Subject: [PATCH 3/5] migrate block topo sort
---
.../mlir/Analysis/TopologicalSortUtils.h | 4 ++++
mlir/include/mlir/Transforms/RegionUtils.h | 4 ----
mlir/lib/Analysis/SliceAnalysis.cpp | 23 ++-----------------
mlir/lib/Analysis/TopologicalSortUtils.cpp | 21 +++++++++++++++++
.../ArmSME/Transforms/TileAllocation.cpp | 1 +
.../OpenACC/OpenACCToLLVMIRTranslation.cpp | 2 +-
.../OpenMP/OpenMPToLLVMIRTranslation.cpp | 1 +
mlir/lib/Target/LLVMIR/ModuleTranslation.cpp | 2 +-
mlir/lib/Transforms/Mem2Reg.cpp | 2 +-
mlir/lib/Transforms/Utils/RegionUtils.cpp | 17 --------------
10 files changed, 32 insertions(+), 45 deletions(-)
diff --git a/mlir/include/mlir/Analysis/TopologicalSortUtils.h b/mlir/include/mlir/Analysis/TopologicalSortUtils.h
index fb9441db119fd..c2bc15ad3143f 100644
--- a/mlir/include/mlir/Analysis/TopologicalSortUtils.h
+++ b/mlir/include/mlir/Analysis/TopologicalSortUtils.h
@@ -104,6 +104,10 @@ bool computeTopologicalSorting(
MutableArrayRef<Operation *> ops,
function_ref<bool(Value, Operation *)> isOperandReady = nullptr);
+/// Get a list of blocks that is sorted according to dominance. This sort is
+/// stable.
+SetVector<Block *> getBlocksSortedByDominance(Region ®ion);
+
} // end namespace mlir
#endif // MLIR_ANALYSIS_TOPOLOGICALSORTUTILS_H
diff --git a/mlir/include/mlir/Transforms/RegionUtils.h b/mlir/include/mlir/Transforms/RegionUtils.h
index f65d0d44eef42..06eebff201d1b 100644
--- a/mlir/include/mlir/Transforms/RegionUtils.h
+++ b/mlir/include/mlir/Transforms/RegionUtils.h
@@ -87,10 +87,6 @@ LogicalResult eraseUnreachableBlocks(RewriterBase &rewriter,
LogicalResult runRegionDCE(RewriterBase &rewriter,
MutableArrayRef<Region> regions);
-/// Get a list of blocks that is sorted according to dominance. This sort is
-/// stable.
-SetVector<Block *> getBlocksSortedByDominance(Region ®ion);
-
} // namespace mlir
#endif // MLIR_TRANSFORMS_REGIONUTILS_H_
diff --git a/mlir/lib/Analysis/SliceAnalysis.cpp b/mlir/lib/Analysis/SliceAnalysis.cpp
index f93183749dfd2..37dae769dbc6d 100644
--- a/mlir/lib/Analysis/SliceAnalysis.cpp
+++ b/mlir/lib/Analysis/SliceAnalysis.cpp
@@ -11,13 +11,11 @@
//===----------------------------------------------------------------------===//
#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/IR/Block.h"
-#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Operation.h"
-#include "mlir/IR/RegionGraphTraits.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Support/LLVM.h"
-#include "llvm/ADT/PostOrderIterator.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallPtrSet.h"
@@ -167,23 +165,6 @@ mlir::getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions,
return topologicalSort(slice);
}
-/// TODO: deduplicate
-static SetVector<Block *> getTopologicallySortedBlocks(Region ®ion) {
- // For each block that has not been visited yet (i.e. that has no
- // predecessors), add it to the list as well as its successors.
- SetVector<Block *> blocks;
- for (Block &b : region) {
- if (blocks.count(&b) == 0) {
- llvm::ReversePostOrderTraversal<Block *> traversal(&b);
- blocks.insert(traversal.begin(), traversal.end());
- }
- }
- assert(blocks.size() == region.getBlocks().size() &&
- "some blocks are not sorted");
-
- return blocks;
-}
-
/// Computes the common ancestor region of all operations in `ops`. Remembers
/// all the traversed regions in `traversedRegions`.
static Region *findCommonParentRegion(const SetVector<Operation *> &ops,
@@ -220,7 +201,7 @@ static void topoSortRegion(Region ®ion,
const DenseSet<Region *> &relevantRegions,
const SetVector<Operation *> &toSort,
SetVector<Operation *> &result) {
- SetVector<Block *> sortedBlocks = getTopologicallySortedBlocks(region);
+ SetVector<Block *> sortedBlocks = getBlocksSortedByDominance(region);
for (Block *block : sortedBlocks) {
for (Operation &op : *block) {
if (toSort.contains(&op))
diff --git a/mlir/lib/Analysis/TopologicalSortUtils.cpp b/mlir/lib/Analysis/TopologicalSortUtils.cpp
index 4281beacee89e..f5e16e9a91fe5 100644
--- a/mlir/lib/Analysis/TopologicalSortUtils.cpp
+++ b/mlir/lib/Analysis/TopologicalSortUtils.cpp
@@ -7,7 +7,12 @@
//===----------------------------------------------------------------------===//
#include "mlir/Analysis/TopologicalSortUtils.h"
+#include "mlir/IR/Block.h"
#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/RegionGraphTraits.h"
+
+#include "llvm/ADT/PostOrderIterator.h"
+#include "llvm/ADT/SetVector.h"
using namespace mlir;
@@ -146,3 +151,19 @@ bool mlir::computeTopologicalSorting(
return allOpsScheduled;
}
+
+SetVector<Block *> mlir::getBlocksSortedByDominance(Region ®ion) {
+ // For each block that has not been visited yet (i.e. that has no
+ // predecessors), add it to the list as well as its successors.
+ SetVector<Block *> blocks;
+ for (Block &b : region) {
+ if (blocks.count(&b) == 0) {
+ llvm::ReversePostOrderTraversal<Block *> traversal(&b);
+ blocks.insert(traversal.begin(), traversal.end());
+ }
+ }
+ assert(blocks.size() == region.getBlocks().size() &&
+ "some blocks are not sorted");
+
+ return blocks;
+}
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
index acbbbe9932e19..733e758b43907 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
@@ -46,6 +46,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Analysis/Liveness.h"
+#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
#include "mlir/Dialect/ArmSME/Transforms/Transforms.h"
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp
index eeda245ce969f..d9cf85e4aecab 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp
@@ -12,6 +12,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.h"
+#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/OpenACC/OpenACC.h"
#include "mlir/IR/BuiltinOps.h"
@@ -19,7 +20,6 @@
#include "mlir/Support/LLVM.h"
#include "mlir/Target/LLVMIR/Dialect/OpenMPCommon.h"
#include "mlir/Target/LLVMIR/ModuleTranslation.h"
-#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Frontend/OpenMP/OMPConstants.h"
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 34b6903f8da07..9d125b7f11809 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -11,6 +11,7 @@
//
//===----------------------------------------------------------------------===//
#include "mlir/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.h"
+#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "mlir/Dialect/OpenMP/OpenMPInterfaces.h"
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index cf3257c8b9b87..1ec0736ec08bf 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -16,6 +16,7 @@
#include "AttrKindDetail.h"
#include "DebugTranslation.h"
#include "LoopAnnotationTranslation.h"
+#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/Dialect/DLTI/DLTI.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMInterfaces.h"
@@ -33,7 +34,6 @@
#include "mlir/Support/LogicalResult.h"
#include "mlir/Target/LLVMIR/LLVMTranslationInterface.h"
#include "mlir/Target/LLVMIR/TypeToLLVM.h"
-#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/PostOrderIterator.h"
#include "llvm/ADT/SetVector.h"
diff --git a/mlir/lib/Transforms/Mem2Reg.cpp b/mlir/lib/Transforms/Mem2Reg.cpp
index e2e240ad865ce..a452cc3fae8ac 100644
--- a/mlir/lib/Transforms/Mem2Reg.cpp
+++ b/mlir/lib/Transforms/Mem2Reg.cpp
@@ -9,6 +9,7 @@
#include "mlir/Transforms/Mem2Reg.h"
#include "mlir/Analysis/DataLayoutAnalysis.h"
#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/PatternMatch.h"
@@ -16,7 +17,6 @@
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/MemorySlotInterfaces.h"
#include "mlir/Transforms/Passes.h"
-#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/GenericIteratedDominanceFrontier.h"
diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp
index b6a6dea5fe9a0..b5e641d39fc0a 100644
--- a/mlir/lib/Transforms/Utils/RegionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp
@@ -19,7 +19,6 @@
#include "llvm/ADT/DepthFirstIterator.h"
#include "llvm/ADT/PostOrderIterator.h"
-#include "llvm/ADT/SmallSet.h"
#include <deque>
@@ -836,19 +835,3 @@ LogicalResult mlir::simplifyRegions(RewriterBase &rewriter,
return success(eliminatedBlocks || eliminatedOpsOrArgs ||
mergedIdenticalBlocks);
}
-
-SetVector<Block *> mlir::getBlocksSortedByDominance(Region ®ion) {
- // For each block that has not been visited yet (i.e. that has no
- // predecessors), add it to the list as well as its successors.
- SetVector<Block *> blocks;
- for (Block &b : region) {
- if (blocks.count(&b) == 0) {
- llvm::ReversePostOrderTraversal<Block *> traversal(&b);
- blocks.insert(traversal.begin(), traversal.end());
- }
- }
- assert(blocks.size() == region.getBlocks().size() &&
- "some blocks are not sorted");
-
- return blocks;
-}
>From a269b086c64d009d2f3af2891b02c5daca7457a4 Mon Sep 17 00:00:00 2001
From: Christian Ulmann <christian.ulmann at nextsilicon.com>
Date: Fri, 17 May 2024 13:56:09 +0000
Subject: [PATCH 4/5] move all topo sorts into the utils file
---
mlir/include/mlir/Analysis/SliceAnalysis.h | 6 --
.../mlir/Analysis/TopologicalSortUtils.h | 4 +
mlir/lib/Analysis/SliceAnalysis.cpp | 74 -------------------
mlir/lib/Analysis/TopologicalSortUtils.cpp | 74 +++++++++++++++++++
.../Conversion/VectorToGPU/VectorToGPU.cpp | 1 +
.../Dialect/Affine/Utils/LoopFusionUtils.cpp | 1 +
mlir/lib/Transforms/SROA.cpp | 1 +
mlir/test/lib/Analysis/TestSlice.cpp | 2 +-
8 files changed, 82 insertions(+), 81 deletions(-)
diff --git a/mlir/include/mlir/Analysis/SliceAnalysis.h b/mlir/include/mlir/Analysis/SliceAnalysis.h
index 19571fc1946be..99279fdfe427c 100644
--- a/mlir/include/mlir/Analysis/SliceAnalysis.h
+++ b/mlir/include/mlir/Analysis/SliceAnalysis.h
@@ -223,12 +223,6 @@ SetVector<Operation *>
getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions = {},
const ForwardSliceOptions &forwardSliceOptions = {});
-/// Multi-root DAG topological sort.
-/// Performs a topological sort of the Operation in the `toSort` SetVector.
-/// Returns a topologically sorted SetVector.
-/// Does not support multi-sets.
-SetVector<Operation *> topologicalSort(const SetVector<Operation *> &toSort);
-
/// Utility to match a generic reduction given a list of iteration-carried
/// arguments, `iterCarriedArgs` and the position of the potential reduction
/// argument within the list, `redPos`. If a reduction is matched, returns the
diff --git a/mlir/include/mlir/Analysis/TopologicalSortUtils.h b/mlir/include/mlir/Analysis/TopologicalSortUtils.h
index c2bc15ad3143f..7aabc5ee457c0 100644
--- a/mlir/include/mlir/Analysis/TopologicalSortUtils.h
+++ b/mlir/include/mlir/Analysis/TopologicalSortUtils.h
@@ -108,6 +108,10 @@ bool computeTopologicalSorting(
/// stable.
SetVector<Block *> getBlocksSortedByDominance(Region ®ion);
+/// Sorts all operation in `toSort` topologically while also region semantics.
+/// Does not support multi-sets.
+SetVector<Operation *> topologicalSort(const SetVector<Operation *> &toSort);
+
} // end namespace mlir
#endif // MLIR_ANALYSIS_TOPOLOGICALSORTUTILS_H
diff --git a/mlir/lib/Analysis/SliceAnalysis.cpp b/mlir/lib/Analysis/SliceAnalysis.cpp
index 37dae769dbc6d..2b1cf411ceeee 100644
--- a/mlir/lib/Analysis/SliceAnalysis.cpp
+++ b/mlir/lib/Analysis/SliceAnalysis.cpp
@@ -165,80 +165,6 @@ mlir::getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions,
return topologicalSort(slice);
}
-/// Computes the common ancestor region of all operations in `ops`. Remembers
-/// all the traversed regions in `traversedRegions`.
-static Region *findCommonParentRegion(const SetVector<Operation *> &ops,
- DenseSet<Region *> &traversedRegions) {
- // Map to count the number of times a region was encountered.
- llvm::DenseMap<Region *, size_t> regionCounts;
- size_t expectedCount = ops.size();
-
- // Walk the region tree for each operation towards the root and add to the
- // region count.
- Region *res = nullptr;
- for (Operation *op : ops) {
- Region *current = op->getParentRegion();
- while (current) {
- // Insert or get the count.
- auto it = regionCounts.try_emplace(current, 0).first;
- size_t count = ++it->getSecond();
- if (count == expectedCount) {
- res = current;
- break;
- }
- current = current->getParentRegion();
- }
- }
- auto firstRange = llvm::make_first_range(regionCounts);
- traversedRegions.insert(firstRange.begin(), firstRange.end());
- return res;
-}
-
-/// Topologically traverses `region` and insers all encountered operations in
-/// `toSort` into the result. Recursively traverses regions when they are
-/// present in `relevantRegions`.
-static void topoSortRegion(Region ®ion,
- const DenseSet<Region *> &relevantRegions,
- const SetVector<Operation *> &toSort,
- SetVector<Operation *> &result) {
- SetVector<Block *> sortedBlocks = getBlocksSortedByDominance(region);
- for (Block *block : sortedBlocks) {
- for (Operation &op : *block) {
- if (toSort.contains(&op))
- result.insert(&op);
- for (Region &subRegion : op.getRegions()) {
- // Skip regions that do not contain operations from `toSort`.
- if (!relevantRegions.contains(®ion))
- continue;
- topoSortRegion(subRegion, relevantRegions, toSort, result);
- }
- }
- }
-}
-
-SetVector<Operation *>
-mlir::topologicalSort(const SetVector<Operation *> &toSort) {
- if (toSort.size() <= 1)
- return toSort;
-
- assert(llvm::all_of(toSort,
- [&](Operation *op) { return toSort.count(op) == 1; }) &&
- "expected only unique set entries");
-
- // First, find the root region to start the recursive traversal through the
- // IR.
- DenseSet<Region *> relevantRegions;
- Region *rootRegion = findCommonParentRegion(toSort, relevantRegions);
- assert(rootRegion && "expected all ops to have a common ancestor");
-
- // Sort all element in `toSort` by recursively traversing the IR.
- SetVector<Operation *> result;
- topoSortRegion(*rootRegion, relevantRegions, toSort, result);
- assert(result.size() == toSort.size() &&
- "expected all operations to be present in the result");
- return result;
-}
-
/// Returns true if `value` (transitively) depends on iteration-carried values
/// of the given `ancestorOp`.
static bool dependsOnCarriedVals(Value value,
diff --git a/mlir/lib/Analysis/TopologicalSortUtils.cpp b/mlir/lib/Analysis/TopologicalSortUtils.cpp
index f5e16e9a91fe5..2fc1bd582ef47 100644
--- a/mlir/lib/Analysis/TopologicalSortUtils.cpp
+++ b/mlir/lib/Analysis/TopologicalSortUtils.cpp
@@ -167,3 +167,77 @@ SetVector<Block *> mlir::getBlocksSortedByDominance(Region ®ion) {
return blocks;
}
+
+/// Computes the common ancestor region of all operations in `ops`. Remembers
+/// all the traversed regions in `traversedRegions`.
+static Region *findCommonParentRegion(const SetVector<Operation *> &ops,
+ DenseSet<Region *> &traversedRegions) {
+ // Map to count the number of times a region was encountered.
+ llvm::DenseMap<Region *, size_t> regionCounts;
+ size_t expectedCount = ops.size();
+
+ // Walk the region tree for each operation towards the root and add to the
+ // region count.
+ Region *res = nullptr;
+ for (Operation *op : ops) {
+ Region *current = op->getParentRegion();
+ while (current) {
+ // Insert or get the count.
+ auto it = regionCounts.try_emplace(current, 0).first;
+ size_t count = ++it->getSecond();
+ if (count == expectedCount) {
+ res = current;
+ break;
+ }
+ current = current->getParentRegion();
+ }
+ }
+ auto firstRange = llvm::make_first_range(regionCounts);
+ traversedRegions.insert(firstRange.begin(), firstRange.end());
+ return res;
+}
+
+/// Topologically traverses `region` and insers all encountered operations in
+/// `toSort` into the result. Recursively traverses regions when they are
+/// present in `relevantRegions`.
+static void topoSortRegion(Region ®ion,
+ const DenseSet<Region *> &relevantRegions,
+ const SetVector<Operation *> &toSort,
+ SetVector<Operation *> &result) {
+ SetVector<Block *> sortedBlocks = getBlocksSortedByDominance(region);
+ for (Block *block : sortedBlocks) {
+ for (Operation &op : *block) {
+ if (toSort.contains(&op))
+ result.insert(&op);
+ for (Region &subRegion : op.getRegions()) {
+ // Skip regions that do not contain operations from `toSort`.
+ if (!relevantRegions.contains(®ion))
+ continue;
+ topoSortRegion(subRegion, relevantRegions, toSort, result);
+ }
+ }
+ }
+}
+
+SetVector<Operation *>
+mlir::topologicalSort(const SetVector<Operation *> &toSort) {
+ if (toSort.size() <= 1)
+ return toSort;
+
+ assert(llvm::all_of(toSort,
+ [&](Operation *op) { return toSort.count(op) == 1; }) &&
+ "expected only unique set entries");
+
+ // First, find the root region to start the recursive traversal through the
+ // IR.
+ DenseSet<Region *> relevantRegions;
+ Region *rootRegion = findCommonParentRegion(toSort, relevantRegions);
+ assert(rootRegion && "expected all ops to have a common ancestor");
+
+ // Sort all element in `toSort` by recursively traversing the IR.
+ SetVector<Operation *> result;
+ topoSortRegion(*rootRegion, relevantRegions, toSort, result);
+ assert(result.size() == toSort.size() &&
+ "expected all operations to be present in the result");
+ return result;
+}
diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index 332f0a2eecfcf..4496c2bc5fe8b 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -15,6 +15,7 @@
#include <type_traits>
#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
diff --git a/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp
index 84ae4b52dcf4e..7f3e43d0b4cd3 100644
--- a/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp
@@ -12,6 +12,7 @@
#include "mlir/Dialect/Affine/LoopFusionUtils.h"
#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
#include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
#include "mlir/Dialect/Affine/Analysis/Utils.h"
diff --git a/mlir/lib/Transforms/SROA.cpp b/mlir/lib/Transforms/SROA.cpp
index 67cbade07bc94..39f7256fb789d 100644
--- a/mlir/lib/Transforms/SROA.cpp
+++ b/mlir/lib/Transforms/SROA.cpp
@@ -9,6 +9,7 @@
#include "mlir/Transforms/SROA.h"
#include "mlir/Analysis/DataLayoutAnalysis.h"
#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/Interfaces/MemorySlotInterfaces.h"
#include "mlir/Transforms/Passes.h"
diff --git a/mlir/test/lib/Analysis/TestSlice.cpp b/mlir/test/lib/Analysis/TestSlice.cpp
index 06c41d8c4a110..fc367c07ad863 100644
--- a/mlir/test/lib/Analysis/TestSlice.cpp
+++ b/mlir/test/lib/Analysis/TestSlice.cpp
@@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Pass/Pass.h"
>From 9ea5fb929db8c4566abca69f5940a30b13de148d Mon Sep 17 00:00:00 2001
From: Christian Ulmann <christian.ulmann at nextsilicon.com>
Date: Tue, 21 May 2024 09:01:26 +0000
Subject: [PATCH 5/5] address nit comments
---
.../mlir/Analysis/TopologicalSortUtils.h | 6 ++--
mlir/lib/Analysis/TopologicalSortUtils.cpp | 34 ++++++++-----------
mlir/lib/Transforms/SROA.cpp | 9 +++--
mlir/test/lib/Analysis/TestSlice.cpp | 3 +-
4 files changed, 26 insertions(+), 26 deletions(-)
diff --git a/mlir/include/mlir/Analysis/TopologicalSortUtils.h b/mlir/include/mlir/Analysis/TopologicalSortUtils.h
index 7aabc5ee457c0..ee98cd8cb380e 100644
--- a/mlir/include/mlir/Analysis/TopologicalSortUtils.h
+++ b/mlir/include/mlir/Analysis/TopologicalSortUtils.h
@@ -104,12 +104,12 @@ bool computeTopologicalSorting(
MutableArrayRef<Operation *> ops,
function_ref<bool(Value, Operation *)> isOperandReady = nullptr);
-/// Get a list of blocks that is sorted according to dominance. This sort is
+/// Gets a list of blocks that is sorted according to dominance. This sort is
/// stable.
SetVector<Block *> getBlocksSortedByDominance(Region ®ion);
-/// Sorts all operation in `toSort` topologically while also region semantics.
-/// Does not support multi-sets.
+/// Sorts all operations in `toSort` topologically while also considering region
+/// semantics. Does not support multi-sets.
SetVector<Operation *> topologicalSort(const SetVector<Operation *> &toSort);
} // end namespace mlir
diff --git a/mlir/lib/Analysis/TopologicalSortUtils.cpp b/mlir/lib/Analysis/TopologicalSortUtils.cpp
index 2fc1bd582ef47..94c403c07385b 100644
--- a/mlir/lib/Analysis/TopologicalSortUtils.cpp
+++ b/mlir/lib/Analysis/TopologicalSortUtils.cpp
@@ -168,12 +168,12 @@ SetVector<Block *> mlir::getBlocksSortedByDominance(Region ®ion) {
return blocks;
}
-/// Computes the common ancestor region of all operations in `ops`. Remembers
-/// all the traversed regions in `traversedRegions`.
-static Region *findCommonParentRegion(const SetVector<Operation *> &ops,
- DenseSet<Region *> &traversedRegions) {
+/// Computes the closest common ancestor region of all operations in `ops`.
+/// Remembers all the traversed regions in `traversedRegions`.
+static Region *findCommonAncestorRegion(const SetVector<Operation *> &ops,
+ DenseSet<Region *> &traversedRegions) {
// Map to count the number of times a region was encountered.
- llvm::DenseMap<Region *, size_t> regionCounts;
+ DenseMap<Region *, size_t> regionCounts;
size_t expectedCount = ops.size();
// Walk the region tree for each operation towards the root and add to the
@@ -182,10 +182,8 @@ static Region *findCommonParentRegion(const SetVector<Operation *> &ops,
for (Operation *op : ops) {
Region *current = op->getParentRegion();
while (current) {
- // Insert or get the count.
- auto it = regionCounts.try_emplace(current, 0).first;
- size_t count = ++it->getSecond();
- if (count == expectedCount) {
+ // Insert or update the count and compare it.
+ if (++regionCounts[current] == expectedCount) {
res = current;
break;
}
@@ -197,11 +195,11 @@ static Region *findCommonParentRegion(const SetVector<Operation *> &ops,
return res;
}
-/// Topologically traverses `region` and insers all encountered operations in
+/// Topologically traverses `region` and inserts all encountered operations in
/// `toSort` into the result. Recursively traverses regions when they are
/// present in `relevantRegions`.
static void topoSortRegion(Region ®ion,
- const DenseSet<Region *> &relevantRegions,
+ const DenseSet<Region *> &ancestorRegions,
const SetVector<Operation *> &toSort,
SetVector<Operation *> &result) {
SetVector<Block *> sortedBlocks = getBlocksSortedByDominance(region);
@@ -211,9 +209,9 @@ static void topoSortRegion(Region ®ion,
result.insert(&op);
for (Region &subRegion : op.getRegions()) {
// Skip regions that do not contain operations from `toSort`.
- if (!relevantRegions.contains(®ion))
+ if (!ancestorRegions.contains(®ion))
continue;
- topoSortRegion(subRegion, relevantRegions, toSort, result);
+ topoSortRegion(subRegion, ancestorRegions, toSort, result);
}
}
}
@@ -224,19 +222,15 @@ mlir::topologicalSort(const SetVector<Operation *> &toSort) {
if (toSort.size() <= 1)
return toSort;
- assert(llvm::all_of(toSort,
- [&](Operation *op) { return toSort.count(op) == 1; }) &&
- "expected only unique set entries");
-
// First, find the root region to start the recursive traversal through the
// IR.
- DenseSet<Region *> relevantRegions;
- Region *rootRegion = findCommonParentRegion(toSort, relevantRegions);
+ DenseSet<Region *> ancestorRegions;
+ Region *rootRegion = findCommonAncestorRegion(toSort, ancestorRegions);
assert(rootRegion && "expected all ops to have a common ancestor");
// Sort all element in `toSort` by recursively traversing the IR.
SetVector<Operation *> result;
- topoSortRegion(*rootRegion, relevantRegions, toSort, result);
+ topoSortRegion(*rootRegion, ancestorRegions, toSort, result);
assert(result.size() == toSort.size() &&
"expected all operations to be present in the result");
return result;
diff --git a/mlir/lib/Transforms/SROA.cpp b/mlir/lib/Transforms/SROA.cpp
index 39f7256fb789d..9a7f4db2afe00 100644
--- a/mlir/lib/Transforms/SROA.cpp
+++ b/mlir/lib/Transforms/SROA.cpp
@@ -108,14 +108,19 @@ computeDestructuringInfo(DestructurableMemorySlot &slot,
// An operation that has blocking uses must be promoted. If it is not
// promotable, destructuring must fail.
- if (!promotable)
+ if (!promotable) {
+ // user->emitError() << "not promotable";
return {};
+ }
SmallVector<OpOperand *> newBlockingUses;
// If the operation decides it cannot deal with removing the blocking uses,
// destructuring must fail.
- if (!promotable.canUsesBeRemoved(blockingUses, newBlockingUses, dataLayout))
+ if (!promotable.canUsesBeRemoved(blockingUses, newBlockingUses,
+ dataLayout)) {
+ // promotable->emitError() << "not removable";
return {};
+ }
// Then, register any new blocking uses for coming operations.
for (OpOperand *blockingUse : newBlockingUses) {
diff --git a/mlir/test/lib/Analysis/TestSlice.cpp b/mlir/test/lib/Analysis/TestSlice.cpp
index fc367c07ad863..7e8320dbf3ec3 100644
--- a/mlir/test/lib/Analysis/TestSlice.cpp
+++ b/mlir/test/lib/Analysis/TestSlice.cpp
@@ -25,7 +25,8 @@ struct TestTopologicalSortPass
StringRef getArgument() const final { return "test-print-topological-sort"; }
StringRef getDescription() const final {
- return "Print operations in topological order";
+ return "Sorts operations topologically and attaches attributes with their "
+ "corresponding index in the ordering to them";
}
void runOnOperation() override {
SetVector<Operation *> toSort;
More information about the Mlir-commits
mailing list