[Mlir-commits] [mlir] [mlir][llvm] Expose llvm array type to CAPI and Python bindings (PR #185475)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Mar 9 10:50:35 PDT 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-llvm
Author: Asher Mancinelli (ashermancinelli)
<details>
<summary>Changes</summary>
This PR mostly copies everything to do with llvm.struct in the CAPI and search-and-replaces `struct` with `array`.
Assisted-by: claude opus 4.6
---
Full diff: https://github.com/llvm/llvm-project/pull/185475.diff
5 Files Affected:
- (modified) mlir/include/mlir-c/Dialect/LLVM.h (+8)
- (modified) mlir/lib/Bindings/Python/DialectLLVM.cpp (+30)
- (modified) mlir/lib/CAPI/Dialect/LLVM.cpp (+12)
- (modified) mlir/test/CAPI/llvm.c (+8)
- (modified) mlir/test/python/dialects/llvm.py (+23)
``````````diff
diff --git a/mlir/include/mlir-c/Dialect/LLVM.h b/mlir/include/mlir-c/Dialect/LLVM.h
index 7381519881e03..a9ac9a363064c 100644
--- a/mlir/include/mlir-c/Dialect/LLVM.h
+++ b/mlir/include/mlir-c/Dialect/LLVM.h
@@ -39,6 +39,11 @@ MLIR_CAPI_EXPORTED MlirType mlirLLVMVoidTypeGet(MlirContext ctx);
MLIR_CAPI_EXPORTED MlirStringRef mlirLLVMVoidTypeGetName(void);
+/// Returns `true` if the type is an LLVM dialect array type.
+MLIR_CAPI_EXPORTED bool mlirTypeIsALLVMArrayType(MlirType type);
+
+MLIR_CAPI_EXPORTED MlirTypeID mlirLLVMArrayTypeGetTypeID(void);
+
/// Creates an llvm.array type.
MLIR_CAPI_EXPORTED MlirType mlirLLVMArrayTypeGet(MlirType elementType,
unsigned numElements);
@@ -48,6 +53,9 @@ MLIR_CAPI_EXPORTED MlirStringRef mlirLLVMArrayTypeGetName(void);
/// Returns the element type of the llvm.array type.
MLIR_CAPI_EXPORTED MlirType mlirLLVMArrayTypeGetElementType(MlirType type);
+/// Returns the number of elements in the llvm.array type.
+MLIR_CAPI_EXPORTED unsigned mlirLLVMArrayTypeGetNumElements(MlirType type);
+
/// Creates an llvm.func type.
MLIR_CAPI_EXPORTED MlirType
mlirLLVMFunctionTypeGet(MlirType resultType, intptr_t nArgumentTypes,
diff --git a/mlir/lib/Bindings/Python/DialectLLVM.cpp b/mlir/lib/Bindings/Python/DialectLLVM.cpp
index dc06d0a3bf671..f17371a8bb385 100644
--- a/mlir/lib/Bindings/Python/DialectLLVM.cpp
+++ b/mlir/lib/Bindings/Python/DialectLLVM.cpp
@@ -160,6 +160,35 @@ struct StructType : PyConcreteType<StructType> {
}
};
+//===--------------------------------------------------------------------===//
+// ArrayType
+//===--------------------------------------------------------------------===//
+
+struct ArrayType : PyConcreteType<ArrayType> {
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsALLVMArrayType;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirLLVMArrayTypeGetTypeID;
+ static constexpr const char *pyClassName = "ArrayType";
+ static inline const MlirStringRef name = mlirLLVMArrayTypeGetName();
+ using Base::Base;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](PyType &elementType, unsigned numElements) {
+ return ArrayType(elementType.getContext(),
+ mlirLLVMArrayTypeGet(elementType, numElements));
+ },
+ "element_type"_a, "num_elements"_a);
+ c.def_prop_ro("element_type", [](const ArrayType &type) {
+ return mlirLLVMArrayTypeGetElementType(type);
+ });
+ c.def_prop_ro("num_elements", [](const ArrayType &type) {
+ return mlirLLVMArrayTypeGetNumElements(type);
+ });
+ }
+};
+
//===--------------------------------------------------------------------===//
// PointerType
//===--------------------------------------------------------------------===//
@@ -196,6 +225,7 @@ struct PointerType : PyConcreteType<PointerType> {
static void populateDialectLLVMSubmodule(nanobind::module_ &m) {
StructType::bind(m);
+ ArrayType::bind(m);
PointerType::bind(m);
m.def(
diff --git a/mlir/lib/CAPI/Dialect/LLVM.cpp b/mlir/lib/CAPI/Dialect/LLVM.cpp
index e6f58d010bda5..d91c32530fb03 100644
--- a/mlir/lib/CAPI/Dialect/LLVM.cpp
+++ b/mlir/lib/CAPI/Dialect/LLVM.cpp
@@ -49,6 +49,14 @@ MlirType mlirLLVMVoidTypeGet(MlirContext ctx) {
MlirStringRef mlirLLVMVoidTypeGetName(void) { return wrap(LLVMVoidType::name); }
+bool mlirTypeIsALLVMArrayType(MlirType type) {
+ return isa<LLVM::LLVMArrayType>(unwrap(type));
+}
+
+MlirTypeID mlirLLVMArrayTypeGetTypeID() {
+ return wrap(LLVM::LLVMArrayType::getTypeID());
+}
+
MlirType mlirLLVMArrayTypeGet(MlirType elementType, unsigned numElements) {
return wrap(LLVMArrayType::get(unwrap(elementType), numElements));
}
@@ -61,6 +69,10 @@ MlirType mlirLLVMArrayTypeGetElementType(MlirType type) {
return wrap(cast<LLVM::LLVMArrayType>(unwrap(type)).getElementType());
}
+unsigned mlirLLVMArrayTypeGetNumElements(MlirType type) {
+ return cast<LLVM::LLVMArrayType>(unwrap(type)).getNumElements();
+}
+
MlirType mlirLLVMFunctionTypeGet(MlirType resultType, intptr_t nArgumentTypes,
MlirType const *argumentTypes, bool isVarArg) {
SmallVector<Type, 2> argumentStorage;
diff --git a/mlir/test/CAPI/llvm.c b/mlir/test/CAPI/llvm.c
index f3c4cbe036d7c..057a42dbb4b14 100644
--- a/mlir/test/CAPI/llvm.c
+++ b/mlir/test/CAPI/llvm.c
@@ -59,6 +59,14 @@ static void testTypeCreation(MlirContext ctx) {
mlirTypeParseGet(ctx, mlirStringRefCreateFromCString(i32_4_text));
// CHECK: !llvm.array<4 x i32>: 1
fprintf(stderr, "%s: %d\n", i32_4_text, mlirTypeEqual(i32_4, i32_4_ref));
+ // CHECK: array_isa: 1
+ fprintf(stderr, "array_isa: %d\n", mlirTypeIsALLVMArrayType(i32_4));
+ // CHECK: array_element_type: 1
+ fprintf(stderr, "array_element_type: %d\n",
+ mlirTypeEqual(mlirLLVMArrayTypeGetElementType(i32_4), i32));
+ // CHECK: array_num_elements: 4
+ fprintf(stderr, "array_num_elements: %u\n",
+ mlirLLVMArrayTypeGetNumElements(i32_4));
const char *i8_i32_i64_text = "!llvm.func<i8 (i32, i64)>";
const MlirType i32_i64_arr[] = {i32, i64};
diff --git a/mlir/test/python/dialects/llvm.py b/mlir/test/python/dialects/llvm.py
index 305ed9aba940d..da8ccf223170f 100644
--- a/mlir/test/python/dialects/llvm.py
+++ b/mlir/test/python/dialects/llvm.py
@@ -102,6 +102,29 @@ def testStructType():
assert isinstance(typ, llvm.StructType)
+# CHECK-LABEL: testArrayType
+ at constructAndPrintInModule
+def testArrayType():
+ i32 = IntegerType.get_signless(32)
+ i8 = IntegerType.get_signless(8)
+
+ arr = llvm.ArrayType.get(i32, 4)
+ # CHECK: !llvm.array<4 x i32>
+ print(arr)
+ assert arr.element_type == i32
+ assert arr.num_elements == 4
+
+ arr2 = llvm.ArrayType.get(i8, 12)
+ # CHECK: !llvm.array<12 x i8>
+ print(arr2)
+ assert arr2.element_type == i8
+ assert arr2.num_elements == 12
+
+ typ = Type.parse("!llvm.array<4 x i32>")
+ assert isinstance(typ, llvm.ArrayType)
+ assert typ == arr
+
+
# CHECK-LABEL: testSmoke
@constructAndPrintInModule
def testSmoke():
``````````
</details>
https://github.com/llvm/llvm-project/pull/185475
More information about the Mlir-commits
mailing list