[Mlir-commits] [mlir] [MLIR] Improve compose expand(collapse) pattern (PR #117768)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Nov 26 11:02:39 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Ian Wood (IanWood1)

<details>
<summary>Changes</summary>

If expand(collapse) has a dimension that gets collapsed and then expanded to the same shape, the pattern would fail to canonicalize this to a single collapse shape. Line 341 was changed because the expand(collapse) could be a reinterpret-cast like sequence where the shapes differ but the rank is the same. This cannot be represented by a single `collapse_shape` op and should be converted to a `cast`.

---
Full diff: https://github.com/llvm/llvm-project/pull/117768.diff


2 Files Affected:

- (modified) mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h (+14-10) 
- (modified) mlir/test/Dialect/Tensor/canonicalize.mlir (+28) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index 89bc57f09ec8ba..0357e34a2e0963 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -338,7 +338,7 @@ struct ComposeExpandOfCollapseOp : public OpRewritePattern<ExpandOpTy> {
 
     int64_t srcRank = srcType.getRank();
     int64_t resultRank = resultType.getRank();
-    if (srcType == resultType)
+    if (srcRank == resultRank)
       return failure();
 
     auto srcReassociation = collapseOp.getReassociationIndices();
@@ -388,12 +388,16 @@ struct ComposeExpandOfCollapseOp : public OpRewritePattern<ExpandOpTy> {
           resultShape.slice(resultIndices.front(), resultIndices.size());
 
       if (srcSubShape.size() == resultSubShape.size()) {
-        if (srcSubShape == resultSubShape &&
-            llvm::count_if(srcSubShape, ShapedType::isDynamic) < 2) {
-          composedReassociation.push_back(srcIndices);
-        } else {
+        if (srcSubShape != resultSubShape ||
+            llvm::count_if(srcSubShape, ShapedType::isDynamic) >= 2) {
           return std::nullopt;
         }
+        for (auto dim : llvm::seq<int64_t>(0, srcSubShape.size())) {
+          ReassociationIndices reassoc;
+          reassoc.push_back(srcIndices.front() + dim);
+          composedReassociation.push_back(reassoc);
+        }
+        continue;
       }
 
       // Find reassociation to collapse `srcSubShape` into `resultSubShape`.
@@ -403,11 +407,11 @@ struct ComposeExpandOfCollapseOp : public OpRewritePattern<ExpandOpTy> {
         return std::nullopt;
 
       // Remap the subshape indices back to the original srcShape.
-      for (auto &subshape_indices : *subShapeReassociation) {
-        ReassociationIndices shape_indices;
-        for (int64_t index : subshape_indices)
-          shape_indices.push_back(srcIndices.front() + index);
-        composedReassociation.push_back(shape_indices);
+      for (auto &subshapeIndices : *subShapeReassociation) {
+        ReassociationIndices shapeIndices;
+        for (int64_t index : subshapeIndices)
+          shapeIndices.push_back(srcIndices.front() + index);
+        composedReassociation.push_back(shapeIndices);
       }
     }
     return {std::move(composedReassociation)};
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 0b54c207dea84e..613ec066337294 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1382,6 +1382,34 @@ func.func @compose_expand_of_collapse_0_rank_to_collapse(%arg0 : tensor<1x1x1x1x
 
 // -----
 
+func.func @compose_expand_of_collapse_static(%arg0 : tensor<4x32x10x64x2xf16>) -> tensor<4x32x10x128xf16> {
+  %collapsed = tensor.collapse_shape %arg0 [[0, 1], [2], [3, 4]] : tensor<4x32x10x64x2xf16> into tensor<128x10x128xf16>
+  %expanded = tensor.expand_shape %collapsed [[0, 1], [2], [3]] output_shape [4, 32, 10, 128] : tensor<128x10x128xf16> into tensor<4x32x10x128xf16>
+  return %expanded : tensor<4x32x10x128xf16>
+}
+
+// CHECK-LABEL: func @compose_expand_of_collapse_static
+// CHECK-SAME:   %[[ARG0:.+]]: tensor<4x32x10x64x2xf16>
+//      CHECK:   %[[RESULT:.+]] = tensor.collapse_shape %[[ARG0]]
+// CHECK-SAME:     [0], [1], [2], [3, 4]
+//      CHECK:   return %[[RESULT]]
+
+// -----
+
+func.func @compose_expand_of_collapse_dynamic(%arg0 : tensor<4x?x10x64x2xf16>, %arg1 : index) -> tensor<4x?x10x128xf16> {
+  %collapsed = tensor.collapse_shape %arg0 [[0, 1], [2], [3, 4]] : tensor<4x?x10x64x2xf16> into tensor<?x10x128xf16>
+  %expanded = tensor.expand_shape %collapsed [[0, 1], [2], [3]] output_shape [4, %arg1,  10, 128] : tensor<?x10x128xf16> into tensor<4x?x10x128xf16>
+  return %expanded : tensor<4x?x10x128xf16>
+}
+
+// CHECK-LABEL: func @compose_expand_of_collapse_dynamic
+// CHECK-SAME:   %[[ARG0:.+]]: tensor<4x?x10x64x2xf16>
+//      CHECK:   %[[RESULT:.+]] = tensor.collapse_shape %[[ARG0]]
+// CHECK-SAME:     [0], [1], [2], [3, 4]
+//      CHECK:   return %[[RESULT]]
+
+// -----
+
 // CHECK-LABEL: func @zero_rank_reshape_multi
 func.func @zero_rank_reshape_multi(%arg0: tensor<f32>) -> tensor<f32> {
   // CHECK: return %arg0

``````````

</details>


https://github.com/llvm/llvm-project/pull/117768


More information about the Mlir-commits mailing list