[Mlir-commits] [mlir] [mlir] Retain original identifier names for debugging (PR #79704)
Maksim Levental
llvmlistbot at llvm.org
Thu Dec 12 13:18:42 PST 2024
makslevental wrote:
Can we iterate on this? Preserving legible SSA names into the parse and even through passes is immensely valuable/helpful for debugging everything. Note, I don't understand the relationship to formatters so what follows doesn't address that use case.
I propose we store the identifier in the Location info as a `NameLoc`. As far as I can tell this avoids the issue of polluting the IR.
I prototyped this using @Wheest's current PR (a few small changes) and I got this:
**Source IR**:
```mlir
#loc = loc("triton/python/examples/empty.py":17:0)
module {
tt.func public @add_kernel(
%in_ptr0: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32} loc("triton/python/examples/empty.py":17:0),
%arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32} loc("triton/python/examples/empty.py":17:0),
%arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32} loc("triton/python/examples/empty.py":17:0)) -> tensor<1024xf32> attributes {noinline = false} {
%c1024_i32 = arith.constant 1024 : i32 loc(#loc1)
%pid = tt.get_program_id x : i32 loc(#loc2)
%block_start = arith.muli %pid, %c1024_i32 : i32 loc(#loc3)
%make_range = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> loc(#loc4)
%block_start_splat = tt.splat %block_start : i32 -> tensor<1024xi32> loc(#loc5)
%offsets = arith.addi %block_start_splat, %make_range : tensor<1024xi32> loc(#loc5)
%in_ptr0_splat = tt.splat %in_ptr0 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>> loc(#loc6)
%addr = tt.addptr %in_ptr0_splat, %offsets : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32> loc(#loc6)
%val = tt.load %addr : tensor<1024x!tt.ptr<f32>> loc(#loc7)
tt.return %val : tensor<1024xf32> loc(#loc8)
} loc(#loc)
} loc(#loc)
#loc1 = loc(unknown)
#loc2 = loc("triton/python/examples/empty.py":24:24)
#loc3 = loc("triton/python/examples/empty.py":25:24)
#loc4 = loc("triton/python/examples/empty.py":26:41)
#loc5 = loc("triton/python/examples/empty.py":26:28)
#loc6 = loc("triton/python/examples/empty.py":27:26)
#loc7 = loc("triton/python/examples/empty.py":27:16)
#loc8 = loc("triton/python/examples/empty.py":29:11)
```
**Just parsing**:
```mlir
#loc = triton/python/examples/empty.py:17:0
module {
tt.func public @add_kernel(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32} triton/python/examples/empty.py:17:0, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32} triton/python/examples/empty.py:17:0, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32} triton/python/examples/empty.py:17:0) -> tensor<1024xf32> attributes {noinline = false} {
%c1024_i32 = arith.constant 1024 : i32 "c1024_i32"
%0 = tt.get_program_id x : i32 "pid"(#loc2)
%1 = arith.muli %0, %c1024_i32 : i32 "block_start"(#loc3)
%2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> "make_range"(#loc4)
%3 = tt.splat %1 : i32 -> tensor<1024xi32> "block_start_splat"(#loc5)
%4 = arith.addi %3, %2 : tensor<1024xi32> "offsets"(#loc5)
%5 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>> "in_ptr0_splat"(#loc6)
%6 = tt.addptr %5, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32> "addr"(#loc6)
%7 = tt.load %6 : tensor<1024x!tt.ptr<f32>> "val"(#loc7)
tt.return %7 : tensor<1024xf32> ""(#loc8)
} triton/python/examples/empty.py:17:0
} triton/python/examples/empty.py:17:0
#loc1 = "c1024_i32"
#loc2 = triton/python/examples/empty.py:24:24
#loc3 = triton/python/examples/empty.py:25:24
#loc4 = triton/python/examples/empty.py:26:41
#loc5 = triton/python/examples/empty.py:26:28
#loc6 = triton/python/examples/empty.py:27:26
#loc7 = triton/python/examples/empty.py:27:16
#loc8 = triton/python/examples/empty.py:29:11
```
**After some passes**:
```mlir
#loc = triton/python/examples/empty.py:17:0
module {
tt.func public @add_kernel(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32} triton/python/examples/empty.py:17:0, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32} triton/python/examples/empty.py:17:0, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32} triton/python/examples/empty.py:17:0) -> tensor<1024xf32> attributes {noinline = false, rewritten} {
%c0_i32 = arith.constant 0 : i32 triton/python/examples/empty.py:17:0
%0 = builtin.unrealized_conversion_cast %arg0, %c0_i32 : !tt.ptr<f32>, i32 to !tt.ptr<f32> triton/python/examples/empty.py:17:0
%c0_i32_0 = arith.constant 0 : i32 triton/python/examples/empty.py:17:0
%1 = builtin.unrealized_conversion_cast %arg1, %c0_i32_0 : !tt.ptr<f32>, i32 to !tt.ptr<f32> triton/python/examples/empty.py:17:0
%c0_i32_1 = arith.constant 0 : i32 triton/python/examples/empty.py:17:0
%2 = builtin.unrealized_conversion_cast %arg2, %c0_i32_1 : !tt.ptr<f32>, i32 to !tt.ptr<f32> triton/python/examples/empty.py:17:0
%c1024_i32 = arith.constant 1024 : i32 "c1024_i32"
%3 = tt.get_program_id x : i32 "pid"(#loc2)
%4 = arith.muli %3, %c1024_i32 : i32 "block_start"(#loc3)
%5 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> "make_range"(#loc4)
%6 = tt.splat %4 : i32 -> tensor<1024xi32> "block_start_splat"(#loc5)
%7 = arith.addi %6, %5 : tensor<1024xi32> "offsets"(#loc5)
%8 = tt.splat %c0_i32 : i32 -> tensor<1024xi32> "in_ptr0_splat"(#loc6)
%cst = arith.constant dense<0> : tensor<1024xi32> "addr"(#loc6)
%c0_i32_2 = arith.constant 0 : i32 "addr"(#loc6)
%9 = arith.addi %4, %c0_i32_2 : i32 "addr"(#loc6)
%10 = arith.addi %cst, %5 : tensor<1024xi32> "addr"(#loc6)
%11 = arith.addi %10, %8 : tensor<1024xi32> "addr"(#loc6)
%12 = tt.addptr %0, %9 : !tt.ptr<f32>, i32 "addr"(#loc6)
%13 = tt.splat %12 {rewritten} : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>> "val"(#loc7)
%14 = tt.addptr %13, %11 {rewritten} : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32> "val"(#loc7)
%15 = tt.load %14 {rewritten} : tensor<1024x!tt.ptr<f32>> "val"(#loc7)
tt.return %15 : tensor<1024xf32> ""(#loc8)
} triton/python/examples/empty.py:17:0
} triton/python/examples/empty.py:17:0
#loc1 = "c1024_i32"
#loc2 = triton/python/examples/empty.py:24:24
#loc3 = triton/python/examples/empty.py:25:24
#loc4 = triton/python/examples/empty.py:26:41
#loc5 = triton/python/examples/empty.py:26:28
#loc6 = triton/python/examples/empty.py:27:26
#loc7 = triton/python/examples/empty.py:27:16
#loc8 = triton/python/examples/empty.py:29:11
```
I'll point out what's nice here is that `%addr` propagates to all of the created/inserted ops which had `addPtrOp.getLoc()` passed to them:
```mlir
%cst = arith.constant dense<0> : tensor<1024xi32> "addr"(#loc6)
%c0_i32_2 = arith.constant 0 : i32 "addr"(#loc6)
%9 = arith.addi %4, %c0_i32_2 : i32 "addr"(#loc6)
%10 = arith.addi %cst, %5 : tensor<1024xi32> "addr"(#loc6)
%11 = arith.addi %10, %8 : tensor<1024xi32> "addr"(#loc6)
%12 = tt.addptr %0, %9 : !tt.ptr<f32>, i32 "addr"(#loc6)
```
Note this example started with debug info in the source and changes I currently have are compatible with that but they of course also work if the original source doesn't have any info:
```mlir
module {
tt.func public @add_kernel(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32} within split...
%c1024_i32 = arith.constant 1024 : i32 "c1024_i32"(within split...)
%0 = tt.get_program_id x : i32 "pid"(within split...)
%1 = arith.muli %0, %c1024_i32 : i32 "block_start"(within split...)
%2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> "make_range"(within split...)
%3 = tt.splat %1 : i32 -> tensor<1024xi32> "block_start_splat"(within split...)
%4 = arith.addi %3, %2 : tensor<1024xi32> "offsets"(within split...)
%5 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>> "in_ptr0_splat"(within split...)
%6 = tt.addptr %5, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32> "addr"(within split...)
%7 = tt.load %6 : tensor<1024x!tt.ptr<f32>> "val"(within split...)
tt.return %7 : tensor<1024xf32> within split...
} within split...
} within split...
```
and
```mlir
module {
tt.func public @add_kernel(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32} within split...
%c0_i32 = arith.constant 0 : i32 within split...
%0 = builtin.unrealized_conversion_cast %arg0, %c0_i32 : !tt.ptr<f32>, i32 to !tt.ptr<f32> within split...
%c0_i32_0 = arith.constant 0 : i32 within split...
%1 = builtin.unrealized_conversion_cast %arg1, %c0_i32_0 : !tt.ptr<f32>, i32 to !tt.ptr<f32> within split...
%c0_i32_1 = arith.constant 0 : i32 within split...
%2 = builtin.unrealized_conversion_cast %arg2, %c0_i32_1 : !tt.ptr<f32>, i32 to !tt.ptr<f32> within split...
%c1024_i32 = arith.constant 1024 : i32 "c1024_i32"(within split...)
%3 = tt.get_program_id x : i32 "pid"(within split...)
%4 = arith.muli %3, %c1024_i32 : i32 "block_start"(within split...)
%5 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> "make_range"(within split...)
%6 = tt.splat %4 : i32 -> tensor<1024xi32> "block_start_splat"(within split...)
%7 = arith.addi %6, %5 : tensor<1024xi32> "offsets"(within split...)
%8 = tt.splat %c0_i32 : i32 -> tensor<1024xi32> "in_ptr0_splat"(within split...)
%cst = arith.constant dense<0> : tensor<1024xi32> "addr"(within split...)
%c0_i32_2 = arith.constant 0 : i32 "addr"(within split...)
%9 = arith.addi %4, %c0_i32_2 : i32 "addr"(within split...)
%10 = arith.addi %cst, %5 : tensor<1024xi32> "addr"(within split...)
%11 = arith.addi %10, %8 : tensor<1024xi32> "addr"(within split...)
%12 = tt.addptr %0, %9 : !tt.ptr<f32>, i32 "addr"(within split...)
%13 = tt.splat %12 {rewritten} : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>> "val"(within split...)
%14 = tt.addptr %13, %11 {rewritten} : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32> "val"(within split...)
%15 = tt.load %14 {rewritten} : tensor<1024x!tt.ptr<f32>> "val"(within split...)
tt.return %15 : tensor<1024xf32> within split...
} within split...
} within split...
```
https://github.com/llvm/llvm-project/pull/79704
More information about the Mlir-commits
mailing list