[Mlir-commits] [mlir] [mlir][python] extend LLVM bindings (PR #89797)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Apr 23 10:12:34 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Maksim Levental (makslevental)
<details>
<summary>Changes</summary>
---
Full diff: https://github.com/llvm/llvm-project/pull/89797.diff
6 Files Affected:
- (modified) mlir/include/mlir-c/Dialect/LLVM.h (+7)
- (modified) mlir/lib/Bindings/Python/DialectLLVM.cpp (+35-8)
- (modified) mlir/lib/CAPI/Dialect/LLVM.cpp (+8)
- (modified) mlir/python/mlir/dialects/LLVMOps.td (+1)
- (modified) mlir/python/mlir/dialects/llvm.py (+8)
- (modified) mlir/test/python/dialects/llvm.py (+43)
``````````diff
diff --git a/mlir/include/mlir-c/Dialect/LLVM.h b/mlir/include/mlir-c/Dialect/LLVM.h
index bd9b7dd26f5e9e..b3e64bd68f7b1c 100644
--- a/mlir/include/mlir-c/Dialect/LLVM.h
+++ b/mlir/include/mlir-c/Dialect/LLVM.h
@@ -23,6 +23,13 @@ MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(LLVM, llvm);
MLIR_CAPI_EXPORTED MlirType mlirLLVMPointerTypeGet(MlirContext ctx,
unsigned addressSpace);
+/// Returns `true` if the type is an LLVM dialect pointer type.
+MLIR_CAPI_EXPORTED bool mlirTypeIsALLVMPointerType(MlirType type);
+
+/// Returns address space of llvm.ptr
+MLIR_CAPI_EXPORTED unsigned
+mlirLLVMPointerTypeGetAddressSpace(MlirType pointerType);
+
/// Creates an llmv.void type.
MLIR_CAPI_EXPORTED MlirType mlirLLVMVoidTypeGet(MlirContext ctx);
diff --git a/mlir/lib/Bindings/Python/DialectLLVM.cpp b/mlir/lib/Bindings/Python/DialectLLVM.cpp
index 843707751dd849..42a4c8c0793ba8 100644
--- a/mlir/lib/Bindings/Python/DialectLLVM.cpp
+++ b/mlir/lib/Bindings/Python/DialectLLVM.cpp
@@ -19,6 +19,11 @@ using namespace mlir::python;
using namespace mlir::python::adaptors;
void populateDialectLLVMSubmodule(const pybind11::module &m) {
+
+ //===--------------------------------------------------------------------===//
+ // StructType
+ //===--------------------------------------------------------------------===//
+
auto llvmStructType =
mlir_type_subclass(m, "StructType", mlirTypeIsALLVMStructType);
@@ -35,8 +40,8 @@ void populateDialectLLVMSubmodule(const pybind11::module &m) {
}
return cls(type);
},
- py::arg("cls"), py::arg("elements"), py::kw_only(),
- py::arg("packed") = false, py::arg("loc") = py::none());
+ "cls"_a, "elements"_a, py::kw_only(), "packed"_a = false,
+ "loc"_a = py::none());
llvmStructType.def_classmethod(
"get_identified",
@@ -44,8 +49,7 @@ void populateDialectLLVMSubmodule(const pybind11::module &m) {
return cls(mlirLLVMStructTypeIdentifiedGet(
context, mlirStringRefCreate(name.data(), name.size())));
},
- py::arg("cls"), py::arg("name"), py::kw_only(),
- py::arg("context") = py::none());
+ "cls"_a, "name"_a, py::kw_only(), "context"_a = py::none());
llvmStructType.def_classmethod(
"get_opaque",
@@ -53,7 +57,7 @@ void populateDialectLLVMSubmodule(const pybind11::module &m) {
return cls(mlirLLVMStructTypeOpaqueGet(
context, mlirStringRefCreate(name.data(), name.size())));
},
- py::arg("cls"), py::arg("name"), py::arg("context") = py::none());
+ "cls"_a, "name"_a, "context"_a = py::none());
llvmStructType.def(
"set_body",
@@ -65,7 +69,7 @@ void populateDialectLLVMSubmodule(const pybind11::module &m) {
"Struct body already set to different content.");
}
},
- py::arg("elements"), py::kw_only(), py::arg("packed") = false);
+ "elements"_a, py::kw_only(), "packed"_a = false);
llvmStructType.def_classmethod(
"new_identified",
@@ -75,8 +79,8 @@ void populateDialectLLVMSubmodule(const pybind11::module &m) {
ctx, mlirStringRefCreate(name.data(), name.length()),
elements.size(), elements.data(), packed));
},
- py::arg("cls"), py::arg("name"), py::arg("elements"), py::kw_only(),
- py::arg("packed") = false, py::arg("context") = py::none());
+ "cls"_a, "name"_a, "elements"_a, py::kw_only(), "packed"_a = false,
+ "context"_a = py::none());
llvmStructType.def_property_readonly(
"name", [](MlirType type) -> std::optional<std::string> {
@@ -105,6 +109,29 @@ void populateDialectLLVMSubmodule(const pybind11::module &m) {
llvmStructType.def_property_readonly(
"opaque", [](MlirType type) { return mlirLLVMStructTypeIsOpaque(type); });
+
+ //===--------------------------------------------------------------------===//
+ // PointerType
+ //===--------------------------------------------------------------------===//
+
+ mlir_type_subclass(m, "PointerType", mlirTypeIsALLVMPointerType)
+ .def_classmethod(
+ "get",
+ [](py::object cls, std::optional<unsigned> addressSpace,
+ MlirContext context) {
+ CollectDiagnosticsToStringScope scope(context);
+ MlirType type = mlirLLVMPointerTypeGet(
+ context, addressSpace.has_value() ? *addressSpace : 0);
+ if (mlirTypeIsNull(type)) {
+ throw py::value_error(scope.takeMessage());
+ }
+ return cls(type);
+ },
+ "cls"_a, "address_space"_a = py::none(), py::kw_only(),
+ "context"_a = py::none())
+ .def_property_readonly("address_space", [](MlirType type) {
+ return mlirLLVMPointerTypeGetAddressSpace(type);
+ });
}
PYBIND11_MODULE(_mlirDialectsLLVM, m) {
diff --git a/mlir/lib/CAPI/Dialect/LLVM.cpp b/mlir/lib/CAPI/Dialect/LLVM.cpp
index 4669c40f843d94..cd817539bb83a0 100644
--- a/mlir/lib/CAPI/Dialect/LLVM.cpp
+++ b/mlir/lib/CAPI/Dialect/LLVM.cpp
@@ -27,6 +27,14 @@ MlirType mlirLLVMPointerTypeGet(MlirContext ctx, unsigned addressSpace) {
return wrap(LLVMPointerType::get(unwrap(ctx), addressSpace));
}
+bool mlirTypeIsALLVMPointerType(MlirType type) {
+ return isa<LLVM::LLVMPointerType>(unwrap(type));
+}
+
+unsigned mlirLLVMPointerTypeGetAddressSpace(MlirType pointerType) {
+ return cast<LLVM::LLVMPointerType>(unwrap(pointerType)).getAddressSpace();
+}
+
MlirType mlirLLVMVoidTypeGet(MlirContext ctx) {
return wrap(LLVMVoidType::get(unwrap(ctx)));
}
diff --git a/mlir/python/mlir/dialects/LLVMOps.td b/mlir/python/mlir/dialects/LLVMOps.td
index dcf2f4245cf49f..30f047f21698e3 100644
--- a/mlir/python/mlir/dialects/LLVMOps.td
+++ b/mlir/python/mlir/dialects/LLVMOps.td
@@ -10,5 +10,6 @@
#define PYTHON_BINDINGS_LLVM_OPS
include "mlir/Dialect/LLVMIR/LLVMOps.td"
+include "mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td"
#endif
diff --git a/mlir/python/mlir/dialects/llvm.py b/mlir/python/mlir/dialects/llvm.py
index 8aa16e4a256030..941a584966dcde 100644
--- a/mlir/python/mlir/dialects/llvm.py
+++ b/mlir/python/mlir/dialects/llvm.py
@@ -5,3 +5,11 @@
from ._llvm_ops_gen import *
from ._llvm_enum_gen import *
from .._mlir_libs._mlirDialectsLLVM import *
+from ..ir import Value
+from ._ods_common import get_op_result_or_op_results as _get_op_result_or_op_results
+
+
+def mlir_constant(value, *, loc=None, ip=None) -> Value:
+ return _get_op_result_or_op_results(
+ ConstantOp(res=value.type, value=value, loc=loc, ip=ip)
+ )
diff --git a/mlir/test/python/dialects/llvm.py b/mlir/test/python/dialects/llvm.py
index fb4b343b170bae..d9ffdeb65bfd40 100644
--- a/mlir/test/python/dialects/llvm.py
+++ b/mlir/test/python/dialects/llvm.py
@@ -107,3 +107,46 @@ def testSmoke():
)
result = llvm.UndefOp(mat64f32_t)
# CHECK: %0 = llvm.mlir.undef : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
+
+
+# CHECK-LABEL: testPointerType
+ at constructAndPrintInModule
+def testPointerType():
+ ptr = llvm.PointerType.get()
+ # CHECK: !llvm.ptr
+ print(ptr)
+
+ ptr_with_addr = llvm.PointerType.get(1)
+ # CHECK: !llvm.ptr<1>
+ print(ptr_with_addr)
+
+
+# CHECK-LABEL: testConstant
+ at constructAndPrintInModule
+def testConstant():
+ i32 = IntegerType.get_signless(32)
+ c_128 = llvm.mlir_constant(IntegerAttr.get(i32, 128))
+ # CHECK: %{{.*}} = llvm.mlir.constant(128 : i32) : i32
+ print(c_128.owner)
+
+
+# CHECK-LABEL: testIntrinsics
+ at constructAndPrintInModule
+def testIntrinsics():
+ i32 = IntegerType.get_signless(32)
+ ptr = llvm.PointerType.get()
+ c_128 = llvm.mlir_constant(IntegerAttr.get(i32, 128))
+ # CHECK: %[[CST128:.*]] = llvm.mlir.constant(128 : i32) : i32
+ print(c_128.owner)
+
+ alloca = llvm.alloca(ptr, c_128, i32)
+ # CHECK: %[[ALLOCA:.*]] = llvm.alloca %[[CST128]] x i32 : (i32) -> !llvm.ptr
+ print(alloca.owner)
+
+ c_0 = llvm.mlir_constant(IntegerAttr.get(IntegerType.get_signless(8), 0))
+ # CHECK: %[[CST0:.+]] = llvm.mlir.constant(0 : i8) : i8
+ print(c_0.owner)
+
+ result = llvm.intr_memset(alloca, c_0, c_128, False)
+ # CHECK: "llvm.intr.memset"(%[[ALLOCA]], %[[CST0]], %[[CST128]]) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
+ print(result)
``````````
</details>
https://github.com/llvm/llvm-project/pull/89797
More information about the Mlir-commits
mailing list