[Mlir-commits] [mlir] c00f81c - [mlir][python] Allow running pass manager on any operation

Rahul Kayaith llvmlistbot at llvm.org
Wed Mar 1 15:17:24 PST 2023


Author: rkayaith
Date: 2023-03-01T18:17:14-05:00
New Revision: c00f81cc46bd88b001110e9c6564c4848f82900c

URL: https://github.com/llvm/llvm-project/commit/c00f81cc46bd88b001110e9c6564c4848f82900c
DIFF: https://github.com/llvm/llvm-project/commit/c00f81cc46bd88b001110e9c6564c4848f82900c.diff

LOG: [mlir][python] Allow running pass manager on any operation

`PassManager.run` is currently restricted to running on `builtin.module`
ops, but this restriction doesn't exist on the C++ side. This updates it
to take `ir.Operation/OpView` instead of `ir.Module`.

Depends on D143354

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D143356

Added: 
    

Modified: 
    mlir/lib/Bindings/Python/Pass.cpp
    mlir/test/Integration/Dialect/SparseTensor/python/tools/sparse_compiler.py
    mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_sparse_compiler.py
    mlir/test/python/execution_engine.py
    mlir/test/python/integration/dialects/linalg/opsrun.py
    mlir/test/python/pass_manager.py

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index 99f17a18bc83c..7e90d8be66cb6 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -116,16 +116,16 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
           "ValueError if the pipeline can't be parsed.")
       .def(
           "run",
-          [](PyPassManager &passManager, PyModule &module) {
+          [](PyPassManager &passManager, PyOperationBase &op) {
             MlirLogicalResult status = mlirPassManagerRunOnOp(
-                passManager.get(), mlirModuleGetOperation(module.get()));
+                passManager.get(), op.getOperation().get());
             if (mlirLogicalResultIsFailure(status))
               throw SetPyError(PyExc_RuntimeError,
                                "Failure while executing pass pipeline.");
           },
-          py::arg("module"),
-          "Run the pass manager on the provided module, throw a RuntimeError "
-          "on failure.")
+          py::arg("operation"),
+          "Run the pass manager on the provided operation, throw a "
+          "RuntimeError on failure.")
       .def(
           "__str__",
           [](PyPassManager &self) {

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/python/tools/sparse_compiler.py b/mlir/test/Integration/Dialect/SparseTensor/python/tools/sparse_compiler.py
index abdab9738def7..25004f9492dbc 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/python/tools/sparse_compiler.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/python/tools/sparse_compiler.py
@@ -24,7 +24,7 @@ def __call__(self, module: ir.Module):
 
   def compile(self, module: ir.Module):
     """Compiles the module by invoking the sparse copmiler pipeline."""
-    passmanager.PassManager.parse(self.pipeline).run(module)
+    passmanager.PassManager.parse(self.pipeline).run(module.operation)
 
   def jit(self, module: ir.Module) -> execution_engine.ExecutionEngine:
     """Wraps the module in a JIT execution engine."""

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_sparse_compiler.py b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_sparse_compiler.py
index 1ba0d393894b9..69db28d4bccd5 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_sparse_compiler.py
+++ b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_sparse_compiler.py
@@ -27,7 +27,7 @@ def __call__(self, module: ir.Module):
 
   def compile(self, module: ir.Module):
     """Compiles the module by invoking the sparse copmiler pipeline."""
-    passmanager.PassManager.parse(self.pipeline).run(module)
+    passmanager.PassManager.parse(self.pipeline).run(module.operation)
 
   def jit(self, module: ir.Module) -> execution_engine.ExecutionEngine:
     """Wraps the module in a JIT execution engine."""

diff  --git a/mlir/test/python/execution_engine.py b/mlir/test/python/execution_engine.py
index c40d522b8ca2a..973810d5b82f4 100644
--- a/mlir/test/python/execution_engine.py
+++ b/mlir/test/python/execution_engine.py
@@ -64,7 +64,7 @@ def testInvalidModule():
 def lowerToLLVM(module):
   pm = PassManager.parse(
       "builtin.module(convert-complex-to-llvm,finalize-memref-to-llvm,convert-func-to-llvm,reconcile-unrealized-casts)")
-  pm.run(module)
+  pm.run(module.operation)
   return module
 
 

diff  --git a/mlir/test/python/integration/dialects/linalg/opsrun.py b/mlir/test/python/integration/dialects/linalg/opsrun.py
index 2a890eafa80f7..2cba577b33266 100644
--- a/mlir/test/python/integration/dialects/linalg/opsrun.py
+++ b/mlir/test/python/integration/dialects/linalg/opsrun.py
@@ -202,7 +202,7 @@ def transform(module, boilerplate):
   pm.add("finalize-memref-to-llvm")
   pm.add("convert-func-to-llvm")
   pm.add("reconcile-unrealized-casts")
-  pm.run(mod)
+  pm.run(mod.operation)
   return mod
 
 

diff  --git a/mlir/test/python/pass_manager.py b/mlir/test/python/pass_manager.py
index 2943881ec85eb..b3acd359a207d 100644
--- a/mlir/test/python/pass_manager.py
+++ b/mlir/test/python/pass_manager.py
@@ -3,6 +3,7 @@
 import gc, sys
 from mlir.ir import *
 from mlir.passmanager import *
+from mlir.dialects.func import FuncOp
 
 # Log everything to stderr and flush so that we have a unified stream to match
 # errors/info emitted by MLIR to stderr.
@@ -120,11 +121,10 @@ def testInvalidNesting():
 # CHECK-LABEL: TEST: testRun
 def testRunPipeline():
   with Context():
-    pm = PassManager.parse("builtin.module(print-op-stats{json=false})")
-    module = Module.parse(r"""func.func @successfulParse() { return }""")
-    pm.run(module)
+    pm = PassManager.parse("any(print-op-stats{json=false})")
+    func = FuncOp.parse(r"""func.func @successfulParse() { return }""")
+    pm.run(func)
 # CHECK: Operations encountered:
-# CHECK: builtin.module    , 1
 # CHECK: func.func      , 1
 # CHECK: func.return        , 1
 run(testRunPipeline)


        


More information about the Mlir-commits mailing list