[Mlir-commits] [mlir] [MLIR][SCF] Add an API to fuse consumer to a producer within scf loop (PR #88712)

Abhishek Varma llvmlistbot at llvm.org
Wed May 22 23:24:53 PDT 2024


================
@@ -1100,6 +1102,413 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
                                    replacements};
 }
 
+//===----------------------------------------------------------------------===//
+// tileAndFuseConsumerUsingSCF implementation.
+//===----------------------------------------------------------------------===//
+
+/// A utility function that checks whether the only use of the result of a
+/// tensor.insert_slice op is in a scf.yield op.
+static LogicalResult
+checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp) {
+  Value result = candidateSliceOp.getResult();
+  Value::use_range uses = result.getUses();
+  if (!llvm::hasSingleElement(uses)) {
+    LLVM_DEBUG(llvm::dbgs() << "Too many uses of the candidate slice op\n");
+    return failure();
+  }
+  OpOperand &operandUse = (*uses.begin());
+  Operation *userOp = operandUse.getOwner();
+  if (!isa<scf::YieldOp>(userOp)) {
+    LLVM_DEBUG(llvm::dbgs()
+               << "Expected scf.yield to be the only user, but got -> "
+               << (*userOp));
+    return failure();
+  }
+  if (result.getDefiningOp()->getBlock() != userOp->getBlock()) {
+    LLVM_DEBUG(llvm::dbgs() << "Expected tensor.insert_slice and scf.yield to "
+                               "be in the same block\n");
+    return failure();
+  }
+  return success();
+}
+
+/// Fetches the OpOperand of the only user (and use) of the value `val` which
+/// implements `TilingInterface` and `DestinationStyleOpInterface`. Returns
+/// failure otherwise.
+static FailureOr<OpOperand *> getConsumerFromUses(Value val,
+                                                  Block *containingOpBlock) {
+  // Step 1. Check that the value has exactly one use.
+  if (!llvm::hasSingleElement(val.getUses()))
+    return failure();
+  // Step 2. Get uses.
+  OpOperand &operand = (*val.getUses().begin());
+  Operation *consumerOp = operand.getOwner();
+  // TODO: We have to init result of consumer before scf.for, use
+  //       DestinationStyleOpInterface to get result shape from init for now.
+  //       Add support for other op such as op has InferTypeOpInterface.
+  if (!isa<TilingInterface>(consumerOp) ||
+      !isa<DestinationStyleOpInterface>(consumerOp))
+    return failure();
+  if (containingOpBlock != consumerOp->getBlock())
+    return failure();
+  return &operand;
+}
+
+/// Fetch the untiled consumer of a scf.for's result which is yielded by a
+/// tensor.insert_slice. This function makes the following assumptions :
+/// 1.  tensor.insert_slice has scf.yield as its only user.
+/// 2.  scf.for's corresponding result has only one use.
+static FailureOr<OpOperand *>
+getUntiledConsumerFromSlice(tensor::InsertSliceOp candidateSliceOp) {
+  if (failed(checkAssumptionForFusingConsumer(candidateSliceOp)))
+    return failure();
+  Value sliceResult = candidateSliceOp.getResult();
+  // Step 1. Fetch the corresponding output.
+  OpOperand &yieldOpOperand = (*sliceResult.getUses().begin());
+  unsigned resultNumber = yieldOpOperand.getOperandNumber();
+  // Step 2. Check containing op is scf.for.
+  Operation *containingOp = candidateSliceOp->getParentOp();
+  auto forOp = dyn_cast<scf::ForOp>(containingOp);
+  if (!forOp)
+    return failure();
+  Value resultingValue = forOp->getResult(resultNumber);
+
+  return getConsumerFromUses(resultingValue, containingOp->getBlock());
+}
+
+/// Fetch the first untiled consumer of a scf.forall's result which is yielded
+/// by a tensor.parallel_insert_slice.
+static FailureOr<OpOperand *>
+getUntiledConsumerFromSlice(tensor::ParallelInsertSliceOp candidateSliceOp) {
+  // Step 1. Fetch the corresponding output
+  Value sliceDest = candidateSliceOp.getDest();
+  auto iterArg = dyn_cast<BlockArgument>(sliceDest);
+  if (!iterArg)
+    return failure();
+  Operation *containingOp = iterArg.getOwner()->getParentOp();
+  if (containingOp != candidateSliceOp->getParentOp()->getParentOp())
+    return failure();
+  // Step 2. Check that the containing op is scf.forall.
+  auto forallOp = dyn_cast<scf::ForallOp>(containingOp);
+  if (!forallOp)
+    return failure();
+  Value resultingValue =
+      forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg));
+
+  return getConsumerFromUses(resultingValue, containingOp->getBlock());
+}
+
+/// This utility currently checks whether the loop either :-
+/// 1. Yields exactly one result.
+/// 2. Has consumer op as its first user and other users to be in the same
+/// containing block as that of consumer op's. Currently we clone the loop op
+/// right before the consumer op in order to maintain a valid def-use chain.
+/// This utility thus helps ensuring that no invalid IR is formed due to the
+/// same.
+static LogicalResult checkAssumptionForLoop(Operation *loopOp,
+                                            Operation *consumerOp) {
+  // Check if the loop op yields one result.
+  if (loopOp->getNumResults() == 1)
+    return success();
+  // Check if the consumerOp is the first user of the loopOp and if other users
+  // are in the same containing block as that of consumer op's.
+  Block *parentBlock = consumerOp->getBlock();
+  for (Operation *userOp : loopOp->getUsers()) {
+    if (userOp == consumerOp)
+      continue;
+    if (parentBlock != userOp->getBlock() ||
+        !consumerOp->isBeforeInBlock(userOp))
+      return failure();
+  }
+  return success();
+}
+
+/// A utility to fetch an untiled consumer of
+/// tensor.insert_slice/tensor.parallel_insert_slice.
+static FailureOr<OpOperand *> getUntiledConsumerFromSlice(Operation *sliceOp) {
+  if (auto insertSlice = dyn_cast<tensor::InsertSliceOp>(sliceOp)) {
+    return getUntiledConsumerFromSlice(insertSlice);
+  } else if (auto parallelInsertSlice =
+                 dyn_cast<tensor::ParallelInsertSliceOp>(sliceOp)) {
+    return getUntiledConsumerFromSlice(parallelInsertSlice);
+  } else {
+    return failure();
+  }
+}
+
+/// After fusing consumer into scf.for we want to modify the scf.yield operation
+/// to reflect the same by returning the values yielded by the tiled consumer.
+static void
+fixTerminatorSCFYield(RewriterBase &rewriter, scf::ForOp newForOp,
+                      TilingResult &tilingResult,
+                      ArrayRef<SmallVector<OpFoldResult>> &resultOffsets,
+                      ArrayRef<SmallVector<OpFoldResult>> &resultSizes,
+                      ArrayRef<BlockArgument> bbArgs) {
+  scf::YieldOp oldTerminatorOp =
+      cast<scf::YieldOp>(newForOp.getBody()->getTerminator());
+  unsigned totalOldResults = oldTerminatorOp->getNumResults();
+  unsigned totalTiledResults = tilingResult.tiledOps[0]->getNumResults();
+  SmallVector<Value> newYieldOperands;
+  newYieldOperands.reserve(totalOldResults + totalTiledResults);
+  for (auto oldResult : oldTerminatorOp.getResults()) {
+    newYieldOperands.push_back(oldResult);
+  }
+  rewriter.setInsertionPointAfter(oldTerminatorOp);
+  Location loc = newForOp.getLoc();
+  for (auto [tiledResult, bbArg, resultOffset, resultSize] :
+       llvm::zip_equal(tilingResult.tiledOps[0]->getResults(), bbArgs,
+                       resultOffsets, resultSizes)) {
+    SmallVector<OpFoldResult> strides(resultOffset.size(),
+                                      rewriter.getIndexAttr(1));
+    Value newInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
+        loc, tiledResult, bbArg, resultOffset, resultSize, strides);
+    newYieldOperands.push_back(newInsertSliceOp);
+  }
+  rewriter.create<scf::YieldOp>(loc, newYieldOperands);
+  rewriter.eraseOp(oldTerminatorOp);
+}
+
+/// After fusing consumer into scf.forall we want to yield each of the resulting
+/// values by the tiled consumer within scf.forall.in_parallel region.
+static void
+fixTerminatorSCFInParallel(RewriterBase &rewriter, scf::ForallOp newForallOp,
+                           SmallVector<Value> tiledResults,
+                           ArrayRef<SmallVector<OpFoldResult>> &resultOffsets,
+                           ArrayRef<SmallVector<OpFoldResult>> &resultSizes,
+                           ArrayRef<BlockArgument> bbArgs) {
+  scf::InParallelOp newTerminatorOp = newForallOp.getTerminator();
+  rewriter.setInsertionPointToStart(newTerminatorOp.getBody());
+  Location firstYieldOpLoc =
+      (*(newTerminatorOp.getYieldingOps().begin())).getLoc();
+  for (auto [tiledResult, bbArg, resultOffset, resultSize] :
+       llvm::zip_equal(tiledResults, bbArgs, resultOffsets, resultSizes)) {
+    SmallVector<OpFoldResult> strides(resultOffset.size(),
+                                      rewriter.getIndexAttr(1));
+    rewriter.create<tensor::ParallelInsertSliceOp>(
+        firstYieldOpLoc, tiledResult, bbArg, resultOffset, resultSize, strides);
+  }
+}
+
+/// Implementation of fusing consumer of a single slice by computing the
+/// slice of the consumer in-place for scf loop.
+FailureOr<scf::SCFFuseConsumerOfSliceResult>
+mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
+                                      Operation *candidateSliceOp) {
+  if (!isa<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
+          candidateSliceOp))
+    return failure();
+
+  bool isInsertSliceOp = isa<tensor::InsertSliceOp>(candidateSliceOp);
+
+  // 1. Get the consumer of scf.for for the result yielded by
+  // tensor.insert_slice/parallel_insert_slice.
+  FailureOr<OpOperand *> maybeConsumerOpOperand =
+      getUntiledConsumerFromSlice(candidateSliceOp);
+  if (failed(maybeConsumerOpOperand)) {
+    return rewriter.notifyMatchFailure(candidateSliceOp,
+                                       "could not fetch consumer to fuse");
+  }
+  OpOperand *consumerOpOperand = *maybeConsumerOpOperand;
+  Operation *consumerOp = consumerOpOperand->getOwner();
+  unsigned operandNumber = consumerOpOperand->getOperandNumber();
+  unsigned resultNumber = 0;
+  if (auto producerResult = dyn_cast<OpResult>(consumerOpOperand->get())) {
+    resultNumber = producerResult.getResultNumber();
+  } else {
+    return rewriter.notifyMatchFailure(
+        consumerOp, "consumer op's operand doesn't seem to be an OpResult");
+  }
+
+  Operation *oldLoopOp = nullptr;
+  SmallVector<Value> newOuts;
+  Block *oldLoopBody = nullptr;
+  unsigned initSize = 0;
+  unsigned rank = 1;
+  if (isInsertSliceOp) {
+    auto forOp = candidateSliceOp->getParentOfType<scf::ForOp>();
+    oldLoopOp = forOp;
+    llvm::append_range(newOuts, forOp.getInits());
+    oldLoopBody = forOp.getBody();
+    initSize = forOp.getInits().size();
+  } else {
+    auto forallOp = candidateSliceOp->getParentOfType<scf::ForallOp>();
+    oldLoopOp = forallOp;
+    llvm::append_range(newOuts, forallOp.getOutputs());
+    oldLoopBody = forallOp.getBody();
+    initSize = forallOp.getOutputs().size();
+    rank = forallOp.getRank();
+  }
+
+  if (failed(checkAssumptionForLoop(oldLoopOp, consumerOp))) {
+    return rewriter.notifyMatchFailure(
+        oldLoopOp, "containing loop op should either yield just one value or "
+                   "have the consumer op as its first user");
+  }
+
+  OpBuilder::InsertionGuard g(rewriter);
+
+  // 2. Check consumer is not using scf loop's output as init.
+  auto dstOp = cast<DestinationStyleOpInterface>(consumerOp);
+  SmallVector<Value> dpsInits =
+      llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; });
+  if (llvm::is_contained(dpsInits, oldLoopOp->getResult(resultNumber))) {
+    return rewriter.notifyMatchFailure(
+        consumerOp,
+        "consumer op taking the result of scf.for as init is not supported");
+  }
+  newOuts.append(dpsInits);
+
+  Location loc = oldLoopOp->getLoc();
+
+  // 3. Create new scf loop op.
+  rewriter.setInsertionPoint(consumerOp);
+  Operation *newLoopOp = nullptr;
+  Block *newLoopBody = nullptr;
+  if (isInsertSliceOp) {
+    auto forOp = cast<scf::ForOp>(oldLoopOp);
+    auto newForOp = rewriter.create<scf::ForOp>(loc, forOp.getLowerBound(),
+                                                forOp.getUpperBound(),
+                                                forOp.getStep(), newOuts);
+    newLoopOp = newForOp;
+    newLoopBody = newForOp.getBody();
+  } else {
+    auto forallOp = cast<scf::ForallOp>(oldLoopOp);
+    auto newForallOp = rewriter.create<scf::ForallOp>(
+        loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
+        forallOp.getMixedStep(), newOuts, forallOp.getMapping());
+    newLoopOp = newForallOp;
+    rewriter.eraseOp(newForallOp.getTerminator());
+    newLoopBody = newForallOp.getBody();
+  }
+
+  // 4. Move the loop body to the new op.
+  unsigned oldNumArguments = oldLoopBody->getNumArguments();
+  rewriter.mergeBlocks(oldLoopBody, newLoopBody,
+                       newLoopBody->getArguments().take_front(oldNumArguments));
+
+  // 5. Set insertion point before terminator op of the loop and create a new
+  // tensor.insert_slice. In the scf.for case this is a clone of the
+  // candidateSliceOp whereas in the scf.forall case this is created from the
+  // operands of tensor.parallel_insert_slice.
+  tensor::InsertSliceOp clonedInsertSliceOp;
+  if (auto sliceOp =
+          dyn_cast<tensor::ParallelInsertSliceOp>(candidateSliceOp)) {
+    auto newForallOp = cast<scf::ForallOp>(newLoopOp);
+    rewriter.setInsertionPoint(newForallOp.getTerminator());
+    clonedInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
+        loc, sliceOp.getSource(), sliceOp.getDest(), sliceOp.getMixedOffsets(),
+        sliceOp.getMixedSizes(), sliceOp.getMixedStrides());
+  } else {
+    auto newForOp = cast<scf::ForOp>(newLoopOp);
+    rewriter.setInsertionPoint(newForOp.getBody()->getTerminator());
+    clonedInsertSliceOp =
+        cast<tensor::InsertSliceOp>(rewriter.clone(*candidateSliceOp));
+  }
+
+  // 6.a. Clone consumer op.
+  auto newForOpBlockArgsForConsumerDest =
+      newLoopBody->getArguments().drop_front(oldNumArguments);
+  auto clonedConsumerOp = cast<TilingInterface>(cloneOpAndUpdateDestinationArgs(
+      rewriter, consumerOp, newForOpBlockArgsForConsumerDest));
+
+  // 6.b. Replace all uses of the loop result with the result of the cloned
+  // tensor.insert_slice.
+  OpOperand &operandToReplace = clonedConsumerOp->getOpOperand(operandNumber);
+  rewriter.modifyOpInPlace(clonedConsumerOp, [&]() {
+    operandToReplace.set(clonedInsertSliceOp.getResult());
+  });
+
+  // 7 - Perform tiling of the cloned consumer and replace the operand at
+  // `operandNumber` with the source of the cloned tensor.insert_slice op.
+  auto ossSliceOp =
+      cast<OffsetSizeAndStrideOpInterface>(clonedInsertSliceOp.getOperation());
+  FailureOr<TilingResult> tileAndFuseResult =
+      tensor::replaceInsertSliceWithTiledConsumer(
+          rewriter, ossSliceOp, clonedConsumerOp->getOpOperand(operandNumber));
+  if (failed(tileAndFuseResult)) {
+    return failure();
+  }
+  rewriter.replaceAllUsesWith(
+      tileAndFuseResult->tiledOps[0]->getOperand(operandNumber),
+      clonedInsertSliceOp.getSource());
+
+  // 8 - Extract offset/sizes/strides required to create the
+  // tensor.insert_slice/parallel_insert_slice for each result of the consumer.
+  SmallVector<OpFoldResult> offsets = ossSliceOp.getMixedOffsets();
+  SmallVector<OpFoldResult> sizes = ossSliceOp.getMixedSizes();
+  SmallVector<OpFoldResult> strides = ossSliceOp.getMixedStrides();
+
+  // 9. Check all insert stride is 1.
+  if (llvm::any_of(strides, [](OpFoldResult stride) {
+        return !isConstantIntValue(stride, 1);
+      })) {
+    return rewriter.notifyMatchFailure(
+        candidateSliceOp, "containingOp's result yield with stride");
+  }
+
+  // 10. Try to get iter domain position from input position.
+  SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
+  if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTile(
+          rewriter, operandNumber, offsets, sizes, iterDomainOffsets,
+          iterDomainSizes))) {
+    return rewriter.notifyMatchFailure(
+        clonedConsumerOp, "can't get iter domain position from input position");
+  }
+
+  // 11. Try to fetch the offset and size for all results of the cloned
+  // consumer. This would then be used to form the corresponding
+  // tensor.insert_slice/parallel_insert_slice later.
+  unsigned totalNumResultsOfConsumer = clonedConsumerOp->getNumResults();
+  SmallVector<SmallVector<OpFoldResult>> resultOffsets(
+      totalNumResultsOfConsumer);
+  SmallVector<SmallVector<OpFoldResult>> resultSizes(totalNumResultsOfConsumer);
+  for (auto [idx, v] : llvm::enumerate(clonedConsumerOp->getResults())) {
+    if (failed(clonedConsumerOp.getResultTilePosition(
+            rewriter, idx, iterDomainOffsets, iterDomainSizes,
+            resultOffsets[idx], resultSizes[idx]))) {
+      return rewriter.notifyMatchFailure(
+          clonedConsumerOp,
+          "can't get result domain position from iter domain position");
+    }
+  }
+
+  auto arrayRefOffsets = ArrayRef<SmallVector<OpFoldResult>>(resultOffsets);
+  auto arrayRefSizes = ArrayRef<SmallVector<OpFoldResult>>(resultSizes);
+  if (isInsertSliceOp) {
+    auto newForOp = cast<scf::ForOp>(newLoopOp);
+    fixTerminatorSCFYield(
+        rewriter, newForOp, *tileAndFuseResult, arrayRefOffsets, arrayRefSizes,
+        newForOp.getBody()->getArguments().drop_front(1 + initSize));
+  } else {
+    auto newForallOp = cast<scf::ForallOp>(newLoopOp);
+    fixTerminatorSCFInParallel(
+        rewriter, newForallOp, tileAndFuseResult->tiledOps[0]->getResults(),
+        arrayRefOffsets, arrayRefSizes,
+        newForallOp.getBody()->getArguments().drop_front(rank + initSize));
+  }
+
+  // 12. Replace the result of scf loop and consumer op with new loop's results.
+  for (auto &&[oldResult, newResult] :
+       llvm::zip_first(oldLoopOp->getResults(), newLoopOp->getResults())) {
+    rewriter.replaceAllUsesWith(oldResult, newResult);
+  }
+
+  for (auto &&[oldResult, newResult] :
+       llvm::zip(consumerOp->getResults(),
+                 newLoopOp->getResults().drop_front(initSize))) {
+    rewriter.replaceAllUsesWith(oldResult, newResult);
+  }
+
+  // 13. Need to erase the old scf loop and the cloned consumer op.
+  rewriter.eraseOp(oldLoopOp);
----------------
Abhishek-Varma wrote:

No, I can't. It has uses because of the following :-
```
         %clone_insert_slice = cloned tensor.insert_slice :
                                    %source<32> into %dest<64> | OFFSET | STRIDES | SIZES
                           
         %tiled_operand = tensor.extract_slice:
                           %clone_insert_slice<64> to <32> | OFFSET | STRIDES | SIZES
```
So, to delete `%clone_insert_slice` even `%tiled_operand` needs to be deleted.

https://github.com/llvm/llvm-project/pull/88712


More information about the Mlir-commits mailing list