[Mlir-commits] [mlir] [mlir][Vector] Update patterns for flattening vector.xfer Ops (1/N) (PR #73522)
Andrzej Warzyński
llvmlistbot at llvm.org
Mon Nov 27 09:13:10 PST 2023
banach-space wrote:
> QQ before I dig into the review: we have some patterns to remove unit dims from xfer ops. Have you tried running those before this flattening step?
You are probably referring to ("rank reducing patterns"): https://github.com/llvm/llvm-project/blob/79b03306af5c11d354fa90db8bfd7818cd811ef5/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir#L197-L204
I have tried that and it didn't trigger for my example. But I need to dig deeper to understand why (and am happy to extend that as part of this work).
However, note that my actual goal is https://github.com/llvm/llvm-project/pull/73523 (which builds on top of this one). I might be able to enable #73523 by updating "rank reducing patterns" instead, but this change feels beneficial on its own regardless. Most of my changes are comments (to help me understand what's going on) and tests to verify the functionality.
> I think we have a pass in IREE that is applying all these simplifications before trying to flatten.
Hm, my example from IREE that I am trying to "fix" (there's a lot going on here and the pattern updated here is just one element of the puzzle):
```mlir
func.func @original(%0: memref<1x1080x1962x2xi32>, %1: memref<1x43x2xi32>, %2: memref<1x1080x1920x2xi32>, %z: index, %y: index, %x: index) {
%cst = arith.constant dense<0> : vector<1x4x2xi32>
%c43 = arith.constant 43 : index
%c4 = arith.constant 4 : index
%c64 = arith.constant 64 : index
%c60 = arith.constant 60 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c0_i32 = arith.constant 0 : i32
%subview = memref.subview %2[0, %z, %y, %z] [1, 60, 64, 2] [1, 1, 1, 1] : memref<1x1080x1920x2xi32> to memref<1x60x64x2xi32, strided<[4147200, 3840, 2, 1], offset: ?>>
%subview_0 = memref.subview %0[0, %z, %y, %z] [1, 60, 106, 2] [1, 1, 1, 1] : memref<1x1080x1962x2xi32> to memref<1x60x106x2xi32, strided<[4237920, 3924, 2, 1], offset: ?>>
scf.for %arg0 = %c0 to %c60 step %c1 {
scf.for %arg1 = %c0 to %c64 step %c4 {
%subview_1 = memref.subview %subview[0, %arg0, %arg1, 0] [1, 1, 4, 2] [1, 1, 1, 1] : memref<1x60x64x2xi32, strided<[4147200, 3840, 2, 1], offset: ?>> to memref<1x1x4x2xi32, strided<[4147200, 3840, 2, 1], offset: ?>>
%6 = scf.for %arg2 = %c0 to %c43 step %c1 iter_args(%arg3 = %cst) -> (vector<1x4x2xi32>) {
%8 = arith.addi %arg2, %arg1 : index
%9 = vector.transfer_read %subview_0[%c0, %arg0, %8, %c0], %c0_i32 {in_bounds = [true, true]} : memref<1x60x106x2xi32, strided<[4237920, 3924, 2, 1], offset: ?>>, vector<4x2xi32>
%10 = vector.transfer_read %1[%c0, %arg2, %c0], %c0_i32 {in_bounds = [true]} : memref<1x43x2xi32>, vector<2xi32>
%11 = vector.broadcast %10 : vector<2xi32> to vector<1x4x2xi32>
%12 = vector.shape_cast %9 : vector<4x2xi32> to vector<8xi32>
%13 = vector.shape_cast %11 : vector<1x4x2xi32> to vector<8xi32>
%14 = arith.muli %12, %13 : vector<8xi32>
%15 = vector.shape_cast %arg3 : vector<1x4x2xi32> to vector<8xi32>
%16 = arith.addi %14, %15 : vector<8xi32>
%17 = vector.shape_cast %16 : vector<8xi32> to vector<1x4x2xi32>
scf.yield %17 : vector<1x4x2xi32>
}
%7 = vector.extract %6[0] : vector<4x2xi32> from vector<1x4x2xi32>
%subview_2 = memref.subview %subview_1[0, 0, 0, 0] [1, 1, 4, 2] [1, 1, 1, 1] : memref<1x1x4x2xi32, strided<[4147200, 3840, 2, 1], offset: ?>> to memref<4x2xi32, affine_map<(d0, d1)[s0] -> (d0 * 2 + d1 + s0)>>
vector.transfer_write %7, %subview_2[%c0, %c0] {in_bounds = [true, true]} : vector<4x2xi32>, memref<4x2xi32, affine_map<(d0, d1)[s0] -> (d0 * 2 + d1 + s0)>>
}
}
return
}
```
So the rank reducing patterns are failing here too. TBH, I'm trying to solve multiple issues here. This is an attempt to reduce the problem space.
> In general, we should aim for removing all these unit dims and avoid the complexity they introduce.
Agreed, that's part of the plan. But removing unit dims is unlikely to be sufficient the fold away these `vector.shape_cast` that I introduce in:
* https://github.com/llvm/llvm-project/pull/71918
That's basically where the example above comes from. In general, I know that I will need multiple things to fix this 😂 .
https://github.com/llvm/llvm-project/pull/73522
More information about the Mlir-commits
mailing list