[mlir] [MLIR][Linalg] Introduce broadcast/transpose semantic to 'linalg.batc… (PR #122275)

Andrzej Warzyński llvmlistbot at llvm.org
Wed Jan 22 08:55:44 PST 2025


================
@@ -1487,6 +1487,154 @@ func.func @matmul_transpose_b(%arg0: memref<3x5xf32>, %arg1: memref<7x5xf32>, %a
 
 // -----
 
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+
+// CHECK-LABEL:   func.func @batch_matmul_bcast_batch_and_m_dim_A(
+// CHECK-SAME:                                    %[[VAL_0:.*]]: memref<5xf32>,
+// CHECK-SAME:                                    %[[VAL_1:.*]]: memref<2x5x7xf32>,
+// CHECK-SAME:                                    %[[VAL_2:.*]]: memref<2x3x7xf32>) {
+// CHECK:           linalg.batch_matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<5xf32>, memref<2x5x7xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
+// CHECK:           return
+// CHECK:         }
+func.func @batch_matmul_bcast_batch_and_m_dim_A(%arg0: memref<5xf32>, %arg1: memref<2x5x7xf32>, %arg2: memref<2x3x7xf32>) {
----------------
banach-space wrote:

SGTM :)

https://github.com/llvm/llvm-project/pull/122275


More information about the Mlir-commits mailing list