[Mlir-commits] [mlir] [mlir][nvgpu] Fix tma descriptor check (PR #152160)
lonely eagle
llvmlistbot at llvm.org
Tue Aug 5 08:31:05 PDT 2025
linuxlonelyeagle wrote:
I provided a working example.
```makefile
tma-swizzle-run:
@${MLIR_OPT} tma-swizzle.mlir -lower-affine -gpu-lower-to-nvvm-pipeline="cubin-chip=sm_90a cubin-features=+ptx80 opt-level=3" | \
${MLIR_RUNNER} -entry-point-result=void -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_CUDA_RUNTIME}
```
```mlir
!a_descriptor = !nvgpu.tensormap.descriptor<tensor = memref<64x16xf16, 3>, swizzle = swizzle_32b, l2promo=none, oob=nan, interleave=none>
!a_type = memref<64x64xf16>
!a_smem_type = memref<64x16xf16, #gpu.address_space<workgroup>>
!barrierType = !nvgpu.mbarrier.group<memorySpace = #gpu.address_space<workgroup>>
func.func @main() {
%a_host = memref.alloc() : memref<64x64xf16>
%f0 = arith.constant 0.0 : f16
%f1 = arith.constant 1.0 : f16
%c1 = arith.constant 1 : index
%c16 = arith.constant 16 : index
%c64 = arith.constant 64 : index
%c128 = arith.constant 128 : index
affine.for %i = 0 to 64 iter_args(%arg = %f0) -> f16 {
%yield = affine.for %j = 0 to 64 iter_args(%arg_1 = %arg) -> f16 {
memref.store %arg_1, %a_host[%i, %j] : memref<64x64xf16>
%iter_arg = arith.addf %arg_1, %f1 : f16
affine.yield %iter_arg : f16
}
affine.yield %yield : f16
}
%token_0 = gpu.wait async
%a_device, %token_1 = gpu.alloc async[%token_0] () : memref<64x64xf16>
%a_cp = gpu.memcpy async [%token_1] %a_device, %a_host : memref<64x64xf16>, memref<64x64xf16>
%a_device_unranked = memref.cast %a_device : memref<64x64xf16> to memref<*xf16>
%a_device_map = nvgpu.tma.create.descriptor %a_device_unranked box[%c64, %c16] : memref<*xf16> -> !a_descriptor
//dynamic share memory size: 64 * 16 * 2 + 16 * 16 * 2 = 2560 bytes;
%c2560 = arith.constant 2560 : i32
gpu.launch blocks(%bx, %by, %bz) in (%sz_bx = %c1, %sz_by = %c1, %sz_bz = %c1)
threads(%tx, %ty, %tz) in (%sz_tx = %c128, %sz_ty = %c1, %sz_tz = %c1)
dynamic_shared_memory_size %c2560 {
%c0 = arith.constant 0 : index
%c2048 = arith.constant 2048 : index
%0 = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
%a_smem = memref.view %0[%c0][] : memref<?xi8, #gpu.address_space<workgroup>> to !a_smem_type
%thread_size = gpu.block_dim x
%thread_id = gpu.thread_id x
%mbarrier = nvgpu.mbarrier.create -> !nvgpu.mbarrier.group<memorySpace = #gpu.address_space<workgroup>>
nvgpu.mbarrier.init %mbarrier[%c0], %thread_size : !nvgpu.mbarrier.group<memorySpace = #gpu.address_space<workgroup>>
%thread_0 = arith.cmpi eq, %thread_id, %c0 : index
scf.if %thread_0 {
gpu.printf "print a matrix:\n"
affine.for %i = 0 to 64 {
affine.for %j = 0 to 64 {
%value = memref.load %a_device[%i, %j] : !a_type
%value_f32 = arith.extf %value : f16 to f32
gpu.printf "%4.0f ", %value_f32 : f32
}
gpu.printf "\n"
}
}
scf.if %thread_0 {
nvgpu.tma.async.load %a_device_map[%c0, %c0], %mbarrier[%c0] to %a_smem : !a_descriptor, !barrierType -> !a_smem_type
nvgpu.mbarrier.arrive.expect_tx %mbarrier[%c0], %c2048 : !barrierType
} else {
nvgpu.mbarrier.arrive.expect_tx %mbarrier[%c0], %c0 : !barrierType
}
%phase_c0 = arith.constant 0 : i1
%c10000000 = arith.constant 10000000 : index
nvgpu.mbarrier.try_wait.parity %mbarrier[%c0], %phase_c0, %c10000000 : !barrierType
scf.if %thread_0 {
gpu.printf "print a smem:\n"
affine.for %i = 0 to 64 {
affine.for %j = 0 to 16 {
%value = memref.load %a_smem[%i, %j] : !a_smem_type
%value_f32 = arith.extf %value : f16 to f32
gpu.printf "%4.0f ", %value_f32 : f32
}
gpu.printf "\n"
}
}
gpu.terminator
}
//%dealloc_a = gpu.dealloc async[%launch] %a_device : memref<64x16xf16>
//%dealloc_b = gpu.dealloc async[%launch] %b_device : memref<16x16xf16>
//gpu.wait [%dealloc_a, %dealloc_b]
memref.dealloc %a_host : !a_type
return
}
```
result
```
print a smem:
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79
128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207
264 265 266 267 268 269 270 271 256 257 258 259 260 261 262 263
328 329 330 331 332 333 334 335 320 321 322 323 324 325 326 327
392 393 394 395 396 397 398 399 384 385 386 387 388 389 390 391
456 457 458 459 460 461 462 463 448 449 450 451 452 453 454 455
512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527
576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591
640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655
704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719
776 777 778 779 780 781 782 783 768 769 770 771 772 773 774 775
840 841 842 843 844 845 846 847 832 833 834 835 836 837 838 839
904 905 906 907 908 909 910 911 896 897 898 899 900 901 902 903
968 969 970 971 972 973 974 975 960 961 962 963 964 965 966 967
1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039
1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103
1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167
1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231
1288 1289 1290 1291 1292 1293 1294 1295 1280 1281 1282 1283 1284 1285 1286 1287
1352 1353 1354 1355 1356 1357 1358 1359 1344 1345 1346 1347 1348 1349 1350 1351
1416 1417 1418 1419 1420 1421 1422 1423 1408 1409 1410 1411 1412 1413 1414 1415
1480 1481 1482 1483 1484 1485 1486 1487 1472 1473 1474 1475 1476 1477 1478 1479
1536 1537 1538 1539 1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550 1551
1600 1601 1602 1603 1604 1605 1606 1607 1608 1609 1610 1611 1612 1613 1614 1615
```
https://github.com/llvm/llvm-project/pull/152160
More information about the Mlir-commits
mailing list