[Mlir-commits] [mlir] [mlir][linalg] Improve getPreservedProducerResults estimation in ElementwiseOpFusion (PR #104409)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Aug 14 23:53:58 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-linalg
Author: None (DanielLevi6)
<details>
<summary>Changes</summary>
This commit changes the getPreservedProducerResults function so that it takes the consumer into account along with the producer, in order to predict which of the producer’s outputs can be dropped during the fusion process. It provides a more accurate prediction, considering that the fusion process also depends on the consumer.
---
Full diff: https://github.com/llvm/llvm-project/pull/104409.diff
2 Files Affected:
- (modified) mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h (+1-1)
- (modified) mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp (+40-3)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 477ef7bfafb181..8c1d6449bf6bae 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -498,7 +498,7 @@ struct ElementwiseOpFusionResult {
Operation *fusedOp;
llvm::DenseMap<Value, Value> replacements;
static llvm::SmallDenseSet<int>
- getPreservedProducerResults(GenericOp producer, GenericOp consumer);
+ getPreservedProducerResults(GenericOp producer, GenericOp consumer, OpOperand *fusedOperand);
};
FailureOr<ElementwiseOpFusionResult>
fuseElementwiseOps(RewriterBase &rewriter, OpOperand *fusedOperand);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index e73df61c964341..9058cdcbc68ac0 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -70,20 +70,56 @@ static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
return t1.compose(fusedConsumerArgIndexMap);
}
+// Checks if the given operand can be dropped, and the remaining operands
+// of the fused producer & consumer after the fusion can still compute the
+// bounds of the op.
+static bool isOpOperandCanBeDroppedAfterFusedLinalgs(GenericOp producer,
+ GenericOp consumer,
+ ArrayRef<OpOperand *> opOperandsToIgnore) {
+ SmallVector<AffineMap> indexingMaps;
+
+ SmallVector<GenericOp> ops = {producer, consumer};
+ for (auto &op : ops) {
+ for (auto &opOperand : op->getOpOperands()) {
+ if (llvm::is_contained(opOperandsToIgnore, &opOperand)) {
+ continue;
+ }
+ indexingMaps.push_back(op.getMatchingIndexingMap(&opOperand));
+ }
+ }
+
+ // The concatanation of the remained indexing maps must be invertible, so
+ // the bounds of the op can be still computed after dropping the selected operand.
+ // inversePermutation returns an empty AffineMap in case the concatanated indexing
+ // maps are not invertible.
+ return inversePermutation(concatAffineMaps(indexingMaps)) != AffineMap();
+}
+
/// Returns a set of indices of the producer's results which would
/// be preserved after the fusion.
llvm::SmallDenseSet<int>
ElementwiseOpFusionResult::getPreservedProducerResults(GenericOp producer,
- GenericOp consumer) {
+ GenericOp consumer,
+ OpOperand *fusedOperand) {
llvm::SmallDenseSet<int> preservedProducerResults;
+ llvm::SmallVector<OpOperand*> opOperandsToIgnore;
+
+ // The fusedOperand will be removed during the fusion
+ opOperandsToIgnore.emplace_back(fusedOperand);
+
for (const auto &producerResult : llvm::enumerate(producer->getResults())) {
auto *outputOperand = producer.getDpsInitOperand(producerResult.index());
+ opOperandsToIgnore.emplace_back(outputOperand);
if (producer.payloadUsesValueFromOperand(outputOperand) ||
- !producer.canOpOperandsBeDropped(outputOperand) ||
+ !isOpOperandCanBeDroppedAfterFusedLinalgs(producer, consumer,
+ opOperandsToIgnore) ||
llvm::any_of(producerResult.value().getUsers(), [&](Operation *user) {
return user != consumer.getOperation();
})) {
preservedProducerResults.insert(producerResult.index());
+
+ // In case the operand can't be dropped
+ opOperandsToIgnore.pop_back_val();
}
}
return preservedProducerResults;
@@ -303,7 +339,8 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
/// Find the results of the producer that have uses outside of the consumer.
llvm::SmallDenseSet<int> preservedProducerResults =
ElementwiseOpFusionResult::getPreservedProducerResults(producer,
- consumer);
+ consumer,
+ fusedOperand);
// Compute the fused operands list and indexing maps.
SmallVector<Value> fusedInputOperands, fusedOutputOperands;
``````````
</details>
https://github.com/llvm/llvm-project/pull/104409
More information about the Mlir-commits
mailing list