[Mlir-commits] [mlir] [mlir] NFC - refactor id builder and avoid leaking impl details (PR #146922)

Nicolas Vasilache llvmlistbot at llvm.org
Mon Jul 7 06:20:31 PDT 2025


https://github.com/nicolasvasilache updated https://github.com/llvm/llvm-project/pull/146922

>From 30bfea8439d866e5bc8ec1c0eafb9161ae585a40 Mon Sep 17 00:00:00 2001
From: Nicolas Vasilache <nico.vasilache at amd.com>
Date: Fri, 4 Jul 2025 10:32:39 +0200
Subject: [PATCH 1/3] [mlir][python] Add utils for more pythonic context
 creation and registration management

Co-authored-by: Fabian Mora <fmora.dev at gmail.com
Co-authored-by: Oleksandr "Alex" Zinenko <git at ozinenko.com>
Co-authored-by: Tres <tpopp at users.noreply.github.com>
---
 mlir/include/mlir-c/IR.h                 |   4 +
 mlir/lib/Bindings/Python/IRCore.cpp      |   6 +
 mlir/lib/CAPI/IR/IR.cpp                  |   4 +
 mlir/python/CMakeLists.txt               |   7 +
 mlir/python/mlir/_mlir_libs/_mlir/ir.pyi |   1 +
 mlir/python/mlir/utils.py                | 489 +++++++++++++++++++++++
 mlir/test/python/utils.py                |  58 +++
 7 files changed, 569 insertions(+)
 create mode 100644 mlir/python/mlir/utils.py
 create mode 100644 mlir/test/python/utils.py

diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 81299c7911d24..877aa73ca2cc0 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -143,6 +143,10 @@ MLIR_CAPI_EXPORTED MlirDialect mlirContextGetOrLoadDialect(MlirContext context,
 MLIR_CAPI_EXPORTED void mlirContextEnableMultithreading(MlirContext context,
                                                         bool enable);
 
+/// Retrieve threading mode current value as controlled by
+/// mlirContextEnableMultithreading.
+MLIR_CAPI_EXPORTED bool mlirContextIsMultithreadingEnabled(MlirContext context);
+
 /// Eagerly loads all available dialects registered with a context, making
 /// them available for use for IR construction.
 MLIR_CAPI_EXPORTED void
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index d961482885300..002923becd23a 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -2939,6 +2939,12 @@ void mlir::python::populateIRCore(nb::module_ &m) {
              ss << pool.ptr;
              return ss.str();
            })
+      .def_prop_ro(
+          "is_multithreading_enabled",
+          [](PyMlirContext &self) {
+            return mlirContextIsMultithreadingEnabled(self.get());
+          },
+          "Returns true if multithreading is enabled for this context.")
       .def(
           "is_registered_operation",
           [](PyMlirContext &self, std::string &name) {
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index fbc66bcf5c2d0..1cc555ad41de1 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -101,6 +101,10 @@ bool mlirContextIsRegisteredOperation(MlirContext context, MlirStringRef name) {
   return unwrap(context)->isOperationRegistered(unwrap(name));
 }
 
+bool mlirContextIsMultithreadingEnabled(MlirContext context) {
+  return unwrap(context)->isMultithreadingEnabled();
+}
+
 void mlirContextEnableMultithreading(MlirContext context, bool enable) {
   return unwrap(context)->enableMultithreading(enable);
 }
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index b2daabb2a5957..b4e0ab2ec6dae 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -48,6 +48,13 @@ declare_mlir_python_sources(MLIRPythonSources.ExecutionEngine
     runtime/*.py
 )
 
+declare_mlir_python_sources(MLIRPythonSources.Utils
+  ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
+  ADD_TO_PARENT MLIRPythonSources
+  SOURCES
+    utils.py
+)
+
 declare_mlir_python_sources(MLIRPythonCAPI.HeaderSources
   ROOT_DIR "${MLIR_SOURCE_DIR}/include"
   SOURCES_GLOB "mlir-c/*.h"
diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
index 70bca3c75d842..56b9f17d52ea7 100644
--- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
+++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
@@ -986,6 +986,7 @@ class ComplexType(Type):
 class Context:
     current: ClassVar[Context] = ...  # read-only
     allow_unregistered_dialects: bool
+    is_multithreading_enabled: bool
     @staticmethod
     def _get_live_count() -> int: ...
     def _CAPICreate(self) -> object: ...
diff --git a/mlir/python/mlir/utils.py b/mlir/python/mlir/utils.py
new file mode 100644
index 0000000000000..10c316bc8f086
--- /dev/null
+++ b/mlir/python/mlir/utils.py
@@ -0,0 +1,489 @@
+from contextlib import contextmanager, nullcontext
+from functools import wraps
+from typing import (
+    Any,
+    Callable,
+    Concatenate,
+    Iterator,
+    Optional,
+    ParamSpec,
+    Sequence,
+    TypeVar,
+)
+
+from mlir import ir
+from mlir._mlir_libs import get_dialect_registry
+from mlir.dialects import func
+from mlir.dialects.transform import interpreter
+from mlir.passmanager import PassManager
+
+RT = TypeVar("RT")
+Param = ParamSpec("Param")
+
+
+ at contextmanager
+def using_mlir_context(
+    *,
+    required_dialects: Optional[Sequence[str]] = None,
+    required_extension_operations: Optional[Sequence[str]] = None,
+    registration_funcs: Optional[Sequence[Callable[[ir.DialectRegistry], None]]] = None,
+) -> Iterator[None]:
+    """Ensure a valid context exists by creating one if necessary.
+
+    NOTE: If values that are attached to a Context should outlive this
+          contextmanager, use caller_mlir_context!
+
+    This can be used as a function decorator or managed context in a with statement.
+    The context will throw an error if the required dialects have not been registered,
+    and a context is guaranteed to exist in this scope.
+
+    This only works on dialects and not dialect extensions currently.
+
+    Parameters
+    ------------
+        required_dialects:
+            Dialects that need to be registered in the context
+        required_extension_operations:
+            Required operations by their fully specified name. These are a proxy for detecting needed dialect extensions.
+        registration_funcs:
+            Functions that should be called to register all missing dialects/operations if they have not been registered.
+    """
+    dialects = required_dialects or []
+    extension_operations = required_extension_operations or []
+    registrations = registration_funcs or []
+    new_context = nullcontext if ir.Context.current else ir.Context
+    with new_context(), ir.Location.unknown():
+        context = ir.Context.current
+        # Attempt to disable multithreading. This could fail if currently being
+        # used in multiple threads. This must be done before checking for
+        # dialects or registering dialects as both will assert fail in a
+        # multithreaded situation.
+        multithreading = context.is_multithreading_enabled
+        if multithreading:
+            context.enable_multithreading(False)
+
+        def attempt_registration():
+            """Register everything from registration_funcs."""
+            nonlocal context, registrations
+
+            # Gather dialects and extensions then add them to the context.
+            registry = ir.DialectRegistry()
+            for rf in registrations:
+                rf(registry)
+
+            context.append_dialect_registry(registry)
+
+        # See if any dialects are missing, register if they are, and then assert they are all registered.
+        try:
+            for dialect in dialects:
+                # If the dialect is registered, continue checking
+                context.get_dialect_descriptor(dialect)
+        except Exception:
+            attempt_registration()
+
+        for dialect in dialects:
+            # If the dialect is registered, continue checking
+            assert context.get_dialect_descriptor(
+                dialect
+            ), f"required dialect {dialect} not registered by registration_funcs"
+
+        # See if any operations are missing and register if they are. We cannot
+        # assert the operations exist in the registry after for some reason.
+        #
+        # TODO: Make this work for dialect extensions specifically
+        for operation in extension_operations:
+            # If the operation is registered, attempt to register and then strongly assert it was added
+            if not context.is_registered_operation(operation):
+                attempt_registration()
+                break
+        for operation in extension_operations:
+            # First get the dialect descriptior which loads the dialect as a side effect
+            dialect = operation.split(".")[0]
+            assert context.get_dialect_descriptor(dialect), f"Never loaded {dialect}"
+            assert context.is_registered_operation(
+                operation
+            ), f"expected {operation} to be registered in its dialect"
+        context.enable_multithreading(multithreading)
+
+        # Context manager related yield
+        try:
+            yield
+        finally:
+            pass
+
+
+ at contextmanager
+def caller_mlir_context(
+    *,
+    required_dialects: Optional[Sequence[str]] = None,
+    required_extension_operations: Optional[Sequence[str]] = None,
+    registration_funcs: Optional[Sequence[Callable[[ir.DialectRegistry], None]]] = None,
+) -> Iterator[None]:
+    """Requires an enclosing context from the caller and ensures relevant operations are loaded.
+
+    NOTE: If the Context is only needed inside of this contextmanager and returned values
+          don't need to the Context, use using_mlir_context!
+
+    A context must already exist before this frame is executed to ensure that any values
+    continue to live on exit. Conceptually, this prevents use-after-free issues and
+    makes the intention clear when one intends to return values tied to a Context.
+    """
+    assert (
+        ir.Context.current
+    ), "Caller must have a context so it outlives this function call."
+    with using_mlir_context(
+        required_dialects=required_dialects,
+        required_extension_operations=required_extension_operations,
+        registration_funcs=registration_funcs,
+    ):
+        # Context manager related yield
+        try:
+            yield
+        finally:
+            pass
+
+
+def with_toplevel_context(f: Callable[Param, RT]) -> Callable[Param, RT]:
+    """Decorate the function to be executed with a fresh MLIR context.
+
+    This decorator will ensure the function is executed inside a context manager for a
+    new MLIR context with upstream and IREE dialects registered. Note that each call to
+    such a function has a new context, meaning that context-owned objects from these
+    functions will not be equal to each other. All arguments and keyword arguments are
+    forwarded.
+
+    The context is destroyed before the function exits so any result from the function
+    must not depend on the context.
+    """
+
+    @wraps(f)
+    def decorator(*args: Param.args, **kwargs: Param.kwargs) -> RT:
+        # Appending dialect registry and loading all available dialects occur on
+        # context creation because of the "_site_initialize" call.
+        with ir.Context(), ir.Location.unknown():
+            results = f(*args, **kwargs)
+        return results
+
+    return decorator
+
+
+def with_toplevel_context_create_module(
+    f: Callable[Concatenate[ir.Module, Param], RT],
+) -> Callable[Param, RT]:
+    """Decorate function to be executed in a fresh MLIR context and give it a module.
+
+    The decorated function will receive, as its leading argument, a fresh MLIR module.
+    The context manager is set up to insert operations into this module. All other
+    arguments and keyword arguments are forwarded.
+
+    The module and context are destroyed before the function exists so any result from
+    the function must not depend on either.
+    """
+
+    @with_toplevel_context
+    @wraps(f)
+    def internal(*args: Param.args, **kwargs: Param.kwargs) -> RT:
+        module = ir.Module.create()
+        with ir.InsertionPoint(module.body):
+            results = f(module, *args, **kwargs)
+        return results
+
+    return internal
+
+
+def call_with_toplevel_context(f: Callable[[], RT]) -> Callable[[], RT]:
+    """Immediately call the function in a fresh MLIR context."""
+    decorated = with_toplevel_context(f)
+    decorated()
+    return decorated
+
+
+def call_with_toplevel_context_create_module(
+    f: Callable[[ir.Module], RT],
+) -> Callable[[], RT]:
+    """Immediately call the function in a fresh MLIR context and give it a module.
+
+    The decorated function will receive, as its only argument, a fresh MLIR module. The
+    context manager is set up to insert operations into this module.
+    """
+    decorated = with_toplevel_context_create_module(f)
+    decorated()
+    return decorated
+
+
+# def _debug_types_impl(types: Sequence[str]) -> Iterator[None]:
+#     from mlir.ir import _GlobalDebug
+
+#     # Save the original debug state. The debug types will be popped rather than
+#     # manually copied and saved for later.
+#     original_flag = _GlobalDebug.flag
+#     _GlobalDebug.flag = True
+#     _GlobalDebug.append_types(types)
+
+#     try:
+#         yield
+#     finally:
+#         # Reset the global debug flag and remove the most recent types that were
+#         # appended. This assumes that nothing else popped when it should not have.
+#         _GlobalDebug.flag = original_flag
+#         _GlobalDebug.pop_types()
+
+
+# @contextmanager
+# def debug_types_context(types: Sequence[str]):
+#     """Temporarily create a context that enables debugging with specified filters.
+
+#     These would be the same as running with -debug-only=*types. Where multiple contexts
+#     will be joined together to create the full list if they are nested.
+
+#     This requires that the core MLIR units were compiled without NDEBUG.
+#     """
+#     return _debug_types_impl(types)
+
+
+# @contextmanager
+# def debug_td(types: Sequence[str] = [], *, full_debug: bool = False) -> Iterator[None]:
+#     """Temporarily create a context that enables full transform dialect debugging,
+#     potentially with additional specified filters.
+
+#     These would be the same as running with -debug-only=*types. Where multiple contexts
+#     will be joined together to create the full list if they are nested.
+
+#     This requires that the core MLIR units were compiled without NDEBUG.
+#     """
+#     return _debug_types_impl(
+#         list(types)
+#         + [
+#             "transform-dialect",
+#             "transform-dialect-print-top-level-after-all",
+#             "async-transform-ops",
+#         ]
+#         + (["transform-dialect-full"] if full_debug else [])
+#     )
+
+
+# @contextmanager
+# def debug_conversion(types: Sequence[str] = []) -> Iterator[None]:
+#     """Temporarily create a context that enables full conversion debugging,
+#     potentially with additional specified filters.
+
+#     These would be the same as running with -debug-only=*types. Where multiple contexts
+#     will be joined together to create the full list if they are nested.
+
+#     This requires that the core MLIR units were compiled without NDEBUG.
+#     """
+#     return _debug_types_impl(["dialect-conversion"])
+
+
+# # TODO: Allow cloning functions from one module to another.
+# # Atm we have to resort to string concatenation.
+# def _create_module_from_main(main):
+#     module = ir.Module.create()
+#     ops = module.operation.regions[0].blocks[0].operations
+#     return ir.Module.parse("\n".join([str(op) for op in ops]) + main)
+
+
+# def _add_lowering_passes(
+#     pm, linalg_lowering, vector_lowering, memref_lowering, loop_lowering
+# ):
+#     if loop_lowering:
+#         pm.add("func.func(expand-strided-metadata)")
+#         pm.add("func.func(lower-affine)")
+#     if linalg_lowering:
+#         pm.add("func.func(convert-linalg-to-loops)")
+#         pm.add("func.func(convert-scf-to-cf)")
+#         pm.add("convert-cf-to-llvm")
+#     if memref_lowering:
+#         pm.add("finalize-memref-to-llvm")
+#     if vector_lowering:
+#         pm.add("convert-vector-to-llvm")
+#     pm.add("convert-arith-to-llvm")
+#     pm.add("convert-index-to-llvm")
+#     pm.add("convert-func-to-llvm")
+#     pm.add("reconcile-unrealized-casts")
+
+
+# def lowering_to_llvm_via_passmanager(
+#     module=None,
+#     main=None,
+#     linalg_lowering=False,
+#     vector_lowering=False,
+#     memref_lowering=False,
+#     loop_lowering=False,
+# ):
+#     """Transforms a module using registered passes and pass manager.
+
+#     Parameters
+#     ----------
+#     module: ir.Module
+#         Module to be transformed.
+
+#     main: str
+#         Main function embedded into a string. Will be used to create the module to be transformed.
+
+#     linalg_lowering: bool
+#         Flag indicating whether linalg lowering should be included in the transformation.
+
+#     vector_lowering: bool
+#         Flag indicating whether vector lowering should be included in the transformation.
+
+#     memref_lowering: bool
+#         Flag indicating whether memref lowering should be included in the transformation.
+
+#     loop_lowering: bool
+#         Flag indicating whether loop lowering should be included in the transformation.
+
+#     Returns
+#     -------
+#     The transformed module.
+#     """
+
+#     if module is None:
+#         module = _create_module_from_main(main)
+#     pm = PassManager("builtin.module")
+#     _add_lowering_passes(
+#         pm, linalg_lowering, vector_lowering, memref_lowering, loop_lowering
+#     )
+#     pm.run(module.operation)
+#     return module
+
+
+# def apply_named_sequence_from_module(
+#     *, payload: ir.Module, transform: ir.Module, entry_point: str = "__transform_main"
+# ):
+#     interpreter.apply_named_sequence(
+#         payload, ir.SymbolTable(transform.operation)[entry_point], transform.operation
+#     )
+
+
+# def build_and_run_common_pass_pipeline(module: ir.Module, nvvm_attach_target_opts: str):
+#     pm = PassManager("builtin.module")
+#     pm.add("convert-nvgpu-to-nvvm")
+#     pm.add("convert-scf-to-cf")
+#     pm.add("gpu-kernel-outlining")
+#     pm.add("convert-func-to-llvm")
+#     pm.add("expand-strided-metadata")
+#     pm.add(f"nvvm-attach-target{{{nvvm_attach_target_opts}}}")
+#     pm.add("lower-affine")
+#     pm.add("convert-arith-to-llvm")
+#     pm.add("convert-index-to-llvm")
+#     pm.run(module.operation)
+#     return module
+
+
+# def build_and_run_gpu_pass_pipeline(module: ir.Module, convert_gpu_to_nvvm_opts: str):
+#     pm = PassManager("builtin.module")
+#     pm.add(f"gpu.module(convert-gpu-to-nvvm{{{convert_gpu_to_nvvm_opts}}})")
+#     pm.run(module.operation)
+#     return module
+
+
+# def build_and_run_host_post_pipeline(module: ir.Module):
+#     pm = PassManager("builtin.module")
+#     pm.add("gpu-to-llvm")
+#     pm.add("convert-nvvm-to-llvm")
+#     pm.add("reconcile-unrealized-casts")
+#     pm.run(module.operation)
+#     return module
+
+
+# def gpu_module_to_binary(module: ir.Module, cubin_format: str):
+#     pm = PassManager("builtin.module")
+#     pm.add(f"gpu-module-to-binary{{format={cubin_format}}}")
+#     pm.run(module.operation)
+#     return module
+
+
+# def lower_nvgpu(
+#     module: ir.Module,
+#     nvvm_attach_target_opts: str,
+#     convert_gpu_to_nvvm_opts: str,
+#     cubin_format: str,
+# ) -> ir.Module:
+#     """Lowers `module` to LLVM IR, targeting GPU, using registered passes and pass manager.
+
+#     Parameters
+#     ----------
+#     module: ir.Module
+#         Module to be lowered.
+
+#     nvvm_attach_target_opts: str
+#         Options describing the NVVM target to be attached as an attribute to the GPU module.
+
+#     convert_gpu_to_nvvm_opts: str
+#         Options for generating NVVM operations from GPU ones.
+
+#     cubin_format: str
+#         Compilation format for serializing to cubin.
+
+#     Returns
+#     -------
+#     The LLVM IR module.
+#     """
+
+#     module = build_and_run_common_pass_pipeline(module, nvvm_attach_target_opts)
+#     module = build_and_run_gpu_pass_pipeline(module, convert_gpu_to_nvvm_opts)
+#     module = build_and_run_host_post_pipeline(module)
+#     return gpu_module_to_binary(module, cubin_format)
+
+
+# def rename_all_symbols(
+#     module: ir.Module,
+#     rename_fn: Callable[[str], Optional[str]],
+#     target_op: str = "func.func",
+# ):
+#     """Renames all the symbols in `module` using the `rename_fn` function.
+#     If `rename_fn` returns `None`, then the symbol is not renamed.
+
+#     Parameters
+#     ----------
+#     @param: module The MLIR module.
+#     @param: rename_fn A callable that takes the the symbol and returns the new name for the symbol.
+#     @param: target_op The 'operation name' of the target.
+#     """
+#     symbols = ir.SymbolTable(module.operation)
+#     for op in module.operation.regions[0].blocks[0]:
+#         if op.operation.name != target_op:
+#             continue
+#         old_name = op.name.value
+#         new_name = rename_fn(old_name)
+#         if new_name is None:
+#             continue
+#         symbols.replace_all_symbol_uses(old_name, new_name, module.operation)
+#         symbols.set_symbol_name(op.operation, new_name)
+
+
+# @with_toplevel_context
+# def run_module(
+#     mod: str,
+#     entry_point: str,
+#     *args,
+#     shared_libs: Sequence[str] = [],
+#     invoke_wrapper: Callable[[Callable[[], None]], Any] = lambda x: x(),
+# ) -> Any:
+#     """Runs an MLIR module using the execution engine.
+
+#     Parameters
+#     ----------
+#     @param: mod The MLIR module as a string.
+#     @param: entry_point The function to execute.
+#     @param: args The args to passthrough to the MLIR function.
+#     @param: shared_libs The libraries to be loaded by the execution engine.
+#     @param: invoke_wrapper A wrapper function to invoke the execution engine, this is useful for performing additional actions.
+
+#     Returns
+#     -------
+#     The results of the wrapper function.
+#     """
+#     from mlir.execution_engine import ExecutionEngine
+
+#     module = ir.Module.parse(mod)
+#     execution_engine = ExecutionEngine(
+#         module,
+#         shared_libs=shared_libs,
+#     )
+
+#     def invoke():
+#         execution_engine.invoke(entry_point, *args)
+
+#     return invoke_wrapper(invoke)
diff --git a/mlir/test/python/utils.py b/mlir/test/python/utils.py
new file mode 100644
index 0000000000000..8435fdd363ae3
--- /dev/null
+++ b/mlir/test/python/utils.py
@@ -0,0 +1,58 @@
+# RUN: %python %s | FileCheck %s
+
+import unittest
+
+from mlir import ir
+from mlir.dialects import arith, builtin
+from mlir.extras import types as T
+from mlir.utils import (
+    call_with_toplevel_context_create_module,
+    caller_mlir_context,
+    using_mlir_context,
+)
+
+
+class TestRequiredContext(unittest.TestCase):
+    def test_shared_context(self):
+        """Test that the context is reused, so values can be passed/returned between functions."""
+
+        @using_mlir_context()
+        def create_add(lhs: ir.Value, rhs: ir.Value) -> ir.Value:
+            return arith.AddFOp(
+                lhs, rhs, fastmath=arith.FastMathFlags.nnan | arith.FastMathFlags.ninf
+            ).result
+
+        @using_mlir_context()
+        def multiple_adds(lhs: ir.Value, rhs: ir.Value) -> ir.Value:
+            return create_add(create_add(lhs, rhs), create_add(lhs, rhs))
+
+        @call_with_toplevel_context_create_module
+        def _(module) -> None:
+            c = arith.ConstantOp(value=42.42, result=ir.F32Type.get()).result
+            multiple_adds(c, c)
+
+            # CHECK: constant
+            # CHECK-NEXT: arith.addf
+            # CHECK-NEXT: arith.addf
+            # CHECK-NEXT: arith.addf
+            print(module)
+
+    def test_unregistered_op_asserts(self):
+        """Confirm that with_mlir_context fails if an operation is still not registered."""
+        with self.assertRaises(AssertionError), using_mlir_context(
+            required_extension_operations=["func.fake_extension_op"],
+            registration_funcs=[],
+        ):
+            pass
+
+    def test_required_op_asserts(self):
+        """Confirm that with_mlir_context fails if an operation is still not registered."""
+        with self.assertRaises(AssertionError), caller_mlir_context(
+            required_extension_operations=["func.fake_extension_op"],
+            registration_funcs=[],
+        ):
+            pass
+
+
+if __name__ == "__main__":
+    unittest.main()

>From c277571c4199dfb661589a3f918f6e78791402aa Mon Sep 17 00:00:00 2001
From: Nicolas Vasilache <nico.vasilache at amd.com>
Date: Sun, 6 Jul 2025 11:55:26 +0200
Subject: [PATCH 2/3] [mlir][transform] NFC - Introduce TransformEachOp util

---
 .../TransformOps/BufferizationTransformOps.td |  30 +-
 .../GPU/TransformOps/GPUTransformOps.td       |  44 +-
 .../Linalg/TransformOps/LinalgTransformOps.td | 567 +++++++-----------
 .../MemRef/TransformOps/MemRefTransformOps.td |  38 +-
 .../NVGPU/TransformOps/NVGPUTransformOps.td   |  72 +--
 .../Dialect/Transform/Interfaces/Utils.td     |  25 +
 6 files changed, 331 insertions(+), 445 deletions(-)
 create mode 100644 mlir/include/mlir/Dialect/Transform/Interfaces/Utils.td

diff --git a/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td b/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td
index 53b3b0505b399..36c0153108411 100644
--- a/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td
@@ -12,6 +12,7 @@
 include "mlir/Dialect/Bufferization/IR/BufferizationEnums.td"
 include "mlir/Dialect/Transform/IR/TransformDialect.td"
 include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
+include "mlir/Dialect/Transform/Interfaces/Utils.td"
 include "mlir/Dialect/Transform/IR/TransformTypes.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/IR/OpBase.td"
@@ -24,9 +25,12 @@ def Transform_AllocTensorOp : Transform_ConcreteOpType<"bufferization.alloc_tens
 //===----------------------------------------------------------------------===//
 
 def BufferLoopHoistingOp
-    : Op<Transform_Dialect, "bufferization.buffer_loop_hoisting",
-        [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
-         TransformEachOpTrait, TransformOpInterface]> {
+    : TransformEachOp<
+      "bufferization.buffer_loop_hoisting",
+      "::mlir::Operation*",
+      [
+        DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+      ]> {
   let description = [{
     Hoist buffer allocations ("memref.alloc" and "memref.alloca") from loops
     within the targeted op. This transform assumes that there are no buffer
@@ -38,14 +42,6 @@ def BufferLoopHoistingOp
   let arguments = (ins TransformHandleTypeInterface:$target);
   let results = (outs);
   let assemblyFormat = "$target attr-dict `:` type($target)";
-
-  let extraClassDeclaration = [{
-    ::mlir::DiagnosedSilenceableFailure applyToOne(
-        ::mlir::transform::TransformRewriter &rewriter,
-        ::mlir::Operation *target,
-        ::mlir::transform::ApplyToEachResultList &results,
-        ::mlir::transform::TransformState &state);
-  }];
 }
 
 //===----------------------------------------------------------------------===//
@@ -164,11 +160,13 @@ def EliminateEmptyTensorsOp
 //===----------------------------------------------------------------------===//
 
 def EmptyTensorToAllocTensorOp
-    : Op<Transform_Dialect, "bufferization.empty_tensor_to_alloc_tensor",
-        [FunctionalStyleTransformOpTrait,
-         MemoryEffectsOpInterface,
-         TransformOpInterface,
-         TransformEachOpTrait]> {
+    : TransformEachOp<
+      "bufferization.empty_tensor_to_alloc_tensor",
+      "::mlir::tensor::EmptyOp",
+      [
+        FunctionalStyleTransformOpTrait,
+        MemoryEffectsOpInterface,
+      ]> {
   let description = [{
     Replace a tensor.empty with a bufferization.tensor_alloc.
 
diff --git a/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td b/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td
index 36b579485fc04..bdfac8b71d5e9 100644
--- a/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td
+++ b/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td
@@ -11,6 +11,7 @@
 
 include "mlir/Dialect/Transform/IR/TransformDialect.td"
 include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
+include "mlir/Dialect/Transform/Interfaces/Utils.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/IR/OpBase.td"
 
@@ -125,12 +126,14 @@ def EliminateBarriersOp :
   let assemblyFormat = [{ attr-dict }];
 }
 
-def MapNestedForallToThreads :
-  Op<Transform_Dialect, "gpu.map_nested_forall_to_threads",
-    [FunctionalStyleTransformOpTrait,
-     MemoryEffectsOpInterface,
-     TransformEachOpTrait,
-     TransformOpInterface]> {
+def MapNestedForallToThreads
+    : TransformEachOp<
+      "gpu.map_nested_forall_to_threads",
+      "::mlir::Operation*",
+      [
+        FunctionalStyleTransformOpTrait,
+        MemoryEffectsOpInterface,
+      ]> {
   let description = [{
       Target the `gpu.launch op` and rewrite all `scf.forall` nested in it to
       distributed `gpu.thread_id` attribute.
@@ -235,21 +238,16 @@ def MapNestedForallToThreads :
     attr-dict
     `:` functional-type($target, $result)
   }];
-  let extraClassDeclaration = [{
-    ::mlir::DiagnosedSilenceableFailure applyToOne(
-        ::mlir::transform::TransformRewriter &rewriter,
-        ::mlir::Operation *target,
-        ::mlir::transform::ApplyToEachResultList &results,
-        ::mlir::transform::TransformState &state);
-  }];
 }
 
-def MapForallToBlocks :
-  Op<Transform_Dialect, "gpu.map_forall_to_blocks",
-    [FunctionalStyleTransformOpTrait,
-     MemoryEffectsOpInterface,
-     TransformOpInterface,
-     TransformEachOpTrait]> {
+def MapForallToBlocks 
+    : TransformEachOp<
+      "gpu.map_forall_to_blocks",
+      "::mlir::Operation*",
+      [
+        FunctionalStyleTransformOpTrait,
+        MemoryEffectsOpInterface,
+      ]> {
   let description = [{
     Target the gpu_launch op and rewrite the top level `scf.forall`
     to distributed gpu.block_id attribute. If `generate_gpu_launch` attribute
@@ -300,14 +298,6 @@ def MapForallToBlocks :
     `:` functional-type($target, $result)
   }];
   let hasVerifier = 1;
-
-  let extraClassDeclaration = [{
-    ::mlir::DiagnosedSilenceableFailure applyToOne(
-        ::mlir::transform::TransformRewriter &rewriter,
-        ::mlir::Operation *target,
-        ::mlir::transform::ApplyToEachResultList &results,
-        ::mlir::transform::TransformState &state);
-  }];
 }
 
 def ApplyGPUPromoteShuffleToAMDGPUPatternsOp : Op<Transform_Dialect,
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 4360055e78691..8a73f4c49499e 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -14,6 +14,7 @@ include "mlir/Dialect/Linalg/TransformOps/LinalgTransformEnums.td"
 include "mlir/Dialect/Transform/IR/TransformAttrs.td"
 include "mlir/Dialect/Transform/IR/TransformDialect.td"
 include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
+include "mlir/Dialect/Transform/Interfaces/Utils.td"
 include "mlir/Dialect/Transform/IR/TransformTypes.td"
 include "mlir/Dialect/SCF/IR/DeviceMappingInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
@@ -247,12 +248,15 @@ def BufferizeToAllocationOp : Op<Transform_Dialect,
 // DecomposeOp
 //===----------------------------------------------------------------------===//
 
-def DecomposeOp : Op<Transform_Dialect, "structured.decompose",
-    [FunctionalStyleTransformOpTrait,
-     MemoryEffectsOpInterface,
-     TransformOpInterface,
-     TransformEachOpTrait,
-     ReportTrackingListenerFailuresOpTrait]> {
+def DecomposeOp
+    : TransformEachOp<
+      "structured.decompose",
+      "::mlir::linalg::LinalgOp",
+      [
+        FunctionalStyleTransformOpTrait,
+        MemoryEffectsOpInterface,
+        ReportTrackingListenerFailuresOpTrait
+      ]> {
   let description = [{
     Decomposes named complex operations, such as higher-dimensional
     (depthwise) convolutions, into combinations of lower-dimensional equivalents
@@ -271,14 +275,6 @@ def DecomposeOp : Op<Transform_Dialect, "structured.decompose",
   let results = (outs TransformHandleTypeInterface:$transformed);
   let assemblyFormat =
       "$target attr-dict `:` functional-type(operands, results)";
-
-  let extraClassDeclaration = [{
-    ::mlir::DiagnosedSilenceableFailure applyToOne(
-        ::mlir::transform::TransformRewriter &rewriter,
-        ::mlir::linalg::LinalgOp target,
-        ::mlir::transform::ApplyToEachResultList &results,
-        ::mlir::transform::TransformState &state);
-  }];
 }
 
 //===----------------------------------------------------------------------===//
@@ -443,10 +439,15 @@ def FuseIntoContainingOp :
 // GeneralizeOp
 //===----------------------------------------------------------------------===//
 
-def GeneralizeOp : Op<Transform_Dialect, "structured.generalize",
-    [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
-     TransformOpInterface, TransformEachOpTrait,
-     ReportTrackingListenerFailuresOpTrait]> {
+def GeneralizeOp
+    : TransformEachOp<
+      "structured.generalize",
+      "::mlir::linalg::LinalgOp",
+      [
+        FunctionalStyleTransformOpTrait,
+        MemoryEffectsOpInterface,
+        ReportTrackingListenerFailuresOpTrait
+      ]> {
   let description = [{
     Transforms a named structured operation into the generic form with the
     explicit attached region.
@@ -467,24 +468,21 @@ def GeneralizeOp : Op<Transform_Dialect, "structured.generalize",
       $target attr-dict `:` 
       custom<SemiFunctionType>(type($target), type($transformed), "false")
   }];
-
-  let extraClassDeclaration = [{
-    ::mlir::DiagnosedSilenceableFailure applyToOne(
-        ::mlir::transform::TransformRewriter &rewriter,
-        ::mlir::linalg::LinalgOp target,
-        ::mlir::transform::ApplyToEachResultList &results,
-        ::mlir::transform::TransformState &state);
-  }];
 }
 
 //===----------------------------------------------------------------------===//
 // SpecializeOp
 //===----------------------------------------------------------------------===//
 
-def SpecializeOp : Op<Transform_Dialect, "structured.specialize",
-    [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
-     TransformOpInterface, TransformEachOpTrait,
-     ReportTrackingListenerFailuresOpTrait]> {
+def SpecializeOp
+    : TransformEachOp<
+      "structured.specialize",
+      "::mlir::linalg::LinalgOp",
+      [
+        FunctionalStyleTransformOpTrait,
+        MemoryEffectsOpInterface,
+        ReportTrackingListenerFailuresOpTrait
+      ]> {
   let description = [{
     Transforms a generic operation into the equivalent named form.
 
@@ -505,24 +503,21 @@ def SpecializeOp : Op<Transform_Dialect, "structured.specialize",
       $target attr-dict `:` 
       custom<SemiFunctionType>(type($target), type($transformed), "false")
   }];
-
-  let extraClassDeclaration = [{
-    ::mlir::DiagnosedSilenceableFailure applyToOne(
-        ::mlir::transform::TransformRewriter &rewriter,
-        ::mlir::linalg::LinalgOp target,
-        ::mlir::transform::ApplyToEachResultList &results,
-        ::mlir::transform::TransformState &state);
-  }];
 }
 
 //===----------------------------------------------------------------------===//
 // InterchangeOp
 //===----------------------------------------------------------------------===//
 
-def InterchangeOp : Op<Transform_Dialect, "structured.interchange",
-    [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
-    TransformOpInterface, TransformEachOpTrait,
-    ReportTrackingListenerFailuresOpTrait]> {
+def InterchangeOp
+    : TransformEachOp<
+      "structured.interchange",
+      "::mlir::linalg::GenericOp",
+      [
+        FunctionalStyleTransformOpTrait,
+        MemoryEffectsOpInterface,
+        ReportTrackingListenerFailuresOpTrait
+      ]> {
   let description = [{
     Interchanges the iterators of the operations pointed to by the target handle
     using the iterator interchange attribute.
@@ -550,24 +545,20 @@ def InterchangeOp : Op<Transform_Dialect, "structured.interchange",
     `:` custom<SemiFunctionType>(type($target), type($transformed), "false")
   }];
   let hasVerifier = 1;
-
-  let extraClassDeclaration = [{
-    ::mlir::DiagnosedSilenceableFailure applyToOne(
-        ::mlir::transform::TransformRewriter &rewriter,
-        ::mlir::linalg::GenericOp target,
-        ::mlir::transform::ApplyToEachResultList &results,
-        ::mlir::transform::TransformState &state);
-  }];
 }
 
 //===----------------------------------------------------------------------===//
 // LinalgCopyToMemrefOp
 //===----------------------------------------------------------------------===//
 
-def LinalgCopyToMemrefOp :
-    Op<Transform_Dialect, "structured.linalg_copy_to_memref",
-      [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
-       TransformEachOpTrait, TransformOpInterface]> {
+def LinalgCopyToMemrefOp
+    : TransformEachOp<
+      "structured.linalg_copy_to_memref",
+      "::mlir::Operation*",
+      [
+        FunctionalStyleTransformOpTrait,
+        MemoryEffectsOpInterface
+      ]> {
   let description = [{
     Targeted rewrite of a linalg.copy on memrefs to a memref.copy.
     This is useful when bufferizing copies to a linalg.copy, later applying some
@@ -585,24 +576,20 @@ def LinalgCopyToMemrefOp :
   let builders = [
     OpBuilder<(ins "Value":$target)>,
   ];
-  let extraClassDeclaration = [{
-    ::mlir::DiagnosedSilenceableFailure applyToOne(
-        ::mlir::transform::TransformRewriter &rewriter,
-        ::mlir::Operation *target,
-        ::mlir::transform::ApplyToEachResultList &results,
-        ::mlir::transform::TransformState &state);
-  }];
 }
 
 //===----------------------------------------------------------------------===//
 // LowerPackOp
 //===----------------------------------------------------------------------===//
-def LowerPackOp : Op<Transform_Dialect, "structured.lower_pack", [
-                         FunctionalStyleTransformOpTrait,
-                         MemoryEffectsOpInterface,
-                         TransformEachOpTrait,
-                         TransformOpInterface,
-                         ReportTrackingListenerFailuresOpTrait]> {
+def LowerPackOp
+    : TransformEachOp<
+      "structured.lower_pack",
+      "::mlir::linalg::PackOp",
+      [
+        FunctionalStyleTransformOpTrait,
+        MemoryEffectsOpInterface,
+        ReportTrackingListenerFailuresOpTrait
+      ]> {
   let description = [{
     Rewrite a linalg.pack into tensor.pad + tensor.expand_shape + linalg.transpose.
 
@@ -623,25 +610,20 @@ def LowerPackOp : Op<Transform_Dialect, "structured.lower_pack", [
   let assemblyFormat = [{
     $target attr-dict `:` functional-type(operands, results)
   }];
-
-  let extraClassDeclaration = [{
-    ::mlir::DiagnosedSilenceableFailure applyToOne(
-        ::mlir::transform::TransformRewriter &rewriter,
-        ::mlir::linalg::PackOp target,
-        ::mlir::transform::ApplyToEachResultList &transformResults,
-        ::mlir::transform::TransformState &state);
-  }];
 }
 
 //===----------------------------------------------------------------------===//
 // LowerUnPackOp
 //===----------------------------------------------------------------------===//
-def LowerUnPackOp : Op<Transform_Dialect, "structured.lower_unpack", [
-                         FunctionalStyleTransformOpTrait,
-                         MemoryEffectsOpInterface,
-                         TransformEachOpTrait,
-                         TransformOpInterface,
-                         ReportTrackingListenerFailuresOpTrait]> {
+def LowerUnPackOp
+    : TransformEachOp<
+      "structured.lower_unpack",
+      "::mlir::linalg::UnPackOp",
+      [
+        FunctionalStyleTransformOpTrait,
+        MemoryEffectsOpInterface,
+        ReportTrackingListenerFailuresOpTrait
+      ]> {
   let description = [{
     Lower a linalg.unpack into empty + linalg.transpose + tensor.collapse_shape +
     tensor.extract_slice.
@@ -665,13 +647,6 @@ def LowerUnPackOp : Op<Transform_Dialect, "structured.lower_unpack", [
     $target attr-dict `:` functional-type(operands, results)
   }];
 
-  let extraClassDeclaration = [{
-    ::mlir::DiagnosedSilenceableFailure applyToOne(
-        ::mlir::transform::TransformRewriter &rewriter,
-        ::mlir::linalg::UnPackOp target,
-        ::mlir::transform::ApplyToEachResultList &transformResults,
-        ::mlir::transform::TransformState &state);
-  }];
 }
 
 //===----------------------------------------------------------------------===//
@@ -745,10 +720,14 @@ def MatchOp : Op<Transform_Dialect, "structured.match",
 // MultiTileSizesOp
 //===----------------------------------------------------------------------===//
 
-def MultiTileSizesOp : Op<Transform_Dialect, "structured.multitile_sizes",
-    [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
-     TransformOpInterface, TransformEachOpTrait,
-     ReportTrackingListenerFailuresOpTrait]> {
+def MultiTileSizesOp
+    : TransformEachOp<
+      "structured.multitile_sizes",
+      "::mlir::linalg::LinalgOp",
+      [
+        DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+        ReportTrackingListenerFailuresOpTrait
+      ]> {
   let description = [{
     Emits the IR computing the tile sizes `s1` and `s2` such that:
 
@@ -817,14 +796,6 @@ def MultiTileSizesOp : Op<Transform_Dialect, "structured.multitile_sizes",
   let assemblyFormat =
     "$target attr-dict `:` custom<MultitileSizesTypes>("
     "type($target), type($low_size), type($high_size), type($split_point))";
-
-  let extraClassDeclaration = [{
-    ::mlir::DiagnosedSilenceableFailure applyToOne(
-        ::mlir::transform::TransformRewriter &rewriter,
-        ::mlir::linalg::LinalgOp target,
-        ::mlir::transform::ApplyToEachResultList &results,
-        TransformState &state);
-  }];
 }
 
 //===----------------------------------------------------------------------===//
@@ -1320,11 +1291,14 @@ def HoistPadBuildPackingLoopNestOp :
   let hasVerifier = 1;
 }
 
-def HoistPadOp : Op<Transform_Dialect, "structured.hoist_pad",
-    [FunctionalStyleTransformOpTrait,
-     MemoryEffectsOpInterface,
-     TransformOpInterface,
-     TransformEachOpTrait]> {
+def HoistPadOp
+    : TransformEachOp<
+      "structured.hoist_pad",
+      "::mlir::tensor::PadOp",
+      [
+        FunctionalStyleTransformOpTrait,
+        MemoryEffectsOpInterface
+      ]> {
   let description = [{
     Hoist the tensor.pad target operation by at most the given number of loops.
     Optionally apply the transpose attribute to the inner dimensions.
@@ -1361,14 +1335,6 @@ def HoistPadOp : Op<Transform_Dialect, "structured.hoist_pad",
     `:` functional-type(operands, results)
   }];
   let hasVerifier = 1;
-
-  let extraClassDeclaration = [{
-    ::mlir::DiagnosedSilenceableFailure applyToOne(
-        ::mlir::transform::TransformRewriter &rewriter,
-        ::mlir::tensor::PadOp,
-        ::mlir::transform::ApplyToEachResultList &results,
-        ::mlir::transform::TransformState &state);
-  }];
 }
 
 //===----------------------------------------------------------------------===//
@@ -1376,10 +1342,15 @@ def HoistPadOp : Op<Transform_Dialect, "structured.hoist_pad",
 //===----------------------------------------------------------------------===//
 
 
-def PromoteOp : Op<Transform_Dialect, "structured.promote",
-    [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
-    TransformOpInterface, TransformEachOpTrait,
-    ReportTrackingListenerFailuresOpTrait]> {
+def PromoteOp
+    : TransformEachOp<
+      "structured.promote",
+      "::mlir::linalg::LinalgOp",
+      [
+        FunctionalStyleTransformOpTrait,
+        MemoryEffectsOpInterface,
+        ReportTrackingListenerFailuresOpTrait
+      ]> {
   let description = [{
     Promotes the specified operands of the target into a separate memory buffer.
 
@@ -1413,14 +1384,6 @@ def PromoteOp : Op<Transform_Dialect, "structured.promote",
     $target attr-dict `:`
     custom<SemiFunctionType>(type($target), type($transformed), "false")
   }];
-
-  let extraClassDeclaration = [{
-    ::mlir::DiagnosedSilenceableFailure applyToOne(
-        ::mlir::transform::TransformRewriter &rewriter,
-        ::mlir::linalg::LinalgOp target,
-        ::mlir::transform::ApplyToEachResultList &results,
-        ::mlir::transform::TransformState &state);
-  }];
 }
 
 //===----------------------------------------------------------------------===//
@@ -1457,10 +1420,15 @@ def ReplaceOp : Op<Transform_Dialect, "structured.replace",
 // ScalarizeOp
 //===----------------------------------------------------------------------===//
 
-def ScalarizeOp : Op<Transform_Dialect, "structured.scalarize",
-    [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
-     TransformOpInterface, TransformEachOpTrait,
-     ReportTrackingListenerFailuresOpTrait]> {
+def ScalarizeOp
+    : TransformEachOp<
+      "structured.scalarize",
+      "::mlir::linalg::LinalgOp",
+      [
+        FunctionalStyleTransformOpTrait,
+        MemoryEffectsOpInterface,
+        ReportTrackingListenerFailuresOpTrait
+      ]> {
   let description = [{
     Indicates that ops of a specific kind in the given function should be
     scalarized (i.e. their dynamic dimensions tiled by 1).
@@ -1492,14 +1460,6 @@ def ScalarizeOp : Op<Transform_Dialect, "structured.scalarize",
     $target attr-dict `:`
     custom<SemiFunctionType>(type($target), type($result), "false")
   }];
-
-  let extraClassDeclaration = [{
-    ::mlir::DiagnosedSilenceableFailure applyToOne(
-        ::mlir::transform::TransformRewriter &rewriter,
-        ::mlir::linalg::LinalgOp target,
-        ::mlir::transform::ApplyToEachResultList &results,
-        ::mlir::transform::TransformState &state);
-  }];
 }
 
 //===----------------------------------------------------------------------===//
@@ -1529,12 +1489,15 @@ def ConvertToLoopsOp : Op<Transform_Dialect, "structured.convert_to_loops",
 // DecomposeInterfaceOp
 //===----------------------------------------------------------------------===//
 
-def DecomposeInterfaceOp : Op<Transform_Dialect, "structured.decompose_interface",
-    [FunctionalStyleTransformOpTrait,
-     MemoryEffectsOpInterface,
-     TransformOpInterface,
-     TransformEachOpTrait,
-     ReportTrackingListenerFailuresOpTrait]> {
+def DecomposeInterfaceOp
+    : TransformEachOp<
+      "structured.decompose_interface",
+      "::mlir::Operation*",
+      [
+        FunctionalStyleTransformOpTrait,
+        MemoryEffectsOpInterface,
+        ReportTrackingListenerFailuresOpTrait
+      ]> {
   let description = [{
     TODO
   }];
@@ -1543,26 +1506,20 @@ def DecomposeInterfaceOp : Op<Transform_Dialect, "structured.decompose_interface
   let results = (outs TransformHandleTypeInterface:$transformed);
   let assemblyFormat =
       "$target attr-dict `:` functional-type(operands, results)";
-
-  let extraClassDeclaration = [{
-    ::mlir::DiagnosedSilenceableFailure applyToOne(
-        ::mlir::transform::TransformRewriter &rewriter,
-        ::mlir::Operation *target,
-        ::mlir::transform::ApplyToEachResultList &results,
-        ::mlir::transform::TransformState &state);
-  }];
 }
 //===----------------------------------------------------------------------===//
 // RewriteInDestinationPassingStyleOp.
 //===----------------------------------------------------------------------===//
 
-def RewriteInDestinationPassingStyleOp : Op<
-    Transform_Dialect, "structured.rewrite_in_destination_passing_style",
-    [FunctionalStyleTransformOpTrait,
-     MemoryEffectsOpInterface,
-     TransformOpInterface,
-     TransformEachOpTrait,
-     ReportTrackingListenerFailuresOpTrait]> {
+def RewriteInDestinationPassingStyleOp
+    : TransformEachOp<
+      "structured.rewrite_in_destination_passing_style",
+      "::mlir::Operation*",
+      [
+        FunctionalStyleTransformOpTrait,
+        MemoryEffectsOpInterface,
+        ReportTrackingListenerFailuresOpTrait
+      ]> {
   let description = [{
     Rewrite a supported tensor operation that is not in destination-passing style
     into a form that is in destination-passing style.
@@ -1592,14 +1549,6 @@ def RewriteInDestinationPassingStyleOp : Op<
     $target attr-dict
     `:` functional-type($target, results)
   }];
-
-  let extraClassDeclaration = [{
-    ::mlir::DiagnosedSilenceableFailure applyToOne(
-        ::mlir::transform::TransformRewriter &rewriter,
-        ::mlir::Operation *target,
-        ::mlir::transform::ApplyToEachResultList &results,
-        ::mlir::transform::TransformState &state);
-  }];
 }
 
 //===----------------------------------------------------------------------===//
@@ -1660,10 +1609,15 @@ def SplitOp : Op<Transform_Dialect, "structured.split",
 // SplitReductionOp
 //===----------------------------------------------------------------------===//
 
-def SplitReductionOp : Op<Transform_Dialect, "structured.split_reduction",
-       [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
-        TransformEachOpTrait, TransformOpInterface,
-        ReportTrackingListenerFailuresOpTrait]> {
+def SplitReductionOp
+    : TransformEachOp<
+      "structured.split_reduction",
+      "::mlir::linalg::LinalgOp",
+      [
+        FunctionalStyleTransformOpTrait,
+        MemoryEffectsOpInterface,
+        ReportTrackingListenerFailuresOpTrait
+      ]> {
   let description = [{
     Indicates that the given `target` op should be transformed with the
     `splitReduction` transformation and split factor provided as attribute.
@@ -1822,24 +1776,21 @@ def SplitReductionOp : Op<Transform_Dialect, "structured.split_reduction",
                    CArg<"bool", "false">:$useScalingAlgorithm,
                    CArg<"bool", "false">:$useAlloc)>
   ];
-
-  let extraClassDeclaration = [{
-    ::mlir::DiagnosedSilenceableFailure applyToOne(
-        ::mlir::transform::TransformRewriter &rewriter,
-        ::mlir::linalg::LinalgOp target,
-        ::mlir::transform::ApplyToEachResultList &results,
-        ::mlir::transform::TransformState &state);
-  }];
 }
 
 //===----------------------------------------------------------------------===//
 // TileReductionUsingForOp
 //===----------------------------------------------------------------------===//
 
-def TileReductionUsingForOp : Op<Transform_Dialect, "structured.tile_reduction_using_for",
-       [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
-        TransformEachOpTrait, TransformOpInterface,
-        ReportTrackingListenerFailuresOpTrait]> {
+def TileReductionUsingForOp
+    : TransformEachOp<
+      "structured.tile_reduction_using_for",
+      "::mlir::Operation*",
+      [
+        FunctionalStyleTransformOpTrait,
+        MemoryEffectsOpInterface,
+        ReportTrackingListenerFailuresOpTrait
+      ]> {
   let description = [{
     Indicates that the given `target` op should be transformed with the
     `tileReduction` transformation with the tile size provided as attribute.
@@ -1934,25 +1885,21 @@ def TileReductionUsingForOp : Op<Transform_Dialect, "structured.tile_reduction_u
     attr-dict
     `:` functional-type(operands, results)
   }];
-
-  let extraClassDeclaration = [{
-    ::mlir::DiagnosedSilenceableFailure applyToOne(
-        ::mlir::transform::TransformRewriter &rewriter,
-        Operation *target,
-        ::mlir::transform::ApplyToEachResultList &results,
-        ::mlir::transform::TransformState &state);
-  }];
 }
 
 //===----------------------------------------------------------------------===//
 // TileReductionUsingForallOp
 //===----------------------------------------------------------------------===//
 
-def TileReductionUsingForallOp :
-  Op<Transform_Dialect, "structured.tile_reduction_using_forall",
-       [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
-        TransformEachOpTrait, TransformOpInterface,
-        ReportTrackingListenerFailuresOpTrait]> {
+def TileReductionUsingForallOp
+    : TransformEachOp<
+      "structured.tile_reduction_using_forall",
+      "::mlir::linalg::LinalgOp",
+      [
+        FunctionalStyleTransformOpTrait,
+        MemoryEffectsOpInterface,
+        ReportTrackingListenerFailuresOpTrait
+      ]> {
   let description = [{
     Tile a PartialReductionOpInterface op to a tiled `scf.forall` doing
     partial reduction.
@@ -2047,15 +1994,6 @@ def TileReductionUsingForallOp :
     attr-dict
     `:` functional-type(operands, results)
   }];
-
-  let extraClassDeclaration = [{
-    ::mlir::DiagnosedSilenceableFailure applyToOne(
-        ::mlir::transform::TransformRewriter &rewriter,
-        ::mlir::linalg::LinalgOp target,
-        ::mlir::transform::ApplyToEachResultList &results,
-        ::mlir::transform::TransformState &state);
-  }];
-
 }
 
 //===----------------------------------------------------------------------===//
@@ -2334,11 +2272,15 @@ def TileUsingForallOp :
 // VectorizeChildrenAndApplyPatternsOp
 //===----------------------------------------------------------------------===//
 
-def VectorizeChildrenAndApplyPatternsOp :
-  Op<Transform_Dialect, "structured.vectorize_children_and_apply_patterns",
-    [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
-     TransformEachOpTrait, TransformOpInterface,
-     ReportTrackingListenerFailuresOpTrait]> {
+def VectorizeChildrenAndApplyPatternsOp
+    : TransformEachOp<
+      "structured.vectorize_children_and_apply_patterns",
+      "::mlir::Operation*",
+      [
+        FunctionalStyleTransformOpTrait,
+        MemoryEffectsOpInterface,
+        ReportTrackingListenerFailuresOpTrait
+      ]> {
   let description = [{
     Vectorizes all children contained in the given `target` using the
     configuration specified by the attributes of this op. This only vectorizes
@@ -2394,13 +2336,6 @@ def VectorizeChildrenAndApplyPatternsOp :
                CArg<"bool", "false">:$vectorizeNDExtract,
                CArg<"bool", "false">:$flatten1DDepthwise)>
   ];
-  let extraClassDeclaration = [{
-    ::mlir::DiagnosedSilenceableFailure applyToOne(
-        ::mlir::transform::TransformRewriter &rewriter,
-        ::mlir::Operation *target,
-        ::mlir::transform::ApplyToEachResultList &results,
-        ::mlir::transform::TransformState &state);
-  }];
 }
 
 def VectorizeOp : Op<Transform_Dialect, "structured.vectorize",
@@ -2479,11 +2414,15 @@ def VectorizeOp : Op<Transform_Dialect, "structured.vectorize",
 // HoistRedundantVectorTransfersOp
 //===----------------------------------------------------------------------===//
 
-def HoistRedundantVectorTransfersOp :
-  Op<Transform_Dialect, "structured.hoist_redundant_vector_transfers",
-    [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
-     TransformEachOpTrait, TransformOpInterface,
-     ReportTrackingListenerFailuresOpTrait]> {
+def HoistRedundantVectorTransfersOp
+    : TransformEachOp<
+      "structured.hoist_redundant_vector_transfers",
+      "::mlir::func::FuncOp",
+      [
+        FunctionalStyleTransformOpTrait,
+        MemoryEffectsOpInterface,
+        ReportTrackingListenerFailuresOpTrait
+      ]> {
   let description = [{
     Hoist vector.transfer_read / vector.transfer_write pairs out of immediately
     enclosing scf::ForOp iteratively, if the following conditions are true:
@@ -2513,24 +2452,21 @@ def HoistRedundantVectorTransfersOp :
     OpBuilder<(ins "Value":$target,
                CArg<"bool", "false">:$verify_non_zero_trip)>,
   ];
-  let extraClassDeclaration = [{
-    ::mlir::DiagnosedSilenceableFailure applyToOne(
-         ::mlir::transform::TransformRewriter &rewriter,
-         ::mlir::func::FuncOp target,
-         ::mlir::transform::ApplyToEachResultList &results,
-         ::mlir::transform::TransformState &state);
-   }];
 }
 
 //===----------------------------------------------------------------------===//
 // HoistRedundantVectorBroadcastsOp
 //===----------------------------------------------------------------------===//
 
-def HoistRedundantVectorBroadcastsOp :
-  Op<Transform_Dialect, "structured.hoist_redundant_vector_broadcasts",
-    [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
-     TransformEachOpTrait, TransformOpInterface,
-     ReportTrackingListenerFailuresOpTrait]> {
+def HoistRedundantVectorBroadcastsOp
+    : TransformEachOp<
+      "structured.hoist_redundant_vector_broadcasts",
+      "::mlir::Operation*",
+      [
+        FunctionalStyleTransformOpTrait,
+        MemoryEffectsOpInterface,
+        ReportTrackingListenerFailuresOpTrait
+      ]> {
   let description = [{
     Hoist vector.extract / vector.broadcasts pairs out of immediately
     enclosing scf::ForOp iteratively.
@@ -2549,26 +2485,21 @@ def HoistRedundantVectorBroadcastsOp :
   let builders = [
     OpBuilder<(ins "Value":$target)>,
   ];
-  let extraClassDeclaration = [{
-    ::mlir::DiagnosedSilenceableFailure applyToOne(
-         ::mlir::transform::TransformRewriter &rewriter,
-         ::mlir::Operation *target,
-         ::mlir::transform::ApplyToEachResultList &results,
-         ::mlir::transform::TransformState &state);
-   }];
 }
 
 //===----------------------------------------------------------------------===//
 // ConvertConv2DToImg2ColOp
 //===----------------------------------------------------------------------===//
 
-def ConvertConv2DToImg2ColOp : Op<Transform_Dialect,
-    "structured.convert_conv2d_to_img2col",
-    [FunctionalStyleTransformOpTrait,
-     MemoryEffectsOpInterface,
-     TransformOpInterface,
-     TransformEachOpTrait,
-     ReportTrackingListenerFailuresOpTrait]> {
+def ConvertConv2DToImg2ColOp
+    : TransformEachOp<
+      "structured.convert_conv2d_to_img2col",
+      "::mlir::linalg::LinalgOp",
+      [
+        FunctionalStyleTransformOpTrait,
+        MemoryEffectsOpInterface,
+        ReportTrackingListenerFailuresOpTrait
+      ]> {
   let description = [{
     Convert linalg.conv_2d_xxx into linalg.generic (for img2col packing)
     and linalg.matmul.
@@ -2626,27 +2557,21 @@ def ConvertConv2DToImg2ColOp : Op<Transform_Dialect,
   let builders = [
     OpBuilder<(ins "Value":$target)>
   ];
-
-  let extraClassDeclaration = [{
-    ::mlir::DiagnosedSilenceableFailure applyToOne(
-        ::mlir::transform::TransformRewriter &rewriter,
-        ::mlir::linalg::LinalgOp target,
-        ::mlir::transform::ApplyToEachResultList &results,
-        ::mlir::transform::TransformState &state);
-  }];
 }
 
 //===----------------------------------------------------------------------===//
 // FlattenElementwiseLinalgOp
 //===----------------------------------------------------------------------===//
 
-def FlattenElementwiseLinalgOp : Op<Transform_Dialect,
-    "structured.flatten_elementwise",
-    [FunctionalStyleTransformOpTrait,
-     MemoryEffectsOpInterface,
-     TransformOpInterface,
-     TransformEachOpTrait,
-     ReportTrackingListenerFailuresOpTrait]> {
+def FlattenElementwiseLinalgOp
+    : TransformEachOp<
+      "structured.flatten_elementwise",
+      "::mlir::linalg::LinalgOp",
+      [
+        FunctionalStyleTransformOpTrait,
+        MemoryEffectsOpInterface,
+        ReportTrackingListenerFailuresOpTrait
+      ]> {
   let description = [{
     Flattens the iteration space and (applicable) operands of elementwise
     linalg ops to a single dimension.
@@ -2669,27 +2594,21 @@ def FlattenElementwiseLinalgOp : Op<Transform_Dialect,
   let builders = [
     OpBuilder<(ins "Value":$target)>
   ];
-
-  let extraClassDeclaration = [{
-    ::mlir::DiagnosedSilenceableFailure applyToOne(
-        ::mlir::transform::TransformRewriter &rewriter,
-        ::mlir::linalg::LinalgOp target,
-        ::mlir::transform::ApplyToEachResultList &results,
-        ::mlir::transform::TransformState &state);
-  }];
 }
 
 //===----------------------------------------------------------------------===//
 // Transpose Conv2D
 //===----------------------------------------------------------------------===//
 
-def TransposeConv2DOp : Op<Transform_Dialect,
-    "structured.transpose_conv2d",
-    [FunctionalStyleTransformOpTrait,
-     MemoryEffectsOpInterface,
-     TransformOpInterface,
-     TransformEachOpTrait,
-     ReportTrackingListenerFailuresOpTrait]> {
+def TransposeConv2DOp
+    : TransformEachOp<
+      "structured.transpose_conv2d",
+      "::mlir::linalg::LinalgOp",
+      [
+        FunctionalStyleTransformOpTrait,
+        MemoryEffectsOpInterface,
+        ReportTrackingListenerFailuresOpTrait
+      ]> {
   let description = [{
     Convert linalg.conv_2d_nhwc_fhwc into linalg.conv_2d_nhwc_hwcf by introducing
     a linalg.transpose on the filter tensor/memref.
@@ -2718,25 +2637,21 @@ def TransposeConv2DOp : Op<Transform_Dialect,
   let builders = [
     OpBuilder<(ins "Value":$target)>
   ];
-
-  let extraClassDeclaration = [{
-    ::mlir::DiagnosedSilenceableFailure applyToOne(
-        ::mlir::transform::TransformRewriter &rewriter,
-        ::mlir::linalg::LinalgOp target,
-        ::mlir::transform::ApplyToEachResultList &results,
-        ::mlir::transform::TransformState &state);
-  }];
 }
 
 //===----------------------------------------------------------------------===//
 // TransposeMatmulOp
 //===----------------------------------------------------------------------===//
 
-def TransposeMatmulOp : Op<Transform_Dialect,
-    "structured.transpose_matmul",
-    [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
-     TransformOpInterface, TransformEachOpTrait,
-     ReportTrackingListenerFailuresOpTrait]> {
+def TransposeMatmulOp
+    : TransformEachOp<
+      "structured.transpose_matmul",
+      "::mlir::linalg::LinalgOp",
+      [
+        FunctionalStyleTransformOpTrait,
+        MemoryEffectsOpInterface,
+        ReportTrackingListenerFailuresOpTrait
+      ]> {
   let description = [{
     Convert Linalg matmul ops to transposed variants.
 
@@ -2764,24 +2679,20 @@ def TransposeMatmulOp : Op<Transform_Dialect,
   let builders = [
     OpBuilder<(ins "Value":$target)>
   ];
-
-  let extraClassDeclaration = [{
-    ::mlir::DiagnosedSilenceableFailure applyToOne(
-        ::mlir::transform::TransformRewriter &rewriter,
-        ::mlir::linalg::LinalgOp target,
-        ::mlir::transform::ApplyToEachResultList &results,
-        ::mlir::transform::TransformState &state);
-  }];
 }
 
 //===----------------------------------------------------------------------===//
 // InsertSliceToCopyOp
 //===----------------------------------------------------------------------===//
 
-def InsertSliceToCopyOp :
-  Op<Transform_Dialect, "structured.insert_slice_to_copy",
-    [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
-     TransformEachOpTrait, TransformOpInterface]> {
+def InsertSliceToCopyOp
+    : TransformEachOp<
+      "structured.insert_slice_to_copy",
+      "::mlir::Operation*",
+      [
+        FunctionalStyleTransformOpTrait,
+        MemoryEffectsOpInterface
+      ]> {
   let description = [{
     Targeted rewrite of an tensor.insert_slice to linalg.copy.
     This is useful to materialize copies explicitly before bufferization and
@@ -2804,13 +2715,6 @@ def InsertSliceToCopyOp :
   let builders = [
     OpBuilder<(ins "Value":$target)>,
   ];
-  let extraClassDeclaration = [{
-    ::mlir::DiagnosedSilenceableFailure applyToOne(
-        ::mlir::transform::TransformRewriter &rewriter,
-        ::mlir::Operation *target,
-        ::mlir::transform::ApplyToEachResultList &results,
-        ::mlir::transform::TransformState &state);
-  }];
 }
 
 //===----------------------------------------------------------------------===//
@@ -2862,6 +2766,7 @@ def MapCopyToThreadsOp :
   let builders = [
     OpBuilder<(ins "Value":$target)>,
   ];
+  // TODO: Use TransformEachOp< once `extraClassDeclaration`s compose.
   let extraClassDeclaration = [{
     ::mlir::DiagnosedSilenceableFailure applyToOne(
         ::mlir::transform::TransformRewriter &rewriter,
@@ -2877,11 +2782,15 @@ def MapCopyToThreadsOp :
 // Winograd Conv2D
 //===----------------------------------------------------------------------===//
 
-def WinogradConv2DOp : Op<Transform_Dialect,
-    "structured.winograd_conv2d",
-    [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
-     TransformOpInterface, TransformEachOpTrait,
-     ReportTrackingListenerFailuresOpTrait]> {
+def WinogradConv2DOp
+    : TransformEachOp<
+      "structured.winograd_conv2d",
+      "::mlir::linalg::LinalgOp",
+      [
+        FunctionalStyleTransformOpTrait,
+        MemoryEffectsOpInterface,
+        ReportTrackingListenerFailuresOpTrait
+      ]> {
   let description = [{
     Winograd Conv2D algorithm will convert linalg Conv2D operation into batched
     matrix multiply. Before the matrix multiply, it will convert filter and
@@ -2913,21 +2822,17 @@ def WinogradConv2DOp : Op<Transform_Dialect,
   let builders = [
     OpBuilder<(ins "Value":$target)>
   ];
-
-  let extraClassDeclaration = [{
-    ::mlir::DiagnosedSilenceableFailure applyToOne(
-        ::mlir::transform::TransformRewriter &rewriter,
-        ::mlir::linalg::LinalgOp target,
-        ::mlir::transform::ApplyToEachResultList &results,
-        ::mlir::transform::TransformState &state);
-  }];
 }
 
-def DecomposeWinogradOp : Op<Transform_Dialect,
-    "structured.decompose_winograd_op",
-    [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
-     TransformOpInterface, TransformEachOpTrait,
-     ReportTrackingListenerFailuresOpTrait]> {
+def DecomposeWinogradOp
+    : TransformEachOp<
+      "structured.decompose_winograd_op",
+      "::mlir::Operation*",
+      [
+        FunctionalStyleTransformOpTrait,
+        MemoryEffectsOpInterface,
+        ReportTrackingListenerFailuresOpTrait
+      ]> {
   let description = [{
     Decompose winograd operations. It will convert filter, input and output
     transform operations into a combination of scf, tensor, and linalg
@@ -2950,14 +2855,6 @@ def DecomposeWinogradOp : Op<Transform_Dialect,
   let builders = [
     OpBuilder<(ins "Value":$target)>
   ];
-
-  let extraClassDeclaration = [{
-    ::mlir::DiagnosedSilenceableFailure applyToOne(
-        ::mlir::transform::TransformRewriter &rewriter,
-        ::mlir::Operation *target,
-        ::mlir::transform::ApplyToEachResultList &results,
-        ::mlir::transform::TransformState &state);
-  }];
 }
 
 #endif // LINALG_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td
index f4694a30a8a12..66d87436f51ef 100644
--- a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td
@@ -11,6 +11,7 @@
 
 include "mlir/Dialect/Transform/IR/TransformDialect.td"
 include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
+include "mlir/Dialect/Transform/Interfaces/Utils.td"
 include "mlir/Dialect/Transform/IR/TransformTypes.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/IR/OpBase.td"
@@ -238,11 +239,13 @@ def MemRefMultiBufferOp : Op<Transform_Dialect, "memref.multibuffer",
 }
 
 def MemRefEraseDeadAllocAndStoresOp
-    : Op<Transform_Dialect, "memref.erase_dead_alloc_and_stores", [
-      TransformEachOpTrait, TransformOpInterface,
-      DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
-      ReportTrackingListenerFailuresOpTrait
-    ]> {
+    : TransformEachOp<
+      "memref.erase_dead_alloc_and_stores",
+      "::mlir::Operation*",
+      [
+        DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+        ReportTrackingListenerFailuresOpTrait
+      ]> {
   let description = [{
     This applies memory optimization on memref. In particular it does store to
     load forwarding, dead store elimination and dead alloc/alloca elimination.
@@ -266,19 +269,16 @@ def MemRefEraseDeadAllocAndStoresOp
   let builders = [
     OpBuilder<(ins "Value":$target)>
   ];
-  let extraClassDeclaration = [{
-    ::mlir::DiagnosedSilenceableFailure applyToOne(
-        ::mlir::transform::TransformRewriter &rewriter,
-        ::mlir::Operation *target,
-        ::mlir::transform::ApplyToEachResultList &results,
-        ::mlir::transform::TransformState &state);
-  }];
 }
 
 def MemRefMakeLoopIndependentOp
-    : Op<Transform_Dialect, "memref.make_loop_independent",
-         [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
-          TransformOpInterface, TransformEachOpTrait]> {
+    : TransformEachOp<
+      "memref.make_loop_independent",
+      "::mlir::Operation*",
+      [
+        FunctionalStyleTransformOpTrait,
+        MemoryEffectsOpInterface
+      ]> {
   let description = [{
     Rewrite the targeted ops such that their index-typed operands no longer
     depend on any loop induction variable of the `num_loop` enclosing `scf.for`
@@ -307,14 +307,6 @@ def MemRefMakeLoopIndependentOp
   let results = (outs TransformHandleTypeInterface:$transformed);
   let assemblyFormat =
       "$target attr-dict `:` functional-type($target, $transformed)";
-
-  let extraClassDeclaration = [{
-    ::mlir::DiagnosedSilenceableFailure applyToOne(
-        ::mlir::transform::TransformRewriter &rewriter,
-        ::mlir::Operation *target,
-        ::mlir::transform::ApplyToEachResultList &results,
-        ::mlir::transform::TransformState &state);
-  }];
 }
 
 #endif // MEMREF_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.td b/mlir/include/mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.td
index 0225562baa58c..17da5ab507279 100644
--- a/mlir/include/mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.td
+++ b/mlir/include/mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.td
@@ -12,6 +12,7 @@
 include "mlir/Dialect/Transform/IR/TransformAttrs.td"
 include "mlir/Dialect/Transform/IR/TransformDialect.td"
 include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
+include "mlir/Dialect/Transform/Interfaces/Utils.td"
 include "mlir/Dialect/Transform/IR/TransformTypes.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 
@@ -34,12 +35,14 @@ def ApplyNVGPUToNVVMConversionPatternsOp : Op<Transform_Dialect,
 // CreateAsyncGroupsOp
 //===----------------------------------------------------------------------===//
 
-def CreateAsyncGroupsOp :
-  Op<Transform_Dialect, "nvgpu.create_async_groups",
-    [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
-     TransformEachOpTrait,
-     TransformOpInterface,
-     ReportTrackingListenerFailuresOpTrait]> {
+def CreateAsyncGroupsOp
+    : TransformEachOp<
+      "nvgpu.create_async_groups",
+      "::mlir::Operation*",
+      [
+        DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+        ReportTrackingListenerFailuresOpTrait
+      ]> {
   let description = [{
     Look for global to shared memory copies within the targeted op in the form
     of vector transfer ops and convert them to async copies when possible.
@@ -65,27 +68,21 @@ def CreateAsyncGroupsOp :
   let assemblyFormat = [{
     $target attr-dict `:` functional-type(operands, results)
   }];
-
-  let extraClassDeclaration = [{
-    ::mlir::DiagnosedSilenceableFailure applyToOne(
-        ::mlir::transform::TransformRewriter &rewriter,
-        ::mlir::Operation *target,
-        ::mlir::transform::ApplyToEachResultList &results,
-        ::mlir::transform::TransformState &state);
-  }];
 }
 
 //===----------------------------------------------------------------------===//
 // PipelineSharedMemoryCopiesOp
 //===----------------------------------------------------------------------===//
 
-def PipelineSharedMemoryCopiesOp :
-  Op<Transform_Dialect, "nvgpu.pipeline_shared_memory_copies",
-    [FunctionalStyleTransformOpTrait,
-     MemoryEffectsOpInterface,
-     TransformEachOpTrait,
-     TransformOpInterface,
-     ReportTrackingListenerFailuresOpTrait]> {
+def PipelineSharedMemoryCopiesOp
+    : TransformEachOp<
+      "nvgpu.pipeline_shared_memory_copies",
+      "::mlir::scf::ForOp",
+      [
+        FunctionalStyleTransformOpTrait,
+        MemoryEffectsOpInterface,
+        ReportTrackingListenerFailuresOpTrait
+      ]> {
   let summary =
     "Applies software pipelining to a given loop with shared memory copies";
 
@@ -136,27 +133,21 @@ def PipelineSharedMemoryCopiesOp :
     attr-dict 
     `:` functional-type(operands, results)
   }];
-
-  let extraClassDeclaration = [{
-    ::mlir::DiagnosedSilenceableFailure applyToOne(
-        ::mlir::transform::TransformRewriter &rewriter,
-        ::mlir::scf::ForOp forOp,
-        ::mlir::transform::ApplyToEachResultList &results,
-        ::mlir::transform::TransformState &state);
-  }];
 }
 
 //===----------------------------------------------------------------------===//
 // RewriteMatmulAsMmaSyncOp
 //===----------------------------------------------------------------------===//
 
-def RewriteMatmulAsMmaSyncOp :
-  Op<Transform_Dialect, "nvgpu.rewrite_matmul_as_mma_sync",
-    [FunctionalStyleTransformOpTrait, 
-     MemoryEffectsOpInterface,
-     TransformEachOpTrait, 
-     TransformOpInterface,
-     ReportTrackingListenerFailuresOpTrait]> {
+def RewriteMatmulAsMmaSyncOp
+    : TransformEachOp<
+      "nvgpu.rewrite_matmul_as_mma_sync",
+      "::mlir::linalg::LinalgOp",
+      [
+        FunctionalStyleTransformOpTrait,
+        MemoryEffectsOpInterface,
+        ReportTrackingListenerFailuresOpTrait
+      ]> {
   let description = [{
     Rewrite a matmul operation on memref to an mma.sync operation on vectors.
 
@@ -169,14 +160,6 @@ def RewriteMatmulAsMmaSyncOp :
   let results = (outs);
 
   let assemblyFormat = "$target attr-dict `:` functional-type(operands, results) ";
-
-  let extraClassDeclaration = [{
-    ::mlir::DiagnosedSilenceableFailure applyToOne(
-        ::mlir::transform::TransformRewriter &rewriter,
-        ::mlir::linalg::LinalgOp linalgOp,
-        ::mlir::transform::ApplyToEachResultList &results,
-        ::mlir::transform::TransformState &state);
-  }];
 }
 
 //===----------------------------------------------------------------------===//
@@ -200,6 +183,7 @@ def RewriteCopyAsTmaOp :
 
   let assemblyFormat = "$target attr-dict `:` functional-type(operands, results) ";
 
+  // TODO: use TransformEachOp when extended to op-less API.
   let extraClassDeclaration = [{
     ::mlir::DiagnosedSilenceableFailure apply(
         ::mlir::transform::TransformRewriter &rewriter,
diff --git a/mlir/include/mlir/Dialect/Transform/Interfaces/Utils.td b/mlir/include/mlir/Dialect/Transform/Interfaces/Utils.td
new file mode 100644
index 0000000000000..68b46f25b756c
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Transform/Interfaces/Utils.td
@@ -0,0 +1,25 @@
+#ifndef MLIR_DIALECT_TRANSFORM_INTERFACES_UTILS_TD
+#define MLIR_DIALECT_TRANSFORM_INTERFACES_UTILS_TD
+
+include "mlir/IR/OpBase.td"
+include "mlir/Dialect/Transform/IR/TransformDialect.td"
+include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+
+class TransformEachOp<string name,
+                      string kind = "::mlir::Operation *",
+                      list<Trait> traits = []>
+  : Op<Transform_Dialect, name, !listconcat(traits, [
+      TransformOpInterface, TransformEachOpTrait
+    ])> {
+  defvar applyMethodDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure applyToOne(
+      ::mlir::transform::TransformRewriter &rewriter,
+      }] # kind # [{ target,
+      ::mlir::transform::ApplyToEachResultList &results,
+      ::mlir::transform::TransformState &state);
+  }];
+  let extraClassDeclaration = applyMethodDeclaration;
+}
+
+#endif // MLIR_DIALECT_TRANSFORM_INTERFACES_UTILS_TD

>From ed41e82abb76c84589bfd05fae2b35f77bb09663 Mon Sep 17 00:00:00 2001
From: Nicolas Vasilache <nico.vasilache at amd.com>
Date: Thu, 3 Jul 2025 18:32:59 +0200
Subject: [PATCH 3/3] [mlir] NFC - refactor id builder and avoid leaking impl
 details

---
 .../mlir/Dialect/GPU/TransformOps/Utils.h     |  31 ++--
 .../GPU/TransformOps/GPUTransformOps.cpp      |  32 +---
 mlir/lib/Dialect/GPU/TransformOps/Utils.cpp   | 165 ++++++++++++++----
 3 files changed, 150 insertions(+), 78 deletions(-)

diff --git a/mlir/include/mlir/Dialect/GPU/TransformOps/Utils.h b/mlir/include/mlir/Dialect/GPU/TransformOps/Utils.h
index 52fc6f4d5c71b..0cd15835ad70f 100644
--- a/mlir/include/mlir/Dialect/GPU/TransformOps/Utils.h
+++ b/mlir/include/mlir/Dialect/GPU/TransformOps/Utils.h
@@ -28,27 +28,24 @@ namespace transform {
 namespace gpu {
 
 /// Helper type for functions that generate ids for the mapping of a scf.forall.
-/// Operates on both 1) an "original" basis that represents the individual
-/// thread and block ids and 2) a "scaled" basis that represents grouped ids
-/// (e.g. block clusters, warpgroups and warps).
-/// The mapping of ids is done in the "scaled" basis (i.e. when mapping to warps
-/// a division by 32 occurs).
-/// The predication is in the "original" basis using the "active" quantities
-/// (`activeMappingSizes`, `availableMappingSizes` and `activeIdOps`).
 struct IdBuilderResult {
-  // Ops used to replace the forall induction variables.
+  /// Error message, if not empty then building the ids failed.
+  std::string errorMsg;
+  /// Values used to replace the forall induction variables.
   SmallVector<Value> mappingIdOps;
-  // Available mapping sizes used to predicate the forall body when they are
-  // larger than the predicate mapping sizes.
-  SmallVector<int64_t> availableMappingSizes;
-  // Actual mapping sizes used to predicate the forall body when they are
-  // smaller than the available mapping sizes.
-  SmallVector<int64_t> activeMappingSizes;
-  // Ops used to predicate the forall body when activeMappingSizes is smaller
-  // than the available mapping sizes.
-  SmallVector<Value> activeIdOps;
+  /// Values used to predicate the forall body when activeMappingSizes is
+  /// smaller than the available mapping sizes.
+  SmallVector<Value> predicateOps;
 };
 
+inline raw_ostream &operator<<(raw_ostream &os, const IdBuilderResult &res) {
+  llvm::interleaveComma(res.mappingIdOps, os << "----mappingIdOps: ");
+  os << "\n";
+  llvm::interleaveComma(res.predicateOps, os << "----predicateOps: ");
+  os << "\n";
+  return os;
+}
+
 /// Common gpu id builder type, allows the configuration of lowering for various
 /// mapping schemes. Takes:
 ///   - A rewriter with insertion point set before the forall op to rewrite.
diff --git a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
index 6446235c06fb2..1ae923db149a5 100644
--- a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
+++ b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
@@ -480,6 +480,10 @@ static DiagnosedSilenceableFailure rewriteOneForallCommonImpl(
 
   IdBuilderResult builderResult =
       gpuIdBuilder.idBuilder(rewriter, loc, forallMappingSizes, originalBasis);
+  if (!builderResult.errorMsg.empty())
+    return definiteFailureHelper(transformOp, forallOp, builderResult.errorMsg);
+
+  LLVM_DEBUG(DBGS() << builderResult);
 
   // Step 4. Map the induction variables to the mappingIdOps, this may involve
   // a permutation.
@@ -490,6 +494,7 @@ static DiagnosedSilenceableFailure rewriteOneForallCommonImpl(
            forallMappingAttrs.getArrayRef().take_front(forallOp.getRank()))) {
     auto mappingAttr = cast<DeviceMappingAttrInterface>(dim);
     Value peIdOp = mappingIdOps[mappingAttr.getRelativeIndex()];
+    LDBG("----map: " << iv << " to " << peIdOp);
     bvm.map(iv, peIdOp);
   }
 
@@ -498,32 +503,7 @@ static DiagnosedSilenceableFailure rewriteOneForallCommonImpl(
   // originalBasis and no predication occurs.
   Value predicate;
   if (originalBasisWasProvided) {
-    SmallVector<int64_t> activeMappingSizes = builderResult.activeMappingSizes;
-    SmallVector<int64_t> availableMappingSizes =
-        builderResult.availableMappingSizes;
-    SmallVector<Value> activeIdOps = builderResult.activeIdOps;
-    LDBG("----activeMappingSizes: " << llvm::interleaved(activeMappingSizes));
-    LDBG("----availableMappingSizes: "
-         << llvm::interleaved(availableMappingSizes));
-    LDBG("----activeIdOps: " << llvm::interleaved(activeIdOps));
-    for (auto [activeId, activeMappingSize, availableMappingSize] :
-         llvm::zip_equal(activeIdOps, activeMappingSizes,
-                         availableMappingSizes)) {
-      if (activeMappingSize > availableMappingSize) {
-        return definiteFailureHelper(
-            transformOp, forallOp,
-            "Trying to map to fewer GPU threads than loop iterations but "
-            "overprovisioning is not yet supported. "
-            "Try additional tiling of the before mapping or map to more "
-            "threads.");
-      }
-      if (activeMappingSize == availableMappingSize)
-        continue;
-      Value idx =
-          rewriter.create<arith::ConstantIndexOp>(loc, activeMappingSize);
-      Value tmpPredicate = rewriter.create<arith::CmpIOp>(
-          loc, arith::CmpIPredicate::ult, activeId, idx);
-      LDBG("----predicate: " << tmpPredicate);
+    for (Value tmpPredicate : builderResult.predicateOps) {
       predicate = predicate ? rewriter.create<arith::AndIOp>(loc, predicate,
                                                              tmpPredicate)
                             : tmpPredicate;
diff --git a/mlir/lib/Dialect/GPU/TransformOps/Utils.cpp b/mlir/lib/Dialect/GPU/TransformOps/Utils.cpp
index 9853e80828390..00bb159e868c0 100644
--- a/mlir/lib/Dialect/GPU/TransformOps/Utils.cpp
+++ b/mlir/lib/Dialect/GPU/TransformOps/Utils.cpp
@@ -47,12 +47,57 @@ using namespace mlir::transform::gpu;
 #define LDBG(X) LLVM_DEBUG(DBGS() << (X) << "\n")
 #define DBGS_ALIAS() (llvm::dbgs() << '[' << DEBUG_TYPE_ALIAS << "] ")
 
+/// Build predicates to filter execution by only the activeIds. Along each
+/// dimension, 3 cases appear:
+///   1. activeMappingSize > availableMappingSize: this is an unsupported case
+///      as this requires additional looping. An error message is produced to
+///      advise the user to tile more or to use more threads.
+///   2. activeMappingSize == availableMappingSize: no predication is needed.
+///   3. activeMappingSize < availableMappingSize: only a subset of threads
+///      should be active and we produce the boolean `id < activeMappingSize`
+///      for further use in building predicated execution.
+static FailureOr<SmallVector<Value>>
+buildPredicates(RewriterBase &rewriter, Location loc, ArrayRef<Value> activeIds,
+                ArrayRef<int64_t> activeMappingSizes,
+                ArrayRef<int64_t> availableMappingSizes,
+                std::string &errorMsg) {
+  // clang-format off
+  LLVM_DEBUG(
+    llvm::interleaveComma(
+      activeMappingSizes, DBGS() << "----activeMappingSizes: ");
+    DBGS() << "\n";
+    llvm::interleaveComma(
+      availableMappingSizes, DBGS() << "----availableMappingSizes: ");
+    DBGS() << "\n";);
+  // clang-format on
+
+  SmallVector<Value> predicateOps;
+  for (auto [activeId, activeMappingSize, availableMappingSize] :
+       llvm::zip_equal(activeIds, activeMappingSizes, availableMappingSizes)) {
+    if (activeMappingSize > availableMappingSize) {
+      errorMsg = "Trying to map to fewer GPU threads than loop iterations but "
+                 "overprovisioning is not yet supported. Try additional tiling "
+                 "before mapping or map to more threads.";
+      return failure();
+    }
+    if (activeMappingSize == availableMappingSize)
+      continue;
+    Value idx = rewriter.create<arith::ConstantIndexOp>(loc, activeMappingSize);
+    Value pred = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
+                                                activeId, idx);
+    predicateOps.push_back(pred);
+  }
+  return predicateOps;
+}
+
 /// Return a flattened thread id for the workgroup with given sizes.
 template <typename ThreadOrBlockIdOp>
 static Value buildLinearId(RewriterBase &rewriter, Location loc,
                            ArrayRef<OpFoldResult> originalBasisOfr) {
-  LLVM_DEBUG(DBGS() << "----buildLinearId with originalBasisOfr:  "
-                    << llvm::interleaved(originalBasisOfr) << "\n");
+  LLVM_DEBUG(llvm::interleaveComma(
+                 originalBasisOfr,
+                 DBGS() << "----buildLinearId with originalBasisOfr:  ");
+             llvm::dbgs() << "\n");
   assert(originalBasisOfr.size() == 3 && "expected 3 sizes");
   IndexType indexType = rewriter.getIndexType();
   AffineExpr tx, ty, tz, bdx, bdy;
@@ -79,44 +124,43 @@ static GpuIdBuilderFnType commonLinearIdBuilderFn(int64_t multiplicity = 1) {
   auto res = [multiplicity](RewriterBase &rewriter, Location loc,
                             ArrayRef<int64_t> forallMappingSizes,
                             ArrayRef<int64_t> originalBasis) {
+    // 1. Compute linearId.
     SmallVector<OpFoldResult> originalBasisOfr =
         getAsIndexOpFoldResult(rewriter.getContext(), originalBasis);
-    OpFoldResult linearId =
+    Value physicalLinearId =
         buildLinearId<ThreadOrBlockIdOp>(rewriter, loc, originalBasisOfr);
+
+    // 2. Compute scaledLinearId.
+    AffineExpr d0 = getAffineDimExpr(0, rewriter.getContext());
+    OpFoldResult scaledLinearId = affine::makeComposedFoldedAffineApply(
+        rewriter, loc, d0.floorDiv(multiplicity), {physicalLinearId});
+
+    // 3. Compute remapped indices.
+    SmallVector<Value> ids;
     // Sizes in [0 .. n] -> [n .. 0] order to properly compute strides in
     // "row-major" order.
     SmallVector<int64_t> reverseBasisSizes(llvm::reverse(forallMappingSizes));
     SmallVector<int64_t> strides = computeStrides(reverseBasisSizes);
-    AffineExpr d0 = getAffineDimExpr(0, rewriter.getContext());
-    OpFoldResult scaledLinearId = affine::makeComposedFoldedAffineApply(
-        rewriter, loc, d0.floorDiv(multiplicity), {linearId});
     SmallVector<AffineExpr> delinearizingExprs = delinearize(d0, strides);
-    SmallVector<Value> ids;
     // Reverse back to be in [0 .. n] order.
     for (AffineExpr e : llvm::reverse(delinearizingExprs)) {
       ids.push_back(
           affine::makeComposedAffineApply(rewriter, loc, e, {scaledLinearId}));
     }
 
-    LLVM_DEBUG(DBGS() << "--delinearization basis: "
-                      << llvm::interleaved(reverseBasisSizes) << "\n";
-               DBGS() << "--delinearization strides: "
-                      << llvm::interleaved(strides) << "\n";
-               DBGS() << "--delinearization exprs: "
-                      << llvm::interleaved(delinearizingExprs) << "\n";
-               DBGS() << "--ids: " << llvm::interleaved(ids) << "\n");
-
-    // Return n-D ids for indexing and 1-D size + id for predicate generation.
-    return IdBuilderResult{
-        /*mappingIdOps=*/ids,
-        /*availableMappingSizes=*/
-        SmallVector<int64_t>{computeProduct(originalBasis)},
-        // `forallMappingSizes` iterate in the scaled basis, they need to be
-        // scaled back into the original basis to provide tight
-        // activeMappingSizes quantities for predication.
-        /*activeMappingSizes=*/
-        SmallVector<int64_t>{computeProduct(forallMappingSizes) * multiplicity},
-        /*activeIdOps=*/SmallVector<Value>{cast<Value>(linearId)}};
+    // 4. Handle predicates using physicalLinearId.
+    std::string errorMsg;
+    SmallVector<Value> predicateOps;
+    FailureOr<SmallVector<Value>> maybePredicateOps =
+        buildPredicates(rewriter, loc, physicalLinearId,
+                        computeProduct(forallMappingSizes) * multiplicity,
+                        computeProduct(originalBasis), errorMsg);
+    if (succeeded(maybePredicateOps))
+      predicateOps = *maybePredicateOps;
+
+    return IdBuilderResult{/*errorMsg=*/errorMsg,
+                           /*mappingIdOps=*/ids,
+                           /*predicateOps=*/predicateOps};
   };
 
   return res;
@@ -143,16 +187,67 @@ static GpuIdBuilderFnType common3DIdBuilderFn(int64_t multiplicity = 1) {
     // In the 3-D mapping case, unscale the first dimension by the multiplicity.
     SmallVector<int64_t> forallMappingSizeInOriginalBasis(forallMappingSizes);
     forallMappingSizeInOriginalBasis[0] *= multiplicity;
-    return IdBuilderResult{
-        /*mappingIdOps=*/scaledIds,
-        /*availableMappingSizes=*/SmallVector<int64_t>{originalBasis},
-        // `forallMappingSizes` iterate in the scaled basis, they need to be
-        // scaled back into the original basis to provide tight
-        // activeMappingSizes quantities for predication.
-        /*activeMappingSizes=*/
-        SmallVector<int64_t>{forallMappingSizeInOriginalBasis},
-        /*activeIdOps=*/ids};
+
+    std::string errorMsg;
+    SmallVector<Value> predicateOps;
+    FailureOr<SmallVector<Value>> maybePredicateOps =
+        buildPredicates(rewriter, loc, ids, forallMappingSizeInOriginalBasis,
+                        originalBasis, errorMsg);
+    if (succeeded(maybePredicateOps))
+      predicateOps = *maybePredicateOps;
+
+    return IdBuilderResult{/*errorMsg=*/errorMsg,
+                           /*mappingIdOps=*/scaledIds,
+                           /*predicateOps=*/predicateOps};
+  };
+  return res;
+}
+
+/// Create a lane id builder that takes the `originalBasis` and decompose
+/// it in the basis of `forallMappingSizes`. The linear id builder returns an
+/// n-D vector of ids for indexing and 1-D size + id for predicate generation.
+static GpuIdBuilderFnType laneIdBuilderFn(int64_t warpSize) {
+  auto res = [warpSize](RewriterBase &rewriter, Location loc,
+                        ArrayRef<int64_t> forallMappingSizes,
+                        ArrayRef<int64_t> originalBasis) {
+    // 1. Compute linearId.
+    SmallVector<OpFoldResult> originalBasisOfr =
+        getAsIndexOpFoldResult(rewriter.getContext(), originalBasis);
+    Value physicalLinearId =
+        buildLinearId<ThreadIdOp>(rewriter, loc, originalBasisOfr);
+
+    // 2. Compute laneId.
+    AffineExpr d0 = getAffineDimExpr(0, rewriter.getContext());
+    OpFoldResult laneId = affine::makeComposedFoldedAffineApply(
+        rewriter, loc, d0 % warpSize, {physicalLinearId});
+
+    // 3. Compute remapped indices.
+    SmallVector<Value> ids;
+    // Sizes in [0 .. n] -> [n .. 0] order to properly compute strides in
+    // "row-major" order.
+    SmallVector<int64_t> reverseBasisSizes(llvm::reverse(forallMappingSizes));
+    SmallVector<int64_t> strides = computeStrides(reverseBasisSizes);
+    SmallVector<AffineExpr> delinearizingExprs = delinearize(d0, strides);
+    // Reverse back to be in [0 .. n] order.
+    for (AffineExpr e : llvm::reverse(delinearizingExprs)) {
+      ids.push_back(
+          affine::makeComposedAffineApply(rewriter, loc, e, {laneId}));
+    }
+
+    // 4. Handle predicates using laneId.
+    std::string errorMsg;
+    SmallVector<Value> predicateOps;
+    FailureOr<SmallVector<Value>> maybePredicateOps = buildPredicates(
+        rewriter, loc, cast<Value>(laneId), computeProduct(forallMappingSizes),
+        computeProduct(originalBasis), errorMsg);
+    if (succeeded(maybePredicateOps))
+      predicateOps = *maybePredicateOps;
+
+    return IdBuilderResult{/*errorMsg=*/errorMsg,
+                           /*mappingIdOps=*/ids,
+                           /*predicateOps=*/predicateOps};
   };
+
   return res;
 }
 



More information about the Mlir-commits mailing list