[Mlir-commits] [mlir] 7f28d27 - [mlir][linalg] Allow controlling folding unit dim reshapes
Lei Zhang
llvmlistbot at llvm.org
Wed Mar 24 15:19:39 PDT 2021
Author: Lei Zhang
Date: 2021-03-24T18:17:57-04:00
New Revision: 7f28d27cb614c47e6cf68f5deae729270d13cb08
URL: https://github.com/llvm/llvm-project/commit/7f28d27cb614c47e6cf68f5deae729270d13cb08
DIFF: https://github.com/llvm/llvm-project/commit/7f28d27cb614c47e6cf68f5deae729270d13cb08.diff
LOG: [mlir][linalg] Allow controlling folding unit dim reshapes
This commit exposes an option to the pattern
FoldWithProducerReshapeOpByExpansion to allow
folding unit dim reshapes. This gives callers
more fine-grained controls.
Differential Revision: https://reviews.llvm.org/D99114
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/Passes.h
mlir/include/mlir/Dialect/Linalg/Passes.td
mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
mlir/test/Dialect/Linalg/reshape_fusion.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h
index ecec2a3c05d2..18820d4316b9 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.h
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.h
@@ -65,7 +65,8 @@ std::unique_ptr<Pass> createLinalgDetensorizePass();
/// Patterns to fold an expanding (collapsing) tensor_reshape operation with its
/// producer (consumer) generic operation by expanding the dimensionality of the
/// loop in the generic op.
-void populateFoldReshapeOpsByExpansionPatterns(RewritePatternSet &patterns);
+void populateFoldReshapeOpsByExpansionPatterns(
+ RewritePatternSet &patterns, bool allowFoldingUnitDimReshapes = false);
/// Patterns to fold a collapsing (expanding) tensor_reshape operation with its
/// producer (consumer) generic/indexed_generic operation by linearizing the
@@ -83,7 +84,8 @@ void populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
RewritePatternSet &patterns);
/// Patterns for fusing linalg operation on tensors.
-void populateLinalgTensorOpsFusionPatterns(RewritePatternSet &patterns);
+void populateLinalgTensorOpsFusionPatterns(
+ RewritePatternSet &patterns, bool allowFoldingUnitDimReshapes = false);
/// Patterns to fold unit-extent dimensions in operands/results of linalg ops on
/// tensors.
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index aad11179be69..786b9ec85dcf 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -37,6 +37,12 @@ def LinalgFoldUnitExtentDims : FunctionPass<"linalg-fold-unit-extent-dims"> {
def LinalgFusionOfTensorOps : Pass<"linalg-fusion-for-tensor-ops"> {
let summary = "Fuse operations on RankedTensorType in linalg dialect";
let constructor = "mlir::createLinalgFusionOfTensorOpsPass()";
+ let options = [
+ Option<"allowFoldingUnitDimReshapes", "allow-folding-unit-dim-reshapes",
+ "bool", /*default=*/"false",
+ "Allow fusing linalg.tensor_reshape ops that performs unit "
+ "dimension collapsing">
+ ];
let dependentDialects = ["linalg::LinalgDialect", "AffineDialect"];
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
index 4b0951ea4c1c..7e89a0887d0d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
@@ -897,9 +897,14 @@ struct FoldProducerReshapeOpByLinearization
/// generic/indexed_generic op, when the reshape op is collapsing
/// dimensions. The dimensionality of the loop in the consumer is expanded.
template <typename GenericOpTy>
-struct FoldWithProducerReshapeOpByExpansion
+class FoldWithProducerReshapeOpByExpansion
: public OpRewritePattern<GenericOpTy> {
- using OpRewritePattern<GenericOpTy>::OpRewritePattern;
+public:
+ FoldWithProducerReshapeOpByExpansion(MLIRContext *context,
+ bool foldUnitDimReshapes,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<GenericOpTy>(context, benefit),
+ allowFoldingUnitDimReshapes(foldUnitDimReshapes) {}
LogicalResult matchAndRewrite(GenericOpTy genericOp,
PatternRewriter &rewriter) const override {
@@ -916,8 +921,9 @@ struct FoldWithProducerReshapeOpByExpansion
if (reshapeOp.getSrcType().getRank() <
reshapeOp.getResultType().getRank() ||
!isFusableWithReshapeByDimExpansion(linalgOp, operand.index()) ||
- isUnitDimExpansionOnly(reshapeOp.getSrcType().getShape(),
- reshapeOp.getReassociationMaps()))
+ (!allowFoldingUnitDimReshapes &&
+ isUnitDimExpansionOnly(reshapeOp.getSrcType().getShape(),
+ reshapeOp.getReassociationMaps())))
continue;
Optional<SmallVector<Value, 1>> replacementValues =
@@ -930,6 +936,9 @@ struct FoldWithProducerReshapeOpByExpansion
}
return failure();
}
+
+private:
+ bool allowFoldingUnitDimReshapes;
};
/// Pattern to fold tensor_reshape op with its producer. The corresponding index
@@ -1134,7 +1143,8 @@ struct FusionOfTensorOpsPass
void runOnOperation() override {
Operation *op = getOperation();
RewritePatternSet patterns(op->getContext());
- populateLinalgTensorOpsFusionPatterns(patterns);
+ populateLinalgTensorOpsFusionPatterns(patterns,
+ allowFoldingUnitDimReshapes);
(void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns));
}
};
@@ -1171,20 +1181,22 @@ void mlir::populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
}
void mlir::populateFoldReshapeOpsByExpansionPatterns(
- RewritePatternSet &patterns) {
- patterns.add<FoldReshapeWithGenericOpByExpansion,
- FoldWithProducerReshapeOpByExpansion<GenericOp>,
+ RewritePatternSet &patterns, bool allowFoldingUnitDimReshapes) {
+ patterns.add<FoldReshapeWithGenericOpByExpansion>(patterns.getContext());
+ patterns.add<FoldWithProducerReshapeOpByExpansion<GenericOp>,
FoldWithProducerReshapeOpByExpansion<IndexedGenericOp>>(
- patterns.getContext());
+ patterns.getContext(), allowFoldingUnitDimReshapes);
}
-void mlir::populateLinalgTensorOpsFusionPatterns(RewritePatternSet &patterns) {
+void mlir::populateLinalgTensorOpsFusionPatterns(
+ RewritePatternSet &patterns, bool allowFoldingUnitDimReshapes) {
auto *context = patterns.getContext();
patterns
.add<FuseTensorOps<GenericOp>, FuseTensorOps<IndexedGenericOp>,
FoldSplatConstants<GenericOp>, FoldSplatConstants<IndexedGenericOp>>(
context);
- populateFoldReshapeOpsByExpansionPatterns(patterns);
+ populateFoldReshapeOpsByExpansionPatterns(patterns,
+ allowFoldingUnitDimReshapes);
GenericOp::getCanonicalizationPatterns(patterns, context);
IndexedGenericOp::getCanonicalizationPatterns(patterns, context);
TensorReshapeOp::getCanonicalizationPatterns(patterns, context);
diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
index fbaf47c9ac4d..d5dc176f1fdf 100644
--- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir
+++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
@@ -1,4 +1,5 @@
-// RUN: mlir-opt %s -linalg-fusion-for-tensor-ops -split-input-file -verify-each=0 | FileCheck %s
+// RUN: mlir-opt %s -linalg-fusion-for-tensor-ops="allow-folding-unit-dim-reshapes=false" -split-input-file -verify-each=0 | FileCheck %s
+// RUN: mlir-opt %s -linalg-fusion-for-tensor-ops="allow-folding-unit-dim-reshapes=true" -split-input-file -verify-each=0 | FileCheck %s --check-prefix=FOLDUNITDIM
#map0 = affine_map<(d0, d1, d2) -> (d2, d0, d1)>
#map1 = affine_map<(d0, d1, d2) -> (d1, d2, d0)>
@@ -300,7 +301,7 @@ func @reshape_as_consumer_permutation
%5 = addi %3, %4 : i32
%6 = index_cast %arg2 : index to i32
%7 = addi %5, %6 : i32
- linalg.yield %7 : i32
+ linalg.yield %7 : i32
} -> tensor<6x4x210xi32>
%d = linalg.tensor_reshape %c
[affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1)>,
@@ -531,3 +532,11 @@ func @unit_dim_reshape_expansion_full
// CHECK-DAG: linalg.tensor_reshape
// CHECK-DAG: linalg.init_tensor
// CHECK: linalg.generic
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<?x2x4xf32>, tensor<?x2x4xf32>)
+
+// FOLDUNITDIM: func @unit_dim_reshape_expansion_full
+// FOLDUNITDIM: linalg.init_tensor
+// FOLDUNITDIM-COUNT-2: linalg.tensor_reshape
+// FOLDUNITDIM: linalg.generic
+// FOLDUNITDIM-SAME: ins(%{{.+}}, %{{.+}} : tensor<1x?x1x2x1x4xf32>, tensor<1x?x1x2x1x4xf32>)
+
More information about the Mlir-commits
mailing list