[Mlir-commits] [mlir] cf2d625 - [mlir][linalg] Expose getPreservedProducerResults method from ElementwiseOpFusion file (#73850)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Dec 8 01:50:38 PST 2023


Author: Amir Bishara
Date: 2023-12-08T11:50:33+02:00
New Revision: cf2d625a5d328ab4af6292be7b47c645ffef0e2b

URL: https://github.com/llvm/llvm-project/commit/cf2d625a5d328ab4af6292be7b47c645ffef0e2b
DIFF: https://github.com/llvm/llvm-project/commit/cf2d625a5d328ab4af6292be7b47c645ffef0e2b.diff

LOG: [mlir][linalg] Expose getPreservedProducerResults method from ElementwiseOpFusion file (#73850)

Declare `getPreservedProducerResults` function which helps to get the
preserved results of the producer linalg generic operation as a result
of elementwise fusion.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 3f4dfe42b71fde..a848d12fbbb50e 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -493,6 +493,8 @@ LogicalResult dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
 struct ElementwiseOpFusionResult {
   Operation *fusedOp;
   llvm::DenseMap<Value, Value> replacements;
+  static llvm::SmallDenseSet<int>
+  getPreservedProducerResults(GenericOp producer, GenericOp consumer);
 };
 FailureOr<ElementwiseOpFusionResult>
 fuseElementwiseOps(RewriterBase &rewriter, OpOperand *fusedOperand);

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index dc5ea28b67cdc0..3eb91190751ef1 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -71,6 +71,25 @@ static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
   return t1.compose(fusedConsumerArgIndexMap);
 }
 
+/// Returns a set of indices of the producer's results which would
+/// be preserved after the fusion.
+llvm::SmallDenseSet<int>
+ElementwiseOpFusionResult::getPreservedProducerResults(GenericOp producer,
+                                                       GenericOp consumer) {
+  llvm::SmallDenseSet<int> preservedProducerResults;
+  for (const auto &producerResult : llvm::enumerate(producer->getResults())) {
+    auto *outputOperand = producer.getDpsInitOperand(producerResult.index());
+    if (producer.payloadUsesValueFromOperand(outputOperand) ||
+        !producer.canOpOperandsBeDropped(outputOperand) ||
+        llvm::any_of(producerResult.value().getUsers(), [&](Operation *user) {
+          return user != consumer.getOperation();
+        })) {
+      preservedProducerResults.insert(producerResult.index());
+    }
+  }
+  return preservedProducerResults;
+}
+
 /// Conditions for elementwise fusion of generic operations.
 bool mlir::linalg::areElementwiseOpsFusable(OpOperand *fusedOperand) {
   if (!fusedOperand)
@@ -285,17 +304,9 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
   assert(consumer.isDpsInput(fusedOperand) &&
          "expected producer of input operand");
   /// Find the results of the producer that have uses outside of the consumer.
-  llvm::SmallDenseSet<int> preservedProducerResults;
-  for (const auto &producerResult : llvm::enumerate(producer->getResults())) {
-    auto *outputOperand = producer.getDpsInitOperand(producerResult.index());
-    if (producer.payloadUsesValueFromOperand(outputOperand) ||
-        !producer.canOpOperandsBeDropped(outputOperand) ||
-        llvm::any_of(producerResult.value().getUsers(), [&](Operation *user) {
-          return user != consumer.getOperation();
-        })) {
-      preservedProducerResults.insert(producerResult.index());
-    }
-  }
+  llvm::SmallDenseSet<int> preservedProducerResults =
+      ElementwiseOpFusionResult::getPreservedProducerResults(producer,
+                                                             consumer);
 
   // Compute the fused operands list and indexing maps.
   SmallVector<Value> fusedInputOperands, fusedOutputOperands;


        


More information about the Mlir-commits mailing list