[Mlir-commits] [mlir] 944a2fe - [mlir][Linalg] Add callbacks to fusion of elementwise operations to control fusion.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Apr 5 16:08:59 PDT 2021


Author: MaheshRavishankar
Date: 2021-04-05T16:08:47-07:00
New Revision: 944a2fe7633fcdd600de2772364e406514d794da

URL: https://github.com/llvm/llvm-project/commit/944a2fe7633fcdd600de2772364e406514d794da
DIFF: https://github.com/llvm/llvm-project/commit/944a2fe7633fcdd600de2772364e406514d794da.diff

LOG: [mlir][Linalg] Add callbacks to fusion of elementwise operations to control fusion.

Right now Elementwise operations fusion in Linalg fuses everything it
can. This can run up against resource limits of the target hardware
without some checks. This patch adds a callback function that clients
can use to implement a cost function. When two elementwise operations
are deemed structurally fusable, the callback can be used to control
if the fusion applies.

Differential Revision: https://reviews.llvm.org/D99820

Added: 
    mlir/test/Dialect/Linalg/fusion-elementwise-options.mlir
    mlir/test/lib/Transforms/TestLinalgElementwiseFusion.cpp

Modified: 
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
    mlir/test/lib/Transforms/CMakeLists.txt
    mlir/tools/mlir-opt/mlir-opt.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 8ce5677762695..7b7d6accd6a07 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -23,6 +23,7 @@ class FrozenRewritePatternSet;
 
 namespace linalg {
 
+struct LinalgElementwiseFusionOptions;
 struct LinalgFusionOptions;
 struct LinalgTilingOptions;
 
@@ -69,9 +70,40 @@ void populateLinalgBufferizePatterns(BufferizeTypeConverter &converter,
 /// tensors.
 void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns);
 
+using ControlElementwiseOpsFusionFn =
+    std::function<bool(const OpResult &producer, const OpOperand &consumer)>;
+
+/// Options that control fusion of elementwise operations.
+struct LinalgElementwiseFusionOptions {
+  /// Enable fusion of reshapes that are introducing unit-dimensions into the
+  /// shape with elementwise operations. By default this is disabled.
+  bool allowFoldingUnitDimReshapes = false;
+
+  LinalgElementwiseFusionOptions &setAllowFoldingUnitDimReshapes(bool val) {
+    allowFoldingUnitDimReshapes = val;
+    return *this;
+  }
+
+  /// Function that allows the caller to control when to stop fusion. Once a
+  /// producer is deemed fusable with the consumer (structurally), this callback
+  /// can be used to abort the fusion based on non-structural constraints. This
+  /// is the hook for cost models to control the amount of fusion done.
+  ControlElementwiseOpsFusionFn controlElementwiseOpsFusionFn =
+      [](const OpResult & /*producer */, const OpOperand & /*consumer */) {
+        return true;
+      };
+
+  LinalgElementwiseFusionOptions &
+  setControlElementwiseOpsFusionFn(ControlElementwiseOpsFusionFn fun) {
+    controlElementwiseOpsFusionFn = std::move(fun);
+    return *this;
+  }
+};
+
 /// Patterns for fusing linalg operation on tensors.
 void populateElementwiseOpsFusionPatterns(
-    RewritePatternSet &patterns, bool allowFoldingUnitDimReshapes = false);
+    RewritePatternSet &patterns,
+    LinalgElementwiseFusionOptions options = LinalgElementwiseFusionOptions());
 
 /// Performs standalone tiling of a single LinalgOp by `tileSizes`.
 /// and permute the loop nest according to `interchangeVector`

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
index 91848edee6f01..bb1a051c78e56 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
@@ -48,6 +48,10 @@ static bool areElementwiseOpsFusable(LinalgOp producer, LinalgOp consumer,
   if (consumerIndexMap.getNumResults() != producer.getNumLoops())
     return false;
 
+  // Currently support only operations with single result.
+  if (producer.getNumOutputs() != 1)
+    return false;
+
   // Finally the index_map for the result must be invertible. For now just
   // verify it is a permutation.
   AffineMap producerResultIndexMap = producer.getOutputIndexingMap(0);
@@ -209,10 +213,12 @@ generateFusedElementwiseOpRegion(PatternRewriter &rewriter, Operation *fusedOp,
 
 static Optional<SmallVector<Value, 1>>
 fuseElementwiseOpsImpl(LinalgOp producer, OpOperand &consumerOpOperand,
+                       const ControlElementwiseOpsFusionFn &controlFn,
                        PatternRewriter &rewriter) {
   LinalgOp consumer = cast<LinalgOp>(consumerOpOperand.getOwner());
   unsigned consumerIdx = consumerOpOperand.getOperandNumber();
-  if (!areElementwiseOpsFusable(producer, consumer, consumerIdx))
+  if (!areElementwiseOpsFusable(producer, consumer, consumerIdx) ||
+      !controlFn(producer->getResult(0), consumerOpOperand))
     return llvm::None;
 
   unsigned numFusedOperands =
@@ -1041,18 +1047,22 @@ struct FoldReshapeWithGenericOpByExpansion
 
 /// Pattern to fold a GenericOp/IndexedGenericOp with a splat constant.
 template <typename LinalgOpTy>
-struct FoldSplatConstants : public OpRewritePattern<LinalgOpTy> {
-  using OpRewritePattern<LinalgOpTy>::OpRewritePattern;
+class FoldSplatConstants : public OpRewritePattern<LinalgOpTy> {
+public:
+  FoldSplatConstants(MLIRContext *context, ControlElementwiseOpsFusionFn &fun,
+                     PatternBenefit benefit = 1)
+      : OpRewritePattern<LinalgOpTy>(context, benefit), controlFn(fun) {}
 
   LogicalResult matchAndRewrite(LinalgOpTy op,
                                 PatternRewriter &rewriter) const override {
     if (!op.hasTensorSemantics())
       return failure();
     LinalgOp linalgOp = cast<LinalgOp>(op.getOperation());
-    for (auto operand : llvm::enumerate(linalgOp.getInputs())) {
-      ConstantOp constantOp = operand.value().getDefiningOp<ConstantOp>();
+    for (auto operand : llvm::enumerate(linalgOp.getInputOpOperands())) {
+      ConstantOp constantOp = operand.value().get().getDefiningOp<ConstantOp>();
       if (!constantOp ||
-          !constantOp.value().cast<DenseElementsAttr>().isSplat())
+          !constantOp.value().cast<DenseElementsAttr>().isSplat() ||
+          !controlFn(constantOp->getResult(0), operand.value()))
         continue;
 
       // The indexing_maps for the operands of the fused operation are same as
@@ -1099,11 +1109,15 @@ struct FoldSplatConstants : public OpRewritePattern<LinalgOpTy> {
     }
     return failure();
   }
+
+private:
+  ControlElementwiseOpsFusionFn controlFn;
 };
 } // namespace
 
 static Optional<SmallVector<Value, 1>>
-fuseElementwiseOps(PatternRewriter &rewriter, OpOperand &consumerOpOperand) {
+fuseElementwiseOps(PatternRewriter &rewriter, OpOperand &consumerOpOperand,
+                   const ControlElementwiseOpsFusionFn &controlFn) {
   Operation *producer = consumerOpOperand.get().getDefiningOp();
   if (!producer || producer->getNumResults() != 1)
     return llvm::None;
@@ -1114,14 +1128,17 @@ fuseElementwiseOps(PatternRewriter &rewriter, OpOperand &consumerOpOperand) {
     return llvm::None;
 
   return fuseElementwiseOpsImpl(cast<LinalgOp>(producer), consumerOpOperand,
-                                rewriter);
+                                controlFn, rewriter);
 }
 
 namespace {
 /// Patterns to fuse a generic op, with the producer of its operands.
 template <typename LinalgOpTy>
-struct FuseElementwiseOps : public OpRewritePattern<LinalgOpTy> {
-  using OpRewritePattern<LinalgOpTy>::OpRewritePattern;
+class FuseElementwiseOps : public OpRewritePattern<LinalgOpTy> {
+public:
+  FuseElementwiseOps(MLIRContext *context, ControlElementwiseOpsFusionFn &fun,
+                     PatternBenefit benefit = 1)
+      : OpRewritePattern<LinalgOpTy>(context, benefit), controlFn(fun) {}
 
   LogicalResult matchAndRewrite(LinalgOpTy op,
                                 PatternRewriter &rewriter) const override {
@@ -1132,7 +1149,7 @@ struct FuseElementwiseOps : public OpRewritePattern<LinalgOpTy> {
       if (!producerOp || !producerOp.hasTensorSemantics())
         continue;
       Optional<SmallVector<Value, 1>> fusedOpResults =
-          fuseElementwiseOps(rewriter, opOperand);
+          fuseElementwiseOps(rewriter, opOperand, controlFn);
       if (fusedOpResults) {
         rewriter.replaceOp(op, *fusedOpResults);
         return success();
@@ -1140,6 +1157,9 @@ struct FuseElementwiseOps : public OpRewritePattern<LinalgOpTy> {
     }
     return failure();
   }
+
+private:
+  ControlElementwiseOpsFusionFn controlFn;
 };
 
 /// Pass that fuses generic ops on tensors. Used only for testing.
@@ -1148,7 +1168,10 @@ struct FusionOfTensorOpsPass
   void runOnOperation() override {
     Operation *op = getOperation();
     RewritePatternSet patterns(op->getContext());
-    populateElementwiseOpsFusionPatterns(patterns, allowFoldingUnitDimReshapes);
+    populateElementwiseOpsFusionPatterns(
+        patterns,
+        LinalgElementwiseFusionOptions().setAllowFoldingUnitDimReshapes(
+            allowFoldingUnitDimReshapes));
     (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns));
   }
 };
@@ -1193,14 +1216,14 @@ void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
 }
 
 void mlir::linalg::populateElementwiseOpsFusionPatterns(
-    RewritePatternSet &patterns, bool allowFoldingUnitDimReshapes) {
+    RewritePatternSet &patterns, LinalgElementwiseFusionOptions options) {
   auto *context = patterns.getContext();
   patterns
       .add<FuseElementwiseOps<GenericOp>, FuseElementwiseOps<IndexedGenericOp>,
            FoldSplatConstants<GenericOp>, FoldSplatConstants<IndexedGenericOp>>(
-          context);
-  populateFoldReshapeOpsByExpansionPatterns(patterns,
-                                            allowFoldingUnitDimReshapes);
+          context, options.controlElementwiseOpsFusionFn);
+  populateFoldReshapeOpsByExpansionPatterns(
+      patterns, options.allowFoldingUnitDimReshapes);
   GenericOp::getCanonicalizationPatterns(patterns, context);
   IndexedGenericOp::getCanonicalizationPatterns(patterns, context);
   TensorReshapeOp::getCanonicalizationPatterns(patterns, context);

diff  --git a/mlir/test/Dialect/Linalg/fusion-elementwise-options.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-options.mlir
new file mode 100644
index 0000000000000..b6c52c72ede81
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/fusion-elementwise-options.mlir
@@ -0,0 +1,62 @@
+// RUN: mlir-opt %s -test-linalg-elementwise-fusion-patterns -split-input-file | FileCheck %s
+
+#map0 = affine_map<(d0, d1) -> (d0, d1)>
+#binary2Dpointwise = {
+  indexing_maps = [#map0, #map0, #map0],
+  iterator_types = ["parallel", "parallel"]
+}
+#ternary2Dpointwise = {
+  indexing_maps = [#map0, #map0, #map0, #map0],
+  iterator_types = ["parallel", "parallel"]
+}
+func @test_fusion_limit(
+    %arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : tensor<?x?xf32>,
+    %arg3 : tensor<?x?xf32>, %arg4 : tensor<?x?xf32>, %arg5 : tensor<?x?xf32>)
+    -> tensor<?x?xf32> {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %d0 = memref.dim %arg0, %c0 : tensor<?x?xf32>
+  %d1 = memref.dim %arg0, %c1 : tensor<?x?xf32>
+  %init = linalg.init_tensor [%d0, %d1] : tensor<?x?xf32>
+  %0 = linalg.generic #binary2Dpointwise
+      ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+      outs(%init : tensor<?x?xf32>) {
+    ^bb0(%arg6 : f32, %arg7 : f32, %arg8 : f32):
+       %1 = mulf %arg6, %arg7 : f32
+       linalg.yield %1 : f32
+    } -> tensor<?x?xf32>
+  %2 = linalg.generic #binary2Dpointwise
+      ins(%arg2, %arg3 : tensor<?x?xf32>, tensor<?x?xf32>)
+      outs(%init : tensor<?x?xf32>) {
+    ^bb0(%arg6 : f32, %arg7 : f32, %arg8 : f32):
+       %3 = mulf %arg6, %arg7 : f32
+       linalg.yield %3 : f32
+    } -> tensor<?x?xf32>
+  %4 = linalg.generic #binary2Dpointwise
+      ins(%arg4, %arg5 : tensor<?x?xf32>, tensor<?x?xf32>)
+      outs(%init : tensor<?x?xf32>) {
+    ^bb0(%arg6 : f32, %arg7 : f32, %arg8 : f32):
+       %5 = mulf %arg6, %arg7 : f32
+       linalg.yield %5 : f32
+    } -> tensor<?x?xf32>
+  %6 = linalg.generic #ternary2Dpointwise
+      ins(%0, %2, %4 : tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>)
+      outs(%init : tensor<?x?xf32>) {
+    ^bb0(%arg6 : f32, %arg7 : f32, %arg8 : f32, %arg9 : f32):
+       %7 = addf %arg6, %arg7 : f32
+       %8 = addf %7, %arg8 : f32
+       linalg.yield %8 : f32
+    } -> tensor<?x?xf32>
+  return %6 : tensor<?x?xf32>
+}
+// CHECK-LABEL: func @test_fusion_limit
+//  CHECK-SAME:   %[[ARG0:[a-zA-z0-9_]+]]: tensor<?x?xf32>
+//  CHECK-SAME:   %[[ARG1:[a-zA-z0-9_]+]]: tensor<?x?xf32>
+//  CHECK-SAME:   %[[ARG2:[a-zA-z0-9_]+]]: tensor<?x?xf32>
+//  CHECK-SAME:   %[[ARG3:[a-zA-z0-9_]+]]: tensor<?x?xf32>
+//  CHECK-SAME:   %[[ARG4:[a-zA-z0-9_]+]]: tensor<?x?xf32>
+//  CHECK-SAME:   %[[ARG5:[a-zA-z0-9_]+]]: tensor<?x?xf32>
+//       CHECK:   %[[OP1:.+]] = linalg.generic {{.+}} ins(%[[ARG2]], %[[ARG3]]
+//       CHECK:   %[[OP2:.+]] = linalg.generic {{.+}} ins(%[[ARG4]], %[[ARG5]]
+//       CHECK:   %[[OP3:.+]] = linalg.generic {{.+}} ins(%[[ARG0]], %[[ARG1]], %[[OP1]], %[[OP2]]
+//       CHECK:   return %[[OP3]]

diff  --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt
index bd60cdfa78cce..e75b0fefca856 100644
--- a/mlir/test/lib/Transforms/CMakeLists.txt
+++ b/mlir/test/lib/Transforms/CMakeLists.txt
@@ -19,6 +19,7 @@ add_mlir_library(MLIRTestTransforms
   TestGpuRewrite.cpp
   TestInlining.cpp
   TestLinalgCodegenStrategy.cpp
+  TestLinalgElementwiseFusion.cpp
   TestLinalgFusionTransforms.cpp
   TestLinalgHoisting.cpp
   TestLinalgTransforms.cpp

diff  --git a/mlir/test/lib/Transforms/TestLinalgElementwiseFusion.cpp b/mlir/test/lib/Transforms/TestLinalgElementwiseFusion.cpp
new file mode 100644
index 0000000000000..c1da6737ac223
--- /dev/null
+++ b/mlir/test/lib/Transforms/TestLinalgElementwiseFusion.cpp
@@ -0,0 +1,79 @@
+//===- TestLinalgElementwiseFusion.cpp - Test Linalg elementwise fusion ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass for testing fusion of elementwise operations in
+// Linalg, mainly linalg options.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/TypeSwitch.h"
+
+namespace mlir {
+
+static void addOperands(Operation *op, llvm::SetVector<Value> &operandSet) {
+  if (!op)
+    return;
+  TypeSwitch<Operation *, void>(op)
+      .Case<linalg::LinalgOp>([&](linalg::LinalgOp linalgOp) {
+        operandSet.insert(linalgOp.getInputs().begin(),
+                          linalgOp.getInputs().end());
+      })
+      .Default([&](Operation *operation) {
+        operandSet.insert(operation->operand_begin(), operation->operand_end());
+      });
+}
+
+template <int limit = 3>
+static bool setFusedOpOperandLimit(const OpResult &producer,
+                                   const OpOperand &consumer) {
+  llvm::SetVector<Value> fusedOpOperands;
+  if (producer.getOwner()->getNumResults() != 1)
+    return false;
+  addOperands(consumer.getOwner(), fusedOpOperands);
+  fusedOpOperands.remove(producer);
+  addOperands(producer.getOwner(), fusedOpOperands);
+  return fusedOpOperands.size() <= limit;
+}
+
+namespace {
+struct TestLinalgElementwiseFusion
+    : public PassWrapper<TestLinalgElementwiseFusion, FunctionPass> {
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect,
+                    tensor::TensorDialect>();
+  }
+
+  void runOnFunction() override {
+    MLIRContext *context = &this->getContext();
+    FuncOp funcOp = this->getFunction();
+    RewritePatternSet fusionPatterns(context);
+
+    linalg::populateElementwiseOpsFusionPatterns(
+        fusionPatterns,
+        linalg::LinalgElementwiseFusionOptions()
+            .setControlElementwiseOpsFusionFn(setFusedOpOperandLimit<4>));
+
+    (void)applyPatternsAndFoldGreedily(funcOp.getBody(),
+                                       std::move(fusionPatterns));
+  }
+};
+} // namespace
+
+namespace test {
+void registerTestLinalgElementwiseFusion() {
+  PassRegistration<TestLinalgElementwiseFusion> testElementwiseFusionPass(
+      "test-linalg-elementwise-fusion-patterns",
+      "Test Linalg element wise operation fusion patterns");
+}
+} // namespace test
+
+} // namespace mlir

diff  --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 2bef89ea7dda7..eea5d8f494221 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -77,6 +77,7 @@ void registerTestGpuParallelLoopMappingPass();
 void registerTestIRVisitorsPass();
 void registerTestInterfaces();
 void registerTestLinalgCodegenStrategy();
+void registerTestLinalgElementwiseFusion();
 void registerTestLinalgFusionTransforms();
 void registerTestLinalgTensorFusionTransforms();
 void registerTestLinalgGreedyFusion();
@@ -154,6 +155,7 @@ void registerTestPasses() {
   test::registerTestIRVisitorsPass();
   test::registerTestInterfaces();
   test::registerTestLinalgCodegenStrategy();
+  test::registerTestLinalgElementwiseFusion();
   test::registerTestLinalgFusionTransforms();
   test::registerTestLinalgTensorFusionTransforms();
   test::registerTestLinalgGreedyFusion();


        


More information about the Mlir-commits mailing list