[Mlir-commits] [mlir] [mlir][linalg] Fix partial fuse by collapse (PR #136326)

Ian Wood llvmlistbot at llvm.org
Fri Apr 18 10:00:06 PDT 2025


https://github.com/IanWood1 created https://github.com/llvm/llvm-project/pull/136326

Similar to `FoldWithProducerReshapeOpByCollapsing`, `FoldReshapeWithGenericOpByCollapsing` needs to be able to handle partial fusion of a reshape by collapsing. This means that the source of the generated `expand_shape` op (aka the collapsed linalg op) might not match the type of the original `collapse_shape` op. This change instead replaces the original linalg op with the new `expand_shape` op which is guaranteed to be the same type.

>From efa8ff33c3735524631bfa920986a1407d1471f7 Mon Sep 17 00:00:00 2001
From: Ian Wood <ianwood2024 at u.northwestern.edu>
Date: Fri, 18 Apr 2025 09:55:00 -0700
Subject: [PATCH] [mlir][linalg] Fix partial fuse by collapse

Signed-off-by: Ian Wood <ianwood2024 at u.northwestern.edu>
---
 .../Linalg/Transforms/ElementwiseOpFusion.cpp | 17 -----------
 .../fuse-with-reshape-by-collapsing.mlir      | 28 +++++++++++++++++++
 2 files changed, 28 insertions(+), 17 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index bf70597d5ddfe..f345cc5f3d172 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1907,23 +1907,6 @@ struct FoldReshapeWithGenericOpByCollapsing
           producer, "failed to do the fusion by collapsing transformation");
     }
 
-    if (!collapseResult) {
-      return rewriter.notifyMatchFailure(reshapeOp,
-                                         "fusion by expansion failed");
-    }
-
-    // Find the replacement for the reshape op. Since the replacements have the
-    // same type as the returns of the original generic op, the consumer reshape
-    // op can be replaced by the source of the expand_shape op that defines
-    // the replacement.
-    Value reshapeReplacement =
-        (collapseResult
-             ->results)[cast<OpResult>(reshapeOp.getSrc()).getResultNumber()];
-    if (auto expandOp =
-            reshapeReplacement.getDefiningOp<tensor::ExpandShapeOp>()) {
-      reshapeReplacement = expandOp.getSrc();
-    }
-    rewriter.replaceOp(reshapeOp, reshapeReplacement);
     rewriter.replaceOp(producer, collapseResult->results);
     return success();
   }
diff --git a/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir b/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir
index dba53b4192cd5..2bf3d21c35526 100644
--- a/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir
+++ b/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir
@@ -830,3 +830,31 @@ func.func @fuse_by_collapsing_correct_insertion(%arg0 : tensor<?x?xf32>,
 // CHECK:     %[[EXPANDED:.+]] = tensor.expand_shape %[[OUT]]
 // CHECK:     %[[DIM:.+]] = tensor.dim %[[EXPANDED]], %[[C0]]
 // CHECK:      return %[[OUT]], %[[DIM]]
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
+#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4, d1, d2)>
+func.func @partial_fuse_by_collapsing(%arg0: tensor<4x?x32x128x192xf16>, %arg1: tensor<4x128x192x?x32xf32>) -> tensor<512x192x?xf32> {
+  %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<4x?x32x128x192xf16>) outs(%arg1 : tensor<4x128x192x?x32xf32>) {
+  ^bb0(%in: f16, %out: f32):
+    linalg.yield %out : f32
+  } -> tensor<4x128x192x?x32xf32>
+  %collapsed = tensor.collapse_shape %0 [[0, 1], [2], [3, 4]] : tensor<4x128x192x?x32xf32> into tensor<512x192x?xf32>
+  return %collapsed : tensor<512x192x?xf32>
+}
+// CHECK-LABEL: func @partial_fuse_by_collapsing
+//  CHECK-SAME:  %[[ARG0:.+]]: tensor<4x?x32x128x192xf16>
+//  CHECK-SAME:  %[[ARG1:.+]]: tensor<4x128x192x?x32xf32>
+//   CHECK-DAG:   %[[COLLAPSED0:.+]] = tensor.collapse_shape %[[ARG0]]
+//  CHECK-SAME:     tensor<4x?x32x128x192xf16> into tensor<4x?x128x192xf16>
+//   CHECK-DAG:   %[[COLLAPSED1:.+]] = tensor.collapse_shape %[[ARG1]]
+//  CHECK-SAME:     tensor<4x128x192x?x32xf32> into tensor<4x128x192x?xf32>
+//       CHECK:   %[[GENERIC:.+]] = linalg.generic
+//  CHECK-SAME:     ins(%[[COLLAPSED0]]
+//  CHECK-SAME:     outs(%[[COLLAPSED1]]
+//       CHECK:   %[[EXPANDED:.+]] = tensor.expand_shape %[[GENERIC]]
+//  CHECK-SAME:     tensor<4x128x192x?xf32> into tensor<4x128x192x?x32xf32>
+//       CHECK:   %[[COLLAPSED:.+]] = tensor.collapse_shape %[[EXPANDED]]
+//  CHECK-SAME:     tensor<4x128x192x?x32xf32> into tensor<512x192x?xf32>
+//       CHECK:   return %[[COLLAPSED]] : tensor<512x192x?xf32>



More information about the Mlir-commits mailing list