[Mlir-commits] [mlir] [mlir][python] remove mixins (PR #68853)
Maksim Levental
llvmlistbot at llvm.org
Thu Oct 12 07:54:26 PDT 2023
https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/68853
>From 1b5e4bafc7cb7aab4aa4109ce1441568bbff1088 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Thu, 12 Oct 2023 01:16:19 -0500
Subject: [PATCH 1/5] delete stubs
---
mlir/python/mlir/dialects/_linalg_ops_ext.py | 47 ---------------
mlir/python/mlir/dialects/_ods_common.py | 59 -------------------
mlir/python/mlir/dialects/arith.py | 6 --
mlir/python/mlir/dialects/bufferization.py | 6 --
mlir/python/mlir/dialects/builtin.py | 5 --
mlir/python/mlir/dialects/func.py | 5 --
mlir/python/mlir/dialects/memref.py | 5 --
mlir/python/mlir/dialects/ml_program.py | 5 --
mlir/python/mlir/dialects/pdl.py | 6 --
mlir/python/mlir/dialects/scf.py | 43 --------------
mlir/python/mlir/dialects/tensor.py | 5 --
.../mlir/dialects/transform/__init__.py | 7 ---
.../mlir/dialects/transform/bufferization.py | 5 --
mlir/python/mlir/dialects/transform/gpu.py | 5 --
mlir/python/mlir/dialects/transform/loop.py | 5 --
mlir/python/mlir/dialects/transform/memref.py | 5 --
mlir/python/mlir/dialects/transform/pdl.py | 5 --
.../mlir/dialects/transform/structured.py | 6 --
mlir/python/mlir/dialects/transform/tensor.py | 5 --
19 files changed, 235 deletions(-)
delete mode 100644 mlir/python/mlir/dialects/_linalg_ops_ext.py
delete mode 100644 mlir/python/mlir/dialects/arith.py
delete mode 100644 mlir/python/mlir/dialects/bufferization.py
delete mode 100644 mlir/python/mlir/dialects/builtin.py
delete mode 100644 mlir/python/mlir/dialects/func.py
delete mode 100644 mlir/python/mlir/dialects/memref.py
delete mode 100644 mlir/python/mlir/dialects/ml_program.py
delete mode 100644 mlir/python/mlir/dialects/pdl.py
delete mode 100644 mlir/python/mlir/dialects/scf.py
delete mode 100644 mlir/python/mlir/dialects/tensor.py
delete mode 100644 mlir/python/mlir/dialects/transform/__init__.py
delete mode 100644 mlir/python/mlir/dialects/transform/bufferization.py
delete mode 100644 mlir/python/mlir/dialects/transform/gpu.py
delete mode 100644 mlir/python/mlir/dialects/transform/loop.py
delete mode 100644 mlir/python/mlir/dialects/transform/memref.py
delete mode 100644 mlir/python/mlir/dialects/transform/pdl.py
delete mode 100644 mlir/python/mlir/dialects/transform/structured.py
delete mode 100644 mlir/python/mlir/dialects/transform/tensor.py
diff --git a/mlir/python/mlir/dialects/_linalg_ops_ext.py b/mlir/python/mlir/dialects/_linalg_ops_ext.py
deleted file mode 100644
index 3f6d854ca3e2b14..000000000000000
--- a/mlir/python/mlir/dialects/_linalg_ops_ext.py
+++ /dev/null
@@ -1,47 +0,0 @@
-# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-try:
- from typing import Optional, Sequence, Union
- from ..ir import *
- from ._ods_common import get_default_loc_context
- from .._mlir_libs._mlirDialectsLinalg import fill_builtin_region
-except ImportError as e:
- raise RuntimeError("Error loading imports from extension module") from e
-
-from ._ods_common import get_op_result_or_value as _get_op_result_or_value
-
-
-def isa(cls: Type, ty: Type):
- try:
- cls(ty)
- return True
- except ValueError:
- return False
-
-
-class StructuredOpMixin:
- """All structured ops use the same mixin class."""
-
- def __init__(self, inputs, outputs=(), results=(), loc=None, ip=None):
- super().__init__(
- self.build_generic(
- results=list(results),
- operands=[list(inputs), list(outputs)],
- loc=loc,
- ip=ip,
- )
- )
-
-
-def select_opview_mixin(parent_opview_cls):
- # TODO: This shouldn't be a heuristic: we should have a way to annotate
- # the OpView to note that it is a structured op.
- if (
- "__init__" not in parent_opview_cls.__dict__
- and hasattr(parent_opview_cls, "inputs")
- and hasattr(parent_opview_cls, "outputs")
- and hasattr(parent_opview_cls, "result_tensors")
- ):
- return StructuredOpMixin
diff --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py
index 895c3228139b392..9cca7d659ec8cb3 100644
--- a/mlir/python/mlir/dialects/_ods_common.py
+++ b/mlir/python/mlir/dialects/_ods_common.py
@@ -9,7 +9,6 @@
__all__ = [
"equally_sized_accessor",
- "extend_opview_class",
"get_default_loc_context",
"get_op_result_or_value",
"get_op_results_or_values",
@@ -18,64 +17,6 @@
]
-def extend_opview_class(ext_module):
- """Decorator to extend an OpView class from an extension module.
-
- Extension modules can expose various entry-points:
- Stand-alone class with the same name as a parent OpView class (i.e.
- "ReturnOp"). A name-based match is attempted first before falling back
- to a below mechanism.
-
- def select_opview_mixin(parent_opview_cls):
- If defined, allows an appropriate mixin class to be selected dynamically
- based on the parent OpView class. Should return NotImplemented if a
- decision is not made.
-
- Args:
- ext_module: A module from which to locate extensions. Can be None if not
- available.
-
- Returns:
- A decorator that takes an OpView subclass and further extends it as
- needed.
- """
-
- def class_decorator(parent_opview_cls: type):
- if ext_module is None:
- return parent_opview_cls
- mixin_cls = NotImplemented
- # First try to resolve by name.
- try:
- mixin_cls = getattr(ext_module, parent_opview_cls.__name__)
- except AttributeError:
- # Fall back to a select_opview_mixin hook.
- try:
- select_mixin = getattr(ext_module, "select_opview_mixin")
- except AttributeError:
- pass
- else:
- mixin_cls = select_mixin(parent_opview_cls)
-
- if mixin_cls is NotImplemented or mixin_cls is None:
- return parent_opview_cls
-
- # Have a mixin_cls. Create an appropriate subclass.
- try:
-
- class LocalOpView(mixin_cls, parent_opview_cls):
- pass
-
- except TypeError as e:
- raise TypeError(
- f"Could not mixin {mixin_cls} into {parent_opview_cls}"
- ) from e
- LocalOpView.__name__ = parent_opview_cls.__name__
- LocalOpView.__qualname__ = parent_opview_cls.__qualname__
- return LocalOpView
-
- return class_decorator
-
-
def segmented_accessor(elements, raw_segments, idx):
"""
Returns a slice of elements corresponding to the idx-th segment.
diff --git a/mlir/python/mlir/dialects/arith.py b/mlir/python/mlir/dialects/arith.py
deleted file mode 100644
index fb13beb63ca66c3..000000000000000
--- a/mlir/python/mlir/dialects/arith.py
+++ /dev/null
@@ -1,6 +0,0 @@
-# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-from ._arith_ops_gen import *
-from ._arith_enum_gen import *
diff --git a/mlir/python/mlir/dialects/bufferization.py b/mlir/python/mlir/dialects/bufferization.py
deleted file mode 100644
index 759b6aa24a9ff73..000000000000000
--- a/mlir/python/mlir/dialects/bufferization.py
+++ /dev/null
@@ -1,6 +0,0 @@
-# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-from ._bufferization_ops_gen import *
-from ._bufferization_enum_gen import *
diff --git a/mlir/python/mlir/dialects/builtin.py b/mlir/python/mlir/dialects/builtin.py
deleted file mode 100644
index 30279e1611f99aa..000000000000000
--- a/mlir/python/mlir/dialects/builtin.py
+++ /dev/null
@@ -1,5 +0,0 @@
-# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-from ._builtin_ops_gen import *
diff --git a/mlir/python/mlir/dialects/func.py b/mlir/python/mlir/dialects/func.py
deleted file mode 100644
index dc554c22173bc60..000000000000000
--- a/mlir/python/mlir/dialects/func.py
+++ /dev/null
@@ -1,5 +0,0 @@
-# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-from ._func_ops_gen import *
diff --git a/mlir/python/mlir/dialects/memref.py b/mlir/python/mlir/dialects/memref.py
deleted file mode 100644
index 3afb6a70cb9e0db..000000000000000
--- a/mlir/python/mlir/dialects/memref.py
+++ /dev/null
@@ -1,5 +0,0 @@
-# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-from ._memref_ops_gen import *
diff --git a/mlir/python/mlir/dialects/ml_program.py b/mlir/python/mlir/dialects/ml_program.py
deleted file mode 100644
index a654529b4bb8843..000000000000000
--- a/mlir/python/mlir/dialects/ml_program.py
+++ /dev/null
@@ -1,5 +0,0 @@
-# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-from ._ml_program_ops_gen import *
diff --git a/mlir/python/mlir/dialects/pdl.py b/mlir/python/mlir/dialects/pdl.py
deleted file mode 100644
index dda2b7d6521965f..000000000000000
--- a/mlir/python/mlir/dialects/pdl.py
+++ /dev/null
@@ -1,6 +0,0 @@
-# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-from ._pdl_ops_gen import *
-from .._mlir_libs._mlirDialectsPDL import *
diff --git a/mlir/python/mlir/dialects/scf.py b/mlir/python/mlir/dialects/scf.py
deleted file mode 100644
index 49685ca2271fc61..000000000000000
--- a/mlir/python/mlir/dialects/scf.py
+++ /dev/null
@@ -1,43 +0,0 @@
-# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-from typing import Optional, Sequence
-
-from ._scf_ops_gen import *
-from .arith import constant
-from ..ir import *
-
-
-def for_(
- start,
- stop=None,
- step=None,
- iter_args: Optional[Sequence[Value]] = None,
- *,
- loc=None,
- ip=None,
-):
- if step is None:
- step = 1
- if stop is None:
- stop = start
- start = 0
- params = [start, stop, step]
- for i, p in enumerate(params):
- if isinstance(p, int):
- p = constant(p)
- elif isinstance(p, float):
- raise ValueError(f"{p=} must be int.")
- params[i] = p
-
- for_op = ForOp(start, stop, step, iter_args, loc=loc, ip=ip)
- iv = for_op.induction_variable
- iter_args = tuple(for_op.inner_iter_args)
- with InsertionPoint(for_op.body):
- if len(iter_args) > 1:
- yield iv, iter_args
- elif len(iter_args) == 1:
- yield iv, iter_args[0]
- else:
- yield iv
diff --git a/mlir/python/mlir/dialects/tensor.py b/mlir/python/mlir/dialects/tensor.py
deleted file mode 100644
index 26edf6b6436dad5..000000000000000
--- a/mlir/python/mlir/dialects/tensor.py
+++ /dev/null
@@ -1,5 +0,0 @@
-# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-from ._tensor_ops_gen import *
diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py
deleted file mode 100644
index b020ad35fcf062f..000000000000000
--- a/mlir/python/mlir/dialects/transform/__init__.py
+++ /dev/null
@@ -1,7 +0,0 @@
-# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-from .._transform_enum_gen import *
-from .._transform_ops_gen import *
-from ..._mlir_libs._mlirDialectsTransform import *
diff --git a/mlir/python/mlir/dialects/transform/bufferization.py b/mlir/python/mlir/dialects/transform/bufferization.py
deleted file mode 100644
index eb77b746cf864fa..000000000000000
--- a/mlir/python/mlir/dialects/transform/bufferization.py
+++ /dev/null
@@ -1,5 +0,0 @@
-# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-from .._bufferization_transform_ops_gen import *
diff --git a/mlir/python/mlir/dialects/transform/gpu.py b/mlir/python/mlir/dialects/transform/gpu.py
deleted file mode 100644
index 8c3de0de7ea3f19..000000000000000
--- a/mlir/python/mlir/dialects/transform/gpu.py
+++ /dev/null
@@ -1,5 +0,0 @@
-# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-from .._gpu_transform_ops_gen import *
diff --git a/mlir/python/mlir/dialects/transform/loop.py b/mlir/python/mlir/dialects/transform/loop.py
deleted file mode 100644
index 86f72788d86c369..000000000000000
--- a/mlir/python/mlir/dialects/transform/loop.py
+++ /dev/null
@@ -1,5 +0,0 @@
-# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-from .._loop_transform_ops_gen import *
diff --git a/mlir/python/mlir/dialects/transform/memref.py b/mlir/python/mlir/dialects/transform/memref.py
deleted file mode 100644
index 1ff04ef6a60a180..000000000000000
--- a/mlir/python/mlir/dialects/transform/memref.py
+++ /dev/null
@@ -1,5 +0,0 @@
-# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-from .._memref_transform_ops_gen import *
diff --git a/mlir/python/mlir/dialects/transform/pdl.py b/mlir/python/mlir/dialects/transform/pdl.py
deleted file mode 100644
index b1515287a3f1ff0..000000000000000
--- a/mlir/python/mlir/dialects/transform/pdl.py
+++ /dev/null
@@ -1,5 +0,0 @@
-# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-from .._transform_pdl_extension_ops_gen import *
diff --git a/mlir/python/mlir/dialects/transform/structured.py b/mlir/python/mlir/dialects/transform/structured.py
deleted file mode 100644
index cb3812301dbd4b5..000000000000000
--- a/mlir/python/mlir/dialects/transform/structured.py
+++ /dev/null
@@ -1,6 +0,0 @@
-# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-from .._structured_transform_ops_gen import *
-from .._structured_transform_enum_gen import *
diff --git a/mlir/python/mlir/dialects/transform/tensor.py b/mlir/python/mlir/dialects/transform/tensor.py
deleted file mode 100644
index bf52255b3df7145..000000000000000
--- a/mlir/python/mlir/dialects/transform/tensor.py
+++ /dev/null
@@ -1,5 +0,0 @@
-# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-from .._tensor_transform_ops_gen import *
>From 66fdd1895808bda25f6144d6d62b0e7590f0cbe9 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Thu, 12 Oct 2023 01:21:39 -0500
Subject: [PATCH 2/5] rename
---
mlir/python/mlir/dialects/{_arith_ops_ext.py => arith.py} | 0
.../mlir/dialects/{_bufferization_ops_ext.py => bufferization.py} | 0
mlir/python/mlir/dialects/{_builtin_ops_ext.py => builtin.py} | 0
mlir/python/mlir/dialects/{_func_ops_ext.py => func.py} | 0
mlir/python/mlir/dialects/{_memref_ops_ext.py => memref.py} | 0
.../mlir/dialects/{_ml_program_ops_ext.py => ml_program.py} | 0
mlir/python/mlir/dialects/{_pdl_ops_ext.py => pdl.py} | 0
mlir/python/mlir/dialects/{_scf_ops_ext.py => scf.py} | 0
mlir/python/mlir/dialects/{_tensor_ops_ext.py => tensor.py} | 0
.../dialects/{_transform_ops_ext.py => transform/__init__.py} | 0
.../bufferization.py} | 0
.../mlir/dialects/{_gpu_transform_ops_ext.py => transform/gpu.py} | 0
.../dialects/{_loop_transform_ops_ext.py => transform/loop.py} | 0
.../{_memref_transform_ops_ext.py => transform/memref.py} | 0
.../{_transform_pdl_extension_ops_ext.py => transform/pdl.py} | 0
.../{_structured_transform_ops_ext.py => transform/structured.py} | 0
.../{_tensor_transform_ops_ext.py => transform/tensor.py} | 0
17 files changed, 0 insertions(+), 0 deletions(-)
rename mlir/python/mlir/dialects/{_arith_ops_ext.py => arith.py} (100%)
rename mlir/python/mlir/dialects/{_bufferization_ops_ext.py => bufferization.py} (100%)
rename mlir/python/mlir/dialects/{_builtin_ops_ext.py => builtin.py} (100%)
rename mlir/python/mlir/dialects/{_func_ops_ext.py => func.py} (100%)
rename mlir/python/mlir/dialects/{_memref_ops_ext.py => memref.py} (100%)
rename mlir/python/mlir/dialects/{_ml_program_ops_ext.py => ml_program.py} (100%)
rename mlir/python/mlir/dialects/{_pdl_ops_ext.py => pdl.py} (100%)
rename mlir/python/mlir/dialects/{_scf_ops_ext.py => scf.py} (100%)
rename mlir/python/mlir/dialects/{_tensor_ops_ext.py => tensor.py} (100%)
rename mlir/python/mlir/dialects/{_transform_ops_ext.py => transform/__init__.py} (100%)
rename mlir/python/mlir/dialects/{_bufferization_transform_ops_ext.py => transform/bufferization.py} (100%)
rename mlir/python/mlir/dialects/{_gpu_transform_ops_ext.py => transform/gpu.py} (100%)
rename mlir/python/mlir/dialects/{_loop_transform_ops_ext.py => transform/loop.py} (100%)
rename mlir/python/mlir/dialects/{_memref_transform_ops_ext.py => transform/memref.py} (100%)
rename mlir/python/mlir/dialects/{_transform_pdl_extension_ops_ext.py => transform/pdl.py} (100%)
rename mlir/python/mlir/dialects/{_structured_transform_ops_ext.py => transform/structured.py} (100%)
rename mlir/python/mlir/dialects/{_tensor_transform_ops_ext.py => transform/tensor.py} (100%)
diff --git a/mlir/python/mlir/dialects/_arith_ops_ext.py b/mlir/python/mlir/dialects/arith.py
similarity index 100%
rename from mlir/python/mlir/dialects/_arith_ops_ext.py
rename to mlir/python/mlir/dialects/arith.py
diff --git a/mlir/python/mlir/dialects/_bufferization_ops_ext.py b/mlir/python/mlir/dialects/bufferization.py
similarity index 100%
rename from mlir/python/mlir/dialects/_bufferization_ops_ext.py
rename to mlir/python/mlir/dialects/bufferization.py
diff --git a/mlir/python/mlir/dialects/_builtin_ops_ext.py b/mlir/python/mlir/dialects/builtin.py
similarity index 100%
rename from mlir/python/mlir/dialects/_builtin_ops_ext.py
rename to mlir/python/mlir/dialects/builtin.py
diff --git a/mlir/python/mlir/dialects/_func_ops_ext.py b/mlir/python/mlir/dialects/func.py
similarity index 100%
rename from mlir/python/mlir/dialects/_func_ops_ext.py
rename to mlir/python/mlir/dialects/func.py
diff --git a/mlir/python/mlir/dialects/_memref_ops_ext.py b/mlir/python/mlir/dialects/memref.py
similarity index 100%
rename from mlir/python/mlir/dialects/_memref_ops_ext.py
rename to mlir/python/mlir/dialects/memref.py
diff --git a/mlir/python/mlir/dialects/_ml_program_ops_ext.py b/mlir/python/mlir/dialects/ml_program.py
similarity index 100%
rename from mlir/python/mlir/dialects/_ml_program_ops_ext.py
rename to mlir/python/mlir/dialects/ml_program.py
diff --git a/mlir/python/mlir/dialects/_pdl_ops_ext.py b/mlir/python/mlir/dialects/pdl.py
similarity index 100%
rename from mlir/python/mlir/dialects/_pdl_ops_ext.py
rename to mlir/python/mlir/dialects/pdl.py
diff --git a/mlir/python/mlir/dialects/_scf_ops_ext.py b/mlir/python/mlir/dialects/scf.py
similarity index 100%
rename from mlir/python/mlir/dialects/_scf_ops_ext.py
rename to mlir/python/mlir/dialects/scf.py
diff --git a/mlir/python/mlir/dialects/_tensor_ops_ext.py b/mlir/python/mlir/dialects/tensor.py
similarity index 100%
rename from mlir/python/mlir/dialects/_tensor_ops_ext.py
rename to mlir/python/mlir/dialects/tensor.py
diff --git a/mlir/python/mlir/dialects/_transform_ops_ext.py b/mlir/python/mlir/dialects/transform/__init__.py
similarity index 100%
rename from mlir/python/mlir/dialects/_transform_ops_ext.py
rename to mlir/python/mlir/dialects/transform/__init__.py
diff --git a/mlir/python/mlir/dialects/_bufferization_transform_ops_ext.py b/mlir/python/mlir/dialects/transform/bufferization.py
similarity index 100%
rename from mlir/python/mlir/dialects/_bufferization_transform_ops_ext.py
rename to mlir/python/mlir/dialects/transform/bufferization.py
diff --git a/mlir/python/mlir/dialects/_gpu_transform_ops_ext.py b/mlir/python/mlir/dialects/transform/gpu.py
similarity index 100%
rename from mlir/python/mlir/dialects/_gpu_transform_ops_ext.py
rename to mlir/python/mlir/dialects/transform/gpu.py
diff --git a/mlir/python/mlir/dialects/_loop_transform_ops_ext.py b/mlir/python/mlir/dialects/transform/loop.py
similarity index 100%
rename from mlir/python/mlir/dialects/_loop_transform_ops_ext.py
rename to mlir/python/mlir/dialects/transform/loop.py
diff --git a/mlir/python/mlir/dialects/_memref_transform_ops_ext.py b/mlir/python/mlir/dialects/transform/memref.py
similarity index 100%
rename from mlir/python/mlir/dialects/_memref_transform_ops_ext.py
rename to mlir/python/mlir/dialects/transform/memref.py
diff --git a/mlir/python/mlir/dialects/_transform_pdl_extension_ops_ext.py b/mlir/python/mlir/dialects/transform/pdl.py
similarity index 100%
rename from mlir/python/mlir/dialects/_transform_pdl_extension_ops_ext.py
rename to mlir/python/mlir/dialects/transform/pdl.py
diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/transform/structured.py
similarity index 100%
rename from mlir/python/mlir/dialects/_structured_transform_ops_ext.py
rename to mlir/python/mlir/dialects/transform/structured.py
diff --git a/mlir/python/mlir/dialects/_tensor_transform_ops_ext.py b/mlir/python/mlir/dialects/transform/tensor.py
similarity index 100%
rename from mlir/python/mlir/dialects/_tensor_transform_ops_ext.py
rename to mlir/python/mlir/dialects/transform/tensor.py
>From b3850f7b5ce192aafd3b82b669467f91ebd9605d Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Thu, 12 Oct 2023 01:23:57 -0500
Subject: [PATCH 3/5] fix
---
mlir/lib/Bindings/Python/Globals.h | 2 +-
mlir/lib/Bindings/Python/IRModule.cpp | 4 +-
mlir/lib/Bindings/Python/MainModule.cpp | 11 +--
mlir/python/CMakeLists.txt | 18 ----
mlir/python/mlir/dialects/arith.py | 10 ++-
mlir/python/mlir/dialects/bufferization.py | 25 +++---
mlir/python/mlir/dialects/builtin.py | 9 +-
mlir/python/mlir/dialects/func.py | 17 +++-
.../dialects/linalg/opdsl/lang/emitter.py | 2 +-
mlir/python/mlir/dialects/memref.py | 15 +++-
mlir/python/mlir/dialects/ml_program.py | 16 ++--
mlir/python/mlir/dialects/pdl.py | 48 +++++++---
mlir/python/mlir/dialects/scf.py | 63 +++++++++++--
mlir/python/mlir/dialects/tensor.py | 20 ++---
.../mlir/dialects/transform/__init__.py | 45 +++++-----
.../mlir/dialects/transform/bufferization.py | 24 +++--
mlir/python/mlir/dialects/transform/gpu.py | 26 +++---
mlir/python/mlir/dialects/transform/loop.py | 25 ++++--
mlir/python/mlir/dialects/transform/memref.py | 24 +++--
mlir/python/mlir/dialects/transform/pdl.py | 90 +++++++++----------
.../mlir/dialects/transform/structured.py | 54 +++++++----
mlir/python/mlir/dialects/transform/tensor.py | 17 ++--
mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp | 49 ++++------
23 files changed, 366 insertions(+), 248 deletions(-)
diff --git a/mlir/lib/Bindings/Python/Globals.h b/mlir/lib/Bindings/Python/Globals.h
index 97cd70089a2e965..dea44bbd469dd3d 100644
--- a/mlir/lib/Bindings/Python/Globals.h
+++ b/mlir/lib/Bindings/Python/Globals.h
@@ -80,7 +80,7 @@ class PyGlobals {
/// Raises an exception if the mapping already exists.
/// This is intended to be called by implementation code.
void registerOperationImpl(const std::string &operationName,
- pybind11::object pyClass);
+ pybind11::object pyClass, bool replace = false);
/// Returns the custom Attribute builder for Attribute kind.
std::optional<pybind11::function>
diff --git a/mlir/lib/Bindings/Python/IRModule.cpp b/mlir/lib/Bindings/Python/IRModule.cpp
index 2cc66277abee0f0..a1c8ab7a09ce155 100644
--- a/mlir/lib/Bindings/Python/IRModule.cpp
+++ b/mlir/lib/Bindings/Python/IRModule.cpp
@@ -96,9 +96,9 @@ void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
}
void PyGlobals::registerOperationImpl(const std::string &operationName,
- py::object pyClass) {
+ py::object pyClass, bool replace) {
py::object &found = operationClassMap[operationName];
- if (found) {
+ if (found && !replace) {
throw std::runtime_error((llvm::Twine("Operation '") + operationName +
"' is already registered.")
.str());
diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index cdddfbe50606d05..a936becf67bea75 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -41,7 +41,7 @@ PYBIND11_MODULE(_mlir, m) {
"dialect_namespace"_a, "dialect_class"_a,
"Testing hook for directly registering a dialect")
.def("_register_operation_impl", &PyGlobals::registerOperationImpl,
- "operation_name"_a, "operation_class"_a,
+ "operation_name"_a, "operation_class"_a, "replace"_a = false,
"Testing hook for directly registering an operation");
// Aside from making the globals accessible to python, having python manage
@@ -63,12 +63,13 @@ PYBIND11_MODULE(_mlir, m) {
"Class decorator for registering a custom Dialect wrapper");
m.def(
"register_operation",
- [](const py::object &dialectClass) -> py::cpp_function {
+ [](const py::object &dialectClass, bool replace) -> py::cpp_function {
return py::cpp_function(
- [dialectClass](py::object opClass) -> py::object {
+ [dialectClass, replace](py::object opClass) -> py::object {
std::string operationName =
opClass.attr("OPERATION_NAME").cast<std::string>();
- PyGlobals::get().registerOperationImpl(operationName, opClass);
+ PyGlobals::get().registerOperationImpl(operationName, opClass,
+ replace);
// Dict-stuff the new opClass by name onto the dialect class.
py::object opClassName = opClass.attr("__name__");
@@ -76,7 +77,7 @@ PYBIND11_MODULE(_mlir, m) {
return opClass;
});
},
- "dialect_class"_a,
+ "dialect_class"_a, "replace"_a = false,
"Produce a class decorator for registering an Operation class as part of "
"a dialect");
m.def(
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index c7b3c283a6b6dc1..586d39fc332fac5 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -78,7 +78,6 @@ declare_mlir_dialect_python_bindings(
TD_FILE dialects/BufferizationOps.td
SOURCES
dialects/bufferization.py
- dialects/_bufferization_ops_ext.py
DIALECT_NAME bufferization
GEN_ENUM_BINDINGS_TD_FILE
"../../include/mlir/Dialect/Bufferization/IR/BufferizationEnums.td"
@@ -90,7 +89,6 @@ declare_mlir_dialect_python_bindings(
TD_FILE dialects/BuiltinOps.td
SOURCES
dialects/builtin.py
- dialects/_builtin_ops_ext.py
DIALECT_NAME builtin)
declare_mlir_dialect_python_bindings(
@@ -115,7 +113,6 @@ declare_mlir_dialect_python_bindings(
TD_FILE dialects/FuncOps.td
SOURCES
dialects/func.py
- dialects/_func_ops_ext.py
DIALECT_NAME func)
declare_mlir_dialect_python_bindings(
@@ -131,7 +128,6 @@ declare_mlir_dialect_python_bindings(
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
TD_FILE dialects/LinalgOps.td
SOURCES
- dialects/_linalg_ops_ext.py
SOURCES_GLOB
dialects/linalg/*.py
DIALECT_NAME linalg
@@ -152,7 +148,6 @@ ADD_TO_PARENT MLIRPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
TD_FILE dialects/TransformPDLExtensionOps.td
SOURCES
- dialects/_transform_pdl_extension_ops_ext.py
dialects/transform/pdl.py
DIALECT_NAME transform
EXTENSION_NAME transform_pdl_extension)
@@ -162,7 +157,6 @@ declare_mlir_dialect_python_bindings(
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
TD_FILE dialects/TransformOps.td
SOURCES
- dialects/_transform_ops_ext.py
dialects/transform/__init__.py
_mlir_libs/_mlir/dialects/transform/__init__.pyi
DIALECT_NAME transform
@@ -175,7 +169,6 @@ declare_mlir_dialect_extension_python_bindings(
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
TD_FILE dialects/BufferizationTransformOps.td
SOURCES
- dialects/_bufferization_transform_ops_ext.py
dialects/transform/bufferization.py
DIALECT_NAME transform
EXTENSION_NAME bufferization_transform)
@@ -185,7 +178,6 @@ declare_mlir_dialect_extension_python_bindings(
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
TD_FILE dialects/GPUTransformOps.td
SOURCES
- dialects/_gpu_transform_ops_ext.py
dialects/transform/gpu.py
DIALECT_NAME transform
EXTENSION_NAME gpu_transform)
@@ -195,7 +187,6 @@ declare_mlir_dialect_extension_python_bindings(
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
TD_FILE dialects/SCFLoopTransformOps.td
SOURCES
- dialects/_loop_transform_ops_ext.py
dialects/transform/loop.py
DIALECT_NAME transform
EXTENSION_NAME loop_transform)
@@ -205,7 +196,6 @@ declare_mlir_dialect_extension_python_bindings(
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
TD_FILE dialects/MemRefTransformOps.td
SOURCES
- dialects/_memref_transform_ops_ext.py
dialects/transform/memref.py
DIALECT_NAME transform
EXTENSION_NAME memref_transform)
@@ -224,7 +214,6 @@ declare_mlir_dialect_extension_python_bindings(
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
TD_FILE dialects/LinalgStructuredTransformOps.td
SOURCES
- dialects/_structured_transform_ops_ext.py
dialects/transform/structured.py
DIALECT_NAME transform
EXTENSION_NAME structured_transform
@@ -246,7 +235,6 @@ declare_mlir_dialect_extension_python_bindings(
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
TD_FILE dialects/TensorTransformOps.td
SOURCES
- dialects/_tensor_transform_ops_ext.py
dialects/transform/tensor.py
DIALECT_NAME transform
EXTENSION_NAME tensor_transform)
@@ -276,7 +264,6 @@ declare_mlir_dialect_python_bindings(
TD_FILE dialects/ArithOps.td
SOURCES
dialects/arith.py
- dialects/_arith_ops_ext.py
DIALECT_NAME arith
GEN_ENUM_BINDINGS)
@@ -286,7 +273,6 @@ declare_mlir_dialect_python_bindings(
TD_FILE dialects/MemRefOps.td
SOURCES
dialects/memref.py
- dialects/_memref_ops_ext.py
DIALECT_NAME memref)
declare_mlir_dialect_python_bindings(
@@ -295,7 +281,6 @@ declare_mlir_dialect_python_bindings(
TD_FILE dialects/MLProgramOps.td
SOURCES
dialects/ml_program.py
- dialects/_ml_program_ops_ext.py
DIALECT_NAME ml_program)
declare_mlir_dialect_python_bindings(
@@ -339,7 +324,6 @@ declare_mlir_dialect_python_bindings(
TD_FILE dialects/PDLOps.td
SOURCES
dialects/pdl.py
- dialects/_pdl_ops_ext.py
_mlir_libs/_mlir/dialects/pdl.pyi
DIALECT_NAME pdl)
@@ -357,7 +341,6 @@ declare_mlir_dialect_python_bindings(
TD_FILE dialects/SCFOps.td
SOURCES
dialects/scf.py
- dialects/_scf_ops_ext.py
DIALECT_NAME scf)
declare_mlir_dialect_python_bindings(
@@ -383,7 +366,6 @@ declare_mlir_dialect_python_bindings(
TD_FILE dialects/TensorOps.td
SOURCES
dialects/tensor.py
- dialects/_tensor_ops_ext.py
DIALECT_NAME tensor)
declare_mlir_dialect_python_bindings(
diff --git a/mlir/python/mlir/dialects/arith.py b/mlir/python/mlir/dialects/arith.py
index df38f871710fe8f..e3b6a428c879de5 100644
--- a/mlir/python/mlir/dialects/arith.py
+++ b/mlir/python/mlir/dialects/arith.py
@@ -2,9 +2,13 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+from ._arith_ops_gen import *
+from ._arith_ops_gen import _Dialect
+from ._arith_enum_gen import *
+
try:
from ..ir import *
- from ._ods_common import get_default_loc_context as _get_default_loc_context
+ from ._ods_common import get_default_loc_context as _get_default_loc_context, _cext as _ods_cext
from typing import Any, List, Union
except ImportError as e:
@@ -30,8 +34,8 @@ def _is_integer_like_type(type: Type):
def _is_float_type(type: Type):
return _is_any_of(type, [BF16Type, F16Type, F32Type, F64Type])
-
-class ConstantOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class ConstantOp(ConstantOp):
"""Specialization for the constant op class."""
def __init__(
diff --git a/mlir/python/mlir/dialects/bufferization.py b/mlir/python/mlir/dialects/bufferization.py
index 1066cb4c775cab9..0ce5448ace4b14c 100644
--- a/mlir/python/mlir/dialects/bufferization.py
+++ b/mlir/python/mlir/dialects/bufferization.py
@@ -2,17 +2,22 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+from ._bufferization_ops_gen import *
+from ._bufferization_ops_gen import _Dialect
+from ._bufferization_enum_gen import *
+
try:
from typing import Sequence, Union
from ..ir import *
- from ._ods_common import get_default_loc_context
+ from ._ods_common import get_default_loc_context, _cext as _ods_cext
from typing import Any, List, Union
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
-class AllocTensorOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class AllocTensorOp(AllocTensorOp):
"""Extends the bufferization.alloc_tensor op."""
def __init__(
@@ -24,18 +29,14 @@ def __init__(
escape: BoolAttr,
*,
loc=None,
- ip=None
+ ip=None,
):
"""Constructs an `alloc_tensor` with static and/or dynamic sizes."""
- context = get_default_loc_context(loc)
- attributes = {}
- if escape:
- attributes["escape"] = escape
- op = self.build_generic(
- results=[tensor_type],
- operands=[dynamic_sizes, copy, size_hint],
- attributes=attributes,
+ super().__init__(
+ tensor_type,
+ dynamic_sizes,
+ copy=copy,
+ size_hint=size_hint,
loc=loc,
ip=ip,
)
- OpView.__init__(self, op)
diff --git a/mlir/python/mlir/dialects/builtin.py b/mlir/python/mlir/dialects/builtin.py
index 27a60123050acb4..b71cc2466d464b3 100644
--- a/mlir/python/mlir/dialects/builtin.py
+++ b/mlir/python/mlir/dialects/builtin.py
@@ -2,17 +2,22 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+from ._builtin_ops_gen import *
+from ._builtin_ops_gen import _Dialect
+
try:
from ..ir import *
+ from ._ods_common import _cext as _ods_cext
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
-class ModuleOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class ModuleOp(ModuleOp):
"""Specialization for the module op class."""
def __init__(self, *, loc=None, ip=None):
- super().__init__(self.build_generic(results=[], operands=[], loc=loc, ip=ip))
+ super().__init__(loc=loc, ip=ip)
body = self.regions[0].blocks.append()
@property
diff --git a/mlir/python/mlir/dialects/func.py b/mlir/python/mlir/dialects/func.py
index 6d264c33f1f9dae..9c6c4c9092c7a88 100644
--- a/mlir/python/mlir/dialects/func.py
+++ b/mlir/python/mlir/dialects/func.py
@@ -2,9 +2,15 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+from ._func_ops_gen import *
+from ._func_ops_gen import _Dialect
+
try:
from ..ir import *
- from ._ods_common import get_default_loc_context as _get_default_loc_context
+ from ._ods_common import (
+ get_default_loc_context as _get_default_loc_context,
+ _cext as _ods_cext,
+ )
import inspect
@@ -16,7 +22,8 @@
RESULT_ATTRIBUTE_NAME = "res_attrs"
-class ConstantOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class ConstantOp(ConstantOp):
"""Specialization for the constant op class."""
def __init__(self, result: Type, value: Attribute, *, loc=None, ip=None):
@@ -27,7 +34,8 @@ def type(self):
return self.results[0].type
-class FuncOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class FuncOp(FuncOp):
"""Specialization for the func op class."""
def __init__(
@@ -238,7 +246,8 @@ def emit_call_op(*call_args):
return decorator
-class CallOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class CallOp(CallOp):
"""Specialization for the call op class."""
def __init__(
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
index 6f9d72164429eea..f91fc8b7160089b 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
@@ -310,7 +310,7 @@ def emit_named_structured_op(
)
# Set the index attributes used to compute the indexing maps.
- named_op = getattr(linalg, op_class_name)(ins, outs, result_types)
+ named_op = getattr(linalg, op_class_name)(result_types, ins, outs)
for name, value in index_attrs.items():
named_op.operation.attributes[name] = value
diff --git a/mlir/python/mlir/dialects/memref.py b/mlir/python/mlir/dialects/memref.py
index 825f1a0a7a6faf4..111ad2178703d28 100644
--- a/mlir/python/mlir/dialects/memref.py
+++ b/mlir/python/mlir/dialects/memref.py
@@ -2,17 +2,24 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+from ._memref_ops_gen import *
+from ._memref_ops_gen import _Dialect
+
try:
from ..ir import *
- from ._ods_common import get_op_result_or_value as _get_op_result_or_value
- from ._ods_common import get_op_results_or_values as _get_op_results_or_values
+ from ._ods_common import (
+ get_op_result_or_value as _get_op_result_or_value,
+ get_op_results_or_values as _get_op_results_or_values,
+ _cext as _ods_cext,
+ )
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
from typing import Optional, Sequence, Union
-class LoadOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class LoadOp(LoadOp):
"""Specialization for the MemRef load operation."""
def __init__(
@@ -21,7 +28,7 @@ def __init__(
indices: Optional[Union[Operation, OpView, Sequence[Value]]] = None,
*,
loc=None,
- ip=None
+ ip=None,
):
"""Creates a memref load operation.
diff --git a/mlir/python/mlir/dialects/ml_program.py b/mlir/python/mlir/dialects/ml_program.py
index c84d23c16ef93ab..dfb6d7f2c03b1cf 100644
--- a/mlir/python/mlir/dialects/ml_program.py
+++ b/mlir/python/mlir/dialects/ml_program.py
@@ -2,21 +2,27 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+from typing import Union
+
+from ._ml_program_ops_gen import *
+from ._ml_program_ops_gen import _Dialect
+
try:
- from typing import Union
from ..ir import *
- from ._ods_common import get_default_loc_context as _get_default_loc_context
+ from ._ods_common import (
+ get_default_loc_context as _get_default_loc_context,
+ _cext as _ods_cext,
+ )
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
-from ._ml_program_ops_gen import *
-
ARGUMENT_ATTRIBUTE_NAME = "arg_attrs"
RESULT_ATTRIBUTE_NAME = "res_attrs"
-class FuncOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class FuncOp(FuncOp):
"""Specialization for the func op class."""
def __init__(
diff --git a/mlir/python/mlir/dialects/pdl.py b/mlir/python/mlir/dialects/pdl.py
index fc9de0b7f7db69c..a8d9c56f4233d9e 100644
--- a/mlir/python/mlir/dialects/pdl.py
+++ b/mlir/python/mlir/dialects/pdl.py
@@ -2,6 +2,11 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+from ._pdl_ops_gen import *
+from ._pdl_ops_gen import _Dialect
+from .._mlir_libs._mlirDialectsPDL import *
+
+
try:
from ..ir import *
from ..dialects import pdl
@@ -12,10 +17,12 @@
from ._ods_common import (
get_op_result_or_value as _get_value,
get_op_results_or_values as _get_values,
+ _cext as _ods_cext,
)
-class ApplyNativeConstraintOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class ApplyNativeConstraintOp(ApplyNativeConstraintOp):
"""Specialization for PDL apply native constraint op class."""
def __init__(
@@ -32,7 +39,8 @@ def __init__(
super().__init__(name, args, loc=loc, ip=ip)
-class ApplyNativeRewriteOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class ApplyNativeRewriteOp(ApplyNativeRewriteOp):
"""Specialization for PDL apply native rewrite op class."""
def __init__(
@@ -50,7 +58,8 @@ def __init__(
super().__init__(results, name, args, loc=loc, ip=ip)
-class AttributeOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class AttributeOp(AttributeOp):
"""Specialization for PDL attribute op class."""
def __init__(
@@ -66,7 +75,8 @@ def __init__(
super().__init__(result, valueType=valueType, value=value, loc=loc, ip=ip)
-class EraseOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class EraseOp(EraseOp):
"""Specialization for PDL erase op class."""
def __init__(
@@ -80,7 +90,8 @@ def __init__(
super().__init__(operation, loc=loc, ip=ip)
-class OperandOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class OperandOp(OperandOp):
"""Specialization for PDL operand op class."""
def __init__(
@@ -95,7 +106,8 @@ def __init__(
super().__init__(result, valueType=type, loc=loc, ip=ip)
-class OperandsOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class OperandsOp(OperandsOp):
"""Specialization for PDL operands op class."""
def __init__(
@@ -110,7 +122,8 @@ def __init__(
super().__init__(result, valueType=types, loc=loc, ip=ip)
-class OperationOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class OperationOp(OperationOp):
"""Specialization for PDL operand op class."""
def __init__(
@@ -143,7 +156,8 @@ def __init__(
)
-class PatternOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class PatternOp(PatternOp):
"""Specialization for PDL pattern op class."""
def __init__(
@@ -164,7 +178,8 @@ def body(self):
return self.regions[0].blocks[0]
-class ReplaceOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class ReplaceOp(ReplaceOp):
"""Specialization for PDL replace op class."""
def __init__(
@@ -184,7 +199,8 @@ def __init__(
super().__init__(op, with_values, replOperation=with_op, loc=loc, ip=ip)
-class ResultOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class ResultOp(ResultOp):
"""Specialization for PDL result op class."""
def __init__(
@@ -200,7 +216,8 @@ def __init__(
super().__init__(result, parent, index, loc=loc, ip=ip)
-class ResultsOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class ResultsOp(ResultsOp):
"""Specialization for PDL results op class."""
def __init__(
@@ -216,7 +233,8 @@ def __init__(
super().__init__(result, parent, index=index, loc=loc, ip=ip)
-class RewriteOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class RewriteOp(RewriteOp):
"""Specialization for PDL rewrite op class."""
def __init__(
@@ -245,7 +263,8 @@ def body(self):
return self.regions[0].blocks[0]
-class TypeOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class TypeOp(TypeOp):
"""Specialization for PDL type op class."""
def __init__(
@@ -255,7 +274,8 @@ def __init__(
super().__init__(result, constantType=constantType, loc=loc, ip=ip)
-class TypesOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class TypesOp(TypesOp):
"""Specialization for PDL types op class."""
def __init__(
diff --git a/mlir/python/mlir/dialects/scf.py b/mlir/python/mlir/dialects/scf.py
index 89cc8a19895c7b4..43ad9f4e2d65f51 100644
--- a/mlir/python/mlir/dialects/scf.py
+++ b/mlir/python/mlir/dialects/scf.py
@@ -2,20 +2,29 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from ._scf_ops_gen import *
+from ._scf_ops_gen import _Dialect
+from .arith import constant
+
try:
from ..ir import *
+ from ._ods_common import (
+ get_op_result_or_value as _get_op_result_or_value,
+ get_op_results_or_values as _get_op_results_or_values,
+ _cext as _ods_cext,
+ )
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
from typing import Optional, Sequence, Union
-from ._ods_common import (
- get_op_result_or_value as _get_op_result_or_value,
- get_op_results_or_values as _get_op_results_or_values,
-)
+_ForOp = ForOp
-class ForOp:
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class ForOp(_ForOp):
"""Specialization for the SCF for op class."""
def __init__(
@@ -41,7 +50,7 @@ def __init__(
iter_args = _get_op_results_or_values(iter_args)
results = [arg.type for arg in iter_args]
- super().__init__(
+ super(_ForOp, self).__init__(
self.build_generic(
regions=1,
results=results,
@@ -74,7 +83,11 @@ def inner_iter_args(self):
return self.body.arguments[1:]
-class IfOp:
+_IfOp = IfOp
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class IfOp(_IfOp):
"""Specialization for the SCF if op class."""
def __init__(self, cond, results_=[], *, hasElse=False, loc=None, ip=None):
@@ -87,7 +100,7 @@ def __init__(self, cond, results_=[], *, hasElse=False, loc=None, ip=None):
operands.append(cond)
results = []
results.extend(results_)
- super().__init__(
+ super(_IfOp, self).__init__(
self.build_generic(
regions=2, results=results, operands=operands, loc=loc, ip=ip
)
@@ -105,3 +118,37 @@ def then_block(self):
def else_block(self):
"""Returns the else block of the if operation."""
return self.regions[1].blocks[0]
+
+
+def for_(
+ start,
+ stop=None,
+ step=None,
+ iter_args: Optional[Sequence[Value]] = None,
+ *,
+ loc=None,
+ ip=None,
+):
+ if step is None:
+ step = 1
+ if stop is None:
+ stop = start
+ start = 0
+ params = [start, stop, step]
+ for i, p in enumerate(params):
+ if isinstance(p, int):
+ p = constant(p)
+ elif isinstance(p, float):
+ raise ValueError(f"{p=} must be int.")
+ params[i] = p
+
+ for_op = ForOp(start, stop, step, iter_args, loc=loc, ip=ip)
+ iv = for_op.induction_variable
+ iter_args = tuple(for_op.inner_iter_args)
+ with InsertionPoint(for_op.body):
+ if len(iter_args) > 1:
+ yield iv, iter_args
+ elif len(iter_args) == 1:
+ yield iv, iter_args[0]
+ else:
+ yield iv
diff --git a/mlir/python/mlir/dialects/tensor.py b/mlir/python/mlir/dialects/tensor.py
index 09b9ec68db7d9c7..67248748eaf3ada 100644
--- a/mlir/python/mlir/dialects/tensor.py
+++ b/mlir/python/mlir/dialects/tensor.py
@@ -2,19 +2,20 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+from ._tensor_ops_gen import *
+from ._tensor_ops_gen import _Dialect
+
try:
from ..ir import *
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
-from typing import Any, Optional, Sequence, Union
-from ._ods_common import (
- get_op_result_or_value as _get_op_result_or_value,
- get_op_results_or_values as _get_op_results_or_values,
-)
+from typing import Sequence, Union
+from ._ods_common import _cext as _ods_cext
-class EmptyOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class EmptyOp(EmptyOp):
"""Extends the tensor.empty op."""
def __init__(
@@ -23,7 +24,7 @@ def __init__(
element_type: Type,
*,
loc=None,
- ip=None
+ ip=None,
):
"""Constructs an `empty` with mixed static/dynamic sizes."""
# TODO: Refactor the EmptyOp to take an element type attribute and
@@ -38,7 +39,4 @@ def __init__(
static_sizes.append(ShapedType.get_dynamic_size())
dynamic_sizes.append(s)
result_type = RankedTensorType.get(static_sizes, element_type)
- op = self.build_generic(
- results=[result_type], operands=dynamic_sizes, attributes={}, loc=loc, ip=ip
- )
- OpView.__init__(self, op)
+ super().__init__(result_type, dynamic_sizes, loc=loc, ip=ip)
diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py
index b1e7b892536f4a1..f7a2026e800aeb0 100644
--- a/mlir/python/mlir/dialects/transform/__init__.py
+++ b/mlir/python/mlir/dialects/transform/__init__.py
@@ -2,11 +2,17 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+from .._transform_enum_gen import *
+from .._transform_ops_gen import *
+from .._transform_ops_gen import _Dialect
+from ..._mlir_libs._mlirDialectsTransform import *
+
try:
- from ..ir import *
- from ._ods_common import (
+ from ...ir import *
+ from .._ods_common import (
get_op_result_or_value as _get_op_result_or_value,
get_op_results_or_values as _get_op_results_or_values,
+ _cext as _ods_cext,
)
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
@@ -14,7 +20,8 @@
from typing import Optional, Sequence, Union
-class CastOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class CastOp(CastOp):
def __init__(
self,
result_type: Type,
@@ -26,7 +33,8 @@ def __init__(
super().__init__(result_type, _get_op_result_or_value(target), loc=loc, ip=ip)
-class ApplyPatternsOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class ApplyPatternsOp(ApplyPatternsOp):
def __init__(
self,
target: Union[Operation, Value, OpView],
@@ -34,19 +42,7 @@ def __init__(
loc=None,
ip=None,
):
- operands = []
- operands.append(_get_op_result_or_value(target))
- super().__init__(
- self.build_generic(
- attributes={},
- results=[],
- operands=operands,
- successors=None,
- regions=None,
- loc=loc,
- ip=ip,
- )
- )
+ super().__init__(target, loc=loc, ip=ip)
self.regions[0].blocks.append()
@property
@@ -54,7 +50,8 @@ def patterns(self) -> Block:
return self.regions[0].blocks[0]
-class testGetParentOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class GetParentOp(GetParentOp):
def __init__(
self,
result_type: Type,
@@ -77,7 +74,8 @@ def __init__(
)
-class MergeHandlesOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class MergeHandlesOp(MergeHandlesOp):
def __init__(
self,
handles: Sequence[Union[Operation, Value]],
@@ -94,7 +92,8 @@ def __init__(
)
-class ReplicateOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class ReplicateOp(ReplicateOp):
def __init__(
self,
pattern: Union[Operation, Value],
@@ -112,7 +111,8 @@ def __init__(
)
-class SequenceOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class SequenceOp(SequenceOp):
def __init__(
self,
failure_propagation_mode,
@@ -163,7 +163,8 @@ def bodyExtraArgs(self) -> BlockArgumentList:
return self.body.arguments[1:]
-class YieldOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class YieldOp(YieldOp):
def __init__(
self,
operands: Optional[Union[Operation, Sequence[Value]]] = None,
diff --git a/mlir/python/mlir/dialects/transform/bufferization.py b/mlir/python/mlir/dialects/transform/bufferization.py
index 7e6c1b81cb350b7..485a8a36b6305e9 100644
--- a/mlir/python/mlir/dialects/transform/bufferization.py
+++ b/mlir/python/mlir/dialects/transform/bufferization.py
@@ -2,9 +2,13 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+from .._bufferization_transform_ops_gen import *
+from .._bufferization_transform_ops_gen import _Dialect
+
try:
- from ..ir import *
- from ..dialects import transform
+ from ...ir import *
+ from ...dialects import transform
+ from .._ods_common import _cext as _ods_cext
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
@@ -12,7 +16,8 @@
from typing import Optional, overload, Union
-class EmptyTensorToAllocTensorOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class EmptyTensorToAllocTensorOp(EmptyTensorToAllocTensorOp):
"""Specialization for EmptyTensorToAllocTensorOp class."""
@overload
@@ -22,7 +27,7 @@ def __init__(
target: Union[Operation, OpView, Value],
*,
loc=None,
- ip=None
+ ip=None,
):
...
@@ -36,7 +41,7 @@ def __init__(
target_or_none: Optional[Union[Operation, OpView, Value]] = None,
*,
loc=None,
- ip=None
+ ip=None,
):
if isinstance(transformed_type_or_target, Type):
transformed_type = transformed_type_or_target
@@ -53,7 +58,8 @@ def __init__(
)
-class OneShotBufferizeOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class OneShotBufferizeOp(OneShotBufferizeOp):
"""Specialization for OneShotBufferizeOp class."""
@overload
@@ -70,7 +76,7 @@ def __init__(
print_conflicts: Optional[bool] = None,
test_analysis_only: Optional[bool] = None,
loc=None,
- ip=None
+ ip=None,
):
...
@@ -87,7 +93,7 @@ def __init__(
print_conflicts: Optional[bool] = None,
test_analysis_only: Optional[bool] = None,
loc=None,
- ip=None
+ ip=None,
):
...
@@ -104,7 +110,7 @@ def __init__(
print_conflicts: Optional[bool] = None,
test_analysis_only: Optional[bool] = None,
loc=None,
- ip=None
+ ip=None,
):
if isinstance(transformed_type_or_target, Type):
transformed_type = transformed_type_or_target
diff --git a/mlir/python/mlir/dialects/transform/gpu.py b/mlir/python/mlir/dialects/transform/gpu.py
index ba72bac3a15264d..00cf0840eeae9e1 100644
--- a/mlir/python/mlir/dialects/transform/gpu.py
+++ b/mlir/python/mlir/dialects/transform/gpu.py
@@ -2,16 +2,21 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+from .._gpu_transform_ops_gen import *
+from .._gpu_transform_ops_gen import _Dialect
+
try:
- from ..ir import *
- from ..dialects import transform
+ from ...ir import *
+ from ...dialects import transform
+ from .._ods_common import _cext as _ods_cext
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
from typing import Optional, Sequence, Union, overload
-class MapForallToBlocks:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class MapForallToBlocks(MapForallToBlocks):
"""Specialization for MapForallToBlocks class."""
@overload
@@ -23,7 +28,7 @@ def __init__(
grid_dims: Optional[Union[Sequence[int], Attribute]] = None,
generate_gpu_launch: Optional[Union[bool, Attribute]] = None,
loc=None,
- ip=None
+ ip=None,
):
...
@@ -35,7 +40,7 @@ def __init__(
grid_dims: Optional[Union[Sequence[int], Attribute]] = None,
generate_gpu_launch: Optional[Union[bool, Attribute]] = None,
loc=None,
- ip=None
+ ip=None,
):
...
@@ -47,7 +52,7 @@ def __init__(
grid_dims: Optional[Union[Sequence[int], Attribute]] = None,
generate_gpu_launch: Optional[Union[bool, Attribute]] = None,
loc=None,
- ip=None
+ ip=None,
):
if isinstance(result_type_or_target, Type):
result_type = result_type_or_target
@@ -66,7 +71,8 @@ def __init__(
)
-class MapNestedForallToThreads:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class MapNestedForallToThreads(MapNestedForallToThreads):
"""Specialization for MapNestedForallToThreads class."""
@overload
@@ -79,7 +85,7 @@ def __init__(
warp_size: Optional[Sequence[int]] = None,
sync_after_distribute: Optional[bool] = None,
loc=None,
- ip=None
+ ip=None,
):
...
@@ -92,7 +98,7 @@ def __init__(
warp_size: Optional[Sequence[int]] = None,
sync_after_distribute: Optional[bool] = None,
loc=None,
- ip=None
+ ip=None,
):
...
@@ -105,7 +111,7 @@ def __init__(
warp_size: Optional[Union[Sequence[int], Attribute]] = None,
sync_after_distribute: Optional[bool] = None,
loc=None,
- ip=None
+ ip=None,
):
if isinstance(result_type_or_target, Type):
result_type = result_type_or_target
diff --git a/mlir/python/mlir/dialects/transform/loop.py b/mlir/python/mlir/dialects/transform/loop.py
index 1cdb2b9e77b5afe..6c89025f413839e 100644
--- a/mlir/python/mlir/dialects/transform/loop.py
+++ b/mlir/python/mlir/dialects/transform/loop.py
@@ -2,16 +2,23 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+from .._loop_transform_ops_gen import *
+from .._loop_transform_ops_gen import _Dialect
+
try:
- from ..ir import *
- from ._ods_common import get_op_result_or_value as _get_op_result_or_value
+ from ...ir import *
+ from .._ods_common import (
+ get_op_result_or_value as _get_op_result_or_value,
+ _cext as _ods_cext,
+ )
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
from typing import Optional, Union
-class GetParentForOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class GetParentForOp(GetParentForOp):
"""Extension for GetParentForOp."""
def __init__(
@@ -34,7 +41,8 @@ def __init__(
)
-class LoopOutlineOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class LoopOutlineOp(LoopOutlineOp):
"""Extension for LoopOutlineOp."""
def __init__(
@@ -61,7 +69,8 @@ def __init__(
)
-class LoopPeelOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class LoopPeelOp(LoopPeelOp):
"""Extension for LoopPeelOp."""
def __init__(
@@ -88,7 +97,8 @@ def __init__(
)
-class LoopPipelineOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class LoopPipelineOp(LoopPipelineOp):
"""Extension for LoopPipelineOp."""
def __init__(
@@ -115,7 +125,8 @@ def __init__(
)
-class LoopUnrollOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class LoopUnrollOp(LoopUnrollOp):
"""Extension for LoopUnrollOp."""
def __init__(
diff --git a/mlir/python/mlir/dialects/transform/memref.py b/mlir/python/mlir/dialects/transform/memref.py
index 1cc00bdcbf381c9..56ea61eb817f89c 100644
--- a/mlir/python/mlir/dialects/transform/memref.py
+++ b/mlir/python/mlir/dialects/transform/memref.py
@@ -2,16 +2,21 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+from .._memref_transform_ops_gen import *
+from .._memref_transform_ops_gen import _Dialect
+
try:
- from ..ir import *
- from ..dialects import transform
+ from ...ir import *
+ from ...dialects import transform
+ from .._ods_common import _cext as _ods_cext
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
from typing import Optional, overload, Union
-class MemRefAllocaToGlobalOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class MemRefAllocaToGlobalOp(MemRefAllocaToGlobalOp):
"""Specialization for MemRefAllocaToGlobalOp class."""
@overload
@@ -22,7 +27,7 @@ def __init__(
alloca: Union[Operation, OpView, Value],
*,
loc=None,
- ip=None
+ ip=None,
):
...
@@ -37,7 +42,7 @@ def __init__(
alloca_or_none: Optional[Union[Operation, OpView, Value]] = None,
*,
loc=None,
- ip=None
+ ip=None,
):
if isinstance(get_global_type_or_alloca, Type):
get_global_type = get_global_type_or_alloca
@@ -57,7 +62,8 @@ def __init__(
)
-class MemRefMultiBufferOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class MemRefMultiBufferOp(MemRefMultiBufferOp):
"""Specialization for MemRefMultiBufferOp class."""
@overload
@@ -69,7 +75,7 @@ def __init__(
*,
skip_analysis: Optional[bool] = None,
loc=None,
- ip=None
+ ip=None,
):
...
@@ -81,7 +87,7 @@ def __init__(
*,
skip_analysis: Optional[bool] = None,
loc=None,
- ip=None
+ ip=None,
):
...
@@ -93,7 +99,7 @@ def __init__(
*,
skip_analysis: Optional[bool] = None,
loc=None,
- ip=None
+ ip=None,
):
if isinstance(transformed_type_or_target, Type):
transformed_type = transformed_type_or_target
diff --git a/mlir/python/mlir/dialects/transform/pdl.py b/mlir/python/mlir/dialects/transform/pdl.py
index c4e4b4b4254b038..bb5fa7ffd306583 100644
--- a/mlir/python/mlir/dialects/transform/pdl.py
+++ b/mlir/python/mlir/dialects/transform/pdl.py
@@ -2,54 +2,54 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+from .._transform_pdl_extension_ops_gen import *
+from .._transform_pdl_extension_ops_gen import _Dialect
+
try:
- from ..ir import *
- from ._ods_common import (
- get_op_result_or_value as _get_op_result_or_value,
- get_op_results_or_values as _get_op_results_or_values,
- )
+ from ...ir import *
+ from .._ods_common import (
+ get_op_result_or_value as _get_op_result_or_value,
+ get_op_results_or_values as _get_op_results_or_values,
+ _cext as _ods_cext,
+ )
except ImportError as e:
- raise RuntimeError("Error loading imports from extension module") from e
+ raise RuntimeError("Error loading imports from extension module") from e
from typing import Union
-class PDLMatchOp:
-
- def __init__(
- self,
- result_type: Type,
- target: Union[Operation, Value],
- pattern_name: Union[Attribute, str],
- *,
- loc=None,
- ip=None,
- ):
- super().__init__(
- result_type,
- _get_op_result_or_value(target),
- pattern_name,
- loc=loc,
- ip=ip,
- )
-
-
-class WithPDLPatternsOp:
-
- def __init__(self,
- target: Union[Operation, Value, Type],
- *,
- loc=None,
- ip=None):
- root = _get_op_result_or_value(target) if not isinstance(target,
- Type) else None
- root_type = target if isinstance(target, Type) else root.type
- super().__init__(root=root, loc=loc, ip=ip)
- self.regions[0].blocks.append(root_type)
-
- @property
- def body(self) -> Block:
- return self.regions[0].blocks[0]
- @property
- def bodyTarget(self) -> Value:
- return self.body.arguments[0]
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class PDLMatchOp(PDLMatchOp):
+ def __init__(
+ self,
+ result_type: Type,
+ target: Union[Operation, Value],
+ pattern_name: Union[Attribute, str],
+ *,
+ loc=None,
+ ip=None,
+ ):
+ super().__init__(
+ result_type,
+ _get_op_result_or_value(target),
+ pattern_name,
+ loc=loc,
+ ip=ip,
+ )
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class WithPDLPatternsOp(WithPDLPatternsOp):
+ def __init__(self, target: Union[Operation, Value, Type], *, loc=None, ip=None):
+ root = _get_op_result_or_value(target) if not isinstance(target, Type) else None
+ root_type = target if isinstance(target, Type) else root.type
+ super().__init__(root=root, loc=loc, ip=ip)
+ self.regions[0].blocks.append(root_type)
+
+ @property
+ def body(self) -> Block:
+ return self.regions[0].blocks[0]
+
+ @property
+ def bodyTarget(self) -> Value:
+ return self.body.arguments[0]
diff --git a/mlir/python/mlir/dialects/transform/structured.py b/mlir/python/mlir/dialects/transform/structured.py
index 3757a3d3b4cce85..284c93823acbd34 100644
--- a/mlir/python/mlir/dialects/transform/structured.py
+++ b/mlir/python/mlir/dialects/transform/structured.py
@@ -2,9 +2,14 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+from .._structured_transform_ops_gen import *
+from .._structured_transform_ops_gen import _Dialect
+from .._structured_transform_enum_gen import *
+
try:
- from ..ir import *
- from ..dialects import transform
+ from ...ir import *
+ from ...dialects import transform
+ from .._ods_common import _cext as _ods_cext
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
@@ -163,7 +168,8 @@ def _get_int_array_array_attr(
return ArrayAttr.get(values)
-class BufferizeToAllocationOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class BufferizeToAllocationOp(BufferizeToAllocationOp):
"""Specialization for BufferizeToAllocationOp class."""
def __init__(
@@ -199,7 +205,8 @@ def __init__(
)
-class DecomposeOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class DecomposeOp(DecomposeOp):
"""Specialization for DecomposeOp class."""
def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
@@ -207,7 +214,8 @@ def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
super().__init__(transformed_type, target, loc=loc, ip=ip)
-class FuseIntoContainingOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class FuseIntoContainingOp(FuseIntoContainingOp):
"""Specialization for FuseIntoContainingOp class."""
@overload
@@ -271,7 +279,8 @@ def __init__(
)
-class GeneralizeOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class GeneralizeOp(GeneralizeOp):
"""Specialization for GeneralizeOp class."""
def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
@@ -279,7 +288,8 @@ def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
super().__init__(transformed_type, target, loc=loc, ip=ip)
-class InterchangeOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class InterchangeOp(InterchangeOp):
"""Specialization for InterchangeOp class."""
def __init__(
@@ -300,7 +310,8 @@ def __init__(
)
-class MapCopyToThreadsOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class MapCopyToThreadsOp(MapCopyToThreadsOp):
"""Specialization for MapCopyToThreadsOp class."""
@overload
@@ -360,7 +371,8 @@ def __init__(
)
-class VectorizeOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class VectorizeOp(VectorizeOp):
"""Specialization for VectorizeOp class."""
def __init__(
@@ -405,7 +417,8 @@ def __init__(
)
-class MatchOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class MatchOp(MatchOp):
"""Specialization for MatchOp class."""
@overload
@@ -464,7 +477,8 @@ def match_op_names(
)
-class MultiTileSizesOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class MultiTileSizesOp(MultiTileSizesOp):
"""Specialization for MultiTileSizesOp class."""
def __init__(
@@ -491,7 +505,8 @@ def __init__(
)
-class PadOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class PadOp(PadOp):
"""Specialization for PadOp class."""
def __init__(
@@ -528,7 +543,8 @@ def __init__(
)
-class ScalarizeOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class ScalarizeOp(ScalarizeOp):
"""Specialization for ScalarizeOp class."""
def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
@@ -536,7 +552,8 @@ def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
super().__init__(result_type, target, loc=loc, ip=ip)
-class SplitOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class SplitOp(SplitOp):
"""Specialization for SplitOp class."""
def __init__(
@@ -567,7 +584,8 @@ def __init__(
)
-class TileUsingForOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class TileUsingForOp(TileUsingForOp):
"""Specialization for TileUsingForOp class."""
@overload
@@ -640,7 +658,8 @@ def __init__(
)
-class TileUsingForallOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class TileUsingForallOp(TileUsingForallOp):
"""Specialization for TileUsingForallOp class."""
@overload
@@ -732,7 +751,8 @@ def __init__(
)
-class VectorizeChildrenAndApplyPatternsOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class VectorizeChildrenAndApplyPatternsOp(VectorizeChildrenAndApplyPatternsOp):
"""Specialization for VectorizeChildrenAndApplyPatternsOp class."""
def __init__(
diff --git a/mlir/python/mlir/dialects/transform/tensor.py b/mlir/python/mlir/dialects/transform/tensor.py
index 996093fbc913e8a..4eb30398f087212 100644
--- a/mlir/python/mlir/dialects/transform/tensor.py
+++ b/mlir/python/mlir/dialects/transform/tensor.py
@@ -2,16 +2,21 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+from .._tensor_transform_ops_gen import *
+from .._tensor_transform_ops_gen import _Dialect
+
try:
- from ..ir import *
- from ..dialects import transform
+ from ...ir import *
+ from ...dialects import transform
+ from .._ods_common import _cext as _ods_cext
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
from typing import Optional, overload, Union
-class MakeLoopIndependentOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class MakeLoopIndependentOp(MakeLoopIndependentOp):
"""Specialization for MakeLoopIndependentOp class."""
@overload
@@ -22,7 +27,7 @@ def __init__(
num_loops: Union[int, IntegerAttr],
*,
loc=None,
- ip=None
+ ip=None,
):
...
@@ -33,7 +38,7 @@ def __init__(
num_loops: Union[int, IntegerAttr],
*,
loc=None,
- ip=None
+ ip=None,
):
...
@@ -44,7 +49,7 @@ def __init__(
num_loops_or_none: Optional[Union[int, IntegerAttr]] = None,
*,
loc=None,
- ip=None
+ ip=None,
):
if isinstance(transformed_type_or_target, Type):
transformed_type = transformed_type_or_target
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 49f3a951426d0ee..d8dcf936e768754 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -30,14 +30,9 @@ constexpr const char *fileHeader = R"Py(
# Autogenerated by mlir-tblgen; don't manually edit.
from ._ods_common import _cext as _ods_cext
-from ._ods_common import extend_opview_class as _ods_extend_opview_class, segmented_accessor as _ods_segmented_accessor, equally_sized_accessor as _ods_equally_sized_accessor, get_default_loc_context as _ods_get_default_loc_context, get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values, get_op_result_or_op_results as _get_op_result_or_op_results
+from ._ods_common import segmented_accessor as _ods_segmented_accessor, equally_sized_accessor as _ods_equally_sized_accessor, get_default_loc_context as _ods_get_default_loc_context, get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values, get_op_result_or_op_results as _get_op_result_or_op_results
_ods_ir = _ods_cext.ir
-try:
- from . import _{0}_ops_ext as _ods_ext_module
-except ImportError:
- _ods_ext_module = None
-
import builtins
from typing import Sequence as _Sequence, Union as _Union
@@ -62,7 +57,6 @@ from ._{0}_ops_gen import _Dialect
/// {1} is the operation name.
constexpr const char *opClassTemplate = R"Py(
@_ods_cext.register_operation(_Dialect)
- at _ods_extend_opview_class(_ods_ext_module)
class {0}(_ods_ir.OpView):
OPERATION_NAME = "{1}"
)Py";
@@ -302,12 +296,8 @@ static bool isODSReserved(StringRef str) {
/// modified version.
static std::string sanitizeName(StringRef name) {
std::string processed_str = name.str();
- std::replace_if(
- processed_str.begin(), processed_str.end(),
- [](char c) { return !llvm::isAlnum(c); }, '_');
- if (llvm::isDigit(*processed_str.begin()))
- return "_" + processed_str;
+ std::replace(processed_str.begin(), processed_str.end(), '-', '_');
if (isPythonReserved(processed_str) || isODSReserved(processed_str))
return processed_str + "_";
@@ -854,9 +844,6 @@ populateBuilderRegions(const Operator &op,
static llvm::SmallVector<std::string> emitDefaultOpBuilder(const Operator &op,
raw_ostream &os) {
// If we are asked to skip default builders, comply.
- if (op.skipDefaultBuilders())
- return {};
-
llvm::SmallVector<std::string> builderArgs;
llvm::SmallVector<std::string> builderLines;
llvm::SmallVector<std::string> operandArgNames;
@@ -989,9 +976,8 @@ static void emitRegionAccessors(const Operator &op, raw_ostream &os) {
static void emitValueBuilder(const Operator &op,
llvm::SmallVector<std::string> functionArgs,
raw_ostream &os) {
- // If we are asked to skip default builders, comply.
- if (op.skipDefaultBuilders())
- return;
+ auto name = sanitizeName(op.getOperationName());
+ iterator_range<llvm::SplittingIterator> splitName = llvm::split(name, ".");
// Params with (possibly) default args.
auto valueBuilderParams =
llvm::map_range(functionArgs, [](const std::string &argAndMaybeDefault) {
@@ -1010,16 +996,16 @@ static void emitValueBuilder(const Operator &op,
auto lhs = *llvm::split(arg, "=").begin();
return (lhs + "=" + llvm::convertToSnakeFromCamelCase(lhs)).str();
});
- std::string name_without_dialect =
- op.getOperationName().substr(op.getOperationName().find('.') + 1);
- os << llvm::formatv(valueBuilderTemplate, sanitizeName(name_without_dialect),
- op.getCppClassName(),
- llvm::join(valueBuilderParams, ", "),
- llvm::join(opBuilderArgs, ", "),
- (op.getNumResults() > 1
- ? "_Sequence[_ods_ir.OpResult]"
- : (op.getNumResults() > 0 ? "_ods_ir.OpResult"
- : "_ods_ir.Operation")));
+ os << llvm::formatv(
+ valueBuilderTemplate,
+ // Drop dialect name and then sanitize again (to catch e.g. func.return).
+ sanitizeName(llvm::join(++splitName.begin(), splitName.end(), "_")),
+ op.getCppClassName(), llvm::join(valueBuilderParams, ", "),
+ llvm::join(opBuilderArgs, ", "),
+ (op.getNumResults() > 1
+ ? "_Sequence[_ods_ir.OpResult]"
+ : (op.getNumResults() > 0 ? "_ods_ir.OpResult"
+ : "_ods_ir.Operation")));
}
/// Emits bindings for a specific Op to the given output stream.
@@ -1051,11 +1037,8 @@ static bool emitAllOps(const llvm::RecordKeeper &records, raw_ostream &os) {
if (clDialectName.empty())
llvm::PrintFatalError("dialect name not provided");
- bool isExtension = !clDialectExtensionName.empty();
- os << llvm::formatv(fileHeader, isExtension
- ? clDialectExtensionName.getValue()
- : clDialectName.getValue());
- if (isExtension)
+ os << fileHeader;
+ if (!clDialectExtensionName.empty())
os << llvm::formatv(dialectExtensionTemplate, clDialectName.getValue());
else
os << llvm::formatv(dialectClassTemplate, clDialectName.getValue());
>From f969481e88aefa27f98277a8315ad400aa13745d Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Thu, 12 Oct 2023 09:42:48 -0500
Subject: [PATCH 4/5] rebase
---
mlir/python/CMakeLists.txt | 1 -
mlir/python/mlir/dialects/_affine_ops_ext.py | 56 -------------------
mlir/python/mlir/dialects/affine.py | 51 ++++++++++++++++-
mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp | 35 ++++++------
4 files changed, 67 insertions(+), 76 deletions(-)
delete mode 100644 mlir/python/mlir/dialects/_affine_ops_ext.py
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 586d39fc332fac5..88e6e13602d291a 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -52,7 +52,6 @@ declare_mlir_dialect_python_bindings(
TD_FILE dialects/AffineOps.td
SOURCES
dialects/affine.py
- dialects/_affine_ops_ext.py
DIALECT_NAME affine
GEN_ENUM_BINDINGS)
diff --git a/mlir/python/mlir/dialects/_affine_ops_ext.py b/mlir/python/mlir/dialects/_affine_ops_ext.py
deleted file mode 100644
index dc465ce7aa1e5f9..000000000000000
--- a/mlir/python/mlir/dialects/_affine_ops_ext.py
+++ /dev/null
@@ -1,56 +0,0 @@
-# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-try:
- from ..ir import *
- from ._ods_common import get_op_result_or_value as _get_op_result_or_value
- from ._ods_common import get_op_results_or_values as _get_op_results_or_values
-except ImportError as e:
- raise RuntimeError("Error loading imports from extension module") from e
-
-from typing import Optional, Sequence, Union
-
-
-class AffineStoreOp:
- """Specialization for the Affine store operation."""
-
- def __init__(
- self,
- value: Union[Operation, OpView, Value],
- memref: Union[Operation, OpView, Value],
- map: AffineMap=None,
- *,
- map_operands=None,
- loc=None,
- ip=None
- ):
- """Creates an affine store operation.
-
- - `value`: the value to store into the memref.
- - `memref`: the buffer to store into.
- - `map`: the affine map that maps the map_operands to the index of the
- memref.
- - `map_operands`: the list of arguments to substitute the dimensions,
- then symbols in the affine map, in increasing order.
- """
- map = map if map is not None else []
- map_operands = map_operands if map_operands is not None else []
- operands = [
- _get_op_result_or_value(value),
- _get_op_result_or_value(memref),
- *[_get_op_result_or_value(op) for op in map_operands]
- ]
- results = []
- attributes = {"map": AffineMapAttr.get(map)}
- regions = None
- _ods_successors = None
- super().__init__(self.build_generic(
- attributes=attributes,
- results=results,
- operands=operands,
- successors=_ods_successors,
- regions=regions,
- loc=loc,
- ip=ip
- ))
diff --git a/mlir/python/mlir/dialects/affine.py b/mlir/python/mlir/dialects/affine.py
index 8a2a64c7c40d190..1eaccfa73a85cbf 100644
--- a/mlir/python/mlir/dialects/affine.py
+++ b/mlir/python/mlir/dialects/affine.py
@@ -1,5 +1,50 @@
-# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from ._affine_ops_gen import *
+from ._affine_ops_gen import _Dialect
+
+try:
+ from ..ir import *
+ from ._ods_common import (
+ get_op_result_or_value as _get_op_result_or_value,
+ get_op_results_or_values as _get_op_results_or_values,
+ _cext as _ods_cext,
+ )
+except ImportError as e:
+ raise RuntimeError("Error loading imports from extension module") from e
+
+from typing import Optional, Sequence, Union
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class AffineStoreOp(AffineStoreOp):
+ """Specialization for the Affine store operation."""
+
+ def __init__(
+ self,
+ value: Union[Operation, OpView, Value],
+ memref: Union[Operation, OpView, Value],
+ map: AffineMap = None,
+ *,
+ map_operands=None,
+ loc=None,
+ ip=None,
+ ):
+ """Creates an affine store operation.
+
+ - `value`: the value to store into the memref.
+ - `memref`: the buffer to store into.
+ - `map`: the affine map that maps the map_operands to the index of the
+ memref.
+ - `map_operands`: the list of arguments to substitute the dimensions,
+ then symbols in the affine map, in increasing order.
+ """
+ map = map if map is not None else []
+ map_operands = map_operands if map_operands is not None else []
+ indicies = [_get_op_result_or_value(op) for op in map_operands]
+ _ods_successors = None
+ super().__init__(
+ value, memref, indicies, AffineMapAttr.get(map), loc=loc, ip=ip
+ )
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index d8dcf936e768754..875678a2333789a 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -295,13 +295,17 @@ static bool isODSReserved(StringRef str) {
/// (does not change the `name` if it already is suitable) and returns the
/// modified version.
static std::string sanitizeName(StringRef name) {
- std::string processed_str = name.str();
+ std::string processedStr = name.str();
+ std::replace_if(
+ processedStr.begin(), processedStr.end(),
+ [](char c) { return !llvm::isAlnum(c); }, '_');
- std::replace(processed_str.begin(), processed_str.end(), '-', '_');
+ if (llvm::isDigit(*processedStr.begin()))
+ return "_" + processedStr;
- if (isPythonReserved(processed_str) || isODSReserved(processed_str))
- return processed_str + "_";
- return processed_str;
+ if (isPythonReserved(processedStr) || isODSReserved(processedStr))
+ return processedStr + "_";
+ return processedStr;
}
static std::string attrSizedTraitForKind(const char *kind) {
@@ -977,7 +981,6 @@ static void emitValueBuilder(const Operator &op,
llvm::SmallVector<std::string> functionArgs,
raw_ostream &os) {
auto name = sanitizeName(op.getOperationName());
- iterator_range<llvm::SplittingIterator> splitName = llvm::split(name, ".");
// Params with (possibly) default args.
auto valueBuilderParams =
llvm::map_range(functionArgs, [](const std::string &argAndMaybeDefault) {
@@ -996,16 +999,16 @@ static void emitValueBuilder(const Operator &op,
auto lhs = *llvm::split(arg, "=").begin();
return (lhs + "=" + llvm::convertToSnakeFromCamelCase(lhs)).str();
});
- os << llvm::formatv(
- valueBuilderTemplate,
- // Drop dialect name and then sanitize again (to catch e.g. func.return).
- sanitizeName(llvm::join(++splitName.begin(), splitName.end(), "_")),
- op.getCppClassName(), llvm::join(valueBuilderParams, ", "),
- llvm::join(opBuilderArgs, ", "),
- (op.getNumResults() > 1
- ? "_Sequence[_ods_ir.OpResult]"
- : (op.getNumResults() > 0 ? "_ods_ir.OpResult"
- : "_ods_ir.Operation")));
+ std::string nameWithoutDialect =
+ op.getOperationName().substr(op.getOperationName().find('.') + 1);
+ os << llvm::formatv(valueBuilderTemplate, sanitizeName(nameWithoutDialect),
+ op.getCppClassName(),
+ llvm::join(valueBuilderParams, ", "),
+ llvm::join(opBuilderArgs, ", "),
+ (op.getNumResults() > 1
+ ? "_Sequence[_ods_ir.OpResult]"
+ : (op.getNumResults() > 0 ? "_ods_ir.OpResult"
+ : "_ods_ir.Operation")));
}
/// Emits bindings for a specific Op to the given output stream.
>From d5a8f42f03c3c7be620bc98be5cacff2db24e485 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Thu, 12 Oct 2023 01:57:00 -0500
Subject: [PATCH 5/5] format
---
mlir/lib/Bindings/Python/Globals.h | 2 +-
mlir/python/mlir/dialects/arith.py | 6 +-
.../linalg/opdsl/ops/core_named_ops.py | 107 ++++++++++--------
mlir/python/mlir/dialects/python_test.py | 7 +-
mlir/python/mlir/runtime/np_to_memref.py | 8 +-
mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp | 1 -
6 files changed, 75 insertions(+), 56 deletions(-)
diff --git a/mlir/lib/Bindings/Python/Globals.h b/mlir/lib/Bindings/Python/Globals.h
index dea44bbd469dd3d..21899bdce22e810 100644
--- a/mlir/lib/Bindings/Python/Globals.h
+++ b/mlir/lib/Bindings/Python/Globals.h
@@ -77,7 +77,7 @@ class PyGlobals {
pybind11::object pyClass);
/// Adds a concrete implementation operation class.
- /// Raises an exception if the mapping already exists.
+ /// Raises an exception if the mapping already exists and replace == false.
/// This is intended to be called by implementation code.
void registerOperationImpl(const std::string &operationName,
pybind11::object pyClass, bool replace = false);
diff --git a/mlir/python/mlir/dialects/arith.py b/mlir/python/mlir/dialects/arith.py
index e3b6a428c879de5..83aca0d58bf2cef 100644
--- a/mlir/python/mlir/dialects/arith.py
+++ b/mlir/python/mlir/dialects/arith.py
@@ -8,7 +8,10 @@
try:
from ..ir import *
- from ._ods_common import get_default_loc_context as _get_default_loc_context, _cext as _ods_cext
+ from ._ods_common import (
+ get_default_loc_context as _get_default_loc_context,
+ _cext as _ods_cext,
+ )
from typing import Any, List, Union
except ImportError as e:
@@ -34,6 +37,7 @@ def _is_integer_like_type(type: Type):
def _is_float_type(type: Type):
return _is_any_of(type, [BF16Type, F16Type, F32Type, F64Type])
+
@_ods_cext.register_operation(_Dialect, replace=True)
class ConstantOp(ConstantOp):
"""Specialization for the constant op class."""
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index a8f8f8e0fbd68b4..19734a80a107bfe 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -296,35 +296,39 @@ def quantized_matmul(
@linalg_structured_op
-def matmul_transpose_a(A=TensorDef(T1, S.K, S.N),
- B=TensorDef(T2, S.K, S.M),
- C=TensorDef(U, S.M, S.N, output=True),
- cast=TypeFnAttrDef(default=TypeFn.cast_signed)):
- """Performs a matrix multiplication of two 2D inputs with lhs operand
- transposed.
+def matmul_transpose_a(
+ A=TensorDef(T1, S.K, S.N),
+ B=TensorDef(T2, S.K, S.M),
+ C=TensorDef(U, S.M, S.N, output=True),
+ cast=TypeFnAttrDef(default=TypeFn.cast_signed),
+):
+ """Performs a matrix multiplication of two 2D inputs with lhs operand
+ transposed.
- Numeric casting is performed on the operands to the inner multiply, promoting
- them to the same data type as the accumulator/output.
- """
- domain(D.m, D.n, D.k)
- implements(ContractionOpInterface)
- C[D.m, D.n] += cast(U, A[D.k, D.m]) * cast(U, B[D.k, D.n])
+ Numeric casting is performed on the operands to the inner multiply, promoting
+ them to the same data type as the accumulator/output.
+ """
+ domain(D.m, D.n, D.k)
+ implements(ContractionOpInterface)
+ C[D.m, D.n] += cast(U, A[D.k, D.m]) * cast(U, B[D.k, D.n])
@linalg_structured_op
-def matmul_transpose_b(A=TensorDef(T1, S.M, S.K),
- B=TensorDef(T2, S.N, S.K),
- C=TensorDef(U, S.M, S.N, output=True),
- cast=TypeFnAttrDef(default=TypeFn.cast_signed)):
- """Performs a matrix multiplication of two 2D inputs with rhs operand
- transposed.
+def matmul_transpose_b(
+ A=TensorDef(T1, S.M, S.K),
+ B=TensorDef(T2, S.N, S.K),
+ C=TensorDef(U, S.M, S.N, output=True),
+ cast=TypeFnAttrDef(default=TypeFn.cast_signed),
+):
+ """Performs a matrix multiplication of two 2D inputs with rhs operand
+ transposed.
- Numeric casting is performed on the operands to the inner multiply, promoting
- them to the same data type as the accumulator/output.
- """
- domain(D.m, D.n, D.k)
- implements(ContractionOpInterface)
- C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.n, D.k])
+ Numeric casting is performed on the operands to the inner multiply, promoting
+ them to the same data type as the accumulator/output.
+ """
+ domain(D.m, D.n, D.k)
+ implements(ContractionOpInterface)
+ C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.n, D.k])
@linalg_structured_op
@@ -390,36 +394,41 @@ def batch_matmul(
@linalg_structured_op
-def batch_matmul_transpose_a(A=TensorDef(T1, Batch, S.K, S.M),
- B=TensorDef(T2, Batch, S.K, S.N),
- C=TensorDef(U, Batch, S.M, S.N, output=True)):
- """Performs a batched matrix multiplication of two 3D inputs where lhs operand
- has its non-batch dimensions transposed.
+def batch_matmul_transpose_a(
+ A=TensorDef(T1, Batch, S.K, S.M),
+ B=TensorDef(T2, Batch, S.K, S.N),
+ C=TensorDef(U, Batch, S.M, S.N, output=True),
+):
+ """Performs a batched matrix multiplication of two 3D inputs where lhs operand
+ has its non-batch dimensions transposed.
- Numeric casting is performed on the operands to the inner multiply, promoting
- them to the same data type as the accumulator/output.
- """
- domain(D.b, D.m, D.n, D.k)
- implements(ContractionOpInterface)
- C[D.b, D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.k, D.m]) \
- * TypeFn.cast_signed(U, B[D.b, D.k, D.n])
+ Numeric casting is performed on the operands to the inner multiply, promoting
+ them to the same data type as the accumulator/output.
+ """
+ domain(D.b, D.m, D.n, D.k)
+ implements(ContractionOpInterface)
+ C[D.b, D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.k, D.m]) * TypeFn.cast_signed(
+ U, B[D.b, D.k, D.n]
+ )
@linalg_structured_op
-def batch_matmul_transpose_b(A=TensorDef(T1, Batch, S.M, S.K),
- B=TensorDef(T2, Batch, S.N, S.K),
- C=TensorDef(U, Batch, S.M, S.N, output=True)):
- """Performs a batched matrix multiplication of two 3D inputs where rhs operand
- has its non-batch dimensions transposed.
+def batch_matmul_transpose_b(
+ A=TensorDef(T1, Batch, S.M, S.K),
+ B=TensorDef(T2, Batch, S.N, S.K),
+ C=TensorDef(U, Batch, S.M, S.N, output=True),
+):
+ """Performs a batched matrix multiplication of two 3D inputs where rhs operand
+ has its non-batch dimensions transposed.
- Numeric casting is performed on the operands to the inner multiply, promoting
- them to the same data type as the accumulator/output.
- """
- domain(D.b, D.m, D.n, D.k)
- implements(ContractionOpInterface)
- C[D.b, D.m,
- D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed(
- U, B[D.b, D.n, D.k])
+ Numeric casting is performed on the operands to the inner multiply, promoting
+ them to the same data type as the accumulator/output.
+ """
+ domain(D.b, D.m, D.n, D.k)
+ implements(ContractionOpInterface)
+ C[D.b, D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed(
+ U, B[D.b, D.n, D.k]
+ )
@linalg_structured_op
diff --git a/mlir/python/mlir/dialects/python_test.py b/mlir/python/mlir/dialects/python_test.py
index 8465af048a28056..6579e02d8549efa 100644
--- a/mlir/python/mlir/dialects/python_test.py
+++ b/mlir/python/mlir/dialects/python_test.py
@@ -3,7 +3,12 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from ._python_test_ops_gen import *
-from .._mlir_libs._mlirPythonTest import TestAttr, TestType, TestTensorValue, TestIntegerRankedTensorType
+from .._mlir_libs._mlirPythonTest import (
+ TestAttr,
+ TestType,
+ TestTensorValue,
+ TestIntegerRankedTensorType,
+)
def register_python_test_dialect(context, load=True):
diff --git a/mlir/python/mlir/runtime/np_to_memref.py b/mlir/python/mlir/runtime/np_to_memref.py
index 0a3b411041b2f4d..f6b706f9bc8ae24 100644
--- a/mlir/python/mlir/runtime/np_to_memref.py
+++ b/mlir/python/mlir/runtime/np_to_memref.py
@@ -114,6 +114,7 @@ def get_unranked_memref_descriptor(nparray):
d.descriptor = ctypes.cast(ctypes.pointer(x), ctypes.c_void_p)
return d
+
def move_aligned_ptr_by_offset(aligned_ptr, offset):
"""Moves the supplied ctypes pointer ahead by `offset` elements."""
aligned_addr = ctypes.addressof(aligned_ptr.contents)
@@ -122,6 +123,7 @@ def move_aligned_ptr_by_offset(aligned_ptr, offset):
content_ptr = ctypes.cast(aligned_addr + shift, type(aligned_ptr))
return content_ptr
+
def unranked_memref_to_numpy(unranked_memref, np_dtype):
"""Converts unranked memrefs to numpy arrays."""
ctp = as_ctype(np_dtype)
@@ -139,10 +141,10 @@ def unranked_memref_to_numpy(unranked_memref, np_dtype):
def ranked_memref_to_numpy(ranked_memref):
"""Converts ranked memrefs to numpy arrays."""
- content_ptr = move_aligned_ptr_by_offset(ranked_memref[0].aligned, ranked_memref[0].offset)
- np_arr = np.ctypeslib.as_array(
- content_ptr, shape=ranked_memref[0].shape
+ content_ptr = move_aligned_ptr_by_offset(
+ ranked_memref[0].aligned, ranked_memref[0].offset
)
+ np_arr = np.ctypeslib.as_array(content_ptr, shape=ranked_memref[0].shape)
strided_arr = np.lib.stride_tricks.as_strided(
np_arr,
np.ctypeslib.as_array(ranked_memref[0].shape),
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 875678a2333789a..aaf200cf414e425 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -847,7 +847,6 @@ populateBuilderRegions(const Operator &op,
/// rebuild anew).
static llvm::SmallVector<std::string> emitDefaultOpBuilder(const Operator &op,
raw_ostream &os) {
- // If we are asked to skip default builders, comply.
llvm::SmallVector<std::string> builderArgs;
llvm::SmallVector<std::string> builderLines;
llvm::SmallVector<std::string> operandArgNames;
More information about the Mlir-commits
mailing list