[Mlir-commits] [mlir] c8457eb - [mlir][transforms] Add a topological sort utility and pass

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon May 16 13:47:35 PDT 2022


Author: Mogball
Date: 2022-05-16T20:47:30Z
New Revision: c8457eb5323ca99361c1748a22a16edb3160ae5f

URL: https://github.com/llvm/llvm-project/commit/c8457eb5323ca99361c1748a22a16edb3160ae5f
DIFF: https://github.com/llvm/llvm-project/commit/c8457eb5323ca99361c1748a22a16edb3160ae5f.diff

LOG: [mlir][transforms] Add a topological sort utility and pass

This patch adds a topological sort utility and pass. A topological sort reorders
the operations in a block without SSA dominance such that, as much as possible,
users of values come after their producers.

The utility function sorts topologically the operation range in a given block
with an optional user-provided callback that can be used to virtually break cycles.
The toposort pass itself recursively sorts graph regions under the target op.

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D125063

Added: 
    mlir/include/mlir/Transforms/TopologicalSortUtils.h
    mlir/lib/Transforms/TopologicalSort.cpp
    mlir/lib/Transforms/Utils/TopologicalSortUtils.cpp
    mlir/test/Transforms/test-toposort.mlir

Modified: 
    mlir/include/mlir/Transforms/Passes.h
    mlir/include/mlir/Transforms/Passes.td
    mlir/lib/Transforms/CMakeLists.txt
    mlir/lib/Transforms/Utils/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h
index a74f65a81583f..f6e1feeaf05ef 100644
--- a/mlir/include/mlir/Transforms/Passes.h
+++ b/mlir/include/mlir/Transforms/Passes.h
@@ -92,6 +92,11 @@ std::unique_ptr<Pass> createSymbolDCEPass();
 std::unique_ptr<Pass>
 createSymbolPrivatizePass(ArrayRef<std::string> excludeSymbols = {});
 
+/// Creates a pass that recursively sorts nested regions without SSA dominance
+/// topologically such that, as much as possible, users of values appear after
+/// their producers.
+std::unique_ptr<Pass> createTopologicalSortPass();
+
 //===----------------------------------------------------------------------===//
 // Registration
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index ee802269b3034..7ac71ceeb588a 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -260,4 +260,21 @@ def ViewOpGraph : Pass<"view-op-graph"> {
   let constructor = "mlir::createPrintOpGraphPass()";
 }
 
+def TopologicalSort : Pass<"topological-sort"> {
+  let summary = "Sort regions without SSA dominance in topological order";
+  let description = [{
+    Recursively sorts all nested regions without SSA dominance in topological
+    order. The main purpose is readability, as well as potentially processing of
+    certain transformations and analyses. The function sorts the operations in
+    all nested regions such that, as much as possible, all users appear after
+    their producers.
+
+    This sort is stable. If the block is already topologically sorted, the IR
+    is not changed. Operations that form a cycle are moved to the end of the
+    regions in a stable order.
+  }];
+
+  let constructor = "mlir::createTopologicalSortPass()";
+}
+
 #endif // MLIR_TRANSFORMS_PASSES

diff  --git a/mlir/include/mlir/Transforms/TopologicalSortUtils.h b/mlir/include/mlir/Transforms/TopologicalSortUtils.h
new file mode 100644
index 0000000000000..828726df00a53
--- /dev/null
+++ b/mlir/include/mlir/Transforms/TopologicalSortUtils.h
@@ -0,0 +1,100 @@
+//===- TopologicalSortUtils.h - Topological sort utilities ------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TRANSFORMS_TOPOLOGICALSORTUTILS_H
+#define MLIR_TRANSFORMS_TOPOLOGICALSORTUTILS_H
+
+#include "mlir/IR/Block.h"
+
+namespace mlir {
+
+/// Given a block, sort a range operations in said block in topological order.
+/// The main purpose is readability of graph regions, potentially faster
+/// processing of certain transformations and analyses, or fixing the SSA
+/// dominance of blocks that require it after transformations. The function
+/// sorts the given operations such that, as much as possible, all users appear
+/// after their producers.
+///
+/// For example:
+///
+/// ```mlir
+/// %0 = test.foo
+/// %1 = test.bar %0, %2
+/// %2 = test.baz
+/// ```
+///
+/// Will become:
+///
+/// ```mlir
+/// %0 = test.foo
+/// %1 = test.baz
+/// %2 = test.bar %0, %1
+/// ```
+///
+/// The sort also works on operations with regions and implicit captures. For
+/// example:
+///
+/// ```mlir
+/// %0 = test.foo {
+///   test.baz %1
+///   %1 = test.bar %2
+/// }
+/// %2 = test.foo
+/// ```
+///
+/// Will become:
+///
+/// ```mlir
+/// %0 = test.foo
+/// %1 = test.foo {
+///   test.baz %2
+///   %2 = test.bar %0
+/// }
+/// ```
+///
+/// Note that the sort is not recursive on nested regions. This sort is stable;
+/// if the operations are already topologically sorted, nothing changes.
+///
+/// Operations that form cycles are moved to the end of the block in order. If
+/// the sort is left with only operations that form a cycle, it breaks the cycle
+/// by marking the first encountered operation as ready and moving on.
+///
+/// The function optionally accepts a callback that can be provided by users to
+/// virtually break cycles early. It is called on top-level operations in the
+/// block with value uses at or below those operations. The function should
+/// return true to mark that value as ready to be scheduled.
+///
+/// For example, if `isOperandReady` is set to always mark edges from `foo.A` to
+/// `foo.B` as ready, these operations:
+///
+/// ```mlir
+/// %0 = foo.B(%1)
+/// %1 = foo.C(%2)
+/// %2 = foo.A(%0)
+/// ```
+///
+/// Are sorted as:
+///
+/// ```mlir
+/// %0 = foo.A(%2)
+/// %1 = foo.C(%0)
+/// %2 = foo.B(%1)
+/// ```
+bool sortTopologically(
+    Block *block, iterator_range<Block::iterator> ops,
+    function_ref<bool(Value, Operation *)> isOperandReady = nullptr);
+
+/// Given a block, sort its operations in topological order, excluding its
+/// terminator if it has one.
+bool sortTopologically(
+    Block *block,
+    function_ref<bool(Value, Operation *)> isOperandReady = nullptr);
+
+} // end namespace mlir
+
+#endif // MLIR_TRANSFORMS_TOPOLOGICALSORTUTILS_H

diff  --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt
index fef0feb379e57..52976c7040c42 100644
--- a/mlir/lib/Transforms/CMakeLists.txt
+++ b/mlir/lib/Transforms/CMakeLists.txt
@@ -12,6 +12,7 @@ add_mlir_library(MLIRTransforms
   StripDebugInfo.cpp
   SymbolDCE.cpp
   SymbolPrivatize.cpp
+  TopologicalSort.cpp
   ViewOpGraph.cpp
 
   ADDITIONAL_HEADER_DIRS

diff  --git a/mlir/lib/Transforms/TopologicalSort.cpp b/mlir/lib/Transforms/TopologicalSort.cpp
new file mode 100644
index 0000000000000..afa0b78fbf255
--- /dev/null
+++ b/mlir/lib/Transforms/TopologicalSort.cpp
@@ -0,0 +1,33 @@
+//===- TopologicalSort.cpp - Topological sort pass ------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "PassDetail.h"
+#include "mlir/IR/RegionKindInterface.h"
+#include "mlir/Transforms/TopologicalSortUtils.h"
+
+using namespace mlir;
+
+namespace {
+struct TopologicalSortPass : public TopologicalSortBase<TopologicalSortPass> {
+  void runOnOperation() override {
+    // Topologically sort the regions of the operation without SSA dominance.
+    getOperation()->walk([](RegionKindInterface op) {
+      for (auto &it : llvm::enumerate(op->getRegions())) {
+        if (op.hasSSADominance(it.index()))
+          continue;
+        for (Block &block : it.value())
+          sortTopologically(&block);
+      }
+    });
+  }
+};
+} // end anonymous namespace
+
+std::unique_ptr<Pass> mlir::createTopologicalSortPass() {
+  return std::make_unique<TopologicalSortPass>();
+}

diff  --git a/mlir/lib/Transforms/Utils/CMakeLists.txt b/mlir/lib/Transforms/Utils/CMakeLists.txt
index a9d410cb21c0a..755e3196837d2 100644
--- a/mlir/lib/Transforms/Utils/CMakeLists.txt
+++ b/mlir/lib/Transforms/Utils/CMakeLists.txt
@@ -7,6 +7,7 @@ add_mlir_library(MLIRTransformUtils
   LoopInvariantCodeMotionUtils.cpp
   RegionUtils.cpp
   SideEffectUtils.cpp
+  TopologicalSortUtils.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Transforms

diff  --git a/mlir/lib/Transforms/Utils/TopologicalSortUtils.cpp b/mlir/lib/Transforms/Utils/TopologicalSortUtils.cpp
new file mode 100644
index 0000000000000..992db294a5650
--- /dev/null
+++ b/mlir/lib/Transforms/Utils/TopologicalSortUtils.cpp
@@ -0,0 +1,98 @@
+//===- TopologicalSortUtils.h - Topological sort utilities ------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Transforms/TopologicalSortUtils.h"
+#include "mlir/IR/OpDefinition.h"
+
+using namespace mlir;
+
+bool mlir::sortTopologically(
+    Block *block, llvm::iterator_range<Block::iterator> ops,
+    function_ref<bool(Value, Operation *)> isOperandReady) {
+  if (ops.empty())
+    return true;
+
+  // The set of operations that have not yet been scheduled.
+  DenseSet<Operation *> unscheduledOps;
+  // Mark all operations as unscheduled.
+  for (Operation &op : ops)
+    unscheduledOps.insert(&op);
+
+  Block::iterator nextScheduledOp = ops.begin();
+  Block::iterator end = ops.end();
+
+  // An operation is ready to be scheduled if all its operands are ready. An
+  // operation is ready if:
+  const auto isReady = [&](Value value, Operation *top) {
+    // - the user-provided callback marks it as ready,
+    if (isOperandReady && isOperandReady(value, top))
+      return true;
+    Operation *parent = value.getDefiningOp();
+    // - it is a block argument,
+    if (!parent)
+      return true;
+    Operation *ancestor = block->findAncestorOpInBlock(*parent);
+    // - it is an implicit capture,
+    if (!ancestor)
+      return true;
+    // - it is defined in a nested region, or
+    if (ancestor == top)
+      return true;
+    // - its ancestor in the block is scheduled.
+    return !unscheduledOps.contains(ancestor);
+  };
+
+  bool allOpsScheduled = true;
+  while (!unscheduledOps.empty()) {
+    bool scheduledAtLeastOnce = false;
+
+    // Loop over the ops that are not sorted yet, try to find the ones "ready",
+    // i.e. the ones for which there aren't any operand produced by an op in the
+    // set, and "schedule" it (move it before the `nextScheduledOp`).
+    for (Operation &op :
+         llvm::make_early_inc_range(llvm::make_range(nextScheduledOp, end))) {
+      // An operation is recursively ready to be scheduled of it and its nested
+      // operations are ready.
+      WalkResult readyToSchedule = op.walk([&](Operation *nestedOp) {
+        return llvm::all_of(
+                   nestedOp->getOperands(),
+                   [&](Value operand) { return isReady(operand, &op); })
+                   ? WalkResult::advance()
+                   : WalkResult::interrupt();
+      });
+      if (readyToSchedule.wasInterrupted())
+        continue;
+
+      // Schedule the operation by moving it to the start.
+      unscheduledOps.erase(&op);
+      op.moveBefore(block, nextScheduledOp);
+      scheduledAtLeastOnce = true;
+      // Move the iterator forward if we schedule the operation at the front.
+      if (&op == &*nextScheduledOp)
+        ++nextScheduledOp;
+    }
+    // If no operations were scheduled, give up and advance the iterator.
+    if (!scheduledAtLeastOnce) {
+      allOpsScheduled = false;
+      unscheduledOps.erase(&*nextScheduledOp);
+      ++nextScheduledOp;
+    }
+  }
+
+  return allOpsScheduled;
+}
+
+bool mlir::sortTopologically(
+    Block *block, function_ref<bool(Value, Operation *)> isOperandReady) {
+  if (block->empty())
+    return true;
+  if (block->back().hasTrait<OpTrait::IsTerminator>())
+    return sortTopologically(block, block->without_terminator(),
+                             isOperandReady);
+  return sortTopologically(block, *block, isOperandReady);
+}

diff  --git a/mlir/test/Transforms/test-toposort.mlir b/mlir/test/Transforms/test-toposort.mlir
new file mode 100644
index 0000000000000..c792add375bcb
--- /dev/null
+++ b/mlir/test/Transforms/test-toposort.mlir
@@ -0,0 +1,75 @@
+// RUN: mlir-opt -topological-sort %s | FileCheck %s
+
+// Test producer is after user.
+// CHECK-LABEL: test.graph_region
+test.graph_region {
+  // CHECK-NEXT: test.foo
+  // CHECK-NEXT: test.baz
+  // CHECK-NEXT: test.bar
+  %0 = "test.foo"() : () -> i32
+  "test.bar"(%1, %0) : (i32, i32) -> ()
+  %1 = "test.baz"() : () -> i32
+}
+
+// Test cycles.
+// CHECK-LABEL: test.graph_region
+test.graph_region {
+  // CHECK-NEXT: test.d
+  // CHECK-NEXT: test.a
+  // CHECK-NEXT: test.c
+  // CHECK-NEXT: test.b
+  %2 = "test.c"(%1) : (i32) -> i32
+  %1 = "test.b"(%0, %2) : (i32, i32) -> i32
+  %0 = "test.a"(%3) : (i32) -> i32
+  %3 = "test.d"() : () -> i32
+}
+
+// Test block arguments.
+// CHECK-LABEL: test.graph_region
+test.graph_region {
+// CHECK-NEXT: (%{{.*}}:
+^entry(%arg0: i32):
+  // CHECK-NEXT: test.foo
+  // CHECK-NEXT: test.baz
+  // CHECK-NEXT: test.bar
+  %0 = "test.foo"(%arg0) : (i32) -> i32
+  "test.bar"(%1, %0) : (i32, i32) -> ()
+  %1 = "test.baz"(%arg0) : (i32) -> i32
+}
+
+// Test implicit block capture (and sort nested region).
+// CHECK-LABEL: test.graph_region
+func.func @test_graph_cfg() -> () {
+  %0 = "test.foo"() : () -> i32
+  cf.br ^next(%0 : i32)
+
+^next(%1: i32):
+  test.graph_region {
+    // CHECK-NEXT: test.foo
+    // CHECK-NEXT: test.baz
+    // CHECK-NEXT: test.bar
+    %2 = "test.foo"(%1) : (i32) -> i32
+    "test.bar"(%3, %2) : (i32, i32) -> ()
+    %3 = "test.baz"(%0) : (i32) -> i32
+  }
+  return
+}
+
+// Test region ops (and recursive sort).
+// CHECK-LABEL: test.graph_region
+test.graph_region {
+  // CHECK-NEXT: test.baz
+  // CHECK-NEXT: test.graph_region attributes {a} {
+  // CHECK-NEXT:   test.b
+  // CHECK-NEXT:   test.a
+  // CHECK-NEXT: }
+  // CHECK-NEXT: test.bar
+  // CHECK-NEXT: test.foo
+  %0 = "test.foo"(%1) : (i32) -> i32
+  test.graph_region attributes {a} {
+    %a = "test.a"(%b) : (i32) -> i32
+    %b = "test.b"(%2) : (i32) -> i32
+  }
+  %1 = "test.bar"(%2) : (i32) -> i32
+  %2 = "test.baz"() : () -> i32
+}


        


More information about the Mlir-commits mailing list