[Mlir-commits] [mlir] [mlir][python] Expose transform param types (PR #67421)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Sep 26 05:18:51 PDT 2023
https://github.com/martin-luecke created https://github.com/llvm/llvm-project/pull/67421
This exposes the Transform dialect types `AnyParamType` and `ParamType` via the Python bindings.
>From 943a59972364fec02098f7b144220c58fbd8ff10 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Martin=20L=C3=BCcke?= <martin.luecke at ed.ac.uk>
Date: Tue, 26 Sep 2023 12:15:46 +0000
Subject: [PATCH] [mlir][python] Expose transform param type
---
mlir/include/mlir-c/Dialect/Transform.h | 19 ++++++++++
mlir/lib/Bindings/Python/DialectTransform.cpp | 36 +++++++++++++++++++
mlir/lib/CAPI/Dialect/Transform.cpp | 28 +++++++++++++++
mlir/test/python/dialects/transform.py | 10 ++++++
4 files changed, 93 insertions(+)
diff --git a/mlir/include/mlir-c/Dialect/Transform.h b/mlir/include/mlir-c/Dialect/Transform.h
index 954575925cc5c45..91c99b1f869f22c 100644
--- a/mlir/include/mlir-c/Dialect/Transform.h
+++ b/mlir/include/mlir-c/Dialect/Transform.h
@@ -27,6 +27,14 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsATransformAnyOpType(MlirType type);
MLIR_CAPI_EXPORTED MlirType mlirTransformAnyOpTypeGet(MlirContext ctx);
+//===---------------------------------------------------------------------===//
+// AnyParamType
+//===---------------------------------------------------------------------===//
+
+MLIR_CAPI_EXPORTED bool mlirTypeIsATransformAnyParamType(MlirType type);
+
+MLIR_CAPI_EXPORTED MlirType mlirTransformAnyParamTypeGet(MlirContext ctx);
+
//===---------------------------------------------------------------------===//
// AnyValueType
//===---------------------------------------------------------------------===//
@@ -49,6 +57,17 @@ mlirTransformOperationTypeGet(MlirContext ctx, MlirStringRef operationName);
MLIR_CAPI_EXPORTED MlirStringRef
mlirTransformOperationTypeGetOperationName(MlirType type);
+//===---------------------------------------------------------------------===//
+// ParamType
+//===---------------------------------------------------------------------===//
+
+MLIR_CAPI_EXPORTED bool mlirTypeIsATransformParamType(MlirType type);
+
+MLIR_CAPI_EXPORTED MlirType mlirTransformParamTypeGet(MlirContext ctx,
+ MlirType type);
+
+MLIR_CAPI_EXPORTED MlirType mlirTransformParamTypeGetType(MlirType type);
+
#ifdef __cplusplus
}
#endif
diff --git a/mlir/lib/Bindings/Python/DialectTransform.cpp b/mlir/lib/Bindings/Python/DialectTransform.cpp
index 932e40220057c13..e7d73c12d3db3d5 100644
--- a/mlir/lib/Bindings/Python/DialectTransform.cpp
+++ b/mlir/lib/Bindings/Python/DialectTransform.cpp
@@ -31,6 +31,20 @@ void populateDialectTransformSubmodule(const pybind11::module &m) {
"Get an instance of AnyOpType in the given context.", py::arg("cls"),
py::arg("context") = py::none());
+ //===-------------------------------------------------------------------===//
+ // AnyParamType
+ //===-------------------------------------------------------------------===//
+
+ auto anyParamType =
+ mlir_type_subclass(m, "AnyParamType", mlirTypeIsATransformAnyParamType);
+ anyParamType.def_classmethod(
+ "get",
+ [](py::object cls, MlirContext ctx) {
+ return cls(mlirTransformAnyParamTypeGet(ctx));
+ },
+ "Get an instance of AnyParamType in the given context.", py::arg("cls"),
+ py::arg("context") = py::none());
+
//===-------------------------------------------------------------------===//
// AnyValueType
//===-------------------------------------------------------------------===//
@@ -71,6 +85,28 @@ void populateDialectTransformSubmodule(const pybind11::module &m) {
return py::str(operationName.data, operationName.length);
},
"Get the name of the payload operation accepted by the handle.");
+
+ //===-------------------------------------------------------------------===//
+ // ParamType
+ //===-------------------------------------------------------------------===//
+
+ auto paramType =
+ mlir_type_subclass(m, "ParamType", mlirTypeIsATransformParamType);
+ paramType.def_classmethod(
+ "get",
+ [](py::object cls, MlirType type, MlirContext ctx) {
+ return cls(mlirTransformParamTypeGet(ctx, type));
+ },
+ "Get an instance of ParamType for the given type in the given context.",
+ py::arg("cls"), py::arg("type") = py::none(),
+ py::arg("context") = py::none());
+ paramType.def_property_readonly(
+ "type",
+ [](MlirType type) {
+ MlirType paramType = mlirTransformParamTypeGetType(type);
+ return paramType;
+ },
+ "Get the type this ParamType is associated with.");
}
PYBIND11_MODULE(_mlirDialectsTransform, m) {
diff --git a/mlir/lib/CAPI/Dialect/Transform.cpp b/mlir/lib/CAPI/Dialect/Transform.cpp
index 5841f6783ad5f1d..3f7f8b8e2113fe4 100644
--- a/mlir/lib/CAPI/Dialect/Transform.cpp
+++ b/mlir/lib/CAPI/Dialect/Transform.cpp
@@ -29,6 +29,18 @@ MlirType mlirTransformAnyOpTypeGet(MlirContext ctx) {
return wrap(transform::AnyOpType::get(unwrap(ctx)));
}
+//===---------------------------------------------------------------------===//
+// AnyParamType
+//===---------------------------------------------------------------------===//
+
+bool mlirTypeIsATransformAnyParamType(MlirType type) {
+ return isa<transform::AnyParamType>(unwrap(type));
+}
+
+MlirType mlirTransformAnyParamTypeGet(MlirContext ctx) {
+ return wrap(transform::AnyParamType::get(unwrap(ctx)));
+}
+
//===---------------------------------------------------------------------===//
// AnyValueType
//===---------------------------------------------------------------------===//
@@ -62,3 +74,19 @@ MlirType mlirTransformOperationTypeGet(MlirContext ctx,
MlirStringRef mlirTransformOperationTypeGetOperationName(MlirType type) {
return wrap(cast<transform::OperationType>(unwrap(type)).getOperationName());
}
+
+//===---------------------------------------------------------------------===//
+// AnyOpType
+//===---------------------------------------------------------------------===//
+
+bool mlirTypeIsATransformParamType(MlirType type) {
+ return isa<transform::ParamType>(unwrap(type));
+}
+
+MlirType mlirTransformParamTypeGet(MlirContext ctx, MlirType type) {
+ return wrap(transform::ParamType::get(unwrap(ctx), unwrap(type)));
+}
+
+MlirType mlirTransformParamTypeGetType(MlirType type) {
+ return wrap(cast<transform::ParamType>(unwrap(type)).getType());
+}
diff --git a/mlir/test/python/dialects/transform.py b/mlir/test/python/dialects/transform.py
index 5df125694256a4e..481d7745720101d 100644
--- a/mlir/test/python/dialects/transform.py
+++ b/mlir/test/python/dialects/transform.py
@@ -22,6 +22,10 @@ def testTypes():
any_op = transform.AnyOpType.get()
print(any_op)
+ # CHECK: !transform.any_param
+ any_param = transform.AnyParamType.get()
+ print(any_param)
+
# CHECK: !transform.any_value
any_value = transform.AnyValueType.get()
print(any_value)
@@ -32,6 +36,12 @@ def testTypes():
print(concrete_op)
print(concrete_op.operation_name)
+ # CHECK: !transform.param<i32>
+ # CHECK: i32
+ param = transform.ParamType.get(IntegerType.get_signless(32))
+ print(param)
+ print(param.type)
+
@run
def testSequenceOp():
More information about the Mlir-commits
mailing list