[Mlir-commits] [mlir] [mlir][nfc] Add tests for linalg.mmt4d (PR #81422)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Tue Feb 13 04:18:38 PST 2024
https://github.com/banach-space updated https://github.com/llvm/llvm-project/pull/81422
>From e8c57d1a65cd6803de6b27c7a343bdd9f3b8ef7d Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Sun, 11 Feb 2024 15:31:47 +0000
Subject: [PATCH 1/2] [mlir][nfc] Add tests for linalg.mmt4d
linalg.mmt4d was added a while back (https://reviews.llvm.org/D105244),
but there virtually no tests in-tree. In the spirit of documenting
through test, this PR adds a few basic examples.
---
mlir/test/Dialect/Linalg/invalid.mlir | 26 +++++++
mlir/test/Dialect/Linalg/named-ops.mlir | 11 +++
.../Linalg/transform-op-mmt4d-to-fma.mlir | 70 +++++++++++++++++++
mlir/test/Dialect/Linalg/vectorization.mlir | 32 +++++++++
4 files changed, 139 insertions(+)
create mode 100644 mlir/test/Dialect/Linalg/transform-op-mmt4d-to-fma.mlir
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 56890df3f3ee52..916c04f33e9c67 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -744,3 +744,29 @@ func.func @illegal_softmax_output_shape(%arg0: tensor<2x16x32xf32>) -> tensor<2x
-> tensor<2x16xf32>
return %1 : tensor<2x16xf32>
}
+
+// -----
+
+func.func @mmt4d_dims_mismatch(%A: tensor<16x16x8x1xf32>,
+ %B: tensor<16x16x8x1xf32>,
+ %C_in: tensor<16x16x8x1xf32>) -> tensor<16x16x8x1xf32> {
+ // expected-error @+1 {{inferred input/output operand #2 has shape's dimension #3 to be 8, but found 1}}
+ %res = linalg.mmt4d
+ ins(%A, %B: tensor<16x16x8x1xf32>, tensor<16x16x8x1xf32>)
+ outs(%C_in: tensor<16x16x8x1xf32>)
+ -> tensor<16x16x8x1xf32>
+ return %res : tensor<16x16x8x1xf32>
+}
+
+// -----
+
+func.func @mmt4d_rank_mismatch(%A: tensor<16x16x8x1xf32>,
+ %B: tensor<16x16x8x1xf32>,
+ %C_in: tensor<8x8xf32>) -> tensor<8x8xf32> {
+ // expected-error @+1 {{expected operand rank (2) to match the result rank of indexing_map #2 (4)}}
+ %res = linalg.mmt4d
+ ins(%A, %B: tensor<16x16x8x1xf32>, tensor<16x16x8x1xf32>)
+ outs(%C_in: tensor<8x8xf32>)
+ -> tensor<8x8xf32>
+ return %res : tensor<8x8xf32>
+}
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index 29977a71dbb864..7064e1b3f9dc76 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -1219,6 +1219,17 @@ func.func @batchmatmul_transpose_b(%arg0: memref<2x3x5xf32>, %arg1: memref<2x7x5
// -----
+// CHECK-LABEL: func @mmt4d
+func.func @mmt4d(%A: tensor<10x32x8x1xf32>, %B: tensor<80x32x4x1xf32>, %C: tensor<10x80x8x4xf32>) -> tensor<10x80x8x4xf32> {
+ // CHECK: %{{.+}} = linalg.mmt4d
+ // CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<10x32x8x1xf32>, tensor<80x32x4x1xf32>)
+ // CHECK-SAME: outs(%{{.+}} : tensor<10x80x8x4xf32>) -> tensor<10x80x8x4xf32>
+ %0 = linalg.mmt4d ins(%A, %B : tensor<10x32x8x1xf32>, tensor<80x32x4x1xf32>) outs(%C: tensor<10x80x8x4xf32>) -> tensor<10x80x8x4xf32>
+ return %0: tensor<10x80x8x4xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @batch_mmt4d
func.func @batch_mmt4d(%arg0: tensor<128x10x32x8x1xf32>, %arg1: tensor<128x80x32x4x1xf32>, %arg2: tensor<128x10x80x8x4xf32>) -> tensor<128x10x80x8x4xf32> {
// CHECK: %{{.+}} = linalg.batch_mmt4d
diff --git a/mlir/test/Dialect/Linalg/transform-op-mmt4d-to-fma.mlir b/mlir/test/Dialect/Linalg/transform-op-mmt4d-to-fma.mlir
new file mode 100644
index 00000000000000..26e36d6c88557b
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-op-mmt4d-to-fma.mlir
@@ -0,0 +1,70 @@
+// RUN: mlir-opt %s -transform-interpreter | FileCheck %s
+
+func.func @mmt4d_to_fma(%A: tensor<16x16x8x1xf32>, %B: tensor<16x16x8x1xf32>, %C_in: tensor<16x16x8x8xf32>) -> tensor<16x16x8x8xf32> {
+ %res = linalg.mmt4d
+ ins(%A, %B: tensor<16x16x8x1xf32>, tensor<16x16x8x1xf32>)
+ outs(%C_in: tensor<16x16x8x8xf32>)
+ -> tensor<16x16x8x8xf32>
+ return %res : tensor<16x16x8x8xf32>
+}
+
+
+// CHECK-LABEL: @mmt4d_to_fma
+// CHECK-COUNT-8: vector.fma
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %module : (!transform.any_op) -> !transform.op<"func.func">
+
+ %mmt4d = transform.structured.match ops{["linalg.mmt4d"]} in %func
+
+ // Step 1: Tile
+ : (!transform.op<"func.func">) -> !transform.any_op
+ // Tile parallel dims
+ %tiled_linalg_op_p, %loops:4 = transform.structured.tile_using_for %mmt4d[1, 1, 0, 8, 8, 0]
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+ // Tile reduction dims
+ %tiled_linalg_op_r, %loops2:2 = transform.structured.tile_using_for %tiled_linalg_op_p[0, 0, 1, 0, 0, 1]
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+
+ // Step 2: Vectorize
+ transform.structured.vectorize %tiled_linalg_op_r : !transform.any_op
+
+ // Step 3: Simplify
+ // vector.multi_reduction --> vector.contract
+ // Generates a 6-dim vector.contract with the dim matching the original MMT4D Op
+ // and with the following split int parallel and reduction dims:
+ // * parallel, parallel, reduction, parallel, parallel, reduction
+ transform.apply_patterns to %func {
+ transform.apply_patterns.vector.reduction_to_contract
+ // Reduce the rank of xfer ops. This transforms vector.contract to be
+ // more matmul-like and to enable the lowering to outer product Ops.
+ transform.apply_patterns.vector.transfer_permutation_patterns
+ } : !transform.op<"func.func">
+
+ // Hoisting and LICM - not strictly required
+ %func_h = transform.structured.hoist_redundant_vector_transfers %func
+ : (!transform.op<"func.func">) -> !transform.op<"func.func">
+ %all_loops = transform.structured.match interface{LoopLikeInterface} in %func_h
+ : (!transform.op<"func.func">) -> !transform.any_op
+ transform.apply_licm to %all_loops : !transform.any_op
+ transform.loop.hoist_loop_invariant_subsets %all_loops : !transform.any_op
+
+ // Simplify the 6-dim vector.contract into a 3-dim matmul-like
+ // vector.contract with the following split splitn parallel and reduction
+ // dims:
+ // * parallel, parallel, reduction
+ transform.apply_patterns to %func_h {
+ transform.apply_patterns.vector.reduction_to_contract
+ transform.apply_patterns.vector.cast_away_vector_leading_one_dim
+ transform.apply_patterns.canonicalization
+ } : !transform.op<"func.func">
+
+ // Step 4: Lower vector.contract to vector.fma
+ transform.apply_patterns to %func_h {
+ transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
+ transform.apply_patterns.vector.lower_outerproduct
+ } : !transform.op<"func.func">
+ transform.yield
+ }
+}
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index 5d1bef478ee987..548c9c7ba76485 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -639,6 +639,38 @@ module attributes {transform.with_named_sequence} {
// -----
+func.func @mmt4d(%A: memref<16x16x8x1xf32>, %B: memref<16x16x8x1xf32>, %C_in: memref<16x16x8x8xf32>) {
+ linalg.mmt4d ins(%A, %B: memref<16x16x8x1xf32>, memref<16x16x8x1xf32>)
+ outs(%C_in: memref<16x16x8x8xf32>)
+ return
+}
+
+// CHECK-LABEL: func.func @mmt4d(
+// CHECK-SAME: %[[A:.*]]: memref<16x16x8x1xf32>, %[[B:.*]]: memref<16x16x8x1xf32>, %[[C:.*]]: memref<16x16x8x8xf32>) {
+// CHECK: %[[VAL_3:.*]] = arith.constant 16 : index
+// CHECK: %[[VAL_4:.*]] = arith.constant 16 : index
+// CHECK: %[[VAL_5:.*]] = arith.constant 16 : index
+// CHECK: %[[VAL_6:.*]] = arith.constant 8 : index
+// CHECK: %[[VAL_7:.*]] = arith.constant 8 : index
+// CHECK: %[[VAL_8:.*]] = arith.constant 1 : index
+// CHECK: %[[VAL_9:.*]] = arith.constant 0 : index
+// CHECK: %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x8x1xf32>
+// CHECK: %[[VEC_B:.*]] = vector.transfer_read %[[B]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x8x1xf32>
+// CHECK: %[[VEC_C:.*]] = vector.transfer_read %[[C]]{{.*}} : memref<16x16x8x8xf32>, vector<16x16x8x8xf32>
+// CHECK: %[[MUL:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<16x16x16x8x8x1xf32>
+// CHECK: %[[RED:.*]] = vector.multi_reduction <add>, %[[MUL]], %[[VEC_C]] [2, 5] : vector<16x16x16x8x8x1xf32> to vector<16x16x8x8xf32>
+// CHECK: vector.transfer_write %[[RED]], %[[C]]{{.*}} : vector<16x16x8x8xf32>, memref<16x16x8x8xf32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %mmt4d = transform.structured.match ops{["linalg.mmt4d"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %mmt4d : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
func.func @matmul_scalable(%A: memref<?x?xf32>, %B: memref<?x?xf32>, %C: memref<?x?xf32>) {
linalg.matmul ins(%A, %B: memref<?x?xf32>, memref<?x?xf32>)
outs(%C: memref<?x?xf32>)
>From 89fd5a5d63e64d72249042f35d8ca1c57bbaa653 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Tue, 13 Feb 2024 12:18:07 +0000
Subject: [PATCH 2/2] fixup! [mlir][nfc] Remove leftover print stmt in a test
Address Cullen's comments
---
mlir/test/Dialect/Linalg/transform-op-mmt4d-to-fma.mlir | 6 +++---
mlir/test/Dialect/Linalg/vectorization.mlir | 7 -------
2 files changed, 3 insertions(+), 10 deletions(-)
diff --git a/mlir/test/Dialect/Linalg/transform-op-mmt4d-to-fma.mlir b/mlir/test/Dialect/Linalg/transform-op-mmt4d-to-fma.mlir
index 26e36d6c88557b..61e13d1bfa9c62 100644
--- a/mlir/test/Dialect/Linalg/transform-op-mmt4d-to-fma.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-mmt4d-to-fma.mlir
@@ -33,7 +33,7 @@ module attributes {transform.with_named_sequence} {
// Step 3: Simplify
// vector.multi_reduction --> vector.contract
// Generates a 6-dim vector.contract with the dim matching the original MMT4D Op
- // and with the following split int parallel and reduction dims:
+ // and with the following split into parallel and reduction dims:
// * parallel, parallel, reduction, parallel, parallel, reduction
transform.apply_patterns to %func {
transform.apply_patterns.vector.reduction_to_contract
@@ -51,7 +51,7 @@ module attributes {transform.with_named_sequence} {
transform.loop.hoist_loop_invariant_subsets %all_loops : !transform.any_op
// Simplify the 6-dim vector.contract into a 3-dim matmul-like
- // vector.contract with the following split splitn parallel and reduction
+ // vector.contract with the following split into parallel and reduction
// dims:
// * parallel, parallel, reduction
transform.apply_patterns to %func_h {
@@ -60,7 +60,7 @@ module attributes {transform.with_named_sequence} {
transform.apply_patterns.canonicalization
} : !transform.op<"func.func">
- // Step 4: Lower vector.contract to vector.fma
+ // Step 4: Lower vector.contract to vector.fma via vector.outerproduct
transform.apply_patterns to %func_h {
transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
transform.apply_patterns.vector.lower_outerproduct
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index 548c9c7ba76485..0272ac599aa3db 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -647,13 +647,6 @@ func.func @mmt4d(%A: memref<16x16x8x1xf32>, %B: memref<16x16x8x1xf32>, %C_in: me
// CHECK-LABEL: func.func @mmt4d(
// CHECK-SAME: %[[A:.*]]: memref<16x16x8x1xf32>, %[[B:.*]]: memref<16x16x8x1xf32>, %[[C:.*]]: memref<16x16x8x8xf32>) {
-// CHECK: %[[VAL_3:.*]] = arith.constant 16 : index
-// CHECK: %[[VAL_4:.*]] = arith.constant 16 : index
-// CHECK: %[[VAL_5:.*]] = arith.constant 16 : index
-// CHECK: %[[VAL_6:.*]] = arith.constant 8 : index
-// CHECK: %[[VAL_7:.*]] = arith.constant 8 : index
-// CHECK: %[[VAL_8:.*]] = arith.constant 1 : index
-// CHECK: %[[VAL_9:.*]] = arith.constant 0 : index
// CHECK: %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x8x1xf32>
// CHECK: %[[VEC_B:.*]] = vector.transfer_read %[[B]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x8x1xf32>
// CHECK: %[[VEC_C:.*]] = vector.transfer_read %[[C]]{{.*}} : memref<16x16x8x8xf32>, vector<16x16x8x8xf32>
More information about the Mlir-commits
mailing list