[Mlir-commits] [mlir] feat(linalg): add a way to pass controlFn to `foldIntoPackUnpackPatterns` (PR #143685)

Ege Beysel llvmlistbot at llvm.org
Sun Jun 29 13:58:49 PDT 2025


https://github.com/egebeysel updated https://github.com/llvm/llvm-project/pull/143685

>From 462c173f011ccfdce03752181a41824969d44e5a Mon Sep 17 00:00:00 2001
From: Ege Beysel <beyselege at gmail.com>
Date: Wed, 11 Jun 2025 11:06:38 +0200
Subject: [PATCH 1/2] feat(linalg): add a way to pass controlFn to
 `foldIntoPackUnpackPatterns` (#22)

This PR adds a mechanism, so that downstream consumers can pass in control functions for the application of these patterns. This change shouldn't affect any consumers of this method that do not specify a controlFn.

In IREE, we (will) use it to control preventing folding patterns that would inhibit fusion. See IREE issue #20896 for more details.
---
 .../Dialect/Linalg/Transforms/Transforms.h    | 10 ++-
 .../Transforms/PackAndUnpackPatterns.cpp      | 83 +++++++++++++++++--
 2 files changed, 85 insertions(+), 8 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 2eef0a06d0eb4..2f0e57ca9f5a7 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1894,10 +1894,18 @@ void populateDecomposeWinogradOpsPatterns(RewritePatternSet &patterns);
 /// convert to a `linalg.dot`.
 void populateContractionOpRankReducingPatterns(RewritePatternSet &patterns);
 
+/// Function type which is used to control folding operations like `tensor.pad`
+/// and `tensor.extract_slice` into linalg.pack/unpack ops.
+using ControlFoldIntoPackUnpackFn = std::function<bool(OpOperand *opOperand)>;
+inline bool defaultControlFoldIntoPackUnpackFn(OpOperand *opOperand) {
+  return true;
+};
 /// Populates `patterns` with patterns that fold operations like `tensor.pad`
 /// and `tensor.extract_slice` into `tensor.pack` and `tensor.unpack` operations
 /// respectively.
-void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns);
+void populateFoldIntoPackAndUnpackPatterns(
+    RewritePatternSet &patterns, const ControlFoldIntoPackUnpackFn &controlFn =
+                                     defaultControlFoldIntoPackUnpackFn);
 
 /// Populates `patterns` with patterns that fold operations like `linalg.pack`
 /// and `linalg.unpack` into `tensor.empty`.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
index 0984b6988b93b..01cebb0f8e80d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Tensor/Transforms/Transforms.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
@@ -197,7 +198,9 @@ struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
 /// Fold a `pad` -> `pack` into `pack` if they have the same padding values and
 /// the pad op has zero low paddings, or if `pack` has no padding values.
 struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
-  using OpRewritePattern<PackOp>::OpRewritePattern;
+public:
+  FoldPadWithPackOp(MLIRContext *context, ControlFoldIntoPackUnpackFn controlFn)
+      : OpRewritePattern<PackOp>(context), controlFn(std::move(controlFn)) {}
 
   LogicalResult matchAndRewrite(PackOp packOp,
                                 PatternRewriter &rewriter) const override {
@@ -206,6 +209,10 @@ struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
     if (!padOp || padOp.getNofold() || !padOp.hasZeroLowPad())
       return failure();
 
+    // User controlled folding function.
+    if (!controlFn(&packOp.getSourceMutable()))
+      return failure();
+
     Value constantPaddingValue = padOp.getConstantPaddingValue();
     if (!constantPaddingValue)
       return failure();
@@ -220,13 +227,20 @@ struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
         packOp.getOuterDimsPerm());
     return success();
   }
+
+private:
+  ControlFoldIntoPackUnpackFn controlFn;
 };
 
 /// Fold a `unpack` -> `extract_slice` into the `unpack` since it already
 /// has extract_slice semantics.
 struct FoldUnpackWithExtractSliceOp
     : public OpRewritePattern<tensor::ExtractSliceOp> {
-  using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
+public:
+  FoldUnpackWithExtractSliceOp(MLIRContext *context,
+                               ControlFoldIntoPackUnpackFn controlFn)
+      : OpRewritePattern<tensor::ExtractSliceOp>(context),
+        controlFn(std::move(controlFn)) {}
 
   LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
                                 PatternRewriter &rewriter) const override {
@@ -234,6 +248,10 @@ struct FoldUnpackWithExtractSliceOp
     if (!unpackOp)
       return failure();
 
+    // User controlled folding function.
+    if (!controlFn(&sliceOp.getSourceMutable()))
+      return failure();
+
     if (sliceOp.getResultType().getRank() != unpackOp.getDestType().getRank()) {
       return rewriter.notifyMatchFailure(
           sliceOp, "rank-reduced folding is not supported");
@@ -255,6 +273,9 @@ struct FoldUnpackWithExtractSliceOp
         unpackOp.getMixedTiles(), unpackOp.getOuterDimsPerm());
     return success();
   }
+
+private:
+  ControlFoldIntoPackUnpackFn controlFn;
 };
 
 // Applies 'permutation' on 'inVec' and stores the result in resVec.
@@ -284,7 +305,12 @@ static bool checkAndPermute(ArrayRef<int64_t> permutation,
 /// semantics.
 struct FoldProducerPackWithConsumerLinalgTransposeOp
     : public OpInterfaceRewritePattern<linalg::LinalgOp> {
-  using OpInterfaceRewritePattern<linalg::LinalgOp>::OpInterfaceRewritePattern;
+
+public:
+  FoldProducerPackWithConsumerLinalgTransposeOp(
+      MLIRContext *context, ControlFoldIntoPackUnpackFn controlFn)
+      : OpInterfaceRewritePattern<linalg::LinalgOp>(context),
+        controlFn(std::move(controlFn)) {}
 
   LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
                                 PatternRewriter &rewriter) const override {
@@ -293,6 +319,10 @@ struct FoldProducerPackWithConsumerLinalgTransposeOp
     if (!packOp)
       return failure();
 
+    // User controlled folding function.
+    if (!controlFn(&linalgOp->getOpOperand(0)))
+      return failure();
+
     FailureOr<SmallVector<int64_t>> maybePerm =
         getTransposeOpPermutation(linalgOp);
     if (failed(maybePerm))
@@ -331,13 +361,20 @@ struct FoldProducerPackWithConsumerLinalgTransposeOp
 
     return success();
   }
+
+private:
+  ControlFoldIntoPackUnpackFn controlFn;
 };
 
 /// Fold 'transpose' -> 'pack' into 'pack' since 'pack' already has transpose
 /// semantics.
 struct FoldConsumerPackWithProducerLinalgTransposeOp
     : public OpRewritePattern<PackOp> {
-  using OpRewritePattern<PackOp>::OpRewritePattern;
+
+public:
+  FoldConsumerPackWithProducerLinalgTransposeOp(
+      MLIRContext *context, ControlFoldIntoPackUnpackFn controlFn)
+      : OpRewritePattern<PackOp>(context), controlFn(std::move(controlFn)) {}
 
   LogicalResult matchAndRewrite(PackOp packOp,
                                 PatternRewriter &rewriter) const override {
@@ -345,6 +382,10 @@ struct FoldConsumerPackWithProducerLinalgTransposeOp
     if (!linalgOp)
       return failure();
 
+    // User controlled folding function.
+    if (!controlFn(&packOp.getSourceMutable()))
+      return failure();
+
     FailureOr<SmallVector<int64_t>> maybePerm =
         getTransposeOpPermutation(linalgOp);
     if (failed(maybePerm))
@@ -375,13 +416,21 @@ struct FoldConsumerPackWithProducerLinalgTransposeOp
 
     return success();
   }
+
+private:
+  ControlFoldIntoPackUnpackFn controlFn;
 };
 
 /// Fold 'unpack' -> 'transpose' into 'unpack' since 'unpack' already has
 /// transpose semantics.
 struct FoldProducerUnPackWithConsumerLinalgTransposeOp
     : public OpInterfaceRewritePattern<linalg::LinalgOp> {
-  using OpInterfaceRewritePattern<linalg::LinalgOp>::OpInterfaceRewritePattern;
+
+public:
+  FoldProducerUnPackWithConsumerLinalgTransposeOp(
+      MLIRContext *context, ControlFoldIntoPackUnpackFn controlFn)
+      : OpInterfaceRewritePattern<linalg::LinalgOp>(context),
+        controlFn(std::move(controlFn)) {}
 
   LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
                                 PatternRewriter &rewriter) const override {
@@ -390,6 +439,10 @@ struct FoldProducerUnPackWithConsumerLinalgTransposeOp
     if (!unPackOp)
       return failure();
 
+    // User controlled folding function.
+    if (!controlFn(&linalgOp->getOpOperand(0)))
+      return failure();
+
     FailureOr<SmallVector<int64_t>> maybePerm =
         getTransposeOpPermutation(linalgOp);
     if (failed(maybePerm))
@@ -416,6 +469,9 @@ struct FoldProducerUnPackWithConsumerLinalgTransposeOp
 
     return success();
   }
+
+private:
+  ControlFoldIntoPackUnpackFn controlFn;
 };
 
 /// Fold 'transpose' -> 'unpack' into 'unpack' since 'unpack' already has
@@ -424,12 +480,21 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp
     : public OpRewritePattern<UnPackOp> {
   using OpRewritePattern<UnPackOp>::OpRewritePattern;
 
+public:
+  FoldConsumerUnPackWithProducerLinalgTransposeOp(
+      MLIRContext *context, ControlFoldIntoPackUnpackFn controlFn)
+      : OpRewritePattern<UnPackOp>(context), controlFn(std::move(controlFn)) {}
+
   LogicalResult matchAndRewrite(UnPackOp unPackOp,
                                 PatternRewriter &rewriter) const override {
     auto linalgOp = unPackOp.getSource().getDefiningOp<linalg::LinalgOp>();
     if (!linalgOp)
       return failure();
 
+    // User controlled folding function.
+    if (!controlFn(&unPackOp.getSourceMutable()))
+      return failure();
+
     FailureOr<SmallVector<int64_t>> maybePerm =
         getTransposeOpPermutation(linalgOp);
     if (failed(maybePerm))
@@ -474,6 +539,9 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp
 
     return success();
   }
+
+private:
+  ControlFoldIntoPackUnpackFn controlFn;
 };
 
 /// tensor.empty does not define any tensor contents, so an unpadded pack
@@ -521,13 +589,14 @@ struct FoldEmptyTensorWithUnPackOp : public OpRewritePattern<UnPackOp> {
 
 } // namespace
 
-void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns) {
+void populateFoldIntoPackAndUnpackPatterns(
+    RewritePatternSet &patterns, const ControlFoldIntoPackUnpackFn &controlFn) {
   patterns.insert<FoldUnpackWithExtractSliceOp, FoldPadWithPackOp,
                   FoldProducerPackWithConsumerLinalgTransposeOp,
                   FoldConsumerPackWithProducerLinalgTransposeOp,
                   FoldConsumerUnPackWithProducerLinalgTransposeOp,
                   FoldProducerUnPackWithConsumerLinalgTransposeOp>(
-      patterns.getContext());
+      patterns.getContext(), controlFn);
 }
 
 void populateSimplifyPackAndUnpackPatterns(RewritePatternSet &patterns) {

>From b1c2a1e4fbd9630ffec7b190751f82d34023d9d5 Mon Sep 17 00:00:00 2001
From: Ege Beysel <beyselege at gmail.com>
Date: Sun, 29 Jun 2025 20:57:55 +0000
Subject: [PATCH 2/2] add tests and address comments

---
 .../Dialect/Linalg/Transforms/Transforms.h    |  7 +-
 .../Transforms/PackAndUnpackPatterns.cpp      | 12 +--
 .../Tensor/fold-into-pack-and-unpack.mlir     | 97 ++++++++++++++++++-
 .../Dialect/Linalg/TestLinalgTransforms.cpp   | 24 ++++-
 4 files changed, 126 insertions(+), 14 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 2f0e57ca9f5a7..34bef0e56e1e4 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1897,15 +1897,12 @@ void populateContractionOpRankReducingPatterns(RewritePatternSet &patterns);
 /// Function type which is used to control folding operations like `tensor.pad`
 /// and `tensor.extract_slice` into linalg.pack/unpack ops.
 using ControlFoldIntoPackUnpackFn = std::function<bool(OpOperand *opOperand)>;
-inline bool defaultControlFoldIntoPackUnpackFn(OpOperand *opOperand) {
-  return true;
-};
 /// Populates `patterns` with patterns that fold operations like `tensor.pad`
 /// and `tensor.extract_slice` into `tensor.pack` and `tensor.unpack` operations
 /// respectively.
 void populateFoldIntoPackAndUnpackPatterns(
-    RewritePatternSet &patterns, const ControlFoldIntoPackUnpackFn &controlFn =
-                                     defaultControlFoldIntoPackUnpackFn);
+    RewritePatternSet &patterns,
+    const ControlFoldIntoPackUnpackFn &controlFn = nullptr);
 
 /// Populates `patterns` with patterns that fold operations like `linalg.pack`
 /// and `linalg.unpack` into `tensor.empty`.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
index 01cebb0f8e80d..9d8f90d991720 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
@@ -210,7 +210,7 @@ struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
       return failure();
 
     // User controlled folding function.
-    if (!controlFn(&packOp.getSourceMutable()))
+    if (controlFn && !controlFn(&packOp.getSourceMutable()))
       return failure();
 
     Value constantPaddingValue = padOp.getConstantPaddingValue();
@@ -249,7 +249,7 @@ struct FoldUnpackWithExtractSliceOp
       return failure();
 
     // User controlled folding function.
-    if (!controlFn(&sliceOp.getSourceMutable()))
+    if (controlFn && !controlFn(&sliceOp.getSourceMutable()))
       return failure();
 
     if (sliceOp.getResultType().getRank() != unpackOp.getDestType().getRank()) {
@@ -320,7 +320,7 @@ struct FoldProducerPackWithConsumerLinalgTransposeOp
       return failure();
 
     // User controlled folding function.
-    if (!controlFn(&linalgOp->getOpOperand(0)))
+    if (controlFn && !controlFn(&linalgOp->getOpOperand(0)))
       return failure();
 
     FailureOr<SmallVector<int64_t>> maybePerm =
@@ -383,7 +383,7 @@ struct FoldConsumerPackWithProducerLinalgTransposeOp
       return failure();
 
     // User controlled folding function.
-    if (!controlFn(&packOp.getSourceMutable()))
+    if (controlFn && !controlFn(&packOp.getSourceMutable()))
       return failure();
 
     FailureOr<SmallVector<int64_t>> maybePerm =
@@ -440,7 +440,7 @@ struct FoldProducerUnPackWithConsumerLinalgTransposeOp
       return failure();
 
     // User controlled folding function.
-    if (!controlFn(&linalgOp->getOpOperand(0)))
+    if (controlFn && !controlFn(&linalgOp->getOpOperand(0)))
       return failure();
 
     FailureOr<SmallVector<int64_t>> maybePerm =
@@ -492,7 +492,7 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp
       return failure();
 
     // User controlled folding function.
-    if (!controlFn(&unPackOp.getSourceMutable()))
+    if (controlFn && !controlFn(&unPackOp.getSourceMutable()))
       return failure();
 
     FailureOr<SmallVector<int64_t>> maybePerm =
diff --git a/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
index 84eb60248b8be..e749ffbe0bd8f 100644
--- a/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
+++ b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
@@ -1,4 +1,5 @@
 // RUN: mlir-opt -split-input-file -test-linalg-transform-patterns=test-fold-into-pack-and-unpack  %s | FileCheck %s
+// RUN: mlir-opt -split-input-file -test-linalg-transform-patterns=test-fold-into-pack-and-unpack-control  %s | FileCheck %s --check-prefix=CONTROL
 
 func.func @fold_unpack_slice(%arg0 : tensor<?x?x8x4xf32>, %arg1 : tensor<?x?xf32>,
     %arg2 : index, %arg3 : index) -> tensor<?x?xf32> {
@@ -373,6 +374,51 @@ func.func @linalg_transpose_linalg.pack_fold(%arg0: tensor<56x57x1x64xf32>) -> t
 
 // -----
 
+func.func @linalg_transpose_linalg.pack_fold_multi_result(%arg0: tensor<56x57x1x64xf32>) -> (tensor<1x56x57x64xf32>, tensor<1x57x56x2x32xf32>) {
+  %0 = tensor.empty() : tensor<1x56x57x64xf32>
+  %transposed = linalg.transpose
+    ins(%arg0 : tensor<56x57x1x64xf32>)
+    outs(%0 : tensor<1x56x57x64xf32>)
+    permutation = [2, 0, 1, 3]
+
+  %1 = tensor.empty() : tensor<1x57x56x2x32xf32>
+  %pack = linalg.pack %transposed
+    outer_dims_perm = [0, 2, 1, 3]
+    inner_dims_pos = [3]
+    inner_tiles = [32]
+    into %1 : tensor<1x56x57x64xf32> -> tensor<1x57x56x2x32xf32>
+  return %transposed, %pack : tensor<1x56x57x64xf32>, tensor<1x57x56x2x32xf32>
+}
+// CHECK-LABEL: func @linalg_transpose_linalg.pack_fold_multi_result(
+//  CHECK-SAME:     %[[ARG0:.+]]: tensor<56x57x1x64xf32>)
+//       CHECK:   %[[INIT:.+]] = tensor.empty() : tensor<1x56x57x64xf32>
+//       CHECK:   %[[TRANSPOSE:.+]] = linalg.transpose
+//  CHECK-SAME:      ins(%[[ARG0]]
+//  CHECK-SAME:      outs(%[[INIT]]
+//  CHECK-SAME:      permutation = [2, 0, 1, 3]
+//       CHECK:   %[[INIT1:.+]] = tensor.empty() : tensor<1x57x56x2x32xf32>
+//       CHECK:   %[[PACK:.+]] = linalg.pack %[[ARG0]]
+//  CHECK-SAME:      outer_dims_perm = [2, 1, 0, 3]
+//  CHECK-SAME:      inner_dims_pos = [3] inner_tiles = [32]
+//  CHECK-SAME:       into %[[INIT1]]
+//       CHECK:   return %[[TRANSPOSE]], %[[PACK]]
+
+// CONTROL-LABEL: func @linalg_transpose_linalg.pack_fold_multi_result(
+//  CONTROL-SAME:     %[[ARG0:.+]]: tensor<56x57x1x64xf32>)
+//       CONTROL:   %[[INIT:.+]] = tensor.empty() : tensor<1x56x57x64xf32>
+//       CONTROL:   %[[TRANSPOSE:.+]] = linalg.transpose
+//  CONTROL-SAME:      ins(%[[ARG0]]
+//  CONTROL-SAME:      outs(%[[INIT]]
+//  CONTROL-SAME:      permutation = [2, 0, 1, 3]
+//       CONTROL:   %[[INIT1:.+]] = tensor.empty() : tensor<1x57x56x2x32xf32>
+//       CONTROL:   %[[PACK:.+]] = linalg.pack %[[TRANSPOSE]]
+//  CONTROL-SAME:      outer_dims_perm = [0, 2, 1, 3]
+//  CONTROL-SAME:      inner_dims_pos = [3] inner_tiles = [32]
+//  CONTROL-SAME:       into %[[INIT1]]
+//       CONTROL:   return %[[TRANSPOSE]], %[[PACK]]
+
+// -----
+
 func.func @linalg_transpose_linalg.pack_fold_with_padding(%arg0: tensor<56x57x1x55xf32>, %padding: f32) -> tensor<1x57x56x2x32xf32> {
   %0 = tensor.empty() : tensor<1x56x57x55xf32>
   %transpose = linalg.transpose
@@ -550,6 +596,55 @@ func.func @linalg_transpose_linalg.unpack_fold(%arg0: tensor<1x1x4x16xi32>) -> t
 
 // -----
 
+func.func @linalg_transpose_linalg.unpack_fold_multi_result(%arg0: tensor<1x1x4x16xi32>) -> (tensor<1x1x16x4xi32>, tensor<16x4xi32>) {
+  %0 = tensor.empty() : tensor<1x1x16x4xi32>
+  %transposed = linalg.transpose ins(%arg0 : tensor<1x1x4x16xi32>)
+                outs(%0 : tensor<1x1x16x4xi32>)
+                permutation = [1, 0, 3, 2]
+  %1 = tensor.empty() : tensor<16x4xi32>
+  %unpack = linalg.unpack %transposed
+            outer_dims_perm = [0, 1]
+            inner_dims_pos = [0, 1]
+            inner_tiles = [16, 4] into
+            %1 : tensor<1x1x16x4xi32> -> tensor<16x4xi32>
+  return %transposed, %unpack : tensor<1x1x16x4xi32>, tensor<16x4xi32>
+}
+//CHECK-LABEL:  func.func @linalg_transpose_linalg.unpack_fold_multi_result(
+// CHECK-SAME:   %[[ARG0:.+]]: tensor<1x1x4x16xi32>) 
+// CHECK-SAME:   -> (tensor<1x1x16x4xi32>, tensor<16x4xi32>) 
+//      CHECK:     %[[EMPTY:.+]] = tensor.empty() : tensor<1x1x16x4xi32>
+//      CHECK:     %[[TRANSPOSE:.+]] = linalg.transpose
+// CHECK-SAME:        ins(%[[ARG0]] : tensor<1x1x4x16xi32>)
+// CHECK-SAME:        outs(%[[EMPTY]] : tensor<1x1x16x4xi32>)
+// CHECK-SAME:        permutation = [1, 0, 3, 2]
+//      CHECK:     %[[OUT:.+]] = tensor.empty() : tensor<16x4xi32>
+//      CHECK:     %[[UNPACK:.+]] = linalg.unpack %[[ARG0]]
+// CHECK-SAME:        outer_dims_perm = [1, 0]
+// CHECK-SAME:        inner_dims_pos = [1, 0]
+// CHECK-SAME:        inner_tiles = [4, 16]
+// CHECK-SAME:        into %[[OUT]] : tensor<1x1x4x16xi32> -> tensor<16x4xi32>
+//      CHECK:     return %[[TRANSPOSE]], %[[UNPACK]] : tensor<1x1x16x4xi32>, tensor<16x4xi32>
+//      CHECK:   }
+
+//CONTROL-LABEL:  func.func @linalg_transpose_linalg.unpack_fold_multi_result(
+// CONTROL-SAME:   %[[ARG0:.+]]: tensor<1x1x4x16xi32>) 
+// CONTROL-SAME:   -> (tensor<1x1x16x4xi32>, tensor<16x4xi32>) 
+//      CONTROL:     %[[EMPTY:.+]] = tensor.empty() : tensor<1x1x16x4xi32>
+//      CONTROL:     %[[TRANSPOSE:.+]] = linalg.transpose
+// CONTROL-SAME:        ins(%[[ARG0]] : tensor<1x1x4x16xi32>)
+// CONTROL-SAME:        outs(%[[EMPTY]] : tensor<1x1x16x4xi32>)
+// CONTROL-SAME:        permutation = [1, 0, 3, 2]
+//      CONTROL:     %[[OUT:.+]] = tensor.empty() : tensor<16x4xi32>
+//      CONTROL:     %[[UNPACK:.+]] = linalg.unpack %[[TRANSPOSE]]
+// CONTROL-SAME:        outer_dims_perm = [0, 1]
+// CONTROL-SAME:        inner_dims_pos = [0, 1]
+// CONTROL-SAME:        inner_tiles = [16, 4]
+// CONTROL-SAME:        into %[[OUT]] : tensor<1x1x16x4xi32> -> tensor<16x4xi32>
+//      CONTROL:     return %[[TRANSPOSE]], %[[UNPACK]] : tensor<1x1x16x4xi32>, tensor<16x4xi32>
+//      CONTROL:   }
+
+// -----
+
 func.func @linalg_transpose_linalg.unpack_fold_partial_tile(%arg0: tensor<1x1x4x16xi32>) -> tensor<15x3xi32> {
   %0 = tensor.empty() : tensor<1x1x16x4xi32>
   %transposed = linalg.transpose ins(%arg0 : tensor<1x1x4x16xi32>)
@@ -797,4 +892,4 @@ func.func @unpack_generic_transpose_fold(%arg0: tensor<57x3x56x1x64xf32>) -> ten
 //  CHECK-SAME:        inner_tiles = [1, 64]
 //  CHECK-SAME:        into %[[OUT:.+]] : tensor<57x3x56x1x64xf32> -> tensor<3648x3x56xf32>
 //       CHECK:       return %[[UNPACK]] : tensor<3648x3x56xf32>
-//       CHECK:    }
+//       CHECK:    }
\ No newline at end of file
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
index 046b9a65f3359..5612ed2f40d12 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
@@ -135,6 +135,11 @@ struct TestLinalgTransforms
       *this, "test-fold-into-pack-and-unpack",
       llvm::cl::desc("Test folding ops into linalg.pack and linalg.unpack"),
       llvm::cl::init(false)};
+  Option<bool> testFoldIntoPackAndUnpackWithControlFn{
+      *this, "test-fold-into-pack-and-unpack-control",
+      llvm::cl::desc(
+          "Test controlling folding ops into linalg.pack and linalg.unpack"),
+      llvm::cl::init(false)};
   Option<bool> testSimplifyPackUnpackPatterns{
       *this, "test-simplify-pack-unpack-patterns",
       llvm::cl::desc("Test patterns to simplify linalg.pack and linalg.unpack"),
@@ -236,9 +241,11 @@ static void applyDecomposeWinogradOps(func::FuncOp funcOp) {
   (void)applyPatternsGreedily(funcOp, std::move(patterns));
 }
 
-static void applyFoldIntoPackAndUnpackPatterns(Operation *rootOp) {
+static void applyFoldIntoPackAndUnpackPatterns(
+    Operation *rootOp,
+    linalg::ControlFoldIntoPackUnpackFn controlFn = nullptr) {
   RewritePatternSet patterns(rootOp->getContext());
-  linalg::populateFoldIntoPackAndUnpackPatterns(patterns);
+  linalg::populateFoldIntoPackAndUnpackPatterns(patterns, controlFn);
   (void)applyPatternsGreedily(rootOp, std::move(patterns));
 }
 
@@ -279,6 +286,19 @@ void TestLinalgTransforms::runOnOperation() {
   Operation *rootOp = getOperation();
   if (testFoldIntoPackAndUnpack)
     applyFoldIntoPackAndUnpackPatterns(rootOp);
+  if (testFoldIntoPackAndUnpackWithControlFn) {
+    linalg::ControlFoldIntoPackUnpackFn controlFn = [](OpOperand *opOperand) {
+      Operation *producer = opOperand->get().getDefiningOp();
+      Operation *consumer = opOperand->getOwner();
+      // If we have a pack/unpack consumer and a producer that has multiple
+      // uses, do not apply the folding patterns.
+      if (isa<linalg::PackOp, linalg::UnPackOp>(consumer) &&
+          isa<TilingInterface>(producer) && !producer->hasOneUse())
+        return false;
+      return true;
+    };
+    applyFoldIntoPackAndUnpackPatterns(rootOp, controlFn);
+  }
   if (testSimplifyPackUnpackPatterns)
     applySimplifyPackUnpackPatterns(rootOp);
 }



More information about the Mlir-commits mailing list