[Mlir-commits] [mlir] [nvgpu][mlir] Fix wgmma store offset (PR #154581)
lonely eagle
llvmlistbot at llvm.org
Wed Aug 20 10:57:04 PDT 2025
linuxlonelyeagle wrote:
error case
```
#map = affine_map<()[s0] -> (s0 mod 8)>
#map1 = affine_map<()[s0] -> (s0 ceildiv 8)>
!a_descriptor = !nvgpu.tensormap.descriptor<tensor = memref<128x64xf16, 3>, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>
!b_descriptor = !nvgpu.tensormap.descriptor<tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>
!a_type = memref<128x64xf16>
!b_type = memref<64x64xf16>
!c_type = memref<128x64xf32>
!a_smem_type = memref<128x64xf16, #gpu.address_space<workgroup>>
!b_smem_type = memref<64x64xf16, #gpu.address_space<workgroup>>
!c_smem_type = memref<128x64xf32, #gpu.address_space<workgroup>>
!c_smem_type_double = memref<128x128xf32, #gpu.address_space<workgroup>>
!barrierType = !nvgpu.mbarrier.group<memorySpace = #gpu.address_space<workgroup>>
func.func private @printMemrefF32(memref<*xf32>)
func.func @main() {
%a_host = memref.alloc() : !a_type
%b_host = memref.alloc() : !b_type
%f0 = arith.constant 0.0 : f16
%f1 = arith.constant 1.0 : f16
%f001 = arith.constant 0.01 : f16
%c1 = arith.constant 1 : index
%c32 = arith.constant 32 : index
%c16= arith.constant 16 : index
%c8 = arith.constant 8 : index
%c64 = arith.constant 64 : index
%c128 = arith.constant 128 : index
%f100 = arith.constant 100.0 : f16
affine.for %i = 0 to 128 iter_args(%arg = %f0) -> f16 {
%y = affine.for %j = 0 to 64 iter_args(%arg_1 = %arg) -> f16 {
//%div = arith.divf %arg_1, %f100 : f16
memref.store %arg_1, %a_host[%i, %j] : !a_type
%iter_arg = arith.addf %arg_1, %f001 : f16
affine.yield %iter_arg : f16
}
affine.yield %y : f16
}
affine.for %i = 0 to 64 iter_args(%arg = %f0) -> f16 {
%y = affine.for %j = 0 to 64 iter_args(%arg_1 = %arg) -> f16 {
//%div = arith.divf %arg_1, %f100 : f16
memref.store %arg_1, %b_host[%i, %j] : !b_type
%iter_arg = arith.addf %arg_1, %f001 : f16
affine.yield %iter_arg : f16
}
affine.yield %y : f16
}
%token_0 = gpu.wait async
%a_device, %token_1 = gpu.alloc async[%token_0] () : !a_type
%b_device, %token_2 = gpu.alloc async[%token_0] () : !b_type
%a_cp = gpu.memcpy async [%token_1] %a_device, %a_host : !a_type, !a_type
%b_cp = gpu.memcpy async [%token_2] %b_device, %b_host : !b_type, !b_type
%a_device_unranked = memref.cast %a_device : !a_type to memref<*xf16>
%b_device_unranked = memref.cast %b_device : !b_type to memref<*xf16>
%a_device_map = nvgpu.tma.create.descriptor %a_device_unranked box[%c128, %c64] : memref<*xf16> -> !a_descriptor
%b_device_map = nvgpu.tma.create.descriptor %b_device_unranked box[%c64, %c64] : memref<*xf16> -> !b_descriptor
// a smem + b smem = 24567
// a smem + b smem + c smem + c1 smem = 40960
%40960 = arith.constant 65536 : i32
gpu.launch blocks(%bx, %by, %bz) in (%sz_bx = %c1, %sz_by = %c1, %sz_bz = %c1)
threads(%tx, %ty, %tz) in (%sz_tx = %c128, %sz_ty = %c1, %sz_tz = %c1)
dynamic_shared_memory_size %40960 {
%c16384 = arith.constant 16384 : index
%c30000 = arith.constant 30000 : index
%c20000 = arith.constant 20000 : index
%c24576 = arith.constant 24576 : index
%c8196 = arith.constant 8196 : index
%c4608 = arith.constant 4608 : index
%c6656 = arith.constant 6656 : index
%c4096 = arith.constant 4096 : index
%c2048 = arith.constant 2048 : index
%c512 = arith.constant 512 : index
%c256 = arith.constant 256 : index
%c0 = arith.constant 0 : index
%0 = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
%a_smem = memref.view %0[%c0][] : memref<?xi8, #gpu.address_space<workgroup>> to !a_smem_type
%b_smem = memref.view %0[%c16384][] : memref<?xi8, #gpu.address_space<workgroup>> to !b_smem_type
%c_smem = memref.view %0[%c0][] : memref<?xi8, #gpu.address_space<workgroup>> to !c_smem_type_double
%thread_size = gpu.block_dim x
%thread_id = gpu.thread_id x
%mbarrier = nvgpu.mbarrier.create -> !nvgpu.mbarrier.group<memorySpace = #gpu.address_space<workgroup>>
nvgpu.mbarrier.init %mbarrier[%c0], %thread_size : !nvgpu.mbarrier.group<memorySpace = #gpu.address_space<workgroup>>
%thread_0 = arith.cmpi eq, %thread_id, %c0 : index
scf.if %thread_0 {
gpu.printf "print a matrix:\n"
affine.for %i = 0 to 128 {
affine.for %j = 0 to 64 {
%value = memref.load %a_device[%i, %j] : !a_type
%value_f32 = arith.extf %value : f16 to f32
gpu.printf "%4.2f ", %value_f32 : f32
}
gpu.printf "\n"
}
}
scf.if %thread_0 {
gpu.printf "print b matrix:\n"
affine.for %i = 0 to 64 {
affine.for %j = 0 to 64 {
%value = memref.load %b_device[%i, %j] : !b_type
%value_f32 = arith.extf %value : f16 to f32
gpu.printf "%4.2f ", %value_f32 : f32
}
gpu.printf "\n"
}
}
scf.if %thread_0 {
nvgpu.tma.async.load %a_device_map[%c0, %c0], %mbarrier[%c0] to %a_smem : !a_descriptor, !barrierType -> !a_smem_type
nvgpu.mbarrier.arrive.expect_tx %mbarrier[%c0], %c16384: !barrierType
} else {
nvgpu.mbarrier.arrive.expect_tx %mbarrier[%c0], %c0 : !barrierType
}
scf.if %thread_0 {
nvgpu.tma.async.load %b_device_map[%c0, %c0], %mbarrier[%c0] to %b_smem : !b_descriptor, !barrierType -> !b_smem_type
nvgpu.mbarrier.arrive.expect_tx %mbarrier[%c0], %c8196 : !barrierType
} else {
nvgpu.mbarrier.arrive.expect_tx %mbarrier[%c0], %c0 : !barrierType
}
%phase_c0 = arith.constant 0 : i1
%c10000000 = arith.constant 10000000 : index
nvgpu.mbarrier.try_wait.parity %mbarrier[%c0], %phase_c0, %c10000000 : !barrierType
scf.if %thread_0 {
gpu.printf "print a smem:\n"
affine.for %i = 0 to 128 {
affine.for %j = 0 to 64 {
%value = memref.load %a_smem[%i, %j] : !a_smem_type
%value_f32 = arith.extf %value : f16 to f32
gpu.printf "%4.2f ", %value_f32 : f32
}
gpu.printf "\n"
}
}
scf.if %thread_0 {
gpu.printf "print b smem:\n"
affine.for %i = 0 to 64 {
affine.for %j = 0 to 64 {
%value = memref.load %b_smem[%i, %j] : !b_smem_type
%value_f32 = arith.extf %value : f16 to f32
gpu.printf "%4.2f ", %value_f32 : f32
}
gpu.printf "\n"
}
}
%acc = nvgpu.warpgroup.mma.init.accumulator -> <fragmented =vector<128x64xf32>>
%dA = nvgpu.warpgroup.generate.descriptor %a_smem, %a_device_map : !a_smem_type, !a_descriptor -> !nvgpu.warpgroup.descriptor<tensor=memref<128x64xf16, #gpu.address_space<workgroup>>>
%dB = nvgpu.warpgroup.generate.descriptor %b_smem, %b_device_map : !b_smem_type, !b_descriptor -> !nvgpu.warpgroup.descriptor<tensor=memref<64x64xf16, #gpu.address_space<workgroup>>>
%md = nvgpu.warpgroup.mma %dA, %dB, %acc {transposeB}: !nvgpu.warpgroup.descriptor<tensor=memref<128x64xf16, #gpu.address_space<workgroup>>>, !nvgpu.warpgroup.descriptor<tensor=memref<64x64xf16, #gpu.address_space<workgroup>>> , <fragmented =vector<128x64xf32>> -> <fragmented =vector<128x64xf32>>
nvvm.wgmma.wait.group.sync.aligned 0
%subview = memref.subview %c_smem [0, 0][128, 64][1, 1] : !c_smem_type_double to memref<128x64xf32, strided<[128, 1]>, #gpu.address_space<workgroup>>
scf.if %thread_0 {
%f0_f32 = arith.constant 0.0 : f32
affine.for %i = 0 to 128 {
affine.for %j = 0 to 128 {
memref.store %f0_f32, %c_smem[%i, %j] : !c_smem_type_double
}
}
}
gpu.barrier
nvgpu.warpgroup.mma.store %md, %subview : <fragmented =vector<128x64xf32>> to memref<128x64xf32, strided<[128, 1]>, #gpu.address_space<workgroup>>
gpu.barrier
scf.if %thread_0 {
gpu.printf "print result:\n"
affine.for %i = 0 to 128 {
affine.for %j = 0 to 128 {
%value = memref.load %c_smem[%i, %j] : !c_smem_type_double
gpu.printf "%4.4f ", %value : f32
}
gpu.printf "\n"
}
}
gpu.terminator
}
%c_host = memref.alloc() : !c_type
affine.for %m = 0 to 128 {
affine.for %n = 0 to 64 {
affine.for %k = 0 to 64 {
%a = memref.load %a_host[%m, %k] : !a_type
%b = memref.load %b_host[%k, %n] : !b_type
%c = memref.load %c_host[%m, %n] : !c_type
%a_f32 = arith.extf %a : f16 to f32
%b_f32 = arith.extf %b : f16 to f32
%mul = arith.mulf %a_f32, %b_f32 : f32
%add = arith.addf %mul, %c : f32
memref.store %add, %c_host[%m, %n] : !c_type
}
}
}
%c_cast = memref.cast %c_host : !c_type to memref<*xf32>
call @printMemrefF32(%c_cast) : (memref<*xf32>) -> ()
return
}
```
https://github.com/llvm/llvm-project/pull/154581
More information about the Mlir-commits
mailing list