[Mlir-commits] [mlir] [mlir][Vector] Update v.contract -> v.outerproduct tests (NFC) (PR #70449)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Fri Oct 27 07:28:44 PDT 2023
https://github.com/banach-space updated https://github.com/llvm/llvm-project/pull/70449
>From 7416db35ffd56ac15b420d6ada7dbbf713ce2582 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Fri, 27 Oct 2023 07:52:09 +0000
Subject: [PATCH 1/2] [mlir][Vector] Update v.contract -> v.outerproduct tests
(NFC)
Re-orders tests in vector-contract-to-outerproduct-transforms.mlir so
that the file starts as follows:
1. plain matmul
2. plain matmul with scalable vectors
3. masked matmul
4. masked matmul with scalable vectors
5. plain matmul with mixed types
6. plain matmul with mixed types and scalable vectors
All of the above share the same indexing maps. This allowed to identify
one more duplicate. Following the cases above are examples with
different maps.
In addition, added extra comments to document the tests and to split
them into categories. There is also some extra reformatting to unify the
tests.
---
...r-contract-to-outerproduct-transforms.mlir | 305 ++++++++++--------
1 file changed, 165 insertions(+), 140 deletions(-)
diff --git a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
index 965d55c53ba333e..6151d1365aaea37 100644
--- a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
@@ -1,15 +1,22 @@
// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s
-#matmat_accesses = [
- affine_map<(i, j, k) -> (i, k)>,
- affine_map<(i, j, k) -> (k, j)>,
- affine_map<(i, j, k) -> (i, j)>
-]
-#matmat_trait = {
- indexing_maps = #matmat_accesses,
- iterator_types = ["parallel", "parallel", "reduction"]
-}
+// NOTE - tests in this file are duplicated so that there's a version for
+// * _fixed width_ and for _scalable_ vectors.
+// In order for the "vector.contract -> vector.outerproduct" patterns to work,
+// only the non-reduction dimension can be scalable (*). For Matmul operations
+// that is set to be the N dimension (i.e. rows of the output matrix), which
+// matches how matrix multiplication are normally implemented for e.g.
+// Arm's SVE. However, making the M dimension scalable (i.e. columns of the
+// output matrix) should work as well.
+//
+// (*) The conversion tested in this file unrolls along the reduction
+// dimension, which is not supported for scalable vectors.
+//
+// TODO: Matvec without a mask
+// ============================================================================
+// Matmul 0 (plain + masked + mixed types)
+// ============================================================================
#matmat_accesses_0 = [
affine_map<(m, n, k) -> (m, k)>,
affine_map<(m, n, k) -> (k, n)>,
@@ -20,62 +27,6 @@
iterator_types = ["parallel", "parallel", "reduction"]
}
-
-// CHECK-LABEL: func.func @masked_extract_contract4(
-// CHECK-SAME: %{{.*}}: vector<3x5xf32>,
-// CHECK-SAME: %{{.*}}: vector<5x7xf32>,
-// CHECK-SAME: %{{.*}}: vector<3x7xf32>,
-// CHECK-SAME: %[[IN_MASK:.*]]: vector<3x7x5xi1>) -> vector<3x7xf32> {
-// CHECK: %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [2, 0, 1] : vector<3x7x5xi1> to vector<5x3x7xi1>
-// CHECK: %[[T_MASK_R0:.*]] = vector.extract %[[T_MASK]][0] : vector<3x7xi1> from vector<5x3x7xi1>
-// CHECK: %{{.*}} = vector.mask %[[T_MASK_R0]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
-// CHECK: %[[T_MASK_R1:.*]] = vector.extract %[[T_MASK]][1] : vector<3x7xi1> from vector<5x3x7xi1>
-// CHECK: %{{.*}} = vector.mask %[[T_MASK_R1]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
-// CHECK: %[[T_MASK_R2:.*]] = vector.extract %[[T_MASK]][2] : vector<3x7xi1> from vector<5x3x7xi1>
-// CHECK: %{{.*}} = vector.mask %[[T_MASK_R2]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
-// CHECK: %[[T_MASK_R3:.*]] = vector.extract %[[T_MASK]][3] : vector<3x7xi1> from vector<5x3x7xi1>
-// CHECK: %{{.*}} = vector.mask %[[T_MASK_R3]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
-// CHECK: %[[T_MASK_R4:.*]] = vector.extract %[[T_MASK]][4] : vector<3x7xi1> from vector<5x3x7xi1>
-// CHECK: %{{.*}} = vector.mask %[[T_MASK_R4]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
-
-func.func @masked_extract_contract4(%arg0: vector<3x5xf32>,
- %arg1: vector<5x7xf32>,
- %arg2: vector<3x7xf32>,
- %m : vector<3x7x5xi1>) -> vector<3x7xf32> {
- %0 = vector.mask %m { vector.contract #matmat_trait %arg0, %arg1, %arg2
- : vector<3x5xf32>, vector<5x7xf32> into vector<3x7xf32> } : vector<3x7x5xi1> -> vector<3x7xf32>
- return %0 : vector<3x7xf32>
-}
-
-// CHECK-LABEL: func.func @masked_extract_contract4_scalable_J_dim(
-// CHECK-SAME: %{{.*}}: vector<3x5xf32>,
-// CHECK-SAME: %{{.*}}: vector<5x[7]xf32>,
-// CHECK-SAME: %{{.*}}: vector<3x[7]xf32>,
-// CHECK-SAME: %[[IN_MASK:.*]]: vector<3x[7]x5xi1>) -> vector<3x[7]xf32> {
-// CHECK: %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [2, 0, 1] : vector<3x[7]x5xi1> to vector<5x3x[7]xi1>
-// CHECK: %[[T_MASK_R0:.*]] = vector.extract %[[T_MASK]][0] : vector<3x[7]xi1> from vector<5x3x[7]xi1>
-// CHECK: %{{.*}} = vector.mask %[[T_MASK_R0]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32>
-// CHECK: %[[T_MASK_R1:.*]] = vector.extract %[[T_MASK]][1] : vector<3x[7]xi1> from vector<5x3x[7]xi1>
-// CHECK: %[[VAL_13:.*]] = vector.mask %[[T_MASK_R1]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32>
-// CHECK: %[[T_MASK_R2:.*]] = vector.extract %[[T_MASK]][2] : vector<3x[7]xi1> from vector<5x3x[7]xi1>
-// CHECK: %{{.*}} = vector.mask %[[T_MASK_R2]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32>
-// CHECK: %[[T_MASK_R3:.*]] = vector.extract %[[T_MASK]][3] : vector<3x[7]xi1> from vector<5x3x[7]xi1>
-// CHECK: %{{.*}} = vector.mask %[[T_MASK_R3]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32>
-// CHECK: %[[T_MASK_R4:.*]] = vector.extract %[[T_MASK]][4] : vector<3x[7]xi1> from vector<5x3x[7]xi1>
-// CHECK: %{{.*}} = vector.mask %[[T_MASK_R4]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32>
-
-// Note that only the J dimension is scalable in this example. In theory, all
-// dimensions could be be scalable, but there is no target yet for which this
-// would make sense.
-func.func @masked_extract_contract4_scalable_J_dim(%arg0: vector<3x5xf32>,
- %arg1: vector<5x[7]xf32>,
- %arg2: vector<3x[7]xf32>,
- %m : vector<3x[7]x5xi1>) -> vector<3x[7]xf32> {
- %0 = vector.mask %m { vector.contract #matmat_trait %arg0, %arg1, %arg2
- : vector<3x5xf32>, vector<5x[7]xf32> into vector<3x[7]xf32> } : vector<3x[7]x5xi1> -> vector<3x[7]xf32>
- return %0 : vector<3x[7]xf32>
-}
-
// CHECK-LABEL: func @matmul
// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x4xf32>,
// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x3xf32>,
@@ -105,9 +56,9 @@ func.func @masked_extract_contract4_scalable_J_dim(%arg0: vector<3x5xf32>,
//
// CHECK: return %[[c3]] : vector<2x3xf32>
func.func @matmul(%arg0: vector<2x4xf32>,
- %arg1: vector<4x3xf32>,
- %arg2: vector<2x3xf32>) -> vector<2x3xf32> {
- %0 = vector.contract #matmat_trait %arg0, %arg1, %arg2
+ %arg1: vector<4x3xf32>,
+ %arg2: vector<2x3xf32>) -> vector<2x3xf32> {
+ %0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2
: vector<2x4xf32>, vector<4x3xf32> into vector<2x3xf32>
return %0 : vector<2x3xf32>
}
@@ -141,48 +92,66 @@ func.func @matmul(%arg0: vector<2x4xf32>,
//
// CHECK: return %[[c3]] : vector<2x[3]xf32>
func.func @matmul_scalable(%arg0: vector<2x4xf32>,
- %arg1: vector<4x[3]xf32>,
- %arg2: vector<2x[3]xf32>) -> vector<2x[3]xf32> {
- %0 = vector.contract #matmat_trait %arg0, %arg1, %arg2
+ %arg1: vector<4x[3]xf32>,
+ %arg2: vector<2x[3]xf32>) -> vector<2x[3]xf32> {
+ %0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2
: vector<2x4xf32>, vector<4x[3]xf32> into vector<2x[3]xf32>
return %0 : vector<2x[3]xf32>
}
-// CHECK-LABEL: func @matmul_0
-// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
-// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>,
-// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32>
-// CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
-// CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf32> from vector<1x2xf32>
-// CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<3xf32> from vector<1x3xf32>
-// CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
-// CHECK: return %[[c0]] : vector<2x3xf32>
-func.func @matmul_0(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vector<2x3xf32>)
--> vector<2x3xf32>
-{
- %0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2
- : vector<2x1xf32>, vector<1x3xf32> into vector<2x3xf32>
- return %0 : vector<2x3xf32>
+// CHECK-LABEL: func.func @masked_matmul(
+// CHECK-SAME: %{{.*}}: vector<3x5xf32>,
+// CHECK-SAME: %{{.*}}: vector<5x7xf32>,
+// CHECK-SAME: %{{.*}}: vector<3x7xf32>,
+// CHECK-SAME: %[[IN_MASK:.*]]: vector<3x7x5xi1>) -> vector<3x7xf32> {
+// CHECK: %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [2, 0, 1] : vector<3x7x5xi1> to vector<5x3x7xi1>
+// CHECK: %[[T_MASK_R0:.*]] = vector.extract %[[T_MASK]][0] : vector<3x7xi1> from vector<5x3x7xi1>
+// CHECK: %{{.*}} = vector.mask %[[T_MASK_R0]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
+// CHECK: %[[T_MASK_R1:.*]] = vector.extract %[[T_MASK]][1] : vector<3x7xi1> from vector<5x3x7xi1>
+// CHECK: %{{.*}} = vector.mask %[[T_MASK_R1]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
+// CHECK: %[[T_MASK_R2:.*]] = vector.extract %[[T_MASK]][2] : vector<3x7xi1> from vector<5x3x7xi1>
+// CHECK: %{{.*}} = vector.mask %[[T_MASK_R2]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
+// CHECK: %[[T_MASK_R3:.*]] = vector.extract %[[T_MASK]][3] : vector<3x7xi1> from vector<5x3x7xi1>
+// CHECK: %{{.*}} = vector.mask %[[T_MASK_R3]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
+// CHECK: %[[T_MASK_R4:.*]] = vector.extract %[[T_MASK]][4] : vector<3x7xi1> from vector<5x3x7xi1>
+// CHECK: %{{.*}} = vector.mask %[[T_MASK_R4]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
+
+func.func @masked_matmul(%arg0: vector<3x5xf32>,
+ %arg1: vector<5x7xf32>,
+ %arg2: vector<3x7xf32>,
+ %m : vector<3x7x5xi1>) -> vector<3x7xf32> {
+ %0 = vector.mask %m { vector.contract #matmat_trait_0 %arg0, %arg1, %arg2
+ : vector<3x5xf32>, vector<5x7xf32> into vector<3x7xf32> } : vector<3x7x5xi1> -> vector<3x7xf32>
+ return %0 : vector<3x7xf32>
}
-// CHECK-LABEL: func @matmul_0_scalable
-// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
-// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x[3]xf32>,
-// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x[3]xf32>
-// CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
-// CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf32> from vector<1x2xf32>
-// CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<[3]xf32> from vector<1x[3]xf32>
-// CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
-// CHECK: return %[[c0]] : vector<2x[3]xf32>
-func.func @matmul_0_scalable(%arg0: vector<2x1xf32>, %arg1: vector<1x[3]xf32>, %arg2: vector<2x[3]xf32>)
--> vector<2x[3]xf32>
-{
- %0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2
- : vector<2x1xf32>, vector<1x[3]xf32> into vector<2x[3]xf32>
- return %0 : vector<2x[3]xf32>
+// CHECK-LABEL: func.func @masked_matmul_scalable(
+// CHECK-SAME: %{{.*}}: vector<3x5xf32>,
+// CHECK-SAME: %{{.*}}: vector<5x[7]xf32>,
+// CHECK-SAME: %{{.*}}: vector<3x[7]xf32>,
+// CHECK-SAME: %[[IN_MASK:.*]]: vector<3x[7]x5xi1>) -> vector<3x[7]xf32> {
+// CHECK: %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [2, 0, 1] : vector<3x[7]x5xi1> to vector<5x3x[7]xi1>
+// CHECK: %[[T_MASK_R0:.*]] = vector.extract %[[T_MASK]][0] : vector<3x[7]xi1> from vector<5x3x[7]xi1>
+// CHECK: %{{.*}} = vector.mask %[[T_MASK_R0]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32>
+// CHECK: %[[T_MASK_R1:.*]] = vector.extract %[[T_MASK]][1] : vector<3x[7]xi1> from vector<5x3x[7]xi1>
+// CHECK: %[[VAL_13:.*]] = vector.mask %[[T_MASK_R1]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32>
+// CHECK: %[[T_MASK_R2:.*]] = vector.extract %[[T_MASK]][2] : vector<3x[7]xi1> from vector<5x3x[7]xi1>
+// CHECK: %{{.*}} = vector.mask %[[T_MASK_R2]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32>
+// CHECK: %[[T_MASK_R3:.*]] = vector.extract %[[T_MASK]][3] : vector<3x[7]xi1> from vector<5x3x[7]xi1>
+// CHECK: %{{.*}} = vector.mask %[[T_MASK_R3]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32>
+// CHECK: %[[T_MASK_R4:.*]] = vector.extract %[[T_MASK]][4] : vector<3x[7]xi1> from vector<5x3x[7]xi1>
+// CHECK: %{{.*}} = vector.mask %[[T_MASK_R4]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32>
+
+func.func @masked_matmul_scalable(%arg0: vector<3x5xf32>,
+ %arg1: vector<5x[7]xf32>,
+ %arg2: vector<3x[7]xf32>,
+ %m : vector<3x[7]x5xi1>) -> vector<3x[7]xf32> {
+ %0 = vector.mask %m { vector.contract #matmat_trait_0 %arg0, %arg1, %arg2
+ : vector<3x5xf32>, vector<5x[7]xf32> into vector<3x[7]xf32> } : vector<3x[7]x5xi1> -> vector<3x[7]xf32>
+ return %0 : vector<3x[7]xf32>
}
-// CHECK-LABEL: func @matmul_0_mixed
+// CHECK-LABEL: func @matmul_mixed
// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf16>,
// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf16>,
// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32>
@@ -193,15 +162,16 @@ func.func @matmul_0_scalable(%arg0: vector<2x1xf32>, %arg1: vector<1x[3]xf32>, %
// CHECK: %[[b1:.*]] = arith.extf %[[b0]] : vector<3xf16> to vector<3xf32>
// CHECK: %[[c0:.*]] = vector.outerproduct %[[a1]], %[[b1]], %[[C]]
// CHECK: return %[[c0]] : vector<2x3xf32>
-func.func @matmul_0_mixed(%arg0: vector<2x1xf16>, %arg1: vector<1x3xf16>, %arg2: vector<2x3xf32>)
--> vector<2x3xf32>
+func.func @matmul_mixed(%arg0: vector<2x1xf16>,
+ %arg1: vector<1x3xf16>,
+ %arg2: vector<2x3xf32>) -> vector<2x3xf32>
{
%0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2
: vector<2x1xf16>, vector<1x3xf16> into vector<2x3xf32>
return %0 : vector<2x3xf32>
}
-// CHECK-LABEL: func @matmul_0_mixed_scalable
+// CHECK-LABEL: func @matmul_mixed_scalable
// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf16>,
// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x[3]xf16>,
// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x[3]xf32>
@@ -212,14 +182,18 @@ func.func @matmul_0_mixed(%arg0: vector<2x1xf16>, %arg1: vector<1x3xf16>, %arg2:
// CHECK: %[[b1:.*]] = arith.extf %[[b0]] : vector<[3]xf16> to vector<[3]xf32>
// CHECK: %[[c0:.*]] = vector.outerproduct %[[a1]], %[[b1]], %[[C]]
// CHECK: return %[[c0]] : vector<2x[3]xf32>
-func.func @matmul_0_mixed_scalable(%arg0: vector<2x1xf16>, %arg1: vector<1x[3]xf16>, %arg2: vector<2x[3]xf32>)
--> vector<2x[3]xf32>
+func.func @matmul_mixed_scalable(%arg0: vector<2x1xf16>,
+ %arg1: vector<1x[3]xf16>,
+ %arg2: vector<2x[3]xf32>) -> vector<2x[3]xf32>
{
%0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2
: vector<2x1xf16>, vector<1x[3]xf16> into vector<2x[3]xf32>
return %0 : vector<2x[3]xf32>
}
+// ============================================================================
+// Matmul 1 (plain)
+// ============================================================================
#matmat_accesses_1 = [
affine_map<(m, n, k) -> (m, k)>,
affine_map<(m, n, k) -> (n, k)>,
@@ -240,8 +214,9 @@ func.func @matmul_0_mixed_scalable(%arg0: vector<2x1xf16>, %arg1: vector<1x[3]xf
// CHECK: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<3xf32> from vector<1x3xf32>
// CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
// CHECK: return %[[c0]] : vector<2x3xf32>
-func.func @matmul_1(%arg0: vector<2x1xf32>, %arg1: vector<3x1xf32>, %arg2: vector<2x3xf32>)
--> vector<2x3xf32>
+func.func @matmul_1(%arg0: vector<2x1xf32>,
+ %arg1: vector<3x1xf32>,
+ %arg2: vector<2x3xf32>) -> vector<2x3xf32>
{
%0 = vector.contract #matmat_trait_1 %arg0, %arg1, %arg2
: vector<2x1xf32>, vector<3x1xf32> into vector<2x3xf32>
@@ -258,14 +233,18 @@ func.func @matmul_1(%arg0: vector<2x1xf32>, %arg1: vector<3x1xf32>, %arg2: vecto
// CHECK: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<[3]xf32> from vector<1x[3]xf32>
// CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
// CHECK: return %[[c0]] : vector<2x[3]xf32>
-func.func @matmul_1_scalable(%arg0: vector<2x1xf32>, %arg1: vector<[3]x1xf32>, %arg2: vector<2x[3]xf32>)
--> vector<2x[3]xf32>
+func.func @matmul_1_scalable(%arg0: vector<2x1xf32>,
+ %arg1: vector<[3]x1xf32>,
+ %arg2: vector<2x[3]xf32>) -> vector<2x[3]xf32>
{
%0 = vector.contract #matmat_trait_1 %arg0, %arg1, %arg2
: vector<2x1xf32>, vector<[3]x1xf32> into vector<2x[3]xf32>
return %0 : vector<2x[3]xf32>
}
+// ============================================================================
+// Matmul 2 (plain)
+// ============================================================================
#matmat_accesses_2 = [
affine_map<(m, n, k) -> (k, m)>,
affine_map<(m, n, k) -> (k, n)>,
@@ -284,8 +263,9 @@ func.func @matmul_1_scalable(%arg0: vector<2x1xf32>, %arg1: vector<[3]x1xf32>, %
// CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<3xf32> from vector<1x3xf32>
// CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
// CHECK: return %[[c0]] : vector<2x3xf32>
-func.func @matmul_2(%arg0: vector<1x2xf32>, %arg1: vector<1x3xf32>, %arg2: vector<2x3xf32>)
--> vector<2x3xf32>
+func.func @matmul_2(%arg0: vector<1x2xf32>,
+ %arg1: vector<1x3xf32>,
+ %arg2: vector<2x3xf32>) -> vector<2x3xf32>
{
%0 = vector.contract #matmat_trait_2 %arg0, %arg1, %arg2
: vector<1x2xf32>, vector<1x3xf32> into vector<2x3xf32>
@@ -300,14 +280,18 @@ func.func @matmul_2(%arg0: vector<1x2xf32>, %arg1: vector<1x3xf32>, %arg2: vecto
// CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<[3]xf32> from vector<1x[3]xf32>
// CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
// CHECK: return %[[c0]] : vector<2x[3]xf32>
-func.func @matmul_2_scalable(%arg0: vector<1x2xf32>, %arg1: vector<1x[3]xf32>, %arg2: vector<2x[3]xf32>)
--> vector<2x[3]xf32>
+func.func @matmul_2_scalable(%arg0: vector<1x2xf32>,
+ %arg1: vector<1x[3]xf32>,
+ %arg2: vector<2x[3]xf32>) -> vector<2x[3]xf32>
{
%0 = vector.contract #matmat_trait_2 %arg0, %arg1, %arg2
: vector<1x2xf32>, vector<1x[3]xf32> into vector<2x[3]xf32>
return %0 : vector<2x[3]xf32>
}
+// ============================================================================
+// Matmul 3 (plain)
+// ============================================================================
#matmat_accesses_3 = [
affine_map<(m, n, k) -> (k, m)>,
affine_map<(m, n, k) -> (n, k)>,
@@ -327,8 +311,9 @@ func.func @matmul_2_scalable(%arg0: vector<1x2xf32>, %arg1: vector<1x[3]xf32>, %
// CHECK: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<3xf32> from vector<1x3xf32>
// CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
// CHECK: return %[[c0]] : vector<2x3xf32>
-func.func @matmul_3(%arg0: vector<1x2xf32>, %arg1: vector<3x1xf32>, %arg2: vector<2x3xf32>)
--> vector<2x3xf32>
+func.func @matmul_3(%arg0: vector<1x2xf32>,
+ %arg1: vector<3x1xf32>,
+ %arg2: vector<2x3xf32>) -> vector<2x3xf32>
{
%0 = vector.contract #matmat_trait_3 %arg0, %arg1, %arg2
: vector<1x2xf32>, vector<3x1xf32> into vector<2x3xf32>
@@ -344,14 +329,18 @@ func.func @matmul_3(%arg0: vector<1x2xf32>, %arg1: vector<3x1xf32>, %arg2: vecto
// CHECK: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<[3]xf32> from vector<1x[3]xf32>
// CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
// CHECK: return %[[c0]] : vector<2x[3]xf32>
-func.func @matmul_3_scalable(%arg0: vector<1x2xf32>, %arg1: vector<[3]x1xf32>, %arg2: vector<2x[3]xf32>)
--> vector<2x[3]xf32>
+func.func @matmul_3_scalable(%arg0: vector<1x2xf32>,
+ %arg1: vector<[3]x1xf32>,
+ %arg2: vector<2x[3]xf32>) -> vector<2x[3]xf32>
{
%0 = vector.contract #matmat_trait_3 %arg0, %arg1, %arg2
: vector<1x2xf32>, vector<[3]x1xf32> into vector<2x[3]xf32>
return %0 : vector<2x[3]xf32>
}
+// ============================================================================
+// Matmul 4 (plain)
+// ============================================================================
#matmat_accesses_4 = [
affine_map<(m, n, k) -> (m, k)>,
affine_map<(m, n, k) -> (k, n)>,
@@ -371,8 +360,9 @@ func.func @matmul_3_scalable(%arg0: vector<1x2xf32>, %arg1: vector<[3]x1xf32>, %
// CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf32> from vector<1x2xf32>
// CHECK: %[[c0:.*]] = vector.outerproduct %[[b0]], %[[a0]], %[[C]]
// CHECK: return %[[c0]] : vector<3x2xf32>
-func.func @matmul_4(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vector<3x2xf32>)
--> vector<3x2xf32>
+func.func @matmul_4(%arg0: vector<2x1xf32>,
+ %arg1: vector<1x3xf32>,
+ %arg2: vector<3x2xf32>) -> vector<3x2xf32>
{
%0 = vector.contract #matmat_trait_4 %arg0, %arg1, %arg2
: vector<2x1xf32>, vector<1x3xf32> into vector<3x2xf32>
@@ -388,24 +378,18 @@ func.func @matmul_4(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vecto
// CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<[2]xf32> from vector<1x[2]xf32>
// CHECK: %[[c0:.*]] = vector.outerproduct %[[b0]], %[[a0]], %[[C]]
// CHECK: return %[[c0]] : vector<3x[2]xf32>
-func.func @matmul_4_scalable(%arg0: vector<[2]x1xf32>, %arg1: vector<1x3xf32>, %arg2: vector<3x[2]xf32>)
--> vector<3x[2]xf32>
+func.func @matmul_4_scalable(%arg0: vector<[2]x1xf32>,
+ %arg1: vector<1x3xf32>,
+ %arg2: vector<3x[2]xf32>) -> vector<3x[2]xf32>
{
%0 = vector.contract #matmat_trait_4 %arg0, %arg1, %arg2
: vector<[2]x1xf32>, vector<1x3xf32> into vector<3x[2]xf32>
return %0 : vector<3x[2]xf32>
}
-#matmat_accesses_5 = [
- affine_map<(m, n, k) -> (m, k)>,
- affine_map<(m, n, k) -> (k, n)>,
- affine_map<(m, n, k) -> (n, m)>
-]
-#matmat_trait_5 = {
- indexing_maps = #matmat_accesses_5,
- iterator_types = ["parallel", "parallel", "reduction"]
-}
-
+// ============================================================================
+// Matvec 1 (masked)
+// ============================================================================
#matvec_accesses_1 = [
affine_map<(m, k) -> (m, k)>,
affine_map<(m, k) -> (k)>,
@@ -440,7 +424,6 @@ func.func @masked_matvec_mk_k_m(%arg0: vector<2x3xf32>,
return %0 : vector<2xf32>
}
-
// CHECK-LABEL: func.func @masked_matvec_mk_k_m_scalable_parallel_dim(
// CHECK-SAME: %{{.*}}: vector<[2]x3xf32>,
// CHECK-SAME: %{{.*}}: vector<3xf32>,
@@ -464,6 +447,9 @@ func.func @masked_matvec_mk_k_m_scalable_parallel_dim(%arg0: vector<[2]x3xf32>,
return %0 : vector<[2]xf32>
}
+// ============================================================================
+// Matvec 2 (masked)
+// ============================================================================
#matvec_accesses_2 = [
affine_map<(m, k) -> (k, m)>,
affine_map<(m, k) -> (k)>,
@@ -479,7 +465,10 @@ func.func @masked_matvec_mk_k_m_scalable_parallel_dim(%arg0: vector<[2]x3xf32>,
// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
// CHECK-SAME: %[[INIT:.+]]: vector<4xf32>
// CHECK-SAME: %[[MASK:.+]]: vector<4x2xi1>
-func.func @masked_matvec_km_k_m(%arg0: vector<2x4xf32>, %arg1: vector<2xf32>, %arg2: vector<4xf32>, %mask: vector<4x2xi1>) -> vector<4xf32> {
+func.func @masked_matvec_km_k_m(%arg0: vector<2x4xf32>,
+ %arg1: vector<2xf32>,
+ %arg2: vector<4xf32>,
+ %mask: vector<4x2xi1>) -> vector<4xf32> {
// CHECK: vector.transpose %[[MASK]]
// CHECK-NOT: vector.transpose %[[MAT]]
// CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
@@ -495,7 +484,10 @@ func.func @masked_matvec_km_k_m(%arg0: vector<2x4xf32>, %arg1: vector<2xf32>, %a
// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
// CHECK-SAME: %[[INIT:.+]]: vector<[4]xf32>
// CHECK-SAME: %[[MASK:.+]]: vector<[4]x2xi1>
-func.func @masked_matvec_km_k_m_scalable_parallel_dim(%arg0: vector<2x[4]xf32>, %arg1: vector<2xf32>, %arg2: vector<[4]xf32>, %mask: vector<[4]x2xi1>) -> vector<[4]xf32> {
+func.func @masked_matvec_km_k_m_scalable_parallel_dim(%arg0: vector<2x[4]xf32>,
+ %arg1: vector<2xf32>,
+ %arg2: vector<[4]xf32>,
+ %mask: vector<[4]x2xi1>) -> vector<[4]xf32> {
// CHECK: vector.transpose %[[MASK]]
// CHECK-NOT: vector.transpose %[[MAT]]
// CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
@@ -506,6 +498,9 @@ func.func @masked_matvec_km_k_m_scalable_parallel_dim(%arg0: vector<2x[4]xf32>,
return %res : vector<[4]xf32>
}
+// ============================================================================
+// Matvec 3 (masked)
+// ============================================================================
#matvec_accesses_3 = [
affine_map<(m, k) -> (k)>,
affine_map<(m, k) -> (m, k)>,
@@ -521,7 +516,10 @@ func.func @masked_matvec_km_k_m_scalable_parallel_dim(%arg0: vector<2x[4]xf32>,
// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
// CHECK-SAME: %[[INIT:.+]]: vector<4xf32>
// CHECK-SAME: %[[MASK:.+]]: vector<4x2xi1>
-func.func @masked_matvec_k_mk_m(%arg0: vector<4x2xf32>, %arg1: vector<2xf32>, %arg2: vector<4xf32>, %mask: vector<4x2xi1>) -> vector<4xf32> {
+func.func @masked_matvec_k_mk_m(%arg0: vector<4x2xf32>,
+ %arg1: vector<2xf32>,
+ %arg2: vector<4xf32>,
+ %mask: vector<4x2xi1>) -> vector<4xf32> {
// CHECK: vector.transpose %[[MASK]]
// CHECK: vector.transpose %[[MAT]]
// CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
@@ -537,7 +535,10 @@ func.func @masked_matvec_k_mk_m(%arg0: vector<4x2xf32>, %arg1: vector<2xf32>, %a
// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
// CHECK-SAME: %[[INIT:.+]]: vector<[4]xf32>
// CHECK-SAME: %[[MASK:.+]]: vector<[4]x2xi1>
-func.func @masked_matvec_k_mk_m_scalable_parallel_dim(%arg0: vector<[4]x2xf32>, %arg1: vector<2xf32>, %arg2: vector<[4]xf32>, %mask: vector<[4]x2xi1>) -> vector<[4]xf32> {
+func.func @masked_matvec_k_mk_m_scalable_parallel_dim(%arg0: vector<[4]x2xf32>,
+ %arg1: vector<2xf32>,
+ %arg2: vector<[4]xf32>,
+ %mask: vector<[4]x2xi1>) -> vector<[4]xf32> {
// CHECK: vector.transpose %[[MASK]]
// CHECK: vector.transpose %[[MAT]]
// CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
@@ -548,6 +549,9 @@ func.func @masked_matvec_k_mk_m_scalable_parallel_dim(%arg0: vector<[4]x2xf32>,
return %res : vector<[4]xf32>
}
+// ============================================================================
+// Matvec 4 (masked)
+// ============================================================================
#matvec_accesses_4 = [
affine_map<(m, k) -> (k)>,
affine_map<(m, k) -> (k, m)>,
@@ -563,7 +567,10 @@ func.func @masked_matvec_k_mk_m_scalable_parallel_dim(%arg0: vector<[4]x2xf32>,
// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
// CHECK-SAME: %[[INIT:.+]]: vector<[4]xf32>
// CHECK-SAME: %[[MASK:.+]]: vector<[4]x2xi1>
-func.func @masked_matvec_k_km_m_scalable_parallel_dim(%arg0: vector<2x[4]xf32>, %arg1: vector<2xf32>, %arg2: vector<[4]xf32>, %mask: vector<[4]x2xi1>) -> vector<[4]xf32> {
+func.func @masked_matvec_k_km_m_scalable_parallel_dim(%arg0: vector<2x[4]xf32>,
+ %arg1: vector<2xf32>,
+ %arg2: vector<[4]xf32>,
+ %mask: vector<[4]x2xi1>) -> vector<[4]xf32> {
// CHECK: vector.transpose %[[MASK]]
// CHECK-NOT: vector.transpose %[[MAT]]
// CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
@@ -579,7 +586,10 @@ func.func @masked_matvec_k_km_m_scalable_parallel_dim(%arg0: vector<2x[4]xf32>,
// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
// CHECK-SAME: %[[INIT:.+]]: vector<4xf32>
// CHECK-SAME: %[[MASK:.+]]: vector<4x2xi1>
-func.func @masked_matvec_k_km_m(%arg0: vector<2x4xf32>, %arg1: vector<2xf32>, %arg2: vector<4xf32>, %mask: vector<4x2xi1>) -> vector<4xf32> {
+func.func @masked_matvec_k_km_m(%arg0: vector<2x4xf32>,
+ %arg1: vector<2xf32>,
+ %arg2: vector<4xf32>,
+ %mask: vector<4x2xi1>) -> vector<4xf32> {
// CHECK: vector.transpose %[[MASK]]
// CHECK-NOT: vector.transpose %[[MAT]]
// CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
@@ -590,6 +600,9 @@ func.func @masked_matvec_k_km_m(%arg0: vector<2x4xf32>, %arg1: vector<2xf32>, %a
return %res : vector<4xf32>
}
+// ============================================================================
+// Matvec 5 (masked)
+// ============================================================================
#matvec_accesses_5 = [
affine_map<(k, m) -> (m, k)>,
affine_map<(k, m) -> (k)>,
@@ -632,6 +645,9 @@ func.func @masked_tmatvec_mk_k_m_scalable_parallel_dim(%arg0: vector<[4]x2xf32>,
return %res : vector<[4]xf32>
}
+// ============================================================================
+// Matvec 6 (masked)
+// ============================================================================
#matvec_accesses_6 = [
affine_map<(k, m) -> (k, m)>,
affine_map<(k, m) -> (k)>,
@@ -674,6 +690,9 @@ func.func @masked_tmatvec_km_k_m_scalable_parallel_dim(%arg0: vector<2x[4]xf32>,
return %res : vector<[4]xf32>
}
+// ============================================================================
+// Matvec 7 (masked)
+// ============================================================================
#matvec_accesses_7 = [
affine_map<(k, m) -> (k)>,
affine_map<(k, m) -> (m, k)>,
@@ -716,6 +735,9 @@ func.func @masked_tmatvec_k_mk_m_scalable_parallel_dim(%arg0: vector<[4]x2xf32>,
return %res : vector<[4]xf32>
}
+// ============================================================================
+// Matvec 8 (masked)
+// ============================================================================
#matvec_accesses_8 = [
affine_map<(k, m) -> (k)>,
affine_map<(k, m) -> (k, m)>,
@@ -759,6 +781,9 @@ func.func @masked_tmatvec_k_km_m_scalable_parallel_dim(%arg0: vector<2x[4]xf32>,
}
+// ============================================================================
+// TD sequence
+// ============================================================================
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
%f = transform.structured.match ops{["func.func"]} in %module_op
>From 629d1242ea923e4682591914ee43e6b526b77e92 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Fri, 27 Oct 2023 14:28:16 +0000
Subject: [PATCH 2/2] fixup! [mlir][Vector] Update v.contract -> v.outerproduct
tests (NFC)
Fix typo
---
.../Vector/vector-contract-to-outerproduct-transforms.mlir | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
index 6151d1365aaea37..2c228a04873e5c0 100644
--- a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
@@ -6,7 +6,7 @@
// only the non-reduction dimension can be scalable (*). For Matmul operations
// that is set to be the N dimension (i.e. rows of the output matrix), which
// matches how matrix multiplication are normally implemented for e.g.
-// Arm's SVE. However, making the M dimension scalable (i.e. columns of the
+// Arm SVE. However, making the M dimension scalable (i.e. columns of the
// output matrix) should work as well.
//
// (*) The conversion tested in this file unrolls along the reduction
More information about the Mlir-commits
mailing list