[Mlir-commits] [mlir] [mlir][vector] Add `extract(transpose(broadcast(x)))` canonicalization (PR #72616)

Andrzej WarzyƄski llvmlistbot at llvm.org
Thu Dec 14 05:41:19 PST 2023


================
@@ -2524,3 +2524,17 @@ func.func @load_store_forwarding_rank_mismatch(%v0: vector<4x1x1xf32>, %arg0: te
       tensor<4x4x4xf32>, vector<1x100x4x5xf32>
   return %r : vector<1x100x4x5xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func @extract_of_transposed_broadcast_dim(
+//  CHECK-SAME:     %[[arg0:.*]]: vector<4x1xf32>
+//       CHECK:   %[[bc:.*]] = vector.broadcast %[[arg0]] : vector<4x1xf32> to vector<100x5x4x1xf32>
+//       CHECK:   %[[tp:.*]] = vector.transpose %[[bc]], [3, 0, 2, 1] : vector<100x5x4x1xf32> to vector<1x100x4x5xf32>
+//       CHECK:   return %[[tp]]
+func.func @extract_of_transposed_broadcast_dim(%arg0: vector<4x1xf32>) -> vector<1x100x4x5xf32> {
+  %0 = vector.broadcast %arg0 : vector<4x1xf32> to vector<100x5x1x4x1xf32>
+  %1 = vector.transpose %0, [2, 4, 0, 3, 1] : vector<100x5x1x4x1xf32> to vector<1x1x100x4x5xf32>
+  %2 = vector.extract %1[0] : vector<1x100x4x5xf32> from vector<1x1x100x4x5xf32>
+  return %2 : vector<1x100x4x5xf32>
----------------
banach-space wrote:

[nit] I think that it using non-unit dim for this example would be more interesting and also easier to parse. For example, you could replace `1` with `123` and then:
```mlir
  %2 = vector.extract %1[100] : vector<1x100x4x5xf32> from vector<123x1x100x4x5xf32>

```

That would make it super easy to see what is going on and which broadcast dim is being removed.

https://github.com/llvm/llvm-project/pull/72616


More information about the Mlir-commits mailing list