[Mlir-commits] [mlir] [mlir] Add bubbling patterns for non intersecting reshapes (PR #103401)
Ian Wood
llvmlistbot at llvm.org
Tue Aug 13 12:38:48 PDT 2024
https://github.com/IanWood1 created https://github.com/llvm/llvm-project/pull/103401
Refactored @Max191's PR https://github.com/llvm/llvm-project/pull/94637 to move it to `Tensor`
>From the original PR
>This PR adds fusion by expansion patterns to push a tensor.expand_shape up through a tensor.collapse_shape with non-intersecting reassociations. Sometimes parallel collapse_shape ops like this can block propagation of expand_shape ops, so this allows them to pass through each other.
I'm not sure if I put the code/tests in the right places, so let me know where those go if they aren't.
cc @MaheshRavishankar @hanhanW
>From 351237c91e4f56b13171e3cf3ca453b86b79afa0 Mon Sep 17 00:00:00 2001
From: Max Dawkins <max.dawkins at gmail.com>
Date: Thu, 23 May 2024 17:24:08 -0400
Subject: [PATCH 1/2] [mlir] Add bubbling patterns for non intersecting
reshapes
---
.../Linalg/Transforms/ElementwiseOpFusion.cpp | 71 +++++++++++++++++++
mlir/test/Dialect/Linalg/reshape_fusion.mlir | 34 +++++++++
2 files changed, 105 insertions(+)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index e73df61c964341..7aa8a0b37c219c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1086,6 +1086,76 @@ struct FoldReshapeWithGenericOpByExpansion
private:
ControlFusionFn controlFoldingReshapes;
};
+
+/// Pattern to bubble up a tensor.expand_shape op through a producer
+/// tensor.collapse_shape op that has non intersecting reassociations.
+struct BubbleUpExpandThroughParallelCollapse
+ : public OpRewritePattern<tensor::ExpandShapeOp> {
+ using OpRewritePattern<tensor::ExpandShapeOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp,
+ PatternRewriter &rewriter) const override {
+ auto collapseOp =
+ expandOp.getSrc().getDefiningOp<tensor::CollapseShapeOp>();
+ if (!collapseOp || !collapseOp->hasOneUse())
+ return failure();
+ auto expandReInds = expandOp.getReassociationIndices();
+ auto collapseReInds = collapseOp.getReassociationIndices();
+
+ // Reshapes are parallel to each other if none of the reassociation indices
+ // have greater than 1 index for both reshapes.
+ for (auto [expandReassociation, collapseReassociation] :
+ llvm::zip_equal(expandReInds, collapseReInds)) {
+ if (collapseReassociation.size() != 1 && expandReassociation.size() != 1)
+ return failure();
+ }
+
+ // Compute new reassociation indices and expanded/collaped shapes.
+ SmallVector<ReassociationIndices> newExpandReInds, newCollapseReInds;
+ Location loc = expandOp->getLoc();
+ SmallVector<OpFoldResult> collapseSizes =
+ tensor::getMixedSizes(rewriter, loc, collapseOp.getSrc());
+ SmallVector<OpFoldResult> expandSizes(getMixedValues(
+ expandOp.getStaticOutputShape(), expandOp.getOutputShape(), rewriter));
+ SmallVector<OpFoldResult> newExpandSizes;
+ int64_t index = 0, expandIndex = 0, collapseIndex = 0;
+ for (auto [idx, collapseReassociation] : llvm::enumerate(collapseReInds)) {
+ if (collapseReassociation.size() != 1) {
+ ReassociationIndices newCollapseReassociation;
+ for (size_t i = 0; i < collapseReassociation.size(); ++i) {
+ newCollapseReassociation.push_back(index);
+ newExpandReInds.push_back({index++});
+ newExpandSizes.push_back(collapseSizes[collapseIndex++]);
+ }
+ newCollapseReInds.push_back(newCollapseReassociation);
+ expandIndex++;
+ continue;
+ }
+ ReassociationIndices newExpandReassociation;
+ auto expandReassociation = expandReInds[idx];
+ for (size_t i = 0; i < expandReassociation.size(); ++i) {
+ newExpandReassociation.push_back(index);
+ newCollapseReInds.push_back({index++});
+ newExpandSizes.push_back(expandSizes[expandIndex++]);
+ }
+ newExpandReInds.push_back(newExpandReassociation);
+ collapseIndex++;
+ }
+
+ // Swap reshape order.
+ SmallVector<Value> dynamicSizes;
+ SmallVector<int64_t> staticSizes;
+ dispatchIndexOpFoldResults(newExpandSizes, dynamicSizes, staticSizes);
+ auto expandResultType = expandOp.getResultType().clone(staticSizes);
+ auto newExpand = rewriter.create<tensor::ExpandShapeOp>(
+ loc, expandResultType, collapseOp.getSrc(), newExpandReInds,
+ newExpandSizes);
+ rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
+ expandOp, newExpand.getResult(), newCollapseReInds);
+ return success();
+ }
+};
+
} // namespace
//===---------------------------------------------------------------------===//
@@ -2083,6 +2153,7 @@ void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
controlFoldingReshapes);
patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(),
controlFoldingReshapes);
+ patterns.add<BubbleUpExpandThroughParallelCollapse>(patterns.getContext());
}
void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns(
diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
index b8df5fc88e1999..86c2904218385c 100644
--- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir
+++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
@@ -887,3 +887,37 @@ func.func @fuse_by_expanding_dynamic_pad(%arg0 : tensor<?x?x?x?x?x?xi32>, %l0: i
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[PAD]] {{\[}}[0], [1, 2], [3], [4, 5]]
// CHECK-SAME: : tensor<?x?x?x?x?x?xi32> into tensor<?x?x?x?xi32>
// CHECK: return %[[COLLAPSE]]
+
+// -----
+
+func.func @bubble_parallel_reshapes(%arg0: tensor<?x?x?x?xf32>, %s0: index, %s1: index, %s2: index, %s3: index) -> tensor<?x?x?x?xf32> {
+ %collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3]] : tensor<?x?x?x?xf32> into tensor<?x?x?xf32>
+ %expand = tensor.expand_shape %collapse [[0], [1], [2, 3]]
+ output_shape [%s0, %s1, %s2, %s3] : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
+ return %expand : tensor<?x?x?x?xf32>
+}
+// CHECK: func @bubble_parallel_reshapes
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?x?xf32>
+// CHECK-SAME: %[[S0:.+]]: index, %[[S1:.+]]: index, %[[S2:.+]]: index, %[[S3:.+]]: index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
+// CHECK-DAG: %[[DIM1:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?x?x?xf32>
+// CHECK-DAG: %[[DIM2:.+]] = tensor.dim %[[ARG0]], %[[C2]] : tensor<?x?x?x?xf32>
+// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1], [2], [3, 4]]
+// CHECK-SAME: output_shape [%[[S0]], %[[DIM1]], %[[DIM2]], %[[S2]], %[[S3]]] : tensor<?x?x?x?xf32> into tensor<?x?x?x?x?xf32>
+// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[EXPAND]] {{\[}}[0], [1, 2], [3], [4]] : tensor<?x?x?x?x?xf32> into tensor<?x?x?x?xf32>
+// CHECK: return %[[COLLAPSE]]
+
+// -----
+
+func.func @no_bubble_intersecting_reshapes(%arg0: tensor<?x?x?x?xf32>, %s0: index, %s1: index, %s2: index, %s3: index) -> tensor<?x?x?x?xf32> {
+ %collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3]] : tensor<?x?x?x?xf32> into tensor<?x?x?xf32>
+ %expand = tensor.expand_shape %collapse [[0], [1, 2], [3]]
+ output_shape [%s0, %s1, %s2, %s3] : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
+ return %expand : tensor<?x?x?x?xf32>
+}
+// CHECK: func @no_bubble_intersecting_reshapes
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?x?xf32>
+// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2], [3]]
+// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[COLLAPSE]] {{\[}}[0], [1, 2], [3]]
+// CHECK: return %[[EXPAND]]
>From 643ee8cd7a6ca805ddbd4fe482d85acd086d2782 Mon Sep 17 00:00:00 2001
From: Ian Wood <ianwood2024 at u.northwestern.edu>
Date: Tue, 13 Aug 2024 17:43:44 +0000
Subject: [PATCH 2/2] Refactor logic and tests to Tensor
---
.../Dialect/Tensor/Transforms/Transforms.h | 4 +
.../Linalg/Transforms/ElementwiseOpFusion.cpp | 71 ------------------
.../Tensor/Transforms/ReshapePatterns.cpp | 75 +++++++++++++++++++
mlir/test/Dialect/Linalg/reshape_fusion.mlir | 34 ---------
mlir/test/Dialect/Tensor/bubble-reshapes.mlir | 47 ++++++++++++
.../Dialect/Tensor/TestTensorTransforms.cpp | 13 ++++
6 files changed, 139 insertions(+), 105 deletions(-)
create mode 100644 mlir/test/Dialect/Tensor/bubble-reshapes.mlir
diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
index 7f983b8b3cfd06..ae695e0326ca1a 100644
--- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
@@ -67,6 +67,10 @@ void populateDropRedundantInsertSliceRankExpansionPatterns(
/// `tensor.collapse_shape` into other ops.
void populateReassociativeReshapeFoldingPatterns(RewritePatternSet &patterns);
+/// Populates `patterns` with patterns that bubble up `tensor.expand_shape`
+/// through `tensor.collapse_shape` ops.
+void populateBubbleUpExpandShapePatterns(RewritePatternSet &patterns);
+
/// Populates `patterns` with patterns that fold tensor.empty with its
/// consumers.
///
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 7aa8a0b37c219c..e73df61c964341 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1086,76 +1086,6 @@ struct FoldReshapeWithGenericOpByExpansion
private:
ControlFusionFn controlFoldingReshapes;
};
-
-/// Pattern to bubble up a tensor.expand_shape op through a producer
-/// tensor.collapse_shape op that has non intersecting reassociations.
-struct BubbleUpExpandThroughParallelCollapse
- : public OpRewritePattern<tensor::ExpandShapeOp> {
- using OpRewritePattern<tensor::ExpandShapeOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp,
- PatternRewriter &rewriter) const override {
- auto collapseOp =
- expandOp.getSrc().getDefiningOp<tensor::CollapseShapeOp>();
- if (!collapseOp || !collapseOp->hasOneUse())
- return failure();
- auto expandReInds = expandOp.getReassociationIndices();
- auto collapseReInds = collapseOp.getReassociationIndices();
-
- // Reshapes are parallel to each other if none of the reassociation indices
- // have greater than 1 index for both reshapes.
- for (auto [expandReassociation, collapseReassociation] :
- llvm::zip_equal(expandReInds, collapseReInds)) {
- if (collapseReassociation.size() != 1 && expandReassociation.size() != 1)
- return failure();
- }
-
- // Compute new reassociation indices and expanded/collaped shapes.
- SmallVector<ReassociationIndices> newExpandReInds, newCollapseReInds;
- Location loc = expandOp->getLoc();
- SmallVector<OpFoldResult> collapseSizes =
- tensor::getMixedSizes(rewriter, loc, collapseOp.getSrc());
- SmallVector<OpFoldResult> expandSizes(getMixedValues(
- expandOp.getStaticOutputShape(), expandOp.getOutputShape(), rewriter));
- SmallVector<OpFoldResult> newExpandSizes;
- int64_t index = 0, expandIndex = 0, collapseIndex = 0;
- for (auto [idx, collapseReassociation] : llvm::enumerate(collapseReInds)) {
- if (collapseReassociation.size() != 1) {
- ReassociationIndices newCollapseReassociation;
- for (size_t i = 0; i < collapseReassociation.size(); ++i) {
- newCollapseReassociation.push_back(index);
- newExpandReInds.push_back({index++});
- newExpandSizes.push_back(collapseSizes[collapseIndex++]);
- }
- newCollapseReInds.push_back(newCollapseReassociation);
- expandIndex++;
- continue;
- }
- ReassociationIndices newExpandReassociation;
- auto expandReassociation = expandReInds[idx];
- for (size_t i = 0; i < expandReassociation.size(); ++i) {
- newExpandReassociation.push_back(index);
- newCollapseReInds.push_back({index++});
- newExpandSizes.push_back(expandSizes[expandIndex++]);
- }
- newExpandReInds.push_back(newExpandReassociation);
- collapseIndex++;
- }
-
- // Swap reshape order.
- SmallVector<Value> dynamicSizes;
- SmallVector<int64_t> staticSizes;
- dispatchIndexOpFoldResults(newExpandSizes, dynamicSizes, staticSizes);
- auto expandResultType = expandOp.getResultType().clone(staticSizes);
- auto newExpand = rewriter.create<tensor::ExpandShapeOp>(
- loc, expandResultType, collapseOp.getSrc(), newExpandReInds,
- newExpandSizes);
- rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
- expandOp, newExpand.getResult(), newCollapseReInds);
- return success();
- }
-};
-
} // namespace
//===---------------------------------------------------------------------===//
@@ -2153,7 +2083,6 @@ void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
controlFoldingReshapes);
patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(),
controlFoldingReshapes);
- patterns.add<BubbleUpExpandThroughParallelCollapse>(patterns.getContext());
}
void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns(
diff --git a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
index be0d71866a095e..061817e41d181e 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
@@ -140,6 +140,76 @@ struct FoldPaddingExpandIntoInsert : public OpRewritePattern<OpTy> {
return success();
}
};
+
+/// Pattern to bubble up a tensor.expand_shape op through a producer
+/// tensor.collapse_shape op that has non intersecting reassociations.
+struct BubbleUpExpandThroughParallelCollapse
+ : public OpRewritePattern<tensor::ExpandShapeOp> {
+ using OpRewritePattern<tensor::ExpandShapeOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp,
+ PatternRewriter &rewriter) const override {
+ auto collapseOp =
+ expandOp.getSrc().getDefiningOp<tensor::CollapseShapeOp>();
+ if (!collapseOp || !collapseOp->hasOneUse())
+ return failure();
+ auto expandReInds = expandOp.getReassociationIndices();
+ auto collapseReInds = collapseOp.getReassociationIndices();
+
+ // Reshapes are parallel to each other if none of the reassociation indices
+ // have greater than 1 index for both reshapes.
+ for (auto [expandReassociation, collapseReassociation] :
+ llvm::zip_equal(expandReInds, collapseReInds)) {
+ if (collapseReassociation.size() != 1 && expandReassociation.size() != 1)
+ return failure();
+ }
+
+ // Compute new reassociation indices and expanded/collaped shapes.
+ SmallVector<ReassociationIndices> newExpandReInds, newCollapseReInds;
+ Location loc = expandOp->getLoc();
+ SmallVector<OpFoldResult> collapseSizes =
+ tensor::getMixedSizes(rewriter, loc, collapseOp.getSrc());
+ SmallVector<OpFoldResult> expandSizes(getMixedValues(
+ expandOp.getStaticOutputShape(), expandOp.getOutputShape(), rewriter));
+ SmallVector<OpFoldResult> newExpandSizes;
+ int64_t index = 0, expandIndex = 0, collapseIndex = 0;
+ for (auto [idx, collapseReassociation] : llvm::enumerate(collapseReInds)) {
+ if (collapseReassociation.size() != 1) {
+ ReassociationIndices newCollapseReassociation;
+ for (size_t i = 0; i < collapseReassociation.size(); ++i) {
+ newCollapseReassociation.push_back(index);
+ newExpandReInds.push_back({index++});
+ newExpandSizes.push_back(collapseSizes[collapseIndex++]);
+ }
+ newCollapseReInds.push_back(newCollapseReassociation);
+ expandIndex++;
+ continue;
+ }
+ ReassociationIndices newExpandReassociation;
+ auto expandReassociation = expandReInds[idx];
+ for (size_t i = 0; i < expandReassociation.size(); ++i) {
+ newExpandReassociation.push_back(index);
+ newCollapseReInds.push_back({index++});
+ newExpandSizes.push_back(expandSizes[expandIndex++]);
+ }
+ newExpandReInds.push_back(newExpandReassociation);
+ collapseIndex++;
+ }
+
+ // Swap reshape order.
+ SmallVector<Value> dynamicSizes;
+ SmallVector<int64_t> staticSizes;
+ dispatchIndexOpFoldResults(newExpandSizes, dynamicSizes, staticSizes);
+ auto expandResultType = expandOp.getResultType().clone(staticSizes);
+ auto newExpand = rewriter.create<tensor::ExpandShapeOp>(
+ loc, expandResultType, collapseOp.getSrc(), newExpandReInds,
+ newExpandSizes);
+ rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
+ expandOp, newExpand.getResult(), newCollapseReInds);
+ return success();
+ }
+};
+
} // namespace
void mlir::tensor::populateReassociativeReshapeFoldingPatterns(
@@ -152,3 +222,8 @@ void mlir::tensor::populateReassociativeReshapeFoldingPatterns(
FoldPaddingExpandIntoInsert<tensor::ParallelInsertSliceOp>>(
patterns.getContext());
}
+
+void mlir::tensor::populateBubbleUpExpandShapePatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<BubbleUpExpandThroughParallelCollapse>(patterns.getContext());
+}
diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
index 86c2904218385c..b8df5fc88e1999 100644
--- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir
+++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
@@ -887,37 +887,3 @@ func.func @fuse_by_expanding_dynamic_pad(%arg0 : tensor<?x?x?x?x?x?xi32>, %l0: i
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[PAD]] {{\[}}[0], [1, 2], [3], [4, 5]]
// CHECK-SAME: : tensor<?x?x?x?x?x?xi32> into tensor<?x?x?x?xi32>
// CHECK: return %[[COLLAPSE]]
-
-// -----
-
-func.func @bubble_parallel_reshapes(%arg0: tensor<?x?x?x?xf32>, %s0: index, %s1: index, %s2: index, %s3: index) -> tensor<?x?x?x?xf32> {
- %collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3]] : tensor<?x?x?x?xf32> into tensor<?x?x?xf32>
- %expand = tensor.expand_shape %collapse [[0], [1], [2, 3]]
- output_shape [%s0, %s1, %s2, %s3] : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
- return %expand : tensor<?x?x?x?xf32>
-}
-// CHECK: func @bubble_parallel_reshapes
-// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?x?xf32>
-// CHECK-SAME: %[[S0:.+]]: index, %[[S1:.+]]: index, %[[S2:.+]]: index, %[[S3:.+]]: index
-// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
-// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
-// CHECK-DAG: %[[DIM1:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?x?x?xf32>
-// CHECK-DAG: %[[DIM2:.+]] = tensor.dim %[[ARG0]], %[[C2]] : tensor<?x?x?x?xf32>
-// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1], [2], [3, 4]]
-// CHECK-SAME: output_shape [%[[S0]], %[[DIM1]], %[[DIM2]], %[[S2]], %[[S3]]] : tensor<?x?x?x?xf32> into tensor<?x?x?x?x?xf32>
-// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[EXPAND]] {{\[}}[0], [1, 2], [3], [4]] : tensor<?x?x?x?x?xf32> into tensor<?x?x?x?xf32>
-// CHECK: return %[[COLLAPSE]]
-
-// -----
-
-func.func @no_bubble_intersecting_reshapes(%arg0: tensor<?x?x?x?xf32>, %s0: index, %s1: index, %s2: index, %s3: index) -> tensor<?x?x?x?xf32> {
- %collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3]] : tensor<?x?x?x?xf32> into tensor<?x?x?xf32>
- %expand = tensor.expand_shape %collapse [[0], [1, 2], [3]]
- output_shape [%s0, %s1, %s2, %s3] : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
- return %expand : tensor<?x?x?x?xf32>
-}
-// CHECK: func @no_bubble_intersecting_reshapes
-// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?x?xf32>
-// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2], [3]]
-// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[COLLAPSE]] {{\[}}[0], [1, 2], [3]]
-// CHECK: return %[[EXPAND]]
diff --git a/mlir/test/Dialect/Tensor/bubble-reshapes.mlir b/mlir/test/Dialect/Tensor/bubble-reshapes.mlir
new file mode 100644
index 00000000000000..cf6b12852bcd39
--- /dev/null
+++ b/mlir/test/Dialect/Tensor/bubble-reshapes.mlir
@@ -0,0 +1,47 @@
+// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns=test-expand-shape-bubbling %s | FileCheck %s
+
+func.func @bubble_parallel_reshapes(%arg0: tensor<?x?x?x?xf32>, %s0: index, %s1: index, %s2: index, %s3: index) -> tensor<?x?x?x?xf32> {
+ %collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3]] : tensor<?x?x?x?xf32> into tensor<?x?x?xf32>
+ %expand = tensor.expand_shape %collapse [[0], [1], [2, 3]]
+ output_shape [%s0, %s1, %s2, %s3] : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
+ return %expand : tensor<?x?x?x?xf32>
+}
+// CHECK: func @bubble_parallel_reshapes
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?x?xf32>
+// CHECK-SAME: %[[S0:.+]]: index, %[[S1:.+]]: index, %[[S2:.+]]: index, %[[S3:.+]]: index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
+// CHECK-DAG: %[[DIM1:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?x?x?xf32>
+// CHECK-DAG: %[[DIM2:.+]] = tensor.dim %[[ARG0]], %[[C2]] : tensor<?x?x?x?xf32>
+// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1], [2], [3, 4]]
+// CHECK-SAME: output_shape [%[[S0]], %[[DIM1]], %[[DIM2]], %[[S2]], %[[S3]]] : tensor<?x?x?x?xf32> into tensor<?x?x?x?x?xf32>
+// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[EXPAND]] {{\[}}[0], [1, 2], [3], [4]] : tensor<?x?x?x?x?xf32> into tensor<?x?x?x?xf32>
+// CHECK: return %[[COLLAPSE]]
+
+// -----
+
+func.func @no_bubble_full_intersecting_reshapes(%arg0: tensor<?x?x?x?xf32>, %s0: index, %s1: index, %s2: index, %s3: index) -> tensor<?x?x?x?xf32> {
+ %collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3]] : tensor<?x?x?x?xf32> into tensor<?x?x?xf32>
+ %expand = tensor.expand_shape %collapse [[0], [1, 2], [3]]
+ output_shape [%s0, %s1, %s2, %s3] : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
+ return %expand : tensor<?x?x?x?xf32>
+}
+// CHECK: func @no_bubble_full_intersecting_reshapes
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?x?xf32>
+// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2], [3]]
+// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[COLLAPSE]] {{\[}}[0], [1, 2], [3]]
+// CHECK: return %[[EXPAND]]
+
+// -----
+
+func.func @no_bubble_partial_intersecting_reshapes(%arg0: tensor<?x?x?x?xf32>, %s0: index, %s1: index, %s2: index, %s3: index) -> tensor<?x?x?x?xf32> {
+ %collapse = tensor.collapse_shape %arg0 [[0, 1, 2], [3]] : tensor<?x?x?x?xf32> into tensor<?x?xf32>
+ %expand = tensor.expand_shape %collapse [[0, 1], [2, 3]]
+ output_shape [%s0, %s1, %s2, %s3] : tensor<?x?xf32> into tensor<?x?x?x?xf32>
+ return %expand : tensor<?x?x?x?xf32>
+}
+// CHECK: func @no_bubble_partial_intersecting_reshapes
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?x?xf32>
+// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2], [3]]
+// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[COLLAPSE]] {{\[}}[0, 1], [2, 3]]
+// CHECK: return %[[EXPAND]]
diff --git a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
index ae4f77f5873e2b..34de600132f5de 100644
--- a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
@@ -72,6 +72,11 @@ struct TestTensorTransforms
llvm::cl::desc("Test folding of expand_shape/collapse_shape"),
llvm::cl::init(false)};
+ Option<bool> testBubbleUpExpandShapePatterns{
+ *this, "test-expand-shape-bubbling",
+ llvm::cl::desc("Test folding of expand_shape/collapse_shape"),
+ llvm::cl::init(false)};
+
Option<bool> testFoldIntoPackAndUnpack{
*this, "test-fold-into-pack-and-unpack",
llvm::cl::desc("Test folding ops into tensor.pack and tensor.unpack"),
@@ -102,6 +107,12 @@ static void applyReassociativeReshapeFoldingPatterns(Operation *rootOp) {
(void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns));
}
+static void applyBubbleUpExpandShapePatterns(Operation *rootOp) {
+ RewritePatternSet patterns(rootOp->getContext());
+ tensor::populateBubbleUpExpandShapePatterns(patterns);
+ (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns));
+}
+
static void applyFoldIntoPackAndUnpackPatterns(Operation *rootOp) {
RewritePatternSet patterns(rootOp->getContext());
tensor::populateFoldIntoPackAndUnpackPatterns(patterns);
@@ -386,6 +397,8 @@ void TestTensorTransforms::runOnOperation() {
applyDropRedundantInsertSliceRankExpansionPatterns(rootOp);
if (testReassociativeReshapeFolding)
applyReassociativeReshapeFoldingPatterns(rootOp);
+ if (testBubbleUpExpandShapePatterns)
+ applyBubbleUpExpandShapePatterns(rootOp);
if (testFoldIntoPackAndUnpack)
applyFoldIntoPackAndUnpackPatterns(rootOp);
if (testRewriteExtractSliceWithTiledCollapseShape) {
More information about the Mlir-commits
mailing list