[Mlir-commits] [mlir] [mlir][linalg] Fix partial fuse by collapse (PR #136326)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Apr 18 10:00:46 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-linalg
Author: Ian Wood (IanWood1)
<details>
<summary>Changes</summary>
Similar to `FoldWithProducerReshapeOpByCollapsing`, `FoldReshapeWithGenericOpByCollapsing` needs to be able to handle partial fusion of a reshape by collapsing. This means that the source of the generated `expand_shape` op (aka the collapsed linalg op) might not match the type of the original `collapse_shape` op. This change instead replaces the original linalg op with the new `expand_shape` op which is guaranteed to be the same type.
---
Full diff: https://github.com/llvm/llvm-project/pull/136326.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp (-17)
- (modified) mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir (+28)
``````````diff
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index bf70597d5ddfe..f345cc5f3d172 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1907,23 +1907,6 @@ struct FoldReshapeWithGenericOpByCollapsing
producer, "failed to do the fusion by collapsing transformation");
}
- if (!collapseResult) {
- return rewriter.notifyMatchFailure(reshapeOp,
- "fusion by expansion failed");
- }
-
- // Find the replacement for the reshape op. Since the replacements have the
- // same type as the returns of the original generic op, the consumer reshape
- // op can be replaced by the source of the expand_shape op that defines
- // the replacement.
- Value reshapeReplacement =
- (collapseResult
- ->results)[cast<OpResult>(reshapeOp.getSrc()).getResultNumber()];
- if (auto expandOp =
- reshapeReplacement.getDefiningOp<tensor::ExpandShapeOp>()) {
- reshapeReplacement = expandOp.getSrc();
- }
- rewriter.replaceOp(reshapeOp, reshapeReplacement);
rewriter.replaceOp(producer, collapseResult->results);
return success();
}
diff --git a/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir b/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir
index dba53b4192cd5..2bf3d21c35526 100644
--- a/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir
+++ b/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir
@@ -830,3 +830,31 @@ func.func @fuse_by_collapsing_correct_insertion(%arg0 : tensor<?x?xf32>,
// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[OUT]]
// CHECK: %[[DIM:.+]] = tensor.dim %[[EXPANDED]], %[[C0]]
// CHECK: return %[[OUT]], %[[DIM]]
+
+// -----
+
+#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
+#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4, d1, d2)>
+func.func @partial_fuse_by_collapsing(%arg0: tensor<4x?x32x128x192xf16>, %arg1: tensor<4x128x192x?x32xf32>) -> tensor<512x192x?xf32> {
+ %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<4x?x32x128x192xf16>) outs(%arg1 : tensor<4x128x192x?x32xf32>) {
+ ^bb0(%in: f16, %out: f32):
+ linalg.yield %out : f32
+ } -> tensor<4x128x192x?x32xf32>
+ %collapsed = tensor.collapse_shape %0 [[0, 1], [2], [3, 4]] : tensor<4x128x192x?x32xf32> into tensor<512x192x?xf32>
+ return %collapsed : tensor<512x192x?xf32>
+}
+// CHECK-LABEL: func @partial_fuse_by_collapsing
+// CHECK-SAME: %[[ARG0:.+]]: tensor<4x?x32x128x192xf16>
+// CHECK-SAME: %[[ARG1:.+]]: tensor<4x128x192x?x32xf32>
+// CHECK-DAG: %[[COLLAPSED0:.+]] = tensor.collapse_shape %[[ARG0]]
+// CHECK-SAME: tensor<4x?x32x128x192xf16> into tensor<4x?x128x192xf16>
+// CHECK-DAG: %[[COLLAPSED1:.+]] = tensor.collapse_shape %[[ARG1]]
+// CHECK-SAME: tensor<4x128x192x?x32xf32> into tensor<4x128x192x?xf32>
+// CHECK: %[[GENERIC:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[COLLAPSED0]]
+// CHECK-SAME: outs(%[[COLLAPSED1]]
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[GENERIC]]
+// CHECK-SAME: tensor<4x128x192x?xf32> into tensor<4x128x192x?x32xf32>
+// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[EXPANDED]]
+// CHECK-SAME: tensor<4x128x192x?x32xf32> into tensor<512x192x?xf32>
+// CHECK: return %[[COLLAPSED]] : tensor<512x192x?xf32>
``````````
</details>
https://github.com/llvm/llvm-project/pull/136326
More information about the Mlir-commits
mailing list