[Mlir-commits] [mlir] [mlir][vector] proper masking support for contract lowering (PR #67145)

Oleksandr Alex Zinenko llvmlistbot at llvm.org
Mon Sep 25 02:56:43 PDT 2023


https://github.com/ftynse updated https://github.com/llvm/llvm-project/pull/67145

>From a8f17e70037abbfa0727d40e81619475ed37ef75 Mon Sep 17 00:00:00 2001
From: Alex Zinenko <zinenko at google.com>
Date: Fri, 22 Sep 2023 14:06:47 +0000
Subject: [PATCH 1/2] [mlir][vector] proper masking support for contract
 lowering

Support all known permutations when lowering masked vector.contract to
vector.outerproduct, and not just the canonical permutation.
---
 .../Vector/Transforms/LowerVectorContract.cpp |  43 +++--
 ...r-contract-to-outerproduct-transforms.mlir | 168 ++++++++++++++++++
 2 files changed, 193 insertions(+), 18 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index 64ab0abda26e640..de7aba6e84b748b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -457,38 +457,44 @@ struct UnrolledOuterProductGenerator
     // Set up the parallel/reduction structure in the right form.
     AffineExpr m, n, k;
     bindDims(rewriter.getContext(), m, n, k);
+    Value transposedMask = t(mask, {2, 0, 1});
     // Classical row-major matmul:  Just permute the lhs.
     if (layout({{m, k}, {k, n}, {m, n}}))
-      return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1),
-                       t(mask, {2, 0, 1}));
+      return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1), transposedMask);
     // TODO: may be better to fail and use some vector<k> -> scalar reduction.
     if (layout({{m, k}, {n, k}, {m, n}})) {
       Value tlhs = t(lhs);
-      return outerProd(tlhs, t(rhs), res, lhsType.getDimSize(1));
+      return outerProd(tlhs, t(rhs), res, lhsType.getDimSize(1),
+                       transposedMask);
     }
     // No need to permute anything.
     if (layout({{k, m}, {k, n}, {m, n}}))
-      return outerProd(lhs, rhs, res, lhsType.getDimSize(0));
+      return outerProd(lhs, rhs, res, lhsType.getDimSize(0), transposedMask);
     // Just permute the rhs.
     if (layout({{k, m}, {n, k}, {m, n}}))
-      return outerProd(lhs, t(rhs), res, lhsType.getDimSize(0));
+      return outerProd(lhs, t(rhs), res, lhsType.getDimSize(0), transposedMask);
     // Transposed output: swap RHS and LHS.
     // Classical row-major matmul: permute the lhs.
     if (layout({{m, k}, {k, n}, {n, m}}))
-      return outerProd(rhs, t(lhs), res, lhsType.getDimSize(1));
+      return outerProd(rhs, t(lhs), res, lhsType.getDimSize(1), transposedMask);
     // TODO: may be better to fail and use some vector<k> -> scalar reduction.
     if (layout({{m, k}, {n, k}, {n, m}})) {
       Value trhs = t(rhs);
-      return outerProd(trhs, t(lhs), res, lhsType.getDimSize(1));
+      return outerProd(trhs, t(lhs), res, lhsType.getDimSize(1),
+                       transposedMask);
     }
     if (layout({{k, m}, {k, n}, {n, m}}))
-      return outerProd(rhs, lhs, res, lhsType.getDimSize(0));
+      return outerProd(rhs, lhs, res, lhsType.getDimSize(0), transposedMask);
     if (layout({{k, m}, {n, k}, {n, m}}))
-      return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0));
+      return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0), transposedMask);
     return failure();
   }
 
-  /// One outer parallel, one inner reduction (matvec flavor)
+  //
+  // One outer parallel, one inner reduction (matvec flavor).
+  // Mask needs to be transposed everywhere to turn the reduction dimension
+  // outermost as required by outerproduct.
+  //
   FailureOr<Value> matvec() {
     if (!iters({Par(), Red()}))
       return failure();
@@ -500,18 +506,19 @@ struct UnrolledOuterProductGenerator
       return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1), t(mask));
     // Case mat-trans-vec: ready to go.
     if (layout({{k, m}, {k}, {m}}))
-      return outerProd(lhs, rhs, res, lhsType.getDimSize(0));
+      return outerProd(lhs, rhs, res, lhsType.getDimSize(0), t(mask));
     // Case vec-mat: swap and transpose.
     if (layout({{k}, {m, k}, {m}}))
-      return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0));
+      return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0), t(mask));
     // Case vec-mat-trans: swap and ready to go.
     if (layout({{k}, {k, m}, {m}}))
-      return outerProd(rhs, lhs, res, lhsType.getDimSize(0));
+      return outerProd(rhs, lhs, res, lhsType.getDimSize(0), t(mask));
     return failure();
   }
 
   //
-  // One outer reduction, one inner parallel (tmatvec flavor)
+  // One outer reduction, one inner parallel (tmatvec flavor).
+  // Mask already has the shape of the outer product.
   //
   FailureOr<Value> tmatvec() {
     if (!iters({Red(), Par()}))
@@ -521,16 +528,16 @@ struct UnrolledOuterProductGenerator
 
     // Case mat-vec: transpose.
     if (layout({{m, k}, {k}, {m}}))
-      return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1));
+      return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1), mask);
     // Case mat-trans-vec: ready to go.
     if (layout({{k, m}, {k}, {m}}))
-      return outerProd(lhs, rhs, res, lhsType.getDimSize(0));
+      return outerProd(lhs, rhs, res, lhsType.getDimSize(0), mask);
     // Case vec-mat: swap and transpose.
     if (layout({{k}, {m, k}, {m}}))
-      return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0));
+      return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0), mask);
     // Case vec-mat-trans: swap and ready to go.
     if (layout({{k}, {k, m}, {m}}))
-      return outerProd(rhs, lhs, res, lhsType.getDimSize(0));
+      return outerProd(rhs, lhs, res, lhsType.getDimSize(0), mask);
     return failure();
   }
 
diff --git a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
index aee4cf6b1379bd0..1b4428431a92ba3 100644
--- a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
@@ -341,6 +341,174 @@ func.func @matmul_7(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vecto
   return %0 : vector<3x2xf32>
 }
 
+// CHECK-LABEL: @masked_matvec_mk_k_m
+// CHECK-SAME:  %[[MAT:.+]]: vector<4x2xf32>
+// CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME:  %[[INIT:.+]]: vector<4xf32>
+// CHECK-SAME:  %[[MASK:.+]]: vector<4x2xi1>
+func.func @masked_matvec_mk_k_m(%arg0: vector<4x2xf32>, %arg1: vector<2xf32>, %arg2: vector<4xf32>, %mask: vector<4x2xi1>) -> vector<4xf32> {
+  // CHECK:         vector.transpose %[[MAT]]
+  // CHECK:         vector.transpose %[[MASK]]
+  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
+  %res = vector.mask %mask {
+    vector.contract {
+      indexing_maps = [affine_map<(m, k) -> (m, k)>,
+                       affine_map<(m, k) -> (k)>,
+                       affine_map<(m, k) -> (m)>],
+      iterator_types = ["parallel", "reduction"],
+      kind = #vector.kind<add>
+    } %arg0, %arg1, %arg2 : vector<4x2xf32>, vector<2xf32>, vector<4xf32> into vector<4xf32>
+  } : vector<4x2xi1> -> vector<4xf32>
+  return %res : vector<4xf32>
+}
+
+// CHECK-LABEL: @masked_matvec_km_k_m
+// CHECK-SAME:  %[[MAT:.+]]: vector<2x4xf32>
+// CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME:  %[[INIT:.+]]: vector<4xf32>
+// CHECK-SAME:  %[[MASK:.+]]: vector<4x2xi1>
+func.func @masked_matvec_km_k_m(%arg0: vector<2x4xf32>, %arg1: vector<2xf32>, %arg2: vector<4xf32>, %mask: vector<4x2xi1>) -> vector<4xf32> {
+  // CHECK-NOT:     vector.transpose %[[MAT]]
+  // CHECK:         vector.transpose %[[MASK]]
+  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
+  %res = vector.mask %mask {
+    vector.contract {
+      indexing_maps = [affine_map<(m, k) -> (k, m)>,
+                       affine_map<(m, k) -> (k)>,
+                       affine_map<(m, k) -> (m)>],
+      iterator_types = ["parallel", "reduction"],
+      kind = #vector.kind<add>
+    } %arg0, %arg1, %arg2 : vector<2x4xf32>, vector<2xf32>, vector<4xf32> into vector<4xf32>
+  } : vector<4x2xi1> -> vector<4xf32>
+  return %res : vector<4xf32>
+}
+
+// CHECK-LABEL: @masked_matvec_k_mk_m
+// CHECK-SAME:  %[[MAT:.+]]: vector<4x2xf32>
+// CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME:  %[[INIT:.+]]: vector<4xf32>
+// CHECK-SAME:  %[[MASK:.+]]: vector<4x2xi1>
+func.func @masked_matvec_k_mk_m(%arg0: vector<4x2xf32>, %arg1: vector<2xf32>, %arg2: vector<4xf32>, %mask: vector<4x2xi1>) -> vector<4xf32> {
+  // CHECK:         vector.transpose %[[MAT]]
+  // CHECK:         vector.transpose %[[MASK]]
+  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
+  %res = vector.mask %mask {
+    vector.contract {
+      indexing_maps = [affine_map<(m, k) -> (k)>,
+                       affine_map<(m, k) -> (m, k)>,
+                       affine_map<(m, k) -> (m)>],
+      iterator_types = ["parallel", "reduction"],
+      kind = #vector.kind<add>
+    } %arg1, %arg0, %arg2 : vector<2xf32>, vector<4x2xf32>, vector<4xf32> into vector<4xf32>
+  } : vector<4x2xi1> -> vector<4xf32>
+  return %res : vector<4xf32>
+}
+
+// CHECK-LABEL: @masked_matvec_k_km_m
+// CHECK-SAME:  %[[MAT:.+]]: vector<2x4xf32>
+// CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME:  %[[INIT:.+]]: vector<4xf32>
+// CHECK-SAME:  %[[MASK:.+]]: vector<4x2xi1>
+func.func @masked_matvec_k_km_m(%arg0: vector<2x4xf32>, %arg1: vector<2xf32>, %arg2: vector<4xf32>, %mask: vector<4x2xi1>) -> vector<4xf32> {
+  // CHECK-NOT:     vector.transpose %[[MAT]]
+  // CHECK:         vector.transpose %[[MASK]]
+  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
+  %res = vector.mask %mask {
+    vector.contract {
+      indexing_maps = [affine_map<(m, k) -> (k)>,
+                       affine_map<(m, k) -> (k, m)>,
+                       affine_map<(m, k) -> (m)>],
+      iterator_types = ["parallel", "reduction"],
+      kind = #vector.kind<add>
+    } %arg1, %arg0, %arg2 : vector<2xf32>, vector<2x4xf32>, vector<4xf32> into vector<4xf32>
+  } : vector<4x2xi1> -> vector<4xf32>
+  return %res : vector<4xf32>
+}
+
+// CHECK-LABEL: @masked_tmatvec_mk_k_m
+// CHECK-SAME:  %[[MAT:.+]]: vector<4x2xf32>
+// CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME:  %[[INIT:.+]]: vector<4xf32>
+// CHECK-SAME:  %[[MASK:.+]]: vector<2x4xi1>
+func.func @masked_tmatvec_mk_k_m(%arg0: vector<4x2xf32>, %arg1: vector<2xf32>, %arg2: vector<4xf32>, %mask: vector<2x4xi1>) -> vector<4xf32> {
+  // CHECK:         vector.transpose %[[MAT]]
+  // CHECK-NOT:     vector.transpose %[[MASK]]
+  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
+  %res = vector.mask %mask {
+    vector.contract {
+      indexing_maps = [affine_map<(k, m) -> (m, k)>,
+                       affine_map<(k, m) -> (k)>,
+                       affine_map<(k, m) -> (m)>],
+      iterator_types = ["reduction", "parallel"],
+      kind = #vector.kind<add>
+    } %arg0, %arg1, %arg2 : vector<4x2xf32>, vector<2xf32>, vector<4xf32> into vector<4xf32>
+  } : vector<2x4xi1> -> vector<4xf32>
+  return %res : vector<4xf32>
+}
+
+// CHECK-LABEL: @masked_tmatvec_km_k_m
+// CHECK-SAME:  %[[MAT:.+]]: vector<2x4xf32>
+// CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME:  %[[INIT:.+]]: vector<4xf32>
+// CHECK-SAME:  %[[MASK:.+]]: vector<2x4xi1>
+func.func @masked_tmatvec_km_k_m(%arg0: vector<2x4xf32>, %arg1: vector<2xf32>, %arg2: vector<4xf32>, %mask: vector<2x4xi1>) -> vector<4xf32> {
+  // CHECK-NOT:     vector.transpose %[[MAT]]
+  // CHECK-NOT:     vector.transpose %[[MASK]]
+  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
+  %res = vector.mask %mask {
+    vector.contract {
+      indexing_maps = [affine_map<(k, m) -> (k, m)>,
+                       affine_map<(k, m) -> (k)>,
+                       affine_map<(k, m) -> (m)>],
+      iterator_types = ["reduction", "parallel"],
+      kind = #vector.kind<add>
+    } %arg0, %arg1, %arg2 : vector<2x4xf32>, vector<2xf32>, vector<4xf32> into vector<4xf32>
+  } : vector<2x4xi1> -> vector<4xf32>
+  return %res : vector<4xf32>
+}
+
+// CHECK-LABEL: @masked_tmatvec_k_mk_m
+// CHECK-SAME:  %[[MAT:.+]]: vector<4x2xf32>
+// CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME:  %[[INIT:.+]]: vector<4xf32>
+// CHECK-SAME:  %[[MASK:.+]]: vector<2x4xi1>
+func.func @masked_tmatvec_k_mk_m(%arg0: vector<4x2xf32>, %arg1: vector<2xf32>, %arg2: vector<4xf32>, %mask: vector<2x4xi1>) -> vector<4xf32> {
+  // CHECK:         vector.transpose %[[MAT]]
+  // CHECK-NOT:     vector.transpose %[[MASK]]
+  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
+  %res = vector.mask %mask {
+    vector.contract {
+      indexing_maps = [affine_map<(k, m) -> (k)>,
+                       affine_map<(k, m) -> (m, k)>,
+                       affine_map<(k, m) -> (m)>],
+      iterator_types = ["reduction", "parallel"],
+      kind = #vector.kind<add>
+    } %arg1, %arg0, %arg2 : vector<2xf32>, vector<4x2xf32>, vector<4xf32> into vector<4xf32>
+  } : vector<2x4xi1> -> vector<4xf32>
+  return %res : vector<4xf32>
+}
+
+// CHECK-LABEL: @masked_tmatvec_k_km_m
+// CHECK-SAME:  %[[MAT:.+]]: vector<2x4xf32>
+// CHECK-SAME:  %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME:  %[[INIT:.+]]: vector<4xf32>
+// CHECK-SAME:  %[[MASK:.+]]: vector<2x4xi1>
+func.func @masked_tmatvec_k_km_m(%arg0: vector<2x4xf32>, %arg1: vector<2xf32>, %arg2: vector<4xf32>, %mask: vector<2x4xi1>) -> vector<4xf32> {
+  // CHECK-NOT:     vector.transpose %[[MAT]]
+  // CHECK-NOT:     vector.transpose %[[MASK]]
+  // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
+  %res = vector.mask %mask {
+    vector.contract {
+      indexing_maps = [affine_map<(k, m) -> (k)>,
+                       affine_map<(k, m) -> (k, m)>,
+                       affine_map<(k, m) -> (m)>],
+      iterator_types = ["reduction", "parallel"],
+      kind = #vector.kind<add>
+    } %arg1, %arg0, %arg2 : vector<2xf32>, vector<2x4xf32>, vector<4xf32> into vector<4xf32>
+  } : vector<2x4xi1> -> vector<4xf32>
+  return %res : vector<4xf32>
+}
+
 
 transform.sequence failures(propagate) {
 ^bb1(%module_op: !transform.any_op):

>From d259722f8355feacd25fa48eb9f79512c4c3563b Mon Sep 17 00:00:00 2001
From: Alex Zinenko <zinenko at google.com>
Date: Mon, 25 Sep 2023 09:55:16 +0000
Subject: [PATCH 2/2] fix implementation-defined behavior mismatch

Calling t() constructs IR, so it should not happen from multiple
function arguments that will be executed in different order on Windows
and Linux.
---
 .../Dialect/Vector/Transforms/LowerVectorContract.cpp    | 9 +++++----
 .../vector-contract-to-outerproduct-transforms.mlir      | 8 ++++----
 2 files changed, 9 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index de7aba6e84b748b..fb674690a0d3a54 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -500,19 +500,20 @@ struct UnrolledOuterProductGenerator
       return failure();
     AffineExpr m, k;
     bindDims(rewriter.getContext(), m, k);
+    Value transposedMask = t(mask);
 
     // Case mat-vec: transpose.
     if (layout({{m, k}, {k}, {m}}))
-      return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1), t(mask));
+      return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1), transposedMask);
     // Case mat-trans-vec: ready to go.
     if (layout({{k, m}, {k}, {m}}))
-      return outerProd(lhs, rhs, res, lhsType.getDimSize(0), t(mask));
+      return outerProd(lhs, rhs, res, lhsType.getDimSize(0), transposedMask);
     // Case vec-mat: swap and transpose.
     if (layout({{k}, {m, k}, {m}}))
-      return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0), t(mask));
+      return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0), transposedMask);
     // Case vec-mat-trans: swap and ready to go.
     if (layout({{k}, {k, m}, {m}}))
-      return outerProd(rhs, lhs, res, lhsType.getDimSize(0), t(mask));
+      return outerProd(rhs, lhs, res, lhsType.getDimSize(0), transposedMask);
     return failure();
   }
 
diff --git a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
index 1b4428431a92ba3..982de77786745b6 100644
--- a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
@@ -347,8 +347,8 @@ func.func @matmul_7(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vecto
 // CHECK-SAME:  %[[INIT:.+]]: vector<4xf32>
 // CHECK-SAME:  %[[MASK:.+]]: vector<4x2xi1>
 func.func @masked_matvec_mk_k_m(%arg0: vector<4x2xf32>, %arg1: vector<2xf32>, %arg2: vector<4xf32>, %mask: vector<4x2xi1>) -> vector<4xf32> {
-  // CHECK:         vector.transpose %[[MAT]]
   // CHECK:         vector.transpose %[[MASK]]
+  // CHECK:         vector.transpose %[[MAT]]
   // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
   %res = vector.mask %mask {
     vector.contract {
@@ -368,8 +368,8 @@ func.func @masked_matvec_mk_k_m(%arg0: vector<4x2xf32>, %arg1: vector<2xf32>, %a
 // CHECK-SAME:  %[[INIT:.+]]: vector<4xf32>
 // CHECK-SAME:  %[[MASK:.+]]: vector<4x2xi1>
 func.func @masked_matvec_km_k_m(%arg0: vector<2x4xf32>, %arg1: vector<2xf32>, %arg2: vector<4xf32>, %mask: vector<4x2xi1>) -> vector<4xf32> {
-  // CHECK-NOT:     vector.transpose %[[MAT]]
   // CHECK:         vector.transpose %[[MASK]]
+  // CHECK-NOT:     vector.transpose %[[MAT]]
   // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
   %res = vector.mask %mask {
     vector.contract {
@@ -389,8 +389,8 @@ func.func @masked_matvec_km_k_m(%arg0: vector<2x4xf32>, %arg1: vector<2xf32>, %a
 // CHECK-SAME:  %[[INIT:.+]]: vector<4xf32>
 // CHECK-SAME:  %[[MASK:.+]]: vector<4x2xi1>
 func.func @masked_matvec_k_mk_m(%arg0: vector<4x2xf32>, %arg1: vector<2xf32>, %arg2: vector<4xf32>, %mask: vector<4x2xi1>) -> vector<4xf32> {
-  // CHECK:         vector.transpose %[[MAT]]
   // CHECK:         vector.transpose %[[MASK]]
+  // CHECK:         vector.transpose %[[MAT]]
   // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
   %res = vector.mask %mask {
     vector.contract {
@@ -410,8 +410,8 @@ func.func @masked_matvec_k_mk_m(%arg0: vector<4x2xf32>, %arg1: vector<2xf32>, %a
 // CHECK-SAME:  %[[INIT:.+]]: vector<4xf32>
 // CHECK-SAME:  %[[MASK:.+]]: vector<4x2xi1>
 func.func @masked_matvec_k_km_m(%arg0: vector<2x4xf32>, %arg1: vector<2xf32>, %arg2: vector<4xf32>, %mask: vector<4x2xi1>) -> vector<4xf32> {
-  // CHECK-NOT:     vector.transpose %[[MAT]]
   // CHECK:         vector.transpose %[[MASK]]
+  // CHECK-NOT:     vector.transpose %[[MAT]]
   // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
   %res = vector.mask %mask {
     vector.contract {



More information about the Mlir-commits mailing list