[Mlir-commits] [mlir] 97f9f1a - [mlir][python] Expose transform param types (#67421)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Sep 26 07:10:29 PDT 2023


Author: martin-luecke
Date: 2023-09-26T16:10:24+02:00
New Revision: 97f9f1a08ab1f5f91282cf95d13f306d03dc0888

URL: https://github.com/llvm/llvm-project/commit/97f9f1a08ab1f5f91282cf95d13f306d03dc0888
DIFF: https://github.com/llvm/llvm-project/commit/97f9f1a08ab1f5f91282cf95d13f306d03dc0888.diff

LOG: [mlir][python] Expose transform param types (#67421)

This exposes the Transform dialect types `AnyParamType` and `ParamType`
via the Python bindings.

Added: 
    

Modified: 
    mlir/include/mlir-c/Dialect/Transform.h
    mlir/lib/Bindings/Python/DialectTransform.cpp
    mlir/lib/CAPI/Dialect/Transform.cpp
    mlir/test/python/dialects/transform.py

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/Dialect/Transform.h b/mlir/include/mlir-c/Dialect/Transform.h
index 954575925cc5c45..91c99b1f869f22c 100644
--- a/mlir/include/mlir-c/Dialect/Transform.h
+++ b/mlir/include/mlir-c/Dialect/Transform.h
@@ -27,6 +27,14 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsATransformAnyOpType(MlirType type);
 
 MLIR_CAPI_EXPORTED MlirType mlirTransformAnyOpTypeGet(MlirContext ctx);
 
+//===---------------------------------------------------------------------===//
+// AnyParamType
+//===---------------------------------------------------------------------===//
+
+MLIR_CAPI_EXPORTED bool mlirTypeIsATransformAnyParamType(MlirType type);
+
+MLIR_CAPI_EXPORTED MlirType mlirTransformAnyParamTypeGet(MlirContext ctx);
+
 //===---------------------------------------------------------------------===//
 // AnyValueType
 //===---------------------------------------------------------------------===//
@@ -49,6 +57,17 @@ mlirTransformOperationTypeGet(MlirContext ctx, MlirStringRef operationName);
 MLIR_CAPI_EXPORTED MlirStringRef
 mlirTransformOperationTypeGetOperationName(MlirType type);
 
+//===---------------------------------------------------------------------===//
+// ParamType
+//===---------------------------------------------------------------------===//
+
+MLIR_CAPI_EXPORTED bool mlirTypeIsATransformParamType(MlirType type);
+
+MLIR_CAPI_EXPORTED MlirType mlirTransformParamTypeGet(MlirContext ctx,
+                                                      MlirType type);
+
+MLIR_CAPI_EXPORTED MlirType mlirTransformParamTypeGetType(MlirType type);
+
 #ifdef __cplusplus
 }
 #endif

diff  --git a/mlir/lib/Bindings/Python/DialectTransform.cpp b/mlir/lib/Bindings/Python/DialectTransform.cpp
index 932e40220057c13..cbbf8332b14ff30 100644
--- a/mlir/lib/Bindings/Python/DialectTransform.cpp
+++ b/mlir/lib/Bindings/Python/DialectTransform.cpp
@@ -31,6 +31,20 @@ void populateDialectTransformSubmodule(const pybind11::module &m) {
       "Get an instance of AnyOpType in the given context.", py::arg("cls"),
       py::arg("context") = py::none());
 
+  //===-------------------------------------------------------------------===//
+  // AnyParamType
+  //===-------------------------------------------------------------------===//
+
+  auto anyParamType =
+      mlir_type_subclass(m, "AnyParamType", mlirTypeIsATransformAnyParamType);
+  anyParamType.def_classmethod(
+      "get",
+      [](py::object cls, MlirContext ctx) {
+        return cls(mlirTransformAnyParamTypeGet(ctx));
+      },
+      "Get an instance of AnyParamType in the given context.", py::arg("cls"),
+      py::arg("context") = py::none());
+
   //===-------------------------------------------------------------------===//
   // AnyValueType
   //===-------------------------------------------------------------------===//
@@ -71,6 +85,27 @@ void populateDialectTransformSubmodule(const pybind11::module &m) {
         return py::str(operationName.data, operationName.length);
       },
       "Get the name of the payload operation accepted by the handle.");
+
+  //===-------------------------------------------------------------------===//
+  // ParamType
+  //===-------------------------------------------------------------------===//
+
+  auto paramType =
+      mlir_type_subclass(m, "ParamType", mlirTypeIsATransformParamType);
+  paramType.def_classmethod(
+      "get",
+      [](py::object cls, MlirType type, MlirContext ctx) {
+        return cls(mlirTransformParamTypeGet(ctx, type));
+      },
+      "Get an instance of ParamType for the given type in the given context.",
+      py::arg("cls"), py::arg("type"), py::arg("context") = py::none());
+  paramType.def_property_readonly(
+      "type",
+      [](MlirType type) {
+        MlirType paramType = mlirTransformParamTypeGetType(type);
+        return paramType;
+      },
+      "Get the type this ParamType is associated with.");
 }
 
 PYBIND11_MODULE(_mlirDialectsTransform, m) {

diff  --git a/mlir/lib/CAPI/Dialect/Transform.cpp b/mlir/lib/CAPI/Dialect/Transform.cpp
index 5841f6783ad5f1d..3f7f8b8e2113fe4 100644
--- a/mlir/lib/CAPI/Dialect/Transform.cpp
+++ b/mlir/lib/CAPI/Dialect/Transform.cpp
@@ -29,6 +29,18 @@ MlirType mlirTransformAnyOpTypeGet(MlirContext ctx) {
   return wrap(transform::AnyOpType::get(unwrap(ctx)));
 }
 
+//===---------------------------------------------------------------------===//
+// AnyParamType
+//===---------------------------------------------------------------------===//
+
+bool mlirTypeIsATransformAnyParamType(MlirType type) {
+  return isa<transform::AnyParamType>(unwrap(type));
+}
+
+MlirType mlirTransformAnyParamTypeGet(MlirContext ctx) {
+  return wrap(transform::AnyParamType::get(unwrap(ctx)));
+}
+
 //===---------------------------------------------------------------------===//
 // AnyValueType
 //===---------------------------------------------------------------------===//
@@ -62,3 +74,19 @@ MlirType mlirTransformOperationTypeGet(MlirContext ctx,
 MlirStringRef mlirTransformOperationTypeGetOperationName(MlirType type) {
   return wrap(cast<transform::OperationType>(unwrap(type)).getOperationName());
 }
+
+//===---------------------------------------------------------------------===//
+// AnyOpType
+//===---------------------------------------------------------------------===//
+
+bool mlirTypeIsATransformParamType(MlirType type) {
+  return isa<transform::ParamType>(unwrap(type));
+}
+
+MlirType mlirTransformParamTypeGet(MlirContext ctx, MlirType type) {
+  return wrap(transform::ParamType::get(unwrap(ctx), unwrap(type)));
+}
+
+MlirType mlirTransformParamTypeGetType(MlirType type) {
+  return wrap(cast<transform::ParamType>(unwrap(type)).getType());
+}

diff  --git a/mlir/test/python/dialects/transform.py b/mlir/test/python/dialects/transform.py
index 5df125694256a4e..481d7745720101d 100644
--- a/mlir/test/python/dialects/transform.py
+++ b/mlir/test/python/dialects/transform.py
@@ -22,6 +22,10 @@ def testTypes():
     any_op = transform.AnyOpType.get()
     print(any_op)
 
+    # CHECK: !transform.any_param
+    any_param = transform.AnyParamType.get()
+    print(any_param)
+
     # CHECK: !transform.any_value
     any_value = transform.AnyValueType.get()
     print(any_value)
@@ -32,6 +36,12 @@ def testTypes():
     print(concrete_op)
     print(concrete_op.operation_name)
 
+    # CHECK: !transform.param<i32>
+    # CHECK: i32
+    param = transform.ParamType.get(IntegerType.get_signless(32))
+    print(param)
+    print(param.type)
+
 
 @run
 def testSequenceOp():


        


More information about the Mlir-commits mailing list