[Mlir-commits] [mlir] [mlir][nfc] Add tests for linalg.mmt4d (PR #81422)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Feb 11 09:13:32 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Andrzej WarzyƄski (banach-space)

<details>
<summary>Changes</summary>

linalg.mmt4d was added a while back (https://reviews.llvm.org/D105244),
but there virtually no tests in-tree. In the spirit of documenting
through test, this PR adds a few basic examples.


---
Full diff: https://github.com/llvm/llvm-project/pull/81422.diff


4 Files Affected:

- (modified) mlir/test/Dialect/Linalg/invalid.mlir (+26) 
- (modified) mlir/test/Dialect/Linalg/named-ops.mlir (+11) 
- (added) mlir/test/Dialect/Linalg/transform-op-mmt4d-to-fma.mlir (+66) 
- (modified) mlir/test/Dialect/Linalg/vectorization.mlir (+32) 


``````````diff
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 56890df3f3ee52..916c04f33e9c67 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -744,3 +744,29 @@ func.func @illegal_softmax_output_shape(%arg0: tensor<2x16x32xf32>) -> tensor<2x
     -> tensor<2x16xf32>
   return %1 : tensor<2x16xf32>
 }
+
+// -----
+
+func.func @mmt4d_dims_mismatch(%A: tensor<16x16x8x1xf32>,
+                               %B: tensor<16x16x8x1xf32>,
+                               %C_in: tensor<16x16x8x1xf32>) -> tensor<16x16x8x1xf32> {
+    // expected-error @+1 {{inferred input/output operand #2 has shape's dimension #3 to be 8, but found 1}}
+    %res = linalg.mmt4d
+                     ins(%A, %B: tensor<16x16x8x1xf32>, tensor<16x16x8x1xf32>)
+                     outs(%C_in: tensor<16x16x8x1xf32>)
+                     -> tensor<16x16x8x1xf32>
+    return %res : tensor<16x16x8x1xf32>
+}
+
+// -----
+
+func.func @mmt4d_rank_mismatch(%A: tensor<16x16x8x1xf32>,
+                 %B: tensor<16x16x8x1xf32>,
+                 %C_in: tensor<8x8xf32>) -> tensor<8x8xf32> {
+    // expected-error @+1 {{expected operand rank (2) to match the result rank of indexing_map #2 (4)}}
+    %res = linalg.mmt4d
+                     ins(%A, %B: tensor<16x16x8x1xf32>, tensor<16x16x8x1xf32>)
+                     outs(%C_in: tensor<8x8xf32>)
+                     -> tensor<8x8xf32>
+    return %res : tensor<8x8xf32>
+}
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index 29977a71dbb864..317231908a9413 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -1219,6 +1219,17 @@ func.func @batchmatmul_transpose_b(%arg0: memref<2x3x5xf32>, %arg1: memref<2x7x5
 
 // -----
 
+// CHECK-LABEL: func @mmt4d
+func.func @mmt4d(%A: tensor<10x32x8x1xf32>, %B: tensor<80x32x4x1xf32>, %C: tensor<128x10x80x8x4xf32>) -> tensor<128x10x80x8x4xf32> {
+  // CHECK: %{{.+}} = linalg.mmt4d
+  // CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<10x32x8x1xf32>, tensor<80x32x4x1xf32>)
+  // CHECK-SAME: outs(%{{.+}} : tensor<10x80x8x4xf32>) -> tensor<10x80x8x4xf32>
+  %0 = linalg.mmt4d ins(%A, %B : tensor<10x32x8x1xf32>, tensor<80x32x4x1xf32>) outs(%C: tensor<10x80x8x4xf32>) -> tensor<10x80x8x4xf32>
+  return %0: tensor<10x80x8x4xf32>
+}
+
+// -----
+
 // CHECK-LABEL: func @batch_mmt4d
 func.func @batch_mmt4d(%arg0: tensor<128x10x32x8x1xf32>, %arg1: tensor<128x80x32x4x1xf32>, %arg2: tensor<128x10x80x8x4xf32>) -> tensor<128x10x80x8x4xf32> {
   // CHECK: %{{.+}} = linalg.batch_mmt4d
diff --git a/mlir/test/Dialect/Linalg/transform-op-mmt4d-to-fma.mlir b/mlir/test/Dialect/Linalg/transform-op-mmt4d-to-fma.mlir
new file mode 100644
index 00000000000000..21534152a45c4b
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-op-mmt4d-to-fma.mlir
@@ -0,0 +1,66 @@
+// RUN: mlir-opt %s -transform-interpreter | FileCheck %s
+
+func.func @mmt4d_to_fma(%A: tensor<16x16x8x1xf32>, %B: tensor<16x16x8x1xf32>, %C_in: tensor<16x16x8x8xf32>) -> tensor<16x16x8x8xf32> {
+  %res = linalg.mmt4d
+                   ins(%A, %B: tensor<16x16x8x1xf32>, tensor<16x16x8x1xf32>)
+                   outs(%C_in: tensor<16x16x8x8xf32>)
+                   -> tensor<16x16x8x8xf32>
+  return %res : tensor<16x16x8x8xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %module : (!transform.any_op) -> !transform.op<"func.func">
+
+    %mmt4d = transform.structured.match ops{["linalg.mmt4d"]} in %func
+
+    // Step 1: Tile
+      : (!transform.op<"func.func">) -> !transform.any_op
+    // Tile parallel dims
+    %tiled_linalg_op_p, %loops:4 = transform.structured.tile_using_for %mmt4d[1, 1, 0, 8, 8, 0]
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+    // Tile reduction dims
+    %tiled_linalg_op_r, %loops2:2 = transform.structured.tile_using_for %tiled_linalg_op_p[0, 0, 1, 0, 0, 1]
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+
+    // Step 2: Vectorize
+    transform.structured.vectorize %tiled_linalg_op_r : !transform.any_op
+
+    // Step 3: Simplify
+    // vector.multi_reduction --> vector.contract
+    // Generates a 6-dim vector.contract with the dim matching the original MMT4D Op
+    // and with the following split int parallel and reduction dims:
+    //    * parallel, parallel, reduction, parallel, parallel, reduction
+    transform.apply_patterns to %func {
+      transform.apply_patterns.vector.reduction_to_contract
+      // Reduce the rank of xfer ops. This transforms vector.contract to be
+      // more matmul-like and to enable the lowering to outer product Ops.
+      transform.apply_patterns.vector.transfer_permutation_patterns
+    } : !transform.op<"func.func">
+
+    // Hoisting and LICM - not strictly required
+    %func_h = transform.structured.hoist_redundant_vector_transfers %func
+      : (!transform.op<"func.func">) -> !transform.op<"func.func">
+    %all_loops = transform.structured.match interface{LoopLikeInterface} in %func_h
+      : (!transform.op<"func.func">) -> !transform.any_op
+    transform.apply_licm to %all_loops : !transform.any_op
+    transform.loop.hoist_loop_invariant_subsets %all_loops : !transform.any_op
+
+    // Simplify the 6-dim vector.contract into a 3-dim matmul-like
+    // vector.contract with the following split splitn parallel and reduction
+    // dims:
+    //    * parallel, parallel, reduction
+    transform.apply_patterns to %func_h {
+      transform.apply_patterns.vector.reduction_to_contract
+      transform.apply_patterns.vector.cast_away_vector_leading_one_dim
+      transform.apply_patterns.canonicalization
+    } : !transform.op<"func.func">
+
+    // Step 4: Lower vector.contract to vector.fma
+    transform.apply_patterns to %func_h {
+      transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
+      transform.apply_patterns.vector.lower_outerproduct
+    } : !transform.op<"func.func">
+    transform.yield
+  }
+}
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index 5d1bef478ee987..548c9c7ba76485 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -639,6 +639,38 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+func.func @mmt4d(%A: memref<16x16x8x1xf32>, %B: memref<16x16x8x1xf32>, %C_in: memref<16x16x8x8xf32>) {
+  linalg.mmt4d ins(%A, %B: memref<16x16x8x1xf32>, memref<16x16x8x1xf32>)
+               outs(%C_in: memref<16x16x8x8xf32>)
+  return
+}
+
+// CHECK-LABEL:   func.func @mmt4d(
+// CHECK-SAME:      %[[A:.*]]: memref<16x16x8x1xf32>, %[[B:.*]]: memref<16x16x8x1xf32>, %[[C:.*]]: memref<16x16x8x8xf32>) {
+// CHECK:           %[[VAL_3:.*]] = arith.constant 16 : index
+// CHECK:           %[[VAL_4:.*]] = arith.constant 16 : index
+// CHECK:           %[[VAL_5:.*]] = arith.constant 16 : index
+// CHECK:           %[[VAL_6:.*]] = arith.constant 8 : index
+// CHECK:           %[[VAL_7:.*]] = arith.constant 8 : index
+// CHECK:           %[[VAL_8:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_9:.*]] = arith.constant 0 : index
+// CHECK:           %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x8x1xf32>
+// CHECK:           %[[VEC_B:.*]] = vector.transfer_read %[[B]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x8x1xf32>
+// CHECK:           %[[VEC_C:.*]] = vector.transfer_read %[[C]]{{.*}} : memref<16x16x8x8xf32>, vector<16x16x8x8xf32>
+// CHECK:           %[[MUL:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<16x16x16x8x8x1xf32>
+// CHECK:           %[[RED:.*]] = vector.multi_reduction <add>, %[[MUL]], %[[VEC_C]] [2, 5] : vector<16x16x16x8x8x1xf32> to vector<16x16x8x8xf32>
+// CHECK:           vector.transfer_write %[[RED]], %[[C]]{{.*}} : vector<16x16x8x8xf32>, memref<16x16x8x8xf32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %mmt4d = transform.structured.match ops{["linalg.mmt4d"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %mmt4d : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
 func.func @matmul_scalable(%A: memref<?x?xf32>, %B: memref<?x?xf32>, %C: memref<?x?xf32>) {
   linalg.matmul ins(%A, %B: memref<?x?xf32>, memref<?x?xf32>)
             outs(%C: memref<?x?xf32>)

``````````

</details>


https://github.com/llvm/llvm-project/pull/81422


More information about the Mlir-commits mailing list