[Mlir-commits] [mlir] [MLIR][SCF] Add an API to fuse consumer to a producer within scf loop (PR #88712)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Apr 19 00:09:22 PDT 2024


================
@@ -1100,6 +1101,428 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
                                    replacements};
 }
 
+//===----------------------------------------------------------------------===//
+// tileAndFuseConsumerUsingSCF implementation.
+//===----------------------------------------------------------------------===//
+
+/// In the following function `source` is the source operand of
+/// tensor.insert_slice op. We traverse through the use-def chain of the same
+/// through the containing scf.for to fetch the first untiled consumer.
+static std::tuple<Operation *, std::optional<OpOperand *>>
+getUntiledConsumerFromSliceDestSCFFor(OpOperand &source,
+                                      unsigned &operandNumber) {
+  // Step 1. Fetch the corresponding output
+  // TODO(avarma): Make it generic for multiple values yielding scf.for.
+  unsigned yieldOperandNumber = source.getOperandNumber();
+  Value resultingValue =
+      source.getOwner()->getParentOp()->getResult(yieldOperandNumber);
+
+  // Step 3. Get users.
+  std::optional<OpOperand *> destinationIterArg;
+  Operation *untiledConsumer;
+  for (Operation *user : resultingValue.getUsers()) {
+    // TODO(avarma): Address the case where the consumer op itself can return
+    //               more than one result.
+    for (Value operand : user->getOperands()) {
+      if (operand == resultingValue) {
+        untiledConsumer = user;
+        break;
+      }
+      operandNumber++;
+    }
+    break;
+  }
+  return {untiledConsumer, destinationIterArg};
+}
+
+static bool checkAssumptionForFusingConsumer(Operation *op) {
+  Value result = op->getResult(0);
+  Value::user_range users = result.getUsers();
+  if (std::distance(users.begin(), users.end()) != 1) {
+    LLVM_DEBUG(llvm::dbgs() << "Too many uses of the candidate slice op\n");
+    return false;
+  }
+  if (isa<tensor::InsertSliceOp>(op) && !isa<scf::YieldOp>(*users.begin())) {
+    LLVM_DEBUG(llvm::dbgs() << "Expected scf.yield to be the only user\n");
+    return false;
+  }
+  return true;
+}
+
+/// Implementation of fusing consumer of a single slice by computing the
+/// slice of the consumer in-place for scf.for.
+static FailureOr<scf::SCFFuseConsumerOfSliceResult>
+tileAndFuseConsumerOfSliceSCFFor(RewriterBase &rewriter,
+                                 tensor::InsertSliceOp candidateSliceOp) {
+  // ASSUMING THAT YIELD OP IS ONLY YIELDING JUST ONE VALUE.
+  if (!checkAssumptionForFusingConsumer(candidateSliceOp)) {
+    return rewriter.notifyMatchFailure(candidateSliceOp,
+                                       "needs only scf.yield as its user");
+  }
+  // 1. Get the consumer of the source.
+  unsigned operandNumber = 0;
+  auto [consumerOp, destinationInitArg] = getUntiledConsumerFromSliceDestSCFFor(
+      candidateSliceOp->getOpOperand(0), operandNumber);
+  if (!consumerOp)
+    return failure();
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(candidateSliceOp);
+
+  Operation *containingOp = candidateSliceOp->getParentOp();
+  // Check consumer has tiling interface.
+  auto tileableConsumer = dyn_cast<TilingInterface>(consumerOp);
+  if (!tileableConsumer) {
+    return rewriter.notifyMatchFailure(consumerOp,
+                                       "consumer is not a TileableInterface");
+  }
+
+  // Check containing op is "scf::ForOp".
+  auto forOp = dyn_cast<scf::ForOp>(containingOp);
+  if (!forOp) {
+    return rewriter.notifyMatchFailure(containingOp,
+                                       "containing op is not a scf.for");
+  }
+
+  // Check containingOp has exactly one use.
+  assert(forOp.getResults().size() == 1 &&
+         "expect exactly one result of the containing op");
+  if (!checkAssumptionForFusingConsumer(forOp)) {
+    return rewriter.notifyMatchFailure(forOp, "scf.for has more than 1 uses");
+  }
+  Value bridge(nullptr);
+  SmallVector<unsigned> operandNums;
+  for (auto [idx, opd] : llvm::enumerate((consumerOp->getOperands()))) {
+    if (opd.getDefiningOp() == containingOp) {
+      operandNums.push_back(idx);
+      if (!bridge) {
+        bridge = opd;
+      } else if (bridge != opd) {
+        return rewriter.notifyMatchFailure(
+            consumerOp,
+            "consumer's operand use more than one containingOp's result");
+      }
+    }
+  }
+
+  // TODO: We have to init result of consumer before scf.for, use
+  //       DestinationStyleOpInterface to get result shape from init for now.
+  //       Add support for other op such as op has InferTypeOpInterface.
+  // Check consumer has DestinationStyleOpInterface.
+  auto dstOp = dyn_cast<DestinationStyleOpInterface>(consumerOp);
+  if (!dstOp) {
+    return rewriter.notifyMatchFailure(
+        consumerOp, "consumer op should have destination style op interface");
+  }
+
+  // Check consumer is not using scf.for's output as init.
+  SmallVector<Value> dpsInits = llvm::to_vector<4>(
+      llvm::map_range(dstOp.getDpsInits(), [](Value v) { return v; }));
+  if (llvm::is_contained(dpsInits, forOp.getResult(0))) {
+    return rewriter.notifyMatchFailure(
+        consumerOp,
+        "consumer op taking the result of scf.for as init is not supported");
+  }
+
+  Location loc = forOp.getLoc();
+  SmallVector<OpFoldResult> offsets = candidateSliceOp.getMixedOffsets();
+  SmallVector<OpFoldResult> sizes = candidateSliceOp.getMixedSizes();
+  SmallVector<OpFoldResult> strides = candidateSliceOp.getMixedStrides();
+  // Check all insert stride is 1.
+  if (llvm::any_of(strides, [](OpFoldResult stride) {
+        return !isConstantIntValue(stride, 1);
+      })) {
+    return rewriter.notifyMatchFailure(
+        candidateSliceOp, "containingOp's result yield with stride");
+  }
+
+  SmallVector<Value> newOuts(forOp.getInits());
+  newOuts.append(dpsInits);
+
+  // Create new scf.for op.
+  rewriter.setInsertionPoint(consumerOp);
+  auto newforOp = rewriter.create<scf::ForOp>(loc, forOp.getLowerBound(),
+                                              forOp.getUpperBound(),
+                                              forOp.getStep(), newOuts);
+  // Move the loop body to the new op.
+  Block *loopBody = forOp.getBody();
+  Block *newLoopBody = newforOp.getBody();
+  rewriter.mergeBlocks(
+      loopBody, newLoopBody,
+      newLoopBody->getArguments().take_front(loopBody->getNumArguments()));
+
+  // Clone the consumer after the insert_slice.
+  rewriter.setInsertionPointAfter(candidateSliceOp);
+  SmallVector<Value> newForOpBlockArgsForConsumerDest;
+  for (unsigned i = loopBody->getNumArguments(),
+                n = newLoopBody->getArguments().size();
+       i < n; i++) {
+    newForOpBlockArgsForConsumerDest.push_back(newLoopBody->getArgument(i));
+  }
+  auto clonedConsumerOp = cast<TilingInterface>(cloneOpAndUpdateDestinationArgs(
+      rewriter, consumerOp, newForOpBlockArgsForConsumerDest));
+
+  // Replace scf.for result's use in the consumer with insert_slice result.
+  rewriter.replaceAllUsesWith(forOp.getResult(0), candidateSliceOp.getResult());
+
+  // Generate the tiled implementation of the consumer of the source.
+  rewriter.setInsertionPoint(candidateSliceOp);
+  FailureOr<TilingResult> tileAndFuseResult =
+      tensor::replaceInsertSliceWithTiledConsumer(
+          rewriter, candidateSliceOp,
+          clonedConsumerOp->getOpOperand(operandNumber));
+  if (failed(tileAndFuseResult)) {
+    return rewriter.notifyMatchFailure(tileableConsumer,
+                                       "failed to tile consumer op: ");
+  }
+
+  // Update the source of the candidateSlice to be the cloned consumer.
+  SmallVector<Value> candidateSliceOpOperands =
+      llvm::to_vector(candidateSliceOp->getOperands());
+  candidateSliceOpOperands[0] = tileAndFuseResult->tiledValues[0];
+  tensor::InsertSliceOp clonedCandidateSliceOp =
+      mlir::clone(rewriter, candidateSliceOp,
+                  candidateSliceOp->getResultTypes(), candidateSliceOpOperands);
+  auto bbArgs = newforOp.getBody()->getArguments();
+  clonedCandidateSliceOp->getOpOperands()[1].set(
+      bbArgs[1 + forOp.getInits().size() + 0]);
+
+  rewriter.replaceAllUsesWith(candidateSliceOp, candidateSliceOp.getSource());
----------------
MaheshRavishankar wrote:

Ok, I understand what is happening here. You are cloning the consumer first and then cloning the insert slice. It should be reverse.

1) First clone the insert slice after the original insert slice.
2) Then clone the consumer after it and replace all uses of the loop result with the result of the cloned insert slice. 
At this point you havent changed the program except for creating a clone of the insert_slice -> consumer chain within the loop.
3) Now call the `tensor::replaceInsertSliceWIthTiledConsumer`
4) That returns the tiled result values. You need to use the `getResultTilePosition` method for the consumer operation to figure out what is the shape and offset of the result tile and use that to create the final `tensor.insert_slice` into the destination arg you added to the loop for that result.

This should be the cleanest way of handling this.

https://github.com/llvm/llvm-project/pull/88712


More information about the Mlir-commits mailing list