[Mlir-commits] [mlir] [mlir][python] Expose transform param types (PR #67421)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Sep 26 05:20:10 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

<details>
<summary>Changes</summary>

This exposes the Transform dialect types `AnyParamType` and `ParamType` via the Python bindings.

---
Full diff: https://github.com/llvm/llvm-project/pull/67421.diff


4 Files Affected:

- (modified) mlir/include/mlir-c/Dialect/Transform.h (+19) 
- (modified) mlir/lib/Bindings/Python/DialectTransform.cpp (+36) 
- (modified) mlir/lib/CAPI/Dialect/Transform.cpp (+28) 
- (modified) mlir/test/python/dialects/transform.py (+10) 


``````````diff
diff --git a/mlir/include/mlir-c/Dialect/Transform.h b/mlir/include/mlir-c/Dialect/Transform.h
index 954575925cc5c45..91c99b1f869f22c 100644
--- a/mlir/include/mlir-c/Dialect/Transform.h
+++ b/mlir/include/mlir-c/Dialect/Transform.h
@@ -27,6 +27,14 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsATransformAnyOpType(MlirType type);
 
 MLIR_CAPI_EXPORTED MlirType mlirTransformAnyOpTypeGet(MlirContext ctx);
 
+//===---------------------------------------------------------------------===//
+// AnyParamType
+//===---------------------------------------------------------------------===//
+
+MLIR_CAPI_EXPORTED bool mlirTypeIsATransformAnyParamType(MlirType type);
+
+MLIR_CAPI_EXPORTED MlirType mlirTransformAnyParamTypeGet(MlirContext ctx);
+
 //===---------------------------------------------------------------------===//
 // AnyValueType
 //===---------------------------------------------------------------------===//
@@ -49,6 +57,17 @@ mlirTransformOperationTypeGet(MlirContext ctx, MlirStringRef operationName);
 MLIR_CAPI_EXPORTED MlirStringRef
 mlirTransformOperationTypeGetOperationName(MlirType type);
 
+//===---------------------------------------------------------------------===//
+// ParamType
+//===---------------------------------------------------------------------===//
+
+MLIR_CAPI_EXPORTED bool mlirTypeIsATransformParamType(MlirType type);
+
+MLIR_CAPI_EXPORTED MlirType mlirTransformParamTypeGet(MlirContext ctx,
+                                                      MlirType type);
+
+MLIR_CAPI_EXPORTED MlirType mlirTransformParamTypeGetType(MlirType type);
+
 #ifdef __cplusplus
 }
 #endif
diff --git a/mlir/lib/Bindings/Python/DialectTransform.cpp b/mlir/lib/Bindings/Python/DialectTransform.cpp
index 932e40220057c13..e7d73c12d3db3d5 100644
--- a/mlir/lib/Bindings/Python/DialectTransform.cpp
+++ b/mlir/lib/Bindings/Python/DialectTransform.cpp
@@ -31,6 +31,20 @@ void populateDialectTransformSubmodule(const pybind11::module &m) {
       "Get an instance of AnyOpType in the given context.", py::arg("cls"),
       py::arg("context") = py::none());
 
+  //===-------------------------------------------------------------------===//
+  // AnyParamType
+  //===-------------------------------------------------------------------===//
+
+  auto anyParamType =
+      mlir_type_subclass(m, "AnyParamType", mlirTypeIsATransformAnyParamType);
+  anyParamType.def_classmethod(
+      "get",
+      [](py::object cls, MlirContext ctx) {
+        return cls(mlirTransformAnyParamTypeGet(ctx));
+      },
+      "Get an instance of AnyParamType in the given context.", py::arg("cls"),
+      py::arg("context") = py::none());
+
   //===-------------------------------------------------------------------===//
   // AnyValueType
   //===-------------------------------------------------------------------===//
@@ -71,6 +85,28 @@ void populateDialectTransformSubmodule(const pybind11::module &m) {
         return py::str(operationName.data, operationName.length);
       },
       "Get the name of the payload operation accepted by the handle.");
+
+  //===-------------------------------------------------------------------===//
+  // ParamType
+  //===-------------------------------------------------------------------===//
+
+  auto paramType =
+      mlir_type_subclass(m, "ParamType", mlirTypeIsATransformParamType);
+  paramType.def_classmethod(
+      "get",
+      [](py::object cls, MlirType type, MlirContext ctx) {
+        return cls(mlirTransformParamTypeGet(ctx, type));
+      },
+      "Get an instance of ParamType for the given type in the given context.",
+      py::arg("cls"), py::arg("type") = py::none(),
+      py::arg("context") = py::none());
+  paramType.def_property_readonly(
+      "type",
+      [](MlirType type) {
+        MlirType paramType = mlirTransformParamTypeGetType(type);
+        return paramType;
+      },
+      "Get the type this ParamType is associated with.");
 }
 
 PYBIND11_MODULE(_mlirDialectsTransform, m) {
diff --git a/mlir/lib/CAPI/Dialect/Transform.cpp b/mlir/lib/CAPI/Dialect/Transform.cpp
index 5841f6783ad5f1d..3f7f8b8e2113fe4 100644
--- a/mlir/lib/CAPI/Dialect/Transform.cpp
+++ b/mlir/lib/CAPI/Dialect/Transform.cpp
@@ -29,6 +29,18 @@ MlirType mlirTransformAnyOpTypeGet(MlirContext ctx) {
   return wrap(transform::AnyOpType::get(unwrap(ctx)));
 }
 
+//===---------------------------------------------------------------------===//
+// AnyParamType
+//===---------------------------------------------------------------------===//
+
+bool mlirTypeIsATransformAnyParamType(MlirType type) {
+  return isa<transform::AnyParamType>(unwrap(type));
+}
+
+MlirType mlirTransformAnyParamTypeGet(MlirContext ctx) {
+  return wrap(transform::AnyParamType::get(unwrap(ctx)));
+}
+
 //===---------------------------------------------------------------------===//
 // AnyValueType
 //===---------------------------------------------------------------------===//
@@ -62,3 +74,19 @@ MlirType mlirTransformOperationTypeGet(MlirContext ctx,
 MlirStringRef mlirTransformOperationTypeGetOperationName(MlirType type) {
   return wrap(cast<transform::OperationType>(unwrap(type)).getOperationName());
 }
+
+//===---------------------------------------------------------------------===//
+// AnyOpType
+//===---------------------------------------------------------------------===//
+
+bool mlirTypeIsATransformParamType(MlirType type) {
+  return isa<transform::ParamType>(unwrap(type));
+}
+
+MlirType mlirTransformParamTypeGet(MlirContext ctx, MlirType type) {
+  return wrap(transform::ParamType::get(unwrap(ctx), unwrap(type)));
+}
+
+MlirType mlirTransformParamTypeGetType(MlirType type) {
+  return wrap(cast<transform::ParamType>(unwrap(type)).getType());
+}
diff --git a/mlir/test/python/dialects/transform.py b/mlir/test/python/dialects/transform.py
index 5df125694256a4e..481d7745720101d 100644
--- a/mlir/test/python/dialects/transform.py
+++ b/mlir/test/python/dialects/transform.py
@@ -22,6 +22,10 @@ def testTypes():
     any_op = transform.AnyOpType.get()
     print(any_op)
 
+    # CHECK: !transform.any_param
+    any_param = transform.AnyParamType.get()
+    print(any_param)
+
     # CHECK: !transform.any_value
     any_value = transform.AnyValueType.get()
     print(any_value)
@@ -32,6 +36,12 @@ def testTypes():
     print(concrete_op)
     print(concrete_op.operation_name)
 
+    # CHECK: !transform.param<i32>
+    # CHECK: i32
+    param = transform.ParamType.get(IntegerType.get_signless(32))
+    print(param)
+    print(param.type)
+
 
 @run
 def testSequenceOp():

``````````

</details>


https://github.com/llvm/llvm-project/pull/67421


More information about the Mlir-commits mailing list