[Mlir-commits] [mlir] [mlir][gpu] Add py binding for AsyncTokenType (PR #96466)

Guray Ozen llvmlistbot at llvm.org
Mon Jun 24 01:53:47 PDT 2024


https://github.com/grypp created https://github.com/llvm/llvm-project/pull/96466

The PR adds py binding for `AsyncTokenType`

>From 3df6cc1112054a299ec39a94fcaf1d1f3fc395af Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Mon, 24 Jun 2024 10:52:52 +0200
Subject: [PATCH] [mlir][gpu] Add py binding for AsyncTokenType

The PR adds py binding for `AsyncTokenType`
---
 mlir/include/mlir-c/Dialect/GPU.h                  |  8 ++++++++
 mlir/lib/Bindings/Python/DialectGPU.cpp            | 14 ++++++++++++++
 mlir/lib/CAPI/Dialect/GPU.cpp                      | 12 ++++++++++++
 mlir/test/Examples/NVGPU/Ch1.py                    |  2 +-
 mlir/test/Examples/NVGPU/Ch2.py                    |  2 +-
 mlir/test/Examples/NVGPU/Ch3.py                    |  2 +-
 mlir/test/Examples/NVGPU/Ch4.py                    |  2 +-
 mlir/test/Examples/NVGPU/Ch5.py                    |  2 +-
 .../GPU/CUDA/sm90/python/tools/matmulBuilder.py    |  4 ++--
 9 files changed, 41 insertions(+), 7 deletions(-)

diff --git a/mlir/include/mlir-c/Dialect/GPU.h b/mlir/include/mlir-c/Dialect/GPU.h
index 2adf73ddff6ea..c42ff61f9592c 100644
--- a/mlir/include/mlir-c/Dialect/GPU.h
+++ b/mlir/include/mlir-c/Dialect/GPU.h
@@ -19,6 +19,14 @@ extern "C" {
 
 MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(GPU, gpu);
 
+//===-------------------------------------------------------------------===//
+// AsyncTokenType
+//===-------------------------------------------------------------------===//
+
+MLIR_CAPI_EXPORTED bool mlirTypeIsAGPUAsyncTokenType(MlirType type);
+
+MLIR_CAPI_EXPORTED MlirType mlirGPUAsyncTokenTypeGet(MlirContext ctx);
+
 //===---------------------------------------------------------------------===//
 // ObjectAttr
 //===---------------------------------------------------------------------===//
diff --git a/mlir/lib/Bindings/Python/DialectGPU.cpp b/mlir/lib/Bindings/Python/DialectGPU.cpp
index 1f68bfc6ff154..a9e339b50dabc 100644
--- a/mlir/lib/Bindings/Python/DialectGPU.cpp
+++ b/mlir/lib/Bindings/Python/DialectGPU.cpp
@@ -25,6 +25,20 @@ using namespace mlir::python::adaptors;
 
 PYBIND11_MODULE(_mlirDialectsGPU, m) {
   m.doc() = "MLIR GPU Dialect";
+  //===-------------------------------------------------------------------===//
+  // AsyncTokenType
+  //===-------------------------------------------------------------------===//
+
+  auto mlirGPUAsyncTokenType =
+      mlir_type_subclass(m, "AsyncTokenType", mlirTypeIsAGPUAsyncTokenType);
+
+  mlirGPUAsyncTokenType.def_classmethod(
+      "get",
+      [](py::object cls, MlirContext ctx) {
+        return cls(mlirGPUAsyncTokenTypeGet(ctx));
+      },
+      "Gets an instance of AsyncTokenType in the same context", py::arg("cls"),
+      py::arg("ctx") = py::none());
 
   //===-------------------------------------------------------------------===//
   // ObjectAttr
diff --git a/mlir/lib/CAPI/Dialect/GPU.cpp b/mlir/lib/CAPI/Dialect/GPU.cpp
index e471e8cd9588e..0acebb2300429 100644
--- a/mlir/lib/CAPI/Dialect/GPU.cpp
+++ b/mlir/lib/CAPI/Dialect/GPU.cpp
@@ -15,6 +15,18 @@ using namespace mlir;
 
 MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(GPU, gpu, gpu::GPUDialect)
 
+//===-------------------------------------------------------------------===//
+// AsyncTokenType
+//===-------------------------------------------------------------------===//
+
+bool mlirTypeIsAGPUAsyncTokenType(MlirType type) {
+  return isa<gpu::AsyncTokenType>(unwrap(type));
+}
+
+MlirType mlirGPUAsyncTokenTypeGet(MlirContext ctx) {
+  return wrap(gpu::AsyncTokenType::get(unwrap(ctx)));
+}
+
 //===---------------------------------------------------------------------===//
 // ObjectAttr
 //===---------------------------------------------------------------------===//
diff --git a/mlir/test/Examples/NVGPU/Ch1.py b/mlir/test/Examples/NVGPU/Ch1.py
index da65aa2ef6a17..cfb48d56f8d49 100644
--- a/mlir/test/Examples/NVGPU/Ch1.py
+++ b/mlir/test/Examples/NVGPU/Ch1.py
@@ -23,7 +23,7 @@
 @NVDSL.mlir_func
 def saxpy(x, y, alpha):
     # 1. Use MLIR GPU dialect to allocate and copy memory
-    token_ty = ir.Type.parse("!gpu.async.token")
+    token_ty = gpu.AsyncTokenType.get()
     t1 = gpu.wait(token_ty, [])
     x_dev, t2 = gpu.alloc(x.type, token_ty, [t1], [], [])
     y_dev, t3 = gpu.alloc(y.type, token_ty, [t2], [], [])
diff --git a/mlir/test/Examples/NVGPU/Ch2.py b/mlir/test/Examples/NVGPU/Ch2.py
index 78c14cb2c7ad8..729913c6d5c4f 100644
--- a/mlir/test/Examples/NVGPU/Ch2.py
+++ b/mlir/test/Examples/NVGPU/Ch2.py
@@ -27,7 +27,7 @@
 
 @NVDSL.mlir_func
 def saxpy(x, y, alpha):
-    token_ty = ir.Type.parse("!gpu.async.token")
+    token_ty = gpu.AsyncTokenType.get()
     t1 = gpu.wait(token_ty, [])
     x_dev, t2 = gpu.alloc(x.type, token_ty, [t1], [], [])
     y_dev, t3 = gpu.alloc(y.type, token_ty, [t2], [], [])
diff --git a/mlir/test/Examples/NVGPU/Ch3.py b/mlir/test/Examples/NVGPU/Ch3.py
index a417014de8b49..eb96b11c63416 100644
--- a/mlir/test/Examples/NVGPU/Ch3.py
+++ b/mlir/test/Examples/NVGPU/Ch3.py
@@ -59,7 +59,7 @@ def tma_load(
 
 @NVDSL.mlir_func
 def gemm_128_128_64(a, b, d):
-    token_ty = ir.Type.parse("!gpu.async.token")
+    token_ty = gpu.AsyncTokenType.get()
     t1 = gpu.wait(token_ty, [])
     a_dev, t2 = gpu.alloc(a.type, token_ty, [t1], [], [])
     b_dev, t3 = gpu.alloc(b.type, token_ty, [t2], [], [])
diff --git a/mlir/test/Examples/NVGPU/Ch4.py b/mlir/test/Examples/NVGPU/Ch4.py
index 8f38d8a90add3..0e3460ff8d63b 100644
--- a/mlir/test/Examples/NVGPU/Ch4.py
+++ b/mlir/test/Examples/NVGPU/Ch4.py
@@ -258,7 +258,7 @@ def epilogue(D: WGMMAMatrix, d_dev):
 #   d -> memref<MxNxf32>
 @NVDSL.mlir_func
 def gemm_multistage(a, b, d, num_stages):
-    token_ty = ir.Type.parse("!gpu.async.token")
+    token_ty = gpu.AsyncTokenType.get()
     t1 = gpu.wait(token_ty, [])
     a_dev, t2 = gpu.alloc(a.type, token_ty, [t1], [], [])
     b_dev, t3 = gpu.alloc(b.type, token_ty, [t2], [], [])
diff --git a/mlir/test/Examples/NVGPU/Ch5.py b/mlir/test/Examples/NVGPU/Ch5.py
index 92e9314e1b812..f98cfd758a75f 100644
--- a/mlir/test/Examples/NVGPU/Ch5.py
+++ b/mlir/test/Examples/NVGPU/Ch5.py
@@ -252,7 +252,7 @@ def epilogue(D: WGMMAMatrix, d_dev):
 
 @NVDSL.mlir_func
 def gemm_warp_specialized(a, b, d, num_stages):
-    token_ty = ir.Type.parse("!gpu.async.token")
+    token_ty = gpu.AsyncTokenType.get()
     t1 = gpu.wait(token_ty, [])
     a_dev, t2 = gpu.alloc(a.type, token_ty, [t1], [], [])
     b_dev, t3 = gpu.alloc(b.type, token_ty, [t2], [], [])
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
index 5269522000f13..75f0dc947e068 100644
--- a/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
+++ b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
@@ -182,7 +182,7 @@ def generate_matmul_ws(
     assert K % BLOCK_K == 0
 
     module = ir.Module.create()
-    token_ty = ir.Type.parse("!gpu.async.token")
+    token_ty = gpu.AsyncTokenType.get()
     a_elem_ty = get_mlir_ty(input_type)
     b_elem_ty = get_mlir_ty(input_type)
     c_elem_ty = get_mlir_ty(output_type)
@@ -682,7 +682,7 @@ def generate_matmul_multistage(
     assert K % BLOCK_K == 0
 
     module = ir.Module.create()
-    token_ty = ir.Type.parse("!gpu.async.token")
+    token_ty = gpu.AsyncTokenType.get()
     a_elem_ty = get_mlir_ty(input_type)
     b_elem_ty = get_mlir_ty(input_type)
     c_elem_ty = get_mlir_ty(output_type)



More information about the Mlir-commits mailing list