[Mlir-commits] [mlir] [nvgpu][mlir] Fix wgmma store offset (PR #154581)

lonely eagle llvmlistbot at llvm.org
Wed Aug 20 10:57:04 PDT 2025


linuxlonelyeagle wrote:

error case
```
#map = affine_map<()[s0] -> (s0 mod 8)>
#map1 = affine_map<()[s0] -> (s0 ceildiv 8)>

!a_descriptor = !nvgpu.tensormap.descriptor<tensor = memref<128x64xf16, 3>, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>
!b_descriptor = !nvgpu.tensormap.descriptor<tensor = memref<64x64xf16, 3>, swizzle = swizzle_128b, l2promo=none, oob=zero, interleave=none>
!a_type = memref<128x64xf16>
!b_type = memref<64x64xf16>
!c_type = memref<128x64xf32>
!a_smem_type = memref<128x64xf16, #gpu.address_space<workgroup>>
!b_smem_type = memref<64x64xf16, #gpu.address_space<workgroup>>
!c_smem_type = memref<128x64xf32, #gpu.address_space<workgroup>>
!c_smem_type_double = memref<128x128xf32, #gpu.address_space<workgroup>>
!barrierType = !nvgpu.mbarrier.group<memorySpace = #gpu.address_space<workgroup>>

func.func private @printMemrefF32(memref<*xf32>)

func.func @main() {
  %a_host = memref.alloc() : !a_type
  %b_host = memref.alloc() : !b_type
  %f0 = arith.constant 0.0 : f16
  %f1 = arith.constant 1.0 : f16
  %f001 = arith.constant 0.01 : f16
  %c1 = arith.constant 1 : index
  %c32 = arith.constant 32 : index
  %c16= arith.constant 16 : index
  %c8 = arith.constant 8 : index
  %c64 = arith.constant 64 : index
  %c128 = arith.constant 128 : index
  %f100 = arith.constant 100.0 : f16
  affine.for %i = 0 to 128 iter_args(%arg = %f0) -> f16 {
    %y = affine.for %j = 0 to 64 iter_args(%arg_1 = %arg) -> f16 {
      //%div = arith.divf %arg_1, %f100 : f16
      memref.store %arg_1, %a_host[%i, %j] : !a_type
      %iter_arg = arith.addf %arg_1, %f001 : f16
      affine.yield %iter_arg : f16
    }
    affine.yield %y : f16
  }

  affine.for %i = 0 to 64 iter_args(%arg = %f0) -> f16 {
    %y = affine.for %j = 0 to 64 iter_args(%arg_1 = %arg) -> f16 {
      //%div = arith.divf %arg_1, %f100 : f16
      memref.store %arg_1, %b_host[%i, %j] : !b_type
      %iter_arg = arith.addf %arg_1, %f001 : f16
      affine.yield %iter_arg : f16
    } 
    affine.yield %y : f16
  }
  %token_0 = gpu.wait async
  %a_device, %token_1 = gpu.alloc async[%token_0] () : !a_type
  %b_device, %token_2 = gpu.alloc async[%token_0] () : !b_type
  %a_cp = gpu.memcpy async [%token_1] %a_device, %a_host : !a_type, !a_type
  %b_cp = gpu.memcpy async [%token_2] %b_device, %b_host : !b_type, !b_type
  %a_device_unranked = memref.cast %a_device : !a_type to memref<*xf16>
  %b_device_unranked = memref.cast %b_device : !b_type to memref<*xf16>
  %a_device_map = nvgpu.tma.create.descriptor %a_device_unranked box[%c128, %c64] : memref<*xf16> -> !a_descriptor
  %b_device_map = nvgpu.tma.create.descriptor %b_device_unranked box[%c64, %c64] : memref<*xf16> -> !b_descriptor
  // a smem + b smem = 24567
  // a smem + b smem + c smem + c1 smem = 40960
  %40960 = arith.constant 65536 : i32
  
  gpu.launch blocks(%bx, %by, %bz) in (%sz_bx = %c1, %sz_by = %c1, %sz_bz = %c1) 
            threads(%tx, %ty, %tz) in (%sz_tx = %c128, %sz_ty = %c1, %sz_tz = %c1)
            dynamic_shared_memory_size %40960 {
    %c16384 = arith.constant 16384 : index
    %c30000 = arith.constant 30000 : index
    %c20000 = arith.constant 20000 : index
    %c24576 = arith.constant 24576 : index
    %c8196 = arith.constant 8196 : index
    %c4608 = arith.constant 4608 : index
    %c6656 = arith.constant 6656 : index
    %c4096 = arith.constant 4096 : index
    %c2048 = arith.constant 2048 : index
    %c512 = arith.constant 512 : index
    %c256 = arith.constant 256 : index
    %c0 = arith.constant 0 : index
    %0 = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
    %a_smem = memref.view %0[%c0][] : memref<?xi8, #gpu.address_space<workgroup>> to !a_smem_type
    %b_smem = memref.view %0[%c16384][] : memref<?xi8, #gpu.address_space<workgroup>> to !b_smem_type
    %c_smem = memref.view %0[%c0][] : memref<?xi8, #gpu.address_space<workgroup>> to !c_smem_type_double

    %thread_size = gpu.block_dim x
    %thread_id = gpu.thread_id x
    %mbarrier = nvgpu.mbarrier.create -> !nvgpu.mbarrier.group<memorySpace = #gpu.address_space<workgroup>>
    nvgpu.mbarrier.init %mbarrier[%c0], %thread_size : !nvgpu.mbarrier.group<memorySpace = #gpu.address_space<workgroup>> 
    %thread_0 = arith.cmpi eq, %thread_id, %c0 : index
    scf.if %thread_0 {
      gpu.printf "print a matrix:\n"
      affine.for %i = 0 to 128 {
        affine.for %j = 0 to 64 {
          %value = memref.load %a_device[%i, %j] : !a_type
          %value_f32 = arith.extf %value : f16 to f32
          gpu.printf "%4.2f ", %value_f32 : f32
        }
        gpu.printf "\n"
      } 
    }

    scf.if %thread_0 {
      gpu.printf "print b matrix:\n"
      affine.for %i = 0 to 64 {
        affine.for %j = 0 to 64 {
          %value = memref.load %b_device[%i, %j] : !b_type
          %value_f32 = arith.extf %value : f16 to f32
          gpu.printf "%4.2f ", %value_f32 : f32
        }
        gpu.printf "\n"
      } 
    }

    scf.if %thread_0 {
      nvgpu.tma.async.load %a_device_map[%c0, %c0], %mbarrier[%c0] to %a_smem : !a_descriptor, !barrierType -> !a_smem_type
      nvgpu.mbarrier.arrive.expect_tx %mbarrier[%c0], %c16384: !barrierType
    } else {
      nvgpu.mbarrier.arrive.expect_tx %mbarrier[%c0], %c0 : !barrierType
    }

    scf.if %thread_0 {
      nvgpu.tma.async.load %b_device_map[%c0, %c0], %mbarrier[%c0] to %b_smem : !b_descriptor, !barrierType -> !b_smem_type
      nvgpu.mbarrier.arrive.expect_tx %mbarrier[%c0], %c8196 : !barrierType
    } else {
      nvgpu.mbarrier.arrive.expect_tx %mbarrier[%c0], %c0 : !barrierType
    }

    %phase_c0 = arith.constant 0 : i1
    %c10000000 = arith.constant 10000000 : index
    nvgpu.mbarrier.try_wait.parity %mbarrier[%c0], %phase_c0, %c10000000 : !barrierType
    scf.if %thread_0 {
      gpu.printf "print a smem:\n"
      affine.for %i = 0 to 128 {
        affine.for %j = 0 to 64 {
          %value = memref.load %a_smem[%i, %j] : !a_smem_type
          %value_f32 = arith.extf %value : f16 to f32
          gpu.printf "%4.2f ", %value_f32 : f32
        }
        gpu.printf "\n"
      } 
    }

    scf.if %thread_0 {
      gpu.printf "print b smem:\n"
      affine.for %i = 0 to 64 {
        affine.for %j = 0 to 64 {
          %value = memref.load %b_smem[%i, %j] : !b_smem_type
          %value_f32 = arith.extf %value : f16 to f32
          gpu.printf "%4.2f ", %value_f32 : f32
        }
        gpu.printf "\n"
      } 
    }

    %acc = nvgpu.warpgroup.mma.init.accumulator -> <fragmented =vector<128x64xf32>>
    %dA = nvgpu.warpgroup.generate.descriptor %a_smem, %a_device_map : !a_smem_type, !a_descriptor -> !nvgpu.warpgroup.descriptor<tensor=memref<128x64xf16, #gpu.address_space<workgroup>>>
    %dB = nvgpu.warpgroup.generate.descriptor %b_smem, %b_device_map : !b_smem_type, !b_descriptor -> !nvgpu.warpgroup.descriptor<tensor=memref<64x64xf16, #gpu.address_space<workgroup>>>
    %md  = nvgpu.warpgroup.mma %dA, %dB, %acc {transposeB}: !nvgpu.warpgroup.descriptor<tensor=memref<128x64xf16, #gpu.address_space<workgroup>>>, !nvgpu.warpgroup.descriptor<tensor=memref<64x64xf16, #gpu.address_space<workgroup>>> , <fragmented =vector<128x64xf32>> -> <fragmented =vector<128x64xf32>>
    nvvm.wgmma.wait.group.sync.aligned 0
    %subview = memref.subview %c_smem [0, 0][128, 64][1, 1] : !c_smem_type_double to memref<128x64xf32, strided<[128, 1]>, #gpu.address_space<workgroup>>
    scf.if %thread_0 {
      %f0_f32 = arith.constant 0.0 : f32
      affine.for %i = 0 to 128 {
        affine.for %j = 0 to 128 {
          memref.store %f0_f32, %c_smem[%i, %j] : !c_smem_type_double
        }
      } 
    }
    gpu.barrier
    nvgpu.warpgroup.mma.store %md, %subview : <fragmented =vector<128x64xf32>> to memref<128x64xf32, strided<[128, 1]>, #gpu.address_space<workgroup>>
    gpu.barrier
    scf.if %thread_0 {
      gpu.printf "print result:\n"
      affine.for %i = 0 to 128 {
        affine.for %j = 0 to 128 {
          %value = memref.load %c_smem[%i, %j] : !c_smem_type_double
          gpu.printf "%4.4f ", %value : f32
        }
        gpu.printf "\n"
      } 
    }

    gpu.terminator
  }

  %c_host = memref.alloc() : !c_type
  affine.for %m = 0 to 128 {
    affine.for %n = 0 to 64 {
      affine.for %k = 0 to 64 {
        %a = memref.load %a_host[%m, %k] : !a_type
        %b = memref.load %b_host[%k, %n] : !b_type
        %c = memref.load %c_host[%m, %n] : !c_type
        %a_f32 = arith.extf %a : f16 to f32
        %b_f32 = arith.extf %b : f16 to f32
        %mul = arith.mulf %a_f32, %b_f32 : f32
        %add = arith.addf %mul, %c : f32
        memref.store %add, %c_host[%m, %n] : !c_type
      }
    }
  }
  %c_cast = memref.cast %c_host : !c_type to memref<*xf32>
  call @printMemrefF32(%c_cast) : (memref<*xf32>) -> ()
  return
}
```

https://github.com/llvm/llvm-project/pull/154581


More information about the Mlir-commits mailing list