[Mlir-commits] [mlir] [MLIR][SCF] Add an API to fuse consumer to a producer within scf loop (PR #88712)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Apr 19 14:32:51 PDT 2024


================
@@ -1100,6 +1102,475 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
                                    replacements};
 }
 
+//===----------------------------------------------------------------------===//
+// tileAndFuseConsumerUsingSCF implementation.
+//===----------------------------------------------------------------------===//
+
+/// A utility function that checks whether the passed value has only one user.
+/// In case the defining operation is a tensor.insert_slice, it checks if the
+/// user is scf.yield.
+static LogicalResult checkAssumptionForFusingConsumer(Value result) {
+  Value::use_range uses = result.getUses();
+  if (!llvm::hasSingleElement(uses)) {
+    LLVM_DEBUG(llvm::dbgs() << "Too many uses of the candidate slice op\n");
+    return failure();
+  }
+  OpOperand &operandUse = (*uses.begin());
+  Operation *userOp = operandUse.getOwner();
+  if (!isa<scf::YieldOp>(userOp)) {
+    LLVM_DEBUG(llvm::dbgs()
+               << "Expected scf.yield to be the only user, but got -> "
+               << (*userOp));
+    return failure();
+  }
+  return success();
+}
+
+/// Fetch the first untiled consumer of a scf.for's result which is yielded by
+/// a tensor.insert_slice. This function makes the following assumptions :-
+/// 1.  tensor.insert_slice has scf.yield as its only user.
+/// 2.  scf.for's correspon
+static FailureOr<OpOperand *>
+getUntiledConsumerFromSlice(tensor::InsertSliceOp candidateSliceOp) {
+  Value sliceResult = candidateSliceOp.getResult();
+  if (failed(checkAssumptionForFusingConsumer(candidateSliceOp.getResult()))) {
+    return failure();
+  }
+  // Step 1. Fetch the corresponding output.
+  OpOperand &yieldOpOperand = (*sliceResult.getUses().begin());
+  unsigned resultNumber = yieldOpOperand.getOperandNumber();
+  // Check containing op is "scf::ForOp".
+  Operation *containingOp = candidateSliceOp->getParentOp();
+  auto forOp = dyn_cast<scf::ForOp>(containingOp);
+  if (!forOp) {
+    return failure();
+  }
+  Value resultingValue = forOp->getResult(resultNumber);
+
+  // Check resultingValue has exactly one use.
+  if (!llvm::hasSingleElement(resultingValue.getUses())) {
+    return failure();
+  }
+
+  // Step 2. Get uses.
+  OpOperand &operand = (*resultingValue.getUses().begin());
+  return &operand;
+}
+
+/// Implementation of fusing consumer of a single slice by computing the
+/// slice of the consumer in-place for scf.for.
+static FailureOr<scf::SCFFuseConsumerOfSliceResult>
+tileAndFuseConsumerOfSliceSCFFor(RewriterBase &rewriter,
+                                 tensor::InsertSliceOp candidateSliceOp) {
+  // 1. Get the consumer of the source.
+  FailureOr<OpOperand *> consumerOpOperand =
+      getUntiledConsumerFromSlice(candidateSliceOp);
+  if (failed(consumerOpOperand)) {
+    return rewriter.notifyMatchFailure(candidateSliceOp,
+                                       "could not fetch consumer to fuse");
+  }
+  Operation *consumerOp = (*consumerOpOperand)->getOwner();
+  unsigned operandNumber = (*consumerOpOperand)->getOperandNumber();
+  unsigned resultNumber =
+      cast<OpResult>((*consumerOpOperand)->get()).getResultNumber();
+
+  Operation *containingOp = candidateSliceOp->getParentOp();
+  auto forOp = static_cast<scf::ForOp>(containingOp);
+
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(candidateSliceOp);
+
+  // Check consumer has tiling interface.
+  auto tileableConsumer = dyn_cast<TilingInterface>(consumerOp);
+  if (!tileableConsumer) {
+    return rewriter.notifyMatchFailure(consumerOp,
+                                       "consumer is not a TileableInterface");
+  }
+
+  // TODO: We have to init result of consumer before scf.for, use
+  //       DestinationStyleOpInterface to get result shape from init for now.
+  //       Add support for other op such as op has InferTypeOpInterface.
+  // Check consumer has DestinationStyleOpInterface.
+  auto dstOp = dyn_cast<DestinationStyleOpInterface>(consumerOp);
+  if (!dstOp) {
+    return rewriter.notifyMatchFailure(
+        consumerOp, "consumer op should have destination style op interface");
+  }
+
+  // Check consumer is not using scf.for's output as init.
+  SmallVector<Value> dpsInits =
+      llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; });
+  if (llvm::is_contained(dpsInits, forOp.getResult(0))) {
+    return rewriter.notifyMatchFailure(
+        consumerOp,
+        "consumer op taking the result of scf.for as init is not supported");
+  }
+
+  Location loc = forOp.getLoc();
+  SmallVector<Value> newOuts(forOp.getInits());
+  newOuts.append(dpsInits);
+
+  // Create new scf.for op.
+  rewriter.setInsertionPoint(consumerOp);
+  auto newforOp = rewriter.create<scf::ForOp>(loc, forOp.getLowerBound(),
+                                              forOp.getUpperBound(),
+                                              forOp.getStep(), newOuts);
+  // Move the loop body to the new op.
+  Block *loopBody = forOp.getBody();
+  Block *newLoopBody = newforOp.getBody();
+  rewriter.mergeBlocks(
+      loopBody, newLoopBody,
+      newLoopBody->getArguments().take_front(loopBody->getNumArguments()));
+
+  // 1 - Clone tensor.insert_slice after original tensor.insert_slice.
+  rewriter.setInsertionPointAfter(candidateSliceOp);
+  SmallVector<Value> candidateSliceOpOperands =
+      llvm::to_vector(candidateSliceOp->getOperands());
+  tensor::InsertSliceOp clonedCandidateSliceOp =
+      mlir::clone(rewriter, candidateSliceOp,
+                  candidateSliceOp->getResultTypes(), candidateSliceOpOperands);
+
+  // 2.a - Clone consumer after the cloned tensor.insert_slice op.
+  rewriter.setInsertionPointAfter(clonedCandidateSliceOp);
+  SmallVector<Value> newForOpBlockArgsForConsumerDest = llvm::map_to_vector(
+      newLoopBody->getArguments().drop_front(loopBody->getNumArguments()),
+      [](BlockArgument b) -> Value { return b; });
+  auto clonedConsumerOp = cast<TilingInterface>(cloneOpAndUpdateDestinationArgs(
+      rewriter, consumerOp, newForOpBlockArgsForConsumerDest));
+  tileableConsumer = clonedConsumerOp;
+
+  // 2.b - Replace all uses of the loop result with the result of the cloned
+  // tensor.insert_slice.
+  rewriter.replaceUsesWithIf(forOp.getResult(resultNumber),
+                             clonedCandidateSliceOp.getResult(),
+                             [&](OpOperand &operand) {
+                               return operand.getOwner() == clonedConsumerOp;
+                             });
+
+  // 3 - Perform tiling of the cloned consumer.
+  rewriter.setInsertionPointAfter(clonedConsumerOp);
+  FailureOr<TilingResult> tileAndFuseResult =
+      tensor::replaceInsertSliceWithTiledConsumer(
+          rewriter,
+          cast<OffsetSizeAndStrideOpInterface>(
+              clonedCandidateSliceOp.getOperation()),
+          clonedConsumerOp->getOpOperand(operandNumber));
+  if (failed(tileAndFuseResult)) {
+    return rewriter.notifyMatchFailure(tileableConsumer,
+                                       "failed to tile consumer op: ");
+  }
+
+  // 4 - Extract offset/sizes/strides required to create the tensor.insert_slice
+  // for each result of the consumer.
+  SmallVector<OpFoldResult> offsets = clonedCandidateSliceOp.getMixedOffsets();
+  SmallVector<OpFoldResult> sizes = clonedCandidateSliceOp.getMixedSizes();
+  SmallVector<OpFoldResult> strides = clonedCandidateSliceOp.getMixedStrides();
+  // Check all insert stride is 1.
+  if (llvm::any_of(strides, [](OpFoldResult stride) {
+        return !isConstantIntValue(stride, 1);
+      })) {
+    return rewriter.notifyMatchFailure(
+        clonedCandidateSliceOp, "containingOp's result yield with stride");
+  }
+  SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
+  // Try to get iter domain position from input position.
+  rewriter.setInsertionPointAfter(clonedConsumerOp);
+  if (failed(tileableConsumer.getIterDomainTilePositionFromOperandPosition(
+          rewriter, operandNumber, offsets, sizes, iterDomainOffsets,
+          iterDomainSizes))) {
+    return rewriter.notifyMatchFailure(
+        tileableConsumer, "can't get iter domain position from input position");
+  }
+
+  // Try to get all containing op result's position from iter domain position.
+  llvm::SmallVector<std::pair<llvm::SmallVector<OpFoldResult>,
+                              llvm::SmallVector<OpFoldResult>>>
+      resultPositions(clonedConsumerOp->getNumResults());
+  for (auto [idx, v] : llvm::enumerate(clonedConsumerOp->getResults())) {
+    if (failed(tileableConsumer.getResultTilePosition(
+            rewriter, idx, iterDomainOffsets, iterDomainSizes,
+            resultPositions[idx].first, resultPositions[idx].second))) {
+      return rewriter.notifyMatchFailure(
+          tileableConsumer,
+          "can't get result domain position from iter domain position");
+    }
+  }
+
+  // 5 - Fix terminator.
+  scf::YieldOp oldTerminatorOp =
+      static_cast<scf::YieldOp>(newforOp.getBody()->getTerminator());
+  SmallVector<Value> newYieldOperands(oldTerminatorOp.getResults());
+  rewriter.setInsertionPointAfter(oldTerminatorOp);
+  auto bbArgs = newforOp.getBody()->getArguments();
+  for (auto [idx, v] :
+       llvm::enumerate(tileAndFuseResult->tiledOps[0]->getResults())) {
+    SmallVector<OpFoldResult> strides(resultPositions[idx].first.size(),
+                                      rewriter.getIndexAttr(1));
+    newYieldOperands.push_back(rewriter.create<tensor::InsertSliceOp>(
+        clonedCandidateSliceOp->getLoc(), v,
+        bbArgs[1 + forOp.getInits().size() + idx], resultPositions[idx].first,
+        resultPositions[idx].second, strides));
+  }
+  rewriter.create<scf::YieldOp>(loc, newYieldOperands);
+  rewriter.eraseOp(oldTerminatorOp);
+
+  // 6 - Replace the result of scf.for and consumer op.
+  for (auto result : llvm::enumerate(forOp.getResults())) {
+    rewriter.replaceAllUsesWith(result.value(),
+                                newforOp->getResult(result.index()));
+  }
+
+  for (auto consumerResult : llvm::enumerate(consumerOp->getResults())) {
+    rewriter.replaceAllUsesWith(
+        consumerResult.value(),
+        newforOp->getResult(forOp.getInits().size() + consumerResult.index()));
+  }
+
+  rewriter.replaceOp(candidateSliceOp, clonedCandidateSliceOp);
+
+  // 7 - Need to erase the old scf.for.
+  rewriter.eraseOp(forOp);
+  rewriter.eraseOp(consumerOp);
+  rewriter.eraseOp(clonedConsumerOp);
+
+  return scf::SCFFuseConsumerOfSliceResult{
+      consumerOp, tileAndFuseResult->tiledOps[0]->getResult(0), {}};
+}
+
+/// Fetch the first untiled consumer of a scf.forall's result which is yielded
+/// by a tensor.parallel_insert_slice.
+static FailureOr<OpOperand *>
+getUntiledConsumerFromSlice(tensor::ParallelInsertSliceOp candidateSliceOp) {
+  // Step 1. Fetch the corresponding output
+  Value sliceDest = candidateSliceOp.getDest();
+  auto iterArg = dyn_cast<BlockArgument>(sliceDest);
+  Operation *containingOp = iterArg.getOwner()->getParentOp();
+  // Step 2. Check assumptions for the containing op.
+  auto forallOp = dyn_cast<scf::ForallOp>(containingOp);
+  if (!forallOp) {
+    return failure();
+  }
+  unsigned resultNumber = 0;
+  for (BlockArgument val : forallOp.getRegionOutArgs()) {
+    if (val == iterArg) {
+      break;
+    }
+    resultNumber++;
+  }
+  Value resultingValue = forallOp->getResult(resultNumber);
+  Value::use_range uses = resultingValue.getUses();
+  if (!llvm::hasSingleElement(uses)) {
+    return failure();
+  }
+
+  // Step 3. Get uses.
+  OpOperand &operand = (*resultingValue.getUses().begin());
+  return &operand;
+}
+
+/// Implementation of fusing consumer of a single slice by computing the
+/// slice of the consumer in-place for scf.forall.
+static FailureOr<scf::SCFFuseConsumerOfSliceResult>
+tileAndFuseConsumerOfSliceSCFForall(
+    RewriterBase &rewriter, tensor::ParallelInsertSliceOp candidateSliceOp) {
+  // 1. Get the consumer of the dest.
+  FailureOr<OpOperand *> consumerOpOperand =
+      getUntiledConsumerFromSlice(candidateSliceOp);
+  if (failed(consumerOpOperand)) {
+    return rewriter.notifyMatchFailure(candidateSliceOp,
+                                       "could not fetch consumer to fuse");
+  }
+  Operation *consumerOp = (*consumerOpOperand)->getOwner();
+  unsigned operandNumber = (*consumerOpOperand)->getOperandNumber();
+  unsigned resultNumber =
+      cast<OpResult>((*consumerOpOperand)->get()).getResultNumber();
+
+  OpBuilder::InsertionGuard g(rewriter);
+  // Using candidateSliceOp->getParentOp() because we have the following case :-
+  // scf.forall.in_parallel {
+  //   tensor.parallel_insert_slice ...
+  // }
+  rewriter.setInsertionPoint(candidateSliceOp->getParentOp());
+
+  Operation *containingOp = candidateSliceOp->getParentOp()->getParentOp();
+  // Check consumer has tiling interface.
+  auto tileableConsumer = dyn_cast<TilingInterface>(consumerOp);
+  if (!tileableConsumer) {
+    return rewriter.notifyMatchFailure(consumerOp,
+                                       "consumer is not a TileableInterface");
+  }
+
+  auto forallOp = static_cast<scf::ForallOp>(containingOp);
+  // TODO: We have to init result of consumer before scf.forall, use
+  //       DestinationStyleOpInterface to get result shape from init for now.
+  //       Add support for other op such as op has InferTypeOpInterface.
+  // Check consumer has DestinationStyleOpInterface.
+  auto dstOp = dyn_cast<DestinationStyleOpInterface>(consumerOp);
+  if (!dstOp) {
+    return rewriter.notifyMatchFailure(
+        consumerOp, "consumer op should have destination style op interface");
+  }
+
+  // Check consumer doesn't use scf.forall's output as init.
+  SmallVector<Value> dpsInits =
+      llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; });
+  if (llvm::is_contained(dpsInits, forallOp.getResult(resultNumber))) {
+    return rewriter.notifyMatchFailure(
+        consumerOp,
+        "consumer op taking the result of scf.forall as init is not supported");
+  }
+
+  Location loc = forallOp.getLoc();
+  // Create new scf.forall op.
+  SmallVector<Value> newOuts(forallOp.getOutputs());
+  newOuts.append(dpsInits);
+  rewriter.setInsertionPoint(consumerOp);
+  auto newforallOp = rewriter.create<scf::ForallOp>(
+      loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
+      forallOp.getMixedStep(), newOuts, forallOp.getMapping());
+
+  // Move the loop body to the new op.
+  rewriter.eraseOp(newforallOp.getTerminator());
+  Block *loopBody = forallOp.getBody();
+  Block *newLoopBody = newforallOp.getBody();
+  rewriter.mergeBlocks(
+      loopBody, newLoopBody,
+      newLoopBody->getArguments().take_front(loopBody->getNumArguments()));
+
+  // 1 - Clone tensor.parallel_insert_slice after the original
+  // tensor.parallel_insert_slice.
+  rewriter.setInsertionPointAfter(candidateSliceOp);
+  SmallVector<Value> candidateSliceOpOperands =
+      llvm::to_vector(candidateSliceOp->getOperands());
+  tensor::ParallelInsertSliceOp clonedCandidateSliceOp =
+      mlir::clone(rewriter, candidateSliceOp,
+                  candidateSliceOp->getResultTypes(), candidateSliceOpOperands);
+  LLVM_DEBUG(llvm::dbgs() << "Created a clone of the candidate slice op : "
+                          << clonedCandidateSliceOp << "\n");
+
+  // 2 - Clone the consumer after the clone tensor.parallel_insert_slice.
+  rewriter.setInsertionPointAfter(clonedCandidateSliceOp);
+  SmallVector<Value> newForOpBlockArgsForConsumerDest = llvm::map_to_vector(
+      newLoopBody->getArguments().drop_front(loopBody->getNumArguments()),
+      [](BlockArgument b) -> Value { return b; });
+  auto clonedConsumerOp = cast<TilingInterface>(cloneOpAndUpdateDestinationArgs(
+      rewriter, consumerOp, newForOpBlockArgsForConsumerDest));
+  tileableConsumer = clonedConsumerOp;
+
+  // 2.b - Replace all uses of the scf.forall's result use in the consumer with
+  // the source of the cloned tensor.parallel_insert_slice.
+  rewriter.replaceUsesWithIf(forallOp.getResult(resultNumber),
+                             clonedCandidateSliceOp.getSource(),
+                             [&](OpOperand &operand) {
+                               return operand.getOwner() == clonedConsumerOp;
+                             });
+
+  // 3 - Perform tiling of the cloned consumer.
+  rewriter.setInsertionPoint(newforallOp.getTerminator());
+  FailureOr<TilingResult> tileAndFuseResult =
+      tensor::replaceInsertSliceWithTiledConsumer(
+          rewriter,
+          cast<OffsetSizeAndStrideOpInterface>(
+              clonedCandidateSliceOp.getOperation()),
+          clonedConsumerOp->getOpOperand(operandNumber));
+  if (failed(tileAndFuseResult)) {
+    return rewriter.notifyMatchFailure(tileableConsumer,
+                                       "failed to tile consumer op: ");
+  }
+
+  // 4 - Extract offset/sizes/strides required to create the
+  // tensor.parallel_insert_slice for each result of the consumer.
+  SmallVector<OpFoldResult> offsets = clonedCandidateSliceOp.getMixedOffsets();
+  SmallVector<OpFoldResult> sizes = clonedCandidateSliceOp.getMixedSizes();
+  SmallVector<OpFoldResult> strides = clonedCandidateSliceOp.getMixedStrides();
+  // Check all insert stride is 1.
+  if (llvm::any_of(strides, [](OpFoldResult stride) {
+        return !isConstantIntValue(stride, 1);
+      })) {
+    return rewriter.notifyMatchFailure(
+        clonedCandidateSliceOp, "containingOp's result yield with stride");
+  }
+  SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
+  // Try to get iter domain position from input position.
+  rewriter.setInsertionPointAfter(tileAndFuseResult->tiledOps[0]);
+  ;
+  if (failed(tileableConsumer.getIterDomainTilePositionFromOperandPosition(
+          rewriter, operandNumber, offsets, sizes, iterDomainOffsets,
+          iterDomainSizes))) {
+    return rewriter.notifyMatchFailure(
+        tileableConsumer, "can't get iter domain position from input position");
+  }
+
+  // Try to get all containing op result's position from iter domain position.
+  llvm::SmallVector<std::pair<llvm::SmallVector<OpFoldResult>,
+                              llvm::SmallVector<OpFoldResult>>>
+      resultPositions(clonedConsumerOp->getNumResults());
+  for (auto [idx, v] : llvm::enumerate(clonedConsumerOp->getResults())) {
+    if (failed(tileableConsumer.getResultTilePosition(
+            rewriter, idx, iterDomainOffsets, iterDomainSizes,
+            resultPositions[idx].first, resultPositions[idx].second))) {
+      return rewriter.notifyMatchFailure(
+          tileableConsumer,
+          "can't get result domain position from iter domain position");
+    }
+  }
+
+  // 5 - Fix terminator.
+  scf::InParallelOp newTerminatorOp = newforallOp.getTerminator();
+  SmallVector<Operation *> yieldingOps = llvm::map_to_vector(
+      newTerminatorOp.getYieldingOps(), [](Operation &op) { return &op; });
+  Operation *firstYieldOp = yieldingOps.front();
+  rewriter.setInsertionPoint(firstYieldOp);
+  auto bbArgs = newforallOp.getBody()->getArguments();
+  for (auto [idx, v] :
+       llvm::enumerate(tileAndFuseResult->tiledOps[0]->getResults())) {
+    SmallVector<OpFoldResult> strides(resultPositions[idx].first.size(),
+                                      rewriter.getIndexAttr(1));
+    rewriter.create<tensor::ParallelInsertSliceOp>(
+        firstYieldOp->getLoc(), v,
+        bbArgs[forallOp.getRank() + forallOp.getOutputs().size() + idx],
+        resultPositions[idx].first, resultPositions[idx].second, strides);
+  }
+
+  // Replace the result of scf.forall and consumer op.
+  for (auto result : llvm::enumerate(forallOp.getResults())) {
+    rewriter.replaceAllUsesWith(result.value(),
+                                newforallOp->getResult(result.index()));
+  }
+
+  for (auto consumerResult : llvm::enumerate(consumerOp->getResults())) {
+    rewriter.replaceAllUsesWith(
+        consumerResult.value(),
+        newforallOp->getResult(forallOp.getOutputs().size() +
+                               consumerResult.index()));
+  }
+
+  // Need to erase the old scf.forall, consumer, cloned consumer and
+  // candidateSliceOp.
+  rewriter.eraseOp(forallOp);
+  rewriter.eraseOp(consumerOp);
----------------
MaheshRavishankar wrote:

Dont need to delete the `consumerOp` here. This is upto the caller.

https://github.com/llvm/llvm-project/pull/88712


More information about the Mlir-commits mailing list