[mlir] [MLIR][Linalg] Introduce broadcast/transpose semantic to 'linalg.batc… (PR #122275)

Andrzej Warzyński llvmlistbot at llvm.org
Wed Jan 15 11:53:08 PST 2025


================
@@ -1487,6 +1487,154 @@ func.func @matmul_transpose_b(%arg0: memref<3x5xf32>, %arg1: memref<7x5xf32>, %a
 
 // -----
 
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+
+// CHECK-LABEL:   func.func @batch_matmul_bcast_batch_and_m_dim_A(
+// CHECK-SAME:                                    %[[VAL_0:.*]]: memref<5xf32>,
+// CHECK-SAME:                                    %[[VAL_1:.*]]: memref<2x5x7xf32>,
+// CHECK-SAME:                                    %[[VAL_2:.*]]: memref<2x3x7xf32>) {
+// CHECK:           linalg.batch_matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<5xf32>, memref<2x5x7xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
+// CHECK:           return
+// CHECK:         }
+func.func @batch_matmul_bcast_batch_and_m_dim_A(%arg0: memref<5xf32>, %arg1: memref<2x5x7xf32>, %arg2: memref<2x3x7xf32>) {
+  linalg.batch_matmul indexing_maps = [
+                       affine_map<(d0, d1, d2, d3) -> (d3)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+                     ]
+                     ins(%arg0, %arg1 : memref<5xf32>, memref<2x5x7xf32>) outs(%arg2: memref<2x3x7xf32>)
+  return
+}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d3)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+
+// CHECK-LABEL:   func.func @batch_matmul_bcast_batch_dim_A(
+// CHECK-SAME:                                              %[[VAL_0:.*]]: memref<3x5xf32>,
+// CHECK-SAME:                                              %[[VAL_1:.*]]: memref<2x5x7xf32>,
+// CHECK-SAME:                                              %[[VAL_2:.*]]: memref<2x3x7xf32>) {
+// CHECK:           linalg.batch_matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<3x5xf32>, memref<2x5x7xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
+// CHECK:           return
+// CHECK:         }
+func.func @batch_matmul_bcast_batch_dim_A(%arg0: memref<3x5xf32>, %arg1: memref<2x5x7xf32>, %arg2: memref<2x3x7xf32>) {
+  linalg.batch_matmul indexing_maps = [
+                       affine_map<(d0, d1, d2, d3) -> (d1, d3)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+                     ]
+                     ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x5x7xf32>) outs(%arg2: memref<2x3x7xf32>)
+  return
+}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+
+// CHECK-LABEL:   func.func @batch_matmul_bcast_batch_and_n_dim_B(
+// CHECK-SAME:                                                    %[[VAL_0:.*]]: memref<2x3x5xf32>,
+// CHECK-SAME:                                                    %[[VAL_1:.*]]: memref<5xf32>,
+// CHECK-SAME:                                                    %[[VAL_2:.*]]: memref<2x3x7xf32>) {
+// CHECK:           linalg.batch_matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<2x3x5xf32>, memref<5xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
+// CHECK:           return
+// CHECK:         }
+func.func @batch_matmul_bcast_batch_and_n_dim_B(%arg0: memref<2x3x5xf32>, %arg1: memref<5xf32>, %arg2: memref<2x3x7xf32>) {
+  linalg.batch_matmul indexing_maps = [
+                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+                       affine_map<(d0, d1, d2, d3) -> (d3)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+                     ]
+                     ins(%arg0, %arg1 : memref<2x3x5xf32>, memref<5xf32>) outs(%arg2: memref<2x3x7xf32>)
+  return
+}
+
+// -----
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+
+// CHECK-LABEL:   func.func @batch_matmul_bcast_batch_dim_B(
+// CHECK-SAME:                                              %[[VAL_0:.*]]: memref<2x3x5xf32>,
+// CHECK-SAME:                                              %[[VAL_1:.*]]: memref<5x7xf32>,
+// CHECK-SAME:                                              %[[VAL_2:.*]]: memref<2x3x7xf32>) {
+// CHECK:           linalg.batch_matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<2x3x5xf32>, memref<5x7xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
+// CHECK:           return
+// CHECK:         }
+
+func.func @batch_matmul_bcast_batch_dim_B(%arg0: memref<2x3x5xf32>, %arg1: memref<5x7xf32>, %arg2: memref<2x3x7xf32>) {
+  linalg.batch_matmul indexing_maps = [
+                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
+                       affine_map<(d0, d1, d2, d3) -> (d3, d2)>,
+                       affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+                     ]
+                     ins(%arg0, %arg1 : memref<2x3x5xf32>, memref<5x7xf32>) outs(%arg2: memref<2x3x7xf32>)
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @batch_matmul_explicit_transpose_a
+//       CHECK:   linalg.batch_matmul
+//  CHECK-SAME:     ins(%{{.+}}, %{{.+}} : memref<2x5x3xf32>, memref<2x5x7xf32>)
+//  CHECK-SAME:     outs(%{{.+}} : memref<2x3x7xf32>)
+func.func @batch_matmul_explicit_transpose_a(%arg0: memref<2x5x3xf32>, %arg1: memref<2x5x7xf32>, %arg2: memref<2x3x7xf32>) {
----------------
banach-space wrote:

For consistency
```suggestion
func.func @batch_matmul_explicit_transpose_A(%arg0: memref<2x5x3xf32>, %arg1: memref<2x5x7xf32>, %arg2: memref<2x3x7xf32>) {
```

https://github.com/llvm/llvm-project/pull/122275


More information about the Mlir-commits mailing list