[Mlir-commits] [mlir] [MLIR][Linalg] Introduce transpose/broadcast semantic to linalg.batch… (PR #130944)
Md Asghar Ahmad Shahid
llvmlistbot at llvm.org
Tue Apr 29 22:57:54 PDT 2025
================
@@ -1484,6 +1484,201 @@ func.func @invalid_C_map_result_dim_batch_matmul(%arg0: memref<?x?x?xf32>, %arg1
}
+// -----
+
+func.func @missing_indexing_map_batch_reduce_matmul(%arg0: memref<?x?x?xf32>,
+ %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
+ // expected-error @+1 {{Indexing_map attribute must have 3 affine maps}}
+ linalg.batch_reduce_matmul
+ indexing_maps = [affine_map<(batch, m, n, k) -> (batch, m, k)>,
+ affine_map<(batch, m, n, k) -> (batch, n, k)>]
+ ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>)
+ outs(%arg2: memref<?x?xf32>)
+ return
+}
+
+// -----
+
+func.func @indexing_map_size_one_batch_reduce_matmul(%arg0: memref<?x?x?xf32>,
+ %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
+ // expected-error @+1 {{Indexing_map attribute must have 3 affine maps}}
+ linalg.batch_reduce_matmul
+ indexing_maps = [affine_map<(batch, m, n, k) -> (batch, m, k)>]
+ ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>)
+ outs(%arg2: memref<?x?xf32>)
+ return
+
+}
+
+// -----
+
+func.func @missing_indexing_map_batch_reduce_matmul(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
+ // expected-error @+1 {{expected attribute value}}
+ linalg.batch_reduce_matmul indexing_maps = [
+ ,
+ affine_map<(batch, m, n, k) -> (batch, k, n)>,
+ affine_map<(batch, m, n, k) -> (m, n)>]
+ ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>)
+ outs(%arg2 :memref<?x?xf32>)
+ return
+}
+
+// -----
+
+func.func @invalid_dim_expr_batch_reduce_matmul_a(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
+ // expected-error @+1 {{Unexpected result dim expression (outside the set of default result dims)}}
+ linalg.batch_reduce_matmul
+ indexing_maps = [affine_map<(batch, m, n, k) -> (batch, n, k)>,
+ affine_map<(batch, m, n, k) -> (batch, k, n)>,
+ affine_map<(batch, m, n, k) -> (m, n)>]
+ ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>)
+ outs(%arg2 :memref<?x?xf32>)
+ return
+}
+
+// -----
+
+func.func @invalid_dim_expr_batch_reduce_matmul_b(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
+ // expected-error @+1 {{Unexpected result dim expression (outside the set of default result dims)}}
+ linalg.batch_reduce_matmul
+ indexing_maps = [affine_map<(batch, m, n, k) -> (batch, m, k)>,
+ affine_map<(batch, m, n, k) -> (batch, k, m)>,
+ affine_map<(batch, m, n, k) -> (m, n)>]
+ ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>)
+ outs(%arg2 :memref<?x?xf32>)
+ return
+}
+
+// -----
+
+func.func @invalid_bcast_batch_reduce_matmul_a(%arg0: memref<?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
+ // expected-error @+1 {{'linalg.batch_reduce_matmul' op Invalid broadcast requested}}
+ linalg.batch_reduce_matmul
+ indexing_maps = [affine_map<(batch, m, n, k) -> (batch)>,
+ affine_map<(batch, m, n, k) -> (batch, k, n)>,
+ affine_map<(batch, m, n, k) -> (m, n)>]
+ ins(%arg0, %arg1 : memref<?xf32>, memref<?x?x?xf32>)
+ outs(%arg2: memref<?x?xf32>)
+ return
+}
+
+// -----
+
+func.func @invalid_multi_dim_bcast_expr_batch_reduce_matmul_a(%arg0: memref<?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
+ // expected-error @+1 {{'linalg.batch_reduce_matmul' op Invalid broadcast requested}}
+ linalg.batch_reduce_matmul
+ indexing_maps = [affine_map<(batch, m, n, k) -> (batch, k)>,
----------------
shahidact wrote:
> My question stands :) What makes this an invalid broadcast?
`affine_map<(batch, m, n, k) -> (k, batch)`, This is invalid as it would not represent a valid matrix multiplication mathematically, as the dimensions that should be multiplied (k in A and k in B) are in positions that don't align properly for the operation.
https://github.com/llvm/llvm-project/pull/130944
More information about the Mlir-commits
mailing list