[Mlir-commits] [mlir] [mlir][python] add binding to `#gpu.object` (PR #88992)

Arthur Eubanks llvmlistbot at llvm.org
Wed Apr 17 12:04:28 PDT 2024


https://github.com/aeubanks updated https://github.com/llvm/llvm-project/pull/88992

>From 191a32d24e91a4108cc4e2ca6a9421fb71692379 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Tue, 16 Apr 2024 16:40:04 -0500
Subject: [PATCH] [mlir][python] add binding to gpu.object

---
 mlir/include/mlir-c/Dialect/GPU.h             | 25 +++++++
 mlir/lib/Bindings/Python/DialectGPU.cpp       | 65 +++++++++++++++++++
 mlir/lib/CAPI/Dialect/GPU.cpp                 | 59 ++++++++++++++++-
 mlir/python/CMakeLists.txt                    | 11 ++++
 mlir/python/mlir/dialects/gpu/__init__.py     |  1 +
 mlir/test/python/dialects/gpu/dialect.py      | 25 +++++++
 .../dialects/gpu/module-to-binary-nvvm.py     | 30 +++++++--
 7 files changed, 210 insertions(+), 6 deletions(-)
 create mode 100644 mlir/lib/Bindings/Python/DialectGPU.cpp

diff --git a/mlir/include/mlir-c/Dialect/GPU.h b/mlir/include/mlir-c/Dialect/GPU.h
index 1a18d82c01d53e..2adf73ddff6eae 100644
--- a/mlir/include/mlir-c/Dialect/GPU.h
+++ b/mlir/include/mlir-c/Dialect/GPU.h
@@ -19,6 +19,31 @@ extern "C" {
 
 MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(GPU, gpu);
 
+//===---------------------------------------------------------------------===//
+// ObjectAttr
+//===---------------------------------------------------------------------===//
+
+MLIR_CAPI_EXPORTED bool mlirAttributeIsAGPUObjectAttr(MlirAttribute attr);
+
+MLIR_CAPI_EXPORTED MlirAttribute
+mlirGPUObjectAttrGet(MlirContext mlirCtx, MlirAttribute target, uint32_t format,
+                     MlirStringRef objectStrRef, MlirAttribute mlirObjectProps);
+
+MLIR_CAPI_EXPORTED MlirAttribute
+mlirGPUObjectAttrGetTarget(MlirAttribute mlirObjectAttr);
+
+MLIR_CAPI_EXPORTED uint32_t
+mlirGPUObjectAttrGetFormat(MlirAttribute mlirObjectAttr);
+
+MLIR_CAPI_EXPORTED MlirStringRef
+mlirGPUObjectAttrGetObject(MlirAttribute mlirObjectAttr);
+
+MLIR_CAPI_EXPORTED bool
+mlirGPUObjectAttrHasProperties(MlirAttribute mlirObjectAttr);
+
+MLIR_CAPI_EXPORTED MlirAttribute
+mlirGPUObjectAttrGetProperties(MlirAttribute mlirObjectAttr);
+
 #ifdef __cplusplus
 }
 #endif
diff --git a/mlir/lib/Bindings/Python/DialectGPU.cpp b/mlir/lib/Bindings/Python/DialectGPU.cpp
new file mode 100644
index 00000000000000..1f68bfc6ff1541
--- /dev/null
+++ b/mlir/lib/Bindings/Python/DialectGPU.cpp
@@ -0,0 +1,65 @@
+//===- DialectGPU.cpp - Pybind module for the GPU passes ------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===---------------------------------------------------------------------===//
+
+#include "mlir-c/Dialect/GPU.h"
+#include "mlir-c/IR.h"
+#include "mlir-c/Support.h"
+#include "mlir/Bindings/Python/PybindAdaptors.h"
+
+#include <pybind11/detail/common.h>
+#include <pybind11/pybind11.h>
+
+namespace py = pybind11;
+using namespace mlir;
+using namespace mlir::python;
+using namespace mlir::python::adaptors;
+
+// -----------------------------------------------------------------------------
+// Module initialization.
+// -----------------------------------------------------------------------------
+
+PYBIND11_MODULE(_mlirDialectsGPU, m) {
+  m.doc() = "MLIR GPU Dialect";
+
+  //===-------------------------------------------------------------------===//
+  // ObjectAttr
+  //===-------------------------------------------------------------------===//
+
+  mlir_attribute_subclass(m, "ObjectAttr", mlirAttributeIsAGPUObjectAttr)
+      .def_classmethod(
+          "get",
+          [](py::object cls, MlirAttribute target, uint32_t format,
+             py::bytes object, std::optional<MlirAttribute> mlirObjectProps) {
+            py::buffer_info info(py::buffer(object).request());
+            MlirStringRef objectStrRef =
+                mlirStringRefCreate(static_cast<char *>(info.ptr), info.size);
+            return cls(mlirGPUObjectAttrGet(
+                mlirAttributeGetContext(target), target, format, objectStrRef,
+                mlirObjectProps.has_value() ? *mlirObjectProps
+                                            : MlirAttribute{nullptr}));
+          },
+          "cls"_a, "target"_a, "format"_a, "object"_a,
+          "properties"_a = py::none(), "Gets a gpu.object from parameters.")
+      .def_property_readonly(
+          "target",
+          [](MlirAttribute self) { return mlirGPUObjectAttrGetTarget(self); })
+      .def_property_readonly(
+          "format",
+          [](MlirAttribute self) { return mlirGPUObjectAttrGetFormat(self); })
+      .def_property_readonly(
+          "object",
+          [](MlirAttribute self) {
+            MlirStringRef stringRef = mlirGPUObjectAttrGetObject(self);
+            return py::bytes(stringRef.data, stringRef.length);
+          })
+      .def_property_readonly("properties", [](MlirAttribute self) {
+        if (mlirGPUObjectAttrHasProperties(self))
+          return py::cast(mlirGPUObjectAttrGetProperties(self));
+        return py::none().cast<py::object>();
+      });
+}
diff --git a/mlir/lib/CAPI/Dialect/GPU.cpp b/mlir/lib/CAPI/Dialect/GPU.cpp
index cd58f0e249a9e2..e471e8cd9588e1 100644
--- a/mlir/lib/CAPI/Dialect/GPU.cpp
+++ b/mlir/lib/CAPI/Dialect/GPU.cpp
@@ -1,4 +1,4 @@
-//===- GPUc.cpp - C Interface for GPU dialect ----------------------------===//
+//===- GPU.cpp - C Interface for GPU dialect ------------------------------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -9,5 +9,60 @@
 #include "mlir-c/Dialect/GPU.h"
 #include "mlir/CAPI/Registration.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "llvm/Support/Casting.h"
 
-MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(GPU, gpu, mlir::gpu::GPUDialect)
+using namespace mlir;
+
+MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(GPU, gpu, gpu::GPUDialect)
+
+//===---------------------------------------------------------------------===//
+// ObjectAttr
+//===---------------------------------------------------------------------===//
+
+bool mlirAttributeIsAGPUObjectAttr(MlirAttribute attr) {
+  return llvm::isa<gpu::ObjectAttr>(unwrap(attr));
+}
+
+MlirAttribute mlirGPUObjectAttrGet(MlirContext mlirCtx, MlirAttribute target,
+                                   uint32_t format, MlirStringRef objectStrRef,
+                                   MlirAttribute mlirObjectProps) {
+  MLIRContext *ctx = unwrap(mlirCtx);
+  llvm::StringRef object = unwrap(objectStrRef);
+  DictionaryAttr objectProps;
+  if (mlirObjectProps.ptr != nullptr)
+    objectProps = llvm::cast<DictionaryAttr>(unwrap(mlirObjectProps));
+  return wrap(gpu::ObjectAttr::get(ctx, unwrap(target),
+                                   static_cast<gpu::CompilationTarget>(format),
+                                   StringAttr::get(ctx, object), objectProps));
+}
+
+MlirAttribute mlirGPUObjectAttrGetTarget(MlirAttribute mlirObjectAttr) {
+  gpu::ObjectAttr objectAttr =
+      llvm::cast<gpu::ObjectAttr>(unwrap(mlirObjectAttr));
+  return wrap(objectAttr.getTarget());
+}
+
+uint32_t mlirGPUObjectAttrGetFormat(MlirAttribute mlirObjectAttr) {
+  gpu::ObjectAttr objectAttr =
+      llvm::cast<gpu::ObjectAttr>(unwrap(mlirObjectAttr));
+  return static_cast<uint32_t>(objectAttr.getFormat());
+}
+
+MlirStringRef mlirGPUObjectAttrGetObject(MlirAttribute mlirObjectAttr) {
+  gpu::ObjectAttr objectAttr =
+      llvm::cast<gpu::ObjectAttr>(unwrap(mlirObjectAttr));
+  llvm::StringRef object = objectAttr.getObject();
+  return mlirStringRefCreate(object.data(), object.size());
+}
+
+bool mlirGPUObjectAttrHasProperties(MlirAttribute mlirObjectAttr) {
+  gpu::ObjectAttr objectAttr =
+      llvm::cast<gpu::ObjectAttr>(unwrap(mlirObjectAttr));
+  return objectAttr.getProperties() != nullptr;
+}
+
+MlirAttribute mlirGPUObjectAttrGetProperties(MlirAttribute mlirObjectAttr) {
+  gpu::ObjectAttr objectAttr =
+      llvm::cast<gpu::ObjectAttr>(unwrap(mlirObjectAttr));
+  return wrap(objectAttr.getProperties());
+}
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index c27ee688a04087..d31bad34afa82c 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -498,6 +498,17 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.Linalg.Pybind
     MLIRCAPILinalg
 )
 
+declare_mlir_python_extension(MLIRPythonExtension.Dialects.GPU.Pybind
+  MODULE_NAME _mlirDialectsGPU
+  ADD_TO_PARENT MLIRPythonSources.Dialects.gpu
+  ROOT_DIR "${PYTHON_SOURCE_DIR}"
+  SOURCES
+    DialectGPU.cpp
+  EMBED_CAPI_LINK_LIBS
+    MLIRCAPIIR
+    MLIRCAPIGPU
+)
+
 declare_mlir_python_extension(MLIRPythonExtension.Dialects.LLVM.Pybind
   MODULE_NAME _mlirDialectsLLVM
   ADD_TO_PARENT MLIRPythonSources.Dialects.llvm
diff --git a/mlir/python/mlir/dialects/gpu/__init__.py b/mlir/python/mlir/dialects/gpu/__init__.py
index 033386b0f803b2..4cd80aa8b7ca85 100644
--- a/mlir/python/mlir/dialects/gpu/__init__.py
+++ b/mlir/python/mlir/dialects/gpu/__init__.py
@@ -4,3 +4,4 @@
 
 from .._gpu_ops_gen import *
 from .._gpu_enum_gen import *
+from ..._mlir_libs._mlirDialectsGPU import *
diff --git a/mlir/test/python/dialects/gpu/dialect.py b/mlir/test/python/dialects/gpu/dialect.py
index 2f49e2e053999b..aded35b04aa1ea 100644
--- a/mlir/test/python/dialects/gpu/dialect.py
+++ b/mlir/test/python/dialects/gpu/dialect.py
@@ -30,3 +30,28 @@ def testMMAElementWiseAttr():
     # CHECK: %block_dim_y = gpu.block_dim  y
     print(module)
     pass
+
+
+# CHECK-LABEL: testObjectAttr
+ at run
+def testObjectAttr():
+    target = Attribute.parse("#nvvm.target")
+    format = gpu.CompilationTarget.Fatbin
+    object = b"BC\xc0\xde5\x14\x00\x00\x05\x00\x00\x00b\x0c0$MY\xbef"
+    properties = DictAttr.get({"O": IntegerAttr.get(IntegerType.get_signless(32), 2)})
+    o = gpu.ObjectAttr.get(target, format, object, properties)
+    # CHECK: #gpu.object<#nvvm.target, properties = {O = 2 : i32}, "BC\C0\DE5\14\00\00\05\00\00\00b\0C0$MY\BEf">
+    print(o)
+    assert o.object == object
+
+    o = gpu.ObjectAttr.get(target, format, object)
+    # CHECK: #gpu.object<#nvvm.target, "BC\C0\DE5\14\00\00\05\00\00\00b\0C0$MY\BEf">
+    print(o)
+
+    object = (
+        b"//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 6.0\n.target sm_50"
+    )
+    o = gpu.ObjectAttr.get(target, format, object)
+    # CHECK: #gpu.object<#nvvm.target, "//\0A// Generated by LLVM NVPTX Back-End\0A//\0A\0A.version 6.0\0A.target sm_50">
+    print(o)
+    assert o.object == object
diff --git a/mlir/test/python/dialects/gpu/module-to-binary-nvvm.py b/mlir/test/python/dialects/gpu/module-to-binary-nvvm.py
index 1c2eb652e71f91..e0225911b21443 100644
--- a/mlir/test/python/dialects/gpu/module-to-binary-nvvm.py
+++ b/mlir/test/python/dialects/gpu/module-to-binary-nvvm.py
@@ -34,9 +34,20 @@ def testGPUToLLVMBin():
     pm = PassManager("any")
     pm.add("gpu-module-to-binary{format=llvm}")
     pm.run(module.operation)
+    # CHECK-LABEL: gpu.binary @kernel_module1
     print(module)
-    # CHECK-LABEL:gpu.binary @kernel_module1
-    # CHECK:[#gpu.object<#nvvm.target<chip = "sm_70">, offload = "{{.*}}">]
+
+    o = gpu.ObjectAttr(module.body.operations[0].objects[0])
+    # CHECK: #gpu.object<#nvvm.target<chip = "sm_70">, offload = "{{.*}}">
+    print(o)
+    # CHECK: #nvvm.target<chip = "sm_70">
+    print(o.target)
+    # CHECK: offload
+    print(gpu.CompilationTarget(o.format))
+    # CHECK: b'{{.*}}'
+    print(o.object)
+    # CHECK: None
+    print(o.properties)
 
 
 # CHECK-LABEL: testGPUToASMBin
@@ -59,6 +70,17 @@ def testGPUToASMBin():
     pm = PassManager("any")
     pm.add("gpu-module-to-binary{format=isa}")
     pm.run(module.operation)
-    print(module)
     # CHECK-LABEL:gpu.binary @kernel_module2
-    # CHECK:[#gpu.object<#nvvm.target<flags = {fast}>, properties = {O = 2 : i32}, assembly = "{{.*}}">, #gpu.object<#nvvm.target, properties = {O = 2 : i32}, assembly = "{{.*}}">]
+    print(module)
+
+    o = gpu.ObjectAttr(module.body.operations[0].objects[0])
+    # CHECK: #gpu.object<#nvvm.target<flags = {fast}>
+    print(o)
+    # CHECK: #nvvm.target<flags = {fast}>
+    print(o.target)
+    # CHECK: assembly
+    print(gpu.CompilationTarget(o.format))
+    # CHECK: b'//\n// Generated by LLVM NVPTX Back-End{{.*}}'
+    print(o.object)
+    # CHECK: {O = 2 : i32}
+    print(o.properties)



More information about the Mlir-commits mailing list