[Mlir-commits] [mlir] [MLIR] Improve compose expand(collapse) pattern (PR #117768)
Ian Wood
llvmlistbot at llvm.org
Tue Nov 26 11:08:28 PST 2024
https://github.com/IanWood1 updated https://github.com/llvm/llvm-project/pull/117768
>From b4ff5f270e5af5e2af784c16d6d0e51ca151b180 Mon Sep 17 00:00:00 2001
From: Ian Wood <ianwood2024 at u.northwestern.edu>
Date: Tue, 26 Nov 2024 22:51:17 -0800
Subject: [PATCH] [MLIR] Improve compose expand(collapse) pattern
If expand(collapse) has a dimension that gets collapsed and then
expanded to the same shape, the pattern would fail to canonicalize this
to a single collapse shape.
Signed-off-by: Ian Wood <ianwood2024 at u.northwestern.edu>
---
.../mlir/Dialect/Utils/ReshapeOpsUtils.h | 22 ++++++++-------
mlir/test/Dialect/Tensor/canonicalize.mlir | 28 +++++++++++++++++++
2 files changed, 40 insertions(+), 10 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index 89bc57f09ec8ba..3fa35bf1851a9c 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -338,7 +338,7 @@ struct ComposeExpandOfCollapseOp : public OpRewritePattern<ExpandOpTy> {
int64_t srcRank = srcType.getRank();
int64_t resultRank = resultType.getRank();
- if (srcType == resultType)
+ if (srcRank == resultRank)
return failure();
auto srcReassociation = collapseOp.getReassociationIndices();
@@ -388,12 +388,14 @@ struct ComposeExpandOfCollapseOp : public OpRewritePattern<ExpandOpTy> {
resultShape.slice(resultIndices.front(), resultIndices.size());
if (srcSubShape.size() == resultSubShape.size()) {
- if (srcSubShape == resultSubShape &&
- llvm::count_if(srcSubShape, ShapedType::isDynamic) < 2) {
- composedReassociation.push_back(srcIndices);
- } else {
+ if (srcSubShape != resultSubShape ||
+ llvm::count_if(srcSubShape, ShapedType::isDynamic) >= 2) {
return std::nullopt;
}
+ for (auto index : llvm::seq<int64_t>(0, srcSubShape.size())) {
+ composedReassociation.emplace_back(1, srcIndices.front() + index);
+ }
+ continue;
}
// Find reassociation to collapse `srcSubShape` into `resultSubShape`.
@@ -403,11 +405,11 @@ struct ComposeExpandOfCollapseOp : public OpRewritePattern<ExpandOpTy> {
return std::nullopt;
// Remap the subshape indices back to the original srcShape.
- for (auto &subshape_indices : *subShapeReassociation) {
- ReassociationIndices shape_indices;
- for (int64_t index : subshape_indices)
- shape_indices.push_back(srcIndices.front() + index);
- composedReassociation.push_back(shape_indices);
+ for (auto &subshapeIndices : *subShapeReassociation) {
+ ReassociationIndices shapeIndices;
+ for (int64_t index : subshapeIndices)
+ shapeIndices.push_back(srcIndices.front() + index);
+ composedReassociation.push_back(shapeIndices);
}
}
return {std::move(composedReassociation)};
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 0b54c207dea84e..613ec066337294 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1382,6 +1382,34 @@ func.func @compose_expand_of_collapse_0_rank_to_collapse(%arg0 : tensor<1x1x1x1x
// -----
+func.func @compose_expand_of_collapse_static(%arg0 : tensor<4x32x10x64x2xf16>) -> tensor<4x32x10x128xf16> {
+ %collapsed = tensor.collapse_shape %arg0 [[0, 1], [2], [3, 4]] : tensor<4x32x10x64x2xf16> into tensor<128x10x128xf16>
+ %expanded = tensor.expand_shape %collapsed [[0, 1], [2], [3]] output_shape [4, 32, 10, 128] : tensor<128x10x128xf16> into tensor<4x32x10x128xf16>
+ return %expanded : tensor<4x32x10x128xf16>
+}
+
+// CHECK-LABEL: func @compose_expand_of_collapse_static
+// CHECK-SAME: %[[ARG0:.+]]: tensor<4x32x10x64x2xf16>
+// CHECK: %[[RESULT:.+]] = tensor.collapse_shape %[[ARG0]]
+// CHECK-SAME: [0], [1], [2], [3, 4]
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func.func @compose_expand_of_collapse_dynamic(%arg0 : tensor<4x?x10x64x2xf16>, %arg1 : index) -> tensor<4x?x10x128xf16> {
+ %collapsed = tensor.collapse_shape %arg0 [[0, 1], [2], [3, 4]] : tensor<4x?x10x64x2xf16> into tensor<?x10x128xf16>
+ %expanded = tensor.expand_shape %collapsed [[0, 1], [2], [3]] output_shape [4, %arg1, 10, 128] : tensor<?x10x128xf16> into tensor<4x?x10x128xf16>
+ return %expanded : tensor<4x?x10x128xf16>
+}
+
+// CHECK-LABEL: func @compose_expand_of_collapse_dynamic
+// CHECK-SAME: %[[ARG0:.+]]: tensor<4x?x10x64x2xf16>
+// CHECK: %[[RESULT:.+]] = tensor.collapse_shape %[[ARG0]]
+// CHECK-SAME: [0], [1], [2], [3, 4]
+// CHECK: return %[[RESULT]]
+
+// -----
+
// CHECK-LABEL: func @zero_rank_reshape_multi
func.func @zero_rank_reshape_multi(%arg0: tensor<f32>) -> tensor<f32> {
// CHECK: return %arg0
More information about the Mlir-commits
mailing list