[Mlir-commits] [mlir] 06e2c78 - [MLIR][Python] Pass OpView subclasses instead of Operation in rewrite patterns (#163080)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Oct 12 20:57:00 PDT 2025


Author: Twice
Date: 2025-10-13T11:56:57+08:00
New Revision: 06e2c78680d753d97b0cd6b7a86b4dbd0dbfb1e9

URL: https://github.com/llvm/llvm-project/commit/06e2c78680d753d97b0cd6b7a86b4dbd0dbfb1e9
DIFF: https://github.com/llvm/llvm-project/commit/06e2c78680d753d97b0cd6b7a86b4dbd0dbfb1e9.diff

LOG: [MLIR][Python] Pass OpView subclasses instead of Operation in rewrite patterns (#163080)

This is a follow-up PR for #162699.

Currently, in the function where we define rewrite patterns, the `op` we
receive is of type `ir.Operation` rather than a specific `OpView` type
(such as `arith.AddIOp`). This means we can’t conveniently access
certain parts of the operation — for example, we need to use
`op.operands[0]` instead of `op.lhs`. The following example code
illustrates this situation.

```python
def to_muli(op, rewriter):
  # op is typed ir.Operation instead of arith.AddIOp
  pass

patterns.add(arith.AddIOp, to_muli)
```

In this PR, we convert the operation to its corresponding `OpView`
subclass before invoking the rewrite pattern callback, making it much
easier to write patterns.

---------

Co-authored-by: Maksim Levental <maksim.levental at gmail.com>

Added: 
    

Modified: 
    mlir/lib/Bindings/Python/Rewrite.cpp
    mlir/python/mlir/dialects/arith.py
    mlir/test/python/rewrite.py

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 47685567d5355..5ddb3fbbb1317 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -197,7 +197,12 @@ class PyRewritePatternSet {
                                    MlirPatternRewriter rewriter,
                                    void *userData) -> MlirLogicalResult {
       nb::handle f(static_cast<PyObject *>(userData));
-      nb::object res = f(op, PyPatternRewriter(rewriter));
+
+      PyMlirContextRef ctx =
+          PyMlirContext::forContext(mlirOperationGetContext(op));
+      nb::object opView = PyOperation::forOperation(ctx, op)->createOpView();
+
+      nb::object res = f(opView, PyPatternRewriter(rewriter));
       return logicalResultFromObject(res);
     };
     MlirRewritePattern pattern = mlirOpRewritePattenCreate(

diff  --git a/mlir/python/mlir/dialects/arith.py b/mlir/python/mlir/dialects/arith.py
index 92da5df9bce66..88e8502a29eae 100644
--- a/mlir/python/mlir/dialects/arith.py
+++ b/mlir/python/mlir/dialects/arith.py
@@ -92,7 +92,7 @@ def type(self):
 
     @property
     def value(self):
-        return Attribute(self.operation.attributes["value"])
+        return self.operation.attributes["value"]
 
     @property
     def literal_value(self) -> Union[int, float]:

diff  --git a/mlir/test/python/rewrite.py b/mlir/test/python/rewrite.py
index acf7db23db914..821e47085a5bd 100644
--- a/mlir/test/python/rewrite.py
+++ b/mlir/test/python/rewrite.py
@@ -17,15 +17,16 @@ def run(f):
 def testRewritePattern():
     def to_muli(op, rewriter):
         with rewriter.ip:
-            new_op = arith.muli(op.operands[0], op.operands[1], loc=op.location)
+            assert isinstance(op, arith.AddIOp)
+            new_op = arith.muli(op.lhs, op.rhs, loc=op.location)
         rewriter.replace_op(op, new_op.owner)
 
     def constant_1_to_2(op, rewriter):
-        c = op.attributes["value"].value
+        c = op.value.value
         if c != 1:
             return True  # failed to match
         with rewriter.ip:
-            new_op = arith.constant(op.result.type, 2, loc=op.location)
+            new_op = arith.constant(op.type, 2, loc=op.location)
         rewriter.replace_op(op, [new_op])
 
     with Context():


        


More information about the Mlir-commits mailing list