[Mlir-commits] [mlir] [mlir][vector] Add scalable vectors to tests for vector.contract (PR #70039)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Fri Oct 27 01:26:01 PDT 2023
https://github.com/banach-space updated https://github.com/llvm/llvm-project/pull/70039
>From 74d86703223e3c5fba3f424189643889a11ae9a4 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Fri, 20 Oct 2023 16:13:49 +0000
Subject: [PATCH 1/2] [mlir][SVE] Add an e2e test for vector.contract
Adds an end-to-end test for `vector.contract` that targets SVE (i.e.
scalable vectors). Note that this requires lifting the restriction on
`vector.outerproduct` (to which `vector.contract` is lowered to) that
would deem the following as invalid by the Op verifier (*):
```
vector.outerproduct %27, %28, %26 {kind = #vector.kind<add>} : vector<3xf32>, vector<[2]xf32>
```
This is indeed valid as the end-to-end test demonstrates (at least when
compiling for SVE).
Depends on #68794
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 9 +-
...r-contract-to-outerproduct-transforms.mlir | 59 ++++++--
.../Vector/vector-scalable-outerproduct.mlir | 5 +-
.../Vector/CPU/ArmSVE/test-contraction.mlir | 137 ++++++++++++++++++
4 files changed, 191 insertions(+), 19 deletions(-)
create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/test-contraction.mlir
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 8de132daf3c6a5d..d77476c10908395 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -3067,9 +3067,12 @@ LogicalResult OuterProductOp::verify() {
return emitOpError("expected #1 operand dim to match result dim #1");
if (vRHS.getDimSize(0) != vRES.getDimSize(1))
return emitOpError("expected #2 operand dim to match result dim #2");
- if (vRHS.isScalable() != vLHS.isScalable())
- return emitOpError("expected either all or none of vector operands #1 "
- "and #2 to be scalable");
+ if (vLHS.isScalable() && !vRHS.isScalable()) {
+ // This restriction reflects what's currently supported in terms of
+ // scalable vectors. However, we could relax this if there's a use case.
+ return emitOpError(
+ "expected either both or only #2 operand dim to be scalable");
+ }
} else {
// An AXPY operation.
if (vRES.getRank() != 1)
diff --git a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
index 3530393c013a1ae..44fb23088cea933 100644
--- a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
@@ -79,21 +79,21 @@ func.func @masked_extract_contract2_scalable_parallel_dim(%arg0: vector<[2]x3xf3
}
// CHECK-LABEL: func.func @masked_extract_contract4(
-// CHECK-SAME: %[[VAL_0:.*]]: vector<3x5xf32>,
-// CHECK-SAME: %[[VAL_1:.*]]: vector<5x7xf32>,
-// CHECK-SAME: %[[VAL_2:.*]]: vector<3x7xf32>,
-// CHECK-SAME: %[[VAL_3:.*]]: vector<3x7x5xi1>) -> vector<3x7xf32> {
-// CHECK: %[[VAL_5:.*]] = vector.transpose %[[VAL_3]], [2, 0, 1] : vector<3x7x5xi1> to vector<5x3x7xi1>
-// CHECK: %[[VAL_8:.*]] = vector.extract %[[VAL_5]][0] : vector<3x7xi1> from vector<5x3x7xi1>
-// CHECK: %[[VAL_9:.*]] = vector.mask %[[VAL_8]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
-// CHECK: %[[VAL_12:.*]] = vector.extract %[[VAL_5]][1] : vector<3x7xi1> from vector<5x3x7xi1>
-// CHECK: %[[VAL_13:.*]] = vector.mask %[[VAL_12]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
-// CHECK: %[[VAL_16:.*]] = vector.extract %[[VAL_5]][2] : vector<3x7xi1> from vector<5x3x7xi1>
-// CHECK: %[[VAL_17:.*]] = vector.mask %[[VAL_16]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
-// CHECK: %[[VAL_20:.*]] = vector.extract %[[VAL_5]][3] : vector<3x7xi1> from vector<5x3x7xi1>
-// CHECK: %[[VAL_21:.*]] = vector.mask %[[VAL_20]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
-// CHECK: %[[VAL_24:.*]] = vector.extract %[[VAL_5]][4] : vector<3x7xi1> from vector<5x3x7xi1>
-// CHECK: %[[VAL_25:.*]] = vector.mask %[[VAL_24]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
+// CHECK-SAME: %{{.*}}: vector<3x5xf32>,
+// CHECK-SAME: %{{.*}}: vector<5x7xf32>,
+// CHECK-SAME: %{{.*}}: vector<3x7xf32>,
+// CHECK-SAME: %[[IN_MASK:.*]]: vector<3x7x5xi1>) -> vector<3x7xf32> {
+// CHECK: %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [2, 0, 1] : vector<3x7x5xi1> to vector<5x3x7xi1>
+// CHECK: %[[T_MASK_R0:.*]] = vector.extract %[[T_MASK]][0] : vector<3x7xi1> from vector<5x3x7xi1>
+// CHECK: %{{.*}} = vector.mask %[[T_MASK_R0]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
+// CHECK: %[[T_MASK_R1:.*]] = vector.extract %[[T_MASK]][1] : vector<3x7xi1> from vector<5x3x7xi1>
+// CHECK: %{{.*}} = vector.mask %[[T_MASK_R1]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
+// CHECK: %[[T_MASK_R2:.*]] = vector.extract %[[T_MASK]][2] : vector<3x7xi1> from vector<5x3x7xi1>
+// CHECK: %{{.*}} = vector.mask %[[T_MASK_R2]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
+// CHECK: %[[T_MASK_R3:.*]] = vector.extract %[[T_MASK]][3] : vector<3x7xi1> from vector<5x3x7xi1>
+// CHECK: %{{.*}} = vector.mask %[[T_MASK_R3]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
+// CHECK: %[[T_MASK_R4:.*]] = vector.extract %[[T_MASK]][4] : vector<3x7xi1> from vector<5x3x7xi1>
+// CHECK: %{{.*}} = vector.mask %[[T_MASK_R4]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
func.func @masked_extract_contract4(%arg0: vector<3x5xf32>,
%arg1: vector<5x7xf32>,
@@ -104,6 +104,35 @@ func.func @masked_extract_contract4(%arg0: vector<3x5xf32>,
return %0 : vector<3x7xf32>
}
+// CHECK-LABEL: func.func @masked_extract_contract4_scalable_J_dim(
+// CHECK-SAME: %{{.*}}: vector<3x5xf32>,
+// CHECK-SAME: %{{.*}}: vector<5x[7]xf32>,
+// CHECK-SAME: %{{.*}}: vector<3x[7]xf32>,
+// CHECK-SAME: %[[IN_MASK:.*]]: vector<3x[7]x5xi1>) -> vector<3x[7]xf32> {
+// CHECK: %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [2, 0, 1] : vector<3x[7]x5xi1> to vector<5x3x[7]xi1>
+// CHECK: %[[T_MASK_R0:.*]] = vector.extract %[[T_MASK]][0] : vector<3x[7]xi1> from vector<5x3x[7]xi1>
+// CHECK: %{{.*}} = vector.mask %[[T_MASK_R0]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32>
+// CHECK: %[[T_MASK_R1:.*]] = vector.extract %[[T_MASK]][1] : vector<3x[7]xi1> from vector<5x3x[7]xi1>
+// CHECK: %[[VAL_13:.*]] = vector.mask %[[T_MASK_R1]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32>
+// CHECK: %[[T_MASK_R2:.*]] = vector.extract %[[T_MASK]][2] : vector<3x[7]xi1> from vector<5x3x[7]xi1>
+// CHECK: %{{.*}} = vector.mask %[[T_MASK_R2]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32>
+// CHECK: %[[T_MASK_R3:.*]] = vector.extract %[[T_MASK]][3] : vector<3x[7]xi1> from vector<5x3x[7]xi1>
+// CHECK: %{{.*}} = vector.mask %[[T_MASK_R3]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32>
+// CHECK: %[[T_MASK_R4:.*]] = vector.extract %[[T_MASK]][4] : vector<3x[7]xi1> from vector<5x3x[7]xi1>
+// CHECK: %{{.*}} = vector.mask %[[T_MASK_R4]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32>
+
+// Note that only the J dimension is scalable in this example. In theory, all
+// dimensions could be be scalable, but there is no target yet for which this
+// would make sense.
+func.func @masked_extract_contract4_scalable_J_dim(%arg0: vector<3x5xf32>,
+ %arg1: vector<5x[7]xf32>,
+ %arg2: vector<3x[7]xf32>,
+ %m : vector<3x[7]x5xi1>) -> vector<3x[7]xf32> {
+ %0 = vector.mask %m { vector.contract #matmat_trait %arg0, %arg1, %arg2
+ : vector<3x5xf32>, vector<5x[7]xf32> into vector<3x[7]xf32> } : vector<3x[7]x5xi1> -> vector<3x[7]xf32>
+ return %0 : vector<3x[7]xf32>
+}
+
// CHECK-LABEL: func @matmul
// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x4xf32>,
// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x3xf32>,
diff --git a/mlir/test/Dialect/Vector/vector-scalable-outerproduct.mlir b/mlir/test/Dialect/Vector/vector-scalable-outerproduct.mlir
index 7d9923e036660c9..3b4e24da92aaacc 100644
--- a/mlir/test/Dialect/Vector/vector-scalable-outerproduct.mlir
+++ b/mlir/test/Dialect/Vector/vector-scalable-outerproduct.mlir
@@ -21,9 +21,12 @@ func.func @invalid_outerproduct(%src : memref<?xf32>) {
%0 = vector.load %src[%idx] : memref<?xf32>, vector<[4]xf32>
%1 = vector.load %src[%idx] : memref<?xf32>, vector<4xf32>
- // expected-error @+1 {{expected either all or none of vector operands #1 and #2 to be scalable}}
+ // expected-error @+1 {{expected either both or only #2 operand dim to be scalable}}
%op = vector.outerproduct %0, %1 : vector<[4]xf32>, vector<4xf32>
+
+ return
}
+
// -----
func.func @invalid_outerproduct1(%src : memref<?xf32>) {
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/test-contraction.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/test-contraction.mlir
new file mode 100644
index 000000000000000..12187dfd7787155
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/test-contraction.mlir
@@ -0,0 +1,137 @@
+// DEFINE: %{compile} = mlir-opt %s -test-transform-dialect-interpreter -test-transform-dialect-erase-schedule\
+// DEFINE: -cse -canonicalize -convert-vector-to-scf -arm-sve-legalize-vector-storage\
+// DEFINE: -convert-vector-to-llvm="enable-arm-sve" -test-lower-to-llvm
+// DEFINE: %{entry} =
+// DEFINE: %{run} = %mcr_aarch64_cmd -e=%{entry} -entry-point-result=void --march=aarch64 --mattr="+sve" -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext
+
+// REDEFINE: %{entry} = entry_i32
+// RUN: %{compile} | %{run} | FileCheck %s --check-prefix=I32
+
+// REDEFINE: %{entry} = entry_f32
+// RUN: %{compile} | %{run} | FileCheck %s --check-prefix=F32
+
+#matmat_accesses = [
+ affine_map<(i, j, k) -> (i, k)>,
+ affine_map<(i, j, k) -> (k, j)>,
+ affine_map<(i, j, k) -> (i, j)>
+]
+#matmat_trait = {
+ indexing_maps = #matmat_accesses,
+ iterator_types = ["parallel", "parallel", "reduction"]
+}
+
+func.func @entry_i32() {
+ %vscale = vector.vscale
+
+ %c0 = arith.constant 0 : index
+ %c2 = arith.constant 2 : index
+ %c3 = arith.constant 3 : index
+ %c5 = arith.constant 5 : index
+ %n_rows = arith.muli %vscale, %c2 : index
+
+ %cst = arith.constant 0: i32
+ %i32_123 = arith.constant 123 : i32
+ %i32_314 = arith.constant 314 : i32
+
+ // Allocate and initialize matrix A
+ %A_alloc = memref.alloca() : memref<3x5xi32>
+ linalg.fill ins(%i32_123 : i32) outs(%A_alloc :memref<3x5xi32>)
+ %mask_a = vector.create_mask %c3, %c5 : vector<3x5xi1>
+ %vector_a = vector.transfer_read %A_alloc[%c0, %c0], %cst, %mask_a {in_bounds = [true, true]} : memref<3x5xi32>, vector<3x5xi32>
+
+ // Allocate and initialize matrix B
+ %B_alloc = memref.alloca(%n_rows) : memref<5x?xi32>
+ linalg.fill ins(%i32_123 : i32) outs(%B_alloc :memref<5x?xi32>)
+ %mask_b = vector.create_mask %c5, %n_rows : vector<5x[2]xi1>
+ %vector_b = vector.transfer_read %B_alloc[%c0, %c0], %cst, %mask_b {in_bounds = [true, true]} : memref<5x?xi32>, vector<5x[2]xi32>
+
+ // Allocate and initialize matrix C
+ %C_alloc = memref.alloca(%n_rows) : memref<3x?xi32>
+ linalg.fill ins(%i32_314 : i32) outs(%C_alloc :memref<3x?xi32>)
+ %mask_c = vector.create_mask %c3, %n_rows : vector<3x[2]xi1>
+ %vector_c = vector.transfer_read %C_alloc[%c0, %c0], %cst, %mask_c {in_bounds = [true, true]} : memref<3x?xi32>, vector<3x[2]xi32>
+
+ // Matmul
+ %m = vector.create_mask %c3, %n_rows, %c5 : vector<3x[2]x5xi1>
+ %0 = vector.mask %m { vector.contract #matmat_trait %vector_a, %vector_b, %vector_c
+ : vector<3x5xi32>, vector<5x[2]xi32> into vector<3x[2]xi32> } : vector<3x[2]x5xi1> -> vector<3x[2]xi32>
+
+ // Print the output
+ %slice1 = vector.extract %0[0] : vector<[2]xi32> from vector<3x[2]xi32>
+ // I32: ( 75959, 75959, 75959, 75959
+ vector.print %slice1 : vector<[2]xi32>
+ %slice2 = vector.extract %0[1] : vector<[2]xi32> from vector<3x[2]xi32>
+ // I32-NEXT: ( 75959, 75959, 75959, 75959
+ vector.print %slice2 : vector<[2]xi32>
+ %slice3 = vector.extract %0[2] : vector<[2]xi32> from vector<3x[2]xi32>
+ // I32-NEXT: ( 75959, 75959, 75959, 75959
+ vector.print %slice3 : vector<[2]xi32>
+
+ // CHECK: SVE: END OF TEST OUTPUT
+ vector.print str "SVE: END OF TEST OUTPUT"
+
+ return
+}
+
+func.func @entry_f32() {
+ %vscale = vector.vscale
+
+ %c0 = arith.constant 0 : index
+ %c2 = arith.constant 2 : index
+ %c3 = arith.constant 3 : index
+ %c5 = arith.constant 5 : index
+ %n_rows = arith.muli %vscale, %c2 : index
+
+ %cst = arith.constant 0.0: f32
+ %f32_123 = arith.constant 1.23 : f32
+ %f32_314 = arith.constant 3.14 : f32
+
+ // Allocate and initialize matrix A
+ %A_alloc = memref.alloca() : memref<3x5xf32>
+ linalg.fill ins(%f32_123 : f32) outs(%A_alloc :memref<3x5xf32>)
+ %mask_a = vector.create_mask %c3, %c5 : vector<3x5xi1>
+ %vector_a = vector.transfer_read %A_alloc[%c0, %c0], %cst, %mask_a {in_bounds = [true, true]} : memref<3x5xf32>, vector<3x5xf32>
+
+ // Allocate and initialize matrix B
+ %B_alloc = memref.alloca(%n_rows) : memref<5x?xf32>
+ linalg.fill ins(%f32_123 : f32) outs(%B_alloc :memref<5x?xf32>)
+ %mask_b = vector.create_mask %c5, %n_rows : vector<5x[2]xi1>
+ %vector_b = vector.transfer_read %B_alloc[%c0, %c0], %cst, %mask_b {in_bounds = [true, true]} : memref<5x?xf32>, vector<5x[2]xf32>
+
+ // Allocate and initialize matrix C
+ %C_alloc = memref.alloca(%n_rows) : memref<3x?xf32>
+ linalg.fill ins(%f32_314 : f32) outs(%C_alloc :memref<3x?xf32>)
+ %mask_c = vector.create_mask %c3, %n_rows : vector<3x[2]xi1>
+ %vector_c = vector.transfer_read %C_alloc[%c0, %c0], %cst, %mask_c {in_bounds = [true, true]} : memref<3x?xf32>, vector<3x[2]xf32>
+
+ // Matmul
+ %m = vector.create_mask %c3, %n_rows, %c5 : vector<3x[2]x5xi1>
+ %0 = vector.mask %m { vector.contract #matmat_trait %vector_a, %vector_b, %vector_c
+ : vector<3x5xf32>, vector<5x[2]xf32> into vector<3x[2]xf32> } : vector<3x[2]x5xi1> -> vector<3x[2]xf32>
+
+ // Print the output
+ %slice1 = vector.extract %0[0] : vector<[2]xf32> from vector<3x[2]xf32>
+ // F32: ( 10.7045, 10.7045, 10.7045, 10.7045
+ vector.print %slice1 : vector<[2]xf32>
+ %slice2 = vector.extract %0[1] : vector<[2]xf32> from vector<3x[2]xf32>
+ // F32-NEXT: ( 10.7045, 10.7045, 10.7045, 10.7045
+ vector.print %slice2 : vector<[2]xf32>
+ %slice3 = vector.extract %0[2] : vector<[2]xf32> from vector<3x[2]xf32>
+ // F32-NEXT: ( 10.7045, 10.7045, 10.7045, 10.7045
+ vector.print %slice3 : vector<[2]xf32>
+
+ // CHECK: SVE: END OF TEST OUTPUT
+ vector.print str "SVE: END OF TEST OUTPUT"
+
+ return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%module_op: !transform.any_op):
+ %f = transform.structured.match ops{["func.func"]} in %module_op
+ : (!transform.any_op) -> !transform.any_op
+
+ transform.apply_patterns to %f {
+ transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
+ } : !transform.any_op
+}
>From 3f8efb7af0be0aec716b7b979a9034db9029f15a Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Tue, 24 Oct 2023 09:24:34 +0000
Subject: [PATCH 2/2] [mlir][vector] Add scalable vectors to tests for
vector.contract
Update the remaining tests for matrix multiplication (_matmul_) in:
* vector-contract-to-outerproduct-transforms.mlir
with cases for scalable vectors.
Note that in order for the "vector.contract -> vector.outerproduct"
patterns to work, only the non-reduction dimension can be scalable (*).
For Matmul operations that is set to be the N dimension (i.e. rows of
the output matrix), which matches how matrix multiplication are normally
implemented for e.g. Arm's SVE. However, making the M dimension scalable
(i.e. columns of the output matrix) should work as well.
Making both parellel dimensions scalable is left as a TODO for when
support for 2-D scalable vectors is more established (this is
work-in-progress as part of the effort to support Arm's SME in MLIR).
(*) The conversion tested in this file unrolls along the reduction
dimension, which is not supported for scalable vectors.
---
.../Vector/Transforms/LowerVectorContract.cpp | 2 +-
...r-contract-to-outerproduct-transforms.mlir | 150 ++++++++++++++++++
2 files changed, 151 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index 5463a7bd8f4c840..6dbe36e605e9a78 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -418,7 +418,7 @@ struct UnrolledOuterProductGenerator
return v;
Type promotedType = dstElementType;
if (vecType)
- promotedType = VectorType::get(vecType.getShape(), promotedType);
+ promotedType = vecType.clone(promotedType);
if (isa<FloatType>(dstElementType))
return rewriter.create<arith::ExtFOp>(loc, promotedType, v);
return rewriter.create<arith::ExtSIOp>(loc, promotedType, v);
diff --git a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
index 44fb23088cea933..6933b24a32a830d 100644
--- a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
@@ -169,6 +169,42 @@ func.func @matmul(%arg0: vector<2x4xf32>,
return %0 : vector<2x3xf32>
}
+// CHECK-LABEL: func @matmul_scalable
+// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x4xf32>,
+// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x[3]xf32>,
+// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x[3]xf32>
+// CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
+// CHECK-SAME: : vector<2x4xf32> to vector<4x2xf32>
+//
+// CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf32> from vector<4x2xf32>
+// CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<[3]xf32> from vector<4x[3]xf32>
+// CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
+// CHECK-SAME: : vector<2xf32>, vector<[3]xf32>
+//
+// CHECK: %[[a1:.*]] = vector.extract %[[At]][1] : vector<2xf32> from vector<4x2xf32>
+// CHECK: %[[b1:.*]] = vector.extract %[[B]][1] : vector<[3]xf32> from vector<4x[3]xf32>
+// CHECK: %[[c1:.*]] = vector.outerproduct %[[a1]], %[[b1]], %[[c0]]
+// CHECK-SAME: : vector<2xf32>, vector<[3]xf32>
+//
+// CHECK: %[[a2:.*]] = vector.extract %[[At]][2] : vector<2xf32> from vector<4x2xf32>
+// CHECK: %[[b2:.*]] = vector.extract %[[B]][2] : vector<[3]xf32> from vector<4x[3]xf32>
+// CHECK: %[[c2:.*]] = vector.outerproduct %[[a2]], %[[b2]], %[[c1]]
+// CHECK-SAME: : vector<2xf32>, vector<[3]xf32>
+//
+// CHECK: %[[a3:.*]] = vector.extract %[[At]][3] : vector<2xf32> from vector<4x2xf32>
+// CHECK: %[[b3:.*]] = vector.extract %[[B]][3] : vector<[3]xf32> from vector<4x[3]xf32>
+// CHECK: %[[c3:.*]] = vector.outerproduct %[[a3]], %[[b3]], %[[c2]]
+// CHECK-SAME: : vector<2xf32>, vector<[3]xf32>
+//
+// CHECK: return %[[c3]] : vector<2x[3]xf32>
+func.func @matmul_scalable(%arg0: vector<2x4xf32>,
+ %arg1: vector<4x[3]xf32>,
+ %arg2: vector<2x[3]xf32>) -> vector<2x[3]xf32> {
+ %0 = vector.contract #matmat_trait %arg0, %arg1, %arg2
+ : vector<2x4xf32>, vector<4x[3]xf32> into vector<2x[3]xf32>
+ return %0 : vector<2x[3]xf32>
+}
+
// CHECK-LABEL: func @matmul_0
// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>,
@@ -186,6 +222,23 @@ func.func @matmul_0(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vecto
return %0 : vector<2x3xf32>
}
+// CHECK-LABEL: func @matmul_0_scalable
+// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
+// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x[3]xf32>,
+// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x[3]xf32>
+// CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
+// CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf32> from vector<1x2xf32>
+// CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<[3]xf32> from vector<1x[3]xf32>
+// CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
+// CHECK: return %[[c0]] : vector<2x[3]xf32>
+func.func @matmul_0_scalable(%arg0: vector<2x1xf32>, %arg1: vector<1x[3]xf32>, %arg2: vector<2x[3]xf32>)
+-> vector<2x[3]xf32>
+{
+ %0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2
+ : vector<2x1xf32>, vector<1x[3]xf32> into vector<2x[3]xf32>
+ return %0 : vector<2x[3]xf32>
+}
+
// CHECK-LABEL: func @matmul_0_mixed
// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf16>,
// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf16>,
@@ -205,6 +258,25 @@ func.func @matmul_0_mixed(%arg0: vector<2x1xf16>, %arg1: vector<1x3xf16>, %arg2:
return %0 : vector<2x3xf32>
}
+// CHECK-LABEL: func @matmul_0_mixed_scalable
+// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf16>,
+// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x[3]xf16>,
+// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x[3]xf32>
+// CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
+// CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf16> from vector<1x2xf16>
+// CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<[3]xf16> from vector<1x[3]xf16>
+// CHECK: %[[a1:.*]] = arith.extf %[[a0]] : vector<2xf16> to vector<2xf32>
+// CHECK: %[[b1:.*]] = arith.extf %[[b0]] : vector<[3]xf16> to vector<[3]xf32>
+// CHECK: %[[c0:.*]] = vector.outerproduct %[[a1]], %[[b1]], %[[C]]
+// CHECK: return %[[c0]] : vector<2x[3]xf32>
+func.func @matmul_0_mixed_scalable(%arg0: vector<2x1xf16>, %arg1: vector<1x[3]xf16>, %arg2: vector<2x[3]xf32>)
+-> vector<2x[3]xf32>
+{
+ %0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2
+ : vector<2x1xf16>, vector<1x[3]xf16> into vector<2x[3]xf32>
+ return %0 : vector<2x[3]xf32>
+}
+
#matmat_accesses_1 = [
affine_map<(m, n, k) -> (m, k)>,
affine_map<(m, n, k) -> (n, k)>,
@@ -233,6 +305,24 @@ func.func @matmul_1(%arg0: vector<2x1xf32>, %arg1: vector<3x1xf32>, %arg2: vecto
return %0 : vector<2x3xf32>
}
+// CHECK-LABEL: func @matmul_1_scalable
+// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
+// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<[3]x1xf32>,
+// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x[3]xf32>
+// CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
+// CHECK: %[[Bt:.*]] = vector.transpose %[[B]], [1, 0]
+// CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf32> from vector<1x2xf32>
+// CHECK: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<[3]xf32> from vector<1x[3]xf32>
+// CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
+// CHECK: return %[[c0]] : vector<2x[3]xf32>
+func.func @matmul_1_scalable(%arg0: vector<2x1xf32>, %arg1: vector<[3]x1xf32>, %arg2: vector<2x[3]xf32>)
+-> vector<2x[3]xf32>
+{
+ %0 = vector.contract #matmat_trait_1 %arg0, %arg1, %arg2
+ : vector<2x1xf32>, vector<[3]x1xf32> into vector<2x[3]xf32>
+ return %0 : vector<2x[3]xf32>
+}
+
#matmat_accesses_2 = [
affine_map<(m, n, k) -> (k, m)>,
affine_map<(m, n, k) -> (k, n)>,
@@ -259,6 +349,22 @@ func.func @matmul_2(%arg0: vector<1x2xf32>, %arg1: vector<1x3xf32>, %arg2: vecto
return %0 : vector<2x3xf32>
}
+// CHECK-LABEL: func @matmul_2_scalable
+// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<1x2xf32>,
+// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x[3]xf32>,
+// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x[3]xf32>
+// CHECK: %[[a0:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<1x2xf32>
+// CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<[3]xf32> from vector<1x[3]xf32>
+// CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
+// CHECK: return %[[c0]] : vector<2x[3]xf32>
+func.func @matmul_2_scalable(%arg0: vector<1x2xf32>, %arg1: vector<1x[3]xf32>, %arg2: vector<2x[3]xf32>)
+-> vector<2x[3]xf32>
+{
+ %0 = vector.contract #matmat_trait_2 %arg0, %arg1, %arg2
+ : vector<1x2xf32>, vector<1x[3]xf32> into vector<2x[3]xf32>
+ return %0 : vector<2x[3]xf32>
+}
+
#matmat_accesses_3 = [
affine_map<(m, n, k) -> (k, m)>,
affine_map<(m, n, k) -> (n, k)>,
@@ -286,6 +392,23 @@ func.func @matmul_3(%arg0: vector<1x2xf32>, %arg1: vector<3x1xf32>, %arg2: vecto
return %0 : vector<2x3xf32>
}
+// CHECK-LABEL: func @matmul_3_scalable
+// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<1x2xf32>,
+// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<[3]x1xf32>,
+// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x[3]xf32>
+// CHECK: %[[Bt:.*]] = vector.transpose %[[B]], [1, 0]
+// CHECK: %[[a0:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<1x2xf32>
+// CHECK: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<[3]xf32> from vector<1x[3]xf32>
+// CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
+// CHECK: return %[[c0]] : vector<2x[3]xf32>
+func.func @matmul_3_scalable(%arg0: vector<1x2xf32>, %arg1: vector<[3]x1xf32>, %arg2: vector<2x[3]xf32>)
+-> vector<2x[3]xf32>
+{
+ %0 = vector.contract #matmat_trait_3 %arg0, %arg1, %arg2
+ : vector<1x2xf32>, vector<[3]x1xf32> into vector<2x[3]xf32>
+ return %0 : vector<2x[3]xf32>
+}
+
#matmat_accesses_4 = [
affine_map<(m, n, k) -> (m, k)>,
affine_map<(m, n, k) -> (k, n)>,
@@ -313,6 +436,33 @@ func.func @matmul_4(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vecto
return %0 : vector<3x2xf32>
}
+// CHECK-LABEL: func @matmul_4_scalable
+// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<[2]x1xf32>,
+// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>,
+// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<3x[2]xf32>
+// CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
+// CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<3xf32> from vector<1x3xf32>
+// CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<[2]xf32> from vector<1x[2]xf32>
+// CHECK: %[[c0:.*]] = vector.outerproduct %[[b0]], %[[a0]], %[[C]]
+// CHECK: return %[[c0]] : vector<3x[2]xf32>
+func.func @matmul_4_scalable(%arg0: vector<[2]x1xf32>, %arg1: vector<1x3xf32>, %arg2: vector<3x[2]xf32>)
+-> vector<3x[2]xf32>
+{
+ %0 = vector.contract #matmat_trait_4 %arg0, %arg1, %arg2
+ : vector<[2]x1xf32>, vector<1x3xf32> into vector<3x[2]xf32>
+ return %0 : vector<3x[2]xf32>
+}
+
+#matmat_accesses_5 = [
+ affine_map<(m, n, k) -> (m, k)>,
+ affine_map<(m, n, k) -> (k, n)>,
+ affine_map<(m, n, k) -> (n, m)>
+]
+#matmat_trait_5 = {
+ indexing_maps = #matmat_accesses_5,
+ iterator_types = ["parallel", "parallel", "reduction"]
+}
+
// CHECK-LABEL: @masked_matvec_mk_k_m
// CHECK-SAME: %[[MAT:.+]]: vector<4x2xf32>
// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
More information about the Mlir-commits
mailing list