[Mlir-commits] [mlir] [mlir][python] simplify extensions (PR #69642)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Oct 19 14:35:07 PDT 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Maksim Levental (makslevental)
<details>
<summary>Changes</summary>
https://github.com/llvm/llvm-project/pull/68853 enabled a lot of nice cleanup. Note, I made sure each of the touched extensions had tests.
---
Full diff: https://github.com/llvm/llvm-project/pull/69642.diff
8 Files Affected:
- (modified) mlir/python/mlir/dialects/affine.py (-45)
- (modified) mlir/python/mlir/dialects/bufferization.py (-36)
- (modified) mlir/python/mlir/dialects/func.py (-3)
- (modified) mlir/python/mlir/dialects/memref.py (-38)
- (modified) mlir/python/mlir/dialects/pdl.py (-69)
- (modified) mlir/python/mlir/dialects/scf.py (+8-25)
- (modified) mlir/test/python/dialects/affine.py (+1-1)
- (modified) mlir/test/python/dialects/func.py (+4)
``````````diff
diff --git a/mlir/python/mlir/dialects/affine.py b/mlir/python/mlir/dialects/affine.py
index 1eaccfa73a85cbf..80d3873e19a05cb 100644
--- a/mlir/python/mlir/dialects/affine.py
+++ b/mlir/python/mlir/dialects/affine.py
@@ -3,48 +3,3 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from ._affine_ops_gen import *
-from ._affine_ops_gen import _Dialect
-
-try:
- from ..ir import *
- from ._ods_common import (
- get_op_result_or_value as _get_op_result_or_value,
- get_op_results_or_values as _get_op_results_or_values,
- _cext as _ods_cext,
- )
-except ImportError as e:
- raise RuntimeError("Error loading imports from extension module") from e
-
-from typing import Optional, Sequence, Union
-
-
- at _ods_cext.register_operation(_Dialect, replace=True)
-class AffineStoreOp(AffineStoreOp):
- """Specialization for the Affine store operation."""
-
- def __init__(
- self,
- value: Union[Operation, OpView, Value],
- memref: Union[Operation, OpView, Value],
- map: AffineMap = None,
- *,
- map_operands=None,
- loc=None,
- ip=None,
- ):
- """Creates an affine store operation.
-
- - `value`: the value to store into the memref.
- - `memref`: the buffer to store into.
- - `map`: the affine map that maps the map_operands to the index of the
- memref.
- - `map_operands`: the list of arguments to substitute the dimensions,
- then symbols in the affine map, in increasing order.
- """
- map = map if map is not None else []
- map_operands = map_operands if map_operands is not None else []
- indicies = [_get_op_result_or_value(op) for op in map_operands]
- _ods_successors = None
- super().__init__(
- value, memref, indicies, AffineMapAttr.get(map), loc=loc, ip=ip
- )
diff --git a/mlir/python/mlir/dialects/bufferization.py b/mlir/python/mlir/dialects/bufferization.py
index 0ce5448ace4b14c..759b6aa24a9ff73 100644
--- a/mlir/python/mlir/dialects/bufferization.py
+++ b/mlir/python/mlir/dialects/bufferization.py
@@ -3,40 +3,4 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from ._bufferization_ops_gen import *
-from ._bufferization_ops_gen import _Dialect
from ._bufferization_enum_gen import *
-
-try:
- from typing import Sequence, Union
- from ..ir import *
- from ._ods_common import get_default_loc_context, _cext as _ods_cext
-
- from typing import Any, List, Union
-except ImportError as e:
- raise RuntimeError("Error loading imports from extension module") from e
-
-
- at _ods_cext.register_operation(_Dialect, replace=True)
-class AllocTensorOp(AllocTensorOp):
- """Extends the bufferization.alloc_tensor op."""
-
- def __init__(
- self,
- tensor_type: Type,
- dynamic_sizes: Sequence[Value],
- copy: Value,
- size_hint: Value,
- escape: BoolAttr,
- *,
- loc=None,
- ip=None,
- ):
- """Constructs an `alloc_tensor` with static and/or dynamic sizes."""
- super().__init__(
- tensor_type,
- dynamic_sizes,
- copy=copy,
- size_hint=size_hint,
- loc=loc,
- ip=ip,
- )
diff --git a/mlir/python/mlir/dialects/func.py b/mlir/python/mlir/dialects/func.py
index 9c6c4c9092c7a88..6599f67b7078777 100644
--- a/mlir/python/mlir/dialects/func.py
+++ b/mlir/python/mlir/dialects/func.py
@@ -26,9 +26,6 @@
class ConstantOp(ConstantOp):
"""Specialization for the constant op class."""
- def __init__(self, result: Type, value: Attribute, *, loc=None, ip=None):
- super().__init__(result, value, loc=loc, ip=ip)
-
@property
def type(self):
return self.results[0].type
diff --git a/mlir/python/mlir/dialects/memref.py b/mlir/python/mlir/dialects/memref.py
index 111ad2178703d28..3afb6a70cb9e0db 100644
--- a/mlir/python/mlir/dialects/memref.py
+++ b/mlir/python/mlir/dialects/memref.py
@@ -3,41 +3,3 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from ._memref_ops_gen import *
-from ._memref_ops_gen import _Dialect
-
-try:
- from ..ir import *
- from ._ods_common import (
- get_op_result_or_value as _get_op_result_or_value,
- get_op_results_or_values as _get_op_results_or_values,
- _cext as _ods_cext,
- )
-except ImportError as e:
- raise RuntimeError("Error loading imports from extension module") from e
-
-from typing import Optional, Sequence, Union
-
-
- at _ods_cext.register_operation(_Dialect, replace=True)
-class LoadOp(LoadOp):
- """Specialization for the MemRef load operation."""
-
- def __init__(
- self,
- memref: Union[Operation, OpView, Value],
- indices: Optional[Union[Operation, OpView, Sequence[Value]]] = None,
- *,
- loc=None,
- ip=None,
- ):
- """Creates a memref load operation.
-
- Args:
- memref: the buffer to load from.
- indices: the list of subscripts, may be empty for zero-dimensional
- buffers.
- loc: user-visible location of the operation.
- ip: insertion point.
- """
- indices_resolved = [] if indices is None else _get_op_results_or_values(indices)
- super().__init__(memref, indices_resolved, loc=loc, ip=ip)
diff --git a/mlir/python/mlir/dialects/pdl.py b/mlir/python/mlir/dialects/pdl.py
index a8d9c56f4233d9e..90d7d706238e649 100644
--- a/mlir/python/mlir/dialects/pdl.py
+++ b/mlir/python/mlir/dialects/pdl.py
@@ -21,43 +21,6 @@
)
- at _ods_cext.register_operation(_Dialect, replace=True)
-class ApplyNativeConstraintOp(ApplyNativeConstraintOp):
- """Specialization for PDL apply native constraint op class."""
-
- def __init__(
- self,
- name: Union[str, StringAttr],
- args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
- *,
- loc=None,
- ip=None,
- ):
- if args is None:
- args = []
- args = _get_values(args)
- super().__init__(name, args, loc=loc, ip=ip)
-
-
- at _ods_cext.register_operation(_Dialect, replace=True)
-class ApplyNativeRewriteOp(ApplyNativeRewriteOp):
- """Specialization for PDL apply native rewrite op class."""
-
- def __init__(
- self,
- results: Sequence[Type],
- name: Union[str, StringAttr],
- args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
- *,
- loc=None,
- ip=None,
- ):
- if args is None:
- args = []
- args = _get_values(args)
- super().__init__(results, name, args, loc=loc, ip=ip)
-
-
@_ods_cext.register_operation(_Dialect, replace=True)
class AttributeOp(AttributeOp):
"""Specialization for PDL attribute op class."""
@@ -75,21 +38,6 @@ def __init__(
super().__init__(result, valueType=valueType, value=value, loc=loc, ip=ip)
- at _ods_cext.register_operation(_Dialect, replace=True)
-class EraseOp(EraseOp):
- """Specialization for PDL erase op class."""
-
- def __init__(
- self,
- operation: Optional[Union[OpView, Operation, Value]] = None,
- *,
- loc=None,
- ip=None,
- ):
- operation = _get_value(operation)
- super().__init__(operation, loc=loc, ip=ip)
-
-
@_ods_cext.register_operation(_Dialect, replace=True)
class OperandOp(OperandOp):
"""Specialization for PDL operand op class."""
@@ -216,23 +164,6 @@ def __init__(
super().__init__(result, parent, index, loc=loc, ip=ip)
- at _ods_cext.register_operation(_Dialect, replace=True)
-class ResultsOp(ResultsOp):
- """Specialization for PDL results op class."""
-
- def __init__(
- self,
- result: Type,
- parent: Union[OpView, Operation, Value],
- index: Optional[Union[IntegerAttr, int]] = None,
- *,
- loc=None,
- ip=None,
- ):
- parent = _get_value(parent)
- super().__init__(result, parent, index=index, loc=loc, ip=ip)
-
-
@_ods_cext.register_operation(_Dialect, replace=True)
class RewriteOp(RewriteOp):
"""Specialization for PDL rewrite op class."""
diff --git a/mlir/python/mlir/dialects/scf.py b/mlir/python/mlir/dialects/scf.py
index 43ad9f4e2d65f51..71c80cab76dfb86 100644
--- a/mlir/python/mlir/dialects/scf.py
+++ b/mlir/python/mlir/dialects/scf.py
@@ -20,11 +20,8 @@
from typing import Optional, Sequence, Union
-_ForOp = ForOp
-
-
@_ods_cext.register_operation(_Dialect, replace=True)
-class ForOp(_ForOp):
+class ForOp(ForOp):
"""Specialization for the SCF for op class."""
def __init__(
@@ -50,17 +47,8 @@ def __init__(
iter_args = _get_op_results_or_values(iter_args)
results = [arg.type for arg in iter_args]
- super(_ForOp, self).__init__(
- self.build_generic(
- regions=1,
- results=results,
- operands=[
- _get_op_result_or_value(o) for o in [lower_bound, upper_bound, step]
- ]
- + list(iter_args),
- loc=loc,
- ip=ip,
- )
+ super().__init__(
+ results, lower_bound, upper_bound, step, iter_args, loc=loc, ip=ip
)
self.regions[0].blocks.append(self.operands[0].type, *results)
@@ -83,28 +71,23 @@ def inner_iter_args(self):
return self.body.arguments[1:]
-_IfOp = IfOp
-
-
@_ods_cext.register_operation(_Dialect, replace=True)
-class IfOp(_IfOp):
+class IfOp(IfOp):
"""Specialization for the SCF if op class."""
- def __init__(self, cond, results_=[], *, hasElse=False, loc=None, ip=None):
+ def __init__(self, cond, results_=None, *, hasElse=False, loc=None, ip=None):
"""Creates an SCF `if` operation.
- `cond` is a MLIR value of 'i1' type to determine which regions of code will be executed.
- `hasElse` determines whether the if operation has the else branch.
"""
+ if results_ is None:
+ results_ = []
operands = []
operands.append(cond)
results = []
results.extend(results_)
- super(_IfOp, self).__init__(
- self.build_generic(
- regions=2, results=results, operands=operands, loc=loc, ip=ip
- )
- )
+ super().__init__(results, cond)
self.regions[0].blocks.append(*[])
if hasElse:
self.regions[1].blocks.append(*[])
diff --git a/mlir/test/python/dialects/affine.py b/mlir/test/python/dialects/affine.py
index d2e664d4653420f..c5ec85457493b42 100644
--- a/mlir/test/python/dialects/affine.py
+++ b/mlir/test/python/dialects/affine.py
@@ -37,7 +37,7 @@ def affine_store_test(arg0):
a1 = arith.ConstantOp(f32, 2.1)
# CHECK: affine.store %[[A1]], %alloc[symbol(%[[ARG0]]) * 3, %[[ARG0]] + symbol(%[[ARG0]]) + 1] : memref<12x12xf32>
- affine.AffineStoreOp(a1, mem, map, map_operands=[arg0, arg0])
+ affine.AffineStoreOp(a1, mem, indices=[arg0, arg0], map=map)
return mem
diff --git a/mlir/test/python/dialects/func.py b/mlir/test/python/dialects/func.py
index 161a12d78776a0e..a2014c64d2fa53b 100644
--- a/mlir/test/python/dialects/func.py
+++ b/mlir/test/python/dialects/func.py
@@ -84,6 +84,9 @@ def testFunctionCalls():
qux = func.FuncOp("qux", ([], [F32Type.get()]))
qux.sym_visibility = StringAttr.get("private")
+ con = func.ConstantOp(qux.type, qux.sym_name.value)
+ assert con.type == qux.type
+
with InsertionPoint(func.FuncOp("caller", ([], [])).add_entry_block()):
func.CallOp(foo, [])
func.CallOp([IndexType.get()], "bar", [])
@@ -94,6 +97,7 @@ def testFunctionCalls():
# CHECK: func private @foo()
# CHECK: func private @bar() -> index
# CHECK: func private @qux() -> f32
+# CHECK: %f = func.constant @qux : () -> f32
# CHECK: func @caller() {
# CHECK: call @foo() : () -> ()
# CHECK: %0 = call @bar() : () -> index
``````````
</details>
https://github.com/llvm/llvm-project/pull/69642
More information about the Mlir-commits
mailing list