[mlir] [MLIR][Linalg] Introduce transpose/broadcast semantic to linalg.batch… (PR #130944)
Andrzej Warzyński
llvmlistbot at llvm.org
Wed Apr 23 00:52:31 PDT 2025
================
@@ -1024,6 +1024,34 @@ func.func @batch_matmul(%arg0: tensor<2x3x5xf32>, %arg1: tensor<2x5x7xf32>, %arg
// -----
+// CHECK: #[[$ACCESS_A:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK: #[[$ACCESS_B:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+// CHECK: #[[$ACCESS_C:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+
+// CHECK-LABEL: func.func @batch_reduce_matmul(
+// CHECK-SAME: %[[ARG_A:.*]]: tensor<2x3x5xf32>,
+// CHECK-SAME: %[[ARG_B:.*]]: tensor<2x5x7xf32>,
+// CHECK-SAME: %[[ARG_C:.*]]: tensor<3x7xf32>) -> tensor<3x7xf32> {
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[$ACCESS_A]], #[[$ACCESS_B]], #[[$ACCESS_C]]],
+// CHECK-SAME: iterator_types = ["reduction", "parallel", "parallel", "reduction"]}
+// CHECK: arith.mulf
+// CHECK: arith.addf
+// CHECK: linalg.yield
+
+func.func @batch_reduce_matmul(%arg0: tensor<2x3x5xf32>, %arg1: tensor<2x5x7xf32>, %arg2: tensor<3x7xf32>) -> tensor<3x7xf32> {
----------------
banach-space wrote:
Please use consistent names (MLIR + LIT). I am suggesting `A`, `B` and `C`, but that's secondary.
```suggestion
// CHECK-SAME: %[[A:.*]]: tensor<2x3x5xf32>,
// CHECK-SAME: %[[B:.*]]: tensor<2x5x7xf32>,
// CHECK-SAME: %[[C:.*]]: tensor<3x7xf32>) -> tensor<3x7xf32> {
// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[$ACCESS_A]], #[[$ACCESS_B]], #[[$ACCESS_C]]],
// CHECK-SAME: iterator_types = ["reduction", "parallel", "parallel", "reduction"]}
// CHECK: arith.mulf
// CHECK: arith.addf
// CHECK: linalg.yield
func.func @batch_reduce_matmul(%A: tensor<2x3x5xf32>, %B: tensor<2x5x7xf32>, %C: tensor<3x7xf32>) -> tensor<3x7xf32> {
```
https://github.com/llvm/llvm-project/pull/130944
More information about the Mlir-commits
mailing list