[Mlir-commits] [mlir] 52525cb - [mlir][linalg][NFC] Make reshape folding control more fine grain

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu May 6 10:12:07 PDT 2021


Author: thomasraoux
Date: 2021-05-06T10:11:39-07:00
New Revision: 52525cb20ff300d634453fdb3986adf4801f205c

URL: https://github.com/llvm/llvm-project/commit/52525cb20ff300d634453fdb3986adf4801f205c
DIFF: https://github.com/llvm/llvm-project/commit/52525cb20ff300d634453fdb3986adf4801f205c.diff

LOG: [mlir][linalg][NFC] Make reshape folding control more fine grain

This expose a lambda control instead of just a boolean to control unit
dimension folding.
This however gives more control to user to pick a good heuristic.
Folding reshapes helps fusion opportunities but may generate sub-optimal
generic ops.

Differential Revision: https://reviews.llvm.org/D101917

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index bc8775a169d9d..de0f5888550f2 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -28,6 +28,10 @@ struct LinalgElementwiseFusionOptions;
 struct LinalgFusionOptions;
 struct LinalgTilingOptions;
 
+/// Default function to control reshape folding. Skips folding unit dimension
+/// reshapes.
+bool skipUnitDimReshape(const OpResult &producer, const OpOperand &consumer);
+
 //===----------------------------------------------------------------------===//
 // Transformations exposed as function calls.
 //===----------------------------------------------------------------------===//
@@ -42,11 +46,15 @@ void populateConvVectorizationPatterns(
 /// parallel loops.
 void populateElementwiseToLinalgConversionPatterns(RewritePatternSet &patterns);
 
+using ControlElementwiseOpsFusionFn =
+    std::function<bool(const OpResult &producer, const OpOperand &consumer)>;
+
 /// Patterns to fold an expanding (collapsing) tensor_reshape operation with its
 /// producer (consumer) generic operation by expanding the dimensionality of the
 /// loop in the generic op.
 void populateFoldReshapeOpsByExpansionPatterns(
-    RewritePatternSet &patterns, bool allowFoldingUnitDimReshapes = false);
+    RewritePatternSet &patterns,
+    ControlElementwiseOpsFusionFn controlFoldingReshapes = skipUnitDimReshape);
 
 /// Patterns to fold a collapsing (expanding) tensor_reshape operation with its
 /// producer (consumer) generic/indexed_generic operation by linearizing the
@@ -71,17 +79,15 @@ void populateLinalgBufferizePatterns(BufferizeTypeConverter &converter,
 /// tensors.
 void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns);
 
-using ControlElementwiseOpsFusionFn =
-    std::function<bool(const OpResult &producer, const OpOperand &consumer)>;
-
 /// Options that control fusion of elementwise operations.
 struct LinalgElementwiseFusionOptions {
-  /// Enable fusion of reshapes that are introducing unit-dimensions into the
-  /// shape with elementwise operations. By default this is disabled.
-  bool allowFoldingUnitDimReshapes = false;
+  /// Enable fusion of reshapes into the shape with elementwise operations. By
+  /// default it is disabled for unit dimensions reshape.
+  ControlElementwiseOpsFusionFn controlFoldingReshapesFn = skipUnitDimReshape;
 
-  LinalgElementwiseFusionOptions &setAllowFoldingUnitDimReshapes(bool val) {
-    allowFoldingUnitDimReshapes = val;
+  LinalgElementwiseFusionOptions &
+  setControlFoldingReshapes(ControlElementwiseOpsFusionFn fun) {
+    controlFoldingReshapesFn = std::move(fun);
     return *this;
   }
 

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
index 7fd6d245ccb59..fab957e7fca43 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
@@ -1164,11 +1164,11 @@ template <typename GenericOpTy>
 class FoldWithProducerReshapeOpByExpansion
     : public OpRewritePattern<GenericOpTy> {
 public:
-  FoldWithProducerReshapeOpByExpansion(MLIRContext *context,
-                                       bool foldUnitDimReshapes,
-                                       PatternBenefit benefit = 1)
+  FoldWithProducerReshapeOpByExpansion(
+      MLIRContext *context, ControlElementwiseOpsFusionFn foldReshapes,
+      PatternBenefit benefit = 1)
       : OpRewritePattern<GenericOpTy>(context, benefit),
-        allowFoldingUnitDimReshapes(foldUnitDimReshapes) {}
+        controlFoldingReshapes(foldReshapes) {}
 
   LogicalResult matchAndRewrite(GenericOpTy genericOp,
                                 PatternRewriter &rewriter) const override {
@@ -1178,16 +1178,15 @@ class FoldWithProducerReshapeOpByExpansion
           operand.value().getDefiningOp<TensorReshapeOp>();
       if (!reshapeOp)
         continue;
-
       // Fold only if
       // - The tensor reshape op is folding.
       // - All constraints of fusing with reshape by expansion are met.
       if (reshapeOp.getSrcType().getRank() <
               reshapeOp.getResultType().getRank() ||
           !isFusableWithReshapeByDimExpansion(linalgOp, operand.index()) ||
-          (!allowFoldingUnitDimReshapes &&
-           isUnitDimExpansionOnly(reshapeOp.getSrcType().getShape(),
-                                  reshapeOp.getReassociationMaps())))
+          (!controlFoldingReshapes(
+              reshapeOp->getResult(0),
+              linalgOp.getInputOpOperands()[operand.index()])))
         continue;
 
       Optional<SmallVector<Value, 1>> replacementValues =
@@ -1202,7 +1201,7 @@ class FoldWithProducerReshapeOpByExpansion
   }
 
 private:
-  bool allowFoldingUnitDimReshapes;
+  ControlElementwiseOpsFusionFn controlFoldingReshapes;
 };
 
 /// Pattern to fold tensor_reshape op with its producer. The corresponding index
@@ -1394,6 +1393,13 @@ fuseElementwiseOps(PatternRewriter &rewriter, OpOperand &consumerOpOperand,
                                 controlFn, rewriter);
 }
 
+bool mlir::linalg::skipUnitDimReshape(const OpResult &producer,
+                                      const OpOperand &consumer) {
+  auto reshapeOp = producer.getDefiningOp<linalg::TensorReshapeOp>();
+  return !isUnitDimExpansionOnly(reshapeOp.getSrcType().getShape(),
+                                 reshapeOp.getReassociationMaps());
+}
+
 namespace {
 /// Patterns to fuse a generic op, with the producer of its operands.
 template <typename LinalgOpTy>
@@ -1431,10 +1437,14 @@ struct FusionOfTensorOpsPass
   void runOnOperation() override {
     Operation *op = getOperation();
     RewritePatternSet patterns(op->getContext());
+    ControlElementwiseOpsFusionFn allowFoldingFn =
+        [](const OpResult &producer, const OpOperand &consumer) {
+          return true;
+        };
     populateElementwiseOpsFusionPatterns(
         patterns,
-        LinalgElementwiseFusionOptions().setAllowFoldingUnitDimReshapes(
-            allowFoldingUnitDimReshapes));
+        LinalgElementwiseFusionOptions().setControlFoldingReshapes(
+            allowFoldingUnitDimReshapes ? allowFoldingFn : skipUnitDimReshape));
     (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns));
   }
 };
@@ -1471,11 +1481,12 @@ void mlir::linalg::populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
 }
 
 void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
-    RewritePatternSet &patterns, bool allowFoldingUnitDimReshapes) {
+    RewritePatternSet &patterns,
+    ControlElementwiseOpsFusionFn controlFoldingReshapes) {
   patterns.add<FoldReshapeWithGenericOpByExpansion>(patterns.getContext());
   patterns.add<FoldWithProducerReshapeOpByExpansion<GenericOp>,
                FoldWithProducerReshapeOpByExpansion<IndexedGenericOp>>(
-      patterns.getContext(), allowFoldingUnitDimReshapes);
+      patterns.getContext(), controlFoldingReshapes);
 }
 
 void mlir::linalg::populateElementwiseOpsFusionPatterns(
@@ -1485,8 +1496,8 @@ void mlir::linalg::populateElementwiseOpsFusionPatterns(
       .add<FuseElementwiseOps<GenericOp>, FuseElementwiseOps<IndexedGenericOp>,
            FoldSplatConstants<GenericOp>, FoldSplatConstants<IndexedGenericOp>>(
           context, options.controlElementwiseOpsFusionFn);
-  populateFoldReshapeOpsByExpansionPatterns(
-      patterns, options.allowFoldingUnitDimReshapes);
+  populateFoldReshapeOpsByExpansionPatterns(patterns,
+                                            options.controlFoldingReshapesFn);
   AffineApplyOp::getCanonicalizationPatterns(patterns, context);
   GenericOp::getCanonicalizationPatterns(patterns, context);
   IndexedGenericOp::getCanonicalizationPatterns(patterns, context);


        


More information about the Mlir-commits mailing list