[Mlir-commits] [mlir] [mlir][linalg] Add FoldReshapeWithGenericOpByCollapsing pattern (PR #131029)

Nirvedh Meshram llvmlistbot at llvm.org
Wed Mar 12 14:07:07 PDT 2025


https://github.com/nirvedhmeshram updated https://github.com/llvm/llvm-project/pull/131029

>From da917f143a7866ecf54ed6169d7e060fafe10669 Mon Sep 17 00:00:00 2001
From: Nirvedh Meshram <nirvedh at gmail.com>
Date: Thu, 6 Mar 2025 15:17:12 -0600
Subject: [PATCH] [linalg] Add FoldReshapeWithGenericOpByCollapsing patten

Signed-off-by: Nirvedh Meshram <nirvedh at gmail.com>
---
 .../Linalg/Transforms/ElementwiseOpFusion.cpp |  78 +++++++++
 .../fuse-with-reshape-by-collapsing.mlir      | 160 ++++++++++++++++++
 .../Linalg/TestLinalgElementwiseFusion.cpp    |   6 +
 3 files changed, 244 insertions(+)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 33667e7ab0c5c..57fcafd6dbb56 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1848,6 +1848,7 @@ namespace {
 class FoldWithProducerReshapeOpByCollapsing
     : public OpRewritePattern<GenericOp> {
 public:
+  // TODO : support fusion with all linalg ops, not just generic.
   FoldWithProducerReshapeOpByCollapsing(MLIRContext *context,
                                         ControlFusionFn foldReshapes,
                                         PatternBenefit benefit = 1)
@@ -1887,6 +1888,81 @@ class FoldWithProducerReshapeOpByCollapsing
   ControlFusionFn controlFoldingReshapes;
 };
 
+/// Pattern to fold a tensor.collapse_shape op with its producer generic op
+/// by expanding the dimensionality of the loop in the producer op.
+struct FoldReshapeWithGenericOpByCollapsing
+    : public OpRewritePattern<tensor::CollapseShapeOp> {
+
+  FoldReshapeWithGenericOpByCollapsing(MLIRContext *context,
+                                       ControlFusionFn foldReshapes,
+                                       PatternBenefit benefit = 1)
+      : OpRewritePattern<tensor::CollapseShapeOp>(context, benefit),
+        controlFoldingReshapes(std::move(foldReshapes)) {}
+
+  LogicalResult matchAndRewrite(tensor::CollapseShapeOp reshapeOp,
+                                PatternRewriter &rewriter) const override {
+    // Fold only if all constraints of fusing with reshape by collapsing are
+    // met.
+    auto producerResult = dyn_cast<OpResult>(reshapeOp.getSrc());
+    if (!producerResult) {
+      return rewriter.notifyMatchFailure(reshapeOp,
+                                         "source not produced by an operation");
+    }
+
+    // TODO : support fusion with all linalg producers, not just generic.
+    auto producer = dyn_cast<GenericOp>(producerResult.getOwner());
+    if (!producer) {
+      return rewriter.notifyMatchFailure(reshapeOp,
+                                         "producer not a generic op");
+    }
+
+    SmallVector<ReassociationIndices> collapsableIterationDims =
+        getCollapsableIterationSpaceDims(
+            producer,
+            producer.getDpsInitOperand(producerResult.getResultNumber()),
+            reshapeOp.getReassociationIndices());
+    if (collapsableIterationDims.empty()) {
+      return rewriter.notifyMatchFailure(
+          reshapeOp, "failed preconditions of fusion with producer generic op");
+    }
+
+    if (!controlFoldingReshapes(&reshapeOp.getSrcMutable())) {
+      return rewriter.notifyMatchFailure(reshapeOp,
+                                         "fusion blocked by control function");
+    }
+
+    std::optional<CollapseResult> collapseResult =
+        collapseOpIterationDims(producer, collapsableIterationDims, rewriter);
+    if (!collapseResult) {
+      return rewriter.notifyMatchFailure(
+          producer, "failed to do the fusion by collapsing transformation");
+    }
+
+    if (!collapseResult) {
+      return rewriter.notifyMatchFailure(reshapeOp,
+                                         "fusion by expansion failed");
+    }
+
+    // Find the replacement for the reshape op. Since the replacements have the
+    // same type as the returns of the original generic op, the consumer reshape
+    // op can be replaced by the source of the expand_shape op that defines
+    // the replacement.
+    Value reshapeReplacement =
+        (collapseResult
+             ->results)[cast<OpResult>(reshapeOp.getSrc()).getResultNumber()];
+    if (auto expandOp =
+            reshapeReplacement.getDefiningOp<tensor::ExpandShapeOp>()) {
+      reshapeReplacement = expandOp.getSrc();
+    }
+    rewriter.replaceOp(reshapeOp, reshapeReplacement);
+    rewriter.replaceOp(producer, collapseResult->results);
+    return success();
+  }
+
+private:
+  ControlFusionFn controlFoldingReshapes;
+};
+
 class FoldPadWithProducerReshapeOpByCollapsing
     : public OpRewritePattern<tensor::PadOp> {
 public:
@@ -2215,6 +2291,8 @@ void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns(
                                                       controlFoldingReshapes);
   patterns.add<FoldPadWithProducerReshapeOpByCollapsing>(
       patterns.getContext(), controlFoldingReshapes);
+  patterns.add<FoldReshapeWithGenericOpByCollapsing>(patterns.getContext(),
+                                                     controlFoldingReshapes);
 }
 
 void mlir::linalg::populateElementwiseOpsFusionPatterns(
diff --git a/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir b/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir
index 89734e7542801..21178fd7e783f 100644
--- a/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir
+++ b/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir
@@ -638,3 +638,163 @@ func.func @fuse_by_collapsing_dynamic_pad(%arg0 : tensor<?x?x?x?xf32>,
 //      CHECK:   %[[EXPAND:.+]] = tensor.expand_shape %[[PAD]] {{\[}}[0], [1, 2], [3], [4, 5]]
 // CHECK-SAME:       output_shape [%[[PAD_SIZE0]], %[[S1]], %[[S2]], %[[PAD_SIZE1]], %[[S4]], %[[S5]]] : tensor<?x?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
 //      CHECK:   return %[[EXPAND]]
+
+// -----
+// Static problem sizes. Checks all aspects of fusion by collapsing with bubbling up collapse shapes.
+#map0 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d5, d6, d7)>
+#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2)>
+#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d3, d4, d5, d6)>
+#map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d5, d6, d7)>
+#map4 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d1, d2, d0, d7, d3, d4, d5, d6)>
+func.func @fuse_by_collapsing_bubblecollapse(%arg0 : tensor<2x3x4x5x6x7x8x9xi32>,
+    %arg1 : tensor<2x3x4xi32>, %arg2 : tensor<5x6x7x8xi32>) -> (tensor<2x12x5x336x9xi32>, tensor<12x2x9x5x336xi32>) {
+  %init_0 = tensor.empty() : tensor<2x3x4x5x6x7x8x9xi32>
+  %init_1 = tensor.empty() : tensor<3x4x2x9x5x6x7x8xi32>
+  %generic:2 = linalg.generic {
+    indexing_maps = [#map0, #map1, #map2, #map3, #map4],
+    iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]}
+    ins(%arg0, %arg1, %arg2 : tensor<2x3x4x5x6x7x8x9xi32>, tensor<2x3x4xi32>, tensor<5x6x7x8xi32>)
+    outs(%init_0, %init_1 : tensor<2x3x4x5x6x7x8x9xi32>, tensor<3x4x2x9x5x6x7x8xi32>) {
+      ^bb0(%b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32, %b4 : i32):
+        %t0 = arith.addi %b0, %b1 : i32
+        %t1 = arith.addi %t0, %b2 : i32
+        linalg.yield %t1, %t1 : i32, i32
+    } -> (tensor<2x3x4x5x6x7x8x9xi32>, tensor<3x4x2x9x5x6x7x8xi32>)
+  %collapse_1 =  tensor.collapse_shape %generic#0 [[0], [1, 2], [3], [4, 5, 6], [7]] :  tensor<2x3x4x5x6x7x8x9xi32> into tensor<2x12x5x336x9xi32>
+  %collapse_2 =  tensor.collapse_shape %generic#1 [[0, 1], [2], [3], [4], [5, 6, 7]] : tensor<3x4x2x9x5x6x7x8xi32> into tensor<12x2x9x5x336xi32>
+  return %collapse_1, %collapse_2 : tensor<2x12x5x336x9xi32>, tensor<12x2x9x5x336xi32>
+}
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>
+//  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3)>
+//  CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d1, d0, d4, d2, d3)>
+//      CHECK: func @fuse_by_collapsing_bubblecollapse(
+// CHECK-SAME:   %[[ARG0:.+]]: tensor<2x3x4x5x6x7x8x9xi32>
+// CHECK-SAME:   %[[ARG1:.+]]: tensor<2x3x4xi32>
+// CHECK-SAME:   %[[ARG2:.+]]: tensor<5x6x7x8xi32>
+//  CHECK-DAG:   %[[INIT0:.+]] = tensor.empty() : tensor<2x3x4x5x6x7x8x9xi32>
+//  CHECK-DAG:   %[[INIT1:.+]] = tensor.empty() : tensor<3x4x2x9x5x6x7x8xi32>
+//  CHECK-DAG:   %[[ARG0_RESHAPE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]{{\]}}
+//  CHECK-DAG:   %[[ARG1_RESHAPE:.+]] = tensor.collapse_shape %[[ARG1]] {{\[}}[0], [1, 2]{{\]}}
+//  CHECK-DAG:   %[[ARG2_RESHAPE:.+]] = tensor.collapse_shape %[[ARG2]] {{\[}}[0], [1, 2, 3]{{\]}}
+//  CHECK-DAG:   %[[INIT0_RESHAPE:.+]] = tensor.collapse_shape %[[INIT0]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]{{\]}}
+//  CHECK-DAG:   %[[INIT1_RESHAPE:.+]] = tensor.collapse_shape %[[INIT1]] {{\[}}[0, 1], [2], [3], [4], [5, 6, 7]{{\]}}
+//      CHECK:   %[[COLLAPSED_OP:.+]]:2 = linalg.generic
+// CHECK-SAME:       indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP0]], #[[MAP3]]]
+// CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]
+// CHECK-SAME:       ins(%[[ARG0_RESHAPE]], %[[ARG1_RESHAPE]], %[[ARG2_RESHAPE]] :
+// CHECK-SAME:       outs(%[[INIT0_RESHAPE]], %[[INIT1_RESHAPE]] :
+//      CHECK:   return %[[COLLAPSED_OP]]#0, %[[COLLAPSED_OP]]#1
+
+// -----
+
+#map0 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d5, d6, d7)>
+#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2)>
+#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d3, d4, d5, d6)>
+#map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d5, d6, d7)>
+func.func @fuse_by_collapsing_indexing_op_bubblecollapse(%arg0 : tensor<2x3x4x5x6x7x8x9xi32>,
+    %arg1 : tensor<2x3x4xi32>, %arg2 : tensor<5x6x7x8xi32>) -> tensor<2x12x5x336x9xi32> {
+  %init = tensor.empty() : tensor<2x3x4x5x6x7x8x9xi32>
+  %generic = linalg.generic {
+    indexing_maps = [#map0, #map1, #map2, #map3],
+    iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]}
+    ins(%arg0, %arg1, %arg2 : tensor<2x3x4x5x6x7x8x9xi32>, tensor<2x3x4xi32>, tensor<5x6x7x8xi32>)
+    outs(%init : tensor<2x3x4x5x6x7x8x9xi32>) {
+      ^bb0(%b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32):
+        %iv0 = linalg.index 0: index
+        %iv1 = linalg.index 1: index
+        %t0 = arith.addi %iv0, %iv1 : index
+        %iv2 = linalg.index 2 : index
+        %t1 = arith.addi %t0, %iv2 : index
+        %iv3 = linalg.index 3 : index
+        %t2 = arith.addi %t1, %iv3 : index
+        %iv4 = linalg.index 4 : index
+        %t3 = arith.addi %t2, %iv4 : index
+        %iv5 = linalg.index 5 : index
+        %t4 = arith.addi %t3, %iv5 : index
+        %iv6 = linalg.index 6 : index
+        %t5 = arith.addi %t4, %iv6 : index
+        %iv7 = linalg.index 7 : index
+        %t6 = arith.addi %t5, %iv7 : index
+        %yield = arith.index_cast %t6 : index to i32
+        linalg.yield %yield : i32
+    } -> tensor<2x3x4x5x6x7x8x9xi32>
+  %collapse = tensor.collapse_shape %generic [[0], [1, 2], [3], [4, 5, 6], [7]] : tensor<2x3x4x5x6x7x8x9xi32> into tensor<2x12x5x336x9xi32>
+  return %collapse : tensor<2x12x5x336x9xi32>
+}
+// CHECK-LABEL: func @fuse_by_collapsing_indexing_op_bubblecollapse(
+//   CHECK-DAG:   %[[C4:.+]] = arith.constant 4 : index
+//   CHECK-DAG:   %[[C8:.+]] = arith.constant 8 : index
+//   CHECK-DAG:   %[[C7:.+]] = arith.constant 7 : index
+//       CHECK:     %[[IV0:.+]] = linalg.index 0
+//       CHECK:     %[[IV1:.+]] = linalg.index 1
+//       CHECK:     %[[REM_IV1:.+]] = arith.remsi %[[IV1]], %[[C4]]
+//       CHECK:     %[[DIV_IV1:.+]] = arith.divsi %[[IV1]], %[[C4]]
+//       CHECK:     %[[IV2:.+]] = linalg.index 2
+//       CHECK:     %[[IV3:.+]] = linalg.index 3
+//       CHECK:     %[[REM1_IV3:.+]] = arith.remsi %[[IV3]], %[[C8]]
+//       CHECK:     %[[DIV1_IV3:.+]] = arith.divsi %[[IV3]], %[[C8]]
+//       CHECK:     %[[REM2_IV3:.+]] = arith.remsi %[[DIV1_IV3]], %[[C7]]
+//       CHECK:     %[[DIV2_IV3:.+]] = arith.divsi %[[DIV1_IV3]], %[[C7]]
+//       CHECK:     %[[IV4:.+]] = linalg.index 4
+//       CHECK:     %[[T0:.+]] = arith.addi %[[IV0]], %[[DIV_IV1]]
+//       CHECK:     %[[T1:.+]] = arith.addi %[[T0]], %[[REM_IV1]]
+//       CHECK:     %[[T2:.+]] = arith.addi %[[T1]], %[[IV2]]
+//       CHECK:     %[[T3:.+]] = arith.addi %[[T2]], %[[DIV2_IV3]]
+//       CHECK:     %[[T4:.+]] = arith.addi %[[T3]], %[[REM2_IV3]]
+//       CHECK:     %[[T5:.+]] = arith.addi %[[T4]], %[[REM1_IV3]]
+//       CHECK:     %[[T6:.+]] = arith.addi %[[T5]], %[[IV4]]
+//       CHECK:     %[[YIELD:.+]] = arith.index_cast %[[T6]]
+//       CHECK:     linalg.yield %[[YIELD]]
+
+// -----
+
+#map0 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d7, d5, d6, d0, d1, d2, d3, d4)>
+#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d5, d6, d0)>
+#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d4, d1, d2, d3)>
+#map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d5, d6, d7)>
+func.func @fuse_by_collapsing_change_reshape_order_bubblecollapse(%arg0 : tensor<9x7x8x2x3x4x5x6xi32>,
+    %arg1 : tensor<7x8x2xi32>, %arg2 : tensor<6x3x4x5xi32>) -> tensor<2x60x6x56x9xi32> {
+  %init = tensor.empty() : tensor<2x3x4x5x6x7x8x9xi32>
+  %generic = linalg.generic {
+    indexing_maps = [#map0, #map1, #map2, #map3],
+    iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]}
+    ins(%arg0, %arg1, %arg2 : tensor<9x7x8x2x3x4x5x6xi32>, tensor<7x8x2xi32>, tensor<6x3x4x5xi32>)
+    outs(%init : tensor<2x3x4x5x6x7x8x9xi32>) {
+      ^bb0(%b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32):
+        %t0 = arith.addi %b0, %b1 : i32
+        %t1 = arith.addi %t0, %b2 : i32
+        linalg.yield %t1 : i32
+    } -> tensor<2x3x4x5x6x7x8x9xi32>
+  %collapse = tensor.collapse_shape %generic [[0], [1, 2, 3], [4], [5, 6], [7]] : tensor<2x3x4x5x6x7x8x9xi32> into tensor<2x60x6x56x9xi32>
+  return %collapse : tensor<2x60x6x56x9xi32>
+}
+
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d4, d3, d0, d1, d2)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d3, d0)>
+//  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d2, d1)>
+//  CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
+//      CHECK: func @fuse_by_collapsing_change_reshape_order_bubblecollapse(
+// CHECK-SAME:   %[[ARG0:.+]]: tensor<9x7x8x2x3x4x5x6xi32>
+// CHECK-SAME:   %[[ARG1:.+]]: tensor<7x8x2xi32>
+// CHECK-SAME:   %[[ARG2:.+]]: tensor<6x3x4x5xi32>
+//  CHECK-DAG:   %[[INIT:.+]] = tensor.empty()
+//  CHECK-DAG:   %[[ARG0_RESHAPE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]{{\]}}
+//  CHECK-DAG:   %[[ARG1_RESHAPE:.+]] = tensor.collapse_shape %[[ARG1]] {{\[}}[0, 1], [2]{{\]}}
+//  CHECK-DAG:   %[[ARG2_RESHAPE:.+]] = tensor.collapse_shape %[[ARG2]] {{\[}}[0], [1, 2, 3]{{\]}}
+//  CHECK-DAG:   %[[INIT_RESHAPE:.+]] = tensor.collapse_shape %[[INIT]] {{\[}}[0], [1, 2, 3], [4], [5, 6], [7]{{\]}}
+//      CHECK:   %[[COLLAPSED_OP:.+]] = linalg.generic
+// CHECK-SAME:       indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP3]]]
+// CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]
+// CHECK-SAME:       ins(%[[ARG0_RESHAPE]], %[[ARG1_RESHAPE]], %[[ARG2_RESHAPE]] :
+// CHECK-SAME:       outs(%[[INIT_RESHAPE]] :
+//      CHECK:   return %[[COLLAPSED_OP]]
+
+//      CONTROL: func @fuse_by_collapsing_change_reshape_order_bubblecollapse(
+// CONTROL-SAME:   %[[ARG0:.+]]: tensor<9x7x8x2x3x4x5x6xi32>
+// CONTROL-SAME:   %[[ARG1:.+]]: tensor<7x8x2xi32>
+// CONTROL-SAME:   %[[ARG2:.+]]: tensor<6x3x4x5xi32>
+//      CONTROL:   %[[GENERIC:.+]] = linalg.generic
+// CONTROL-SAME:       ins(%[[ARG0]],
+//      CONTROL:   %[[COLLAPSE:.+]] = tensor.collapse_shape %[[GENERIC]]
+//      CONTROL:   return %[[COLLAPSE]]
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
index e4883e47f2063..34a396f18f90e 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
@@ -235,6 +235,12 @@ struct TestLinalgElementwiseFusion
           // Skip fusing the first operand.
           return fusedOperand->getOperandNumber();
         }
+        Operation *consumer = fusedOperand->getOwner();
+        if (auto collapseOp = dyn_cast<tensor::CollapseShapeOp>(consumer)) {
+          auto producerResult = dyn_cast<OpResult>(collapseOp.getSrc());
+          // skip fusing first result.
+          return producerResult.getResultNumber();
+        }
         return true;
       };
       linalg::populateFoldReshapeOpsByCollapsingPatterns(patterns, controlFn);



More information about the Mlir-commits mailing list