[Mlir-commits] [mlir] [mlir] Fix ComposeExpandOfCollapseOp for dynamic case (PR #142663)

Ian Wood llvmlistbot at llvm.org
Thu Jun 5 13:22:29 PDT 2025


https://github.com/IanWood1 updated https://github.com/llvm/llvm-project/pull/142663

>From 510f3c7f7433de0b951e68e3f024fd7a9d4c6027 Mon Sep 17 00:00:00 2001
From: Ian Wood <ianwood2024 at u.northwestern.edu>
Date: Tue, 3 Jun 2025 12:33:01 -0700
Subject: [PATCH 1/2] [mlir] Fix ComposeExpandOfCollapseOp for dynamic case

Signed-off-by: Ian Wood <ianwood2024 at u.northwestern.edu>
---
 mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h |  8 +++++---
 mlir/test/Dialect/Tensor/canonicalize.mlir        | 14 ++++++++++++++
 2 files changed, 19 insertions(+), 3 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index af575e10acc8e..c9f29ab8f15d5 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -387,11 +387,13 @@ struct ComposeExpandOfCollapseOp : public OpRewritePattern<ExpandOpTy> {
       auto resultSubShape =
           resultShape.slice(resultIndices.front(), resultIndices.size());
 
+      if (llvm::count_if(srcSubShape, ShapedType::isDynamic) >= 2)
+        return std::nullopt;
+
       if (srcSubShape.size() == resultSubShape.size()) {
-        if (srcSubShape != resultSubShape ||
-            llvm::count_if(srcSubShape, ShapedType::isDynamic) >= 2) {
+        if (srcSubShape != resultSubShape)
           return std::nullopt;
-        }
+
         for (auto index : llvm::seq<int64_t>(0, srcSubShape.size())) {
           composedReassociation.emplace_back(1, srcIndices.front() + index);
         }
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 646b2197d9aa6..7e5e57423f44d 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1319,6 +1319,20 @@ func.func @compose_expand_of_collapse_dynamic(%arg0 : tensor<4x?x10x64x2xf16>, %
 
 // -----
 
+func.func @no_compose_collapse_of_expand_dynamic(%arg0 : tensor<?x8x128x?xf16>, %arg1: index) -> tensor<?x128x?xf16> {
+  %collapse = tensor.collapse_shape %arg0 [[0, 1, 2, 3]] : tensor<?x8x128x?xf16> into tensor<?xf16>
+  %expanded_19 = tensor.expand_shape %collapse [[0, 1, 2]] output_shape [%arg1, 8, %arg1] : tensor<?xf16> into tensor<?x128x?xf16>
+  return %expanded_19 : tensor<?x128x?xf16>
+}
+// CHECK-LABEL: func @no_compose_collapse_of_expand_dynamic
+//  CHECK-SAME:   %[[ARG0:.+]]: tensor
+//  CHECK-SAME:   %[[ARG1:.+]]: index
+//       CHECK:   %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]]
+//       CHECK:   %[[EXPAND:.+]] = tensor.expand_shape %[[COLLAPSE]]
+//       CHECK:   return %[[EXPAND]]
+
+// -----
+
 // CHECK-LABEL: func @zero_rank_reshape_multi
 func.func @zero_rank_reshape_multi(%arg0: tensor<f32>) -> tensor<f32> {
   // CHECK: return %arg0

>From 37a0e5b042951b6de625c4c85b717a4865801ee7 Mon Sep 17 00:00:00 2001
From: Ian Wood <ianwood2024 at u.northwestern.edu>
Date: Thu, 5 Jun 2025 13:23:12 -0700
Subject: [PATCH 2/2] Check >=2 dyn for source and result

Signed-off-by: Ian Wood <ianwood2024 at u.northwestern.edu>
---
 mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index c9f29ab8f15d5..61c2a50e514ca 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -387,7 +387,8 @@ struct ComposeExpandOfCollapseOp : public OpRewritePattern<ExpandOpTy> {
       auto resultSubShape =
           resultShape.slice(resultIndices.front(), resultIndices.size());
 
-      if (llvm::count_if(srcSubShape, ShapedType::isDynamic) >= 2)
+      if (llvm::count_if(srcSubShape, ShapedType::isDynamic) >= 2 &&
+          llvm::count_if(resultSubShape, ShapedType::isDynamic) >= 2)
         return std::nullopt;
 
       if (srcSubShape.size() == resultSubShape.size()) {



More information about the Mlir-commits mailing list