[Mlir-commits] [mlir] [linalg] Add FoldReshapeWithGenericOpByCollapsing patten (PR #131029)
Nirvedh Meshram
llvmlistbot at llvm.org
Wed Mar 12 13:58:00 PDT 2025
https://github.com/nirvedhmeshram created https://github.com/llvm/llvm-project/pull/131029
This pattern to bubble up collapse shapes was missing in `populateFoldReshapeOpsByCollapsingPatterns` .
>From f44118b6984f0a755c105a2bb755a0c14272866f Mon Sep 17 00:00:00 2001
From: Nirvedh Meshram <nirvedh at gmail.com>
Date: Thu, 6 Mar 2025 15:17:12 -0600
Subject: [PATCH] [linalg] Add FoldReshapeWithGenericOpByCollapsing patten
Signed-off-by: Nirvedh Meshram <nirvedh at gmail.com>
---
.../Linalg/Transforms/ElementwiseOpFusion.cpp | 78 +++++++++
.../fuse-with-reshape-by-collapsing.mlir | 160 ++++++++++++++++++
.../Linalg/TestLinalgElementwiseFusion.cpp | 6 +
3 files changed, 244 insertions(+)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 33667e7ab0c5c..117433f287281 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1848,6 +1848,7 @@ namespace {
class FoldWithProducerReshapeOpByCollapsing
: public OpRewritePattern<GenericOp> {
public:
+ // TODO : support fusion with all linalg producers, not just generic.
FoldWithProducerReshapeOpByCollapsing(MLIRContext *context,
ControlFusionFn foldReshapes,
PatternBenefit benefit = 1)
@@ -1887,6 +1888,81 @@ class FoldWithProducerReshapeOpByCollapsing
ControlFusionFn controlFoldingReshapes;
};
+/// Pattern to fold a tensor.collapse_shape op with its producer generic op
+/// by expanding the dimensionality of the loop in the producer op.
+struct FoldReshapeWithGenericOpByCollapsing
+ : public OpRewritePattern<tensor::CollapseShapeOp> {
+
+ FoldReshapeWithGenericOpByCollapsing(MLIRContext *context,
+ ControlFusionFn foldReshapes,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<tensor::CollapseShapeOp>(context, benefit),
+ controlFoldingReshapes(std::move(foldReshapes)) {}
+
+ LogicalResult matchAndRewrite(tensor::CollapseShapeOp reshapeOp,
+ PatternRewriter &rewriter) const override {
+ // Fold only if all constraints of fusing with reshape by collapsing are
+ // met.
+ auto producerResult = dyn_cast<OpResult>(reshapeOp.getSrc());
+ if (!producerResult) {
+ return rewriter.notifyMatchFailure(reshapeOp,
+ "source not produced by an operation");
+ }
+
+ // TODO : support fusion with all linalg producers, not just generic.
+ auto producer = dyn_cast<GenericOp>(producerResult.getOwner());
+ if (!producer) {
+ return rewriter.notifyMatchFailure(reshapeOp,
+ "producer not a generic op");
+ }
+
+ SmallVector<ReassociationIndices> collapsableIterationDims =
+ getCollapsableIterationSpaceDims(
+ producer,
+ producer.getDpsInitOperand(producerResult.getResultNumber()),
+ reshapeOp.getReassociationIndices());
+ if (collapsableIterationDims.empty()) {
+ return rewriter.notifyMatchFailure(
+ reshapeOp, "failed preconditions of fusion with producer generic op");
+ }
+
+ if (!controlFoldingReshapes(&reshapeOp.getSrcMutable())) {
+ return rewriter.notifyMatchFailure(reshapeOp,
+ "fusion blocked by control function");
+ }
+
+ std::optional<CollapseResult> collapseResult =
+ collapseOpIterationDims(producer, collapsableIterationDims, rewriter);
+ if (!collapseResult) {
+ return rewriter.notifyMatchFailure(
+ producer, "failed to do the fusion by collapsing transformation");
+ }
+
+ if (!collapseResult) {
+ return rewriter.notifyMatchFailure(reshapeOp,
+ "fusion by expansion failed");
+ }
+
+ // Find the replacement for the reshape op. Since the replacements have the
+ // same type as the returns of the original generic op, the consumer reshape
+ // op can be replaced by the source of the collapse_shape op that defines
+ // the replacement.
+ Value reshapeReplacement =
+ (collapseResult
+ ->results)[cast<OpResult>(reshapeOp.getSrc()).getResultNumber()];
+ if (auto collapseOp =
+ reshapeReplacement.getDefiningOp<tensor::ExpandShapeOp>()) {
+ reshapeReplacement = collapseOp.getSrc();
+ }
+ rewriter.replaceOp(reshapeOp, reshapeReplacement);
+ rewriter.replaceOp(producer, collapseResult->results);
+ return success();
+ }
+
+private:
+ ControlFusionFn controlFoldingReshapes;
+};
+
class FoldPadWithProducerReshapeOpByCollapsing
: public OpRewritePattern<tensor::PadOp> {
public:
@@ -2215,6 +2291,8 @@ void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns(
controlFoldingReshapes);
patterns.add<FoldPadWithProducerReshapeOpByCollapsing>(
patterns.getContext(), controlFoldingReshapes);
+ patterns.add<FoldReshapeWithGenericOpByCollapsing>(patterns.getContext(),
+ controlFoldingReshapes);
}
void mlir::linalg::populateElementwiseOpsFusionPatterns(
diff --git a/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir b/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir
index 89734e7542801..21178fd7e783f 100644
--- a/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir
+++ b/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir
@@ -638,3 +638,163 @@ func.func @fuse_by_collapsing_dynamic_pad(%arg0 : tensor<?x?x?x?xf32>,
// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[PAD]] {{\[}}[0], [1, 2], [3], [4, 5]]
// CHECK-SAME: output_shape [%[[PAD_SIZE0]], %[[S1]], %[[S2]], %[[PAD_SIZE1]], %[[S4]], %[[S5]]] : tensor<?x?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
// CHECK: return %[[EXPAND]]
+
+// -----
+// Static problem sizes. Checks all aspects of fusion by collapsing with bubbling up collapse shapes.
+#map0 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d5, d6, d7)>
+#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2)>
+#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d3, d4, d5, d6)>
+#map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d5, d6, d7)>
+#map4 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d1, d2, d0, d7, d3, d4, d5, d6)>
+func.func @fuse_by_collapsing_bubblecollapse(%arg0 : tensor<2x3x4x5x6x7x8x9xi32>,
+ %arg1 : tensor<2x3x4xi32>, %arg2 : tensor<5x6x7x8xi32>) -> (tensor<2x12x5x336x9xi32>, tensor<12x2x9x5x336xi32>) {
+ %init_0 = tensor.empty() : tensor<2x3x4x5x6x7x8x9xi32>
+ %init_1 = tensor.empty() : tensor<3x4x2x9x5x6x7x8xi32>
+ %generic:2 = linalg.generic {
+ indexing_maps = [#map0, #map1, #map2, #map3, #map4],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]}
+ ins(%arg0, %arg1, %arg2 : tensor<2x3x4x5x6x7x8x9xi32>, tensor<2x3x4xi32>, tensor<5x6x7x8xi32>)
+ outs(%init_0, %init_1 : tensor<2x3x4x5x6x7x8x9xi32>, tensor<3x4x2x9x5x6x7x8xi32>) {
+ ^bb0(%b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32, %b4 : i32):
+ %t0 = arith.addi %b0, %b1 : i32
+ %t1 = arith.addi %t0, %b2 : i32
+ linalg.yield %t1, %t1 : i32, i32
+ } -> (tensor<2x3x4x5x6x7x8x9xi32>, tensor<3x4x2x9x5x6x7x8xi32>)
+ %collapse_1 = tensor.collapse_shape %generic#0 [[0], [1, 2], [3], [4, 5, 6], [7]] : tensor<2x3x4x5x6x7x8x9xi32> into tensor<2x12x5x336x9xi32>
+ %collapse_2 = tensor.collapse_shape %generic#1 [[0, 1], [2], [3], [4], [5, 6, 7]] : tensor<3x4x2x9x5x6x7x8xi32> into tensor<12x2x9x5x336xi32>
+ return %collapse_1, %collapse_2 : tensor<2x12x5x336x9xi32>, tensor<12x2x9x5x336xi32>
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3)>
+// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d1, d0, d4, d2, d3)>
+// CHECK: func @fuse_by_collapsing_bubblecollapse(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<2x3x4x5x6x7x8x9xi32>
+// CHECK-SAME: %[[ARG1:.+]]: tensor<2x3x4xi32>
+// CHECK-SAME: %[[ARG2:.+]]: tensor<5x6x7x8xi32>
+// CHECK-DAG: %[[INIT0:.+]] = tensor.empty() : tensor<2x3x4x5x6x7x8x9xi32>
+// CHECK-DAG: %[[INIT1:.+]] = tensor.empty() : tensor<3x4x2x9x5x6x7x8xi32>
+// CHECK-DAG: %[[ARG0_RESHAPE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]{{\]}}
+// CHECK-DAG: %[[ARG1_RESHAPE:.+]] = tensor.collapse_shape %[[ARG1]] {{\[}}[0], [1, 2]{{\]}}
+// CHECK-DAG: %[[ARG2_RESHAPE:.+]] = tensor.collapse_shape %[[ARG2]] {{\[}}[0], [1, 2, 3]{{\]}}
+// CHECK-DAG: %[[INIT0_RESHAPE:.+]] = tensor.collapse_shape %[[INIT0]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]{{\]}}
+// CHECK-DAG: %[[INIT1_RESHAPE:.+]] = tensor.collapse_shape %[[INIT1]] {{\[}}[0, 1], [2], [3], [4], [5, 6, 7]{{\]}}
+// CHECK: %[[COLLAPSED_OP:.+]]:2 = linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP0]], #[[MAP3]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]
+// CHECK-SAME: ins(%[[ARG0_RESHAPE]], %[[ARG1_RESHAPE]], %[[ARG2_RESHAPE]] :
+// CHECK-SAME: outs(%[[INIT0_RESHAPE]], %[[INIT1_RESHAPE]] :
+// CHECK: return %[[COLLAPSED_OP]]#0, %[[COLLAPSED_OP]]#1
+
+// -----
+
+#map0 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d5, d6, d7)>
+#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2)>
+#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d3, d4, d5, d6)>
+#map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d5, d6, d7)>
+func.func @fuse_by_collapsing_indexing_op_bubblecollapse(%arg0 : tensor<2x3x4x5x6x7x8x9xi32>,
+ %arg1 : tensor<2x3x4xi32>, %arg2 : tensor<5x6x7x8xi32>) -> tensor<2x12x5x336x9xi32> {
+ %init = tensor.empty() : tensor<2x3x4x5x6x7x8x9xi32>
+ %generic = linalg.generic {
+ indexing_maps = [#map0, #map1, #map2, #map3],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]}
+ ins(%arg0, %arg1, %arg2 : tensor<2x3x4x5x6x7x8x9xi32>, tensor<2x3x4xi32>, tensor<5x6x7x8xi32>)
+ outs(%init : tensor<2x3x4x5x6x7x8x9xi32>) {
+ ^bb0(%b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32):
+ %iv0 = linalg.index 0: index
+ %iv1 = linalg.index 1: index
+ %t0 = arith.addi %iv0, %iv1 : index
+ %iv2 = linalg.index 2 : index
+ %t1 = arith.addi %t0, %iv2 : index
+ %iv3 = linalg.index 3 : index
+ %t2 = arith.addi %t1, %iv3 : index
+ %iv4 = linalg.index 4 : index
+ %t3 = arith.addi %t2, %iv4 : index
+ %iv5 = linalg.index 5 : index
+ %t4 = arith.addi %t3, %iv5 : index
+ %iv6 = linalg.index 6 : index
+ %t5 = arith.addi %t4, %iv6 : index
+ %iv7 = linalg.index 7 : index
+ %t6 = arith.addi %t5, %iv7 : index
+ %yield = arith.index_cast %t6 : index to i32
+ linalg.yield %yield : i32
+ } -> tensor<2x3x4x5x6x7x8x9xi32>
+ %collapse = tensor.collapse_shape %generic [[0], [1, 2], [3], [4, 5, 6], [7]] : tensor<2x3x4x5x6x7x8x9xi32> into tensor<2x12x5x336x9xi32>
+ return %collapse : tensor<2x12x5x336x9xi32>
+}
+// CHECK-LABEL: func @fuse_by_collapsing_indexing_op_bubblecollapse(
+// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index
+// CHECK-DAG: %[[C8:.+]] = arith.constant 8 : index
+// CHECK-DAG: %[[C7:.+]] = arith.constant 7 : index
+// CHECK: %[[IV0:.+]] = linalg.index 0
+// CHECK: %[[IV1:.+]] = linalg.index 1
+// CHECK: %[[REM_IV1:.+]] = arith.remsi %[[IV1]], %[[C4]]
+// CHECK: %[[DIV_IV1:.+]] = arith.divsi %[[IV1]], %[[C4]]
+// CHECK: %[[IV2:.+]] = linalg.index 2
+// CHECK: %[[IV3:.+]] = linalg.index 3
+// CHECK: %[[REM1_IV3:.+]] = arith.remsi %[[IV3]], %[[C8]]
+// CHECK: %[[DIV1_IV3:.+]] = arith.divsi %[[IV3]], %[[C8]]
+// CHECK: %[[REM2_IV3:.+]] = arith.remsi %[[DIV1_IV3]], %[[C7]]
+// CHECK: %[[DIV2_IV3:.+]] = arith.divsi %[[DIV1_IV3]], %[[C7]]
+// CHECK: %[[IV4:.+]] = linalg.index 4
+// CHECK: %[[T0:.+]] = arith.addi %[[IV0]], %[[DIV_IV1]]
+// CHECK: %[[T1:.+]] = arith.addi %[[T0]], %[[REM_IV1]]
+// CHECK: %[[T2:.+]] = arith.addi %[[T1]], %[[IV2]]
+// CHECK: %[[T3:.+]] = arith.addi %[[T2]], %[[DIV2_IV3]]
+// CHECK: %[[T4:.+]] = arith.addi %[[T3]], %[[REM2_IV3]]
+// CHECK: %[[T5:.+]] = arith.addi %[[T4]], %[[REM1_IV3]]
+// CHECK: %[[T6:.+]] = arith.addi %[[T5]], %[[IV4]]
+// CHECK: %[[YIELD:.+]] = arith.index_cast %[[T6]]
+// CHECK: linalg.yield %[[YIELD]]
+
+// -----
+
+#map0 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d7, d5, d6, d0, d1, d2, d3, d4)>
+#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d5, d6, d0)>
+#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d4, d1, d2, d3)>
+#map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d5, d6, d7)>
+func.func @fuse_by_collapsing_change_reshape_order_bubblecollapse(%arg0 : tensor<9x7x8x2x3x4x5x6xi32>,
+ %arg1 : tensor<7x8x2xi32>, %arg2 : tensor<6x3x4x5xi32>) -> tensor<2x60x6x56x9xi32> {
+ %init = tensor.empty() : tensor<2x3x4x5x6x7x8x9xi32>
+ %generic = linalg.generic {
+ indexing_maps = [#map0, #map1, #map2, #map3],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]}
+ ins(%arg0, %arg1, %arg2 : tensor<9x7x8x2x3x4x5x6xi32>, tensor<7x8x2xi32>, tensor<6x3x4x5xi32>)
+ outs(%init : tensor<2x3x4x5x6x7x8x9xi32>) {
+ ^bb0(%b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32):
+ %t0 = arith.addi %b0, %b1 : i32
+ %t1 = arith.addi %t0, %b2 : i32
+ linalg.yield %t1 : i32
+ } -> tensor<2x3x4x5x6x7x8x9xi32>
+ %collapse = tensor.collapse_shape %generic [[0], [1, 2, 3], [4], [5, 6], [7]] : tensor<2x3x4x5x6x7x8x9xi32> into tensor<2x60x6x56x9xi32>
+ return %collapse : tensor<2x60x6x56x9xi32>
+}
+
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d4, d3, d0, d1, d2)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d3, d0)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d2, d1)>
+// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
+// CHECK: func @fuse_by_collapsing_change_reshape_order_bubblecollapse(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<9x7x8x2x3x4x5x6xi32>
+// CHECK-SAME: %[[ARG1:.+]]: tensor<7x8x2xi32>
+// CHECK-SAME: %[[ARG2:.+]]: tensor<6x3x4x5xi32>
+// CHECK-DAG: %[[INIT:.+]] = tensor.empty()
+// CHECK-DAG: %[[ARG0_RESHAPE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]{{\]}}
+// CHECK-DAG: %[[ARG1_RESHAPE:.+]] = tensor.collapse_shape %[[ARG1]] {{\[}}[0, 1], [2]{{\]}}
+// CHECK-DAG: %[[ARG2_RESHAPE:.+]] = tensor.collapse_shape %[[ARG2]] {{\[}}[0], [1, 2, 3]{{\]}}
+// CHECK-DAG: %[[INIT_RESHAPE:.+]] = tensor.collapse_shape %[[INIT]] {{\[}}[0], [1, 2, 3], [4], [5, 6], [7]{{\]}}
+// CHECK: %[[COLLAPSED_OP:.+]] = linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP3]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]
+// CHECK-SAME: ins(%[[ARG0_RESHAPE]], %[[ARG1_RESHAPE]], %[[ARG2_RESHAPE]] :
+// CHECK-SAME: outs(%[[INIT_RESHAPE]] :
+// CHECK: return %[[COLLAPSED_OP]]
+
+// CONTROL: func @fuse_by_collapsing_change_reshape_order_bubblecollapse(
+// CONTROL-SAME: %[[ARG0:.+]]: tensor<9x7x8x2x3x4x5x6xi32>
+// CONTROL-SAME: %[[ARG1:.+]]: tensor<7x8x2xi32>
+// CONTROL-SAME: %[[ARG2:.+]]: tensor<6x3x4x5xi32>
+// CONTROL: %[[GENERIC:.+]] = linalg.generic
+// CONTROL-SAME: ins(%[[ARG0]],
+// CONTROL: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[GENERIC]]
+// CONTROL: return %[[COLLAPSE]]
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
index e4883e47f2063..34a396f18f90e 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
@@ -235,6 +235,12 @@ struct TestLinalgElementwiseFusion
// Skip fusing the first operand.
return fusedOperand->getOperandNumber();
}
+ Operation *consumer = fusedOperand->getOwner();
+ if (auto collapseOp = dyn_cast<tensor::CollapseShapeOp>(consumer)) {
+ auto producerResult = dyn_cast<OpResult>(collapseOp.getSrc());
+ // skip fusing first result.
+ return producerResult.getResultNumber();
+ }
return true;
};
linalg::populateFoldReshapeOpsByCollapsingPatterns(patterns, controlFn);
More information about the Mlir-commits
mailing list