[Mlir-commits] [mlir] [mlir][nfc] Update tests for Contract -> Op transforms (PR #76054)

Andrzej WarzyƄski llvmlistbot at llvm.org
Wed Dec 20 05:33:56 PST 2023


https://github.com/banach-space created https://github.com/llvm/llvm-project/pull/76054

Updates two tests for vector.contract -> vector.outerproduct
transformations:

1. Rename "vector-contract-to-outerproduct-transforms.mlir" as
   "vector-contract-to-outerproduct-matmul-transforms.mlir". The new
   name more accurate captures what's being tested. it is also
   consistent with
   "vector-contract-to-outerproduct-matvec-transforms.mlir", which
   covers vector matvec operations and makes finding relevant tests
   easier.

2. For matmul tests, move the traits definining the iteration spaces to
   the top of the file. This is consistent with how matvec tests are
   defined and also makes it easy to quickly identify what cases are
   covered.

3. For matmul tests, use more meaningful names for function arguments.
   This helps keep things consistent across the file (i.e. function
   definitions wih check lines and comments).

4. For matvec test, move a few tests around so that the most basic case
   (without masking) is first.

5. Update comments.


>From 4dbfc390052c08a5a219873cf346a85bde2c287e Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Wed, 20 Dec 2023 12:38:24 +0000
Subject: [PATCH 1/5] [mlir][nfc] Update tests for Contract -> Op transforms

Updates two tests for vector.contract -> vector.outerproduct
transformations:

1. Rename "vector-contract-to-outerproduct-transforms.mlir" as
   "vector-contract-to-outerproduct-matmul-transforms.mlir". The new
   name more accurate captures what's being tested. it is also
   consistent with
   "vector-contract-to-outerproduct-matvec-transforms.mlir", which
   covers vector matvec operations and makes finding relevant tests
   easier.

2. For matmul tests, move the traits definining the iteration spaces to
   the top of the file. This is consistent with how matvec tests are
   defined and also makes it easy to quickly identify what cases are
   covered.

3. For matmul tests, use more meaningful names for function arguments.
   This helps keep things consistent across the file (i.e. function
   definitions wih check lines and comments).

4. For matvec test, move a few tests around so that the most basic case
   (without masking) is first.

5. Update comments.
---
 ...lir => vector-contract-to-outerproduct-matmul-transforms.mlir} | 0
 1 file changed, 0 insertions(+), 0 deletions(-)
 rename mlir/test/Dialect/Vector/{vector-contract-to-outerproduct-transforms.mlir => vector-contract-to-outerproduct-matmul-transforms.mlir} (100%)

diff --git a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matmul-transforms.mlir
similarity index 100%
rename from mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
rename to mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matmul-transforms.mlir

>From 35e19879a3c7a42a9e5bf69498ccd8cb48ec5a44 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Wed, 20 Dec 2023 12:39:57 +0000
Subject: [PATCH 2/5] fixup! [mlir][vector] Extend `CreateMaskFolder` (#75842)

Move matmul traits
---
 ...act-to-outerproduct-matmul-transforms.mlir | 81 ++++++++++---------
 1 file changed, 41 insertions(+), 40 deletions(-)

diff --git a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matmul-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matmul-transforms.mlir
index 7588b738ff9aa3..ce4b95da1f12f9 100644
--- a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matmul-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matmul-transforms.mlir
@@ -25,6 +25,47 @@
   iterator_types = ["parallel", "parallel", "reduction"]
 }
 
+#matmat_accesses_1 = [
+  affine_map<(m, n, k) -> (m, k)>,
+  affine_map<(m, n, k) -> (n, k)>,
+  affine_map<(m, n, k) -> (m, n)>
+]
+#matmat_trait_1 = {
+  indexing_maps = #matmat_accesses_1,
+  iterator_types = ["parallel", "parallel", "reduction"]
+}
+
+#matmat_accesses_2 = [
+  affine_map<(m, n, k) -> (k, m)>,
+  affine_map<(m, n, k) -> (k, n)>,
+  affine_map<(m, n, k) -> (m, n)>
+]
+#matmat_trait_2 = {
+  indexing_maps = #matmat_accesses_2,
+  iterator_types = ["parallel", "parallel", "reduction"]
+}
+
+#matmat_accesses_3 = [
+  affine_map<(m, n, k) -> (k, m)>,
+  affine_map<(m, n, k) -> (n, k)>,
+  affine_map<(m, n, k) -> (m, n)>
+]
+#matmat_trait_3 = {
+  indexing_maps = #matmat_accesses_3,
+  iterator_types = ["parallel", "parallel", "reduction"]
+}
+
+#matmat_accesses_4 = [
+  affine_map<(m, n, k) -> (m, k)>,
+  affine_map<(m, n, k) -> (k, n)>,
+  affine_map<(m, n, k) -> (n, m)>
+]
+#matmat_trait_4 = {
+  indexing_maps = #matmat_accesses_4,
+  iterator_types = ["parallel", "parallel", "reduction"]
+}
+
+
 // CHECK-LABEL: func @matmul
 // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x4xf32>,
 // CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x3xf32>,
@@ -192,16 +233,6 @@ func.func @matmul_mixed_scalable(%arg0: vector<2x1xf16>,
 // ============================================================================
 //  Matmul 1 (plain)
 // ============================================================================
-#matmat_accesses_1 = [
-  affine_map<(m, n, k) -> (m, k)>,
-  affine_map<(m, n, k) -> (n, k)>,
-  affine_map<(m, n, k) -> (m, n)>
-]
-#matmat_trait_1 = {
-  indexing_maps = #matmat_accesses_1,
-  iterator_types = ["parallel", "parallel", "reduction"]
-}
-
 // CHECK-LABEL: func @matmul_1
 // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
 // CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<3x1xf32>,
@@ -243,16 +274,6 @@ func.func @matmul_1_scalable(%arg0: vector<2x1xf32>,
 // ============================================================================
 //  Matmul 2 (plain)
 // ============================================================================
-#matmat_accesses_2 = [
-  affine_map<(m, n, k) -> (k, m)>,
-  affine_map<(m, n, k) -> (k, n)>,
-  affine_map<(m, n, k) -> (m, n)>
-]
-#matmat_trait_2 = {
-  indexing_maps = #matmat_accesses_2,
-  iterator_types = ["parallel", "parallel", "reduction"]
-}
-
 // CHECK-LABEL: func @matmul_2
 // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<1x2xf32>,
 // CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>,
@@ -290,16 +311,6 @@ func.func @matmul_2_scalable(%arg0: vector<1x2xf32>,
 // ============================================================================
 //  Matmul 3 (plain)
 // ============================================================================
-#matmat_accesses_3 = [
-  affine_map<(m, n, k) -> (k, m)>,
-  affine_map<(m, n, k) -> (n, k)>,
-  affine_map<(m, n, k) -> (m, n)>
-]
-#matmat_trait_3 = {
-  indexing_maps = #matmat_accesses_3,
-  iterator_types = ["parallel", "parallel", "reduction"]
-}
-
 // CHECK-LABEL: func @matmul_3
 // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<1x2xf32>,
 // CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<3x1xf32>,
@@ -339,16 +350,6 @@ func.func @matmul_3_scalable(%arg0: vector<1x2xf32>,
 // ============================================================================
 //  Matmul 4 (plain)
 // ============================================================================
-#matmat_accesses_4 = [
-  affine_map<(m, n, k) -> (m, k)>,
-  affine_map<(m, n, k) -> (k, n)>,
-  affine_map<(m, n, k) -> (n, m)>
-]
-#matmat_trait_4 = {
-  indexing_maps = #matmat_accesses_4,
-  iterator_types = ["parallel", "parallel", "reduction"]
-}
-
 // CHECK-LABEL: func @matmul_4
 // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
 // CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>,

>From c9e73e5c691ea52a1a47c57abd8ac849948bb4ad Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Wed, 20 Dec 2023 12:40:52 +0000
Subject: [PATCH 3/5] fixup! [mlir][nfc] Update tests for Contract -> Op
 transforms

Update func var names
---
 ...act-to-outerproduct-matmul-transforms.mlir | 112 +++++++++---------
 1 file changed, 56 insertions(+), 56 deletions(-)

diff --git a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matmul-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matmul-transforms.mlir
index ce4b95da1f12f9..2b0870d2aa06d4 100644
--- a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matmul-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matmul-transforms.mlir
@@ -94,10 +94,10 @@
 // CHECK-SAME:  : vector<2xf32>, vector<3xf32>
 //
 //      CHECK: return %[[c3]] : vector<2x3xf32>
-func.func @matmul(%arg0: vector<2x4xf32>,
-                  %arg1: vector<4x3xf32>,
-                  %arg2: vector<2x3xf32>) -> vector<2x3xf32> {
-  %0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2
+func.func @matmul(%A: vector<2x4xf32>,
+                  %B: vector<4x3xf32>,
+                  %C: vector<2x3xf32>) -> vector<2x3xf32> {
+  %0 = vector.contract #matmat_trait_0 %A, %B, %C
     : vector<2x4xf32>, vector<4x3xf32> into vector<2x3xf32>
   return %0 : vector<2x3xf32>
 }
@@ -130,10 +130,10 @@ func.func @matmul(%arg0: vector<2x4xf32>,
 // CHECK-SAME:  : vector<2xf32>, vector<[3]xf32>
 //
 //      CHECK: return %[[c3]] : vector<2x[3]xf32>
-func.func @matmul_scalable(%arg0: vector<2x4xf32>,
-                           %arg1: vector<4x[3]xf32>,
-                           %arg2: vector<2x[3]xf32>) -> vector<2x[3]xf32> {
-  %0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2
+func.func @matmul_scalable(%A: vector<2x4xf32>,
+                           %B: vector<4x[3]xf32>,
+                           %C: vector<2x[3]xf32>) -> vector<2x[3]xf32> {
+  %0 = vector.contract #matmat_trait_0 %A, %B, %C
     : vector<2x4xf32>, vector<4x[3]xf32> into vector<2x[3]xf32>
   return %0 : vector<2x[3]xf32>
 }
@@ -155,11 +155,11 @@ func.func @matmul_scalable(%arg0: vector<2x4xf32>,
 // CHECK:         %[[T_MASK_R4:.*]] = vector.extract %[[T_MASK]][4] : vector<3x7xi1> from vector<5x3x7xi1>
 // CHECK:         %{{.*}} = vector.mask %[[T_MASK_R4]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
 
-func.func @masked_matmul(%arg0: vector<3x5xf32>,
-                         %arg1: vector<5x7xf32>,
-                         %arg2: vector<3x7xf32>,
+func.func @masked_matmul(%A: vector<3x5xf32>,
+                         %B: vector<5x7xf32>,
+                         %C: vector<3x7xf32>,
                          %m : vector<3x7x5xi1>) -> vector<3x7xf32> {
-  %0 = vector.mask %m { vector.contract #matmat_trait_0 %arg0, %arg1, %arg2
+  %0 = vector.mask %m { vector.contract #matmat_trait_0 %A, %B, %C
   : vector<3x5xf32>, vector<5x7xf32> into vector<3x7xf32> } : vector<3x7x5xi1> -> vector<3x7xf32>
   return %0 : vector<3x7xf32>
 }
@@ -181,11 +181,11 @@ func.func @masked_matmul(%arg0: vector<3x5xf32>,
 // CHECK:         %[[T_MASK_R4:.*]] = vector.extract %[[T_MASK]][4] : vector<3x[7]xi1> from vector<5x3x[7]xi1>
 // CHECK:         %{{.*}} = vector.mask %[[T_MASK_R4]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32>
 
-func.func @masked_matmul_scalable(%arg0: vector<3x5xf32>,
-                                  %arg1: vector<5x[7]xf32>,
-                                  %arg2: vector<3x[7]xf32>,
+func.func @masked_matmul_scalable(%A: vector<3x5xf32>,
+                                  %B: vector<5x[7]xf32>,
+                                  %C: vector<3x[7]xf32>,
                                   %m : vector<3x[7]x5xi1>) -> vector<3x[7]xf32> {
-  %0 = vector.mask %m { vector.contract #matmat_trait_0 %arg0, %arg1, %arg2
+  %0 = vector.mask %m { vector.contract #matmat_trait_0 %A, %B, %C
   : vector<3x5xf32>, vector<5x[7]xf32> into vector<3x[7]xf32> } : vector<3x[7]x5xi1> -> vector<3x[7]xf32>
   return %0 : vector<3x[7]xf32>
 }
@@ -201,11 +201,11 @@ func.func @masked_matmul_scalable(%arg0: vector<3x5xf32>,
 //      CHECK: %[[b1:.*]] = arith.extf %[[b0]] : vector<3xf16> to vector<3xf32>
 //      CHECK: %[[c0:.*]] = vector.outerproduct %[[a1]], %[[b1]], %[[C]]
 //      CHECK: return %[[c0]] : vector<2x3xf32>
-func.func @matmul_mixed(%arg0: vector<2x1xf16>,
-                          %arg1: vector<1x3xf16>,
-                          %arg2: vector<2x3xf32>) -> vector<2x3xf32>
+func.func @matmul_mixed(%A: vector<2x1xf16>,
+                        %B: vector<1x3xf16>,
+                        %C: vector<2x3xf32>) -> vector<2x3xf32>
 {
-  %0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2
+  %0 = vector.contract #matmat_trait_0 %A, %B, %C
     : vector<2x1xf16>, vector<1x3xf16> into vector<2x3xf32>
   return %0 : vector<2x3xf32>
 }
@@ -221,11 +221,11 @@ func.func @matmul_mixed(%arg0: vector<2x1xf16>,
 //      CHECK: %[[b1:.*]] = arith.extf %[[b0]] : vector<[3]xf16> to vector<[3]xf32>
 //      CHECK: %[[c0:.*]] = vector.outerproduct %[[a1]], %[[b1]], %[[C]]
 //      CHECK: return %[[c0]] : vector<2x[3]xf32>
-func.func @matmul_mixed_scalable(%arg0: vector<2x1xf16>,
-                                   %arg1: vector<1x[3]xf16>,
-                                   %arg2: vector<2x[3]xf32>) -> vector<2x[3]xf32>
+func.func @matmul_mixed_scalable(%A: vector<2x1xf16>,
+                                 %B: vector<1x[3]xf16>,
+                                 %C: vector<2x[3]xf32>) -> vector<2x[3]xf32>
 {
-  %0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2
+  %0 = vector.contract #matmat_trait_0 %A, %B, %C
     : vector<2x1xf16>, vector<1x[3]xf16> into vector<2x[3]xf32>
   return %0 : vector<2x[3]xf32>
 }
@@ -243,11 +243,11 @@ func.func @matmul_mixed_scalable(%arg0: vector<2x1xf16>,
 //      CHECK: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<3xf32> from vector<1x3xf32>
 //      CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
 //      CHECK: return %[[c0]] : vector<2x3xf32>
-func.func @matmul_1(%arg0: vector<2x1xf32>,
-                    %arg1: vector<3x1xf32>,
-                    %arg2: vector<2x3xf32>) -> vector<2x3xf32>
+func.func @matmul_1(%A: vector<2x1xf32>,
+                    %B: vector<3x1xf32>,
+                    %C: vector<2x3xf32>) -> vector<2x3xf32>
 {
-  %0 = vector.contract #matmat_trait_1 %arg0, %arg1, %arg2
+  %0 = vector.contract #matmat_trait_1 %A, %B, %C
     : vector<2x1xf32>, vector<3x1xf32> into vector<2x3xf32>
   return %0 : vector<2x3xf32>
 }
@@ -262,11 +262,11 @@ func.func @matmul_1(%arg0: vector<2x1xf32>,
 //      CHECK: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<[3]xf32> from vector<1x[3]xf32>
 //      CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
 //      CHECK: return %[[c0]] : vector<2x[3]xf32>
-func.func @matmul_1_scalable(%arg0: vector<2x1xf32>,
-                             %arg1: vector<[3]x1xf32>,
-                             %arg2: vector<2x[3]xf32>) -> vector<2x[3]xf32>
+func.func @matmul_1_scalable(%A: vector<2x1xf32>,
+                             %B: vector<[3]x1xf32>,
+                             %C: vector<2x[3]xf32>) -> vector<2x[3]xf32>
 {
-  %0 = vector.contract #matmat_trait_1 %arg0, %arg1, %arg2
+  %0 = vector.contract #matmat_trait_1 %A, %B, %C
     : vector<2x1xf32>, vector<[3]x1xf32> into vector<2x[3]xf32>
   return %0 : vector<2x[3]xf32>
 }
@@ -282,11 +282,11 @@ func.func @matmul_1_scalable(%arg0: vector<2x1xf32>,
 //      CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<3xf32> from vector<1x3xf32>
 //      CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
 //      CHECK: return %[[c0]] : vector<2x3xf32>
-func.func @matmul_2(%arg0: vector<1x2xf32>,
-                    %arg1: vector<1x3xf32>,
-                    %arg2: vector<2x3xf32>) -> vector<2x3xf32>
+func.func @matmul_2(%A: vector<1x2xf32>,
+                    %B: vector<1x3xf32>,
+                    %C: vector<2x3xf32>) -> vector<2x3xf32>
 {
-  %0 = vector.contract #matmat_trait_2 %arg0, %arg1, %arg2
+  %0 = vector.contract #matmat_trait_2 %A, %B, %C
     : vector<1x2xf32>, vector<1x3xf32> into vector<2x3xf32>
   return %0 : vector<2x3xf32>
 }
@@ -299,11 +299,11 @@ func.func @matmul_2(%arg0: vector<1x2xf32>,
 //      CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<[3]xf32> from vector<1x[3]xf32>
 //      CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
 //      CHECK: return %[[c0]] : vector<2x[3]xf32>
-func.func @matmul_2_scalable(%arg0: vector<1x2xf32>,
-                             %arg1: vector<1x[3]xf32>,
-                             %arg2: vector<2x[3]xf32>) -> vector<2x[3]xf32>
+func.func @matmul_2_scalable(%A: vector<1x2xf32>,
+                             %B: vector<1x[3]xf32>,
+                             %C: vector<2x[3]xf32>) -> vector<2x[3]xf32>
 {
-  %0 = vector.contract #matmat_trait_2 %arg0, %arg1, %arg2
+  %0 = vector.contract #matmat_trait_2 %A, %B, %C
     : vector<1x2xf32>, vector<1x[3]xf32> into vector<2x[3]xf32>
   return %0 : vector<2x[3]xf32>
 }
@@ -320,11 +320,11 @@ func.func @matmul_2_scalable(%arg0: vector<1x2xf32>,
 //      CHECK: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<3xf32> from vector<1x3xf32>
 //      CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
 //      CHECK: return %[[c0]] : vector<2x3xf32>
-func.func @matmul_3(%arg0: vector<1x2xf32>,
-                    %arg1: vector<3x1xf32>,
-                    %arg2: vector<2x3xf32>) -> vector<2x3xf32>
+func.func @matmul_3(%A: vector<1x2xf32>,
+                    %B: vector<3x1xf32>,
+                    %C: vector<2x3xf32>) -> vector<2x3xf32>
 {
-  %0 = vector.contract #matmat_trait_3 %arg0, %arg1, %arg2
+  %0 = vector.contract #matmat_trait_3 %A, %B, %C
     : vector<1x2xf32>, vector<3x1xf32> into vector<2x3xf32>
   return %0 : vector<2x3xf32>
 }
@@ -338,11 +338,11 @@ func.func @matmul_3(%arg0: vector<1x2xf32>,
 //      CHECK: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<[3]xf32> from vector<1x[3]xf32>
 //      CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
 //      CHECK: return %[[c0]] : vector<2x[3]xf32>
-func.func @matmul_3_scalable(%arg0: vector<1x2xf32>,
-                             %arg1: vector<[3]x1xf32>,
-                             %arg2: vector<2x[3]xf32>) -> vector<2x[3]xf32>
+func.func @matmul_3_scalable(%A: vector<1x2xf32>,
+                             %B: vector<[3]x1xf32>,
+                             %C: vector<2x[3]xf32>) -> vector<2x[3]xf32>
 {
-  %0 = vector.contract #matmat_trait_3 %arg0, %arg1, %arg2
+  %0 = vector.contract #matmat_trait_3 %A, %B, %C
     : vector<1x2xf32>, vector<[3]x1xf32> into vector<2x[3]xf32>
   return %0 : vector<2x[3]xf32>
 }
@@ -359,11 +359,11 @@ func.func @matmul_3_scalable(%arg0: vector<1x2xf32>,
 //      CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf32> from vector<1x2xf32>
 //      CHECK: %[[c0:.*]] = vector.outerproduct %[[b0]], %[[a0]], %[[C]]
 //      CHECK: return %[[c0]] : vector<3x2xf32>
-func.func @matmul_4(%arg0: vector<2x1xf32>,
-                    %arg1: vector<1x3xf32>,
-                    %arg2: vector<3x2xf32>) -> vector<3x2xf32>
+func.func @matmul_4(%A: vector<2x1xf32>,
+                    %B: vector<1x3xf32>,
+                    %C: vector<3x2xf32>) -> vector<3x2xf32>
 {
-  %0 = vector.contract #matmat_trait_4 %arg0, %arg1, %arg2
+  %0 = vector.contract #matmat_trait_4 %A, %B, %C
     : vector<2x1xf32>, vector<1x3xf32> into vector<3x2xf32>
   return %0 : vector<3x2xf32>
 }
@@ -377,11 +377,11 @@ func.func @matmul_4(%arg0: vector<2x1xf32>,
 //      CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<[2]xf32> from vector<1x[2]xf32>
 //      CHECK: %[[c0:.*]] = vector.outerproduct %[[b0]], %[[a0]], %[[C]]
 //      CHECK: return %[[c0]] : vector<3x[2]xf32>
-func.func @matmul_4_scalable(%arg0: vector<[2]x1xf32>,
-                             %arg1: vector<1x3xf32>,
-                             %arg2: vector<3x[2]xf32>) -> vector<3x[2]xf32>
+func.func @matmul_4_scalable(%A: vector<[2]x1xf32>,
+                             %B: vector<1x3xf32>,
+                             %C: vector<3x[2]xf32>) -> vector<3x[2]xf32>
 {
-  %0 = vector.contract #matmat_trait_4 %arg0, %arg1, %arg2
+  %0 = vector.contract #matmat_trait_4 %A, %B, %C
     : vector<[2]x1xf32>, vector<1x3xf32> into vector<3x[2]xf32>
   return %0 : vector<3x[2]xf32>
 }

>From 07de2f1e9e89f2c56caeee1aecf5f7a1d6a783fb Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Wed, 20 Dec 2023 12:43:13 +0000
Subject: [PATCH 4/5] fixup! [mlir][vector] Extend `CreateMaskFolder` (#75842)

Re-order matvec tests so that the one without masking is always first
---
 ...act-to-outerproduct-matvec-transforms.mlir | 60 +++++++++----------
 1 file changed, 30 insertions(+), 30 deletions(-)

diff --git a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir
index c09a4d569638a5..d86c6158bcdf2f 100644
--- a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir
@@ -235,6 +235,23 @@ func.func @masked_matvec_mk_k_m_max_scalable_parallel_dim(%A: vector<[2]x3xf32>,
 // ============================================================================
 //  Matvec 2 (plain + masked + scalable)
 // ============================================================================
+// CHECK-LABEL: func @matvec_km_k_m
+// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
+// CHECK-SAME: %[[X:.*1]]: vector<2xf32>
+// CHECK-SAME: %[[B:.*2]]: vector<2xf32>
+// CHECK: %[[T3:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[T4:.*]] = vector.extract %[[X]][0] : f32 from vector<2xf32>
+// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[B]] {kind = #vector.kind<add>} : vector<2xf32>, f32
+// CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[T7:.*]] = vector.extract %[[X]][1] : f32 from vector<2xf32>
+// CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] {kind = #vector.kind<add>} : vector<2xf32>, f32
+func.func @matvec_km_k_m(%A: vector<2x2xf32>,
+                         %x: vector<2xf32>,
+                         %b: vector<2xf32>) -> vector<2xf32> {
+  %0 = vector.contract #matvec_trait_2 %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32>
+  return %0 : vector<2xf32>
+}
+
 // CHECK-LABEL: @masked_matvec_km_k_m
 // CHECK-SAME:  %[[A:.+]]: vector<2x4xf32>
 // CHECK-SAME:  %[[X:.+]]: vector<2xf32>
@@ -273,26 +290,27 @@ func.func @masked_matvec_km_k_m_scalable_parallel_dim(%A: vector<2x[4]xf32>,
   return %res : vector<[4]xf32>
 }
 
-// CHECK-LABEL: func @matvec_km_k_m
+// ============================================================================
+//  Matvec 3 (plain + masked + scalable)
+// ============================================================================
+// CHECK-LABEL: func @matvec_k_mk_m
 // CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
 // CHECK-SAME: %[[X:.*1]]: vector<2xf32>
 // CHECK-SAME: %[[B:.*2]]: vector<2xf32>
-// CHECK: %[[T3:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T4:.*]] = vector.extract %[[X]][0] : f32 from vector<2xf32>
-// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[B]] {kind = #vector.kind<add>} : vector<2xf32>, f32
-// CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T7:.*]] = vector.extract %[[X]][1] : f32 from vector<2xf32>
-// CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] {kind = #vector.kind<add>} : vector<2xf32>, f32
-func.func @matvec_km_k_m(%A: vector<2x2xf32>,
+// CHECK: %[[T3:.*]] = vector.transpose %[[A]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
+// CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[T5:.*]] = vector.extract %[[X]][0] : f32 from vector<2xf32>
+// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[B]] {kind = #vector.kind<add>} : vector<2xf32>, f32
+// CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[T8:.*]] = vector.extract %[[X]][1] : f32 from vector<2xf32>
+// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind<add>} : vector<2xf32>, f32
+func.func @matvec_k_mk_m(%A: vector<2x2xf32>, 
                          %x: vector<2xf32>,
                          %b: vector<2xf32>) -> vector<2xf32> {
-  %0 = vector.contract #matvec_trait_2 %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32>
+  %0 = vector.contract #matvec_trait_3 %x, %A, %b : vector<2xf32>, vector<2x2xf32> into vector<2xf32>
   return %0 : vector<2xf32>
 }
 
-// ============================================================================
-//  Matvec 3 (plain + masked + scalable)
-// ============================================================================
 // CHECK-LABEL: @masked_matvec_k_mk_m
 // CHECK-SAME:  %[[A:.+]]: vector<4x2xf32>
 // CHECK-SAME:  %[[X:.+]]: vector<2xf32>
@@ -331,24 +349,6 @@ func.func @masked_matvec_k_mk_m_scalable_parallel_dim(%A: vector<[4]x2xf32>,
   return %res : vector<[4]xf32>
 }
 
-// CHECK-LABEL: func @matvec_k_mk_m
-// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
-// CHECK-SAME: %[[X:.*1]]: vector<2xf32>
-// CHECK-SAME: %[[B:.*2]]: vector<2xf32>
-// CHECK: %[[T3:.*]] = vector.transpose %[[A]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
-// CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T5:.*]] = vector.extract %[[X]][0] : f32 from vector<2xf32>
-// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[B]] {kind = #vector.kind<add>} : vector<2xf32>, f32
-// CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T8:.*]] = vector.extract %[[X]][1] : f32 from vector<2xf32>
-// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind<add>} : vector<2xf32>, f32
-func.func @matvec_k_mk_m(%A: vector<2x2xf32>, 
-                         %x: vector<2xf32>,
-                         %b: vector<2xf32>) -> vector<2xf32> {
-  %0 = vector.contract #matvec_trait_3 %x, %A, %b : vector<2xf32>, vector<2x2xf32> into vector<2xf32>
-  return %0 : vector<2xf32>
-}
-
 // ============================================================================
 //  Matvec 4 (plain + masked + scalable)
 // ============================================================================

>From ee77d06dcd637a4ee9caf191a181d7d8ffe3e0da Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Wed, 20 Dec 2023 12:44:23 +0000
Subject: [PATCH 5/5] fixup! [mlir][nfc] Update tests for Contract -> Op
 transforms

Update comments
---
 ...act-to-outerproduct-matmul-transforms.mlir | 42 ++++++++++---------
 1 file changed, 23 insertions(+), 19 deletions(-)

diff --git a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matmul-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matmul-transforms.mlir
index 2b0870d2aa06d4..96023124bd3ff8 100644
--- a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matmul-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matmul-transforms.mlir
@@ -1,20 +1,22 @@
 // RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s
 
-// NOTE - tests in this file are duplicated so that there's a version for
-//    * _fixed width_ and for _scalable_ vectors.
-// In order for the "vector.contract -> vector.outerproduct" patterns to work,
-// only the non-reduction dimension can be scalable (*). For Matmul operations
-// that is set to be the N dimension (i.e. rows of the output matrix), which
-// matches how matrix multiplication are normally implemented for e.g. 
-// Arm SVE. However, making the M dimension scalable (i.e. columns of the
-// output matrix) should work as well.
-//
-// (*) The conversion tested in this file unrolls along the reduction
-// dimension, which is not supported for scalable vectors.
+/// Tests for `vector.contract` -> `vector.outerproduct` transformations for
+/// Matmul operations:
+///   C += A * B.
+/// (A, B and C are 2-d matrices). ATM three different variants / are tested:
+///   * plain (no mask, fixed-wdith vectors),
+///   * masked (fixed-width vectors,
+///   * scalable (mask + scalable vectors).
+/// In order for the "vector.contract -> vector.outerproduct" patterns to work,
+/// only the non-reduction dimension can be scalable (*). For Matmul operations
+/// that is set to be the N dimension (i.e. rows of the output matrix), which
+/// matches how matrix multiplication are normally implemented for e.g.
+/// Arm SVE. However, making the M dimension scalable (i.e. columns of the
+/// output matrix) should work as well.
+///
+/// (*) The conversion tested in this file unrolls along the reduction
+/// dimension, which is not supported for scalable vectors.
 
-// ============================================================================
-//  Matmul 0 (plain + masked + mixed types)
-// ============================================================================
 #matmat_accesses_0 = [
   affine_map<(m, n, k) -> (m, k)>,
   affine_map<(m, n, k) -> (k, n)>,
@@ -65,7 +67,9 @@
   iterator_types = ["parallel", "parallel", "reduction"]
 }
 
-
+// ============================================================================
+//  Matmul 0 (plain + masked + mixed types)
+// ============================================================================
 // CHECK-LABEL: func @matmul
 // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x4xf32>,
 // CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x3xf32>,
@@ -231,7 +235,7 @@ func.func @matmul_mixed_scalable(%A: vector<2x1xf16>,
 }
 
 // ============================================================================
-//  Matmul 1 (plain)
+//  Matmul 1 (plain + scalable)
 // ============================================================================
 // CHECK-LABEL: func @matmul_1
 // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
@@ -272,7 +276,7 @@ func.func @matmul_1_scalable(%A: vector<2x1xf32>,
 }
 
 // ============================================================================
-//  Matmul 2 (plain)
+//  Matmul 2 (plain + scalable)
 // ============================================================================
 // CHECK-LABEL: func @matmul_2
 // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<1x2xf32>,
@@ -309,7 +313,7 @@ func.func @matmul_2_scalable(%A: vector<1x2xf32>,
 }
 
 // ============================================================================
-//  Matmul 3 (plain)
+//  Matmul 3 (plain + scalable)
 // ============================================================================
 // CHECK-LABEL: func @matmul_3
 // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<1x2xf32>,
@@ -348,7 +352,7 @@ func.func @matmul_3_scalable(%A: vector<1x2xf32>,
 }
 
 // ============================================================================
-//  Matmul 4 (plain)
+//  Matmul 4 (plain + scalable)
 // ============================================================================
 // CHECK-LABEL: func @matmul_4
 // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,



More information about the Mlir-commits mailing list