[Mlir-commits] [mlir] [mlir][vector] Update v.contract -> v.outerproduct tests (PR #70379)

Andrzej WarzyƄski llvmlistbot at llvm.org
Thu Oct 26 13:58:52 PDT 2023


https://github.com/banach-space created https://github.com/llvm/llvm-project/pull/70379

Tests for conversions from `vector.contract` to `vector.outerproduct`
are updated with cases for scalable vectors. This patch updates one
specific test files:

  * vector-contract-to-outerproduct-transforms.mlir,

and only updates tests for matvec operations (the remaining matmul
operations have been updated in previous patches). For consistency with
the existing tests, only the parallel dimension is made scalable. Making
the reduction dimension scalable would lead to different patterns
without `vector.outerproduct`.


>From e3aa2ca5948cbcb931b954114f6cc8e57c4b9ee8 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Thu, 26 Oct 2023 20:53:46 +0000
Subject: [PATCH] [mlir][vector] Update v.contract -> v.outerproduct tests

Tests for conversions from `vector.contract` to `vector.outerproduct`
are updated with cases for scalable vectors. This patch updates one
specific test files:

  * vector-contract-to-outerproduct-transforms.mlir,

and only updates tests for matvec operations (the remaining matmul
operations have been updated in previous patches). For consistency with
the existing tests, only the parallel dimension is made scalable. Making
the reduction dimension scalable would lead to different patterns
without `vector.outerproduct`.
---
 ...r-contract-to-outerproduct-transforms.mlir | 280 ++++++++++++++----
 1 file changed, 224 insertions(+), 56 deletions(-)

diff --git a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
index 44fb23088cea933..ec88759cd4927cb 100644
--- a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
@@ -313,6 +313,16 @@ func.func @matmul_4(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vecto
   return %0 : vector<3x2xf32>
 }
 
+#matvec_accesses_1 = [
+  affine_map<(m, k) -> (m, k)>,
+  affine_map<(m, k) -> (k)>,
+  affine_map<(m, k) -> (m)>
+]
+#matvec_trait_1 = {
+  indexing_maps = #matvec_accesses_1,
+  iterator_types = ["parallel", "reduction"]
+}
+
 // CHECK-LABEL: @masked_matvec_mk_k_m
 // CHECK-SAME:  %[[MAT:.+]]: vector<4x2xf32>
 // CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
@@ -323,17 +333,38 @@ func.func @masked_matvec_mk_k_m(%arg0: vector<4x2xf32>, %arg1: vector<2xf32>, %a
   // CHECK:         vector.transpose %[[MAT]]
   // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
   %res = vector.mask %mask {
-    vector.contract {
-      indexing_maps = [affine_map<(m, k) -> (m, k)>,
-                       affine_map<(m, k) -> (k)>,
-                       affine_map<(m, k) -> (m)>],
-      iterator_types = ["parallel", "reduction"],
-      kind = #vector.kind<add>
-    } %arg0, %arg1, %arg2 : vector<4x2xf32>, vector<2xf32>, vector<4xf32> into vector<4xf32>
+    vector.contract #matvec_trait_1 %arg0, %arg1, %arg2
+      : vector<4x2xf32>, vector<2xf32>, vector<4xf32> into vector<4xf32>
   } : vector<4x2xi1> -> vector<4xf32>
   return %res : vector<4xf32>
 }
 
+// CHECK-LABEL: @masked_matvec_mk_k_m_scalable_parallel_dim
+// CHECK-SAME:  %[[MAT:.+]]: vector<[4]x2xf32>
+// CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME:  %[[INIT:.+]]: vector<[4]xf32>
+// CHECK-SAME:  %[[MASK:.+]]: vector<[4]x2xi1>
+func.func @masked_matvec_mk_k_m_scalable_parallel_dim(%arg0: vector<[4]x2xf32>, %arg1: vector<2xf32>, %arg2: vector<[4]xf32>, %mask: vector<[4]x2xi1>) -> vector<[4]xf32> {
+  // CHECK:         vector.transpose %[[MASK]]
+  // CHECK:         vector.transpose %[[MAT]]
+  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
+  %res = vector.mask %mask {
+    vector.contract #matvec_trait_1 %arg0, %arg1, %arg2
+      : vector<[4]x2xf32>, vector<2xf32>, vector<[4]xf32> into vector<[4]xf32>
+  } : vector<[4]x2xi1> -> vector<[4]xf32>
+  return %res : vector<[4]xf32>
+}
+
+#matvec_accesses_2 = [
+  affine_map<(m, k) -> (k, m)>,
+  affine_map<(m, k) -> (k)>,
+  affine_map<(m, k) -> (m)>
+]
+#matvec_trait_2 = {
+  indexing_maps = #matvec_accesses_2,
+  iterator_types = ["parallel", "reduction"]
+}
+
 // CHECK-LABEL: @masked_matvec_km_k_m
 // CHECK-SAME:  %[[MAT:.+]]: vector<2x4xf32>
 // CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
@@ -344,17 +375,38 @@ func.func @masked_matvec_km_k_m(%arg0: vector<2x4xf32>, %arg1: vector<2xf32>, %a
   // CHECK-NOT:     vector.transpose %[[MAT]]
   // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
   %res = vector.mask %mask {
-    vector.contract {
-      indexing_maps = [affine_map<(m, k) -> (k, m)>,
-                       affine_map<(m, k) -> (k)>,
-                       affine_map<(m, k) -> (m)>],
-      iterator_types = ["parallel", "reduction"],
-      kind = #vector.kind<add>
-    } %arg0, %arg1, %arg2 : vector<2x4xf32>, vector<2xf32>, vector<4xf32> into vector<4xf32>
+    vector.contract #matvec_trait_2 %arg0, %arg1, %arg2
+      : vector<2x4xf32>, vector<2xf32>, vector<4xf32> into vector<4xf32>
   } : vector<4x2xi1> -> vector<4xf32>
   return %res : vector<4xf32>
 }
 
+// CHECK-LABEL: @masked_matvec_km_k_m_scalable_parallel_dim
+// CHECK-SAME:  %[[MAT:.+]]: vector<2x[4]xf32>
+// CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME:  %[[INIT:.+]]: vector<[4]xf32>
+// CHECK-SAME:  %[[MASK:.+]]: vector<[4]x2xi1>
+func.func @masked_matvec_km_k_m_scalable_parallel_dim(%arg0: vector<2x[4]xf32>, %arg1: vector<2xf32>, %arg2: vector<[4]xf32>, %mask: vector<[4]x2xi1>) -> vector<[4]xf32> {
+  // CHECK:         vector.transpose %[[MASK]]
+  // CHECK-NOT:     vector.transpose %[[MAT]]
+  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
+  %res = vector.mask %mask {
+    vector.contract #matvec_trait_2 %arg0, %arg1, %arg2
+      : vector<2x[4]xf32>, vector<2xf32>, vector<[4]xf32> into vector<[4]xf32>
+  } : vector<[4]x2xi1> -> vector<[4]xf32>
+  return %res : vector<[4]xf32>
+}
+
+#matvec_accesses_3 = [
+  affine_map<(m, k) -> (k)>,
+  affine_map<(m, k) -> (m, k)>,
+  affine_map<(m, k) -> (m)>
+]
+#matvec_trait_3 = {
+  indexing_maps = #matvec_accesses_3,
+  iterator_types = ["parallel", "reduction"]
+}
+
 // CHECK-LABEL: @masked_matvec_k_mk_m
 // CHECK-SAME:  %[[MAT:.+]]: vector<4x2xf32>
 // CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
@@ -365,17 +417,54 @@ func.func @masked_matvec_k_mk_m(%arg0: vector<4x2xf32>, %arg1: vector<2xf32>, %a
   // CHECK:         vector.transpose %[[MAT]]
   // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
   %res = vector.mask %mask {
-    vector.contract {
-      indexing_maps = [affine_map<(m, k) -> (k)>,
-                       affine_map<(m, k) -> (m, k)>,
-                       affine_map<(m, k) -> (m)>],
-      iterator_types = ["parallel", "reduction"],
-      kind = #vector.kind<add>
-    } %arg1, %arg0, %arg2 : vector<2xf32>, vector<4x2xf32>, vector<4xf32> into vector<4xf32>
+      vector.contract #matvec_trait_3 %arg1, %arg0, %arg2
+        : vector<2xf32>, vector<4x2xf32>, vector<4xf32> into vector<4xf32>
   } : vector<4x2xi1> -> vector<4xf32>
   return %res : vector<4xf32>
 }
 
+// CHECK-LABEL: @masked_matvec_k_mk_m_scalable_parallel_dim
+// CHECK-SAME:  %[[MAT:.+]]: vector<[4]x2xf32>
+// CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME:  %[[INIT:.+]]: vector<[4]xf32>
+// CHECK-SAME:  %[[MASK:.+]]: vector<[4]x2xi1>
+func.func @masked_matvec_k_mk_m_scalable_parallel_dim(%arg0: vector<[4]x2xf32>, %arg1: vector<2xf32>, %arg2: vector<[4]xf32>, %mask: vector<[4]x2xi1>) -> vector<[4]xf32> {
+  // CHECK:         vector.transpose %[[MASK]]
+  // CHECK:         vector.transpose %[[MAT]]
+  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
+  %res = vector.mask %mask {
+      vector.contract #matvec_trait_3 %arg1, %arg0, %arg2
+        : vector<2xf32>, vector<[4]x2xf32>, vector<[4]xf32> into vector<[4]xf32>
+  } : vector<[4]x2xi1> -> vector<[4]xf32>
+  return %res : vector<[4]xf32>
+}
+
+#matvec_accesses_4 = [
+  affine_map<(m, k) -> (k)>,
+  affine_map<(m, k) -> (k, m)>,
+  affine_map<(m, k) -> (m)>
+]
+#matvec_trait_4 = {
+  indexing_maps = #matvec_accesses_4,
+  iterator_types = ["parallel", "reduction"]
+}
+
+// CHECK-LABEL: @masked_matvec_k_km_m_scalable_parallel_dim
+// CHECK-SAME:  %[[MAT:.+]]: vector<2x[4]xf32>
+// CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME:  %[[INIT:.+]]: vector<[4]xf32>
+// CHECK-SAME:  %[[MASK:.+]]: vector<[4]x2xi1>
+func.func @masked_matvec_k_km_m_scalable_parallel_dim(%arg0: vector<2x[4]xf32>, %arg1: vector<2xf32>, %arg2: vector<[4]xf32>, %mask: vector<[4]x2xi1>) -> vector<[4]xf32> {
+  // CHECK:         vector.transpose %[[MASK]]
+  // CHECK-NOT:     vector.transpose %[[MAT]]
+  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
+  %res = vector.mask %mask {
+    vector.contract #matvec_trait_4 %arg1, %arg0, %arg2
+      : vector<2xf32>, vector<2x[4]xf32>, vector<[4]xf32> into vector<[4]xf32>
+  } : vector<[4]x2xi1> -> vector<[4]xf32>
+  return %res : vector<[4]xf32>
+}
+
 // CHECK-LABEL: @masked_matvec_k_km_m
 // CHECK-SAME:  %[[MAT:.+]]: vector<2x4xf32>
 // CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
@@ -386,17 +475,22 @@ func.func @masked_matvec_k_km_m(%arg0: vector<2x4xf32>, %arg1: vector<2xf32>, %a
   // CHECK-NOT:     vector.transpose %[[MAT]]
   // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
   %res = vector.mask %mask {
-    vector.contract {
-      indexing_maps = [affine_map<(m, k) -> (k)>,
-                       affine_map<(m, k) -> (k, m)>,
-                       affine_map<(m, k) -> (m)>],
-      iterator_types = ["parallel", "reduction"],
-      kind = #vector.kind<add>
-    } %arg1, %arg0, %arg2 : vector<2xf32>, vector<2x4xf32>, vector<4xf32> into vector<4xf32>
+    vector.contract #matvec_trait_4 %arg1, %arg0, %arg2
+      : vector<2xf32>, vector<2x4xf32>, vector<4xf32> into vector<4xf32>
   } : vector<4x2xi1> -> vector<4xf32>
   return %res : vector<4xf32>
 }
 
+#matvec_accesses_5 = [
+  affine_map<(k, m) -> (m, k)>,
+  affine_map<(k, m) -> (k)>,
+  affine_map<(k, m) -> (m)>
+]
+#matvec_trait_5 = {
+  indexing_maps = #matvec_accesses_5,
+  iterator_types = ["reduction", "parallel"]
+}
+
 // CHECK-LABEL: @masked_tmatvec_mk_k_m
 // CHECK-SAME:  %[[MAT:.+]]: vector<4x2xf32>
 // CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
@@ -407,17 +501,38 @@ func.func @masked_tmatvec_mk_k_m(%arg0: vector<4x2xf32>, %arg1: vector<2xf32>, %
   // CHECK-NOT:     vector.transpose %[[MASK]]
   // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
   %res = vector.mask %mask {
-    vector.contract {
-      indexing_maps = [affine_map<(k, m) -> (m, k)>,
-                       affine_map<(k, m) -> (k)>,
-                       affine_map<(k, m) -> (m)>],
-      iterator_types = ["reduction", "parallel"],
-      kind = #vector.kind<add>
-    } %arg0, %arg1, %arg2 : vector<4x2xf32>, vector<2xf32>, vector<4xf32> into vector<4xf32>
+    vector.contract #matvec_trait_5 %arg0, %arg1, %arg2
+      : vector<4x2xf32>, vector<2xf32>, vector<4xf32> into vector<4xf32>
   } : vector<2x4xi1> -> vector<4xf32>
   return %res : vector<4xf32>
 }
 
+// CHECK-LABEL: @masked_tmatvec_mk_k_m_scalable_parallel_dim
+// CHECK-SAME:  %[[MAT:.+]]: vector<[4]x2xf32>
+// CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME:  %[[INIT:.+]]: vector<[4]xf32>
+// CHECK-SAME:  %[[MASK:.+]]: vector<2x[4]xi1>
+func.func @masked_tmatvec_mk_k_m_scalable_parallel_dim(%arg0: vector<[4]x2xf32>, %arg1: vector<2xf32>, %arg2: vector<[4]xf32>, %mask: vector<2x[4]xi1>) -> vector<[4]xf32> {
+  // CHECK:         vector.transpose %[[MAT]]
+  // CHECK-NOT:     vector.transpose %[[MASK]]
+  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
+  %res = vector.mask %mask {
+    vector.contract #matvec_trait_5 %arg0, %arg1, %arg2
+      : vector<[4]x2xf32>, vector<2xf32>, vector<[4]xf32> into vector<[4]xf32>
+  } : vector<2x[4]xi1> -> vector<[4]xf32>
+  return %res : vector<[4]xf32>
+}
+
+#matvec_accesses_6 = [
+  affine_map<(k, m) -> (k, m)>,
+  affine_map<(k, m) -> (k)>,
+  affine_map<(k, m) -> (m)>
+]
+#matvec_trait_6 = {
+  indexing_maps = #matvec_accesses_6,
+  iterator_types = ["reduction", "parallel"]
+}
+
 // CHECK-LABEL: @masked_tmatvec_km_k_m
 // CHECK-SAME:  %[[MAT:.+]]: vector<2x4xf32>
 // CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
@@ -428,17 +543,38 @@ func.func @masked_tmatvec_km_k_m(%arg0: vector<2x4xf32>, %arg1: vector<2xf32>, %
   // CHECK-NOT:     vector.transpose %[[MASK]]
   // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
   %res = vector.mask %mask {
-    vector.contract {
-      indexing_maps = [affine_map<(k, m) -> (k, m)>,
-                       affine_map<(k, m) -> (k)>,
-                       affine_map<(k, m) -> (m)>],
-      iterator_types = ["reduction", "parallel"],
-      kind = #vector.kind<add>
-    } %arg0, %arg1, %arg2 : vector<2x4xf32>, vector<2xf32>, vector<4xf32> into vector<4xf32>
+    vector.contract #matvec_trait_6 %arg0, %arg1, %arg2
+      : vector<2x4xf32>, vector<2xf32>, vector<4xf32> into vector<4xf32>
   } : vector<2x4xi1> -> vector<4xf32>
   return %res : vector<4xf32>
 }
 
+// CHECK-LABEL: @masked_tmatvec_km_k_m_scalable_parallel_dim
+// CHECK-SAME:  %[[MAT:.+]]: vector<2x[4]xf32>
+// CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME:  %[[INIT:.+]]: vector<[4]xf32>
+// CHECK-SAME:  %[[MASK:.+]]: vector<2x[4]xi1>
+func.func @masked_tmatvec_km_k_m_scalable_parallel_dim(%arg0: vector<2x[4]xf32>, %arg1: vector<2xf32>, %arg2: vector<[4]xf32>, %mask: vector<2x[4]xi1>) -> vector<[4]xf32> {
+  // CHECK-NOT:     vector.transpose %[[MAT]]
+  // CHECK-NOT:     vector.transpose %[[MASK]]
+  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
+  %res = vector.mask %mask {
+    vector.contract #matvec_trait_6 %arg0, %arg1, %arg2
+      : vector<2x[4]xf32>, vector<2xf32>, vector<[4]xf32> into vector<[4]xf32>
+  } : vector<2x[4]xi1> -> vector<[4]xf32>
+  return %res : vector<[4]xf32>
+}
+
+#matvec_accesses_7 = [
+  affine_map<(k, m) -> (k)>,
+  affine_map<(k, m) -> (m, k)>,
+  affine_map<(k, m) -> (m)>
+]
+#matvec_trait_7 = {
+  indexing_maps = #matvec_accesses_7,
+  iterator_types = ["reduction", "parallel"]
+}
+
 // CHECK-LABEL: @masked_tmatvec_k_mk_m
 // CHECK-SAME:  %[[MAT:.+]]: vector<4x2xf32>
 // CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
@@ -449,17 +585,38 @@ func.func @masked_tmatvec_k_mk_m(%arg0: vector<4x2xf32>, %arg1: vector<2xf32>, %
   // CHECK-NOT:     vector.transpose %[[MASK]]
   // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
   %res = vector.mask %mask {
-    vector.contract {
-      indexing_maps = [affine_map<(k, m) -> (k)>,
-                       affine_map<(k, m) -> (m, k)>,
-                       affine_map<(k, m) -> (m)>],
-      iterator_types = ["reduction", "parallel"],
-      kind = #vector.kind<add>
-    } %arg1, %arg0, %arg2 : vector<2xf32>, vector<4x2xf32>, vector<4xf32> into vector<4xf32>
+    vector.contract #matvec_trait_7 %arg1, %arg0, %arg2
+      : vector<2xf32>, vector<4x2xf32>, vector<4xf32> into vector<4xf32>
   } : vector<2x4xi1> -> vector<4xf32>
   return %res : vector<4xf32>
 }
 
+// CHECK-LABEL: @masked_tmatvec_k_mk_m_scalable_parallel_dim
+// CHECK-SAME:  %[[MAT:.+]]: vector<[4]x2xf32>
+// CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME:  %[[INIT:.+]]: vector<[4]xf32>
+// CHECK-SAME:  %[[MASK:.+]]: vector<2x[4]xi1>
+func.func @masked_tmatvec_k_mk_m_scalable_parallel_dim(%arg0: vector<[4]x2xf32>, %arg1: vector<2xf32>, %arg2: vector<[4]xf32>, %mask: vector<2x[4]xi1>) -> vector<[4]xf32> {
+  // CHECK:         vector.transpose %[[MAT]]
+  // CHECK-NOT:     vector.transpose %[[MASK]]
+  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
+  %res = vector.mask %mask {
+    vector.contract #matvec_trait_7 %arg1, %arg0, %arg2
+      : vector<2xf32>, vector<[4]x2xf32>, vector<[4]xf32> into vector<[4]xf32>
+  } : vector<2x[4]xi1> -> vector<[4]xf32>
+  return %res : vector<[4]xf32>
+}
+
+#matvec_accesses_8 = [
+  affine_map<(k, m) -> (k)>,
+  affine_map<(k, m) -> (k, m)>,
+  affine_map<(k, m) -> (m)>
+]
+#matvec_trait_8 = {
+  indexing_maps = #matvec_accesses_8,
+  iterator_types = ["reduction", "parallel"]
+}
+
 // CHECK-LABEL: @masked_tmatvec_k_km_m
 // CHECK-SAME:  %[[MAT:.+]]: vector<2x4xf32>
 // CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
@@ -470,17 +627,28 @@ func.func @masked_tmatvec_k_km_m(%arg0: vector<2x4xf32>, %arg1: vector<2xf32>, %
   // CHECK-NOT:     vector.transpose %[[MASK]]
   // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
   %res = vector.mask %mask {
-    vector.contract {
-      indexing_maps = [affine_map<(k, m) -> (k)>,
-                       affine_map<(k, m) -> (k, m)>,
-                       affine_map<(k, m) -> (m)>],
-      iterator_types = ["reduction", "parallel"],
-      kind = #vector.kind<add>
-    } %arg1, %arg0, %arg2 : vector<2xf32>, vector<2x4xf32>, vector<4xf32> into vector<4xf32>
+    vector.contract #matvec_trait_8 %arg1, %arg0, %arg2
+      : vector<2xf32>, vector<2x4xf32>, vector<4xf32> into vector<4xf32>
   } : vector<2x4xi1> -> vector<4xf32>
   return %res : vector<4xf32>
 }
 
+// CHECK-LABEL: @masked_tmatvec_k_km_m_scalable_parallel_dim
+// CHECK-SAME:  %[[MAT:.+]]: vector<2x[4]xf32>
+// CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME:  %[[INIT:.+]]: vector<[4]xf32>
+// CHECK-SAME:  %[[MASK:.+]]: vector<2x[4]xi1>
+func.func @masked_tmatvec_k_km_m_scalable_parallel_dim(%arg0: vector<2x[4]xf32>, %arg1: vector<2xf32>, %arg2: vector<[4]xf32>, %mask: vector<2x[4]xi1>) -> vector<[4]xf32> {
+  // CHECK-NOT:     vector.transpose %[[MAT]]
+  // CHECK-NOT:     vector.transpose %[[MASK]]
+  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
+  %res = vector.mask %mask {
+    vector.contract #matvec_trait_8 %arg1, %arg0, %arg2
+      : vector<2xf32>, vector<2x[4]xf32>, vector<[4]xf32> into vector<[4]xf32>
+  } : vector<2x[4]xi1> -> vector<[4]xf32>
+  return %res : vector<[4]xf32>
+}
+
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {



More information about the Mlir-commits mailing list