[Mlir-commits] [mlir] [mlir] GEMM Hopper Tensor Core Integration Test (PR #81478)
Guray Ozen
llvmlistbot at llvm.org
Sun Mar 3 08:38:01 PST 2024
https://github.com/grypp updated https://github.com/llvm/llvm-project/pull/81478
>From 8548c413400f7adbcc4728b7e33f68d388e2aefe Mon Sep 17 00:00:00 2001
From: grypp <guray.ozen at gmail.com>
Date: Mon, 12 Feb 2024 13:33:26 +0000
Subject: [PATCH 01/10] [mlir] GEMM Hopper Tensor Core Integration Test
This test aims to validate the correctness of the supported GEMM kernels in
NVGPU dialects, with current support for Multistage and Warp Specialization
kernels.
The test constructs and metaprograms IR using Python bindings, allowing
generic IR building. This flexibility enables changes to the shape,
tile size, or data type of the GEMM for testing purposes.
The entry function is `matmul`, where one can specify GEMM shape, tile size,
data type, GEMM algorithm (Multistage or Warp Specialization), and the maximum
number of stages.
Verification is done via numpy's matmul operation.
Example:
```
matmul(input_type=np.float16, # input types
output_type=np.float32, # output type
M=4096, N=4096, K=4096, # Shape
BLOCK_M=128, BLOCK_N=128, BLOCK_K=64, # Tile Size
use_warp_specialization=True, # Enable Warp Specialization
max_num_stages=3) # Number of stages in shared memory
```
### Parallelism Across CTAs
GEMM includes three loops defining the shape of the GEMM, specified in the
`matmul` function.
The program builds IR using the following loop structure, tiling the loops
with the given tile size and parallelizing the two outermost loops into the
first and second dimensions of CTAs.
```
for(bi = 0; i < M; i += BLOCK_M) # parallelize across blockIdx.x
for(bj = 0; j < N; j += BLOCK_N) # parallelize across blockIdx.y
for(bk = 0; k < K; K += BLOCK_K)
for(i = bi; i < (bi + BLOCK_M); ++i)
for(j = bj; j < (bj + BLOCK_N); ++j)
for(k = bk; k < (bk + BLOCK_K); ++k)
```
## Multistage Kernel
This kernel launches a single warp group (128 threads). The primary thread
(pthread) requests load from TMA. Threads collectively wait for the data and
perform mma operations. After completing the shape, threads together store
first fragmented registers to shared memory, then from shared memory to global
memory; this part is called the epilogue.
Execution Timeline of Multistage Kernel with 3 stages:
```
+-------+----------------+--------------------+--------------------+--------------------+-----+-----------------------+
| |Prologue ----> |MainLoop ----> |Epilogue |
+-------+----------------+--------------------+--------------------+--------------------+-----+-----------------------+
|pthread|[tma-0,1,2] |[wait-0][mma][tma-2]|[wait-1][mma][tma-0]|[wait-2][mma][tma-1]| ... | [mma-wait] |[epilogue]|
|wgroup | ........ |[wait-0][mma] |[wait-1][mma] |[wait-2][mma] | ... | [mma-wait] |[epilogue]|
+-------+----------------+--------------------+--------------------+--------------------+-----+-----------------------+
```
## Warp Specialization Kernel
This kernel launches 2 warp groups (2x128 threads) per CTA, specializing one
as `producer warp group` and another as `consumer warp group`. The
`producer warp group` is responsible for requesting TMA load, while the
`consumer warp group` performs the mma operation. The epilogue section is
handled by the `consumer warp group` as its threads own the fragmented registers.
Execution Timeline of Warp Specialization Kernel with 2 stages:
```
+--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+
| |MainLoop ----> | 1st Epilogue | 2nd Epilogue |
+--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+
|pthread1|[tma-0] | [tma-1] | [tma-0] | [tma-1] | ..........................| ........... | [shmem->global] |
|wgroup1 | .......| | | | | | [shmem->global] |
+--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+
|wgroup2 |[wait-0][mma], [wait-1][mma], [wait-0][mma], [wait-1][mma], ......| [reg->shmem] | [shmem->global]|
+--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+
```
---
.../GPU/CUDA/sm90/python/lit.local.cfg | 2 +
.../GPU/CUDA/sm90/python/matmul.py | 186 +++++
.../GPU/CUDA/sm90/python/tools/lit.local.cfg | 3 +
.../CUDA/sm90/python/tools/matmulBuilder.py | 676 ++++++++++++++++++
.../CUDA/sm90/python/tools/nvgpucompiler.py | 44 ++
5 files changed, 911 insertions(+)
create mode 100644 mlir/test/Integration/GPU/CUDA/sm90/python/lit.local.cfg
create mode 100644 mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py
create mode 100644 mlir/test/Integration/GPU/CUDA/sm90/python/tools/lit.local.cfg
create mode 100644 mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
create mode 100644 mlir/test/Integration/GPU/CUDA/sm90/python/tools/nvgpucompiler.py
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/python/lit.local.cfg b/mlir/test/Integration/GPU/CUDA/sm90/python/lit.local.cfg
new file mode 100644
index 00000000000000..2d5a9d00e73226
--- /dev/null
+++ b/mlir/test/Integration/GPU/CUDA/sm90/python/lit.local.cfg
@@ -0,0 +1,2 @@
+if not config.enable_cuda_runner or not config.mlir_run_cuda_sm90_tests:
+ config.unsupported = True
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py b/mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py
new file mode 100644
index 00000000000000..e153fcb44b9860
--- /dev/null
+++ b/mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py
@@ -0,0 +1,186 @@
+# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
+# RUN: %PYTHON %s | FileCheck %s
+# CHECK: PASS
+
+# ===--- GEMM Hopper Tensor Core Integration Test ---===
+#
+# This test aims to validate the correctness of the supported GEMM kernels in
+# NVGPU dialects, with current support for Multistage and Warp Specialization
+# kernels.
+# The test constructs and metaprograms IR using Python bindings, allowing
+# generic IR building. This flexibility enables changes to the shape,
+# tile size, or data type of the GEMM for testing purposes.
+# The entry function is `matmul`, where one can specify GEMM shape, tile size,
+# data type, GEMM algorithm (Multistage or Warp Specialization), and the maximum
+# number of stages.
+# Verification is done via numpy's matmul operation.
+#
+# Example:
+# matmul(input_type=np.float16, # input types
+# output_type=np.float32, # output type
+# M=4096, N=4096, K=4096, # Shape
+# BLOCK_M=128, BLOCK_N=128, BLOCK_K=64, # Tile Size
+# use_warp_specialization=True, # Enable Warp Specialization
+# max_num_stages=3) # Number of stages in shared memory
+#
+# ===--- Parallelism Across CTAs ---===
+#
+# GEMM includes three loops defining the shape of the GEMM, specified in the
+# `matmul` function.
+# The program builds IR using the following loop structure, tiling the loops
+# with the given tile size and parallelizing the two outermost loops into the
+# first and second dimensions of CTAs.
+#
+# for(bi = 0; i < M; i += BLOCK_M) # parallelize across blockIdx.x
+# for(bj = 0; j < N; j += BLOCK_N) # parallelize across blockIdx.y
+# for(bk = 0; k < K; K += BLOCK_K)
+# for(i = bi; i < (bi + BLOCK_M); ++i)
+# for(j = bj; j < (bj + BLOCK_N); ++j)
+# for(k = bk; k < (bk + BLOCK_K); ++k)
+#
+# ===--- Multistage Kernel ---===
+#
+# This kernel launches a single warp group (128 threads). The primary thread
+# (pthread) requests load from TMA. Threads collectively wait for the data and
+# perform mma operations. After completing the shape, threads together store
+# first fragmented registers to shared memory, then from shared memory to global
+# memory; this part is called the epilogue.
+#
+# Execution Timeline of Multistage Kernel with 3 stages:
+# +-------+----------------+--------------------+--------------------+--------------------+-----+-----------------------+
+# | |Prologue ----> |MainLoop ----> |Epilogue |
+# +-------+----------------+--------------------+--------------------+--------------------+-----+-----------------------+
+# |pthread|[tma-0,1,2] |[wait-0][mma][tma-2]|[wait-1][mma][tma-0]|[wait-2][mma][tma-1]| ... | [mma-wait] |[epilogue]|
+# |wgroup | ........ |[wait-0][mma] |[wait-1][mma] |[wait-2][mma] | ... | [mma-wait] |[epilogue]|
+# +-------+----------------+--------------------+--------------------+--------------------+-----+-----------------------+
+#
+# ===--- Warp Specialization Kernel ---===
+#
+# This kernel launches 2 warp groups (2x128 threads) per CTA, specializing one
+# as `producer warp group` and another as `consumer warp group`. The
+# `producer warp group` is responsible for requesting TMA load, while the
+# `consumer warp group` performs the mma operation. The epilogue section is
+# handled by the `consumer warp group` as its threads own the fragmented registers.
+#
+# Execution Timeline of Warp Specialization Kernel with 2 stages:
+# +--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+
+# | |MainLoop ----> | 1st Epilogue | 2nd Epilogue |
+# +--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+
+# |pthread1|[tma-0] | [tma-1] | [tma-0] | [tma-1] | ..........................| ........... | [shmem->global] |
+# |wgroup1 | .......| | | | | | [shmem->global] |
+# +--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+
+# |wgroup2 |[wait-0][mma], [wait-1][mma], [wait-0][mma], [wait-1][mma], ......| [reg->shmem] | [shmem->global]|
+# +--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+
+
+import errno
+import numpy as np
+import subprocess
+import ctypes
+from tools import nvgpucompiler
+from tools import matmulBuilder
+import contextlib
+import os
+import sys
+import pathlib
+import ctypes
+from mlir import runtime as rt
+
+def generate_matmul(input_type=np.float16,
+ output_type=np.float32,
+ M=4096,
+ N=4096,
+ K=4096,
+ BLOCK_M=128,
+ BLOCK_N=128,
+ BLOCK_K=64,
+ use_warp_specilization=True,
+ saveIR=False,
+ max_num_stages=3):
+ with matmulBuilder.ir.Context() as ctx, matmulBuilder.ir.Location.unknown():
+ if use_warp_specilization:
+ mlir_nvgpu_module = matmulBuilder.generate_matmul_ws(input_type, output_type, M, N, K, BLOCK_M, BLOCK_N,
+ BLOCK_K, max_num_stages)
+ else:
+ mlir_nvgpu_module = matmulBuilder.generate_matmul_multistage(input_type, output_type, M, N, K, BLOCK_M,
+ BLOCK_N, BLOCK_K, max_num_stages)
+
+ mlir_nvgpu_module.operation.verify()
+
+ # Save generated IR
+ if saveIR:
+ # print(mlir_nvgpu_module)
+ original_stdout = sys.stdout
+ with open('gemm.mlir', 'w') as f:
+ sys.stdout = f
+ print(mlir_nvgpu_module)
+ sys.stdout = original_stdout
+
+ # Get compiler
+ options = f"cubin-chip=sm_90a cubin-features=+ptx80 opt-level=3"
+ support_lib = os.getenv("SUPPORT_LIB")
+ if not os.path.exists(support_lib):
+ raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), support_lib)
+ compiler = nvgpucompiler.NvgpuCompiler(options, opt_level=3, shared_libs=[support_lib])
+
+ # Compile
+ engine = compiler.compile_and_jit(mlir_nvgpu_module)
+ return engine
+
+
+def matmul(input_type=np.float16,
+ output_type=np.float32,
+ M=128,
+ N=128,
+ K=128,
+ BLOCK_M=128,
+ BLOCK_N=128,
+ BLOCK_K=64,
+ use_warp_specilization=True,
+ saveIR=False,
+ max_num_stages=3,
+ print_results=False,
+ no_verify=False):
+ # Print the configuration
+ ity = "f16" if input_type == np.float16 else "f32"
+ oty = "f16" if output_type == np.float16 else "f32"
+ gemmty = "Warp Specilization" if use_warp_specilization else "Multistage"
+ print("===-- Running GEMM " + gemmty + " " + oty + " += " + ity + " * " + ity + ", Size " + str(M) + "x" + str(N) +
+ "x" + str(K) + ", Tile " + str(BLOCK_M) + "x" + str(BLOCK_N) + "x" + str(BLOCK_K) + ", stages " +
+ str(max_num_stages) + " --===")
+
+ # Build IR and compile
+ engine = generate_matmul(input_type, output_type, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, use_warp_specilization,
+ saveIR, max_num_stages)
+
+ # Allocate matrices and invoke the matmul
+ c = np.zeros((M, N), output_type)
+ a = np.random.randn(M, K).astype(input_type)
+ b = np.random.randn(K, N).astype(input_type)
+ mem_a = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(a)))
+ mem_b = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(b)))
+ mem_c = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(c)))
+ kernelName = "mlir_matmul_warpspecialized" if use_warp_specilization else "mlir_matmul_multistage"
+
+ # Launch the MLIR generated kernel
+ engine.invoke(kernelName, mem_a, mem_b, mem_c)
+
+ float_formatter = "{:.2f}".format
+ np.set_printoptions(formatter={'float_kind': float_formatter})
+
+ if print_results:
+ print(c)
+
+ # Verify the results
+ if not no_verify:
+ ref = a.astype(input_type) @ b.astype(input_type)
+ if print_results:
+ print(ref)
+ np.testing.assert_allclose(c, ref, rtol=5e-03, atol=1e-01)
+
+ print("PASS ")
+
+
+# GEMM Multistage f32 += f16 * f16
+matmul(np.float16, np.float32, 128, 128, 4096, max_num_stages=3, use_warp_specilization=False)
+# GEMM Warp Specilized f32 += f16 * f16
+matmul(np.float16, np.float32, 256, 1024, 512, max_num_stages=3, use_warp_specilization=True)
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/python/tools/lit.local.cfg b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/lit.local.cfg
new file mode 100644
index 00000000000000..d9f34f219c4d95
--- /dev/null
+++ b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/lit.local.cfg
@@ -0,0 +1,3 @@
+# Files in this directory are tools, not tests.
+config.unsupported = True
+
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
new file mode 100644
index 00000000000000..09ab35d4c5f15e
--- /dev/null
+++ b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
@@ -0,0 +1,676 @@
+import numpy as np
+from mlir import ir
+from mlir.dialects import arith
+from mlir.dialects import func
+from mlir.dialects import gpu
+from mlir.dialects import memref
+from mlir.dialects import nvgpu
+from mlir.dialects import nvvm
+from mlir.dialects import llvm
+from mlir.dialects import builtin
+from mlir.dialects import scf
+from mlir.dialects import vector
+
+
+TMA_LAST_DIM_F16 = 64 # 128B flaot16
+WARP_SIZE = 32
+WARP_GROUP_SIZE = WARP_SIZE * 4
+
+PRODUCER_REGISTER_SIZE = 40
+CONSUMER_REGISTER_SIZE = 232
+
+PRODUCER_PRIMARY_THREAD = 128
+CONSUMER_PRIMARY_THREAD = 0
+
+MLIR_DYNAMIC = -9223372036854775808
+f16_byte = 2
+f32_byte = 4
+
+DEBUG = False
+
+
+def debug_print(fmt, *args, predicate=None, threadNumber=-1, forcePrint=False):
+ if not DEBUG and not forcePrint:
+ return
+ type_formats = []
+ for arg in args:
+ ty_format = None
+ if ir.IndexType.isinstance(arg.type):
+ ty_format = "%llu"
+ if ir.IntegerType.isinstance(arg.type):
+ width = ir.IntegerType(arg.type).width
+ if width == 64:
+ ty_format = "%llu"
+ elif width == 32:
+ ty_format = "%d"
+ elif width == 1:
+ ty_format = "%i"
+ if ir.F32Type.isinstance(arg.type):
+ ty_format = "%f"
+ if ty_format is None:
+ raise NotImplementedError(arg.type)
+ type_formats.append(ty_format)
+ if threadNumber != -1:
+ tidx = gpu.thread_id(gpu.Dimension.x)
+ predicate = arith.cmpi(arith.CmpIPredicate.eq, tidx, c(threadNumber))
+ scf.yield_([])
+ if_op = scf.IfOp(predicate)
+ with ir.InsertionPoint(if_op.then_block):
+ gpu.printf(fmt.format(*type_formats) + "\n", args)
+ scf.yield_([])
+
+
+def c(value, ty=None):
+ ty = ir.IndexType.get() if ty is None else ty
+ return arith.constant(ty, value)
+
+
+def generate_matmul_ws(input_type=np.float16,
+ output_type=np.float32,
+ M=4096,
+ N=4096,
+ K=4096,
+ BLOCK_M=128,
+ BLOCK_N=128,
+ BLOCK_K=128,
+ max_num_stages=3):
+ # Limitaitons for now
+ assert input_type == np.float16
+ assert output_type == np.float32
+ assert M % BLOCK_M == 0
+ assert N % BLOCK_N == 0
+ assert K % BLOCK_K == 0
+
+ required_stages = (M * K + K * N) // (BLOCK_M * BLOCK_K + BLOCK_K * BLOCK_N)
+ num_stages = min(required_stages, max_num_stages)
+
+ module = ir.Module.create()
+ f16 = ir.F16Type.get()
+ f32 = ir.F32Type.get()
+ i1 = ir.IntegerType.get_signless(1)
+ i32 = ir.IntegerType.get_signless(32)
+ index = ir.IndexType.get()
+ i8 = ir.IntegerType.get_signless(8)
+ token_ty = ir.Type.parse("!gpu.async.token")
+ a_ty = ir.MemRefType.get([M, K], f16)
+ b_ty = ir.MemRefType.get((K, N), f16)
+ c_elem_ty = f16 if output_type == np.float16 else f32
+ c_ty = ir.MemRefType.get((M, N), c_elem_ty)
+ a_tile_shape = a_tma_shape = (BLOCK_M, TMA_LAST_DIM_F16)
+ b_tma_shape = (BLOCK_K, TMA_LAST_DIM_F16)
+ b_tile_shape = (BLOCK_K, BLOCK_N)
+ txcount = ((b_tile_shape[0] * b_tile_shape[1]) + (a_tile_shape[0] * a_tile_shape[1])) * f16_byte
+ smem_space_str = "#gpu.address_space<workgroup>"
+ smem_space = ir.Attribute.parse(smem_space_str)
+ input_type_str = "f16" if input_type == np.float16 else "f32"
+ output_type_str = "f16" if output_type == np.float16 else "f32"
+ mbar_ty = ir.Type.parse("!nvgpu.mbarrier.group<memorySpace = " + str(smem_space) + ", num_barriers = " +
+ str(num_stages) + ">")
+ a_tma_desc_ty = ir.Type.parse("!nvgpu.tensormap.descriptor<tensor = memref<" + str(BLOCK_M) + "x" +
+ str(TMA_LAST_DIM_F16) + "x" + str(input_type_str) + ", " + str(smem_space) +
+ ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>")
+ b_tma_desc_ty = ir.Type.parse("!nvgpu.tensormap.descriptor<tensor = memref<" + str(BLOCK_K) + "x" +
+ str(TMA_LAST_DIM_F16) + "x" + str(input_type_str) + ", " + str(smem_space) +
+ ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>")
+ acc_ty = ir.Type.parse("!nvgpu.warpgroup.accumulator<fragmented=vector<" + str(BLOCK_M) + "x" + str(BLOCK_N) + "x" +
+ str(output_type_str) + ">>")
+ a_wgmma_ty = ir.Type.parse("!nvgpu.warpgroup.descriptor<tensor=memref<" + str(BLOCK_M) + "x" + str(BLOCK_K) + "x" +
+ str(input_type_str) + ", " + smem_space_str + ">>")
+ b_wgmma_ty = ir.Type.parse("!nvgpu.warpgroup.descriptor<tensor=memref<" + str(BLOCK_K) + "x" + str(BLOCK_N) + "x" +
+ str(input_type_str) + ", " + smem_space_str + ">>")
+
+ with ir.InsertionPoint(module.body):
+
+ @func.FuncOp.from_py_func(a_ty, b_ty, c_ty)
+ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
+ lhs_tile_bytes = BLOCK_M * BLOCK_K * f16_byte
+ rhs_tile_bytes = BLOCK_N * BLOCK_K * f16_byte
+ smem_size_input = (lhs_tile_bytes + rhs_tile_bytes) * num_stages
+ smem_size_output = BLOCK_M * BLOCK_N * f32_byte
+ smem_size = max(smem_size_input, smem_size_output)
+
+ # Step 1. Allocate device memory and memcpy
+ t1 = gpu.wait(token_ty, [])
+ a_device, t2 = gpu.alloc(a_ty, token_ty, [t1], [], [])
+ b_device, t3 = gpu.alloc(b_ty, token_ty, [t2], [], [])
+ c_device, t4 = gpu.alloc(c_ty, token_ty, [t3], [], [])
+ t5 = gpu.memcpy(token_ty, [t4], a_device, a_host)
+ t6 = gpu.memcpy(token_ty, [t5], b_device, b_host)
+ t7 = gpu.wait(token_ty, [t6])
+
+ # Step 2. Create TMA Descriptors
+ tma_specs = [(a_device, a_tma_desc_ty, a_tma_shape), (b_device, b_tma_desc_ty, b_tma_shape)]
+ tma_descs = []
+ for x_device, tensor_map_ty, tile_shape in tma_specs:
+ x_unranked = memref.cast(ir.UnrankedMemRefType.get(f16, a_ty.memory_space), x_device)
+ tma_descs.append(nvgpu.TmaCreateDescriptorOp(tensor_map_ty, x_unranked, map(c, tile_shape)).result)
+ a_tma_desc, b_tma_desc = tma_descs
+
+ # Step 3. Launch Kernel with 2 Warpgroups : 1 Producer, 1 Consumer
+ cta_m = M // BLOCK_M
+ cta_n = N // BLOCK_N
+ assert M % BLOCK_M == 0 and N % BLOCK_N == 0
+ grid = (cta_m, cta_n, 1)
+ block = (WARP_GROUP_SIZE * 2, 1, 1)
+ launch_op = gpu.LaunchOp(token_ty, [t7],
+ *map(c, grid),
+ *map(c, block),
+ dynamicSharedMemorySize=c(smem_size, ty=i32))
+ launch_op.body.blocks.append(*([index] * 12))
+ with ir.InsertionPoint(launch_op.body.blocks[0]):
+ # GPU Step 0. This is need for vectorized ld/st
+ memref.assume_alignment(c_device, 16)
+ dynamic_smem = gpu.dynamic_shared_memory(
+ ir.MemRefType.get((MLIR_DYNAMIC, ), i8, memory_space=smem_space))
+ ticks = c(10000000)
+
+ # GPU Step 1. Bootstrapping: find the primary thread, warps, warp groups and etc.
+ tidx = gpu.thread_id(gpu.Dimension.x)
+ wgPrimaryThread = arith.cmpi(arith.CmpIPredicate.eq, arith.remui(tidx, c(WARP_GROUP_SIZE)), c(0))
+ warp_id = arith.divui(tidx, c(32))
+ warpgroup_id = arith.divui(warp_id, c(4))
+ is_producer = arith.cmpi(arith.CmpIPredicate.eq, warpgroup_id,
+ c(1 if PRODUCER_PRIMARY_THREAD == 128 else 0))
+ is_consumer = arith.cmpi(arith.CmpIPredicate.eq, warpgroup_id,
+ c(0 if CONSUMER_PRIMARY_THREAD == 0 else 1))
+ producerPrimaryThread = arith.cmpi(arith.CmpIPredicate.eq, tidx, c(PRODUCER_PRIMARY_THREAD))
+ consumerPrimaryThread = arith.cmpi(arith.CmpIPredicate.eq, tidx, c(CONSUMER_PRIMARY_THREAD))
+ bidx = gpu.block_id(gpu.Dimension.x)
+ bidy = gpu.block_id(gpu.Dimension.y)
+ dimX = arith.muli(bidx, c(BLOCK_M))
+ dimY = arith.muli(bidy, c(BLOCK_N))
+
+ # GPU Step 2. Initialize mbarrier groups
+ mbarTMA = nvgpu.mbarrier_create(mbar_ty)
+ mbarDONE = nvgpu.mbarrier_create(mbar_ty)
+ for i in range(num_stages):
+ nvgpu.mbarrier_init(mbarTMA, c(1), c(i), predicate=wgPrimaryThread)
+ nvgpu.mbarrier_init(mbarDONE, c(1), c(i), predicate=wgPrimaryThread)
+ gpu.barrier()
+
+ # GPU Step 3. Prefetch TMA descriptors
+ nvgpu.tma_prefetch_descriptor(a_tma_desc, predicate=wgPrimaryThread)
+ nvgpu.tma_prefetch_descriptor(b_tma_desc, predicate=wgPrimaryThread)
+
+ # GPU Step 5. Producer Warpgroup (TMA Warpgroup)
+ with ir.InsertionPoint(scf.IfOp(is_producer).then_block):
+
+ # Step 5.1. Reduce register size
+ nvvm.setmaxregister(PRODUCER_REGISTER_SIZE, nvvm.SetMaxRegisterAction.decrease)
+
+ # Step 5.2. TMA Main Loop
+ for_op = scf.ForOp(c(0), c(K // BLOCK_K), c(1), [arith.constant(i1, 1)])
+ with ir.InsertionPoint(for_op.body):
+ phaseParity = for_op.inner_iter_args[0]
+ iv = for_op.induction_variable
+ stage = arith.remui(iv, c(num_stages))
+
+ # Step 5.2.1. Wait mbarDONE
+ debug_print("[prod] {} | mbarDONE[{}] try_wait phase={}",
+ iv,
+ stage,
+ phaseParity,
+ predicate=producerPrimaryThread)
+ nvgpu.MBarrierTryWaitParityOp(mbarDONE, phaseParity, ticks, mbarId=stage)
+ debug_print("[prod] {} | mbarDONE[{}] try_wait phase={} [done]",
+ iv,
+ stage,
+ phaseParity,
+ predicate=producerPrimaryThread)
+ p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(num_stages - 1))
+ phaseParity = arith.select(p, arith.xori(phaseParity, arith.constant(i1, 1)), phaseParity)
+
+ # Step 5.2.2. Load TMA
+ a_offset = arith.muli(stage, c(lhs_tile_bytes))
+ a_tma_slice = memref.view(ir.MemRefType.get(a_tma_shape, f16, memory_space=smem_space),
+ dynamic_smem, a_offset, [])
+ b_offset = arith.addi(arith.muli(stage, c(rhs_tile_bytes)), c(lhs_tile_bytes * num_stages))
+ b_tma_slice_1 = memref.view(ir.MemRefType.get(b_tma_shape, f16, memory_space=smem_space),
+ dynamic_smem, b_offset, [])
+ b_offset2 = arith.addi(b_offset, c(BLOCK_K * TMA_LAST_DIM_F16 * f16_byte))
+ b_tma_slice_2 = memref.view(ir.MemRefType.get(b_tma_shape, f16, memory_space=smem_space),
+ dynamic_smem, b_offset2, [])
+ debug_print("[prod] a_offset={} b_offset={} b_offset2={}",
+ a_offset,
+ b_offset,
+ b_offset2,
+ predicate=producerPrimaryThread)
+ coord = arith.muli(c(64), iv)
+ nvgpu.TmaAsyncLoadOp(a_tma_slice,
+ mbarTMA,
+ a_tma_desc,
+ coordinates=[coord, dimX],
+ mbarId=stage,
+ predicate=producerPrimaryThread)
+ nvgpu.TmaAsyncLoadOp(b_tma_slice_1,
+ mbarTMA,
+ b_tma_desc,
+ coordinates=[dimY, coord],
+ mbarId=stage,
+ predicate=producerPrimaryThread)
+ dimY2 = arith.addi(dimY, c(64))
+ nvgpu.TmaAsyncLoadOp(b_tma_slice_2,
+ mbarTMA,
+ b_tma_desc,
+ coordinates=[dimY2, coord],
+ mbarId=stage,
+ predicate=producerPrimaryThread)
+
+ # Step 5.2.3. Arrive mbarTMA
+ debug_print("[prod] {} | mbarTMA[{}] arrive", iv, stage, predicate=producerPrimaryThread)
+ nvgpu.mbarrier_arrive_expect_tx(mbarTMA, c(txcount), stage, predicate=producerPrimaryThread)
+ debug_print("[prod] {} | mbarTMA[{}] arrive [done]",
+ iv,
+ stage,
+ predicate=producerPrimaryThread)
+ scf.yield_([phaseParity])
+ scf.yield_([])
+
+ # GPU Step 6. Consumer Warpgroup (MMA Warpgroup)
+ if_op = scf.IfOp(is_consumer)
+ with ir.InsertionPoint(if_op.then_block):
+
+ # Step 6.1. Increase register size
+ nvvm.setmaxregister(CONSUMER_REGISTER_SIZE, nvvm.SetMaxRegisterAction.increase)
+
+ # GPU Step 6.2. Initialize MMA registers
+ acc = nvgpu.warpgroup_mma_init_accumulator(acc_ty)
+
+ # Step 6.3. MMA Main Loop
+ for_op = scf.ForOp(c(0), c(K // BLOCK_K), c(1), [acc, arith.constant(i1, 0)])
+ with ir.InsertionPoint(for_op.body):
+ # Step 6.3.1. Wait mbar1
+ phaseParity = for_op.inner_iter_args[1]
+ iv = for_op.induction_variable
+ stage = arith.remui(iv, c(num_stages))
+ debug_print("[cons] {} | mbarTMA[{}] try_wait phase={}",
+ iv,
+ stage,
+ phaseParity,
+ predicate=consumerPrimaryThread)
+ nvgpu.MBarrierTryWaitParityOp(mbarTMA, phaseParity, ticks, mbarId=stage)
+ debug_print("[cons] {} | mbarTMA[{}] try_wait phase={} [done]",
+ iv,
+ stage,
+ phaseParity,
+ predicate=consumerPrimaryThread)
+
+ # Step 6.3.2. Create WGMMA Descriptors
+ a_offset = arith.muli(stage, c(lhs_tile_bytes))
+ a_tile_slice = memref.view(ir.MemRefType.get(a_tile_shape, f16, memory_space=smem_space),
+ dynamic_smem, a_offset, [])
+ b_offset = arith.addi(arith.muli(stage, c(rhs_tile_bytes)), c(lhs_tile_bytes * num_stages))
+ b_tile_slice = memref.view(ir.MemRefType.get(b_tile_shape, f16, memory_space=smem_space),
+ dynamic_smem, b_offset, [])
+ debug_print("[cons] a_offset={} b_offset={}",
+ a_offset,
+ b_offset,
+ predicate=consumerPrimaryThread)
+ da = nvgpu.WarpgroupGenerateDescriptorOp(a_wgmma_ty, a_tile_slice, a_tma_desc)
+ db = nvgpu.WarpgroupGenerateDescriptorOp(b_wgmma_ty, b_tile_slice, b_tma_desc)
+
+ # Step 6.3.3. MMA
+ carry_acc = for_op.inner_iter_args[0]
+ new_acc = nvgpu.WarpgroupMmaOp(acc.type, da, db, carry_acc, transposeB=True)
+
+ # Step 6.3.4. Arrive mbarDONE
+ p1 = arith.cmpi(arith.CmpIPredicate.sgt, iv, c(0))
+ p_arrive = arith.andi(consumerPrimaryThread, p1)
+ with ir.InsertionPoint(scf.IfOp(p_arrive).then_block):
+ p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(0))
+ barId = arith.select(p, c(num_stages - 1), arith.subi(stage, c(1)))
+ remoteCtaId = c(0)
+ pred = consumerPrimaryThread
+ debug_print("[cons] {} | mbarDONE[{}] arrive.mapa pred={} ",
+ iv,
+ barId,
+ remoteCtaId,
+ pred,
+ predicate=consumerPrimaryThread)
+ nvgpu.mbarrier_arrive(ir.Type.parse("!nvgpu.mbarrier.token"), mbarDONE, barId)
+ debug_print("[cons] {} | mbarDONE[{}] arrive pred={} [done]",
+ iv,
+ barId,
+ remoteCtaId,
+ pred,
+ predicate=consumerPrimaryThread)
+ scf.yield_([])
+
+ p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(num_stages - 1))
+ phaseParity = arith.select(p, arith.xori(phaseParity, arith.constant(i1, 1)), phaseParity)
+
+ # Step 6.3.5. Yield
+ scf.yield_([new_acc, phaseParity])
+
+ # Step 6.3. Wait All WGMMA
+ nvvm.WgmmaWaitGroupSyncOp(0)
+
+ with ir.InsertionPoint(scf.IfOp(consumerPrimaryThread).then_block):
+ barId = c((K // BLOCK_K) % num_stages)
+ nvgpu.mbarrier_arrive(ir.Type.parse("!nvgpu.mbarrier.token"), mbarDONE, barId)
+ scf.yield_([])
+
+ # Step 6.4. Epilogue (registers --> shared memory)
+ acc_smem_ty = ir.MemRefType.get((BLOCK_M, BLOCK_N), c_elem_ty, memory_space=smem_space)
+ acc_smem = memref.view(acc_smem_ty, dynamic_smem, c(0), [])
+ debug_print("[cons] | Storing", predicate=consumerPrimaryThread)
+ nvgpu.WarpgroupMmaStoreOp(for_op.results[0], acc_smem)
+ scf.yield_([])
+ gpu.barrier()
+
+ # GPU Step 9. Epilogue (shared memory --> global memory)
+ fd = ir.MemRefType.get([BLOCK_M * BLOCK_N], c_elem_ty, memory_space=smem_space)
+ collapsed_smem = memref.view(fd, dynamic_smem, c(0), [])
+ rty = ir.MemRefType.get((BLOCK_M, BLOCK_N), c_elem_ty,
+ ir.Attribute.parse("strided<[" + str(N) + ", 1], offset: ?>"))
+ c_device_per_block = memref.SubViewOp(rty, c_device, [dimX, dimY], [], [], [MLIR_DYNAMIC, MLIR_DYNAMIC],
+ [BLOCK_M, BLOCK_N], [1, 1])
+ vlen = 1
+ for_op = scf.ForOp(tidx, c(BLOCK_M * BLOCK_N), c(vlen * WARP_GROUP_SIZE * 2))
+ with ir.InsertionPoint(for_op.body):
+ x = arith.divui(for_op.induction_variable, c(BLOCK_M))
+ y = arith.remui(for_op.induction_variable, c(BLOCK_N))
+ vdata = vector.load(ir.VectorType.get((vlen, ), c_elem_ty), collapsed_smem,
+ [for_op.induction_variable])
+ vector.store(vdata, c_device_per_block, [x, y])
+ scf.yield_([])
+
+ gpu.terminator()
+
+ # Step 4. Copy back to host
+ t8 = gpu.wait(token_ty, [launch_op])
+ t9 = gpu.memcpy(token_ty, [t8], c_host, c_device)
+ gpu.wait(token_ty, [t9])
+
+ mlir_matmul_warpspecialized.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
+ module.operation.verify()
+ return module
+
+
+def generate_matmul_multistage(input_type=np.float16,
+ output_type=np.float32,
+ M=4096,
+ N=4096,
+ K=4096,
+ BLOCK_M=128,
+ BLOCK_N=128,
+ BLOCK_K=64,
+ max_num_stages=3):
+ # Limitaitons for now
+ assert input_type == np.float16
+ assert output_type == np.float32
+ assert M % BLOCK_M == 0
+ assert N % BLOCK_N == 0
+ assert K % BLOCK_K == 0
+
+ required_stages = (M * K + K * N) // (BLOCK_M * BLOCK_K + BLOCK_K * BLOCK_N)
+ num_stages = min(required_stages, max_num_stages)
+
+ module = ir.Module.create()
+ f16 = ir.F16Type.get()
+ f32 = ir.F32Type.get()
+ i1 = ir.IntegerType.get_signless(1)
+ i32 = ir.IntegerType.get_signless(32)
+ index = ir.IndexType.get()
+ i8 = ir.IntegerType.get_signless(8)
+ token_ty = ir.Type.parse("!gpu.async.token")
+ a_ty = ir.MemRefType.get([M, K], f16)
+ b_ty = ir.MemRefType.get((K, N), f16)
+ c_elem_ty = f16 if output_type == np.float16 else f32
+ c_ty = ir.MemRefType.get((M, N), c_elem_ty)
+ a_tile_shape = a_tma_shape = (BLOCK_M, TMA_LAST_DIM_F16)
+ b_tma_shape = (BLOCK_K, TMA_LAST_DIM_F16)
+ b_tile_shape = (BLOCK_K, BLOCK_N)
+ txcount = ((b_tile_shape[0] * b_tile_shape[1]) + (a_tile_shape[0] * a_tile_shape[1])) * f16_byte
+ smem_space_str = "#gpu.address_space<workgroup>"
+ smem_space = ir.Attribute.parse(smem_space_str)
+ input_type_str = "f16" if input_type == np.float16 else "f32"
+ output_type_str = "f16" if output_type == np.float16 else "f32"
+ mbar_ty = ir.Type.parse("!nvgpu.mbarrier.group<memorySpace = " + str(smem_space) + ", num_barriers = " +
+ str(num_stages) + ">")
+ a_tma_desc_ty = ir.Type.parse("!nvgpu.tensormap.descriptor<tensor = memref<" + str(BLOCK_M) + "x" +
+ str(TMA_LAST_DIM_F16) + "x" + str(input_type_str) + ", " + str(smem_space) +
+ ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>")
+ b_tma_desc_ty = ir.Type.parse("!nvgpu.tensormap.descriptor<tensor = memref<" + str(BLOCK_K) + "x" +
+ str(TMA_LAST_DIM_F16) + "x" + str(input_type_str) + ", " + str(smem_space) +
+ ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>")
+ acc_ty = ir.Type.parse("!nvgpu.warpgroup.accumulator<fragmented=vector<" + str(BLOCK_M) + "x" + str(BLOCK_N) + "x" +
+ str(output_type_str) + ">>")
+ a_wgmma_ty = ir.Type.parse("!nvgpu.warpgroup.descriptor<tensor=memref<" + str(BLOCK_M) + "x" + str(BLOCK_K) + "x" +
+ str(input_type_str) + ", " + smem_space_str + ">>")
+ b_wgmma_ty = ir.Type.parse("!nvgpu.warpgroup.descriptor<tensor=memref<" + str(BLOCK_K) + "x" + str(BLOCK_N) + "x" +
+ str(input_type_str) + ", " + smem_space_str + ">>")
+
+ with ir.InsertionPoint(module.body):
+
+ @func.FuncOp.from_py_func(a_ty, b_ty, c_ty)
+ def mlir_matmul_multistage(a_host, b_host, c_host):
+ lhs_tile_bytes = BLOCK_M * BLOCK_K * f16_byte
+ rhs_tile_bytes = BLOCK_N * BLOCK_K * f16_byte
+ smem_size_input = (lhs_tile_bytes + rhs_tile_bytes) * num_stages
+ smem_size_output = BLOCK_M * BLOCK_N * f32_byte
+ smem_size = max(smem_size_input, smem_size_output)
+
+ # Step 1. Allocate device memory and memcpy
+ t1 = gpu.wait(token_ty, [])
+ a_device, t2 = gpu.alloc(a_ty, token_ty, [t1], [], [])
+ b_device, t3 = gpu.alloc(b_ty, token_ty, [t2], [], [])
+ c_device, t4 = gpu.alloc(c_ty, token_ty, [t3], [], [])
+ t5 = gpu.memcpy(token_ty, [t4], a_device, a_host)
+ t6 = gpu.memcpy(token_ty, [t5], b_device, b_host)
+ t7 = gpu.wait(token_ty, [t6])
+
+ # Step 2. Create TMA Descriptors
+ tma_specs = [(a_device, a_tma_desc_ty, a_tma_shape), (b_device, b_tma_desc_ty, b_tma_shape)]
+ tma_descs = []
+ for x_device, tensor_map_ty, tile_shape in tma_specs:
+ x_unranked = memref.cast(ir.UnrankedMemRefType.get(f16, a_ty.memory_space), x_device)
+ tma_descs.append(nvgpu.TmaCreateDescriptorOp(tensor_map_ty, x_unranked, map(c, tile_shape)).result)
+ a_tma_desc, b_tma_desc = tma_descs
+
+ # Step 3. Launch Kernel with 1 Warpgroup
+ cta_m = M // BLOCK_M
+ cta_n = N // BLOCK_N
+ assert M % BLOCK_M == 0 and N % BLOCK_N == 0
+ grid = (cta_m, cta_n, 1)
+ block = (WARP_GROUP_SIZE, 1, 1)
+ launch_op = gpu.LaunchOp(token_ty, [t7],
+ *map(c, grid),
+ *map(c, block),
+ dynamicSharedMemorySize=c(smem_size, ty=i32))
+ launch_op.body.blocks.append(*([index] * 12))
+ with ir.InsertionPoint(launch_op.body.blocks[0]):
+ # GPU Step 0. Bootstrapping
+ memref.assume_alignment(c_device, 16)
+ dynamic_smem = gpu.dynamic_shared_memory(
+ ir.MemRefType.get((MLIR_DYNAMIC, ), i8, memory_space=smem_space))
+ ticks = c(10000000)
+ tidx = gpu.thread_id(gpu.Dimension.x)
+ primaryThread = arith.cmpi(arith.CmpIPredicate.eq, tidx, c(0))
+ warpId = arith.divui(tidx, c(32))
+ bidx = gpu.block_id(gpu.Dimension.x)
+ bidy = gpu.block_id(gpu.Dimension.y)
+ dimX = arith.muli(bidx, c(BLOCK_M))
+ dimY = arith.muli(bidy, c(BLOCK_N))
+
+ # GPU Step 1. Initialize mbarrier groups
+ mbarTMA = nvgpu.mbarrier_create(mbar_ty)
+ for i in range(num_stages):
+ nvgpu.mbarrier_init(mbarTMA, c(1), c(i), predicate=primaryThread)
+ gpu.barrier()
+
+ # GPU Step 2. Prefetch TMA descriptors
+ nvgpu.tma_prefetch_descriptor(a_tma_desc, predicate=primaryThread)
+ nvgpu.tma_prefetch_descriptor(b_tma_desc, predicate=primaryThread)
+
+ # GPU Step 3. Prologue (global memory --> shared memory)
+ for_op = scf.ForOp(c(0), c(num_stages-1), c(1))
+ with ir.InsertionPoint(for_op.body):
+ iv = for_op.induction_variable
+
+ # Step 3.1. Calculate offsets
+ a_offset = arith.muli(iv, c(lhs_tile_bytes))
+ a_tma_slice = memref.view(ir.MemRefType.get(a_tma_shape, f16, memory_space=smem_space),
+ dynamic_smem, a_offset, [])
+ b_offset = arith.addi(arith.muli(iv, c(rhs_tile_bytes)), c(lhs_tile_bytes * num_stages))
+ b_tma_slice_1 = memref.view(ir.MemRefType.get(b_tma_shape, f16, memory_space=smem_space),
+ dynamic_smem, b_offset, [])
+ b_offset2 = arith.addi(b_offset, c(BLOCK_K * TMA_LAST_DIM_F16 * f16_byte))
+ b_tma_slice_2 = memref.view(ir.MemRefType.get(b_tma_shape, f16, memory_space=smem_space),
+ dynamic_smem, b_offset2, [])
+
+ # Step 3.2. TMA Load
+ coord = arith.muli(c(64), iv)
+ dimY2 = arith.addi(dimY, c(64))
+ debug_print("[Prologue] TMA Load a_offset={} b_offset={} b_offset2={} @ a=({},{}) b=({},{})",
+ a_offset,
+ b_offset,
+ b_offset2,
+ coord, dimX,
+ dimY, coord,
+ predicate=primaryThread)
+ nvgpu.TmaAsyncLoadOp(a_tma_slice,
+ mbarTMA,
+ a_tma_desc,
+ coordinates=[coord, dimX],
+ mbarId=iv,
+ predicate=primaryThread)
+ nvgpu.TmaAsyncLoadOp(b_tma_slice_1,
+ mbarTMA,
+ b_tma_desc,
+ coordinates=[dimY, coord],
+ mbarId=iv,
+ predicate=primaryThread)
+ nvgpu.TmaAsyncLoadOp(b_tma_slice_2,
+ mbarTMA,
+ b_tma_desc,
+ coordinates=[dimY2, coord],
+ mbarId=iv,
+ predicate=primaryThread)
+
+ # Step 3.2. mbarTMA arrive
+ debug_print("[Prologue] mbarTMA[{}] arrive", iv, predicate=primaryThread)
+ nvgpu.mbarrier_arrive_expect_tx(mbarTMA, c(txcount), iv, predicate=primaryThread)
+ debug_print("[Prologue] mbarTMA[{}] arrive [done]", iv, predicate=primaryThread)
+ scf.yield_([])
+
+ # GPU Step 4. Main Loop
+ acc = nvgpu.warpgroup_mma_init_accumulator(acc_ty)
+ for_op = scf.ForOp(c(0), c(K // BLOCK_K), c(1), [acc, arith.constant(i1, 0)])
+ with ir.InsertionPoint(for_op.body):
+ # Step 4.1. Wait mbarTMA
+ phaseParity = for_op.inner_iter_args[1]
+ iv = for_op.induction_variable
+ stage = arith.remui(iv, c(num_stages))
+ debug_print("[MainLoop] mbarTMA[{}] try_wait phase={}", stage, phaseParity, predicate=primaryThread)
+ nvgpu.MBarrierTryWaitParityOp(mbarTMA, phaseParity, ticks, mbarId=stage)
+ debug_print("[MainLoop] mbarTMA[{}] try_wait phase={} [done]", stage, phaseParity, predicate=primaryThread)
+
+ # Step 4.2. Create WGMMA Descriptors
+ a_offset = arith.muli(stage, c(lhs_tile_bytes))
+ a_tile_slice = memref.view(ir.MemRefType.get(a_tile_shape, f16, memory_space=smem_space),
+ dynamic_smem, a_offset, [])
+ b_offset = arith.addi(arith.muli(stage, c(rhs_tile_bytes)), c(lhs_tile_bytes * num_stages))
+ b_tile_slice = memref.view(ir.MemRefType.get(b_tile_shape, f16, memory_space=smem_space),
+ dynamic_smem, b_offset, [])
+ debug_print("[MainLoop] iv={} MMA a_offset={} b_offset={}", iv, a_offset, b_offset, predicate=primaryThread)
+ da = nvgpu.WarpgroupGenerateDescriptorOp(a_wgmma_ty, a_tile_slice, a_tma_desc)
+ db = nvgpu.WarpgroupGenerateDescriptorOp(b_wgmma_ty, b_tile_slice, b_tma_desc)
+
+ # Step 4.3. MMA
+ carry_acc = for_op.inner_iter_args[0]
+ new_acc = nvgpu.WarpgroupMmaOp(acc.type, da, db, carry_acc, transposeB=True)
+
+ # Step 4.4. Load TMA for next stage
+ p1 = arith.cmpi(arith.CmpIPredicate.ult, arith.addi(iv, c(num_stages-1)), c(K // BLOCK_K))
+ p = arith.andi(primaryThread, p1)
+ with ir.InsertionPoint(scf.IfOp(p).then_block):
+ nextStage = arith.addi(iv, c(num_stages-1))
+ nextSlot = arith.remui(nextStage, c(num_stages))
+ a_offset = arith.muli(nextSlot, c(lhs_tile_bytes))
+ a_tma_slice = memref.view(ir.MemRefType.get(a_tma_shape, f16, memory_space=smem_space),
+ dynamic_smem, a_offset, [])
+ b_offset = arith.addi(arith.muli(nextSlot, c(rhs_tile_bytes)), c(lhs_tile_bytes * num_stages))
+ b_tma_slice_1 = memref.view(ir.MemRefType.get(b_tma_shape, f16, memory_space=smem_space),
+ dynamic_smem, b_offset, [])
+ b_offset2 = arith.addi(b_offset, c(BLOCK_K * TMA_LAST_DIM_F16 * f16_byte))
+ b_tma_slice_2 = memref.view(ir.MemRefType.get(b_tma_shape, f16, memory_space=smem_space),
+ dynamic_smem, b_offset2, [])
+
+ coord = arith.muli(c(64), nextStage)
+ debug_print("[MainLoop] iv={} TMA Load a_offset={} b_offset={} b_offset2={} @ a=({},{}) b=({},{})",
+ iv, a_offset,
+ b_offset,
+ b_offset2,
+ coord, dimX,
+ dimY, coord,
+ predicate=primaryThread)
+ nvgpu.TmaAsyncLoadOp(a_tma_slice,
+ mbarTMA,
+ a_tma_desc,
+ coordinates=[coord, dimX],
+ mbarId=nextSlot,
+ predicate=primaryThread)
+ nvgpu.TmaAsyncLoadOp(b_tma_slice_1,
+ mbarTMA,
+ b_tma_desc,
+ coordinates=[dimY, coord],
+ mbarId=nextSlot,
+ predicate=primaryThread)
+ dimY2 = arith.addi(dimY, c(64))
+ nvgpu.TmaAsyncLoadOp(b_tma_slice_2,
+ mbarTMA,
+ b_tma_desc,
+ coordinates=[dimY2, coord],
+ mbarId=nextSlot,
+ predicate=primaryThread)
+
+ debug_print("[MainLoop] mbarTMA[{}] arrive", nextSlot, predicate=primaryThread)
+ nvgpu.mbarrier_arrive_expect_tx(mbarTMA, c(txcount), nextSlot, predicate=primaryThread)
+ debug_print("[MainLoop] mbarTMA[{}] arrive [done]", nextSlot, predicate=primaryThread)
+ scf.yield_([])
+ # Step 4.5. Change the phaseParity
+ p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(num_stages - 1))
+ phaseParity = arith.select(p, arith.xori(phaseParity, arith.constant(i1, 1)), phaseParity)
+
+ # Step 4.5. Yield
+ scf.yield_([new_acc, phaseParity])
+
+ # Step 5. Wait All WGMMA groups
+ nvvm.WgmmaWaitGroupSyncOp(0)
+ gpu.barrier()
+
+ # Step 6. Epilogue (registers --> shared memory)
+ acc_smem_ty = ir.MemRefType.get((BLOCK_M, BLOCK_N), c_elem_ty, memory_space=smem_space)
+ acc_smem = memref.view(acc_smem_ty, dynamic_smem, c(0), [])
+ debug_print("Storing", predicate=primaryThread)
+ nvgpu.WarpgroupMmaStoreOp(for_op.results[0], acc_smem)
+ gpu.barrier()
+
+ # GPU Step 7. Epilogue (shared memory --> global memory)
+ fd = ir.MemRefType.get([BLOCK_M * BLOCK_N], c_elem_ty, memory_space=smem_space)
+ collapsed_smem = memref.view(fd, dynamic_smem, c(0), [])
+ rty = ir.MemRefType.get((BLOCK_M, BLOCK_N), c_elem_ty,
+ ir.Attribute.parse("strided<[" + str(N) + ", 1], offset: ?>"))
+ c_device_per_block = memref.SubViewOp(rty, c_device, [dimX, dimY], [], [], [MLIR_DYNAMIC, MLIR_DYNAMIC],
+ [BLOCK_M, BLOCK_N], [1, 1])
+ vlen = 1
+ for_op = scf.ForOp(tidx, c(BLOCK_M * BLOCK_N), c(vlen * WARP_GROUP_SIZE))
+ with ir.InsertionPoint(for_op.body):
+ x = arith.divui(for_op.induction_variable, c(BLOCK_M))
+ y = arith.remui(for_op.induction_variable, c(BLOCK_N))
+ vdata = vector.load(ir.VectorType.get((vlen, ), c_elem_ty), collapsed_smem,
+ [for_op.induction_variable])
+ vector.store(vdata, c_device_per_block, [x, y])
+ scf.yield_([])
+
+ gpu.terminator()
+
+ # Step 4. Copy back to host
+ t8 = gpu.wait(token_ty, [launch_op])
+ t9 = gpu.memcpy(token_ty, [t8], c_host, c_device)
+ gpu.wait(token_ty, [t9])
+
+ mlir_matmul_multistage.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
+ module.operation.verify()
+ return module
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/python/tools/nvgpucompiler.py b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/nvgpucompiler.py
new file mode 100644
index 00000000000000..2a4f67326ee718
--- /dev/null
+++ b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/nvgpucompiler.py
@@ -0,0 +1,44 @@
+# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+# This file contains the Nvgpu class.
+
+from mlir import execution_engine
+from mlir import ir
+from mlir import passmanager
+from typing import Sequence
+import errno
+import os
+import sys
+
+_SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__))
+sys.path.append(_SCRIPT_PATH)
+
+class NvgpuCompiler:
+ """Nvgpu class for compiling and building MLIR modules."""
+
+ def __init__(self, options: str, opt_level: int, shared_libs: Sequence[str]):
+ pipeline = f"builtin.module(gpu-lower-to-nvvm-pipeline{{{options}}})"
+ self.pipeline = pipeline
+ self.shared_libs = shared_libs
+ self.opt_level = opt_level
+
+ def __call__(self, module: ir.Module):
+ """Convenience application method."""
+ self.compile(module)
+
+ def compile(self, module: ir.Module):
+ """Compiles the module by invoking the nvgpu pipeline."""
+ passmanager.PassManager.parse(self.pipeline).run(module.operation)
+
+ def jit(self, module: ir.Module) -> execution_engine.ExecutionEngine:
+ """Wraps the module in a JIT execution engine."""
+ return execution_engine.ExecutionEngine(
+ module, opt_level=self.opt_level, shared_libs=self.shared_libs
+ )
+
+ def compile_and_jit(self, module: ir.Module) -> execution_engine.ExecutionEngine:
+ """Compiles and jits the module."""
+ self.compile(module)
+ return self.jit(module)
>From 379fa54d3583a7ecb0753d12124485c7050dac69 Mon Sep 17 00:00:00 2001
From: grypp <guray.ozen at gmail.com>
Date: Mon, 12 Feb 2024 13:39:08 +0000
Subject: [PATCH 02/10] format with yapf
---
.../GPU/CUDA/sm90/python/matmul.py | 51 +-
.../CUDA/sm90/python/tools/matmulBuilder.py | 655 ++++++++++++------
.../CUDA/sm90/python/tools/nvgpucompiler.py | 15 +-
3 files changed, 483 insertions(+), 238 deletions(-)
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py b/mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py
index e153fcb44b9860..00313270d723d6 100644
--- a/mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py
+++ b/mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py
@@ -85,6 +85,7 @@
import ctypes
from mlir import runtime as rt
+
def generate_matmul(input_type=np.float16,
output_type=np.float32,
M=4096,
@@ -96,16 +97,19 @@ def generate_matmul(input_type=np.float16,
use_warp_specilization=True,
saveIR=False,
max_num_stages=3):
- with matmulBuilder.ir.Context() as ctx, matmulBuilder.ir.Location.unknown():
+ with matmulBuilder.ir.Context() as ctx, matmulBuilder.ir.Location.unknown(
+ ):
if use_warp_specilization:
- mlir_nvgpu_module = matmulBuilder.generate_matmul_ws(input_type, output_type, M, N, K, BLOCK_M, BLOCK_N,
- BLOCK_K, max_num_stages)
+ mlir_nvgpu_module = matmulBuilder.generate_matmul_ws(
+ input_type, output_type, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K,
+ max_num_stages)
else:
- mlir_nvgpu_module = matmulBuilder.generate_matmul_multistage(input_type, output_type, M, N, K, BLOCK_M,
- BLOCK_N, BLOCK_K, max_num_stages)
+ mlir_nvgpu_module = matmulBuilder.generate_matmul_multistage(
+ input_type, output_type, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K,
+ max_num_stages)
mlir_nvgpu_module.operation.verify()
-
+
# Save generated IR
if saveIR:
# print(mlir_nvgpu_module)
@@ -119,8 +123,11 @@ def generate_matmul(input_type=np.float16,
options = f"cubin-chip=sm_90a cubin-features=+ptx80 opt-level=3"
support_lib = os.getenv("SUPPORT_LIB")
if not os.path.exists(support_lib):
- raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), support_lib)
- compiler = nvgpucompiler.NvgpuCompiler(options, opt_level=3, shared_libs=[support_lib])
+ raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT),
+ support_lib)
+ compiler = nvgpucompiler.NvgpuCompiler(options,
+ opt_level=3,
+ shared_libs=[support_lib])
# Compile
engine = compiler.compile_and_jit(mlir_nvgpu_module)
@@ -144,13 +151,15 @@ def matmul(input_type=np.float16,
ity = "f16" if input_type == np.float16 else "f32"
oty = "f16" if output_type == np.float16 else "f32"
gemmty = "Warp Specilization" if use_warp_specilization else "Multistage"
- print("===-- Running GEMM " + gemmty + " " + oty + " += " + ity + " * " + ity + ", Size " + str(M) + "x" + str(N) +
- "x" + str(K) + ", Tile " + str(BLOCK_M) + "x" + str(BLOCK_N) + "x" + str(BLOCK_K) + ", stages " +
- str(max_num_stages) + " --===")
+ print("===-- Running GEMM " + gemmty + " " + oty + " += " + ity + " * " +
+ ity + ", Size " + str(M) + "x" + str(N) + "x" + str(K) + ", Tile " +
+ str(BLOCK_M) + "x" + str(BLOCK_N) + "x" + str(BLOCK_K) +
+ ", stages " + str(max_num_stages) + " --===")
# Build IR and compile
- engine = generate_matmul(input_type, output_type, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, use_warp_specilization,
- saveIR, max_num_stages)
+ engine = generate_matmul(input_type, output_type, M, N, K, BLOCK_M,
+ BLOCK_N, BLOCK_K, use_warp_specilization, saveIR,
+ max_num_stages)
# Allocate matrices and invoke the matmul
c = np.zeros((M, N), output_type)
@@ -181,6 +190,18 @@ def matmul(input_type=np.float16,
# GEMM Multistage f32 += f16 * f16
-matmul(np.float16, np.float32, 128, 128, 4096, max_num_stages=3, use_warp_specilization=False)
+matmul(np.float16,
+ np.float32,
+ 128,
+ 128,
+ 4096,
+ max_num_stages=3,
+ use_warp_specilization=False)
# GEMM Warp Specilized f32 += f16 * f16
-matmul(np.float16, np.float32, 256, 1024, 512, max_num_stages=3, use_warp_specilization=True)
+matmul(np.float16,
+ np.float32,
+ 256,
+ 1024,
+ 512,
+ max_num_stages=3,
+ use_warp_specilization=True)
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
index 09ab35d4c5f15e..0fda8698606cd5 100644
--- a/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
+++ b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
@@ -11,7 +11,6 @@
from mlir.dialects import scf
from mlir.dialects import vector
-
TMA_LAST_DIM_F16 = 64 # 128B flaot16
WARP_SIZE = 32
WARP_GROUP_SIZE = WARP_SIZE * 4
@@ -81,7 +80,8 @@ def generate_matmul_ws(input_type=np.float16,
assert N % BLOCK_N == 0
assert K % BLOCK_K == 0
- required_stages = (M * K + K * N) // (BLOCK_M * BLOCK_K + BLOCK_K * BLOCK_N)
+ required_stages = (M * K + K * N) // (BLOCK_M * BLOCK_K +
+ BLOCK_K * BLOCK_N)
num_stages = min(required_stages, max_num_stages)
module = ir.Module.create()
@@ -99,25 +99,36 @@ def generate_matmul_ws(input_type=np.float16,
a_tile_shape = a_tma_shape = (BLOCK_M, TMA_LAST_DIM_F16)
b_tma_shape = (BLOCK_K, TMA_LAST_DIM_F16)
b_tile_shape = (BLOCK_K, BLOCK_N)
- txcount = ((b_tile_shape[0] * b_tile_shape[1]) + (a_tile_shape[0] * a_tile_shape[1])) * f16_byte
+ txcount = ((b_tile_shape[0] * b_tile_shape[1]) +
+ (a_tile_shape[0] * a_tile_shape[1])) * f16_byte
smem_space_str = "#gpu.address_space<workgroup>"
smem_space = ir.Attribute.parse(smem_space_str)
input_type_str = "f16" if input_type == np.float16 else "f32"
output_type_str = "f16" if output_type == np.float16 else "f32"
- mbar_ty = ir.Type.parse("!nvgpu.mbarrier.group<memorySpace = " + str(smem_space) + ", num_barriers = " +
+ mbar_ty = ir.Type.parse("!nvgpu.mbarrier.group<memorySpace = " +
+ str(smem_space) + ", num_barriers = " +
str(num_stages) + ">")
- a_tma_desc_ty = ir.Type.parse("!nvgpu.tensormap.descriptor<tensor = memref<" + str(BLOCK_M) + "x" +
- str(TMA_LAST_DIM_F16) + "x" + str(input_type_str) + ", " + str(smem_space) +
- ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>")
- b_tma_desc_ty = ir.Type.parse("!nvgpu.tensormap.descriptor<tensor = memref<" + str(BLOCK_K) + "x" +
- str(TMA_LAST_DIM_F16) + "x" + str(input_type_str) + ", " + str(smem_space) +
- ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>")
- acc_ty = ir.Type.parse("!nvgpu.warpgroup.accumulator<fragmented=vector<" + str(BLOCK_M) + "x" + str(BLOCK_N) + "x" +
+ a_tma_desc_ty = ir.Type.parse(
+ "!nvgpu.tensormap.descriptor<tensor = memref<" + str(BLOCK_M) + "x" +
+ str(TMA_LAST_DIM_F16) + "x" + str(input_type_str) + ", " +
+ str(smem_space) +
+ ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>")
+ b_tma_desc_ty = ir.Type.parse(
+ "!nvgpu.tensormap.descriptor<tensor = memref<" + str(BLOCK_K) + "x" +
+ str(TMA_LAST_DIM_F16) + "x" + str(input_type_str) + ", " +
+ str(smem_space) +
+ ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>")
+ acc_ty = ir.Type.parse("!nvgpu.warpgroup.accumulator<fragmented=vector<" +
+ str(BLOCK_M) + "x" + str(BLOCK_N) + "x" +
str(output_type_str) + ">>")
- a_wgmma_ty = ir.Type.parse("!nvgpu.warpgroup.descriptor<tensor=memref<" + str(BLOCK_M) + "x" + str(BLOCK_K) + "x" +
- str(input_type_str) + ", " + smem_space_str + ">>")
- b_wgmma_ty = ir.Type.parse("!nvgpu.warpgroup.descriptor<tensor=memref<" + str(BLOCK_K) + "x" + str(BLOCK_N) + "x" +
- str(input_type_str) + ", " + smem_space_str + ">>")
+ a_wgmma_ty = ir.Type.parse("!nvgpu.warpgroup.descriptor<tensor=memref<" +
+ str(BLOCK_M) + "x" + str(BLOCK_K) + "x" +
+ str(input_type_str) + ", " + smem_space_str +
+ ">>")
+ b_wgmma_ty = ir.Type.parse("!nvgpu.warpgroup.descriptor<tensor=memref<" +
+ str(BLOCK_K) + "x" + str(BLOCK_N) + "x" +
+ str(input_type_str) + ", " + smem_space_str +
+ ">>")
with ir.InsertionPoint(module.body):
@@ -139,13 +150,18 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
t7 = gpu.wait(token_ty, [t6])
# Step 2. Create TMA Descriptors
- tma_specs = [(a_device, a_tma_desc_ty, a_tma_shape), (b_device, b_tma_desc_ty, b_tma_shape)]
+ tma_specs = [(a_device, a_tma_desc_ty, a_tma_shape),
+ (b_device, b_tma_desc_ty, b_tma_shape)]
tma_descs = []
for x_device, tensor_map_ty, tile_shape in tma_specs:
- x_unranked = memref.cast(ir.UnrankedMemRefType.get(f16, a_ty.memory_space), x_device)
- tma_descs.append(nvgpu.TmaCreateDescriptorOp(tensor_map_ty, x_unranked, map(c, tile_shape)).result)
+ x_unranked = memref.cast(
+ ir.UnrankedMemRefType.get(f16, a_ty.memory_space),
+ x_device)
+ tma_descs.append(
+ nvgpu.TmaCreateDescriptorOp(tensor_map_ty, x_unranked,
+ map(c, tile_shape)).result)
a_tma_desc, b_tma_desc = tma_descs
-
+
# Step 3. Launch Kernel with 2 Warpgroups : 1 Producer, 1 Consumer
cta_m = M // BLOCK_M
cta_n = N // BLOCK_N
@@ -155,26 +171,37 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
launch_op = gpu.LaunchOp(token_ty, [t7],
*map(c, grid),
*map(c, block),
- dynamicSharedMemorySize=c(smem_size, ty=i32))
+ dynamicSharedMemorySize=c(smem_size,
+ ty=i32))
launch_op.body.blocks.append(*([index] * 12))
with ir.InsertionPoint(launch_op.body.blocks[0]):
# GPU Step 0. This is need for vectorized ld/st
memref.assume_alignment(c_device, 16)
dynamic_smem = gpu.dynamic_shared_memory(
- ir.MemRefType.get((MLIR_DYNAMIC, ), i8, memory_space=smem_space))
+ ir.MemRefType.get((MLIR_DYNAMIC, ),
+ i8,
+ memory_space=smem_space))
ticks = c(10000000)
# GPU Step 1. Bootstrapping: find the primary thread, warps, warp groups and etc.
tidx = gpu.thread_id(gpu.Dimension.x)
- wgPrimaryThread = arith.cmpi(arith.CmpIPredicate.eq, arith.remui(tidx, c(WARP_GROUP_SIZE)), c(0))
+ wgPrimaryThread = arith.cmpi(
+ arith.CmpIPredicate.eq,
+ arith.remui(tidx, c(WARP_GROUP_SIZE)), c(0))
warp_id = arith.divui(tidx, c(32))
warpgroup_id = arith.divui(warp_id, c(4))
- is_producer = arith.cmpi(arith.CmpIPredicate.eq, warpgroup_id,
- c(1 if PRODUCER_PRIMARY_THREAD == 128 else 0))
- is_consumer = arith.cmpi(arith.CmpIPredicate.eq, warpgroup_id,
- c(0 if CONSUMER_PRIMARY_THREAD == 0 else 1))
- producerPrimaryThread = arith.cmpi(arith.CmpIPredicate.eq, tidx, c(PRODUCER_PRIMARY_THREAD))
- consumerPrimaryThread = arith.cmpi(arith.CmpIPredicate.eq, tidx, c(CONSUMER_PRIMARY_THREAD))
+ is_producer = arith.cmpi(
+ arith.CmpIPredicate.eq, warpgroup_id,
+ c(1 if PRODUCER_PRIMARY_THREAD == 128 else 0))
+ is_consumer = arith.cmpi(
+ arith.CmpIPredicate.eq, warpgroup_id,
+ c(0 if CONSUMER_PRIMARY_THREAD == 0 else 1))
+ producerPrimaryThread = arith.cmpi(arith.CmpIPredicate.eq,
+ tidx,
+ c(PRODUCER_PRIMARY_THREAD))
+ consumerPrimaryThread = arith.cmpi(arith.CmpIPredicate.eq,
+ tidx,
+ c(CONSUMER_PRIMARY_THREAD))
bidx = gpu.block_id(gpu.Dimension.x)
bidy = gpu.block_id(gpu.Dimension.y)
dimX = arith.muli(bidx, c(BLOCK_M))
@@ -184,57 +211,88 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
mbarTMA = nvgpu.mbarrier_create(mbar_ty)
mbarDONE = nvgpu.mbarrier_create(mbar_ty)
for i in range(num_stages):
- nvgpu.mbarrier_init(mbarTMA, c(1), c(i), predicate=wgPrimaryThread)
- nvgpu.mbarrier_init(mbarDONE, c(1), c(i), predicate=wgPrimaryThread)
+ nvgpu.mbarrier_init(mbarTMA,
+ c(1),
+ c(i),
+ predicate=wgPrimaryThread)
+ nvgpu.mbarrier_init(mbarDONE,
+ c(1),
+ c(i),
+ predicate=wgPrimaryThread)
gpu.barrier()
# GPU Step 3. Prefetch TMA descriptors
- nvgpu.tma_prefetch_descriptor(a_tma_desc, predicate=wgPrimaryThread)
- nvgpu.tma_prefetch_descriptor(b_tma_desc, predicate=wgPrimaryThread)
-
+ nvgpu.tma_prefetch_descriptor(a_tma_desc,
+ predicate=wgPrimaryThread)
+ nvgpu.tma_prefetch_descriptor(b_tma_desc,
+ predicate=wgPrimaryThread)
+
# GPU Step 5. Producer Warpgroup (TMA Warpgroup)
with ir.InsertionPoint(scf.IfOp(is_producer).then_block):
-
+
# Step 5.1. Reduce register size
- nvvm.setmaxregister(PRODUCER_REGISTER_SIZE, nvvm.SetMaxRegisterAction.decrease)
+ nvvm.setmaxregister(PRODUCER_REGISTER_SIZE,
+ nvvm.SetMaxRegisterAction.decrease)
# Step 5.2. TMA Main Loop
- for_op = scf.ForOp(c(0), c(K // BLOCK_K), c(1), [arith.constant(i1, 1)])
+ for_op = scf.ForOp(c(0), c(K // BLOCK_K), c(1),
+ [arith.constant(i1, 1)])
with ir.InsertionPoint(for_op.body):
phaseParity = for_op.inner_iter_args[0]
iv = for_op.induction_variable
stage = arith.remui(iv, c(num_stages))
# Step 5.2.1. Wait mbarDONE
- debug_print("[prod] {} | mbarDONE[{}] try_wait phase={}",
- iv,
- stage,
- phaseParity,
- predicate=producerPrimaryThread)
- nvgpu.MBarrierTryWaitParityOp(mbarDONE, phaseParity, ticks, mbarId=stage)
- debug_print("[prod] {} | mbarDONE[{}] try_wait phase={} [done]",
- iv,
- stage,
- phaseParity,
- predicate=producerPrimaryThread)
- p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(num_stages - 1))
- phaseParity = arith.select(p, arith.xori(phaseParity, arith.constant(i1, 1)), phaseParity)
+ debug_print(
+ "[prod] {} | mbarDONE[{}] try_wait phase={}",
+ iv,
+ stage,
+ phaseParity,
+ predicate=producerPrimaryThread)
+ nvgpu.MBarrierTryWaitParityOp(mbarDONE,
+ phaseParity,
+ ticks,
+ mbarId=stage)
+ debug_print(
+ "[prod] {} | mbarDONE[{}] try_wait phase={} [done]",
+ iv,
+ stage,
+ phaseParity,
+ predicate=producerPrimaryThread)
+ p = arith.cmpi(arith.CmpIPredicate.eq, stage,
+ c(num_stages - 1))
+ phaseParity = arith.select(
+ p, arith.xori(phaseParity, arith.constant(i1, 1)),
+ phaseParity)
# Step 5.2.2. Load TMA
a_offset = arith.muli(stage, c(lhs_tile_bytes))
- a_tma_slice = memref.view(ir.MemRefType.get(a_tma_shape, f16, memory_space=smem_space),
- dynamic_smem, a_offset, [])
- b_offset = arith.addi(arith.muli(stage, c(rhs_tile_bytes)), c(lhs_tile_bytes * num_stages))
- b_tma_slice_1 = memref.view(ir.MemRefType.get(b_tma_shape, f16, memory_space=smem_space),
- dynamic_smem, b_offset, [])
- b_offset2 = arith.addi(b_offset, c(BLOCK_K * TMA_LAST_DIM_F16 * f16_byte))
- b_tma_slice_2 = memref.view(ir.MemRefType.get(b_tma_shape, f16, memory_space=smem_space),
- dynamic_smem, b_offset2, [])
- debug_print("[prod] a_offset={} b_offset={} b_offset2={}",
- a_offset,
- b_offset,
- b_offset2,
- predicate=producerPrimaryThread)
+ a_tma_slice = memref.view(
+ ir.MemRefType.get(a_tma_shape,
+ f16,
+ memory_space=smem_space),
+ dynamic_smem, a_offset, [])
+ b_offset = arith.addi(
+ arith.muli(stage, c(rhs_tile_bytes)),
+ c(lhs_tile_bytes * num_stages))
+ b_tma_slice_1 = memref.view(
+ ir.MemRefType.get(b_tma_shape,
+ f16,
+ memory_space=smem_space),
+ dynamic_smem, b_offset, [])
+ b_offset2 = arith.addi(
+ b_offset, c(BLOCK_K * TMA_LAST_DIM_F16 * f16_byte))
+ b_tma_slice_2 = memref.view(
+ ir.MemRefType.get(b_tma_shape,
+ f16,
+ memory_space=smem_space),
+ dynamic_smem, b_offset2, [])
+ debug_print(
+ "[prod] a_offset={} b_offset={} b_offset2={}",
+ a_offset,
+ b_offset,
+ b_offset2,
+ predicate=producerPrimaryThread)
coord = arith.muli(c(64), iv)
nvgpu.TmaAsyncLoadOp(a_tma_slice,
mbarTMA,
@@ -257,8 +315,15 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
predicate=producerPrimaryThread)
# Step 5.2.3. Arrive mbarTMA
- debug_print("[prod] {} | mbarTMA[{}] arrive", iv, stage, predicate=producerPrimaryThread)
- nvgpu.mbarrier_arrive_expect_tx(mbarTMA, c(txcount), stage, predicate=producerPrimaryThread)
+ debug_print("[prod] {} | mbarTMA[{}] arrive",
+ iv,
+ stage,
+ predicate=producerPrimaryThread)
+ nvgpu.mbarrier_arrive_expect_tx(
+ mbarTMA,
+ c(txcount),
+ stage,
+ predicate=producerPrimaryThread)
debug_print("[prod] {} | mbarTMA[{}] arrive [done]",
iv,
stage,
@@ -269,75 +334,104 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
# GPU Step 6. Consumer Warpgroup (MMA Warpgroup)
if_op = scf.IfOp(is_consumer)
with ir.InsertionPoint(if_op.then_block):
-
+
# Step 6.1. Increase register size
- nvvm.setmaxregister(CONSUMER_REGISTER_SIZE, nvvm.SetMaxRegisterAction.increase)
-
+ nvvm.setmaxregister(CONSUMER_REGISTER_SIZE,
+ nvvm.SetMaxRegisterAction.increase)
+
# GPU Step 6.2. Initialize MMA registers
acc = nvgpu.warpgroup_mma_init_accumulator(acc_ty)
# Step 6.3. MMA Main Loop
- for_op = scf.ForOp(c(0), c(K // BLOCK_K), c(1), [acc, arith.constant(i1, 0)])
+ for_op = scf.ForOp(c(0), c(K // BLOCK_K), c(1),
+ [acc, arith.constant(i1, 0)])
with ir.InsertionPoint(for_op.body):
# Step 6.3.1. Wait mbar1
phaseParity = for_op.inner_iter_args[1]
iv = for_op.induction_variable
stage = arith.remui(iv, c(num_stages))
- debug_print("[cons] {} | mbarTMA[{}] try_wait phase={}",
- iv,
- stage,
- phaseParity,
- predicate=consumerPrimaryThread)
- nvgpu.MBarrierTryWaitParityOp(mbarTMA, phaseParity, ticks, mbarId=stage)
- debug_print("[cons] {} | mbarTMA[{}] try_wait phase={} [done]",
- iv,
- stage,
- phaseParity,
- predicate=consumerPrimaryThread)
-
+ debug_print(
+ "[cons] {} | mbarTMA[{}] try_wait phase={}",
+ iv,
+ stage,
+ phaseParity,
+ predicate=consumerPrimaryThread)
+ nvgpu.MBarrierTryWaitParityOp(mbarTMA,
+ phaseParity,
+ ticks,
+ mbarId=stage)
+ debug_print(
+ "[cons] {} | mbarTMA[{}] try_wait phase={} [done]",
+ iv,
+ stage,
+ phaseParity,
+ predicate=consumerPrimaryThread)
+
# Step 6.3.2. Create WGMMA Descriptors
a_offset = arith.muli(stage, c(lhs_tile_bytes))
- a_tile_slice = memref.view(ir.MemRefType.get(a_tile_shape, f16, memory_space=smem_space),
- dynamic_smem, a_offset, [])
- b_offset = arith.addi(arith.muli(stage, c(rhs_tile_bytes)), c(lhs_tile_bytes * num_stages))
- b_tile_slice = memref.view(ir.MemRefType.get(b_tile_shape, f16, memory_space=smem_space),
- dynamic_smem, b_offset, [])
+ a_tile_slice = memref.view(
+ ir.MemRefType.get(a_tile_shape,
+ f16,
+ memory_space=smem_space),
+ dynamic_smem, a_offset, [])
+ b_offset = arith.addi(
+ arith.muli(stage, c(rhs_tile_bytes)),
+ c(lhs_tile_bytes * num_stages))
+ b_tile_slice = memref.view(
+ ir.MemRefType.get(b_tile_shape,
+ f16,
+ memory_space=smem_space),
+ dynamic_smem, b_offset, [])
debug_print("[cons] a_offset={} b_offset={}",
a_offset,
b_offset,
predicate=consumerPrimaryThread)
- da = nvgpu.WarpgroupGenerateDescriptorOp(a_wgmma_ty, a_tile_slice, a_tma_desc)
- db = nvgpu.WarpgroupGenerateDescriptorOp(b_wgmma_ty, b_tile_slice, b_tma_desc)
+ da = nvgpu.WarpgroupGenerateDescriptorOp(
+ a_wgmma_ty, a_tile_slice, a_tma_desc)
+ db = nvgpu.WarpgroupGenerateDescriptorOp(
+ b_wgmma_ty, b_tile_slice, b_tma_desc)
# Step 6.3.3. MMA
carry_acc = for_op.inner_iter_args[0]
- new_acc = nvgpu.WarpgroupMmaOp(acc.type, da, db, carry_acc, transposeB=True)
+ new_acc = nvgpu.WarpgroupMmaOp(acc.type,
+ da,
+ db,
+ carry_acc,
+ transposeB=True)
# Step 6.3.4. Arrive mbarDONE
p1 = arith.cmpi(arith.CmpIPredicate.sgt, iv, c(0))
p_arrive = arith.andi(consumerPrimaryThread, p1)
with ir.InsertionPoint(scf.IfOp(p_arrive).then_block):
p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(0))
- barId = arith.select(p, c(num_stages - 1), arith.subi(stage, c(1)))
+ barId = arith.select(p, c(num_stages - 1),
+ arith.subi(stage, c(1)))
remoteCtaId = c(0)
pred = consumerPrimaryThread
- debug_print("[cons] {} | mbarDONE[{}] arrive.mapa pred={} ",
- iv,
- barId,
- remoteCtaId,
- pred,
- predicate=consumerPrimaryThread)
- nvgpu.mbarrier_arrive(ir.Type.parse("!nvgpu.mbarrier.token"), mbarDONE, barId)
- debug_print("[cons] {} | mbarDONE[{}] arrive pred={} [done]",
- iv,
- barId,
- remoteCtaId,
- pred,
- predicate=consumerPrimaryThread)
+ debug_print(
+ "[cons] {} | mbarDONE[{}] arrive.mapa pred={} ",
+ iv,
+ barId,
+ remoteCtaId,
+ pred,
+ predicate=consumerPrimaryThread)
+ nvgpu.mbarrier_arrive(
+ ir.Type.parse("!nvgpu.mbarrier.token"),
+ mbarDONE, barId)
+ debug_print(
+ "[cons] {} | mbarDONE[{}] arrive pred={} [done]",
+ iv,
+ barId,
+ remoteCtaId,
+ pred,
+ predicate=consumerPrimaryThread)
scf.yield_([])
- p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(num_stages - 1))
- phaseParity = arith.select(p, arith.xori(phaseParity, arith.constant(i1, 1)), phaseParity)
+ p = arith.cmpi(arith.CmpIPredicate.eq, stage,
+ c(num_stages - 1))
+ phaseParity = arith.select(
+ p, arith.xori(phaseParity, arith.constant(i1, 1)),
+ phaseParity)
# Step 6.3.5. Yield
scf.yield_([new_acc, phaseParity])
@@ -345,32 +439,45 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
# Step 6.3. Wait All WGMMA
nvvm.WgmmaWaitGroupSyncOp(0)
- with ir.InsertionPoint(scf.IfOp(consumerPrimaryThread).then_block):
+ with ir.InsertionPoint(
+ scf.IfOp(consumerPrimaryThread).then_block):
barId = c((K // BLOCK_K) % num_stages)
- nvgpu.mbarrier_arrive(ir.Type.parse("!nvgpu.mbarrier.token"), mbarDONE, barId)
+ nvgpu.mbarrier_arrive(
+ ir.Type.parse("!nvgpu.mbarrier.token"), mbarDONE,
+ barId)
scf.yield_([])
# Step 6.4. Epilogue (registers --> shared memory)
- acc_smem_ty = ir.MemRefType.get((BLOCK_M, BLOCK_N), c_elem_ty, memory_space=smem_space)
+ acc_smem_ty = ir.MemRefType.get((BLOCK_M, BLOCK_N),
+ c_elem_ty,
+ memory_space=smem_space)
acc_smem = memref.view(acc_smem_ty, dynamic_smem, c(0), [])
- debug_print("[cons] | Storing", predicate=consumerPrimaryThread)
+ debug_print("[cons] | Storing",
+ predicate=consumerPrimaryThread)
nvgpu.WarpgroupMmaStoreOp(for_op.results[0], acc_smem)
scf.yield_([])
gpu.barrier()
# GPU Step 9. Epilogue (shared memory --> global memory)
- fd = ir.MemRefType.get([BLOCK_M * BLOCK_N], c_elem_ty, memory_space=smem_space)
+ fd = ir.MemRefType.get([BLOCK_M * BLOCK_N],
+ c_elem_ty,
+ memory_space=smem_space)
collapsed_smem = memref.view(fd, dynamic_smem, c(0), [])
- rty = ir.MemRefType.get((BLOCK_M, BLOCK_N), c_elem_ty,
- ir.Attribute.parse("strided<[" + str(N) + ", 1], offset: ?>"))
- c_device_per_block = memref.SubViewOp(rty, c_device, [dimX, dimY], [], [], [MLIR_DYNAMIC, MLIR_DYNAMIC],
- [BLOCK_M, BLOCK_N], [1, 1])
+ rty = ir.MemRefType.get(
+ (BLOCK_M, BLOCK_N), c_elem_ty,
+ ir.Attribute.parse("strided<[" + str(N) +
+ ", 1], offset: ?>"))
+ c_device_per_block = memref.SubViewOp(
+ rty, c_device, [dimX, dimY], [], [],
+ [MLIR_DYNAMIC, MLIR_DYNAMIC], [BLOCK_M, BLOCK_N], [1, 1])
vlen = 1
- for_op = scf.ForOp(tidx, c(BLOCK_M * BLOCK_N), c(vlen * WARP_GROUP_SIZE * 2))
+ for_op = scf.ForOp(tidx, c(BLOCK_M * BLOCK_N),
+ c(vlen * WARP_GROUP_SIZE * 2))
with ir.InsertionPoint(for_op.body):
x = arith.divui(for_op.induction_variable, c(BLOCK_M))
y = arith.remui(for_op.induction_variable, c(BLOCK_N))
- vdata = vector.load(ir.VectorType.get((vlen, ), c_elem_ty), collapsed_smem,
+ vdata = vector.load(ir.VectorType.get(
+ (vlen, ), c_elem_ty), collapsed_smem,
[for_op.induction_variable])
vector.store(vdata, c_device_per_block, [x, y])
scf.yield_([])
@@ -382,7 +489,8 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
t9 = gpu.memcpy(token_ty, [t8], c_host, c_device)
gpu.wait(token_ty, [t9])
- mlir_matmul_warpspecialized.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
+ mlir_matmul_warpspecialized.func_op.attributes[
+ "llvm.emit_c_interface"] = ir.UnitAttr.get()
module.operation.verify()
return module
@@ -403,7 +511,8 @@ def generate_matmul_multistage(input_type=np.float16,
assert N % BLOCK_N == 0
assert K % BLOCK_K == 0
- required_stages = (M * K + K * N) // (BLOCK_M * BLOCK_K + BLOCK_K * BLOCK_N)
+ required_stages = (M * K + K * N) // (BLOCK_M * BLOCK_K +
+ BLOCK_K * BLOCK_N)
num_stages = min(required_stages, max_num_stages)
module = ir.Module.create()
@@ -421,25 +530,36 @@ def generate_matmul_multistage(input_type=np.float16,
a_tile_shape = a_tma_shape = (BLOCK_M, TMA_LAST_DIM_F16)
b_tma_shape = (BLOCK_K, TMA_LAST_DIM_F16)
b_tile_shape = (BLOCK_K, BLOCK_N)
- txcount = ((b_tile_shape[0] * b_tile_shape[1]) + (a_tile_shape[0] * a_tile_shape[1])) * f16_byte
+ txcount = ((b_tile_shape[0] * b_tile_shape[1]) +
+ (a_tile_shape[0] * a_tile_shape[1])) * f16_byte
smem_space_str = "#gpu.address_space<workgroup>"
smem_space = ir.Attribute.parse(smem_space_str)
input_type_str = "f16" if input_type == np.float16 else "f32"
output_type_str = "f16" if output_type == np.float16 else "f32"
- mbar_ty = ir.Type.parse("!nvgpu.mbarrier.group<memorySpace = " + str(smem_space) + ", num_barriers = " +
+ mbar_ty = ir.Type.parse("!nvgpu.mbarrier.group<memorySpace = " +
+ str(smem_space) + ", num_barriers = " +
str(num_stages) + ">")
- a_tma_desc_ty = ir.Type.parse("!nvgpu.tensormap.descriptor<tensor = memref<" + str(BLOCK_M) + "x" +
- str(TMA_LAST_DIM_F16) + "x" + str(input_type_str) + ", " + str(smem_space) +
- ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>")
- b_tma_desc_ty = ir.Type.parse("!nvgpu.tensormap.descriptor<tensor = memref<" + str(BLOCK_K) + "x" +
- str(TMA_LAST_DIM_F16) + "x" + str(input_type_str) + ", " + str(smem_space) +
- ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>")
- acc_ty = ir.Type.parse("!nvgpu.warpgroup.accumulator<fragmented=vector<" + str(BLOCK_M) + "x" + str(BLOCK_N) + "x" +
+ a_tma_desc_ty = ir.Type.parse(
+ "!nvgpu.tensormap.descriptor<tensor = memref<" + str(BLOCK_M) + "x" +
+ str(TMA_LAST_DIM_F16) + "x" + str(input_type_str) + ", " +
+ str(smem_space) +
+ ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>")
+ b_tma_desc_ty = ir.Type.parse(
+ "!nvgpu.tensormap.descriptor<tensor = memref<" + str(BLOCK_K) + "x" +
+ str(TMA_LAST_DIM_F16) + "x" + str(input_type_str) + ", " +
+ str(smem_space) +
+ ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>")
+ acc_ty = ir.Type.parse("!nvgpu.warpgroup.accumulator<fragmented=vector<" +
+ str(BLOCK_M) + "x" + str(BLOCK_N) + "x" +
str(output_type_str) + ">>")
- a_wgmma_ty = ir.Type.parse("!nvgpu.warpgroup.descriptor<tensor=memref<" + str(BLOCK_M) + "x" + str(BLOCK_K) + "x" +
- str(input_type_str) + ", " + smem_space_str + ">>")
- b_wgmma_ty = ir.Type.parse("!nvgpu.warpgroup.descriptor<tensor=memref<" + str(BLOCK_K) + "x" + str(BLOCK_N) + "x" +
- str(input_type_str) + ", " + smem_space_str + ">>")
+ a_wgmma_ty = ir.Type.parse("!nvgpu.warpgroup.descriptor<tensor=memref<" +
+ str(BLOCK_M) + "x" + str(BLOCK_K) + "x" +
+ str(input_type_str) + ", " + smem_space_str +
+ ">>")
+ b_wgmma_ty = ir.Type.parse("!nvgpu.warpgroup.descriptor<tensor=memref<" +
+ str(BLOCK_K) + "x" + str(BLOCK_N) + "x" +
+ str(input_type_str) + ", " + smem_space_str +
+ ">>")
with ir.InsertionPoint(module.body):
@@ -461,11 +581,16 @@ def mlir_matmul_multistage(a_host, b_host, c_host):
t7 = gpu.wait(token_ty, [t6])
# Step 2. Create TMA Descriptors
- tma_specs = [(a_device, a_tma_desc_ty, a_tma_shape), (b_device, b_tma_desc_ty, b_tma_shape)]
+ tma_specs = [(a_device, a_tma_desc_ty, a_tma_shape),
+ (b_device, b_tma_desc_ty, b_tma_shape)]
tma_descs = []
for x_device, tensor_map_ty, tile_shape in tma_specs:
- x_unranked = memref.cast(ir.UnrankedMemRefType.get(f16, a_ty.memory_space), x_device)
- tma_descs.append(nvgpu.TmaCreateDescriptorOp(tensor_map_ty, x_unranked, map(c, tile_shape)).result)
+ x_unranked = memref.cast(
+ ir.UnrankedMemRefType.get(f16, a_ty.memory_space),
+ x_device)
+ tma_descs.append(
+ nvgpu.TmaCreateDescriptorOp(tensor_map_ty, x_unranked,
+ map(c, tile_shape)).result)
a_tma_desc, b_tma_desc = tma_descs
# Step 3. Launch Kernel with 1 Warpgroup
@@ -477,13 +602,16 @@ def mlir_matmul_multistage(a_host, b_host, c_host):
launch_op = gpu.LaunchOp(token_ty, [t7],
*map(c, grid),
*map(c, block),
- dynamicSharedMemorySize=c(smem_size, ty=i32))
+ dynamicSharedMemorySize=c(smem_size,
+ ty=i32))
launch_op.body.blocks.append(*([index] * 12))
with ir.InsertionPoint(launch_op.body.blocks[0]):
# GPU Step 0. Bootstrapping
memref.assume_alignment(c_device, 16)
dynamic_smem = gpu.dynamic_shared_memory(
- ir.MemRefType.get((MLIR_DYNAMIC, ), i8, memory_space=smem_space))
+ ir.MemRefType.get((MLIR_DYNAMIC, ),
+ i8,
+ memory_space=smem_space))
ticks = c(10000000)
tidx = gpu.thread_id(gpu.Dimension.x)
primaryThread = arith.cmpi(arith.CmpIPredicate.eq, tidx, c(0))
@@ -496,39 +624,58 @@ def mlir_matmul_multistage(a_host, b_host, c_host):
# GPU Step 1. Initialize mbarrier groups
mbarTMA = nvgpu.mbarrier_create(mbar_ty)
for i in range(num_stages):
- nvgpu.mbarrier_init(mbarTMA, c(1), c(i), predicate=primaryThread)
+ nvgpu.mbarrier_init(mbarTMA,
+ c(1),
+ c(i),
+ predicate=primaryThread)
gpu.barrier()
# GPU Step 2. Prefetch TMA descriptors
- nvgpu.tma_prefetch_descriptor(a_tma_desc, predicate=primaryThread)
- nvgpu.tma_prefetch_descriptor(b_tma_desc, predicate=primaryThread)
+ nvgpu.tma_prefetch_descriptor(a_tma_desc,
+ predicate=primaryThread)
+ nvgpu.tma_prefetch_descriptor(b_tma_desc,
+ predicate=primaryThread)
# GPU Step 3. Prologue (global memory --> shared memory)
- for_op = scf.ForOp(c(0), c(num_stages-1), c(1))
+ for_op = scf.ForOp(c(0), c(num_stages - 1), c(1))
with ir.InsertionPoint(for_op.body):
iv = for_op.induction_variable
# Step 3.1. Calculate offsets
a_offset = arith.muli(iv, c(lhs_tile_bytes))
- a_tma_slice = memref.view(ir.MemRefType.get(a_tma_shape, f16, memory_space=smem_space),
- dynamic_smem, a_offset, [])
- b_offset = arith.addi(arith.muli(iv, c(rhs_tile_bytes)), c(lhs_tile_bytes * num_stages))
- b_tma_slice_1 = memref.view(ir.MemRefType.get(b_tma_shape, f16, memory_space=smem_space),
- dynamic_smem, b_offset, [])
- b_offset2 = arith.addi(b_offset, c(BLOCK_K * TMA_LAST_DIM_F16 * f16_byte))
- b_tma_slice_2 = memref.view(ir.MemRefType.get(b_tma_shape, f16, memory_space=smem_space),
- dynamic_smem, b_offset2, [])
+ a_tma_slice = memref.view(
+ ir.MemRefType.get(a_tma_shape,
+ f16,
+ memory_space=smem_space),
+ dynamic_smem, a_offset, [])
+ b_offset = arith.addi(arith.muli(iv, c(rhs_tile_bytes)),
+ c(lhs_tile_bytes * num_stages))
+ b_tma_slice_1 = memref.view(
+ ir.MemRefType.get(b_tma_shape,
+ f16,
+ memory_space=smem_space),
+ dynamic_smem, b_offset, [])
+ b_offset2 = arith.addi(
+ b_offset, c(BLOCK_K * TMA_LAST_DIM_F16 * f16_byte))
+ b_tma_slice_2 = memref.view(
+ ir.MemRefType.get(b_tma_shape,
+ f16,
+ memory_space=smem_space),
+ dynamic_smem, b_offset2, [])
# Step 3.2. TMA Load
coord = arith.muli(c(64), iv)
dimY2 = arith.addi(dimY, c(64))
- debug_print("[Prologue] TMA Load a_offset={} b_offset={} b_offset2={} @ a=({},{}) b=({},{})",
- a_offset,
- b_offset,
- b_offset2,
- coord, dimX,
- dimY, coord,
- predicate=primaryThread)
+ debug_print(
+ "[Prologue] TMA Load a_offset={} b_offset={} b_offset2={} @ a=({},{}) b=({},{})",
+ a_offset,
+ b_offset,
+ b_offset2,
+ coord,
+ dimX,
+ dimY,
+ coord,
+ predicate=primaryThread)
nvgpu.TmaAsyncLoadOp(a_tma_slice,
mbarTMA,
a_tma_desc,
@@ -549,89 +696,153 @@ def mlir_matmul_multistage(a_host, b_host, c_host):
predicate=primaryThread)
# Step 3.2. mbarTMA arrive
- debug_print("[Prologue] mbarTMA[{}] arrive", iv, predicate=primaryThread)
- nvgpu.mbarrier_arrive_expect_tx(mbarTMA, c(txcount), iv, predicate=primaryThread)
- debug_print("[Prologue] mbarTMA[{}] arrive [done]", iv, predicate=primaryThread)
+ debug_print("[Prologue] mbarTMA[{}] arrive",
+ iv,
+ predicate=primaryThread)
+ nvgpu.mbarrier_arrive_expect_tx(mbarTMA,
+ c(txcount),
+ iv,
+ predicate=primaryThread)
+ debug_print("[Prologue] mbarTMA[{}] arrive [done]",
+ iv,
+ predicate=primaryThread)
scf.yield_([])
# GPU Step 4. Main Loop
acc = nvgpu.warpgroup_mma_init_accumulator(acc_ty)
- for_op = scf.ForOp(c(0), c(K // BLOCK_K), c(1), [acc, arith.constant(i1, 0)])
+ for_op = scf.ForOp(c(0), c(K // BLOCK_K), c(1),
+ [acc, arith.constant(i1, 0)])
with ir.InsertionPoint(for_op.body):
# Step 4.1. Wait mbarTMA
phaseParity = for_op.inner_iter_args[1]
iv = for_op.induction_variable
stage = arith.remui(iv, c(num_stages))
- debug_print("[MainLoop] mbarTMA[{}] try_wait phase={}", stage, phaseParity, predicate=primaryThread)
- nvgpu.MBarrierTryWaitParityOp(mbarTMA, phaseParity, ticks, mbarId=stage)
- debug_print("[MainLoop] mbarTMA[{}] try_wait phase={} [done]", stage, phaseParity, predicate=primaryThread)
+ debug_print("[MainLoop] mbarTMA[{}] try_wait phase={}",
+ stage,
+ phaseParity,
+ predicate=primaryThread)
+ nvgpu.MBarrierTryWaitParityOp(mbarTMA,
+ phaseParity,
+ ticks,
+ mbarId=stage)
+ debug_print(
+ "[MainLoop] mbarTMA[{}] try_wait phase={} [done]",
+ stage,
+ phaseParity,
+ predicate=primaryThread)
# Step 4.2. Create WGMMA Descriptors
a_offset = arith.muli(stage, c(lhs_tile_bytes))
- a_tile_slice = memref.view(ir.MemRefType.get(a_tile_shape, f16, memory_space=smem_space),
- dynamic_smem, a_offset, [])
- b_offset = arith.addi(arith.muli(stage, c(rhs_tile_bytes)), c(lhs_tile_bytes * num_stages))
- b_tile_slice = memref.view(ir.MemRefType.get(b_tile_shape, f16, memory_space=smem_space),
- dynamic_smem, b_offset, [])
- debug_print("[MainLoop] iv={} MMA a_offset={} b_offset={}", iv, a_offset, b_offset, predicate=primaryThread)
- da = nvgpu.WarpgroupGenerateDescriptorOp(a_wgmma_ty, a_tile_slice, a_tma_desc)
- db = nvgpu.WarpgroupGenerateDescriptorOp(b_wgmma_ty, b_tile_slice, b_tma_desc)
+ a_tile_slice = memref.view(
+ ir.MemRefType.get(a_tile_shape,
+ f16,
+ memory_space=smem_space),
+ dynamic_smem, a_offset, [])
+ b_offset = arith.addi(arith.muli(stage, c(rhs_tile_bytes)),
+ c(lhs_tile_bytes * num_stages))
+ b_tile_slice = memref.view(
+ ir.MemRefType.get(b_tile_shape,
+ f16,
+ memory_space=smem_space),
+ dynamic_smem, b_offset, [])
+ debug_print("[MainLoop] iv={} MMA a_offset={} b_offset={}",
+ iv,
+ a_offset,
+ b_offset,
+ predicate=primaryThread)
+ da = nvgpu.WarpgroupGenerateDescriptorOp(
+ a_wgmma_ty, a_tile_slice, a_tma_desc)
+ db = nvgpu.WarpgroupGenerateDescriptorOp(
+ b_wgmma_ty, b_tile_slice, b_tma_desc)
# Step 4.3. MMA
carry_acc = for_op.inner_iter_args[0]
- new_acc = nvgpu.WarpgroupMmaOp(acc.type, da, db, carry_acc, transposeB=True)
+ new_acc = nvgpu.WarpgroupMmaOp(acc.type,
+ da,
+ db,
+ carry_acc,
+ transposeB=True)
# Step 4.4. Load TMA for next stage
- p1 = arith.cmpi(arith.CmpIPredicate.ult, arith.addi(iv, c(num_stages-1)), c(K // BLOCK_K))
+ p1 = arith.cmpi(arith.CmpIPredicate.ult,
+ arith.addi(iv, c(num_stages - 1)),
+ c(K // BLOCK_K))
p = arith.andi(primaryThread, p1)
with ir.InsertionPoint(scf.IfOp(p).then_block):
- nextStage = arith.addi(iv, c(num_stages-1))
+ nextStage = arith.addi(iv, c(num_stages - 1))
nextSlot = arith.remui(nextStage, c(num_stages))
a_offset = arith.muli(nextSlot, c(lhs_tile_bytes))
- a_tma_slice = memref.view(ir.MemRefType.get(a_tma_shape, f16, memory_space=smem_space),
- dynamic_smem, a_offset, [])
- b_offset = arith.addi(arith.muli(nextSlot, c(rhs_tile_bytes)), c(lhs_tile_bytes * num_stages))
- b_tma_slice_1 = memref.view(ir.MemRefType.get(b_tma_shape, f16, memory_space=smem_space),
- dynamic_smem, b_offset, [])
- b_offset2 = arith.addi(b_offset, c(BLOCK_K * TMA_LAST_DIM_F16 * f16_byte))
- b_tma_slice_2 = memref.view(ir.MemRefType.get(b_tma_shape, f16, memory_space=smem_space),
- dynamic_smem, b_offset2, [])
+ a_tma_slice = memref.view(
+ ir.MemRefType.get(a_tma_shape,
+ f16,
+ memory_space=smem_space),
+ dynamic_smem, a_offset, [])
+ b_offset = arith.addi(
+ arith.muli(nextSlot, c(rhs_tile_bytes)),
+ c(lhs_tile_bytes * num_stages))
+ b_tma_slice_1 = memref.view(
+ ir.MemRefType.get(b_tma_shape,
+ f16,
+ memory_space=smem_space),
+ dynamic_smem, b_offset, [])
+ b_offset2 = arith.addi(
+ b_offset, c(BLOCK_K * TMA_LAST_DIM_F16 * f16_byte))
+ b_tma_slice_2 = memref.view(
+ ir.MemRefType.get(b_tma_shape,
+ f16,
+ memory_space=smem_space),
+ dynamic_smem, b_offset2, [])
coord = arith.muli(c(64), nextStage)
- debug_print("[MainLoop] iv={} TMA Load a_offset={} b_offset={} b_offset2={} @ a=({},{}) b=({},{})",
- iv, a_offset,
- b_offset,
- b_offset2,
- coord, dimX,
- dimY, coord,
- predicate=primaryThread)
+ debug_print(
+ "[MainLoop] iv={} TMA Load a_offset={} b_offset={} b_offset2={} @ a=({},{}) b=({},{})",
+ iv,
+ a_offset,
+ b_offset,
+ b_offset2,
+ coord,
+ dimX,
+ dimY,
+ coord,
+ predicate=primaryThread)
nvgpu.TmaAsyncLoadOp(a_tma_slice,
- mbarTMA,
- a_tma_desc,
- coordinates=[coord, dimX],
- mbarId=nextSlot,
- predicate=primaryThread)
+ mbarTMA,
+ a_tma_desc,
+ coordinates=[coord, dimX],
+ mbarId=nextSlot,
+ predicate=primaryThread)
nvgpu.TmaAsyncLoadOp(b_tma_slice_1,
- mbarTMA,
- b_tma_desc,
- coordinates=[dimY, coord],
- mbarId=nextSlot,
- predicate=primaryThread)
+ mbarTMA,
+ b_tma_desc,
+ coordinates=[dimY, coord],
+ mbarId=nextSlot,
+ predicate=primaryThread)
dimY2 = arith.addi(dimY, c(64))
nvgpu.TmaAsyncLoadOp(b_tma_slice_2,
- mbarTMA,
- b_tma_desc,
- coordinates=[dimY2, coord],
- mbarId=nextSlot,
- predicate=primaryThread)
-
- debug_print("[MainLoop] mbarTMA[{}] arrive", nextSlot, predicate=primaryThread)
- nvgpu.mbarrier_arrive_expect_tx(mbarTMA, c(txcount), nextSlot, predicate=primaryThread)
- debug_print("[MainLoop] mbarTMA[{}] arrive [done]", nextSlot, predicate=primaryThread)
+ mbarTMA,
+ b_tma_desc,
+ coordinates=[dimY2, coord],
+ mbarId=nextSlot,
+ predicate=primaryThread)
+
+ debug_print("[MainLoop] mbarTMA[{}] arrive",
+ nextSlot,
+ predicate=primaryThread)
+ nvgpu.mbarrier_arrive_expect_tx(
+ mbarTMA,
+ c(txcount),
+ nextSlot,
+ predicate=primaryThread)
+ debug_print("[MainLoop] mbarTMA[{}] arrive [done]",
+ nextSlot,
+ predicate=primaryThread)
scf.yield_([])
# Step 4.5. Change the phaseParity
- p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(num_stages - 1))
- phaseParity = arith.select(p, arith.xori(phaseParity, arith.constant(i1, 1)), phaseParity)
+ p = arith.cmpi(arith.CmpIPredicate.eq, stage,
+ c(num_stages - 1))
+ phaseParity = arith.select(
+ p, arith.xori(phaseParity, arith.constant(i1, 1)),
+ phaseParity)
# Step 4.5. Yield
scf.yield_([new_acc, phaseParity])
@@ -641,25 +852,34 @@ def mlir_matmul_multistage(a_host, b_host, c_host):
gpu.barrier()
# Step 6. Epilogue (registers --> shared memory)
- acc_smem_ty = ir.MemRefType.get((BLOCK_M, BLOCK_N), c_elem_ty, memory_space=smem_space)
+ acc_smem_ty = ir.MemRefType.get((BLOCK_M, BLOCK_N),
+ c_elem_ty,
+ memory_space=smem_space)
acc_smem = memref.view(acc_smem_ty, dynamic_smem, c(0), [])
debug_print("Storing", predicate=primaryThread)
nvgpu.WarpgroupMmaStoreOp(for_op.results[0], acc_smem)
gpu.barrier()
# GPU Step 7. Epilogue (shared memory --> global memory)
- fd = ir.MemRefType.get([BLOCK_M * BLOCK_N], c_elem_ty, memory_space=smem_space)
+ fd = ir.MemRefType.get([BLOCK_M * BLOCK_N],
+ c_elem_ty,
+ memory_space=smem_space)
collapsed_smem = memref.view(fd, dynamic_smem, c(0), [])
- rty = ir.MemRefType.get((BLOCK_M, BLOCK_N), c_elem_ty,
- ir.Attribute.parse("strided<[" + str(N) + ", 1], offset: ?>"))
- c_device_per_block = memref.SubViewOp(rty, c_device, [dimX, dimY], [], [], [MLIR_DYNAMIC, MLIR_DYNAMIC],
- [BLOCK_M, BLOCK_N], [1, 1])
+ rty = ir.MemRefType.get(
+ (BLOCK_M, BLOCK_N), c_elem_ty,
+ ir.Attribute.parse("strided<[" + str(N) +
+ ", 1], offset: ?>"))
+ c_device_per_block = memref.SubViewOp(
+ rty, c_device, [dimX, dimY], [], [],
+ [MLIR_DYNAMIC, MLIR_DYNAMIC], [BLOCK_M, BLOCK_N], [1, 1])
vlen = 1
- for_op = scf.ForOp(tidx, c(BLOCK_M * BLOCK_N), c(vlen * WARP_GROUP_SIZE))
+ for_op = scf.ForOp(tidx, c(BLOCK_M * BLOCK_N),
+ c(vlen * WARP_GROUP_SIZE))
with ir.InsertionPoint(for_op.body):
x = arith.divui(for_op.induction_variable, c(BLOCK_M))
y = arith.remui(for_op.induction_variable, c(BLOCK_N))
- vdata = vector.load(ir.VectorType.get((vlen, ), c_elem_ty), collapsed_smem,
+ vdata = vector.load(ir.VectorType.get(
+ (vlen, ), c_elem_ty), collapsed_smem,
[for_op.induction_variable])
vector.store(vdata, c_device_per_block, [x, y])
scf.yield_([])
@@ -671,6 +891,7 @@ def mlir_matmul_multistage(a_host, b_host, c_host):
t9 = gpu.memcpy(token_ty, [t8], c_host, c_device)
gpu.wait(token_ty, [t9])
- mlir_matmul_multistage.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
+ mlir_matmul_multistage.func_op.attributes[
+ "llvm.emit_c_interface"] = ir.UnitAttr.get()
module.operation.verify()
return module
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/python/tools/nvgpucompiler.py b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/nvgpucompiler.py
index 2a4f67326ee718..ab218812bb4d96 100644
--- a/mlir/test/Integration/GPU/CUDA/sm90/python/tools/nvgpucompiler.py
+++ b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/nvgpucompiler.py
@@ -15,10 +15,12 @@
_SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__))
sys.path.append(_SCRIPT_PATH)
+
class NvgpuCompiler:
"""Nvgpu class for compiling and building MLIR modules."""
- def __init__(self, options: str, opt_level: int, shared_libs: Sequence[str]):
+ def __init__(self, options: str, opt_level: int,
+ shared_libs: Sequence[str]):
pipeline = f"builtin.module(gpu-lower-to-nvvm-pipeline{{{options}}})"
self.pipeline = pipeline
self.shared_libs = shared_libs
@@ -31,14 +33,15 @@ def __call__(self, module: ir.Module):
def compile(self, module: ir.Module):
"""Compiles the module by invoking the nvgpu pipeline."""
passmanager.PassManager.parse(self.pipeline).run(module.operation)
-
+
def jit(self, module: ir.Module) -> execution_engine.ExecutionEngine:
"""Wraps the module in a JIT execution engine."""
- return execution_engine.ExecutionEngine(
- module, opt_level=self.opt_level, shared_libs=self.shared_libs
- )
+ return execution_engine.ExecutionEngine(module,
+ opt_level=self.opt_level,
+ shared_libs=self.shared_libs)
- def compile_and_jit(self, module: ir.Module) -> execution_engine.ExecutionEngine:
+ def compile_and_jit(self,
+ module: ir.Module) -> execution_engine.ExecutionEngine:
"""Compiles and jits the module."""
self.compile(module)
return self.jit(module)
>From 39f7213b299b316f59555a6b095b57217d0fdaab Mon Sep 17 00:00:00 2001
From: grypp <guray.ozen at gmail.com>
Date: Mon, 12 Feb 2024 13:59:39 +0000
Subject: [PATCH 03/10] format it with black
---
.../GPU/CUDA/sm90/python/matmul.py | 177 ++--
.../CUDA/sm90/python/tools/matmulBuilder.py | 975 +++++++++++-------
.../CUDA/sm90/python/tools/nvgpucompiler.py | 12 +-
3 files changed, 701 insertions(+), 463 deletions(-)
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py b/mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py
index 00313270d723d6..ff9c950193e695 100644
--- a/mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py
+++ b/mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py
@@ -86,27 +86,44 @@
from mlir import runtime as rt
-def generate_matmul(input_type=np.float16,
- output_type=np.float32,
- M=4096,
- N=4096,
- K=4096,
- BLOCK_M=128,
- BLOCK_N=128,
- BLOCK_K=64,
- use_warp_specilization=True,
- saveIR=False,
- max_num_stages=3):
- with matmulBuilder.ir.Context() as ctx, matmulBuilder.ir.Location.unknown(
- ):
+def generate_matmul(
+ input_type=np.float16,
+ output_type=np.float32,
+ M=4096,
+ N=4096,
+ K=4096,
+ BLOCK_M=128,
+ BLOCK_N=128,
+ BLOCK_K=64,
+ use_warp_specilization=True,
+ saveIR=False,
+ max_num_stages=3,
+):
+ with matmulBuilder.ir.Context() as ctx, matmulBuilder.ir.Location.unknown():
if use_warp_specilization:
mlir_nvgpu_module = matmulBuilder.generate_matmul_ws(
- input_type, output_type, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K,
- max_num_stages)
+ input_type,
+ output_type,
+ M,
+ N,
+ K,
+ BLOCK_M,
+ BLOCK_N,
+ BLOCK_K,
+ max_num_stages,
+ )
else:
mlir_nvgpu_module = matmulBuilder.generate_matmul_multistage(
- input_type, output_type, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K,
- max_num_stages)
+ input_type,
+ output_type,
+ M,
+ N,
+ K,
+ BLOCK_M,
+ BLOCK_N,
+ BLOCK_K,
+ max_num_stages,
+ )
mlir_nvgpu_module.operation.verify()
@@ -114,7 +131,7 @@ def generate_matmul(input_type=np.float16,
if saveIR:
# print(mlir_nvgpu_module)
original_stdout = sys.stdout
- with open('gemm.mlir', 'w') as f:
+ with open("gemm.mlir", "w") as f:
sys.stdout = f
print(mlir_nvgpu_module)
sys.stdout = original_stdout
@@ -123,43 +140,77 @@ def generate_matmul(input_type=np.float16,
options = f"cubin-chip=sm_90a cubin-features=+ptx80 opt-level=3"
support_lib = os.getenv("SUPPORT_LIB")
if not os.path.exists(support_lib):
- raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT),
- support_lib)
- compiler = nvgpucompiler.NvgpuCompiler(options,
- opt_level=3,
- shared_libs=[support_lib])
+ raise FileNotFoundError(
+ errno.ENOENT, os.strerror(errno.ENOENT), support_lib
+ )
+ compiler = nvgpucompiler.NvgpuCompiler(
+ options, opt_level=3, shared_libs=[support_lib]
+ )
# Compile
engine = compiler.compile_and_jit(mlir_nvgpu_module)
return engine
-def matmul(input_type=np.float16,
- output_type=np.float32,
- M=128,
- N=128,
- K=128,
- BLOCK_M=128,
- BLOCK_N=128,
- BLOCK_K=64,
- use_warp_specilization=True,
- saveIR=False,
- max_num_stages=3,
- print_results=False,
- no_verify=False):
+def matmul(
+ input_type=np.float16,
+ output_type=np.float32,
+ M=128,
+ N=128,
+ K=128,
+ BLOCK_M=128,
+ BLOCK_N=128,
+ BLOCK_K=64,
+ use_warp_specilization=True,
+ saveIR=False,
+ max_num_stages=3,
+ print_results=False,
+ no_verify=False,
+):
# Print the configuration
ity = "f16" if input_type == np.float16 else "f32"
oty = "f16" if output_type == np.float16 else "f32"
gemmty = "Warp Specilization" if use_warp_specilization else "Multistage"
- print("===-- Running GEMM " + gemmty + " " + oty + " += " + ity + " * " +
- ity + ", Size " + str(M) + "x" + str(N) + "x" + str(K) + ", Tile " +
- str(BLOCK_M) + "x" + str(BLOCK_N) + "x" + str(BLOCK_K) +
- ", stages " + str(max_num_stages) + " --===")
+ print(
+ "===-- Running GEMM "
+ + gemmty
+ + " "
+ + oty
+ + " += "
+ + ity
+ + " * "
+ + ity
+ + ", Size "
+ + str(M)
+ + "x"
+ + str(N)
+ + "x"
+ + str(K)
+ + ", Tile "
+ + str(BLOCK_M)
+ + "x"
+ + str(BLOCK_N)
+ + "x"
+ + str(BLOCK_K)
+ + ", stages "
+ + str(max_num_stages)
+ + " --==="
+ )
# Build IR and compile
- engine = generate_matmul(input_type, output_type, M, N, K, BLOCK_M,
- BLOCK_N, BLOCK_K, use_warp_specilization, saveIR,
- max_num_stages)
+ engine = generate_matmul(
+ input_type,
+ output_type,
+ M,
+ N,
+ K,
+ BLOCK_M,
+ BLOCK_N,
+ BLOCK_K,
+ use_warp_specilization,
+ saveIR,
+ max_num_stages,
+ )
# Allocate matrices and invoke the matmul
c = np.zeros((M, N), output_type)
@@ -168,13 +219,17 @@ def matmul(input_type=np.float16,
mem_a = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(a)))
mem_b = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(b)))
mem_c = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(c)))
- kernelName = "mlir_matmul_warpspecialized" if use_warp_specilization else "mlir_matmul_multistage"
+ kernelName = (
+ "mlir_matmul_warpspecialized"
+ if use_warp_specilization
+ else "mlir_matmul_multistage"
+ )
# Launch the MLIR generated kernel
engine.invoke(kernelName, mem_a, mem_b, mem_c)
float_formatter = "{:.2f}".format
- np.set_printoptions(formatter={'float_kind': float_formatter})
+ np.set_printoptions(formatter={"float_kind": float_formatter})
if print_results:
print(c)
@@ -190,18 +245,22 @@ def matmul(input_type=np.float16,
# GEMM Multistage f32 += f16 * f16
-matmul(np.float16,
- np.float32,
- 128,
- 128,
- 4096,
- max_num_stages=3,
- use_warp_specilization=False)
+matmul(
+ np.float16,
+ np.float32,
+ 128,
+ 128,
+ 4096,
+ max_num_stages=3,
+ use_warp_specilization=False,
+)
# GEMM Warp Specilized f32 += f16 * f16
-matmul(np.float16,
- np.float32,
- 256,
- 1024,
- 512,
- max_num_stages=3,
- use_warp_specilization=True)
+matmul(
+ np.float16,
+ np.float32,
+ 256,
+ 1024,
+ 512,
+ max_num_stages=3,
+ use_warp_specilization=True,
+)
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
index 0fda8698606cd5..4c132ec992a461 100644
--- a/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
+++ b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
@@ -64,15 +64,17 @@ def c(value, ty=None):
return arith.constant(ty, value)
-def generate_matmul_ws(input_type=np.float16,
- output_type=np.float32,
- M=4096,
- N=4096,
- K=4096,
- BLOCK_M=128,
- BLOCK_N=128,
- BLOCK_K=128,
- max_num_stages=3):
+def generate_matmul_ws(
+ input_type=np.float16,
+ output_type=np.float32,
+ M=4096,
+ N=4096,
+ K=4096,
+ BLOCK_M=128,
+ BLOCK_N=128,
+ BLOCK_K=128,
+ max_num_stages=3,
+):
# Limitaitons for now
assert input_type == np.float16
assert output_type == np.float32
@@ -80,8 +82,7 @@ def generate_matmul_ws(input_type=np.float16,
assert N % BLOCK_N == 0
assert K % BLOCK_K == 0
- required_stages = (M * K + K * N) // (BLOCK_M * BLOCK_K +
- BLOCK_K * BLOCK_N)
+ required_stages = (M * K + K * N) // (BLOCK_M * BLOCK_K + BLOCK_K * BLOCK_N)
num_stages = min(required_stages, max_num_stages)
module = ir.Module.create()
@@ -99,36 +100,73 @@ def generate_matmul_ws(input_type=np.float16,
a_tile_shape = a_tma_shape = (BLOCK_M, TMA_LAST_DIM_F16)
b_tma_shape = (BLOCK_K, TMA_LAST_DIM_F16)
b_tile_shape = (BLOCK_K, BLOCK_N)
- txcount = ((b_tile_shape[0] * b_tile_shape[1]) +
- (a_tile_shape[0] * a_tile_shape[1])) * f16_byte
+ txcount = (
+ (b_tile_shape[0] * b_tile_shape[1]) + (a_tile_shape[0] * a_tile_shape[1])
+ ) * f16_byte
smem_space_str = "#gpu.address_space<workgroup>"
smem_space = ir.Attribute.parse(smem_space_str)
input_type_str = "f16" if input_type == np.float16 else "f32"
output_type_str = "f16" if output_type == np.float16 else "f32"
- mbar_ty = ir.Type.parse("!nvgpu.mbarrier.group<memorySpace = " +
- str(smem_space) + ", num_barriers = " +
- str(num_stages) + ">")
+ mbar_ty = ir.Type.parse(
+ "!nvgpu.mbarrier.group<memorySpace = "
+ + str(smem_space)
+ + ", num_barriers = "
+ + str(num_stages)
+ + ">"
+ )
a_tma_desc_ty = ir.Type.parse(
- "!nvgpu.tensormap.descriptor<tensor = memref<" + str(BLOCK_M) + "x" +
- str(TMA_LAST_DIM_F16) + "x" + str(input_type_str) + ", " +
- str(smem_space) +
- ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>")
+ "!nvgpu.tensormap.descriptor<tensor = memref<"
+ + str(BLOCK_M)
+ + "x"
+ + str(TMA_LAST_DIM_F16)
+ + "x"
+ + str(input_type_str)
+ + ", "
+ + str(smem_space)
+ + ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>"
+ )
b_tma_desc_ty = ir.Type.parse(
- "!nvgpu.tensormap.descriptor<tensor = memref<" + str(BLOCK_K) + "x" +
- str(TMA_LAST_DIM_F16) + "x" + str(input_type_str) + ", " +
- str(smem_space) +
- ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>")
- acc_ty = ir.Type.parse("!nvgpu.warpgroup.accumulator<fragmented=vector<" +
- str(BLOCK_M) + "x" + str(BLOCK_N) + "x" +
- str(output_type_str) + ">>")
- a_wgmma_ty = ir.Type.parse("!nvgpu.warpgroup.descriptor<tensor=memref<" +
- str(BLOCK_M) + "x" + str(BLOCK_K) + "x" +
- str(input_type_str) + ", " + smem_space_str +
- ">>")
- b_wgmma_ty = ir.Type.parse("!nvgpu.warpgroup.descriptor<tensor=memref<" +
- str(BLOCK_K) + "x" + str(BLOCK_N) + "x" +
- str(input_type_str) + ", " + smem_space_str +
- ">>")
+ "!nvgpu.tensormap.descriptor<tensor = memref<"
+ + str(BLOCK_K)
+ + "x"
+ + str(TMA_LAST_DIM_F16)
+ + "x"
+ + str(input_type_str)
+ + ", "
+ + str(smem_space)
+ + ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>"
+ )
+ acc_ty = ir.Type.parse(
+ "!nvgpu.warpgroup.accumulator<fragmented=vector<"
+ + str(BLOCK_M)
+ + "x"
+ + str(BLOCK_N)
+ + "x"
+ + str(output_type_str)
+ + ">>"
+ )
+ a_wgmma_ty = ir.Type.parse(
+ "!nvgpu.warpgroup.descriptor<tensor=memref<"
+ + str(BLOCK_M)
+ + "x"
+ + str(BLOCK_K)
+ + "x"
+ + str(input_type_str)
+ + ", "
+ + smem_space_str
+ + ">>"
+ )
+ b_wgmma_ty = ir.Type.parse(
+ "!nvgpu.warpgroup.descriptor<tensor=memref<"
+ + str(BLOCK_K)
+ + "x"
+ + str(BLOCK_N)
+ + "x"
+ + str(input_type_str)
+ + ", "
+ + smem_space_str
+ + ">>"
+ )
with ir.InsertionPoint(module.body):
@@ -150,16 +188,20 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
t7 = gpu.wait(token_ty, [t6])
# Step 2. Create TMA Descriptors
- tma_specs = [(a_device, a_tma_desc_ty, a_tma_shape),
- (b_device, b_tma_desc_ty, b_tma_shape)]
+ tma_specs = [
+ (a_device, a_tma_desc_ty, a_tma_shape),
+ (b_device, b_tma_desc_ty, b_tma_shape),
+ ]
tma_descs = []
for x_device, tensor_map_ty, tile_shape in tma_specs:
x_unranked = memref.cast(
- ir.UnrankedMemRefType.get(f16, a_ty.memory_space),
- x_device)
+ ir.UnrankedMemRefType.get(f16, a_ty.memory_space), x_device
+ )
tma_descs.append(
- nvgpu.TmaCreateDescriptorOp(tensor_map_ty, x_unranked,
- map(c, tile_shape)).result)
+ nvgpu.TmaCreateDescriptorOp(
+ tensor_map_ty, x_unranked, map(c, tile_shape)
+ ).result
+ )
a_tma_desc, b_tma_desc = tma_descs
# Step 3. Launch Kernel with 2 Warpgroups : 1 Producer, 1 Consumer
@@ -168,40 +210,45 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
assert M % BLOCK_M == 0 and N % BLOCK_N == 0
grid = (cta_m, cta_n, 1)
block = (WARP_GROUP_SIZE * 2, 1, 1)
- launch_op = gpu.LaunchOp(token_ty, [t7],
- *map(c, grid),
- *map(c, block),
- dynamicSharedMemorySize=c(smem_size,
- ty=i32))
+ launch_op = gpu.LaunchOp(
+ token_ty,
+ [t7],
+ *map(c, grid),
+ *map(c, block),
+ dynamicSharedMemorySize=c(smem_size, ty=i32)
+ )
launch_op.body.blocks.append(*([index] * 12))
with ir.InsertionPoint(launch_op.body.blocks[0]):
# GPU Step 0. This is need for vectorized ld/st
memref.assume_alignment(c_device, 16)
dynamic_smem = gpu.dynamic_shared_memory(
- ir.MemRefType.get((MLIR_DYNAMIC, ),
- i8,
- memory_space=smem_space))
+ ir.MemRefType.get((MLIR_DYNAMIC,), i8, memory_space=smem_space)
+ )
ticks = c(10000000)
# GPU Step 1. Bootstrapping: find the primary thread, warps, warp groups and etc.
tidx = gpu.thread_id(gpu.Dimension.x)
wgPrimaryThread = arith.cmpi(
- arith.CmpIPredicate.eq,
- arith.remui(tidx, c(WARP_GROUP_SIZE)), c(0))
+ arith.CmpIPredicate.eq, arith.remui(tidx, c(WARP_GROUP_SIZE)), c(0)
+ )
warp_id = arith.divui(tidx, c(32))
warpgroup_id = arith.divui(warp_id, c(4))
is_producer = arith.cmpi(
- arith.CmpIPredicate.eq, warpgroup_id,
- c(1 if PRODUCER_PRIMARY_THREAD == 128 else 0))
+ arith.CmpIPredicate.eq,
+ warpgroup_id,
+ c(1 if PRODUCER_PRIMARY_THREAD == 128 else 0),
+ )
is_consumer = arith.cmpi(
- arith.CmpIPredicate.eq, warpgroup_id,
- c(0 if CONSUMER_PRIMARY_THREAD == 0 else 1))
- producerPrimaryThread = arith.cmpi(arith.CmpIPredicate.eq,
- tidx,
- c(PRODUCER_PRIMARY_THREAD))
- consumerPrimaryThread = arith.cmpi(arith.CmpIPredicate.eq,
- tidx,
- c(CONSUMER_PRIMARY_THREAD))
+ arith.CmpIPredicate.eq,
+ warpgroup_id,
+ c(0 if CONSUMER_PRIMARY_THREAD == 0 else 1),
+ )
+ producerPrimaryThread = arith.cmpi(
+ arith.CmpIPredicate.eq, tidx, c(PRODUCER_PRIMARY_THREAD)
+ )
+ consumerPrimaryThread = arith.cmpi(
+ arith.CmpIPredicate.eq, tidx, c(CONSUMER_PRIMARY_THREAD)
+ )
bidx = gpu.block_id(gpu.Dimension.x)
bidy = gpu.block_id(gpu.Dimension.y)
dimX = arith.muli(bidx, c(BLOCK_M))
@@ -211,32 +258,25 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
mbarTMA = nvgpu.mbarrier_create(mbar_ty)
mbarDONE = nvgpu.mbarrier_create(mbar_ty)
for i in range(num_stages):
- nvgpu.mbarrier_init(mbarTMA,
- c(1),
- c(i),
- predicate=wgPrimaryThread)
- nvgpu.mbarrier_init(mbarDONE,
- c(1),
- c(i),
- predicate=wgPrimaryThread)
+ nvgpu.mbarrier_init(mbarTMA, c(1), c(i), predicate=wgPrimaryThread)
+ nvgpu.mbarrier_init(mbarDONE, c(1), c(i), predicate=wgPrimaryThread)
gpu.barrier()
# GPU Step 3. Prefetch TMA descriptors
- nvgpu.tma_prefetch_descriptor(a_tma_desc,
- predicate=wgPrimaryThread)
- nvgpu.tma_prefetch_descriptor(b_tma_desc,
- predicate=wgPrimaryThread)
+ nvgpu.tma_prefetch_descriptor(a_tma_desc, predicate=wgPrimaryThread)
+ nvgpu.tma_prefetch_descriptor(b_tma_desc, predicate=wgPrimaryThread)
# GPU Step 5. Producer Warpgroup (TMA Warpgroup)
with ir.InsertionPoint(scf.IfOp(is_producer).then_block):
-
# Step 5.1. Reduce register size
- nvvm.setmaxregister(PRODUCER_REGISTER_SIZE,
- nvvm.SetMaxRegisterAction.decrease)
+ nvvm.setmaxregister(
+ PRODUCER_REGISTER_SIZE, nvvm.SetMaxRegisterAction.decrease
+ )
# Step 5.2. TMA Main Loop
- for_op = scf.ForOp(c(0), c(K // BLOCK_K), c(1),
- [arith.constant(i1, 1)])
+ for_op = scf.ForOp(
+ c(0), c(K // BLOCK_K), c(1), [arith.constant(i1, 1)]
+ )
with ir.InsertionPoint(for_op.body):
phaseParity = for_op.inner_iter_args[0]
iv = for_op.induction_variable
@@ -248,103 +288,126 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
iv,
stage,
phaseParity,
- predicate=producerPrimaryThread)
- nvgpu.MBarrierTryWaitParityOp(mbarDONE,
- phaseParity,
- ticks,
- mbarId=stage)
+ predicate=producerPrimaryThread,
+ )
+ nvgpu.MBarrierTryWaitParityOp(
+ mbarDONE, phaseParity, ticks, mbarId=stage
+ )
debug_print(
"[prod] {} | mbarDONE[{}] try_wait phase={} [done]",
iv,
stage,
phaseParity,
- predicate=producerPrimaryThread)
- p = arith.cmpi(arith.CmpIPredicate.eq, stage,
- c(num_stages - 1))
+ predicate=producerPrimaryThread,
+ )
+ p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(num_stages - 1))
phaseParity = arith.select(
- p, arith.xori(phaseParity, arith.constant(i1, 1)),
- phaseParity)
+ p,
+ arith.xori(phaseParity, arith.constant(i1, 1)),
+ phaseParity,
+ )
# Step 5.2.2. Load TMA
a_offset = arith.muli(stage, c(lhs_tile_bytes))
a_tma_slice = memref.view(
- ir.MemRefType.get(a_tma_shape,
- f16,
- memory_space=smem_space),
- dynamic_smem, a_offset, [])
+ ir.MemRefType.get(
+ a_tma_shape, f16, memory_space=smem_space
+ ),
+ dynamic_smem,
+ a_offset,
+ [],
+ )
b_offset = arith.addi(
arith.muli(stage, c(rhs_tile_bytes)),
- c(lhs_tile_bytes * num_stages))
+ c(lhs_tile_bytes * num_stages),
+ )
b_tma_slice_1 = memref.view(
- ir.MemRefType.get(b_tma_shape,
- f16,
- memory_space=smem_space),
- dynamic_smem, b_offset, [])
+ ir.MemRefType.get(
+ b_tma_shape, f16, memory_space=smem_space
+ ),
+ dynamic_smem,
+ b_offset,
+ [],
+ )
b_offset2 = arith.addi(
- b_offset, c(BLOCK_K * TMA_LAST_DIM_F16 * f16_byte))
+ b_offset, c(BLOCK_K * TMA_LAST_DIM_F16 * f16_byte)
+ )
b_tma_slice_2 = memref.view(
- ir.MemRefType.get(b_tma_shape,
- f16,
- memory_space=smem_space),
- dynamic_smem, b_offset2, [])
+ ir.MemRefType.get(
+ b_tma_shape, f16, memory_space=smem_space
+ ),
+ dynamic_smem,
+ b_offset2,
+ [],
+ )
debug_print(
"[prod] a_offset={} b_offset={} b_offset2={}",
a_offset,
b_offset,
b_offset2,
- predicate=producerPrimaryThread)
+ predicate=producerPrimaryThread,
+ )
coord = arith.muli(c(64), iv)
- nvgpu.TmaAsyncLoadOp(a_tma_slice,
- mbarTMA,
- a_tma_desc,
- coordinates=[coord, dimX],
- mbarId=stage,
- predicate=producerPrimaryThread)
- nvgpu.TmaAsyncLoadOp(b_tma_slice_1,
- mbarTMA,
- b_tma_desc,
- coordinates=[dimY, coord],
- mbarId=stage,
- predicate=producerPrimaryThread)
+ nvgpu.TmaAsyncLoadOp(
+ a_tma_slice,
+ mbarTMA,
+ a_tma_desc,
+ coordinates=[coord, dimX],
+ mbarId=stage,
+ predicate=producerPrimaryThread,
+ )
+ nvgpu.TmaAsyncLoadOp(
+ b_tma_slice_1,
+ mbarTMA,
+ b_tma_desc,
+ coordinates=[dimY, coord],
+ mbarId=stage,
+ predicate=producerPrimaryThread,
+ )
dimY2 = arith.addi(dimY, c(64))
- nvgpu.TmaAsyncLoadOp(b_tma_slice_2,
- mbarTMA,
- b_tma_desc,
- coordinates=[dimY2, coord],
- mbarId=stage,
- predicate=producerPrimaryThread)
+ nvgpu.TmaAsyncLoadOp(
+ b_tma_slice_2,
+ mbarTMA,
+ b_tma_desc,
+ coordinates=[dimY2, coord],
+ mbarId=stage,
+ predicate=producerPrimaryThread,
+ )
# Step 5.2.3. Arrive mbarTMA
- debug_print("[prod] {} | mbarTMA[{}] arrive",
- iv,
- stage,
- predicate=producerPrimaryThread)
+ debug_print(
+ "[prod] {} | mbarTMA[{}] arrive",
+ iv,
+ stage,
+ predicate=producerPrimaryThread,
+ )
nvgpu.mbarrier_arrive_expect_tx(
- mbarTMA,
- c(txcount),
+ mbarTMA, c(txcount), stage, predicate=producerPrimaryThread
+ )
+ debug_print(
+ "[prod] {} | mbarTMA[{}] arrive [done]",
+ iv,
stage,
- predicate=producerPrimaryThread)
- debug_print("[prod] {} | mbarTMA[{}] arrive [done]",
- iv,
- stage,
- predicate=producerPrimaryThread)
+ predicate=producerPrimaryThread,
+ )
scf.yield_([phaseParity])
scf.yield_([])
# GPU Step 6. Consumer Warpgroup (MMA Warpgroup)
if_op = scf.IfOp(is_consumer)
with ir.InsertionPoint(if_op.then_block):
-
# Step 6.1. Increase register size
- nvvm.setmaxregister(CONSUMER_REGISTER_SIZE,
- nvvm.SetMaxRegisterAction.increase)
+ nvvm.setmaxregister(
+ CONSUMER_REGISTER_SIZE, nvvm.SetMaxRegisterAction.increase
+ )
# GPU Step 6.2. Initialize MMA registers
acc = nvgpu.warpgroup_mma_init_accumulator(acc_ty)
# Step 6.3. MMA Main Loop
- for_op = scf.ForOp(c(0), c(K // BLOCK_K), c(1),
- [acc, arith.constant(i1, 0)])
+ for_op = scf.ForOp(
+ c(0), c(K // BLOCK_K), c(1), [acc, arith.constant(i1, 0)]
+ )
with ir.InsertionPoint(for_op.body):
# Step 6.3.1. Wait mbar1
phaseParity = for_op.inner_iter_args[1]
@@ -355,57 +418,68 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
iv,
stage,
phaseParity,
- predicate=consumerPrimaryThread)
- nvgpu.MBarrierTryWaitParityOp(mbarTMA,
- phaseParity,
- ticks,
- mbarId=stage)
+ predicate=consumerPrimaryThread,
+ )
+ nvgpu.MBarrierTryWaitParityOp(
+ mbarTMA, phaseParity, ticks, mbarId=stage
+ )
debug_print(
"[cons] {} | mbarTMA[{}] try_wait phase={} [done]",
iv,
stage,
phaseParity,
- predicate=consumerPrimaryThread)
+ predicate=consumerPrimaryThread,
+ )
# Step 6.3.2. Create WGMMA Descriptors
a_offset = arith.muli(stage, c(lhs_tile_bytes))
a_tile_slice = memref.view(
- ir.MemRefType.get(a_tile_shape,
- f16,
- memory_space=smem_space),
- dynamic_smem, a_offset, [])
+ ir.MemRefType.get(
+ a_tile_shape, f16, memory_space=smem_space
+ ),
+ dynamic_smem,
+ a_offset,
+ [],
+ )
b_offset = arith.addi(
arith.muli(stage, c(rhs_tile_bytes)),
- c(lhs_tile_bytes * num_stages))
+ c(lhs_tile_bytes * num_stages),
+ )
b_tile_slice = memref.view(
- ir.MemRefType.get(b_tile_shape,
- f16,
- memory_space=smem_space),
- dynamic_smem, b_offset, [])
- debug_print("[cons] a_offset={} b_offset={}",
- a_offset,
- b_offset,
- predicate=consumerPrimaryThread)
+ ir.MemRefType.get(
+ b_tile_shape, f16, memory_space=smem_space
+ ),
+ dynamic_smem,
+ b_offset,
+ [],
+ )
+ debug_print(
+ "[cons] a_offset={} b_offset={}",
+ a_offset,
+ b_offset,
+ predicate=consumerPrimaryThread,
+ )
da = nvgpu.WarpgroupGenerateDescriptorOp(
- a_wgmma_ty, a_tile_slice, a_tma_desc)
+ a_wgmma_ty, a_tile_slice, a_tma_desc
+ )
db = nvgpu.WarpgroupGenerateDescriptorOp(
- b_wgmma_ty, b_tile_slice, b_tma_desc)
+ b_wgmma_ty, b_tile_slice, b_tma_desc
+ )
# Step 6.3.3. MMA
carry_acc = for_op.inner_iter_args[0]
- new_acc = nvgpu.WarpgroupMmaOp(acc.type,
- da,
- db,
- carry_acc,
- transposeB=True)
+ new_acc = nvgpu.WarpgroupMmaOp(
+ acc.type, da, db, carry_acc, transposeB=True
+ )
# Step 6.3.4. Arrive mbarDONE
p1 = arith.cmpi(arith.CmpIPredicate.sgt, iv, c(0))
p_arrive = arith.andi(consumerPrimaryThread, p1)
with ir.InsertionPoint(scf.IfOp(p_arrive).then_block):
p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(0))
- barId = arith.select(p, c(num_stages - 1),
- arith.subi(stage, c(1)))
+ barId = arith.select(
+ p, c(num_stages - 1), arith.subi(stage, c(1))
+ )
remoteCtaId = c(0)
pred = consumerPrimaryThread
debug_print(
@@ -414,24 +488,27 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
barId,
remoteCtaId,
pred,
- predicate=consumerPrimaryThread)
+ predicate=consumerPrimaryThread,
+ )
nvgpu.mbarrier_arrive(
- ir.Type.parse("!nvgpu.mbarrier.token"),
- mbarDONE, barId)
+ ir.Type.parse("!nvgpu.mbarrier.token"), mbarDONE, barId
+ )
debug_print(
"[cons] {} | mbarDONE[{}] arrive pred={} [done]",
iv,
barId,
remoteCtaId,
pred,
- predicate=consumerPrimaryThread)
+ predicate=consumerPrimaryThread,
+ )
scf.yield_([])
- p = arith.cmpi(arith.CmpIPredicate.eq, stage,
- c(num_stages - 1))
+ p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(num_stages - 1))
phaseParity = arith.select(
- p, arith.xori(phaseParity, arith.constant(i1, 1)),
- phaseParity)
+ p,
+ arith.xori(phaseParity, arith.constant(i1, 1)),
+ phaseParity,
+ )
# Step 6.3.5. Yield
scf.yield_([new_acc, phaseParity])
@@ -439,46 +516,55 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
# Step 6.3. Wait All WGMMA
nvvm.WgmmaWaitGroupSyncOp(0)
- with ir.InsertionPoint(
- scf.IfOp(consumerPrimaryThread).then_block):
+ with ir.InsertionPoint(scf.IfOp(consumerPrimaryThread).then_block):
barId = c((K // BLOCK_K) % num_stages)
nvgpu.mbarrier_arrive(
- ir.Type.parse("!nvgpu.mbarrier.token"), mbarDONE,
- barId)
+ ir.Type.parse("!nvgpu.mbarrier.token"), mbarDONE, barId
+ )
scf.yield_([])
# Step 6.4. Epilogue (registers --> shared memory)
- acc_smem_ty = ir.MemRefType.get((BLOCK_M, BLOCK_N),
- c_elem_ty,
- memory_space=smem_space)
+ acc_smem_ty = ir.MemRefType.get(
+ (BLOCK_M, BLOCK_N), c_elem_ty, memory_space=smem_space
+ )
acc_smem = memref.view(acc_smem_ty, dynamic_smem, c(0), [])
- debug_print("[cons] | Storing",
- predicate=consumerPrimaryThread)
+ debug_print("[cons] | Storing", predicate=consumerPrimaryThread)
nvgpu.WarpgroupMmaStoreOp(for_op.results[0], acc_smem)
scf.yield_([])
gpu.barrier()
# GPU Step 9. Epilogue (shared memory --> global memory)
- fd = ir.MemRefType.get([BLOCK_M * BLOCK_N],
- c_elem_ty,
- memory_space=smem_space)
+ fd = ir.MemRefType.get(
+ [BLOCK_M * BLOCK_N], c_elem_ty, memory_space=smem_space
+ )
collapsed_smem = memref.view(fd, dynamic_smem, c(0), [])
rty = ir.MemRefType.get(
- (BLOCK_M, BLOCK_N), c_elem_ty,
- ir.Attribute.parse("strided<[" + str(N) +
- ", 1], offset: ?>"))
+ (BLOCK_M, BLOCK_N),
+ c_elem_ty,
+ ir.Attribute.parse("strided<[" + str(N) + ", 1], offset: ?>"),
+ )
c_device_per_block = memref.SubViewOp(
- rty, c_device, [dimX, dimY], [], [],
- [MLIR_DYNAMIC, MLIR_DYNAMIC], [BLOCK_M, BLOCK_N], [1, 1])
+ rty,
+ c_device,
+ [dimX, dimY],
+ [],
+ [],
+ [MLIR_DYNAMIC, MLIR_DYNAMIC],
+ [BLOCK_M, BLOCK_N],
+ [1, 1],
+ )
vlen = 1
- for_op = scf.ForOp(tidx, c(BLOCK_M * BLOCK_N),
- c(vlen * WARP_GROUP_SIZE * 2))
+ for_op = scf.ForOp(
+ tidx, c(BLOCK_M * BLOCK_N), c(vlen * WARP_GROUP_SIZE * 2)
+ )
with ir.InsertionPoint(for_op.body):
x = arith.divui(for_op.induction_variable, c(BLOCK_M))
y = arith.remui(for_op.induction_variable, c(BLOCK_N))
- vdata = vector.load(ir.VectorType.get(
- (vlen, ), c_elem_ty), collapsed_smem,
- [for_op.induction_variable])
+ vdata = vector.load(
+ ir.VectorType.get((vlen,), c_elem_ty),
+ collapsed_smem,
+ [for_op.induction_variable],
+ )
vector.store(vdata, c_device_per_block, [x, y])
scf.yield_([])
@@ -490,20 +576,23 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
gpu.wait(token_ty, [t9])
mlir_matmul_warpspecialized.func_op.attributes[
- "llvm.emit_c_interface"] = ir.UnitAttr.get()
+ "llvm.emit_c_interface"
+ ] = ir.UnitAttr.get()
module.operation.verify()
return module
-def generate_matmul_multistage(input_type=np.float16,
- output_type=np.float32,
- M=4096,
- N=4096,
- K=4096,
- BLOCK_M=128,
- BLOCK_N=128,
- BLOCK_K=64,
- max_num_stages=3):
+def generate_matmul_multistage(
+ input_type=np.float16,
+ output_type=np.float32,
+ M=4096,
+ N=4096,
+ K=4096,
+ BLOCK_M=128,
+ BLOCK_N=128,
+ BLOCK_K=64,
+ max_num_stages=3,
+):
# Limitaitons for now
assert input_type == np.float16
assert output_type == np.float32
@@ -511,8 +600,7 @@ def generate_matmul_multistage(input_type=np.float16,
assert N % BLOCK_N == 0
assert K % BLOCK_K == 0
- required_stages = (M * K + K * N) // (BLOCK_M * BLOCK_K +
- BLOCK_K * BLOCK_N)
+ required_stages = (M * K + K * N) // (BLOCK_M * BLOCK_K + BLOCK_K * BLOCK_N)
num_stages = min(required_stages, max_num_stages)
module = ir.Module.create()
@@ -530,36 +618,73 @@ def generate_matmul_multistage(input_type=np.float16,
a_tile_shape = a_tma_shape = (BLOCK_M, TMA_LAST_DIM_F16)
b_tma_shape = (BLOCK_K, TMA_LAST_DIM_F16)
b_tile_shape = (BLOCK_K, BLOCK_N)
- txcount = ((b_tile_shape[0] * b_tile_shape[1]) +
- (a_tile_shape[0] * a_tile_shape[1])) * f16_byte
+ txcount = (
+ (b_tile_shape[0] * b_tile_shape[1]) + (a_tile_shape[0] * a_tile_shape[1])
+ ) * f16_byte
smem_space_str = "#gpu.address_space<workgroup>"
smem_space = ir.Attribute.parse(smem_space_str)
input_type_str = "f16" if input_type == np.float16 else "f32"
output_type_str = "f16" if output_type == np.float16 else "f32"
- mbar_ty = ir.Type.parse("!nvgpu.mbarrier.group<memorySpace = " +
- str(smem_space) + ", num_barriers = " +
- str(num_stages) + ">")
+ mbar_ty = ir.Type.parse(
+ "!nvgpu.mbarrier.group<memorySpace = "
+ + str(smem_space)
+ + ", num_barriers = "
+ + str(num_stages)
+ + ">"
+ )
a_tma_desc_ty = ir.Type.parse(
- "!nvgpu.tensormap.descriptor<tensor = memref<" + str(BLOCK_M) + "x" +
- str(TMA_LAST_DIM_F16) + "x" + str(input_type_str) + ", " +
- str(smem_space) +
- ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>")
+ "!nvgpu.tensormap.descriptor<tensor = memref<"
+ + str(BLOCK_M)
+ + "x"
+ + str(TMA_LAST_DIM_F16)
+ + "x"
+ + str(input_type_str)
+ + ", "
+ + str(smem_space)
+ + ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>"
+ )
b_tma_desc_ty = ir.Type.parse(
- "!nvgpu.tensormap.descriptor<tensor = memref<" + str(BLOCK_K) + "x" +
- str(TMA_LAST_DIM_F16) + "x" + str(input_type_str) + ", " +
- str(smem_space) +
- ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>")
- acc_ty = ir.Type.parse("!nvgpu.warpgroup.accumulator<fragmented=vector<" +
- str(BLOCK_M) + "x" + str(BLOCK_N) + "x" +
- str(output_type_str) + ">>")
- a_wgmma_ty = ir.Type.parse("!nvgpu.warpgroup.descriptor<tensor=memref<" +
- str(BLOCK_M) + "x" + str(BLOCK_K) + "x" +
- str(input_type_str) + ", " + smem_space_str +
- ">>")
- b_wgmma_ty = ir.Type.parse("!nvgpu.warpgroup.descriptor<tensor=memref<" +
- str(BLOCK_K) + "x" + str(BLOCK_N) + "x" +
- str(input_type_str) + ", " + smem_space_str +
- ">>")
+ "!nvgpu.tensormap.descriptor<tensor = memref<"
+ + str(BLOCK_K)
+ + "x"
+ + str(TMA_LAST_DIM_F16)
+ + "x"
+ + str(input_type_str)
+ + ", "
+ + str(smem_space)
+ + ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>"
+ )
+ acc_ty = ir.Type.parse(
+ "!nvgpu.warpgroup.accumulator<fragmented=vector<"
+ + str(BLOCK_M)
+ + "x"
+ + str(BLOCK_N)
+ + "x"
+ + str(output_type_str)
+ + ">>"
+ )
+ a_wgmma_ty = ir.Type.parse(
+ "!nvgpu.warpgroup.descriptor<tensor=memref<"
+ + str(BLOCK_M)
+ + "x"
+ + str(BLOCK_K)
+ + "x"
+ + str(input_type_str)
+ + ", "
+ + smem_space_str
+ + ">>"
+ )
+ b_wgmma_ty = ir.Type.parse(
+ "!nvgpu.warpgroup.descriptor<tensor=memref<"
+ + str(BLOCK_K)
+ + "x"
+ + str(BLOCK_N)
+ + "x"
+ + str(input_type_str)
+ + ", "
+ + smem_space_str
+ + ">>"
+ )
with ir.InsertionPoint(module.body):
@@ -581,16 +706,20 @@ def mlir_matmul_multistage(a_host, b_host, c_host):
t7 = gpu.wait(token_ty, [t6])
# Step 2. Create TMA Descriptors
- tma_specs = [(a_device, a_tma_desc_ty, a_tma_shape),
- (b_device, b_tma_desc_ty, b_tma_shape)]
+ tma_specs = [
+ (a_device, a_tma_desc_ty, a_tma_shape),
+ (b_device, b_tma_desc_ty, b_tma_shape),
+ ]
tma_descs = []
for x_device, tensor_map_ty, tile_shape in tma_specs:
x_unranked = memref.cast(
- ir.UnrankedMemRefType.get(f16, a_ty.memory_space),
- x_device)
+ ir.UnrankedMemRefType.get(f16, a_ty.memory_space), x_device
+ )
tma_descs.append(
- nvgpu.TmaCreateDescriptorOp(tensor_map_ty, x_unranked,
- map(c, tile_shape)).result)
+ nvgpu.TmaCreateDescriptorOp(
+ tensor_map_ty, x_unranked, map(c, tile_shape)
+ ).result
+ )
a_tma_desc, b_tma_desc = tma_descs
# Step 3. Launch Kernel with 1 Warpgroup
@@ -599,19 +728,20 @@ def mlir_matmul_multistage(a_host, b_host, c_host):
assert M % BLOCK_M == 0 and N % BLOCK_N == 0
grid = (cta_m, cta_n, 1)
block = (WARP_GROUP_SIZE, 1, 1)
- launch_op = gpu.LaunchOp(token_ty, [t7],
- *map(c, grid),
- *map(c, block),
- dynamicSharedMemorySize=c(smem_size,
- ty=i32))
+ launch_op = gpu.LaunchOp(
+ token_ty,
+ [t7],
+ *map(c, grid),
+ *map(c, block),
+ dynamicSharedMemorySize=c(smem_size, ty=i32)
+ )
launch_op.body.blocks.append(*([index] * 12))
with ir.InsertionPoint(launch_op.body.blocks[0]):
# GPU Step 0. Bootstrapping
memref.assume_alignment(c_device, 16)
dynamic_smem = gpu.dynamic_shared_memory(
- ir.MemRefType.get((MLIR_DYNAMIC, ),
- i8,
- memory_space=smem_space))
+ ir.MemRefType.get((MLIR_DYNAMIC,), i8, memory_space=smem_space)
+ )
ticks = c(10000000)
tidx = gpu.thread_id(gpu.Dimension.x)
primaryThread = arith.cmpi(arith.CmpIPredicate.eq, tidx, c(0))
@@ -624,17 +754,12 @@ def mlir_matmul_multistage(a_host, b_host, c_host):
# GPU Step 1. Initialize mbarrier groups
mbarTMA = nvgpu.mbarrier_create(mbar_ty)
for i in range(num_stages):
- nvgpu.mbarrier_init(mbarTMA,
- c(1),
- c(i),
- predicate=primaryThread)
+ nvgpu.mbarrier_init(mbarTMA, c(1), c(i), predicate=primaryThread)
gpu.barrier()
# GPU Step 2. Prefetch TMA descriptors
- nvgpu.tma_prefetch_descriptor(a_tma_desc,
- predicate=primaryThread)
- nvgpu.tma_prefetch_descriptor(b_tma_desc,
- predicate=primaryThread)
+ nvgpu.tma_prefetch_descriptor(a_tma_desc, predicate=primaryThread)
+ nvgpu.tma_prefetch_descriptor(b_tma_desc, predicate=primaryThread)
# GPU Step 3. Prologue (global memory --> shared memory)
for_op = scf.ForOp(c(0), c(num_stages - 1), c(1))
@@ -644,24 +769,30 @@ def mlir_matmul_multistage(a_host, b_host, c_host):
# Step 3.1. Calculate offsets
a_offset = arith.muli(iv, c(lhs_tile_bytes))
a_tma_slice = memref.view(
- ir.MemRefType.get(a_tma_shape,
- f16,
- memory_space=smem_space),
- dynamic_smem, a_offset, [])
- b_offset = arith.addi(arith.muli(iv, c(rhs_tile_bytes)),
- c(lhs_tile_bytes * num_stages))
+ ir.MemRefType.get(a_tma_shape, f16, memory_space=smem_space),
+ dynamic_smem,
+ a_offset,
+ [],
+ )
+ b_offset = arith.addi(
+ arith.muli(iv, c(rhs_tile_bytes)),
+ c(lhs_tile_bytes * num_stages),
+ )
b_tma_slice_1 = memref.view(
- ir.MemRefType.get(b_tma_shape,
- f16,
- memory_space=smem_space),
- dynamic_smem, b_offset, [])
+ ir.MemRefType.get(b_tma_shape, f16, memory_space=smem_space),
+ dynamic_smem,
+ b_offset,
+ [],
+ )
b_offset2 = arith.addi(
- b_offset, c(BLOCK_K * TMA_LAST_DIM_F16 * f16_byte))
+ b_offset, c(BLOCK_K * TMA_LAST_DIM_F16 * f16_byte)
+ )
b_tma_slice_2 = memref.view(
- ir.MemRefType.get(b_tma_shape,
- f16,
- memory_space=smem_space),
- dynamic_smem, b_offset2, [])
+ ir.MemRefType.get(b_tma_shape, f16, memory_space=smem_space),
+ dynamic_smem,
+ b_offset2,
+ [],
+ )
# Step 3.2. TMA Load
coord = arith.muli(c(64), iv)
@@ -675,123 +806,153 @@ def mlir_matmul_multistage(a_host, b_host, c_host):
dimX,
dimY,
coord,
- predicate=primaryThread)
- nvgpu.TmaAsyncLoadOp(a_tma_slice,
- mbarTMA,
- a_tma_desc,
- coordinates=[coord, dimX],
- mbarId=iv,
- predicate=primaryThread)
- nvgpu.TmaAsyncLoadOp(b_tma_slice_1,
- mbarTMA,
- b_tma_desc,
- coordinates=[dimY, coord],
- mbarId=iv,
- predicate=primaryThread)
- nvgpu.TmaAsyncLoadOp(b_tma_slice_2,
- mbarTMA,
- b_tma_desc,
- coordinates=[dimY2, coord],
- mbarId=iv,
- predicate=primaryThread)
+ predicate=primaryThread,
+ )
+ nvgpu.TmaAsyncLoadOp(
+ a_tma_slice,
+ mbarTMA,
+ a_tma_desc,
+ coordinates=[coord, dimX],
+ mbarId=iv,
+ predicate=primaryThread,
+ )
+ nvgpu.TmaAsyncLoadOp(
+ b_tma_slice_1,
+ mbarTMA,
+ b_tma_desc,
+ coordinates=[dimY, coord],
+ mbarId=iv,
+ predicate=primaryThread,
+ )
+ nvgpu.TmaAsyncLoadOp(
+ b_tma_slice_2,
+ mbarTMA,
+ b_tma_desc,
+ coordinates=[dimY2, coord],
+ mbarId=iv,
+ predicate=primaryThread,
+ )
# Step 3.2. mbarTMA arrive
- debug_print("[Prologue] mbarTMA[{}] arrive",
- iv,
- predicate=primaryThread)
- nvgpu.mbarrier_arrive_expect_tx(mbarTMA,
- c(txcount),
- iv,
- predicate=primaryThread)
- debug_print("[Prologue] mbarTMA[{}] arrive [done]",
- iv,
- predicate=primaryThread)
+ debug_print(
+ "[Prologue] mbarTMA[{}] arrive", iv, predicate=primaryThread
+ )
+ nvgpu.mbarrier_arrive_expect_tx(
+ mbarTMA, c(txcount), iv, predicate=primaryThread
+ )
+ debug_print(
+ "[Prologue] mbarTMA[{}] arrive [done]",
+ iv,
+ predicate=primaryThread,
+ )
scf.yield_([])
# GPU Step 4. Main Loop
acc = nvgpu.warpgroup_mma_init_accumulator(acc_ty)
- for_op = scf.ForOp(c(0), c(K // BLOCK_K), c(1),
- [acc, arith.constant(i1, 0)])
+ for_op = scf.ForOp(
+ c(0), c(K // BLOCK_K), c(1), [acc, arith.constant(i1, 0)]
+ )
with ir.InsertionPoint(for_op.body):
# Step 4.1. Wait mbarTMA
phaseParity = for_op.inner_iter_args[1]
iv = for_op.induction_variable
stage = arith.remui(iv, c(num_stages))
- debug_print("[MainLoop] mbarTMA[{}] try_wait phase={}",
- stage,
- phaseParity,
- predicate=primaryThread)
- nvgpu.MBarrierTryWaitParityOp(mbarTMA,
- phaseParity,
- ticks,
- mbarId=stage)
+ debug_print(
+ "[MainLoop] mbarTMA[{}] try_wait phase={}",
+ stage,
+ phaseParity,
+ predicate=primaryThread,
+ )
+ nvgpu.MBarrierTryWaitParityOp(
+ mbarTMA, phaseParity, ticks, mbarId=stage
+ )
debug_print(
"[MainLoop] mbarTMA[{}] try_wait phase={} [done]",
stage,
phaseParity,
- predicate=primaryThread)
+ predicate=primaryThread,
+ )
# Step 4.2. Create WGMMA Descriptors
a_offset = arith.muli(stage, c(lhs_tile_bytes))
a_tile_slice = memref.view(
- ir.MemRefType.get(a_tile_shape,
- f16,
- memory_space=smem_space),
- dynamic_smem, a_offset, [])
- b_offset = arith.addi(arith.muli(stage, c(rhs_tile_bytes)),
- c(lhs_tile_bytes * num_stages))
+ ir.MemRefType.get(a_tile_shape, f16, memory_space=smem_space),
+ dynamic_smem,
+ a_offset,
+ [],
+ )
+ b_offset = arith.addi(
+ arith.muli(stage, c(rhs_tile_bytes)),
+ c(lhs_tile_bytes * num_stages),
+ )
b_tile_slice = memref.view(
- ir.MemRefType.get(b_tile_shape,
- f16,
- memory_space=smem_space),
- dynamic_smem, b_offset, [])
- debug_print("[MainLoop] iv={} MMA a_offset={} b_offset={}",
- iv,
- a_offset,
- b_offset,
- predicate=primaryThread)
+ ir.MemRefType.get(b_tile_shape, f16, memory_space=smem_space),
+ dynamic_smem,
+ b_offset,
+ [],
+ )
+ debug_print(
+ "[MainLoop] iv={} MMA a_offset={} b_offset={}",
+ iv,
+ a_offset,
+ b_offset,
+ predicate=primaryThread,
+ )
da = nvgpu.WarpgroupGenerateDescriptorOp(
- a_wgmma_ty, a_tile_slice, a_tma_desc)
+ a_wgmma_ty, a_tile_slice, a_tma_desc
+ )
db = nvgpu.WarpgroupGenerateDescriptorOp(
- b_wgmma_ty, b_tile_slice, b_tma_desc)
+ b_wgmma_ty, b_tile_slice, b_tma_desc
+ )
# Step 4.3. MMA
carry_acc = for_op.inner_iter_args[0]
- new_acc = nvgpu.WarpgroupMmaOp(acc.type,
- da,
- db,
- carry_acc,
- transposeB=True)
+ new_acc = nvgpu.WarpgroupMmaOp(
+ acc.type, da, db, carry_acc, transposeB=True
+ )
# Step 4.4. Load TMA for next stage
- p1 = arith.cmpi(arith.CmpIPredicate.ult,
- arith.addi(iv, c(num_stages - 1)),
- c(K // BLOCK_K))
+ p1 = arith.cmpi(
+ arith.CmpIPredicate.ult,
+ arith.addi(iv, c(num_stages - 1)),
+ c(K // BLOCK_K),
+ )
p = arith.andi(primaryThread, p1)
with ir.InsertionPoint(scf.IfOp(p).then_block):
nextStage = arith.addi(iv, c(num_stages - 1))
nextSlot = arith.remui(nextStage, c(num_stages))
a_offset = arith.muli(nextSlot, c(lhs_tile_bytes))
a_tma_slice = memref.view(
- ir.MemRefType.get(a_tma_shape,
- f16,
- memory_space=smem_space),
- dynamic_smem, a_offset, [])
+ ir.MemRefType.get(
+ a_tma_shape, f16, memory_space=smem_space
+ ),
+ dynamic_smem,
+ a_offset,
+ [],
+ )
b_offset = arith.addi(
arith.muli(nextSlot, c(rhs_tile_bytes)),
- c(lhs_tile_bytes * num_stages))
+ c(lhs_tile_bytes * num_stages),
+ )
b_tma_slice_1 = memref.view(
- ir.MemRefType.get(b_tma_shape,
- f16,
- memory_space=smem_space),
- dynamic_smem, b_offset, [])
+ ir.MemRefType.get(
+ b_tma_shape, f16, memory_space=smem_space
+ ),
+ dynamic_smem,
+ b_offset,
+ [],
+ )
b_offset2 = arith.addi(
- b_offset, c(BLOCK_K * TMA_LAST_DIM_F16 * f16_byte))
+ b_offset, c(BLOCK_K * TMA_LAST_DIM_F16 * f16_byte)
+ )
b_tma_slice_2 = memref.view(
- ir.MemRefType.get(b_tma_shape,
- f16,
- memory_space=smem_space),
- dynamic_smem, b_offset2, [])
+ ir.MemRefType.get(
+ b_tma_shape, f16, memory_space=smem_space
+ ),
+ dynamic_smem,
+ b_offset2,
+ [],
+ )
coord = arith.muli(c(64), nextStage)
debug_print(
@@ -804,45 +965,53 @@ def mlir_matmul_multistage(a_host, b_host, c_host):
dimX,
dimY,
coord,
- predicate=primaryThread)
- nvgpu.TmaAsyncLoadOp(a_tma_slice,
- mbarTMA,
- a_tma_desc,
- coordinates=[coord, dimX],
- mbarId=nextSlot,
- predicate=primaryThread)
- nvgpu.TmaAsyncLoadOp(b_tma_slice_1,
- mbarTMA,
- b_tma_desc,
- coordinates=[dimY, coord],
- mbarId=nextSlot,
- predicate=primaryThread)
+ predicate=primaryThread,
+ )
+ nvgpu.TmaAsyncLoadOp(
+ a_tma_slice,
+ mbarTMA,
+ a_tma_desc,
+ coordinates=[coord, dimX],
+ mbarId=nextSlot,
+ predicate=primaryThread,
+ )
+ nvgpu.TmaAsyncLoadOp(
+ b_tma_slice_1,
+ mbarTMA,
+ b_tma_desc,
+ coordinates=[dimY, coord],
+ mbarId=nextSlot,
+ predicate=primaryThread,
+ )
dimY2 = arith.addi(dimY, c(64))
- nvgpu.TmaAsyncLoadOp(b_tma_slice_2,
- mbarTMA,
- b_tma_desc,
- coordinates=[dimY2, coord],
- mbarId=nextSlot,
- predicate=primaryThread)
-
- debug_print("[MainLoop] mbarTMA[{}] arrive",
- nextSlot,
- predicate=primaryThread)
- nvgpu.mbarrier_arrive_expect_tx(
+ nvgpu.TmaAsyncLoadOp(
+ b_tma_slice_2,
mbarTMA,
- c(txcount),
+ b_tma_desc,
+ coordinates=[dimY2, coord],
+ mbarId=nextSlot,
+ predicate=primaryThread,
+ )
+
+ debug_print(
+ "[MainLoop] mbarTMA[{}] arrive",
+ nextSlot,
+ predicate=primaryThread,
+ )
+ nvgpu.mbarrier_arrive_expect_tx(
+ mbarTMA, c(txcount), nextSlot, predicate=primaryThread
+ )
+ debug_print(
+ "[MainLoop] mbarTMA[{}] arrive [done]",
nextSlot,
- predicate=primaryThread)
- debug_print("[MainLoop] mbarTMA[{}] arrive [done]",
- nextSlot,
- predicate=primaryThread)
+ predicate=primaryThread,
+ )
scf.yield_([])
# Step 4.5. Change the phaseParity
- p = arith.cmpi(arith.CmpIPredicate.eq, stage,
- c(num_stages - 1))
+ p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(num_stages - 1))
phaseParity = arith.select(
- p, arith.xori(phaseParity, arith.constant(i1, 1)),
- phaseParity)
+ p, arith.xori(phaseParity, arith.constant(i1, 1)), phaseParity
+ )
# Step 4.5. Yield
scf.yield_([new_acc, phaseParity])
@@ -852,35 +1021,46 @@ def mlir_matmul_multistage(a_host, b_host, c_host):
gpu.barrier()
# Step 6. Epilogue (registers --> shared memory)
- acc_smem_ty = ir.MemRefType.get((BLOCK_M, BLOCK_N),
- c_elem_ty,
- memory_space=smem_space)
+ acc_smem_ty = ir.MemRefType.get(
+ (BLOCK_M, BLOCK_N), c_elem_ty, memory_space=smem_space
+ )
acc_smem = memref.view(acc_smem_ty, dynamic_smem, c(0), [])
debug_print("Storing", predicate=primaryThread)
nvgpu.WarpgroupMmaStoreOp(for_op.results[0], acc_smem)
gpu.barrier()
# GPU Step 7. Epilogue (shared memory --> global memory)
- fd = ir.MemRefType.get([BLOCK_M * BLOCK_N],
- c_elem_ty,
- memory_space=smem_space)
+ fd = ir.MemRefType.get(
+ [BLOCK_M * BLOCK_N], c_elem_ty, memory_space=smem_space
+ )
collapsed_smem = memref.view(fd, dynamic_smem, c(0), [])
rty = ir.MemRefType.get(
- (BLOCK_M, BLOCK_N), c_elem_ty,
- ir.Attribute.parse("strided<[" + str(N) +
- ", 1], offset: ?>"))
+ (BLOCK_M, BLOCK_N),
+ c_elem_ty,
+ ir.Attribute.parse("strided<[" + str(N) + ", 1], offset: ?>"),
+ )
c_device_per_block = memref.SubViewOp(
- rty, c_device, [dimX, dimY], [], [],
- [MLIR_DYNAMIC, MLIR_DYNAMIC], [BLOCK_M, BLOCK_N], [1, 1])
+ rty,
+ c_device,
+ [dimX, dimY],
+ [],
+ [],
+ [MLIR_DYNAMIC, MLIR_DYNAMIC],
+ [BLOCK_M, BLOCK_N],
+ [1, 1],
+ )
vlen = 1
- for_op = scf.ForOp(tidx, c(BLOCK_M * BLOCK_N),
- c(vlen * WARP_GROUP_SIZE))
+ for_op = scf.ForOp(
+ tidx, c(BLOCK_M * BLOCK_N), c(vlen * WARP_GROUP_SIZE)
+ )
with ir.InsertionPoint(for_op.body):
x = arith.divui(for_op.induction_variable, c(BLOCK_M))
y = arith.remui(for_op.induction_variable, c(BLOCK_N))
- vdata = vector.load(ir.VectorType.get(
- (vlen, ), c_elem_ty), collapsed_smem,
- [for_op.induction_variable])
+ vdata = vector.load(
+ ir.VectorType.get((vlen,), c_elem_ty),
+ collapsed_smem,
+ [for_op.induction_variable],
+ )
vector.store(vdata, c_device_per_block, [x, y])
scf.yield_([])
@@ -892,6 +1072,7 @@ def mlir_matmul_multistage(a_host, b_host, c_host):
gpu.wait(token_ty, [t9])
mlir_matmul_multistage.func_op.attributes[
- "llvm.emit_c_interface"] = ir.UnitAttr.get()
+ "llvm.emit_c_interface"
+ ] = ir.UnitAttr.get()
module.operation.verify()
return module
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/python/tools/nvgpucompiler.py b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/nvgpucompiler.py
index ab218812bb4d96..1c9cc74fcd169c 100644
--- a/mlir/test/Integration/GPU/CUDA/sm90/python/tools/nvgpucompiler.py
+++ b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/nvgpucompiler.py
@@ -19,8 +19,7 @@
class NvgpuCompiler:
"""Nvgpu class for compiling and building MLIR modules."""
- def __init__(self, options: str, opt_level: int,
- shared_libs: Sequence[str]):
+ def __init__(self, options: str, opt_level: int, shared_libs: Sequence[str]):
pipeline = f"builtin.module(gpu-lower-to-nvvm-pipeline{{{options}}})"
self.pipeline = pipeline
self.shared_libs = shared_libs
@@ -36,12 +35,11 @@ def compile(self, module: ir.Module):
def jit(self, module: ir.Module) -> execution_engine.ExecutionEngine:
"""Wraps the module in a JIT execution engine."""
- return execution_engine.ExecutionEngine(module,
- opt_level=self.opt_level,
- shared_libs=self.shared_libs)
+ return execution_engine.ExecutionEngine(
+ module, opt_level=self.opt_level, shared_libs=self.shared_libs
+ )
- def compile_and_jit(self,
- module: ir.Module) -> execution_engine.ExecutionEngine:
+ def compile_and_jit(self, module: ir.Module) -> execution_engine.ExecutionEngine:
"""Compiles and jits the module."""
self.compile(module)
return self.jit(module)
>From 9b65ffd405b5f5a241d0129223d04fc0a68730e3 Mon Sep 17 00:00:00 2001
From: grypp <guray.ozen at gmail.com>
Date: Tue, 13 Feb 2024 09:09:46 +0000
Subject: [PATCH 04/10] fix the spelling mistake
---
.../Integration/GPU/CUDA/sm90/python/matmul.py | 16 ++++++++--------
1 file changed, 8 insertions(+), 8 deletions(-)
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py b/mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py
index ff9c950193e695..f5d5b30c70d1e1 100644
--- a/mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py
+++ b/mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py
@@ -95,12 +95,12 @@ def generate_matmul(
BLOCK_M=128,
BLOCK_N=128,
BLOCK_K=64,
- use_warp_specilization=True,
+ use_warp_specialization=True,
saveIR=False,
max_num_stages=3,
):
with matmulBuilder.ir.Context() as ctx, matmulBuilder.ir.Location.unknown():
- if use_warp_specilization:
+ if use_warp_specialization:
mlir_nvgpu_module = matmulBuilder.generate_matmul_ws(
input_type,
output_type,
@@ -161,7 +161,7 @@ def matmul(
BLOCK_M=128,
BLOCK_N=128,
BLOCK_K=64,
- use_warp_specilization=True,
+ use_warp_specialization=True,
saveIR=False,
max_num_stages=3,
print_results=False,
@@ -170,7 +170,7 @@ def matmul(
# Print the configuration
ity = "f16" if input_type == np.float16 else "f32"
oty = "f16" if output_type == np.float16 else "f32"
- gemmty = "Warp Specilization" if use_warp_specilization else "Multistage"
+ gemmty = "Warp specialization" if use_warp_specialization else "Multistage"
print(
"===-- Running GEMM "
+ gemmty
@@ -207,7 +207,7 @@ def matmul(
BLOCK_M,
BLOCK_N,
BLOCK_K,
- use_warp_specilization,
+ use_warp_specialization,
saveIR,
max_num_stages,
)
@@ -221,7 +221,7 @@ def matmul(
mem_c = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(c)))
kernelName = (
"mlir_matmul_warpspecialized"
- if use_warp_specilization
+ if use_warp_specialization
else "mlir_matmul_multistage"
)
@@ -252,7 +252,7 @@ def matmul(
128,
4096,
max_num_stages=3,
- use_warp_specilization=False,
+ use_warp_specialization=False,
)
# GEMM Warp Specilized f32 += f16 * f16
matmul(
@@ -262,5 +262,5 @@ def matmul(
1024,
512,
max_num_stages=3,
- use_warp_specilization=True,
+ use_warp_specialization=True,
)
>From a31f9b17d8b90c5828b8f2373b31573e1add52bc Mon Sep 17 00:00:00 2001
From: grypp <guray.ozen at gmail.com>
Date: Tue, 20 Feb 2024 10:09:11 +0000
Subject: [PATCH 05/10] address comments
---
.../CUDA/sm90/python/tools/matmulBuilder.py | 201 +++++++++++-------
1 file changed, 121 insertions(+), 80 deletions(-)
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
index 4c132ec992a461..434af5e0920466 100644
--- a/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
+++ b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
@@ -10,6 +10,7 @@
from mlir.dialects import builtin
from mlir.dialects import scf
from mlir.dialects import vector
+from mlir.extras import types as T
TMA_LAST_DIM_F16 = 64 # 128B flaot16
WARP_SIZE = 32
@@ -22,12 +23,8 @@
CONSUMER_PRIMARY_THREAD = 0
MLIR_DYNAMIC = -9223372036854775808
-f16_byte = 2
-f32_byte = 4
DEBUG = False
-
-
def debug_print(fmt, *args, predicate=None, threadNumber=-1, forcePrint=False):
if not DEBUG and not forcePrint:
return
@@ -59,11 +56,52 @@ def debug_print(fmt, *args, predicate=None, threadNumber=-1, forcePrint=False):
scf.yield_([])
+def get_type_str(ty):
+ if ir.F16Type.isinstance(ty):
+ return "f16"
+ if ir.F32Type.isinstance(ty):
+ return "f32"
+ if ir.F64Type.isinstance(ty):
+ return "f64"
+ if ir.IntegerType.isinstance(ty):
+ return "i" + str(ir.IntegerType(ty).width)
+ if ir.IndexType.isinstance(ty):
+ return "T.index()"
+ raise NotImplementedError(ty)
+
+
+def get_type_size(ty):
+ if ir.F16Type.isinstance(ty):
+ return 2
+ if ir.F32Type.isinstance(ty):
+ return 4
+ if ir.F64Type.isinstance(ty):
+ return 8
+ if ir.IntegerType.isinstance(ty):
+ return ir.IntegerType(ty).width // 8
+ if ir.IndexType.isinstance(ty):
+ return 8
+ raise NotImplementedError(ty)
+
+
+def get_mlir_ty(dtype):
+ if dtype == np.float16:
+ return T.f16()
+ if dtype == np.float32:
+ return T.f32()
+ if dtype == np.float64:
+ return T.f64()
+ if dtype == np.int32:
+ return T.i32()
+ if dtype == np.int64:
+ return T.i64()
+ raise NotImplementedError(dtype)
+
+
def c(value, ty=None):
- ty = ir.IndexType.get() if ty is None else ty
+ ty = T.index() if ty is None else ty
return arith.constant(ty, value)
-
def generate_matmul_ws(
input_type=np.float16,
output_type=np.float32,
@@ -86,27 +124,21 @@ def generate_matmul_ws(
num_stages = min(required_stages, max_num_stages)
module = ir.Module.create()
- f16 = ir.F16Type.get()
- f32 = ir.F32Type.get()
- i1 = ir.IntegerType.get_signless(1)
- i32 = ir.IntegerType.get_signless(32)
- index = ir.IndexType.get()
- i8 = ir.IntegerType.get_signless(8)
token_ty = ir.Type.parse("!gpu.async.token")
- a_ty = ir.MemRefType.get([M, K], f16)
- b_ty = ir.MemRefType.get((K, N), f16)
- c_elem_ty = f16 if output_type == np.float16 else f32
+ a_elem_ty = get_mlir_ty(input_type)
+ b_elem_ty = get_mlir_ty(input_type)
+ c_elem_ty = get_mlir_ty(output_type)
+ a_ty = ir.MemRefType.get([M, K], a_elem_ty)
+ b_ty = ir.MemRefType.get((K, N), b_elem_ty)
c_ty = ir.MemRefType.get((M, N), c_elem_ty)
a_tile_shape = a_tma_shape = (BLOCK_M, TMA_LAST_DIM_F16)
b_tma_shape = (BLOCK_K, TMA_LAST_DIM_F16)
b_tile_shape = (BLOCK_K, BLOCK_N)
- txcount = (
- (b_tile_shape[0] * b_tile_shape[1]) + (a_tile_shape[0] * a_tile_shape[1])
- ) * f16_byte
+ txcount = (b_tile_shape[0] * b_tile_shape[1] * get_type_size(a_elem_ty)) + (
+ a_tile_shape[0] * a_tile_shape[1] * get_type_size(b_elem_ty)
+ )
smem_space_str = "#gpu.address_space<workgroup>"
smem_space = ir.Attribute.parse(smem_space_str)
- input_type_str = "f16" if input_type == np.float16 else "f32"
- output_type_str = "f16" if output_type == np.float16 else "f32"
mbar_ty = ir.Type.parse(
"!nvgpu.mbarrier.group<memorySpace = "
+ str(smem_space)
@@ -120,7 +152,7 @@ def generate_matmul_ws(
+ "x"
+ str(TMA_LAST_DIM_F16)
+ "x"
- + str(input_type_str)
+ + get_type_str(a_elem_ty)
+ ", "
+ str(smem_space)
+ ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>"
@@ -131,7 +163,7 @@ def generate_matmul_ws(
+ "x"
+ str(TMA_LAST_DIM_F16)
+ "x"
- + str(input_type_str)
+ + get_type_str(b_elem_ty)
+ ", "
+ str(smem_space)
+ ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>"
@@ -142,7 +174,7 @@ def generate_matmul_ws(
+ "x"
+ str(BLOCK_N)
+ "x"
- + str(output_type_str)
+ + get_type_str(c_elem_ty)
+ ">>"
)
a_wgmma_ty = ir.Type.parse(
@@ -151,7 +183,7 @@ def generate_matmul_ws(
+ "x"
+ str(BLOCK_K)
+ "x"
- + str(input_type_str)
+ + get_type_str(a_elem_ty)
+ ", "
+ smem_space_str
+ ">>"
@@ -162,7 +194,7 @@ def generate_matmul_ws(
+ "x"
+ str(BLOCK_N)
+ "x"
- + str(input_type_str)
+ + get_type_str(a_elem_ty)
+ ", "
+ smem_space_str
+ ">>"
@@ -172,10 +204,10 @@ def generate_matmul_ws(
@func.FuncOp.from_py_func(a_ty, b_ty, c_ty)
def mlir_matmul_warpspecialized(a_host, b_host, c_host):
- lhs_tile_bytes = BLOCK_M * BLOCK_K * f16_byte
- rhs_tile_bytes = BLOCK_N * BLOCK_K * f16_byte
+ lhs_tile_bytes = BLOCK_M * BLOCK_K * get_type_size(a_elem_ty)
+ rhs_tile_bytes = BLOCK_N * BLOCK_K * get_type_size(b_elem_ty)
smem_size_input = (lhs_tile_bytes + rhs_tile_bytes) * num_stages
- smem_size_output = BLOCK_M * BLOCK_N * f32_byte
+ smem_size_output = BLOCK_M * BLOCK_N * get_type_size(c_elem_ty)
smem_size = max(smem_size_input, smem_size_output)
# Step 1. Allocate device memory and memcpy
@@ -195,7 +227,7 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
tma_descs = []
for x_device, tensor_map_ty, tile_shape in tma_specs:
x_unranked = memref.cast(
- ir.UnrankedMemRefType.get(f16, a_ty.memory_space), x_device
+ ir.UnrankedMemRefType.get(a_elem_ty, a_ty.memory_space), x_device
)
tma_descs.append(
nvgpu.TmaCreateDescriptorOp(
@@ -215,14 +247,14 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
[t7],
*map(c, grid),
*map(c, block),
- dynamicSharedMemorySize=c(smem_size, ty=i32)
+ dynamicSharedMemorySize=c(smem_size, ty=T.i32())
)
- launch_op.body.blocks.append(*([index] * 12))
+ launch_op.body.blocks.append(*([T.index()] * 12))
with ir.InsertionPoint(launch_op.body.blocks[0]):
# GPU Step 0. This is need for vectorized ld/st
memref.assume_alignment(c_device, 16)
dynamic_smem = gpu.dynamic_shared_memory(
- ir.MemRefType.get((MLIR_DYNAMIC,), i8, memory_space=smem_space)
+ ir.MemRefType.get((MLIR_DYNAMIC,), T.i8(), memory_space=smem_space)
)
ticks = c(10000000)
@@ -275,7 +307,7 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
# Step 5.2. TMA Main Loop
for_op = scf.ForOp(
- c(0), c(K // BLOCK_K), c(1), [arith.constant(i1, 1)]
+ c(0), c(K // BLOCK_K), c(1), [arith.constant(T.bool(), 1)]
)
with ir.InsertionPoint(for_op.body):
phaseParity = for_op.inner_iter_args[0]
@@ -303,7 +335,7 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(num_stages - 1))
phaseParity = arith.select(
p,
- arith.xori(phaseParity, arith.constant(i1, 1)),
+ arith.xori(phaseParity, arith.constant(T.bool(), 1)),
phaseParity,
)
@@ -311,7 +343,7 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
a_offset = arith.muli(stage, c(lhs_tile_bytes))
a_tma_slice = memref.view(
ir.MemRefType.get(
- a_tma_shape, f16, memory_space=smem_space
+ a_tma_shape, a_elem_ty, memory_space=smem_space
),
dynamic_smem,
a_offset,
@@ -323,18 +355,19 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
)
b_tma_slice_1 = memref.view(
ir.MemRefType.get(
- b_tma_shape, f16, memory_space=smem_space
+ b_tma_shape, b_elem_ty, memory_space=smem_space
),
dynamic_smem,
b_offset,
[],
)
b_offset2 = arith.addi(
- b_offset, c(BLOCK_K * TMA_LAST_DIM_F16 * f16_byte)
+ b_offset,
+ c(BLOCK_K * TMA_LAST_DIM_F16 * get_type_size(b_elem_ty)),
)
b_tma_slice_2 = memref.view(
ir.MemRefType.get(
- b_tma_shape, f16, memory_space=smem_space
+ b_tma_shape, b_elem_ty, memory_space=smem_space
),
dynamic_smem,
b_offset2,
@@ -406,7 +439,7 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
# Step 6.3. MMA Main Loop
for_op = scf.ForOp(
- c(0), c(K // BLOCK_K), c(1), [acc, arith.constant(i1, 0)]
+ c(0), c(K // BLOCK_K), c(1), [acc, arith.constant(T.bool(), 0)]
)
with ir.InsertionPoint(for_op.body):
# Step 6.3.1. Wait mbar1
@@ -435,7 +468,7 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
a_offset = arith.muli(stage, c(lhs_tile_bytes))
a_tile_slice = memref.view(
ir.MemRefType.get(
- a_tile_shape, f16, memory_space=smem_space
+ a_tile_shape, a_elem_ty, memory_space=smem_space
),
dynamic_smem,
a_offset,
@@ -447,7 +480,7 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
)
b_tile_slice = memref.view(
ir.MemRefType.get(
- b_tile_shape, f16, memory_space=smem_space
+ b_tile_shape, b_elem_ty, memory_space=smem_space
),
dynamic_smem,
b_offset,
@@ -506,7 +539,7 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(num_stages - 1))
phaseParity = arith.select(
p,
- arith.xori(phaseParity, arith.constant(i1, 1)),
+ arith.xori(phaseParity, arith.constant(T.bool(), 1)),
phaseParity,
)
@@ -604,27 +637,21 @@ def generate_matmul_multistage(
num_stages = min(required_stages, max_num_stages)
module = ir.Module.create()
- f16 = ir.F16Type.get()
- f32 = ir.F32Type.get()
- i1 = ir.IntegerType.get_signless(1)
- i32 = ir.IntegerType.get_signless(32)
- index = ir.IndexType.get()
- i8 = ir.IntegerType.get_signless(8)
token_ty = ir.Type.parse("!gpu.async.token")
- a_ty = ir.MemRefType.get([M, K], f16)
- b_ty = ir.MemRefType.get((K, N), f16)
- c_elem_ty = f16 if output_type == np.float16 else f32
+ a_elem_ty = get_mlir_ty(input_type)
+ b_elem_ty = get_mlir_ty(input_type)
+ c_elem_ty = get_mlir_ty(output_type)
+ a_ty = ir.MemRefType.get([M, K], a_elem_ty)
+ b_ty = ir.MemRefType.get((K, N), b_elem_ty)
c_ty = ir.MemRefType.get((M, N), c_elem_ty)
a_tile_shape = a_tma_shape = (BLOCK_M, TMA_LAST_DIM_F16)
b_tma_shape = (BLOCK_K, TMA_LAST_DIM_F16)
b_tile_shape = (BLOCK_K, BLOCK_N)
- txcount = (
- (b_tile_shape[0] * b_tile_shape[1]) + (a_tile_shape[0] * a_tile_shape[1])
- ) * f16_byte
+ txcount = (b_tile_shape[0] * b_tile_shape[1] * get_type_size(a_elem_ty)) + (
+ a_tile_shape[0] * a_tile_shape[1] * get_type_size(b_elem_ty)
+ )
smem_space_str = "#gpu.address_space<workgroup>"
smem_space = ir.Attribute.parse(smem_space_str)
- input_type_str = "f16" if input_type == np.float16 else "f32"
- output_type_str = "f16" if output_type == np.float16 else "f32"
mbar_ty = ir.Type.parse(
"!nvgpu.mbarrier.group<memorySpace = "
+ str(smem_space)
@@ -638,7 +665,7 @@ def generate_matmul_multistage(
+ "x"
+ str(TMA_LAST_DIM_F16)
+ "x"
- + str(input_type_str)
+ + get_type_str(a_elem_ty)
+ ", "
+ str(smem_space)
+ ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>"
@@ -649,7 +676,7 @@ def generate_matmul_multistage(
+ "x"
+ str(TMA_LAST_DIM_F16)
+ "x"
- + str(input_type_str)
+ + get_type_str(b_elem_ty)
+ ", "
+ str(smem_space)
+ ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>"
@@ -660,7 +687,7 @@ def generate_matmul_multistage(
+ "x"
+ str(BLOCK_N)
+ "x"
- + str(output_type_str)
+ + get_type_str(c_elem_ty)
+ ">>"
)
a_wgmma_ty = ir.Type.parse(
@@ -669,7 +696,7 @@ def generate_matmul_multistage(
+ "x"
+ str(BLOCK_K)
+ "x"
- + str(input_type_str)
+ + get_type_str(a_elem_ty)
+ ", "
+ smem_space_str
+ ">>"
@@ -680,7 +707,7 @@ def generate_matmul_multistage(
+ "x"
+ str(BLOCK_N)
+ "x"
- + str(input_type_str)
+ + get_type_str(a_elem_ty)
+ ", "
+ smem_space_str
+ ">>"
@@ -690,10 +717,10 @@ def generate_matmul_multistage(
@func.FuncOp.from_py_func(a_ty, b_ty, c_ty)
def mlir_matmul_multistage(a_host, b_host, c_host):
- lhs_tile_bytes = BLOCK_M * BLOCK_K * f16_byte
- rhs_tile_bytes = BLOCK_N * BLOCK_K * f16_byte
+ lhs_tile_bytes = BLOCK_M * BLOCK_K * get_type_size(a_elem_ty)
+ rhs_tile_bytes = BLOCK_N * BLOCK_K * get_type_size(b_elem_ty)
smem_size_input = (lhs_tile_bytes + rhs_tile_bytes) * num_stages
- smem_size_output = BLOCK_M * BLOCK_N * f32_byte
+ smem_size_output = BLOCK_M * BLOCK_N * get_type_size(c_elem_ty)
smem_size = max(smem_size_input, smem_size_output)
# Step 1. Allocate device memory and memcpy
@@ -713,7 +740,7 @@ def mlir_matmul_multistage(a_host, b_host, c_host):
tma_descs = []
for x_device, tensor_map_ty, tile_shape in tma_specs:
x_unranked = memref.cast(
- ir.UnrankedMemRefType.get(f16, a_ty.memory_space), x_device
+ ir.UnrankedMemRefType.get(a_elem_ty, a_ty.memory_space), x_device
)
tma_descs.append(
nvgpu.TmaCreateDescriptorOp(
@@ -733,14 +760,14 @@ def mlir_matmul_multistage(a_host, b_host, c_host):
[t7],
*map(c, grid),
*map(c, block),
- dynamicSharedMemorySize=c(smem_size, ty=i32)
+ dynamicSharedMemorySize=c(smem_size, ty=T.i32())
)
- launch_op.body.blocks.append(*([index] * 12))
+ launch_op.body.blocks.append(*([T.index()] * 12))
with ir.InsertionPoint(launch_op.body.blocks[0]):
# GPU Step 0. Bootstrapping
memref.assume_alignment(c_device, 16)
dynamic_smem = gpu.dynamic_shared_memory(
- ir.MemRefType.get((MLIR_DYNAMIC,), i8, memory_space=smem_space)
+ ir.MemRefType.get((MLIR_DYNAMIC,), T.i8(), memory_space=smem_space)
)
ticks = c(10000000)
tidx = gpu.thread_id(gpu.Dimension.x)
@@ -769,7 +796,9 @@ def mlir_matmul_multistage(a_host, b_host, c_host):
# Step 3.1. Calculate offsets
a_offset = arith.muli(iv, c(lhs_tile_bytes))
a_tma_slice = memref.view(
- ir.MemRefType.get(a_tma_shape, f16, memory_space=smem_space),
+ ir.MemRefType.get(
+ a_tma_shape, a_elem_ty, memory_space=smem_space
+ ),
dynamic_smem,
a_offset,
[],
@@ -779,16 +808,21 @@ def mlir_matmul_multistage(a_host, b_host, c_host):
c(lhs_tile_bytes * num_stages),
)
b_tma_slice_1 = memref.view(
- ir.MemRefType.get(b_tma_shape, f16, memory_space=smem_space),
+ ir.MemRefType.get(
+ b_tma_shape, b_elem_ty, memory_space=smem_space
+ ),
dynamic_smem,
b_offset,
[],
)
b_offset2 = arith.addi(
- b_offset, c(BLOCK_K * TMA_LAST_DIM_F16 * f16_byte)
+ b_offset,
+ c(BLOCK_K * TMA_LAST_DIM_F16 * get_type_size(b_elem_ty)),
)
b_tma_slice_2 = memref.view(
- ir.MemRefType.get(b_tma_shape, f16, memory_space=smem_space),
+ ir.MemRefType.get(
+ b_tma_shape, b_elem_ty, memory_space=smem_space
+ ),
dynamic_smem,
b_offset2,
[],
@@ -850,7 +884,7 @@ def mlir_matmul_multistage(a_host, b_host, c_host):
# GPU Step 4. Main Loop
acc = nvgpu.warpgroup_mma_init_accumulator(acc_ty)
for_op = scf.ForOp(
- c(0), c(K // BLOCK_K), c(1), [acc, arith.constant(i1, 0)]
+ c(0), c(K // BLOCK_K), c(1), [acc, arith.constant(T.bool(), 0)]
)
with ir.InsertionPoint(for_op.body):
# Step 4.1. Wait mbarTMA
@@ -876,7 +910,9 @@ def mlir_matmul_multistage(a_host, b_host, c_host):
# Step 4.2. Create WGMMA Descriptors
a_offset = arith.muli(stage, c(lhs_tile_bytes))
a_tile_slice = memref.view(
- ir.MemRefType.get(a_tile_shape, f16, memory_space=smem_space),
+ ir.MemRefType.get(
+ a_tile_shape, a_elem_ty, memory_space=smem_space
+ ),
dynamic_smem,
a_offset,
[],
@@ -886,7 +922,9 @@ def mlir_matmul_multistage(a_host, b_host, c_host):
c(lhs_tile_bytes * num_stages),
)
b_tile_slice = memref.view(
- ir.MemRefType.get(b_tile_shape, f16, memory_space=smem_space),
+ ir.MemRefType.get(
+ b_tile_shape, b_elem_ty, memory_space=smem_space
+ ),
dynamic_smem,
b_offset,
[],
@@ -924,7 +962,7 @@ def mlir_matmul_multistage(a_host, b_host, c_host):
a_offset = arith.muli(nextSlot, c(lhs_tile_bytes))
a_tma_slice = memref.view(
ir.MemRefType.get(
- a_tma_shape, f16, memory_space=smem_space
+ a_tma_shape, a_elem_ty, memory_space=smem_space
),
dynamic_smem,
a_offset,
@@ -936,18 +974,19 @@ def mlir_matmul_multistage(a_host, b_host, c_host):
)
b_tma_slice_1 = memref.view(
ir.MemRefType.get(
- b_tma_shape, f16, memory_space=smem_space
+ b_tma_shape, b_elem_ty, memory_space=smem_space
),
dynamic_smem,
b_offset,
[],
)
b_offset2 = arith.addi(
- b_offset, c(BLOCK_K * TMA_LAST_DIM_F16 * f16_byte)
+ b_offset,
+ c(BLOCK_K * TMA_LAST_DIM_F16 * get_type_size(b_elem_ty)),
)
b_tma_slice_2 = memref.view(
ir.MemRefType.get(
- b_tma_shape, f16, memory_space=smem_space
+ b_tma_shape, b_elem_ty, memory_space=smem_space
),
dynamic_smem,
b_offset2,
@@ -1010,7 +1049,9 @@ def mlir_matmul_multistage(a_host, b_host, c_host):
# Step 4.5. Change the phaseParity
p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(num_stages - 1))
phaseParity = arith.select(
- p, arith.xori(phaseParity, arith.constant(i1, 1)), phaseParity
+ p,
+ arith.xori(phaseParity, arith.constant(T.bool(), 1)),
+ phaseParity,
)
# Step 4.5. Yield
>From 777f2089b144b935ba97d291ea3f66c0622a6e2a Mon Sep 17 00:00:00 2001
From: grypp <guray.ozen at gmail.com>
Date: Tue, 20 Feb 2024 10:23:23 +0000
Subject: [PATCH 06/10] format
---
.../Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py | 3 +++
1 file changed, 3 insertions(+)
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
index 434af5e0920466..52732b1c01426e 100644
--- a/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
+++ b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
@@ -25,6 +25,8 @@
MLIR_DYNAMIC = -9223372036854775808
DEBUG = False
+
+
def debug_print(fmt, *args, predicate=None, threadNumber=-1, forcePrint=False):
if not DEBUG and not forcePrint:
return
@@ -102,6 +104,7 @@ def c(value, ty=None):
ty = T.index() if ty is None else ty
return arith.constant(ty, value)
+
def generate_matmul_ws(
input_type=np.float16,
output_type=np.float32,
>From 0934dcb2d2037f0595e982fc05577fbf8cea3fa6 Mon Sep 17 00:00:00 2001
From: grypp <guray.ozen at gmail.com>
Date: Sun, 3 Mar 2024 08:23:20 +0000
Subject: [PATCH 07/10] Allow multiple stages, and fix the kernels. Add
test_short that test multiple cases.
---
.../GPU/CUDA/sm90/python/matmul.py | 157 +++++++--
.../CUDA/sm90/python/tools/matmulBuilder.py | 306 ++++++++++--------
2 files changed, 308 insertions(+), 155 deletions(-)
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py b/mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py
index f5d5b30c70d1e1..88609273a8b548 100644
--- a/mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py
+++ b/mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py
@@ -1,6 +1,6 @@
# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
# RUN: %PYTHON %s | FileCheck %s
-# CHECK: PASS
+
# ===--- GEMM Hopper Tensor Core Integration Test ---===
#
@@ -168,6 +168,8 @@ def matmul(
no_verify=False,
):
# Print the configuration
+ required_stages = (M * K + K * N) // (BLOCK_M * BLOCK_K + BLOCK_K * BLOCK_N)
+ num_stages = min(required_stages, max_num_stages)
ity = "f16" if input_type == np.float16 else "f32"
oty = "f16" if output_type == np.float16 else "f32"
gemmty = "Warp specialization" if use_warp_specialization else "Multistage"
@@ -193,7 +195,7 @@ def matmul(
+ "x"
+ str(BLOCK_K)
+ ", stages "
- + str(max_num_stages)
+ + str(num_stages)
+ " --==="
)
@@ -209,7 +211,7 @@ def matmul(
BLOCK_K,
use_warp_specialization,
saveIR,
- max_num_stages,
+ num_stages,
)
# Allocate matrices and invoke the matmul
@@ -219,10 +221,17 @@ def matmul(
mem_a = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(a)))
mem_b = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(b)))
mem_c = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(c)))
- kernelName = (
- "mlir_matmul_warpspecialized"
- if use_warp_specialization
- else "mlir_matmul_multistage"
+ kernelName = matmulBuilder.make_kernel_name(
+ input_type,
+ output_type,
+ M,
+ N,
+ K,
+ BLOCK_M,
+ BLOCK_N,
+ BLOCK_K,
+ num_stages,
+ use_warp_specialization,
)
# Launch the MLIR generated kernel
@@ -243,24 +252,118 @@ def matmul(
print("PASS ")
+# Takes longer time to run
+def test_long():
+ for stages in range(1,7):
+ for M in [128, 512, 1024, 4096, 8192]:
+ for N in [128, 512, 1024, 4096, 8192]:
+ for K in [64, 128, 512, 1024, 4096, 8192]:
+ matmul(
+ np.float16,
+ np.float32,
+ M,N,
+ K,
+ max_num_stages=stages,
+ use_warp_specialization=False,
+ no_verify=True,
+ saveIR=True
+ )
+ matmul(
+ np.float16,
+ np.float32,
+ M,N,
+ K,
+ max_num_stages=stages,
+ use_warp_specialization=True,
+ )
-# GEMM Multistage f32 += f16 * f16
-matmul(
- np.float16,
- np.float32,
- 128,
- 128,
- 4096,
- max_num_stages=3,
- use_warp_specialization=False,
-)
-# GEMM Warp Specilized f32 += f16 * f16
-matmul(
- np.float16,
- np.float32,
- 256,
- 1024,
- 512,
- max_num_stages=3,
- use_warp_specialization=True,
-)
+def test_short():
+ for stages in [1, 3]:
+ for M in [128, 512]:
+ for N in [128, 512]:
+ for K in [64, 512]:
+ matmul(
+ np.float16,
+ np.float32,
+ M,N,
+ K,
+ max_num_stages=stages,
+ use_warp_specialization=False,
+ no_verify=True,
+ saveIR=True
+ )
+ matmul(
+ np.float16,
+ np.float32,
+ M,N,
+ K,
+ max_num_stages=stages,
+ use_warp_specialization=True,
+ )
+
+# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x128x64, Tile 128x128x64, stages 1 --===
+# CHECK: PASS
+# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x128x64, Tile 128x128x64, stages 1 --===
+# CHECK: PASS
+# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x128x512, Tile 128x128x64, stages 1 --===
+# CHECK: PASS
+# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x128x512, Tile 128x128x64, stages 1 --===
+# CHECK: PASS
+# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x512x64, Tile 128x128x64, stages 1 --===
+# CHECK: PASS
+# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x512x64, Tile 128x128x64, stages 1 --===
+# CHECK: PASS
+# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x512x512, Tile 128x128x64, stages 1 --===
+# CHECK: PASS
+# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x512x512, Tile 128x128x64, stages 1 --===
+# CHECK: PASS
+# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x128x64, Tile 128x128x64, stages 1 --===
+# CHECK: PASS
+# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x128x64, Tile 128x128x64, stages 1 --===
+# CHECK: PASS
+# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x128x512, Tile 128x128x64, stages 1 --===
+# CHECK: PASS
+# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x128x512, Tile 128x128x64, stages 1 --===
+# CHECK: PASS
+# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x512x64, Tile 128x128x64, stages 1 --===
+# CHECK: PASS
+# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x512x64, Tile 128x128x64, stages 1 --===
+# CHECK: PASS
+# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x512x512, Tile 128x128x64, stages 1 --===
+# CHECK: PASS
+# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x512x512, Tile 128x128x64, stages 1 --===
+# CHECK: PASS
+# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x128x64, Tile 128x128x64, stages 1 --===
+# CHECK: PASS
+# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x128x64, Tile 128x128x64, stages 1 --===
+# CHECK: PASS
+# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x128x512, Tile 128x128x64, stages 3 --===
+# CHECK: PASS
+# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x128x512, Tile 128x128x64, stages 3 --===
+# CHECK: PASS
+# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x512x64, Tile 128x128x64, stages 2 --===
+# CHECK: PASS
+# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x512x64, Tile 128x128x64, stages 2 --===
+# CHECK: PASS
+# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x512x512, Tile 128x128x64, stages 3 --===
+# CHECK: PASS
+# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x512x512, Tile 128x128x64, stages 3 --===
+# CHECK: PASS
+# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x128x64, Tile 128x128x64, stages 2 --===
+# CHECK: PASS
+# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x128x64, Tile 128x128x64, stages 2 --===
+# CHECK: PASS
+# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x128x512, Tile 128x128x64, stages 3 --===
+# CHECK: PASS
+# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x128x512, Tile 128x128x64, stages 3 --===
+# CHECK: PASS
+# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x512x64, Tile 128x128x64, stages 3 --===
+# CHECK: PASS
+# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x512x64, Tile 128x128x64, stages 3 --===
+# CHECK: PASS
+# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x512x512, Tile 128x128x64, stages 3 --===
+# CHECK: PASS
+# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x512x512, Tile 128x128x64, stages 3 --===
+# CHECK: PASS
+
+test_short()
\ No newline at end of file
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
index 52732b1c01426e..a541378359fe97 100644
--- a/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
+++ b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
@@ -105,6 +105,41 @@ def c(value, ty=None):
return arith.constant(ty, value)
+def make_kernel_name(
+ input_type=np.float16,
+ output_type=np.float32,
+ M=4096,
+ N=4096,
+ K=4096,
+ BLOCK_M=128,
+ BLOCK_N=128,
+ BLOCK_K=128,
+ num_stages=3,
+ use_warp_specialization=False,
+):
+ kernelName = (
+ "warpspecialized"
+ if use_warp_specialization
+ else "multistage"
+ )
+ return (
+ kernelName
+ + "_"
+ + str(M)
+ + "x"
+ + str(N)
+ + "x"
+ + str(K)
+ + "_"
+ + str(BLOCK_M)
+ + "x"
+ + str(BLOCK_N)
+ + "x"
+ + str(BLOCK_K)
+ + "_"
+ + str(num_stages)
+ )
+
def generate_matmul_ws(
input_type=np.float16,
output_type=np.float32,
@@ -114,7 +149,7 @@ def generate_matmul_ws(
BLOCK_M=128,
BLOCK_N=128,
BLOCK_K=128,
- max_num_stages=3,
+ num_stages=3,
):
# Limitaitons for now
assert input_type == np.float16
@@ -123,9 +158,6 @@ def generate_matmul_ws(
assert N % BLOCK_N == 0
assert K % BLOCK_K == 0
- required_stages = (M * K + K * N) // (BLOCK_M * BLOCK_K + BLOCK_K * BLOCK_N)
- num_stages = min(required_stages, max_num_stages)
-
module = ir.Module.create()
token_ty = ir.Type.parse("!gpu.async.token")
a_elem_ty = get_mlir_ty(input_type)
@@ -202,11 +234,15 @@ def generate_matmul_ws(
+ smem_space_str
+ ">>"
)
-
+ kernelName = make_kernel_name(
+ input_type, output_type, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, num_stages, True
+ )
with ir.InsertionPoint(module.body):
-
- @func.FuncOp.from_py_func(a_ty, b_ty, c_ty)
- def mlir_matmul_warpspecialized(a_host, b_host, c_host):
+ fop = func.FuncOp(kernelName, ([a_ty, b_ty, c_ty], []))
+ with ir.InsertionPoint(fop.add_entry_block()):
+ a_host = fop.arguments[0]
+ b_host = fop.arguments[1]
+ c_host = fop.arguments[2]
lhs_tile_bytes = BLOCK_M * BLOCK_K * get_type_size(a_elem_ty)
rhs_tile_bytes = BLOCK_N * BLOCK_K * get_type_size(b_elem_ty)
smem_size_input = (lhs_tile_bytes + rhs_tile_bytes) * num_stages
@@ -301,6 +337,7 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
nvgpu.tma_prefetch_descriptor(a_tma_desc, predicate=wgPrimaryThread)
nvgpu.tma_prefetch_descriptor(b_tma_desc, predicate=wgPrimaryThread)
+ ns = num_stages if num_stages == 1 else num_stages - 1
# GPU Step 5. Producer Warpgroup (TMA Warpgroup)
with ir.InsertionPoint(scf.IfOp(is_producer).then_block):
# Step 5.1. Reduce register size
@@ -319,7 +356,7 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
# Step 5.2.1. Wait mbarDONE
debug_print(
- "[prod] {} | mbarDONE[{}] try_wait phase={}",
+ "[prod] iv={} | mbarDONE[{}] try_wait phase={}",
iv,
stage,
phaseParity,
@@ -329,7 +366,7 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
mbarDONE, phaseParity, ticks, mbarId=stage
)
debug_print(
- "[prod] {} | mbarDONE[{}] try_wait phase={} [done]",
+ "[prod] iv={} | mbarDONE[{}] try_wait phase={} [done]",
iv,
stage,
phaseParity,
@@ -412,7 +449,7 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
# Step 5.2.3. Arrive mbarTMA
debug_print(
- "[prod] {} | mbarTMA[{}] arrive",
+ "[prod] iv={} | mbarTMA[{}] arrive",
iv,
stage,
predicate=producerPrimaryThread,
@@ -421,7 +458,7 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
mbarTMA, c(txcount), stage, predicate=producerPrimaryThread
)
debug_print(
- "[prod] {} | mbarTMA[{}] arrive [done]",
+ "[prod] iv={} | mbarTMA[{}] arrive [done]",
iv,
stage,
predicate=producerPrimaryThread,
@@ -450,7 +487,7 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
iv = for_op.induction_variable
stage = arith.remui(iv, c(num_stages))
debug_print(
- "[cons] {} | mbarTMA[{}] try_wait phase={}",
+ "[cons] iv={} | mbarTMA[{}] try_wait phase={}",
iv,
stage,
phaseParity,
@@ -460,7 +497,7 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
mbarTMA, phaseParity, ticks, mbarId=stage
)
debug_print(
- "[cons] {} | mbarTMA[{}] try_wait phase={} [done]",
+ "[cons] iv={} | mbarTMA[{}] try_wait phase={} [done]",
iv,
stage,
phaseParity,
@@ -509,32 +546,29 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
)
# Step 6.3.4. Arrive mbarDONE
- p1 = arith.cmpi(arith.CmpIPredicate.sgt, iv, c(0))
- p_arrive = arith.andi(consumerPrimaryThread, p1)
+ if num_stages == 1:
+ p_arrive = consumerPrimaryThread
+ else:
+ p1 = arith.cmpi(arith.CmpIPredicate.sgt, iv, c(0))
+ p_arrive = arith.andi(consumerPrimaryThread, p1)
with ir.InsertionPoint(scf.IfOp(p_arrive).then_block):
p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(0))
barId = arith.select(
p, c(num_stages - 1), arith.subi(stage, c(1))
)
- remoteCtaId = c(0)
- pred = consumerPrimaryThread
debug_print(
- "[cons] {} | mbarDONE[{}] arrive.mapa pred={} ",
+ "[cons] iv={} | mbarDONE[{}] arrive ",
iv,
barId,
- remoteCtaId,
- pred,
predicate=consumerPrimaryThread,
)
nvgpu.mbarrier_arrive(
ir.Type.parse("!nvgpu.mbarrier.token"), mbarDONE, barId
)
debug_print(
- "[cons] {} | mbarDONE[{}] arrive pred={} [done]",
+ "[cons] iv={} | mbarDONE[{}] arrive [done]",
iv,
barId,
- remoteCtaId,
- pred,
predicate=consumerPrimaryThread,
)
scf.yield_([])
@@ -609,11 +643,13 @@ def mlir_matmul_warpspecialized(a_host, b_host, c_host):
# Step 4. Copy back to host
t8 = gpu.wait(token_ty, [launch_op])
t9 = gpu.memcpy(token_ty, [t8], c_host, c_device)
+ gpu.dealloc(token_ty, [t8], a_device)
+ gpu.dealloc(token_ty, [t8], b_device)
gpu.wait(token_ty, [t9])
+ gpu.dealloc(token_ty, [t8], c_device)
+ func.ReturnOp([])
- mlir_matmul_warpspecialized.func_op.attributes[
- "llvm.emit_c_interface"
- ] = ir.UnitAttr.get()
+ fop.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
module.operation.verify()
return module
@@ -627,7 +663,7 @@ def generate_matmul_multistage(
BLOCK_M=128,
BLOCK_N=128,
BLOCK_K=64,
- max_num_stages=3,
+ num_stages=3,
):
# Limitaitons for now
assert input_type == np.float16
@@ -636,9 +672,6 @@ def generate_matmul_multistage(
assert N % BLOCK_N == 0
assert K % BLOCK_K == 0
- required_stages = (M * K + K * N) // (BLOCK_M * BLOCK_K + BLOCK_K * BLOCK_N)
- num_stages = min(required_stages, max_num_stages)
-
module = ir.Module.create()
token_ty = ir.Type.parse("!gpu.async.token")
a_elem_ty = get_mlir_ty(input_type)
@@ -717,9 +750,23 @@ def generate_matmul_multistage(
)
with ir.InsertionPoint(module.body):
-
- @func.FuncOp.from_py_func(a_ty, b_ty, c_ty)
- def mlir_matmul_multistage(a_host, b_host, c_host):
+ kernelName = make_kernel_name(
+ input_type,
+ output_type,
+ M,
+ N,
+ K,
+ BLOCK_M,
+ BLOCK_N,
+ BLOCK_K,
+ num_stages,
+ False,
+ )
+ fop = func.FuncOp(kernelName, ([a_ty, b_ty, c_ty], []))
+ with ir.InsertionPoint(fop.add_entry_block()):
+ a_host = fop.arguments[0]
+ b_host = fop.arguments[1]
+ c_host = fop.arguments[2]
lhs_tile_bytes = BLOCK_M * BLOCK_K * get_type_size(a_elem_ty)
rhs_tile_bytes = BLOCK_N * BLOCK_K * get_type_size(b_elem_ty)
smem_size_input = (lhs_tile_bytes + rhs_tile_bytes) * num_stages
@@ -792,7 +839,8 @@ def mlir_matmul_multistage(a_host, b_host, c_host):
nvgpu.tma_prefetch_descriptor(b_tma_desc, predicate=primaryThread)
# GPU Step 3. Prologue (global memory --> shared memory)
- for_op = scf.ForOp(c(0), c(num_stages - 1), c(1))
+ ns = num_stages if num_stages == 1 else num_stages - 1
+ for_op = scf.ForOp(c(0), c(ns), c(1))
with ir.InsertionPoint(for_op.body):
iv = for_op.induction_variable
@@ -951,104 +999,105 @@ def mlir_matmul_multistage(a_host, b_host, c_host):
new_acc = nvgpu.WarpgroupMmaOp(
acc.type, da, db, carry_acc, transposeB=True
)
+ if num_stages == 1:
+ nvvm.WgmmaWaitGroupSyncOp(0)
# Step 4.4. Load TMA for next stage
p1 = arith.cmpi(
arith.CmpIPredicate.ult,
- arith.addi(iv, c(num_stages - 1)),
+ arith.addi(iv, c(ns)),
c(K // BLOCK_K),
)
p = arith.andi(primaryThread, p1)
- with ir.InsertionPoint(scf.IfOp(p).then_block):
- nextStage = arith.addi(iv, c(num_stages - 1))
- nextSlot = arith.remui(nextStage, c(num_stages))
- a_offset = arith.muli(nextSlot, c(lhs_tile_bytes))
- a_tma_slice = memref.view(
- ir.MemRefType.get(
- a_tma_shape, a_elem_ty, memory_space=smem_space
- ),
- dynamic_smem,
- a_offset,
- [],
- )
- b_offset = arith.addi(
- arith.muli(nextSlot, c(rhs_tile_bytes)),
- c(lhs_tile_bytes * num_stages),
- )
- b_tma_slice_1 = memref.view(
- ir.MemRefType.get(
- b_tma_shape, b_elem_ty, memory_space=smem_space
- ),
- dynamic_smem,
- b_offset,
- [],
- )
- b_offset2 = arith.addi(
- b_offset,
- c(BLOCK_K * TMA_LAST_DIM_F16 * get_type_size(b_elem_ty)),
- )
- b_tma_slice_2 = memref.view(
- ir.MemRefType.get(
- b_tma_shape, b_elem_ty, memory_space=smem_space
- ),
- dynamic_smem,
- b_offset2,
- [],
- )
+ nextStage = arith.addi(iv, c(ns))
+ nextSlot = arith.remui(nextStage, c(num_stages))
+ a_offset = arith.muli(nextSlot, c(lhs_tile_bytes))
- coord = arith.muli(c(64), nextStage)
- debug_print(
- "[MainLoop] iv={} TMA Load a_offset={} b_offset={} b_offset2={} @ a=({},{}) b=({},{})",
- iv,
- a_offset,
- b_offset,
- b_offset2,
- coord,
- dimX,
- dimY,
- coord,
- predicate=primaryThread,
- )
- nvgpu.TmaAsyncLoadOp(
- a_tma_slice,
- mbarTMA,
- a_tma_desc,
- coordinates=[coord, dimX],
- mbarId=nextSlot,
- predicate=primaryThread,
- )
- nvgpu.TmaAsyncLoadOp(
- b_tma_slice_1,
- mbarTMA,
- b_tma_desc,
- coordinates=[dimY, coord],
- mbarId=nextSlot,
- predicate=primaryThread,
- )
- dimY2 = arith.addi(dimY, c(64))
- nvgpu.TmaAsyncLoadOp(
- b_tma_slice_2,
- mbarTMA,
- b_tma_desc,
- coordinates=[dimY2, coord],
- mbarId=nextSlot,
- predicate=primaryThread,
- )
+ debug_print(
+ "[MainLoop] mbarTMA[{}] arrive",
+ nextSlot,
+ predicate=p,
+ )
+ nvgpu.mbarrier_arrive_expect_tx(
+ mbarTMA, c(txcount), nextSlot, predicate=p
+ )
+ debug_print(
+ "[MainLoop] mbarTMA[{}] arrive [done]",
+ nextSlot,
+ predicate=p,
+ )
- debug_print(
- "[MainLoop] mbarTMA[{}] arrive",
- nextSlot,
- predicate=primaryThread,
- )
- nvgpu.mbarrier_arrive_expect_tx(
- mbarTMA, c(txcount), nextSlot, predicate=primaryThread
- )
- debug_print(
- "[MainLoop] mbarTMA[{}] arrive [done]",
- nextSlot,
- predicate=primaryThread,
- )
- scf.yield_([])
+ a_tma_slice = memref.view(
+ ir.MemRefType.get(
+ a_tma_shape, a_elem_ty, memory_space=smem_space
+ ),
+ dynamic_smem,
+ a_offset,
+ [],
+ )
+ b_offset = arith.addi(
+ arith.muli(nextSlot, c(rhs_tile_bytes)),
+ c(lhs_tile_bytes * num_stages),
+ )
+ b_tma_slice_1 = memref.view(
+ ir.MemRefType.get(
+ b_tma_shape, b_elem_ty, memory_space=smem_space
+ ),
+ dynamic_smem,
+ b_offset,
+ [],
+ )
+ b_offset2 = arith.addi(
+ b_offset,
+ c(BLOCK_K * TMA_LAST_DIM_F16 * get_type_size(b_elem_ty)),
+ )
+ b_tma_slice_2 = memref.view(
+ ir.MemRefType.get(
+ b_tma_shape, b_elem_ty, memory_space=smem_space
+ ),
+ dynamic_smem,
+ b_offset2,
+ [],
+ )
+
+ coord = arith.muli(c(64), nextStage)
+ debug_print(
+ "[MainLoop] iv={} TMA Load a_offset={} b_offset={} b_offset2={} @ a=({},{}) b=({},{})",
+ iv,
+ a_offset,
+ b_offset,
+ b_offset2,
+ coord,
+ dimX,
+ dimY,
+ coord,
+ predicate=p,
+ )
+ nvgpu.TmaAsyncLoadOp(
+ a_tma_slice,
+ mbarTMA,
+ a_tma_desc,
+ coordinates=[coord, dimX],
+ mbarId=nextSlot,
+ predicate=p,
+ )
+ nvgpu.TmaAsyncLoadOp(
+ b_tma_slice_1,
+ mbarTMA,
+ b_tma_desc,
+ coordinates=[dimY, coord],
+ mbarId=nextSlot,
+ predicate=p,
+ )
+ dimY2 = arith.addi(dimY, c(64))
+ nvgpu.TmaAsyncLoadOp(
+ b_tma_slice_2,
+ mbarTMA,
+ b_tma_desc,
+ coordinates=[dimY2, coord],
+ mbarId=nextSlot,
+ predicate=p,
+ )
# Step 4.5. Change the phaseParity
p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(num_stages - 1))
phaseParity = arith.select(
@@ -1062,7 +1111,6 @@ def mlir_matmul_multistage(a_host, b_host, c_host):
# Step 5. Wait All WGMMA groups
nvvm.WgmmaWaitGroupSyncOp(0)
- gpu.barrier()
# Step 6. Epilogue (registers --> shared memory)
acc_smem_ty = ir.MemRefType.get(
@@ -1113,10 +1161,12 @@ def mlir_matmul_multistage(a_host, b_host, c_host):
# Step 4. Copy back to host
t8 = gpu.wait(token_ty, [launch_op])
t9 = gpu.memcpy(token_ty, [t8], c_host, c_device)
+ gpu.dealloc(token_ty, [t8], a_device)
+ gpu.dealloc(token_ty, [t8], b_device)
gpu.wait(token_ty, [t9])
+ gpu.dealloc(token_ty, [t8], c_device)
+ func.ReturnOp([])
- mlir_matmul_multistage.func_op.attributes[
- "llvm.emit_c_interface"
- ] = ir.UnitAttr.get()
+ fop.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
module.operation.verify()
return module
>From 861fe27fe18f398ada57b5940f01df39e77ea3d6 Mon Sep 17 00:00:00 2001
From: grypp <guray.ozen at gmail.com>
Date: Sun, 3 Mar 2024 08:42:00 +0000
Subject: [PATCH 08/10] format
---
.../GPU/CUDA/sm90/python/matmul.py | 108 +++++++-----------
.../CUDA/sm90/python/tools/matmulBuilder.py | 11 +-
2 files changed, 44 insertions(+), 75 deletions(-)
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py b/mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py
index 88609273a8b548..dd474b01ffabef 100644
--- a/mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py
+++ b/mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py
@@ -252,118 +252,90 @@ def matmul(
print("PASS ")
+
# Takes longer time to run
def test_long():
- for stages in range(1,7):
+ for stages in range(1, 7):
for M in [128, 512, 1024, 4096, 8192]:
for N in [128, 512, 1024, 4096, 8192]:
for K in [64, 128, 512, 1024, 4096, 8192]:
matmul(
np.float16,
np.float32,
- M,N,
+ M,
+ N,
K,
max_num_stages=stages,
use_warp_specialization=False,
no_verify=True,
- saveIR=True
)
matmul(
np.float16,
np.float32,
- M,N,
+ M,
+ N,
K,
max_num_stages=stages,
- use_warp_specialization=True,
+ use_warp_specialization=True,
)
+
def test_short():
for stages in [1, 3]:
for M in [128, 512]:
- for N in [128, 512]:
- for K in [64, 512]:
+ for N in [128]:
+ for K in [64, 256]:
matmul(
np.float16,
np.float32,
- M,N,
+ M,
+ N,
K,
max_num_stages=stages,
use_warp_specialization=False,
- no_verify=True,
- saveIR=True
)
matmul(
np.float16,
np.float32,
- M,N,
+ M,
+ N,
K,
max_num_stages=stages,
- use_warp_specialization=True,
+ use_warp_specialization=True,
)
+
# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x128x64, Tile 128x128x64, stages 1 --===
-# CHECK: PASS
+# CHECK: PASS
# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x128x64, Tile 128x128x64, stages 1 --===
-# CHECK: PASS
-# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x128x512, Tile 128x128x64, stages 1 --===
-# CHECK: PASS
-# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x128x512, Tile 128x128x64, stages 1 --===
-# CHECK: PASS
-# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x512x64, Tile 128x128x64, stages 1 --===
-# CHECK: PASS
-# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x512x64, Tile 128x128x64, stages 1 --===
-# CHECK: PASS
-# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x512x512, Tile 128x128x64, stages 1 --===
-# CHECK: PASS
-# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x512x512, Tile 128x128x64, stages 1 --===
-# CHECK: PASS
+# CHECK: PASS
+# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x128x256, Tile 128x128x64, stages 1 --===
+# CHECK: PASS
+# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x128x256, Tile 128x128x64, stages 1 --===
+# CHECK: PASS
# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x128x64, Tile 128x128x64, stages 1 --===
-# CHECK: PASS
+# CHECK: PASS
# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x128x64, Tile 128x128x64, stages 1 --===
-# CHECK: PASS
-# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x128x512, Tile 128x128x64, stages 1 --===
-# CHECK: PASS
-# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x128x512, Tile 128x128x64, stages 1 --===
-# CHECK: PASS
-# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x512x64, Tile 128x128x64, stages 1 --===
-# CHECK: PASS
-# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x512x64, Tile 128x128x64, stages 1 --===
-# CHECK: PASS
-# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x512x512, Tile 128x128x64, stages 1 --===
-# CHECK: PASS
-# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x512x512, Tile 128x128x64, stages 1 --===
-# CHECK: PASS
+# CHECK: PASS
+# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x128x256, Tile 128x128x64, stages 1 --===
+# CHECK: PASS
+# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x128x256, Tile 128x128x64, stages 1 --===
+# CHECK: PASS
# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x128x64, Tile 128x128x64, stages 1 --===
-# CHECK: PASS
+# CHECK: PASS
# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x128x64, Tile 128x128x64, stages 1 --===
-# CHECK: PASS
-# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x128x512, Tile 128x128x64, stages 3 --===
-# CHECK: PASS
-# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x128x512, Tile 128x128x64, stages 3 --===
-# CHECK: PASS
-# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x512x64, Tile 128x128x64, stages 2 --===
-# CHECK: PASS
-# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x512x64, Tile 128x128x64, stages 2 --===
-# CHECK: PASS
-# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x512x512, Tile 128x128x64, stages 3 --===
-# CHECK: PASS
-# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x512x512, Tile 128x128x64, stages 3 --===
-# CHECK: PASS
+# CHECK: PASS
+# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x128x256, Tile 128x128x64, stages 3 --===
+# CHECK: PASS
+# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x128x256, Tile 128x128x64, stages 3 --===
+# CHECK: PASS
# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x128x64, Tile 128x128x64, stages 2 --===
-# CHECK: PASS
+# CHECK: PASS
# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x128x64, Tile 128x128x64, stages 2 --===
-# CHECK: PASS
-# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x128x512, Tile 128x128x64, stages 3 --===
-# CHECK: PASS
-# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x128x512, Tile 128x128x64, stages 3 --===
-# CHECK: PASS
-# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x512x64, Tile 128x128x64, stages 3 --===
-# CHECK: PASS
-# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x512x64, Tile 128x128x64, stages 3 --===
-# CHECK: PASS
-# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x512x512, Tile 128x128x64, stages 3 --===
-# CHECK: PASS
-# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x512x512, Tile 128x128x64, stages 3 --===
+# CHECK: PASS
+# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x128x256, Tile 128x128x64, stages 3 --===
+# CHECK: PASS
+# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x128x256, Tile 128x128x64, stages 3 --===
# CHECK: PASS
-test_short()
\ No newline at end of file
+test_short()
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
index a541378359fe97..fefc65b854040f 100644
--- a/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
+++ b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
@@ -117,11 +117,7 @@ def make_kernel_name(
num_stages=3,
use_warp_specialization=False,
):
- kernelName = (
- "warpspecialized"
- if use_warp_specialization
- else "multistage"
- )
+ kernelName = "warpspecialized" if use_warp_specialization else "multistage"
return (
kernelName
+ "_"
@@ -140,6 +136,7 @@ def make_kernel_name(
+ str(num_stages)
)
+
def generate_matmul_ws(
input_type=np.float16,
output_type=np.float32,
@@ -644,7 +641,7 @@ def generate_matmul_ws(
t8 = gpu.wait(token_ty, [launch_op])
t9 = gpu.memcpy(token_ty, [t8], c_host, c_device)
gpu.dealloc(token_ty, [t8], a_device)
- gpu.dealloc(token_ty, [t8], b_device)
+ gpu.dealloc(token_ty, [t8], b_device)
gpu.wait(token_ty, [t9])
gpu.dealloc(token_ty, [t8], c_device)
func.ReturnOp([])
@@ -1162,7 +1159,7 @@ def generate_matmul_multistage(
t8 = gpu.wait(token_ty, [launch_op])
t9 = gpu.memcpy(token_ty, [t8], c_host, c_device)
gpu.dealloc(token_ty, [t8], a_device)
- gpu.dealloc(token_ty, [t8], b_device)
+ gpu.dealloc(token_ty, [t8], b_device)
gpu.wait(token_ty, [t9])
gpu.dealloc(token_ty, [t8], c_device)
func.ReturnOp([])
>From 89370ef6180d2ce6f0b81ac106ecd623b89f3cda Mon Sep 17 00:00:00 2001
From: grypp <guray.ozen at gmail.com>
Date: Sun, 3 Mar 2024 08:44:07 +0000
Subject: [PATCH 09/10] Add asserts 128x128x64
---
.../Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py | 6 ++++++
1 file changed, 6 insertions(+)
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
index fefc65b854040f..0e496d897c8373 100644
--- a/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
+++ b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
@@ -151,6 +151,9 @@ def generate_matmul_ws(
# Limitaitons for now
assert input_type == np.float16
assert output_type == np.float32
+ assert BLOCK_M == 128
+ assert BLOCK_N == 128
+ assert BLOCK_K == 64
assert M % BLOCK_M == 0
assert N % BLOCK_N == 0
assert K % BLOCK_K == 0
@@ -665,6 +668,9 @@ def generate_matmul_multistage(
# Limitaitons for now
assert input_type == np.float16
assert output_type == np.float32
+ assert BLOCK_M == 128
+ assert BLOCK_N == 128
+ assert BLOCK_K == 64
assert M % BLOCK_M == 0
assert N % BLOCK_N == 0
assert K % BLOCK_K == 0
>From 8f7efd6db8cba383349462e4e89b04d57225cb3f Mon Sep 17 00:00:00 2001
From: grypp <guray.ozen at gmail.com>
Date: Sun, 3 Mar 2024 16:37:45 +0000
Subject: [PATCH 10/10] Address comments
---
.../GPU/CUDA/sm90/python/matmul.py | 2 +-
.../CUDA/sm90/python/tools/matmulBuilder.py | 45 ++++++-------------
2 files changed, 14 insertions(+), 33 deletions(-)
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py b/mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py
index dd474b01ffabef..cb7248ef23cd9e 100644
--- a/mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py
+++ b/mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py
@@ -98,6 +98,7 @@ def generate_matmul(
use_warp_specialization=True,
saveIR=False,
max_num_stages=3,
+ options=f"cubin-chip=sm_90a cubin-features=+ptx80 opt-level=3",
):
with matmulBuilder.ir.Context() as ctx, matmulBuilder.ir.Location.unknown():
if use_warp_specialization:
@@ -137,7 +138,6 @@ def generate_matmul(
sys.stdout = original_stdout
# Get compiler
- options = f"cubin-chip=sm_90a cubin-features=+ptx80 opt-level=3"
support_lib = os.getenv("SUPPORT_LIB")
if not os.path.exists(support_lib):
raise FileNotFoundError(
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
index 0e496d897c8373..fac138dce605a7 100644
--- a/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
+++ b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
@@ -22,6 +22,7 @@
PRODUCER_PRIMARY_THREAD = 128
CONSUMER_PRIMARY_THREAD = 0
+# C++ uses this value to understand whether it's dynamic or not.
MLIR_DYNAMIC = -9223372036854775808
DEBUG = False
@@ -58,31 +59,11 @@ def debug_print(fmt, *args, predicate=None, threadNumber=-1, forcePrint=False):
scf.yield_([])
-def get_type_str(ty):
- if ir.F16Type.isinstance(ty):
- return "f16"
- if ir.F32Type.isinstance(ty):
- return "f32"
- if ir.F64Type.isinstance(ty):
- return "f64"
- if ir.IntegerType.isinstance(ty):
- return "i" + str(ir.IntegerType(ty).width)
- if ir.IndexType.isinstance(ty):
- return "T.index()"
- raise NotImplementedError(ty)
-
-
def get_type_size(ty):
- if ir.F16Type.isinstance(ty):
- return 2
- if ir.F32Type.isinstance(ty):
- return 4
- if ir.F64Type.isinstance(ty):
- return 8
+ if ir.FloatType.isinstance(ty):
+ return ir.FloatType(ty).width // 8
if ir.IntegerType.isinstance(ty):
return ir.IntegerType(ty).width // 8
- if ir.IndexType.isinstance(ty):
- return 8
raise NotImplementedError(ty)
@@ -187,7 +168,7 @@ def generate_matmul_ws(
+ "x"
+ str(TMA_LAST_DIM_F16)
+ "x"
- + get_type_str(a_elem_ty)
+ + str(a_elem_ty)
+ ", "
+ str(smem_space)
+ ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>"
@@ -198,7 +179,7 @@ def generate_matmul_ws(
+ "x"
+ str(TMA_LAST_DIM_F16)
+ "x"
- + get_type_str(b_elem_ty)
+ + str(b_elem_ty)
+ ", "
+ str(smem_space)
+ ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>"
@@ -209,7 +190,7 @@ def generate_matmul_ws(
+ "x"
+ str(BLOCK_N)
+ "x"
- + get_type_str(c_elem_ty)
+ + str(c_elem_ty)
+ ">>"
)
a_wgmma_ty = ir.Type.parse(
@@ -218,7 +199,7 @@ def generate_matmul_ws(
+ "x"
+ str(BLOCK_K)
+ "x"
- + get_type_str(a_elem_ty)
+ + str(a_elem_ty)
+ ", "
+ smem_space_str
+ ">>"
@@ -229,7 +210,7 @@ def generate_matmul_ws(
+ "x"
+ str(BLOCK_N)
+ "x"
- + get_type_str(a_elem_ty)
+ + str(a_elem_ty)
+ ", "
+ smem_space_str
+ ">>"
@@ -704,7 +685,7 @@ def generate_matmul_multistage(
+ "x"
+ str(TMA_LAST_DIM_F16)
+ "x"
- + get_type_str(a_elem_ty)
+ + str(a_elem_ty)
+ ", "
+ str(smem_space)
+ ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>"
@@ -715,7 +696,7 @@ def generate_matmul_multistage(
+ "x"
+ str(TMA_LAST_DIM_F16)
+ "x"
- + get_type_str(b_elem_ty)
+ + str(b_elem_ty)
+ ", "
+ str(smem_space)
+ ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>"
@@ -726,7 +707,7 @@ def generate_matmul_multistage(
+ "x"
+ str(BLOCK_N)
+ "x"
- + get_type_str(c_elem_ty)
+ + str(c_elem_ty)
+ ">>"
)
a_wgmma_ty = ir.Type.parse(
@@ -735,7 +716,7 @@ def generate_matmul_multistage(
+ "x"
+ str(BLOCK_K)
+ "x"
- + get_type_str(a_elem_ty)
+ + str(a_elem_ty)
+ ", "
+ smem_space_str
+ ">>"
@@ -746,7 +727,7 @@ def generate_matmul_multistage(
+ "x"
+ str(BLOCK_N)
+ "x"
- + get_type_str(a_elem_ty)
+ + str(a_elem_ty)
+ ", "
+ smem_space_str
+ ">>"
More information about the Mlir-commits
mailing list