[Mlir-commits] [mlir] [MLIR][TilingInterface] Extend consumer fusion for multi-use of producer (PR #110105)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Sep 28 18:51:56 PDT 2024


================
@@ -1481,21 +1481,33 @@ checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp) {
 /// failure otherwise.
 static FailureOr<OpOperand *> getConsumerFromUses(Value val,
                                                   Block *containingOpBlock) {
-  // Step 1. Check that the value has exactly one use.
-  if (!llvm::hasSingleElement(val.getUses()))
-    return failure();
-  // Step 2. Get uses.
-  OpOperand &operand = (*val.getUses().begin());
-  Operation *consumerOp = operand.getOwner();
-  // TODO: We have to init result of consumer before scf.for, use
-  //       DestinationStyleOpInterface to get result shape from init for now.
-  //       Add support for other op such as op has InferTypeOpInterface.
-  if (!isa<TilingInterface>(consumerOp) ||
-      !isa<DestinationStyleOpInterface>(consumerOp))
-    return failure();
-  if (containingOpBlock != consumerOp->getBlock())
-    return failure();
-  return &operand;
+  // Check that the value has exactly one use which isn't a scf.yield or a
+  // tensor.parallel_insert_slice op.
+  Operation *visitedConsumerOp = nullptr;
+  for (OpOperand &opOperand : val.getUses()) {
+    Operation *consumerOp = opOperand.getOwner();
+    if (isa<scf::YieldOp, tensor::ParallelInsertSliceOp>(consumerOp))
+      continue;
+    if (visitedConsumerOp && visitedConsumerOp != consumerOp)
+      return failure();
+    // TODO: We have to init result of consumer before scf.for, use
+    //       DestinationStyleOpInterface to get result shape from init for now.
+    //       Add support for other op such as op has InferTypeOpInterface.
+    if (!isa<TilingInterface>(consumerOp) ||
+        !isa<DestinationStyleOpInterface>(consumerOp))
+      return failure();
+    if (containingOpBlock != consumerOp->getBlock())
+      return failure();
+    visitedConsumerOp = consumerOp;
+  }
+
+  for (OpOperand &opOperand : val.getUses()) {
+    Operation *consumerOp = opOperand.getOwner();
----------------
Yun-Fly wrote:

Maybe we can cache the unique `operand` in above traversal of `val.getUses()` in avoid of additional one? 

https://github.com/llvm/llvm-project/pull/110105


More information about the Mlir-commits mailing list