[Mlir-commits] [mlir] b546f43 - [mlir]Linalg] Allow controlling fusion of linalg.generic -> linalg.tensor_expand_shape.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Aug 23 16:28:25 PDT 2021
Author: MaheshRavishankar
Date: 2021-08-23T16:28:10-07:00
New Revision: b546f4347b87c20502c747b366a3303949a34d69
URL: https://github.com/llvm/llvm-project/commit/b546f4347b87c20502c747b366a3303949a34d69
DIFF: https://github.com/llvm/llvm-project/commit/b546f4347b87c20502c747b366a3303949a34d69.diff
LOG: [mlir]Linalg] Allow controlling fusion of linalg.generic -> linalg.tensor_expand_shape.
Differential Revision: https://reviews.llvm.org/D108565
Added:
mlir/test/Dialect/Linalg/reshape_control_fusion.mlir
Modified:
mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
mlir/tools/mlir-opt/mlir-opt.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 43a4105e4c3f8..2b29f58428dce 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1133,7 +1133,13 @@ struct FoldConsumerReshapeOpByLinearization
/// by expanding the dimensionality of the loop in the producer op.
struct FoldReshapeWithGenericOpByExpansion
: public OpRewritePattern<TensorExpandShapeOp> {
- using OpRewritePattern<TensorExpandShapeOp>::OpRewritePattern;
+
+ FoldReshapeWithGenericOpByExpansion(
+ MLIRContext *context, ControlElementwiseOpsFusionFn foldReshapes,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<TensorExpandShapeOp>(context, benefit),
+ controlFoldingReshapes(foldReshapes) {}
+
LogicalResult matchAndRewrite(TensorExpandShapeOp reshapeOp,
PatternRewriter &rewriter) const override {
// Fold only if all constraints of fusing with reshape by expansion are met.
@@ -1141,7 +1147,8 @@ struct FoldReshapeWithGenericOpByExpansion
if (!producer || producer.getNumOutputs() != 1 ||
!isFusableWithReshapeByDimExpansion(producer,
producer.getOutputOperand(0)) ||
- isUnitDimExpansionOnly(reshapeOp))
+ !controlFoldingReshapes(producer->getResult(0),
+ reshapeOp->getOpOperand(0)))
return failure();
Optional<SmallVector<Value>> replacementValues = fuseWithReshapeByExpansion(
producer, reshapeOp, producer.getOutputOperand(0), rewriter);
@@ -1150,6 +1157,9 @@ struct FoldReshapeWithGenericOpByExpansion
rewriter.replaceOp(reshapeOp, replacementValues.getValue());
return success();
}
+
+private:
+ ControlElementwiseOpsFusionFn controlFoldingReshapes;
};
/// Pattern to fold a generic op with a splat constant.
@@ -1242,12 +1252,15 @@ fuseElementwiseOps(PatternRewriter &rewriter, OpOperand *consumerOpOperand,
bool mlir::linalg::skipUnitDimReshape(const OpResult &producer,
OpOperand &consumer) {
- auto expandShapeOp = producer.getDefiningOp<linalg::TensorExpandShapeOp>();
- if (expandShapeOp)
- return !isUnitDimExpansionOnly(expandShapeOp);
- auto collapseShapeOp =
- producer.getDefiningOp<linalg::TensorCollapseShapeOp>();
- return !isUnitDimExpansionOnly(collapseShapeOp);
+ if (auto producerCollapseOp =
+ dyn_cast<linalg::TensorCollapseShapeOp>(producer.getOwner())) {
+ return !isUnitDimExpansionOnly(producerCollapseOp);
+ }
+ if (auto consumerExpandOp =
+ dyn_cast<linalg::TensorExpandShapeOp>(consumer.getOwner())) {
+ return !isUnitDimExpansionOnly(consumerExpandOp);
+ }
+ return true;
}
namespace {
@@ -1389,7 +1402,8 @@ void mlir::linalg::populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
RewritePatternSet &patterns,
ControlElementwiseOpsFusionFn controlFoldingReshapes) {
- patterns.add<FoldReshapeWithGenericOpByExpansion>(patterns.getContext());
+ patterns.add<FoldReshapeWithGenericOpByExpansion>(patterns.getContext(),
+ controlFoldingReshapes);
patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(),
controlFoldingReshapes);
}
diff --git a/mlir/test/Dialect/Linalg/reshape_control_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_control_fusion.mlir
new file mode 100644
index 0000000000000..f72723f36804c
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/reshape_control_fusion.mlir
@@ -0,0 +1,62 @@
+// RUN: mlir-opt -test-linalg-control-fusion-by-expansion %s -split-input-file | FileCheck %s
+
+func @control_producer_reshape_fusion(%arg0 : tensor<?x?x?xf32>, %arg1 : tensor<?xf32>) -> tensor<?x?xf32> {
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %0 = linalg.tensor_collapse_shape %arg0 [[0, 1], [2]] : tensor<?x?x?xf32> into tensor<?x?xf32>
+ %d0 = tensor.dim %0, %c0 : tensor<?x?xf32>
+ %d1 = tensor.dim %0, %c1 : tensor<?x?xf32>
+ %init = linalg.init_tensor [%d0, %d1] : tensor<?x?xf32>
+ %1 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%0, %arg1 : tensor<?x?xf32>, tensor<?xf32>)
+ outs(%init : tensor<?x?xf32>) {
+ ^bb0(%arg2 : f32, %arg3:f32, %arg4 : f32):
+ %2 = addf %arg2, %arg3 : f32
+ linalg.yield %2 : f32
+ } -> tensor<?x?xf32>
+ return %1 : tensor<?x?xf32>
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d1)>
+// CHECK: builtin.func @control_producer_reshape_fusion
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?xf32>
+// CHECK-DAG: %[[C0:.+]] = constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = constant 1 : index
+// CHECK: %[[RESHAPE:.+]] = linalg.tensor_collapse_shape %[[ARG0]]
+// CHECK-SAME: {{\[}}[0, 1], [2]{{\]}} : tensor<?x?x?xf32> into tensor<?x?xf32>
+// CHECK: %[[RESULT:.+]] = linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP0]]]
+// CHECK-SAME: ins(%[[RESHAPE]], %[[ARG1]] : tensor<?x?xf32>, tensor<?xf32>)
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func @control_consumer_reshape_fusion(%arg0 : tensor<1x?x?xf32>, %arg1 : tensor<1x?x?xf32>) -> tensor<1x?x?xf32> {
+ %c1 = constant 1 : index
+ %c2 = constant 2 : index
+ %cst = constant 0.0 : f32
+ %d0 = tensor.dim %arg0, %c1 : tensor<1x?x?xf32>
+ %d1 = tensor.dim %arg1, %c2 : tensor<1x?x?xf32>
+ %init = linalg.init_tensor [%d0, %d1] : tensor<?x?xf32>
+ %fill = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]}
+ outs(%init : tensor<?x?xf32>) {
+ ^bb0(%arg2: f32):
+ linalg.yield %cst : f32
+ } -> tensor<?x?xf32>
+ %0 = linalg.tensor_expand_shape %fill [[0, 1], [2]] : tensor<?x?xf32> into tensor<1x?x?xf32>
+ %1 = linalg.batch_matmul ins(%arg0, %arg1 : tensor<1x?x?xf32>, tensor<1x?x?xf32>)
+ outs(%0 : tensor<1x?x?xf32>) -> tensor<1x?x?xf32>
+ return %1 : tensor<1x?x?xf32>
+}
+// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)
+// CHECK: builtin.func @control_consumer_reshape_fusion
+// CHECK: %[[FILL:.+]] = linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP]]]
+// CHECK-SAME: outs(%{{.+}} : tensor<1x?x?xf32>)
+// CHECK: linalg.batch_matmul
+// CHECK-SAME: outs(%[[FILL]] : tensor<1x?x?xf32>)
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
index 487d42f8e9848..e9f72e500686a 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
@@ -73,6 +73,52 @@ struct TestLinalgElementwiseFusion
}
};
+struct TestLinalgControlFuseByExpansion
+ : public PassWrapper<TestLinalgControlFuseByExpansion, FunctionPass> {
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry
+ .insert<AffineDialect, linalg::LinalgDialect, tensor::TensorDialect>();
+ }
+ StringRef getArgument() const final {
+ return "test-linalg-control-fusion-by-expansion";
+ }
+ StringRef getDescription() const final {
+ return "Test controlling of fusion of elementwise ops with reshape by "
+ "expansion";
+ }
+
+ void runOnFunction() override {
+ MLIRContext *context = &this->getContext();
+ FuncOp funcOp = this->getFunction();
+ RewritePatternSet fusionPatterns(context);
+
+ linalg::ControlElementwiseOpsFusionFn controlReshapeFusionFn =
+ [](const OpResult &producer, OpOperand &consumer) {
+ if (auto collapseOp =
+ producer.getDefiningOp<linalg::TensorCollapseShapeOp>()) {
+ if (!collapseOp.src().getDefiningOp<linalg::LinalgOp>()) {
+ return false;
+ }
+ }
+ if (auto expandOp =
+ dyn_cast<linalg::TensorExpandShapeOp>(consumer.getOwner())) {
+ if (expandOp->hasOneUse()) {
+ OpOperand &use = *expandOp->getUses().begin();
+ auto linalgOp = dyn_cast<linalg::LinalgOp>(use.getOwner());
+ if (linalgOp && linalgOp.isOutputTensor(&use))
+ return true;
+ }
+ }
+ return linalg::skipUnitDimReshape(producer, consumer);
+ };
+
+ linalg::populateFoldReshapeOpsByExpansionPatterns(fusionPatterns,
+ controlReshapeFusionFn);
+ (void)applyPatternsAndFoldGreedily(funcOp.getBody(),
+ std::move(fusionPatterns));
+ }
+};
+
struct TestPushExpandingReshape
: public PassWrapper<TestPushExpandingReshape, FunctionPass> {
void getDependentDialects(DialectRegistry ®istry) const override {
@@ -99,6 +145,10 @@ void registerTestLinalgElementwiseFusion() {
PassRegistration<TestLinalgElementwiseFusion>();
}
+void registerTestLinalgControlFuseByExpansion() {
+ PassRegistration<TestLinalgControlFuseByExpansion>();
+}
+
void registerTestPushExpandingReshape() {
PassRegistration<TestPushExpandingReshape>();
}
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 3059b1fafb96a..2fba77f76b433 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -78,6 +78,7 @@ void registerTestGpuParallelLoopMappingPass();
void registerTestIRVisitorsPass();
void registerTestInterfaces();
void registerTestLinalgCodegenStrategy();
+void registerTestLinalgControlFuseByExpansion();
void registerTestLinalgDistribution();
void registerTestLinalgElementwiseFusion();
void registerTestPushExpandingReshape();
@@ -165,6 +166,7 @@ void registerTestPasses() {
mlir::test::registerTestIRVisitorsPass();
mlir::test::registerTestInterfaces();
mlir::test::registerTestLinalgCodegenStrategy();
+ mlir::test::registerTestLinalgControlFuseByExpansion();
mlir::test::registerTestLinalgDistribution();
mlir::test::registerTestLinalgElementwiseFusion();
mlir::test::registerTestPushExpandingReshape();
More information about the Mlir-commits
mailing list