[Mlir-commits] [mlir] [mlir][nvgpu] Simplify TMA IR generation (PR #87153)

Guray Ozen llvmlistbot at llvm.org
Sat Mar 30 05:15:13 PDT 2024


https://github.com/grypp created https://github.com/llvm/llvm-project/pull/87153

This PR add `TmaDescriptorBuilder` class simplifies TMA generation in the test, makes the code more readable.

>From 52030c8597ac92fce1bbe726ccf74ae8fe334a3f Mon Sep 17 00:00:00 2001
From: grypp <guray.ozen at gmail.com>
Date: Tue, 26 Mar 2024 10:42:05 +0000
Subject: [PATCH] [mlir][nvgpu] Simplify TMA IR generation

This PR simplifies TMA generation in the test, makes the code more readable.

Co-authored-by: Manish Gupta <manigupta at google.com>
---
 .../CUDA/sm90/python/tools/matmulBuilder.py   | 189 +++++++++---------
 1 file changed, 96 insertions(+), 93 deletions(-)

diff --git a/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
index fac138dce605a7..6823587801a7b0 100644
--- a/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
+++ b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
@@ -28,6 +28,43 @@
 DEBUG = False
 
 
+class TmaDescriptorBuilder:
+    """A class that builds a TMA descriptor."""
+
+    def __init__(self, swizzle, l2promo, oob, interleave, tma_box_shape, memref_ty):
+        self.swizzle = swizzle  # mlir.nvgpu.TensorMapSwizzleKind
+        self.l2promo = l2promo  # mlir.nvgpu.TensorMapL2PromoKind
+        self.oob = oob  # mlir.nvgpu.TensorMapOOBKind
+        self.interleave = interleave  # mlir.nvgpu.TensorMapInterleaveKind
+        self.tma_box_shape = tma_box_shape
+        self.memref_ty = memref_ty  # MemRefType
+
+    @property
+    def tensormap_descriptor_ty(self):
+        """Returns a tensormap descriptor type."""
+        memref_str = f"memref<{self.tma_box_shape[0]}x{self.tma_box_shape[1]}x{self.memref_ty.element_type}, 3>"
+        parse_str = f"!nvgpu.tensormap.descriptor<tensor = {memref_str},\
+                                              swizzle = {self.swizzle},\
+                                              l2promo = {self.l2promo},\
+                                              oob = {self.oob},\
+                                              interleave = {self.interleave}>"
+        return ir.Type.parse(parse_str)
+
+    def tma_descriptor_op(self, device_ptr):
+        """Returns a tensormap descriptor op."""
+        tma_descriptor_ty = self.tensormap_descriptor_ty
+        device_unranked_memref = memref.CastOp(
+            ir.UnrankedMemRefType.get(
+                self.memref_ty.element_type, self.memref_ty.memory_space
+            ),
+            device_ptr,
+        )
+        tma_descriptor_op = nvgpu.TmaCreateDescriptorOp(
+            tma_descriptor_ty, device_unranked_memref, map(c, self.tma_box_shape)
+        )
+        return tma_descriptor_op.result
+
+
 def debug_print(fmt, *args, predicate=None, threadNumber=-1, forcePrint=False):
     if not DEBUG and not forcePrint:
         return
@@ -162,28 +199,6 @@ def generate_matmul_ws(
         + str(num_stages)
         + ">"
     )
-    a_tma_desc_ty = ir.Type.parse(
-        "!nvgpu.tensormap.descriptor<tensor = memref<"
-        + str(BLOCK_M)
-        + "x"
-        + str(TMA_LAST_DIM_F16)
-        + "x"
-        + str(a_elem_ty)
-        + ", "
-        + str(smem_space)
-        + ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>"
-    )
-    b_tma_desc_ty = ir.Type.parse(
-        "!nvgpu.tensormap.descriptor<tensor = memref<"
-        + str(BLOCK_K)
-        + "x"
-        + str(TMA_LAST_DIM_F16)
-        + "x"
-        + str(b_elem_ty)
-        + ", "
-        + str(smem_space)
-        + ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>"
-    )
     acc_ty = ir.Type.parse(
         "!nvgpu.warpgroup.accumulator<fragmented=vector<"
         + str(BLOCK_M)
@@ -240,21 +255,26 @@ def generate_matmul_ws(
             t7 = gpu.wait(token_ty, [t6])
 
             # Step 2. Create TMA Descriptors
-            tma_specs = [
-                (a_device, a_tma_desc_ty, a_tma_shape),
-                (b_device, b_tma_desc_ty, b_tma_shape),
-            ]
-            tma_descs = []
-            for x_device, tensor_map_ty, tile_shape in tma_specs:
-                x_unranked = memref.cast(
-                    ir.UnrankedMemRefType.get(a_elem_ty, a_ty.memory_space), x_device
-                )
-                tma_descs.append(
-                    nvgpu.TmaCreateDescriptorOp(
-                        tensor_map_ty, x_unranked, map(c, tile_shape)
-                    ).result
-                )
-            a_tma_desc, b_tma_desc = tma_descs
+            a_tma_desc = TmaDescriptorBuilder(
+                nvgpu.TensorMapSwizzleKind.SWIZZLE_128B,
+                nvgpu.TensorMapL2PromoKind.L2PROMO_NONE,
+                nvgpu.TensorMapOOBKind.OOB_ZERO,
+                nvgpu.TensorMapInterleaveKind.INTERLEAVE_NONE,
+                a_tma_shape,
+                a_ty,
+            )
+
+            b_tma_desc = TmaDescriptorBuilder(
+                nvgpu.TensorMapSwizzleKind.SWIZZLE_128B,
+                nvgpu.TensorMapL2PromoKind.L2PROMO_NONE,
+                nvgpu.TensorMapOOBKind.OOB_ZERO,
+                nvgpu.TensorMapInterleaveKind.INTERLEAVE_NONE,
+                b_tma_shape,
+                b_ty,
+            )
+
+            a_tma_desc_op = a_tma_desc.tma_descriptor_op(a_device)
+            b_tma_desc_op = b_tma_desc.tma_descriptor_op(b_device)
 
             # Step 3. Launch Kernel with 2 Warpgroups : 1 Producer, 1 Consumer
             cta_m = M // BLOCK_M
@@ -267,7 +287,7 @@ def generate_matmul_ws(
                 [t7],
                 *map(c, grid),
                 *map(c, block),
-                dynamicSharedMemorySize=c(smem_size, ty=T.i32())
+                dynamicSharedMemorySize=c(smem_size, ty=T.i32()),
             )
             launch_op.body.blocks.append(*([T.index()] * 12))
             with ir.InsertionPoint(launch_op.body.blocks[0]):
@@ -315,8 +335,8 @@ def generate_matmul_ws(
                 gpu.barrier()
 
                 # GPU Step 3. Prefetch TMA descriptors
-                nvgpu.tma_prefetch_descriptor(a_tma_desc, predicate=wgPrimaryThread)
-                nvgpu.tma_prefetch_descriptor(b_tma_desc, predicate=wgPrimaryThread)
+                nvgpu.tma_prefetch_descriptor(a_tma_desc_op, predicate=wgPrimaryThread)
+                nvgpu.tma_prefetch_descriptor(b_tma_desc_op, predicate=wgPrimaryThread)
 
                 ns = num_stages if num_stages == 1 else num_stages - 1
                 # GPU Step 5. Producer Warpgroup (TMA Warpgroup)
@@ -405,7 +425,7 @@ def generate_matmul_ws(
                         nvgpu.TmaAsyncLoadOp(
                             a_tma_slice,
                             mbarTMA,
-                            a_tma_desc,
+                            a_tma_desc_op,
                             coordinates=[coord, dimX],
                             mbarId=stage,
                             predicate=producerPrimaryThread,
@@ -413,7 +433,7 @@ def generate_matmul_ws(
                         nvgpu.TmaAsyncLoadOp(
                             b_tma_slice_1,
                             mbarTMA,
-                            b_tma_desc,
+                            b_tma_desc_op,
                             coordinates=[dimY, coord],
                             mbarId=stage,
                             predicate=producerPrimaryThread,
@@ -422,7 +442,7 @@ def generate_matmul_ws(
                         nvgpu.TmaAsyncLoadOp(
                             b_tma_slice_2,
                             mbarTMA,
-                            b_tma_desc,
+                            b_tma_desc_op,
                             coordinates=[dimY2, coord],
                             mbarId=stage,
                             predicate=producerPrimaryThread,
@@ -514,10 +534,10 @@ def generate_matmul_ws(
                             predicate=consumerPrimaryThread,
                         )
                         da = nvgpu.WarpgroupGenerateDescriptorOp(
-                            a_wgmma_ty, a_tile_slice, a_tma_desc
+                            a_wgmma_ty, a_tile_slice, a_tma_desc_op
                         )
                         db = nvgpu.WarpgroupGenerateDescriptorOp(
-                            b_wgmma_ty, b_tile_slice, b_tma_desc
+                            b_wgmma_ty, b_tile_slice, b_tma_desc_op
                         )
 
                         # Step 6.3.3. MMA
@@ -679,28 +699,6 @@ def generate_matmul_multistage(
         + str(num_stages)
         + ">"
     )
-    a_tma_desc_ty = ir.Type.parse(
-        "!nvgpu.tensormap.descriptor<tensor = memref<"
-        + str(BLOCK_M)
-        + "x"
-        + str(TMA_LAST_DIM_F16)
-        + "x"
-        + str(a_elem_ty)
-        + ", "
-        + str(smem_space)
-        + ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>"
-    )
-    b_tma_desc_ty = ir.Type.parse(
-        "!nvgpu.tensormap.descriptor<tensor = memref<"
-        + str(BLOCK_K)
-        + "x"
-        + str(TMA_LAST_DIM_F16)
-        + "x"
-        + str(b_elem_ty)
-        + ", "
-        + str(smem_space)
-        + ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>"
-    )
     acc_ty = ir.Type.parse(
         "!nvgpu.warpgroup.accumulator<fragmented=vector<"
         + str(BLOCK_M)
@@ -767,21 +765,26 @@ def generate_matmul_multistage(
             t7 = gpu.wait(token_ty, [t6])
 
             # Step 2. Create TMA Descriptors
-            tma_specs = [
-                (a_device, a_tma_desc_ty, a_tma_shape),
-                (b_device, b_tma_desc_ty, b_tma_shape),
-            ]
-            tma_descs = []
-            for x_device, tensor_map_ty, tile_shape in tma_specs:
-                x_unranked = memref.cast(
-                    ir.UnrankedMemRefType.get(a_elem_ty, a_ty.memory_space), x_device
-                )
-                tma_descs.append(
-                    nvgpu.TmaCreateDescriptorOp(
-                        tensor_map_ty, x_unranked, map(c, tile_shape)
-                    ).result
-                )
-            a_tma_desc, b_tma_desc = tma_descs
+            a_tma_desc = TmaDescriptorBuilder(
+                nvgpu.TensorMapSwizzleKind.SWIZZLE_128B,
+                nvgpu.TensorMapL2PromoKind.L2PROMO_NONE,
+                nvgpu.TensorMapOOBKind.OOB_ZERO,
+                nvgpu.TensorMapInterleaveKind.INTERLEAVE_NONE,
+                a_tma_shape,
+                a_ty,
+            )
+
+            b_tma_desc = TmaDescriptorBuilder(
+                nvgpu.TensorMapSwizzleKind.SWIZZLE_128B,
+                nvgpu.TensorMapL2PromoKind.L2PROMO_NONE,
+                nvgpu.TensorMapOOBKind.OOB_ZERO,
+                nvgpu.TensorMapInterleaveKind.INTERLEAVE_NONE,
+                b_tma_shape,
+                b_ty,
+            )
+
+            a_tma_desc_op = a_tma_desc.tma_descriptor_op(a_device)
+            b_tma_desc_op = b_tma_desc.tma_descriptor_op(b_device)
 
             # Step 3. Launch Kernel with 1 Warpgroup
             cta_m = M // BLOCK_M
@@ -794,7 +797,7 @@ def generate_matmul_multistage(
                 [t7],
                 *map(c, grid),
                 *map(c, block),
-                dynamicSharedMemorySize=c(smem_size, ty=T.i32())
+                dynamicSharedMemorySize=c(smem_size, ty=T.i32()),
             )
             launch_op.body.blocks.append(*([T.index()] * 12))
             with ir.InsertionPoint(launch_op.body.blocks[0]):
@@ -819,8 +822,8 @@ def generate_matmul_multistage(
                 gpu.barrier()
 
                 # GPU Step 2. Prefetch TMA descriptors
-                nvgpu.tma_prefetch_descriptor(a_tma_desc, predicate=primaryThread)
-                nvgpu.tma_prefetch_descriptor(b_tma_desc, predicate=primaryThread)
+                nvgpu.tma_prefetch_descriptor(a_tma_desc_op, predicate=primaryThread)
+                nvgpu.tma_prefetch_descriptor(b_tma_desc_op, predicate=primaryThread)
 
                 # GPU Step 3. Prologue (global memory --> shared memory)
                 ns = num_stages if num_stages == 1 else num_stages - 1
@@ -880,7 +883,7 @@ def generate_matmul_multistage(
                     nvgpu.TmaAsyncLoadOp(
                         a_tma_slice,
                         mbarTMA,
-                        a_tma_desc,
+                        a_tma_desc_op,
                         coordinates=[coord, dimX],
                         mbarId=iv,
                         predicate=primaryThread,
@@ -888,7 +891,7 @@ def generate_matmul_multistage(
                     nvgpu.TmaAsyncLoadOp(
                         b_tma_slice_1,
                         mbarTMA,
-                        b_tma_desc,
+                        b_tma_desc_op,
                         coordinates=[dimY, coord],
                         mbarId=iv,
                         predicate=primaryThread,
@@ -896,7 +899,7 @@ def generate_matmul_multistage(
                     nvgpu.TmaAsyncLoadOp(
                         b_tma_slice_2,
                         mbarTMA,
-                        b_tma_desc,
+                        b_tma_desc_op,
                         coordinates=[dimY2, coord],
                         mbarId=iv,
                         predicate=primaryThread,
@@ -972,10 +975,10 @@ def generate_matmul_multistage(
                         predicate=primaryThread,
                     )
                     da = nvgpu.WarpgroupGenerateDescriptorOp(
-                        a_wgmma_ty, a_tile_slice, a_tma_desc
+                        a_wgmma_ty, a_tile_slice, a_tma_desc_op
                     )
                     db = nvgpu.WarpgroupGenerateDescriptorOp(
-                        b_wgmma_ty, b_tile_slice, b_tma_desc
+                        b_wgmma_ty, b_tile_slice, b_tma_desc_op
                     )
 
                     # Step 4.3. MMA
@@ -1060,7 +1063,7 @@ def generate_matmul_multistage(
                     nvgpu.TmaAsyncLoadOp(
                         a_tma_slice,
                         mbarTMA,
-                        a_tma_desc,
+                        a_tma_desc_op,
                         coordinates=[coord, dimX],
                         mbarId=nextSlot,
                         predicate=p,
@@ -1068,7 +1071,7 @@ def generate_matmul_multistage(
                     nvgpu.TmaAsyncLoadOp(
                         b_tma_slice_1,
                         mbarTMA,
-                        b_tma_desc,
+                        b_tma_desc_op,
                         coordinates=[dimY, coord],
                         mbarId=nextSlot,
                         predicate=p,
@@ -1077,7 +1080,7 @@ def generate_matmul_multistage(
                     nvgpu.TmaAsyncLoadOp(
                         b_tma_slice_2,
                         mbarTMA,
-                        b_tma_desc,
+                        b_tma_desc_op,
                         coordinates=[dimY2, coord],
                         mbarId=nextSlot,
                         predicate=p,



More information about the Mlir-commits mailing list