[Mlir-commits] [mlir] [mlir][tensor] Generalize/restrict `GeneralizeOuterUnitDimsPackOpPattern` (PR #114315)
Andrzej Warzyński
llvmlistbot at llvm.org
Tue Nov 5 09:40:31 PST 2024
banach-space wrote:
Thanks for taking a look! Quick reply to your high-level question (emphasis mine):
> This test has a non-unit outer dimension, but that dimension is not packed (no inner_tile for it). I think this may actually be intended behavior, since the pattern checks that `packOp.getTiledOuterDims()` are all 1. **Maybe the comments are just misleading**, and this case is meant to be supported.
Indeed, I'd really appreciate if somebody could clarify this 😅
> I'm not sure what the motivation of this pattern was to begin with, so I can't say if this type of case needs to be supported, but I would be wary of removing that functionality without hearing from whoever wrote this pattern.
100% agreed, thanks for bringing this up! As you can see, I'm actually struggling a bit to get feedback for this - I really appreciate you not shying away! 🙏🏻
While waiting for Quinn to chime in, I will make a couple of points.
1. The current logic to compute the necessary sizes is quite convoluted. Adding support for the case mentioned above has been quite tricky. I can try to add support for non-unit not-tiled-outer-dims, but would really prefer avoid complexities that are not required.
2. The current logic for non-unit not-tiled-outer-dims is quite limited and breaks when the padding value is set. You can try this example:
```mlir
func.func @simple_KCRS_to_KRSCsr(%arg0: tensor<3x1x32x8xf32>, %arg1:
tensor<3x1x1x1x8x32xf32>, %pad: f32) -> tensor<3x1x1x1x8x32xf32> {
%0 = tensor.pack %arg0 padding_value(%pad : f32) outer_dims_perm = [0, 2, 3, 1] inner_dims_pos = [3, 2] inner_tiles = [8, 32] into %arg1 : tensor<3x1x32x8xf32> -> tensor<3x1x1x1x8x32xf32>
return %0 : tensor<3x1x1x1x8x32xf32>
}
```
The logic that files: [getPackOpSourceOrPaddedSource](https://github.com/llvm/llvm-project/blob/32473864cb4631780095e25a0378663b98a11188/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp#L1019-L1081) (i.e. you will hit the assert in that method).
https://github.com/llvm/llvm-project/pull/114315
More information about the Mlir-commits
mailing list