[Mlir-commits] [mlir] [mlir][linalg] Fix rank-reduced cases for extract/insert slice in DropUnitDims (PR #74723)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Dec 7 06:55:53 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-linalg

Author: Quinn Dawkins (qedawkins)

<details>
<summary>Changes</summary>

Inferring the reshape reassociation indices for extract/insert slice ops based on the read sizes of the original slicing op will generate an invalid expand/collapse shape op for already rank-reduced cases. Instead just infer from the shape of the slice.

Ported from Differential Revision: https://reviews.llvm.org/D147488

---
Full diff: https://github.com/llvm/llvm-project/pull/74723.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp (+13-8) 
- (modified) mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir (+24) 


``````````diff
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index 6fbf351455787..c495956fa5770 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -572,13 +572,17 @@ struct RankReducedExtractSliceOp
   LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
                                 PatternRewriter &rewriter) const override {
     RankedTensorType resultType = sliceOp.getType();
-    SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets();
-    SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes();
-    SmallVector<OpFoldResult> strides = sliceOp.getMixedStrides();
-    auto reassociation = getReassociationMapForFoldingUnitDims(sizes);
+    SmallVector<OpFoldResult> targetShape;
+    for (auto size : resultType.getShape())
+      targetShape.push_back(rewriter.getIndexAttr(size));
+    auto reassociation = getReassociationMapForFoldingUnitDims(targetShape);
     if (!reassociation ||
         reassociation->size() == static_cast<size_t>(resultType.getRank()))
       return failure();
+
+    SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets();
+    SmallVector<OpFoldResult> strides = sliceOp.getMixedStrides();
+    SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes();
     auto rankReducedType = cast<RankedTensorType>(
         tensor::ExtractSliceOp::inferCanonicalRankReducedResultType(
             reassociation->size(), sliceOp.getSourceType(), offsets, sizes,
@@ -602,13 +606,14 @@ struct RankReducedInsertSliceOp : public OpRewritePattern<InsertOpTy> {
   LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
                                 PatternRewriter &rewriter) const override {
     RankedTensorType sourceType = insertSliceOp.getSourceType();
-    SmallVector<OpFoldResult> offsets = insertSliceOp.getMixedOffsets();
-    SmallVector<OpFoldResult> sizes = insertSliceOp.getMixedSizes();
-    SmallVector<OpFoldResult> strides = insertSliceOp.getMixedStrides();
-    auto reassociation = getReassociationMapForFoldingUnitDims(sizes);
+    SmallVector<OpFoldResult> targetShape;
+    for (auto size : sourceType.getShape())
+      targetShape.push_back(rewriter.getIndexAttr(size));
+    auto reassociation = getReassociationMapForFoldingUnitDims(targetShape);
     if (!reassociation ||
         reassociation->size() == static_cast<size_t>(sourceType.getRank()))
       return failure();
+
     Location loc = insertSliceOp.getLoc();
     tensor::CollapseShapeOp reshapedSource;
     {
diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
index 795e9ee528717..0c51a032df901 100644
--- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
+++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
@@ -489,6 +489,18 @@ func.func @slice_unit_dims(%arg0: tensor<1x3xf32>) -> tensor<1x1xf32> {
 
 // -----
 
+func.func @rank_reduced_extract_slice(%arg0: tensor<1x1x3x1x3xf32>) -> tensor<1x3x3xf32> {
+  %0 = tensor.extract_slice %arg0[0, 0, 0, 0, 0] [1, 1, 3, 1, 3] [1, 1, 1, 1, 1] : tensor<1x1x3x1x3xf32> to tensor<1x3x3xf32>
+  return %0 : tensor<1x3x3xf32>
+}
+// CHECK-LABEL: func @rank_reduced_extract_slice
+//       CHECK:   %[[SLICE:.+]] = tensor.extract_slice
+//  CHECK-SAME:     tensor<1x1x3x1x3xf32> to tensor<3x3xf32>
+//       CHECK:   %[[RESULT:.+]] = tensor.expand_shape %[[SLICE]] {{\[}}[0, 1], [2]]
+//       CHECK:   return %[[RESULT]]
+
+// -----
+
 func.func @insert_slice_unit_dims(%arg0: tensor<1x3xf32>, %arg1: tensor<1x1xf32>) -> tensor<1x3xf32> {
   %0 = tensor.insert_slice %arg1 into %arg0[0, 2] [1, 1] [1, 1] : tensor<1x1xf32> into tensor<1x3xf32>
   return %0 : tensor<1x3xf32>
@@ -501,6 +513,18 @@ func.func @insert_slice_unit_dims(%arg0: tensor<1x3xf32>, %arg1: tensor<1x1xf32>
 
 // -----
 
+func.func @rank_reduced_insert_slice(%arg0: tensor<1x1x3x1x3xf32>, %arg1: tensor<1x3x3xf32>) -> tensor<1x1x3x1x3xf32> {
+  %0 = tensor.insert_slice %arg1 into %arg0[0, 0, 0, 0, 0] [1, 1, 3, 1, 3] [1, 1, 1, 1, 1] : tensor<1x3x3xf32> into tensor<1x1x3x1x3xf32>
+  return %0 : tensor<1x1x3x1x3xf32>
+}
+// CHECK-LABEL: func @rank_reduced_insert_slice
+//       CHECK:   %[[RESHAPE:.+]] = tensor.collapse_shape %{{.+}} {{\[}}[0, 1], [2]]
+//       CHECK:   %[[RESULT:.+]] = tensor.insert_slice %[[RESHAPE]]
+//  CHECK-SAME:     tensor<3x3xf32> into tensor<1x1x3x1x3xf32>
+//       CHECK:   return %[[RESULT]]
+
+// -----
+
 #accesses = [
   affine_map<(i, j, k, l, m) -> (i, k, m)>,
   affine_map<(i, j, k, l, m) -> ()>,

``````````

</details>


https://github.com/llvm/llvm-project/pull/74723


More information about the Mlir-commits mailing list