[Mlir-commits] [mlir] [mlir][nfc] Add tests for linalg.mmt4d (PR #81422)

Andrzej WarzyƄski llvmlistbot at llvm.org
Tue Feb 13 04:18:38 PST 2024


https://github.com/banach-space updated https://github.com/llvm/llvm-project/pull/81422

>From e8c57d1a65cd6803de6b27c7a343bdd9f3b8ef7d Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Sun, 11 Feb 2024 15:31:47 +0000
Subject: [PATCH 1/2] [mlir][nfc] Add tests for linalg.mmt4d

linalg.mmt4d was added a while back (https://reviews.llvm.org/D105244),
but there virtually no tests in-tree. In the spirit of documenting
through test, this PR adds a few basic examples.
---
 mlir/test/Dialect/Linalg/invalid.mlir         | 26 +++++++
 mlir/test/Dialect/Linalg/named-ops.mlir       | 11 +++
 .../Linalg/transform-op-mmt4d-to-fma.mlir     | 70 +++++++++++++++++++
 mlir/test/Dialect/Linalg/vectorization.mlir   | 32 +++++++++
 4 files changed, 139 insertions(+)
 create mode 100644 mlir/test/Dialect/Linalg/transform-op-mmt4d-to-fma.mlir

diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 56890df3f3ee52..916c04f33e9c67 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -744,3 +744,29 @@ func.func @illegal_softmax_output_shape(%arg0: tensor<2x16x32xf32>) -> tensor<2x
     -> tensor<2x16xf32>
   return %1 : tensor<2x16xf32>
 }
+
+// -----
+
+func.func @mmt4d_dims_mismatch(%A: tensor<16x16x8x1xf32>,
+                               %B: tensor<16x16x8x1xf32>,
+                               %C_in: tensor<16x16x8x1xf32>) -> tensor<16x16x8x1xf32> {
+    // expected-error @+1 {{inferred input/output operand #2 has shape's dimension #3 to be 8, but found 1}}
+    %res = linalg.mmt4d
+                     ins(%A, %B: tensor<16x16x8x1xf32>, tensor<16x16x8x1xf32>)
+                     outs(%C_in: tensor<16x16x8x1xf32>)
+                     -> tensor<16x16x8x1xf32>
+    return %res : tensor<16x16x8x1xf32>
+}
+
+// -----
+
+func.func @mmt4d_rank_mismatch(%A: tensor<16x16x8x1xf32>,
+                 %B: tensor<16x16x8x1xf32>,
+                 %C_in: tensor<8x8xf32>) -> tensor<8x8xf32> {
+    // expected-error @+1 {{expected operand rank (2) to match the result rank of indexing_map #2 (4)}}
+    %res = linalg.mmt4d
+                     ins(%A, %B: tensor<16x16x8x1xf32>, tensor<16x16x8x1xf32>)
+                     outs(%C_in: tensor<8x8xf32>)
+                     -> tensor<8x8xf32>
+    return %res : tensor<8x8xf32>
+}
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index 29977a71dbb864..7064e1b3f9dc76 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -1219,6 +1219,17 @@ func.func @batchmatmul_transpose_b(%arg0: memref<2x3x5xf32>, %arg1: memref<2x7x5
 
 // -----
 
+// CHECK-LABEL: func @mmt4d
+func.func @mmt4d(%A: tensor<10x32x8x1xf32>, %B: tensor<80x32x4x1xf32>, %C: tensor<10x80x8x4xf32>) -> tensor<10x80x8x4xf32> {
+  // CHECK: %{{.+}} = linalg.mmt4d
+  // CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<10x32x8x1xf32>, tensor<80x32x4x1xf32>)
+  // CHECK-SAME: outs(%{{.+}} : tensor<10x80x8x4xf32>) -> tensor<10x80x8x4xf32>
+  %0 = linalg.mmt4d ins(%A, %B : tensor<10x32x8x1xf32>, tensor<80x32x4x1xf32>) outs(%C: tensor<10x80x8x4xf32>) -> tensor<10x80x8x4xf32>
+  return %0: tensor<10x80x8x4xf32>
+}
+
+// -----
+
 // CHECK-LABEL: func @batch_mmt4d
 func.func @batch_mmt4d(%arg0: tensor<128x10x32x8x1xf32>, %arg1: tensor<128x80x32x4x1xf32>, %arg2: tensor<128x10x80x8x4xf32>) -> tensor<128x10x80x8x4xf32> {
   // CHECK: %{{.+}} = linalg.batch_mmt4d
diff --git a/mlir/test/Dialect/Linalg/transform-op-mmt4d-to-fma.mlir b/mlir/test/Dialect/Linalg/transform-op-mmt4d-to-fma.mlir
new file mode 100644
index 00000000000000..26e36d6c88557b
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-op-mmt4d-to-fma.mlir
@@ -0,0 +1,70 @@
+// RUN: mlir-opt %s -transform-interpreter | FileCheck %s
+
+func.func @mmt4d_to_fma(%A: tensor<16x16x8x1xf32>, %B: tensor<16x16x8x1xf32>, %C_in: tensor<16x16x8x8xf32>) -> tensor<16x16x8x8xf32> {
+  %res = linalg.mmt4d
+                   ins(%A, %B: tensor<16x16x8x1xf32>, tensor<16x16x8x1xf32>)
+                   outs(%C_in: tensor<16x16x8x8xf32>)
+                   -> tensor<16x16x8x8xf32>
+  return %res : tensor<16x16x8x8xf32>
+}
+
+
+// CHECK-LABEL:     @mmt4d_to_fma
+// CHECK-COUNT-8:         vector.fma
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %module : (!transform.any_op) -> !transform.op<"func.func">
+
+    %mmt4d = transform.structured.match ops{["linalg.mmt4d"]} in %func
+
+    // Step 1: Tile
+      : (!transform.op<"func.func">) -> !transform.any_op
+    // Tile parallel dims
+    %tiled_linalg_op_p, %loops:4 = transform.structured.tile_using_for %mmt4d[1, 1, 0, 8, 8, 0]
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+    // Tile reduction dims
+    %tiled_linalg_op_r, %loops2:2 = transform.structured.tile_using_for %tiled_linalg_op_p[0, 0, 1, 0, 0, 1]
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+
+    // Step 2: Vectorize
+    transform.structured.vectorize %tiled_linalg_op_r : !transform.any_op
+
+    // Step 3: Simplify
+    // vector.multi_reduction --> vector.contract
+    // Generates a 6-dim vector.contract with the dim matching the original MMT4D Op
+    // and with the following split int parallel and reduction dims:
+    //    * parallel, parallel, reduction, parallel, parallel, reduction
+    transform.apply_patterns to %func {
+      transform.apply_patterns.vector.reduction_to_contract
+      // Reduce the rank of xfer ops. This transforms vector.contract to be
+      // more matmul-like and to enable the lowering to outer product Ops.
+      transform.apply_patterns.vector.transfer_permutation_patterns
+    } : !transform.op<"func.func">
+
+    // Hoisting and LICM - not strictly required
+    %func_h = transform.structured.hoist_redundant_vector_transfers %func
+      : (!transform.op<"func.func">) -> !transform.op<"func.func">
+    %all_loops = transform.structured.match interface{LoopLikeInterface} in %func_h
+      : (!transform.op<"func.func">) -> !transform.any_op
+    transform.apply_licm to %all_loops : !transform.any_op
+    transform.loop.hoist_loop_invariant_subsets %all_loops : !transform.any_op
+
+    // Simplify the 6-dim vector.contract into a 3-dim matmul-like
+    // vector.contract with the following split splitn parallel and reduction
+    // dims:
+    //    * parallel, parallel, reduction
+    transform.apply_patterns to %func_h {
+      transform.apply_patterns.vector.reduction_to_contract
+      transform.apply_patterns.vector.cast_away_vector_leading_one_dim
+      transform.apply_patterns.canonicalization
+    } : !transform.op<"func.func">
+
+    // Step 4: Lower vector.contract to vector.fma
+    transform.apply_patterns to %func_h {
+      transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
+      transform.apply_patterns.vector.lower_outerproduct
+    } : !transform.op<"func.func">
+    transform.yield
+  }
+}
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index 5d1bef478ee987..548c9c7ba76485 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -639,6 +639,38 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+func.func @mmt4d(%A: memref<16x16x8x1xf32>, %B: memref<16x16x8x1xf32>, %C_in: memref<16x16x8x8xf32>) {
+  linalg.mmt4d ins(%A, %B: memref<16x16x8x1xf32>, memref<16x16x8x1xf32>)
+               outs(%C_in: memref<16x16x8x8xf32>)
+  return
+}
+
+// CHECK-LABEL:   func.func @mmt4d(
+// CHECK-SAME:      %[[A:.*]]: memref<16x16x8x1xf32>, %[[B:.*]]: memref<16x16x8x1xf32>, %[[C:.*]]: memref<16x16x8x8xf32>) {
+// CHECK:           %[[VAL_3:.*]] = arith.constant 16 : index
+// CHECK:           %[[VAL_4:.*]] = arith.constant 16 : index
+// CHECK:           %[[VAL_5:.*]] = arith.constant 16 : index
+// CHECK:           %[[VAL_6:.*]] = arith.constant 8 : index
+// CHECK:           %[[VAL_7:.*]] = arith.constant 8 : index
+// CHECK:           %[[VAL_8:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_9:.*]] = arith.constant 0 : index
+// CHECK:           %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x8x1xf32>
+// CHECK:           %[[VEC_B:.*]] = vector.transfer_read %[[B]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x8x1xf32>
+// CHECK:           %[[VEC_C:.*]] = vector.transfer_read %[[C]]{{.*}} : memref<16x16x8x8xf32>, vector<16x16x8x8xf32>
+// CHECK:           %[[MUL:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<16x16x16x8x8x1xf32>
+// CHECK:           %[[RED:.*]] = vector.multi_reduction <add>, %[[MUL]], %[[VEC_C]] [2, 5] : vector<16x16x16x8x8x1xf32> to vector<16x16x8x8xf32>
+// CHECK:           vector.transfer_write %[[RED]], %[[C]]{{.*}} : vector<16x16x8x8xf32>, memref<16x16x8x8xf32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %mmt4d = transform.structured.match ops{["linalg.mmt4d"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %mmt4d : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
 func.func @matmul_scalable(%A: memref<?x?xf32>, %B: memref<?x?xf32>, %C: memref<?x?xf32>) {
   linalg.matmul ins(%A, %B: memref<?x?xf32>, memref<?x?xf32>)
             outs(%C: memref<?x?xf32>)

>From 89fd5a5d63e64d72249042f35d8ca1c57bbaa653 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Tue, 13 Feb 2024 12:18:07 +0000
Subject: [PATCH 2/2] fixup! [mlir][nfc] Remove leftover print stmt in a test

Address Cullen's comments
---
 mlir/test/Dialect/Linalg/transform-op-mmt4d-to-fma.mlir | 6 +++---
 mlir/test/Dialect/Linalg/vectorization.mlir             | 7 -------
 2 files changed, 3 insertions(+), 10 deletions(-)

diff --git a/mlir/test/Dialect/Linalg/transform-op-mmt4d-to-fma.mlir b/mlir/test/Dialect/Linalg/transform-op-mmt4d-to-fma.mlir
index 26e36d6c88557b..61e13d1bfa9c62 100644
--- a/mlir/test/Dialect/Linalg/transform-op-mmt4d-to-fma.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-mmt4d-to-fma.mlir
@@ -33,7 +33,7 @@ module attributes {transform.with_named_sequence} {
     // Step 3: Simplify
     // vector.multi_reduction --> vector.contract
     // Generates a 6-dim vector.contract with the dim matching the original MMT4D Op
-    // and with the following split int parallel and reduction dims:
+    // and with the following split into parallel and reduction dims:
     //    * parallel, parallel, reduction, parallel, parallel, reduction
     transform.apply_patterns to %func {
       transform.apply_patterns.vector.reduction_to_contract
@@ -51,7 +51,7 @@ module attributes {transform.with_named_sequence} {
     transform.loop.hoist_loop_invariant_subsets %all_loops : !transform.any_op
 
     // Simplify the 6-dim vector.contract into a 3-dim matmul-like
-    // vector.contract with the following split splitn parallel and reduction
+    // vector.contract with the following split into parallel and reduction
     // dims:
     //    * parallel, parallel, reduction
     transform.apply_patterns to %func_h {
@@ -60,7 +60,7 @@ module attributes {transform.with_named_sequence} {
       transform.apply_patterns.canonicalization
     } : !transform.op<"func.func">
 
-    // Step 4: Lower vector.contract to vector.fma
+    // Step 4: Lower vector.contract to vector.fma via vector.outerproduct
     transform.apply_patterns to %func_h {
       transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
       transform.apply_patterns.vector.lower_outerproduct
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index 548c9c7ba76485..0272ac599aa3db 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -647,13 +647,6 @@ func.func @mmt4d(%A: memref<16x16x8x1xf32>, %B: memref<16x16x8x1xf32>, %C_in: me
 
 // CHECK-LABEL:   func.func @mmt4d(
 // CHECK-SAME:      %[[A:.*]]: memref<16x16x8x1xf32>, %[[B:.*]]: memref<16x16x8x1xf32>, %[[C:.*]]: memref<16x16x8x8xf32>) {
-// CHECK:           %[[VAL_3:.*]] = arith.constant 16 : index
-// CHECK:           %[[VAL_4:.*]] = arith.constant 16 : index
-// CHECK:           %[[VAL_5:.*]] = arith.constant 16 : index
-// CHECK:           %[[VAL_6:.*]] = arith.constant 8 : index
-// CHECK:           %[[VAL_7:.*]] = arith.constant 8 : index
-// CHECK:           %[[VAL_8:.*]] = arith.constant 1 : index
-// CHECK:           %[[VAL_9:.*]] = arith.constant 0 : index
 // CHECK:           %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x8x1xf32>
 // CHECK:           %[[VEC_B:.*]] = vector.transfer_read %[[B]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x8x1xf32>
 // CHECK:           %[[VEC_C:.*]] = vector.transfer_read %[[C]]{{.*}} : memref<16x16x8x8xf32>, vector<16x16x8x8xf32>



More information about the Mlir-commits mailing list