[Mlir-commits] [mlir] c9986f8 - [mlir][Pass] Handle spaces in pipeline strings

Min-Yih Hsu llvmlistbot at llvm.org
Mon Jan 30 12:51:21 PST 2023


Author: Bruno Schmitt
Date: 2023-01-30T12:49:19-08:00
New Revision: c9986f8263980ce6bc9ef87d96d97911eaad547d

URL: https://github.com/llvm/llvm-project/commit/c9986f8263980ce6bc9ef87d96d97911eaad547d
DIFF: https://github.com/llvm/llvm-project/commit/c9986f8263980ce6bc9ef87d96d97911eaad547d.diff

LOG: [mlir][Pass] Handle spaces in pipeline strings

An user might want to add extra spaces for better readability, e.g:
```
mypm = pm.PassManager.parse(f"""builtin.module(
    mypass1,
        func.func(mypass2,mypass3)
)""")
```
GitHub issue #59151

The parser was not taking into account the possibility of spaces after
`)`or `}`

Differential Revision: https://reviews.llvm.org/D142821

Added: 
    

Modified: 
    mlir/lib/Pass/PassRegistry.cpp
    mlir/test/Pass/pipeline-parsing.mlir
    mlir/test/python/pass_manager.py

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Pass/PassRegistry.cpp b/mlir/lib/Pass/PassRegistry.cpp
index 3d10fe14d43d9..42d65418344ee 100644
--- a/mlir/lib/Pass/PassRegistry.cpp
+++ b/mlir/lib/Pass/PassRegistry.cpp
@@ -6,8 +6,8 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include <utility>
 #include <optional>
+#include <utility>
 
 #include "mlir/Pass/Pass.h"
 #include "mlir/Pass/PassManager.h"
@@ -588,6 +588,9 @@ LogicalResult TextualPipeline::parsePipelineText(StringRef text,
       pipeline.back().options = text.substr(0, close);
       text = text.substr(close + 1);
 
+      // Consume space characters that an user might add for readability.
+      text = text.ltrim();
+
       // Skip checking for '(' because nested pipelines cannot have options.
     } else if (sep == '(') {
       text = text.substr(1);
@@ -607,6 +610,8 @@ LogicalResult TextualPipeline::parsePipelineText(StringRef text,
                             "parentheses while parsing pipeline");
 
       pipelineStack.pop_back();
+      // Consume space characters that an user might add for readability.
+      text = text.ltrim();
     }
 
     // Check if we've finished parsing.
@@ -703,6 +708,7 @@ LogicalResult mlir::parsePassPipeline(StringRef pipeline, OpPassManager &pm,
 
 FailureOr<OpPassManager> mlir::parsePassPipeline(StringRef pipeline,
                                                  raw_ostream &errorStream) {
+  pipeline = pipeline.trim();
   // Pipelines are expected to be of the form `<op-name>(<pipeline>)`.
   size_t pipelineStart = pipeline.find_first_of('(');
   if (pipelineStart == 0 || pipelineStart == StringRef::npos ||
@@ -712,7 +718,7 @@ FailureOr<OpPassManager> mlir::parsePassPipeline(StringRef pipeline,
     return failure();
   }
 
-  StringRef opName = pipeline.take_front(pipelineStart);
+  StringRef opName = pipeline.take_front(pipelineStart).rtrim();
   OpPassManager pm(opName);
   if (failed(parsePassPipeline(pipeline.drop_front(1 + pipelineStart), pm,
                                errorStream)))

diff  --git a/mlir/test/Pass/pipeline-parsing.mlir b/mlir/test/Pass/pipeline-parsing.mlir
index 6291dd647391b..f41553d2669f0 100644
--- a/mlir/test/Pass/pipeline-parsing.mlir
+++ b/mlir/test/Pass/pipeline-parsing.mlir
@@ -1,6 +1,7 @@
 // RUN: mlir-opt %s -mlir-disable-threading -pass-pipeline='builtin.module(builtin.module(test-module-pass,func.func(test-function-pass)),func.func(test-function-pass),func.func(cse,canonicalize))' -verify-each=false -mlir-timing -mlir-timing-display=tree 2>&1 | FileCheck %s
 // RUN: mlir-opt %s -mlir-disable-threading -test-textual-pm-nested-pipeline -verify-each=false -mlir-timing -mlir-timing-display=tree 2>&1 | FileCheck %s --check-prefix=TEXTUAL_CHECK
 // RUN: mlir-opt %s -mlir-disable-threading -pass-pipeline='builtin.module(builtin.module(test-module-pass),any(test-interface-pass),any(test-interface-pass),func.func(test-function-pass),any(canonicalize),func.func(cse))' -verify-each=false -mlir-timing -mlir-timing-display=tree 2>&1 | FileCheck %s --check-prefix=GENERIC_MERGE_CHECK
+// RUN: mlir-opt %s -mlir-disable-threading -pass-pipeline=' builtin.module ( builtin.module( func.func( test-function-pass, print-op-stats{ json=false } ) ) ) ' -verify-each=false -mlir-timing -mlir-timing-display=tree 2>&1 | FileCheck %s --check-prefix=PIPELINE_STR_WITH_SPACES_CHECK
 // RUN: not mlir-opt %s -pass-pipeline='any(builtin.module(test-module-pass)' 2>&1 | FileCheck --check-prefix=CHECK_ERROR_1 %s
 // RUN: not mlir-opt %s -pass-pipeline='builtin.module(test-module-pass))' 2>&1 | FileCheck --check-prefix=CHECK_ERROR_2 %s
 // RUN: not mlir-opt %s -pass-pipeline='any(builtin.module()()' 2>&1 | FileCheck --check-prefix=CHECK_ERROR_3 %s
@@ -51,6 +52,11 @@ module {
 // TEXTUAL_CHECK-NEXT:     'func.func' Pipeline
 // TEXTUAL_CHECK-NEXT:       TestFunctionPass
 
+// PIPELINE_STR_WITH_SPACES_CHECK:   'builtin.module' Pipeline
+// PIPELINE_STR_WITH_SPACES_CHECK-NEXT:   'func.func' Pipeline
+// PIPELINE_STR_WITH_SPACES_CHECK-NEXT:     TestFunctionPass
+// PIPELINE_STR_WITH_SPACES_CHECK-NEXT:     PrintOpStats
+
 // Check that generic pass pipelines are only merged when they aren't
 // going to overlap with op-specific pipelines.
 // GENERIC_MERGE_CHECK:      Pipeline Collection : ['builtin.module', 'any']

diff  --git a/mlir/test/python/pass_manager.py b/mlir/test/python/pass_manager.py
index 492c7e09ec5ae..2943881ec85eb 100644
--- a/mlir/test/python/pass_manager.py
+++ b/mlir/test/python/pass_manager.py
@@ -59,6 +59,18 @@ def testParseSuccess():
     log("Roundtrip: ", pm)
 run(testParseSuccess)
 
+# Verify successful round-trip.
+# CHECK-LABEL: TEST: testParseSpacedPipeline
+def testParseSpacedPipeline():
+  with Context():
+    # A registered pass should parse successfully even if has extras spaces for readability
+    pm = PassManager.parse("""builtin.module(
+        func.func( print-op-stats{ json=false } )
+    )""")
+    # CHECK: Roundtrip: builtin.module(func.func(print-op-stats{json=false}))
+    log("Roundtrip: ", pm)
+run(testParseSpacedPipeline)
+
 # Verify failure on unregistered pass.
 # CHECK-LABEL: TEST: testParseFail
 def testParseFail():


        


More information about the Mlir-commits mailing list