[Mlir-commits] [mlir] 69011a2 - [mlir][Linalg] Make Elementwise op fusion return a map from existing values to values in the fused op.

Mahesh Ravishankar llvmlistbot at llvm.org
Tue Jan 31 12:34:03 PST 2023


Author: Mahesh Ravishankar
Date: 2023-01-31T20:33:44Z
New Revision: 69011a2ad0ce8662f69d3abef12280b1f463f99c

URL: https://github.com/llvm/llvm-project/commit/69011a2ad0ce8662f69d3abef12280b1f463f99c
DIFF: https://github.com/llvm/llvm-project/commit/69011a2ad0ce8662f69d3abef12280b1f463f99c.diff

LOG: [mlir][Linalg] Make Elementwise op fusion return a map from existing values to values in the fused op.

This replacement can be used to eliminate all uses of the
producer/consumer for case where producer/consumer has other uses
outside of the producer/consumer pair. This makes the
producer/consumer dead.

Add test and minor fixup to the test harness.

Reviewed By: hanchung

Differential Revision: https://reviews.llvm.org/D142848

Added: 
    mlir/test/Dialect/Linalg/fusion-multiuse-producer.mlir

Modified: 
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
    mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 7fe9dfe2fbf5f..953eb59b95134 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -162,8 +162,12 @@ bool areElementwiseOpsFusable(OpOperand *fusedOperand);
 /// Fuse two `linalg.generic` operations that have a producer-consumer
 /// relationship captured through `fusedOperand`. The method expects
 /// that `areElementwiseOpsFusable` returns true for the given `fusedOperand`.
-FailureOr<Operation *> fuseElementwiseOps(RewriterBase &rewriter,
-                                          OpOperand *fusedOperand);
+struct ElementwiseOpFusionResult {
+  Operation *fusedOp;
+  llvm::DenseMap<Value, Value> replacements;
+};
+FailureOr<ElementwiseOpFusionResult>
+fuseElementwiseOps(RewriterBase &rewriter, OpOperand *fusedOperand);
 
 /// Split the given `op` into two parts along the given iteration space
 /// `dimension` at the specified `splitPoint`, and return the two parts.

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 2504a2ab0c9bd..8df324dfa3818 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -23,8 +23,8 @@
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Support/LLVM.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include <utility>
 #include <optional>
+#include <utility>
 
 namespace mlir {
 #define GEN_PASS_DEF_LINALGFOLDUNITEXTENTDIMS
@@ -73,6 +73,9 @@ static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
 
 /// Conditions for elementwise fusion of generic operations.
 bool mlir::linalg::areElementwiseOpsFusable(OpOperand *fusedOperand) {
+  if (!fusedOperand)
+    return false;
+
   auto producer = fusedOperand->get().getDefiningOp<GenericOp>();
   auto consumer = dyn_cast<GenericOp>(fusedOperand->getOwner());
 
@@ -270,7 +273,7 @@ static void generateFusedElementwiseOpRegion(
          "Ill-formed GenericOp region");
 }
 
-FailureOr<Operation *>
+FailureOr<mlir::linalg::ElementwiseOpFusionResult>
 mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
                                  OpOperand *fusedOperand) {
   assert(areElementwiseOpsFusable(fusedOperand) &&
@@ -390,7 +393,15 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
   generateFusedElementwiseOpRegion(
       rewriter, fusedOp, consumerToProducerLoopsMap, fusedOperand,
       consumer.getNumLoops(), preservedProducerResults);
-  return fusedOp.getOperation();
+  ElementwiseOpFusionResult result;
+  result.fusedOp = fusedOp;
+  int resultNum = 0;
+  for (auto [index, producerResult] : llvm::enumerate(producer->getResults()))
+    if (preservedProducerResults.count(index))
+      result.replacements[producerResult] = fusedOp->getResult(resultNum++);
+  for (auto consumerResult : consumer->getResults())
+    result.replacements[consumerResult] = fusedOp->getResult(resultNum++);
+  return result;
 }
 
 namespace {
@@ -411,13 +422,20 @@ class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
       if (!controlFn(&opOperand))
         continue;
 
-      FailureOr<Operation *> fusedOp = fuseElementwiseOps(rewriter, &opOperand);
-      if (succeeded(fusedOp)) {
-        auto replacements =
-            (*fusedOp)->getResults().take_back(genericOp.getNumResults());
-        rewriter.replaceOp(genericOp, replacements);
-        return success();
+      FailureOr<ElementwiseOpFusionResult> fusionResult =
+          fuseElementwiseOps(rewriter, &opOperand);
+      if (failed(fusionResult))
+        rewriter.notifyMatchFailure(genericOp, "fusion failed");
+      Operation *producer = opOperand.get().getDefiningOp();
+      for (auto [origVal, replacement] : fusionResult->replacements) {
+        Value origValCopy = origVal;
+        rewriter.replaceUseIf(origVal, replacement, [&](OpOperand &use) {
+          // Only replace consumer uses.
+          return use.get().getDefiningOp() != producer;
+        });
       }
+      rewriter.eraseOp(genericOp);
+      return success();
     }
     return failure();
   }

diff  --git a/mlir/test/Dialect/Linalg/fusion-multiuse-producer.mlir b/mlir/test/Dialect/Linalg/fusion-multiuse-producer.mlir
new file mode 100644
index 0000000000000..7871ae08fd54a
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/fusion-multiuse-producer.mlir
@@ -0,0 +1,34 @@
+// RUN: mlir-opt -test-linalg-elementwise-fusion-patterns=fuse-multiuse-producer -split-input-file %s | FileCheck %s
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+func.func @multi_use_producer(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
+    %arg2 : tensor<?x?xf32>, %arg3 : tensor<?x?xf32>, %arg4 : tensor<?x?xf32>)
+    -> (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>) {
+  %0:2 = linalg.generic {
+      indexing_maps = [#map, #map, #map],
+      iterator_types = ["parallel", "parallel"]}
+      ins(%arg0 : tensor<?x?xf32>)
+      outs(%arg1, %arg2 : tensor<?x?xf32>, tensor<?x?xf32>) {
+  ^bb0(%b0: f32, %b1 : f32, %b2 : f32):
+    %1 = arith.addf %b0, %b1 : f32
+    linalg.yield %1, %1 : f32, f32
+  } -> (tensor<?x?xf32>, tensor<?x?xf32>)
+  %2 = linalg.generic {
+      indexing_maps = [#map, #map, #map],
+      iterator_types = ["parallel", "parallel"]}
+      ins(%0#1, %arg3 : tensor<?x?xf32>, tensor<?x?xf32>)
+      outs(%arg4 : tensor<?x?xf32>) {
+  ^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
+    %3 = arith.mulf %b0, %b1 : f32
+    linalg.yield %3 : f32
+  } -> tensor<?x?xf32>
+  return %0#0, %0#1, %2 : tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>
+}
+//      CHECK: func @multi_use_producer(
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+// CHECK-SAME:     %[[ARG3:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+// CHECK-SAME:     %[[ARG4:[a-zA-Z0-9]+]]: tensor<?x?xf32>)
+//      CHECK:   %[[RESULT:.+]]:3 = linalg.generic
+//      CHECK:   return %[[RESULT]]#0, %[[RESULT]]#1, %[[RESULT]]#2

diff  --git a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
index e2f61a9611b0d..21c9eaee5213b 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
@@ -51,6 +51,38 @@ static bool setFusedOpOperandLimit(OpOperand *fusedOperand) {
 }
 
 namespace {
+
+/// Pattern to test fusion of producer with consumer, even if producer has
+/// multiple uses.
+struct TestMultiUseProducerFusion : public OpRewritePattern<linalg::GenericOp> {
+  using OpRewritePattern<linalg::GenericOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(linalg::GenericOp genericOp,
+                                PatternRewriter &rewriter) const override {
+    OpOperand *fusableOperand = nullptr;
+    for (OpOperand &operand : genericOp->getOpOperands()) {
+      if (linalg::areElementwiseOpsFusable(&operand)) {
+        fusableOperand = &operand;
+        break;
+      }
+    }
+    if (!fusableOperand) {
+      return rewriter.notifyMatchFailure(genericOp, "no fusable operand found");
+    }
+    std::optional<linalg::ElementwiseOpFusionResult> fusionResult =
+        linalg::fuseElementwiseOps(rewriter, fusableOperand);
+    if (!fusionResult)
+      rewriter.notifyMatchFailure(genericOp, "fusion failed");
+    for (auto [origValue, replacement] : fusionResult->replacements) {
+      rewriter.replaceUseIf(origValue, replacement, [&](OpOperand &use) {
+        return use.getOwner() != genericOp.getOperation();
+      });
+    }
+    rewriter.eraseOp(genericOp);
+    return success();
+  }
+};
+
 struct TestLinalgElementwiseFusion
     : public PassWrapper<TestLinalgElementwiseFusion,
                          OperationPass<func::FuncOp>> {
@@ -105,6 +137,12 @@ struct TestLinalgElementwiseFusion
                      "fusion patterns that "
                      "collapse the iteration space of the consumer"),
       llvm::cl::init(false)};
+
+  Option<bool> fuseMultiUseProducer{
+      *this, "fuse-multiuse-producer",
+      llvm::cl::desc("Test fusion of producer ops with multiple uses"),
+      llvm::cl::init(false)};
+
   ListOption<int64_t> collapseDimensions{
       *this, "collapse-dimensions-control",
       llvm::cl::desc("Test controlling dimension collapse pattern")};
@@ -117,8 +155,9 @@ struct TestLinalgElementwiseFusion
       RewritePatternSet fusionPatterns(context);
       auto controlFn = [](OpOperand *operand) { return true; };
       linalg::populateElementwiseOpsFusionPatterns(fusionPatterns, controlFn);
-      (void)applyPatternsAndFoldGreedily(funcOp.getBody(),
-                                         std::move(fusionPatterns));
+      if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
+                                              std::move(fusionPatterns))))
+        return signalPassFailure();
       return;
     }
 
@@ -127,8 +166,9 @@ struct TestLinalgElementwiseFusion
       linalg::populateElementwiseOpsFusionPatterns(fusionPatterns,
                                                    setFusedOpOperandLimit<4>);
 
-      (void)applyPatternsAndFoldGreedily(funcOp.getBody(),
-                                         std::move(fusionPatterns));
+      if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
+                                              std::move(fusionPatterns))))
+        return signalPassFailure();
       return;
     }
 
@@ -172,8 +212,9 @@ struct TestLinalgElementwiseFusion
 
       linalg::populateFoldReshapeOpsByExpansionPatterns(fusionPatterns,
                                                         controlReshapeFusionFn);
-      (void)applyPatternsAndFoldGreedily(funcOp.getBody(),
-                                         std::move(fusionPatterns));
+      if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
+                                              std::move(fusionPatterns))))
+        return signalPassFailure();
       return;
     }
 
@@ -181,7 +222,10 @@ struct TestLinalgElementwiseFusion
       RewritePatternSet patterns(context);
       linalg::populateFoldReshapeOpsByCollapsingPatterns(
           patterns, [](OpOperand * /*fusedOperand */) { return true; });
-      (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
+      if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
+                                              std::move(patterns))))
+        return signalPassFailure();
+      return;
     }
 
     if (fuseWithReshapeByCollapsingWithControlFn) {
@@ -195,7 +239,19 @@ struct TestLinalgElementwiseFusion
         return true;
       };
       linalg::populateFoldReshapeOpsByCollapsingPatterns(patterns, controlFn);
-      (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
+      if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
+                                              std::move(patterns))))
+        return signalPassFailure();
+      return;
+    }
+
+    if (fuseMultiUseProducer) {
+      RewritePatternSet patterns(context);
+      patterns.insert<TestMultiUseProducerFusion>(context);
+      if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
+                                              std::move(patterns))))
+        return signalPassFailure();
+      return;
     }
 
     if (!collapseDimensions.empty()) {
@@ -209,7 +265,10 @@ struct TestLinalgElementwiseFusion
           };
       RewritePatternSet patterns(context);
       linalg::populateCollapseDimensions(patterns, collapseFn);
-      (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
+      if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
+                                              std::move(patterns))))
+        return signalPassFailure();
+      return;
     }
   }
 };


        


More information about the Mlir-commits mailing list