[Mlir-commits] [mlir] [mlir][linalg] Fix partial fuse by collapse (PR #136326)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Apr 18 10:00:46 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-linalg

Author: Ian Wood (IanWood1)

<details>
<summary>Changes</summary>

Similar to `FoldWithProducerReshapeOpByCollapsing`, `FoldReshapeWithGenericOpByCollapsing` needs to be able to handle partial fusion of a reshape by collapsing. This means that the source of the generated `expand_shape` op (aka the collapsed linalg op) might not match the type of the original `collapse_shape` op. This change instead replaces the original linalg op with the new `expand_shape` op which is guaranteed to be the same type.

---
Full diff: https://github.com/llvm/llvm-project/pull/136326.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp (-17) 
- (modified) mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir (+28) 


``````````diff
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index bf70597d5ddfe..f345cc5f3d172 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1907,23 +1907,6 @@ struct FoldReshapeWithGenericOpByCollapsing
           producer, "failed to do the fusion by collapsing transformation");
     }
 
-    if (!collapseResult) {
-      return rewriter.notifyMatchFailure(reshapeOp,
-                                         "fusion by expansion failed");
-    }
-
-    // Find the replacement for the reshape op. Since the replacements have the
-    // same type as the returns of the original generic op, the consumer reshape
-    // op can be replaced by the source of the expand_shape op that defines
-    // the replacement.
-    Value reshapeReplacement =
-        (collapseResult
-             ->results)[cast<OpResult>(reshapeOp.getSrc()).getResultNumber()];
-    if (auto expandOp =
-            reshapeReplacement.getDefiningOp<tensor::ExpandShapeOp>()) {
-      reshapeReplacement = expandOp.getSrc();
-    }
-    rewriter.replaceOp(reshapeOp, reshapeReplacement);
     rewriter.replaceOp(producer, collapseResult->results);
     return success();
   }
diff --git a/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir b/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir
index dba53b4192cd5..2bf3d21c35526 100644
--- a/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir
+++ b/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir
@@ -830,3 +830,31 @@ func.func @fuse_by_collapsing_correct_insertion(%arg0 : tensor<?x?xf32>,
 // CHECK:     %[[EXPANDED:.+]] = tensor.expand_shape %[[OUT]]
 // CHECK:     %[[DIM:.+]] = tensor.dim %[[EXPANDED]], %[[C0]]
 // CHECK:      return %[[OUT]], %[[DIM]]
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
+#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4, d1, d2)>
+func.func @partial_fuse_by_collapsing(%arg0: tensor<4x?x32x128x192xf16>, %arg1: tensor<4x128x192x?x32xf32>) -> tensor<512x192x?xf32> {
+  %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<4x?x32x128x192xf16>) outs(%arg1 : tensor<4x128x192x?x32xf32>) {
+  ^bb0(%in: f16, %out: f32):
+    linalg.yield %out : f32
+  } -> tensor<4x128x192x?x32xf32>
+  %collapsed = tensor.collapse_shape %0 [[0, 1], [2], [3, 4]] : tensor<4x128x192x?x32xf32> into tensor<512x192x?xf32>
+  return %collapsed : tensor<512x192x?xf32>
+}
+// CHECK-LABEL: func @partial_fuse_by_collapsing
+//  CHECK-SAME:  %[[ARG0:.+]]: tensor<4x?x32x128x192xf16>
+//  CHECK-SAME:  %[[ARG1:.+]]: tensor<4x128x192x?x32xf32>
+//   CHECK-DAG:   %[[COLLAPSED0:.+]] = tensor.collapse_shape %[[ARG0]]
+//  CHECK-SAME:     tensor<4x?x32x128x192xf16> into tensor<4x?x128x192xf16>
+//   CHECK-DAG:   %[[COLLAPSED1:.+]] = tensor.collapse_shape %[[ARG1]]
+//  CHECK-SAME:     tensor<4x128x192x?x32xf32> into tensor<4x128x192x?xf32>
+//       CHECK:   %[[GENERIC:.+]] = linalg.generic
+//  CHECK-SAME:     ins(%[[COLLAPSED0]]
+//  CHECK-SAME:     outs(%[[COLLAPSED1]]
+//       CHECK:   %[[EXPANDED:.+]] = tensor.expand_shape %[[GENERIC]]
+//  CHECK-SAME:     tensor<4x128x192x?xf32> into tensor<4x128x192x?x32xf32>
+//       CHECK:   %[[COLLAPSED:.+]] = tensor.collapse_shape %[[EXPANDED]]
+//  CHECK-SAME:     tensor<4x128x192x?x32xf32> into tensor<512x192x?xf32>
+//       CHECK:   return %[[COLLAPSED]] : tensor<512x192x?xf32>

``````````

</details>


https://github.com/llvm/llvm-project/pull/136326


More information about the Mlir-commits mailing list