[Mlir-commits] [mlir] 82ab0f7 - [mlir][linalg] Fix rank-reduced cases for extract/insert slice in DropUnitDims (#74723)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Dec 16 07:08:55 PST 2023
Author: Quinn Dawkins
Date: 2023-12-16T10:08:51-05:00
New Revision: 82ab0f7f36222a0311b5220df52f4193664569e8
URL: https://github.com/llvm/llvm-project/commit/82ab0f7f36222a0311b5220df52f4193664569e8
DIFF: https://github.com/llvm/llvm-project/commit/82ab0f7f36222a0311b5220df52f4193664569e8.diff
LOG: [mlir][linalg] Fix rank-reduced cases for extract/insert slice in DropUnitDims (#74723)
Inferring the reshape reassociation indices for extract/insert slice ops
based on the read sizes of the original slicing op will generate an
invalid expand/collapse shape op for already rank-reduced cases. Instead
just infer from the shape of the slice.
Ported from Differential Revision: https://reviews.llvm.org/D147488
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index 6fbf3514557871..c495956fa57702 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -572,13 +572,17 @@ struct RankReducedExtractSliceOp
LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
PatternRewriter &rewriter) const override {
RankedTensorType resultType = sliceOp.getType();
- SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets();
- SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes();
- SmallVector<OpFoldResult> strides = sliceOp.getMixedStrides();
- auto reassociation = getReassociationMapForFoldingUnitDims(sizes);
+ SmallVector<OpFoldResult> targetShape;
+ for (auto size : resultType.getShape())
+ targetShape.push_back(rewriter.getIndexAttr(size));
+ auto reassociation = getReassociationMapForFoldingUnitDims(targetShape);
if (!reassociation ||
reassociation->size() == static_cast<size_t>(resultType.getRank()))
return failure();
+
+ SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets();
+ SmallVector<OpFoldResult> strides = sliceOp.getMixedStrides();
+ SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes();
auto rankReducedType = cast<RankedTensorType>(
tensor::ExtractSliceOp::inferCanonicalRankReducedResultType(
reassociation->size(), sliceOp.getSourceType(), offsets, sizes,
@@ -602,13 +606,14 @@ struct RankReducedInsertSliceOp : public OpRewritePattern<InsertOpTy> {
LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
PatternRewriter &rewriter) const override {
RankedTensorType sourceType = insertSliceOp.getSourceType();
- SmallVector<OpFoldResult> offsets = insertSliceOp.getMixedOffsets();
- SmallVector<OpFoldResult> sizes = insertSliceOp.getMixedSizes();
- SmallVector<OpFoldResult> strides = insertSliceOp.getMixedStrides();
- auto reassociation = getReassociationMapForFoldingUnitDims(sizes);
+ SmallVector<OpFoldResult> targetShape;
+ for (auto size : sourceType.getShape())
+ targetShape.push_back(rewriter.getIndexAttr(size));
+ auto reassociation = getReassociationMapForFoldingUnitDims(targetShape);
if (!reassociation ||
reassociation->size() == static_cast<size_t>(sourceType.getRank()))
return failure();
+
Location loc = insertSliceOp.getLoc();
tensor::CollapseShapeOp reshapedSource;
{
diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
index 795e9ee5287173..0c51a032df9016 100644
--- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
+++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
@@ -489,6 +489,18 @@ func.func @slice_unit_dims(%arg0: tensor<1x3xf32>) -> tensor<1x1xf32> {
// -----
+func.func @rank_reduced_extract_slice(%arg0: tensor<1x1x3x1x3xf32>) -> tensor<1x3x3xf32> {
+ %0 = tensor.extract_slice %arg0[0, 0, 0, 0, 0] [1, 1, 3, 1, 3] [1, 1, 1, 1, 1] : tensor<1x1x3x1x3xf32> to tensor<1x3x3xf32>
+ return %0 : tensor<1x3x3xf32>
+}
+// CHECK-LABEL: func @rank_reduced_extract_slice
+// CHECK: %[[SLICE:.+]] = tensor.extract_slice
+// CHECK-SAME: tensor<1x1x3x1x3xf32> to tensor<3x3xf32>
+// CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[SLICE]] {{\[}}[0, 1], [2]]
+// CHECK: return %[[RESULT]]
+
+// -----
+
func.func @insert_slice_unit_dims(%arg0: tensor<1x3xf32>, %arg1: tensor<1x1xf32>) -> tensor<1x3xf32> {
%0 = tensor.insert_slice %arg1 into %arg0[0, 2] [1, 1] [1, 1] : tensor<1x1xf32> into tensor<1x3xf32>
return %0 : tensor<1x3xf32>
@@ -501,6 +513,18 @@ func.func @insert_slice_unit_dims(%arg0: tensor<1x3xf32>, %arg1: tensor<1x1xf32>
// -----
+func.func @rank_reduced_insert_slice(%arg0: tensor<1x1x3x1x3xf32>, %arg1: tensor<1x3x3xf32>) -> tensor<1x1x3x1x3xf32> {
+ %0 = tensor.insert_slice %arg1 into %arg0[0, 0, 0, 0, 0] [1, 1, 3, 1, 3] [1, 1, 1, 1, 1] : tensor<1x3x3xf32> into tensor<1x1x3x1x3xf32>
+ return %0 : tensor<1x1x3x1x3xf32>
+}
+// CHECK-LABEL: func @rank_reduced_insert_slice
+// CHECK: %[[RESHAPE:.+]] = tensor.collapse_shape %{{.+}} {{\[}}[0, 1], [2]]
+// CHECK: %[[RESULT:.+]] = tensor.insert_slice %[[RESHAPE]]
+// CHECK-SAME: tensor<3x3xf32> into tensor<1x1x3x1x3xf32>
+// CHECK: return %[[RESULT]]
+
+// -----
+
#accesses = [
affine_map<(i, j, k, l, m) -> (i, k, m)>,
affine_map<(i, j, k, l, m) -> ()>,
More information about the Mlir-commits
mailing list