[Mlir-commits] [mlir] [mlir] GEMM Hopper Tensor Core Integration Test (PR #81478)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Feb 12 05:36:55 PST 2024
github-actions[bot] wrote:
<!--LLVM CODE FORMAT COMMENT: {darker}-->
:warning: Python code formatter, darker found issues in your code. :warning:
<details>
<summary>
You can test this locally with the following command:
</summary>
``````````bash
darker --check --diff -r 7bc079c85219ad6e954fb6071cd108151203c85e...a9b4035aa421506302e74c6f6ff19a0aec62b735 mlir/test/Integration/GPU/CUDA/sm90/python/matmul.py mlir/test/Integration/GPU/CUDA/sm90/python/tools/matmulBuilder.py mlir/test/Integration/GPU/CUDA/sm90/python/tools/nvgpucompiler.py
``````````
</details>
<details>
<summary>
View the diff from darker here.
</summary>
``````````diff
--- matmul.py 2024-02-12 13:33:26.000000 +0000
+++ matmul.py 2024-02-12 13:36:46.432817 +0000
@@ -83,91 +83,155 @@
import sys
import pathlib
import ctypes
from mlir import runtime as rt
-def generate_matmul(input_type=np.float16,
- output_type=np.float32,
- M=4096,
- N=4096,
- K=4096,
- BLOCK_M=128,
- BLOCK_N=128,
- BLOCK_K=64,
- use_warp_specilization=True,
- saveIR=False,
- max_num_stages=3):
+
+def generate_matmul(
+ input_type=np.float16,
+ output_type=np.float32,
+ M=4096,
+ N=4096,
+ K=4096,
+ BLOCK_M=128,
+ BLOCK_N=128,
+ BLOCK_K=64,
+ use_warp_specilization=True,
+ saveIR=False,
+ max_num_stages=3,
+):
with matmulBuilder.ir.Context() as ctx, matmulBuilder.ir.Location.unknown():
if use_warp_specilization:
- mlir_nvgpu_module = matmulBuilder.generate_matmul_ws(input_type, output_type, M, N, K, BLOCK_M, BLOCK_N,
- BLOCK_K, max_num_stages)
+ mlir_nvgpu_module = matmulBuilder.generate_matmul_ws(
+ input_type,
+ output_type,
+ M,
+ N,
+ K,
+ BLOCK_M,
+ BLOCK_N,
+ BLOCK_K,
+ max_num_stages,
+ )
else:
- mlir_nvgpu_module = matmulBuilder.generate_matmul_multistage(input_type, output_type, M, N, K, BLOCK_M,
- BLOCK_N, BLOCK_K, max_num_stages)
+ mlir_nvgpu_module = matmulBuilder.generate_matmul_multistage(
+ input_type,
+ output_type,
+ M,
+ N,
+ K,
+ BLOCK_M,
+ BLOCK_N,
+ BLOCK_K,
+ max_num_stages,
+ )
mlir_nvgpu_module.operation.verify()
-
+
# Save generated IR
if saveIR:
# print(mlir_nvgpu_module)
original_stdout = sys.stdout
- with open('gemm.mlir', 'w') as f:
+ with open("gemm.mlir", "w") as f:
sys.stdout = f
print(mlir_nvgpu_module)
sys.stdout = original_stdout
# Get compiler
options = f"cubin-chip=sm_90a cubin-features=+ptx80 opt-level=3"
support_lib = os.getenv("SUPPORT_LIB")
if not os.path.exists(support_lib):
- raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), support_lib)
- compiler = nvgpucompiler.NvgpuCompiler(options, opt_level=3, shared_libs=[support_lib])
+ raise FileNotFoundError(
+ errno.ENOENT, os.strerror(errno.ENOENT), support_lib
+ )
+ compiler = nvgpucompiler.NvgpuCompiler(
+ options, opt_level=3, shared_libs=[support_lib]
+ )
# Compile
engine = compiler.compile_and_jit(mlir_nvgpu_module)
return engine
-def matmul(input_type=np.float16,
- output_type=np.float32,
- M=128,
- N=128,
- K=128,
- BLOCK_M=128,
- BLOCK_N=128,
- BLOCK_K=64,
- use_warp_specilization=True,
- saveIR=False,
- max_num_stages=3,
- print_results=False,
- no_verify=False):
+def matmul(
+ input_type=np.float16,
+ output_type=np.float32,
+ M=128,
+ N=128,
+ K=128,
+ BLOCK_M=128,
+ BLOCK_N=128,
+ BLOCK_K=64,
+ use_warp_specilization=True,
+ saveIR=False,
+ max_num_stages=3,
+ print_results=False,
+ no_verify=False,
+):
# Print the configuration
ity = "f16" if input_type == np.float16 else "f32"
oty = "f16" if output_type == np.float16 else "f32"
gemmty = "Warp Specilization" if use_warp_specilization else "Multistage"
- print("===-- Running GEMM " + gemmty + " " + oty + " += " + ity + " * " + ity + ", Size " + str(M) + "x" + str(N) +
- "x" + str(K) + ", Tile " + str(BLOCK_M) + "x" + str(BLOCK_N) + "x" + str(BLOCK_K) + ", stages " +
- str(max_num_stages) + " --===")
+ print(
+ "===-- Running GEMM "
+ + gemmty
+ + " "
+ + oty
+ + " += "
+ + ity
+ + " * "
+ + ity
+ + ", Size "
+ + str(M)
+ + "x"
+ + str(N)
+ + "x"
+ + str(K)
+ + ", Tile "
+ + str(BLOCK_M)
+ + "x"
+ + str(BLOCK_N)
+ + "x"
+ + str(BLOCK_K)
+ + ", stages "
+ + str(max_num_stages)
+ + " --==="
+ )
# Build IR and compile
- engine = generate_matmul(input_type, output_type, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, use_warp_specilization,
- saveIR, max_num_stages)
+ engine = generate_matmul(
+ input_type,
+ output_type,
+ M,
+ N,
+ K,
+ BLOCK_M,
+ BLOCK_N,
+ BLOCK_K,
+ use_warp_specilization,
+ saveIR,
+ max_num_stages,
+ )
# Allocate matrices and invoke the matmul
c = np.zeros((M, N), output_type)
a = np.random.randn(M, K).astype(input_type)
b = np.random.randn(K, N).astype(input_type)
mem_a = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(a)))
mem_b = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(b)))
mem_c = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(c)))
- kernelName = "mlir_matmul_warpspecialized" if use_warp_specilization else "mlir_matmul_multistage"
+ kernelName = (
+ "mlir_matmul_warpspecialized"
+ if use_warp_specilization
+ else "mlir_matmul_multistage"
+ )
# Launch the MLIR generated kernel
engine.invoke(kernelName, mem_a, mem_b, mem_c)
float_formatter = "{:.2f}".format
- np.set_printoptions(formatter={'float_kind': float_formatter})
+ np.set_printoptions(formatter={"float_kind": float_formatter})
if print_results:
print(c)
# Verify the results
@@ -179,8 +243,24 @@
print("PASS ")
# GEMM Multistage f32 += f16 * f16
-matmul(np.float16, np.float32, 128, 128, 4096, max_num_stages=3, use_warp_specilization=False)
+matmul(
+ np.float16,
+ np.float32,
+ 128,
+ 128,
+ 4096,
+ max_num_stages=3,
+ use_warp_specilization=False,
+)
# GEMM Warp Specilized f32 += f16 * f16
-matmul(np.float16, np.float32, 256, 1024, 512, max_num_stages=3, use_warp_specilization=True)
+matmul(
+ np.float16,
+ np.float32,
+ 256,
+ 1024,
+ 512,
+ max_num_stages=3,
+ use_warp_specilization=True,
+)
--- tools/matmulBuilder.py 2024-02-12 13:33:26.000000 +0000
+++ tools/matmulBuilder.py 2024-02-12 13:36:46.754115 +0000
@@ -63,19 +63,21 @@
def c(value, ty=None):
ty = ir.IndexType.get() if ty is None else ty
return arith.constant(ty, value)
-def generate_matmul_ws(input_type=np.float16,
- output_type=np.float32,
- M=4096,
- N=4096,
- K=4096,
- BLOCK_M=128,
- BLOCK_N=128,
- BLOCK_K=128,
- max_num_stages=3):
+def generate_matmul_ws(
+ input_type=np.float16,
+ output_type=np.float32,
+ M=4096,
+ N=4096,
+ K=4096,
+ BLOCK_M=128,
+ BLOCK_N=128,
+ BLOCK_K=128,
+ max_num_stages=3,
+):
# Limitaitons for now
assert input_type == np.float16
assert output_type == np.float32
assert M % BLOCK_M == 0
assert N % BLOCK_N == 0
@@ -97,29 +99,77 @@
c_elem_ty = f16 if output_type == np.float16 else f32
c_ty = ir.MemRefType.get((M, N), c_elem_ty)
a_tile_shape = a_tma_shape = (BLOCK_M, TMA_LAST_DIM_F16)
b_tma_shape = (BLOCK_K, TMA_LAST_DIM_F16)
b_tile_shape = (BLOCK_K, BLOCK_N)
- txcount = ((b_tile_shape[0] * b_tile_shape[1]) + (a_tile_shape[0] * a_tile_shape[1])) * f16_byte
+ txcount = (
+ (b_tile_shape[0] * b_tile_shape[1]) + (a_tile_shape[0] * a_tile_shape[1])
+ ) * f16_byte
smem_space_str = "#gpu.address_space<workgroup>"
smem_space = ir.Attribute.parse(smem_space_str)
input_type_str = "f16" if input_type == np.float16 else "f32"
output_type_str = "f16" if output_type == np.float16 else "f32"
- mbar_ty = ir.Type.parse("!nvgpu.mbarrier.group<memorySpace = " + str(smem_space) + ", num_barriers = " +
- str(num_stages) + ">")
- a_tma_desc_ty = ir.Type.parse("!nvgpu.tensormap.descriptor<tensor = memref<" + str(BLOCK_M) + "x" +
- str(TMA_LAST_DIM_F16) + "x" + str(input_type_str) + ", " + str(smem_space) +
- ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>")
- b_tma_desc_ty = ir.Type.parse("!nvgpu.tensormap.descriptor<tensor = memref<" + str(BLOCK_K) + "x" +
- str(TMA_LAST_DIM_F16) + "x" + str(input_type_str) + ", " + str(smem_space) +
- ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>")
- acc_ty = ir.Type.parse("!nvgpu.warpgroup.accumulator<fragmented=vector<" + str(BLOCK_M) + "x" + str(BLOCK_N) + "x" +
- str(output_type_str) + ">>")
- a_wgmma_ty = ir.Type.parse("!nvgpu.warpgroup.descriptor<tensor=memref<" + str(BLOCK_M) + "x" + str(BLOCK_K) + "x" +
- str(input_type_str) + ", " + smem_space_str + ">>")
- b_wgmma_ty = ir.Type.parse("!nvgpu.warpgroup.descriptor<tensor=memref<" + str(BLOCK_K) + "x" + str(BLOCK_N) + "x" +
- str(input_type_str) + ", " + smem_space_str + ">>")
+ mbar_ty = ir.Type.parse(
+ "!nvgpu.mbarrier.group<memorySpace = "
+ + str(smem_space)
+ + ", num_barriers = "
+ + str(num_stages)
+ + ">"
+ )
+ a_tma_desc_ty = ir.Type.parse(
+ "!nvgpu.tensormap.descriptor<tensor = memref<"
+ + str(BLOCK_M)
+ + "x"
+ + str(TMA_LAST_DIM_F16)
+ + "x"
+ + str(input_type_str)
+ + ", "
+ + str(smem_space)
+ + ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>"
+ )
+ b_tma_desc_ty = ir.Type.parse(
+ "!nvgpu.tensormap.descriptor<tensor = memref<"
+ + str(BLOCK_K)
+ + "x"
+ + str(TMA_LAST_DIM_F16)
+ + "x"
+ + str(input_type_str)
+ + ", "
+ + str(smem_space)
+ + ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>"
+ )
+ acc_ty = ir.Type.parse(
+ "!nvgpu.warpgroup.accumulator<fragmented=vector<"
+ + str(BLOCK_M)
+ + "x"
+ + str(BLOCK_N)
+ + "x"
+ + str(output_type_str)
+ + ">>"
+ )
+ a_wgmma_ty = ir.Type.parse(
+ "!nvgpu.warpgroup.descriptor<tensor=memref<"
+ + str(BLOCK_M)
+ + "x"
+ + str(BLOCK_K)
+ + "x"
+ + str(input_type_str)
+ + ", "
+ + smem_space_str
+ + ">>"
+ )
+ b_wgmma_ty = ir.Type.parse(
+ "!nvgpu.warpgroup.descriptor<tensor=memref<"
+ + str(BLOCK_K)
+ + "x"
+ + str(BLOCK_N)
+ + "x"
+ + str(input_type_str)
+ + ", "
+ + smem_space_str
+ + ">>"
+ )
with ir.InsertionPoint(module.body):
@func.FuncOp.from_py_func(a_ty, b_ty, c_ty)
def mlir_matmul_warpspecialized(a_host, b_host, c_host):
@@ -137,46 +187,71 @@
t5 = gpu.memcpy(token_ty, [t4], a_device, a_host)
t6 = gpu.memcpy(token_ty, [t5], b_device, b_host)
t7 = gpu.wait(token_ty, [t6])
# Step 2. Create TMA Descriptors
- tma_specs = [(a_device, a_tma_desc_ty, a_tma_shape), (b_device, b_tma_desc_ty, b_tma_shape)]
+ tma_specs = [
+ (a_device, a_tma_desc_ty, a_tma_shape),
+ (b_device, b_tma_desc_ty, b_tma_shape),
+ ]
tma_descs = []
for x_device, tensor_map_ty, tile_shape in tma_specs:
- x_unranked = memref.cast(ir.UnrankedMemRefType.get(f16, a_ty.memory_space), x_device)
- tma_descs.append(nvgpu.TmaCreateDescriptorOp(tensor_map_ty, x_unranked, map(c, tile_shape)).result)
+ x_unranked = memref.cast(
+ ir.UnrankedMemRefType.get(f16, a_ty.memory_space), x_device
+ )
+ tma_descs.append(
+ nvgpu.TmaCreateDescriptorOp(
+ tensor_map_ty, x_unranked, map(c, tile_shape)
+ ).result
+ )
a_tma_desc, b_tma_desc = tma_descs
-
+
# Step 3. Launch Kernel with 2 Warpgroups : 1 Producer, 1 Consumer
cta_m = M // BLOCK_M
cta_n = N // BLOCK_N
assert M % BLOCK_M == 0 and N % BLOCK_N == 0
grid = (cta_m, cta_n, 1)
block = (WARP_GROUP_SIZE * 2, 1, 1)
- launch_op = gpu.LaunchOp(token_ty, [t7],
- *map(c, grid),
- *map(c, block),
- dynamicSharedMemorySize=c(smem_size, ty=i32))
+ launch_op = gpu.LaunchOp(
+ token_ty,
+ [t7],
+ *map(c, grid),
+ *map(c, block),
+ dynamicSharedMemorySize=c(smem_size, ty=i32)
+ )
launch_op.body.blocks.append(*([index] * 12))
with ir.InsertionPoint(launch_op.body.blocks[0]):
# GPU Step 0. This is need for vectorized ld/st
memref.assume_alignment(c_device, 16)
dynamic_smem = gpu.dynamic_shared_memory(
- ir.MemRefType.get((MLIR_DYNAMIC, ), i8, memory_space=smem_space))
+ ir.MemRefType.get((MLIR_DYNAMIC,), i8, memory_space=smem_space)
+ )
ticks = c(10000000)
# GPU Step 1. Bootstrapping: find the primary thread, warps, warp groups and etc.
tidx = gpu.thread_id(gpu.Dimension.x)
- wgPrimaryThread = arith.cmpi(arith.CmpIPredicate.eq, arith.remui(tidx, c(WARP_GROUP_SIZE)), c(0))
+ wgPrimaryThread = arith.cmpi(
+ arith.CmpIPredicate.eq, arith.remui(tidx, c(WARP_GROUP_SIZE)), c(0)
+ )
warp_id = arith.divui(tidx, c(32))
warpgroup_id = arith.divui(warp_id, c(4))
- is_producer = arith.cmpi(arith.CmpIPredicate.eq, warpgroup_id,
- c(1 if PRODUCER_PRIMARY_THREAD == 128 else 0))
- is_consumer = arith.cmpi(arith.CmpIPredicate.eq, warpgroup_id,
- c(0 if CONSUMER_PRIMARY_THREAD == 0 else 1))
- producerPrimaryThread = arith.cmpi(arith.CmpIPredicate.eq, tidx, c(PRODUCER_PRIMARY_THREAD))
- consumerPrimaryThread = arith.cmpi(arith.CmpIPredicate.eq, tidx, c(CONSUMER_PRIMARY_THREAD))
+ is_producer = arith.cmpi(
+ arith.CmpIPredicate.eq,
+ warpgroup_id,
+ c(1 if PRODUCER_PRIMARY_THREAD == 128 else 0),
+ )
+ is_consumer = arith.cmpi(
+ arith.CmpIPredicate.eq,
+ warpgroup_id,
+ c(0 if CONSUMER_PRIMARY_THREAD == 0 else 1),
+ )
+ producerPrimaryThread = arith.cmpi(
+ arith.CmpIPredicate.eq, tidx, c(PRODUCER_PRIMARY_THREAD)
+ )
+ consumerPrimaryThread = arith.cmpi(
+ arith.CmpIPredicate.eq, tidx, c(CONSUMER_PRIMARY_THREAD)
+ )
bidx = gpu.block_id(gpu.Dimension.x)
bidy = gpu.block_id(gpu.Dimension.y)
dimX = arith.muli(bidx, c(BLOCK_M))
dimY = arith.muli(bidy, c(BLOCK_N))
@@ -189,215 +264,338 @@
gpu.barrier()
# GPU Step 3. Prefetch TMA descriptors
nvgpu.tma_prefetch_descriptor(a_tma_desc, predicate=wgPrimaryThread)
nvgpu.tma_prefetch_descriptor(b_tma_desc, predicate=wgPrimaryThread)
-
+
# GPU Step 5. Producer Warpgroup (TMA Warpgroup)
with ir.InsertionPoint(scf.IfOp(is_producer).then_block):
-
# Step 5.1. Reduce register size
- nvvm.setmaxregister(PRODUCER_REGISTER_SIZE, nvvm.SetMaxRegisterAction.decrease)
+ nvvm.setmaxregister(
+ PRODUCER_REGISTER_SIZE, nvvm.SetMaxRegisterAction.decrease
+ )
# Step 5.2. TMA Main Loop
- for_op = scf.ForOp(c(0), c(K // BLOCK_K), c(1), [arith.constant(i1, 1)])
+ for_op = scf.ForOp(
+ c(0), c(K // BLOCK_K), c(1), [arith.constant(i1, 1)]
+ )
with ir.InsertionPoint(for_op.body):
phaseParity = for_op.inner_iter_args[0]
iv = for_op.induction_variable
stage = arith.remui(iv, c(num_stages))
# Step 5.2.1. Wait mbarDONE
- debug_print("[prod] {} | mbarDONE[{}] try_wait phase={}",
- iv,
- stage,
- phaseParity,
- predicate=producerPrimaryThread)
- nvgpu.MBarrierTryWaitParityOp(mbarDONE, phaseParity, ticks, mbarId=stage)
- debug_print("[prod] {} | mbarDONE[{}] try_wait phase={} [done]",
- iv,
- stage,
- phaseParity,
- predicate=producerPrimaryThread)
+ debug_print(
+ "[prod] {} | mbarDONE[{}] try_wait phase={}",
+ iv,
+ stage,
+ phaseParity,
+ predicate=producerPrimaryThread,
+ )
+ nvgpu.MBarrierTryWaitParityOp(
+ mbarDONE, phaseParity, ticks, mbarId=stage
+ )
+ debug_print(
+ "[prod] {} | mbarDONE[{}] try_wait phase={} [done]",
+ iv,
+ stage,
+ phaseParity,
+ predicate=producerPrimaryThread,
+ )
p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(num_stages - 1))
- phaseParity = arith.select(p, arith.xori(phaseParity, arith.constant(i1, 1)), phaseParity)
+ phaseParity = arith.select(
+ p,
+ arith.xori(phaseParity, arith.constant(i1, 1)),
+ phaseParity,
+ )
# Step 5.2.2. Load TMA
a_offset = arith.muli(stage, c(lhs_tile_bytes))
- a_tma_slice = memref.view(ir.MemRefType.get(a_tma_shape, f16, memory_space=smem_space),
- dynamic_smem, a_offset, [])
- b_offset = arith.addi(arith.muli(stage, c(rhs_tile_bytes)), c(lhs_tile_bytes * num_stages))
- b_tma_slice_1 = memref.view(ir.MemRefType.get(b_tma_shape, f16, memory_space=smem_space),
- dynamic_smem, b_offset, [])
- b_offset2 = arith.addi(b_offset, c(BLOCK_K * TMA_LAST_DIM_F16 * f16_byte))
- b_tma_slice_2 = memref.view(ir.MemRefType.get(b_tma_shape, f16, memory_space=smem_space),
- dynamic_smem, b_offset2, [])
- debug_print("[prod] a_offset={} b_offset={} b_offset2={}",
- a_offset,
- b_offset,
- b_offset2,
- predicate=producerPrimaryThread)
+ a_tma_slice = memref.view(
+ ir.MemRefType.get(
+ a_tma_shape, f16, memory_space=smem_space
+ ),
+ dynamic_smem,
+ a_offset,
+ [],
+ )
+ b_offset = arith.addi(
+ arith.muli(stage, c(rhs_tile_bytes)),
+ c(lhs_tile_bytes * num_stages),
+ )
+ b_tma_slice_1 = memref.view(
+ ir.MemRefType.get(
+ b_tma_shape, f16, memory_space=smem_space
+ ),
+ dynamic_smem,
+ b_offset,
+ [],
+ )
+ b_offset2 = arith.addi(
+ b_offset, c(BLOCK_K * TMA_LAST_DIM_F16 * f16_byte)
+ )
+ b_tma_slice_2 = memref.view(
+ ir.MemRefType.get(
+ b_tma_shape, f16, memory_space=smem_space
+ ),
+ dynamic_smem,
+ b_offset2,
+ [],
+ )
+ debug_print(
+ "[prod] a_offset={} b_offset={} b_offset2={}",
+ a_offset,
+ b_offset,
+ b_offset2,
+ predicate=producerPrimaryThread,
+ )
coord = arith.muli(c(64), iv)
- nvgpu.TmaAsyncLoadOp(a_tma_slice,
- mbarTMA,
- a_tma_desc,
- coordinates=[coord, dimX],
- mbarId=stage,
- predicate=producerPrimaryThread)
- nvgpu.TmaAsyncLoadOp(b_tma_slice_1,
- mbarTMA,
- b_tma_desc,
- coordinates=[dimY, coord],
- mbarId=stage,
- predicate=producerPrimaryThread)
+ nvgpu.TmaAsyncLoadOp(
+ a_tma_slice,
+ mbarTMA,
+ a_tma_desc,
+ coordinates=[coord, dimX],
+ mbarId=stage,
+ predicate=producerPrimaryThread,
+ )
+ nvgpu.TmaAsyncLoadOp(
+ b_tma_slice_1,
+ mbarTMA,
+ b_tma_desc,
+ coordinates=[dimY, coord],
+ mbarId=stage,
+ predicate=producerPrimaryThread,
+ )
dimY2 = arith.addi(dimY, c(64))
- nvgpu.TmaAsyncLoadOp(b_tma_slice_2,
- mbarTMA,
- b_tma_desc,
- coordinates=[dimY2, coord],
- mbarId=stage,
- predicate=producerPrimaryThread)
+ nvgpu.TmaAsyncLoadOp(
+ b_tma_slice_2,
+ mbarTMA,
+ b_tma_desc,
+ coordinates=[dimY2, coord],
+ mbarId=stage,
+ predicate=producerPrimaryThread,
+ )
# Step 5.2.3. Arrive mbarTMA
- debug_print("[prod] {} | mbarTMA[{}] arrive", iv, stage, predicate=producerPrimaryThread)
- nvgpu.mbarrier_arrive_expect_tx(mbarTMA, c(txcount), stage, predicate=producerPrimaryThread)
- debug_print("[prod] {} | mbarTMA[{}] arrive [done]",
- iv,
- stage,
- predicate=producerPrimaryThread)
+ debug_print(
+ "[prod] {} | mbarTMA[{}] arrive",
+ iv,
+ stage,
+ predicate=producerPrimaryThread,
+ )
+ nvgpu.mbarrier_arrive_expect_tx(
+ mbarTMA, c(txcount), stage, predicate=producerPrimaryThread
+ )
+ debug_print(
+ "[prod] {} | mbarTMA[{}] arrive [done]",
+ iv,
+ stage,
+ predicate=producerPrimaryThread,
+ )
scf.yield_([phaseParity])
scf.yield_([])
# GPU Step 6. Consumer Warpgroup (MMA Warpgroup)
if_op = scf.IfOp(is_consumer)
with ir.InsertionPoint(if_op.then_block):
-
# Step 6.1. Increase register size
- nvvm.setmaxregister(CONSUMER_REGISTER_SIZE, nvvm.SetMaxRegisterAction.increase)
-
+ nvvm.setmaxregister(
+ CONSUMER_REGISTER_SIZE, nvvm.SetMaxRegisterAction.increase
+ )
+
# GPU Step 6.2. Initialize MMA registers
acc = nvgpu.warpgroup_mma_init_accumulator(acc_ty)
# Step 6.3. MMA Main Loop
- for_op = scf.ForOp(c(0), c(K // BLOCK_K), c(1), [acc, arith.constant(i1, 0)])
+ for_op = scf.ForOp(
+ c(0), c(K // BLOCK_K), c(1), [acc, arith.constant(i1, 0)]
+ )
with ir.InsertionPoint(for_op.body):
# Step 6.3.1. Wait mbar1
phaseParity = for_op.inner_iter_args[1]
iv = for_op.induction_variable
stage = arith.remui(iv, c(num_stages))
- debug_print("[cons] {} | mbarTMA[{}] try_wait phase={}",
- iv,
- stage,
- phaseParity,
- predicate=consumerPrimaryThread)
- nvgpu.MBarrierTryWaitParityOp(mbarTMA, phaseParity, ticks, mbarId=stage)
- debug_print("[cons] {} | mbarTMA[{}] try_wait phase={} [done]",
- iv,
- stage,
- phaseParity,
- predicate=consumerPrimaryThread)
-
+ debug_print(
+ "[cons] {} | mbarTMA[{}] try_wait phase={}",
+ iv,
+ stage,
+ phaseParity,
+ predicate=consumerPrimaryThread,
+ )
+ nvgpu.MBarrierTryWaitParityOp(
+ mbarTMA, phaseParity, ticks, mbarId=stage
+ )
+ debug_print(
+ "[cons] {} | mbarTMA[{}] try_wait phase={} [done]",
+ iv,
+ stage,
+ phaseParity,
+ predicate=consumerPrimaryThread,
+ )
+
# Step 6.3.2. Create WGMMA Descriptors
a_offset = arith.muli(stage, c(lhs_tile_bytes))
- a_tile_slice = memref.view(ir.MemRefType.get(a_tile_shape, f16, memory_space=smem_space),
- dynamic_smem, a_offset, [])
- b_offset = arith.addi(arith.muli(stage, c(rhs_tile_bytes)), c(lhs_tile_bytes * num_stages))
- b_tile_slice = memref.view(ir.MemRefType.get(b_tile_shape, f16, memory_space=smem_space),
- dynamic_smem, b_offset, [])
- debug_print("[cons] a_offset={} b_offset={}",
- a_offset,
- b_offset,
- predicate=consumerPrimaryThread)
- da = nvgpu.WarpgroupGenerateDescriptorOp(a_wgmma_ty, a_tile_slice, a_tma_desc)
- db = nvgpu.WarpgroupGenerateDescriptorOp(b_wgmma_ty, b_tile_slice, b_tma_desc)
+ a_tile_slice = memref.view(
+ ir.MemRefType.get(
+ a_tile_shape, f16, memory_space=smem_space
+ ),
+ dynamic_smem,
+ a_offset,
+ [],
+ )
+ b_offset = arith.addi(
+ arith.muli(stage, c(rhs_tile_bytes)),
+ c(lhs_tile_bytes * num_stages),
+ )
+ b_tile_slice = memref.view(
+ ir.MemRefType.get(
+ b_tile_shape, f16, memory_space=smem_space
+ ),
+ dynamic_smem,
+ b_offset,
+ [],
+ )
+ debug_print(
+ "[cons] a_offset={} b_offset={}",
+ a_offset,
+ b_offset,
+ predicate=consumerPrimaryThread,
+ )
+ da = nvgpu.WarpgroupGenerateDescriptorOp(
+ a_wgmma_ty, a_tile_slice, a_tma_desc
+ )
+ db = nvgpu.WarpgroupGenerateDescriptorOp(
+ b_wgmma_ty, b_tile_slice, b_tma_desc
+ )
# Step 6.3.3. MMA
carry_acc = for_op.inner_iter_args[0]
- new_acc = nvgpu.WarpgroupMmaOp(acc.type, da, db, carry_acc, transposeB=True)
+ new_acc = nvgpu.WarpgroupMmaOp(
+ acc.type, da, db, carry_acc, transposeB=True
+ )
# Step 6.3.4. Arrive mbarDONE
p1 = arith.cmpi(arith.CmpIPredicate.sgt, iv, c(0))
p_arrive = arith.andi(consumerPrimaryThread, p1)
with ir.InsertionPoint(scf.IfOp(p_arrive).then_block):
p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(0))
- barId = arith.select(p, c(num_stages - 1), arith.subi(stage, c(1)))
+ barId = arith.select(
+ p, c(num_stages - 1), arith.subi(stage, c(1))
+ )
remoteCtaId = c(0)
pred = consumerPrimaryThread
- debug_print("[cons] {} | mbarDONE[{}] arrive.mapa pred={} ",
- iv,
- barId,
- remoteCtaId,
- pred,
- predicate=consumerPrimaryThread)
- nvgpu.mbarrier_arrive(ir.Type.parse("!nvgpu.mbarrier.token"), mbarDONE, barId)
- debug_print("[cons] {} | mbarDONE[{}] arrive pred={} [done]",
- iv,
- barId,
- remoteCtaId,
- pred,
- predicate=consumerPrimaryThread)
+ debug_print(
+ "[cons] {} | mbarDONE[{}] arrive.mapa pred={} ",
+ iv,
+ barId,
+ remoteCtaId,
+ pred,
+ predicate=consumerPrimaryThread,
+ )
+ nvgpu.mbarrier_arrive(
+ ir.Type.parse("!nvgpu.mbarrier.token"), mbarDONE, barId
+ )
+ debug_print(
+ "[cons] {} | mbarDONE[{}] arrive pred={} [done]",
+ iv,
+ barId,
+ remoteCtaId,
+ pred,
+ predicate=consumerPrimaryThread,
+ )
scf.yield_([])
p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(num_stages - 1))
- phaseParity = arith.select(p, arith.xori(phaseParity, arith.constant(i1, 1)), phaseParity)
+ phaseParity = arith.select(
+ p,
+ arith.xori(phaseParity, arith.constant(i1, 1)),
+ phaseParity,
+ )
# Step 6.3.5. Yield
scf.yield_([new_acc, phaseParity])
# Step 6.3. Wait All WGMMA
nvvm.WgmmaWaitGroupSyncOp(0)
with ir.InsertionPoint(scf.IfOp(consumerPrimaryThread).then_block):
barId = c((K // BLOCK_K) % num_stages)
- nvgpu.mbarrier_arrive(ir.Type.parse("!nvgpu.mbarrier.token"), mbarDONE, barId)
+ nvgpu.mbarrier_arrive(
+ ir.Type.parse("!nvgpu.mbarrier.token"), mbarDONE, barId
+ )
scf.yield_([])
# Step 6.4. Epilogue (registers --> shared memory)
- acc_smem_ty = ir.MemRefType.get((BLOCK_M, BLOCK_N), c_elem_ty, memory_space=smem_space)
+ acc_smem_ty = ir.MemRefType.get(
+ (BLOCK_M, BLOCK_N), c_elem_ty, memory_space=smem_space
+ )
acc_smem = memref.view(acc_smem_ty, dynamic_smem, c(0), [])
debug_print("[cons] | Storing", predicate=consumerPrimaryThread)
nvgpu.WarpgroupMmaStoreOp(for_op.results[0], acc_smem)
scf.yield_([])
gpu.barrier()
# GPU Step 9. Epilogue (shared memory --> global memory)
- fd = ir.MemRefType.get([BLOCK_M * BLOCK_N], c_elem_ty, memory_space=smem_space)
+ fd = ir.MemRefType.get(
+ [BLOCK_M * BLOCK_N], c_elem_ty, memory_space=smem_space
+ )
collapsed_smem = memref.view(fd, dynamic_smem, c(0), [])
- rty = ir.MemRefType.get((BLOCK_M, BLOCK_N), c_elem_ty,
- ir.Attribute.parse("strided<[" + str(N) + ", 1], offset: ?>"))
- c_device_per_block = memref.SubViewOp(rty, c_device, [dimX, dimY], [], [], [MLIR_DYNAMIC, MLIR_DYNAMIC],
- [BLOCK_M, BLOCK_N], [1, 1])
+ rty = ir.MemRefType.get(
+ (BLOCK_M, BLOCK_N),
+ c_elem_ty,
+ ir.Attribute.parse("strided<[" + str(N) + ", 1], offset: ?>"),
+ )
+ c_device_per_block = memref.SubViewOp(
+ rty,
+ c_device,
+ [dimX, dimY],
+ [],
+ [],
+ [MLIR_DYNAMIC, MLIR_DYNAMIC],
+ [BLOCK_M, BLOCK_N],
+ [1, 1],
+ )
vlen = 1
- for_op = scf.ForOp(tidx, c(BLOCK_M * BLOCK_N), c(vlen * WARP_GROUP_SIZE * 2))
+ for_op = scf.ForOp(
+ tidx, c(BLOCK_M * BLOCK_N), c(vlen * WARP_GROUP_SIZE * 2)
+ )
with ir.InsertionPoint(for_op.body):
x = arith.divui(for_op.induction_variable, c(BLOCK_M))
y = arith.remui(for_op.induction_variable, c(BLOCK_N))
- vdata = vector.load(ir.VectorType.get((vlen, ), c_elem_ty), collapsed_smem,
- [for_op.induction_variable])
+ vdata = vector.load(
+ ir.VectorType.get((vlen,), c_elem_ty),
+ collapsed_smem,
+ [for_op.induction_variable],
+ )
vector.store(vdata, c_device_per_block, [x, y])
scf.yield_([])
gpu.terminator()
# Step 4. Copy back to host
t8 = gpu.wait(token_ty, [launch_op])
t9 = gpu.memcpy(token_ty, [t8], c_host, c_device)
gpu.wait(token_ty, [t9])
- mlir_matmul_warpspecialized.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
+ mlir_matmul_warpspecialized.func_op.attributes[
+ "llvm.emit_c_interface"
+ ] = ir.UnitAttr.get()
module.operation.verify()
return module
-def generate_matmul_multistage(input_type=np.float16,
- output_type=np.float32,
- M=4096,
- N=4096,
- K=4096,
- BLOCK_M=128,
- BLOCK_N=128,
- BLOCK_K=64,
- max_num_stages=3):
+def generate_matmul_multistage(
+ input_type=np.float16,
+ output_type=np.float32,
+ M=4096,
+ N=4096,
+ K=4096,
+ BLOCK_M=128,
+ BLOCK_N=128,
+ BLOCK_K=64,
+ max_num_stages=3,
+):
# Limitaitons for now
assert input_type == np.float16
assert output_type == np.float32
assert M % BLOCK_M == 0
assert N % BLOCK_N == 0
@@ -419,29 +617,77 @@
c_elem_ty = f16 if output_type == np.float16 else f32
c_ty = ir.MemRefType.get((M, N), c_elem_ty)
a_tile_shape = a_tma_shape = (BLOCK_M, TMA_LAST_DIM_F16)
b_tma_shape = (BLOCK_K, TMA_LAST_DIM_F16)
b_tile_shape = (BLOCK_K, BLOCK_N)
- txcount = ((b_tile_shape[0] * b_tile_shape[1]) + (a_tile_shape[0] * a_tile_shape[1])) * f16_byte
+ txcount = (
+ (b_tile_shape[0] * b_tile_shape[1]) + (a_tile_shape[0] * a_tile_shape[1])
+ ) * f16_byte
smem_space_str = "#gpu.address_space<workgroup>"
smem_space = ir.Attribute.parse(smem_space_str)
input_type_str = "f16" if input_type == np.float16 else "f32"
output_type_str = "f16" if output_type == np.float16 else "f32"
- mbar_ty = ir.Type.parse("!nvgpu.mbarrier.group<memorySpace = " + str(smem_space) + ", num_barriers = " +
- str(num_stages) + ">")
- a_tma_desc_ty = ir.Type.parse("!nvgpu.tensormap.descriptor<tensor = memref<" + str(BLOCK_M) + "x" +
- str(TMA_LAST_DIM_F16) + "x" + str(input_type_str) + ", " + str(smem_space) +
- ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>")
- b_tma_desc_ty = ir.Type.parse("!nvgpu.tensormap.descriptor<tensor = memref<" + str(BLOCK_K) + "x" +
- str(TMA_LAST_DIM_F16) + "x" + str(input_type_str) + ", " + str(smem_space) +
- ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>")
- acc_ty = ir.Type.parse("!nvgpu.warpgroup.accumulator<fragmented=vector<" + str(BLOCK_M) + "x" + str(BLOCK_N) + "x" +
- str(output_type_str) + ">>")
- a_wgmma_ty = ir.Type.parse("!nvgpu.warpgroup.descriptor<tensor=memref<" + str(BLOCK_M) + "x" + str(BLOCK_K) + "x" +
- str(input_type_str) + ", " + smem_space_str + ">>")
- b_wgmma_ty = ir.Type.parse("!nvgpu.warpgroup.descriptor<tensor=memref<" + str(BLOCK_K) + "x" + str(BLOCK_N) + "x" +
- str(input_type_str) + ", " + smem_space_str + ">>")
+ mbar_ty = ir.Type.parse(
+ "!nvgpu.mbarrier.group<memorySpace = "
+ + str(smem_space)
+ + ", num_barriers = "
+ + str(num_stages)
+ + ">"
+ )
+ a_tma_desc_ty = ir.Type.parse(
+ "!nvgpu.tensormap.descriptor<tensor = memref<"
+ + str(BLOCK_M)
+ + "x"
+ + str(TMA_LAST_DIM_F16)
+ + "x"
+ + str(input_type_str)
+ + ", "
+ + str(smem_space)
+ + ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>"
+ )
+ b_tma_desc_ty = ir.Type.parse(
+ "!nvgpu.tensormap.descriptor<tensor = memref<"
+ + str(BLOCK_K)
+ + "x"
+ + str(TMA_LAST_DIM_F16)
+ + "x"
+ + str(input_type_str)
+ + ", "
+ + str(smem_space)
+ + ">, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>"
+ )
+ acc_ty = ir.Type.parse(
+ "!nvgpu.warpgroup.accumulator<fragmented=vector<"
+ + str(BLOCK_M)
+ + "x"
+ + str(BLOCK_N)
+ + "x"
+ + str(output_type_str)
+ + ">>"
+ )
+ a_wgmma_ty = ir.Type.parse(
+ "!nvgpu.warpgroup.descriptor<tensor=memref<"
+ + str(BLOCK_M)
+ + "x"
+ + str(BLOCK_K)
+ + "x"
+ + str(input_type_str)
+ + ", "
+ + smem_space_str
+ + ">>"
+ )
+ b_wgmma_ty = ir.Type.parse(
+ "!nvgpu.warpgroup.descriptor<tensor=memref<"
+ + str(BLOCK_K)
+ + "x"
+ + str(BLOCK_N)
+ + "x"
+ + str(input_type_str)
+ + ", "
+ + smem_space_str
+ + ">>"
+ )
with ir.InsertionPoint(module.body):
@func.FuncOp.from_py_func(a_ty, b_ty, c_ty)
def mlir_matmul_multistage(a_host, b_host, c_host):
@@ -459,33 +705,46 @@
t5 = gpu.memcpy(token_ty, [t4], a_device, a_host)
t6 = gpu.memcpy(token_ty, [t5], b_device, b_host)
t7 = gpu.wait(token_ty, [t6])
# Step 2. Create TMA Descriptors
- tma_specs = [(a_device, a_tma_desc_ty, a_tma_shape), (b_device, b_tma_desc_ty, b_tma_shape)]
+ tma_specs = [
+ (a_device, a_tma_desc_ty, a_tma_shape),
+ (b_device, b_tma_desc_ty, b_tma_shape),
+ ]
tma_descs = []
for x_device, tensor_map_ty, tile_shape in tma_specs:
- x_unranked = memref.cast(ir.UnrankedMemRefType.get(f16, a_ty.memory_space), x_device)
- tma_descs.append(nvgpu.TmaCreateDescriptorOp(tensor_map_ty, x_unranked, map(c, tile_shape)).result)
+ x_unranked = memref.cast(
+ ir.UnrankedMemRefType.get(f16, a_ty.memory_space), x_device
+ )
+ tma_descs.append(
+ nvgpu.TmaCreateDescriptorOp(
+ tensor_map_ty, x_unranked, map(c, tile_shape)
+ ).result
+ )
a_tma_desc, b_tma_desc = tma_descs
# Step 3. Launch Kernel with 1 Warpgroup
cta_m = M // BLOCK_M
cta_n = N // BLOCK_N
assert M % BLOCK_M == 0 and N % BLOCK_N == 0
grid = (cta_m, cta_n, 1)
block = (WARP_GROUP_SIZE, 1, 1)
- launch_op = gpu.LaunchOp(token_ty, [t7],
- *map(c, grid),
- *map(c, block),
- dynamicSharedMemorySize=c(smem_size, ty=i32))
+ launch_op = gpu.LaunchOp(
+ token_ty,
+ [t7],
+ *map(c, grid),
+ *map(c, block),
+ dynamicSharedMemorySize=c(smem_size, ty=i32)
+ )
launch_op.body.blocks.append(*([index] * 12))
with ir.InsertionPoint(launch_op.body.blocks[0]):
# GPU Step 0. Bootstrapping
memref.assume_alignment(c_device, 16)
dynamic_smem = gpu.dynamic_shared_memory(
- ir.MemRefType.get((MLIR_DYNAMIC, ), i8, memory_space=smem_space))
+ ir.MemRefType.get((MLIR_DYNAMIC,), i8, memory_space=smem_space)
+ )
ticks = c(10000000)
tidx = gpu.thread_id(gpu.Dimension.x)
primaryThread = arith.cmpi(arith.CmpIPredicate.eq, tidx, c(0))
warpId = arith.divui(tidx, c(32))
bidx = gpu.block_id(gpu.Dimension.x)
@@ -502,175 +761,319 @@
# GPU Step 2. Prefetch TMA descriptors
nvgpu.tma_prefetch_descriptor(a_tma_desc, predicate=primaryThread)
nvgpu.tma_prefetch_descriptor(b_tma_desc, predicate=primaryThread)
# GPU Step 3. Prologue (global memory --> shared memory)
- for_op = scf.ForOp(c(0), c(num_stages-1), c(1))
+ for_op = scf.ForOp(c(0), c(num_stages - 1), c(1))
with ir.InsertionPoint(for_op.body):
iv = for_op.induction_variable
# Step 3.1. Calculate offsets
a_offset = arith.muli(iv, c(lhs_tile_bytes))
- a_tma_slice = memref.view(ir.MemRefType.get(a_tma_shape, f16, memory_space=smem_space),
- dynamic_smem, a_offset, [])
- b_offset = arith.addi(arith.muli(iv, c(rhs_tile_bytes)), c(lhs_tile_bytes * num_stages))
- b_tma_slice_1 = memref.view(ir.MemRefType.get(b_tma_shape, f16, memory_space=smem_space),
- dynamic_smem, b_offset, [])
- b_offset2 = arith.addi(b_offset, c(BLOCK_K * TMA_LAST_DIM_F16 * f16_byte))
- b_tma_slice_2 = memref.view(ir.MemRefType.get(b_tma_shape, f16, memory_space=smem_space),
- dynamic_smem, b_offset2, [])
+ a_tma_slice = memref.view(
+ ir.MemRefType.get(a_tma_shape, f16, memory_space=smem_space),
+ dynamic_smem,
+ a_offset,
+ [],
+ )
+ b_offset = arith.addi(
+ arith.muli(iv, c(rhs_tile_bytes)),
+ c(lhs_tile_bytes * num_stages),
+ )
+ b_tma_slice_1 = memref.view(
+ ir.MemRefType.get(b_tma_shape, f16, memory_space=smem_space),
+ dynamic_smem,
+ b_offset,
+ [],
+ )
+ b_offset2 = arith.addi(
+ b_offset, c(BLOCK_K * TMA_LAST_DIM_F16 * f16_byte)
+ )
+ b_tma_slice_2 = memref.view(
+ ir.MemRefType.get(b_tma_shape, f16, memory_space=smem_space),
+ dynamic_smem,
+ b_offset2,
+ [],
+ )
# Step 3.2. TMA Load
coord = arith.muli(c(64), iv)
dimY2 = arith.addi(dimY, c(64))
- debug_print("[Prologue] TMA Load a_offset={} b_offset={} b_offset2={} @ a=({},{}) b=({},{})",
- a_offset,
- b_offset,
- b_offset2,
- coord, dimX,
- dimY, coord,
- predicate=primaryThread)
- nvgpu.TmaAsyncLoadOp(a_tma_slice,
- mbarTMA,
- a_tma_desc,
- coordinates=[coord, dimX],
- mbarId=iv,
- predicate=primaryThread)
- nvgpu.TmaAsyncLoadOp(b_tma_slice_1,
- mbarTMA,
- b_tma_desc,
- coordinates=[dimY, coord],
- mbarId=iv,
- predicate=primaryThread)
- nvgpu.TmaAsyncLoadOp(b_tma_slice_2,
- mbarTMA,
- b_tma_desc,
- coordinates=[dimY2, coord],
- mbarId=iv,
- predicate=primaryThread)
+ debug_print(
+ "[Prologue] TMA Load a_offset={} b_offset={} b_offset2={} @ a=({},{}) b=({},{})",
+ a_offset,
+ b_offset,
+ b_offset2,
+ coord,
+ dimX,
+ dimY,
+ coord,
+ predicate=primaryThread,
+ )
+ nvgpu.TmaAsyncLoadOp(
+ a_tma_slice,
+ mbarTMA,
+ a_tma_desc,
+ coordinates=[coord, dimX],
+ mbarId=iv,
+ predicate=primaryThread,
+ )
+ nvgpu.TmaAsyncLoadOp(
+ b_tma_slice_1,
+ mbarTMA,
+ b_tma_desc,
+ coordinates=[dimY, coord],
+ mbarId=iv,
+ predicate=primaryThread,
+ )
+ nvgpu.TmaAsyncLoadOp(
+ b_tma_slice_2,
+ mbarTMA,
+ b_tma_desc,
+ coordinates=[dimY2, coord],
+ mbarId=iv,
+ predicate=primaryThread,
+ )
# Step 3.2. mbarTMA arrive
- debug_print("[Prologue] mbarTMA[{}] arrive", iv, predicate=primaryThread)
- nvgpu.mbarrier_arrive_expect_tx(mbarTMA, c(txcount), iv, predicate=primaryThread)
- debug_print("[Prologue] mbarTMA[{}] arrive [done]", iv, predicate=primaryThread)
+ debug_print(
+ "[Prologue] mbarTMA[{}] arrive", iv, predicate=primaryThread
+ )
+ nvgpu.mbarrier_arrive_expect_tx(
+ mbarTMA, c(txcount), iv, predicate=primaryThread
+ )
+ debug_print(
+ "[Prologue] mbarTMA[{}] arrive [done]",
+ iv,
+ predicate=primaryThread,
+ )
scf.yield_([])
# GPU Step 4. Main Loop
acc = nvgpu.warpgroup_mma_init_accumulator(acc_ty)
- for_op = scf.ForOp(c(0), c(K // BLOCK_K), c(1), [acc, arith.constant(i1, 0)])
+ for_op = scf.ForOp(
+ c(0), c(K // BLOCK_K), c(1), [acc, arith.constant(i1, 0)]
+ )
with ir.InsertionPoint(for_op.body):
# Step 4.1. Wait mbarTMA
phaseParity = for_op.inner_iter_args[1]
iv = for_op.induction_variable
stage = arith.remui(iv, c(num_stages))
- debug_print("[MainLoop] mbarTMA[{}] try_wait phase={}", stage, phaseParity, predicate=primaryThread)
- nvgpu.MBarrierTryWaitParityOp(mbarTMA, phaseParity, ticks, mbarId=stage)
- debug_print("[MainLoop] mbarTMA[{}] try_wait phase={} [done]", stage, phaseParity, predicate=primaryThread)
+ debug_print(
+ "[MainLoop] mbarTMA[{}] try_wait phase={}",
+ stage,
+ phaseParity,
+ predicate=primaryThread,
+ )
+ nvgpu.MBarrierTryWaitParityOp(
+ mbarTMA, phaseParity, ticks, mbarId=stage
+ )
+ debug_print(
+ "[MainLoop] mbarTMA[{}] try_wait phase={} [done]",
+ stage,
+ phaseParity,
+ predicate=primaryThread,
+ )
# Step 4.2. Create WGMMA Descriptors
a_offset = arith.muli(stage, c(lhs_tile_bytes))
- a_tile_slice = memref.view(ir.MemRefType.get(a_tile_shape, f16, memory_space=smem_space),
- dynamic_smem, a_offset, [])
- b_offset = arith.addi(arith.muli(stage, c(rhs_tile_bytes)), c(lhs_tile_bytes * num_stages))
- b_tile_slice = memref.view(ir.MemRefType.get(b_tile_shape, f16, memory_space=smem_space),
- dynamic_smem, b_offset, [])
- debug_print("[MainLoop] iv={} MMA a_offset={} b_offset={}", iv, a_offset, b_offset, predicate=primaryThread)
- da = nvgpu.WarpgroupGenerateDescriptorOp(a_wgmma_ty, a_tile_slice, a_tma_desc)
- db = nvgpu.WarpgroupGenerateDescriptorOp(b_wgmma_ty, b_tile_slice, b_tma_desc)
+ a_tile_slice = memref.view(
+ ir.MemRefType.get(a_tile_shape, f16, memory_space=smem_space),
+ dynamic_smem,
+ a_offset,
+ [],
+ )
+ b_offset = arith.addi(
+ arith.muli(stage, c(rhs_tile_bytes)),
+ c(lhs_tile_bytes * num_stages),
+ )
+ b_tile_slice = memref.view(
+ ir.MemRefType.get(b_tile_shape, f16, memory_space=smem_space),
+ dynamic_smem,
+ b_offset,
+ [],
+ )
+ debug_print(
+ "[MainLoop] iv={} MMA a_offset={} b_offset={}",
+ iv,
+ a_offset,
+ b_offset,
+ predicate=primaryThread,
+ )
+ da = nvgpu.WarpgroupGenerateDescriptorOp(
+ a_wgmma_ty, a_tile_slice, a_tma_desc
+ )
+ db = nvgpu.WarpgroupGenerateDescriptorOp(
+ b_wgmma_ty, b_tile_slice, b_tma_desc
+ )
# Step 4.3. MMA
carry_acc = for_op.inner_iter_args[0]
- new_acc = nvgpu.WarpgroupMmaOp(acc.type, da, db, carry_acc, transposeB=True)
+ new_acc = nvgpu.WarpgroupMmaOp(
+ acc.type, da, db, carry_acc, transposeB=True
+ )
# Step 4.4. Load TMA for next stage
- p1 = arith.cmpi(arith.CmpIPredicate.ult, arith.addi(iv, c(num_stages-1)), c(K // BLOCK_K))
+ p1 = arith.cmpi(
+ arith.CmpIPredicate.ult,
+ arith.addi(iv, c(num_stages - 1)),
+ c(K // BLOCK_K),
+ )
p = arith.andi(primaryThread, p1)
with ir.InsertionPoint(scf.IfOp(p).then_block):
- nextStage = arith.addi(iv, c(num_stages-1))
+ nextStage = arith.addi(iv, c(num_stages - 1))
nextSlot = arith.remui(nextStage, c(num_stages))
a_offset = arith.muli(nextSlot, c(lhs_tile_bytes))
- a_tma_slice = memref.view(ir.MemRefType.get(a_tma_shape, f16, memory_space=smem_space),
- dynamic_smem, a_offset, [])
- b_offset = arith.addi(arith.muli(nextSlot, c(rhs_tile_bytes)), c(lhs_tile_bytes * num_stages))
- b_tma_slice_1 = memref.view(ir.MemRefType.get(b_tma_shape, f16, memory_space=smem_space),
- dynamic_smem, b_offset, [])
- b_offset2 = arith.addi(b_offset, c(BLOCK_K * TMA_LAST_DIM_F16 * f16_byte))
- b_tma_slice_2 = memref.view(ir.MemRefType.get(b_tma_shape, f16, memory_space=smem_space),
- dynamic_smem, b_offset2, [])
+ a_tma_slice = memref.view(
+ ir.MemRefType.get(
+ a_tma_shape, f16, memory_space=smem_space
+ ),
+ dynamic_smem,
+ a_offset,
+ [],
+ )
+ b_offset = arith.addi(
+ arith.muli(nextSlot, c(rhs_tile_bytes)),
+ c(lhs_tile_bytes * num_stages),
+ )
+ b_tma_slice_1 = memref.view(
+ ir.MemRefType.get(
+ b_tma_shape, f16, memory_space=smem_space
+ ),
+ dynamic_smem,
+ b_offset,
+ [],
+ )
+ b_offset2 = arith.addi(
+ b_offset, c(BLOCK_K * TMA_LAST_DIM_F16 * f16_byte)
+ )
+ b_tma_slice_2 = memref.view(
+ ir.MemRefType.get(
+ b_tma_shape, f16, memory_space=smem_space
+ ),
+ dynamic_smem,
+ b_offset2,
+ [],
+ )
coord = arith.muli(c(64), nextStage)
- debug_print("[MainLoop] iv={} TMA Load a_offset={} b_offset={} b_offset2={} @ a=({},{}) b=({},{})",
- iv, a_offset,
- b_offset,
- b_offset2,
- coord, dimX,
- dimY, coord,
- predicate=primaryThread)
- nvgpu.TmaAsyncLoadOp(a_tma_slice,
- mbarTMA,
- a_tma_desc,
- coordinates=[coord, dimX],
- mbarId=nextSlot,
- predicate=primaryThread)
- nvgpu.TmaAsyncLoadOp(b_tma_slice_1,
- mbarTMA,
- b_tma_desc,
- coordinates=[dimY, coord],
- mbarId=nextSlot,
- predicate=primaryThread)
+ debug_print(
+ "[MainLoop] iv={} TMA Load a_offset={} b_offset={} b_offset2={} @ a=({},{}) b=({},{})",
+ iv,
+ a_offset,
+ b_offset,
+ b_offset2,
+ coord,
+ dimX,
+ dimY,
+ coord,
+ predicate=primaryThread,
+ )
+ nvgpu.TmaAsyncLoadOp(
+ a_tma_slice,
+ mbarTMA,
+ a_tma_desc,
+ coordinates=[coord, dimX],
+ mbarId=nextSlot,
+ predicate=primaryThread,
+ )
+ nvgpu.TmaAsyncLoadOp(
+ b_tma_slice_1,
+ mbarTMA,
+ b_tma_desc,
+ coordinates=[dimY, coord],
+ mbarId=nextSlot,
+ predicate=primaryThread,
+ )
dimY2 = arith.addi(dimY, c(64))
- nvgpu.TmaAsyncLoadOp(b_tma_slice_2,
- mbarTMA,
- b_tma_desc,
- coordinates=[dimY2, coord],
- mbarId=nextSlot,
- predicate=primaryThread)
-
- debug_print("[MainLoop] mbarTMA[{}] arrive", nextSlot, predicate=primaryThread)
- nvgpu.mbarrier_arrive_expect_tx(mbarTMA, c(txcount), nextSlot, predicate=primaryThread)
- debug_print("[MainLoop] mbarTMA[{}] arrive [done]", nextSlot, predicate=primaryThread)
+ nvgpu.TmaAsyncLoadOp(
+ b_tma_slice_2,
+ mbarTMA,
+ b_tma_desc,
+ coordinates=[dimY2, coord],
+ mbarId=nextSlot,
+ predicate=primaryThread,
+ )
+
+ debug_print(
+ "[MainLoop] mbarTMA[{}] arrive",
+ nextSlot,
+ predicate=primaryThread,
+ )
+ nvgpu.mbarrier_arrive_expect_tx(
+ mbarTMA, c(txcount), nextSlot, predicate=primaryThread
+ )
+ debug_print(
+ "[MainLoop] mbarTMA[{}] arrive [done]",
+ nextSlot,
+ predicate=primaryThread,
+ )
scf.yield_([])
# Step 4.5. Change the phaseParity
p = arith.cmpi(arith.CmpIPredicate.eq, stage, c(num_stages - 1))
- phaseParity = arith.select(p, arith.xori(phaseParity, arith.constant(i1, 1)), phaseParity)
+ phaseParity = arith.select(
+ p, arith.xori(phaseParity, arith.constant(i1, 1)), phaseParity
+ )
# Step 4.5. Yield
scf.yield_([new_acc, phaseParity])
# Step 5. Wait All WGMMA groups
nvvm.WgmmaWaitGroupSyncOp(0)
gpu.barrier()
# Step 6. Epilogue (registers --> shared memory)
- acc_smem_ty = ir.MemRefType.get((BLOCK_M, BLOCK_N), c_elem_ty, memory_space=smem_space)
+ acc_smem_ty = ir.MemRefType.get(
+ (BLOCK_M, BLOCK_N), c_elem_ty, memory_space=smem_space
+ )
acc_smem = memref.view(acc_smem_ty, dynamic_smem, c(0), [])
debug_print("Storing", predicate=primaryThread)
nvgpu.WarpgroupMmaStoreOp(for_op.results[0], acc_smem)
gpu.barrier()
# GPU Step 7. Epilogue (shared memory --> global memory)
- fd = ir.MemRefType.get([BLOCK_M * BLOCK_N], c_elem_ty, memory_space=smem_space)
+ fd = ir.MemRefType.get(
+ [BLOCK_M * BLOCK_N], c_elem_ty, memory_space=smem_space
+ )
collapsed_smem = memref.view(fd, dynamic_smem, c(0), [])
- rty = ir.MemRefType.get((BLOCK_M, BLOCK_N), c_elem_ty,
- ir.Attribute.parse("strided<[" + str(N) + ", 1], offset: ?>"))
- c_device_per_block = memref.SubViewOp(rty, c_device, [dimX, dimY], [], [], [MLIR_DYNAMIC, MLIR_DYNAMIC],
- [BLOCK_M, BLOCK_N], [1, 1])
+ rty = ir.MemRefType.get(
+ (BLOCK_M, BLOCK_N),
+ c_elem_ty,
+ ir.Attribute.parse("strided<[" + str(N) + ", 1], offset: ?>"),
+ )
+ c_device_per_block = memref.SubViewOp(
+ rty,
+ c_device,
+ [dimX, dimY],
+ [],
+ [],
+ [MLIR_DYNAMIC, MLIR_DYNAMIC],
+ [BLOCK_M, BLOCK_N],
+ [1, 1],
+ )
vlen = 1
- for_op = scf.ForOp(tidx, c(BLOCK_M * BLOCK_N), c(vlen * WARP_GROUP_SIZE))
+ for_op = scf.ForOp(
+ tidx, c(BLOCK_M * BLOCK_N), c(vlen * WARP_GROUP_SIZE)
+ )
with ir.InsertionPoint(for_op.body):
x = arith.divui(for_op.induction_variable, c(BLOCK_M))
y = arith.remui(for_op.induction_variable, c(BLOCK_N))
- vdata = vector.load(ir.VectorType.get((vlen, ), c_elem_ty), collapsed_smem,
- [for_op.induction_variable])
+ vdata = vector.load(
+ ir.VectorType.get((vlen,), c_elem_ty),
+ collapsed_smem,
+ [for_op.induction_variable],
+ )
vector.store(vdata, c_device_per_block, [x, y])
scf.yield_([])
gpu.terminator()
# Step 4. Copy back to host
t8 = gpu.wait(token_ty, [launch_op])
t9 = gpu.memcpy(token_ty, [t8], c_host, c_device)
gpu.wait(token_ty, [t9])
- mlir_matmul_multistage.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
+ mlir_matmul_multistage.func_op.attributes[
+ "llvm.emit_c_interface"
+ ] = ir.UnitAttr.get()
module.operation.verify()
return module
--- tools/nvgpucompiler.py 2024-02-12 13:33:26.000000 +0000
+++ tools/nvgpucompiler.py 2024-02-12 13:36:46.799197 +0000
@@ -13,10 +13,11 @@
import sys
_SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__))
sys.path.append(_SCRIPT_PATH)
+
class NvgpuCompiler:
"""Nvgpu class for compiling and building MLIR modules."""
def __init__(self, options: str, opt_level: int, shared_libs: Sequence[str]):
pipeline = f"builtin.module(gpu-lower-to-nvvm-pipeline{{{options}}})"
@@ -29,11 +30,11 @@
self.compile(module)
def compile(self, module: ir.Module):
"""Compiles the module by invoking the nvgpu pipeline."""
passmanager.PassManager.parse(self.pipeline).run(module.operation)
-
+
def jit(self, module: ir.Module) -> execution_engine.ExecutionEngine:
"""Wraps the module in a JIT execution engine."""
return execution_engine.ExecutionEngine(
module, opt_level=self.opt_level, shared_libs=self.shared_libs
)
``````````
</details>
https://github.com/llvm/llvm-project/pull/81478
More information about the Mlir-commits
mailing list