[Mlir-commits] [mlir] 47424f2 - [mlir][transforms] TopologicalSort: Support ops from different blocks
Matthias Springer
llvmlistbot at llvm.org
Wed Oct 12 18:36:43 PDT 2022
Author: Matthias Springer
Date: 2022-10-13T10:36:06+09:00
New Revision: 47424f22d46b340b9f9204647168fb4190c64472
URL: https://github.com/llvm/llvm-project/commit/47424f22d46b340b9f9204647168fb4190c64472
DIFF: https://github.com/llvm/llvm-project/commit/47424f22d46b340b9f9204647168fb4190c64472.diff
LOG: [mlir][transforms] TopologicalSort: Support ops from different blocks
This change allows analyzing ops from different block, in particular when used in programs that have `cf` branches.
Differential Revision: https://reviews.llvm.org/D135644
Added:
Modified:
mlir/include/mlir/Transforms/TopologicalSortUtils.h
mlir/lib/Transforms/Utils/TopologicalSortUtils.cpp
mlir/test/Transforms/test-toposort.mlir
mlir/test/lib/Transforms/TestTopologicalSort.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Transforms/TopologicalSortUtils.h b/mlir/include/mlir/Transforms/TopologicalSortUtils.h
index 1a50d4dcf549c..74e44b1dc485d 100644
--- a/mlir/include/mlir/Transforms/TopologicalSortUtils.h
+++ b/mlir/include/mlir/Transforms/TopologicalSortUtils.h
@@ -95,16 +95,13 @@ bool sortTopologically(
Block *block,
function_ref<bool(Value, Operation *)> isOperandReady = nullptr);
-/// Compute a topological ordering of the given ops. All ops must belong to the
-/// specified block.
-///
-/// This sort is not stable.
+/// Compute a topological ordering of the given ops. This sort is not stable.
///
/// Note: If the specified ops contain incomplete/interrupted SSA use-def
/// chains, the result may not actually be a topological sorting with respect to
/// the entire program.
bool computeTopologicalSorting(
- Block *block, MutableArrayRef<Operation *> ops,
+ MutableArrayRef<Operation *> ops,
function_ref<bool(Value, Operation *)> isOperandReady = nullptr);
} // end namespace mlir
diff --git a/mlir/lib/Transforms/Utils/TopologicalSortUtils.cpp b/mlir/lib/Transforms/Utils/TopologicalSortUtils.cpp
index 8767877515d9f..f3a9d217f2c98 100644
--- a/mlir/lib/Transforms/Utils/TopologicalSortUtils.cpp
+++ b/mlir/lib/Transforms/Utils/TopologicalSortUtils.cpp
@@ -12,12 +12,11 @@
using namespace mlir;
/// Return `true` if the given operation is ready to be scheduled.
-static bool isOpReady(Block *block, Operation *op,
- DenseSet<Operation *> &unscheduledOps,
+static bool isOpReady(Operation *op, DenseSet<Operation *> &unscheduledOps,
function_ref<bool(Value, Operation *)> isOperandReady) {
// An operation is ready to be scheduled if all its operands are ready. An
// operation is ready if:
- const auto isReady = [&](Value value, Operation *top) {
+ const auto isReady = [&](Value value) {
// - the user-provided callback marks it as ready,
if (isOperandReady && isOperandReady(value, op))
return true;
@@ -25,22 +24,24 @@ static bool isOpReady(Block *block, Operation *op,
// - it is a block argument,
if (!parent)
return true;
- Operation *ancestor = block->findAncestorOpInBlock(*parent);
- // - it is an implicit capture,
- if (!ancestor)
- return true;
- // - it is defined in a nested region, or
- if (ancestor == op)
- return true;
- // - its ancestor in the block is scheduled.
- return !unscheduledOps.contains(ancestor);
+ // - or it is not defined by an unscheduled op (and also not nested within
+ // an unscheduled op).
+ do {
+ // Stop traversal when op under examination is reached.
+ if (parent == op)
+ return true;
+ if (unscheduledOps.contains(parent))
+ return false;
+ } while ((parent = parent->getParentOp()));
+ // No unscheduled op found.
+ return true;
};
// An operation is recursively ready to be scheduled of it and its nested
// operations are ready.
WalkResult readyToSchedule = op->walk([&](Operation *nestedOp) {
return llvm::all_of(nestedOp->getOperands(),
- [&](Value operand) { return isReady(operand, op); })
+ [&](Value operand) { return isReady(operand); })
? WalkResult::advance()
: WalkResult::interrupt();
});
@@ -71,7 +72,7 @@ bool mlir::sortTopologically(
// set, and "schedule" it (move it before the `nextScheduledOp`).
for (Operation &op :
llvm::make_early_inc_range(llvm::make_range(nextScheduledOp, end))) {
- if (!isOpReady(block, &op, unscheduledOps, isOperandReady))
+ if (!isOpReady(&op, unscheduledOps, isOperandReady))
continue;
// Schedule the operation by moving it to the start.
@@ -104,7 +105,7 @@ bool mlir::sortTopologically(
}
bool mlir::computeTopologicalSorting(
- Block *block, MutableArrayRef<Operation *> ops,
+ MutableArrayRef<Operation *> ops,
function_ref<bool(Value, Operation *)> isOperandReady) {
if (ops.empty())
return true;
@@ -113,10 +114,8 @@ bool mlir::computeTopologicalSorting(
DenseSet<Operation *> unscheduledOps;
// Mark all operations as unscheduled.
- for (Operation *op : ops) {
- assert(op->getBlock() == block && "op must belong to block");
+ for (Operation *op : ops)
unscheduledOps.insert(op);
- }
unsigned nextScheduledOp = 0;
@@ -128,7 +127,7 @@ bool mlir::computeTopologicalSorting(
// i.e. the ones for which there aren't any operand produced by an op in the
// set, and "schedule" it (swap it with the op at `nextScheduledOp`).
for (unsigned i = nextScheduledOp; i < ops.size(); ++i) {
- if (!isOpReady(block, ops[i], unscheduledOps, isOperandReady))
+ if (!isOpReady(ops[i], unscheduledOps, isOperandReady))
continue;
// Schedule the operation by moving it to the start.
diff --git a/mlir/test/Transforms/test-toposort.mlir b/mlir/test/Transforms/test-toposort.mlir
index 2ebf35c753bab..c47b885dbec78 100644
--- a/mlir/test/Transforms/test-toposort.mlir
+++ b/mlir/test/Transforms/test-toposort.mlir
@@ -1,5 +1,5 @@
-// RUN: mlir-opt -topological-sort %s | FileCheck %s
-// RUN: mlir-opt -test-topological-sort-analysis %s | FileCheck %s -check-prefix=CHECK-ANALYSIS
+// RUN: mlir-opt %s -topological-sort | FileCheck %s
+// RUN: mlir-opt %s -test-topological-sort-analysis -verify-diagnostics | FileCheck %s -check-prefix=CHECK-ANALYSIS
// Test producer is after user.
// CHECK-LABEL: test.graph_region
@@ -36,6 +36,35 @@ test.graph_region attributes{"root"} {
%3 = "test.d"() {selected} : () -> i32
}
+// Test not all scheduled.
+// CHECK-LABEL: test.graph_region
+// CHECK-ANALYSIS-LABEL: test.graph_region
+// expected-error at +1 {{could not schedule all ops}}
+test.graph_region attributes{"root"} {
+ %0 = "test.a"(%1) {selected} : (i32) -> i32
+ %1 = "test.b"(%0) {selected} : (i32) -> i32
+}
+
+// CHECK-LABEL: func @test_multiple_blocks
+// CHECK-ANALYSIS-LABEL: func @test_multiple_blocks
+func.func @test_multiple_blocks() -> (i32) attributes{"root", "ordered"} {
+ // CHECK-ANALYSIS-NEXT: test.foo{{.*}} {pos = 0
+ %0 = "test.foo"() {selected = 2} : () -> (i32)
+ // CHECK-ANALYSIS-NEXT: test.foo
+ %1 = "test.foo"() : () -> (i32)
+ cf.br ^bb0
+^bb0:
+ // CHECK-ANALYSIS: test.foo{{.*}} {pos = 1
+ %2 = "test.foo"() {selected = 3} : () -> (i32)
+ // CHECK-ANALYSIS-NEXT: test.bar{{.*}} {pos = 2
+ %3 = "test.bar"(%0, %1, %2) {selected = 0} : (i32, i32, i32) -> (i32)
+ cf.br ^bb1 (%2 : i32)
+^bb1(%arg0: i32):
+ // CHECK-ANALYSIS: test.qux{{.*}} {pos = 3
+ %4 = "test.qux"(%arg0, %0) {selected = 1} : (i32, i32) -> (i32)
+ return %4 : i32
+}
+
// Test block arguments.
// CHECK-LABEL: test.graph_region
test.graph_region {
diff --git a/mlir/test/lib/Transforms/TestTopologicalSort.cpp b/mlir/test/lib/Transforms/TestTopologicalSort.cpp
index 9ed64eae4b423..4ad5b5c2608fc 100644
--- a/mlir/test/lib/Transforms/TestTopologicalSort.cpp
+++ b/mlir/test/lib/Transforms/TestTopologicalSort.cpp
@@ -30,25 +30,47 @@ struct TestTopologicalSortAnalysisPass
Operation *op = getOperation();
OpBuilder builder(op->getContext());
- op->walk([&](Operation *root) {
+ WalkResult result = op->walk([&](Operation *root) {
if (!root->hasAttr("root"))
return WalkResult::advance();
- assert(root->getNumRegions() == 1 && root->getRegion(0).hasOneBlock() &&
- "expected one block");
- Block *block = &root->getRegion(0).front();
SmallVector<Operation *> selectedOps;
- block->walk([&](Operation *op) {
- if (op->hasAttr("selected"))
- selectedOps.push_back(op);
+ root->walk([&](Operation *selected) {
+ if (!selected->hasAttr("selected"))
+ return WalkResult::advance();
+ if (root->hasAttr("ordered")) {
+ // If the root has an "ordered" attribute, we fill the selectedOps
+ // vector in a certain order.
+ int64_t pos =
+ selected->getAttr("selected").cast<IntegerAttr>().getInt();
+ if (pos >= static_cast<int64_t>(selectedOps.size()))
+ selectedOps.append(pos + 1 - selectedOps.size(), nullptr);
+ selectedOps[pos] = selected;
+ } else {
+ selectedOps.push_back(selected);
+ }
+ return WalkResult::advance();
});
- computeTopologicalSorting(block, selectedOps);
+ if (llvm::find(selectedOps, nullptr) != selectedOps.end()) {
+ root->emitError("invalid test case: some indices are missing among the "
+ "selected ops");
+ return WalkResult::skip();
+ }
+
+ if (!computeTopologicalSorting(selectedOps)) {
+ root->emitError("could not schedule all ops");
+ return WalkResult::skip();
+ }
+
for (const auto &it : llvm::enumerate(selectedOps))
it.value()->setAttr("pos", builder.getIndexAttr(it.index()));
return WalkResult::advance();
});
+
+ if (result.wasSkipped())
+ signalPassFailure();
}
};
} // namespace
More information about the Mlir-commits
mailing list