[Mlir-commits] [mlir] [mlir][Vector] Update v.contract -> v.outerproduct tests (NFC) (PR #70449)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Oct 27 05:24:23 PDT 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Andrzej WarzyĆski (banach-space)
<details>
<summary>Changes</summary>
Re-orders tests in vector-contract-to-outerproduct-transforms.mlir so
that the file starts as follows:
1. plain matmul
2. plain matmul with scalable vectors
3. masked matmul
4. masked matmul with scalable vectors
5. plain matmul with mixed types
6. plain matmul with mixed types and scalable vectors
All of the above share the same indexing maps. This allowed to identify
one more duplicate. Following the cases above are examples with
different maps.
In addition, added extra comments to document the tests and to split
them into categories. There is also some extra reformatting to unify the
tests.
---
Patch is 32.05 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/70449.diff
1 Files Affected:
- (modified) mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir (+161-136)
``````````diff
diff --git a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
index 965d55c53ba333e..05611122abd2f25 100644
--- a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
@@ -1,15 +1,22 @@
// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s
-#matmat_accesses = [
- affine_map<(i, j, k) -> (i, k)>,
- affine_map<(i, j, k) -> (k, j)>,
- affine_map<(i, j, k) -> (i, j)>
-]
-#matmat_trait = {
- indexing_maps = #matmat_accesses,
- iterator_types = ["parallel", "parallel", "reduction"]
-}
+// NOTE - tests in this file are duplicated so that there's a version for
+// * _fixed width_ and for _scalable_ vectors.
+// In order for the "vector.contract -> vector.outerproduct" patterns to work,
+// only the non-reduction dimension can be scalable (*). For Matmul operations
+// that is set to be the N dimension (i.e. rows of the output matrix), which
+// matches how matrix multiplication are normally implemented for e.g.
+// Arm's SVE. However, making the M dimension scalable (i.e. columns of the
+// output matrix) should work as well.
+//
+// (*) The conversion tested in this file unrolls along the reduction
+// dimension, which is not supported for scalable vectors.
+//
+// TODO: Matvec without a mask
+// ============================================================================
+// Matmul 0 (plain + masked + mixed types)
+// ============================================================================
#matmat_accesses_0 = [
affine_map<(m, n, k) -> (m, k)>,
affine_map<(m, n, k) -> (k, n)>,
@@ -20,62 +27,6 @@
iterator_types = ["parallel", "parallel", "reduction"]
}
-
-// CHECK-LABEL: func.func @masked_extract_contract4(
-// CHECK-SAME: %{{.*}}: vector<3x5xf32>,
-// CHECK-SAME: %{{.*}}: vector<5x7xf32>,
-// CHECK-SAME: %{{.*}}: vector<3x7xf32>,
-// CHECK-SAME: %[[IN_MASK:.*]]: vector<3x7x5xi1>) -> vector<3x7xf32> {
-// CHECK: %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [2, 0, 1] : vector<3x7x5xi1> to vector<5x3x7xi1>
-// CHECK: %[[T_MASK_R0:.*]] = vector.extract %[[T_MASK]][0] : vector<3x7xi1> from vector<5x3x7xi1>
-// CHECK: %{{.*}} = vector.mask %[[T_MASK_R0]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
-// CHECK: %[[T_MASK_R1:.*]] = vector.extract %[[T_MASK]][1] : vector<3x7xi1> from vector<5x3x7xi1>
-// CHECK: %{{.*}} = vector.mask %[[T_MASK_R1]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
-// CHECK: %[[T_MASK_R2:.*]] = vector.extract %[[T_MASK]][2] : vector<3x7xi1> from vector<5x3x7xi1>
-// CHECK: %{{.*}} = vector.mask %[[T_MASK_R2]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
-// CHECK: %[[T_MASK_R3:.*]] = vector.extract %[[T_MASK]][3] : vector<3x7xi1> from vector<5x3x7xi1>
-// CHECK: %{{.*}} = vector.mask %[[T_MASK_R3]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
-// CHECK: %[[T_MASK_R4:.*]] = vector.extract %[[T_MASK]][4] : vector<3x7xi1> from vector<5x3x7xi1>
-// CHECK: %{{.*}} = vector.mask %[[T_MASK_R4]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
-
-func.func @masked_extract_contract4(%arg0: vector<3x5xf32>,
- %arg1: vector<5x7xf32>,
- %arg2: vector<3x7xf32>,
- %m : vector<3x7x5xi1>) -> vector<3x7xf32> {
- %0 = vector.mask %m { vector.contract #matmat_trait %arg0, %arg1, %arg2
- : vector<3x5xf32>, vector<5x7xf32> into vector<3x7xf32> } : vector<3x7x5xi1> -> vector<3x7xf32>
- return %0 : vector<3x7xf32>
-}
-
-// CHECK-LABEL: func.func @masked_extract_contract4_scalable_J_dim(
-// CHECK-SAME: %{{.*}}: vector<3x5xf32>,
-// CHECK-SAME: %{{.*}}: vector<5x[7]xf32>,
-// CHECK-SAME: %{{.*}}: vector<3x[7]xf32>,
-// CHECK-SAME: %[[IN_MASK:.*]]: vector<3x[7]x5xi1>) -> vector<3x[7]xf32> {
-// CHECK: %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [2, 0, 1] : vector<3x[7]x5xi1> to vector<5x3x[7]xi1>
-// CHECK: %[[T_MASK_R0:.*]] = vector.extract %[[T_MASK]][0] : vector<3x[7]xi1> from vector<5x3x[7]xi1>
-// CHECK: %{{.*}} = vector.mask %[[T_MASK_R0]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32>
-// CHECK: %[[T_MASK_R1:.*]] = vector.extract %[[T_MASK]][1] : vector<3x[7]xi1> from vector<5x3x[7]xi1>
-// CHECK: %[[VAL_13:.*]] = vector.mask %[[T_MASK_R1]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32>
-// CHECK: %[[T_MASK_R2:.*]] = vector.extract %[[T_MASK]][2] : vector<3x[7]xi1> from vector<5x3x[7]xi1>
-// CHECK: %{{.*}} = vector.mask %[[T_MASK_R2]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32>
-// CHECK: %[[T_MASK_R3:.*]] = vector.extract %[[T_MASK]][3] : vector<3x[7]xi1> from vector<5x3x[7]xi1>
-// CHECK: %{{.*}} = vector.mask %[[T_MASK_R3]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32>
-// CHECK: %[[T_MASK_R4:.*]] = vector.extract %[[T_MASK]][4] : vector<3x[7]xi1> from vector<5x3x[7]xi1>
-// CHECK: %{{.*}} = vector.mask %[[T_MASK_R4]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32>
-
-// Note that only the J dimension is scalable in this example. In theory, all
-// dimensions could be be scalable, but there is no target yet for which this
-// would make sense.
-func.func @masked_extract_contract4_scalable_J_dim(%arg0: vector<3x5xf32>,
- %arg1: vector<5x[7]xf32>,
- %arg2: vector<3x[7]xf32>,
- %m : vector<3x[7]x5xi1>) -> vector<3x[7]xf32> {
- %0 = vector.mask %m { vector.contract #matmat_trait %arg0, %arg1, %arg2
- : vector<3x5xf32>, vector<5x[7]xf32> into vector<3x[7]xf32> } : vector<3x[7]x5xi1> -> vector<3x[7]xf32>
- return %0 : vector<3x[7]xf32>
-}
-
// CHECK-LABEL: func @matmul
// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x4xf32>,
// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x3xf32>,
@@ -105,8 +56,8 @@ func.func @masked_extract_contract4_scalable_J_dim(%arg0: vector<3x5xf32>,
//
// CHECK: return %[[c3]] : vector<2x3xf32>
func.func @matmul(%arg0: vector<2x4xf32>,
- %arg1: vector<4x3xf32>,
- %arg2: vector<2x3xf32>) -> vector<2x3xf32> {
+ %arg1: vector<4x3xf32>,
+ %arg2: vector<2x3xf32>) -> vector<2x3xf32> {
%0 = vector.contract #matmat_trait %arg0, %arg1, %arg2
: vector<2x4xf32>, vector<4x3xf32> into vector<2x3xf32>
return %0 : vector<2x3xf32>
@@ -141,45 +92,63 @@ func.func @matmul(%arg0: vector<2x4xf32>,
//
// CHECK: return %[[c3]] : vector<2x[3]xf32>
func.func @matmul_scalable(%arg0: vector<2x4xf32>,
- %arg1: vector<4x[3]xf32>,
- %arg2: vector<2x[3]xf32>) -> vector<2x[3]xf32> {
+ %arg1: vector<4x[3]xf32>,
+ %arg2: vector<2x[3]xf32>) -> vector<2x[3]xf32> {
%0 = vector.contract #matmat_trait %arg0, %arg1, %arg2
: vector<2x4xf32>, vector<4x[3]xf32> into vector<2x[3]xf32>
return %0 : vector<2x[3]xf32>
}
-// CHECK-LABEL: func @matmul_0
-// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
-// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>,
-// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32>
-// CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
-// CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf32> from vector<1x2xf32>
-// CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<3xf32> from vector<1x3xf32>
-// CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
-// CHECK: return %[[c0]] : vector<2x3xf32>
-func.func @matmul_0(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vector<2x3xf32>)
--> vector<2x3xf32>
-{
- %0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2
- : vector<2x1xf32>, vector<1x3xf32> into vector<2x3xf32>
- return %0 : vector<2x3xf32>
+// CHECK-LABEL: func.func @masked_matmul(
+// CHECK-SAME: %{{.*}}: vector<3x5xf32>,
+// CHECK-SAME: %{{.*}}: vector<5x7xf32>,
+// CHECK-SAME: %{{.*}}: vector<3x7xf32>,
+// CHECK-SAME: %[[IN_MASK:.*]]: vector<3x7x5xi1>) -> vector<3x7xf32> {
+// CHECK: %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [2, 0, 1] : vector<3x7x5xi1> to vector<5x3x7xi1>
+// CHECK: %[[T_MASK_R0:.*]] = vector.extract %[[T_MASK]][0] : vector<3x7xi1> from vector<5x3x7xi1>
+// CHECK: %{{.*}} = vector.mask %[[T_MASK_R0]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
+// CHECK: %[[T_MASK_R1:.*]] = vector.extract %[[T_MASK]][1] : vector<3x7xi1> from vector<5x3x7xi1>
+// CHECK: %{{.*}} = vector.mask %[[T_MASK_R1]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
+// CHECK: %[[T_MASK_R2:.*]] = vector.extract %[[T_MASK]][2] : vector<3x7xi1> from vector<5x3x7xi1>
+// CHECK: %{{.*}} = vector.mask %[[T_MASK_R2]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
+// CHECK: %[[T_MASK_R3:.*]] = vector.extract %[[T_MASK]][3] : vector<3x7xi1> from vector<5x3x7xi1>
+// CHECK: %{{.*}} = vector.mask %[[T_MASK_R3]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
+// CHECK: %[[T_MASK_R4:.*]] = vector.extract %[[T_MASK]][4] : vector<3x7xi1> from vector<5x3x7xi1>
+// CHECK: %{{.*}} = vector.mask %[[T_MASK_R4]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
+
+func.func @masked_matmul(%arg0: vector<3x5xf32>,
+ %arg1: vector<5x7xf32>,
+ %arg2: vector<3x7xf32>,
+ %m : vector<3x7x5xi1>) -> vector<3x7xf32> {
+ %0 = vector.mask %m { vector.contract #matmat_trait %arg0, %arg1, %arg2
+ : vector<3x5xf32>, vector<5x7xf32> into vector<3x7xf32> } : vector<3x7x5xi1> -> vector<3x7xf32>
+ return %0 : vector<3x7xf32>
}
-// CHECK-LABEL: func @matmul_0_scalable
-// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
-// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x[3]xf32>,
-// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x[3]xf32>
-// CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
-// CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf32> from vector<1x2xf32>
-// CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<[3]xf32> from vector<1x[3]xf32>
-// CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
-// CHECK: return %[[c0]] : vector<2x[3]xf32>
-func.func @matmul_0_scalable(%arg0: vector<2x1xf32>, %arg1: vector<1x[3]xf32>, %arg2: vector<2x[3]xf32>)
--> vector<2x[3]xf32>
-{
- %0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2
- : vector<2x1xf32>, vector<1x[3]xf32> into vector<2x[3]xf32>
- return %0 : vector<2x[3]xf32>
+// CHECK-LABEL: func.func @masked_matmul_scalable(
+// CHECK-SAME: %{{.*}}: vector<3x5xf32>,
+// CHECK-SAME: %{{.*}}: vector<5x[7]xf32>,
+// CHECK-SAME: %{{.*}}: vector<3x[7]xf32>,
+// CHECK-SAME: %[[IN_MASK:.*]]: vector<3x[7]x5xi1>) -> vector<3x[7]xf32> {
+// CHECK: %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [2, 0, 1] : vector<3x[7]x5xi1> to vector<5x3x[7]xi1>
+// CHECK: %[[T_MASK_R0:.*]] = vector.extract %[[T_MASK]][0] : vector<3x[7]xi1> from vector<5x3x[7]xi1>
+// CHECK: %{{.*}} = vector.mask %[[T_MASK_R0]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32>
+// CHECK: %[[T_MASK_R1:.*]] = vector.extract %[[T_MASK]][1] : vector<3x[7]xi1> from vector<5x3x[7]xi1>
+// CHECK: %[[VAL_13:.*]] = vector.mask %[[T_MASK_R1]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32>
+// CHECK: %[[T_MASK_R2:.*]] = vector.extract %[[T_MASK]][2] : vector<3x[7]xi1> from vector<5x3x[7]xi1>
+// CHECK: %{{.*}} = vector.mask %[[T_MASK_R2]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32>
+// CHECK: %[[T_MASK_R3:.*]] = vector.extract %[[T_MASK]][3] : vector<3x[7]xi1> from vector<5x3x[7]xi1>
+// CHECK: %{{.*}} = vector.mask %[[T_MASK_R3]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32>
+// CHECK: %[[T_MASK_R4:.*]] = vector.extract %[[T_MASK]][4] : vector<3x[7]xi1> from vector<5x3x[7]xi1>
+// CHECK: %{{.*}} = vector.mask %[[T_MASK_R4]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32>
+
+func.func @masked_matmul_scalable(%arg0: vector<3x5xf32>,
+ %arg1: vector<5x[7]xf32>,
+ %arg2: vector<3x[7]xf32>,
+ %m : vector<3x[7]x5xi1>) -> vector<3x[7]xf32> {
+ %0 = vector.mask %m { vector.contract #matmat_trait %arg0, %arg1, %arg2
+ : vector<3x5xf32>, vector<5x[7]xf32> into vector<3x[7]xf32> } : vector<3x[7]x5xi1> -> vector<3x[7]xf32>
+ return %0 : vector<3x[7]xf32>
}
// CHECK-LABEL: func @matmul_0_mixed
@@ -193,8 +162,9 @@ func.func @matmul_0_scalable(%arg0: vector<2x1xf32>, %arg1: vector<1x[3]xf32>, %
// CHECK: %[[b1:.*]] = arith.extf %[[b0]] : vector<3xf16> to vector<3xf32>
// CHECK: %[[c0:.*]] = vector.outerproduct %[[a1]], %[[b1]], %[[C]]
// CHECK: return %[[c0]] : vector<2x3xf32>
-func.func @matmul_0_mixed(%arg0: vector<2x1xf16>, %arg1: vector<1x3xf16>, %arg2: vector<2x3xf32>)
--> vector<2x3xf32>
+func.func @matmul_mixed(%arg0: vector<2x1xf16>,
+ %arg1: vector<1x3xf16>,
+ %arg2: vector<2x3xf32>) -> vector<2x3xf32>
{
%0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2
: vector<2x1xf16>, vector<1x3xf16> into vector<2x3xf32>
@@ -212,14 +182,18 @@ func.func @matmul_0_mixed(%arg0: vector<2x1xf16>, %arg1: vector<1x3xf16>, %arg2:
// CHECK: %[[b1:.*]] = arith.extf %[[b0]] : vector<[3]xf16> to vector<[3]xf32>
// CHECK: %[[c0:.*]] = vector.outerproduct %[[a1]], %[[b1]], %[[C]]
// CHECK: return %[[c0]] : vector<2x[3]xf32>
-func.func @matmul_0_mixed_scalable(%arg0: vector<2x1xf16>, %arg1: vector<1x[3]xf16>, %arg2: vector<2x[3]xf32>)
--> vector<2x[3]xf32>
+func.func @matmul_mixed_scalable(%arg0: vector<2x1xf16>,
+ %arg1: vector<1x[3]xf16>,
+ %arg2: vector<2x[3]xf32>) -> vector<2x[3]xf32>
{
%0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2
: vector<2x1xf16>, vector<1x[3]xf16> into vector<2x[3]xf32>
return %0 : vector<2x[3]xf32>
}
+// ============================================================================
+// Matmul 1 (plain)
+// ============================================================================
#matmat_accesses_1 = [
affine_map<(m, n, k) -> (m, k)>,
affine_map<(m, n, k) -> (n, k)>,
@@ -240,8 +214,9 @@ func.func @matmul_0_mixed_scalable(%arg0: vector<2x1xf16>, %arg1: vector<1x[3]xf
// CHECK: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<3xf32> from vector<1x3xf32>
// CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
// CHECK: return %[[c0]] : vector<2x3xf32>
-func.func @matmul_1(%arg0: vector<2x1xf32>, %arg1: vector<3x1xf32>, %arg2: vector<2x3xf32>)
--> vector<2x3xf32>
+func.func @matmul_1(%arg0: vector<2x1xf32>,
+ %arg1: vector<3x1xf32>,
+ %arg2: vector<2x3xf32>) -> vector<2x3xf32>
{
%0 = vector.contract #matmat_trait_1 %arg0, %arg1, %arg2
: vector<2x1xf32>, vector<3x1xf32> into vector<2x3xf32>
@@ -258,14 +233,18 @@ func.func @matmul_1(%arg0: vector<2x1xf32>, %arg1: vector<3x1xf32>, %arg2: vecto
// CHECK: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<[3]xf32> from vector<1x[3]xf32>
// CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
// CHECK: return %[[c0]] : vector<2x[3]xf32>
-func.func @matmul_1_scalable(%arg0: vector<2x1xf32>, %arg1: vector<[3]x1xf32>, %arg2: vector<2x[3]xf32>)
--> vector<2x[3]xf32>
+func.func @matmul_1_scalable(%arg0: vector<2x1xf32>,
+ %arg1: vector<[3]x1xf32>,
+ %arg2: vector<2x[3]xf32>) -> vector<2x[3]xf32>
{
%0 = vector.contract #matmat_trait_1 %arg0, %arg1, %arg2
: vector<2x1xf32>, vector<[3]x1xf32> into vector<2x[3]xf32>
return %0 : vector<2x[3]xf32>
}
+// ============================================================================
+// Matmul 2 (plain)
+// ============================================================================
#matmat_accesses_2 = [
affine_map<(m, n, k) -> (k, m)>,
affine_map<(m, n, k) -> (k, n)>,
@@ -284,8 +263,9 @@ func.func @matmul_1_scalable(%arg0: vector<2x1xf32>, %arg1: vector<[3]x1xf32>, %
// CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<3xf32> from vector<1x3xf32>
// CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
// CHECK: return %[[c0]] : vector<2x3xf32>
-func.func @matmul_2(%arg0: vector<1x2xf32>, %arg1: vector<1x3xf32>, %arg2: vector<2x3xf32>)
--> vector<2x3xf32>
+func.func @matmul_2(%arg0: vector<1x2xf32>,
+ %arg1: vector<1x3xf32>,
+ %arg2: vector<2x3xf32>) -> vector<2x3xf32>
{
%0 = vector.contract #matmat_trait_2 %arg0, %arg1, %arg2
: vector<1x2xf32>, vector<1x3xf32> into vector<2x3xf32>
@@ -300,14 +280,18 @@ func.func @matmul_2(%arg0: vector<1x2xf32>, %arg1: vector<1x3xf32>, %arg2: vecto
// CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<[3]xf32> from vector<1x[3]xf32>
// CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
// CHECK: return %[[c0]] : vector<2x[3]xf32>
-func.func @matmul_2_scalable(%arg0: vector<1x2xf32>, %arg1: vector<1x[3]xf32>, %arg2: vector<2x[3]xf32>)
--> vector<2x[3]xf32>
+func.func @matmul_2_scalable(%arg0: vector<1x2xf32>,
+ %arg1: vector<1x[3]xf32>,
+ %arg2: vector<2x[3]xf32>) -> vector<2x[3]xf32>
{
%0 = vector.contract #matmat_trait_2 %arg0, %arg1, %arg2
: vector<1x2xf32>, vector<1x[3]xf32> into vector<2x[3]xf32>
return %0 : vector<2x[3]xf32>
}
+// ============================================================================
+// Matmul 3 (plain)
+// ============================================================================
#matmat_accesses_3 = [
affine_map<(m, n, k) -> (k, m)>,
affine_map<(m, n, k) -> (n, k)>,
@@ -327,8 +311,9 @@ func.func @matmul_2_scalable(%arg0: vector<1x2xf32>, %arg1: vector<1x[3]xf32>, %
// CHECK: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<3xf32> from vector<1x3xf32>
// CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
// CHECK: return %[[c0]] : vector<2x3xf32>
-func.func @matmul_3(%arg0: vector<1x2xf32>, %arg1: vector<3x1xf32>, %arg2: vector<2x3xf32>)
--> vector<2x3xf32>
+func.func @matmul_3(%arg0: vector<1x2xf32>,
+ %arg1: vector<3x1xf32>,
+ %arg2: vector<2x3xf32>) -> vector<2x3xf32>
{
%0 = vector.contract #matmat_trait_3 %arg0, %arg1, %arg2
: vector<1x2xf32>, vector<3x1x...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/70449
More information about the Mlir-commits
mailing list