[Mlir-commits] [mlir] [mlir][python] remove mixins (PR #68847)
Maksim Levental
llvmlistbot at llvm.org
Wed Oct 11 21:36:36 PDT 2023
https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/68847
>From 9857375418487ead7153db8fd8156837ff58feff Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Wed, 11 Oct 2023 23:32:14 -0500
Subject: [PATCH] [mlir][python] remove mixins
---
mlir/lib/Bindings/Python/Globals.h | 2 +-
mlir/lib/Bindings/Python/IRModule.cpp | 4 +-
mlir/lib/Bindings/Python/MainModule.cpp | 11 +-
mlir/python/CMakeLists.txt | 18 -
mlir/python/mlir/dialects/_arith_ops_ext.py | 69 --
.../mlir/dialects/_bufferization_ops_ext.py | 41 -
.../_bufferization_transform_ops_ext.py | 128 ---
mlir/python/mlir/dialects/_builtin_ops_ext.py | 20 -
mlir/python/mlir/dialects/_func_ops_ext.py | 319 --------
.../mlir/dialects/_gpu_transform_ops_ext.py | 124 ---
mlir/python/mlir/dialects/_linalg_ops_ext.py | 47 --
.../mlir/dialects/_loop_transform_ops_ext.py | 134 ---
mlir/python/mlir/dialects/_memref_ops_ext.py | 36 -
.../dialects/_memref_transform_ops_ext.py | 114 ---
.../mlir/dialects/_ml_program_ops_ext.py | 113 ---
mlir/python/mlir/dialects/_ods_common.py | 59 --
mlir/python/mlir/dialects/_pdl_ops_ext.py | 271 ------
mlir/python/mlir/dialects/_scf_ops_ext.py | 107 ---
.../dialects/_structured_transform_ops_ext.py | 759 -----------------
mlir/python/mlir/dialects/_tensor_ops_ext.py | 44 -
.../dialects/_tensor_transform_ops_ext.py | 64 --
.../mlir/dialects/_transform_ops_ext.py | 176 ----
.../_transform_pdl_extension_ops_ext.py | 55 --
mlir/python/mlir/dialects/arith.py | 67 ++
mlir/python/mlir/dialects/bufferization.py | 43 +
mlir/python/mlir/dialects/builtin.py | 24 +
mlir/python/mlir/dialects/func.py | 323 ++++++++
.../dialects/linalg/opdsl/lang/emitter.py | 2 +-
mlir/python/mlir/dialects/memref.py | 38 +
mlir/python/mlir/dialects/ml_program.py | 114 +++
mlir/python/mlir/dialects/pdl.py | 285 +++++++
mlir/python/mlir/dialects/scf.py | 115 ++-
mlir/python/mlir/dialects/tensor.py | 47 ++
.../mlir/dialects/transform/__init__.py | 185 +++++
.../mlir/dialects/transform/bufferization.py | 129 +++
mlir/python/mlir/dialects/transform/gpu.py | 125 +++
mlir/python/mlir/dialects/transform/loop.py | 140 ++++
mlir/python/mlir/dialects/transform/memref.py | 115 +++
mlir/python/mlir/dialects/transform/pdl.py | 50 ++
.../mlir/dialects/transform/structured.py | 773 ++++++++++++++++++
mlir/python/mlir/dialects/transform/tensor.py | 64 ++
mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp | 9 +-
42 files changed, 2646 insertions(+), 2717 deletions(-)
delete mode 100644 mlir/python/mlir/dialects/_arith_ops_ext.py
delete mode 100644 mlir/python/mlir/dialects/_bufferization_ops_ext.py
delete mode 100644 mlir/python/mlir/dialects/_bufferization_transform_ops_ext.py
delete mode 100644 mlir/python/mlir/dialects/_builtin_ops_ext.py
delete mode 100644 mlir/python/mlir/dialects/_func_ops_ext.py
delete mode 100644 mlir/python/mlir/dialects/_gpu_transform_ops_ext.py
delete mode 100644 mlir/python/mlir/dialects/_linalg_ops_ext.py
delete mode 100644 mlir/python/mlir/dialects/_loop_transform_ops_ext.py
delete mode 100644 mlir/python/mlir/dialects/_memref_ops_ext.py
delete mode 100644 mlir/python/mlir/dialects/_memref_transform_ops_ext.py
delete mode 100644 mlir/python/mlir/dialects/_ml_program_ops_ext.py
delete mode 100644 mlir/python/mlir/dialects/_pdl_ops_ext.py
delete mode 100644 mlir/python/mlir/dialects/_scf_ops_ext.py
delete mode 100644 mlir/python/mlir/dialects/_structured_transform_ops_ext.py
delete mode 100644 mlir/python/mlir/dialects/_tensor_ops_ext.py
delete mode 100644 mlir/python/mlir/dialects/_tensor_transform_ops_ext.py
delete mode 100644 mlir/python/mlir/dialects/_transform_ops_ext.py
delete mode 100644 mlir/python/mlir/dialects/_transform_pdl_extension_ops_ext.py
diff --git a/mlir/lib/Bindings/Python/Globals.h b/mlir/lib/Bindings/Python/Globals.h
index 97cd70089a2e965..dea44bbd469dd3d 100644
--- a/mlir/lib/Bindings/Python/Globals.h
+++ b/mlir/lib/Bindings/Python/Globals.h
@@ -80,7 +80,7 @@ class PyGlobals {
/// Raises an exception if the mapping already exists.
/// This is intended to be called by implementation code.
void registerOperationImpl(const std::string &operationName,
- pybind11::object pyClass);
+ pybind11::object pyClass, bool replace = false);
/// Returns the custom Attribute builder for Attribute kind.
std::optional<pybind11::function>
diff --git a/mlir/lib/Bindings/Python/IRModule.cpp b/mlir/lib/Bindings/Python/IRModule.cpp
index 2cc66277abee0f0..a1c8ab7a09ce155 100644
--- a/mlir/lib/Bindings/Python/IRModule.cpp
+++ b/mlir/lib/Bindings/Python/IRModule.cpp
@@ -96,9 +96,9 @@ void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
}
void PyGlobals::registerOperationImpl(const std::string &operationName,
- py::object pyClass) {
+ py::object pyClass, bool replace) {
py::object &found = operationClassMap[operationName];
- if (found) {
+ if (found && !replace) {
throw std::runtime_error((llvm::Twine("Operation '") + operationName +
"' is already registered.")
.str());
diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index cdddfbe50606d05..a936becf67bea75 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -41,7 +41,7 @@ PYBIND11_MODULE(_mlir, m) {
"dialect_namespace"_a, "dialect_class"_a,
"Testing hook for directly registering a dialect")
.def("_register_operation_impl", &PyGlobals::registerOperationImpl,
- "operation_name"_a, "operation_class"_a,
+ "operation_name"_a, "operation_class"_a, "replace"_a = false,
"Testing hook for directly registering an operation");
// Aside from making the globals accessible to python, having python manage
@@ -63,12 +63,13 @@ PYBIND11_MODULE(_mlir, m) {
"Class decorator for registering a custom Dialect wrapper");
m.def(
"register_operation",
- [](const py::object &dialectClass) -> py::cpp_function {
+ [](const py::object &dialectClass, bool replace) -> py::cpp_function {
return py::cpp_function(
- [dialectClass](py::object opClass) -> py::object {
+ [dialectClass, replace](py::object opClass) -> py::object {
std::string operationName =
opClass.attr("OPERATION_NAME").cast<std::string>();
- PyGlobals::get().registerOperationImpl(operationName, opClass);
+ PyGlobals::get().registerOperationImpl(operationName, opClass,
+ replace);
// Dict-stuff the new opClass by name onto the dialect class.
py::object opClassName = opClass.attr("__name__");
@@ -76,7 +77,7 @@ PYBIND11_MODULE(_mlir, m) {
return opClass;
});
},
- "dialect_class"_a,
+ "dialect_class"_a, "replace"_a = false,
"Produce a class decorator for registering an Operation class as part of "
"a dialect");
m.def(
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 088d9a765b97730..2eff1cc7c588d8a 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -68,7 +68,6 @@ declare_mlir_dialect_python_bindings(
TD_FILE dialects/BufferizationOps.td
SOURCES
dialects/bufferization.py
- dialects/_bufferization_ops_ext.py
DIALECT_NAME bufferization
GEN_ENUM_BINDINGS_TD_FILE
"../../include/mlir/Dialect/Bufferization/IR/BufferizationEnums.td"
@@ -80,7 +79,6 @@ declare_mlir_dialect_python_bindings(
TD_FILE dialects/BuiltinOps.td
SOURCES
dialects/builtin.py
- dialects/_builtin_ops_ext.py
DIALECT_NAME builtin)
declare_mlir_dialect_python_bindings(
@@ -105,7 +103,6 @@ declare_mlir_dialect_python_bindings(
TD_FILE dialects/FuncOps.td
SOURCES
dialects/func.py
- dialects/_func_ops_ext.py
DIALECT_NAME func)
declare_mlir_dialect_python_bindings(
@@ -121,7 +118,6 @@ declare_mlir_dialect_python_bindings(
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
TD_FILE dialects/LinalgOps.td
SOURCES
- dialects/_linalg_ops_ext.py
SOURCES_GLOB
dialects/linalg/*.py
DIALECT_NAME linalg
@@ -142,7 +138,6 @@ ADD_TO_PARENT MLIRPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
TD_FILE dialects/TransformPDLExtensionOps.td
SOURCES
- dialects/_transform_pdl_extension_ops_ext.py
dialects/transform/pdl.py
DIALECT_NAME transform
EXTENSION_NAME transform_pdl_extension)
@@ -152,7 +147,6 @@ declare_mlir_dialect_python_bindings(
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
TD_FILE dialects/TransformOps.td
SOURCES
- dialects/_transform_ops_ext.py
dialects/transform/__init__.py
_mlir_libs/_mlir/dialects/transform/__init__.pyi
DIALECT_NAME transform
@@ -165,7 +159,6 @@ declare_mlir_dialect_extension_python_bindings(
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
TD_FILE dialects/BufferizationTransformOps.td
SOURCES
- dialects/_bufferization_transform_ops_ext.py
dialects/transform/bufferization.py
DIALECT_NAME transform
EXTENSION_NAME bufferization_transform)
@@ -175,7 +168,6 @@ declare_mlir_dialect_extension_python_bindings(
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
TD_FILE dialects/GPUTransformOps.td
SOURCES
- dialects/_gpu_transform_ops_ext.py
dialects/transform/gpu.py
DIALECT_NAME transform
EXTENSION_NAME gpu_transform)
@@ -185,7 +177,6 @@ declare_mlir_dialect_extension_python_bindings(
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
TD_FILE dialects/SCFLoopTransformOps.td
SOURCES
- dialects/_loop_transform_ops_ext.py
dialects/transform/loop.py
DIALECT_NAME transform
EXTENSION_NAME loop_transform)
@@ -195,7 +186,6 @@ declare_mlir_dialect_extension_python_bindings(
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
TD_FILE dialects/MemRefTransformOps.td
SOURCES
- dialects/_memref_transform_ops_ext.py
dialects/transform/memref.py
DIALECT_NAME transform
EXTENSION_NAME memref_transform)
@@ -214,7 +204,6 @@ declare_mlir_dialect_extension_python_bindings(
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
TD_FILE dialects/LinalgStructuredTransformOps.td
SOURCES
- dialects/_structured_transform_ops_ext.py
dialects/transform/structured.py
DIALECT_NAME transform
EXTENSION_NAME structured_transform
@@ -236,7 +225,6 @@ declare_mlir_dialect_extension_python_bindings(
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
TD_FILE dialects/TensorTransformOps.td
SOURCES
- dialects/_tensor_transform_ops_ext.py
dialects/transform/tensor.py
DIALECT_NAME transform
EXTENSION_NAME tensor_transform)
@@ -266,7 +254,6 @@ declare_mlir_dialect_python_bindings(
TD_FILE dialects/ArithOps.td
SOURCES
dialects/arith.py
- dialects/_arith_ops_ext.py
DIALECT_NAME arith
GEN_ENUM_BINDINGS)
@@ -276,7 +263,6 @@ declare_mlir_dialect_python_bindings(
TD_FILE dialects/MemRefOps.td
SOURCES
dialects/memref.py
- dialects/_memref_ops_ext.py
DIALECT_NAME memref)
declare_mlir_dialect_python_bindings(
@@ -285,7 +271,6 @@ declare_mlir_dialect_python_bindings(
TD_FILE dialects/MLProgramOps.td
SOURCES
dialects/ml_program.py
- dialects/_ml_program_ops_ext.py
DIALECT_NAME ml_program)
declare_mlir_dialect_python_bindings(
@@ -329,7 +314,6 @@ declare_mlir_dialect_python_bindings(
TD_FILE dialects/PDLOps.td
SOURCES
dialects/pdl.py
- dialects/_pdl_ops_ext.py
_mlir_libs/_mlir/dialects/pdl.pyi
DIALECT_NAME pdl)
@@ -347,7 +331,6 @@ declare_mlir_dialect_python_bindings(
TD_FILE dialects/SCFOps.td
SOURCES
dialects/scf.py
- dialects/_scf_ops_ext.py
DIALECT_NAME scf)
declare_mlir_dialect_python_bindings(
@@ -373,7 +356,6 @@ declare_mlir_dialect_python_bindings(
TD_FILE dialects/TensorOps.td
SOURCES
dialects/tensor.py
- dialects/_tensor_ops_ext.py
DIALECT_NAME tensor)
declare_mlir_dialect_python_bindings(
diff --git a/mlir/python/mlir/dialects/_arith_ops_ext.py b/mlir/python/mlir/dialects/_arith_ops_ext.py
deleted file mode 100644
index df38f871710fe8f..000000000000000
--- a/mlir/python/mlir/dialects/_arith_ops_ext.py
+++ /dev/null
@@ -1,69 +0,0 @@
-# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-try:
- from ..ir import *
- from ._ods_common import get_default_loc_context as _get_default_loc_context
-
- from typing import Any, List, Union
-except ImportError as e:
- raise RuntimeError("Error loading imports from extension module") from e
-
-
-def _isa(obj: Any, cls: type):
- try:
- cls(obj)
- except ValueError:
- return False
- return True
-
-
-def _is_any_of(obj: Any, classes: List[type]):
- return any(_isa(obj, cls) for cls in classes)
-
-
-def _is_integer_like_type(type: Type):
- return _is_any_of(type, [IntegerType, IndexType])
-
-
-def _is_float_type(type: Type):
- return _is_any_of(type, [BF16Type, F16Type, F32Type, F64Type])
-
-
-class ConstantOp:
- """Specialization for the constant op class."""
-
- def __init__(
- self, result: Type, value: Union[int, float, Attribute], *, loc=None, ip=None
- ):
- if isinstance(value, int):
- super().__init__(IntegerAttr.get(result, value), loc=loc, ip=ip)
- elif isinstance(value, float):
- super().__init__(FloatAttr.get(result, value), loc=loc, ip=ip)
- else:
- super().__init__(value, loc=loc, ip=ip)
-
- @classmethod
- def create_index(cls, value: int, *, loc=None, ip=None):
- """Create an index-typed constant."""
- return cls(
- IndexType.get(context=_get_default_loc_context(loc)), value, loc=loc, ip=ip
- )
-
- @property
- def type(self):
- return self.results[0].type
-
- @property
- def value(self):
- return Attribute(self.operation.attributes["value"])
-
- @property
- def literal_value(self) -> Union[int, float]:
- if _is_integer_like_type(self.type):
- return IntegerAttr(self.value).value
- elif _is_float_type(self.type):
- return FloatAttr(self.value).value
- else:
- raise ValueError("only integer and float constants have literal values")
diff --git a/mlir/python/mlir/dialects/_bufferization_ops_ext.py b/mlir/python/mlir/dialects/_bufferization_ops_ext.py
deleted file mode 100644
index 1066cb4c775cab9..000000000000000
--- a/mlir/python/mlir/dialects/_bufferization_ops_ext.py
+++ /dev/null
@@ -1,41 +0,0 @@
-# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-try:
- from typing import Sequence, Union
- from ..ir import *
- from ._ods_common import get_default_loc_context
-
- from typing import Any, List, Union
-except ImportError as e:
- raise RuntimeError("Error loading imports from extension module") from e
-
-
-class AllocTensorOp:
- """Extends the bufferization.alloc_tensor op."""
-
- def __init__(
- self,
- tensor_type: Type,
- dynamic_sizes: Sequence[Value],
- copy: Value,
- size_hint: Value,
- escape: BoolAttr,
- *,
- loc=None,
- ip=None
- ):
- """Constructs an `alloc_tensor` with static and/or dynamic sizes."""
- context = get_default_loc_context(loc)
- attributes = {}
- if escape:
- attributes["escape"] = escape
- op = self.build_generic(
- results=[tensor_type],
- operands=[dynamic_sizes, copy, size_hint],
- attributes=attributes,
- loc=loc,
- ip=ip,
- )
- OpView.__init__(self, op)
diff --git a/mlir/python/mlir/dialects/_bufferization_transform_ops_ext.py b/mlir/python/mlir/dialects/_bufferization_transform_ops_ext.py
deleted file mode 100644
index 7e6c1b81cb350b7..000000000000000
--- a/mlir/python/mlir/dialects/_bufferization_transform_ops_ext.py
+++ /dev/null
@@ -1,128 +0,0 @@
-# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-try:
- from ..ir import *
- from ..dialects import transform
-except ImportError as e:
- raise RuntimeError("Error loading imports from extension module") from e
-
-from enum import Enum
-from typing import Optional, overload, Union
-
-
-class EmptyTensorToAllocTensorOp:
- """Specialization for EmptyTensorToAllocTensorOp class."""
-
- @overload
- def __init__(
- self,
- transformed_type: Type,
- target: Union[Operation, OpView, Value],
- *,
- loc=None,
- ip=None
- ):
- ...
-
- @overload
- def __init__(self, target: Union[Operation, OpView, Value], *, loc=None, ip=None):
- ...
-
- def __init__(
- self,
- transformed_type_or_target: Type,
- target_or_none: Optional[Union[Operation, OpView, Value]] = None,
- *,
- loc=None,
- ip=None
- ):
- if isinstance(transformed_type_or_target, Type):
- transformed_type = transformed_type_or_target
- target = target_or_none
- else:
- transformed_type = transform.OperationType.get("bufferization.alloc_tensor")
- target = transformed_type_or_target
-
- super().__init__(
- transformed_type,
- target,
- loc=loc,
- ip=ip,
- )
-
-
-class OneShotBufferizeOp:
- """Specialization for OneShotBufferizeOp class."""
-
- @overload
- def __init__(
- self,
- transformed_type: Type,
- target: Union[Operation, OpView, Value],
- *,
- allow_return_allocs_from_loops: Optional[bool] = None,
- allow_unknown_ops: Optional[bool] = None,
- bufferize_function_boundaries: Optional[bool] = None,
- function_boundary_type_conversion: Optional[Enum] = None,
- memcpy_op: Optional[str] = None,
- print_conflicts: Optional[bool] = None,
- test_analysis_only: Optional[bool] = None,
- loc=None,
- ip=None
- ):
- ...
-
- @overload
- def __init__(
- self,
- target: Union[Operation, OpView, Value],
- *,
- allow_return_allocs_from_loops: Optional[bool] = None,
- allow_unknown_ops: Optional[bool] = None,
- bufferize_function_boundaries: Optional[bool] = None,
- function_boundary_type_conversion: Optional[Enum] = None,
- memcpy_op: Optional[str] = None,
- print_conflicts: Optional[bool] = None,
- test_analysis_only: Optional[bool] = None,
- loc=None,
- ip=None
- ):
- ...
-
- def __init__(
- self,
- transformed_type_or_target: Type,
- target_or_none: Optional[Union[Operation, OpView, Value]] = None,
- *,
- allow_return_allocs_from_loops: Optional[bool] = None,
- allow_unknown_ops: Optional[bool] = None,
- bufferize_function_boundaries: Optional[bool] = None,
- function_boundary_type_conversion: Optional[Enum] = None,
- memcpy_op: Optional[str] = None,
- print_conflicts: Optional[bool] = None,
- test_analysis_only: Optional[bool] = None,
- loc=None,
- ip=None
- ):
- if isinstance(transformed_type_or_target, Type):
- transformed_type = transformed_type_or_target
- target = target_or_none
- else:
- transformed_type = transform.AnyOpType.get()
- target = transformed_type_or_target
-
- super().__init__(
- transformed_type,
- target,
- allow_return_allocs_from_loops=allow_return_allocs_from_loops,
- allow_unknown_ops=allow_unknown_ops,
- bufferize_function_boundaries=bufferize_function_boundaries,
- function_boundary_type_conversion=function_boundary_type_conversion,
- memcpy_op=memcpy_op,
- print_conflicts=print_conflicts,
- test_analysis_only=test_analysis_only,
- loc=loc,
- ip=ip,
- )
diff --git a/mlir/python/mlir/dialects/_builtin_ops_ext.py b/mlir/python/mlir/dialects/_builtin_ops_ext.py
deleted file mode 100644
index 27a60123050acb4..000000000000000
--- a/mlir/python/mlir/dialects/_builtin_ops_ext.py
+++ /dev/null
@@ -1,20 +0,0 @@
-# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-try:
- from ..ir import *
-except ImportError as e:
- raise RuntimeError("Error loading imports from extension module") from e
-
-
-class ModuleOp:
- """Specialization for the module op class."""
-
- def __init__(self, *, loc=None, ip=None):
- super().__init__(self.build_generic(results=[], operands=[], loc=loc, ip=ip))
- body = self.regions[0].blocks.append()
-
- @property
- def body(self):
- return self.regions[0].blocks[0]
diff --git a/mlir/python/mlir/dialects/_func_ops_ext.py b/mlir/python/mlir/dialects/_func_ops_ext.py
deleted file mode 100644
index 6d264c33f1f9dae..000000000000000
--- a/mlir/python/mlir/dialects/_func_ops_ext.py
+++ /dev/null
@@ -1,319 +0,0 @@
-# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-try:
- from ..ir import *
- from ._ods_common import get_default_loc_context as _get_default_loc_context
-
- import inspect
-
- from typing import Any, List, Optional, Sequence, Union
-except ImportError as e:
- raise RuntimeError("Error loading imports from extension module") from e
-
-ARGUMENT_ATTRIBUTE_NAME = "arg_attrs"
-RESULT_ATTRIBUTE_NAME = "res_attrs"
-
-
-class ConstantOp:
- """Specialization for the constant op class."""
-
- def __init__(self, result: Type, value: Attribute, *, loc=None, ip=None):
- super().__init__(result, value, loc=loc, ip=ip)
-
- @property
- def type(self):
- return self.results[0].type
-
-
-class FuncOp:
- """Specialization for the func op class."""
-
- def __init__(
- self, name, type, *, visibility=None, body_builder=None, loc=None, ip=None
- ):
- """
- Create a FuncOp with the provided `name`, `type`, and `visibility`.
- - `name` is a string representing the function name.
- - `type` is either a FunctionType or a pair of list describing inputs and
- results.
- - `visibility` is a string matching `public`, `private`, or `nested`. None
- implies private visibility.
- - `body_builder` is an optional callback, when provided a new entry block
- is created and the callback is invoked with the new op as argument within
- an InsertionPoint context already set for the block. The callback is
- expected to insert a terminator in the block.
- """
- sym_name = StringAttr.get(str(name))
-
- # If the type is passed as a tuple, build a FunctionType on the fly.
- if isinstance(type, tuple):
- type = FunctionType.get(inputs=type[0], results=type[1])
-
- type = TypeAttr.get(type)
- sym_visibility = (
- StringAttr.get(str(visibility)) if visibility is not None else None
- )
- super().__init__(sym_name, type, sym_visibility=sym_visibility, loc=loc, ip=ip)
- if body_builder:
- entry_block = self.add_entry_block()
- with InsertionPoint(entry_block):
- body_builder(self)
-
- @property
- def is_external(self):
- return len(self.regions[0].blocks) == 0
-
- @property
- def body(self):
- return self.regions[0]
-
- @property
- def type(self):
- return FunctionType(TypeAttr(self.attributes["function_type"]).value)
-
- @property
- def visibility(self):
- return self.attributes["sym_visibility"]
-
- @property
- def name(self) -> StringAttr:
- return StringAttr(self.attributes["sym_name"])
-
- @property
- def entry_block(self):
- if self.is_external:
- raise IndexError("External function does not have a body")
- return self.regions[0].blocks[0]
-
- def add_entry_block(self, arg_locs: Optional[Sequence[Location]] = None):
- """
- Add an entry block to the function body using the function signature to
- infer block arguments.
- Returns the newly created block
- """
- if not self.is_external:
- raise IndexError("The function already has an entry block!")
- self.body.blocks.append(*self.type.inputs, arg_locs=arg_locs)
- return self.body.blocks[0]
-
- @property
- def arg_attrs(self):
- return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME])
-
- @arg_attrs.setter
- def arg_attrs(self, attribute: Union[ArrayAttr, list]):
- if isinstance(attribute, ArrayAttr):
- self.attributes[ARGUMENT_ATTRIBUTE_NAME] = attribute
- else:
- self.attributes[ARGUMENT_ATTRIBUTE_NAME] = ArrayAttr.get(
- attribute, context=self.context
- )
-
- @property
- def arguments(self):
- return self.entry_block.arguments
-
- @property
- def result_attrs(self):
- return self.attributes[RESULT_ATTRIBUTE_NAME]
-
- @result_attrs.setter
- def result_attrs(self, attribute: ArrayAttr):
- self.attributes[RESULT_ATTRIBUTE_NAME] = attribute
-
- @classmethod
- def from_py_func(
- FuncOp,
- *inputs: Type,
- results: Optional[Sequence[Type]] = None,
- name: Optional[str] = None,
- ):
- """Decorator to define an MLIR FuncOp specified as a python function.
-
- Requires that an `mlir.ir.InsertionPoint` and `mlir.ir.Location` are
- active for the current thread (i.e. established in a `with` block).
-
- When applied as a decorator to a Python function, an entry block will
- be constructed for the FuncOp with types as specified in `*inputs`. The
- block arguments will be passed positionally to the Python function. In
- addition, if the Python function accepts keyword arguments generally or
- has a corresponding keyword argument, the following will be passed:
- * `func_op`: The `func` op being defined.
-
- By default, the function name will be the Python function `__name__`. This
- can be overriden by passing the `name` argument to the decorator.
-
- If `results` is not specified, then the decorator will implicitly
- insert a `ReturnOp` with the `Value`'s returned from the decorated
- function. It will also set the `FuncOp` type with the actual return
- value types. If `results` is specified, then the decorated function
- must return `None` and no implicit `ReturnOp` is added (nor are the result
- types updated). The implicit behavior is intended for simple, single-block
- cases, and users should specify result types explicitly for any complicated
- cases.
-
- The decorated function can further be called from Python and will insert
- a `CallOp` at the then-current insertion point, returning either None (
- if no return values), a unary Value (for one result), or a list of Values).
- This mechanism cannot be used to emit recursive calls (by construction).
- """
-
- def decorator(f):
- from . import func
-
- # Introspect the callable for optional features.
- sig = inspect.signature(f)
- has_arg_func_op = False
- for param in sig.parameters.values():
- if param.kind == param.VAR_KEYWORD:
- has_arg_func_op = True
- if param.name == "func_op" and (
- param.kind == param.POSITIONAL_OR_KEYWORD
- or param.kind == param.KEYWORD_ONLY
- ):
- has_arg_func_op = True
-
- # Emit the FuncOp.
- implicit_return = results is None
- symbol_name = name or f.__name__
- function_type = FunctionType.get(
- inputs=inputs, results=[] if implicit_return else results
- )
- func_op = FuncOp(name=symbol_name, type=function_type)
- with InsertionPoint(func_op.add_entry_block()):
- func_args = func_op.entry_block.arguments
- func_kwargs = {}
- if has_arg_func_op:
- func_kwargs["func_op"] = func_op
- return_values = f(*func_args, **func_kwargs)
- if not implicit_return:
- return_types = list(results)
- assert return_values is None, (
- "Capturing a python function with explicit `results=` "
- "requires that the wrapped function returns None."
- )
- else:
- # Coerce return values, add ReturnOp and rewrite func type.
- if return_values is None:
- return_values = []
- elif isinstance(return_values, tuple):
- return_values = list(return_values)
- elif isinstance(return_values, Value):
- # Returning a single value is fine, coerce it into a list.
- return_values = [return_values]
- elif isinstance(return_values, OpView):
- # Returning a single operation is fine, coerce its results a list.
- return_values = return_values.operation.results
- elif isinstance(return_values, Operation):
- # Returning a single operation is fine, coerce its results a list.
- return_values = return_values.results
- else:
- return_values = list(return_values)
- func.ReturnOp(return_values)
- # Recompute the function type.
- return_types = [v.type for v in return_values]
- function_type = FunctionType.get(
- inputs=inputs, results=return_types
- )
- func_op.attributes["function_type"] = TypeAttr.get(function_type)
-
- def emit_call_op(*call_args):
- call_op = func.CallOp(
- return_types, FlatSymbolRefAttr.get(symbol_name), call_args
- )
- if return_types is None:
- return None
- elif len(return_types) == 1:
- return call_op.result
- else:
- return call_op.results
-
- wrapped = emit_call_op
- wrapped.__name__ = f.__name__
- wrapped.func_op = func_op
- return wrapped
-
- return decorator
-
-
-class CallOp:
- """Specialization for the call op class."""
-
- def __init__(
- self,
- calleeOrResults: Union[FuncOp, List[Type]],
- argumentsOrCallee: Union[List, FlatSymbolRefAttr, str],
- arguments: Optional[List] = None,
- *,
- loc=None,
- ip=None,
- ):
- """Creates an call operation.
-
- The constructor accepts three different forms:
-
- 1. A function op to be called followed by a list of arguments.
- 2. A list of result types, followed by the name of the function to be
- called as string, following by a list of arguments.
- 3. A list of result types, followed by the name of the function to be
- called as symbol reference attribute, followed by a list of arguments.
-
- For example
-
- f = func.FuncOp("foo", ...)
- func.CallOp(f, [args])
- func.CallOp([result_types], "foo", [args])
-
- In all cases, the location and insertion point may be specified as keyword
- arguments if not provided by the surrounding context managers.
- """
-
- # TODO: consider supporting constructor "overloads", e.g., through a custom
- # or pybind-provided metaclass.
- if isinstance(calleeOrResults, FuncOp):
- if not isinstance(argumentsOrCallee, list):
- raise ValueError(
- "when constructing a call to a function, expected "
- + "the second argument to be a list of call arguments, "
- + f"got {type(argumentsOrCallee)}"
- )
- if arguments is not None:
- raise ValueError(
- "unexpected third argument when constructing a call"
- + "to a function"
- )
-
- super().__init__(
- calleeOrResults.type.results,
- FlatSymbolRefAttr.get(
- calleeOrResults.name.value, context=_get_default_loc_context(loc)
- ),
- argumentsOrCallee,
- loc=loc,
- ip=ip,
- )
- return
-
- if isinstance(argumentsOrCallee, list):
- raise ValueError(
- "when constructing a call to a function by name, "
- + "expected the second argument to be a string or a "
- + f"FlatSymbolRefAttr, got {type(argumentsOrCallee)}"
- )
-
- if isinstance(argumentsOrCallee, FlatSymbolRefAttr):
- super().__init__(
- calleeOrResults, argumentsOrCallee, arguments, loc=loc, ip=ip
- )
- elif isinstance(argumentsOrCallee, str):
- super().__init__(
- calleeOrResults,
- FlatSymbolRefAttr.get(
- argumentsOrCallee, context=_get_default_loc_context(loc)
- ),
- arguments,
- loc=loc,
- ip=ip,
- )
diff --git a/mlir/python/mlir/dialects/_gpu_transform_ops_ext.py b/mlir/python/mlir/dialects/_gpu_transform_ops_ext.py
deleted file mode 100644
index ba72bac3a15264d..000000000000000
--- a/mlir/python/mlir/dialects/_gpu_transform_ops_ext.py
+++ /dev/null
@@ -1,124 +0,0 @@
-# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-try:
- from ..ir import *
- from ..dialects import transform
-except ImportError as e:
- raise RuntimeError("Error loading imports from extension module") from e
-
-from typing import Optional, Sequence, Union, overload
-
-
-class MapForallToBlocks:
- """Specialization for MapForallToBlocks class."""
-
- @overload
- def __init__(
- self,
- result_type: Type,
- target: Union[Operation, OpView, Value],
- *,
- grid_dims: Optional[Union[Sequence[int], Attribute]] = None,
- generate_gpu_launch: Optional[Union[bool, Attribute]] = None,
- loc=None,
- ip=None
- ):
- ...
-
- @overload
- def __init__(
- self,
- target: Union[Operation, OpView, Value],
- *,
- grid_dims: Optional[Union[Sequence[int], Attribute]] = None,
- generate_gpu_launch: Optional[Union[bool, Attribute]] = None,
- loc=None,
- ip=None
- ):
- ...
-
- def __init__(
- self,
- result_type_or_target: Union[Operation, OpView, Type, Value],
- target_or_none: Optional[Union[Operation, OpView, Value]] = None,
- *,
- grid_dims: Optional[Union[Sequence[int], Attribute]] = None,
- generate_gpu_launch: Optional[Union[bool, Attribute]] = None,
- loc=None,
- ip=None
- ):
- if isinstance(result_type_or_target, Type):
- result_type = result_type_or_target
- target = target_or_none
- else:
- result_type = transform.AnyOpType.get()
- target = result_type_or_target
-
- super().__init__(
- result_type,
- target,
- grid_dims=grid_dims,
- generate_gpu_launch=generate_gpu_launch,
- loc=loc,
- ip=ip,
- )
-
-
-class MapNestedForallToThreads:
- """Specialization for MapNestedForallToThreads class."""
-
- @overload
- def __init__(
- self,
- result_type: Type,
- target: Union[Operation, OpView, Value],
- *,
- block_dims: Optional[Sequence[int]] = None,
- warp_size: Optional[Sequence[int]] = None,
- sync_after_distribute: Optional[bool] = None,
- loc=None,
- ip=None
- ):
- ...
-
- @overload
- def __init__(
- self,
- target: Union[Operation, OpView, Value],
- *,
- block_dims: Optional[Sequence[int]] = None,
- warp_size: Optional[Sequence[int]] = None,
- sync_after_distribute: Optional[bool] = None,
- loc=None,
- ip=None
- ):
- ...
-
- def __init__(
- self,
- result_type_or_target: Union[Operation, OpView, Value, Type],
- target_or_none: Optional[Union[Operation, OpView, Value]] = None,
- *,
- block_dims: Optional[Union[Sequence[int], Attribute]] = None,
- warp_size: Optional[Union[Sequence[int], Attribute]] = None,
- sync_after_distribute: Optional[bool] = None,
- loc=None,
- ip=None
- ):
- if isinstance(result_type_or_target, Type):
- result_type = result_type_or_target
- target = target_or_none
- else:
- result_type = result_type_or_target.type
- target = result_type_or_target
- super().__init__(
- result_type,
- target,
- block_dims=block_dims,
- warp_size=warp_size,
- sync_after_distribute=sync_after_distribute,
- loc=loc,
- ip=ip,
- )
diff --git a/mlir/python/mlir/dialects/_linalg_ops_ext.py b/mlir/python/mlir/dialects/_linalg_ops_ext.py
deleted file mode 100644
index 3f6d854ca3e2b14..000000000000000
--- a/mlir/python/mlir/dialects/_linalg_ops_ext.py
+++ /dev/null
@@ -1,47 +0,0 @@
-# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-try:
- from typing import Optional, Sequence, Union
- from ..ir import *
- from ._ods_common import get_default_loc_context
- from .._mlir_libs._mlirDialectsLinalg import fill_builtin_region
-except ImportError as e:
- raise RuntimeError("Error loading imports from extension module") from e
-
-from ._ods_common import get_op_result_or_value as _get_op_result_or_value
-
-
-def isa(cls: Type, ty: Type):
- try:
- cls(ty)
- return True
- except ValueError:
- return False
-
-
-class StructuredOpMixin:
- """All structured ops use the same mixin class."""
-
- def __init__(self, inputs, outputs=(), results=(), loc=None, ip=None):
- super().__init__(
- self.build_generic(
- results=list(results),
- operands=[list(inputs), list(outputs)],
- loc=loc,
- ip=ip,
- )
- )
-
-
-def select_opview_mixin(parent_opview_cls):
- # TODO: This shouldn't be a heuristic: we should have a way to annotate
- # the OpView to note that it is a structured op.
- if (
- "__init__" not in parent_opview_cls.__dict__
- and hasattr(parent_opview_cls, "inputs")
- and hasattr(parent_opview_cls, "outputs")
- and hasattr(parent_opview_cls, "result_tensors")
- ):
- return StructuredOpMixin
diff --git a/mlir/python/mlir/dialects/_loop_transform_ops_ext.py b/mlir/python/mlir/dialects/_loop_transform_ops_ext.py
deleted file mode 100644
index 1cdb2b9e77b5afe..000000000000000
--- a/mlir/python/mlir/dialects/_loop_transform_ops_ext.py
+++ /dev/null
@@ -1,134 +0,0 @@
-# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-try:
- from ..ir import *
- from ._ods_common import get_op_result_or_value as _get_op_result_or_value
-except ImportError as e:
- raise RuntimeError("Error loading imports from extension module") from e
-
-from typing import Optional, Union
-
-
-class GetParentForOp:
- """Extension for GetParentForOp."""
-
- def __init__(
- self,
- result_type: Type,
- target: Union[Operation, Value],
- *,
- num_loops: Optional[int] = None,
- ip=None,
- loc=None,
- ):
- if num_loops is None:
- num_loops = 1
- super().__init__(
- result_type,
- _get_op_result_or_value(target),
- num_loops=num_loops,
- ip=ip,
- loc=loc,
- )
-
-
-class LoopOutlineOp:
- """Extension for LoopOutlineOp."""
-
- def __init__(
- self,
- function_type: Type,
- call_type: Type,
- target: Union[Operation, Value],
- *,
- func_name: Union[str, StringAttr],
- ip=None,
- loc=None,
- ):
- super().__init__(
- function_type,
- call_type,
- _get_op_result_or_value(target),
- func_name=(
- func_name
- if isinstance(func_name, StringAttr)
- else StringAttr.get(func_name)
- ),
- ip=ip,
- loc=loc,
- )
-
-
-class LoopPeelOp:
- """Extension for LoopPeelOp."""
-
- def __init__(
- self,
- main_loop_type: Type,
- remainder_loop_type: Type,
- target: Union[Operation, Value],
- *,
- fail_if_already_divisible: Union[bool, BoolAttr] = False,
- ip=None,
- loc=None,
- ):
- super().__init__(
- main_loop_type,
- remainder_loop_type,
- _get_op_result_or_value(target),
- fail_if_already_divisible=(
- fail_if_already_divisible
- if isinstance(fail_if_already_divisible, BoolAttr)
- else BoolAttr.get(fail_if_already_divisible)
- ),
- ip=ip,
- loc=loc,
- )
-
-
-class LoopPipelineOp:
- """Extension for LoopPipelineOp."""
-
- def __init__(
- self,
- result_type: Type,
- target: Union[Operation, Value],
- *,
- iteration_interval: Optional[Union[int, IntegerAttr]] = None,
- read_latency: Optional[Union[int, IntegerAttr]] = None,
- ip=None,
- loc=None,
- ):
- if iteration_interval is None:
- iteration_interval = 1
- if read_latency is None:
- read_latency = 10
- super().__init__(
- result_type,
- _get_op_result_or_value(target),
- iteration_interval=iteration_interval,
- read_latency=read_latency,
- ip=ip,
- loc=loc,
- )
-
-
-class LoopUnrollOp:
- """Extension for LoopUnrollOp."""
-
- def __init__(
- self,
- target: Union[Operation, Value],
- *,
- factor: Union[int, IntegerAttr],
- ip=None,
- loc=None,
- ):
- super().__init__(
- _get_op_result_or_value(target),
- factor=factor,
- ip=ip,
- loc=loc,
- )
diff --git a/mlir/python/mlir/dialects/_memref_ops_ext.py b/mlir/python/mlir/dialects/_memref_ops_ext.py
deleted file mode 100644
index 825f1a0a7a6faf4..000000000000000
--- a/mlir/python/mlir/dialects/_memref_ops_ext.py
+++ /dev/null
@@ -1,36 +0,0 @@
-# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-try:
- from ..ir import *
- from ._ods_common import get_op_result_or_value as _get_op_result_or_value
- from ._ods_common import get_op_results_or_values as _get_op_results_or_values
-except ImportError as e:
- raise RuntimeError("Error loading imports from extension module") from e
-
-from typing import Optional, Sequence, Union
-
-
-class LoadOp:
- """Specialization for the MemRef load operation."""
-
- def __init__(
- self,
- memref: Union[Operation, OpView, Value],
- indices: Optional[Union[Operation, OpView, Sequence[Value]]] = None,
- *,
- loc=None,
- ip=None
- ):
- """Creates a memref load operation.
-
- Args:
- memref: the buffer to load from.
- indices: the list of subscripts, may be empty for zero-dimensional
- buffers.
- loc: user-visible location of the operation.
- ip: insertion point.
- """
- indices_resolved = [] if indices is None else _get_op_results_or_values(indices)
- super().__init__(memref, indices_resolved, loc=loc, ip=ip)
diff --git a/mlir/python/mlir/dialects/_memref_transform_ops_ext.py b/mlir/python/mlir/dialects/_memref_transform_ops_ext.py
deleted file mode 100644
index 1cc00bdcbf381c9..000000000000000
--- a/mlir/python/mlir/dialects/_memref_transform_ops_ext.py
+++ /dev/null
@@ -1,114 +0,0 @@
-# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-try:
- from ..ir import *
- from ..dialects import transform
-except ImportError as e:
- raise RuntimeError("Error loading imports from extension module") from e
-
-from typing import Optional, overload, Union
-
-
-class MemRefAllocaToGlobalOp:
- """Specialization for MemRefAllocaToGlobalOp class."""
-
- @overload
- def __init__(
- self,
- get_global_type: Type,
- global_type: Type,
- alloca: Union[Operation, OpView, Value],
- *,
- loc=None,
- ip=None
- ):
- ...
-
- @overload
- def __init__(self, alloca: Union[Operation, OpView, Value], *, loc=None, ip=None):
- ...
-
- def __init__(
- self,
- get_global_type_or_alloca: Union[Operation, OpView, Type, Value],
- global_type_or_none: Optional[Type] = None,
- alloca_or_none: Optional[Union[Operation, OpView, Value]] = None,
- *,
- loc=None,
- ip=None
- ):
- if isinstance(get_global_type_or_alloca, Type):
- get_global_type = get_global_type_or_alloca
- global_type = global_type_or_none
- alloca = alloca_or_none
- else:
- get_global_type = transform.AnyOpType.get()
- global_type = transform.AnyOpType.get()
- alloca = get_global_type_or_alloca
-
- super().__init__(
- get_global_type,
- global_type,
- alloca,
- loc=loc,
- ip=ip,
- )
-
-
-class MemRefMultiBufferOp:
- """Specialization for MemRefMultiBufferOp class."""
-
- @overload
- def __init__(
- self,
- transformed_type: Type,
- target: Union[Operation, OpView, Value],
- factor: Union[int, IntegerAttr],
- *,
- skip_analysis: Optional[bool] = None,
- loc=None,
- ip=None
- ):
- ...
-
- @overload
- def __init__(
- self,
- target: Union[Operation, OpView, Value],
- factor: Union[int, IntegerAttr],
- *,
- skip_analysis: Optional[bool] = None,
- loc=None,
- ip=None
- ):
- ...
-
- def __init__(
- self,
- transformed_type_or_target: Type,
- target_or_factor: Union[int, IntegerAttr, Operation, OpView, Value] = None,
- factor_or_none: Optional[Union[int, IntegerAttr]] = None,
- *,
- skip_analysis: Optional[bool] = None,
- loc=None,
- ip=None
- ):
- if isinstance(transformed_type_or_target, Type):
- transformed_type = transformed_type_or_target
- target = target_or_factor
- factor = factor_or_none
- else:
- transformed_type = transform.AnyOpType.get()
- target = transformed_type_or_target
- factor = target_or_factor
-
- super().__init__(
- transformed_type,
- target,
- factor,
- skip_analysis=skip_analysis,
- loc=loc,
- ip=ip,
- )
diff --git a/mlir/python/mlir/dialects/_ml_program_ops_ext.py b/mlir/python/mlir/dialects/_ml_program_ops_ext.py
deleted file mode 100644
index c84d23c16ef93ab..000000000000000
--- a/mlir/python/mlir/dialects/_ml_program_ops_ext.py
+++ /dev/null
@@ -1,113 +0,0 @@
-# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-try:
- from typing import Union
- from ..ir import *
- from ._ods_common import get_default_loc_context as _get_default_loc_context
-except ImportError as e:
- raise RuntimeError("Error loading imports from extension module") from e
-
-from ._ml_program_ops_gen import *
-
-
-ARGUMENT_ATTRIBUTE_NAME = "arg_attrs"
-RESULT_ATTRIBUTE_NAME = "res_attrs"
-
-
-class FuncOp:
- """Specialization for the func op class."""
-
- def __init__(
- self, name, type, *, visibility=None, body_builder=None, loc=None, ip=None
- ):
- """
- Create a FuncOp with the provided `name`, `type`, and `visibility`.
- - `name` is a string representing the function name.
- - `type` is either a FunctionType or a pair of list describing inputs and
- results.
- - `visibility` is a string matching `public`, `private`, or `nested`. None
- implies private visibility.
- - `body_builder` is an optional callback, when provided a new entry block
- is created and the callback is invoked with the new op as argument within
- an InsertionPoint context already set for the block. The callback is
- expected to insert a terminator in the block.
- """
- sym_name = StringAttr.get(str(name))
-
- # If the type is passed as a tuple, build a FunctionType on the fly.
- if isinstance(type, tuple):
- type = FunctionType.get(inputs=type[0], results=type[1])
-
- type = TypeAttr.get(type)
- sym_visibility = (
- StringAttr.get(str(visibility)) if visibility is not None else None
- )
- super().__init__(sym_name, type, sym_visibility=sym_visibility, loc=loc, ip=ip)
- if body_builder:
- entry_block = self.add_entry_block()
- with InsertionPoint(entry_block):
- body_builder(self)
-
- @property
- def is_external(self):
- return len(self.regions[0].blocks) == 0
-
- @property
- def body(self):
- return self.regions[0]
-
- @property
- def type(self):
- return FunctionType(TypeAttr(self.attributes["function_type"]).value)
-
- @property
- def visibility(self):
- return self.attributes["sym_visibility"]
-
- @property
- def name(self) -> StringAttr:
- return StringAttr(self.attributes["sym_name"])
-
- @property
- def entry_block(self):
- if self.is_external:
- raise IndexError("External function does not have a body")
- return self.regions[0].blocks[0]
-
- def add_entry_block(self):
- """
- Add an entry block to the function body using the function signature to
- infer block arguments.
- Returns the newly created block
- """
- if not self.is_external:
- raise IndexError("The function already has an entry block!")
- self.body.blocks.append(*self.type.inputs)
- return self.body.blocks[0]
-
- @property
- def arg_attrs(self):
- return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME])
-
- @arg_attrs.setter
- def arg_attrs(self, attribute: Union[ArrayAttr, list]):
- if isinstance(attribute, ArrayAttr):
- self.attributes[ARGUMENT_ATTRIBUTE_NAME] = attribute
- else:
- self.attributes[ARGUMENT_ATTRIBUTE_NAME] = ArrayAttr.get(
- attribute, context=self.context
- )
-
- @property
- def arguments(self):
- return self.entry_block.arguments
-
- @property
- def result_attrs(self):
- return self.attributes[RESULT_ATTRIBUTE_NAME]
-
- @result_attrs.setter
- def result_attrs(self, attribute: ArrayAttr):
- self.attributes[RESULT_ATTRIBUTE_NAME] = attribute
diff --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py
index 895c3228139b392..9cca7d659ec8cb3 100644
--- a/mlir/python/mlir/dialects/_ods_common.py
+++ b/mlir/python/mlir/dialects/_ods_common.py
@@ -9,7 +9,6 @@
__all__ = [
"equally_sized_accessor",
- "extend_opview_class",
"get_default_loc_context",
"get_op_result_or_value",
"get_op_results_or_values",
@@ -18,64 +17,6 @@
]
-def extend_opview_class(ext_module):
- """Decorator to extend an OpView class from an extension module.
-
- Extension modules can expose various entry-points:
- Stand-alone class with the same name as a parent OpView class (i.e.
- "ReturnOp"). A name-based match is attempted first before falling back
- to a below mechanism.
-
- def select_opview_mixin(parent_opview_cls):
- If defined, allows an appropriate mixin class to be selected dynamically
- based on the parent OpView class. Should return NotImplemented if a
- decision is not made.
-
- Args:
- ext_module: A module from which to locate extensions. Can be None if not
- available.
-
- Returns:
- A decorator that takes an OpView subclass and further extends it as
- needed.
- """
-
- def class_decorator(parent_opview_cls: type):
- if ext_module is None:
- return parent_opview_cls
- mixin_cls = NotImplemented
- # First try to resolve by name.
- try:
- mixin_cls = getattr(ext_module, parent_opview_cls.__name__)
- except AttributeError:
- # Fall back to a select_opview_mixin hook.
- try:
- select_mixin = getattr(ext_module, "select_opview_mixin")
- except AttributeError:
- pass
- else:
- mixin_cls = select_mixin(parent_opview_cls)
-
- if mixin_cls is NotImplemented or mixin_cls is None:
- return parent_opview_cls
-
- # Have a mixin_cls. Create an appropriate subclass.
- try:
-
- class LocalOpView(mixin_cls, parent_opview_cls):
- pass
-
- except TypeError as e:
- raise TypeError(
- f"Could not mixin {mixin_cls} into {parent_opview_cls}"
- ) from e
- LocalOpView.__name__ = parent_opview_cls.__name__
- LocalOpView.__qualname__ = parent_opview_cls.__qualname__
- return LocalOpView
-
- return class_decorator
-
-
def segmented_accessor(elements, raw_segments, idx):
"""
Returns a slice of elements corresponding to the idx-th segment.
diff --git a/mlir/python/mlir/dialects/_pdl_ops_ext.py b/mlir/python/mlir/dialects/_pdl_ops_ext.py
deleted file mode 100644
index fc9de0b7f7db69c..000000000000000
--- a/mlir/python/mlir/dialects/_pdl_ops_ext.py
+++ /dev/null
@@ -1,271 +0,0 @@
-# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-try:
- from ..ir import *
- from ..dialects import pdl
-except ImportError as e:
- raise RuntimeError("Error loading imports from extension module") from e
-
-from typing import Union, Optional, Sequence, Mapping
-from ._ods_common import (
- get_op_result_or_value as _get_value,
- get_op_results_or_values as _get_values,
-)
-
-
-class ApplyNativeConstraintOp:
- """Specialization for PDL apply native constraint op class."""
-
- def __init__(
- self,
- name: Union[str, StringAttr],
- args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
- *,
- loc=None,
- ip=None,
- ):
- if args is None:
- args = []
- args = _get_values(args)
- super().__init__(name, args, loc=loc, ip=ip)
-
-
-class ApplyNativeRewriteOp:
- """Specialization for PDL apply native rewrite op class."""
-
- def __init__(
- self,
- results: Sequence[Type],
- name: Union[str, StringAttr],
- args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
- *,
- loc=None,
- ip=None,
- ):
- if args is None:
- args = []
- args = _get_values(args)
- super().__init__(results, name, args, loc=loc, ip=ip)
-
-
-class AttributeOp:
- """Specialization for PDL attribute op class."""
-
- def __init__(
- self,
- valueType: Optional[Union[OpView, Operation, Value]] = None,
- value: Optional[Attribute] = None,
- *,
- loc=None,
- ip=None,
- ):
- valueType = valueType if valueType is None else _get_value(valueType)
- result = pdl.AttributeType.get()
- super().__init__(result, valueType=valueType, value=value, loc=loc, ip=ip)
-
-
-class EraseOp:
- """Specialization for PDL erase op class."""
-
- def __init__(
- self,
- operation: Optional[Union[OpView, Operation, Value]] = None,
- *,
- loc=None,
- ip=None,
- ):
- operation = _get_value(operation)
- super().__init__(operation, loc=loc, ip=ip)
-
-
-class OperandOp:
- """Specialization for PDL operand op class."""
-
- def __init__(
- self,
- type: Optional[Union[OpView, Operation, Value]] = None,
- *,
- loc=None,
- ip=None,
- ):
- type = type if type is None else _get_value(type)
- result = pdl.ValueType.get()
- super().__init__(result, valueType=type, loc=loc, ip=ip)
-
-
-class OperandsOp:
- """Specialization for PDL operands op class."""
-
- def __init__(
- self,
- types: Optional[Union[OpView, Operation, Value]] = None,
- *,
- loc=None,
- ip=None,
- ):
- types = types if types is None else _get_value(types)
- result = pdl.RangeType.get(pdl.ValueType.get())
- super().__init__(result, valueType=types, loc=loc, ip=ip)
-
-
-class OperationOp:
- """Specialization for PDL operand op class."""
-
- def __init__(
- self,
- name: Optional[Union[str, StringAttr]] = None,
- args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
- attributes: Optional[Mapping[str, Union[OpView, Operation, Value]]] = None,
- types: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
- *,
- loc=None,
- ip=None,
- ):
- if types is None:
- types = []
- if attributes is None:
- attributes = {}
- if args is None:
- args = []
- args = _get_values(args)
- attrNames = []
- attrValues = []
- for attrName, attrValue in attributes.items():
- attrNames.append(StringAttr.get(attrName))
- attrValues.append(_get_value(attrValue))
- attrNames = ArrayAttr.get(attrNames)
- types = _get_values(types)
- result = pdl.OperationType.get()
- super().__init__(
- result, args, attrValues, attrNames, types, opName=name, loc=loc, ip=ip
- )
-
-
-class PatternOp:
- """Specialization for PDL pattern op class."""
-
- def __init__(
- self,
- benefit: Union[IntegerAttr, int],
- name: Optional[Union[StringAttr, str]] = None,
- *,
- loc=None,
- ip=None,
- ):
- """Creates an PDL `pattern` operation."""
- super().__init__(benefit, sym_name=name, loc=loc, ip=ip)
- self.regions[0].blocks.append()
-
- @property
- def body(self):
- """Return the body (block) of the pattern."""
- return self.regions[0].blocks[0]
-
-
-class ReplaceOp:
- """Specialization for PDL replace op class."""
-
- def __init__(
- self,
- op: Union[OpView, Operation, Value],
- *,
- with_op: Optional[Union[OpView, Operation, Value]] = None,
- with_values: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
- loc=None,
- ip=None,
- ):
- if with_values is None:
- with_values = []
- op = _get_value(op)
- with_op = with_op if with_op is None else _get_value(with_op)
- with_values = _get_values(with_values)
- super().__init__(op, with_values, replOperation=with_op, loc=loc, ip=ip)
-
-
-class ResultOp:
- """Specialization for PDL result op class."""
-
- def __init__(
- self,
- parent: Union[OpView, Operation, Value],
- index: Union[IntegerAttr, int],
- *,
- loc=None,
- ip=None,
- ):
- parent = _get_value(parent)
- result = pdl.ValueType.get()
- super().__init__(result, parent, index, loc=loc, ip=ip)
-
-
-class ResultsOp:
- """Specialization for PDL results op class."""
-
- def __init__(
- self,
- result: Type,
- parent: Union[OpView, Operation, Value],
- index: Optional[Union[IntegerAttr, int]] = None,
- *,
- loc=None,
- ip=None,
- ):
- parent = _get_value(parent)
- super().__init__(result, parent, index=index, loc=loc, ip=ip)
-
-
-class RewriteOp:
- """Specialization for PDL rewrite op class."""
-
- def __init__(
- self,
- root: Optional[Union[OpView, Operation, Value]] = None,
- name: Optional[Union[StringAttr, str]] = None,
- args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
- *,
- loc=None,
- ip=None,
- ):
- if args is None:
- args = []
- root = root if root is None else _get_value(root)
- args = _get_values(args)
- super().__init__(args, root=root, name=name, loc=loc, ip=ip)
-
- def add_body(self):
- """Add body (block) to the rewrite."""
- self.regions[0].blocks.append()
- return self.body
-
- @property
- def body(self):
- """Return the body (block) of the rewrite."""
- return self.regions[0].blocks[0]
-
-
-class TypeOp:
- """Specialization for PDL type op class."""
-
- def __init__(
- self, constantType: Optional[Union[TypeAttr, Type]] = None, *, loc=None, ip=None
- ):
- result = pdl.TypeType.get()
- super().__init__(result, constantType=constantType, loc=loc, ip=ip)
-
-
-class TypesOp:
- """Specialization for PDL types op class."""
-
- def __init__(
- self,
- constantTypes: Optional[Sequence[Union[TypeAttr, Type]]] = None,
- *,
- loc=None,
- ip=None,
- ):
- if constantTypes is None:
- constantTypes = []
- result = pdl.RangeType.get(pdl.TypeType.get())
- super().__init__(result, constantTypes=constantTypes, loc=loc, ip=ip)
diff --git a/mlir/python/mlir/dialects/_scf_ops_ext.py b/mlir/python/mlir/dialects/_scf_ops_ext.py
deleted file mode 100644
index 89cc8a19895c7b4..000000000000000
--- a/mlir/python/mlir/dialects/_scf_ops_ext.py
+++ /dev/null
@@ -1,107 +0,0 @@
-# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-try:
- from ..ir import *
-except ImportError as e:
- raise RuntimeError("Error loading imports from extension module") from e
-
-from typing import Optional, Sequence, Union
-
-from ._ods_common import (
- get_op_result_or_value as _get_op_result_or_value,
- get_op_results_or_values as _get_op_results_or_values,
-)
-
-
-class ForOp:
- """Specialization for the SCF for op class."""
-
- def __init__(
- self,
- lower_bound,
- upper_bound,
- step,
- iter_args: Optional[Union[Operation, OpView, Sequence[Value]]] = None,
- *,
- loc=None,
- ip=None,
- ):
- """Creates an SCF `for` operation.
-
- - `lower_bound` is the value to use as lower bound of the loop.
- - `upper_bound` is the value to use as upper bound of the loop.
- - `step` is the value to use as loop step.
- - `iter_args` is a list of additional loop-carried arguments or an operation
- producing them as results.
- """
- if iter_args is None:
- iter_args = []
- iter_args = _get_op_results_or_values(iter_args)
-
- results = [arg.type for arg in iter_args]
- super().__init__(
- self.build_generic(
- regions=1,
- results=results,
- operands=[
- _get_op_result_or_value(o) for o in [lower_bound, upper_bound, step]
- ]
- + list(iter_args),
- loc=loc,
- ip=ip,
- )
- )
- self.regions[0].blocks.append(self.operands[0].type, *results)
-
- @property
- def body(self):
- """Returns the body (block) of the loop."""
- return self.regions[0].blocks[0]
-
- @property
- def induction_variable(self):
- """Returns the induction variable of the loop."""
- return self.body.arguments[0]
-
- @property
- def inner_iter_args(self):
- """Returns the loop-carried arguments usable within the loop.
-
- To obtain the loop-carried operands, use `iter_args`.
- """
- return self.body.arguments[1:]
-
-
-class IfOp:
- """Specialization for the SCF if op class."""
-
- def __init__(self, cond, results_=[], *, hasElse=False, loc=None, ip=None):
- """Creates an SCF `if` operation.
-
- - `cond` is a MLIR value of 'i1' type to determine which regions of code will be executed.
- - `hasElse` determines whether the if operation has the else branch.
- """
- operands = []
- operands.append(cond)
- results = []
- results.extend(results_)
- super().__init__(
- self.build_generic(
- regions=2, results=results, operands=operands, loc=loc, ip=ip
- )
- )
- self.regions[0].blocks.append(*[])
- if hasElse:
- self.regions[1].blocks.append(*[])
-
- @property
- def then_block(self):
- """Returns the then block of the if operation."""
- return self.regions[0].blocks[0]
-
- @property
- def else_block(self):
- """Returns the else block of the if operation."""
- return self.regions[1].blocks[0]
diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
deleted file mode 100644
index 3757a3d3b4cce85..000000000000000
--- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
+++ /dev/null
@@ -1,759 +0,0 @@
-# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-try:
- from ..ir import *
- from ..dialects import transform
-except ImportError as e:
- raise RuntimeError("Error loading imports from extension module") from e
-
-from typing import List, Optional, Sequence, Tuple, Union, overload
-
-StaticIntLike = Union[int, IntegerAttr]
-ValueLike = Union[Operation, OpView, Value]
-MixedInt = Union[StaticIntLike, ValueLike]
-
-IntOrAttrList = Sequence[Union[IntegerAttr, int]]
-OptionalIntList = Optional[Union[ArrayAttr, IntOrAttrList]]
-
-BoolOrAttrList = Sequence[Union[BoolAttr, bool]]
-OptionalBoolList = Optional[Union[ArrayAttr, BoolOrAttrList]]
-
-MixedValues = Union[Sequence[Union[StaticIntLike, ValueLike]], ArrayAttr, ValueLike]
-
-DynamicIndexList = Sequence[Union[MixedInt, Sequence[MixedInt]]]
-
-
-def _dispatch_dynamic_index_list(
- indices: Union[DynamicIndexList, ArrayAttr],
-) -> Tuple[List[ValueLike], Union[List[int], ArrayAttr], List[bool]]:
- """Dispatches a list of indices to the appropriate form.
-
- This is similar to the custom `DynamicIndexList` directive upstream:
- provided indices may be in the form of dynamic SSA values or static values,
- and they may be scalable (i.e., as a singleton list) or not. This function
- dispatches each index into its respective form. It also extracts the SSA
- values and static indices from various similar structures, respectively.
- """
- dynamic_indices = []
- static_indices = [ShapedType.get_dynamic_size()] * len(indices)
- scalable_indices = [False] * len(indices)
-
- # ArrayAttr: Extract index values.
- if isinstance(indices, ArrayAttr):
- indices = [idx for idx in indices]
-
- def process_nonscalable_index(i, index):
- """Processes any form of non-scalable index.
-
- Returns False if the given index was scalable and thus remains
- unprocessed; True otherwise.
- """
- if isinstance(index, int):
- static_indices[i] = index
- elif isinstance(index, IntegerAttr):
- static_indices[i] = index.value # pytype: disable=attribute-error
- elif isinstance(index, (Operation, Value, OpView)):
- dynamic_indices.append(index)
- else:
- return False
- return True
-
- # Process each index at a time.
- for i, index in enumerate(indices):
- if not process_nonscalable_index(i, index):
- # If it wasn't processed, it must be a scalable index, which is
- # provided as a Sequence of one value, so extract and process that.
- scalable_indices[i] = True
- assert len(index) == 1
- ret = process_nonscalable_index(i, index[0])
- assert ret
-
- return dynamic_indices, static_indices, scalable_indices
-
-
-# Dispatches `MixedValues` that all represents integers in various forms into
-# the following three categories:
-# - `dynamic_values`: a list of `Value`s, potentially from op results;
-# - `packed_values`: a value handle, potentially from an op result, associated
-# to one or more payload operations of integer type;
-# - `static_values`: an `ArrayAttr` of `i64`s with static values, from Python
-# `int`s, from `IntegerAttr`s, or from an `ArrayAttr`.
-# The input is in the form for `packed_values`, only that result is set and the
-# other two are empty. Otherwise, the input can be a mix of the other two forms,
-# and for each dynamic value, a special value is added to the `static_values`.
-def _dispatch_mixed_values(
- values: MixedValues,
-) -> Tuple[List[Value], Union[Operation, Value, OpView], DenseI64ArrayAttr]:
- dynamic_values = []
- packed_values = None
- static_values = None
- if isinstance(values, ArrayAttr):
- static_values = values
- elif isinstance(values, (Operation, Value, OpView)):
- packed_values = values
- else:
- static_values = []
- for size in values or []:
- if isinstance(size, int):
- static_values.append(size)
- else:
- static_values.append(ShapedType.get_dynamic_size())
- dynamic_values.append(size)
- static_values = DenseI64ArrayAttr.get(static_values)
-
- return (dynamic_values, packed_values, static_values)
-
-
-def _get_value_or_attribute_value(
- value_or_attr: Union[any, Attribute, ArrayAttr]
-) -> any:
- if isinstance(value_or_attr, Attribute) and hasattr(value_or_attr, "value"):
- return value_or_attr.value
- if isinstance(value_or_attr, ArrayAttr):
- return _get_value_list(value_or_attr)
- return value_or_attr
-
-
-def _get_value_list(
- sequence_or_array_attr: Union[Sequence[any], ArrayAttr]
-) -> Sequence[any]:
- return [_get_value_or_attribute_value(v) for v in sequence_or_array_attr]
-
-
-def _get_int_array_attr(values: Optional[Union[ArrayAttr, IntOrAttrList]]) -> ArrayAttr:
- if values is None:
- return None
-
- # Turn into a Python list of Python ints.
- values = _get_value_list(values)
-
- # Make an ArrayAttr of IntegerAttrs out of it.
- return ArrayAttr.get(
- [IntegerAttr.get(IntegerType.get_signless(64), v) for v in values]
- )
-
-
-def _get_int_array_array_attr(
- values: Optional[Union[ArrayAttr, Sequence[Union[ArrayAttr, IntOrAttrList]]]]
-) -> ArrayAttr:
- """Creates an ArrayAttr of ArrayAttrs of IntegerAttrs.
-
- The input has to be a collection of collection of integers, where any
- Python Sequence and ArrayAttr are admissible collections and Python ints and
- any IntegerAttr are admissible integers. Both levels of collections are
- turned into ArrayAttr; the inner level is turned into IntegerAttrs of i64s.
- If the input is None, an empty ArrayAttr is returned.
- """
- if values is None:
- return None
-
- # Make sure the outer level is a list.
- values = _get_value_list(values)
-
- # The inner level is now either invalid or a mixed sequence of ArrayAttrs and
- # Sequences. Make sure the nested values are all lists.
- values = [_get_value_list(nested) for nested in values]
-
- # Turn each nested list into an ArrayAttr.
- values = [_get_int_array_attr(nested) for nested in values]
-
- # Turn the outer list into an ArrayAttr.
- return ArrayAttr.get(values)
-
-
-class BufferizeToAllocationOp:
- """Specialization for BufferizeToAllocationOp class."""
-
- def __init__(
- self,
- target: Union[Operation, OpView, Value],
- *,
- memory_space: Optional[Union[int, str, Attribute]] = None,
- memcpy_op: Optional[str] = None,
- alloc_op: Optional[str] = None,
- bufferize_destination_only: Optional[bool] = None,
- loc=None,
- ip=None,
- ):
- # No other types are allowed, so hard-code those here.
- allocated_buffer_type = transform.AnyValueType.get()
- new_ops_type = transform.AnyOpType.get()
-
- if isinstance(memory_space, int):
- memory_space = str(memory_space)
- if isinstance(memory_space, str):
- memory_space = Attribute.parse(memory_space)
-
- super().__init__(
- allocated_buffer_type,
- new_ops_type,
- target,
- memory_space=memory_space,
- memcpy_op=memcpy_op,
- alloc_op=alloc_op,
- bufferize_destination_only=bufferize_destination_only,
- loc=loc,
- ip=ip,
- )
-
-
-class DecomposeOp:
- """Specialization for DecomposeOp class."""
-
- def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
- transformed_type = transform.AnyOpType.get()
- super().__init__(transformed_type, target, loc=loc, ip=ip)
-
-
-class FuseIntoContainingOp:
- """Specialization for FuseIntoContainingOp class."""
-
- @overload
- def __init__(
- self,
- fused_op_type: Type,
- new_containing_op_type: Type,
- producer_op: Union[Operation, OpView, Value],
- containing_op: Union[Operation, OpView, Value],
- *,
- loc=None,
- ip=None,
- ):
- ...
-
- @overload
- def __init__(
- self,
- producer_op: Union[Operation, OpView, Value],
- containing_op: Union[Operation, OpView, Value],
- *,
- loc=None,
- ip=None,
- ):
- ...
-
- def __init__(
- self,
- fused_op_type_or_producer_op: Union[Operation, OpView, Type, Value],
- new_containing_op_type_or_containing_op: Union[Operation, OpView, Type, Value],
- producer_op_or_none: Optional[Union[Operation, OpView, Value]] = None,
- containing_op_or_none: Optional[Union[Operation, OpView, Value]] = None,
- *,
- loc=None,
- ip=None,
- ):
- if isinstance(fused_op_type_or_producer_op, Type):
- if not isinstance(new_containing_op_type_or_containing_op, Type):
- raise TypeError(
- "If 'fused_op_type_or_producer_op' is a type, then "
- "'new_containing_op_type_or_containing_op' is expected "
- "to be one as well."
- )
- fused_op_type = fused_op_type_or_producer_op
- new_containing_op_type = new_containing_op_type_or_containing_op
- producer_op = producer_op_or_none
- containing_op = containing_op_or_none
- else:
- fused_op_type = transform.AnyOpType.get()
- new_containing_op_type = transform.AnyOpType.get()
- producer_op = fused_op_type_or_producer_op
- containing_op = new_containing_op_type_or_containing_op
-
- super().__init__(
- fused_op_type,
- new_containing_op_type,
- producer_op,
- containing_op,
- loc=loc,
- ip=ip,
- )
-
-
-class GeneralizeOp:
- """Specialization for GeneralizeOp class."""
-
- def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
- transformed_type = transform.AnyOpType.get()
- super().__init__(transformed_type, target, loc=loc, ip=ip)
-
-
-class InterchangeOp:
- """Specialization for InterchangeOp class."""
-
- def __init__(
- self,
- target: Union[Operation, Value],
- *,
- iterator_interchange: OptionalIntList = None,
- loc=None,
- ip=None,
- ):
- transformed_type = transform.AnyOpType.get()
- super().__init__(
- transformed_type,
- target,
- iterator_interchange=iterator_interchange,
- loc=loc,
- ip=ip,
- )
-
-
-class MapCopyToThreadsOp:
- """Specialization for MapCopyToThreadsOp class."""
-
- @overload
- def __init__(
- self,
- forall_op_type: Type,
- tiled_op_type: Type,
- target: Union[Operation, OpView, Value],
- *,
- total_num_threads: Union[int, IntegerAttr],
- desired_bit_alignment: Union[int, IntegerAttr],
- loc=None,
- ip=None,
- ):
- ...
-
- @overload
- def __init__(
- self,
- target: Union[Operation, OpView, Value],
- *,
- total_num_threads: Union[int, IntegerAttr],
- desired_bit_alignment: Union[int, IntegerAttr],
- loc=None,
- ip=None,
- ):
- ...
-
- def __init__(
- self,
- forall_op_type_or_target: Union[Operation, OpView, Type, Value],
- tiled_op_type_or_none: Optional[Type] = None,
- target_or_none: Optional[Union[Operation, OpView, Value]] = None,
- *,
- total_num_threads: Union[int, IntegerAttr],
- desired_bit_alignment: Union[int, IntegerAttr],
- loc=None,
- ip=None,
- ):
- if isinstance(forall_op_type_or_target, Type):
- forall_op_type = forall_op_type_or_target
- tiled_op_type = tiled_op_type_or_none
- target = target_or_none
- else:
- forall_op_type = transform.AnyOpType.get()
- tiled_op_type = transform.AnyOpType.get()
- target = forall_op_type_or_target
-
- super().__init__(
- forall_op_type,
- tiled_op_type,
- target,
- total_num_threads=total_num_threads,
- desired_bit_alignment=desired_bit_alignment,
- loc=loc,
- ip=ip,
- )
-
-
-class VectorizeOp:
- """Specialization for VectorizeOp class."""
-
- def __init__(
- self,
- target: Union[Operation, OpView, Value],
- vector_sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
- *,
- vectorize_nd_extract: Optional[bool] = None,
- scalable_sizes: OptionalBoolList = None,
- static_vector_sizes: OptionalIntList = None,
- loc=None,
- ip=None,
- ):
- if (
- scalable_sizes is None
- and static_vector_sizes is None
- and vector_sizes is None
- ):
- dynamic_vector_sizes = []
- elif scalable_sizes is None and static_vector_sizes is None:
- (
- dynamic_vector_sizes,
- static_vector_sizes,
- scalable_sizes,
- ) = _dispatch_dynamic_index_list(vector_sizes)
- elif scalable_sizes is None or static_vector_sizes is None:
- raise TypeError(
- "'scalable_sizes' and 'static_vector_sizes' must either both "
- "be given explicitly or both be given as part of 'vector_sizes'."
- )
- else:
- dynamic_vector_sizes = vector_sizes
-
- super().__init__(
- target,
- vector_sizes=dynamic_vector_sizes,
- static_vector_sizes=static_vector_sizes,
- scalable_sizes=scalable_sizes,
- vectorize_nd_extract=vectorize_nd_extract,
- loc=loc,
- ip=ip,
- )
-
-
-class MatchOp:
- """Specialization for MatchOp class."""
-
- @overload
- @classmethod
- def match_op_names(
- cls,
- target: Union[Operation, Value],
- names: Union[str, Sequence[str]],
- *,
- loc=None,
- ip=None,
- ):
- ...
-
- @overload
- @classmethod
- def match_op_names(
- cls,
- result_type: Type,
- target: Union[Operation, Value],
- names: Union[str, Sequence[str]],
- *,
- loc=None,
- ip=None,
- ):
- ...
-
- @classmethod
- def match_op_names(
- cls,
- result_type_or_target: Union[Type, Operation, Value],
- target_or_names: Union[Operation, Value, Sequence[str], str],
- names_or_none: Optional[Union[Sequence[str], str]] = None,
- *,
- loc=None,
- ip=None,
- ):
- if isinstance(result_type_or_target, Type):
- result_type = result_type_or_target
- target = target_or_names
- names = names_or_none
- else:
- result_type = transform.AnyOpType.get()
- target = result_type_or_target
- names = target_or_names
-
- if isinstance(names, str):
- names = [names]
-
- return cls(
- result_type,
- target,
- ops=ArrayAttr.get(list(map(lambda s: StringAttr.get(s), names))),
- loc=loc,
- ip=ip,
- )
-
-
-class MultiTileSizesOp:
- """Specialization for MultiTileSizesOp class."""
-
- def __init__(
- self,
- result_type: Type,
- target: Union[Operation, Value],
- *,
- dimension: Union[int, IntegerAttr],
- target_size: Union[int, IntegerAttr],
- divisor: Optional[Optional[Union[int, IntegerAttr]]] = None,
- loc=None,
- ip=None,
- ):
- super().__init__(
- result_type,
- result_type,
- result_type,
- target,
- dimension=dimension,
- target_size=target_size,
- divisor=divisor,
- loc=loc,
- ip=ip,
- )
-
-
-class PadOp:
- """Specialization for PadOp class."""
-
- def __init__(
- self,
- target: Union[Operation, OpView, Value],
- *,
- padding_values: Optional[Union[ArrayAttr, Sequence[Attribute]]] = None,
- padding_dimensions: OptionalIntList = None,
- pad_to_multiple_of: OptionalIntList = None,
- pack_paddings: OptionalIntList = None,
- transpose_paddings: Optional[
- Union[ArrayAttr, Sequence[Union[ArrayAttr, IntOrAttrList]]]
- ] = None,
- copy_back_op: Optional[Union[str, StringAttr]] = None,
- loc=None,
- ip=None,
- ):
- transpose_paddings = _get_int_array_array_attr(transpose_paddings)
-
- any_op_type = transform.AnyOpType.get()
- super().__init__(
- any_op_type,
- any_op_type,
- any_op_type,
- target,
- padding_values=padding_values,
- padding_dimensions=padding_dimensions,
- pad_to_multiple_of=pad_to_multiple_of,
- pack_paddings=pack_paddings,
- transpose_paddings=transpose_paddings,
- copy_back_op=copy_back_op,
- loc=loc,
- ip=ip,
- )
-
-
-class ScalarizeOp:
- """Specialization for ScalarizeOp class."""
-
- def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
- result_type = transform.AnyOpType.get()
- super().__init__(result_type, target, loc=loc, ip=ip)
-
-
-class SplitOp:
- """Specialization for SplitOp class."""
-
- def __init__(
- self,
- target: Union[Operation, Value],
- dimension: Union[int, Attribute],
- split_point: Union[int, Operation, Value, Attribute],
- *,
- loc=None,
- ip=None,
- ):
- if isinstance(split_point, int):
- static_split_point = split_point
- dynamic_split_point = None
- else:
- static_split_point = ShapedType.get_dynamic_size()
- dynamic_split_point = split_point
-
- super().__init__(
- target.type,
- target.type,
- target,
- dimension=dimension,
- static_split_point=static_split_point,
- dynamic_split_point=dynamic_split_point,
- loc=loc,
- ip=ip,
- )
-
-
-class TileUsingForOp:
- """Specialization for TileUsingForOp class."""
-
- @overload
- def __init__(
- self,
- loop_types: Union[Type, List[Type]],
- target: Union[Operation, Value],
- *,
- sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
- interchange: OptionalIntList = None,
- loc=None,
- ip=None,
- ):
- ...
-
- @overload
- def __init__(
- self,
- target: Union[Operation, Value, OpView],
- *,
- sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
- interchange: OptionalIntList = None,
- loc=None,
- ip=None,
- ):
- ...
-
- def __init__(
- self,
- loop_types_or_target: Union[Type, List[Type], Operation, Value],
- target_or_none: Optional[Union[Operation, Value, OpView]] = None,
- *,
- sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
- interchange: OptionalIntList = None,
- loc=None,
- ip=None,
- ):
- (
- dynamic_sizes,
- static_sizes,
- scalable_sizes,
- ) = _dispatch_dynamic_index_list(sizes)
-
- num_loops = sum(v if v == 0 else 1 for v in static_sizes)
-
- if isinstance(loop_types_or_target, (Operation, Value, OpView)):
- loop_types = [transform.AnyOpType.get()] * num_loops
- target = loop_types_or_target
- assert (
- target_or_none is None
- ), "Cannot construct TileUsingForOp with two targets."
- else:
- loop_types = (
- ([loop_types_or_target] * num_loops)
- if isinstance(loop_types_or_target, Type)
- else loop_types_or_target
- )
- target = target_or_none
-
- super().__init__(
- target.type,
- loop_types,
- target,
- dynamic_sizes=dynamic_sizes,
- static_sizes=static_sizes,
- interchange=interchange,
- scalable_sizes=scalable_sizes,
- loc=loc,
- ip=ip,
- )
-
-
-class TileUsingForallOp:
- """Specialization for TileUsingForallOp class."""
-
- @overload
- def __init__(
- self,
- loops_type: Type,
- tiled_op_type: Type,
- target: Union[Operation, Value, OpView],
- *,
- num_threads: Optional[MixedValues] = None,
- tile_sizes: MixedValues = None,
- mapping=None,
- loc=None,
- ip=None,
- ):
- ...
-
- @overload
- def __init__(
- self,
- target: Union[Operation, Value, OpView],
- *,
- num_threads: Optional[MixedValues] = None,
- tile_sizes: MixedValues = None,
- mapping=None,
- loc=None,
- ip=None,
- ):
- ...
-
- def __init__(
- self,
- loops_type_or_target: Union[
- Type, Union[Operation, Value, OpView] # loops_type
- ], # target
- tiled_op_type_or_none: Optional[Type] = None,
- target_or_none: Optional[Union[Operation, Value, OpView]] = None,
- *,
- num_threads: MixedValues = None,
- tile_sizes: MixedValues = None,
- mapping=None,
- loc=None,
- ip=None,
- ):
- # `Type` arguments in the front are optional: add default values to front.
- if isinstance(loops_type_or_target, Type):
- # First overload: type arguments provided.
- if not isinstance(tiled_op_type_or_none, Type):
- raise TypeError(
- "If 'loops_type_or_target' is a type, then "
- "'tiled_op_type_or_none' is expected to be one as well."
- )
- loops_type = loops_type_or_target
- tiled_op_type = tiled_op_type_or_none
- target = target_or_none
- else:
- # Last overload: type arguments missing.
- loops_type = transform.AnyOpType.get()
- tiled_op_type = transform.AnyOpType.get()
- target = loops_type_or_target
-
- # Unpack mixed num_threads.
- (
- dynamic_num_threads,
- packed_num_threads,
- num_threads_attr,
- ) = _dispatch_mixed_values(num_threads)
-
- # Unpack mixed tile_sizes.
- (
- dynamic_tile_sizes,
- packed_tile_sizes,
- tile_sizes_attr,
- ) = _dispatch_mixed_values(tile_sizes)
-
- super().__init__(
- loops_type,
- tiled_op_type,
- target=target,
- tile_sizes=dynamic_tile_sizes,
- packed_tile_sizes=packed_tile_sizes,
- static_tile_sizes=tile_sizes_attr,
- num_threads=dynamic_num_threads,
- packed_num_threads=packed_num_threads,
- static_num_threads=num_threads_attr,
- mapping=mapping,
- loc=loc,
- ip=ip,
- )
-
-
-class VectorizeChildrenAndApplyPatternsOp:
- """Specialization for VectorizeChildrenAndApplyPatternsOp class."""
-
- def __init__(
- self,
- target: Union[Operation, Value],
- *,
- disable_multi_reduction_to_contract_patterns: bool = False,
- disable_transfer_permutation_map_lowering_patterns: bool = False,
- vectorize_nd_extract: bool = False,
- vectorize_padding: bool = False,
- loc=None,
- ip=None,
- ):
- transformed_type = transform.AnyOpType.get()
- super().__init__(
- transformed_type,
- target,
- disable_multi_reduction_to_contract_patterns=disable_multi_reduction_to_contract_patterns,
- disable_transfer_permutation_map_lowering_patterns=disable_transfer_permutation_map_lowering_patterns,
- vectorize_nd_extract=vectorize_nd_extract,
- vectorize_padding=vectorize_padding,
- loc=loc,
- ip=ip,
- )
diff --git a/mlir/python/mlir/dialects/_tensor_ops_ext.py b/mlir/python/mlir/dialects/_tensor_ops_ext.py
deleted file mode 100644
index 09b9ec68db7d9c7..000000000000000
--- a/mlir/python/mlir/dialects/_tensor_ops_ext.py
+++ /dev/null
@@ -1,44 +0,0 @@
-# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-try:
- from ..ir import *
-except ImportError as e:
- raise RuntimeError("Error loading imports from extension module") from e
-
-from typing import Any, Optional, Sequence, Union
-from ._ods_common import (
- get_op_result_or_value as _get_op_result_or_value,
- get_op_results_or_values as _get_op_results_or_values,
-)
-
-
-class EmptyOp:
- """Extends the tensor.empty op."""
-
- def __init__(
- self,
- sizes: Sequence[Union[int, Value]],
- element_type: Type,
- *,
- loc=None,
- ip=None
- ):
- """Constructs an `empty` with mixed static/dynamic sizes."""
- # TODO: Refactor the EmptyOp to take an element type attribute and
- # then use normal result type inference, unifying the Python and C++ side
- # with a standard mechanism (versus stashing that in builders).
- dynamic_sizes = []
- static_sizes = []
- for s in sizes:
- if isinstance(s, int):
- static_sizes.append(s)
- else:
- static_sizes.append(ShapedType.get_dynamic_size())
- dynamic_sizes.append(s)
- result_type = RankedTensorType.get(static_sizes, element_type)
- op = self.build_generic(
- results=[result_type], operands=dynamic_sizes, attributes={}, loc=loc, ip=ip
- )
- OpView.__init__(self, op)
diff --git a/mlir/python/mlir/dialects/_tensor_transform_ops_ext.py b/mlir/python/mlir/dialects/_tensor_transform_ops_ext.py
deleted file mode 100644
index 996093fbc913e8a..000000000000000
--- a/mlir/python/mlir/dialects/_tensor_transform_ops_ext.py
+++ /dev/null
@@ -1,64 +0,0 @@
-# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-try:
- from ..ir import *
- from ..dialects import transform
-except ImportError as e:
- raise RuntimeError("Error loading imports from extension module") from e
-
-from typing import Optional, overload, Union
-
-
-class MakeLoopIndependentOp:
- """Specialization for MakeLoopIndependentOp class."""
-
- @overload
- def __init__(
- self,
- transformed_type: Type,
- target: Union[Operation, OpView, Value],
- num_loops: Union[int, IntegerAttr],
- *,
- loc=None,
- ip=None
- ):
- ...
-
- @overload
- def __init__(
- self,
- target: Union[Operation, OpView, Value],
- num_loops: Union[int, IntegerAttr],
- *,
- loc=None,
- ip=None
- ):
- ...
-
- def __init__(
- self,
- transformed_type_or_target: Type,
- target_or_num_loops: Union[int, IntegerAttr, Operation, OpView, Value] = None,
- num_loops_or_none: Optional[Union[int, IntegerAttr]] = None,
- *,
- loc=None,
- ip=None
- ):
- if isinstance(transformed_type_or_target, Type):
- transformed_type = transformed_type_or_target
- target = target_or_num_loops
- num_loops = num_loops_or_none
- else:
- transformed_type = transform.AnyOpType.get()
- target = transformed_type_or_target
- num_loops = target_or_num_loops
-
- super().__init__(
- transformed_type,
- target,
- num_loops,
- loc=loc,
- ip=ip,
- )
diff --git a/mlir/python/mlir/dialects/_transform_ops_ext.py b/mlir/python/mlir/dialects/_transform_ops_ext.py
deleted file mode 100644
index b1e7b892536f4a1..000000000000000
--- a/mlir/python/mlir/dialects/_transform_ops_ext.py
+++ /dev/null
@@ -1,176 +0,0 @@
-# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-try:
- from ..ir import *
- from ._ods_common import (
- get_op_result_or_value as _get_op_result_or_value,
- get_op_results_or_values as _get_op_results_or_values,
- )
-except ImportError as e:
- raise RuntimeError("Error loading imports from extension module") from e
-
-from typing import Optional, Sequence, Union
-
-
-class CastOp:
- def __init__(
- self,
- result_type: Type,
- target: Union[Operation, Value],
- *,
- loc=None,
- ip=None,
- ):
- super().__init__(result_type, _get_op_result_or_value(target), loc=loc, ip=ip)
-
-
-class ApplyPatternsOp:
- def __init__(
- self,
- target: Union[Operation, Value, OpView],
- *,
- loc=None,
- ip=None,
- ):
- operands = []
- operands.append(_get_op_result_or_value(target))
- super().__init__(
- self.build_generic(
- attributes={},
- results=[],
- operands=operands,
- successors=None,
- regions=None,
- loc=loc,
- ip=ip,
- )
- )
- self.regions[0].blocks.append()
-
- @property
- def patterns(self) -> Block:
- return self.regions[0].blocks[0]
-
-
-class testGetParentOp:
- def __init__(
- self,
- result_type: Type,
- target: Union[Operation, Value],
- *,
- isolated_from_above: bool = False,
- op_name: Optional[str] = None,
- deduplicate: bool = False,
- loc=None,
- ip=None,
- ):
- super().__init__(
- result_type,
- _get_op_result_or_value(target),
- isolated_from_above=isolated_from_above,
- op_name=op_name,
- deduplicate=deduplicate,
- loc=loc,
- ip=ip,
- )
-
-
-class MergeHandlesOp:
- def __init__(
- self,
- handles: Sequence[Union[Operation, Value]],
- *,
- deduplicate: bool = False,
- loc=None,
- ip=None,
- ):
- super().__init__(
- [_get_op_result_or_value(h) for h in handles],
- deduplicate=deduplicate,
- loc=loc,
- ip=ip,
- )
-
-
-class ReplicateOp:
- def __init__(
- self,
- pattern: Union[Operation, Value],
- handles: Sequence[Union[Operation, Value]],
- *,
- loc=None,
- ip=None,
- ):
- super().__init__(
- [_get_op_result_or_value(h).type for h in handles],
- _get_op_result_or_value(pattern),
- [_get_op_result_or_value(h) for h in handles],
- loc=loc,
- ip=ip,
- )
-
-
-class SequenceOp:
- def __init__(
- self,
- failure_propagation_mode,
- results: Sequence[Type],
- target: Union[Operation, Value, Type],
- extra_bindings: Optional[
- Union[Sequence[Value], Sequence[Type], Operation, OpView]
- ] = None,
- ):
- root = (
- _get_op_result_or_value(target)
- if isinstance(target, (Operation, Value))
- else None
- )
- root_type = root.type if not isinstance(target, Type) else target
-
- if extra_bindings is None:
- extra_bindings = []
- if isinstance(extra_bindings, (Operation, OpView)):
- extra_bindings = _get_op_results_or_values(extra_bindings)
-
- extra_binding_types = []
- if len(extra_bindings) != 0:
- if isinstance(extra_bindings[0], Type):
- extra_binding_types = extra_bindings
- extra_bindings = []
- else:
- extra_binding_types = [v.type for v in extra_bindings]
-
- super().__init__(
- results_=results,
- failure_propagation_mode=failure_propagation_mode,
- root=root,
- extra_bindings=extra_bindings,
- )
- self.regions[0].blocks.append(*tuple([root_type] + extra_binding_types))
-
- @property
- def body(self) -> Block:
- return self.regions[0].blocks[0]
-
- @property
- def bodyTarget(self) -> Value:
- return self.body.arguments[0]
-
- @property
- def bodyExtraArgs(self) -> BlockArgumentList:
- return self.body.arguments[1:]
-
-
-class YieldOp:
- def __init__(
- self,
- operands: Optional[Union[Operation, Sequence[Value]]] = None,
- *,
- loc=None,
- ip=None,
- ):
- if operands is None:
- operands = []
- super().__init__(_get_op_results_or_values(operands), loc=loc, ip=ip)
diff --git a/mlir/python/mlir/dialects/_transform_pdl_extension_ops_ext.py b/mlir/python/mlir/dialects/_transform_pdl_extension_ops_ext.py
deleted file mode 100644
index c4e4b4b4254b038..000000000000000
--- a/mlir/python/mlir/dialects/_transform_pdl_extension_ops_ext.py
+++ /dev/null
@@ -1,55 +0,0 @@
-# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-try:
- from ..ir import *
- from ._ods_common import (
- get_op_result_or_value as _get_op_result_or_value,
- get_op_results_or_values as _get_op_results_or_values,
- )
-except ImportError as e:
- raise RuntimeError("Error loading imports from extension module") from e
-
-from typing import Union
-
-class PDLMatchOp:
-
- def __init__(
- self,
- result_type: Type,
- target: Union[Operation, Value],
- pattern_name: Union[Attribute, str],
- *,
- loc=None,
- ip=None,
- ):
- super().__init__(
- result_type,
- _get_op_result_or_value(target),
- pattern_name,
- loc=loc,
- ip=ip,
- )
-
-
-class WithPDLPatternsOp:
-
- def __init__(self,
- target: Union[Operation, Value, Type],
- *,
- loc=None,
- ip=None):
- root = _get_op_result_or_value(target) if not isinstance(target,
- Type) else None
- root_type = target if isinstance(target, Type) else root.type
- super().__init__(root=root, loc=loc, ip=ip)
- self.regions[0].blocks.append(root_type)
-
- @property
- def body(self) -> Block:
- return self.regions[0].blocks[0]
-
- @property
- def bodyTarget(self) -> Value:
- return self.body.arguments[0]
diff --git a/mlir/python/mlir/dialects/arith.py b/mlir/python/mlir/dialects/arith.py
index fb13beb63ca66c3..e3b6a428c879de5 100644
--- a/mlir/python/mlir/dialects/arith.py
+++ b/mlir/python/mlir/dialects/arith.py
@@ -3,4 +3,71 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from ._arith_ops_gen import *
+from ._arith_ops_gen import _Dialect
from ._arith_enum_gen import *
+
+try:
+ from ..ir import *
+ from ._ods_common import get_default_loc_context as _get_default_loc_context, _cext as _ods_cext
+
+ from typing import Any, List, Union
+except ImportError as e:
+ raise RuntimeError("Error loading imports from extension module") from e
+
+
+def _isa(obj: Any, cls: type):
+ try:
+ cls(obj)
+ except ValueError:
+ return False
+ return True
+
+
+def _is_any_of(obj: Any, classes: List[type]):
+ return any(_isa(obj, cls) for cls in classes)
+
+
+def _is_integer_like_type(type: Type):
+ return _is_any_of(type, [IntegerType, IndexType])
+
+
+def _is_float_type(type: Type):
+ return _is_any_of(type, [BF16Type, F16Type, F32Type, F64Type])
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class ConstantOp(ConstantOp):
+ """Specialization for the constant op class."""
+
+ def __init__(
+ self, result: Type, value: Union[int, float, Attribute], *, loc=None, ip=None
+ ):
+ if isinstance(value, int):
+ super().__init__(IntegerAttr.get(result, value), loc=loc, ip=ip)
+ elif isinstance(value, float):
+ super().__init__(FloatAttr.get(result, value), loc=loc, ip=ip)
+ else:
+ super().__init__(value, loc=loc, ip=ip)
+
+ @classmethod
+ def create_index(cls, value: int, *, loc=None, ip=None):
+ """Create an index-typed constant."""
+ return cls(
+ IndexType.get(context=_get_default_loc_context(loc)), value, loc=loc, ip=ip
+ )
+
+ @property
+ def type(self):
+ return self.results[0].type
+
+ @property
+ def value(self):
+ return Attribute(self.operation.attributes["value"])
+
+ @property
+ def literal_value(self) -> Union[int, float]:
+ if _is_integer_like_type(self.type):
+ return IntegerAttr(self.value).value
+ elif _is_float_type(self.type):
+ return FloatAttr(self.value).value
+ else:
+ raise ValueError("only integer and float constants have literal values")
diff --git a/mlir/python/mlir/dialects/bufferization.py b/mlir/python/mlir/dialects/bufferization.py
index 759b6aa24a9ff73..78139c8f5cc3b0a 100644
--- a/mlir/python/mlir/dialects/bufferization.py
+++ b/mlir/python/mlir/dialects/bufferization.py
@@ -3,4 +3,47 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from ._bufferization_ops_gen import *
+from ._bufferization_ops_gen import _Dialect
from ._bufferization_enum_gen import *
+
+try:
+ from typing import Sequence, Union
+ from ..ir import *
+ from ._ods_common import get_default_loc_context, _cext as _ods_cext
+
+ from typing import Any, List, Union
+except ImportError as e:
+ raise RuntimeError("Error loading imports from extension module") from e
+
+_AllocTensorOp = AllocTensorOp
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class AllocTensorOp(_AllocTensorOp):
+ """Extends the bufferization.alloc_tensor op."""
+
+ def __init__(
+ self,
+ tensor_type: Type,
+ dynamic_sizes: Sequence[Value],
+ copy: Value,
+ size_hint: Value,
+ escape: BoolAttr,
+ *,
+ loc=None,
+ ip=None,
+ ):
+ """Constructs an `alloc_tensor` with static and/or dynamic sizes."""
+ context = get_default_loc_context(loc)
+ attributes = {}
+ if escape:
+ attributes["escape"] = escape
+ super(_AllocTensorOp, self).__init__(
+ self.build_generic(
+ results=[tensor_type],
+ operands=[dynamic_sizes, copy, size_hint],
+ attributes=attributes,
+ loc=loc,
+ ip=ip,
+ )
+ )
diff --git a/mlir/python/mlir/dialects/builtin.py b/mlir/python/mlir/dialects/builtin.py
index 30279e1611f99aa..1c8215ba4b67532 100644
--- a/mlir/python/mlir/dialects/builtin.py
+++ b/mlir/python/mlir/dialects/builtin.py
@@ -3,3 +3,27 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from ._builtin_ops_gen import *
+from ._builtin_ops_gen import _Dialect
+
+try:
+ from ..ir import *
+ from ._ods_common import _cext as _ods_cext
+except ImportError as e:
+ raise RuntimeError("Error loading imports from extension module") from e
+
+_ModuleOp = ModuleOp
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class ModuleOp(_ModuleOp):
+ """Specialization for the module op class."""
+
+ def __init__(self, *, loc=None, ip=None):
+ super(_ModuleOp, self).__init__(
+ self.build_generic(results=[], operands=[], loc=loc, ip=ip)
+ )
+ body = self.regions[0].blocks.append()
+
+ @property
+ def body(self):
+ return self.regions[0].blocks[0]
diff --git a/mlir/python/mlir/dialects/func.py b/mlir/python/mlir/dialects/func.py
index dc554c22173bc60..9c6c4c9092c7a88 100644
--- a/mlir/python/mlir/dialects/func.py
+++ b/mlir/python/mlir/dialects/func.py
@@ -3,3 +3,326 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from ._func_ops_gen import *
+from ._func_ops_gen import _Dialect
+
+try:
+ from ..ir import *
+ from ._ods_common import (
+ get_default_loc_context as _get_default_loc_context,
+ _cext as _ods_cext,
+ )
+
+ import inspect
+
+ from typing import Any, List, Optional, Sequence, Union
+except ImportError as e:
+ raise RuntimeError("Error loading imports from extension module") from e
+
+ARGUMENT_ATTRIBUTE_NAME = "arg_attrs"
+RESULT_ATTRIBUTE_NAME = "res_attrs"
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class ConstantOp(ConstantOp):
+ """Specialization for the constant op class."""
+
+ def __init__(self, result: Type, value: Attribute, *, loc=None, ip=None):
+ super().__init__(result, value, loc=loc, ip=ip)
+
+ @property
+ def type(self):
+ return self.results[0].type
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class FuncOp(FuncOp):
+ """Specialization for the func op class."""
+
+ def __init__(
+ self, name, type, *, visibility=None, body_builder=None, loc=None, ip=None
+ ):
+ """
+ Create a FuncOp with the provided `name`, `type`, and `visibility`.
+ - `name` is a string representing the function name.
+ - `type` is either a FunctionType or a pair of list describing inputs and
+ results.
+ - `visibility` is a string matching `public`, `private`, or `nested`. None
+ implies private visibility.
+ - `body_builder` is an optional callback, when provided a new entry block
+ is created and the callback is invoked with the new op as argument within
+ an InsertionPoint context already set for the block. The callback is
+ expected to insert a terminator in the block.
+ """
+ sym_name = StringAttr.get(str(name))
+
+ # If the type is passed as a tuple, build a FunctionType on the fly.
+ if isinstance(type, tuple):
+ type = FunctionType.get(inputs=type[0], results=type[1])
+
+ type = TypeAttr.get(type)
+ sym_visibility = (
+ StringAttr.get(str(visibility)) if visibility is not None else None
+ )
+ super().__init__(sym_name, type, sym_visibility=sym_visibility, loc=loc, ip=ip)
+ if body_builder:
+ entry_block = self.add_entry_block()
+ with InsertionPoint(entry_block):
+ body_builder(self)
+
+ @property
+ def is_external(self):
+ return len(self.regions[0].blocks) == 0
+
+ @property
+ def body(self):
+ return self.regions[0]
+
+ @property
+ def type(self):
+ return FunctionType(TypeAttr(self.attributes["function_type"]).value)
+
+ @property
+ def visibility(self):
+ return self.attributes["sym_visibility"]
+
+ @property
+ def name(self) -> StringAttr:
+ return StringAttr(self.attributes["sym_name"])
+
+ @property
+ def entry_block(self):
+ if self.is_external:
+ raise IndexError("External function does not have a body")
+ return self.regions[0].blocks[0]
+
+ def add_entry_block(self, arg_locs: Optional[Sequence[Location]] = None):
+ """
+ Add an entry block to the function body using the function signature to
+ infer block arguments.
+ Returns the newly created block
+ """
+ if not self.is_external:
+ raise IndexError("The function already has an entry block!")
+ self.body.blocks.append(*self.type.inputs, arg_locs=arg_locs)
+ return self.body.blocks[0]
+
+ @property
+ def arg_attrs(self):
+ return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME])
+
+ @arg_attrs.setter
+ def arg_attrs(self, attribute: Union[ArrayAttr, list]):
+ if isinstance(attribute, ArrayAttr):
+ self.attributes[ARGUMENT_ATTRIBUTE_NAME] = attribute
+ else:
+ self.attributes[ARGUMENT_ATTRIBUTE_NAME] = ArrayAttr.get(
+ attribute, context=self.context
+ )
+
+ @property
+ def arguments(self):
+ return self.entry_block.arguments
+
+ @property
+ def result_attrs(self):
+ return self.attributes[RESULT_ATTRIBUTE_NAME]
+
+ @result_attrs.setter
+ def result_attrs(self, attribute: ArrayAttr):
+ self.attributes[RESULT_ATTRIBUTE_NAME] = attribute
+
+ @classmethod
+ def from_py_func(
+ FuncOp,
+ *inputs: Type,
+ results: Optional[Sequence[Type]] = None,
+ name: Optional[str] = None,
+ ):
+ """Decorator to define an MLIR FuncOp specified as a python function.
+
+ Requires that an `mlir.ir.InsertionPoint` and `mlir.ir.Location` are
+ active for the current thread (i.e. established in a `with` block).
+
+ When applied as a decorator to a Python function, an entry block will
+ be constructed for the FuncOp with types as specified in `*inputs`. The
+ block arguments will be passed positionally to the Python function. In
+ addition, if the Python function accepts keyword arguments generally or
+ has a corresponding keyword argument, the following will be passed:
+ * `func_op`: The `func` op being defined.
+
+ By default, the function name will be the Python function `__name__`. This
+ can be overriden by passing the `name` argument to the decorator.
+
+ If `results` is not specified, then the decorator will implicitly
+ insert a `ReturnOp` with the `Value`'s returned from the decorated
+ function. It will also set the `FuncOp` type with the actual return
+ value types. If `results` is specified, then the decorated function
+ must return `None` and no implicit `ReturnOp` is added (nor are the result
+ types updated). The implicit behavior is intended for simple, single-block
+ cases, and users should specify result types explicitly for any complicated
+ cases.
+
+ The decorated function can further be called from Python and will insert
+ a `CallOp` at the then-current insertion point, returning either None (
+ if no return values), a unary Value (for one result), or a list of Values).
+ This mechanism cannot be used to emit recursive calls (by construction).
+ """
+
+ def decorator(f):
+ from . import func
+
+ # Introspect the callable for optional features.
+ sig = inspect.signature(f)
+ has_arg_func_op = False
+ for param in sig.parameters.values():
+ if param.kind == param.VAR_KEYWORD:
+ has_arg_func_op = True
+ if param.name == "func_op" and (
+ param.kind == param.POSITIONAL_OR_KEYWORD
+ or param.kind == param.KEYWORD_ONLY
+ ):
+ has_arg_func_op = True
+
+ # Emit the FuncOp.
+ implicit_return = results is None
+ symbol_name = name or f.__name__
+ function_type = FunctionType.get(
+ inputs=inputs, results=[] if implicit_return else results
+ )
+ func_op = FuncOp(name=symbol_name, type=function_type)
+ with InsertionPoint(func_op.add_entry_block()):
+ func_args = func_op.entry_block.arguments
+ func_kwargs = {}
+ if has_arg_func_op:
+ func_kwargs["func_op"] = func_op
+ return_values = f(*func_args, **func_kwargs)
+ if not implicit_return:
+ return_types = list(results)
+ assert return_values is None, (
+ "Capturing a python function with explicit `results=` "
+ "requires that the wrapped function returns None."
+ )
+ else:
+ # Coerce return values, add ReturnOp and rewrite func type.
+ if return_values is None:
+ return_values = []
+ elif isinstance(return_values, tuple):
+ return_values = list(return_values)
+ elif isinstance(return_values, Value):
+ # Returning a single value is fine, coerce it into a list.
+ return_values = [return_values]
+ elif isinstance(return_values, OpView):
+ # Returning a single operation is fine, coerce its results a list.
+ return_values = return_values.operation.results
+ elif isinstance(return_values, Operation):
+ # Returning a single operation is fine, coerce its results a list.
+ return_values = return_values.results
+ else:
+ return_values = list(return_values)
+ func.ReturnOp(return_values)
+ # Recompute the function type.
+ return_types = [v.type for v in return_values]
+ function_type = FunctionType.get(
+ inputs=inputs, results=return_types
+ )
+ func_op.attributes["function_type"] = TypeAttr.get(function_type)
+
+ def emit_call_op(*call_args):
+ call_op = func.CallOp(
+ return_types, FlatSymbolRefAttr.get(symbol_name), call_args
+ )
+ if return_types is None:
+ return None
+ elif len(return_types) == 1:
+ return call_op.result
+ else:
+ return call_op.results
+
+ wrapped = emit_call_op
+ wrapped.__name__ = f.__name__
+ wrapped.func_op = func_op
+ return wrapped
+
+ return decorator
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class CallOp(CallOp):
+ """Specialization for the call op class."""
+
+ def __init__(
+ self,
+ calleeOrResults: Union[FuncOp, List[Type]],
+ argumentsOrCallee: Union[List, FlatSymbolRefAttr, str],
+ arguments: Optional[List] = None,
+ *,
+ loc=None,
+ ip=None,
+ ):
+ """Creates an call operation.
+
+ The constructor accepts three different forms:
+
+ 1. A function op to be called followed by a list of arguments.
+ 2. A list of result types, followed by the name of the function to be
+ called as string, following by a list of arguments.
+ 3. A list of result types, followed by the name of the function to be
+ called as symbol reference attribute, followed by a list of arguments.
+
+ For example
+
+ f = func.FuncOp("foo", ...)
+ func.CallOp(f, [args])
+ func.CallOp([result_types], "foo", [args])
+
+ In all cases, the location and insertion point may be specified as keyword
+ arguments if not provided by the surrounding context managers.
+ """
+
+ # TODO: consider supporting constructor "overloads", e.g., through a custom
+ # or pybind-provided metaclass.
+ if isinstance(calleeOrResults, FuncOp):
+ if not isinstance(argumentsOrCallee, list):
+ raise ValueError(
+ "when constructing a call to a function, expected "
+ + "the second argument to be a list of call arguments, "
+ + f"got {type(argumentsOrCallee)}"
+ )
+ if arguments is not None:
+ raise ValueError(
+ "unexpected third argument when constructing a call"
+ + "to a function"
+ )
+
+ super().__init__(
+ calleeOrResults.type.results,
+ FlatSymbolRefAttr.get(
+ calleeOrResults.name.value, context=_get_default_loc_context(loc)
+ ),
+ argumentsOrCallee,
+ loc=loc,
+ ip=ip,
+ )
+ return
+
+ if isinstance(argumentsOrCallee, list):
+ raise ValueError(
+ "when constructing a call to a function by name, "
+ + "expected the second argument to be a string or a "
+ + f"FlatSymbolRefAttr, got {type(argumentsOrCallee)}"
+ )
+
+ if isinstance(argumentsOrCallee, FlatSymbolRefAttr):
+ super().__init__(
+ calleeOrResults, argumentsOrCallee, arguments, loc=loc, ip=ip
+ )
+ elif isinstance(argumentsOrCallee, str):
+ super().__init__(
+ calleeOrResults,
+ FlatSymbolRefAttr.get(
+ argumentsOrCallee, context=_get_default_loc_context(loc)
+ ),
+ arguments,
+ loc=loc,
+ ip=ip,
+ )
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
index 6f9d72164429eea..f91fc8b7160089b 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
@@ -310,7 +310,7 @@ def emit_named_structured_op(
)
# Set the index attributes used to compute the indexing maps.
- named_op = getattr(linalg, op_class_name)(ins, outs, result_types)
+ named_op = getattr(linalg, op_class_name)(result_types, ins, outs)
for name, value in index_attrs.items():
named_op.operation.attributes[name] = value
diff --git a/mlir/python/mlir/dialects/memref.py b/mlir/python/mlir/dialects/memref.py
index 3afb6a70cb9e0db..111ad2178703d28 100644
--- a/mlir/python/mlir/dialects/memref.py
+++ b/mlir/python/mlir/dialects/memref.py
@@ -3,3 +3,41 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from ._memref_ops_gen import *
+from ._memref_ops_gen import _Dialect
+
+try:
+ from ..ir import *
+ from ._ods_common import (
+ get_op_result_or_value as _get_op_result_or_value,
+ get_op_results_or_values as _get_op_results_or_values,
+ _cext as _ods_cext,
+ )
+except ImportError as e:
+ raise RuntimeError("Error loading imports from extension module") from e
+
+from typing import Optional, Sequence, Union
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class LoadOp(LoadOp):
+ """Specialization for the MemRef load operation."""
+
+ def __init__(
+ self,
+ memref: Union[Operation, OpView, Value],
+ indices: Optional[Union[Operation, OpView, Sequence[Value]]] = None,
+ *,
+ loc=None,
+ ip=None,
+ ):
+ """Creates a memref load operation.
+
+ Args:
+ memref: the buffer to load from.
+ indices: the list of subscripts, may be empty for zero-dimensional
+ buffers.
+ loc: user-visible location of the operation.
+ ip: insertion point.
+ """
+ indices_resolved = [] if indices is None else _get_op_results_or_values(indices)
+ super().__init__(memref, indices_resolved, loc=loc, ip=ip)
diff --git a/mlir/python/mlir/dialects/ml_program.py b/mlir/python/mlir/dialects/ml_program.py
index a654529b4bb8843..dfb6d7f2c03b1cf 100644
--- a/mlir/python/mlir/dialects/ml_program.py
+++ b/mlir/python/mlir/dialects/ml_program.py
@@ -2,4 +2,118 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+from typing import Union
+
from ._ml_program_ops_gen import *
+from ._ml_program_ops_gen import _Dialect
+
+try:
+ from ..ir import *
+ from ._ods_common import (
+ get_default_loc_context as _get_default_loc_context,
+ _cext as _ods_cext,
+ )
+except ImportError as e:
+ raise RuntimeError("Error loading imports from extension module") from e
+
+
+ARGUMENT_ATTRIBUTE_NAME = "arg_attrs"
+RESULT_ATTRIBUTE_NAME = "res_attrs"
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class FuncOp(FuncOp):
+ """Specialization for the func op class."""
+
+ def __init__(
+ self, name, type, *, visibility=None, body_builder=None, loc=None, ip=None
+ ):
+ """
+ Create a FuncOp with the provided `name`, `type`, and `visibility`.
+ - `name` is a string representing the function name.
+ - `type` is either a FunctionType or a pair of list describing inputs and
+ results.
+ - `visibility` is a string matching `public`, `private`, or `nested`. None
+ implies private visibility.
+ - `body_builder` is an optional callback, when provided a new entry block
+ is created and the callback is invoked with the new op as argument within
+ an InsertionPoint context already set for the block. The callback is
+ expected to insert a terminator in the block.
+ """
+ sym_name = StringAttr.get(str(name))
+
+ # If the type is passed as a tuple, build a FunctionType on the fly.
+ if isinstance(type, tuple):
+ type = FunctionType.get(inputs=type[0], results=type[1])
+
+ type = TypeAttr.get(type)
+ sym_visibility = (
+ StringAttr.get(str(visibility)) if visibility is not None else None
+ )
+ super().__init__(sym_name, type, sym_visibility=sym_visibility, loc=loc, ip=ip)
+ if body_builder:
+ entry_block = self.add_entry_block()
+ with InsertionPoint(entry_block):
+ body_builder(self)
+
+ @property
+ def is_external(self):
+ return len(self.regions[0].blocks) == 0
+
+ @property
+ def body(self):
+ return self.regions[0]
+
+ @property
+ def type(self):
+ return FunctionType(TypeAttr(self.attributes["function_type"]).value)
+
+ @property
+ def visibility(self):
+ return self.attributes["sym_visibility"]
+
+ @property
+ def name(self) -> StringAttr:
+ return StringAttr(self.attributes["sym_name"])
+
+ @property
+ def entry_block(self):
+ if self.is_external:
+ raise IndexError("External function does not have a body")
+ return self.regions[0].blocks[0]
+
+ def add_entry_block(self):
+ """
+ Add an entry block to the function body using the function signature to
+ infer block arguments.
+ Returns the newly created block
+ """
+ if not self.is_external:
+ raise IndexError("The function already has an entry block!")
+ self.body.blocks.append(*self.type.inputs)
+ return self.body.blocks[0]
+
+ @property
+ def arg_attrs(self):
+ return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME])
+
+ @arg_attrs.setter
+ def arg_attrs(self, attribute: Union[ArrayAttr, list]):
+ if isinstance(attribute, ArrayAttr):
+ self.attributes[ARGUMENT_ATTRIBUTE_NAME] = attribute
+ else:
+ self.attributes[ARGUMENT_ATTRIBUTE_NAME] = ArrayAttr.get(
+ attribute, context=self.context
+ )
+
+ @property
+ def arguments(self):
+ return self.entry_block.arguments
+
+ @property
+ def result_attrs(self):
+ return self.attributes[RESULT_ATTRIBUTE_NAME]
+
+ @result_attrs.setter
+ def result_attrs(self, attribute: ArrayAttr):
+ self.attributes[RESULT_ATTRIBUTE_NAME] = attribute
diff --git a/mlir/python/mlir/dialects/pdl.py b/mlir/python/mlir/dialects/pdl.py
index dda2b7d6521965f..a8d9c56f4233d9e 100644
--- a/mlir/python/mlir/dialects/pdl.py
+++ b/mlir/python/mlir/dialects/pdl.py
@@ -3,4 +3,289 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from ._pdl_ops_gen import *
+from ._pdl_ops_gen import _Dialect
from .._mlir_libs._mlirDialectsPDL import *
+
+
+try:
+ from ..ir import *
+ from ..dialects import pdl
+except ImportError as e:
+ raise RuntimeError("Error loading imports from extension module") from e
+
+from typing import Union, Optional, Sequence, Mapping
+from ._ods_common import (
+ get_op_result_or_value as _get_value,
+ get_op_results_or_values as _get_values,
+ _cext as _ods_cext,
+)
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class ApplyNativeConstraintOp(ApplyNativeConstraintOp):
+ """Specialization for PDL apply native constraint op class."""
+
+ def __init__(
+ self,
+ name: Union[str, StringAttr],
+ args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
+ *,
+ loc=None,
+ ip=None,
+ ):
+ if args is None:
+ args = []
+ args = _get_values(args)
+ super().__init__(name, args, loc=loc, ip=ip)
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class ApplyNativeRewriteOp(ApplyNativeRewriteOp):
+ """Specialization for PDL apply native rewrite op class."""
+
+ def __init__(
+ self,
+ results: Sequence[Type],
+ name: Union[str, StringAttr],
+ args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
+ *,
+ loc=None,
+ ip=None,
+ ):
+ if args is None:
+ args = []
+ args = _get_values(args)
+ super().__init__(results, name, args, loc=loc, ip=ip)
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class AttributeOp(AttributeOp):
+ """Specialization for PDL attribute op class."""
+
+ def __init__(
+ self,
+ valueType: Optional[Union[OpView, Operation, Value]] = None,
+ value: Optional[Attribute] = None,
+ *,
+ loc=None,
+ ip=None,
+ ):
+ valueType = valueType if valueType is None else _get_value(valueType)
+ result = pdl.AttributeType.get()
+ super().__init__(result, valueType=valueType, value=value, loc=loc, ip=ip)
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class EraseOp(EraseOp):
+ """Specialization for PDL erase op class."""
+
+ def __init__(
+ self,
+ operation: Optional[Union[OpView, Operation, Value]] = None,
+ *,
+ loc=None,
+ ip=None,
+ ):
+ operation = _get_value(operation)
+ super().__init__(operation, loc=loc, ip=ip)
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class OperandOp(OperandOp):
+ """Specialization for PDL operand op class."""
+
+ def __init__(
+ self,
+ type: Optional[Union[OpView, Operation, Value]] = None,
+ *,
+ loc=None,
+ ip=None,
+ ):
+ type = type if type is None else _get_value(type)
+ result = pdl.ValueType.get()
+ super().__init__(result, valueType=type, loc=loc, ip=ip)
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class OperandsOp(OperandsOp):
+ """Specialization for PDL operands op class."""
+
+ def __init__(
+ self,
+ types: Optional[Union[OpView, Operation, Value]] = None,
+ *,
+ loc=None,
+ ip=None,
+ ):
+ types = types if types is None else _get_value(types)
+ result = pdl.RangeType.get(pdl.ValueType.get())
+ super().__init__(result, valueType=types, loc=loc, ip=ip)
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class OperationOp(OperationOp):
+ """Specialization for PDL operand op class."""
+
+ def __init__(
+ self,
+ name: Optional[Union[str, StringAttr]] = None,
+ args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
+ attributes: Optional[Mapping[str, Union[OpView, Operation, Value]]] = None,
+ types: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
+ *,
+ loc=None,
+ ip=None,
+ ):
+ if types is None:
+ types = []
+ if attributes is None:
+ attributes = {}
+ if args is None:
+ args = []
+ args = _get_values(args)
+ attrNames = []
+ attrValues = []
+ for attrName, attrValue in attributes.items():
+ attrNames.append(StringAttr.get(attrName))
+ attrValues.append(_get_value(attrValue))
+ attrNames = ArrayAttr.get(attrNames)
+ types = _get_values(types)
+ result = pdl.OperationType.get()
+ super().__init__(
+ result, args, attrValues, attrNames, types, opName=name, loc=loc, ip=ip
+ )
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class PatternOp(PatternOp):
+ """Specialization for PDL pattern op class."""
+
+ def __init__(
+ self,
+ benefit: Union[IntegerAttr, int],
+ name: Optional[Union[StringAttr, str]] = None,
+ *,
+ loc=None,
+ ip=None,
+ ):
+ """Creates an PDL `pattern` operation."""
+ super().__init__(benefit, sym_name=name, loc=loc, ip=ip)
+ self.regions[0].blocks.append()
+
+ @property
+ def body(self):
+ """Return the body (block) of the pattern."""
+ return self.regions[0].blocks[0]
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class ReplaceOp(ReplaceOp):
+ """Specialization for PDL replace op class."""
+
+ def __init__(
+ self,
+ op: Union[OpView, Operation, Value],
+ *,
+ with_op: Optional[Union[OpView, Operation, Value]] = None,
+ with_values: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
+ loc=None,
+ ip=None,
+ ):
+ if with_values is None:
+ with_values = []
+ op = _get_value(op)
+ with_op = with_op if with_op is None else _get_value(with_op)
+ with_values = _get_values(with_values)
+ super().__init__(op, with_values, replOperation=with_op, loc=loc, ip=ip)
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class ResultOp(ResultOp):
+ """Specialization for PDL result op class."""
+
+ def __init__(
+ self,
+ parent: Union[OpView, Operation, Value],
+ index: Union[IntegerAttr, int],
+ *,
+ loc=None,
+ ip=None,
+ ):
+ parent = _get_value(parent)
+ result = pdl.ValueType.get()
+ super().__init__(result, parent, index, loc=loc, ip=ip)
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class ResultsOp(ResultsOp):
+ """Specialization for PDL results op class."""
+
+ def __init__(
+ self,
+ result: Type,
+ parent: Union[OpView, Operation, Value],
+ index: Optional[Union[IntegerAttr, int]] = None,
+ *,
+ loc=None,
+ ip=None,
+ ):
+ parent = _get_value(parent)
+ super().__init__(result, parent, index=index, loc=loc, ip=ip)
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class RewriteOp(RewriteOp):
+ """Specialization for PDL rewrite op class."""
+
+ def __init__(
+ self,
+ root: Optional[Union[OpView, Operation, Value]] = None,
+ name: Optional[Union[StringAttr, str]] = None,
+ args: Optional[Sequence[Union[OpView, Operation, Value]]] = None,
+ *,
+ loc=None,
+ ip=None,
+ ):
+ if args is None:
+ args = []
+ root = root if root is None else _get_value(root)
+ args = _get_values(args)
+ super().__init__(args, root=root, name=name, loc=loc, ip=ip)
+
+ def add_body(self):
+ """Add body (block) to the rewrite."""
+ self.regions[0].blocks.append()
+ return self.body
+
+ @property
+ def body(self):
+ """Return the body (block) of the rewrite."""
+ return self.regions[0].blocks[0]
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class TypeOp(TypeOp):
+ """Specialization for PDL type op class."""
+
+ def __init__(
+ self, constantType: Optional[Union[TypeAttr, Type]] = None, *, loc=None, ip=None
+ ):
+ result = pdl.TypeType.get()
+ super().__init__(result, constantType=constantType, loc=loc, ip=ip)
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class TypesOp(TypesOp):
+ """Specialization for PDL types op class."""
+
+ def __init__(
+ self,
+ constantTypes: Optional[Sequence[Union[TypeAttr, Type]]] = None,
+ *,
+ loc=None,
+ ip=None,
+ ):
+ if constantTypes is None:
+ constantTypes = []
+ result = pdl.RangeType.get(pdl.TypeType.get())
+ super().__init__(result, constantTypes=constantTypes, loc=loc, ip=ip)
diff --git a/mlir/python/mlir/dialects/scf.py b/mlir/python/mlir/dialects/scf.py
index 49685ca2271fc61..43ad9f4e2d65f51 100644
--- a/mlir/python/mlir/dialects/scf.py
+++ b/mlir/python/mlir/dialects/scf.py
@@ -2,11 +2,122 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-from typing import Optional, Sequence
from ._scf_ops_gen import *
+from ._scf_ops_gen import _Dialect
from .arith import constant
-from ..ir import *
+
+try:
+ from ..ir import *
+ from ._ods_common import (
+ get_op_result_or_value as _get_op_result_or_value,
+ get_op_results_or_values as _get_op_results_or_values,
+ _cext as _ods_cext,
+ )
+except ImportError as e:
+ raise RuntimeError("Error loading imports from extension module") from e
+
+from typing import Optional, Sequence, Union
+
+
+_ForOp = ForOp
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class ForOp(_ForOp):
+ """Specialization for the SCF for op class."""
+
+ def __init__(
+ self,
+ lower_bound,
+ upper_bound,
+ step,
+ iter_args: Optional[Union[Operation, OpView, Sequence[Value]]] = None,
+ *,
+ loc=None,
+ ip=None,
+ ):
+ """Creates an SCF `for` operation.
+
+ - `lower_bound` is the value to use as lower bound of the loop.
+ - `upper_bound` is the value to use as upper bound of the loop.
+ - `step` is the value to use as loop step.
+ - `iter_args` is a list of additional loop-carried arguments or an operation
+ producing them as results.
+ """
+ if iter_args is None:
+ iter_args = []
+ iter_args = _get_op_results_or_values(iter_args)
+
+ results = [arg.type for arg in iter_args]
+ super(_ForOp, self).__init__(
+ self.build_generic(
+ regions=1,
+ results=results,
+ operands=[
+ _get_op_result_or_value(o) for o in [lower_bound, upper_bound, step]
+ ]
+ + list(iter_args),
+ loc=loc,
+ ip=ip,
+ )
+ )
+ self.regions[0].blocks.append(self.operands[0].type, *results)
+
+ @property
+ def body(self):
+ """Returns the body (block) of the loop."""
+ return self.regions[0].blocks[0]
+
+ @property
+ def induction_variable(self):
+ """Returns the induction variable of the loop."""
+ return self.body.arguments[0]
+
+ @property
+ def inner_iter_args(self):
+ """Returns the loop-carried arguments usable within the loop.
+
+ To obtain the loop-carried operands, use `iter_args`.
+ """
+ return self.body.arguments[1:]
+
+
+_IfOp = IfOp
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class IfOp(_IfOp):
+ """Specialization for the SCF if op class."""
+
+ def __init__(self, cond, results_=[], *, hasElse=False, loc=None, ip=None):
+ """Creates an SCF `if` operation.
+
+ - `cond` is a MLIR value of 'i1' type to determine which regions of code will be executed.
+ - `hasElse` determines whether the if operation has the else branch.
+ """
+ operands = []
+ operands.append(cond)
+ results = []
+ results.extend(results_)
+ super(_IfOp, self).__init__(
+ self.build_generic(
+ regions=2, results=results, operands=operands, loc=loc, ip=ip
+ )
+ )
+ self.regions[0].blocks.append(*[])
+ if hasElse:
+ self.regions[1].blocks.append(*[])
+
+ @property
+ def then_block(self):
+ """Returns the then block of the if operation."""
+ return self.regions[0].blocks[0]
+
+ @property
+ def else_block(self):
+ """Returns the else block of the if operation."""
+ return self.regions[1].blocks[0]
def for_(
diff --git a/mlir/python/mlir/dialects/tensor.py b/mlir/python/mlir/dialects/tensor.py
index 26edf6b6436dad5..a007d683ca7be32 100644
--- a/mlir/python/mlir/dialects/tensor.py
+++ b/mlir/python/mlir/dialects/tensor.py
@@ -3,3 +3,50 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from ._tensor_ops_gen import *
+from ._tensor_ops_gen import _Dialect
+
+try:
+ from ..ir import *
+except ImportError as e:
+ raise RuntimeError("Error loading imports from extension module") from e
+
+from typing import Sequence, Union
+from ._ods_common import _cext as _ods_cext
+
+_EmptyOp = EmptyOp
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class EmptyOp(_EmptyOp):
+ """Extends the tensor.empty op."""
+
+ def __init__(
+ self,
+ sizes: Sequence[Union[int, Value]],
+ element_type: Type,
+ *,
+ loc=None,
+ ip=None,
+ ):
+ """Constructs an `empty` with mixed static/dynamic sizes."""
+ # TODO: Refactor the EmptyOp to take an element type attribute and
+ # then use normal result type inference, unifying the Python and C++ side
+ # with a standard mechanism (versus stashing that in builders).
+ dynamic_sizes = []
+ static_sizes = []
+ for s in sizes:
+ if isinstance(s, int):
+ static_sizes.append(s)
+ else:
+ static_sizes.append(ShapedType.get_dynamic_size())
+ dynamic_sizes.append(s)
+ result_type = RankedTensorType.get(static_sizes, element_type)
+ super(_EmptyOp, self).__init__(
+ self.build_generic(
+ results=[result_type],
+ operands=dynamic_sizes,
+ attributes={},
+ loc=loc,
+ ip=ip,
+ )
+ )
diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py
index b020ad35fcf062f..a33b096675f1a7b 100644
--- a/mlir/python/mlir/dialects/transform/__init__.py
+++ b/mlir/python/mlir/dialects/transform/__init__.py
@@ -4,4 +4,189 @@
from .._transform_enum_gen import *
from .._transform_ops_gen import *
+from .._transform_ops_gen import _Dialect
from ..._mlir_libs._mlirDialectsTransform import *
+
+try:
+ from ...ir import *
+ from .._ods_common import (
+ get_op_result_or_value as _get_op_result_or_value,
+ get_op_results_or_values as _get_op_results_or_values,
+ _cext as _ods_cext,
+ )
+except ImportError as e:
+ raise RuntimeError("Error loading imports from extension module") from e
+
+from typing import Optional, Sequence, Union
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class CastOp(CastOp):
+ def __init__(
+ self,
+ result_type: Type,
+ target: Union[Operation, Value],
+ *,
+ loc=None,
+ ip=None,
+ ):
+ super().__init__(result_type, _get_op_result_or_value(target), loc=loc, ip=ip)
+
+
+_ApplyPatternsOp = ApplyPatternsOp
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class ApplyPatternsOp(_ApplyPatternsOp):
+ def __init__(
+ self,
+ target: Union[Operation, Value, OpView],
+ *,
+ loc=None,
+ ip=None,
+ ):
+ operands = []
+ operands.append(_get_op_result_or_value(target))
+ super(_ApplyPatternsOp, self).__init__(
+ self.build_generic(
+ attributes={},
+ results=[],
+ operands=operands,
+ successors=None,
+ regions=None,
+ loc=loc,
+ ip=ip,
+ )
+ )
+ self.regions[0].blocks.append()
+
+ @property
+ def patterns(self) -> Block:
+ return self.regions[0].blocks[0]
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class GetParentOp(GetParentOp):
+ def __init__(
+ self,
+ result_type: Type,
+ target: Union[Operation, Value],
+ *,
+ isolated_from_above: bool = False,
+ op_name: Optional[str] = None,
+ deduplicate: bool = False,
+ loc=None,
+ ip=None,
+ ):
+ super().__init__(
+ result_type,
+ _get_op_result_or_value(target),
+ isolated_from_above=isolated_from_above,
+ op_name=op_name,
+ deduplicate=deduplicate,
+ loc=loc,
+ ip=ip,
+ )
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class MergeHandlesOp(MergeHandlesOp):
+ def __init__(
+ self,
+ handles: Sequence[Union[Operation, Value]],
+ *,
+ deduplicate: bool = False,
+ loc=None,
+ ip=None,
+ ):
+ super().__init__(
+ [_get_op_result_or_value(h) for h in handles],
+ deduplicate=deduplicate,
+ loc=loc,
+ ip=ip,
+ )
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class ReplicateOp(ReplicateOp):
+ def __init__(
+ self,
+ pattern: Union[Operation, Value],
+ handles: Sequence[Union[Operation, Value]],
+ *,
+ loc=None,
+ ip=None,
+ ):
+ super().__init__(
+ [_get_op_result_or_value(h).type for h in handles],
+ _get_op_result_or_value(pattern),
+ [_get_op_result_or_value(h) for h in handles],
+ loc=loc,
+ ip=ip,
+ )
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class SequenceOp(SequenceOp):
+ def __init__(
+ self,
+ failure_propagation_mode,
+ results: Sequence[Type],
+ target: Union[Operation, Value, Type],
+ extra_bindings: Optional[
+ Union[Sequence[Value], Sequence[Type], Operation, OpView]
+ ] = None,
+ ):
+ root = (
+ _get_op_result_or_value(target)
+ if isinstance(target, (Operation, Value))
+ else None
+ )
+ root_type = root.type if not isinstance(target, Type) else target
+
+ if extra_bindings is None:
+ extra_bindings = []
+ if isinstance(extra_bindings, (Operation, OpView)):
+ extra_bindings = _get_op_results_or_values(extra_bindings)
+
+ extra_binding_types = []
+ if len(extra_bindings) != 0:
+ if isinstance(extra_bindings[0], Type):
+ extra_binding_types = extra_bindings
+ extra_bindings = []
+ else:
+ extra_binding_types = [v.type for v in extra_bindings]
+
+ super().__init__(
+ results_=results,
+ failure_propagation_mode=failure_propagation_mode,
+ root=root,
+ extra_bindings=extra_bindings,
+ )
+ self.regions[0].blocks.append(*tuple([root_type] + extra_binding_types))
+
+ @property
+ def body(self) -> Block:
+ return self.regions[0].blocks[0]
+
+ @property
+ def bodyTarget(self) -> Value:
+ return self.body.arguments[0]
+
+ @property
+ def bodyExtraArgs(self) -> BlockArgumentList:
+ return self.body.arguments[1:]
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class YieldOp(YieldOp):
+ def __init__(
+ self,
+ operands: Optional[Union[Operation, Sequence[Value]]] = None,
+ *,
+ loc=None,
+ ip=None,
+ ):
+ if operands is None:
+ operands = []
+ super().__init__(_get_op_results_or_values(operands), loc=loc, ip=ip)
diff --git a/mlir/python/mlir/dialects/transform/bufferization.py b/mlir/python/mlir/dialects/transform/bufferization.py
index eb77b746cf864fa..485a8a36b6305e9 100644
--- a/mlir/python/mlir/dialects/transform/bufferization.py
+++ b/mlir/python/mlir/dialects/transform/bufferization.py
@@ -3,3 +3,132 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from .._bufferization_transform_ops_gen import *
+from .._bufferization_transform_ops_gen import _Dialect
+
+try:
+ from ...ir import *
+ from ...dialects import transform
+ from .._ods_common import _cext as _ods_cext
+except ImportError as e:
+ raise RuntimeError("Error loading imports from extension module") from e
+
+from enum import Enum
+from typing import Optional, overload, Union
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class EmptyTensorToAllocTensorOp(EmptyTensorToAllocTensorOp):
+ """Specialization for EmptyTensorToAllocTensorOp class."""
+
+ @overload
+ def __init__(
+ self,
+ transformed_type: Type,
+ target: Union[Operation, OpView, Value],
+ *,
+ loc=None,
+ ip=None,
+ ):
+ ...
+
+ @overload
+ def __init__(self, target: Union[Operation, OpView, Value], *, loc=None, ip=None):
+ ...
+
+ def __init__(
+ self,
+ transformed_type_or_target: Type,
+ target_or_none: Optional[Union[Operation, OpView, Value]] = None,
+ *,
+ loc=None,
+ ip=None,
+ ):
+ if isinstance(transformed_type_or_target, Type):
+ transformed_type = transformed_type_or_target
+ target = target_or_none
+ else:
+ transformed_type = transform.OperationType.get("bufferization.alloc_tensor")
+ target = transformed_type_or_target
+
+ super().__init__(
+ transformed_type,
+ target,
+ loc=loc,
+ ip=ip,
+ )
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class OneShotBufferizeOp(OneShotBufferizeOp):
+ """Specialization for OneShotBufferizeOp class."""
+
+ @overload
+ def __init__(
+ self,
+ transformed_type: Type,
+ target: Union[Operation, OpView, Value],
+ *,
+ allow_return_allocs_from_loops: Optional[bool] = None,
+ allow_unknown_ops: Optional[bool] = None,
+ bufferize_function_boundaries: Optional[bool] = None,
+ function_boundary_type_conversion: Optional[Enum] = None,
+ memcpy_op: Optional[str] = None,
+ print_conflicts: Optional[bool] = None,
+ test_analysis_only: Optional[bool] = None,
+ loc=None,
+ ip=None,
+ ):
+ ...
+
+ @overload
+ def __init__(
+ self,
+ target: Union[Operation, OpView, Value],
+ *,
+ allow_return_allocs_from_loops: Optional[bool] = None,
+ allow_unknown_ops: Optional[bool] = None,
+ bufferize_function_boundaries: Optional[bool] = None,
+ function_boundary_type_conversion: Optional[Enum] = None,
+ memcpy_op: Optional[str] = None,
+ print_conflicts: Optional[bool] = None,
+ test_analysis_only: Optional[bool] = None,
+ loc=None,
+ ip=None,
+ ):
+ ...
+
+ def __init__(
+ self,
+ transformed_type_or_target: Type,
+ target_or_none: Optional[Union[Operation, OpView, Value]] = None,
+ *,
+ allow_return_allocs_from_loops: Optional[bool] = None,
+ allow_unknown_ops: Optional[bool] = None,
+ bufferize_function_boundaries: Optional[bool] = None,
+ function_boundary_type_conversion: Optional[Enum] = None,
+ memcpy_op: Optional[str] = None,
+ print_conflicts: Optional[bool] = None,
+ test_analysis_only: Optional[bool] = None,
+ loc=None,
+ ip=None,
+ ):
+ if isinstance(transformed_type_or_target, Type):
+ transformed_type = transformed_type_or_target
+ target = target_or_none
+ else:
+ transformed_type = transform.AnyOpType.get()
+ target = transformed_type_or_target
+
+ super().__init__(
+ transformed_type,
+ target,
+ allow_return_allocs_from_loops=allow_return_allocs_from_loops,
+ allow_unknown_ops=allow_unknown_ops,
+ bufferize_function_boundaries=bufferize_function_boundaries,
+ function_boundary_type_conversion=function_boundary_type_conversion,
+ memcpy_op=memcpy_op,
+ print_conflicts=print_conflicts,
+ test_analysis_only=test_analysis_only,
+ loc=loc,
+ ip=ip,
+ )
diff --git a/mlir/python/mlir/dialects/transform/gpu.py b/mlir/python/mlir/dialects/transform/gpu.py
index 8c3de0de7ea3f19..00cf0840eeae9e1 100644
--- a/mlir/python/mlir/dialects/transform/gpu.py
+++ b/mlir/python/mlir/dialects/transform/gpu.py
@@ -3,3 +3,128 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from .._gpu_transform_ops_gen import *
+from .._gpu_transform_ops_gen import _Dialect
+
+try:
+ from ...ir import *
+ from ...dialects import transform
+ from .._ods_common import _cext as _ods_cext
+except ImportError as e:
+ raise RuntimeError("Error loading imports from extension module") from e
+
+from typing import Optional, Sequence, Union, overload
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class MapForallToBlocks(MapForallToBlocks):
+ """Specialization for MapForallToBlocks class."""
+
+ @overload
+ def __init__(
+ self,
+ result_type: Type,
+ target: Union[Operation, OpView, Value],
+ *,
+ grid_dims: Optional[Union[Sequence[int], Attribute]] = None,
+ generate_gpu_launch: Optional[Union[bool, Attribute]] = None,
+ loc=None,
+ ip=None,
+ ):
+ ...
+
+ @overload
+ def __init__(
+ self,
+ target: Union[Operation, OpView, Value],
+ *,
+ grid_dims: Optional[Union[Sequence[int], Attribute]] = None,
+ generate_gpu_launch: Optional[Union[bool, Attribute]] = None,
+ loc=None,
+ ip=None,
+ ):
+ ...
+
+ def __init__(
+ self,
+ result_type_or_target: Union[Operation, OpView, Type, Value],
+ target_or_none: Optional[Union[Operation, OpView, Value]] = None,
+ *,
+ grid_dims: Optional[Union[Sequence[int], Attribute]] = None,
+ generate_gpu_launch: Optional[Union[bool, Attribute]] = None,
+ loc=None,
+ ip=None,
+ ):
+ if isinstance(result_type_or_target, Type):
+ result_type = result_type_or_target
+ target = target_or_none
+ else:
+ result_type = transform.AnyOpType.get()
+ target = result_type_or_target
+
+ super().__init__(
+ result_type,
+ target,
+ grid_dims=grid_dims,
+ generate_gpu_launch=generate_gpu_launch,
+ loc=loc,
+ ip=ip,
+ )
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class MapNestedForallToThreads(MapNestedForallToThreads):
+ """Specialization for MapNestedForallToThreads class."""
+
+ @overload
+ def __init__(
+ self,
+ result_type: Type,
+ target: Union[Operation, OpView, Value],
+ *,
+ block_dims: Optional[Sequence[int]] = None,
+ warp_size: Optional[Sequence[int]] = None,
+ sync_after_distribute: Optional[bool] = None,
+ loc=None,
+ ip=None,
+ ):
+ ...
+
+ @overload
+ def __init__(
+ self,
+ target: Union[Operation, OpView, Value],
+ *,
+ block_dims: Optional[Sequence[int]] = None,
+ warp_size: Optional[Sequence[int]] = None,
+ sync_after_distribute: Optional[bool] = None,
+ loc=None,
+ ip=None,
+ ):
+ ...
+
+ def __init__(
+ self,
+ result_type_or_target: Union[Operation, OpView, Value, Type],
+ target_or_none: Optional[Union[Operation, OpView, Value]] = None,
+ *,
+ block_dims: Optional[Union[Sequence[int], Attribute]] = None,
+ warp_size: Optional[Union[Sequence[int], Attribute]] = None,
+ sync_after_distribute: Optional[bool] = None,
+ loc=None,
+ ip=None,
+ ):
+ if isinstance(result_type_or_target, Type):
+ result_type = result_type_or_target
+ target = target_or_none
+ else:
+ result_type = result_type_or_target.type
+ target = result_type_or_target
+ super().__init__(
+ result_type,
+ target,
+ block_dims=block_dims,
+ warp_size=warp_size,
+ sync_after_distribute=sync_after_distribute,
+ loc=loc,
+ ip=ip,
+ )
diff --git a/mlir/python/mlir/dialects/transform/loop.py b/mlir/python/mlir/dialects/transform/loop.py
index 86f72788d86c369..6c89025f413839e 100644
--- a/mlir/python/mlir/dialects/transform/loop.py
+++ b/mlir/python/mlir/dialects/transform/loop.py
@@ -3,3 +3,143 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from .._loop_transform_ops_gen import *
+from .._loop_transform_ops_gen import _Dialect
+
+try:
+ from ...ir import *
+ from .._ods_common import (
+ get_op_result_or_value as _get_op_result_or_value,
+ _cext as _ods_cext,
+ )
+except ImportError as e:
+ raise RuntimeError("Error loading imports from extension module") from e
+
+from typing import Optional, Union
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class GetParentForOp(GetParentForOp):
+ """Extension for GetParentForOp."""
+
+ def __init__(
+ self,
+ result_type: Type,
+ target: Union[Operation, Value],
+ *,
+ num_loops: Optional[int] = None,
+ ip=None,
+ loc=None,
+ ):
+ if num_loops is None:
+ num_loops = 1
+ super().__init__(
+ result_type,
+ _get_op_result_or_value(target),
+ num_loops=num_loops,
+ ip=ip,
+ loc=loc,
+ )
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class LoopOutlineOp(LoopOutlineOp):
+ """Extension for LoopOutlineOp."""
+
+ def __init__(
+ self,
+ function_type: Type,
+ call_type: Type,
+ target: Union[Operation, Value],
+ *,
+ func_name: Union[str, StringAttr],
+ ip=None,
+ loc=None,
+ ):
+ super().__init__(
+ function_type,
+ call_type,
+ _get_op_result_or_value(target),
+ func_name=(
+ func_name
+ if isinstance(func_name, StringAttr)
+ else StringAttr.get(func_name)
+ ),
+ ip=ip,
+ loc=loc,
+ )
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class LoopPeelOp(LoopPeelOp):
+ """Extension for LoopPeelOp."""
+
+ def __init__(
+ self,
+ main_loop_type: Type,
+ remainder_loop_type: Type,
+ target: Union[Operation, Value],
+ *,
+ fail_if_already_divisible: Union[bool, BoolAttr] = False,
+ ip=None,
+ loc=None,
+ ):
+ super().__init__(
+ main_loop_type,
+ remainder_loop_type,
+ _get_op_result_or_value(target),
+ fail_if_already_divisible=(
+ fail_if_already_divisible
+ if isinstance(fail_if_already_divisible, BoolAttr)
+ else BoolAttr.get(fail_if_already_divisible)
+ ),
+ ip=ip,
+ loc=loc,
+ )
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class LoopPipelineOp(LoopPipelineOp):
+ """Extension for LoopPipelineOp."""
+
+ def __init__(
+ self,
+ result_type: Type,
+ target: Union[Operation, Value],
+ *,
+ iteration_interval: Optional[Union[int, IntegerAttr]] = None,
+ read_latency: Optional[Union[int, IntegerAttr]] = None,
+ ip=None,
+ loc=None,
+ ):
+ if iteration_interval is None:
+ iteration_interval = 1
+ if read_latency is None:
+ read_latency = 10
+ super().__init__(
+ result_type,
+ _get_op_result_or_value(target),
+ iteration_interval=iteration_interval,
+ read_latency=read_latency,
+ ip=ip,
+ loc=loc,
+ )
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class LoopUnrollOp(LoopUnrollOp):
+ """Extension for LoopUnrollOp."""
+
+ def __init__(
+ self,
+ target: Union[Operation, Value],
+ *,
+ factor: Union[int, IntegerAttr],
+ ip=None,
+ loc=None,
+ ):
+ super().__init__(
+ _get_op_result_or_value(target),
+ factor=factor,
+ ip=ip,
+ loc=loc,
+ )
diff --git a/mlir/python/mlir/dialects/transform/memref.py b/mlir/python/mlir/dialects/transform/memref.py
index 1ff04ef6a60a180..56ea61eb817f89c 100644
--- a/mlir/python/mlir/dialects/transform/memref.py
+++ b/mlir/python/mlir/dialects/transform/memref.py
@@ -3,3 +3,118 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from .._memref_transform_ops_gen import *
+from .._memref_transform_ops_gen import _Dialect
+
+try:
+ from ...ir import *
+ from ...dialects import transform
+ from .._ods_common import _cext as _ods_cext
+except ImportError as e:
+ raise RuntimeError("Error loading imports from extension module") from e
+
+from typing import Optional, overload, Union
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class MemRefAllocaToGlobalOp(MemRefAllocaToGlobalOp):
+ """Specialization for MemRefAllocaToGlobalOp class."""
+
+ @overload
+ def __init__(
+ self,
+ get_global_type: Type,
+ global_type: Type,
+ alloca: Union[Operation, OpView, Value],
+ *,
+ loc=None,
+ ip=None,
+ ):
+ ...
+
+ @overload
+ def __init__(self, alloca: Union[Operation, OpView, Value], *, loc=None, ip=None):
+ ...
+
+ def __init__(
+ self,
+ get_global_type_or_alloca: Union[Operation, OpView, Type, Value],
+ global_type_or_none: Optional[Type] = None,
+ alloca_or_none: Optional[Union[Operation, OpView, Value]] = None,
+ *,
+ loc=None,
+ ip=None,
+ ):
+ if isinstance(get_global_type_or_alloca, Type):
+ get_global_type = get_global_type_or_alloca
+ global_type = global_type_or_none
+ alloca = alloca_or_none
+ else:
+ get_global_type = transform.AnyOpType.get()
+ global_type = transform.AnyOpType.get()
+ alloca = get_global_type_or_alloca
+
+ super().__init__(
+ get_global_type,
+ global_type,
+ alloca,
+ loc=loc,
+ ip=ip,
+ )
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class MemRefMultiBufferOp(MemRefMultiBufferOp):
+ """Specialization for MemRefMultiBufferOp class."""
+
+ @overload
+ def __init__(
+ self,
+ transformed_type: Type,
+ target: Union[Operation, OpView, Value],
+ factor: Union[int, IntegerAttr],
+ *,
+ skip_analysis: Optional[bool] = None,
+ loc=None,
+ ip=None,
+ ):
+ ...
+
+ @overload
+ def __init__(
+ self,
+ target: Union[Operation, OpView, Value],
+ factor: Union[int, IntegerAttr],
+ *,
+ skip_analysis: Optional[bool] = None,
+ loc=None,
+ ip=None,
+ ):
+ ...
+
+ def __init__(
+ self,
+ transformed_type_or_target: Type,
+ target_or_factor: Union[int, IntegerAttr, Operation, OpView, Value] = None,
+ factor_or_none: Optional[Union[int, IntegerAttr]] = None,
+ *,
+ skip_analysis: Optional[bool] = None,
+ loc=None,
+ ip=None,
+ ):
+ if isinstance(transformed_type_or_target, Type):
+ transformed_type = transformed_type_or_target
+ target = target_or_factor
+ factor = factor_or_none
+ else:
+ transformed_type = transform.AnyOpType.get()
+ target = transformed_type_or_target
+ factor = target_or_factor
+
+ super().__init__(
+ transformed_type,
+ target,
+ factor,
+ skip_analysis=skip_analysis,
+ loc=loc,
+ ip=ip,
+ )
diff --git a/mlir/python/mlir/dialects/transform/pdl.py b/mlir/python/mlir/dialects/transform/pdl.py
index b1515287a3f1ff0..bb5fa7ffd306583 100644
--- a/mlir/python/mlir/dialects/transform/pdl.py
+++ b/mlir/python/mlir/dialects/transform/pdl.py
@@ -3,3 +3,53 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from .._transform_pdl_extension_ops_gen import *
+from .._transform_pdl_extension_ops_gen import _Dialect
+
+try:
+ from ...ir import *
+ from .._ods_common import (
+ get_op_result_or_value as _get_op_result_or_value,
+ get_op_results_or_values as _get_op_results_or_values,
+ _cext as _ods_cext,
+ )
+except ImportError as e:
+ raise RuntimeError("Error loading imports from extension module") from e
+
+from typing import Union
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class PDLMatchOp(PDLMatchOp):
+ def __init__(
+ self,
+ result_type: Type,
+ target: Union[Operation, Value],
+ pattern_name: Union[Attribute, str],
+ *,
+ loc=None,
+ ip=None,
+ ):
+ super().__init__(
+ result_type,
+ _get_op_result_or_value(target),
+ pattern_name,
+ loc=loc,
+ ip=ip,
+ )
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class WithPDLPatternsOp(WithPDLPatternsOp):
+ def __init__(self, target: Union[Operation, Value, Type], *, loc=None, ip=None):
+ root = _get_op_result_or_value(target) if not isinstance(target, Type) else None
+ root_type = target if isinstance(target, Type) else root.type
+ super().__init__(root=root, loc=loc, ip=ip)
+ self.regions[0].blocks.append(root_type)
+
+ @property
+ def body(self) -> Block:
+ return self.regions[0].blocks[0]
+
+ @property
+ def bodyTarget(self) -> Value:
+ return self.body.arguments[0]
diff --git a/mlir/python/mlir/dialects/transform/structured.py b/mlir/python/mlir/dialects/transform/structured.py
index cb3812301dbd4b5..284c93823acbd34 100644
--- a/mlir/python/mlir/dialects/transform/structured.py
+++ b/mlir/python/mlir/dialects/transform/structured.py
@@ -3,4 +3,777 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from .._structured_transform_ops_gen import *
+from .._structured_transform_ops_gen import _Dialect
from .._structured_transform_enum_gen import *
+
+try:
+ from ...ir import *
+ from ...dialects import transform
+ from .._ods_common import _cext as _ods_cext
+except ImportError as e:
+ raise RuntimeError("Error loading imports from extension module") from e
+
+from typing import List, Optional, Sequence, Tuple, Union, overload
+
+StaticIntLike = Union[int, IntegerAttr]
+ValueLike = Union[Operation, OpView, Value]
+MixedInt = Union[StaticIntLike, ValueLike]
+
+IntOrAttrList = Sequence[Union[IntegerAttr, int]]
+OptionalIntList = Optional[Union[ArrayAttr, IntOrAttrList]]
+
+BoolOrAttrList = Sequence[Union[BoolAttr, bool]]
+OptionalBoolList = Optional[Union[ArrayAttr, BoolOrAttrList]]
+
+MixedValues = Union[Sequence[Union[StaticIntLike, ValueLike]], ArrayAttr, ValueLike]
+
+DynamicIndexList = Sequence[Union[MixedInt, Sequence[MixedInt]]]
+
+
+def _dispatch_dynamic_index_list(
+ indices: Union[DynamicIndexList, ArrayAttr],
+) -> Tuple[List[ValueLike], Union[List[int], ArrayAttr], List[bool]]:
+ """Dispatches a list of indices to the appropriate form.
+
+ This is similar to the custom `DynamicIndexList` directive upstream:
+ provided indices may be in the form of dynamic SSA values or static values,
+ and they may be scalable (i.e., as a singleton list) or not. This function
+ dispatches each index into its respective form. It also extracts the SSA
+ values and static indices from various similar structures, respectively.
+ """
+ dynamic_indices = []
+ static_indices = [ShapedType.get_dynamic_size()] * len(indices)
+ scalable_indices = [False] * len(indices)
+
+ # ArrayAttr: Extract index values.
+ if isinstance(indices, ArrayAttr):
+ indices = [idx for idx in indices]
+
+ def process_nonscalable_index(i, index):
+ """Processes any form of non-scalable index.
+
+ Returns False if the given index was scalable and thus remains
+ unprocessed; True otherwise.
+ """
+ if isinstance(index, int):
+ static_indices[i] = index
+ elif isinstance(index, IntegerAttr):
+ static_indices[i] = index.value # pytype: disable=attribute-error
+ elif isinstance(index, (Operation, Value, OpView)):
+ dynamic_indices.append(index)
+ else:
+ return False
+ return True
+
+ # Process each index at a time.
+ for i, index in enumerate(indices):
+ if not process_nonscalable_index(i, index):
+ # If it wasn't processed, it must be a scalable index, which is
+ # provided as a Sequence of one value, so extract and process that.
+ scalable_indices[i] = True
+ assert len(index) == 1
+ ret = process_nonscalable_index(i, index[0])
+ assert ret
+
+ return dynamic_indices, static_indices, scalable_indices
+
+
+# Dispatches `MixedValues` that all represents integers in various forms into
+# the following three categories:
+# - `dynamic_values`: a list of `Value`s, potentially from op results;
+# - `packed_values`: a value handle, potentially from an op result, associated
+# to one or more payload operations of integer type;
+# - `static_values`: an `ArrayAttr` of `i64`s with static values, from Python
+# `int`s, from `IntegerAttr`s, or from an `ArrayAttr`.
+# The input is in the form for `packed_values`, only that result is set and the
+# other two are empty. Otherwise, the input can be a mix of the other two forms,
+# and for each dynamic value, a special value is added to the `static_values`.
+def _dispatch_mixed_values(
+ values: MixedValues,
+) -> Tuple[List[Value], Union[Operation, Value, OpView], DenseI64ArrayAttr]:
+ dynamic_values = []
+ packed_values = None
+ static_values = None
+ if isinstance(values, ArrayAttr):
+ static_values = values
+ elif isinstance(values, (Operation, Value, OpView)):
+ packed_values = values
+ else:
+ static_values = []
+ for size in values or []:
+ if isinstance(size, int):
+ static_values.append(size)
+ else:
+ static_values.append(ShapedType.get_dynamic_size())
+ dynamic_values.append(size)
+ static_values = DenseI64ArrayAttr.get(static_values)
+
+ return (dynamic_values, packed_values, static_values)
+
+
+def _get_value_or_attribute_value(
+ value_or_attr: Union[any, Attribute, ArrayAttr]
+) -> any:
+ if isinstance(value_or_attr, Attribute) and hasattr(value_or_attr, "value"):
+ return value_or_attr.value
+ if isinstance(value_or_attr, ArrayAttr):
+ return _get_value_list(value_or_attr)
+ return value_or_attr
+
+
+def _get_value_list(
+ sequence_or_array_attr: Union[Sequence[any], ArrayAttr]
+) -> Sequence[any]:
+ return [_get_value_or_attribute_value(v) for v in sequence_or_array_attr]
+
+
+def _get_int_array_attr(values: Optional[Union[ArrayAttr, IntOrAttrList]]) -> ArrayAttr:
+ if values is None:
+ return None
+
+ # Turn into a Python list of Python ints.
+ values = _get_value_list(values)
+
+ # Make an ArrayAttr of IntegerAttrs out of it.
+ return ArrayAttr.get(
+ [IntegerAttr.get(IntegerType.get_signless(64), v) for v in values]
+ )
+
+
+def _get_int_array_array_attr(
+ values: Optional[Union[ArrayAttr, Sequence[Union[ArrayAttr, IntOrAttrList]]]]
+) -> ArrayAttr:
+ """Creates an ArrayAttr of ArrayAttrs of IntegerAttrs.
+
+ The input has to be a collection of collection of integers, where any
+ Python Sequence and ArrayAttr are admissible collections and Python ints and
+ any IntegerAttr are admissible integers. Both levels of collections are
+ turned into ArrayAttr; the inner level is turned into IntegerAttrs of i64s.
+ If the input is None, an empty ArrayAttr is returned.
+ """
+ if values is None:
+ return None
+
+ # Make sure the outer level is a list.
+ values = _get_value_list(values)
+
+ # The inner level is now either invalid or a mixed sequence of ArrayAttrs and
+ # Sequences. Make sure the nested values are all lists.
+ values = [_get_value_list(nested) for nested in values]
+
+ # Turn each nested list into an ArrayAttr.
+ values = [_get_int_array_attr(nested) for nested in values]
+
+ # Turn the outer list into an ArrayAttr.
+ return ArrayAttr.get(values)
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class BufferizeToAllocationOp(BufferizeToAllocationOp):
+ """Specialization for BufferizeToAllocationOp class."""
+
+ def __init__(
+ self,
+ target: Union[Operation, OpView, Value],
+ *,
+ memory_space: Optional[Union[int, str, Attribute]] = None,
+ memcpy_op: Optional[str] = None,
+ alloc_op: Optional[str] = None,
+ bufferize_destination_only: Optional[bool] = None,
+ loc=None,
+ ip=None,
+ ):
+ # No other types are allowed, so hard-code those here.
+ allocated_buffer_type = transform.AnyValueType.get()
+ new_ops_type = transform.AnyOpType.get()
+
+ if isinstance(memory_space, int):
+ memory_space = str(memory_space)
+ if isinstance(memory_space, str):
+ memory_space = Attribute.parse(memory_space)
+
+ super().__init__(
+ allocated_buffer_type,
+ new_ops_type,
+ target,
+ memory_space=memory_space,
+ memcpy_op=memcpy_op,
+ alloc_op=alloc_op,
+ bufferize_destination_only=bufferize_destination_only,
+ loc=loc,
+ ip=ip,
+ )
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class DecomposeOp(DecomposeOp):
+ """Specialization for DecomposeOp class."""
+
+ def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
+ transformed_type = transform.AnyOpType.get()
+ super().__init__(transformed_type, target, loc=loc, ip=ip)
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class FuseIntoContainingOp(FuseIntoContainingOp):
+ """Specialization for FuseIntoContainingOp class."""
+
+ @overload
+ def __init__(
+ self,
+ fused_op_type: Type,
+ new_containing_op_type: Type,
+ producer_op: Union[Operation, OpView, Value],
+ containing_op: Union[Operation, OpView, Value],
+ *,
+ loc=None,
+ ip=None,
+ ):
+ ...
+
+ @overload
+ def __init__(
+ self,
+ producer_op: Union[Operation, OpView, Value],
+ containing_op: Union[Operation, OpView, Value],
+ *,
+ loc=None,
+ ip=None,
+ ):
+ ...
+
+ def __init__(
+ self,
+ fused_op_type_or_producer_op: Union[Operation, OpView, Type, Value],
+ new_containing_op_type_or_containing_op: Union[Operation, OpView, Type, Value],
+ producer_op_or_none: Optional[Union[Operation, OpView, Value]] = None,
+ containing_op_or_none: Optional[Union[Operation, OpView, Value]] = None,
+ *,
+ loc=None,
+ ip=None,
+ ):
+ if isinstance(fused_op_type_or_producer_op, Type):
+ if not isinstance(new_containing_op_type_or_containing_op, Type):
+ raise TypeError(
+ "If 'fused_op_type_or_producer_op' is a type, then "
+ "'new_containing_op_type_or_containing_op' is expected "
+ "to be one as well."
+ )
+ fused_op_type = fused_op_type_or_producer_op
+ new_containing_op_type = new_containing_op_type_or_containing_op
+ producer_op = producer_op_or_none
+ containing_op = containing_op_or_none
+ else:
+ fused_op_type = transform.AnyOpType.get()
+ new_containing_op_type = transform.AnyOpType.get()
+ producer_op = fused_op_type_or_producer_op
+ containing_op = new_containing_op_type_or_containing_op
+
+ super().__init__(
+ fused_op_type,
+ new_containing_op_type,
+ producer_op,
+ containing_op,
+ loc=loc,
+ ip=ip,
+ )
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class GeneralizeOp(GeneralizeOp):
+ """Specialization for GeneralizeOp class."""
+
+ def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
+ transformed_type = transform.AnyOpType.get()
+ super().__init__(transformed_type, target, loc=loc, ip=ip)
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class InterchangeOp(InterchangeOp):
+ """Specialization for InterchangeOp class."""
+
+ def __init__(
+ self,
+ target: Union[Operation, Value],
+ *,
+ iterator_interchange: OptionalIntList = None,
+ loc=None,
+ ip=None,
+ ):
+ transformed_type = transform.AnyOpType.get()
+ super().__init__(
+ transformed_type,
+ target,
+ iterator_interchange=iterator_interchange,
+ loc=loc,
+ ip=ip,
+ )
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class MapCopyToThreadsOp(MapCopyToThreadsOp):
+ """Specialization for MapCopyToThreadsOp class."""
+
+ @overload
+ def __init__(
+ self,
+ forall_op_type: Type,
+ tiled_op_type: Type,
+ target: Union[Operation, OpView, Value],
+ *,
+ total_num_threads: Union[int, IntegerAttr],
+ desired_bit_alignment: Union[int, IntegerAttr],
+ loc=None,
+ ip=None,
+ ):
+ ...
+
+ @overload
+ def __init__(
+ self,
+ target: Union[Operation, OpView, Value],
+ *,
+ total_num_threads: Union[int, IntegerAttr],
+ desired_bit_alignment: Union[int, IntegerAttr],
+ loc=None,
+ ip=None,
+ ):
+ ...
+
+ def __init__(
+ self,
+ forall_op_type_or_target: Union[Operation, OpView, Type, Value],
+ tiled_op_type_or_none: Optional[Type] = None,
+ target_or_none: Optional[Union[Operation, OpView, Value]] = None,
+ *,
+ total_num_threads: Union[int, IntegerAttr],
+ desired_bit_alignment: Union[int, IntegerAttr],
+ loc=None,
+ ip=None,
+ ):
+ if isinstance(forall_op_type_or_target, Type):
+ forall_op_type = forall_op_type_or_target
+ tiled_op_type = tiled_op_type_or_none
+ target = target_or_none
+ else:
+ forall_op_type = transform.AnyOpType.get()
+ tiled_op_type = transform.AnyOpType.get()
+ target = forall_op_type_or_target
+
+ super().__init__(
+ forall_op_type,
+ tiled_op_type,
+ target,
+ total_num_threads=total_num_threads,
+ desired_bit_alignment=desired_bit_alignment,
+ loc=loc,
+ ip=ip,
+ )
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class VectorizeOp(VectorizeOp):
+ """Specialization for VectorizeOp class."""
+
+ def __init__(
+ self,
+ target: Union[Operation, OpView, Value],
+ vector_sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
+ *,
+ vectorize_nd_extract: Optional[bool] = None,
+ scalable_sizes: OptionalBoolList = None,
+ static_vector_sizes: OptionalIntList = None,
+ loc=None,
+ ip=None,
+ ):
+ if (
+ scalable_sizes is None
+ and static_vector_sizes is None
+ and vector_sizes is None
+ ):
+ dynamic_vector_sizes = []
+ elif scalable_sizes is None and static_vector_sizes is None:
+ (
+ dynamic_vector_sizes,
+ static_vector_sizes,
+ scalable_sizes,
+ ) = _dispatch_dynamic_index_list(vector_sizes)
+ elif scalable_sizes is None or static_vector_sizes is None:
+ raise TypeError(
+ "'scalable_sizes' and 'static_vector_sizes' must either both "
+ "be given explicitly or both be given as part of 'vector_sizes'."
+ )
+ else:
+ dynamic_vector_sizes = vector_sizes
+
+ super().__init__(
+ target,
+ vector_sizes=dynamic_vector_sizes,
+ static_vector_sizes=static_vector_sizes,
+ scalable_sizes=scalable_sizes,
+ vectorize_nd_extract=vectorize_nd_extract,
+ loc=loc,
+ ip=ip,
+ )
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class MatchOp(MatchOp):
+ """Specialization for MatchOp class."""
+
+ @overload
+ @classmethod
+ def match_op_names(
+ cls,
+ target: Union[Operation, Value],
+ names: Union[str, Sequence[str]],
+ *,
+ loc=None,
+ ip=None,
+ ):
+ ...
+
+ @overload
+ @classmethod
+ def match_op_names(
+ cls,
+ result_type: Type,
+ target: Union[Operation, Value],
+ names: Union[str, Sequence[str]],
+ *,
+ loc=None,
+ ip=None,
+ ):
+ ...
+
+ @classmethod
+ def match_op_names(
+ cls,
+ result_type_or_target: Union[Type, Operation, Value],
+ target_or_names: Union[Operation, Value, Sequence[str], str],
+ names_or_none: Optional[Union[Sequence[str], str]] = None,
+ *,
+ loc=None,
+ ip=None,
+ ):
+ if isinstance(result_type_or_target, Type):
+ result_type = result_type_or_target
+ target = target_or_names
+ names = names_or_none
+ else:
+ result_type = transform.AnyOpType.get()
+ target = result_type_or_target
+ names = target_or_names
+
+ if isinstance(names, str):
+ names = [names]
+
+ return cls(
+ result_type,
+ target,
+ ops=ArrayAttr.get(list(map(lambda s: StringAttr.get(s), names))),
+ loc=loc,
+ ip=ip,
+ )
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class MultiTileSizesOp(MultiTileSizesOp):
+ """Specialization for MultiTileSizesOp class."""
+
+ def __init__(
+ self,
+ result_type: Type,
+ target: Union[Operation, Value],
+ *,
+ dimension: Union[int, IntegerAttr],
+ target_size: Union[int, IntegerAttr],
+ divisor: Optional[Optional[Union[int, IntegerAttr]]] = None,
+ loc=None,
+ ip=None,
+ ):
+ super().__init__(
+ result_type,
+ result_type,
+ result_type,
+ target,
+ dimension=dimension,
+ target_size=target_size,
+ divisor=divisor,
+ loc=loc,
+ ip=ip,
+ )
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class PadOp(PadOp):
+ """Specialization for PadOp class."""
+
+ def __init__(
+ self,
+ target: Union[Operation, OpView, Value],
+ *,
+ padding_values: Optional[Union[ArrayAttr, Sequence[Attribute]]] = None,
+ padding_dimensions: OptionalIntList = None,
+ pad_to_multiple_of: OptionalIntList = None,
+ pack_paddings: OptionalIntList = None,
+ transpose_paddings: Optional[
+ Union[ArrayAttr, Sequence[Union[ArrayAttr, IntOrAttrList]]]
+ ] = None,
+ copy_back_op: Optional[Union[str, StringAttr]] = None,
+ loc=None,
+ ip=None,
+ ):
+ transpose_paddings = _get_int_array_array_attr(transpose_paddings)
+
+ any_op_type = transform.AnyOpType.get()
+ super().__init__(
+ any_op_type,
+ any_op_type,
+ any_op_type,
+ target,
+ padding_values=padding_values,
+ padding_dimensions=padding_dimensions,
+ pad_to_multiple_of=pad_to_multiple_of,
+ pack_paddings=pack_paddings,
+ transpose_paddings=transpose_paddings,
+ copy_back_op=copy_back_op,
+ loc=loc,
+ ip=ip,
+ )
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class ScalarizeOp(ScalarizeOp):
+ """Specialization for ScalarizeOp class."""
+
+ def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
+ result_type = transform.AnyOpType.get()
+ super().__init__(result_type, target, loc=loc, ip=ip)
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class SplitOp(SplitOp):
+ """Specialization for SplitOp class."""
+
+ def __init__(
+ self,
+ target: Union[Operation, Value],
+ dimension: Union[int, Attribute],
+ split_point: Union[int, Operation, Value, Attribute],
+ *,
+ loc=None,
+ ip=None,
+ ):
+ if isinstance(split_point, int):
+ static_split_point = split_point
+ dynamic_split_point = None
+ else:
+ static_split_point = ShapedType.get_dynamic_size()
+ dynamic_split_point = split_point
+
+ super().__init__(
+ target.type,
+ target.type,
+ target,
+ dimension=dimension,
+ static_split_point=static_split_point,
+ dynamic_split_point=dynamic_split_point,
+ loc=loc,
+ ip=ip,
+ )
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class TileUsingForOp(TileUsingForOp):
+ """Specialization for TileUsingForOp class."""
+
+ @overload
+ def __init__(
+ self,
+ loop_types: Union[Type, List[Type]],
+ target: Union[Operation, Value],
+ *,
+ sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
+ interchange: OptionalIntList = None,
+ loc=None,
+ ip=None,
+ ):
+ ...
+
+ @overload
+ def __init__(
+ self,
+ target: Union[Operation, Value, OpView],
+ *,
+ sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
+ interchange: OptionalIntList = None,
+ loc=None,
+ ip=None,
+ ):
+ ...
+
+ def __init__(
+ self,
+ loop_types_or_target: Union[Type, List[Type], Operation, Value],
+ target_or_none: Optional[Union[Operation, Value, OpView]] = None,
+ *,
+ sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
+ interchange: OptionalIntList = None,
+ loc=None,
+ ip=None,
+ ):
+ (
+ dynamic_sizes,
+ static_sizes,
+ scalable_sizes,
+ ) = _dispatch_dynamic_index_list(sizes)
+
+ num_loops = sum(v if v == 0 else 1 for v in static_sizes)
+
+ if isinstance(loop_types_or_target, (Operation, Value, OpView)):
+ loop_types = [transform.AnyOpType.get()] * num_loops
+ target = loop_types_or_target
+ assert (
+ target_or_none is None
+ ), "Cannot construct TileUsingForOp with two targets."
+ else:
+ loop_types = (
+ ([loop_types_or_target] * num_loops)
+ if isinstance(loop_types_or_target, Type)
+ else loop_types_or_target
+ )
+ target = target_or_none
+
+ super().__init__(
+ target.type,
+ loop_types,
+ target,
+ dynamic_sizes=dynamic_sizes,
+ static_sizes=static_sizes,
+ interchange=interchange,
+ scalable_sizes=scalable_sizes,
+ loc=loc,
+ ip=ip,
+ )
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class TileUsingForallOp(TileUsingForallOp):
+ """Specialization for TileUsingForallOp class."""
+
+ @overload
+ def __init__(
+ self,
+ loops_type: Type,
+ tiled_op_type: Type,
+ target: Union[Operation, Value, OpView],
+ *,
+ num_threads: Optional[MixedValues] = None,
+ tile_sizes: MixedValues = None,
+ mapping=None,
+ loc=None,
+ ip=None,
+ ):
+ ...
+
+ @overload
+ def __init__(
+ self,
+ target: Union[Operation, Value, OpView],
+ *,
+ num_threads: Optional[MixedValues] = None,
+ tile_sizes: MixedValues = None,
+ mapping=None,
+ loc=None,
+ ip=None,
+ ):
+ ...
+
+ def __init__(
+ self,
+ loops_type_or_target: Union[
+ Type, Union[Operation, Value, OpView] # loops_type
+ ], # target
+ tiled_op_type_or_none: Optional[Type] = None,
+ target_or_none: Optional[Union[Operation, Value, OpView]] = None,
+ *,
+ num_threads: MixedValues = None,
+ tile_sizes: MixedValues = None,
+ mapping=None,
+ loc=None,
+ ip=None,
+ ):
+ # `Type` arguments in the front are optional: add default values to front.
+ if isinstance(loops_type_or_target, Type):
+ # First overload: type arguments provided.
+ if not isinstance(tiled_op_type_or_none, Type):
+ raise TypeError(
+ "If 'loops_type_or_target' is a type, then "
+ "'tiled_op_type_or_none' is expected to be one as well."
+ )
+ loops_type = loops_type_or_target
+ tiled_op_type = tiled_op_type_or_none
+ target = target_or_none
+ else:
+ # Last overload: type arguments missing.
+ loops_type = transform.AnyOpType.get()
+ tiled_op_type = transform.AnyOpType.get()
+ target = loops_type_or_target
+
+ # Unpack mixed num_threads.
+ (
+ dynamic_num_threads,
+ packed_num_threads,
+ num_threads_attr,
+ ) = _dispatch_mixed_values(num_threads)
+
+ # Unpack mixed tile_sizes.
+ (
+ dynamic_tile_sizes,
+ packed_tile_sizes,
+ tile_sizes_attr,
+ ) = _dispatch_mixed_values(tile_sizes)
+
+ super().__init__(
+ loops_type,
+ tiled_op_type,
+ target=target,
+ tile_sizes=dynamic_tile_sizes,
+ packed_tile_sizes=packed_tile_sizes,
+ static_tile_sizes=tile_sizes_attr,
+ num_threads=dynamic_num_threads,
+ packed_num_threads=packed_num_threads,
+ static_num_threads=num_threads_attr,
+ mapping=mapping,
+ loc=loc,
+ ip=ip,
+ )
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class VectorizeChildrenAndApplyPatternsOp(VectorizeChildrenAndApplyPatternsOp):
+ """Specialization for VectorizeChildrenAndApplyPatternsOp class."""
+
+ def __init__(
+ self,
+ target: Union[Operation, Value],
+ *,
+ disable_multi_reduction_to_contract_patterns: bool = False,
+ disable_transfer_permutation_map_lowering_patterns: bool = False,
+ vectorize_nd_extract: bool = False,
+ vectorize_padding: bool = False,
+ loc=None,
+ ip=None,
+ ):
+ transformed_type = transform.AnyOpType.get()
+ super().__init__(
+ transformed_type,
+ target,
+ disable_multi_reduction_to_contract_patterns=disable_multi_reduction_to_contract_patterns,
+ disable_transfer_permutation_map_lowering_patterns=disable_transfer_permutation_map_lowering_patterns,
+ vectorize_nd_extract=vectorize_nd_extract,
+ vectorize_padding=vectorize_padding,
+ loc=loc,
+ ip=ip,
+ )
diff --git a/mlir/python/mlir/dialects/transform/tensor.py b/mlir/python/mlir/dialects/transform/tensor.py
index bf52255b3df7145..4eb30398f087212 100644
--- a/mlir/python/mlir/dialects/transform/tensor.py
+++ b/mlir/python/mlir/dialects/transform/tensor.py
@@ -3,3 +3,67 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from .._tensor_transform_ops_gen import *
+from .._tensor_transform_ops_gen import _Dialect
+
+try:
+ from ...ir import *
+ from ...dialects import transform
+ from .._ods_common import _cext as _ods_cext
+except ImportError as e:
+ raise RuntimeError("Error loading imports from extension module") from e
+
+from typing import Optional, overload, Union
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class MakeLoopIndependentOp(MakeLoopIndependentOp):
+ """Specialization for MakeLoopIndependentOp class."""
+
+ @overload
+ def __init__(
+ self,
+ transformed_type: Type,
+ target: Union[Operation, OpView, Value],
+ num_loops: Union[int, IntegerAttr],
+ *,
+ loc=None,
+ ip=None,
+ ):
+ ...
+
+ @overload
+ def __init__(
+ self,
+ target: Union[Operation, OpView, Value],
+ num_loops: Union[int, IntegerAttr],
+ *,
+ loc=None,
+ ip=None,
+ ):
+ ...
+
+ def __init__(
+ self,
+ transformed_type_or_target: Type,
+ target_or_num_loops: Union[int, IntegerAttr, Operation, OpView, Value] = None,
+ num_loops_or_none: Optional[Union[int, IntegerAttr]] = None,
+ *,
+ loc=None,
+ ip=None,
+ ):
+ if isinstance(transformed_type_or_target, Type):
+ transformed_type = transformed_type_or_target
+ target = target_or_num_loops
+ num_loops = num_loops_or_none
+ else:
+ transformed_type = transform.AnyOpType.get()
+ target = transformed_type_or_target
+ num_loops = target_or_num_loops
+
+ super().__init__(
+ transformed_type,
+ target,
+ num_loops,
+ loc=loc,
+ ip=ip,
+ )
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 2c81538b7b40433..9eaf2f89530c8f1 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -30,7 +30,7 @@ constexpr const char *fileHeader = R"Py(
# Autogenerated by mlir-tblgen; don't manually edit.
from ._ods_common import _cext as _ods_cext
-from ._ods_common import extend_opview_class as _ods_extend_opview_class, segmented_accessor as _ods_segmented_accessor, equally_sized_accessor as _ods_equally_sized_accessor, get_default_loc_context as _ods_get_default_loc_context, get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values, get_op_result_or_op_results as _get_op_result_or_op_results
+from ._ods_common import segmented_accessor as _ods_segmented_accessor, equally_sized_accessor as _ods_equally_sized_accessor, get_default_loc_context as _ods_get_default_loc_context, get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values, get_op_result_or_op_results as _get_op_result_or_op_results
_ods_ir = _ods_cext.ir
try:
@@ -62,7 +62,6 @@ from ._{0}_ops_gen import _Dialect
/// {1} is the operation name.
constexpr const char *opClassTemplate = R"Py(
@_ods_cext.register_operation(_Dialect)
- at _ods_extend_opview_class(_ods_ext_module)
class {0}(_ods_ir.OpView):
OPERATION_NAME = "{1}"
)Py";
@@ -850,9 +849,6 @@ populateBuilderRegions(const Operator &op,
static llvm::SmallVector<std::string> emitDefaultOpBuilder(const Operator &op,
raw_ostream &os) {
// If we are asked to skip default builders, comply.
- if (op.skipDefaultBuilders())
- return {};
-
llvm::SmallVector<std::string> builderArgs;
llvm::SmallVector<std::string> builderLines;
llvm::SmallVector<std::string> operandArgNames;
@@ -985,9 +981,6 @@ static void emitRegionAccessors(const Operator &op, raw_ostream &os) {
static void emitValueBuilder(const Operator &op,
llvm::SmallVector<std::string> functionArgs,
raw_ostream &os) {
- // If we are asked to skip default builders, comply.
- if (op.skipDefaultBuilders())
- return;
auto name = sanitizeName(op.getOperationName());
iterator_range<llvm::SplittingIterator> splitName = llvm::split(name, ".");
// Params with (possibly) default args.
More information about the Mlir-commits
mailing list