[Mlir-commits] [mlir] [MLIR][Python] enable ptr dialect bindings (PR #167270)
Maksim Levental
llvmlistbot at llvm.org
Sun Nov 9 22:40:35 PST 2025
https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/167270
>From 4aa9ab260339d1ce455913392d7481731811f665 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Sun, 9 Nov 2025 22:14:12 -0800
Subject: [PATCH] [MLIR][Python] enable ptr dialect bindings
---
mlir/include/mlir-c/Dialect/PtrDialect.h | 41 +++++++++++++++++++++
mlir/lib/Bindings/Python/DialectPtr.cpp | 41 +++++++++++++++++++++
mlir/lib/CAPI/Dialect/CMakeLists.txt | 9 +++++
mlir/lib/CAPI/Dialect/PtrDialect.cpp | 39 ++++++++++++++++++++
mlir/python/CMakeLists.txt | 45 ++++++++++++++++++------
mlir/python/mlir/dialects/PtrOps.td | 14 ++++++++
mlir/python/mlir/dialects/ptr.py | 6 ++++
mlir/test/python/dialects/ptr.py | 21 +++++++++++
8 files changed, 206 insertions(+), 10 deletions(-)
create mode 100644 mlir/include/mlir-c/Dialect/PtrDialect.h
create mode 100644 mlir/lib/Bindings/Python/DialectPtr.cpp
create mode 100644 mlir/lib/CAPI/Dialect/PtrDialect.cpp
create mode 100644 mlir/python/mlir/dialects/PtrOps.td
create mode 100644 mlir/python/mlir/dialects/ptr.py
create mode 100644 mlir/test/python/dialects/ptr.py
diff --git a/mlir/include/mlir-c/Dialect/PtrDialect.h b/mlir/include/mlir-c/Dialect/PtrDialect.h
new file mode 100644
index 0000000000000..3df70940014d1
--- /dev/null
+++ b/mlir/include/mlir-c/Dialect/PtrDialect.h
@@ -0,0 +1,41 @@
+//===- PtrDialect.h - C interface for the Ptr dialect -------------*- C -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_C_DIALECT_PTR_H
+#define MLIR_C_DIALECT_PTR_H
+
+#include "mlir-c/IR.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+//===----------------------------------------------------------------------===//
+// Dialect API.
+//===----------------------------------------------------------------------===//
+
+MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Ptr, ptr);
+
+//===----------------------------------------------------------------------===//
+// MemorySpaceAttrInterface API.
+//===----------------------------------------------------------------------===//
+
+//===----------------------------------------------------------------------===//
+// Type API.
+//===----------------------------------------------------------------------===//
+
+/// Checks if the given type is a Ptr type.
+MLIR_CAPI_EXPORTED bool mlirPtrTypeIsAPtrType(MlirType type);
+
+MLIR_CAPI_EXPORTED MlirType mlirPtrGetPtrType(MlirAttribute memorySpace);
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif // MLIR_C_DIALECT_PTR_H
diff --git a/mlir/lib/Bindings/Python/DialectPtr.cpp b/mlir/lib/Bindings/Python/DialectPtr.cpp
new file mode 100644
index 0000000000000..8dc99c14c71a3
--- /dev/null
+++ b/mlir/lib/Bindings/Python/DialectPtr.cpp
@@ -0,0 +1,41 @@
+//===- DialectPtr.cpp - Pybind module for Ptr dialect API support ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "NanobindUtils.h"
+
+#include "mlir-c/Dialect/PtrDialect.h"
+#include "mlir-c/IR.h"
+#include "mlir-c/Support.h"
+#include "mlir/Bindings/Python/Diagnostics.h"
+#include "mlir/Bindings/Python/Nanobind.h"
+#include "mlir/Bindings/Python/NanobindAdaptors.h"
+
+namespace nb = nanobind;
+
+using namespace nanobind::literals;
+
+using namespace mlir;
+using namespace mlir::python;
+using namespace mlir::python::nanobind_adaptors;
+
+static void populateDialectPTRSubmodule(nanobind::module_ &m) {
+ mlir_type_subclass(m, "PtrType", mlirPtrTypeIsAPtrType)
+ .def_classmethod(
+ "get",
+ [](const nb::object &cls, MlirAttribute memorySpace) {
+ return cls(mlirPtrGetPtrType(memorySpace));
+ },
+ "Gets an instance of PtrType with memory_space in the same context",
+ nb::arg("cls"), nb::arg("memory_space"));
+}
+
+NB_MODULE(_mlirDialectsPTR, m) {
+ m.doc() = "MLIR PTR Dialect";
+
+ populateDialectPTRSubmodule(m);
+}
diff --git a/mlir/lib/CAPI/Dialect/CMakeLists.txt b/mlir/lib/CAPI/Dialect/CMakeLists.txt
index bb1fdf8be3c8f..9462ad20f4fa6 100644
--- a/mlir/lib/CAPI/Dialect/CMakeLists.txt
+++ b/mlir/lib/CAPI/Dialect/CMakeLists.txt
@@ -278,3 +278,12 @@ add_mlir_upstream_c_api_library(MLIRCAPISMT
MLIRCAPIIR
MLIRSMT
)
+
+add_mlir_upstream_c_api_library(MLIRCAPIPtrDialect
+ PtrDialect.cpp
+
+ PARTIAL_SOURCES_INTENDED
+ LINK_LIBS PUBLIC
+ MLIRCAPIIR
+ MLIRPtrDialect
+)
\ No newline at end of file
diff --git a/mlir/lib/CAPI/Dialect/PtrDialect.cpp b/mlir/lib/CAPI/Dialect/PtrDialect.cpp
new file mode 100644
index 0000000000000..a8f06f4852f0c
--- /dev/null
+++ b/mlir/lib/CAPI/Dialect/PtrDialect.cpp
@@ -0,0 +1,39 @@
+//===- PtrDialect.cpp - C interface for the Ptr dialect -------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir-c/Dialect/PtrDialect.h"
+#include "mlir/CAPI/Registration.h"
+#include "mlir/Dialect/Ptr/IR/PtrDialect.h"
+#include "mlir/Dialect/Ptr/IR/PtrTypes.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "ptr-dialect-capi"
+
+using namespace mlir;
+using namespace ptr;
+
+//===----------------------------------------------------------------------===//
+// Dialect API.
+//===----------------------------------------------------------------------===//
+
+MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Ptr, ptr, mlir::ptr::PtrDialect)
+
+bool mlirPtrTypeIsAPtrType(MlirType type) {
+ return llvm::isa<ptr::PtrType>(unwrap(type));
+}
+
+MlirType mlirPtrGetPtrType(MlirAttribute memorySpace) {
+ MemorySpaceAttrInterface memorySpaceAttr =
+ dyn_cast<MemorySpaceAttrInterface>(unwrap(memorySpace));
+ if (!memorySpaceAttr) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "expected memory-space to be MemorySpaceAttrInterface");
+ return {nullptr};
+ }
+ return wrap(ptr::PtrType::get(memorySpaceAttr));
+}
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 51c75764faf3c..112c8e970522c 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -516,6 +516,15 @@ declare_mlir_dialect_python_bindings(
GEN_ENUM_BINDINGS
)
+declare_mlir_dialect_python_bindings(
+ ADD_TO_PARENT MLIRPythonSources.Dialects
+ ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
+ TD_FILE dialects/PtrOps.td
+ SOURCES dialects/ptr.py
+ DIALECT_NAME ptr
+ GEN_ENUM_BINDINGS
+)
+
################################################################################
# Python extensions.
# The sources for these are all in lib/Bindings/Python, but since they have to
@@ -579,7 +588,7 @@ declare_mlir_python_extension(MLIRPythonExtension.RegisterEverything
MLIRCAPIRegisterEverything
)
-declare_mlir_python_extension(MLIRPythonExtension.Dialects.Linalg.Pybind
+declare_mlir_python_extension(MLIRPythonExtension.Dialects.Linalg.Nanobind
MODULE_NAME _mlirDialectsLinalg
ADD_TO_PARENT MLIRPythonSources.Dialects.linalg
ROOT_DIR "${PYTHON_SOURCE_DIR}"
@@ -593,7 +602,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.Linalg.Pybind
MLIRCAPILinalg
)
-declare_mlir_python_extension(MLIRPythonExtension.Dialects.GPU.Pybind
+declare_mlir_python_extension(MLIRPythonExtension.Dialects.GPU.Nanobind
MODULE_NAME _mlirDialectsGPU
ADD_TO_PARENT MLIRPythonSources.Dialects.gpu
ROOT_DIR "${PYTHON_SOURCE_DIR}"
@@ -607,7 +616,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.GPU.Pybind
MLIRCAPIGPU
)
-declare_mlir_python_extension(MLIRPythonExtension.Dialects.LLVM.Pybind
+declare_mlir_python_extension(MLIRPythonExtension.Dialects.LLVM.Nanobind
MODULE_NAME _mlirDialectsLLVM
ADD_TO_PARENT MLIRPythonSources.Dialects.llvm
ROOT_DIR "${PYTHON_SOURCE_DIR}"
@@ -623,7 +632,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.LLVM.Pybind
MLIRCAPITarget
)
-declare_mlir_python_extension(MLIRPythonExtension.Dialects.Quant.Pybind
+declare_mlir_python_extension(MLIRPythonExtension.Dialects.Quant.Nanobind
MODULE_NAME _mlirDialectsQuant
ADD_TO_PARENT MLIRPythonSources.Dialects.quant
ROOT_DIR "${PYTHON_SOURCE_DIR}"
@@ -637,7 +646,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.Quant.Pybind
MLIRCAPIQuant
)
-declare_mlir_python_extension(MLIRPythonExtension.Dialects.NVGPU.Pybind
+declare_mlir_python_extension(MLIRPythonExtension.Dialects.NVGPU.Nanobind
MODULE_NAME _mlirDialectsNVGPU
ADD_TO_PARENT MLIRPythonSources.Dialects.nvgpu
ROOT_DIR "${PYTHON_SOURCE_DIR}"
@@ -651,7 +660,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.NVGPU.Pybind
MLIRCAPINVGPU
)
-declare_mlir_python_extension(MLIRPythonExtension.Dialects.PDL.Pybind
+declare_mlir_python_extension(MLIRPythonExtension.Dialects.PDL.Nanobind
MODULE_NAME _mlirDialectsPDL
ADD_TO_PARENT MLIRPythonSources.Dialects.pdl
ROOT_DIR "${PYTHON_SOURCE_DIR}"
@@ -665,7 +674,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.PDL.Pybind
MLIRCAPIPDL
)
-declare_mlir_python_extension(MLIRPythonExtension.Dialects.SparseTensor.Pybind
+declare_mlir_python_extension(MLIRPythonExtension.Dialects.SparseTensor.Nanobind
MODULE_NAME _mlirDialectsSparseTensor
ADD_TO_PARENT MLIRPythonSources.Dialects.sparse_tensor
ROOT_DIR "${PYTHON_SOURCE_DIR}"
@@ -679,7 +688,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.SparseTensor.Pybind
MLIRCAPISparseTensor
)
-declare_mlir_python_extension(MLIRPythonExtension.Dialects.Transform.Pybind
+declare_mlir_python_extension(MLIRPythonExtension.Dialects.Transform.Nanobind
MODULE_NAME _mlirDialectsTransform
ADD_TO_PARENT MLIRPythonSources.Dialects.transform
ROOT_DIR "${PYTHON_SOURCE_DIR}"
@@ -693,7 +702,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.Transform.Pybind
MLIRCAPITransformDialect
)
-declare_mlir_python_extension(MLIRPythonExtension.Dialects.IRDL.Pybind
+declare_mlir_python_extension(MLIRPythonExtension.Dialects.IRDL.Nanobind
MODULE_NAME _mlirDialectsIRDL
ADD_TO_PARENT MLIRPythonSources.Dialects.irdl
ROOT_DIR "${PYTHON_SOURCE_DIR}"
@@ -761,7 +770,7 @@ declare_mlir_python_extension(MLIRPythonExtension.LinalgPasses
MLIRCAPILinalg
)
-declare_mlir_python_extension(MLIRPythonExtension.Dialects.SMT.Pybind
+declare_mlir_python_extension(MLIRPythonExtension.Dialects.SMT.Nanobind
MODULE_NAME _mlirDialectsSMT
ADD_TO_PARENT MLIRPythonSources.Dialects.smt
ROOT_DIR "${PYTHON_SOURCE_DIR}"
@@ -778,6 +787,22 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.SMT.Pybind
MLIRCAPIExportSMTLIB
)
+declare_mlir_python_extension(MLIRPythonExtension.Dialects.Ptr.Nanobind
+ MODULE_NAME _mlirDialectsPtr
+ ADD_TO_PARENT MLIRPythonSources.Dialects.ptr
+ ROOT_DIR "${PYTHON_SOURCE_DIR}"
+ PYTHON_BINDINGS_LIBRARY nanobind
+ SOURCES
+ DialectPtr.cpp
+ # Headers must be included explicitly so they are installed.
+ NanobindUtils.h
+ PRIVATE_LINK_LIBS
+ LLVMSupport
+ EMBED_CAPI_LINK_LIBS
+ MLIRCAPIIR
+ MLIRCAPIPtrDialect
+)
+
declare_mlir_python_extension(MLIRPythonExtension.SparseTensorDialectPasses
MODULE_NAME _mlirSparseTensorPasses
ADD_TO_PARENT MLIRPythonSources.Dialects.sparse_tensor
diff --git a/mlir/python/mlir/dialects/PtrOps.td b/mlir/python/mlir/dialects/PtrOps.td
new file mode 100644
index 0000000000000..8bde942c10192
--- /dev/null
+++ b/mlir/python/mlir/dialects/PtrOps.td
@@ -0,0 +1,14 @@
+//===- PTROps.td - Entry point for PTR bindings ------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef BINDINGS_PYTHON_PTR_OPS
+#define BINDINGS_PYTHON_PTR_OPS
+
+include "mlir/Dialect/Ptr/IR/PtrOps.td"
+
+#endif // BINDINGS_PYTHON_PTR_OPS
diff --git a/mlir/python/mlir/dialects/ptr.py b/mlir/python/mlir/dialects/ptr.py
new file mode 100644
index 0000000000000..a837b5b894dc6
--- /dev/null
+++ b/mlir/python/mlir/dialects/ptr.py
@@ -0,0 +1,6 @@
+# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from ._ptr_ops_gen import *
+from ._ptr_enum_gen import *
diff --git a/mlir/test/python/dialects/ptr.py b/mlir/test/python/dialects/ptr.py
new file mode 100644
index 0000000000000..ef0f26303afbf
--- /dev/null
+++ b/mlir/test/python/dialects/ptr.py
@@ -0,0 +1,21 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+from mlir.dialects import ptr
+from mlir.ir import Context, Location, Module, InsertionPoint
+
+
+def run(f):
+ print("\nTEST:", f.__name__)
+ with Context(), Location.unknown():
+ module = Module.create()
+ with InsertionPoint(module.body):
+ f(module)
+ print(module)
+ assert module.operation.verify()
+
+
+# CHECK-LABEL: TEST: test_smoke
+ at run
+def test_smoke(_module):
+ null = ptr.constant(True)
+ # CHECK: ptr.constant #ptr.null : !ptr.ptr<#ptr.generic_space>
More information about the Mlir-commits
mailing list