[Mlir-commits] [mlir] [MLIR][SCF] Add an API to fuse consumer to a producer within scf loop (PR #88712)

Oleksandr Alex Zinenko llvmlistbot at llvm.org
Fri May 17 03:36:39 PDT 2024


================
@@ -1100,6 +1102,398 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
                                    replacements};
 }
 
+//===----------------------------------------------------------------------===//
+// tileAndFuseConsumerUsingSCF implementation.
+//===----------------------------------------------------------------------===//
+
+/// A utility function that checks whether the passed value has only one user.
+/// In case the defining operation is a tensor.insert_slice, it checks if the
+/// user is scf.yield.
+static LogicalResult checkAssumptionForFusingConsumer(Value result) {
+  Value::use_range uses = result.getUses();
+  if (!llvm::hasSingleElement(uses)) {
+    LLVM_DEBUG(llvm::dbgs() << "Too many uses of the candidate slice op\n");
+    return failure();
+  }
+  OpOperand &operandUse = (*uses.begin());
+  Operation *userOp = operandUse.getOwner();
+  if (!isa<scf::YieldOp>(userOp)) {
+    LLVM_DEBUG(llvm::dbgs()
+               << "Expected scf.yield to be the only user, but got -> "
+               << (*userOp));
+    return failure();
+  }
+  if (result.getDefiningOp()->getBlock() != userOp->getBlock()) {
+    LLVM_DEBUG(llvm::dbgs() << "Expected tensor.insert_slice and scf.yield to "
+                               "be in the same block\n");
+    return failure();
+  }
+  return success();
+}
+
+/// Fetch the first untiled consumer of a scf.for's result which is yielded by
+/// a tensor.insert_slice. This function makes the following assumptions :-
+/// 1.  tensor.insert_slice has scf.yield as its only user.
+/// 2.  scf.for's corresponding result has only one use.
+static OpOperand *
+getUntiledConsumerFromSlice(tensor::InsertSliceOp candidateSliceOp) {
+  Value sliceResult = candidateSliceOp.getResult();
+  if (failed(checkAssumptionForFusingConsumer(candidateSliceOp.getResult()))) {
+    return nullptr;
+  }
+  // Step 1. Fetch the corresponding output.
+  OpOperand &yieldOpOperand = (*sliceResult.getUses().begin());
+  unsigned resultNumber = yieldOpOperand.getOperandNumber();
+  // Step 2. Check containing op is scf.for.
+  Operation *containingOp = candidateSliceOp->getParentOp();
+  auto forOp = dyn_cast<scf::ForOp>(containingOp);
+  if (!forOp) {
+    return nullptr;
+  }
+  Value resultingValue = forOp->getResult(resultNumber);
+
+  // Step 3. Check resulting value of scf.for has exactly one use.
+  if (!llvm::hasSingleElement(resultingValue.getUses())) {
+    return nullptr;
+  }
+
+  // Step 4. Get uses.
+  OpOperand &operand = (*resultingValue.getUses().begin());
+  Operation *consumerOp = operand.getOwner();
+  // TODO: We have to init result of consumer before scf.for, use
+  //       DestinationStyleOpInterface to get result shape from init for now.
+  //       Add support for other op such as op has InferTypeOpInterface.
+  if (!isa<TilingInterface>(consumerOp) ||
+      !isa<DestinationStyleOpInterface>(consumerOp)) {
+    return nullptr;
+  }
+  if (containingOp->getBlock() != consumerOp->getBlock()) {
+    return nullptr;
+  }
+  return &operand;
+}
+
+/// Fetch the first untiled consumer of a scf.forall's result which is yielded
+/// by a tensor.parallel_insert_slice.
+static OpOperand *
+getUntiledConsumerFromSlice(tensor::ParallelInsertSliceOp candidateSliceOp) {
+  // Step 1. Fetch the corresponding output
+  Value sliceDest = candidateSliceOp.getDest();
+  auto iterArg = cast<BlockArgument>(sliceDest);
+  Operation *containingOp = iterArg.getOwner()->getParentOp();
+  // Step 2. Check that the containing op is scf.forall.
+  auto forallOp = dyn_cast<scf::ForallOp>(containingOp);
+  if (!forallOp) {
+    return nullptr;
+  }
+  Value resultingValue =
+      forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg));
+  // Step 3. Check resulting value of scf.forall has exactly one use.
+  Value::use_range uses = resultingValue.getUses();
+  if (!llvm::hasSingleElement(uses)) {
+    return nullptr;
+  }
+
+  // Step 4. Get uses.
+  OpOperand &operand = (*resultingValue.getUses().begin());
+  Operation *consumerOp = operand.getOwner();
+  // TODO: We have to init result of consumer before scf.forall, use
+  //       DestinationStyleOpInterface to get result shape from init for now.
+  //       Add support for other op such as op has InferTypeOpInterface.
+  if (!isa<TilingInterface>(consumerOp) ||
+      !isa<DestinationStyleOpInterface>(consumerOp)) {
+    return nullptr;
+  }
+  if (containingOp->getBlock() != consumerOp->getBlock()) {
+    return nullptr;
+  }
+  return &operand;
+}
+
+/// This utility currently checks whether the loop either :-
+/// 1. Yields exactly one result.
+/// 2. Has consumer op as its first user and other users to be in the same
+/// containing block as that of consumer op's. Currently we clone the loop op
+/// right before the consumer op in order to maintain a valid def-use chain.
+/// This utility thus helps ensuring that no invalid IR is formed due to the
+/// same.
+static LogicalResult checkAssumptionForLoop(Operation *loopOp,
+                                            Operation *consumerOp) {
+  // Check if the loop op yields one result.
+  if (loopOp->getNumResults() == 1)
+    return success();
+  // Check if the consumerOp is the first user of the loopOp and if other users
+  // are in the same containing block as that of consumer op's.
+  Block *parentBlock = consumerOp->getBlock();
+  for (Operation *userOp : loopOp->getUsers()) {
+    if (userOp == consumerOp)
+      continue;
+    if (parentBlock != userOp->getBlock() ||
+        !consumerOp->isBeforeInBlock(userOp))
+      return failure();
+  }
+  return success();
+}
+
+static OpOperand *getUntiledConsumerFromSlice(Operation *sliceOp) {
+  if (auto insertSlice = dyn_cast<tensor::InsertSliceOp>(sliceOp)) {
+    return getUntiledConsumerFromSlice(insertSlice);
+  } else if (auto parallelInsertSlice =
+                 dyn_cast<tensor::ParallelInsertSliceOp>(sliceOp)) {
+    return getUntiledConsumerFromSlice(parallelInsertSlice);
+  } else {
+    return nullptr;
+  }
+}
+
+static void
+fixTerminatorSCFYield(RewriterBase &rewriter, scf::ForOp newForOp,
+                      TilingResult tilingResult,
+                      SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
+                      SmallVector<SmallVector<OpFoldResult>> &resultSizes,
+                      SmallVector<OpFoldResult> &strides, unsigned initSize) {
+  scf::YieldOp oldTerminatorOp =
+      cast<scf::YieldOp>(newForOp.getBody()->getTerminator());
+  SmallVector<Value> newYieldOperands(oldTerminatorOp.getResults());
+  rewriter.setInsertionPointAfter(oldTerminatorOp);
+  MutableArrayRef<BlockArgument> bbArgs = newForOp.getBody()->getArguments();
+  Location loc = newForOp.getLoc();
+  for (auto [idx, v] :
+       llvm::enumerate(tilingResult.tiledOps[0]->getResults())) {
+    SmallVector<OpFoldResult> strides(resultOffsets[idx].size(),
+                                      rewriter.getIndexAttr(1));
+    Value newInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
+        loc, v, bbArgs[1 + initSize + idx], resultOffsets[idx],
+        resultSizes[idx], strides);
+    newYieldOperands.push_back(newInsertSliceOp);
+  }
+  rewriter.create<scf::YieldOp>(loc, newYieldOperands);
+  rewriter.eraseOp(oldTerminatorOp);
+}
+
+static void fixTerminatorSCFInParallel(
+    RewriterBase &rewriter, scf::ForallOp newForallOp,
+    TilingResult tilingResult,
+    SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
+    SmallVector<SmallVector<OpFoldResult>> &resultSizes,
+    SmallVector<OpFoldResult> &strides, unsigned initSize, unsigned rank) {
+  scf::InParallelOp newTerminatorOp = newForallOp.getTerminator();
+  rewriter.setInsertionPointToStart(newTerminatorOp.getBody());
+  Location firstYieldOpLoc =
+      (*(newTerminatorOp.getYieldingOps().begin())).getLoc();
+  MutableArrayRef<BlockArgument> bbArgs = newForallOp.getBody()->getArguments();
+  for (auto [idx, v] :
+       llvm::enumerate(tilingResult.tiledOps[0]->getResults())) {
+    SmallVector<OpFoldResult> strides(resultOffsets[idx].size(),
+                                      rewriter.getIndexAttr(1));
+    rewriter.create<tensor::ParallelInsertSliceOp>(
+        firstYieldOpLoc, v, bbArgs[rank + initSize + idx], resultOffsets[idx],
+        resultSizes[idx], strides);
+  }
+}
+
+/// Implementation of fusing consumer of a single slice by computing the
+/// slice of the consumer in-place for scf loop.
+FailureOr<scf::SCFFuseConsumerOfSliceResult>
+mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
+                                      Operation *candidateSliceOp) {
+  if (!isa<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
+          candidateSliceOp))
+    return failure();
+
+  bool isInsertSliceOp = isa<tensor::InsertSliceOp>(candidateSliceOp);
+
+  // 1. Get the consumer of scf.for for the result yielded by
+  // tensor.insert_slice/parallel_insert_slice.
+  OpOperand *consumerOpOperand = getUntiledConsumerFromSlice(candidateSliceOp);
+  if (!consumerOpOperand) {
+    return rewriter.notifyMatchFailure(candidateSliceOp,
+                                       "could not fetch consumer to fuse");
+  }
+  Operation *consumerOp = consumerOpOperand->getOwner();
+  unsigned operandNumber = consumerOpOperand->getOperandNumber();
+  unsigned resultNumber =
+      cast<OpResult>(consumerOpOperand->get()).getResultNumber();
+
+  Operation *oldLoopOp = nullptr;
+  SmallVector<Value> newOuts;
+  Block *oldLoopBody = nullptr;
+  unsigned initSize = 0;
+  unsigned rank = 1;
+  if (isInsertSliceOp) {
+    auto forOp = candidateSliceOp->template getParentOfType<scf::ForOp>();
----------------
ftynse wrote:

```suggestion
    auto forOp = candidateSliceOp->getParentOfType<scf::ForOp>();
```

Nit: I believe this is no longer necessary, here and below.

https://github.com/llvm/llvm-project/pull/88712


More information about the Mlir-commits mailing list