[Mlir-commits] [mlir] 97f9f1a - [mlir][python] Expose transform param types (#67421)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Sep 26 07:10:29 PDT 2023
Author: martin-luecke
Date: 2023-09-26T16:10:24+02:00
New Revision: 97f9f1a08ab1f5f91282cf95d13f306d03dc0888
URL: https://github.com/llvm/llvm-project/commit/97f9f1a08ab1f5f91282cf95d13f306d03dc0888
DIFF: https://github.com/llvm/llvm-project/commit/97f9f1a08ab1f5f91282cf95d13f306d03dc0888.diff
LOG: [mlir][python] Expose transform param types (#67421)
This exposes the Transform dialect types `AnyParamType` and `ParamType`
via the Python bindings.
Added:
Modified:
mlir/include/mlir-c/Dialect/Transform.h
mlir/lib/Bindings/Python/DialectTransform.cpp
mlir/lib/CAPI/Dialect/Transform.cpp
mlir/test/python/dialects/transform.py
Removed:
################################################################################
diff --git a/mlir/include/mlir-c/Dialect/Transform.h b/mlir/include/mlir-c/Dialect/Transform.h
index 954575925cc5c45..91c99b1f869f22c 100644
--- a/mlir/include/mlir-c/Dialect/Transform.h
+++ b/mlir/include/mlir-c/Dialect/Transform.h
@@ -27,6 +27,14 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsATransformAnyOpType(MlirType type);
MLIR_CAPI_EXPORTED MlirType mlirTransformAnyOpTypeGet(MlirContext ctx);
+//===---------------------------------------------------------------------===//
+// AnyParamType
+//===---------------------------------------------------------------------===//
+
+MLIR_CAPI_EXPORTED bool mlirTypeIsATransformAnyParamType(MlirType type);
+
+MLIR_CAPI_EXPORTED MlirType mlirTransformAnyParamTypeGet(MlirContext ctx);
+
//===---------------------------------------------------------------------===//
// AnyValueType
//===---------------------------------------------------------------------===//
@@ -49,6 +57,17 @@ mlirTransformOperationTypeGet(MlirContext ctx, MlirStringRef operationName);
MLIR_CAPI_EXPORTED MlirStringRef
mlirTransformOperationTypeGetOperationName(MlirType type);
+//===---------------------------------------------------------------------===//
+// ParamType
+//===---------------------------------------------------------------------===//
+
+MLIR_CAPI_EXPORTED bool mlirTypeIsATransformParamType(MlirType type);
+
+MLIR_CAPI_EXPORTED MlirType mlirTransformParamTypeGet(MlirContext ctx,
+ MlirType type);
+
+MLIR_CAPI_EXPORTED MlirType mlirTransformParamTypeGetType(MlirType type);
+
#ifdef __cplusplus
}
#endif
diff --git a/mlir/lib/Bindings/Python/DialectTransform.cpp b/mlir/lib/Bindings/Python/DialectTransform.cpp
index 932e40220057c13..cbbf8332b14ff30 100644
--- a/mlir/lib/Bindings/Python/DialectTransform.cpp
+++ b/mlir/lib/Bindings/Python/DialectTransform.cpp
@@ -31,6 +31,20 @@ void populateDialectTransformSubmodule(const pybind11::module &m) {
"Get an instance of AnyOpType in the given context.", py::arg("cls"),
py::arg("context") = py::none());
+ //===-------------------------------------------------------------------===//
+ // AnyParamType
+ //===-------------------------------------------------------------------===//
+
+ auto anyParamType =
+ mlir_type_subclass(m, "AnyParamType", mlirTypeIsATransformAnyParamType);
+ anyParamType.def_classmethod(
+ "get",
+ [](py::object cls, MlirContext ctx) {
+ return cls(mlirTransformAnyParamTypeGet(ctx));
+ },
+ "Get an instance of AnyParamType in the given context.", py::arg("cls"),
+ py::arg("context") = py::none());
+
//===-------------------------------------------------------------------===//
// AnyValueType
//===-------------------------------------------------------------------===//
@@ -71,6 +85,27 @@ void populateDialectTransformSubmodule(const pybind11::module &m) {
return py::str(operationName.data, operationName.length);
},
"Get the name of the payload operation accepted by the handle.");
+
+ //===-------------------------------------------------------------------===//
+ // ParamType
+ //===-------------------------------------------------------------------===//
+
+ auto paramType =
+ mlir_type_subclass(m, "ParamType", mlirTypeIsATransformParamType);
+ paramType.def_classmethod(
+ "get",
+ [](py::object cls, MlirType type, MlirContext ctx) {
+ return cls(mlirTransformParamTypeGet(ctx, type));
+ },
+ "Get an instance of ParamType for the given type in the given context.",
+ py::arg("cls"), py::arg("type"), py::arg("context") = py::none());
+ paramType.def_property_readonly(
+ "type",
+ [](MlirType type) {
+ MlirType paramType = mlirTransformParamTypeGetType(type);
+ return paramType;
+ },
+ "Get the type this ParamType is associated with.");
}
PYBIND11_MODULE(_mlirDialectsTransform, m) {
diff --git a/mlir/lib/CAPI/Dialect/Transform.cpp b/mlir/lib/CAPI/Dialect/Transform.cpp
index 5841f6783ad5f1d..3f7f8b8e2113fe4 100644
--- a/mlir/lib/CAPI/Dialect/Transform.cpp
+++ b/mlir/lib/CAPI/Dialect/Transform.cpp
@@ -29,6 +29,18 @@ MlirType mlirTransformAnyOpTypeGet(MlirContext ctx) {
return wrap(transform::AnyOpType::get(unwrap(ctx)));
}
+//===---------------------------------------------------------------------===//
+// AnyParamType
+//===---------------------------------------------------------------------===//
+
+bool mlirTypeIsATransformAnyParamType(MlirType type) {
+ return isa<transform::AnyParamType>(unwrap(type));
+}
+
+MlirType mlirTransformAnyParamTypeGet(MlirContext ctx) {
+ return wrap(transform::AnyParamType::get(unwrap(ctx)));
+}
+
//===---------------------------------------------------------------------===//
// AnyValueType
//===---------------------------------------------------------------------===//
@@ -62,3 +74,19 @@ MlirType mlirTransformOperationTypeGet(MlirContext ctx,
MlirStringRef mlirTransformOperationTypeGetOperationName(MlirType type) {
return wrap(cast<transform::OperationType>(unwrap(type)).getOperationName());
}
+
+//===---------------------------------------------------------------------===//
+// AnyOpType
+//===---------------------------------------------------------------------===//
+
+bool mlirTypeIsATransformParamType(MlirType type) {
+ return isa<transform::ParamType>(unwrap(type));
+}
+
+MlirType mlirTransformParamTypeGet(MlirContext ctx, MlirType type) {
+ return wrap(transform::ParamType::get(unwrap(ctx), unwrap(type)));
+}
+
+MlirType mlirTransformParamTypeGetType(MlirType type) {
+ return wrap(cast<transform::ParamType>(unwrap(type)).getType());
+}
diff --git a/mlir/test/python/dialects/transform.py b/mlir/test/python/dialects/transform.py
index 5df125694256a4e..481d7745720101d 100644
--- a/mlir/test/python/dialects/transform.py
+++ b/mlir/test/python/dialects/transform.py
@@ -22,6 +22,10 @@ def testTypes():
any_op = transform.AnyOpType.get()
print(any_op)
+ # CHECK: !transform.any_param
+ any_param = transform.AnyParamType.get()
+ print(any_param)
+
# CHECK: !transform.any_value
any_value = transform.AnyValueType.get()
print(any_value)
@@ -32,6 +36,12 @@ def testTypes():
print(concrete_op)
print(concrete_op.operation_name)
+ # CHECK: !transform.param<i32>
+ # CHECK: i32
+ param = transform.ParamType.get(IntegerType.get_signless(32))
+ print(param)
+ print(param.type)
+
@run
def testSequenceOp():
More information about the Mlir-commits
mailing list