[Mlir-commits] [mlir] [mlir][linalg] Fix rank-reduced cases for extract/insert slice in DropUnitDims (PR #74723)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Dec 7 06:55:53 PST 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-linalg
Author: Quinn Dawkins (qedawkins)
<details>
<summary>Changes</summary>
Inferring the reshape reassociation indices for extract/insert slice ops based on the read sizes of the original slicing op will generate an invalid expand/collapse shape op for already rank-reduced cases. Instead just infer from the shape of the slice.
Ported from Differential Revision: https://reviews.llvm.org/D147488
---
Full diff: https://github.com/llvm/llvm-project/pull/74723.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp (+13-8)
- (modified) mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir (+24)
``````````diff
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index 6fbf351455787..c495956fa5770 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -572,13 +572,17 @@ struct RankReducedExtractSliceOp
LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
PatternRewriter &rewriter) const override {
RankedTensorType resultType = sliceOp.getType();
- SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets();
- SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes();
- SmallVector<OpFoldResult> strides = sliceOp.getMixedStrides();
- auto reassociation = getReassociationMapForFoldingUnitDims(sizes);
+ SmallVector<OpFoldResult> targetShape;
+ for (auto size : resultType.getShape())
+ targetShape.push_back(rewriter.getIndexAttr(size));
+ auto reassociation = getReassociationMapForFoldingUnitDims(targetShape);
if (!reassociation ||
reassociation->size() == static_cast<size_t>(resultType.getRank()))
return failure();
+
+ SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets();
+ SmallVector<OpFoldResult> strides = sliceOp.getMixedStrides();
+ SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes();
auto rankReducedType = cast<RankedTensorType>(
tensor::ExtractSliceOp::inferCanonicalRankReducedResultType(
reassociation->size(), sliceOp.getSourceType(), offsets, sizes,
@@ -602,13 +606,14 @@ struct RankReducedInsertSliceOp : public OpRewritePattern<InsertOpTy> {
LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
PatternRewriter &rewriter) const override {
RankedTensorType sourceType = insertSliceOp.getSourceType();
- SmallVector<OpFoldResult> offsets = insertSliceOp.getMixedOffsets();
- SmallVector<OpFoldResult> sizes = insertSliceOp.getMixedSizes();
- SmallVector<OpFoldResult> strides = insertSliceOp.getMixedStrides();
- auto reassociation = getReassociationMapForFoldingUnitDims(sizes);
+ SmallVector<OpFoldResult> targetShape;
+ for (auto size : sourceType.getShape())
+ targetShape.push_back(rewriter.getIndexAttr(size));
+ auto reassociation = getReassociationMapForFoldingUnitDims(targetShape);
if (!reassociation ||
reassociation->size() == static_cast<size_t>(sourceType.getRank()))
return failure();
+
Location loc = insertSliceOp.getLoc();
tensor::CollapseShapeOp reshapedSource;
{
diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
index 795e9ee528717..0c51a032df901 100644
--- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
+++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
@@ -489,6 +489,18 @@ func.func @slice_unit_dims(%arg0: tensor<1x3xf32>) -> tensor<1x1xf32> {
// -----
+func.func @rank_reduced_extract_slice(%arg0: tensor<1x1x3x1x3xf32>) -> tensor<1x3x3xf32> {
+ %0 = tensor.extract_slice %arg0[0, 0, 0, 0, 0] [1, 1, 3, 1, 3] [1, 1, 1, 1, 1] : tensor<1x1x3x1x3xf32> to tensor<1x3x3xf32>
+ return %0 : tensor<1x3x3xf32>
+}
+// CHECK-LABEL: func @rank_reduced_extract_slice
+// CHECK: %[[SLICE:.+]] = tensor.extract_slice
+// CHECK-SAME: tensor<1x1x3x1x3xf32> to tensor<3x3xf32>
+// CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[SLICE]] {{\[}}[0, 1], [2]]
+// CHECK: return %[[RESULT]]
+
+// -----
+
func.func @insert_slice_unit_dims(%arg0: tensor<1x3xf32>, %arg1: tensor<1x1xf32>) -> tensor<1x3xf32> {
%0 = tensor.insert_slice %arg1 into %arg0[0, 2] [1, 1] [1, 1] : tensor<1x1xf32> into tensor<1x3xf32>
return %0 : tensor<1x3xf32>
@@ -501,6 +513,18 @@ func.func @insert_slice_unit_dims(%arg0: tensor<1x3xf32>, %arg1: tensor<1x1xf32>
// -----
+func.func @rank_reduced_insert_slice(%arg0: tensor<1x1x3x1x3xf32>, %arg1: tensor<1x3x3xf32>) -> tensor<1x1x3x1x3xf32> {
+ %0 = tensor.insert_slice %arg1 into %arg0[0, 0, 0, 0, 0] [1, 1, 3, 1, 3] [1, 1, 1, 1, 1] : tensor<1x3x3xf32> into tensor<1x1x3x1x3xf32>
+ return %0 : tensor<1x1x3x1x3xf32>
+}
+// CHECK-LABEL: func @rank_reduced_insert_slice
+// CHECK: %[[RESHAPE:.+]] = tensor.collapse_shape %{{.+}} {{\[}}[0, 1], [2]]
+// CHECK: %[[RESULT:.+]] = tensor.insert_slice %[[RESHAPE]]
+// CHECK-SAME: tensor<3x3xf32> into tensor<1x1x3x1x3xf32>
+// CHECK: return %[[RESULT]]
+
+// -----
+
#accesses = [
affine_map<(i, j, k, l, m) -> (i, k, m)>,
affine_map<(i, j, k, l, m) -> ()>,
``````````
</details>
https://github.com/llvm/llvm-project/pull/74723
More information about the Mlir-commits
mailing list