[Mlir-commits] [mlir] 79d4d16 - [mlir][python] extend LLVM bindings (#89797)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Apr 24 05:43:09 PDT 2024


Author: Maksim Levental
Date: 2024-04-24T07:43:05-05:00
New Revision: 79d4d165638b7587937fc60431e0865fd73c9334

URL: https://github.com/llvm/llvm-project/commit/79d4d165638b7587937fc60431e0865fd73c9334
DIFF: https://github.com/llvm/llvm-project/commit/79d4d165638b7587937fc60431e0865fd73c9334.diff

LOG: [mlir][python] extend LLVM bindings (#89797)

Add bindings for LLVM pointer type.

Added: 
    

Modified: 
    mlir/include/mlir-c/Dialect/LLVM.h
    mlir/lib/Bindings/Python/DialectLLVM.cpp
    mlir/lib/CAPI/Dialect/LLVM.cpp
    mlir/python/mlir/dialects/LLVMOps.td
    mlir/python/mlir/dialects/llvm.py
    mlir/test/python/dialects/llvm.py

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/Dialect/LLVM.h b/mlir/include/mlir-c/Dialect/LLVM.h
index bd9b7dd26f5e9e..b3e64bd68f7b1c 100644
--- a/mlir/include/mlir-c/Dialect/LLVM.h
+++ b/mlir/include/mlir-c/Dialect/LLVM.h
@@ -23,6 +23,13 @@ MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(LLVM, llvm);
 MLIR_CAPI_EXPORTED MlirType mlirLLVMPointerTypeGet(MlirContext ctx,
                                                    unsigned addressSpace);
 
+/// Returns `true` if the type is an LLVM dialect pointer type.
+MLIR_CAPI_EXPORTED bool mlirTypeIsALLVMPointerType(MlirType type);
+
+/// Returns address space of llvm.ptr
+MLIR_CAPI_EXPORTED unsigned
+mlirLLVMPointerTypeGetAddressSpace(MlirType pointerType);
+
 /// Creates an llmv.void type.
 MLIR_CAPI_EXPORTED MlirType mlirLLVMVoidTypeGet(MlirContext ctx);
 

diff  --git a/mlir/lib/Bindings/Python/DialectLLVM.cpp b/mlir/lib/Bindings/Python/DialectLLVM.cpp
index 843707751dd849..42a4c8c0793ba8 100644
--- a/mlir/lib/Bindings/Python/DialectLLVM.cpp
+++ b/mlir/lib/Bindings/Python/DialectLLVM.cpp
@@ -19,6 +19,11 @@ using namespace mlir::python;
 using namespace mlir::python::adaptors;
 
 void populateDialectLLVMSubmodule(const pybind11::module &m) {
+
+  //===--------------------------------------------------------------------===//
+  // StructType
+  //===--------------------------------------------------------------------===//
+
   auto llvmStructType =
       mlir_type_subclass(m, "StructType", mlirTypeIsALLVMStructType);
 
@@ -35,8 +40,8 @@ void populateDialectLLVMSubmodule(const pybind11::module &m) {
         }
         return cls(type);
       },
-      py::arg("cls"), py::arg("elements"), py::kw_only(),
-      py::arg("packed") = false, py::arg("loc") = py::none());
+      "cls"_a, "elements"_a, py::kw_only(), "packed"_a = false,
+      "loc"_a = py::none());
 
   llvmStructType.def_classmethod(
       "get_identified",
@@ -44,8 +49,7 @@ void populateDialectLLVMSubmodule(const pybind11::module &m) {
         return cls(mlirLLVMStructTypeIdentifiedGet(
             context, mlirStringRefCreate(name.data(), name.size())));
       },
-      py::arg("cls"), py::arg("name"), py::kw_only(),
-      py::arg("context") = py::none());
+      "cls"_a, "name"_a, py::kw_only(), "context"_a = py::none());
 
   llvmStructType.def_classmethod(
       "get_opaque",
@@ -53,7 +57,7 @@ void populateDialectLLVMSubmodule(const pybind11::module &m) {
         return cls(mlirLLVMStructTypeOpaqueGet(
             context, mlirStringRefCreate(name.data(), name.size())));
       },
-      py::arg("cls"), py::arg("name"), py::arg("context") = py::none());
+      "cls"_a, "name"_a, "context"_a = py::none());
 
   llvmStructType.def(
       "set_body",
@@ -65,7 +69,7 @@ void populateDialectLLVMSubmodule(const pybind11::module &m) {
               "Struct body already set to 
diff erent content.");
         }
       },
-      py::arg("elements"), py::kw_only(), py::arg("packed") = false);
+      "elements"_a, py::kw_only(), "packed"_a = false);
 
   llvmStructType.def_classmethod(
       "new_identified",
@@ -75,8 +79,8 @@ void populateDialectLLVMSubmodule(const pybind11::module &m) {
             ctx, mlirStringRefCreate(name.data(), name.length()),
             elements.size(), elements.data(), packed));
       },
-      py::arg("cls"), py::arg("name"), py::arg("elements"), py::kw_only(),
-      py::arg("packed") = false, py::arg("context") = py::none());
+      "cls"_a, "name"_a, "elements"_a, py::kw_only(), "packed"_a = false,
+      "context"_a = py::none());
 
   llvmStructType.def_property_readonly(
       "name", [](MlirType type) -> std::optional<std::string> {
@@ -105,6 +109,29 @@ void populateDialectLLVMSubmodule(const pybind11::module &m) {
 
   llvmStructType.def_property_readonly(
       "opaque", [](MlirType type) { return mlirLLVMStructTypeIsOpaque(type); });
+
+  //===--------------------------------------------------------------------===//
+  // PointerType
+  //===--------------------------------------------------------------------===//
+
+  mlir_type_subclass(m, "PointerType", mlirTypeIsALLVMPointerType)
+      .def_classmethod(
+          "get",
+          [](py::object cls, std::optional<unsigned> addressSpace,
+             MlirContext context) {
+            CollectDiagnosticsToStringScope scope(context);
+            MlirType type = mlirLLVMPointerTypeGet(
+                context, addressSpace.has_value() ? *addressSpace : 0);
+            if (mlirTypeIsNull(type)) {
+              throw py::value_error(scope.takeMessage());
+            }
+            return cls(type);
+          },
+          "cls"_a, "address_space"_a = py::none(), py::kw_only(),
+          "context"_a = py::none())
+      .def_property_readonly("address_space", [](MlirType type) {
+        return mlirLLVMPointerTypeGetAddressSpace(type);
+      });
 }
 
 PYBIND11_MODULE(_mlirDialectsLLVM, m) {

diff  --git a/mlir/lib/CAPI/Dialect/LLVM.cpp b/mlir/lib/CAPI/Dialect/LLVM.cpp
index 21c66f38a8af03..108ebe5367d567 100644
--- a/mlir/lib/CAPI/Dialect/LLVM.cpp
+++ b/mlir/lib/CAPI/Dialect/LLVM.cpp
@@ -27,6 +27,14 @@ MlirType mlirLLVMPointerTypeGet(MlirContext ctx, unsigned addressSpace) {
   return wrap(LLVMPointerType::get(unwrap(ctx), addressSpace));
 }
 
+bool mlirTypeIsALLVMPointerType(MlirType type) {
+  return isa<LLVM::LLVMPointerType>(unwrap(type));
+}
+
+unsigned mlirLLVMPointerTypeGetAddressSpace(MlirType pointerType) {
+  return cast<LLVM::LLVMPointerType>(unwrap(pointerType)).getAddressSpace();
+}
+
 MlirType mlirLLVMVoidTypeGet(MlirContext ctx) {
   return wrap(LLVMVoidType::get(unwrap(ctx)));
 }

diff  --git a/mlir/python/mlir/dialects/LLVMOps.td b/mlir/python/mlir/dialects/LLVMOps.td
index dcf2f4245cf49f..30f047f21698e3 100644
--- a/mlir/python/mlir/dialects/LLVMOps.td
+++ b/mlir/python/mlir/dialects/LLVMOps.td
@@ -10,5 +10,6 @@
 #define PYTHON_BINDINGS_LLVM_OPS
 
 include "mlir/Dialect/LLVMIR/LLVMOps.td"
+include "mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td"
 
 #endif

diff  --git a/mlir/python/mlir/dialects/llvm.py b/mlir/python/mlir/dialects/llvm.py
index 8aa16e4a256030..941a584966dcde 100644
--- a/mlir/python/mlir/dialects/llvm.py
+++ b/mlir/python/mlir/dialects/llvm.py
@@ -5,3 +5,11 @@
 from ._llvm_ops_gen import *
 from ._llvm_enum_gen import *
 from .._mlir_libs._mlirDialectsLLVM import *
+from ..ir import Value
+from ._ods_common import get_op_result_or_op_results as _get_op_result_or_op_results
+
+
+def mlir_constant(value, *, loc=None, ip=None) -> Value:
+    return _get_op_result_or_op_results(
+        ConstantOp(res=value.type, value=value, loc=loc, ip=ip)
+    )

diff  --git a/mlir/test/python/dialects/llvm.py b/mlir/test/python/dialects/llvm.py
index fb4b343b170bae..d9ffdeb65bfd40 100644
--- a/mlir/test/python/dialects/llvm.py
+++ b/mlir/test/python/dialects/llvm.py
@@ -107,3 +107,46 @@ def testSmoke():
     )
     result = llvm.UndefOp(mat64f32_t)
     # CHECK: %0 = llvm.mlir.undef : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
+
+
+# CHECK-LABEL: testPointerType
+ at constructAndPrintInModule
+def testPointerType():
+    ptr = llvm.PointerType.get()
+    # CHECK: !llvm.ptr
+    print(ptr)
+
+    ptr_with_addr = llvm.PointerType.get(1)
+    # CHECK: !llvm.ptr<1>
+    print(ptr_with_addr)
+
+
+# CHECK-LABEL: testConstant
+ at constructAndPrintInModule
+def testConstant():
+    i32 = IntegerType.get_signless(32)
+    c_128 = llvm.mlir_constant(IntegerAttr.get(i32, 128))
+    # CHECK: %{{.*}} = llvm.mlir.constant(128 : i32) : i32
+    print(c_128.owner)
+
+
+# CHECK-LABEL: testIntrinsics
+ at constructAndPrintInModule
+def testIntrinsics():
+    i32 = IntegerType.get_signless(32)
+    ptr = llvm.PointerType.get()
+    c_128 = llvm.mlir_constant(IntegerAttr.get(i32, 128))
+    # CHECK: %[[CST128:.*]] = llvm.mlir.constant(128 : i32) : i32
+    print(c_128.owner)
+
+    alloca = llvm.alloca(ptr, c_128, i32)
+    # CHECK: %[[ALLOCA:.*]] = llvm.alloca %[[CST128]] x i32 : (i32) -> !llvm.ptr
+    print(alloca.owner)
+
+    c_0 = llvm.mlir_constant(IntegerAttr.get(IntegerType.get_signless(8), 0))
+    # CHECK: %[[CST0:.+]] = llvm.mlir.constant(0 : i8) : i8
+    print(c_0.owner)
+
+    result = llvm.intr_memset(alloca, c_0, c_128, False)
+    # CHECK: "llvm.intr.memset"(%[[ALLOCA]], %[[CST0]], %[[CST128]]) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
+    print(result)


        


More information about the Mlir-commits mailing list