[Mlir-commits] [mlir] 9619a24 - [MLIR][Vector] Refactor tests for contract -> OP transforms (3/N) (#73447)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Nov 29 06:24:47 PST 2023


Author: Andrzej WarzyƄski
Date: 2023-11-29T14:24:42Z
New Revision: 9619a2420eac885060bae6c45e85f85abfc7d6a9

URL: https://github.com/llvm/llvm-project/commit/9619a2420eac885060bae6c45e85f85abfc7d6a9
DIFF: https://github.com/llvm/llvm-project/commit/9619a2420eac885060bae6c45e85f85abfc7d6a9.diff

LOG:     [MLIR][Vector] Refactor tests for contract -> OP transforms (3/N) (#73447)

This patch refactors tests for:

      vector.contract -> vector.outerproduct

for matvec operations (b += Ax). Summary of changes:
  * names of LIT variables are unified,
  * "plain" tests (i.e. without masking and with fixed-width vectors)
    are moved to the top of their respective sections,
  * missing "plain" cases are added.

This is a part of a larger effort to add cases with scalable vectors to
tests for the Vector dialect. I am refactoring these tests so that it's
easier to identify what cases are tested and where to add tests for
scalable vectors.

Implements #72834.

Added: 
    

Modified: 
    mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir
index 7456e122e946f7d..e84a43feaff39dc 100644
--- a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir
@@ -1,5 +1,17 @@
 // RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s
 
+/// Tests for `vector.contract` -> `vector.outerproduct` transformations for
+/// Matvec operations:
+///   b += A * x.
+/// (b and x are 1-d vectors, A is a 2-d matrix). ATM three 
diff erent variants
+/// are tested:
+///   * plain (no mask, fixed-wdith vectors),
+///   * masked (fixed-width vectors,
+///   * scalable (mask + scalable vectors).
+///
+/// TODO: These tests were extracted from 2 
diff erent files. If you find the
+/// formatting inconsistent, please update accordingly.
+
 #matvec_accesses_1 = [
   affine_map<(m, k) -> (m, k)>,
   affine_map<(m, k) -> (k)>,
@@ -46,6 +58,36 @@
   iterator_types = ["parallel", "reduction"]
 }
 
+#matvec_accesses_5 = [
+  affine_map<(k, m) -> (m, k)>,
+  affine_map<(k, m) -> (k)>,
+  affine_map<(k, m) -> (m)>
+]
+#matvec_trait_5 = {
+  indexing_maps = #matvec_accesses_5,
+  iterator_types = ["reduction", "parallel"]
+}
+
+#matvec_accesses_6 = [
+  affine_map<(k, m) -> (k, m)>,
+  affine_map<(k, m) -> (k)>,
+  affine_map<(k, m) -> (m)>
+]
+#matvec_trait_6 = {
+  indexing_maps = #matvec_accesses_6,
+  iterator_types = ["reduction", "parallel"]
+}
+
+#matvec_accesses_7 = [
+  affine_map<(k, m) -> (k)>,
+  affine_map<(k, m) -> (m, k)>,
+  affine_map<(k, m) -> (m)>
+]
+#matvec_trait_7 = {
+  indexing_maps = #matvec_accesses_7,
+  iterator_types = ["reduction", "parallel"]
+}
+
 #matvec_accesses_8 = [
   affine_map<(k, m) -> (k)>,
   affine_map<(k, m) -> (k, m)>,
@@ -59,6 +101,24 @@
 // ============================================================================
 //  Matvec 1 (plain + masked + scalable)
 // ============================================================================
+// CHECK-LABEL: func @matvec_mk_k_m
+// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
+// CHECK-SAME: %[[X:.*1]]: vector<2xf32>
+// CHECK-SAME: %[[B:.*2]]: vector<2xf32>
+// CHECK: %[[T3:.*]] = vector.transpose %[[A]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
+// CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[T5:.*]] = vector.extract %[[X]][0] : f32 from vector<2xf32>
+// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[B]] {kind = #vector.kind<add>} : vector<2xf32>, f32
+// CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[T8:.*]] = vector.extract %[[X]][1] : f32 from vector<2xf32>
+// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind<add>} : vector<2xf32>, f32
+func.func @matvec_mk_k_m(%A: vector<2x2xf32>,
+                         %x: vector<2xf32>,
+                         %b: vector<2xf32>) -> vector<2xf32> {
+  %0 = vector.contract #matvec_trait_1 %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32>
+  return %0 : vector<2xf32>
+}
+
 // CHECK-LABEL:   func.func @masked_matvec_mk_k_m(
 // CHECK-SAME:      %{{.*}}: vector<2x3xf32>,
 // CHECK-SAME:      %{{.*}}: vector<3xf32>,
@@ -73,12 +133,11 @@
 
 // CHECK:           %[[MASK2:.*]] = vector.extract %[[T_MASK]][2] : vector<2xi1> from vector<3x2xi1>
 // CHECK:           vector.mask %[[MASK2]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32>
-
-func.func @masked_matvec_mk_k_m(%arg0: vector<2x3xf32>,
-                                %arg1: vector<3xf32>,
-                                %arg2: vector<2xf32>,
+func.func @masked_matvec_mk_k_m(%A: vector<2x3xf32>,
+                                %x: vector<3xf32>,
+                                %b: vector<2xf32>,
                                 %m: vector<2x3xi1>) -> vector<2xf32> {
-  %0 = vector.mask %m { vector.contract #matvec_trait_1 %arg0, %arg1, %arg2
+  %0 = vector.mask %m { vector.contract #matvec_trait_1 %A, %x, %b
           : vector<2x3xf32>, vector<3xf32> into vector<2xf32> } : vector<2x3xi1> -> vector<2xf32>
   return %0 : vector<2xf32>
 }
@@ -97,46 +156,28 @@ func.func @masked_matvec_mk_k_m(%arg0: vector<2x3xf32>,
 
 // CHECK:           %[[MASK2:.*]] = vector.extract %[[T_MASK]][2] : vector<[2]xi1> from vector<3x[2]xi1>
 // CHECK:           vector.mask %[[MASK2]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<[2]xf32>, f32 } : vector<[2]xi1> -> vector<[2]xf32>
-func.func @masked_matvec_mk_k_m_scalable_parallel_dim(%arg0: vector<[2]x3xf32>,
-                                                      %arg1: vector<3xf32>,
-                                                      %arg2: vector<[2]xf32>,
+func.func @masked_matvec_mk_k_m_scalable_parallel_dim(%A: vector<[2]x3xf32>,
+                                                      %x: vector<3xf32>,
+                                                      %b: vector<[2]xf32>,
                                                       %m: vector<[2]x3xi1>) -> vector<[2]xf32> {
-  %0 = vector.mask %m { vector.contract #matvec_trait_1 %arg0, %arg1, %arg2
+  %0 = vector.mask %m { vector.contract #matvec_trait_1 %A, %x, %b
           : vector<[2]x3xf32>, vector<3xf32> into vector<[2]xf32> } : vector<[2]x3xi1> -> vector<[2]xf32>
   return %0 : vector<[2]xf32>
 }
 
-// CHECK-LABEL: func @matvec_mk_k_m
-// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
-// CHECK-SAME: %[[B:.*1]]: vector<2xf32>
-// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
-// CHECK: %[[T3:.*]] = vector.transpose %[[A]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
-// CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T5:.*]] = vector.extract %[[B]][0] : f32 from vector<2xf32>
-// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[C]] {kind = #vector.kind<add>} : vector<2xf32>, f32
-// CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T8:.*]] = vector.extract %[[B]][1] : f32 from vector<2xf32>
-// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind<add>} : vector<2xf32>, f32
-func.func @matvec_mk_k_m(%A: vector<2x2xf32>,
-                         %x: vector<2xf32>,
-                         %b: vector<2xf32>) -> vector<2xf32> {
-  %0 = vector.contract #matvec_trait_1 %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32>
-  return %0 : vector<2xf32>
-}
-
 // ============================================================================
 //  Matvec 1  - max (plain)
 // ============================================================================
 // CHECK-LABEL: func @matvec_mk_k_m_max
 // CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
-// CHECK-SAME: %[[B:.*1]]: vector<2xf32>
-// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
+// CHECK-SAME: %[[X:.*1]]: vector<2xf32>
+// CHECK-SAME: %[[B:.*2]]: vector<2xf32>
 // CHECK: %[[T3:.*]] = vector.transpose %[[A]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
 // CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T5:.*]] = vector.extract %[[B]][0] : f32 from vector<2xf32>
-// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[C]] {kind = #vector.kind<maxf>} : vector<2xf32>, f32
+// CHECK: %[[T5:.*]] = vector.extract %[[X]][0] : f32 from vector<2xf32>
+// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[B]] {kind = #vector.kind<maxf>} : vector<2xf32>, f32
 // CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T8:.*]] = vector.extract %[[B]][1] : f32 from vector<2xf32>
+// CHECK: %[[T8:.*]] = vector.extract %[[X]][1] : f32 from vector<2xf32>
 // CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind<maxf>} : vector<2xf32>, f32
 func.func @matvec_mk_k_m_max(%A: vector<2x2xf32>,
                              %x: vector<2xf32>,
@@ -149,38 +190,38 @@ func.func @matvec_mk_k_m_max(%A: vector<2x2xf32>,
 //  Matvec 2 (plain + masked + scalable)
 // ============================================================================
 // CHECK-LABEL: @masked_matvec_km_k_m
-// CHECK-SAME:  %[[MAT:.+]]: vector<2x4xf32>
-// CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
-// CHECK-SAME:  %[[INIT:.+]]: vector<4xf32>
+// CHECK-SAME:  %[[A:.+]]: vector<2x4xf32>
+// CHECK-SAME:  %[[X:.+]]: vector<2xf32>
+// CHECK-SAME:  %[[B:.+]]: vector<4xf32>
 // CHECK-SAME:  %[[MASK:.+]]: vector<4x2xi1>
-func.func @masked_matvec_km_k_m(%arg0: vector<2x4xf32>,
-                                %arg1: vector<2xf32>,
-                                %arg2: vector<4xf32>, 
+func.func @masked_matvec_km_k_m(%A: vector<2x4xf32>,
+                                %x: vector<2xf32>,
+                                %b: vector<4xf32>, 
                                 %mask: vector<4x2xi1>) -> vector<4xf32> {
   // CHECK:         vector.transpose %[[MASK]]
-  // CHECK-NOT:     vector.transpose %[[MAT]]
+  // CHECK-NOT:     vector.transpose %[[A]]
   // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
   %res = vector.mask %mask {
-    vector.contract #matvec_trait_2 %arg0, %arg1, %arg2
+    vector.contract #matvec_trait_2 %A, %x, %b
       : vector<2x4xf32>, vector<2xf32>, vector<4xf32> into vector<4xf32>
   } : vector<4x2xi1> -> vector<4xf32>
   return %res : vector<4xf32>
 }
 
 // CHECK-LABEL: @masked_matvec_km_k_m_scalable_parallel_dim
-// CHECK-SAME:  %[[MAT:.+]]: vector<2x[4]xf32>
-// CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
-// CHECK-SAME:  %[[INIT:.+]]: vector<[4]xf32>
+// CHECK-SAME:  %[[A:.+]]: vector<2x[4]xf32>
+// CHECK-SAME:  %[[X:.+]]: vector<2xf32>
+// CHECK-SAME:  %[[B:.+]]: vector<[4]xf32>
 // CHECK-SAME:  %[[MASK:.+]]: vector<[4]x2xi1>
-func.func @masked_matvec_km_k_m_scalable_parallel_dim(%arg0: vector<2x[4]xf32>,
-                                                      %arg1: vector<2xf32>,
-                                                      %arg2: vector<[4]xf32>,
+func.func @masked_matvec_km_k_m_scalable_parallel_dim(%A: vector<2x[4]xf32>,
+                                                      %x: vector<2xf32>,
+                                                      %b: vector<[4]xf32>,
                                                       %mask: vector<[4]x2xi1>) -> vector<[4]xf32> {
   // CHECK:         vector.transpose %[[MASK]]
-  // CHECK-NOT:     vector.transpose %[[MAT]]
+  // CHECK-NOT:     vector.transpose %[[A]]
   // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
   %res = vector.mask %mask {
-    vector.contract #matvec_trait_2 %arg0, %arg1, %arg2
+    vector.contract #matvec_trait_2 %A, %x, %b
       : vector<2x[4]xf32>, vector<2xf32>, vector<[4]xf32> into vector<[4]xf32>
   } : vector<[4]x2xi1> -> vector<[4]xf32>
   return %res : vector<[4]xf32>
@@ -188,13 +229,13 @@ func.func @masked_matvec_km_k_m_scalable_parallel_dim(%arg0: vector<2x[4]xf32>,
 
 // CHECK-LABEL: func @matvec_km_k_m
 // CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
-// CHECK-SAME: %[[B:.*1]]: vector<2xf32>
-// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
+// CHECK-SAME: %[[X:.*1]]: vector<2xf32>
+// CHECK-SAME: %[[B:.*2]]: vector<2xf32>
 // CHECK: %[[T3:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T4:.*]] = vector.extract %[[B]][0] : f32 from vector<2xf32>
-// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[C]] {kind = #vector.kind<add>} : vector<2xf32>, f32
+// CHECK: %[[T4:.*]] = vector.extract %[[X]][0] : f32 from vector<2xf32>
+// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[B]] {kind = #vector.kind<add>} : vector<2xf32>, f32
 // CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T7:.*]] = vector.extract %[[B]][1] : f32 from vector<2xf32>
+// CHECK: %[[T7:.*]] = vector.extract %[[X]][1] : f32 from vector<2xf32>
 // CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] {kind = #vector.kind<add>} : vector<2xf32>, f32
 func.func @matvec_km_k_m(%A: vector<2x2xf32>,
                          %x: vector<2xf32>,
@@ -207,54 +248,53 @@ func.func @matvec_km_k_m(%A: vector<2x2xf32>,
 //  Matvec 3 (plain + masked + scalable)
 // ============================================================================
 // CHECK-LABEL: @masked_matvec_k_mk_m
-// CHECK-SAME:  %[[MAT:.+]]: vector<4x2xf32>
-// CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
-// CHECK-SAME:  %[[INIT:.+]]: vector<4xf32>
+// CHECK-SAME:  %[[A:.+]]: vector<4x2xf32>
+// CHECK-SAME:  %[[X:.+]]: vector<2xf32>
+// CHECK-SAME:  %[[B:.+]]: vector<4xf32>
 // CHECK-SAME:  %[[MASK:.+]]: vector<4x2xi1>
-func.func @masked_matvec_k_mk_m(%arg0: vector<4x2xf32>,
-                                %arg1: vector<2xf32>,
-                                %arg2: vector<4xf32>,
+func.func @masked_matvec_k_mk_m(%A: vector<4x2xf32>,
+                                %x: vector<2xf32>,
+                                %b: vector<4xf32>,
                                 %mask: vector<4x2xi1>) -> vector<4xf32> {
   // CHECK:         vector.transpose %[[MASK]]
-  // CHECK:         vector.transpose %[[MAT]]
+  // CHECK:         vector.transpose %[[A]]
   // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
   %res = vector.mask %mask {
-      vector.contract #matvec_trait_3 %arg1, %arg0, %arg2
+      vector.contract #matvec_trait_3 %x, %A, %b
         : vector<2xf32>, vector<4x2xf32>, vector<4xf32> into vector<4xf32>
   } : vector<4x2xi1> -> vector<4xf32>
   return %res : vector<4xf32>
 }
 
 // CHECK-LABEL: @masked_matvec_k_mk_m_scalable_parallel_dim
-// CHECK-SAME:  %[[MAT:.+]]: vector<[4]x2xf32>
-// CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
-// CHECK-SAME:  %[[INIT:.+]]: vector<[4]xf32>
+// CHECK-SAME:  %[[A:.+]]: vector<[4]x2xf32>
+// CHECK-SAME:  %[[X:.+]]: vector<2xf32>
+// CHECK-SAME:  %[[B:.+]]: vector<[4]xf32>
 // CHECK-SAME:  %[[MASK:.+]]: vector<[4]x2xi1>
-func.func @masked_matvec_k_mk_m_scalable_parallel_dim(%arg0: vector<[4]x2xf32>,
-                                                      %arg1: vector<2xf32>,
-                                                      %arg2: vector<[4]xf32>,
+func.func @masked_matvec_k_mk_m_scalable_parallel_dim(%A: vector<[4]x2xf32>,
+                                                      %x: vector<2xf32>,
+                                                      %b: vector<[4]xf32>,
                                                       %mask: vector<[4]x2xi1>) -> vector<[4]xf32> {
   // CHECK:         vector.transpose %[[MASK]]
-  // CHECK:         vector.transpose %[[MAT]]
+  // CHECK:         vector.transpose %[[A]]
   // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
   %res = vector.mask %mask {
-      vector.contract #matvec_trait_3 %arg1, %arg0, %arg2
+      vector.contract #matvec_trait_3 %x, %A, %b
         : vector<2xf32>, vector<[4]x2xf32>, vector<[4]xf32> into vector<[4]xf32>
   } : vector<[4]x2xi1> -> vector<[4]xf32>
   return %res : vector<[4]xf32>
 }
 
-
 // CHECK-LABEL: func @matvec_k_mk_m
 // CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
-// CHECK-SAME: %[[B:.*1]]: vector<2xf32>
-// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
+// CHECK-SAME: %[[X:.*1]]: vector<2xf32>
+// CHECK-SAME: %[[B:.*2]]: vector<2xf32>
 // CHECK: %[[T3:.*]] = vector.transpose %[[A]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
 // CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T5:.*]] = vector.extract %[[B]][0] : f32 from vector<2xf32>
-// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[C]] {kind = #vector.kind<add>} : vector<2xf32>, f32
+// CHECK: %[[T5:.*]] = vector.extract %[[X]][0] : f32 from vector<2xf32>
+// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[B]] {kind = #vector.kind<add>} : vector<2xf32>, f32
 // CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T8:.*]] = vector.extract %[[B]][1] : f32 from vector<2xf32>
+// CHECK: %[[T8:.*]] = vector.extract %[[X]][1] : f32 from vector<2xf32>
 // CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind<add>} : vector<2xf32>, f32
 func.func @matvec_k_mk_m(%A: vector<2x2xf32>, 
                          %x: vector<2xf32>,
@@ -266,191 +306,232 @@ func.func @matvec_k_mk_m(%A: vector<2x2xf32>,
 // ============================================================================
 //  Matvec 4 (plain + masked + scalable)
 // ============================================================================
+// CHECK-LABEL: func @matvec_k_km_m
+// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
+// CHECK-SAME: %[[X:.*1]]: vector<2xf32>
+// CHECK-SAME: %[[B:.*2]]: vector<2xf32>
+// CHECK: %[[T3:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[T4:.*]] = vector.extract %[[X]][0] : f32 from vector<2xf32>
+// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[B]] {kind = #vector.kind<add>} : vector<2xf32>, f32
+// CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[T7:.*]] = vector.extract %[[X]][1] : f32 from vector<2xf32>
+// CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] {kind = #vector.kind<add>} : vector<2xf32>, f32
+func.func @matvec_k_km_m(%A: vector<2x2xf32>,
+                         %x: vector<2xf32>,
+                         %b: vector<2xf32>) -> vector<2xf32> {
+  %0 = vector.contract #matvec_trait_4 %x, %A, %b : vector<2xf32>, vector<2x2xf32> into vector<2xf32>
+  return %0 : vector<2xf32>
+}
+
 // CHECK-LABEL: @masked_matvec_k_km_m_scalable_parallel_dim
-// CHECK-SAME:  %[[MAT:.+]]: vector<2x[4]xf32>
-// CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
-// CHECK-SAME:  %[[INIT:.+]]: vector<[4]xf32>
+// CHECK-SAME:  %[[A:.+]]: vector<2x[4]xf32>
+// CHECK-SAME:  %[[X:.+]]: vector<2xf32>
+// CHECK-SAME:  %[[B:.+]]: vector<[4]xf32>
 // CHECK-SAME:  %[[MASK:.+]]: vector<[4]x2xi1>
-func.func @masked_matvec_k_km_m_scalable_parallel_dim(%arg0: vector<2x[4]xf32>,
-                                                      %arg1: vector<2xf32>,
-                                                      %arg2: vector<[4]xf32>,
+func.func @masked_matvec_k_km_m_scalable_parallel_dim(%A: vector<2x[4]xf32>,
+                                                      %x: vector<2xf32>,
+                                                      %b: vector<[4]xf32>,
                                                       %mask: vector<[4]x2xi1>) -> vector<[4]xf32> {
   // CHECK:         vector.transpose %[[MASK]]
-  // CHECK-NOT:     vector.transpose %[[MAT]]
+  // CHECK-NOT:     vector.transpose %[[A]]
   // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
   %res = vector.mask %mask {
-    vector.contract #matvec_trait_4 %arg1, %arg0, %arg2
+    vector.contract #matvec_trait_4 %x, %A, %b
       : vector<2xf32>, vector<2x[4]xf32>, vector<[4]xf32> into vector<[4]xf32>
   } : vector<[4]x2xi1> -> vector<[4]xf32>
   return %res : vector<[4]xf32>
 }
 
 // CHECK-LABEL: @masked_matvec_k_km_m
-// CHECK-SAME:  %[[MAT:.+]]: vector<2x4xf32>
-// CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
-// CHECK-SAME:  %[[INIT:.+]]: vector<4xf32>
+// CHECK-SAME:  %[[A:.+]]: vector<2x4xf32>
+// CHECK-SAME:  %[[X:.+]]: vector<2xf32>
+// CHECK-SAME:  %[[B:.+]]: vector<4xf32>
 // CHECK-SAME:  %[[MASK:.+]]: vector<4x2xi1>
-func.func @masked_matvec_k_km_m(%arg0: vector<2x4xf32>,
-                                %arg1: vector<2xf32>,
-                                %arg2: vector<4xf32>,
+func.func @masked_matvec_k_km_m(%A: vector<2x4xf32>,
+                                %x: vector<2xf32>,
+                                %b: vector<4xf32>,
                                 %mask: vector<4x2xi1>) -> vector<4xf32> {
   // CHECK:         vector.transpose %[[MASK]]
-  // CHECK-NOT:     vector.transpose %[[MAT]]
+  // CHECK-NOT:     vector.transpose %[[A]]
   // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
   %res = vector.mask %mask {
-    vector.contract #matvec_trait_4 %arg1, %arg0, %arg2
+    vector.contract #matvec_trait_4 %x, %A, %b
       : vector<2xf32>, vector<2x4xf32>, vector<4xf32> into vector<4xf32>
   } : vector<4x2xi1> -> vector<4xf32>
   return %res : vector<4xf32>
 }
 
-// CHECK-LABEL: func @matvec_k_km_m
-// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
-// CHECK-SAME: %[[B:.*1]]: vector<2xf32>
-// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
-// CHECK: %[[T3:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T4:.*]] = vector.extract %[[B]][0] : f32 from vector<2xf32>
-// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[C]] {kind = #vector.kind<add>} : vector<2xf32>, f32
-// CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T7:.*]] = vector.extract %[[B]][1] : f32 from vector<2xf32>
-// CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] {kind = #vector.kind<add>} : vector<2xf32>, f32
-func.func @matvec_k_km_m(%A: vector<2x2xf32>,
-                         %x: vector<2xf32>,
-                         %b: vector<2xf32>) -> vector<2xf32> {
-  %0 = vector.contract #matvec_trait_4 %x, %A, %b : vector<2xf32>, vector<2x2xf32> into vector<2xf32>
-  return %0 : vector<2xf32>
-}
-
 // ============================================================================
-//  Matvec 5 (masked + scalable)
+//  Matvec 5 (plain + masked + scalable)
 // ============================================================================
-#matvec_accesses_5 = [
-  affine_map<(k, m) -> (m, k)>,
-  affine_map<(k, m) -> (k)>,
-  affine_map<(k, m) -> (m)>
-]
-#matvec_trait_5 = {
-  indexing_maps = #matvec_accesses_5,
-  iterator_types = ["reduction", "parallel"]
+// CHECK-LABEL:   func.func @tmatvec_mk_k_m(
+// CHECK-SAME:      %[[A:.*]]: vector<2x2xf32>,
+// CHECK-SAME:      %[[X:.*]]: vector<2xf32>,
+// CHECK-SAME:      %[[B:.*]]: vector<2xf32>) -> vector<2xf32> {
+// CHECK:           %[[VAL_3:.*]] = vector.transpose %[[A]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
+// CHECK:           %[[VAL_4:.*]] = vector.extract %[[VAL_3]][0] : vector<2xf32> from vector<2x2xf32>
+// CHECK:           %[[VAL_5:.*]] = vector.extract %[[X]][0] : f32 from vector<2xf32>
+// CHECK:           %[[VAL_6:.*]] = vector.outerproduct %[[VAL_4]], %[[VAL_5]], %[[B]] {kind = #vector.kind<add>} : vector<2xf32>, f32
+// CHECK:           %[[VAL_7:.*]] = vector.extract %[[VAL_3]][1] : vector<2xf32> from vector<2x2xf32>
+// CHECK:           %[[VAL_8:.*]] = vector.extract %[[X]][1] : f32 from vector<2xf32>
+// CHECK:           %[[VAL_9:.*]] = vector.outerproduct %[[VAL_7]], %[[VAL_8]], %[[VAL_6]] {kind = #vector.kind<add>} : vector<2xf32>, f32
+func.func @tmatvec_mk_k_m(%A: vector<2x2xf32>,
+                          %x: vector<2xf32>,
+                          %b: vector<2xf32>) -> vector<2xf32> {
+  %0 = vector.contract #matvec_trait_5 %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32>
+  return %0 : vector<2xf32>
 }
 
 // CHECK-LABEL: @masked_tmatvec_mk_k_m
-// CHECK-SAME:  %[[MAT:.+]]: vector<4x2xf32>
-// CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
-// CHECK-SAME:  %[[INIT:.+]]: vector<4xf32>
+// CHECK-SAME:  %[[A:.+]]: vector<4x2xf32>
+// CHECK-SAME:  %[[X:.+]]: vector<2xf32>
+// CHECK-SAME:  %[[B:.+]]: vector<4xf32>
 // CHECK-SAME:  %[[MASK:.+]]: vector<2x4xi1>
-func.func @masked_tmatvec_mk_k_m(%arg0: vector<4x2xf32>, %arg1: vector<2xf32>, %arg2: vector<4xf32>, %mask: vector<2x4xi1>) -> vector<4xf32> {
-  // CHECK:         vector.transpose %[[MAT]]
+func.func @masked_tmatvec_mk_k_m(%A: vector<4x2xf32>,
+                                 %x: vector<2xf32>,
+                                 %b: vector<4xf32>,
+                                 %mask: vector<2x4xi1>) -> vector<4xf32> {
+  // CHECK:         vector.transpose %[[A]]
   // CHECK-NOT:     vector.transpose %[[MASK]]
   // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
   %res = vector.mask %mask {
-    vector.contract #matvec_trait_5 %arg0, %arg1, %arg2
+    vector.contract #matvec_trait_5 %A, %x, %b
       : vector<4x2xf32>, vector<2xf32>, vector<4xf32> into vector<4xf32>
   } : vector<2x4xi1> -> vector<4xf32>
   return %res : vector<4xf32>
 }
 
 // CHECK-LABEL: @masked_tmatvec_mk_k_m_scalable_parallel_dim
-// CHECK-SAME:  %[[MAT:.+]]: vector<[4]x2xf32>
-// CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
-// CHECK-SAME:  %[[INIT:.+]]: vector<[4]xf32>
+// CHECK-SAME:  %[[A:.+]]: vector<[4]x2xf32>
+// CHECK-SAME:  %[[X:.+]]: vector<2xf32>
+// CHECK-SAME:  %[[B:.+]]: vector<[4]xf32>
 // CHECK-SAME:  %[[MASK:.+]]: vector<2x[4]xi1>
-func.func @masked_tmatvec_mk_k_m_scalable_parallel_dim(%arg0: vector<[4]x2xf32>, %arg1: vector<2xf32>, %arg2: vector<[4]xf32>, %mask: vector<2x[4]xi1>) -> vector<[4]xf32> {
-  // CHECK:         vector.transpose %[[MAT]]
+func.func @masked_tmatvec_mk_k_m_scalable_parallel_dim(%A: vector<[4]x2xf32>,
+                                                       %x: vector<2xf32>,
+                                                       %b: vector<[4]xf32>,
+                                                       %mask: vector<2x[4]xi1>) -> vector<[4]xf32> {
+  // CHECK:         vector.transpose %[[A]]
   // CHECK-NOT:     vector.transpose %[[MASK]]
   // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
   %res = vector.mask %mask {
-    vector.contract #matvec_trait_5 %arg0, %arg1, %arg2
+    vector.contract #matvec_trait_5 %A, %x, %b
       : vector<[4]x2xf32>, vector<2xf32>, vector<[4]xf32> into vector<[4]xf32>
   } : vector<2x[4]xi1> -> vector<[4]xf32>
   return %res : vector<[4]xf32>
 }
 
 // ============================================================================
-//  Matvec 6 (masked + scalable)
+//  Matvec 6 (plain + masked + scalable)
 // ============================================================================
-#matvec_accesses_6 = [
-  affine_map<(k, m) -> (k, m)>,
-  affine_map<(k, m) -> (k)>,
-  affine_map<(k, m) -> (m)>
-]
-#matvec_trait_6 = {
-  indexing_maps = #matvec_accesses_6,
-  iterator_types = ["reduction", "parallel"]
+// CHECK-LABEL:   func.func @tmatvec_km_k_m(
+// CHECK-SAME:      %[[A:.*]]: vector<2x2xf32>,
+// CHECK-SAME:      %[[X:.*]]: vector<2xf32>,
+// CHECK-SAME:      %[[B:.*]]: vector<2xf32>) -> vector<2xf32> {
+// CHECK:           %[[VAL_3:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<2x2xf32>
+// CHECK:           %[[VAL_4:.*]] = vector.extract %[[X]][0] : f32 from vector<2xf32>
+// CHECK:           %[[VAL_5:.*]] = vector.outerproduct %[[VAL_3]], %[[VAL_4]], %[[B]] {kind = #vector.kind<add>} : vector<2xf32>, f32
+// CHECK:           %[[VAL_6:.*]] = vector.extract %[[A]][1] : vector<2xf32> from vector<2x2xf32>
+// CHECK:           %[[VAL_7:.*]] = vector.extract %[[X]][1] : f32 from vector<2xf32>
+// CHECK:           %[[VAL_8:.*]] = vector.outerproduct %[[VAL_6]], %[[VAL_7]], %[[VAL_5]] {kind = #vector.kind<add>} : vector<2xf32>, f32
+func.func @tmatvec_km_k_m(%A: vector<2x2xf32>,
+                          %x: vector<2xf32>,
+                          %b: vector<2xf32>) -> vector<2xf32> {
+  %0 = vector.contract #matvec_trait_6 %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32>
+  return %0 : vector<2xf32>
 }
 
 // CHECK-LABEL: @masked_tmatvec_km_k_m
-// CHECK-SAME:  %[[MAT:.+]]: vector<2x4xf32>
-// CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
-// CHECK-SAME:  %[[INIT:.+]]: vector<4xf32>
+// CHECK-SAME:  %[[A:.+]]: vector<2x4xf32>
+// CHECK-SAME:  %[[X:.+]]: vector<2xf32>
+// CHECK-SAME:  %[[B:.+]]: vector<4xf32>
 // CHECK-SAME:  %[[MASK:.+]]: vector<2x4xi1>
-func.func @masked_tmatvec_km_k_m(%arg0: vector<2x4xf32>, %arg1: vector<2xf32>, %arg2: vector<4xf32>, %mask: vector<2x4xi1>) -> vector<4xf32> {
-  // CHECK-NOT:     vector.transpose %[[MAT]]
+func.func @masked_tmatvec_km_k_m(%A: vector<2x4xf32>,
+                                 %x: vector<2xf32>,
+                                 %b: vector<4xf32>,
+                                 %mask: vector<2x4xi1>) -> vector<4xf32> {
+  // CHECK-NOT:     vector.transpose %[[A]]
   // CHECK-NOT:     vector.transpose %[[MASK]]
   // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
   %res = vector.mask %mask {
-    vector.contract #matvec_trait_6 %arg0, %arg1, %arg2
+    vector.contract #matvec_trait_6 %A, %x, %b
       : vector<2x4xf32>, vector<2xf32>, vector<4xf32> into vector<4xf32>
   } : vector<2x4xi1> -> vector<4xf32>
   return %res : vector<4xf32>
 }
 
 // CHECK-LABEL: @masked_tmatvec_km_k_m_scalable_parallel_dim
-// CHECK-SAME:  %[[MAT:.+]]: vector<2x[4]xf32>
-// CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
-// CHECK-SAME:  %[[INIT:.+]]: vector<[4]xf32>
+// CHECK-SAME:  %[[A:.+]]: vector<2x[4]xf32>
+// CHECK-SAME:  %[[X:.+]]: vector<2xf32>
+// CHECK-SAME:  %[[B:.+]]: vector<[4]xf32>
 // CHECK-SAME:  %[[MASK:.+]]: vector<2x[4]xi1>
-func.func @masked_tmatvec_km_k_m_scalable_parallel_dim(%arg0: vector<2x[4]xf32>, %arg1: vector<2xf32>, %arg2: vector<[4]xf32>, %mask: vector<2x[4]xi1>) -> vector<[4]xf32> {
-  // CHECK-NOT:     vector.transpose %[[MAT]]
+func.func @masked_tmatvec_km_k_m_scalable_parallel_dim(%A: vector<2x[4]xf32>,
+                                                       %x: vector<2xf32>,
+                                                       %b: vector<[4]xf32>,
+                                                       %mask: vector<2x[4]xi1>) -> vector<[4]xf32> {
+  // CHECK-NOT:     vector.transpose %[[A]]
   // CHECK-NOT:     vector.transpose %[[MASK]]
   // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
   %res = vector.mask %mask {
-    vector.contract #matvec_trait_6 %arg0, %arg1, %arg2
+    vector.contract #matvec_trait_6 %A, %x, %b
       : vector<2x[4]xf32>, vector<2xf32>, vector<[4]xf32> into vector<[4]xf32>
   } : vector<2x[4]xi1> -> vector<[4]xf32>
   return %res : vector<[4]xf32>
 }
 
 // ============================================================================
-//  Matvec 7 (masked + scalable)
+//  Matvec 7 (plain + masked + scalable)
 // ============================================================================
-#matvec_accesses_7 = [
-  affine_map<(k, m) -> (k)>,
-  affine_map<(k, m) -> (m, k)>,
-  affine_map<(k, m) -> (m)>
-]
-#matvec_trait_7 = {
-  indexing_maps = #matvec_accesses_7,
-  iterator_types = ["reduction", "parallel"]
+// CHECK-LABEL:   func.func @tmatvec_k_mk_m(
+// CHECK-SAME:      %[[A:.*]]: vector<2x2xf32>,
+// CHECK-SAME:      %[[X:.*]]: vector<2xf32>,
+// CHECK-SAME:      %[[B:.*]]: vector<2xf32>) -> vector<2xf32> {
+// CHECK:           %[[VAL_3:.*]] = vector.transpose %[[A]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
+// CHECK:           %[[VAL_4:.*]] = vector.extract %[[VAL_3]][0] : vector<2xf32> from vector<2x2xf32>
+// CHECK:           %[[VAL_5:.*]] = vector.extract %[[X]][0] : f32 from vector<2xf32>
+// CHECK:           %[[VAL_6:.*]] = vector.outerproduct %[[VAL_4]], %[[VAL_5]], %[[B]] {kind = #vector.kind<add>} : vector<2xf32>, f32
+// CHECK:           %[[VAL_7:.*]] = vector.extract %[[VAL_3]][1] : vector<2xf32> from vector<2x2xf32>
+// CHECK:           %[[VAL_8:.*]] = vector.extract %[[X]][1] : f32 from vector<2xf32>
+// CHECK:           %[[VAL_9:.*]] = vector.outerproduct %[[VAL_7]], %[[VAL_8]], %[[VAL_6]] {kind = #vector.kind<add>} : vector<2xf32>, f32
+func.func @tmatvec_k_mk_m(%A: vector<2x2xf32>,
+                          %x: vector<2xf32>,
+                          %b: vector<2xf32>) -> vector<2xf32> {
+  %0 = vector.contract #matvec_trait_7 %x, %A, %b : vector<2xf32>, vector<2x2xf32> into vector<2xf32>
+  return %0 : vector<2xf32>
 }
 
 // CHECK-LABEL: @masked_tmatvec_k_mk_m
-// CHECK-SAME:  %[[MAT:.+]]: vector<4x2xf32>
-// CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
-// CHECK-SAME:  %[[INIT:.+]]: vector<4xf32>
+// CHECK-SAME:  %[[A:.+]]: vector<4x2xf32>
+// CHECK-SAME:  %[[X:.+]]: vector<2xf32>
+// CHECK-SAME:  %[[B:.+]]: vector<4xf32>
 // CHECK-SAME:  %[[MASK:.+]]: vector<2x4xi1>
-func.func @masked_tmatvec_k_mk_m(%arg0: vector<4x2xf32>, %arg1: vector<2xf32>, %arg2: vector<4xf32>, %mask: vector<2x4xi1>) -> vector<4xf32> {
-  // CHECK:         vector.transpose %[[MAT]]
+func.func @masked_tmatvec_k_mk_m(%A: vector<4x2xf32>,
+                                 %x: vector<2xf32>,
+                                 %b: vector<4xf32>,
+                                 %mask: vector<2x4xi1>) -> vector<4xf32> {
+  // CHECK:         vector.transpose %[[A]]
   // CHECK-NOT:     vector.transpose %[[MASK]]
   // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
   %res = vector.mask %mask {
-    vector.contract #matvec_trait_7 %arg1, %arg0, %arg2
+    vector.contract #matvec_trait_7 %x, %A, %b
       : vector<2xf32>, vector<4x2xf32>, vector<4xf32> into vector<4xf32>
   } : vector<2x4xi1> -> vector<4xf32>
   return %res : vector<4xf32>
 }
 
 // CHECK-LABEL: @masked_tmatvec_k_mk_m_scalable_parallel_dim
-// CHECK-SAME:  %[[MAT:.+]]: vector<[4]x2xf32>
-// CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
-// CHECK-SAME:  %[[INIT:.+]]: vector<[4]xf32>
+// CHECK-SAME:  %[[A:.+]]: vector<[4]x2xf32>
+// CHECK-SAME:  %[[X:.+]]: vector<2xf32>
+// CHECK-SAME:  %[[B:.+]]: vector<[4]xf32>
 // CHECK-SAME:  %[[MASK:.+]]: vector<2x[4]xi1>
-func.func @masked_tmatvec_k_mk_m_scalable_parallel_dim(%arg0: vector<[4]x2xf32>, %arg1: vector<2xf32>, %arg2: vector<[4]xf32>, %mask: vector<2x[4]xi1>) -> vector<[4]xf32> {
-  // CHECK:         vector.transpose %[[MAT]]
+func.func @masked_tmatvec_k_mk_m_scalable_parallel_dim(%A: vector<[4]x2xf32>,
+                                                       %x: vector<2xf32>,
+                                                       %b: vector<[4]xf32>,
+                                                       %mask: vector<2x[4]xi1>) -> vector<[4]xf32> {
+  // CHECK:         vector.transpose %[[A]]
   // CHECK-NOT:     vector.transpose %[[MASK]]
   // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
   %res = vector.mask %mask {
-    vector.contract #matvec_trait_7 %arg1, %arg0, %arg2
+    vector.contract #matvec_trait_7 %x, %A, %b
       : vector<2xf32>, vector<[4]x2xf32>, vector<[4]xf32> into vector<[4]xf32>
   } : vector<2x[4]xi1> -> vector<[4]xf32>
   return %res : vector<[4]xf32>
@@ -459,50 +540,56 @@ func.func @masked_tmatvec_k_mk_m_scalable_parallel_dim(%arg0: vector<[4]x2xf32>,
 // ============================================================================
 //  Matvec 8 (plain + masked + scalable)
 // ============================================================================
-// CHECK-LABEL: func @matvec_m_mk_k
+// CHECK-LABEL: func @tmatvec_m_mk_k
 // CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
-// CHECK-SAME: %[[B:.*1]]: vector<2xf32>
-// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
+// CHECK-SAME: %[[X:.*1]]: vector<2xf32>
+// CHECK-SAME: %[[B:.*2]]: vector<2xf32>
 // CHECK: %[[T3:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T4:.*]] = vector.extract %[[B]][0] : f32 from vector<2xf32>
-// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[C]] {kind = #vector.kind<add>} : vector<2xf32>, f32
+// CHECK: %[[T4:.*]] = vector.extract %[[X]][0] : f32 from vector<2xf32>
+// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[B]] {kind = #vector.kind<add>} : vector<2xf32>, f32
 // CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T7:.*]] = vector.extract %[[B]][1] : f32 from vector<2xf32>
+// CHECK: %[[T7:.*]] = vector.extract %[[X]][1] : f32 from vector<2xf32>
 // CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] {kind = #vector.kind<add>} : vector<2xf32>, f32
-func.func @matvec_m_mk_k(%A: vector<2x2xf32>,
-                         %x: vector<2xf32>,
-                         %b: vector<2xf32>) -> vector<2xf32> {
+func.func @tmatvec_m_mk_k(%A: vector<2x2xf32>,
+                          %x: vector<2xf32>,
+                          %b: vector<2xf32>) -> vector<2xf32> {
   %0 = vector.contract #matvec_trait_8 %x, %A, %b : vector<2xf32>, vector<2x2xf32> into vector<2xf32>
   return %0 : vector<2xf32>
 }
 
 // CHECK-LABEL: @masked_tmatvec_k_km_m
-// CHECK-SAME:  %[[MAT:.+]]: vector<2x4xf32>
-// CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
-// CHECK-SAME:  %[[INIT:.+]]: vector<4xf32>
+// CHECK-SAME:  %[[A:.+]]: vector<2x4xf32>
+// CHECK-SAME:  %[[X:.+]]: vector<2xf32>
+// CHECK-SAME:  %[[B:.+]]: vector<4xf32>
 // CHECK-SAME:  %[[MASK:.+]]: vector<2x4xi1>
-func.func @masked_tmatvec_k_km_m(%arg0: vector<2x4xf32>, %arg1: vector<2xf32>, %arg2: vector<4xf32>, %mask: vector<2x4xi1>) -> vector<4xf32> {
-  // CHECK-NOT:     vector.transpose %[[MAT]]
+func.func @masked_tmatvec_k_km_m(%A: vector<2x4xf32>,
+                                 %x: vector<2xf32>,
+                                 %b: vector<4xf32>,
+                                 %mask: vector<2x4xi1>) -> vector<4xf32> {
+  // CHECK-NOT:     vector.transpose %[[A]]
   // CHECK-NOT:     vector.transpose %[[MASK]]
   // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
   %res = vector.mask %mask {
-    vector.contract #matvec_trait_8 %arg1, %arg0, %arg2
+    vector.contract #matvec_trait_8 %x, %A, %b
       : vector<2xf32>, vector<2x4xf32>, vector<4xf32> into vector<4xf32>
   } : vector<2x4xi1> -> vector<4xf32>
   return %res : vector<4xf32>
 }
 
 // CHECK-LABEL: @masked_tmatvec_k_km_m_scalable_parallel_dim
-// CHECK-SAME:  %[[MAT:.+]]: vector<2x[4]xf32>
-// CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
-// CHECK-SAME:  %[[INIT:.+]]: vector<[4]xf32>
+// CHECK-SAME:  %[[A:.+]]: vector<2x[4]xf32>
+// CHECK-SAME:  %[[X:.+]]: vector<2xf32>
+// CHECK-SAME:  %[[B:.+]]: vector<[4]xf32>
 // CHECK-SAME:  %[[MASK:.+]]: vector<2x[4]xi1>
-func.func @masked_tmatvec_k_km_m_scalable_parallel_dim(%arg0: vector<2x[4]xf32>, %arg1: vector<2xf32>, %arg2: vector<[4]xf32>, %mask: vector<2x[4]xi1>) -> vector<[4]xf32> {
-  // CHECK-NOT:     vector.transpose %[[MAT]]
+func.func @masked_tmatvec_k_km_m_scalable_parallel_dim(%A: vector<2x[4]xf32>,
+                                                       %x: vector<2xf32>,
+                                                       %b: vector<[4]xf32>,
+                                                       %mask: vector<2x[4]xi1>) -> vector<[4]xf32> {
+  // CHECK-NOT:     vector.transpose %[[A]]
   // CHECK-NOT:     vector.transpose %[[MASK]]
   // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
   %res = vector.mask %mask {
-    vector.contract #matvec_trait_8 %arg1, %arg0, %arg2
+    vector.contract #matvec_trait_8 %x, %A, %b
       : vector<2xf32>, vector<2x[4]xf32>, vector<[4]xf32> into vector<[4]xf32>
   } : vector<2x[4]xi1> -> vector<[4]xf32>
   return %res : vector<[4]xf32>


        


More information about the Mlir-commits mailing list