[Mlir-commits] [mlir] [MLIR][SCF] Add an API to fuse consumer to a producer within scf loop (PR #88712)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed May 22 10:07:31 PDT 2024


================
@@ -1100,6 +1103,408 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
                                    replacements};
 }
 
+//===----------------------------------------------------------------------===//
+// tileAndFuseConsumerUsingSCF implementation.
+//===----------------------------------------------------------------------===//
+
+/// A utility function that checks whether the only use of the result of a
+/// tensor.insert_slice op is in a scf.yield op.
+static LogicalResult
+checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp) {
+  Value result = candidateSliceOp.getResult();
+  Value::use_range uses = result.getUses();
+  if (!llvm::hasSingleElement(uses)) {
+    LLVM_DEBUG(llvm::dbgs() << "Too many uses of the candidate slice op\n");
+    return failure();
+  }
+  OpOperand &operandUse = (*uses.begin());
+  Operation *userOp = operandUse.getOwner();
+  if (!isa<scf::YieldOp>(userOp)) {
+    LLVM_DEBUG(llvm::dbgs()
+               << "Expected scf.yield to be the only user, but got -> "
+               << (*userOp));
+    return failure();
+  }
+  if (result.getDefiningOp()->getBlock() != userOp->getBlock()) {
+    LLVM_DEBUG(llvm::dbgs() << "Expected tensor.insert_slice and scf.yield to "
+                               "be in the same block\n");
+    return failure();
+  }
+  return success();
+}
+
+/// Fetches the OpOperand of the only user (and use) of the value `val` which
+/// implements `TilingInterface` and `DestinationStyleOpInterface`. Returns
+/// failure otherwise.
+static FailureOr<OpOperand *> getConsumerFromUses(Value val,
+                                                  Block *containingOpBlock) {
+  // Step 1. Check that the value has exactly one use.
+  if (!llvm::hasSingleElement(val.getUses()))
+    return failure();
+  // Step 2. Get uses.
+  OpOperand &operand = (*val.getUses().begin());
+  Operation *consumerOp = operand.getOwner();
+  // TODO: We have to init result of consumer before scf.for, use
+  //       DestinationStyleOpInterface to get result shape from init for now.
+  //       Add support for other op such as op has InferTypeOpInterface.
+  if (!isa<TilingInterface>(consumerOp) ||
+      !isa<DestinationStyleOpInterface>(consumerOp))
+    return failure();
+  if (containingOpBlock != consumerOp->getBlock())
+    return failure();
+  return &operand;
+}
+
+/// Fetch the untiled consumer of a scf.for's result which is yielded by a
+/// tensor.insert_slice. This function makes the following assumptions :
+/// 1.  tensor.insert_slice has scf.yield as its only user.
+/// 2.  scf.for's corresponding result has only one use.
+static FailureOr<OpOperand *>
+getUntiledConsumerFromSlice(tensor::InsertSliceOp candidateSliceOp) {
+  if (failed(checkAssumptionForFusingConsumer(candidateSliceOp)))
+    return failure();
+  Value sliceResult = candidateSliceOp.getResult();
+  // Step 1. Fetch the corresponding output.
+  OpOperand &yieldOpOperand = (*sliceResult.getUses().begin());
+  unsigned resultNumber = yieldOpOperand.getOperandNumber();
+  // Step 2. Check containing op is scf.for.
+  Operation *containingOp = candidateSliceOp->getParentOp();
+  auto forOp = dyn_cast<scf::ForOp>(containingOp);
+  if (!forOp)
+    return failure();
+  Value resultingValue = forOp->getResult(resultNumber);
+
+  return getConsumerFromUses(resultingValue, containingOp->getBlock());
+}
+
+/// Fetch the first untiled consumer of a scf.forall's result which is yielded
+/// by a tensor.parallel_insert_slice.
+static FailureOr<OpOperand *>
+getUntiledConsumerFromSlice(tensor::ParallelInsertSliceOp candidateSliceOp) {
+  // Step 1. Fetch the corresponding output
+  Value sliceDest = candidateSliceOp.getDest();
+  auto iterArg = dyn_cast<BlockArgument>(sliceDest);
+  if (!iterArg)
+    return failure();
+  Operation *containingOp = iterArg.getOwner()->getParentOp();
+  if (containingOp != candidateSliceOp->getParentOp()->getParentOp())
+    return failure();
+  // Step 2. Check that the containing op is scf.forall.
+  auto forallOp = dyn_cast<scf::ForallOp>(containingOp);
+  if (!forallOp)
+    return failure();
+  Value resultingValue =
+      forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg));
+
+  return getConsumerFromUses(resultingValue, containingOp->getBlock());
+}
+
+/// This utility currently checks whether the loop either :-
+/// 1. Yields exactly one result.
+/// 2. Has consumer op as its first user and other users to be in the same
+/// containing block as that of consumer op's. Currently we clone the loop op
+/// right before the consumer op in order to maintain a valid def-use chain.
+/// This utility thus helps ensuring that no invalid IR is formed due to the
+/// same.
+static LogicalResult checkAssumptionForLoop(Operation *loopOp,
+                                            Operation *consumerOp) {
+  // Check if the loop op yields one result.
+  if (loopOp->getNumResults() == 1)
+    return success();
+  // Check if the consumerOp is the first user of the loopOp and if other users
+  // are in the same containing block as that of consumer op's.
+  Block *parentBlock = consumerOp->getBlock();
+  for (Operation *userOp : loopOp->getUsers()) {
+    if (userOp == consumerOp)
+      continue;
+    if (parentBlock != userOp->getBlock() ||
+        !consumerOp->isBeforeInBlock(userOp))
+      return failure();
+  }
+  return success();
+}
+
+/// A utility to fetch an untiled consumer of
+/// tensor.insert_slice/tensor.parallel_insert_slice.
+static FailureOr<OpOperand *> getUntiledConsumerFromSlice(Operation *sliceOp) {
+  if (auto insertSlice = dyn_cast<tensor::InsertSliceOp>(sliceOp)) {
+    return getUntiledConsumerFromSlice(insertSlice);
+  } else if (auto parallelInsertSlice =
+                 dyn_cast<tensor::ParallelInsertSliceOp>(sliceOp)) {
+    return getUntiledConsumerFromSlice(parallelInsertSlice);
+  } else {
+    return failure();
+  }
+}
+
+/// After fusing consumer into scf.for we want to modify the scf.yield operation
+/// to reflect the same by returning the values yielded by the tiled consumer.
+static void
+fixTerminatorSCFYield(RewriterBase &rewriter, scf::ForOp newForOp,
+                      TilingResult &tilingResult,
+                      ArrayRef<SmallVector<OpFoldResult>> &resultOffsets,
+                      ArrayRef<SmallVector<OpFoldResult>> &resultSizes,
+                      ArrayRef<BlockArgument> bbArgs) {
+  scf::YieldOp oldTerminatorOp =
+      cast<scf::YieldOp>(newForOp.getBody()->getTerminator());
+  unsigned totalOldResults = oldTerminatorOp->getNumResults();
+  unsigned totalTiledResults = tilingResult.tiledOps[0]->getNumResults();
+  SmallVector<Value> newYieldOperands;
+  newYieldOperands.reserve(totalOldResults + totalTiledResults);
+  for (auto oldResult : oldTerminatorOp.getResults()) {
+    newYieldOperands.push_back(oldResult);
+  }
+  rewriter.setInsertionPointAfter(oldTerminatorOp);
+  Location loc = newForOp.getLoc();
+  for (auto [tiledResult, bbArg, resultOffset, resultSize] :
+       llvm::zip_equal(tilingResult.tiledOps[0]->getResults(), bbArgs,
+                       resultOffsets, resultSizes)) {
+    SmallVector<OpFoldResult> strides(resultOffset.size(),
+                                      rewriter.getIndexAttr(1));
+    Value newInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
+        loc, tiledResult, bbArg, resultOffset, resultSize, strides);
+    newYieldOperands.push_back(newInsertSliceOp);
+  }
+  rewriter.create<scf::YieldOp>(loc, newYieldOperands);
+  rewriter.eraseOp(oldTerminatorOp);
+}
+
+/// After fusing consumer into scf.forall we want to yield each of the resulting
+/// values by the tiled consumer within scf.forall.in_parallel region.
+static void
+fixTerminatorSCFInParallel(RewriterBase &rewriter, scf::ForallOp newForallOp,
+                           SmallVector<Value> tiledResults,
+                           ArrayRef<SmallVector<OpFoldResult>> &resultOffsets,
+                           ArrayRef<SmallVector<OpFoldResult>> &resultSizes,
+                           ArrayRef<BlockArgument> bbArgs) {
+  scf::InParallelOp newTerminatorOp = newForallOp.getTerminator();
+  rewriter.setInsertionPointToStart(newTerminatorOp.getBody());
+  Location firstYieldOpLoc =
+      (*(newTerminatorOp.getYieldingOps().begin())).getLoc();
+  for (auto [tiledResult, bbArg, resultOffset, resultSize] :
+       llvm::zip_equal(tiledResults, bbArgs, resultOffsets, resultSizes)) {
+    SmallVector<OpFoldResult> strides(resultOffset.size(),
+                                      rewriter.getIndexAttr(1));
+    rewriter.create<tensor::ParallelInsertSliceOp>(
+        firstYieldOpLoc, tiledResult, bbArg, resultOffset, resultSize, strides);
+  }
+}
+
+/// Implementation of fusing consumer of a single slice by computing the
+/// slice of the consumer in-place for scf loop.
+FailureOr<scf::SCFFuseConsumerOfSliceResult>
+mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
+                                      Operation *candidateSliceOp) {
+  if (!isa<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
+          candidateSliceOp))
+    return failure();
+
+  bool isInsertSliceOp = isa<tensor::InsertSliceOp>(candidateSliceOp);
+
+  // 1. Get the consumer of scf.for for the result yielded by
+  // tensor.insert_slice/parallel_insert_slice.
+  FailureOr<OpOperand *> maybeConsumerOpOperand =
+      getUntiledConsumerFromSlice(candidateSliceOp);
+  if (failed(maybeConsumerOpOperand)) {
+    return rewriter.notifyMatchFailure(candidateSliceOp,
+                                       "could not fetch consumer to fuse");
+  }
+  OpOperand *consumerOpOperand = *maybeConsumerOpOperand;
+  Operation *consumerOp = consumerOpOperand->getOwner();
+  unsigned operandNumber = consumerOpOperand->getOperandNumber();
+  unsigned resultNumber = 0;
+  if (auto producerResult = dyn_cast<OpResult>(consumerOpOperand->get())) {
+    resultNumber = producerResult.getResultNumber();
+  } else {
+    return rewriter.notifyMatchFailure(
+        consumerOp, "consumer op's operand doesn't seem to be an OpResult");
+  }
+
+  Operation *oldLoopOp = nullptr;
+  SmallVector<Value> newOuts;
+  Block *oldLoopBody = nullptr;
+  unsigned initSize = 0;
+  unsigned rank = 1;
+  if (isInsertSliceOp) {
+    auto forOp = candidateSliceOp->getParentOfType<scf::ForOp>();
+    oldLoopOp = forOp;
+    llvm::append_range(newOuts, forOp.getInits());
+    oldLoopBody = forOp.getBody();
+    initSize = forOp.getInits().size();
+  } else {
+    auto forallOp = candidateSliceOp->getParentOfType<scf::ForallOp>();
+    oldLoopOp = forallOp;
+    llvm::append_range(newOuts, forallOp.getOutputs());
+    oldLoopBody = forallOp.getBody();
+    initSize = forallOp.getOutputs().size();
+    rank = forallOp.getRank();
+  }
+
+  if (failed(checkAssumptionForLoop(oldLoopOp, consumerOp))) {
+    return rewriter.notifyMatchFailure(
+        oldLoopOp, "containing loop op should either yield just one value or "
+                   "have the consumer op as its first user");
+  }
+
+  OpBuilder::InsertionGuard g(rewriter);
+
+  // 2. Check consumer is not using scf loop's output as init.
+  auto dstOp = cast<DestinationStyleOpInterface>(consumerOp);
+  SmallVector<Value> dpsInits =
+      llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; });
+  if (llvm::is_contained(dpsInits, oldLoopOp->getResult(resultNumber))) {
+    return rewriter.notifyMatchFailure(
+        consumerOp,
+        "consumer op taking the result of scf.for as init is not supported");
+  }
+  newOuts.append(dpsInits);
+
+  Location loc = oldLoopOp->getLoc();
+
+  // 3. Create new scf loop op.
+  rewriter.setInsertionPoint(consumerOp);
+  Operation *newLoopOp = nullptr;
+  Block *newLoopBody = nullptr;
+  if (isInsertSliceOp) {
+    auto forOp = cast<scf::ForOp>(oldLoopOp);
+    auto newForOp = rewriter.create<scf::ForOp>(loc, forOp.getLowerBound(),
+                                                forOp.getUpperBound(),
+                                                forOp.getStep(), newOuts);
+    newLoopOp = newForOp;
+    newLoopBody = newForOp.getBody();
+  } else {
+    auto forallOp = cast<scf::ForallOp>(oldLoopOp);
+    auto newForallOp = rewriter.create<scf::ForallOp>(
+        loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
+        forallOp.getMixedStep(), newOuts, forallOp.getMapping());
+    newLoopOp = newForallOp;
+    rewriter.eraseOp(newForallOp.getTerminator());
+    newLoopBody = newForallOp.getBody();
+  }
+
+  // 4. Move the loop body to the new op.
+  unsigned oldNumArguments = oldLoopBody->getNumArguments();
+  rewriter.mergeBlocks(oldLoopBody, newLoopBody,
+                       newLoopBody->getArguments().take_front(oldNumArguments));
+
+  // 5.a. Clone consumer after the cloned
+  // tensor.insert_slice/parallel_insert_slice op.
+  rewriter.setInsertionPointAfter(candidateSliceOp);
+  auto newForOpBlockArgsForConsumerDest =
+      newLoopBody->getArguments().drop_front(oldNumArguments);
+  auto clonedConsumerOp = cast<TilingInterface>(cloneOpAndUpdateDestinationArgs(
+      rewriter, consumerOp, newForOpBlockArgsForConsumerDest));
+
+  // 5.b. Replace all uses of the loop result with the result of the cloned
+  // tensor.insert_slice/parallel_insert_slice.
+  OpOperand &operandToReplace = clonedConsumerOp->getOpOperand(operandNumber);
+  rewriter.modifyOpInPlace(clonedConsumerOp, [&]() {
+    if (auto sliceOp = dyn_cast<tensor::InsertSliceOp>(candidateSliceOp)) {
+      operandToReplace.set(sliceOp.getResult());
+    } else if (auto sliceOp =
+                   dyn_cast<tensor::ParallelInsertSliceOp>(candidateSliceOp)) {
+      operandToReplace.set(sliceOp.getSource());
----------------
MaheshRavishankar wrote:

> > By construction all the slice uses of the tensor.insert_slice are exactly the same shape as the source.
> 
> Where does this guarantee come from? Looking at the implementation of `replaceInsertSliceWithTiledConsumer` it is just calling `getTiledImplementationFromOperandTile` without passing in the source of the `insert_slice`.

You are using the slice of the operand to compute the slice of the iteration space that computes that operand. That is only possible if that is a bijection. Then the tile of the operand computed from the iteration space is going to be the same. Things have to be consistent.

Yes, we are indexing on `tensor.extract_slice` and `tensor.insert_slice`, which we can maybe eventually switch to an interface that allows different kind of "slices", but this is the current state right now. If we miss it, then its a "missed optimization" and not a correctness issue.

https://github.com/llvm/llvm-project/pull/88712


More information about the Mlir-commits mailing list