[Mlir-commits] [mlir] [mlir][python] expose LLVMStructType API (PR #81672)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Feb 13 13:52:44 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-llvm
Author: Oleksandr "Alex" Zinenko (ftynse)
<details>
<summary>Changes</summary>
Expose the API for constructing and inspecting StructTypes from the LLVM dialect. Separate constructor methods are used instead of overloads for better readability, similarly to IntegerType.
---
Full diff: https://github.com/llvm/llvm-project/pull/81672.diff
6 Files Affected:
- (modified) mlir/include/mlir-c/Dialect/LLVM.h (+61-1)
- (added) mlir/lib/Bindings/Python/DialectLLVM.cpp (+145)
- (modified) mlir/lib/CAPI/Dialect/LLVM.cpp (+65-1)
- (modified) mlir/python/CMakeLists.txt (+13)
- (modified) mlir/python/mlir/dialects/llvm.py (+1)
- (modified) mlir/test/python/dialects/llvm.py (+84)
``````````diff
diff --git a/mlir/include/mlir-c/Dialect/LLVM.h b/mlir/include/mlir-c/Dialect/LLVM.h
index 72701a82225436..80170b7e48b132 100644
--- a/mlir/include/mlir-c/Dialect/LLVM.h
+++ b/mlir/include/mlir-c/Dialect/LLVM.h
@@ -34,11 +34,71 @@ MLIR_CAPI_EXPORTED MlirType
mlirLLVMFunctionTypeGet(MlirType resultType, intptr_t nArgumentTypes,
MlirType const *argumentTypes, bool isVarArg);
-/// Creates an LLVM literal (unnamed) struct type.
+/// Returns `true` if the type is an LLVM dialect struct type.
+MLIR_CAPI_EXPORTED bool
+mlirTypeIsALLVMStructType(MlirType type);
+
+/// Returns `true` if the type is a literal (unnamed) LLVM struct type.
+MLIR_CAPI_EXPORTED bool mlirLLVMStructTypeIsLiteral(MlirType type);
+
+/// Returns the number of fields in the struct. Asserts if the struct is opaque
+/// or not yet initialized.
+MLIR_CAPI_EXPORTED intptr_t mlirLLVMStructTypeGetNumElementTypes(MlirType type);
+
+/// Returns the `positions`-th field of the struct. Asserts if the struct is
+/// opaque, not yet initialized or if the position is out of range.
+MLIR_CAPI_EXPORTED MlirType mlirLLVMStructTypeGetElementType(MlirType type, intptr_t position);
+
+/// Returns `true` if the struct is packed.
+MLIR_CAPI_EXPORTED bool mlirLLVMStructTypeIsPacked(MlirType type);
+
+/// Returns the identifier of the identified struct. Asserts that the struct is
+/// identified, i.e., not literal.
+MLIR_CAPI_EXPORTED MlirStringRef mlirLLVMStructTypeGetIdentifier(MlirType type);
+
+/// Returns `true` is the struct is explicitly opaque (will not have a body) or
+/// unitiniazlied (will eventually have a body).
+MLIR_CAPI_EXPORTED bool mlirLLVMStructTypeIsOpaque(MlirType type);
+
+/// Creates an LLVM literal (unnamed) struct type. This may assert if the fields
+/// have types not compatible with the LLVM dialect. For a graceful failure, use
+/// the checked version.
MLIR_CAPI_EXPORTED MlirType
mlirLLVMStructTypeLiteralGet(MlirContext ctx, intptr_t nFieldTypes,
MlirType const *fieldTypes, bool isPacked);
+/// Creates an LLVM literal (unnamed) struct type if possible. Emits a
+/// diagnostic at the given location and returns null otherwise.
+MLIR_CAPI_EXPORTED MlirType
+mlirLLVMStructTypeLiteralGetChecked(MlirLocation loc, intptr_t nFieldTypes,
+ MlirType const *fieldTypes, bool isPacked);
+
+/// Creates an LLVM identified struct type with no body. If a struct type with
+/// this name already exists in the context, returns that type. Use
+/// mlirLLVMStructTypeIdentifiedNewGet to create a fresh struct type,
+/// potentially renaming it. The body should be set separatelty by calling
+/// mlirLLVMStructTypeSetBody, if it isn't set already.
+MLIR_CAPI_EXPORTED MlirType
+mlirLLVMStructTypeIdentifiedGet(MlirContext ctx, MlirStringRef name);
+
+/// Creates an LLVM identified struct type with no body and a name starting with
+/// the given prefix. If a struct with the exact name as the given prefix
+/// already exists, appends an unspecified suffix to the name so that the name
+/// is unique in context.
+MLIR_CAPI_EXPORTED MlirType mlirLLVMStructTypeIdentifiedNewGet(
+ MlirContext ctx, MlirStringRef name, intptr_t nFieldTypes,
+ MlirType const *fieldTypes, bool isPacked);
+
+MLIR_CAPI_EXPORTED MlirType mlirLLVMStructTypeOpaqueGet(MlirContext ctx,
+ MlirStringRef name);
+
+/// Sets the body of the identified struct if it hasn't been set yet. Returns
+/// whether the operation was successful.
+MLIR_CAPI_EXPORTED MlirLogicalResult mlirLLVMStructTypeSetBody(MlirType structType,
+ intptr_t nFieldTypes,
+ MlirType const *fieldTypes,
+ bool isPacked);
+
#ifdef __cplusplus
}
#endif
diff --git a/mlir/lib/Bindings/Python/DialectLLVM.cpp b/mlir/lib/Bindings/Python/DialectLLVM.cpp
new file mode 100644
index 00000000000000..583375e1e2a13f
--- /dev/null
+++ b/mlir/lib/Bindings/Python/DialectLLVM.cpp
@@ -0,0 +1,145 @@
+//===- DialectLLVM.cpp - Pybind module for LLVM dialect API support -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir-c/Diagnostics.h"
+#include "mlir-c/Dialect/LLVM.h"
+#include "mlir-c/IR.h"
+#include "mlir-c/Support.h"
+#include "mlir/Bindings/Python/PybindAdaptors.h"
+#include <string>
+
+namespace py = pybind11;
+using namespace llvm;
+using namespace mlir;
+using namespace mlir::python;
+using namespace mlir::python::adaptors;
+
+namespace {
+/// Standalone RAII scope guard. We don't want to depend on the LLVM support
+/// library here to simplify build.
+template <typename FnTy>
+class OnScopeExit {
+public:
+ OnScopeExit(FnTy &&fn) : callback(std::forward<FnTy>(fn)) {}
+ ~OnScopeExit() { callback(); }
+
+private:
+ FnTy callback;
+};
+
+template <typename FnTy>
+OnScopeExit(FnTy &&fn) -> OnScopeExit<FnTy>;
+} // namespace
+
+void populateDialectLLVMSubmodule(const pybind11::module &m) {
+ auto llvmStructType =
+ mlir_type_subclass(m, "StructType", mlirTypeIsALLVMStructType);
+
+ llvmStructType.def_classmethod(
+ "get_literal",
+ [](py::object cls, const std::vector<MlirType> &elements, bool packed,
+ MlirLocation loc) {
+ std::string errorMessage = "";
+ auto handler = +[](MlirDiagnostic diag, void *data) {
+ auto printer = +[](MlirStringRef message, void *data) {
+ *static_cast<std::string *>(data) +=
+ StringRef(message.data, message.length);
+ };
+ mlirDiagnosticPrint(diag, printer, data);
+ return mlirLogicalResultSuccess();
+ };
+
+ MlirContext context = mlirLocationGetContext(loc);
+ MlirDiagnosticHandlerID diagID = mlirContextAttachDiagnosticHandler(
+ context, handler, &errorMessage, nullptr);
+ OnScopeExit scopeGuard(
+ [&]() { mlirContextDetachDiagnosticHandler(context, diagID); });
+
+ MlirType type = mlirLLVMStructTypeLiteralGetChecked(
+ loc, elements.size(), elements.data(), packed);
+ if (mlirTypeIsNull(type)) {
+ throw py::value_error(errorMessage);
+ }
+ return cls(type);
+ },
+ py::arg("cls"), py::arg("elements"), py::kw_only(),
+ py::arg("packed") = false, py::arg("loc") = py::none());
+
+ llvmStructType.def_classmethod(
+ "get_identified",
+ [](py::object cls, const std::string &name, MlirContext context) {
+ return cls(mlirLLVMStructTypeIdentifiedGet(
+ context, mlirStringRefCreate(name.data(), name.size())));
+ },
+ py::arg("cls"), py::arg("name"), py::kw_only(),
+ py::arg("context") = py::none());
+
+ llvmStructType.def_classmethod(
+ "get_opaque",
+ [](py::object cls, const std::string &name, MlirContext context) {
+ return cls(mlirLLVMStructTypeOpaqueGet(
+ context, mlirStringRefCreate(name.data(), name.size())));
+ }, py::arg("cls"), py::arg("name"), py::arg("context") = py::none());
+
+ llvmStructType.def(
+ "set_body",
+ [](MlirType self, const std::vector<MlirType> &elements, bool packed) {
+ MlirLogicalResult result = mlirLLVMStructTypeSetBody(
+ self, elements.size(), elements.data(), packed);
+ if (!mlirLogicalResultIsSuccess(result)) {
+ throw py::value_error(
+ "Struct body already set to different content.");
+ }
+ },
+ py::arg("elements"), py::kw_only(), py::arg("packed") = false);
+
+ llvmStructType.def_classmethod(
+ "new_identified",
+ [](py::object cls, const std::string &name,
+ const std::vector<MlirType> &elements, bool packed, MlirContext ctx) {
+ return cls(mlirLLVMStructTypeIdentifiedNewGet(
+ ctx, mlirStringRefCreate(name.data(), name.length()),
+ elements.size(), elements.data(), packed));
+ },
+ py::arg("cls"), py::arg("name"), py::arg("elements"), py::kw_only(),
+ py::arg("packed") = false, py::arg("context") = py::none());
+
+ llvmStructType.def_property_readonly(
+ "name", [](MlirType type) -> std::optional<std::string> {
+ if (mlirLLVMStructTypeIsLiteral(type))
+ return std::nullopt;
+
+ MlirStringRef stringRef = mlirLLVMStructTypeGetIdentifier(type);
+ return StringRef(stringRef.data, stringRef.length).str();
+ });
+
+ llvmStructType.def_property_readonly("body", [](MlirType type) -> py::object {
+ // Don't crash in absence of a body.
+ if (mlirLLVMStructTypeIsOpaque(type))
+ return py::none();
+
+ py::list body;
+ for (intptr_t i = 0, e = mlirLLVMStructTypeGetNumElementTypes(type); i < e;
+ ++i) {
+ body.append(mlirLLVMStructTypeGetElementType(type, i));
+ }
+ return body;
+ });
+
+ llvmStructType.def_property_readonly(
+ "packed", [](MlirType type) { return mlirLLVMStructTypeIsPacked(type); });
+
+ llvmStructType.def_property_readonly(
+ "opaque", [](MlirType type) { return mlirLLVMStructTypeIsOpaque(type); });
+}
+
+PYBIND11_MODULE(_mlirDialectsLLVM, m) {
+ m.doc() = "MLIR LLVM Dialect";
+
+ populateDialectLLVMSubmodule(m);
+}
\ No newline at end of file
diff --git a/mlir/lib/CAPI/Dialect/LLVM.cpp b/mlir/lib/CAPI/Dialect/LLVM.cpp
index b4405f7aac8ab2..ada9fcb46be21f 100644
--- a/mlir/lib/CAPI/Dialect/LLVM.cpp
+++ b/mlir/lib/CAPI/Dialect/LLVM.cpp
@@ -36,11 +36,75 @@ MlirType mlirLLVMFunctionTypeGet(MlirType resultType, intptr_t nArgumentTypes,
unwrapList(nArgumentTypes, argumentTypes, argumentStorage), isVarArg));
}
+bool mlirTypeIsALLVMStructType(MlirType type) {
+ return isa<LLVM::LLVMStructType>(unwrap(type));
+}
+
+bool mlirLLVMStructTypeIsLiteral(MlirType type) {
+ return !cast<LLVM::LLVMStructType>(unwrap(type)).isIdentified();
+}
+
+intptr_t mlirLLVMStructTypeGetNumElementTypes(MlirType type) {
+ return cast<LLVM::LLVMStructType>(unwrap(type)).getBody().size();
+}
+
+MlirType mlirLLVMStructTypeGetElementType(MlirType type, intptr_t position) {
+ return wrap(cast<LLVM::LLVMStructType>(unwrap(type)).getBody()[position]);
+}
+
+bool mlirLLVMStructTypeIsPacked(MlirType type) {
+ return cast<LLVM::LLVMStructType>(unwrap(type)).isPacked();
+}
+
+MlirStringRef mlirLLVMStructTypeGetIdentifier(MlirType type) {
+ return wrap(cast<LLVM::LLVMStructType>(unwrap(type)).getName());
+}
+
+bool mlirLLVMStructTypeIsOpaque(MlirType type) {
+ return cast<LLVM::LLVMStructType>(unwrap(type)).isOpaque();
+}
+
MlirType mlirLLVMStructTypeLiteralGet(MlirContext ctx, intptr_t nFieldTypes,
MlirType const *fieldTypes,
bool isPacked) {
- SmallVector<Type, 2> fieldStorage;
+ SmallVector<Type> fieldStorage;
return wrap(LLVMStructType::getLiteral(
unwrap(ctx), unwrapList(nFieldTypes, fieldTypes, fieldStorage),
isPacked));
}
+
+MlirType mlirLLVMStructTypeLiteralGetChecked(MlirLocation loc,
+ intptr_t nFieldTypes,
+ MlirType const *fieldTypes,
+ bool isPacked) {
+ SmallVector<Type> fieldStorage;
+ return wrap(LLVMStructType::getLiteralChecked(
+ [loc]() { return emitError(unwrap(loc)); }, unwrap(loc)->getContext(),
+ unwrapList(nFieldTypes, fieldTypes, fieldStorage), isPacked));
+}
+
+MlirType mlirLLVMStructTypeOpaqueGet(MlirContext ctx, MlirStringRef name) {
+ return wrap(LLVMStructType::getOpaque(unwrap(name), unwrap(ctx)));
+}
+
+MlirType mlirLLVMStructTypeIdentifiedGet(MlirContext ctx, MlirStringRef name) {
+ return wrap(LLVMStructType::getIdentified(unwrap(ctx), unwrap(name)));
+}
+
+MlirType mlirLLVMStructTypeIdentifiedNewGet(MlirContext ctx, MlirStringRef name,
+intptr_t nFieldTypes, MlirType const *fieldTypes, bool isPacked) {
+ SmallVector<Type> fields;
+ return wrap(LLVMStructType::getNewIdentified(
+ unwrap(ctx), unwrap(name), unwrapList(nFieldTypes, fieldTypes, fields),
+ isPacked));
+}
+
+MlirLogicalResult mlirLLVMStructTypeSetBody(MlirType structType,
+ intptr_t nFieldTypes,
+ MlirType const *fieldTypes,
+ bool isPacked) {
+ SmallVector<Type> fields;
+ return wrap(
+ cast<LLVM::LLVMStructType>(unwrap(structType))
+ .setBody(unwrapList(nFieldTypes, fieldTypes, fields), isPacked));
+}
\ No newline at end of file
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 266b86090fe174..ed167afeb69a62 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -482,6 +482,19 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.Linalg.Pybind
MLIRCAPILinalg
)
+declare_mlir_python_extension(MLIRPythonExtension.Dialects.LLVM.Pybind
+ MODULE_NAME _mlirDialectsLLVM
+ ADD_TO_PARENT MLIRPythonSources.Dialects.llvm
+ ROOT_DIR "${PYTHON_SOURCE_DIR}"
+ SOURCES
+ DialectLLVM.cpp
+ PRIVATE_LINK_LIBS
+ LLVMSupport
+ EMBED_CAPI_LINK_LIBS
+ MLIRCAPIIR
+ MLIRCAPILLVM
+)
+
declare_mlir_python_extension(MLIRPythonExtension.Dialects.Quant.Pybind
MODULE_NAME _mlirDialectsQuant
ADD_TO_PARENT MLIRPythonSources.Dialects.quant
diff --git a/mlir/python/mlir/dialects/llvm.py b/mlir/python/mlir/dialects/llvm.py
index 77025438c37a4f..8aa16e4a256030 100644
--- a/mlir/python/mlir/dialects/llvm.py
+++ b/mlir/python/mlir/dialects/llvm.py
@@ -4,3 +4,4 @@
from ._llvm_ops_gen import *
from ._llvm_enum_gen import *
+from .._mlir_libs._mlirDialectsLLVM import *
diff --git a/mlir/test/python/dialects/llvm.py b/mlir/test/python/dialects/llvm.py
index 2d207ae14eecd2..af919df723827f 100644
--- a/mlir/test/python/dialects/llvm.py
+++ b/mlir/test/python/dialects/llvm.py
@@ -15,6 +15,90 @@ def constructAndPrintInModule(f):
return f
+# CHECK-LABEL: testStructType
+ at constructAndPrintInModule
+def testStructType():
+ print(llvm.StructType.get_literal([]))
+ # CHECK: !llvm.struct<()>
+
+ i8, i32, i64 = tuple(map(lambda x: IntegerType.get_signless(x), [8, 32, 64]))
+ print(llvm.StructType.get_literal([i8, i32, i64]))
+ print(llvm.StructType.get_literal([i32]))
+ print(llvm.StructType.get_literal([i32, i32], packed=True))
+ literal = llvm.StructType.get_literal([i8, i32, i64])
+ assert len(literal.body) == 3
+ print(*tuple(literal.body))
+ assert literal.name is None
+ # CHECK: !llvm.struct<(i8, i32, i64)>
+ # CHECK: !llvm.struct<(i32)>
+ # CHECK: !llvm.struct<packed (i32, i32)>
+ # CHECK: i8 i32 i64
+
+ assert llvm.StructType.get_literal([i32]) == llvm.StructType.get_literal([i32])
+ assert llvm.StructType.get_literal([i32]) != llvm.StructType.get_literal([i64])
+
+ print(llvm.StructType.get_identified("foo"))
+ print(llvm.StructType.get_identified("bar"))
+ # CHECK: !llvm.struct<"foo", opaque>
+ # CHECK: !llvm.struct<"bar", opaque>
+
+ assert llvm.StructType.get_identified("foo") == llvm.StructType.get_identified(
+ "foo"
+ )
+ assert llvm.StructType.get_identified("foo") != llvm.StructType.get_identified(
+ "bar"
+ )
+
+ foo_struct = llvm.StructType.get_identified("foo")
+ print(foo_struct.name)
+ print(foo_struct.body)
+ assert foo_struct.opaque
+ foo_struct.set_body([i32, i64])
+ print(*tuple(foo_struct.body))
+ print(foo_struct)
+ assert not foo_struct.packed
+ assert not foo_struct.opaque
+ assert llvm.StructType.get_identified("foo") == foo_struct
+ # CHECK: foo
+ # CHECK: None
+ # CHECK: i32 i64
+ # CHECK: !llvm.struct<"foo", (i32, i64)>
+
+ bar_struct = llvm.StructType.get_identified("bar")
+ bar_struct.set_body([i32], packed=True)
+ print(bar_struct)
+ assert bar_struct.packed
+ # CHECK: !llvm.struct<"bar", packed (i32)>
+
+ # Same body, should not raise.
+ foo_struct.set_body([i32, i64])
+
+ try:
+ foo_struct.set_body([])
+ except ValueError as e:
+ pass
+ else:
+ assert False, "expected exception not raised"
+
+ try:
+ bar_struct.set_body([i32])
+ except ValueError as e:
+ pass
+ else:
+ assert False, "expected exception not raisr"
+
+ print(llvm.StructType.new_identified("foo", []))
+ assert llvm.StructType.new_identified("foo", []) != llvm.StructType.new_identified(
+ "foo", []
+ )
+ # CHECK: !llvm.struct<"foo{{[^"]+}}
+
+ opaque = llvm.StructType.get_opaque("opaque")
+ print(opaque)
+ assert opaque.opaque
+ # CHECK: !llvm.struct<"opaque", opaque>
+
+
# CHECK-LABEL: testSmoke
@constructAndPrintInModule
def testSmoke():
``````````
</details>
https://github.com/llvm/llvm-project/pull/81672
More information about the Mlir-commits
mailing list