[Mlir-commits] [mlir] [mlir][tensor] Fold unpadding collapse_shape into extract_slice (PR #93554)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue May 28 07:27:37 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Adam Siemieniuk (adam-smnk)
<details>
<summary>Changes</summary>
---
Full diff: https://github.com/llvm/llvm-project/pull/93554.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp (+40-6)
- (modified) mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir (+69)
``````````diff
diff --git a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
index 6cf0f845f59db..d7c608a773bb7 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
@@ -48,6 +48,39 @@ struct FoldExpandOfRankReducingExtract
}
};
+/// Fold collapse_shape which only removes static dimensions of size `1`
+/// into extract_slice.
+struct FoldUnPaddingCollapseIntoExtract
+ : public OpRewritePattern<tensor::CollapseShapeOp> {
+ using OpRewritePattern<tensor::CollapseShapeOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tensor::CollapseShapeOp collapseShapeOp,
+ PatternRewriter &rewriter) const override {
+ auto extractSliceOp =
+ collapseShapeOp.getSrc().getDefiningOp<tensor::ExtractSliceOp>();
+ // Collapse cannot be folded away with multiple users of the extract slice
+ // and it is not necessarily beneficial to only convert the collapse into
+ // another extract slice.
+ if (!extractSliceOp || !extractSliceOp.getResult().hasOneUse())
+ return failure();
+
+ // Only fold away simple collapse where all removed dimensions have static
+ // size `1`.
+ SliceVerificationResult res = isRankReducedType(
+ collapseShapeOp.getSrcType(), collapseShapeOp.getResultType());
+ if (res != SliceVerificationResult::Success)
+ return rewriter.notifyMatchFailure(collapseShapeOp,
+ "expected unpadding collapse");
+
+ Value unPaddedExtractSlice = rewriter.create<tensor::ExtractSliceOp>(
+ extractSliceOp.getLoc(), collapseShapeOp.getResultType(),
+ extractSliceOp.getSource(), extractSliceOp.getMixedOffsets(),
+ extractSliceOp.getMixedSizes(), extractSliceOp.getMixedStrides());
+ rewriter.replaceOp(collapseShapeOp, unPaddedExtractSlice);
+ return success();
+ }
+};
+
/// Fold insert_slice(collapse_shape) ops that cancel itself out.
template <typename OpTy>
struct FoldInsertOfRankReducingInsert : public OpRewritePattern<OpTy> {
@@ -111,10 +144,11 @@ struct FoldPaddingExpandIntoInsert : public OpRewritePattern<OpTy> {
void mlir::tensor::populateReassociativeReshapeFoldingPatterns(
RewritePatternSet &patterns) {
- patterns.add<FoldExpandOfRankReducingExtract,
- FoldInsertOfRankReducingInsert<tensor::InsertSliceOp>,
- FoldInsertOfRankReducingInsert<tensor::ParallelInsertSliceOp>,
- FoldPaddingExpandIntoInsert<tensor::InsertSliceOp>,
- FoldPaddingExpandIntoInsert<tensor::ParallelInsertSliceOp>>(
- patterns.getContext());
+ patterns
+ .add<FoldExpandOfRankReducingExtract, FoldUnPaddingCollapseIntoExtract,
+ FoldInsertOfRankReducingInsert<tensor::InsertSliceOp>,
+ FoldInsertOfRankReducingInsert<tensor::ParallelInsertSliceOp>,
+ FoldPaddingExpandIntoInsert<tensor::InsertSliceOp>,
+ FoldPaddingExpandIntoInsert<tensor::ParallelInsertSliceOp>>(
+ patterns.getContext());
}
diff --git a/mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir b/mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir
index 644d9a918f6ca..c2368c4bf2c91 100644
--- a/mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir
+++ b/mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir
@@ -22,6 +22,75 @@ func.func @expand_shape_of_rank_reducing_extract(
// -----
+// CHECK-LABEL: func @unpadding_collapse_of_extract_slice(
+// CHECK-SAME: %[[t:.*]]: tensor<?x?x?x?xf32>
+// CHECK-SAME: %[[x:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[y:[a-zA-Z0-9_]+]]: index
+// CHECK: %[[extract:.*]] = tensor.extract_slice %[[t]][%[[x]], %[[y]], 0, 0] [1, %{{.*}}, 1, %{{.*}}] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<?x?xf32>
+// CHECK: return %[[extract]]
+func.func @unpadding_collapse_of_extract_slice(
+ %t: tensor<?x?x?x?xf32>, %x: index, %y: index)
+ -> tensor<?x?xf32> {
+ %c1 = arith.constant 1 : index
+ %c3 = arith.constant 3 : index
+ %sz0 = tensor.dim %t, %c1 : tensor<?x?x?x?xf32>
+ %sz1 = tensor.dim %t, %c3 : tensor<?x?x?x?xf32>
+ %0 = tensor.extract_slice %t[%x, %y, 0, 0] [1, %sz0, 1, %sz1] [1, 1, 1, 1]
+ : tensor<?x?x?x?xf32> to tensor<1x?x1x?xf32>
+ %1 = tensor.collapse_shape %0 [[0, 1], [2, 3]]
+ : tensor<1x?x1x?xf32> into tensor<?x?xf32>
+ return %1 : tensor<?x?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @non_unpadding_collapse_of_extract_slice(
+// CHECK-SAME: %[[t:.*]]: tensor<?x?x?x?xf32>
+// CHECK-SAME: %[[x:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[y:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[sz:[a-zA-Z0-9_]+]]: index
+// CHECK: %[[extract:.*]] = tensor.extract_slice %[[t]][%[[x]], %[[y]], 0, 0] [%{{.*}}, %{{.*}}, %[[sz]], 1] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<?x?x?xf32>
+// CHECK: %[[collapse:.*]] = tensor.collapse_shape %[[extract]] {{\[}}[0], [1, 2]] : tensor<?x?x?xf32> into tensor<?x?xf32>
+// CHECK: return %[[collapse]]
+func.func @non_unpadding_collapse_of_extract_slice(
+ %t: tensor<?x?x?x?xf32>, %x: index, %y: index, %sz: index)
+ -> tensor<?x?xf32> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %sz0 = tensor.dim %t, %c0 : tensor<?x?x?x?xf32>
+ %sz1 = tensor.dim %t, %c1 : tensor<?x?x?x?xf32>
+ %0 = tensor.extract_slice %t[%x, %y, 0, 0] [%sz0, %sz1, %sz, 1] [1, 1, 1, 1]
+ : tensor<?x?x?x?xf32> to tensor<?x?x?xf32>
+ %1 = tensor.collapse_shape %0 [[0], [1, 2]]
+ : tensor<?x?x?xf32> into tensor<?x?xf32>
+ return %1 : tensor<?x?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @unpadding_collapse_of_extract_slice_with_multiple_users(
+// CHECK-SAME: %[[t:.*]]: tensor<?x?x?x?xf32>
+// CHECK-SAME: %[[x:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[y:[a-zA-Z0-9_]+]]: index
+// CHECK: %[[extract:.*]] = tensor.extract_slice %[[t]][%[[x]], %[[y]], 0, 0] [1, %{{.*}}, 1, %{{.*}}] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<1x?x1x?xf32>
+// CHECK: %[[collapse:.*]] = tensor.collapse_shape %[[extract]] {{\[}}[0, 1], [2, 3]] : tensor<1x?x1x?xf32> into tensor<?x?xf32>
+// CHECK: return %[[extract]], %[[collapse]]
+func.func @unpadding_collapse_of_extract_slice_with_multiple_users(
+ %t: tensor<?x?x?x?xf32>, %x: index, %y: index)
+ -> (tensor<1x?x1x?xf32>, tensor<?x?xf32>) {
+ %c1 = arith.constant 1 : index
+ %c3 = arith.constant 3 : index
+ %sz0 = tensor.dim %t, %c1 : tensor<?x?x?x?xf32>
+ %sz1 = tensor.dim %t, %c3 : tensor<?x?x?x?xf32>
+ %0 = tensor.extract_slice %t[%x, %y, 0, 0] [1, %sz0, 1, %sz1] [1, 1, 1, 1]
+ : tensor<?x?x?x?xf32> to tensor<1x?x1x?xf32>
+ %1 = tensor.collapse_shape %0 [[0, 1], [2, 3]]
+ : tensor<1x?x1x?xf32> into tensor<?x?xf32>
+ return %0, %1 : tensor<1x?x1x?xf32>, tensor<?x?xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @rank_reducing_insert_of_collapse_shape(
// CHECK-SAME: %[[t:.*]]: tensor<?x1x1x5xf32>
// CHECK: %[[insert:.*]] = tensor.insert_slice %[[t]] into %{{.*}}[0, 0, 0, 0] [%{{.*}}, 1, 1, 5] [1, 1, 1, 1] : tensor<?x1x1x5xf32> into tensor<?x?x?x?xf32>
``````````
</details>
https://github.com/llvm/llvm-project/pull/93554
More information about the Mlir-commits
mailing list