[Mlir-commits] [mlir] [vector][mlir] Canonicalize to shape_cast where possible (PR #140583)

Kunwar Grover llvmlistbot at llvm.org
Wed Dec 17 07:26:37 PST 2025


Groverkss wrote:

> Thanks @ftynse, and everyone who attended the call, I'm sorry I couldn't make it. I'm happy to create a new PR with 2 of the 3 patterns in this PR, will do so in the coming weeks.
> 
> > Alternatively, show where having shape_cast instead of transpose is beneficial.
> 
> One small advantage is that we get the following simplification for free:
> 
> ```mlir
> %0 = vector.shape_cast %arg0 : vector<1x2x2f32> to vector<1x4xf32>
> %1 = vector.transpose %0, [1, 0] : vector<1x4xf32> to vector<4x1xf32>
> %2 = vector.shape_cast %1 : vector<4x1xf32> to vector<4xf32>
> ```
> 
> ====>
> 
> ```mlir
> %0 = vector.shape_cast %arg0 : vector<1x2x2f32> to vector<1x4xf32>
> %1 = vector.shape_cast %0 : vector<1x4xf32> to vector<4x1xf32>
> %2 = vector.shape_cast %1 : vector<4x1xf32> to vector<4xf32>
> ```
> 
> ====>
> 
> ```mlir
> %2 = vector.shape_cast %arg0 : vector<1x2x2f32> to vector<4xf32>
> ```

This is not a benefit of having shape_cast here. This is a special case of a more general canonicalization pattern that can be written on shape_cast(transpose).

```
%0 = vector.shape_cast %arg0 : vector<1x2x2f32> to vector<1x4xf32>
%1 = vector.transpose %0, [1, 0] : vector<1x4xf32> to vector<4x1xf32>
%2 = vector.shape_cast %1 : vector<4x1xf32> to vector<4xf32>
```

You can canonicalize shape_cast(transpose) -> transpose(shape_cast) if it works on a subset which is doesn't get affected by the shape_cast:

```mlir
shape_cast : 1x2x2 -> 1x4
transpose: 1x4 -> 4x1
shape_cast: 4x1 -> 4
```

apply shape_cast(transpose) -> transpose(shape_cast)

```mlir
shape_cast: 1x2x2 -> 1x4
shape_cast: 1x4 -> 4
transpose: 4 -> 4 // no-op
```

and you get the same result later:

```mlir
shape_cast: 1x2x2 -> 4
```

It also works for more interesting cases:

```mlir
shape_cast: 1x2x2x2x3 -> 1x4x2x3
transpose: 1x4x2x3 -> 4x1x3x2
shape_cast: 4x1x3x2 -> 4x3x2
```

canonicalizes to:

```mlir
shape_cast: 1x2x2x4 -> 4x2x3
transpose: 4x2x3 -> 3x2
```

This is actually shows a good reason not to focus on these special cased patterns, as they just hide more general patterns.

For shape_cast, you just have to choose a direction to go, either always up, always down, or always expand, always collapse and you can write patterns like this.

https://github.com/llvm/llvm-project/pull/140583


More information about the Mlir-commits mailing list