[Mlir-commits] [mlir] [mlir][linalg] Improve getPreservedProducerResults estimation in ElementwiseOpFusion (PR #104409)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Aug 15 00:14:25 PDT 2024
https://github.com/DanielLevi6 updated https://github.com/llvm/llvm-project/pull/104409
>From 9436558a6079dedce70e71fae5a87469ee581ddf Mon Sep 17 00:00:00 2001
From: Daniel Levi <daniel60030 at gmail.com>
Date: Tue, 13 Aug 2024 16:10:16 +0300
Subject: [PATCH] Improve getPreservedProducerResults estimation in
ElementwiseOpFusion
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
This commit changes the getPreservedProducerResults function so that it takes the consumer into account along with the producer, in order to predict which of the producer’s outputs can be dropped during the fusion process. It provides a more accurate prediction, considering that the fusion process also depends on the consumer.
---
.../Dialect/Linalg/Transforms/Transforms.h | 3 +-
.../Linalg/Transforms/ElementwiseOpFusion.cpp | 46 ++++++++++++++++---
2 files changed, 42 insertions(+), 7 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 477ef7bfafb181..e832837183a15e 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -498,7 +498,8 @@ struct ElementwiseOpFusionResult {
Operation *fusedOp;
llvm::DenseMap<Value, Value> replacements;
static llvm::SmallDenseSet<int>
- getPreservedProducerResults(GenericOp producer, GenericOp consumer);
+ getPreservedProducerResults(GenericOp producer, GenericOp consumer,
+ OpOperand *fusedOperand);
};
FailureOr<ElementwiseOpFusionResult>
fuseElementwiseOps(RewriterBase &rewriter, OpOperand *fusedOperand);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index e73df61c964341..c4de461a921c78 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -70,20 +70,54 @@ static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
return t1.compose(fusedConsumerArgIndexMap);
}
+// Checks if the given operand can be dropped, and the remaining operands
+// of the fused producer & consumer after the fusion can still compute the
+// bounds of the op.
+static bool isOpOperandCanBeDroppedAfterFusedLinalgs(
+ GenericOp producer, GenericOp consumer,
+ ArrayRef<OpOperand *> opOperandsToIgnore) {
+ SmallVector<AffineMap> indexingMaps;
+
+ SmallVector<GenericOp> ops = {producer, consumer};
+ for (auto &op : ops) {
+ for (auto &opOperand : op->getOpOperands()) {
+ if (llvm::is_contained(opOperandsToIgnore, &opOperand)) {
+ continue;
+ }
+ indexingMaps.push_back(op.getMatchingIndexingMap(&opOperand));
+ }
+ }
+
+ // The concatanation of the remained indexing maps must be invertible, so
+ // the bounds of the op can be still computed after dropping the selected
+ // operand. inversePermutation returns an empty AffineMap in case the
+ // concatanated indexing maps are not invertible.
+ return inversePermutation(concatAffineMaps(indexingMaps)) != AffineMap();
+}
+
/// Returns a set of indices of the producer's results which would
/// be preserved after the fusion.
-llvm::SmallDenseSet<int>
-ElementwiseOpFusionResult::getPreservedProducerResults(GenericOp producer,
- GenericOp consumer) {
+llvm::SmallDenseSet<int> ElementwiseOpFusionResult::getPreservedProducerResults(
+ GenericOp producer, GenericOp consumer, OpOperand *fusedOperand) {
llvm::SmallDenseSet<int> preservedProducerResults;
+ llvm::SmallVector<OpOperand *> opOperandsToIgnore;
+
+ // The fusedOperand will be removed during the fusion
+ opOperandsToIgnore.emplace_back(fusedOperand);
+
for (const auto &producerResult : llvm::enumerate(producer->getResults())) {
auto *outputOperand = producer.getDpsInitOperand(producerResult.index());
+ opOperandsToIgnore.emplace_back(outputOperand);
if (producer.payloadUsesValueFromOperand(outputOperand) ||
- !producer.canOpOperandsBeDropped(outputOperand) ||
+ !isOpOperandCanBeDroppedAfterFusedLinalgs(producer, consumer,
+ opOperandsToIgnore) ||
llvm::any_of(producerResult.value().getUsers(), [&](Operation *user) {
return user != consumer.getOperation();
})) {
preservedProducerResults.insert(producerResult.index());
+
+ // In case the operand can't be dropped
+ opOperandsToIgnore.pop_back_val();
}
}
return preservedProducerResults;
@@ -302,8 +336,8 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
"expected producer of input operand");
/// Find the results of the producer that have uses outside of the consumer.
llvm::SmallDenseSet<int> preservedProducerResults =
- ElementwiseOpFusionResult::getPreservedProducerResults(producer,
- consumer);
+ ElementwiseOpFusionResult::getPreservedProducerResults(producer, consumer,
+ fusedOperand);
// Compute the fused operands list and indexing maps.
SmallVector<Value> fusedInputOperands, fusedOutputOperands;
More information about the Mlir-commits
mailing list