[Mlir-commits] [mlir] [MLIR][TilingInterface] Extend consumer fusion for multi-use of producer (PR #110105)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Sep 28 18:51:56 PDT 2024
================
@@ -1481,21 +1481,33 @@ checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp) {
/// failure otherwise.
static FailureOr<OpOperand *> getConsumerFromUses(Value val,
Block *containingOpBlock) {
- // Step 1. Check that the value has exactly one use.
- if (!llvm::hasSingleElement(val.getUses()))
- return failure();
- // Step 2. Get uses.
- OpOperand &operand = (*val.getUses().begin());
- Operation *consumerOp = operand.getOwner();
- // TODO: We have to init result of consumer before scf.for, use
- // DestinationStyleOpInterface to get result shape from init for now.
- // Add support for other op such as op has InferTypeOpInterface.
- if (!isa<TilingInterface>(consumerOp) ||
- !isa<DestinationStyleOpInterface>(consumerOp))
- return failure();
- if (containingOpBlock != consumerOp->getBlock())
- return failure();
- return &operand;
+ // Check that the value has exactly one use which isn't a scf.yield or a
+ // tensor.parallel_insert_slice op.
+ Operation *visitedConsumerOp = nullptr;
+ for (OpOperand &opOperand : val.getUses()) {
+ Operation *consumerOp = opOperand.getOwner();
+ if (isa<scf::YieldOp, tensor::ParallelInsertSliceOp>(consumerOp))
+ continue;
+ if (visitedConsumerOp && visitedConsumerOp != consumerOp)
+ return failure();
+ // TODO: We have to init result of consumer before scf.for, use
+ // DestinationStyleOpInterface to get result shape from init for now.
+ // Add support for other op such as op has InferTypeOpInterface.
+ if (!isa<TilingInterface>(consumerOp) ||
+ !isa<DestinationStyleOpInterface>(consumerOp))
+ return failure();
+ if (containingOpBlock != consumerOp->getBlock())
+ return failure();
+ visitedConsumerOp = consumerOp;
+ }
+
+ for (OpOperand &opOperand : val.getUses()) {
+ Operation *consumerOp = opOperand.getOwner();
----------------
Yun-Fly wrote:
Maybe we can cache the unique `operand` in above traversal of `val.getUses()` in avoid of additional one?
https://github.com/llvm/llvm-project/pull/110105
More information about the Mlir-commits
mailing list