[Mlir-commits] [mlir] [mlir][python] expose LLVMStructType API (PR #81672)

Oleksandr Alex Zinenko llvmlistbot at llvm.org
Tue Feb 13 13:52:09 PST 2024


https://github.com/ftynse created https://github.com/llvm/llvm-project/pull/81672

Expose the API for constructing and inspecting StructTypes from the LLVM dialect. Separate constructor methods are used instead of overloads for better readability, similarly to IntegerType.

>From f12fc66f9829f6de485bcfd7807470ec613daf84 Mon Sep 17 00:00:00 2001
From: Alex Zinenko <ftynse at gmail.com>
Date: Tue, 13 Feb 2024 22:47:58 +0100
Subject: [PATCH] [mlir][python] expose LLVMStructType API

Expose the API for constructing and inspecting StructTypes from the LLVM
dialect. Separate constructor methods are used instead of overloads for better
readability, similarly to IntegerType.
---
 mlir/include/mlir-c/Dialect/LLVM.h       |  62 +++++++++-
 mlir/lib/Bindings/Python/DialectLLVM.cpp | 145 +++++++++++++++++++++++
 mlir/lib/CAPI/Dialect/LLVM.cpp           |  66 ++++++++++-
 mlir/python/CMakeLists.txt               |  13 ++
 mlir/python/mlir/dialects/llvm.py        |   1 +
 mlir/test/python/dialects/llvm.py        |  84 +++++++++++++
 6 files changed, 369 insertions(+), 2 deletions(-)
 create mode 100644 mlir/lib/Bindings/Python/DialectLLVM.cpp

diff --git a/mlir/include/mlir-c/Dialect/LLVM.h b/mlir/include/mlir-c/Dialect/LLVM.h
index 72701a82225436..80170b7e48b132 100644
--- a/mlir/include/mlir-c/Dialect/LLVM.h
+++ b/mlir/include/mlir-c/Dialect/LLVM.h
@@ -34,11 +34,71 @@ MLIR_CAPI_EXPORTED MlirType
 mlirLLVMFunctionTypeGet(MlirType resultType, intptr_t nArgumentTypes,
                         MlirType const *argumentTypes, bool isVarArg);
 
-/// Creates an LLVM literal (unnamed) struct type.
+/// Returns `true` if the type is an LLVM dialect struct type.
+MLIR_CAPI_EXPORTED bool
+mlirTypeIsALLVMStructType(MlirType type);
+
+/// Returns `true` if the type is a literal (unnamed) LLVM struct type.
+MLIR_CAPI_EXPORTED bool mlirLLVMStructTypeIsLiteral(MlirType type);
+
+/// Returns the number of fields in the struct. Asserts if the struct is opaque
+/// or not yet initialized.
+MLIR_CAPI_EXPORTED intptr_t mlirLLVMStructTypeGetNumElementTypes(MlirType type);
+
+/// Returns the `positions`-th field of the struct. Asserts if the struct is
+/// opaque, not yet initialized or if the position is out of range.
+MLIR_CAPI_EXPORTED MlirType mlirLLVMStructTypeGetElementType(MlirType type, intptr_t position);
+
+/// Returns `true` if the struct is packed.
+MLIR_CAPI_EXPORTED bool mlirLLVMStructTypeIsPacked(MlirType type);
+
+/// Returns the identifier of the identified struct. Asserts that the struct is
+/// identified, i.e., not literal.
+MLIR_CAPI_EXPORTED MlirStringRef mlirLLVMStructTypeGetIdentifier(MlirType type);
+
+/// Returns `true` is the struct is explicitly opaque (will not have a body) or
+/// unitiniazlied (will eventually have a body).
+MLIR_CAPI_EXPORTED bool mlirLLVMStructTypeIsOpaque(MlirType type);
+
+/// Creates an LLVM literal (unnamed) struct type. This may assert if the fields
+/// have types not compatible with the LLVM dialect. For a graceful failure, use
+/// the checked version.
 MLIR_CAPI_EXPORTED MlirType
 mlirLLVMStructTypeLiteralGet(MlirContext ctx, intptr_t nFieldTypes,
                              MlirType const *fieldTypes, bool isPacked);
 
+/// Creates an LLVM literal (unnamed) struct type if possible. Emits a
+/// diagnostic at the given location and returns null otherwise.
+MLIR_CAPI_EXPORTED MlirType
+mlirLLVMStructTypeLiteralGetChecked(MlirLocation loc, intptr_t nFieldTypes,
+                                    MlirType const *fieldTypes, bool isPacked);
+
+/// Creates an LLVM identified struct type with no body. If a struct type with
+/// this name already exists in the context, returns that type. Use
+/// mlirLLVMStructTypeIdentifiedNewGet to create a fresh struct type,
+/// potentially renaming it. The body should be set separatelty by calling
+/// mlirLLVMStructTypeSetBody, if it isn't set already.
+MLIR_CAPI_EXPORTED MlirType
+mlirLLVMStructTypeIdentifiedGet(MlirContext ctx, MlirStringRef name);
+
+/// Creates an LLVM identified struct type with no body and a name starting with
+/// the given prefix. If a struct with the exact name as the given prefix
+/// already exists, appends an unspecified suffix to the name so that the name
+/// is unique in context.
+MLIR_CAPI_EXPORTED MlirType mlirLLVMStructTypeIdentifiedNewGet(
+    MlirContext ctx, MlirStringRef name, intptr_t nFieldTypes,
+    MlirType const *fieldTypes, bool isPacked);
+
+MLIR_CAPI_EXPORTED MlirType mlirLLVMStructTypeOpaqueGet(MlirContext ctx,
+                                                        MlirStringRef name);
+
+/// Sets the body of the identified struct if it hasn't been set yet. Returns
+/// whether the operation was successful.
+MLIR_CAPI_EXPORTED MlirLogicalResult mlirLLVMStructTypeSetBody(MlirType structType,
+                                                  intptr_t nFieldTypes,
+                                                  MlirType const *fieldTypes,
+                                                  bool isPacked);
+
 #ifdef __cplusplus
 }
 #endif
diff --git a/mlir/lib/Bindings/Python/DialectLLVM.cpp b/mlir/lib/Bindings/Python/DialectLLVM.cpp
new file mode 100644
index 00000000000000..583375e1e2a13f
--- /dev/null
+++ b/mlir/lib/Bindings/Python/DialectLLVM.cpp
@@ -0,0 +1,145 @@
+//===- DialectLLVM.cpp - Pybind module for LLVM dialect API support -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir-c/Diagnostics.h"
+#include "mlir-c/Dialect/LLVM.h"
+#include "mlir-c/IR.h"
+#include "mlir-c/Support.h"
+#include "mlir/Bindings/Python/PybindAdaptors.h"
+#include <string>
+
+namespace py = pybind11;
+using namespace llvm;
+using namespace mlir;
+using namespace mlir::python;
+using namespace mlir::python::adaptors;
+
+namespace {
+/// Standalone RAII scope guard. We don't want to depend on the LLVM support
+/// library here to simplify build.
+template <typename FnTy>
+class OnScopeExit {
+public:
+  OnScopeExit(FnTy &&fn) : callback(std::forward<FnTy>(fn)) {}
+  ~OnScopeExit() { callback(); }
+
+private:
+  FnTy callback;
+};
+
+template <typename FnTy>
+OnScopeExit(FnTy &&fn) -> OnScopeExit<FnTy>;
+} // namespace
+
+void populateDialectLLVMSubmodule(const pybind11::module &m) {
+  auto llvmStructType =
+      mlir_type_subclass(m, "StructType", mlirTypeIsALLVMStructType);
+
+  llvmStructType.def_classmethod(
+      "get_literal",
+      [](py::object cls, const std::vector<MlirType> &elements, bool packed,
+         MlirLocation loc) {
+        std::string errorMessage = "";
+        auto handler = +[](MlirDiagnostic diag, void *data) {
+          auto printer = +[](MlirStringRef message, void *data) {
+            *static_cast<std::string *>(data) +=
+                StringRef(message.data, message.length);
+          };
+          mlirDiagnosticPrint(diag, printer, data);
+          return mlirLogicalResultSuccess();
+        };
+
+        MlirContext context = mlirLocationGetContext(loc);
+        MlirDiagnosticHandlerID diagID = mlirContextAttachDiagnosticHandler(
+            context, handler, &errorMessage, nullptr);
+        OnScopeExit scopeGuard(
+            [&]() { mlirContextDetachDiagnosticHandler(context, diagID); });
+
+        MlirType type = mlirLLVMStructTypeLiteralGetChecked(
+            loc, elements.size(), elements.data(), packed);
+        if (mlirTypeIsNull(type)) {
+          throw py::value_error(errorMessage);
+        }
+        return cls(type);
+      },
+      py::arg("cls"), py::arg("elements"), py::kw_only(),
+      py::arg("packed") = false, py::arg("loc") = py::none());
+
+  llvmStructType.def_classmethod(
+      "get_identified",
+      [](py::object cls, const std::string &name, MlirContext context) {
+        return cls(mlirLLVMStructTypeIdentifiedGet(
+            context, mlirStringRefCreate(name.data(), name.size())));
+      },
+      py::arg("cls"), py::arg("name"), py::kw_only(),
+      py::arg("context") = py::none());
+
+  llvmStructType.def_classmethod(
+      "get_opaque",
+      [](py::object cls, const std::string &name, MlirContext context) {
+        return cls(mlirLLVMStructTypeOpaqueGet(
+            context, mlirStringRefCreate(name.data(), name.size())));
+      }, py::arg("cls"), py::arg("name"), py::arg("context") = py::none());
+
+  llvmStructType.def(
+      "set_body",
+      [](MlirType self, const std::vector<MlirType> &elements, bool packed) {
+        MlirLogicalResult result = mlirLLVMStructTypeSetBody(
+            self, elements.size(), elements.data(), packed);
+        if (!mlirLogicalResultIsSuccess(result)) {
+          throw py::value_error(
+              "Struct body already set to different content.");
+        }
+      },
+      py::arg("elements"), py::kw_only(), py::arg("packed") = false);
+
+  llvmStructType.def_classmethod(
+      "new_identified",
+      [](py::object cls, const std::string &name,
+         const std::vector<MlirType> &elements, bool packed, MlirContext ctx) {
+        return cls(mlirLLVMStructTypeIdentifiedNewGet(
+            ctx, mlirStringRefCreate(name.data(), name.length()),
+            elements.size(), elements.data(), packed));
+      },
+      py::arg("cls"), py::arg("name"), py::arg("elements"), py::kw_only(),
+      py::arg("packed") = false, py::arg("context") = py::none());
+
+  llvmStructType.def_property_readonly(
+      "name", [](MlirType type) -> std::optional<std::string> {
+        if (mlirLLVMStructTypeIsLiteral(type))
+          return std::nullopt;
+
+        MlirStringRef stringRef = mlirLLVMStructTypeGetIdentifier(type);
+        return StringRef(stringRef.data, stringRef.length).str();
+      });
+
+  llvmStructType.def_property_readonly("body", [](MlirType type) -> py::object {
+    // Don't crash in absence of a body.
+    if (mlirLLVMStructTypeIsOpaque(type))
+      return py::none();
+
+    py::list body;
+    for (intptr_t i = 0, e = mlirLLVMStructTypeGetNumElementTypes(type); i < e;
+         ++i) {
+      body.append(mlirLLVMStructTypeGetElementType(type, i));
+    }
+    return body;
+  });
+
+  llvmStructType.def_property_readonly(
+      "packed", [](MlirType type) { return mlirLLVMStructTypeIsPacked(type); });
+
+  llvmStructType.def_property_readonly(
+      "opaque", [](MlirType type) { return mlirLLVMStructTypeIsOpaque(type); });
+}
+
+PYBIND11_MODULE(_mlirDialectsLLVM, m) {
+  m.doc() = "MLIR LLVM Dialect";
+
+  populateDialectLLVMSubmodule(m);
+}
\ No newline at end of file
diff --git a/mlir/lib/CAPI/Dialect/LLVM.cpp b/mlir/lib/CAPI/Dialect/LLVM.cpp
index b4405f7aac8ab2..ada9fcb46be21f 100644
--- a/mlir/lib/CAPI/Dialect/LLVM.cpp
+++ b/mlir/lib/CAPI/Dialect/LLVM.cpp
@@ -36,11 +36,75 @@ MlirType mlirLLVMFunctionTypeGet(MlirType resultType, intptr_t nArgumentTypes,
       unwrapList(nArgumentTypes, argumentTypes, argumentStorage), isVarArg));
 }
 
+bool mlirTypeIsALLVMStructType(MlirType type) {
+  return isa<LLVM::LLVMStructType>(unwrap(type));
+}
+
+bool mlirLLVMStructTypeIsLiteral(MlirType type) {
+  return !cast<LLVM::LLVMStructType>(unwrap(type)).isIdentified();
+}
+
+intptr_t mlirLLVMStructTypeGetNumElementTypes(MlirType type) {
+  return cast<LLVM::LLVMStructType>(unwrap(type)).getBody().size();
+}
+
+MlirType mlirLLVMStructTypeGetElementType(MlirType type, intptr_t position) {
+  return wrap(cast<LLVM::LLVMStructType>(unwrap(type)).getBody()[position]);
+}
+
+bool mlirLLVMStructTypeIsPacked(MlirType type) {
+  return cast<LLVM::LLVMStructType>(unwrap(type)).isPacked();
+}
+
+MlirStringRef mlirLLVMStructTypeGetIdentifier(MlirType type) {
+  return wrap(cast<LLVM::LLVMStructType>(unwrap(type)).getName());
+}
+
+bool mlirLLVMStructTypeIsOpaque(MlirType type) {
+  return cast<LLVM::LLVMStructType>(unwrap(type)).isOpaque();
+}
+
 MlirType mlirLLVMStructTypeLiteralGet(MlirContext ctx, intptr_t nFieldTypes,
                                       MlirType const *fieldTypes,
                                       bool isPacked) {
-  SmallVector<Type, 2> fieldStorage;
+  SmallVector<Type> fieldStorage;
   return wrap(LLVMStructType::getLiteral(
       unwrap(ctx), unwrapList(nFieldTypes, fieldTypes, fieldStorage),
       isPacked));
 }
+
+MlirType mlirLLVMStructTypeLiteralGetChecked(MlirLocation loc,
+                                             intptr_t nFieldTypes,
+                                             MlirType const *fieldTypes,
+                                             bool isPacked) {
+  SmallVector<Type> fieldStorage;
+  return wrap(LLVMStructType::getLiteralChecked(
+      [loc]() { return emitError(unwrap(loc)); }, unwrap(loc)->getContext(),
+      unwrapList(nFieldTypes, fieldTypes, fieldStorage), isPacked));
+}
+
+MlirType mlirLLVMStructTypeOpaqueGet(MlirContext ctx, MlirStringRef name) {
+  return wrap(LLVMStructType::getOpaque(unwrap(name), unwrap(ctx)));
+}
+
+MlirType mlirLLVMStructTypeIdentifiedGet(MlirContext ctx, MlirStringRef name) {
+  return wrap(LLVMStructType::getIdentified(unwrap(ctx), unwrap(name)));
+}
+
+MlirType mlirLLVMStructTypeIdentifiedNewGet(MlirContext ctx, MlirStringRef name,
+intptr_t nFieldTypes, MlirType const *fieldTypes, bool isPacked) {
+  SmallVector<Type> fields;
+  return wrap(LLVMStructType::getNewIdentified(
+      unwrap(ctx), unwrap(name), unwrapList(nFieldTypes, fieldTypes, fields),
+      isPacked));
+}
+
+MlirLogicalResult mlirLLVMStructTypeSetBody(MlirType structType,
+                                            intptr_t nFieldTypes,
+                                            MlirType const *fieldTypes,
+                                            bool isPacked) {
+  SmallVector<Type> fields;
+  return wrap(
+      cast<LLVM::LLVMStructType>(unwrap(structType))
+          .setBody(unwrapList(nFieldTypes, fieldTypes, fields), isPacked));
+}
\ No newline at end of file
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 266b86090fe174..ed167afeb69a62 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -482,6 +482,19 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.Linalg.Pybind
     MLIRCAPILinalg
 )
 
+declare_mlir_python_extension(MLIRPythonExtension.Dialects.LLVM.Pybind
+  MODULE_NAME _mlirDialectsLLVM
+  ADD_TO_PARENT MLIRPythonSources.Dialects.llvm
+  ROOT_DIR "${PYTHON_SOURCE_DIR}"
+  SOURCES
+    DialectLLVM.cpp
+  PRIVATE_LINK_LIBS
+    LLVMSupport
+  EMBED_CAPI_LINK_LIBS
+    MLIRCAPIIR
+    MLIRCAPILLVM
+)
+
 declare_mlir_python_extension(MLIRPythonExtension.Dialects.Quant.Pybind
   MODULE_NAME _mlirDialectsQuant
   ADD_TO_PARENT MLIRPythonSources.Dialects.quant
diff --git a/mlir/python/mlir/dialects/llvm.py b/mlir/python/mlir/dialects/llvm.py
index 77025438c37a4f..8aa16e4a256030 100644
--- a/mlir/python/mlir/dialects/llvm.py
+++ b/mlir/python/mlir/dialects/llvm.py
@@ -4,3 +4,4 @@
 
 from ._llvm_ops_gen import *
 from ._llvm_enum_gen import *
+from .._mlir_libs._mlirDialectsLLVM import *
diff --git a/mlir/test/python/dialects/llvm.py b/mlir/test/python/dialects/llvm.py
index 2d207ae14eecd2..af919df723827f 100644
--- a/mlir/test/python/dialects/llvm.py
+++ b/mlir/test/python/dialects/llvm.py
@@ -15,6 +15,90 @@ def constructAndPrintInModule(f):
     return f
 
 
+# CHECK-LABEL: testStructType
+ at constructAndPrintInModule
+def testStructType():
+    print(llvm.StructType.get_literal([]))
+    # CHECK: !llvm.struct<()>
+
+    i8, i32, i64 = tuple(map(lambda x: IntegerType.get_signless(x), [8, 32, 64]))
+    print(llvm.StructType.get_literal([i8, i32, i64]))
+    print(llvm.StructType.get_literal([i32]))
+    print(llvm.StructType.get_literal([i32, i32], packed=True))
+    literal = llvm.StructType.get_literal([i8, i32, i64])
+    assert len(literal.body) == 3
+    print(*tuple(literal.body))
+    assert literal.name is None
+    # CHECK: !llvm.struct<(i8, i32, i64)>
+    # CHECK: !llvm.struct<(i32)>
+    # CHECK: !llvm.struct<packed (i32, i32)>
+    # CHECK: i8 i32 i64
+
+    assert llvm.StructType.get_literal([i32]) == llvm.StructType.get_literal([i32])
+    assert llvm.StructType.get_literal([i32]) != llvm.StructType.get_literal([i64])
+
+    print(llvm.StructType.get_identified("foo"))
+    print(llvm.StructType.get_identified("bar"))
+    # CHECK: !llvm.struct<"foo", opaque>
+    # CHECK: !llvm.struct<"bar", opaque>
+
+    assert llvm.StructType.get_identified("foo") == llvm.StructType.get_identified(
+        "foo"
+    )
+    assert llvm.StructType.get_identified("foo") != llvm.StructType.get_identified(
+        "bar"
+    )
+
+    foo_struct = llvm.StructType.get_identified("foo")
+    print(foo_struct.name)
+    print(foo_struct.body)
+    assert foo_struct.opaque
+    foo_struct.set_body([i32, i64])
+    print(*tuple(foo_struct.body))
+    print(foo_struct)
+    assert not foo_struct.packed
+    assert not foo_struct.opaque
+    assert llvm.StructType.get_identified("foo") == foo_struct
+    # CHECK: foo
+    # CHECK: None
+    # CHECK: i32 i64
+    # CHECK: !llvm.struct<"foo", (i32, i64)>
+
+    bar_struct = llvm.StructType.get_identified("bar")
+    bar_struct.set_body([i32], packed=True)
+    print(bar_struct)
+    assert bar_struct.packed
+    # CHECK: !llvm.struct<"bar", packed (i32)>
+
+    # Same body, should not raise.
+    foo_struct.set_body([i32, i64])
+
+    try:
+        foo_struct.set_body([])
+    except ValueError as e:
+        pass
+    else:
+        assert False, "expected exception not raised"
+
+    try:
+        bar_struct.set_body([i32])
+    except ValueError as e:
+        pass
+    else:
+        assert False, "expected exception not raisr"
+
+    print(llvm.StructType.new_identified("foo", []))
+    assert llvm.StructType.new_identified("foo", []) != llvm.StructType.new_identified(
+        "foo", []
+    )
+    # CHECK: !llvm.struct<"foo{{[^"]+}}
+
+    opaque = llvm.StructType.get_opaque("opaque")
+    print(opaque)
+    assert opaque.opaque
+    # CHECK: !llvm.struct<"opaque", opaque>
+
+
 # CHECK-LABEL: testSmoke
 @constructAndPrintInModule
 def testSmoke():



More information about the Mlir-commits mailing list