[Mlir-commits] [mlir] [mlir][vector] Add `extract(transpose(broadcast(x)))` canonicalization (PR #72616)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Thu Dec 14 05:41:19 PST 2023
================
@@ -2524,3 +2524,17 @@ func.func @load_store_forwarding_rank_mismatch(%v0: vector<4x1x1xf32>, %arg0: te
tensor<4x4x4xf32>, vector<1x100x4x5xf32>
return %r : vector<1x100x4x5xf32>
}
+
+// -----
+
+// CHECK-LABEL: func @extract_of_transposed_broadcast_dim(
+// CHECK-SAME: %[[arg0:.*]]: vector<4x1xf32>
+// CHECK: %[[bc:.*]] = vector.broadcast %[[arg0]] : vector<4x1xf32> to vector<100x5x4x1xf32>
+// CHECK: %[[tp:.*]] = vector.transpose %[[bc]], [3, 0, 2, 1] : vector<100x5x4x1xf32> to vector<1x100x4x5xf32>
+// CHECK: return %[[tp]]
+func.func @extract_of_transposed_broadcast_dim(%arg0: vector<4x1xf32>) -> vector<1x100x4x5xf32> {
+ %0 = vector.broadcast %arg0 : vector<4x1xf32> to vector<100x5x1x4x1xf32>
+ %1 = vector.transpose %0, [2, 4, 0, 3, 1] : vector<100x5x1x4x1xf32> to vector<1x1x100x4x5xf32>
+ %2 = vector.extract %1[0] : vector<1x100x4x5xf32> from vector<1x1x100x4x5xf32>
+ return %2 : vector<1x100x4x5xf32>
----------------
banach-space wrote:
[nit] I think that it using non-unit dim for this example would be more interesting and also easier to parse. For example, you could replace `1` with `123` and then:
```mlir
%2 = vector.extract %1[100] : vector<1x100x4x5xf32> from vector<123x1x100x4x5xf32>
```
That would make it super easy to see what is going on and which broadcast dim is being removed.
https://github.com/llvm/llvm-project/pull/72616
More information about the Mlir-commits
mailing list