[Mlir-commits] [mlir] feat(linalg): add a way to pass controlFn to `foldIntoPackUnpackPatterns` (PR #143685)
Ege Beysel
llvmlistbot at llvm.org
Wed Jun 11 04:11:22 PDT 2025
https://github.com/egebeysel created https://github.com/llvm/llvm-project/pull/143685
This PR adds a mechanism, so that downstream consumers can pass in control functions for the application of these patterns. This change shouldn't affect any consumers of this method that do not specify a controlFn. The controlFn always gets the source operand of the consumer in each of the patterns as a parameter.
In IREE, we (will) use it to control preventing folding patterns that would inhibit fusion. See IREE issue [#20896](https://github.com/iree-org/iree/issues/20896) for more details.
>From 462c173f011ccfdce03752181a41824969d44e5a Mon Sep 17 00:00:00 2001
From: Ege Beysel <beyselege at gmail.com>
Date: Wed, 11 Jun 2025 11:06:38 +0200
Subject: [PATCH] feat(linalg): add a way to pass controlFn to
`foldIntoPackUnpackPatterns` (#22)
This PR adds a mechanism, so that downstream consumers can pass in control functions for the application of these patterns. This change shouldn't affect any consumers of this method that do not specify a controlFn.
In IREE, we (will) use it to control preventing folding patterns that would inhibit fusion. See IREE issue #20896 for more details.
---
.../Dialect/Linalg/Transforms/Transforms.h | 10 ++-
.../Transforms/PackAndUnpackPatterns.cpp | 83 +++++++++++++++++--
2 files changed, 85 insertions(+), 8 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 2eef0a06d0eb4..2f0e57ca9f5a7 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1894,10 +1894,18 @@ void populateDecomposeWinogradOpsPatterns(RewritePatternSet &patterns);
/// convert to a `linalg.dot`.
void populateContractionOpRankReducingPatterns(RewritePatternSet &patterns);
+/// Function type which is used to control folding operations like `tensor.pad`
+/// and `tensor.extract_slice` into linalg.pack/unpack ops.
+using ControlFoldIntoPackUnpackFn = std::function<bool(OpOperand *opOperand)>;
+inline bool defaultControlFoldIntoPackUnpackFn(OpOperand *opOperand) {
+ return true;
+};
/// Populates `patterns` with patterns that fold operations like `tensor.pad`
/// and `tensor.extract_slice` into `tensor.pack` and `tensor.unpack` operations
/// respectively.
-void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns);
+void populateFoldIntoPackAndUnpackPatterns(
+ RewritePatternSet &patterns, const ControlFoldIntoPackUnpackFn &controlFn =
+ defaultControlFoldIntoPackUnpackFn);
/// Populates `patterns` with patterns that fold operations like `linalg.pack`
/// and `linalg.unpack` into `tensor.empty`.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
index 0984b6988b93b..01cebb0f8e80d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
@@ -197,7 +198,9 @@ struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
/// Fold a `pad` -> `pack` into `pack` if they have the same padding values and
/// the pad op has zero low paddings, or if `pack` has no padding values.
struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
- using OpRewritePattern<PackOp>::OpRewritePattern;
+public:
+ FoldPadWithPackOp(MLIRContext *context, ControlFoldIntoPackUnpackFn controlFn)
+ : OpRewritePattern<PackOp>(context), controlFn(std::move(controlFn)) {}
LogicalResult matchAndRewrite(PackOp packOp,
PatternRewriter &rewriter) const override {
@@ -206,6 +209,10 @@ struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
if (!padOp || padOp.getNofold() || !padOp.hasZeroLowPad())
return failure();
+ // User controlled folding function.
+ if (!controlFn(&packOp.getSourceMutable()))
+ return failure();
+
Value constantPaddingValue = padOp.getConstantPaddingValue();
if (!constantPaddingValue)
return failure();
@@ -220,13 +227,20 @@ struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
packOp.getOuterDimsPerm());
return success();
}
+
+private:
+ ControlFoldIntoPackUnpackFn controlFn;
};
/// Fold a `unpack` -> `extract_slice` into the `unpack` since it already
/// has extract_slice semantics.
struct FoldUnpackWithExtractSliceOp
: public OpRewritePattern<tensor::ExtractSliceOp> {
- using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
+public:
+ FoldUnpackWithExtractSliceOp(MLIRContext *context,
+ ControlFoldIntoPackUnpackFn controlFn)
+ : OpRewritePattern<tensor::ExtractSliceOp>(context),
+ controlFn(std::move(controlFn)) {}
LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
PatternRewriter &rewriter) const override {
@@ -234,6 +248,10 @@ struct FoldUnpackWithExtractSliceOp
if (!unpackOp)
return failure();
+ // User controlled folding function.
+ if (!controlFn(&sliceOp.getSourceMutable()))
+ return failure();
+
if (sliceOp.getResultType().getRank() != unpackOp.getDestType().getRank()) {
return rewriter.notifyMatchFailure(
sliceOp, "rank-reduced folding is not supported");
@@ -255,6 +273,9 @@ struct FoldUnpackWithExtractSliceOp
unpackOp.getMixedTiles(), unpackOp.getOuterDimsPerm());
return success();
}
+
+private:
+ ControlFoldIntoPackUnpackFn controlFn;
};
// Applies 'permutation' on 'inVec' and stores the result in resVec.
@@ -284,7 +305,12 @@ static bool checkAndPermute(ArrayRef<int64_t> permutation,
/// semantics.
struct FoldProducerPackWithConsumerLinalgTransposeOp
: public OpInterfaceRewritePattern<linalg::LinalgOp> {
- using OpInterfaceRewritePattern<linalg::LinalgOp>::OpInterfaceRewritePattern;
+
+public:
+ FoldProducerPackWithConsumerLinalgTransposeOp(
+ MLIRContext *context, ControlFoldIntoPackUnpackFn controlFn)
+ : OpInterfaceRewritePattern<linalg::LinalgOp>(context),
+ controlFn(std::move(controlFn)) {}
LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
PatternRewriter &rewriter) const override {
@@ -293,6 +319,10 @@ struct FoldProducerPackWithConsumerLinalgTransposeOp
if (!packOp)
return failure();
+ // User controlled folding function.
+ if (!controlFn(&linalgOp->getOpOperand(0)))
+ return failure();
+
FailureOr<SmallVector<int64_t>> maybePerm =
getTransposeOpPermutation(linalgOp);
if (failed(maybePerm))
@@ -331,13 +361,20 @@ struct FoldProducerPackWithConsumerLinalgTransposeOp
return success();
}
+
+private:
+ ControlFoldIntoPackUnpackFn controlFn;
};
/// Fold 'transpose' -> 'pack' into 'pack' since 'pack' already has transpose
/// semantics.
struct FoldConsumerPackWithProducerLinalgTransposeOp
: public OpRewritePattern<PackOp> {
- using OpRewritePattern<PackOp>::OpRewritePattern;
+
+public:
+ FoldConsumerPackWithProducerLinalgTransposeOp(
+ MLIRContext *context, ControlFoldIntoPackUnpackFn controlFn)
+ : OpRewritePattern<PackOp>(context), controlFn(std::move(controlFn)) {}
LogicalResult matchAndRewrite(PackOp packOp,
PatternRewriter &rewriter) const override {
@@ -345,6 +382,10 @@ struct FoldConsumerPackWithProducerLinalgTransposeOp
if (!linalgOp)
return failure();
+ // User controlled folding function.
+ if (!controlFn(&packOp.getSourceMutable()))
+ return failure();
+
FailureOr<SmallVector<int64_t>> maybePerm =
getTransposeOpPermutation(linalgOp);
if (failed(maybePerm))
@@ -375,13 +416,21 @@ struct FoldConsumerPackWithProducerLinalgTransposeOp
return success();
}
+
+private:
+ ControlFoldIntoPackUnpackFn controlFn;
};
/// Fold 'unpack' -> 'transpose' into 'unpack' since 'unpack' already has
/// transpose semantics.
struct FoldProducerUnPackWithConsumerLinalgTransposeOp
: public OpInterfaceRewritePattern<linalg::LinalgOp> {
- using OpInterfaceRewritePattern<linalg::LinalgOp>::OpInterfaceRewritePattern;
+
+public:
+ FoldProducerUnPackWithConsumerLinalgTransposeOp(
+ MLIRContext *context, ControlFoldIntoPackUnpackFn controlFn)
+ : OpInterfaceRewritePattern<linalg::LinalgOp>(context),
+ controlFn(std::move(controlFn)) {}
LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
PatternRewriter &rewriter) const override {
@@ -390,6 +439,10 @@ struct FoldProducerUnPackWithConsumerLinalgTransposeOp
if (!unPackOp)
return failure();
+ // User controlled folding function.
+ if (!controlFn(&linalgOp->getOpOperand(0)))
+ return failure();
+
FailureOr<SmallVector<int64_t>> maybePerm =
getTransposeOpPermutation(linalgOp);
if (failed(maybePerm))
@@ -416,6 +469,9 @@ struct FoldProducerUnPackWithConsumerLinalgTransposeOp
return success();
}
+
+private:
+ ControlFoldIntoPackUnpackFn controlFn;
};
/// Fold 'transpose' -> 'unpack' into 'unpack' since 'unpack' already has
@@ -424,12 +480,21 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp
: public OpRewritePattern<UnPackOp> {
using OpRewritePattern<UnPackOp>::OpRewritePattern;
+public:
+ FoldConsumerUnPackWithProducerLinalgTransposeOp(
+ MLIRContext *context, ControlFoldIntoPackUnpackFn controlFn)
+ : OpRewritePattern<UnPackOp>(context), controlFn(std::move(controlFn)) {}
+
LogicalResult matchAndRewrite(UnPackOp unPackOp,
PatternRewriter &rewriter) const override {
auto linalgOp = unPackOp.getSource().getDefiningOp<linalg::LinalgOp>();
if (!linalgOp)
return failure();
+ // User controlled folding function.
+ if (!controlFn(&unPackOp.getSourceMutable()))
+ return failure();
+
FailureOr<SmallVector<int64_t>> maybePerm =
getTransposeOpPermutation(linalgOp);
if (failed(maybePerm))
@@ -474,6 +539,9 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp
return success();
}
+
+private:
+ ControlFoldIntoPackUnpackFn controlFn;
};
/// tensor.empty does not define any tensor contents, so an unpadded pack
@@ -521,13 +589,14 @@ struct FoldEmptyTensorWithUnPackOp : public OpRewritePattern<UnPackOp> {
} // namespace
-void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns) {
+void populateFoldIntoPackAndUnpackPatterns(
+ RewritePatternSet &patterns, const ControlFoldIntoPackUnpackFn &controlFn) {
patterns.insert<FoldUnpackWithExtractSliceOp, FoldPadWithPackOp,
FoldProducerPackWithConsumerLinalgTransposeOp,
FoldConsumerPackWithProducerLinalgTransposeOp,
FoldConsumerUnPackWithProducerLinalgTransposeOp,
FoldProducerUnPackWithConsumerLinalgTransposeOp>(
- patterns.getContext());
+ patterns.getContext(), controlFn);
}
void populateSimplifyPackAndUnpackPatterns(RewritePatternSet &patterns) {
More information about the Mlir-commits
mailing list