[Mlir-commits] [mlir] [mlir][Vector] Add support for masks in castAwayContractionLeadingOneDim (PR #81906)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Feb 15 11:15:48 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Andrzej WarzyĆski (banach-space)
<details>
<summary>Changes</summary>
Partial fix for #<!-- -->78787
---
Full diff: https://github.com/llvm/llvm-project/pull/81906.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp (+34-11)
- (modified) mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir (+74-30)
``````````diff
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
index 74382b027c2f48..f7f2b934056185 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
@@ -332,9 +332,12 @@ struct CastAwayTransferWriteLeadingOneDim
LogicalResult
mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
RewriterBase &rewriter) {
- // TODO(#78787): Not supported masked op yet.
- if (cast<MaskableOpInterface>(contractOp.getOperation()).isMasked())
- return failure();
+ // Specifically for masked Ops for which we need to update the insertion
+ // point
+ PatternRewriter::InsertionGuard guard(rewriter);
+
+ auto isMasked =
+ cast<MaskableOpInterface>(contractOp.getOperation()).isMasked();
VectorType oldAccType = dyn_cast<VectorType>(contractOp.getAccType());
if (oldAccType == nullptr)
return failure();
@@ -346,6 +349,12 @@ mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
// greedily to drop more.
int64_t dropDim = 1;
+ if (isMasked) {
+ // Update the insertion point to avoid adding more ops to the vector.mask
+ // region corresponding to `mask`
+ rewriter.setInsertionPointAfter(contractOp->getParentOp());
+ }
+
auto oldIndexingMaps = contractOp.getIndexingMapsArray();
SmallVector<AffineMap> newIndexingMaps;
@@ -368,6 +377,7 @@ mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
SmallVector<Value> operands = {contractOp.getLhs(), contractOp.getRhs(),
contractOp.getAcc()};
SmallVector<Value> newOperands;
+ auto loc = contractOp.getLoc();
for (const auto &it : llvm::enumerate(oldIndexingMaps)) {
// Check if the dim to be dropped exists as a leading dim in the operand
@@ -405,7 +415,7 @@ mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
map = AffineMap::get(map.getNumDims(), 0, transposeResults,
contractOp.getContext());
operands[it.index()] = rewriter.create<vector::TransposeOp>(
- contractOp.getLoc(), operands[it.index()], perm);
+ loc, operands[it.index()], perm);
}
}
// We have taken care to have the dim to be dropped be
@@ -429,17 +439,30 @@ mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
// Extract if its a valid extraction, otherwise use the operand
// without extraction.
newOperands.push_back(
- validExtract ? rewriter.create<vector::ExtractOp>(contractOp.getLoc(),
- operands[it.index()],
- splatZero(dropDim))
+ validExtract ? rewriter.create<vector::ExtractOp>(
+ loc, operands[it.index()], splatZero(dropDim))
: operands[it.index()]);
}
- auto newContractOp = rewriter.create<vector::ContractionOp>(
- contractOp.getLoc(), newOperands[0], newOperands[1], newOperands[2],
+ Operation *newContractOp = rewriter.create<vector::ContractionOp>(
+ loc, newOperands[0], newOperands[1], newOperands[2],
rewriter.getAffineMapArrayAttr(newIndexingMaps),
rewriter.getArrayAttr(newIteratorTypes), contractOp.getKind());
- rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
- contractOp, contractOp->getResultTypes()[0], newContractOp);
+
+ if (isMasked) {
+ auto mask = contractOp.getMaskingOp();
+ auto newMask = rewriter.create<vector::ExtractOp>(loc, mask.getMask(),
+ splatZero(dropDim));
+
+ newContractOp =
+ mlir::vector::maskOperation(rewriter, newContractOp, newMask);
+ rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
+ mask, contractOp->getResultTypes()[0], newContractOp->getResults()[0]);
+ } else {
+ rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
+ contractOp, contractOp->getResultTypes()[0],
+ newContractOp->getResults()[0]);
+ }
+
return success();
}
diff --git a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
index af6e636245b04e..4ba51c5953d13c 100644
--- a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
@@ -30,6 +30,80 @@ func.func @cast_away_contraction_leading_one_dims(%arg0: vector<1x16x8xf32>, %ar
}
// -----
+// CHECK: #[[$MAP_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[$MAP_1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK: #[[$MAP_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL: func.func @cast_away_contraction_leading_one_dim_under_const_mask
+// CHECK: %[[MASK:.*]] = vector.constant_mask [15, 15, 8] : vector<16x16x8xi1>
+// CHECK: %[[R0:.*]] = vector.extract %{{.*}}[0] : vector<16x8xf32> from vector<1x16x8xf32>
+// CHECK: %[[R1:.*]] = vector.extract %{{.*}}[0] : vector<8x16xf32> from vector<1x8x16xf32>
+// CHECK: %[[R2:.*]] = vector.extract %{{.*}}[0] : vector<16x16xf32> from vector<1x16x16xf32>
+// CHECK: %[[CONTRACT:.*]] = vector.mask %[[MASK]] {
+// CHECK-SAME: vector.contract {indexing_maps = [#[[$MAP_0]], #[[$MAP_1]], #[[$MAP_2]]], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+// CHECK-SAME: %[[R0]], %[[R1]], %[[R2]] : vector<16x8xf32>, vector<8x16xf32> into vector<16x16xf32>
+// CHECK-SAME: } : vector<16x16x8xi1> -> vector<16x16xf32>
+// CHECK: %[[RES:.*]] = vector.broadcast %[[CONTRACT]] : vector<16x16xf32> to vector<1x16x16xf32>
+// CHECK: return %[[RES]] : vector<1x16x16xf32>
+
+#contraction_accesses0 = [
+ affine_map<(l, i, j, k) -> (l, i, k)>,
+ affine_map<(l, i, j, k) -> (l, k, j)>,
+ affine_map<(l, i, j, k) -> (l, i, j)>
+]
+#contraction_trait0 = {
+ indexing_maps = #contraction_accesses0,
+ iterator_types = ["parallel", "parallel", "parallel", "reduction"]
+}
+
+func.func @cast_away_contraction_leading_one_dim_under_const_mask(%arg0: vector<1x16x8xf32>, %arg1: vector<1x8x16xf32>, %arg2: vector<1x16x16xf32>) -> vector<1x16x16xf32> {
+ %mask = vector.constant_mask [1, 15, 15, 8] : vector<1x16x16x8xi1>
+ %0 = vector.mask %mask {
+ vector.contract #contraction_trait0 %arg0, %arg1, %arg2 : vector<1x16x8xf32>, vector<1x8x16xf32> into vector<1x16x16xf32>
+ } : vector<1x16x16x8xi1> -> vector<1x16x16xf32>
+ return %0 : vector<1x16x16xf32>
+}
+
+// -----
+// CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL: func.func @cast_away_contraction_leading_one_dim_under_mask
+// CHECK: %[[R0:.*]] = vector.extract %{{.*}} : vector<16x8xf32> from vector<1x16x8xf32>
+// CHECK: %[[R1:.*]] = vector.extract %{{.*}} : vector<8x16xf32> from vector<1x8x16xf32>
+// CHECK: %[[R2:.*]] = vector.extract %{{.*}} : vector<16x16xf32> from vector<1x16x16xf32>
+// CHECK: %[[M:.*]] = vector.extract %{{.*}} : vector<16x16x8xi1> from vector<1x16x16x8xi1>
+// CHECK: %[[CONTRACT:.*]] = vector.mask %[[M]] {
+// CHECK-SAME: vector.contract {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+// CHECK-SAME: %[[R0]], %[[R1]], %[[R2]] : vector<16x8xf32>, vector<8x16xf32> into vector<16x16xf32>
+// CHECK-SAME: } : vector<16x16x8xi1> -> vector<16x16xf32>
+// CHECK-NEXT: %[[RES:.*]] = vector.broadcast %[[CONTRACT]] : vector<16x16xf32> to vector<1x16x16xf32>
+// CHECK-NEXT: return %[[RES]] : vector<1x16x16xf32>
+
+#contraction_accesses0 = [
+ affine_map<(l, i, j, k) -> (l, i, k)>,
+ affine_map<(l, i, j, k) -> (l, k, j)>,
+ affine_map<(l, i, j, k) -> (l, i, j)>
+]
+#contraction_trait0 = {
+ indexing_maps = #contraction_accesses0,
+ iterator_types = ["parallel", "parallel", "parallel", "reduction"]
+}
+
+func.func @cast_away_contraction_leading_one_dim_under_mask(
+ %arg0: vector<1x16x8xf32>,
+ %arg1: vector<1x8x16xf32>,
+ %arg2: vector<1x16x16xf32>,
+ %mask: vector<1x16x16x8xi1>) -> vector<1x16x16xf32> {
+ %0 = vector.mask %mask {
+ vector.contract #contraction_trait0 %arg0, %arg1, %arg2 : vector<1x16x8xf32>, vector<1x8x16xf32> into vector<1x16x16xf32>
+ } : vector<1x16x16x8xi1> -> vector<1x16x16xf32>
+ return %0: vector<1x16x16xf32>
+}
+
+// -----
+
// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1) -> (d1)>
// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1) -> (d1, d0)>
// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1) -> (d0)>
@@ -164,36 +238,6 @@ func.func @cast_away_contraction_leading_one_dims_nonleadingunitdim_rank4_acctra
return %0: vector<1x1x2x16xf32>
}
-// -----
-
-// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
-// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
-// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
-
-// CHECK-LABEL: not_insert_cast_for_contraction_under_mask
-// CHECK: %[[MASK:.+]] = vector.constant_mask
-// CHECK: %[[CASTED_MASK:.+]] = vector.broadcast %[[MASK]]
-// CHECK: %[[RET:.+]] = vector.mask %[[CASTED_MASK]] {
-// CHECK-SAME: vector.contract {{.*}} : vector<1x16x8xf32>, vector<1x8x16xf32> into vector<1x16x16xf32> }
-// CHECK: return %[[RET]] : vector<1x16x16xf32>
-
-#contraction_accesses0 = [
- affine_map<(l, i, j, k) -> (l, i, k)>,
- affine_map<(l, i, j, k) -> (l, k, j)>,
- affine_map<(l, i, j, k) -> (l, i, j)>
-]
-#contraction_trait0 = {
- indexing_maps = #contraction_accesses0,
- iterator_types = ["parallel", "parallel", "parallel", "reduction"]
-}
-
-func.func @not_insert_cast_for_contraction_under_mask(%arg0: vector<1x16x8xf32>, %arg1: vector<1x8x16xf32>, %arg2: vector<1x16x16xf32>) -> vector<1x16x16xf32> {
- %mask = vector.constant_mask [1, 15, 15, 8] : vector<1x16x16x8xi1>
- %0 = vector.mask %mask {
- vector.contract #contraction_trait0 %arg0, %arg1, %arg2 : vector<1x16x8xf32>, vector<1x8x16xf32> into vector<1x16x16xf32>
- } : vector<1x16x16x8xi1> -> vector<1x16x16xf32>
- return %0 : vector<1x16x16xf32>
-}
// -----
// CHECK-LABEL: func @cast_away_extract_strided_slice_leading_one_dims
``````````
</details>
https://github.com/llvm/llvm-project/pull/81906
More information about the Mlir-commits
mailing list