[Mlir-commits] [mlir] [mlir][bufferization]-Support unhandled cases in EmptyTensorElimination (PR #118958)
Matthias Springer
llvmlistbot at llvm.org
Mon Dec 16 08:07:32 PST 2024
================
@@ -28,47 +28,72 @@ namespace bufferization {
using namespace mlir;
using namespace mlir::bufferization;
+/// Return true if `val` is in scope at the given
+/// `insertionPoint`.
+static bool valueDominateInsertionPoint(const DominanceInfo &domInfo,
+ Operation *insertionPoint, Value val) {
+ if (auto bbArg = dyn_cast<BlockArgument>(val)) {
+ Block *owner = bbArg.getOwner();
+ if (!owner->findAncestorOpInBlock(*insertionPoint))
+ return false;
+ } else {
+ auto opResult = cast<OpResult>(val);
+ if (!domInfo.properlyDominates(opResult.getOwner(), insertionPoint))
+ return false;
+ }
+ return true;
+}
+
/// Return true if all `neededValues` are in scope at the given
/// `insertionPoint`.
static bool
neededValuesDominateInsertionPoint(const DominanceInfo &domInfo,
Operation *insertionPoint,
const SmallVector<Value> &neededValues) {
- for (Value val : neededValues) {
- if (auto bbArg = dyn_cast<BlockArgument>(val)) {
- Block *owner = bbArg.getOwner();
- if (!owner->findAncestorOpInBlock(*insertionPoint))
- return false;
- } else {
- auto opResult = cast<OpResult>(val);
- if (!domInfo.properlyDominates(opResult.getOwner(), insertionPoint))
- return false;
- }
- }
- return true;
-}
+ for (Value val : neededValues)
+ if (!valueDominateInsertionPoint(domInfo, insertionPoint, val))
+ return false;
-/// Return true if the given `insertionPoint` dominates all uses of
-/// `emptyTensorOp`.
-static bool insertionPointDominatesUses(const DominanceInfo &domInfo,
- Operation *insertionPoint,
- Operation *emptyTensorOp) {
- return llvm::all_of(emptyTensorOp->getUsers(), [&](Operation *user) {
- return domInfo.dominates(insertionPoint, user);
- });
+ return true;
}
-/// Find a valid insertion point for a replacement of `emptyTensorOp`, assuming
-/// that the replacement may use any value from `neededValues`.
+/// Find a valid insertion point for a replacement of `useToBeEliminated`,
+/// assuming that the replacement may use any value from `neededValues`.
static Operation *
-findValidInsertionPoint(Operation *emptyTensorOp,
+findValidInsertionPoint(OpOperand *useToBeEliminated,
const SmallVector<Value> &neededValues) {
DominanceInfo domInfo;
- // Gather all possible insertion points: the location of `emptyTensorOp` and
- // right after the definition of each value in `neededValues`.
+ Operation *candidateInsertionPoint = useToBeEliminated->getOwner();
+ assert(isa<OpResult>(useToBeEliminated->get()) && "expected a result value");
+ // Both `tensor.empty` and its user are within different blocks.
+ if (useToBeEliminated->getOwner()->getBlock() !=
+ useToBeEliminated->get().getDefiningOp()->getBlock())
+ candidateInsertionPoint = useToBeEliminated->get().getDefiningOp();
+
+ // Trying to move the needed values before the `emptyTensorOp`.
+ for (Value val : neededValues) {
+ if (valueDominateInsertionPoint(domInfo, candidateInsertionPoint, val))
+ continue;
+ Operation *definingOp = val.getDefiningOp();
+ if (!definingOp)
+ continue;
+
+ bool isItSafeToMoveOp =
+ llvm::all_of(definingOp->getOperands(), [&](Value operand) {
----------------
matthias-springer wrote:
I'm not really happy with this approach because this is not a general solution. I think this is just solving one edge case. Moving single ops around feels a bit like a hack. What if tomorrow somebody comes along with a use case where we have to move an entire graph of ops (i.e., defining ops of the operands of the op that you're moving)? Are we going to keep extending the pass like that?
But I don't know a better solution that works for all cases.
Can we make this transformation extensible with a lambda, such that downstream users can customize the transformation? Some options that come to mind:
* A lambda that computes the insertion point. By default (for compatibility with existing code), it does what the code is doing now. But users can provide a different lambda that also moves ops around and can thus compute an insertion point in cases such as yours.
* A "buildSubsetExtraction" lambda that internally computes an insertion point, calls `SubsetInsertionOpInterface::buildSubsetExtraction` and `getValuesNeededToBuildSubsetExtraction` and may also move ops around. (This sounds slightly better to me than the previous suggestion.)
https://github.com/llvm/llvm-project/pull/118958
More information about the Mlir-commits
mailing list