[Mlir-commits] [mlir] 4a4b233 - [mlir][linalg] Improve getPreservedProducerResults estimation in ElementwiseOpFusion (#104409)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Aug 20 16:14:04 PDT 2024
Author: DanielLevi6
Date: 2024-08-20T16:14:01-07:00
New Revision: 4a4b233f35adaed44e50157db3846d0d23f2f6e1
URL: https://github.com/llvm/llvm-project/commit/4a4b233f35adaed44e50157db3846d0d23f2f6e1
DIFF: https://github.com/llvm/llvm-project/commit/4a4b233f35adaed44e50157db3846d0d23f2f6e1.diff
LOG: [mlir][linalg] Improve getPreservedProducerResults estimation in ElementwiseOpFusion (#104409)
This commit changes the getPreservedProducerResults function so that it
takes the consumer into account along with the producer, in order to
predict which of the producer’s outputs can be dropped during the fusion
process. It provides a more accurate prediction, considering that the
fusion process also depends on the consumer.
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 861e14d22d9625..bee3452ebb685f 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -497,12 +497,19 @@ LogicalResult dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
struct ElementwiseOpFusionResult {
Operation *fusedOp;
llvm::DenseMap<Value, Value> replacements;
- static llvm::SmallDenseSet<int>
- getPreservedProducerResults(GenericOp producer, GenericOp consumer);
};
FailureOr<ElementwiseOpFusionResult>
fuseElementwiseOps(RewriterBase &rewriter, OpOperand *fusedOperand);
+/// Returns a set of indices of the producer's results which would
+/// be preserved after the fusion.
+/// * There is a chance that the implementation of the transformation does not
+/// agree with the result of this method. This function gives a prediction based
+/// on an optimized fusion.
+llvm::SmallDenseSet<int> getPreservedProducerResults(GenericOp producer,
+ GenericOp consumer,
+ OpOperand *fusedOperand);
+
/// Try to peel and canonicalize loop `op` and return the new result.
/// Also applies affine_min/max bounds simplification on the fly where relevant.
// TODO: Add support for scf.parallel and affine.for loops.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 9f1b6fdc55df3b..e510302840d32c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -71,20 +71,57 @@ static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
return t1.compose(fusedConsumerArgIndexMap);
}
+// Checks if the given operand can be dropped, and the remaining operands
+// of the fused producer & consumer after the fusion can still compute the
+// bounds of the op.
+static bool isOpOperandCanBeDroppedAfterFusedLinalgs(
+ GenericOp producer, GenericOp consumer,
+ ArrayRef<OpOperand *> opOperandsToIgnore) {
+ SmallVector<AffineMap> indexingMaps;
+
+ SmallVector<GenericOp> ops = {producer, consumer};
+ for (auto &op : ops) {
+ for (auto &opOperand : op->getOpOperands()) {
+ if (llvm::is_contained(opOperandsToIgnore, &opOperand)) {
+ continue;
+ }
+ indexingMaps.push_back(op.getMatchingIndexingMap(&opOperand));
+ }
+ }
+
+ // The concatanation of the remained indexing maps must be invertible, so
+ // the bounds of the op can be still computed after dropping the selected
+ // operand. inversePermutation returns an empty AffineMap in case the
+ // concatanated indexing maps are not invertible.
+ return inversePermutation(concatAffineMaps(indexingMaps)) != AffineMap();
+}
+
/// Returns a set of indices of the producer's results which would
/// be preserved after the fusion.
-llvm::SmallDenseSet<int>
-ElementwiseOpFusionResult::getPreservedProducerResults(GenericOp producer,
- GenericOp consumer) {
+/// * There is a chance that the implementation of the transformation does not
+/// agree with the result of this method. This function gives a prediction based
+/// on an optimized fusion.
+llvm::SmallDenseSet<int> mlir::linalg::getPreservedProducerResults(
+ GenericOp producer, GenericOp consumer, OpOperand *fusedOperand) {
llvm::SmallDenseSet<int> preservedProducerResults;
+ llvm::SmallVector<OpOperand *> opOperandsToIgnore;
+
+ // The fusedOperand will be removed during the fusion
+ opOperandsToIgnore.emplace_back(fusedOperand);
+
for (const auto &producerResult : llvm::enumerate(producer->getResults())) {
auto *outputOperand = producer.getDpsInitOperand(producerResult.index());
+ opOperandsToIgnore.emplace_back(outputOperand);
if (producer.payloadUsesValueFromOperand(outputOperand) ||
- !producer.canOpOperandsBeDropped(outputOperand) ||
+ !isOpOperandCanBeDroppedAfterFusedLinalgs(producer, consumer,
+ opOperandsToIgnore) ||
llvm::any_of(producerResult.value().getUsers(), [&](Operation *user) {
return user != consumer.getOperation();
})) {
preservedProducerResults.insert(producerResult.index());
+
+ // In case the operand can't be dropped
+ opOperandsToIgnore.pop_back_val();
}
}
return preservedProducerResults;
@@ -301,10 +338,11 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
// TODO: allow fusing the producer of an output operand.
assert(consumer.isDpsInput(fusedOperand) &&
"expected producer of input operand");
- /// Find the results of the producer that have uses outside of the consumer.
+ /// Find the results of the producer that have uses outside of the consumer,
+ /// after the fusion.
llvm::SmallDenseSet<int> preservedProducerResults =
- ElementwiseOpFusionResult::getPreservedProducerResults(producer,
- consumer);
+ mlir::linalg::getPreservedProducerResults(producer, consumer,
+ fusedOperand);
// Compute the fused operands list and indexing maps.
SmallVector<Value> fusedInputOperands, fusedOutputOperands;
More information about the Mlir-commits
mailing list