[Mlir-commits] [mlir] c8457eb - [mlir][transforms] Add a topological sort utility and pass
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon May 16 13:47:35 PDT 2022
Author: Mogball
Date: 2022-05-16T20:47:30Z
New Revision: c8457eb5323ca99361c1748a22a16edb3160ae5f
URL: https://github.com/llvm/llvm-project/commit/c8457eb5323ca99361c1748a22a16edb3160ae5f
DIFF: https://github.com/llvm/llvm-project/commit/c8457eb5323ca99361c1748a22a16edb3160ae5f.diff
LOG: [mlir][transforms] Add a topological sort utility and pass
This patch adds a topological sort utility and pass. A topological sort reorders
the operations in a block without SSA dominance such that, as much as possible,
users of values come after their producers.
The utility function sorts topologically the operation range in a given block
with an optional user-provided callback that can be used to virtually break cycles.
The toposort pass itself recursively sorts graph regions under the target op.
Reviewed By: mehdi_amini
Differential Revision: https://reviews.llvm.org/D125063
Added:
mlir/include/mlir/Transforms/TopologicalSortUtils.h
mlir/lib/Transforms/TopologicalSort.cpp
mlir/lib/Transforms/Utils/TopologicalSortUtils.cpp
mlir/test/Transforms/test-toposort.mlir
Modified:
mlir/include/mlir/Transforms/Passes.h
mlir/include/mlir/Transforms/Passes.td
mlir/lib/Transforms/CMakeLists.txt
mlir/lib/Transforms/Utils/CMakeLists.txt
Removed:
################################################################################
diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h
index a74f65a81583f..f6e1feeaf05ef 100644
--- a/mlir/include/mlir/Transforms/Passes.h
+++ b/mlir/include/mlir/Transforms/Passes.h
@@ -92,6 +92,11 @@ std::unique_ptr<Pass> createSymbolDCEPass();
std::unique_ptr<Pass>
createSymbolPrivatizePass(ArrayRef<std::string> excludeSymbols = {});
+/// Creates a pass that recursively sorts nested regions without SSA dominance
+/// topologically such that, as much as possible, users of values appear after
+/// their producers.
+std::unique_ptr<Pass> createTopologicalSortPass();
+
//===----------------------------------------------------------------------===//
// Registration
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index ee802269b3034..7ac71ceeb588a 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -260,4 +260,21 @@ def ViewOpGraph : Pass<"view-op-graph"> {
let constructor = "mlir::createPrintOpGraphPass()";
}
+def TopologicalSort : Pass<"topological-sort"> {
+ let summary = "Sort regions without SSA dominance in topological order";
+ let description = [{
+ Recursively sorts all nested regions without SSA dominance in topological
+ order. The main purpose is readability, as well as potentially processing of
+ certain transformations and analyses. The function sorts the operations in
+ all nested regions such that, as much as possible, all users appear after
+ their producers.
+
+ This sort is stable. If the block is already topologically sorted, the IR
+ is not changed. Operations that form a cycle are moved to the end of the
+ regions in a stable order.
+ }];
+
+ let constructor = "mlir::createTopologicalSortPass()";
+}
+
#endif // MLIR_TRANSFORMS_PASSES
diff --git a/mlir/include/mlir/Transforms/TopologicalSortUtils.h b/mlir/include/mlir/Transforms/TopologicalSortUtils.h
new file mode 100644
index 0000000000000..828726df00a53
--- /dev/null
+++ b/mlir/include/mlir/Transforms/TopologicalSortUtils.h
@@ -0,0 +1,100 @@
+//===- TopologicalSortUtils.h - Topological sort utilities ------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TRANSFORMS_TOPOLOGICALSORTUTILS_H
+#define MLIR_TRANSFORMS_TOPOLOGICALSORTUTILS_H
+
+#include "mlir/IR/Block.h"
+
+namespace mlir {
+
+/// Given a block, sort a range operations in said block in topological order.
+/// The main purpose is readability of graph regions, potentially faster
+/// processing of certain transformations and analyses, or fixing the SSA
+/// dominance of blocks that require it after transformations. The function
+/// sorts the given operations such that, as much as possible, all users appear
+/// after their producers.
+///
+/// For example:
+///
+/// ```mlir
+/// %0 = test.foo
+/// %1 = test.bar %0, %2
+/// %2 = test.baz
+/// ```
+///
+/// Will become:
+///
+/// ```mlir
+/// %0 = test.foo
+/// %1 = test.baz
+/// %2 = test.bar %0, %1
+/// ```
+///
+/// The sort also works on operations with regions and implicit captures. For
+/// example:
+///
+/// ```mlir
+/// %0 = test.foo {
+/// test.baz %1
+/// %1 = test.bar %2
+/// }
+/// %2 = test.foo
+/// ```
+///
+/// Will become:
+///
+/// ```mlir
+/// %0 = test.foo
+/// %1 = test.foo {
+/// test.baz %2
+/// %2 = test.bar %0
+/// }
+/// ```
+///
+/// Note that the sort is not recursive on nested regions. This sort is stable;
+/// if the operations are already topologically sorted, nothing changes.
+///
+/// Operations that form cycles are moved to the end of the block in order. If
+/// the sort is left with only operations that form a cycle, it breaks the cycle
+/// by marking the first encountered operation as ready and moving on.
+///
+/// The function optionally accepts a callback that can be provided by users to
+/// virtually break cycles early. It is called on top-level operations in the
+/// block with value uses at or below those operations. The function should
+/// return true to mark that value as ready to be scheduled.
+///
+/// For example, if `isOperandReady` is set to always mark edges from `foo.A` to
+/// `foo.B` as ready, these operations:
+///
+/// ```mlir
+/// %0 = foo.B(%1)
+/// %1 = foo.C(%2)
+/// %2 = foo.A(%0)
+/// ```
+///
+/// Are sorted as:
+///
+/// ```mlir
+/// %0 = foo.A(%2)
+/// %1 = foo.C(%0)
+/// %2 = foo.B(%1)
+/// ```
+bool sortTopologically(
+ Block *block, iterator_range<Block::iterator> ops,
+ function_ref<bool(Value, Operation *)> isOperandReady = nullptr);
+
+/// Given a block, sort its operations in topological order, excluding its
+/// terminator if it has one.
+bool sortTopologically(
+ Block *block,
+ function_ref<bool(Value, Operation *)> isOperandReady = nullptr);
+
+} // end namespace mlir
+
+#endif // MLIR_TRANSFORMS_TOPOLOGICALSORTUTILS_H
diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt
index fef0feb379e57..52976c7040c42 100644
--- a/mlir/lib/Transforms/CMakeLists.txt
+++ b/mlir/lib/Transforms/CMakeLists.txt
@@ -12,6 +12,7 @@ add_mlir_library(MLIRTransforms
StripDebugInfo.cpp
SymbolDCE.cpp
SymbolPrivatize.cpp
+ TopologicalSort.cpp
ViewOpGraph.cpp
ADDITIONAL_HEADER_DIRS
diff --git a/mlir/lib/Transforms/TopologicalSort.cpp b/mlir/lib/Transforms/TopologicalSort.cpp
new file mode 100644
index 0000000000000..afa0b78fbf255
--- /dev/null
+++ b/mlir/lib/Transforms/TopologicalSort.cpp
@@ -0,0 +1,33 @@
+//===- TopologicalSort.cpp - Topological sort pass ------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "PassDetail.h"
+#include "mlir/IR/RegionKindInterface.h"
+#include "mlir/Transforms/TopologicalSortUtils.h"
+
+using namespace mlir;
+
+namespace {
+struct TopologicalSortPass : public TopologicalSortBase<TopologicalSortPass> {
+ void runOnOperation() override {
+ // Topologically sort the regions of the operation without SSA dominance.
+ getOperation()->walk([](RegionKindInterface op) {
+ for (auto &it : llvm::enumerate(op->getRegions())) {
+ if (op.hasSSADominance(it.index()))
+ continue;
+ for (Block &block : it.value())
+ sortTopologically(&block);
+ }
+ });
+ }
+};
+} // end anonymous namespace
+
+std::unique_ptr<Pass> mlir::createTopologicalSortPass() {
+ return std::make_unique<TopologicalSortPass>();
+}
diff --git a/mlir/lib/Transforms/Utils/CMakeLists.txt b/mlir/lib/Transforms/Utils/CMakeLists.txt
index a9d410cb21c0a..755e3196837d2 100644
--- a/mlir/lib/Transforms/Utils/CMakeLists.txt
+++ b/mlir/lib/Transforms/Utils/CMakeLists.txt
@@ -7,6 +7,7 @@ add_mlir_library(MLIRTransformUtils
LoopInvariantCodeMotionUtils.cpp
RegionUtils.cpp
SideEffectUtils.cpp
+ TopologicalSortUtils.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Transforms
diff --git a/mlir/lib/Transforms/Utils/TopologicalSortUtils.cpp b/mlir/lib/Transforms/Utils/TopologicalSortUtils.cpp
new file mode 100644
index 0000000000000..992db294a5650
--- /dev/null
+++ b/mlir/lib/Transforms/Utils/TopologicalSortUtils.cpp
@@ -0,0 +1,98 @@
+//===- TopologicalSortUtils.h - Topological sort utilities ------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Transforms/TopologicalSortUtils.h"
+#include "mlir/IR/OpDefinition.h"
+
+using namespace mlir;
+
+bool mlir::sortTopologically(
+ Block *block, llvm::iterator_range<Block::iterator> ops,
+ function_ref<bool(Value, Operation *)> isOperandReady) {
+ if (ops.empty())
+ return true;
+
+ // The set of operations that have not yet been scheduled.
+ DenseSet<Operation *> unscheduledOps;
+ // Mark all operations as unscheduled.
+ for (Operation &op : ops)
+ unscheduledOps.insert(&op);
+
+ Block::iterator nextScheduledOp = ops.begin();
+ Block::iterator end = ops.end();
+
+ // An operation is ready to be scheduled if all its operands are ready. An
+ // operation is ready if:
+ const auto isReady = [&](Value value, Operation *top) {
+ // - the user-provided callback marks it as ready,
+ if (isOperandReady && isOperandReady(value, top))
+ return true;
+ Operation *parent = value.getDefiningOp();
+ // - it is a block argument,
+ if (!parent)
+ return true;
+ Operation *ancestor = block->findAncestorOpInBlock(*parent);
+ // - it is an implicit capture,
+ if (!ancestor)
+ return true;
+ // - it is defined in a nested region, or
+ if (ancestor == top)
+ return true;
+ // - its ancestor in the block is scheduled.
+ return !unscheduledOps.contains(ancestor);
+ };
+
+ bool allOpsScheduled = true;
+ while (!unscheduledOps.empty()) {
+ bool scheduledAtLeastOnce = false;
+
+ // Loop over the ops that are not sorted yet, try to find the ones "ready",
+ // i.e. the ones for which there aren't any operand produced by an op in the
+ // set, and "schedule" it (move it before the `nextScheduledOp`).
+ for (Operation &op :
+ llvm::make_early_inc_range(llvm::make_range(nextScheduledOp, end))) {
+ // An operation is recursively ready to be scheduled of it and its nested
+ // operations are ready.
+ WalkResult readyToSchedule = op.walk([&](Operation *nestedOp) {
+ return llvm::all_of(
+ nestedOp->getOperands(),
+ [&](Value operand) { return isReady(operand, &op); })
+ ? WalkResult::advance()
+ : WalkResult::interrupt();
+ });
+ if (readyToSchedule.wasInterrupted())
+ continue;
+
+ // Schedule the operation by moving it to the start.
+ unscheduledOps.erase(&op);
+ op.moveBefore(block, nextScheduledOp);
+ scheduledAtLeastOnce = true;
+ // Move the iterator forward if we schedule the operation at the front.
+ if (&op == &*nextScheduledOp)
+ ++nextScheduledOp;
+ }
+ // If no operations were scheduled, give up and advance the iterator.
+ if (!scheduledAtLeastOnce) {
+ allOpsScheduled = false;
+ unscheduledOps.erase(&*nextScheduledOp);
+ ++nextScheduledOp;
+ }
+ }
+
+ return allOpsScheduled;
+}
+
+bool mlir::sortTopologically(
+ Block *block, function_ref<bool(Value, Operation *)> isOperandReady) {
+ if (block->empty())
+ return true;
+ if (block->back().hasTrait<OpTrait::IsTerminator>())
+ return sortTopologically(block, block->without_terminator(),
+ isOperandReady);
+ return sortTopologically(block, *block, isOperandReady);
+}
diff --git a/mlir/test/Transforms/test-toposort.mlir b/mlir/test/Transforms/test-toposort.mlir
new file mode 100644
index 0000000000000..c792add375bcb
--- /dev/null
+++ b/mlir/test/Transforms/test-toposort.mlir
@@ -0,0 +1,75 @@
+// RUN: mlir-opt -topological-sort %s | FileCheck %s
+
+// Test producer is after user.
+// CHECK-LABEL: test.graph_region
+test.graph_region {
+ // CHECK-NEXT: test.foo
+ // CHECK-NEXT: test.baz
+ // CHECK-NEXT: test.bar
+ %0 = "test.foo"() : () -> i32
+ "test.bar"(%1, %0) : (i32, i32) -> ()
+ %1 = "test.baz"() : () -> i32
+}
+
+// Test cycles.
+// CHECK-LABEL: test.graph_region
+test.graph_region {
+ // CHECK-NEXT: test.d
+ // CHECK-NEXT: test.a
+ // CHECK-NEXT: test.c
+ // CHECK-NEXT: test.b
+ %2 = "test.c"(%1) : (i32) -> i32
+ %1 = "test.b"(%0, %2) : (i32, i32) -> i32
+ %0 = "test.a"(%3) : (i32) -> i32
+ %3 = "test.d"() : () -> i32
+}
+
+// Test block arguments.
+// CHECK-LABEL: test.graph_region
+test.graph_region {
+// CHECK-NEXT: (%{{.*}}:
+^entry(%arg0: i32):
+ // CHECK-NEXT: test.foo
+ // CHECK-NEXT: test.baz
+ // CHECK-NEXT: test.bar
+ %0 = "test.foo"(%arg0) : (i32) -> i32
+ "test.bar"(%1, %0) : (i32, i32) -> ()
+ %1 = "test.baz"(%arg0) : (i32) -> i32
+}
+
+// Test implicit block capture (and sort nested region).
+// CHECK-LABEL: test.graph_region
+func.func @test_graph_cfg() -> () {
+ %0 = "test.foo"() : () -> i32
+ cf.br ^next(%0 : i32)
+
+^next(%1: i32):
+ test.graph_region {
+ // CHECK-NEXT: test.foo
+ // CHECK-NEXT: test.baz
+ // CHECK-NEXT: test.bar
+ %2 = "test.foo"(%1) : (i32) -> i32
+ "test.bar"(%3, %2) : (i32, i32) -> ()
+ %3 = "test.baz"(%0) : (i32) -> i32
+ }
+ return
+}
+
+// Test region ops (and recursive sort).
+// CHECK-LABEL: test.graph_region
+test.graph_region {
+ // CHECK-NEXT: test.baz
+ // CHECK-NEXT: test.graph_region attributes {a} {
+ // CHECK-NEXT: test.b
+ // CHECK-NEXT: test.a
+ // CHECK-NEXT: }
+ // CHECK-NEXT: test.bar
+ // CHECK-NEXT: test.foo
+ %0 = "test.foo"(%1) : (i32) -> i32
+ test.graph_region attributes {a} {
+ %a = "test.a"(%b) : (i32) -> i32
+ %b = "test.b"(%2) : (i32) -> i32
+ }
+ %1 = "test.bar"(%2) : (i32) -> i32
+ %2 = "test.baz"() : () -> i32
+}
More information about the Mlir-commits
mailing list