[Mlir-commits] [mlir] [mlir][bufferization]-Add ControlBuildSubsetExtractionFn to TensorEmptyElimination (PR #120851)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Dec 21 12:07:57 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-bufferization

Author: Amir Bishara (amirBish)

<details>
<summary>Changes</summary>

This PR Adds a `ControlBuildSubsetExtractionFn` to the tensor empty elimination util, This will control the building of the subsets extraction of the
`SubsetInsertionOpInterface`.

This control function returns the subsets extraction value that will replace the `emptyTensorOp` use
which is being consumed by a specefic user (which the
 util expects to eliminate it).

The default control function will stay like  today's behavior without any additional changes.

---
Full diff: https://github.com/llvm/llvm-project/pull/120851.diff


2 Files Affected:

- (modified) mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h (+30-2) 
- (modified) mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp (+29-20) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h
index 892675954493b9..bd9242e2caccb4 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h
@@ -10,7 +10,9 @@
 #define MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_TRANSFORMS_H
 
 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/IR/Operation.h"
+#include "mlir/Interfaces/SubsetOpInterface.h"
 
 namespace mlir {
 namespace bufferization {
@@ -34,13 +36,39 @@ struct OneShotBufferizationOptions;
 /// "tensor.empty" op.
 LogicalResult eliminateEmptyTensors(RewriterBase &rewriter, Operation *op);
 
+/// Find a valid insertion point for a replacement of `emptyTensorOp`'s
+/// use of `user` operation, assuming that the replacement may use any
+/// value from `neededValues`.
+Operation *findValidInsertionPoint(Operation *emptyTensorOp, Operation *user,
+                                   const SmallVector<Value> &neededValues);
+
+/// A function type that defines a callBack to control the build of the
+/// subsets extraction of the `SubsetInsertionOpInterface`.
+/// The subsets extraction value will replace the `emptyTensorOp` value
+/// which is being consumed by `user`, failing of building such a value
+/// should be indicated with an empty value.
+/// This function should guarantee the legality of the replacement.
+using ControlBuildSubsetExtractionFn =
+    std::function<Value(RewriterBase &, SubsetInsertionOpInterface,
+                        tensor::EmptyOp emptyTensorOp, Operation *user)>;
+
+/// This method Builds and returns a subset extraction value for the
+/// destination tensor that the given `op` inserts into.
+/// It returns a value which should replace the `emptyTensorOp` use
+/// that is being consumed by `user`, If no such a value found it
+/// will return an empty Value.
+Value buildSubsetExtraction(RewriterBase &rewriter,
+                            SubsetInsertionOpInterface op,
+                            tensor::EmptyOp emptyTensorOp, Operation *user);
+
 /// Try to eliminate "tensor.empty" ops inside `op`.
 ///
 /// This function overload accepts an existing `OneShotAnalysisState`, which
 /// contains in-place bufferization decisions. This overload is useful if an
 /// existing analysis should be reused for empty tensor elimination.
-LogicalResult eliminateEmptyTensors(RewriterBase &rewriter, Operation *op,
-                                    OneShotAnalysisState &state);
+LogicalResult eliminateEmptyTensors(
+    RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state,
+    ControlBuildSubsetExtractionFn subsetsExtractionFn = buildSubsetExtraction);
 
 /// Within the given operation, hoist buffers from loops where possible. See
 /// "BufferLoopHoistingPass" for more information.
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
index abc0635a2cdff0..dabb44edd32783 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
@@ -51,9 +51,9 @@ neededValuesDominateInsertionPoint(const DominanceInfo &domInfo,
 /// Find a valid insertion point for a replacement of `emptyTensorOp`'s
 /// use of `user` operation, assuming that the replacement may use any
 /// value from `neededValues`.
-static Operation *
-findValidInsertionPoint(Operation *emptyTensorOp, Operation *user,
-                        const SmallVector<Value> &neededValues) {
+Operation *mlir::bufferization::findValidInsertionPoint(
+    Operation *emptyTensorOp, Operation *user,
+    const SmallVector<Value> &neededValues) {
   DominanceInfo domInfo;
   Operation *candidateInsertionPoint = emptyTensorOp;
 
@@ -93,8 +93,31 @@ findValidInsertionPoint(Operation *emptyTensorOp, Operation *user,
   return nullptr;
 }
 
+Value mlir::bufferization::buildSubsetExtraction(RewriterBase &rewriter,
+                                                 SubsetInsertionOpInterface op,
+                                                 tensor::EmptyOp emptyTensorOp,
+                                                 Operation *user) {
+
+  mlir::OpBuilder::InsertionGuard guard(rewriter);
+  // All values that are needed to create the replacement op.
+  SmallVector<Value> neededValues = op.getValuesNeededToBuildSubsetExtraction();
+  // Find a suitable insertion point. If no suitable insertion point
+  // for the replacement can be found, return an empty value to skip
+  // this replacement.
+  Operation *insertionPoint =
+      findValidInsertionPoint(emptyTensorOp, user, neededValues);
+  if (!insertionPoint)
+    return {};
+
+  rewriter.setInsertionPoint(insertionPoint);
+  Value replacement =
+      op.buildSubsetExtraction(rewriter, emptyTensorOp->getLoc());
+  return replacement;
+}
+
 LogicalResult mlir::bufferization::eliminateEmptyTensors(
-    RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state) {
+    RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state,
+    ControlBuildSubsetExtractionFn subsetsExtractionFn) {
   OpBuilder::InsertionGuard g(rewriter);
   llvm::DenseSet<OpOperand *> visitedOpOperands;
   op->walk([&](SubsetInsertionOpInterface op) {
@@ -105,10 +128,6 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors(
     if (!state.isInPlace(source))
       return WalkResult::skip();
 
-    // All values that are needed to create the replacement op.
-    SmallVector<Value> neededValues =
-        op.getValuesNeededToBuildSubsetExtraction();
-
     // Find tensor.empty ops on the reverse SSA use-def chain. Only follow
     // equivalent tensors. I.e., stop when there are ops such as extract_slice
     // on the path.
@@ -129,7 +148,7 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors(
         &visitedOpOperands);
 
     for (Value v : emptyTensors) {
-      Operation *emptyTensorOp = v.getDefiningOp();
+      auto emptyTensorOp = v.getDefiningOp<tensor::EmptyOp>();
 
       // Find the use to be replaced from the use-def chain.
       auto iter = llvm::find_if(
@@ -142,17 +161,7 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors(
         continue;
       OpOperand *useToBeReplaced = *iter;
       Operation *user = useToBeReplaced->getOwner();
-
-      // Find a suitable insertion point. If no suitable insertion point for
-      // the replacement can be found, skip this replacement.
-      Operation *insertionPoint =
-          findValidInsertionPoint(emptyTensorOp, user, neededValues);
-      if (!insertionPoint)
-        continue;
-
-      rewriter.setInsertionPoint(insertionPoint);
-      Value replacement =
-          op.buildSubsetExtraction(rewriter, emptyTensorOp->getLoc());
+      auto replacement = subsetsExtractionFn(rewriter, op, emptyTensorOp, user);
       if (!replacement)
         continue;
       if (emptyTensorOp == replacement.getDefiningOp())

``````````

</details>


https://github.com/llvm/llvm-project/pull/120851


More information about the Mlir-commits mailing list