[Mlir-commits] [mlir] [NVGPU] Fix nvdsl examples - take 2 (PR #167321)

Giacomo Castiglioni llvmlistbot at llvm.org
Mon Nov 10 06:39:25 PST 2025


https://github.com/castigli created https://github.com/llvm/llvm-project/pull/167321

This PR re-lands https://github.com/llvm/llvm-project/pull/156830

This PR aims at fixing the nvdsl examples which got a bit out of sync not being tested in the CI.

The fixed bugs were related to the following PRs:
- move to nanobind #118583
- split gpu module initialization #135478
- gpu dialect python API change #163883

>From 3da7d3caab538a9350fceeb466cbc068ef4a3df7 Mon Sep 17 00:00:00 2001
From: Giacomo Castiglioni <giacastiglioni at gmail.com>
Date: Tue, 2 Sep 2025 14:56:12 +0200
Subject: [PATCH 1/6] fix nvdsl

---
 mlir/test/Examples/NVGPU/Ch5.py                 | 2 +-
 mlir/test/Examples/NVGPU/tools/nvdsl.py         | 7 +++----
 mlir/test/Examples/NVGPU/tools/nvgpucompiler.py | 4 +++-
 3 files changed, 7 insertions(+), 6 deletions(-)

diff --git a/mlir/test/Examples/NVGPU/Ch5.py b/mlir/test/Examples/NVGPU/Ch5.py
index f98cfd758a75f..91c346c837dda 100644
--- a/mlir/test/Examples/NVGPU/Ch5.py
+++ b/mlir/test/Examples/NVGPU/Ch5.py
@@ -156,7 +156,7 @@ def producer_loop(
 ):
     phase = const(True, ty=T.bool())
 
-    for iv, phase in scf.for_(0, (K // TILE_K), 1, [phase]):
+    for iv, phase, _ in scf.for_(0, (K // TILE_K), 1, [phase]):
         stage = iv % num_stages
         # Wait MMA to be done
         mbar_mma[stage].try_wait(phase)
diff --git a/mlir/test/Examples/NVGPU/tools/nvdsl.py b/mlir/test/Examples/NVGPU/tools/nvdsl.py
index 90dbb2355e1c8..d4c50fc9bc28d 100644
--- a/mlir/test/Examples/NVGPU/tools/nvdsl.py
+++ b/mlir/test/Examples/NVGPU/tools/nvdsl.py
@@ -84,8 +84,7 @@ def arrive(self, txcount: int = 0, predicate=None):
                 self.mbar_group_op, txcount_op, self.id_op, predicate=predicate
             )
         else:
-            nvgpu.mbarrier_arrive(
-                ir.Type.parse("!nvgpu.mbarrier.token"), self.mbar_group_op, self.id_op
+            nvgpu.mbarrier_arrive(self.mbar_group_op, self.id_op
             )
 
     def try_wait(self, phase: bool = False, ticks: int = 10000000):
@@ -144,7 +143,7 @@ def create_descriptor(self, device_ptr):
             device_ptr,
         )
         self.tma_descriptor = nvgpu.TmaCreateDescriptorOp(
-            tma_descriptor_ty, device_unranked_memref, map(const, self.tma_box_shape)
+            tma_descriptor_ty, device_unranked_memref, list(map(const, self.tma_box_shape))
         )
         return self.tma_descriptor.result
 
@@ -156,7 +155,7 @@ def load(self, dest, mbarrier: Mbarriers, coords=[0], predicate=None):
             dest,
             mbarrier.mbar_group_op,
             self.tma_descriptor,
-            coordinates=map(const, coords),
+            coordinates=list(map(const, coords)),
             mbarId=mbarrier.id_op,
             predicate=predicate,
         )
diff --git a/mlir/test/Examples/NVGPU/tools/nvgpucompiler.py b/mlir/test/Examples/NVGPU/tools/nvgpucompiler.py
index 1c9cc74fcd169..4b661f8df6a9f 100644
--- a/mlir/test/Examples/NVGPU/tools/nvgpucompiler.py
+++ b/mlir/test/Examples/NVGPU/tools/nvgpucompiler.py
@@ -35,9 +35,11 @@ def compile(self, module: ir.Module):
 
     def jit(self, module: ir.Module) -> execution_engine.ExecutionEngine:
         """Wraps the module in a JIT execution engine."""
-        return execution_engine.ExecutionEngine(
+        ee = execution_engine.ExecutionEngine(
             module, opt_level=self.opt_level, shared_libs=self.shared_libs
         )
+        ee.initialize()
+        return ee
 
     def compile_and_jit(self, module: ir.Module) -> execution_engine.ExecutionEngine:
         """Compiles and jits the module."""

>From 1c755c5d24c81a3ee8adfafa924c4375800455d2 Mon Sep 17 00:00:00 2001
From: Giacomo Castiglioni <giacastiglioni at gmail.com>
Date: Wed, 8 Oct 2025 16:31:19 +0200
Subject: [PATCH 2/6] format

---
 mlir/test/Examples/NVGPU/tools/nvdsl.py | 7 ++++---
 1 file changed, 4 insertions(+), 3 deletions(-)

diff --git a/mlir/test/Examples/NVGPU/tools/nvdsl.py b/mlir/test/Examples/NVGPU/tools/nvdsl.py
index d4c50fc9bc28d..ab4e37fdfa9b7 100644
--- a/mlir/test/Examples/NVGPU/tools/nvdsl.py
+++ b/mlir/test/Examples/NVGPU/tools/nvdsl.py
@@ -84,8 +84,7 @@ def arrive(self, txcount: int = 0, predicate=None):
                 self.mbar_group_op, txcount_op, self.id_op, predicate=predicate
             )
         else:
-            nvgpu.mbarrier_arrive(self.mbar_group_op, self.id_op
-            )
+            nvgpu.mbarrier_arrive(self.mbar_group_op, self.id_op)
 
     def try_wait(self, phase: bool = False, ticks: int = 10000000):
         ticks_op = const(ticks)
@@ -143,7 +142,9 @@ def create_descriptor(self, device_ptr):
             device_ptr,
         )
         self.tma_descriptor = nvgpu.TmaCreateDescriptorOp(
-            tma_descriptor_ty, device_unranked_memref, list(map(const, self.tma_box_shape))
+            tma_descriptor_ty,
+            device_unranked_memref,
+            list(map(const, self.tma_box_shape)),
         )
         return self.tma_descriptor.result
 

>From 6598d8c18af6b47487923ecca5e6c9d55f7165e0 Mon Sep 17 00:00:00 2001
From: Giacomo Castiglioni <giacastiglioni at gmail.com>
Date: Mon, 13 Oct 2025 17:43:48 +0200
Subject: [PATCH 3/6] check either computation if sm_90 is available or dump
 and check IR without running

---
 mlir/test/Examples/NVGPU/Ch0.py         |  28 ++-
 mlir/test/Examples/NVGPU/Ch1.py         |  50 ++++-
 mlir/test/Examples/NVGPU/Ch2.py         |  84 ++++++++-
 mlir/test/Examples/NVGPU/Ch3.py         |  96 +++++++++-
 mlir/test/Examples/NVGPU/Ch4.py         | 196 +++++++++++++++++++-
 mlir/test/Examples/NVGPU/Ch5.py         | 176 +++++++++++++++++-
 mlir/test/Examples/NVGPU/tools/nvdsl.py | 234 ++++++++++++------------
 7 files changed, 712 insertions(+), 152 deletions(-)

diff --git a/mlir/test/Examples/NVGPU/Ch0.py b/mlir/test/Examples/NVGPU/Ch0.py
index 8f60088178d11..0caab36ee28fc 100644
--- a/mlir/test/Examples/NVGPU/Ch0.py
+++ b/mlir/test/Examples/NVGPU/Ch0.py
@@ -1,5 +1,8 @@
 # RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
-# RUN:   %PYTHON %s | FileCheck %s
+# RUN: env MLIR_RUN_CUDA_SM90_TESTS=%mlir_run_cuda_sm90_tests \
+# RUN: sh -c 'if [[ "$MLIR_RUN_CUDA_SM90_TESTS" == "1" ]]; \
+# RUN: then %PYTHON %s | FileCheck %s; \
+# RUN: else %PYTHON %s | FileCheck %s --check-prefix=DUMPIR; fi'
 
 # ===----------------------------------------------------------------------===//
 #  Chapter 0 : Hello World
@@ -18,10 +21,12 @@
 from tools.nvdsl import *
 
 
+dump_only = os.getenv("MLIR_RUN_CUDA_SM90_TESTS") != "1"
+
 # 1. The decorator generates a MLIR func.func.
 # Everything inside the Python function becomes the body of the func.
 # The decorator also translates `alpha` to an `index` type.
- at NVDSL.mlir_func
+ at NVDSL.mlir_func(dump_only)
 def main(alpha):
     # 2. The decorator generates a MLIR gpu.launch.
     # Everything inside the Python function becomes the body of the gpu.launch.
@@ -38,13 +43,28 @@ def kernel():
     # 3. Call the GPU kernel
     kernel()
 
-
 alpha = 100
 # 4. The `mlir_func` decorator JIT compiles the IR and executes the MLIR function.
 main(alpha)
 
-
 # CHECK: GPU thread 0 has 100
 # CHECK: GPU thread 1 has 101
 # CHECK: GPU thread 2 has 102
 # CHECK: GPU thread 3 has 103
+
+# DUMPIR:   func.func @main(%arg0: index) attributes {llvm.emit_c_interface} {
+# DUMPIR:     %[[C1:.*]] = arith.constant 1 : index
+# DUMPIR:     %[[C1_0:.*]] = arith.constant 1 : index
+# DUMPIR:     %[[C1_1:.*]] = arith.constant 1 : index
+# DUMPIR:     %[[C4:.*]] = arith.constant 4 : index
+# DUMPIR:     %[[C1_2:.*]] = arith.constant 1 : index
+# DUMPIR:     %[[C1_3:.*]] = arith.constant 1 : index
+# DUMPIR:     %[[C0_I32:.*]] = arith.constant 0 : i32
+# DUMPIR:     gpu.launch blocks(%arg1, %arg2, %arg3) in (%arg7 = %[[C1]], %arg8 = %[[C1_0]], %arg9 = %[[C1_1]]) threads(%arg4, %arg5, %arg6) in (%arg10 = %[[C4]], %arg11 = %[[C1_2]], %arg12 = %[[C1_3]]) dynamic_shared_memory_size %[[C0_I32]] {
+# DUMPIR:       %[[TIDX:.*]] = gpu.thread_id  x
+# DUMPIR:       %[[MYVAL:.*]] = arith.addi %arg0, %[[TIDX]] : index
+# DUMPIR:       gpu.printf "GPU thread %llu has %llu\0A", %[[TIDX]], %[[MYVAL]] : index, index
+# DUMPIR:       gpu.terminator
+# DUMPIR:     }
+# DUMPIR:     return
+# DUMPIR:   }
diff --git a/mlir/test/Examples/NVGPU/Ch1.py b/mlir/test/Examples/NVGPU/Ch1.py
index cfb48d56f8d49..9fa7d82ae6688 100644
--- a/mlir/test/Examples/NVGPU/Ch1.py
+++ b/mlir/test/Examples/NVGPU/Ch1.py
@@ -1,5 +1,8 @@
 # RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
-# RUN:   %PYTHON %s | FileCheck %s
+# RUN: env MLIR_RUN_CUDA_SM90_TESTS=%mlir_run_cuda_sm90_tests \
+# RUN: sh -c 'if [[ "$MLIR_RUN_CUDA_SM90_TESTS" == "1" ]]; \
+# RUN: then %PYTHON %s | FileCheck %s; \
+# RUN: else %PYTHON %s | FileCheck %s --check-prefix=DUMPIR; fi'
 
 # ===----------------------------------------------------------------------===//
 #  Chapter 1 : 2D Saxpy
@@ -19,8 +22,9 @@
 from tools.nvdsl import *
 import numpy as np
 
+dump_only = os.getenv("MLIR_RUN_CUDA_SM90_TESTS") != "1"
 
- at NVDSL.mlir_func
+ at NVDSL.mlir_func(dump_only)
 def saxpy(x, y, alpha):
     # 1. Use MLIR GPU dialect to allocate and copy memory
     token_ty = gpu.AsyncTokenType.get()
@@ -56,11 +60,43 @@ def saxpy_kernel():
 alpha = 2.0
 x = np.random.randn(M, N).astype(np.float32)
 y = np.ones((M, N), np.float32)
+
 saxpy(x, y, alpha)
 
-#  4. Verify MLIR with reference computation
-ref = np.ones((M, N), np.float32)
-ref += x * alpha
-np.testing.assert_allclose(y, ref, rtol=5e-03, atol=1e-01)
-print("PASS")
+if not dump_only:
+    # 4. Verify MLIR with reference computation
+    ref = np.ones((M, N), np.float32)
+    ref += x * alpha
+    np.testing.assert_allclose(y, ref, rtol=5e-03, atol=1e-01)
+    print("PASS")
 # CHECK-NOT: Mismatched elements
+# CHECK: PASS
+
+# DUMPIR:   func.func @saxpy(%arg0: memref<256x32xf32>, %arg1: memref<256x32xf32>, %arg2: f32) attributes {llvm.emit_c_interface} {
+# DUMPIR:     %[[WAIT0:.*]] = gpu.wait async
+# DUMPIR:     %[[MEMREF:.*]], %[[ASYNC0:.*]] = gpu.alloc async [%[[WAIT0]]] () : memref<256x32xf32>
+# DUMPIR:     %[[MEMREF0:.*]], %[[ASYNC1:.*]] = gpu.alloc async [%[[ASYNC0]]] () : memref<256x32xf32>
+# DUMPIR:     %[[MEMCPY1:.*]] = gpu.memcpy async [%[[ASYNC1]]] %[[MEMREF]], %arg0 : memref<256x32xf32>, memref<256x32xf32>
+# DUMPIR:     %[[MEMCPY2:.*]] = gpu.memcpy async [%[[MEMCPY1]]] %[[MEMREF0]], %arg1 : memref<256x32xf32>, memref<256x32xf32>
+# DUMPIR:     %[[WAIT1:.*]] = gpu.wait async [%[[MEMCPY2]]]
+# DUMPIR:     %[[C256:.*]] = arith.constant 256 : index
+# DUMPIR:     %[[C1:.*]] = arith.constant 1 : index
+# DUMPIR:     %[[C1_2:.*]] = arith.constant 1 : index
+# DUMPIR:     %[[C32:.*]] = arith.constant 32 : index
+# DUMPIR:     %[[C1_3:.*]] = arith.constant 1 : index
+# DUMPIR:     %[[C1_4:.*]] = arith.constant 1 : index
+# DUMPIR:     %[[C0_I32:.*]] = arith.constant 0 : i32
+# DUMPIR:     gpu.launch blocks(%arg3, %arg4, %arg5) in (%arg9 = %[[C256]], %arg10 = %[[C1]], %arg11 = %[[C1_2]]) threads(%arg6, %arg7, %arg8) in (%arg12 = %[[C32]], %arg13 = %[[C1_3]], %arg14 = %[[C1_4]]) dynamic_shared_memory_size %[[C0_I32]] {
+# DUMPIR:       %[[BLOCKID:.*]] = gpu.block_id  x
+# DUMPIR:       %[[THREADID:.*]] = gpu.thread_id  x
+# DUMPIR:       %[[LD0:.*]] = memref.load %[[MEMREF]][%[[BLOCKID]], %[[THREADID]]] : memref<256x32xf32>
+# DUMPIR:       %[[LD1:.*]] = memref.load %[[MEMREF0]][%[[BLOCKID]], %[[THREADID]]] : memref<256x32xf32>
+# DUMPIR:       %[[MUL:.*]] = arith.mulf %[[LD0]], %arg2 : f32
+# DUMPIR:       %[[ADD:.*]] = arith.addf %[[LD1]], %[[MUL]] : f32
+# DUMPIR:       memref.store %[[ADD]], %[[MEMREF0]][%[[BLOCKID]], %[[THREADID]]] : memref<256x32xf32>
+# DUMPIR:       gpu.terminator
+# DUMPIR:     }
+# DUMPIR:     %[[MEMCPY3:.*]] = gpu.memcpy async [%[[WAIT1]]] %arg1, %[[MEMREF0]] : memref<256x32xf32>, memref<256x32xf32>
+# DUMPIR:     %[[WAIT2:.*]] = gpu.wait async [%[[MEMCPY3]]]
+# DUMPIR:     return
+# DUMPIR:   }
diff --git a/mlir/test/Examples/NVGPU/Ch2.py b/mlir/test/Examples/NVGPU/Ch2.py
index 729913c6d5c4f..9d35833027e9f 100644
--- a/mlir/test/Examples/NVGPU/Ch2.py
+++ b/mlir/test/Examples/NVGPU/Ch2.py
@@ -1,5 +1,8 @@
 # RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
-# RUN:   %PYTHON %s | FileCheck %s
+# RUN: env MLIR_RUN_CUDA_SM90_TESTS=%mlir_run_cuda_sm90_tests \
+# RUN: sh -c 'if [[ "$MLIR_RUN_CUDA_SM90_TESTS" == "1" ]]; \
+# RUN: then %PYTHON %s | FileCheck %s; \
+# RUN: else %PYTHON %s | FileCheck %s --check-prefix=DUMPIR; fi'
 
 # ===----------------------------------------------------------------------===//
 #  Chapter 2 : 2D Saxpy with TMA
@@ -24,8 +27,9 @@
 from mlir.extras import types as T
 import numpy as np
 
+dump_only = os.getenv("MLIR_RUN_CUDA_SM90_TESTS") != "1"
 
- at NVDSL.mlir_func
+ at NVDSL.mlir_func(dump_only)
 def saxpy(x, y, alpha):
     token_ty = gpu.AsyncTokenType.get()
     t1 = gpu.wait(token_ty, [])
@@ -85,9 +89,75 @@ def saxpy_tma_kernel():
 y = np.ones((M, N), np.float32)
 saxpy(x, y, alpha)
 
-#  4. Verify MLIR with reference computation
-ref = np.ones((M, N), np.float32)
-ref += x * alpha
-np.testing.assert_allclose(y, ref, rtol=5e-03, atol=1e-01)
-print("PASS")
+if not dump_only:
+    #  4. Verify MLIR with reference computation
+    ref = np.ones((M, N), np.float32)
+    ref += x * alpha
+    np.testing.assert_allclose(y, ref, rtol=5e-03, atol=1e-01)
+    print("PASS")
 # CHECK-NOT: Mismatched elements
+# CHECK: PASS
+
+# DUMPIR:   func.func @saxpy(%arg0: memref<256x32xf32>, %arg1: memref<256x32xf32>, %arg2: f32) attributes {llvm.emit_c_interface} {
+# DUMPIR:     %[[WAIT0:.*]] = gpu.wait async
+# DUMPIR:     %[[MEMREF:.*]], %[[ASYNC0:.*]] = gpu.alloc async [%[[WAIT0]]] () : memref<256x32xf32>
+# DUMPIR:     %[[MEMREF0:.*]], %[[ASYNC1:.*]] = gpu.alloc async [%[[ASYNC0]]] () : memref<256x32xf32>
+# DUMPIR:     %[[MEMCPY1:.*]] = gpu.memcpy async [%[[ASYNC1]]] %[[MEMREF]], %arg0 : memref<256x32xf32>, memref<256x32xf32>
+# DUMPIR:     %[[MEMCPY2:.*]] = gpu.memcpy async [%[[MEMCPY1]]] %[[MEMREF0]], %arg1 : memref<256x32xf32>, memref<256x32xf32>
+# DUMPIR:     %[[WAIT1:.*]] = gpu.wait async [%[[MEMCPY2]]]
+# DUMPIR:     %[[CAST:.*]] = memref.cast %[[MEMREF]] : memref<256x32xf32> to memref<*xf32>
+# DUMPIR:     %[[C1:.*]] = arith.constant 1 : index
+# DUMPIR:     %[[C32:.*]] = arith.constant 32 : index
+# DUMPIR:     %[[TMA0:.*]] = nvgpu.tma.create.descriptor %[[CAST]] box[%[[C1]], %[[C32]]] : memref<*xf32> -> <tensor = memref<1x32xf32, 3>, swizzle = none, l2promo = none, oob = zero, interleave = none>
+# DUMPIR:     %[[CAST2:.*]] = memref.cast %[[MEMREF0]] : memref<256x32xf32> to memref<*xf32>
+# DUMPIR:     %[[C1_3:.*]] = arith.constant 1 : index
+# DUMPIR:     %[[C32_4:.*]] = arith.constant 32 : index
+# DUMPIR:     %[[TMA1:.*]] = nvgpu.tma.create.descriptor %[[CAST2]] box[%[[C1_3]], %[[C32_4]]] : memref<*xf32> -> <tensor = memref<1x32xf32, 3>, swizzle = none, l2promo = none, oob = zero, interleave = none>
+# DUMPIR:     %[[C256:.*]] = arith.constant 256 : index
+# DUMPIR:     %[[C1_5:.*]] = arith.constant 1 : index
+# DUMPIR:     %[[C1_6:.*]] = arith.constant 1 : index
+# DUMPIR:     %[[C32_7:.*]] = arith.constant 32 : index
+# DUMPIR:     %[[C1_8:.*]] = arith.constant 1 : index
+# DUMPIR:     %[[C1_9:.*]] = arith.constant 1 : index
+# DUMPIR:     %[[C256_I32:.*]] = arith.constant 256 : i32
+# DUMPIR:     gpu.launch blocks(%arg3, %arg4, %arg5) in (%arg9 = %[[C256]], %arg10 = %[[C1_5]], %arg11 = %[[C1_6]]) threads(%arg6, %arg7, %arg8) in (%arg12 = %[[C32_7]], %arg13 = %[[C1_8]], %arg14 = %[[C1_9]]) dynamic_shared_memory_size %[[C256_I32]] {
+# DUMPIR:       %[[BLOCKID:.*]] = gpu.block_id  x
+# DUMPIR:       %[[THREADID:.*]] = gpu.thread_id  x
+# DUMPIR:       %[[C0:.*]] = arith.constant 0 : index
+# DUMPIR:       %[[EQ:.*]] = arith.cmpi eq, %[[THREADID]], %[[C0]] : index
+# DUMPIR:       %[[MB:.*]] = nvgpu.mbarrier.create -> <memorySpace = #gpu.address_space<workgroup>>
+# DUMPIR:       %[[C0_10:.*]] = arith.constant 0 : index
+# DUMPIR:       %[[C1_11:.*]] = arith.constant 1 : index
+# DUMPIR:       nvgpu.mbarrier.init %[[MB]][%[[C0_10]]], %[[C1_11]], predicate = %[[EQ]] : <memorySpace = #gpu.address_space<workgroup>>
+# DUMPIR:       %[[DSM0:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
+# DUMPIR:       %[[C0_12:.*]] = arith.constant 0 : index
+# DUMPIR:       %[[VIEW:.*]] = memref.view %[[DSM0]][%[[C0_12]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<1x32xf32, #gpu.address_space<workgroup>>
+# DUMPIR:       %[[DSM1:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
+# DUMPIR:       %[[C128:.*]] = arith.constant 128 : index
+# DUMPIR:       %[[VIEW_13:.*]] = memref.view %[[DSM1]][%[[C128]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<1x32xf32, #gpu.address_space<workgroup>>
+# DUMPIR:       %[[C0_14:.*]] = arith.constant 0 : index
+# DUMPIR:       %[[C0_15:.*]] = arith.constant 0 : index
+# DUMPIR:       nvgpu.tma.async.load %[[TMA0]][%[[C0_15]], %[[BLOCKID]]], %[[MB]][%[[C0_14]]] to %[[VIEW]], predicate = %[[EQ]] : <tensor = memref<1x32xf32, 3>, swizzle = none, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>> -> memref<1x32xf32, #gpu.address_space<workgroup>>
+# DUMPIR:       %[[C0_16:.*]] = arith.constant 0 : index
+# DUMPIR:       %[[C0_17:.*]] = arith.constant 0 : index
+# DUMPIR:       nvgpu.tma.async.load %[[TMA1]][%[[C0_17]], %[[BLOCKID]]], %[[MB]][%[[C0_16]]] to %[[VIEW_13]], predicate = %[[EQ]] : <tensor = memref<1x32xf32, 3>, swizzle = none, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>> -> memref<1x32xf32, #gpu.address_space<workgroup>>
+# DUMPIR:       %[[C0_18:.*]] = arith.constant 0 : index
+# DUMPIR:       %[[C256_19:.*]] = arith.constant 256 : index
+# DUMPIR:       nvgpu.mbarrier.arrive.expect_tx %[[MB]][%[[C0_18]]], %[[C256_19]], predicate = %[[EQ]] : <memorySpace = #gpu.address_space<workgroup>>
+# DUMPIR:       %[[C0_20:.*]] = arith.constant 0 : index
+# DUMPIR:       %[[C10000000:.*]] = arith.constant 10000000 : index
+# DUMPIR:       %[[FALSE:.*]] = arith.constant false
+# DUMPIR:       nvgpu.mbarrier.try_wait.parity %[[MB]][%[[C0_20]]], %[[FALSE]], %[[C10000000]] : <memorySpace = #gpu.address_space<workgroup>>
+# DUMPIR:       %[[C0_21:.*]] = arith.constant 0 : index
+# DUMPIR:       %[[LD0:.*]] = memref.load %[[VIEW]][%[[C0_21]], %[[THREADID]]] : memref<1x32xf32, #gpu.address_space<workgroup>>
+# DUMPIR:       %[[C0_22:.*]] = arith.constant 0 : index
+# DUMPIR:       %[[LD1:.*]] = memref.load %[[VIEW_13]][%[[C0_22]], %[[THREADID]]] : memref<1x32xf32, #gpu.address_space<workgroup>>
+# DUMPIR:       %[[MUL:.*]] = arith.mulf %[[LD0]], %arg2 : f32
+# DUMPIR:       %[[ADD:.*]] = arith.addf %[[LD1]], %[[MUL]] : f32
+# DUMPIR:       memref.store %[[ADD]], %[[MEMREF0]][%[[BLOCKID]], %[[THREADID]]] : memref<256x32xf32>
+# DUMPIR:       gpu.terminator
+# DUMPIR:     }
+# DUMPIR:     %[[MEMCPY3:.*]] = gpu.memcpy async [%[[WAIT1]]] %arg1, %[[MEMREF0]] : memref<256x32xf32>, memref<256x32xf32>
+# DUMPIR:     %[[WAIT2:.*]] = gpu.wait async [%[[MEMCPY3]]]
+# DUMPIR:     return
+# DUMPIR:   }
diff --git a/mlir/test/Examples/NVGPU/Ch3.py b/mlir/test/Examples/NVGPU/Ch3.py
index eb96b11c63416..4b44c89eabf2e 100644
--- a/mlir/test/Examples/NVGPU/Ch3.py
+++ b/mlir/test/Examples/NVGPU/Ch3.py
@@ -1,5 +1,8 @@
 # RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
-# RUN:   %PYTHON %s | FileCheck %s
+# RUN: env MLIR_RUN_CUDA_SM90_TESTS=%mlir_run_cuda_sm90_tests \
+# RUN: sh -c 'if [[ "$MLIR_RUN_CUDA_SM90_TESTS" == "1" ]]; \
+# RUN: then %PYTHON %s | FileCheck %s; \
+# RUN: else %PYTHON %s | FileCheck %s --check-prefix=DUMPIR; fi'
 
 # ===----------------------------------------------------------------------===//
 #  Chapter 3 : GEMM 128x128x64 with Tensor Core
@@ -21,6 +24,7 @@
 from mlir.extras import types as T
 import numpy as np
 
+dump_only = os.getenv("MLIR_RUN_CUDA_SM90_TESTS") != "1"
 
 def tma_load(
     mbar_group: Mbarriers,
@@ -57,7 +61,7 @@ def tma_load(
     b_tma.load(b2, mbar_group[0], coords=[64, 0], predicate=p)
 
 
- at NVDSL.mlir_func
+ at NVDSL.mlir_func(dump_only)
 def gemm_128_128_64(a, b, d):
     token_ty = gpu.AsyncTokenType.get()
     t1 = gpu.wait(token_ty, [])
@@ -123,7 +127,89 @@ def gemm_tma_kernel():
 d = np.zeros((M, N), np.float32)
 gemm_128_128_64(a, b, d)
 
-ref_d = a.astype(np.float16) @ b.astype(np.float16)
-np.testing.assert_allclose(d, ref_d, rtol=5e-03, atol=1e-01)
-print("PASS")
+if not dump_only:
+    # Verify MLIR program with reference computation in python
+    ref_d = a.astype(np.float16) @ b.astype(np.float16)
+    np.testing.assert_allclose(d, ref_d, rtol=5e-03, atol=1e-01)
+    print("PASS")
 # CHECK-NOT: Mismatched elements
+# CHECK: PASS
+
+# DUMPIR:   func.func @gemm_128_128_64(%arg0: memref<128x64xf16>, %arg1: memref<64x128xf16>, %arg2: memref<128x128xf32>) attributes {llvm.emit_c_interface} {
+# DUMPIR:     %[[WAIT0:.*]] = gpu.wait async
+# DUMPIR:     %[[MEM0:.*]], %[[ASYNC0:.*]] = gpu.alloc async [%[[WAIT0]]] () : memref<128x64xf16>
+# DUMPIR:     %[[MEM1:.*]], %[[ASYNC1:.*]] = gpu.alloc async [%[[ASYNC0]]] () : memref<64x128xf16>
+# DUMPIR:     %[[MEM2:.*]], %[[ASYNC2:.*]] = gpu.alloc async [%[[ASYNC1]]] () : memref<128x128xf32>
+# DUMPIR:     %[[CPY1:.*]] = gpu.memcpy async [%[[ASYNC2]]] %[[MEM0]], %arg0 : memref<128x64xf16>, memref<128x64xf16>
+# DUMPIR:     %[[CPY2:.*]] = gpu.memcpy async [%[[CPY1]]] %[[MEM1]], %arg1 : memref<64x128xf16>, memref<64x128xf16>
+# DUMPIR:     %[[WAIT1:.*]] = gpu.wait async [%[[CPY2]]]
+# DUMPIR:     %[[CAST0:.*]] = memref.cast %[[MEM0]] : memref<128x64xf16> to memref<*xf16>
+# DUMPIR:     %[[C128:.*]] = arith.constant 128 : index
+# DUMPIR:     %[[C64:.*]] = arith.constant 64 : index
+# DUMPIR:     %[[TMA0:.*]] = nvgpu.tma.create.descriptor %[[CAST0]] box[%[[C128]], %[[C64]]] : memref<*xf16> -> <tensor = memref<128x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>
+# DUMPIR:     %[[CAST1:.*]] = memref.cast %[[MEM1]] : memref<64x128xf16> to memref<*xf16>
+# DUMPIR:     %[[C64_5:.*]] = arith.constant 64 : index
+# DUMPIR:     %[[C64_6:.*]] = arith.constant 64 : index
+# DUMPIR:     %[[TMA1:.*]] = nvgpu.tma.create.descriptor %[[CAST1]] box[%[[C64_5]], %[[C64_6]]] : memref<*xf16> -> <tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>
+# DUMPIR:     %[[C1:.*]] = arith.constant 1 : index
+# DUMPIR:     %[[C1_7:.*]] = arith.constant 1 : index
+# DUMPIR:     %[[C1_8:.*]] = arith.constant 1 : index
+# DUMPIR:     %[[C128_9:.*]] = arith.constant 128 : index
+# DUMPIR:     %[[C1_10:.*]] = arith.constant 1 : index
+# DUMPIR:     %[[C1_11:.*]] = arith.constant 1 : index
+# DUMPIR:     %[[C32768_I32:.*]] = arith.constant 32768 : i32
+# DUMPIR:     gpu.launch blocks(%arg3, %arg4, %arg5) in (%arg9 = %[[C1]], %arg10 = %[[C1_7]], %arg11 = %[[C1_8]]) threads(%arg6, %arg7, %arg8) in (%arg12 = %[[C128_9]], %arg13 = %[[C1_10]], %arg14 = %[[C1_11]]) dynamic_shared_memory_size %[[C32768_I32]] {
+# DUMPIR:       %[[THREADID:.*]] = gpu.thread_id  x
+# DUMPIR:       %[[MB:.*]] = nvgpu.mbarrier.create -> <memorySpace = #gpu.address_space<workgroup>>
+# DUMPIR:       %[[C0:.*]] = arith.constant 0 : index
+# DUMPIR:       %[[EQ:.*]] = arith.cmpi eq, %[[THREADID]], %[[C0]] : index
+# DUMPIR:       %[[C0_12:.*]] = arith.constant 0 : index
+# DUMPIR:       %[[C1_13:.*]] = arith.constant 1 : index
+# DUMPIR:       nvgpu.mbarrier.init %[[MB]][%[[C0_12]]], %[[C1_13]], predicate = %[[EQ]] : <memorySpace = #gpu.address_space<workgroup>>
+# DUMPIR:       nvgpu.tma.prefetch.descriptor %[[TMA0]], predicate = %[[EQ]] : <tensor = memref<128x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>
+# DUMPIR:       nvgpu.tma.prefetch.descriptor %[[TMA1]], predicate = %[[EQ]] : <tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>
+# DUMPIR:       %[[DSM0:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
+# DUMPIR:       %[[C0_14:.*]] = arith.constant 0 : index
+# DUMPIR:       %[[VIEW:.*]] = memref.view %[[DSM0]][%[[C0_14]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<128x64xf16, #gpu.address_space<workgroup>>
+# DUMPIR:       %[[DSM1:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
+# DUMPIR:       %[[C16384:.*]] = arith.constant 16384 : index
+# DUMPIR:       %[[VIEW_15:.*]] = memref.view %[[DSM1]][%[[C16384]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<64x128xf16, #gpu.address_space<workgroup>>
+# DUMPIR:       %[[DSM2:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
+# DUMPIR:       %[[C0_16:.*]] = arith.constant 0 : index
+# DUMPIR:       %[[VIEW_17:.*]] = memref.view %[[DSM2]][%[[C0_16]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<128x64xf16, #gpu.address_space<workgroup>>
+# DUMPIR:       %[[DSM3:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
+# DUMPIR:       %[[C16384_18:.*]] = arith.constant 16384 : index
+# DUMPIR:       %[[VIEW_19:.*]] = memref.view %[[DSM3]][%[[C16384_18]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<64x64xf16, #gpu.address_space<workgroup>>
+# DUMPIR:       %[[DSM4:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
+# DUMPIR:       %[[C24576:.*]] = arith.constant 24576 : index
+# DUMPIR:       %[[VIEW_20:.*]] = memref.view %[[DSM4]][%[[C24576]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<64x64xf16, #gpu.address_space<workgroup>>
+# DUMPIR:       %[[C0_21:.*]] = arith.constant 0 : index
+# DUMPIR:       %[[C32768:.*]] = arith.constant 32768 : index
+# DUMPIR:       nvgpu.mbarrier.arrive.expect_tx %[[MB]][%[[C0_21]]], %[[C32768]], predicate = %[[EQ]] : <memorySpace = #gpu.address_space<workgroup>>
+# DUMPIR:       %[[C0_22:.*]] = arith.constant 0 : index
+# DUMPIR:       %[[C0_23:.*]] = arith.constant 0 : index
+# DUMPIR:       %[[C0_24:.*]] = arith.constant 0 : index
+# DUMPIR:       nvgpu.tma.async.load %[[TMA0]][%[[C0_23]], %[[C0_24]]], %[[MB]][%[[C0_22]]] to %[[VIEW_17]], predicate = %[[EQ]] : <tensor = memref<128x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>> -> memref<128x64xf16, #gpu.address_space<workgroup>>
+# DUMPIR:       %[[C0_25:.*]] = arith.constant 0 : index
+# DUMPIR:       %[[C0_26:.*]] = arith.constant 0 : index
+# DUMPIR:       %[[C0_27:.*]] = arith.constant 0 : index
+# DUMPIR:       nvgpu.tma.async.load %[[TMA1]][%[[C0_26]], %[[C0_27]]], %[[MB]][%[[C0_25]]] to %[[VIEW_19]], predicate = %[[EQ]] : <tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>> -> memref<64x64xf16, #gpu.address_space<workgroup>>
+# DUMPIR:       %[[C0_28:.*]] = arith.constant 0 : index
+# DUMPIR:       %[[C64_29:.*]] = arith.constant 64 : index
+# DUMPIR:       %[[C0_30:.*]] = arith.constant 0 : index
+# DUMPIR:       nvgpu.tma.async.load %[[TMA1]][%[[C64_29]], %[[C0_30]]], %[[MB]][%[[C0_28]]] to %[[VIEW_20]], predicate = %[[EQ]] : <tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>> -> memref<64x64xf16, #gpu.address_space<workgroup>>
+# DUMPIR:       %[[C0_31:.*]] = arith.constant 0 : index
+# DUMPIR:       %[[C10000000:.*]] = arith.constant 10000000 : index
+# DUMPIR:       %[[FALSE:.*]] = arith.constant false
+# DUMPIR:       nvgpu.mbarrier.try_wait.parity %[[MB]][%[[C0_31]]], %[[FALSE]], %[[C10000000]] : <memorySpace = #gpu.address_space<workgroup>>
+# DUMPIR:       %[[WG_ACC:.*]] = nvgpu.warpgroup.mma.init.accumulator -> <fragmented = vector<128x128xf32>>
+# DUMPIR:       %[[GEN0:.*]] = nvgpu.warpgroup.generate.descriptor %[[VIEW]], %[[TMA0]] : memref<128x64xf16, #gpu.address_space<workgroup>>, <tensor = memref<128x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none> -> <tensor = memref<128x64xf16, #gpu.address_space<workgroup>>>
+# DUMPIR:       %[[GEN1:.*]] = nvgpu.warpgroup.generate.descriptor %[[VIEW_15]], %[[TMA1]] : memref<64x128xf16, #gpu.address_space<workgroup>>, <tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none> -> <tensor = memref<64x128xf16, #gpu.address_space<workgroup>>>
+# DUMPIR:       %[[MMA:.*]] = nvgpu.warpgroup.mma %[[GEN0]], %[[GEN1]], %[[WG_ACC]] {transposeB} : <tensor = memref<128x64xf16, #gpu.address_space<workgroup>>>, <tensor = memref<64x128xf16, #gpu.address_space<workgroup>>>, <fragmented = vector<128x128xf32>> -> <fragmented = vector<128x128xf32>>
+# DUMPIR:       nvgpu.warpgroup.mma.store %[[MMA]], %[[MEM2]] : <fragmented = vector<128x128xf32>> to memref<128x128xf32>
+# DUMPIR:       gpu.terminator
+# DUMPIR:     }
+# DUMPIR:     %[[CPY3:.*]] = gpu.memcpy async [%[[WAIT1]]] %arg2, %[[MEM2]] : memref<128x128xf32>, memref<128x128xf32>
+# DUMPIR:     gpu.wait [%[[CPY3]]]
+# DUMPIR:     return
+# DUMPIR:   }
diff --git a/mlir/test/Examples/NVGPU/Ch4.py b/mlir/test/Examples/NVGPU/Ch4.py
index 0e3460ff8d63b..fd6b40203f839 100644
--- a/mlir/test/Examples/NVGPU/Ch4.py
+++ b/mlir/test/Examples/NVGPU/Ch4.py
@@ -1,5 +1,9 @@
 # RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
-# RUN:   %PYTHON %s | FileCheck %s
+# RUN: env MLIR_RUN_CUDA_SM90_TESTS=%mlir_run_cuda_sm90_tests \
+# RUN: sh -c 'if [[ "$MLIR_RUN_CUDA_SM90_TESTS" == "1" ]]; \
+# RUN: then %PYTHON %s | FileCheck %s; \
+# RUN: else %PYTHON %s | FileCheck %s --check-prefix=DUMPIR; fi'
+
 
 # ===----------------------------------------------------------------------===//
 #  Chapter 4 : Multistage GEMM with Tensor Core
@@ -47,6 +51,7 @@
 from tools.nvdsl import *
 import numpy as np
 
+dump_only = os.getenv("MLIR_RUN_CUDA_SM90_TESTS") != "1"
 
 def partition_shape():
     """
@@ -256,7 +261,7 @@ def epilogue(D: WGMMAMatrix, d_dev):
 #   a -> memref<MxKxf16>
 #   b -> memref<NxKf16>
 #   d -> memref<MxNxf32>
- at NVDSL.mlir_func
+ at NVDSL.mlir_func(dump_only)
 def gemm_multistage(a, b, d, num_stages):
     token_ty = gpu.AsyncTokenType.get()
     t1 = gpu.wait(token_ty, [])
@@ -314,10 +319,187 @@ def gemm_multistage_kernel():
 gemm_multistage(a, b, d, num_stages=7)
 
 
-# Verify MLIR with reference computation
-ref_d = a.astype(np.float16) @ b.astype(np.float16)
-np.testing.assert_allclose(d, ref_d, rtol=5e-03, atol=1e-01)
-
+if not dump_only:
+    # Verify MLIR with reference computation
+    ref_d = a.astype(np.float16) @ b.astype(np.float16)
+    np.testing.assert_allclose(d, ref_d, rtol=5e-03, atol=1e-01)
 
-print("PASS")
+    print("PASS")
 # CHECK-NOT: Mismatched elements
+# CHECK: PASS
+
+# DUMPIR:   func.func @gemm_multistage(%arg0: memref<512x1024xf16>, %arg1: memref<1024x256xf16>, %arg2: memref<512x256xf32>) attributes {llvm.emit_c_interface} {
+# DUMPIR:     %[[WAIT:.*]] = gpu.wait async
+# DUMPIR:     %[[AMEM:.*]], %[[ATOK:.*]] = gpu.alloc async [%[[WAIT]]] () : memref<512x1024xf16>
+# DUMPIR:     %[[BMEM:.*]], %[[BTOK:.*]] = gpu.alloc async [%[[ATOK]]] () : memref<1024x256xf16>
+# DUMPIR:     %[[DMEM:.*]], %[[DTOK:.*]] = gpu.alloc async [%[[BTOK]]] () : memref<512x256xf32>
+# DUMPIR:     %[[CPYA:.*]] = gpu.memcpy async [%[[DTOK]]] %[[AMEM]], %arg0 : memref<512x1024xf16>, memref<512x1024xf16>
+# DUMPIR:     %[[CPYB:.*]] = gpu.memcpy async [%[[CPYA]]] %[[BMEM]], %arg1 : memref<1024x256xf16>, memref<1024x256xf16>
+# DUMPIR:     %[[WAIT2:.*]] = gpu.wait async [%[[CPYB]]]
+# DUMPIR:     %[[CASTA:.*]] = memref.cast %[[AMEM]] : memref<512x1024xf16> to memref<*xf16>
+# DUMPIR:     %[[C128:.*]] = arith.constant 128 : index
+# DUMPIR:     %[[C64:.*]] = arith.constant 64 : index
+# DUMPIR:     %[[TMAA:.*]] = nvgpu.tma.create.descriptor %[[CASTA]] box[%[[C128]], %[[C64]]] : memref<*xf16> -> <tensor = memref<128x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>
+# DUMPIR:     %[[CASTB:.*]] = memref.cast %[[BMEM]] : memref<1024x256xf16> to memref<*xf16>
+# DUMPIR:     %[[C64_B1:.*]] = arith.constant 64 : index
+# DUMPIR:     %[[C64_B2:.*]] = arith.constant 64 : index
+# DUMPIR:     %[[TMAB:.*]] = nvgpu.tma.create.descriptor %[[CASTB]] box[%[[C64_B1]], %[[C64_B2]]] : memref<*xf16> -> <tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>
+# DUMPIR:     %[[C4:.*]] = arith.constant 4 : index
+# DUMPIR:     %[[C2:.*]] = arith.constant 2 : index
+# DUMPIR:     %[[C1:.*]] = arith.constant 1 : index
+# DUMPIR:     %[[C128_GRID:.*]] = arith.constant 128 : index
+# DUMPIR:     %[[C1_T1:.*]] = arith.constant 1 : index
+# DUMPIR:     %[[C1_T2:.*]] = arith.constant 1 : index
+# DUMPIR:     %[[SMEM_SIZE:.*]] = arith.constant 229376 : i32
+# DUMPIR:     gpu.launch blocks(%arg3, %arg4, %arg5) in (%arg9 = %[[C4]], %arg10 = %[[C2]], %arg11 = %[[C1]]) threads(%arg6, %arg7, %arg8) in (%arg12 = %[[C128_GRID]], %arg13 = %[[C1_T1]], %arg14 = %[[C1_T2]]) dynamic_shared_memory_size %[[SMEM_SIZE]] {
+# DUMPIR:       %[[TID_X:.*]] = gpu.thread_id  x
+# DUMPIR:       %[[MBAR:.*]] = nvgpu.mbarrier.create -> <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7>
+# DUMPIR:       %[[C0:.*]] = arith.constant 0 : index
+# DUMPIR:       %[[IS_THREAD0:.*]] = arith.cmpi eq, %[[TID_X]], %[[C0]] : index
+# DUMPIR:       scf.if %[[IS_THREAD0]] {
+# DUMPIR:         %[[C0_INIT:.*]] = arith.constant 0 : index
+# DUMPIR:         %[[C7:.*]] = arith.constant 7 : index
+# DUMPIR:         %[[C1_INIT:.*]] = arith.constant 1 : index
+# DUMPIR:         scf.for %arg15 = %[[C0_INIT]] to %[[C7]] step %[[C1_INIT]] {
+# DUMPIR:           %[[C1_MBAR:.*]] = arith.constant 1 : index
+# DUMPIR:           nvgpu.mbarrier.init %[[MBAR]][%arg15], %[[C1_MBAR]] : <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7>
+# DUMPIR:         }
+# DUMPIR:         nvgpu.tma.prefetch.descriptor %[[TMAA]] : <tensor = memref<128x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>
+# DUMPIR:         nvgpu.tma.prefetch.descriptor %[[TMAB]] : <tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>
+# DUMPIR:       }
+# DUMPIR:       %[[C0_PROLOGUE:.*]] = arith.constant 0 : index
+# DUMPIR:       %[[C6:.*]] = arith.constant 6 : index
+# DUMPIR:       %[[C1_PROLOGUE:.*]] = arith.constant 1 : index
+# DUMPIR:       scf.for %arg15 = %[[C0_PROLOGUE]] to %[[C6]] step %[[C1_PROLOGUE]] {
+# DUMPIR:         %[[BID_X_P:.*]] = gpu.block_id  x
+# DUMPIR:         %[[BID_Y_P:.*]] = gpu.block_id  y
+# DUMPIR:         %[[C128_P1:.*]] = arith.constant 128 : index
+# DUMPIR:         %[[DIMX_P:.*]] = arith.muli %[[BID_X_P]], %[[C128_P1]] : index
+# DUMPIR:         %[[C128_P2:.*]] = arith.constant 128 : index
+# DUMPIR:         %[[DIMY_P:.*]] = arith.muli %[[BID_Y_P]], %[[C128_P2]] : index
+# DUMPIR:         %{{.*}} = gpu.thread_id  x
+# DUMPIR:         %[[TID_X_P:.*]] = gpu.thread_id  x
+# DUMPIR:         %[[C0_P:.*]] = arith.constant 0 : index
+# DUMPIR:         %[[PRED_P:.*]] = arith.cmpi eq, %[[TID_X_P]], %[[C0_P]] : index
+# DUMPIR:         %[[C16384_P1:.*]] = arith.constant 16384 : index
+# DUMPIR:         %[[OFF_A_P:.*]] = arith.muli %arg15, %[[C16384_P1]] : index
+# DUMPIR:         %[[C16384_P2:.*]] = arith.constant 16384 : index
+# DUMPIR:         %[[OFF_B_BASE_P:.*]] = arith.muli %arg15, %[[C16384_P2]] : index
+# DUMPIR:         %[[C114688:.*]] = arith.constant 114688 : index
+# DUMPIR:         %[[OFF_B1_P:.*]] = arith.addi %[[OFF_B_BASE_P]], %[[C114688]] : index
+# DUMPIR:         %[[C8192:.*]] = arith.constant 8192 : index
+# DUMPIR:         %[[OFF_B2_P:.*]] = arith.addi %[[OFF_B1_P]], %[[C8192]] : index
+# DUMPIR:         %[[SMEM_A_P:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
+# DUMPIR:         %[[VIEW_A_P:.*]] = memref.view %[[SMEM_A_P]][%[[OFF_A_P]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<128x64xf16, #gpu.address_space<workgroup>>
+# DUMPIR:         %[[SMEM_B1_P:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
+# DUMPIR:         %[[VIEW_B1_P:.*]] = memref.view %[[SMEM_B1_P]][%[[OFF_B1_P]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<64x64xf16, #gpu.address_space<workgroup>>
+# DUMPIR:         %[[SMEM_B2_P:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
+# DUMPIR:         %[[VIEW_B2_P:.*]] = memref.view %[[SMEM_B2_P]][%[[OFF_B2_P]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<64x64xf16, #gpu.address_space<workgroup>>
+# DUMPIR:         %[[C32768:.*]] = arith.constant 32768 : index
+# DUMPIR:         nvgpu.mbarrier.arrive.expect_tx %[[MBAR]][%arg15], %[[C32768]], predicate = %[[PRED_P]] : <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7>
+# DUMPIR:         %[[C64_K_P:.*]] = arith.constant 64 : index
+# DUMPIR:         %[[K_COORD_P:.*]] = arith.muli %arg15, %[[C64_K_P]] : index
+# DUMPIR:         nvgpu.tma.async.load %[[TMAA]][%[[K_COORD_P]], %[[DIMX_P]]], %[[MBAR]][%arg15] to %[[VIEW_A_P]], predicate = %[[PRED_P]] : <tensor = memref<128x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7> -> memref<128x64xf16, #gpu.address_space<workgroup>>
+# DUMPIR:         nvgpu.tma.async.load %[[TMAB]][%[[DIMY_P]], %[[K_COORD_P]]], %[[MBAR]][%arg15] to %[[VIEW_B1_P]], predicate = %[[PRED_P]] : <tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7> -> memref<64x64xf16, #gpu.address_space<workgroup>>
+# DUMPIR:         %[[C64_OFF:.*]] = arith.constant 64 : index
+# DUMPIR:         %[[DIMY_P_OFF:.*]] = arith.addi %[[DIMY_P]], %[[C64_OFF]] : index
+# DUMPIR:         nvgpu.tma.async.load %[[TMAB]][%[[DIMY_P_OFF]], %[[K_COORD_P]]], %[[MBAR]][%arg15] to %[[VIEW_B2_P]], predicate = %[[PRED_P]] : <tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7> -> memref<64x64xf16, #gpu.address_space<workgroup>>
+# DUMPIR:       }
+# DUMPIR:       %[[TID_X_LOOP:.*]] = gpu.thread_id  x
+# DUMPIR:       %[[ACC_INIT:.*]] = nvgpu.warpgroup.mma.init.accumulator -> <fragmented = vector<128x128xf32>>
+# DUMPIR:       %[[FALSE_LOOP:.*]] = arith.constant false
+# DUMPIR:       %[[C0_LOOP:.*]] = arith.constant 0 : index
+# DUMPIR:       %[[C16_LOOP:.*]] = arith.constant 16 : index
+# DUMPIR:       %[[C1_LOOP:.*]] = arith.constant 1 : index
+# DUMPIR:       %[[LOOP_RES:.*]]:2 = scf.for %arg15 = %[[C0_LOOP]] to %[[C16_LOOP]] step %[[C1_LOOP]] iter_args(%arg16 = %[[ACC_INIT]], %arg17 = %[[FALSE_LOOP]]) -> (!nvgpu.warpgroup.accumulator<fragmented = vector<128x128xf32>>, i1) {
+# DUMPIR:         %[[C7_L:.*]] = arith.constant 7 : index
+# DUMPIR:         %[[STAGE_L:.*]] = arith.remui %arg15, %[[C7_L]] : index
+# DUMPIR:         %[[C10M:.*]] = arith.constant 10000000 : index
+# DUMPIR:         nvgpu.mbarrier.try_wait.parity %[[MBAR]][%[[STAGE_L]]], %arg17, %[[C10M]] : <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7>
+# DUMPIR:         %[[C16384_L:.*]] = arith.constant 16384 : index
+# DUMPIR:         %[[OFF_A_L:.*]] = arith.muli %[[STAGE_L]], %[[C16384_L]] : index
+# DUMPIR:         %[[C114688_L:.*]] = arith.constant 114688 : index
+# DUMPIR:         %[[OFF_B_L:.*]] = arith.addi %[[OFF_A_L]], %[[C114688_L]] : index
+# DUMPIR:         %[[SMEM_A_L:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
+# DUMPIR:         %[[VIEW_A_L:.*]] = memref.view %[[SMEM_A_L]][%[[OFF_A_L]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<128x64xf16, #gpu.address_space<workgroup>>
+# DUMPIR:         %[[SMEM_B_L:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
+# DUMPIR:         %[[VIEW_B_L:.*]] = memref.view %[[SMEM_B_L]][%[[OFF_B_L]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<64x128xf16, #gpu.address_space<workgroup>>
+# DUMPIR:         %[[DESC_A_L:.*]] = nvgpu.warpgroup.generate.descriptor %[[VIEW_A_L]], %[[TMAA]] : memref<128x64xf16, #gpu.address_space<workgroup>>, <tensor = memref<128x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none> -> <tensor = memref<128x64xf16, #gpu.address_space<workgroup>>>
+# DUMPIR:         %[[DESC_B_L:.*]] = nvgpu.warpgroup.generate.descriptor %[[VIEW_B_L]], %[[TMAB]] : memref<64x128xf16, #gpu.address_space<workgroup>>, <tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none> -> <tensor = memref<64x128xf16, #gpu.address_space<workgroup>>>
+# DUMPIR:         %[[ACC_L:.*]] = nvgpu.warpgroup.mma %[[DESC_A_L]], %[[DESC_B_L]], %arg16 {transposeB} : <tensor = memref<128x64xf16, #gpu.address_space<workgroup>>>, <tensor = memref<64x128xf16, #gpu.address_space<workgroup>>>, <fragmented = vector<128x128xf32>> -> <fragmented = vector<128x128xf32>>
+# DUMPIR:         %[[C6_NEXT:.*]] = arith.constant 6 : index
+# DUMPIR:         %[[ITER_NEXT:.*]] = arith.addi %arg15, %[[C6_NEXT]] : index
+# DUMPIR:         %[[C16_CMP:.*]] = arith.constant 16 : index
+# DUMPIR:         %[[IN_RANGE:.*]] = arith.cmpi ult, %[[ITER_NEXT]], %[[C16_CMP]] : index
+# DUMPIR:         %[[C0_CMP:.*]] = arith.constant 0 : index
+# DUMPIR:         %[[IS_THREAD0_L:.*]] = arith.cmpi eq, %[[TID_X_LOOP]], %[[C0_CMP]] : index
+# DUMPIR:         %[[DO_LOAD:.*]] = arith.andi %[[IN_RANGE]], %[[IS_THREAD0_L]] : i1
+# DUMPIR:         %[[C6_STAGE:.*]] = arith.constant 6 : index
+# DUMPIR:         %[[STAGE_NEXT_L:.*]] = arith.addi %arg15, %[[C6_STAGE]] : index
+# DUMPIR:         %[[C7_MOD:.*]] = arith.constant 7 : index
+# DUMPIR:         %[[STAGE_LOAD:.*]] = arith.remui %[[STAGE_NEXT_L]], %[[C7_MOD]] : index
+# DUMPIR:         %[[BID_X_L:.*]] = gpu.block_id  x
+# DUMPIR:         %[[BID_Y_L:.*]] = gpu.block_id  y
+# DUMPIR:         %[[C128_L1:.*]] = arith.constant 128 : index
+# DUMPIR:         %[[DIMX_L:.*]] = arith.muli %[[BID_X_L]], %[[C128_L1]] : index
+# DUMPIR:         %[[C128_L2:.*]] = arith.constant 128 : index
+# DUMPIR:         %[[DIMY_L:.*]] = arith.muli %[[BID_Y_L]], %[[C128_L2]] : index
+# DUMPIR:         %[[TID_X_L1:.*]] = gpu.thread_id  x
+# DUMPIR:         %[[TID_X_L2:.*]] = gpu.thread_id  x
+# DUMPIR:         %[[C16384_LA1:.*]] = arith.constant 16384 : index
+# DUMPIR:         %[[OFF_A_LOAD:.*]] = arith.muli %[[STAGE_LOAD]], %[[C16384_LA1]] : index
+# DUMPIR:         %[[C16384_LA2:.*]] = arith.constant 16384 : index
+# DUMPIR:         %[[OFF_B_BASE_LOAD:.*]] = arith.muli %[[STAGE_LOAD]], %[[C16384_LA2]] : index
+# DUMPIR:         %[[C114688_LOAD:.*]] = arith.constant 114688 : index
+# DUMPIR:         %[[OFF_B1_LOAD:.*]] = arith.addi %[[OFF_B_BASE_LOAD]], %[[C114688_LOAD]] : index
+# DUMPIR:         %[[C8192_LOAD:.*]] = arith.constant 8192 : index
+# DUMPIR:         %[[OFF_B2_LOAD:.*]] = arith.addi %[[OFF_B1_LOAD]], %[[C8192_LOAD]] : index
+# DUMPIR:         %[[SMEM_A_LOAD:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
+# DUMPIR:         %[[VIEW_A_LOAD:.*]] = memref.view %[[SMEM_A_LOAD]][%[[OFF_A_LOAD]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<128x64xf16, #gpu.address_space<workgroup>>
+# DUMPIR:         %[[SMEM_B1_LOAD:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
+# DUMPIR:         %[[VIEW_B1_LOAD:.*]] = memref.view %[[SMEM_B1_LOAD]][%[[OFF_B1_LOAD]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<64x64xf16, #gpu.address_space<workgroup>>
+# DUMPIR:         %[[SMEM_B2_LOAD:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
+# DUMPIR:         %[[VIEW_B2_LOAD:.*]] = memref.view %[[SMEM_B2_LOAD]][%[[OFF_B2_LOAD]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<64x64xf16, #gpu.address_space<workgroup>>
+# DUMPIR:         %[[C32768_LOAD:.*]] = arith.constant 32768 : index
+# DUMPIR:         nvgpu.mbarrier.arrive.expect_tx %[[MBAR]][%[[STAGE_LOAD]]], %[[C32768_LOAD]], predicate = %[[DO_LOAD]] : <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7>
+# DUMPIR:         %[[C64_K_LOAD:.*]] = arith.constant 64 : index
+# DUMPIR:         %[[K_COORD_LOAD:.*]] = arith.muli %[[STAGE_NEXT_L]], %[[C64_K_LOAD]] : index
+# DUMPIR:         nvgpu.tma.async.load %[[TMAA]][%[[K_COORD_LOAD]], %[[DIMX_L]]], %[[MBAR]][%[[STAGE_LOAD]]] to %[[VIEW_A_LOAD]], predicate = %[[DO_LOAD]] : <tensor = memref<128x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7> -> memref<128x64xf16, #gpu.address_space<workgroup>>
+# DUMPIR:         nvgpu.tma.async.load %[[TMAB]][%[[DIMY_L]], %[[K_COORD_LOAD]]], %[[MBAR]][%[[STAGE_LOAD]]] to %[[VIEW_B1_LOAD]], predicate = %[[DO_LOAD]] : <tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7> -> memref<64x64xf16, #gpu.address_space<workgroup>>
+# DUMPIR:         %[[C64_OFF_LOAD:.*]] = arith.constant 64 : index
+# DUMPIR:         %[[DIMY_L_OFF:.*]] = arith.addi %[[DIMY_L]], %[[C64_OFF_LOAD]] : index
+# DUMPIR:         nvgpu.tma.async.load %[[TMAB]][%[[DIMY_L_OFF]], %[[K_COORD_LOAD]]], %[[MBAR]][%[[STAGE_LOAD]]] to %[[VIEW_B2_LOAD]], predicate = %[[DO_LOAD]] : <tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7> -> memref<64x64xf16, #gpu.address_space<workgroup>>
+# DUMPIR:         %[[C6_FLIP:.*]] = arith.constant 6 : index
+# DUMPIR:         %[[IS_STAGE6:.*]] = arith.cmpi eq, %[[STAGE_L]], %[[C6_FLIP]] : index
+# DUMPIR:         %[[TRUE:.*]] = arith.constant true
+# DUMPIR:         %[[PARITY_FLIP:.*]] = arith.xori %arg17, %[[TRUE]] : i1
+# DUMPIR:         %[[NEW_PARITY:.*]] = arith.select %[[IS_STAGE6]], %[[PARITY_FLIP]], %arg17 : i1
+# DUMPIR:         scf.yield %[[ACC_L]], %[[NEW_PARITY]] : !nvgpu.warpgroup.accumulator<fragmented = vector<128x128xf32>>, i1
+# DUMPIR:       }
+# DUMPIR:       nvvm.wgmma.wait.group.sync.aligned 0
+# DUMPIR:       %[[TID_X_EPI:.*]] = gpu.thread_id  x
+# DUMPIR:       %[[BID_X_EPI:.*]] = gpu.block_id  x
+# DUMPIR:       %[[BID_Y_EPI:.*]] = gpu.block_id  y
+# DUMPIR:       %[[C128_EPI1:.*]] = arith.constant 128 : index
+# DUMPIR:       %[[DIMX_EPI:.*]] = arith.muli %[[BID_X_EPI]], %[[C128_EPI1]] : index
+# DUMPIR:       %[[C128_EPI2:.*]] = arith.constant 128 : index
+# DUMPIR:       %[[DIMY_EPI:.*]] = arith.muli %[[BID_Y_EPI]], %[[C128_EPI2]] : index
+# DUMPIR:       %[[SMEM_EPI:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
+# DUMPIR:       %[[C0_VIEW:.*]] = arith.constant 0 : index
+# DUMPIR:       %[[VIEW_EPI:.*]] = memref.view %[[SMEM_EPI]][%[[C0_VIEW]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<128x128xf32, #gpu.address_space<workgroup>>
+# DUMPIR:       %[[SUBVIEW_EPI:.*]] = memref.subview %[[DMEM]][%[[DIMX_EPI]], %[[DIMY_EPI]]] [128, 128] [1, 1] : memref<512x256xf32> to memref<128x128xf32, strided<[256, 1], offset: ?>>
+# DUMPIR:       nvgpu.warpgroup.mma.store %[[LOOP_RES]]#0, %[[VIEW_EPI]] : <fragmented = vector<128x128xf32>> to memref<128x128xf32, #gpu.address_space<workgroup>>
+# DUMPIR:       gpu.barrier
+# DUMPIR:       %[[C0_STORE:.*]] = arith.constant 0 : index
+# DUMPIR:       %[[C128_STORE:.*]] = arith.constant 128 : index
+# DUMPIR:       %[[C1_STORE:.*]] = arith.constant 1 : index
+# DUMPIR:       scf.for %arg15 = %[[C0_STORE]] to %[[C128_STORE]] step %[[C1_STORE]] {
+# DUMPIR:         %[[VAL_LOAD:.*]] = memref.load %[[VIEW_EPI]][%arg15, %[[TID_X_EPI]]] : memref<128x128xf32, #gpu.address_space<workgroup>>
+# DUMPIR:         memref.store %[[VAL_LOAD]], %[[SUBVIEW_EPI]][%arg15, %[[TID_X_EPI]]] : memref<128x128xf32, strided<[256, 1], offset: ?>>
+# DUMPIR:       }
+# DUMPIR:       gpu.terminator
+# DUMPIR:     }
+# DUMPIR:     %[[CPYD:.*]] = gpu.memcpy async [%[[WAIT2]]] %arg2, %[[DMEM]] : memref<512x256xf32>, memref<512x256xf32>
+# DUMPIR:     gpu.wait [%[[CPYD]]]
+# DUMPIR:     return
+# DUMPIR:   }
diff --git a/mlir/test/Examples/NVGPU/Ch5.py b/mlir/test/Examples/NVGPU/Ch5.py
index 91c346c837dda..59e955e8a0f2e 100644
--- a/mlir/test/Examples/NVGPU/Ch5.py
+++ b/mlir/test/Examples/NVGPU/Ch5.py
@@ -1,5 +1,8 @@
 # RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
-# RUN:   %PYTHON %s | FileCheck %s
+# RUN: env MLIR_RUN_CUDA_SM90_TESTS=%mlir_run_cuda_sm90_tests \
+# RUN: sh -c 'if [[ "$MLIR_RUN_CUDA_SM90_TESTS" == "1" ]]; \
+# RUN: then %PYTHON %s | FileCheck %s; \
+# RUN: else %PYTHON %s | FileCheck %s --check-prefix=DUMPIR; fi'
 
 # ===----------------------------------------------------------------------===//
 #  Chapter 5 : Warp Specialized GEMM with Tensor Core
@@ -47,6 +50,7 @@
 from tools.nvdsl import *
 import numpy as np
 
+dump_only = os.getenv("MLIR_RUN_CUDA_SM90_TESTS") != "1"
 
 def partition_shape():
     """
@@ -250,7 +254,7 @@ def epilogue(D: WGMMAMatrix, d_dev):
         scf.yield_([])
 
 
- at NVDSL.mlir_func
+ at NVDSL.mlir_func(dump_only)
 def gemm_warp_specialized(a, b, d, num_stages):
     token_ty = gpu.AsyncTokenType.get()
     t1 = gpu.wait(token_ty, [])
@@ -311,11 +315,167 @@ def gemm_warp_specialized_kernel():
 
 gemm_warp_specialized(a, b, d, num_stages=7)
 
+if not dump_only:
+    # Verify MLIR with reference computation
+    ref_d = a.astype(np.float16) @ b.astype(np.float16)
+    np.testing.assert_allclose(d, ref_d, rtol=5e-03, atol=1e-01)
 
-# Verify MLIR with reference computation
-ref_d = a.astype(np.float16) @ b.astype(np.float16)
-np.testing.assert_allclose(d, ref_d, rtol=5e-03, atol=1e-01)
-
-
-print("PASS")
+    print("PASS")
 # CHECK-NOT: Mismatched elements
+# CHECK: PASS
+
+# DUMPIR: gpu.launch blocks(%arg3, %arg4, %arg5) in (%arg9 = %c4, %arg10 = %c2, %arg11 = %c1) threads(%arg6, %arg7, %arg8) in (%arg12 = %c256, %arg13 = %c1_7, %arg14 = %c1_8) dynamic_shared_memory_size %c229376_i32 {
+# DUMPIR:       %[[TID_X:.*]] = gpu.thread_id  x
+# DUMPIR:       %[[C128:.*]] = arith.constant 128 : index
+# DUMPIR:       %[[REM1:.*]] = arith.remui %[[TID_X]], %[[C128]] : index
+# DUMPIR:       %[[C0:.*]] = arith.constant 0 : index
+# DUMPIR:       %[[IS_PRIMARY:.*]] = arith.cmpi eq, %[[REM1]], %[[C0]] : index
+# DUMPIR:       %[[C128_1:.*]] = arith.constant 128 : index
+# DUMPIR:       %[[DIV1:.*]] = arith.divui %[[TID_X]], %[[C128_1]] : index
+# DUMPIR:       %[[C1:.*]] = arith.constant 1 : index
+# DUMPIR:       %[[IS_PRODUCER:.*]] = arith.cmpi eq, %[[DIV1]], %[[C1]] : index
+# DUMPIR:       %[[TID_X_2:.*]] = gpu.thread_id  x
+# DUMPIR:       %[[C128_2:.*]] = arith.constant 128 : index
+# DUMPIR:       %[[REM2:.*]] = arith.remui %[[TID_X_2]], %[[C128_2]] : index
+# DUMPIR:       %[[C0_2:.*]] = arith.constant 0 : index
+# DUMPIR:       %[[IS_PRIMARY_2:.*]] = arith.cmpi eq, %[[REM2]], %[[C0_2]] : index
+# DUMPIR:       %[[C128_3:.*]] = arith.constant 128 : index
+# DUMPIR:       %[[DIV2:.*]] = arith.divui %[[TID_X_2]], %[[C128_3]] : index
+# DUMPIR:       %[[C0_3:.*]] = arith.constant 0 : index
+# DUMPIR:       %[[IS_CONSUMER:.*]] = arith.cmpi eq, %[[DIV2]], %[[C0_3]] : index
+# DUMPIR:       %[[TID_X_3:.*]] = gpu.thread_id  x
+# DUMPIR:       %[[MBAR_MMA:.*]] = nvgpu.mbarrier.create -> <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7>
+# DUMPIR:       %[[MBAR_TMA:.*]] = nvgpu.mbarrier.create -> <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7>
+# DUMPIR:       %[[C0_4:.*]] = arith.constant 0 : index
+# DUMPIR:       %[[IS_THREAD0:.*]] = arith.cmpi eq, %[[TID_X_3]], %[[C0_4]] : index
+# DUMPIR:       scf.if %[[IS_THREAD0]] {
+# DUMPIR:         %[[C0_INIT:.*]] = arith.constant 0 : index
+# DUMPIR:         %[[C7:.*]] = arith.constant 7 : index
+# DUMPIR:         %[[C1_INIT:.*]] = arith.constant 1 : index
+# DUMPIR:         scf.for %arg15 = %[[C0_INIT]] to %[[C7]] step %[[C1_INIT]] {
+# DUMPIR:           %[[C1_INIT_VAL:.*]] = arith.constant 1 : index
+# DUMPIR:           nvgpu.mbarrier.init %[[MBAR_MMA]][%arg15], %[[C1_INIT_VAL]] : <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7>
+# DUMPIR:           %[[C1_INIT_VAL_2:.*]] = arith.constant 1 : index
+# DUMPIR:           nvgpu.mbarrier.init %[[MBAR_TMA]][%arg15], %[[C1_INIT_VAL_2]] : <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7>
+# DUMPIR:         }
+# DUMPIR:         nvgpu.tma.prefetch.descriptor %{{.*}} : <tensor = memref<128x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>
+# DUMPIR:         nvgpu.tma.prefetch.descriptor %{{.*}} : <tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>
+# DUMPIR:       }
+# DUMPIR:       scf.if %[[IS_PRODUCER]] {
+# DUMPIR:         nvvm.setmaxregister  decrease 40
+# DUMPIR:         %[[TRUE:.*]] = arith.constant true
+# DUMPIR:         %[[C0_PROD:.*]] = arith.constant 0 : index
+# DUMPIR:         %[[C16:.*]] = arith.constant 16 : index
+# DUMPIR:         %[[C1_PROD:.*]] = arith.constant 1 : index
+# DUMPIR:         %[[PROD_LOOP:.*]] = scf.for %arg15 = %[[C0_PROD]] to %[[C16]] step %[[C1_PROD]] iter_args(%arg16 = %[[TRUE]]) -> (i1) {
+# DUMPIR:           %[[C7_PROD:.*]] = arith.constant 7 : index
+# DUMPIR:           %[[SLOT:.*]] = arith.remui %arg15, %[[C7_PROD]] : index
+# DUMPIR:           %[[TIMEOUT:.*]] = arith.constant 10000000 : index
+# DUMPIR:           nvgpu.mbarrier.try_wait.parity %[[MBAR_MMA]][%[[SLOT]]], %arg16, %[[TIMEOUT]] : <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7>
+# DUMPIR:           %[[C6:.*]] = arith.constant 6 : index
+# DUMPIR:           %[[IS_LAST:.*]] = arith.cmpi eq, %[[SLOT]], %[[C6]] : index
+# DUMPIR:           %[[TRUE_2:.*]] = arith.constant true
+# DUMPIR:           %[[FLIP:.*]] = arith.xori %arg16, %[[TRUE_2]] : i1
+# DUMPIR:           %[[PHASE:.*]] = arith.select %[[IS_LAST]], %[[FLIP]], %arg16 : i1
+# DUMPIR:           %[[BID_X:.*]] = gpu.block_id  x
+# DUMPIR:           %[[BID_Y:.*]] = gpu.block_id  y
+# DUMPIR:           %[[C128_TILE:.*]] = arith.constant 128 : index
+# DUMPIR:           %[[DIM_X:.*]] = arith.muli %[[BID_X]], %[[C128_TILE]] : index
+# DUMPIR:           %[[C128_TILE_2:.*]] = arith.constant 128 : index
+# DUMPIR:           %[[DIM_Y:.*]] = arith.muli %[[BID_Y]], %[[C128_TILE_2]] : index
+# DUMPIR:           %[[TID_PROD:.*]] = gpu.thread_id  x
+# DUMPIR:           %[[C16384:.*]] = arith.constant 16384 : index
+# DUMPIR:           %[[OFF_A:.*]] = arith.muli %[[SLOT]], %[[C16384]] : index
+# DUMPIR:           %[[C16384_2:.*]] = arith.constant 16384 : index
+# DUMPIR:           %[[OFF_B_BASE:.*]] = arith.muli %[[SLOT]], %[[C16384_2]] : index
+# DUMPIR:           %[[C114688:.*]] = arith.constant 114688 : index
+# DUMPIR:           %[[OFF_B1:.*]] = arith.addi %[[OFF_B_BASE]], %[[C114688]] : index
+# DUMPIR:           %[[C8192:.*]] = arith.constant 8192 : index
+# DUMPIR:           %[[OFF_B2:.*]] = arith.addi %[[OFF_B1]], %[[C8192]] : index
+# DUMPIR:           %[[SMEM:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
+# DUMPIR:           %[[VIEW_A:.*]] = memref.view %[[SMEM]][%[[OFF_A]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<128x64xf16, #gpu.address_space<workgroup>>
+# DUMPIR:           %[[SMEM_2:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
+# DUMPIR:           %[[VIEW_B1:.*]] = memref.view %[[SMEM_2]][%[[OFF_B1]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<64x64xf16, #gpu.address_space<workgroup>>
+# DUMPIR:           %[[SMEM_3:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
+# DUMPIR:           %[[VIEW_B2:.*]] = memref.view %[[SMEM_3]][%[[OFF_B2]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<64x64xf16, #gpu.address_space<workgroup>>
+# DUMPIR:           %[[TX_COUNT:.*]] = arith.constant 32768 : index
+# DUMPIR:           nvgpu.mbarrier.arrive.expect_tx %[[MBAR_TMA]][%[[SLOT]]], %[[TX_COUNT]], predicate = %[[IS_PRIMARY]] : <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7>
+# DUMPIR:           %[[C128_WG:.*]] = arith.constant 128 : index
+# DUMPIR:           %[[TID_MOD:.*]] = arith.remui %[[TID_PROD]], %[[C128_WG]] : index
+# DUMPIR:           %[[C0_TMA:.*]] = arith.constant 0 : index
+# DUMPIR:           %[[IS_TMA_THREAD:.*]] = arith.cmpi eq, %[[TID_MOD]], %[[C0_TMA]] : index
+# DUMPIR:           %[[C64:.*]] = arith.constant 64 : index
+# DUMPIR:           %[[K_COORD:.*]] = arith.muli %arg15, %[[C64]] : index
+# DUMPIR:           nvgpu.tma.async.load %{{.*}}[%[[K_COORD]], %[[DIM_X]]], %[[MBAR_TMA]][%[[SLOT]]] to %[[VIEW_A]], predicate = %[[IS_TMA_THREAD]] : <tensor = memref<128x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7> -> memref<128x64xf16, #gpu.address_space<workgroup>>
+# DUMPIR:           nvgpu.tma.async.load %{{.*}}[%[[DIM_Y]], %[[K_COORD]]], %[[MBAR_TMA]][%[[SLOT]]] to %[[VIEW_B1]], predicate = %[[IS_TMA_THREAD]] : <tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7> -> memref<64x64xf16, #gpu.address_space<workgroup>>
+# DUMPIR:           %[[C64_OFF:.*]] = arith.constant 64 : index
+# DUMPIR:           %[[DIM_Y_OFF:.*]] = arith.addi %[[DIM_Y]], %[[C64_OFF]] : index
+# DUMPIR:           nvgpu.tma.async.load %{{.*}}[%[[DIM_Y_OFF]], %[[K_COORD]]], %[[MBAR_TMA]][%[[SLOT]]] to %[[VIEW_B2]], predicate = %[[IS_TMA_THREAD]] : <tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7> -> memref<64x64xf16, #gpu.address_space<workgroup>>
+# DUMPIR:           scf.yield %[[PHASE]] : i1
+# DUMPIR:         }
+# DUMPIR:       }
+# DUMPIR:       scf.if %[[IS_CONSUMER]] {
+# DUMPIR:         nvvm.setmaxregister  increase 232
+# DUMPIR:         %[[FALSE:.*]] = arith.constant false
+# DUMPIR:         %[[ACC_INIT:.*]] = nvgpu.warpgroup.mma.init.accumulator -> <fragmented = vector<128x128xf32>>
+# DUMPIR:         %[[C0_CONS:.*]] = arith.constant 0 : index
+# DUMPIR:         %[[C16_CONS:.*]] = arith.constant 16 : index
+# DUMPIR:         %[[C1_CONS:.*]] = arith.constant 1 : index
+# DUMPIR:         %[[CONS_LOOP:.*]]:2 = scf.for %arg15 = %[[C0_CONS]] to %[[C16_CONS]] step %[[C1_CONS]] iter_args(%arg16 = %[[ACC_INIT]], %arg17 = %[[FALSE]]) -> (!nvgpu.warpgroup.accumulator<fragmented = vector<128x128xf32>>, i1) {
+# DUMPIR:           %[[C7_CONS:.*]] = arith.constant 7 : index
+# DUMPIR:           %[[SLOT_CONS:.*]] = arith.remui %arg15, %[[C7_CONS]] : index
+# DUMPIR:           %[[TIMEOUT_CONS:.*]] = arith.constant 10000000 : index
+# DUMPIR:           nvgpu.mbarrier.try_wait.parity %[[MBAR_TMA]][%[[SLOT_CONS]]], %arg17, %[[TIMEOUT_CONS]] : <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7>
+# DUMPIR:           %[[C16384_CONS:.*]] = arith.constant 16384 : index
+# DUMPIR:           %[[OFF_A_CONS:.*]] = arith.muli %[[SLOT_CONS]], %[[C16384_CONS]] : index
+# DUMPIR:           %[[C114688_CONS:.*]] = arith.constant 114688 : index
+# DUMPIR:           %[[OFF_B_CONS:.*]] = arith.addi %[[OFF_A_CONS]], %[[C114688_CONS]] : index
+# DUMPIR:           %[[SMEM_CONS:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
+# DUMPIR:           %[[VIEW_A_CONS:.*]] = memref.view %[[SMEM_CONS]][%[[OFF_A_CONS]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<128x64xf16, #gpu.address_space<workgroup>>
+# DUMPIR:           %[[SMEM_CONS_2:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
+# DUMPIR:           %[[VIEW_B_CONS:.*]] = memref.view %[[SMEM_CONS_2]][%[[OFF_B_CONS]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<64x128xf16, #gpu.address_space<workgroup>>
+# DUMPIR:           %[[DESC_A:.*]] = nvgpu.warpgroup.generate.descriptor %[[VIEW_A_CONS]], %{{.*}} : memref<128x64xf16, #gpu.address_space<workgroup>>, <tensor = memref<128x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none> -> <tensor = memref<128x64xf16, #gpu.address_space<workgroup>>>
+# DUMPIR:           %[[DESC_B:.*]] = nvgpu.warpgroup.generate.descriptor %[[VIEW_B_CONS]], %{{.*}} : memref<64x128xf16, #gpu.address_space<workgroup>>, <tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none> -> <tensor = memref<64x128xf16, #gpu.address_space<workgroup>>>
+# DUMPIR:           %[[ACC:.*]] = nvgpu.warpgroup.mma %[[DESC_A]], %[[DESC_B]], %arg16 {transposeB} : <tensor = memref<128x64xf16, #gpu.address_space<workgroup>>>, <tensor = memref<64x128xf16, #gpu.address_space<workgroup>>>, <fragmented = vector<128x128xf32>> -> <fragmented = vector<128x128xf32>>
+# DUMPIR:           %[[C0_CMP:.*]] = arith.constant 0 : index
+# DUMPIR:           %[[IS_NOT_FIRST:.*]] = arith.cmpi ugt, %arg15, %[[C0_CMP]] : index
+# DUMPIR:           %[[ARRIVE_PRED:.*]] = arith.andi %[[IS_NOT_FIRST]], %[[IS_PRIMARY_2]] : i1
+# DUMPIR:           scf.if %[[ARRIVE_PRED]] {
+# DUMPIR:             %[[C0_ARR:.*]] = arith.constant 0 : index
+# DUMPIR:             %[[IS_ZERO:.*]] = arith.cmpi eq, %[[SLOT_CONS]], %[[C0_ARR]] : index
+# DUMPIR:             %[[C6_WRAP:.*]] = arith.constant 6 : index
+# DUMPIR:             %[[C1_SUB:.*]] = arith.constant 1 : index
+# DUMPIR:             %[[PREV_SLOT:.*]] = arith.subi %[[SLOT_CONS]], %[[C1_SUB]] : index
+# DUMPIR:             %[[BARR_ID:.*]] = arith.select %[[IS_ZERO]], %[[C6_WRAP]], %[[PREV_SLOT]] : index
+# DUMPIR:             %{{.*}} = nvgpu.mbarrier.arrive %[[MBAR_MMA]][%[[BARR_ID]]] : <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7> -> !nvgpu.mbarrier.token
+# DUMPIR:           }
+# DUMPIR:           %[[C6_LAST:.*]] = arith.constant 6 : index
+# DUMPIR:           %[[IS_LAST_CONS:.*]] = arith.cmpi eq, %[[SLOT_CONS]], %[[C6_LAST]] : index
+# DUMPIR:           %[[TRUE_CONS:.*]] = arith.constant true
+# DUMPIR:           %[[FLIP_CONS:.*]] = arith.xori %arg17, %[[TRUE_CONS]] : i1
+# DUMPIR:           %[[PHASE_CONS:.*]] = arith.select %[[IS_LAST_CONS]], %[[FLIP_CONS]], %arg17 : i1
+# DUMPIR:           scf.yield %[[ACC]], %[[PHASE_CONS]] : !nvgpu.warpgroup.accumulator<fragmented = vector<128x128xf32>>, i1
+# DUMPIR:         }
+# DUMPIR:         nvvm.wgmma.wait.group.sync.aligned 0
+# DUMPIR:         %[[TID_EPI:.*]] = gpu.thread_id  x
+# DUMPIR:         %[[BID_X_EPI:.*]] = gpu.block_id  x
+# DUMPIR:         %[[BID_Y_EPI:.*]] = gpu.block_id  y
+# DUMPIR:         %[[C128_EPI:.*]] = arith.constant 128 : index
+# DUMPIR:         %[[DIM_X_EPI:.*]] = arith.muli %[[BID_X_EPI]], %[[C128_EPI]] : index
+# DUMPIR:         %[[C128_EPI_2:.*]] = arith.constant 128 : index
+# DUMPIR:         %[[DIM_Y_EPI:.*]] = arith.muli %[[BID_Y_EPI]], %[[C128_EPI_2]] : index
+# DUMPIR:         %[[SMEM_EPI:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
+# DUMPIR:         %[[C0_EPI:.*]] = arith.constant 0 : index
+# DUMPIR:         %[[VIEW_EPI:.*]] = memref.view %[[SMEM_EPI]][%[[C0_EPI]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<128x128xf32, #gpu.address_space<workgroup>>
+# DUMPIR:         %[[SUBVIEW:.*]] = memref.subview %{{.*}}[%[[DIM_X_EPI]], %[[DIM_Y_EPI]]] [128, 128] [1, 1] : memref<512x256xf32> to memref<128x128xf32, strided<[256, 1], offset: ?>>
+# DUMPIR:         nvgpu.warpgroup.mma.store %[[CONS_LOOP]]#0, %[[VIEW_EPI]] : <fragmented = vector<128x128xf32>> to memref<128x128xf32, #gpu.address_space<workgroup>>
+# DUMPIR:         gpu.barrier
+# DUMPIR:         %[[C0_STORE:.*]] = arith.constant 0 : index
+# DUMPIR:         %[[C128_STORE:.*]] = arith.constant 128 : index
+# DUMPIR:         %[[C1_STORE:.*]] = arith.constant 1 : index
+# DUMPIR:         scf.for %arg15 = %[[C0_STORE]] to %[[C128_STORE]] step %[[C1_STORE]] {
+# DUMPIR:           %{{.*}} = memref.load %[[VIEW_EPI]][%arg15, %[[TID_EPI]]] : memref<128x128xf32, #gpu.address_space<workgroup>>
+# DUMPIR:           memref.store %{{.*}}, %[[SUBVIEW]][%arg15, %[[TID_EPI]]] : memref<128x128xf32, strided<[256, 1], offset: ?>>
+# DUMPIR:         }
+# DUMPIR:       }
+# DUMPIR:       gpu.terminator
\ No newline at end of file
diff --git a/mlir/test/Examples/NVGPU/tools/nvdsl.py b/mlir/test/Examples/NVGPU/tools/nvdsl.py
index ab4e37fdfa9b7..4e0a10308095f 100644
--- a/mlir/test/Examples/NVGPU/tools/nvdsl.py
+++ b/mlir/test/Examples/NVGPU/tools/nvdsl.py
@@ -327,130 +327,136 @@ def wrapper(*args, **kwargs):
         return decorator
 
     @staticmethod
-    def mlir_func(funcBody):
-        @functools.wraps(funcBody)
-        def wrapper(*args, **kwargs):
-            function_name = funcBody.__name__
-
-            def saveIR(module):
-                """Save generated IR"""
-                if True:  # self.saveIR:
-                    # print(mlir_nvgpu_module)
+    def mlir_func(dump_only=False):
+        def decorator(funcBody):
+            @functools.wraps(funcBody)
+            def wrapper(*args, **kwargs):
+                function_name = funcBody.__name__
+
+                def saveIR(module):
+                    """Save generated IR"""
                     original_stdout = sys.stdout
                     with open("nvdsl.mlir", "w") as f:
                         sys.stdout = f
                         print(module)
                         sys.stdout = original_stdout
 
-            def _binary_op(lhs, rhs, op: str, predAtt="") -> "ArithValue":
-                """Generate MLIR's Arith dialects binary operations."""
-                rhs = const(rhs)
-                if arith._is_float_type(lhs.type) and arith._is_float_type(rhs.type):
-                    op += "F"
-                    if op.startswith("Cmp"):
-                        predicateAttr = getattr(arith, f"CmpFPredicate").__dict__[
-                            predAtt
-                        ]
-                elif arith._is_integer_like_type(
-                    lhs.type
-                ) and arith._is_integer_like_type(lhs.type):
-                    if op == "Div" or op == "Rem":
-                        op += "U"
-                    op += "I"
-                    if op.startswith("Cmp"):
-                        predicateAttr = getattr(arith, f"CmpIPredicate").__dict__[
-                            predAtt
-                        ]
-                else:
-                    raise NotImplementedError(
-                        f"Unsupported '{op}' operands: {lhs}, {rhs}"
-                    )
+                def _binary_op(lhs, rhs, op: str, predAtt="") -> "ArithValue":
+                    """Generate MLIR's Arith dialects binary operations."""
+                    rhs = const(rhs)
+                    if arith._is_float_type(lhs.type) and arith._is_float_type(
+                        rhs.type
+                    ):
+                        op += "F"
+                        if op.startswith("Cmp"):
+                            predicateAttr = getattr(arith, f"CmpFPredicate").__dict__[
+                                predAtt
+                            ]
+                    elif arith._is_integer_like_type(
+                        lhs.type
+                    ) and arith._is_integer_like_type(lhs.type):
+                        if op == "Div" or op == "Rem":
+                            op += "U"
+                        op += "I"
+                        if op.startswith("Cmp"):
+                            predicateAttr = getattr(arith, f"CmpIPredicate").__dict__[
+                                predAtt
+                            ]
+                    else:
+                        raise NotImplementedError(
+                            f"Unsupported '{op}' operands: {lhs}, {rhs}"
+                        )
 
-                if op.startswith("Cmp"):
-                    op = getattr(arith, f"{op}Op")
-
-                    return op(predicateAttr, lhs, rhs).result
-                else:
-                    op = getattr(arith, f"{op}Op")
-                    return op(lhs, rhs).result
-
-            @ir.register_value_caster(ir.IndexType.static_typeid)
-            @ir.register_value_caster(ir.F32Type.static_typeid)
-            @ir.register_value_caster(ir.F16Type.static_typeid)
-            @ir.register_value_caster(ir.F64Type.static_typeid)
-            @ir.register_value_caster(ir.IntegerType.static_typeid)
-            class ArithValue(ir.Value):
-                """Overloads operators for MLIR's Arith dialects binary operations."""
-
-                def __init__(self, v):
-                    super().__init__(v)
-
-                __add__ = partialmethod(_binary_op, op="Add")
-                __sub__ = partialmethod(_binary_op, op="Sub")
-                __mul__ = partialmethod(_binary_op, op="Mul")
-                __truediv__ = partialmethod(_binary_op, op="Div")
-                __mod__ = partialmethod(_binary_op, op="Rem")
-                __xor__ = partialmethod(_binary_op, op="XOr")
-                __lt__ = partialmethod(_binary_op, op="Cmp", predAtt="ult")
-                __le__ = partialmethod(_binary_op, op="Cmp", predAtt="ule")
-                __eq__ = partialmethod(_binary_op, op="Cmp", predAtt="eq")
-                __ne__ = partialmethod(_binary_op, op="Cmp", predAtt="ne")
-                __gt__ = partialmethod(_binary_op, op="Cmp", predAtt="ugt")
-                __ge__ = partialmethod(_binary_op, op="Cmp", predAtt="uge")
-                __and__ = partialmethod(_binary_op, op="And")
-                __or__ = partialmethod(_binary_op, op="Or")
-
-                def __str__(self):
-                    return (
-                        super()
-                        .__str__()
-                        .replace(ir.Value.__name__, ArithValue.__name__)
+                    if op.startswith("Cmp"):
+                        op = getattr(arith, f"{op}Op")
+
+                        return op(predicateAttr, lhs, rhs).result
+                    else:
+                        op = getattr(arith, f"{op}Op")
+                        return op(lhs, rhs).result
+
+                @ir.register_value_caster(ir.IndexType.static_typeid)
+                @ir.register_value_caster(ir.F32Type.static_typeid)
+                @ir.register_value_caster(ir.F16Type.static_typeid)
+                @ir.register_value_caster(ir.F64Type.static_typeid)
+                @ir.register_value_caster(ir.IntegerType.static_typeid)
+                class ArithValue(ir.Value):
+                    """Overloads operators for MLIR's Arith dialects binary operations."""
+
+                    def __init__(self, v):
+                        super().__init__(v)
+
+                    __add__ = partialmethod(_binary_op, op="Add")
+                    __sub__ = partialmethod(_binary_op, op="Sub")
+                    __mul__ = partialmethod(_binary_op, op="Mul")
+                    __truediv__ = partialmethod(_binary_op, op="Div")
+                    __mod__ = partialmethod(_binary_op, op="Rem")
+                    __xor__ = partialmethod(_binary_op, op="XOr")
+                    __lt__ = partialmethod(_binary_op, op="Cmp", predAtt="ult")
+                    __le__ = partialmethod(_binary_op, op="Cmp", predAtt="ule")
+                    __eq__ = partialmethod(_binary_op, op="Cmp", predAtt="eq")
+                    __ne__ = partialmethod(_binary_op, op="Cmp", predAtt="ne")
+                    __gt__ = partialmethod(_binary_op, op="Cmp", predAtt="ugt")
+                    __ge__ = partialmethod(_binary_op, op="Cmp", predAtt="uge")
+                    __and__ = partialmethod(_binary_op, op="And")
+                    __or__ = partialmethod(_binary_op, op="Or")
+
+                    def __str__(self):
+                        return (
+                            super()
+                            .__str__()
+                            .replace(ir.Value.__name__, ArithValue.__name__)
+                        )
+
+                # Generate MLIR Context and start generating IR
+                with ir.Context(), ir.Location.unknown():
+                    types = []
+                    for arg in args:
+                        types.append(get_mlir_ty(arg))
+
+                    # Build IR
+                    module = ir.Module.create()
+                    with ir.InsertionPoint(module.body):
+                        fop = func.FuncOp(function_name, (types, []))
+                        fop.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
+                        with ir.InsertionPoint(fop.add_entry_block()):
+                            fargs = []
+                            for i, a in enumerate(types):
+                                fargs.append(fop.arguments[i])
+
+                            # Call user function body
+                            result = funcBody(*fargs, **kwargs)
+                            func.ReturnOp([])
+
+                    # Save IR in a file
+                    # saveIR(module)
+                    if dump_only:
+                        print(module)
+                        return 0
+
+                    # Verify the module
+                    module.operation.verify()
+
+                    # Compile and JIT MLIR module
+                    options = f"cubin-chip=sm_90a cubin-features=+ptx80 opt-level=3"
+                    support_lib = os.getenv("SUPPORT_LIB")
+                    if not os.path.exists(support_lib):
+                        raise FileNotFoundError(
+                            errno.ENOENT, os.strerror(errno.ENOENT), support_lib
+                        )
+                    compiler = nvgpucompiler.NvgpuCompiler(
+                        options, opt_level=3, shared_libs=[support_lib]
                     )
+                    engine = compiler.compile_and_jit(module)
 
-            # Generate MLIR Context and start generating IR
-            with ir.Context(), ir.Location.unknown():
-                types = []
-                for arg in args:
-                    types.append(get_mlir_ty(arg))
-
-                # Build IR
-                module = ir.Module.create()
-                with ir.InsertionPoint(module.body):
-                    fop = func.FuncOp(function_name, (types, []))
-                    fop.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
-                    with ir.InsertionPoint(fop.add_entry_block()):
-                        fargs = []
-                        for i, a in enumerate(types):
-                            fargs.append(fop.arguments[i])
-
-                        # Call user function body
-                        result = funcBody(*fargs, **kwargs)
-                        func.ReturnOp([])
-
-                # Save IR in a file
-                # saveIR(module)
-
-                # Verify the module
-                module.operation.verify()
-
-                # Compile and JIT MLIR module
-                options = f"cubin-chip=sm_90a cubin-features=+ptx80 opt-level=3"
-                support_lib = os.getenv("SUPPORT_LIB")
-                if not os.path.exists(support_lib):
-                    raise FileNotFoundError(
-                        errno.ENOENT, os.strerror(errno.ENOENT), support_lib
-                    )
-                compiler = nvgpucompiler.NvgpuCompiler(
-                    options, opt_level=3, shared_libs=[support_lib]
-                )
-                engine = compiler.compile_and_jit(module)
+                # Convert input arguments to MLIR arguments
+                newArgs = get_mlir_func_obj_ty(args)
 
-            # Convert input arguments to MLIR arguments
-            newArgs = get_mlir_func_obj_ty(args)
+                # Run the compiled program
+                engine.invoke(function_name, *newArgs)
 
-            # Run the compiled program
-            engine.invoke(function_name, *newArgs)
+                return result
 
-            return result
+            return wrapper
 
-        return wrapper
+        return decorator

>From 78f92c822d8ccc40fed76a25e0b62547f0f3b982 Mon Sep 17 00:00:00 2001
From: Giacomo Castiglioni <giacastiglioni at gmail.com>
Date: Tue, 28 Oct 2025 15:33:32 +0100
Subject: [PATCH 4/6] address review, remove dump ir from decorator

---
 mlir/test/Examples/NVGPU/Ch0.py         |  11 +-
 mlir/test/Examples/NVGPU/Ch1.py         |  13 +-
 mlir/test/Examples/NVGPU/Ch2.py         |  13 +-
 mlir/test/Examples/NVGPU/Ch3.py         |  13 +-
 mlir/test/Examples/NVGPU/Ch4.py         |  13 +-
 mlir/test/Examples/NVGPU/Ch5.py         |  13 +-
 mlir/test/Examples/NVGPU/lit.local.cfg  |   2 +-
 mlir/test/Examples/NVGPU/tools/nvdsl.py | 249 ++++++++++++------------
 8 files changed, 160 insertions(+), 167 deletions(-)

diff --git a/mlir/test/Examples/NVGPU/Ch0.py b/mlir/test/Examples/NVGPU/Ch0.py
index 0caab36ee28fc..d5d2dc65f7dfd 100644
--- a/mlir/test/Examples/NVGPU/Ch0.py
+++ b/mlir/test/Examples/NVGPU/Ch0.py
@@ -1,8 +1,9 @@
 # RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
-# RUN: env MLIR_RUN_CUDA_SM90_TESTS=%mlir_run_cuda_sm90_tests \
-# RUN: sh -c 'if [[ "$MLIR_RUN_CUDA_SM90_TESTS" == "1" ]]; \
+# RUN: sh -c 'if [[ "%mlir_run_cuda_sm90_tests" == "1" ]]; \
 # RUN: then %PYTHON %s | FileCheck %s; \
-# RUN: else %PYTHON %s | FileCheck %s --check-prefix=DUMPIR; fi'
+# RUN: else export MLIR_NVDSL_PRINT_IR=1; \
+# RUN: %PYTHON %s | FileCheck %s --check-prefix=DUMPIR; fi'
+
 
 # ===----------------------------------------------------------------------===//
 #  Chapter 0 : Hello World
@@ -21,12 +22,10 @@
 from tools.nvdsl import *
 
 
-dump_only = os.getenv("MLIR_RUN_CUDA_SM90_TESTS") != "1"
-
 # 1. The decorator generates a MLIR func.func.
 # Everything inside the Python function becomes the body of the func.
 # The decorator also translates `alpha` to an `index` type.
- at NVDSL.mlir_func(dump_only)
+ at NVDSL.mlir_func
 def main(alpha):
     # 2. The decorator generates a MLIR gpu.launch.
     # Everything inside the Python function becomes the body of the gpu.launch.
diff --git a/mlir/test/Examples/NVGPU/Ch1.py b/mlir/test/Examples/NVGPU/Ch1.py
index 9fa7d82ae6688..8c162ec85d0d0 100644
--- a/mlir/test/Examples/NVGPU/Ch1.py
+++ b/mlir/test/Examples/NVGPU/Ch1.py
@@ -1,8 +1,9 @@
 # RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
-# RUN: env MLIR_RUN_CUDA_SM90_TESTS=%mlir_run_cuda_sm90_tests \
-# RUN: sh -c 'if [[ "$MLIR_RUN_CUDA_SM90_TESTS" == "1" ]]; \
+# RUN: sh -c 'if [[ "%mlir_run_cuda_sm90_tests" == "1" ]]; \
 # RUN: then %PYTHON %s | FileCheck %s; \
-# RUN: else %PYTHON %s | FileCheck %s --check-prefix=DUMPIR; fi'
+# RUN: else export MLIR_NVDSL_PRINT_IR=1; \
+# RUN: %PYTHON %s | FileCheck %s --check-prefix=DUMPIR; fi'
+
 
 # ===----------------------------------------------------------------------===//
 #  Chapter 1 : 2D Saxpy
@@ -22,9 +23,9 @@
 from tools.nvdsl import *
 import numpy as np
 
-dump_only = os.getenv("MLIR_RUN_CUDA_SM90_TESTS") != "1"
 
- at NVDSL.mlir_func(dump_only)
+
+ at NVDSL.mlir_func
 def saxpy(x, y, alpha):
     # 1. Use MLIR GPU dialect to allocate and copy memory
     token_ty = gpu.AsyncTokenType.get()
@@ -63,7 +64,7 @@ def saxpy_kernel():
 
 saxpy(x, y, alpha)
 
-if not dump_only:
+if os.getenv("MLIR_NVDSL_PRINT_IR") != "1":
     # 4. Verify MLIR with reference computation
     ref = np.ones((M, N), np.float32)
     ref += x * alpha
diff --git a/mlir/test/Examples/NVGPU/Ch2.py b/mlir/test/Examples/NVGPU/Ch2.py
index 9d35833027e9f..d684db460f9d9 100644
--- a/mlir/test/Examples/NVGPU/Ch2.py
+++ b/mlir/test/Examples/NVGPU/Ch2.py
@@ -1,8 +1,9 @@
 # RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
-# RUN: env MLIR_RUN_CUDA_SM90_TESTS=%mlir_run_cuda_sm90_tests \
-# RUN: sh -c 'if [[ "$MLIR_RUN_CUDA_SM90_TESTS" == "1" ]]; \
+# RUN: sh -c 'if [[ "%mlir_run_cuda_sm90_tests" == "1" ]]; \
 # RUN: then %PYTHON %s | FileCheck %s; \
-# RUN: else %PYTHON %s | FileCheck %s --check-prefix=DUMPIR; fi'
+# RUN: else export MLIR_NVDSL_PRINT_IR=1; \
+# RUN: %PYTHON %s | FileCheck %s --check-prefix=DUMPIR; fi'
+
 
 # ===----------------------------------------------------------------------===//
 #  Chapter 2 : 2D Saxpy with TMA
@@ -27,9 +28,7 @@
 from mlir.extras import types as T
 import numpy as np
 
-dump_only = os.getenv("MLIR_RUN_CUDA_SM90_TESTS") != "1"
-
- at NVDSL.mlir_func(dump_only)
+ at NVDSL.mlir_func
 def saxpy(x, y, alpha):
     token_ty = gpu.AsyncTokenType.get()
     t1 = gpu.wait(token_ty, [])
@@ -89,7 +88,7 @@ def saxpy_tma_kernel():
 y = np.ones((M, N), np.float32)
 saxpy(x, y, alpha)
 
-if not dump_only:
+if os.getenv("MLIR_NVDSL_PRINT_IR") != "1":
     #  4. Verify MLIR with reference computation
     ref = np.ones((M, N), np.float32)
     ref += x * alpha
diff --git a/mlir/test/Examples/NVGPU/Ch3.py b/mlir/test/Examples/NVGPU/Ch3.py
index 4b44c89eabf2e..8cdd63bb779c6 100644
--- a/mlir/test/Examples/NVGPU/Ch3.py
+++ b/mlir/test/Examples/NVGPU/Ch3.py
@@ -1,8 +1,9 @@
 # RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
-# RUN: env MLIR_RUN_CUDA_SM90_TESTS=%mlir_run_cuda_sm90_tests \
-# RUN: sh -c 'if [[ "$MLIR_RUN_CUDA_SM90_TESTS" == "1" ]]; \
+# RUN: sh -c 'if [[ "%mlir_run_cuda_sm90_tests" == "1" ]]; \
 # RUN: then %PYTHON %s | FileCheck %s; \
-# RUN: else %PYTHON %s | FileCheck %s --check-prefix=DUMPIR; fi'
+# RUN: else export MLIR_NVDSL_PRINT_IR=1; \
+# RUN: %PYTHON %s | FileCheck %s --check-prefix=DUMPIR; fi'
+
 
 # ===----------------------------------------------------------------------===//
 #  Chapter 3 : GEMM 128x128x64 with Tensor Core
@@ -24,8 +25,6 @@
 from mlir.extras import types as T
 import numpy as np
 
-dump_only = os.getenv("MLIR_RUN_CUDA_SM90_TESTS") != "1"
-
 def tma_load(
     mbar_group: Mbarriers,
     a_tma: TMA,
@@ -61,7 +60,7 @@ def tma_load(
     b_tma.load(b2, mbar_group[0], coords=[64, 0], predicate=p)
 
 
- at NVDSL.mlir_func(dump_only)
+ at NVDSL.mlir_func
 def gemm_128_128_64(a, b, d):
     token_ty = gpu.AsyncTokenType.get()
     t1 = gpu.wait(token_ty, [])
@@ -127,7 +126,7 @@ def gemm_tma_kernel():
 d = np.zeros((M, N), np.float32)
 gemm_128_128_64(a, b, d)
 
-if not dump_only:
+if os.getenv("MLIR_NVDSL_PRINT_IR") != "1":
     # Verify MLIR program with reference computation in python
     ref_d = a.astype(np.float16) @ b.astype(np.float16)
     np.testing.assert_allclose(d, ref_d, rtol=5e-03, atol=1e-01)
diff --git a/mlir/test/Examples/NVGPU/Ch4.py b/mlir/test/Examples/NVGPU/Ch4.py
index fd6b40203f839..cc5c02d16d906 100644
--- a/mlir/test/Examples/NVGPU/Ch4.py
+++ b/mlir/test/Examples/NVGPU/Ch4.py
@@ -1,8 +1,8 @@
 # RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
-# RUN: env MLIR_RUN_CUDA_SM90_TESTS=%mlir_run_cuda_sm90_tests \
-# RUN: sh -c 'if [[ "$MLIR_RUN_CUDA_SM90_TESTS" == "1" ]]; \
+# RUN: sh -c 'if [[ "%mlir_run_cuda_sm90_tests" == "1" ]]; \
 # RUN: then %PYTHON %s | FileCheck %s; \
-# RUN: else %PYTHON %s | FileCheck %s --check-prefix=DUMPIR; fi'
+# RUN: else export MLIR_NVDSL_PRINT_IR=1; \
+# RUN: %PYTHON %s | FileCheck %s --check-prefix=DUMPIR; fi'
 
 
 # ===----------------------------------------------------------------------===//
@@ -51,7 +51,7 @@
 from tools.nvdsl import *
 import numpy as np
 
-dump_only = os.getenv("MLIR_RUN_CUDA_SM90_TESTS") != "1"
+
 
 def partition_shape():
     """
@@ -261,7 +261,7 @@ def epilogue(D: WGMMAMatrix, d_dev):
 #   a -> memref<MxKxf16>
 #   b -> memref<NxKf16>
 #   d -> memref<MxNxf32>
- at NVDSL.mlir_func(dump_only)
+ at NVDSL.mlir_func
 def gemm_multistage(a, b, d, num_stages):
     token_ty = gpu.AsyncTokenType.get()
     t1 = gpu.wait(token_ty, [])
@@ -318,8 +318,7 @@ def gemm_multistage_kernel():
 
 gemm_multistage(a, b, d, num_stages=7)
 
-
-if not dump_only:
+if os.getenv("MLIR_NVDSL_PRINT_IR") != "1":
     # Verify MLIR with reference computation
     ref_d = a.astype(np.float16) @ b.astype(np.float16)
     np.testing.assert_allclose(d, ref_d, rtol=5e-03, atol=1e-01)
diff --git a/mlir/test/Examples/NVGPU/Ch5.py b/mlir/test/Examples/NVGPU/Ch5.py
index 59e955e8a0f2e..78a47f52a4cb8 100644
--- a/mlir/test/Examples/NVGPU/Ch5.py
+++ b/mlir/test/Examples/NVGPU/Ch5.py
@@ -1,8 +1,9 @@
 # RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
-# RUN: env MLIR_RUN_CUDA_SM90_TESTS=%mlir_run_cuda_sm90_tests \
-# RUN: sh -c 'if [[ "$MLIR_RUN_CUDA_SM90_TESTS" == "1" ]]; \
+# RUN: sh -c 'if [[ "%mlir_run_cuda_sm90_tests" == "1" ]]; \
 # RUN: then %PYTHON %s | FileCheck %s; \
-# RUN: else %PYTHON %s | FileCheck %s --check-prefix=DUMPIR; fi'
+# RUN: else export MLIR_NVDSL_PRINT_IR=1; \
+# RUN: %PYTHON %s | FileCheck %s --check-prefix=DUMPIR; fi'
+
 
 # ===----------------------------------------------------------------------===//
 #  Chapter 5 : Warp Specialized GEMM with Tensor Core
@@ -50,7 +51,7 @@
 from tools.nvdsl import *
 import numpy as np
 
-dump_only = os.getenv("MLIR_RUN_CUDA_SM90_TESTS") != "1"
+
 
 def partition_shape():
     """
@@ -254,7 +255,7 @@ def epilogue(D: WGMMAMatrix, d_dev):
         scf.yield_([])
 
 
- at NVDSL.mlir_func(dump_only)
+ at NVDSL.mlir_func
 def gemm_warp_specialized(a, b, d, num_stages):
     token_ty = gpu.AsyncTokenType.get()
     t1 = gpu.wait(token_ty, [])
@@ -315,7 +316,7 @@ def gemm_warp_specialized_kernel():
 
 gemm_warp_specialized(a, b, d, num_stages=7)
 
-if not dump_only:
+if os.getenv("MLIR_NVDSL_PRINT_IR") != "1":
     # Verify MLIR with reference computation
     ref_d = a.astype(np.float16) @ b.astype(np.float16)
     np.testing.assert_allclose(d, ref_d, rtol=5e-03, atol=1e-01)
diff --git a/mlir/test/Examples/NVGPU/lit.local.cfg b/mlir/test/Examples/NVGPU/lit.local.cfg
index 689cd252e7a25..cf975dad53db3 100644
--- a/mlir/test/Examples/NVGPU/lit.local.cfg
+++ b/mlir/test/Examples/NVGPU/lit.local.cfg
@@ -1,4 +1,4 @@
 config.unsupported = False
-if not config.enable_cuda_runner or not config.mlir_run_cuda_sm90_tests:
+if not config.enable_cuda_runner:
   config.unsupported = True
   
\ No newline at end of file
diff --git a/mlir/test/Examples/NVGPU/tools/nvdsl.py b/mlir/test/Examples/NVGPU/tools/nvdsl.py
index 4e0a10308095f..ed708c616f75f 100644
--- a/mlir/test/Examples/NVGPU/tools/nvdsl.py
+++ b/mlir/test/Examples/NVGPU/tools/nvdsl.py
@@ -9,7 +9,7 @@
 from tools import nvgpucompiler
 
 MLIR_DYNAMIC = -9223372036854775808
-
+DUMP_ONLY = os.getenv("MLIR_NVDSL_PRINT_IR") == "1"
 
 def const(value: int, ty=None):
     ty = T.index() if ty is None else ty
@@ -327,136 +327,131 @@ def wrapper(*args, **kwargs):
         return decorator
 
     @staticmethod
-    def mlir_func(dump_only=False):
-        def decorator(funcBody):
-            @functools.wraps(funcBody)
-            def wrapper(*args, **kwargs):
-                function_name = funcBody.__name__
-
-                def saveIR(module):
-                    """Save generated IR"""
-                    original_stdout = sys.stdout
-                    with open("nvdsl.mlir", "w") as f:
-                        sys.stdout = f
-                        print(module)
-                        sys.stdout = original_stdout
-
-                def _binary_op(lhs, rhs, op: str, predAtt="") -> "ArithValue":
-                    """Generate MLIR's Arith dialects binary operations."""
-                    rhs = const(rhs)
-                    if arith._is_float_type(lhs.type) and arith._is_float_type(
-                        rhs.type
-                    ):
-                        op += "F"
-                        if op.startswith("Cmp"):
-                            predicateAttr = getattr(arith, f"CmpFPredicate").__dict__[
-                                predAtt
-                            ]
-                    elif arith._is_integer_like_type(
-                        lhs.type
-                    ) and arith._is_integer_like_type(lhs.type):
-                        if op == "Div" or op == "Rem":
-                            op += "U"
-                        op += "I"
-                        if op.startswith("Cmp"):
-                            predicateAttr = getattr(arith, f"CmpIPredicate").__dict__[
-                                predAtt
-                            ]
-                    else:
-                        raise NotImplementedError(
-                            f"Unsupported '{op}' operands: {lhs}, {rhs}"
-                        )
-
+    def mlir_func(funcBody):
+        @functools.wraps(funcBody)
+        def wrapper(*args, **kwargs):
+            function_name = funcBody.__name__
+
+            def saveIR(module):
+                """Save generated IR"""
+                original_stdout = sys.stdout
+                with open("nvdsl.mlir", "w") as f:
+                    sys.stdout = f
+                    print(module)
+                    sys.stdout = original_stdout
+
+            def _binary_op(lhs, rhs, op: str, predAtt="") -> "ArithValue":
+                """Generate MLIR's Arith dialects binary operations."""
+                rhs = const(rhs)
+                if arith._is_float_type(lhs.type) and arith._is_float_type(rhs.type):
+                    op += "F"
                     if op.startswith("Cmp"):
-                        op = getattr(arith, f"{op}Op")
-
-                        return op(predicateAttr, lhs, rhs).result
-                    else:
-                        op = getattr(arith, f"{op}Op")
-                        return op(lhs, rhs).result
-
-                @ir.register_value_caster(ir.IndexType.static_typeid)
-                @ir.register_value_caster(ir.F32Type.static_typeid)
-                @ir.register_value_caster(ir.F16Type.static_typeid)
-                @ir.register_value_caster(ir.F64Type.static_typeid)
-                @ir.register_value_caster(ir.IntegerType.static_typeid)
-                class ArithValue(ir.Value):
-                    """Overloads operators for MLIR's Arith dialects binary operations."""
-
-                    def __init__(self, v):
-                        super().__init__(v)
-
-                    __add__ = partialmethod(_binary_op, op="Add")
-                    __sub__ = partialmethod(_binary_op, op="Sub")
-                    __mul__ = partialmethod(_binary_op, op="Mul")
-                    __truediv__ = partialmethod(_binary_op, op="Div")
-                    __mod__ = partialmethod(_binary_op, op="Rem")
-                    __xor__ = partialmethod(_binary_op, op="XOr")
-                    __lt__ = partialmethod(_binary_op, op="Cmp", predAtt="ult")
-                    __le__ = partialmethod(_binary_op, op="Cmp", predAtt="ule")
-                    __eq__ = partialmethod(_binary_op, op="Cmp", predAtt="eq")
-                    __ne__ = partialmethod(_binary_op, op="Cmp", predAtt="ne")
-                    __gt__ = partialmethod(_binary_op, op="Cmp", predAtt="ugt")
-                    __ge__ = partialmethod(_binary_op, op="Cmp", predAtt="uge")
-                    __and__ = partialmethod(_binary_op, op="And")
-                    __or__ = partialmethod(_binary_op, op="Or")
-
-                    def __str__(self):
-                        return (
-                            super()
-                            .__str__()
-                            .replace(ir.Value.__name__, ArithValue.__name__)
-                        )
-
-                # Generate MLIR Context and start generating IR
-                with ir.Context(), ir.Location.unknown():
-                    types = []
-                    for arg in args:
-                        types.append(get_mlir_ty(arg))
-
-                    # Build IR
-                    module = ir.Module.create()
-                    with ir.InsertionPoint(module.body):
-                        fop = func.FuncOp(function_name, (types, []))
-                        fop.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
-                        with ir.InsertionPoint(fop.add_entry_block()):
-                            fargs = []
-                            for i, a in enumerate(types):
-                                fargs.append(fop.arguments[i])
-
-                            # Call user function body
-                            result = funcBody(*fargs, **kwargs)
-                            func.ReturnOp([])
-
-                    # Save IR in a file
-                    # saveIR(module)
-                    if dump_only:
-                        print(module)
-                        return 0
-
-                    # Verify the module
-                    module.operation.verify()
-
-                    # Compile and JIT MLIR module
-                    options = f"cubin-chip=sm_90a cubin-features=+ptx80 opt-level=3"
-                    support_lib = os.getenv("SUPPORT_LIB")
-                    if not os.path.exists(support_lib):
-                        raise FileNotFoundError(
-                            errno.ENOENT, os.strerror(errno.ENOENT), support_lib
-                        )
-                    compiler = nvgpucompiler.NvgpuCompiler(
-                        options, opt_level=3, shared_libs=[support_lib]
+                        predicateAttr = getattr(arith, f"CmpFPredicate").__dict__[
+                            predAtt
+                        ]
+                elif arith._is_integer_like_type(
+                    lhs.type
+                ) and arith._is_integer_like_type(lhs.type):
+                    if op == "Div" or op == "Rem":
+                        op += "U"
+                    op += "I"
+                    if op.startswith("Cmp"):
+                        predicateAttr = getattr(arith, f"CmpIPredicate").__dict__[
+                            predAtt
+                        ]
+                else:
+                    raise NotImplementedError(
+                        f"Unsupported '{op}' operands: {lhs}, {rhs}"
+                    )
+
+                if op.startswith("Cmp"):
+                    op = getattr(arith, f"{op}Op")
+
+                    return op(predicateAttr, lhs, rhs).result
+                else:
+                    op = getattr(arith, f"{op}Op")
+                    return op(lhs, rhs).result
+
+            @ir.register_value_caster(ir.IndexType.static_typeid)
+            @ir.register_value_caster(ir.F32Type.static_typeid)
+            @ir.register_value_caster(ir.F16Type.static_typeid)
+            @ir.register_value_caster(ir.F64Type.static_typeid)
+            @ir.register_value_caster(ir.IntegerType.static_typeid)
+            class ArithValue(ir.Value):
+                """Overloads operators for MLIR's Arith dialects binary operations."""
+
+                def __init__(self, v):
+                    super().__init__(v)
+
+                __add__ = partialmethod(_binary_op, op="Add")
+                __sub__ = partialmethod(_binary_op, op="Sub")
+                __mul__ = partialmethod(_binary_op, op="Mul")
+                __truediv__ = partialmethod(_binary_op, op="Div")
+                __mod__ = partialmethod(_binary_op, op="Rem")
+                __xor__ = partialmethod(_binary_op, op="XOr")
+                __lt__ = partialmethod(_binary_op, op="Cmp", predAtt="ult")
+                __le__ = partialmethod(_binary_op, op="Cmp", predAtt="ule")
+                __eq__ = partialmethod(_binary_op, op="Cmp", predAtt="eq")
+                __ne__ = partialmethod(_binary_op, op="Cmp", predAtt="ne")
+                __gt__ = partialmethod(_binary_op, op="Cmp", predAtt="ugt")
+                __ge__ = partialmethod(_binary_op, op="Cmp", predAtt="uge")
+                __and__ = partialmethod(_binary_op, op="And")
+                __or__ = partialmethod(_binary_op, op="Or")
+
+                def __str__(self):
+                    return (
+                        super()
+                        .__str__()
+                        .replace(ir.Value.__name__, ArithValue.__name__)
                     )
-                    engine = compiler.compile_and_jit(module)
 
-                # Convert input arguments to MLIR arguments
-                newArgs = get_mlir_func_obj_ty(args)
+            # Generate MLIR Context and start generating IR
+            with ir.Context(), ir.Location.unknown():
+                types = []
+                for arg in args:
+                    types.append(get_mlir_ty(arg))
+
+                # Build IR
+                module = ir.Module.create()
+                with ir.InsertionPoint(module.body):
+                    fop = func.FuncOp(function_name, (types, []))
+                    fop.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
+                    with ir.InsertionPoint(fop.add_entry_block()):
+                        fargs = []
+                        for i, a in enumerate(types):
+                            fargs.append(fop.arguments[i])
+
+                        # Call user function body
+                        result = funcBody(*fargs, **kwargs)
+                        func.ReturnOp([])
+
+                # Save IR in a file
+                # saveIR(module)
+                if DUMP_ONLY:
+                    print(module)
+                    return 0
+
+                # Verify the module
+                module.operation.verify()
+
+                # Compile and JIT MLIR module
+                options = f"cubin-chip=sm_90a cubin-features=+ptx80 opt-level=3"
+                support_lib = os.getenv("SUPPORT_LIB")
+                if not os.path.exists(support_lib):
+                    raise FileNotFoundError(
+                        errno.ENOENT, os.strerror(errno.ENOENT), support_lib
+                    )
+                compiler = nvgpucompiler.NvgpuCompiler(
+                    options, opt_level=3, shared_libs=[support_lib]
+                )
+                engine = compiler.compile_and_jit(module)
 
-                # Run the compiled program
-                engine.invoke(function_name, *newArgs)
+            # Convert input arguments to MLIR arguments
+            newArgs = get_mlir_func_obj_ty(args)
 
-                return result
+            # Run the compiled program
+            engine.invoke(function_name, *newArgs)
 
-            return wrapper
+            return result
 
-        return decorator
+        return wrapper

>From b79019663cbbfe8c72ba785101b070e32a23a08a Mon Sep 17 00:00:00 2001
From: Giacomo Castiglioni <giacastiglioni at gmail.com>
Date: Mon, 3 Nov 2025 17:36:09 +0100
Subject: [PATCH 5/6] format python files

---
 mlir/test/Examples/NVGPU/Ch0.py         | 1 +
 mlir/test/Examples/NVGPU/Ch1.py         | 1 -
 mlir/test/Examples/NVGPU/Ch2.py         | 1 +
 mlir/test/Examples/NVGPU/Ch3.py         | 1 +
 mlir/test/Examples/NVGPU/Ch4.py         | 1 -
 mlir/test/Examples/NVGPU/Ch5.py         | 3 +--
 mlir/test/Examples/NVGPU/tools/nvdsl.py | 1 +
 7 files changed, 5 insertions(+), 4 deletions(-)

diff --git a/mlir/test/Examples/NVGPU/Ch0.py b/mlir/test/Examples/NVGPU/Ch0.py
index d5d2dc65f7dfd..6d8ff738feac2 100644
--- a/mlir/test/Examples/NVGPU/Ch0.py
+++ b/mlir/test/Examples/NVGPU/Ch0.py
@@ -42,6 +42,7 @@ def kernel():
     # 3. Call the GPU kernel
     kernel()
 
+
 alpha = 100
 # 4. The `mlir_func` decorator JIT compiles the IR and executes the MLIR function.
 main(alpha)
diff --git a/mlir/test/Examples/NVGPU/Ch1.py b/mlir/test/Examples/NVGPU/Ch1.py
index 8c162ec85d0d0..0ad740557e8da 100644
--- a/mlir/test/Examples/NVGPU/Ch1.py
+++ b/mlir/test/Examples/NVGPU/Ch1.py
@@ -24,7 +24,6 @@
 import numpy as np
 
 
-
 @NVDSL.mlir_func
 def saxpy(x, y, alpha):
     # 1. Use MLIR GPU dialect to allocate and copy memory
diff --git a/mlir/test/Examples/NVGPU/Ch2.py b/mlir/test/Examples/NVGPU/Ch2.py
index d684db460f9d9..18e510221ac42 100644
--- a/mlir/test/Examples/NVGPU/Ch2.py
+++ b/mlir/test/Examples/NVGPU/Ch2.py
@@ -28,6 +28,7 @@
 from mlir.extras import types as T
 import numpy as np
 
+
 @NVDSL.mlir_func
 def saxpy(x, y, alpha):
     token_ty = gpu.AsyncTokenType.get()
diff --git a/mlir/test/Examples/NVGPU/Ch3.py b/mlir/test/Examples/NVGPU/Ch3.py
index 8cdd63bb779c6..dbc73c3efc60c 100644
--- a/mlir/test/Examples/NVGPU/Ch3.py
+++ b/mlir/test/Examples/NVGPU/Ch3.py
@@ -25,6 +25,7 @@
 from mlir.extras import types as T
 import numpy as np
 
+
 def tma_load(
     mbar_group: Mbarriers,
     a_tma: TMA,
diff --git a/mlir/test/Examples/NVGPU/Ch4.py b/mlir/test/Examples/NVGPU/Ch4.py
index cc5c02d16d906..5039af4714f76 100644
--- a/mlir/test/Examples/NVGPU/Ch4.py
+++ b/mlir/test/Examples/NVGPU/Ch4.py
@@ -52,7 +52,6 @@
 import numpy as np
 
 
-
 def partition_shape():
     """
     Calculate the partition shape based on the block IDs.
diff --git a/mlir/test/Examples/NVGPU/Ch5.py b/mlir/test/Examples/NVGPU/Ch5.py
index 78a47f52a4cb8..f81e06d13d435 100644
--- a/mlir/test/Examples/NVGPU/Ch5.py
+++ b/mlir/test/Examples/NVGPU/Ch5.py
@@ -52,7 +52,6 @@
 import numpy as np
 
 
-
 def partition_shape():
     """
     Calculate the partition shape based on the block IDs.
@@ -479,4 +478,4 @@ def gemm_warp_specialized_kernel():
 # DUMPIR:           memref.store %{{.*}}, %[[SUBVIEW]][%arg15, %[[TID_EPI]]] : memref<128x128xf32, strided<[256, 1], offset: ?>>
 # DUMPIR:         }
 # DUMPIR:       }
-# DUMPIR:       gpu.terminator
\ No newline at end of file
+# DUMPIR:       gpu.terminator
diff --git a/mlir/test/Examples/NVGPU/tools/nvdsl.py b/mlir/test/Examples/NVGPU/tools/nvdsl.py
index ed708c616f75f..82bc26594900c 100644
--- a/mlir/test/Examples/NVGPU/tools/nvdsl.py
+++ b/mlir/test/Examples/NVGPU/tools/nvdsl.py
@@ -11,6 +11,7 @@
 MLIR_DYNAMIC = -9223372036854775808
 DUMP_ONLY = os.getenv("MLIR_NVDSL_PRINT_IR") == "1"
 
+
 def const(value: int, ty=None):
     ty = T.index() if ty is None else ty
     if isinstance(value, ir.Value) and (

>From 487e41d99ca1f2b032020893ba590f8ef596981f Mon Sep 17 00:00:00 2001
From: Giacomo Castiglioni <giacastiglioni at gmail.com>
Date: Mon, 10 Nov 2025 15:31:15 +0100
Subject: [PATCH 6/6] fix breakage due to new bindings and reduce lit check

---
 mlir/test/Examples/NVGPU/Ch0.py         |  4 +-
 mlir/test/Examples/NVGPU/Ch1.py         | 33 ++++-------
 mlir/test/Examples/NVGPU/Ch2.py         | 53 ++++--------------
 mlir/test/Examples/NVGPU/Ch3.py         | 34 +++---------
 mlir/test/Examples/NVGPU/Ch4.py         | 74 +++++++------------------
 mlir/test/Examples/NVGPU/Ch5.py         |  7 +--
 mlir/test/Examples/NVGPU/tools/nvdsl.py |  9 +--
 7 files changed, 60 insertions(+), 154 deletions(-)

diff --git a/mlir/test/Examples/NVGPU/Ch0.py b/mlir/test/Examples/NVGPU/Ch0.py
index 6d8ff738feac2..e09720a0f3b75 100644
--- a/mlir/test/Examples/NVGPU/Ch0.py
+++ b/mlir/test/Examples/NVGPU/Ch0.py
@@ -37,7 +37,7 @@ def kernel():
         # + operator generates arith.addi
         myValue = alpha + tidx
         # Print from a GPU thread
-        gpu.printf("GPU thread %llu has %llu\n", [tidx, myValue])
+        gpu.printf("GPU thread %llu has %llu\n", tidx, myValue)
 
     # 3. Call the GPU kernel
     kernel()
@@ -53,13 +53,13 @@ def kernel():
 # CHECK: GPU thread 3 has 103
 
 # DUMPIR:   func.func @main(%arg0: index) attributes {llvm.emit_c_interface} {
+# DUMPIR:     %[[C0_I32:.*]] = arith.constant 0 : i32
 # DUMPIR:     %[[C1:.*]] = arith.constant 1 : index
 # DUMPIR:     %[[C1_0:.*]] = arith.constant 1 : index
 # DUMPIR:     %[[C1_1:.*]] = arith.constant 1 : index
 # DUMPIR:     %[[C4:.*]] = arith.constant 4 : index
 # DUMPIR:     %[[C1_2:.*]] = arith.constant 1 : index
 # DUMPIR:     %[[C1_3:.*]] = arith.constant 1 : index
-# DUMPIR:     %[[C0_I32:.*]] = arith.constant 0 : i32
 # DUMPIR:     gpu.launch blocks(%arg1, %arg2, %arg3) in (%arg7 = %[[C1]], %arg8 = %[[C1_0]], %arg9 = %[[C1_1]]) threads(%arg4, %arg5, %arg6) in (%arg10 = %[[C4]], %arg11 = %[[C1_2]], %arg12 = %[[C1_3]]) dynamic_shared_memory_size %[[C0_I32]] {
 # DUMPIR:       %[[TIDX:.*]] = gpu.thread_id  x
 # DUMPIR:       %[[MYVAL:.*]] = arith.addi %arg0, %[[TIDX]] : index
diff --git a/mlir/test/Examples/NVGPU/Ch1.py b/mlir/test/Examples/NVGPU/Ch1.py
index 0ad740557e8da..6e44e4d04fa06 100644
--- a/mlir/test/Examples/NVGPU/Ch1.py
+++ b/mlir/test/Examples/NVGPU/Ch1.py
@@ -28,12 +28,12 @@
 def saxpy(x, y, alpha):
     # 1. Use MLIR GPU dialect to allocate and copy memory
     token_ty = gpu.AsyncTokenType.get()
-    t1 = gpu.wait(token_ty, [])
+    t1 = gpu.wait([])
     x_dev, t2 = gpu.alloc(x.type, token_ty, [t1], [], [])
     y_dev, t3 = gpu.alloc(y.type, token_ty, [t2], [], [])
     t4 = gpu.memcpy(token_ty, [t3], x_dev, x)
     t5 = gpu.memcpy(token_ty, [t4], y_dev, y)
-    t6 = gpu.wait(token_ty, [t5])
+    t6 = gpu.wait([t5])
 
     # 2. Compute 2D SAXPY kernel
     @NVDSL.mlir_gpu_launch(grid=(M, 1, 1), block=(N, 1, 1))
@@ -51,7 +51,7 @@ def saxpy_kernel():
     saxpy_kernel()
 
     t7 = gpu.memcpy(token_ty, [t6], y, y_dev)
-    gpu.wait(token_ty, [t7])
+    gpu.wait([t7])
 
 
 # 3. Pass numpy arrays to MLIR
@@ -72,31 +72,20 @@ def saxpy_kernel():
 # CHECK-NOT: Mismatched elements
 # CHECK: PASS
 
-# DUMPIR:   func.func @saxpy(%arg0: memref<256x32xf32>, %arg1: memref<256x32xf32>, %arg2: f32) attributes {llvm.emit_c_interface} {
+# DUMPIR:   func.func @saxpy(%[[ARG0:.*]]: memref<256x32xf32>, %[[ARG1:.*]]: memref<256x32xf32>, %[[ARG2:.*]]: f32) attributes {llvm.emit_c_interface} {
 # DUMPIR:     %[[WAIT0:.*]] = gpu.wait async
 # DUMPIR:     %[[MEMREF:.*]], %[[ASYNC0:.*]] = gpu.alloc async [%[[WAIT0]]] () : memref<256x32xf32>
 # DUMPIR:     %[[MEMREF0:.*]], %[[ASYNC1:.*]] = gpu.alloc async [%[[ASYNC0]]] () : memref<256x32xf32>
-# DUMPIR:     %[[MEMCPY1:.*]] = gpu.memcpy async [%[[ASYNC1]]] %[[MEMREF]], %arg0 : memref<256x32xf32>, memref<256x32xf32>
-# DUMPIR:     %[[MEMCPY2:.*]] = gpu.memcpy async [%[[MEMCPY1]]] %[[MEMREF0]], %arg1 : memref<256x32xf32>, memref<256x32xf32>
+# DUMPIR:     %[[MEMCPY1:.*]] = gpu.memcpy async [%[[ASYNC1]]] %[[MEMREF]], %[[ARG0]] : memref<256x32xf32>, memref<256x32xf32>
+# DUMPIR:     %[[MEMCPY2:.*]] = gpu.memcpy async [%[[MEMCPY1]]] %[[MEMREF0]], %[[ARG1]] : memref<256x32xf32>, memref<256x32xf32>
 # DUMPIR:     %[[WAIT1:.*]] = gpu.wait async [%[[MEMCPY2]]]
-# DUMPIR:     %[[C256:.*]] = arith.constant 256 : index
-# DUMPIR:     %[[C1:.*]] = arith.constant 1 : index
-# DUMPIR:     %[[C1_2:.*]] = arith.constant 1 : index
-# DUMPIR:     %[[C32:.*]] = arith.constant 32 : index
-# DUMPIR:     %[[C1_3:.*]] = arith.constant 1 : index
-# DUMPIR:     %[[C1_4:.*]] = arith.constant 1 : index
-# DUMPIR:     %[[C0_I32:.*]] = arith.constant 0 : i32
-# DUMPIR:     gpu.launch blocks(%arg3, %arg4, %arg5) in (%arg9 = %[[C256]], %arg10 = %[[C1]], %arg11 = %[[C1_2]]) threads(%arg6, %arg7, %arg8) in (%arg12 = %[[C32]], %arg13 = %[[C1_3]], %arg14 = %[[C1_4]]) dynamic_shared_memory_size %[[C0_I32]] {
-# DUMPIR:       %[[BLOCKID:.*]] = gpu.block_id  x
-# DUMPIR:       %[[THREADID:.*]] = gpu.thread_id  x
-# DUMPIR:       %[[LD0:.*]] = memref.load %[[MEMREF]][%[[BLOCKID]], %[[THREADID]]] : memref<256x32xf32>
-# DUMPIR:       %[[LD1:.*]] = memref.load %[[MEMREF0]][%[[BLOCKID]], %[[THREADID]]] : memref<256x32xf32>
-# DUMPIR:       %[[MUL:.*]] = arith.mulf %[[LD0]], %arg2 : f32
+# DUMPIR:       %[[LD0:.*]] = memref.load %[[MEMREF]][%{{.*}}, %{{.*}}] : memref<256x32xf32>
+# DUMPIR:       %[[LD1:.*]] = memref.load %[[MEMREF0]][%{{.*}}, %{{.*}}] : memref<256x32xf32>
+# DUMPIR:       %[[MUL:.*]] = arith.mulf %[[LD0]], %[[ARG2]] : f32
 # DUMPIR:       %[[ADD:.*]] = arith.addf %[[LD1]], %[[MUL]] : f32
-# DUMPIR:       memref.store %[[ADD]], %[[MEMREF0]][%[[BLOCKID]], %[[THREADID]]] : memref<256x32xf32>
+# DUMPIR:       memref.store %[[ADD]], %[[MEMREF0]][%{{.*}}, %{{.*}}] : memref<256x32xf32>
 # DUMPIR:       gpu.terminator
-# DUMPIR:     }
-# DUMPIR:     %[[MEMCPY3:.*]] = gpu.memcpy async [%[[WAIT1]]] %arg1, %[[MEMREF0]] : memref<256x32xf32>, memref<256x32xf32>
+# DUMPIR:     %[[MEMCPY3:.*]] = gpu.memcpy async [%[[WAIT1]]] %[[ARG1]], %[[MEMREF0]] : memref<256x32xf32>, memref<256x32xf32>
 # DUMPIR:     %[[WAIT2:.*]] = gpu.wait async [%[[MEMCPY3]]]
 # DUMPIR:     return
 # DUMPIR:   }
diff --git a/mlir/test/Examples/NVGPU/Ch2.py b/mlir/test/Examples/NVGPU/Ch2.py
index 18e510221ac42..aba610cee0b34 100644
--- a/mlir/test/Examples/NVGPU/Ch2.py
+++ b/mlir/test/Examples/NVGPU/Ch2.py
@@ -32,12 +32,12 @@
 @NVDSL.mlir_func
 def saxpy(x, y, alpha):
     token_ty = gpu.AsyncTokenType.get()
-    t1 = gpu.wait(token_ty, [])
+    t1 = gpu.wait([])
     x_dev, t2 = gpu.alloc(x.type, token_ty, [t1], [], [])
     y_dev, t3 = gpu.alloc(y.type, token_ty, [t2], [], [])
     t4 = gpu.memcpy(token_ty, [t3], x_dev, x)
     t5 = gpu.memcpy(token_ty, [t4], y_dev, y)
-    t6 = gpu.wait(token_ty, [t5])
+    t6 = gpu.wait([t5])
 
     x_tma = TMA([1, N], x.type)
     y_tma = TMA([1, N], y.type)
@@ -78,7 +78,7 @@ def saxpy_tma_kernel():
     saxpy_tma_kernel()
 
     t7 = gpu.memcpy(token_ty, [t6], y, y_dev)
-    gpu.wait(token_ty, [t7])
+    gpu.wait([t7])
 
 
 # 3. Pass numpy arrays to MLIR
@@ -98,33 +98,15 @@ def saxpy_tma_kernel():
 # CHECK-NOT: Mismatched elements
 # CHECK: PASS
 
-# DUMPIR:   func.func @saxpy(%arg0: memref<256x32xf32>, %arg1: memref<256x32xf32>, %arg2: f32) attributes {llvm.emit_c_interface} {
+# DUMPIR:   func.func @saxpy(%{{.*}}: memref<256x32xf32>, %[[ARG1:.*]]: memref<256x32xf32>, %[[ARG2:.*]]: f32) attributes {llvm.emit_c_interface} {
 # DUMPIR:     %[[WAIT0:.*]] = gpu.wait async
 # DUMPIR:     %[[MEMREF:.*]], %[[ASYNC0:.*]] = gpu.alloc async [%[[WAIT0]]] () : memref<256x32xf32>
-# DUMPIR:     %[[MEMREF0:.*]], %[[ASYNC1:.*]] = gpu.alloc async [%[[ASYNC0]]] () : memref<256x32xf32>
-# DUMPIR:     %[[MEMCPY1:.*]] = gpu.memcpy async [%[[ASYNC1]]] %[[MEMREF]], %arg0 : memref<256x32xf32>, memref<256x32xf32>
-# DUMPIR:     %[[MEMCPY2:.*]] = gpu.memcpy async [%[[MEMCPY1]]] %[[MEMREF0]], %arg1 : memref<256x32xf32>, memref<256x32xf32>
-# DUMPIR:     %[[WAIT1:.*]] = gpu.wait async [%[[MEMCPY2]]]
 # DUMPIR:     %[[CAST:.*]] = memref.cast %[[MEMREF]] : memref<256x32xf32> to memref<*xf32>
 # DUMPIR:     %[[C1:.*]] = arith.constant 1 : index
 # DUMPIR:     %[[C32:.*]] = arith.constant 32 : index
 # DUMPIR:     %[[TMA0:.*]] = nvgpu.tma.create.descriptor %[[CAST]] box[%[[C1]], %[[C32]]] : memref<*xf32> -> <tensor = memref<1x32xf32, 3>, swizzle = none, l2promo = none, oob = zero, interleave = none>
-# DUMPIR:     %[[CAST2:.*]] = memref.cast %[[MEMREF0]] : memref<256x32xf32> to memref<*xf32>
-# DUMPIR:     %[[C1_3:.*]] = arith.constant 1 : index
-# DUMPIR:     %[[C32_4:.*]] = arith.constant 32 : index
-# DUMPIR:     %[[TMA1:.*]] = nvgpu.tma.create.descriptor %[[CAST2]] box[%[[C1_3]], %[[C32_4]]] : memref<*xf32> -> <tensor = memref<1x32xf32, 3>, swizzle = none, l2promo = none, oob = zero, interleave = none>
-# DUMPIR:     %[[C256:.*]] = arith.constant 256 : index
-# DUMPIR:     %[[C1_5:.*]] = arith.constant 1 : index
-# DUMPIR:     %[[C1_6:.*]] = arith.constant 1 : index
-# DUMPIR:     %[[C32_7:.*]] = arith.constant 32 : index
-# DUMPIR:     %[[C1_8:.*]] = arith.constant 1 : index
-# DUMPIR:     %[[C1_9:.*]] = arith.constant 1 : index
-# DUMPIR:     %[[C256_I32:.*]] = arith.constant 256 : i32
-# DUMPIR:     gpu.launch blocks(%arg3, %arg4, %arg5) in (%arg9 = %[[C256]], %arg10 = %[[C1_5]], %arg11 = %[[C1_6]]) threads(%arg6, %arg7, %arg8) in (%arg12 = %[[C32_7]], %arg13 = %[[C1_8]], %arg14 = %[[C1_9]]) dynamic_shared_memory_size %[[C256_I32]] {
-# DUMPIR:       %[[BLOCKID:.*]] = gpu.block_id  x
-# DUMPIR:       %[[THREADID:.*]] = gpu.thread_id  x
 # DUMPIR:       %[[C0:.*]] = arith.constant 0 : index
-# DUMPIR:       %[[EQ:.*]] = arith.cmpi eq, %[[THREADID]], %[[C0]] : index
+# DUMPIR:       %[[EQ:.*]] = arith.cmpi eq, %{{.*}}, %[[C0]] : index
 # DUMPIR:       %[[MB:.*]] = nvgpu.mbarrier.create -> <memorySpace = #gpu.address_space<workgroup>>
 # DUMPIR:       %[[C0_10:.*]] = arith.constant 0 : index
 # DUMPIR:       %[[C1_11:.*]] = arith.constant 1 : index
@@ -135,29 +117,18 @@ def saxpy_tma_kernel():
 # DUMPIR:       %[[DSM1:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
 # DUMPIR:       %[[C128:.*]] = arith.constant 128 : index
 # DUMPIR:       %[[VIEW_13:.*]] = memref.view %[[DSM1]][%[[C128]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<1x32xf32, #gpu.address_space<workgroup>>
-# DUMPIR:       %[[C0_14:.*]] = arith.constant 0 : index
-# DUMPIR:       %[[C0_15:.*]] = arith.constant 0 : index
-# DUMPIR:       nvgpu.tma.async.load %[[TMA0]][%[[C0_15]], %[[BLOCKID]]], %[[MB]][%[[C0_14]]] to %[[VIEW]], predicate = %[[EQ]] : <tensor = memref<1x32xf32, 3>, swizzle = none, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>> -> memref<1x32xf32, #gpu.address_space<workgroup>>
-# DUMPIR:       %[[C0_16:.*]] = arith.constant 0 : index
-# DUMPIR:       %[[C0_17:.*]] = arith.constant 0 : index
-# DUMPIR:       nvgpu.tma.async.load %[[TMA1]][%[[C0_17]], %[[BLOCKID]]], %[[MB]][%[[C0_16]]] to %[[VIEW_13]], predicate = %[[EQ]] : <tensor = memref<1x32xf32, 3>, swizzle = none, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>> -> memref<1x32xf32, #gpu.address_space<workgroup>>
-# DUMPIR:       %[[C0_18:.*]] = arith.constant 0 : index
-# DUMPIR:       %[[C256_19:.*]] = arith.constant 256 : index
-# DUMPIR:       nvgpu.mbarrier.arrive.expect_tx %[[MB]][%[[C0_18]]], %[[C256_19]], predicate = %[[EQ]] : <memorySpace = #gpu.address_space<workgroup>>
+# DUMPIR:       nvgpu.tma.async.load %[[TMA0]][%{{.*}}, %{{.*}}], %[[MB]][%{{.*}}] to %[[VIEW]], predicate = %[[EQ]] : <tensor = memref<1x32xf32, 3>, swizzle = none, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>> -> memref<1x32xf32, #gpu.address_space<workgroup>>
+# DUMPIR:       nvgpu.mbarrier.arrive.expect_tx %[[MB]][%{{.*}}], %{{.*}}, predicate = %[[EQ]] : <memorySpace = #gpu.address_space<workgroup>>
 # DUMPIR:       %[[C0_20:.*]] = arith.constant 0 : index
 # DUMPIR:       %[[C10000000:.*]] = arith.constant 10000000 : index
 # DUMPIR:       %[[FALSE:.*]] = arith.constant false
 # DUMPIR:       nvgpu.mbarrier.try_wait.parity %[[MB]][%[[C0_20]]], %[[FALSE]], %[[C10000000]] : <memorySpace = #gpu.address_space<workgroup>>
 # DUMPIR:       %[[C0_21:.*]] = arith.constant 0 : index
-# DUMPIR:       %[[LD0:.*]] = memref.load %[[VIEW]][%[[C0_21]], %[[THREADID]]] : memref<1x32xf32, #gpu.address_space<workgroup>>
+# DUMPIR:       %[[LD0:.*]] = memref.load %[[VIEW]][%[[C0_21]], %{{.*}}] : memref<1x32xf32, #gpu.address_space<workgroup>>
 # DUMPIR:       %[[C0_22:.*]] = arith.constant 0 : index
-# DUMPIR:       %[[LD1:.*]] = memref.load %[[VIEW_13]][%[[C0_22]], %[[THREADID]]] : memref<1x32xf32, #gpu.address_space<workgroup>>
-# DUMPIR:       %[[MUL:.*]] = arith.mulf %[[LD0]], %arg2 : f32
-# DUMPIR:       %[[ADD:.*]] = arith.addf %[[LD1]], %[[MUL]] : f32
-# DUMPIR:       memref.store %[[ADD]], %[[MEMREF0]][%[[BLOCKID]], %[[THREADID]]] : memref<256x32xf32>
-# DUMPIR:       gpu.terminator
-# DUMPIR:     }
-# DUMPIR:     %[[MEMCPY3:.*]] = gpu.memcpy async [%[[WAIT1]]] %arg1, %[[MEMREF0]] : memref<256x32xf32>, memref<256x32xf32>
-# DUMPIR:     %[[WAIT2:.*]] = gpu.wait async [%[[MEMCPY3]]]
+# DUMPIR:       %[[LD1:.*]] = memref.load %[[VIEW_13]][%[[C0_22]], %{{.*}}] : memref<1x32xf32, #gpu.address_space<workgroup>>
+# DUMPIR:       memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<256x32xf32>
+# DUMPIR:     %[[MEMCPY3:.*]] = gpu.memcpy async [%{{.*}}] %[[ARG1]], %{{.*}} : memref<256x32xf32>, memref<256x32xf32>
+# DUMPIR:     %{{.*}} = gpu.wait async [%[[MEMCPY3]]]
 # DUMPIR:     return
 # DUMPIR:   }
diff --git a/mlir/test/Examples/NVGPU/Ch3.py b/mlir/test/Examples/NVGPU/Ch3.py
index dbc73c3efc60c..fe11575416866 100644
--- a/mlir/test/Examples/NVGPU/Ch3.py
+++ b/mlir/test/Examples/NVGPU/Ch3.py
@@ -64,13 +64,13 @@ def tma_load(
 @NVDSL.mlir_func
 def gemm_128_128_64(a, b, d):
     token_ty = gpu.AsyncTokenType.get()
-    t1 = gpu.wait(token_ty, [])
+    t1 = gpu.wait([])
     a_dev, t2 = gpu.alloc(a.type, token_ty, [t1], [], [])
     b_dev, t3 = gpu.alloc(b.type, token_ty, [t2], [], [])
     d_dev, t4 = gpu.alloc(d.type, token_ty, [t3], [], [])
     t5 = gpu.memcpy(token_ty, [t4], a_dev, a)
     t6 = gpu.memcpy(token_ty, [t5], b_dev, b)
-    t7 = gpu.wait(token_ty, [t6])
+    t7 = gpu.wait([t6])
 
     sw = nvgpu.TensorMapSwizzleKind.SWIZZLE_128B
     a_tma = TMA([128, 64], a.type, swizzle=sw)
@@ -115,7 +115,7 @@ def gemm_tma_kernel():
     gemm_tma_kernel()
 
     t8 = gpu.memcpy(token_ty, [t7], d, d_dev)
-    gpu.wait(None, [t8])
+    gpu.wait([t8])
 
 
 # Python pass arguments to MLIR
@@ -135,30 +135,14 @@ def gemm_tma_kernel():
 # CHECK-NOT: Mismatched elements
 # CHECK: PASS
 
-# DUMPIR:   func.func @gemm_128_128_64(%arg0: memref<128x64xf16>, %arg1: memref<64x128xf16>, %arg2: memref<128x128xf32>) attributes {llvm.emit_c_interface} {
-# DUMPIR:     %[[WAIT0:.*]] = gpu.wait async
-# DUMPIR:     %[[MEM0:.*]], %[[ASYNC0:.*]] = gpu.alloc async [%[[WAIT0]]] () : memref<128x64xf16>
-# DUMPIR:     %[[MEM1:.*]], %[[ASYNC1:.*]] = gpu.alloc async [%[[ASYNC0]]] () : memref<64x128xf16>
-# DUMPIR:     %[[MEM2:.*]], %[[ASYNC2:.*]] = gpu.alloc async [%[[ASYNC1]]] () : memref<128x128xf32>
-# DUMPIR:     %[[CPY1:.*]] = gpu.memcpy async [%[[ASYNC2]]] %[[MEM0]], %arg0 : memref<128x64xf16>, memref<128x64xf16>
-# DUMPIR:     %[[CPY2:.*]] = gpu.memcpy async [%[[CPY1]]] %[[MEM1]], %arg1 : memref<64x128xf16>, memref<64x128xf16>
-# DUMPIR:     %[[WAIT1:.*]] = gpu.wait async [%[[CPY2]]]
-# DUMPIR:     %[[CAST0:.*]] = memref.cast %[[MEM0]] : memref<128x64xf16> to memref<*xf16>
+# DUMPIR:   func.func @gemm_128_128_64(%{{.*}}: memref<128x64xf16>, %{{.*}}: memref<64x128xf16>, %[[ARG2:.*]]: memref<128x128xf32>) attributes {llvm.emit_c_interface} {
 # DUMPIR:     %[[C128:.*]] = arith.constant 128 : index
 # DUMPIR:     %[[C64:.*]] = arith.constant 64 : index
-# DUMPIR:     %[[TMA0:.*]] = nvgpu.tma.create.descriptor %[[CAST0]] box[%[[C128]], %[[C64]]] : memref<*xf16> -> <tensor = memref<128x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>
-# DUMPIR:     %[[CAST1:.*]] = memref.cast %[[MEM1]] : memref<64x128xf16> to memref<*xf16>
+# DUMPIR:     %[[TMA0:.*]] = nvgpu.tma.create.descriptor %{{.*}} box[%[[C128]], %[[C64]]] : memref<*xf16> -> <tensor = memref<128x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>
+# DUMPIR:     %[[CAST1:.*]] = memref.cast %{{.*}} : memref<64x128xf16> to memref<*xf16>
 # DUMPIR:     %[[C64_5:.*]] = arith.constant 64 : index
 # DUMPIR:     %[[C64_6:.*]] = arith.constant 64 : index
 # DUMPIR:     %[[TMA1:.*]] = nvgpu.tma.create.descriptor %[[CAST1]] box[%[[C64_5]], %[[C64_6]]] : memref<*xf16> -> <tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>
-# DUMPIR:     %[[C1:.*]] = arith.constant 1 : index
-# DUMPIR:     %[[C1_7:.*]] = arith.constant 1 : index
-# DUMPIR:     %[[C1_8:.*]] = arith.constant 1 : index
-# DUMPIR:     %[[C128_9:.*]] = arith.constant 128 : index
-# DUMPIR:     %[[C1_10:.*]] = arith.constant 1 : index
-# DUMPIR:     %[[C1_11:.*]] = arith.constant 1 : index
-# DUMPIR:     %[[C32768_I32:.*]] = arith.constant 32768 : i32
-# DUMPIR:     gpu.launch blocks(%arg3, %arg4, %arg5) in (%arg9 = %[[C1]], %arg10 = %[[C1_7]], %arg11 = %[[C1_8]]) threads(%arg6, %arg7, %arg8) in (%arg12 = %[[C128_9]], %arg13 = %[[C1_10]], %arg14 = %[[C1_11]]) dynamic_shared_memory_size %[[C32768_I32]] {
 # DUMPIR:       %[[THREADID:.*]] = gpu.thread_id  x
 # DUMPIR:       %[[MB:.*]] = nvgpu.mbarrier.create -> <memorySpace = #gpu.address_space<workgroup>>
 # DUMPIR:       %[[C0:.*]] = arith.constant 0 : index
@@ -206,10 +190,10 @@ def gemm_tma_kernel():
 # DUMPIR:       %[[GEN0:.*]] = nvgpu.warpgroup.generate.descriptor %[[VIEW]], %[[TMA0]] : memref<128x64xf16, #gpu.address_space<workgroup>>, <tensor = memref<128x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none> -> <tensor = memref<128x64xf16, #gpu.address_space<workgroup>>>
 # DUMPIR:       %[[GEN1:.*]] = nvgpu.warpgroup.generate.descriptor %[[VIEW_15]], %[[TMA1]] : memref<64x128xf16, #gpu.address_space<workgroup>>, <tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none> -> <tensor = memref<64x128xf16, #gpu.address_space<workgroup>>>
 # DUMPIR:       %[[MMA:.*]] = nvgpu.warpgroup.mma %[[GEN0]], %[[GEN1]], %[[WG_ACC]] {transposeB} : <tensor = memref<128x64xf16, #gpu.address_space<workgroup>>>, <tensor = memref<64x128xf16, #gpu.address_space<workgroup>>>, <fragmented = vector<128x128xf32>> -> <fragmented = vector<128x128xf32>>
-# DUMPIR:       nvgpu.warpgroup.mma.store %[[MMA]], %[[MEM2]] : <fragmented = vector<128x128xf32>> to memref<128x128xf32>
+# DUMPIR:       nvgpu.warpgroup.mma.store %[[MMA]], %{{.*}} : <fragmented = vector<128x128xf32>> to memref<128x128xf32>
 # DUMPIR:       gpu.terminator
 # DUMPIR:     }
-# DUMPIR:     %[[CPY3:.*]] = gpu.memcpy async [%[[WAIT1]]] %arg2, %[[MEM2]] : memref<128x128xf32>, memref<128x128xf32>
-# DUMPIR:     gpu.wait [%[[CPY3]]]
+# DUMPIR:     %[[CPY3:.*]] = gpu.memcpy async [%{{.*}}] %[[ARG2]], %{{.*}} : memref<128x128xf32>, memref<128x128xf32>
+# DUMPIR:     gpu.wait async [%[[CPY3]]]
 # DUMPIR:     return
 # DUMPIR:   }
diff --git a/mlir/test/Examples/NVGPU/Ch4.py b/mlir/test/Examples/NVGPU/Ch4.py
index 5039af4714f76..dffafda7f21c9 100644
--- a/mlir/test/Examples/NVGPU/Ch4.py
+++ b/mlir/test/Examples/NVGPU/Ch4.py
@@ -263,13 +263,13 @@ def epilogue(D: WGMMAMatrix, d_dev):
 @NVDSL.mlir_func
 def gemm_multistage(a, b, d, num_stages):
     token_ty = gpu.AsyncTokenType.get()
-    t1 = gpu.wait(token_ty, [])
+    t1 = gpu.wait([])
     a_dev, t2 = gpu.alloc(a.type, token_ty, [t1], [], [])
     b_dev, t3 = gpu.alloc(b.type, token_ty, [t2], [], [])
     d_dev, t4 = gpu.alloc(d.type, token_ty, [t3], [], [])
     t5 = gpu.memcpy(token_ty, [t4], a_dev, a)
     t6 = gpu.memcpy(token_ty, [t5], b_dev, b)
-    t7 = gpu.wait(token_ty, [t6])
+    t7 = gpu.wait([t6])
 
     sw = nvgpu.TensorMapSwizzleKind.SWIZZLE_128B
     a_tma = TMA([128, 64], a.type, swizzle=sw)
@@ -301,7 +301,7 @@ def gemm_multistage_kernel():
     gemm_multistage_kernel()
 
     t8 = gpu.memcpy(token_ty, [t7], d, d_dev)
-    gpu.wait(None, [t8])
+    gpu.wait([t8])
 
 
 # Python pass arguments to MLIR
@@ -326,44 +326,17 @@ def gemm_multistage_kernel():
 # CHECK-NOT: Mismatched elements
 # CHECK: PASS
 
-# DUMPIR:   func.func @gemm_multistage(%arg0: memref<512x1024xf16>, %arg1: memref<1024x256xf16>, %arg2: memref<512x256xf32>) attributes {llvm.emit_c_interface} {
-# DUMPIR:     %[[WAIT:.*]] = gpu.wait async
-# DUMPIR:     %[[AMEM:.*]], %[[ATOK:.*]] = gpu.alloc async [%[[WAIT]]] () : memref<512x1024xf16>
-# DUMPIR:     %[[BMEM:.*]], %[[BTOK:.*]] = gpu.alloc async [%[[ATOK]]] () : memref<1024x256xf16>
-# DUMPIR:     %[[DMEM:.*]], %[[DTOK:.*]] = gpu.alloc async [%[[BTOK]]] () : memref<512x256xf32>
-# DUMPIR:     %[[CPYA:.*]] = gpu.memcpy async [%[[DTOK]]] %[[AMEM]], %arg0 : memref<512x1024xf16>, memref<512x1024xf16>
-# DUMPIR:     %[[CPYB:.*]] = gpu.memcpy async [%[[CPYA]]] %[[BMEM]], %arg1 : memref<1024x256xf16>, memref<1024x256xf16>
-# DUMPIR:     %[[WAIT2:.*]] = gpu.wait async [%[[CPYB]]]
-# DUMPIR:     %[[CASTA:.*]] = memref.cast %[[AMEM]] : memref<512x1024xf16> to memref<*xf16>
-# DUMPIR:     %[[C128:.*]] = arith.constant 128 : index
-# DUMPIR:     %[[C64:.*]] = arith.constant 64 : index
-# DUMPIR:     %[[TMAA:.*]] = nvgpu.tma.create.descriptor %[[CASTA]] box[%[[C128]], %[[C64]]] : memref<*xf16> -> <tensor = memref<128x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>
-# DUMPIR:     %[[CASTB:.*]] = memref.cast %[[BMEM]] : memref<1024x256xf16> to memref<*xf16>
-# DUMPIR:     %[[C64_B1:.*]] = arith.constant 64 : index
-# DUMPIR:     %[[C64_B2:.*]] = arith.constant 64 : index
-# DUMPIR:     %[[TMAB:.*]] = nvgpu.tma.create.descriptor %[[CASTB]] box[%[[C64_B1]], %[[C64_B2]]] : memref<*xf16> -> <tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>
-# DUMPIR:     %[[C4:.*]] = arith.constant 4 : index
-# DUMPIR:     %[[C2:.*]] = arith.constant 2 : index
-# DUMPIR:     %[[C1:.*]] = arith.constant 1 : index
-# DUMPIR:     %[[C128_GRID:.*]] = arith.constant 128 : index
-# DUMPIR:     %[[C1_T1:.*]] = arith.constant 1 : index
-# DUMPIR:     %[[C1_T2:.*]] = arith.constant 1 : index
-# DUMPIR:     %[[SMEM_SIZE:.*]] = arith.constant 229376 : i32
-# DUMPIR:     gpu.launch blocks(%arg3, %arg4, %arg5) in (%arg9 = %[[C4]], %arg10 = %[[C2]], %arg11 = %[[C1]]) threads(%arg6, %arg7, %arg8) in (%arg12 = %[[C128_GRID]], %arg13 = %[[C1_T1]], %arg14 = %[[C1_T2]]) dynamic_shared_memory_size %[[SMEM_SIZE]] {
-# DUMPIR:       %[[TID_X:.*]] = gpu.thread_id  x
-# DUMPIR:       %[[MBAR:.*]] = nvgpu.mbarrier.create -> <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7>
-# DUMPIR:       %[[C0:.*]] = arith.constant 0 : index
-# DUMPIR:       %[[IS_THREAD0:.*]] = arith.cmpi eq, %[[TID_X]], %[[C0]] : index
-# DUMPIR:       scf.if %[[IS_THREAD0]] {
+# DUMPIR:   func.func @gemm_multistage(%{{.*}}: memref<512x1024xf16>, %{{.*}}: memref<1024x256xf16>, %{{.*}}: memref<512x256xf32>) attributes {llvm.emit_c_interface} {
+# DUMPIR:       scf.if %{{.*}} {
 # DUMPIR:         %[[C0_INIT:.*]] = arith.constant 0 : index
 # DUMPIR:         %[[C7:.*]] = arith.constant 7 : index
 # DUMPIR:         %[[C1_INIT:.*]] = arith.constant 1 : index
 # DUMPIR:         scf.for %arg15 = %[[C0_INIT]] to %[[C7]] step %[[C1_INIT]] {
 # DUMPIR:           %[[C1_MBAR:.*]] = arith.constant 1 : index
-# DUMPIR:           nvgpu.mbarrier.init %[[MBAR]][%arg15], %[[C1_MBAR]] : <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7>
+# DUMPIR:           nvgpu.mbarrier.init %{{.*}}[%arg15], %[[C1_MBAR]] : <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7>
 # DUMPIR:         }
-# DUMPIR:         nvgpu.tma.prefetch.descriptor %[[TMAA]] : <tensor = memref<128x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>
-# DUMPIR:         nvgpu.tma.prefetch.descriptor %[[TMAB]] : <tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>
+# DUMPIR:         nvgpu.tma.prefetch.descriptor %{{.*}} : <tensor = memref<128x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>
+# DUMPIR:         nvgpu.tma.prefetch.descriptor %{{.*}} : <tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>
 # DUMPIR:       }
 # DUMPIR:       %[[C0_PROLOGUE:.*]] = arith.constant 0 : index
 # DUMPIR:       %[[C6:.*]] = arith.constant 6 : index
@@ -394,14 +367,14 @@ def gemm_multistage_kernel():
 # DUMPIR:         %[[SMEM_B2_P:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
 # DUMPIR:         %[[VIEW_B2_P:.*]] = memref.view %[[SMEM_B2_P]][%[[OFF_B2_P]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<64x64xf16, #gpu.address_space<workgroup>>
 # DUMPIR:         %[[C32768:.*]] = arith.constant 32768 : index
-# DUMPIR:         nvgpu.mbarrier.arrive.expect_tx %[[MBAR]][%arg15], %[[C32768]], predicate = %[[PRED_P]] : <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7>
+# DUMPIR:         nvgpu.mbarrier.arrive.expect_tx %{{.*}}[%arg15], %[[C32768]], predicate = %[[PRED_P]] : <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7>
 # DUMPIR:         %[[C64_K_P:.*]] = arith.constant 64 : index
 # DUMPIR:         %[[K_COORD_P:.*]] = arith.muli %arg15, %[[C64_K_P]] : index
-# DUMPIR:         nvgpu.tma.async.load %[[TMAA]][%[[K_COORD_P]], %[[DIMX_P]]], %[[MBAR]][%arg15] to %[[VIEW_A_P]], predicate = %[[PRED_P]] : <tensor = memref<128x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7> -> memref<128x64xf16, #gpu.address_space<workgroup>>
-# DUMPIR:         nvgpu.tma.async.load %[[TMAB]][%[[DIMY_P]], %[[K_COORD_P]]], %[[MBAR]][%arg15] to %[[VIEW_B1_P]], predicate = %[[PRED_P]] : <tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7> -> memref<64x64xf16, #gpu.address_space<workgroup>>
+# DUMPIR:         nvgpu.tma.async.load %{{.*}}[%[[K_COORD_P]], %[[DIMX_P]]], %{{.*}}[%arg15] to %[[VIEW_A_P]], predicate = %[[PRED_P]] : <tensor = memref<128x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7> -> memref<128x64xf16, #gpu.address_space<workgroup>>
+# DUMPIR:         nvgpu.tma.async.load %{{.*}}[%[[DIMY_P]], %[[K_COORD_P]]], %{{.*}}[%arg15] to %[[VIEW_B1_P]], predicate = %[[PRED_P]] : <tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7> -> memref<64x64xf16, #gpu.address_space<workgroup>>
 # DUMPIR:         %[[C64_OFF:.*]] = arith.constant 64 : index
 # DUMPIR:         %[[DIMY_P_OFF:.*]] = arith.addi %[[DIMY_P]], %[[C64_OFF]] : index
-# DUMPIR:         nvgpu.tma.async.load %[[TMAB]][%[[DIMY_P_OFF]], %[[K_COORD_P]]], %[[MBAR]][%arg15] to %[[VIEW_B2_P]], predicate = %[[PRED_P]] : <tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7> -> memref<64x64xf16, #gpu.address_space<workgroup>>
+# DUMPIR:         nvgpu.tma.async.load %{{.*}}[%[[DIMY_P_OFF]], %[[K_COORD_P]]], %{{.*}}[%arg15] to %[[VIEW_B2_P]], predicate = %[[PRED_P]] : <tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7> -> memref<64x64xf16, #gpu.address_space<workgroup>>
 # DUMPIR:       }
 # DUMPIR:       %[[TID_X_LOOP:.*]] = gpu.thread_id  x
 # DUMPIR:       %[[ACC_INIT:.*]] = nvgpu.warpgroup.mma.init.accumulator -> <fragmented = vector<128x128xf32>>
@@ -413,7 +386,7 @@ def gemm_multistage_kernel():
 # DUMPIR:         %[[C7_L:.*]] = arith.constant 7 : index
 # DUMPIR:         %[[STAGE_L:.*]] = arith.remui %arg15, %[[C7_L]] : index
 # DUMPIR:         %[[C10M:.*]] = arith.constant 10000000 : index
-# DUMPIR:         nvgpu.mbarrier.try_wait.parity %[[MBAR]][%[[STAGE_L]]], %arg17, %[[C10M]] : <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7>
+# DUMPIR:         nvgpu.mbarrier.try_wait.parity %{{.*}}[%[[STAGE_L]]], %arg17, %[[C10M]] : <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7>
 # DUMPIR:         %[[C16384_L:.*]] = arith.constant 16384 : index
 # DUMPIR:         %[[OFF_A_L:.*]] = arith.muli %[[STAGE_L]], %[[C16384_L]] : index
 # DUMPIR:         %[[C114688_L:.*]] = arith.constant 114688 : index
@@ -422,8 +395,8 @@ def gemm_multistage_kernel():
 # DUMPIR:         %[[VIEW_A_L:.*]] = memref.view %[[SMEM_A_L]][%[[OFF_A_L]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<128x64xf16, #gpu.address_space<workgroup>>
 # DUMPIR:         %[[SMEM_B_L:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
 # DUMPIR:         %[[VIEW_B_L:.*]] = memref.view %[[SMEM_B_L]][%[[OFF_B_L]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<64x128xf16, #gpu.address_space<workgroup>>
-# DUMPIR:         %[[DESC_A_L:.*]] = nvgpu.warpgroup.generate.descriptor %[[VIEW_A_L]], %[[TMAA]] : memref<128x64xf16, #gpu.address_space<workgroup>>, <tensor = memref<128x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none> -> <tensor = memref<128x64xf16, #gpu.address_space<workgroup>>>
-# DUMPIR:         %[[DESC_B_L:.*]] = nvgpu.warpgroup.generate.descriptor %[[VIEW_B_L]], %[[TMAB]] : memref<64x128xf16, #gpu.address_space<workgroup>>, <tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none> -> <tensor = memref<64x128xf16, #gpu.address_space<workgroup>>>
+# DUMPIR:         %[[DESC_A_L:.*]] = nvgpu.warpgroup.generate.descriptor %[[VIEW_A_L]], %{{.*}} : memref<128x64xf16, #gpu.address_space<workgroup>>, <tensor = memref<128x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none> -> <tensor = memref<128x64xf16, #gpu.address_space<workgroup>>>
+# DUMPIR:         %[[DESC_B_L:.*]] = nvgpu.warpgroup.generate.descriptor %[[VIEW_B_L]], %{{.*}} : memref<64x128xf16, #gpu.address_space<workgroup>>, <tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none> -> <tensor = memref<64x128xf16, #gpu.address_space<workgroup>>>
 # DUMPIR:         %[[ACC_L:.*]] = nvgpu.warpgroup.mma %[[DESC_A_L]], %[[DESC_B_L]], %arg16 {transposeB} : <tensor = memref<128x64xf16, #gpu.address_space<workgroup>>>, <tensor = memref<64x128xf16, #gpu.address_space<workgroup>>>, <fragmented = vector<128x128xf32>> -> <fragmented = vector<128x128xf32>>
 # DUMPIR:         %[[C6_NEXT:.*]] = arith.constant 6 : index
 # DUMPIR:         %[[ITER_NEXT:.*]] = arith.addi %arg15, %[[C6_NEXT]] : index
@@ -459,14 +432,14 @@ def gemm_multistage_kernel():
 # DUMPIR:         %[[SMEM_B2_LOAD:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
 # DUMPIR:         %[[VIEW_B2_LOAD:.*]] = memref.view %[[SMEM_B2_LOAD]][%[[OFF_B2_LOAD]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<64x64xf16, #gpu.address_space<workgroup>>
 # DUMPIR:         %[[C32768_LOAD:.*]] = arith.constant 32768 : index
-# DUMPIR:         nvgpu.mbarrier.arrive.expect_tx %[[MBAR]][%[[STAGE_LOAD]]], %[[C32768_LOAD]], predicate = %[[DO_LOAD]] : <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7>
+# DUMPIR:         nvgpu.mbarrier.arrive.expect_tx %{{.*}}[%[[STAGE_LOAD]]], %[[C32768_LOAD]], predicate = %[[DO_LOAD]] : <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7>
 # DUMPIR:         %[[C64_K_LOAD:.*]] = arith.constant 64 : index
 # DUMPIR:         %[[K_COORD_LOAD:.*]] = arith.muli %[[STAGE_NEXT_L]], %[[C64_K_LOAD]] : index
-# DUMPIR:         nvgpu.tma.async.load %[[TMAA]][%[[K_COORD_LOAD]], %[[DIMX_L]]], %[[MBAR]][%[[STAGE_LOAD]]] to %[[VIEW_A_LOAD]], predicate = %[[DO_LOAD]] : <tensor = memref<128x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7> -> memref<128x64xf16, #gpu.address_space<workgroup>>
-# DUMPIR:         nvgpu.tma.async.load %[[TMAB]][%[[DIMY_L]], %[[K_COORD_LOAD]]], %[[MBAR]][%[[STAGE_LOAD]]] to %[[VIEW_B1_LOAD]], predicate = %[[DO_LOAD]] : <tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7> -> memref<64x64xf16, #gpu.address_space<workgroup>>
+# DUMPIR:         nvgpu.tma.async.load %{{.*}}[%[[K_COORD_LOAD]], %[[DIMX_L]]], %{{.*}}[%[[STAGE_LOAD]]] to %[[VIEW_A_LOAD]], predicate = %[[DO_LOAD]] : <tensor = memref<128x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7> -> memref<128x64xf16, #gpu.address_space<workgroup>>
+# DUMPIR:         nvgpu.tma.async.load %{{.*}}[%[[DIMY_L]], %[[K_COORD_LOAD]]], %{{.*}}[%[[STAGE_LOAD]]] to %[[VIEW_B1_LOAD]], predicate = %[[DO_LOAD]] : <tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7> -> memref<64x64xf16, #gpu.address_space<workgroup>>
 # DUMPIR:         %[[C64_OFF_LOAD:.*]] = arith.constant 64 : index
 # DUMPIR:         %[[DIMY_L_OFF:.*]] = arith.addi %[[DIMY_L]], %[[C64_OFF_LOAD]] : index
-# DUMPIR:         nvgpu.tma.async.load %[[TMAB]][%[[DIMY_L_OFF]], %[[K_COORD_LOAD]]], %[[MBAR]][%[[STAGE_LOAD]]] to %[[VIEW_B2_LOAD]], predicate = %[[DO_LOAD]] : <tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7> -> memref<64x64xf16, #gpu.address_space<workgroup>>
+# DUMPIR:         nvgpu.tma.async.load %{{.*}}[%[[DIMY_L_OFF]], %[[K_COORD_LOAD]]], %{{.*}}[%[[STAGE_LOAD]]] to %[[VIEW_B2_LOAD]], predicate = %[[DO_LOAD]] : <tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>, num_barriers = 7> -> memref<64x64xf16, #gpu.address_space<workgroup>>
 # DUMPIR:         %[[C6_FLIP:.*]] = arith.constant 6 : index
 # DUMPIR:         %[[IS_STAGE6:.*]] = arith.cmpi eq, %[[STAGE_L]], %[[C6_FLIP]] : index
 # DUMPIR:         %[[TRUE:.*]] = arith.constant true
@@ -485,7 +458,7 @@ def gemm_multistage_kernel():
 # DUMPIR:       %[[SMEM_EPI:.*]] = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
 # DUMPIR:       %[[C0_VIEW:.*]] = arith.constant 0 : index
 # DUMPIR:       %[[VIEW_EPI:.*]] = memref.view %[[SMEM_EPI]][%[[C0_VIEW]]][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<128x128xf32, #gpu.address_space<workgroup>>
-# DUMPIR:       %[[SUBVIEW_EPI:.*]] = memref.subview %[[DMEM]][%[[DIMX_EPI]], %[[DIMY_EPI]]] [128, 128] [1, 1] : memref<512x256xf32> to memref<128x128xf32, strided<[256, 1], offset: ?>>
+# DUMPIR:       %[[SUBVIEW_EPI:.*]] = memref.subview %{{.*}}[%[[DIMX_EPI]], %[[DIMY_EPI]]] [128, 128] [1, 1] : memref<512x256xf32> to memref<128x128xf32, strided<[256, 1], offset: ?>>
 # DUMPIR:       nvgpu.warpgroup.mma.store %[[LOOP_RES]]#0, %[[VIEW_EPI]] : <fragmented = vector<128x128xf32>> to memref<128x128xf32, #gpu.address_space<workgroup>>
 # DUMPIR:       gpu.barrier
 # DUMPIR:       %[[C0_STORE:.*]] = arith.constant 0 : index
@@ -494,10 +467,3 @@ def gemm_multistage_kernel():
 # DUMPIR:       scf.for %arg15 = %[[C0_STORE]] to %[[C128_STORE]] step %[[C1_STORE]] {
 # DUMPIR:         %[[VAL_LOAD:.*]] = memref.load %[[VIEW_EPI]][%arg15, %[[TID_X_EPI]]] : memref<128x128xf32, #gpu.address_space<workgroup>>
 # DUMPIR:         memref.store %[[VAL_LOAD]], %[[SUBVIEW_EPI]][%arg15, %[[TID_X_EPI]]] : memref<128x128xf32, strided<[256, 1], offset: ?>>
-# DUMPIR:       }
-# DUMPIR:       gpu.terminator
-# DUMPIR:     }
-# DUMPIR:     %[[CPYD:.*]] = gpu.memcpy async [%[[WAIT2]]] %arg2, %[[DMEM]] : memref<512x256xf32>, memref<512x256xf32>
-# DUMPIR:     gpu.wait [%[[CPYD]]]
-# DUMPIR:     return
-# DUMPIR:   }
diff --git a/mlir/test/Examples/NVGPU/Ch5.py b/mlir/test/Examples/NVGPU/Ch5.py
index f81e06d13d435..b725e50d8f44b 100644
--- a/mlir/test/Examples/NVGPU/Ch5.py
+++ b/mlir/test/Examples/NVGPU/Ch5.py
@@ -257,13 +257,13 @@ def epilogue(D: WGMMAMatrix, d_dev):
 @NVDSL.mlir_func
 def gemm_warp_specialized(a, b, d, num_stages):
     token_ty = gpu.AsyncTokenType.get()
-    t1 = gpu.wait(token_ty, [])
+    t1 = gpu.wait([])
     a_dev, t2 = gpu.alloc(a.type, token_ty, [t1], [], [])
     b_dev, t3 = gpu.alloc(b.type, token_ty, [t2], [], [])
     d_dev, t4 = gpu.alloc(d.type, token_ty, [t3], [], [])
     t5 = gpu.memcpy(token_ty, [t4], a_dev, a)
     t6 = gpu.memcpy(token_ty, [t5], b_dev, b)
-    t7 = gpu.wait(token_ty, [t6])
+    t7 = gpu.wait([t6])
 
     sw = nvgpu.TensorMapSwizzleKind.SWIZZLE_128B
     a_tma = TMA([128, 64], a.type, swizzle=sw)
@@ -299,7 +299,7 @@ def gemm_warp_specialized_kernel():
     gemm_warp_specialized_kernel()
 
     t8 = gpu.memcpy(token_ty, [t7], d, d_dev)
-    gpu.wait(None, [t8])
+    gpu.wait([t8])
 
 
 # Python pass arguments to MLIR
@@ -324,7 +324,6 @@ def gemm_warp_specialized_kernel():
 # CHECK-NOT: Mismatched elements
 # CHECK: PASS
 
-# DUMPIR: gpu.launch blocks(%arg3, %arg4, %arg5) in (%arg9 = %c4, %arg10 = %c2, %arg11 = %c1) threads(%arg6, %arg7, %arg8) in (%arg12 = %c256, %arg13 = %c1_7, %arg14 = %c1_8) dynamic_shared_memory_size %c229376_i32 {
 # DUMPIR:       %[[TID_X:.*]] = gpu.thread_id  x
 # DUMPIR:       %[[C128:.*]] = arith.constant 128 : index
 # DUMPIR:       %[[REM1:.*]] = arith.remui %[[TID_X]], %[[C128]] : index
diff --git a/mlir/test/Examples/NVGPU/tools/nvdsl.py b/mlir/test/Examples/NVGPU/tools/nvdsl.py
index 82bc26594900c..856107293470d 100644
--- a/mlir/test/Examples/NVGPU/tools/nvdsl.py
+++ b/mlir/test/Examples/NVGPU/tools/nvdsl.py
@@ -311,13 +311,10 @@ def decorator(func):
             @functools.wraps(func)
             def wrapper(*args, **kwargs):
                 launch_op = gpu.LaunchOp(
-                    None,
-                    [],
-                    *map(const, grid),
-                    *map(const, block),
-                    dynamicSharedMemorySize=arith.constant(T.i32(), smem),
+                    grid_size=grid,
+                    block_size=block,
+                    dynamic_shared_memory_size=arith.constant(T.i32(), smem),
                 )
-                launch_op.body.blocks.append(*([T.index()] * 12))
                 with ir.InsertionPoint(launch_op.body.blocks[0]):
                     result = func(*args, **kwargs)
                     gpu.terminator()



More information about the Mlir-commits mailing list