[Mlir-commits] [mlir] [mlir][python] expose LLVMStructType API (PR #81672)

Oleksandr Alex Zinenko llvmlistbot at llvm.org
Wed Feb 14 05:07:33 PST 2024


https://github.com/ftynse updated https://github.com/llvm/llvm-project/pull/81672

>From 78d2433f94fc71412ee468a6ab03c927c15c42a6 Mon Sep 17 00:00:00 2001
From: Alex Zinenko <zinenko at google.com>
Date: Wed, 14 Feb 2024 13:03:46 +0000
Subject: [PATCH] [mlir] expose LLVMStructType API

Expose the API for constructing and inspecting StructTypes from the LLVM
dialect. Separate constructor methods are used instead of overloads for
better readability, similarly to IntegerType.
---
 mlir/include/mlir-c/Dialect/LLVM.h       |  61 ++++++++-
 mlir/lib/Bindings/Python/DialectLLVM.cpp | 145 +++++++++++++++++++++
 mlir/lib/CAPI/Dialect/LLVM.cpp           |  68 +++++++++-
 mlir/python/CMakeLists.txt               |  13 ++
 mlir/python/mlir/dialects/llvm.py        |   1 +
 mlir/test/CAPI/llvm.c                    | 156 ++++++++++++++++++++++-
 mlir/test/python/dialects/llvm.py        |  84 ++++++++++++
 7 files changed, 525 insertions(+), 3 deletions(-)
 create mode 100644 mlir/lib/Bindings/Python/DialectLLVM.cpp

diff --git a/mlir/include/mlir-c/Dialect/LLVM.h b/mlir/include/mlir-c/Dialect/LLVM.h
index 72701a82225436..ac216b01f364d4 100644
--- a/mlir/include/mlir-c/Dialect/LLVM.h
+++ b/mlir/include/mlir-c/Dialect/LLVM.h
@@ -34,11 +34,70 @@ MLIR_CAPI_EXPORTED MlirType
 mlirLLVMFunctionTypeGet(MlirType resultType, intptr_t nArgumentTypes,
                         MlirType const *argumentTypes, bool isVarArg);
 
-/// Creates an LLVM literal (unnamed) struct type.
+/// Returns `true` if the type is an LLVM dialect struct type.
+MLIR_CAPI_EXPORTED bool mlirTypeIsALLVMStructType(MlirType type);
+
+/// Returns `true` if the type is a literal (unnamed) LLVM struct type.
+MLIR_CAPI_EXPORTED bool mlirLLVMStructTypeIsLiteral(MlirType type);
+
+/// Returns the number of fields in the struct. Asserts if the struct is opaque
+/// or not yet initialized.
+MLIR_CAPI_EXPORTED intptr_t mlirLLVMStructTypeGetNumElementTypes(MlirType type);
+
+/// Returns the `positions`-th field of the struct. Asserts if the struct is
+/// opaque, not yet initialized or if the position is out of range.
+MLIR_CAPI_EXPORTED MlirType mlirLLVMStructTypeGetElementType(MlirType type,
+                                                             intptr_t position);
+
+/// Returns `true` if the struct is packed.
+MLIR_CAPI_EXPORTED bool mlirLLVMStructTypeIsPacked(MlirType type);
+
+/// Returns the identifier of the identified struct. Asserts that the struct is
+/// identified, i.e., not literal.
+MLIR_CAPI_EXPORTED MlirStringRef mlirLLVMStructTypeGetIdentifier(MlirType type);
+
+/// Returns `true` is the struct is explicitly opaque (will not have a body) or
+/// uninitialized (will eventually have a body).
+MLIR_CAPI_EXPORTED bool mlirLLVMStructTypeIsOpaque(MlirType type);
+
+/// Creates an LLVM literal (unnamed) struct type. This may assert if the fields
+/// have types not compatible with the LLVM dialect. For a graceful failure, use
+/// the checked version.
 MLIR_CAPI_EXPORTED MlirType
 mlirLLVMStructTypeLiteralGet(MlirContext ctx, intptr_t nFieldTypes,
                              MlirType const *fieldTypes, bool isPacked);
 
+/// Creates an LLVM literal (unnamed) struct type if possible. Emits a
+/// diagnostic at the given location and returns null otherwise.
+MLIR_CAPI_EXPORTED MlirType
+mlirLLVMStructTypeLiteralGetChecked(MlirLocation loc, intptr_t nFieldTypes,
+                                    MlirType const *fieldTypes, bool isPacked);
+
+/// Creates an LLVM identified struct type with no body. If a struct type with
+/// this name already exists in the context, returns that type. Use
+/// mlirLLVMStructTypeIdentifiedNewGet to create a fresh struct type,
+/// potentially renaming it. The body should be set separatelty by calling
+/// mlirLLVMStructTypeSetBody, if it isn't set already.
+MLIR_CAPI_EXPORTED MlirType mlirLLVMStructTypeIdentifiedGet(MlirContext ctx,
+                                                            MlirStringRef name);
+
+/// Creates an LLVM identified struct type with no body and a name starting with
+/// the given prefix. If a struct with the exact name as the given prefix
+/// already exists, appends an unspecified suffix to the name so that the name
+/// is unique in context.
+MLIR_CAPI_EXPORTED MlirType mlirLLVMStructTypeIdentifiedNewGet(
+    MlirContext ctx, MlirStringRef name, intptr_t nFieldTypes,
+    MlirType const *fieldTypes, bool isPacked);
+
+MLIR_CAPI_EXPORTED MlirType mlirLLVMStructTypeOpaqueGet(MlirContext ctx,
+                                                        MlirStringRef name);
+
+/// Sets the body of the identified struct if it hasn't been set yet. Returns
+/// whether the operation was successful.
+MLIR_CAPI_EXPORTED MlirLogicalResult
+mlirLLVMStructTypeSetBody(MlirType structType, intptr_t nFieldTypes,
+                          MlirType const *fieldTypes, bool isPacked);
+
 #ifdef __cplusplus
 }
 #endif
diff --git a/mlir/lib/Bindings/Python/DialectLLVM.cpp b/mlir/lib/Bindings/Python/DialectLLVM.cpp
new file mode 100644
index 00000000000000..780f5eacf0b8e5
--- /dev/null
+++ b/mlir/lib/Bindings/Python/DialectLLVM.cpp
@@ -0,0 +1,145 @@
+//===- DialectLLVM.cpp - Pybind module for LLVM dialect API support -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir-c/Diagnostics.h"
+#include "mlir-c/Dialect/LLVM.h"
+#include "mlir-c/IR.h"
+#include "mlir-c/Support.h"
+#include "mlir/Bindings/Python/PybindAdaptors.h"
+#include <string>
+
+namespace py = pybind11;
+using namespace llvm;
+using namespace mlir;
+using namespace mlir::python;
+using namespace mlir::python::adaptors;
+
+/// RAII scope intercepting all diagnostics into a string. The message must be
+/// checked before this goes out of scope.
+class CollectDiagnosticsToStringScope {
+public:
+  explicit CollectDiagnosticsToStringScope(MlirContext ctx) : context(ctx) {
+    handlerID = mlirContextAttachDiagnosticHandler(ctx, &handler, &errorMessage,
+                                                   /*deleteUserData=*/nullptr);
+  }
+  ~CollectDiagnosticsToStringScope() {
+    assert(errorMessage.empty() && "unchecked error message");
+    mlirContextDetachDiagnosticHandler(context, handlerID);
+  }
+
+  [[nodiscard]] std::string takeMessage() { return std::move(errorMessage); }
+
+private:
+  static MlirLogicalResult handler(MlirDiagnostic diag, void *data) {
+    auto printer = +[](MlirStringRef message, void *data) {
+      *static_cast<std::string *>(data) +=
+          StringRef(message.data, message.length);
+    };
+    mlirDiagnosticPrint(diag, printer, data);
+    return mlirLogicalResultSuccess();
+  }
+
+  MlirContext context;
+  MlirDiagnosticHandlerID handlerID;
+  std::string errorMessage = "";
+};
+
+void populateDialectLLVMSubmodule(const pybind11::module &m) {
+  auto llvmStructType =
+      mlir_type_subclass(m, "StructType", mlirTypeIsALLVMStructType);
+
+  llvmStructType.def_classmethod(
+      "get_literal",
+      [](py::object cls, const std::vector<MlirType> &elements, bool packed,
+         MlirLocation loc) {
+        CollectDiagnosticsToStringScope scope(mlirLocationGetContext(loc));
+
+        MlirType type = mlirLLVMStructTypeLiteralGetChecked(
+            loc, elements.size(), elements.data(), packed);
+        if (mlirTypeIsNull(type)) {
+          throw py::value_error(scope.takeMessage());
+        }
+        return cls(type);
+      },
+      py::arg("cls"), py::arg("elements"), py::kw_only(),
+      py::arg("packed") = false, py::arg("loc") = py::none());
+
+  llvmStructType.def_classmethod(
+      "get_identified",
+      [](py::object cls, const std::string &name, MlirContext context) {
+        return cls(mlirLLVMStructTypeIdentifiedGet(
+            context, mlirStringRefCreate(name.data(), name.size())));
+      },
+      py::arg("cls"), py::arg("name"), py::kw_only(),
+      py::arg("context") = py::none());
+
+  llvmStructType.def_classmethod(
+      "get_opaque",
+      [](py::object cls, const std::string &name, MlirContext context) {
+        return cls(mlirLLVMStructTypeOpaqueGet(
+            context, mlirStringRefCreate(name.data(), name.size())));
+      },
+      py::arg("cls"), py::arg("name"), py::arg("context") = py::none());
+
+  llvmStructType.def(
+      "set_body",
+      [](MlirType self, const std::vector<MlirType> &elements, bool packed) {
+        MlirLogicalResult result = mlirLLVMStructTypeSetBody(
+            self, elements.size(), elements.data(), packed);
+        if (!mlirLogicalResultIsSuccess(result)) {
+          throw py::value_error(
+              "Struct body already set to different content.");
+        }
+      },
+      py::arg("elements"), py::kw_only(), py::arg("packed") = false);
+
+  llvmStructType.def_classmethod(
+      "new_identified",
+      [](py::object cls, const std::string &name,
+         const std::vector<MlirType> &elements, bool packed, MlirContext ctx) {
+        return cls(mlirLLVMStructTypeIdentifiedNewGet(
+            ctx, mlirStringRefCreate(name.data(), name.length()),
+            elements.size(), elements.data(), packed));
+      },
+      py::arg("cls"), py::arg("name"), py::arg("elements"), py::kw_only(),
+      py::arg("packed") = false, py::arg("context") = py::none());
+
+  llvmStructType.def_property_readonly(
+      "name", [](MlirType type) -> std::optional<std::string> {
+        if (mlirLLVMStructTypeIsLiteral(type))
+          return std::nullopt;
+
+        MlirStringRef stringRef = mlirLLVMStructTypeGetIdentifier(type);
+        return StringRef(stringRef.data, stringRef.length).str();
+      });
+
+  llvmStructType.def_property_readonly("body", [](MlirType type) -> py::object {
+    // Don't crash in absence of a body.
+    if (mlirLLVMStructTypeIsOpaque(type))
+      return py::none();
+
+    py::list body;
+    for (intptr_t i = 0, e = mlirLLVMStructTypeGetNumElementTypes(type); i < e;
+         ++i) {
+      body.append(mlirLLVMStructTypeGetElementType(type, i));
+    }
+    return body;
+  });
+
+  llvmStructType.def_property_readonly(
+      "packed", [](MlirType type) { return mlirLLVMStructTypeIsPacked(type); });
+
+  llvmStructType.def_property_readonly(
+      "opaque", [](MlirType type) { return mlirLLVMStructTypeIsOpaque(type); });
+}
+
+PYBIND11_MODULE(_mlirDialectsLLVM, m) {
+  m.doc() = "MLIR LLVM Dialect";
+
+  populateDialectLLVMSubmodule(m);
+}
diff --git a/mlir/lib/CAPI/Dialect/LLVM.cpp b/mlir/lib/CAPI/Dialect/LLVM.cpp
index b4405f7aac8ab2..642018a814ca12 100644
--- a/mlir/lib/CAPI/Dialect/LLVM.cpp
+++ b/mlir/lib/CAPI/Dialect/LLVM.cpp
@@ -36,11 +36,77 @@ MlirType mlirLLVMFunctionTypeGet(MlirType resultType, intptr_t nArgumentTypes,
       unwrapList(nArgumentTypes, argumentTypes, argumentStorage), isVarArg));
 }
 
+bool mlirTypeIsALLVMStructType(MlirType type) {
+  return isa<LLVM::LLVMStructType>(unwrap(type));
+}
+
+bool mlirLLVMStructTypeIsLiteral(MlirType type) {
+  return !cast<LLVM::LLVMStructType>(unwrap(type)).isIdentified();
+}
+
+intptr_t mlirLLVMStructTypeGetNumElementTypes(MlirType type) {
+  return cast<LLVM::LLVMStructType>(unwrap(type)).getBody().size();
+}
+
+MlirType mlirLLVMStructTypeGetElementType(MlirType type, intptr_t position) {
+  return wrap(cast<LLVM::LLVMStructType>(unwrap(type)).getBody()[position]);
+}
+
+bool mlirLLVMStructTypeIsPacked(MlirType type) {
+  return cast<LLVM::LLVMStructType>(unwrap(type)).isPacked();
+}
+
+MlirStringRef mlirLLVMStructTypeGetIdentifier(MlirType type) {
+  return wrap(cast<LLVM::LLVMStructType>(unwrap(type)).getName());
+}
+
+bool mlirLLVMStructTypeIsOpaque(MlirType type) {
+  return cast<LLVM::LLVMStructType>(unwrap(type)).isOpaque();
+}
+
 MlirType mlirLLVMStructTypeLiteralGet(MlirContext ctx, intptr_t nFieldTypes,
                                       MlirType const *fieldTypes,
                                       bool isPacked) {
-  SmallVector<Type, 2> fieldStorage;
+  SmallVector<Type> fieldStorage;
   return wrap(LLVMStructType::getLiteral(
       unwrap(ctx), unwrapList(nFieldTypes, fieldTypes, fieldStorage),
       isPacked));
 }
+
+MlirType mlirLLVMStructTypeLiteralGetChecked(MlirLocation loc,
+                                             intptr_t nFieldTypes,
+                                             MlirType const *fieldTypes,
+                                             bool isPacked) {
+  SmallVector<Type> fieldStorage;
+  return wrap(LLVMStructType::getLiteralChecked(
+      [loc]() { return emitError(unwrap(loc)); }, unwrap(loc)->getContext(),
+      unwrapList(nFieldTypes, fieldTypes, fieldStorage), isPacked));
+}
+
+MlirType mlirLLVMStructTypeOpaqueGet(MlirContext ctx, MlirStringRef name) {
+  return wrap(LLVMStructType::getOpaque(unwrap(name), unwrap(ctx)));
+}
+
+MlirType mlirLLVMStructTypeIdentifiedGet(MlirContext ctx, MlirStringRef name) {
+  return wrap(LLVMStructType::getIdentified(unwrap(ctx), unwrap(name)));
+}
+
+MlirType mlirLLVMStructTypeIdentifiedNewGet(MlirContext ctx, MlirStringRef name,
+                                            intptr_t nFieldTypes,
+                                            MlirType const *fieldTypes,
+                                            bool isPacked) {
+  SmallVector<Type> fields;
+  return wrap(LLVMStructType::getNewIdentified(
+      unwrap(ctx), unwrap(name), unwrapList(nFieldTypes, fieldTypes, fields),
+      isPacked));
+}
+
+MlirLogicalResult mlirLLVMStructTypeSetBody(MlirType structType,
+                                            intptr_t nFieldTypes,
+                                            MlirType const *fieldTypes,
+                                            bool isPacked) {
+  SmallVector<Type> fields;
+  return wrap(
+      cast<LLVM::LLVMStructType>(unwrap(structType))
+          .setBody(unwrapList(nFieldTypes, fieldTypes, fields), isPacked));
+}
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 266b86090fe174..ed167afeb69a62 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -482,6 +482,19 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.Linalg.Pybind
     MLIRCAPILinalg
 )
 
+declare_mlir_python_extension(MLIRPythonExtension.Dialects.LLVM.Pybind
+  MODULE_NAME _mlirDialectsLLVM
+  ADD_TO_PARENT MLIRPythonSources.Dialects.llvm
+  ROOT_DIR "${PYTHON_SOURCE_DIR}"
+  SOURCES
+    DialectLLVM.cpp
+  PRIVATE_LINK_LIBS
+    LLVMSupport
+  EMBED_CAPI_LINK_LIBS
+    MLIRCAPIIR
+    MLIRCAPILLVM
+)
+
 declare_mlir_python_extension(MLIRPythonExtension.Dialects.Quant.Pybind
   MODULE_NAME _mlirDialectsQuant
   ADD_TO_PARENT MLIRPythonSources.Dialects.quant
diff --git a/mlir/python/mlir/dialects/llvm.py b/mlir/python/mlir/dialects/llvm.py
index 77025438c37a4f..8aa16e4a256030 100644
--- a/mlir/python/mlir/dialects/llvm.py
+++ b/mlir/python/mlir/dialects/llvm.py
@@ -4,3 +4,4 @@
 
 from ._llvm_ops_gen import *
 from ._llvm_enum_gen import *
+from .._mlir_libs._mlirDialectsLLVM import *
diff --git a/mlir/test/CAPI/llvm.c b/mlir/test/CAPI/llvm.c
index aaec7b113f0a97..5a78fac91a5097 100644
--- a/mlir/test/CAPI/llvm.c
+++ b/mlir/test/CAPI/llvm.c
@@ -12,6 +12,7 @@
 #include "mlir-c/Dialect/LLVM.h"
 #include "mlir-c/BuiltinTypes.h"
 #include "mlir-c/IR.h"
+#include "mlir-c/Support.h"
 
 #include <assert.h>
 #include <math.h>
@@ -73,11 +74,164 @@ static void testTypeCreation(MlirContext ctx) {
           mlirTypeEqual(i32_i64_s, i32_i64_s_ref));
 }
 
+// CHECK-LABEL: testStructTypeCreation
+static int testStructTypeCreation(MlirContext ctx) {
+  fprintf(stderr, "testStructTypeCreation");
+
+  // CHECK: !llvm.struct<()>
+  mlirTypeDump(mlirLLVMStructTypeLiteralGet(ctx, /*nFieldTypes=*/0,
+                                            /*fieldTypes=*/NULL,
+                                            /*isPacked=*/false));
+
+  MlirType i8 = mlirIntegerTypeGet(ctx, 8);
+  MlirType i32 = mlirIntegerTypeGet(ctx, 32);
+  MlirType i64 = mlirIntegerTypeGet(ctx, 64);
+  MlirType i8_i32_i64[] = {i8, i32, i64};
+  // CHECK: !llvm.struct<(i8, i32, i64)>
+  mlirTypeDump(
+      mlirLLVMStructTypeLiteralGet(ctx, sizeof(i8_i32_i64) / sizeof(MlirType),
+                                   i8_i32_i64, /*isPacked=*/false));
+  // CHECK: !llvm.struct<(i32)>
+  mlirTypeDump(mlirLLVMStructTypeLiteralGet(ctx, 1, &i32, /*isPacked=*/false));
+  MlirType i32_i32[] = {i32, i32};
+  // CHECK: !llvm.struct<packed (i32, i32)>
+  mlirTypeDump(mlirLLVMStructTypeLiteralGet(
+      ctx, sizeof(i32_i32) / sizeof(MlirType), i32_i32, /*isPacked=*/true));
+
+  MlirType literal =
+      mlirLLVMStructTypeLiteralGet(ctx, sizeof(i8_i32_i64) / sizeof(MlirType),
+                                   i8_i32_i64, /*isPacked=*/false);
+  // CHECK: num elements: 3
+  // CHECK: i8
+  // CHECK: i32
+  // CHECK: i64
+  fprintf(stderr, "num elements: %ld\n",
+          mlirLLVMStructTypeGetNumElementTypes(literal));
+  for (intptr_t i = 0; i < 3; ++i) {
+    mlirTypeDump(mlirLLVMStructTypeGetElementType(literal, i));
+  }
+
+  if (!mlirTypeEqual(
+          mlirLLVMStructTypeLiteralGet(ctx, 1, &i32, /*isPacked=*/false),
+          mlirLLVMStructTypeLiteralGet(ctx, 1, &i32, /*isPacked=*/false))) {
+    return 1;
+  }
+  if (mlirTypeEqual(
+          mlirLLVMStructTypeLiteralGet(ctx, 1, &i32, /*isPacked=*/false),
+          mlirLLVMStructTypeLiteralGet(ctx, 1, &i64, /*isPacked=*/false))) {
+    return 2;
+  }
+
+  // CHECK: !llvm.struct<"foo", opaque>
+  // CHECK: !llvm.struct<"bar", opaque>
+  mlirTypeDump(mlirLLVMStructTypeIdentifiedGet(
+      ctx, mlirStringRefCreateFromCString("foo")));
+  mlirTypeDump(mlirLLVMStructTypeIdentifiedGet(
+      ctx, mlirStringRefCreateFromCString("bar")));
+
+  if (!mlirTypeEqual(mlirLLVMStructTypeIdentifiedGet(
+                         ctx, mlirStringRefCreateFromCString("foo")),
+                     mlirLLVMStructTypeIdentifiedGet(
+                         ctx, mlirStringRefCreateFromCString("foo")))) {
+    return 3;
+  }
+  if (mlirTypeEqual(mlirLLVMStructTypeIdentifiedGet(
+                        ctx, mlirStringRefCreateFromCString("foo")),
+                    mlirLLVMStructTypeIdentifiedGet(
+                        ctx, mlirStringRefCreateFromCString("bar")))) {
+    return 4;
+  }
+
+  MlirType fooStruct = mlirLLVMStructTypeIdentifiedGet(
+      ctx, mlirStringRefCreateFromCString("foo"));
+  MlirStringRef name = mlirLLVMStructTypeGetIdentifier(fooStruct);
+  if (memcmp(name.data, "foo", name.length))
+    return 5;
+  if (!mlirLLVMStructTypeIsOpaque(fooStruct))
+    return 6;
+
+  MlirType i32_i64[] = {i32, i64};
+  MlirLogicalResult result =
+      mlirLLVMStructTypeSetBody(fooStruct, sizeof(i32_i64) / sizeof(MlirType),
+                                i32_i64, /*isPacked=*/false);
+  if (!mlirLogicalResultIsSuccess(result))
+    return 7;
+
+  // CHECK: !llvm.struct<"foo", (i32, i64)>
+  mlirTypeDump(fooStruct);
+  if (mlirLLVMStructTypeIsOpaque(fooStruct))
+    return 8;
+  if (mlirLLVMStructTypeIsPacked(fooStruct))
+    return 9;
+  if (!mlirTypeEqual(mlirLLVMStructTypeIdentifiedGet(
+                         ctx, mlirStringRefCreateFromCString("foo")),
+                     fooStruct)) {
+    return 10;
+  }
+
+  MlirType barStruct = mlirLLVMStructTypeIdentifiedGet(
+      ctx, mlirStringRefCreateFromCString("bar"));
+  result = mlirLLVMStructTypeSetBody(barStruct, 1, &i32, /*isPacked=*/true);
+  if (!mlirLogicalResultIsSuccess(result))
+    return 11;
+
+  // CHECK: !llvm.struct<"bar", packed (i32)>
+  mlirTypeDump(barStruct);
+  if (!mlirLLVMStructTypeIsPacked(barStruct))
+    return 12;
+
+  // Same body, should succeed.
+  result =
+      mlirLLVMStructTypeSetBody(fooStruct, sizeof(i32_i64) / sizeof(MlirType),
+                                i32_i64, /*isPacked=*/false);
+  if (!mlirLogicalResultIsSuccess(result))
+    return 13;
+
+  // Different body, should fail.
+  result = mlirLLVMStructTypeSetBody(fooStruct, 1, &i32, /*isPacked=*/false);
+  if (mlirLogicalResultIsSuccess(result))
+    return 14;
+
+  // Packed flag differs, should fail.
+  result = mlirLLVMStructTypeSetBody(barStruct, 1, &i32, /*isPacked=*/false);
+  if (mlirLogicalResultIsSuccess(result))
+    return 15;
+
+  // Should have a different name.
+  // CHECK: !llvm.struct<"foo{{[^"]+}}
+  mlirTypeDump(mlirLLVMStructTypeIdentifiedNewGet(
+      ctx, mlirStringRefCreateFromCString("foo"), /*nFieldTypes=*/0,
+      /*fieldTypes=*/NULL, /*isPacked=*/false));
+
+  // Two freshly created "new" types must differ.
+  if (mlirTypeEqual(
+          mlirLLVMStructTypeIdentifiedNewGet(
+              ctx, mlirStringRefCreateFromCString("foo"), /*nFieldTypes=*/0,
+              /*fieldTypes=*/NULL, /*isPacked=*/false),
+          mlirLLVMStructTypeIdentifiedNewGet(
+              ctx, mlirStringRefCreateFromCString("foo"), /*nFieldTypes=*/0,
+              /*fieldTypes=*/NULL, /*isPacked=*/false))) {
+    return 16;
+  }
+
+  MlirType opaque = mlirLLVMStructTypeOpaqueGet(
+      ctx, mlirStringRefCreateFromCString("opaque"));
+  // CHECK: !llvm.struct<"opaque", opaque>
+  mlirTypeDump(opaque);
+  if (!mlirLLVMStructTypeIsOpaque(opaque))
+    return 17;
+
+  return 0;
+}
+
 int main(void) {
   MlirContext ctx = mlirContextCreate();
   mlirDialectHandleRegisterDialect(mlirGetDialectHandle__llvm__(), ctx);
   mlirContextGetOrLoadDialect(ctx, mlirStringRefCreateFromCString("llvm"));
   testTypeCreation(ctx);
+  int result = testStructTypeCreation(ctx);
   mlirContextDestroy(ctx);
-  return 0;
+  if (result)
+    fprintf(stderr, "FAILED: code %d", result);
+  return result;
 }
diff --git a/mlir/test/python/dialects/llvm.py b/mlir/test/python/dialects/llvm.py
index 2d207ae14eecd2..fb4b343b170bae 100644
--- a/mlir/test/python/dialects/llvm.py
+++ b/mlir/test/python/dialects/llvm.py
@@ -15,6 +15,90 @@ def constructAndPrintInModule(f):
     return f
 
 
+# CHECK-LABEL: testStructType
+ at constructAndPrintInModule
+def testStructType():
+    print(llvm.StructType.get_literal([]))
+    # CHECK: !llvm.struct<()>
+
+    i8, i32, i64 = tuple(map(lambda x: IntegerType.get_signless(x), [8, 32, 64]))
+    print(llvm.StructType.get_literal([i8, i32, i64]))
+    print(llvm.StructType.get_literal([i32]))
+    print(llvm.StructType.get_literal([i32, i32], packed=True))
+    literal = llvm.StructType.get_literal([i8, i32, i64])
+    assert len(literal.body) == 3
+    print(*tuple(literal.body))
+    assert literal.name is None
+    # CHECK: !llvm.struct<(i8, i32, i64)>
+    # CHECK: !llvm.struct<(i32)>
+    # CHECK: !llvm.struct<packed (i32, i32)>
+    # CHECK: i8 i32 i64
+
+    assert llvm.StructType.get_literal([i32]) == llvm.StructType.get_literal([i32])
+    assert llvm.StructType.get_literal([i32]) != llvm.StructType.get_literal([i64])
+
+    print(llvm.StructType.get_identified("foo"))
+    print(llvm.StructType.get_identified("bar"))
+    # CHECK: !llvm.struct<"foo", opaque>
+    # CHECK: !llvm.struct<"bar", opaque>
+
+    assert llvm.StructType.get_identified("foo") == llvm.StructType.get_identified(
+        "foo"
+    )
+    assert llvm.StructType.get_identified("foo") != llvm.StructType.get_identified(
+        "bar"
+    )
+
+    foo_struct = llvm.StructType.get_identified("foo")
+    print(foo_struct.name)
+    print(foo_struct.body)
+    assert foo_struct.opaque
+    foo_struct.set_body([i32, i64])
+    print(*tuple(foo_struct.body))
+    print(foo_struct)
+    assert not foo_struct.packed
+    assert not foo_struct.opaque
+    assert llvm.StructType.get_identified("foo") == foo_struct
+    # CHECK: foo
+    # CHECK: None
+    # CHECK: i32 i64
+    # CHECK: !llvm.struct<"foo", (i32, i64)>
+
+    bar_struct = llvm.StructType.get_identified("bar")
+    bar_struct.set_body([i32], packed=True)
+    print(bar_struct)
+    assert bar_struct.packed
+    # CHECK: !llvm.struct<"bar", packed (i32)>
+
+    # Same body, should not raise.
+    foo_struct.set_body([i32, i64])
+
+    try:
+        foo_struct.set_body([])
+    except ValueError as e:
+        pass
+    else:
+        assert False, "expected exception not raised"
+
+    try:
+        bar_struct.set_body([i32])
+    except ValueError as e:
+        pass
+    else:
+        assert False, "expected exception not raised"
+
+    print(llvm.StructType.new_identified("foo", []))
+    assert llvm.StructType.new_identified("foo", []) != llvm.StructType.new_identified(
+        "foo", []
+    )
+    # CHECK: !llvm.struct<"foo{{[^"]+}}
+
+    opaque = llvm.StructType.get_opaque("opaque")
+    print(opaque)
+    assert opaque.opaque
+    # CHECK: !llvm.struct<"opaque", opaque>
+
+
 # CHECK-LABEL: testSmoke
 @constructAndPrintInModule
 def testSmoke():



More information about the Mlir-commits mailing list