[Mlir-commits] [mlir] 446981b - [mlir][tensor] ExtractSliceFromReshape: handle collapsing of unit dim edge cases

Christopher Bate llvmlistbot at llvm.org
Sat Oct 22 12:29:41 PDT 2022


Author: Christopher Bate
Date: 2022-10-22T13:29:34-06:00
New Revision: 446981bdb64d0ae24ac77b8ba07f3ee3808c3936

URL: https://github.com/llvm/llvm-project/commit/446981bdb64d0ae24ac77b8ba07f3ee3808c3936
DIFF: https://github.com/llvm/llvm-project/commit/446981bdb64d0ae24ac77b8ba07f3ee3808c3936.diff

LOG: [mlir][tensor] ExtractSliceFromReshape: handle collapsing of unit dim edge cases

Prior to this change, the "ExtractSliceFromReshape" pattern would transform

```
%collapsed = tensor.collapse_shape %input [[0, 1], [2]]
                : tensor<1x11x100xf32> into tensor<11x100xf32>
%slice = tensor.extract_slice %collapsed [%offt, 0] [%size, 100] [1, 1]
                : tensor<11x100xf32> to tensor<?x100xf32>
```

into a loop that iterated over the range `%size - %offt`, that pieces
together multiple sub-slices of `%input` along the first dimension. This
is correct but obviously inefficient. The technical condition is that
collapsing at-most-one non-unit dimension of `%src` will not result in a
subsequent slice along the corresponding dimension of `%collapsed`
mapping across discontinuities in the index space of `%src`. Thus, the
definition of a "linearized dimension" (from the perspective of
`tensor.collapse_shape`) is updated to reflect this condition.

The transform will now generate

```
%slice = tensor.extract_slice %input [0, %offt, 0][1, %size, 100] [1, 1]
            : tensor<1x11x100xf32> to tensor<1x?x100xf32>
%result = tensor.collapse_shape [[0, 1], [2]]
            : tensor<1x?x100xf32> to tensor<?x100xf32>
```

which can be further canonicalized.

Additional tests are added to check this family of edge cases.

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D135726

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tensor/Transforms/TransformUtils.h
    mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
    mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshapeUtils.cpp
    mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
    mlir/test/Dialect/Tensor/extract-slice-from-collapse-shape.mlir
    mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tensor/Transforms/TransformUtils.h b/mlir/include/mlir/Dialect/Tensor/Transforms/TransformUtils.h
index 96b7f99baf59f..13e38af8ae906 100644
--- a/mlir/include/mlir/Dialect/Tensor/Transforms/TransformUtils.h
+++ b/mlir/include/mlir/Dialect/Tensor/Transforms/TransformUtils.h
@@ -204,6 +204,66 @@ class ExtractSliceFromCollapseHelper {
   SmallVector<Value> tiledSizes;
 };
 
+/// Tries to simplify a `tensor.collapse_shape` operation by inserting a single
+/// rank-reducing `tensor.extract_slice` operation. The `extract_slice` op will
+/// either take the place of the source, allowing for a new, simpler
+/// `collapse_shape` op to replace `op`, or the `collapse_shape` op will be
+/// completely replaced by the `extract_slice` result. Either way, `op` is
+/// replaced and new new op is returned.
+///
+/// ### Example:
+/// ```
+/// %result = tensor.collapse_shape %0 [[0, 1], [2, 3]]
+///    : tensor<?x1x30x10xf32> to tensor<?x300xf32>
+/// ```
+/// can be transformed to
+///
+/// ```
+/// %tmp = tensor.extract_slice %0 [0, 0, 0, 0]
+///                         [0, %dim1, 30, 30]
+///                         [1, 1, 1 1]
+///   : tensor<?x1x30x10xf32> to tensor<?x30x10xf32>
+/// %result = tensor.collapse_shape %tmp [[0], [1, 2]]
+///   : tensor<?x30x10xf32> to tensor<?x300xf32>
+/// ```
+///
+/// ### Example:
+///
+/// ```
+/// %result = tensor.collapse_shape %1 [[0, 1], [2]]
+///    : tensor<?x1x30xf32> to tensor<?x30xf32>
+/// ```
+/// can be transformed to
+/// ```
+/// %result = tensor.extract_slice %1 [0, 0, 0]
+///                                   [%dim2, 1, 30]
+///                                   [1, 1, 1]
+///    : tensor<?x1x30xf32> to tensor<?x30xf32>
+/// ```
+///
+/// ### Unsupported cases:
+///
+/// This transform doesn't yet support reducing the rank of the reassociation
+/// indices, which would require inserting a `tensor.expand_shape` op similar to
+/// the following example:
+/// ```
+/// %result = tensor.collapse_shape %0 [[0, 1], [2, 3]]
+///    : tensor<1x1x30x10xf32> to tensor<1x300xf32>
+/// ```
+/// can be transformed to
+/// ```
+/// %tmp = tensor.extract_slice %0 [0, 0, 0, 0]
+///                         [0, 1, 30, 30]
+///                         [1, 1, 1 1]
+///   : tensor<1x1x30x10xf32> to tensor<30x10xf32>
+/// %result0 = tensor.collapse_shape %tmp [[0, 1]]
+///   : tensor<30x10xf32> to tensor<300xf32>
+/// %result1 = tensor.expand_shape %tmp [[0, 1], [2]] :... tensor<1x300xf32>
+/// ```
+///
+FailureOr<Operation *>
+simplifyCollapseShapeWithRankReducingExtractSlice(tensor::CollapseShapeOp op,
+                                                  RewriterBase &rewriter);
 } // namespace tensor
 } // namespace mlir
 

diff  --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index 1c584d2742011..dba055d9fd992 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -460,6 +460,58 @@ class SliceFromCollapseHelper {
   llvm::SmallBitVector linearizedDimensions;
   llvm::SmallBitVector slicedDimensions;
 };
+
+/// Parameters required to simplify a collapsing reshape op with a rank-reducing
+/// slice operation. See `getSimplifyCollapseShapeWithRankReducingSliceInfo`.
+struct CollapseShapeRankReducingSliceSimplificationInfo {
+  /// The shape of the output of the rank-reducing slice.
+  RankedTensorType sliceResultType;
+  /// The reassociation indices for the new collapse shape op, if required. If
+  /// `None`, the slice should replace the collapse shape op.
+  Optional<SmallVector<ReassociationIndices>> newReassociationIndices;
+};
+
+/// A collapsing reshape operation can sometimes be simplified or eliminated by
+/// inserting a single rank-reducing slice operation between it and the source
+/// tensor. The slice op will either take the place of the source, allowing for
+/// a new, simpler reshape op to replace the original, or the reshape op will be
+/// completely replaced by the slice result.
+///
+/// This function returns the parameters required to implement this pattern. If
+/// the pattern is not applicable, then failure is returned.
+///
+/// ### Example:
+/// ```
+/// %result = tensor.collapse_shape %0 [[0, 1], [2, 3]]
+///    : tensor<?x1x30x10xf32> to tensor<?x300xf32>
+/// ```
+/// can be transformed to
+/// ```
+/// %tmp = tensor.extract_slice %0 [0, 0, 0, 0]
+///                         [0, %dim1, 30, 30]
+///                         [1, 1, 1 1]
+///   : tensor<?x1x30x10xf32> to tensor<?x30x10xf32>
+/// %result = tensor.collapse_shape %tmp [[0], [1, 2]]
+///   : tensor<?x30x10xf32> to tensor<?x300xf32>
+/// ```
+///
+/// ### Example:
+/// ```
+/// %result = tensor.collapse_shape %1 [[0, 1], [2]]
+///    : tensor<?x1x30xf32> to tensor<?x30xf32>
+/// ```
+/// can be transformed to
+/// ```
+/// %result = tensor.extract_slice %1 [0, 0, 0]
+///                                   [%dim2, 1, 30]
+///                                   [1, 1, 1]
+///    : tensor<?x1x30xf32> to tensor<?x30xf32>
+/// ```
+FailureOr<CollapseShapeRankReducingSliceSimplificationInfo>
+getSimplifyCollapseShapeWithRankReducingSliceInfo(
+    RankedTensorType sourceType,
+    ArrayRef<ReassociationIndices> reassociationIndices);
+
 } // namespace mlir
 
 #endif // MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H

diff  --git a/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshapeUtils.cpp b/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshapeUtils.cpp
index 98430da084d87..67c949c706c09 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshapeUtils.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshapeUtils.cpp
@@ -26,8 +26,8 @@ using namespace mlir;
 using namespace mlir::tensor;
 
 /// Get the dimension size of a value of RankedTensor type at the
-OpFoldResult getShapeDimSize(OpBuilder &b, Location loc, Value rankedTensor,
-                             int64_t dimIdx) {
+static OpFoldResult getShapeDimSize(OpBuilder &b, Location loc,
+                                    Value rankedTensor, int64_t dimIdx) {
   RankedTensorType tensorType = rankedTensor.getType().cast<RankedTensorType>();
   if (!tensorType.isDynamicDim(dimIdx)) {
     return b.getIndexAttr(tensorType.getDimSize(dimIdx));
@@ -103,6 +103,11 @@ FailureOr<ExtractSliceFromCollapseHelper>
 tensor::ExtractSliceFromCollapseHelper::create(OpBuilder &b,
                                                tensor::CollapseShapeOp op,
                                                ArrayRef<Range> sliceParams) {
+  // Don't perform this pattern if the collapse op can be simplified by
+  // a rank-reducing extract slice.
+  if (succeeded(mlir::getSimplifyCollapseShapeWithRankReducingSliceInfo(
+          op.getSrcType(), op.getReassociationIndices())))
+    return failure();
 
   // Materialize the output shape of the collapse_shape operation. This will
   // create IR describing the output shape in terms of the input shape.
@@ -125,9 +130,6 @@ tensor::ExtractSliceFromCollapseHelper::create(OpBuilder &b,
 
   auto collapseShapeInputShape = getShapeDimSizes(b, op.getLoc(), op.getSrc());
 
-  SmallVector<OpFoldResult> srcShape =
-      getShapeDimSizes(b, op->getLoc(), op.getSrc());
-
   SmallVector<Value> tileSizes;
   for (unsigned i = 0; i < sliceParams.size(); i++) {
     if (slicedDimensions[i] && linearizedDimensions[i])
@@ -178,3 +180,36 @@ tensor::ExtractSliceFromCollapseHelper::emitLoopNestBody(
       loc, subTileResult, reassociationIndices);
   return std::make_pair(collapsedResult, insertParams);
 }
+
+FailureOr<Operation *>
+tensor::simplifyCollapseShapeWithRankReducingExtractSlice(
+    tensor::CollapseShapeOp op, RewriterBase &rewriter) {
+  SmallVector<ReassociationIndices> reassociationIndices =
+      op.getReassociationIndices();
+  RankedTensorType sourceType = op.getSrcType();
+  FailureOr<CollapseShapeRankReducingSliceSimplificationInfo> info =
+      getSimplifyCollapseShapeWithRankReducingSliceInfo(sourceType,
+                                                        reassociationIndices);
+  if (failed(info))
+    return failure();
+
+  // Create the rank-reducing extract slice op.
+  auto zero = rewriter.getIndexAttr(0);
+  auto one = rewriter.getIndexAttr(1);
+  SmallVector<OpFoldResult> offsets(sourceType.getRank(), zero);
+  SmallVector<OpFoldResult> sizes =
+      getShapeDimSizes(rewriter, op.getLoc(), op.getSrc());
+  SmallVector<OpFoldResult> strides(sourceType.getRank(), one);
+  auto sliceOp = rewriter.create<tensor::ExtractSliceOp>(
+      op.getLoc(), info->sliceResultType, op.getSrc(), offsets, sizes, strides);
+
+  if (!info->newReassociationIndices.has_value()) {
+    rewriter.replaceOp(op, sliceOp.getResult());
+    return sliceOp.getOperation();
+  }
+
+  return rewriter
+      .replaceOpWithNewOp<tensor::CollapseShapeOp>(
+          op, sliceOp.getResult(), info->newReassociationIndices.value())
+      .getOperation();
+}

diff  --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
index 9bca50f643216..e31d069b99900 100644
--- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
@@ -352,3 +352,99 @@ SliceFromCollapseHelper::getInsertSliceParams(MLIRContext *ctx,
   }
   return insertParams;
 }
+
+/// Returns the index of the only non-unit dimension among `indices` of `shape`,
+/// if such a dimension exists and `indices` has more than one element.
+/// Otherwise, return none.
+static Optional<int64_t> getUniqueNonUnitDim(ArrayRef<int64_t> indices,
+                                             ArrayRef<int64_t> shape) {
+  // Return false if more than one of the dimensions in this group are not 1.
+  Optional<int64_t> dimIndex = None;
+  if (indices.size() < 2)
+    return None;
+  for (int64_t idx : indices) {
+    if (shape[idx] != 1) {
+      if (dimIndex != None)
+        return None;
+      dimIndex = idx;
+    }
+  }
+  return dimIndex;
+}
+
+// For each segment in the reassociation indices, check whether we can
+// simplify that segment with a rank-reducing extract slice. We can do this if
+// all but (exactly) one of the corresponding source dims is 1.
+static SmallVector<Optional<int64_t>> getCollapseShapeTrivialSegments(
+    RankedTensorType sourceType,
+    ArrayRef<ReassociationIndices> reassociationIndices) {
+  SmallVector<Optional<int64_t>> trivialSegments;
+  for (const auto &indices : reassociationIndices)
+    trivialSegments.push_back(
+        getUniqueNonUnitDim(indices, sourceType.getShape()));
+  return trivialSegments;
+}
+
+/// Returns true if any of the segments of the reassociation indices for a
+/// collapsing reshape can be simplified using a rank-reducing slice.
+static FailureOr<SmallVector<Optional<int64_t>>>
+canCollapseShapeBeSimplifiedByRankReducingSlice(
+    RankedTensorType sourceType,
+    ArrayRef<ReassociationIndices> reassociationIndices) {
+  SmallVector<Optional<int64_t>> trivialSegments =
+      getCollapseShapeTrivialSegments(sourceType, reassociationIndices);
+  if (!llvm::any_of(trivialSegments, [](const Optional<int64_t> &idx) {
+        return idx.has_value();
+      }))
+    return failure();
+  return trivialSegments;
+}
+
+FailureOr<CollapseShapeRankReducingSliceSimplificationInfo>
+mlir::getSimplifyCollapseShapeWithRankReducingSliceInfo(
+    RankedTensorType sourceType,
+    ArrayRef<ReassociationIndices> reassociationIndices) {
+  FailureOr<SmallVector<Optional<int64_t>>> trivialSegments =
+      canCollapseShapeBeSimplifiedByRankReducingSlice(sourceType,
+                                                      reassociationIndices);
+  if (failed(trivialSegments))
+    return failure();
+
+  // Create the expected result shape of the rank-reducing slice.
+  SmallVector<int64_t> sliceShape;
+  for (const auto &[nonUnitDim, indices] :
+       llvm::zip(*trivialSegments, reassociationIndices)) {
+    if (nonUnitDim) {
+      sliceShape.push_back(sourceType.getDimSize(nonUnitDim.value()));
+      continue;
+    }
+    llvm::append_range(sliceShape, llvm::map_range(indices, [&](int64_t idx) {
+                         return sourceType.getDimSize(idx);
+                       }));
+  }
+  auto sliceType =
+      RankedTensorType::get(sliceShape, sourceType.getElementType());
+
+  // If the rank-reducing slice simplified every segment, then we are done.
+  if (sliceShape.size() == reassociationIndices.size())
+    return CollapseShapeRankReducingSliceSimplificationInfo{sliceType, None};
+
+  // Otherwise, we need to create a new collapse_shape op for the segments that
+  // weren't covered by the slice. By design, the new reassociation indices has
+  // the same number of groups as the old reassociation indices.
+  SmallVector<ReassociationIndices> newReassociationIndices;
+  SmallVector<int64_t, 2> reassociation;
+  int64_t groupIdx = 0;
+  for (int64_t dimIdx = 0; dimIdx < sliceType.getRank(); dimIdx++) {
+    reassociation.push_back(dimIdx);
+    if ((*trivialSegments)[groupIdx] ||
+        reassociation.size() == reassociationIndices[groupIdx].size()) {
+      newReassociationIndices.push_back(reassociation);
+      reassociation.clear();
+      groupIdx++;
+    }
+  }
+
+  return CollapseShapeRankReducingSliceSimplificationInfo{
+      sliceType, newReassociationIndices};
+}

diff  --git a/mlir/test/Dialect/Tensor/extract-slice-from-collapse-shape.mlir b/mlir/test/Dialect/Tensor/extract-slice-from-collapse-shape.mlir
index 9a022787c464f..ccbba9013ab29 100644
--- a/mlir/test/Dialect/Tensor/extract-slice-from-collapse-shape.mlir
+++ b/mlir/test/Dialect/Tensor/extract-slice-from-collapse-shape.mlir
@@ -177,3 +177,65 @@ func.func @no_sliced_linearized_dims(%input: tensor<30x11x100xf32>, %offt: index
   // CHECK: return %[[res]]
   return %slice : tensor<330x?xf32>
 }
+
+// -----
+
+// The below tests verify that a dimension which is the result of collapsing at
+// most one non-unit dim is handled properly.
+
+// CHECK: @collapse_and_slice_unit_dim(%[[arg0:.+]]: tensor<{{.*}}>, %[[arg1:.+]]: index, %[[arg2:.+]]: index
+func.func @collapse_and_slice_unit_dim(%input: tensor<1x11x100xf32>, %offt: index, %size: index) -> tensor<?x100xf32> {
+  %collapsed = tensor.collapse_shape %input [[0, 1], [2]] : tensor<1x11x100xf32> into tensor<11x100xf32>
+  %slice = tensor.extract_slice %collapsed [%offt, 0] [%size, 100] [1, 1] : tensor<11x100xf32> to tensor<?x100xf32>
+  // CHECK-NOT: scf.for
+  // CHECK: %[[e:.+]] = tensor.extract_slice %[[arg0]][0, 0, 0] [1, 11, 100] [1, 1, 1]
+  // CHECK-SAME:           tensor<1x11x100xf32> to tensor<11x100xf32>
+  // CHECK: %[[e1:.+]] = tensor.extract_slice %[[e]][%[[arg1]], 0] [%[[arg2]], 100] [1, 1]
+  // CHECK-SAME:           tensor<11x100xf32> to tensor<?x100xf32>    
+  return %slice : tensor<?x100xf32>
+}
+
+// CHECK: @collapse_and_slice_multiple_unit_dim_dynamic(%[[arg0:.+]]: tensor<{{.*}}>, %[[arg1:.+]]: index, %[[arg2:.+]]: index
+func.func @collapse_and_slice_multiple_unit_dim_dynamic(%input: tensor<1x?x1x100xf32>, %offt: index, %size: index) -> tensor<?x100xf32> {
+  %collapsed = tensor.collapse_shape %input [[0, 1, 2], [3]] : tensor<1x?x1x100xf32> into tensor<?x100xf32>
+  %slice = tensor.extract_slice %collapsed [%offt, 0] [%size, 100] [1, 1] : tensor<?x100xf32> to tensor<?x100xf32>
+  // CHECK-NOT: scf.for
+  // CHECK: %[[c1:.+]] = arith.constant 1 : index
+  // CHECK: %[[dim:.+]] = tensor.dim %[[arg0]], %[[c1]] : 
+  // CHECK: %[[e:.+]] = tensor.extract_slice %[[arg0]][0, 0, 0, 0] [1, %[[dim]], 1, 100] [1, 1, 1, 1]
+  // CHECK-SAME:           tensor<1x?x1x100xf32> to tensor<?x100xf32>
+  // CHECK: %[[e1:.+]] = tensor.extract_slice %[[e]][%[[arg1]], 0] [%[[arg2]], 100] [1, 1]
+  // CHECK-SAME:           tensor<?x100xf32> to tensor<?x100xf32>  
+  return %slice : tensor<?x100xf32>
+}
+
+// CHECK: @collapse_and_slice_multiple_unit_dim_mixed(%[[arg0:.+]]: tensor<{{.*}}>, %[[arg1:.+]]: index, %[[arg2:.+]]: index
+func.func @collapse_and_slice_multiple_unit_dim_mixed(%input: tensor<1x?x1x100x10xf32>, %offt: index, %size: index) -> tensor<?x?xf32> {
+  %collapsed = tensor.collapse_shape %input [[0, 1, 2], [3, 4]] : tensor<1x?x1x100x10xf32> into tensor<?x1000xf32>
+  %slice = tensor.extract_slice %collapsed [%offt, %offt] [%size, %size] [1, 1] : tensor<?x1000xf32> to tensor<?x?xf32>
+  // CHECK-DAG: %[[c0]] = arith.constant 0 : index
+  // CHECK-DAG: %[[c1]] = arith.constant 1 : index
+  // CHECK: %[[dim:.+]] = tensor.dim %[[arg0]], %[[c1]]
+  // CHECK: %[[rank_reduced:.+]] = tensor.extract_slice %[[arg0]][0, 0, 0, 0, 0] [1, %[[dim]], 1, 100, 10] [1, 1, 1, 1, 1]
+  // CHECK: %[[empty:.+]] = tensor.empty
+  // CHECK: %[[result:.+]] = scf.for %[[iv:.+]] = %[[c0]] to %[[arg2]] step %[[c1]] iter_args(%[[ia:.+]] = %[[empty]])
+  // CHECK:     %[[idx:.+]] = affine.apply
+  // CHECK:     %[[multi_index:.+]] = affine.delinearize_index %[[idx]] into
+  // CHECK:     %[[collapsed:.+]] = tensor.collapse_shape
+  // CHECK:     %[[updated:.+]] = tensor.insert_slice
+  // CHECK:     scf.yield %[[updated]]
+  // CHECK: return %[[result]]
+  return %slice : tensor<?x?xf32>
+}
+
+// Edge case where all collapsed dims are unit dims. This pattern can't eliminate the collapse shape, 
+// that should be handled by `linalg-fold-unit-extent-dims`.
+
+// CHECK: @collapse_and_slice_multiple_all_unit_dim(%[[arg0:.+]]: tensor<{{.*}}>)
+func.func @collapse_and_slice_multiple_all_unit_dim(%input: tensor<1x1x1x100xf32>) -> tensor<1x100xf32> {
+  %collapsed = tensor.collapse_shape %input [[0, 1, 2], [3]] : tensor<1x1x1x100xf32> into tensor<1x100xf32>
+  %slice = tensor.extract_slice %collapsed [0, 0] [1, 100] [1, 1] : tensor<1x100xf32> to tensor<1x100xf32>  
+  return %slice : tensor<1x100xf32>  
+  // CHECK: %[[collapse:.+]] = tensor.collapse_shape %[[arg0]] {{\[}}[0, 1, 2], [3]] : tensor<1x1x1x100xf32> into tensor<1x100xf32>
+  // CHECK: return %[[collapse]]  
+}

diff  --git a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
index df9e62e64a54b..461da29095465 100644
--- a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
@@ -128,7 +128,22 @@ struct RewriteExtractSliceFromCollapseShapeBase
       return rewriter.notifyMatchFailure(
           op, "producer is not a tensor.collapse_shape op");
 
-    // Materialize the output shape values of the slice operation.a
+    // Try to simplify the collapse shape using a rank-reducing slice, if
+    // possible.
+    FailureOr<Operation *> simplifiedCollapseShapeResult =
+        tensor::simplifyCollapseShapeWithRankReducingExtractSlice(collapseOp,
+                                                                  rewriter);
+    if (succeeded(simplifiedCollapseShapeResult)) {
+      auto newCollapseOp =
+          dyn_cast<tensor::CollapseShapeOp>(*simplifiedCollapseShapeResult);
+      // The collapse shape op might have been simplified away, so we can just
+      // return.
+      if (!newCollapseOp)
+        return success();
+      collapseOp = newCollapseOp;
+    }
+
+    // Materialize the output shape values of the slice operation.
     ReifiedRankedShapedTypeDims reifiedShapes;
     if (failed(op.reifyResultShapes(rewriter, reifiedShapes)))
       return rewriter.notifyMatchFailure(op, "failed to reify result shapes");


        


More information about the Mlir-commits mailing list