[Mlir-commits] [mlir] d95e6d0 - [mlir] GEMM Hopper Tensor Core Integration Test (#81478)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Mar 4 13:04:03 PST 2024
Author: Guray Ozen
Date: 2024-03-04T21:03:59Z
New Revision: d95e6d027486876559f1a2a96c33b8ad93cc0ae4
URL: https://github.com/llvm/llvm-project/commit/d95e6d027486876559f1a2a96c33b8ad93cc0ae4
DIFF: https://github.com/llvm/llvm-project/commit/d95e6d027486876559f1a2a96c33b8ad93cc0ae4.diff
LOG: [mlir] GEMM Hopper Tensor Core Integration Test (#81478)
Added:
mlir/test/Integration/GPU/CUDA/sm90/python/lit.local.cfg
mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py
mlir/test/Integration/GPU/CUDA/sm90/python/tools/lit.local.cfg
mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
mlir/test/Integration/GPU/CUDA/sm90/python/tools/nvgpucompiler.py
Modified:
Removed:
################################################################################
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/python/lit.local.cfg b/mlir/test/Integration/GPU/CUDA/sm90/python/lit.local.cfg
new file mode 100644
index 00000000000000..2d5a9d00e73226
--- /dev/null
+++ b/mlir/test/Integration/GPU/CUDA/sm90/python/lit.local.cfg
@@ -0,0 +1,2 @@
+if not config.enable_cuda_runner or not config.mlir_run_cuda_sm90_tests:
+ config.unsupported = True
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py b/mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py
new file mode 100644
index 00000000000000..cb7248ef23cd9e
--- /dev/null
+++ b/mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py
@@ -0,0 +1,341 @@
+# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
+# RUN: %PYTHON %s | FileCheck %s
+
+
+# ===--- GEMM Hopper Tensor Core Integration Test ---===
+#
+# This test aims to validate the correctness of the supported GEMM kernels in
+# NVGPU dialects, with current support for Multistage and Warp Specialization
+# kernels.
+# The test constructs and metaprograms IR using Python bindings, allowing
+# generic IR building. This flexibility enables changes to the shape,
+# tile size, or data type of the GEMM for testing purposes.
+# The entry function is `matmul`, where one can specify GEMM shape, tile size,
+# data type, GEMM algorithm (Multistage or Warp Specialization), and the maximum
+# number of stages.
+# Verification is done via numpy's matmul operation.
+#
+# Example:
+# matmul(input_type=np.float16, # input types
+# output_type=np.float32, # output type
+# M=4096, N=4096, K=4096, # Shape
+# BLOCK_M=128, BLOCK_N=128, BLOCK_K=64, # Tile Size
+# use_warp_specialization=True, # Enable Warp Specialization
+# max_num_stages=3) # Number of stages in shared memory
+#
+# ===--- Parallelism Across CTAs ---===
+#
+# GEMM includes three loops defining the shape of the GEMM, specified in the
+# `matmul` function.
+# The program builds IR using the following loop structure, tiling the loops
+# with the given tile size and parallelizing the two outermost loops into the
+# first and second dimensions of CTAs.
+#
+# for(bi = 0; i < M; i += BLOCK_M) # parallelize across blockIdx.x
+# for(bj = 0; j < N; j += BLOCK_N) # parallelize across blockIdx.y
+# for(bk = 0; k < K; K += BLOCK_K)
+# for(i = bi; i < (bi + BLOCK_M); ++i)
+# for(j = bj; j < (bj + BLOCK_N); ++j)
+# for(k = bk; k < (bk + BLOCK_K); ++k)
+#
+# ===--- Multistage Kernel ---===
+#
+# This kernel launches a single warp group (128 threads). The primary thread
+# (pthread) requests load from TMA. Threads collectively wait for the data and
+# perform mma operations. After completing the shape, threads together store
+# first fragmented registers to shared memory, then from shared memory to global
+# memory; this part is called the epilogue.
+#
+# Execution Timeline of Multistage Kernel with 3 stages:
+# +-------+----------------+--------------------+--------------------+--------------------+-----+-----------------------+
+# | |Prologue ----> |MainLoop ----> |Epilogue |
+# +-------+----------------+--------------------+--------------------+--------------------+-----+-----------------------+
+# |pthread|[tma-0,1,2] |[wait-0][mma][tma-2]|[wait-1][mma][tma-0]|[wait-2][mma][tma-1]| ... | [mma-wait] |[epilogue]|
+# |wgroup | ........ |[wait-0][mma] |[wait-1][mma] |[wait-2][mma] | ... | [mma-wait] |[epilogue]|
+# +-------+----------------+--------------------+--------------------+--------------------+-----+-----------------------+
+#
+# ===--- Warp Specialization Kernel ---===
+#
+# This kernel launches 2 warp groups (2x128 threads) per CTA, specializing one
+# as `producer warp group` and another as `consumer warp group`. The
+# `producer warp group` is responsible for requesting TMA load, while the
+# `consumer warp group` performs the mma operation. The epilogue section is
+# handled by the `consumer warp group` as its threads own the fragmented registers.
+#
+# Execution Timeline of Warp Specialization Kernel with 2 stages:
+# +--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+
+# | |MainLoop ----> | 1st Epilogue | 2nd Epilogue |
+# +--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+
+# |pthread1|[tma-0] | [tma-1] | [tma-0] | [tma-1] | ..........................| ........... | [shmem->global] |
+# |wgroup1 | .......| | | | | | [shmem->global] |
+# +--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+
+# |wgroup2 |[wait-0][mma], [wait-1][mma], [wait-0][mma], [wait-1][mma], ......| [reg->shmem] | [shmem->global]|
+# +--------+--------+---------+---------+---------+-----------------------+---+--------------+-----------------+
+
+import errno
+import numpy as np
+import subprocess
+import ctypes
+from tools import nvgpucompiler
+from tools import matmulBuilder
+import contextlib
+import os
+import sys
+import pathlib
+import ctypes
+from mlir import runtime as rt
+
+
+def generate_matmul(
+ input_type=np.float16,
+ output_type=np.float32,
+ M=4096,
+ N=4096,
+ K=4096,
+ BLOCK_M=128,
+ BLOCK_N=128,
+ BLOCK_K=64,
+ use_warp_specialization=True,
+ saveIR=False,
+ max_num_stages=3,
+ options=f"cubin-chip=sm_90a cubin-features=+ptx80 opt-level=3",
+):
+ with matmulBuilder.ir.Context() as ctx, matmulBuilder.ir.Location.unknown():
+ if use_warp_specialization:
+ mlir_nvgpu_module = matmulBuilder.generate_matmul_ws(
+ input_type,
+ output_type,
+ M,
+ N,
+ K,
+ BLOCK_M,
+ BLOCK_N,
+ BLOCK_K,
+ max_num_stages,
+ )
+ else:
+ mlir_nvgpu_module = matmulBuilder.generate_matmul_multistage(
+ input_type,
+ output_type,
+ M,
+ N,
+ K,
+ BLOCK_M,
+ BLOCK_N,
+ BLOCK_K,
+ max_num_stages,
+ )
+
+ mlir_nvgpu_module.operation.verify()
+
+ # Save generated IR
+ if saveIR:
+ # print(mlir_nvgpu_module)
+ original_stdout = sys.stdout
+ with open("gemm.mlir", "w") as f:
+ sys.stdout = f
+ print(mlir_nvgpu_module)
+ sys.stdout = original_stdout
+
+ # Get compiler
+ support_lib = os.getenv("SUPPORT_LIB")
+ if not os.path.exists(support_lib):
+ raise FileNotFoundError(
+ errno.ENOENT, os.strerror(errno.ENOENT), support_lib
+ )
+ compiler = nvgpucompiler.NvgpuCompiler(
+ options, opt_level=3, shared_libs=[support_lib]
+ )
+
+ # Compile
+ engine = compiler.compile_and_jit(mlir_nvgpu_module)
+ return engine
+
+
+def matmul(
+ input_type=np.float16,
+ output_type=np.float32,
+ M=128,
+ N=128,
+ K=128,
+ BLOCK_M=128,
+ BLOCK_N=128,
+ BLOCK_K=64,
+ use_warp_specialization=True,
+ saveIR=False,
+ max_num_stages=3,
+ print_results=False,
+ no_verify=False,
+):
+ # Print the configuration
+ required_stages = (M * K + K * N) // (BLOCK_M * BLOCK_K + BLOCK_K * BLOCK_N)
+ num_stages = min(required_stages, max_num_stages)
+ ity = "f16" if input_type == np.float16 else "f32"
+ oty = "f16" if output_type == np.float16 else "f32"
+ gemmty = "Warp specialization" if use_warp_specialization else "Multistage"
+ print(
+ "===-- Running GEMM "
+ + gemmty
+ + " "
+ + oty
+ + " += "
+ + ity
+ + " * "
+ + ity
+ + ", Size "
+ + str(M)
+ + "x"
+ + str(N)
+ + "x"
+ + str(K)
+ + ", Tile "
+ + str(BLOCK_M)
+ + "x"
+ + str(BLOCK_N)
+ + "x"
+ + str(BLOCK_K)
+ + ", stages "
+ + str(num_stages)
+ + " --==="
+ )
+
+ # Build IR and compile
+ engine = generate_matmul(
+ input_type,
+ output_type,
+ M,
+ N,
+ K,
+ BLOCK_M,
+ BLOCK_N,
+ BLOCK_K,
+ use_warp_specialization,
+ saveIR,
+ num_stages,
+ )
+
+ # Allocate matrices and invoke the matmul
+ c = np.zeros((M, N), output_type)
+ a = np.random.randn(M, K).astype(input_type)
+ b = np.random.randn(K, N).astype(input_type)
+ mem_a = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(a)))
+ mem_b = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(b)))
+ mem_c = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(c)))
+ kernelName = matmulBuilder.make_kernel_name(
+ input_type,
+ output_type,
+ M,
+ N,
+ K,
+ BLOCK_M,
+ BLOCK_N,
+ BLOCK_K,
+ num_stages,
+ use_warp_specialization,
+ )
+
+ # Launch the MLIR generated kernel
+ engine.invoke(kernelName, mem_a, mem_b, mem_c)
+
+ float_formatter = "{:.2f}".format
+ np.set_printoptions(formatter={"float_kind": float_formatter})
+
+ if print_results:
+ print(c)
+
+ # Verify the results
+ if not no_verify:
+ ref = a.astype(input_type) @ b.astype(input_type)
+ if print_results:
+ print(ref)
+ np.testing.assert_allclose(c, ref, rtol=5e-03, atol=1e-01)
+
+ print("PASS ")
+
+
+# Takes longer time to run
+def test_long():
+ for stages in range(1, 7):
+ for M in [128, 512, 1024, 4096, 8192]:
+ for N in [128, 512, 1024, 4096, 8192]:
+ for K in [64, 128, 512, 1024, 4096, 8192]:
+ matmul(
+ np.float16,
+ np.float32,
+ M,
+ N,
+ K,
+ max_num_stages=stages,
+ use_warp_specialization=False,
+ no_verify=True,
+ )
+ matmul(
+ np.float16,
+ np.float32,
+ M,
+ N,
+ K,
+ max_num_stages=stages,
+ use_warp_specialization=True,
+ )
+
+
+def test_short():
+ for stages in [1, 3]:
+ for M in [128, 512]:
+ for N in [128]:
+ for K in [64, 256]:
+ matmul(
+ np.float16,
+ np.float32,
+ M,
+ N,
+ K,
+ max_num_stages=stages,
+ use_warp_specialization=False,
+ )
+ matmul(
+ np.float16,
+ np.float32,
+ M,
+ N,
+ K,
+ max_num_stages=stages,
+ use_warp_specialization=True,
+ )
+
+
+# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x128x64, Tile 128x128x64, stages 1 --===
+# CHECK: PASS
+# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x128x64, Tile 128x128x64, stages 1 --===
+# CHECK: PASS
+# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x128x256, Tile 128x128x64, stages 1 --===
+# CHECK: PASS
+# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x128x256, Tile 128x128x64, stages 1 --===
+# CHECK: PASS
+# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x128x64, Tile 128x128x64, stages 1 --===
+# CHECK: PASS
+# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x128x64, Tile 128x128x64, stages 1 --===
+# CHECK: PASS
+# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x128x256, Tile 128x128x64, stages 1 --===
+# CHECK: PASS
+# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x128x256, Tile 128x128x64, stages 1 --===
+# CHECK: PASS
+# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x128x64, Tile 128x128x64, stages 1 --===
+# CHECK: PASS
+# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x128x64, Tile 128x128x64, stages 1 --===
+# CHECK: PASS
+# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 128x128x256, Tile 128x128x64, stages 3 --===
+# CHECK: PASS
+# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 128x128x256, Tile 128x128x64, stages 3 --===
+# CHECK: PASS
+# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x128x64, Tile 128x128x64, stages 2 --===
+# CHECK: PASS
+# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x128x64, Tile 128x128x64, stages 2 --===
+# CHECK: PASS
+# CHECK: ===-- Running GEMM Multistage f32 += f16 * f16, Size 512x128x256, Tile 128x128x64, stages 3 --===
+# CHECK: PASS
+# CHECK: ===-- Running GEMM Warp specialization f32 += f16 * f16, Size 512x128x256, Tile 128x128x64, stages 3 --===
+# CHECK: PASS
+
+test_short()
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/python/tools/lit.local.cfg b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/lit.local.cfg
new file mode 100644
index 00000000000000..d9f34f219c4d95
--- /dev/null
+++ b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/lit.local.cfg
@@ -0,0 +1,3 @@
+# Files in this directory are tools, not tests.
+config.unsupported = True
+
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
new file mode 100644
index 00000000000000..fac138dce605a7
--- /dev/null
+++ b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py
@@ -0,0 +1,1156 @@
+import numpy as np
+from mlir import ir
+from mlir.dialects import arith
+from mlir.dialects import func
+from mlir.dialects import gpu
+from mlir.dialects import memref
+from mlir.dialects import nvgpu
+from mlir.dialects import nvvm
+from mlir.dialects import llvm
+from mlir.dialects import builtin
+from mlir.dialects import scf
+from mlir.dialects import vector
+from mlir.extras import types as T
+
+TMA_LAST_DIM_F16 = 64 # 128B flaot16
+WARP_SIZE = 32
+WARP_GROUP_SIZE = WARP_SIZE * 4
+
+PRODUCER_REGISTER_SIZE = 40
+CONSUMER_REGISTER_SIZE = 232
+
+PRODUCER_PRIMARY_THREAD = 128
+CONSUMER_PRIMARY_THREAD = 0
+
+# C++ uses this value to understand whether it's dynamic or not.
+MLIR_DYNAMIC = -9223372036854775808
+
+DEBUG = False
+
+
+def debug_print(fmt, *args, predicate=None, threadNumber=-1, forcePrint=False):
+ if not DEBUG and not forcePrint:
+ return
+ type_formats = []
+ for arg in args:
+ ty_format = None
+ if ir.IndexType.isinstance(arg.type):
+ ty_format = "%llu"
+ if ir.IntegerType.isinstance(arg.type):
+ width = ir.IntegerType(arg.type).width
+ if width == 64:
+ ty_format = "%llu"
+ elif width == 32:
+ ty_format = "%d"
+ elif width == 1:
+ ty_format = "%i"
+ if ir.F32Type.isinstance(arg.type):
+ ty_format = "%f"
+ if ty_format is None:
+ raise NotImplementedError(arg.type)
+ type_formats.append(ty_format)
+ if threadNumber != -1:
+ tidx = gpu.thread_id(gpu.Dimension.x)
+ predicate = arith.cmpi(arith.CmpIPredicate.eq, tidx, c(threadNumber))
+ scf.yield_([])
+ if_op = scf.IfOp(predicate)
+ with ir.InsertionPoint(if_op.then_block):
+ gpu.printf(fmt.format(*type_formats) + "\n", args)
+ scf.yield_([])
+
+
+def get_type_size(ty):
+ if ir.FloatType.isinstance(ty):
+ return ir.FloatType(ty).width // 8
+ if ir.IntegerType.isinstance(ty):
+ return ir.IntegerType(ty).width // 8
+ raise NotImplementedError(ty)
+
+
+def get_mlir_ty(dtype):
+ if dtype == np.float16:
+ return T.f16()
+ if dtype == np.float32:
+ return T.f32()
+ if dtype == np.float64:
+ return T.f64()
+ if dtype == np.int32:
+ return T.i32()
+ if dtype == np.int64:
+ return T.i64()
+ raise NotImplementedError(dtype)
+
+
+def c(value, ty=None):
+ ty = T.index() if ty is None else ty
+ return arith.constant(ty, value)
+
+
+def make_kernel_name(
+ input_type=np.float16,
+ output_type=np.float32,
+ M=4096,
+ N=4096,
+ K=4096,
+ BLOCK_M=128,
+ BLOCK_N=128,
+ BLOCK_K=128,
+ num_stages=3,
+ use_warp_specialization=False,
+):
+ kernelName = "warpspecialized" if use_warp_specialization else "multistage"
+ return (
+ kernelName
+ + "_"
+ + str(M)
+ + "x"
+ + str(N)
+ + "x"
+ + str(K)
+ + "_"
+ + str(BLOCK_M)
+ + "x"
+ + str(BLOCK_N)
+ + "x"
+ + str(BLOCK_K)
+ + "_"
+ + str(num_stages)
+ )
+
+
+def generate_matmul_ws(
+ input_type=np.float16,
+ output_type=np.float32,
+ M=4096,
+ N=4096,
+ K=4096,
+ BLOCK_M=128,
+ BLOCK_N=128,
+ BLOCK_K=128,
+ num_stages=3,
+):
+ # Limitaitons for now
+ assert input_type == np.float16
+ assert output_type == np.float32
+ assert BLOCK_M == 128
+ assert BLOCK_N == 128
+ assert BLOCK_K == 64
+ assert M % BLOCK_M == 0
+ assert N % BLOCK_N == 0
+ assert K % BLOCK_K == 0
+
+ module = ir.Module.create()
+ token_ty = ir.Type.parse("!gpu.async.token")
+ a_elem_ty = get_mlir_ty(input_type)
+ b_elem_ty = get_mlir_ty(input_type)
+ c_elem_ty = get_mlir_ty(output_type)
+ a_ty = ir.MemRefType.get([M, K], a_elem_ty)
+ b_ty = ir.MemRefType.get((K, N), b_elem_ty)
+ c_ty = ir.MemRefType.get((M, N), c_elem_ty)
+ a_tile_shape = a_tma_shape = (BLOCK_M, TMA_LAST_DIM_F16)
+ b_tma_shape = (BLOCK_K, TMA_LAST_DIM_F16)
+ b_tile_shape = (BLOCK_K, BLOCK_N)
+ txcount = (b_tile_shape[0] * b_tile_shape[1] * get_type_size(a_elem_ty)) + (
+ a_tile_shape[0] * a_tile_shape[1] * get_type_size(b_elem_ty)
+ )
+ smem_space_str = "#gpu.address_space<workgroup>"
+ smem_space = ir.Attribute.parse(smem_space_str)
+ mbar_ty = ir.Type.parse(
+ "!nvgpu.mbarrier.group<memorySpace = "
+ + str(smem_space)
+ + ", num_barriers = "
+ + str(num_stages)
+ + ">"
+ )
+ a_tma_desc_ty = ir.Type.parse(
+ "!nvgpu.tensormap.descriptor<tensor = memref<"
+ + str(BLOCK_M)
+ + "x"
+ + str(TMA_LAST_DIM_F16)
+ + "x"
+ + str(a_elem_ty)
+ + ", "
+ + str(smem_space)
+ + ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>"
+ )
+ b_tma_desc_ty = ir.Type.parse(
+ "!nvgpu.tensormap.descriptor<tensor = memref<"
+ + str(BLOCK_K)
+ + "x"
+ + str(TMA_LAST_DIM_F16)
+ + "x"
+ + str(b_elem_ty)
+ + ", "
+ + str(smem_space)
+ + ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>"
+ )
+ acc_ty = ir.Type.parse(
+ "!nvgpu.warpgroup.accumulator<fragmented=vector<"
+ + str(BLOCK_M)
+ + "x"
+ + str(BLOCK_N)
+ + "x"
+ + str(c_elem_ty)
+ + ">>"
+ )
+ a_wgmma_ty = ir.Type.parse(
+ "!nvgpu.warpgroup.descriptor<tensor=memref<"
+ + str(BLOCK_M)
+ + "x"
+ + str(BLOCK_K)
+ + "x"
+ + str(a_elem_ty)
+ + ", "
+ + smem_space_str
+ + ">>"
+ )
+ b_wgmma_ty = ir.Type.parse(
+ "!nvgpu.warpgroup.descriptor<tensor=memref<"
+ + str(BLOCK_K)
+ + "x"
+ + str(BLOCK_N)
+ + "x"
+ + str(a_elem_ty)
+ + ", "
+ + smem_space_str
+ + ">>"
+ )
+ kernelName = make_kernel_name(
+ input_type, output_type, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, num_stages, True
+ )
+ with ir.InsertionPoint(module.body):
+ fop = func.FuncOp(kernelName, ([a_ty, b_ty, c_ty], []))
+ with ir.InsertionPoint(fop.add_entry_block()):
+ a_host = fop.arguments[0]
+ b_host = fop.arguments[1]
+ c_host = fop.arguments[2]
+ lhs_tile_bytes = BLOCK_M * BLOCK_K * get_type_size(a_elem_ty)
+ rhs_tile_bytes = BLOCK_N * BLOCK_K * get_type_size(b_elem_ty)
+ smem_size_input = (lhs_tile_bytes + rhs_tile_bytes) * num_stages
+ smem_size_output = BLOCK_M * BLOCK_N * get_type_size(c_elem_ty)
+ smem_size = max(smem_size_input, smem_size_output)
+
+ # Step 1. Allocate device memory and memcpy
+ t1 = gpu.wait(token_ty, [])
+ a_device, t2 = gpu.alloc(a_ty, token_ty, [t1], [], [])
+ b_device, t3 = gpu.alloc(b_ty, token_ty, [t2], [], [])
+ c_device, t4 = gpu.alloc(c_ty, token_ty, [t3], [], [])
+ t5 = gpu.memcpy(token_ty, [t4], a_device, a_host)
+ t6 = gpu.memcpy(token_ty, [t5], b_device, b_host)
+ t7 = gpu.wait(token_ty, [t6])
+
+ # Step 2. Create TMA Descriptors
+ tma_specs = [
+ (a_device, a_tma_desc_ty, a_tma_shape),
+ (b_device, b_tma_desc_ty, b_tma_shape),
+ ]
+ tma_descs = []
+ for x_device, tensor_map_ty, tile_shape in tma_specs:
+ x_unranked = memref.cast(
+ ir.UnrankedMemRefType.get(a_elem_ty, a_ty.memory_space), x_device
+ )
+ tma_descs.append(
+ nvgpu.TmaCreateDescriptorOp(
+ tensor_map_ty, x_unranked, map(c, tile_shape)
+ ).result
+ )
+ a_tma_desc, b_tma_desc = tma_descs
+
+ # Step 3. Launch Kernel with 2 Warpgroups : 1 Producer, 1 Consumer
+ cta_m = M // BLOCK_M
+ cta_n = N // BLOCK_N
+ assert M % BLOCK_M == 0 and N % BLOCK_N == 0
+ grid = (cta_m, cta_n, 1)
+ block = (WARP_GROUP_SIZE * 2, 1, 1)
+ launch_op = gpu.LaunchOp(
+ token_ty,
+ [t7],
+ *map(c, grid),
+ *map(c, block),
+ dynamicSharedMemorySize=c(smem_size, ty=T.i32())
+ )
+ launch_op.body.blocks.append(*([T.index()] * 12))
+ with ir.InsertionPoint(launch_op.body.blocks[0]):
+ # GPU Step 0. This is need for vectorized ld/st
+ memref.assume_alignment(c_device, 16)
+ dynamic_smem = gpu.dynamic_shared_memory(
+ ir.MemRefType.get((MLIR_DYNAMIC,), T.i8(), memory_space=smem_space)
+ )
+ ticks = c(10000000)
+
+ # GPU Step 1. Bootstrapping: find the primary thread, warps, warp groups and etc.
+ tidx = gpu.thread_id(gpu.Dimension.x)
+ wgPrimaryThread = arith.cmpi(
+ arith.CmpIPredicate.eq, arith.remui(tidx, c(WARP_GROUP_SIZE)), c(0)
+ )
+ warp_id = arith.divui(tidx, c(32))
+ warpgroup_id = arith.divui(warp_id, c(4))
+ is_producer = arith.cmpi(
+ arith.CmpIPredicate.eq,
+ warpgroup_id,
+ c(1 if PRODUCER_PRIMARY_THREAD == 128 else 0),
+ )
+ is_consumer = arith.cmpi(
+ arith.CmpIPredicate.eq,
+ warpgroup_id,
+ c(0 if CONSUMER_PRIMARY_THREAD == 0 else 1),
+ )
+ producerPrimaryThread = arith.cmpi(
+ arith.CmpIPredicate.eq, tidx, c(PRODUCER_PRIMARY_THREAD)
+ )
+ consumerPrimaryThread = arith.cmpi(
+ arith.CmpIPredicate.eq, tidx, c(CONSUMER_PRIMARY_THREAD)
+ )
+ bidx = gpu.block_id(gpu.Dimension.x)
+ bidy = gpu.block_id(gpu.Dimension.y)
+ dimX = arith.muli(bidx, c(BLOCK_M))
+ dimY = arith.muli(bidy, c(BLOCK_N))
+
+ # GPU Step 2. Initialize mbarrier groups
+ mbarTMA = nvgpu.mbarrier_create(mbar_ty)
+ mbarDONE = nvgpu.mbarrier_create(mbar_ty)
+ for i in range(num_stages):
+ nvgpu.mbarrier_init(mbarTMA, c(1), c(i), predicate=wgPrimaryThread)
+ nvgpu.mbarrier_init(mbarDONE, c(1), c(i), predicate=wgPrimaryThread)
+ gpu.barrier()
+
+ # GPU Step 3. Prefetch TMA descriptors
+ nvgpu.tma_prefetch_descriptor(a_tma_desc, predicate=wgPrimaryThread)
+ nvgpu.tma_prefetch_descriptor(b_tma_desc, predicate=wgPrimaryThread)
+
+ ns = num_stages if num_stages == 1 else num_stages - 1
+ # GPU Step 5. Producer Warpgroup (TMA Warpgroup)
+ with ir.InsertionPoint(scf.IfOp(is_producer).then_block):
+ # Step 5.1. Reduce register size
+ nvvm.setmaxregister(
+ PRODUCER_REGISTER_SIZE, nvvm.SetMaxRegisterAction.decrease
+ )
+
+ # Step 5.2. TMA Main Loop
+ for_op = scf.ForOp(
+ c(0), c(K // BLOCK_K), c(1), [arith.constant(T.bool(), 1)]
+ )
+ with ir.InsertionPoint(for_op.body):
+ phaseParity = for_op.inner_iter_args[0]
+ iv = for_op.induction_variable
+ stage = arith.remui(iv, c(num_stages))
+
+ # Step 5.2.1. Wait mbarDONE
+ debug_print(
+ "[prod] iv={} | mbarDONE[{}] try_wait phase={}",
+ iv,
+ stage,
+ phaseParity,
+ predicate=producerPrimaryThread,
+ )
+ nvgpu.MBarrierTryWaitParityOp(
+ mbarDONE, phaseParity, ticks, mbarId=stage
+ )
+ debug_print(
+ "[prod] iv={} | mbarDONE[{}] try_wait phase={} [done]",
+ iv,
+ stage,
+ phaseParity,
+ predicate=producerPrimaryThread,
+ )
+ p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(num_stages - 1))
+ phaseParity = arith.select(
+ p,
+ arith.xori(phaseParity, arith.constant(T.bool(), 1)),
+ phaseParity,
+ )
+
+ # Step 5.2.2. Load TMA
+ a_offset = arith.muli(stage, c(lhs_tile_bytes))
+ a_tma_slice = memref.view(
+ ir.MemRefType.get(
+ a_tma_shape, a_elem_ty, memory_space=smem_space
+ ),
+ dynamic_smem,
+ a_offset,
+ [],
+ )
+ b_offset = arith.addi(
+ arith.muli(stage, c(rhs_tile_bytes)),
+ c(lhs_tile_bytes * num_stages),
+ )
+ b_tma_slice_1 = memref.view(
+ ir.MemRefType.get(
+ b_tma_shape, b_elem_ty, memory_space=smem_space
+ ),
+ dynamic_smem,
+ b_offset,
+ [],
+ )
+ b_offset2 = arith.addi(
+ b_offset,
+ c(BLOCK_K * TMA_LAST_DIM_F16 * get_type_size(b_elem_ty)),
+ )
+ b_tma_slice_2 = memref.view(
+ ir.MemRefType.get(
+ b_tma_shape, b_elem_ty, memory_space=smem_space
+ ),
+ dynamic_smem,
+ b_offset2,
+ [],
+ )
+ debug_print(
+ "[prod] a_offset={} b_offset={} b_offset2={}",
+ a_offset,
+ b_offset,
+ b_offset2,
+ predicate=producerPrimaryThread,
+ )
+ coord = arith.muli(c(64), iv)
+ nvgpu.TmaAsyncLoadOp(
+ a_tma_slice,
+ mbarTMA,
+ a_tma_desc,
+ coordinates=[coord, dimX],
+ mbarId=stage,
+ predicate=producerPrimaryThread,
+ )
+ nvgpu.TmaAsyncLoadOp(
+ b_tma_slice_1,
+ mbarTMA,
+ b_tma_desc,
+ coordinates=[dimY, coord],
+ mbarId=stage,
+ predicate=producerPrimaryThread,
+ )
+ dimY2 = arith.addi(dimY, c(64))
+ nvgpu.TmaAsyncLoadOp(
+ b_tma_slice_2,
+ mbarTMA,
+ b_tma_desc,
+ coordinates=[dimY2, coord],
+ mbarId=stage,
+ predicate=producerPrimaryThread,
+ )
+
+ # Step 5.2.3. Arrive mbarTMA
+ debug_print(
+ "[prod] iv={} | mbarTMA[{}] arrive",
+ iv,
+ stage,
+ predicate=producerPrimaryThread,
+ )
+ nvgpu.mbarrier_arrive_expect_tx(
+ mbarTMA, c(txcount), stage, predicate=producerPrimaryThread
+ )
+ debug_print(
+ "[prod] iv={} | mbarTMA[{}] arrive [done]",
+ iv,
+ stage,
+ predicate=producerPrimaryThread,
+ )
+ scf.yield_([phaseParity])
+ scf.yield_([])
+
+ # GPU Step 6. Consumer Warpgroup (MMA Warpgroup)
+ if_op = scf.IfOp(is_consumer)
+ with ir.InsertionPoint(if_op.then_block):
+ # Step 6.1. Increase register size
+ nvvm.setmaxregister(
+ CONSUMER_REGISTER_SIZE, nvvm.SetMaxRegisterAction.increase
+ )
+
+ # GPU Step 6.2. Initialize MMA registers
+ acc = nvgpu.warpgroup_mma_init_accumulator(acc_ty)
+
+ # Step 6.3. MMA Main Loop
+ for_op = scf.ForOp(
+ c(0), c(K // BLOCK_K), c(1), [acc, arith.constant(T.bool(), 0)]
+ )
+ with ir.InsertionPoint(for_op.body):
+ # Step 6.3.1. Wait mbar1
+ phaseParity = for_op.inner_iter_args[1]
+ iv = for_op.induction_variable
+ stage = arith.remui(iv, c(num_stages))
+ debug_print(
+ "[cons] iv={} | mbarTMA[{}] try_wait phase={}",
+ iv,
+ stage,
+ phaseParity,
+ predicate=consumerPrimaryThread,
+ )
+ nvgpu.MBarrierTryWaitParityOp(
+ mbarTMA, phaseParity, ticks, mbarId=stage
+ )
+ debug_print(
+ "[cons] iv={} | mbarTMA[{}] try_wait phase={} [done]",
+ iv,
+ stage,
+ phaseParity,
+ predicate=consumerPrimaryThread,
+ )
+
+ # Step 6.3.2. Create WGMMA Descriptors
+ a_offset = arith.muli(stage, c(lhs_tile_bytes))
+ a_tile_slice = memref.view(
+ ir.MemRefType.get(
+ a_tile_shape, a_elem_ty, memory_space=smem_space
+ ),
+ dynamic_smem,
+ a_offset,
+ [],
+ )
+ b_offset = arith.addi(
+ arith.muli(stage, c(rhs_tile_bytes)),
+ c(lhs_tile_bytes * num_stages),
+ )
+ b_tile_slice = memref.view(
+ ir.MemRefType.get(
+ b_tile_shape, b_elem_ty, memory_space=smem_space
+ ),
+ dynamic_smem,
+ b_offset,
+ [],
+ )
+ debug_print(
+ "[cons] a_offset={} b_offset={}",
+ a_offset,
+ b_offset,
+ predicate=consumerPrimaryThread,
+ )
+ da = nvgpu.WarpgroupGenerateDescriptorOp(
+ a_wgmma_ty, a_tile_slice, a_tma_desc
+ )
+ db = nvgpu.WarpgroupGenerateDescriptorOp(
+ b_wgmma_ty, b_tile_slice, b_tma_desc
+ )
+
+ # Step 6.3.3. MMA
+ carry_acc = for_op.inner_iter_args[0]
+ new_acc = nvgpu.WarpgroupMmaOp(
+ acc.type, da, db, carry_acc, transposeB=True
+ )
+
+ # Step 6.3.4. Arrive mbarDONE
+ if num_stages == 1:
+ p_arrive = consumerPrimaryThread
+ else:
+ p1 = arith.cmpi(arith.CmpIPredicate.sgt, iv, c(0))
+ p_arrive = arith.andi(consumerPrimaryThread, p1)
+ with ir.InsertionPoint(scf.IfOp(p_arrive).then_block):
+ p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(0))
+ barId = arith.select(
+ p, c(num_stages - 1), arith.subi(stage, c(1))
+ )
+ debug_print(
+ "[cons] iv={} | mbarDONE[{}] arrive ",
+ iv,
+ barId,
+ predicate=consumerPrimaryThread,
+ )
+ nvgpu.mbarrier_arrive(
+ ir.Type.parse("!nvgpu.mbarrier.token"), mbarDONE, barId
+ )
+ debug_print(
+ "[cons] iv={} | mbarDONE[{}] arrive [done]",
+ iv,
+ barId,
+ predicate=consumerPrimaryThread,
+ )
+ scf.yield_([])
+
+ p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(num_stages - 1))
+ phaseParity = arith.select(
+ p,
+ arith.xori(phaseParity, arith.constant(T.bool(), 1)),
+ phaseParity,
+ )
+
+ # Step 6.3.5. Yield
+ scf.yield_([new_acc, phaseParity])
+
+ # Step 6.3. Wait All WGMMA
+ nvvm.WgmmaWaitGroupSyncOp(0)
+
+ with ir.InsertionPoint(scf.IfOp(consumerPrimaryThread).then_block):
+ barId = c((K // BLOCK_K) % num_stages)
+ nvgpu.mbarrier_arrive(
+ ir.Type.parse("!nvgpu.mbarrier.token"), mbarDONE, barId
+ )
+ scf.yield_([])
+
+ # Step 6.4. Epilogue (registers --> shared memory)
+ acc_smem_ty = ir.MemRefType.get(
+ (BLOCK_M, BLOCK_N), c_elem_ty, memory_space=smem_space
+ )
+ acc_smem = memref.view(acc_smem_ty, dynamic_smem, c(0), [])
+ debug_print("[cons] | Storing", predicate=consumerPrimaryThread)
+ nvgpu.WarpgroupMmaStoreOp(for_op.results[0], acc_smem)
+ scf.yield_([])
+ gpu.barrier()
+
+ # GPU Step 9. Epilogue (shared memory --> global memory)
+ fd = ir.MemRefType.get(
+ [BLOCK_M * BLOCK_N], c_elem_ty, memory_space=smem_space
+ )
+ collapsed_smem = memref.view(fd, dynamic_smem, c(0), [])
+ rty = ir.MemRefType.get(
+ (BLOCK_M, BLOCK_N),
+ c_elem_ty,
+ ir.Attribute.parse("strided<[" + str(N) + ", 1], offset: ?>"),
+ )
+ c_device_per_block = memref.SubViewOp(
+ rty,
+ c_device,
+ [dimX, dimY],
+ [],
+ [],
+ [MLIR_DYNAMIC, MLIR_DYNAMIC],
+ [BLOCK_M, BLOCK_N],
+ [1, 1],
+ )
+ vlen = 1
+ for_op = scf.ForOp(
+ tidx, c(BLOCK_M * BLOCK_N), c(vlen * WARP_GROUP_SIZE * 2)
+ )
+ with ir.InsertionPoint(for_op.body):
+ x = arith.divui(for_op.induction_variable, c(BLOCK_M))
+ y = arith.remui(for_op.induction_variable, c(BLOCK_N))
+ vdata = vector.load(
+ ir.VectorType.get((vlen,), c_elem_ty),
+ collapsed_smem,
+ [for_op.induction_variable],
+ )
+ vector.store(vdata, c_device_per_block, [x, y])
+ scf.yield_([])
+
+ gpu.terminator()
+
+ # Step 4. Copy back to host
+ t8 = gpu.wait(token_ty, [launch_op])
+ t9 = gpu.memcpy(token_ty, [t8], c_host, c_device)
+ gpu.dealloc(token_ty, [t8], a_device)
+ gpu.dealloc(token_ty, [t8], b_device)
+ gpu.wait(token_ty, [t9])
+ gpu.dealloc(token_ty, [t8], c_device)
+ func.ReturnOp([])
+
+ fop.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
+ module.operation.verify()
+ return module
+
+
+def generate_matmul_multistage(
+ input_type=np.float16,
+ output_type=np.float32,
+ M=4096,
+ N=4096,
+ K=4096,
+ BLOCK_M=128,
+ BLOCK_N=128,
+ BLOCK_K=64,
+ num_stages=3,
+):
+ # Limitaitons for now
+ assert input_type == np.float16
+ assert output_type == np.float32
+ assert BLOCK_M == 128
+ assert BLOCK_N == 128
+ assert BLOCK_K == 64
+ assert M % BLOCK_M == 0
+ assert N % BLOCK_N == 0
+ assert K % BLOCK_K == 0
+
+ module = ir.Module.create()
+ token_ty = ir.Type.parse("!gpu.async.token")
+ a_elem_ty = get_mlir_ty(input_type)
+ b_elem_ty = get_mlir_ty(input_type)
+ c_elem_ty = get_mlir_ty(output_type)
+ a_ty = ir.MemRefType.get([M, K], a_elem_ty)
+ b_ty = ir.MemRefType.get((K, N), b_elem_ty)
+ c_ty = ir.MemRefType.get((M, N), c_elem_ty)
+ a_tile_shape = a_tma_shape = (BLOCK_M, TMA_LAST_DIM_F16)
+ b_tma_shape = (BLOCK_K, TMA_LAST_DIM_F16)
+ b_tile_shape = (BLOCK_K, BLOCK_N)
+ txcount = (b_tile_shape[0] * b_tile_shape[1] * get_type_size(a_elem_ty)) + (
+ a_tile_shape[0] * a_tile_shape[1] * get_type_size(b_elem_ty)
+ )
+ smem_space_str = "#gpu.address_space<workgroup>"
+ smem_space = ir.Attribute.parse(smem_space_str)
+ mbar_ty = ir.Type.parse(
+ "!nvgpu.mbarrier.group<memorySpace = "
+ + str(smem_space)
+ + ", num_barriers = "
+ + str(num_stages)
+ + ">"
+ )
+ a_tma_desc_ty = ir.Type.parse(
+ "!nvgpu.tensormap.descriptor<tensor = memref<"
+ + str(BLOCK_M)
+ + "x"
+ + str(TMA_LAST_DIM_F16)
+ + "x"
+ + str(a_elem_ty)
+ + ", "
+ + str(smem_space)
+ + ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>"
+ )
+ b_tma_desc_ty = ir.Type.parse(
+ "!nvgpu.tensormap.descriptor<tensor = memref<"
+ + str(BLOCK_K)
+ + "x"
+ + str(TMA_LAST_DIM_F16)
+ + "x"
+ + str(b_elem_ty)
+ + ", "
+ + str(smem_space)
+ + ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>"
+ )
+ acc_ty = ir.Type.parse(
+ "!nvgpu.warpgroup.accumulator<fragmented=vector<"
+ + str(BLOCK_M)
+ + "x"
+ + str(BLOCK_N)
+ + "x"
+ + str(c_elem_ty)
+ + ">>"
+ )
+ a_wgmma_ty = ir.Type.parse(
+ "!nvgpu.warpgroup.descriptor<tensor=memref<"
+ + str(BLOCK_M)
+ + "x"
+ + str(BLOCK_K)
+ + "x"
+ + str(a_elem_ty)
+ + ", "
+ + smem_space_str
+ + ">>"
+ )
+ b_wgmma_ty = ir.Type.parse(
+ "!nvgpu.warpgroup.descriptor<tensor=memref<"
+ + str(BLOCK_K)
+ + "x"
+ + str(BLOCK_N)
+ + "x"
+ + str(a_elem_ty)
+ + ", "
+ + smem_space_str
+ + ">>"
+ )
+
+ with ir.InsertionPoint(module.body):
+ kernelName = make_kernel_name(
+ input_type,
+ output_type,
+ M,
+ N,
+ K,
+ BLOCK_M,
+ BLOCK_N,
+ BLOCK_K,
+ num_stages,
+ False,
+ )
+ fop = func.FuncOp(kernelName, ([a_ty, b_ty, c_ty], []))
+ with ir.InsertionPoint(fop.add_entry_block()):
+ a_host = fop.arguments[0]
+ b_host = fop.arguments[1]
+ c_host = fop.arguments[2]
+ lhs_tile_bytes = BLOCK_M * BLOCK_K * get_type_size(a_elem_ty)
+ rhs_tile_bytes = BLOCK_N * BLOCK_K * get_type_size(b_elem_ty)
+ smem_size_input = (lhs_tile_bytes + rhs_tile_bytes) * num_stages
+ smem_size_output = BLOCK_M * BLOCK_N * get_type_size(c_elem_ty)
+ smem_size = max(smem_size_input, smem_size_output)
+
+ # Step 1. Allocate device memory and memcpy
+ t1 = gpu.wait(token_ty, [])
+ a_device, t2 = gpu.alloc(a_ty, token_ty, [t1], [], [])
+ b_device, t3 = gpu.alloc(b_ty, token_ty, [t2], [], [])
+ c_device, t4 = gpu.alloc(c_ty, token_ty, [t3], [], [])
+ t5 = gpu.memcpy(token_ty, [t4], a_device, a_host)
+ t6 = gpu.memcpy(token_ty, [t5], b_device, b_host)
+ t7 = gpu.wait(token_ty, [t6])
+
+ # Step 2. Create TMA Descriptors
+ tma_specs = [
+ (a_device, a_tma_desc_ty, a_tma_shape),
+ (b_device, b_tma_desc_ty, b_tma_shape),
+ ]
+ tma_descs = []
+ for x_device, tensor_map_ty, tile_shape in tma_specs:
+ x_unranked = memref.cast(
+ ir.UnrankedMemRefType.get(a_elem_ty, a_ty.memory_space), x_device
+ )
+ tma_descs.append(
+ nvgpu.TmaCreateDescriptorOp(
+ tensor_map_ty, x_unranked, map(c, tile_shape)
+ ).result
+ )
+ a_tma_desc, b_tma_desc = tma_descs
+
+ # Step 3. Launch Kernel with 1 Warpgroup
+ cta_m = M // BLOCK_M
+ cta_n = N // BLOCK_N
+ assert M % BLOCK_M == 0 and N % BLOCK_N == 0
+ grid = (cta_m, cta_n, 1)
+ block = (WARP_GROUP_SIZE, 1, 1)
+ launch_op = gpu.LaunchOp(
+ token_ty,
+ [t7],
+ *map(c, grid),
+ *map(c, block),
+ dynamicSharedMemorySize=c(smem_size, ty=T.i32())
+ )
+ launch_op.body.blocks.append(*([T.index()] * 12))
+ with ir.InsertionPoint(launch_op.body.blocks[0]):
+ # GPU Step 0. Bootstrapping
+ memref.assume_alignment(c_device, 16)
+ dynamic_smem = gpu.dynamic_shared_memory(
+ ir.MemRefType.get((MLIR_DYNAMIC,), T.i8(), memory_space=smem_space)
+ )
+ ticks = c(10000000)
+ tidx = gpu.thread_id(gpu.Dimension.x)
+ primaryThread = arith.cmpi(arith.CmpIPredicate.eq, tidx, c(0))
+ warpId = arith.divui(tidx, c(32))
+ bidx = gpu.block_id(gpu.Dimension.x)
+ bidy = gpu.block_id(gpu.Dimension.y)
+ dimX = arith.muli(bidx, c(BLOCK_M))
+ dimY = arith.muli(bidy, c(BLOCK_N))
+
+ # GPU Step 1. Initialize mbarrier groups
+ mbarTMA = nvgpu.mbarrier_create(mbar_ty)
+ for i in range(num_stages):
+ nvgpu.mbarrier_init(mbarTMA, c(1), c(i), predicate=primaryThread)
+ gpu.barrier()
+
+ # GPU Step 2. Prefetch TMA descriptors
+ nvgpu.tma_prefetch_descriptor(a_tma_desc, predicate=primaryThread)
+ nvgpu.tma_prefetch_descriptor(b_tma_desc, predicate=primaryThread)
+
+ # GPU Step 3. Prologue (global memory --> shared memory)
+ ns = num_stages if num_stages == 1 else num_stages - 1
+ for_op = scf.ForOp(c(0), c(ns), c(1))
+ with ir.InsertionPoint(for_op.body):
+ iv = for_op.induction_variable
+
+ # Step 3.1. Calculate offsets
+ a_offset = arith.muli(iv, c(lhs_tile_bytes))
+ a_tma_slice = memref.view(
+ ir.MemRefType.get(
+ a_tma_shape, a_elem_ty, memory_space=smem_space
+ ),
+ dynamic_smem,
+ a_offset,
+ [],
+ )
+ b_offset = arith.addi(
+ arith.muli(iv, c(rhs_tile_bytes)),
+ c(lhs_tile_bytes * num_stages),
+ )
+ b_tma_slice_1 = memref.view(
+ ir.MemRefType.get(
+ b_tma_shape, b_elem_ty, memory_space=smem_space
+ ),
+ dynamic_smem,
+ b_offset,
+ [],
+ )
+ b_offset2 = arith.addi(
+ b_offset,
+ c(BLOCK_K * TMA_LAST_DIM_F16 * get_type_size(b_elem_ty)),
+ )
+ b_tma_slice_2 = memref.view(
+ ir.MemRefType.get(
+ b_tma_shape, b_elem_ty, memory_space=smem_space
+ ),
+ dynamic_smem,
+ b_offset2,
+ [],
+ )
+
+ # Step 3.2. TMA Load
+ coord = arith.muli(c(64), iv)
+ dimY2 = arith.addi(dimY, c(64))
+ debug_print(
+ "[Prologue] TMA Load a_offset={} b_offset={} b_offset2={} @ a=({},{}) b=({},{})",
+ a_offset,
+ b_offset,
+ b_offset2,
+ coord,
+ dimX,
+ dimY,
+ coord,
+ predicate=primaryThread,
+ )
+ nvgpu.TmaAsyncLoadOp(
+ a_tma_slice,
+ mbarTMA,
+ a_tma_desc,
+ coordinates=[coord, dimX],
+ mbarId=iv,
+ predicate=primaryThread,
+ )
+ nvgpu.TmaAsyncLoadOp(
+ b_tma_slice_1,
+ mbarTMA,
+ b_tma_desc,
+ coordinates=[dimY, coord],
+ mbarId=iv,
+ predicate=primaryThread,
+ )
+ nvgpu.TmaAsyncLoadOp(
+ b_tma_slice_2,
+ mbarTMA,
+ b_tma_desc,
+ coordinates=[dimY2, coord],
+ mbarId=iv,
+ predicate=primaryThread,
+ )
+
+ # Step 3.2. mbarTMA arrive
+ debug_print(
+ "[Prologue] mbarTMA[{}] arrive", iv, predicate=primaryThread
+ )
+ nvgpu.mbarrier_arrive_expect_tx(
+ mbarTMA, c(txcount), iv, predicate=primaryThread
+ )
+ debug_print(
+ "[Prologue] mbarTMA[{}] arrive [done]",
+ iv,
+ predicate=primaryThread,
+ )
+ scf.yield_([])
+
+ # GPU Step 4. Main Loop
+ acc = nvgpu.warpgroup_mma_init_accumulator(acc_ty)
+ for_op = scf.ForOp(
+ c(0), c(K // BLOCK_K), c(1), [acc, arith.constant(T.bool(), 0)]
+ )
+ with ir.InsertionPoint(for_op.body):
+ # Step 4.1. Wait mbarTMA
+ phaseParity = for_op.inner_iter_args[1]
+ iv = for_op.induction_variable
+ stage = arith.remui(iv, c(num_stages))
+ debug_print(
+ "[MainLoop] mbarTMA[{}] try_wait phase={}",
+ stage,
+ phaseParity,
+ predicate=primaryThread,
+ )
+ nvgpu.MBarrierTryWaitParityOp(
+ mbarTMA, phaseParity, ticks, mbarId=stage
+ )
+ debug_print(
+ "[MainLoop] mbarTMA[{}] try_wait phase={} [done]",
+ stage,
+ phaseParity,
+ predicate=primaryThread,
+ )
+
+ # Step 4.2. Create WGMMA Descriptors
+ a_offset = arith.muli(stage, c(lhs_tile_bytes))
+ a_tile_slice = memref.view(
+ ir.MemRefType.get(
+ a_tile_shape, a_elem_ty, memory_space=smem_space
+ ),
+ dynamic_smem,
+ a_offset,
+ [],
+ )
+ b_offset = arith.addi(
+ arith.muli(stage, c(rhs_tile_bytes)),
+ c(lhs_tile_bytes * num_stages),
+ )
+ b_tile_slice = memref.view(
+ ir.MemRefType.get(
+ b_tile_shape, b_elem_ty, memory_space=smem_space
+ ),
+ dynamic_smem,
+ b_offset,
+ [],
+ )
+ debug_print(
+ "[MainLoop] iv={} MMA a_offset={} b_offset={}",
+ iv,
+ a_offset,
+ b_offset,
+ predicate=primaryThread,
+ )
+ da = nvgpu.WarpgroupGenerateDescriptorOp(
+ a_wgmma_ty, a_tile_slice, a_tma_desc
+ )
+ db = nvgpu.WarpgroupGenerateDescriptorOp(
+ b_wgmma_ty, b_tile_slice, b_tma_desc
+ )
+
+ # Step 4.3. MMA
+ carry_acc = for_op.inner_iter_args[0]
+ new_acc = nvgpu.WarpgroupMmaOp(
+ acc.type, da, db, carry_acc, transposeB=True
+ )
+ if num_stages == 1:
+ nvvm.WgmmaWaitGroupSyncOp(0)
+
+ # Step 4.4. Load TMA for next stage
+ p1 = arith.cmpi(
+ arith.CmpIPredicate.ult,
+ arith.addi(iv, c(ns)),
+ c(K // BLOCK_K),
+ )
+ p = arith.andi(primaryThread, p1)
+ nextStage = arith.addi(iv, c(ns))
+ nextSlot = arith.remui(nextStage, c(num_stages))
+ a_offset = arith.muli(nextSlot, c(lhs_tile_bytes))
+
+ debug_print(
+ "[MainLoop] mbarTMA[{}] arrive",
+ nextSlot,
+ predicate=p,
+ )
+ nvgpu.mbarrier_arrive_expect_tx(
+ mbarTMA, c(txcount), nextSlot, predicate=p
+ )
+ debug_print(
+ "[MainLoop] mbarTMA[{}] arrive [done]",
+ nextSlot,
+ predicate=p,
+ )
+
+ a_tma_slice = memref.view(
+ ir.MemRefType.get(
+ a_tma_shape, a_elem_ty, memory_space=smem_space
+ ),
+ dynamic_smem,
+ a_offset,
+ [],
+ )
+ b_offset = arith.addi(
+ arith.muli(nextSlot, c(rhs_tile_bytes)),
+ c(lhs_tile_bytes * num_stages),
+ )
+ b_tma_slice_1 = memref.view(
+ ir.MemRefType.get(
+ b_tma_shape, b_elem_ty, memory_space=smem_space
+ ),
+ dynamic_smem,
+ b_offset,
+ [],
+ )
+ b_offset2 = arith.addi(
+ b_offset,
+ c(BLOCK_K * TMA_LAST_DIM_F16 * get_type_size(b_elem_ty)),
+ )
+ b_tma_slice_2 = memref.view(
+ ir.MemRefType.get(
+ b_tma_shape, b_elem_ty, memory_space=smem_space
+ ),
+ dynamic_smem,
+ b_offset2,
+ [],
+ )
+
+ coord = arith.muli(c(64), nextStage)
+ debug_print(
+ "[MainLoop] iv={} TMA Load a_offset={} b_offset={} b_offset2={} @ a=({},{}) b=({},{})",
+ iv,
+ a_offset,
+ b_offset,
+ b_offset2,
+ coord,
+ dimX,
+ dimY,
+ coord,
+ predicate=p,
+ )
+ nvgpu.TmaAsyncLoadOp(
+ a_tma_slice,
+ mbarTMA,
+ a_tma_desc,
+ coordinates=[coord, dimX],
+ mbarId=nextSlot,
+ predicate=p,
+ )
+ nvgpu.TmaAsyncLoadOp(
+ b_tma_slice_1,
+ mbarTMA,
+ b_tma_desc,
+ coordinates=[dimY, coord],
+ mbarId=nextSlot,
+ predicate=p,
+ )
+ dimY2 = arith.addi(dimY, c(64))
+ nvgpu.TmaAsyncLoadOp(
+ b_tma_slice_2,
+ mbarTMA,
+ b_tma_desc,
+ coordinates=[dimY2, coord],
+ mbarId=nextSlot,
+ predicate=p,
+ )
+ # Step 4.5. Change the phaseParity
+ p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(num_stages - 1))
+ phaseParity = arith.select(
+ p,
+ arith.xori(phaseParity, arith.constant(T.bool(), 1)),
+ phaseParity,
+ )
+
+ # Step 4.5. Yield
+ scf.yield_([new_acc, phaseParity])
+
+ # Step 5. Wait All WGMMA groups
+ nvvm.WgmmaWaitGroupSyncOp(0)
+
+ # Step 6. Epilogue (registers --> shared memory)
+ acc_smem_ty = ir.MemRefType.get(
+ (BLOCK_M, BLOCK_N), c_elem_ty, memory_space=smem_space
+ )
+ acc_smem = memref.view(acc_smem_ty, dynamic_smem, c(0), [])
+ debug_print("Storing", predicate=primaryThread)
+ nvgpu.WarpgroupMmaStoreOp(for_op.results[0], acc_smem)
+ gpu.barrier()
+
+ # GPU Step 7. Epilogue (shared memory --> global memory)
+ fd = ir.MemRefType.get(
+ [BLOCK_M * BLOCK_N], c_elem_ty, memory_space=smem_space
+ )
+ collapsed_smem = memref.view(fd, dynamic_smem, c(0), [])
+ rty = ir.MemRefType.get(
+ (BLOCK_M, BLOCK_N),
+ c_elem_ty,
+ ir.Attribute.parse("strided<[" + str(N) + ", 1], offset: ?>"),
+ )
+ c_device_per_block = memref.SubViewOp(
+ rty,
+ c_device,
+ [dimX, dimY],
+ [],
+ [],
+ [MLIR_DYNAMIC, MLIR_DYNAMIC],
+ [BLOCK_M, BLOCK_N],
+ [1, 1],
+ )
+ vlen = 1
+ for_op = scf.ForOp(
+ tidx, c(BLOCK_M * BLOCK_N), c(vlen * WARP_GROUP_SIZE)
+ )
+ with ir.InsertionPoint(for_op.body):
+ x = arith.divui(for_op.induction_variable, c(BLOCK_M))
+ y = arith.remui(for_op.induction_variable, c(BLOCK_N))
+ vdata = vector.load(
+ ir.VectorType.get((vlen,), c_elem_ty),
+ collapsed_smem,
+ [for_op.induction_variable],
+ )
+ vector.store(vdata, c_device_per_block, [x, y])
+ scf.yield_([])
+
+ gpu.terminator()
+
+ # Step 4. Copy back to host
+ t8 = gpu.wait(token_ty, [launch_op])
+ t9 = gpu.memcpy(token_ty, [t8], c_host, c_device)
+ gpu.dealloc(token_ty, [t8], a_device)
+ gpu.dealloc(token_ty, [t8], b_device)
+ gpu.wait(token_ty, [t9])
+ gpu.dealloc(token_ty, [t8], c_device)
+ func.ReturnOp([])
+
+ fop.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
+ module.operation.verify()
+ return module
diff --git a/mlir/test/Integration/GPU/CUDA/sm90/python/tools/nvgpucompiler.py b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/nvgpucompiler.py
new file mode 100644
index 00000000000000..1c9cc74fcd169c
--- /dev/null
+++ b/mlir/test/Integration/GPU/CUDA/sm90/python/tools/nvgpucompiler.py
@@ -0,0 +1,45 @@
+# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+# This file contains the Nvgpu class.
+
+from mlir import execution_engine
+from mlir import ir
+from mlir import passmanager
+from typing import Sequence
+import errno
+import os
+import sys
+
+_SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__))
+sys.path.append(_SCRIPT_PATH)
+
+
+class NvgpuCompiler:
+ """Nvgpu class for compiling and building MLIR modules."""
+
+ def __init__(self, options: str, opt_level: int, shared_libs: Sequence[str]):
+ pipeline = f"builtin.module(gpu-lower-to-nvvm-pipeline{{{options}}})"
+ self.pipeline = pipeline
+ self.shared_libs = shared_libs
+ self.opt_level = opt_level
+
+ def __call__(self, module: ir.Module):
+ """Convenience application method."""
+ self.compile(module)
+
+ def compile(self, module: ir.Module):
+ """Compiles the module by invoking the nvgpu pipeline."""
+ passmanager.PassManager.parse(self.pipeline).run(module.operation)
+
+ def jit(self, module: ir.Module) -> execution_engine.ExecutionEngine:
+ """Wraps the module in a JIT execution engine."""
+ return execution_engine.ExecutionEngine(
+ module, opt_level=self.opt_level, shared_libs=self.shared_libs
+ )
+
+ def compile_and_jit(self, module: ir.Module) -> execution_engine.ExecutionEngine:
+ """Compiles and jits the module."""
+ self.compile(module)
+ return self.jit(module)
More information about the Mlir-commits
mailing list