[Mlir-commits] [mlir] [MLIR][SCF] Add an API to fuse consumer to a producer within scf loop (PR #88712)

Abhishek Varma llvmlistbot at llvm.org
Thu Apr 18 04:45:48 PDT 2024


================
@@ -1100,6 +1101,428 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
                                    replacements};
 }
 
+//===----------------------------------------------------------------------===//
+// tileAndFuseConsumerUsingSCF implementation.
+//===----------------------------------------------------------------------===//
+
+/// In the following function `source` is the source operand of
+/// tensor.insert_slice op. We traverse through the use-def chain of the same
+/// through the containing scf.for to fetch the first untiled consumer.
+static std::tuple<Operation *, std::optional<OpOperand *>>
+getUntiledConsumerFromSliceDestSCFFor(OpOperand &source,
+                                      unsigned &operandNumber) {
+  // Step 1. Fetch the corresponding output
+  // TODO(avarma): Make it generic for multiple values yielding scf.for.
+  unsigned yieldOperandNumber = source.getOperandNumber();
+  Value resultingValue =
+      source.getOwner()->getParentOp()->getResult(yieldOperandNumber);
+
+  // Step 3. Get users.
+  std::optional<OpOperand *> destinationIterArg;
+  Operation *untiledConsumer;
+  for (Operation *user : resultingValue.getUsers()) {
+    // TODO(avarma): Address the case where the consumer op itself can return
+    //               more than one result.
+    for (Value operand : user->getOperands()) {
+      if (operand == resultingValue) {
+        untiledConsumer = user;
+        break;
+      }
+      operandNumber++;
+    }
+    break;
+  }
+  return {untiledConsumer, destinationIterArg};
+}
+
+static bool checkAssumptionForFusingConsumer(Operation *op) {
+  Value result = op->getResult(0);
+  Value::user_range users = result.getUsers();
+  if (std::distance(users.begin(), users.end()) != 1) {
+    LLVM_DEBUG(llvm::dbgs() << "Too many uses of the candidate slice op\n");
+    return false;
+  }
+  if (isa<tensor::InsertSliceOp>(op) && !isa<scf::YieldOp>(*users.begin())) {
+    LLVM_DEBUG(llvm::dbgs() << "Expected scf.yield to be the only user\n");
+    return false;
+  }
+  return true;
+}
+
+/// Implementation of fusing consumer of a single slice by computing the
+/// slice of the consumer in-place for scf.for.
+static FailureOr<scf::SCFFuseConsumerOfSliceResult>
+tileAndFuseConsumerOfSliceSCFFor(RewriterBase &rewriter,
+                                 tensor::InsertSliceOp candidateSliceOp) {
+  // ASSUMING THAT YIELD OP IS ONLY YIELDING JUST ONE VALUE.
+  if (!checkAssumptionForFusingConsumer(candidateSliceOp)) {
+    return rewriter.notifyMatchFailure(candidateSliceOp,
+                                       "needs only scf.yield as its user");
+  }
+  // 1. Get the consumer of the source.
+  unsigned operandNumber = 0;
+  auto [consumerOp, destinationInitArg] = getUntiledConsumerFromSliceDestSCFFor(
+      candidateSliceOp->getOpOperand(0), operandNumber);
+  if (!consumerOp)
+    return failure();
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(candidateSliceOp);
+
+  Operation *containingOp = candidateSliceOp->getParentOp();
+  // Check consumer has tiling interface.
+  auto tileableConsumer = dyn_cast<TilingInterface>(consumerOp);
+  if (!tileableConsumer) {
+    return rewriter.notifyMatchFailure(consumerOp,
+                                       "consumer is not a TileableInterface");
+  }
+
+  // Check containing op is "scf::ForOp".
+  auto forOp = dyn_cast<scf::ForOp>(containingOp);
+  if (!forOp) {
+    return rewriter.notifyMatchFailure(containingOp,
+                                       "containing op is not a scf.for");
+  }
+
+  // Check containingOp has exactly one use.
+  assert(forOp.getResults().size() == 1 &&
----------------
Abhishek-Varma wrote:

I chose to handle it in the latest push.

I've added a `resultNumber` - this helps us to deal with multiple value yielding scf loop effectively.

https://github.com/llvm/llvm-project/pull/88712


More information about the Mlir-commits mailing list