[Mlir-commits] [mlir] [mlir] GEMM Hopper Tensor Core Integration Test (PR #81478)
Manish Gupta
llvmlistbot at llvm.org
Thu Feb 22 16:53:17 PST 2024
================
@@ -0,0 +1,1078 @@
+import numpy as np
+from mlir import ir
+from mlir.dialects import arith
+from mlir.dialects import func
+from mlir.dialects import gpu
+from mlir.dialects import memref
+from mlir.dialects import nvgpu
+from mlir.dialects import nvvm
+from mlir.dialects import llvm
+from mlir.dialects import builtin
+from mlir.dialects import scf
+from mlir.dialects import vector
+
+TMA_LAST_DIM_F16 = 64 # 128B flaot16
+WARP_SIZE = 32
+WARP_GROUP_SIZE = WARP_SIZE * 4
+
+PRODUCER_REGISTER_SIZE = 40
+CONSUMER_REGISTER_SIZE = 232
+
+PRODUCER_PRIMARY_THREAD = 128
+CONSUMER_PRIMARY_THREAD = 0
+
+MLIR_DYNAMIC = -9223372036854775808
+f16_byte = 2
+f32_byte = 4
+
+DEBUG = False
+
+
+def debug_print(fmt, *args, predicate=None, threadNumber=-1, forcePrint=False):
+ if not DEBUG and not forcePrint:
+ return
+ type_formats = []
+ for arg in args:
+ ty_format = None
+ if ir.IndexType.isinstance(arg.type):
+ ty_format = "%llu"
+ if ir.IntegerType.isinstance(arg.type):
+ width = ir.IntegerType(arg.type).width
+ if width == 64:
+ ty_format = "%llu"
+ elif width == 32:
+ ty_format = "%d"
+ elif width == 1:
+ ty_format = "%i"
+ if ir.F32Type.isinstance(arg.type):
+ ty_format = "%f"
+ if ty_format is None:
+ raise NotImplementedError(arg.type)
+ type_formats.append(ty_format)
+ if threadNumber != -1:
+ tidx = gpu.thread_id(gpu.Dimension.x)
+ predicate = arith.cmpi(arith.CmpIPredicate.eq, tidx, c(threadNumber))
+ scf.yield_([])
+ if_op = scf.IfOp(predicate)
+ with ir.InsertionPoint(if_op.then_block):
+ gpu.printf(fmt.format(*type_formats) + "\n", args)
+ scf.yield_([])
+
+
+def c(value, ty=None):
+ ty = ir.IndexType.get() if ty is None else ty
+ return arith.constant(ty, value)
+
+
+def generate_matmul_ws(
+ input_type=np.float16,
+ output_type=np.float32,
+ M=4096,
+ N=4096,
+ K=4096,
+ BLOCK_M=128,
+ BLOCK_N=128,
+ BLOCK_K=128,
+ max_num_stages=3,
+):
+ # Limitaitons for now
+ assert input_type == np.float16
+ assert output_type == np.float32
+ assert M % BLOCK_M == 0
+ assert N % BLOCK_N == 0
+ assert K % BLOCK_K == 0
----------------
manishucsd wrote:
Sounds good!
For this PR, should we add asserts on (BLOCK_M, BLOCK_N, BLOCK_K) == (128,128,64)?
https://github.com/llvm/llvm-project/pull/81478
More information about the Mlir-commits
mailing list