[Mlir-commits] [mlir] b546f43 - [mlir]Linalg] Allow controlling fusion of linalg.generic -> linalg.tensor_expand_shape.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Aug 23 16:28:25 PDT 2021


Author: MaheshRavishankar
Date: 2021-08-23T16:28:10-07:00
New Revision: b546f4347b87c20502c747b366a3303949a34d69

URL: https://github.com/llvm/llvm-project/commit/b546f4347b87c20502c747b366a3303949a34d69
DIFF: https://github.com/llvm/llvm-project/commit/b546f4347b87c20502c747b366a3303949a34d69.diff

LOG: [mlir]Linalg] Allow controlling fusion of linalg.generic -> linalg.tensor_expand_shape.

Differential Revision: https://reviews.llvm.org/D108565

Added: 
    mlir/test/Dialect/Linalg/reshape_control_fusion.mlir

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
    mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
    mlir/tools/mlir-opt/mlir-opt.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 43a4105e4c3f8..2b29f58428dce 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1133,7 +1133,13 @@ struct FoldConsumerReshapeOpByLinearization
 /// by expanding the dimensionality of the loop in the producer op.
 struct FoldReshapeWithGenericOpByExpansion
     : public OpRewritePattern<TensorExpandShapeOp> {
-  using OpRewritePattern<TensorExpandShapeOp>::OpRewritePattern;
+
+  FoldReshapeWithGenericOpByExpansion(
+      MLIRContext *context, ControlElementwiseOpsFusionFn foldReshapes,
+      PatternBenefit benefit = 1)
+      : OpRewritePattern<TensorExpandShapeOp>(context, benefit),
+        controlFoldingReshapes(foldReshapes) {}
+
   LogicalResult matchAndRewrite(TensorExpandShapeOp reshapeOp,
                                 PatternRewriter &rewriter) const override {
     // Fold only if all constraints of fusing with reshape by expansion are met.
@@ -1141,7 +1147,8 @@ struct FoldReshapeWithGenericOpByExpansion
     if (!producer || producer.getNumOutputs() != 1 ||
         !isFusableWithReshapeByDimExpansion(producer,
                                             producer.getOutputOperand(0)) ||
-        isUnitDimExpansionOnly(reshapeOp))
+        !controlFoldingReshapes(producer->getResult(0),
+                                reshapeOp->getOpOperand(0)))
       return failure();
     Optional<SmallVector<Value>> replacementValues = fuseWithReshapeByExpansion(
         producer, reshapeOp, producer.getOutputOperand(0), rewriter);
@@ -1150,6 +1157,9 @@ struct FoldReshapeWithGenericOpByExpansion
     rewriter.replaceOp(reshapeOp, replacementValues.getValue());
     return success();
   }
+
+private:
+  ControlElementwiseOpsFusionFn controlFoldingReshapes;
 };
 
 /// Pattern to fold a generic op with a splat constant.
@@ -1242,12 +1252,15 @@ fuseElementwiseOps(PatternRewriter &rewriter, OpOperand *consumerOpOperand,
 
 bool mlir::linalg::skipUnitDimReshape(const OpResult &producer,
                                       OpOperand &consumer) {
-  auto expandShapeOp = producer.getDefiningOp<linalg::TensorExpandShapeOp>();
-  if (expandShapeOp)
-    return !isUnitDimExpansionOnly(expandShapeOp);
-  auto collapseShapeOp =
-      producer.getDefiningOp<linalg::TensorCollapseShapeOp>();
-  return !isUnitDimExpansionOnly(collapseShapeOp);
+  if (auto producerCollapseOp =
+          dyn_cast<linalg::TensorCollapseShapeOp>(producer.getOwner())) {
+    return !isUnitDimExpansionOnly(producerCollapseOp);
+  }
+  if (auto consumerExpandOp =
+          dyn_cast<linalg::TensorExpandShapeOp>(consumer.getOwner())) {
+    return !isUnitDimExpansionOnly(consumerExpandOp);
+  }
+  return true;
 }
 
 namespace {
@@ -1389,7 +1402,8 @@ void mlir::linalg::populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
 void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
     RewritePatternSet &patterns,
     ControlElementwiseOpsFusionFn controlFoldingReshapes) {
-  patterns.add<FoldReshapeWithGenericOpByExpansion>(patterns.getContext());
+  patterns.add<FoldReshapeWithGenericOpByExpansion>(patterns.getContext(),
+                                                    controlFoldingReshapes);
   patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(),
                                                      controlFoldingReshapes);
 }

diff  --git a/mlir/test/Dialect/Linalg/reshape_control_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_control_fusion.mlir
new file mode 100644
index 0000000000000..f72723f36804c
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/reshape_control_fusion.mlir
@@ -0,0 +1,62 @@
+// RUN: mlir-opt -test-linalg-control-fusion-by-expansion %s -split-input-file | FileCheck %s
+
+func @control_producer_reshape_fusion(%arg0 : tensor<?x?x?xf32>, %arg1 : tensor<?xf32>) -> tensor<?x?xf32> {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %0 = linalg.tensor_collapse_shape %arg0 [[0, 1], [2]] : tensor<?x?x?xf32> into tensor<?x?xf32>
+  %d0 = tensor.dim %0, %c0 : tensor<?x?xf32>
+  %d1 = tensor.dim %0, %c1 : tensor<?x?xf32>
+  %init = linalg.init_tensor [%d0, %d1] : tensor<?x?xf32>
+  %1 = linalg.generic {
+      indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>],
+      iterator_types = ["parallel", "parallel"]}
+      ins(%0, %arg1 : tensor<?x?xf32>, tensor<?xf32>)
+      outs(%init : tensor<?x?xf32>) {
+      ^bb0(%arg2 : f32, %arg3:f32, %arg4 : f32):
+        %2 = addf %arg2, %arg3 : f32
+        linalg.yield %2 : f32
+      } -> tensor<?x?xf32>
+  return %1 : tensor<?x?xf32>
+}
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d1)>
+//      CHECK: builtin.func @control_producer_reshape_fusion
+// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
+// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?xf32>
+//  CHECK-DAG:   %[[C0:.+]] = constant 0 : index
+//  CHECK-DAG:   %[[C1:.+]] = constant 1 : index
+//      CHECK:   %[[RESHAPE:.+]] = linalg.tensor_collapse_shape %[[ARG0]]
+// CHECK-SAME:       {{\[}}[0, 1], [2]{{\]}} : tensor<?x?x?xf32> into tensor<?x?xf32>
+//      CHECK:   %[[RESULT:.+]] = linalg.generic
+// CHECK-SAME:       indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP0]]]
+// CHECK-SAME:       ins(%[[RESHAPE]], %[[ARG1]] : tensor<?x?xf32>, tensor<?xf32>)
+//      CHECK:   return %[[RESULT]]
+
+// -----
+
+func @control_consumer_reshape_fusion(%arg0 : tensor<1x?x?xf32>, %arg1 : tensor<1x?x?xf32>) -> tensor<1x?x?xf32> {
+  %c1 = constant 1 : index
+  %c2 = constant 2 : index
+  %cst = constant 0.0 : f32
+  %d0 = tensor.dim %arg0, %c1 : tensor<1x?x?xf32>
+  %d1 = tensor.dim %arg1, %c2 : tensor<1x?x?xf32>
+  %init = linalg.init_tensor [%d0, %d1] : tensor<?x?xf32>
+  %fill = linalg.generic {
+      indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>],
+      iterator_types = ["parallel", "parallel"]}
+      outs(%init : tensor<?x?xf32>) {
+      ^bb0(%arg2: f32):
+        linalg.yield %cst : f32
+      } -> tensor<?x?xf32>
+  %0 = linalg.tensor_expand_shape %fill [[0, 1], [2]] : tensor<?x?xf32> into tensor<1x?x?xf32>
+  %1 = linalg.batch_matmul ins(%arg0, %arg1 : tensor<1x?x?xf32>, tensor<1x?x?xf32>)
+      outs(%0 : tensor<1x?x?xf32>) -> tensor<1x?x?xf32>
+  return %1 : tensor<1x?x?xf32>
+}
+//  CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)
+//      CHECK: builtin.func @control_consumer_reshape_fusion
+//      CHECK:   %[[FILL:.+]] = linalg.generic
+// CHECK-SAME:       indexing_maps = [#[[MAP]]]
+// CHECK-SAME:       outs(%{{.+}} : tensor<1x?x?xf32>)
+//      CHECK:   linalg.batch_matmul
+// CHECK-SAME:       outs(%[[FILL]] : tensor<1x?x?xf32>)

diff  --git a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
index 487d42f8e9848..e9f72e500686a 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
@@ -73,6 +73,52 @@ struct TestLinalgElementwiseFusion
   }
 };
 
+struct TestLinalgControlFuseByExpansion
+    : public PassWrapper<TestLinalgControlFuseByExpansion, FunctionPass> {
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry
+        .insert<AffineDialect, linalg::LinalgDialect, tensor::TensorDialect>();
+  }
+  StringRef getArgument() const final {
+    return "test-linalg-control-fusion-by-expansion";
+  }
+  StringRef getDescription() const final {
+    return "Test controlling of fusion of elementwise ops with reshape by "
+           "expansion";
+  }
+
+  void runOnFunction() override {
+    MLIRContext *context = &this->getContext();
+    FuncOp funcOp = this->getFunction();
+    RewritePatternSet fusionPatterns(context);
+
+    linalg::ControlElementwiseOpsFusionFn controlReshapeFusionFn =
+        [](const OpResult &producer, OpOperand &consumer) {
+          if (auto collapseOp =
+                  producer.getDefiningOp<linalg::TensorCollapseShapeOp>()) {
+            if (!collapseOp.src().getDefiningOp<linalg::LinalgOp>()) {
+              return false;
+            }
+          }
+          if (auto expandOp =
+                  dyn_cast<linalg::TensorExpandShapeOp>(consumer.getOwner())) {
+            if (expandOp->hasOneUse()) {
+              OpOperand &use = *expandOp->getUses().begin();
+              auto linalgOp = dyn_cast<linalg::LinalgOp>(use.getOwner());
+              if (linalgOp && linalgOp.isOutputTensor(&use))
+                return true;
+            }
+          }
+          return linalg::skipUnitDimReshape(producer, consumer);
+        };
+
+    linalg::populateFoldReshapeOpsByExpansionPatterns(fusionPatterns,
+                                                      controlReshapeFusionFn);
+    (void)applyPatternsAndFoldGreedily(funcOp.getBody(),
+                                       std::move(fusionPatterns));
+  }
+};
+
 struct TestPushExpandingReshape
     : public PassWrapper<TestPushExpandingReshape, FunctionPass> {
   void getDependentDialects(DialectRegistry &registry) const override {
@@ -99,6 +145,10 @@ void registerTestLinalgElementwiseFusion() {
   PassRegistration<TestLinalgElementwiseFusion>();
 }
 
+void registerTestLinalgControlFuseByExpansion() {
+  PassRegistration<TestLinalgControlFuseByExpansion>();
+}
+
 void registerTestPushExpandingReshape() {
   PassRegistration<TestPushExpandingReshape>();
 }

diff  --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 3059b1fafb96a..2fba77f76b433 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -78,6 +78,7 @@ void registerTestGpuParallelLoopMappingPass();
 void registerTestIRVisitorsPass();
 void registerTestInterfaces();
 void registerTestLinalgCodegenStrategy();
+void registerTestLinalgControlFuseByExpansion();
 void registerTestLinalgDistribution();
 void registerTestLinalgElementwiseFusion();
 void registerTestPushExpandingReshape();
@@ -165,6 +166,7 @@ void registerTestPasses() {
   mlir::test::registerTestIRVisitorsPass();
   mlir::test::registerTestInterfaces();
   mlir::test::registerTestLinalgCodegenStrategy();
+  mlir::test::registerTestLinalgControlFuseByExpansion();
   mlir::test::registerTestLinalgDistribution();
   mlir::test::registerTestLinalgElementwiseFusion();
   mlir::test::registerTestPushExpandingReshape();


        


More information about the Mlir-commits mailing list