[Mlir-commits] [mlir] [mlir][vector] Better handle rank-preserving shape_cast (PR #135855)

James Newling llvmlistbot at llvm.org
Wed Apr 16 09:19:35 PDT 2025


newling wrote:

> Shape cast is a generic reshape op that allows rank increasing/decreasing/preserving transformations.

Unfortunately it isn't. The description [here](https://mlir.llvm.org/docs/Dialects/Vector/#vectorshape_cast-vectorshapecastop) says it cannot mix dimensions, so it is like the union of the tensor dialect's [expand_shape](https://mlir.llvm.org/docs/Dialects/TensorOps/#tensorexpand_shape-tensorexpandshapeop) and [collapse_shape](https://mlir.llvm.org/docs/Dialects/TensorOps/#tensorcollapse_shape-tensorcollapseshapeop)

Perhaps it should be like tensor dialects reshape, I would be happy to go down that path instead. But this PR is to make it consistent with how it's currently defined.

Example currently:

```
// error: 'vector.shape_cast' op invalid shape cast
func.func @invalid_0(%arg : vector<4x3xi32>) -> vector<2x3x2xi32> {
  %1 = vector.shape_cast %arg : vector<4x3xi32> to vector<2x3x2xi32>
  return %1 : vector<2x3x2xi32>
}

func.func @valid_0(%arg : vector<4x3x1xi32>) -> vector<2x3x2xi32> {
  %1 = vector.shape_cast %arg : vector<4x3x1xi32> to vector<2x3x2xi32>
  return %1 : vector<2x3x2xi32>
}
```

IMO it doesn't make sense that the second one above verifies, but the first one doesn't -- both mix dimensions. 

The reason `valid_0` doesn't fail to verify is because the current logic fails to consider the case where the rank is unchanged (see [here](https://github.com/llvm/llvm-project/blob/34598fdadc06bd3b21aa97342dda05ecd9233912/mlir/lib/Dialect/Vector/IR/VectorOps.cpp#L5591). All I can think is that the [initial commit](https://github.com/llvm/llvm-project/commit/3ce8095c295e6a9ef7c946ad8c035a8b5a392ec1) of this op just forgot to consider this case. 
 

https://github.com/llvm/llvm-project/pull/135855


More information about the Mlir-commits mailing list