[Mlir-commits] [mlir] [mlir][linalg] Improve getPreservedProducerResults estimation in ElementwiseOpFusion (PR #104409)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Aug 19 00:30:22 PDT 2024


https://github.com/DanielLevi6 updated https://github.com/llvm/llvm-project/pull/104409

>From 9004bf6c6eb1a319e2b4c64e24d10d11391a1d77 Mon Sep 17 00:00:00 2001
From: Daniel Levi <daniel60030 at gmail.com>
Date: Tue, 13 Aug 2024 16:10:16 +0300
Subject: [PATCH] Improve getPreservedProducerResults estimation in
 ElementwiseOpFusion
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

This commit changes the getPreservedProducerResults function so that it takes the consumer into account along with the producer, in order to predict which of the producer’s outputs can be dropped during the fusion process. It provides a more accurate prediction, considering that the fusion process also depends on the consumer.
---
 .../Dialect/Linalg/Transforms/Transforms.h    | 11 +++-
 .../Linalg/Transforms/ElementwiseOpFusion.cpp | 52 ++++++++++++++++---
 2 files changed, 54 insertions(+), 9 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 477ef7bfafb181..6640682b4cb250 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -497,12 +497,19 @@ LogicalResult dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
 struct ElementwiseOpFusionResult {
   Operation *fusedOp;
   llvm::DenseMap<Value, Value> replacements;
-  static llvm::SmallDenseSet<int>
-  getPreservedProducerResults(GenericOp producer, GenericOp consumer);
 };
 FailureOr<ElementwiseOpFusionResult>
 fuseElementwiseOps(RewriterBase &rewriter, OpOperand *fusedOperand);
 
+/// Returns a set of indices of the producer's results which would
+/// be preserved after the fusion.
+/// * There is a chance that the implementation of the transformation does not
+/// agree with the result of this method. This function gives a prediction based
+/// on an optimized fusion.
+llvm::SmallDenseSet<int> getPreservedProducerResults(GenericOp producer,
+                                                     GenericOp consumer,
+                                                     OpOperand *fusedOperand);
+
 /// Try to peel and canonicalize loop `op` and return the new result.
 /// Also applies affine_min/max bounds simplification on the fly where relevant.
 // TODO: Add support for scf.parallel and affine.for loops.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index e73df61c964341..d1ccae3121f683 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -70,20 +70,57 @@ static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
   return t1.compose(fusedConsumerArgIndexMap);
 }
 
+// Checks if the given operand can be dropped, and the remaining operands
+// of the fused producer & consumer after the fusion can still compute the
+// bounds of the op.
+static bool isOpOperandCanBeDroppedAfterFusedLinalgs(
+    GenericOp producer, GenericOp consumer,
+    ArrayRef<OpOperand *> opOperandsToIgnore) {
+  SmallVector<AffineMap> indexingMaps;
+
+  SmallVector<GenericOp> ops = {producer, consumer};
+  for (auto &op : ops) {
+    for (auto &opOperand : op->getOpOperands()) {
+      if (llvm::is_contained(opOperandsToIgnore, &opOperand)) {
+        continue;
+      }
+      indexingMaps.push_back(op.getMatchingIndexingMap(&opOperand));
+    }
+  }
+
+  // The concatanation of the remained indexing maps must be invertible, so
+  // the bounds of the op can be still computed after dropping the selected
+  // operand. inversePermutation returns an empty AffineMap in case the
+  // concatanated indexing maps are not invertible.
+  return inversePermutation(concatAffineMaps(indexingMaps)) != AffineMap();
+}
+
 /// Returns a set of indices of the producer's results which would
 /// be preserved after the fusion.
-llvm::SmallDenseSet<int>
-ElementwiseOpFusionResult::getPreservedProducerResults(GenericOp producer,
-                                                       GenericOp consumer) {
+/// * There is a chance that the implementation of the transformation does not
+/// agree with the result of this method. This function gives a prediction based
+/// on an optimized fusion.
+llvm::SmallDenseSet<int> mlir::linalg::getPreservedProducerResults(
+    GenericOp producer, GenericOp consumer, OpOperand *fusedOperand) {
   llvm::SmallDenseSet<int> preservedProducerResults;
+  llvm::SmallVector<OpOperand *> opOperandsToIgnore;
+
+  // The fusedOperand will be removed during the fusion
+  opOperandsToIgnore.emplace_back(fusedOperand);
+
   for (const auto &producerResult : llvm::enumerate(producer->getResults())) {
     auto *outputOperand = producer.getDpsInitOperand(producerResult.index());
+    opOperandsToIgnore.emplace_back(outputOperand);
     if (producer.payloadUsesValueFromOperand(outputOperand) ||
-        !producer.canOpOperandsBeDropped(outputOperand) ||
+        !isOpOperandCanBeDroppedAfterFusedLinalgs(producer, consumer,
+                                                  opOperandsToIgnore) ||
         llvm::any_of(producerResult.value().getUsers(), [&](Operation *user) {
           return user != consumer.getOperation();
         })) {
       preservedProducerResults.insert(producerResult.index());
+
+      // In case the operand can't be dropped
+      opOperandsToIgnore.pop_back_val();
     }
   }
   return preservedProducerResults;
@@ -300,10 +337,11 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
   // TODO: allow fusing the producer of an output operand.
   assert(consumer.isDpsInput(fusedOperand) &&
          "expected producer of input operand");
-  /// Find the results of the producer that have uses outside of the consumer.
+  /// Find the results of the producer that have uses outside of the consumer,
+  /// after the fusion.
   llvm::SmallDenseSet<int> preservedProducerResults =
-      ElementwiseOpFusionResult::getPreservedProducerResults(producer,
-                                                             consumer);
+      mlir::linalg::getPreservedProducerResults(producer, consumer,
+                                                fusedOperand);
 
   // Compute the fused operands list and indexing maps.
   SmallVector<Value> fusedInputOperands, fusedOutputOperands;



More information about the Mlir-commits mailing list