[mlir] [MLIR][Linalg] Introduce broadcast/transpose semantic to 'linalg.batc… (PR #122275)
Andrzej Warzyński
llvmlistbot at llvm.org
Wed Jan 15 11:53:08 PST 2025
================
@@ -1487,6 +1487,154 @@ func.func @matmul_transpose_b(%arg0: memref<3x5xf32>, %arg1: memref<7x5xf32>, %a
// -----
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+
+// CHECK-LABEL: func.func @batch_matmul_bcast_batch_and_m_dim_A(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<5xf32>,
+// CHECK-SAME: %[[VAL_1:.*]]: memref<2x5x7xf32>,
+// CHECK-SAME: %[[VAL_2:.*]]: memref<2x3x7xf32>) {
+// CHECK: linalg.batch_matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<5xf32>, memref<2x5x7xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
+// CHECK: return
+// CHECK: }
+func.func @batch_matmul_bcast_batch_and_m_dim_A(%arg0: memref<5xf32>, %arg1: memref<2x5x7xf32>, %arg2: memref<2x3x7xf32>) {
----------------
banach-space wrote:
[nit] This name suggests that the `batch` and `m` dims are broadcast:
* `@batch_matmul_bcast_batch_and_m_dim_A`.
That's not quite true though, right? It's the `k` dim that's broadcast and `batch` and `m` are the missing dims.
https://github.com/llvm/llvm-project/pull/122275
More information about the Mlir-commits
mailing list