[Mlir-commits] [mlir] [mlir][python] expose LLVMStructType API (PR #81672)
Oleksandr Alex Zinenko
llvmlistbot at llvm.org
Wed Feb 14 05:05:53 PST 2024
https://github.com/ftynse updated https://github.com/llvm/llvm-project/pull/81672
>From 30b3321e04c21c4c806aeae715b9a831128aba13 Mon Sep 17 00:00:00 2001
From: Alex Zinenko <zinenko at google.com>
Date: Wed, 14 Feb 2024 13:03:46 +0000
Subject: [PATCH] [mlir] expose LLVMStructType API
Expose the API for constructing and inspecting StructTypes from the LLVM
dialect. Separate constructor methods are used instead of overloads for
better readability, similarly to IntegerType.
---
mlir/include/mlir-c/Dialect/LLVM.h | 61 ++++++++-
mlir/lib/Bindings/Python/DialectLLVM.cpp | 145 +++++++++++++++++++++
mlir/lib/CAPI/Dialect/LLVM.cpp | 68 +++++++++-
mlir/python/CMakeLists.txt | 13 ++
mlir/python/mlir/dialects/llvm.py | 1 +
mlir/test/CAPI/llvm.c | 156 ++++++++++++++++++++++-
mlir/test/python/dialects/llvm.py | 84 ++++++++++++
7 files changed, 525 insertions(+), 3 deletions(-)
create mode 100644 mlir/lib/Bindings/Python/DialectLLVM.cpp
diff --git a/mlir/include/mlir-c/Dialect/LLVM.h b/mlir/include/mlir-c/Dialect/LLVM.h
index 72701a82225436..a1bc6092b86eed 100644
--- a/mlir/include/mlir-c/Dialect/LLVM.h
+++ b/mlir/include/mlir-c/Dialect/LLVM.h
@@ -34,11 +34,70 @@ MLIR_CAPI_EXPORTED MlirType
mlirLLVMFunctionTypeGet(MlirType resultType, intptr_t nArgumentTypes,
MlirType const *argumentTypes, bool isVarArg);
-/// Creates an LLVM literal (unnamed) struct type.
+/// Returns `true` if the type is an LLVM dialect struct type.
+MLIR_CAPI_EXPORTED bool mlirTypeIsALLVMStructType(MlirType type);
+
+/// Returns `true` if the type is a literal (unnamed) LLVM struct type.
+MLIR_CAPI_EXPORTED bool mlirLLVMStructTypeIsLiteral(MlirType type);
+
+/// Returns the number of fields in the struct. Asserts if the struct is opaque
+/// or not yet initialized.
+MLIR_CAPI_EXPORTED intptr_t mlirLLVMStructTypeGetNumElementTypes(MlirType type);
+
+/// Returns the `positions`-th field of the struct. Asserts if the struct is
+/// opaque, not yet initialized or if the position is out of range.
+MLIR_CAPI_EXPORTED MlirType mlirLLVMStructTypeGetElementType(MlirType type,
+ intptr_t position);
+
+/// Returns `true` if the struct is packed.
+MLIR_CAPI_EXPORTED bool mlirLLVMStructTypeIsPacked(MlirType type);
+
+/// Returns the identifier of the identified struct. Asserts that the struct is
+/// identified, i.e., not literal.
+MLIR_CAPI_EXPORTED MlirStringRef mlirLLVMStructTypeGetIdentifier(MlirType type);
+
+/// Returns `true` is the struct is explicitly opaque (will not have a body) or
+/// unitiniazlied (will eventually have a body).
+MLIR_CAPI_EXPORTED bool mlirLLVMStructTypeIsOpaque(MlirType type);
+
+/// Creates an LLVM literal (unnamed) struct type. This may assert if the fields
+/// have types not compatible with the LLVM dialect. For a graceful failure, use
+/// the checked version.
MLIR_CAPI_EXPORTED MlirType
mlirLLVMStructTypeLiteralGet(MlirContext ctx, intptr_t nFieldTypes,
MlirType const *fieldTypes, bool isPacked);
+/// Creates an LLVM literal (unnamed) struct type if possible. Emits a
+/// diagnostic at the given location and returns null otherwise.
+MLIR_CAPI_EXPORTED MlirType
+mlirLLVMStructTypeLiteralGetChecked(MlirLocation loc, intptr_t nFieldTypes,
+ MlirType const *fieldTypes, bool isPacked);
+
+/// Creates an LLVM identified struct type with no body. If a struct type with
+/// this name already exists in the context, returns that type. Use
+/// mlirLLVMStructTypeIdentifiedNewGet to create a fresh struct type,
+/// potentially renaming it. The body should be set separatelty by calling
+/// mlirLLVMStructTypeSetBody, if it isn't set already.
+MLIR_CAPI_EXPORTED MlirType mlirLLVMStructTypeIdentifiedGet(MlirContext ctx,
+ MlirStringRef name);
+
+/// Creates an LLVM identified struct type with no body and a name starting with
+/// the given prefix. If a struct with the exact name as the given prefix
+/// already exists, appends an unspecified suffix to the name so that the name
+/// is unique in context.
+MLIR_CAPI_EXPORTED MlirType mlirLLVMStructTypeIdentifiedNewGet(
+ MlirContext ctx, MlirStringRef name, intptr_t nFieldTypes,
+ MlirType const *fieldTypes, bool isPacked);
+
+MLIR_CAPI_EXPORTED MlirType mlirLLVMStructTypeOpaqueGet(MlirContext ctx,
+ MlirStringRef name);
+
+/// Sets the body of the identified struct if it hasn't been set yet. Returns
+/// whether the operation was successful.
+MLIR_CAPI_EXPORTED MlirLogicalResult
+mlirLLVMStructTypeSetBody(MlirType structType, intptr_t nFieldTypes,
+ MlirType const *fieldTypes, bool isPacked);
+
#ifdef __cplusplus
}
#endif
diff --git a/mlir/lib/Bindings/Python/DialectLLVM.cpp b/mlir/lib/Bindings/Python/DialectLLVM.cpp
new file mode 100644
index 00000000000000..780f5eacf0b8e5
--- /dev/null
+++ b/mlir/lib/Bindings/Python/DialectLLVM.cpp
@@ -0,0 +1,145 @@
+//===- DialectLLVM.cpp - Pybind module for LLVM dialect API support -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir-c/Diagnostics.h"
+#include "mlir-c/Dialect/LLVM.h"
+#include "mlir-c/IR.h"
+#include "mlir-c/Support.h"
+#include "mlir/Bindings/Python/PybindAdaptors.h"
+#include <string>
+
+namespace py = pybind11;
+using namespace llvm;
+using namespace mlir;
+using namespace mlir::python;
+using namespace mlir::python::adaptors;
+
+/// RAII scope intercepting all diagnostics into a string. The message must be
+/// checked before this goes out of scope.
+class CollectDiagnosticsToStringScope {
+public:
+ explicit CollectDiagnosticsToStringScope(MlirContext ctx) : context(ctx) {
+ handlerID = mlirContextAttachDiagnosticHandler(ctx, &handler, &errorMessage,
+ /*deleteUserData=*/nullptr);
+ }
+ ~CollectDiagnosticsToStringScope() {
+ assert(errorMessage.empty() && "unchecked error message");
+ mlirContextDetachDiagnosticHandler(context, handlerID);
+ }
+
+ [[nodiscard]] std::string takeMessage() { return std::move(errorMessage); }
+
+private:
+ static MlirLogicalResult handler(MlirDiagnostic diag, void *data) {
+ auto printer = +[](MlirStringRef message, void *data) {
+ *static_cast<std::string *>(data) +=
+ StringRef(message.data, message.length);
+ };
+ mlirDiagnosticPrint(diag, printer, data);
+ return mlirLogicalResultSuccess();
+ }
+
+ MlirContext context;
+ MlirDiagnosticHandlerID handlerID;
+ std::string errorMessage = "";
+};
+
+void populateDialectLLVMSubmodule(const pybind11::module &m) {
+ auto llvmStructType =
+ mlir_type_subclass(m, "StructType", mlirTypeIsALLVMStructType);
+
+ llvmStructType.def_classmethod(
+ "get_literal",
+ [](py::object cls, const std::vector<MlirType> &elements, bool packed,
+ MlirLocation loc) {
+ CollectDiagnosticsToStringScope scope(mlirLocationGetContext(loc));
+
+ MlirType type = mlirLLVMStructTypeLiteralGetChecked(
+ loc, elements.size(), elements.data(), packed);
+ if (mlirTypeIsNull(type)) {
+ throw py::value_error(scope.takeMessage());
+ }
+ return cls(type);
+ },
+ py::arg("cls"), py::arg("elements"), py::kw_only(),
+ py::arg("packed") = false, py::arg("loc") = py::none());
+
+ llvmStructType.def_classmethod(
+ "get_identified",
+ [](py::object cls, const std::string &name, MlirContext context) {
+ return cls(mlirLLVMStructTypeIdentifiedGet(
+ context, mlirStringRefCreate(name.data(), name.size())));
+ },
+ py::arg("cls"), py::arg("name"), py::kw_only(),
+ py::arg("context") = py::none());
+
+ llvmStructType.def_classmethod(
+ "get_opaque",
+ [](py::object cls, const std::string &name, MlirContext context) {
+ return cls(mlirLLVMStructTypeOpaqueGet(
+ context, mlirStringRefCreate(name.data(), name.size())));
+ },
+ py::arg("cls"), py::arg("name"), py::arg("context") = py::none());
+
+ llvmStructType.def(
+ "set_body",
+ [](MlirType self, const std::vector<MlirType> &elements, bool packed) {
+ MlirLogicalResult result = mlirLLVMStructTypeSetBody(
+ self, elements.size(), elements.data(), packed);
+ if (!mlirLogicalResultIsSuccess(result)) {
+ throw py::value_error(
+ "Struct body already set to different content.");
+ }
+ },
+ py::arg("elements"), py::kw_only(), py::arg("packed") = false);
+
+ llvmStructType.def_classmethod(
+ "new_identified",
+ [](py::object cls, const std::string &name,
+ const std::vector<MlirType> &elements, bool packed, MlirContext ctx) {
+ return cls(mlirLLVMStructTypeIdentifiedNewGet(
+ ctx, mlirStringRefCreate(name.data(), name.length()),
+ elements.size(), elements.data(), packed));
+ },
+ py::arg("cls"), py::arg("name"), py::arg("elements"), py::kw_only(),
+ py::arg("packed") = false, py::arg("context") = py::none());
+
+ llvmStructType.def_property_readonly(
+ "name", [](MlirType type) -> std::optional<std::string> {
+ if (mlirLLVMStructTypeIsLiteral(type))
+ return std::nullopt;
+
+ MlirStringRef stringRef = mlirLLVMStructTypeGetIdentifier(type);
+ return StringRef(stringRef.data, stringRef.length).str();
+ });
+
+ llvmStructType.def_property_readonly("body", [](MlirType type) -> py::object {
+ // Don't crash in absence of a body.
+ if (mlirLLVMStructTypeIsOpaque(type))
+ return py::none();
+
+ py::list body;
+ for (intptr_t i = 0, e = mlirLLVMStructTypeGetNumElementTypes(type); i < e;
+ ++i) {
+ body.append(mlirLLVMStructTypeGetElementType(type, i));
+ }
+ return body;
+ });
+
+ llvmStructType.def_property_readonly(
+ "packed", [](MlirType type) { return mlirLLVMStructTypeIsPacked(type); });
+
+ llvmStructType.def_property_readonly(
+ "opaque", [](MlirType type) { return mlirLLVMStructTypeIsOpaque(type); });
+}
+
+PYBIND11_MODULE(_mlirDialectsLLVM, m) {
+ m.doc() = "MLIR LLVM Dialect";
+
+ populateDialectLLVMSubmodule(m);
+}
diff --git a/mlir/lib/CAPI/Dialect/LLVM.cpp b/mlir/lib/CAPI/Dialect/LLVM.cpp
index b4405f7aac8ab2..642018a814ca12 100644
--- a/mlir/lib/CAPI/Dialect/LLVM.cpp
+++ b/mlir/lib/CAPI/Dialect/LLVM.cpp
@@ -36,11 +36,77 @@ MlirType mlirLLVMFunctionTypeGet(MlirType resultType, intptr_t nArgumentTypes,
unwrapList(nArgumentTypes, argumentTypes, argumentStorage), isVarArg));
}
+bool mlirTypeIsALLVMStructType(MlirType type) {
+ return isa<LLVM::LLVMStructType>(unwrap(type));
+}
+
+bool mlirLLVMStructTypeIsLiteral(MlirType type) {
+ return !cast<LLVM::LLVMStructType>(unwrap(type)).isIdentified();
+}
+
+intptr_t mlirLLVMStructTypeGetNumElementTypes(MlirType type) {
+ return cast<LLVM::LLVMStructType>(unwrap(type)).getBody().size();
+}
+
+MlirType mlirLLVMStructTypeGetElementType(MlirType type, intptr_t position) {
+ return wrap(cast<LLVM::LLVMStructType>(unwrap(type)).getBody()[position]);
+}
+
+bool mlirLLVMStructTypeIsPacked(MlirType type) {
+ return cast<LLVM::LLVMStructType>(unwrap(type)).isPacked();
+}
+
+MlirStringRef mlirLLVMStructTypeGetIdentifier(MlirType type) {
+ return wrap(cast<LLVM::LLVMStructType>(unwrap(type)).getName());
+}
+
+bool mlirLLVMStructTypeIsOpaque(MlirType type) {
+ return cast<LLVM::LLVMStructType>(unwrap(type)).isOpaque();
+}
+
MlirType mlirLLVMStructTypeLiteralGet(MlirContext ctx, intptr_t nFieldTypes,
MlirType const *fieldTypes,
bool isPacked) {
- SmallVector<Type, 2> fieldStorage;
+ SmallVector<Type> fieldStorage;
return wrap(LLVMStructType::getLiteral(
unwrap(ctx), unwrapList(nFieldTypes, fieldTypes, fieldStorage),
isPacked));
}
+
+MlirType mlirLLVMStructTypeLiteralGetChecked(MlirLocation loc,
+ intptr_t nFieldTypes,
+ MlirType const *fieldTypes,
+ bool isPacked) {
+ SmallVector<Type> fieldStorage;
+ return wrap(LLVMStructType::getLiteralChecked(
+ [loc]() { return emitError(unwrap(loc)); }, unwrap(loc)->getContext(),
+ unwrapList(nFieldTypes, fieldTypes, fieldStorage), isPacked));
+}
+
+MlirType mlirLLVMStructTypeOpaqueGet(MlirContext ctx, MlirStringRef name) {
+ return wrap(LLVMStructType::getOpaque(unwrap(name), unwrap(ctx)));
+}
+
+MlirType mlirLLVMStructTypeIdentifiedGet(MlirContext ctx, MlirStringRef name) {
+ return wrap(LLVMStructType::getIdentified(unwrap(ctx), unwrap(name)));
+}
+
+MlirType mlirLLVMStructTypeIdentifiedNewGet(MlirContext ctx, MlirStringRef name,
+ intptr_t nFieldTypes,
+ MlirType const *fieldTypes,
+ bool isPacked) {
+ SmallVector<Type> fields;
+ return wrap(LLVMStructType::getNewIdentified(
+ unwrap(ctx), unwrap(name), unwrapList(nFieldTypes, fieldTypes, fields),
+ isPacked));
+}
+
+MlirLogicalResult mlirLLVMStructTypeSetBody(MlirType structType,
+ intptr_t nFieldTypes,
+ MlirType const *fieldTypes,
+ bool isPacked) {
+ SmallVector<Type> fields;
+ return wrap(
+ cast<LLVM::LLVMStructType>(unwrap(structType))
+ .setBody(unwrapList(nFieldTypes, fieldTypes, fields), isPacked));
+}
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 266b86090fe174..ed167afeb69a62 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -482,6 +482,19 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.Linalg.Pybind
MLIRCAPILinalg
)
+declare_mlir_python_extension(MLIRPythonExtension.Dialects.LLVM.Pybind
+ MODULE_NAME _mlirDialectsLLVM
+ ADD_TO_PARENT MLIRPythonSources.Dialects.llvm
+ ROOT_DIR "${PYTHON_SOURCE_DIR}"
+ SOURCES
+ DialectLLVM.cpp
+ PRIVATE_LINK_LIBS
+ LLVMSupport
+ EMBED_CAPI_LINK_LIBS
+ MLIRCAPIIR
+ MLIRCAPILLVM
+)
+
declare_mlir_python_extension(MLIRPythonExtension.Dialects.Quant.Pybind
MODULE_NAME _mlirDialectsQuant
ADD_TO_PARENT MLIRPythonSources.Dialects.quant
diff --git a/mlir/python/mlir/dialects/llvm.py b/mlir/python/mlir/dialects/llvm.py
index 77025438c37a4f..8aa16e4a256030 100644
--- a/mlir/python/mlir/dialects/llvm.py
+++ b/mlir/python/mlir/dialects/llvm.py
@@ -4,3 +4,4 @@
from ._llvm_ops_gen import *
from ._llvm_enum_gen import *
+from .._mlir_libs._mlirDialectsLLVM import *
diff --git a/mlir/test/CAPI/llvm.c b/mlir/test/CAPI/llvm.c
index aaec7b113f0a97..5a78fac91a5097 100644
--- a/mlir/test/CAPI/llvm.c
+++ b/mlir/test/CAPI/llvm.c
@@ -12,6 +12,7 @@
#include "mlir-c/Dialect/LLVM.h"
#include "mlir-c/BuiltinTypes.h"
#include "mlir-c/IR.h"
+#include "mlir-c/Support.h"
#include <assert.h>
#include <math.h>
@@ -73,11 +74,164 @@ static void testTypeCreation(MlirContext ctx) {
mlirTypeEqual(i32_i64_s, i32_i64_s_ref));
}
+// CHECK-LABEL: testStructTypeCreation
+static int testStructTypeCreation(MlirContext ctx) {
+ fprintf(stderr, "testStructTypeCreation");
+
+ // CHECK: !llvm.struct<()>
+ mlirTypeDump(mlirLLVMStructTypeLiteralGet(ctx, /*nFieldTypes=*/0,
+ /*fieldTypes=*/NULL,
+ /*isPacked=*/false));
+
+ MlirType i8 = mlirIntegerTypeGet(ctx, 8);
+ MlirType i32 = mlirIntegerTypeGet(ctx, 32);
+ MlirType i64 = mlirIntegerTypeGet(ctx, 64);
+ MlirType i8_i32_i64[] = {i8, i32, i64};
+ // CHECK: !llvm.struct<(i8, i32, i64)>
+ mlirTypeDump(
+ mlirLLVMStructTypeLiteralGet(ctx, sizeof(i8_i32_i64) / sizeof(MlirType),
+ i8_i32_i64, /*isPacked=*/false));
+ // CHECK: !llvm.struct<(i32)>
+ mlirTypeDump(mlirLLVMStructTypeLiteralGet(ctx, 1, &i32, /*isPacked=*/false));
+ MlirType i32_i32[] = {i32, i32};
+ // CHECK: !llvm.struct<packed (i32, i32)>
+ mlirTypeDump(mlirLLVMStructTypeLiteralGet(
+ ctx, sizeof(i32_i32) / sizeof(MlirType), i32_i32, /*isPacked=*/true));
+
+ MlirType literal =
+ mlirLLVMStructTypeLiteralGet(ctx, sizeof(i8_i32_i64) / sizeof(MlirType),
+ i8_i32_i64, /*isPacked=*/false);
+ // CHECK: num elements: 3
+ // CHECK: i8
+ // CHECK: i32
+ // CHECK: i64
+ fprintf(stderr, "num elements: %ld\n",
+ mlirLLVMStructTypeGetNumElementTypes(literal));
+ for (intptr_t i = 0; i < 3; ++i) {
+ mlirTypeDump(mlirLLVMStructTypeGetElementType(literal, i));
+ }
+
+ if (!mlirTypeEqual(
+ mlirLLVMStructTypeLiteralGet(ctx, 1, &i32, /*isPacked=*/false),
+ mlirLLVMStructTypeLiteralGet(ctx, 1, &i32, /*isPacked=*/false))) {
+ return 1;
+ }
+ if (mlirTypeEqual(
+ mlirLLVMStructTypeLiteralGet(ctx, 1, &i32, /*isPacked=*/false),
+ mlirLLVMStructTypeLiteralGet(ctx, 1, &i64, /*isPacked=*/false))) {
+ return 2;
+ }
+
+ // CHECK: !llvm.struct<"foo", opaque>
+ // CHECK: !llvm.struct<"bar", opaque>
+ mlirTypeDump(mlirLLVMStructTypeIdentifiedGet(
+ ctx, mlirStringRefCreateFromCString("foo")));
+ mlirTypeDump(mlirLLVMStructTypeIdentifiedGet(
+ ctx, mlirStringRefCreateFromCString("bar")));
+
+ if (!mlirTypeEqual(mlirLLVMStructTypeIdentifiedGet(
+ ctx, mlirStringRefCreateFromCString("foo")),
+ mlirLLVMStructTypeIdentifiedGet(
+ ctx, mlirStringRefCreateFromCString("foo")))) {
+ return 3;
+ }
+ if (mlirTypeEqual(mlirLLVMStructTypeIdentifiedGet(
+ ctx, mlirStringRefCreateFromCString("foo")),
+ mlirLLVMStructTypeIdentifiedGet(
+ ctx, mlirStringRefCreateFromCString("bar")))) {
+ return 4;
+ }
+
+ MlirType fooStruct = mlirLLVMStructTypeIdentifiedGet(
+ ctx, mlirStringRefCreateFromCString("foo"));
+ MlirStringRef name = mlirLLVMStructTypeGetIdentifier(fooStruct);
+ if (memcmp(name.data, "foo", name.length))
+ return 5;
+ if (!mlirLLVMStructTypeIsOpaque(fooStruct))
+ return 6;
+
+ MlirType i32_i64[] = {i32, i64};
+ MlirLogicalResult result =
+ mlirLLVMStructTypeSetBody(fooStruct, sizeof(i32_i64) / sizeof(MlirType),
+ i32_i64, /*isPacked=*/false);
+ if (!mlirLogicalResultIsSuccess(result))
+ return 7;
+
+ // CHECK: !llvm.struct<"foo", (i32, i64)>
+ mlirTypeDump(fooStruct);
+ if (mlirLLVMStructTypeIsOpaque(fooStruct))
+ return 8;
+ if (mlirLLVMStructTypeIsPacked(fooStruct))
+ return 9;
+ if (!mlirTypeEqual(mlirLLVMStructTypeIdentifiedGet(
+ ctx, mlirStringRefCreateFromCString("foo")),
+ fooStruct)) {
+ return 10;
+ }
+
+ MlirType barStruct = mlirLLVMStructTypeIdentifiedGet(
+ ctx, mlirStringRefCreateFromCString("bar"));
+ result = mlirLLVMStructTypeSetBody(barStruct, 1, &i32, /*isPacked=*/true);
+ if (!mlirLogicalResultIsSuccess(result))
+ return 11;
+
+ // CHECK: !llvm.struct<"bar", packed (i32)>
+ mlirTypeDump(barStruct);
+ if (!mlirLLVMStructTypeIsPacked(barStruct))
+ return 12;
+
+ // Same body, should succeed.
+ result =
+ mlirLLVMStructTypeSetBody(fooStruct, sizeof(i32_i64) / sizeof(MlirType),
+ i32_i64, /*isPacked=*/false);
+ if (!mlirLogicalResultIsSuccess(result))
+ return 13;
+
+ // Different body, should fail.
+ result = mlirLLVMStructTypeSetBody(fooStruct, 1, &i32, /*isPacked=*/false);
+ if (mlirLogicalResultIsSuccess(result))
+ return 14;
+
+ // Packed flag differs, should fail.
+ result = mlirLLVMStructTypeSetBody(barStruct, 1, &i32, /*isPacked=*/false);
+ if (mlirLogicalResultIsSuccess(result))
+ return 15;
+
+ // Should have a different name.
+ // CHECK: !llvm.struct<"foo{{[^"]+}}
+ mlirTypeDump(mlirLLVMStructTypeIdentifiedNewGet(
+ ctx, mlirStringRefCreateFromCString("foo"), /*nFieldTypes=*/0,
+ /*fieldTypes=*/NULL, /*isPacked=*/false));
+
+ // Two freshly created "new" types must differ.
+ if (mlirTypeEqual(
+ mlirLLVMStructTypeIdentifiedNewGet(
+ ctx, mlirStringRefCreateFromCString("foo"), /*nFieldTypes=*/0,
+ /*fieldTypes=*/NULL, /*isPacked=*/false),
+ mlirLLVMStructTypeIdentifiedNewGet(
+ ctx, mlirStringRefCreateFromCString("foo"), /*nFieldTypes=*/0,
+ /*fieldTypes=*/NULL, /*isPacked=*/false))) {
+ return 16;
+ }
+
+ MlirType opaque = mlirLLVMStructTypeOpaqueGet(
+ ctx, mlirStringRefCreateFromCString("opaque"));
+ // CHECK: !llvm.struct<"opaque", opaque>
+ mlirTypeDump(opaque);
+ if (!mlirLLVMStructTypeIsOpaque(opaque))
+ return 17;
+
+ return 0;
+}
+
int main(void) {
MlirContext ctx = mlirContextCreate();
mlirDialectHandleRegisterDialect(mlirGetDialectHandle__llvm__(), ctx);
mlirContextGetOrLoadDialect(ctx, mlirStringRefCreateFromCString("llvm"));
testTypeCreation(ctx);
+ int result = testStructTypeCreation(ctx);
mlirContextDestroy(ctx);
- return 0;
+ if (result)
+ fprintf(stderr, "FAILED: code %d", result);
+ return result;
}
diff --git a/mlir/test/python/dialects/llvm.py b/mlir/test/python/dialects/llvm.py
index 2d207ae14eecd2..af919df723827f 100644
--- a/mlir/test/python/dialects/llvm.py
+++ b/mlir/test/python/dialects/llvm.py
@@ -15,6 +15,90 @@ def constructAndPrintInModule(f):
return f
+# CHECK-LABEL: testStructType
+ at constructAndPrintInModule
+def testStructType():
+ print(llvm.StructType.get_literal([]))
+ # CHECK: !llvm.struct<()>
+
+ i8, i32, i64 = tuple(map(lambda x: IntegerType.get_signless(x), [8, 32, 64]))
+ print(llvm.StructType.get_literal([i8, i32, i64]))
+ print(llvm.StructType.get_literal([i32]))
+ print(llvm.StructType.get_literal([i32, i32], packed=True))
+ literal = llvm.StructType.get_literal([i8, i32, i64])
+ assert len(literal.body) == 3
+ print(*tuple(literal.body))
+ assert literal.name is None
+ # CHECK: !llvm.struct<(i8, i32, i64)>
+ # CHECK: !llvm.struct<(i32)>
+ # CHECK: !llvm.struct<packed (i32, i32)>
+ # CHECK: i8 i32 i64
+
+ assert llvm.StructType.get_literal([i32]) == llvm.StructType.get_literal([i32])
+ assert llvm.StructType.get_literal([i32]) != llvm.StructType.get_literal([i64])
+
+ print(llvm.StructType.get_identified("foo"))
+ print(llvm.StructType.get_identified("bar"))
+ # CHECK: !llvm.struct<"foo", opaque>
+ # CHECK: !llvm.struct<"bar", opaque>
+
+ assert llvm.StructType.get_identified("foo") == llvm.StructType.get_identified(
+ "foo"
+ )
+ assert llvm.StructType.get_identified("foo") != llvm.StructType.get_identified(
+ "bar"
+ )
+
+ foo_struct = llvm.StructType.get_identified("foo")
+ print(foo_struct.name)
+ print(foo_struct.body)
+ assert foo_struct.opaque
+ foo_struct.set_body([i32, i64])
+ print(*tuple(foo_struct.body))
+ print(foo_struct)
+ assert not foo_struct.packed
+ assert not foo_struct.opaque
+ assert llvm.StructType.get_identified("foo") == foo_struct
+ # CHECK: foo
+ # CHECK: None
+ # CHECK: i32 i64
+ # CHECK: !llvm.struct<"foo", (i32, i64)>
+
+ bar_struct = llvm.StructType.get_identified("bar")
+ bar_struct.set_body([i32], packed=True)
+ print(bar_struct)
+ assert bar_struct.packed
+ # CHECK: !llvm.struct<"bar", packed (i32)>
+
+ # Same body, should not raise.
+ foo_struct.set_body([i32, i64])
+
+ try:
+ foo_struct.set_body([])
+ except ValueError as e:
+ pass
+ else:
+ assert False, "expected exception not raised"
+
+ try:
+ bar_struct.set_body([i32])
+ except ValueError as e:
+ pass
+ else:
+ assert False, "expected exception not raisr"
+
+ print(llvm.StructType.new_identified("foo", []))
+ assert llvm.StructType.new_identified("foo", []) != llvm.StructType.new_identified(
+ "foo", []
+ )
+ # CHECK: !llvm.struct<"foo{{[^"]+}}
+
+ opaque = llvm.StructType.get_opaque("opaque")
+ print(opaque)
+ assert opaque.opaque
+ # CHECK: !llvm.struct<"opaque", opaque>
+
+
# CHECK-LABEL: testSmoke
@constructAndPrintInModule
def testSmoke():
More information about the Mlir-commits
mailing list