[Mlir-commits] [mlir] 7f28d27 - [mlir][linalg] Allow controlling folding unit dim reshapes

Lei Zhang llvmlistbot at llvm.org
Wed Mar 24 15:19:39 PDT 2021


Author: Lei Zhang
Date: 2021-03-24T18:17:57-04:00
New Revision: 7f28d27cb614c47e6cf68f5deae729270d13cb08

URL: https://github.com/llvm/llvm-project/commit/7f28d27cb614c47e6cf68f5deae729270d13cb08
DIFF: https://github.com/llvm/llvm-project/commit/7f28d27cb614c47e6cf68f5deae729270d13cb08.diff

LOG: [mlir][linalg] Allow controlling folding unit dim reshapes

This commit exposes an option to the pattern
FoldWithProducerReshapeOpByExpansion to allow
folding unit dim reshapes. This gives callers
more fine-grained controls.

Differential Revision: https://reviews.llvm.org/D99114

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Passes.h
    mlir/include/mlir/Dialect/Linalg/Passes.td
    mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
    mlir/test/Dialect/Linalg/reshape_fusion.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h
index ecec2a3c05d2..18820d4316b9 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.h
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.h
@@ -65,7 +65,8 @@ std::unique_ptr<Pass> createLinalgDetensorizePass();
 /// Patterns to fold an expanding (collapsing) tensor_reshape operation with its
 /// producer (consumer) generic operation by expanding the dimensionality of the
 /// loop in the generic op.
-void populateFoldReshapeOpsByExpansionPatterns(RewritePatternSet &patterns);
+void populateFoldReshapeOpsByExpansionPatterns(
+    RewritePatternSet &patterns, bool allowFoldingUnitDimReshapes = false);
 
 /// Patterns to fold a collapsing (expanding) tensor_reshape operation with its
 /// producer (consumer) generic/indexed_generic operation by linearizing the
@@ -83,7 +84,8 @@ void populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
     RewritePatternSet &patterns);
 
 /// Patterns for fusing linalg operation on tensors.
-void populateLinalgTensorOpsFusionPatterns(RewritePatternSet &patterns);
+void populateLinalgTensorOpsFusionPatterns(
+    RewritePatternSet &patterns, bool allowFoldingUnitDimReshapes = false);
 
 /// Patterns to fold unit-extent dimensions in operands/results of linalg ops on
 /// tensors.

diff  --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index aad11179be69..786b9ec85dcf 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -37,6 +37,12 @@ def LinalgFoldUnitExtentDims : FunctionPass<"linalg-fold-unit-extent-dims"> {
 def LinalgFusionOfTensorOps : Pass<"linalg-fusion-for-tensor-ops"> {
   let summary = "Fuse operations on RankedTensorType in linalg dialect";
   let constructor = "mlir::createLinalgFusionOfTensorOpsPass()";
+  let options = [
+    Option<"allowFoldingUnitDimReshapes", "allow-folding-unit-dim-reshapes",
+           "bool", /*default=*/"false",
+           "Allow fusing linalg.tensor_reshape ops that performs unit "
+           "dimension collapsing">
+  ];
   let dependentDialects = ["linalg::LinalgDialect", "AffineDialect"];
 }
 

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
index 4b0951ea4c1c..7e89a0887d0d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
@@ -897,9 +897,14 @@ struct FoldProducerReshapeOpByLinearization
 /// generic/indexed_generic op, when the reshape op is collapsing
 /// dimensions. The dimensionality of the loop in the consumer is expanded.
 template <typename GenericOpTy>
-struct FoldWithProducerReshapeOpByExpansion
+class FoldWithProducerReshapeOpByExpansion
     : public OpRewritePattern<GenericOpTy> {
-  using OpRewritePattern<GenericOpTy>::OpRewritePattern;
+public:
+  FoldWithProducerReshapeOpByExpansion(MLIRContext *context,
+                                       bool foldUnitDimReshapes,
+                                       PatternBenefit benefit = 1)
+      : OpRewritePattern<GenericOpTy>(context, benefit),
+        allowFoldingUnitDimReshapes(foldUnitDimReshapes) {}
 
   LogicalResult matchAndRewrite(GenericOpTy genericOp,
                                 PatternRewriter &rewriter) const override {
@@ -916,8 +921,9 @@ struct FoldWithProducerReshapeOpByExpansion
       if (reshapeOp.getSrcType().getRank() <
               reshapeOp.getResultType().getRank() ||
           !isFusableWithReshapeByDimExpansion(linalgOp, operand.index()) ||
-          isUnitDimExpansionOnly(reshapeOp.getSrcType().getShape(),
-                                 reshapeOp.getReassociationMaps()))
+          (!allowFoldingUnitDimReshapes &&
+           isUnitDimExpansionOnly(reshapeOp.getSrcType().getShape(),
+                                  reshapeOp.getReassociationMaps())))
         continue;
 
       Optional<SmallVector<Value, 1>> replacementValues =
@@ -930,6 +936,9 @@ struct FoldWithProducerReshapeOpByExpansion
     }
     return failure();
   }
+
+private:
+  bool allowFoldingUnitDimReshapes;
 };
 
 /// Pattern to fold tensor_reshape op with its producer. The corresponding index
@@ -1134,7 +1143,8 @@ struct FusionOfTensorOpsPass
   void runOnOperation() override {
     Operation *op = getOperation();
     RewritePatternSet patterns(op->getContext());
-    populateLinalgTensorOpsFusionPatterns(patterns);
+    populateLinalgTensorOpsFusionPatterns(patterns,
+                                          allowFoldingUnitDimReshapes);
     (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns));
   }
 };
@@ -1171,20 +1181,22 @@ void mlir::populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
 }
 
 void mlir::populateFoldReshapeOpsByExpansionPatterns(
-    RewritePatternSet &patterns) {
-  patterns.add<FoldReshapeWithGenericOpByExpansion,
-               FoldWithProducerReshapeOpByExpansion<GenericOp>,
+    RewritePatternSet &patterns, bool allowFoldingUnitDimReshapes) {
+  patterns.add<FoldReshapeWithGenericOpByExpansion>(patterns.getContext());
+  patterns.add<FoldWithProducerReshapeOpByExpansion<GenericOp>,
                FoldWithProducerReshapeOpByExpansion<IndexedGenericOp>>(
-      patterns.getContext());
+      patterns.getContext(), allowFoldingUnitDimReshapes);
 }
 
-void mlir::populateLinalgTensorOpsFusionPatterns(RewritePatternSet &patterns) {
+void mlir::populateLinalgTensorOpsFusionPatterns(
+    RewritePatternSet &patterns, bool allowFoldingUnitDimReshapes) {
   auto *context = patterns.getContext();
   patterns
       .add<FuseTensorOps<GenericOp>, FuseTensorOps<IndexedGenericOp>,
            FoldSplatConstants<GenericOp>, FoldSplatConstants<IndexedGenericOp>>(
           context);
-  populateFoldReshapeOpsByExpansionPatterns(patterns);
+  populateFoldReshapeOpsByExpansionPatterns(patterns,
+                                            allowFoldingUnitDimReshapes);
   GenericOp::getCanonicalizationPatterns(patterns, context);
   IndexedGenericOp::getCanonicalizationPatterns(patterns, context);
   TensorReshapeOp::getCanonicalizationPatterns(patterns, context);

diff  --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
index fbaf47c9ac4d..d5dc176f1fdf 100644
--- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir
+++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
@@ -1,4 +1,5 @@
-// RUN: mlir-opt %s -linalg-fusion-for-tensor-ops -split-input-file -verify-each=0 | FileCheck %s
+// RUN: mlir-opt %s -linalg-fusion-for-tensor-ops="allow-folding-unit-dim-reshapes=false" -split-input-file -verify-each=0 | FileCheck %s
+// RUN: mlir-opt %s -linalg-fusion-for-tensor-ops="allow-folding-unit-dim-reshapes=true" -split-input-file -verify-each=0 | FileCheck %s --check-prefix=FOLDUNITDIM
 
 #map0 = affine_map<(d0, d1, d2) -> (d2, d0, d1)>
 #map1 = affine_map<(d0, d1, d2) -> (d1, d2, d0)>
@@ -300,7 +301,7 @@ func @reshape_as_consumer_permutation
          %5 = addi %3, %4 : i32
          %6 = index_cast %arg2 : index to i32
          %7 = addi %5, %6 : i32
-	 linalg.yield %7 : i32
+         linalg.yield %7 : i32
        } -> tensor<6x4x210xi32>
   %d = linalg.tensor_reshape %c
          [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1)>,
@@ -531,3 +532,11 @@ func @unit_dim_reshape_expansion_full
 //  CHECK-DAG:   linalg.tensor_reshape
 //  CHECK-DAG:   linalg.init_tensor
 //      CHECK:   linalg.generic
+// CHECK-SAME:     ins(%{{.+}}, %{{.+}} : tensor<?x2x4xf32>, tensor<?x2x4xf32>)
+
+//         FOLDUNITDIM: func @unit_dim_reshape_expansion_full
+//         FOLDUNITDIM:   linalg.init_tensor
+// FOLDUNITDIM-COUNT-2:   linalg.tensor_reshape
+//         FOLDUNITDIM:   linalg.generic
+//    FOLDUNITDIM-SAME:     ins(%{{.+}}, %{{.+}} : tensor<1x?x1x2x1x4xf32>, tensor<1x?x1x2x1x4xf32>)
+


        


More information about the Mlir-commits mailing list