[Mlir-commits] [mlir] [mlir] Add bubbling patterns for non intersecting reshapes (PR #94637)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jun 6 09:23:59 PDT 2024
https://github.com/Max191 created https://github.com/llvm/llvm-project/pull/94637
This PR adds fusion by expansion patterns to push a tensor.expand_shape up through a tensor.collapse_shape with non-intersecting reassociations. Sometimes parallel collapse_shape ops like this can block propagation of expand_shape ops, so this allows them to pass through each other.
>From 930a4d72e8d818af62744070c00cd667aaacbd9e Mon Sep 17 00:00:00 2001
From: Max Dawkins <max.dawkins at gmail.com>
Date: Thu, 6 Jun 2024 10:46:45 -0400
Subject: [PATCH 1/4] [mlir] Fix bugs in expand_shape patterns after semantics
changes
---
.../mlir/Dialect/Utils/ReshapeOpsUtils.h | 56 ++++++++++++++----
mlir/test/Dialect/Tensor/canonicalize.mlir | 57 ++++++++++++++++++-
2 files changed, 101 insertions(+), 12 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index e8f6edc3f133e..3b986f4a60064 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -85,21 +85,55 @@ bool isReassociationValid(ArrayRef<AffineMap> reassociation,
template <typename ReshapeOpTy, typename InverseReshapeOpTy>
static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp,
ArrayRef<Attribute> operands) {
-
+ // Fold identity reshape.
if (reshapeOp.getSrcType() == reshapeOp.getType())
return reshapeOp.getSrc();
- // Fold producer-consumer reshape ops where the operand type of the
- // producer is same as the return type of the consumer.
- auto reshapeSrcOp =
- reshapeOp.getSrc().template getDefiningOp<InverseReshapeOpTy>();
- if (reshapeSrcOp && reshapeSrcOp.getSrcType() == reshapeOp.getResultType())
- return reshapeSrcOp.getSrc();
-
// Reshape of a constant can be replaced with a new constant.
if (auto elements = dyn_cast_or_null<DenseElementsAttr>(operands.front()))
return elements.reshape(cast<ShapedType>(reshapeOp.getResult().getType()));
+ // Fold if the producer reshape source has the same shape with at most 1
+ // dynamic dimension.
+ auto reshapeSrcOp =
+ reshapeOp.getSrc().template getDefiningOp<InverseReshapeOpTy>();
+ if (!reshapeSrcOp)
+ return nullptr;
+ auto srcType = reshapeSrcOp.getSrcType();
+ auto resultType = reshapeOp.getResultType();
+ if (srcType != resultType)
+ return nullptr;
+
+ // If the reshapes are expanding and then collapsing, the ops can be folded
+ // despite multiple dynamic dimensions.
+ if (srcType.getRank() < reshapeSrcOp.getResultType().getRank())
+ return reshapeSrcOp.getSrc();
+ // Otherwise, only 1 dynamic dimension is allowed.
+ if (srcType == resultType &&
+ llvm::count_if(srcType.getShape(), ShapedType::isDynamic) < 2) {
+ return reshapeSrcOp.getSrc();
+ }
+
+ // Fold producer-consumer reshape ops when they are perfect inverses of each
+ // other:
+ // 1) Reassociation indices are equivalent.
+ // 2) Boundary types are equivalent.
+ // 3) No reassociations have more than 1 dynamic dimension, and reassociated
+ // shapes are equal for each reassociation.
+ auto reassociations = reshapeOp.getReassociationIndices();
+ auto inverseReassociations = reshapeSrcOp.getReassociationIndices();
+ if (reassociations != inverseReassociations)
+ return nullptr;
+ ArrayRef<int64_t> expandedSrcShape = srcType.getShape();
+ ArrayRef<int64_t> expandedResultShape = resultType.getShape();
+ if (llvm::none_of(reassociations, [&](auto reInd) {
+ auto srcSlice = expandedSrcShape.slice(reInd.front(), reInd.size());
+ auto resSlice = expandedResultShape.slice(reInd.front(), reInd.size());
+ return srcSlice == resSlice &&
+ llvm::count_if(srcSlice, ShapedType::isDynamic) > 1;
+ })) {
+ return reshapeSrcOp.getSrc();
+ }
return nullptr;
}
@@ -360,10 +394,12 @@ struct ComposeExpandOfCollapseOp : public OpRewritePattern<ExpandOpTy> {
resultShape.slice(resultIndices.front(), resultIndices.size());
if (srcSubShape.size() == resultSubShape.size()) {
- if (srcSubShape == resultSubShape)
+ if (srcSubShape == resultSubShape &&
+ llvm::count_if(srcSubShape, ShapedType::isDynamic) < 2) {
composedReassociation.push_back(srcIndices);
- else
+ } else {
return std::nullopt;
+ }
}
// Find reassociation to collapse `srcSubShape` into `resultSubShape`.
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index f7fbd3834288b..4a04d37d4be29 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1139,7 +1139,7 @@ func.func @fold_collapse_of_expand(%arg0 : tensor<12x4xf32>) -> tensor<12x4xf32>
return %1 : tensor<12x4xf32>
}
// CHECK-LABEL: @fold_collapse_of_expand
-// CHECK-NOT: linalg.{{.*}}shape
+// CHECK-NOT: tensor.{{.*}}_shape
// -----
@@ -1152,7 +1152,60 @@ func.func @fold_collapse_of_expand_dynamic(%arg0 : tensor<?x?xf32>, %arg1: index
return %1 : tensor<?x?xf32>
}
// CHECK-LABEL: @fold_collapse_of_expand_dynamic
-// CHECK-NOT: linalg.{{.*}}_shape
+// CHECK-NOT: tensor.{{.*}}_shape
+
+// -----
+
+func.func @fold_collapse_of_expand_fully_dynamic(%arg0 : tensor<?x?xf32>, %arg1: index, %arg2: index, %arg3: index)
+ -> tensor<?x?xf32> {
+ %0 = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [%arg1, %arg2, %arg3]
+ : tensor<?x?xf32> into tensor<?x?x?xf32>
+ %1 = tensor.collapse_shape %0 [[0, 1], [2]]
+ : tensor<?x?x?xf32> into tensor<?x?xf32>
+ return %1 : tensor<?x?xf32>
+}
+// CHECK-LABEL: @fold_collapse_of_expand_fully_dynamic
+// CHECK-NOT: tensor.{{.*}}_shape
+
+// -----
+
+func.func @fold_expand_of_collapse(%arg0 : tensor<3x4x4xf32>) -> tensor<3x4x4xf32> {
+ %0 = tensor.collapse_shape %arg0 [[0, 1], [2]]
+ : tensor<3x4x4xf32> into tensor<12x4xf32>
+ %1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [3, 4, 4]
+ : tensor<12x4xf32> into tensor<3x4x4xf32>
+ return %1 : tensor<3x4x4xf32>
+}
+// CHECK-LABEL: @fold_expand_of_collapse
+// CHECK-NOT: tensor.{{.*}}_shape
+
+// -----
+
+func.func @fold_expand_of_collapse_dynamic(%arg0 : tensor<?x4x?xf32>, %arg1: index, %arg2: index)
+ -> tensor<?x4x?xf32> {
+ %0 = tensor.collapse_shape %arg0 [[0, 1], [2]]
+ : tensor<?x4x?xf32> into tensor<?x?xf32>
+ %1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%arg1, 4, %arg2]
+ : tensor<?x?xf32> into tensor<?x4x?xf32>
+ return %1 : tensor<?x4x?xf32>
+}
+// CHECK-LABEL: @fold_expand_of_collapse_dynamic
+// CHECK-NOT: tensor.{{.*}}_shape
+
+// -----
+
+func.func @no_fold_expand_of_collapse_dynamic(%arg0 : tensor<?x?x?xf32>, %arg1: index, %arg2: index, %arg3: index)
+ -> tensor<?x?x?xf32> {
+ %0 = tensor.collapse_shape %arg0 [[0, 1], [2]]
+ : tensor<?x?x?xf32> into tensor<?x?xf32>
+ %1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%arg1, %arg2, %arg3]
+ : tensor<?x?xf32> into tensor<?x?x?xf32>
+ return %1 : tensor<?x?x?xf32>
+}
+// CHECK-LABEL: @no_fold_expand_of_collapse_dynamic
+// CHECK: tensor.collapse_shape
+// CHECK: %[[EXPAND:.+]] = tensor.expand_shape
+// CHECK: return %[[EXPAND]]
// -----
>From 0536c7ec051aa8b2d6b2b9cc04a54a2d5bcfdb8d Mon Sep 17 00:00:00 2001
From: Max Dawkins <max.dawkins at gmail.com>
Date: Thu, 6 Jun 2024 12:07:17 -0400
Subject: [PATCH 2/4] fix bug
---
mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h | 9 ++++-----
mlir/test/Dialect/Tensor/canonicalize.mlir | 15 +++++++++++++++
2 files changed, 19 insertions(+), 5 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index 3b986f4a60064..31a23be26d5a7 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -104,11 +104,6 @@ static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp,
if (srcType != resultType)
return nullptr;
- // If the reshapes are expanding and then collapsing, the ops can be folded
- // despite multiple dynamic dimensions.
- if (srcType.getRank() < reshapeSrcOp.getResultType().getRank())
- return reshapeSrcOp.getSrc();
- // Otherwise, only 1 dynamic dimension is allowed.
if (srcType == resultType &&
llvm::count_if(srcType.getShape(), ShapedType::isDynamic) < 2) {
return reshapeSrcOp.getSrc();
@@ -124,6 +119,10 @@ static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp,
auto inverseReassociations = reshapeSrcOp.getReassociationIndices();
if (reassociations != inverseReassociations)
return nullptr;
+ // If the reshapes are expanding and then collapsing, the ops can be folded
+ // despite multiple dynamic dimensions.
+ if (srcType.getRank() < reshapeSrcOp.getResultType().getRank())
+ return reshapeSrcOp.getSrc();
ArrayRef<int64_t> expandedSrcShape = srcType.getShape();
ArrayRef<int64_t> expandedResultShape = resultType.getShape();
if (llvm::none_of(reassociations, [&](auto reInd) {
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 4a04d37d4be29..9a6b03986ccb6 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1169,6 +1169,21 @@ func.func @fold_collapse_of_expand_fully_dynamic(%arg0 : tensor<?x?xf32>, %arg1:
// -----
+func.func @no_fold_parallel_collapse_of_expand_dynamic(%arg0 : tensor<?x?x?xf32>, %arg1: index, %arg2: index, %arg3: index, %arg4: index)
+ -> tensor<?x?x?xf32> {
+ %0 = tensor.expand_shape %arg0 [[0, 1], [2], [3]] output_shape [%arg1, %arg2, %arg3, %arg4]
+ : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
+ %1 = tensor.collapse_shape %0 [[0], [1], [2, 3]]
+ : tensor<?x?x?x?xf32> into tensor<?x?x?xf32>
+ return %1 : tensor<?x?x?xf32>
+}
+// CHECK-LABEL: @no_fold_parallel_collapse_of_expand_dynamic
+// CHECK: tensor.expand_shape
+// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape
+// CHECK: return %[[COLLAPSE]]
+
+// -----
+
func.func @fold_expand_of_collapse(%arg0 : tensor<3x4x4xf32>) -> tensor<3x4x4xf32> {
%0 = tensor.collapse_shape %arg0 [[0, 1], [2]]
: tensor<3x4x4xf32> into tensor<12x4xf32>
>From 2dc8fea7edb4797e15bf1c555dae57ec42a393b4 Mon Sep 17 00:00:00 2001
From: Max Dawkins <max.dawkins at gmail.com>
Date: Thu, 6 Jun 2024 12:16:45 -0400
Subject: [PATCH 3/4] address comments
---
mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h | 15 ++++++---------
1 file changed, 6 insertions(+), 9 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index 31a23be26d5a7..96f0f7bf1aa49 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -104,8 +104,7 @@ static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp,
if (srcType != resultType)
return nullptr;
- if (srcType == resultType &&
- llvm::count_if(srcType.getShape(), ShapedType::isDynamic) < 2) {
+ if (llvm::count_if(srcType.getShape(), ShapedType::isDynamic) < 2) {
return reshapeSrcOp.getSrc();
}
@@ -116,8 +115,7 @@ static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp,
// 3) No reassociations have more than 1 dynamic dimension, and reassociated
// shapes are equal for each reassociation.
auto reassociations = reshapeOp.getReassociationIndices();
- auto inverseReassociations = reshapeSrcOp.getReassociationIndices();
- if (reassociations != inverseReassociations)
+ if (reassociations != reshapeSrcOp.getReassociationIndices())
return nullptr;
// If the reshapes are expanding and then collapsing, the ops can be folded
// despite multiple dynamic dimensions.
@@ -125,11 +123,10 @@ static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp,
return reshapeSrcOp.getSrc();
ArrayRef<int64_t> expandedSrcShape = srcType.getShape();
ArrayRef<int64_t> expandedResultShape = resultType.getShape();
- if (llvm::none_of(reassociations, [&](auto reInd) {
- auto srcSlice = expandedSrcShape.slice(reInd.front(), reInd.size());
- auto resSlice = expandedResultShape.slice(reInd.front(), reInd.size());
- return srcSlice == resSlice &&
- llvm::count_if(srcSlice, ShapedType::isDynamic) > 1;
+ if (llvm::all_of(reassociations, [&](auto reInd) {
+ ArrayRef<int64_t> srcSlice =
+ expandedSrcShape.slice(reInd.front(), reInd.size());
+ return llvm::count_if(srcSlice, ShapedType::isDynamic) < 2;
})) {
return reshapeSrcOp.getSrc();
}
>From 8b5a6be14375f194de1827de3193b416c2f15589 Mon Sep 17 00:00:00 2001
From: Max Dawkins <max.dawkins at gmail.com>
Date: Thu, 23 May 2024 17:24:08 -0400
Subject: [PATCH 4/4] [mlir] Add bubbling patterns for non intersecting
reshapes
---
.../Linalg/Transforms/ElementwiseOpFusion.cpp | 71 +++++++++++++++++++
mlir/test/Dialect/Linalg/reshape_fusion.mlir | 34 +++++++++
2 files changed, 105 insertions(+)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index ad313c2d5ce60..579116904aad2 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1023,6 +1023,76 @@ struct FoldReshapeWithGenericOpByExpansion
private:
ControlFusionFn controlFoldingReshapes;
};
+
+/// Pattern to bubble up a tensor.expand_shape op through a producer
+/// tensor.collapse_shape op that has non intersecting reassociations.
+struct BubbleUpExpandThroughParallelCollapse
+ : public OpRewritePattern<tensor::ExpandShapeOp> {
+ using OpRewritePattern<tensor::ExpandShapeOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp,
+ PatternRewriter &rewriter) const override {
+ auto collapseOp =
+ expandOp.getSrc().getDefiningOp<tensor::CollapseShapeOp>();
+ if (!collapseOp || !collapseOp->hasOneUse())
+ return failure();
+ auto expandReInds = expandOp.getReassociationIndices();
+ auto collapseReInds = collapseOp.getReassociationIndices();
+
+ // Reshapes are parallel to each other if none of the reassociation indices
+ // have greater than 1 index for both reshapes.
+ for (auto [expandReassociation, collapseReassociation] :
+ llvm::zip_equal(expandReInds, collapseReInds)) {
+ if (collapseReassociation.size() != 1 && expandReassociation.size() != 1)
+ return failure();
+ }
+
+ // Compute new reassociation indices and expanded/collaped shapes.
+ SmallVector<ReassociationIndices> newExpandReInds, newCollapseReInds;
+ Location loc = expandOp->getLoc();
+ SmallVector<OpFoldResult> collapseSizes =
+ tensor::getMixedSizes(rewriter, loc, collapseOp.getSrc());
+ SmallVector<OpFoldResult> expandSizes(getMixedValues(
+ expandOp.getStaticOutputShape(), expandOp.getOutputShape(), rewriter));
+ SmallVector<OpFoldResult> newExpandSizes;
+ int64_t index = 0, expandIndex = 0, collapseIndex = 0;
+ for (auto [idx, collapseReassociation] : llvm::enumerate(collapseReInds)) {
+ if (collapseReassociation.size() != 1) {
+ ReassociationIndices newCollapseReassociation;
+ for (size_t i = 0; i < collapseReassociation.size(); ++i) {
+ newCollapseReassociation.push_back(index);
+ newExpandReInds.push_back({index++});
+ newExpandSizes.push_back(collapseSizes[collapseIndex++]);
+ }
+ newCollapseReInds.push_back(newCollapseReassociation);
+ expandIndex++;
+ continue;
+ }
+ ReassociationIndices newExpandReassociation;
+ auto expandReassociation = expandReInds[idx];
+ for (size_t i = 0; i < expandReassociation.size(); ++i) {
+ newExpandReassociation.push_back(index);
+ newCollapseReInds.push_back({index++});
+ newExpandSizes.push_back(expandSizes[expandIndex++]);
+ }
+ newExpandReInds.push_back(newExpandReassociation);
+ collapseIndex++;
+ }
+
+ // Swap reshape order.
+ SmallVector<Value> dynamicSizes;
+ SmallVector<int64_t> staticSizes;
+ dispatchIndexOpFoldResults(newExpandSizes, dynamicSizes, staticSizes);
+ auto expandResultType = expandOp.getResultType().clone(staticSizes);
+ auto newExpand = rewriter.create<tensor::ExpandShapeOp>(
+ loc, expandResultType, collapseOp.getSrc(), newExpandReInds,
+ newExpandSizes);
+ rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
+ expandOp, newExpand.getResult(), newCollapseReInds);
+ return success();
+ }
+};
+
} // namespace
//===---------------------------------------------------------------------===//
@@ -1939,6 +2009,7 @@ void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
controlFoldingReshapes);
patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(),
controlFoldingReshapes);
+ patterns.add<BubbleUpExpandThroughParallelCollapse>(patterns.getContext());
}
void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns(
diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
index f42666f81bbad..1354b138983a0 100644
--- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir
+++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
@@ -826,3 +826,37 @@ func.func @linalg_add_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>,
// CHECK-SAME: [0, 1], [2, 3]
// CHECK-SAME: tensor<?x7x?x8xf32> into tensor<?x?xf32>
// CHECK: return %[[T4]]
+
+// -----
+
+func.func @bubble_parallel_reshapes(%arg0: tensor<?x?x?x?xf32>, %s0: index, %s1: index, %s2: index, %s3: index) -> tensor<?x?x?x?xf32> {
+ %collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3]] : tensor<?x?x?x?xf32> into tensor<?x?x?xf32>
+ %expand = tensor.expand_shape %collapse [[0], [1], [2, 3]]
+ output_shape [%s0, %s1, %s2, %s3] : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
+ return %expand : tensor<?x?x?x?xf32>
+}
+// CHECK: func @bubble_parallel_reshapes
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?x?xf32>
+// CHECK-SAME: %[[S0:.+]]: index, %[[S1:.+]]: index, %[[S2:.+]]: index, %[[S3:.+]]: index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
+// CHECK-DAG: %[[DIM1:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?x?x?xf32>
+// CHECK-DAG: %[[DIM2:.+]] = tensor.dim %[[ARG0]], %[[C2]] : tensor<?x?x?x?xf32>
+// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1], [2], [3, 4]]
+// CHECK-SAME: output_shape [%[[S0]], %[[DIM1]], %[[DIM2]], %[[S2]], %[[S3]]] : tensor<?x?x?x?xf32> into tensor<?x?x?x?x?xf32>
+// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[EXPAND]] {{\[}}[0], [1, 2], [3], [4]] : tensor<?x?x?x?x?xf32> into tensor<?x?x?x?xf32>
+// CHECK: return %[[COLLAPSE]]
+
+// -----
+
+func.func @no_bubble_intersecting_reshapes(%arg0: tensor<?x?x?x?xf32>, %s0: index, %s1: index, %s2: index, %s3: index) -> tensor<?x?x?x?xf32> {
+ %collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3]] : tensor<?x?x?x?xf32> into tensor<?x?x?xf32>
+ %expand = tensor.expand_shape %collapse [[0], [1, 2], [3]]
+ output_shape [%s0, %s1, %s2, %s3] : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
+ return %expand : tensor<?x?x?x?xf32>
+}
+// CHECK: func @no_bubble_intersecting_reshapes
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?x?xf32>
+// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2], [3]]
+// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[COLLAPSE]] {{\[}}[0], [1, 2], [3]]
+// CHECK: return %[[EXPAND]]
More information about the Mlir-commits
mailing list