[Mlir-commits] [mlir] 3de0626 - [mlir][Linalg] Update signatures of the callback functions.

Hanhan Wang llvmlistbot at llvm.org
Tue Jul 20 17:03:49 PDT 2021


Author: Hanhan Wang
Date: 2021-07-20T17:03:34-07:00
New Revision: 3de06260f746e8a768330fdcc51ea031c39f8860

URL: https://github.com/llvm/llvm-project/commit/3de06260f746e8a768330fdcc51ea031c39f8860
DIFF: https://github.com/llvm/llvm-project/commit/3de06260f746e8a768330fdcc51ea031c39f8860.diff

LOG: [mlir][Linalg] Update signatures of the callback functions.

This allows caller to use non-const functions, e.g., `getOperandNumber`, etc. It
is expected that OpOperand is not modified in a callback function.

Reviewed By: mravishankar

Differential Revision: https://reviews.llvm.org/D106322

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index a9724943f9bef..a2ce99bc97b0e 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -33,7 +33,7 @@ struct LinalgTilingOptions;
 
 /// Default function to control reshape folding. Skips folding unit dimension
 /// reshapes.
-bool skipUnitDimReshape(const OpResult &producer, const OpOperand &consumer);
+bool skipUnitDimReshape(const OpResult &producer, OpOperand &consumer);
 
 //===----------------------------------------------------------------------===//
 // Transformations exposed as function calls.
@@ -49,8 +49,11 @@ void populateConvVectorizationPatterns(
 /// parallel loops.
 void populateElementwiseToLinalgConversionPatterns(RewritePatternSet &patterns);
 
+/// Function type which is used to control when to stop fusion. It is expected
+/// that OpOperand is not modified in the callback. The OpOperand is not marked
+/// as const to allow callers to use non-const methods.
 using ControlElementwiseOpsFusionFn =
-    std::function<bool(const OpResult &producer, const OpOperand &consumer)>;
+    std::function<bool(const OpResult &producer, OpOperand &consumer)>;
 
 /// Patterns to fold an expanding (collapsing) tensor_reshape operation with its
 /// producer (consumer) generic operation by expanding the dimensionality of the
@@ -104,7 +107,7 @@ struct LinalgElementwiseFusionOptions {
   /// can be used to abort the fusion based on non-structural constraints. This
   /// is the hook for cost models to control the amount of fusion done.
   ControlElementwiseOpsFusionFn controlElementwiseOpsFusionFn =
-      [](const OpResult & /*producer */, const OpOperand & /*consumer */) {
+      [](const OpResult & /*producer */, OpOperand & /*consumer */) {
         return true;
       };
 

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 11eda50415a63..fdca523b38544 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1241,7 +1241,7 @@ fuseElementwiseOps(PatternRewriter &rewriter, OpOperand *consumerOpOperand,
 }
 
 bool mlir::linalg::skipUnitDimReshape(const OpResult &producer,
-                                      const OpOperand &consumer) {
+                                      OpOperand &consumer) {
   auto expandShapeOp = producer.getDefiningOp<linalg::TensorExpandShapeOp>();
   if (expandShapeOp)
     return !isUnitDimExpansionOnly(expandShapeOp);


        


More information about the Mlir-commits mailing list