[Mlir-commits] [mlir] [mlir][vector] Add scalable vectors to tests for vector.contract (PR #70039)

Andrzej WarzyƄski llvmlistbot at llvm.org
Fri Oct 27 01:36:55 PDT 2023


https://github.com/banach-space updated https://github.com/llvm/llvm-project/pull/70039

>From 49c7d2e67c018ddb714395f2a2dc168b0d7f56cc Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Tue, 24 Oct 2023 09:24:34 +0000
Subject: [PATCH] [mlir][vector] Add scalable vectors to tests for
 vector.contract

Update the remaining tests for matrix multiplication (_matmul_) in:

  * vector-contract-to-outerproduct-transforms.mlir

with cases for scalable vectors.

Note that in order for the "vector.contract -> vector.outerproduct"
patterns to work, only the non-reduction dimension can be scalable (*).
For Matmul operations that is set to be the N dimension (i.e. rows of
the output matrix), which matches how matrix multiplication are normally
implemented for e.g. Arm's SVE. However, making the M dimension scalable
(i.e. columns of the output matrix) should work as well.

Making both parellel dimensions scalable is left as a TODO for when
support for 2-D scalable vectors is more established (this is
work-in-progress as part of the effort to support Arm's SME in MLIR).

The change in:

  * `UnrolledOuterProductGenerator`

is a "bug fix" to make sure that the conversion pattern correctly
propagates scalability when creating `arith.extf` operations.

(*) The conversion tested in this file unrolls along the reduction
dimension, which is not supported for scalable vectors.
---
 .../Vector/Transforms/LowerVectorContract.cpp |   2 +-
 ...r-contract-to-outerproduct-transforms.mlir | 150 ++++++++++++++++++
 2 files changed, 151 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index 5463a7bd8f4c840..6dbe36e605e9a78 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -418,7 +418,7 @@ struct UnrolledOuterProductGenerator
       return v;
     Type promotedType = dstElementType;
     if (vecType)
-      promotedType = VectorType::get(vecType.getShape(), promotedType);
+      promotedType = vecType.clone(promotedType);
     if (isa<FloatType>(dstElementType))
       return rewriter.create<arith::ExtFOp>(loc, promotedType, v);
     return rewriter.create<arith::ExtSIOp>(loc, promotedType, v);
diff --git a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
index 44fb23088cea933..6933b24a32a830d 100644
--- a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
@@ -169,6 +169,42 @@ func.func @matmul(%arg0: vector<2x4xf32>,
   return %0 : vector<2x3xf32>
 }
 
+// CHECK-LABEL: func @matmul_scalable
+// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x4xf32>,
+// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x[3]xf32>,
+// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x[3]xf32>
+//      CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
+// CHECK-SAME:  : vector<2x4xf32> to vector<4x2xf32>
+//
+//      CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf32> from vector<4x2xf32>
+//      CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<[3]xf32> from vector<4x[3]xf32>
+//      CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
+// CHECK-SAME:  : vector<2xf32>, vector<[3]xf32>
+//
+//      CHECK: %[[a1:.*]] = vector.extract %[[At]][1] : vector<2xf32> from vector<4x2xf32>
+//      CHECK: %[[b1:.*]] = vector.extract %[[B]][1] : vector<[3]xf32> from vector<4x[3]xf32>
+//      CHECK: %[[c1:.*]] = vector.outerproduct %[[a1]], %[[b1]], %[[c0]]
+// CHECK-SAME:  : vector<2xf32>, vector<[3]xf32>
+//
+//      CHECK: %[[a2:.*]] = vector.extract %[[At]][2] : vector<2xf32> from vector<4x2xf32>
+//      CHECK: %[[b2:.*]] = vector.extract %[[B]][2] : vector<[3]xf32> from vector<4x[3]xf32>
+//      CHECK: %[[c2:.*]] = vector.outerproduct %[[a2]], %[[b2]], %[[c1]]
+// CHECK-SAME:  : vector<2xf32>, vector<[3]xf32>
+//
+//      CHECK: %[[a3:.*]] = vector.extract %[[At]][3] : vector<2xf32> from vector<4x2xf32>
+//      CHECK: %[[b3:.*]] = vector.extract %[[B]][3] : vector<[3]xf32> from vector<4x[3]xf32>
+//      CHECK: %[[c3:.*]] = vector.outerproduct %[[a3]], %[[b3]], %[[c2]]
+// CHECK-SAME:  : vector<2xf32>, vector<[3]xf32>
+//
+//      CHECK: return %[[c3]] : vector<2x[3]xf32>
+func.func @matmul_scalable(%arg0: vector<2x4xf32>,
+                          %arg1: vector<4x[3]xf32>,
+                          %arg2: vector<2x[3]xf32>) -> vector<2x[3]xf32> {
+  %0 = vector.contract #matmat_trait %arg0, %arg1, %arg2
+    : vector<2x4xf32>, vector<4x[3]xf32> into vector<2x[3]xf32>
+  return %0 : vector<2x[3]xf32>
+}
+
 // CHECK-LABEL: func @matmul_0
 // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
 // CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>,
@@ -186,6 +222,23 @@ func.func @matmul_0(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vecto
   return %0 : vector<2x3xf32>
 }
 
+// CHECK-LABEL: func @matmul_0_scalable
+// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
+// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x[3]xf32>,
+// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x[3]xf32>
+//      CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
+//      CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf32> from vector<1x2xf32>
+//      CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<[3]xf32> from vector<1x[3]xf32>
+//      CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
+//      CHECK: return %[[c0]] : vector<2x[3]xf32>
+func.func @matmul_0_scalable(%arg0: vector<2x1xf32>, %arg1: vector<1x[3]xf32>, %arg2: vector<2x[3]xf32>)
+-> vector<2x[3]xf32>
+{
+  %0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2
+    : vector<2x1xf32>, vector<1x[3]xf32> into vector<2x[3]xf32>
+  return %0 : vector<2x[3]xf32>
+}
+
 // CHECK-LABEL: func @matmul_0_mixed
 // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf16>,
 // CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf16>,
@@ -205,6 +258,25 @@ func.func @matmul_0_mixed(%arg0: vector<2x1xf16>, %arg1: vector<1x3xf16>, %arg2:
   return %0 : vector<2x3xf32>
 }
 
+// CHECK-LABEL: func @matmul_0_mixed_scalable
+// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf16>,
+// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x[3]xf16>,
+// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x[3]xf32>
+//      CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
+//      CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf16> from vector<1x2xf16>
+//      CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<[3]xf16> from vector<1x[3]xf16>
+//      CHECK: %[[a1:.*]] = arith.extf %[[a0]] : vector<2xf16> to vector<2xf32>
+//      CHECK: %[[b1:.*]] = arith.extf %[[b0]] : vector<[3]xf16> to vector<[3]xf32>
+//      CHECK: %[[c0:.*]] = vector.outerproduct %[[a1]], %[[b1]], %[[C]]
+//      CHECK: return %[[c0]] : vector<2x[3]xf32>
+func.func @matmul_0_mixed_scalable(%arg0: vector<2x1xf16>, %arg1: vector<1x[3]xf16>, %arg2: vector<2x[3]xf32>)
+-> vector<2x[3]xf32>
+{
+  %0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2
+    : vector<2x1xf16>, vector<1x[3]xf16> into vector<2x[3]xf32>
+  return %0 : vector<2x[3]xf32>
+}
+
 #matmat_accesses_1 = [
   affine_map<(m, n, k) -> (m, k)>,
   affine_map<(m, n, k) -> (n, k)>,
@@ -233,6 +305,24 @@ func.func @matmul_1(%arg0: vector<2x1xf32>, %arg1: vector<3x1xf32>, %arg2: vecto
   return %0 : vector<2x3xf32>
 }
 
+// CHECK-LABEL: func @matmul_1_scalable
+// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
+// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<[3]x1xf32>,
+// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x[3]xf32>
+//      CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
+//      CHECK: %[[Bt:.*]] = vector.transpose %[[B]], [1, 0]
+//      CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf32> from vector<1x2xf32>
+//      CHECK: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<[3]xf32> from vector<1x[3]xf32>
+//      CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
+//      CHECK: return %[[c0]] : vector<2x[3]xf32>
+func.func @matmul_1_scalable(%arg0: vector<2x1xf32>, %arg1: vector<[3]x1xf32>, %arg2: vector<2x[3]xf32>)
+-> vector<2x[3]xf32>
+{
+  %0 = vector.contract #matmat_trait_1 %arg0, %arg1, %arg2
+    : vector<2x1xf32>, vector<[3]x1xf32> into vector<2x[3]xf32>
+  return %0 : vector<2x[3]xf32>
+}
+
 #matmat_accesses_2 = [
   affine_map<(m, n, k) -> (k, m)>,
   affine_map<(m, n, k) -> (k, n)>,
@@ -259,6 +349,22 @@ func.func @matmul_2(%arg0: vector<1x2xf32>, %arg1: vector<1x3xf32>, %arg2: vecto
   return %0 : vector<2x3xf32>
 }
 
+// CHECK-LABEL: func @matmul_2_scalable
+// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<1x2xf32>,
+// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x[3]xf32>,
+// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x[3]xf32>
+//      CHECK: %[[a0:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<1x2xf32>
+//      CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<[3]xf32> from vector<1x[3]xf32>
+//      CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
+//      CHECK: return %[[c0]] : vector<2x[3]xf32>
+func.func @matmul_2_scalable(%arg0: vector<1x2xf32>, %arg1: vector<1x[3]xf32>, %arg2: vector<2x[3]xf32>)
+-> vector<2x[3]xf32>
+{
+  %0 = vector.contract #matmat_trait_2 %arg0, %arg1, %arg2
+    : vector<1x2xf32>, vector<1x[3]xf32> into vector<2x[3]xf32>
+  return %0 : vector<2x[3]xf32>
+}
+
 #matmat_accesses_3 = [
   affine_map<(m, n, k) -> (k, m)>,
   affine_map<(m, n, k) -> (n, k)>,
@@ -286,6 +392,23 @@ func.func @matmul_3(%arg0: vector<1x2xf32>, %arg1: vector<3x1xf32>, %arg2: vecto
   return %0 : vector<2x3xf32>
 }
 
+// CHECK-LABEL: func @matmul_3_scalable
+// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<1x2xf32>,
+// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<[3]x1xf32>,
+// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x[3]xf32>
+//      CHECK: %[[Bt:.*]] = vector.transpose %[[B]], [1, 0]
+//      CHECK: %[[a0:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<1x2xf32>
+//      CHECK: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<[3]xf32> from vector<1x[3]xf32>
+//      CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
+//      CHECK: return %[[c0]] : vector<2x[3]xf32>
+func.func @matmul_3_scalable(%arg0: vector<1x2xf32>, %arg1: vector<[3]x1xf32>, %arg2: vector<2x[3]xf32>)
+-> vector<2x[3]xf32>
+{
+  %0 = vector.contract #matmat_trait_3 %arg0, %arg1, %arg2
+    : vector<1x2xf32>, vector<[3]x1xf32> into vector<2x[3]xf32>
+  return %0 : vector<2x[3]xf32>
+}
+
 #matmat_accesses_4 = [
   affine_map<(m, n, k) -> (m, k)>,
   affine_map<(m, n, k) -> (k, n)>,
@@ -313,6 +436,33 @@ func.func @matmul_4(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vecto
   return %0 : vector<3x2xf32>
 }
 
+// CHECK-LABEL: func @matmul_4_scalable
+// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<[2]x1xf32>,
+// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>,
+// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<3x[2]xf32>
+//      CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
+//      CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<3xf32> from vector<1x3xf32>
+//      CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<[2]xf32> from vector<1x[2]xf32>
+//      CHECK: %[[c0:.*]] = vector.outerproduct %[[b0]], %[[a0]], %[[C]]
+//      CHECK: return %[[c0]] : vector<3x[2]xf32>
+func.func @matmul_4_scalable(%arg0: vector<[2]x1xf32>, %arg1: vector<1x3xf32>, %arg2: vector<3x[2]xf32>)
+-> vector<3x[2]xf32>
+{
+  %0 = vector.contract #matmat_trait_4 %arg0, %arg1, %arg2
+    : vector<[2]x1xf32>, vector<1x3xf32> into vector<3x[2]xf32>
+  return %0 : vector<3x[2]xf32>
+}
+
+#matmat_accesses_5 = [
+  affine_map<(m, n, k) -> (m, k)>,
+  affine_map<(m, n, k) -> (k, n)>,
+  affine_map<(m, n, k) -> (n, m)>
+]
+#matmat_trait_5 = {
+  indexing_maps = #matmat_accesses_5,
+  iterator_types = ["parallel", "parallel", "reduction"]
+}
+
 // CHECK-LABEL: @masked_matvec_mk_k_m
 // CHECK-SAME:  %[[MAT:.+]]: vector<4x2xf32>
 // CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>



More information about the Mlir-commits mailing list