[Mlir-commits] [mlir] [mlir][python] Expose transform param types (PR #67421)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Sep 26 06:06:47 PDT 2023


https://github.com/martin-luecke updated https://github.com/llvm/llvm-project/pull/67421

>From 19c558573162a30e01e09c311a071f2b3f2c3ccc Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Martin=20L=C3=BCcke?= <martin.luecke at ed.ac.uk>
Date: Tue, 26 Sep 2023 13:06:08 +0000
Subject: [PATCH] [mlir][python] Expose transform param type

---
 mlir/include/mlir-c/Dialect/Transform.h       | 19 ++++++++++
 mlir/lib/Bindings/Python/DialectTransform.cpp | 35 +++++++++++++++++++
 mlir/lib/CAPI/Dialect/Transform.cpp           | 28 +++++++++++++++
 mlir/test/python/dialects/transform.py        | 10 ++++++
 4 files changed, 92 insertions(+)

diff --git a/mlir/include/mlir-c/Dialect/Transform.h b/mlir/include/mlir-c/Dialect/Transform.h
index 954575925cc5c45..91c99b1f869f22c 100644
--- a/mlir/include/mlir-c/Dialect/Transform.h
+++ b/mlir/include/mlir-c/Dialect/Transform.h
@@ -27,6 +27,14 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsATransformAnyOpType(MlirType type);
 
 MLIR_CAPI_EXPORTED MlirType mlirTransformAnyOpTypeGet(MlirContext ctx);
 
+//===---------------------------------------------------------------------===//
+// AnyParamType
+//===---------------------------------------------------------------------===//
+
+MLIR_CAPI_EXPORTED bool mlirTypeIsATransformAnyParamType(MlirType type);
+
+MLIR_CAPI_EXPORTED MlirType mlirTransformAnyParamTypeGet(MlirContext ctx);
+
 //===---------------------------------------------------------------------===//
 // AnyValueType
 //===---------------------------------------------------------------------===//
@@ -49,6 +57,17 @@ mlirTransformOperationTypeGet(MlirContext ctx, MlirStringRef operationName);
 MLIR_CAPI_EXPORTED MlirStringRef
 mlirTransformOperationTypeGetOperationName(MlirType type);
 
+//===---------------------------------------------------------------------===//
+// ParamType
+//===---------------------------------------------------------------------===//
+
+MLIR_CAPI_EXPORTED bool mlirTypeIsATransformParamType(MlirType type);
+
+MLIR_CAPI_EXPORTED MlirType mlirTransformParamTypeGet(MlirContext ctx,
+                                                      MlirType type);
+
+MLIR_CAPI_EXPORTED MlirType mlirTransformParamTypeGetType(MlirType type);
+
 #ifdef __cplusplus
 }
 #endif
diff --git a/mlir/lib/Bindings/Python/DialectTransform.cpp b/mlir/lib/Bindings/Python/DialectTransform.cpp
index 932e40220057c13..cbbf8332b14ff30 100644
--- a/mlir/lib/Bindings/Python/DialectTransform.cpp
+++ b/mlir/lib/Bindings/Python/DialectTransform.cpp
@@ -31,6 +31,20 @@ void populateDialectTransformSubmodule(const pybind11::module &m) {
       "Get an instance of AnyOpType in the given context.", py::arg("cls"),
       py::arg("context") = py::none());
 
+  //===-------------------------------------------------------------------===//
+  // AnyParamType
+  //===-------------------------------------------------------------------===//
+
+  auto anyParamType =
+      mlir_type_subclass(m, "AnyParamType", mlirTypeIsATransformAnyParamType);
+  anyParamType.def_classmethod(
+      "get",
+      [](py::object cls, MlirContext ctx) {
+        return cls(mlirTransformAnyParamTypeGet(ctx));
+      },
+      "Get an instance of AnyParamType in the given context.", py::arg("cls"),
+      py::arg("context") = py::none());
+
   //===-------------------------------------------------------------------===//
   // AnyValueType
   //===-------------------------------------------------------------------===//
@@ -71,6 +85,27 @@ void populateDialectTransformSubmodule(const pybind11::module &m) {
         return py::str(operationName.data, operationName.length);
       },
       "Get the name of the payload operation accepted by the handle.");
+
+  //===-------------------------------------------------------------------===//
+  // ParamType
+  //===-------------------------------------------------------------------===//
+
+  auto paramType =
+      mlir_type_subclass(m, "ParamType", mlirTypeIsATransformParamType);
+  paramType.def_classmethod(
+      "get",
+      [](py::object cls, MlirType type, MlirContext ctx) {
+        return cls(mlirTransformParamTypeGet(ctx, type));
+      },
+      "Get an instance of ParamType for the given type in the given context.",
+      py::arg("cls"), py::arg("type"), py::arg("context") = py::none());
+  paramType.def_property_readonly(
+      "type",
+      [](MlirType type) {
+        MlirType paramType = mlirTransformParamTypeGetType(type);
+        return paramType;
+      },
+      "Get the type this ParamType is associated with.");
 }
 
 PYBIND11_MODULE(_mlirDialectsTransform, m) {
diff --git a/mlir/lib/CAPI/Dialect/Transform.cpp b/mlir/lib/CAPI/Dialect/Transform.cpp
index 5841f6783ad5f1d..3f7f8b8e2113fe4 100644
--- a/mlir/lib/CAPI/Dialect/Transform.cpp
+++ b/mlir/lib/CAPI/Dialect/Transform.cpp
@@ -29,6 +29,18 @@ MlirType mlirTransformAnyOpTypeGet(MlirContext ctx) {
   return wrap(transform::AnyOpType::get(unwrap(ctx)));
 }
 
+//===---------------------------------------------------------------------===//
+// AnyParamType
+//===---------------------------------------------------------------------===//
+
+bool mlirTypeIsATransformAnyParamType(MlirType type) {
+  return isa<transform::AnyParamType>(unwrap(type));
+}
+
+MlirType mlirTransformAnyParamTypeGet(MlirContext ctx) {
+  return wrap(transform::AnyParamType::get(unwrap(ctx)));
+}
+
 //===---------------------------------------------------------------------===//
 // AnyValueType
 //===---------------------------------------------------------------------===//
@@ -62,3 +74,19 @@ MlirType mlirTransformOperationTypeGet(MlirContext ctx,
 MlirStringRef mlirTransformOperationTypeGetOperationName(MlirType type) {
   return wrap(cast<transform::OperationType>(unwrap(type)).getOperationName());
 }
+
+//===---------------------------------------------------------------------===//
+// AnyOpType
+//===---------------------------------------------------------------------===//
+
+bool mlirTypeIsATransformParamType(MlirType type) {
+  return isa<transform::ParamType>(unwrap(type));
+}
+
+MlirType mlirTransformParamTypeGet(MlirContext ctx, MlirType type) {
+  return wrap(transform::ParamType::get(unwrap(ctx), unwrap(type)));
+}
+
+MlirType mlirTransformParamTypeGetType(MlirType type) {
+  return wrap(cast<transform::ParamType>(unwrap(type)).getType());
+}
diff --git a/mlir/test/python/dialects/transform.py b/mlir/test/python/dialects/transform.py
index 5df125694256a4e..481d7745720101d 100644
--- a/mlir/test/python/dialects/transform.py
+++ b/mlir/test/python/dialects/transform.py
@@ -22,6 +22,10 @@ def testTypes():
     any_op = transform.AnyOpType.get()
     print(any_op)
 
+    # CHECK: !transform.any_param
+    any_param = transform.AnyParamType.get()
+    print(any_param)
+
     # CHECK: !transform.any_value
     any_value = transform.AnyValueType.get()
     print(any_value)
@@ -32,6 +36,12 @@ def testTypes():
     print(concrete_op)
     print(concrete_op.operation_name)
 
+    # CHECK: !transform.param<i32>
+    # CHECK: i32
+    param = transform.ParamType.get(IntegerType.get_signless(32))
+    print(param)
+    print(param.type)
+
 
 @run
 def testSequenceOp():



More information about the Mlir-commits mailing list