[Mlir-commits] [mlir] 4652b69 - [mlir][linalg] Fix partial fuse by collapse (#136326)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Apr 24 20:06:20 PDT 2025


Author: Ian Wood
Date: 2025-04-24T20:06:17-07:00
New Revision: 4652b69b0a512ff248b08cfa7ba9547860d1cc10

URL: https://github.com/llvm/llvm-project/commit/4652b69b0a512ff248b08cfa7ba9547860d1cc10
DIFF: https://github.com/llvm/llvm-project/commit/4652b69b0a512ff248b08cfa7ba9547860d1cc10.diff

LOG: [mlir][linalg] Fix partial fuse by collapse (#136326)

Similar to `FoldWithProducerReshapeOpByCollapsing`,
`FoldReshapeWithGenericOpByCollapsing` needs to be able to handle
partial fusion of a reshape by collapsing. This means that the source of
the generated `expand_shape` op (aka the collapsed linalg op) might not
match the type of the original `collapse_shape` op. This change instead
replaces the original linalg op with the new `expand_shape` op which is
guaranteed to be the same type.

Signed-off-by: Ian Wood <ianwood2024 at u.northwestern.edu>

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
    mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 62d016b87d627..8c8f8594b81af 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1907,23 +1907,6 @@ struct FoldReshapeWithGenericOpByCollapsing
           producer, "failed to do the fusion by collapsing transformation");
     }
 
-    if (!collapseResult) {
-      return rewriter.notifyMatchFailure(reshapeOp,
-                                         "fusion by expansion failed");
-    }
-
-    // Find the replacement for the reshape op. Since the replacements have the
-    // same type as the returns of the original generic op, the consumer reshape
-    // op can be replaced by the source of the expand_shape op that defines
-    // the replacement.
-    Value reshapeReplacement =
-        (collapseResult
-             ->results)[cast<OpResult>(reshapeOp.getSrc()).getResultNumber()];
-    if (auto expandOp =
-            reshapeReplacement.getDefiningOp<tensor::ExpandShapeOp>()) {
-      reshapeReplacement = expandOp.getSrc();
-    }
-    rewriter.replaceOp(reshapeOp, reshapeReplacement);
     rewriter.replaceOp(producer, collapseResult->results);
     return success();
   }

diff  --git a/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir b/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir
index dba53b4192cd5..2bf3d21c35526 100644
--- a/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir
+++ b/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir
@@ -830,3 +830,31 @@ func.func @fuse_by_collapsing_correct_insertion(%arg0 : tensor<?x?xf32>,
 // CHECK:     %[[EXPANDED:.+]] = tensor.expand_shape %[[OUT]]
 // CHECK:     %[[DIM:.+]] = tensor.dim %[[EXPANDED]], %[[C0]]
 // CHECK:      return %[[OUT]], %[[DIM]]
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
+#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4, d1, d2)>
+func.func @partial_fuse_by_collapsing(%arg0: tensor<4x?x32x128x192xf16>, %arg1: tensor<4x128x192x?x32xf32>) -> tensor<512x192x?xf32> {
+  %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<4x?x32x128x192xf16>) outs(%arg1 : tensor<4x128x192x?x32xf32>) {
+  ^bb0(%in: f16, %out: f32):
+    linalg.yield %out : f32
+  } -> tensor<4x128x192x?x32xf32>
+  %collapsed = tensor.collapse_shape %0 [[0, 1], [2], [3, 4]] : tensor<4x128x192x?x32xf32> into tensor<512x192x?xf32>
+  return %collapsed : tensor<512x192x?xf32>
+}
+// CHECK-LABEL: func @partial_fuse_by_collapsing
+//  CHECK-SAME:  %[[ARG0:.+]]: tensor<4x?x32x128x192xf16>
+//  CHECK-SAME:  %[[ARG1:.+]]: tensor<4x128x192x?x32xf32>
+//   CHECK-DAG:   %[[COLLAPSED0:.+]] = tensor.collapse_shape %[[ARG0]]
+//  CHECK-SAME:     tensor<4x?x32x128x192xf16> into tensor<4x?x128x192xf16>
+//   CHECK-DAG:   %[[COLLAPSED1:.+]] = tensor.collapse_shape %[[ARG1]]
+//  CHECK-SAME:     tensor<4x128x192x?x32xf32> into tensor<4x128x192x?xf32>
+//       CHECK:   %[[GENERIC:.+]] = linalg.generic
+//  CHECK-SAME:     ins(%[[COLLAPSED0]]
+//  CHECK-SAME:     outs(%[[COLLAPSED1]]
+//       CHECK:   %[[EXPANDED:.+]] = tensor.expand_shape %[[GENERIC]]
+//  CHECK-SAME:     tensor<4x128x192x?xf32> into tensor<4x128x192x?x32xf32>
+//       CHECK:   %[[COLLAPSED:.+]] = tensor.collapse_shape %[[EXPANDED]]
+//  CHECK-SAME:     tensor<4x128x192x?x32xf32> into tensor<512x192x?xf32>
+//       CHECK:   return %[[COLLAPSED]] : tensor<512x192x?xf32>


        


More information about the Mlir-commits mailing list