[Mlir-commits] [mlir] andrzej/add more tests for scalable p2 (PR #70425)

Andrzej WarzyƄski llvmlistbot at llvm.org
Fri Oct 27 01:08:30 PDT 2023


https://github.com/banach-space created https://github.com/llvm/llvm-project/pull/70425

- [mlir][vector] Update v.contract -> v.outerproduct tests
- [mlir][Vector] Update v.contract -> v.outerproduct tests (2/N)


>From c262dec81d4d89c6fbd514d0f0afed456947dc79 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Thu, 26 Oct 2023 20:53:46 +0000
Subject: [PATCH 1/2] [mlir][vector] Update v.contract -> v.outerproduct tests

Tests for conversions from vector.contract to vector.outerproduct
for _matvec_ operations are updated with cases for scalable vectors.

This patch updates one specific test file:

    vector-contract-to-outerproduct-transforms.mlir.

The remaining _matmul_ operations in this file will be updated in a
separate patch. Only the parallel dimension is made scalable. Making the
reduction dimension scalable would lead to different patterns without
vector.outerproduct (that would need to be added to some other file).
---
 ...r-contract-to-outerproduct-transforms.mlir | 280 ++++++++++++++----
 1 file changed, 224 insertions(+), 56 deletions(-)

diff --git a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
index 44fb23088cea933..ec88759cd4927cb 100644
--- a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
@@ -313,6 +313,16 @@ func.func @matmul_4(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vecto
   return %0 : vector<3x2xf32>
 }
 
+#matvec_accesses_1 = [
+  affine_map<(m, k) -> (m, k)>,
+  affine_map<(m, k) -> (k)>,
+  affine_map<(m, k) -> (m)>
+]
+#matvec_trait_1 = {
+  indexing_maps = #matvec_accesses_1,
+  iterator_types = ["parallel", "reduction"]
+}
+
 // CHECK-LABEL: @masked_matvec_mk_k_m
 // CHECK-SAME:  %[[MAT:.+]]: vector<4x2xf32>
 // CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
@@ -323,17 +333,38 @@ func.func @masked_matvec_mk_k_m(%arg0: vector<4x2xf32>, %arg1: vector<2xf32>, %a
   // CHECK:         vector.transpose %[[MAT]]
   // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
   %res = vector.mask %mask {
-    vector.contract {
-      indexing_maps = [affine_map<(m, k) -> (m, k)>,
-                       affine_map<(m, k) -> (k)>,
-                       affine_map<(m, k) -> (m)>],
-      iterator_types = ["parallel", "reduction"],
-      kind = #vector.kind<add>
-    } %arg0, %arg1, %arg2 : vector<4x2xf32>, vector<2xf32>, vector<4xf32> into vector<4xf32>
+    vector.contract #matvec_trait_1 %arg0, %arg1, %arg2
+      : vector<4x2xf32>, vector<2xf32>, vector<4xf32> into vector<4xf32>
   } : vector<4x2xi1> -> vector<4xf32>
   return %res : vector<4xf32>
 }
 
+// CHECK-LABEL: @masked_matvec_mk_k_m_scalable_parallel_dim
+// CHECK-SAME:  %[[MAT:.+]]: vector<[4]x2xf32>
+// CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME:  %[[INIT:.+]]: vector<[4]xf32>
+// CHECK-SAME:  %[[MASK:.+]]: vector<[4]x2xi1>
+func.func @masked_matvec_mk_k_m_scalable_parallel_dim(%arg0: vector<[4]x2xf32>, %arg1: vector<2xf32>, %arg2: vector<[4]xf32>, %mask: vector<[4]x2xi1>) -> vector<[4]xf32> {
+  // CHECK:         vector.transpose %[[MASK]]
+  // CHECK:         vector.transpose %[[MAT]]
+  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
+  %res = vector.mask %mask {
+    vector.contract #matvec_trait_1 %arg0, %arg1, %arg2
+      : vector<[4]x2xf32>, vector<2xf32>, vector<[4]xf32> into vector<[4]xf32>
+  } : vector<[4]x2xi1> -> vector<[4]xf32>
+  return %res : vector<[4]xf32>
+}
+
+#matvec_accesses_2 = [
+  affine_map<(m, k) -> (k, m)>,
+  affine_map<(m, k) -> (k)>,
+  affine_map<(m, k) -> (m)>
+]
+#matvec_trait_2 = {
+  indexing_maps = #matvec_accesses_2,
+  iterator_types = ["parallel", "reduction"]
+}
+
 // CHECK-LABEL: @masked_matvec_km_k_m
 // CHECK-SAME:  %[[MAT:.+]]: vector<2x4xf32>
 // CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
@@ -344,17 +375,38 @@ func.func @masked_matvec_km_k_m(%arg0: vector<2x4xf32>, %arg1: vector<2xf32>, %a
   // CHECK-NOT:     vector.transpose %[[MAT]]
   // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
   %res = vector.mask %mask {
-    vector.contract {
-      indexing_maps = [affine_map<(m, k) -> (k, m)>,
-                       affine_map<(m, k) -> (k)>,
-                       affine_map<(m, k) -> (m)>],
-      iterator_types = ["parallel", "reduction"],
-      kind = #vector.kind<add>
-    } %arg0, %arg1, %arg2 : vector<2x4xf32>, vector<2xf32>, vector<4xf32> into vector<4xf32>
+    vector.contract #matvec_trait_2 %arg0, %arg1, %arg2
+      : vector<2x4xf32>, vector<2xf32>, vector<4xf32> into vector<4xf32>
   } : vector<4x2xi1> -> vector<4xf32>
   return %res : vector<4xf32>
 }
 
+// CHECK-LABEL: @masked_matvec_km_k_m_scalable_parallel_dim
+// CHECK-SAME:  %[[MAT:.+]]: vector<2x[4]xf32>
+// CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME:  %[[INIT:.+]]: vector<[4]xf32>
+// CHECK-SAME:  %[[MASK:.+]]: vector<[4]x2xi1>
+func.func @masked_matvec_km_k_m_scalable_parallel_dim(%arg0: vector<2x[4]xf32>, %arg1: vector<2xf32>, %arg2: vector<[4]xf32>, %mask: vector<[4]x2xi1>) -> vector<[4]xf32> {
+  // CHECK:         vector.transpose %[[MASK]]
+  // CHECK-NOT:     vector.transpose %[[MAT]]
+  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
+  %res = vector.mask %mask {
+    vector.contract #matvec_trait_2 %arg0, %arg1, %arg2
+      : vector<2x[4]xf32>, vector<2xf32>, vector<[4]xf32> into vector<[4]xf32>
+  } : vector<[4]x2xi1> -> vector<[4]xf32>
+  return %res : vector<[4]xf32>
+}
+
+#matvec_accesses_3 = [
+  affine_map<(m, k) -> (k)>,
+  affine_map<(m, k) -> (m, k)>,
+  affine_map<(m, k) -> (m)>
+]
+#matvec_trait_3 = {
+  indexing_maps = #matvec_accesses_3,
+  iterator_types = ["parallel", "reduction"]
+}
+
 // CHECK-LABEL: @masked_matvec_k_mk_m
 // CHECK-SAME:  %[[MAT:.+]]: vector<4x2xf32>
 // CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
@@ -365,17 +417,54 @@ func.func @masked_matvec_k_mk_m(%arg0: vector<4x2xf32>, %arg1: vector<2xf32>, %a
   // CHECK:         vector.transpose %[[MAT]]
   // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
   %res = vector.mask %mask {
-    vector.contract {
-      indexing_maps = [affine_map<(m, k) -> (k)>,
-                       affine_map<(m, k) -> (m, k)>,
-                       affine_map<(m, k) -> (m)>],
-      iterator_types = ["parallel", "reduction"],
-      kind = #vector.kind<add>
-    } %arg1, %arg0, %arg2 : vector<2xf32>, vector<4x2xf32>, vector<4xf32> into vector<4xf32>
+      vector.contract #matvec_trait_3 %arg1, %arg0, %arg2
+        : vector<2xf32>, vector<4x2xf32>, vector<4xf32> into vector<4xf32>
   } : vector<4x2xi1> -> vector<4xf32>
   return %res : vector<4xf32>
 }
 
+// CHECK-LABEL: @masked_matvec_k_mk_m_scalable_parallel_dim
+// CHECK-SAME:  %[[MAT:.+]]: vector<[4]x2xf32>
+// CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME:  %[[INIT:.+]]: vector<[4]xf32>
+// CHECK-SAME:  %[[MASK:.+]]: vector<[4]x2xi1>
+func.func @masked_matvec_k_mk_m_scalable_parallel_dim(%arg0: vector<[4]x2xf32>, %arg1: vector<2xf32>, %arg2: vector<[4]xf32>, %mask: vector<[4]x2xi1>) -> vector<[4]xf32> {
+  // CHECK:         vector.transpose %[[MASK]]
+  // CHECK:         vector.transpose %[[MAT]]
+  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
+  %res = vector.mask %mask {
+      vector.contract #matvec_trait_3 %arg1, %arg0, %arg2
+        : vector<2xf32>, vector<[4]x2xf32>, vector<[4]xf32> into vector<[4]xf32>
+  } : vector<[4]x2xi1> -> vector<[4]xf32>
+  return %res : vector<[4]xf32>
+}
+
+#matvec_accesses_4 = [
+  affine_map<(m, k) -> (k)>,
+  affine_map<(m, k) -> (k, m)>,
+  affine_map<(m, k) -> (m)>
+]
+#matvec_trait_4 = {
+  indexing_maps = #matvec_accesses_4,
+  iterator_types = ["parallel", "reduction"]
+}
+
+// CHECK-LABEL: @masked_matvec_k_km_m_scalable_parallel_dim
+// CHECK-SAME:  %[[MAT:.+]]: vector<2x[4]xf32>
+// CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME:  %[[INIT:.+]]: vector<[4]xf32>
+// CHECK-SAME:  %[[MASK:.+]]: vector<[4]x2xi1>
+func.func @masked_matvec_k_km_m_scalable_parallel_dim(%arg0: vector<2x[4]xf32>, %arg1: vector<2xf32>, %arg2: vector<[4]xf32>, %mask: vector<[4]x2xi1>) -> vector<[4]xf32> {
+  // CHECK:         vector.transpose %[[MASK]]
+  // CHECK-NOT:     vector.transpose %[[MAT]]
+  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
+  %res = vector.mask %mask {
+    vector.contract #matvec_trait_4 %arg1, %arg0, %arg2
+      : vector<2xf32>, vector<2x[4]xf32>, vector<[4]xf32> into vector<[4]xf32>
+  } : vector<[4]x2xi1> -> vector<[4]xf32>
+  return %res : vector<[4]xf32>
+}
+
 // CHECK-LABEL: @masked_matvec_k_km_m
 // CHECK-SAME:  %[[MAT:.+]]: vector<2x4xf32>
 // CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
@@ -386,17 +475,22 @@ func.func @masked_matvec_k_km_m(%arg0: vector<2x4xf32>, %arg1: vector<2xf32>, %a
   // CHECK-NOT:     vector.transpose %[[MAT]]
   // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
   %res = vector.mask %mask {
-    vector.contract {
-      indexing_maps = [affine_map<(m, k) -> (k)>,
-                       affine_map<(m, k) -> (k, m)>,
-                       affine_map<(m, k) -> (m)>],
-      iterator_types = ["parallel", "reduction"],
-      kind = #vector.kind<add>
-    } %arg1, %arg0, %arg2 : vector<2xf32>, vector<2x4xf32>, vector<4xf32> into vector<4xf32>
+    vector.contract #matvec_trait_4 %arg1, %arg0, %arg2
+      : vector<2xf32>, vector<2x4xf32>, vector<4xf32> into vector<4xf32>
   } : vector<4x2xi1> -> vector<4xf32>
   return %res : vector<4xf32>
 }
 
+#matvec_accesses_5 = [
+  affine_map<(k, m) -> (m, k)>,
+  affine_map<(k, m) -> (k)>,
+  affine_map<(k, m) -> (m)>
+]
+#matvec_trait_5 = {
+  indexing_maps = #matvec_accesses_5,
+  iterator_types = ["reduction", "parallel"]
+}
+
 // CHECK-LABEL: @masked_tmatvec_mk_k_m
 // CHECK-SAME:  %[[MAT:.+]]: vector<4x2xf32>
 // CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
@@ -407,17 +501,38 @@ func.func @masked_tmatvec_mk_k_m(%arg0: vector<4x2xf32>, %arg1: vector<2xf32>, %
   // CHECK-NOT:     vector.transpose %[[MASK]]
   // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
   %res = vector.mask %mask {
-    vector.contract {
-      indexing_maps = [affine_map<(k, m) -> (m, k)>,
-                       affine_map<(k, m) -> (k)>,
-                       affine_map<(k, m) -> (m)>],
-      iterator_types = ["reduction", "parallel"],
-      kind = #vector.kind<add>
-    } %arg0, %arg1, %arg2 : vector<4x2xf32>, vector<2xf32>, vector<4xf32> into vector<4xf32>
+    vector.contract #matvec_trait_5 %arg0, %arg1, %arg2
+      : vector<4x2xf32>, vector<2xf32>, vector<4xf32> into vector<4xf32>
   } : vector<2x4xi1> -> vector<4xf32>
   return %res : vector<4xf32>
 }
 
+// CHECK-LABEL: @masked_tmatvec_mk_k_m_scalable_parallel_dim
+// CHECK-SAME:  %[[MAT:.+]]: vector<[4]x2xf32>
+// CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME:  %[[INIT:.+]]: vector<[4]xf32>
+// CHECK-SAME:  %[[MASK:.+]]: vector<2x[4]xi1>
+func.func @masked_tmatvec_mk_k_m_scalable_parallel_dim(%arg0: vector<[4]x2xf32>, %arg1: vector<2xf32>, %arg2: vector<[4]xf32>, %mask: vector<2x[4]xi1>) -> vector<[4]xf32> {
+  // CHECK:         vector.transpose %[[MAT]]
+  // CHECK-NOT:     vector.transpose %[[MASK]]
+  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
+  %res = vector.mask %mask {
+    vector.contract #matvec_trait_5 %arg0, %arg1, %arg2
+      : vector<[4]x2xf32>, vector<2xf32>, vector<[4]xf32> into vector<[4]xf32>
+  } : vector<2x[4]xi1> -> vector<[4]xf32>
+  return %res : vector<[4]xf32>
+}
+
+#matvec_accesses_6 = [
+  affine_map<(k, m) -> (k, m)>,
+  affine_map<(k, m) -> (k)>,
+  affine_map<(k, m) -> (m)>
+]
+#matvec_trait_6 = {
+  indexing_maps = #matvec_accesses_6,
+  iterator_types = ["reduction", "parallel"]
+}
+
 // CHECK-LABEL: @masked_tmatvec_km_k_m
 // CHECK-SAME:  %[[MAT:.+]]: vector<2x4xf32>
 // CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
@@ -428,17 +543,38 @@ func.func @masked_tmatvec_km_k_m(%arg0: vector<2x4xf32>, %arg1: vector<2xf32>, %
   // CHECK-NOT:     vector.transpose %[[MASK]]
   // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
   %res = vector.mask %mask {
-    vector.contract {
-      indexing_maps = [affine_map<(k, m) -> (k, m)>,
-                       affine_map<(k, m) -> (k)>,
-                       affine_map<(k, m) -> (m)>],
-      iterator_types = ["reduction", "parallel"],
-      kind = #vector.kind<add>
-    } %arg0, %arg1, %arg2 : vector<2x4xf32>, vector<2xf32>, vector<4xf32> into vector<4xf32>
+    vector.contract #matvec_trait_6 %arg0, %arg1, %arg2
+      : vector<2x4xf32>, vector<2xf32>, vector<4xf32> into vector<4xf32>
   } : vector<2x4xi1> -> vector<4xf32>
   return %res : vector<4xf32>
 }
 
+// CHECK-LABEL: @masked_tmatvec_km_k_m_scalable_parallel_dim
+// CHECK-SAME:  %[[MAT:.+]]: vector<2x[4]xf32>
+// CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME:  %[[INIT:.+]]: vector<[4]xf32>
+// CHECK-SAME:  %[[MASK:.+]]: vector<2x[4]xi1>
+func.func @masked_tmatvec_km_k_m_scalable_parallel_dim(%arg0: vector<2x[4]xf32>, %arg1: vector<2xf32>, %arg2: vector<[4]xf32>, %mask: vector<2x[4]xi1>) -> vector<[4]xf32> {
+  // CHECK-NOT:     vector.transpose %[[MAT]]
+  // CHECK-NOT:     vector.transpose %[[MASK]]
+  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
+  %res = vector.mask %mask {
+    vector.contract #matvec_trait_6 %arg0, %arg1, %arg2
+      : vector<2x[4]xf32>, vector<2xf32>, vector<[4]xf32> into vector<[4]xf32>
+  } : vector<2x[4]xi1> -> vector<[4]xf32>
+  return %res : vector<[4]xf32>
+}
+
+#matvec_accesses_7 = [
+  affine_map<(k, m) -> (k)>,
+  affine_map<(k, m) -> (m, k)>,
+  affine_map<(k, m) -> (m)>
+]
+#matvec_trait_7 = {
+  indexing_maps = #matvec_accesses_7,
+  iterator_types = ["reduction", "parallel"]
+}
+
 // CHECK-LABEL: @masked_tmatvec_k_mk_m
 // CHECK-SAME:  %[[MAT:.+]]: vector<4x2xf32>
 // CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
@@ -449,17 +585,38 @@ func.func @masked_tmatvec_k_mk_m(%arg0: vector<4x2xf32>, %arg1: vector<2xf32>, %
   // CHECK-NOT:     vector.transpose %[[MASK]]
   // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
   %res = vector.mask %mask {
-    vector.contract {
-      indexing_maps = [affine_map<(k, m) -> (k)>,
-                       affine_map<(k, m) -> (m, k)>,
-                       affine_map<(k, m) -> (m)>],
-      iterator_types = ["reduction", "parallel"],
-      kind = #vector.kind<add>
-    } %arg1, %arg0, %arg2 : vector<2xf32>, vector<4x2xf32>, vector<4xf32> into vector<4xf32>
+    vector.contract #matvec_trait_7 %arg1, %arg0, %arg2
+      : vector<2xf32>, vector<4x2xf32>, vector<4xf32> into vector<4xf32>
   } : vector<2x4xi1> -> vector<4xf32>
   return %res : vector<4xf32>
 }
 
+// CHECK-LABEL: @masked_tmatvec_k_mk_m_scalable_parallel_dim
+// CHECK-SAME:  %[[MAT:.+]]: vector<[4]x2xf32>
+// CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME:  %[[INIT:.+]]: vector<[4]xf32>
+// CHECK-SAME:  %[[MASK:.+]]: vector<2x[4]xi1>
+func.func @masked_tmatvec_k_mk_m_scalable_parallel_dim(%arg0: vector<[4]x2xf32>, %arg1: vector<2xf32>, %arg2: vector<[4]xf32>, %mask: vector<2x[4]xi1>) -> vector<[4]xf32> {
+  // CHECK:         vector.transpose %[[MAT]]
+  // CHECK-NOT:     vector.transpose %[[MASK]]
+  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
+  %res = vector.mask %mask {
+    vector.contract #matvec_trait_7 %arg1, %arg0, %arg2
+      : vector<2xf32>, vector<[4]x2xf32>, vector<[4]xf32> into vector<[4]xf32>
+  } : vector<2x[4]xi1> -> vector<[4]xf32>
+  return %res : vector<[4]xf32>
+}
+
+#matvec_accesses_8 = [
+  affine_map<(k, m) -> (k)>,
+  affine_map<(k, m) -> (k, m)>,
+  affine_map<(k, m) -> (m)>
+]
+#matvec_trait_8 = {
+  indexing_maps = #matvec_accesses_8,
+  iterator_types = ["reduction", "parallel"]
+}
+
 // CHECK-LABEL: @masked_tmatvec_k_km_m
 // CHECK-SAME:  %[[MAT:.+]]: vector<2x4xf32>
 // CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
@@ -470,17 +627,28 @@ func.func @masked_tmatvec_k_km_m(%arg0: vector<2x4xf32>, %arg1: vector<2xf32>, %
   // CHECK-NOT:     vector.transpose %[[MASK]]
   // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
   %res = vector.mask %mask {
-    vector.contract {
-      indexing_maps = [affine_map<(k, m) -> (k)>,
-                       affine_map<(k, m) -> (k, m)>,
-                       affine_map<(k, m) -> (m)>],
-      iterator_types = ["reduction", "parallel"],
-      kind = #vector.kind<add>
-    } %arg1, %arg0, %arg2 : vector<2xf32>, vector<2x4xf32>, vector<4xf32> into vector<4xf32>
+    vector.contract #matvec_trait_8 %arg1, %arg0, %arg2
+      : vector<2xf32>, vector<2x4xf32>, vector<4xf32> into vector<4xf32>
   } : vector<2x4xi1> -> vector<4xf32>
   return %res : vector<4xf32>
 }
 
+// CHECK-LABEL: @masked_tmatvec_k_km_m_scalable_parallel_dim
+// CHECK-SAME:  %[[MAT:.+]]: vector<2x[4]xf32>
+// CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME:  %[[INIT:.+]]: vector<[4]xf32>
+// CHECK-SAME:  %[[MASK:.+]]: vector<2x[4]xi1>
+func.func @masked_tmatvec_k_km_m_scalable_parallel_dim(%arg0: vector<2x[4]xf32>, %arg1: vector<2xf32>, %arg2: vector<[4]xf32>, %mask: vector<2x[4]xi1>) -> vector<[4]xf32> {
+  // CHECK-NOT:     vector.transpose %[[MAT]]
+  // CHECK-NOT:     vector.transpose %[[MASK]]
+  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
+  %res = vector.mask %mask {
+    vector.contract #matvec_trait_8 %arg1, %arg0, %arg2
+      : vector<2xf32>, vector<2x[4]xf32>, vector<[4]xf32> into vector<[4]xf32>
+  } : vector<2x[4]xi1> -> vector<[4]xf32>
+  return %res : vector<[4]xf32>
+}
+
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {

>From fb801b3266e3a1f4e38eb4a0b8e628a2161a3563 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Fri, 27 Oct 2023 07:52:09 +0000
Subject: [PATCH 2/2] [mlir][Vector] Update v.contract -> v.outerproduct tests
 (2/N)

The remaining tests for conversions from vector.contract to
vector.outerproduct for _matmul_ operations in:

  * "vector-contract-to-outerproduct-transforms.mlir"

are updated with cases for scalable vectors. One duplicated test is
removed.

In addition:

  * tests are re-organised so that _matvec_ tests and _matmul_ tests are
    "clustered" together,
  * one duplicate case for _matvec_ is removed,
  * function formatting is unified,
  * added comments to document and to seperate different cases,
  * unified the naming for matrix/vector dimensions: (i, j, k) -> (m, n,
    k),

While this does add a bit of noise to this patch, I wanted to avoid
sending seperate patches to refactor this file.

Depends on #70379
---
 ...r-contract-to-outerproduct-transforms.mlir | 830 +++++++++++-------
 1 file changed, 496 insertions(+), 334 deletions(-)

diff --git a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
index ec88759cd4927cb..5f560017cad312d 100644
--- a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
@@ -1,36 +1,31 @@
 // RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s
 
-#matvec_accesses = [
-  affine_map<(i, j) -> (i, j)>,
-  affine_map<(i, j) -> (j)>,
-  affine_map<(i, j) -> (i)>
-]
-#matvec_trait = {
-  indexing_maps = #matvec_accesses,
-  iterator_types = ["parallel", "reduction"]
-}
-
-#matmat_accesses = [
-  affine_map<(i, j, k) -> (i, k)>,
-  affine_map<(i, j, k) -> (k, j)>,
-  affine_map<(i, j, k) -> (i, j)>
-]
-#matmat_trait = {
-  indexing_maps = #matmat_accesses,
-  iterator_types = ["parallel", "parallel", "reduction"]
-}
+// NOTE - tests in this file are duplicated so that there's a version for
+//    * _fixed width_ and for _scalable_ vectors.
+// In order for the "vector.contract -> vector.outerproduct" patterns to work,
+// only the non-reduction dimension can be scalable (*). For Matmul operations
+// that is set to be the N dimension (i.e. rows of the output matrix), which
+// matches how matrix multiplication are normally implemented for e.g. 
+// Arm's SVE. However, making the M dimension scalable (i.e. columns of the
+// output matrix) should work as well.
+//
+// (*) The conversion tested in this file unrolls along the reduction
+// dimension, which is not supported for scalable vectors.
 
-#matmat_accesses_0 = [
-  affine_map<(m, n, k) -> (m, k)>,
-  affine_map<(m, n, k) -> (k, n)>,
-  affine_map<(m, n, k) -> (m, n)>
+// ============================================================================
+//  Matvec 1
+// ============================================================================
+#matvec_accesses_1 = [
+  affine_map<(m, k) -> (m, k)>,
+  affine_map<(m, k) -> (k)>,
+  affine_map<(m, k) -> (m)>
 ]
-#matmat_trait_0 = {
-  indexing_maps = #matmat_accesses_0,
-  iterator_types = ["parallel", "parallel", "reduction"]
+#matvec_trait_1 = {
+  indexing_maps = #matvec_accesses_1,
+  iterator_types = ["parallel", "reduction"]
 }
 
-// CHECK-LABEL:   func.func @masked_extract_contract2(
+// CHECK-LABEL:   func.func @masked_matvec_mk_k_m(
 // CHECK-SAME:      %{{.*}}: vector<2x3xf32>,
 // CHECK-SAME:      %{{.*}}: vector<3xf32>,
 // CHECK-SAME:      %{{.*}}: vector<2xf32>,
@@ -45,17 +40,16 @@
 // CHECK:           %[[MASK2:.*]] = vector.extract %[[T_MASK]][2] : vector<2xi1> from vector<3x2xi1>
 // CHECK:           vector.mask %[[MASK2]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32>
 
-func.func @masked_extract_contract2(%arg0: vector<2x3xf32>,
-                                    %arg1: vector<3xf32>,
-                                    %arg2: vector<2xf32>,
-                                    %m: vector<2x3xi1>) -> vector<2xf32> {
-  %0 = vector.mask %m { vector.contract #matvec_trait %arg0, %arg1, %arg2
+func.func @masked_matvec_mk_k_m(%arg0: vector<2x3xf32>,
+                                %arg1: vector<3xf32>,
+                                %arg2: vector<2xf32>,
+                                %m: vector<2x3xi1>) -> vector<2xf32> {
+  %0 = vector.mask %m { vector.contract #matvec_trait_1 %arg0, %arg1, %arg2
           : vector<2x3xf32>, vector<3xf32> into vector<2xf32> } : vector<2x3xi1> -> vector<2xf32>
   return %0 : vector<2xf32>
 }
 
-
-// CHECK-LABEL:   func.func @masked_extract_contract2_scalable_parallel_dim(
+// CHECK-LABEL:   func.func @masked_matvec_mk_k_m_scalable_parallel_dim(
 // CHECK-SAME:      %{{.*}}: vector<[2]x3xf32>,
 // CHECK-SAME:      %{{.*}}: vector<3xf32>,
 // CHECK-SAME:      %{{.*}}: vector<[2]xf32>,
@@ -69,326 +63,58 @@ func.func @masked_extract_contract2(%arg0: vector<2x3xf32>,
 
 // CHECK:           %[[MASK2:.*]] = vector.extract %[[T_MASK]][2] : vector<[2]xi1> from vector<3x[2]xi1>
 // CHECK:           vector.mask %[[MASK2]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<[2]xf32>, f32 } : vector<[2]xi1> -> vector<[2]xf32>
-func.func @masked_extract_contract2_scalable_parallel_dim(%arg0: vector<[2]x3xf32>,
-                                    %arg1: vector<3xf32>,
-                                    %arg2: vector<[2]xf32>,
-                                    %m: vector<[2]x3xi1>) -> vector<[2]xf32> {
-  %0 = vector.mask %m { vector.contract #matvec_trait %arg0, %arg1, %arg2
+func.func @masked_matvec_mk_k_m_scalable_parallel_dim(%arg0: vector<[2]x3xf32>,
+                                                      %arg1: vector<3xf32>,
+                                                      %arg2: vector<[2]xf32>,
+                                                      %m: vector<[2]x3xi1>) -> vector<[2]xf32> {
+  %0 = vector.mask %m { vector.contract #matvec_trait_1 %arg0, %arg1, %arg2
           : vector<[2]x3xf32>, vector<3xf32> into vector<[2]xf32> } : vector<[2]x3xi1> -> vector<[2]xf32>
   return %0 : vector<[2]xf32>
 }
 
-// CHECK-LABEL: func.func @masked_extract_contract4(
-// CHECK-SAME:    %{{.*}}: vector<3x5xf32>,
-// CHECK-SAME:    %{{.*}}: vector<5x7xf32>,
-// CHECK-SAME:    %{{.*}}: vector<3x7xf32>,
-// CHECK-SAME:    %[[IN_MASK:.*]]: vector<3x7x5xi1>) -> vector<3x7xf32> {
-// CHECK:         %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [2, 0, 1] : vector<3x7x5xi1> to vector<5x3x7xi1>
-// CHECK:         %[[T_MASK_R0:.*]] = vector.extract %[[T_MASK]][0] : vector<3x7xi1> from vector<5x3x7xi1>
-// CHECK:         %{{.*}} = vector.mask %[[T_MASK_R0]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
-// CHECK:         %[[T_MASK_R1:.*]] = vector.extract %[[T_MASK]][1] : vector<3x7xi1> from vector<5x3x7xi1>
-// CHECK:         %{{.*}} = vector.mask %[[T_MASK_R1]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
-// CHECK:         %[[T_MASK_R2:.*]] = vector.extract %[[T_MASK]][2] : vector<3x7xi1> from vector<5x3x7xi1>
-// CHECK:         %{{.*}} = vector.mask %[[T_MASK_R2]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
-// CHECK:         %[[T_MASK_R3:.*]] = vector.extract %[[T_MASK]][3] : vector<3x7xi1> from vector<5x3x7xi1>
-// CHECK:         %{{.*}} = vector.mask %[[T_MASK_R3]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
-// CHECK:         %[[T_MASK_R4:.*]] = vector.extract %[[T_MASK]][4] : vector<3x7xi1> from vector<5x3x7xi1>
-// CHECK:         %{{.*}} = vector.mask %[[T_MASK_R4]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
-
-func.func @masked_extract_contract4(%arg0: vector<3x5xf32>,
-                                    %arg1: vector<5x7xf32>,
-                                    %arg2: vector<3x7xf32>,
-                                    %m : vector<3x7x5xi1>) -> vector<3x7xf32> {
-  %0 = vector.mask %m { vector.contract #matmat_trait %arg0, %arg1, %arg2
-  : vector<3x5xf32>, vector<5x7xf32> into vector<3x7xf32> } : vector<3x7x5xi1> -> vector<3x7xf32>
-  return %0 : vector<3x7xf32>
-}
-
-// CHECK-LABEL: func.func @masked_extract_contract4_scalable_J_dim(
-// CHECK-SAME:    %{{.*}}: vector<3x5xf32>,
-// CHECK-SAME:    %{{.*}}: vector<5x[7]xf32>,
-// CHECK-SAME:    %{{.*}}: vector<3x[7]xf32>,
-// CHECK-SAME:    %[[IN_MASK:.*]]: vector<3x[7]x5xi1>) -> vector<3x[7]xf32> {
-// CHECK:         %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [2, 0, 1] : vector<3x[7]x5xi1> to vector<5x3x[7]xi1>
-// CHECK:         %[[T_MASK_R0:.*]] = vector.extract %[[T_MASK]][0] : vector<3x[7]xi1> from vector<5x3x[7]xi1>
-// CHECK:         %{{.*}} = vector.mask %[[T_MASK_R0]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32>
-// CHECK:         %[[T_MASK_R1:.*]] = vector.extract %[[T_MASK]][1] : vector<3x[7]xi1> from vector<5x3x[7]xi1>
-// CHECK:         %[[VAL_13:.*]] = vector.mask %[[T_MASK_R1]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32>
-// CHECK:         %[[T_MASK_R2:.*]] = vector.extract %[[T_MASK]][2] : vector<3x[7]xi1> from vector<5x3x[7]xi1>
-// CHECK:         %{{.*}} = vector.mask %[[T_MASK_R2]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32>
-// CHECK:         %[[T_MASK_R3:.*]] = vector.extract %[[T_MASK]][3] : vector<3x[7]xi1> from vector<5x3x[7]xi1>
-// CHECK:         %{{.*}} = vector.mask %[[T_MASK_R3]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32>
-// CHECK:         %[[T_MASK_R4:.*]] = vector.extract %[[T_MASK]][4] : vector<3x[7]xi1> from vector<5x3x[7]xi1>
-// CHECK:         %{{.*}} = vector.mask %[[T_MASK_R4]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32>
-
-// Note that only the J dimension is scalable in this example. In theory, all
-// dimensions could be be scalable, but there is no target yet for which this
-// would make sense.
-func.func @masked_extract_contract4_scalable_J_dim(%arg0: vector<3x5xf32>,
-                                    %arg1: vector<5x[7]xf32>,
-                                    %arg2: vector<3x[7]xf32>,
-                                    %m : vector<3x[7]x5xi1>) -> vector<3x[7]xf32> {
-  %0 = vector.mask %m { vector.contract #matmat_trait %arg0, %arg1, %arg2
-  : vector<3x5xf32>, vector<5x[7]xf32> into vector<3x[7]xf32> } : vector<3x[7]x5xi1> -> vector<3x[7]xf32>
-  return %0 : vector<3x[7]xf32>
-}
-
-// CHECK-LABEL: func @matmul
-// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x4xf32>,
-// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x3xf32>,
-// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32>
-//      CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
-// CHECK-SAME:  : vector<2x4xf32> to vector<4x2xf32>
-//
-//      CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf32> from vector<4x2xf32>
-//      CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<3xf32> from vector<4x3xf32>
-//      CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
-// CHECK-SAME:  : vector<2xf32>, vector<3xf32>
-//
-//      CHECK: %[[a1:.*]] = vector.extract %[[At]][1] : vector<2xf32> from vector<4x2xf32>
-//      CHECK: %[[b1:.*]] = vector.extract %[[B]][1] : vector<3xf32> from vector<4x3xf32>
-//      CHECK: %[[c1:.*]] = vector.outerproduct %[[a1]], %[[b1]], %[[c0]]
-// CHECK-SAME:  : vector<2xf32>, vector<3xf32>
-//
-//      CHECK: %[[a2:.*]] = vector.extract %[[At]][2] : vector<2xf32> from vector<4x2xf32>
-//      CHECK: %[[b2:.*]] = vector.extract %[[B]][2] : vector<3xf32> from vector<4x3xf32>
-//      CHECK: %[[c2:.*]] = vector.outerproduct %[[a2]], %[[b2]], %[[c1]]
-// CHECK-SAME:  : vector<2xf32>, vector<3xf32>
-//
-//      CHECK: %[[a3:.*]] = vector.extract %[[At]][3] : vector<2xf32> from vector<4x2xf32>
-//      CHECK: %[[b3:.*]] = vector.extract %[[B]][3] : vector<3xf32> from vector<4x3xf32>
-//      CHECK: %[[c3:.*]] = vector.outerproduct %[[a3]], %[[b3]], %[[c2]]
-// CHECK-SAME:  : vector<2xf32>, vector<3xf32>
-//
-//      CHECK: return %[[c3]] : vector<2x3xf32>
-func.func @matmul(%arg0: vector<2x4xf32>,
-                          %arg1: vector<4x3xf32>,
-                          %arg2: vector<2x3xf32>) -> vector<2x3xf32> {
-  %0 = vector.contract #matmat_trait %arg0, %arg1, %arg2
-    : vector<2x4xf32>, vector<4x3xf32> into vector<2x3xf32>
-  return %0 : vector<2x3xf32>
-}
-
-// CHECK-LABEL: func @matmul_0
-// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
-// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>,
-// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32>
-//      CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
-//      CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf32> from vector<1x2xf32>
-//      CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<3xf32> from vector<1x3xf32>
-//      CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
-//      CHECK: return %[[c0]] : vector<2x3xf32>
-func.func @matmul_0(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vector<2x3xf32>)
--> vector<2x3xf32>
-{
-  %0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2
-    : vector<2x1xf32>, vector<1x3xf32> into vector<2x3xf32>
-  return %0 : vector<2x3xf32>
-}
-
-// CHECK-LABEL: func @matmul_0_mixed
-// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf16>,
-// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf16>,
-// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32>
-//      CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
-//      CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf16> from vector<1x2xf16>
-//      CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<3xf16> from vector<1x3xf16>
-//      CHECK: %[[a1:.*]] = arith.extf %[[a0]] : vector<2xf16> to vector<2xf32>
-//      CHECK: %[[b1:.*]] = arith.extf %[[b0]] : vector<3xf16> to vector<3xf32>
-//      CHECK: %[[c0:.*]] = vector.outerproduct %[[a1]], %[[b1]], %[[C]]
-//      CHECK: return %[[c0]] : vector<2x3xf32>
-func.func @matmul_0_mixed(%arg0: vector<2x1xf16>, %arg1: vector<1x3xf16>, %arg2: vector<2x3xf32>)
--> vector<2x3xf32>
-{
-  %0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2
-    : vector<2x1xf16>, vector<1x3xf16> into vector<2x3xf32>
-  return %0 : vector<2x3xf32>
-}
-
-#matmat_accesses_1 = [
-  affine_map<(m, n, k) -> (m, k)>,
-  affine_map<(m, n, k) -> (n, k)>,
-  affine_map<(m, n, k) -> (m, n)>
-]
-#matmat_trait_1 = {
-  indexing_maps = #matmat_accesses_1,
-  iterator_types = ["parallel", "parallel", "reduction"]
-}
-
-// CHECK-LABEL: func @matmul_1
-// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
-// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<3x1xf32>,
-// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32>
-//      CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
-//      CHECK: %[[Bt:.*]] = vector.transpose %[[B]], [1, 0]
-//      CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf32> from vector<1x2xf32>
-//      CHECK: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<3xf32> from vector<1x3xf32>
-//      CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
-//      CHECK: return %[[c0]] : vector<2x3xf32>
-func.func @matmul_1(%arg0: vector<2x1xf32>, %arg1: vector<3x1xf32>, %arg2: vector<2x3xf32>)
--> vector<2x3xf32>
-{
-  %0 = vector.contract #matmat_trait_1 %arg0, %arg1, %arg2
-    : vector<2x1xf32>, vector<3x1xf32> into vector<2x3xf32>
-  return %0 : vector<2x3xf32>
-}
-
-#matmat_accesses_2 = [
-  affine_map<(m, n, k) -> (k, m)>,
-  affine_map<(m, n, k) -> (k, n)>,
-  affine_map<(m, n, k) -> (m, n)>
-]
-#matmat_trait_2 = {
-  indexing_maps = #matmat_accesses_2,
-  iterator_types = ["parallel", "parallel", "reduction"]
-}
-
-// CHECK-LABEL: func @matmul_2
-// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<1x2xf32>,
-// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>,
-// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32>
-//      CHECK: %[[a0:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<1x2xf32>
-//      CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<3xf32> from vector<1x3xf32>
-//      CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
-//      CHECK: return %[[c0]] : vector<2x3xf32>
-func.func @matmul_2(%arg0: vector<1x2xf32>, %arg1: vector<1x3xf32>, %arg2: vector<2x3xf32>)
--> vector<2x3xf32>
-{
-  %0 = vector.contract #matmat_trait_2 %arg0, %arg1, %arg2
-    : vector<1x2xf32>, vector<1x3xf32> into vector<2x3xf32>
-  return %0 : vector<2x3xf32>
-}
-
-#matmat_accesses_3 = [
-  affine_map<(m, n, k) -> (k, m)>,
-  affine_map<(m, n, k) -> (n, k)>,
-  affine_map<(m, n, k) -> (m, n)>
-]
-#matmat_trait_3 = {
-  indexing_maps = #matmat_accesses_3,
-  iterator_types = ["parallel", "parallel", "reduction"]
-}
-
-// CHECK-LABEL: func @matmul_3
-// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<1x2xf32>,
-// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<3x1xf32>,
-// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32>
-//      CHECK: %[[Bt:.*]] = vector.transpose %[[B]], [1, 0]
-//      CHECK: %[[a0:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<1x2xf32>
-//      CHECK: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<3xf32> from vector<1x3xf32>
-//      CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
-//      CHECK: return %[[c0]] : vector<2x3xf32>
-func.func @matmul_3(%arg0: vector<1x2xf32>, %arg1: vector<3x1xf32>, %arg2: vector<2x3xf32>)
--> vector<2x3xf32>
-{
-  %0 = vector.contract #matmat_trait_3 %arg0, %arg1, %arg2
-    : vector<1x2xf32>, vector<3x1xf32> into vector<2x3xf32>
-  return %0 : vector<2x3xf32>
-}
-
-#matmat_accesses_4 = [
-  affine_map<(m, n, k) -> (m, k)>,
-  affine_map<(m, n, k) -> (k, n)>,
-  affine_map<(m, n, k) -> (n, m)>
-]
-#matmat_trait_4 = {
-  indexing_maps = #matmat_accesses_4,
-  iterator_types = ["parallel", "parallel", "reduction"]
-}
-
-// CHECK-LABEL: func @matmul_4
-// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
-// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>,
-// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<3x2xf32>
-//      CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
-//      CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<3xf32> from vector<1x3xf32>
-//      CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf32> from vector<1x2xf32>
-//      CHECK: %[[c0:.*]] = vector.outerproduct %[[b0]], %[[a0]], %[[C]]
-//      CHECK: return %[[c0]] : vector<3x2xf32>
-func.func @matmul_4(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vector<3x2xf32>)
--> vector<3x2xf32>
-{
-  %0 = vector.contract #matmat_trait_4 %arg0, %arg1, %arg2
-    : vector<2x1xf32>, vector<1x3xf32> into vector<3x2xf32>
-  return %0 : vector<3x2xf32>
-}
-
-#matvec_accesses_1 = [
-  affine_map<(m, k) -> (m, k)>,
+// ============================================================================
+//  Matvec 2
+// ============================================================================
+#matvec_accesses_2 = [
+  affine_map<(m, k) -> (k, m)>,
   affine_map<(m, k) -> (k)>,
   affine_map<(m, k) -> (m)>
 ]
-#matvec_trait_1 = {
-  indexing_maps = #matvec_accesses_1,
+#matvec_trait_2 = {
+  indexing_maps = #matvec_accesses_2,
   iterator_types = ["parallel", "reduction"]
 }
 
-// CHECK-LABEL: @masked_matvec_mk_k_m
-// CHECK-SAME:  %[[MAT:.+]]: vector<4x2xf32>
+// CHECK-LABEL: @masked_matvec_km_k_m
+// CHECK-SAME:  %[[MAT:.+]]: vector<2x4xf32>
 // CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
 // CHECK-SAME:  %[[INIT:.+]]: vector<4xf32>
 // CHECK-SAME:  %[[MASK:.+]]: vector<4x2xi1>
-func.func @masked_matvec_mk_k_m(%arg0: vector<4x2xf32>, %arg1: vector<2xf32>, %arg2: vector<4xf32>, %mask: vector<4x2xi1>) -> vector<4xf32> {
+func.func @masked_matvec_km_k_m(%arg0: vector<2x4xf32>,
+                                %arg1: vector<2xf32>,
+                                %arg2: vector<4xf32>, 
+                                %mask: vector<4x2xi1>) -> vector<4xf32> {
   // CHECK:         vector.transpose %[[MASK]]
-  // CHECK:         vector.transpose %[[MAT]]
+  // CHECK-NOT:     vector.transpose %[[MAT]]
   // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
   %res = vector.mask %mask {
-    vector.contract #matvec_trait_1 %arg0, %arg1, %arg2
-      : vector<4x2xf32>, vector<2xf32>, vector<4xf32> into vector<4xf32>
+    vector.contract #matvec_trait_2 %arg0, %arg1, %arg2
+      : vector<2x4xf32>, vector<2xf32>, vector<4xf32> into vector<4xf32>
   } : vector<4x2xi1> -> vector<4xf32>
   return %res : vector<4xf32>
 }
 
-// CHECK-LABEL: @masked_matvec_mk_k_m_scalable_parallel_dim
-// CHECK-SAME:  %[[MAT:.+]]: vector<[4]x2xf32>
+// CHECK-LABEL: @masked_matvec_km_k_m_scalable_parallel_dim
+// CHECK-SAME:  %[[MAT:.+]]: vector<2x[4]xf32>
 // CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
 // CHECK-SAME:  %[[INIT:.+]]: vector<[4]xf32>
 // CHECK-SAME:  %[[MASK:.+]]: vector<[4]x2xi1>
-func.func @masked_matvec_mk_k_m_scalable_parallel_dim(%arg0: vector<[4]x2xf32>, %arg1: vector<2xf32>, %arg2: vector<[4]xf32>, %mask: vector<[4]x2xi1>) -> vector<[4]xf32> {
+func.func @masked_matvec_km_k_m_scalable_parallel_dim(%arg0: vector<2x[4]xf32>,
+                                                      %arg1: vector<2xf32>,
+                                                      %arg2: vector<[4]xf32>,
+                                                      %mask: vector<[4]x2xi1>) -> vector<[4]xf32> {
   // CHECK:         vector.transpose %[[MASK]]
-  // CHECK:         vector.transpose %[[MAT]]
-  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
-  %res = vector.mask %mask {
-    vector.contract #matvec_trait_1 %arg0, %arg1, %arg2
-      : vector<[4]x2xf32>, vector<2xf32>, vector<[4]xf32> into vector<[4]xf32>
-  } : vector<[4]x2xi1> -> vector<[4]xf32>
-  return %res : vector<[4]xf32>
-}
-
-#matvec_accesses_2 = [
-  affine_map<(m, k) -> (k, m)>,
-  affine_map<(m, k) -> (k)>,
-  affine_map<(m, k) -> (m)>
-]
-#matvec_trait_2 = {
-  indexing_maps = #matvec_accesses_2,
-  iterator_types = ["parallel", "reduction"]
-}
-
-// CHECK-LABEL: @masked_matvec_km_k_m
-// CHECK-SAME:  %[[MAT:.+]]: vector<2x4xf32>
-// CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
-// CHECK-SAME:  %[[INIT:.+]]: vector<4xf32>
-// CHECK-SAME:  %[[MASK:.+]]: vector<4x2xi1>
-func.func @masked_matvec_km_k_m(%arg0: vector<2x4xf32>, %arg1: vector<2xf32>, %arg2: vector<4xf32>, %mask: vector<4x2xi1>) -> vector<4xf32> {
-  // CHECK:         vector.transpose %[[MASK]]
-  // CHECK-NOT:     vector.transpose %[[MAT]]
-  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
-  %res = vector.mask %mask {
-    vector.contract #matvec_trait_2 %arg0, %arg1, %arg2
-      : vector<2x4xf32>, vector<2xf32>, vector<4xf32> into vector<4xf32>
-  } : vector<4x2xi1> -> vector<4xf32>
-  return %res : vector<4xf32>
-}
-
-// CHECK-LABEL: @masked_matvec_km_k_m_scalable_parallel_dim
-// CHECK-SAME:  %[[MAT:.+]]: vector<2x[4]xf32>
-// CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
-// CHECK-SAME:  %[[INIT:.+]]: vector<[4]xf32>
-// CHECK-SAME:  %[[MASK:.+]]: vector<[4]x2xi1>
-func.func @masked_matvec_km_k_m_scalable_parallel_dim(%arg0: vector<2x[4]xf32>, %arg1: vector<2xf32>, %arg2: vector<[4]xf32>, %mask: vector<[4]x2xi1>) -> vector<[4]xf32> {
-  // CHECK:         vector.transpose %[[MASK]]
-  // CHECK-NOT:     vector.transpose %[[MAT]]
+  // CHECK-NOT:     vector.transpose %[[MAT]]
   // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
   %res = vector.mask %mask {
     vector.contract #matvec_trait_2 %arg0, %arg1, %arg2
@@ -397,6 +123,9 @@ func.func @masked_matvec_km_k_m_scalable_parallel_dim(%arg0: vector<2x[4]xf32>,
   return %res : vector<[4]xf32>
 }
 
+// ============================================================================
+//  Matvec 3
+// ============================================================================
 #matvec_accesses_3 = [
   affine_map<(m, k) -> (k)>,
   affine_map<(m, k) -> (m, k)>,
@@ -412,7 +141,10 @@ func.func @masked_matvec_km_k_m_scalable_parallel_dim(%arg0: vector<2x[4]xf32>,
 // CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
 // CHECK-SAME:  %[[INIT:.+]]: vector<4xf32>
 // CHECK-SAME:  %[[MASK:.+]]: vector<4x2xi1>
-func.func @masked_matvec_k_mk_m(%arg0: vector<4x2xf32>, %arg1: vector<2xf32>, %arg2: vector<4xf32>, %mask: vector<4x2xi1>) -> vector<4xf32> {
+func.func @masked_matvec_k_mk_m(%arg0: vector<4x2xf32>,
+                                %arg1: vector<2xf32>,
+                                %arg2: vector<4xf32>,
+                                %mask: vector<4x2xi1>) -> vector<4xf32> {
   // CHECK:         vector.transpose %[[MASK]]
   // CHECK:         vector.transpose %[[MAT]]
   // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
@@ -428,7 +160,10 @@ func.func @masked_matvec_k_mk_m(%arg0: vector<4x2xf32>, %arg1: vector<2xf32>, %a
 // CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
 // CHECK-SAME:  %[[INIT:.+]]: vector<[4]xf32>
 // CHECK-SAME:  %[[MASK:.+]]: vector<[4]x2xi1>
-func.func @masked_matvec_k_mk_m_scalable_parallel_dim(%arg0: vector<[4]x2xf32>, %arg1: vector<2xf32>, %arg2: vector<[4]xf32>, %mask: vector<[4]x2xi1>) -> vector<[4]xf32> {
+func.func @masked_matvec_k_mk_m_scalable_parallel_dim(%arg0: vector<[4]x2xf32>,
+                                                      %arg1: vector<2xf32>,
+                                                      %arg2: vector<[4]xf32>,
+                                                      %mask: vector<[4]x2xi1>) -> vector<[4]xf32> {
   // CHECK:         vector.transpose %[[MASK]]
   // CHECK:         vector.transpose %[[MAT]]
   // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
@@ -439,6 +174,9 @@ func.func @masked_matvec_k_mk_m_scalable_parallel_dim(%arg0: vector<[4]x2xf32>,
   return %res : vector<[4]xf32>
 }
 
+// ============================================================================
+//  Matvec 4
+// ============================================================================
 #matvec_accesses_4 = [
   affine_map<(m, k) -> (k)>,
   affine_map<(m, k) -> (k, m)>,
@@ -454,7 +192,10 @@ func.func @masked_matvec_k_mk_m_scalable_parallel_dim(%arg0: vector<[4]x2xf32>,
 // CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
 // CHECK-SAME:  %[[INIT:.+]]: vector<[4]xf32>
 // CHECK-SAME:  %[[MASK:.+]]: vector<[4]x2xi1>
-func.func @masked_matvec_k_km_m_scalable_parallel_dim(%arg0: vector<2x[4]xf32>, %arg1: vector<2xf32>, %arg2: vector<[4]xf32>, %mask: vector<[4]x2xi1>) -> vector<[4]xf32> {
+func.func @masked_matvec_k_km_m_scalable_parallel_dim(%arg0: vector<2x[4]xf32>,
+                                                      %arg1: vector<2xf32>,
+                                                      %arg2: vector<[4]xf32>,
+                                                      %mask: vector<[4]x2xi1>) -> vector<[4]xf32> {
   // CHECK:         vector.transpose %[[MASK]]
   // CHECK-NOT:     vector.transpose %[[MAT]]
   // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
@@ -470,7 +211,10 @@ func.func @masked_matvec_k_km_m_scalable_parallel_dim(%arg0: vector<2x[4]xf32>,
 // CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
 // CHECK-SAME:  %[[INIT:.+]]: vector<4xf32>
 // CHECK-SAME:  %[[MASK:.+]]: vector<4x2xi1>
-func.func @masked_matvec_k_km_m(%arg0: vector<2x4xf32>, %arg1: vector<2xf32>, %arg2: vector<4xf32>, %mask: vector<4x2xi1>) -> vector<4xf32> {
+func.func @masked_matvec_k_km_m(%arg0: vector<2x4xf32>,
+                                %arg1: vector<2xf32>,
+                                %arg2: vector<4xf32>,
+                                %mask: vector<4x2xi1>) -> vector<4xf32> {
   // CHECK:         vector.transpose %[[MASK]]
   // CHECK-NOT:     vector.transpose %[[MAT]]
   // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
@@ -481,6 +225,9 @@ func.func @masked_matvec_k_km_m(%arg0: vector<2x4xf32>, %arg1: vector<2xf32>, %a
   return %res : vector<4xf32>
 }
 
+// ============================================================================
+//  Matvec 5
+// ============================================================================
 #matvec_accesses_5 = [
   affine_map<(k, m) -> (m, k)>,
   affine_map<(k, m) -> (k)>,
@@ -523,6 +270,9 @@ func.func @masked_tmatvec_mk_k_m_scalable_parallel_dim(%arg0: vector<[4]x2xf32>,
   return %res : vector<[4]xf32>
 }
 
+// ============================================================================
+//  Matvec 6
+// ============================================================================
 #matvec_accesses_6 = [
   affine_map<(k, m) -> (k, m)>,
   affine_map<(k, m) -> (k)>,
@@ -565,6 +315,9 @@ func.func @masked_tmatvec_km_k_m_scalable_parallel_dim(%arg0: vector<2x[4]xf32>,
   return %res : vector<[4]xf32>
 }
 
+// ============================================================================
+//  Matvec 7
+// ============================================================================
 #matvec_accesses_7 = [
   affine_map<(k, m) -> (k)>,
   affine_map<(k, m) -> (m, k)>,
@@ -607,6 +360,9 @@ func.func @masked_tmatvec_k_mk_m_scalable_parallel_dim(%arg0: vector<[4]x2xf32>,
   return %res : vector<[4]xf32>
 }
 
+// ============================================================================
+//  Matvec 8
+// ============================================================================
 #matvec_accesses_8 = [
   affine_map<(k, m) -> (k)>,
   affine_map<(k, m) -> (k, m)>,
@@ -649,7 +405,413 @@ func.func @masked_tmatvec_k_km_m_scalable_parallel_dim(%arg0: vector<2x[4]xf32>,
   return %res : vector<[4]xf32>
 }
 
+// ============================================================================
+//  Masked Matmul
+// ============================================================================
+#matmat_accesses = [
+  affine_map<(m, n, k) -> (m, k)>,
+  affine_map<(m, n, k) -> (k, n)>,
+  affine_map<(m, n, k) -> (m, n)>
+]
+#matmat_trait = {
+  indexing_maps = #matmat_accesses,
+  iterator_types = ["parallel", "parallel", "reduction"]
+}
+
+// CHECK-LABEL: func @matmul
+// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x4xf32>,
+// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x3xf32>,
+// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32>
+//      CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
+// CHECK-SAME:  : vector<2x4xf32> to vector<4x2xf32>
+//
+//      CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf32> from vector<4x2xf32>
+//      CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<3xf32> from vector<4x3xf32>
+//      CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
+// CHECK-SAME:  : vector<2xf32>, vector<3xf32>
+//
+//      CHECK: %[[a1:.*]] = vector.extract %[[At]][1] : vector<2xf32> from vector<4x2xf32>
+//      CHECK: %[[b1:.*]] = vector.extract %[[B]][1] : vector<3xf32> from vector<4x3xf32>
+//      CHECK: %[[c1:.*]] = vector.outerproduct %[[a1]], %[[b1]], %[[c0]]
+// CHECK-SAME:  : vector<2xf32>, vector<3xf32>
+//
+//      CHECK: %[[a2:.*]] = vector.extract %[[At]][2] : vector<2xf32> from vector<4x2xf32>
+//      CHECK: %[[b2:.*]] = vector.extract %[[B]][2] : vector<3xf32> from vector<4x3xf32>
+//      CHECK: %[[c2:.*]] = vector.outerproduct %[[a2]], %[[b2]], %[[c1]]
+// CHECK-SAME:  : vector<2xf32>, vector<3xf32>
+//
+//      CHECK: %[[a3:.*]] = vector.extract %[[At]][3] : vector<2xf32> from vector<4x2xf32>
+//      CHECK: %[[b3:.*]] = vector.extract %[[B]][3] : vector<3xf32> from vector<4x3xf32>
+//      CHECK: %[[c3:.*]] = vector.outerproduct %[[a3]], %[[b3]], %[[c2]]
+// CHECK-SAME:  : vector<2xf32>, vector<3xf32>
+//
+//      CHECK: return %[[c3]] : vector<2x3xf32>
+func.func @matmul(%arg0: vector<2x4xf32>,
+                  %arg1: vector<4x3xf32>,
+                  %arg2: vector<2x3xf32>) -> vector<2x3xf32> {
+  %0 = vector.contract #matmat_trait %arg0, %arg1, %arg2
+    : vector<2x4xf32>, vector<4x3xf32> into vector<2x3xf32>
+  return %0 : vector<2x3xf32>
+}
+
+// CHECK-LABEL: func @matmul_scalable
+// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x4xf32>,
+// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x[3]xf32>,
+// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x[3]xf32>
+//      CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
+// CHECK-SAME:  : vector<2x4xf32> to vector<4x2xf32>
+//
+//      CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf32> from vector<4x2xf32>
+//      CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<[3]xf32> from vector<4x[3]xf32>
+//      CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
+// CHECK-SAME:  : vector<2xf32>, vector<[3]xf32>
+//
+//      CHECK: %[[a1:.*]] = vector.extract %[[At]][1] : vector<2xf32> from vector<4x2xf32>
+//      CHECK: %[[b1:.*]] = vector.extract %[[B]][1] : vector<[3]xf32> from vector<4x[3]xf32>
+//      CHECK: %[[c1:.*]] = vector.outerproduct %[[a1]], %[[b1]], %[[c0]]
+// CHECK-SAME:  : vector<2xf32>, vector<[3]xf32>
+//
+//      CHECK: %[[a2:.*]] = vector.extract %[[At]][2] : vector<2xf32> from vector<4x2xf32>
+//      CHECK: %[[b2:.*]] = vector.extract %[[B]][2] : vector<[3]xf32> from vector<4x[3]xf32>
+//      CHECK: %[[c2:.*]] = vector.outerproduct %[[a2]], %[[b2]], %[[c1]]
+// CHECK-SAME:  : vector<2xf32>, vector<[3]xf32>
+//
+//      CHECK: %[[a3:.*]] = vector.extract %[[At]][3] : vector<2xf32> from vector<4x2xf32>
+//      CHECK: %[[b3:.*]] = vector.extract %[[B]][3] : vector<[3]xf32> from vector<4x[3]xf32>
+//      CHECK: %[[c3:.*]] = vector.outerproduct %[[a3]], %[[b3]], %[[c2]]
+// CHECK-SAME:  : vector<2xf32>, vector<[3]xf32>
+//
+//      CHECK: return %[[c3]] : vector<2x[3]xf32>
+func.func @matmul_scalable(%arg0: vector<2x4xf32>,
+                           %arg1: vector<4x[3]xf32>,
+                           %arg2: vector<2x[3]xf32>) -> vector<2x[3]xf32> {
+  %0 = vector.contract #matmat_trait %arg0, %arg1, %arg2
+    : vector<2x4xf32>, vector<4x[3]xf32> into vector<2x[3]xf32>
+  return %0 : vector<2x[3]xf32>
+}
+
+// CHECK-LABEL: func.func @masked_matmul(
+// CHECK-SAME:    %{{.*}}: vector<3x5xf32>,
+// CHECK-SAME:    %{{.*}}: vector<5x7xf32>,
+// CHECK-SAME:    %{{.*}}: vector<3x7xf32>,
+// CHECK-SAME:    %[[IN_MASK:.*]]: vector<3x7x5xi1>) -> vector<3x7xf32> {
+// CHECK:         %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [2, 0, 1] : vector<3x7x5xi1> to vector<5x3x7xi1>
+// CHECK:         %[[T_MASK_R0:.*]] = vector.extract %[[T_MASK]][0] : vector<3x7xi1> from vector<5x3x7xi1>
+// CHECK:         %{{.*}} = vector.mask %[[T_MASK_R0]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
+// CHECK:         %[[T_MASK_R1:.*]] = vector.extract %[[T_MASK]][1] : vector<3x7xi1> from vector<5x3x7xi1>
+// CHECK:         %{{.*}} = vector.mask %[[T_MASK_R1]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
+// CHECK:         %[[T_MASK_R2:.*]] = vector.extract %[[T_MASK]][2] : vector<3x7xi1> from vector<5x3x7xi1>
+// CHECK:         %{{.*}} = vector.mask %[[T_MASK_R2]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
+// CHECK:         %[[T_MASK_R3:.*]] = vector.extract %[[T_MASK]][3] : vector<3x7xi1> from vector<5x3x7xi1>
+// CHECK:         %{{.*}} = vector.mask %[[T_MASK_R3]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
+// CHECK:         %[[T_MASK_R4:.*]] = vector.extract %[[T_MASK]][4] : vector<3x7xi1> from vector<5x3x7xi1>
+// CHECK:         %{{.*}} = vector.mask %[[T_MASK_R4]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
+
+func.func @masked_matmul(%arg0: vector<3x5xf32>,
+                         %arg1: vector<5x7xf32>,
+                         %arg2: vector<3x7xf32>,
+                         %m : vector<3x7x5xi1>) -> vector<3x7xf32> {
+  %0 = vector.mask %m { vector.contract #matmat_trait %arg0, %arg1, %arg2
+  : vector<3x5xf32>, vector<5x7xf32> into vector<3x7xf32> } : vector<3x7x5xi1> -> vector<3x7xf32>
+  return %0 : vector<3x7xf32>
+}
+
+// CHECK-LABEL: func.func @masked_matmul_scalable(
+// CHECK-SAME:    %{{.*}}: vector<3x5xf32>,
+// CHECK-SAME:    %{{.*}}: vector<5x[7]xf32>,
+// CHECK-SAME:    %{{.*}}: vector<3x[7]xf32>,
+// CHECK-SAME:    %[[IN_MASK:.*]]: vector<3x[7]x5xi1>) -> vector<3x[7]xf32> {
+// CHECK:         %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [2, 0, 1] : vector<3x[7]x5xi1> to vector<5x3x[7]xi1>
+// CHECK:         %[[T_MASK_R0:.*]] = vector.extract %[[T_MASK]][0] : vector<3x[7]xi1> from vector<5x3x[7]xi1>
+// CHECK:         %{{.*}} = vector.mask %[[T_MASK_R0]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32>
+// CHECK:         %[[T_MASK_R1:.*]] = vector.extract %[[T_MASK]][1] : vector<3x[7]xi1> from vector<5x3x[7]xi1>
+// CHECK:         %[[VAL_13:.*]] = vector.mask %[[T_MASK_R1]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32>
+// CHECK:         %[[T_MASK_R2:.*]] = vector.extract %[[T_MASK]][2] : vector<3x[7]xi1> from vector<5x3x[7]xi1>
+// CHECK:         %{{.*}} = vector.mask %[[T_MASK_R2]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32>
+// CHECK:         %[[T_MASK_R3:.*]] = vector.extract %[[T_MASK]][3] : vector<3x[7]xi1> from vector<5x3x[7]xi1>
+// CHECK:         %{{.*}} = vector.mask %[[T_MASK_R3]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32>
+// CHECK:         %[[T_MASK_R4:.*]] = vector.extract %[[T_MASK]][4] : vector<3x[7]xi1> from vector<5x3x[7]xi1>
+// CHECK:         %{{.*}} = vector.mask %[[T_MASK_R4]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<[7]xf32> } : vector<3x[7]xi1> -> vector<3x[7]xf32>
+
+func.func @masked_matmul_scalable(%arg0: vector<3x5xf32>,
+                                  %arg1: vector<5x[7]xf32>,
+                                  %arg2: vector<3x[7]xf32>,
+                                  %m : vector<3x[7]x5xi1>) -> vector<3x[7]xf32> {
+  %0 = vector.mask %m { vector.contract #matmat_trait %arg0, %arg1, %arg2
+  : vector<3x5xf32>, vector<5x[7]xf32> into vector<3x[7]xf32> } : vector<3x[7]x5xi1> -> vector<3x[7]xf32>
+  return %0 : vector<3x[7]xf32>
+}
+
+// ============================================================================
+//  Matmul 0
+// ============================================================================
+#matmat_accesses_0 = [
+  affine_map<(m, n, k) -> (m, k)>,
+  affine_map<(m, n, k) -> (k, n)>,
+  affine_map<(m, n, k) -> (m, n)>
+]
+#matmat_trait_0 = {
+  indexing_maps = #matmat_accesses_0,
+  iterator_types = ["parallel", "parallel", "reduction"]
+}
+
+// CHECK-LABEL: func @matmul_0
+// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
+// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>,
+// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32>
+//      CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
+//      CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf32> from vector<1x2xf32>
+//      CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<3xf32> from vector<1x3xf32>
+//      CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
+//      CHECK: return %[[c0]] : vector<2x3xf32>
+func.func @matmul_0(%arg0: vector<2x1xf32>,
+                    %arg1: vector<1x3xf32>,
+                    %arg2: vector<2x3xf32>) -> vector<2x3xf32>
+{
+  %0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2
+    : vector<2x1xf32>, vector<1x3xf32> into vector<2x3xf32>
+  return %0 : vector<2x3xf32>
+}
+
+// CHECK-LABEL: func @matmul_0_mixed
+// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf16>,
+// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf16>,
+// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32>
+//      CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
+//      CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf16> from vector<1x2xf16>
+//      CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<3xf16> from vector<1x3xf16>
+//      CHECK: %[[a1:.*]] = arith.extf %[[a0]] : vector<2xf16> to vector<2xf32>
+//      CHECK: %[[b1:.*]] = arith.extf %[[b0]] : vector<3xf16> to vector<3xf32>
+//      CHECK: %[[c0:.*]] = vector.outerproduct %[[a1]], %[[b1]], %[[C]]
+//      CHECK: return %[[c0]] : vector<2x3xf32>
+func.func @matmul_0_mixed(%arg0: vector<2x1xf16>,
+                          %arg1: vector<1x3xf16>,
+                          %arg2: vector<2x3xf32>) -> vector<2x3xf32>
+{
+  %0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2
+    : vector<2x1xf16>, vector<1x3xf16> into vector<2x3xf32>
+  return %0 : vector<2x3xf32>
+}
+
+// CHECK-LABEL: func @matmul_0_mixed_scalable
+// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf16>,
+// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x[3]xf16>,
+// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x[3]xf32>
+//      CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
+//      CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf16> from vector<1x2xf16>
+//      CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<[3]xf16> from vector<1x[3]xf16>
+//      CHECK: %[[a1:.*]] = arith.extf %[[a0]] : vector<2xf16> to vector<2xf32>
+//      CHECK: %[[b1:.*]] = arith.extf %[[b0]] : vector<[3]xf16> to vector<[3]xf32>
+//      CHECK: %[[c0:.*]] = vector.outerproduct %[[a1]], %[[b1]], %[[C]]
+//      CHECK: return %[[c0]] : vector<2x[3]xf32>
+func.func @matmul_0_mixed_scalable(%arg0: vector<2x1xf16>,
+                                   %arg1: vector<1x[3]xf16>,
+                                   %arg2: vector<2x[3]xf32>) -> vector<2x[3]xf32>
+{
+  %0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2
+    : vector<2x1xf16>, vector<1x[3]xf16> into vector<2x[3]xf32>
+  return %0 : vector<2x[3]xf32>
+}
+
+// ============================================================================
+//  Matmul 1
+// ============================================================================
+#matmat_accesses_1 = [
+  affine_map<(m, n, k) -> (m, k)>,
+  affine_map<(m, n, k) -> (n, k)>,
+  affine_map<(m, n, k) -> (m, n)>
+]
+#matmat_trait_1 = {
+  indexing_maps = #matmat_accesses_1,
+  iterator_types = ["parallel", "parallel", "reduction"]
+}
+
+// CHECK-LABEL: func @matmul_1
+// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
+// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<3x1xf32>,
+// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32>
+//      CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
+//      CHECK: %[[Bt:.*]] = vector.transpose %[[B]], [1, 0]
+//      CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf32> from vector<1x2xf32>
+//      CHECK: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<3xf32> from vector<1x3xf32>
+//      CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
+//      CHECK: return %[[c0]] : vector<2x3xf32>
+func.func @matmul_1(%arg0: vector<2x1xf32>,
+                    %arg1: vector<3x1xf32>,
+                    %arg2: vector<2x3xf32>) -> vector<2x3xf32>
+{
+  %0 = vector.contract #matmat_trait_1 %arg0, %arg1, %arg2
+    : vector<2x1xf32>, vector<3x1xf32> into vector<2x3xf32>
+  return %0 : vector<2x3xf32>
+}
+
+// CHECK-LABEL: func @matmul_1_scalable
+// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
+// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<[3]x1xf32>,
+// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x[3]xf32>
+//      CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
+//      CHECK: %[[Bt:.*]] = vector.transpose %[[B]], [1, 0]
+//      CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf32> from vector<1x2xf32>
+//      CHECK: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<[3]xf32> from vector<1x[3]xf32>
+//      CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
+//      CHECK: return %[[c0]] : vector<2x[3]xf32>
+func.func @matmul_1_scalable(%arg0: vector<2x1xf32>,
+                             %arg1: vector<[3]x1xf32>,
+                             %arg2: vector<2x[3]xf32>) -> vector<2x[3]xf32>
+{
+  %0 = vector.contract #matmat_trait_1 %arg0, %arg1, %arg2
+    : vector<2x1xf32>, vector<[3]x1xf32> into vector<2x[3]xf32>
+  return %0 : vector<2x[3]xf32>
+}
+
+// ============================================================================
+//  Matmul 2
+// ============================================================================
+#matmat_accesses_2 = [
+  affine_map<(m, n, k) -> (k, m)>,
+  affine_map<(m, n, k) -> (k, n)>,
+  affine_map<(m, n, k) -> (m, n)>
+]
+#matmat_trait_2 = {
+  indexing_maps = #matmat_accesses_2,
+  iterator_types = ["parallel", "parallel", "reduction"]
+}
+
+// CHECK-LABEL: func @matmul_2
+// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<1x2xf32>,
+// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>,
+// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32>
+//      CHECK: %[[a0:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<1x2xf32>
+//      CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<3xf32> from vector<1x3xf32>
+//      CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
+//      CHECK: return %[[c0]] : vector<2x3xf32>
+func.func @matmul_2(%arg0: vector<1x2xf32>,
+                    %arg1: vector<1x3xf32>,
+                    %arg2: vector<2x3xf32>) -> vector<2x3xf32>
+{
+  %0 = vector.contract #matmat_trait_2 %arg0, %arg1, %arg2
+    : vector<1x2xf32>, vector<1x3xf32> into vector<2x3xf32>
+  return %0 : vector<2x3xf32>
+}
+
+// CHECK-LABEL: func @matmul_2_scalable
+// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<1x2xf32>,
+// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x[3]xf32>,
+// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x[3]xf32>
+//      CHECK: %[[a0:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<1x2xf32>
+//      CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<[3]xf32> from vector<1x[3]xf32>
+//      CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
+//      CHECK: return %[[c0]] : vector<2x[3]xf32>
+func.func @matmul_2_scalable(%arg0: vector<1x2xf32>,
+                             %arg1: vector<1x[3]xf32>,
+                             %arg2: vector<2x[3]xf32>) -> vector<2x[3]xf32>
+{
+  %0 = vector.contract #matmat_trait_2 %arg0, %arg1, %arg2
+    : vector<1x2xf32>, vector<1x[3]xf32> into vector<2x[3]xf32>
+  return %0 : vector<2x[3]xf32>
+}
+
+// ============================================================================
+//  Matmul 3
+// ============================================================================
+#matmat_accesses_3 = [
+  affine_map<(m, n, k) -> (k, m)>,
+  affine_map<(m, n, k) -> (n, k)>,
+  affine_map<(m, n, k) -> (m, n)>
+]
+#matmat_trait_3 = {
+  indexing_maps = #matmat_accesses_3,
+  iterator_types = ["parallel", "parallel", "reduction"]
+}
+
+// CHECK-LABEL: func @matmul_3
+// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<1x2xf32>,
+// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<3x1xf32>,
+// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32>
+//      CHECK: %[[Bt:.*]] = vector.transpose %[[B]], [1, 0]
+//      CHECK: %[[a0:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<1x2xf32>
+//      CHECK: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<3xf32> from vector<1x3xf32>
+//      CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
+//      CHECK: return %[[c0]] : vector<2x3xf32>
+func.func @matmul_3(%arg0: vector<1x2xf32>,
+                    %arg1: vector<3x1xf32>,
+                    %arg2: vector<2x3xf32>) -> vector<2x3xf32>
+{
+  %0 = vector.contract #matmat_trait_3 %arg0, %arg1, %arg2
+    : vector<1x2xf32>, vector<3x1xf32> into vector<2x3xf32>
+  return %0 : vector<2x3xf32>
+}
+
+// CHECK-LABEL: func @matmul_3_scalable
+// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<1x2xf32>,
+// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<[3]x1xf32>,
+// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x[3]xf32>
+//      CHECK: %[[Bt:.*]] = vector.transpose %[[B]], [1, 0]
+//      CHECK: %[[a0:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<1x2xf32>
+//      CHECK: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<[3]xf32> from vector<1x[3]xf32>
+//      CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
+//      CHECK: return %[[c0]] : vector<2x[3]xf32>
+func.func @matmul_3_scalable(%arg0: vector<1x2xf32>,
+                             %arg1: vector<[3]x1xf32>,
+                             %arg2: vector<2x[3]xf32>) -> vector<2x[3]xf32>
+{
+  %0 = vector.contract #matmat_trait_3 %arg0, %arg1, %arg2
+    : vector<1x2xf32>, vector<[3]x1xf32> into vector<2x[3]xf32>
+  return %0 : vector<2x[3]xf32>
+}
+
+// ============================================================================
+//  Matmul 4
+// ============================================================================
+#matmat_accesses_4 = [
+  affine_map<(m, n, k) -> (m, k)>,
+  affine_map<(m, n, k) -> (k, n)>,
+  affine_map<(m, n, k) -> (n, m)>
+]
+#matmat_trait_4 = {
+  indexing_maps = #matmat_accesses_4,
+  iterator_types = ["parallel", "parallel", "reduction"]
+}
+
+// CHECK-LABEL: func @matmul_4
+// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
+// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>,
+// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<3x2xf32>
+//      CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
+//      CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<3xf32> from vector<1x3xf32>
+//      CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf32> from vector<1x2xf32>
+//      CHECK: %[[c0:.*]] = vector.outerproduct %[[b0]], %[[a0]], %[[C]]
+//      CHECK: return %[[c0]] : vector<3x2xf32>
+func.func @matmul_4(%arg0: vector<2x1xf32>,
+                    %arg1: vector<1x3xf32>,
+                    %arg2: vector<3x2xf32>) -> vector<3x2xf32>
+{
+  %0 = vector.contract #matmat_trait_4 %arg0, %arg1, %arg2
+    : vector<2x1xf32>, vector<1x3xf32> into vector<3x2xf32>
+  return %0 : vector<3x2xf32>
+}
+
+// CHECK-LABEL: func @matmul_4_scalable
+// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<[2]x1xf32>,
+// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>,
+// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<3x[2]xf32>
+//      CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
+//      CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<3xf32> from vector<1x3xf32>
+//      CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<[2]xf32> from vector<1x[2]xf32>
+//      CHECK: %[[c0:.*]] = vector.outerproduct %[[b0]], %[[a0]], %[[C]]
+//      CHECK: return %[[c0]] : vector<3x[2]xf32>
+func.func @matmul_4_scalable(%arg0: vector<[2]x1xf32>,
+                             %arg1: vector<1x3xf32>,
+                             %arg2: vector<3x[2]xf32>) -> vector<3x[2]xf32>
+{
+  %0 = vector.contract #matmat_trait_4 %arg0, %arg1, %arg2
+    : vector<[2]x1xf32>, vector<1x3xf32> into vector<3x[2]xf32>
+  return %0 : vector<3x[2]xf32>
+}
 
+// ============================================================================
+//  TD sequence
+// ============================================================================
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
     %f = transform.structured.match ops{["func.func"]} in %module_op



More information about the Mlir-commits mailing list