[Mlir-commits] [mlir] [MLIR][Linalg] Introduce broadcast/transpose semantic to 'linalg.batc… (PR #122275)
Md Asghar Ahmad Shahid
llvmlistbot at llvm.org
Mon Jan 13 08:19:29 PST 2025
================
@@ -1002,3 +1002,26 @@ func.func @matmul_transpose_a_b_explicit(%arg0: memref<5x3xf32>, %arg1: memref<7
// -----
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+
+// CHECK-LABEL: func.func @batch_matmul(
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<?x?x?xf32>, %[[VAL_1:.*]]: tensor<?x?x?xf32>,
+// CHECK-SAME: %[[VAL_2:.*]]: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+// CHECK: linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%[[VAL_0]], %[[VAL_1]] : tensor<?x?x?xf32>, tensor<?x?x?xf32>) outs(%[[VAL_2]] : tensor<?x?x?xf32>) {
+// CHECK: arith.mulf
+// CHECK: arith.addf
+
+func.func @batch_matmul(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>, %arg2: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
----------------
shahidact wrote:
> Also, how about dedicated tests for transpose and broadcasts?
In the "matmul" PR I had got a comment that one test is enough to verify the generalization in conjunction with invalid tests, and it is true. Hence, I refrained to add redundancy.
https://github.com/llvm/llvm-project/pull/122275
More information about the Mlir-commits
mailing list