[Mlir-commits] [mlir] 8010d7e - [mlir] add an option to print op stats in JSON
Okwan Kwon
llvmlistbot at llvm.org
Wed Jun 15 10:07:46 PDT 2022
Author: Okwan Kwon
Date: 2022-06-15T10:07:36-07:00
New Revision: 8010d7e0446a09471d5349466b80651c3ad76af3
URL: https://github.com/llvm/llvm-project/commit/8010d7e0446a09471d5349466b80651c3ad76af3
DIFF: https://github.com/llvm/llvm-project/commit/8010d7e0446a09471d5349466b80651c3ad76af3.diff
LOG: [mlir] add an option to print op stats in JSON
Differential Revision: https://reviews.llvm.org/D127691
Added:
mlir/test/IR/op-stats-json.mlir
Modified:
mlir/include/mlir/Transforms/Passes.td
mlir/lib/Transforms/OpStats.cpp
mlir/test/CAPI/pass.c
mlir/test/python/pass_manager.py
Removed:
################################################################################
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index 7ac71ceeb588a..8b8e6a1001574 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -145,6 +145,10 @@ def LoopInvariantCodeMotion : Pass<"loop-invariant-code-motion"> {
def PrintOpStats : Pass<"print-op-stats"> {
let summary = "Print statistics of operations";
let constructor = "mlir::createPrintOpStatsPass()";
+ let options = [
+ Option<"printAsJSON", "json", "bool", /*default=*/"false",
+ "print the stats as JSON">
+ ];
}
def SCCP : Pass<"sccp"> {
diff --git a/mlir/lib/Transforms/OpStats.cpp b/mlir/lib/Transforms/OpStats.cpp
index 8adc4117f9542..e7740abb3d19f 100644
--- a/mlir/lib/Transforms/OpStats.cpp
+++ b/mlir/lib/Transforms/OpStats.cpp
@@ -27,6 +27,9 @@ struct PrintOpStatsPass : public PrintOpStatsBase<PrintOpStatsPass> {
// Print summary of op stats.
void printSummary();
+ // Print symmary of op stats in JSON.
+ void printSummaryInJSON();
+
private:
llvm::StringMap<int64_t> opCount;
raw_ostream &os;
@@ -37,8 +40,12 @@ void PrintOpStatsPass::runOnOperation() {
opCount.clear();
// Compute the operation statistics for the currently visited operation.
- getOperation()->walk([&](Operation *op) { ++opCount[op->getName().getStringRef()]; });
- printSummary();
+ getOperation()->walk(
+ [&](Operation *op) { ++opCount[op->getName().getStringRef()]; });
+ if (printAsJSON) {
+ printSummaryInJSON();
+ } else
+ printSummary();
}
void PrintOpStatsPass::printSummary() {
@@ -80,6 +87,23 @@ void PrintOpStatsPass::printSummary() {
}
}
+void PrintOpStatsPass::printSummaryInJSON() {
+ SmallVector<StringRef, 64> sorted(opCount.keys());
+ llvm::sort(sorted);
+
+ os << "{\n";
+
+ for (unsigned i = 0, e = sorted.size(); i != e; ++i) {
+ const auto &key = sorted[i];
+ os << " \"" << key << "\" : " << opCount[key];
+ if (i != e - 1)
+ os << ",\n";
+ else
+ os << "\n";
+ }
+ os << "}\n";
+}
+
std::unique_ptr<Pass> mlir::createPrintOpStatsPass(raw_ostream &os) {
return std::make_unique<PrintOpStatsPass>(os);
}
diff --git a/mlir/test/CAPI/pass.c b/mlir/test/CAPI/pass.c
index 63aba29e56f61..f73398e817beb 100644
--- a/mlir/test/CAPI/pass.c
+++ b/mlir/test/CAPI/pass.c
@@ -138,14 +138,14 @@ void testPrintPassPipeline() {
mlirOpPassManagerAddOwnedPass(nestedFuncPm, printOpStatPass);
// Print the top level pass manager
- // CHECK: Top-level: builtin.module(func.func(print-op-stats))
+ // CHECK: Top-level: builtin.module(func.func(print-op-stats{json=false}))
fprintf(stderr, "Top-level: ");
mlirPrintPassPipeline(mlirPassManagerGetAsOpPassManager(pm), printToStderr,
NULL);
fprintf(stderr, "\n");
// Print the pipeline nested one level down
- // CHECK: Nested Module: func.func(print-op-stats)
+ // CHECK: Nested Module: func.func(print-op-stats{json=false})
fprintf(stderr, "Nested Module: ");
mlirPrintPassPipeline(nestedModulePm, printToStderr, NULL);
fprintf(stderr, "\n");
@@ -166,8 +166,9 @@ void testParsePassPipeline() {
// Try parse a pipeline.
MlirLogicalResult status = mlirParsePassPipeline(
mlirPassManagerGetAsOpPassManager(pm),
- mlirStringRefCreateFromCString("builtin.module(func.func(print-op-stats),"
- " func.func(print-op-stats))"));
+ mlirStringRefCreateFromCString(
+ "builtin.module(func.func(print-op-stats{json=false}),"
+ " func.func(print-op-stats{json=false}))"));
// Expect a failure, we haven't registered the print-op-stats pass yet.
if (mlirLogicalResultIsSuccess(status)) {
fprintf(
@@ -179,8 +180,9 @@ void testParsePassPipeline() {
mlirRegisterTransformsPrintOpStats();
status = mlirParsePassPipeline(
mlirPassManagerGetAsOpPassManager(pm),
- mlirStringRefCreateFromCString("builtin.module(func.func(print-op-stats),"
- " func.func(print-op-stats))"));
+ mlirStringRefCreateFromCString(
+ "builtin.module(func.func(print-op-stats{json=false}),"
+ " func.func(print-op-stats{json=false}))"));
// Expect a failure, we haven't registered the print-op-stats pass yet.
if (mlirLogicalResultIsFailure(status)) {
fprintf(stderr,
@@ -188,8 +190,8 @@ void testParsePassPipeline() {
exit(EXIT_FAILURE);
}
- // CHECK: Round-trip: builtin.module(func.func(print-op-stats),
- // func.func(print-op-stats))
+ // CHECK: Round-trip: builtin.module(func.func(print-op-stats{json=false}),
+ // func.func(print-op-stats{json=false}))
fprintf(stderr, "Round-trip: ");
mlirPrintPassPipeline(mlirPassManagerGetAsOpPassManager(pm), printToStderr,
NULL);
diff --git a/mlir/test/IR/op-stats-json.mlir b/mlir/test/IR/op-stats-json.mlir
new file mode 100644
index 0000000000000..40b0602a95897
--- /dev/null
+++ b/mlir/test/IR/op-stats-json.mlir
@@ -0,0 +1,37 @@
+// RUN: mlir-opt -allow-unregistered-dialect -print-op-stats=json %s -o=/dev/null 2>&1 | FileCheck %s
+
+func.func @main(tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> {
+^bb0(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>):
+ %0 = arith.addf %arg0, %arg1 : tensor<4xf32>
+ %1 = arith.addf %arg0, %arg1 : tensor<4xf32>
+ %2 = arith.addf %arg0, %arg1 : tensor<4xf32>
+ %3 = arith.addf %arg0, %arg1 : tensor<4xf32>
+ %4 = arith.addf %arg0, %arg1 : tensor<4xf32>
+ %5 = arith.addf %arg0, %arg1 : tensor<4xf32>
+ %10 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32>
+ %11 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32>
+ %12 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32>
+ %13 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32>
+ %14 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32>
+ %15 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32>
+ %16 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32>
+ %17 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32>
+ %18 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32>
+ %19 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32>
+ %20 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32>
+ %21 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32>
+ %22 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32>
+ %23 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32>
+ %24 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32>
+ %25 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32>
+ %26 = "xla.add"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32>
+ %30 = "long_op_name"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32>
+ return %1 : tensor<4xf32>
+}
+
+// CHECK: {
+// CHECK: "arith.addf" : 6,
+// CHECK: "func.return" : 1,
+// CHECK: "long_op_name" : 1,
+// CHECK: "xla.add" : 17
+// CHECK: }
diff --git a/mlir/test/python/pass_manager.py b/mlir/test/python/pass_manager.py
index c046bb818632e..6cc627d542337 100644
--- a/mlir/test/python/pass_manager.py
+++ b/mlir/test/python/pass_manager.py
@@ -36,19 +36,19 @@ def testParseSuccess():
# A first import is expected to fail because the pass isn't registered
# until we import mlir.transforms
try:
- pm = PassManager.parse("builtin.module(func.func(print-op-stats))")
+ pm = PassManager.parse("builtin.module(func.func(print-op-stats{json=false}))")
# TODO: this error should be propagate to Python but the C API does not help right now.
# CHECK: error: 'print-op-stats' does not refer to a registered pass or pass pipeline
except ValueError as e:
- # CHECK: ValueError exception: invalid pass pipeline 'builtin.module(func.func(print-op-stats))'.
+ # CHECK: ValueError exception: invalid pass pipeline 'builtin.module(func.func(print-op-stats{json=false}))'.
log("ValueError exception:", e)
else:
log("Exception not produced")
# This will register the pass and round-trip should be possible now.
import mlir.transforms
- pm = PassManager.parse("builtin.module(func.func(print-op-stats))")
- # CHECK: Roundtrip: builtin.module(func.func(print-op-stats))
+ pm = PassManager.parse("builtin.module(func.func(print-op-stats{json=false}))")
+ # CHECK: Roundtrip: builtin.module(func.func(print-op-stats{json=false}))
log("Roundtrip: ", pm)
run(testParseSuccess)
@@ -86,7 +86,7 @@ def testInvalidNesting():
# CHECK-LABEL: TEST: testRun
def testRunPipeline():
with Context():
- pm = PassManager.parse("print-op-stats")
+ pm = PassManager.parse("print-op-stats{json=false}")
module = Module.parse(r"""func.func @successfulParse() { return }""")
pm.run(module)
# CHECK: Operations encountered:
More information about the Mlir-commits
mailing list