[Mlir-commits] [mlir] [MLIR][Linalg] Introduce transpose/broadcast semantic to linalg.batch… (PR #130944)

Md Asghar Ahmad Shahid llvmlistbot at llvm.org
Mon Apr 28 10:30:49 PDT 2025


================
@@ -1024,6 +1024,34 @@ func.func @batch_matmul(%arg0: tensor<2x3x5xf32>, %arg1: tensor<2x5x7xf32>, %arg
 
 // -----
 
+// CHECK: #[[$ACCESS_A:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK: #[[$ACCESS_B:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+// CHECK: #[[$ACCESS_C:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
+
+// CHECK-LABEL:   func.func @batch_reduce_matmul(
+// CHECK-SAME:        %[[ARG_A:.*]]: tensor<2x3x5xf32>,
+// CHECK-SAME:        %[[ARG_B:.*]]: tensor<2x5x7xf32>,
+// CHECK-SAME:        %[[ARG_C:.*]]: tensor<3x7xf32>) -> tensor<3x7xf32> {
+// CHECK:           linalg.generic
+// CHECK-SAME:          indexing_maps = [#[[$ACCESS_A]], #[[$ACCESS_B]], #[[$ACCESS_C]]],
+// CHECK-SAME:          iterator_types = ["reduction", "parallel", "parallel", "reduction"]}
+// CHECK:           arith.mulf
+// CHECK:           arith.addf
+// CHECK:           linalg.yield
+
+func.func @batch_reduce_matmul(%arg0: tensor<2x3x5xf32>, %arg1: tensor<2x5x7xf32>, %arg2: tensor<3x7xf32>) -> tensor<3x7xf32> {
----------------
shahidact wrote:

Done.

https://github.com/llvm/llvm-project/pull/130944


More information about the Mlir-commits mailing list