[Mlir-commits] [mlir] [mlir] NFC - refactor id builder and avoid leaking impl details (PR #146922)
Nicolas Vasilache
llvmlistbot at llvm.org
Mon Jul 7 06:20:31 PDT 2025
https://github.com/nicolasvasilache updated https://github.com/llvm/llvm-project/pull/146922
>From 30bfea8439d866e5bc8ec1c0eafb9161ae585a40 Mon Sep 17 00:00:00 2001
From: Nicolas Vasilache <nico.vasilache at amd.com>
Date: Fri, 4 Jul 2025 10:32:39 +0200
Subject: [PATCH 1/3] [mlir][python] Add utils for more pythonic context
creation and registration management
Co-authored-by: Fabian Mora <fmora.dev at gmail.com
Co-authored-by: Oleksandr "Alex" Zinenko <git at ozinenko.com>
Co-authored-by: Tres <tpopp at users.noreply.github.com>
---
mlir/include/mlir-c/IR.h | 4 +
mlir/lib/Bindings/Python/IRCore.cpp | 6 +
mlir/lib/CAPI/IR/IR.cpp | 4 +
mlir/python/CMakeLists.txt | 7 +
mlir/python/mlir/_mlir_libs/_mlir/ir.pyi | 1 +
mlir/python/mlir/utils.py | 489 +++++++++++++++++++++++
mlir/test/python/utils.py | 58 +++
7 files changed, 569 insertions(+)
create mode 100644 mlir/python/mlir/utils.py
create mode 100644 mlir/test/python/utils.py
diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 81299c7911d24..877aa73ca2cc0 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -143,6 +143,10 @@ MLIR_CAPI_EXPORTED MlirDialect mlirContextGetOrLoadDialect(MlirContext context,
MLIR_CAPI_EXPORTED void mlirContextEnableMultithreading(MlirContext context,
bool enable);
+/// Retrieve threading mode current value as controlled by
+/// mlirContextEnableMultithreading.
+MLIR_CAPI_EXPORTED bool mlirContextIsMultithreadingEnabled(MlirContext context);
+
/// Eagerly loads all available dialects registered with a context, making
/// them available for use for IR construction.
MLIR_CAPI_EXPORTED void
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index d961482885300..002923becd23a 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -2939,6 +2939,12 @@ void mlir::python::populateIRCore(nb::module_ &m) {
ss << pool.ptr;
return ss.str();
})
+ .def_prop_ro(
+ "is_multithreading_enabled",
+ [](PyMlirContext &self) {
+ return mlirContextIsMultithreadingEnabled(self.get());
+ },
+ "Returns true if multithreading is enabled for this context.")
.def(
"is_registered_operation",
[](PyMlirContext &self, std::string &name) {
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index fbc66bcf5c2d0..1cc555ad41de1 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -101,6 +101,10 @@ bool mlirContextIsRegisteredOperation(MlirContext context, MlirStringRef name) {
return unwrap(context)->isOperationRegistered(unwrap(name));
}
+bool mlirContextIsMultithreadingEnabled(MlirContext context) {
+ return unwrap(context)->isMultithreadingEnabled();
+}
+
void mlirContextEnableMultithreading(MlirContext context, bool enable) {
return unwrap(context)->enableMultithreading(enable);
}
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index b2daabb2a5957..b4e0ab2ec6dae 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -48,6 +48,13 @@ declare_mlir_python_sources(MLIRPythonSources.ExecutionEngine
runtime/*.py
)
+declare_mlir_python_sources(MLIRPythonSources.Utils
+ ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
+ ADD_TO_PARENT MLIRPythonSources
+ SOURCES
+ utils.py
+)
+
declare_mlir_python_sources(MLIRPythonCAPI.HeaderSources
ROOT_DIR "${MLIR_SOURCE_DIR}/include"
SOURCES_GLOB "mlir-c/*.h"
diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
index 70bca3c75d842..56b9f17d52ea7 100644
--- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
+++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
@@ -986,6 +986,7 @@ class ComplexType(Type):
class Context:
current: ClassVar[Context] = ... # read-only
allow_unregistered_dialects: bool
+ is_multithreading_enabled: bool
@staticmethod
def _get_live_count() -> int: ...
def _CAPICreate(self) -> object: ...
diff --git a/mlir/python/mlir/utils.py b/mlir/python/mlir/utils.py
new file mode 100644
index 0000000000000..10c316bc8f086
--- /dev/null
+++ b/mlir/python/mlir/utils.py
@@ -0,0 +1,489 @@
+from contextlib import contextmanager, nullcontext
+from functools import wraps
+from typing import (
+ Any,
+ Callable,
+ Concatenate,
+ Iterator,
+ Optional,
+ ParamSpec,
+ Sequence,
+ TypeVar,
+)
+
+from mlir import ir
+from mlir._mlir_libs import get_dialect_registry
+from mlir.dialects import func
+from mlir.dialects.transform import interpreter
+from mlir.passmanager import PassManager
+
+RT = TypeVar("RT")
+Param = ParamSpec("Param")
+
+
+ at contextmanager
+def using_mlir_context(
+ *,
+ required_dialects: Optional[Sequence[str]] = None,
+ required_extension_operations: Optional[Sequence[str]] = None,
+ registration_funcs: Optional[Sequence[Callable[[ir.DialectRegistry], None]]] = None,
+) -> Iterator[None]:
+ """Ensure a valid context exists by creating one if necessary.
+
+ NOTE: If values that are attached to a Context should outlive this
+ contextmanager, use caller_mlir_context!
+
+ This can be used as a function decorator or managed context in a with statement.
+ The context will throw an error if the required dialects have not been registered,
+ and a context is guaranteed to exist in this scope.
+
+ This only works on dialects and not dialect extensions currently.
+
+ Parameters
+ ------------
+ required_dialects:
+ Dialects that need to be registered in the context
+ required_extension_operations:
+ Required operations by their fully specified name. These are a proxy for detecting needed dialect extensions.
+ registration_funcs:
+ Functions that should be called to register all missing dialects/operations if they have not been registered.
+ """
+ dialects = required_dialects or []
+ extension_operations = required_extension_operations or []
+ registrations = registration_funcs or []
+ new_context = nullcontext if ir.Context.current else ir.Context
+ with new_context(), ir.Location.unknown():
+ context = ir.Context.current
+ # Attempt to disable multithreading. This could fail if currently being
+ # used in multiple threads. This must be done before checking for
+ # dialects or registering dialects as both will assert fail in a
+ # multithreaded situation.
+ multithreading = context.is_multithreading_enabled
+ if multithreading:
+ context.enable_multithreading(False)
+
+ def attempt_registration():
+ """Register everything from registration_funcs."""
+ nonlocal context, registrations
+
+ # Gather dialects and extensions then add them to the context.
+ registry = ir.DialectRegistry()
+ for rf in registrations:
+ rf(registry)
+
+ context.append_dialect_registry(registry)
+
+ # See if any dialects are missing, register if they are, and then assert they are all registered.
+ try:
+ for dialect in dialects:
+ # If the dialect is registered, continue checking
+ context.get_dialect_descriptor(dialect)
+ except Exception:
+ attempt_registration()
+
+ for dialect in dialects:
+ # If the dialect is registered, continue checking
+ assert context.get_dialect_descriptor(
+ dialect
+ ), f"required dialect {dialect} not registered by registration_funcs"
+
+ # See if any operations are missing and register if they are. We cannot
+ # assert the operations exist in the registry after for some reason.
+ #
+ # TODO: Make this work for dialect extensions specifically
+ for operation in extension_operations:
+ # If the operation is registered, attempt to register and then strongly assert it was added
+ if not context.is_registered_operation(operation):
+ attempt_registration()
+ break
+ for operation in extension_operations:
+ # First get the dialect descriptior which loads the dialect as a side effect
+ dialect = operation.split(".")[0]
+ assert context.get_dialect_descriptor(dialect), f"Never loaded {dialect}"
+ assert context.is_registered_operation(
+ operation
+ ), f"expected {operation} to be registered in its dialect"
+ context.enable_multithreading(multithreading)
+
+ # Context manager related yield
+ try:
+ yield
+ finally:
+ pass
+
+
+ at contextmanager
+def caller_mlir_context(
+ *,
+ required_dialects: Optional[Sequence[str]] = None,
+ required_extension_operations: Optional[Sequence[str]] = None,
+ registration_funcs: Optional[Sequence[Callable[[ir.DialectRegistry], None]]] = None,
+) -> Iterator[None]:
+ """Requires an enclosing context from the caller and ensures relevant operations are loaded.
+
+ NOTE: If the Context is only needed inside of this contextmanager and returned values
+ don't need to the Context, use using_mlir_context!
+
+ A context must already exist before this frame is executed to ensure that any values
+ continue to live on exit. Conceptually, this prevents use-after-free issues and
+ makes the intention clear when one intends to return values tied to a Context.
+ """
+ assert (
+ ir.Context.current
+ ), "Caller must have a context so it outlives this function call."
+ with using_mlir_context(
+ required_dialects=required_dialects,
+ required_extension_operations=required_extension_operations,
+ registration_funcs=registration_funcs,
+ ):
+ # Context manager related yield
+ try:
+ yield
+ finally:
+ pass
+
+
+def with_toplevel_context(f: Callable[Param, RT]) -> Callable[Param, RT]:
+ """Decorate the function to be executed with a fresh MLIR context.
+
+ This decorator will ensure the function is executed inside a context manager for a
+ new MLIR context with upstream and IREE dialects registered. Note that each call to
+ such a function has a new context, meaning that context-owned objects from these
+ functions will not be equal to each other. All arguments and keyword arguments are
+ forwarded.
+
+ The context is destroyed before the function exits so any result from the function
+ must not depend on the context.
+ """
+
+ @wraps(f)
+ def decorator(*args: Param.args, **kwargs: Param.kwargs) -> RT:
+ # Appending dialect registry and loading all available dialects occur on
+ # context creation because of the "_site_initialize" call.
+ with ir.Context(), ir.Location.unknown():
+ results = f(*args, **kwargs)
+ return results
+
+ return decorator
+
+
+def with_toplevel_context_create_module(
+ f: Callable[Concatenate[ir.Module, Param], RT],
+) -> Callable[Param, RT]:
+ """Decorate function to be executed in a fresh MLIR context and give it a module.
+
+ The decorated function will receive, as its leading argument, a fresh MLIR module.
+ The context manager is set up to insert operations into this module. All other
+ arguments and keyword arguments are forwarded.
+
+ The module and context are destroyed before the function exists so any result from
+ the function must not depend on either.
+ """
+
+ @with_toplevel_context
+ @wraps(f)
+ def internal(*args: Param.args, **kwargs: Param.kwargs) -> RT:
+ module = ir.Module.create()
+ with ir.InsertionPoint(module.body):
+ results = f(module, *args, **kwargs)
+ return results
+
+ return internal
+
+
+def call_with_toplevel_context(f: Callable[[], RT]) -> Callable[[], RT]:
+ """Immediately call the function in a fresh MLIR context."""
+ decorated = with_toplevel_context(f)
+ decorated()
+ return decorated
+
+
+def call_with_toplevel_context_create_module(
+ f: Callable[[ir.Module], RT],
+) -> Callable[[], RT]:
+ """Immediately call the function in a fresh MLIR context and give it a module.
+
+ The decorated function will receive, as its only argument, a fresh MLIR module. The
+ context manager is set up to insert operations into this module.
+ """
+ decorated = with_toplevel_context_create_module(f)
+ decorated()
+ return decorated
+
+
+# def _debug_types_impl(types: Sequence[str]) -> Iterator[None]:
+# from mlir.ir import _GlobalDebug
+
+# # Save the original debug state. The debug types will be popped rather than
+# # manually copied and saved for later.
+# original_flag = _GlobalDebug.flag
+# _GlobalDebug.flag = True
+# _GlobalDebug.append_types(types)
+
+# try:
+# yield
+# finally:
+# # Reset the global debug flag and remove the most recent types that were
+# # appended. This assumes that nothing else popped when it should not have.
+# _GlobalDebug.flag = original_flag
+# _GlobalDebug.pop_types()
+
+
+# @contextmanager
+# def debug_types_context(types: Sequence[str]):
+# """Temporarily create a context that enables debugging with specified filters.
+
+# These would be the same as running with -debug-only=*types. Where multiple contexts
+# will be joined together to create the full list if they are nested.
+
+# This requires that the core MLIR units were compiled without NDEBUG.
+# """
+# return _debug_types_impl(types)
+
+
+# @contextmanager
+# def debug_td(types: Sequence[str] = [], *, full_debug: bool = False) -> Iterator[None]:
+# """Temporarily create a context that enables full transform dialect debugging,
+# potentially with additional specified filters.
+
+# These would be the same as running with -debug-only=*types. Where multiple contexts
+# will be joined together to create the full list if they are nested.
+
+# This requires that the core MLIR units were compiled without NDEBUG.
+# """
+# return _debug_types_impl(
+# list(types)
+# + [
+# "transform-dialect",
+# "transform-dialect-print-top-level-after-all",
+# "async-transform-ops",
+# ]
+# + (["transform-dialect-full"] if full_debug else [])
+# )
+
+
+# @contextmanager
+# def debug_conversion(types: Sequence[str] = []) -> Iterator[None]:
+# """Temporarily create a context that enables full conversion debugging,
+# potentially with additional specified filters.
+
+# These would be the same as running with -debug-only=*types. Where multiple contexts
+# will be joined together to create the full list if they are nested.
+
+# This requires that the core MLIR units were compiled without NDEBUG.
+# """
+# return _debug_types_impl(["dialect-conversion"])
+
+
+# # TODO: Allow cloning functions from one module to another.
+# # Atm we have to resort to string concatenation.
+# def _create_module_from_main(main):
+# module = ir.Module.create()
+# ops = module.operation.regions[0].blocks[0].operations
+# return ir.Module.parse("\n".join([str(op) for op in ops]) + main)
+
+
+# def _add_lowering_passes(
+# pm, linalg_lowering, vector_lowering, memref_lowering, loop_lowering
+# ):
+# if loop_lowering:
+# pm.add("func.func(expand-strided-metadata)")
+# pm.add("func.func(lower-affine)")
+# if linalg_lowering:
+# pm.add("func.func(convert-linalg-to-loops)")
+# pm.add("func.func(convert-scf-to-cf)")
+# pm.add("convert-cf-to-llvm")
+# if memref_lowering:
+# pm.add("finalize-memref-to-llvm")
+# if vector_lowering:
+# pm.add("convert-vector-to-llvm")
+# pm.add("convert-arith-to-llvm")
+# pm.add("convert-index-to-llvm")
+# pm.add("convert-func-to-llvm")
+# pm.add("reconcile-unrealized-casts")
+
+
+# def lowering_to_llvm_via_passmanager(
+# module=None,
+# main=None,
+# linalg_lowering=False,
+# vector_lowering=False,
+# memref_lowering=False,
+# loop_lowering=False,
+# ):
+# """Transforms a module using registered passes and pass manager.
+
+# Parameters
+# ----------
+# module: ir.Module
+# Module to be transformed.
+
+# main: str
+# Main function embedded into a string. Will be used to create the module to be transformed.
+
+# linalg_lowering: bool
+# Flag indicating whether linalg lowering should be included in the transformation.
+
+# vector_lowering: bool
+# Flag indicating whether vector lowering should be included in the transformation.
+
+# memref_lowering: bool
+# Flag indicating whether memref lowering should be included in the transformation.
+
+# loop_lowering: bool
+# Flag indicating whether loop lowering should be included in the transformation.
+
+# Returns
+# -------
+# The transformed module.
+# """
+
+# if module is None:
+# module = _create_module_from_main(main)
+# pm = PassManager("builtin.module")
+# _add_lowering_passes(
+# pm, linalg_lowering, vector_lowering, memref_lowering, loop_lowering
+# )
+# pm.run(module.operation)
+# return module
+
+
+# def apply_named_sequence_from_module(
+# *, payload: ir.Module, transform: ir.Module, entry_point: str = "__transform_main"
+# ):
+# interpreter.apply_named_sequence(
+# payload, ir.SymbolTable(transform.operation)[entry_point], transform.operation
+# )
+
+
+# def build_and_run_common_pass_pipeline(module: ir.Module, nvvm_attach_target_opts: str):
+# pm = PassManager("builtin.module")
+# pm.add("convert-nvgpu-to-nvvm")
+# pm.add("convert-scf-to-cf")
+# pm.add("gpu-kernel-outlining")
+# pm.add("convert-func-to-llvm")
+# pm.add("expand-strided-metadata")
+# pm.add(f"nvvm-attach-target{{{nvvm_attach_target_opts}}}")
+# pm.add("lower-affine")
+# pm.add("convert-arith-to-llvm")
+# pm.add("convert-index-to-llvm")
+# pm.run(module.operation)
+# return module
+
+
+# def build_and_run_gpu_pass_pipeline(module: ir.Module, convert_gpu_to_nvvm_opts: str):
+# pm = PassManager("builtin.module")
+# pm.add(f"gpu.module(convert-gpu-to-nvvm{{{convert_gpu_to_nvvm_opts}}})")
+# pm.run(module.operation)
+# return module
+
+
+# def build_and_run_host_post_pipeline(module: ir.Module):
+# pm = PassManager("builtin.module")
+# pm.add("gpu-to-llvm")
+# pm.add("convert-nvvm-to-llvm")
+# pm.add("reconcile-unrealized-casts")
+# pm.run(module.operation)
+# return module
+
+
+# def gpu_module_to_binary(module: ir.Module, cubin_format: str):
+# pm = PassManager("builtin.module")
+# pm.add(f"gpu-module-to-binary{{format={cubin_format}}}")
+# pm.run(module.operation)
+# return module
+
+
+# def lower_nvgpu(
+# module: ir.Module,
+# nvvm_attach_target_opts: str,
+# convert_gpu_to_nvvm_opts: str,
+# cubin_format: str,
+# ) -> ir.Module:
+# """Lowers `module` to LLVM IR, targeting GPU, using registered passes and pass manager.
+
+# Parameters
+# ----------
+# module: ir.Module
+# Module to be lowered.
+
+# nvvm_attach_target_opts: str
+# Options describing the NVVM target to be attached as an attribute to the GPU module.
+
+# convert_gpu_to_nvvm_opts: str
+# Options for generating NVVM operations from GPU ones.
+
+# cubin_format: str
+# Compilation format for serializing to cubin.
+
+# Returns
+# -------
+# The LLVM IR module.
+# """
+
+# module = build_and_run_common_pass_pipeline(module, nvvm_attach_target_opts)
+# module = build_and_run_gpu_pass_pipeline(module, convert_gpu_to_nvvm_opts)
+# module = build_and_run_host_post_pipeline(module)
+# return gpu_module_to_binary(module, cubin_format)
+
+
+# def rename_all_symbols(
+# module: ir.Module,
+# rename_fn: Callable[[str], Optional[str]],
+# target_op: str = "func.func",
+# ):
+# """Renames all the symbols in `module` using the `rename_fn` function.
+# If `rename_fn` returns `None`, then the symbol is not renamed.
+
+# Parameters
+# ----------
+# @param: module The MLIR module.
+# @param: rename_fn A callable that takes the the symbol and returns the new name for the symbol.
+# @param: target_op The 'operation name' of the target.
+# """
+# symbols = ir.SymbolTable(module.operation)
+# for op in module.operation.regions[0].blocks[0]:
+# if op.operation.name != target_op:
+# continue
+# old_name = op.name.value
+# new_name = rename_fn(old_name)
+# if new_name is None:
+# continue
+# symbols.replace_all_symbol_uses(old_name, new_name, module.operation)
+# symbols.set_symbol_name(op.operation, new_name)
+
+
+# @with_toplevel_context
+# def run_module(
+# mod: str,
+# entry_point: str,
+# *args,
+# shared_libs: Sequence[str] = [],
+# invoke_wrapper: Callable[[Callable[[], None]], Any] = lambda x: x(),
+# ) -> Any:
+# """Runs an MLIR module using the execution engine.
+
+# Parameters
+# ----------
+# @param: mod The MLIR module as a string.
+# @param: entry_point The function to execute.
+# @param: args The args to passthrough to the MLIR function.
+# @param: shared_libs The libraries to be loaded by the execution engine.
+# @param: invoke_wrapper A wrapper function to invoke the execution engine, this is useful for performing additional actions.
+
+# Returns
+# -------
+# The results of the wrapper function.
+# """
+# from mlir.execution_engine import ExecutionEngine
+
+# module = ir.Module.parse(mod)
+# execution_engine = ExecutionEngine(
+# module,
+# shared_libs=shared_libs,
+# )
+
+# def invoke():
+# execution_engine.invoke(entry_point, *args)
+
+# return invoke_wrapper(invoke)
diff --git a/mlir/test/python/utils.py b/mlir/test/python/utils.py
new file mode 100644
index 0000000000000..8435fdd363ae3
--- /dev/null
+++ b/mlir/test/python/utils.py
@@ -0,0 +1,58 @@
+# RUN: %python %s | FileCheck %s
+
+import unittest
+
+from mlir import ir
+from mlir.dialects import arith, builtin
+from mlir.extras import types as T
+from mlir.utils import (
+ call_with_toplevel_context_create_module,
+ caller_mlir_context,
+ using_mlir_context,
+)
+
+
+class TestRequiredContext(unittest.TestCase):
+ def test_shared_context(self):
+ """Test that the context is reused, so values can be passed/returned between functions."""
+
+ @using_mlir_context()
+ def create_add(lhs: ir.Value, rhs: ir.Value) -> ir.Value:
+ return arith.AddFOp(
+ lhs, rhs, fastmath=arith.FastMathFlags.nnan | arith.FastMathFlags.ninf
+ ).result
+
+ @using_mlir_context()
+ def multiple_adds(lhs: ir.Value, rhs: ir.Value) -> ir.Value:
+ return create_add(create_add(lhs, rhs), create_add(lhs, rhs))
+
+ @call_with_toplevel_context_create_module
+ def _(module) -> None:
+ c = arith.ConstantOp(value=42.42, result=ir.F32Type.get()).result
+ multiple_adds(c, c)
+
+ # CHECK: constant
+ # CHECK-NEXT: arith.addf
+ # CHECK-NEXT: arith.addf
+ # CHECK-NEXT: arith.addf
+ print(module)
+
+ def test_unregistered_op_asserts(self):
+ """Confirm that with_mlir_context fails if an operation is still not registered."""
+ with self.assertRaises(AssertionError), using_mlir_context(
+ required_extension_operations=["func.fake_extension_op"],
+ registration_funcs=[],
+ ):
+ pass
+
+ def test_required_op_asserts(self):
+ """Confirm that with_mlir_context fails if an operation is still not registered."""
+ with self.assertRaises(AssertionError), caller_mlir_context(
+ required_extension_operations=["func.fake_extension_op"],
+ registration_funcs=[],
+ ):
+ pass
+
+
+if __name__ == "__main__":
+ unittest.main()
>From c277571c4199dfb661589a3f918f6e78791402aa Mon Sep 17 00:00:00 2001
From: Nicolas Vasilache <nico.vasilache at amd.com>
Date: Sun, 6 Jul 2025 11:55:26 +0200
Subject: [PATCH 2/3] [mlir][transform] NFC - Introduce TransformEachOp util
---
.../TransformOps/BufferizationTransformOps.td | 30 +-
.../GPU/TransformOps/GPUTransformOps.td | 44 +-
.../Linalg/TransformOps/LinalgTransformOps.td | 567 +++++++-----------
.../MemRef/TransformOps/MemRefTransformOps.td | 38 +-
.../NVGPU/TransformOps/NVGPUTransformOps.td | 72 +--
.../Dialect/Transform/Interfaces/Utils.td | 25 +
6 files changed, 331 insertions(+), 445 deletions(-)
create mode 100644 mlir/include/mlir/Dialect/Transform/Interfaces/Utils.td
diff --git a/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td b/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td
index 53b3b0505b399..36c0153108411 100644
--- a/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td
@@ -12,6 +12,7 @@
include "mlir/Dialect/Bufferization/IR/BufferizationEnums.td"
include "mlir/Dialect/Transform/IR/TransformDialect.td"
include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
+include "mlir/Dialect/Transform/Interfaces/Utils.td"
include "mlir/Dialect/Transform/IR/TransformTypes.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/OpBase.td"
@@ -24,9 +25,12 @@ def Transform_AllocTensorOp : Transform_ConcreteOpType<"bufferization.alloc_tens
//===----------------------------------------------------------------------===//
def BufferLoopHoistingOp
- : Op<Transform_Dialect, "bufferization.buffer_loop_hoisting",
- [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
- TransformEachOpTrait, TransformOpInterface]> {
+ : TransformEachOp<
+ "bufferization.buffer_loop_hoisting",
+ "::mlir::Operation*",
+ [
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ ]> {
let description = [{
Hoist buffer allocations ("memref.alloc" and "memref.alloca") from loops
within the targeted op. This transform assumes that there are no buffer
@@ -38,14 +42,6 @@ def BufferLoopHoistingOp
let arguments = (ins TransformHandleTypeInterface:$target);
let results = (outs);
let assemblyFormat = "$target attr-dict `:` type($target)";
-
- let extraClassDeclaration = [{
- ::mlir::DiagnosedSilenceableFailure applyToOne(
- ::mlir::transform::TransformRewriter &rewriter,
- ::mlir::Operation *target,
- ::mlir::transform::ApplyToEachResultList &results,
- ::mlir::transform::TransformState &state);
- }];
}
//===----------------------------------------------------------------------===//
@@ -164,11 +160,13 @@ def EliminateEmptyTensorsOp
//===----------------------------------------------------------------------===//
def EmptyTensorToAllocTensorOp
- : Op<Transform_Dialect, "bufferization.empty_tensor_to_alloc_tensor",
- [FunctionalStyleTransformOpTrait,
- MemoryEffectsOpInterface,
- TransformOpInterface,
- TransformEachOpTrait]> {
+ : TransformEachOp<
+ "bufferization.empty_tensor_to_alloc_tensor",
+ "::mlir::tensor::EmptyOp",
+ [
+ FunctionalStyleTransformOpTrait,
+ MemoryEffectsOpInterface,
+ ]> {
let description = [{
Replace a tensor.empty with a bufferization.tensor_alloc.
diff --git a/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td b/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td
index 36b579485fc04..bdfac8b71d5e9 100644
--- a/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td
+++ b/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td
@@ -11,6 +11,7 @@
include "mlir/Dialect/Transform/IR/TransformDialect.td"
include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
+include "mlir/Dialect/Transform/Interfaces/Utils.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/OpBase.td"
@@ -125,12 +126,14 @@ def EliminateBarriersOp :
let assemblyFormat = [{ attr-dict }];
}
-def MapNestedForallToThreads :
- Op<Transform_Dialect, "gpu.map_nested_forall_to_threads",
- [FunctionalStyleTransformOpTrait,
- MemoryEffectsOpInterface,
- TransformEachOpTrait,
- TransformOpInterface]> {
+def MapNestedForallToThreads
+ : TransformEachOp<
+ "gpu.map_nested_forall_to_threads",
+ "::mlir::Operation*",
+ [
+ FunctionalStyleTransformOpTrait,
+ MemoryEffectsOpInterface,
+ ]> {
let description = [{
Target the `gpu.launch op` and rewrite all `scf.forall` nested in it to
distributed `gpu.thread_id` attribute.
@@ -235,21 +238,16 @@ def MapNestedForallToThreads :
attr-dict
`:` functional-type($target, $result)
}];
- let extraClassDeclaration = [{
- ::mlir::DiagnosedSilenceableFailure applyToOne(
- ::mlir::transform::TransformRewriter &rewriter,
- ::mlir::Operation *target,
- ::mlir::transform::ApplyToEachResultList &results,
- ::mlir::transform::TransformState &state);
- }];
}
-def MapForallToBlocks :
- Op<Transform_Dialect, "gpu.map_forall_to_blocks",
- [FunctionalStyleTransformOpTrait,
- MemoryEffectsOpInterface,
- TransformOpInterface,
- TransformEachOpTrait]> {
+def MapForallToBlocks
+ : TransformEachOp<
+ "gpu.map_forall_to_blocks",
+ "::mlir::Operation*",
+ [
+ FunctionalStyleTransformOpTrait,
+ MemoryEffectsOpInterface,
+ ]> {
let description = [{
Target the gpu_launch op and rewrite the top level `scf.forall`
to distributed gpu.block_id attribute. If `generate_gpu_launch` attribute
@@ -300,14 +298,6 @@ def MapForallToBlocks :
`:` functional-type($target, $result)
}];
let hasVerifier = 1;
-
- let extraClassDeclaration = [{
- ::mlir::DiagnosedSilenceableFailure applyToOne(
- ::mlir::transform::TransformRewriter &rewriter,
- ::mlir::Operation *target,
- ::mlir::transform::ApplyToEachResultList &results,
- ::mlir::transform::TransformState &state);
- }];
}
def ApplyGPUPromoteShuffleToAMDGPUPatternsOp : Op<Transform_Dialect,
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 4360055e78691..8a73f4c49499e 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -14,6 +14,7 @@ include "mlir/Dialect/Linalg/TransformOps/LinalgTransformEnums.td"
include "mlir/Dialect/Transform/IR/TransformAttrs.td"
include "mlir/Dialect/Transform/IR/TransformDialect.td"
include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
+include "mlir/Dialect/Transform/Interfaces/Utils.td"
include "mlir/Dialect/Transform/IR/TransformTypes.td"
include "mlir/Dialect/SCF/IR/DeviceMappingInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
@@ -247,12 +248,15 @@ def BufferizeToAllocationOp : Op<Transform_Dialect,
// DecomposeOp
//===----------------------------------------------------------------------===//
-def DecomposeOp : Op<Transform_Dialect, "structured.decompose",
- [FunctionalStyleTransformOpTrait,
- MemoryEffectsOpInterface,
- TransformOpInterface,
- TransformEachOpTrait,
- ReportTrackingListenerFailuresOpTrait]> {
+def DecomposeOp
+ : TransformEachOp<
+ "structured.decompose",
+ "::mlir::linalg::LinalgOp",
+ [
+ FunctionalStyleTransformOpTrait,
+ MemoryEffectsOpInterface,
+ ReportTrackingListenerFailuresOpTrait
+ ]> {
let description = [{
Decomposes named complex operations, such as higher-dimensional
(depthwise) convolutions, into combinations of lower-dimensional equivalents
@@ -271,14 +275,6 @@ def DecomposeOp : Op<Transform_Dialect, "structured.decompose",
let results = (outs TransformHandleTypeInterface:$transformed);
let assemblyFormat =
"$target attr-dict `:` functional-type(operands, results)";
-
- let extraClassDeclaration = [{
- ::mlir::DiagnosedSilenceableFailure applyToOne(
- ::mlir::transform::TransformRewriter &rewriter,
- ::mlir::linalg::LinalgOp target,
- ::mlir::transform::ApplyToEachResultList &results,
- ::mlir::transform::TransformState &state);
- }];
}
//===----------------------------------------------------------------------===//
@@ -443,10 +439,15 @@ def FuseIntoContainingOp :
// GeneralizeOp
//===----------------------------------------------------------------------===//
-def GeneralizeOp : Op<Transform_Dialect, "structured.generalize",
- [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
- TransformOpInterface, TransformEachOpTrait,
- ReportTrackingListenerFailuresOpTrait]> {
+def GeneralizeOp
+ : TransformEachOp<
+ "structured.generalize",
+ "::mlir::linalg::LinalgOp",
+ [
+ FunctionalStyleTransformOpTrait,
+ MemoryEffectsOpInterface,
+ ReportTrackingListenerFailuresOpTrait
+ ]> {
let description = [{
Transforms a named structured operation into the generic form with the
explicit attached region.
@@ -467,24 +468,21 @@ def GeneralizeOp : Op<Transform_Dialect, "structured.generalize",
$target attr-dict `:`
custom<SemiFunctionType>(type($target), type($transformed), "false")
}];
-
- let extraClassDeclaration = [{
- ::mlir::DiagnosedSilenceableFailure applyToOne(
- ::mlir::transform::TransformRewriter &rewriter,
- ::mlir::linalg::LinalgOp target,
- ::mlir::transform::ApplyToEachResultList &results,
- ::mlir::transform::TransformState &state);
- }];
}
//===----------------------------------------------------------------------===//
// SpecializeOp
//===----------------------------------------------------------------------===//
-def SpecializeOp : Op<Transform_Dialect, "structured.specialize",
- [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
- TransformOpInterface, TransformEachOpTrait,
- ReportTrackingListenerFailuresOpTrait]> {
+def SpecializeOp
+ : TransformEachOp<
+ "structured.specialize",
+ "::mlir::linalg::LinalgOp",
+ [
+ FunctionalStyleTransformOpTrait,
+ MemoryEffectsOpInterface,
+ ReportTrackingListenerFailuresOpTrait
+ ]> {
let description = [{
Transforms a generic operation into the equivalent named form.
@@ -505,24 +503,21 @@ def SpecializeOp : Op<Transform_Dialect, "structured.specialize",
$target attr-dict `:`
custom<SemiFunctionType>(type($target), type($transformed), "false")
}];
-
- let extraClassDeclaration = [{
- ::mlir::DiagnosedSilenceableFailure applyToOne(
- ::mlir::transform::TransformRewriter &rewriter,
- ::mlir::linalg::LinalgOp target,
- ::mlir::transform::ApplyToEachResultList &results,
- ::mlir::transform::TransformState &state);
- }];
}
//===----------------------------------------------------------------------===//
// InterchangeOp
//===----------------------------------------------------------------------===//
-def InterchangeOp : Op<Transform_Dialect, "structured.interchange",
- [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
- TransformOpInterface, TransformEachOpTrait,
- ReportTrackingListenerFailuresOpTrait]> {
+def InterchangeOp
+ : TransformEachOp<
+ "structured.interchange",
+ "::mlir::linalg::GenericOp",
+ [
+ FunctionalStyleTransformOpTrait,
+ MemoryEffectsOpInterface,
+ ReportTrackingListenerFailuresOpTrait
+ ]> {
let description = [{
Interchanges the iterators of the operations pointed to by the target handle
using the iterator interchange attribute.
@@ -550,24 +545,20 @@ def InterchangeOp : Op<Transform_Dialect, "structured.interchange",
`:` custom<SemiFunctionType>(type($target), type($transformed), "false")
}];
let hasVerifier = 1;
-
- let extraClassDeclaration = [{
- ::mlir::DiagnosedSilenceableFailure applyToOne(
- ::mlir::transform::TransformRewriter &rewriter,
- ::mlir::linalg::GenericOp target,
- ::mlir::transform::ApplyToEachResultList &results,
- ::mlir::transform::TransformState &state);
- }];
}
//===----------------------------------------------------------------------===//
// LinalgCopyToMemrefOp
//===----------------------------------------------------------------------===//
-def LinalgCopyToMemrefOp :
- Op<Transform_Dialect, "structured.linalg_copy_to_memref",
- [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
- TransformEachOpTrait, TransformOpInterface]> {
+def LinalgCopyToMemrefOp
+ : TransformEachOp<
+ "structured.linalg_copy_to_memref",
+ "::mlir::Operation*",
+ [
+ FunctionalStyleTransformOpTrait,
+ MemoryEffectsOpInterface
+ ]> {
let description = [{
Targeted rewrite of a linalg.copy on memrefs to a memref.copy.
This is useful when bufferizing copies to a linalg.copy, later applying some
@@ -585,24 +576,20 @@ def LinalgCopyToMemrefOp :
let builders = [
OpBuilder<(ins "Value":$target)>,
];
- let extraClassDeclaration = [{
- ::mlir::DiagnosedSilenceableFailure applyToOne(
- ::mlir::transform::TransformRewriter &rewriter,
- ::mlir::Operation *target,
- ::mlir::transform::ApplyToEachResultList &results,
- ::mlir::transform::TransformState &state);
- }];
}
//===----------------------------------------------------------------------===//
// LowerPackOp
//===----------------------------------------------------------------------===//
-def LowerPackOp : Op<Transform_Dialect, "structured.lower_pack", [
- FunctionalStyleTransformOpTrait,
- MemoryEffectsOpInterface,
- TransformEachOpTrait,
- TransformOpInterface,
- ReportTrackingListenerFailuresOpTrait]> {
+def LowerPackOp
+ : TransformEachOp<
+ "structured.lower_pack",
+ "::mlir::linalg::PackOp",
+ [
+ FunctionalStyleTransformOpTrait,
+ MemoryEffectsOpInterface,
+ ReportTrackingListenerFailuresOpTrait
+ ]> {
let description = [{
Rewrite a linalg.pack into tensor.pad + tensor.expand_shape + linalg.transpose.
@@ -623,25 +610,20 @@ def LowerPackOp : Op<Transform_Dialect, "structured.lower_pack", [
let assemblyFormat = [{
$target attr-dict `:` functional-type(operands, results)
}];
-
- let extraClassDeclaration = [{
- ::mlir::DiagnosedSilenceableFailure applyToOne(
- ::mlir::transform::TransformRewriter &rewriter,
- ::mlir::linalg::PackOp target,
- ::mlir::transform::ApplyToEachResultList &transformResults,
- ::mlir::transform::TransformState &state);
- }];
}
//===----------------------------------------------------------------------===//
// LowerUnPackOp
//===----------------------------------------------------------------------===//
-def LowerUnPackOp : Op<Transform_Dialect, "structured.lower_unpack", [
- FunctionalStyleTransformOpTrait,
- MemoryEffectsOpInterface,
- TransformEachOpTrait,
- TransformOpInterface,
- ReportTrackingListenerFailuresOpTrait]> {
+def LowerUnPackOp
+ : TransformEachOp<
+ "structured.lower_unpack",
+ "::mlir::linalg::UnPackOp",
+ [
+ FunctionalStyleTransformOpTrait,
+ MemoryEffectsOpInterface,
+ ReportTrackingListenerFailuresOpTrait
+ ]> {
let description = [{
Lower a linalg.unpack into empty + linalg.transpose + tensor.collapse_shape +
tensor.extract_slice.
@@ -665,13 +647,6 @@ def LowerUnPackOp : Op<Transform_Dialect, "structured.lower_unpack", [
$target attr-dict `:` functional-type(operands, results)
}];
- let extraClassDeclaration = [{
- ::mlir::DiagnosedSilenceableFailure applyToOne(
- ::mlir::transform::TransformRewriter &rewriter,
- ::mlir::linalg::UnPackOp target,
- ::mlir::transform::ApplyToEachResultList &transformResults,
- ::mlir::transform::TransformState &state);
- }];
}
//===----------------------------------------------------------------------===//
@@ -745,10 +720,14 @@ def MatchOp : Op<Transform_Dialect, "structured.match",
// MultiTileSizesOp
//===----------------------------------------------------------------------===//
-def MultiTileSizesOp : Op<Transform_Dialect, "structured.multitile_sizes",
- [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
- TransformOpInterface, TransformEachOpTrait,
- ReportTrackingListenerFailuresOpTrait]> {
+def MultiTileSizesOp
+ : TransformEachOp<
+ "structured.multitile_sizes",
+ "::mlir::linalg::LinalgOp",
+ [
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ ReportTrackingListenerFailuresOpTrait
+ ]> {
let description = [{
Emits the IR computing the tile sizes `s1` and `s2` such that:
@@ -817,14 +796,6 @@ def MultiTileSizesOp : Op<Transform_Dialect, "structured.multitile_sizes",
let assemblyFormat =
"$target attr-dict `:` custom<MultitileSizesTypes>("
"type($target), type($low_size), type($high_size), type($split_point))";
-
- let extraClassDeclaration = [{
- ::mlir::DiagnosedSilenceableFailure applyToOne(
- ::mlir::transform::TransformRewriter &rewriter,
- ::mlir::linalg::LinalgOp target,
- ::mlir::transform::ApplyToEachResultList &results,
- TransformState &state);
- }];
}
//===----------------------------------------------------------------------===//
@@ -1320,11 +1291,14 @@ def HoistPadBuildPackingLoopNestOp :
let hasVerifier = 1;
}
-def HoistPadOp : Op<Transform_Dialect, "structured.hoist_pad",
- [FunctionalStyleTransformOpTrait,
- MemoryEffectsOpInterface,
- TransformOpInterface,
- TransformEachOpTrait]> {
+def HoistPadOp
+ : TransformEachOp<
+ "structured.hoist_pad",
+ "::mlir::tensor::PadOp",
+ [
+ FunctionalStyleTransformOpTrait,
+ MemoryEffectsOpInterface
+ ]> {
let description = [{
Hoist the tensor.pad target operation by at most the given number of loops.
Optionally apply the transpose attribute to the inner dimensions.
@@ -1361,14 +1335,6 @@ def HoistPadOp : Op<Transform_Dialect, "structured.hoist_pad",
`:` functional-type(operands, results)
}];
let hasVerifier = 1;
-
- let extraClassDeclaration = [{
- ::mlir::DiagnosedSilenceableFailure applyToOne(
- ::mlir::transform::TransformRewriter &rewriter,
- ::mlir::tensor::PadOp,
- ::mlir::transform::ApplyToEachResultList &results,
- ::mlir::transform::TransformState &state);
- }];
}
//===----------------------------------------------------------------------===//
@@ -1376,10 +1342,15 @@ def HoistPadOp : Op<Transform_Dialect, "structured.hoist_pad",
//===----------------------------------------------------------------------===//
-def PromoteOp : Op<Transform_Dialect, "structured.promote",
- [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
- TransformOpInterface, TransformEachOpTrait,
- ReportTrackingListenerFailuresOpTrait]> {
+def PromoteOp
+ : TransformEachOp<
+ "structured.promote",
+ "::mlir::linalg::LinalgOp",
+ [
+ FunctionalStyleTransformOpTrait,
+ MemoryEffectsOpInterface,
+ ReportTrackingListenerFailuresOpTrait
+ ]> {
let description = [{
Promotes the specified operands of the target into a separate memory buffer.
@@ -1413,14 +1384,6 @@ def PromoteOp : Op<Transform_Dialect, "structured.promote",
$target attr-dict `:`
custom<SemiFunctionType>(type($target), type($transformed), "false")
}];
-
- let extraClassDeclaration = [{
- ::mlir::DiagnosedSilenceableFailure applyToOne(
- ::mlir::transform::TransformRewriter &rewriter,
- ::mlir::linalg::LinalgOp target,
- ::mlir::transform::ApplyToEachResultList &results,
- ::mlir::transform::TransformState &state);
- }];
}
//===----------------------------------------------------------------------===//
@@ -1457,10 +1420,15 @@ def ReplaceOp : Op<Transform_Dialect, "structured.replace",
// ScalarizeOp
//===----------------------------------------------------------------------===//
-def ScalarizeOp : Op<Transform_Dialect, "structured.scalarize",
- [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
- TransformOpInterface, TransformEachOpTrait,
- ReportTrackingListenerFailuresOpTrait]> {
+def ScalarizeOp
+ : TransformEachOp<
+ "structured.scalarize",
+ "::mlir::linalg::LinalgOp",
+ [
+ FunctionalStyleTransformOpTrait,
+ MemoryEffectsOpInterface,
+ ReportTrackingListenerFailuresOpTrait
+ ]> {
let description = [{
Indicates that ops of a specific kind in the given function should be
scalarized (i.e. their dynamic dimensions tiled by 1).
@@ -1492,14 +1460,6 @@ def ScalarizeOp : Op<Transform_Dialect, "structured.scalarize",
$target attr-dict `:`
custom<SemiFunctionType>(type($target), type($result), "false")
}];
-
- let extraClassDeclaration = [{
- ::mlir::DiagnosedSilenceableFailure applyToOne(
- ::mlir::transform::TransformRewriter &rewriter,
- ::mlir::linalg::LinalgOp target,
- ::mlir::transform::ApplyToEachResultList &results,
- ::mlir::transform::TransformState &state);
- }];
}
//===----------------------------------------------------------------------===//
@@ -1529,12 +1489,15 @@ def ConvertToLoopsOp : Op<Transform_Dialect, "structured.convert_to_loops",
// DecomposeInterfaceOp
//===----------------------------------------------------------------------===//
-def DecomposeInterfaceOp : Op<Transform_Dialect, "structured.decompose_interface",
- [FunctionalStyleTransformOpTrait,
- MemoryEffectsOpInterface,
- TransformOpInterface,
- TransformEachOpTrait,
- ReportTrackingListenerFailuresOpTrait]> {
+def DecomposeInterfaceOp
+ : TransformEachOp<
+ "structured.decompose_interface",
+ "::mlir::Operation*",
+ [
+ FunctionalStyleTransformOpTrait,
+ MemoryEffectsOpInterface,
+ ReportTrackingListenerFailuresOpTrait
+ ]> {
let description = [{
TODO
}];
@@ -1543,26 +1506,20 @@ def DecomposeInterfaceOp : Op<Transform_Dialect, "structured.decompose_interface
let results = (outs TransformHandleTypeInterface:$transformed);
let assemblyFormat =
"$target attr-dict `:` functional-type(operands, results)";
-
- let extraClassDeclaration = [{
- ::mlir::DiagnosedSilenceableFailure applyToOne(
- ::mlir::transform::TransformRewriter &rewriter,
- ::mlir::Operation *target,
- ::mlir::transform::ApplyToEachResultList &results,
- ::mlir::transform::TransformState &state);
- }];
}
//===----------------------------------------------------------------------===//
// RewriteInDestinationPassingStyleOp.
//===----------------------------------------------------------------------===//
-def RewriteInDestinationPassingStyleOp : Op<
- Transform_Dialect, "structured.rewrite_in_destination_passing_style",
- [FunctionalStyleTransformOpTrait,
- MemoryEffectsOpInterface,
- TransformOpInterface,
- TransformEachOpTrait,
- ReportTrackingListenerFailuresOpTrait]> {
+def RewriteInDestinationPassingStyleOp
+ : TransformEachOp<
+ "structured.rewrite_in_destination_passing_style",
+ "::mlir::Operation*",
+ [
+ FunctionalStyleTransformOpTrait,
+ MemoryEffectsOpInterface,
+ ReportTrackingListenerFailuresOpTrait
+ ]> {
let description = [{
Rewrite a supported tensor operation that is not in destination-passing style
into a form that is in destination-passing style.
@@ -1592,14 +1549,6 @@ def RewriteInDestinationPassingStyleOp : Op<
$target attr-dict
`:` functional-type($target, results)
}];
-
- let extraClassDeclaration = [{
- ::mlir::DiagnosedSilenceableFailure applyToOne(
- ::mlir::transform::TransformRewriter &rewriter,
- ::mlir::Operation *target,
- ::mlir::transform::ApplyToEachResultList &results,
- ::mlir::transform::TransformState &state);
- }];
}
//===----------------------------------------------------------------------===//
@@ -1660,10 +1609,15 @@ def SplitOp : Op<Transform_Dialect, "structured.split",
// SplitReductionOp
//===----------------------------------------------------------------------===//
-def SplitReductionOp : Op<Transform_Dialect, "structured.split_reduction",
- [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
- TransformEachOpTrait, TransformOpInterface,
- ReportTrackingListenerFailuresOpTrait]> {
+def SplitReductionOp
+ : TransformEachOp<
+ "structured.split_reduction",
+ "::mlir::linalg::LinalgOp",
+ [
+ FunctionalStyleTransformOpTrait,
+ MemoryEffectsOpInterface,
+ ReportTrackingListenerFailuresOpTrait
+ ]> {
let description = [{
Indicates that the given `target` op should be transformed with the
`splitReduction` transformation and split factor provided as attribute.
@@ -1822,24 +1776,21 @@ def SplitReductionOp : Op<Transform_Dialect, "structured.split_reduction",
CArg<"bool", "false">:$useScalingAlgorithm,
CArg<"bool", "false">:$useAlloc)>
];
-
- let extraClassDeclaration = [{
- ::mlir::DiagnosedSilenceableFailure applyToOne(
- ::mlir::transform::TransformRewriter &rewriter,
- ::mlir::linalg::LinalgOp target,
- ::mlir::transform::ApplyToEachResultList &results,
- ::mlir::transform::TransformState &state);
- }];
}
//===----------------------------------------------------------------------===//
// TileReductionUsingForOp
//===----------------------------------------------------------------------===//
-def TileReductionUsingForOp : Op<Transform_Dialect, "structured.tile_reduction_using_for",
- [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
- TransformEachOpTrait, TransformOpInterface,
- ReportTrackingListenerFailuresOpTrait]> {
+def TileReductionUsingForOp
+ : TransformEachOp<
+ "structured.tile_reduction_using_for",
+ "::mlir::Operation*",
+ [
+ FunctionalStyleTransformOpTrait,
+ MemoryEffectsOpInterface,
+ ReportTrackingListenerFailuresOpTrait
+ ]> {
let description = [{
Indicates that the given `target` op should be transformed with the
`tileReduction` transformation with the tile size provided as attribute.
@@ -1934,25 +1885,21 @@ def TileReductionUsingForOp : Op<Transform_Dialect, "structured.tile_reduction_u
attr-dict
`:` functional-type(operands, results)
}];
-
- let extraClassDeclaration = [{
- ::mlir::DiagnosedSilenceableFailure applyToOne(
- ::mlir::transform::TransformRewriter &rewriter,
- Operation *target,
- ::mlir::transform::ApplyToEachResultList &results,
- ::mlir::transform::TransformState &state);
- }];
}
//===----------------------------------------------------------------------===//
// TileReductionUsingForallOp
//===----------------------------------------------------------------------===//
-def TileReductionUsingForallOp :
- Op<Transform_Dialect, "structured.tile_reduction_using_forall",
- [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
- TransformEachOpTrait, TransformOpInterface,
- ReportTrackingListenerFailuresOpTrait]> {
+def TileReductionUsingForallOp
+ : TransformEachOp<
+ "structured.tile_reduction_using_forall",
+ "::mlir::linalg::LinalgOp",
+ [
+ FunctionalStyleTransformOpTrait,
+ MemoryEffectsOpInterface,
+ ReportTrackingListenerFailuresOpTrait
+ ]> {
let description = [{
Tile a PartialReductionOpInterface op to a tiled `scf.forall` doing
partial reduction.
@@ -2047,15 +1994,6 @@ def TileReductionUsingForallOp :
attr-dict
`:` functional-type(operands, results)
}];
-
- let extraClassDeclaration = [{
- ::mlir::DiagnosedSilenceableFailure applyToOne(
- ::mlir::transform::TransformRewriter &rewriter,
- ::mlir::linalg::LinalgOp target,
- ::mlir::transform::ApplyToEachResultList &results,
- ::mlir::transform::TransformState &state);
- }];
-
}
//===----------------------------------------------------------------------===//
@@ -2334,11 +2272,15 @@ def TileUsingForallOp :
// VectorizeChildrenAndApplyPatternsOp
//===----------------------------------------------------------------------===//
-def VectorizeChildrenAndApplyPatternsOp :
- Op<Transform_Dialect, "structured.vectorize_children_and_apply_patterns",
- [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
- TransformEachOpTrait, TransformOpInterface,
- ReportTrackingListenerFailuresOpTrait]> {
+def VectorizeChildrenAndApplyPatternsOp
+ : TransformEachOp<
+ "structured.vectorize_children_and_apply_patterns",
+ "::mlir::Operation*",
+ [
+ FunctionalStyleTransformOpTrait,
+ MemoryEffectsOpInterface,
+ ReportTrackingListenerFailuresOpTrait
+ ]> {
let description = [{
Vectorizes all children contained in the given `target` using the
configuration specified by the attributes of this op. This only vectorizes
@@ -2394,13 +2336,6 @@ def VectorizeChildrenAndApplyPatternsOp :
CArg<"bool", "false">:$vectorizeNDExtract,
CArg<"bool", "false">:$flatten1DDepthwise)>
];
- let extraClassDeclaration = [{
- ::mlir::DiagnosedSilenceableFailure applyToOne(
- ::mlir::transform::TransformRewriter &rewriter,
- ::mlir::Operation *target,
- ::mlir::transform::ApplyToEachResultList &results,
- ::mlir::transform::TransformState &state);
- }];
}
def VectorizeOp : Op<Transform_Dialect, "structured.vectorize",
@@ -2479,11 +2414,15 @@ def VectorizeOp : Op<Transform_Dialect, "structured.vectorize",
// HoistRedundantVectorTransfersOp
//===----------------------------------------------------------------------===//
-def HoistRedundantVectorTransfersOp :
- Op<Transform_Dialect, "structured.hoist_redundant_vector_transfers",
- [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
- TransformEachOpTrait, TransformOpInterface,
- ReportTrackingListenerFailuresOpTrait]> {
+def HoistRedundantVectorTransfersOp
+ : TransformEachOp<
+ "structured.hoist_redundant_vector_transfers",
+ "::mlir::func::FuncOp",
+ [
+ FunctionalStyleTransformOpTrait,
+ MemoryEffectsOpInterface,
+ ReportTrackingListenerFailuresOpTrait
+ ]> {
let description = [{
Hoist vector.transfer_read / vector.transfer_write pairs out of immediately
enclosing scf::ForOp iteratively, if the following conditions are true:
@@ -2513,24 +2452,21 @@ def HoistRedundantVectorTransfersOp :
OpBuilder<(ins "Value":$target,
CArg<"bool", "false">:$verify_non_zero_trip)>,
];
- let extraClassDeclaration = [{
- ::mlir::DiagnosedSilenceableFailure applyToOne(
- ::mlir::transform::TransformRewriter &rewriter,
- ::mlir::func::FuncOp target,
- ::mlir::transform::ApplyToEachResultList &results,
- ::mlir::transform::TransformState &state);
- }];
}
//===----------------------------------------------------------------------===//
// HoistRedundantVectorBroadcastsOp
//===----------------------------------------------------------------------===//
-def HoistRedundantVectorBroadcastsOp :
- Op<Transform_Dialect, "structured.hoist_redundant_vector_broadcasts",
- [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
- TransformEachOpTrait, TransformOpInterface,
- ReportTrackingListenerFailuresOpTrait]> {
+def HoistRedundantVectorBroadcastsOp
+ : TransformEachOp<
+ "structured.hoist_redundant_vector_broadcasts",
+ "::mlir::Operation*",
+ [
+ FunctionalStyleTransformOpTrait,
+ MemoryEffectsOpInterface,
+ ReportTrackingListenerFailuresOpTrait
+ ]> {
let description = [{
Hoist vector.extract / vector.broadcasts pairs out of immediately
enclosing scf::ForOp iteratively.
@@ -2549,26 +2485,21 @@ def HoistRedundantVectorBroadcastsOp :
let builders = [
OpBuilder<(ins "Value":$target)>,
];
- let extraClassDeclaration = [{
- ::mlir::DiagnosedSilenceableFailure applyToOne(
- ::mlir::transform::TransformRewriter &rewriter,
- ::mlir::Operation *target,
- ::mlir::transform::ApplyToEachResultList &results,
- ::mlir::transform::TransformState &state);
- }];
}
//===----------------------------------------------------------------------===//
// ConvertConv2DToImg2ColOp
//===----------------------------------------------------------------------===//
-def ConvertConv2DToImg2ColOp : Op<Transform_Dialect,
- "structured.convert_conv2d_to_img2col",
- [FunctionalStyleTransformOpTrait,
- MemoryEffectsOpInterface,
- TransformOpInterface,
- TransformEachOpTrait,
- ReportTrackingListenerFailuresOpTrait]> {
+def ConvertConv2DToImg2ColOp
+ : TransformEachOp<
+ "structured.convert_conv2d_to_img2col",
+ "::mlir::linalg::LinalgOp",
+ [
+ FunctionalStyleTransformOpTrait,
+ MemoryEffectsOpInterface,
+ ReportTrackingListenerFailuresOpTrait
+ ]> {
let description = [{
Convert linalg.conv_2d_xxx into linalg.generic (for img2col packing)
and linalg.matmul.
@@ -2626,27 +2557,21 @@ def ConvertConv2DToImg2ColOp : Op<Transform_Dialect,
let builders = [
OpBuilder<(ins "Value":$target)>
];
-
- let extraClassDeclaration = [{
- ::mlir::DiagnosedSilenceableFailure applyToOne(
- ::mlir::transform::TransformRewriter &rewriter,
- ::mlir::linalg::LinalgOp target,
- ::mlir::transform::ApplyToEachResultList &results,
- ::mlir::transform::TransformState &state);
- }];
}
//===----------------------------------------------------------------------===//
// FlattenElementwiseLinalgOp
//===----------------------------------------------------------------------===//
-def FlattenElementwiseLinalgOp : Op<Transform_Dialect,
- "structured.flatten_elementwise",
- [FunctionalStyleTransformOpTrait,
- MemoryEffectsOpInterface,
- TransformOpInterface,
- TransformEachOpTrait,
- ReportTrackingListenerFailuresOpTrait]> {
+def FlattenElementwiseLinalgOp
+ : TransformEachOp<
+ "structured.flatten_elementwise",
+ "::mlir::linalg::LinalgOp",
+ [
+ FunctionalStyleTransformOpTrait,
+ MemoryEffectsOpInterface,
+ ReportTrackingListenerFailuresOpTrait
+ ]> {
let description = [{
Flattens the iteration space and (applicable) operands of elementwise
linalg ops to a single dimension.
@@ -2669,27 +2594,21 @@ def FlattenElementwiseLinalgOp : Op<Transform_Dialect,
let builders = [
OpBuilder<(ins "Value":$target)>
];
-
- let extraClassDeclaration = [{
- ::mlir::DiagnosedSilenceableFailure applyToOne(
- ::mlir::transform::TransformRewriter &rewriter,
- ::mlir::linalg::LinalgOp target,
- ::mlir::transform::ApplyToEachResultList &results,
- ::mlir::transform::TransformState &state);
- }];
}
//===----------------------------------------------------------------------===//
// Transpose Conv2D
//===----------------------------------------------------------------------===//
-def TransposeConv2DOp : Op<Transform_Dialect,
- "structured.transpose_conv2d",
- [FunctionalStyleTransformOpTrait,
- MemoryEffectsOpInterface,
- TransformOpInterface,
- TransformEachOpTrait,
- ReportTrackingListenerFailuresOpTrait]> {
+def TransposeConv2DOp
+ : TransformEachOp<
+ "structured.transpose_conv2d",
+ "::mlir::linalg::LinalgOp",
+ [
+ FunctionalStyleTransformOpTrait,
+ MemoryEffectsOpInterface,
+ ReportTrackingListenerFailuresOpTrait
+ ]> {
let description = [{
Convert linalg.conv_2d_nhwc_fhwc into linalg.conv_2d_nhwc_hwcf by introducing
a linalg.transpose on the filter tensor/memref.
@@ -2718,25 +2637,21 @@ def TransposeConv2DOp : Op<Transform_Dialect,
let builders = [
OpBuilder<(ins "Value":$target)>
];
-
- let extraClassDeclaration = [{
- ::mlir::DiagnosedSilenceableFailure applyToOne(
- ::mlir::transform::TransformRewriter &rewriter,
- ::mlir::linalg::LinalgOp target,
- ::mlir::transform::ApplyToEachResultList &results,
- ::mlir::transform::TransformState &state);
- }];
}
//===----------------------------------------------------------------------===//
// TransposeMatmulOp
//===----------------------------------------------------------------------===//
-def TransposeMatmulOp : Op<Transform_Dialect,
- "structured.transpose_matmul",
- [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
- TransformOpInterface, TransformEachOpTrait,
- ReportTrackingListenerFailuresOpTrait]> {
+def TransposeMatmulOp
+ : TransformEachOp<
+ "structured.transpose_matmul",
+ "::mlir::linalg::LinalgOp",
+ [
+ FunctionalStyleTransformOpTrait,
+ MemoryEffectsOpInterface,
+ ReportTrackingListenerFailuresOpTrait
+ ]> {
let description = [{
Convert Linalg matmul ops to transposed variants.
@@ -2764,24 +2679,20 @@ def TransposeMatmulOp : Op<Transform_Dialect,
let builders = [
OpBuilder<(ins "Value":$target)>
];
-
- let extraClassDeclaration = [{
- ::mlir::DiagnosedSilenceableFailure applyToOne(
- ::mlir::transform::TransformRewriter &rewriter,
- ::mlir::linalg::LinalgOp target,
- ::mlir::transform::ApplyToEachResultList &results,
- ::mlir::transform::TransformState &state);
- }];
}
//===----------------------------------------------------------------------===//
// InsertSliceToCopyOp
//===----------------------------------------------------------------------===//
-def InsertSliceToCopyOp :
- Op<Transform_Dialect, "structured.insert_slice_to_copy",
- [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
- TransformEachOpTrait, TransformOpInterface]> {
+def InsertSliceToCopyOp
+ : TransformEachOp<
+ "structured.insert_slice_to_copy",
+ "::mlir::Operation*",
+ [
+ FunctionalStyleTransformOpTrait,
+ MemoryEffectsOpInterface
+ ]> {
let description = [{
Targeted rewrite of an tensor.insert_slice to linalg.copy.
This is useful to materialize copies explicitly before bufferization and
@@ -2804,13 +2715,6 @@ def InsertSliceToCopyOp :
let builders = [
OpBuilder<(ins "Value":$target)>,
];
- let extraClassDeclaration = [{
- ::mlir::DiagnosedSilenceableFailure applyToOne(
- ::mlir::transform::TransformRewriter &rewriter,
- ::mlir::Operation *target,
- ::mlir::transform::ApplyToEachResultList &results,
- ::mlir::transform::TransformState &state);
- }];
}
//===----------------------------------------------------------------------===//
@@ -2862,6 +2766,7 @@ def MapCopyToThreadsOp :
let builders = [
OpBuilder<(ins "Value":$target)>,
];
+ // TODO: Use TransformEachOp< once `extraClassDeclaration`s compose.
let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
::mlir::transform::TransformRewriter &rewriter,
@@ -2877,11 +2782,15 @@ def MapCopyToThreadsOp :
// Winograd Conv2D
//===----------------------------------------------------------------------===//
-def WinogradConv2DOp : Op<Transform_Dialect,
- "structured.winograd_conv2d",
- [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
- TransformOpInterface, TransformEachOpTrait,
- ReportTrackingListenerFailuresOpTrait]> {
+def WinogradConv2DOp
+ : TransformEachOp<
+ "structured.winograd_conv2d",
+ "::mlir::linalg::LinalgOp",
+ [
+ FunctionalStyleTransformOpTrait,
+ MemoryEffectsOpInterface,
+ ReportTrackingListenerFailuresOpTrait
+ ]> {
let description = [{
Winograd Conv2D algorithm will convert linalg Conv2D operation into batched
matrix multiply. Before the matrix multiply, it will convert filter and
@@ -2913,21 +2822,17 @@ def WinogradConv2DOp : Op<Transform_Dialect,
let builders = [
OpBuilder<(ins "Value":$target)>
];
-
- let extraClassDeclaration = [{
- ::mlir::DiagnosedSilenceableFailure applyToOne(
- ::mlir::transform::TransformRewriter &rewriter,
- ::mlir::linalg::LinalgOp target,
- ::mlir::transform::ApplyToEachResultList &results,
- ::mlir::transform::TransformState &state);
- }];
}
-def DecomposeWinogradOp : Op<Transform_Dialect,
- "structured.decompose_winograd_op",
- [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
- TransformOpInterface, TransformEachOpTrait,
- ReportTrackingListenerFailuresOpTrait]> {
+def DecomposeWinogradOp
+ : TransformEachOp<
+ "structured.decompose_winograd_op",
+ "::mlir::Operation*",
+ [
+ FunctionalStyleTransformOpTrait,
+ MemoryEffectsOpInterface,
+ ReportTrackingListenerFailuresOpTrait
+ ]> {
let description = [{
Decompose winograd operations. It will convert filter, input and output
transform operations into a combination of scf, tensor, and linalg
@@ -2950,14 +2855,6 @@ def DecomposeWinogradOp : Op<Transform_Dialect,
let builders = [
OpBuilder<(ins "Value":$target)>
];
-
- let extraClassDeclaration = [{
- ::mlir::DiagnosedSilenceableFailure applyToOne(
- ::mlir::transform::TransformRewriter &rewriter,
- ::mlir::Operation *target,
- ::mlir::transform::ApplyToEachResultList &results,
- ::mlir::transform::TransformState &state);
- }];
}
#endif // LINALG_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td
index f4694a30a8a12..66d87436f51ef 100644
--- a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td
@@ -11,6 +11,7 @@
include "mlir/Dialect/Transform/IR/TransformDialect.td"
include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
+include "mlir/Dialect/Transform/Interfaces/Utils.td"
include "mlir/Dialect/Transform/IR/TransformTypes.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/OpBase.td"
@@ -238,11 +239,13 @@ def MemRefMultiBufferOp : Op<Transform_Dialect, "memref.multibuffer",
}
def MemRefEraseDeadAllocAndStoresOp
- : Op<Transform_Dialect, "memref.erase_dead_alloc_and_stores", [
- TransformEachOpTrait, TransformOpInterface,
- DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
- ReportTrackingListenerFailuresOpTrait
- ]> {
+ : TransformEachOp<
+ "memref.erase_dead_alloc_and_stores",
+ "::mlir::Operation*",
+ [
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ ReportTrackingListenerFailuresOpTrait
+ ]> {
let description = [{
This applies memory optimization on memref. In particular it does store to
load forwarding, dead store elimination and dead alloc/alloca elimination.
@@ -266,19 +269,16 @@ def MemRefEraseDeadAllocAndStoresOp
let builders = [
OpBuilder<(ins "Value":$target)>
];
- let extraClassDeclaration = [{
- ::mlir::DiagnosedSilenceableFailure applyToOne(
- ::mlir::transform::TransformRewriter &rewriter,
- ::mlir::Operation *target,
- ::mlir::transform::ApplyToEachResultList &results,
- ::mlir::transform::TransformState &state);
- }];
}
def MemRefMakeLoopIndependentOp
- : Op<Transform_Dialect, "memref.make_loop_independent",
- [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
- TransformOpInterface, TransformEachOpTrait]> {
+ : TransformEachOp<
+ "memref.make_loop_independent",
+ "::mlir::Operation*",
+ [
+ FunctionalStyleTransformOpTrait,
+ MemoryEffectsOpInterface
+ ]> {
let description = [{
Rewrite the targeted ops such that their index-typed operands no longer
depend on any loop induction variable of the `num_loop` enclosing `scf.for`
@@ -307,14 +307,6 @@ def MemRefMakeLoopIndependentOp
let results = (outs TransformHandleTypeInterface:$transformed);
let assemblyFormat =
"$target attr-dict `:` functional-type($target, $transformed)";
-
- let extraClassDeclaration = [{
- ::mlir::DiagnosedSilenceableFailure applyToOne(
- ::mlir::transform::TransformRewriter &rewriter,
- ::mlir::Operation *target,
- ::mlir::transform::ApplyToEachResultList &results,
- ::mlir::transform::TransformState &state);
- }];
}
#endif // MEMREF_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.td b/mlir/include/mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.td
index 0225562baa58c..17da5ab507279 100644
--- a/mlir/include/mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.td
+++ b/mlir/include/mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.td
@@ -12,6 +12,7 @@
include "mlir/Dialect/Transform/IR/TransformAttrs.td"
include "mlir/Dialect/Transform/IR/TransformDialect.td"
include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
+include "mlir/Dialect/Transform/Interfaces/Utils.td"
include "mlir/Dialect/Transform/IR/TransformTypes.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
@@ -34,12 +35,14 @@ def ApplyNVGPUToNVVMConversionPatternsOp : Op<Transform_Dialect,
// CreateAsyncGroupsOp
//===----------------------------------------------------------------------===//
-def CreateAsyncGroupsOp :
- Op<Transform_Dialect, "nvgpu.create_async_groups",
- [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
- TransformEachOpTrait,
- TransformOpInterface,
- ReportTrackingListenerFailuresOpTrait]> {
+def CreateAsyncGroupsOp
+ : TransformEachOp<
+ "nvgpu.create_async_groups",
+ "::mlir::Operation*",
+ [
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ ReportTrackingListenerFailuresOpTrait
+ ]> {
let description = [{
Look for global to shared memory copies within the targeted op in the form
of vector transfer ops and convert them to async copies when possible.
@@ -65,27 +68,21 @@ def CreateAsyncGroupsOp :
let assemblyFormat = [{
$target attr-dict `:` functional-type(operands, results)
}];
-
- let extraClassDeclaration = [{
- ::mlir::DiagnosedSilenceableFailure applyToOne(
- ::mlir::transform::TransformRewriter &rewriter,
- ::mlir::Operation *target,
- ::mlir::transform::ApplyToEachResultList &results,
- ::mlir::transform::TransformState &state);
- }];
}
//===----------------------------------------------------------------------===//
// PipelineSharedMemoryCopiesOp
//===----------------------------------------------------------------------===//
-def PipelineSharedMemoryCopiesOp :
- Op<Transform_Dialect, "nvgpu.pipeline_shared_memory_copies",
- [FunctionalStyleTransformOpTrait,
- MemoryEffectsOpInterface,
- TransformEachOpTrait,
- TransformOpInterface,
- ReportTrackingListenerFailuresOpTrait]> {
+def PipelineSharedMemoryCopiesOp
+ : TransformEachOp<
+ "nvgpu.pipeline_shared_memory_copies",
+ "::mlir::scf::ForOp",
+ [
+ FunctionalStyleTransformOpTrait,
+ MemoryEffectsOpInterface,
+ ReportTrackingListenerFailuresOpTrait
+ ]> {
let summary =
"Applies software pipelining to a given loop with shared memory copies";
@@ -136,27 +133,21 @@ def PipelineSharedMemoryCopiesOp :
attr-dict
`:` functional-type(operands, results)
}];
-
- let extraClassDeclaration = [{
- ::mlir::DiagnosedSilenceableFailure applyToOne(
- ::mlir::transform::TransformRewriter &rewriter,
- ::mlir::scf::ForOp forOp,
- ::mlir::transform::ApplyToEachResultList &results,
- ::mlir::transform::TransformState &state);
- }];
}
//===----------------------------------------------------------------------===//
// RewriteMatmulAsMmaSyncOp
//===----------------------------------------------------------------------===//
-def RewriteMatmulAsMmaSyncOp :
- Op<Transform_Dialect, "nvgpu.rewrite_matmul_as_mma_sync",
- [FunctionalStyleTransformOpTrait,
- MemoryEffectsOpInterface,
- TransformEachOpTrait,
- TransformOpInterface,
- ReportTrackingListenerFailuresOpTrait]> {
+def RewriteMatmulAsMmaSyncOp
+ : TransformEachOp<
+ "nvgpu.rewrite_matmul_as_mma_sync",
+ "::mlir::linalg::LinalgOp",
+ [
+ FunctionalStyleTransformOpTrait,
+ MemoryEffectsOpInterface,
+ ReportTrackingListenerFailuresOpTrait
+ ]> {
let description = [{
Rewrite a matmul operation on memref to an mma.sync operation on vectors.
@@ -169,14 +160,6 @@ def RewriteMatmulAsMmaSyncOp :
let results = (outs);
let assemblyFormat = "$target attr-dict `:` functional-type(operands, results) ";
-
- let extraClassDeclaration = [{
- ::mlir::DiagnosedSilenceableFailure applyToOne(
- ::mlir::transform::TransformRewriter &rewriter,
- ::mlir::linalg::LinalgOp linalgOp,
- ::mlir::transform::ApplyToEachResultList &results,
- ::mlir::transform::TransformState &state);
- }];
}
//===----------------------------------------------------------------------===//
@@ -200,6 +183,7 @@ def RewriteCopyAsTmaOp :
let assemblyFormat = "$target attr-dict `:` functional-type(operands, results) ";
+ // TODO: use TransformEachOp when extended to op-less API.
let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure apply(
::mlir::transform::TransformRewriter &rewriter,
diff --git a/mlir/include/mlir/Dialect/Transform/Interfaces/Utils.td b/mlir/include/mlir/Dialect/Transform/Interfaces/Utils.td
new file mode 100644
index 0000000000000..68b46f25b756c
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Transform/Interfaces/Utils.td
@@ -0,0 +1,25 @@
+#ifndef MLIR_DIALECT_TRANSFORM_INTERFACES_UTILS_TD
+#define MLIR_DIALECT_TRANSFORM_INTERFACES_UTILS_TD
+
+include "mlir/IR/OpBase.td"
+include "mlir/Dialect/Transform/IR/TransformDialect.td"
+include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+
+class TransformEachOp<string name,
+ string kind = "::mlir::Operation *",
+ list<Trait> traits = []>
+ : Op<Transform_Dialect, name, !listconcat(traits, [
+ TransformOpInterface, TransformEachOpTrait
+ ])> {
+ defvar applyMethodDeclaration = [{
+ ::mlir::DiagnosedSilenceableFailure applyToOne(
+ ::mlir::transform::TransformRewriter &rewriter,
+ }] # kind # [{ target,
+ ::mlir::transform::ApplyToEachResultList &results,
+ ::mlir::transform::TransformState &state);
+ }];
+ let extraClassDeclaration = applyMethodDeclaration;
+}
+
+#endif // MLIR_DIALECT_TRANSFORM_INTERFACES_UTILS_TD
>From ed41e82abb76c84589bfd05fae2b35f77bb09663 Mon Sep 17 00:00:00 2001
From: Nicolas Vasilache <nico.vasilache at amd.com>
Date: Thu, 3 Jul 2025 18:32:59 +0200
Subject: [PATCH 3/3] [mlir] NFC - refactor id builder and avoid leaking impl
details
---
.../mlir/Dialect/GPU/TransformOps/Utils.h | 31 ++--
.../GPU/TransformOps/GPUTransformOps.cpp | 32 +---
mlir/lib/Dialect/GPU/TransformOps/Utils.cpp | 165 ++++++++++++++----
3 files changed, 150 insertions(+), 78 deletions(-)
diff --git a/mlir/include/mlir/Dialect/GPU/TransformOps/Utils.h b/mlir/include/mlir/Dialect/GPU/TransformOps/Utils.h
index 52fc6f4d5c71b..0cd15835ad70f 100644
--- a/mlir/include/mlir/Dialect/GPU/TransformOps/Utils.h
+++ b/mlir/include/mlir/Dialect/GPU/TransformOps/Utils.h
@@ -28,27 +28,24 @@ namespace transform {
namespace gpu {
/// Helper type for functions that generate ids for the mapping of a scf.forall.
-/// Operates on both 1) an "original" basis that represents the individual
-/// thread and block ids and 2) a "scaled" basis that represents grouped ids
-/// (e.g. block clusters, warpgroups and warps).
-/// The mapping of ids is done in the "scaled" basis (i.e. when mapping to warps
-/// a division by 32 occurs).
-/// The predication is in the "original" basis using the "active" quantities
-/// (`activeMappingSizes`, `availableMappingSizes` and `activeIdOps`).
struct IdBuilderResult {
- // Ops used to replace the forall induction variables.
+ /// Error message, if not empty then building the ids failed.
+ std::string errorMsg;
+ /// Values used to replace the forall induction variables.
SmallVector<Value> mappingIdOps;
- // Available mapping sizes used to predicate the forall body when they are
- // larger than the predicate mapping sizes.
- SmallVector<int64_t> availableMappingSizes;
- // Actual mapping sizes used to predicate the forall body when they are
- // smaller than the available mapping sizes.
- SmallVector<int64_t> activeMappingSizes;
- // Ops used to predicate the forall body when activeMappingSizes is smaller
- // than the available mapping sizes.
- SmallVector<Value> activeIdOps;
+ /// Values used to predicate the forall body when activeMappingSizes is
+ /// smaller than the available mapping sizes.
+ SmallVector<Value> predicateOps;
};
+inline raw_ostream &operator<<(raw_ostream &os, const IdBuilderResult &res) {
+ llvm::interleaveComma(res.mappingIdOps, os << "----mappingIdOps: ");
+ os << "\n";
+ llvm::interleaveComma(res.predicateOps, os << "----predicateOps: ");
+ os << "\n";
+ return os;
+}
+
/// Common gpu id builder type, allows the configuration of lowering for various
/// mapping schemes. Takes:
/// - A rewriter with insertion point set before the forall op to rewrite.
diff --git a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
index 6446235c06fb2..1ae923db149a5 100644
--- a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
+++ b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
@@ -480,6 +480,10 @@ static DiagnosedSilenceableFailure rewriteOneForallCommonImpl(
IdBuilderResult builderResult =
gpuIdBuilder.idBuilder(rewriter, loc, forallMappingSizes, originalBasis);
+ if (!builderResult.errorMsg.empty())
+ return definiteFailureHelper(transformOp, forallOp, builderResult.errorMsg);
+
+ LLVM_DEBUG(DBGS() << builderResult);
// Step 4. Map the induction variables to the mappingIdOps, this may involve
// a permutation.
@@ -490,6 +494,7 @@ static DiagnosedSilenceableFailure rewriteOneForallCommonImpl(
forallMappingAttrs.getArrayRef().take_front(forallOp.getRank()))) {
auto mappingAttr = cast<DeviceMappingAttrInterface>(dim);
Value peIdOp = mappingIdOps[mappingAttr.getRelativeIndex()];
+ LDBG("----map: " << iv << " to " << peIdOp);
bvm.map(iv, peIdOp);
}
@@ -498,32 +503,7 @@ static DiagnosedSilenceableFailure rewriteOneForallCommonImpl(
// originalBasis and no predication occurs.
Value predicate;
if (originalBasisWasProvided) {
- SmallVector<int64_t> activeMappingSizes = builderResult.activeMappingSizes;
- SmallVector<int64_t> availableMappingSizes =
- builderResult.availableMappingSizes;
- SmallVector<Value> activeIdOps = builderResult.activeIdOps;
- LDBG("----activeMappingSizes: " << llvm::interleaved(activeMappingSizes));
- LDBG("----availableMappingSizes: "
- << llvm::interleaved(availableMappingSizes));
- LDBG("----activeIdOps: " << llvm::interleaved(activeIdOps));
- for (auto [activeId, activeMappingSize, availableMappingSize] :
- llvm::zip_equal(activeIdOps, activeMappingSizes,
- availableMappingSizes)) {
- if (activeMappingSize > availableMappingSize) {
- return definiteFailureHelper(
- transformOp, forallOp,
- "Trying to map to fewer GPU threads than loop iterations but "
- "overprovisioning is not yet supported. "
- "Try additional tiling of the before mapping or map to more "
- "threads.");
- }
- if (activeMappingSize == availableMappingSize)
- continue;
- Value idx =
- rewriter.create<arith::ConstantIndexOp>(loc, activeMappingSize);
- Value tmpPredicate = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::ult, activeId, idx);
- LDBG("----predicate: " << tmpPredicate);
+ for (Value tmpPredicate : builderResult.predicateOps) {
predicate = predicate ? rewriter.create<arith::AndIOp>(loc, predicate,
tmpPredicate)
: tmpPredicate;
diff --git a/mlir/lib/Dialect/GPU/TransformOps/Utils.cpp b/mlir/lib/Dialect/GPU/TransformOps/Utils.cpp
index 9853e80828390..00bb159e868c0 100644
--- a/mlir/lib/Dialect/GPU/TransformOps/Utils.cpp
+++ b/mlir/lib/Dialect/GPU/TransformOps/Utils.cpp
@@ -47,12 +47,57 @@ using namespace mlir::transform::gpu;
#define LDBG(X) LLVM_DEBUG(DBGS() << (X) << "\n")
#define DBGS_ALIAS() (llvm::dbgs() << '[' << DEBUG_TYPE_ALIAS << "] ")
+/// Build predicates to filter execution by only the activeIds. Along each
+/// dimension, 3 cases appear:
+/// 1. activeMappingSize > availableMappingSize: this is an unsupported case
+/// as this requires additional looping. An error message is produced to
+/// advise the user to tile more or to use more threads.
+/// 2. activeMappingSize == availableMappingSize: no predication is needed.
+/// 3. activeMappingSize < availableMappingSize: only a subset of threads
+/// should be active and we produce the boolean `id < activeMappingSize`
+/// for further use in building predicated execution.
+static FailureOr<SmallVector<Value>>
+buildPredicates(RewriterBase &rewriter, Location loc, ArrayRef<Value> activeIds,
+ ArrayRef<int64_t> activeMappingSizes,
+ ArrayRef<int64_t> availableMappingSizes,
+ std::string &errorMsg) {
+ // clang-format off
+ LLVM_DEBUG(
+ llvm::interleaveComma(
+ activeMappingSizes, DBGS() << "----activeMappingSizes: ");
+ DBGS() << "\n";
+ llvm::interleaveComma(
+ availableMappingSizes, DBGS() << "----availableMappingSizes: ");
+ DBGS() << "\n";);
+ // clang-format on
+
+ SmallVector<Value> predicateOps;
+ for (auto [activeId, activeMappingSize, availableMappingSize] :
+ llvm::zip_equal(activeIds, activeMappingSizes, availableMappingSizes)) {
+ if (activeMappingSize > availableMappingSize) {
+ errorMsg = "Trying to map to fewer GPU threads than loop iterations but "
+ "overprovisioning is not yet supported. Try additional tiling "
+ "before mapping or map to more threads.";
+ return failure();
+ }
+ if (activeMappingSize == availableMappingSize)
+ continue;
+ Value idx = rewriter.create<arith::ConstantIndexOp>(loc, activeMappingSize);
+ Value pred = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
+ activeId, idx);
+ predicateOps.push_back(pred);
+ }
+ return predicateOps;
+}
+
/// Return a flattened thread id for the workgroup with given sizes.
template <typename ThreadOrBlockIdOp>
static Value buildLinearId(RewriterBase &rewriter, Location loc,
ArrayRef<OpFoldResult> originalBasisOfr) {
- LLVM_DEBUG(DBGS() << "----buildLinearId with originalBasisOfr: "
- << llvm::interleaved(originalBasisOfr) << "\n");
+ LLVM_DEBUG(llvm::interleaveComma(
+ originalBasisOfr,
+ DBGS() << "----buildLinearId with originalBasisOfr: ");
+ llvm::dbgs() << "\n");
assert(originalBasisOfr.size() == 3 && "expected 3 sizes");
IndexType indexType = rewriter.getIndexType();
AffineExpr tx, ty, tz, bdx, bdy;
@@ -79,44 +124,43 @@ static GpuIdBuilderFnType commonLinearIdBuilderFn(int64_t multiplicity = 1) {
auto res = [multiplicity](RewriterBase &rewriter, Location loc,
ArrayRef<int64_t> forallMappingSizes,
ArrayRef<int64_t> originalBasis) {
+ // 1. Compute linearId.
SmallVector<OpFoldResult> originalBasisOfr =
getAsIndexOpFoldResult(rewriter.getContext(), originalBasis);
- OpFoldResult linearId =
+ Value physicalLinearId =
buildLinearId<ThreadOrBlockIdOp>(rewriter, loc, originalBasisOfr);
+
+ // 2. Compute scaledLinearId.
+ AffineExpr d0 = getAffineDimExpr(0, rewriter.getContext());
+ OpFoldResult scaledLinearId = affine::makeComposedFoldedAffineApply(
+ rewriter, loc, d0.floorDiv(multiplicity), {physicalLinearId});
+
+ // 3. Compute remapped indices.
+ SmallVector<Value> ids;
// Sizes in [0 .. n] -> [n .. 0] order to properly compute strides in
// "row-major" order.
SmallVector<int64_t> reverseBasisSizes(llvm::reverse(forallMappingSizes));
SmallVector<int64_t> strides = computeStrides(reverseBasisSizes);
- AffineExpr d0 = getAffineDimExpr(0, rewriter.getContext());
- OpFoldResult scaledLinearId = affine::makeComposedFoldedAffineApply(
- rewriter, loc, d0.floorDiv(multiplicity), {linearId});
SmallVector<AffineExpr> delinearizingExprs = delinearize(d0, strides);
- SmallVector<Value> ids;
// Reverse back to be in [0 .. n] order.
for (AffineExpr e : llvm::reverse(delinearizingExprs)) {
ids.push_back(
affine::makeComposedAffineApply(rewriter, loc, e, {scaledLinearId}));
}
- LLVM_DEBUG(DBGS() << "--delinearization basis: "
- << llvm::interleaved(reverseBasisSizes) << "\n";
- DBGS() << "--delinearization strides: "
- << llvm::interleaved(strides) << "\n";
- DBGS() << "--delinearization exprs: "
- << llvm::interleaved(delinearizingExprs) << "\n";
- DBGS() << "--ids: " << llvm::interleaved(ids) << "\n");
-
- // Return n-D ids for indexing and 1-D size + id for predicate generation.
- return IdBuilderResult{
- /*mappingIdOps=*/ids,
- /*availableMappingSizes=*/
- SmallVector<int64_t>{computeProduct(originalBasis)},
- // `forallMappingSizes` iterate in the scaled basis, they need to be
- // scaled back into the original basis to provide tight
- // activeMappingSizes quantities for predication.
- /*activeMappingSizes=*/
- SmallVector<int64_t>{computeProduct(forallMappingSizes) * multiplicity},
- /*activeIdOps=*/SmallVector<Value>{cast<Value>(linearId)}};
+ // 4. Handle predicates using physicalLinearId.
+ std::string errorMsg;
+ SmallVector<Value> predicateOps;
+ FailureOr<SmallVector<Value>> maybePredicateOps =
+ buildPredicates(rewriter, loc, physicalLinearId,
+ computeProduct(forallMappingSizes) * multiplicity,
+ computeProduct(originalBasis), errorMsg);
+ if (succeeded(maybePredicateOps))
+ predicateOps = *maybePredicateOps;
+
+ return IdBuilderResult{/*errorMsg=*/errorMsg,
+ /*mappingIdOps=*/ids,
+ /*predicateOps=*/predicateOps};
};
return res;
@@ -143,16 +187,67 @@ static GpuIdBuilderFnType common3DIdBuilderFn(int64_t multiplicity = 1) {
// In the 3-D mapping case, unscale the first dimension by the multiplicity.
SmallVector<int64_t> forallMappingSizeInOriginalBasis(forallMappingSizes);
forallMappingSizeInOriginalBasis[0] *= multiplicity;
- return IdBuilderResult{
- /*mappingIdOps=*/scaledIds,
- /*availableMappingSizes=*/SmallVector<int64_t>{originalBasis},
- // `forallMappingSizes` iterate in the scaled basis, they need to be
- // scaled back into the original basis to provide tight
- // activeMappingSizes quantities for predication.
- /*activeMappingSizes=*/
- SmallVector<int64_t>{forallMappingSizeInOriginalBasis},
- /*activeIdOps=*/ids};
+
+ std::string errorMsg;
+ SmallVector<Value> predicateOps;
+ FailureOr<SmallVector<Value>> maybePredicateOps =
+ buildPredicates(rewriter, loc, ids, forallMappingSizeInOriginalBasis,
+ originalBasis, errorMsg);
+ if (succeeded(maybePredicateOps))
+ predicateOps = *maybePredicateOps;
+
+ return IdBuilderResult{/*errorMsg=*/errorMsg,
+ /*mappingIdOps=*/scaledIds,
+ /*predicateOps=*/predicateOps};
+ };
+ return res;
+}
+
+/// Create a lane id builder that takes the `originalBasis` and decompose
+/// it in the basis of `forallMappingSizes`. The linear id builder returns an
+/// n-D vector of ids for indexing and 1-D size + id for predicate generation.
+static GpuIdBuilderFnType laneIdBuilderFn(int64_t warpSize) {
+ auto res = [warpSize](RewriterBase &rewriter, Location loc,
+ ArrayRef<int64_t> forallMappingSizes,
+ ArrayRef<int64_t> originalBasis) {
+ // 1. Compute linearId.
+ SmallVector<OpFoldResult> originalBasisOfr =
+ getAsIndexOpFoldResult(rewriter.getContext(), originalBasis);
+ Value physicalLinearId =
+ buildLinearId<ThreadIdOp>(rewriter, loc, originalBasisOfr);
+
+ // 2. Compute laneId.
+ AffineExpr d0 = getAffineDimExpr(0, rewriter.getContext());
+ OpFoldResult laneId = affine::makeComposedFoldedAffineApply(
+ rewriter, loc, d0 % warpSize, {physicalLinearId});
+
+ // 3. Compute remapped indices.
+ SmallVector<Value> ids;
+ // Sizes in [0 .. n] -> [n .. 0] order to properly compute strides in
+ // "row-major" order.
+ SmallVector<int64_t> reverseBasisSizes(llvm::reverse(forallMappingSizes));
+ SmallVector<int64_t> strides = computeStrides(reverseBasisSizes);
+ SmallVector<AffineExpr> delinearizingExprs = delinearize(d0, strides);
+ // Reverse back to be in [0 .. n] order.
+ for (AffineExpr e : llvm::reverse(delinearizingExprs)) {
+ ids.push_back(
+ affine::makeComposedAffineApply(rewriter, loc, e, {laneId}));
+ }
+
+ // 4. Handle predicates using laneId.
+ std::string errorMsg;
+ SmallVector<Value> predicateOps;
+ FailureOr<SmallVector<Value>> maybePredicateOps = buildPredicates(
+ rewriter, loc, cast<Value>(laneId), computeProduct(forallMappingSizes),
+ computeProduct(originalBasis), errorMsg);
+ if (succeeded(maybePredicateOps))
+ predicateOps = *maybePredicateOps;
+
+ return IdBuilderResult{/*errorMsg=*/errorMsg,
+ /*mappingIdOps=*/ids,
+ /*predicateOps=*/predicateOps};
};
+
return res;
}
More information about the Mlir-commits
mailing list