[Mlir-commits] [mlir] 52525cb - [mlir][linalg][NFC] Make reshape folding control more fine grain
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu May 6 10:12:07 PDT 2021
Author: thomasraoux
Date: 2021-05-06T10:11:39-07:00
New Revision: 52525cb20ff300d634453fdb3986adf4801f205c
URL: https://github.com/llvm/llvm-project/commit/52525cb20ff300d634453fdb3986adf4801f205c
DIFF: https://github.com/llvm/llvm-project/commit/52525cb20ff300d634453fdb3986adf4801f205c.diff
LOG: [mlir][linalg][NFC] Make reshape folding control more fine grain
This expose a lambda control instead of just a boolean to control unit
dimension folding.
This however gives more control to user to pick a good heuristic.
Folding reshapes helps fusion opportunities but may generate sub-optimal
generic ops.
Differential Revision: https://reviews.llvm.org/D101917
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index bc8775a169d9d..de0f5888550f2 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -28,6 +28,10 @@ struct LinalgElementwiseFusionOptions;
struct LinalgFusionOptions;
struct LinalgTilingOptions;
+/// Default function to control reshape folding. Skips folding unit dimension
+/// reshapes.
+bool skipUnitDimReshape(const OpResult &producer, const OpOperand &consumer);
+
//===----------------------------------------------------------------------===//
// Transformations exposed as function calls.
//===----------------------------------------------------------------------===//
@@ -42,11 +46,15 @@ void populateConvVectorizationPatterns(
/// parallel loops.
void populateElementwiseToLinalgConversionPatterns(RewritePatternSet &patterns);
+using ControlElementwiseOpsFusionFn =
+ std::function<bool(const OpResult &producer, const OpOperand &consumer)>;
+
/// Patterns to fold an expanding (collapsing) tensor_reshape operation with its
/// producer (consumer) generic operation by expanding the dimensionality of the
/// loop in the generic op.
void populateFoldReshapeOpsByExpansionPatterns(
- RewritePatternSet &patterns, bool allowFoldingUnitDimReshapes = false);
+ RewritePatternSet &patterns,
+ ControlElementwiseOpsFusionFn controlFoldingReshapes = skipUnitDimReshape);
/// Patterns to fold a collapsing (expanding) tensor_reshape operation with its
/// producer (consumer) generic/indexed_generic operation by linearizing the
@@ -71,17 +79,15 @@ void populateLinalgBufferizePatterns(BufferizeTypeConverter &converter,
/// tensors.
void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns);
-using ControlElementwiseOpsFusionFn =
- std::function<bool(const OpResult &producer, const OpOperand &consumer)>;
-
/// Options that control fusion of elementwise operations.
struct LinalgElementwiseFusionOptions {
- /// Enable fusion of reshapes that are introducing unit-dimensions into the
- /// shape with elementwise operations. By default this is disabled.
- bool allowFoldingUnitDimReshapes = false;
+ /// Enable fusion of reshapes into the shape with elementwise operations. By
+ /// default it is disabled for unit dimensions reshape.
+ ControlElementwiseOpsFusionFn controlFoldingReshapesFn = skipUnitDimReshape;
- LinalgElementwiseFusionOptions &setAllowFoldingUnitDimReshapes(bool val) {
- allowFoldingUnitDimReshapes = val;
+ LinalgElementwiseFusionOptions &
+ setControlFoldingReshapes(ControlElementwiseOpsFusionFn fun) {
+ controlFoldingReshapesFn = std::move(fun);
return *this;
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
index 7fd6d245ccb59..fab957e7fca43 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
@@ -1164,11 +1164,11 @@ template <typename GenericOpTy>
class FoldWithProducerReshapeOpByExpansion
: public OpRewritePattern<GenericOpTy> {
public:
- FoldWithProducerReshapeOpByExpansion(MLIRContext *context,
- bool foldUnitDimReshapes,
- PatternBenefit benefit = 1)
+ FoldWithProducerReshapeOpByExpansion(
+ MLIRContext *context, ControlElementwiseOpsFusionFn foldReshapes,
+ PatternBenefit benefit = 1)
: OpRewritePattern<GenericOpTy>(context, benefit),
- allowFoldingUnitDimReshapes(foldUnitDimReshapes) {}
+ controlFoldingReshapes(foldReshapes) {}
LogicalResult matchAndRewrite(GenericOpTy genericOp,
PatternRewriter &rewriter) const override {
@@ -1178,16 +1178,15 @@ class FoldWithProducerReshapeOpByExpansion
operand.value().getDefiningOp<TensorReshapeOp>();
if (!reshapeOp)
continue;
-
// Fold only if
// - The tensor reshape op is folding.
// - All constraints of fusing with reshape by expansion are met.
if (reshapeOp.getSrcType().getRank() <
reshapeOp.getResultType().getRank() ||
!isFusableWithReshapeByDimExpansion(linalgOp, operand.index()) ||
- (!allowFoldingUnitDimReshapes &&
- isUnitDimExpansionOnly(reshapeOp.getSrcType().getShape(),
- reshapeOp.getReassociationMaps())))
+ (!controlFoldingReshapes(
+ reshapeOp->getResult(0),
+ linalgOp.getInputOpOperands()[operand.index()])))
continue;
Optional<SmallVector<Value, 1>> replacementValues =
@@ -1202,7 +1201,7 @@ class FoldWithProducerReshapeOpByExpansion
}
private:
- bool allowFoldingUnitDimReshapes;
+ ControlElementwiseOpsFusionFn controlFoldingReshapes;
};
/// Pattern to fold tensor_reshape op with its producer. The corresponding index
@@ -1394,6 +1393,13 @@ fuseElementwiseOps(PatternRewriter &rewriter, OpOperand &consumerOpOperand,
controlFn, rewriter);
}
+bool mlir::linalg::skipUnitDimReshape(const OpResult &producer,
+ const OpOperand &consumer) {
+ auto reshapeOp = producer.getDefiningOp<linalg::TensorReshapeOp>();
+ return !isUnitDimExpansionOnly(reshapeOp.getSrcType().getShape(),
+ reshapeOp.getReassociationMaps());
+}
+
namespace {
/// Patterns to fuse a generic op, with the producer of its operands.
template <typename LinalgOpTy>
@@ -1431,10 +1437,14 @@ struct FusionOfTensorOpsPass
void runOnOperation() override {
Operation *op = getOperation();
RewritePatternSet patterns(op->getContext());
+ ControlElementwiseOpsFusionFn allowFoldingFn =
+ [](const OpResult &producer, const OpOperand &consumer) {
+ return true;
+ };
populateElementwiseOpsFusionPatterns(
patterns,
- LinalgElementwiseFusionOptions().setAllowFoldingUnitDimReshapes(
- allowFoldingUnitDimReshapes));
+ LinalgElementwiseFusionOptions().setControlFoldingReshapes(
+ allowFoldingUnitDimReshapes ? allowFoldingFn : skipUnitDimReshape));
(void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns));
}
};
@@ -1471,11 +1481,12 @@ void mlir::linalg::populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
}
void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
- RewritePatternSet &patterns, bool allowFoldingUnitDimReshapes) {
+ RewritePatternSet &patterns,
+ ControlElementwiseOpsFusionFn controlFoldingReshapes) {
patterns.add<FoldReshapeWithGenericOpByExpansion>(patterns.getContext());
patterns.add<FoldWithProducerReshapeOpByExpansion<GenericOp>,
FoldWithProducerReshapeOpByExpansion<IndexedGenericOp>>(
- patterns.getContext(), allowFoldingUnitDimReshapes);
+ patterns.getContext(), controlFoldingReshapes);
}
void mlir::linalg::populateElementwiseOpsFusionPatterns(
@@ -1485,8 +1496,8 @@ void mlir::linalg::populateElementwiseOpsFusionPatterns(
.add<FuseElementwiseOps<GenericOp>, FuseElementwiseOps<IndexedGenericOp>,
FoldSplatConstants<GenericOp>, FoldSplatConstants<IndexedGenericOp>>(
context, options.controlElementwiseOpsFusionFn);
- populateFoldReshapeOpsByExpansionPatterns(
- patterns, options.allowFoldingUnitDimReshapes);
+ populateFoldReshapeOpsByExpansionPatterns(patterns,
+ options.controlFoldingReshapesFn);
AffineApplyOp::getCanonicalizationPatterns(patterns, context);
GenericOp::getCanonicalizationPatterns(patterns, context);
IndexedGenericOp::getCanonicalizationPatterns(patterns, context);
More information about the Mlir-commits
mailing list