[Mlir-commits] [mlir] [amdgpu] Add Python bindings for TDM types (PR #172309)
Tim Gymnich
llvmlistbot at llvm.org
Mon Dec 15 06:45:25 PST 2025
https://github.com/tgymnich updated https://github.com/llvm/llvm-project/pull/172309
>From 9884d84926708b452aa629b54c7ee5ec57b405b9 Mon Sep 17 00:00:00 2001
From: Tim Gymnich <tim at gymni.ch>
Date: Mon, 15 Dec 2025 14:38:59 +0000
Subject: [PATCH 1/2] [amdgpu] Add Python bindings for TDM types
---
mlir/include/mlir-c/Dialect/AMDGPU.h | 27 +++++++++
mlir/lib/Bindings/Python/DialectAMDGPU.cpp | 64 ++++++++++++++++++++++
mlir/lib/CAPI/Dialect/AMDGPU.cpp | 43 +++++++++++++++
mlir/python/CMakeLists.txt | 15 +++++
mlir/python/mlir/dialects/amdgpu.py | 1 +
mlir/test/python/dialects/amdgpu.py | 21 ++++++-
6 files changed, 170 insertions(+), 1 deletion(-)
create mode 100644 mlir/lib/Bindings/Python/DialectAMDGPU.cpp
diff --git a/mlir/include/mlir-c/Dialect/AMDGPU.h b/mlir/include/mlir-c/Dialect/AMDGPU.h
index 142044f7f3afe..950dca3f2fa1c 100644
--- a/mlir/include/mlir-c/Dialect/AMDGPU.h
+++ b/mlir/include/mlir-c/Dialect/AMDGPU.h
@@ -18,6 +18,33 @@ extern "C" {
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(AMDGPU, amdgpu);
+//===---------------------------------------------------------------------===//
+// TDMBaseType
+//===---------------------------------------------------------------------===//
+
+MLIR_CAPI_EXPORTED bool mlirTypeIsAAMDGPUTDMBaseType(MlirType type);
+
+MLIR_CAPI_EXPORTED MlirType mlirAMDGPUTDMBaseTypeGet(MlirContext ctx,
+ MlirType elementType);
+
+//===---------------------------------------------------------------------===//
+// TDMDescriptorType
+//===---------------------------------------------------------------------===//
+
+MLIR_CAPI_EXPORTED bool mlirTypeIsAAMDGPUTDMDescriptorType(MlirType type);
+
+MLIR_CAPI_EXPORTED MlirType mlirAMDGPUTDMDescriptorTypeGet(MlirContext ctx);
+
+//===---------------------------------------------------------------------===//
+// TDMGatherBaseType
+//===---------------------------------------------------------------------===//
+
+MLIR_CAPI_EXPORTED bool mlirTypeIsAAMDGPUTDMGatherBaseType(MlirType type);
+
+MLIR_CAPI_EXPORTED MlirType mlirAMDGPUTDMGatherBaseTypeGet(MlirContext ctx,
+ MlirType elementType,
+ MlirType indexType);
+
#ifdef __cplusplus
}
#endif
diff --git a/mlir/lib/Bindings/Python/DialectAMDGPU.cpp b/mlir/lib/Bindings/Python/DialectAMDGPU.cpp
new file mode 100644
index 0000000000000..99e6b65640973
--- /dev/null
+++ b/mlir/lib/Bindings/Python/DialectAMDGPU.cpp
@@ -0,0 +1,64 @@
+//===--- DialectAMDGPU.cpp - Pybind module for AMDGPU dialect API support -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir-c/Dialect/AMDGPU.h"
+#include "mlir-c/IR.h"
+#include "mlir/Bindings/Python/Nanobind.h"
+#include "mlir/Bindings/Python/NanobindAdaptors.h"
+#include "nanobind/nanobind.h"
+
+namespace nb = nanobind;
+using namespace llvm;
+using namespace mlir;
+using namespace mlir::python;
+using namespace mlir::python::nanobind_adaptors;
+
+static void populateDialectAMDGPUSubmodule(const nb::module_ &m) {
+ auto amdgpuTDMBaseType =
+ mlir_type_subclass(m, "TDMBaseType", mlirTypeIsAAMDGPUTDMBaseType);
+
+ amdgpuTDMBaseType.def_classmethod(
+ "get",
+ [](const nb::object &cls, MlirType elementType, MlirContext ctx) {
+ return cls(mlirAMDGPUTDMBaseTypeGet(ctx, elementType));
+ },
+ "Gets an instance of TDMBaseType in the same context", nb::arg("cls"),
+ nb::arg("element_type"), nb::arg("ctx") = nb::none());
+
+
+ auto amdgpuTDMDescriptorType = mlir_type_subclass(
+ m, "TDMDescriptorType", mlirTypeIsAAMDGPUTDMDescriptorType);
+
+ amdgpuTDMDescriptorType.def_classmethod(
+ "get",
+ [](const nb::object &cls, MlirContext ctx) {
+ return cls(mlirAMDGPUTDMDescriptorTypeGet(ctx));
+ },
+ "Gets an instance of TDMDescriptorType in the same context",
+ nb::arg("cls"), nb::arg("ctx") = nb::none());
+
+
+ auto amdgpuTDMGatherBaseType = mlir_type_subclass(
+ m, "TDMGatherBaseType", mlirTypeIsAAMDGPUTDMGatherBaseType);
+
+ amdgpuTDMGatherBaseType.def_classmethod(
+ "get",
+ [](const nb::object &cls, MlirType elementType, MlirType indexType,
+ MlirContext ctx) {
+ return cls(mlirAMDGPUTDMGatherBaseTypeGet(ctx, elementType, indexType));
+ },
+ "Gets an instance of TDMGatherBaseType in the same context",
+ nb::arg("cls"), nb::arg("element_type"), nb::arg("index_type"),
+ nb::arg("ctx") = nb::none());
+};
+
+NB_MODULE(_mlirDialectsAMDGPU, m) {
+ m.doc() = "MLIR AMDGPU dialect.";
+
+ populateDialectAMDGPUSubmodule(m);
+}
diff --git a/mlir/lib/CAPI/Dialect/AMDGPU.cpp b/mlir/lib/CAPI/Dialect/AMDGPU.cpp
index d877ca2dff375..26dfb27a56879 100644
--- a/mlir/lib/CAPI/Dialect/AMDGPU.cpp
+++ b/mlir/lib/CAPI/Dialect/AMDGPU.cpp
@@ -12,3 +12,46 @@
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(AMDGPU, amdgpu,
mlir::amdgpu::AMDGPUDialect)
+
+using namespace mlir;
+using namespace mlir::amdgpu;
+
+//===---------------------------------------------------------------------===//
+// TDMBaseType
+//===---------------------------------------------------------------------===//
+
+bool mlirTypeIsAAMDGPUTDMBaseType(MlirType type) {
+ return isa<amdgpu::TDMBaseType>(unwrap(type));
+}
+
+MlirType mlirAMDGPUTDMBaseTypeGet(MlirContext ctx, MlirType elementType) {
+ return wrap(amdgpu::TDMBaseType::get(unwrap(ctx),
+ cast<Type>(unwrap(elementType))));
+}
+
+//===---------------------------------------------------------------------===//
+// TDMDescriptorType
+//===---------------------------------------------------------------------===//
+
+bool mlirTypeIsAAMDGPUTDMDescriptorType(MlirType type) {
+ return isa<amdgpu::TDMDescriptorType>(unwrap(type));
+}
+
+MlirType mlirAMDGPUTDMDescriptorTypeGet(MlirContext ctx) {
+ return wrap(amdgpu::TDMDescriptorType::get(unwrap(ctx)));
+}
+
+//===---------------------------------------------------------------------===//
+// TDMGatherBaseType
+//===---------------------------------------------------------------------===//
+
+bool mlirTypeIsAAMDGPUTDMGatherBaseType(MlirType type) {
+ return isa<amdgpu::TDMGatherBaseType>(unwrap(type));
+}
+
+MlirType mlirAMDGPUTDMGatherBaseTypeGet(MlirContext ctx, MlirType elementType,
+ MlirType indexType) {
+ return wrap(amdgpu::TDMGatherBaseType::get(unwrap(ctx),
+ cast<Type>(unwrap(elementType)),
+ cast<Type>(unwrap(indexType))));
+}
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 2acb6ee6cfda5..6e449e275f782 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -804,6 +804,21 @@ declare_mlir_python_extension(MLIRPythonExtension.TransformInterpreter
MLIRCAPITransformDialectTransforms
)
+declare_mlir_python_extension(MLIRPythonExtension.Dialects.AMDGPU.Pybind
+ MODULE_NAME _mlirDialectsAMDGPU
+ ADD_TO_PARENT MLIRPythonSources.Dialects.amdgpu
+ ROOT_DIR "${PYTHON_SOURCE_DIR}"
+ PYTHON_BINDINGS_LIBRARY nanobind
+ SOURCES
+ DialectAMDGPU.cpp
+ PRIVATE_LINK_LIBS
+ LLVMSupport
+ EMBED_CAPI_LINK_LIBS
+ MLIRCAPIIR
+ MLIRCAPIAMDGPU
+)
+
+
# TODO: Figure out how to put this in the test tree.
# This should not be included in the main Python extension. However,
# putting it into MLIRPythonTestSources along with the dialect declaration
diff --git a/mlir/python/mlir/dialects/amdgpu.py b/mlir/python/mlir/dialects/amdgpu.py
index 43d905d0c481c..1c4d274bc31af 100644
--- a/mlir/python/mlir/dialects/amdgpu.py
+++ b/mlir/python/mlir/dialects/amdgpu.py
@@ -4,3 +4,4 @@
from ._amdgpu_ops_gen import *
from ._amdgpu_enum_gen import *
+from .._mlir_libs._mlirDialectsAMDGPU import *
diff --git a/mlir/test/python/dialects/amdgpu.py b/mlir/test/python/dialects/amdgpu.py
index b479576dac093..c126a6d201eb0 100644
--- a/mlir/test/python/dialects/amdgpu.py
+++ b/mlir/test/python/dialects/amdgpu.py
@@ -2,7 +2,7 @@
# This is just a smoke test that the dialect is functional.
from mlir.ir import *
-from mlir.dialects import amdgpu, func
+from mlir.dialects import amdgpu, func, memref
def constructAndPrintInModule(f):
@@ -43,3 +43,22 @@ def testFatRawBufferCastOpParams():
# CHECK-NEXT: amdgpu.fat_raw_buffer_cast %[[ARG0]] resetOffset : memref<?x?xf32> to memref<?x?xf32, #amdgpu.address_space<fat_raw_buffer>>
# CHECK-NEXT: amdgpu.fat_raw_buffer_cast %[[ARG0]] boundsCheck(false) : memref<?x?xf32> to memref<?x?xf32, #amdgpu.address_space<fat_raw_buffer>>
# CHECK-NEXT: amdgpu.fat_raw_buffer_cast %[[ARG0]] boundsCheck(false) resetOffset : memref<?x?xf32> to memref<?x?xf32, #amdgpu.address_space<fat_raw_buffer>>
+
+
+# CHECK-LABEL: testTDMTypes
+ at constructAndPrintInModule
+def testTDMTypes():
+ f32 = F32Type.get()
+ i32 = IntegerType.get_signless(32)
+
+ # CHECK: !amdgpu.tdm_base<f32>
+ tdm_base = amdgpu.TDMBaseType.get(f32)
+ print(tdm_base)
+
+ # CHECK: !amdgpu.tdm_descriptor
+ tdm_descriptor = amdgpu.TDMDescriptorType.get()
+ print(tdm_descriptor)
+
+ # CHECK: !amdgpu.tdm_gather_base<f32, i32>`
+ tdm_gather_base = amdgpu.TDMGatherBaseType.get(f32, i32)
+ print(tdm_gather_base)
>From 917f55f7ab0e6b287df07df2a9f933df1e7043f0 Mon Sep 17 00:00:00 2001
From: Tim Gymnich <tim at gymni.ch>
Date: Mon, 15 Dec 2025 14:45:12 +0000
Subject: [PATCH 2/2] drop cast
---
mlir/lib/CAPI/Dialect/AMDGPU.cpp | 8 +++-----
1 file changed, 3 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/CAPI/Dialect/AMDGPU.cpp b/mlir/lib/CAPI/Dialect/AMDGPU.cpp
index 26dfb27a56879..77536e822c0ac 100644
--- a/mlir/lib/CAPI/Dialect/AMDGPU.cpp
+++ b/mlir/lib/CAPI/Dialect/AMDGPU.cpp
@@ -25,8 +25,7 @@ bool mlirTypeIsAAMDGPUTDMBaseType(MlirType type) {
}
MlirType mlirAMDGPUTDMBaseTypeGet(MlirContext ctx, MlirType elementType) {
- return wrap(amdgpu::TDMBaseType::get(unwrap(ctx),
- cast<Type>(unwrap(elementType))));
+ return wrap(amdgpu::TDMBaseType::get(unwrap(ctx), unwrap(elementType)));
}
//===---------------------------------------------------------------------===//
@@ -51,7 +50,6 @@ bool mlirTypeIsAAMDGPUTDMGatherBaseType(MlirType type) {
MlirType mlirAMDGPUTDMGatherBaseTypeGet(MlirContext ctx, MlirType elementType,
MlirType indexType) {
- return wrap(amdgpu::TDMGatherBaseType::get(unwrap(ctx),
- cast<Type>(unwrap(elementType)),
- cast<Type>(unwrap(indexType))));
+ return wrap(amdgpu::TDMGatherBaseType::get(unwrap(ctx), unwrap(elementType),
+ unwrap(indexType)));
}
More information about the Mlir-commits
mailing list