[Mlir-commits] [llvm] [mlir] [MLIR][Analysis] Consolidate topological sort utilities (PR #92563)
Christian Ulmann
llvmlistbot at llvm.org
Fri May 17 08:45:45 PDT 2024
https://github.com/Dinistro created https://github.com/llvm/llvm-project/pull/92563
This PR attempts to consolidate the different topological sort utilities into one place. It adds them to the analysis folder because the `SliceAnalysis` uses some of these.
There are now two different sorting strategies:
1. Sort only according to SSA use-def chains
2. Sort while taking regions into account. This requires a much more elaborate traversal and cannot be applied on graph regions that easily.
This additionally reimplements the region aware topological sorting because the previous implementation had an exponential space complexity.
I'm open to suggestions on how to combine this further or how to fuse the test passes.
>From 616b6c29d7163897b474a6ce83b068440afa5fee Mon Sep 17 00:00:00 2001
From: Christian Ulmann <christian.ulmann at nextsilicon.com>
Date: Fri, 17 May 2024 11:16:56 +0000
Subject: [PATCH 1/4] first impl of the new topo sort
---
mlir/include/mlir/Analysis/SliceAnalysis.h | 1 +
mlir/lib/Analysis/SliceAnalysis.cpp | 126 +++++++++-----
mlir/test/Analysis/test-topoligical-sort.mlir | 53 ++++--
mlir/test/Dialect/Affine/slicing-utils.mlir | 160 +++++++++---------
mlir/test/lib/Analysis/TestSlice.cpp | 28 ++-
5 files changed, 211 insertions(+), 157 deletions(-)
diff --git a/mlir/include/mlir/Analysis/SliceAnalysis.h b/mlir/include/mlir/Analysis/SliceAnalysis.h
index d5cdf72c3889f..19571fc1946be 100644
--- a/mlir/include/mlir/Analysis/SliceAnalysis.h
+++ b/mlir/include/mlir/Analysis/SliceAnalysis.h
@@ -226,6 +226,7 @@ getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions = {},
/// Multi-root DAG topological sort.
/// Performs a topological sort of the Operation in the `toSort` SetVector.
/// Returns a topologically sorted SetVector.
+/// Does not support multi-sets.
SetVector<Operation *> topologicalSort(const SetVector<Operation *> &toSort);
/// Utility to match a generic reduction given a list of iteration-carried
diff --git a/mlir/lib/Analysis/SliceAnalysis.cpp b/mlir/lib/Analysis/SliceAnalysis.cpp
index 26fe8e3dc0819..f93183749dfd2 100644
--- a/mlir/lib/Analysis/SliceAnalysis.cpp
+++ b/mlir/lib/Analysis/SliceAnalysis.cpp
@@ -11,10 +11,13 @@
//===----------------------------------------------------------------------===//
#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/IR/Block.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Operation.h"
+#include "mlir/IR/RegionGraphTraits.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/PostOrderIterator.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallPtrSet.h"
@@ -164,60 +167,95 @@ mlir::getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions,
return topologicalSort(slice);
}
-namespace {
-/// DFS post-order implementation that maintains a global count to work across
-/// multiple invocations, to help implement topological sort on multi-root DAGs.
-/// We traverse all operations but only record the ones that appear in
-/// `toSort` for the final result.
-struct DFSState {
- DFSState(const SetVector<Operation *> &set) : toSort(set), seen() {}
- const SetVector<Operation *> &toSort;
- SmallVector<Operation *, 16> topologicalCounts;
- DenseSet<Operation *> seen;
-};
-} // namespace
-
-static void dfsPostorder(Operation *root, DFSState *state) {
- SmallVector<Operation *> queue(1, root);
- std::vector<Operation *> ops;
- while (!queue.empty()) {
- Operation *current = queue.pop_back_val();
- ops.push_back(current);
- for (Operation *op : current->getUsers())
- queue.push_back(op);
- for (Region ®ion : current->getRegions()) {
- for (Operation &op : region.getOps())
- queue.push_back(&op);
+/// TODO: deduplicate
+static SetVector<Block *> getTopologicallySortedBlocks(Region ®ion) {
+ // For each block that has not been visited yet (i.e. that has no
+ // predecessors), add it to the list as well as its successors.
+ SetVector<Block *> blocks;
+ for (Block &b : region) {
+ if (blocks.count(&b) == 0) {
+ llvm::ReversePostOrderTraversal<Block *> traversal(&b);
+ blocks.insert(traversal.begin(), traversal.end());
}
}
+ assert(blocks.size() == region.getBlocks().size() &&
+ "some blocks are not sorted");
- for (Operation *op : llvm::reverse(ops)) {
- if (state->seen.insert(op).second && state->toSort.count(op) > 0)
- state->topologicalCounts.push_back(op);
+ return blocks;
+}
+
+/// Computes the common ancestor region of all operations in `ops`. Remembers
+/// all the traversed regions in `traversedRegions`.
+static Region *findCommonParentRegion(const SetVector<Operation *> &ops,
+ DenseSet<Region *> &traversedRegions) {
+ // Map to count the number of times a region was encountered.
+ llvm::DenseMap<Region *, size_t> regionCounts;
+ size_t expectedCount = ops.size();
+
+ // Walk the region tree for each operation towards the root and add to the
+ // region count.
+ Region *res = nullptr;
+ for (Operation *op : ops) {
+ Region *current = op->getParentRegion();
+ while (current) {
+ // Insert or get the count.
+ auto it = regionCounts.try_emplace(current, 0).first;
+ size_t count = ++it->getSecond();
+ if (count == expectedCount) {
+ res = current;
+ break;
+ }
+ current = current->getParentRegion();
+ }
+ }
+ auto firstRange = llvm::make_first_range(regionCounts);
+ traversedRegions.insert(firstRange.begin(), firstRange.end());
+ return res;
+}
+
+/// Topologically traverses `region` and insers all encountered operations in
+/// `toSort` into the result. Recursively traverses regions when they are
+/// present in `relevantRegions`.
+static void topoSortRegion(Region ®ion,
+ const DenseSet<Region *> &relevantRegions,
+ const SetVector<Operation *> &toSort,
+ SetVector<Operation *> &result) {
+ SetVector<Block *> sortedBlocks = getTopologicallySortedBlocks(region);
+ for (Block *block : sortedBlocks) {
+ for (Operation &op : *block) {
+ if (toSort.contains(&op))
+ result.insert(&op);
+ for (Region &subRegion : op.getRegions()) {
+ // Skip regions that do not contain operations from `toSort`.
+ if (!relevantRegions.contains(®ion))
+ continue;
+ topoSortRegion(subRegion, relevantRegions, toSort, result);
+ }
+ }
}
}
SetVector<Operation *>
mlir::topologicalSort(const SetVector<Operation *> &toSort) {
- if (toSort.empty()) {
+ if (toSort.size() <= 1)
return toSort;
- }
- // Run from each root with global count and `seen` set.
- DFSState state(toSort);
- for (auto *s : toSort) {
- assert(toSort.count(s) == 1 && "NYI: multi-sets not supported");
- dfsPostorder(s, &state);
- }
-
- // Reorder and return.
- SetVector<Operation *> res;
- for (auto it = state.topologicalCounts.rbegin(),
- eit = state.topologicalCounts.rend();
- it != eit; ++it) {
- res.insert(*it);
- }
- return res;
+ assert(llvm::all_of(toSort,
+ [&](Operation *op) { return toSort.count(op) == 1; }) &&
+ "expected only unique set entries");
+
+ // First, find the root region to start the recursive traversal through the
+ // IR.
+ DenseSet<Region *> relevantRegions;
+ Region *rootRegion = findCommonParentRegion(toSort, relevantRegions);
+ assert(rootRegion && "expected all ops to have a common ancestor");
+
+ // Sort all element in `toSort` by recursively traversing the IR.
+ SetVector<Operation *> result;
+ topoSortRegion(*rootRegion, relevantRegions, toSort, result);
+ assert(result.size() == toSort.size() &&
+ "expected all operations to be present in the result");
+ return result;
}
/// Returns true if `value` (transitively) depends on iteration-carried values
diff --git a/mlir/test/Analysis/test-topoligical-sort.mlir b/mlir/test/Analysis/test-topoligical-sort.mlir
index 8608586402055..150aff854fc8f 100644
--- a/mlir/test/Analysis/test-topoligical-sort.mlir
+++ b/mlir/test/Analysis/test-topoligical-sort.mlir
@@ -1,21 +1,38 @@
-// RUN: mlir-opt %s -pass-pipeline="builtin.module(func.func(test-print-topological-sort))" 2>&1 | FileCheck %s
+// RUN: mlir-opt %s -pass-pipeline="builtin.module(func.func(test-print-topological-sort))" --split-input-file | FileCheck %s
-// CHECK-LABEL: Testing : region
-// CHECK: arith.addi {{.*}} : index
-// CHECK-NEXT: scf.for
-// CHECK: } {__test_sort_original_idx__ = 2 : i64}
-// CHECK-NEXT: arith.addi {{.*}} : i32
-// CHECK-NEXT: arith.subi {{.*}} : i32
-func.func @region(
- %arg0 : index, %arg1 : index, %arg2 : index, %arg3 : index,
- %arg4 : i32, %arg5 : i32, %arg6 : i32,
- %buffer : memref<i32>) {
- %0 = arith.addi %arg4, %arg5 {__test_sort_original_idx__ = 0} : i32
- %idx = arith.addi %arg0, %arg1 {__test_sort_original_idx__ = 3} : index
- scf.for %arg7 = %idx to %arg2 step %arg3 {
- %2 = arith.addi %0, %arg5 : i32
- %3 = arith.subi %2, %arg6 {__test_sort_original_idx__ = 1} : i32
- memref.store %3, %buffer[] : memref<i32>
- } {__test_sort_original_idx__ = 2}
+// CHECK-LABEL: single_element
+func.func @single_element() {
+ // CHECK: test_sort_index = 0
+ return {test_to_sort}
+}
+
+// -----
+
+// CHECK-LABEL: @simple_region
+func.func @simple_region(%cond: i1) {
+ // CHECK: test_sort_index = 0
+ %0 = arith.constant {test_to_sort} 42 : i32
+ scf.if %cond {
+ %1 = arith.addi %0, %0 : i32
+ // CHECK: test_sort_index = 2
+ %2 = arith.subi %0, %1 {test_to_sort} : i32
+ // CHECK: test_sort_index = 1
+ } {test_to_sort}
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @multi_region
+func.func @multi_region(%cond: i1) {
+ scf.if %cond {
+ // CHECK: test_sort_index = 0
+ %0 = arith.constant {test_to_sort} 42 : i32
+ }
+
+ scf.if %cond {
+ // CHECK: test_sort_index = 1
+ %0 = arith.constant {test_to_sort} 24 : i32
+ }
return
}
diff --git a/mlir/test/Dialect/Affine/slicing-utils.mlir b/mlir/test/Dialect/Affine/slicing-utils.mlir
index 74379978fdf8c..0848a924b9d96 100644
--- a/mlir/test/Dialect/Affine/slicing-utils.mlir
+++ b/mlir/test/Dialect/Affine/slicing-utils.mlir
@@ -28,15 +28,15 @@ func.func @slicing_test() {
// BWD: matched: %[[v1:.*]] {{.*}} backward static slice:
//
// FWDBWD: matched: %[[v1:.*]] {{.*}} static slice:
- // FWDBWD-DAG: %[[v4:.*]] = "slicing-test-op"() : () -> i4
- // FWDBWD-DAG: %[[v3:.*]] = "slicing-test-op"() : () -> i3
- // FWDBWD-NEXT: %[[v6:.*]] = "slicing-test-op"(%[[v3]], %[[v4]]) : (i3, i4) -> i6
- // FWDBWD-DAG: %[[v2:.*]] = "slicing-test-op"() : () -> i2
- // FWDBWD-DAG: %[[v1:.*]] = "slicing-test-op"() : () -> i1
- // FWDBWD-NEXT: %[[v5:.*]] = "slicing-test-op"(%[[v1]], %[[v2]]) : (i1, i2) -> i5
- // FWDBWD-DAG: %[[v8:.*]] = "slicing-test-op"(%[[v5]], %[[v6]]) : (i5, i6) -> i8
- // FWDBWD-DAG: %[[v7:.*]] = "slicing-test-op"(%[[v1]], %[[v5]]) : (i1, i5) -> i7
- // FWDBWD-NEXT: %[[v9:.*]] = "slicing-test-op"(%[[v7]], %[[v8]]) : (i7, i8) -> i9
+ // FWDBWD: %[[v1:.*]] = "slicing-test-op"() : () -> i1
+ // FWDBWD: %[[v2:.*]] = "slicing-test-op"() : () -> i2
+ // FWDBWD: %[[v3:.*]] = "slicing-test-op"() : () -> i3
+ // FWDBWD: %[[v4:.*]] = "slicing-test-op"() : () -> i4
+ // FWDBWD: %[[v5:.*]] = "slicing-test-op"(%[[v1]], %[[v2]]) : (i1, i2) -> i5
+ // FWDBWD: %[[v6:.*]] = "slicing-test-op"(%[[v3]], %[[v4]]) : (i3, i4) -> i6
+ // FWDBWD: %[[v7:.*]] = "slicing-test-op"(%[[v1]], %[[v5]]) : (i1, i5) -> i7
+ // FWDBWD: %[[v8:.*]] = "slicing-test-op"(%[[v5]], %[[v6]]) : (i5, i6) -> i8
+ // FWDBWD: %[[v9:.*]] = "slicing-test-op"(%[[v7]], %[[v8]]) : (i7, i8) -> i9
%1 = "slicing-test-op" () : () -> i1
@@ -49,15 +49,15 @@ func.func @slicing_test() {
// BWD: matched: %[[v2:.*]] {{.*}} backward static slice:
//
// FWDBWD-NEXT: matched: %[[v2:.*]] {{.*}} static slice:
- // FWDBWD-DAG: %[[v4:.*]] = "slicing-test-op"() : () -> i4
- // FWDBWD-DAG: %[[v3:.*]] = "slicing-test-op"() : () -> i3
- // FWDBWD-NEXT: %[[v6:.*]] = "slicing-test-op"(%[[v3]], %[[v4]]) : (i3, i4) -> i6
- // FWDBWD-DAG: %[[v2:.*]] = "slicing-test-op"() : () -> i2
- // FWDBWD-DAG: %[[v1:.*]] = "slicing-test-op"() : () -> i1
- // FWDBWD-NEXT: %[[v5:.*]] = "slicing-test-op"(%[[v1]], %[[v2]]) : (i1, i2) -> i5
- // FWDBWD-DAG: %[[v8:.*]] = "slicing-test-op"(%[[v5]], %[[v6]]) : (i5, i6) -> i8
- // FWDBWD-DAG: %[[v7:.*]] = "slicing-test-op"(%[[v1]], %[[v5]]) : (i1, i5) -> i7
- // FWDBWD-NEXT: %[[v9:.*]] = "slicing-test-op"(%[[v7]], %[[v8]]) : (i7, i8) -> i9
+ // FWDBWD: %[[v1:.*]] = "slicing-test-op"() : () -> i1
+ // FWDBWD: %[[v2:.*]] = "slicing-test-op"() : () -> i2
+ // FWDBWD: %[[v3:.*]] = "slicing-test-op"() : () -> i3
+ // FWDBWD: %[[v4:.*]] = "slicing-test-op"() : () -> i4
+ // FWDBWD: %[[v5:.*]] = "slicing-test-op"(%[[v1]], %[[v2]]) : (i1, i2) -> i5
+ // FWDBWD: %[[v6:.*]] = "slicing-test-op"(%[[v3]], %[[v4]]) : (i3, i4) -> i6
+ // FWDBWD: %[[v7:.*]] = "slicing-test-op"(%[[v1]], %[[v5]]) : (i1, i5) -> i7
+ // FWDBWD: %[[v8:.*]] = "slicing-test-op"(%[[v5]], %[[v6]]) : (i5, i6) -> i8
+ // FWDBWD: %[[v9:.*]] = "slicing-test-op"(%[[v7]], %[[v8]]) : (i7, i8) -> i9
%2 = "slicing-test-op" () : () -> i2
@@ -69,15 +69,15 @@ func.func @slicing_test() {
// BWD: matched: %[[v3:.*]] {{.*}} backward static slice:
//
// FWDBWD-NEXT: matched: %[[v3:.*]] {{.*}} static slice:
- // FWDBWD-DAG: %[[v2:.*]] = "slicing-test-op"() : () -> i2
- // FWDBWD-DAG: %[[v1:.*]] = "slicing-test-op"() : () -> i1
- // FWDBWD-NEXT: %[[v5:.*]] = "slicing-test-op"(%[[v1]], %[[v2]]) : (i1, i2) -> i5
- // FWDBWD-NEXT: %[[v7:.*]] = "slicing-test-op"(%[[v1]], %[[v5]]) : (i1, i5) -> i7
- // FWDBWD-DAG: %[[v4:.*]] = "slicing-test-op"() : () -> i4
- // FWDBWD-DAG: %[[v3:.*]] = "slicing-test-op"() : () -> i3
- // FWDBWD-NEXT: %[[v6:.*]] = "slicing-test-op"(%[[v3]], %[[v4]]) : (i3, i4) -> i6
- // FWDBWD-NEXT: %[[v8:.*]] = "slicing-test-op"(%[[v5]], %[[v6]]) : (i5, i6) -> i8
- // FWDBWD-NEXT: %[[v9:.*]] = "slicing-test-op"(%[[v7]], %[[v8]]) : (i7, i8) -> i9
+ // FWDBWD: %[[v1:.*]] = "slicing-test-op"() : () -> i1
+ // FWDBWD: %[[v2:.*]] = "slicing-test-op"() : () -> i2
+ // FWDBWD: %[[v3:.*]] = "slicing-test-op"() : () -> i3
+ // FWDBWD: %[[v4:.*]] = "slicing-test-op"() : () -> i4
+ // FWDBWD: %[[v5:.*]] = "slicing-test-op"(%[[v1]], %[[v2]]) : (i1, i2) -> i5
+ // FWDBWD: %[[v6:.*]] = "slicing-test-op"(%[[v3]], %[[v4]]) : (i3, i4) -> i6
+ // FWDBWD: %[[v7:.*]] = "slicing-test-op"(%[[v1]], %[[v5]]) : (i1, i5) -> i7
+ // FWDBWD: %[[v8:.*]] = "slicing-test-op"(%[[v5]], %[[v6]]) : (i5, i6) -> i8
+ // FWDBWD: %[[v9:.*]] = "slicing-test-op"(%[[v7]], %[[v8]]) : (i7, i8) -> i9
%3 = "slicing-test-op" () : () -> i3
@@ -89,15 +89,15 @@ func.func @slicing_test() {
// BWD: matched: %[[v4:.*]] {{.*}} backward static slice:
//
// FWDBWD-NEXT: matched: %[[v4:.*]] {{.*}} static slice:
- // FWDBWD-DAG: %[[v2:.*]] = "slicing-test-op"() : () -> i2
- // FWDBWD-DAG: %[[v1:.*]] = "slicing-test-op"() : () -> i1
- // FWDBWD-NEXT: %[[v5:.*]] = "slicing-test-op"(%[[v1]], %[[v2]]) : (i1, i2) -> i5
- // FWDBWD-NEXT: %[[v7:.*]] = "slicing-test-op"(%[[v1]], %[[v5]]) : (i1, i5) -> i7
- // FWDBWD-DAG: %[[v4:.*]] = "slicing-test-op"() : () -> i4
- // FWDBWD-DAG: %[[v3:.*]] = "slicing-test-op"() : () -> i3
- // FWDBWD-NEXT: %[[v6:.*]] = "slicing-test-op"(%[[v3]], %[[v4]]) : (i3, i4) -> i6
- // FWDBWD-NEXT: %[[v8:.*]] = "slicing-test-op"(%[[v5]], %[[v6]]) : (i5, i6) -> i8
- // FWDBWD-NEXT: %[[v9:.*]] = "slicing-test-op"(%[[v7]], %[[v8]]) : (i7, i8) -> i9
+ // FWDBWD: %[[v1:.*]] = "slicing-test-op"() : () -> i1
+ // FWDBWD: %[[v2:.*]] = "slicing-test-op"() : () -> i2
+ // FWDBWD: %[[v3:.*]] = "slicing-test-op"() : () -> i3
+ // FWDBWD: %[[v4:.*]] = "slicing-test-op"() : () -> i4
+ // FWDBWD: %[[v5:.*]] = "slicing-test-op"(%[[v1]], %[[v2]]) : (i1, i2) -> i5
+ // FWDBWD: %[[v6:.*]] = "slicing-test-op"(%[[v3]], %[[v4]]) : (i3, i4) -> i6
+ // FWDBWD: %[[v7:.*]] = "slicing-test-op"(%[[v1]], %[[v5]]) : (i1, i5) -> i7
+ // FWDBWD: %[[v8:.*]] = "slicing-test-op"(%[[v5]], %[[v6]]) : (i5, i6) -> i8
+ // FWDBWD: %[[v9:.*]] = "slicing-test-op"(%[[v7]], %[[v8]]) : (i7, i8) -> i9
%4 = "slicing-test-op" () : () -> i4
@@ -111,15 +111,15 @@ func.func @slicing_test() {
// BWD-DAG: %[[v2:.*]] = "slicing-test-op"() : () -> i2
//
// FWDBWD-NEXT: matched: %[[v5:.*]] {{.*}} static slice:
- // FWDBWD-DAG: %[[v4:.*]] = "slicing-test-op"() : () -> i4
- // FWDBWD-DAG: %[[v3:.*]] = "slicing-test-op"() : () -> i3
- // FWDBWD-NEXT: %[[v6:.*]] = "slicing-test-op"(%[[v3]], %[[v4]]) : (i3, i4) -> i6
- // FWDBWD-DAG: %[[v2:.*]] = "slicing-test-op"() : () -> i2
- // FWDBWD-DAG: %[[v1:.*]] = "slicing-test-op"() : () -> i1
- // FWDBWD-NEXT: %[[v5:.*]] = "slicing-test-op"(%[[v1]], %[[v2]]) : (i1, i2) -> i5
- // FWDBWD-DAG: %[[v8:.*]] = "slicing-test-op"(%[[v5]], %[[v6]]) : (i5, i6) -> i8
- // FWDBWD-DAG: %[[v7:.*]] = "slicing-test-op"(%[[v1]], %[[v5]]) : (i1, i5) -> i7
- // FWDBWD-NEXT: %[[v9:.*]] = "slicing-test-op"(%[[v7]], %[[v8]]) : (i7, i8) -> i9
+ // FWDBWD: %[[v1:.*]] = "slicing-test-op"() : () -> i1
+ // FWDBWD: %[[v2:.*]] = "slicing-test-op"() : () -> i2
+ // FWDBWD: %[[v3:.*]] = "slicing-test-op"() : () -> i3
+ // FWDBWD: %[[v4:.*]] = "slicing-test-op"() : () -> i4
+ // FWDBWD: %[[v5:.*]] = "slicing-test-op"(%[[v1]], %[[v2]]) : (i1, i2) -> i5
+ // FWDBWD: %[[v6:.*]] = "slicing-test-op"(%[[v3]], %[[v4]]) : (i3, i4) -> i6
+ // FWDBWD: %[[v7:.*]] = "slicing-test-op"(%[[v1]], %[[v5]]) : (i1, i5) -> i7
+ // FWDBWD: %[[v8:.*]] = "slicing-test-op"(%[[v5]], %[[v6]]) : (i5, i6) -> i8
+ // FWDBWD: %[[v9:.*]] = "slicing-test-op"(%[[v7]], %[[v8]]) : (i7, i8) -> i9
%5 = "slicing-test-op" (%1, %2) : (i1, i2) -> i5
@@ -132,15 +132,15 @@ func.func @slicing_test() {
// BWD-DAG: %[[v4:.*]] = "slicing-test-op"() : () -> i4
//
// FWDBWD-NEXT: matched: %[[v6:.*]] {{.*}} static slice:
- // FWDBWD-DAG: %[[v2:.*]] = "slicing-test-op"() : () -> i2
- // FWDBWD-DAG: %[[v1:.*]] = "slicing-test-op"() : () -> i1
- // FWDBWD-NEXT: %[[v5:.*]] = "slicing-test-op"(%[[v1]], %[[v2]]) : (i1, i2) -> i5
- // FWDBWD-NEXT: %[[v7:.*]] = "slicing-test-op"(%[[v1]], %[[v5]]) : (i1, i5) -> i7
- // FWDBWD-DAG: %[[v4:.*]] = "slicing-test-op"() : () -> i4
- // FWDBWD-DAG: %[[v3:.*]] = "slicing-test-op"() : () -> i3
- // FWDBWD-NEXT: %[[v6:.*]] = "slicing-test-op"(%[[v3]], %[[v4]]) : (i3, i4) -> i6
- // FWDBWD-NEXT: %[[v8:.*]] = "slicing-test-op"(%[[v5]], %[[v6]]) : (i5, i6) -> i8
- // FWDBWD-NEXT: %[[v9:.*]] = "slicing-test-op"(%[[v7]], %[[v8]]) : (i7, i8) -> i9
+ // FWDBWD: %[[v1:.*]] = "slicing-test-op"() : () -> i1
+ // FWDBWD: %[[v2:.*]] = "slicing-test-op"() : () -> i2
+ // FWDBWD: %[[v3:.*]] = "slicing-test-op"() : () -> i3
+ // FWDBWD: %[[v4:.*]] = "slicing-test-op"() : () -> i4
+ // FWDBWD: %[[v5:.*]] = "slicing-test-op"(%[[v1]], %[[v2]]) : (i1, i2) -> i5
+ // FWDBWD: %[[v6:.*]] = "slicing-test-op"(%[[v3]], %[[v4]]) : (i3, i4) -> i6
+ // FWDBWD: %[[v7:.*]] = "slicing-test-op"(%[[v1]], %[[v5]]) : (i1, i5) -> i7
+ // FWDBWD: %[[v8:.*]] = "slicing-test-op"(%[[v5]], %[[v6]]) : (i5, i6) -> i8
+ // FWDBWD: %[[v9:.*]] = "slicing-test-op"(%[[v7]], %[[v8]]) : (i7, i8) -> i9
%6 = "slicing-test-op" (%3, %4) : (i3, i4) -> i6
@@ -153,15 +153,15 @@ func.func @slicing_test() {
// BWD-NEXT: %[[v5:.*]] = "slicing-test-op"(%[[v1]], %[[v2]]) : (i1, i2) -> i5
//
// FWDBWD-NEXT: matched: %[[v7:.*]] {{.*}} static slice:
- // FWDBWD-DAG: %[[v4:.*]] = "slicing-test-op"() : () -> i4
- // FWDBWD-DAG: %[[v3:.*]] = "slicing-test-op"() : () -> i3
- // FWDBWD-NEXT: %[[v6:.*]] = "slicing-test-op"(%[[v3]], %[[v4]]) : (i3, i4) -> i6
- // FWDBWD-DAG: %[[v2:.*]] = "slicing-test-op"() : () -> i2
- // FWDBWD-DAG: %[[v1:.*]] = "slicing-test-op"() : () -> i1
+ // FWDBWD: %[[v1:.*]] = "slicing-test-op"() : () -> i1
+ // FWDBWD: %[[v2:.*]] = "slicing-test-op"() : () -> i2
+ // FWDBWD: %[[v3:.*]] = "slicing-test-op"() : () -> i3
+ // FWDBWD: %[[v4:.*]] = "slicing-test-op"() : () -> i4
// FWDBWD: %[[v5:.*]] = "slicing-test-op"(%[[v1]], %[[v2]]) : (i1, i2) -> i5
- // FWDBWD-DAG: %[[v8:.*]] = "slicing-test-op"(%[[v5]], %[[v6]]) : (i5, i6) -> i8
- // FWDBWD-DAG: %[[v7:.*]] = "slicing-test-op"(%[[v1]], %[[v5]]) : (i1, i5) -> i7
- // FWDBWD-NEXT: %[[v9:.*]] = "slicing-test-op"(%[[v7]], %[[v8]]) : (i7, i8) -> i9
+ // FWDBWD: %[[v6:.*]] = "slicing-test-op"(%[[v3]], %[[v4]]) : (i3, i4) -> i6
+ // FWDBWD: %[[v7:.*]] = "slicing-test-op"(%[[v1]], %[[v5]]) : (i1, i5) -> i7
+ // FWDBWD: %[[v8:.*]] = "slicing-test-op"(%[[v5]], %[[v6]]) : (i5, i6) -> i8
+ // FWDBWD: %[[v9:.*]] = "slicing-test-op"(%[[v7]], %[[v8]]) : (i7, i8) -> i9
%7 = "slicing-test-op" (%1, %5) : (i1, i5) -> i7
@@ -177,15 +177,15 @@ func.func @slicing_test() {
// BWD-NEXT: %[[v6:.*]] = "slicing-test-op"(%[[v3]], %[[v4]]) : (i3, i4) -> i6
//
// FWDBWD-NEXT: matched: %[[v8:.*]] {{.*}} static slice:
- // FWDBWD-DAG: %[[v4:.*]] = "slicing-test-op"() : () -> i4
- // FWDBWD-DAG: %[[v3:.*]] = "slicing-test-op"() : () -> i3
- // FWDBWD-NEXT: %[[v6:.*]] = "slicing-test-op"(%[[v3]], %[[v4]]) : (i3, i4) -> i6
- // FWDBWD-DAG: %[[v2:.*]] = "slicing-test-op"() : () -> i2
- // FWDBWD-DAG: %[[v1:.*]] = "slicing-test-op"() : () -> i1
- // FWDBWD-NEXT: %[[v5:.*]] = "slicing-test-op"(%[[v1]], %[[v2]]) : (i1, i2) -> i5
- // FWDBWD-DAG: %[[v8:.*]] = "slicing-test-op"(%[[v5]], %[[v6]]) : (i5, i6) -> i8
- // FWDBWD-DAG: %[[v7:.*]] = "slicing-test-op"(%[[v1]], %[[v5]]) : (i1, i5) -> i7
- // FWDBWD-NEXT: %[[v9:.*]] = "slicing-test-op"(%[[v7]], %[[v8]]) : (i7, i8) -> i9
+ // FWDBWD: %[[v1:.*]] = "slicing-test-op"() : () -> i1
+ // FWDBWD: %[[v2:.*]] = "slicing-test-op"() : () -> i2
+ // FWDBWD: %[[v3:.*]] = "slicing-test-op"() : () -> i3
+ // FWDBWD: %[[v4:.*]] = "slicing-test-op"() : () -> i4
+ // FWDBWD: %[[v5:.*]] = "slicing-test-op"(%[[v1]], %[[v2]]) : (i1, i2) -> i5
+ // FWDBWD: %[[v6:.*]] = "slicing-test-op"(%[[v3]], %[[v4]]) : (i3, i4) -> i6
+ // FWDBWD: %[[v7:.*]] = "slicing-test-op"(%[[v1]], %[[v5]]) : (i1, i5) -> i7
+ // FWDBWD: %[[v8:.*]] = "slicing-test-op"(%[[v5]], %[[v6]]) : (i5, i6) -> i8
+ // FWDBWD: %[[v9:.*]] = "slicing-test-op"(%[[v7]], %[[v8]]) : (i7, i8) -> i9
%8 = "slicing-test-op" (%5, %6) : (i5, i6) -> i8
@@ -202,15 +202,15 @@ func.func @slicing_test() {
// BWD-NEXT: %[[v8:.*]] = "slicing-test-op"(%[[v5]], %[[v6]]) : (i5, i6) -> i8
//
// FWDBWD-NEXT: matched: %[[v9:.*]] {{.*}} static slice:
- // FWDBWD-DAG: %[[v4:.*]] = "slicing-test-op"() : () -> i4
- // FWDBWD-DAG: %[[v3:.*]] = "slicing-test-op"() : () -> i3
- // FWDBWD-NEXT: %[[v6:.*]] = "slicing-test-op"(%[[v3]], %[[v4]]) : (i3, i4) -> i6
- // FWDBWD-DAG: %[[v2:.*]] = "slicing-test-op"() : () -> i2
- // FWDBWD-DAG: %[[v1:.*]] = "slicing-test-op"() : () -> i1
- // FWDBWD-NEXT: %[[v5:.*]] = "slicing-test-op"(%[[v1]], %[[v2]]) : (i1, i2) -> i5
- // FWDBWD-DAG: %[[v8:.*]] = "slicing-test-op"(%[[v5]], %[[v6]]) : (i5, i6) -> i8
- // FWDBWD-DAG: %[[v7:.*]] = "slicing-test-op"(%[[v1]], %[[v5]]) : (i1, i5) -> i7
- // FWDBWD-NEXT: %[[v9:.*]] = "slicing-test-op"(%[[v7]], %[[v8]]) : (i7, i8) -> i9
+ // FWDBWD: %[[v1:.*]] = "slicing-test-op"() : () -> i1
+ // FWDBWD: %[[v2:.*]] = "slicing-test-op"() : () -> i2
+ // FWDBWD: %[[v3:.*]] = "slicing-test-op"() : () -> i3
+ // FWDBWD: %[[v4:.*]] = "slicing-test-op"() : () -> i4
+ // FWDBWD: %[[v5:.*]] = "slicing-test-op"(%[[v1]], %[[v2]]) : (i1, i2) -> i5
+ // FWDBWD: %[[v6:.*]] = "slicing-test-op"(%[[v3]], %[[v4]]) : (i3, i4) -> i6
+ // FWDBWD: %[[v7:.*]] = "slicing-test-op"(%[[v1]], %[[v5]]) : (i1, i5) -> i7
+ // FWDBWD: %[[v8:.*]] = "slicing-test-op"(%[[v5]], %[[v6]]) : (i5, i6) -> i8
+ // FWDBWD: %[[v9:.*]] = "slicing-test-op"(%[[v7]], %[[v8]]) : (i7, i8) -> i9
%9 = "slicing-test-op" (%7, %8) : (i7, i8) -> i9
diff --git a/mlir/test/lib/Analysis/TestSlice.cpp b/mlir/test/lib/Analysis/TestSlice.cpp
index b445febde5971..06c41d8c4a110 100644
--- a/mlir/test/lib/Analysis/TestSlice.cpp
+++ b/mlir/test/lib/Analysis/TestSlice.cpp
@@ -1,4 +1,4 @@
-//===------------- TestSlice.cpp - Test slice related analisis ------------===//
+//===- TestSlice.cpp - Test slice related analisis ------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -7,12 +7,14 @@
//===----------------------------------------------------------------------===//
#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Pass/Pass.h"
using namespace mlir;
-static const StringLiteral kOrderMarker = "__test_sort_original_idx__";
+static const StringLiteral kToSortMark = "test_to_sort";
+static const StringLiteral kOrderIndex = "test_sort_index";
namespace {
@@ -26,20 +28,16 @@ struct TestTopologicalSortPass
return "Print operations in topological order";
}
void runOnOperation() override {
- std::map<int, Operation *> ops;
- getOperation().walk([&ops](Operation *op) {
- if (auto originalOrderAttr = op->getAttrOfType<IntegerAttr>(kOrderMarker))
- ops[originalOrderAttr.getInt()] = op;
+ SetVector<Operation *> toSort;
+ getOperation().walk([&](Operation *op) {
+ if (op->hasAttrOfType<UnitAttr>(kToSortMark))
+ toSort.insert(op);
});
- SetVector<Operation *> sortedOp;
- for (auto op : ops)
- sortedOp.insert(op.second);
- sortedOp = topologicalSort(sortedOp);
- llvm::errs() << "Testing : " << getOperation().getName() << "\n";
- for (Operation *op : sortedOp) {
- op->print(llvm::errs());
- llvm::errs() << "\n";
- }
+
+ auto i32Type = IntegerType::get(&getContext(), 32);
+ SetVector<Operation *> sortedOps = topologicalSort(toSort);
+ for (auto [index, op] : llvm::enumerate(sortedOps))
+ op->setAttr(kOrderIndex, IntegerAttr::get(i32Type, index));
}
};
>From de8f97bddb7e4077fa23dace44f10ebce6aa6f02 Mon Sep 17 00:00:00 2001
From: Christian Ulmann <christian.ulmann at nextsilicon.com>
Date: Fri, 17 May 2024 11:42:08 +0000
Subject: [PATCH 2/4] move topo utils to analysis
---
.../mlir/{Transforms => Analysis}/TopologicalSortUtils.h | 6 +++---
mlir/lib/Analysis/CMakeLists.txt | 2 ++
.../{Transforms/Utils => Analysis}/TopologicalSortUtils.cpp | 4 ++--
mlir/lib/Transforms/TopologicalSort.cpp | 2 +-
mlir/lib/Transforms/Utils/CMakeLists.txt | 1 -
mlir/lib/Transforms/Utils/RegionUtils.cpp | 2 +-
mlir/lib/Transforms/ViewOpGraph.cpp | 2 +-
mlir/test/{Transforms => Analysis}/test-toposort.mlir | 0
mlir/test/lib/Analysis/CMakeLists.txt | 1 +
.../lib/{Transforms => Analysis}/TestTopologicalSort.cpp | 2 +-
mlir/test/lib/Transforms/CMakeLists.txt | 1 -
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel | 1 -
12 files changed, 12 insertions(+), 12 deletions(-)
rename mlir/include/mlir/{Transforms => Analysis}/TopologicalSortUtils.h (95%)
rename mlir/lib/{Transforms/Utils => Analysis}/TopologicalSortUtils.cpp (97%)
rename mlir/test/{Transforms => Analysis}/test-toposort.mlir (100%)
rename mlir/test/lib/{Transforms => Analysis}/TestTopologicalSort.cpp (98%)
diff --git a/mlir/include/mlir/Transforms/TopologicalSortUtils.h b/mlir/include/mlir/Analysis/TopologicalSortUtils.h
similarity index 95%
rename from mlir/include/mlir/Transforms/TopologicalSortUtils.h
rename to mlir/include/mlir/Analysis/TopologicalSortUtils.h
index 74e44b1dc485d..fb9441db119fd 100644
--- a/mlir/include/mlir/Transforms/TopologicalSortUtils.h
+++ b/mlir/include/mlir/Analysis/TopologicalSortUtils.h
@@ -6,8 +6,8 @@
//
//===----------------------------------------------------------------------===//
-#ifndef MLIR_TRANSFORMS_TOPOLOGICALSORTUTILS_H
-#define MLIR_TRANSFORMS_TOPOLOGICALSORTUTILS_H
+#ifndef MLIR_ANALYSIS_TOPOLOGICALSORTUTILS_H
+#define MLIR_ANALYSIS_TOPOLOGICALSORTUTILS_H
#include "mlir/IR/Block.h"
@@ -106,4 +106,4 @@ bool computeTopologicalSorting(
} // end namespace mlir
-#endif // MLIR_TRANSFORMS_TOPOLOGICALSORTUTILS_H
+#endif // MLIR_ANALYSIS_TOPOLOGICALSORTUTILS_H
diff --git a/mlir/lib/Analysis/CMakeLists.txt b/mlir/lib/Analysis/CMakeLists.txt
index 005814ddbec79..38d8415d81c72 100644
--- a/mlir/lib/Analysis/CMakeLists.txt
+++ b/mlir/lib/Analysis/CMakeLists.txt
@@ -6,6 +6,7 @@ set(LLVM_OPTIONAL_SOURCES
Liveness.cpp
CFGLoopInfo.cpp
SliceAnalysis.cpp
+ TopologicalSortUtils.cpp
AliasAnalysis/LocalAliasAnalysis.cpp
@@ -28,6 +29,7 @@ add_mlir_library(MLIRAnalysis
Liveness.cpp
CFGLoopInfo.cpp
SliceAnalysis.cpp
+ TopologicalSortUtils.cpp
AliasAnalysis/LocalAliasAnalysis.cpp
diff --git a/mlir/lib/Transforms/Utils/TopologicalSortUtils.cpp b/mlir/lib/Analysis/TopologicalSortUtils.cpp
similarity index 97%
rename from mlir/lib/Transforms/Utils/TopologicalSortUtils.cpp
rename to mlir/lib/Analysis/TopologicalSortUtils.cpp
index f3a9d217f2c98..4281beacee89e 100644
--- a/mlir/lib/Transforms/Utils/TopologicalSortUtils.cpp
+++ b/mlir/lib/Analysis/TopologicalSortUtils.cpp
@@ -1,4 +1,4 @@
-//===- TopologicalSortUtils.h - Topological sort utilities ------*- C++ -*-===//
+//===- TopologicalSortUtils.cpp - Topological sort utilities --------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Transforms/TopologicalSortUtils.h"
+#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/IR/OpDefinition.h"
using namespace mlir;
diff --git a/mlir/lib/Transforms/TopologicalSort.cpp b/mlir/lib/Transforms/TopologicalSort.cpp
index 1219968fb3692..528f6ef676020 100644
--- a/mlir/lib/Transforms/TopologicalSort.cpp
+++ b/mlir/lib/Transforms/TopologicalSort.cpp
@@ -8,8 +8,8 @@
#include "mlir/Transforms/Passes.h"
+#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/IR/RegionKindInterface.h"
-#include "mlir/Transforms/TopologicalSortUtils.h"
namespace mlir {
#define GEN_PASS_DEF_TOPOLOGICALSORT
diff --git a/mlir/lib/Transforms/Utils/CMakeLists.txt b/mlir/lib/Transforms/Utils/CMakeLists.txt
index d6aac0e2da4f5..b5788c679edc4 100644
--- a/mlir/lib/Transforms/Utils/CMakeLists.txt
+++ b/mlir/lib/Transforms/Utils/CMakeLists.txt
@@ -10,7 +10,6 @@ add_mlir_library(MLIRTransformUtils
LoopInvariantCodeMotionUtils.cpp
OneToNTypeConversion.cpp
RegionUtils.cpp
- TopologicalSortUtils.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Transforms
diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp
index 192f59b353295..b6a6dea5fe9a0 100644
--- a/mlir/lib/Transforms/Utils/RegionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Transforms/RegionUtils.h"
+#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Operation.h"
@@ -15,7 +16,6 @@
#include "mlir/IR/Value.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
-#include "mlir/Transforms/TopologicalSortUtils.h"
#include "llvm/ADT/DepthFirstIterator.h"
#include "llvm/ADT/PostOrderIterator.h"
diff --git a/mlir/lib/Transforms/ViewOpGraph.cpp b/mlir/lib/Transforms/ViewOpGraph.cpp
index c2eb2b893cea4..b3c0a06c96fea 100644
--- a/mlir/lib/Transforms/ViewOpGraph.cpp
+++ b/mlir/lib/Transforms/ViewOpGraph.cpp
@@ -8,12 +8,12 @@
#include "mlir/Transforms/ViewOpGraph.h"
+#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Operation.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/IndentedOstream.h"
-#include "mlir/Transforms/TopologicalSortUtils.h"
#include "llvm/Support/Format.h"
#include "llvm/Support/GraphWriter.h"
#include <map>
diff --git a/mlir/test/Transforms/test-toposort.mlir b/mlir/test/Analysis/test-toposort.mlir
similarity index 100%
rename from mlir/test/Transforms/test-toposort.mlir
rename to mlir/test/Analysis/test-toposort.mlir
diff --git a/mlir/test/lib/Analysis/CMakeLists.txt b/mlir/test/lib/Analysis/CMakeLists.txt
index d168888c1e71e..7c6b31ae8b73e 100644
--- a/mlir/test/lib/Analysis/CMakeLists.txt
+++ b/mlir/test/lib/Analysis/CMakeLists.txt
@@ -10,6 +10,7 @@ add_mlir_library(MLIRTestAnalysis
TestMemRefDependenceCheck.cpp
TestMemRefStrideCalculation.cpp
TestSlice.cpp
+ TestTopologicalSort.cpp
DataFlow/TestDeadCodeAnalysis.cpp
DataFlow/TestDenseBackwardDataFlowAnalysis.cpp
diff --git a/mlir/test/lib/Transforms/TestTopologicalSort.cpp b/mlir/test/lib/Analysis/TestTopologicalSort.cpp
similarity index 98%
rename from mlir/test/lib/Transforms/TestTopologicalSort.cpp
rename to mlir/test/lib/Analysis/TestTopologicalSort.cpp
index 3b110c7126200..c7e0206b2a4d7 100644
--- a/mlir/test/lib/Transforms/TestTopologicalSort.cpp
+++ b/mlir/test/lib/Analysis/TestTopologicalSort.cpp
@@ -6,10 +6,10 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h"
-#include "mlir/Transforms/TopologicalSortUtils.h"
using namespace mlir;
diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt
index a849b7ebd29e2..975a41ac3d5fe 100644
--- a/mlir/test/lib/Transforms/CMakeLists.txt
+++ b/mlir/test/lib/Transforms/CMakeLists.txt
@@ -26,7 +26,6 @@ add_mlir_library(MLIRTestTransforms
TestInlining.cpp
TestIntRangeInference.cpp
TestMakeIsolatedFromAbove.cpp
- TestTopologicalSort.cpp
${MLIRTestTransformsPDLSrc}
EXCLUDE_FROM_LIBMLIR
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index f7495a202669c..566be0c730d79 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -7579,7 +7579,6 @@ cc_library(
"include/mlir/Transforms/LoopInvariantCodeMotionUtils.h",
"include/mlir/Transforms/OneToNTypeConversion.h",
"include/mlir/Transforms/RegionUtils.h",
- "include/mlir/Transforms/TopologicalSortUtils.h",
],
includes = ["include"],
deps = [
>From 1c5e0eb214afb3f6339e562d0d64780cd5c07ae9 Mon Sep 17 00:00:00 2001
From: Christian Ulmann <christian.ulmann at nextsilicon.com>
Date: Fri, 17 May 2024 12:02:50 +0000
Subject: [PATCH 3/4] migrate block topo sort
---
.../mlir/Analysis/TopologicalSortUtils.h | 4 ++++
mlir/include/mlir/Transforms/RegionUtils.h | 4 ----
mlir/lib/Analysis/SliceAnalysis.cpp | 23 ++-----------------
mlir/lib/Analysis/TopologicalSortUtils.cpp | 21 +++++++++++++++++
.../OpenACC/OpenACCToLLVMIRTranslation.cpp | 2 +-
.../OpenMP/OpenMPToLLVMIRTranslation.cpp | 1 +
mlir/lib/Target/LLVMIR/ModuleTranslation.cpp | 2 +-
mlir/lib/Transforms/Mem2Reg.cpp | 2 +-
mlir/lib/Transforms/Utils/RegionUtils.cpp | 17 --------------
9 files changed, 31 insertions(+), 45 deletions(-)
diff --git a/mlir/include/mlir/Analysis/TopologicalSortUtils.h b/mlir/include/mlir/Analysis/TopologicalSortUtils.h
index fb9441db119fd..c2bc15ad3143f 100644
--- a/mlir/include/mlir/Analysis/TopologicalSortUtils.h
+++ b/mlir/include/mlir/Analysis/TopologicalSortUtils.h
@@ -104,6 +104,10 @@ bool computeTopologicalSorting(
MutableArrayRef<Operation *> ops,
function_ref<bool(Value, Operation *)> isOperandReady = nullptr);
+/// Get a list of blocks that is sorted according to dominance. This sort is
+/// stable.
+SetVector<Block *> getBlocksSortedByDominance(Region ®ion);
+
} // end namespace mlir
#endif // MLIR_ANALYSIS_TOPOLOGICALSORTUTILS_H
diff --git a/mlir/include/mlir/Transforms/RegionUtils.h b/mlir/include/mlir/Transforms/RegionUtils.h
index f65d0d44eef42..06eebff201d1b 100644
--- a/mlir/include/mlir/Transforms/RegionUtils.h
+++ b/mlir/include/mlir/Transforms/RegionUtils.h
@@ -87,10 +87,6 @@ LogicalResult eraseUnreachableBlocks(RewriterBase &rewriter,
LogicalResult runRegionDCE(RewriterBase &rewriter,
MutableArrayRef<Region> regions);
-/// Get a list of blocks that is sorted according to dominance. This sort is
-/// stable.
-SetVector<Block *> getBlocksSortedByDominance(Region ®ion);
-
} // namespace mlir
#endif // MLIR_TRANSFORMS_REGIONUTILS_H_
diff --git a/mlir/lib/Analysis/SliceAnalysis.cpp b/mlir/lib/Analysis/SliceAnalysis.cpp
index f93183749dfd2..37dae769dbc6d 100644
--- a/mlir/lib/Analysis/SliceAnalysis.cpp
+++ b/mlir/lib/Analysis/SliceAnalysis.cpp
@@ -11,13 +11,11 @@
//===----------------------------------------------------------------------===//
#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/IR/Block.h"
-#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Operation.h"
-#include "mlir/IR/RegionGraphTraits.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Support/LLVM.h"
-#include "llvm/ADT/PostOrderIterator.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallPtrSet.h"
@@ -167,23 +165,6 @@ mlir::getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions,
return topologicalSort(slice);
}
-/// TODO: deduplicate
-static SetVector<Block *> getTopologicallySortedBlocks(Region ®ion) {
- // For each block that has not been visited yet (i.e. that has no
- // predecessors), add it to the list as well as its successors.
- SetVector<Block *> blocks;
- for (Block &b : region) {
- if (blocks.count(&b) == 0) {
- llvm::ReversePostOrderTraversal<Block *> traversal(&b);
- blocks.insert(traversal.begin(), traversal.end());
- }
- }
- assert(blocks.size() == region.getBlocks().size() &&
- "some blocks are not sorted");
-
- return blocks;
-}
-
/// Computes the common ancestor region of all operations in `ops`. Remembers
/// all the traversed regions in `traversedRegions`.
static Region *findCommonParentRegion(const SetVector<Operation *> &ops,
@@ -220,7 +201,7 @@ static void topoSortRegion(Region ®ion,
const DenseSet<Region *> &relevantRegions,
const SetVector<Operation *> &toSort,
SetVector<Operation *> &result) {
- SetVector<Block *> sortedBlocks = getTopologicallySortedBlocks(region);
+ SetVector<Block *> sortedBlocks = getBlocksSortedByDominance(region);
for (Block *block : sortedBlocks) {
for (Operation &op : *block) {
if (toSort.contains(&op))
diff --git a/mlir/lib/Analysis/TopologicalSortUtils.cpp b/mlir/lib/Analysis/TopologicalSortUtils.cpp
index 4281beacee89e..f5e16e9a91fe5 100644
--- a/mlir/lib/Analysis/TopologicalSortUtils.cpp
+++ b/mlir/lib/Analysis/TopologicalSortUtils.cpp
@@ -7,7 +7,12 @@
//===----------------------------------------------------------------------===//
#include "mlir/Analysis/TopologicalSortUtils.h"
+#include "mlir/IR/Block.h"
#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/RegionGraphTraits.h"
+
+#include "llvm/ADT/PostOrderIterator.h"
+#include "llvm/ADT/SetVector.h"
using namespace mlir;
@@ -146,3 +151,19 @@ bool mlir::computeTopologicalSorting(
return allOpsScheduled;
}
+
+SetVector<Block *> mlir::getBlocksSortedByDominance(Region ®ion) {
+ // For each block that has not been visited yet (i.e. that has no
+ // predecessors), add it to the list as well as its successors.
+ SetVector<Block *> blocks;
+ for (Block &b : region) {
+ if (blocks.count(&b) == 0) {
+ llvm::ReversePostOrderTraversal<Block *> traversal(&b);
+ blocks.insert(traversal.begin(), traversal.end());
+ }
+ }
+ assert(blocks.size() == region.getBlocks().size() &&
+ "some blocks are not sorted");
+
+ return blocks;
+}
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp
index eeda245ce969f..d9cf85e4aecab 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp
@@ -12,6 +12,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.h"
+#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/OpenACC/OpenACC.h"
#include "mlir/IR/BuiltinOps.h"
@@ -19,7 +20,6 @@
#include "mlir/Support/LLVM.h"
#include "mlir/Target/LLVMIR/Dialect/OpenMPCommon.h"
#include "mlir/Target/LLVMIR/ModuleTranslation.h"
-#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Frontend/OpenMP/OMPConstants.h"
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 34b6903f8da07..9d125b7f11809 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -11,6 +11,7 @@
//
//===----------------------------------------------------------------------===//
#include "mlir/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.h"
+#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "mlir/Dialect/OpenMP/OpenMPInterfaces.h"
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index cf3257c8b9b87..1ec0736ec08bf 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -16,6 +16,7 @@
#include "AttrKindDetail.h"
#include "DebugTranslation.h"
#include "LoopAnnotationTranslation.h"
+#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/Dialect/DLTI/DLTI.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMInterfaces.h"
@@ -33,7 +34,6 @@
#include "mlir/Support/LogicalResult.h"
#include "mlir/Target/LLVMIR/LLVMTranslationInterface.h"
#include "mlir/Target/LLVMIR/TypeToLLVM.h"
-#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/PostOrderIterator.h"
#include "llvm/ADT/SetVector.h"
diff --git a/mlir/lib/Transforms/Mem2Reg.cpp b/mlir/lib/Transforms/Mem2Reg.cpp
index e2e240ad865ce..a452cc3fae8ac 100644
--- a/mlir/lib/Transforms/Mem2Reg.cpp
+++ b/mlir/lib/Transforms/Mem2Reg.cpp
@@ -9,6 +9,7 @@
#include "mlir/Transforms/Mem2Reg.h"
#include "mlir/Analysis/DataLayoutAnalysis.h"
#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/PatternMatch.h"
@@ -16,7 +17,6 @@
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/MemorySlotInterfaces.h"
#include "mlir/Transforms/Passes.h"
-#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/GenericIteratedDominanceFrontier.h"
diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp
index b6a6dea5fe9a0..b5e641d39fc0a 100644
--- a/mlir/lib/Transforms/Utils/RegionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp
@@ -19,7 +19,6 @@
#include "llvm/ADT/DepthFirstIterator.h"
#include "llvm/ADT/PostOrderIterator.h"
-#include "llvm/ADT/SmallSet.h"
#include <deque>
@@ -836,19 +835,3 @@ LogicalResult mlir::simplifyRegions(RewriterBase &rewriter,
return success(eliminatedBlocks || eliminatedOpsOrArgs ||
mergedIdenticalBlocks);
}
-
-SetVector<Block *> mlir::getBlocksSortedByDominance(Region ®ion) {
- // For each block that has not been visited yet (i.e. that has no
- // predecessors), add it to the list as well as its successors.
- SetVector<Block *> blocks;
- for (Block &b : region) {
- if (blocks.count(&b) == 0) {
- llvm::ReversePostOrderTraversal<Block *> traversal(&b);
- blocks.insert(traversal.begin(), traversal.end());
- }
- }
- assert(blocks.size() == region.getBlocks().size() &&
- "some blocks are not sorted");
-
- return blocks;
-}
>From 8c6a17c1eb6782c5db75012e3d0858723ac38b6f Mon Sep 17 00:00:00 2001
From: Christian Ulmann <christian.ulmann at nextsilicon.com>
Date: Fri, 17 May 2024 13:56:09 +0000
Subject: [PATCH 4/4] move all topo sorts into the utils file
---
mlir/include/mlir/Analysis/SliceAnalysis.h | 6 --
.../mlir/Analysis/TopologicalSortUtils.h | 4 +
mlir/lib/Analysis/SliceAnalysis.cpp | 74 -------------------
mlir/lib/Analysis/TopologicalSortUtils.cpp | 74 +++++++++++++++++++
.../Conversion/VectorToGPU/VectorToGPU.cpp | 1 +
.../Dialect/Affine/Utils/LoopFusionUtils.cpp | 1 +
mlir/lib/Transforms/SROA.cpp | 1 +
mlir/test/lib/Analysis/TestSlice.cpp | 2 +-
8 files changed, 82 insertions(+), 81 deletions(-)
diff --git a/mlir/include/mlir/Analysis/SliceAnalysis.h b/mlir/include/mlir/Analysis/SliceAnalysis.h
index 19571fc1946be..99279fdfe427c 100644
--- a/mlir/include/mlir/Analysis/SliceAnalysis.h
+++ b/mlir/include/mlir/Analysis/SliceAnalysis.h
@@ -223,12 +223,6 @@ SetVector<Operation *>
getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions = {},
const ForwardSliceOptions &forwardSliceOptions = {});
-/// Multi-root DAG topological sort.
-/// Performs a topological sort of the Operation in the `toSort` SetVector.
-/// Returns a topologically sorted SetVector.
-/// Does not support multi-sets.
-SetVector<Operation *> topologicalSort(const SetVector<Operation *> &toSort);
-
/// Utility to match a generic reduction given a list of iteration-carried
/// arguments, `iterCarriedArgs` and the position of the potential reduction
/// argument within the list, `redPos`. If a reduction is matched, returns the
diff --git a/mlir/include/mlir/Analysis/TopologicalSortUtils.h b/mlir/include/mlir/Analysis/TopologicalSortUtils.h
index c2bc15ad3143f..7aabc5ee457c0 100644
--- a/mlir/include/mlir/Analysis/TopologicalSortUtils.h
+++ b/mlir/include/mlir/Analysis/TopologicalSortUtils.h
@@ -108,6 +108,10 @@ bool computeTopologicalSorting(
/// stable.
SetVector<Block *> getBlocksSortedByDominance(Region ®ion);
+/// Sorts all operation in `toSort` topologically while also region semantics.
+/// Does not support multi-sets.
+SetVector<Operation *> topologicalSort(const SetVector<Operation *> &toSort);
+
} // end namespace mlir
#endif // MLIR_ANALYSIS_TOPOLOGICALSORTUTILS_H
diff --git a/mlir/lib/Analysis/SliceAnalysis.cpp b/mlir/lib/Analysis/SliceAnalysis.cpp
index 37dae769dbc6d..2b1cf411ceeee 100644
--- a/mlir/lib/Analysis/SliceAnalysis.cpp
+++ b/mlir/lib/Analysis/SliceAnalysis.cpp
@@ -165,80 +165,6 @@ mlir::getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions,
return topologicalSort(slice);
}
-/// Computes the common ancestor region of all operations in `ops`. Remembers
-/// all the traversed regions in `traversedRegions`.
-static Region *findCommonParentRegion(const SetVector<Operation *> &ops,
- DenseSet<Region *> &traversedRegions) {
- // Map to count the number of times a region was encountered.
- llvm::DenseMap<Region *, size_t> regionCounts;
- size_t expectedCount = ops.size();
-
- // Walk the region tree for each operation towards the root and add to the
- // region count.
- Region *res = nullptr;
- for (Operation *op : ops) {
- Region *current = op->getParentRegion();
- while (current) {
- // Insert or get the count.
- auto it = regionCounts.try_emplace(current, 0).first;
- size_t count = ++it->getSecond();
- if (count == expectedCount) {
- res = current;
- break;
- }
- current = current->getParentRegion();
- }
- }
- auto firstRange = llvm::make_first_range(regionCounts);
- traversedRegions.insert(firstRange.begin(), firstRange.end());
- return res;
-}
-
-/// Topologically traverses `region` and insers all encountered operations in
-/// `toSort` into the result. Recursively traverses regions when they are
-/// present in `relevantRegions`.
-static void topoSortRegion(Region ®ion,
- const DenseSet<Region *> &relevantRegions,
- const SetVector<Operation *> &toSort,
- SetVector<Operation *> &result) {
- SetVector<Block *> sortedBlocks = getBlocksSortedByDominance(region);
- for (Block *block : sortedBlocks) {
- for (Operation &op : *block) {
- if (toSort.contains(&op))
- result.insert(&op);
- for (Region &subRegion : op.getRegions()) {
- // Skip regions that do not contain operations from `toSort`.
- if (!relevantRegions.contains(®ion))
- continue;
- topoSortRegion(subRegion, relevantRegions, toSort, result);
- }
- }
- }
-}
-
-SetVector<Operation *>
-mlir::topologicalSort(const SetVector<Operation *> &toSort) {
- if (toSort.size() <= 1)
- return toSort;
-
- assert(llvm::all_of(toSort,
- [&](Operation *op) { return toSort.count(op) == 1; }) &&
- "expected only unique set entries");
-
- // First, find the root region to start the recursive traversal through the
- // IR.
- DenseSet<Region *> relevantRegions;
- Region *rootRegion = findCommonParentRegion(toSort, relevantRegions);
- assert(rootRegion && "expected all ops to have a common ancestor");
-
- // Sort all element in `toSort` by recursively traversing the IR.
- SetVector<Operation *> result;
- topoSortRegion(*rootRegion, relevantRegions, toSort, result);
- assert(result.size() == toSort.size() &&
- "expected all operations to be present in the result");
- return result;
-}
-
/// Returns true if `value` (transitively) depends on iteration-carried values
/// of the given `ancestorOp`.
static bool dependsOnCarriedVals(Value value,
diff --git a/mlir/lib/Analysis/TopologicalSortUtils.cpp b/mlir/lib/Analysis/TopologicalSortUtils.cpp
index f5e16e9a91fe5..2fc1bd582ef47 100644
--- a/mlir/lib/Analysis/TopologicalSortUtils.cpp
+++ b/mlir/lib/Analysis/TopologicalSortUtils.cpp
@@ -167,3 +167,77 @@ SetVector<Block *> mlir::getBlocksSortedByDominance(Region ®ion) {
return blocks;
}
+
+/// Computes the common ancestor region of all operations in `ops`. Remembers
+/// all the traversed regions in `traversedRegions`.
+static Region *findCommonParentRegion(const SetVector<Operation *> &ops,
+ DenseSet<Region *> &traversedRegions) {
+ // Map to count the number of times a region was encountered.
+ llvm::DenseMap<Region *, size_t> regionCounts;
+ size_t expectedCount = ops.size();
+
+ // Walk the region tree for each operation towards the root and add to the
+ // region count.
+ Region *res = nullptr;
+ for (Operation *op : ops) {
+ Region *current = op->getParentRegion();
+ while (current) {
+ // Insert or get the count.
+ auto it = regionCounts.try_emplace(current, 0).first;
+ size_t count = ++it->getSecond();
+ if (count == expectedCount) {
+ res = current;
+ break;
+ }
+ current = current->getParentRegion();
+ }
+ }
+ auto firstRange = llvm::make_first_range(regionCounts);
+ traversedRegions.insert(firstRange.begin(), firstRange.end());
+ return res;
+}
+
+/// Topologically traverses `region` and insers all encountered operations in
+/// `toSort` into the result. Recursively traverses regions when they are
+/// present in `relevantRegions`.
+static void topoSortRegion(Region ®ion,
+ const DenseSet<Region *> &relevantRegions,
+ const SetVector<Operation *> &toSort,
+ SetVector<Operation *> &result) {
+ SetVector<Block *> sortedBlocks = getBlocksSortedByDominance(region);
+ for (Block *block : sortedBlocks) {
+ for (Operation &op : *block) {
+ if (toSort.contains(&op))
+ result.insert(&op);
+ for (Region &subRegion : op.getRegions()) {
+ // Skip regions that do not contain operations from `toSort`.
+ if (!relevantRegions.contains(®ion))
+ continue;
+ topoSortRegion(subRegion, relevantRegions, toSort, result);
+ }
+ }
+ }
+}
+
+SetVector<Operation *>
+mlir::topologicalSort(const SetVector<Operation *> &toSort) {
+ if (toSort.size() <= 1)
+ return toSort;
+
+ assert(llvm::all_of(toSort,
+ [&](Operation *op) { return toSort.count(op) == 1; }) &&
+ "expected only unique set entries");
+
+ // First, find the root region to start the recursive traversal through the
+ // IR.
+ DenseSet<Region *> relevantRegions;
+ Region *rootRegion = findCommonParentRegion(toSort, relevantRegions);
+ assert(rootRegion && "expected all ops to have a common ancestor");
+
+ // Sort all element in `toSort` by recursively traversing the IR.
+ SetVector<Operation *> result;
+ topoSortRegion(*rootRegion, relevantRegions, toSort, result);
+ assert(result.size() == toSort.size() &&
+ "expected all operations to be present in the result");
+ return result;
+}
diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index 332f0a2eecfcf..4496c2bc5fe8b 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -15,6 +15,7 @@
#include <type_traits>
#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
diff --git a/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp
index 84ae4b52dcf4e..7f3e43d0b4cd3 100644
--- a/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp
@@ -12,6 +12,7 @@
#include "mlir/Dialect/Affine/LoopFusionUtils.h"
#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
#include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
#include "mlir/Dialect/Affine/Analysis/Utils.h"
diff --git a/mlir/lib/Transforms/SROA.cpp b/mlir/lib/Transforms/SROA.cpp
index 67cbade07bc94..39f7256fb789d 100644
--- a/mlir/lib/Transforms/SROA.cpp
+++ b/mlir/lib/Transforms/SROA.cpp
@@ -9,6 +9,7 @@
#include "mlir/Transforms/SROA.h"
#include "mlir/Analysis/DataLayoutAnalysis.h"
#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/Interfaces/MemorySlotInterfaces.h"
#include "mlir/Transforms/Passes.h"
diff --git a/mlir/test/lib/Analysis/TestSlice.cpp b/mlir/test/lib/Analysis/TestSlice.cpp
index 06c41d8c4a110..fc367c07ad863 100644
--- a/mlir/test/lib/Analysis/TestSlice.cpp
+++ b/mlir/test/lib/Analysis/TestSlice.cpp
@@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Pass/Pass.h"
More information about the Mlir-commits
mailing list