[Mlir-commits] [mlir] 7e749d4 - [mlir][bufferization]-Add ControlBuildSubsetExtractionFn to TensorEmptyElimination (#120851)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Dec 28 03:28:13 PST 2024
Author: Amir Bishara
Date: 2024-12-28T13:28:09+02:00
New Revision: 7e749d4fb7327ce2da307ed020c02a07e8279992
URL: https://github.com/llvm/llvm-project/commit/7e749d4fb7327ce2da307ed020c02a07e8279992
DIFF: https://github.com/llvm/llvm-project/commit/7e749d4fb7327ce2da307ed020c02a07e8279992.diff
LOG: [mlir][bufferization]-Add ControlBuildSubsetExtractionFn to TensorEmptyElimination (#120851)
This PR Adds a `ControlBuildSubsetExtractionFn` to the tensor empty
elimination util, This will control the building of the subsets
extraction of the
`SubsetInsertionOpInterface`.
This control function returns the subsets extraction value that will
replace the `emptyTensorOp` use
which is being consumed by a specefic user (which the
util expects to eliminate it).
The default control function will stay like today's behavior without any
additional changes.
Added:
Modified:
mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h
mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h
index 892675954493b9..a4ee893ca53416 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h
@@ -10,7 +10,9 @@
#define MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_TRANSFORMS_H
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Operation.h"
+#include "mlir/Interfaces/SubsetOpInterface.h"
namespace mlir {
namespace bufferization {
@@ -34,13 +36,35 @@ struct OneShotBufferizationOptions;
/// "tensor.empty" op.
LogicalResult eliminateEmptyTensors(RewriterBase &rewriter, Operation *op);
+/// A function type that defines a callback to control the construction
+/// of the subset extraction of the `SubsetInsertionOpInterface`.
+/// The subset extraction value can be used as a replacement for the
+/// `emptyTensorOp` value which is being consumed by `user`, failing
+/// of building such a value should be indicated with an empty value.
+/// This function should guarantee the legality of the replacement,
+/// i.e. the replacement should dominate the user of the `emptyTensorOp`
+/// being eliminated.
+using ControlBuildSubsetExtractionFn =
+ std::function<Value(RewriterBase &, SubsetInsertionOpInterface,
+ tensor::EmptyOp emptyTensorOp, Operation *user)>;
+
+/// This method builds and returns a subset extraction value for the
+/// destination tensor that the given `op` inserts into.
+/// It returns a value which should replace the `emptyTensorOp` use
+/// that is being consumed by `user`.
+/// If no such a value found it will return an empty Value.
+Value buildSubsetExtraction(RewriterBase &rewriter,
+ SubsetInsertionOpInterface op,
+ tensor::EmptyOp emptyTensorOp, Operation *user);
+
/// Try to eliminate "tensor.empty" ops inside `op`.
///
/// This function overload accepts an existing `OneShotAnalysisState`, which
/// contains in-place bufferization decisions. This overload is useful if an
/// existing analysis should be reused for empty tensor elimination.
-LogicalResult eliminateEmptyTensors(RewriterBase &rewriter, Operation *op,
- OneShotAnalysisState &state);
+LogicalResult eliminateEmptyTensors(
+ RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state,
+ ControlBuildSubsetExtractionFn subsetsExtractionFn = buildSubsetExtraction);
/// Within the given operation, hoist buffers from loops where possible. See
/// "BufferLoopHoistingPass" for more information.
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
index abc0635a2cdff0..98c3d8d0adc6d2 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
@@ -93,8 +93,31 @@ findValidInsertionPoint(Operation *emptyTensorOp, Operation *user,
return nullptr;
}
+Value mlir::bufferization::buildSubsetExtraction(RewriterBase &rewriter,
+ SubsetInsertionOpInterface op,
+ tensor::EmptyOp emptyTensorOp,
+ Operation *user) {
+
+ mlir::OpBuilder::InsertionGuard guard(rewriter);
+ // All values that are needed to create the replacement op.
+ SmallVector<Value> neededValues = op.getValuesNeededToBuildSubsetExtraction();
+ // Find a suitable insertion point. If no suitable insertion point
+ // for the replacement can be found, return an empty value to skip
+ // this replacement.
+ Operation *insertionPoint =
+ findValidInsertionPoint(emptyTensorOp, user, neededValues);
+ if (!insertionPoint)
+ return {};
+
+ rewriter.setInsertionPoint(insertionPoint);
+ Value replacement =
+ op.buildSubsetExtraction(rewriter, emptyTensorOp->getLoc());
+ return replacement;
+}
+
LogicalResult mlir::bufferization::eliminateEmptyTensors(
- RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state) {
+ RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state,
+ ControlBuildSubsetExtractionFn subsetsExtractionFn) {
OpBuilder::InsertionGuard g(rewriter);
llvm::DenseSet<OpOperand *> visitedOpOperands;
op->walk([&](SubsetInsertionOpInterface op) {
@@ -105,10 +128,6 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors(
if (!state.isInPlace(source))
return WalkResult::skip();
- // All values that are needed to create the replacement op.
- SmallVector<Value> neededValues =
- op.getValuesNeededToBuildSubsetExtraction();
-
// Find tensor.empty ops on the reverse SSA use-def chain. Only follow
// equivalent tensors. I.e., stop when there are ops such as extract_slice
// on the path.
@@ -129,8 +148,8 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors(
&visitedOpOperands);
for (Value v : emptyTensors) {
- Operation *emptyTensorOp = v.getDefiningOp();
-
+ auto emptyTensorOp = v.getDefiningOp<tensor::EmptyOp>();
+ assert(emptyTensorOp && "expected tensor.empty op");
// Find the use to be replaced from the use-def chain.
auto iter = llvm::find_if(
visitedOpOperands, [&emptyTensorOp](OpOperand *opOperand) {
@@ -142,17 +161,7 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors(
continue;
OpOperand *useToBeReplaced = *iter;
Operation *user = useToBeReplaced->getOwner();
-
- // Find a suitable insertion point. If no suitable insertion point for
- // the replacement can be found, skip this replacement.
- Operation *insertionPoint =
- findValidInsertionPoint(emptyTensorOp, user, neededValues);
- if (!insertionPoint)
- continue;
-
- rewriter.setInsertionPoint(insertionPoint);
- Value replacement =
- op.buildSubsetExtraction(rewriter, emptyTensorOp->getLoc());
+ auto replacement = subsetsExtractionFn(rewriter, op, emptyTensorOp, user);
if (!replacement)
continue;
if (emptyTensorOp == replacement.getDefiningOp())
More information about the Mlir-commits
mailing list