[Mlir-commits] [mlir] [mlir][nvgpu] Improve `tensormap.descriptor` Type Verifier (PR #77904)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jan 12 02:39:49 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-nvgpu

@llvm/pr-subscribers-mlir

Author: Guray Ozen (grypp)

<details>
<summary>Changes</summary>

This PR improves the verifier for the `nvgpu.tensormap.descriptor` type. The descriptor contains information for TMA, and the compile-time check ensures its restrictions, such as the last memory dimension being 128-byte. This prevents runtime crashes.

See cuda driver for more explanation:
https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html#group__CUDA__TENSOR__MEMORY_1ga7c7d2aaac9e49294304e755e6f341d7

---
Full diff: https://github.com/llvm/llvm-project/pull/77904.diff


5 Files Affected:

- (modified) mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h (+4) 
- (modified) mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp (+16) 
- (modified) mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir (+7-10) 
- (modified) mlir/test/Dialect/NVGPU/invalid.mlir (+20) 
- (modified) mlir/test/Dialect/NVGPU/tmaload-transform.mlir (+18-18) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h
index 2888fed2779575..cc41e17a8f59cf 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h
@@ -27,6 +27,10 @@ constexpr int kWarpSize = 32;
 constexpr int kWgmmaSizeM = 64;
 /// Maximum tensor dimension that TMA supports
 constexpr int kMaxTMATensorDimension = 5;
+/// Maximum any dimension for TMA
+constexpr unsigned kMaxTMADimension = 256;
+/// Last dimension of 2D+ TMA must be 128 bytes
+constexpr unsigned kMaxTMALastdimByte = 128;
 
 #define GET_ATTRDEF_CLASSES
 #include "mlir/Dialect/NVGPU/IR/NVGPUAttrDefs.h.inc"
diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
index b0a4ed1cc2697c..cbbdbb5c948e36 100644
--- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
+++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
@@ -355,6 +355,22 @@ std::optional<InFlightDiagnostic> verifyTmaDescriptorWithMemref(
   if (!descMemref.hasStaticShape())
     return op->emitError() << "the tensor map descriptor must be static shaped";
 
+  for (auto dim : descMemref.getShape()) {
+    if (dim <= 0 || dim > kMaxTMADimension) {
+      return op->emitError() << "the tensor map descriptor must not have zero "
+                                "dimension";
+    }
+  }
+  if (descMemref.getRank() > 1) {
+    unsigned lastDimensionByte =
+        descMemref.getElementTypeBitWidth() * descMemref.getShape().back() / 8;
+    if (lastDimensionByte != kMaxTMALastdimByte)
+      return op->emitError() << "the tensormap descriptor must have last "
+                                "dimension of "
+                             << kMaxTMALastdimByte << " bytes but it is "
+                             << lastDimensionByte << " bytes";
+  }
+
   // No verification if memref type is not provided
   if (!memrefType.has_value())
     return std::nullopt;
diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index b8a0f75d1cc8b9..3f859a2c2be884 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -753,10 +753,7 @@ func.func @tma_prefetch(%tensorMap1d: !tensorMap1d, %p : i1) {
 }
 
 !lhsTensorMap = !nvgpu.tensormap.descriptor<tensor = memref<128x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>
-!rhsTensorMap = !nvgpu.tensormap.descriptor<tensor = memref<64x128xf16, strided<[128, 1], offset: 8192>, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>
-
-!shmemlhs = memref<128x64xf16,3>
-!shmemrhs = memref<64x128xf16, strided<[128, 1], offset: 8192>, 3>
+!rhsTensorMap = !nvgpu.tensormap.descriptor<tensor = memref<64x64xf16, strided<[64, 1], offset: 8192>, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>
 
 module @mymodule {
   // Dynamic Shared memory
@@ -765,17 +762,17 @@ module @mymodule {
   func.func @async_tma_load(%lhsTensorMap: !lhsTensorMap, %rhsTensorMap: !rhsTensorMap, %mbarrier: !barrierType) {
     %c0 = arith.constant 0 : index
     %dynamicMem = memref.get_global @dynamicShmem : memref<0xf16, 3>
-    %lhsShmem = memref.reinterpret_cast %dynamicMem to offset: [0], sizes: [128,64], strides: [64,1] : memref<0xf16, 3> to !shmemlhs
-    %rhsShmem2 = memref.reinterpret_cast %dynamicMem to offset: [0], sizes: [2,64,128],  strides: [8192,128,1] : memref<0xf16, 3> to memref<2x64x128xf16,3>
-    %rhsShmem3 = memref.subview %rhsShmem2[1,0,0][1, 64, 128][1, 1, 1] : memref<2x64x128xf16,3> to memref<1x64x128xf16, strided<[8192, 128, 1], offset: 8192>, 3>
-    %rhsShmem = memref.subview %rhsShmem3[0,0,0][1, 64, 128][1, 1, 1] : memref<1x64x128xf16, strided<[8192, 128, 1], offset: 8192>, 3> to !shmemrhs
+    %lhsShmem = memref.reinterpret_cast %dynamicMem to offset: [0], sizes: [128,64], strides: [64,1] : memref<0xf16, 3> to memref<128x64xf16,3>
+    %rhsShmem2 = memref.reinterpret_cast %dynamicMem to offset: [0], sizes: [4, 64, 64],  strides: [4096, 64, 1] : memref<0xf16, 3> to memref<4x64x64xf16,3>
+    %rhsShmem3 = memref.subview %rhsShmem2[2, 0, 0][1, 64, 64][1, 1, 1] : memref<4x64x64xf16,3> to memref<1x64x64xf16, strided<[4096, 64, 1], offset: 8192>, 3>
+    %rhsShmem = memref.subview %rhsShmem3[0, 0, 0][1, 64, 64][1, 1, 1]  : memref<1x64x64xf16, strided<[4096, 64, 1], offset: 8192>, 3> to memref<64x64xf16, strided<[64, 1], offset: 8192>, 3>
     // CHECK: nvvm.cp.async.bulk.tensor.shared.cluster.global
-    nvgpu.tma.async.load %lhsTensorMap[%c0, %c0], %mbarrier[%c0] to %lhsShmem : !lhsTensorMap, !barrierType -> !shmemlhs
+    nvgpu.tma.async.load %lhsTensorMap[%c0, %c0], %mbarrier[%c0] to %lhsShmem : !lhsTensorMap, !barrierType -> memref<128x64xf16,3>
     // CHECK: %[[desc:.+]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> 
     // CHECK: %[[c8192:.+]] = llvm.mlir.constant(8192 : index) : i64
     // CHECK: %[[shmemOfset:.+]] = llvm.getelementptr %[[desc]][%[[c8192]]] : (!llvm.ptr<3>, i64)
     // CHECK: nvvm.cp.async.bulk.tensor.shared.cluster.global %[[shmemOfset]], %{{.*}}, %{{.*}}, box[%{{.*}}, %{{.*}}]
-    nvgpu.tma.async.load %rhsTensorMap[%c0, %c0], %mbarrier[%c0] to %rhsShmem : !rhsTensorMap, !barrierType -> !shmemrhs
+    nvgpu.tma.async.load %rhsTensorMap[%c0, %c0], %mbarrier[%c0] to %rhsShmem : !rhsTensorMap, !barrierType -> memref<64x64xf16, strided<[64, 1], offset: 8192>, 3>
     return
   }
 }
diff --git a/mlir/test/Dialect/NVGPU/invalid.mlir b/mlir/test/Dialect/NVGPU/invalid.mlir
index e1949fcfad7ad6..4c070e9a0fad35 100644
--- a/mlir/test/Dialect/NVGPU/invalid.mlir
+++ b/mlir/test/Dialect/NVGPU/invalid.mlir
@@ -316,3 +316,23 @@ func.func @tma_load_4(%desc: !desc,  %buffer1: memref<128xf32,3>, %buffer2: memr
   nvgpu.tma.async.load %desc[%c0, %c0], %mbarrier[%c0] to %buffer1 : !desc, !mbarrier -> memref<128xf32,3>
   return
 }
+
+// -----
+
+!desc = !nvgpu.tensormap.descriptor<tensor = memref<64x128xf16,3>, swizzle=swizzle_32b, l2promo = none, oob = zero, interleave = none>
+func.func @tma_generate_descriptor_incorrect_last_dim(%b0 : index, %b1 : index, %mem : memref<*xf16>) {
+  // expected-error @+1 {{the tensormap descriptor must have last dimension of 128 bytes but it is 256 bytes}}
+  %descA = nvgpu.tma.create.descriptor %mem box[%b0, %b1] : memref<*xf16> -> !desc
+  return
+}
+// -----
+
+
+!desc = !nvgpu.tensormap.descriptor<tensor = memref<64x128xf32,3>, swizzle=swizzle_32b, l2promo = none, oob = zero, interleave = none>
+!mbarrier = !nvgpu.mbarrier.group<memorySpace = #gpu.address_space<workgroup>>
+func.func @tma_generate_descriptor_incorrect_last_dim(%desc: !desc,  %buffer2: memref<64x128xf32,3>, %mbarrier: !mbarrier) {
+  %c0 = arith.constant 0 : index
+  // expected-error @+1 {{the tensormap descriptor must have last dimension of 128 bytes but it is 512 bytes}}
+  nvgpu.tma.async.load %desc[%c0, %c0], %mbarrier[%c0] to %buffer2 : !desc, !mbarrier -> memref<64x128xf32,3>
+  return
+}
\ No newline at end of file
diff --git a/mlir/test/Dialect/NVGPU/tmaload-transform.mlir b/mlir/test/Dialect/NVGPU/tmaload-transform.mlir
index ab6483151a63f2..5f3074cad926c9 100644
--- a/mlir/test/Dialect/NVGPU/tmaload-transform.mlir
+++ b/mlir/test/Dialect/NVGPU/tmaload-transform.mlir
@@ -3,8 +3,8 @@
 // RUN:     -test-transform-dialect-erase-schedule \
 // RUN: | FileCheck %s
 
-memref.global "private" @bufferLhsGlobal : memref<64x8xf32, #gpu.address_space<workgroup>>
-memref.global "private" @bufferRhsGlobal : memref<8x128xf32, #gpu.address_space<workgroup>>
+memref.global "private" @bufferLhsGlobal : memref<64x32xf32, #gpu.address_space<workgroup>>
+memref.global "private" @bufferRhsGlobal : memref<8x32xf32, #gpu.address_space<workgroup>>
 
 // CHECK-LABEL: func.func @main()
 func.func @main() {
@@ -12,26 +12,26 @@ func.func @main() {
   %c128 = arith.constant 128 : index
 
   %0 = gpu.wait async
-  %memref, %asyncToken = gpu.alloc async [%0] () : memref<64x8xf32>
-  %memref_1, %asyncToken_2 = gpu.alloc async [%0] () : memref<8x128xf32>
+  %memref, %asyncToken = gpu.alloc async [%0] () : memref<64x32xf32>
+  %memref_1, %asyncToken_2 = gpu.alloc async [%0] () : memref<8x32xf32>
 
-  //      CHECK: %[[M1:.*]] = memref.cast %{{.*}} : memref<64x8xf32> to memref<*xf32>
+  //      CHECK: %[[M1:.*]] = memref.cast %{{.*}} : memref<64x32xf32> to memref<*xf32>
   //      CHECK: %[[c64:.*]] = arith.constant 64 : index
   //      CHECK: %[[c8:.*]] = arith.constant 8 : index
   //      CHECK: %[[D1:.*]] = nvgpu.tma.create.descriptor %[[M1]] box[%[[c64]], %[[c8]]]
-  // CHECK-SAME:   : memref<*xf32> -> <tensor = memref<64x8xf32, #gpu.address_space<workgroup>>, swizzle = none, l2promo = none, oob = zero, interleave = none>
-  //      CHECK: %[[cast_2:.*]] = memref.cast %memref_0 : memref<8x128xf32> to memref<*xf32>
+  // CHECK-SAME:   : memref<*xf32> -> <tensor = memref<64x32xf32, #gpu.address_space<workgroup>>, swizzle = none, l2promo = none, oob = zero, interleave = none>
+  //      CHECK: %[[cast_2:.*]] = memref.cast %memref_0 : memref<8x32xf32> to memref<*xf32>
   //      CHECK: %[[c8_2:.*]] = arith.constant 8 : index
   //      CHECK: %[[c128_2:.*]] = arith.constant 128 : index
   //      CHECK: %[[D2:.*]] = nvgpu.tma.create.descriptor %cast_2 box[%[[c8_2]], %[[c128_2]]]
-  // CHECK-SAME:   : memref<*xf32> -> <tensor = memref<8x128xf32, #gpu.address_space<workgroup>>, swizzle = none, l2promo = none, oob = zero, interleave = none>
+  // CHECK-SAME:   : memref<*xf32> -> <tensor = memref<8x32xf32, #gpu.address_space<workgroup>>, swizzle = none, l2promo = none, oob = zero, interleave = none>
   // CHECK: gpu.launch
   gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1)
             threads(%tx, %ty, %tz) in (%block_x = %c128, %block_y = %c1, %block_z = %c1) {
-    //      CHECK: %[[G1:.*]] = memref.get_global @bufferLhsGlobal : memref<64x8xf32, #gpu.address_space<workgroup>>
-    //      CHECK: %[[G2:.*]] = memref.get_global @bufferRhsGlobal : memref<8x128xf32, #gpu.address_space<workgroup>>
-    %out = memref.get_global @bufferLhsGlobal : memref<64x8xf32, #gpu.address_space<workgroup>>
-    %out_1 = memref.get_global @bufferRhsGlobal : memref<8x128xf32, #gpu.address_space<workgroup>>
+    //      CHECK: %[[G1:.*]] = memref.get_global @bufferLhsGlobal : memref<64x32xf32, #gpu.address_space<workgroup>>
+    //      CHECK: %[[G2:.*]] = memref.get_global @bufferRhsGlobal : memref<8x32xf32, #gpu.address_space<workgroup>>
+    %out = memref.get_global @bufferLhsGlobal : memref<64x32xf32, #gpu.address_space<workgroup>>
+    %out_1 = memref.get_global @bufferRhsGlobal : memref<8x32xf32, #gpu.address_space<workgroup>>
 
     //      CHECK: %[[B:.*]] = nvgpu.mbarrier.create -> <memorySpace = #gpu.address_space<workgroup>
     //      CHECK: nvgpu.mbarrier.init %[[B]][%{{.*}}], %{{.*}} : <memorySpace = #gpu.address_space<workgroup>
@@ -45,15 +45,15 @@ func.func @main() {
     //
     //      CHECK:   %[[c0_7:.*]] = arith.constant 0 : index
     //      CHECK:   nvgpu.tma.async.load %[[D1]][%[[c0_7]], %[[c0_7]]], %[[B]][%{{.*}}] to %[[G1]]
-    // CHECK-SAME:     : <tensor = memref<64x8xf32, #gpu.address_space<workgroup>>,
+    // CHECK-SAME:     : <tensor = memref<64x32xf32, #gpu.address_space<workgroup>>,
     // CHECK-SAME:        swizzle = none, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>
-    // CHECK-SAME:     -> memref<64x8xf32, #gpu.address_space<workgroup>>
+    // CHECK-SAME:     -> memref<64x32xf32, #gpu.address_space<workgroup>>
     //
     //      CHECK:   %[[c0_8:.*]] = arith.constant 0 : index
     //      CHECK:   nvgpu.tma.async.load %[[D2]][%[[c0_8]], %[[c0_8]]], %[[B]][%{{.*}}] to %[[G2]]
-    // CHECK-SAME:     : <tensor = memref<8x128xf32, #gpu.address_space<workgroup>>,
+    // CHECK-SAME:     : <tensor = memref<8x32xf32, #gpu.address_space<workgroup>>,
     // CHECK-SAME:         swizzle = none, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>
-    // CHECK-SAME:    -> memref<8x128xf32, #gpu.address_space<workgroup>>
+    // CHECK-SAME:    -> memref<8x32xf32, #gpu.address_space<workgroup>>
     //
     //      CHECK:   %[[c6144:.*]] = arith.constant 6144 : index
     //      CHECK:   nvgpu.mbarrier.arrive.expect_tx %[[B]][%{{.*}}], %[[c6144]] : <memorySpace = #gpu.address_space<workgroup>
@@ -67,8 +67,8 @@ func.func @main() {
     //      CHECK: nvgpu.mbarrier.try_wait.parity %[[B]][%{{.*}}], %[[c0_6]], %[[c10000000]] : <memorySpace = #gpu.address_space<workgroup>
 
     /// Both copies are matched and end up in the same async group.
-    linalg.copy ins(%memref: memref<64x8xf32>) outs(%out: memref<64x8xf32, #gpu.address_space<workgroup>>)
-    linalg.copy ins(%memref_1: memref<8x128xf32>) outs(%out_1: memref<8x128xf32, #gpu.address_space<workgroup>>)
+    linalg.copy ins(%memref: memref<64x32xf32>) outs(%out: memref<64x32xf32, #gpu.address_space<workgroup>>)
+    linalg.copy ins(%memref_1: memref<8x32xf32>) outs(%out_1: memref<8x32xf32, #gpu.address_space<workgroup>>)
 
     gpu.terminator
   }

``````````

</details>


https://github.com/llvm/llvm-project/pull/77904


More information about the Mlir-commits mailing list