[Mlir-commits] [mlir] 5aa6038 - [mlir] Make topologicalSort iterative and consider op regions
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Nov 10 10:05:17 PST 2021
Author: thomasraoux
Date: 2021-11-10T10:05:01-08:00
New Revision: 5aa6038a407451da2ca5438c5b03c40aa4c72aad
URL: https://github.com/llvm/llvm-project/commit/5aa6038a407451da2ca5438c5b03c40aa4c72aad
DIFF: https://github.com/llvm/llvm-project/commit/5aa6038a407451da2ca5438c5b03c40aa4c72aad.diff
LOG: [mlir] Make topologicalSort iterative and consider op regions
When doing topological sort we need to make sure an op is scheduled before any
of the ops within its regions.
Also change the algorithm to not be recursive in order to prevent potential
stack overflow.
Differential Revision: https://reviews.llvm.org/D113423
Added:
mlir/test/Analysis/test-topoligical-sort.mlir
mlir/test/lib/Analysis/TestSlice.cpp
Modified:
mlir/lib/Analysis/SliceAnalysis.cpp
mlir/test/lib/Analysis/CMakeLists.txt
mlir/tools/mlir-opt/mlir-opt.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Analysis/SliceAnalysis.cpp b/mlir/lib/Analysis/SliceAnalysis.cpp
index 4ab5b49b597a4..a29315a3938ce 100644
--- a/mlir/lib/Analysis/SliceAnalysis.cpp
+++ b/mlir/lib/Analysis/SliceAnalysis.cpp
@@ -168,20 +168,26 @@ struct DFSState {
};
} // namespace
-static void DFSPostorder(Operation *current, DFSState *state) {
- for (Value result : current->getResults()) {
- for (Operation *op : result.getUsers())
- DFSPostorder(op, state);
- }
- bool inserted;
- using IterTy = decltype(state->seen.begin());
- IterTy iter;
- std::tie(iter, inserted) = state->seen.insert(current);
- if (inserted) {
- if (state->toSort.count(current) > 0) {
- state->topologicalCounts.push_back(current);
+static void DFSPostorder(Operation *root, DFSState *state) {
+ SmallVector<Operation *> queue(1, root);
+ std::vector<Operation *> ops;
+ while (!queue.empty()) {
+ Operation *current = queue.pop_back_val();
+ ops.push_back(current);
+ for (Value result : current->getResults()) {
+ for (Operation *op : result.getUsers())
+ queue.push_back(op);
+ }
+ for (Region ®ion : current->getRegions()) {
+ for (Operation &op : region.getOps())
+ queue.push_back(&op);
}
}
+
+ for (Operation *op : llvm::reverse(ops)) {
+ if (state->seen.insert(op).second && state->toSort.count(op) > 0)
+ state->topologicalCounts.push_back(op);
+ }
}
SetVector<Operation *>
diff --git a/mlir/test/Analysis/test-topoligical-sort.mlir b/mlir/test/Analysis/test-topoligical-sort.mlir
new file mode 100644
index 0000000000000..a93468580fa68
--- /dev/null
+++ b/mlir/test/Analysis/test-topoligical-sort.mlir
@@ -0,0 +1,21 @@
+// RUN: mlir-opt %s -test-print-topological-sort 2>&1 | FileCheck %s
+
+// CHECK-LABEL: Testing : region
+// CHECK: arith.addi {{.*}} : index
+// CHECK-NEXT: scf.for
+// CHECK: } {__test_sort_original_idx__ = 2 : i64}
+// CHECK-NEXT: arith.addi {{.*}} : i32
+// CHECK-NEXT: arith.subi {{.*}} : i32
+func @region(
+ %arg0 : index, %arg1 : index, %arg2 : index, %arg3 : index,
+ %arg4 : i32, %arg5 : i32, %arg6 : i32,
+ %buffer : memref<i32>) {
+ %0 = arith.addi %arg4, %arg5 {__test_sort_original_idx__ = 0} : i32
+ %idx = arith.addi %arg0, %arg1 {__test_sort_original_idx__ = 3} : index
+ scf.for %arg7 = %idx to %arg2 step %arg3 {
+ %2 = arith.addi %0, %arg5 : i32
+ %3 = arith.subi %2, %arg6 {__test_sort_original_idx__ = 1} : i32
+ memref.store %3, %buffer[] : memref<i32>
+ } {__test_sort_original_idx__ = 2}
+ return
+}
diff --git a/mlir/test/lib/Analysis/CMakeLists.txt b/mlir/test/lib/Analysis/CMakeLists.txt
index aa9eadb6706c7..00321034429d6 100644
--- a/mlir/test/lib/Analysis/CMakeLists.txt
+++ b/mlir/test/lib/Analysis/CMakeLists.txt
@@ -8,6 +8,7 @@ add_mlir_library(MLIRTestAnalysis
TestMemRefDependenceCheck.cpp
TestMemRefStrideCalculation.cpp
TestNumberOfExecutions.cpp
+ TestSlice.cpp
EXCLUDE_FROM_LIBMLIR
diff --git a/mlir/test/lib/Analysis/TestSlice.cpp b/mlir/test/lib/Analysis/TestSlice.cpp
new file mode 100644
index 0000000000000..962df7602a29f
--- /dev/null
+++ b/mlir/test/lib/Analysis/TestSlice.cpp
@@ -0,0 +1,50 @@
+//===------------- TestSlice.cpp - Test slice related analisis ------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+
+static const StringLiteral kOrderMarker = "__test_sort_original_idx__";
+
+namespace {
+
+struct TestTopologicalSortPass
+ : public PassWrapper<TestTopologicalSortPass, FunctionPass> {
+ StringRef getArgument() const final { return "test-print-topological-sort"; }
+ StringRef getDescription() const final {
+ return "Print operations in topological order";
+ }
+ void runOnFunction() override {
+ std::map<int, Operation *> ops;
+ getFunction().walk([&ops](Operation *op) {
+ if (auto originalOrderAttr = op->getAttrOfType<IntegerAttr>(kOrderMarker))
+ ops[originalOrderAttr.getInt()] = op;
+ });
+ SetVector<Operation *> sortedOp;
+ for (auto op : ops)
+ sortedOp.insert(op.second);
+ sortedOp = topologicalSort(sortedOp);
+ llvm::errs() << "Testing : " << getFunction().getName() << "\n";
+ for (Operation *op : sortedOp) {
+ op->print(llvm::errs());
+ llvm::errs() << "\n";
+ }
+ }
+};
+
+} // end anonymous namespace
+
+namespace mlir {
+namespace test {
+void registerTestSliceAnalysisPass() {
+ PassRegistration<TestTopologicalSortPass>();
+}
+} // namespace test
+} // namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index ec8d002ba7b2b..dd24b4e507523 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -106,6 +106,7 @@ void registerTestPDLByteCodePass();
void registerTestPreparationPassWithAllowedMemrefResults();
void registerTestRecursiveTypesPass();
void registerTestSCFUtilsPass();
+void registerTestSliceAnalysisPass();
void registerTestVectorConversions();
} // namespace test
} // namespace mlir
@@ -195,6 +196,7 @@ void registerTestPasses() {
mlir::test::registerTestPDLByteCodePass();
mlir::test::registerTestRecursiveTypesPass();
mlir::test::registerTestSCFUtilsPass();
+ mlir::test::registerTestSliceAnalysisPass();
mlir::test::registerTestVectorConversions();
}
#endif
More information about the Mlir-commits
mailing list