[Mlir-commits] [mlir] [mlir][nvgpu] Simplify TMA IR generation (PR #87153)
Guray Ozen
llvmlistbot at llvm.org
Sat Mar 30 05:15:13 PDT 2024
https://github.com/grypp created https://github.com/llvm/llvm-project/pull/87153
This PR add `TmaDescriptorBuilder` class simplifies TMA generation in the test, makes the code more readable.
>From 52030c8597ac92fce1bbe726ccf74ae8fe334a3f Mon Sep 17 00:00:00 2001
From: grypp <guray.ozen at gmail.com>
Date: Tue, 26 Mar 2024 10:42:05 +0000
Subject: [PATCH] [mlir][nvgpu] Simplify TMA IR generation
This PR simplifies TMA generation in the test, makes the code more readable.
Co-authored-by: Manish Gupta <manigupta at google.com>
---
.../CUDA/sm90/python/tools/matmulBuilder.py | 189 +++++++++---------
1 file changed, 96 insertions(+), 93 deletions(-)
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
index fac138dce605a7..6823587801a7b0 100644
--- a/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
+++ b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
@@ -28,6 +28,43 @@
DEBUG = False
+class TmaDescriptorBuilder:
+ """A class that builds a TMA descriptor."""
+
+ def __init__(self, swizzle, l2promo, oob, interleave, tma_box_shape, memref_ty):
+ self.swizzle = swizzle # mlir.nvgpu.TensorMapSwizzleKind
+ self.l2promo = l2promo # mlir.nvgpu.TensorMapL2PromoKind
+ self.oob = oob # mlir.nvgpu.TensorMapOOBKind
+ self.interleave = interleave # mlir.nvgpu.TensorMapInterleaveKind
+ self.tma_box_shape = tma_box_shape
+ self.memref_ty = memref_ty # MemRefType
+
+ @property
+ def tensormap_descriptor_ty(self):
+ """Returns a tensormap descriptor type."""
+ memref_str = f"memref<{self.tma_box_shape[0]}x{self.tma_box_shape[1]}x{self.memref_ty.element_type}, 3>"
+ parse_str = f"!nvgpu.tensormap.descriptor<tensor = {memref_str},\
+ swizzle = {self.swizzle},\
+ l2promo = {self.l2promo},\
+ oob = {self.oob},\
+ interleave = {self.interleave}>"
+ return ir.Type.parse(parse_str)
+
+ def tma_descriptor_op(self, device_ptr):
+ """Returns a tensormap descriptor op."""
+ tma_descriptor_ty = self.tensormap_descriptor_ty
+ device_unranked_memref = memref.CastOp(
+ ir.UnrankedMemRefType.get(
+ self.memref_ty.element_type, self.memref_ty.memory_space
+ ),
+ device_ptr,
+ )
+ tma_descriptor_op = nvgpu.TmaCreateDescriptorOp(
+ tma_descriptor_ty, device_unranked_memref, map(c, self.tma_box_shape)
+ )
+ return tma_descriptor_op.result
+
+
def debug_print(fmt, *args, predicate=None, threadNumber=-1, forcePrint=False):
if not DEBUG and not forcePrint:
return
@@ -162,28 +199,6 @@ def generate_matmul_ws(
+ str(num_stages)
+ ">"
)
- a_tma_desc_ty = ir.Type.parse(
- "!nvgpu.tensormap.descriptor<tensor = memref<"
- + str(BLOCK_M)
- + "x"
- + str(TMA_LAST_DIM_F16)
- + "x"
- + str(a_elem_ty)
- + ", "
- + str(smem_space)
- + ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>"
- )
- b_tma_desc_ty = ir.Type.parse(
- "!nvgpu.tensormap.descriptor<tensor = memref<"
- + str(BLOCK_K)
- + "x"
- + str(TMA_LAST_DIM_F16)
- + "x"
- + str(b_elem_ty)
- + ", "
- + str(smem_space)
- + ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>"
- )
acc_ty = ir.Type.parse(
"!nvgpu.warpgroup.accumulator<fragmented=vector<"
+ str(BLOCK_M)
@@ -240,21 +255,26 @@ def generate_matmul_ws(
t7 = gpu.wait(token_ty, [t6])
# Step 2. Create TMA Descriptors
- tma_specs = [
- (a_device, a_tma_desc_ty, a_tma_shape),
- (b_device, b_tma_desc_ty, b_tma_shape),
- ]
- tma_descs = []
- for x_device, tensor_map_ty, tile_shape in tma_specs:
- x_unranked = memref.cast(
- ir.UnrankedMemRefType.get(a_elem_ty, a_ty.memory_space), x_device
- )
- tma_descs.append(
- nvgpu.TmaCreateDescriptorOp(
- tensor_map_ty, x_unranked, map(c, tile_shape)
- ).result
- )
- a_tma_desc, b_tma_desc = tma_descs
+ a_tma_desc = TmaDescriptorBuilder(
+ nvgpu.TensorMapSwizzleKind.SWIZZLE_128B,
+ nvgpu.TensorMapL2PromoKind.L2PROMO_NONE,
+ nvgpu.TensorMapOOBKind.OOB_ZERO,
+ nvgpu.TensorMapInterleaveKind.INTERLEAVE_NONE,
+ a_tma_shape,
+ a_ty,
+ )
+
+ b_tma_desc = TmaDescriptorBuilder(
+ nvgpu.TensorMapSwizzleKind.SWIZZLE_128B,
+ nvgpu.TensorMapL2PromoKind.L2PROMO_NONE,
+ nvgpu.TensorMapOOBKind.OOB_ZERO,
+ nvgpu.TensorMapInterleaveKind.INTERLEAVE_NONE,
+ b_tma_shape,
+ b_ty,
+ )
+
+ a_tma_desc_op = a_tma_desc.tma_descriptor_op(a_device)
+ b_tma_desc_op = b_tma_desc.tma_descriptor_op(b_device)
# Step 3. Launch Kernel with 2 Warpgroups : 1 Producer, 1 Consumer
cta_m = M // BLOCK_M
@@ -267,7 +287,7 @@ def generate_matmul_ws(
[t7],
*map(c, grid),
*map(c, block),
- dynamicSharedMemorySize=c(smem_size, ty=T.i32())
+ dynamicSharedMemorySize=c(smem_size, ty=T.i32()),
)
launch_op.body.blocks.append(*([T.index()] * 12))
with ir.InsertionPoint(launch_op.body.blocks[0]):
@@ -315,8 +335,8 @@ def generate_matmul_ws(
gpu.barrier()
# GPU Step 3. Prefetch TMA descriptors
- nvgpu.tma_prefetch_descriptor(a_tma_desc, predicate=wgPrimaryThread)
- nvgpu.tma_prefetch_descriptor(b_tma_desc, predicate=wgPrimaryThread)
+ nvgpu.tma_prefetch_descriptor(a_tma_desc_op, predicate=wgPrimaryThread)
+ nvgpu.tma_prefetch_descriptor(b_tma_desc_op, predicate=wgPrimaryThread)
ns = num_stages if num_stages == 1 else num_stages - 1
# GPU Step 5. Producer Warpgroup (TMA Warpgroup)
@@ -405,7 +425,7 @@ def generate_matmul_ws(
nvgpu.TmaAsyncLoadOp(
a_tma_slice,
mbarTMA,
- a_tma_desc,
+ a_tma_desc_op,
coordinates=[coord, dimX],
mbarId=stage,
predicate=producerPrimaryThread,
@@ -413,7 +433,7 @@ def generate_matmul_ws(
nvgpu.TmaAsyncLoadOp(
b_tma_slice_1,
mbarTMA,
- b_tma_desc,
+ b_tma_desc_op,
coordinates=[dimY, coord],
mbarId=stage,
predicate=producerPrimaryThread,
@@ -422,7 +442,7 @@ def generate_matmul_ws(
nvgpu.TmaAsyncLoadOp(
b_tma_slice_2,
mbarTMA,
- b_tma_desc,
+ b_tma_desc_op,
coordinates=[dimY2, coord],
mbarId=stage,
predicate=producerPrimaryThread,
@@ -514,10 +534,10 @@ def generate_matmul_ws(
predicate=consumerPrimaryThread,
)
da = nvgpu.WarpgroupGenerateDescriptorOp(
- a_wgmma_ty, a_tile_slice, a_tma_desc
+ a_wgmma_ty, a_tile_slice, a_tma_desc_op
)
db = nvgpu.WarpgroupGenerateDescriptorOp(
- b_wgmma_ty, b_tile_slice, b_tma_desc
+ b_wgmma_ty, b_tile_slice, b_tma_desc_op
)
# Step 6.3.3. MMA
@@ -679,28 +699,6 @@ def generate_matmul_multistage(
+ str(num_stages)
+ ">"
)
- a_tma_desc_ty = ir.Type.parse(
- "!nvgpu.tensormap.descriptor<tensor = memref<"
- + str(BLOCK_M)
- + "x"
- + str(TMA_LAST_DIM_F16)
- + "x"
- + str(a_elem_ty)
- + ", "
- + str(smem_space)
- + ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>"
- )
- b_tma_desc_ty = ir.Type.parse(
- "!nvgpu.tensormap.descriptor<tensor = memref<"
- + str(BLOCK_K)
- + "x"
- + str(TMA_LAST_DIM_F16)
- + "x"
- + str(b_elem_ty)
- + ", "
- + str(smem_space)
- + ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>"
- )
acc_ty = ir.Type.parse(
"!nvgpu.warpgroup.accumulator<fragmented=vector<"
+ str(BLOCK_M)
@@ -767,21 +765,26 @@ def generate_matmul_multistage(
t7 = gpu.wait(token_ty, [t6])
# Step 2. Create TMA Descriptors
- tma_specs = [
- (a_device, a_tma_desc_ty, a_tma_shape),
- (b_device, b_tma_desc_ty, b_tma_shape),
- ]
- tma_descs = []
- for x_device, tensor_map_ty, tile_shape in tma_specs:
- x_unranked = memref.cast(
- ir.UnrankedMemRefType.get(a_elem_ty, a_ty.memory_space), x_device
- )
- tma_descs.append(
- nvgpu.TmaCreateDescriptorOp(
- tensor_map_ty, x_unranked, map(c, tile_shape)
- ).result
- )
- a_tma_desc, b_tma_desc = tma_descs
+ a_tma_desc = TmaDescriptorBuilder(
+ nvgpu.TensorMapSwizzleKind.SWIZZLE_128B,
+ nvgpu.TensorMapL2PromoKind.L2PROMO_NONE,
+ nvgpu.TensorMapOOBKind.OOB_ZERO,
+ nvgpu.TensorMapInterleaveKind.INTERLEAVE_NONE,
+ a_tma_shape,
+ a_ty,
+ )
+
+ b_tma_desc = TmaDescriptorBuilder(
+ nvgpu.TensorMapSwizzleKind.SWIZZLE_128B,
+ nvgpu.TensorMapL2PromoKind.L2PROMO_NONE,
+ nvgpu.TensorMapOOBKind.OOB_ZERO,
+ nvgpu.TensorMapInterleaveKind.INTERLEAVE_NONE,
+ b_tma_shape,
+ b_ty,
+ )
+
+ a_tma_desc_op = a_tma_desc.tma_descriptor_op(a_device)
+ b_tma_desc_op = b_tma_desc.tma_descriptor_op(b_device)
# Step 3. Launch Kernel with 1 Warpgroup
cta_m = M // BLOCK_M
@@ -794,7 +797,7 @@ def generate_matmul_multistage(
[t7],
*map(c, grid),
*map(c, block),
- dynamicSharedMemorySize=c(smem_size, ty=T.i32())
+ dynamicSharedMemorySize=c(smem_size, ty=T.i32()),
)
launch_op.body.blocks.append(*([T.index()] * 12))
with ir.InsertionPoint(launch_op.body.blocks[0]):
@@ -819,8 +822,8 @@ def generate_matmul_multistage(
gpu.barrier()
# GPU Step 2. Prefetch TMA descriptors
- nvgpu.tma_prefetch_descriptor(a_tma_desc, predicate=primaryThread)
- nvgpu.tma_prefetch_descriptor(b_tma_desc, predicate=primaryThread)
+ nvgpu.tma_prefetch_descriptor(a_tma_desc_op, predicate=primaryThread)
+ nvgpu.tma_prefetch_descriptor(b_tma_desc_op, predicate=primaryThread)
# GPU Step 3. Prologue (global memory --> shared memory)
ns = num_stages if num_stages == 1 else num_stages - 1
@@ -880,7 +883,7 @@ def generate_matmul_multistage(
nvgpu.TmaAsyncLoadOp(
a_tma_slice,
mbarTMA,
- a_tma_desc,
+ a_tma_desc_op,
coordinates=[coord, dimX],
mbarId=iv,
predicate=primaryThread,
@@ -888,7 +891,7 @@ def generate_matmul_multistage(
nvgpu.TmaAsyncLoadOp(
b_tma_slice_1,
mbarTMA,
- b_tma_desc,
+ b_tma_desc_op,
coordinates=[dimY, coord],
mbarId=iv,
predicate=primaryThread,
@@ -896,7 +899,7 @@ def generate_matmul_multistage(
nvgpu.TmaAsyncLoadOp(
b_tma_slice_2,
mbarTMA,
- b_tma_desc,
+ b_tma_desc_op,
coordinates=[dimY2, coord],
mbarId=iv,
predicate=primaryThread,
@@ -972,10 +975,10 @@ def generate_matmul_multistage(
predicate=primaryThread,
)
da = nvgpu.WarpgroupGenerateDescriptorOp(
- a_wgmma_ty, a_tile_slice, a_tma_desc
+ a_wgmma_ty, a_tile_slice, a_tma_desc_op
)
db = nvgpu.WarpgroupGenerateDescriptorOp(
- b_wgmma_ty, b_tile_slice, b_tma_desc
+ b_wgmma_ty, b_tile_slice, b_tma_desc_op
)
# Step 4.3. MMA
@@ -1060,7 +1063,7 @@ def generate_matmul_multistage(
nvgpu.TmaAsyncLoadOp(
a_tma_slice,
mbarTMA,
- a_tma_desc,
+ a_tma_desc_op,
coordinates=[coord, dimX],
mbarId=nextSlot,
predicate=p,
@@ -1068,7 +1071,7 @@ def generate_matmul_multistage(
nvgpu.TmaAsyncLoadOp(
b_tma_slice_1,
mbarTMA,
- b_tma_desc,
+ b_tma_desc_op,
coordinates=[dimY, coord],
mbarId=nextSlot,
predicate=p,
@@ -1077,7 +1080,7 @@ def generate_matmul_multistage(
nvgpu.TmaAsyncLoadOp(
b_tma_slice_2,
mbarTMA,
- b_tma_desc,
+ b_tma_desc_op,
coordinates=[dimY2, coord],
mbarId=nextSlot,
predicate=p,
More information about the Mlir-commits
mailing list