[Mlir-commits] [mlir] [mlir] GEMM Hopper Tensor Core Integration Test (PR #81478)
Manish Gupta
llvmlistbot at llvm.org
Mon Feb 19 16:09:14 PST 2024
================
@@ -0,0 +1,1078 @@
+import numpy as np
+from mlir import ir
+from mlir.dialects import arith
+from mlir.dialects import func
+from mlir.dialects import gpu
+from mlir.dialects import memref
+from mlir.dialects import nvgpu
+from mlir.dialects import nvvm
+from mlir.dialects import llvm
+from mlir.dialects import builtin
+from mlir.dialects import scf
+from mlir.dialects import vector
+
+TMA_LAST_DIM_F16 = 64 # 128B flaot16
+WARP_SIZE = 32
+WARP_GROUP_SIZE = WARP_SIZE * 4
+
+PRODUCER_REGISTER_SIZE = 40
+CONSUMER_REGISTER_SIZE = 232
+
+PRODUCER_PRIMARY_THREAD = 128
+CONSUMER_PRIMARY_THREAD = 0
+
+MLIR_DYNAMIC = -9223372036854775808
+f16_byte = 2
+f32_byte = 4
+
+DEBUG = False
+
+
+def debug_print(fmt, *args, predicate=None, threadNumber=-1, forcePrint=False):
+ if not DEBUG and not forcePrint:
+ return
+ type_formats = []
+ for arg in args:
+ ty_format = None
+ if ir.IndexType.isinstance(arg.type):
+ ty_format = "%llu"
+ if ir.IntegerType.isinstance(arg.type):
+ width = ir.IntegerType(arg.type).width
+ if width == 64:
+ ty_format = "%llu"
+ elif width == 32:
+ ty_format = "%d"
+ elif width == 1:
+ ty_format = "%i"
+ if ir.F32Type.isinstance(arg.type):
+ ty_format = "%f"
+ if ty_format is None:
+ raise NotImplementedError(arg.type)
+ type_formats.append(ty_format)
+ if threadNumber != -1:
+ tidx = gpu.thread_id(gpu.Dimension.x)
+ predicate = arith.cmpi(arith.CmpIPredicate.eq, tidx, c(threadNumber))
+ scf.yield_([])
+ if_op = scf.IfOp(predicate)
+ with ir.InsertionPoint(if_op.then_block):
+ gpu.printf(fmt.format(*type_formats) + "\n", args)
+ scf.yield_([])
+
+
+def c(value, ty=None):
+ ty = ir.IndexType.get() if ty is None else ty
+ return arith.constant(ty, value)
+
+
+def generate_matmul_ws(
+ input_type=np.float16,
+ output_type=np.float32,
+ M=4096,
+ N=4096,
+ K=4096,
+ BLOCK_M=128,
+ BLOCK_N=128,
+ BLOCK_K=128,
+ max_num_stages=3,
+):
+ # Limitaitons for now
+ assert input_type == np.float16
+ assert output_type == np.float32
+ assert M % BLOCK_M == 0
+ assert N % BLOCK_N == 0
+ assert K % BLOCK_K == 0
+
+ required_stages = (M * K + K * N) // (BLOCK_M * BLOCK_K + BLOCK_K * BLOCK_N)
+ num_stages = min(required_stages, max_num_stages)
+
+ module = ir.Module.create()
+ f16 = ir.F16Type.get()
+ f32 = ir.F32Type.get()
+ i1 = ir.IntegerType.get_signless(1)
+ i32 = ir.IntegerType.get_signless(32)
+ index = ir.IndexType.get()
+ i8 = ir.IntegerType.get_signless(8)
+ token_ty = ir.Type.parse("!gpu.async.token")
+ a_ty = ir.MemRefType.get([M, K], f16)
+ b_ty = ir.MemRefType.get((K, N), f16)
+ c_elem_ty = f16 if output_type == np.float16 else f32
+ c_ty = ir.MemRefType.get((M, N), c_elem_ty)
+ a_tile_shape = a_tma_shape = (BLOCK_M, TMA_LAST_DIM_F16)
+ b_tma_shape = (BLOCK_K, TMA_LAST_DIM_F16)
+ b_tile_shape = (BLOCK_K, BLOCK_N)
+ txcount = (
+ (b_tile_shape[0] * b_tile_shape[1]) + (a_tile_shape[0] * a_tile_shape[1])
+ ) * f16_byte
----------------
manishucsd wrote:
should we replace f16_byte as function of input_type? do we have a an milr utility that gives size in bits or bytes for an MLIR builtin primitive type?
https://github.com/llvm/llvm-project/pull/81478
More information about the Mlir-commits
mailing list