[Mlir-commits] [mlir] d2b070d - [mlir][linalg][NFC] Split populateFoldUnitExtentDimsViaReshapesPatterns

Matthias Springer llvmlistbot at llvm.org
Thu Dec 15 03:05:05 PST 2022


Author: Matthias Springer
Date: 2022-12-15T12:04:40+01:00
New Revision: d2b070d3c95c50579da83be97f79bd2c52f188c8

URL: https://github.com/llvm/llvm-project/commit/d2b070d3c95c50579da83be97f79bd2c52f188c8
DIFF: https://github.com/llvm/llvm-project/commit/d2b070d3c95c50579da83be97f79bd2c52f188c8.diff

LOG: [mlir][linalg][NFC] Split populateFoldUnitExtentDimsViaReshapesPatterns

MoveInitOperandsToInput is put into a separate populate... function because it can interfere with certain transformations.

Differential Revision: https://reviews.llvm.org/D140091

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 13a7e6f3f2aec..06d0448c2ad70 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -136,6 +136,9 @@ void populateFoldUnitExtentDimsViaReshapesPatterns(RewritePatternSet &patterns);
 /// tensors via rank-reducing slices.
 void populateFoldUnitExtentDimsViaSlicesPatterns(RewritePatternSet &patterns);
 
+/// A pattern that converts init operands to input operands.
+void populateMoveInitOperandsToInputPattern(RewritePatternSet &patterns);
+
 /// Patterns that are used to inline constant operands into linalg generic ops.
 void populateInlineConstantOperandsPatterns(RewritePatternSet &patterns);
 

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index ab288c0bb5f1a..e4b0593194af1 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -233,7 +233,7 @@ struct FoldUnitDimLoops : public OpRewritePattern<GenericOp> {
   }
 };
 
-/// Pattern to add init operands to ins when all the loops are parallel and
+/// Pattern to move init operands to ins when all the loops are parallel and
 /// blockArgument corresponding to init is used in the region. This is a fix-up
 /// when unit reduction dimensions are all folded away. In this context, it
 /// becomes a elementwise generic op. E.g., it converts
@@ -269,7 +269,7 @@ struct FoldUnitDimLoops : public OpRewritePattern<GenericOp> {
 ///    %4 = arith.addf %in, %in_0 : f32
 ///    linalg.yield %4 : f32
 ///  } -> tensor<1x1xf32>
-struct AddInitOperandsToInput : public OpRewritePattern<GenericOp> {
+struct MoveInitOperandsToInput : public OpRewritePattern<GenericOp> {
   using OpRewritePattern<GenericOp>::OpRewritePattern;
   LogicalResult matchAndRewrite(GenericOp genericOp,
                                 PatternRewriter &rewriter) const override {
@@ -667,10 +667,10 @@ void mlir::linalg::populateFoldUnitExtentDimsViaReshapesPatterns(
   patterns.add<ReplaceUnitExtents>(context,
                                    RankReductionStrategy::ReassociativeReshape);
   // TODO: Patterns unrelated to unit dim folding should be factored out.
-  patterns
-      .add<FoldUnitDimLoops, AddInitOperandsToInput, RankReducedExtractSliceOp,
-           RankReducedInsertSliceOp<tensor::InsertSliceOp>,
-           RankReducedInsertSliceOp<tensor::ParallelInsertSliceOp>>(context);
+  patterns.add<FoldUnitDimLoops, RankReducedExtractSliceOp,
+               RankReducedInsertSliceOp<tensor::InsertSliceOp>,
+               RankReducedInsertSliceOp<tensor::ParallelInsertSliceOp>>(
+      context);
   linalg::FillOp::getCanonicalizationPatterns(patterns, context);
   tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context);
   tensor::EmptyOp::getCanonicalizationPatterns(patterns, context);
@@ -688,6 +688,11 @@ void mlir::linalg::populateFoldUnitExtentDimsViaSlicesPatterns(
   patterns.add<FoldUnitDimLoops>(context);
 }
 
+void mlir::linalg::populateMoveInitOperandsToInputPattern(
+    RewritePatternSet &patterns) {
+  patterns.add<MoveInitOperandsToInput>(patterns.getContext());
+}
+
 namespace {
 /// Pass that removes unit-extent dims within generic ops.
 struct LinalgFoldUnitExtentDimsPass
@@ -697,11 +702,13 @@ struct LinalgFoldUnitExtentDimsPass
     MLIRContext *context = op->getContext();
     RewritePatternSet patterns(context);
     if (foldOneTripLoopsOnly) {
-      patterns.add<FoldUnitDimLoops, AddInitOperandsToInput>(context);
+      patterns.add<FoldUnitDimLoops, MoveInitOperandsToInput>(context);
     } else if (useRankReducingSlices) {
       populateFoldUnitExtentDimsViaSlicesPatterns(patterns);
+      populateMoveInitOperandsToInputPattern(patterns);
     } else {
       populateFoldUnitExtentDimsViaReshapesPatterns(patterns);
+      populateMoveInitOperandsToInputPattern(patterns);
     }
     (void)applyPatternsAndFoldGreedily(op, std::move(patterns));
   }


        


More information about the Mlir-commits mailing list