[Mlir-commits] [mlir] [MLIR][SCF] Add an API to fuse consumer to a producer within scf loop (PR #88712)

Quinn Dawkins llvmlistbot at llvm.org
Wed May 22 08:47:07 PDT 2024


================
@@ -1100,6 +1103,408 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
                                    replacements};
 }
 
+//===----------------------------------------------------------------------===//
+// tileAndFuseConsumerUsingSCF implementation.
+//===----------------------------------------------------------------------===//
+
+/// A utility function that checks whether the only use of the result of a
+/// tensor.insert_slice op is in a scf.yield op.
+static LogicalResult
+checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp) {
+  Value result = candidateSliceOp.getResult();
+  Value::use_range uses = result.getUses();
+  if (!llvm::hasSingleElement(uses)) {
+    LLVM_DEBUG(llvm::dbgs() << "Too many uses of the candidate slice op\n");
+    return failure();
+  }
+  OpOperand &operandUse = (*uses.begin());
+  Operation *userOp = operandUse.getOwner();
+  if (!isa<scf::YieldOp>(userOp)) {
+    LLVM_DEBUG(llvm::dbgs()
+               << "Expected scf.yield to be the only user, but got -> "
+               << (*userOp));
+    return failure();
+  }
+  if (result.getDefiningOp()->getBlock() != userOp->getBlock()) {
+    LLVM_DEBUG(llvm::dbgs() << "Expected tensor.insert_slice and scf.yield to "
+                               "be in the same block\n");
+    return failure();
+  }
+  return success();
+}
+
+/// Fetches the OpOperand of the only user (and use) of the value `val` which
+/// implements `TilingInterface` and `DestinationStyleOpInterface`. Returns
+/// failure otherwise.
+static FailureOr<OpOperand *> getConsumerFromUses(Value val,
+                                                  Block *containingOpBlock) {
+  // Step 1. Check that the value has exactly one use.
+  if (!llvm::hasSingleElement(val.getUses()))
+    return failure();
+  // Step 2. Get uses.
+  OpOperand &operand = (*val.getUses().begin());
+  Operation *consumerOp = operand.getOwner();
+  // TODO: We have to init result of consumer before scf.for, use
+  //       DestinationStyleOpInterface to get result shape from init for now.
+  //       Add support for other op such as op has InferTypeOpInterface.
+  if (!isa<TilingInterface>(consumerOp) ||
+      !isa<DestinationStyleOpInterface>(consumerOp))
+    return failure();
+  if (containingOpBlock != consumerOp->getBlock())
+    return failure();
+  return &operand;
+}
+
+/// Fetch the untiled consumer of a scf.for's result which is yielded by a
+/// tensor.insert_slice. This function makes the following assumptions :
+/// 1.  tensor.insert_slice has scf.yield as its only user.
+/// 2.  scf.for's corresponding result has only one use.
+static FailureOr<OpOperand *>
+getUntiledConsumerFromSlice(tensor::InsertSliceOp candidateSliceOp) {
+  if (failed(checkAssumptionForFusingConsumer(candidateSliceOp)))
+    return failure();
+  Value sliceResult = candidateSliceOp.getResult();
+  // Step 1. Fetch the corresponding output.
+  OpOperand &yieldOpOperand = (*sliceResult.getUses().begin());
+  unsigned resultNumber = yieldOpOperand.getOperandNumber();
+  // Step 2. Check containing op is scf.for.
+  Operation *containingOp = candidateSliceOp->getParentOp();
+  auto forOp = dyn_cast<scf::ForOp>(containingOp);
+  if (!forOp)
+    return failure();
+  Value resultingValue = forOp->getResult(resultNumber);
+
+  return getConsumerFromUses(resultingValue, containingOp->getBlock());
+}
+
+/// Fetch the first untiled consumer of a scf.forall's result which is yielded
+/// by a tensor.parallel_insert_slice.
+static FailureOr<OpOperand *>
+getUntiledConsumerFromSlice(tensor::ParallelInsertSliceOp candidateSliceOp) {
+  // Step 1. Fetch the corresponding output
+  Value sliceDest = candidateSliceOp.getDest();
+  auto iterArg = dyn_cast<BlockArgument>(sliceDest);
+  if (!iterArg)
+    return failure();
+  Operation *containingOp = iterArg.getOwner()->getParentOp();
+  if (containingOp != candidateSliceOp->getParentOp()->getParentOp())
+    return failure();
+  // Step 2. Check that the containing op is scf.forall.
+  auto forallOp = dyn_cast<scf::ForallOp>(containingOp);
+  if (!forallOp)
+    return failure();
+  Value resultingValue =
+      forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg));
+
+  return getConsumerFromUses(resultingValue, containingOp->getBlock());
+}
+
+/// This utility currently checks whether the loop either :-
+/// 1. Yields exactly one result.
+/// 2. Has consumer op as its first user and other users to be in the same
+/// containing block as that of consumer op's. Currently we clone the loop op
+/// right before the consumer op in order to maintain a valid def-use chain.
+/// This utility thus helps ensuring that no invalid IR is formed due to the
+/// same.
+static LogicalResult checkAssumptionForLoop(Operation *loopOp,
+                                            Operation *consumerOp) {
+  // Check if the loop op yields one result.
+  if (loopOp->getNumResults() == 1)
+    return success();
+  // Check if the consumerOp is the first user of the loopOp and if other users
+  // are in the same containing block as that of consumer op's.
+  Block *parentBlock = consumerOp->getBlock();
+  for (Operation *userOp : loopOp->getUsers()) {
+    if (userOp == consumerOp)
+      continue;
+    if (parentBlock != userOp->getBlock() ||
+        !consumerOp->isBeforeInBlock(userOp))
+      return failure();
+  }
+  return success();
+}
+
+/// A utility to fetch an untiled consumer of
+/// tensor.insert_slice/tensor.parallel_insert_slice.
+static FailureOr<OpOperand *> getUntiledConsumerFromSlice(Operation *sliceOp) {
+  if (auto insertSlice = dyn_cast<tensor::InsertSliceOp>(sliceOp)) {
+    return getUntiledConsumerFromSlice(insertSlice);
+  } else if (auto parallelInsertSlice =
+                 dyn_cast<tensor::ParallelInsertSliceOp>(sliceOp)) {
+    return getUntiledConsumerFromSlice(parallelInsertSlice);
+  } else {
+    return failure();
+  }
+}
+
+/// After fusing consumer into scf.for we want to modify the scf.yield operation
+/// to reflect the same by returning the values yielded by the tiled consumer.
+static void
+fixTerminatorSCFYield(RewriterBase &rewriter, scf::ForOp newForOp,
+                      TilingResult &tilingResult,
+                      ArrayRef<SmallVector<OpFoldResult>> &resultOffsets,
+                      ArrayRef<SmallVector<OpFoldResult>> &resultSizes,
+                      ArrayRef<BlockArgument> bbArgs) {
+  scf::YieldOp oldTerminatorOp =
+      cast<scf::YieldOp>(newForOp.getBody()->getTerminator());
+  unsigned totalOldResults = oldTerminatorOp->getNumResults();
+  unsigned totalTiledResults = tilingResult.tiledOps[0]->getNumResults();
+  SmallVector<Value> newYieldOperands;
+  newYieldOperands.reserve(totalOldResults + totalTiledResults);
+  for (auto oldResult : oldTerminatorOp.getResults()) {
+    newYieldOperands.push_back(oldResult);
+  }
+  rewriter.setInsertionPointAfter(oldTerminatorOp);
+  Location loc = newForOp.getLoc();
+  for (auto [tiledResult, bbArg, resultOffset, resultSize] :
+       llvm::zip_equal(tilingResult.tiledOps[0]->getResults(), bbArgs,
+                       resultOffsets, resultSizes)) {
+    SmallVector<OpFoldResult> strides(resultOffset.size(),
+                                      rewriter.getIndexAttr(1));
+    Value newInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
+        loc, tiledResult, bbArg, resultOffset, resultSize, strides);
+    newYieldOperands.push_back(newInsertSliceOp);
+  }
+  rewriter.create<scf::YieldOp>(loc, newYieldOperands);
+  rewriter.eraseOp(oldTerminatorOp);
+}
+
+/// After fusing consumer into scf.forall we want to yield each of the resulting
+/// values by the tiled consumer within scf.forall.in_parallel region.
+static void
+fixTerminatorSCFInParallel(RewriterBase &rewriter, scf::ForallOp newForallOp,
+                           SmallVector<Value> tiledResults,
+                           ArrayRef<SmallVector<OpFoldResult>> &resultOffsets,
+                           ArrayRef<SmallVector<OpFoldResult>> &resultSizes,
+                           ArrayRef<BlockArgument> bbArgs) {
+  scf::InParallelOp newTerminatorOp = newForallOp.getTerminator();
+  rewriter.setInsertionPointToStart(newTerminatorOp.getBody());
+  Location firstYieldOpLoc =
+      (*(newTerminatorOp.getYieldingOps().begin())).getLoc();
+  for (auto [tiledResult, bbArg, resultOffset, resultSize] :
+       llvm::zip_equal(tiledResults, bbArgs, resultOffsets, resultSizes)) {
+    SmallVector<OpFoldResult> strides(resultOffset.size(),
+                                      rewriter.getIndexAttr(1));
+    rewriter.create<tensor::ParallelInsertSliceOp>(
+        firstYieldOpLoc, tiledResult, bbArg, resultOffset, resultSize, strides);
+  }
+}
+
+/// Implementation of fusing consumer of a single slice by computing the
+/// slice of the consumer in-place for scf loop.
+FailureOr<scf::SCFFuseConsumerOfSliceResult>
+mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
+                                      Operation *candidateSliceOp) {
+  if (!isa<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
+          candidateSliceOp))
+    return failure();
+
+  bool isInsertSliceOp = isa<tensor::InsertSliceOp>(candidateSliceOp);
+
+  // 1. Get the consumer of scf.for for the result yielded by
+  // tensor.insert_slice/parallel_insert_slice.
+  FailureOr<OpOperand *> maybeConsumerOpOperand =
+      getUntiledConsumerFromSlice(candidateSliceOp);
+  if (failed(maybeConsumerOpOperand)) {
+    return rewriter.notifyMatchFailure(candidateSliceOp,
+                                       "could not fetch consumer to fuse");
+  }
+  OpOperand *consumerOpOperand = *maybeConsumerOpOperand;
+  Operation *consumerOp = consumerOpOperand->getOwner();
+  unsigned operandNumber = consumerOpOperand->getOperandNumber();
+  unsigned resultNumber = 0;
+  if (auto producerResult = dyn_cast<OpResult>(consumerOpOperand->get())) {
+    resultNumber = producerResult.getResultNumber();
+  } else {
+    return rewriter.notifyMatchFailure(
+        consumerOp, "consumer op's operand doesn't seem to be an OpResult");
+  }
+
+  Operation *oldLoopOp = nullptr;
+  SmallVector<Value> newOuts;
+  Block *oldLoopBody = nullptr;
+  unsigned initSize = 0;
+  unsigned rank = 1;
+  if (isInsertSliceOp) {
+    auto forOp = candidateSliceOp->getParentOfType<scf::ForOp>();
+    oldLoopOp = forOp;
+    llvm::append_range(newOuts, forOp.getInits());
+    oldLoopBody = forOp.getBody();
+    initSize = forOp.getInits().size();
+  } else {
+    auto forallOp = candidateSliceOp->getParentOfType<scf::ForallOp>();
+    oldLoopOp = forallOp;
+    llvm::append_range(newOuts, forallOp.getOutputs());
+    oldLoopBody = forallOp.getBody();
+    initSize = forallOp.getOutputs().size();
+    rank = forallOp.getRank();
+  }
+
+  if (failed(checkAssumptionForLoop(oldLoopOp, consumerOp))) {
+    return rewriter.notifyMatchFailure(
+        oldLoopOp, "containing loop op should either yield just one value or "
+                   "have the consumer op as its first user");
+  }
+
+  OpBuilder::InsertionGuard g(rewriter);
+
+  // 2. Check consumer is not using scf loop's output as init.
+  auto dstOp = cast<DestinationStyleOpInterface>(consumerOp);
+  SmallVector<Value> dpsInits =
+      llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; });
+  if (llvm::is_contained(dpsInits, oldLoopOp->getResult(resultNumber))) {
+    return rewriter.notifyMatchFailure(
+        consumerOp,
+        "consumer op taking the result of scf.for as init is not supported");
+  }
+  newOuts.append(dpsInits);
+
+  Location loc = oldLoopOp->getLoc();
+
+  // 3. Create new scf loop op.
+  rewriter.setInsertionPoint(consumerOp);
+  Operation *newLoopOp = nullptr;
+  Block *newLoopBody = nullptr;
+  if (isInsertSliceOp) {
+    auto forOp = cast<scf::ForOp>(oldLoopOp);
+    auto newForOp = rewriter.create<scf::ForOp>(loc, forOp.getLowerBound(),
+                                                forOp.getUpperBound(),
+                                                forOp.getStep(), newOuts);
+    newLoopOp = newForOp;
+    newLoopBody = newForOp.getBody();
+  } else {
+    auto forallOp = cast<scf::ForallOp>(oldLoopOp);
+    auto newForallOp = rewriter.create<scf::ForallOp>(
+        loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
+        forallOp.getMixedStep(), newOuts, forallOp.getMapping());
+    newLoopOp = newForallOp;
+    rewriter.eraseOp(newForallOp.getTerminator());
+    newLoopBody = newForallOp.getBody();
+  }
+
+  // 4. Move the loop body to the new op.
+  unsigned oldNumArguments = oldLoopBody->getNumArguments();
+  rewriter.mergeBlocks(oldLoopBody, newLoopBody,
+                       newLoopBody->getArguments().take_front(oldNumArguments));
+
+  // 5.a. Clone consumer after the cloned
+  // tensor.insert_slice/parallel_insert_slice op.
+  rewriter.setInsertionPointAfter(candidateSliceOp);
+  auto newForOpBlockArgsForConsumerDest =
+      newLoopBody->getArguments().drop_front(oldNumArguments);
+  auto clonedConsumerOp = cast<TilingInterface>(cloneOpAndUpdateDestinationArgs(
+      rewriter, consumerOp, newForOpBlockArgsForConsumerDest));
+
+  // 5.b. Replace all uses of the loop result with the result of the cloned
+  // tensor.insert_slice/parallel_insert_slice.
+  OpOperand &operandToReplace = clonedConsumerOp->getOpOperand(operandNumber);
+  rewriter.modifyOpInPlace(clonedConsumerOp, [&]() {
+    if (auto sliceOp = dyn_cast<tensor::InsertSliceOp>(candidateSliceOp)) {
+      operandToReplace.set(sliceOp.getResult());
+    } else if (auto sliceOp =
+                   dyn_cast<tensor::ParallelInsertSliceOp>(candidateSliceOp)) {
+      operandToReplace.set(sliceOp.getSource());
----------------
qedawkins wrote:

That's for linalg. Other operations can do whatever they want (not trying to be difficult, just trying to make sure we at least clarify any implicit interface API requirements like this).

https://github.com/llvm/llvm-project/pull/88712


More information about the Mlir-commits mailing list