[Mlir-commits] [mlir] [mlir][python] value casting (PR #68763)

Maksim Levental llvmlistbot at llvm.org
Thu Oct 19 13:49:14 PDT 2023


https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/68763

>From dba43135ce1b43e1ef274be4686d853b5c24ab2c Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Thu, 12 Oct 2023 01:16:19 -0500
Subject: [PATCH 1/8] delete stubs

---
 mlir/python/mlir/dialects/_linalg_ops_ext.py  | 47 ---------------
 mlir/python/mlir/dialects/_ods_common.py      | 59 -------------------
 mlir/python/mlir/dialects/arith.py            |  6 --
 mlir/python/mlir/dialects/bufferization.py    |  6 --
 mlir/python/mlir/dialects/builtin.py          |  5 --
 mlir/python/mlir/dialects/func.py             |  5 --
 mlir/python/mlir/dialects/memref.py           |  5 --
 mlir/python/mlir/dialects/ml_program.py       |  5 --
 mlir/python/mlir/dialects/pdl.py              |  6 --
 mlir/python/mlir/dialects/scf.py              | 43 --------------
 mlir/python/mlir/dialects/tensor.py           |  5 --
 .../mlir/dialects/transform/__init__.py       |  7 ---
 .../mlir/dialects/transform/bufferization.py  |  5 --
 mlir/python/mlir/dialects/transform/gpu.py    |  5 --
 mlir/python/mlir/dialects/transform/loop.py   |  5 --
 mlir/python/mlir/dialects/transform/memref.py |  5 --
 mlir/python/mlir/dialects/transform/pdl.py    |  5 --
 .../mlir/dialects/transform/structured.py     |  6 --
 mlir/python/mlir/dialects/transform/tensor.py |  5 --
 19 files changed, 235 deletions(-)
 delete mode 100644 mlir/python/mlir/dialects/_linalg_ops_ext.py
 delete mode 100644 mlir/python/mlir/dialects/arith.py
 delete mode 100644 mlir/python/mlir/dialects/bufferization.py
 delete mode 100644 mlir/python/mlir/dialects/builtin.py
 delete mode 100644 mlir/python/mlir/dialects/func.py
 delete mode 100644 mlir/python/mlir/dialects/memref.py
 delete mode 100644 mlir/python/mlir/dialects/ml_program.py
 delete mode 100644 mlir/python/mlir/dialects/pdl.py
 delete mode 100644 mlir/python/mlir/dialects/scf.py
 delete mode 100644 mlir/python/mlir/dialects/tensor.py
 delete mode 100644 mlir/python/mlir/dialects/transform/__init__.py
 delete mode 100644 mlir/python/mlir/dialects/transform/bufferization.py
 delete mode 100644 mlir/python/mlir/dialects/transform/gpu.py
 delete mode 100644 mlir/python/mlir/dialects/transform/loop.py
 delete mode 100644 mlir/python/mlir/dialects/transform/memref.py
 delete mode 100644 mlir/python/mlir/dialects/transform/pdl.py
 delete mode 100644 mlir/python/mlir/dialects/transform/structured.py
 delete mode 100644 mlir/python/mlir/dialects/transform/tensor.py

diff --git a/mlir/python/mlir/dialects/_linalg_ops_ext.py b/mlir/python/mlir/dialects/_linalg_ops_ext.py
deleted file mode 100644
index 3f6d854ca3e2b14..000000000000000
--- a/mlir/python/mlir/dialects/_linalg_ops_ext.py
+++ /dev/null
@@ -1,47 +0,0 @@
-#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-#  See https://llvm.org/LICENSE.txt for license information.
-#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-try:
-    from typing import Optional, Sequence, Union
-    from ..ir import *
-    from ._ods_common import get_default_loc_context
-    from .._mlir_libs._mlirDialectsLinalg import fill_builtin_region
-except ImportError as e:
-    raise RuntimeError("Error loading imports from extension module") from e
-
-from ._ods_common import get_op_result_or_value as _get_op_result_or_value
-
-
-def isa(cls: Type, ty: Type):
-    try:
-        cls(ty)
-        return True
-    except ValueError:
-        return False
-
-
-class StructuredOpMixin:
-    """All structured ops use the same mixin class."""
-
-    def __init__(self, inputs, outputs=(), results=(), loc=None, ip=None):
-        super().__init__(
-            self.build_generic(
-                results=list(results),
-                operands=[list(inputs), list(outputs)],
-                loc=loc,
-                ip=ip,
-            )
-        )
-
-
-def select_opview_mixin(parent_opview_cls):
-    # TODO: This shouldn't be a heuristic: we should have a way to annotate
-    # the OpView to note that it is a structured op.
-    if (
-        "__init__" not in parent_opview_cls.__dict__
-        and hasattr(parent_opview_cls, "inputs")
-        and hasattr(parent_opview_cls, "outputs")
-        and hasattr(parent_opview_cls, "result_tensors")
-    ):
-        return StructuredOpMixin
diff --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py
index 895c3228139b392..9cca7d659ec8cb3 100644
--- a/mlir/python/mlir/dialects/_ods_common.py
+++ b/mlir/python/mlir/dialects/_ods_common.py
@@ -9,7 +9,6 @@
 
 __all__ = [
     "equally_sized_accessor",
-    "extend_opview_class",
     "get_default_loc_context",
     "get_op_result_or_value",
     "get_op_results_or_values",
@@ -18,64 +17,6 @@
 ]
 
 
-def extend_opview_class(ext_module):
-    """Decorator to extend an OpView class from an extension module.
-
-    Extension modules can expose various entry-points:
-      Stand-alone class with the same name as a parent OpView class (i.e.
-      "ReturnOp"). A name-based match is attempted first before falling back
-      to a below mechanism.
-
-      def select_opview_mixin(parent_opview_cls):
-        If defined, allows an appropriate mixin class to be selected dynamically
-        based on the parent OpView class. Should return NotImplemented if a
-        decision is not made.
-
-    Args:
-      ext_module: A module from which to locate extensions. Can be None if not
-        available.
-
-    Returns:
-      A decorator that takes an OpView subclass and further extends it as
-      needed.
-    """
-
-    def class_decorator(parent_opview_cls: type):
-        if ext_module is None:
-            return parent_opview_cls
-        mixin_cls = NotImplemented
-        # First try to resolve by name.
-        try:
-            mixin_cls = getattr(ext_module, parent_opview_cls.__name__)
-        except AttributeError:
-            # Fall back to a select_opview_mixin hook.
-            try:
-                select_mixin = getattr(ext_module, "select_opview_mixin")
-            except AttributeError:
-                pass
-            else:
-                mixin_cls = select_mixin(parent_opview_cls)
-
-        if mixin_cls is NotImplemented or mixin_cls is None:
-            return parent_opview_cls
-
-        # Have a mixin_cls. Create an appropriate subclass.
-        try:
-
-            class LocalOpView(mixin_cls, parent_opview_cls):
-                pass
-
-        except TypeError as e:
-            raise TypeError(
-                f"Could not mixin {mixin_cls} into {parent_opview_cls}"
-            ) from e
-        LocalOpView.__name__ = parent_opview_cls.__name__
-        LocalOpView.__qualname__ = parent_opview_cls.__qualname__
-        return LocalOpView
-
-    return class_decorator
-
-
 def segmented_accessor(elements, raw_segments, idx):
     """
     Returns a slice of elements corresponding to the idx-th segment.
diff --git a/mlir/python/mlir/dialects/arith.py b/mlir/python/mlir/dialects/arith.py
deleted file mode 100644
index fb13beb63ca66c3..000000000000000
--- a/mlir/python/mlir/dialects/arith.py
+++ /dev/null
@@ -1,6 +0,0 @@
-#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-#  See https://llvm.org/LICENSE.txt for license information.
-#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-from ._arith_ops_gen import *
-from ._arith_enum_gen import *
diff --git a/mlir/python/mlir/dialects/bufferization.py b/mlir/python/mlir/dialects/bufferization.py
deleted file mode 100644
index 759b6aa24a9ff73..000000000000000
--- a/mlir/python/mlir/dialects/bufferization.py
+++ /dev/null
@@ -1,6 +0,0 @@
-#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-#  See https://llvm.org/LICENSE.txt for license information.
-#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-from ._bufferization_ops_gen import *
-from ._bufferization_enum_gen import *
diff --git a/mlir/python/mlir/dialects/builtin.py b/mlir/python/mlir/dialects/builtin.py
deleted file mode 100644
index 30279e1611f99aa..000000000000000
--- a/mlir/python/mlir/dialects/builtin.py
+++ /dev/null
@@ -1,5 +0,0 @@
-#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-#  See https://llvm.org/LICENSE.txt for license information.
-#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-from ._builtin_ops_gen import *
diff --git a/mlir/python/mlir/dialects/func.py b/mlir/python/mlir/dialects/func.py
deleted file mode 100644
index dc554c22173bc60..000000000000000
--- a/mlir/python/mlir/dialects/func.py
+++ /dev/null
@@ -1,5 +0,0 @@
-#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-#  See https://llvm.org/LICENSE.txt for license information.
-#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-from ._func_ops_gen import *
diff --git a/mlir/python/mlir/dialects/memref.py b/mlir/python/mlir/dialects/memref.py
deleted file mode 100644
index 3afb6a70cb9e0db..000000000000000
--- a/mlir/python/mlir/dialects/memref.py
+++ /dev/null
@@ -1,5 +0,0 @@
-#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-#  See https://llvm.org/LICENSE.txt for license information.
-#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-from ._memref_ops_gen import *
diff --git a/mlir/python/mlir/dialects/ml_program.py b/mlir/python/mlir/dialects/ml_program.py
deleted file mode 100644
index a654529b4bb8843..000000000000000
--- a/mlir/python/mlir/dialects/ml_program.py
+++ /dev/null
@@ -1,5 +0,0 @@
-#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-#  See https://llvm.org/LICENSE.txt for license information.
-#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-from ._ml_program_ops_gen import *
diff --git a/mlir/python/mlir/dialects/pdl.py b/mlir/python/mlir/dialects/pdl.py
deleted file mode 100644
index dda2b7d6521965f..000000000000000
--- a/mlir/python/mlir/dialects/pdl.py
+++ /dev/null
@@ -1,6 +0,0 @@
-#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-#  See https://llvm.org/LICENSE.txt for license information.
-#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-from ._pdl_ops_gen import *
-from .._mlir_libs._mlirDialectsPDL import *
diff --git a/mlir/python/mlir/dialects/scf.py b/mlir/python/mlir/dialects/scf.py
deleted file mode 100644
index 49685ca2271fc61..000000000000000
--- a/mlir/python/mlir/dialects/scf.py
+++ /dev/null
@@ -1,43 +0,0 @@
-#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-#  See https://llvm.org/LICENSE.txt for license information.
-#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-from typing import Optional, Sequence
-
-from ._scf_ops_gen import *
-from .arith import constant
-from ..ir import *
-
-
-def for_(
-    start,
-    stop=None,
-    step=None,
-    iter_args: Optional[Sequence[Value]] = None,
-    *,
-    loc=None,
-    ip=None,
-):
-    if step is None:
-        step = 1
-    if stop is None:
-        stop = start
-        start = 0
-    params = [start, stop, step]
-    for i, p in enumerate(params):
-        if isinstance(p, int):
-            p = constant(p)
-        elif isinstance(p, float):
-            raise ValueError(f"{p=} must be int.")
-        params[i] = p
-
-    for_op = ForOp(start, stop, step, iter_args, loc=loc, ip=ip)
-    iv = for_op.induction_variable
-    iter_args = tuple(for_op.inner_iter_args)
-    with InsertionPoint(for_op.body):
-        if len(iter_args) > 1:
-            yield iv, iter_args
-        elif len(iter_args) == 1:
-            yield iv, iter_args[0]
-        else:
-            yield iv
diff --git a/mlir/python/mlir/dialects/tensor.py b/mlir/python/mlir/dialects/tensor.py
deleted file mode 100644
index 26edf6b6436dad5..000000000000000
--- a/mlir/python/mlir/dialects/tensor.py
+++ /dev/null
@@ -1,5 +0,0 @@
-#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-#  See https://llvm.org/LICENSE.txt for license information.
-#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-from ._tensor_ops_gen import *
diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py
deleted file mode 100644
index b020ad35fcf062f..000000000000000
--- a/mlir/python/mlir/dialects/transform/__init__.py
+++ /dev/null
@@ -1,7 +0,0 @@
-#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-#  See https://llvm.org/LICENSE.txt for license information.
-#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-from .._transform_enum_gen import *
-from .._transform_ops_gen import *
-from ..._mlir_libs._mlirDialectsTransform import *
diff --git a/mlir/python/mlir/dialects/transform/bufferization.py b/mlir/python/mlir/dialects/transform/bufferization.py
deleted file mode 100644
index eb77b746cf864fa..000000000000000
--- a/mlir/python/mlir/dialects/transform/bufferization.py
+++ /dev/null
@@ -1,5 +0,0 @@
-#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-#  See https://llvm.org/LICENSE.txt for license information.
-#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-from .._bufferization_transform_ops_gen import *
diff --git a/mlir/python/mlir/dialects/transform/gpu.py b/mlir/python/mlir/dialects/transform/gpu.py
deleted file mode 100644
index 8c3de0de7ea3f19..000000000000000
--- a/mlir/python/mlir/dialects/transform/gpu.py
+++ /dev/null
@@ -1,5 +0,0 @@
-#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-#  See https://llvm.org/LICENSE.txt for license information.
-#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-from .._gpu_transform_ops_gen import *
diff --git a/mlir/python/mlir/dialects/transform/loop.py b/mlir/python/mlir/dialects/transform/loop.py
deleted file mode 100644
index 86f72788d86c369..000000000000000
--- a/mlir/python/mlir/dialects/transform/loop.py
+++ /dev/null
@@ -1,5 +0,0 @@
-#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-#  See https://llvm.org/LICENSE.txt for license information.
-#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-from .._loop_transform_ops_gen import *
diff --git a/mlir/python/mlir/dialects/transform/memref.py b/mlir/python/mlir/dialects/transform/memref.py
deleted file mode 100644
index 1ff04ef6a60a180..000000000000000
--- a/mlir/python/mlir/dialects/transform/memref.py
+++ /dev/null
@@ -1,5 +0,0 @@
-#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-#  See https://llvm.org/LICENSE.txt for license information.
-#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-from .._memref_transform_ops_gen import *
diff --git a/mlir/python/mlir/dialects/transform/pdl.py b/mlir/python/mlir/dialects/transform/pdl.py
deleted file mode 100644
index b1515287a3f1ff0..000000000000000
--- a/mlir/python/mlir/dialects/transform/pdl.py
+++ /dev/null
@@ -1,5 +0,0 @@
-#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-#  See https://llvm.org/LICENSE.txt for license information.
-#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-from .._transform_pdl_extension_ops_gen import *
diff --git a/mlir/python/mlir/dialects/transform/structured.py b/mlir/python/mlir/dialects/transform/structured.py
deleted file mode 100644
index cb3812301dbd4b5..000000000000000
--- a/mlir/python/mlir/dialects/transform/structured.py
+++ /dev/null
@@ -1,6 +0,0 @@
-#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-#  See https://llvm.org/LICENSE.txt for license information.
-#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-from .._structured_transform_ops_gen import *
-from .._structured_transform_enum_gen import *
diff --git a/mlir/python/mlir/dialects/transform/tensor.py b/mlir/python/mlir/dialects/transform/tensor.py
deleted file mode 100644
index bf52255b3df7145..000000000000000
--- a/mlir/python/mlir/dialects/transform/tensor.py
+++ /dev/null
@@ -1,5 +0,0 @@
-#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-#  See https://llvm.org/LICENSE.txt for license information.
-#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-from .._tensor_transform_ops_gen import *

>From 9a32dc6a09759084bcd3443c2dcd873c49004592 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Thu, 12 Oct 2023 01:21:39 -0500
Subject: [PATCH 2/8] rename

---
 mlir/python/mlir/dialects/{_arith_ops_ext.py => arith.py}         | 0
 .../mlir/dialects/{_bufferization_ops_ext.py => bufferization.py} | 0
 mlir/python/mlir/dialects/{_builtin_ops_ext.py => builtin.py}     | 0
 mlir/python/mlir/dialects/{_func_ops_ext.py => func.py}           | 0
 mlir/python/mlir/dialects/{_memref_ops_ext.py => memref.py}       | 0
 .../mlir/dialects/{_ml_program_ops_ext.py => ml_program.py}       | 0
 mlir/python/mlir/dialects/{_pdl_ops_ext.py => pdl.py}             | 0
 mlir/python/mlir/dialects/{_scf_ops_ext.py => scf.py}             | 0
 mlir/python/mlir/dialects/{_tensor_ops_ext.py => tensor.py}       | 0
 .../dialects/{_transform_ops_ext.py => transform/__init__.py}     | 0
 .../bufferization.py}                                             | 0
 .../mlir/dialects/{_gpu_transform_ops_ext.py => transform/gpu.py} | 0
 .../dialects/{_loop_transform_ops_ext.py => transform/loop.py}    | 0
 .../{_memref_transform_ops_ext.py => transform/memref.py}         | 0
 .../{_transform_pdl_extension_ops_ext.py => transform/pdl.py}     | 0
 .../{_structured_transform_ops_ext.py => transform/structured.py} | 0
 .../{_tensor_transform_ops_ext.py => transform/tensor.py}         | 0
 17 files changed, 0 insertions(+), 0 deletions(-)
 rename mlir/python/mlir/dialects/{_arith_ops_ext.py => arith.py} (100%)
 rename mlir/python/mlir/dialects/{_bufferization_ops_ext.py => bufferization.py} (100%)
 rename mlir/python/mlir/dialects/{_builtin_ops_ext.py => builtin.py} (100%)
 rename mlir/python/mlir/dialects/{_func_ops_ext.py => func.py} (100%)
 rename mlir/python/mlir/dialects/{_memref_ops_ext.py => memref.py} (100%)
 rename mlir/python/mlir/dialects/{_ml_program_ops_ext.py => ml_program.py} (100%)
 rename mlir/python/mlir/dialects/{_pdl_ops_ext.py => pdl.py} (100%)
 rename mlir/python/mlir/dialects/{_scf_ops_ext.py => scf.py} (100%)
 rename mlir/python/mlir/dialects/{_tensor_ops_ext.py => tensor.py} (100%)
 rename mlir/python/mlir/dialects/{_transform_ops_ext.py => transform/__init__.py} (100%)
 rename mlir/python/mlir/dialects/{_bufferization_transform_ops_ext.py => transform/bufferization.py} (100%)
 rename mlir/python/mlir/dialects/{_gpu_transform_ops_ext.py => transform/gpu.py} (100%)
 rename mlir/python/mlir/dialects/{_loop_transform_ops_ext.py => transform/loop.py} (100%)
 rename mlir/python/mlir/dialects/{_memref_transform_ops_ext.py => transform/memref.py} (100%)
 rename mlir/python/mlir/dialects/{_transform_pdl_extension_ops_ext.py => transform/pdl.py} (100%)
 rename mlir/python/mlir/dialects/{_structured_transform_ops_ext.py => transform/structured.py} (100%)
 rename mlir/python/mlir/dialects/{_tensor_transform_ops_ext.py => transform/tensor.py} (100%)

diff --git a/mlir/python/mlir/dialects/_arith_ops_ext.py b/mlir/python/mlir/dialects/arith.py
similarity index 100%
rename from mlir/python/mlir/dialects/_arith_ops_ext.py
rename to mlir/python/mlir/dialects/arith.py
diff --git a/mlir/python/mlir/dialects/_bufferization_ops_ext.py b/mlir/python/mlir/dialects/bufferization.py
similarity index 100%
rename from mlir/python/mlir/dialects/_bufferization_ops_ext.py
rename to mlir/python/mlir/dialects/bufferization.py
diff --git a/mlir/python/mlir/dialects/_builtin_ops_ext.py b/mlir/python/mlir/dialects/builtin.py
similarity index 100%
rename from mlir/python/mlir/dialects/_builtin_ops_ext.py
rename to mlir/python/mlir/dialects/builtin.py
diff --git a/mlir/python/mlir/dialects/_func_ops_ext.py b/mlir/python/mlir/dialects/func.py
similarity index 100%
rename from mlir/python/mlir/dialects/_func_ops_ext.py
rename to mlir/python/mlir/dialects/func.py
diff --git a/mlir/python/mlir/dialects/_memref_ops_ext.py b/mlir/python/mlir/dialects/memref.py
similarity index 100%
rename from mlir/python/mlir/dialects/_memref_ops_ext.py
rename to mlir/python/mlir/dialects/memref.py
diff --git a/mlir/python/mlir/dialects/_ml_program_ops_ext.py b/mlir/python/mlir/dialects/ml_program.py
similarity index 100%
rename from mlir/python/mlir/dialects/_ml_program_ops_ext.py
rename to mlir/python/mlir/dialects/ml_program.py
diff --git a/mlir/python/mlir/dialects/_pdl_ops_ext.py b/mlir/python/mlir/dialects/pdl.py
similarity index 100%
rename from mlir/python/mlir/dialects/_pdl_ops_ext.py
rename to mlir/python/mlir/dialects/pdl.py
diff --git a/mlir/python/mlir/dialects/_scf_ops_ext.py b/mlir/python/mlir/dialects/scf.py
similarity index 100%
rename from mlir/python/mlir/dialects/_scf_ops_ext.py
rename to mlir/python/mlir/dialects/scf.py
diff --git a/mlir/python/mlir/dialects/_tensor_ops_ext.py b/mlir/python/mlir/dialects/tensor.py
similarity index 100%
rename from mlir/python/mlir/dialects/_tensor_ops_ext.py
rename to mlir/python/mlir/dialects/tensor.py
diff --git a/mlir/python/mlir/dialects/_transform_ops_ext.py b/mlir/python/mlir/dialects/transform/__init__.py
similarity index 100%
rename from mlir/python/mlir/dialects/_transform_ops_ext.py
rename to mlir/python/mlir/dialects/transform/__init__.py
diff --git a/mlir/python/mlir/dialects/_bufferization_transform_ops_ext.py b/mlir/python/mlir/dialects/transform/bufferization.py
similarity index 100%
rename from mlir/python/mlir/dialects/_bufferization_transform_ops_ext.py
rename to mlir/python/mlir/dialects/transform/bufferization.py
diff --git a/mlir/python/mlir/dialects/_gpu_transform_ops_ext.py b/mlir/python/mlir/dialects/transform/gpu.py
similarity index 100%
rename from mlir/python/mlir/dialects/_gpu_transform_ops_ext.py
rename to mlir/python/mlir/dialects/transform/gpu.py
diff --git a/mlir/python/mlir/dialects/_loop_transform_ops_ext.py b/mlir/python/mlir/dialects/transform/loop.py
similarity index 100%
rename from mlir/python/mlir/dialects/_loop_transform_ops_ext.py
rename to mlir/python/mlir/dialects/transform/loop.py
diff --git a/mlir/python/mlir/dialects/_memref_transform_ops_ext.py b/mlir/python/mlir/dialects/transform/memref.py
similarity index 100%
rename from mlir/python/mlir/dialects/_memref_transform_ops_ext.py
rename to mlir/python/mlir/dialects/transform/memref.py
diff --git a/mlir/python/mlir/dialects/_transform_pdl_extension_ops_ext.py b/mlir/python/mlir/dialects/transform/pdl.py
similarity index 100%
rename from mlir/python/mlir/dialects/_transform_pdl_extension_ops_ext.py
rename to mlir/python/mlir/dialects/transform/pdl.py
diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/transform/structured.py
similarity index 100%
rename from mlir/python/mlir/dialects/_structured_transform_ops_ext.py
rename to mlir/python/mlir/dialects/transform/structured.py
diff --git a/mlir/python/mlir/dialects/_tensor_transform_ops_ext.py b/mlir/python/mlir/dialects/transform/tensor.py
similarity index 100%
rename from mlir/python/mlir/dialects/_tensor_transform_ops_ext.py
rename to mlir/python/mlir/dialects/transform/tensor.py

>From 969b91a26ca3518f6e9f2952eb4feb4fc90f3267 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Thu, 12 Oct 2023 01:23:57 -0500
Subject: [PATCH 3/8] fix

---
 mlir/lib/Bindings/Python/Globals.h            |  2 +-
 mlir/lib/Bindings/Python/IRModule.cpp         |  4 +-
 mlir/lib/Bindings/Python/MainModule.cpp       | 11 +--
 mlir/python/CMakeLists.txt                    | 18 ----
 mlir/python/mlir/dialects/arith.py            | 10 ++-
 mlir/python/mlir/dialects/bufferization.py    | 25 +++---
 mlir/python/mlir/dialects/builtin.py          |  9 +-
 mlir/python/mlir/dialects/func.py             | 17 +++-
 .../dialects/linalg/opdsl/lang/emitter.py     |  2 +-
 mlir/python/mlir/dialects/memref.py           | 15 +++-
 mlir/python/mlir/dialects/ml_program.py       | 16 ++--
 mlir/python/mlir/dialects/pdl.py              | 48 +++++++---
 mlir/python/mlir/dialects/scf.py              | 63 +++++++++++--
 mlir/python/mlir/dialects/tensor.py           | 20 ++---
 .../mlir/dialects/transform/__init__.py       | 45 +++++-----
 .../mlir/dialects/transform/bufferization.py  | 24 +++--
 mlir/python/mlir/dialects/transform/gpu.py    | 26 +++---
 mlir/python/mlir/dialects/transform/loop.py   | 25 ++++--
 mlir/python/mlir/dialects/transform/memref.py | 24 +++--
 mlir/python/mlir/dialects/transform/pdl.py    | 90 +++++++++----------
 .../mlir/dialects/transform/structured.py     | 54 +++++++----
 mlir/python/mlir/dialects/transform/tensor.py | 17 ++--
 mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp | 49 ++++------
 23 files changed, 366 insertions(+), 248 deletions(-)

diff --git a/mlir/lib/Bindings/Python/Globals.h b/mlir/lib/Bindings/Python/Globals.h
index 97cd70089a2e965..dea44bbd469dd3d 100644
--- a/mlir/lib/Bindings/Python/Globals.h
+++ b/mlir/lib/Bindings/Python/Globals.h
@@ -80,7 +80,7 @@ class PyGlobals {
   /// Raises an exception if the mapping already exists.
   /// This is intended to be called by implementation code.
   void registerOperationImpl(const std::string &operationName,
-                             pybind11::object pyClass);
+                             pybind11::object pyClass, bool replace = false);
 
   /// Returns the custom Attribute builder for Attribute kind.
   std::optional<pybind11::function>
diff --git a/mlir/lib/Bindings/Python/IRModule.cpp b/mlir/lib/Bindings/Python/IRModule.cpp
index 2cc66277abee0f0..a1c8ab7a09ce155 100644
--- a/mlir/lib/Bindings/Python/IRModule.cpp
+++ b/mlir/lib/Bindings/Python/IRModule.cpp
@@ -96,9 +96,9 @@ void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
 }
 
 void PyGlobals::registerOperationImpl(const std::string &operationName,
-                                      py::object pyClass) {
+                                      py::object pyClass, bool replace) {
   py::object &found = operationClassMap[operationName];
-  if (found) {
+  if (found && !replace) {
     throw std::runtime_error((llvm::Twine("Operation '") + operationName +
                               "' is already registered.")
                                  .str());
diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index cdddfbe50606d05..a936becf67bea75 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -41,7 +41,7 @@ PYBIND11_MODULE(_mlir, m) {
            "dialect_namespace"_a, "dialect_class"_a,
            "Testing hook for directly registering a dialect")
       .def("_register_operation_impl", &PyGlobals::registerOperationImpl,
-           "operation_name"_a, "operation_class"_a,
+           "operation_name"_a, "operation_class"_a, "replace"_a = false,
            "Testing hook for directly registering an operation");
 
   // Aside from making the globals accessible to python, having python manage
@@ -63,12 +63,13 @@ PYBIND11_MODULE(_mlir, m) {
       "Class decorator for registering a custom Dialect wrapper");
   m.def(
       "register_operation",
-      [](const py::object &dialectClass) -> py::cpp_function {
+      [](const py::object &dialectClass, bool replace) -> py::cpp_function {
         return py::cpp_function(
-            [dialectClass](py::object opClass) -> py::object {
+            [dialectClass, replace](py::object opClass) -> py::object {
               std::string operationName =
                   opClass.attr("OPERATION_NAME").cast<std::string>();
-              PyGlobals::get().registerOperationImpl(operationName, opClass);
+              PyGlobals::get().registerOperationImpl(operationName, opClass,
+                                                     replace);
 
               // Dict-stuff the new opClass by name onto the dialect class.
               py::object opClassName = opClass.attr("__name__");
@@ -76,7 +77,7 @@ PYBIND11_MODULE(_mlir, m) {
               return opClass;
             });
       },
-      "dialect_class"_a,
+      "dialect_class"_a, "replace"_a = false,
       "Produce a class decorator for registering an Operation class as part of "
       "a dialect");
   m.def(
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index c7b3c283a6b6dc1..586d39fc332fac5 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -78,7 +78,6 @@ declare_mlir_dialect_python_bindings(
   TD_FILE dialects/BufferizationOps.td
   SOURCES
     dialects/bufferization.py
-    dialects/_bufferization_ops_ext.py
   DIALECT_NAME bufferization
   GEN_ENUM_BINDINGS_TD_FILE
     "../../include/mlir/Dialect/Bufferization/IR/BufferizationEnums.td"
@@ -90,7 +89,6 @@ declare_mlir_dialect_python_bindings(
   TD_FILE dialects/BuiltinOps.td
   SOURCES
     dialects/builtin.py
-    dialects/_builtin_ops_ext.py
   DIALECT_NAME builtin)
 
 declare_mlir_dialect_python_bindings(
@@ -115,7 +113,6 @@ declare_mlir_dialect_python_bindings(
   TD_FILE dialects/FuncOps.td
   SOURCES
     dialects/func.py
-    dialects/_func_ops_ext.py
   DIALECT_NAME func)
 
 declare_mlir_dialect_python_bindings(
@@ -131,7 +128,6 @@ declare_mlir_dialect_python_bindings(
   ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
   TD_FILE dialects/LinalgOps.td
   SOURCES
-    dialects/_linalg_ops_ext.py
   SOURCES_GLOB
     dialects/linalg/*.py
   DIALECT_NAME linalg
@@ -152,7 +148,6 @@ ADD_TO_PARENT MLIRPythonSources.Dialects
 ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
   TD_FILE dialects/TransformPDLExtensionOps.td
   SOURCES
-    dialects/_transform_pdl_extension_ops_ext.py
     dialects/transform/pdl.py
   DIALECT_NAME transform
   EXTENSION_NAME transform_pdl_extension)
@@ -162,7 +157,6 @@ declare_mlir_dialect_python_bindings(
   ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
   TD_FILE dialects/TransformOps.td
   SOURCES
-    dialects/_transform_ops_ext.py
     dialects/transform/__init__.py
     _mlir_libs/_mlir/dialects/transform/__init__.pyi
   DIALECT_NAME transform
@@ -175,7 +169,6 @@ declare_mlir_dialect_extension_python_bindings(
   ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
   TD_FILE dialects/BufferizationTransformOps.td
   SOURCES
-    dialects/_bufferization_transform_ops_ext.py
     dialects/transform/bufferization.py
   DIALECT_NAME transform
   EXTENSION_NAME bufferization_transform)
@@ -185,7 +178,6 @@ declare_mlir_dialect_extension_python_bindings(
   ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
   TD_FILE dialects/GPUTransformOps.td
   SOURCES
-    dialects/_gpu_transform_ops_ext.py
     dialects/transform/gpu.py
   DIALECT_NAME transform
   EXTENSION_NAME gpu_transform)
@@ -195,7 +187,6 @@ declare_mlir_dialect_extension_python_bindings(
   ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
   TD_FILE dialects/SCFLoopTransformOps.td
   SOURCES
-    dialects/_loop_transform_ops_ext.py
     dialects/transform/loop.py
   DIALECT_NAME transform
   EXTENSION_NAME loop_transform)
@@ -205,7 +196,6 @@ declare_mlir_dialect_extension_python_bindings(
   ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
   TD_FILE dialects/MemRefTransformOps.td
   SOURCES
-    dialects/_memref_transform_ops_ext.py
     dialects/transform/memref.py
   DIALECT_NAME transform
   EXTENSION_NAME memref_transform)
@@ -224,7 +214,6 @@ declare_mlir_dialect_extension_python_bindings(
   ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
   TD_FILE dialects/LinalgStructuredTransformOps.td
   SOURCES
-    dialects/_structured_transform_ops_ext.py
     dialects/transform/structured.py
   DIALECT_NAME transform
   EXTENSION_NAME structured_transform
@@ -246,7 +235,6 @@ declare_mlir_dialect_extension_python_bindings(
   ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
   TD_FILE dialects/TensorTransformOps.td
   SOURCES
-    dialects/_tensor_transform_ops_ext.py
     dialects/transform/tensor.py
   DIALECT_NAME transform
   EXTENSION_NAME tensor_transform)
@@ -276,7 +264,6 @@ declare_mlir_dialect_python_bindings(
   TD_FILE dialects/ArithOps.td
   SOURCES
     dialects/arith.py
-    dialects/_arith_ops_ext.py
   DIALECT_NAME arith
   GEN_ENUM_BINDINGS)
 
@@ -286,7 +273,6 @@ declare_mlir_dialect_python_bindings(
   TD_FILE dialects/MemRefOps.td
   SOURCES
     dialects/memref.py
-    dialects/_memref_ops_ext.py
   DIALECT_NAME memref)
 
 declare_mlir_dialect_python_bindings(
@@ -295,7 +281,6 @@ declare_mlir_dialect_python_bindings(
   TD_FILE dialects/MLProgramOps.td
   SOURCES
     dialects/ml_program.py
-    dialects/_ml_program_ops_ext.py
   DIALECT_NAME ml_program)
 
 declare_mlir_dialect_python_bindings(
@@ -339,7 +324,6 @@ declare_mlir_dialect_python_bindings(
   TD_FILE dialects/PDLOps.td
   SOURCES
     dialects/pdl.py
-    dialects/_pdl_ops_ext.py
     _mlir_libs/_mlir/dialects/pdl.pyi
   DIALECT_NAME pdl)
 
@@ -357,7 +341,6 @@ declare_mlir_dialect_python_bindings(
   TD_FILE dialects/SCFOps.td
   SOURCES
     dialects/scf.py
-    dialects/_scf_ops_ext.py
   DIALECT_NAME scf)
 
 declare_mlir_dialect_python_bindings(
@@ -383,7 +366,6 @@ declare_mlir_dialect_python_bindings(
   TD_FILE dialects/TensorOps.td
   SOURCES
     dialects/tensor.py
-    dialects/_tensor_ops_ext.py
   DIALECT_NAME tensor)
 
 declare_mlir_dialect_python_bindings(
diff --git a/mlir/python/mlir/dialects/arith.py b/mlir/python/mlir/dialects/arith.py
index df38f871710fe8f..e3b6a428c879de5 100644
--- a/mlir/python/mlir/dialects/arith.py
+++ b/mlir/python/mlir/dialects/arith.py
@@ -2,9 +2,13 @@
 #  See https://llvm.org/LICENSE.txt for license information.
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
+from ._arith_ops_gen import *
+from ._arith_ops_gen import _Dialect
+from ._arith_enum_gen import *
+
 try:
     from ..ir import *
-    from ._ods_common import get_default_loc_context as _get_default_loc_context
+    from ._ods_common import get_default_loc_context as _get_default_loc_context, _cext as _ods_cext
 
     from typing import Any, List, Union
 except ImportError as e:
@@ -30,8 +34,8 @@ def _is_integer_like_type(type: Type):
 def _is_float_type(type: Type):
     return _is_any_of(type, [BF16Type, F16Type, F32Type, F64Type])
 
-
-class ConstantOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class ConstantOp(ConstantOp):
     """Specialization for the constant op class."""
 
     def __init__(
diff --git a/mlir/python/mlir/dialects/bufferization.py b/mlir/python/mlir/dialects/bufferization.py
index 1066cb4c775cab9..0ce5448ace4b14c 100644
--- a/mlir/python/mlir/dialects/bufferization.py
+++ b/mlir/python/mlir/dialects/bufferization.py
@@ -2,17 +2,22 @@
 #  See https://llvm.org/LICENSE.txt for license information.
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
+from ._bufferization_ops_gen import *
+from ._bufferization_ops_gen import _Dialect
+from ._bufferization_enum_gen import *
+
 try:
     from typing import Sequence, Union
     from ..ir import *
-    from ._ods_common import get_default_loc_context
+    from ._ods_common import get_default_loc_context, _cext as _ods_cext
 
     from typing import Any, List, Union
 except ImportError as e:
     raise RuntimeError("Error loading imports from extension module") from e
 
 
-class AllocTensorOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class AllocTensorOp(AllocTensorOp):
     """Extends the bufferization.alloc_tensor op."""
 
     def __init__(
@@ -24,18 +29,14 @@ def __init__(
         escape: BoolAttr,
         *,
         loc=None,
-        ip=None
+        ip=None,
     ):
         """Constructs an `alloc_tensor` with static and/or dynamic sizes."""
-        context = get_default_loc_context(loc)
-        attributes = {}
-        if escape:
-            attributes["escape"] = escape
-        op = self.build_generic(
-            results=[tensor_type],
-            operands=[dynamic_sizes, copy, size_hint],
-            attributes=attributes,
+        super().__init__(
+            tensor_type,
+            dynamic_sizes,
+            copy=copy,
+            size_hint=size_hint,
             loc=loc,
             ip=ip,
         )
-        OpView.__init__(self, op)
diff --git a/mlir/python/mlir/dialects/builtin.py b/mlir/python/mlir/dialects/builtin.py
index 27a60123050acb4..b71cc2466d464b3 100644
--- a/mlir/python/mlir/dialects/builtin.py
+++ b/mlir/python/mlir/dialects/builtin.py
@@ -2,17 +2,22 @@
 #  See https://llvm.org/LICENSE.txt for license information.
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
+from ._builtin_ops_gen import *
+from ._builtin_ops_gen import _Dialect
+
 try:
     from ..ir import *
+    from ._ods_common import _cext as _ods_cext
 except ImportError as e:
     raise RuntimeError("Error loading imports from extension module") from e
 
 
-class ModuleOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class ModuleOp(ModuleOp):
     """Specialization for the module op class."""
 
     def __init__(self, *, loc=None, ip=None):
-        super().__init__(self.build_generic(results=[], operands=[], loc=loc, ip=ip))
+        super().__init__(loc=loc, ip=ip)
         body = self.regions[0].blocks.append()
 
     @property
diff --git a/mlir/python/mlir/dialects/func.py b/mlir/python/mlir/dialects/func.py
index 6d264c33f1f9dae..9c6c4c9092c7a88 100644
--- a/mlir/python/mlir/dialects/func.py
+++ b/mlir/python/mlir/dialects/func.py
@@ -2,9 +2,15 @@
 #  See https://llvm.org/LICENSE.txt for license information.
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
+from ._func_ops_gen import *
+from ._func_ops_gen import _Dialect
+
 try:
     from ..ir import *
-    from ._ods_common import get_default_loc_context as _get_default_loc_context
+    from ._ods_common import (
+        get_default_loc_context as _get_default_loc_context,
+        _cext as _ods_cext,
+    )
 
     import inspect
 
@@ -16,7 +22,8 @@
 RESULT_ATTRIBUTE_NAME = "res_attrs"
 
 
-class ConstantOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class ConstantOp(ConstantOp):
     """Specialization for the constant op class."""
 
     def __init__(self, result: Type, value: Attribute, *, loc=None, ip=None):
@@ -27,7 +34,8 @@ def type(self):
         return self.results[0].type
 
 
-class FuncOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class FuncOp(FuncOp):
     """Specialization for the func op class."""
 
     def __init__(
@@ -238,7 +246,8 @@ def emit_call_op(*call_args):
         return decorator
 
 
-class CallOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class CallOp(CallOp):
     """Specialization for the call op class."""
 
     def __init__(
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
index 6f9d72164429eea..f91fc8b7160089b 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
@@ -310,7 +310,7 @@ def emit_named_structured_op(
         )
 
     # Set the index attributes used to compute the indexing maps.
-    named_op = getattr(linalg, op_class_name)(ins, outs, result_types)
+    named_op = getattr(linalg, op_class_name)(result_types, ins, outs)
     for name, value in index_attrs.items():
         named_op.operation.attributes[name] = value
 
diff --git a/mlir/python/mlir/dialects/memref.py b/mlir/python/mlir/dialects/memref.py
index 825f1a0a7a6faf4..111ad2178703d28 100644
--- a/mlir/python/mlir/dialects/memref.py
+++ b/mlir/python/mlir/dialects/memref.py
@@ -2,17 +2,24 @@
 #  See https://llvm.org/LICENSE.txt for license information.
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
+from ._memref_ops_gen import *
+from ._memref_ops_gen import _Dialect
+
 try:
     from ..ir import *
-    from ._ods_common import get_op_result_or_value as _get_op_result_or_value
-    from ._ods_common import get_op_results_or_values as _get_op_results_or_values
+    from ._ods_common import (
+        get_op_result_or_value as _get_op_result_or_value,
+        get_op_results_or_values as _get_op_results_or_values,
+        _cext as _ods_cext,
+    )
 except ImportError as e:
     raise RuntimeError("Error loading imports from extension module") from e
 
 from typing import Optional, Sequence, Union
 
 
-class LoadOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class LoadOp(LoadOp):
     """Specialization for the MemRef load operation."""
 
     def __init__(
@@ -21,7 +28,7 @@ def __init__(
         indices: Optional[Union[Operation, OpView, Sequence[Value]]] = None,
         *,
         loc=None,
-        ip=None
+        ip=None,
     ):
         """Creates a memref load operation.
 
diff --git a/mlir/python/mlir/dialects/ml_program.py b/mlir/python/mlir/dialects/ml_program.py
index c84d23c16ef93ab..dfb6d7f2c03b1cf 100644
--- a/mlir/python/mlir/dialects/ml_program.py
+++ b/mlir/python/mlir/dialects/ml_program.py
@@ -2,21 +2,27 @@
 #  See https://llvm.org/LICENSE.txt for license information.
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
+from typing import Union
+
+from ._ml_program_ops_gen import *
+from ._ml_program_ops_gen import _Dialect
+
 try:
-    from typing import Union
     from ..ir import *
-    from ._ods_common import get_default_loc_context as _get_default_loc_context
+    from ._ods_common import (
+        get_default_loc_context as _get_default_loc_context,
+        _cext as _ods_cext,
+    )
 except ImportError as e:
     raise RuntimeError("Error loading imports from extension module") from e
 
-from ._ml_program_ops_gen import *
-
 
 ARGUMENT_ATTRIBUTE_NAME = "arg_attrs"
 RESULT_ATTRIBUTE_NAME = "res_attrs"
 
 
-class FuncOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class FuncOp(FuncOp):
     """Specialization for the func op class."""
 
     def __init__(
diff --git a/mlir/python/mlir/dialects/pdl.py b/mlir/python/mlir/dialects/pdl.py
index fc9de0b7f7db69c..a8d9c56f4233d9e 100644
--- a/mlir/python/mlir/dialects/pdl.py
+++ b/mlir/python/mlir/dialects/pdl.py
@@ -2,6 +2,11 @@
 #  See https://llvm.org/LICENSE.txt for license information.
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
+from ._pdl_ops_gen import *
+from ._pdl_ops_gen import _Dialect
+from .._mlir_libs._mlirDialectsPDL import *
+
+
 try:
     from ..ir import *
     from ..dialects import pdl
@@ -12,10 +17,12 @@
 from ._ods_common import (
     get_op_result_or_value as _get_value,
     get_op_results_or_values as _get_values,
+    _cext as _ods_cext,
 )
 
 
-class ApplyNativeConstraintOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class ApplyNativeConstraintOp(ApplyNativeConstraintOp):
     """Specialization for PDL apply native constraint op class."""
 
     def __init__(
@@ -32,7 +39,8 @@ def __init__(
         super().__init__(name, args, loc=loc, ip=ip)
 
 
-class ApplyNativeRewriteOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class ApplyNativeRewriteOp(ApplyNativeRewriteOp):
     """Specialization for PDL apply native rewrite op class."""
 
     def __init__(
@@ -50,7 +58,8 @@ def __init__(
         super().__init__(results, name, args, loc=loc, ip=ip)
 
 
-class AttributeOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class AttributeOp(AttributeOp):
     """Specialization for PDL attribute op class."""
 
     def __init__(
@@ -66,7 +75,8 @@ def __init__(
         super().__init__(result, valueType=valueType, value=value, loc=loc, ip=ip)
 
 
-class EraseOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class EraseOp(EraseOp):
     """Specialization for PDL erase op class."""
 
     def __init__(
@@ -80,7 +90,8 @@ def __init__(
         super().__init__(operation, loc=loc, ip=ip)
 
 
-class OperandOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class OperandOp(OperandOp):
     """Specialization for PDL operand op class."""
 
     def __init__(
@@ -95,7 +106,8 @@ def __init__(
         super().__init__(result, valueType=type, loc=loc, ip=ip)
 
 
-class OperandsOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class OperandsOp(OperandsOp):
     """Specialization for PDL operands op class."""
 
     def __init__(
@@ -110,7 +122,8 @@ def __init__(
         super().__init__(result, valueType=types, loc=loc, ip=ip)
 
 
-class OperationOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class OperationOp(OperationOp):
     """Specialization for PDL operand op class."""
 
     def __init__(
@@ -143,7 +156,8 @@ def __init__(
         )
 
 
-class PatternOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class PatternOp(PatternOp):
     """Specialization for PDL pattern op class."""
 
     def __init__(
@@ -164,7 +178,8 @@ def body(self):
         return self.regions[0].blocks[0]
 
 
-class ReplaceOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class ReplaceOp(ReplaceOp):
     """Specialization for PDL replace op class."""
 
     def __init__(
@@ -184,7 +199,8 @@ def __init__(
         super().__init__(op, with_values, replOperation=with_op, loc=loc, ip=ip)
 
 
-class ResultOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class ResultOp(ResultOp):
     """Specialization for PDL result op class."""
 
     def __init__(
@@ -200,7 +216,8 @@ def __init__(
         super().__init__(result, parent, index, loc=loc, ip=ip)
 
 
-class ResultsOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class ResultsOp(ResultsOp):
     """Specialization for PDL results op class."""
 
     def __init__(
@@ -216,7 +233,8 @@ def __init__(
         super().__init__(result, parent, index=index, loc=loc, ip=ip)
 
 
-class RewriteOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class RewriteOp(RewriteOp):
     """Specialization for PDL rewrite op class."""
 
     def __init__(
@@ -245,7 +263,8 @@ def body(self):
         return self.regions[0].blocks[0]
 
 
-class TypeOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class TypeOp(TypeOp):
     """Specialization for PDL type op class."""
 
     def __init__(
@@ -255,7 +274,8 @@ def __init__(
         super().__init__(result, constantType=constantType, loc=loc, ip=ip)
 
 
-class TypesOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class TypesOp(TypesOp):
     """Specialization for PDL types op class."""
 
     def __init__(
diff --git a/mlir/python/mlir/dialects/scf.py b/mlir/python/mlir/dialects/scf.py
index 89cc8a19895c7b4..43ad9f4e2d65f51 100644
--- a/mlir/python/mlir/dialects/scf.py
+++ b/mlir/python/mlir/dialects/scf.py
@@ -2,20 +2,29 @@
 #  See https://llvm.org/LICENSE.txt for license information.
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
+
+from ._scf_ops_gen import *
+from ._scf_ops_gen import _Dialect
+from .arith import constant
+
 try:
     from ..ir import *
+    from ._ods_common import (
+        get_op_result_or_value as _get_op_result_or_value,
+        get_op_results_or_values as _get_op_results_or_values,
+        _cext as _ods_cext,
+    )
 except ImportError as e:
     raise RuntimeError("Error loading imports from extension module") from e
 
 from typing import Optional, Sequence, Union
 
-from ._ods_common import (
-    get_op_result_or_value as _get_op_result_or_value,
-    get_op_results_or_values as _get_op_results_or_values,
-)
 
+_ForOp = ForOp
 
-class ForOp:
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class ForOp(_ForOp):
     """Specialization for the SCF for op class."""
 
     def __init__(
@@ -41,7 +50,7 @@ def __init__(
         iter_args = _get_op_results_or_values(iter_args)
 
         results = [arg.type for arg in iter_args]
-        super().__init__(
+        super(_ForOp, self).__init__(
             self.build_generic(
                 regions=1,
                 results=results,
@@ -74,7 +83,11 @@ def inner_iter_args(self):
         return self.body.arguments[1:]
 
 
-class IfOp:
+_IfOp = IfOp
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class IfOp(_IfOp):
     """Specialization for the SCF if op class."""
 
     def __init__(self, cond, results_=[], *, hasElse=False, loc=None, ip=None):
@@ -87,7 +100,7 @@ def __init__(self, cond, results_=[], *, hasElse=False, loc=None, ip=None):
         operands.append(cond)
         results = []
         results.extend(results_)
-        super().__init__(
+        super(_IfOp, self).__init__(
             self.build_generic(
                 regions=2, results=results, operands=operands, loc=loc, ip=ip
             )
@@ -105,3 +118,37 @@ def then_block(self):
     def else_block(self):
         """Returns the else block of the if operation."""
         return self.regions[1].blocks[0]
+
+
+def for_(
+    start,
+    stop=None,
+    step=None,
+    iter_args: Optional[Sequence[Value]] = None,
+    *,
+    loc=None,
+    ip=None,
+):
+    if step is None:
+        step = 1
+    if stop is None:
+        stop = start
+        start = 0
+    params = [start, stop, step]
+    for i, p in enumerate(params):
+        if isinstance(p, int):
+            p = constant(p)
+        elif isinstance(p, float):
+            raise ValueError(f"{p=} must be int.")
+        params[i] = p
+
+    for_op = ForOp(start, stop, step, iter_args, loc=loc, ip=ip)
+    iv = for_op.induction_variable
+    iter_args = tuple(for_op.inner_iter_args)
+    with InsertionPoint(for_op.body):
+        if len(iter_args) > 1:
+            yield iv, iter_args
+        elif len(iter_args) == 1:
+            yield iv, iter_args[0]
+        else:
+            yield iv
diff --git a/mlir/python/mlir/dialects/tensor.py b/mlir/python/mlir/dialects/tensor.py
index 09b9ec68db7d9c7..67248748eaf3ada 100644
--- a/mlir/python/mlir/dialects/tensor.py
+++ b/mlir/python/mlir/dialects/tensor.py
@@ -2,19 +2,20 @@
 #  See https://llvm.org/LICENSE.txt for license information.
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
+from ._tensor_ops_gen import *
+from ._tensor_ops_gen import _Dialect
+
 try:
     from ..ir import *
 except ImportError as e:
     raise RuntimeError("Error loading imports from extension module") from e
 
-from typing import Any, Optional, Sequence, Union
-from ._ods_common import (
-    get_op_result_or_value as _get_op_result_or_value,
-    get_op_results_or_values as _get_op_results_or_values,
-)
+from typing import Sequence, Union
+from ._ods_common import _cext as _ods_cext
 
 
-class EmptyOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class EmptyOp(EmptyOp):
     """Extends the tensor.empty op."""
 
     def __init__(
@@ -23,7 +24,7 @@ def __init__(
         element_type: Type,
         *,
         loc=None,
-        ip=None
+        ip=None,
     ):
         """Constructs an `empty` with mixed static/dynamic sizes."""
         # TODO: Refactor the EmptyOp to take an element type attribute and
@@ -38,7 +39,4 @@ def __init__(
                 static_sizes.append(ShapedType.get_dynamic_size())
                 dynamic_sizes.append(s)
         result_type = RankedTensorType.get(static_sizes, element_type)
-        op = self.build_generic(
-            results=[result_type], operands=dynamic_sizes, attributes={}, loc=loc, ip=ip
-        )
-        OpView.__init__(self, op)
+        super().__init__(result_type, dynamic_sizes, loc=loc, ip=ip)
diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py
index b1e7b892536f4a1..f7a2026e800aeb0 100644
--- a/mlir/python/mlir/dialects/transform/__init__.py
+++ b/mlir/python/mlir/dialects/transform/__init__.py
@@ -2,11 +2,17 @@
 #  See https://llvm.org/LICENSE.txt for license information.
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
+from .._transform_enum_gen import *
+from .._transform_ops_gen import *
+from .._transform_ops_gen import _Dialect
+from ..._mlir_libs._mlirDialectsTransform import *
+
 try:
-    from ..ir import *
-    from ._ods_common import (
+    from ...ir import *
+    from .._ods_common import (
         get_op_result_or_value as _get_op_result_or_value,
         get_op_results_or_values as _get_op_results_or_values,
+        _cext as _ods_cext,
     )
 except ImportError as e:
     raise RuntimeError("Error loading imports from extension module") from e
@@ -14,7 +20,8 @@
 from typing import Optional, Sequence, Union
 
 
-class CastOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class CastOp(CastOp):
     def __init__(
         self,
         result_type: Type,
@@ -26,7 +33,8 @@ def __init__(
         super().__init__(result_type, _get_op_result_or_value(target), loc=loc, ip=ip)
 
 
-class ApplyPatternsOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class ApplyPatternsOp(ApplyPatternsOp):
     def __init__(
         self,
         target: Union[Operation, Value, OpView],
@@ -34,19 +42,7 @@ def __init__(
         loc=None,
         ip=None,
     ):
-        operands = []
-        operands.append(_get_op_result_or_value(target))
-        super().__init__(
-            self.build_generic(
-                attributes={},
-                results=[],
-                operands=operands,
-                successors=None,
-                regions=None,
-                loc=loc,
-                ip=ip,
-            )
-        )
+        super().__init__(target, loc=loc, ip=ip)
         self.regions[0].blocks.append()
 
     @property
@@ -54,7 +50,8 @@ def patterns(self) -> Block:
         return self.regions[0].blocks[0]
 
 
-class testGetParentOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class GetParentOp(GetParentOp):
     def __init__(
         self,
         result_type: Type,
@@ -77,7 +74,8 @@ def __init__(
         )
 
 
-class MergeHandlesOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class MergeHandlesOp(MergeHandlesOp):
     def __init__(
         self,
         handles: Sequence[Union[Operation, Value]],
@@ -94,7 +92,8 @@ def __init__(
         )
 
 
-class ReplicateOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class ReplicateOp(ReplicateOp):
     def __init__(
         self,
         pattern: Union[Operation, Value],
@@ -112,7 +111,8 @@ def __init__(
         )
 
 
-class SequenceOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class SequenceOp(SequenceOp):
     def __init__(
         self,
         failure_propagation_mode,
@@ -163,7 +163,8 @@ def bodyExtraArgs(self) -> BlockArgumentList:
         return self.body.arguments[1:]
 
 
-class YieldOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class YieldOp(YieldOp):
     def __init__(
         self,
         operands: Optional[Union[Operation, Sequence[Value]]] = None,
diff --git a/mlir/python/mlir/dialects/transform/bufferization.py b/mlir/python/mlir/dialects/transform/bufferization.py
index 7e6c1b81cb350b7..485a8a36b6305e9 100644
--- a/mlir/python/mlir/dialects/transform/bufferization.py
+++ b/mlir/python/mlir/dialects/transform/bufferization.py
@@ -2,9 +2,13 @@
 #  See https://llvm.org/LICENSE.txt for license information.
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
+from .._bufferization_transform_ops_gen import *
+from .._bufferization_transform_ops_gen import _Dialect
+
 try:
-    from ..ir import *
-    from ..dialects import transform
+    from ...ir import *
+    from ...dialects import transform
+    from .._ods_common import _cext as _ods_cext
 except ImportError as e:
     raise RuntimeError("Error loading imports from extension module") from e
 
@@ -12,7 +16,8 @@
 from typing import Optional, overload, Union
 
 
-class EmptyTensorToAllocTensorOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class EmptyTensorToAllocTensorOp(EmptyTensorToAllocTensorOp):
     """Specialization for EmptyTensorToAllocTensorOp class."""
 
     @overload
@@ -22,7 +27,7 @@ def __init__(
         target: Union[Operation, OpView, Value],
         *,
         loc=None,
-        ip=None
+        ip=None,
     ):
         ...
 
@@ -36,7 +41,7 @@ def __init__(
         target_or_none: Optional[Union[Operation, OpView, Value]] = None,
         *,
         loc=None,
-        ip=None
+        ip=None,
     ):
         if isinstance(transformed_type_or_target, Type):
             transformed_type = transformed_type_or_target
@@ -53,7 +58,8 @@ def __init__(
         )
 
 
-class OneShotBufferizeOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class OneShotBufferizeOp(OneShotBufferizeOp):
     """Specialization for OneShotBufferizeOp class."""
 
     @overload
@@ -70,7 +76,7 @@ def __init__(
         print_conflicts: Optional[bool] = None,
         test_analysis_only: Optional[bool] = None,
         loc=None,
-        ip=None
+        ip=None,
     ):
         ...
 
@@ -87,7 +93,7 @@ def __init__(
         print_conflicts: Optional[bool] = None,
         test_analysis_only: Optional[bool] = None,
         loc=None,
-        ip=None
+        ip=None,
     ):
         ...
 
@@ -104,7 +110,7 @@ def __init__(
         print_conflicts: Optional[bool] = None,
         test_analysis_only: Optional[bool] = None,
         loc=None,
-        ip=None
+        ip=None,
     ):
         if isinstance(transformed_type_or_target, Type):
             transformed_type = transformed_type_or_target
diff --git a/mlir/python/mlir/dialects/transform/gpu.py b/mlir/python/mlir/dialects/transform/gpu.py
index ba72bac3a15264d..00cf0840eeae9e1 100644
--- a/mlir/python/mlir/dialects/transform/gpu.py
+++ b/mlir/python/mlir/dialects/transform/gpu.py
@@ -2,16 +2,21 @@
 #  See https://llvm.org/LICENSE.txt for license information.
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
+from .._gpu_transform_ops_gen import *
+from .._gpu_transform_ops_gen import _Dialect
+
 try:
-    from ..ir import *
-    from ..dialects import transform
+    from ...ir import *
+    from ...dialects import transform
+    from .._ods_common import _cext as _ods_cext
 except ImportError as e:
     raise RuntimeError("Error loading imports from extension module") from e
 
 from typing import Optional, Sequence, Union, overload
 
 
-class MapForallToBlocks:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class MapForallToBlocks(MapForallToBlocks):
     """Specialization for MapForallToBlocks class."""
 
     @overload
@@ -23,7 +28,7 @@ def __init__(
         grid_dims: Optional[Union[Sequence[int], Attribute]] = None,
         generate_gpu_launch: Optional[Union[bool, Attribute]] = None,
         loc=None,
-        ip=None
+        ip=None,
     ):
         ...
 
@@ -35,7 +40,7 @@ def __init__(
         grid_dims: Optional[Union[Sequence[int], Attribute]] = None,
         generate_gpu_launch: Optional[Union[bool, Attribute]] = None,
         loc=None,
-        ip=None
+        ip=None,
     ):
         ...
 
@@ -47,7 +52,7 @@ def __init__(
         grid_dims: Optional[Union[Sequence[int], Attribute]] = None,
         generate_gpu_launch: Optional[Union[bool, Attribute]] = None,
         loc=None,
-        ip=None
+        ip=None,
     ):
         if isinstance(result_type_or_target, Type):
             result_type = result_type_or_target
@@ -66,7 +71,8 @@ def __init__(
         )
 
 
-class MapNestedForallToThreads:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class MapNestedForallToThreads(MapNestedForallToThreads):
     """Specialization for MapNestedForallToThreads class."""
 
     @overload
@@ -79,7 +85,7 @@ def __init__(
         warp_size: Optional[Sequence[int]] = None,
         sync_after_distribute: Optional[bool] = None,
         loc=None,
-        ip=None
+        ip=None,
     ):
         ...
 
@@ -92,7 +98,7 @@ def __init__(
         warp_size: Optional[Sequence[int]] = None,
         sync_after_distribute: Optional[bool] = None,
         loc=None,
-        ip=None
+        ip=None,
     ):
         ...
 
@@ -105,7 +111,7 @@ def __init__(
         warp_size: Optional[Union[Sequence[int], Attribute]] = None,
         sync_after_distribute: Optional[bool] = None,
         loc=None,
-        ip=None
+        ip=None,
     ):
         if isinstance(result_type_or_target, Type):
             result_type = result_type_or_target
diff --git a/mlir/python/mlir/dialects/transform/loop.py b/mlir/python/mlir/dialects/transform/loop.py
index 1cdb2b9e77b5afe..6c89025f413839e 100644
--- a/mlir/python/mlir/dialects/transform/loop.py
+++ b/mlir/python/mlir/dialects/transform/loop.py
@@ -2,16 +2,23 @@
 #  See https://llvm.org/LICENSE.txt for license information.
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
+from .._loop_transform_ops_gen import *
+from .._loop_transform_ops_gen import _Dialect
+
 try:
-    from ..ir import *
-    from ._ods_common import get_op_result_or_value as _get_op_result_or_value
+    from ...ir import *
+    from .._ods_common import (
+        get_op_result_or_value as _get_op_result_or_value,
+        _cext as _ods_cext,
+    )
 except ImportError as e:
     raise RuntimeError("Error loading imports from extension module") from e
 
 from typing import Optional, Union
 
 
-class GetParentForOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class GetParentForOp(GetParentForOp):
     """Extension for GetParentForOp."""
 
     def __init__(
@@ -34,7 +41,8 @@ def __init__(
         )
 
 
-class LoopOutlineOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class LoopOutlineOp(LoopOutlineOp):
     """Extension for LoopOutlineOp."""
 
     def __init__(
@@ -61,7 +69,8 @@ def __init__(
         )
 
 
-class LoopPeelOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class LoopPeelOp(LoopPeelOp):
     """Extension for LoopPeelOp."""
 
     def __init__(
@@ -88,7 +97,8 @@ def __init__(
         )
 
 
-class LoopPipelineOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class LoopPipelineOp(LoopPipelineOp):
     """Extension for LoopPipelineOp."""
 
     def __init__(
@@ -115,7 +125,8 @@ def __init__(
         )
 
 
-class LoopUnrollOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class LoopUnrollOp(LoopUnrollOp):
     """Extension for LoopUnrollOp."""
 
     def __init__(
diff --git a/mlir/python/mlir/dialects/transform/memref.py b/mlir/python/mlir/dialects/transform/memref.py
index 1cc00bdcbf381c9..56ea61eb817f89c 100644
--- a/mlir/python/mlir/dialects/transform/memref.py
+++ b/mlir/python/mlir/dialects/transform/memref.py
@@ -2,16 +2,21 @@
 #  See https://llvm.org/LICENSE.txt for license information.
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
+from .._memref_transform_ops_gen import *
+from .._memref_transform_ops_gen import _Dialect
+
 try:
-    from ..ir import *
-    from ..dialects import transform
+    from ...ir import *
+    from ...dialects import transform
+    from .._ods_common import _cext as _ods_cext
 except ImportError as e:
     raise RuntimeError("Error loading imports from extension module") from e
 
 from typing import Optional, overload, Union
 
 
-class MemRefAllocaToGlobalOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class MemRefAllocaToGlobalOp(MemRefAllocaToGlobalOp):
     """Specialization for MemRefAllocaToGlobalOp class."""
 
     @overload
@@ -22,7 +27,7 @@ def __init__(
         alloca: Union[Operation, OpView, Value],
         *,
         loc=None,
-        ip=None
+        ip=None,
     ):
         ...
 
@@ -37,7 +42,7 @@ def __init__(
         alloca_or_none: Optional[Union[Operation, OpView, Value]] = None,
         *,
         loc=None,
-        ip=None
+        ip=None,
     ):
         if isinstance(get_global_type_or_alloca, Type):
             get_global_type = get_global_type_or_alloca
@@ -57,7 +62,8 @@ def __init__(
         )
 
 
-class MemRefMultiBufferOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class MemRefMultiBufferOp(MemRefMultiBufferOp):
     """Specialization for MemRefMultiBufferOp class."""
 
     @overload
@@ -69,7 +75,7 @@ def __init__(
         *,
         skip_analysis: Optional[bool] = None,
         loc=None,
-        ip=None
+        ip=None,
     ):
         ...
 
@@ -81,7 +87,7 @@ def __init__(
         *,
         skip_analysis: Optional[bool] = None,
         loc=None,
-        ip=None
+        ip=None,
     ):
         ...
 
@@ -93,7 +99,7 @@ def __init__(
         *,
         skip_analysis: Optional[bool] = None,
         loc=None,
-        ip=None
+        ip=None,
     ):
         if isinstance(transformed_type_or_target, Type):
             transformed_type = transformed_type_or_target
diff --git a/mlir/python/mlir/dialects/transform/pdl.py b/mlir/python/mlir/dialects/transform/pdl.py
index c4e4b4b4254b038..bb5fa7ffd306583 100644
--- a/mlir/python/mlir/dialects/transform/pdl.py
+++ b/mlir/python/mlir/dialects/transform/pdl.py
@@ -2,54 +2,54 @@
 #  See https://llvm.org/LICENSE.txt for license information.
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
+from .._transform_pdl_extension_ops_gen import *
+from .._transform_pdl_extension_ops_gen import _Dialect
+
 try:
-  from ..ir import *
-  from ._ods_common import (
-      get_op_result_or_value as _get_op_result_or_value,
-      get_op_results_or_values as _get_op_results_or_values,
-  )
+    from ...ir import *
+    from .._ods_common import (
+        get_op_result_or_value as _get_op_result_or_value,
+        get_op_results_or_values as _get_op_results_or_values,
+        _cext as _ods_cext,
+    )
 except ImportError as e:
-  raise RuntimeError("Error loading imports from extension module") from e
+    raise RuntimeError("Error loading imports from extension module") from e
 
 from typing import Union
 
-class PDLMatchOp:
-
-  def __init__(
-      self,
-      result_type: Type,
-      target: Union[Operation, Value],
-      pattern_name: Union[Attribute, str],
-      *,
-      loc=None,
-      ip=None,
-  ):
-    super().__init__(
-        result_type,
-        _get_op_result_or_value(target),
-        pattern_name,
-        loc=loc,
-        ip=ip,
-    )
-
-
-class WithPDLPatternsOp:
-
-  def __init__(self,
-               target: Union[Operation, Value, Type],
-               *,
-               loc=None,
-               ip=None):
-    root = _get_op_result_or_value(target) if not isinstance(target,
-                                                             Type) else None
-    root_type = target if isinstance(target, Type) else root.type
-    super().__init__(root=root, loc=loc, ip=ip)
-    self.regions[0].blocks.append(root_type)
-
-  @property
-  def body(self) -> Block:
-    return self.regions[0].blocks[0]
 
-  @property
-  def bodyTarget(self) -> Value:
-    return self.body.arguments[0]
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class PDLMatchOp(PDLMatchOp):
+    def __init__(
+        self,
+        result_type: Type,
+        target: Union[Operation, Value],
+        pattern_name: Union[Attribute, str],
+        *,
+        loc=None,
+        ip=None,
+    ):
+        super().__init__(
+            result_type,
+            _get_op_result_or_value(target),
+            pattern_name,
+            loc=loc,
+            ip=ip,
+        )
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class WithPDLPatternsOp(WithPDLPatternsOp):
+    def __init__(self, target: Union[Operation, Value, Type], *, loc=None, ip=None):
+        root = _get_op_result_or_value(target) if not isinstance(target, Type) else None
+        root_type = target if isinstance(target, Type) else root.type
+        super().__init__(root=root, loc=loc, ip=ip)
+        self.regions[0].blocks.append(root_type)
+
+    @property
+    def body(self) -> Block:
+        return self.regions[0].blocks[0]
+
+    @property
+    def bodyTarget(self) -> Value:
+        return self.body.arguments[0]
diff --git a/mlir/python/mlir/dialects/transform/structured.py b/mlir/python/mlir/dialects/transform/structured.py
index 3757a3d3b4cce85..284c93823acbd34 100644
--- a/mlir/python/mlir/dialects/transform/structured.py
+++ b/mlir/python/mlir/dialects/transform/structured.py
@@ -2,9 +2,14 @@
 #  See https://llvm.org/LICENSE.txt for license information.
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
+from .._structured_transform_ops_gen import *
+from .._structured_transform_ops_gen import _Dialect
+from .._structured_transform_enum_gen import *
+
 try:
-    from ..ir import *
-    from ..dialects import transform
+    from ...ir import *
+    from ...dialects import transform
+    from .._ods_common import _cext as _ods_cext
 except ImportError as e:
     raise RuntimeError("Error loading imports from extension module") from e
 
@@ -163,7 +168,8 @@ def _get_int_array_array_attr(
     return ArrayAttr.get(values)
 
 
-class BufferizeToAllocationOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class BufferizeToAllocationOp(BufferizeToAllocationOp):
     """Specialization for BufferizeToAllocationOp class."""
 
     def __init__(
@@ -199,7 +205,8 @@ def __init__(
         )
 
 
-class DecomposeOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class DecomposeOp(DecomposeOp):
     """Specialization for DecomposeOp class."""
 
     def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
@@ -207,7 +214,8 @@ def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
         super().__init__(transformed_type, target, loc=loc, ip=ip)
 
 
-class FuseIntoContainingOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class FuseIntoContainingOp(FuseIntoContainingOp):
     """Specialization for FuseIntoContainingOp class."""
 
     @overload
@@ -271,7 +279,8 @@ def __init__(
         )
 
 
-class GeneralizeOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class GeneralizeOp(GeneralizeOp):
     """Specialization for GeneralizeOp class."""
 
     def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
@@ -279,7 +288,8 @@ def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
         super().__init__(transformed_type, target, loc=loc, ip=ip)
 
 
-class InterchangeOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class InterchangeOp(InterchangeOp):
     """Specialization for InterchangeOp class."""
 
     def __init__(
@@ -300,7 +310,8 @@ def __init__(
         )
 
 
-class MapCopyToThreadsOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class MapCopyToThreadsOp(MapCopyToThreadsOp):
     """Specialization for MapCopyToThreadsOp class."""
 
     @overload
@@ -360,7 +371,8 @@ def __init__(
         )
 
 
-class VectorizeOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class VectorizeOp(VectorizeOp):
     """Specialization for VectorizeOp class."""
 
     def __init__(
@@ -405,7 +417,8 @@ def __init__(
         )
 
 
-class MatchOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class MatchOp(MatchOp):
     """Specialization for MatchOp class."""
 
     @overload
@@ -464,7 +477,8 @@ def match_op_names(
         )
 
 
-class MultiTileSizesOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class MultiTileSizesOp(MultiTileSizesOp):
     """Specialization for MultiTileSizesOp class."""
 
     def __init__(
@@ -491,7 +505,8 @@ def __init__(
         )
 
 
-class PadOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class PadOp(PadOp):
     """Specialization for PadOp class."""
 
     def __init__(
@@ -528,7 +543,8 @@ def __init__(
         )
 
 
-class ScalarizeOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class ScalarizeOp(ScalarizeOp):
     """Specialization for ScalarizeOp class."""
 
     def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
@@ -536,7 +552,8 @@ def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
         super().__init__(result_type, target, loc=loc, ip=ip)
 
 
-class SplitOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class SplitOp(SplitOp):
     """Specialization for SplitOp class."""
 
     def __init__(
@@ -567,7 +584,8 @@ def __init__(
         )
 
 
-class TileUsingForOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class TileUsingForOp(TileUsingForOp):
     """Specialization for TileUsingForOp class."""
 
     @overload
@@ -640,7 +658,8 @@ def __init__(
         )
 
 
-class TileUsingForallOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class TileUsingForallOp(TileUsingForallOp):
     """Specialization for TileUsingForallOp class."""
 
     @overload
@@ -732,7 +751,8 @@ def __init__(
         )
 
 
-class VectorizeChildrenAndApplyPatternsOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class VectorizeChildrenAndApplyPatternsOp(VectorizeChildrenAndApplyPatternsOp):
     """Specialization for VectorizeChildrenAndApplyPatternsOp class."""
 
     def __init__(
diff --git a/mlir/python/mlir/dialects/transform/tensor.py b/mlir/python/mlir/dialects/transform/tensor.py
index 996093fbc913e8a..4eb30398f087212 100644
--- a/mlir/python/mlir/dialects/transform/tensor.py
+++ b/mlir/python/mlir/dialects/transform/tensor.py
@@ -2,16 +2,21 @@
 #  See https://llvm.org/LICENSE.txt for license information.
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
+from .._tensor_transform_ops_gen import *
+from .._tensor_transform_ops_gen import _Dialect
+
 try:
-    from ..ir import *
-    from ..dialects import transform
+    from ...ir import *
+    from ...dialects import transform
+    from .._ods_common import _cext as _ods_cext
 except ImportError as e:
     raise RuntimeError("Error loading imports from extension module") from e
 
 from typing import Optional, overload, Union
 
 
-class MakeLoopIndependentOp:
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class MakeLoopIndependentOp(MakeLoopIndependentOp):
     """Specialization for MakeLoopIndependentOp class."""
 
     @overload
@@ -22,7 +27,7 @@ def __init__(
         num_loops: Union[int, IntegerAttr],
         *,
         loc=None,
-        ip=None
+        ip=None,
     ):
         ...
 
@@ -33,7 +38,7 @@ def __init__(
         num_loops: Union[int, IntegerAttr],
         *,
         loc=None,
-        ip=None
+        ip=None,
     ):
         ...
 
@@ -44,7 +49,7 @@ def __init__(
         num_loops_or_none: Optional[Union[int, IntegerAttr]] = None,
         *,
         loc=None,
-        ip=None
+        ip=None,
     ):
         if isinstance(transformed_type_or_target, Type):
             transformed_type = transformed_type_or_target
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 49f3a951426d0ee..d8dcf936e768754 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -30,14 +30,9 @@ constexpr const char *fileHeader = R"Py(
 # Autogenerated by mlir-tblgen; don't manually edit.
 
 from ._ods_common import _cext as _ods_cext
-from ._ods_common import extend_opview_class as _ods_extend_opview_class, segmented_accessor as _ods_segmented_accessor, equally_sized_accessor as _ods_equally_sized_accessor, get_default_loc_context as _ods_get_default_loc_context, get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values, get_op_result_or_op_results as _get_op_result_or_op_results
+from ._ods_common import segmented_accessor as _ods_segmented_accessor, equally_sized_accessor as _ods_equally_sized_accessor, get_default_loc_context as _ods_get_default_loc_context, get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values, get_op_result_or_op_results as _get_op_result_or_op_results
 _ods_ir = _ods_cext.ir
 
-try:
-  from . import _{0}_ops_ext as _ods_ext_module
-except ImportError:
-  _ods_ext_module = None
-
 import builtins
 from typing import Sequence as _Sequence, Union as _Union
 
@@ -62,7 +57,6 @@ from ._{0}_ops_gen import _Dialect
 ///   {1} is the operation name.
 constexpr const char *opClassTemplate = R"Py(
 @_ods_cext.register_operation(_Dialect)
- at _ods_extend_opview_class(_ods_ext_module)
 class {0}(_ods_ir.OpView):
   OPERATION_NAME = "{1}"
 )Py";
@@ -302,12 +296,8 @@ static bool isODSReserved(StringRef str) {
 /// modified version.
 static std::string sanitizeName(StringRef name) {
   std::string processed_str = name.str();
-  std::replace_if(
-      processed_str.begin(), processed_str.end(),
-      [](char c) { return !llvm::isAlnum(c); }, '_');
 
-  if (llvm::isDigit(*processed_str.begin()))
-    return "_" + processed_str;
+  std::replace(processed_str.begin(), processed_str.end(), '-', '_');
 
   if (isPythonReserved(processed_str) || isODSReserved(processed_str))
     return processed_str + "_";
@@ -854,9 +844,6 @@ populateBuilderRegions(const Operator &op,
 static llvm::SmallVector<std::string> emitDefaultOpBuilder(const Operator &op,
                                                            raw_ostream &os) {
   // If we are asked to skip default builders, comply.
-  if (op.skipDefaultBuilders())
-    return {};
-
   llvm::SmallVector<std::string> builderArgs;
   llvm::SmallVector<std::string> builderLines;
   llvm::SmallVector<std::string> operandArgNames;
@@ -989,9 +976,8 @@ static void emitRegionAccessors(const Operator &op, raw_ostream &os) {
 static void emitValueBuilder(const Operator &op,
                              llvm::SmallVector<std::string> functionArgs,
                              raw_ostream &os) {
-  // If we are asked to skip default builders, comply.
-  if (op.skipDefaultBuilders())
-    return;
+  auto name = sanitizeName(op.getOperationName());
+  iterator_range<llvm::SplittingIterator> splitName = llvm::split(name, ".");
   // Params with (possibly) default args.
   auto valueBuilderParams =
       llvm::map_range(functionArgs, [](const std::string &argAndMaybeDefault) {
@@ -1010,16 +996,16 @@ static void emitValueBuilder(const Operator &op,
         auto lhs = *llvm::split(arg, "=").begin();
         return (lhs + "=" + llvm::convertToSnakeFromCamelCase(lhs)).str();
       });
-  std::string name_without_dialect =
-      op.getOperationName().substr(op.getOperationName().find('.') + 1);
-  os << llvm::formatv(valueBuilderTemplate, sanitizeName(name_without_dialect),
-                      op.getCppClassName(),
-                      llvm::join(valueBuilderParams, ", "),
-                      llvm::join(opBuilderArgs, ", "),
-                      (op.getNumResults() > 1
-                           ? "_Sequence[_ods_ir.OpResult]"
-                           : (op.getNumResults() > 0 ? "_ods_ir.OpResult"
-                                                     : "_ods_ir.Operation")));
+  os << llvm::formatv(
+      valueBuilderTemplate,
+      // Drop dialect name and then sanitize again (to catch e.g. func.return).
+      sanitizeName(llvm::join(++splitName.begin(), splitName.end(), "_")),
+      op.getCppClassName(), llvm::join(valueBuilderParams, ", "),
+      llvm::join(opBuilderArgs, ", "),
+      (op.getNumResults() > 1
+           ? "_Sequence[_ods_ir.OpResult]"
+           : (op.getNumResults() > 0 ? "_ods_ir.OpResult"
+                                     : "_ods_ir.Operation")));
 }
 
 /// Emits bindings for a specific Op to the given output stream.
@@ -1051,11 +1037,8 @@ static bool emitAllOps(const llvm::RecordKeeper &records, raw_ostream &os) {
   if (clDialectName.empty())
     llvm::PrintFatalError("dialect name not provided");
 
-  bool isExtension = !clDialectExtensionName.empty();
-  os << llvm::formatv(fileHeader, isExtension
-                                      ? clDialectExtensionName.getValue()
-                                      : clDialectName.getValue());
-  if (isExtension)
+  os << fileHeader;
+  if (!clDialectExtensionName.empty())
     os << llvm::formatv(dialectExtensionTemplate, clDialectName.getValue());
   else
     os << llvm::formatv(dialectClassTemplate, clDialectName.getValue());

>From 15e92d0cad0bfd511618f68cd497a4fef2b27f93 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Thu, 12 Oct 2023 09:42:48 -0500
Subject: [PATCH 4/8] rebase

---
 mlir/python/CMakeLists.txt                    |  1 -
 mlir/python/mlir/dialects/_affine_ops_ext.py  | 56 -------------------
 mlir/python/mlir/dialects/affine.py           | 51 ++++++++++++++++-
 mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp | 35 ++++++------
 4 files changed, 67 insertions(+), 76 deletions(-)
 delete mode 100644 mlir/python/mlir/dialects/_affine_ops_ext.py

diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 586d39fc332fac5..88e6e13602d291a 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -52,7 +52,6 @@ declare_mlir_dialect_python_bindings(
   TD_FILE dialects/AffineOps.td
   SOURCES
     dialects/affine.py
-    dialects/_affine_ops_ext.py
   DIALECT_NAME affine
   GEN_ENUM_BINDINGS)
 
diff --git a/mlir/python/mlir/dialects/_affine_ops_ext.py b/mlir/python/mlir/dialects/_affine_ops_ext.py
deleted file mode 100644
index dc465ce7aa1e5f9..000000000000000
--- a/mlir/python/mlir/dialects/_affine_ops_ext.py
+++ /dev/null
@@ -1,56 +0,0 @@
-#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-#  See https://llvm.org/LICENSE.txt for license information.
-#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-try:
-    from ..ir import *
-    from ._ods_common import get_op_result_or_value as _get_op_result_or_value
-    from ._ods_common import get_op_results_or_values as _get_op_results_or_values
-except ImportError as e:
-    raise RuntimeError("Error loading imports from extension module") from e
-
-from typing import Optional, Sequence, Union
-
-
-class AffineStoreOp:
-    """Specialization for the Affine store operation."""
-
-    def __init__(
-        self,
-        value: Union[Operation, OpView, Value],
-        memref: Union[Operation, OpView, Value],
-        map: AffineMap=None,
-        *,
-        map_operands=None,
-        loc=None,
-        ip=None
-    ):
-        """Creates an affine store operation.
-
-        - `value`: the value to store into the memref.
-        - `memref`: the buffer to store into.
-        - `map`: the affine map that maps the map_operands to the index of the 
-          memref.
-        - `map_operands`: the list of arguments to substitute the dimensions, 
-          then symbols in the affine map, in increasing order.
-        """
-        map = map if map is not None else []
-        map_operands = map_operands if map_operands is not None else []
-        operands = [
-            _get_op_result_or_value(value),
-            _get_op_result_or_value(memref),
-            *[_get_op_result_or_value(op) for op in map_operands]
-        ]
-        results = []
-        attributes = {"map": AffineMapAttr.get(map)}
-        regions = None
-        _ods_successors = None
-        super().__init__(self.build_generic(
-            attributes=attributes,
-            results=results,
-            operands=operands,
-            successors=_ods_successors,
-            regions=regions,
-            loc=loc,
-            ip=ip
-        ))
diff --git a/mlir/python/mlir/dialects/affine.py b/mlir/python/mlir/dialects/affine.py
index 8a2a64c7c40d190..1eaccfa73a85cbf 100644
--- a/mlir/python/mlir/dialects/affine.py
+++ b/mlir/python/mlir/dialects/affine.py
@@ -1,5 +1,50 @@
-#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.                                        
-#  See https://llvm.org/LICENSE.txt for license information.                                                            
-#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception    
+#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+#  See https://llvm.org/LICENSE.txt for license information.
+#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
 from ._affine_ops_gen import *
+from ._affine_ops_gen import _Dialect
+
+try:
+    from ..ir import *
+    from ._ods_common import (
+        get_op_result_or_value as _get_op_result_or_value,
+        get_op_results_or_values as _get_op_results_or_values,
+        _cext as _ods_cext,
+    )
+except ImportError as e:
+    raise RuntimeError("Error loading imports from extension module") from e
+
+from typing import Optional, Sequence, Union
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class AffineStoreOp(AffineStoreOp):
+    """Specialization for the Affine store operation."""
+
+    def __init__(
+        self,
+        value: Union[Operation, OpView, Value],
+        memref: Union[Operation, OpView, Value],
+        map: AffineMap = None,
+        *,
+        map_operands=None,
+        loc=None,
+        ip=None,
+    ):
+        """Creates an affine store operation.
+
+        - `value`: the value to store into the memref.
+        - `memref`: the buffer to store into.
+        - `map`: the affine map that maps the map_operands to the index of the
+          memref.
+        - `map_operands`: the list of arguments to substitute the dimensions,
+          then symbols in the affine map, in increasing order.
+        """
+        map = map if map is not None else []
+        map_operands = map_operands if map_operands is not None else []
+        indicies = [_get_op_result_or_value(op) for op in map_operands]
+        _ods_successors = None
+        super().__init__(
+            value, memref, indicies, AffineMapAttr.get(map), loc=loc, ip=ip
+        )
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index d8dcf936e768754..875678a2333789a 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -295,13 +295,17 @@ static bool isODSReserved(StringRef str) {
 /// (does not change the `name` if it already is suitable) and returns the
 /// modified version.
 static std::string sanitizeName(StringRef name) {
-  std::string processed_str = name.str();
+  std::string processedStr = name.str();
+  std::replace_if(
+      processedStr.begin(), processedStr.end(),
+      [](char c) { return !llvm::isAlnum(c); }, '_');
 
-  std::replace(processed_str.begin(), processed_str.end(), '-', '_');
+  if (llvm::isDigit(*processedStr.begin()))
+    return "_" + processedStr;
 
-  if (isPythonReserved(processed_str) || isODSReserved(processed_str))
-    return processed_str + "_";
-  return processed_str;
+  if (isPythonReserved(processedStr) || isODSReserved(processedStr))
+    return processedStr + "_";
+  return processedStr;
 }
 
 static std::string attrSizedTraitForKind(const char *kind) {
@@ -977,7 +981,6 @@ static void emitValueBuilder(const Operator &op,
                              llvm::SmallVector<std::string> functionArgs,
                              raw_ostream &os) {
   auto name = sanitizeName(op.getOperationName());
-  iterator_range<llvm::SplittingIterator> splitName = llvm::split(name, ".");
   // Params with (possibly) default args.
   auto valueBuilderParams =
       llvm::map_range(functionArgs, [](const std::string &argAndMaybeDefault) {
@@ -996,16 +999,16 @@ static void emitValueBuilder(const Operator &op,
         auto lhs = *llvm::split(arg, "=").begin();
         return (lhs + "=" + llvm::convertToSnakeFromCamelCase(lhs)).str();
       });
-  os << llvm::formatv(
-      valueBuilderTemplate,
-      // Drop dialect name and then sanitize again (to catch e.g. func.return).
-      sanitizeName(llvm::join(++splitName.begin(), splitName.end(), "_")),
-      op.getCppClassName(), llvm::join(valueBuilderParams, ", "),
-      llvm::join(opBuilderArgs, ", "),
-      (op.getNumResults() > 1
-           ? "_Sequence[_ods_ir.OpResult]"
-           : (op.getNumResults() > 0 ? "_ods_ir.OpResult"
-                                     : "_ods_ir.Operation")));
+  std::string nameWithoutDialect =
+      op.getOperationName().substr(op.getOperationName().find('.') + 1);
+  os << llvm::formatv(valueBuilderTemplate, sanitizeName(nameWithoutDialect),
+                      op.getCppClassName(),
+                      llvm::join(valueBuilderParams, ", "),
+                      llvm::join(opBuilderArgs, ", "),
+                      (op.getNumResults() > 1
+                           ? "_Sequence[_ods_ir.OpResult]"
+                           : (op.getNumResults() > 0 ? "_ods_ir.OpResult"
+                                                     : "_ods_ir.Operation")));
 }
 
 /// Emits bindings for a specific Op to the given output stream.

>From 0545186ba15f3103e4a558c0df806610c4c32cdf Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Thu, 12 Oct 2023 01:57:00 -0500
Subject: [PATCH 5/8] format

---
 mlir/lib/Bindings/Python/Globals.h            |   2 +-
 mlir/python/mlir/dialects/arith.py            |   6 +-
 .../linalg/opdsl/ops/core_named_ops.py        | 107 ++++++++++--------
 mlir/python/mlir/dialects/python_test.py      |   7 +-
 mlir/python/mlir/runtime/np_to_memref.py      |   8 +-
 mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp |   2 -
 6 files changed, 75 insertions(+), 57 deletions(-)

diff --git a/mlir/lib/Bindings/Python/Globals.h b/mlir/lib/Bindings/Python/Globals.h
index dea44bbd469dd3d..21899bdce22e810 100644
--- a/mlir/lib/Bindings/Python/Globals.h
+++ b/mlir/lib/Bindings/Python/Globals.h
@@ -77,7 +77,7 @@ class PyGlobals {
                            pybind11::object pyClass);
 
   /// Adds a concrete implementation operation class.
-  /// Raises an exception if the mapping already exists.
+  /// Raises an exception if the mapping already exists and replace == false.
   /// This is intended to be called by implementation code.
   void registerOperationImpl(const std::string &operationName,
                              pybind11::object pyClass, bool replace = false);
diff --git a/mlir/python/mlir/dialects/arith.py b/mlir/python/mlir/dialects/arith.py
index e3b6a428c879de5..83aca0d58bf2cef 100644
--- a/mlir/python/mlir/dialects/arith.py
+++ b/mlir/python/mlir/dialects/arith.py
@@ -8,7 +8,10 @@
 
 try:
     from ..ir import *
-    from ._ods_common import get_default_loc_context as _get_default_loc_context, _cext as _ods_cext
+    from ._ods_common import (
+        get_default_loc_context as _get_default_loc_context,
+        _cext as _ods_cext,
+    )
 
     from typing import Any, List, Union
 except ImportError as e:
@@ -34,6 +37,7 @@ def _is_integer_like_type(type: Type):
 def _is_float_type(type: Type):
     return _is_any_of(type, [BF16Type, F16Type, F32Type, F64Type])
 
+
 @_ods_cext.register_operation(_Dialect, replace=True)
 class ConstantOp(ConstantOp):
     """Specialization for the constant op class."""
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index a8f8f8e0fbd68b4..19734a80a107bfe 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -296,35 +296,39 @@ def quantized_matmul(
 
 
 @linalg_structured_op
-def matmul_transpose_a(A=TensorDef(T1, S.K, S.N),
-                       B=TensorDef(T2, S.K, S.M),
-                       C=TensorDef(U, S.M, S.N, output=True),
-                       cast=TypeFnAttrDef(default=TypeFn.cast_signed)):
-  """Performs a matrix multiplication of two 2D inputs with lhs operand
-  transposed.
+def matmul_transpose_a(
+    A=TensorDef(T1, S.K, S.N),
+    B=TensorDef(T2, S.K, S.M),
+    C=TensorDef(U, S.M, S.N, output=True),
+    cast=TypeFnAttrDef(default=TypeFn.cast_signed),
+):
+    """Performs a matrix multiplication of two 2D inputs with lhs operand
+    transposed.
 
-  Numeric casting is performed on the operands to the inner multiply, promoting
-  them to the same data type as the accumulator/output.
-  """
-  domain(D.m, D.n, D.k)
-  implements(ContractionOpInterface)
-  C[D.m, D.n] += cast(U, A[D.k, D.m]) * cast(U, B[D.k, D.n])
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output.
+    """
+    domain(D.m, D.n, D.k)
+    implements(ContractionOpInterface)
+    C[D.m, D.n] += cast(U, A[D.k, D.m]) * cast(U, B[D.k, D.n])
 
 
 @linalg_structured_op
-def matmul_transpose_b(A=TensorDef(T1, S.M, S.K),
-                       B=TensorDef(T2, S.N, S.K),
-                       C=TensorDef(U, S.M, S.N, output=True),
-                       cast=TypeFnAttrDef(default=TypeFn.cast_signed)):
-  """Performs a matrix multiplication of two 2D inputs with rhs operand
-  transposed.
+def matmul_transpose_b(
+    A=TensorDef(T1, S.M, S.K),
+    B=TensorDef(T2, S.N, S.K),
+    C=TensorDef(U, S.M, S.N, output=True),
+    cast=TypeFnAttrDef(default=TypeFn.cast_signed),
+):
+    """Performs a matrix multiplication of two 2D inputs with rhs operand
+    transposed.
 
-  Numeric casting is performed on the operands to the inner multiply, promoting
-  them to the same data type as the accumulator/output.
-  """
-  domain(D.m, D.n, D.k)
-  implements(ContractionOpInterface)
-  C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.n, D.k])
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output.
+    """
+    domain(D.m, D.n, D.k)
+    implements(ContractionOpInterface)
+    C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.n, D.k])
 
 
 @linalg_structured_op
@@ -390,36 +394,41 @@ def batch_matmul(
 
 
 @linalg_structured_op
-def batch_matmul_transpose_a(A=TensorDef(T1, Batch, S.K, S.M),
-                             B=TensorDef(T2, Batch, S.K, S.N),
-                             C=TensorDef(U, Batch, S.M, S.N, output=True)):
-  """Performs a batched matrix multiplication of two 3D inputs where lhs operand
-  has its non-batch dimensions transposed.
+def batch_matmul_transpose_a(
+    A=TensorDef(T1, Batch, S.K, S.M),
+    B=TensorDef(T2, Batch, S.K, S.N),
+    C=TensorDef(U, Batch, S.M, S.N, output=True),
+):
+    """Performs a batched matrix multiplication of two 3D inputs where lhs operand
+    has its non-batch dimensions transposed.
 
-  Numeric casting is performed on the operands to the inner multiply, promoting
-  them to the same data type as the accumulator/output.
-  """
-  domain(D.b, D.m, D.n, D.k)
-  implements(ContractionOpInterface)
-  C[D.b, D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.k, D.m]) \
-                    * TypeFn.cast_signed(U, B[D.b, D.k, D.n])
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output.
+    """
+    domain(D.b, D.m, D.n, D.k)
+    implements(ContractionOpInterface)
+    C[D.b, D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.k, D.m]) * TypeFn.cast_signed(
+        U, B[D.b, D.k, D.n]
+    )
 
 
 @linalg_structured_op
-def batch_matmul_transpose_b(A=TensorDef(T1, Batch, S.M, S.K),
-                             B=TensorDef(T2, Batch, S.N, S.K),
-                             C=TensorDef(U, Batch, S.M, S.N, output=True)):
-  """Performs a batched matrix multiplication of two 3D inputs where rhs operand
-  has its non-batch dimensions transposed.
+def batch_matmul_transpose_b(
+    A=TensorDef(T1, Batch, S.M, S.K),
+    B=TensorDef(T2, Batch, S.N, S.K),
+    C=TensorDef(U, Batch, S.M, S.N, output=True),
+):
+    """Performs a batched matrix multiplication of two 3D inputs where rhs operand
+    has its non-batch dimensions transposed.
 
-  Numeric casting is performed on the operands to the inner multiply, promoting
-  them to the same data type as the accumulator/output.
-  """
-  domain(D.b, D.m, D.n, D.k)
-  implements(ContractionOpInterface)
-  C[D.b, D.m,
-    D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed(
-        U, B[D.b, D.n, D.k])
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output.
+    """
+    domain(D.b, D.m, D.n, D.k)
+    implements(ContractionOpInterface)
+    C[D.b, D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed(
+        U, B[D.b, D.n, D.k]
+    )
 
 
 @linalg_structured_op
diff --git a/mlir/python/mlir/dialects/python_test.py b/mlir/python/mlir/dialects/python_test.py
index 8465af048a28056..6579e02d8549efa 100644
--- a/mlir/python/mlir/dialects/python_test.py
+++ b/mlir/python/mlir/dialects/python_test.py
@@ -3,7 +3,12 @@
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
 from ._python_test_ops_gen import *
-from .._mlir_libs._mlirPythonTest import TestAttr, TestType, TestTensorValue, TestIntegerRankedTensorType
+from .._mlir_libs._mlirPythonTest import (
+    TestAttr,
+    TestType,
+    TestTensorValue,
+    TestIntegerRankedTensorType,
+)
 
 
 def register_python_test_dialect(context, load=True):
diff --git a/mlir/python/mlir/runtime/np_to_memref.py b/mlir/python/mlir/runtime/np_to_memref.py
index 0a3b411041b2f4d..f6b706f9bc8ae24 100644
--- a/mlir/python/mlir/runtime/np_to_memref.py
+++ b/mlir/python/mlir/runtime/np_to_memref.py
@@ -114,6 +114,7 @@ def get_unranked_memref_descriptor(nparray):
     d.descriptor = ctypes.cast(ctypes.pointer(x), ctypes.c_void_p)
     return d
 
+
 def move_aligned_ptr_by_offset(aligned_ptr, offset):
     """Moves the supplied ctypes pointer ahead by `offset` elements."""
     aligned_addr = ctypes.addressof(aligned_ptr.contents)
@@ -122,6 +123,7 @@ def move_aligned_ptr_by_offset(aligned_ptr, offset):
     content_ptr = ctypes.cast(aligned_addr + shift, type(aligned_ptr))
     return content_ptr
 
+
 def unranked_memref_to_numpy(unranked_memref, np_dtype):
     """Converts unranked memrefs to numpy arrays."""
     ctp = as_ctype(np_dtype)
@@ -139,10 +141,10 @@ def unranked_memref_to_numpy(unranked_memref, np_dtype):
 
 def ranked_memref_to_numpy(ranked_memref):
     """Converts ranked memrefs to numpy arrays."""
-    content_ptr = move_aligned_ptr_by_offset(ranked_memref[0].aligned, ranked_memref[0].offset)
-    np_arr = np.ctypeslib.as_array(
-        content_ptr, shape=ranked_memref[0].shape
+    content_ptr = move_aligned_ptr_by_offset(
+        ranked_memref[0].aligned, ranked_memref[0].offset
     )
+    np_arr = np.ctypeslib.as_array(content_ptr, shape=ranked_memref[0].shape)
     strided_arr = np.lib.stride_tricks.as_strided(
         np_arr,
         np.ctypeslib.as_array(ranked_memref[0].shape),
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 875678a2333789a..c8ef84721090ab9 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -847,7 +847,6 @@ populateBuilderRegions(const Operator &op,
 /// rebuild anew).
 static llvm::SmallVector<std::string> emitDefaultOpBuilder(const Operator &op,
                                                            raw_ostream &os) {
-  // If we are asked to skip default builders, comply.
   llvm::SmallVector<std::string> builderArgs;
   llvm::SmallVector<std::string> builderLines;
   llvm::SmallVector<std::string> operandArgNames;
@@ -980,7 +979,6 @@ static void emitRegionAccessors(const Operator &op, raw_ostream &os) {
 static void emitValueBuilder(const Operator &op,
                              llvm::SmallVector<std::string> functionArgs,
                              raw_ostream &os) {
-  auto name = sanitizeName(op.getOperationName());
   // Params with (possibly) default args.
   auto valueBuilderParams =
       llvm::map_range(functionArgs, [](const std::string &argAndMaybeDefault) {

>From 73adb885927fdff8dbe8f4156516eefd7351bb23 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Thu, 19 Oct 2023 13:22:38 -0500
Subject: [PATCH 6/8] add value builder test

---
 mlir/test/python/dialects/arith_dialect.py | 13 +++++++++++++
 1 file changed, 13 insertions(+)

diff --git a/mlir/test/python/dialects/arith_dialect.py b/mlir/test/python/dialects/arith_dialect.py
index f4a793aee4aa14c..6d1c5eab7589847 100644
--- a/mlir/test/python/dialects/arith_dialect.py
+++ b/mlir/test/python/dialects/arith_dialect.py
@@ -33,3 +33,16 @@ def testFastMathFlags():
             )
             # CHECK: %0 = arith.addf %cst, %cst fastmath<nnan,ninf> : f32
             print(r)
+
+
+# CHECK-LABEL: TEST: testArithValueBuilder
+ at run
+def testArithValueBuilder():
+    with Context() as ctx, Location.unknown():
+        module = Module.create()
+        f32_t = F32Type.get()
+
+        with InsertionPoint(module.body):
+            a = arith.constant(value=FloatAttr.get(f32_t, 42.42))
+            # CHECK: %cst = arith.constant 4.242000e+01 : f32
+            print(a)

>From 86ce32ac6ba276028bd09fe732a5e46eabe56520 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Thu, 19 Oct 2023 14:28:39 -0500
Subject: [PATCH 7/8] update docs

---
 mlir/docs/Bindings/Python.md | 127 ++++++++++++++++-------------------
 1 file changed, 58 insertions(+), 69 deletions(-)

diff --git a/mlir/docs/Bindings/Python.md b/mlir/docs/Bindings/Python.md
index bf54efee1f14e0c..bc2e676a878c0f4 100644
--- a/mlir/docs/Bindings/Python.md
+++ b/mlir/docs/Bindings/Python.md
@@ -1017,90 +1017,79 @@ very generic signature.
 
 #### Extending Generated Op Classes
 
-Note that this is a rather complex mechanism and this section errs on the side
-of explicitness. Users are encouraged to find an example and duplicate it if
-they don't feel the need to understand the subtlety. The `builtin` dialect
-provides some relatively simple examples.
-
 As mentioned above, the build system generates Python sources like
 `_{DIALECT_NAMESPACE}_ops_gen.py` for each dialect with Python bindings. It is
-often desirable to to use these generated classes as a starting point for
-further customization, so an extension mechanism is provided to make this easy
-(you are always free to do ad-hoc patching in your `{DIALECT_NAMESPACE}.py` file
-but we prefer a more standard mechanism that is applied uniformly).
+often desirable to use these generated classes as a starting point for
+further customization, so an extension mechanism is provided to make this easy.
+This mechanism uses conventional inheritance combined with `OpView` registration.
+For example, the default builder for `arith.constant`
+
+```python
+class ConstantOp(_ods_ir.OpView):
+  OPERATION_NAME = "arith.constant"
+
+  _ODS_REGIONS = (0, True)
+
+  def __init__(self, value, *, loc=None, ip=None):
+    ...
+```
 
-To provide extensions, add a `_{DIALECT_NAMESPACE}_ops_ext.py` file to the
-`dialects` module (i.e. adjacent to your `{DIALECT_NAMESPACE}.py` top-level and
-the `*_ops_gen.py` file). Using the `builtin` dialect and `FuncOp` as an
-example, the generated code will include an import like this:
+expects `value` to be a `TypedAttr` (e.g., `IntegerAttr` or `FloatAttr`). 
+Thus, a natural extension is a builder that accepts a MLIR type and a Python value and instantiates the appropriate `TypedAttr`:
 
 ```python
-try:
-  from . import _builtin_ops_ext as _ods_ext_module
-except ImportError:
-  _ods_ext_module = None
+from typing import Union
+
+from mlir.ir import Type, IntegerAttr, FloatAttr
+from mlir.dialects._arith_ops_gen import _Dialect, ConstantOp
+from mlir.dialects._ods_common import _cext
+
+ at _cext.register_operation(_Dialect, replace=True)
+class ConstantOpExt(ConstantOp):
+    def __init__(
+        self, result: Type, value: Union[int, float], *, loc=None, ip=None
+    ):
+        if isinstance(value, int):
+            super().__init__(IntegerAttr.get(result, value), loc=loc, ip=ip)
+        elif isinstance(value, float):
+            super().__init__(FloatAttr.get(result, value), loc=loc, ip=ip)
+        else:
+            raise NotImplementedError(f"Building `arith.constant` not supported for {result=} {value=}")
 ```
 
-Then for each generated concrete `OpView` subclass, it will apply a decorator
-like:
+which enables building an instance of `arith.constant` like so:
 
 ```python
- at _ods_cext.register_operation(_Dialect)
- at _ods_extend_opview_class(_ods_ext_module)
-class FuncOp(_ods_ir.OpView):
+from mlir.ir import F32Type
+
+a = ConstantOpExt(F32Type.get(), 42.42)
+b = ConstantOpExt(IntegerType.get_signless(32), 42)
 ```
 
-See the `_ods_common.py` `extend_opview_class` function for details of the
-mechanism. At a high level:
-
-*   If the extension module exists, locate an extension class for the op (in
-    this example, `FuncOp`):
-    *   First by looking for an attribute with the exact name in the extension
-        module.
-    *   Falling back to calling a `select_opview_mixin(parent_opview_cls)`
-        function defined in the extension module.
-*   If a mixin class is found, a new subclass is dynamically created that
-    multiply inherits from `({_builtin_ops_ext.FuncOp},
-    _builtin_ops_gen.FuncOp)`.
-
-The mixin class should not inherit from anything (i.e. directly extends `object`
-only). The facility is typically used to define custom `__init__` methods,
-properties, instance methods and static methods. Due to the inheritance
-ordering, the mixin class can act as though it extends the generated `OpView`
-subclass in most contexts (i.e. `issubclass(_builtin_ops_ext.FuncOp, OpView)`
-will return `False` but usage generally allows you treat it as duck typed as an
-`OpView`).
-
-There are a couple of recommendations, given how the class hierarchy is defined:
-
-*   For static methods that need to instantiate the actual "leaf" op (which is
-    dynamically generated and would result in circular dependencies to try to
-    reference by name), prefer to use `@classmethod` and the concrete subclass
-    will be provided as your first `cls` argument. See
-    `_builtin_ops_ext.FuncOp.from_py_func` as an example.
-*   If seeking to replace the generated `__init__` method entirely, you may
-    actually want to invoke the super-super-class `mlir.ir.OpView` constructor
-    directly, as it takes an `mlir.ir.Operation`, which is likely what you are
-    constructing (i.e. the generated `__init__` method likely adds more API
-    constraints than you want to expose in a custom builder).
-
-A pattern that comes up frequently is wanting to provide a sugared `__init__`
-method which has optional or type-polymorphism/implicit conversions but to
-otherwise want to invoke the default op building logic. For such cases, it is
-recommended to use an idiom such as:
+Note, three key aspects of the extension mechanism in this example:
+
+1. `ConstantOpExt` directly inherits from the generated `ConstantOp`;
+2. in this, simplest, case all that's required is a call to the super class' initializer, i.e., `super().__init__(...)`;
+3. in order to register `ConstantOpExt` as the preferred `OpView` that is returned by `mlir.ir.Operation.opview` (see [Operations, Regions and Blocks](#operations-regions-and-blocks))
+   we decorate the class with `@_cext.register_operation(_Dialect, replace=True)`, **where the `replace=True` must be used**.
+
+In some more complex cases it might be necessary to explicitly build the `OpView` through `OpView.build_generic` (see [Default Builder](#default-builder)), just as is performed by the generated builders.
+I.e., we must call `OpView.build_generic` **and pass the result to `OpView.__init__`**, where the small issue becomes that the latter is already overridden by the generated builder.
+Thus, we must call a method of a super class' super class (the "grandparent"); for example:
 
 ```python
-  def __init__(self, sugar, spice, *, loc=None, ip=None):
-    ... massage into result_type, operands, attributes ...
-    OpView.__init__(self, self.build_generic(
-        results=[result_type],
-        operands=operands,
-        attributes=attributes,
-        loc=loc,
-        ip=ip))
+from mlir.dialects._scf_ops_gen import _Dialect, ForOp
+from mlir.dialects._ods_common import _cext
+
+ at _cext.register_operation(_Dialect, replace=True)
+class ForOpExt(ForOp):
+    def __init__(self, lower_bound, upper_bound, step, iter_args, *, loc=None, ip=None):
+        ...
+        super(ForOp, self).__init__(self.build_generic(...))
 ```
 
-Refer to the documentation for `build_generic` for more information.
+where `OpView.__init__` is called via `super(ForOp, self).__init__`.
+Note, there are alternatives ways to implement this (e.g., explicitly writing `OpView.__init__`); see any discussion on Python inheritance.
 
 ## Providing Python bindings for a dialect
 

>From 6725dea7e254e1bbc15e33cabc4391c564241662 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Wed, 11 Oct 2023 00:28:04 -0500
Subject: [PATCH 8/8] [mlir][python] value casting

---
 mlir/python/mlir/dialects/_ods_common.py      | 58 +++++++++++++++-
 mlir/python/mlir/ir.py                        | 14 ++++
 mlir/test/mlir-tblgen/op-python-bindings.td   | 16 ++---
 mlir/test/python/dialects/arith_dialect.py    | 69 +++++++++++++++++--
 mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp | 17 +++--
 5 files changed, 156 insertions(+), 18 deletions(-)

diff --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py
index 9cca7d659ec8cb3..cb85990bf4240e7 100644
--- a/mlir/python/mlir/dialects/_ods_common.py
+++ b/mlir/python/mlir/dialects/_ods_common.py
@@ -1,11 +1,18 @@
 #  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 #  See https://llvm.org/LICENSE.txt for license information.
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+from collections import defaultdict
 
 # Provide a convenient name for sub-packages to resolve the main C-extension
 # with a relative import.
 from .._mlir_libs import _mlir as _cext
-from typing import Sequence as _Sequence, Union as _Union
+from typing import (
+    Callable as _Callable,
+    Sequence as _Sequence,
+    Type as _Type,
+    TypeVar as _TypeVar,
+    Union as _Union,
+)
 
 __all__ = [
     "equally_sized_accessor",
@@ -123,3 +130,52 @@ def get_op_result_or_op_results(
         if len(op.results) > 0
         else op
     )
+
+
+U = _TypeVar("U", bound=_cext.ir.Value)
+SubClassValueT = _Type[U]
+
+TypeCasterT = _Callable[
+    [_Union[_cext.ir.Value, _cext.ir.OpResult]], _Union[SubClassValueT, None]
+]
+
+_VALUE_CASTERS: defaultdict[
+    _cext.ir.TypeID,
+    _Sequence[TypeCasterT],
+] = defaultdict(list)
+
+
+def has_value_caster(typeid: _cext.ir.TypeID):
+    if not isinstance(typeid, _cext.ir.TypeID):
+        raise ValueError(f"{typeid=} is not a TypeID")
+    if typeid in _VALUE_CASTERS:
+        return True
+    return False
+
+
+def get_value_caster(typeid: _cext.ir.TypeID):
+    if not has_value_caster(typeid):
+        raise ValueError(f"no registered caster for {typeid=}")
+    return _VALUE_CASTERS[typeid]
+
+
+def maybe_cast(
+    val: _Union[
+        _cext.ir.Value,
+        _cext.ir.OpResult,
+        _Sequence[_cext.ir.Value],
+        _Sequence[_cext.ir.OpResult],
+        _cext.ir.Operation,
+    ]
+) -> _Union[SubClassValueT, _Sequence[SubClassValueT], _cext.ir.Operation]:
+    if isinstance(val, (tuple, list)):
+        return tuple(map(maybe_cast, val))
+
+    if not isinstance(val, _cext.ir.Value) and not isinstance(val, _cext.ir.OpResult):
+        return val
+
+    if has_value_caster(val.type.typeid):
+        for caster in get_value_caster(val.type.typeid):
+            if casted := caster(val):
+                return casted
+    return val
diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py
index 43553f3118a51fc..019da8cd677fd45 100644
--- a/mlir/python/mlir/ir.py
+++ b/mlir/python/mlir/ir.py
@@ -5,6 +5,20 @@
 from ._mlir_libs._mlir.ir import *
 from ._mlir_libs._mlir.ir import _GlobalDebug
 from ._mlir_libs._mlir import register_type_caster
+from .dialects._ods_common import TypeCasterT, _VALUE_CASTERS
+
+
+def register_value_caster(typeid: TypeID, priority: int = None):
+    def wrapper(caster: TypeCasterT):
+        if not isinstance(typeid, TypeID):
+            raise ValueError(f"{typeid=} is not a TypeID")
+        if priority is None:
+            _VALUE_CASTERS[typeid].append(caster)
+        else:
+            _VALUE_CASTERS[typeid].insert(priority, caster)
+        return caster
+
+    return wrapper
 
 
 # Convenience decorator for registering user-friendly Attribute builders.
diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td
index 63dad1cc901fe2b..fa7000c02873b5e 100644
--- a/mlir/test/mlir-tblgen/op-python-bindings.td
+++ b/mlir/test/mlir-tblgen/op-python-bindings.td
@@ -236,7 +236,7 @@ def DeriveResultTypesOp : TestOp<"derive_result_types_op", [FirstAttrDerivedResu
 }
 
 // CHECK: def derive_result_types_op(type_, *, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(DeriveResultTypesOp(type_=type_, loc=loc, ip=ip))
+// CHECK:   return _maybe_cast(_get_op_result_or_op_results(DeriveResultTypesOp(type_=type_, loc=loc, ip=ip)))
 
 // CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_variadic_op"
 def DeriveResultTypesVariadicOp : TestOp<"derive_result_types_variadic_op", [FirstAttrDerivedResultType]> {
@@ -246,7 +246,7 @@ def DeriveResultTypesVariadicOp : TestOp<"derive_result_types_variadic_op", [Fir
 }
 
 // CHECK: def derive_result_types_variadic_op(res, _gen_res_1, type_, *, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(DeriveResultTypesVariadicOp(res=res, _gen_res_1=_gen_res_1, type_=type_, loc=loc, ip=ip))
+// CHECK:   return _maybe_cast(_get_op_result_or_op_results(DeriveResultTypesVariadicOp(res=res, _gen_res_1=_gen_res_1, type_=type_, loc=loc, ip=ip)))
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class EmptyOp(_ods_ir.OpView):
@@ -276,7 +276,7 @@ def InferResultTypesImpliedOp : TestOp<"infer_result_types_implied_op"> {
 }
 
 // CHECK: def infer_result_types_implied_op(*, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(InferResultTypesImpliedOp(loc=loc, ip=ip))
+// CHECK:   return _maybe_cast(_get_op_result_or_op_results(InferResultTypesImpliedOp(loc=loc, ip=ip)))
 
 // CHECK-LABEL: OPERATION_NAME = "test.infer_result_types_op"
 def InferResultTypesOp : TestOp<"infer_result_types_op", [InferTypeOpInterface]> {
@@ -289,7 +289,7 @@ def InferResultTypesOp : TestOp<"infer_result_types_op", [InferTypeOpInterface]>
 }
 
 // CHECK: def infer_result_types_op(*, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(InferResultTypesOp(loc=loc, ip=ip))
+// CHECK:   return _maybe_cast(_get_op_result_or_op_results(InferResultTypesOp(loc=loc, ip=ip)))
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class MissingNamesOp(_ods_ir.OpView):
@@ -461,7 +461,7 @@ def SameResultsOp : TestOp<"same_results", [SameOperandsAndResultType]> {
 }
 
 // CHECK: def same_results(in1, in2, *, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(SameResultsOp(in1=in1, in2=in2, loc=loc, ip=ip))
+// CHECK:   return _maybe_cast(_get_op_result_or_op_results(SameResultsOp(in1=in1, in2=in2, loc=loc, ip=ip)))
 
 // CHECK-LABEL: OPERATION_NAME = "test.same_results_variadic"
 def SameResultsVariadicOp : TestOp<"same_results_variadic", [SameOperandsAndResultType]> {
@@ -471,7 +471,7 @@ def SameResultsVariadicOp : TestOp<"same_results_variadic", [SameOperandsAndResu
 }
 
 // CHECK: def same_results_variadic(res, in1, in2, *, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(SameResultsVariadicOp(res=res, in1=in1, in2=in2, loc=loc, ip=ip))
+// CHECK:   return _maybe_cast(_get_op_result_or_op_results(SameResultsVariadicOp(res=res, in1=in1, in2=in2, loc=loc, ip=ip)))
 
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
@@ -524,7 +524,7 @@ def SameVariadicResultSizeOp : TestOp<"same_variadic_result",
 }
 
 // CHECK: def same_variadic_result(variadic1, non_variadic, variadic2, *, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(SameVariadicResultSizeOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip))
+// CHECK:   return _maybe_cast(_get_op_result_or_op_results(SameVariadicResultSizeOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip)))
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class SimpleOp(_ods_ir.OpView):
@@ -564,7 +564,7 @@ def SimpleOp : TestOp<"simple"> {
 }
 
 // CHECK: def simple(i64, f64, i32, f32, *, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(SimpleOp(i64=i64, f64=f64, i32=i32, f32=f32, loc=loc, ip=ip))
+// CHECK:   return _maybe_cast(_get_op_result_or_op_results(SimpleOp(i64=i64, f64=f64, i32=i32, f32=f32, loc=loc, ip=ip)))
 
 // CHECK: class VariadicAndNormalRegionOp(_ods_ir.OpView):
 // CHECK-LABEL: OPERATION_NAME = "test.variadic_and_normal_region"
diff --git a/mlir/test/python/dialects/arith_dialect.py b/mlir/test/python/dialects/arith_dialect.py
index 6d1c5eab7589847..1a4f635b05b4aad 100644
--- a/mlir/test/python/dialects/arith_dialect.py
+++ b/mlir/test/python/dialects/arith_dialect.py
@@ -1,8 +1,10 @@
 # RUN: %PYTHON %s | FileCheck %s
+from functools import partialmethod
 
 from mlir.ir import *
-import mlir.dialects.func as func
 import mlir.dialects.arith as arith
+import mlir.dialects._arith_ops_ext as arith_ext
+from mlir.dialects._ods_common import maybe_cast
 
 
 def run(f):
@@ -35,14 +37,71 @@ def testFastMathFlags():
             print(r)
 
 
-# CHECK-LABEL: TEST: testArithValueBuilder
+# CHECK-LABEL: TEST: testArithValue
 @run
-def testArithValueBuilder():
+def testArithValue():
+    def _binary_op(lhs, rhs, op: str):
+        op = op.capitalize()
+        if arith_ext._is_float_type(lhs.type):
+            op += "F"
+        elif arith_ext._is_integer_like_type(lhs.type):
+            op += "I"
+        else:
+            raise NotImplementedError(f"Unsupported '{op}' operands: {lhs}, {rhs}")
+
+        op = getattr(arith, f"{op}Op")
+        return maybe_cast(op(lhs, rhs).result)
+
+    @register_value_caster(F16Type.static_typeid)
+    @register_value_caster(F32Type.static_typeid)
+    @register_value_caster(F64Type.static_typeid)
+    @register_value_caster(IntegerType.static_typeid)
+    class ArithValue(Value):
+        __add__ = partialmethod(_binary_op, op="add")
+        __sub__ = partialmethod(_binary_op, op="sub")
+        __mul__ = partialmethod(_binary_op, op="mul")
+
+        def __str__(self):
+            return super().__str__().replace("Value", "ArithValue")
+
+    @register_value_caster(IntegerType.static_typeid, priority=0)
+    class ArithValue1(Value):
+        __mul__ = partialmethod(_binary_op, op="mul")
+
+        def __str__(self):
+            return super().__str__().replace("Value", "ArithValue1")
+
+    @register_value_caster(IntegerType.static_typeid, priority=0)
+    def no_op_caster(val):
+        print("no_op_caster", val)
+        return None
+
     with Context() as ctx, Location.unknown():
         module = Module.create()
+        f16_t = F16Type.get()
         f32_t = F32Type.get()
+        f64_t = F64Type.get()
+        i32 = IntegerType.get_signless(32)
 
         with InsertionPoint(module.body):
+            a = arith.constant(value=FloatAttr.get(f16_t, 42.42))
+            b = a + a
+            # CHECK: ArithValue(%0 = arith.addf %cst, %cst : f16)
+            print(b)
+
             a = arith.constant(value=FloatAttr.get(f32_t, 42.42))
-            # CHECK: %cst = arith.constant 4.242000e+01 : f32
-            print(a)
+            b = a - a
+            # CHECK: ArithValue(%1 = arith.subf %cst_0, %cst_0 : f32)
+            print(b)
+
+            a = arith.constant(value=FloatAttr.get(f64_t, 42.42))
+            b = a * a
+            # CHECK: ArithValue(%2 = arith.mulf %cst_1, %cst_1 : f64)
+            print(b)
+
+            # CHECK: no_op_caster Value(%c1_i32 = arith.constant 1 : i32)
+            a = arith.constant(value=IntegerAttr.get(i32, 1))
+            b = a * a
+            # CHECK: no_op_caster Value(%3 = arith.muli %c1_i32, %c1_i32 : i32)
+            # CHECK: ArithValue1(%3 = arith.muli %c1_i32, %c1_i32 : i32)
+            print(b)
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index c8ef84721090ab9..170ac6b87c693d7 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -30,7 +30,16 @@ constexpr const char *fileHeader = R"Py(
 # Autogenerated by mlir-tblgen; don't manually edit.
 
 from ._ods_common import _cext as _ods_cext
-from ._ods_common import segmented_accessor as _ods_segmented_accessor, equally_sized_accessor as _ods_equally_sized_accessor, get_default_loc_context as _ods_get_default_loc_context, get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values, get_op_result_or_op_results as _get_op_result_or_op_results
+from ._ods_common import (
+    SubClassValueT as _SubClassValueT,
+    equally_sized_accessor as _ods_equally_sized_accessor,
+    get_default_loc_context as _ods_get_default_loc_context,
+    get_op_result_or_op_results as _get_op_result_or_op_results,
+    get_op_result_or_value as _get_op_result_or_value,
+    get_op_results_or_values as _get_op_results_or_values,
+    maybe_cast as _maybe_cast,
+    segmented_accessor as _ods_segmented_accessor,
+)
 _ods_ir = _ods_cext.ir
 
 import builtins
@@ -263,7 +272,7 @@ constexpr const char *regionAccessorTemplate = R"Py(
 
 constexpr const char *valueBuilderTemplate = R"Py(
 def {0}({2}) -> {4}:
-  return _get_op_result_or_op_results({1}({3}))
+  return _maybe_cast(_get_op_result_or_op_results({1}({3})))
 )Py";
 
 static llvm::cl::OptionCategory
@@ -1004,8 +1013,8 @@ static void emitValueBuilder(const Operator &op,
                       llvm::join(valueBuilderParams, ", "),
                       llvm::join(opBuilderArgs, ", "),
                       (op.getNumResults() > 1
-                           ? "_Sequence[_ods_ir.OpResult]"
-                           : (op.getNumResults() > 0 ? "_ods_ir.OpResult"
+                           ? "_Sequence[_SubClassValueT]"
+                           : (op.getNumResults() > 0 ? "_SubClassValueT"
                                                      : "_ods_ir.Operation")));
 }
 



More information about the Mlir-commits mailing list