[Mlir-commits] [mlir] [mlir][nvgpu] Improve `tensormap.descriptor` Type Verifier (PR #77904)
Guray Ozen
llvmlistbot at llvm.org
Wed Jan 31 04:57:47 PST 2024
https://github.com/grypp updated https://github.com/llvm/llvm-project/pull/77904
>From bb035e2fbf3c5f5191b43079e503341ccb0f15fc Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Fri, 12 Jan 2024 11:38:35 +0100
Subject: [PATCH 1/3] [mlir][nvgpu] Improve `tensormap.descriptor` Type
Verifier
This PR improves the verifier for the `nvgpu.tensormap.descriptor` type. The descriptor contains information for TMA, and the compile-time check ensures its restrictions, such as the last memory dimension being 128-byte. This prevents runtime crashes.
See cuda driver for more explanation:
https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html#group__CUDA__TENSOR__MEMORY_1ga7c7d2aaac9e49294304e755e6f341d7
---
.../mlir/Dialect/NVGPU/IR/NVGPUDialect.h | 4 +++
mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp | 16 +++++++++
.../Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir | 17 ++++-----
mlir/test/Dialect/NVGPU/invalid.mlir | 20 +++++++++++
.../test/Dialect/NVGPU/tmaload-transform.mlir | 36 +++++++++----------
5 files changed, 65 insertions(+), 28 deletions(-)
diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h
index 2888fed277957..cc41e17a8f59c 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h
@@ -27,6 +27,10 @@ constexpr int kWarpSize = 32;
constexpr int kWgmmaSizeM = 64;
/// Maximum tensor dimension that TMA supports
constexpr int kMaxTMATensorDimension = 5;
+/// Maximum any dimension for TMA
+constexpr unsigned kMaxTMADimension = 256;
+/// Last dimension of 2D+ TMA must be 128 bytes
+constexpr unsigned kMaxTMALastdimByte = 128;
#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/NVGPU/IR/NVGPUAttrDefs.h.inc"
diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
index bdfb2f54052ae..4507fec530c30 100644
--- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
+++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
@@ -355,6 +355,22 @@ std::optional<InFlightDiagnostic> verifyTmaDescriptorWithMemref(
if (!descMemref.hasStaticShape())
return op->emitError() << "the tensor map descriptor must be static shaped";
+ for (auto dim : descMemref.getShape()) {
+ if (dim <= 0 || dim > kMaxTMADimension) {
+ return op->emitError() << "the tensor map descriptor must not have zero "
+ "dimension";
+ }
+ }
+ if (descMemref.getRank() > 1) {
+ unsigned lastDimensionByte =
+ descMemref.getElementTypeBitWidth() * descMemref.getShape().back() / 8;
+ if (lastDimensionByte != kMaxTMALastdimByte)
+ return op->emitError() << "the tensormap descriptor must have last "
+ "dimension of "
+ << kMaxTMALastdimByte << " bytes but it is "
+ << lastDimensionByte << " bytes";
+ }
+
// No verification if memref type is not provided
if (!memrefType.has_value())
return std::nullopt;
diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index edccd7e80603b..4fc16e5fdd705 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -799,10 +799,7 @@ func.func @tma_prefetch(%tensorMap1d: !tensorMap1d, %p : i1) {
}
!lhsTensorMap = !nvgpu.tensormap.descriptor<tensor = memref<128x64xf16, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>
-!rhsTensorMap = !nvgpu.tensormap.descriptor<tensor = memref<64x128xf16, strided<[128, 1], offset: 8192>, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>
-
-!shmemlhs = memref<128x64xf16,3>
-!shmemrhs = memref<64x128xf16, strided<[128, 1], offset: 8192>, 3>
+!rhsTensorMap = !nvgpu.tensormap.descriptor<tensor = memref<64x64xf16, strided<[64, 1], offset: 8192>, 3>, swizzle = swizzle_128b, l2promo = none, oob = zero, interleave = none>
module @mymodule {
// Dynamic Shared memory
@@ -811,17 +808,17 @@ module @mymodule {
func.func @async_tma_load(%lhsTensorMap: !lhsTensorMap, %rhsTensorMap: !rhsTensorMap, %mbarrier: !barrierType) {
%c0 = arith.constant 0 : index
%dynamicMem = memref.get_global @dynamicShmem : memref<0xf16, 3>
- %lhsShmem = memref.reinterpret_cast %dynamicMem to offset: [0], sizes: [128,64], strides: [64,1] : memref<0xf16, 3> to !shmemlhs
- %rhsShmem2 = memref.reinterpret_cast %dynamicMem to offset: [0], sizes: [2,64,128], strides: [8192,128,1] : memref<0xf16, 3> to memref<2x64x128xf16,3>
- %rhsShmem3 = memref.subview %rhsShmem2[1,0,0][1, 64, 128][1, 1, 1] : memref<2x64x128xf16,3> to memref<1x64x128xf16, strided<[8192, 128, 1], offset: 8192>, 3>
- %rhsShmem = memref.subview %rhsShmem3[0,0,0][1, 64, 128][1, 1, 1] : memref<1x64x128xf16, strided<[8192, 128, 1], offset: 8192>, 3> to !shmemrhs
+ %lhsShmem = memref.reinterpret_cast %dynamicMem to offset: [0], sizes: [128,64], strides: [64,1] : memref<0xf16, 3> to memref<128x64xf16,3>
+ %rhsShmem2 = memref.reinterpret_cast %dynamicMem to offset: [0], sizes: [4, 64, 64], strides: [4096, 64, 1] : memref<0xf16, 3> to memref<4x64x64xf16,3>
+ %rhsShmem3 = memref.subview %rhsShmem2[2, 0, 0][1, 64, 64][1, 1, 1] : memref<4x64x64xf16,3> to memref<1x64x64xf16, strided<[4096, 64, 1], offset: 8192>, 3>
+ %rhsShmem = memref.subview %rhsShmem3[0, 0, 0][1, 64, 64][1, 1, 1] : memref<1x64x64xf16, strided<[4096, 64, 1], offset: 8192>, 3> to memref<64x64xf16, strided<[64, 1], offset: 8192>, 3>
// CHECK: nvvm.cp.async.bulk.tensor.shared.cluster.global
- nvgpu.tma.async.load %lhsTensorMap[%c0, %c0], %mbarrier[%c0] to %lhsShmem : !lhsTensorMap, !barrierType -> !shmemlhs
+ nvgpu.tma.async.load %lhsTensorMap[%c0, %c0], %mbarrier[%c0] to %lhsShmem : !lhsTensorMap, !barrierType -> memref<128x64xf16,3>
// CHECK: %[[desc:.+]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[c8192:.+]] = llvm.mlir.constant(8192 : index) : i64
// CHECK: %[[shmemOfset:.+]] = llvm.getelementptr %[[desc]][%[[c8192]]] : (!llvm.ptr<3>, i64)
// CHECK: nvvm.cp.async.bulk.tensor.shared.cluster.global %[[shmemOfset]], %{{.*}}, %{{.*}}, box[%{{.*}}, %{{.*}}]
- nvgpu.tma.async.load %rhsTensorMap[%c0, %c0], %mbarrier[%c0] to %rhsShmem : !rhsTensorMap, !barrierType -> !shmemrhs
+ nvgpu.tma.async.load %rhsTensorMap[%c0, %c0], %mbarrier[%c0] to %rhsShmem : !rhsTensorMap, !barrierType -> memref<64x64xf16, strided<[64, 1], offset: 8192>, 3>
return
}
}
diff --git a/mlir/test/Dialect/NVGPU/invalid.mlir b/mlir/test/Dialect/NVGPU/invalid.mlir
index e1949fcfad7ad..4c070e9a0fad3 100644
--- a/mlir/test/Dialect/NVGPU/invalid.mlir
+++ b/mlir/test/Dialect/NVGPU/invalid.mlir
@@ -316,3 +316,23 @@ func.func @tma_load_4(%desc: !desc, %buffer1: memref<128xf32,3>, %buffer2: memr
nvgpu.tma.async.load %desc[%c0, %c0], %mbarrier[%c0] to %buffer1 : !desc, !mbarrier -> memref<128xf32,3>
return
}
+
+// -----
+
+!desc = !nvgpu.tensormap.descriptor<tensor = memref<64x128xf16,3>, swizzle=swizzle_32b, l2promo = none, oob = zero, interleave = none>
+func.func @tma_generate_descriptor_incorrect_last_dim(%b0 : index, %b1 : index, %mem : memref<*xf16>) {
+ // expected-error @+1 {{the tensormap descriptor must have last dimension of 128 bytes but it is 256 bytes}}
+ %descA = nvgpu.tma.create.descriptor %mem box[%b0, %b1] : memref<*xf16> -> !desc
+ return
+}
+// -----
+
+
+!desc = !nvgpu.tensormap.descriptor<tensor = memref<64x128xf32,3>, swizzle=swizzle_32b, l2promo = none, oob = zero, interleave = none>
+!mbarrier = !nvgpu.mbarrier.group<memorySpace = #gpu.address_space<workgroup>>
+func.func @tma_generate_descriptor_incorrect_last_dim(%desc: !desc, %buffer2: memref<64x128xf32,3>, %mbarrier: !mbarrier) {
+ %c0 = arith.constant 0 : index
+ // expected-error @+1 {{the tensormap descriptor must have last dimension of 128 bytes but it is 512 bytes}}
+ nvgpu.tma.async.load %desc[%c0, %c0], %mbarrier[%c0] to %buffer2 : !desc, !mbarrier -> memref<64x128xf32,3>
+ return
+}
\ No newline at end of file
diff --git a/mlir/test/Dialect/NVGPU/tmaload-transform.mlir b/mlir/test/Dialect/NVGPU/tmaload-transform.mlir
index ab6483151a63f..5f3074cad926c 100644
--- a/mlir/test/Dialect/NVGPU/tmaload-transform.mlir
+++ b/mlir/test/Dialect/NVGPU/tmaload-transform.mlir
@@ -3,8 +3,8 @@
// RUN: -test-transform-dialect-erase-schedule \
// RUN: | FileCheck %s
-memref.global "private" @bufferLhsGlobal : memref<64x8xf32, #gpu.address_space<workgroup>>
-memref.global "private" @bufferRhsGlobal : memref<8x128xf32, #gpu.address_space<workgroup>>
+memref.global "private" @bufferLhsGlobal : memref<64x32xf32, #gpu.address_space<workgroup>>
+memref.global "private" @bufferRhsGlobal : memref<8x32xf32, #gpu.address_space<workgroup>>
// CHECK-LABEL: func.func @main()
func.func @main() {
@@ -12,26 +12,26 @@ func.func @main() {
%c128 = arith.constant 128 : index
%0 = gpu.wait async
- %memref, %asyncToken = gpu.alloc async [%0] () : memref<64x8xf32>
- %memref_1, %asyncToken_2 = gpu.alloc async [%0] () : memref<8x128xf32>
+ %memref, %asyncToken = gpu.alloc async [%0] () : memref<64x32xf32>
+ %memref_1, %asyncToken_2 = gpu.alloc async [%0] () : memref<8x32xf32>
- // CHECK: %[[M1:.*]] = memref.cast %{{.*}} : memref<64x8xf32> to memref<*xf32>
+ // CHECK: %[[M1:.*]] = memref.cast %{{.*}} : memref<64x32xf32> to memref<*xf32>
// CHECK: %[[c64:.*]] = arith.constant 64 : index
// CHECK: %[[c8:.*]] = arith.constant 8 : index
// CHECK: %[[D1:.*]] = nvgpu.tma.create.descriptor %[[M1]] box[%[[c64]], %[[c8]]]
- // CHECK-SAME: : memref<*xf32> -> <tensor = memref<64x8xf32, #gpu.address_space<workgroup>>, swizzle = none, l2promo = none, oob = zero, interleave = none>
- // CHECK: %[[cast_2:.*]] = memref.cast %memref_0 : memref<8x128xf32> to memref<*xf32>
+ // CHECK-SAME: : memref<*xf32> -> <tensor = memref<64x32xf32, #gpu.address_space<workgroup>>, swizzle = none, l2promo = none, oob = zero, interleave = none>
+ // CHECK: %[[cast_2:.*]] = memref.cast %memref_0 : memref<8x32xf32> to memref<*xf32>
// CHECK: %[[c8_2:.*]] = arith.constant 8 : index
// CHECK: %[[c128_2:.*]] = arith.constant 128 : index
// CHECK: %[[D2:.*]] = nvgpu.tma.create.descriptor %cast_2 box[%[[c8_2]], %[[c128_2]]]
- // CHECK-SAME: : memref<*xf32> -> <tensor = memref<8x128xf32, #gpu.address_space<workgroup>>, swizzle = none, l2promo = none, oob = zero, interleave = none>
+ // CHECK-SAME: : memref<*xf32> -> <tensor = memref<8x32xf32, #gpu.address_space<workgroup>>, swizzle = none, l2promo = none, oob = zero, interleave = none>
// CHECK: gpu.launch
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1)
threads(%tx, %ty, %tz) in (%block_x = %c128, %block_y = %c1, %block_z = %c1) {
- // CHECK: %[[G1:.*]] = memref.get_global @bufferLhsGlobal : memref<64x8xf32, #gpu.address_space<workgroup>>
- // CHECK: %[[G2:.*]] = memref.get_global @bufferRhsGlobal : memref<8x128xf32, #gpu.address_space<workgroup>>
- %out = memref.get_global @bufferLhsGlobal : memref<64x8xf32, #gpu.address_space<workgroup>>
- %out_1 = memref.get_global @bufferRhsGlobal : memref<8x128xf32, #gpu.address_space<workgroup>>
+ // CHECK: %[[G1:.*]] = memref.get_global @bufferLhsGlobal : memref<64x32xf32, #gpu.address_space<workgroup>>
+ // CHECK: %[[G2:.*]] = memref.get_global @bufferRhsGlobal : memref<8x32xf32, #gpu.address_space<workgroup>>
+ %out = memref.get_global @bufferLhsGlobal : memref<64x32xf32, #gpu.address_space<workgroup>>
+ %out_1 = memref.get_global @bufferRhsGlobal : memref<8x32xf32, #gpu.address_space<workgroup>>
// CHECK: %[[B:.*]] = nvgpu.mbarrier.create -> <memorySpace = #gpu.address_space<workgroup>
// CHECK: nvgpu.mbarrier.init %[[B]][%{{.*}}], %{{.*}} : <memorySpace = #gpu.address_space<workgroup>
@@ -45,15 +45,15 @@ func.func @main() {
//
// CHECK: %[[c0_7:.*]] = arith.constant 0 : index
// CHECK: nvgpu.tma.async.load %[[D1]][%[[c0_7]], %[[c0_7]]], %[[B]][%{{.*}}] to %[[G1]]
- // CHECK-SAME: : <tensor = memref<64x8xf32, #gpu.address_space<workgroup>>,
+ // CHECK-SAME: : <tensor = memref<64x32xf32, #gpu.address_space<workgroup>>,
// CHECK-SAME: swizzle = none, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>
- // CHECK-SAME: -> memref<64x8xf32, #gpu.address_space<workgroup>>
+ // CHECK-SAME: -> memref<64x32xf32, #gpu.address_space<workgroup>>
//
// CHECK: %[[c0_8:.*]] = arith.constant 0 : index
// CHECK: nvgpu.tma.async.load %[[D2]][%[[c0_8]], %[[c0_8]]], %[[B]][%{{.*}}] to %[[G2]]
- // CHECK-SAME: : <tensor = memref<8x128xf32, #gpu.address_space<workgroup>>,
+ // CHECK-SAME: : <tensor = memref<8x32xf32, #gpu.address_space<workgroup>>,
// CHECK-SAME: swizzle = none, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>
- // CHECK-SAME: -> memref<8x128xf32, #gpu.address_space<workgroup>>
+ // CHECK-SAME: -> memref<8x32xf32, #gpu.address_space<workgroup>>
//
// CHECK: %[[c6144:.*]] = arith.constant 6144 : index
// CHECK: nvgpu.mbarrier.arrive.expect_tx %[[B]][%{{.*}}], %[[c6144]] : <memorySpace = #gpu.address_space<workgroup>
@@ -67,8 +67,8 @@ func.func @main() {
// CHECK: nvgpu.mbarrier.try_wait.parity %[[B]][%{{.*}}], %[[c0_6]], %[[c10000000]] : <memorySpace = #gpu.address_space<workgroup>
/// Both copies are matched and end up in the same async group.
- linalg.copy ins(%memref: memref<64x8xf32>) outs(%out: memref<64x8xf32, #gpu.address_space<workgroup>>)
- linalg.copy ins(%memref_1: memref<8x128xf32>) outs(%out_1: memref<8x128xf32, #gpu.address_space<workgroup>>)
+ linalg.copy ins(%memref: memref<64x32xf32>) outs(%out: memref<64x32xf32, #gpu.address_space<workgroup>>)
+ linalg.copy ins(%memref_1: memref<8x32xf32>) outs(%out_1: memref<8x32xf32, #gpu.address_space<workgroup>>)
gpu.terminator
}
>From 51b47fac56e306192dc133fe7658bcaaaa13f632 Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Fri, 19 Jan 2024 13:31:07 +0100
Subject: [PATCH 2/3] fix test
---
mlir/test/Dialect/NVGPU/invalid.mlir | 2 +-
mlir/test/Dialect/NVGPU/tmaload-transform.mlir | 12 ++++++------
2 files changed, 7 insertions(+), 7 deletions(-)
diff --git a/mlir/test/Dialect/NVGPU/invalid.mlir b/mlir/test/Dialect/NVGPU/invalid.mlir
index 4c070e9a0fad3..c3aed35153241 100644
--- a/mlir/test/Dialect/NVGPU/invalid.mlir
+++ b/mlir/test/Dialect/NVGPU/invalid.mlir
@@ -335,4 +335,4 @@ func.func @tma_generate_descriptor_incorrect_last_dim(%desc: !desc, %buffer2: m
// expected-error @+1 {{the tensormap descriptor must have last dimension of 128 bytes but it is 512 bytes}}
nvgpu.tma.async.load %desc[%c0, %c0], %mbarrier[%c0] to %buffer2 : !desc, !mbarrier -> memref<64x128xf32,3>
return
-}
\ No newline at end of file
+}
diff --git a/mlir/test/Dialect/NVGPU/tmaload-transform.mlir b/mlir/test/Dialect/NVGPU/tmaload-transform.mlir
index 5f3074cad926c..29e300a992d3a 100644
--- a/mlir/test/Dialect/NVGPU/tmaload-transform.mlir
+++ b/mlir/test/Dialect/NVGPU/tmaload-transform.mlir
@@ -17,13 +17,13 @@ func.func @main() {
// CHECK: %[[M1:.*]] = memref.cast %{{.*}} : memref<64x32xf32> to memref<*xf32>
// CHECK: %[[c64:.*]] = arith.constant 64 : index
- // CHECK: %[[c8:.*]] = arith.constant 8 : index
- // CHECK: %[[D1:.*]] = nvgpu.tma.create.descriptor %[[M1]] box[%[[c64]], %[[c8]]]
+ // CHECK: %[[c32:.*]] = arith.constant 32 : index
+ // CHECK: %[[D1:.*]] = nvgpu.tma.create.descriptor %[[M1]] box[%[[c64]], %[[c32]]]
// CHECK-SAME: : memref<*xf32> -> <tensor = memref<64x32xf32, #gpu.address_space<workgroup>>, swizzle = none, l2promo = none, oob = zero, interleave = none>
// CHECK: %[[cast_2:.*]] = memref.cast %memref_0 : memref<8x32xf32> to memref<*xf32>
// CHECK: %[[c8_2:.*]] = arith.constant 8 : index
- // CHECK: %[[c128_2:.*]] = arith.constant 128 : index
- // CHECK: %[[D2:.*]] = nvgpu.tma.create.descriptor %cast_2 box[%[[c8_2]], %[[c128_2]]]
+ // CHECK: %[[c32_2:.*]] = arith.constant 32 : index
+ // CHECK: %[[D2:.*]] = nvgpu.tma.create.descriptor %cast_2 box[%[[c8_2]], %[[c32_2]]]
// CHECK-SAME: : memref<*xf32> -> <tensor = memref<8x32xf32, #gpu.address_space<workgroup>>, swizzle = none, l2promo = none, oob = zero, interleave = none>
// CHECK: gpu.launch
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1)
@@ -55,8 +55,8 @@ func.func @main() {
// CHECK-SAME: swizzle = none, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>
// CHECK-SAME: -> memref<8x32xf32, #gpu.address_space<workgroup>>
//
- // CHECK: %[[c6144:.*]] = arith.constant 6144 : index
- // CHECK: nvgpu.mbarrier.arrive.expect_tx %[[B]][%{{.*}}], %[[c6144]] : <memorySpace = #gpu.address_space<workgroup>
+ // CHECK: %[[c9216:.*]] = arith.constant 9216 : index
+ // CHECK: nvgpu.mbarrier.arrive.expect_tx %[[B]][%{{.*}}], %[[c9216]] : <memorySpace = #gpu.address_space<workgroup>
// CHECK: } else {
// CHECK: %[[c0_7:.*]] = arith.constant 0 : index
// CHECK: nvgpu.mbarrier.arrive.expect_tx %[[B]][%{{.*}}], %[[c0_7]] : <memorySpace = #gpu.address_space<workgroup>
>From 079280f6edb0000e23aebbf444c0bab1fcfd9ec0 Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Wed, 31 Jan 2024 13:57:31 +0100
Subject: [PATCH 3/3] Address comments
---
mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h | 10 +++++++---
mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp | 5 +++--
2 files changed, 10 insertions(+), 5 deletions(-)
diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h
index cc41e17a8f59c..19070f6f062a0 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h
@@ -25,9 +25,13 @@ constexpr int kWarpSize = 32;
/// M size of wgmma.mma_async instruction
constexpr int kWgmmaSizeM = 64;
-/// Maximum tensor dimension that TMA supports
-constexpr int kMaxTMATensorDimension = 5;
-/// Maximum any dimension for TMA
+
+/// Maximum TMA tile dimension (tensorRank) must be non-zero and less than or
+/// equal to the maximum supported dimensionality of 5.
+constexpr unsigned kMaxTMATensorDimension = 5;
+/// Maximum TMA tile size (boxDim), which specifies number of elements
+/// to be traversed along each of the kMaxTMATensorDimension (tensorRank)
+/// dimensions, must be non-zero and less than or equal to 256.
constexpr unsigned kMaxTMADimension = 256;
/// Last dimension of 2D+ TMA must be 128 bytes
constexpr unsigned kMaxTMALastdimByte = 128;
diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
index 4507fec530c30..4b6327479a219 100644
--- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
+++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
@@ -357,8 +357,9 @@ std::optional<InFlightDiagnostic> verifyTmaDescriptorWithMemref(
for (auto dim : descMemref.getShape()) {
if (dim <= 0 || dim > kMaxTMADimension) {
- return op->emitError() << "the tensor map descriptor must not have zero "
- "dimension";
+ return op->emitError() << "the tensor map descriptor must have "
+ "dimensions between 1 and "
+ << kMaxTMADimension << " but it is " << dim;
}
}
if (descMemref.getRank() > 1) {
More information about the Mlir-commits
mailing list