[Mlir-commits] [mlir] [mlir][ptr] Add int_to_ptr && ptr_to_int ops (PR #190527)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Apr 5 22:14:09 PDT 2026


enjustli wrote:

such as this case: https://github.com/triton-lang/triton/blob/f1ff6575ef2b7b24de60afa74a7c9ed3e9a48264/python/tutorials/08-grouped-gemm.py#L270

Although this Python file is a tutorial, the operators are real and effective.

```python

@triton.autotune(
    tma_configs,
    key=['group_size'],
)
@triton.jit
def grouped_matmul_tma_kernel(
    # device tensor of matrices pointers
    group_a_ptrs,
    group_b_ptrs,
    group_c_ptrs,
    # device tensor of gemm sizes. its shape is [group_size, 3]
    # dim 0 is group_size, dim 1 is the values of <M, N, K> of each gemm
    group_gemm_sizes,
    # device tensor of leading dimension sizes. its shape is [group_size, 3]
    # dim 0 is group_size, dim 1 is the values of <lda, ldb, ldc> of each gemm
    g_lds,
    # number of gemms
    group_size,
    # number of virtual SM
    NUM_SM: tl.constexpr,
    # tile sizes
    BLOCK_SIZE_M: tl.constexpr,
    BLOCK_SIZE_N: tl.constexpr,
    BLOCK_SIZE_K: tl.constexpr,
    # is the output FP8 or FP16
    FP8: tl.constexpr,
):
    dtype = tl.float8e4nv if FP8 else tl.float16
    tile_idx = tl.program_id(0)
    last_problem_end = 0
    for g in range(group_size):
        # get the gemm size of the current problem
        gm = tl.load(group_gemm_sizes + g * 3)
        gn = tl.load(group_gemm_sizes + g * 3 + 1)
        gk = tl.load(group_gemm_sizes + g * 3 + 2)
        num_m_tiles = tl.cdiv(gm, BLOCK_SIZE_M)
        num_n_tiles = tl.cdiv(gn, BLOCK_SIZE_N)
        num_tiles = num_m_tiles * num_n_tiles
        if tile_idx >= last_problem_end and tile_idx < last_problem_end + num_tiles:
            # pick up a tile from the current gemm problem
            lda = tl.load(g_lds + g * 3)
            ldb = tl.load(g_lds + g * 3 + 1)
            ldc = tl.load(g_lds + g * 3 + 2)

            a_ptr = tl.load(group_a_ptrs + g).to(tl.pointer_type(dtype))
            b_ptr = tl.load(group_b_ptrs + g).to(tl.pointer_type(dtype))
            c_ptr = tl.load(group_c_ptrs + g).to(tl.pointer_type(dtype))

            a_desc = tl.make_tensor_descriptor(
                a_ptr,
                shape=[gm, gk],
                strides=[lda, 1],
                block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K],
            )

            b_desc = tl.make_tensor_descriptor(
                b_ptr,
                shape=[gn, gk],
                strides=[ldb, 1],
                block_shape=[BLOCK_SIZE_N, BLOCK_SIZE_K],
            )
            c_desc = tl.make_tensor_descriptor(
                c_ptr,
                shape=[gm, gn],
                strides=[ldc, 1],
                block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N],
            )

            # iterate through the tiles in the current gemm problem
            while (tile_idx >= last_problem_end and tile_idx < last_problem_end + num_tiles):
                k = gk
                # figure out tile coordinates
                tile_idx_in_gemm = tile_idx - last_problem_end
                tile_m_idx = tile_idx_in_gemm // num_n_tiles
                tile_n_idx = tile_idx_in_gemm % num_n_tiles

                # do regular gemm here
                offs_am = tile_m_idx * BLOCK_SIZE_M
                offs_bn = tile_n_idx * BLOCK_SIZE_N

                accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
                for kk in range(0, tl.cdiv(k, BLOCK_SIZE_K)):
                    a = a_desc.load([offs_am, kk * BLOCK_SIZE_K])
                    b = b_desc.load([offs_bn, kk * BLOCK_SIZE_K])
                    accumulator += tl.dot(a, b.T)

                offs_cm = tile_m_idx * BLOCK_SIZE_M
                offs_cn = tile_n_idx * BLOCK_SIZE_N

                c = accumulator.to(dtype)
                c_desc.store([offs_cm, offs_cn], c)

                # go to the next tile by advancing NUM_SM
                tile_idx += NUM_SM

        # get ready to go to the next gemm problem
        last_problem_end = last_problem_end + num_tiles

```
In the above code,
```python
            a_ptr = tl.load(group_a_ptrs + g).to(tl.pointer_type(dtype))
            b_ptr = tl.load(group_b_ptrs + g).to(tl.pointer_type(dtype))
            c_ptr = tl.load(group_c_ptrs + g).to(tl.pointer_type(dtype))
```

after some pass, this code will generate 
```
      %a_ptr = tt.addptr %group_a_ptrs, %g : !tt.ptr<i64>, i32
      %a_ptr_6 = tt.load %a_ptr : !tt.ptr<i64>
      %a_ptr_7 = tt.int_to_ptr %a_ptr_6 : i64 -> !tt.ptr<f16>
      %b_ptr = tt.addptr %group_b_ptrs, %g : !tt.ptr<i64>, i32
      %b_ptr_8 = tt.load %b_ptr : !tt.ptr<i64>
      %b_ptr_9 = tt.int_to_ptr %b_ptr_8 : i64 -> !tt.ptr<f16>
      %c_ptr = tt.addptr %group_c_ptrs, %g : !tt.ptr<i64>, i32
      %c_ptr_10 = tt.load %c_ptr : !tt.ptr<i64>
      %c_ptr_11 = tt.int_to_ptr %c_ptr_10 : i64 -> !tt.ptr<f16>
``

full mlir file, from: https://github.com/triton-lang/triton/blob/f1ff6575ef2b7b24de60afa74a7c9ed3e9a48264/test/TritonGPU/automatic-warp-specialization.mlir#L354
```mlir

// CHECK-LABEL: @grouped_matmul_tma_kernel
#blocked = #ttg.blocked<{sizePerThread = [1, 128], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
#blocked2 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}>
#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 16}>
#smem = #ttg.shared_memory
#tmem = #ttng.tensor_memory_encoding<blockM = 128, blockN = 128, colStride = 1>
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} {
  tt.func public @grouped_matmul_tma_kernel(%group_a_ptrs: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %group_b_ptrs: !tt.ptr<i64> {tt.divisibility = 16 : i32} , %group_c_ptrs: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %gm: i32 {tt.divisibility = 16 : i32}, %gn: i32 {tt.divisibility = 16 : i32}, %gk: i32 {tt.divisibility = 16 : i32}, %group_size: i32) attributes {noinline = false} {
    %false = arith.constant false
    %true = arith.constant true
    %c1_i32 = arith.constant 1 : i32
    %c3_i32 = arith.constant 3 : i32
    %c2_i32 = arith.constant 2 : i32
    %c1_i64 = arith.constant 1 : i64
    %c128_i32 = arith.constant 128 : i32
    %c64_i32 = arith.constant 64 : i32
    %c4_i32 = arith.constant 4 : i32
    %c0_i32 = arith.constant 0 : i32
    %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked>
    %num_m_tiles_0 = arith.divsi %gm, %c128_i32 : i32
    %num_n_tiles_1 = arith.divsi %gn, %c128_i32 : i32
    %num_tiles = arith.muli %num_m_tiles_0, %num_n_tiles_1 : i32
    %start_pid = tt.get_program_id x : i32
    %1 = arith.divsi %gk, %c64_i32 : i32
    %stride = arith.constant 1024 : i64
    // CHECK: ttng.tmem_alloc : () -> !ttg.memdesc<2x128x128xf32, #tmem, #ttng.tensor_memory, mutable>
    // CHECK: default
    // CHECK: ttg.global_scratch_alloc {alignment = 128 : i32, nbytes = 384 : i32}
    // CHECK: scf.for
    // CHECK: ttng.tensormap_create
    // CHECK: scf.for
    // CHECK: partition0
    // CHECK: partition1
    // CHECK: ttg.global_scratch_alloc {alignment = 128 : i32, nbytes = 384 : i32}
    // CHECK: ttg.global_scratch_alloc {alignment = 128 : i32, nbytes = 384 : i32}
    // CHECK: scf.for
    // CHECK: ttng.tensormap_create
    // CHECK: ttng.tensormap_create
    // CHECK: scf.for
    // CHECK: scf.for
    scf.for %g = %c0_i32 to %group_size step %c1_i32  : i32 {
      %a_ptr = tt.addptr %group_a_ptrs, %g : !tt.ptr<i64>, i32
      %a_ptr_6 = tt.load %a_ptr : !tt.ptr<i64>
      %a_ptr_7 = tt.int_to_ptr %a_ptr_6 : i64 -> !tt.ptr<f16>
      %b_ptr = tt.addptr %group_b_ptrs, %g : !tt.ptr<i64>, i32
      %b_ptr_8 = tt.load %b_ptr : !tt.ptr<i64>
      %b_ptr_9 = tt.int_to_ptr %b_ptr_8 : i64 -> !tt.ptr<f16>
      %c_ptr = tt.addptr %group_c_ptrs, %g : !tt.ptr<i64>, i32
      %c_ptr_10 = tt.load %c_ptr : !tt.ptr<i64>
      %c_ptr_11 = tt.int_to_ptr %c_ptr_10 : i64 -> !tt.ptr<f16>
      %a_desc_12 = tt.make_tensor_descriptor %a_ptr_7, [%gm, %gk], [%stride, %c1_i64] : <f16>, <tensor<128x64xf16, #shared>>
      %b_desc_13 = tt.make_tensor_descriptor %b_ptr_9, [%gn, %gk], [%stride, %c1_i64] : <f16>, <tensor<128x64xf16, #shared>>
      %c_desc_14 = tt.make_tensor_descriptor %c_ptr_11, [%gm, %gn], [%stride, %c1_i64] : <f16>, <tensor<128x128xf16, #shared>>
      scf.for %tile_idx = %start_pid to %num_tiles step %c4_i32  : i32 {
        %tile_m_idx = arith.divsi %tile_idx, %num_n_tiles_1 : i32
        %tile_n_idx = arith.remsi %tile_idx, %num_n_tiles_1 : i32
        %offs_am = arith.muli %tile_m_idx, %c128_i32 : i32
        %offs_bn = arith.muli %tile_n_idx, %c128_i32 : i32
        %accumulator, %accumulator_15 = ttng.tmem_alloc : () -> (!ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, !ttg.async.token)
        %accumulator_16 = ttng.tmem_store %cst, %accumulator[%accumulator_15], %true : tensor<128x128xf32, #blocked> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>
        %accumulator_17:2 = scf.for %accumulator_20 = %c0_i32 to %1 step %c1_i32 iter_args(%arg11 = %false, %accumulator_21 = %accumulator_16) -> (i1, !ttg.async.token)  : i32 {
          %a = arith.muli %accumulator_20, %c64_i32 : i32
          %a_22 = tt.descriptor_load %a_desc_12[%offs_am, %a] {loop.cluster = 2 : i32, loop.stage = 0 : i32} : !tt.tensordesc<tensor<128x64xf16, #shared>> -> tensor<128x64xf16, #blocked1>
          %a_23 = ttg.local_alloc %a_22 : (tensor<128x64xf16, #blocked1>) -> !ttg.memdesc<128x64xf16, #shared, #smem>
          %b = tt.descriptor_load %b_desc_13[%offs_bn, %a] : !tt.tensordesc<tensor<128x64xf16, #shared>> -> tensor<128x64xf16, #blocked1>
          %accumulator_24 = ttg.local_alloc %b : (tensor<128x64xf16, #blocked1>) -> !ttg.memdesc<128x64xf16, #shared, #smem>
          %accumulator_25 = ttg.memdesc_trans %accumulator_24 {order = array<i32: 1, 0>} : !ttg.memdesc<128x64xf16, #shared, #smem> -> !ttg.memdesc<64x128xf16, #shared1, #smem>
          %accumulator_26 = ttng.tc_gen5_mma %a_23, %accumulator_25, %accumulator[%accumulator_21], %arg11, %true : !ttg.memdesc<128x64xf16, #shared, #smem>, !ttg.memdesc<64x128xf16, #shared1, #smem>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>
          scf.yield %true, %accumulator_26 : i1, !ttg.async.token
        } {tt.scheduled_max_stage = 2 : i32}
        %accumulator_18, %accumulator_19 = ttng.tmem_load %accumulator[%accumulator_17#1] : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked>
        %c = arith.truncf %accumulator_18 : tensor<128x128xf32, #blocked> to tensor<128x128xf16, #blocked>
        %2 = ttg.convert_layout %c : tensor<128x128xf16, #blocked> -> tensor<128x128xf16, #blocked2>
        tt.descriptor_store %c_desc_14[%offs_am, %offs_bn], %2 : !tt.tensordesc<tensor<128x128xf16, #shared>>, tensor<128x128xf16, #blocked2>
      }
    } {tt.warp_specialize}
    tt.return
  }
}
```



https://github.com/llvm/llvm-project/pull/190527


More information about the Mlir-commits mailing list