[Mlir-commits] [mlir] [MLIR][Python] enable ptr dialect bindings (PR #167270)

Maksim Levental llvmlistbot at llvm.org
Sun Nov 9 22:40:35 PST 2025


https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/167270

>From 4aa9ab260339d1ce455913392d7481731811f665 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Sun, 9 Nov 2025 22:14:12 -0800
Subject: [PATCH] [MLIR][Python] enable ptr dialect bindings

---
 mlir/include/mlir-c/Dialect/PtrDialect.h | 41 +++++++++++++++++++++
 mlir/lib/Bindings/Python/DialectPtr.cpp  | 41 +++++++++++++++++++++
 mlir/lib/CAPI/Dialect/CMakeLists.txt     |  9 +++++
 mlir/lib/CAPI/Dialect/PtrDialect.cpp     | 39 ++++++++++++++++++++
 mlir/python/CMakeLists.txt               | 45 ++++++++++++++++++------
 mlir/python/mlir/dialects/PtrOps.td      | 14 ++++++++
 mlir/python/mlir/dialects/ptr.py         |  6 ++++
 mlir/test/python/dialects/ptr.py         | 21 +++++++++++
 8 files changed, 206 insertions(+), 10 deletions(-)
 create mode 100644 mlir/include/mlir-c/Dialect/PtrDialect.h
 create mode 100644 mlir/lib/Bindings/Python/DialectPtr.cpp
 create mode 100644 mlir/lib/CAPI/Dialect/PtrDialect.cpp
 create mode 100644 mlir/python/mlir/dialects/PtrOps.td
 create mode 100644 mlir/python/mlir/dialects/ptr.py
 create mode 100644 mlir/test/python/dialects/ptr.py

diff --git a/mlir/include/mlir-c/Dialect/PtrDialect.h b/mlir/include/mlir-c/Dialect/PtrDialect.h
new file mode 100644
index 0000000000000..3df70940014d1
--- /dev/null
+++ b/mlir/include/mlir-c/Dialect/PtrDialect.h
@@ -0,0 +1,41 @@
+//===- PtrDialect.h - C interface for the Ptr dialect -------------*- C -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_C_DIALECT_PTR_H
+#define MLIR_C_DIALECT_PTR_H
+
+#include "mlir-c/IR.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+//===----------------------------------------------------------------------===//
+// Dialect API.
+//===----------------------------------------------------------------------===//
+
+MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Ptr, ptr);
+
+//===----------------------------------------------------------------------===//
+// MemorySpaceAttrInterface API.
+//===----------------------------------------------------------------------===//
+
+//===----------------------------------------------------------------------===//
+// Type API.
+//===----------------------------------------------------------------------===//
+
+/// Checks if the given type is a Ptr type.
+MLIR_CAPI_EXPORTED bool mlirPtrTypeIsAPtrType(MlirType type);
+
+MLIR_CAPI_EXPORTED MlirType mlirPtrGetPtrType(MlirAttribute memorySpace);
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif // MLIR_C_DIALECT_PTR_H
diff --git a/mlir/lib/Bindings/Python/DialectPtr.cpp b/mlir/lib/Bindings/Python/DialectPtr.cpp
new file mode 100644
index 0000000000000..8dc99c14c71a3
--- /dev/null
+++ b/mlir/lib/Bindings/Python/DialectPtr.cpp
@@ -0,0 +1,41 @@
+//===- DialectPtr.cpp - Pybind module for Ptr dialect API support ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "NanobindUtils.h"
+
+#include "mlir-c/Dialect/PtrDialect.h"
+#include "mlir-c/IR.h"
+#include "mlir-c/Support.h"
+#include "mlir/Bindings/Python/Diagnostics.h"
+#include "mlir/Bindings/Python/Nanobind.h"
+#include "mlir/Bindings/Python/NanobindAdaptors.h"
+
+namespace nb = nanobind;
+
+using namespace nanobind::literals;
+
+using namespace mlir;
+using namespace mlir::python;
+using namespace mlir::python::nanobind_adaptors;
+
+static void populateDialectPTRSubmodule(nanobind::module_ &m) {
+  mlir_type_subclass(m, "PtrType", mlirPtrTypeIsAPtrType)
+      .def_classmethod(
+          "get",
+          [](const nb::object &cls, MlirAttribute memorySpace) {
+            return cls(mlirPtrGetPtrType(memorySpace));
+          },
+          "Gets an instance of PtrType with memory_space in the same context",
+          nb::arg("cls"), nb::arg("memory_space"));
+}
+
+NB_MODULE(_mlirDialectsPTR, m) {
+  m.doc() = "MLIR PTR Dialect";
+
+  populateDialectPTRSubmodule(m);
+}
diff --git a/mlir/lib/CAPI/Dialect/CMakeLists.txt b/mlir/lib/CAPI/Dialect/CMakeLists.txt
index bb1fdf8be3c8f..9462ad20f4fa6 100644
--- a/mlir/lib/CAPI/Dialect/CMakeLists.txt
+++ b/mlir/lib/CAPI/Dialect/CMakeLists.txt
@@ -278,3 +278,12 @@ add_mlir_upstream_c_api_library(MLIRCAPISMT
   MLIRCAPIIR
   MLIRSMT
 )
+
+add_mlir_upstream_c_api_library(MLIRCAPIPtrDialect
+  PtrDialect.cpp
+
+  PARTIAL_SOURCES_INTENDED
+  LINK_LIBS PUBLIC
+  MLIRCAPIIR
+  MLIRPtrDialect
+)
\ No newline at end of file
diff --git a/mlir/lib/CAPI/Dialect/PtrDialect.cpp b/mlir/lib/CAPI/Dialect/PtrDialect.cpp
new file mode 100644
index 0000000000000..a8f06f4852f0c
--- /dev/null
+++ b/mlir/lib/CAPI/Dialect/PtrDialect.cpp
@@ -0,0 +1,39 @@
+//===- PtrDialect.cpp - C interface for the Ptr dialect -------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir-c/Dialect/PtrDialect.h"
+#include "mlir/CAPI/Registration.h"
+#include "mlir/Dialect/Ptr/IR/PtrDialect.h"
+#include "mlir/Dialect/Ptr/IR/PtrTypes.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "ptr-dialect-capi"
+
+using namespace mlir;
+using namespace ptr;
+
+//===----------------------------------------------------------------------===//
+// Dialect API.
+//===----------------------------------------------------------------------===//
+
+MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Ptr, ptr, mlir::ptr::PtrDialect)
+
+bool mlirPtrTypeIsAPtrType(MlirType type) {
+  return llvm::isa<ptr::PtrType>(unwrap(type));
+}
+
+MlirType mlirPtrGetPtrType(MlirAttribute memorySpace) {
+  MemorySpaceAttrInterface memorySpaceAttr =
+      dyn_cast<MemorySpaceAttrInterface>(unwrap(memorySpace));
+  if (!memorySpaceAttr) {
+    LLVM_DEBUG(llvm::dbgs()
+               << "expected memory-space to be MemorySpaceAttrInterface");
+    return {nullptr};
+  }
+  return wrap(ptr::PtrType::get(memorySpaceAttr));
+}
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 51c75764faf3c..112c8e970522c 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -516,6 +516,15 @@ declare_mlir_dialect_python_bindings(
   GEN_ENUM_BINDINGS
 )
 
+declare_mlir_dialect_python_bindings(
+  ADD_TO_PARENT MLIRPythonSources.Dialects
+  ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
+  TD_FILE dialects/PtrOps.td
+  SOURCES dialects/ptr.py
+  DIALECT_NAME ptr
+  GEN_ENUM_BINDINGS
+)
+
 ################################################################################
 # Python extensions.
 # The sources for these are all in lib/Bindings/Python, but since they have to
@@ -579,7 +588,7 @@ declare_mlir_python_extension(MLIRPythonExtension.RegisterEverything
     MLIRCAPIRegisterEverything
 )
 
-declare_mlir_python_extension(MLIRPythonExtension.Dialects.Linalg.Pybind
+declare_mlir_python_extension(MLIRPythonExtension.Dialects.Linalg.Nanobind
   MODULE_NAME _mlirDialectsLinalg
   ADD_TO_PARENT MLIRPythonSources.Dialects.linalg
   ROOT_DIR "${PYTHON_SOURCE_DIR}"
@@ -593,7 +602,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.Linalg.Pybind
     MLIRCAPILinalg
 )
 
-declare_mlir_python_extension(MLIRPythonExtension.Dialects.GPU.Pybind
+declare_mlir_python_extension(MLIRPythonExtension.Dialects.GPU.Nanobind
   MODULE_NAME _mlirDialectsGPU
   ADD_TO_PARENT MLIRPythonSources.Dialects.gpu
   ROOT_DIR "${PYTHON_SOURCE_DIR}"
@@ -607,7 +616,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.GPU.Pybind
     MLIRCAPIGPU
 )
 
-declare_mlir_python_extension(MLIRPythonExtension.Dialects.LLVM.Pybind
+declare_mlir_python_extension(MLIRPythonExtension.Dialects.LLVM.Nanobind
   MODULE_NAME _mlirDialectsLLVM
   ADD_TO_PARENT MLIRPythonSources.Dialects.llvm
   ROOT_DIR "${PYTHON_SOURCE_DIR}"
@@ -623,7 +632,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.LLVM.Pybind
     MLIRCAPITarget
 )
 
-declare_mlir_python_extension(MLIRPythonExtension.Dialects.Quant.Pybind
+declare_mlir_python_extension(MLIRPythonExtension.Dialects.Quant.Nanobind
   MODULE_NAME _mlirDialectsQuant
   ADD_TO_PARENT MLIRPythonSources.Dialects.quant
   ROOT_DIR "${PYTHON_SOURCE_DIR}"
@@ -637,7 +646,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.Quant.Pybind
     MLIRCAPIQuant
 )
 
-declare_mlir_python_extension(MLIRPythonExtension.Dialects.NVGPU.Pybind
+declare_mlir_python_extension(MLIRPythonExtension.Dialects.NVGPU.Nanobind
   MODULE_NAME _mlirDialectsNVGPU
   ADD_TO_PARENT MLIRPythonSources.Dialects.nvgpu
   ROOT_DIR "${PYTHON_SOURCE_DIR}"
@@ -651,7 +660,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.NVGPU.Pybind
     MLIRCAPINVGPU
 )
 
-declare_mlir_python_extension(MLIRPythonExtension.Dialects.PDL.Pybind
+declare_mlir_python_extension(MLIRPythonExtension.Dialects.PDL.Nanobind
   MODULE_NAME _mlirDialectsPDL
   ADD_TO_PARENT MLIRPythonSources.Dialects.pdl
   ROOT_DIR "${PYTHON_SOURCE_DIR}"
@@ -665,7 +674,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.PDL.Pybind
     MLIRCAPIPDL
 )
 
-declare_mlir_python_extension(MLIRPythonExtension.Dialects.SparseTensor.Pybind
+declare_mlir_python_extension(MLIRPythonExtension.Dialects.SparseTensor.Nanobind
   MODULE_NAME _mlirDialectsSparseTensor
   ADD_TO_PARENT MLIRPythonSources.Dialects.sparse_tensor
   ROOT_DIR "${PYTHON_SOURCE_DIR}"
@@ -679,7 +688,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.SparseTensor.Pybind
     MLIRCAPISparseTensor
 )
 
-declare_mlir_python_extension(MLIRPythonExtension.Dialects.Transform.Pybind
+declare_mlir_python_extension(MLIRPythonExtension.Dialects.Transform.Nanobind
   MODULE_NAME _mlirDialectsTransform
   ADD_TO_PARENT MLIRPythonSources.Dialects.transform
   ROOT_DIR "${PYTHON_SOURCE_DIR}"
@@ -693,7 +702,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.Transform.Pybind
     MLIRCAPITransformDialect
 )
 
-declare_mlir_python_extension(MLIRPythonExtension.Dialects.IRDL.Pybind
+declare_mlir_python_extension(MLIRPythonExtension.Dialects.IRDL.Nanobind
   MODULE_NAME _mlirDialectsIRDL
   ADD_TO_PARENT MLIRPythonSources.Dialects.irdl
   ROOT_DIR "${PYTHON_SOURCE_DIR}"
@@ -761,7 +770,7 @@ declare_mlir_python_extension(MLIRPythonExtension.LinalgPasses
     MLIRCAPILinalg
 )
 
-declare_mlir_python_extension(MLIRPythonExtension.Dialects.SMT.Pybind
+declare_mlir_python_extension(MLIRPythonExtension.Dialects.SMT.Nanobind
   MODULE_NAME _mlirDialectsSMT
   ADD_TO_PARENT MLIRPythonSources.Dialects.smt
   ROOT_DIR "${PYTHON_SOURCE_DIR}"
@@ -778,6 +787,22 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.SMT.Pybind
     MLIRCAPIExportSMTLIB
 )
 
+declare_mlir_python_extension(MLIRPythonExtension.Dialects.Ptr.Nanobind
+  MODULE_NAME _mlirDialectsPtr
+  ADD_TO_PARENT MLIRPythonSources.Dialects.ptr
+  ROOT_DIR "${PYTHON_SOURCE_DIR}"
+  PYTHON_BINDINGS_LIBRARY nanobind
+  SOURCES
+    DialectPtr.cpp
+    # Headers must be included explicitly so they are installed.
+    NanobindUtils.h
+  PRIVATE_LINK_LIBS
+    LLVMSupport
+  EMBED_CAPI_LINK_LIBS
+    MLIRCAPIIR
+    MLIRCAPIPtrDialect
+)
+
 declare_mlir_python_extension(MLIRPythonExtension.SparseTensorDialectPasses
   MODULE_NAME _mlirSparseTensorPasses
   ADD_TO_PARENT MLIRPythonSources.Dialects.sparse_tensor
diff --git a/mlir/python/mlir/dialects/PtrOps.td b/mlir/python/mlir/dialects/PtrOps.td
new file mode 100644
index 0000000000000..8bde942c10192
--- /dev/null
+++ b/mlir/python/mlir/dialects/PtrOps.td
@@ -0,0 +1,14 @@
+//===- PTROps.td - Entry point for PTR bindings ------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef BINDINGS_PYTHON_PTR_OPS
+#define BINDINGS_PYTHON_PTR_OPS
+
+include "mlir/Dialect/Ptr/IR/PtrOps.td"
+
+#endif // BINDINGS_PYTHON_PTR_OPS
diff --git a/mlir/python/mlir/dialects/ptr.py b/mlir/python/mlir/dialects/ptr.py
new file mode 100644
index 0000000000000..a837b5b894dc6
--- /dev/null
+++ b/mlir/python/mlir/dialects/ptr.py
@@ -0,0 +1,6 @@
+#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+#  See https://llvm.org/LICENSE.txt for license information.
+#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from ._ptr_ops_gen import *
+from ._ptr_enum_gen import *
diff --git a/mlir/test/python/dialects/ptr.py b/mlir/test/python/dialects/ptr.py
new file mode 100644
index 0000000000000..ef0f26303afbf
--- /dev/null
+++ b/mlir/test/python/dialects/ptr.py
@@ -0,0 +1,21 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+from mlir.dialects import ptr
+from mlir.ir import Context, Location, Module, InsertionPoint
+
+
+def run(f):
+    print("\nTEST:", f.__name__)
+    with Context(), Location.unknown():
+        module = Module.create()
+        with InsertionPoint(module.body):
+            f(module)
+        print(module)
+        assert module.operation.verify()
+
+
+# CHECK-LABEL: TEST: test_smoke
+ at run
+def test_smoke(_module):
+    null = ptr.constant(True)
+    # CHECK: ptr.constant #ptr.null : !ptr.ptr<#ptr.generic_space>



More information about the Mlir-commits mailing list