[Mlir-commits] [mlir] 17afa5b - [mlir][nfc] Update tests for Contract -> Op transforms (#76054)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Dec 21 05:20:21 PST 2023
Author: Andrzej WarzyĆski
Date: 2023-12-21T13:20:16Z
New Revision: 17afa5befb4cbe86c22c25ae1603433c8bd21551
URL: https://github.com/llvm/llvm-project/commit/17afa5befb4cbe86c22c25ae1603433c8bd21551
DIFF: https://github.com/llvm/llvm-project/commit/17afa5befb4cbe86c22c25ae1603433c8bd21551.diff
LOG: [mlir][nfc] Update tests for Contract -> Op transforms (#76054)
Updates two tests for vector.contract -> vector.outerproduct
transformations:
1. Rename "vector-contract-to-outerproduct-transforms.mlir" as
"vector-contract-to-outerproduct-matmul-transforms.mlir". The new
name more accurate captures what's being tested. it is also
consistent with
"vector-contract-to-outerproduct-matvec-transforms.mlir", which
covers vector matvec operations and makes finding relevant tests
easier.
2. For matmul tests, move the traits definining the iteration spaces to
the top of the file. This is consistent with how matvec tests are
defined and also makes it easy to quickly identify what cases are
covered.
3. For matmul tests, use more meaningful names for function arguments.
This helps keep things consistent across the file (i.e. function
definitions wih check lines and comments).
4. For matvec test, move a few tests around so that the most basic case
(without masking) is first.
5. Update comments.
Added:
mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matmul-transforms.mlir
Modified:
mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir
Removed:
mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
################################################################################
diff --git a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matmul-transforms.mlir
similarity index 81%
rename from mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
rename to mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matmul-transforms.mlir
index 7588b738ff9aa3..7a60ff8ea85897 100644
--- a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matmul-transforms.mlir
@@ -1,20 +1,22 @@
// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s
-// NOTE - tests in this file are duplicated so that there's a version for
-// * _fixed width_ and for _scalable_ vectors.
-// In order for the "vector.contract -> vector.outerproduct" patterns to work,
-// only the non-reduction dimension can be scalable (*). For Matmul operations
-// that is set to be the N dimension (i.e. rows of the output matrix), which
-// matches how matrix multiplication are normally implemented for e.g.
-// Arm SVE. However, making the M dimension scalable (i.e. columns of the
-// output matrix) should work as well.
-//
-// (*) The conversion tested in this file unrolls along the reduction
-// dimension, which is not supported for scalable vectors.
+/// Tests for `vector.contract` -> `vector.outerproduct` transformations for
+/// matmul operations:
+/// C += A * B.
+/// (A, B and C are 2-d matrices). ATM three
diff erent variants / are tested:
+/// * plain (no mask, fixed-wdith vectors),
+/// * masked (fixed-width vectors,
+/// * scalable (mask + scalable vectors).
+/// In order for the "vector.contract -> vector.outerproduct" patterns to work,
+/// only the non-reduction dimension can be scalable (*). For matmul operations
+/// that is set to be the N dimension (i.e. rows of the output matrix), which
+/// matches how matrix multiplication are normally implemented for e.g.
+/// Arm SVE. However, making the M dimension scalable (i.e. columns of the
+/// output matrix) should work as well.
+///
+/// (*) The conversion tested in this file unrolls along the reduction
+/// dimension, which is not supported for scalable vectors.
-// ============================================================================
-// Matmul 0 (plain + masked + mixed types)
-// ============================================================================
#matmat_accesses_0 = [
affine_map<(m, n, k) -> (m, k)>,
affine_map<(m, n, k) -> (k, n)>,
@@ -25,6 +27,49 @@
iterator_types = ["parallel", "parallel", "reduction"]
}
+#matmat_accesses_1 = [
+ affine_map<(m, n, k) -> (m, k)>,
+ affine_map<(m, n, k) -> (n, k)>,
+ affine_map<(m, n, k) -> (m, n)>
+]
+#matmat_trait_1 = {
+ indexing_maps = #matmat_accesses_1,
+ iterator_types = ["parallel", "parallel", "reduction"]
+}
+
+#matmat_accesses_2 = [
+ affine_map<(m, n, k) -> (k, m)>,
+ affine_map<(m, n, k) -> (k, n)>,
+ affine_map<(m, n, k) -> (m, n)>
+]
+#matmat_trait_2 = {
+ indexing_maps = #matmat_accesses_2,
+ iterator_types = ["parallel", "parallel", "reduction"]
+}
+
+#matmat_accesses_3 = [
+ affine_map<(m, n, k) -> (k, m)>,
+ affine_map<(m, n, k) -> (n, k)>,
+ affine_map<(m, n, k) -> (m, n)>
+]
+#matmat_trait_3 = {
+ indexing_maps = #matmat_accesses_3,
+ iterator_types = ["parallel", "parallel", "reduction"]
+}
+
+#matmat_accesses_4 = [
+ affine_map<(m, n, k) -> (m, k)>,
+ affine_map<(m, n, k) -> (k, n)>,
+ affine_map<(m, n, k) -> (n, m)>
+]
+#matmat_trait_4 = {
+ indexing_maps = #matmat_accesses_4,
+ iterator_types = ["parallel", "parallel", "reduction"]
+}
+
+// ============================================================================
+// Matmul 0 (plain + masked + mixed types)
+// ============================================================================
// CHECK-LABEL: func @matmul
// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x4xf32>,
// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x3xf32>,
@@ -53,10 +98,10 @@
// CHECK-SAME: : vector<2xf32>, vector<3xf32>
//
// CHECK: return %[[c3]] : vector<2x3xf32>
-func.func @matmul(%arg0: vector<2x4xf32>,
- %arg1: vector<4x3xf32>,
- %arg2: vector<2x3xf32>) -> vector<2x3xf32> {
- %0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2
+func.func @matmul(%A: vector<2x4xf32>,
+ %B: vector<4x3xf32>,
+ %C: vector<2x3xf32>) -> vector<2x3xf32> {
+ %0 = vector.contract #matmat_trait_0 %A, %B, %C
: vector<2x4xf32>, vector<4x3xf32> into vector<2x3xf32>
return %0 : vector<2x3xf32>
}
@@ -89,10 +134,10 @@ func.func @matmul(%arg0: vector<2x4xf32>,
// CHECK-SAME: : vector<2xf32>, vector<[3]xf32>
//
// CHECK: return %[[c3]] : vector<2x[3]xf32>
-func.func @matmul_scalable(%arg0: vector<2x4xf32>,
- %arg1: vector<4x[3]xf32>,
- %arg2: vector<2x[3]xf32>) -> vector<2x[3]xf32> {
- %0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2
+func.func @matmul_scalable(%A: vector<2x4xf32>,
+ %B: vector<4x[3]xf32>,
+ %C: vector<2x[3]xf32>) -> vector<2x[3]xf32> {
+ %0 = vector.contract #matmat_trait_0 %A, %B, %C
: vector<2x4xf32>, vector<4x[3]xf32> into vector<2x[3]xf32>
return %0 : vector<2x[3]xf32>
}
@@ -114,11 +159,11 @@ func.func @matmul_scalable(%arg0: vector<2x4xf32>,
// CHECK: %[[T_MASK_R4:.*]] = vector.extract %[[T_MASK]][4] : vector<3x7xi1> from vector<5x3x7xi1>
// CHECK: %{{.*}} = vector.mask %[[T_MASK_R4]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
-func.func @masked_matmul(%arg0: vector<3x5xf32>,
- %arg1: vector<5x7xf32>,
- %arg2: vector<3x7xf32>,
+func.func @masked_matmul(%A: vector<3x5xf32>,
+ %B: vector<5x7xf32>,
+ %C: vector<3x7xf32>,
%m : vector<3x7x5xi1>) -> vector<3x7xf32> {
- %0 = vector.mask %m { vector.contract #matmat_trait_0 %arg0, %arg1, %arg2
+ %0 = vector.mask %m { vector.contract #matmat_trait_0 %A, %B, %C
: vector<3x5xf32>, vector<5x7xf32> into vector<3x7xf32> } : vector<3x7x5xi1> -> vector<3x7xf32>
return %0 : vector<3x7xf32>
}
@@ -140,11 +185,11 @@ func.func @masked_matmul(%arg0: vector<3x5xf32>,
// CHECK: %[[T_MASK_R4:.*]] = vector.extract %[[T_MASK]][4] : vector<3x[7]xi1> from vector<5x3x[7]xi1>
// CHECK: %{{.*}} = vector.mask %[[T_MASK_R4]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32>
-func.func @masked_matmul_scalable(%arg0: vector<3x5xf32>,
- %arg1: vector<5x[7]xf32>,
- %arg2: vector<3x[7]xf32>,
+func.func @masked_matmul_scalable(%A: vector<3x5xf32>,
+ %B: vector<5x[7]xf32>,
+ %C: vector<3x[7]xf32>,
%m : vector<3x[7]x5xi1>) -> vector<3x[7]xf32> {
- %0 = vector.mask %m { vector.contract #matmat_trait_0 %arg0, %arg1, %arg2
+ %0 = vector.mask %m { vector.contract #matmat_trait_0 %A, %B, %C
: vector<3x5xf32>, vector<5x[7]xf32> into vector<3x[7]xf32> } : vector<3x[7]x5xi1> -> vector<3x[7]xf32>
return %0 : vector<3x[7]xf32>
}
@@ -160,11 +205,11 @@ func.func @masked_matmul_scalable(%arg0: vector<3x5xf32>,
// CHECK: %[[b1:.*]] = arith.extf %[[b0]] : vector<3xf16> to vector<3xf32>
// CHECK: %[[c0:.*]] = vector.outerproduct %[[a1]], %[[b1]], %[[C]]
// CHECK: return %[[c0]] : vector<2x3xf32>
-func.func @matmul_mixed(%arg0: vector<2x1xf16>,
- %arg1: vector<1x3xf16>,
- %arg2: vector<2x3xf32>) -> vector<2x3xf32>
+func.func @matmul_mixed(%A: vector<2x1xf16>,
+ %B: vector<1x3xf16>,
+ %C: vector<2x3xf32>) -> vector<2x3xf32>
{
- %0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2
+ %0 = vector.contract #matmat_trait_0 %A, %B, %C
: vector<2x1xf16>, vector<1x3xf16> into vector<2x3xf32>
return %0 : vector<2x3xf32>
}
@@ -180,28 +225,18 @@ func.func @matmul_mixed(%arg0: vector<2x1xf16>,
// CHECK: %[[b1:.*]] = arith.extf %[[b0]] : vector<[3]xf16> to vector<[3]xf32>
// CHECK: %[[c0:.*]] = vector.outerproduct %[[a1]], %[[b1]], %[[C]]
// CHECK: return %[[c0]] : vector<2x[3]xf32>
-func.func @matmul_mixed_scalable(%arg0: vector<2x1xf16>,
- %arg1: vector<1x[3]xf16>,
- %arg2: vector<2x[3]xf32>) -> vector<2x[3]xf32>
+func.func @matmul_mixed_scalable(%A: vector<2x1xf16>,
+ %B: vector<1x[3]xf16>,
+ %C: vector<2x[3]xf32>) -> vector<2x[3]xf32>
{
- %0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2
+ %0 = vector.contract #matmat_trait_0 %A, %B, %C
: vector<2x1xf16>, vector<1x[3]xf16> into vector<2x[3]xf32>
return %0 : vector<2x[3]xf32>
}
// ============================================================================
-// Matmul 1 (plain)
+// Matmul 1 (plain + scalable)
// ============================================================================
-#matmat_accesses_1 = [
- affine_map<(m, n, k) -> (m, k)>,
- affine_map<(m, n, k) -> (n, k)>,
- affine_map<(m, n, k) -> (m, n)>
-]
-#matmat_trait_1 = {
- indexing_maps = #matmat_accesses_1,
- iterator_types = ["parallel", "parallel", "reduction"]
-}
-
// CHECK-LABEL: func @matmul_1
// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<3x1xf32>,
@@ -212,11 +247,11 @@ func.func @matmul_mixed_scalable(%arg0: vector<2x1xf16>,
// CHECK: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<3xf32> from vector<1x3xf32>
// CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
// CHECK: return %[[c0]] : vector<2x3xf32>
-func.func @matmul_1(%arg0: vector<2x1xf32>,
- %arg1: vector<3x1xf32>,
- %arg2: vector<2x3xf32>) -> vector<2x3xf32>
+func.func @matmul_1(%A: vector<2x1xf32>,
+ %B: vector<3x1xf32>,
+ %C: vector<2x3xf32>) -> vector<2x3xf32>
{
- %0 = vector.contract #matmat_trait_1 %arg0, %arg1, %arg2
+ %0 = vector.contract #matmat_trait_1 %A, %B, %C
: vector<2x1xf32>, vector<3x1xf32> into vector<2x3xf32>
return %0 : vector<2x3xf32>
}
@@ -231,28 +266,18 @@ func.func @matmul_1(%arg0: vector<2x1xf32>,
// CHECK: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<[3]xf32> from vector<1x[3]xf32>
// CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
// CHECK: return %[[c0]] : vector<2x[3]xf32>
-func.func @matmul_1_scalable(%arg0: vector<2x1xf32>,
- %arg1: vector<[3]x1xf32>,
- %arg2: vector<2x[3]xf32>) -> vector<2x[3]xf32>
+func.func @matmul_1_scalable(%A: vector<2x1xf32>,
+ %B: vector<[3]x1xf32>,
+ %C: vector<2x[3]xf32>) -> vector<2x[3]xf32>
{
- %0 = vector.contract #matmat_trait_1 %arg0, %arg1, %arg2
+ %0 = vector.contract #matmat_trait_1 %A, %B, %C
: vector<2x1xf32>, vector<[3]x1xf32> into vector<2x[3]xf32>
return %0 : vector<2x[3]xf32>
}
// ============================================================================
-// Matmul 2 (plain)
+// Matmul 2 (plain + scalable)
// ============================================================================
-#matmat_accesses_2 = [
- affine_map<(m, n, k) -> (k, m)>,
- affine_map<(m, n, k) -> (k, n)>,
- affine_map<(m, n, k) -> (m, n)>
-]
-#matmat_trait_2 = {
- indexing_maps = #matmat_accesses_2,
- iterator_types = ["parallel", "parallel", "reduction"]
-}
-
// CHECK-LABEL: func @matmul_2
// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<1x2xf32>,
// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>,
@@ -261,11 +286,11 @@ func.func @matmul_1_scalable(%arg0: vector<2x1xf32>,
// CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<3xf32> from vector<1x3xf32>
// CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
// CHECK: return %[[c0]] : vector<2x3xf32>
-func.func @matmul_2(%arg0: vector<1x2xf32>,
- %arg1: vector<1x3xf32>,
- %arg2: vector<2x3xf32>) -> vector<2x3xf32>
+func.func @matmul_2(%A: vector<1x2xf32>,
+ %B: vector<1x3xf32>,
+ %C: vector<2x3xf32>) -> vector<2x3xf32>
{
- %0 = vector.contract #matmat_trait_2 %arg0, %arg1, %arg2
+ %0 = vector.contract #matmat_trait_2 %A, %B, %C
: vector<1x2xf32>, vector<1x3xf32> into vector<2x3xf32>
return %0 : vector<2x3xf32>
}
@@ -278,28 +303,18 @@ func.func @matmul_2(%arg0: vector<1x2xf32>,
// CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<[3]xf32> from vector<1x[3]xf32>
// CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
// CHECK: return %[[c0]] : vector<2x[3]xf32>
-func.func @matmul_2_scalable(%arg0: vector<1x2xf32>,
- %arg1: vector<1x[3]xf32>,
- %arg2: vector<2x[3]xf32>) -> vector<2x[3]xf32>
+func.func @matmul_2_scalable(%A: vector<1x2xf32>,
+ %B: vector<1x[3]xf32>,
+ %C: vector<2x[3]xf32>) -> vector<2x[3]xf32>
{
- %0 = vector.contract #matmat_trait_2 %arg0, %arg1, %arg2
+ %0 = vector.contract #matmat_trait_2 %A, %B, %C
: vector<1x2xf32>, vector<1x[3]xf32> into vector<2x[3]xf32>
return %0 : vector<2x[3]xf32>
}
// ============================================================================
-// Matmul 3 (plain)
+// Matmul 3 (plain + scalable)
// ============================================================================
-#matmat_accesses_3 = [
- affine_map<(m, n, k) -> (k, m)>,
- affine_map<(m, n, k) -> (n, k)>,
- affine_map<(m, n, k) -> (m, n)>
-]
-#matmat_trait_3 = {
- indexing_maps = #matmat_accesses_3,
- iterator_types = ["parallel", "parallel", "reduction"]
-}
-
// CHECK-LABEL: func @matmul_3
// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<1x2xf32>,
// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<3x1xf32>,
@@ -309,11 +324,11 @@ func.func @matmul_2_scalable(%arg0: vector<1x2xf32>,
// CHECK: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<3xf32> from vector<1x3xf32>
// CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
// CHECK: return %[[c0]] : vector<2x3xf32>
-func.func @matmul_3(%arg0: vector<1x2xf32>,
- %arg1: vector<3x1xf32>,
- %arg2: vector<2x3xf32>) -> vector<2x3xf32>
+func.func @matmul_3(%A: vector<1x2xf32>,
+ %B: vector<3x1xf32>,
+ %C: vector<2x3xf32>) -> vector<2x3xf32>
{
- %0 = vector.contract #matmat_trait_3 %arg0, %arg1, %arg2
+ %0 = vector.contract #matmat_trait_3 %A, %B, %C
: vector<1x2xf32>, vector<3x1xf32> into vector<2x3xf32>
return %0 : vector<2x3xf32>
}
@@ -327,28 +342,18 @@ func.func @matmul_3(%arg0: vector<1x2xf32>,
// CHECK: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<[3]xf32> from vector<1x[3]xf32>
// CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
// CHECK: return %[[c0]] : vector<2x[3]xf32>
-func.func @matmul_3_scalable(%arg0: vector<1x2xf32>,
- %arg1: vector<[3]x1xf32>,
- %arg2: vector<2x[3]xf32>) -> vector<2x[3]xf32>
+func.func @matmul_3_scalable(%A: vector<1x2xf32>,
+ %B: vector<[3]x1xf32>,
+ %C: vector<2x[3]xf32>) -> vector<2x[3]xf32>
{
- %0 = vector.contract #matmat_trait_3 %arg0, %arg1, %arg2
+ %0 = vector.contract #matmat_trait_3 %A, %B, %C
: vector<1x2xf32>, vector<[3]x1xf32> into vector<2x[3]xf32>
return %0 : vector<2x[3]xf32>
}
// ============================================================================
-// Matmul 4 (plain)
+// Matmul 4 (plain + scalable)
// ============================================================================
-#matmat_accesses_4 = [
- affine_map<(m, n, k) -> (m, k)>,
- affine_map<(m, n, k) -> (k, n)>,
- affine_map<(m, n, k) -> (n, m)>
-]
-#matmat_trait_4 = {
- indexing_maps = #matmat_accesses_4,
- iterator_types = ["parallel", "parallel", "reduction"]
-}
-
// CHECK-LABEL: func @matmul_4
// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>,
@@ -358,11 +363,11 @@ func.func @matmul_3_scalable(%arg0: vector<1x2xf32>,
// CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf32> from vector<1x2xf32>
// CHECK: %[[c0:.*]] = vector.outerproduct %[[b0]], %[[a0]], %[[C]]
// CHECK: return %[[c0]] : vector<3x2xf32>
-func.func @matmul_4(%arg0: vector<2x1xf32>,
- %arg1: vector<1x3xf32>,
- %arg2: vector<3x2xf32>) -> vector<3x2xf32>
+func.func @matmul_4(%A: vector<2x1xf32>,
+ %B: vector<1x3xf32>,
+ %C: vector<3x2xf32>) -> vector<3x2xf32>
{
- %0 = vector.contract #matmat_trait_4 %arg0, %arg1, %arg2
+ %0 = vector.contract #matmat_trait_4 %A, %B, %C
: vector<2x1xf32>, vector<1x3xf32> into vector<3x2xf32>
return %0 : vector<3x2xf32>
}
@@ -376,11 +381,11 @@ func.func @matmul_4(%arg0: vector<2x1xf32>,
// CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<[2]xf32> from vector<1x[2]xf32>
// CHECK: %[[c0:.*]] = vector.outerproduct %[[b0]], %[[a0]], %[[C]]
// CHECK: return %[[c0]] : vector<3x[2]xf32>
-func.func @matmul_4_scalable(%arg0: vector<[2]x1xf32>,
- %arg1: vector<1x3xf32>,
- %arg2: vector<3x[2]xf32>) -> vector<3x[2]xf32>
+func.func @matmul_4_scalable(%A: vector<[2]x1xf32>,
+ %B: vector<1x3xf32>,
+ %C: vector<3x[2]xf32>) -> vector<3x[2]xf32>
{
- %0 = vector.contract #matmat_trait_4 %arg0, %arg1, %arg2
+ %0 = vector.contract #matmat_trait_4 %A, %B, %C
: vector<[2]x1xf32>, vector<1x3xf32> into vector<3x[2]xf32>
return %0 : vector<3x[2]xf32>
}
diff --git a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir
index c09a4d569638a5..d86c6158bcdf2f 100644
--- a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir
@@ -235,6 +235,23 @@ func.func @masked_matvec_mk_k_m_max_scalable_parallel_dim(%A: vector<[2]x3xf32>,
// ============================================================================
// Matvec 2 (plain + masked + scalable)
// ============================================================================
+// CHECK-LABEL: func @matvec_km_k_m
+// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
+// CHECK-SAME: %[[X:.*1]]: vector<2xf32>
+// CHECK-SAME: %[[B:.*2]]: vector<2xf32>
+// CHECK: %[[T3:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[T4:.*]] = vector.extract %[[X]][0] : f32 from vector<2xf32>
+// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[B]] {kind = #vector.kind<add>} : vector<2xf32>, f32
+// CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[T7:.*]] = vector.extract %[[X]][1] : f32 from vector<2xf32>
+// CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] {kind = #vector.kind<add>} : vector<2xf32>, f32
+func.func @matvec_km_k_m(%A: vector<2x2xf32>,
+ %x: vector<2xf32>,
+ %b: vector<2xf32>) -> vector<2xf32> {
+ %0 = vector.contract #matvec_trait_2 %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32>
+ return %0 : vector<2xf32>
+}
+
// CHECK-LABEL: @masked_matvec_km_k_m
// CHECK-SAME: %[[A:.+]]: vector<2x4xf32>
// CHECK-SAME: %[[X:.+]]: vector<2xf32>
@@ -273,26 +290,27 @@ func.func @masked_matvec_km_k_m_scalable_parallel_dim(%A: vector<2x[4]xf32>,
return %res : vector<[4]xf32>
}
-// CHECK-LABEL: func @matvec_km_k_m
+// ============================================================================
+// Matvec 3 (plain + masked + scalable)
+// ============================================================================
+// CHECK-LABEL: func @matvec_k_mk_m
// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
// CHECK-SAME: %[[X:.*1]]: vector<2xf32>
// CHECK-SAME: %[[B:.*2]]: vector<2xf32>
-// CHECK: %[[T3:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T4:.*]] = vector.extract %[[X]][0] : f32 from vector<2xf32>
-// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[B]] {kind = #vector.kind<add>} : vector<2xf32>, f32
-// CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T7:.*]] = vector.extract %[[X]][1] : f32 from vector<2xf32>
-// CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] {kind = #vector.kind<add>} : vector<2xf32>, f32
-func.func @matvec_km_k_m(%A: vector<2x2xf32>,
+// CHECK: %[[T3:.*]] = vector.transpose %[[A]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
+// CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[T5:.*]] = vector.extract %[[X]][0] : f32 from vector<2xf32>
+// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[B]] {kind = #vector.kind<add>} : vector<2xf32>, f32
+// CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[T8:.*]] = vector.extract %[[X]][1] : f32 from vector<2xf32>
+// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind<add>} : vector<2xf32>, f32
+func.func @matvec_k_mk_m(%A: vector<2x2xf32>,
%x: vector<2xf32>,
%b: vector<2xf32>) -> vector<2xf32> {
- %0 = vector.contract #matvec_trait_2 %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32>
+ %0 = vector.contract #matvec_trait_3 %x, %A, %b : vector<2xf32>, vector<2x2xf32> into vector<2xf32>
return %0 : vector<2xf32>
}
-// ============================================================================
-// Matvec 3 (plain + masked + scalable)
-// ============================================================================
// CHECK-LABEL: @masked_matvec_k_mk_m
// CHECK-SAME: %[[A:.+]]: vector<4x2xf32>
// CHECK-SAME: %[[X:.+]]: vector<2xf32>
@@ -331,24 +349,6 @@ func.func @masked_matvec_k_mk_m_scalable_parallel_dim(%A: vector<[4]x2xf32>,
return %res : vector<[4]xf32>
}
-// CHECK-LABEL: func @matvec_k_mk_m
-// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
-// CHECK-SAME: %[[X:.*1]]: vector<2xf32>
-// CHECK-SAME: %[[B:.*2]]: vector<2xf32>
-// CHECK: %[[T3:.*]] = vector.transpose %[[A]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
-// CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T5:.*]] = vector.extract %[[X]][0] : f32 from vector<2xf32>
-// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[B]] {kind = #vector.kind<add>} : vector<2xf32>, f32
-// CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T8:.*]] = vector.extract %[[X]][1] : f32 from vector<2xf32>
-// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind<add>} : vector<2xf32>, f32
-func.func @matvec_k_mk_m(%A: vector<2x2xf32>,
- %x: vector<2xf32>,
- %b: vector<2xf32>) -> vector<2xf32> {
- %0 = vector.contract #matvec_trait_3 %x, %A, %b : vector<2xf32>, vector<2x2xf32> into vector<2xf32>
- return %0 : vector<2xf32>
-}
-
// ============================================================================
// Matvec 4 (plain + masked + scalable)
// ============================================================================
More information about the Mlir-commits
mailing list