[Mlir-commits] [mlir] 681eacc - [MLIR][transform][python] add sugared python abstractions for transform dialect (#75073)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Dec 15 04:04:47 PST 2023


Author: martin-luecke
Date: 2023-12-15T13:04:43+01:00
New Revision: 681eacc1b670fd7137d8677fef6fc76c6e37dca9

URL: https://github.com/llvm/llvm-project/commit/681eacc1b670fd7137d8677fef6fc76c6e37dca9
DIFF: https://github.com/llvm/llvm-project/commit/681eacc1b670fd7137d8677fef6fc76c6e37dca9.diff

LOG: [MLIR][transform][python] add sugared python abstractions for transform dialect (#75073)

This adds Python abstractions for the different handle types of the
transform dialect

The abstractions allow for straightforward chaining of transforms by
calling their member functions.
As an initial PR for this infrastructure, only a single transform is
included: `transform.structured.match`.
With a future `tile` transform abstraction an example of the usage is: 
```Python
def script(module: OpHandle):
    module.match_ops(MatchInterfaceEnum.TilingInterface).tile(tile_sizes=[32,32])
```
to generate the following IR:
```mlir
%0 = transform.structured.match interface{TilingInterface} in %arg0
%tiled_op, %loops = transform.structured.tile_using_for %0 [32, 32]
```

These abstractions are intended to enhance the usability and flexibility
of the transform dialect by providing an accessible interface that
allows for easy assembly of complex transformation chains.

Added: 
    mlir/python/mlir/extras/dialects/transform/__init__.py
    mlir/test/python/dialects/transform_extras.py

Modified: 
    mlir/include/mlir-c/Dialect/Transform.h
    mlir/include/mlir/Bindings/Python/PybindAdaptors.h
    mlir/lib/Bindings/Python/DialectTransform.cpp
    mlir/lib/CAPI/Dialect/Transform.cpp
    mlir/python/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/Dialect/Transform.h b/mlir/include/mlir-c/Dialect/Transform.h
index 91c99b1f869f22..02c99b59218825 100644
--- a/mlir/include/mlir-c/Dialect/Transform.h
+++ b/mlir/include/mlir-c/Dialect/Transform.h
@@ -25,6 +25,8 @@ MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Transform, transform);
 
 MLIR_CAPI_EXPORTED bool mlirTypeIsATransformAnyOpType(MlirType type);
 
+MLIR_CAPI_EXPORTED MlirTypeID mlirTransformAnyOpTypeGetTypeID(void);
+
 MLIR_CAPI_EXPORTED MlirType mlirTransformAnyOpTypeGet(MlirContext ctx);
 
 //===---------------------------------------------------------------------===//
@@ -33,6 +35,8 @@ MLIR_CAPI_EXPORTED MlirType mlirTransformAnyOpTypeGet(MlirContext ctx);
 
 MLIR_CAPI_EXPORTED bool mlirTypeIsATransformAnyParamType(MlirType type);
 
+MLIR_CAPI_EXPORTED MlirTypeID mlirTransformAnyParamTypeGetTypeID(void);
+
 MLIR_CAPI_EXPORTED MlirType mlirTransformAnyParamTypeGet(MlirContext ctx);
 
 //===---------------------------------------------------------------------===//
@@ -41,6 +45,8 @@ MLIR_CAPI_EXPORTED MlirType mlirTransformAnyParamTypeGet(MlirContext ctx);
 
 MLIR_CAPI_EXPORTED bool mlirTypeIsATransformAnyValueType(MlirType type);
 
+MLIR_CAPI_EXPORTED MlirTypeID mlirTransformAnyValueTypeGetTypeID(void);
+
 MLIR_CAPI_EXPORTED MlirType mlirTransformAnyValueTypeGet(MlirContext ctx);
 
 //===---------------------------------------------------------------------===//
@@ -63,6 +69,8 @@ mlirTransformOperationTypeGetOperationName(MlirType type);
 
 MLIR_CAPI_EXPORTED bool mlirTypeIsATransformParamType(MlirType type);
 
+MLIR_CAPI_EXPORTED MlirTypeID mlirTransformParamTypeGetTypeID(void);
+
 MLIR_CAPI_EXPORTED MlirType mlirTransformParamTypeGet(MlirContext ctx,
                                                       MlirType type);
 

diff  --git a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
index 5e0e56fc00a673..66cf20e1c136f9 100644
--- a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
+++ b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
@@ -495,6 +495,13 @@ class mlir_type_subclass : public pure_subclass {
           .attr("replace")(superCls.attr("__name__"), captureTypeName);
     });
     if (getTypeIDFunction) {
+      // 'get_static_typeid' method.
+      // This is modeled as a static method instead of a static property because
+      // `def_property_readonly_static` is not available in `pure_subclass` and
+      // we do not want to introduce the complexity that pybind uses to
+      // implement it.
+      def_staticmethod("get_static_typeid",
+                       [getTypeIDFunction]() { return getTypeIDFunction(); });
       py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
           .attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)(
               getTypeIDFunction())(pybind11::cpp_function(

diff  --git a/mlir/lib/Bindings/Python/DialectTransform.cpp b/mlir/lib/Bindings/Python/DialectTransform.cpp
index c7764f4e7aecac..6b57e652aa9d8b 100644
--- a/mlir/lib/Bindings/Python/DialectTransform.cpp
+++ b/mlir/lib/Bindings/Python/DialectTransform.cpp
@@ -27,7 +27,8 @@ void populateDialectTransformSubmodule(const pybind11::module &m) {
   //===-------------------------------------------------------------------===//
 
   auto anyOpType =
-      mlir_type_subclass(m, "AnyOpType", mlirTypeIsATransformAnyOpType);
+      mlir_type_subclass(m, "AnyOpType", mlirTypeIsATransformAnyOpType,
+                         mlirTransformAnyOpTypeGetTypeID);
   anyOpType.def_classmethod(
       "get",
       [](py::object cls, MlirContext ctx) {
@@ -41,7 +42,8 @@ void populateDialectTransformSubmodule(const pybind11::module &m) {
   //===-------------------------------------------------------------------===//
 
   auto anyParamType =
-      mlir_type_subclass(m, "AnyParamType", mlirTypeIsATransformAnyParamType);
+      mlir_type_subclass(m, "AnyParamType", mlirTypeIsATransformAnyParamType,
+                         mlirTransformAnyParamTypeGetTypeID);
   anyParamType.def_classmethod(
       "get",
       [](py::object cls, MlirContext ctx) {
@@ -55,7 +57,8 @@ void populateDialectTransformSubmodule(const pybind11::module &m) {
   //===-------------------------------------------------------------------===//
 
   auto anyValueType =
-      mlir_type_subclass(m, "AnyValueType", mlirTypeIsATransformAnyValueType);
+      mlir_type_subclass(m, "AnyValueType", mlirTypeIsATransformAnyValueType,
+                         mlirTransformAnyValueTypeGetTypeID);
   anyValueType.def_classmethod(
       "get",
       [](py::object cls, MlirContext ctx) {
@@ -96,7 +99,8 @@ void populateDialectTransformSubmodule(const pybind11::module &m) {
   //===-------------------------------------------------------------------===//
 
   auto paramType =
-      mlir_type_subclass(m, "ParamType", mlirTypeIsATransformParamType);
+      mlir_type_subclass(m, "ParamType", mlirTypeIsATransformParamType,
+                         mlirTransformParamTypeGetTypeID);
   paramType.def_classmethod(
       "get",
       [](py::object cls, MlirType type, MlirContext ctx) {

diff  --git a/mlir/lib/CAPI/Dialect/Transform.cpp b/mlir/lib/CAPI/Dialect/Transform.cpp
index 3f7f8b8e2113fe..5fd773572bd3c8 100644
--- a/mlir/lib/CAPI/Dialect/Transform.cpp
+++ b/mlir/lib/CAPI/Dialect/Transform.cpp
@@ -25,6 +25,10 @@ bool mlirTypeIsATransformAnyOpType(MlirType type) {
   return isa<transform::AnyOpType>(unwrap(type));
 }
 
+MlirTypeID mlirTransformAnyOpTypeGetTypeID(void) {
+  return wrap(transform::AnyOpType::getTypeID());
+}
+
 MlirType mlirTransformAnyOpTypeGet(MlirContext ctx) {
   return wrap(transform::AnyOpType::get(unwrap(ctx)));
 }
@@ -37,6 +41,10 @@ bool mlirTypeIsATransformAnyParamType(MlirType type) {
   return isa<transform::AnyParamType>(unwrap(type));
 }
 
+MlirTypeID mlirTransformAnyParamTypeGetTypeID(void) {
+  return wrap(transform::AnyParamType::getTypeID());
+}
+
 MlirType mlirTransformAnyParamTypeGet(MlirContext ctx) {
   return wrap(transform::AnyParamType::get(unwrap(ctx)));
 }
@@ -49,6 +57,10 @@ bool mlirTypeIsATransformAnyValueType(MlirType type) {
   return isa<transform::AnyValueType>(unwrap(type));
 }
 
+MlirTypeID mlirTransformAnyValueTypeGetTypeID(void) {
+  return wrap(transform::AnyValueType::getTypeID());
+}
+
 MlirType mlirTransformAnyValueTypeGet(MlirContext ctx) {
   return wrap(transform::AnyValueType::get(unwrap(ctx)));
 }
@@ -76,13 +88,17 @@ MlirStringRef mlirTransformOperationTypeGetOperationName(MlirType type) {
 }
 
 //===---------------------------------------------------------------------===//
-// AnyOpType
+// ParamType
 //===---------------------------------------------------------------------===//
 
 bool mlirTypeIsATransformParamType(MlirType type) {
   return isa<transform::ParamType>(unwrap(type));
 }
 
+MlirTypeID mlirTransformParamTypeGetTypeID(void) {
+  return wrap(transform::ParamType::getTypeID());
+}
+
 MlirType mlirTransformParamTypeGet(MlirContext ctx, MlirType type) {
   return wrap(transform::ParamType::get(unwrap(ctx), unwrap(type)));
 }

diff  --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 585918afc26335..41d91cf6778338 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -166,6 +166,14 @@ declare_mlir_dialect_python_bindings(
     "../../include/mlir/Dialect/Transform/IR/TransformAttrs.td"
 )
 
+declare_mlir_python_sources(
+  MLIRPythonSources.Dialects.transform.extras
+  ADD_TO_PARENT MLIRPythonSources.Dialects
+  ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
+  GEN_ENUM_BINDINGS
+  SOURCES
+    extras/dialects/transform/__init__.py)
+
 declare_mlir_dialect_extension_python_bindings(
   ADD_TO_PARENT MLIRPythonSources.Dialects
   ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"

diff  --git a/mlir/python/mlir/extras/dialects/transform/__init__.py b/mlir/python/mlir/extras/dialects/transform/__init__.py
new file mode 100644
index 00000000000000..9e313324318aa6
--- /dev/null
+++ b/mlir/python/mlir/extras/dialects/transform/__init__.py
@@ -0,0 +1,148 @@
+#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+#  See https://llvm.org/LICENSE.txt for license information.
+#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from __future__ import annotations
+from typing import Callable, Optional, Sequence
+
+from .... import ir
+from ....dialects import transform
+from ....dialects.transform import structured
+
+
+class Handle(ir.Value):
+    """
+    Base class for wrappers around 
diff erent types of transform handle with
+    methods to chain further transforms.
+
+    The fields `children` and `parent` are used to capture the relation of
+    handles statically in order to enable further analysis. The payload
+    operation of a child handle is nested into a region of the payload operation
+    of the corresponding parent handle.
+    """
+
+    def __init__(
+        self,
+        v: ir.Value,
+        *,
+        parent: Optional[Handle] = None,
+        children: Optional[Sequence[Handle]] = None,
+    ):
+        super().__init__(v)
+        self.parent = parent
+        self.children = children if children is not None else []
+
+
+ at ir.register_value_caster(transform.AnyOpType.get_static_typeid())
+ at ir.register_value_caster(transform.OperationType.get_static_typeid())
+class OpHandle(Handle):
+    """
+    Wrapper around a transform operation handle with methods to chain further
+    transforms.
+    """
+
+    def __init__(
+        self,
+        v: ir.Value,
+        *,
+        parent: Optional[Handle] = None,
+        children: Optional[Sequence[Handle]] = None,
+    ):
+        super().__init__(v, parent=parent, children=children)
+
+    def match_ops(
+        self,
+        ops: str
+        | ir.OpView
+        | structured.MatchInterfaceEnum
+        | Sequence[str | ir.OpView],
+    ) -> OpHandle:
+        """
+        Emits a `transform.structured.MatchOp`.
+        Returns a handle to payload ops that match the given names, types, or
+        interface. If only a single type is given, the value wrapped by the
+        resulting handle is populated with the respective type.
+        """
+        # Handle interface.
+        if isinstance(ops, structured.MatchInterfaceEnum) or (
+            isinstance(ops, str) and ops in structured.MatchInterfaceEnum.__members__
+        ):
+            if isinstance(ops, str):
+                ops = structured.MatchInterfaceEnum[ops]
+            match_op = structured.MatchOp(
+                transform.AnyOpType.get(),
+                self,
+                interface=ops,
+            )
+
+        # Handle op name(s), either given directly as string or given as op.
+        else:
+            if isinstance(ops, str):
+                op_type = transform.OperationType.get(ops)
+                op_names = [ops]
+            elif isinstance(ops, Sequence):
+                op_type = transform.AnyOpType.get()
+                op_names = [
+                    op if isinstance(op, str) else op.OPERATION_NAME for op in ops
+                ]
+            else:
+                op_type = transform.OperationType.get(ops.OPERATION_NAME)
+                op_names = [ops.OPERATION_NAME]
+            match_op = structured.MatchOp.match_op_names(
+                op_type,
+                self,
+                op_names,
+            )
+
+        handle = OpHandle(match_op.results_, parent=self)
+        self.children.append(handle)
+        return handle
+
+
+def insert_transform_script(
+    block_or_insertion_point: ir.Block | ir.InsertionPoint,
+    script: Callable[[OpHandle], None],
+    dump_script: bool = False,
+) -> None:
+    """
+    Inserts the transform script of the schedule into the module. The script
+    should accept an instance of OpHandle as argument, which will be called with
+    the block arg of the newly created named_sequence op.
+
+    Example:
+    This python code
+    ```
+    module = ir.Module.create()
+    def test_match_ops_single(module: OpHandle):
+        module.match_ops(scf.ForOp)
+    insert_transform_script(module.body, script)
+    ```
+    generates the following IR:
+    ```
+    module {
+        transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
+        ^bb0(%arg0: !transform.any_op):
+            %0 = transform.structured.match ops{["scf.for"]} in %arg0
+                 : (!transform.any_op) -> !transform.op<"scf.for">
+        }
+    }
+    ```
+    """
+    if isinstance(block_or_insertion_point, ir.Block):
+        context = block_or_insertion_point.owner.context
+        insertion_point = ir.InsertionPoint.at_block_begin(block_or_insertion_point)
+    else:
+        context = block_or_insertion_point.block.owner.context
+        insertion_point = block_or_insertion_point
+
+    with context, ir.Location.unknown(context):
+        with insertion_point:
+            named_sequence_op = transform.NamedSequenceOp(
+                "__transform_main", [transform.AnyOpType.get()], []
+            )
+        with ir.InsertionPoint(named_sequence_op.body):
+            script(named_sequence_op.bodyTarget)
+            transform.YieldOp([])
+
+    if dump_script:
+        print(named_sequence_op)

diff  --git a/mlir/test/python/dialects/transform_extras.py b/mlir/test/python/dialects/transform_extras.py
new file mode 100644
index 00000000000000..dbfa8a2dc73c41
--- /dev/null
+++ b/mlir/test/python/dialects/transform_extras.py
@@ -0,0 +1,95 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+from typing import Callable
+from mlir import ir
+from mlir.dialects import scf
+from mlir.dialects.transform import structured
+from mlir.extras.dialects.transform import OpHandle, insert_transform_script
+
+
+def build_transform_script(script: Callable[[OpHandle], None]):
+    print("\nTEST:", script.__name__)
+    with ir.Context(), ir.Location.unknown():
+        module = ir.Module.create()
+        module.operation.attributes["transform.with_named_sequence"] = ir.UnitAttr.get()
+        insert_transform_script(module.body, script=script, dump_script=True)
+        module.operation.verify()
+
+
+def build_transform_script_at_insertion_point(script: Callable[[OpHandle], None]):
+    print("\nTEST:", script.__name__)
+    with ir.Context(), ir.Location.unknown():
+        module = ir.Module.create()
+        module.operation.attributes["transform.with_named_sequence"] = ir.UnitAttr.get()
+        insert_transform_script(
+            ir.InsertionPoint.at_block_begin(module.body),
+            script=script,
+            dump_script=True,
+        )
+        module.operation.verify()
+
+
+# CHECK-LABEL: TEST: test_build_script_at_insertion_point
+ at build_transform_script_at_insertion_point
+def test_build_script_at_insertion_point(op: OpHandle):
+    pass
+    # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
+    # CHECK-NEXT: transform.yield
+    # CHECK-NEXT: }
+
+
+# CHECK-LABEL: TEST: test_match_ops_single
+ at build_transform_script
+def test_match_ops_single(op: OpHandle):
+    op.match_ops(scf.ForOp)
+    # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
+    # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match ops{["scf.for"]}
+    # CHECK-SAME:    in %[[VAL_0]]
+    # CHECK-SAME:      -> !transform.op<"scf.for">
+
+
+# CHECK-LABEL: TEST: test_match_ops_string_name
+ at build_transform_script
+def test_match_ops_string_name(op: OpHandle):
+    op.match_ops("linalg.matmul")
+    # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
+    # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match
+    # CHECK-SAME:   ops{["linalg.matmul"]} in %[[VAL_0]]
+
+
+# CHECK-LABEL: TEST: test_match_ops_string_iface
+ at build_transform_script
+def test_match_ops_string_iface(op: OpHandle):
+    op.match_ops("LinalgOp")
+    # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
+    # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match
+    # CHECK-SAME:   interface{LinalgOp} in %[[VAL_0]]
+
+
+# CHECK-LABEL: TEST: test_match_ops_iface
+ at build_transform_script
+def test_match_ops_iface(op: OpHandle):
+    op.match_ops(structured.MatchInterfaceEnum.LinalgOp)
+    # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
+    # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match
+    # CHECK-SAME:   interface{LinalgOp} in %[[VAL_0]]
+
+
+# CHECK-LABEL: TEST: test_match_ops_multiple
+ at build_transform_script
+def test_match_ops_multiple(op: OpHandle):
+    op.match_ops([scf.ForOp, scf.ForallOp])
+    # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
+    # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match
+    # CHECK-SAME:   ops{["scf.for", "scf.forall"]} in %[[VAL_0]]
+    # CHECK-SAME:     -> !transform.any_op
+
+
+# CHECK-LABEL: TEST: test_match_ops_mixed
+ at build_transform_script
+def test_match_ops_mixed(op: OpHandle):
+    op.match_ops([scf.ForOp, "linalg.matmul", scf.ForallOp])
+    # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
+    # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match
+    # CHECK-SAME:   ops{["scf.for", "linalg.matmul", "scf.forall"]} in %[[VAL_0]]
+    # CHECK-SAME:     -> !transform.any_op


        


More information about the Mlir-commits mailing list