[Mlir-commits] [mlir] 446981b - [mlir][tensor] ExtractSliceFromReshape: handle collapsing of unit dim edge cases
Christopher Bate
llvmlistbot at llvm.org
Sat Oct 22 12:29:41 PDT 2022
Author: Christopher Bate
Date: 2022-10-22T13:29:34-06:00
New Revision: 446981bdb64d0ae24ac77b8ba07f3ee3808c3936
URL: https://github.com/llvm/llvm-project/commit/446981bdb64d0ae24ac77b8ba07f3ee3808c3936
DIFF: https://github.com/llvm/llvm-project/commit/446981bdb64d0ae24ac77b8ba07f3ee3808c3936.diff
LOG: [mlir][tensor] ExtractSliceFromReshape: handle collapsing of unit dim edge cases
Prior to this change, the "ExtractSliceFromReshape" pattern would transform
```
%collapsed = tensor.collapse_shape %input [[0, 1], [2]]
: tensor<1x11x100xf32> into tensor<11x100xf32>
%slice = tensor.extract_slice %collapsed [%offt, 0] [%size, 100] [1, 1]
: tensor<11x100xf32> to tensor<?x100xf32>
```
into a loop that iterated over the range `%size - %offt`, that pieces
together multiple sub-slices of `%input` along the first dimension. This
is correct but obviously inefficient. The technical condition is that
collapsing at-most-one non-unit dimension of `%src` will not result in a
subsequent slice along the corresponding dimension of `%collapsed`
mapping across discontinuities in the index space of `%src`. Thus, the
definition of a "linearized dimension" (from the perspective of
`tensor.collapse_shape`) is updated to reflect this condition.
The transform will now generate
```
%slice = tensor.extract_slice %input [0, %offt, 0][1, %size, 100] [1, 1]
: tensor<1x11x100xf32> to tensor<1x?x100xf32>
%result = tensor.collapse_shape [[0, 1], [2]]
: tensor<1x?x100xf32> to tensor<?x100xf32>
```
which can be further canonicalized.
Additional tests are added to check this family of edge cases.
Reviewed By: ThomasRaoux
Differential Revision: https://reviews.llvm.org/D135726
Added:
Modified:
mlir/include/mlir/Dialect/Tensor/Transforms/TransformUtils.h
mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshapeUtils.cpp
mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
mlir/test/Dialect/Tensor/extract-slice-from-collapse-shape.mlir
mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/TransformUtils.h b/mlir/include/mlir/Dialect/Tensor/Transforms/TransformUtils.h
index 96b7f99baf59f..13e38af8ae906 100644
--- a/mlir/include/mlir/Dialect/Tensor/Transforms/TransformUtils.h
+++ b/mlir/include/mlir/Dialect/Tensor/Transforms/TransformUtils.h
@@ -204,6 +204,66 @@ class ExtractSliceFromCollapseHelper {
SmallVector<Value> tiledSizes;
};
+/// Tries to simplify a `tensor.collapse_shape` operation by inserting a single
+/// rank-reducing `tensor.extract_slice` operation. The `extract_slice` op will
+/// either take the place of the source, allowing for a new, simpler
+/// `collapse_shape` op to replace `op`, or the `collapse_shape` op will be
+/// completely replaced by the `extract_slice` result. Either way, `op` is
+/// replaced and new new op is returned.
+///
+/// ### Example:
+/// ```
+/// %result = tensor.collapse_shape %0 [[0, 1], [2, 3]]
+/// : tensor<?x1x30x10xf32> to tensor<?x300xf32>
+/// ```
+/// can be transformed to
+///
+/// ```
+/// %tmp = tensor.extract_slice %0 [0, 0, 0, 0]
+/// [0, %dim1, 30, 30]
+/// [1, 1, 1 1]
+/// : tensor<?x1x30x10xf32> to tensor<?x30x10xf32>
+/// %result = tensor.collapse_shape %tmp [[0], [1, 2]]
+/// : tensor<?x30x10xf32> to tensor<?x300xf32>
+/// ```
+///
+/// ### Example:
+///
+/// ```
+/// %result = tensor.collapse_shape %1 [[0, 1], [2]]
+/// : tensor<?x1x30xf32> to tensor<?x30xf32>
+/// ```
+/// can be transformed to
+/// ```
+/// %result = tensor.extract_slice %1 [0, 0, 0]
+/// [%dim2, 1, 30]
+/// [1, 1, 1]
+/// : tensor<?x1x30xf32> to tensor<?x30xf32>
+/// ```
+///
+/// ### Unsupported cases:
+///
+/// This transform doesn't yet support reducing the rank of the reassociation
+/// indices, which would require inserting a `tensor.expand_shape` op similar to
+/// the following example:
+/// ```
+/// %result = tensor.collapse_shape %0 [[0, 1], [2, 3]]
+/// : tensor<1x1x30x10xf32> to tensor<1x300xf32>
+/// ```
+/// can be transformed to
+/// ```
+/// %tmp = tensor.extract_slice %0 [0, 0, 0, 0]
+/// [0, 1, 30, 30]
+/// [1, 1, 1 1]
+/// : tensor<1x1x30x10xf32> to tensor<30x10xf32>
+/// %result0 = tensor.collapse_shape %tmp [[0, 1]]
+/// : tensor<30x10xf32> to tensor<300xf32>
+/// %result1 = tensor.expand_shape %tmp [[0, 1], [2]] :... tensor<1x300xf32>
+/// ```
+///
+FailureOr<Operation *>
+simplifyCollapseShapeWithRankReducingExtractSlice(tensor::CollapseShapeOp op,
+ RewriterBase &rewriter);
} // namespace tensor
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index 1c584d2742011..dba055d9fd992 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -460,6 +460,58 @@ class SliceFromCollapseHelper {
llvm::SmallBitVector linearizedDimensions;
llvm::SmallBitVector slicedDimensions;
};
+
+/// Parameters required to simplify a collapsing reshape op with a rank-reducing
+/// slice operation. See `getSimplifyCollapseShapeWithRankReducingSliceInfo`.
+struct CollapseShapeRankReducingSliceSimplificationInfo {
+ /// The shape of the output of the rank-reducing slice.
+ RankedTensorType sliceResultType;
+ /// The reassociation indices for the new collapse shape op, if required. If
+ /// `None`, the slice should replace the collapse shape op.
+ Optional<SmallVector<ReassociationIndices>> newReassociationIndices;
+};
+
+/// A collapsing reshape operation can sometimes be simplified or eliminated by
+/// inserting a single rank-reducing slice operation between it and the source
+/// tensor. The slice op will either take the place of the source, allowing for
+/// a new, simpler reshape op to replace the original, or the reshape op will be
+/// completely replaced by the slice result.
+///
+/// This function returns the parameters required to implement this pattern. If
+/// the pattern is not applicable, then failure is returned.
+///
+/// ### Example:
+/// ```
+/// %result = tensor.collapse_shape %0 [[0, 1], [2, 3]]
+/// : tensor<?x1x30x10xf32> to tensor<?x300xf32>
+/// ```
+/// can be transformed to
+/// ```
+/// %tmp = tensor.extract_slice %0 [0, 0, 0, 0]
+/// [0, %dim1, 30, 30]
+/// [1, 1, 1 1]
+/// : tensor<?x1x30x10xf32> to tensor<?x30x10xf32>
+/// %result = tensor.collapse_shape %tmp [[0], [1, 2]]
+/// : tensor<?x30x10xf32> to tensor<?x300xf32>
+/// ```
+///
+/// ### Example:
+/// ```
+/// %result = tensor.collapse_shape %1 [[0, 1], [2]]
+/// : tensor<?x1x30xf32> to tensor<?x30xf32>
+/// ```
+/// can be transformed to
+/// ```
+/// %result = tensor.extract_slice %1 [0, 0, 0]
+/// [%dim2, 1, 30]
+/// [1, 1, 1]
+/// : tensor<?x1x30xf32> to tensor<?x30xf32>
+/// ```
+FailureOr<CollapseShapeRankReducingSliceSimplificationInfo>
+getSimplifyCollapseShapeWithRankReducingSliceInfo(
+ RankedTensorType sourceType,
+ ArrayRef<ReassociationIndices> reassociationIndices);
+
} // namespace mlir
#endif // MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H
diff --git a/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshapeUtils.cpp b/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshapeUtils.cpp
index 98430da084d87..67c949c706c09 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshapeUtils.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshapeUtils.cpp
@@ -26,8 +26,8 @@ using namespace mlir;
using namespace mlir::tensor;
/// Get the dimension size of a value of RankedTensor type at the
-OpFoldResult getShapeDimSize(OpBuilder &b, Location loc, Value rankedTensor,
- int64_t dimIdx) {
+static OpFoldResult getShapeDimSize(OpBuilder &b, Location loc,
+ Value rankedTensor, int64_t dimIdx) {
RankedTensorType tensorType = rankedTensor.getType().cast<RankedTensorType>();
if (!tensorType.isDynamicDim(dimIdx)) {
return b.getIndexAttr(tensorType.getDimSize(dimIdx));
@@ -103,6 +103,11 @@ FailureOr<ExtractSliceFromCollapseHelper>
tensor::ExtractSliceFromCollapseHelper::create(OpBuilder &b,
tensor::CollapseShapeOp op,
ArrayRef<Range> sliceParams) {
+ // Don't perform this pattern if the collapse op can be simplified by
+ // a rank-reducing extract slice.
+ if (succeeded(mlir::getSimplifyCollapseShapeWithRankReducingSliceInfo(
+ op.getSrcType(), op.getReassociationIndices())))
+ return failure();
// Materialize the output shape of the collapse_shape operation. This will
// create IR describing the output shape in terms of the input shape.
@@ -125,9 +130,6 @@ tensor::ExtractSliceFromCollapseHelper::create(OpBuilder &b,
auto collapseShapeInputShape = getShapeDimSizes(b, op.getLoc(), op.getSrc());
- SmallVector<OpFoldResult> srcShape =
- getShapeDimSizes(b, op->getLoc(), op.getSrc());
-
SmallVector<Value> tileSizes;
for (unsigned i = 0; i < sliceParams.size(); i++) {
if (slicedDimensions[i] && linearizedDimensions[i])
@@ -178,3 +180,36 @@ tensor::ExtractSliceFromCollapseHelper::emitLoopNestBody(
loc, subTileResult, reassociationIndices);
return std::make_pair(collapsedResult, insertParams);
}
+
+FailureOr<Operation *>
+tensor::simplifyCollapseShapeWithRankReducingExtractSlice(
+ tensor::CollapseShapeOp op, RewriterBase &rewriter) {
+ SmallVector<ReassociationIndices> reassociationIndices =
+ op.getReassociationIndices();
+ RankedTensorType sourceType = op.getSrcType();
+ FailureOr<CollapseShapeRankReducingSliceSimplificationInfo> info =
+ getSimplifyCollapseShapeWithRankReducingSliceInfo(sourceType,
+ reassociationIndices);
+ if (failed(info))
+ return failure();
+
+ // Create the rank-reducing extract slice op.
+ auto zero = rewriter.getIndexAttr(0);
+ auto one = rewriter.getIndexAttr(1);
+ SmallVector<OpFoldResult> offsets(sourceType.getRank(), zero);
+ SmallVector<OpFoldResult> sizes =
+ getShapeDimSizes(rewriter, op.getLoc(), op.getSrc());
+ SmallVector<OpFoldResult> strides(sourceType.getRank(), one);
+ auto sliceOp = rewriter.create<tensor::ExtractSliceOp>(
+ op.getLoc(), info->sliceResultType, op.getSrc(), offsets, sizes, strides);
+
+ if (!info->newReassociationIndices.has_value()) {
+ rewriter.replaceOp(op, sliceOp.getResult());
+ return sliceOp.getOperation();
+ }
+
+ return rewriter
+ .replaceOpWithNewOp<tensor::CollapseShapeOp>(
+ op, sliceOp.getResult(), info->newReassociationIndices.value())
+ .getOperation();
+}
diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
index 9bca50f643216..e31d069b99900 100644
--- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
@@ -352,3 +352,99 @@ SliceFromCollapseHelper::getInsertSliceParams(MLIRContext *ctx,
}
return insertParams;
}
+
+/// Returns the index of the only non-unit dimension among `indices` of `shape`,
+/// if such a dimension exists and `indices` has more than one element.
+/// Otherwise, return none.
+static Optional<int64_t> getUniqueNonUnitDim(ArrayRef<int64_t> indices,
+ ArrayRef<int64_t> shape) {
+ // Return false if more than one of the dimensions in this group are not 1.
+ Optional<int64_t> dimIndex = None;
+ if (indices.size() < 2)
+ return None;
+ for (int64_t idx : indices) {
+ if (shape[idx] != 1) {
+ if (dimIndex != None)
+ return None;
+ dimIndex = idx;
+ }
+ }
+ return dimIndex;
+}
+
+// For each segment in the reassociation indices, check whether we can
+// simplify that segment with a rank-reducing extract slice. We can do this if
+// all but (exactly) one of the corresponding source dims is 1.
+static SmallVector<Optional<int64_t>> getCollapseShapeTrivialSegments(
+ RankedTensorType sourceType,
+ ArrayRef<ReassociationIndices> reassociationIndices) {
+ SmallVector<Optional<int64_t>> trivialSegments;
+ for (const auto &indices : reassociationIndices)
+ trivialSegments.push_back(
+ getUniqueNonUnitDim(indices, sourceType.getShape()));
+ return trivialSegments;
+}
+
+/// Returns true if any of the segments of the reassociation indices for a
+/// collapsing reshape can be simplified using a rank-reducing slice.
+static FailureOr<SmallVector<Optional<int64_t>>>
+canCollapseShapeBeSimplifiedByRankReducingSlice(
+ RankedTensorType sourceType,
+ ArrayRef<ReassociationIndices> reassociationIndices) {
+ SmallVector<Optional<int64_t>> trivialSegments =
+ getCollapseShapeTrivialSegments(sourceType, reassociationIndices);
+ if (!llvm::any_of(trivialSegments, [](const Optional<int64_t> &idx) {
+ return idx.has_value();
+ }))
+ return failure();
+ return trivialSegments;
+}
+
+FailureOr<CollapseShapeRankReducingSliceSimplificationInfo>
+mlir::getSimplifyCollapseShapeWithRankReducingSliceInfo(
+ RankedTensorType sourceType,
+ ArrayRef<ReassociationIndices> reassociationIndices) {
+ FailureOr<SmallVector<Optional<int64_t>>> trivialSegments =
+ canCollapseShapeBeSimplifiedByRankReducingSlice(sourceType,
+ reassociationIndices);
+ if (failed(trivialSegments))
+ return failure();
+
+ // Create the expected result shape of the rank-reducing slice.
+ SmallVector<int64_t> sliceShape;
+ for (const auto &[nonUnitDim, indices] :
+ llvm::zip(*trivialSegments, reassociationIndices)) {
+ if (nonUnitDim) {
+ sliceShape.push_back(sourceType.getDimSize(nonUnitDim.value()));
+ continue;
+ }
+ llvm::append_range(sliceShape, llvm::map_range(indices, [&](int64_t idx) {
+ return sourceType.getDimSize(idx);
+ }));
+ }
+ auto sliceType =
+ RankedTensorType::get(sliceShape, sourceType.getElementType());
+
+ // If the rank-reducing slice simplified every segment, then we are done.
+ if (sliceShape.size() == reassociationIndices.size())
+ return CollapseShapeRankReducingSliceSimplificationInfo{sliceType, None};
+
+ // Otherwise, we need to create a new collapse_shape op for the segments that
+ // weren't covered by the slice. By design, the new reassociation indices has
+ // the same number of groups as the old reassociation indices.
+ SmallVector<ReassociationIndices> newReassociationIndices;
+ SmallVector<int64_t, 2> reassociation;
+ int64_t groupIdx = 0;
+ for (int64_t dimIdx = 0; dimIdx < sliceType.getRank(); dimIdx++) {
+ reassociation.push_back(dimIdx);
+ if ((*trivialSegments)[groupIdx] ||
+ reassociation.size() == reassociationIndices[groupIdx].size()) {
+ newReassociationIndices.push_back(reassociation);
+ reassociation.clear();
+ groupIdx++;
+ }
+ }
+
+ return CollapseShapeRankReducingSliceSimplificationInfo{
+ sliceType, newReassociationIndices};
+}
diff --git a/mlir/test/Dialect/Tensor/extract-slice-from-collapse-shape.mlir b/mlir/test/Dialect/Tensor/extract-slice-from-collapse-shape.mlir
index 9a022787c464f..ccbba9013ab29 100644
--- a/mlir/test/Dialect/Tensor/extract-slice-from-collapse-shape.mlir
+++ b/mlir/test/Dialect/Tensor/extract-slice-from-collapse-shape.mlir
@@ -177,3 +177,65 @@ func.func @no_sliced_linearized_dims(%input: tensor<30x11x100xf32>, %offt: index
// CHECK: return %[[res]]
return %slice : tensor<330x?xf32>
}
+
+// -----
+
+// The below tests verify that a dimension which is the result of collapsing at
+// most one non-unit dim is handled properly.
+
+// CHECK: @collapse_and_slice_unit_dim(%[[arg0:.+]]: tensor<{{.*}}>, %[[arg1:.+]]: index, %[[arg2:.+]]: index
+func.func @collapse_and_slice_unit_dim(%input: tensor<1x11x100xf32>, %offt: index, %size: index) -> tensor<?x100xf32> {
+ %collapsed = tensor.collapse_shape %input [[0, 1], [2]] : tensor<1x11x100xf32> into tensor<11x100xf32>
+ %slice = tensor.extract_slice %collapsed [%offt, 0] [%size, 100] [1, 1] : tensor<11x100xf32> to tensor<?x100xf32>
+ // CHECK-NOT: scf.for
+ // CHECK: %[[e:.+]] = tensor.extract_slice %[[arg0]][0, 0, 0] [1, 11, 100] [1, 1, 1]
+ // CHECK-SAME: tensor<1x11x100xf32> to tensor<11x100xf32>
+ // CHECK: %[[e1:.+]] = tensor.extract_slice %[[e]][%[[arg1]], 0] [%[[arg2]], 100] [1, 1]
+ // CHECK-SAME: tensor<11x100xf32> to tensor<?x100xf32>
+ return %slice : tensor<?x100xf32>
+}
+
+// CHECK: @collapse_and_slice_multiple_unit_dim_dynamic(%[[arg0:.+]]: tensor<{{.*}}>, %[[arg1:.+]]: index, %[[arg2:.+]]: index
+func.func @collapse_and_slice_multiple_unit_dim_dynamic(%input: tensor<1x?x1x100xf32>, %offt: index, %size: index) -> tensor<?x100xf32> {
+ %collapsed = tensor.collapse_shape %input [[0, 1, 2], [3]] : tensor<1x?x1x100xf32> into tensor<?x100xf32>
+ %slice = tensor.extract_slice %collapsed [%offt, 0] [%size, 100] [1, 1] : tensor<?x100xf32> to tensor<?x100xf32>
+ // CHECK-NOT: scf.for
+ // CHECK: %[[c1:.+]] = arith.constant 1 : index
+ // CHECK: %[[dim:.+]] = tensor.dim %[[arg0]], %[[c1]] :
+ // CHECK: %[[e:.+]] = tensor.extract_slice %[[arg0]][0, 0, 0, 0] [1, %[[dim]], 1, 100] [1, 1, 1, 1]
+ // CHECK-SAME: tensor<1x?x1x100xf32> to tensor<?x100xf32>
+ // CHECK: %[[e1:.+]] = tensor.extract_slice %[[e]][%[[arg1]], 0] [%[[arg2]], 100] [1, 1]
+ // CHECK-SAME: tensor<?x100xf32> to tensor<?x100xf32>
+ return %slice : tensor<?x100xf32>
+}
+
+// CHECK: @collapse_and_slice_multiple_unit_dim_mixed(%[[arg0:.+]]: tensor<{{.*}}>, %[[arg1:.+]]: index, %[[arg2:.+]]: index
+func.func @collapse_and_slice_multiple_unit_dim_mixed(%input: tensor<1x?x1x100x10xf32>, %offt: index, %size: index) -> tensor<?x?xf32> {
+ %collapsed = tensor.collapse_shape %input [[0, 1, 2], [3, 4]] : tensor<1x?x1x100x10xf32> into tensor<?x1000xf32>
+ %slice = tensor.extract_slice %collapsed [%offt, %offt] [%size, %size] [1, 1] : tensor<?x1000xf32> to tensor<?x?xf32>
+ // CHECK-DAG: %[[c0]] = arith.constant 0 : index
+ // CHECK-DAG: %[[c1]] = arith.constant 1 : index
+ // CHECK: %[[dim:.+]] = tensor.dim %[[arg0]], %[[c1]]
+ // CHECK: %[[rank_reduced:.+]] = tensor.extract_slice %[[arg0]][0, 0, 0, 0, 0] [1, %[[dim]], 1, 100, 10] [1, 1, 1, 1, 1]
+ // CHECK: %[[empty:.+]] = tensor.empty
+ // CHECK: %[[result:.+]] = scf.for %[[iv:.+]] = %[[c0]] to %[[arg2]] step %[[c1]] iter_args(%[[ia:.+]] = %[[empty]])
+ // CHECK: %[[idx:.+]] = affine.apply
+ // CHECK: %[[multi_index:.+]] = affine.delinearize_index %[[idx]] into
+ // CHECK: %[[collapsed:.+]] = tensor.collapse_shape
+ // CHECK: %[[updated:.+]] = tensor.insert_slice
+ // CHECK: scf.yield %[[updated]]
+ // CHECK: return %[[result]]
+ return %slice : tensor<?x?xf32>
+}
+
+// Edge case where all collapsed dims are unit dims. This pattern can't eliminate the collapse shape,
+// that should be handled by `linalg-fold-unit-extent-dims`.
+
+// CHECK: @collapse_and_slice_multiple_all_unit_dim(%[[arg0:.+]]: tensor<{{.*}}>)
+func.func @collapse_and_slice_multiple_all_unit_dim(%input: tensor<1x1x1x100xf32>) -> tensor<1x100xf32> {
+ %collapsed = tensor.collapse_shape %input [[0, 1, 2], [3]] : tensor<1x1x1x100xf32> into tensor<1x100xf32>
+ %slice = tensor.extract_slice %collapsed [0, 0] [1, 100] [1, 1] : tensor<1x100xf32> to tensor<1x100xf32>
+ return %slice : tensor<1x100xf32>
+ // CHECK: %[[collapse:.+]] = tensor.collapse_shape %[[arg0]] {{\[}}[0, 1, 2], [3]] : tensor<1x1x1x100xf32> into tensor<1x100xf32>
+ // CHECK: return %[[collapse]]
+}
diff --git a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
index df9e62e64a54b..461da29095465 100644
--- a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
@@ -128,7 +128,22 @@ struct RewriteExtractSliceFromCollapseShapeBase
return rewriter.notifyMatchFailure(
op, "producer is not a tensor.collapse_shape op");
- // Materialize the output shape values of the slice operation.a
+ // Try to simplify the collapse shape using a rank-reducing slice, if
+ // possible.
+ FailureOr<Operation *> simplifiedCollapseShapeResult =
+ tensor::simplifyCollapseShapeWithRankReducingExtractSlice(collapseOp,
+ rewriter);
+ if (succeeded(simplifiedCollapseShapeResult)) {
+ auto newCollapseOp =
+ dyn_cast<tensor::CollapseShapeOp>(*simplifiedCollapseShapeResult);
+ // The collapse shape op might have been simplified away, so we can just
+ // return.
+ if (!newCollapseOp)
+ return success();
+ collapseOp = newCollapseOp;
+ }
+
+ // Materialize the output shape values of the slice operation.
ReifiedRankedShapedTypeDims reifiedShapes;
if (failed(op.reifyResultShapes(rewriter, reifiedShapes)))
return rewriter.notifyMatchFailure(op, "failed to reify result shapes");
More information about the Mlir-commits
mailing list