[Mlir-commits] [mlir] 5aa6038 - [mlir] Make topologicalSort iterative and consider op regions

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Nov 10 10:05:17 PST 2021


Author: thomasraoux
Date: 2021-11-10T10:05:01-08:00
New Revision: 5aa6038a407451da2ca5438c5b03c40aa4c72aad

URL: https://github.com/llvm/llvm-project/commit/5aa6038a407451da2ca5438c5b03c40aa4c72aad
DIFF: https://github.com/llvm/llvm-project/commit/5aa6038a407451da2ca5438c5b03c40aa4c72aad.diff

LOG: [mlir] Make topologicalSort iterative and consider op regions

When doing topological sort we need to make sure an op is scheduled before any
of the ops within its regions.
Also change the algorithm to not be recursive in order to prevent potential
stack overflow.

Differential Revision: https://reviews.llvm.org/D113423

Added: 
    mlir/test/Analysis/test-topoligical-sort.mlir
    mlir/test/lib/Analysis/TestSlice.cpp

Modified: 
    mlir/lib/Analysis/SliceAnalysis.cpp
    mlir/test/lib/Analysis/CMakeLists.txt
    mlir/tools/mlir-opt/mlir-opt.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Analysis/SliceAnalysis.cpp b/mlir/lib/Analysis/SliceAnalysis.cpp
index 4ab5b49b597a4..a29315a3938ce 100644
--- a/mlir/lib/Analysis/SliceAnalysis.cpp
+++ b/mlir/lib/Analysis/SliceAnalysis.cpp
@@ -168,20 +168,26 @@ struct DFSState {
 };
 } // namespace
 
-static void DFSPostorder(Operation *current, DFSState *state) {
-  for (Value result : current->getResults()) {
-    for (Operation *op : result.getUsers())
-      DFSPostorder(op, state);
-  }
-  bool inserted;
-  using IterTy = decltype(state->seen.begin());
-  IterTy iter;
-  std::tie(iter, inserted) = state->seen.insert(current);
-  if (inserted) {
-    if (state->toSort.count(current) > 0) {
-      state->topologicalCounts.push_back(current);
+static void DFSPostorder(Operation *root, DFSState *state) {
+  SmallVector<Operation *> queue(1, root);
+  std::vector<Operation *> ops;
+  while (!queue.empty()) {
+    Operation *current = queue.pop_back_val();
+    ops.push_back(current);
+    for (Value result : current->getResults()) {
+      for (Operation *op : result.getUsers())
+        queue.push_back(op);
+    }
+    for (Region &region : current->getRegions()) {
+      for (Operation &op : region.getOps())
+        queue.push_back(&op);
     }
   }
+
+  for (Operation *op : llvm::reverse(ops)) {
+    if (state->seen.insert(op).second && state->toSort.count(op) > 0)
+      state->topologicalCounts.push_back(op);
+  }
 }
 
 SetVector<Operation *>

diff  --git a/mlir/test/Analysis/test-topoligical-sort.mlir b/mlir/test/Analysis/test-topoligical-sort.mlir
new file mode 100644
index 0000000000000..a93468580fa68
--- /dev/null
+++ b/mlir/test/Analysis/test-topoligical-sort.mlir
@@ -0,0 +1,21 @@
+// RUN: mlir-opt %s -test-print-topological-sort 2>&1 | FileCheck %s
+
+// CHECK-LABEL: Testing : region
+//       CHECK: arith.addi {{.*}} : index
+//  CHECK-NEXT: scf.for
+//       CHECK: } {__test_sort_original_idx__ = 2 : i64}
+//  CHECK-NEXT: arith.addi {{.*}} : i32
+//  CHECK-NEXT: arith.subi {{.*}} : i32
+func @region(
+  %arg0 : index, %arg1 : index, %arg2 : index, %arg3 : index,
+  %arg4 : i32, %arg5 : i32, %arg6 : i32,
+  %buffer : memref<i32>) {
+  %0 = arith.addi %arg4, %arg5 {__test_sort_original_idx__ = 0} : i32
+  %idx = arith.addi %arg0, %arg1 {__test_sort_original_idx__ = 3} : index
+  scf.for %arg7 = %idx to %arg2 step %arg3  {
+    %2 = arith.addi %0, %arg5 : i32
+    %3 = arith.subi %2, %arg6 {__test_sort_original_idx__ = 1} : i32
+    memref.store %3, %buffer[] : memref<i32>
+  } {__test_sort_original_idx__ = 2}
+  return
+}

diff  --git a/mlir/test/lib/Analysis/CMakeLists.txt b/mlir/test/lib/Analysis/CMakeLists.txt
index aa9eadb6706c7..00321034429d6 100644
--- a/mlir/test/lib/Analysis/CMakeLists.txt
+++ b/mlir/test/lib/Analysis/CMakeLists.txt
@@ -8,6 +8,7 @@ add_mlir_library(MLIRTestAnalysis
   TestMemRefDependenceCheck.cpp
   TestMemRefStrideCalculation.cpp
   TestNumberOfExecutions.cpp
+  TestSlice.cpp
 
 
   EXCLUDE_FROM_LIBMLIR

diff  --git a/mlir/test/lib/Analysis/TestSlice.cpp b/mlir/test/lib/Analysis/TestSlice.cpp
new file mode 100644
index 0000000000000..962df7602a29f
--- /dev/null
+++ b/mlir/test/lib/Analysis/TestSlice.cpp
@@ -0,0 +1,50 @@
+//===------------- TestSlice.cpp - Test slice related analisis ------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+
+static const StringLiteral kOrderMarker = "__test_sort_original_idx__";
+
+namespace {
+
+struct TestTopologicalSortPass
+    : public PassWrapper<TestTopologicalSortPass, FunctionPass> {
+  StringRef getArgument() const final { return "test-print-topological-sort"; }
+  StringRef getDescription() const final {
+    return "Print operations in topological order";
+  }
+  void runOnFunction() override {
+    std::map<int, Operation *> ops;
+    getFunction().walk([&ops](Operation *op) {
+      if (auto originalOrderAttr = op->getAttrOfType<IntegerAttr>(kOrderMarker))
+        ops[originalOrderAttr.getInt()] = op;
+    });
+    SetVector<Operation *> sortedOp;
+    for (auto op : ops)
+      sortedOp.insert(op.second);
+    sortedOp = topologicalSort(sortedOp);
+    llvm::errs() << "Testing : " << getFunction().getName() << "\n";
+    for (Operation *op : sortedOp) {
+      op->print(llvm::errs());
+      llvm::errs() << "\n";
+    }
+  }
+};
+
+} // end anonymous namespace
+
+namespace mlir {
+namespace test {
+void registerTestSliceAnalysisPass() {
+  PassRegistration<TestTopologicalSortPass>();
+}
+} // namespace test
+} // namespace mlir

diff  --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index ec8d002ba7b2b..dd24b4e507523 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -106,6 +106,7 @@ void registerTestPDLByteCodePass();
 void registerTestPreparationPassWithAllowedMemrefResults();
 void registerTestRecursiveTypesPass();
 void registerTestSCFUtilsPass();
+void registerTestSliceAnalysisPass();
 void registerTestVectorConversions();
 } // namespace test
 } // namespace mlir
@@ -195,6 +196,7 @@ void registerTestPasses() {
   mlir::test::registerTestPDLByteCodePass();
   mlir::test::registerTestRecursiveTypesPass();
   mlir::test::registerTestSCFUtilsPass();
+  mlir::test::registerTestSliceAnalysisPass();
   mlir::test::registerTestVectorConversions();
 }
 #endif


        


More information about the Mlir-commits mailing list