[Mlir-commits] [mlir] [mlir][nvgpu] Fix tma descriptor check (PR	#152160)
    lonely eagle 
    llvmlistbot at llvm.org
       
    Tue Aug  5 08:31:05 PDT 2025
    
    
  
linuxlonelyeagle wrote:
I provided a working example.
```makefile
tma-swizzle-run:
	@${MLIR_OPT} tma-swizzle.mlir -lower-affine -gpu-lower-to-nvvm-pipeline="cubin-chip=sm_90a cubin-features=+ptx80 opt-level=3" | \
	${MLIR_RUNNER} -entry-point-result=void -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_CUDA_RUNTIME}
```
```mlir
!a_descriptor = !nvgpu.tensormap.descriptor<tensor = memref<64x16xf16, 3>, swizzle = swizzle_32b, l2promo=none, oob=nan, interleave=none>
!a_type = memref<64x64xf16>
!a_smem_type = memref<64x16xf16, #gpu.address_space<workgroup>>
!barrierType = !nvgpu.mbarrier.group<memorySpace = #gpu.address_space<workgroup>>
func.func @main() {
  %a_host = memref.alloc() : memref<64x64xf16>
  %f0 = arith.constant 0.0 : f16
  %f1 = arith.constant 1.0 : f16
  %c1 = arith.constant 1 : index
  %c16 = arith.constant 16 : index
  %c64 = arith.constant 64 : index
  %c128 = arith.constant 128 : index
  affine.for %i = 0 to 64 iter_args(%arg = %f0) -> f16 {
    %yield = affine.for %j = 0 to 64 iter_args(%arg_1 = %arg) -> f16 {
      memref.store %arg_1, %a_host[%i, %j] : memref<64x64xf16>
      %iter_arg = arith.addf %arg_1, %f1 : f16
      affine.yield %iter_arg : f16
    }
    affine.yield %yield : f16
  }
  %token_0 = gpu.wait async
  %a_device, %token_1 = gpu.alloc async[%token_0] () : memref<64x64xf16>
  %a_cp = gpu.memcpy async [%token_1] %a_device, %a_host : memref<64x64xf16>, memref<64x64xf16>
  %a_device_unranked = memref.cast %a_device : memref<64x64xf16> to memref<*xf16>
  %a_device_map = nvgpu.tma.create.descriptor %a_device_unranked box[%c64, %c16] : memref<*xf16> -> !a_descriptor
  //dynamic share memory size: 64 * 16 * 2 + 16 * 16 * 2 = 2560 bytes;
  %c2560 = arith.constant 2560 : i32
  gpu.launch blocks(%bx, %by, %bz) in (%sz_bx = %c1, %sz_by = %c1, %sz_bz = %c1) 
            threads(%tx, %ty, %tz) in (%sz_tx = %c128, %sz_ty = %c1, %sz_tz = %c1)
            dynamic_shared_memory_size %c2560 {
    %c0 = arith.constant 0 : index
    %c2048 = arith.constant 2048 : index
    %0 = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
    %a_smem = memref.view %0[%c0][] : memref<?xi8, #gpu.address_space<workgroup>> to !a_smem_type
    %thread_size = gpu.block_dim x
    %thread_id = gpu.thread_id x
    %mbarrier = nvgpu.mbarrier.create -> !nvgpu.mbarrier.group<memorySpace = #gpu.address_space<workgroup>>
    nvgpu.mbarrier.init %mbarrier[%c0], %thread_size : !nvgpu.mbarrier.group<memorySpace = #gpu.address_space<workgroup>> 
    %thread_0 = arith.cmpi eq, %thread_id, %c0 : index
    scf.if %thread_0 {
      gpu.printf "print a matrix:\n"
      affine.for %i = 0 to 64 {
        affine.for %j = 0 to 64 {
          %value = memref.load %a_device[%i, %j] : !a_type
          %value_f32 = arith.extf %value : f16 to f32
          gpu.printf "%4.0f ", %value_f32 : f32
        }
        gpu.printf "\n"
      } 
    }
    scf.if %thread_0 {
     nvgpu.tma.async.load %a_device_map[%c0, %c0], %mbarrier[%c0] to %a_smem : !a_descriptor, !barrierType -> !a_smem_type
     nvgpu.mbarrier.arrive.expect_tx %mbarrier[%c0], %c2048 : !barrierType
    } else {
      nvgpu.mbarrier.arrive.expect_tx %mbarrier[%c0], %c0 : !barrierType
    }
    %phase_c0 = arith.constant 0 : i1
    %c10000000 = arith.constant 10000000 : index
    nvgpu.mbarrier.try_wait.parity %mbarrier[%c0], %phase_c0, %c10000000 : !barrierType
    scf.if %thread_0 {
      gpu.printf "print a smem:\n"
      affine.for %i = 0 to 64 {
        affine.for %j = 0 to 16 {
          %value = memref.load %a_smem[%i, %j] : !a_smem_type
          %value_f32 = arith.extf %value : f16 to f32
          gpu.printf "%4.0f ", %value_f32 : f32
        }
        gpu.printf "\n"
      } 
    }
    gpu.terminator
  }
  //%dealloc_a = gpu.dealloc async[%launch] %a_device : memref<64x16xf16>
  //%dealloc_b = gpu.dealloc async[%launch] %b_device : memref<16x16xf16>
  //gpu.wait [%dealloc_a, %dealloc_b]
  memref.dealloc %a_host : !a_type
  return
}
```
result
```
print a smem:
   0    1    2    3    4    5    6    7    8    9   10   11   12   13   14   15 
  64   65   66   67   68   69   70   71   72   73   74   75   76   77   78   79 
 128  129  130  131  132  133  134  135  136  137  138  139  140  141  142  143 
 192  193  194  195  196  197  198  199  200  201  202  203  204  205  206  207 
 264  265  266  267  268  269  270  271  256  257  258  259  260  261  262  263 
 328  329  330  331  332  333  334  335  320  321  322  323  324  325  326  327 
 392  393  394  395  396  397  398  399  384  385  386  387  388  389  390  391 
 456  457  458  459  460  461  462  463  448  449  450  451  452  453  454  455 
 512  513  514  515  516  517  518  519  520  521  522  523  524  525  526  527 
 576  577  578  579  580  581  582  583  584  585  586  587  588  589  590  591 
 640  641  642  643  644  645  646  647  648  649  650  651  652  653  654  655 
 704  705  706  707  708  709  710  711  712  713  714  715  716  717  718  719 
 776  777  778  779  780  781  782  783  768  769  770  771  772  773  774  775 
 840  841  842  843  844  845  846  847  832  833  834  835  836  837  838  839 
 904  905  906  907  908  909  910  911  896  897  898  899  900  901  902  903 
 968  969  970  971  972  973  974  975  960  961  962  963  964  965  966  967 
1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 
1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 
1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 
1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 
1288 1289 1290 1291 1292 1293 1294 1295 1280 1281 1282 1283 1284 1285 1286 1287 
1352 1353 1354 1355 1356 1357 1358 1359 1344 1345 1346 1347 1348 1349 1350 1351 
1416 1417 1418 1419 1420 1421 1422 1423 1408 1409 1410 1411 1412 1413 1414 1415 
1480 1481 1482 1483 1484 1485 1486 1487 1472 1473 1474 1475 1476 1477 1478 1479 
1536 1537 1538 1539 1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550 1551 
1600 1601 1602 1603 1604 1605 1606 1607 1608 1609 1610 1611 1612 1613 1614 1615 
```
https://github.com/llvm/llvm-project/pull/152160
    
    
More information about the Mlir-commits
mailing list