[Mlir-commits] [mlir] [mlir] Retain original identifier names for debugging (PR #79704)

Maksim Levental llvmlistbot at llvm.org
Thu Dec 12 13:18:42 PST 2024


makslevental wrote:

Can we iterate on this? Preserving legible SSA names into the parse and even through passes is immensely valuable/helpful for debugging everything. Note, I don't understand the relationship to formatters so what follows doesn't address that use case.

I propose we store the identifier in the Location info as a `NameLoc`. As far as I can tell this avoids the issue of polluting the IR.

I prototyped this using @Wheest's current PR (a few small changes) and I got this:

**Source IR**:

```mlir
#loc = loc("triton/python/examples/empty.py":17:0)
module {
  tt.func public @add_kernel(
    %in_ptr0: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32} loc("triton/python/examples/empty.py":17:0),
    %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32} loc("triton/python/examples/empty.py":17:0),
    %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32} loc("triton/python/examples/empty.py":17:0)) -> tensor<1024xf32> attributes {noinline = false} {
    %c1024_i32 = arith.constant 1024 : i32 loc(#loc1)
    %pid = tt.get_program_id x : i32 loc(#loc2)
    %block_start = arith.muli %pid, %c1024_i32 : i32 loc(#loc3)
    %make_range = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> loc(#loc4)
    %block_start_splat = tt.splat %block_start : i32 -> tensor<1024xi32> loc(#loc5)
    %offsets = arith.addi %block_start_splat, %make_range : tensor<1024xi32> loc(#loc5)
    %in_ptr0_splat = tt.splat %in_ptr0 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>> loc(#loc6)
    %addr = tt.addptr %in_ptr0_splat, %offsets : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32> loc(#loc6)
    %val = tt.load %addr : tensor<1024x!tt.ptr<f32>> loc(#loc7)
    tt.return %val : tensor<1024xf32> loc(#loc8)
  } loc(#loc)
} loc(#loc)
#loc1 = loc(unknown)
#loc2 = loc("triton/python/examples/empty.py":24:24)
#loc3 = loc("triton/python/examples/empty.py":25:24)
#loc4 = loc("triton/python/examples/empty.py":26:41)
#loc5 = loc("triton/python/examples/empty.py":26:28)
#loc6 = loc("triton/python/examples/empty.py":27:26)
#loc7 = loc("triton/python/examples/empty.py":27:16)
#loc8 = loc("triton/python/examples/empty.py":29:11)
```

**Just parsing**:

```mlir
#loc = triton/python/examples/empty.py:17:0
module {
  tt.func public @add_kernel(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32} triton/python/examples/empty.py:17:0, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32} triton/python/examples/empty.py:17:0, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32} triton/python/examples/empty.py:17:0) -> tensor<1024xf32> attributes {noinline = false} {
    %c1024_i32 = arith.constant 1024 : i32 "c1024_i32"
    %0 = tt.get_program_id x : i32 "pid"(#loc2)
    %1 = arith.muli %0, %c1024_i32 : i32 "block_start"(#loc3)
    %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> "make_range"(#loc4)
    %3 = tt.splat %1 : i32 -> tensor<1024xi32> "block_start_splat"(#loc5)
    %4 = arith.addi %3, %2 : tensor<1024xi32> "offsets"(#loc5)
    %5 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>> "in_ptr0_splat"(#loc6)
    %6 = tt.addptr %5, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32> "addr"(#loc6)
    %7 = tt.load %6 : tensor<1024x!tt.ptr<f32>> "val"(#loc7)
    tt.return %7 : tensor<1024xf32> ""(#loc8)
  } triton/python/examples/empty.py:17:0
} triton/python/examples/empty.py:17:0
#loc1 = "c1024_i32"
#loc2 = triton/python/examples/empty.py:24:24
#loc3 = triton/python/examples/empty.py:25:24
#loc4 = triton/python/examples/empty.py:26:41
#loc5 = triton/python/examples/empty.py:26:28
#loc6 = triton/python/examples/empty.py:27:26
#loc7 = triton/python/examples/empty.py:27:16
#loc8 = triton/python/examples/empty.py:29:11
```

**After some passes**:

```mlir
#loc = triton/python/examples/empty.py:17:0
module {
  tt.func public @add_kernel(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32} triton/python/examples/empty.py:17:0, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32} triton/python/examples/empty.py:17:0, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32} triton/python/examples/empty.py:17:0) -> tensor<1024xf32> attributes {noinline = false, rewritten} {
    %c0_i32 = arith.constant 0 : i32 triton/python/examples/empty.py:17:0
    %0 = builtin.unrealized_conversion_cast %arg0, %c0_i32 : !tt.ptr<f32>, i32 to !tt.ptr<f32> triton/python/examples/empty.py:17:0
    %c0_i32_0 = arith.constant 0 : i32 triton/python/examples/empty.py:17:0
    %1 = builtin.unrealized_conversion_cast %arg1, %c0_i32_0 : !tt.ptr<f32>, i32 to !tt.ptr<f32> triton/python/examples/empty.py:17:0
    %c0_i32_1 = arith.constant 0 : i32 triton/python/examples/empty.py:17:0
    %2 = builtin.unrealized_conversion_cast %arg2, %c0_i32_1 : !tt.ptr<f32>, i32 to !tt.ptr<f32> triton/python/examples/empty.py:17:0
    %c1024_i32 = arith.constant 1024 : i32 "c1024_i32"
    %3 = tt.get_program_id x : i32 "pid"(#loc2)
    %4 = arith.muli %3, %c1024_i32 : i32 "block_start"(#loc3)
    %5 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> "make_range"(#loc4)
    %6 = tt.splat %4 : i32 -> tensor<1024xi32> "block_start_splat"(#loc5)
    %7 = arith.addi %6, %5 : tensor<1024xi32> "offsets"(#loc5)
    %8 = tt.splat %c0_i32 : i32 -> tensor<1024xi32> "in_ptr0_splat"(#loc6)
    %cst = arith.constant dense<0> : tensor<1024xi32> "addr"(#loc6)
    %c0_i32_2 = arith.constant 0 : i32 "addr"(#loc6)
    %9 = arith.addi %4, %c0_i32_2 : i32 "addr"(#loc6)
    %10 = arith.addi %cst, %5 : tensor<1024xi32> "addr"(#loc6)
    %11 = arith.addi %10, %8 : tensor<1024xi32> "addr"(#loc6)
    %12 = tt.addptr %0, %9 : !tt.ptr<f32>, i32 "addr"(#loc6)
    %13 = tt.splat %12 {rewritten} : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>> "val"(#loc7)
    %14 = tt.addptr %13, %11 {rewritten} : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32> "val"(#loc7)
    %15 = tt.load %14 {rewritten} : tensor<1024x!tt.ptr<f32>> "val"(#loc7)
    tt.return %15 : tensor<1024xf32> ""(#loc8)
  } triton/python/examples/empty.py:17:0
} triton/python/examples/empty.py:17:0
#loc1 = "c1024_i32"
#loc2 = triton/python/examples/empty.py:24:24
#loc3 = triton/python/examples/empty.py:25:24
#loc4 = triton/python/examples/empty.py:26:41
#loc5 = triton/python/examples/empty.py:26:28
#loc6 = triton/python/examples/empty.py:27:26
#loc7 = triton/python/examples/empty.py:27:16
#loc8 = triton/python/examples/empty.py:29:11
```

I'll point out what's nice here is that `%addr` propagates to all of the created/inserted ops which had `addPtrOp.getLoc()` passed to them:

```mlir
%cst = arith.constant dense<0> : tensor<1024xi32> "addr"(#loc6)
%c0_i32_2 = arith.constant 0 : i32 "addr"(#loc6)
%9 = arith.addi %4, %c0_i32_2 : i32 "addr"(#loc6)
%10 = arith.addi %cst, %5 : tensor<1024xi32> "addr"(#loc6)
%11 = arith.addi %10, %8 : tensor<1024xi32> "addr"(#loc6)
%12 = tt.addptr %0, %9 : !tt.ptr<f32>, i32 "addr"(#loc6)
```

Note this example started with debug info in the source and changes I currently have are compatible with that but they of course also work if the original source doesn't have any info:

```mlir
module {
  tt.func public @add_kernel(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32} within split...
    %c1024_i32 = arith.constant 1024 : i32 "c1024_i32"(within split...)
    %0 = tt.get_program_id x : i32 "pid"(within split...)
    %1 = arith.muli %0, %c1024_i32 : i32 "block_start"(within split...)
    %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> "make_range"(within split...)
    %3 = tt.splat %1 : i32 -> tensor<1024xi32> "block_start_splat"(within split...)
    %4 = arith.addi %3, %2 : tensor<1024xi32> "offsets"(within split...)
    %5 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>> "in_ptr0_splat"(within split...)
    %6 = tt.addptr %5, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32> "addr"(within split...)
    %7 = tt.load %6 : tensor<1024x!tt.ptr<f32>> "val"(within split...)
    tt.return %7 : tensor<1024xf32> within split...
  } within split...
} within split...
```
and 
```mlir
module {
  tt.func public @add_kernel(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32} within split...
    %c0_i32 = arith.constant 0 : i32 within split...
    %0 = builtin.unrealized_conversion_cast %arg0, %c0_i32 : !tt.ptr<f32>, i32 to !tt.ptr<f32> within split...
    %c0_i32_0 = arith.constant 0 : i32 within split...
    %1 = builtin.unrealized_conversion_cast %arg1, %c0_i32_0 : !tt.ptr<f32>, i32 to !tt.ptr<f32> within split...
    %c0_i32_1 = arith.constant 0 : i32 within split...
    %2 = builtin.unrealized_conversion_cast %arg2, %c0_i32_1 : !tt.ptr<f32>, i32 to !tt.ptr<f32> within split...
    %c1024_i32 = arith.constant 1024 : i32 "c1024_i32"(within split...)
    %3 = tt.get_program_id x : i32 "pid"(within split...)
    %4 = arith.muli %3, %c1024_i32 : i32 "block_start"(within split...)
    %5 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> "make_range"(within split...)
    %6 = tt.splat %4 : i32 -> tensor<1024xi32> "block_start_splat"(within split...)
    %7 = arith.addi %6, %5 : tensor<1024xi32> "offsets"(within split...)
    %8 = tt.splat %c0_i32 : i32 -> tensor<1024xi32> "in_ptr0_splat"(within split...)
    %cst = arith.constant dense<0> : tensor<1024xi32> "addr"(within split...)
    %c0_i32_2 = arith.constant 0 : i32 "addr"(within split...)
    %9 = arith.addi %4, %c0_i32_2 : i32 "addr"(within split...)
    %10 = arith.addi %cst, %5 : tensor<1024xi32> "addr"(within split...)
    %11 = arith.addi %10, %8 : tensor<1024xi32> "addr"(within split...)
    %12 = tt.addptr %0, %9 : !tt.ptr<f32>, i32 "addr"(within split...)
    %13 = tt.splat %12 {rewritten} : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>> "val"(within split...)
    %14 = tt.addptr %13, %11 {rewritten} : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32> "val"(within split...)
    %15 = tt.load %14 {rewritten} : tensor<1024x!tt.ptr<f32>> "val"(within split...)
    tt.return %15 : tensor<1024xf32> within split...
  } within split...
} within split...
```

https://github.com/llvm/llvm-project/pull/79704


More information about the Mlir-commits mailing list