[Mlir-commits] [mlir] [mlir][vector] proper masking support for contract lowering (PR #67145)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Sep 22 07:10:36 PDT 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-vector
<details>
<summary>Changes</summary>
Support all known permutations when lowering masked vector.contract to vector.outerproduct, and not just the canonical permutation.
---
Full diff: https://github.com/llvm/llvm-project/pull/67145.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp (+25-18)
- (modified) mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir (+168)
``````````diff
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index 64ab0abda26e640..de7aba6e84b748b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -457,38 +457,44 @@ struct UnrolledOuterProductGenerator
// Set up the parallel/reduction structure in the right form.
AffineExpr m, n, k;
bindDims(rewriter.getContext(), m, n, k);
+ Value transposedMask = t(mask, {2, 0, 1});
// Classical row-major matmul: Just permute the lhs.
if (layout({{m, k}, {k, n}, {m, n}}))
- return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1),
- t(mask, {2, 0, 1}));
+ return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1), transposedMask);
// TODO: may be better to fail and use some vector<k> -> scalar reduction.
if (layout({{m, k}, {n, k}, {m, n}})) {
Value tlhs = t(lhs);
- return outerProd(tlhs, t(rhs), res, lhsType.getDimSize(1));
+ return outerProd(tlhs, t(rhs), res, lhsType.getDimSize(1),
+ transposedMask);
}
// No need to permute anything.
if (layout({{k, m}, {k, n}, {m, n}}))
- return outerProd(lhs, rhs, res, lhsType.getDimSize(0));
+ return outerProd(lhs, rhs, res, lhsType.getDimSize(0), transposedMask);
// Just permute the rhs.
if (layout({{k, m}, {n, k}, {m, n}}))
- return outerProd(lhs, t(rhs), res, lhsType.getDimSize(0));
+ return outerProd(lhs, t(rhs), res, lhsType.getDimSize(0), transposedMask);
// Transposed output: swap RHS and LHS.
// Classical row-major matmul: permute the lhs.
if (layout({{m, k}, {k, n}, {n, m}}))
- return outerProd(rhs, t(lhs), res, lhsType.getDimSize(1));
+ return outerProd(rhs, t(lhs), res, lhsType.getDimSize(1), transposedMask);
// TODO: may be better to fail and use some vector<k> -> scalar reduction.
if (layout({{m, k}, {n, k}, {n, m}})) {
Value trhs = t(rhs);
- return outerProd(trhs, t(lhs), res, lhsType.getDimSize(1));
+ return outerProd(trhs, t(lhs), res, lhsType.getDimSize(1),
+ transposedMask);
}
if (layout({{k, m}, {k, n}, {n, m}}))
- return outerProd(rhs, lhs, res, lhsType.getDimSize(0));
+ return outerProd(rhs, lhs, res, lhsType.getDimSize(0), transposedMask);
if (layout({{k, m}, {n, k}, {n, m}}))
- return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0));
+ return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0), transposedMask);
return failure();
}
- /// One outer parallel, one inner reduction (matvec flavor)
+ //
+ // One outer parallel, one inner reduction (matvec flavor).
+ // Mask needs to be transposed everywhere to turn the reduction dimension
+ // outermost as required by outerproduct.
+ //
FailureOr<Value> matvec() {
if (!iters({Par(), Red()}))
return failure();
@@ -500,18 +506,19 @@ struct UnrolledOuterProductGenerator
return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1), t(mask));
// Case mat-trans-vec: ready to go.
if (layout({{k, m}, {k}, {m}}))
- return outerProd(lhs, rhs, res, lhsType.getDimSize(0));
+ return outerProd(lhs, rhs, res, lhsType.getDimSize(0), t(mask));
// Case vec-mat: swap and transpose.
if (layout({{k}, {m, k}, {m}}))
- return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0));
+ return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0), t(mask));
// Case vec-mat-trans: swap and ready to go.
if (layout({{k}, {k, m}, {m}}))
- return outerProd(rhs, lhs, res, lhsType.getDimSize(0));
+ return outerProd(rhs, lhs, res, lhsType.getDimSize(0), t(mask));
return failure();
}
//
- // One outer reduction, one inner parallel (tmatvec flavor)
+ // One outer reduction, one inner parallel (tmatvec flavor).
+ // Mask already has the shape of the outer product.
//
FailureOr<Value> tmatvec() {
if (!iters({Red(), Par()}))
@@ -521,16 +528,16 @@ struct UnrolledOuterProductGenerator
// Case mat-vec: transpose.
if (layout({{m, k}, {k}, {m}}))
- return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1));
+ return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1), mask);
// Case mat-trans-vec: ready to go.
if (layout({{k, m}, {k}, {m}}))
- return outerProd(lhs, rhs, res, lhsType.getDimSize(0));
+ return outerProd(lhs, rhs, res, lhsType.getDimSize(0), mask);
// Case vec-mat: swap and transpose.
if (layout({{k}, {m, k}, {m}}))
- return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0));
+ return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0), mask);
// Case vec-mat-trans: swap and ready to go.
if (layout({{k}, {k, m}, {m}}))
- return outerProd(rhs, lhs, res, lhsType.getDimSize(0));
+ return outerProd(rhs, lhs, res, lhsType.getDimSize(0), mask);
return failure();
}
diff --git a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
index aee4cf6b1379bd0..1b4428431a92ba3 100644
--- a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
@@ -341,6 +341,174 @@ func.func @matmul_7(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vecto
return %0 : vector<3x2xf32>
}
+// CHECK-LABEL: @masked_matvec_mk_k_m
+// CHECK-SAME: %[[MAT:.+]]: vector<4x2xf32>
+// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME: %[[INIT:.+]]: vector<4xf32>
+// CHECK-SAME: %[[MASK:.+]]: vector<4x2xi1>
+func.func @masked_matvec_mk_k_m(%arg0: vector<4x2xf32>, %arg1: vector<2xf32>, %arg2: vector<4xf32>, %mask: vector<4x2xi1>) -> vector<4xf32> {
+ // CHECK: vector.transpose %[[MAT]]
+ // CHECK: vector.transpose %[[MASK]]
+ // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
+ %res = vector.mask %mask {
+ vector.contract {
+ indexing_maps = [affine_map<(m, k) -> (m, k)>,
+ affine_map<(m, k) -> (k)>,
+ affine_map<(m, k) -> (m)>],
+ iterator_types = ["parallel", "reduction"],
+ kind = #vector.kind<add>
+ } %arg0, %arg1, %arg2 : vector<4x2xf32>, vector<2xf32>, vector<4xf32> into vector<4xf32>
+ } : vector<4x2xi1> -> vector<4xf32>
+ return %res : vector<4xf32>
+}
+
+// CHECK-LABEL: @masked_matvec_km_k_m
+// CHECK-SAME: %[[MAT:.+]]: vector<2x4xf32>
+// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME: %[[INIT:.+]]: vector<4xf32>
+// CHECK-SAME: %[[MASK:.+]]: vector<4x2xi1>
+func.func @masked_matvec_km_k_m(%arg0: vector<2x4xf32>, %arg1: vector<2xf32>, %arg2: vector<4xf32>, %mask: vector<4x2xi1>) -> vector<4xf32> {
+ // CHECK-NOT: vector.transpose %[[MAT]]
+ // CHECK: vector.transpose %[[MASK]]
+ // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
+ %res = vector.mask %mask {
+ vector.contract {
+ indexing_maps = [affine_map<(m, k) -> (k, m)>,
+ affine_map<(m, k) -> (k)>,
+ affine_map<(m, k) -> (m)>],
+ iterator_types = ["parallel", "reduction"],
+ kind = #vector.kind<add>
+ } %arg0, %arg1, %arg2 : vector<2x4xf32>, vector<2xf32>, vector<4xf32> into vector<4xf32>
+ } : vector<4x2xi1> -> vector<4xf32>
+ return %res : vector<4xf32>
+}
+
+// CHECK-LABEL: @masked_matvec_k_mk_m
+// CHECK-SAME: %[[MAT:.+]]: vector<4x2xf32>
+// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME: %[[INIT:.+]]: vector<4xf32>
+// CHECK-SAME: %[[MASK:.+]]: vector<4x2xi1>
+func.func @masked_matvec_k_mk_m(%arg0: vector<4x2xf32>, %arg1: vector<2xf32>, %arg2: vector<4xf32>, %mask: vector<4x2xi1>) -> vector<4xf32> {
+ // CHECK: vector.transpose %[[MAT]]
+ // CHECK: vector.transpose %[[MASK]]
+ // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
+ %res = vector.mask %mask {
+ vector.contract {
+ indexing_maps = [affine_map<(m, k) -> (k)>,
+ affine_map<(m, k) -> (m, k)>,
+ affine_map<(m, k) -> (m)>],
+ iterator_types = ["parallel", "reduction"],
+ kind = #vector.kind<add>
+ } %arg1, %arg0, %arg2 : vector<2xf32>, vector<4x2xf32>, vector<4xf32> into vector<4xf32>
+ } : vector<4x2xi1> -> vector<4xf32>
+ return %res : vector<4xf32>
+}
+
+// CHECK-LABEL: @masked_matvec_k_km_m
+// CHECK-SAME: %[[MAT:.+]]: vector<2x4xf32>
+// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME: %[[INIT:.+]]: vector<4xf32>
+// CHECK-SAME: %[[MASK:.+]]: vector<4x2xi1>
+func.func @masked_matvec_k_km_m(%arg0: vector<2x4xf32>, %arg1: vector<2xf32>, %arg2: vector<4xf32>, %mask: vector<4x2xi1>) -> vector<4xf32> {
+ // CHECK-NOT: vector.transpose %[[MAT]]
+ // CHECK: vector.transpose %[[MASK]]
+ // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
+ %res = vector.mask %mask {
+ vector.contract {
+ indexing_maps = [affine_map<(m, k) -> (k)>,
+ affine_map<(m, k) -> (k, m)>,
+ affine_map<(m, k) -> (m)>],
+ iterator_types = ["parallel", "reduction"],
+ kind = #vector.kind<add>
+ } %arg1, %arg0, %arg2 : vector<2xf32>, vector<2x4xf32>, vector<4xf32> into vector<4xf32>
+ } : vector<4x2xi1> -> vector<4xf32>
+ return %res : vector<4xf32>
+}
+
+// CHECK-LABEL: @masked_tmatvec_mk_k_m
+// CHECK-SAME: %[[MAT:.+]]: vector<4x2xf32>
+// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME: %[[INIT:.+]]: vector<4xf32>
+// CHECK-SAME: %[[MASK:.+]]: vector<2x4xi1>
+func.func @masked_tmatvec_mk_k_m(%arg0: vector<4x2xf32>, %arg1: vector<2xf32>, %arg2: vector<4xf32>, %mask: vector<2x4xi1>) -> vector<4xf32> {
+ // CHECK: vector.transpose %[[MAT]]
+ // CHECK-NOT: vector.transpose %[[MASK]]
+ // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
+ %res = vector.mask %mask {
+ vector.contract {
+ indexing_maps = [affine_map<(k, m) -> (m, k)>,
+ affine_map<(k, m) -> (k)>,
+ affine_map<(k, m) -> (m)>],
+ iterator_types = ["reduction", "parallel"],
+ kind = #vector.kind<add>
+ } %arg0, %arg1, %arg2 : vector<4x2xf32>, vector<2xf32>, vector<4xf32> into vector<4xf32>
+ } : vector<2x4xi1> -> vector<4xf32>
+ return %res : vector<4xf32>
+}
+
+// CHECK-LABEL: @masked_tmatvec_km_k_m
+// CHECK-SAME: %[[MAT:.+]]: vector<2x4xf32>
+// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME: %[[INIT:.+]]: vector<4xf32>
+// CHECK-SAME: %[[MASK:.+]]: vector<2x4xi1>
+func.func @masked_tmatvec_km_k_m(%arg0: vector<2x4xf32>, %arg1: vector<2xf32>, %arg2: vector<4xf32>, %mask: vector<2x4xi1>) -> vector<4xf32> {
+ // CHECK-NOT: vector.transpose %[[MAT]]
+ // CHECK-NOT: vector.transpose %[[MASK]]
+ // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
+ %res = vector.mask %mask {
+ vector.contract {
+ indexing_maps = [affine_map<(k, m) -> (k, m)>,
+ affine_map<(k, m) -> (k)>,
+ affine_map<(k, m) -> (m)>],
+ iterator_types = ["reduction", "parallel"],
+ kind = #vector.kind<add>
+ } %arg0, %arg1, %arg2 : vector<2x4xf32>, vector<2xf32>, vector<4xf32> into vector<4xf32>
+ } : vector<2x4xi1> -> vector<4xf32>
+ return %res : vector<4xf32>
+}
+
+// CHECK-LABEL: @masked_tmatvec_k_mk_m
+// CHECK-SAME: %[[MAT:.+]]: vector<4x2xf32>
+// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME: %[[INIT:.+]]: vector<4xf32>
+// CHECK-SAME: %[[MASK:.+]]: vector<2x4xi1>
+func.func @masked_tmatvec_k_mk_m(%arg0: vector<4x2xf32>, %arg1: vector<2xf32>, %arg2: vector<4xf32>, %mask: vector<2x4xi1>) -> vector<4xf32> {
+ // CHECK: vector.transpose %[[MAT]]
+ // CHECK-NOT: vector.transpose %[[MASK]]
+ // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
+ %res = vector.mask %mask {
+ vector.contract {
+ indexing_maps = [affine_map<(k, m) -> (k)>,
+ affine_map<(k, m) -> (m, k)>,
+ affine_map<(k, m) -> (m)>],
+ iterator_types = ["reduction", "parallel"],
+ kind = #vector.kind<add>
+ } %arg1, %arg0, %arg2 : vector<2xf32>, vector<4x2xf32>, vector<4xf32> into vector<4xf32>
+ } : vector<2x4xi1> -> vector<4xf32>
+ return %res : vector<4xf32>
+}
+
+// CHECK-LABEL: @masked_tmatvec_k_km_m
+// CHECK-SAME: %[[MAT:.+]]: vector<2x4xf32>
+// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME: %[[INIT:.+]]: vector<4xf32>
+// CHECK-SAME: %[[MASK:.+]]: vector<2x4xi1>
+func.func @masked_tmatvec_k_km_m(%arg0: vector<2x4xf32>, %arg1: vector<2xf32>, %arg2: vector<4xf32>, %mask: vector<2x4xi1>) -> vector<4xf32> {
+ // CHECK-NOT: vector.transpose %[[MAT]]
+ // CHECK-NOT: vector.transpose %[[MASK]]
+ // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
+ %res = vector.mask %mask {
+ vector.contract {
+ indexing_maps = [affine_map<(k, m) -> (k)>,
+ affine_map<(k, m) -> (k, m)>,
+ affine_map<(k, m) -> (m)>],
+ iterator_types = ["reduction", "parallel"],
+ kind = #vector.kind<add>
+ } %arg1, %arg0, %arg2 : vector<2xf32>, vector<2x4xf32>, vector<4xf32> into vector<4xf32>
+ } : vector<2x4xi1> -> vector<4xf32>
+ return %res : vector<4xf32>
+}
+
transform.sequence failures(propagate) {
^bb1(%module_op: !transform.any_op):
``````````
</details>
https://github.com/llvm/llvm-project/pull/67145
More information about the Mlir-commits
mailing list