[Mlir-commits] [mlir] d2b070d - [mlir][linalg][NFC] Split populateFoldUnitExtentDimsViaReshapesPatterns
Matthias Springer
llvmlistbot at llvm.org
Thu Dec 15 03:05:05 PST 2022
Author: Matthias Springer
Date: 2022-12-15T12:04:40+01:00
New Revision: d2b070d3c95c50579da83be97f79bd2c52f188c8
URL: https://github.com/llvm/llvm-project/commit/d2b070d3c95c50579da83be97f79bd2c52f188c8
DIFF: https://github.com/llvm/llvm-project/commit/d2b070d3c95c50579da83be97f79bd2c52f188c8.diff
LOG: [mlir][linalg][NFC] Split populateFoldUnitExtentDimsViaReshapesPatterns
MoveInitOperandsToInput is put into a separate populate... function because it can interfere with certain transformations.
Differential Revision: https://reviews.llvm.org/D140091
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 13a7e6f3f2aec..06d0448c2ad70 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -136,6 +136,9 @@ void populateFoldUnitExtentDimsViaReshapesPatterns(RewritePatternSet &patterns);
/// tensors via rank-reducing slices.
void populateFoldUnitExtentDimsViaSlicesPatterns(RewritePatternSet &patterns);
+/// A pattern that converts init operands to input operands.
+void populateMoveInitOperandsToInputPattern(RewritePatternSet &patterns);
+
/// Patterns that are used to inline constant operands into linalg generic ops.
void populateInlineConstantOperandsPatterns(RewritePatternSet &patterns);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index ab288c0bb5f1a..e4b0593194af1 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -233,7 +233,7 @@ struct FoldUnitDimLoops : public OpRewritePattern<GenericOp> {
}
};
-/// Pattern to add init operands to ins when all the loops are parallel and
+/// Pattern to move init operands to ins when all the loops are parallel and
/// blockArgument corresponding to init is used in the region. This is a fix-up
/// when unit reduction dimensions are all folded away. In this context, it
/// becomes a elementwise generic op. E.g., it converts
@@ -269,7 +269,7 @@ struct FoldUnitDimLoops : public OpRewritePattern<GenericOp> {
/// %4 = arith.addf %in, %in_0 : f32
/// linalg.yield %4 : f32
/// } -> tensor<1x1xf32>
-struct AddInitOperandsToInput : public OpRewritePattern<GenericOp> {
+struct MoveInitOperandsToInput : public OpRewritePattern<GenericOp> {
using OpRewritePattern<GenericOp>::OpRewritePattern;
LogicalResult matchAndRewrite(GenericOp genericOp,
PatternRewriter &rewriter) const override {
@@ -667,10 +667,10 @@ void mlir::linalg::populateFoldUnitExtentDimsViaReshapesPatterns(
patterns.add<ReplaceUnitExtents>(context,
RankReductionStrategy::ReassociativeReshape);
// TODO: Patterns unrelated to unit dim folding should be factored out.
- patterns
- .add<FoldUnitDimLoops, AddInitOperandsToInput, RankReducedExtractSliceOp,
- RankReducedInsertSliceOp<tensor::InsertSliceOp>,
- RankReducedInsertSliceOp<tensor::ParallelInsertSliceOp>>(context);
+ patterns.add<FoldUnitDimLoops, RankReducedExtractSliceOp,
+ RankReducedInsertSliceOp<tensor::InsertSliceOp>,
+ RankReducedInsertSliceOp<tensor::ParallelInsertSliceOp>>(
+ context);
linalg::FillOp::getCanonicalizationPatterns(patterns, context);
tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context);
tensor::EmptyOp::getCanonicalizationPatterns(patterns, context);
@@ -688,6 +688,11 @@ void mlir::linalg::populateFoldUnitExtentDimsViaSlicesPatterns(
patterns.add<FoldUnitDimLoops>(context);
}
+void mlir::linalg::populateMoveInitOperandsToInputPattern(
+ RewritePatternSet &patterns) {
+ patterns.add<MoveInitOperandsToInput>(patterns.getContext());
+}
+
namespace {
/// Pass that removes unit-extent dims within generic ops.
struct LinalgFoldUnitExtentDimsPass
@@ -697,11 +702,13 @@ struct LinalgFoldUnitExtentDimsPass
MLIRContext *context = op->getContext();
RewritePatternSet patterns(context);
if (foldOneTripLoopsOnly) {
- patterns.add<FoldUnitDimLoops, AddInitOperandsToInput>(context);
+ patterns.add<FoldUnitDimLoops, MoveInitOperandsToInput>(context);
} else if (useRankReducingSlices) {
populateFoldUnitExtentDimsViaSlicesPatterns(patterns);
+ populateMoveInitOperandsToInputPattern(patterns);
} else {
populateFoldUnitExtentDimsViaReshapesPatterns(patterns);
+ populateMoveInitOperandsToInputPattern(patterns);
}
(void)applyPatternsAndFoldGreedily(op, std::move(patterns));
}
More information about the Mlir-commits
mailing list