[Mlir-commits] [mlir] [mlir][linalg][conv] Flatten the channel dimension when vectorizing (PR #71918)

Andrzej Warzyński llvmlistbot at llvm.org
Thu Nov 16 02:47:13 PST 2023


banach-space wrote:

> @banach-space I really like how your recent change significantly reduced the complexity compared to the previous approach. 

@nicolasvasilache Thank you - I was about to write a long reply in which I say "thank you" for your observation (which, indeed, makes things incredibly simpler). But you have already noticed that :)

> Hopefully the shape_cast swaps nicely with the vector transfers next :)

That would be happening as a post vectorisation canonicalisation, right?

> Should we schedule a meeting to discuss this? 

Yes - let me ping you offline :) 

Regarding `linalg.generic`s with maps which are not permutations - we need to take a closer look and make sure that the resulting code for convolutions would be equally good. I can take another look at my examples, but I need to prioritise landing this first :) But yes, an important TODO.

As for "strided" loads/stores - super important TODO as well :)

In any case, I have refactored this patch following Nicolas' observation (apologies for over-complicating it so much before). Thank you, that was incredibly helpful 🙏🏻 .

Now, this change is still a "one off" pattern application within the vectoriser. Should this be done elsewhere? We'd need to match whatever `depthwiseConv1dSliceAsMulAcc` is generating post vectorisation, which would turn this into something quite complex.

Btw, I need to refine the tests - I am aware of that. And to benchmark.

Finally, to better visualise the current changes:

**Input**
```mlir
func.func @conv_dill_2(%input: memref<3x5x4xf32>,
%filter: memref<2x4xf32>, %output: memref<3x2x4xf32>) {
  linalg.depthwise_conv_1d_nwc_wc
    {dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
    ins(%input, %filter : memref<3x5x4xf32>, memref<2x4xf32>)
    outs(%output : memref<3x2x4xf32>)
  return
}
```

**Vectorisation _without_ shape casting**
```mlir
  func.func @conv_dill_2(%arg0: memref<3x5x4xf32>, %arg1: memref<2x4xf32>, %arg2: memref<3x2x4xf32>) {
    %c0 = arith.constant 0 : index
    %cst = arith.constant 0.000000e+00 : f32
    %0 = vector.transfer_read %arg0[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<3x5x4xf32>, vector<3x4x4xf32>
    %1 = vector.transfer_read %arg1[%c0, %c0], %cst {in_bounds = [true, true]} : memref<2x4xf32>, vector<2x4xf32>
    %2 = vector.transfer_read %arg2[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<3x2x4xf32>, vector<3x2x4xf32>
    %3 = vector.extract_strided_slice %0 {offsets = [0, 0, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x4xf32> to vector<3x2x4xf32>
    %4 = vector.extract_strided_slice %0 {offsets = [0, 2, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x4xf32> to vector<3x2x4xf32>
    %5 = vector.extract %1[0] : vector<4xf32> from vector<2x4xf32>
    %6 = vector.extract %1[1] : vector<4xf32> from vector<2x4xf32>
    %7 = vector.broadcast %5 : vector<4xf32> to vector<3x2x4xf32>
    %8 = vector.fma %3, %7, %2 : vector<3x2x4xf32>
    %9 = vector.broadcast %6 : vector<4xf32> to vector<3x2x4xf32>
    %10 = vector.fma %4, %9, %8 : vector<3x2x4xf32>
    vector.transfer_write %10, %arg2[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<3x2x4xf32>, memref<3x2x4xf32>
    return
  }
```

**Vectorisation _with_ shape casting**
```mlir
  func.func @conv_dill_2(%arg0: memref<3x5x4xf32>, %arg1: memref<2x4xf32>, %arg2: memref<3x2x4xf32>) {
    %c0 = arith.constant 0 : index
    %cst = arith.constant 0.000000e+00 : f32
    %0 = vector.transfer_read %arg0[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<3x5x4xf32>, vector<3x4x4xf32>
    %1 = vector.transfer_read %arg1[%c0, %c0], %cst {in_bounds = [true, true]} : memref<2x4xf32>, vector<2x4xf32>
    %2 = vector.transfer_read %arg2[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<3x2x4xf32>, vector<3x2x4xf32>
    %3 = vector.extract_strided_slice %0 {offsets = [0, 0, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x4xf32> to vector<3x2x4xf32>
    %4 = vector.extract_strided_slice %0 {offsets = [0, 2, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x4xf32> to vector<3x2x4xf32>
    %5 = vector.extract %1[0] : vector<4xf32> from vector<2x4xf32>
    %6 = vector.extract %1[1] : vector<4xf32> from vector<2x4xf32>
    %7 = vector.shape_cast %3 : vector<3x2x4xf32> to vector<3x8xf32>
    %8 = vector.shape_cast %2 : vector<3x2x4xf32> to vector<3x8xf32>
    %9 = vector.broadcast %5 : vector<4xf32> to vector<3x2x4xf32>
    %10 = vector.shape_cast %9 : vector<3x2x4xf32> to vector<3x8xf32>
    %11 = vector.fma %7, %10, %8 : vector<3x8xf32>
    %12 = vector.shape_cast %4 : vector<3x2x4xf32> to vector<3x8xf32>
    %13 = vector.broadcast %6 : vector<4xf32> to vector<3x2x4xf32>
    %14 = vector.shape_cast %13 : vector<3x2x4xf32> to vector<3x8xf32>
    %15 = vector.fma %12, %14, %11 : vector<3x8xf32>
    %16 = vector.shape_cast %15 : vector<3x8xf32> to vector<3x2x4xf32>
    vector.transfer_write %16, %arg2[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<3x2x4xf32>, memref<3x2x4xf32>
    return
  }
```

https://github.com/llvm/llvm-project/pull/71918


More information about the Mlir-commits mailing list