[Mlir-commits] [mlir] [MLIR][SCF] Add an API to fuse consumer to a producer within scf loop (PR #88712)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Apr 19 00:09:22 PDT 2024


================
@@ -1100,6 +1101,428 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
                                    replacements};
 }
 
+//===----------------------------------------------------------------------===//
+// tileAndFuseConsumerUsingSCF implementation.
+//===----------------------------------------------------------------------===//
+
+/// In the following function `source` is the source operand of
+/// tensor.insert_slice op. We traverse through the use-def chain of the same
+/// through the containing scf.for to fetch the first untiled consumer.
+static std::tuple<Operation *, std::optional<OpOperand *>>
+getUntiledConsumerFromSliceDestSCFFor(OpOperand &source,
+                                      unsigned &operandNumber) {
+  // Step 1. Fetch the corresponding output
+  // TODO(avarma): Make it generic for multiple values yielding scf.for.
+  unsigned yieldOperandNumber = source.getOperandNumber();
+  Value resultingValue =
+      source.getOwner()->getParentOp()->getResult(yieldOperandNumber);
+
+  // Step 3. Get users.
+  std::optional<OpOperand *> destinationIterArg;
+  Operation *untiledConsumer;
+  for (Operation *user : resultingValue.getUsers()) {
+    // TODO(avarma): Address the case where the consumer op itself can return
+    //               more than one result.
+    for (Value operand : user->getOperands()) {
+      if (operand == resultingValue) {
+        untiledConsumer = user;
+        break;
+      }
+      operandNumber++;
+    }
+    break;
+  }
+  return {untiledConsumer, destinationIterArg};
+}
+
+static bool checkAssumptionForFusingConsumer(Operation *op) {
+  Value result = op->getResult(0);
+  Value::user_range users = result.getUsers();
+  if (std::distance(users.begin(), users.end()) != 1) {
+    LLVM_DEBUG(llvm::dbgs() << "Too many uses of the candidate slice op\n");
+    return false;
+  }
+  if (isa<tensor::InsertSliceOp>(op) && !isa<scf::YieldOp>(*users.begin())) {
+    LLVM_DEBUG(llvm::dbgs() << "Expected scf.yield to be the only user\n");
+    return false;
+  }
+  return true;
+}
+
+/// Implementation of fusing consumer of a single slice by computing the
+/// slice of the consumer in-place for scf.for.
+static FailureOr<scf::SCFFuseConsumerOfSliceResult>
+tileAndFuseConsumerOfSliceSCFFor(RewriterBase &rewriter,
+                                 tensor::InsertSliceOp candidateSliceOp) {
+  // ASSUMING THAT YIELD OP IS ONLY YIELDING JUST ONE VALUE.
+  if (!checkAssumptionForFusingConsumer(candidateSliceOp)) {
+    return rewriter.notifyMatchFailure(candidateSliceOp,
+                                       "needs only scf.yield as its user");
+  }
+  // 1. Get the consumer of the source.
+  unsigned operandNumber = 0;
+  auto [consumerOp, destinationInitArg] = getUntiledConsumerFromSliceDestSCFFor(
+      candidateSliceOp->getOpOperand(0), operandNumber);
+  if (!consumerOp)
+    return failure();
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(candidateSliceOp);
+
+  Operation *containingOp = candidateSliceOp->getParentOp();
+  // Check consumer has tiling interface.
+  auto tileableConsumer = dyn_cast<TilingInterface>(consumerOp);
+  if (!tileableConsumer) {
+    return rewriter.notifyMatchFailure(consumerOp,
+                                       "consumer is not a TileableInterface");
+  }
+
+  // Check containing op is "scf::ForOp".
+  auto forOp = dyn_cast<scf::ForOp>(containingOp);
+  if (!forOp) {
+    return rewriter.notifyMatchFailure(containingOp,
+                                       "containing op is not a scf.for");
+  }
+
+  // Check containingOp has exactly one use.
+  assert(forOp.getResults().size() == 1 &&
+         "expect exactly one result of the containing op");
+  if (!checkAssumptionForFusingConsumer(forOp)) {
+    return rewriter.notifyMatchFailure(forOp, "scf.for has more than 1 uses");
+  }
+  Value bridge(nullptr);
+  SmallVector<unsigned> operandNums;
+  for (auto [idx, opd] : llvm::enumerate((consumerOp->getOperands()))) {
+    if (opd.getDefiningOp() == containingOp) {
+      operandNums.push_back(idx);
+      if (!bridge) {
+        bridge = opd;
+      } else if (bridge != opd) {
+        return rewriter.notifyMatchFailure(
+            consumerOp,
+            "consumer's operand use more than one containingOp's result");
+      }
+    }
+  }
+
+  // TODO: We have to init result of consumer before scf.for, use
+  //       DestinationStyleOpInterface to get result shape from init for now.
+  //       Add support for other op such as op has InferTypeOpInterface.
+  // Check consumer has DestinationStyleOpInterface.
+  auto dstOp = dyn_cast<DestinationStyleOpInterface>(consumerOp);
+  if (!dstOp) {
+    return rewriter.notifyMatchFailure(
+        consumerOp, "consumer op should have destination style op interface");
+  }
+
+  // Check consumer is not using scf.for's output as init.
+  SmallVector<Value> dpsInits = llvm::to_vector<4>(
+      llvm::map_range(dstOp.getDpsInits(), [](Value v) { return v; }));
+  if (llvm::is_contained(dpsInits, forOp.getResult(0))) {
+    return rewriter.notifyMatchFailure(
+        consumerOp,
+        "consumer op taking the result of scf.for as init is not supported");
+  }
+
+  Location loc = forOp.getLoc();
+  SmallVector<OpFoldResult> offsets = candidateSliceOp.getMixedOffsets();
+  SmallVector<OpFoldResult> sizes = candidateSliceOp.getMixedSizes();
+  SmallVector<OpFoldResult> strides = candidateSliceOp.getMixedStrides();
+  // Check all insert stride is 1.
+  if (llvm::any_of(strides, [](OpFoldResult stride) {
+        return !isConstantIntValue(stride, 1);
+      })) {
+    return rewriter.notifyMatchFailure(
+        candidateSliceOp, "containingOp's result yield with stride");
+  }
+
+  SmallVector<Value> newOuts(forOp.getInits());
+  newOuts.append(dpsInits);
+
+  // Create new scf.for op.
+  rewriter.setInsertionPoint(consumerOp);
+  auto newforOp = rewriter.create<scf::ForOp>(loc, forOp.getLowerBound(),
+                                              forOp.getUpperBound(),
+                                              forOp.getStep(), newOuts);
+  // Move the loop body to the new op.
+  Block *loopBody = forOp.getBody();
+  Block *newLoopBody = newforOp.getBody();
+  rewriter.mergeBlocks(
+      loopBody, newLoopBody,
+      newLoopBody->getArguments().take_front(loopBody->getNumArguments()));
+
+  // Clone the consumer after the insert_slice.
+  rewriter.setInsertionPointAfter(candidateSliceOp);
+  SmallVector<Value> newForOpBlockArgsForConsumerDest;
+  for (unsigned i = loopBody->getNumArguments(),
+                n = newLoopBody->getArguments().size();
+       i < n; i++) {
+    newForOpBlockArgsForConsumerDest.push_back(newLoopBody->getArgument(i));
+  }
+  auto clonedConsumerOp = cast<TilingInterface>(cloneOpAndUpdateDestinationArgs(
+      rewriter, consumerOp, newForOpBlockArgsForConsumerDest));
+
+  // Replace scf.for result's use in the consumer with insert_slice result.
+  rewriter.replaceAllUsesWith(forOp.getResult(0), candidateSliceOp.getResult());
+
+  // Generate the tiled implementation of the consumer of the source.
+  rewriter.setInsertionPoint(candidateSliceOp);
+  FailureOr<TilingResult> tileAndFuseResult =
+      tensor::replaceInsertSliceWithTiledConsumer(
+          rewriter, candidateSliceOp,
+          clonedConsumerOp->getOpOperand(operandNumber));
+  if (failed(tileAndFuseResult)) {
+    return rewriter.notifyMatchFailure(tileableConsumer,
+                                       "failed to tile consumer op: ");
+  }
+
+  // Update the source of the candidateSlice to be the cloned consumer.
+  SmallVector<Value> candidateSliceOpOperands =
+      llvm::to_vector(candidateSliceOp->getOperands());
+  candidateSliceOpOperands[0] = tileAndFuseResult->tiledValues[0];
+  tensor::InsertSliceOp clonedCandidateSliceOp =
+      mlir::clone(rewriter, candidateSliceOp,
+                  candidateSliceOp->getResultTypes(), candidateSliceOpOperands);
+  auto bbArgs = newforOp.getBody()->getArguments();
+  clonedCandidateSliceOp->getOpOperands()[1].set(
+      bbArgs[1 + forOp.getInits().size() + 0]);
+
+  rewriter.replaceAllUsesWith(candidateSliceOp, candidateSliceOp.getSource());
+  rewriter.eraseOp(clonedConsumerOp);
+
+  // Fix terminator.
+  scf::YieldOp oldTerminatorOp =
+      static_cast<scf::YieldOp>(newforOp.getBody()->getTerminator());
+
+  SmallVector<Value> newYieldOperands;
+  for (Value val : oldTerminatorOp.getResults()) {
+    if (val == candidateSliceOp.getSource()) {
+      newYieldOperands.push_back(candidateSliceOp.getResult());
+    } else {
+      newYieldOperands.push_back(val);
+    }
+  }
+  newYieldOperands.push_back(clonedCandidateSliceOp.getResult());
+  rewriter.setInsertionPointAfter(oldTerminatorOp);
+  rewriter.create<scf::YieldOp>(loc, newYieldOperands);
+  rewriter.eraseOp(oldTerminatorOp);
+
+  // Replace the result of for and consumer op.
+  for (auto result : llvm::enumerate(forOp.getResults())) {
+    rewriter.replaceAllUsesWith(result.value(),
+                                newforOp->getResult(result.index()));
+  }
+
+  for (auto consumerResult : llvm::enumerate(consumerOp->getResults())) {
+    rewriter.replaceAllUsesWith(
+        consumerResult.value(),
+        newforOp->getResult(forOp.getInits().size() + consumerResult.index()));
+  }
+
+  // Need to erase the old for.
+  rewriter.eraseOp(forOp);
+  rewriter.eraseOp(consumerOp);
+
+  return scf::SCFFuseConsumerOfSliceResult{
+      consumerOp, tileAndFuseResult->tiledOps[0]->getResult(0), {}};
+}
+
+/// In the following function `dest` is the destination operand of
+/// tensor.parallel_insert_slice op. We traverse through the use-def chain of
+/// the same through the containing scf.forall to fetch the first untiled
+/// consumer.
+static std::tuple<Operation *, std::optional<OpOperand *>>
+getUntiledConsumerFromSliceDestSCFForall(OpOperand *dest,
+                                         unsigned &operandNumber) {
+  // Step 1. Fetch the corresponding output
+  // TODO(avarma): Make it generic for multiple values yielding scf.forall.
+  auto iterArg = dyn_cast<BlockArgument>(dest->get());
+  Value resultingValue = iterArg.getOwner()->getParentOp()->getResult(0);
+
+  // Step 3. Get users.
+  std::optional<OpOperand *> destinationIterArg;
+  Operation *untiledConsumer;
+  for (Operation *user : resultingValue.getUsers()) {
+    // TODO(avarma): Address the case where the consumer op itself can return
+    //               more than one result.
+    for (Value operand : user->getOperands()) {
+      if (operand == resultingValue) {
+        untiledConsumer = user;
+        break;
+      }
+      operandNumber++;
+    }
+    break;
+  }
+  return {untiledConsumer, destinationIterArg};
+}
+
+/// Implementation of fusing consumer of a single slice by computing the
+/// slice of the consumer in-place for scf.forall.
+static FailureOr<scf::SCFFuseConsumerOfSliceResult>
+tileAndFuseConsumerOfSliceSCFForall(
+    RewriterBase &rewriter, tensor::ParallelInsertSliceOp candidateSliceOp) {
+  // 1. Get the consumer of the dest.
+  unsigned operandNumber = 0;
+  auto [consumerOp, destinationInitArg] =
+      getUntiledConsumerFromSliceDestSCFForall(
+          &candidateSliceOp.getDestMutable(), operandNumber);
+  if (!consumerOp)
+    return failure();
+  OpBuilder::InsertionGuard g(rewriter);
+  // Using candidateSliceOp->getParentOp() because we have the following case :-
+  // scf.forall.in_parallel {
+  //   tensor.parallel_insert_slice ...
+  // }
+  rewriter.setInsertionPoint(candidateSliceOp->getParentOp());
+
+  Operation *containingOp = candidateSliceOp->getParentOp()->getParentOp();
+  // Check consumer has tiling interface.
+  auto tileableConsumer = dyn_cast<TilingInterface>(consumerOp);
+  if (!tileableConsumer) {
+    return rewriter.notifyMatchFailure(consumerOp,
+                                       "consumer is not a TileableInterface");
+  }
+
+  // Check containing op is "scf::ForallOp".
+  auto forallOp = dyn_cast<scf::ForallOp>(containingOp);
+  if (!forallOp) {
+    return rewriter.notifyMatchFailure(containingOp,
+                                       "containing op is not a scf.forall");
+  }
+
+  // Check consumer don't use more than one result of containingOp.
+  // Check containingOp has exactly one use.
+  assert(forallOp.getResults().size() == 1 &&
+         "expect exactly one result of the containing op");
+  if (!checkAssumptionForFusingConsumer(forallOp)) {
+    return rewriter.notifyMatchFailure(forallOp,
+                                       "scf.forall has more than 1 uses");
+  }
+
+  // TODO: We have to init result of consumer before scf.forall, use
+  //       DestinationStyleOpInterface to get result shape from init for now.
+  //       Add support for other op such as op has InferTypeOpInterface.
+  // Check consumer has DestinationStyleOpInterface.
+  auto dstOp = dyn_cast<DestinationStyleOpInterface>(consumerOp);
+  if (!dstOp) {
+    return rewriter.notifyMatchFailure(
+        consumerOp, "consumer op should have destination style op interface");
+  }
+
+  // Check consumer doon't use scf.forall's output as init.
+  SmallVector<Value> dpsInits = llvm::to_vector<4>(
+      llvm::map_range(dstOp.getDpsInits(), [](Value v) { return v; }));
+  if (llvm::is_contained(dpsInits, forallOp.getResult(0))) {
+    return rewriter.notifyMatchFailure(
+        consumerOp,
+        "consumer op taking the result of scf.forall as init is not supported");
+  }
+
+  SmallVector<OpFoldResult> offsets = candidateSliceOp.getMixedOffsets();
+  SmallVector<OpFoldResult> sizes = candidateSliceOp.getMixedSizes();
+  SmallVector<OpFoldResult> strides = candidateSliceOp.getMixedStrides();
+
+  // Check all insert stride is 1.
+  if (llvm::any_of(strides, [](OpFoldResult stride) {
+        return !isConstantIntValue(stride, 1);
+      })) {
+    return rewriter.notifyMatchFailure(
+        candidateSliceOp, "containingOp's result yield with stride");
+  }
+
+  Location loc = forallOp.getLoc();
+  // Create new scf.forall op.
+  SmallVector<Value> newOuts(forallOp.getOutputs());
+  newOuts.append(dpsInits);
+  rewriter.setInsertionPoint(consumerOp);
+  auto newforallOp = rewriter.create<scf::ForallOp>(
+      loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
+      forallOp.getMixedStep(), newOuts, forallOp.getMapping());
+
+  // Move the loop body to the new op.
+  rewriter.eraseOp(newforallOp.getTerminator());
+  Block *loopBody = forallOp.getBody();
+  Block *newLoopBody = newforallOp.getBody();
+  rewriter.mergeBlocks(
+      loopBody, newLoopBody,
+      newLoopBody->getArguments().take_front(loopBody->getNumArguments()));
+
+  // Clone the consumer after the insert_slice.
+  rewriter.setInsertionPointAfter(candidateSliceOp);
+  SmallVector<Value> newForOpBlockArgsForConsumerDest;
+  for (unsigned i = loopBody->getNumArguments(),
+                n = newLoopBody->getArguments().size();
+       i < n; i++) {
+    newForOpBlockArgsForConsumerDest.push_back(newLoopBody->getArgument(i));
+  }
+  auto clonedConsumerOp = cast<TilingInterface>(cloneOpAndUpdateDestinationArgs(
+      rewriter, consumerOp, newForOpBlockArgsForConsumerDest));
+
+  // Replace scf.forall result's use in the consumer with parallel_insert_slice
+  // source.
+  rewriter.replaceAllUsesWith(forallOp.getResult(0),
+                              candidateSliceOp.getSource());
+
+  // Generate the tiled implementation of the consumer of the source.
+  rewriter.setInsertionPoint(candidateSliceOp->getParentOp());
+  FailureOr<TilingResult> tileAndFuseResult =
+      tensor::replaceInsertSliceWithTiledConsumer(
+          rewriter, candidateSliceOp,
+          clonedConsumerOp->getOpOperand(operandNumber));
+  if (failed(tileAndFuseResult)) {
+    return rewriter.notifyMatchFailure(tileableConsumer,
+                                       "failed to tile consumer op: ");
+  }
+
+  // Update the source of the candidateSlice to be the cloned consumer.
+  rewriter.setInsertionPointAfter(candidateSliceOp);
+  SmallVector<Value> candidateSliceOpOperands =
+      llvm::to_vector(candidateSliceOp->getOperands());
+  candidateSliceOpOperands[0] = tileAndFuseResult->tiledValues[0];
+  tensor::ParallelInsertSliceOp clonedCandidateSliceOp =
+      mlir::clone(rewriter, candidateSliceOp,
+                  candidateSliceOp->getResultTypes(), candidateSliceOpOperands);
+  auto bbArgs = newforallOp.getBody()->getArguments();
+  clonedCandidateSliceOp->getOpOperands()[1].set(
+      bbArgs[forallOp.getRank() + forallOp.getOutputs().size() + 0]);
+
+  rewriter.eraseOp(clonedConsumerOp);
+
+  // Replace the result of scf.forall and consumer op.
+  for (auto result : llvm::enumerate(forallOp.getResults())) {
+    rewriter.replaceAllUsesWith(result.value(),
+                                newforallOp->getResult(result.index()));
+  }
+
+  for (auto consumerResult : llvm::enumerate(consumerOp->getResults())) {
+    rewriter.replaceAllUsesWith(
+        consumerResult.value(),
+        newforallOp->getResult(forallOp.getOutputs().size() +
+                               consumerResult.index()));
+  }
+
+  // Need to erase the old scf.forall and consumer.
+  rewriter.eraseOp(forallOp);
----------------
MaheshRavishankar wrote:

Ok, that makes sense... we probably need a DCE on the `scf.forall` to remove the unused original result.

https://github.com/llvm/llvm-project/pull/88712


More information about the Mlir-commits mailing list