[Mlir-commits] [mlir] [mlir][ptr] Add int_to_ptr && ptr_to_int ops (PR #190527)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Apr 5 22:14:09 PDT 2026
enjustli wrote:
such as this case: https://github.com/triton-lang/triton/blob/f1ff6575ef2b7b24de60afa74a7c9ed3e9a48264/python/tutorials/08-grouped-gemm.py#L270
Although this Python file is a tutorial, the operators are real and effective.
```python
@triton.autotune(
tma_configs,
key=['group_size'],
)
@triton.jit
def grouped_matmul_tma_kernel(
# device tensor of matrices pointers
group_a_ptrs,
group_b_ptrs,
group_c_ptrs,
# device tensor of gemm sizes. its shape is [group_size, 3]
# dim 0 is group_size, dim 1 is the values of <M, N, K> of each gemm
group_gemm_sizes,
# device tensor of leading dimension sizes. its shape is [group_size, 3]
# dim 0 is group_size, dim 1 is the values of <lda, ldb, ldc> of each gemm
g_lds,
# number of gemms
group_size,
# number of virtual SM
NUM_SM: tl.constexpr,
# tile sizes
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
# is the output FP8 or FP16
FP8: tl.constexpr,
):
dtype = tl.float8e4nv if FP8 else tl.float16
tile_idx = tl.program_id(0)
last_problem_end = 0
for g in range(group_size):
# get the gemm size of the current problem
gm = tl.load(group_gemm_sizes + g * 3)
gn = tl.load(group_gemm_sizes + g * 3 + 1)
gk = tl.load(group_gemm_sizes + g * 3 + 2)
num_m_tiles = tl.cdiv(gm, BLOCK_SIZE_M)
num_n_tiles = tl.cdiv(gn, BLOCK_SIZE_N)
num_tiles = num_m_tiles * num_n_tiles
if tile_idx >= last_problem_end and tile_idx < last_problem_end + num_tiles:
# pick up a tile from the current gemm problem
lda = tl.load(g_lds + g * 3)
ldb = tl.load(g_lds + g * 3 + 1)
ldc = tl.load(g_lds + g * 3 + 2)
a_ptr = tl.load(group_a_ptrs + g).to(tl.pointer_type(dtype))
b_ptr = tl.load(group_b_ptrs + g).to(tl.pointer_type(dtype))
c_ptr = tl.load(group_c_ptrs + g).to(tl.pointer_type(dtype))
a_desc = tl.make_tensor_descriptor(
a_ptr,
shape=[gm, gk],
strides=[lda, 1],
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K],
)
b_desc = tl.make_tensor_descriptor(
b_ptr,
shape=[gn, gk],
strides=[ldb, 1],
block_shape=[BLOCK_SIZE_N, BLOCK_SIZE_K],
)
c_desc = tl.make_tensor_descriptor(
c_ptr,
shape=[gm, gn],
strides=[ldc, 1],
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N],
)
# iterate through the tiles in the current gemm problem
while (tile_idx >= last_problem_end and tile_idx < last_problem_end + num_tiles):
k = gk
# figure out tile coordinates
tile_idx_in_gemm = tile_idx - last_problem_end
tile_m_idx = tile_idx_in_gemm // num_n_tiles
tile_n_idx = tile_idx_in_gemm % num_n_tiles
# do regular gemm here
offs_am = tile_m_idx * BLOCK_SIZE_M
offs_bn = tile_n_idx * BLOCK_SIZE_N
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for kk in range(0, tl.cdiv(k, BLOCK_SIZE_K)):
a = a_desc.load([offs_am, kk * BLOCK_SIZE_K])
b = b_desc.load([offs_bn, kk * BLOCK_SIZE_K])
accumulator += tl.dot(a, b.T)
offs_cm = tile_m_idx * BLOCK_SIZE_M
offs_cn = tile_n_idx * BLOCK_SIZE_N
c = accumulator.to(dtype)
c_desc.store([offs_cm, offs_cn], c)
# go to the next tile by advancing NUM_SM
tile_idx += NUM_SM
# get ready to go to the next gemm problem
last_problem_end = last_problem_end + num_tiles
```
In the above code,
```python
a_ptr = tl.load(group_a_ptrs + g).to(tl.pointer_type(dtype))
b_ptr = tl.load(group_b_ptrs + g).to(tl.pointer_type(dtype))
c_ptr = tl.load(group_c_ptrs + g).to(tl.pointer_type(dtype))
```
after some pass, this code will generate
```
%a_ptr = tt.addptr %group_a_ptrs, %g : !tt.ptr<i64>, i32
%a_ptr_6 = tt.load %a_ptr : !tt.ptr<i64>
%a_ptr_7 = tt.int_to_ptr %a_ptr_6 : i64 -> !tt.ptr<f16>
%b_ptr = tt.addptr %group_b_ptrs, %g : !tt.ptr<i64>, i32
%b_ptr_8 = tt.load %b_ptr : !tt.ptr<i64>
%b_ptr_9 = tt.int_to_ptr %b_ptr_8 : i64 -> !tt.ptr<f16>
%c_ptr = tt.addptr %group_c_ptrs, %g : !tt.ptr<i64>, i32
%c_ptr_10 = tt.load %c_ptr : !tt.ptr<i64>
%c_ptr_11 = tt.int_to_ptr %c_ptr_10 : i64 -> !tt.ptr<f16>
``
full mlir file, from: https://github.com/triton-lang/triton/blob/f1ff6575ef2b7b24de60afa74a7c9ed3e9a48264/test/TritonGPU/automatic-warp-specialization.mlir#L354
```mlir
// CHECK-LABEL: @grouped_matmul_tma_kernel
#blocked = #ttg.blocked<{sizePerThread = [1, 128], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
#blocked2 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}>
#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 16}>
#smem = #ttg.shared_memory
#tmem = #ttng.tensor_memory_encoding<blockM = 128, blockN = 128, colStride = 1>
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} {
tt.func public @grouped_matmul_tma_kernel(%group_a_ptrs: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %group_b_ptrs: !tt.ptr<i64> {tt.divisibility = 16 : i32} , %group_c_ptrs: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %gm: i32 {tt.divisibility = 16 : i32}, %gn: i32 {tt.divisibility = 16 : i32}, %gk: i32 {tt.divisibility = 16 : i32}, %group_size: i32) attributes {noinline = false} {
%false = arith.constant false
%true = arith.constant true
%c1_i32 = arith.constant 1 : i32
%c3_i32 = arith.constant 3 : i32
%c2_i32 = arith.constant 2 : i32
%c1_i64 = arith.constant 1 : i64
%c128_i32 = arith.constant 128 : i32
%c64_i32 = arith.constant 64 : i32
%c4_i32 = arith.constant 4 : i32
%c0_i32 = arith.constant 0 : i32
%cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked>
%num_m_tiles_0 = arith.divsi %gm, %c128_i32 : i32
%num_n_tiles_1 = arith.divsi %gn, %c128_i32 : i32
%num_tiles = arith.muli %num_m_tiles_0, %num_n_tiles_1 : i32
%start_pid = tt.get_program_id x : i32
%1 = arith.divsi %gk, %c64_i32 : i32
%stride = arith.constant 1024 : i64
// CHECK: ttng.tmem_alloc : () -> !ttg.memdesc<2x128x128xf32, #tmem, #ttng.tensor_memory, mutable>
// CHECK: default
// CHECK: ttg.global_scratch_alloc {alignment = 128 : i32, nbytes = 384 : i32}
// CHECK: scf.for
// CHECK: ttng.tensormap_create
// CHECK: scf.for
// CHECK: partition0
// CHECK: partition1
// CHECK: ttg.global_scratch_alloc {alignment = 128 : i32, nbytes = 384 : i32}
// CHECK: ttg.global_scratch_alloc {alignment = 128 : i32, nbytes = 384 : i32}
// CHECK: scf.for
// CHECK: ttng.tensormap_create
// CHECK: ttng.tensormap_create
// CHECK: scf.for
// CHECK: scf.for
scf.for %g = %c0_i32 to %group_size step %c1_i32 : i32 {
%a_ptr = tt.addptr %group_a_ptrs, %g : !tt.ptr<i64>, i32
%a_ptr_6 = tt.load %a_ptr : !tt.ptr<i64>
%a_ptr_7 = tt.int_to_ptr %a_ptr_6 : i64 -> !tt.ptr<f16>
%b_ptr = tt.addptr %group_b_ptrs, %g : !tt.ptr<i64>, i32
%b_ptr_8 = tt.load %b_ptr : !tt.ptr<i64>
%b_ptr_9 = tt.int_to_ptr %b_ptr_8 : i64 -> !tt.ptr<f16>
%c_ptr = tt.addptr %group_c_ptrs, %g : !tt.ptr<i64>, i32
%c_ptr_10 = tt.load %c_ptr : !tt.ptr<i64>
%c_ptr_11 = tt.int_to_ptr %c_ptr_10 : i64 -> !tt.ptr<f16>
%a_desc_12 = tt.make_tensor_descriptor %a_ptr_7, [%gm, %gk], [%stride, %c1_i64] : <f16>, <tensor<128x64xf16, #shared>>
%b_desc_13 = tt.make_tensor_descriptor %b_ptr_9, [%gn, %gk], [%stride, %c1_i64] : <f16>, <tensor<128x64xf16, #shared>>
%c_desc_14 = tt.make_tensor_descriptor %c_ptr_11, [%gm, %gn], [%stride, %c1_i64] : <f16>, <tensor<128x128xf16, #shared>>
scf.for %tile_idx = %start_pid to %num_tiles step %c4_i32 : i32 {
%tile_m_idx = arith.divsi %tile_idx, %num_n_tiles_1 : i32
%tile_n_idx = arith.remsi %tile_idx, %num_n_tiles_1 : i32
%offs_am = arith.muli %tile_m_idx, %c128_i32 : i32
%offs_bn = arith.muli %tile_n_idx, %c128_i32 : i32
%accumulator, %accumulator_15 = ttng.tmem_alloc : () -> (!ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, !ttg.async.token)
%accumulator_16 = ttng.tmem_store %cst, %accumulator[%accumulator_15], %true : tensor<128x128xf32, #blocked> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>
%accumulator_17:2 = scf.for %accumulator_20 = %c0_i32 to %1 step %c1_i32 iter_args(%arg11 = %false, %accumulator_21 = %accumulator_16) -> (i1, !ttg.async.token) : i32 {
%a = arith.muli %accumulator_20, %c64_i32 : i32
%a_22 = tt.descriptor_load %a_desc_12[%offs_am, %a] {loop.cluster = 2 : i32, loop.stage = 0 : i32} : !tt.tensordesc<tensor<128x64xf16, #shared>> -> tensor<128x64xf16, #blocked1>
%a_23 = ttg.local_alloc %a_22 : (tensor<128x64xf16, #blocked1>) -> !ttg.memdesc<128x64xf16, #shared, #smem>
%b = tt.descriptor_load %b_desc_13[%offs_bn, %a] : !tt.tensordesc<tensor<128x64xf16, #shared>> -> tensor<128x64xf16, #blocked1>
%accumulator_24 = ttg.local_alloc %b : (tensor<128x64xf16, #blocked1>) -> !ttg.memdesc<128x64xf16, #shared, #smem>
%accumulator_25 = ttg.memdesc_trans %accumulator_24 {order = array<i32: 1, 0>} : !ttg.memdesc<128x64xf16, #shared, #smem> -> !ttg.memdesc<64x128xf16, #shared1, #smem>
%accumulator_26 = ttng.tc_gen5_mma %a_23, %accumulator_25, %accumulator[%accumulator_21], %arg11, %true : !ttg.memdesc<128x64xf16, #shared, #smem>, !ttg.memdesc<64x128xf16, #shared1, #smem>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>
scf.yield %true, %accumulator_26 : i1, !ttg.async.token
} {tt.scheduled_max_stage = 2 : i32}
%accumulator_18, %accumulator_19 = ttng.tmem_load %accumulator[%accumulator_17#1] : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked>
%c = arith.truncf %accumulator_18 : tensor<128x128xf32, #blocked> to tensor<128x128xf16, #blocked>
%2 = ttg.convert_layout %c : tensor<128x128xf16, #blocked> -> tensor<128x128xf16, #blocked2>
tt.descriptor_store %c_desc_14[%offs_am, %offs_bn], %2 : !tt.tensordesc<tensor<128x128xf16, #shared>>, tensor<128x128xf16, #blocked2>
}
} {tt.warp_specialize}
tt.return
}
}
```
https://github.com/llvm/llvm-project/pull/190527
More information about the Mlir-commits
mailing list