[Mlir-commits] [mlir] a2288a8 - [mlir][python] remove mixins (#68853)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Oct 19 14:20:18 PDT 2023
Author: Maksim Levental
Date: 2023-10-19T16:20:14-05:00
New Revision: a2288a8944c310fcad1196302f16513797e1fcbc
URL: https://github.com/llvm/llvm-project/commit/a2288a8944c310fcad1196302f16513797e1fcbc
DIFF: https://github.com/llvm/llvm-project/commit/a2288a8944c310fcad1196302f16513797e1fcbc.diff
LOG: [mlir][python] remove mixins (#68853)
This PR replaces the mixin `OpView` extension mechanism with the
standard inheritance mechanism.
Why? Firstly, mixins are not very pythonic (inheritance is usually used
for this), a little convoluted, and too "tight" (can only be used in the
immediately adjacent `_ext.py`). Secondly, it (mixins) are now blocking
are correct implementation of "value builders" (see
[here](https://github.com/llvm/llvm-project/pull/68764)) where the
problem becomes how to choose the correct base class that the value
builder should call.
This PR looks big/complicated but appearances are deceiving; 4 things
were needed to make this work:
1. Drop `skipDefaultBuilders` in
`OpPythonBindingGen::emitDefaultOpBuilders`
2. Former mixin extension classes are converted to inherit from the
generated `OpView` instead of being "mixins"
a. extension classes that simply were calling into an already generated
`super().__init__` continue to do so
b. (almost all) extension classes that were calling `self.build_generic`
because of a lack of default builder being generated can now also just
call `super().__init__`
3. To handle the [lone single
use-case](https://sourcegraph.com/search?q=context%3Aglobal+select_opview_mixin&patternType=standard&sm=1&groupBy=repo)
of `select_opview_mixin`, namely
[linalg](https://github.com/llvm/llvm-project/blob/main/mlir/python/mlir/dialects/_linalg_ops_ext.py#L38),
only a small change was necessary in `opdsl/lang/emitter.py` (thanks to
the emission/generation of default builders/`__init__`s)
4. since the `extend_opview_class` decorator is removed, we need a way
to register extension classes as the desired `OpView` that `op.opview`
conjures into existence; so we do the standard thing and just enable
replacing the existing registered `OpView` i.e.,
`register_operation(_Dialect, replace=True)`.
Note, the upgrade path for the common case is to change an extension to
inherit from the generated builder and decorate it with
`register_operation(_Dialect, replace=True)`. In the slightly more
complicated case where `super().__init(self.build_generic(...))` is
called in the extension's `__init__`, this needs to be updated to call
`__init__` in `OpView`, i.e., the grandparent (see updated docs).
Note, also `<DIALECT>_ext.py` files/modules will no longer be automatically loaded.
Note, the PR has 3 base commits that look funny but this was done for
the purpose of tracking the line history of moving the
`<DIALECT>_ops_ext.py` class into `<DIALECT>.py` and updating (commit
labeled "fix").
Added:
Modified:
mlir/docs/Bindings/Python.md
mlir/lib/Bindings/Python/Globals.h
mlir/lib/Bindings/Python/IRModule.cpp
mlir/lib/Bindings/Python/MainModule.cpp
mlir/python/CMakeLists.txt
mlir/python/mlir/dialects/_ods_common.py
mlir/python/mlir/dialects/affine.py
mlir/python/mlir/dialects/arith.py
mlir/python/mlir/dialects/bufferization.py
mlir/python/mlir/dialects/builtin.py
mlir/python/mlir/dialects/func.py
mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
mlir/python/mlir/dialects/memref.py
mlir/python/mlir/dialects/ml_program.py
mlir/python/mlir/dialects/pdl.py
mlir/python/mlir/dialects/python_test.py
mlir/python/mlir/dialects/scf.py
mlir/python/mlir/dialects/tensor.py
mlir/python/mlir/dialects/transform/__init__.py
mlir/python/mlir/dialects/transform/bufferization.py
mlir/python/mlir/dialects/transform/gpu.py
mlir/python/mlir/dialects/transform/loop.py
mlir/python/mlir/dialects/transform/memref.py
mlir/python/mlir/dialects/transform/pdl.py
mlir/python/mlir/dialects/transform/structured.py
mlir/python/mlir/dialects/transform/tensor.py
mlir/python/mlir/runtime/np_to_memref.py
mlir/test/python/dialects/arith_dialect.py
mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
Removed:
mlir/python/mlir/dialects/_affine_ops_ext.py
mlir/python/mlir/dialects/_arith_ops_ext.py
mlir/python/mlir/dialects/_bufferization_ops_ext.py
mlir/python/mlir/dialects/_bufferization_transform_ops_ext.py
mlir/python/mlir/dialects/_builtin_ops_ext.py
mlir/python/mlir/dialects/_func_ops_ext.py
mlir/python/mlir/dialects/_gpu_transform_ops_ext.py
mlir/python/mlir/dialects/_linalg_ops_ext.py
mlir/python/mlir/dialects/_loop_transform_ops_ext.py
mlir/python/mlir/dialects/_memref_ops_ext.py
mlir/python/mlir/dialects/_memref_transform_ops_ext.py
mlir/python/mlir/dialects/_ml_program_ops_ext.py
mlir/python/mlir/dialects/_pdl_ops_ext.py
mlir/python/mlir/dialects/_scf_ops_ext.py
mlir/python/mlir/dialects/_structured_transform_ops_ext.py
mlir/python/mlir/dialects/_tensor_ops_ext.py
mlir/python/mlir/dialects/_tensor_transform_ops_ext.py
mlir/python/mlir/dialects/_transform_ops_ext.py
mlir/python/mlir/dialects/_transform_pdl_extension_ops_ext.py
################################################################################
diff --git a/mlir/docs/Bindings/Python.md b/mlir/docs/Bindings/Python.md
index bf54efee1f14e0c..bc2e676a878c0f4 100644
--- a/mlir/docs/Bindings/Python.md
+++ b/mlir/docs/Bindings/Python.md
@@ -1017,90 +1017,79 @@ very generic signature.
#### Extending Generated Op Classes
-Note that this is a rather complex mechanism and this section errs on the side
-of explicitness. Users are encouraged to find an example and duplicate it if
-they don't feel the need to understand the subtlety. The `builtin` dialect
-provides some relatively simple examples.
-
As mentioned above, the build system generates Python sources like
`_{DIALECT_NAMESPACE}_ops_gen.py` for each dialect with Python bindings. It is
-often desirable to to use these generated classes as a starting point for
-further customization, so an extension mechanism is provided to make this easy
-(you are always free to do ad-hoc patching in your `{DIALECT_NAMESPACE}.py` file
-but we prefer a more standard mechanism that is applied uniformly).
+often desirable to use these generated classes as a starting point for
+further customization, so an extension mechanism is provided to make this easy.
+This mechanism uses conventional inheritance combined with `OpView` registration.
+For example, the default builder for `arith.constant`
+
+```python
+class ConstantOp(_ods_ir.OpView):
+ OPERATION_NAME = "arith.constant"
+
+ _ODS_REGIONS = (0, True)
+
+ def __init__(self, value, *, loc=None, ip=None):
+ ...
+```
-To provide extensions, add a `_{DIALECT_NAMESPACE}_ops_ext.py` file to the
-`dialects` module (i.e. adjacent to your `{DIALECT_NAMESPACE}.py` top-level and
-the `*_ops_gen.py` file). Using the `builtin` dialect and `FuncOp` as an
-example, the generated code will include an import like this:
+expects `value` to be a `TypedAttr` (e.g., `IntegerAttr` or `FloatAttr`).
+Thus, a natural extension is a builder that accepts a MLIR type and a Python value and instantiates the appropriate `TypedAttr`:
```python
-try:
- from . import _builtin_ops_ext as _ods_ext_module
-except ImportError:
- _ods_ext_module = None
+from typing import Union
+
+from mlir.ir import Type, IntegerAttr, FloatAttr
+from mlir.dialects._arith_ops_gen import _Dialect, ConstantOp
+from mlir.dialects._ods_common import _cext
+
+ at _cext.register_operation(_Dialect, replace=True)
+class ConstantOpExt(ConstantOp):
+ def __init__(
+ self, result: Type, value: Union[int, float], *, loc=None, ip=None
+ ):
+ if isinstance(value, int):
+ super().__init__(IntegerAttr.get(result, value), loc=loc, ip=ip)
+ elif isinstance(value, float):
+ super().__init__(FloatAttr.get(result, value), loc=loc, ip=ip)
+ else:
+ raise NotImplementedError(f"Building `arith.constant` not supported for {result=} {value=}")
```
-Then for each generated concrete `OpView` subclass, it will apply a decorator
-like:
+which enables building an instance of `arith.constant` like so:
```python
- at _ods_cext.register_operation(_Dialect)
- at _ods_extend_opview_class(_ods_ext_module)
-class FuncOp(_ods_ir.OpView):
+from mlir.ir import F32Type
+
+a = ConstantOpExt(F32Type.get(), 42.42)
+b = ConstantOpExt(IntegerType.get_signless(32), 42)
```
-See the `_ods_common.py` `extend_opview_class` function for details of the
-mechanism. At a high level:
-
-* If the extension module exists, locate an extension class for the op (in
- this example, `FuncOp`):
- * First by looking for an attribute with the exact name in the extension
- module.
- * Falling back to calling a `select_opview_mixin(parent_opview_cls)`
- function defined in the extension module.
-* If a mixin class is found, a new subclass is dynamically created that
- multiply inherits from `({_builtin_ops_ext.FuncOp},
- _builtin_ops_gen.FuncOp)`.
-
-The mixin class should not inherit from anything (i.e. directly extends `object`
-only). The facility is typically used to define custom `__init__` methods,
-properties, instance methods and static methods. Due to the inheritance
-ordering, the mixin class can act as though it extends the generated `OpView`
-subclass in most contexts (i.e. `issubclass(_builtin_ops_ext.FuncOp, OpView)`
-will return `False` but usage generally allows you treat it as duck typed as an
-`OpView`).
-
-There are a couple of recommendations, given how the class hierarchy is defined:
-
-* For static methods that need to instantiate the actual "leaf" op (which is
- dynamically generated and would result in circular dependencies to try to
- reference by name), prefer to use `@classmethod` and the concrete subclass
- will be provided as your first `cls` argument. See
- `_builtin_ops_ext.FuncOp.from_py_func` as an example.
-* If seeking to replace the generated `__init__` method entirely, you may
- actually want to invoke the super-super-class `mlir.ir.OpView` constructor
- directly, as it takes an `mlir.ir.Operation`, which is likely what you are
- constructing (i.e. the generated `__init__` method likely adds more API
- constraints than you want to expose in a custom builder).
-
-A pattern that comes up frequently is wanting to provide a sugared `__init__`
-method which has optional or type-polymorphism/implicit conversions but to
-otherwise want to invoke the default op building logic. For such cases, it is
-recommended to use an idiom such as:
+Note, three key aspects of the extension mechanism in this example:
+
+1. `ConstantOpExt` directly inherits from the generated `ConstantOp`;
+2. in this, simplest, case all that's required is a call to the super class' initializer, i.e., `super().__init__(...)`;
+3. in order to register `ConstantOpExt` as the preferred `OpView` that is returned by `mlir.ir.Operation.opview` (see [Operations, Regions and Blocks](#operations-regions-and-blocks))
+ we decorate the class with `@_cext.register_operation(_Dialect, replace=True)`, **where the `replace=True` must be used**.
+
+In some more complex cases it might be necessary to explicitly build the `OpView` through `OpView.build_generic` (see [Default Builder](#default-builder)), just as is performed by the generated builders.
+I.e., we must call `OpView.build_generic` **and pass the result to `OpView.__init__`**, where the small issue becomes that the latter is already overridden by the generated builder.
+Thus, we must call a method of a super class' super class (the "grandparent"); for example:
```python
- def __init__(self, sugar, spice, *, loc=None, ip=None):
- ... massage into result_type, operands, attributes ...
- OpView.__init__(self, self.build_generic(
- results=[result_type],
- operands=operands,
- attributes=attributes,
- loc=loc,
- ip=ip))
+from mlir.dialects._scf_ops_gen import _Dialect, ForOp
+from mlir.dialects._ods_common import _cext
+
+ at _cext.register_operation(_Dialect, replace=True)
+class ForOpExt(ForOp):
+ def __init__(self, lower_bound, upper_bound, step, iter_args, *, loc=None, ip=None):
+ ...
+ super(ForOp, self).__init__(self.build_generic(...))
```
-Refer to the documentation for `build_generic` for more information.
+where `OpView.__init__` is called via `super(ForOp, self).__init__`.
+Note, there are alternatives ways to implement this (e.g., explicitly writing `OpView.__init__`); see any discussion on Python inheritance.
## Providing Python bindings for a dialect
diff --git a/mlir/lib/Bindings/Python/Globals.h b/mlir/lib/Bindings/Python/Globals.h
index 97cd70089a2e965..21899bdce22e810 100644
--- a/mlir/lib/Bindings/Python/Globals.h
+++ b/mlir/lib/Bindings/Python/Globals.h
@@ -77,10 +77,10 @@ class PyGlobals {
pybind11::object pyClass);
/// Adds a concrete implementation operation class.
- /// Raises an exception if the mapping already exists.
+ /// Raises an exception if the mapping already exists and replace == false.
/// This is intended to be called by implementation code.
void registerOperationImpl(const std::string &operationName,
- pybind11::object pyClass);
+ pybind11::object pyClass, bool replace = false);
/// Returns the custom Attribute builder for Attribute kind.
std::optional<pybind11::function>
diff --git a/mlir/lib/Bindings/Python/IRModule.cpp b/mlir/lib/Bindings/Python/IRModule.cpp
index 2cc66277abee0f0..a1c8ab7a09ce155 100644
--- a/mlir/lib/Bindings/Python/IRModule.cpp
+++ b/mlir/lib/Bindings/Python/IRModule.cpp
@@ -96,9 +96,9 @@ void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
}
void PyGlobals::registerOperationImpl(const std::string &operationName,
- py::object pyClass) {
+ py::object pyClass, bool replace) {
py::object &found = operationClassMap[operationName];
- if (found) {
+ if (found && !replace) {
throw std::runtime_error((llvm::Twine("Operation '") + operationName +
"' is already registered.")
.str());
diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index cdddfbe50606d05..a936becf67bea75 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -41,7 +41,7 @@ PYBIND11_MODULE(_mlir, m) {
"dialect_namespace"_a, "dialect_class"_a,
"Testing hook for directly registering a dialect")
.def("_register_operation_impl", &PyGlobals::registerOperationImpl,
- "operation_name"_a, "operation_class"_a,
+ "operation_name"_a, "operation_class"_a, "replace"_a = false,
"Testing hook for directly registering an operation");
// Aside from making the globals accessible to python, having python manage
@@ -63,12 +63,13 @@ PYBIND11_MODULE(_mlir, m) {
"Class decorator for registering a custom Dialect wrapper");
m.def(
"register_operation",
- [](const py::object &dialectClass) -> py::cpp_function {
+ [](const py::object &dialectClass, bool replace) -> py::cpp_function {
return py::cpp_function(
- [dialectClass](py::object opClass) -> py::object {
+ [dialectClass, replace](py::object opClass) -> py::object {
std::string operationName =
opClass.attr("OPERATION_NAME").cast<std::string>();
- PyGlobals::get().registerOperationImpl(operationName, opClass);
+ PyGlobals::get().registerOperationImpl(operationName, opClass,
+ replace);
// Dict-stuff the new opClass by name onto the dialect class.
py::object opClassName = opClass.attr("__name__");
@@ -76,7 +77,7 @@ PYBIND11_MODULE(_mlir, m) {
return opClass;
});
},
- "dialect_class"_a,
+ "dialect_class"_a, "replace"_a = false,
"Produce a class decorator for registering an Operation class as part of "
"a dialect");
m.def(
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index c7b3c283a6b6dc1..88e6e13602d291a 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -52,7 +52,6 @@ declare_mlir_dialect_python_bindings(
TD_FILE dialects/AffineOps.td
SOURCES
dialects/affine.py
- dialects/_affine_ops_ext.py
DIALECT_NAME affine
GEN_ENUM_BINDINGS)
@@ -78,7 +77,6 @@ declare_mlir_dialect_python_bindings(
TD_FILE dialects/BufferizationOps.td
SOURCES
dialects/bufferization.py
- dialects/_bufferization_ops_ext.py
DIALECT_NAME bufferization
GEN_ENUM_BINDINGS_TD_FILE
"../../include/mlir/Dialect/Bufferization/IR/BufferizationEnums.td"
@@ -90,7 +88,6 @@ declare_mlir_dialect_python_bindings(
TD_FILE dialects/BuiltinOps.td
SOURCES
dialects/builtin.py
- dialects/_builtin_ops_ext.py
DIALECT_NAME builtin)
declare_mlir_dialect_python_bindings(
@@ -115,7 +112,6 @@ declare_mlir_dialect_python_bindings(
TD_FILE dialects/FuncOps.td
SOURCES
dialects/func.py
- dialects/_func_ops_ext.py
DIALECT_NAME func)
declare_mlir_dialect_python_bindings(
@@ -131,7 +127,6 @@ declare_mlir_dialect_python_bindings(
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
TD_FILE dialects/LinalgOps.td
SOURCES
- dialects/_linalg_ops_ext.py
SOURCES_GLOB
dialects/linalg/*.py
DIALECT_NAME linalg
@@ -152,7 +147,6 @@ ADD_TO_PARENT MLIRPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
TD_FILE dialects/TransformPDLExtensionOps.td
SOURCES
- dialects/_transform_pdl_extension_ops_ext.py
dialects/transform/pdl.py
DIALECT_NAME transform
EXTENSION_NAME transform_pdl_extension)
@@ -162,7 +156,6 @@ declare_mlir_dialect_python_bindings(
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
TD_FILE dialects/TransformOps.td
SOURCES
- dialects/_transform_ops_ext.py
dialects/transform/__init__.py
_mlir_libs/_mlir/dialects/transform/__init__.pyi
DIALECT_NAME transform
@@ -175,7 +168,6 @@ declare_mlir_dialect_extension_python_bindings(
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
TD_FILE dialects/BufferizationTransformOps.td
SOURCES
- dialects/_bufferization_transform_ops_ext.py
dialects/transform/bufferization.py
DIALECT_NAME transform
EXTENSION_NAME bufferization_transform)
@@ -185,7 +177,6 @@ declare_mlir_dialect_extension_python_bindings(
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
TD_FILE dialects/GPUTransformOps.td
SOURCES
- dialects/_gpu_transform_ops_ext.py
dialects/transform/gpu.py
DIALECT_NAME transform
EXTENSION_NAME gpu_transform)
@@ -195,7 +186,6 @@ declare_mlir_dialect_extension_python_bindings(
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
TD_FILE dialects/SCFLoopTransformOps.td
SOURCES
- dialects/_loop_transform_ops_ext.py
dialects/transform/loop.py
DIALECT_NAME transform
EXTENSION_NAME loop_transform)
@@ -205,7 +195,6 @@ declare_mlir_dialect_extension_python_bindings(
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
TD_FILE dialects/MemRefTransformOps.td
SOURCES
- dialects/_memref_transform_ops_ext.py
dialects/transform/memref.py
DIALECT_NAME transform
EXTENSION_NAME memref_transform)
@@ -224,7 +213,6 @@ declare_mlir_dialect_extension_python_bindings(
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
TD_FILE dialects/LinalgStructuredTransformOps.td
SOURCES
- dialects/_structured_transform_ops_ext.py
dialects/transform/structured.py
DIALECT_NAME transform
EXTENSION_NAME structured_transform
@@ -246,7 +234,6 @@ declare_mlir_dialect_extension_python_bindings(
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
TD_FILE dialects/TensorTransformOps.td
SOURCES
- dialects/_tensor_transform_ops_ext.py
dialects/transform/tensor.py
DIALECT_NAME transform
EXTENSION_NAME tensor_transform)
@@ -276,7 +263,6 @@ declare_mlir_dialect_python_bindings(
TD_FILE dialects/ArithOps.td
SOURCES
dialects/arith.py
- dialects/_arith_ops_ext.py
DIALECT_NAME arith
GEN_ENUM_BINDINGS)
@@ -286,7 +272,6 @@ declare_mlir_dialect_python_bindings(
TD_FILE dialects/MemRefOps.td
SOURCES
dialects/memref.py
- dialects/_memref_ops_ext.py
DIALECT_NAME memref)
declare_mlir_dialect_python_bindings(
@@ -295,7 +280,6 @@ declare_mlir_dialect_python_bindings(
TD_FILE dialects/MLProgramOps.td
SOURCES
dialects/ml_program.py
- dialects/_ml_program_ops_ext.py
DIALECT_NAME ml_program)
declare_mlir_dialect_python_bindings(
@@ -339,7 +323,6 @@ declare_mlir_dialect_python_bindings(
TD_FILE dialects/PDLOps.td
SOURCES
dialects/pdl.py
- dialects/_pdl_ops_ext.py
_mlir_libs/_mlir/dialects/pdl.pyi
DIALECT_NAME pdl)
@@ -357,7 +340,6 @@ declare_mlir_dialect_python_bindings(
TD_FILE dialects/SCFOps.td
SOURCES
dialects/scf.py
- dialects/_scf_ops_ext.py
DIALECT_NAME scf)
declare_mlir_dialect_python_bindings(
@@ -383,7 +365,6 @@ declare_mlir_dialect_python_bindings(
TD_FILE dialects/TensorOps.td
SOURCES
dialects/tensor.py
- dialects/_tensor_ops_ext.py
DIALECT_NAME tensor)
declare_mlir_dialect_python_bindings(
diff --git a/mlir/python/mlir/dialects/_affine_ops_ext.py b/mlir/python/mlir/dialects/_affine_ops_ext.py
deleted file mode 100644
index dc465ce7aa1e5f9..000000000000000
--- a/mlir/python/mlir/dialects/_affine_ops_ext.py
+++ /dev/null
@@ -1,56 +0,0 @@
-# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-try:
- from ..ir import *
- from ._ods_common import get_op_result_or_value as _get_op_result_or_value
- from ._ods_common import get_op_results_or_values as _get_op_results_or_values
-except ImportError as e:
- raise RuntimeError("Error loading imports from extension module") from e
-
-from typing import Optional, Sequence, Union
-
-
-class AffineStoreOp:
- """Specialization for the Affine store operation."""
-
- def __init__(
- self,
- value: Union[Operation, OpView, Value],
- memref: Union[Operation, OpView, Value],
- map: AffineMap=None,
- *,
- map_operands=None,
- loc=None,
- ip=None
- ):
- """Creates an affine store operation.
-
- - `value`: the value to store into the memref.
- - `memref`: the buffer to store into.
- - `map`: the affine map that maps the map_operands to the index of the
- memref.
- - `map_operands`: the list of arguments to substitute the dimensions,
- then symbols in the affine map, in increasing order.
- """
- map = map if map is not None else []
- map_operands = map_operands if map_operands is not None else []
- operands = [
- _get_op_result_or_value(value),
- _get_op_result_or_value(memref),
- *[_get_op_result_or_value(op) for op in map_operands]
- ]
- results = []
- attributes = {"map": AffineMapAttr.get(map)}
- regions = None
- _ods_successors = None
- super().__init__(self.build_generic(
- attributes=attributes,
- results=results,
- operands=operands,
- successors=_ods_successors,
- regions=regions,
- loc=loc,
- ip=ip
- ))
diff --git a/mlir/python/mlir/dialects/_arith_ops_ext.py b/mlir/python/mlir/dialects/_arith_ops_ext.py
deleted file mode 100644
index df38f871710fe8f..000000000000000
--- a/mlir/python/mlir/dialects/_arith_ops_ext.py
+++ /dev/null
@@ -1,69 +0,0 @@
-# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-try:
- from ..ir import *
- from ._ods_common import get_default_loc_context as _get_default_loc_context
-
- from typing import Any, List, Union
-except ImportError as e:
- raise RuntimeError("Error loading imports from extension module") from e
-
-
-def _isa(obj: Any, cls: type):
- try:
- cls(obj)
- except ValueError:
- return False
- return True
-
-
-def _is_any_of(obj: Any, classes: List[type]):
- return any(_isa(obj, cls) for cls in classes)
-
-
-def _is_integer_like_type(type: Type):
- return _is_any_of(type, [IntegerType, IndexType])
-
-
-def _is_float_type(type: Type):
- return _is_any_of(type, [BF16Type, F16Type, F32Type, F64Type])
-
-
-class ConstantOp:
- """Specialization for the constant op class."""
-
- def __init__(
- self, result: Type, value: Union[int, float, Attribute], *, loc=None, ip=None
- ):
- if isinstance(value, int):
- super().__init__(IntegerAttr.get(result, value), loc=loc, ip=ip)
- elif isinstance(value, float):
- super().__init__(FloatAttr.get(result, value), loc=loc, ip=ip)
- else:
- super().__init__(value, loc=loc, ip=ip)
-
- @classmethod
- def create_index(cls, value: int, *, loc=None, ip=None):
- """Create an index-typed constant."""
- return cls(
- IndexType.get(context=_get_default_loc_context(loc)), value, loc=loc, ip=ip
- )
-
- @property
- def type(self):
- return self.results[0].type
-
- @property
- def value(self):
- return Attribute(self.operation.attributes["value"])
-
- @property
- def literal_value(self) -> Union[int, float]:
- if _is_integer_like_type(self.type):
- return IntegerAttr(self.value).value
- elif _is_float_type(self.type):
- return FloatAttr(self.value).value
- else:
- raise ValueError("only integer and float constants have literal values")
diff --git a/mlir/python/mlir/dialects/_bufferization_ops_ext.py b/mlir/python/mlir/dialects/_bufferization_ops_ext.py
deleted file mode 100644
index 1066cb4c775cab9..000000000000000
--- a/mlir/python/mlir/dialects/_bufferization_ops_ext.py
+++ /dev/null
@@ -1,41 +0,0 @@
-# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-try:
- from typing import Sequence, Union
- from ..ir import *
- from ._ods_common import get_default_loc_context
-
- from typing import Any, List, Union
-except ImportError as e:
- raise RuntimeError("Error loading imports from extension module") from e
-
-
-class AllocTensorOp:
- """Extends the bufferization.alloc_tensor op."""
-
- def __init__(
- self,
- tensor_type: Type,
- dynamic_sizes: Sequence[Value],
- copy: Value,
- size_hint: Value,
- escape: BoolAttr,
- *,
- loc=None,
- ip=None
- ):
- """Constructs an `alloc_tensor` with static and/or dynamic sizes."""
- context = get_default_loc_context(loc)
- attributes = {}
- if escape:
- attributes["escape"] = escape
- op = self.build_generic(
- results=[tensor_type],
- operands=[dynamic_sizes, copy, size_hint],
- attributes=attributes,
- loc=loc,
- ip=ip,
- )
- OpView.__init__(self, op)
diff --git a/mlir/python/mlir/dialects/_bufferization_transform_ops_ext.py b/mlir/python/mlir/dialects/_bufferization_transform_ops_ext.py
deleted file mode 100644
index 7e6c1b81cb350b7..000000000000000
--- a/mlir/python/mlir/dialects/_bufferization_transform_ops_ext.py
+++ /dev/null
@@ -1,128 +0,0 @@
-# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-try:
- from ..ir import *
- from ..dialects import transform
-except ImportError as e:
- raise RuntimeError("Error loading imports from extension module") from e
-
-from enum import Enum
-from typing import Optional, overload, Union
-
-
-class EmptyTensorToAllocTensorOp:
- """Specialization for EmptyTensorToAllocTensorOp class."""
-
- @overload
- def __init__(
- self,
- transformed_type: Type,
- target: Union[Operation, OpView, Value],
- *,
- loc=None,
- ip=None
- ):
- ...
-
- @overload
- def __init__(self, target: Union[Operation, OpView, Value], *, loc=None, ip=None):
- ...
-
- def __init__(
- self,
- transformed_type_or_target: Type,
- target_or_none: Optional[Union[Operation, OpView, Value]] = None,
- *,
- loc=None,
- ip=None
- ):
- if isinstance(transformed_type_or_target, Type):
- transformed_type = transformed_type_or_target
- target = target_or_none
- else:
- transformed_type = transform.OperationType.get("bufferization.alloc_tensor")
- target = transformed_type_or_target
-
- super().__init__(
- transformed_type,
- target,
- loc=loc,
- ip=ip,
- )
-
-
-class OneShotBufferizeOp:
- """Specialization for OneShotBufferizeOp class."""
-
- @overload
- def __init__(
- self,
- transformed_type: Type,
- target: Union[Operation, OpView, Value],
- *,
- allow_return_allocs_from_loops: Optional[bool] = None,
- allow_unknown_ops: Optional[bool] = None,
- bufferize_function_boundaries: Optional[bool] = None,
- function_boundary_type_conversion: Optional[Enum] = None,
- memcpy_op: Optional[str] = None,
- print_conflicts: Optional[bool] = None,
- test_analysis_only: Optional[bool] = None,
- loc=None,
- ip=None
- ):
- ...
-
- @overload
- def __init__(
- self,
- target: Union[Operation, OpView, Value],
- *,
- allow_return_allocs_from_loops: Optional[bool] = None,
- allow_unknown_ops: Optional[bool] = None,
- bufferize_function_boundaries: Optional[bool] = None,
- function_boundary_type_conversion: Optional[Enum] = None,
- memcpy_op: Optional[str] = None,
- print_conflicts: Optional[bool] = None,
- test_analysis_only: Optional[bool] = None,
- loc=None,
- ip=None
- ):
- ...
-
- def __init__(
- self,
- transformed_type_or_target: Type,
- target_or_none: Optional[Union[Operation, OpView, Value]] = None,
- *,
- allow_return_allocs_from_loops: Optional[bool] = None,
- allow_unknown_ops: Optional[bool] = None,
- bufferize_function_boundaries: Optional[bool] = None,
- function_boundary_type_conversion: Optional[Enum] = None,
- memcpy_op: Optional[str] = None,
- print_conflicts: Optional[bool] = None,
- test_analysis_only: Optional[bool] = None,
- loc=None,
- ip=None
- ):
- if isinstance(transformed_type_or_target, Type):
- transformed_type = transformed_type_or_target
- target = target_or_none
- else:
- transformed_type = transform.AnyOpType.get()
- target = transformed_type_or_target
-
- super().__init__(
- transformed_type,
- target,
- allow_return_allocs_from_loops=allow_return_allocs_from_loops,
- allow_unknown_ops=allow_unknown_ops,
- bufferize_function_boundaries=bufferize_function_boundaries,
- function_boundary_type_conversion=function_boundary_type_conversion,
- memcpy_op=memcpy_op,
- print_conflicts=print_conflicts,
- test_analysis_only=test_analysis_only,
- loc=loc,
- ip=ip,
- )
diff --git a/mlir/python/mlir/dialects/_builtin_ops_ext.py b/mlir/python/mlir/dialects/_builtin_ops_ext.py
deleted file mode 100644
index 27a60123050acb4..000000000000000
--- a/mlir/python/mlir/dialects/_builtin_ops_ext.py
+++ /dev/null
@@ -1,20 +0,0 @@
-# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-try:
- from ..ir import *
-except ImportError as e:
- raise RuntimeError("Error loading imports from extension module") from e
-
-
-class ModuleOp:
- """Specialization for the module op class."""
-
- def __init__(self, *, loc=None, ip=None):
- super().__init__(self.build_generic(results=[], operands=[], loc=loc, ip=ip))
- body = self.regions[0].blocks.append()
-
- @property
- def body(self):
- return self.regions[0].blocks[0]
diff --git a/mlir/python/mlir/dialects/_func_ops_ext.py b/mlir/python/mlir/dialects/_func_ops_ext.py
deleted file mode 100644
index 6d264c33f1f9dae..000000000000000
--- a/mlir/python/mlir/dialects/_func_ops_ext.py
+++ /dev/null
@@ -1,319 +0,0 @@
-# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-try:
- from ..ir import *
- from ._ods_common import get_default_loc_context as _get_default_loc_context
-
- import inspect
-
- from typing import Any, List, Optional, Sequence, Union
-except ImportError as e:
- raise RuntimeError("Error loading imports from extension module") from e
-
-ARGUMENT_ATTRIBUTE_NAME = "arg_attrs"
-RESULT_ATTRIBUTE_NAME = "res_attrs"
-
-
-class ConstantOp:
- """Specialization for the constant op class."""
-
- def __init__(self, result: Type, value: Attribute, *, loc=None, ip=None):
- super().__init__(result, value, loc=loc, ip=ip)
-
- @property
- def type(self):
- return self.results[0].type
-
-
-class FuncOp:
- """Specialization for the func op class."""
-
- def __init__(
- self, name, type, *, visibility=None, body_builder=None, loc=None, ip=None
- ):
- """
- Create a FuncOp with the provided `name`, `type`, and `visibility`.
- - `name` is a string representing the function name.
- - `type` is either a FunctionType or a pair of list describing inputs and
- results.
- - `visibility` is a string matching `public`, `private`, or `nested`. None
- implies private visibility.
- - `body_builder` is an optional callback, when provided a new entry block
- is created and the callback is invoked with the new op as argument within
- an InsertionPoint context already set for the block. The callback is
- expected to insert a terminator in the block.
- """
- sym_name = StringAttr.get(str(name))
-
- # If the type is passed as a tuple, build a FunctionType on the fly.
- if isinstance(type, tuple):
- type = FunctionType.get(inputs=type[0], results=type[1])
-
- type = TypeAttr.get(type)
- sym_visibility = (
- StringAttr.get(str(visibility)) if visibility is not None else None
- )
- super().__init__(sym_name, type, sym_visibility=sym_visibility, loc=loc, ip=ip)
- if body_builder:
- entry_block = self.add_entry_block()
- with InsertionPoint(entry_block):
- body_builder(self)
-
- @property
- def is_external(self):
- return len(self.regions[0].blocks) == 0
-
- @property
- def body(self):
- return self.regions[0]
-
- @property
- def type(self):
- return FunctionType(TypeAttr(self.attributes["function_type"]).value)
-
- @property
- def visibility(self):
- return self.attributes["sym_visibility"]
-
- @property
- def name(self) -> StringAttr:
- return StringAttr(self.attributes["sym_name"])
-
- @property
- def entry_block(self):
- if self.is_external:
- raise IndexError("External function does not have a body")
- return self.regions[0].blocks[0]
-
- def add_entry_block(self, arg_locs: Optional[Sequence[Location]] = None):
- """
- Add an entry block to the function body using the function signature to
- infer block arguments.
- Returns the newly created block
- """
- if not self.is_external:
- raise IndexError("The function already has an entry block!")
- self.body.blocks.append(*self.type.inputs, arg_locs=arg_locs)
- return self.body.blocks[0]
-
- @property
- def arg_attrs(self):
- return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME])
-
- @arg_attrs.setter
- def arg_attrs(self, attribute: Union[ArrayAttr, list]):
- if isinstance(attribute, ArrayAttr):
- self.attributes[ARGUMENT_ATTRIBUTE_NAME] = attribute
- else:
- self.attributes[ARGUMENT_ATTRIBUTE_NAME] = ArrayAttr.get(
- attribute, context=self.context
- )
-
- @property
- def arguments(self):
- return self.entry_block.arguments
-
- @property
- def result_attrs(self):
- return self.attributes[RESULT_ATTRIBUTE_NAME]
-
- @result_attrs.setter
- def result_attrs(self, attribute: ArrayAttr):
- self.attributes[RESULT_ATTRIBUTE_NAME] = attribute
-
- @classmethod
- def from_py_func(
- FuncOp,
- *inputs: Type,
- results: Optional[Sequence[Type]] = None,
- name: Optional[str] = None,
- ):
- """Decorator to define an MLIR FuncOp specified as a python function.
-
- Requires that an `mlir.ir.InsertionPoint` and `mlir.ir.Location` are
- active for the current thread (i.e. established in a `with` block).
-
- When applied as a decorator to a Python function, an entry block will
- be constructed for the FuncOp with types as specified in `*inputs`. The
- block arguments will be passed positionally to the Python function. In
- addition, if the Python function accepts keyword arguments generally or
- has a corresponding keyword argument, the following will be passed:
- * `func_op`: The `func` op being defined.
-
- By default, the function name will be the Python function `__name__`. This
- can be overriden by passing the `name` argument to the decorator.
-
- If `results` is not specified, then the decorator will implicitly
- insert a `ReturnOp` with the `Value`'s returned from the decorated
- function. It will also set the `FuncOp` type with the actual return
- value types. If `results` is specified, then the decorated function
- must return `None` and no implicit `ReturnOp` is added (nor are the result
- types updated). The implicit behavior is intended for simple, single-block
- cases, and users should specify result types explicitly for any complicated
- cases.
-
- The decorated function can further be called from Python and will insert
- a `CallOp` at the then-current insertion point, returning either None (
- if no return values), a unary Value (for one result), or a list of Values).
- This mechanism cannot be used to emit recursive calls (by construction).
- """
-
- def decorator(f):
- from . import func
-
- # Introspect the callable for optional features.
- sig = inspect.signature(f)
- has_arg_func_op = False
- for param in sig.parameters.values():
- if param.kind == param.VAR_KEYWORD:
- has_arg_func_op = True
- if param.name == "func_op" and (
- param.kind == param.POSITIONAL_OR_KEYWORD
- or param.kind == param.KEYWORD_ONLY
- ):
- has_arg_func_op = True
-
- # Emit the FuncOp.
- implicit_return = results is None
- symbol_name = name or f.__name__
- function_type = FunctionType.get(
- inputs=inputs, results=[] if implicit_return else results
- )
- func_op = FuncOp(name=symbol_name, type=function_type)
- with InsertionPoint(func_op.add_entry_block()):
- func_args = func_op.entry_block.arguments
- func_kwargs = {}
- if has_arg_func_op:
- func_kwargs["func_op"] = func_op
- return_values = f(*func_args, **func_kwargs)
- if not implicit_return:
- return_types = list(results)
- assert return_values is None, (
- "Capturing a python function with explicit `results=` "
- "requires that the wrapped function returns None."
- )
- else:
- # Coerce return values, add ReturnOp and rewrite func type.
- if return_values is None:
- return_values = []
- elif isinstance(return_values, tuple):
- return_values = list(return_values)
- elif isinstance(return_values, Value):
- # Returning a single value is fine, coerce it into a list.
- return_values = [return_values]
- elif isinstance(return_values, OpView):
- # Returning a single operation is fine, coerce its results a list.
- return_values = return_values.operation.results
- elif isinstance(return_values, Operation):
- # Returning a single operation is fine, coerce its results a list.
- return_values = return_values.results
- else:
- return_values = list(return_values)
- func.ReturnOp(return_values)
- # Recompute the function type.
- return_types = [v.type for v in return_values]
- function_type = FunctionType.get(
- inputs=inputs, results=return_types
- )
- func_op.attributes["function_type"] = TypeAttr.get(function_type)
-
- def emit_call_op(*call_args):
- call_op = func.CallOp(
- return_types, FlatSymbolRefAttr.get(symbol_name), call_args
- )
- if return_types is None:
- return None
- elif len(return_types) == 1:
- return call_op.result
- else:
- return call_op.results
-
- wrapped = emit_call_op
- wrapped.__name__ = f.__name__
- wrapped.func_op = func_op
- return wrapped
-
- return decorator
-
-
-class CallOp:
- """Specialization for the call op class."""
-
- def __init__(
- self,
- calleeOrResults: Union[FuncOp, List[Type]],
- argumentsOrCallee: Union[List, FlatSymbolRefAttr, str],
- arguments: Optional[List] = None,
- *,
- loc=None,
- ip=None,
- ):
- """Creates an call operation.
-
- The constructor accepts three
diff erent forms:
-
- 1. A function op to be called followed by a list of arguments.
- 2. A list of result types, followed by the name of the function to be
- called as string, following by a list of arguments.
- 3. A list of result types, followed by the name of the function to be
- called as symbol reference attribute, followed by a list of arguments.
-
- For example
-
- f = func.FuncOp("foo", ...)
- func.CallOp(f, [args])
- func.CallOp([result_types], "foo", [args])
-
- In all cases, the location and insertion point may be specified as keyword
- arguments if not provided by the surrounding context managers.
- """
-
- # TODO: consider supporting constructor "overloads", e.g., through a custom
- # or pybind-provided metaclass.
- if isinstance(calleeOrResults, FuncOp):
- if not isinstance(argumentsOrCallee, list):
- raise ValueError(
- "when constructing a call to a function, expected "
- + "the second argument to be a list of call arguments, "
- + f"got {type(argumentsOrCallee)}"
- )
- if arguments is not None:
- raise ValueError(
- "unexpected third argument when constructing a call"
- + "to a function"
- )
-
- super().__init__(
- calleeOrResults.type.results,
- FlatSymbolRefAttr.get(
- calleeOrResults.name.value, context=_get_default_loc_context(loc)
- ),
- argumentsOrCallee,
- loc=loc,
- ip=ip,
- )
- return
-
- if isinstance(argumentsOrCallee, list):
- raise ValueError(
- "when constructing a call to a function by name, "
- + "expected the second argument to be a string or a "
- + f"FlatSymbolRefAttr, got {type(argumentsOrCallee)}"
- )
-
- if isinstance(argumentsOrCallee, FlatSymbolRefAttr):
- super().__init__(
- calleeOrResults, argumentsOrCallee, arguments, loc=loc, ip=ip
- )
- elif isinstance(argumentsOrCallee, str):
- super().__init__(
- calleeOrResults,
- FlatSymbolRefAttr.get(
- argumentsOrCallee, context=_get_default_loc_context(loc)
- ),
- arguments,
- loc=loc,
- ip=ip,
- )
diff --git a/mlir/python/mlir/dialects/_gpu_transform_ops_ext.py b/mlir/python/mlir/dialects/_gpu_transform_ops_ext.py
deleted file mode 100644
index ba72bac3a15264d..000000000000000
--- a/mlir/python/mlir/dialects/_gpu_transform_ops_ext.py
+++ /dev/null
@@ -1,124 +0,0 @@
-# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-try:
- from ..ir import *
- from ..dialects import transform
-except ImportError as e:
- raise RuntimeError("Error loading imports from extension module") from e
-
-from typing import Optional, Sequence, Union, overload
-
-
-class MapForallToBlocks:
- """Specialization for MapForallToBlocks class."""
-
- @overload
- def __init__(
- self,
- result_type: Type,
- target: Union[Operation, OpView, Value],
- *,
- grid_dims: Optional[Union[Sequence[int], Attribute]] = None,
- generate_gpu_launch: Optional[Union[bool, Attribute]] = None,
- loc=None,
- ip=None
- ):
- ...
-
- @overload
- def __init__(
- self,
- target: Union[Operation, OpView, Value],
- *,
- grid_dims: Optional[Union[Sequence[int], Attribute]] = None,
- generate_gpu_launch: Optional[Union[bool, Attribute]] = None,
- loc=None,
- ip=None
- ):
- ...
-
- def __init__(
- self,
- result_type_or_target: Union[Operation, OpView, Type, Value],
- target_or_none: Optional[Union[Operation, OpView, Value]] = None,
- *,
- grid_dims: Optional[Union[Sequence[int], Attribute]] = None,
- generate_gpu_launch: Optional[Union[bool, Attribute]] = None,
- loc=None,
- ip=None
- ):
- if isinstance(result_type_or_target, Type):
- result_type = result_type_or_target
- target = target_or_none
- else:
- result_type = transform.AnyOpType.get()
- target = result_type_or_target
-
- super().__init__(
- result_type,
- target,
- grid_dims=grid_dims,
- generate_gpu_launch=generate_gpu_launch,
- loc=loc,
- ip=ip,
- )
-
-
-class MapNestedForallToThreads:
- """Specialization for MapNestedForallToThreads class."""
-
- @overload
- def __init__(
- self,
- result_type: Type,
- target: Union[Operation, OpView, Value],
- *,
- block_dims: Optional[Sequence[int]] = None,
- warp_size: Optional[Sequence[int]] = None,
- sync_after_distribute: Optional[bool] = None,
- loc=None,
- ip=None
- ):
- ...
-
- @overload
- def __init__(
- self,
- target: Union[Operation, OpView, Value],
- *,
- block_dims: Optional[Sequence[int]] = None,
- warp_size: Optional[Sequence[int]] = None,
- sync_after_distribute: Optional[bool] = None,
- loc=None,
- ip=None
- ):
- ...
-
- def __init__(
- self,
- result_type_or_target: Union[Operation, OpView, Value, Type],
- target_or_none: Optional[Union[Operation, OpView, Value]] = None,
- *,
- block_dims: Optional[Union[Sequence[int], Attribute]] = None,
- warp_size: Optional[Union[Sequence[int], Attribute]] = None,
- sync_after_distribute: Optional[bool] = None,
- loc=None,
- ip=None
- ):
- if isinstance(result_type_or_target, Type):
- result_type = result_type_or_target
- target = target_or_none
- else:
- result_type = result_type_or_target.type
- target = result_type_or_target
- super().__init__(
- result_type,
- target,
- block_dims=block_dims,
- warp_size=warp_size,
- sync_after_distribute=sync_after_distribute,
- loc=loc,
- ip=ip,
- )
diff --git a/mlir/python/mlir/dialects/_linalg_ops_ext.py b/mlir/python/mlir/dialects/_linalg_ops_ext.py
deleted file mode 100644
index 3f6d854ca3e2b14..000000000000000
--- a/mlir/python/mlir/dialects/_linalg_ops_ext.py
+++ /dev/null
@@ -1,47 +0,0 @@
-# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-try:
- from typing import Optional, Sequence, Union
- from ..ir import *
- from ._ods_common import get_default_loc_context
- from .._mlir_libs._mlirDialectsLinalg import fill_builtin_region
-except ImportError as e:
- raise RuntimeError("Error loading imports from extension module") from e
-
-from ._ods_common import get_op_result_or_value as _get_op_result_or_value
-
-
-def isa(cls: Type, ty: Type):
- try:
- cls(ty)
- return True
- except ValueError:
- return False
-
-
-class StructuredOpMixin:
- """All structured ops use the same mixin class."""
-
- def __init__(self, inputs, outputs=(), results=(), loc=None, ip=None):
- super().__init__(
- self.build_generic(
- results=list(results),
- operands=[list(inputs), list(outputs)],
- loc=loc,
- ip=ip,
- )
- )
-
-
-def select_opview_mixin(parent_opview_cls):
- # TODO: This shouldn't be a heuristic: we should have a way to annotate
- # the OpView to note that it is a structured op.
- if (
- "__init__" not in parent_opview_cls.__dict__
- and hasattr(parent_opview_cls, "inputs")
- and hasattr(parent_opview_cls, "outputs")
- and hasattr(parent_opview_cls, "result_tensors")
- ):
- return StructuredOpMixin
diff --git a/mlir/python/mlir/dialects/_loop_transform_ops_ext.py b/mlir/python/mlir/dialects/_loop_transform_ops_ext.py
deleted file mode 100644
index 1cdb2b9e77b5afe..000000000000000
--- a/mlir/python/mlir/dialects/_loop_transform_ops_ext.py
+++ /dev/null
@@ -1,134 +0,0 @@
-# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-try:
- from ..ir import *
- from ._ods_common import get_op_result_or_value as _get_op_result_or_value
-except ImportError as e:
- raise RuntimeError("Error loading imports from extension module") from e
-
-from typing import Optional, Union
-
-
-class GetParentForOp:
- """Extension for GetParentForOp."""
-
- def __init__(
- self,
- result_type: Type,
- target: Union[Operation, Value],
- *,
- num_loops: Optional[int] = None,
- ip=None,
- loc=None,
- ):
- if num_loops is None:
- num_loops = 1
- super().__init__(
- result_type,
- _get_op_result_or_value(target),
- num_loops=num_loops,
- ip=ip,
- loc=loc,
- )
-
-
-class LoopOutlineOp:
- """Extension for LoopOutlineOp."""
-
- def __init__(
- self,
- function_type: Type,
- call_type: Type,
- target: Union[Operation, Value],
- *,
- func_name: Union[str, StringAttr],
- ip=None,
- loc=None,
- ):
- super().__init__(
- function_type,
- call_type,
- _get_op_result_or_value(target),
- func_name=(
- func_name
- if isinstance(func_name, StringAttr)
- else StringAttr.get(func_name)
- ),
- ip=ip,
- loc=loc,
- )
-
-
-class LoopPeelOp:
- """Extension for LoopPeelOp."""
-
- def __init__(
- self,
- main_loop_type: Type,
- remainder_loop_type: Type,
- target: Union[Operation, Value],
- *,
- fail_if_already_divisible: Union[bool, BoolAttr] = False,
- ip=None,
- loc=None,
- ):
- super().__init__(
- main_loop_type,
- remainder_loop_type,
- _get_op_result_or_value(target),
- fail_if_already_divisible=(
- fail_if_already_divisible
- if isinstance(fail_if_already_divisible, BoolAttr)
- else BoolAttr.get(fail_if_already_divisible)
- ),
- ip=ip,
- loc=loc,
- )
-
-
-class LoopPipelineOp:
- """Extension for LoopPipelineOp."""
-
- def __init__(
- self,
- result_type: Type,
- target: Union[Operation, Value],
- *,
- iteration_interval: Optional[Union[int, IntegerAttr]] = None,
- read_latency: Optional[Union[int, IntegerAttr]] = None,
- ip=None,
- loc=None,
- ):
- if iteration_interval is None:
- iteration_interval = 1
- if read_latency is None:
- read_latency = 10
- super().__init__(
- result_type,
- _get_op_result_or_value(target),
- iteration_interval=iteration_interval,
- read_latency=read_latency,
- ip=ip,
- loc=loc,
- )
-
-
-class LoopUnrollOp:
- """Extension for LoopUnrollOp."""
-
- def __init__(
- self,
- target: Union[Operation, Value],
- *,
- factor: Union[int, IntegerAttr],
- ip=None,
- loc=None,
- ):
- super().__init__(
- _get_op_result_or_value(target),
- factor=factor,
- ip=ip,
- loc=loc,
- )
diff --git a/mlir/python/mlir/dialects/_memref_ops_ext.py b/mlir/python/mlir/dialects/_memref_ops_ext.py
deleted file mode 100644
index 825f1a0a7a6faf4..000000000000000
--- a/mlir/python/mlir/dialects/_memref_ops_ext.py
+++ /dev/null
@@ -1,36 +0,0 @@
-# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-try:
- from ..ir import *
- from ._ods_common import get_op_result_or_value as _get_op_result_or_value
- from ._ods_common import get_op_results_or_values as _get_op_results_or_values
-except ImportError as e:
- raise RuntimeError("Error loading imports from extension module") from e
-
-from typing import Optional, Sequence, Union
-
-
-class LoadOp:
- """Specialization for the MemRef load operation."""
-
- def __init__(
- self,
- memref: Union[Operation, OpView, Value],
- indices: Optional[Union[Operation, OpView, Sequence[Value]]] = None,
- *,
- loc=None,
- ip=None
- ):
- """Creates a memref load operation.
-
- Args:
- memref: the buffer to load from.
- indices: the list of subscripts, may be empty for zero-dimensional
- buffers.
- loc: user-visible location of the operation.
- ip: insertion point.
- """
- indices_resolved = [] if indices is None else _get_op_results_or_values(indices)
- super().__init__(memref, indices_resolved, loc=loc, ip=ip)
diff --git a/mlir/python/mlir/dialects/_memref_transform_ops_ext.py b/mlir/python/mlir/dialects/_memref_transform_ops_ext.py
deleted file mode 100644
index 1cc00bdcbf381c9..000000000000000
--- a/mlir/python/mlir/dialects/_memref_transform_ops_ext.py
+++ /dev/null
@@ -1,114 +0,0 @@
-# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-try:
- from ..ir import *
- from ..dialects import transform
-except ImportError as e:
- raise RuntimeError("Error loading imports from extension module") from e
-
-from typing import Optional, overload, Union
-
-
-class MemRefAllocaToGlobalOp:
- """Specialization for MemRefAllocaToGlobalOp class."""
-
- @overload
- def __init__(
- self,
- get_global_type: Type,
- global_type: Type,
- alloca: Union[Operation, OpView, Value],
- *,
- loc=None,
- ip=None
- ):
- ...
-
- @overload
- def __init__(self, alloca: Union[Operation, OpView, Value], *, loc=None, ip=None):
- ...
-
- def __init__(
- self,
- get_global_type_or_alloca: Union[Operation, OpView, Type, Value],
- global_type_or_none: Optional[Type] = None,
- alloca_or_none: Optional[Union[Operation, OpView, Value]] = None,
- *,
- loc=None,
- ip=None
- ):
- if isinstance(get_global_type_or_alloca, Type):
- get_global_type = get_global_type_or_alloca
- global_type = global_type_or_none
- alloca = alloca_or_none
- else:
- get_global_type = transform.AnyOpType.get()
- global_type = transform.AnyOpType.get()
- alloca = get_global_type_or_alloca
-
- super().__init__(
- get_global_type,
- global_type,
- alloca,
- loc=loc,
- ip=ip,
- )
-
-
-class MemRefMultiBufferOp:
- """Specialization for MemRefMultiBufferOp class."""
-
- @overload
- def __init__(
- self,
- transformed_type: Type,
- target: Union[Operation, OpView, Value],
- factor: Union[int, IntegerAttr],
- *,
- skip_analysis: Optional[bool] = None,
- loc=None,
- ip=None
- ):
- ...
-
- @overload
- def __init__(
- self,
- target: Union[Operation, OpView, Value],
- factor: Union[int, IntegerAttr],
- *,
- skip_analysis: Optional[bool] = None,
- loc=None,
- ip=None
- ):
- ...
-
- def __init__(
- self,
- transformed_type_or_target: Type,
- target_or_factor: Union[int, IntegerAttr, Operation, OpView, Value] = None,
- factor_or_none: Optional[Union[int, IntegerAttr]] = None,
- *,
- skip_analysis: Optional[bool] = None,
- loc=None,
- ip=None
- ):
- if isinstance(transformed_type_or_target, Type):
- transformed_type = transformed_type_or_target
- target = target_or_factor
- factor = factor_or_none
- else:
- transformed_type = transform.AnyOpType.get()
- target = transformed_type_or_target
- factor = target_or_factor
-
- super().__init__(
- transformed_type,
- target,
- factor,
- skip_analysis=skip_analysis,
- loc=loc,
- ip=ip,
- )
diff --git a/mlir/python/mlir/dialects/_ml_program_ops_ext.py b/mlir/python/mlir/dialects/_ml_program_ops_ext.py
deleted file mode 100644
index c84d23c16ef93ab..000000000000000
--- a/mlir/python/mlir/dialects/_ml_program_ops_ext.py
+++ /dev/null
@@ -1,113 +0,0 @@
-# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-try:
- from typing import Union
- from ..ir import *
- from ._ods_common import get_default_loc_context as _get_default_loc_context
-except ImportError as e:
- raise RuntimeError("Error loading imports from extension module") from e
-
-from ._ml_program_ops_gen import *
-
-
-ARGUMENT_ATTRIBUTE_NAME = "arg_attrs"
-RESULT_ATTRIBUTE_NAME = "res_attrs"
-
-
-class FuncOp:
- """Specialization for the func op class."""
-
- def __init__(
- self, name, type, *, visibility=None, body_builder=None, loc=None, ip=None
- ):
- """
- Create a FuncOp with the provided `name`, `type`, and `visibility`.
- - `name` is a string representing the function name.
- - `type` is either a FunctionType or a pair of list describing inputs and
- results.
- - `visibility` is a string matching `public`, `private`, or `nested`. None
- implies private visibility.
- - `body_builder` is an optional callback, when provided a new entry block
- is created and the callback is invoked with the new op as argument within
- an InsertionPoint context already set for the block. The callback is
- expected to insert a terminator in the block.
- """
- sym_name = StringAttr.get(str(name))
-
- # If the type is passed as a tuple, build a FunctionType on the fly.
- if isinstance(type, tuple):
- type = FunctionType.get(inputs=type[0], results=type[1])
-
- type = TypeAttr.get(type)
- sym_visibility = (
- StringAttr.get(str(visibility)) if visibility is not None else None
- )
- super().__init__(sym_name, type, sym_visibility=sym_visibility, loc=loc, ip=ip)
- if body_builder:
- entry_block = self.add_entry_block()
- with InsertionPoint(entry_block):
- body_builder(self)
-
- @property
- def is_external(self):
- return len(self.regions[0].blocks) == 0
-
- @property
- def body(self):
- return self.regions[0]
-
- @property
- def type(self):
- return FunctionType(TypeAttr(self.attributes["function_type"]).value)
-
- @property
- def visibility(self):
- return self.attributes["sym_visibility"]
-
- @property
- def name(self) -> StringAttr:
- return StringAttr(self.attributes["sym_name"])
-
- @property
- def entry_block(self):
- if self.is_external:
- raise IndexError("External function does not have a body")
- return self.regions[0].blocks[0]
-
- def add_entry_block(self):
- """
- Add an entry block to the function body using the function signature to
- infer block arguments.
- Returns the newly created block
- """
- if not self.is_external:
- raise IndexError("The function already has an entry block!")
- self.body.blocks.append(*self.type.inputs)
- return self.body.blocks[0]
-
- @property
- def arg_attrs(self):
- return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME])
-
- @arg_attrs.setter
- def arg_attrs(self, attribute: Union[ArrayAttr, list]):
- if isinstance(attribute, ArrayAttr):
- self.attributes[ARGUMENT_ATTRIBUTE_NAME] = attribute
- else:
- self.attributes[ARGUMENT_ATTRIBUTE_NAME] = ArrayAttr.get(
- attribute, context=self.context
- )
-
- @property
- def arguments(self):
- return self.entry_block.arguments
-
- @property
- def result_attrs(self):
- return self.attributes[RESULT_ATTRIBUTE_NAME]
-
- @result_attrs.setter
- def result_attrs(self, attribute: ArrayAttr):
- self.attributes[RESULT_ATTRIBUTE_NAME] = attribute
diff --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py
index 895c3228139b392..9cca7d659ec8cb3 100644
--- a/mlir/python/mlir/dialects/_ods_common.py
+++ b/mlir/python/mlir/dialects/_ods_common.py
@@ -9,7 +9,6 @@
__all__ = [
"equally_sized_accessor",
- "extend_opview_class",
"get_default_loc_context",
"get_op_result_or_value",
"get_op_results_or_values",
@@ -18,64 +17,6 @@
]
-def extend_opview_class(ext_module):
- """Decorator to extend an OpView class from an extension module.
-
- Extension modules can expose various entry-points:
- Stand-alone class with the same name as a parent OpView class (i.e.
- "ReturnOp"). A name-based match is attempted first before falling back
- to a below mechanism.
-
- def select_opview_mixin(parent_opview_cls):
- If defined, allows an appropriate mixin class to be selected dynamically
- based on the parent OpView class. Should return NotImplemented if a
- decision is not made.
-
- Args:
- ext_module: A module from which to locate extensions. Can be None if not
- available.
-
- Returns:
- A decorator that takes an OpView subclass and further extends it as
- needed.
- """
-
- def class_decorator(parent_opview_cls: type):
- if ext_module is None:
- return parent_opview_cls
- mixin_cls = NotImplemented
- # First try to resolve by name.
- try:
- mixin_cls = getattr(ext_module, parent_opview_cls.__name__)
- except AttributeError:
- # Fall back to a select_opview_mixin hook.
- try:
- select_mixin = getattr(ext_module, "select_opview_mixin")
- except AttributeError:
- pass
- else:
- mixin_cls = select_mixin(parent_opview_cls)
-
- if mixin_cls is NotImplemented or mixin_cls is None:
- return parent_opview_cls
-
- # Have a mixin_cls. Create an appropriate subclass.
- try:
-
- class LocalOpView(mixin_cls, parent_opview_cls):
- pass
-
- except TypeError as e:
- raise TypeError(
- f"Could not mixin {mixin_cls} into {parent_opview_cls}"
- ) from e
- LocalOpView.__name__ = parent_opview_cls.__name__
- LocalOpView.__qualname__ = parent_opview_cls.__qualname__
- return LocalOpView
-
- return class_decorator
-
-
def segmented_accessor(elements, raw_segments, idx):
"""
Returns a slice of elements corresponding to the idx-th segment.
diff --git a/mlir/python/mlir/dialects/_pdl_ops_ext.py b/mlir/python/mlir/dialects/_pdl_ops_ext.py
deleted file mode 100644
index fc9de0b7f7db69c..000000000000000
--- a/mlir/python/mlir/dialects/_pdl_ops_ext.py
+++ /dev/null
@@ -1,271 +0,0 @@
-# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-try:
- from ..ir import *
- from ..dialects import pdl
-except ImportError as e:
- raise RuntimeError("Error loading imports from extension module") from e
-
-from typing import Union, Optional, Sequence, Mapping
-from ._ods_common import (
- get_op_result_or_value as _get_value,
- get_op_results_or_values as _get_values,
-)
-
-
-class ApplyNativeConstraintOp:
- """Specialization for PDL apply native constraint op class."""
-
- def __init__(
- self,
- name: Union[str, StringAttr],
- args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
- *,
- loc=None,
- ip=None,
- ):
- if args is None:
- args = []
- args = _get_values(args)
- super().__init__(name, args, loc=loc, ip=ip)
-
-
-class ApplyNativeRewriteOp:
- """Specialization for PDL apply native rewrite op class."""
-
- def __init__(
- self,
- results: Sequence[Type],
- name: Union[str, StringAttr],
- args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
- *,
- loc=None,
- ip=None,
- ):
- if args is None:
- args = []
- args = _get_values(args)
- super().__init__(results, name, args, loc=loc, ip=ip)
-
-
-class AttributeOp:
- """Specialization for PDL attribute op class."""
-
- def __init__(
- self,
- valueType: Optional[Union[OpView, Operation, Value]] = None,
- value: Optional[Attribute] = None,
- *,
- loc=None,
- ip=None,
- ):
- valueType = valueType if valueType is None else _get_value(valueType)
- result = pdl.AttributeType.get()
- super().__init__(result, valueType=valueType, value=value, loc=loc, ip=ip)
-
-
-class EraseOp:
- """Specialization for PDL erase op class."""
-
- def __init__(
- self,
- operation: Optional[Union[OpView, Operation, Value]] = None,
- *,
- loc=None,
- ip=None,
- ):
- operation = _get_value(operation)
- super().__init__(operation, loc=loc, ip=ip)
-
-
-class OperandOp:
- """Specialization for PDL operand op class."""
-
- def __init__(
- self,
- type: Optional[Union[OpView, Operation, Value]] = None,
- *,
- loc=None,
- ip=None,
- ):
- type = type if type is None else _get_value(type)
- result = pdl.ValueType.get()
- super().__init__(result, valueType=type, loc=loc, ip=ip)
-
-
-class OperandsOp:
- """Specialization for PDL operands op class."""
-
- def __init__(
- self,
- types: Optional[Union[OpView, Operation, Value]] = None,
- *,
- loc=None,
- ip=None,
- ):
- types = types if types is None else _get_value(types)
- result = pdl.RangeType.get(pdl.ValueType.get())
- super().__init__(result, valueType=types, loc=loc, ip=ip)
-
-
-class OperationOp:
- """Specialization for PDL operand op class."""
-
- def __init__(
- self,
- name: Optional[Union[str, StringAttr]] = None,
- args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
- attributes: Optional[Mapping[str, Union[OpView, Operation, Value]]] = None,
- types: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
- *,
- loc=None,
- ip=None,
- ):
- if types is None:
- types = []
- if attributes is None:
- attributes = {}
- if args is None:
- args = []
- args = _get_values(args)
- attrNames = []
- attrValues = []
- for attrName, attrValue in attributes.items():
- attrNames.append(StringAttr.get(attrName))
- attrValues.append(_get_value(attrValue))
- attrNames = ArrayAttr.get(attrNames)
- types = _get_values(types)
- result = pdl.OperationType.get()
- super().__init__(
- result, args, attrValues, attrNames, types, opName=name, loc=loc, ip=ip
- )
-
-
-class PatternOp:
- """Specialization for PDL pattern op class."""
-
- def __init__(
- self,
- benefit: Union[IntegerAttr, int],
- name: Optional[Union[StringAttr, str]] = None,
- *,
- loc=None,
- ip=None,
- ):
- """Creates an PDL `pattern` operation."""
- super().__init__(benefit, sym_name=name, loc=loc, ip=ip)
- self.regions[0].blocks.append()
-
- @property
- def body(self):
- """Return the body (block) of the pattern."""
- return self.regions[0].blocks[0]
-
-
-class ReplaceOp:
- """Specialization for PDL replace op class."""
-
- def __init__(
- self,
- op: Union[OpView, Operation, Value],
- *,
- with_op: Optional[Union[OpView, Operation, Value]] = None,
- with_values: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
- loc=None,
- ip=None,
- ):
- if with_values is None:
- with_values = []
- op = _get_value(op)
- with_op = with_op if with_op is None else _get_value(with_op)
- with_values = _get_values(with_values)
- super().__init__(op, with_values, replOperation=with_op, loc=loc, ip=ip)
-
-
-class ResultOp:
- """Specialization for PDL result op class."""
-
- def __init__(
- self,
- parent: Union[OpView, Operation, Value],
- index: Union[IntegerAttr, int],
- *,
- loc=None,
- ip=None,
- ):
- parent = _get_value(parent)
- result = pdl.ValueType.get()
- super().__init__(result, parent, index, loc=loc, ip=ip)
-
-
-class ResultsOp:
- """Specialization for PDL results op class."""
-
- def __init__(
- self,
- result: Type,
- parent: Union[OpView, Operation, Value],
- index: Optional[Union[IntegerAttr, int]] = None,
- *,
- loc=None,
- ip=None,
- ):
- parent = _get_value(parent)
- super().__init__(result, parent, index=index, loc=loc, ip=ip)
-
-
-class RewriteOp:
- """Specialization for PDL rewrite op class."""
-
- def __init__(
- self,
- root: Optional[Union[OpView, Operation, Value]] = None,
- name: Optional[Union[StringAttr, str]] = None,
- args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
- *,
- loc=None,
- ip=None,
- ):
- if args is None:
- args = []
- root = root if root is None else _get_value(root)
- args = _get_values(args)
- super().__init__(args, root=root, name=name, loc=loc, ip=ip)
-
- def add_body(self):
- """Add body (block) to the rewrite."""
- self.regions[0].blocks.append()
- return self.body
-
- @property
- def body(self):
- """Return the body (block) of the rewrite."""
- return self.regions[0].blocks[0]
-
-
-class TypeOp:
- """Specialization for PDL type op class."""
-
- def __init__(
- self, constantType: Optional[Union[TypeAttr, Type]] = None, *, loc=None, ip=None
- ):
- result = pdl.TypeType.get()
- super().__init__(result, constantType=constantType, loc=loc, ip=ip)
-
-
-class TypesOp:
- """Specialization for PDL types op class."""
-
- def __init__(
- self,
- constantTypes: Optional[Sequence[Union[TypeAttr, Type]]] = None,
- *,
- loc=None,
- ip=None,
- ):
- if constantTypes is None:
- constantTypes = []
- result = pdl.RangeType.get(pdl.TypeType.get())
- super().__init__(result, constantTypes=constantTypes, loc=loc, ip=ip)
diff --git a/mlir/python/mlir/dialects/_scf_ops_ext.py b/mlir/python/mlir/dialects/_scf_ops_ext.py
deleted file mode 100644
index 89cc8a19895c7b4..000000000000000
--- a/mlir/python/mlir/dialects/_scf_ops_ext.py
+++ /dev/null
@@ -1,107 +0,0 @@
-# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-try:
- from ..ir import *
-except ImportError as e:
- raise RuntimeError("Error loading imports from extension module") from e
-
-from typing import Optional, Sequence, Union
-
-from ._ods_common import (
- get_op_result_or_value as _get_op_result_or_value,
- get_op_results_or_values as _get_op_results_or_values,
-)
-
-
-class ForOp:
- """Specialization for the SCF for op class."""
-
- def __init__(
- self,
- lower_bound,
- upper_bound,
- step,
- iter_args: Optional[Union[Operation, OpView, Sequence[Value]]] = None,
- *,
- loc=None,
- ip=None,
- ):
- """Creates an SCF `for` operation.
-
- - `lower_bound` is the value to use as lower bound of the loop.
- - `upper_bound` is the value to use as upper bound of the loop.
- - `step` is the value to use as loop step.
- - `iter_args` is a list of additional loop-carried arguments or an operation
- producing them as results.
- """
- if iter_args is None:
- iter_args = []
- iter_args = _get_op_results_or_values(iter_args)
-
- results = [arg.type for arg in iter_args]
- super().__init__(
- self.build_generic(
- regions=1,
- results=results,
- operands=[
- _get_op_result_or_value(o) for o in [lower_bound, upper_bound, step]
- ]
- + list(iter_args),
- loc=loc,
- ip=ip,
- )
- )
- self.regions[0].blocks.append(self.operands[0].type, *results)
-
- @property
- def body(self):
- """Returns the body (block) of the loop."""
- return self.regions[0].blocks[0]
-
- @property
- def induction_variable(self):
- """Returns the induction variable of the loop."""
- return self.body.arguments[0]
-
- @property
- def inner_iter_args(self):
- """Returns the loop-carried arguments usable within the loop.
-
- To obtain the loop-carried operands, use `iter_args`.
- """
- return self.body.arguments[1:]
-
-
-class IfOp:
- """Specialization for the SCF if op class."""
-
- def __init__(self, cond, results_=[], *, hasElse=False, loc=None, ip=None):
- """Creates an SCF `if` operation.
-
- - `cond` is a MLIR value of 'i1' type to determine which regions of code will be executed.
- - `hasElse` determines whether the if operation has the else branch.
- """
- operands = []
- operands.append(cond)
- results = []
- results.extend(results_)
- super().__init__(
- self.build_generic(
- regions=2, results=results, operands=operands, loc=loc, ip=ip
- )
- )
- self.regions[0].blocks.append(*[])
- if hasElse:
- self.regions[1].blocks.append(*[])
-
- @property
- def then_block(self):
- """Returns the then block of the if operation."""
- return self.regions[0].blocks[0]
-
- @property
- def else_block(self):
- """Returns the else block of the if operation."""
- return self.regions[1].blocks[0]
diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
deleted file mode 100644
index 3757a3d3b4cce85..000000000000000
--- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
+++ /dev/null
@@ -1,759 +0,0 @@
-# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-try:
- from ..ir import *
- from ..dialects import transform
-except ImportError as e:
- raise RuntimeError("Error loading imports from extension module") from e
-
-from typing import List, Optional, Sequence, Tuple, Union, overload
-
-StaticIntLike = Union[int, IntegerAttr]
-ValueLike = Union[Operation, OpView, Value]
-MixedInt = Union[StaticIntLike, ValueLike]
-
-IntOrAttrList = Sequence[Union[IntegerAttr, int]]
-OptionalIntList = Optional[Union[ArrayAttr, IntOrAttrList]]
-
-BoolOrAttrList = Sequence[Union[BoolAttr, bool]]
-OptionalBoolList = Optional[Union[ArrayAttr, BoolOrAttrList]]
-
-MixedValues = Union[Sequence[Union[StaticIntLike, ValueLike]], ArrayAttr, ValueLike]
-
-DynamicIndexList = Sequence[Union[MixedInt, Sequence[MixedInt]]]
-
-
-def _dispatch_dynamic_index_list(
- indices: Union[DynamicIndexList, ArrayAttr],
-) -> Tuple[List[ValueLike], Union[List[int], ArrayAttr], List[bool]]:
- """Dispatches a list of indices to the appropriate form.
-
- This is similar to the custom `DynamicIndexList` directive upstream:
- provided indices may be in the form of dynamic SSA values or static values,
- and they may be scalable (i.e., as a singleton list) or not. This function
- dispatches each index into its respective form. It also extracts the SSA
- values and static indices from various similar structures, respectively.
- """
- dynamic_indices = []
- static_indices = [ShapedType.get_dynamic_size()] * len(indices)
- scalable_indices = [False] * len(indices)
-
- # ArrayAttr: Extract index values.
- if isinstance(indices, ArrayAttr):
- indices = [idx for idx in indices]
-
- def process_nonscalable_index(i, index):
- """Processes any form of non-scalable index.
-
- Returns False if the given index was scalable and thus remains
- unprocessed; True otherwise.
- """
- if isinstance(index, int):
- static_indices[i] = index
- elif isinstance(index, IntegerAttr):
- static_indices[i] = index.value # pytype: disable=attribute-error
- elif isinstance(index, (Operation, Value, OpView)):
- dynamic_indices.append(index)
- else:
- return False
- return True
-
- # Process each index at a time.
- for i, index in enumerate(indices):
- if not process_nonscalable_index(i, index):
- # If it wasn't processed, it must be a scalable index, which is
- # provided as a Sequence of one value, so extract and process that.
- scalable_indices[i] = True
- assert len(index) == 1
- ret = process_nonscalable_index(i, index[0])
- assert ret
-
- return dynamic_indices, static_indices, scalable_indices
-
-
-# Dispatches `MixedValues` that all represents integers in various forms into
-# the following three categories:
-# - `dynamic_values`: a list of `Value`s, potentially from op results;
-# - `packed_values`: a value handle, potentially from an op result, associated
-# to one or more payload operations of integer type;
-# - `static_values`: an `ArrayAttr` of `i64`s with static values, from Python
-# `int`s, from `IntegerAttr`s, or from an `ArrayAttr`.
-# The input is in the form for `packed_values`, only that result is set and the
-# other two are empty. Otherwise, the input can be a mix of the other two forms,
-# and for each dynamic value, a special value is added to the `static_values`.
-def _dispatch_mixed_values(
- values: MixedValues,
-) -> Tuple[List[Value], Union[Operation, Value, OpView], DenseI64ArrayAttr]:
- dynamic_values = []
- packed_values = None
- static_values = None
- if isinstance(values, ArrayAttr):
- static_values = values
- elif isinstance(values, (Operation, Value, OpView)):
- packed_values = values
- else:
- static_values = []
- for size in values or []:
- if isinstance(size, int):
- static_values.append(size)
- else:
- static_values.append(ShapedType.get_dynamic_size())
- dynamic_values.append(size)
- static_values = DenseI64ArrayAttr.get(static_values)
-
- return (dynamic_values, packed_values, static_values)
-
-
-def _get_value_or_attribute_value(
- value_or_attr: Union[any, Attribute, ArrayAttr]
-) -> any:
- if isinstance(value_or_attr, Attribute) and hasattr(value_or_attr, "value"):
- return value_or_attr.value
- if isinstance(value_or_attr, ArrayAttr):
- return _get_value_list(value_or_attr)
- return value_or_attr
-
-
-def _get_value_list(
- sequence_or_array_attr: Union[Sequence[any], ArrayAttr]
-) -> Sequence[any]:
- return [_get_value_or_attribute_value(v) for v in sequence_or_array_attr]
-
-
-def _get_int_array_attr(values: Optional[Union[ArrayAttr, IntOrAttrList]]) -> ArrayAttr:
- if values is None:
- return None
-
- # Turn into a Python list of Python ints.
- values = _get_value_list(values)
-
- # Make an ArrayAttr of IntegerAttrs out of it.
- return ArrayAttr.get(
- [IntegerAttr.get(IntegerType.get_signless(64), v) for v in values]
- )
-
-
-def _get_int_array_array_attr(
- values: Optional[Union[ArrayAttr, Sequence[Union[ArrayAttr, IntOrAttrList]]]]
-) -> ArrayAttr:
- """Creates an ArrayAttr of ArrayAttrs of IntegerAttrs.
-
- The input has to be a collection of collection of integers, where any
- Python Sequence and ArrayAttr are admissible collections and Python ints and
- any IntegerAttr are admissible integers. Both levels of collections are
- turned into ArrayAttr; the inner level is turned into IntegerAttrs of i64s.
- If the input is None, an empty ArrayAttr is returned.
- """
- if values is None:
- return None
-
- # Make sure the outer level is a list.
- values = _get_value_list(values)
-
- # The inner level is now either invalid or a mixed sequence of ArrayAttrs and
- # Sequences. Make sure the nested values are all lists.
- values = [_get_value_list(nested) for nested in values]
-
- # Turn each nested list into an ArrayAttr.
- values = [_get_int_array_attr(nested) for nested in values]
-
- # Turn the outer list into an ArrayAttr.
- return ArrayAttr.get(values)
-
-
-class BufferizeToAllocationOp:
- """Specialization for BufferizeToAllocationOp class."""
-
- def __init__(
- self,
- target: Union[Operation, OpView, Value],
- *,
- memory_space: Optional[Union[int, str, Attribute]] = None,
- memcpy_op: Optional[str] = None,
- alloc_op: Optional[str] = None,
- bufferize_destination_only: Optional[bool] = None,
- loc=None,
- ip=None,
- ):
- # No other types are allowed, so hard-code those here.
- allocated_buffer_type = transform.AnyValueType.get()
- new_ops_type = transform.AnyOpType.get()
-
- if isinstance(memory_space, int):
- memory_space = str(memory_space)
- if isinstance(memory_space, str):
- memory_space = Attribute.parse(memory_space)
-
- super().__init__(
- allocated_buffer_type,
- new_ops_type,
- target,
- memory_space=memory_space,
- memcpy_op=memcpy_op,
- alloc_op=alloc_op,
- bufferize_destination_only=bufferize_destination_only,
- loc=loc,
- ip=ip,
- )
-
-
-class DecomposeOp:
- """Specialization for DecomposeOp class."""
-
- def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
- transformed_type = transform.AnyOpType.get()
- super().__init__(transformed_type, target, loc=loc, ip=ip)
-
-
-class FuseIntoContainingOp:
- """Specialization for FuseIntoContainingOp class."""
-
- @overload
- def __init__(
- self,
- fused_op_type: Type,
- new_containing_op_type: Type,
- producer_op: Union[Operation, OpView, Value],
- containing_op: Union[Operation, OpView, Value],
- *,
- loc=None,
- ip=None,
- ):
- ...
-
- @overload
- def __init__(
- self,
- producer_op: Union[Operation, OpView, Value],
- containing_op: Union[Operation, OpView, Value],
- *,
- loc=None,
- ip=None,
- ):
- ...
-
- def __init__(
- self,
- fused_op_type_or_producer_op: Union[Operation, OpView, Type, Value],
- new_containing_op_type_or_containing_op: Union[Operation, OpView, Type, Value],
- producer_op_or_none: Optional[Union[Operation, OpView, Value]] = None,
- containing_op_or_none: Optional[Union[Operation, OpView, Value]] = None,
- *,
- loc=None,
- ip=None,
- ):
- if isinstance(fused_op_type_or_producer_op, Type):
- if not isinstance(new_containing_op_type_or_containing_op, Type):
- raise TypeError(
- "If 'fused_op_type_or_producer_op' is a type, then "
- "'new_containing_op_type_or_containing_op' is expected "
- "to be one as well."
- )
- fused_op_type = fused_op_type_or_producer_op
- new_containing_op_type = new_containing_op_type_or_containing_op
- producer_op = producer_op_or_none
- containing_op = containing_op_or_none
- else:
- fused_op_type = transform.AnyOpType.get()
- new_containing_op_type = transform.AnyOpType.get()
- producer_op = fused_op_type_or_producer_op
- containing_op = new_containing_op_type_or_containing_op
-
- super().__init__(
- fused_op_type,
- new_containing_op_type,
- producer_op,
- containing_op,
- loc=loc,
- ip=ip,
- )
-
-
-class GeneralizeOp:
- """Specialization for GeneralizeOp class."""
-
- def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
- transformed_type = transform.AnyOpType.get()
- super().__init__(transformed_type, target, loc=loc, ip=ip)
-
-
-class InterchangeOp:
- """Specialization for InterchangeOp class."""
-
- def __init__(
- self,
- target: Union[Operation, Value],
- *,
- iterator_interchange: OptionalIntList = None,
- loc=None,
- ip=None,
- ):
- transformed_type = transform.AnyOpType.get()
- super().__init__(
- transformed_type,
- target,
- iterator_interchange=iterator_interchange,
- loc=loc,
- ip=ip,
- )
-
-
-class MapCopyToThreadsOp:
- """Specialization for MapCopyToThreadsOp class."""
-
- @overload
- def __init__(
- self,
- forall_op_type: Type,
- tiled_op_type: Type,
- target: Union[Operation, OpView, Value],
- *,
- total_num_threads: Union[int, IntegerAttr],
- desired_bit_alignment: Union[int, IntegerAttr],
- loc=None,
- ip=None,
- ):
- ...
-
- @overload
- def __init__(
- self,
- target: Union[Operation, OpView, Value],
- *,
- total_num_threads: Union[int, IntegerAttr],
- desired_bit_alignment: Union[int, IntegerAttr],
- loc=None,
- ip=None,
- ):
- ...
-
- def __init__(
- self,
- forall_op_type_or_target: Union[Operation, OpView, Type, Value],
- tiled_op_type_or_none: Optional[Type] = None,
- target_or_none: Optional[Union[Operation, OpView, Value]] = None,
- *,
- total_num_threads: Union[int, IntegerAttr],
- desired_bit_alignment: Union[int, IntegerAttr],
- loc=None,
- ip=None,
- ):
- if isinstance(forall_op_type_or_target, Type):
- forall_op_type = forall_op_type_or_target
- tiled_op_type = tiled_op_type_or_none
- target = target_or_none
- else:
- forall_op_type = transform.AnyOpType.get()
- tiled_op_type = transform.AnyOpType.get()
- target = forall_op_type_or_target
-
- super().__init__(
- forall_op_type,
- tiled_op_type,
- target,
- total_num_threads=total_num_threads,
- desired_bit_alignment=desired_bit_alignment,
- loc=loc,
- ip=ip,
- )
-
-
-class VectorizeOp:
- """Specialization for VectorizeOp class."""
-
- def __init__(
- self,
- target: Union[Operation, OpView, Value],
- vector_sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
- *,
- vectorize_nd_extract: Optional[bool] = None,
- scalable_sizes: OptionalBoolList = None,
- static_vector_sizes: OptionalIntList = None,
- loc=None,
- ip=None,
- ):
- if (
- scalable_sizes is None
- and static_vector_sizes is None
- and vector_sizes is None
- ):
- dynamic_vector_sizes = []
- elif scalable_sizes is None and static_vector_sizes is None:
- (
- dynamic_vector_sizes,
- static_vector_sizes,
- scalable_sizes,
- ) = _dispatch_dynamic_index_list(vector_sizes)
- elif scalable_sizes is None or static_vector_sizes is None:
- raise TypeError(
- "'scalable_sizes' and 'static_vector_sizes' must either both "
- "be given explicitly or both be given as part of 'vector_sizes'."
- )
- else:
- dynamic_vector_sizes = vector_sizes
-
- super().__init__(
- target,
- vector_sizes=dynamic_vector_sizes,
- static_vector_sizes=static_vector_sizes,
- scalable_sizes=scalable_sizes,
- vectorize_nd_extract=vectorize_nd_extract,
- loc=loc,
- ip=ip,
- )
-
-
-class MatchOp:
- """Specialization for MatchOp class."""
-
- @overload
- @classmethod
- def match_op_names(
- cls,
- target: Union[Operation, Value],
- names: Union[str, Sequence[str]],
- *,
- loc=None,
- ip=None,
- ):
- ...
-
- @overload
- @classmethod
- def match_op_names(
- cls,
- result_type: Type,
- target: Union[Operation, Value],
- names: Union[str, Sequence[str]],
- *,
- loc=None,
- ip=None,
- ):
- ...
-
- @classmethod
- def match_op_names(
- cls,
- result_type_or_target: Union[Type, Operation, Value],
- target_or_names: Union[Operation, Value, Sequence[str], str],
- names_or_none: Optional[Union[Sequence[str], str]] = None,
- *,
- loc=None,
- ip=None,
- ):
- if isinstance(result_type_or_target, Type):
- result_type = result_type_or_target
- target = target_or_names
- names = names_or_none
- else:
- result_type = transform.AnyOpType.get()
- target = result_type_or_target
- names = target_or_names
-
- if isinstance(names, str):
- names = [names]
-
- return cls(
- result_type,
- target,
- ops=ArrayAttr.get(list(map(lambda s: StringAttr.get(s), names))),
- loc=loc,
- ip=ip,
- )
-
-
-class MultiTileSizesOp:
- """Specialization for MultiTileSizesOp class."""
-
- def __init__(
- self,
- result_type: Type,
- target: Union[Operation, Value],
- *,
- dimension: Union[int, IntegerAttr],
- target_size: Union[int, IntegerAttr],
- divisor: Optional[Optional[Union[int, IntegerAttr]]] = None,
- loc=None,
- ip=None,
- ):
- super().__init__(
- result_type,
- result_type,
- result_type,
- target,
- dimension=dimension,
- target_size=target_size,
- divisor=divisor,
- loc=loc,
- ip=ip,
- )
-
-
-class PadOp:
- """Specialization for PadOp class."""
-
- def __init__(
- self,
- target: Union[Operation, OpView, Value],
- *,
- padding_values: Optional[Union[ArrayAttr, Sequence[Attribute]]] = None,
- padding_dimensions: OptionalIntList = None,
- pad_to_multiple_of: OptionalIntList = None,
- pack_paddings: OptionalIntList = None,
- transpose_paddings: Optional[
- Union[ArrayAttr, Sequence[Union[ArrayAttr, IntOrAttrList]]]
- ] = None,
- copy_back_op: Optional[Union[str, StringAttr]] = None,
- loc=None,
- ip=None,
- ):
- transpose_paddings = _get_int_array_array_attr(transpose_paddings)
-
- any_op_type = transform.AnyOpType.get()
- super().__init__(
- any_op_type,
- any_op_type,
- any_op_type,
- target,
- padding_values=padding_values,
- padding_dimensions=padding_dimensions,
- pad_to_multiple_of=pad_to_multiple_of,
- pack_paddings=pack_paddings,
- transpose_paddings=transpose_paddings,
- copy_back_op=copy_back_op,
- loc=loc,
- ip=ip,
- )
-
-
-class ScalarizeOp:
- """Specialization for ScalarizeOp class."""
-
- def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
- result_type = transform.AnyOpType.get()
- super().__init__(result_type, target, loc=loc, ip=ip)
-
-
-class SplitOp:
- """Specialization for SplitOp class."""
-
- def __init__(
- self,
- target: Union[Operation, Value],
- dimension: Union[int, Attribute],
- split_point: Union[int, Operation, Value, Attribute],
- *,
- loc=None,
- ip=None,
- ):
- if isinstance(split_point, int):
- static_split_point = split_point
- dynamic_split_point = None
- else:
- static_split_point = ShapedType.get_dynamic_size()
- dynamic_split_point = split_point
-
- super().__init__(
- target.type,
- target.type,
- target,
- dimension=dimension,
- static_split_point=static_split_point,
- dynamic_split_point=dynamic_split_point,
- loc=loc,
- ip=ip,
- )
-
-
-class TileUsingForOp:
- """Specialization for TileUsingForOp class."""
-
- @overload
- def __init__(
- self,
- loop_types: Union[Type, List[Type]],
- target: Union[Operation, Value],
- *,
- sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
- interchange: OptionalIntList = None,
- loc=None,
- ip=None,
- ):
- ...
-
- @overload
- def __init__(
- self,
- target: Union[Operation, Value, OpView],
- *,
- sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
- interchange: OptionalIntList = None,
- loc=None,
- ip=None,
- ):
- ...
-
- def __init__(
- self,
- loop_types_or_target: Union[Type, List[Type], Operation, Value],
- target_or_none: Optional[Union[Operation, Value, OpView]] = None,
- *,
- sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
- interchange: OptionalIntList = None,
- loc=None,
- ip=None,
- ):
- (
- dynamic_sizes,
- static_sizes,
- scalable_sizes,
- ) = _dispatch_dynamic_index_list(sizes)
-
- num_loops = sum(v if v == 0 else 1 for v in static_sizes)
-
- if isinstance(loop_types_or_target, (Operation, Value, OpView)):
- loop_types = [transform.AnyOpType.get()] * num_loops
- target = loop_types_or_target
- assert (
- target_or_none is None
- ), "Cannot construct TileUsingForOp with two targets."
- else:
- loop_types = (
- ([loop_types_or_target] * num_loops)
- if isinstance(loop_types_or_target, Type)
- else loop_types_or_target
- )
- target = target_or_none
-
- super().__init__(
- target.type,
- loop_types,
- target,
- dynamic_sizes=dynamic_sizes,
- static_sizes=static_sizes,
- interchange=interchange,
- scalable_sizes=scalable_sizes,
- loc=loc,
- ip=ip,
- )
-
-
-class TileUsingForallOp:
- """Specialization for TileUsingForallOp class."""
-
- @overload
- def __init__(
- self,
- loops_type: Type,
- tiled_op_type: Type,
- target: Union[Operation, Value, OpView],
- *,
- num_threads: Optional[MixedValues] = None,
- tile_sizes: MixedValues = None,
- mapping=None,
- loc=None,
- ip=None,
- ):
- ...
-
- @overload
- def __init__(
- self,
- target: Union[Operation, Value, OpView],
- *,
- num_threads: Optional[MixedValues] = None,
- tile_sizes: MixedValues = None,
- mapping=None,
- loc=None,
- ip=None,
- ):
- ...
-
- def __init__(
- self,
- loops_type_or_target: Union[
- Type, Union[Operation, Value, OpView] # loops_type
- ], # target
- tiled_op_type_or_none: Optional[Type] = None,
- target_or_none: Optional[Union[Operation, Value, OpView]] = None,
- *,
- num_threads: MixedValues = None,
- tile_sizes: MixedValues = None,
- mapping=None,
- loc=None,
- ip=None,
- ):
- # `Type` arguments in the front are optional: add default values to front.
- if isinstance(loops_type_or_target, Type):
- # First overload: type arguments provided.
- if not isinstance(tiled_op_type_or_none, Type):
- raise TypeError(
- "If 'loops_type_or_target' is a type, then "
- "'tiled_op_type_or_none' is expected to be one as well."
- )
- loops_type = loops_type_or_target
- tiled_op_type = tiled_op_type_or_none
- target = target_or_none
- else:
- # Last overload: type arguments missing.
- loops_type = transform.AnyOpType.get()
- tiled_op_type = transform.AnyOpType.get()
- target = loops_type_or_target
-
- # Unpack mixed num_threads.
- (
- dynamic_num_threads,
- packed_num_threads,
- num_threads_attr,
- ) = _dispatch_mixed_values(num_threads)
-
- # Unpack mixed tile_sizes.
- (
- dynamic_tile_sizes,
- packed_tile_sizes,
- tile_sizes_attr,
- ) = _dispatch_mixed_values(tile_sizes)
-
- super().__init__(
- loops_type,
- tiled_op_type,
- target=target,
- tile_sizes=dynamic_tile_sizes,
- packed_tile_sizes=packed_tile_sizes,
- static_tile_sizes=tile_sizes_attr,
- num_threads=dynamic_num_threads,
- packed_num_threads=packed_num_threads,
- static_num_threads=num_threads_attr,
- mapping=mapping,
- loc=loc,
- ip=ip,
- )
-
-
-class VectorizeChildrenAndApplyPatternsOp:
- """Specialization for VectorizeChildrenAndApplyPatternsOp class."""
-
- def __init__(
- self,
- target: Union[Operation, Value],
- *,
- disable_multi_reduction_to_contract_patterns: bool = False,
- disable_transfer_permutation_map_lowering_patterns: bool = False,
- vectorize_nd_extract: bool = False,
- vectorize_padding: bool = False,
- loc=None,
- ip=None,
- ):
- transformed_type = transform.AnyOpType.get()
- super().__init__(
- transformed_type,
- target,
- disable_multi_reduction_to_contract_patterns=disable_multi_reduction_to_contract_patterns,
- disable_transfer_permutation_map_lowering_patterns=disable_transfer_permutation_map_lowering_patterns,
- vectorize_nd_extract=vectorize_nd_extract,
- vectorize_padding=vectorize_padding,
- loc=loc,
- ip=ip,
- )
diff --git a/mlir/python/mlir/dialects/_tensor_ops_ext.py b/mlir/python/mlir/dialects/_tensor_ops_ext.py
deleted file mode 100644
index 09b9ec68db7d9c7..000000000000000
--- a/mlir/python/mlir/dialects/_tensor_ops_ext.py
+++ /dev/null
@@ -1,44 +0,0 @@
-# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-try:
- from ..ir import *
-except ImportError as e:
- raise RuntimeError("Error loading imports from extension module") from e
-
-from typing import Any, Optional, Sequence, Union
-from ._ods_common import (
- get_op_result_or_value as _get_op_result_or_value,
- get_op_results_or_values as _get_op_results_or_values,
-)
-
-
-class EmptyOp:
- """Extends the tensor.empty op."""
-
- def __init__(
- self,
- sizes: Sequence[Union[int, Value]],
- element_type: Type,
- *,
- loc=None,
- ip=None
- ):
- """Constructs an `empty` with mixed static/dynamic sizes."""
- # TODO: Refactor the EmptyOp to take an element type attribute and
- # then use normal result type inference, unifying the Python and C++ side
- # with a standard mechanism (versus stashing that in builders).
- dynamic_sizes = []
- static_sizes = []
- for s in sizes:
- if isinstance(s, int):
- static_sizes.append(s)
- else:
- static_sizes.append(ShapedType.get_dynamic_size())
- dynamic_sizes.append(s)
- result_type = RankedTensorType.get(static_sizes, element_type)
- op = self.build_generic(
- results=[result_type], operands=dynamic_sizes, attributes={}, loc=loc, ip=ip
- )
- OpView.__init__(self, op)
diff --git a/mlir/python/mlir/dialects/_tensor_transform_ops_ext.py b/mlir/python/mlir/dialects/_tensor_transform_ops_ext.py
deleted file mode 100644
index 996093fbc913e8a..000000000000000
--- a/mlir/python/mlir/dialects/_tensor_transform_ops_ext.py
+++ /dev/null
@@ -1,64 +0,0 @@
-# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-try:
- from ..ir import *
- from ..dialects import transform
-except ImportError as e:
- raise RuntimeError("Error loading imports from extension module") from e
-
-from typing import Optional, overload, Union
-
-
-class MakeLoopIndependentOp:
- """Specialization for MakeLoopIndependentOp class."""
-
- @overload
- def __init__(
- self,
- transformed_type: Type,
- target: Union[Operation, OpView, Value],
- num_loops: Union[int, IntegerAttr],
- *,
- loc=None,
- ip=None
- ):
- ...
-
- @overload
- def __init__(
- self,
- target: Union[Operation, OpView, Value],
- num_loops: Union[int, IntegerAttr],
- *,
- loc=None,
- ip=None
- ):
- ...
-
- def __init__(
- self,
- transformed_type_or_target: Type,
- target_or_num_loops: Union[int, IntegerAttr, Operation, OpView, Value] = None,
- num_loops_or_none: Optional[Union[int, IntegerAttr]] = None,
- *,
- loc=None,
- ip=None
- ):
- if isinstance(transformed_type_or_target, Type):
- transformed_type = transformed_type_or_target
- target = target_or_num_loops
- num_loops = num_loops_or_none
- else:
- transformed_type = transform.AnyOpType.get()
- target = transformed_type_or_target
- num_loops = target_or_num_loops
-
- super().__init__(
- transformed_type,
- target,
- num_loops,
- loc=loc,
- ip=ip,
- )
diff --git a/mlir/python/mlir/dialects/_transform_ops_ext.py b/mlir/python/mlir/dialects/_transform_ops_ext.py
deleted file mode 100644
index b1e7b892536f4a1..000000000000000
--- a/mlir/python/mlir/dialects/_transform_ops_ext.py
+++ /dev/null
@@ -1,176 +0,0 @@
-# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-try:
- from ..ir import *
- from ._ods_common import (
- get_op_result_or_value as _get_op_result_or_value,
- get_op_results_or_values as _get_op_results_or_values,
- )
-except ImportError as e:
- raise RuntimeError("Error loading imports from extension module") from e
-
-from typing import Optional, Sequence, Union
-
-
-class CastOp:
- def __init__(
- self,
- result_type: Type,
- target: Union[Operation, Value],
- *,
- loc=None,
- ip=None,
- ):
- super().__init__(result_type, _get_op_result_or_value(target), loc=loc, ip=ip)
-
-
-class ApplyPatternsOp:
- def __init__(
- self,
- target: Union[Operation, Value, OpView],
- *,
- loc=None,
- ip=None,
- ):
- operands = []
- operands.append(_get_op_result_or_value(target))
- super().__init__(
- self.build_generic(
- attributes={},
- results=[],
- operands=operands,
- successors=None,
- regions=None,
- loc=loc,
- ip=ip,
- )
- )
- self.regions[0].blocks.append()
-
- @property
- def patterns(self) -> Block:
- return self.regions[0].blocks[0]
-
-
-class testGetParentOp:
- def __init__(
- self,
- result_type: Type,
- target: Union[Operation, Value],
- *,
- isolated_from_above: bool = False,
- op_name: Optional[str] = None,
- deduplicate: bool = False,
- loc=None,
- ip=None,
- ):
- super().__init__(
- result_type,
- _get_op_result_or_value(target),
- isolated_from_above=isolated_from_above,
- op_name=op_name,
- deduplicate=deduplicate,
- loc=loc,
- ip=ip,
- )
-
-
-class MergeHandlesOp:
- def __init__(
- self,
- handles: Sequence[Union[Operation, Value]],
- *,
- deduplicate: bool = False,
- loc=None,
- ip=None,
- ):
- super().__init__(
- [_get_op_result_or_value(h) for h in handles],
- deduplicate=deduplicate,
- loc=loc,
- ip=ip,
- )
-
-
-class ReplicateOp:
- def __init__(
- self,
- pattern: Union[Operation, Value],
- handles: Sequence[Union[Operation, Value]],
- *,
- loc=None,
- ip=None,
- ):
- super().__init__(
- [_get_op_result_or_value(h).type for h in handles],
- _get_op_result_or_value(pattern),
- [_get_op_result_or_value(h) for h in handles],
- loc=loc,
- ip=ip,
- )
-
-
-class SequenceOp:
- def __init__(
- self,
- failure_propagation_mode,
- results: Sequence[Type],
- target: Union[Operation, Value, Type],
- extra_bindings: Optional[
- Union[Sequence[Value], Sequence[Type], Operation, OpView]
- ] = None,
- ):
- root = (
- _get_op_result_or_value(target)
- if isinstance(target, (Operation, Value))
- else None
- )
- root_type = root.type if not isinstance(target, Type) else target
-
- if extra_bindings is None:
- extra_bindings = []
- if isinstance(extra_bindings, (Operation, OpView)):
- extra_bindings = _get_op_results_or_values(extra_bindings)
-
- extra_binding_types = []
- if len(extra_bindings) != 0:
- if isinstance(extra_bindings[0], Type):
- extra_binding_types = extra_bindings
- extra_bindings = []
- else:
- extra_binding_types = [v.type for v in extra_bindings]
-
- super().__init__(
- results_=results,
- failure_propagation_mode=failure_propagation_mode,
- root=root,
- extra_bindings=extra_bindings,
- )
- self.regions[0].blocks.append(*tuple([root_type] + extra_binding_types))
-
- @property
- def body(self) -> Block:
- return self.regions[0].blocks[0]
-
- @property
- def bodyTarget(self) -> Value:
- return self.body.arguments[0]
-
- @property
- def bodyExtraArgs(self) -> BlockArgumentList:
- return self.body.arguments[1:]
-
-
-class YieldOp:
- def __init__(
- self,
- operands: Optional[Union[Operation, Sequence[Value]]] = None,
- *,
- loc=None,
- ip=None,
- ):
- if operands is None:
- operands = []
- super().__init__(_get_op_results_or_values(operands), loc=loc, ip=ip)
diff --git a/mlir/python/mlir/dialects/_transform_pdl_extension_ops_ext.py b/mlir/python/mlir/dialects/_transform_pdl_extension_ops_ext.py
deleted file mode 100644
index c4e4b4b4254b038..000000000000000
--- a/mlir/python/mlir/dialects/_transform_pdl_extension_ops_ext.py
+++ /dev/null
@@ -1,55 +0,0 @@
-# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-try:
- from ..ir import *
- from ._ods_common import (
- get_op_result_or_value as _get_op_result_or_value,
- get_op_results_or_values as _get_op_results_or_values,
- )
-except ImportError as e:
- raise RuntimeError("Error loading imports from extension module") from e
-
-from typing import Union
-
-class PDLMatchOp:
-
- def __init__(
- self,
- result_type: Type,
- target: Union[Operation, Value],
- pattern_name: Union[Attribute, str],
- *,
- loc=None,
- ip=None,
- ):
- super().__init__(
- result_type,
- _get_op_result_or_value(target),
- pattern_name,
- loc=loc,
- ip=ip,
- )
-
-
-class WithPDLPatternsOp:
-
- def __init__(self,
- target: Union[Operation, Value, Type],
- *,
- loc=None,
- ip=None):
- root = _get_op_result_or_value(target) if not isinstance(target,
- Type) else None
- root_type = target if isinstance(target, Type) else root.type
- super().__init__(root=root, loc=loc, ip=ip)
- self.regions[0].blocks.append(root_type)
-
- @property
- def body(self) -> Block:
- return self.regions[0].blocks[0]
-
- @property
- def bodyTarget(self) -> Value:
- return self.body.arguments[0]
diff --git a/mlir/python/mlir/dialects/affine.py b/mlir/python/mlir/dialects/affine.py
index 8a2a64c7c40d190..1eaccfa73a85cbf 100644
--- a/mlir/python/mlir/dialects/affine.py
+++ b/mlir/python/mlir/dialects/affine.py
@@ -1,5 +1,50 @@
-# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from ._affine_ops_gen import *
+from ._affine_ops_gen import _Dialect
+
+try:
+ from ..ir import *
+ from ._ods_common import (
+ get_op_result_or_value as _get_op_result_or_value,
+ get_op_results_or_values as _get_op_results_or_values,
+ _cext as _ods_cext,
+ )
+except ImportError as e:
+ raise RuntimeError("Error loading imports from extension module") from e
+
+from typing import Optional, Sequence, Union
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class AffineStoreOp(AffineStoreOp):
+ """Specialization for the Affine store operation."""
+
+ def __init__(
+ self,
+ value: Union[Operation, OpView, Value],
+ memref: Union[Operation, OpView, Value],
+ map: AffineMap = None,
+ *,
+ map_operands=None,
+ loc=None,
+ ip=None,
+ ):
+ """Creates an affine store operation.
+
+ - `value`: the value to store into the memref.
+ - `memref`: the buffer to store into.
+ - `map`: the affine map that maps the map_operands to the index of the
+ memref.
+ - `map_operands`: the list of arguments to substitute the dimensions,
+ then symbols in the affine map, in increasing order.
+ """
+ map = map if map is not None else []
+ map_operands = map_operands if map_operands is not None else []
+ indicies = [_get_op_result_or_value(op) for op in map_operands]
+ _ods_successors = None
+ super().__init__(
+ value, memref, indicies, AffineMapAttr.get(map), loc=loc, ip=ip
+ )
diff --git a/mlir/python/mlir/dialects/arith.py b/mlir/python/mlir/dialects/arith.py
index fb13beb63ca66c3..83aca0d58bf2cef 100644
--- a/mlir/python/mlir/dialects/arith.py
+++ b/mlir/python/mlir/dialects/arith.py
@@ -3,4 +3,75 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from ._arith_ops_gen import *
+from ._arith_ops_gen import _Dialect
from ._arith_enum_gen import *
+
+try:
+ from ..ir import *
+ from ._ods_common import (
+ get_default_loc_context as _get_default_loc_context,
+ _cext as _ods_cext,
+ )
+
+ from typing import Any, List, Union
+except ImportError as e:
+ raise RuntimeError("Error loading imports from extension module") from e
+
+
+def _isa(obj: Any, cls: type):
+ try:
+ cls(obj)
+ except ValueError:
+ return False
+ return True
+
+
+def _is_any_of(obj: Any, classes: List[type]):
+ return any(_isa(obj, cls) for cls in classes)
+
+
+def _is_integer_like_type(type: Type):
+ return _is_any_of(type, [IntegerType, IndexType])
+
+
+def _is_float_type(type: Type):
+ return _is_any_of(type, [BF16Type, F16Type, F32Type, F64Type])
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class ConstantOp(ConstantOp):
+ """Specialization for the constant op class."""
+
+ def __init__(
+ self, result: Type, value: Union[int, float, Attribute], *, loc=None, ip=None
+ ):
+ if isinstance(value, int):
+ super().__init__(IntegerAttr.get(result, value), loc=loc, ip=ip)
+ elif isinstance(value, float):
+ super().__init__(FloatAttr.get(result, value), loc=loc, ip=ip)
+ else:
+ super().__init__(value, loc=loc, ip=ip)
+
+ @classmethod
+ def create_index(cls, value: int, *, loc=None, ip=None):
+ """Create an index-typed constant."""
+ return cls(
+ IndexType.get(context=_get_default_loc_context(loc)), value, loc=loc, ip=ip
+ )
+
+ @property
+ def type(self):
+ return self.results[0].type
+
+ @property
+ def value(self):
+ return Attribute(self.operation.attributes["value"])
+
+ @property
+ def literal_value(self) -> Union[int, float]:
+ if _is_integer_like_type(self.type):
+ return IntegerAttr(self.value).value
+ elif _is_float_type(self.type):
+ return FloatAttr(self.value).value
+ else:
+ raise ValueError("only integer and float constants have literal values")
diff --git a/mlir/python/mlir/dialects/bufferization.py b/mlir/python/mlir/dialects/bufferization.py
index 759b6aa24a9ff73..0ce5448ace4b14c 100644
--- a/mlir/python/mlir/dialects/bufferization.py
+++ b/mlir/python/mlir/dialects/bufferization.py
@@ -3,4 +3,40 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from ._bufferization_ops_gen import *
+from ._bufferization_ops_gen import _Dialect
from ._bufferization_enum_gen import *
+
+try:
+ from typing import Sequence, Union
+ from ..ir import *
+ from ._ods_common import get_default_loc_context, _cext as _ods_cext
+
+ from typing import Any, List, Union
+except ImportError as e:
+ raise RuntimeError("Error loading imports from extension module") from e
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class AllocTensorOp(AllocTensorOp):
+ """Extends the bufferization.alloc_tensor op."""
+
+ def __init__(
+ self,
+ tensor_type: Type,
+ dynamic_sizes: Sequence[Value],
+ copy: Value,
+ size_hint: Value,
+ escape: BoolAttr,
+ *,
+ loc=None,
+ ip=None,
+ ):
+ """Constructs an `alloc_tensor` with static and/or dynamic sizes."""
+ super().__init__(
+ tensor_type,
+ dynamic_sizes,
+ copy=copy,
+ size_hint=size_hint,
+ loc=loc,
+ ip=ip,
+ )
diff --git a/mlir/python/mlir/dialects/builtin.py b/mlir/python/mlir/dialects/builtin.py
index 30279e1611f99aa..b71cc2466d464b3 100644
--- a/mlir/python/mlir/dialects/builtin.py
+++ b/mlir/python/mlir/dialects/builtin.py
@@ -3,3 +3,23 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from ._builtin_ops_gen import *
+from ._builtin_ops_gen import _Dialect
+
+try:
+ from ..ir import *
+ from ._ods_common import _cext as _ods_cext
+except ImportError as e:
+ raise RuntimeError("Error loading imports from extension module") from e
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class ModuleOp(ModuleOp):
+ """Specialization for the module op class."""
+
+ def __init__(self, *, loc=None, ip=None):
+ super().__init__(loc=loc, ip=ip)
+ body = self.regions[0].blocks.append()
+
+ @property
+ def body(self):
+ return self.regions[0].blocks[0]
diff --git a/mlir/python/mlir/dialects/func.py b/mlir/python/mlir/dialects/func.py
index dc554c22173bc60..9c6c4c9092c7a88 100644
--- a/mlir/python/mlir/dialects/func.py
+++ b/mlir/python/mlir/dialects/func.py
@@ -3,3 +3,326 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from ._func_ops_gen import *
+from ._func_ops_gen import _Dialect
+
+try:
+ from ..ir import *
+ from ._ods_common import (
+ get_default_loc_context as _get_default_loc_context,
+ _cext as _ods_cext,
+ )
+
+ import inspect
+
+ from typing import Any, List, Optional, Sequence, Union
+except ImportError as e:
+ raise RuntimeError("Error loading imports from extension module") from e
+
+ARGUMENT_ATTRIBUTE_NAME = "arg_attrs"
+RESULT_ATTRIBUTE_NAME = "res_attrs"
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class ConstantOp(ConstantOp):
+ """Specialization for the constant op class."""
+
+ def __init__(self, result: Type, value: Attribute, *, loc=None, ip=None):
+ super().__init__(result, value, loc=loc, ip=ip)
+
+ @property
+ def type(self):
+ return self.results[0].type
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class FuncOp(FuncOp):
+ """Specialization for the func op class."""
+
+ def __init__(
+ self, name, type, *, visibility=None, body_builder=None, loc=None, ip=None
+ ):
+ """
+ Create a FuncOp with the provided `name`, `type`, and `visibility`.
+ - `name` is a string representing the function name.
+ - `type` is either a FunctionType or a pair of list describing inputs and
+ results.
+ - `visibility` is a string matching `public`, `private`, or `nested`. None
+ implies private visibility.
+ - `body_builder` is an optional callback, when provided a new entry block
+ is created and the callback is invoked with the new op as argument within
+ an InsertionPoint context already set for the block. The callback is
+ expected to insert a terminator in the block.
+ """
+ sym_name = StringAttr.get(str(name))
+
+ # If the type is passed as a tuple, build a FunctionType on the fly.
+ if isinstance(type, tuple):
+ type = FunctionType.get(inputs=type[0], results=type[1])
+
+ type = TypeAttr.get(type)
+ sym_visibility = (
+ StringAttr.get(str(visibility)) if visibility is not None else None
+ )
+ super().__init__(sym_name, type, sym_visibility=sym_visibility, loc=loc, ip=ip)
+ if body_builder:
+ entry_block = self.add_entry_block()
+ with InsertionPoint(entry_block):
+ body_builder(self)
+
+ @property
+ def is_external(self):
+ return len(self.regions[0].blocks) == 0
+
+ @property
+ def body(self):
+ return self.regions[0]
+
+ @property
+ def type(self):
+ return FunctionType(TypeAttr(self.attributes["function_type"]).value)
+
+ @property
+ def visibility(self):
+ return self.attributes["sym_visibility"]
+
+ @property
+ def name(self) -> StringAttr:
+ return StringAttr(self.attributes["sym_name"])
+
+ @property
+ def entry_block(self):
+ if self.is_external:
+ raise IndexError("External function does not have a body")
+ return self.regions[0].blocks[0]
+
+ def add_entry_block(self, arg_locs: Optional[Sequence[Location]] = None):
+ """
+ Add an entry block to the function body using the function signature to
+ infer block arguments.
+ Returns the newly created block
+ """
+ if not self.is_external:
+ raise IndexError("The function already has an entry block!")
+ self.body.blocks.append(*self.type.inputs, arg_locs=arg_locs)
+ return self.body.blocks[0]
+
+ @property
+ def arg_attrs(self):
+ return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME])
+
+ @arg_attrs.setter
+ def arg_attrs(self, attribute: Union[ArrayAttr, list]):
+ if isinstance(attribute, ArrayAttr):
+ self.attributes[ARGUMENT_ATTRIBUTE_NAME] = attribute
+ else:
+ self.attributes[ARGUMENT_ATTRIBUTE_NAME] = ArrayAttr.get(
+ attribute, context=self.context
+ )
+
+ @property
+ def arguments(self):
+ return self.entry_block.arguments
+
+ @property
+ def result_attrs(self):
+ return self.attributes[RESULT_ATTRIBUTE_NAME]
+
+ @result_attrs.setter
+ def result_attrs(self, attribute: ArrayAttr):
+ self.attributes[RESULT_ATTRIBUTE_NAME] = attribute
+
+ @classmethod
+ def from_py_func(
+ FuncOp,
+ *inputs: Type,
+ results: Optional[Sequence[Type]] = None,
+ name: Optional[str] = None,
+ ):
+ """Decorator to define an MLIR FuncOp specified as a python function.
+
+ Requires that an `mlir.ir.InsertionPoint` and `mlir.ir.Location` are
+ active for the current thread (i.e. established in a `with` block).
+
+ When applied as a decorator to a Python function, an entry block will
+ be constructed for the FuncOp with types as specified in `*inputs`. The
+ block arguments will be passed positionally to the Python function. In
+ addition, if the Python function accepts keyword arguments generally or
+ has a corresponding keyword argument, the following will be passed:
+ * `func_op`: The `func` op being defined.
+
+ By default, the function name will be the Python function `__name__`. This
+ can be overriden by passing the `name` argument to the decorator.
+
+ If `results` is not specified, then the decorator will implicitly
+ insert a `ReturnOp` with the `Value`'s returned from the decorated
+ function. It will also set the `FuncOp` type with the actual return
+ value types. If `results` is specified, then the decorated function
+ must return `None` and no implicit `ReturnOp` is added (nor are the result
+ types updated). The implicit behavior is intended for simple, single-block
+ cases, and users should specify result types explicitly for any complicated
+ cases.
+
+ The decorated function can further be called from Python and will insert
+ a `CallOp` at the then-current insertion point, returning either None (
+ if no return values), a unary Value (for one result), or a list of Values).
+ This mechanism cannot be used to emit recursive calls (by construction).
+ """
+
+ def decorator(f):
+ from . import func
+
+ # Introspect the callable for optional features.
+ sig = inspect.signature(f)
+ has_arg_func_op = False
+ for param in sig.parameters.values():
+ if param.kind == param.VAR_KEYWORD:
+ has_arg_func_op = True
+ if param.name == "func_op" and (
+ param.kind == param.POSITIONAL_OR_KEYWORD
+ or param.kind == param.KEYWORD_ONLY
+ ):
+ has_arg_func_op = True
+
+ # Emit the FuncOp.
+ implicit_return = results is None
+ symbol_name = name or f.__name__
+ function_type = FunctionType.get(
+ inputs=inputs, results=[] if implicit_return else results
+ )
+ func_op = FuncOp(name=symbol_name, type=function_type)
+ with InsertionPoint(func_op.add_entry_block()):
+ func_args = func_op.entry_block.arguments
+ func_kwargs = {}
+ if has_arg_func_op:
+ func_kwargs["func_op"] = func_op
+ return_values = f(*func_args, **func_kwargs)
+ if not implicit_return:
+ return_types = list(results)
+ assert return_values is None, (
+ "Capturing a python function with explicit `results=` "
+ "requires that the wrapped function returns None."
+ )
+ else:
+ # Coerce return values, add ReturnOp and rewrite func type.
+ if return_values is None:
+ return_values = []
+ elif isinstance(return_values, tuple):
+ return_values = list(return_values)
+ elif isinstance(return_values, Value):
+ # Returning a single value is fine, coerce it into a list.
+ return_values = [return_values]
+ elif isinstance(return_values, OpView):
+ # Returning a single operation is fine, coerce its results a list.
+ return_values = return_values.operation.results
+ elif isinstance(return_values, Operation):
+ # Returning a single operation is fine, coerce its results a list.
+ return_values = return_values.results
+ else:
+ return_values = list(return_values)
+ func.ReturnOp(return_values)
+ # Recompute the function type.
+ return_types = [v.type for v in return_values]
+ function_type = FunctionType.get(
+ inputs=inputs, results=return_types
+ )
+ func_op.attributes["function_type"] = TypeAttr.get(function_type)
+
+ def emit_call_op(*call_args):
+ call_op = func.CallOp(
+ return_types, FlatSymbolRefAttr.get(symbol_name), call_args
+ )
+ if return_types is None:
+ return None
+ elif len(return_types) == 1:
+ return call_op.result
+ else:
+ return call_op.results
+
+ wrapped = emit_call_op
+ wrapped.__name__ = f.__name__
+ wrapped.func_op = func_op
+ return wrapped
+
+ return decorator
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class CallOp(CallOp):
+ """Specialization for the call op class."""
+
+ def __init__(
+ self,
+ calleeOrResults: Union[FuncOp, List[Type]],
+ argumentsOrCallee: Union[List, FlatSymbolRefAttr, str],
+ arguments: Optional[List] = None,
+ *,
+ loc=None,
+ ip=None,
+ ):
+ """Creates an call operation.
+
+ The constructor accepts three
diff erent forms:
+
+ 1. A function op to be called followed by a list of arguments.
+ 2. A list of result types, followed by the name of the function to be
+ called as string, following by a list of arguments.
+ 3. A list of result types, followed by the name of the function to be
+ called as symbol reference attribute, followed by a list of arguments.
+
+ For example
+
+ f = func.FuncOp("foo", ...)
+ func.CallOp(f, [args])
+ func.CallOp([result_types], "foo", [args])
+
+ In all cases, the location and insertion point may be specified as keyword
+ arguments if not provided by the surrounding context managers.
+ """
+
+ # TODO: consider supporting constructor "overloads", e.g., through a custom
+ # or pybind-provided metaclass.
+ if isinstance(calleeOrResults, FuncOp):
+ if not isinstance(argumentsOrCallee, list):
+ raise ValueError(
+ "when constructing a call to a function, expected "
+ + "the second argument to be a list of call arguments, "
+ + f"got {type(argumentsOrCallee)}"
+ )
+ if arguments is not None:
+ raise ValueError(
+ "unexpected third argument when constructing a call"
+ + "to a function"
+ )
+
+ super().__init__(
+ calleeOrResults.type.results,
+ FlatSymbolRefAttr.get(
+ calleeOrResults.name.value, context=_get_default_loc_context(loc)
+ ),
+ argumentsOrCallee,
+ loc=loc,
+ ip=ip,
+ )
+ return
+
+ if isinstance(argumentsOrCallee, list):
+ raise ValueError(
+ "when constructing a call to a function by name, "
+ + "expected the second argument to be a string or a "
+ + f"FlatSymbolRefAttr, got {type(argumentsOrCallee)}"
+ )
+
+ if isinstance(argumentsOrCallee, FlatSymbolRefAttr):
+ super().__init__(
+ calleeOrResults, argumentsOrCallee, arguments, loc=loc, ip=ip
+ )
+ elif isinstance(argumentsOrCallee, str):
+ super().__init__(
+ calleeOrResults,
+ FlatSymbolRefAttr.get(
+ argumentsOrCallee, context=_get_default_loc_context(loc)
+ ),
+ arguments,
+ loc=loc,
+ ip=ip,
+ )
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
index 6f9d72164429eea..f91fc8b7160089b 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
@@ -310,7 +310,7 @@ def emit_named_structured_op(
)
# Set the index attributes used to compute the indexing maps.
- named_op = getattr(linalg, op_class_name)(ins, outs, result_types)
+ named_op = getattr(linalg, op_class_name)(result_types, ins, outs)
for name, value in index_attrs.items():
named_op.operation.attributes[name] = value
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index a8f8f8e0fbd68b4..19734a80a107bfe 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -296,35 +296,39 @@ def quantized_matmul(
@linalg_structured_op
-def matmul_transpose_a(A=TensorDef(T1, S.K, S.N),
- B=TensorDef(T2, S.K, S.M),
- C=TensorDef(U, S.M, S.N, output=True),
- cast=TypeFnAttrDef(default=TypeFn.cast_signed)):
- """Performs a matrix multiplication of two 2D inputs with lhs operand
- transposed.
+def matmul_transpose_a(
+ A=TensorDef(T1, S.K, S.N),
+ B=TensorDef(T2, S.K, S.M),
+ C=TensorDef(U, S.M, S.N, output=True),
+ cast=TypeFnAttrDef(default=TypeFn.cast_signed),
+):
+ """Performs a matrix multiplication of two 2D inputs with lhs operand
+ transposed.
- Numeric casting is performed on the operands to the inner multiply, promoting
- them to the same data type as the accumulator/output.
- """
- domain(D.m, D.n, D.k)
- implements(ContractionOpInterface)
- C[D.m, D.n] += cast(U, A[D.k, D.m]) * cast(U, B[D.k, D.n])
+ Numeric casting is performed on the operands to the inner multiply, promoting
+ them to the same data type as the accumulator/output.
+ """
+ domain(D.m, D.n, D.k)
+ implements(ContractionOpInterface)
+ C[D.m, D.n] += cast(U, A[D.k, D.m]) * cast(U, B[D.k, D.n])
@linalg_structured_op
-def matmul_transpose_b(A=TensorDef(T1, S.M, S.K),
- B=TensorDef(T2, S.N, S.K),
- C=TensorDef(U, S.M, S.N, output=True),
- cast=TypeFnAttrDef(default=TypeFn.cast_signed)):
- """Performs a matrix multiplication of two 2D inputs with rhs operand
- transposed.
+def matmul_transpose_b(
+ A=TensorDef(T1, S.M, S.K),
+ B=TensorDef(T2, S.N, S.K),
+ C=TensorDef(U, S.M, S.N, output=True),
+ cast=TypeFnAttrDef(default=TypeFn.cast_signed),
+):
+ """Performs a matrix multiplication of two 2D inputs with rhs operand
+ transposed.
- Numeric casting is performed on the operands to the inner multiply, promoting
- them to the same data type as the accumulator/output.
- """
- domain(D.m, D.n, D.k)
- implements(ContractionOpInterface)
- C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.n, D.k])
+ Numeric casting is performed on the operands to the inner multiply, promoting
+ them to the same data type as the accumulator/output.
+ """
+ domain(D.m, D.n, D.k)
+ implements(ContractionOpInterface)
+ C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.n, D.k])
@linalg_structured_op
@@ -390,36 +394,41 @@ def batch_matmul(
@linalg_structured_op
-def batch_matmul_transpose_a(A=TensorDef(T1, Batch, S.K, S.M),
- B=TensorDef(T2, Batch, S.K, S.N),
- C=TensorDef(U, Batch, S.M, S.N, output=True)):
- """Performs a batched matrix multiplication of two 3D inputs where lhs operand
- has its non-batch dimensions transposed.
+def batch_matmul_transpose_a(
+ A=TensorDef(T1, Batch, S.K, S.M),
+ B=TensorDef(T2, Batch, S.K, S.N),
+ C=TensorDef(U, Batch, S.M, S.N, output=True),
+):
+ """Performs a batched matrix multiplication of two 3D inputs where lhs operand
+ has its non-batch dimensions transposed.
- Numeric casting is performed on the operands to the inner multiply, promoting
- them to the same data type as the accumulator/output.
- """
- domain(D.b, D.m, D.n, D.k)
- implements(ContractionOpInterface)
- C[D.b, D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.k, D.m]) \
- * TypeFn.cast_signed(U, B[D.b, D.k, D.n])
+ Numeric casting is performed on the operands to the inner multiply, promoting
+ them to the same data type as the accumulator/output.
+ """
+ domain(D.b, D.m, D.n, D.k)
+ implements(ContractionOpInterface)
+ C[D.b, D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.k, D.m]) * TypeFn.cast_signed(
+ U, B[D.b, D.k, D.n]
+ )
@linalg_structured_op
-def batch_matmul_transpose_b(A=TensorDef(T1, Batch, S.M, S.K),
- B=TensorDef(T2, Batch, S.N, S.K),
- C=TensorDef(U, Batch, S.M, S.N, output=True)):
- """Performs a batched matrix multiplication of two 3D inputs where rhs operand
- has its non-batch dimensions transposed.
+def batch_matmul_transpose_b(
+ A=TensorDef(T1, Batch, S.M, S.K),
+ B=TensorDef(T2, Batch, S.N, S.K),
+ C=TensorDef(U, Batch, S.M, S.N, output=True),
+):
+ """Performs a batched matrix multiplication of two 3D inputs where rhs operand
+ has its non-batch dimensions transposed.
- Numeric casting is performed on the operands to the inner multiply, promoting
- them to the same data type as the accumulator/output.
- """
- domain(D.b, D.m, D.n, D.k)
- implements(ContractionOpInterface)
- C[D.b, D.m,
- D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed(
- U, B[D.b, D.n, D.k])
+ Numeric casting is performed on the operands to the inner multiply, promoting
+ them to the same data type as the accumulator/output.
+ """
+ domain(D.b, D.m, D.n, D.k)
+ implements(ContractionOpInterface)
+ C[D.b, D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed(
+ U, B[D.b, D.n, D.k]
+ )
@linalg_structured_op
diff --git a/mlir/python/mlir/dialects/memref.py b/mlir/python/mlir/dialects/memref.py
index 3afb6a70cb9e0db..111ad2178703d28 100644
--- a/mlir/python/mlir/dialects/memref.py
+++ b/mlir/python/mlir/dialects/memref.py
@@ -3,3 +3,41 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from ._memref_ops_gen import *
+from ._memref_ops_gen import _Dialect
+
+try:
+ from ..ir import *
+ from ._ods_common import (
+ get_op_result_or_value as _get_op_result_or_value,
+ get_op_results_or_values as _get_op_results_or_values,
+ _cext as _ods_cext,
+ )
+except ImportError as e:
+ raise RuntimeError("Error loading imports from extension module") from e
+
+from typing import Optional, Sequence, Union
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class LoadOp(LoadOp):
+ """Specialization for the MemRef load operation."""
+
+ def __init__(
+ self,
+ memref: Union[Operation, OpView, Value],
+ indices: Optional[Union[Operation, OpView, Sequence[Value]]] = None,
+ *,
+ loc=None,
+ ip=None,
+ ):
+ """Creates a memref load operation.
+
+ Args:
+ memref: the buffer to load from.
+ indices: the list of subscripts, may be empty for zero-dimensional
+ buffers.
+ loc: user-visible location of the operation.
+ ip: insertion point.
+ """
+ indices_resolved = [] if indices is None else _get_op_results_or_values(indices)
+ super().__init__(memref, indices_resolved, loc=loc, ip=ip)
diff --git a/mlir/python/mlir/dialects/ml_program.py b/mlir/python/mlir/dialects/ml_program.py
index a654529b4bb8843..dfb6d7f2c03b1cf 100644
--- a/mlir/python/mlir/dialects/ml_program.py
+++ b/mlir/python/mlir/dialects/ml_program.py
@@ -2,4 +2,118 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+from typing import Union
+
from ._ml_program_ops_gen import *
+from ._ml_program_ops_gen import _Dialect
+
+try:
+ from ..ir import *
+ from ._ods_common import (
+ get_default_loc_context as _get_default_loc_context,
+ _cext as _ods_cext,
+ )
+except ImportError as e:
+ raise RuntimeError("Error loading imports from extension module") from e
+
+
+ARGUMENT_ATTRIBUTE_NAME = "arg_attrs"
+RESULT_ATTRIBUTE_NAME = "res_attrs"
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class FuncOp(FuncOp):
+ """Specialization for the func op class."""
+
+ def __init__(
+ self, name, type, *, visibility=None, body_builder=None, loc=None, ip=None
+ ):
+ """
+ Create a FuncOp with the provided `name`, `type`, and `visibility`.
+ - `name` is a string representing the function name.
+ - `type` is either a FunctionType or a pair of list describing inputs and
+ results.
+ - `visibility` is a string matching `public`, `private`, or `nested`. None
+ implies private visibility.
+ - `body_builder` is an optional callback, when provided a new entry block
+ is created and the callback is invoked with the new op as argument within
+ an InsertionPoint context already set for the block. The callback is
+ expected to insert a terminator in the block.
+ """
+ sym_name = StringAttr.get(str(name))
+
+ # If the type is passed as a tuple, build a FunctionType on the fly.
+ if isinstance(type, tuple):
+ type = FunctionType.get(inputs=type[0], results=type[1])
+
+ type = TypeAttr.get(type)
+ sym_visibility = (
+ StringAttr.get(str(visibility)) if visibility is not None else None
+ )
+ super().__init__(sym_name, type, sym_visibility=sym_visibility, loc=loc, ip=ip)
+ if body_builder:
+ entry_block = self.add_entry_block()
+ with InsertionPoint(entry_block):
+ body_builder(self)
+
+ @property
+ def is_external(self):
+ return len(self.regions[0].blocks) == 0
+
+ @property
+ def body(self):
+ return self.regions[0]
+
+ @property
+ def type(self):
+ return FunctionType(TypeAttr(self.attributes["function_type"]).value)
+
+ @property
+ def visibility(self):
+ return self.attributes["sym_visibility"]
+
+ @property
+ def name(self) -> StringAttr:
+ return StringAttr(self.attributes["sym_name"])
+
+ @property
+ def entry_block(self):
+ if self.is_external:
+ raise IndexError("External function does not have a body")
+ return self.regions[0].blocks[0]
+
+ def add_entry_block(self):
+ """
+ Add an entry block to the function body using the function signature to
+ infer block arguments.
+ Returns the newly created block
+ """
+ if not self.is_external:
+ raise IndexError("The function already has an entry block!")
+ self.body.blocks.append(*self.type.inputs)
+ return self.body.blocks[0]
+
+ @property
+ def arg_attrs(self):
+ return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME])
+
+ @arg_attrs.setter
+ def arg_attrs(self, attribute: Union[ArrayAttr, list]):
+ if isinstance(attribute, ArrayAttr):
+ self.attributes[ARGUMENT_ATTRIBUTE_NAME] = attribute
+ else:
+ self.attributes[ARGUMENT_ATTRIBUTE_NAME] = ArrayAttr.get(
+ attribute, context=self.context
+ )
+
+ @property
+ def arguments(self):
+ return self.entry_block.arguments
+
+ @property
+ def result_attrs(self):
+ return self.attributes[RESULT_ATTRIBUTE_NAME]
+
+ @result_attrs.setter
+ def result_attrs(self, attribute: ArrayAttr):
+ self.attributes[RESULT_ATTRIBUTE_NAME] = attribute
diff --git a/mlir/python/mlir/dialects/pdl.py b/mlir/python/mlir/dialects/pdl.py
index dda2b7d6521965f..a8d9c56f4233d9e 100644
--- a/mlir/python/mlir/dialects/pdl.py
+++ b/mlir/python/mlir/dialects/pdl.py
@@ -3,4 +3,289 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from ._pdl_ops_gen import *
+from ._pdl_ops_gen import _Dialect
from .._mlir_libs._mlirDialectsPDL import *
+
+
+try:
+ from ..ir import *
+ from ..dialects import pdl
+except ImportError as e:
+ raise RuntimeError("Error loading imports from extension module") from e
+
+from typing import Union, Optional, Sequence, Mapping
+from ._ods_common import (
+ get_op_result_or_value as _get_value,
+ get_op_results_or_values as _get_values,
+ _cext as _ods_cext,
+)
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class ApplyNativeConstraintOp(ApplyNativeConstraintOp):
+ """Specialization for PDL apply native constraint op class."""
+
+ def __init__(
+ self,
+ name: Union[str, StringAttr],
+ args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
+ *,
+ loc=None,
+ ip=None,
+ ):
+ if args is None:
+ args = []
+ args = _get_values(args)
+ super().__init__(name, args, loc=loc, ip=ip)
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class ApplyNativeRewriteOp(ApplyNativeRewriteOp):
+ """Specialization for PDL apply native rewrite op class."""
+
+ def __init__(
+ self,
+ results: Sequence[Type],
+ name: Union[str, StringAttr],
+ args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
+ *,
+ loc=None,
+ ip=None,
+ ):
+ if args is None:
+ args = []
+ args = _get_values(args)
+ super().__init__(results, name, args, loc=loc, ip=ip)
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class AttributeOp(AttributeOp):
+ """Specialization for PDL attribute op class."""
+
+ def __init__(
+ self,
+ valueType: Optional[Union[OpView, Operation, Value]] = None,
+ value: Optional[Attribute] = None,
+ *,
+ loc=None,
+ ip=None,
+ ):
+ valueType = valueType if valueType is None else _get_value(valueType)
+ result = pdl.AttributeType.get()
+ super().__init__(result, valueType=valueType, value=value, loc=loc, ip=ip)
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class EraseOp(EraseOp):
+ """Specialization for PDL erase op class."""
+
+ def __init__(
+ self,
+ operation: Optional[Union[OpView, Operation, Value]] = None,
+ *,
+ loc=None,
+ ip=None,
+ ):
+ operation = _get_value(operation)
+ super().__init__(operation, loc=loc, ip=ip)
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class OperandOp(OperandOp):
+ """Specialization for PDL operand op class."""
+
+ def __init__(
+ self,
+ type: Optional[Union[OpView, Operation, Value]] = None,
+ *,
+ loc=None,
+ ip=None,
+ ):
+ type = type if type is None else _get_value(type)
+ result = pdl.ValueType.get()
+ super().__init__(result, valueType=type, loc=loc, ip=ip)
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class OperandsOp(OperandsOp):
+ """Specialization for PDL operands op class."""
+
+ def __init__(
+ self,
+ types: Optional[Union[OpView, Operation, Value]] = None,
+ *,
+ loc=None,
+ ip=None,
+ ):
+ types = types if types is None else _get_value(types)
+ result = pdl.RangeType.get(pdl.ValueType.get())
+ super().__init__(result, valueType=types, loc=loc, ip=ip)
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class OperationOp(OperationOp):
+ """Specialization for PDL operand op class."""
+
+ def __init__(
+ self,
+ name: Optional[Union[str, StringAttr]] = None,
+ args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
+ attributes: Optional[Mapping[str, Union[OpView, Operation, Value]]] = None,
+ types: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
+ *,
+ loc=None,
+ ip=None,
+ ):
+ if types is None:
+ types = []
+ if attributes is None:
+ attributes = {}
+ if args is None:
+ args = []
+ args = _get_values(args)
+ attrNames = []
+ attrValues = []
+ for attrName, attrValue in attributes.items():
+ attrNames.append(StringAttr.get(attrName))
+ attrValues.append(_get_value(attrValue))
+ attrNames = ArrayAttr.get(attrNames)
+ types = _get_values(types)
+ result = pdl.OperationType.get()
+ super().__init__(
+ result, args, attrValues, attrNames, types, opName=name, loc=loc, ip=ip
+ )
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class PatternOp(PatternOp):
+ """Specialization for PDL pattern op class."""
+
+ def __init__(
+ self,
+ benefit: Union[IntegerAttr, int],
+ name: Optional[Union[StringAttr, str]] = None,
+ *,
+ loc=None,
+ ip=None,
+ ):
+ """Creates an PDL `pattern` operation."""
+ super().__init__(benefit, sym_name=name, loc=loc, ip=ip)
+ self.regions[0].blocks.append()
+
+ @property
+ def body(self):
+ """Return the body (block) of the pattern."""
+ return self.regions[0].blocks[0]
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class ReplaceOp(ReplaceOp):
+ """Specialization for PDL replace op class."""
+
+ def __init__(
+ self,
+ op: Union[OpView, Operation, Value],
+ *,
+ with_op: Optional[Union[OpView, Operation, Value]] = None,
+ with_values: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
+ loc=None,
+ ip=None,
+ ):
+ if with_values is None:
+ with_values = []
+ op = _get_value(op)
+ with_op = with_op if with_op is None else _get_value(with_op)
+ with_values = _get_values(with_values)
+ super().__init__(op, with_values, replOperation=with_op, loc=loc, ip=ip)
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class ResultOp(ResultOp):
+ """Specialization for PDL result op class."""
+
+ def __init__(
+ self,
+ parent: Union[OpView, Operation, Value],
+ index: Union[IntegerAttr, int],
+ *,
+ loc=None,
+ ip=None,
+ ):
+ parent = _get_value(parent)
+ result = pdl.ValueType.get()
+ super().__init__(result, parent, index, loc=loc, ip=ip)
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class ResultsOp(ResultsOp):
+ """Specialization for PDL results op class."""
+
+ def __init__(
+ self,
+ result: Type,
+ parent: Union[OpView, Operation, Value],
+ index: Optional[Union[IntegerAttr, int]] = None,
+ *,
+ loc=None,
+ ip=None,
+ ):
+ parent = _get_value(parent)
+ super().__init__(result, parent, index=index, loc=loc, ip=ip)
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class RewriteOp(RewriteOp):
+ """Specialization for PDL rewrite op class."""
+
+ def __init__(
+ self,
+ root: Optional[Union[OpView, Operation, Value]] = None,
+ name: Optional[Union[StringAttr, str]] = None,
+ args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
+ *,
+ loc=None,
+ ip=None,
+ ):
+ if args is None:
+ args = []
+ root = root if root is None else _get_value(root)
+ args = _get_values(args)
+ super().__init__(args, root=root, name=name, loc=loc, ip=ip)
+
+ def add_body(self):
+ """Add body (block) to the rewrite."""
+ self.regions[0].blocks.append()
+ return self.body
+
+ @property
+ def body(self):
+ """Return the body (block) of the rewrite."""
+ return self.regions[0].blocks[0]
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class TypeOp(TypeOp):
+ """Specialization for PDL type op class."""
+
+ def __init__(
+ self, constantType: Optional[Union[TypeAttr, Type]] = None, *, loc=None, ip=None
+ ):
+ result = pdl.TypeType.get()
+ super().__init__(result, constantType=constantType, loc=loc, ip=ip)
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class TypesOp(TypesOp):
+ """Specialization for PDL types op class."""
+
+ def __init__(
+ self,
+ constantTypes: Optional[Sequence[Union[TypeAttr, Type]]] = None,
+ *,
+ loc=None,
+ ip=None,
+ ):
+ if constantTypes is None:
+ constantTypes = []
+ result = pdl.RangeType.get(pdl.TypeType.get())
+ super().__init__(result, constantTypes=constantTypes, loc=loc, ip=ip)
diff --git a/mlir/python/mlir/dialects/python_test.py b/mlir/python/mlir/dialects/python_test.py
index 8465af048a28056..6579e02d8549efa 100644
--- a/mlir/python/mlir/dialects/python_test.py
+++ b/mlir/python/mlir/dialects/python_test.py
@@ -3,7 +3,12 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from ._python_test_ops_gen import *
-from .._mlir_libs._mlirPythonTest import TestAttr, TestType, TestTensorValue, TestIntegerRankedTensorType
+from .._mlir_libs._mlirPythonTest import (
+ TestAttr,
+ TestType,
+ TestTensorValue,
+ TestIntegerRankedTensorType,
+)
def register_python_test_dialect(context, load=True):
diff --git a/mlir/python/mlir/dialects/scf.py b/mlir/python/mlir/dialects/scf.py
index 49685ca2271fc61..43ad9f4e2d65f51 100644
--- a/mlir/python/mlir/dialects/scf.py
+++ b/mlir/python/mlir/dialects/scf.py
@@ -2,11 +2,122 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-from typing import Optional, Sequence
from ._scf_ops_gen import *
+from ._scf_ops_gen import _Dialect
from .arith import constant
-from ..ir import *
+
+try:
+ from ..ir import *
+ from ._ods_common import (
+ get_op_result_or_value as _get_op_result_or_value,
+ get_op_results_or_values as _get_op_results_or_values,
+ _cext as _ods_cext,
+ )
+except ImportError as e:
+ raise RuntimeError("Error loading imports from extension module") from e
+
+from typing import Optional, Sequence, Union
+
+
+_ForOp = ForOp
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class ForOp(_ForOp):
+ """Specialization for the SCF for op class."""
+
+ def __init__(
+ self,
+ lower_bound,
+ upper_bound,
+ step,
+ iter_args: Optional[Union[Operation, OpView, Sequence[Value]]] = None,
+ *,
+ loc=None,
+ ip=None,
+ ):
+ """Creates an SCF `for` operation.
+
+ - `lower_bound` is the value to use as lower bound of the loop.
+ - `upper_bound` is the value to use as upper bound of the loop.
+ - `step` is the value to use as loop step.
+ - `iter_args` is a list of additional loop-carried arguments or an operation
+ producing them as results.
+ """
+ if iter_args is None:
+ iter_args = []
+ iter_args = _get_op_results_or_values(iter_args)
+
+ results = [arg.type for arg in iter_args]
+ super(_ForOp, self).__init__(
+ self.build_generic(
+ regions=1,
+ results=results,
+ operands=[
+ _get_op_result_or_value(o) for o in [lower_bound, upper_bound, step]
+ ]
+ + list(iter_args),
+ loc=loc,
+ ip=ip,
+ )
+ )
+ self.regions[0].blocks.append(self.operands[0].type, *results)
+
+ @property
+ def body(self):
+ """Returns the body (block) of the loop."""
+ return self.regions[0].blocks[0]
+
+ @property
+ def induction_variable(self):
+ """Returns the induction variable of the loop."""
+ return self.body.arguments[0]
+
+ @property
+ def inner_iter_args(self):
+ """Returns the loop-carried arguments usable within the loop.
+
+ To obtain the loop-carried operands, use `iter_args`.
+ """
+ return self.body.arguments[1:]
+
+
+_IfOp = IfOp
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class IfOp(_IfOp):
+ """Specialization for the SCF if op class."""
+
+ def __init__(self, cond, results_=[], *, hasElse=False, loc=None, ip=None):
+ """Creates an SCF `if` operation.
+
+ - `cond` is a MLIR value of 'i1' type to determine which regions of code will be executed.
+ - `hasElse` determines whether the if operation has the else branch.
+ """
+ operands = []
+ operands.append(cond)
+ results = []
+ results.extend(results_)
+ super(_IfOp, self).__init__(
+ self.build_generic(
+ regions=2, results=results, operands=operands, loc=loc, ip=ip
+ )
+ )
+ self.regions[0].blocks.append(*[])
+ if hasElse:
+ self.regions[1].blocks.append(*[])
+
+ @property
+ def then_block(self):
+ """Returns the then block of the if operation."""
+ return self.regions[0].blocks[0]
+
+ @property
+ def else_block(self):
+ """Returns the else block of the if operation."""
+ return self.regions[1].blocks[0]
def for_(
diff --git a/mlir/python/mlir/dialects/tensor.py b/mlir/python/mlir/dialects/tensor.py
index 26edf6b6436dad5..67248748eaf3ada 100644
--- a/mlir/python/mlir/dialects/tensor.py
+++ b/mlir/python/mlir/dialects/tensor.py
@@ -3,3 +3,40 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from ._tensor_ops_gen import *
+from ._tensor_ops_gen import _Dialect
+
+try:
+ from ..ir import *
+except ImportError as e:
+ raise RuntimeError("Error loading imports from extension module") from e
+
+from typing import Sequence, Union
+from ._ods_common import _cext as _ods_cext
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class EmptyOp(EmptyOp):
+ """Extends the tensor.empty op."""
+
+ def __init__(
+ self,
+ sizes: Sequence[Union[int, Value]],
+ element_type: Type,
+ *,
+ loc=None,
+ ip=None,
+ ):
+ """Constructs an `empty` with mixed static/dynamic sizes."""
+ # TODO: Refactor the EmptyOp to take an element type attribute and
+ # then use normal result type inference, unifying the Python and C++ side
+ # with a standard mechanism (versus stashing that in builders).
+ dynamic_sizes = []
+ static_sizes = []
+ for s in sizes:
+ if isinstance(s, int):
+ static_sizes.append(s)
+ else:
+ static_sizes.append(ShapedType.get_dynamic_size())
+ dynamic_sizes.append(s)
+ result_type = RankedTensorType.get(static_sizes, element_type)
+ super().__init__(result_type, dynamic_sizes, loc=loc, ip=ip)
diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py
index b020ad35fcf062f..f7a2026e800aeb0 100644
--- a/mlir/python/mlir/dialects/transform/__init__.py
+++ b/mlir/python/mlir/dialects/transform/__init__.py
@@ -4,4 +4,174 @@
from .._transform_enum_gen import *
from .._transform_ops_gen import *
+from .._transform_ops_gen import _Dialect
from ..._mlir_libs._mlirDialectsTransform import *
+
+try:
+ from ...ir import *
+ from .._ods_common import (
+ get_op_result_or_value as _get_op_result_or_value,
+ get_op_results_or_values as _get_op_results_or_values,
+ _cext as _ods_cext,
+ )
+except ImportError as e:
+ raise RuntimeError("Error loading imports from extension module") from e
+
+from typing import Optional, Sequence, Union
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class CastOp(CastOp):
+ def __init__(
+ self,
+ result_type: Type,
+ target: Union[Operation, Value],
+ *,
+ loc=None,
+ ip=None,
+ ):
+ super().__init__(result_type, _get_op_result_or_value(target), loc=loc, ip=ip)
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class ApplyPatternsOp(ApplyPatternsOp):
+ def __init__(
+ self,
+ target: Union[Operation, Value, OpView],
+ *,
+ loc=None,
+ ip=None,
+ ):
+ super().__init__(target, loc=loc, ip=ip)
+ self.regions[0].blocks.append()
+
+ @property
+ def patterns(self) -> Block:
+ return self.regions[0].blocks[0]
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class GetParentOp(GetParentOp):
+ def __init__(
+ self,
+ result_type: Type,
+ target: Union[Operation, Value],
+ *,
+ isolated_from_above: bool = False,
+ op_name: Optional[str] = None,
+ deduplicate: bool = False,
+ loc=None,
+ ip=None,
+ ):
+ super().__init__(
+ result_type,
+ _get_op_result_or_value(target),
+ isolated_from_above=isolated_from_above,
+ op_name=op_name,
+ deduplicate=deduplicate,
+ loc=loc,
+ ip=ip,
+ )
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class MergeHandlesOp(MergeHandlesOp):
+ def __init__(
+ self,
+ handles: Sequence[Union[Operation, Value]],
+ *,
+ deduplicate: bool = False,
+ loc=None,
+ ip=None,
+ ):
+ super().__init__(
+ [_get_op_result_or_value(h) for h in handles],
+ deduplicate=deduplicate,
+ loc=loc,
+ ip=ip,
+ )
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class ReplicateOp(ReplicateOp):
+ def __init__(
+ self,
+ pattern: Union[Operation, Value],
+ handles: Sequence[Union[Operation, Value]],
+ *,
+ loc=None,
+ ip=None,
+ ):
+ super().__init__(
+ [_get_op_result_or_value(h).type for h in handles],
+ _get_op_result_or_value(pattern),
+ [_get_op_result_or_value(h) for h in handles],
+ loc=loc,
+ ip=ip,
+ )
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class SequenceOp(SequenceOp):
+ def __init__(
+ self,
+ failure_propagation_mode,
+ results: Sequence[Type],
+ target: Union[Operation, Value, Type],
+ extra_bindings: Optional[
+ Union[Sequence[Value], Sequence[Type], Operation, OpView]
+ ] = None,
+ ):
+ root = (
+ _get_op_result_or_value(target)
+ if isinstance(target, (Operation, Value))
+ else None
+ )
+ root_type = root.type if not isinstance(target, Type) else target
+
+ if extra_bindings is None:
+ extra_bindings = []
+ if isinstance(extra_bindings, (Operation, OpView)):
+ extra_bindings = _get_op_results_or_values(extra_bindings)
+
+ extra_binding_types = []
+ if len(extra_bindings) != 0:
+ if isinstance(extra_bindings[0], Type):
+ extra_binding_types = extra_bindings
+ extra_bindings = []
+ else:
+ extra_binding_types = [v.type for v in extra_bindings]
+
+ super().__init__(
+ results_=results,
+ failure_propagation_mode=failure_propagation_mode,
+ root=root,
+ extra_bindings=extra_bindings,
+ )
+ self.regions[0].blocks.append(*tuple([root_type] + extra_binding_types))
+
+ @property
+ def body(self) -> Block:
+ return self.regions[0].blocks[0]
+
+ @property
+ def bodyTarget(self) -> Value:
+ return self.body.arguments[0]
+
+ @property
+ def bodyExtraArgs(self) -> BlockArgumentList:
+ return self.body.arguments[1:]
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class YieldOp(YieldOp):
+ def __init__(
+ self,
+ operands: Optional[Union[Operation, Sequence[Value]]] = None,
+ *,
+ loc=None,
+ ip=None,
+ ):
+ if operands is None:
+ operands = []
+ super().__init__(_get_op_results_or_values(operands), loc=loc, ip=ip)
diff --git a/mlir/python/mlir/dialects/transform/bufferization.py b/mlir/python/mlir/dialects/transform/bufferization.py
index eb77b746cf864fa..485a8a36b6305e9 100644
--- a/mlir/python/mlir/dialects/transform/bufferization.py
+++ b/mlir/python/mlir/dialects/transform/bufferization.py
@@ -3,3 +3,132 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from .._bufferization_transform_ops_gen import *
+from .._bufferization_transform_ops_gen import _Dialect
+
+try:
+ from ...ir import *
+ from ...dialects import transform
+ from .._ods_common import _cext as _ods_cext
+except ImportError as e:
+ raise RuntimeError("Error loading imports from extension module") from e
+
+from enum import Enum
+from typing import Optional, overload, Union
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class EmptyTensorToAllocTensorOp(EmptyTensorToAllocTensorOp):
+ """Specialization for EmptyTensorToAllocTensorOp class."""
+
+ @overload
+ def __init__(
+ self,
+ transformed_type: Type,
+ target: Union[Operation, OpView, Value],
+ *,
+ loc=None,
+ ip=None,
+ ):
+ ...
+
+ @overload
+ def __init__(self, target: Union[Operation, OpView, Value], *, loc=None, ip=None):
+ ...
+
+ def __init__(
+ self,
+ transformed_type_or_target: Type,
+ target_or_none: Optional[Union[Operation, OpView, Value]] = None,
+ *,
+ loc=None,
+ ip=None,
+ ):
+ if isinstance(transformed_type_or_target, Type):
+ transformed_type = transformed_type_or_target
+ target = target_or_none
+ else:
+ transformed_type = transform.OperationType.get("bufferization.alloc_tensor")
+ target = transformed_type_or_target
+
+ super().__init__(
+ transformed_type,
+ target,
+ loc=loc,
+ ip=ip,
+ )
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class OneShotBufferizeOp(OneShotBufferizeOp):
+ """Specialization for OneShotBufferizeOp class."""
+
+ @overload
+ def __init__(
+ self,
+ transformed_type: Type,
+ target: Union[Operation, OpView, Value],
+ *,
+ allow_return_allocs_from_loops: Optional[bool] = None,
+ allow_unknown_ops: Optional[bool] = None,
+ bufferize_function_boundaries: Optional[bool] = None,
+ function_boundary_type_conversion: Optional[Enum] = None,
+ memcpy_op: Optional[str] = None,
+ print_conflicts: Optional[bool] = None,
+ test_analysis_only: Optional[bool] = None,
+ loc=None,
+ ip=None,
+ ):
+ ...
+
+ @overload
+ def __init__(
+ self,
+ target: Union[Operation, OpView, Value],
+ *,
+ allow_return_allocs_from_loops: Optional[bool] = None,
+ allow_unknown_ops: Optional[bool] = None,
+ bufferize_function_boundaries: Optional[bool] = None,
+ function_boundary_type_conversion: Optional[Enum] = None,
+ memcpy_op: Optional[str] = None,
+ print_conflicts: Optional[bool] = None,
+ test_analysis_only: Optional[bool] = None,
+ loc=None,
+ ip=None,
+ ):
+ ...
+
+ def __init__(
+ self,
+ transformed_type_or_target: Type,
+ target_or_none: Optional[Union[Operation, OpView, Value]] = None,
+ *,
+ allow_return_allocs_from_loops: Optional[bool] = None,
+ allow_unknown_ops: Optional[bool] = None,
+ bufferize_function_boundaries: Optional[bool] = None,
+ function_boundary_type_conversion: Optional[Enum] = None,
+ memcpy_op: Optional[str] = None,
+ print_conflicts: Optional[bool] = None,
+ test_analysis_only: Optional[bool] = None,
+ loc=None,
+ ip=None,
+ ):
+ if isinstance(transformed_type_or_target, Type):
+ transformed_type = transformed_type_or_target
+ target = target_or_none
+ else:
+ transformed_type = transform.AnyOpType.get()
+ target = transformed_type_or_target
+
+ super().__init__(
+ transformed_type,
+ target,
+ allow_return_allocs_from_loops=allow_return_allocs_from_loops,
+ allow_unknown_ops=allow_unknown_ops,
+ bufferize_function_boundaries=bufferize_function_boundaries,
+ function_boundary_type_conversion=function_boundary_type_conversion,
+ memcpy_op=memcpy_op,
+ print_conflicts=print_conflicts,
+ test_analysis_only=test_analysis_only,
+ loc=loc,
+ ip=ip,
+ )
diff --git a/mlir/python/mlir/dialects/transform/gpu.py b/mlir/python/mlir/dialects/transform/gpu.py
index 8c3de0de7ea3f19..00cf0840eeae9e1 100644
--- a/mlir/python/mlir/dialects/transform/gpu.py
+++ b/mlir/python/mlir/dialects/transform/gpu.py
@@ -3,3 +3,128 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from .._gpu_transform_ops_gen import *
+from .._gpu_transform_ops_gen import _Dialect
+
+try:
+ from ...ir import *
+ from ...dialects import transform
+ from .._ods_common import _cext as _ods_cext
+except ImportError as e:
+ raise RuntimeError("Error loading imports from extension module") from e
+
+from typing import Optional, Sequence, Union, overload
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class MapForallToBlocks(MapForallToBlocks):
+ """Specialization for MapForallToBlocks class."""
+
+ @overload
+ def __init__(
+ self,
+ result_type: Type,
+ target: Union[Operation, OpView, Value],
+ *,
+ grid_dims: Optional[Union[Sequence[int], Attribute]] = None,
+ generate_gpu_launch: Optional[Union[bool, Attribute]] = None,
+ loc=None,
+ ip=None,
+ ):
+ ...
+
+ @overload
+ def __init__(
+ self,
+ target: Union[Operation, OpView, Value],
+ *,
+ grid_dims: Optional[Union[Sequence[int], Attribute]] = None,
+ generate_gpu_launch: Optional[Union[bool, Attribute]] = None,
+ loc=None,
+ ip=None,
+ ):
+ ...
+
+ def __init__(
+ self,
+ result_type_or_target: Union[Operation, OpView, Type, Value],
+ target_or_none: Optional[Union[Operation, OpView, Value]] = None,
+ *,
+ grid_dims: Optional[Union[Sequence[int], Attribute]] = None,
+ generate_gpu_launch: Optional[Union[bool, Attribute]] = None,
+ loc=None,
+ ip=None,
+ ):
+ if isinstance(result_type_or_target, Type):
+ result_type = result_type_or_target
+ target = target_or_none
+ else:
+ result_type = transform.AnyOpType.get()
+ target = result_type_or_target
+
+ super().__init__(
+ result_type,
+ target,
+ grid_dims=grid_dims,
+ generate_gpu_launch=generate_gpu_launch,
+ loc=loc,
+ ip=ip,
+ )
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class MapNestedForallToThreads(MapNestedForallToThreads):
+ """Specialization for MapNestedForallToThreads class."""
+
+ @overload
+ def __init__(
+ self,
+ result_type: Type,
+ target: Union[Operation, OpView, Value],
+ *,
+ block_dims: Optional[Sequence[int]] = None,
+ warp_size: Optional[Sequence[int]] = None,
+ sync_after_distribute: Optional[bool] = None,
+ loc=None,
+ ip=None,
+ ):
+ ...
+
+ @overload
+ def __init__(
+ self,
+ target: Union[Operation, OpView, Value],
+ *,
+ block_dims: Optional[Sequence[int]] = None,
+ warp_size: Optional[Sequence[int]] = None,
+ sync_after_distribute: Optional[bool] = None,
+ loc=None,
+ ip=None,
+ ):
+ ...
+
+ def __init__(
+ self,
+ result_type_or_target: Union[Operation, OpView, Value, Type],
+ target_or_none: Optional[Union[Operation, OpView, Value]] = None,
+ *,
+ block_dims: Optional[Union[Sequence[int], Attribute]] = None,
+ warp_size: Optional[Union[Sequence[int], Attribute]] = None,
+ sync_after_distribute: Optional[bool] = None,
+ loc=None,
+ ip=None,
+ ):
+ if isinstance(result_type_or_target, Type):
+ result_type = result_type_or_target
+ target = target_or_none
+ else:
+ result_type = result_type_or_target.type
+ target = result_type_or_target
+ super().__init__(
+ result_type,
+ target,
+ block_dims=block_dims,
+ warp_size=warp_size,
+ sync_after_distribute=sync_after_distribute,
+ loc=loc,
+ ip=ip,
+ )
diff --git a/mlir/python/mlir/dialects/transform/loop.py b/mlir/python/mlir/dialects/transform/loop.py
index 86f72788d86c369..6c89025f413839e 100644
--- a/mlir/python/mlir/dialects/transform/loop.py
+++ b/mlir/python/mlir/dialects/transform/loop.py
@@ -3,3 +3,143 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from .._loop_transform_ops_gen import *
+from .._loop_transform_ops_gen import _Dialect
+
+try:
+ from ...ir import *
+ from .._ods_common import (
+ get_op_result_or_value as _get_op_result_or_value,
+ _cext as _ods_cext,
+ )
+except ImportError as e:
+ raise RuntimeError("Error loading imports from extension module") from e
+
+from typing import Optional, Union
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class GetParentForOp(GetParentForOp):
+ """Extension for GetParentForOp."""
+
+ def __init__(
+ self,
+ result_type: Type,
+ target: Union[Operation, Value],
+ *,
+ num_loops: Optional[int] = None,
+ ip=None,
+ loc=None,
+ ):
+ if num_loops is None:
+ num_loops = 1
+ super().__init__(
+ result_type,
+ _get_op_result_or_value(target),
+ num_loops=num_loops,
+ ip=ip,
+ loc=loc,
+ )
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class LoopOutlineOp(LoopOutlineOp):
+ """Extension for LoopOutlineOp."""
+
+ def __init__(
+ self,
+ function_type: Type,
+ call_type: Type,
+ target: Union[Operation, Value],
+ *,
+ func_name: Union[str, StringAttr],
+ ip=None,
+ loc=None,
+ ):
+ super().__init__(
+ function_type,
+ call_type,
+ _get_op_result_or_value(target),
+ func_name=(
+ func_name
+ if isinstance(func_name, StringAttr)
+ else StringAttr.get(func_name)
+ ),
+ ip=ip,
+ loc=loc,
+ )
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class LoopPeelOp(LoopPeelOp):
+ """Extension for LoopPeelOp."""
+
+ def __init__(
+ self,
+ main_loop_type: Type,
+ remainder_loop_type: Type,
+ target: Union[Operation, Value],
+ *,
+ fail_if_already_divisible: Union[bool, BoolAttr] = False,
+ ip=None,
+ loc=None,
+ ):
+ super().__init__(
+ main_loop_type,
+ remainder_loop_type,
+ _get_op_result_or_value(target),
+ fail_if_already_divisible=(
+ fail_if_already_divisible
+ if isinstance(fail_if_already_divisible, BoolAttr)
+ else BoolAttr.get(fail_if_already_divisible)
+ ),
+ ip=ip,
+ loc=loc,
+ )
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class LoopPipelineOp(LoopPipelineOp):
+ """Extension for LoopPipelineOp."""
+
+ def __init__(
+ self,
+ result_type: Type,
+ target: Union[Operation, Value],
+ *,
+ iteration_interval: Optional[Union[int, IntegerAttr]] = None,
+ read_latency: Optional[Union[int, IntegerAttr]] = None,
+ ip=None,
+ loc=None,
+ ):
+ if iteration_interval is None:
+ iteration_interval = 1
+ if read_latency is None:
+ read_latency = 10
+ super().__init__(
+ result_type,
+ _get_op_result_or_value(target),
+ iteration_interval=iteration_interval,
+ read_latency=read_latency,
+ ip=ip,
+ loc=loc,
+ )
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class LoopUnrollOp(LoopUnrollOp):
+ """Extension for LoopUnrollOp."""
+
+ def __init__(
+ self,
+ target: Union[Operation, Value],
+ *,
+ factor: Union[int, IntegerAttr],
+ ip=None,
+ loc=None,
+ ):
+ super().__init__(
+ _get_op_result_or_value(target),
+ factor=factor,
+ ip=ip,
+ loc=loc,
+ )
diff --git a/mlir/python/mlir/dialects/transform/memref.py b/mlir/python/mlir/dialects/transform/memref.py
index 1ff04ef6a60a180..56ea61eb817f89c 100644
--- a/mlir/python/mlir/dialects/transform/memref.py
+++ b/mlir/python/mlir/dialects/transform/memref.py
@@ -3,3 +3,118 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from .._memref_transform_ops_gen import *
+from .._memref_transform_ops_gen import _Dialect
+
+try:
+ from ...ir import *
+ from ...dialects import transform
+ from .._ods_common import _cext as _ods_cext
+except ImportError as e:
+ raise RuntimeError("Error loading imports from extension module") from e
+
+from typing import Optional, overload, Union
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class MemRefAllocaToGlobalOp(MemRefAllocaToGlobalOp):
+ """Specialization for MemRefAllocaToGlobalOp class."""
+
+ @overload
+ def __init__(
+ self,
+ get_global_type: Type,
+ global_type: Type,
+ alloca: Union[Operation, OpView, Value],
+ *,
+ loc=None,
+ ip=None,
+ ):
+ ...
+
+ @overload
+ def __init__(self, alloca: Union[Operation, OpView, Value], *, loc=None, ip=None):
+ ...
+
+ def __init__(
+ self,
+ get_global_type_or_alloca: Union[Operation, OpView, Type, Value],
+ global_type_or_none: Optional[Type] = None,
+ alloca_or_none: Optional[Union[Operation, OpView, Value]] = None,
+ *,
+ loc=None,
+ ip=None,
+ ):
+ if isinstance(get_global_type_or_alloca, Type):
+ get_global_type = get_global_type_or_alloca
+ global_type = global_type_or_none
+ alloca = alloca_or_none
+ else:
+ get_global_type = transform.AnyOpType.get()
+ global_type = transform.AnyOpType.get()
+ alloca = get_global_type_or_alloca
+
+ super().__init__(
+ get_global_type,
+ global_type,
+ alloca,
+ loc=loc,
+ ip=ip,
+ )
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class MemRefMultiBufferOp(MemRefMultiBufferOp):
+ """Specialization for MemRefMultiBufferOp class."""
+
+ @overload
+ def __init__(
+ self,
+ transformed_type: Type,
+ target: Union[Operation, OpView, Value],
+ factor: Union[int, IntegerAttr],
+ *,
+ skip_analysis: Optional[bool] = None,
+ loc=None,
+ ip=None,
+ ):
+ ...
+
+ @overload
+ def __init__(
+ self,
+ target: Union[Operation, OpView, Value],
+ factor: Union[int, IntegerAttr],
+ *,
+ skip_analysis: Optional[bool] = None,
+ loc=None,
+ ip=None,
+ ):
+ ...
+
+ def __init__(
+ self,
+ transformed_type_or_target: Type,
+ target_or_factor: Union[int, IntegerAttr, Operation, OpView, Value] = None,
+ factor_or_none: Optional[Union[int, IntegerAttr]] = None,
+ *,
+ skip_analysis: Optional[bool] = None,
+ loc=None,
+ ip=None,
+ ):
+ if isinstance(transformed_type_or_target, Type):
+ transformed_type = transformed_type_or_target
+ target = target_or_factor
+ factor = factor_or_none
+ else:
+ transformed_type = transform.AnyOpType.get()
+ target = transformed_type_or_target
+ factor = target_or_factor
+
+ super().__init__(
+ transformed_type,
+ target,
+ factor,
+ skip_analysis=skip_analysis,
+ loc=loc,
+ ip=ip,
+ )
diff --git a/mlir/python/mlir/dialects/transform/pdl.py b/mlir/python/mlir/dialects/transform/pdl.py
index b1515287a3f1ff0..bb5fa7ffd306583 100644
--- a/mlir/python/mlir/dialects/transform/pdl.py
+++ b/mlir/python/mlir/dialects/transform/pdl.py
@@ -3,3 +3,53 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from .._transform_pdl_extension_ops_gen import *
+from .._transform_pdl_extension_ops_gen import _Dialect
+
+try:
+ from ...ir import *
+ from .._ods_common import (
+ get_op_result_or_value as _get_op_result_or_value,
+ get_op_results_or_values as _get_op_results_or_values,
+ _cext as _ods_cext,
+ )
+except ImportError as e:
+ raise RuntimeError("Error loading imports from extension module") from e
+
+from typing import Union
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class PDLMatchOp(PDLMatchOp):
+ def __init__(
+ self,
+ result_type: Type,
+ target: Union[Operation, Value],
+ pattern_name: Union[Attribute, str],
+ *,
+ loc=None,
+ ip=None,
+ ):
+ super().__init__(
+ result_type,
+ _get_op_result_or_value(target),
+ pattern_name,
+ loc=loc,
+ ip=ip,
+ )
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class WithPDLPatternsOp(WithPDLPatternsOp):
+ def __init__(self, target: Union[Operation, Value, Type], *, loc=None, ip=None):
+ root = _get_op_result_or_value(target) if not isinstance(target, Type) else None
+ root_type = target if isinstance(target, Type) else root.type
+ super().__init__(root=root, loc=loc, ip=ip)
+ self.regions[0].blocks.append(root_type)
+
+ @property
+ def body(self) -> Block:
+ return self.regions[0].blocks[0]
+
+ @property
+ def bodyTarget(self) -> Value:
+ return self.body.arguments[0]
diff --git a/mlir/python/mlir/dialects/transform/structured.py b/mlir/python/mlir/dialects/transform/structured.py
index cb3812301dbd4b5..284c93823acbd34 100644
--- a/mlir/python/mlir/dialects/transform/structured.py
+++ b/mlir/python/mlir/dialects/transform/structured.py
@@ -3,4 +3,777 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from .._structured_transform_ops_gen import *
+from .._structured_transform_ops_gen import _Dialect
from .._structured_transform_enum_gen import *
+
+try:
+ from ...ir import *
+ from ...dialects import transform
+ from .._ods_common import _cext as _ods_cext
+except ImportError as e:
+ raise RuntimeError("Error loading imports from extension module") from e
+
+from typing import List, Optional, Sequence, Tuple, Union, overload
+
+StaticIntLike = Union[int, IntegerAttr]
+ValueLike = Union[Operation, OpView, Value]
+MixedInt = Union[StaticIntLike, ValueLike]
+
+IntOrAttrList = Sequence[Union[IntegerAttr, int]]
+OptionalIntList = Optional[Union[ArrayAttr, IntOrAttrList]]
+
+BoolOrAttrList = Sequence[Union[BoolAttr, bool]]
+OptionalBoolList = Optional[Union[ArrayAttr, BoolOrAttrList]]
+
+MixedValues = Union[Sequence[Union[StaticIntLike, ValueLike]], ArrayAttr, ValueLike]
+
+DynamicIndexList = Sequence[Union[MixedInt, Sequence[MixedInt]]]
+
+
+def _dispatch_dynamic_index_list(
+ indices: Union[DynamicIndexList, ArrayAttr],
+) -> Tuple[List[ValueLike], Union[List[int], ArrayAttr], List[bool]]:
+ """Dispatches a list of indices to the appropriate form.
+
+ This is similar to the custom `DynamicIndexList` directive upstream:
+ provided indices may be in the form of dynamic SSA values or static values,
+ and they may be scalable (i.e., as a singleton list) or not. This function
+ dispatches each index into its respective form. It also extracts the SSA
+ values and static indices from various similar structures, respectively.
+ """
+ dynamic_indices = []
+ static_indices = [ShapedType.get_dynamic_size()] * len(indices)
+ scalable_indices = [False] * len(indices)
+
+ # ArrayAttr: Extract index values.
+ if isinstance(indices, ArrayAttr):
+ indices = [idx for idx in indices]
+
+ def process_nonscalable_index(i, index):
+ """Processes any form of non-scalable index.
+
+ Returns False if the given index was scalable and thus remains
+ unprocessed; True otherwise.
+ """
+ if isinstance(index, int):
+ static_indices[i] = index
+ elif isinstance(index, IntegerAttr):
+ static_indices[i] = index.value # pytype: disable=attribute-error
+ elif isinstance(index, (Operation, Value, OpView)):
+ dynamic_indices.append(index)
+ else:
+ return False
+ return True
+
+ # Process each index at a time.
+ for i, index in enumerate(indices):
+ if not process_nonscalable_index(i, index):
+ # If it wasn't processed, it must be a scalable index, which is
+ # provided as a Sequence of one value, so extract and process that.
+ scalable_indices[i] = True
+ assert len(index) == 1
+ ret = process_nonscalable_index(i, index[0])
+ assert ret
+
+ return dynamic_indices, static_indices, scalable_indices
+
+
+# Dispatches `MixedValues` that all represents integers in various forms into
+# the following three categories:
+# - `dynamic_values`: a list of `Value`s, potentially from op results;
+# - `packed_values`: a value handle, potentially from an op result, associated
+# to one or more payload operations of integer type;
+# - `static_values`: an `ArrayAttr` of `i64`s with static values, from Python
+# `int`s, from `IntegerAttr`s, or from an `ArrayAttr`.
+# The input is in the form for `packed_values`, only that result is set and the
+# other two are empty. Otherwise, the input can be a mix of the other two forms,
+# and for each dynamic value, a special value is added to the `static_values`.
+def _dispatch_mixed_values(
+ values: MixedValues,
+) -> Tuple[List[Value], Union[Operation, Value, OpView], DenseI64ArrayAttr]:
+ dynamic_values = []
+ packed_values = None
+ static_values = None
+ if isinstance(values, ArrayAttr):
+ static_values = values
+ elif isinstance(values, (Operation, Value, OpView)):
+ packed_values = values
+ else:
+ static_values = []
+ for size in values or []:
+ if isinstance(size, int):
+ static_values.append(size)
+ else:
+ static_values.append(ShapedType.get_dynamic_size())
+ dynamic_values.append(size)
+ static_values = DenseI64ArrayAttr.get(static_values)
+
+ return (dynamic_values, packed_values, static_values)
+
+
+def _get_value_or_attribute_value(
+ value_or_attr: Union[any, Attribute, ArrayAttr]
+) -> any:
+ if isinstance(value_or_attr, Attribute) and hasattr(value_or_attr, "value"):
+ return value_or_attr.value
+ if isinstance(value_or_attr, ArrayAttr):
+ return _get_value_list(value_or_attr)
+ return value_or_attr
+
+
+def _get_value_list(
+ sequence_or_array_attr: Union[Sequence[any], ArrayAttr]
+) -> Sequence[any]:
+ return [_get_value_or_attribute_value(v) for v in sequence_or_array_attr]
+
+
+def _get_int_array_attr(values: Optional[Union[ArrayAttr, IntOrAttrList]]) -> ArrayAttr:
+ if values is None:
+ return None
+
+ # Turn into a Python list of Python ints.
+ values = _get_value_list(values)
+
+ # Make an ArrayAttr of IntegerAttrs out of it.
+ return ArrayAttr.get(
+ [IntegerAttr.get(IntegerType.get_signless(64), v) for v in values]
+ )
+
+
+def _get_int_array_array_attr(
+ values: Optional[Union[ArrayAttr, Sequence[Union[ArrayAttr, IntOrAttrList]]]]
+) -> ArrayAttr:
+ """Creates an ArrayAttr of ArrayAttrs of IntegerAttrs.
+
+ The input has to be a collection of collection of integers, where any
+ Python Sequence and ArrayAttr are admissible collections and Python ints and
+ any IntegerAttr are admissible integers. Both levels of collections are
+ turned into ArrayAttr; the inner level is turned into IntegerAttrs of i64s.
+ If the input is None, an empty ArrayAttr is returned.
+ """
+ if values is None:
+ return None
+
+ # Make sure the outer level is a list.
+ values = _get_value_list(values)
+
+ # The inner level is now either invalid or a mixed sequence of ArrayAttrs and
+ # Sequences. Make sure the nested values are all lists.
+ values = [_get_value_list(nested) for nested in values]
+
+ # Turn each nested list into an ArrayAttr.
+ values = [_get_int_array_attr(nested) for nested in values]
+
+ # Turn the outer list into an ArrayAttr.
+ return ArrayAttr.get(values)
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class BufferizeToAllocationOp(BufferizeToAllocationOp):
+ """Specialization for BufferizeToAllocationOp class."""
+
+ def __init__(
+ self,
+ target: Union[Operation, OpView, Value],
+ *,
+ memory_space: Optional[Union[int, str, Attribute]] = None,
+ memcpy_op: Optional[str] = None,
+ alloc_op: Optional[str] = None,
+ bufferize_destination_only: Optional[bool] = None,
+ loc=None,
+ ip=None,
+ ):
+ # No other types are allowed, so hard-code those here.
+ allocated_buffer_type = transform.AnyValueType.get()
+ new_ops_type = transform.AnyOpType.get()
+
+ if isinstance(memory_space, int):
+ memory_space = str(memory_space)
+ if isinstance(memory_space, str):
+ memory_space = Attribute.parse(memory_space)
+
+ super().__init__(
+ allocated_buffer_type,
+ new_ops_type,
+ target,
+ memory_space=memory_space,
+ memcpy_op=memcpy_op,
+ alloc_op=alloc_op,
+ bufferize_destination_only=bufferize_destination_only,
+ loc=loc,
+ ip=ip,
+ )
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class DecomposeOp(DecomposeOp):
+ """Specialization for DecomposeOp class."""
+
+ def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
+ transformed_type = transform.AnyOpType.get()
+ super().__init__(transformed_type, target, loc=loc, ip=ip)
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class FuseIntoContainingOp(FuseIntoContainingOp):
+ """Specialization for FuseIntoContainingOp class."""
+
+ @overload
+ def __init__(
+ self,
+ fused_op_type: Type,
+ new_containing_op_type: Type,
+ producer_op: Union[Operation, OpView, Value],
+ containing_op: Union[Operation, OpView, Value],
+ *,
+ loc=None,
+ ip=None,
+ ):
+ ...
+
+ @overload
+ def __init__(
+ self,
+ producer_op: Union[Operation, OpView, Value],
+ containing_op: Union[Operation, OpView, Value],
+ *,
+ loc=None,
+ ip=None,
+ ):
+ ...
+
+ def __init__(
+ self,
+ fused_op_type_or_producer_op: Union[Operation, OpView, Type, Value],
+ new_containing_op_type_or_containing_op: Union[Operation, OpView, Type, Value],
+ producer_op_or_none: Optional[Union[Operation, OpView, Value]] = None,
+ containing_op_or_none: Optional[Union[Operation, OpView, Value]] = None,
+ *,
+ loc=None,
+ ip=None,
+ ):
+ if isinstance(fused_op_type_or_producer_op, Type):
+ if not isinstance(new_containing_op_type_or_containing_op, Type):
+ raise TypeError(
+ "If 'fused_op_type_or_producer_op' is a type, then "
+ "'new_containing_op_type_or_containing_op' is expected "
+ "to be one as well."
+ )
+ fused_op_type = fused_op_type_or_producer_op
+ new_containing_op_type = new_containing_op_type_or_containing_op
+ producer_op = producer_op_or_none
+ containing_op = containing_op_or_none
+ else:
+ fused_op_type = transform.AnyOpType.get()
+ new_containing_op_type = transform.AnyOpType.get()
+ producer_op = fused_op_type_or_producer_op
+ containing_op = new_containing_op_type_or_containing_op
+
+ super().__init__(
+ fused_op_type,
+ new_containing_op_type,
+ producer_op,
+ containing_op,
+ loc=loc,
+ ip=ip,
+ )
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class GeneralizeOp(GeneralizeOp):
+ """Specialization for GeneralizeOp class."""
+
+ def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
+ transformed_type = transform.AnyOpType.get()
+ super().__init__(transformed_type, target, loc=loc, ip=ip)
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class InterchangeOp(InterchangeOp):
+ """Specialization for InterchangeOp class."""
+
+ def __init__(
+ self,
+ target: Union[Operation, Value],
+ *,
+ iterator_interchange: OptionalIntList = None,
+ loc=None,
+ ip=None,
+ ):
+ transformed_type = transform.AnyOpType.get()
+ super().__init__(
+ transformed_type,
+ target,
+ iterator_interchange=iterator_interchange,
+ loc=loc,
+ ip=ip,
+ )
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class MapCopyToThreadsOp(MapCopyToThreadsOp):
+ """Specialization for MapCopyToThreadsOp class."""
+
+ @overload
+ def __init__(
+ self,
+ forall_op_type: Type,
+ tiled_op_type: Type,
+ target: Union[Operation, OpView, Value],
+ *,
+ total_num_threads: Union[int, IntegerAttr],
+ desired_bit_alignment: Union[int, IntegerAttr],
+ loc=None,
+ ip=None,
+ ):
+ ...
+
+ @overload
+ def __init__(
+ self,
+ target: Union[Operation, OpView, Value],
+ *,
+ total_num_threads: Union[int, IntegerAttr],
+ desired_bit_alignment: Union[int, IntegerAttr],
+ loc=None,
+ ip=None,
+ ):
+ ...
+
+ def __init__(
+ self,
+ forall_op_type_or_target: Union[Operation, OpView, Type, Value],
+ tiled_op_type_or_none: Optional[Type] = None,
+ target_or_none: Optional[Union[Operation, OpView, Value]] = None,
+ *,
+ total_num_threads: Union[int, IntegerAttr],
+ desired_bit_alignment: Union[int, IntegerAttr],
+ loc=None,
+ ip=None,
+ ):
+ if isinstance(forall_op_type_or_target, Type):
+ forall_op_type = forall_op_type_or_target
+ tiled_op_type = tiled_op_type_or_none
+ target = target_or_none
+ else:
+ forall_op_type = transform.AnyOpType.get()
+ tiled_op_type = transform.AnyOpType.get()
+ target = forall_op_type_or_target
+
+ super().__init__(
+ forall_op_type,
+ tiled_op_type,
+ target,
+ total_num_threads=total_num_threads,
+ desired_bit_alignment=desired_bit_alignment,
+ loc=loc,
+ ip=ip,
+ )
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class VectorizeOp(VectorizeOp):
+ """Specialization for VectorizeOp class."""
+
+ def __init__(
+ self,
+ target: Union[Operation, OpView, Value],
+ vector_sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
+ *,
+ vectorize_nd_extract: Optional[bool] = None,
+ scalable_sizes: OptionalBoolList = None,
+ static_vector_sizes: OptionalIntList = None,
+ loc=None,
+ ip=None,
+ ):
+ if (
+ scalable_sizes is None
+ and static_vector_sizes is None
+ and vector_sizes is None
+ ):
+ dynamic_vector_sizes = []
+ elif scalable_sizes is None and static_vector_sizes is None:
+ (
+ dynamic_vector_sizes,
+ static_vector_sizes,
+ scalable_sizes,
+ ) = _dispatch_dynamic_index_list(vector_sizes)
+ elif scalable_sizes is None or static_vector_sizes is None:
+ raise TypeError(
+ "'scalable_sizes' and 'static_vector_sizes' must either both "
+ "be given explicitly or both be given as part of 'vector_sizes'."
+ )
+ else:
+ dynamic_vector_sizes = vector_sizes
+
+ super().__init__(
+ target,
+ vector_sizes=dynamic_vector_sizes,
+ static_vector_sizes=static_vector_sizes,
+ scalable_sizes=scalable_sizes,
+ vectorize_nd_extract=vectorize_nd_extract,
+ loc=loc,
+ ip=ip,
+ )
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class MatchOp(MatchOp):
+ """Specialization for MatchOp class."""
+
+ @overload
+ @classmethod
+ def match_op_names(
+ cls,
+ target: Union[Operation, Value],
+ names: Union[str, Sequence[str]],
+ *,
+ loc=None,
+ ip=None,
+ ):
+ ...
+
+ @overload
+ @classmethod
+ def match_op_names(
+ cls,
+ result_type: Type,
+ target: Union[Operation, Value],
+ names: Union[str, Sequence[str]],
+ *,
+ loc=None,
+ ip=None,
+ ):
+ ...
+
+ @classmethod
+ def match_op_names(
+ cls,
+ result_type_or_target: Union[Type, Operation, Value],
+ target_or_names: Union[Operation, Value, Sequence[str], str],
+ names_or_none: Optional[Union[Sequence[str], str]] = None,
+ *,
+ loc=None,
+ ip=None,
+ ):
+ if isinstance(result_type_or_target, Type):
+ result_type = result_type_or_target
+ target = target_or_names
+ names = names_or_none
+ else:
+ result_type = transform.AnyOpType.get()
+ target = result_type_or_target
+ names = target_or_names
+
+ if isinstance(names, str):
+ names = [names]
+
+ return cls(
+ result_type,
+ target,
+ ops=ArrayAttr.get(list(map(lambda s: StringAttr.get(s), names))),
+ loc=loc,
+ ip=ip,
+ )
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class MultiTileSizesOp(MultiTileSizesOp):
+ """Specialization for MultiTileSizesOp class."""
+
+ def __init__(
+ self,
+ result_type: Type,
+ target: Union[Operation, Value],
+ *,
+ dimension: Union[int, IntegerAttr],
+ target_size: Union[int, IntegerAttr],
+ divisor: Optional[Optional[Union[int, IntegerAttr]]] = None,
+ loc=None,
+ ip=None,
+ ):
+ super().__init__(
+ result_type,
+ result_type,
+ result_type,
+ target,
+ dimension=dimension,
+ target_size=target_size,
+ divisor=divisor,
+ loc=loc,
+ ip=ip,
+ )
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class PadOp(PadOp):
+ """Specialization for PadOp class."""
+
+ def __init__(
+ self,
+ target: Union[Operation, OpView, Value],
+ *,
+ padding_values: Optional[Union[ArrayAttr, Sequence[Attribute]]] = None,
+ padding_dimensions: OptionalIntList = None,
+ pad_to_multiple_of: OptionalIntList = None,
+ pack_paddings: OptionalIntList = None,
+ transpose_paddings: Optional[
+ Union[ArrayAttr, Sequence[Union[ArrayAttr, IntOrAttrList]]]
+ ] = None,
+ copy_back_op: Optional[Union[str, StringAttr]] = None,
+ loc=None,
+ ip=None,
+ ):
+ transpose_paddings = _get_int_array_array_attr(transpose_paddings)
+
+ any_op_type = transform.AnyOpType.get()
+ super().__init__(
+ any_op_type,
+ any_op_type,
+ any_op_type,
+ target,
+ padding_values=padding_values,
+ padding_dimensions=padding_dimensions,
+ pad_to_multiple_of=pad_to_multiple_of,
+ pack_paddings=pack_paddings,
+ transpose_paddings=transpose_paddings,
+ copy_back_op=copy_back_op,
+ loc=loc,
+ ip=ip,
+ )
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class ScalarizeOp(ScalarizeOp):
+ """Specialization for ScalarizeOp class."""
+
+ def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
+ result_type = transform.AnyOpType.get()
+ super().__init__(result_type, target, loc=loc, ip=ip)
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class SplitOp(SplitOp):
+ """Specialization for SplitOp class."""
+
+ def __init__(
+ self,
+ target: Union[Operation, Value],
+ dimension: Union[int, Attribute],
+ split_point: Union[int, Operation, Value, Attribute],
+ *,
+ loc=None,
+ ip=None,
+ ):
+ if isinstance(split_point, int):
+ static_split_point = split_point
+ dynamic_split_point = None
+ else:
+ static_split_point = ShapedType.get_dynamic_size()
+ dynamic_split_point = split_point
+
+ super().__init__(
+ target.type,
+ target.type,
+ target,
+ dimension=dimension,
+ static_split_point=static_split_point,
+ dynamic_split_point=dynamic_split_point,
+ loc=loc,
+ ip=ip,
+ )
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class TileUsingForOp(TileUsingForOp):
+ """Specialization for TileUsingForOp class."""
+
+ @overload
+ def __init__(
+ self,
+ loop_types: Union[Type, List[Type]],
+ target: Union[Operation, Value],
+ *,
+ sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
+ interchange: OptionalIntList = None,
+ loc=None,
+ ip=None,
+ ):
+ ...
+
+ @overload
+ def __init__(
+ self,
+ target: Union[Operation, Value, OpView],
+ *,
+ sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
+ interchange: OptionalIntList = None,
+ loc=None,
+ ip=None,
+ ):
+ ...
+
+ def __init__(
+ self,
+ loop_types_or_target: Union[Type, List[Type], Operation, Value],
+ target_or_none: Optional[Union[Operation, Value, OpView]] = None,
+ *,
+ sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
+ interchange: OptionalIntList = None,
+ loc=None,
+ ip=None,
+ ):
+ (
+ dynamic_sizes,
+ static_sizes,
+ scalable_sizes,
+ ) = _dispatch_dynamic_index_list(sizes)
+
+ num_loops = sum(v if v == 0 else 1 for v in static_sizes)
+
+ if isinstance(loop_types_or_target, (Operation, Value, OpView)):
+ loop_types = [transform.AnyOpType.get()] * num_loops
+ target = loop_types_or_target
+ assert (
+ target_or_none is None
+ ), "Cannot construct TileUsingForOp with two targets."
+ else:
+ loop_types = (
+ ([loop_types_or_target] * num_loops)
+ if isinstance(loop_types_or_target, Type)
+ else loop_types_or_target
+ )
+ target = target_or_none
+
+ super().__init__(
+ target.type,
+ loop_types,
+ target,
+ dynamic_sizes=dynamic_sizes,
+ static_sizes=static_sizes,
+ interchange=interchange,
+ scalable_sizes=scalable_sizes,
+ loc=loc,
+ ip=ip,
+ )
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class TileUsingForallOp(TileUsingForallOp):
+ """Specialization for TileUsingForallOp class."""
+
+ @overload
+ def __init__(
+ self,
+ loops_type: Type,
+ tiled_op_type: Type,
+ target: Union[Operation, Value, OpView],
+ *,
+ num_threads: Optional[MixedValues] = None,
+ tile_sizes: MixedValues = None,
+ mapping=None,
+ loc=None,
+ ip=None,
+ ):
+ ...
+
+ @overload
+ def __init__(
+ self,
+ target: Union[Operation, Value, OpView],
+ *,
+ num_threads: Optional[MixedValues] = None,
+ tile_sizes: MixedValues = None,
+ mapping=None,
+ loc=None,
+ ip=None,
+ ):
+ ...
+
+ def __init__(
+ self,
+ loops_type_or_target: Union[
+ Type, Union[Operation, Value, OpView] # loops_type
+ ], # target
+ tiled_op_type_or_none: Optional[Type] = None,
+ target_or_none: Optional[Union[Operation, Value, OpView]] = None,
+ *,
+ num_threads: MixedValues = None,
+ tile_sizes: MixedValues = None,
+ mapping=None,
+ loc=None,
+ ip=None,
+ ):
+ # `Type` arguments in the front are optional: add default values to front.
+ if isinstance(loops_type_or_target, Type):
+ # First overload: type arguments provided.
+ if not isinstance(tiled_op_type_or_none, Type):
+ raise TypeError(
+ "If 'loops_type_or_target' is a type, then "
+ "'tiled_op_type_or_none' is expected to be one as well."
+ )
+ loops_type = loops_type_or_target
+ tiled_op_type = tiled_op_type_or_none
+ target = target_or_none
+ else:
+ # Last overload: type arguments missing.
+ loops_type = transform.AnyOpType.get()
+ tiled_op_type = transform.AnyOpType.get()
+ target = loops_type_or_target
+
+ # Unpack mixed num_threads.
+ (
+ dynamic_num_threads,
+ packed_num_threads,
+ num_threads_attr,
+ ) = _dispatch_mixed_values(num_threads)
+
+ # Unpack mixed tile_sizes.
+ (
+ dynamic_tile_sizes,
+ packed_tile_sizes,
+ tile_sizes_attr,
+ ) = _dispatch_mixed_values(tile_sizes)
+
+ super().__init__(
+ loops_type,
+ tiled_op_type,
+ target=target,
+ tile_sizes=dynamic_tile_sizes,
+ packed_tile_sizes=packed_tile_sizes,
+ static_tile_sizes=tile_sizes_attr,
+ num_threads=dynamic_num_threads,
+ packed_num_threads=packed_num_threads,
+ static_num_threads=num_threads_attr,
+ mapping=mapping,
+ loc=loc,
+ ip=ip,
+ )
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class VectorizeChildrenAndApplyPatternsOp(VectorizeChildrenAndApplyPatternsOp):
+ """Specialization for VectorizeChildrenAndApplyPatternsOp class."""
+
+ def __init__(
+ self,
+ target: Union[Operation, Value],
+ *,
+ disable_multi_reduction_to_contract_patterns: bool = False,
+ disable_transfer_permutation_map_lowering_patterns: bool = False,
+ vectorize_nd_extract: bool = False,
+ vectorize_padding: bool = False,
+ loc=None,
+ ip=None,
+ ):
+ transformed_type = transform.AnyOpType.get()
+ super().__init__(
+ transformed_type,
+ target,
+ disable_multi_reduction_to_contract_patterns=disable_multi_reduction_to_contract_patterns,
+ disable_transfer_permutation_map_lowering_patterns=disable_transfer_permutation_map_lowering_patterns,
+ vectorize_nd_extract=vectorize_nd_extract,
+ vectorize_padding=vectorize_padding,
+ loc=loc,
+ ip=ip,
+ )
diff --git a/mlir/python/mlir/dialects/transform/tensor.py b/mlir/python/mlir/dialects/transform/tensor.py
index bf52255b3df7145..4eb30398f087212 100644
--- a/mlir/python/mlir/dialects/transform/tensor.py
+++ b/mlir/python/mlir/dialects/transform/tensor.py
@@ -3,3 +3,67 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from .._tensor_transform_ops_gen import *
+from .._tensor_transform_ops_gen import _Dialect
+
+try:
+ from ...ir import *
+ from ...dialects import transform
+ from .._ods_common import _cext as _ods_cext
+except ImportError as e:
+ raise RuntimeError("Error loading imports from extension module") from e
+
+from typing import Optional, overload, Union
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class MakeLoopIndependentOp(MakeLoopIndependentOp):
+ """Specialization for MakeLoopIndependentOp class."""
+
+ @overload
+ def __init__(
+ self,
+ transformed_type: Type,
+ target: Union[Operation, OpView, Value],
+ num_loops: Union[int, IntegerAttr],
+ *,
+ loc=None,
+ ip=None,
+ ):
+ ...
+
+ @overload
+ def __init__(
+ self,
+ target: Union[Operation, OpView, Value],
+ num_loops: Union[int, IntegerAttr],
+ *,
+ loc=None,
+ ip=None,
+ ):
+ ...
+
+ def __init__(
+ self,
+ transformed_type_or_target: Type,
+ target_or_num_loops: Union[int, IntegerAttr, Operation, OpView, Value] = None,
+ num_loops_or_none: Optional[Union[int, IntegerAttr]] = None,
+ *,
+ loc=None,
+ ip=None,
+ ):
+ if isinstance(transformed_type_or_target, Type):
+ transformed_type = transformed_type_or_target
+ target = target_or_num_loops
+ num_loops = num_loops_or_none
+ else:
+ transformed_type = transform.AnyOpType.get()
+ target = transformed_type_or_target
+ num_loops = target_or_num_loops
+
+ super().__init__(
+ transformed_type,
+ target,
+ num_loops,
+ loc=loc,
+ ip=ip,
+ )
diff --git a/mlir/python/mlir/runtime/np_to_memref.py b/mlir/python/mlir/runtime/np_to_memref.py
index 0a3b411041b2f4d..f6b706f9bc8ae24 100644
--- a/mlir/python/mlir/runtime/np_to_memref.py
+++ b/mlir/python/mlir/runtime/np_to_memref.py
@@ -114,6 +114,7 @@ def get_unranked_memref_descriptor(nparray):
d.descriptor = ctypes.cast(ctypes.pointer(x), ctypes.c_void_p)
return d
+
def move_aligned_ptr_by_offset(aligned_ptr, offset):
"""Moves the supplied ctypes pointer ahead by `offset` elements."""
aligned_addr = ctypes.addressof(aligned_ptr.contents)
@@ -122,6 +123,7 @@ def move_aligned_ptr_by_offset(aligned_ptr, offset):
content_ptr = ctypes.cast(aligned_addr + shift, type(aligned_ptr))
return content_ptr
+
def unranked_memref_to_numpy(unranked_memref, np_dtype):
"""Converts unranked memrefs to numpy arrays."""
ctp = as_ctype(np_dtype)
@@ -139,10 +141,10 @@ def unranked_memref_to_numpy(unranked_memref, np_dtype):
def ranked_memref_to_numpy(ranked_memref):
"""Converts ranked memrefs to numpy arrays."""
- content_ptr = move_aligned_ptr_by_offset(ranked_memref[0].aligned, ranked_memref[0].offset)
- np_arr = np.ctypeslib.as_array(
- content_ptr, shape=ranked_memref[0].shape
+ content_ptr = move_aligned_ptr_by_offset(
+ ranked_memref[0].aligned, ranked_memref[0].offset
)
+ np_arr = np.ctypeslib.as_array(content_ptr, shape=ranked_memref[0].shape)
strided_arr = np.lib.stride_tricks.as_strided(
np_arr,
np.ctypeslib.as_array(ranked_memref[0].shape),
diff --git a/mlir/test/python/dialects/arith_dialect.py b/mlir/test/python/dialects/arith_dialect.py
index f4a793aee4aa14c..6d1c5eab7589847 100644
--- a/mlir/test/python/dialects/arith_dialect.py
+++ b/mlir/test/python/dialects/arith_dialect.py
@@ -33,3 +33,16 @@ def testFastMathFlags():
)
# CHECK: %0 = arith.addf %cst, %cst fastmath<nnan,ninf> : f32
print(r)
+
+
+# CHECK-LABEL: TEST: testArithValueBuilder
+ at run
+def testArithValueBuilder():
+ with Context() as ctx, Location.unknown():
+ module = Module.create()
+ f32_t = F32Type.get()
+
+ with InsertionPoint(module.body):
+ a = arith.constant(value=FloatAttr.get(f32_t, 42.42))
+ # CHECK: %cst = arith.constant 4.242000e+01 : f32
+ print(a)
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 49f3a951426d0ee..c8ef84721090ab9 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -30,14 +30,9 @@ constexpr const char *fileHeader = R"Py(
# Autogenerated by mlir-tblgen; don't manually edit.
from ._ods_common import _cext as _ods_cext
-from ._ods_common import extend_opview_class as _ods_extend_opview_class, segmented_accessor as _ods_segmented_accessor, equally_sized_accessor as _ods_equally_sized_accessor, get_default_loc_context as _ods_get_default_loc_context, get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values, get_op_result_or_op_results as _get_op_result_or_op_results
+from ._ods_common import segmented_accessor as _ods_segmented_accessor, equally_sized_accessor as _ods_equally_sized_accessor, get_default_loc_context as _ods_get_default_loc_context, get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values, get_op_result_or_op_results as _get_op_result_or_op_results
_ods_ir = _ods_cext.ir
-try:
- from . import _{0}_ops_ext as _ods_ext_module
-except ImportError:
- _ods_ext_module = None
-
import builtins
from typing import Sequence as _Sequence, Union as _Union
@@ -62,7 +57,6 @@ from ._{0}_ops_gen import _Dialect
/// {1} is the operation name.
constexpr const char *opClassTemplate = R"Py(
@_ods_cext.register_operation(_Dialect)
- at _ods_extend_opview_class(_ods_ext_module)
class {0}(_ods_ir.OpView):
OPERATION_NAME = "{1}"
)Py";
@@ -301,17 +295,17 @@ static bool isODSReserved(StringRef str) {
/// (does not change the `name` if it already is suitable) and returns the
/// modified version.
static std::string sanitizeName(StringRef name) {
- std::string processed_str = name.str();
+ std::string processedStr = name.str();
std::replace_if(
- processed_str.begin(), processed_str.end(),
+ processedStr.begin(), processedStr.end(),
[](char c) { return !llvm::isAlnum(c); }, '_');
- if (llvm::isDigit(*processed_str.begin()))
- return "_" + processed_str;
+ if (llvm::isDigit(*processedStr.begin()))
+ return "_" + processedStr;
- if (isPythonReserved(processed_str) || isODSReserved(processed_str))
- return processed_str + "_";
- return processed_str;
+ if (isPythonReserved(processedStr) || isODSReserved(processedStr))
+ return processedStr + "_";
+ return processedStr;
}
static std::string attrSizedTraitForKind(const char *kind) {
@@ -853,10 +847,6 @@ populateBuilderRegions(const Operator &op,
/// rebuild anew).
static llvm::SmallVector<std::string> emitDefaultOpBuilder(const Operator &op,
raw_ostream &os) {
- // If we are asked to skip default builders, comply.
- if (op.skipDefaultBuilders())
- return {};
-
llvm::SmallVector<std::string> builderArgs;
llvm::SmallVector<std::string> builderLines;
llvm::SmallVector<std::string> operandArgNames;
@@ -989,9 +979,6 @@ static void emitRegionAccessors(const Operator &op, raw_ostream &os) {
static void emitValueBuilder(const Operator &op,
llvm::SmallVector<std::string> functionArgs,
raw_ostream &os) {
- // If we are asked to skip default builders, comply.
- if (op.skipDefaultBuilders())
- return;
// Params with (possibly) default args.
auto valueBuilderParams =
llvm::map_range(functionArgs, [](const std::string &argAndMaybeDefault) {
@@ -1010,9 +997,9 @@ static void emitValueBuilder(const Operator &op,
auto lhs = *llvm::split(arg, "=").begin();
return (lhs + "=" + llvm::convertToSnakeFromCamelCase(lhs)).str();
});
- std::string name_without_dialect =
+ std::string nameWithoutDialect =
op.getOperationName().substr(op.getOperationName().find('.') + 1);
- os << llvm::formatv(valueBuilderTemplate, sanitizeName(name_without_dialect),
+ os << llvm::formatv(valueBuilderTemplate, sanitizeName(nameWithoutDialect),
op.getCppClassName(),
llvm::join(valueBuilderParams, ", "),
llvm::join(opBuilderArgs, ", "),
@@ -1051,11 +1038,8 @@ static bool emitAllOps(const llvm::RecordKeeper &records, raw_ostream &os) {
if (clDialectName.empty())
llvm::PrintFatalError("dialect name not provided");
- bool isExtension = !clDialectExtensionName.empty();
- os << llvm::formatv(fileHeader, isExtension
- ? clDialectExtensionName.getValue()
- : clDialectName.getValue());
- if (isExtension)
+ os << fileHeader;
+ if (!clDialectExtensionName.empty())
os << llvm::formatv(dialectExtensionTemplate, clDialectName.getValue());
else
os << llvm::formatv(dialectClassTemplate, clDialectName.getValue());
More information about the Mlir-commits
mailing list