[mlir] [MLIR][Linalg] Introduce transpose/broadcast semantic to linalg.batch… (PR #130944)

Andrzej Warzyński llvmlistbot at llvm.org
Fri May 2 07:31:04 PDT 2025


================
@@ -1484,6 +1484,201 @@ func.func @invalid_C_map_result_dim_batch_matmul(%arg0: memref<?x?x?xf32>, %arg1
 }
 
 
+// -----
+
+func.func @missing_indexing_map_batch_reduce_matmul(%arg0: memref<?x?x?xf32>,
+     %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
+     // expected-error @+1 {{Indexing_map attribute must have 3 affine maps}}
+     linalg.batch_reduce_matmul
+         indexing_maps = [affine_map<(batch, m, n, k) -> (batch, m, k)>,
+                          affine_map<(batch, m, n, k) -> (batch, n, k)>]
+         ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>)
+         outs(%arg2: memref<?x?xf32>)
+     return
+}
+
+// -----
+
+func.func @indexing_map_size_one_batch_reduce_matmul(%arg0: memref<?x?x?xf32>,
+     %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
+     // expected-error @+1 {{Indexing_map attribute must have 3 affine maps}}
+     linalg.batch_reduce_matmul
+         indexing_maps = [affine_map<(batch, m, n, k) -> (batch, m, k)>]
+         ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>)
+         outs(%arg2: memref<?x?xf32>)
+     return
+
+}
+
+// -----
+
+func.func @missing_indexing_map_batch_reduce_matmul(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
+  // expected-error @+1 {{expected attribute value}}
+  linalg.batch_reduce_matmul indexing_maps = [
+                       ,
+                       affine_map<(batch, m, n, k) -> (batch, k, n)>,
+                       affine_map<(batch, m, n, k) -> (m, n)>]
+      ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>)
+      outs(%arg2 :memref<?x?xf32>)
+  return
+}
+
+// -----
+
+func.func @invalid_dim_expr_batch_reduce_matmul_a(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
+  // expected-error @+1 {{Unexpected result dim expression (outside the set of default result dims)}}
+  linalg.batch_reduce_matmul
+      indexing_maps = [affine_map<(batch, m, n, k) -> (batch, n, k)>,
+                       affine_map<(batch, m, n, k) -> (batch, k, n)>,
+                       affine_map<(batch, m, n, k) -> (m, n)>]
+      ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>)
+      outs(%arg2 :memref<?x?xf32>)
+  return
+}
+
+// -----
+
+func.func @invalid_dim_expr_batch_reduce_matmul_b(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
+  // expected-error @+1 {{Unexpected result dim expression (outside the set of default result dims)}}
+  linalg.batch_reduce_matmul
+      indexing_maps = [affine_map<(batch, m, n, k) -> (batch, m, k)>,
+                       affine_map<(batch, m, n, k) -> (batch, k, m)>,
+                       affine_map<(batch, m, n, k) -> (m, n)>]
+      ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>)
+      outs(%arg2 :memref<?x?xf32>)
+  return
+}
+
+// -----
+
+func.func @invalid_bcast_batch_reduce_matmul_a(%arg0: memref<?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
+  // expected-error @+1 {{'linalg.batch_reduce_matmul' op Invalid broadcast requested}}
+  linalg.batch_reduce_matmul
+      indexing_maps = [affine_map<(batch, m, n, k) -> (batch)>,
+                       affine_map<(batch, m, n, k) -> (batch, k, n)>,
+                       affine_map<(batch, m, n, k) -> (m, n)>]
+      ins(%arg0, %arg1 : memref<?xf32>, memref<?x?x?xf32>)
+      outs(%arg2: memref<?x?xf32>)
+  return
+}
+
+// -----
+
+func.func @invalid_multi_dim_bcast_expr_batch_reduce_matmul_a(%arg0: memref<?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
+  // expected-error @+1 {{'linalg.batch_reduce_matmul' op Invalid broadcast requested}}
+  linalg.batch_reduce_matmul
+      indexing_maps = [affine_map<(batch, m, n, k) -> (batch, k)>,
----------------
banach-space wrote:

>  For example, if M not equal to N and we broadcast M to K for A matrix and N to K for B matrix, K no longer remains common/shared, consequently violating the matmul definition.

I don't quite follow this - this sounds like we could end up with ... two different `K`s?

In general, there are two possibilities:
* `A = tensor<M x i32>` and `B = tensor<K x N x i32>` - what's the problem with broadcasting `A` to `[M x K]` where `K` is taken from `B`?
* `A = tensor<M x K x i32>` and `B = tensor<N x i32>` - what's the problem with broadcasting `B` to `[K x N]` where `K` is taken from `A`?

There's a 3rd option that indeed does not make sense (there's nowhere to obtain `K` from):
* `A = tensor<M x i32>` and `B = tensor<N x i32>`.

Is this consistent with your understanding? If yes, what ma I missing here?

https://github.com/llvm/llvm-project/pull/130944


More information about the Mlir-commits mailing list