[Mlir-commits] [mlir] [MLIR][Linalg] Introduce broadcast/transpose semantic to 'linalg.batc… (PR #122275)
Adam Siemieniuk
llvmlistbot at llvm.org
Mon Jan 13 04:03:04 PST 2025
================
@@ -1142,3 +1142,177 @@ func.func @winograd_output_transform_output_width(%arg0: tensor<6x6x3x3x2x2xf32>
%0 = linalg.winograd_output_transform m(4) r(3) ins(%arg0 : tensor<6x6x3x3x2x2xf32>) outs(%arg1 : tensor<2x12x11x2xf32>) -> tensor<2x12x11x2xf32>
return %0 : tensor<2x12x11x2xf32>
}
+
+// -----
+
+func.func @missing_indexing_map_batch_matmul(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>) {
+ // expected-error @+1 {{expected attribute value}}
+ linalg.batch_matmul indexing_maps = [
+ ,
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+ ]
+ ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>)
+ outs(%arg2 :memref<?x?x?xf32>)
+ return
+}
+
+// -----
+
+func.func @invalid_dim_expr_batch_matmul_a(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>) {
+ // expected-error @+1 {{Unexpected dim expression in map result}}
+ linalg.batch_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+ ]
+ ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>) outs(%arg2 :memref<?x?x?xf32>)
+ return
+}
+
+// -----
+
+func.func @invalid_dim_expr_batch_matmul_b(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>) {
+ // expected-error @+1 {{Unexpected dim expression in map result}}
+ linalg.batch_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+ ]
+ ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>) outs(%arg2 :memref<?x?x?xf32>)
+ return
+}
+
+// -----
+
+func.func @invalid_bcast_batch_matmul_a(%arg0: memref<?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>) {
+ // expected-error @+1 {{'linalg.batch_matmul' op Invalid broadcast requested}}
+ linalg.batch_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+ ]
+ ins(%arg0, %arg1 : memref<?xf32>, memref<?x?x?xf32>) outs(%arg2: memref<?x?x?xf32>)
+ return
+}
+
+// -----
+
+func.func @invalid_multi_dim_bcast_expr_batch_matmul_a(%arg0: memref<?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>) {
+ // expected-error @+1 {{'linalg.batch_matmul' op Invalid broadcast requested}}
+ linalg.batch_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+ ]
+ ins(%arg0, %arg1 : memref<?x?xf32>, memref<?x?x?xf32>) outs(%arg2: memref<?x?x?xf32>)
+ return
+}
+
+// -----
+
+func.func @invalid_multi_dim_bcast_expr_batch_matmul_b(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?x?xf32>) {
+ // expected-error @+1 {{'linalg.batch_matmul' op Invalid broadcast requested}}
+ linalg.batch_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d3, d0)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+ ]
+ ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?xf32>) outs(%arg2: memref<?x?x?xf32>)
+ return
+}
+
+// -----
+
+func.func @invalid_bcast_batch_matmul_b(%arg0: memref<?x?x?xf32>, %arg1: memref<?xf32>, %arg2: memref<?x?x?xf32>) {
+ // expected-error @+1 {{'linalg.batch_matmul' op Invalid broadcast requested}}
+ linalg.batch_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+ ]
+ ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?xf32>) outs(%arg2: memref<?x?x?xf32>)
+ return
+}
+
+// -----
+
+func.func @invalid_batch_dim_batch_matmul_a(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>) {
+ // expected-error @+1 {{'linalg.batch_matmul' op Invalid batch dimension expression}}
+ linalg.batch_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d1, d0, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+ ]
+ ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>) outs(%arg2 :memref<?x?x?xf32>)
+ return
+}
+
+// -----
+
+func.func @invalid_batch_dim_batch_matmul_b(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>) {
+ // expected-error @+1 {{'linalg.batch_matmul' op Invalid batch dimension expression}}
+ linalg.batch_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d2, d3, d0)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+ ]
+ ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>) outs(%arg2 :memref<?x?x?xf32>)
+ return
+}
+
+// -----
+
+func.func @invalid_A_map_result_num_batch_matmul(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
+ // expected-error @+1 {{'linalg.batch_matmul' op no. of result dim expression cannot exceed 3.}}
+ linalg.batch_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d3, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+ ]
+ ins(%arg0, %arg1: memref<?x?x?xf32>, memref<?x?x?xf32>)
+ outs(%arg2: memref<?x?xf32>)
+ return
+}
+
+// -----
+
+func.func @invalid_B_map_result_num_batch_matmul(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
+ // expected-error @+1 {{'linalg.batch_matmul' op no. of result dim expression cannot exceed 3.}}
+ linalg.batch_matmul indexing_maps = [
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d3, d2, d3)>,
+ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+ ]
+ ins(%arg0, %arg1: memref<?x?x?xf32>, memref<?x?x?xf32>)
+ outs(%arg2: memref<?x?xf32>)
+ return
+}
+
+// -----
+
+func.func @invalid_C_map_result_num_batch_matmul(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
+ // expected-error @+1 {{'linalg.batch_matmul' op no. of result dim expression cannot exceed 3.}}
----------------
adam-smnk wrote:
The error feels pretty unrelated as dims do not exceed 3.
Sth along the line `expects 3 dims but got {N}` should be more informative.
https://github.com/llvm/llvm-project/pull/122275
More information about the Mlir-commits
mailing list