[Mlir-commits] [mlir] f33a0a4 - [mlir][nvgpu] Improve `tensormap.descriptor` Type Verifier (#77904)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Feb 5 01:32:07 PST 2024


Author: Guray Ozen
Date: 2024-02-05T10:32:03+01:00
New Revision: f33a0a483550e3441aae4059d6b3d81eab6a398c

URL: https://github.com/llvm/llvm-project/commit/f33a0a483550e3441aae4059d6b3d81eab6a398c
DIFF: https://github.com/llvm/llvm-project/commit/f33a0a483550e3441aae4059d6b3d81eab6a398c.diff

LOG: [mlir][nvgpu] Improve `tensormap.descriptor` Type Verifier (#77904)

This PR improves the verifier for the `nvgpu.tensormap.descriptor` type.
The descriptor contains information for TMA, and the compile-time check
ensures its restrictions, such as the last memory dimension being
128-byte. This prevents runtime crashes.

See cuda driver for more explanation:

https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html#group__CUDA__TENSOR__MEMORY_1ga7c7d2aaac9e49294304e755e6f341d7

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h
    mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
    mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
    mlir/test/Dialect/NVGPU/invalid.mlir
    mlir/test/Dialect/NVGPU/tmaload-transform.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h
index 2888fed277957..19070f6f062a0 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h
@@ -25,8 +25,16 @@ constexpr int kWarpSize = 32;
 
 /// M size of wgmma.mma_async instruction
 constexpr int kWgmmaSizeM = 64;
-/// Maximum tensor dimension that TMA supports
-constexpr int kMaxTMATensorDimension = 5;
+
+/// Maximum TMA tile dimension (tensorRank) must be non-zero and less than or
+/// equal to the maximum supported dimensionality of 5.
+constexpr unsigned kMaxTMATensorDimension = 5;
+/// Maximum TMA tile size (boxDim), which specifies number of elements
+/// to be traversed along each of the kMaxTMATensorDimension (tensorRank)
+/// dimensions, must be non-zero and less than or equal to 256.
+constexpr unsigned kMaxTMADimension = 256;
+/// Last dimension of 2D+ TMA must be 128 bytes
+constexpr unsigned kMaxTMALastdimByte = 128;
 
 #define GET_ATTRDEF_CLASSES
 #include "mlir/Dialect/NVGPU/IR/NVGPUAttrDefs.h.inc"

diff  --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
index bdfb2f54052ae..4b6327479a219 100644
--- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
+++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
@@ -355,6 +355,23 @@ std::optional<InFlightDiagnostic> verifyTmaDescriptorWithMemref(
   if (!descMemref.hasStaticShape())
     return op->emitError() << "the tensor map descriptor must be static shaped";
 
+  for (auto dim : descMemref.getShape()) {
+    if (dim <= 0 || dim > kMaxTMADimension) {
+      return op->emitError() << "the tensor map descriptor must have "
+                                "dimensions between 1 and "
+                             << kMaxTMADimension << " but it is " << dim;
+    }
+  }
+  if (descMemref.getRank() > 1) {
+    unsigned lastDimensionByte =
+        descMemref.getElementTypeBitWidth() * descMemref.getShape().back() / 8;
+    if (lastDimensionByte != kMaxTMALastdimByte)
+      return op->emitError() << "the tensormap descriptor must have last "
+                                "dimension of "
+                             << kMaxTMALastdimByte << " bytes but it is "
+                             << lastDimensionByte << " bytes";
+  }
+
   // No verification if memref type is not provided
   if (!memrefType.has_value())
     return std::nullopt;

diff  --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index d0ff38a1d6034..09a873fa331d5 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -799,10 +799,7 @@ func.func @tma_prefetch(%tensorMap1d: !tensorMap1d, %p : i1) {
 }
 
 !lhsTensorMap = !nvgpu.tensormap.descriptor<tensor = memref<128x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>
-!rhsTensorMap = !nvgpu.tensormap.descriptor<tensor = memref<64x128xf16, strided<[128, 1], offset: 8192>, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>
-
-!shmemlhs = memref<128x64xf16,3>
-!shmemrhs = memref<64x128xf16, strided<[128, 1], offset: 8192>, 3>
+!rhsTensorMap = !nvgpu.tensormap.descriptor<tensor = memref<64x64xf16, strided<[64, 1], offset: 8192>, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>
 
 module @mymodule {
   // Dynamic Shared memory
@@ -811,17 +808,17 @@ module @mymodule {
   func.func @async_tma_load(%lhsTensorMap: !lhsTensorMap, %rhsTensorMap: !rhsTensorMap, %mbarrier: !barrierType) {
     %c0 = arith.constant 0 : index
     %dynamicMem = memref.get_global @dynamicShmem : memref<0xf16, 3>
-    %lhsShmem = memref.reinterpret_cast %dynamicMem to offset: [0], sizes: [128,64], strides: [64,1] : memref<0xf16, 3> to !shmemlhs
-    %rhsShmem2 = memref.reinterpret_cast %dynamicMem to offset: [0], sizes: [2,64,128],  strides: [8192,128,1] : memref<0xf16, 3> to memref<2x64x128xf16,3>
-    %rhsShmem3 = memref.subview %rhsShmem2[1,0,0][1, 64, 128][1, 1, 1] : memref<2x64x128xf16,3> to memref<1x64x128xf16, strided<[8192, 128, 1], offset: 8192>, 3>
-    %rhsShmem = memref.subview %rhsShmem3[0,0,0][1, 64, 128][1, 1, 1] : memref<1x64x128xf16, strided<[8192, 128, 1], offset: 8192>, 3> to !shmemrhs
+    %lhsShmem = memref.reinterpret_cast %dynamicMem to offset: [0], sizes: [128,64], strides: [64,1] : memref<0xf16, 3> to memref<128x64xf16,3>
+    %rhsShmem2 = memref.reinterpret_cast %dynamicMem to offset: [0], sizes: [4, 64, 64],  strides: [4096, 64, 1] : memref<0xf16, 3> to memref<4x64x64xf16,3>
+    %rhsShmem3 = memref.subview %rhsShmem2[2, 0, 0][1, 64, 64][1, 1, 1] : memref<4x64x64xf16,3> to memref<1x64x64xf16, strided<[4096, 64, 1], offset: 8192>, 3>
+    %rhsShmem = memref.subview %rhsShmem3[0, 0, 0][1, 64, 64][1, 1, 1]  : memref<1x64x64xf16, strided<[4096, 64, 1], offset: 8192>, 3> to memref<64x64xf16, strided<[64, 1], offset: 8192>, 3>
     // CHECK: nvvm.cp.async.bulk.tensor.shared.cluster.global
-    nvgpu.tma.async.load %lhsTensorMap[%c0, %c0], %mbarrier[%c0] to %lhsShmem : !lhsTensorMap, !barrierType -> !shmemlhs
+    nvgpu.tma.async.load %lhsTensorMap[%c0, %c0], %mbarrier[%c0] to %lhsShmem : !lhsTensorMap, !barrierType -> memref<128x64xf16,3>
     // CHECK: %[[desc:.+]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> 
     // CHECK: %[[c8192:.+]] = llvm.mlir.constant(8192 : index) : i64
     // CHECK: %[[shmemOfset:.+]] = llvm.getelementptr %[[desc]][%[[c8192]]] : (!llvm.ptr<3>, i64)
     // CHECK: nvvm.cp.async.bulk.tensor.shared.cluster.global %[[shmemOfset]], %{{.*}}, %{{.*}}, box[%{{.*}}, %{{.*}}]
-    nvgpu.tma.async.load %rhsTensorMap[%c0, %c0], %mbarrier[%c0] to %rhsShmem : !rhsTensorMap, !barrierType -> !shmemrhs
+    nvgpu.tma.async.load %rhsTensorMap[%c0, %c0], %mbarrier[%c0] to %rhsShmem : !rhsTensorMap, !barrierType -> memref<64x64xf16, strided<[64, 1], offset: 8192>, 3>
     return
   }
 }

diff  --git a/mlir/test/Dialect/NVGPU/invalid.mlir b/mlir/test/Dialect/NVGPU/invalid.mlir
index e1949fcfad7ad..c3aed35153241 100644
--- a/mlir/test/Dialect/NVGPU/invalid.mlir
+++ b/mlir/test/Dialect/NVGPU/invalid.mlir
@@ -316,3 +316,23 @@ func.func @tma_load_4(%desc: !desc,  %buffer1: memref<128xf32,3>, %buffer2: memr
   nvgpu.tma.async.load %desc[%c0, %c0], %mbarrier[%c0] to %buffer1 : !desc, !mbarrier -> memref<128xf32,3>
   return
 }
+
+// -----
+
+!desc = !nvgpu.tensormap.descriptor<tensor = memref<64x128xf16,3>, swizzle=swizzle_32b, l2promo = none, oob = zero, interleave = none>
+func.func @tma_generate_descriptor_incorrect_last_dim(%b0 : index, %b1 : index, %mem : memref<*xf16>) {
+  // expected-error @+1 {{the tensormap descriptor must have last dimension of 128 bytes but it is 256 bytes}}
+  %descA = nvgpu.tma.create.descriptor %mem box[%b0, %b1] : memref<*xf16> -> !desc
+  return
+}
+// -----
+
+
+!desc = !nvgpu.tensormap.descriptor<tensor = memref<64x128xf32,3>, swizzle=swizzle_32b, l2promo = none, oob = zero, interleave = none>
+!mbarrier = !nvgpu.mbarrier.group<memorySpace = #gpu.address_space<workgroup>>
+func.func @tma_generate_descriptor_incorrect_last_dim(%desc: !desc,  %buffer2: memref<64x128xf32,3>, %mbarrier: !mbarrier) {
+  %c0 = arith.constant 0 : index
+  // expected-error @+1 {{the tensormap descriptor must have last dimension of 128 bytes but it is 512 bytes}}
+  nvgpu.tma.async.load %desc[%c0, %c0], %mbarrier[%c0] to %buffer2 : !desc, !mbarrier -> memref<64x128xf32,3>
+  return
+}

diff  --git a/mlir/test/Dialect/NVGPU/tmaload-transform.mlir b/mlir/test/Dialect/NVGPU/tmaload-transform.mlir
index ab6483151a63f..29e300a992d3a 100644
--- a/mlir/test/Dialect/NVGPU/tmaload-transform.mlir
+++ b/mlir/test/Dialect/NVGPU/tmaload-transform.mlir
@@ -3,8 +3,8 @@
 // RUN:     -test-transform-dialect-erase-schedule \
 // RUN: | FileCheck %s
 
-memref.global "private" @bufferLhsGlobal : memref<64x8xf32, #gpu.address_space<workgroup>>
-memref.global "private" @bufferRhsGlobal : memref<8x128xf32, #gpu.address_space<workgroup>>
+memref.global "private" @bufferLhsGlobal : memref<64x32xf32, #gpu.address_space<workgroup>>
+memref.global "private" @bufferRhsGlobal : memref<8x32xf32, #gpu.address_space<workgroup>>
 
 // CHECK-LABEL: func.func @main()
 func.func @main() {
@@ -12,26 +12,26 @@ func.func @main() {
   %c128 = arith.constant 128 : index
 
   %0 = gpu.wait async
-  %memref, %asyncToken = gpu.alloc async [%0] () : memref<64x8xf32>
-  %memref_1, %asyncToken_2 = gpu.alloc async [%0] () : memref<8x128xf32>
+  %memref, %asyncToken = gpu.alloc async [%0] () : memref<64x32xf32>
+  %memref_1, %asyncToken_2 = gpu.alloc async [%0] () : memref<8x32xf32>
 
-  //      CHECK: %[[M1:.*]] = memref.cast %{{.*}} : memref<64x8xf32> to memref<*xf32>
+  //      CHECK: %[[M1:.*]] = memref.cast %{{.*}} : memref<64x32xf32> to memref<*xf32>
   //      CHECK: %[[c64:.*]] = arith.constant 64 : index
-  //      CHECK: %[[c8:.*]] = arith.constant 8 : index
-  //      CHECK: %[[D1:.*]] = nvgpu.tma.create.descriptor %[[M1]] box[%[[c64]], %[[c8]]]
-  // CHECK-SAME:   : memref<*xf32> -> <tensor = memref<64x8xf32, #gpu.address_space<workgroup>>, swizzle = none, l2promo = none, oob = zero, interleave = none>
-  //      CHECK: %[[cast_2:.*]] = memref.cast %memref_0 : memref<8x128xf32> to memref<*xf32>
+  //      CHECK: %[[c32:.*]] = arith.constant 32 : index
+  //      CHECK: %[[D1:.*]] = nvgpu.tma.create.descriptor %[[M1]] box[%[[c64]], %[[c32]]]
+  // CHECK-SAME:   : memref<*xf32> -> <tensor = memref<64x32xf32, #gpu.address_space<workgroup>>, swizzle = none, l2promo = none, oob = zero, interleave = none>
+  //      CHECK: %[[cast_2:.*]] = memref.cast %memref_0 : memref<8x32xf32> to memref<*xf32>
   //      CHECK: %[[c8_2:.*]] = arith.constant 8 : index
-  //      CHECK: %[[c128_2:.*]] = arith.constant 128 : index
-  //      CHECK: %[[D2:.*]] = nvgpu.tma.create.descriptor %cast_2 box[%[[c8_2]], %[[c128_2]]]
-  // CHECK-SAME:   : memref<*xf32> -> <tensor = memref<8x128xf32, #gpu.address_space<workgroup>>, swizzle = none, l2promo = none, oob = zero, interleave = none>
+  //      CHECK: %[[c32_2:.*]] = arith.constant 32 : index
+  //      CHECK: %[[D2:.*]] = nvgpu.tma.create.descriptor %cast_2 box[%[[c8_2]], %[[c32_2]]]
+  // CHECK-SAME:   : memref<*xf32> -> <tensor = memref<8x32xf32, #gpu.address_space<workgroup>>, swizzle = none, l2promo = none, oob = zero, interleave = none>
   // CHECK: gpu.launch
   gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1)
             threads(%tx, %ty, %tz) in (%block_x = %c128, %block_y = %c1, %block_z = %c1) {
-    //      CHECK: %[[G1:.*]] = memref.get_global @bufferLhsGlobal : memref<64x8xf32, #gpu.address_space<workgroup>>
-    //      CHECK: %[[G2:.*]] = memref.get_global @bufferRhsGlobal : memref<8x128xf32, #gpu.address_space<workgroup>>
-    %out = memref.get_global @bufferLhsGlobal : memref<64x8xf32, #gpu.address_space<workgroup>>
-    %out_1 = memref.get_global @bufferRhsGlobal : memref<8x128xf32, #gpu.address_space<workgroup>>
+    //      CHECK: %[[G1:.*]] = memref.get_global @bufferLhsGlobal : memref<64x32xf32, #gpu.address_space<workgroup>>
+    //      CHECK: %[[G2:.*]] = memref.get_global @bufferRhsGlobal : memref<8x32xf32, #gpu.address_space<workgroup>>
+    %out = memref.get_global @bufferLhsGlobal : memref<64x32xf32, #gpu.address_space<workgroup>>
+    %out_1 = memref.get_global @bufferRhsGlobal : memref<8x32xf32, #gpu.address_space<workgroup>>
 
     //      CHECK: %[[B:.*]] = nvgpu.mbarrier.create -> <memorySpace = #gpu.address_space<workgroup>
     //      CHECK: nvgpu.mbarrier.init %[[B]][%{{.*}}], %{{.*}} : <memorySpace = #gpu.address_space<workgroup>
@@ -45,18 +45,18 @@ func.func @main() {
     //
     //      CHECK:   %[[c0_7:.*]] = arith.constant 0 : index
     //      CHECK:   nvgpu.tma.async.load %[[D1]][%[[c0_7]], %[[c0_7]]], %[[B]][%{{.*}}] to %[[G1]]
-    // CHECK-SAME:     : <tensor = memref<64x8xf32, #gpu.address_space<workgroup>>,
+    // CHECK-SAME:     : <tensor = memref<64x32xf32, #gpu.address_space<workgroup>>,
     // CHECK-SAME:        swizzle = none, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>
-    // CHECK-SAME:     -> memref<64x8xf32, #gpu.address_space<workgroup>>
+    // CHECK-SAME:     -> memref<64x32xf32, #gpu.address_space<workgroup>>
     //
     //      CHECK:   %[[c0_8:.*]] = arith.constant 0 : index
     //      CHECK:   nvgpu.tma.async.load %[[D2]][%[[c0_8]], %[[c0_8]]], %[[B]][%{{.*}}] to %[[G2]]
-    // CHECK-SAME:     : <tensor = memref<8x128xf32, #gpu.address_space<workgroup>>,
+    // CHECK-SAME:     : <tensor = memref<8x32xf32, #gpu.address_space<workgroup>>,
     // CHECK-SAME:         swizzle = none, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>
-    // CHECK-SAME:    -> memref<8x128xf32, #gpu.address_space<workgroup>>
+    // CHECK-SAME:    -> memref<8x32xf32, #gpu.address_space<workgroup>>
     //
-    //      CHECK:   %[[c6144:.*]] = arith.constant 6144 : index
-    //      CHECK:   nvgpu.mbarrier.arrive.expect_tx %[[B]][%{{.*}}], %[[c6144]] : <memorySpace = #gpu.address_space<workgroup>
+    //      CHECK:   %[[c9216:.*]] = arith.constant 9216 : index
+    //      CHECK:   nvgpu.mbarrier.arrive.expect_tx %[[B]][%{{.*}}], %[[c9216]] : <memorySpace = #gpu.address_space<workgroup>
     //      CHECK: } else {
     //      CHECK:   %[[c0_7:.*]] = arith.constant 0 : index
     //      CHECK:   nvgpu.mbarrier.arrive.expect_tx %[[B]][%{{.*}}], %[[c0_7]] : <memorySpace = #gpu.address_space<workgroup>
@@ -67,8 +67,8 @@ func.func @main() {
     //      CHECK: nvgpu.mbarrier.try_wait.parity %[[B]][%{{.*}}], %[[c0_6]], %[[c10000000]] : <memorySpace = #gpu.address_space<workgroup>
 
     /// Both copies are matched and end up in the same async group.
-    linalg.copy ins(%memref: memref<64x8xf32>) outs(%out: memref<64x8xf32, #gpu.address_space<workgroup>>)
-    linalg.copy ins(%memref_1: memref<8x128xf32>) outs(%out_1: memref<8x128xf32, #gpu.address_space<workgroup>>)
+    linalg.copy ins(%memref: memref<64x32xf32>) outs(%out: memref<64x32xf32, #gpu.address_space<workgroup>>)
+    linalg.copy ins(%memref_1: memref<8x32xf32>) outs(%out_1: memref<8x32xf32, #gpu.address_space<workgroup>>)
 
     gpu.terminator
   }


        


More information about the Mlir-commits mailing list