[Mlir-commits] [mlir] [mlir][python] simplify extensions (PR #69642)

Maksim Levental llvmlistbot at llvm.org
Thu Oct 19 14:33:29 PDT 2023


https://github.com/makslevental created https://github.com/llvm/llvm-project/pull/69642

https://github.com/llvm/llvm-project/pull/68853 enabled a lot of nice cleanup. Note, I made sure each of the touched extensions had tests.

>From 725f960c1fe390e9291560cca934dde2cfd9440d Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Thu, 19 Oct 2023 15:11:58 -0500
Subject: [PATCH] [mlir][python] simplify extensions

---
 mlir/python/mlir/dialects/affine.py        | 45 --------------
 mlir/python/mlir/dialects/bufferization.py | 36 -----------
 mlir/python/mlir/dialects/func.py          |  3 -
 mlir/python/mlir/dialects/memref.py        | 38 ------------
 mlir/python/mlir/dialects/pdl.py           | 69 ----------------------
 mlir/python/mlir/dialects/scf.py           | 33 +++--------
 mlir/test/python/dialects/affine.py        |  2 +-
 mlir/test/python/dialects/func.py          |  4 ++
 8 files changed, 13 insertions(+), 217 deletions(-)

diff --git a/mlir/python/mlir/dialects/affine.py b/mlir/python/mlir/dialects/affine.py
index 1eaccfa73a85cbf..80d3873e19a05cb 100644
--- a/mlir/python/mlir/dialects/affine.py
+++ b/mlir/python/mlir/dialects/affine.py
@@ -3,48 +3,3 @@
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
 from ._affine_ops_gen import *
-from ._affine_ops_gen import _Dialect
-
-try:
-    from ..ir import *
-    from ._ods_common import (
-        get_op_result_or_value as _get_op_result_or_value,
-        get_op_results_or_values as _get_op_results_or_values,
-        _cext as _ods_cext,
-    )
-except ImportError as e:
-    raise RuntimeError("Error loading imports from extension module") from e
-
-from typing import Optional, Sequence, Union
-
-
- at _ods_cext.register_operation(_Dialect, replace=True)
-class AffineStoreOp(AffineStoreOp):
-    """Specialization for the Affine store operation."""
-
-    def __init__(
-        self,
-        value: Union[Operation, OpView, Value],
-        memref: Union[Operation, OpView, Value],
-        map: AffineMap = None,
-        *,
-        map_operands=None,
-        loc=None,
-        ip=None,
-    ):
-        """Creates an affine store operation.
-
-        - `value`: the value to store into the memref.
-        - `memref`: the buffer to store into.
-        - `map`: the affine map that maps the map_operands to the index of the
-          memref.
-        - `map_operands`: the list of arguments to substitute the dimensions,
-          then symbols in the affine map, in increasing order.
-        """
-        map = map if map is not None else []
-        map_operands = map_operands if map_operands is not None else []
-        indicies = [_get_op_result_or_value(op) for op in map_operands]
-        _ods_successors = None
-        super().__init__(
-            value, memref, indicies, AffineMapAttr.get(map), loc=loc, ip=ip
-        )
diff --git a/mlir/python/mlir/dialects/bufferization.py b/mlir/python/mlir/dialects/bufferization.py
index 0ce5448ace4b14c..759b6aa24a9ff73 100644
--- a/mlir/python/mlir/dialects/bufferization.py
+++ b/mlir/python/mlir/dialects/bufferization.py
@@ -3,40 +3,4 @@
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
 from ._bufferization_ops_gen import *
-from ._bufferization_ops_gen import _Dialect
 from ._bufferization_enum_gen import *
-
-try:
-    from typing import Sequence, Union
-    from ..ir import *
-    from ._ods_common import get_default_loc_context, _cext as _ods_cext
-
-    from typing import Any, List, Union
-except ImportError as e:
-    raise RuntimeError("Error loading imports from extension module") from e
-
-
- at _ods_cext.register_operation(_Dialect, replace=True)
-class AllocTensorOp(AllocTensorOp):
-    """Extends the bufferization.alloc_tensor op."""
-
-    def __init__(
-        self,
-        tensor_type: Type,
-        dynamic_sizes: Sequence[Value],
-        copy: Value,
-        size_hint: Value,
-        escape: BoolAttr,
-        *,
-        loc=None,
-        ip=None,
-    ):
-        """Constructs an `alloc_tensor` with static and/or dynamic sizes."""
-        super().__init__(
-            tensor_type,
-            dynamic_sizes,
-            copy=copy,
-            size_hint=size_hint,
-            loc=loc,
-            ip=ip,
-        )
diff --git a/mlir/python/mlir/dialects/func.py b/mlir/python/mlir/dialects/func.py
index 9c6c4c9092c7a88..6599f67b7078777 100644
--- a/mlir/python/mlir/dialects/func.py
+++ b/mlir/python/mlir/dialects/func.py
@@ -26,9 +26,6 @@
 class ConstantOp(ConstantOp):
     """Specialization for the constant op class."""
 
-    def __init__(self, result: Type, value: Attribute, *, loc=None, ip=None):
-        super().__init__(result, value, loc=loc, ip=ip)
-
     @property
     def type(self):
         return self.results[0].type
diff --git a/mlir/python/mlir/dialects/memref.py b/mlir/python/mlir/dialects/memref.py
index 111ad2178703d28..3afb6a70cb9e0db 100644
--- a/mlir/python/mlir/dialects/memref.py
+++ b/mlir/python/mlir/dialects/memref.py
@@ -3,41 +3,3 @@
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
 from ._memref_ops_gen import *
-from ._memref_ops_gen import _Dialect
-
-try:
-    from ..ir import *
-    from ._ods_common import (
-        get_op_result_or_value as _get_op_result_or_value,
-        get_op_results_or_values as _get_op_results_or_values,
-        _cext as _ods_cext,
-    )
-except ImportError as e:
-    raise RuntimeError("Error loading imports from extension module") from e
-
-from typing import Optional, Sequence, Union
-
-
- at _ods_cext.register_operation(_Dialect, replace=True)
-class LoadOp(LoadOp):
-    """Specialization for the MemRef load operation."""
-
-    def __init__(
-        self,
-        memref: Union[Operation, OpView, Value],
-        indices: Optional[Union[Operation, OpView, Sequence[Value]]] = None,
-        *,
-        loc=None,
-        ip=None,
-    ):
-        """Creates a memref load operation.
-
-        Args:
-          memref: the buffer to load from.
-          indices: the list of subscripts, may be empty for zero-dimensional
-            buffers.
-          loc: user-visible location of the operation.
-          ip: insertion point.
-        """
-        indices_resolved = [] if indices is None else _get_op_results_or_values(indices)
-        super().__init__(memref, indices_resolved, loc=loc, ip=ip)
diff --git a/mlir/python/mlir/dialects/pdl.py b/mlir/python/mlir/dialects/pdl.py
index a8d9c56f4233d9e..90d7d706238e649 100644
--- a/mlir/python/mlir/dialects/pdl.py
+++ b/mlir/python/mlir/dialects/pdl.py
@@ -21,43 +21,6 @@
 )
 
 
- at _ods_cext.register_operation(_Dialect, replace=True)
-class ApplyNativeConstraintOp(ApplyNativeConstraintOp):
-    """Specialization for PDL apply native constraint op class."""
-
-    def __init__(
-        self,
-        name: Union[str, StringAttr],
-        args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
-        *,
-        loc=None,
-        ip=None,
-    ):
-        if args is None:
-            args = []
-        args = _get_values(args)
-        super().__init__(name, args, loc=loc, ip=ip)
-
-
- at _ods_cext.register_operation(_Dialect, replace=True)
-class ApplyNativeRewriteOp(ApplyNativeRewriteOp):
-    """Specialization for PDL apply native rewrite op class."""
-
-    def __init__(
-        self,
-        results: Sequence[Type],
-        name: Union[str, StringAttr],
-        args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
-        *,
-        loc=None,
-        ip=None,
-    ):
-        if args is None:
-            args = []
-        args = _get_values(args)
-        super().__init__(results, name, args, loc=loc, ip=ip)
-
-
 @_ods_cext.register_operation(_Dialect, replace=True)
 class AttributeOp(AttributeOp):
     """Specialization for PDL attribute op class."""
@@ -75,21 +38,6 @@ def __init__(
         super().__init__(result, valueType=valueType, value=value, loc=loc, ip=ip)
 
 
- at _ods_cext.register_operation(_Dialect, replace=True)
-class EraseOp(EraseOp):
-    """Specialization for PDL erase op class."""
-
-    def __init__(
-        self,
-        operation: Optional[Union[OpView, Operation, Value]] = None,
-        *,
-        loc=None,
-        ip=None,
-    ):
-        operation = _get_value(operation)
-        super().__init__(operation, loc=loc, ip=ip)
-
-
 @_ods_cext.register_operation(_Dialect, replace=True)
 class OperandOp(OperandOp):
     """Specialization for PDL operand op class."""
@@ -216,23 +164,6 @@ def __init__(
         super().__init__(result, parent, index, loc=loc, ip=ip)
 
 
- at _ods_cext.register_operation(_Dialect, replace=True)
-class ResultsOp(ResultsOp):
-    """Specialization for PDL results op class."""
-
-    def __init__(
-        self,
-        result: Type,
-        parent: Union[OpView, Operation, Value],
-        index: Optional[Union[IntegerAttr, int]] = None,
-        *,
-        loc=None,
-        ip=None,
-    ):
-        parent = _get_value(parent)
-        super().__init__(result, parent, index=index, loc=loc, ip=ip)
-
-
 @_ods_cext.register_operation(_Dialect, replace=True)
 class RewriteOp(RewriteOp):
     """Specialization for PDL rewrite op class."""
diff --git a/mlir/python/mlir/dialects/scf.py b/mlir/python/mlir/dialects/scf.py
index 43ad9f4e2d65f51..71c80cab76dfb86 100644
--- a/mlir/python/mlir/dialects/scf.py
+++ b/mlir/python/mlir/dialects/scf.py
@@ -20,11 +20,8 @@
 from typing import Optional, Sequence, Union
 
 
-_ForOp = ForOp
-
-
 @_ods_cext.register_operation(_Dialect, replace=True)
-class ForOp(_ForOp):
+class ForOp(ForOp):
     """Specialization for the SCF for op class."""
 
     def __init__(
@@ -50,17 +47,8 @@ def __init__(
         iter_args = _get_op_results_or_values(iter_args)
 
         results = [arg.type for arg in iter_args]
-        super(_ForOp, self).__init__(
-            self.build_generic(
-                regions=1,
-                results=results,
-                operands=[
-                    _get_op_result_or_value(o) for o in [lower_bound, upper_bound, step]
-                ]
-                + list(iter_args),
-                loc=loc,
-                ip=ip,
-            )
+        super().__init__(
+            results, lower_bound, upper_bound, step, iter_args, loc=loc, ip=ip
         )
         self.regions[0].blocks.append(self.operands[0].type, *results)
 
@@ -83,28 +71,23 @@ def inner_iter_args(self):
         return self.body.arguments[1:]
 
 
-_IfOp = IfOp
-
-
 @_ods_cext.register_operation(_Dialect, replace=True)
-class IfOp(_IfOp):
+class IfOp(IfOp):
     """Specialization for the SCF if op class."""
 
-    def __init__(self, cond, results_=[], *, hasElse=False, loc=None, ip=None):
+    def __init__(self, cond, results_=None, *, hasElse=False, loc=None, ip=None):
         """Creates an SCF `if` operation.
 
         - `cond` is a MLIR value of 'i1' type to determine which regions of code will be executed.
         - `hasElse` determines whether the if operation has the else branch.
         """
+        if results_ is None:
+            results_ = []
         operands = []
         operands.append(cond)
         results = []
         results.extend(results_)
-        super(_IfOp, self).__init__(
-            self.build_generic(
-                regions=2, results=results, operands=operands, loc=loc, ip=ip
-            )
-        )
+        super().__init__(results, cond)
         self.regions[0].blocks.append(*[])
         if hasElse:
             self.regions[1].blocks.append(*[])
diff --git a/mlir/test/python/dialects/affine.py b/mlir/test/python/dialects/affine.py
index d2e664d4653420f..c5ec85457493b42 100644
--- a/mlir/test/python/dialects/affine.py
+++ b/mlir/test/python/dialects/affine.py
@@ -37,7 +37,7 @@ def affine_store_test(arg0):
                 a1 = arith.ConstantOp(f32, 2.1)
 
                 # CHECK: affine.store %[[A1]], %alloc[symbol(%[[ARG0]]) * 3, %[[ARG0]] + symbol(%[[ARG0]]) + 1] : memref<12x12xf32>
-                affine.AffineStoreOp(a1, mem, map, map_operands=[arg0, arg0])
+                affine.AffineStoreOp(a1, mem, indices=[arg0, arg0], map=map)
 
                 return mem
 
diff --git a/mlir/test/python/dialects/func.py b/mlir/test/python/dialects/func.py
index 161a12d78776a0e..a2014c64d2fa53b 100644
--- a/mlir/test/python/dialects/func.py
+++ b/mlir/test/python/dialects/func.py
@@ -84,6 +84,9 @@ def testFunctionCalls():
     qux = func.FuncOp("qux", ([], [F32Type.get()]))
     qux.sym_visibility = StringAttr.get("private")
 
+    con = func.ConstantOp(qux.type, qux.sym_name.value)
+    assert con.type == qux.type
+
     with InsertionPoint(func.FuncOp("caller", ([], [])).add_entry_block()):
         func.CallOp(foo, [])
         func.CallOp([IndexType.get()], "bar", [])
@@ -94,6 +97,7 @@ def testFunctionCalls():
 # CHECK: func private @foo()
 # CHECK: func private @bar() -> index
 # CHECK: func private @qux() -> f32
+# CHECK: %f = func.constant @qux : () -> f32
 # CHECK: func @caller() {
 # CHECK:   call @foo() : () -> ()
 # CHECK:   %0 = call @bar() : () -> index



More information about the Mlir-commits mailing list