[Mlir-commits] [mlir] [mlir][vector] proper masking support for contract lowering (PR #67145)
Oleksandr Alex Zinenko
llvmlistbot at llvm.org
Mon Sep 25 02:56:43 PDT 2023
https://github.com/ftynse updated https://github.com/llvm/llvm-project/pull/67145
>From a8f17e70037abbfa0727d40e81619475ed37ef75 Mon Sep 17 00:00:00 2001
From: Alex Zinenko <zinenko at google.com>
Date: Fri, 22 Sep 2023 14:06:47 +0000
Subject: [PATCH 1/2] [mlir][vector] proper masking support for contract
lowering
Support all known permutations when lowering masked vector.contract to
vector.outerproduct, and not just the canonical permutation.
---
.../Vector/Transforms/LowerVectorContract.cpp | 43 +++--
...r-contract-to-outerproduct-transforms.mlir | 168 ++++++++++++++++++
2 files changed, 193 insertions(+), 18 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index 64ab0abda26e640..de7aba6e84b748b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -457,38 +457,44 @@ struct UnrolledOuterProductGenerator
// Set up the parallel/reduction structure in the right form.
AffineExpr m, n, k;
bindDims(rewriter.getContext(), m, n, k);
+ Value transposedMask = t(mask, {2, 0, 1});
// Classical row-major matmul: Just permute the lhs.
if (layout({{m, k}, {k, n}, {m, n}}))
- return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1),
- t(mask, {2, 0, 1}));
+ return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1), transposedMask);
// TODO: may be better to fail and use some vector<k> -> scalar reduction.
if (layout({{m, k}, {n, k}, {m, n}})) {
Value tlhs = t(lhs);
- return outerProd(tlhs, t(rhs), res, lhsType.getDimSize(1));
+ return outerProd(tlhs, t(rhs), res, lhsType.getDimSize(1),
+ transposedMask);
}
// No need to permute anything.
if (layout({{k, m}, {k, n}, {m, n}}))
- return outerProd(lhs, rhs, res, lhsType.getDimSize(0));
+ return outerProd(lhs, rhs, res, lhsType.getDimSize(0), transposedMask);
// Just permute the rhs.
if (layout({{k, m}, {n, k}, {m, n}}))
- return outerProd(lhs, t(rhs), res, lhsType.getDimSize(0));
+ return outerProd(lhs, t(rhs), res, lhsType.getDimSize(0), transposedMask);
// Transposed output: swap RHS and LHS.
// Classical row-major matmul: permute the lhs.
if (layout({{m, k}, {k, n}, {n, m}}))
- return outerProd(rhs, t(lhs), res, lhsType.getDimSize(1));
+ return outerProd(rhs, t(lhs), res, lhsType.getDimSize(1), transposedMask);
// TODO: may be better to fail and use some vector<k> -> scalar reduction.
if (layout({{m, k}, {n, k}, {n, m}})) {
Value trhs = t(rhs);
- return outerProd(trhs, t(lhs), res, lhsType.getDimSize(1));
+ return outerProd(trhs, t(lhs), res, lhsType.getDimSize(1),
+ transposedMask);
}
if (layout({{k, m}, {k, n}, {n, m}}))
- return outerProd(rhs, lhs, res, lhsType.getDimSize(0));
+ return outerProd(rhs, lhs, res, lhsType.getDimSize(0), transposedMask);
if (layout({{k, m}, {n, k}, {n, m}}))
- return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0));
+ return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0), transposedMask);
return failure();
}
- /// One outer parallel, one inner reduction (matvec flavor)
+ //
+ // One outer parallel, one inner reduction (matvec flavor).
+ // Mask needs to be transposed everywhere to turn the reduction dimension
+ // outermost as required by outerproduct.
+ //
FailureOr<Value> matvec() {
if (!iters({Par(), Red()}))
return failure();
@@ -500,18 +506,19 @@ struct UnrolledOuterProductGenerator
return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1), t(mask));
// Case mat-trans-vec: ready to go.
if (layout({{k, m}, {k}, {m}}))
- return outerProd(lhs, rhs, res, lhsType.getDimSize(0));
+ return outerProd(lhs, rhs, res, lhsType.getDimSize(0), t(mask));
// Case vec-mat: swap and transpose.
if (layout({{k}, {m, k}, {m}}))
- return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0));
+ return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0), t(mask));
// Case vec-mat-trans: swap and ready to go.
if (layout({{k}, {k, m}, {m}}))
- return outerProd(rhs, lhs, res, lhsType.getDimSize(0));
+ return outerProd(rhs, lhs, res, lhsType.getDimSize(0), t(mask));
return failure();
}
//
- // One outer reduction, one inner parallel (tmatvec flavor)
+ // One outer reduction, one inner parallel (tmatvec flavor).
+ // Mask already has the shape of the outer product.
//
FailureOr<Value> tmatvec() {
if (!iters({Red(), Par()}))
@@ -521,16 +528,16 @@ struct UnrolledOuterProductGenerator
// Case mat-vec: transpose.
if (layout({{m, k}, {k}, {m}}))
- return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1));
+ return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1), mask);
// Case mat-trans-vec: ready to go.
if (layout({{k, m}, {k}, {m}}))
- return outerProd(lhs, rhs, res, lhsType.getDimSize(0));
+ return outerProd(lhs, rhs, res, lhsType.getDimSize(0), mask);
// Case vec-mat: swap and transpose.
if (layout({{k}, {m, k}, {m}}))
- return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0));
+ return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0), mask);
// Case vec-mat-trans: swap and ready to go.
if (layout({{k}, {k, m}, {m}}))
- return outerProd(rhs, lhs, res, lhsType.getDimSize(0));
+ return outerProd(rhs, lhs, res, lhsType.getDimSize(0), mask);
return failure();
}
diff --git a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
index aee4cf6b1379bd0..1b4428431a92ba3 100644
--- a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
@@ -341,6 +341,174 @@ func.func @matmul_7(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vecto
return %0 : vector<3x2xf32>
}
+// CHECK-LABEL: @masked_matvec_mk_k_m
+// CHECK-SAME: %[[MAT:.+]]: vector<4x2xf32>
+// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME: %[[INIT:.+]]: vector<4xf32>
+// CHECK-SAME: %[[MASK:.+]]: vector<4x2xi1>
+func.func @masked_matvec_mk_k_m(%arg0: vector<4x2xf32>, %arg1: vector<2xf32>, %arg2: vector<4xf32>, %mask: vector<4x2xi1>) -> vector<4xf32> {
+ // CHECK: vector.transpose %[[MAT]]
+ // CHECK: vector.transpose %[[MASK]]
+ // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
+ %res = vector.mask %mask {
+ vector.contract {
+ indexing_maps = [affine_map<(m, k) -> (m, k)>,
+ affine_map<(m, k) -> (k)>,
+ affine_map<(m, k) -> (m)>],
+ iterator_types = ["parallel", "reduction"],
+ kind = #vector.kind<add>
+ } %arg0, %arg1, %arg2 : vector<4x2xf32>, vector<2xf32>, vector<4xf32> into vector<4xf32>
+ } : vector<4x2xi1> -> vector<4xf32>
+ return %res : vector<4xf32>
+}
+
+// CHECK-LABEL: @masked_matvec_km_k_m
+// CHECK-SAME: %[[MAT:.+]]: vector<2x4xf32>
+// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME: %[[INIT:.+]]: vector<4xf32>
+// CHECK-SAME: %[[MASK:.+]]: vector<4x2xi1>
+func.func @masked_matvec_km_k_m(%arg0: vector<2x4xf32>, %arg1: vector<2xf32>, %arg2: vector<4xf32>, %mask: vector<4x2xi1>) -> vector<4xf32> {
+ // CHECK-NOT: vector.transpose %[[MAT]]
+ // CHECK: vector.transpose %[[MASK]]
+ // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
+ %res = vector.mask %mask {
+ vector.contract {
+ indexing_maps = [affine_map<(m, k) -> (k, m)>,
+ affine_map<(m, k) -> (k)>,
+ affine_map<(m, k) -> (m)>],
+ iterator_types = ["parallel", "reduction"],
+ kind = #vector.kind<add>
+ } %arg0, %arg1, %arg2 : vector<2x4xf32>, vector<2xf32>, vector<4xf32> into vector<4xf32>
+ } : vector<4x2xi1> -> vector<4xf32>
+ return %res : vector<4xf32>
+}
+
+// CHECK-LABEL: @masked_matvec_k_mk_m
+// CHECK-SAME: %[[MAT:.+]]: vector<4x2xf32>
+// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME: %[[INIT:.+]]: vector<4xf32>
+// CHECK-SAME: %[[MASK:.+]]: vector<4x2xi1>
+func.func @masked_matvec_k_mk_m(%arg0: vector<4x2xf32>, %arg1: vector<2xf32>, %arg2: vector<4xf32>, %mask: vector<4x2xi1>) -> vector<4xf32> {
+ // CHECK: vector.transpose %[[MAT]]
+ // CHECK: vector.transpose %[[MASK]]
+ // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
+ %res = vector.mask %mask {
+ vector.contract {
+ indexing_maps = [affine_map<(m, k) -> (k)>,
+ affine_map<(m, k) -> (m, k)>,
+ affine_map<(m, k) -> (m)>],
+ iterator_types = ["parallel", "reduction"],
+ kind = #vector.kind<add>
+ } %arg1, %arg0, %arg2 : vector<2xf32>, vector<4x2xf32>, vector<4xf32> into vector<4xf32>
+ } : vector<4x2xi1> -> vector<4xf32>
+ return %res : vector<4xf32>
+}
+
+// CHECK-LABEL: @masked_matvec_k_km_m
+// CHECK-SAME: %[[MAT:.+]]: vector<2x4xf32>
+// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME: %[[INIT:.+]]: vector<4xf32>
+// CHECK-SAME: %[[MASK:.+]]: vector<4x2xi1>
+func.func @masked_matvec_k_km_m(%arg0: vector<2x4xf32>, %arg1: vector<2xf32>, %arg2: vector<4xf32>, %mask: vector<4x2xi1>) -> vector<4xf32> {
+ // CHECK-NOT: vector.transpose %[[MAT]]
+ // CHECK: vector.transpose %[[MASK]]
+ // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
+ %res = vector.mask %mask {
+ vector.contract {
+ indexing_maps = [affine_map<(m, k) -> (k)>,
+ affine_map<(m, k) -> (k, m)>,
+ affine_map<(m, k) -> (m)>],
+ iterator_types = ["parallel", "reduction"],
+ kind = #vector.kind<add>
+ } %arg1, %arg0, %arg2 : vector<2xf32>, vector<2x4xf32>, vector<4xf32> into vector<4xf32>
+ } : vector<4x2xi1> -> vector<4xf32>
+ return %res : vector<4xf32>
+}
+
+// CHECK-LABEL: @masked_tmatvec_mk_k_m
+// CHECK-SAME: %[[MAT:.+]]: vector<4x2xf32>
+// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME: %[[INIT:.+]]: vector<4xf32>
+// CHECK-SAME: %[[MASK:.+]]: vector<2x4xi1>
+func.func @masked_tmatvec_mk_k_m(%arg0: vector<4x2xf32>, %arg1: vector<2xf32>, %arg2: vector<4xf32>, %mask: vector<2x4xi1>) -> vector<4xf32> {
+ // CHECK: vector.transpose %[[MAT]]
+ // CHECK-NOT: vector.transpose %[[MASK]]
+ // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
+ %res = vector.mask %mask {
+ vector.contract {
+ indexing_maps = [affine_map<(k, m) -> (m, k)>,
+ affine_map<(k, m) -> (k)>,
+ affine_map<(k, m) -> (m)>],
+ iterator_types = ["reduction", "parallel"],
+ kind = #vector.kind<add>
+ } %arg0, %arg1, %arg2 : vector<4x2xf32>, vector<2xf32>, vector<4xf32> into vector<4xf32>
+ } : vector<2x4xi1> -> vector<4xf32>
+ return %res : vector<4xf32>
+}
+
+// CHECK-LABEL: @masked_tmatvec_km_k_m
+// CHECK-SAME: %[[MAT:.+]]: vector<2x4xf32>
+// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME: %[[INIT:.+]]: vector<4xf32>
+// CHECK-SAME: %[[MASK:.+]]: vector<2x4xi1>
+func.func @masked_tmatvec_km_k_m(%arg0: vector<2x4xf32>, %arg1: vector<2xf32>, %arg2: vector<4xf32>, %mask: vector<2x4xi1>) -> vector<4xf32> {
+ // CHECK-NOT: vector.transpose %[[MAT]]
+ // CHECK-NOT: vector.transpose %[[MASK]]
+ // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
+ %res = vector.mask %mask {
+ vector.contract {
+ indexing_maps = [affine_map<(k, m) -> (k, m)>,
+ affine_map<(k, m) -> (k)>,
+ affine_map<(k, m) -> (m)>],
+ iterator_types = ["reduction", "parallel"],
+ kind = #vector.kind<add>
+ } %arg0, %arg1, %arg2 : vector<2x4xf32>, vector<2xf32>, vector<4xf32> into vector<4xf32>
+ } : vector<2x4xi1> -> vector<4xf32>
+ return %res : vector<4xf32>
+}
+
+// CHECK-LABEL: @masked_tmatvec_k_mk_m
+// CHECK-SAME: %[[MAT:.+]]: vector<4x2xf32>
+// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME: %[[INIT:.+]]: vector<4xf32>
+// CHECK-SAME: %[[MASK:.+]]: vector<2x4xi1>
+func.func @masked_tmatvec_k_mk_m(%arg0: vector<4x2xf32>, %arg1: vector<2xf32>, %arg2: vector<4xf32>, %mask: vector<2x4xi1>) -> vector<4xf32> {
+ // CHECK: vector.transpose %[[MAT]]
+ // CHECK-NOT: vector.transpose %[[MASK]]
+ // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
+ %res = vector.mask %mask {
+ vector.contract {
+ indexing_maps = [affine_map<(k, m) -> (k)>,
+ affine_map<(k, m) -> (m, k)>,
+ affine_map<(k, m) -> (m)>],
+ iterator_types = ["reduction", "parallel"],
+ kind = #vector.kind<add>
+ } %arg1, %arg0, %arg2 : vector<2xf32>, vector<4x2xf32>, vector<4xf32> into vector<4xf32>
+ } : vector<2x4xi1> -> vector<4xf32>
+ return %res : vector<4xf32>
+}
+
+// CHECK-LABEL: @masked_tmatvec_k_km_m
+// CHECK-SAME: %[[MAT:.+]]: vector<2x4xf32>
+// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME: %[[INIT:.+]]: vector<4xf32>
+// CHECK-SAME: %[[MASK:.+]]: vector<2x4xi1>
+func.func @masked_tmatvec_k_km_m(%arg0: vector<2x4xf32>, %arg1: vector<2xf32>, %arg2: vector<4xf32>, %mask: vector<2x4xi1>) -> vector<4xf32> {
+ // CHECK-NOT: vector.transpose %[[MAT]]
+ // CHECK-NOT: vector.transpose %[[MASK]]
+ // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
+ %res = vector.mask %mask {
+ vector.contract {
+ indexing_maps = [affine_map<(k, m) -> (k)>,
+ affine_map<(k, m) -> (k, m)>,
+ affine_map<(k, m) -> (m)>],
+ iterator_types = ["reduction", "parallel"],
+ kind = #vector.kind<add>
+ } %arg1, %arg0, %arg2 : vector<2xf32>, vector<2x4xf32>, vector<4xf32> into vector<4xf32>
+ } : vector<2x4xi1> -> vector<4xf32>
+ return %res : vector<4xf32>
+}
+
transform.sequence failures(propagate) {
^bb1(%module_op: !transform.any_op):
>From d259722f8355feacd25fa48eb9f79512c4c3563b Mon Sep 17 00:00:00 2001
From: Alex Zinenko <zinenko at google.com>
Date: Mon, 25 Sep 2023 09:55:16 +0000
Subject: [PATCH 2/2] fix implementation-defined behavior mismatch
Calling t() constructs IR, so it should not happen from multiple
function arguments that will be executed in different order on Windows
and Linux.
---
.../Dialect/Vector/Transforms/LowerVectorContract.cpp | 9 +++++----
.../vector-contract-to-outerproduct-transforms.mlir | 8 ++++----
2 files changed, 9 insertions(+), 8 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index de7aba6e84b748b..fb674690a0d3a54 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -500,19 +500,20 @@ struct UnrolledOuterProductGenerator
return failure();
AffineExpr m, k;
bindDims(rewriter.getContext(), m, k);
+ Value transposedMask = t(mask);
// Case mat-vec: transpose.
if (layout({{m, k}, {k}, {m}}))
- return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1), t(mask));
+ return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1), transposedMask);
// Case mat-trans-vec: ready to go.
if (layout({{k, m}, {k}, {m}}))
- return outerProd(lhs, rhs, res, lhsType.getDimSize(0), t(mask));
+ return outerProd(lhs, rhs, res, lhsType.getDimSize(0), transposedMask);
// Case vec-mat: swap and transpose.
if (layout({{k}, {m, k}, {m}}))
- return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0), t(mask));
+ return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0), transposedMask);
// Case vec-mat-trans: swap and ready to go.
if (layout({{k}, {k, m}, {m}}))
- return outerProd(rhs, lhs, res, lhsType.getDimSize(0), t(mask));
+ return outerProd(rhs, lhs, res, lhsType.getDimSize(0), transposedMask);
return failure();
}
diff --git a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
index 1b4428431a92ba3..982de77786745b6 100644
--- a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
@@ -347,8 +347,8 @@ func.func @matmul_7(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vecto
// CHECK-SAME: %[[INIT:.+]]: vector<4xf32>
// CHECK-SAME: %[[MASK:.+]]: vector<4x2xi1>
func.func @masked_matvec_mk_k_m(%arg0: vector<4x2xf32>, %arg1: vector<2xf32>, %arg2: vector<4xf32>, %mask: vector<4x2xi1>) -> vector<4xf32> {
- // CHECK: vector.transpose %[[MAT]]
// CHECK: vector.transpose %[[MASK]]
+ // CHECK: vector.transpose %[[MAT]]
// CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
%res = vector.mask %mask {
vector.contract {
@@ -368,8 +368,8 @@ func.func @masked_matvec_mk_k_m(%arg0: vector<4x2xf32>, %arg1: vector<2xf32>, %a
// CHECK-SAME: %[[INIT:.+]]: vector<4xf32>
// CHECK-SAME: %[[MASK:.+]]: vector<4x2xi1>
func.func @masked_matvec_km_k_m(%arg0: vector<2x4xf32>, %arg1: vector<2xf32>, %arg2: vector<4xf32>, %mask: vector<4x2xi1>) -> vector<4xf32> {
- // CHECK-NOT: vector.transpose %[[MAT]]
// CHECK: vector.transpose %[[MASK]]
+ // CHECK-NOT: vector.transpose %[[MAT]]
// CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
%res = vector.mask %mask {
vector.contract {
@@ -389,8 +389,8 @@ func.func @masked_matvec_km_k_m(%arg0: vector<2x4xf32>, %arg1: vector<2xf32>, %a
// CHECK-SAME: %[[INIT:.+]]: vector<4xf32>
// CHECK-SAME: %[[MASK:.+]]: vector<4x2xi1>
func.func @masked_matvec_k_mk_m(%arg0: vector<4x2xf32>, %arg1: vector<2xf32>, %arg2: vector<4xf32>, %mask: vector<4x2xi1>) -> vector<4xf32> {
- // CHECK: vector.transpose %[[MAT]]
// CHECK: vector.transpose %[[MASK]]
+ // CHECK: vector.transpose %[[MAT]]
// CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
%res = vector.mask %mask {
vector.contract {
@@ -410,8 +410,8 @@ func.func @masked_matvec_k_mk_m(%arg0: vector<4x2xf32>, %arg1: vector<2xf32>, %a
// CHECK-SAME: %[[INIT:.+]]: vector<4xf32>
// CHECK-SAME: %[[MASK:.+]]: vector<4x2xi1>
func.func @masked_matvec_k_km_m(%arg0: vector<2x4xf32>, %arg1: vector<2xf32>, %arg2: vector<4xf32>, %mask: vector<4x2xi1>) -> vector<4xf32> {
- // CHECK-NOT: vector.transpose %[[MAT]]
// CHECK: vector.transpose %[[MASK]]
+ // CHECK-NOT: vector.transpose %[[MAT]]
// CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
%res = vector.mask %mask {
vector.contract {
More information about the Mlir-commits
mailing list