[Mlir-commits] [mlir] 8010d7e - [mlir] add an option to print op stats in JSON

Okwan Kwon llvmlistbot at llvm.org
Wed Jun 15 10:07:46 PDT 2022


Author: Okwan Kwon
Date: 2022-06-15T10:07:36-07:00
New Revision: 8010d7e0446a09471d5349466b80651c3ad76af3

URL: https://github.com/llvm/llvm-project/commit/8010d7e0446a09471d5349466b80651c3ad76af3
DIFF: https://github.com/llvm/llvm-project/commit/8010d7e0446a09471d5349466b80651c3ad76af3.diff

LOG: [mlir] add an option to print op stats in JSON

Differential Revision: https://reviews.llvm.org/D127691

Added: 
    mlir/test/IR/op-stats-json.mlir

Modified: 
    mlir/include/mlir/Transforms/Passes.td
    mlir/lib/Transforms/OpStats.cpp
    mlir/test/CAPI/pass.c
    mlir/test/python/pass_manager.py

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index 7ac71ceeb588a..8b8e6a1001574 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -145,6 +145,10 @@ def LoopInvariantCodeMotion : Pass<"loop-invariant-code-motion"> {
 def PrintOpStats : Pass<"print-op-stats"> {
   let summary = "Print statistics of operations";
   let constructor = "mlir::createPrintOpStatsPass()";
+  let options = [
+    Option<"printAsJSON", "json", "bool", /*default=*/"false",
+           "print the stats as JSON">
+  ];
 }
 
 def SCCP : Pass<"sccp"> {

diff  --git a/mlir/lib/Transforms/OpStats.cpp b/mlir/lib/Transforms/OpStats.cpp
index 8adc4117f9542..e7740abb3d19f 100644
--- a/mlir/lib/Transforms/OpStats.cpp
+++ b/mlir/lib/Transforms/OpStats.cpp
@@ -27,6 +27,9 @@ struct PrintOpStatsPass : public PrintOpStatsBase<PrintOpStatsPass> {
   // Print summary of op stats.
   void printSummary();
 
+  // Print symmary of op stats in JSON.
+  void printSummaryInJSON();
+
 private:
   llvm::StringMap<int64_t> opCount;
   raw_ostream &os;
@@ -37,8 +40,12 @@ void PrintOpStatsPass::runOnOperation() {
   opCount.clear();
 
   // Compute the operation statistics for the currently visited operation.
-  getOperation()->walk([&](Operation *op) { ++opCount[op->getName().getStringRef()]; });
-  printSummary();
+  getOperation()->walk(
+      [&](Operation *op) { ++opCount[op->getName().getStringRef()]; });
+  if (printAsJSON) {
+    printSummaryInJSON();
+  } else
+    printSummary();
 }
 
 void PrintOpStatsPass::printSummary() {
@@ -80,6 +87,23 @@ void PrintOpStatsPass::printSummary() {
   }
 }
 
+void PrintOpStatsPass::printSummaryInJSON() {
+  SmallVector<StringRef, 64> sorted(opCount.keys());
+  llvm::sort(sorted);
+
+  os << "{\n";
+
+  for (unsigned i = 0, e = sorted.size(); i != e; ++i) {
+    const auto &key = sorted[i];
+    os << "  \"" << key << "\" : " << opCount[key];
+    if (i != e - 1)
+      os << ",\n";
+    else
+      os << "\n";
+  }
+  os << "}\n";
+}
+
 std::unique_ptr<Pass> mlir::createPrintOpStatsPass(raw_ostream &os) {
   return std::make_unique<PrintOpStatsPass>(os);
 }

diff  --git a/mlir/test/CAPI/pass.c b/mlir/test/CAPI/pass.c
index 63aba29e56f61..f73398e817beb 100644
--- a/mlir/test/CAPI/pass.c
+++ b/mlir/test/CAPI/pass.c
@@ -138,14 +138,14 @@ void testPrintPassPipeline() {
   mlirOpPassManagerAddOwnedPass(nestedFuncPm, printOpStatPass);
 
   // Print the top level pass manager
-  // CHECK: Top-level: builtin.module(func.func(print-op-stats))
+  // CHECK: Top-level: builtin.module(func.func(print-op-stats{json=false}))
   fprintf(stderr, "Top-level: ");
   mlirPrintPassPipeline(mlirPassManagerGetAsOpPassManager(pm), printToStderr,
                         NULL);
   fprintf(stderr, "\n");
 
   // Print the pipeline nested one level down
-  // CHECK: Nested Module: func.func(print-op-stats)
+  // CHECK: Nested Module: func.func(print-op-stats{json=false})
   fprintf(stderr, "Nested Module: ");
   mlirPrintPassPipeline(nestedModulePm, printToStderr, NULL);
   fprintf(stderr, "\n");
@@ -166,8 +166,9 @@ void testParsePassPipeline() {
   // Try parse a pipeline.
   MlirLogicalResult status = mlirParsePassPipeline(
       mlirPassManagerGetAsOpPassManager(pm),
-      mlirStringRefCreateFromCString("builtin.module(func.func(print-op-stats),"
-                                     " func.func(print-op-stats))"));
+      mlirStringRefCreateFromCString(
+          "builtin.module(func.func(print-op-stats{json=false}),"
+          " func.func(print-op-stats{json=false}))"));
   // Expect a failure, we haven't registered the print-op-stats pass yet.
   if (mlirLogicalResultIsSuccess(status)) {
     fprintf(
@@ -179,8 +180,9 @@ void testParsePassPipeline() {
   mlirRegisterTransformsPrintOpStats();
   status = mlirParsePassPipeline(
       mlirPassManagerGetAsOpPassManager(pm),
-      mlirStringRefCreateFromCString("builtin.module(func.func(print-op-stats),"
-                                     " func.func(print-op-stats))"));
+      mlirStringRefCreateFromCString(
+          "builtin.module(func.func(print-op-stats{json=false}),"
+          " func.func(print-op-stats{json=false}))"));
   // Expect a failure, we haven't registered the print-op-stats pass yet.
   if (mlirLogicalResultIsFailure(status)) {
     fprintf(stderr,
@@ -188,8 +190,8 @@ void testParsePassPipeline() {
     exit(EXIT_FAILURE);
   }
 
-  // CHECK: Round-trip: builtin.module(func.func(print-op-stats),
-  // func.func(print-op-stats))
+  // CHECK: Round-trip: builtin.module(func.func(print-op-stats{json=false}),
+  // func.func(print-op-stats{json=false}))
   fprintf(stderr, "Round-trip: ");
   mlirPrintPassPipeline(mlirPassManagerGetAsOpPassManager(pm), printToStderr,
                         NULL);

diff  --git a/mlir/test/IR/op-stats-json.mlir b/mlir/test/IR/op-stats-json.mlir
new file mode 100644
index 0000000000000..40b0602a95897
--- /dev/null
+++ b/mlir/test/IR/op-stats-json.mlir
@@ -0,0 +1,37 @@
+// RUN: mlir-opt -allow-unregistered-dialect -print-op-stats=json %s -o=/dev/null 2>&1 | FileCheck %s
+
+func.func @main(tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> {
+^bb0(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>):
+  %0 = arith.addf %arg0, %arg1 : tensor<4xf32>
+  %1 = arith.addf %arg0, %arg1 : tensor<4xf32>
+  %2 = arith.addf %arg0, %arg1 : tensor<4xf32>
+  %3 = arith.addf %arg0, %arg1 : tensor<4xf32>
+  %4 = arith.addf %arg0, %arg1 : tensor<4xf32>
+  %5 = arith.addf %arg0, %arg1 : tensor<4xf32>
+  %10 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32>
+  %11 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32>
+  %12 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32>
+  %13 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32>
+  %14 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32>
+  %15 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32>
+  %16 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32>
+  %17 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32>
+  %18 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32>
+  %19 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32>
+  %20 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32>
+  %21 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32>
+  %22 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32>
+  %23 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32>
+  %24 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32>
+  %25 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32>
+  %26 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32>
+  %30 = "long_op_name"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32>
+  return %1 : tensor<4xf32>
+}
+
+// CHECK: {
+// CHECK:   "arith.addf" : 6,
+// CHECK:   "func.return" : 1,
+// CHECK:   "long_op_name" : 1,
+// CHECK:   "xla.add" : 17
+// CHECK: }

diff  --git a/mlir/test/python/pass_manager.py b/mlir/test/python/pass_manager.py
index c046bb818632e..6cc627d542337 100644
--- a/mlir/test/python/pass_manager.py
+++ b/mlir/test/python/pass_manager.py
@@ -36,19 +36,19 @@ def testParseSuccess():
     # A first import is expected to fail because the pass isn't registered
     # until we import mlir.transforms
     try:
-      pm = PassManager.parse("builtin.module(func.func(print-op-stats))")
+      pm = PassManager.parse("builtin.module(func.func(print-op-stats{json=false}))")
       # TODO: this error should be propagate to Python but the C API does not help right now.
       # CHECK: error: 'print-op-stats' does not refer to a registered pass or pass pipeline
     except ValueError as e:
-      # CHECK: ValueError exception: invalid pass pipeline 'builtin.module(func.func(print-op-stats))'.
+      # CHECK: ValueError exception: invalid pass pipeline 'builtin.module(func.func(print-op-stats{json=false}))'.
       log("ValueError exception:", e)
     else:
       log("Exception not produced")
 
     # This will register the pass and round-trip should be possible now.
     import mlir.transforms
-    pm = PassManager.parse("builtin.module(func.func(print-op-stats))")
-    # CHECK: Roundtrip: builtin.module(func.func(print-op-stats))
+    pm = PassManager.parse("builtin.module(func.func(print-op-stats{json=false}))")
+    # CHECK: Roundtrip: builtin.module(func.func(print-op-stats{json=false}))
     log("Roundtrip: ", pm)
 run(testParseSuccess)
 
@@ -86,7 +86,7 @@ def testInvalidNesting():
 # CHECK-LABEL: TEST: testRun
 def testRunPipeline():
   with Context():
-    pm = PassManager.parse("print-op-stats")
+    pm = PassManager.parse("print-op-stats{json=false}")
     module = Module.parse(r"""func.func @successfulParse() { return }""")
     pm.run(module)
 # CHECK: Operations encountered:


        


More information about the Mlir-commits mailing list