[Mlir-commits] [mlir] dd1b1d4 - [mlir][python] Allow adding to existing pass manager

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Nov 4 09:04:32 PDT 2022


Author: rkayaith
Date: 2022-11-04T12:04:26-04:00
New Revision: dd1b1d44503288f98b298f1ec8b374137d2812e2

URL: https://github.com/llvm/llvm-project/commit/dd1b1d44503288f98b298f1ec8b374137d2812e2
DIFF: https://github.com/llvm/llvm-project/commit/dd1b1d44503288f98b298f1ec8b374137d2812e2.diff

LOG: [mlir][python] Allow adding to existing pass manager

This adds a `PassManager.add` method which adds pipeline elements to the
pass manager. This allows for progressively building up a pipeline from
python without string manipulation.

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D137344

Added: 
    

Modified: 
    mlir/lib/Bindings/Python/Pass.cpp
    mlir/test/python/integration/dialects/linalg/opsrun.py
    mlir/test/python/pass_manager.py

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index 13f1cfa3536a..cb3c1586eb99 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -100,6 +100,20 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
           "Parse a textual pass-pipeline and return a top-level PassManager "
           "that can be applied on a Module. Throw a ValueError if the pipeline "
           "can't be parsed")
+      .def(
+          "add",
+          [](PyPassManager &passManager, const std::string &pipeline) {
+            PyPrintAccumulator errorMsg;
+            MlirLogicalResult status = mlirOpPassManagerAddPipeline(
+                mlirPassManagerGetAsOpPassManager(passManager.get()),
+                mlirStringRefCreate(pipeline.data(), pipeline.size()),
+                errorMsg.getCallback(), errorMsg.getUserData());
+            if (mlirLogicalResultIsFailure(status))
+              throw SetPyError(PyExc_ValueError, std::string(errorMsg.join()));
+          },
+          py::arg("pipeline"),
+          "Add textual pipeline elements to the pass manager. Throws a "
+          "ValueError if the pipeline can't be parsed.")
       .def(
           "run",
           [](PyPassManager &passManager, PyModule &module) {

diff  --git a/mlir/test/python/integration/dialects/linalg/opsrun.py b/mlir/test/python/integration/dialects/linalg/opsrun.py
index 2075ecfc21d0..585741ae9336 100644
--- a/mlir/test/python/integration/dialects/linalg/opsrun.py
+++ b/mlir/test/python/integration/dialects/linalg/opsrun.py
@@ -191,11 +191,17 @@ def transform(module, boilerplate):
   ops = module.operation.regions[0].blocks[0].operations
   mod = Module.parse("\n".join([str(op) for op in ops]) + boilerplate)
 
-  pm = PassManager.parse(
-      "builtin.module(func.func(convert-linalg-to-loops, lower-affine, " +
-      "convert-math-to-llvm, convert-scf-to-cf, arith-expand, memref-expand), "
-      + "convert-vector-to-llvm, convert-memref-to-llvm, convert-func-to-llvm," +
-      "reconcile-unrealized-casts)")
+  pm = PassManager('builtin.module')
+  pm.add("func.func(convert-linalg-to-loops)")
+  pm.add("func.func(lower-affine)")
+  pm.add("func.func(convert-math-to-llvm)")
+  pm.add("func.func(convert-scf-to-cf)")
+  pm.add("func.func(arith-expand)")
+  pm.add("func.func(memref-expand)")
+  pm.add("convert-vector-to-llvm")
+  pm.add("convert-memref-to-llvm")
+  pm.add("convert-func-to-llvm")
+  pm.add("reconcile-unrealized-casts")
   pm.run(mod)
   return mod
 

diff  --git a/mlir/test/python/pass_manager.py b/mlir/test/python/pass_manager.py
index 04e325e13e78..492c7e09ec5a 100644
--- a/mlir/test/python/pass_manager.py
+++ b/mlir/test/python/pass_manager.py
@@ -75,6 +75,20 @@ def testParseFail():
       log("Exception not produced")
 run(testParseFail)
 
+# Check that adding to a pass manager works
+# CHECK-LABEL: TEST: testAdd
+ at run
+def testAdd():
+  pm = PassManager("any", Context())
+  # CHECK: pm: 'any()'
+  log(f"pm: '{pm}'")
+  # CHECK: pm: 'any(cse)'
+  pm.add("cse")
+  log(f"pm: '{pm}'")
+  # CHECK: pm: 'any(cse,cse)'
+  pm.add("cse")
+  log(f"pm: '{pm}'")
+
 
 # Verify failure on incorrect level of nesting.
 # CHECK-LABEL: TEST: testInvalidNesting


        


More information about the Mlir-commits mailing list