[Mlir-commits] [mlir] [vector][mlir] Canonicalize to shape_cast where possible (PR #140583)
James Newling
llvmlistbot at llvm.org
Thu Jun 26 11:28:06 PDT 2025
================
@@ -753,12 +762,13 @@ func.func @fold_extract_broadcast_0dvec_input_scalar_output(%a : vector<f32>,
// -----
+
// CHECK-LABEL: negative_fold_extract_broadcast
-// CHECK: vector.broadcast %{{.*}} : vector<1x1xf32> to vector<1x1x4xf32>
-// CHECK: vector.extract %{{.*}}[0, 0] : vector<4xf32> from vector<1x1x4xf32>
+// CHECK: vector.broadcast %{{.*}} : vector<1x1xf32> to vector<1x2x4xf32>
+// CHECK: vector.extract %{{.*}}[0, 0] : vector<4xf32> from vector<1x2x4xf32>
func.func @negative_fold_extract_broadcast(%a : vector<1x1xf32>) -> vector<4xf32> {
- %b = vector.broadcast %a : vector<1x1xf32> to vector<1x1x4xf32>
- %r = vector.extract %b[0, 0] : vector<4xf32> from vector<1x1x4xf32>
+ %b = vector.broadcast %a : vector<1x1xf32> to vector<1x2x4xf32>
+ %r = vector.extract %b[0, 0] : vector<4xf32> from vector<1x2x4xf32>
----------------
newling wrote:
> Keep both tests, one with the original shape and one with the new ones?
Makes sense, will do.
> Unrelated: it looks like we are missing a canonicalization patter here? This should be turned into a single vector.broadcast to vector<4xf32>?
No because you can't broadcast <1x1xf32> to <4xf32> -- broadcasts can never reduce rank in Vector. FWIW slightly related to my comment [here](https://github.com/llvm/llvm-project/pull/145740#discussion_r2167614549) where this would be simpler if ops didn't do implicit shape casting. In this case if it was something like
```
%s = vector.shape_cast %a : vector<1x1xf32> to vector<1x1x1xf32>
%b = vector.broadcast %s : vector<1x1x1xf32> to vector<1x2x4xf32>
%r = vector.extract %b[0, 0] : vector<1x1x4xf32> from vector<1x2x4xf32>
%s = vector.shape_cast %r : vector<1x1x4> to vector<4>
```
ie if we constrained broadcasts and extracts to be rank retaining, then this would be canonicalized to
```
%s = vector.shape_cast %a : vector<1x1xf32> to vector<1x1x1xf32>
%b = vector.broadcast %s : vector<1x1x1xf32> to vector<1x1x4xf32>
%s = vector.shape_cast %b : vector<1x1x4> to vector<4>
```
which, if you have faith that the shape_casts will vanish at a later point, is simpler!
p.s. I plan to reply in https://github.com/llvm/llvm-project/pull/145740 later today
https://github.com/llvm/llvm-project/pull/140583
More information about the Mlir-commits
mailing list