[Mlir-commits] [mlir] [mlir][tosa] Fold tensor.cast into tosa.transpose (PR #170029)

Tomer Solomon llvmlistbot at llvm.org
Sat Dec 6 08:12:39 PST 2025


recursion-man wrote:

Thanks for the review @lhutton1, happy to add some context.

In our pipeline we have a mix of TOSA ops and some custom ones. We do tiling and peeling, and then run _--canonicalize_. The tiling tail plus the _OpWithOffsetSizesAndStridesConstantArgumentFolder_ canonicalization pattern of tensor.extract_slice insert a relaxing tensor.cast that end up producing IR that looks like:

```mlir
scf.for{
%lhs = tensor.extract_slice %lhs_full[...] : tensor<12x32x256xi8> to tensor<4x8x256xi8>
%rhs = tensor.extract_slice %rhs_full[...] : tensor<12x8x32x32xi8> to tensor<4x8x8x32xi8>
%cast = tensor.cast %rhs : tensor<4x8x8x32xi8> -> tensor<?x8x8x32xi8>
%tr = tosa.transpose %cast {perms = array<i32: 0, 2, 1, 3>} : (tensor<?x8x8x32xi8>) -> tensor<?x8x8x32xi8>
%collapsed = tensor.collapse_shape %tr {...} : tensor<?x8x8x32xi8> -> tensor<?x8x256xi8>
%fc = "custom.fc"(%collapsed, ...) : (tensor<?x8x256xi8>, ...) -> tensor<?x8x256xi8>
%add = "custom.add"(%lhs, %fc, ...) : (tensor<4x8x256xi8>, tensor<?x8x256xi8>, ...) -> tensor<4x8x256xi8>
...
}
```

Our custom ops already implement their own tensor.cast folding, and custom.add’s verifier requires all dims to match. In this shape, canonicalization now fails at custom.add, because one operand is tensor<4x8x256xi8> and the other is tensor<?x8x256xi8>.

With the transpose cast‑folding pattern from this patch, the ir get canonicalized to:
```mlir
scf.for{
%rhs = tensor.extract_slice %rhs_full[...] : tensor<12x8x32x32xi8> to tensor<4x8x8x32xi8>
%tr = tosa.transpose %rhs {perms = array<i32: 0, 2, 1, 3>} : (tensor<4x8x8x32xi8>) -> tensor<4x8x8x32xi8>
%collapsed = tensor.collapse_shape %tr {...} : tensor<4x8x8x32xi8> -> tensor<4x8x256xi8>
%fc = "internal.fc"(%collapsed, ...) : (tensor<4x8x256xi8>, ...) -> tensor<4x8x256xi8>
%add = "internal.add"(%lhs, %fc, ...) : (tensor<4x8x256xi8>, tensor<4x8x256xi8>, ...) -> tensor<4x8x256xi8>
...
}
```
So the transpose outputs the more precise type that the slice canonicalizer just produced, and everything downstream becomes fully static (our custom ops also implement tensor.cast folding). 
Since it’s the canonicalization pass that introduces this relaxing tensor.cast in the first place, I thought it makes sense to also handle the propagation of the static shapes in canonicalization – so that by the end of _--canonicalize_ we’ve recovered as much static shape info as possible. That’s why I initially attached this to tosa::TransposeOp’s canonicalizations. In practice, I also can’t easily move this logic to a later legalization pass in our pipeline, because without letting the tensor.cast propagate through tosa.transpose we already hit verifier failures for some of these custom ops during canonicalization.

I’m starting to wonder if the root cause here is that by the time I introduce scf and tensor.extract_slice, I should already have transformed most of my TOSA ops out of the IR.
Does that match how you’d expect this to be layered?

Thanks again for the feedback!

https://github.com/llvm/llvm-project/pull/170029


More information about the Mlir-commits mailing list