[Mlir-commits] [mlir] [mlir][vector] Add tests for scalable vectors (PR #67806)

Andrzej WarzyƄski llvmlistbot at llvm.org
Fri Sep 29 07:07:02 PDT 2023


https://github.com/banach-space created https://github.com/llvm/llvm-project/pull/67806

Adds tests for scalable vectors in:
  * vector-contract-to-outerproduct-transforms.mlir
Every existing test is duplicated with (fixed-width vectors are replaced
with scalable vectors). One test required a fix in
  * LowerVectorContract.cpp.

This change is a part of a larger effort to enable scalable
vectorisation in Linalg. See this RFC for more context:
  * https://discourse.llvm.org/t/rfc-scalable-vectorisation-in-linalg/

Fixes #67804


>From e16dcb4b86a75c222ce8a10d60939890ea1cd9b7 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Fri, 29 Sep 2023 13:58:32 +0000
Subject: [PATCH] [mlir][vector] Add tests for scalable vectors

Adds tests for scalable vectors in:
  * vector-contract-to-outerproduct-transforms.mlir
Every existing test is duplicated with (fixed-width vectors are replaced
with scalable vectors). One test required a fix in
  * LowerVectorContract.cpp.

This change is a part of a larger effort to enable scalable
vectorisation in Linalg. See this RFC for more context:
  * https://discourse.llvm.org/t/rfc-scalable-vectorisation-in-linalg/

Fixes #67804
---
 .../Vector/Transforms/LowerVectorContract.cpp |   2 +-
 ...r-contract-to-outerproduct-transforms.mlir | 495 ++++++++++++++++--
 2 files changed, 453 insertions(+), 44 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index 04d9ddf2183f8c5..6e63d52d22a1f6b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -418,7 +418,7 @@ struct UnrolledOuterProductGenerator
       return v;
     Type promotedType = dstElementType;
     if (vecType)
-      promotedType = VectorType::get(vecType.getShape(), promotedType);
+      promotedType = vecType.clone(promotedType);
     if (isa<FloatType>(dstElementType))
       return rewriter.create<arith::ExtFOp>(loc, promotedType, v);
     return rewriter.create<arith::ExtSIOp>(loc, promotedType, v);
diff --git a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
index deea7747f36799c..3746897bcd864f6 100644
--- a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
@@ -34,16 +34,16 @@
 // CHECK-SAME:                                      %[[VAL_0:.*]]: vector<2x3xf32>,
 // CHECK-SAME:                                      %[[VAL_1:.*]]: vector<3xf32>,
 // CHECK-SAME:                                      %[[VAL_2:.*]]: vector<2xf32>,
-// CHECK-SAME:                                      %[[IN_MASK:.*]]: vector<2x3xi1>) -> vector<2xf32>
+// CHECK-SAME:      %[[IN_MASK:.*]]: vector<2x3xi1>) -> vector<2xf32>
 // CHECK:           %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [1, 0] : vector<2x3xi1> to vector<3x2xi1>
 // CHECK:           %[[MASK0:.*]] = vector.extract %[[T_MASK]][0] : vector<2xi1> from vector<3x2xi1>
-// CHECK:           vector.mask %[[MASK0]] { vector.outerproduct
+// CHECK:           vector.mask %[[MASK0]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32>
 
 // CHECK:           %[[MASK1:.*]] = vector.extract %[[T_MASK]][1] : vector<2xi1> from vector<3x2xi1>
-// CHECK:           vector.mask %[[MASK1]] { vector.outerproduct
+// CHECK:           vector.mask %[[MASK1]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32>
 
 // CHECK:           %[[MASK2:.*]] = vector.extract %[[T_MASK]][2] : vector<2xi1> from vector<3x2xi1>
-// CHECK:           vector.mask %[[MASK2]] { vector.outerproduct
+// CHECK:           vector.mask %[[MASK2]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32>
 
 func.func @masked_extract_contract2(%arg0: vector<2x3xf32>,
                                     %arg1: vector<3xf32>,
@@ -54,22 +54,46 @@ func.func @masked_extract_contract2(%arg0: vector<2x3xf32>,
   return %0 : vector<2xf32>
 }
 
+
+// CHECK-LABEL:   func.func @masked_extract_contract2_scalable(
+// CHECK-SAME:      %{{.*}}: vector<[2]x[3]xf32>,
+// CHECK-SAME:      %{{.*}}: vector<[3]xf32>,
+// CHECK-SAME:      %{{.*}}: vector<[2]xf32>,
+// CHECK-SAME:      %[[IN_MASK:.*]]: vector<[2]x[3]xi1>) -> vector<[2]xf32>
+// CHECK:           %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [1, 0] : vector<[2]x[3]xi1> to vector<[3]x[2]xi1>
+// CHECK:           %[[MASK0:.*]] = vector.extract %[[T_MASK]][0] : vector<[2]xi1> from vector<[3]x[2]xi1>
+// CHECK:           vector.mask %[[MASK0]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<[2]xf32>, f32 } : vector<[2]xi1> -> vector<[2]xf32>
+
+// CHECK:           %[[MASK1:.*]] = vector.extract %[[T_MASK]][1] : vector<[2]xi1> from vector<[3]x[2]xi1>
+// CHECK:           vector.mask %[[MASK1]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<[2]xf32>, f32 } : vector<[2]xi1> -> vector<[2]xf32>
+
+// CHECK:           %[[MASK2:.*]] = vector.extract %[[T_MASK]][2] : vector<[2]xi1> from vector<[3]x[2]xi1>
+// CHECK:           vector.mask %[[MASK2]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<[2]xf32>, f32 } : vector<[2]xi1> -> vector<[2]xf32>
+func.func @masked_extract_contract2_scalable(%arg0: vector<[2]x[3]xf32>,
+                                    %arg1: vector<[3]xf32>,
+                                    %arg2: vector<[2]xf32>,
+                                    %m: vector<[2]x[3]xi1>) -> vector<[2]xf32> {
+  %0 = vector.mask %m { vector.contract #matvec_trait %arg0, %arg1, %arg2
+          : vector<[2]x[3]xf32>, vector<[3]xf32> into vector<[2]xf32> } : vector<[2]x[3]xi1> -> vector<[2]xf32>
+  return %0 : vector<[2]xf32>
+}
+
 // CHECK-LABEL: func.func @masked_extract_contract4(
-// CHECK-SAME:                                      %[[VAL_0:.*]]: vector<3x5xf32>,
-// CHECK-SAME:                                      %[[VAL_1:.*]]: vector<5x7xf32>,
-// CHECK-SAME:                                      %[[VAL_2:.*]]: vector<3x7xf32>,
-// CHECK-SAME:                                      %[[VAL_3:.*]]: vector<3x7x5xi1>) -> vector<3x7xf32> {
-// CHECK:         %[[VAL_5:.*]] = vector.transpose %[[VAL_3]], [2, 0, 1] : vector<3x7x5xi1> to vector<5x3x7xi1>
-// CHECK:         %[[VAL_8:.*]] = vector.extract %[[VAL_5]][0] : vector<3x7xi1> from vector<5x3x7xi1>
-// CHECK:         %[[VAL_9:.*]] = vector.mask %[[VAL_8]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
-// CHECK:         %[[VAL_12:.*]] = vector.extract %[[VAL_5]][1] : vector<3x7xi1> from vector<5x3x7xi1>
-// CHECK:         %[[VAL_13:.*]] = vector.mask %[[VAL_12]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
-// CHECK:         %[[VAL_16:.*]] = vector.extract %[[VAL_5]][2] : vector<3x7xi1> from vector<5x3x7xi1>
-// CHECK:         %[[VAL_17:.*]] = vector.mask %[[VAL_16]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
-// CHECK:         %[[VAL_20:.*]] = vector.extract %[[VAL_5]][3] : vector<3x7xi1> from vector<5x3x7xi1>
-// CHECK:         %[[VAL_21:.*]] = vector.mask %[[VAL_20]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
-// CHECK:         %[[VAL_24:.*]] = vector.extract %[[VAL_5]][4] : vector<3x7xi1> from vector<5x3x7xi1>
-// CHECK:         %[[VAL_25:.*]] = vector.mask %[[VAL_24]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
+// CHECK-SAME:    %{{.*}}: vector<3x5xf32>,
+// CHECK-SAME:    %{{.*}}: vector<5x7xf32>,
+// CHECK-SAME:    %{{.*}}: vector<3x7xf32>,
+// CHECK-SAME:    %[[M:.*]]: vector<3x7x5xi1>) -> vector<3x7xf32> {
+// CHECK:         %[[M_TRAN:.*]] = vector.transpose %[[M]], [2, 0, 1] : vector<3x7x5xi1> to vector<5x3x7xi1>
+// CHECK:         %[[M_0:.*]] = vector.extract %[[M_TRAN]][0] : vector<3x7xi1> from vector<5x3x7xi1>
+// CHECK:         %{{.*}} = vector.mask %[[M_0]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
+// CHECK:         %[[M_1:.*]] = vector.extract %[[M_TRAN]][1] : vector<3x7xi1> from vector<5x3x7xi1>
+// CHECK:         %{{.*}} = vector.mask %[[M_1]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
+// CHECK:         %[[M_2:.*]] = vector.extract %[[M_TRAN]][2] : vector<3x7xi1> from vector<5x3x7xi1>
+// CHECK:         %{{.*}} = vector.mask %[[M_2]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
+// CHECK:         %[[M_3:.*]] = vector.extract %[[M_TRAN]][3] : vector<3x7xi1> from vector<5x3x7xi1>
+// CHECK:         %{{.*}} = vector.mask %[[M_3]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
+// CHECK:         %[[M_4:.*]] = vector.extract %[[M_TRAN]][4] : vector<3x7xi1> from vector<5x3x7xi1>
+// CHECK:         %{{.*}} = vector.mask %[[M_4]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
 
 func.func @masked_extract_contract4(%arg0: vector<3x5xf32>,
                                     %arg1: vector<5x7xf32>,
@@ -80,10 +104,36 @@ func.func @masked_extract_contract4(%arg0: vector<3x5xf32>,
   return %0 : vector<3x7xf32>
 }
 
+// CHECK-LABEL: func.func @masked_extract_contract4_scalable(
+// CHECK-SAME:    %{{.*}}: vector<[3]x[5]xf32>,
+// CHECK-SAME:    %{{.*}}: vector<[5]x[7]xf32>,
+// CHECK-SAME:    %{{.*}}: vector<[3]x[7]xf32>,
+// CHECK-SAME:    %[[M:.*]]: vector<[3]x[7]x[5]xi1>) -> vector<[3]x[7]xf32> {
+// CHECK:         %[[M_TRAN:.*]] = vector.transpose %[[M]], [2, 0, 1] : vector<[3]x[7]x[5]xi1> to vector<[5]x[3]x[7]xi1>
+// CHECK:         %[[M_0:.*]] = vector.extract %[[M_TRAN]][0] : vector<[3]x[7]xi1> from vector<[5]x[3]x[7]xi1>
+// CHECK:         %{{.*}} = vector.mask %[[M_0]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<[3]xf32>, vector<[7]xf32> } : vector<[3]x[7]xi1> -> vector<[3]x[7]xf32>
+// CHECK:         %[[M_1:.*]] = vector.extract %[[M_TRAN]][1] : vector<[3]x[7]xi1> from vector<[5]x[3]x[7]xi1>
+// CHECK:         %{{.*}} = vector.mask %[[M_1]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<[3]xf32>, vector<[7]xf32> } : vector<[3]x[7]xi1> -> vector<[3]x[7]xf32>
+// CHECK:         %[[M_2:.*]] = vector.extract %[[M_TRAN]][2] : vector<[3]x[7]xi1> from vector<[5]x[3]x[7]xi1>
+// CHECK:         %{{.*}} = vector.mask %[[M_2]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<[3]xf32>, vector<[7]xf32> } : vector<[3]x[7]xi1> -> vector<[3]x[7]xf32>
+// CHECK:         %[[M_3:.*]] = vector.extract %[[M_TRAN]][3] : vector<[3]x[7]xi1> from vector<[5]x[3]x[7]xi1>
+// CHECK:         %{{.*}} = vector.mask %[[M_3]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<[3]xf32>, vector<[7]xf32> } : vector<[3]x[7]xi1> -> vector<[3]x[7]xf32>
+// CHECK:         %[[M_4:.*]] = vector.extract %[[M_TRAN]][4] : vector<[3]x[7]xi1> from vector<[5]x[3]x[7]xi1>
+// CHECK:         %{{.*}} = vector.mask %[[M_4]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<[3]xf32>, vector<[7]xf32> } : vector<[3]x[7]xi1> -> vector<[3]x[7]xf32>
+
+func.func @masked_extract_contract4_scalable(%arg0: vector<[3]x[5]xf32>,
+                                    %arg1: vector<[5]x[7]xf32>,
+                                    %arg2: vector<[3]x[7]xf32>,
+                                    %m : vector<[3]x[7]x[5]xi1>) -> vector<[3]x[7]xf32> {
+  %0 = vector.mask %m { vector.contract #matmat_trait %arg0, %arg1, %arg2
+  : vector<[3]x[5]xf32>, vector<[5]x[7]xf32> into vector<[3]x[7]xf32> } : vector<[3]x[7]x[5]xi1> -> vector<[3]x[7]xf32>
+  return %0 : vector<[3]x[7]xf32>
+}
+
 // CHECK-LABEL: func @matmul
-// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x4xf32>,
-// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x3xf32>,
-// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32>
+// CHECK-SAME: %[[A:.*]]: vector<2x4xf32>,
+// CHECK-SAME: %[[B:.*]]: vector<4x3xf32>,
+// CHECK-SAME: %[[C:.*]]: vector<2x3xf32>
 //      CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
 // CHECK-SAME:  : vector<2x4xf32> to vector<4x2xf32>
 //
@@ -116,6 +166,42 @@ func.func @matmul(%arg0: vector<2x4xf32>,
   return %0 : vector<2x3xf32>
 }
 
+// CHECK-LABEL: func @matmul_scalable
+// CHECK-SAME: %[[A:.*]]: vector<[2]x[4]xf32>,
+// CHECK-SAME: %[[B:.*]]: vector<[4]x[3]xf32>,
+// CHECK-SAME: %[[C:.*]]: vector<[2]x[3]xf32>
+//      CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
+// CHECK-SAME:  : vector<[2]x[4]xf32> to vector<[4]x[2]xf32>
+//
+//      CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<[2]xf32> from vector<[4]x[2]xf32>
+//      CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<[3]xf32> from vector<[4]x[3]xf32>
+//      CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
+// CHECK-SAME:  : vector<[2]xf32>, vector<[3]xf32>
+//
+//      CHECK: %[[a1:.*]] = vector.extract %[[At]][1] : vector<[2]xf32> from vector<[4]x[2]xf32>
+//      CHECK: %[[b1:.*]] = vector.extract %[[B]][1] : vector<[3]xf32> from vector<[4]x[3]xf32>
+//      CHECK: %[[c1:.*]] = vector.outerproduct %[[a1]], %[[b1]], %[[c0]]
+// CHECK-SAME:  : vector<[2]xf32>, vector<[3]xf32>
+//
+//      CHECK: %[[a2:.*]] = vector.extract %[[At]][2] : vector<[2]xf32> from vector<[4]x[2]xf32>
+//      CHECK: %[[b2:.*]] = vector.extract %[[B]][2] : vector<[3]xf32> from vector<[4]x[3]xf32>
+//      CHECK: %[[c2:.*]] = vector.outerproduct %[[a2]], %[[b2]], %[[c1]]
+// CHECK-SAME:  : vector<[2]xf32>, vector<[3]xf32>
+//
+//      CHECK: %[[a3:.*]] = vector.extract %[[At]][3] : vector<[2]xf32> from vector<[4]x[2]xf32>
+//      CHECK: %[[b3:.*]] = vector.extract %[[B]][3] : vector<[3]xf32> from vector<[4]x[3]xf32>
+//      CHECK: %[[c3:.*]] = vector.outerproduct %[[a3]], %[[b3]], %[[c2]]
+// CHECK-SAME:  : vector<[2]xf32>, vector<[3]xf32>
+//
+//      CHECK: return %[[c3]] : vector<[2]x[3]xf32>
+func.func @matmul_scalable(%arg0: vector<[2]x[4]xf32>,
+                          %arg1: vector<[4]x[3]xf32>,
+                          %arg2: vector<[2]x[3]xf32>) -> vector<[2]x[3]xf32> {
+  %0 = vector.contract #matmat_trait %arg0, %arg1, %arg2
+    : vector<[2]x[4]xf32>, vector<[4]x[3]xf32> into vector<[2]x[3]xf32>
+  return %0 : vector<[2]x[3]xf32>
+}
+
 // CHECK-LABEL: func @matmul_0
 // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
 // CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>,
@@ -133,6 +219,23 @@ func.func @matmul_0(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vecto
   return %0 : vector<2x3xf32>
 }
 
+// CHECK-LABEL: func @matmul_0_scalable
+// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
+// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>,
+// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32>
+//      CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
+//      CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf32> from vector<1x2xf32>
+//      CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<3xf32> from vector<1x3xf32>
+//      CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
+//      CHECK: return %[[c0]] : vector<2x3xf32>
+func.func @matmul_0_scalable(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vector<2x3xf32>)
+-> vector<2x3xf32>
+{
+  %0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2
+    : vector<2x1xf32>, vector<1x3xf32> into vector<2x3xf32>
+  return %0 : vector<2x3xf32>
+}
+
 // CHECK-LABEL: func @matmul_0_mixed
 // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf16>,
 // CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf16>,
@@ -152,6 +255,25 @@ func.func @matmul_0_mixed(%arg0: vector<2x1xf16>, %arg1: vector<1x3xf16>, %arg2:
   return %0 : vector<2x3xf32>
 }
 
+// CHECK-LABEL: func @matmul_0_mixed_scalable
+// CHECK-SAME: %[[A:.*]]: vector<[2]x[1]xf16>,
+// CHECK-SAME: %[[B:.*]]: vector<[1]x[3]xf16>,
+// CHECK-SAME: %[[C:.*]]: vector<[2]x[3]xf32>
+//      CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
+//      CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<[2]xf16> from vector<[1]x[2]xf16>
+//      CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<[3]xf16> from vector<[1]x[3]xf16>
+//      CHECK: %[[a1:.*]] = arith.extf %[[a0]] : vector<[2]xf16> to vector<[2]xf32>
+//      CHECK: %[[b1:.*]] = arith.extf %[[b0]] : vector<[3]xf16> to vector<[3]xf32>
+//      CHECK: %[[c0:.*]] = vector.outerproduct %[[a1]], %[[b1]], %[[C]]
+//      CHECK: return %[[c0]] : vector<[2]x[3]xf32>
+func.func @matmul_0_mixed_scalable(%arg0: vector<[2]x[1]xf16>, %arg1: vector<[1]x[3]xf16>, %arg2: vector<[2]x[3]xf32>)
+-> vector<[2]x[3]xf32>
+{
+  %0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2
+    : vector<[2]x[1]xf16>, vector<[1]x[3]xf16> into vector<[2]x[3]xf32>
+  return %0 : vector<[2]x[3]xf32>
+}
+
 #matmat_accesses_1 = [
   affine_map<(m, n, k) -> (m, k)>,
   affine_map<(m, n, k) -> (n, k)>,
@@ -163,9 +285,9 @@ func.func @matmul_0_mixed(%arg0: vector<2x1xf16>, %arg1: vector<1x3xf16>, %arg2:
 }
 
 // CHECK-LABEL: func @matmul_1
-// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
-// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<3x1xf32>,
-// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32>
+// CHECK-SAME: %[[A:.*]]: vector<2x1xf32>,
+// CHECK-SAME: %[[B:.*]]: vector<3x1xf32>,
+// CHECK-SAME: %[[C:.*]]: vector<2x3xf32>
 //      CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
 //      CHECK: %[[Bt:.*]] = vector.transpose %[[B]], [1, 0]
 //      CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf32> from vector<1x2xf32>
@@ -180,6 +302,24 @@ func.func @matmul_1(%arg0: vector<2x1xf32>, %arg1: vector<3x1xf32>, %arg2: vecto
   return %0 : vector<2x3xf32>
 }
 
+// CHECK-LABEL: func @matmul_1_scalable
+// CHECK-SAME: %[[A:.*]]: vector<[2]x[1]xf32>,
+// CHECK-SAME: %[[B:.*]]: vector<[3]x[1]xf32>,
+// CHECK-SAME: %[[C:.*]]: vector<[2]x[3]xf32>
+//      CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
+//      CHECK: %[[Bt:.*]] = vector.transpose %[[B]], [1, 0]
+//      CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<[2]xf32> from vector<[1]x[2]xf32>
+//      CHECK: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<[3]xf32> from vector<[1]x[3]xf32>
+//      CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
+//      CHECK: return %[[c0]] : vector<[2]x[3]xf32>
+func.func @matmul_1_scalable(%arg0: vector<[2]x[1]xf32>, %arg1: vector<[3]x[1]xf32>, %arg2: vector<[2]x[3]xf32>)
+-> vector<[2]x[3]xf32>
+{
+  %0 = vector.contract #matmat_trait_1 %arg0, %arg1, %arg2
+    : vector<[2]x[1]xf32>, vector<[3]x[1]xf32> into vector<[2]x[3]xf32>
+  return %0 : vector<[2]x[3]xf32>
+}
+
 #matmat_accesses_2 = [
   affine_map<(m, n, k) -> (k, m)>,
   affine_map<(m, n, k) -> (k, n)>,
@@ -191,9 +331,9 @@ func.func @matmul_1(%arg0: vector<2x1xf32>, %arg1: vector<3x1xf32>, %arg2: vecto
 }
 
 // CHECK-LABEL: func @matmul_2
-// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<1x2xf32>,
-// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>,
-// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32>
+// CHECK-SAME: %[[A:.*]]: vector<1x2xf32>,
+// CHECK-SAME: %[[B:.*]]: vector<1x3xf32>,
+// CHECK-SAME: %[[C:.*]]: vector<2x3xf32>
 //      CHECK: %[[a0:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<1x2xf32>
 //      CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<3xf32> from vector<1x3xf32>
 //      CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
@@ -206,6 +346,22 @@ func.func @matmul_2(%arg0: vector<1x2xf32>, %arg1: vector<1x3xf32>, %arg2: vecto
   return %0 : vector<2x3xf32>
 }
 
+// CHECK-LABEL: func @matmul_2_scalable
+// CHECK-SAME: %[[A:.*]]: vector<[1]x[2]xf32>,
+// CHECK-SAME: %[[B:.*]]: vector<[1]x[3]xf32>,
+// CHECK-SAME: %[[C:.*]]: vector<[2]x[3]xf32>
+//      CHECK: %[[a0:.*]] = vector.extract %[[A]][0] : vector<[2]xf32> from vector<[1]x[2]xf32>
+//      CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<[3]xf32> from vector<[1]x[3]xf32>
+//      CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
+//      CHECK: return %[[c0]] : vector<[2]x[3]xf32>
+func.func @matmul_2_scalable(%arg0: vector<[1]x[2]xf32>, %arg1: vector<[1]x[3]xf32>, %arg2: vector<[2]x[3]xf32>)
+-> vector<[2]x[3]xf32>
+{
+  %0 = vector.contract #matmat_trait_2 %arg0, %arg1, %arg2
+    : vector<[1]x[2]xf32>, vector<[1]x[3]xf32> into vector<[2]x[3]xf32>
+  return %0 : vector<[2]x[3]xf32>
+}
+
 #matmat_accesses_3 = [
   affine_map<(m, n, k) -> (k, m)>,
   affine_map<(m, n, k) -> (n, k)>,
@@ -217,9 +373,9 @@ func.func @matmul_2(%arg0: vector<1x2xf32>, %arg1: vector<1x3xf32>, %arg2: vecto
 }
 
 // CHECK-LABEL: func @matmul_3
-// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<1x2xf32>,
-// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<3x1xf32>,
-// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32>
+// CHECK-SAME: %[[A:.*]]: vector<1x2xf32>,
+// CHECK-SAME: %[[B:.*]]: vector<3x1xf32>,
+// CHECK-SAME: %[[C:.*]]: vector<2x3xf32>
 //      CHECK: %[[Bt:.*]] = vector.transpose %[[B]], [1, 0]
 //      CHECK: %[[a0:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<1x2xf32>
 //      CHECK: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<3xf32> from vector<1x3xf32>
@@ -233,6 +389,23 @@ func.func @matmul_3(%arg0: vector<1x2xf32>, %arg1: vector<3x1xf32>, %arg2: vecto
   return %0 : vector<2x3xf32>
 }
 
+// CHECK-LABEL: func @matmul_3_scalable
+// CHECK-SAME: %[[A:.*]]: vector<[1]x[2]xf32>,
+// CHECK-SAME: %[[B:.*]]: vector<[3]x[1]xf32>,
+// CHECK-SAME: %[[C:.*]]: vector<[2]x[3]xf32>
+//      CHECK: %[[Bt:.*]] = vector.transpose %[[B]], [1, 0]
+//      CHECK: %[[a0:.*]] = vector.extract %[[A]][0] : vector<[2]xf32> from vector<[1]x[2]xf32>
+//      CHECK: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<[3]xf32> from vector<[1]x[3]xf32>
+//      CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
+//      CHECK: return %[[c0]] : vector<[2]x[3]xf32>
+func.func @matmul_3_scalable(%arg0: vector<[1]x[2]xf32>, %arg1: vector<[3]x[1]xf32>, %arg2: vector<[2]x[3]xf32>)
+-> vector<[2]x[3]xf32>
+{
+  %0 = vector.contract #matmat_trait_3 %arg0, %arg1, %arg2
+    : vector<[1]x[2]xf32>, vector<[3]x[1]xf32> into vector<[2]x[3]xf32>
+  return %0 : vector<[2]x[3]xf32>
+}
+
 #matmat_accesses_4 = [
   affine_map<(m, n, k) -> (m, k)>,
   affine_map<(m, n, k) -> (k, n)>,
@@ -244,9 +417,9 @@ func.func @matmul_3(%arg0: vector<1x2xf32>, %arg1: vector<3x1xf32>, %arg2: vecto
 }
 
 // CHECK-LABEL: func @matmul_4
-// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
-// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>,
-// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<3x2xf32>
+// CHECK-SAME: %[[A:.*]]: vector<2x1xf32>,
+// CHECK-SAME: %[[B:.*]]: vector<1x3xf32>,
+// CHECK-SAME: %[[C:.*]]: vector<3x2xf32>
 //      CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
 //      CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<3xf32> from vector<1x3xf32>
 //      CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf32> from vector<1x2xf32>
@@ -260,6 +433,23 @@ func.func @matmul_4(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vecto
   return %0 : vector<3x2xf32>
 }
 
+// CHECK-LABEL: func @matmul_4_scalable
+// CHECK-SAME: %[[A:.*]]: vector<[2]x[1]xf32>,
+// CHECK-SAME: %[[B:.*]]: vector<[1]x[3]xf32>,
+// CHECK-SAME: %[[C:.*]]: vector<[3]x[2]xf32>
+//      CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
+//      CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<[3]xf32> from vector<[1]x[3]xf32>
+//      CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<[2]xf32> from vector<[1]x[2]xf32>
+//      CHECK: %[[c0:.*]] = vector.outerproduct %[[b0]], %[[a0]], %[[C]]
+//      CHECK: return %[[c0]] : vector<[3]x[2]xf32>
+func.func @matmul_4_scalable(%arg0: vector<[2]x[1]xf32>, %arg1: vector<[1]x[3]xf32>, %arg2: vector<[3]x[2]xf32>)
+-> vector<[3]x[2]xf32>
+{
+  %0 = vector.contract #matmat_trait_4 %arg0, %arg1, %arg2
+    : vector<[2]x[1]xf32>, vector<[1]x[3]xf32> into vector<[3]x[2]xf32>
+  return %0 : vector<[3]x[2]xf32>
+}
+
 #matmat_accesses_5 = [
   affine_map<(m, n, k) -> (m, k)>,
   affine_map<(m, n, k) -> (k, n)>,
@@ -271,9 +461,9 @@ func.func @matmul_4(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vecto
 }
 
 // CHECK-LABEL: func @matmul_5
-// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
-// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>,
-// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<3x2xf32>
+// CHECK-SAME: %[[A:.*]]: vector<2x1xf32>,
+// CHECK-SAME: %[[B:.*]]: vector<1x3xf32>,
+// CHECK-SAME: %[[C:.*]]: vector<3x2xf32>
 //      CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
 //      CHECK-DAG: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf32> from vector<1x2xf32>
 //      CHECK-DAG: %[[b0:.*]] = vector.extract %[[B]][0] : vector<3xf32> from vector<1x3xf32>
@@ -287,6 +477,23 @@ func.func @matmul_5(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vecto
   return %0 : vector<3x2xf32>
 }
 
+// CHECK-LABEL: func @matmul_5_scalable
+// CHECK-SAME: %[[A:.*]]: vector<[2]x[1]xf32>,
+// CHECK-SAME: %[[B:.*]]: vector<[1]x[3]xf32>,
+// CHECK-SAME: %[[C:.*]]: vector<[3]x[2]xf32>
+//      CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
+//      CHECK-DAG: %[[a0:.*]] = vector.extract %[[At]][0] : vector<[2]xf32> from vector<[1]x[2]xf32>
+//      CHECK-DAG: %[[b0:.*]] = vector.extract %[[B]][0] : vector<[3]xf32> from vector<[1]x[3]xf32>
+//      CHECK: %[[c0:.*]] = vector.outerproduct %[[b0]], %[[a0]], %[[C]]
+//      CHECK: return %[[c0]] : vector<[3]x[2]xf32>
+func.func @matmul_5_scalable(%arg0: vector<[2]x[1]xf32>, %arg1: vector<[1]x[3]xf32>, %arg2: vector<[3]x[2]xf32>)
+-> vector<[3]x[2]xf32>
+{
+  %0 = vector.contract #matmat_trait_5 %arg0, %arg1, %arg2
+    : vector<[2]x[1]xf32>, vector<[1]x[3]xf32> into vector<[3]x[2]xf32>
+  return %0 : vector<[3]x[2]xf32>
+}
+
 #matmat_accesses_6 = [
   affine_map<(m, n, k) -> (m, k)>,
   affine_map<(m, n, k) -> (k, n)>,
@@ -298,9 +505,9 @@ func.func @matmul_5(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vecto
 }
 
 // CHECK-LABEL: func @matmul_6
-// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
-// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>,
-// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<3x2xf32>
+// CHECK-SAME: %[[A:.*]]: vector<2x1xf32>,
+// CHECK-SAME: %[[B:.*]]: vector<1x3xf32>,
+// CHECK-SAME: %[[C:.*]]: vector<3x2xf32>
 //      CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
 //      CHECK-DAG: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf32> from vector<1x2xf32>
 //      CHECK-DAG: %[[b0:.*]] = vector.extract %[[B]][0] : vector<3xf32> from vector<1x3xf32>
@@ -314,6 +521,23 @@ func.func @matmul_6(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vecto
   return %0 : vector<3x2xf32>
 }
 
+// CHECK-LABEL: func @matmul_6_scalable
+// CHECK-SAME: %[[A:.*]]: vector<[2]x[1]xf32>,
+// CHECK-SAME: %[[B:.*]]: vector<[1]x[3]xf32>,
+// CHECK-SAME: %[[C:.*]]: vector<[3]x[2]xf32>
+//      CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
+//      CHECK-DAG: %[[a0:.*]] = vector.extract %[[At]][0] : vector<[2]xf32> from vector<[1]x[2]xf32>
+//      CHECK-DAG: %[[b0:.*]] = vector.extract %[[B]][0] : vector<[3]xf32> from vector<[1]x[3]xf32>
+//      CHECK: %[[c0:.*]] = vector.outerproduct %[[b0]], %[[a0]], %[[C]]
+//      CHECK: return %[[c0]] : vector<[3]x[2]xf32>
+func.func @matmul_6_scalable(%arg0: vector<[2]x[1]xf32>, %arg1: vector<[1]x[3]xf32>, %arg2: vector<[3]x[2]xf32>)
+-> vector<[3]x[2]xf32>
+{
+  %0 = vector.contract #matmat_trait_6 %arg0, %arg1, %arg2
+    : vector<[2]x[1]xf32>, vector<[1]x[3]xf32> into vector<[3]x[2]xf32>
+  return %0 : vector<[3]x[2]xf32>
+}
+
 #matmat_accesses_7 = [
   affine_map<(m, n, k) -> (m, k)>,
   affine_map<(m, n, k) -> (k, n)>,
@@ -325,9 +549,9 @@ func.func @matmul_6(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vecto
 }
 
 // CHECK-LABEL: func @matmul_7
-// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
-// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>,
-// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<3x2xf32>
+// CHECK-SAME: %[[A:.*]]: vector<2x1xf32>,
+// CHECK-SAME: %[[B:.*]]: vector<1x3xf32>,
+// CHECK-SAME: %[[C:.*]]: vector<3x2xf32>
 //      CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
 //      CHECK-DAG: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf32> from vector<1x2xf32>
 //      CHECK-DAG: %[[b0:.*]] = vector.extract %[[B]][0] : vector<3xf32> from vector<1x3xf32>
@@ -341,6 +565,23 @@ func.func @matmul_7(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vecto
   return %0 : vector<3x2xf32>
 }
 
+// CHECK-LABEL: func @matmul_7_scalable
+// CHECK-SAME: %[[A:.*]]: vector<[2]x[1]xf32>,
+// CHECK-SAME: %[[B:.*]]: vector<[1]x[3]xf32>,
+// CHECK-SAME: %[[C:.*]]: vector<[3]x[2]xf32>
+//      CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
+//      CHECK-DAG: %[[a0:.*]] = vector.extract %[[At]][0] : vector<[2]xf32> from vector<[1]x[2]xf32>
+//      CHECK-DAG: %[[b0:.*]] = vector.extract %[[B]][0] : vector<[3]xf32> from vector<[1]x[3]xf32>
+//      CHECK: %[[c0:.*]] = vector.outerproduct %[[b0]], %[[a0]], %[[C]]
+//      CHECK: return %[[c0]] : vector<[3]x[2]xf32>
+func.func @matmul_7_scalable(%arg0: vector<[2]x[1]xf32>, %arg1: vector<[1]x[3]xf32>, %arg2: vector<[3]x[2]xf32>)
+-> vector<[3]x[2]xf32>
+{
+  %0 = vector.contract #matmat_trait_7 %arg0, %arg1, %arg2
+    : vector<[2]x[1]xf32>, vector<[1]x[3]xf32> into vector<[3]x[2]xf32>
+  return %0 : vector<[3]x[2]xf32>
+}
+
 // CHECK-LABEL: @masked_matvec_mk_k_m
 // CHECK-SAME:  %[[MAT:.+]]: vector<4x2xf32>
 // CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
@@ -362,6 +603,27 @@ func.func @masked_matvec_mk_k_m(%arg0: vector<4x2xf32>, %arg1: vector<2xf32>, %a
   return %res : vector<4xf32>
 }
 
+// CHECK-LABEL: @masked_matvec_mk_k_m_scalable
+// CHECK-SAME:  %[[MAT:.+]]: vector<[4]x[2]xf32>
+// CHECK-SAME:  %[[VEC:.+]]: vector<[2]xf32>
+// CHECK-SAME:  %[[INIT:.+]]: vector<[4]xf32>
+// CHECK-SAME:  %[[MASK:.+]]: vector<[4]x[2]xi1>
+func.func @masked_matvec_mk_k_m_scalable(%arg0: vector<[4]x[2]xf32>, %arg1: vector<[2]xf32>, %arg2: vector<[4]xf32>, %mask: vector<[4]x[2]xi1>) -> vector<[4]xf32> {
+  // CHECK:         vector.transpose %[[MASK]]
+  // CHECK:         vector.transpose %[[MAT]]
+  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
+  %res = vector.mask %mask {
+    vector.contract {
+      indexing_maps = [affine_map<(m, k) -> (m, k)>,
+                       affine_map<(m, k) -> (k)>,
+                       affine_map<(m, k) -> (m)>],
+      iterator_types = ["parallel", "reduction"],
+      kind = #vector.kind<add>
+    } %arg0, %arg1, %arg2 : vector<[4]x[2]xf32>, vector<[2]xf32>, vector<[4]xf32> into vector<[4]xf32>
+  } : vector<[4]x[2]xi1> -> vector<[4]xf32>
+  return %res : vector<[4]xf32>
+}
+
 // CHECK-LABEL: @masked_matvec_km_k_m
 // CHECK-SAME:  %[[MAT:.+]]: vector<2x4xf32>
 // CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
@@ -383,6 +645,27 @@ func.func @masked_matvec_km_k_m(%arg0: vector<2x4xf32>, %arg1: vector<2xf32>, %a
   return %res : vector<4xf32>
 }
 
+// CHECK-LABEL: @masked_matvec_km_k_m_scalable
+// CHECK-SAME:  %[[MAT:.+]]: vector<[2]x[4]xf32>
+// CHECK-SAME:  %[[VEC:.+]]: vector<[2]xf32>
+// CHECK-SAME:  %[[INIT:.+]]: vector<[4]xf32>
+// CHECK-SAME:  %[[MASK:.+]]: vector<[4]x[2]xi1>
+func.func @masked_matvec_km_k_m_scalable(%arg0: vector<[2]x[4]xf32>, %arg1: vector<[2]xf32>, %arg2: vector<[4]xf32>, %mask: vector<[4]x[2]xi1>) -> vector<[4]xf32> {
+  // CHECK:         vector.transpose %[[MASK]]
+  // CHECK-NOT:     vector.transpose %[[MAT]]
+  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
+  %res = vector.mask %mask {
+    vector.contract {
+      indexing_maps = [affine_map<(m, k) -> (k, m)>,
+                       affine_map<(m, k) -> (k)>,
+                       affine_map<(m, k) -> (m)>],
+      iterator_types = ["parallel", "reduction"],
+      kind = #vector.kind<add>
+    } %arg0, %arg1, %arg2 : vector<[2]x[4]xf32>, vector<[2]xf32>, vector<[4]xf32> into vector<[4]xf32>
+  } : vector<[4]x[2]xi1> -> vector<[4]xf32>
+  return %res : vector<[4]xf32>
+}
+
 // CHECK-LABEL: @masked_matvec_k_mk_m
 // CHECK-SAME:  %[[MAT:.+]]: vector<4x2xf32>
 // CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
@@ -404,6 +687,27 @@ func.func @masked_matvec_k_mk_m(%arg0: vector<4x2xf32>, %arg1: vector<2xf32>, %a
   return %res : vector<4xf32>
 }
 
+// CHECK-LABEL: @masked_matvec_k_mk_m_scalable
+// CHECK-SAME:  %[[MAT:.+]]: vector<[4]x[2]xf32>
+// CHECK-SAME:  %[[VEC:.+]]: vector<[2]xf32>
+// CHECK-SAME:  %[[INIT:.+]]: vector<[4]xf32>
+// CHECK-SAME:  %[[MASK:.+]]: vector<[4]x[2]xi1>
+func.func @masked_matvec_k_mk_m_scalable(%arg0: vector<[4]x[2]xf32>, %arg1: vector<[2]xf32>, %arg2: vector<[4]xf32>, %mask: vector<[4]x[2]xi1>) -> vector<[4]xf32> {
+  // CHECK:         vector.transpose %[[MASK]]
+  // CHECK:         vector.transpose %[[MAT]]
+  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
+  %res = vector.mask %mask {
+    vector.contract {
+      indexing_maps = [affine_map<(m, k) -> (k)>,
+                       affine_map<(m, k) -> (m, k)>,
+                       affine_map<(m, k) -> (m)>],
+      iterator_types = ["parallel", "reduction"],
+      kind = #vector.kind<add>
+    } %arg1, %arg0, %arg2 : vector<[2]xf32>, vector<[4]x[2]xf32>, vector<[4]xf32> into vector<[4]xf32>
+  } : vector<[4]x[2]xi1> -> vector<[4]xf32>
+  return %res : vector<[4]xf32>
+}
+
 // CHECK-LABEL: @masked_matvec_k_km_m
 // CHECK-SAME:  %[[MAT:.+]]: vector<2x4xf32>
 // CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
@@ -425,6 +729,27 @@ func.func @masked_matvec_k_km_m(%arg0: vector<2x4xf32>, %arg1: vector<2xf32>, %a
   return %res : vector<4xf32>
 }
 
+// CHECK-LABEL: @masked_matvec_k_km_m_scalable
+// CHECK-SAME:  %[[MAT:.+]]: vector<[2]x[4]xf32>
+// CHECK-SAME:  %[[VEC:.+]]: vector<[2]xf32>
+// CHECK-SAME:  %[[INIT:.+]]: vector<[4]xf32>
+// CHECK-SAME:  %[[MASK:.+]]: vector<[4]x[2]xi1>
+func.func @masked_matvec_k_km_m_scalable(%arg0: vector<[2]x[4]xf32>, %arg1: vector<[2]xf32>, %arg2: vector<[4]xf32>, %mask: vector<[4]x[2]xi1>) -> vector<[4]xf32> {
+  // CHECK:         vector.transpose %[[MASK]]
+  // CHECK-NOT:     vector.transpose %[[MAT]]
+  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
+  %res = vector.mask %mask {
+    vector.contract {
+      indexing_maps = [affine_map<(m, k) -> (k)>,
+                       affine_map<(m, k) -> (k, m)>,
+                       affine_map<(m, k) -> (m)>],
+      iterator_types = ["parallel", "reduction"],
+      kind = #vector.kind<add>
+    } %arg1, %arg0, %arg2 : vector<[2]xf32>, vector<[2]x[4]xf32>, vector<[4]xf32> into vector<[4]xf32>
+  } : vector<[4]x[2]xi1> -> vector<[4]xf32>
+  return %res : vector<[4]xf32>
+}
+
 // CHECK-LABEL: @masked_tmatvec_mk_k_m
 // CHECK-SAME:  %[[MAT:.+]]: vector<4x2xf32>
 // CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
@@ -446,6 +771,27 @@ func.func @masked_tmatvec_mk_k_m(%arg0: vector<4x2xf32>, %arg1: vector<2xf32>, %
   return %res : vector<4xf32>
 }
 
+// CHECK-LABEL: @masked_tmatvec_mk_k_m_scalable
+// CHECK-SAME:  %[[MAT:.+]]: vector<[4]x[2]xf32>
+// CHECK-SAME:  %[[VEC:.+]]: vector<[2]xf32>
+// CHECK-SAME:  %[[INIT:.+]]: vector<[4]xf32>
+// CHECK-SAME:  %[[MASK:.+]]: vector<[2]x[4]xi1>
+func.func @masked_tmatvec_mk_k_m_scalable(%arg0: vector<[4]x[2]xf32>, %arg1: vector<[2]xf32>, %arg2: vector<[4]xf32>, %mask: vector<[2]x[4]xi1>) -> vector<[4]xf32> {
+  // CHECK:         vector.transpose %[[MAT]]
+  // CHECK-NOT:     vector.transpose %[[MASK]]
+  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
+  %res = vector.mask %mask {
+    vector.contract {
+      indexing_maps = [affine_map<(k, m) -> (m, k)>,
+                       affine_map<(k, m) -> (k)>,
+                       affine_map<(k, m) -> (m)>],
+      iterator_types = ["reduction", "parallel"],
+      kind = #vector.kind<add>
+    } %arg0, %arg1, %arg2 : vector<[4]x[2]xf32>, vector<[2]xf32>, vector<[4]xf32> into vector<[4]xf32>
+  } : vector<[2]x[4]xi1> -> vector<[4]xf32>
+  return %res : vector<[4]xf32>
+}
+
 // CHECK-LABEL: @masked_tmatvec_km_k_m
 // CHECK-SAME:  %[[MAT:.+]]: vector<2x4xf32>
 // CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
@@ -467,6 +813,27 @@ func.func @masked_tmatvec_km_k_m(%arg0: vector<2x4xf32>, %arg1: vector<2xf32>, %
   return %res : vector<4xf32>
 }
 
+// CHECK-LABEL: @masked_tmatvec_km_k_m_scalable
+// CHECK-SAME:  %[[MAT:.+]]: vector<[2]x[4]xf32>
+// CHECK-SAME:  %[[VEC:.+]]: vector<[2]xf32>
+// CHECK-SAME:  %[[INIT:.+]]: vector<[4]xf32>
+// CHECK-SAME:  %[[MASK:.+]]: vector<[2]x[4]xi1>
+func.func @masked_tmatvec_km_k_m_scalable(%arg0: vector<[2]x[4]xf32>, %arg1: vector<[2]xf32>, %arg2: vector<[4]xf32>, %mask: vector<[2]x[4]xi1>) -> vector<[4]xf32> {
+  // CHECK-NOT:     vector.transpose %[[MAT]]
+  // CHECK-NOT:     vector.transpose %[[MASK]]
+  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
+  %res = vector.mask %mask {
+    vector.contract {
+      indexing_maps = [affine_map<(k, m) -> (k, m)>,
+                       affine_map<(k, m) -> (k)>,
+                       affine_map<(k, m) -> (m)>],
+      iterator_types = ["reduction", "parallel"],
+      kind = #vector.kind<add>
+    } %arg0, %arg1, %arg2 : vector<[2]x[4]xf32>, vector<[2]xf32>, vector<[4]xf32> into vector<[4]xf32>
+  } : vector<[2]x[4]xi1> -> vector<[4]xf32>
+  return %res : vector<[4]xf32>
+}
+
 // CHECK-LABEL: @masked_tmatvec_k_mk_m
 // CHECK-SAME:  %[[MAT:.+]]: vector<4x2xf32>
 // CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
@@ -488,6 +855,27 @@ func.func @masked_tmatvec_k_mk_m(%arg0: vector<4x2xf32>, %arg1: vector<2xf32>, %
   return %res : vector<4xf32>
 }
 
+// CHECK-LABEL: @masked_tmatvec_k_mk_m_scalable
+// CHECK-SAME:  %[[MAT:.+]]: vector<[4]x[2]xf32>
+// CHECK-SAME:  %[[VEC:.+]]: vector<[2]xf32>
+// CHECK-SAME:  %[[INIT:.+]]: vector<[4]xf32>
+// CHECK-SAME:  %[[MASK:.+]]: vector<[2]x[4]xi1>
+func.func @masked_tmatvec_k_mk_m_scalable(%arg0: vector<[4]x[2]xf32>, %arg1: vector<[2]xf32>, %arg2: vector<[4]xf32>, %mask: vector<[2]x[4]xi1>) -> vector<[4]xf32> {
+  // CHECK:         vector.transpose %[[MAT]]
+  // CHECK-NOT:     vector.transpose %[[MASK]]
+  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
+  %res = vector.mask %mask {
+    vector.contract {
+      indexing_maps = [affine_map<(k, m) -> (k)>,
+                       affine_map<(k, m) -> (m, k)>,
+                       affine_map<(k, m) -> (m)>],
+      iterator_types = ["reduction", "parallel"],
+      kind = #vector.kind<add>
+    } %arg1, %arg0, %arg2 : vector<[2]xf32>, vector<[4]x[2]xf32>, vector<[4]xf32> into vector<[4]xf32>
+  } : vector<[2]x[4]xi1> -> vector<[4]xf32>
+  return %res : vector<[4]xf32>
+}
+
 // CHECK-LABEL: @masked_tmatvec_k_km_m
 // CHECK-SAME:  %[[MAT:.+]]: vector<2x4xf32>
 // CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
@@ -509,6 +897,27 @@ func.func @masked_tmatvec_k_km_m(%arg0: vector<2x4xf32>, %arg1: vector<2xf32>, %
   return %res : vector<4xf32>
 }
 
+// CHECK-LABEL: @masked_tmatvec_k_km_m_scalable
+// CHECK-SAME:  %[[MAT:.+]]: vector<[2]x[4]xf32>
+// CHECK-SAME:  %[[VEC:.+]]: vector<[2]xf32>
+// CHECK-SAME:  %[[INIT:.+]]: vector<[4]xf32>
+// CHECK-SAME:  %[[MASK:.+]]: vector<[2]x[4]xi1>
+func.func @masked_tmatvec_k_km_m_scalable(%arg0: vector<[2]x[4]xf32>, %arg1: vector<[2]xf32>, %arg2: vector<[4]xf32>, %mask: vector<[2]x[4]xi1>) -> vector<[4]xf32> {
+  // CHECK-NOT:     vector.transpose %[[MAT]]
+  // CHECK-NOT:     vector.transpose %[[MASK]]
+  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
+  %res = vector.mask %mask {
+    vector.contract {
+      indexing_maps = [affine_map<(k, m) -> (k)>,
+                       affine_map<(k, m) -> (k, m)>,
+                       affine_map<(k, m) -> (m)>],
+      iterator_types = ["reduction", "parallel"],
+      kind = #vector.kind<add>
+    } %arg1, %arg0, %arg2 : vector<[2]xf32>, vector<[2]x[4]xf32>, vector<[4]xf32> into vector<[4]xf32>
+  } : vector<[2]x[4]xi1> -> vector<[4]xf32>
+  return %res : vector<[4]xf32>
+}
+
 
 transform.sequence failures(propagate) {
 ^bb1(%module_op: !transform.any_op):



More information about the Mlir-commits mailing list