[Mlir-commits] [mlir] [mlir][nvgpu] NVGPU Tutorials (PR #87065)

Manish Gupta llvmlistbot at llvm.org
Sat Mar 30 09:48:12 PDT 2024


================
@@ -0,0 +1,92 @@
+# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
+# RUN:   %PYTHON %s | FileCheck %s
+
+# ===----------------------------------------------------------------------===//
+#  Chapter 3 : GEMM 64x64x64 with Tensor Core
+# ===----------------------------------------------------------------------===//
+#
+# This program demonstrates a GEMM operation with 64x64x64 matrix multiplication
+#
+# This chapter introduces demonstrates:
+# 1. Execute TMA Load for two input matrices
+# 2. Performs Tensor Core GEMM 64x64x64 by warpgroup
+# 3. Stores fragmented registers to global memory by warpgroup
+#
+# ===----------------------------------------------------------------------===//
+
+
+from mlir import ir
+from mlir.dialects import nvgpu, scf, arith, memref, vector, gpu
+from tools.nvdsl import *
+from mlir.extras import types as T
+import numpy as np
+
+
+ at NVDSL.mlir_func
+def gemm_64_64_64(x, y, z):
+    token_ty = ir.Type.parse("!gpu.async.token")
+    t1 = gpu.wait(token_ty, [])
+    x_dev, t2 = gpu.alloc(x.type, token_ty, [t1], [], [])
+    y_dev, t3 = gpu.alloc(y.type, token_ty, [t2], [], [])
+    z_dev, t4 = gpu.alloc(z.type, token_ty, [t3], [], [])
+    t5 = gpu.memcpy(token_ty, [t4], x_dev, x)
+    t6 = gpu.memcpy(token_ty, [t5], y_dev, y)
+    t7 = gpu.wait(token_ty, [t6])
+
+    sw = nvgpu.TensorMapSwizzleKind.SWIZZLE_128B
+    x_tma = TMA([N, N], x.type, swizzle=sw)
+    y_tma = TMA([N, N], y.type, swizzle=sw)
+    x_tma.create_descriptor(x_dev)
+    y_tma.create_descriptor(y_dev)
+
+    @NVDSL.mlir_gpu_launch(grid=(1, 1, 1), block=(128, 1, 1), smem=16384)
----------------
manishucsd wrote:

```
smem_size_in_bytes=N*N*get_type_size(x.type) + N*N*get_type_size(y.type)
@NVDSL.mlir_gpu_launch(grid=(1, 1, 1), block=(128, 1, 1), smem=smem_size_in_bytes)
```

Let us try and have as few magic numbers as possible in the tutorial. Also, highlighting the use of APIs that are already present in the tutorial. 

https://github.com/llvm/llvm-project/pull/87065


More information about the Mlir-commits mailing list