[Mlir-commits] [mlir] [vector][mlir] Canonicalize to shape_cast where possible (PR #140583)
Kunwar Grover
llvmlistbot at llvm.org
Wed Dec 17 07:26:37 PST 2025
Groverkss wrote:
> Thanks @ftynse, and everyone who attended the call, I'm sorry I couldn't make it. I'm happy to create a new PR with 2 of the 3 patterns in this PR, will do so in the coming weeks.
>
> > Alternatively, show where having shape_cast instead of transpose is beneficial.
>
> One small advantage is that we get the following simplification for free:
>
> ```mlir
> %0 = vector.shape_cast %arg0 : vector<1x2x2f32> to vector<1x4xf32>
> %1 = vector.transpose %0, [1, 0] : vector<1x4xf32> to vector<4x1xf32>
> %2 = vector.shape_cast %1 : vector<4x1xf32> to vector<4xf32>
> ```
>
> ====>
>
> ```mlir
> %0 = vector.shape_cast %arg0 : vector<1x2x2f32> to vector<1x4xf32>
> %1 = vector.shape_cast %0 : vector<1x4xf32> to vector<4x1xf32>
> %2 = vector.shape_cast %1 : vector<4x1xf32> to vector<4xf32>
> ```
>
> ====>
>
> ```mlir
> %2 = vector.shape_cast %arg0 : vector<1x2x2f32> to vector<4xf32>
> ```
This is not a benefit of having shape_cast here. This is a special case of a more general canonicalization pattern that can be written on shape_cast(transpose).
```
%0 = vector.shape_cast %arg0 : vector<1x2x2f32> to vector<1x4xf32>
%1 = vector.transpose %0, [1, 0] : vector<1x4xf32> to vector<4x1xf32>
%2 = vector.shape_cast %1 : vector<4x1xf32> to vector<4xf32>
```
You can canonicalize shape_cast(transpose) -> transpose(shape_cast) if it works on a subset which is doesn't get affected by the shape_cast:
```mlir
shape_cast : 1x2x2 -> 1x4
transpose: 1x4 -> 4x1
shape_cast: 4x1 -> 4
```
apply shape_cast(transpose) -> transpose(shape_cast)
```mlir
shape_cast: 1x2x2 -> 1x4
shape_cast: 1x4 -> 4
transpose: 4 -> 4 // no-op
```
and you get the same result later:
```mlir
shape_cast: 1x2x2 -> 4
```
It also works for more interesting cases:
```mlir
shape_cast: 1x2x2x2x3 -> 1x4x2x3
transpose: 1x4x2x3 -> 4x1x3x2
shape_cast: 4x1x3x2 -> 4x3x2
```
canonicalizes to:
```mlir
shape_cast: 1x2x2x4 -> 4x2x3
transpose: 4x2x3 -> 3x2
```
This is actually shows a good reason not to focus on these special cased patterns, as they just hide more general patterns.
For shape_cast, you just have to choose a direction to go, either always up, always down, or always expand, always collapse and you can write patterns like this.
https://github.com/llvm/llvm-project/pull/140583
More information about the Mlir-commits
mailing list