[Mlir-commits] [mlir] [MLIR][SCF] Add an API to fuse consumer to a producer within scf loop (PR #88712)

Oleksandr Alex Zinenko llvmlistbot at llvm.org
Mon Apr 22 04:57:19 PDT 2024


================
@@ -1100,6 +1102,459 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
                                    replacements};
 }
 
+//===----------------------------------------------------------------------===//
+// tileAndFuseConsumerUsingSCF implementation.
+//===----------------------------------------------------------------------===//
+
+/// A utility function that checks whether the passed value has only one user.
+/// In case the defining operation is a tensor.insert_slice, it checks if the
+/// user is scf.yield.
+static LogicalResult checkAssumptionForFusingConsumer(Value result) {
+  Value::use_range uses = result.getUses();
+  if (!llvm::hasSingleElement(uses)) {
+    LLVM_DEBUG(llvm::dbgs() << "Too many uses of the candidate slice op\n");
+    return failure();
+  }
+  OpOperand &operandUse = (*uses.begin());
+  Operation *userOp = operandUse.getOwner();
+  if (!isa<scf::YieldOp>(userOp)) {
+    LLVM_DEBUG(llvm::dbgs()
+               << "Expected scf.yield to be the only user, but got -> "
+               << (*userOp));
+    return failure();
+  }
+  return success();
+}
+
+/// Fetch the first untiled consumer of a scf.for's result which is yielded by
+/// a tensor.insert_slice. This function makes the following assumptions :-
+/// 1.  tensor.insert_slice has scf.yield as its only user.
+/// 2.  scf.for's corresponding result has only one use.
+static FailureOr<OpOperand *>
+getUntiledConsumerFromSlice(tensor::InsertSliceOp candidateSliceOp) {
+  Value sliceResult = candidateSliceOp.getResult();
+  if (failed(checkAssumptionForFusingConsumer(candidateSliceOp.getResult()))) {
+    return failure();
+  }
+  // Step 1. Fetch the corresponding output.
+  OpOperand &yieldOpOperand = (*sliceResult.getUses().begin());
+  unsigned resultNumber = yieldOpOperand.getOperandNumber();
+  // Step 2. Check containing op is scf.for.
+  Operation *containingOp = candidateSliceOp->getParentOp();
+  auto forOp = dyn_cast<scf::ForOp>(containingOp);
+  if (!forOp) {
+    return failure();
+  }
+  Value resultingValue = forOp->getResult(resultNumber);
+
+  // Step 3. Check resulting value of scf.for has exactly one use.
+  if (!llvm::hasSingleElement(resultingValue.getUses())) {
+    return failure();
+  }
+
+  // Step 4. Get uses.
+  OpOperand &operand = (*resultingValue.getUses().begin());
+  Operation *consumerOp = operand.getOwner();
+  // TODO: We have to init result of consumer before scf.for, use
+  //       DestinationStyleOpInterface to get result shape from init for now.
+  //       Add support for other op such as op has InferTypeOpInterface.
+  if (!isa<TilingInterface>(consumerOp) ||
+      !isa<DestinationStyleOpInterface>(consumerOp)) {
+    return failure();
+  }
+  return &operand;
+}
+
+/// Implementation of fusing consumer of a single slice by computing the
+/// slice of the consumer in-place for scf.for.
+static FailureOr<scf::SCFFuseConsumerOfSliceResult>
+tileAndFuseConsumerOfSliceSCFFor(RewriterBase &rewriter,
+                                 tensor::InsertSliceOp candidateSliceOp) {
+  // 1. Get the consumer of scf.for for the result yielded by
+  // tensor.insert_slice.
+  FailureOr<OpOperand *> consumerOpOperand =
+      getUntiledConsumerFromSlice(candidateSliceOp);
+  if (failed(consumerOpOperand)) {
+    return rewriter.notifyMatchFailure(candidateSliceOp,
+                                       "could not fetch consumer to fuse");
+  }
+  Operation *consumerOp = (*consumerOpOperand)->getOwner();
+  unsigned operandNumber = (*consumerOpOperand)->getOperandNumber();
+  unsigned resultNumber =
+      cast<OpResult>((*consumerOpOperand)->get()).getResultNumber();
+
+  auto forOp = candidateSliceOp->getParentOfType<scf::ForOp>();
+
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(candidateSliceOp);
+
+  auto dstOp = static_cast<DestinationStyleOpInterface>(consumerOp);
+  // 2. Check consumer is not using scf.for's output as init.
+  SmallVector<Value> dpsInits =
+      llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; });
+  if (llvm::is_contained(dpsInits, forOp.getResult(0))) {
+    return rewriter.notifyMatchFailure(
+        consumerOp,
+        "consumer op taking the result of scf.for as init is not supported");
+  }
+
+  Location loc = forOp.getLoc();
+  SmallVector<Value> newOuts(forOp.getInits());
+  newOuts.append(dpsInits);
+
+  // 3. Create new scf.for op.
+  rewriter.setInsertionPoint(consumerOp);
+  auto newforOp = rewriter.create<scf::ForOp>(loc, forOp.getLowerBound(),
+                                              forOp.getUpperBound(),
+                                              forOp.getStep(), newOuts);
+  // 4. Move the loop body to the new op.
+  Block *loopBody = forOp.getBody();
+  Block *newLoopBody = newforOp.getBody();
+  rewriter.mergeBlocks(
+      loopBody, newLoopBody,
+      newLoopBody->getArguments().take_front(loopBody->getNumArguments()));
+
+  // 5. Clone tensor.insert_slice after original tensor.insert_slice.
+  rewriter.setInsertionPointAfter(candidateSliceOp);
+  SmallVector<Value> candidateSliceOpOperands =
+      llvm::to_vector(candidateSliceOp->getOperands());
+  tensor::InsertSliceOp clonedCandidateSliceOp =
+      mlir::clone(rewriter, candidateSliceOp,
+                  candidateSliceOp->getResultTypes(), candidateSliceOpOperands);
+
+  // 6.a. Clone consumer after the cloned tensor.insert_slice op.
+  rewriter.setInsertionPointAfter(clonedCandidateSliceOp);
+  SmallVector<Value> newForOpBlockArgsForConsumerDest = llvm::map_to_vector(
+      newLoopBody->getArguments().drop_front(loopBody->getNumArguments()),
+      [](BlockArgument b) -> Value { return b; });
+  auto clonedConsumerOp = cast<TilingInterface>(cloneOpAndUpdateDestinationArgs(
+      rewriter, consumerOp, newForOpBlockArgsForConsumerDest));
+
+  // 6.b. Replace all uses of the loop result with the result of the cloned
+  // tensor.insert_slice.
+  rewriter.replaceUsesWithIf(forOp.getResult(resultNumber),
+                             clonedCandidateSliceOp.getResult(),
+                             [&](OpOperand &operand) {
+                               return operand.getOwner() == clonedConsumerOp;
+                             });
+
+  // 7 - Perform tiling of the cloned consumer.
+  rewriter.setInsertionPointAfter(clonedConsumerOp);
+  FailureOr<TilingResult> tileAndFuseResult =
+      tensor::replaceInsertSliceWithTiledConsumer(
+          rewriter,
+          cast<OffsetSizeAndStrideOpInterface>(
+              clonedCandidateSliceOp.getOperation()),
+          clonedConsumerOp->getOpOperand(operandNumber));
+  if (failed(tileAndFuseResult)) {
+    return rewriter.notifyMatchFailure(clonedConsumerOp,
+                                       "failed to tile consumer op: ");
+  }
+  assert(!(tileAndFuseResult->tiledOps.empty()) && "tiled consumer not found");
+
+  // 8 - Extract offset/sizes/strides required to create the tensor.insert_slice
+  // for each result of the consumer.
+  SmallVector<OpFoldResult> offsets = clonedCandidateSliceOp.getMixedOffsets();
+  SmallVector<OpFoldResult> sizes = clonedCandidateSliceOp.getMixedSizes();
+  SmallVector<OpFoldResult> strides = clonedCandidateSliceOp.getMixedStrides();
+  // 9. Check all insert stride is 1.
+  if (llvm::any_of(strides, [](OpFoldResult stride) {
+        return !isConstantIntValue(stride, 1);
+      })) {
+    return rewriter.notifyMatchFailure(
+        clonedCandidateSliceOp, "containingOp's result yield with stride");
+  }
+  SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
+  // 10. Try to get iter domain position from input position.
+  rewriter.setInsertionPointAfter(clonedConsumerOp);
+  if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTile(
+          rewriter, operandNumber, offsets, sizes, iterDomainOffsets,
+          iterDomainSizes))) {
+    return rewriter.notifyMatchFailure(
+        clonedConsumerOp, "can't get iter domain position from input position");
+  }
+
+  // 11. Try to get all containing op result's position from iter domain
+  // position.
+  llvm::SmallVector<std::pair<llvm::SmallVector<OpFoldResult>,
+                              llvm::SmallVector<OpFoldResult>>>
+      resultPositions(clonedConsumerOp->getNumResults());
+  for (auto [idx, v] : llvm::enumerate(clonedConsumerOp->getResults())) {
+    if (failed(clonedConsumerOp.getResultTilePosition(
+            rewriter, idx, iterDomainOffsets, iterDomainSizes,
+            resultPositions[idx].first, resultPositions[idx].second))) {
+      return rewriter.notifyMatchFailure(
+          clonedConsumerOp,
+          "can't get result domain position from iter domain position");
+    }
+  }
+
+  // 12. Fix terminator.
+  scf::YieldOp oldTerminatorOp =
+      static_cast<scf::YieldOp>(newforOp.getBody()->getTerminator());
+  SmallVector<Value> newYieldOperands(oldTerminatorOp.getResults());
+  rewriter.setInsertionPointAfter(oldTerminatorOp);
+  auto bbArgs = newforOp.getBody()->getArguments();
+  for (auto [idx, v] :
+       llvm::enumerate(tileAndFuseResult->tiledOps[0]->getResults())) {
+    SmallVector<OpFoldResult> strides(resultPositions[idx].first.size(),
+                                      rewriter.getIndexAttr(1));
+    Value newInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
+        clonedCandidateSliceOp->getLoc(), v,
+        bbArgs[1 + forOp.getInits().size() + idx], resultPositions[idx].first,
+        resultPositions[idx].second, strides);
+    newYieldOperands.push_back(newInsertSliceOp);
+  }
+  rewriter.create<scf::YieldOp>(loc, newYieldOperands);
+  rewriter.eraseOp(oldTerminatorOp);
+
+  // 13. Replace the result of scf.for and consumer op.
+  for (auto result : llvm::enumerate(forOp.getResults())) {
+    rewriter.replaceAllUsesWith(result.value(),
+                                newforOp->getResult(result.index()));
+  }
+
+  for (auto consumerResult : llvm::enumerate(consumerOp->getResults())) {
+    rewriter.replaceAllUsesWith(
+        consumerResult.value(),
+        newforOp->getResult(forOp.getInits().size() + consumerResult.index()));
+  }
+
+  rewriter.replaceOp(candidateSliceOp, clonedCandidateSliceOp);
+
+  // 14. Need to erase the old scf.for and the cloned consumer op.
+  rewriter.eraseOp(forOp);
+  rewriter.eraseOp(clonedConsumerOp);
+
+  return scf::SCFFuseConsumerOfSliceResult{
+      consumerOp, tileAndFuseResult->tiledOps[0], {}};
+}
+
+/// Fetch the first untiled consumer of a scf.forall's result which is yielded
+/// by a tensor.parallel_insert_slice.
+static FailureOr<OpOperand *>
+getUntiledConsumerFromSlice(tensor::ParallelInsertSliceOp candidateSliceOp) {
+  // Step 1. Fetch the corresponding output
+  Value sliceDest = candidateSliceOp.getDest();
+  auto iterArg = dyn_cast<BlockArgument>(sliceDest);
+  Operation *containingOp = iterArg.getOwner()->getParentOp();
+  // Step 2. Check that the containing op is scf.forall.
+  auto forallOp = dyn_cast<scf::ForallOp>(containingOp);
+  if (!forallOp) {
+    return failure();
+  }
+  unsigned resultNumber = 0;
+  for (BlockArgument val : forallOp.getRegionOutArgs()) {
+    if (val == iterArg) {
+      break;
+    }
+    resultNumber++;
+  }
+  Value resultingValue = forallOp->getResult(resultNumber);
+  // Step 3. Check resulting value of scf.forall has exactly one use.
+  Value::use_range uses = resultingValue.getUses();
+  if (!llvm::hasSingleElement(uses)) {
+    return failure();
+  }
+
+  // Step 4. Get uses.
+  OpOperand &operand = (*resultingValue.getUses().begin());
+  Operation *consumerOp = operand.getOwner();
+  // TODO: We have to init result of consumer before scf.for, use
+  //       DestinationStyleOpInterface to get result shape from init for now.
+  //       Add support for other op such as op has InferTypeOpInterface.
+  if (!isa<TilingInterface>(consumerOp) ||
+      !isa<DestinationStyleOpInterface>(consumerOp)) {
+    return failure();
+  }
+  return &operand;
+}
+
+/// Implementation of fusing consumer of a single slice by computing the
+/// slice of the consumer in-place for scf.forall.
+static FailureOr<scf::SCFFuseConsumerOfSliceResult>
+tileAndFuseConsumerOfSliceSCFForall(
+    RewriterBase &rewriter, tensor::ParallelInsertSliceOp candidateSliceOp) {
+  // 1. Get the consumer of the dest.
+  FailureOr<OpOperand *> consumerOpOperand =
+      getUntiledConsumerFromSlice(candidateSliceOp);
+  if (failed(consumerOpOperand)) {
+    return rewriter.notifyMatchFailure(candidateSliceOp,
+                                       "could not fetch consumer to fuse");
+  }
+  Operation *consumerOp = (*consumerOpOperand)->getOwner();
+  unsigned operandNumber = (*consumerOpOperand)->getOperandNumber();
+  unsigned resultNumber =
+      cast<OpResult>((*consumerOpOperand)->get()).getResultNumber();
+
+  OpBuilder::InsertionGuard g(rewriter);
+  // Using candidateSliceOp->getParentOp() because we have the following case :-
+  // scf.forall.in_parallel {
+  //   tensor.parallel_insert_slice ...
+  // }
+  rewriter.setInsertionPoint(candidateSliceOp->getParentOp());
+
+  Operation *containingOp = candidateSliceOp->getParentOp()->getParentOp();
+  auto forallOp = static_cast<scf::ForallOp>(containingOp);
+
+  auto dstOp = static_cast<DestinationStyleOpInterface>(consumerOp);
+  // 2. Check consumer is not using scf.forall's output as init.
+  SmallVector<Value> dpsInits =
+      llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; });
+  if (llvm::is_contained(dpsInits, forallOp.getResult(resultNumber))) {
+    return rewriter.notifyMatchFailure(
+        consumerOp,
+        "consumer op taking the result of scf.forall as init is not supported");
+  }
+
+  Location loc = forallOp.getLoc();
+  // 3. Create new scf.forall op.
+  SmallVector<Value> newOuts(forallOp.getOutputs());
+  newOuts.append(dpsInits);
+  rewriter.setInsertionPoint(consumerOp);
+  auto newforallOp = rewriter.create<scf::ForallOp>(
+      loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
+      forallOp.getMixedStep(), newOuts, forallOp.getMapping());
+
+  // 4. Move the loop body to the new op.
+  rewriter.eraseOp(newforallOp.getTerminator());
+  Block *loopBody = forallOp.getBody();
+  Block *newLoopBody = newforallOp.getBody();
+  rewriter.mergeBlocks(
+      loopBody, newLoopBody,
+      newLoopBody->getArguments().take_front(loopBody->getNumArguments()));
+
+  // 5. Clone tensor.parallel_insert_slice after the original
+  // tensor.parallel_insert_slice.
+  rewriter.setInsertionPointAfter(candidateSliceOp);
+  SmallVector<Value> candidateSliceOpOperands =
+      llvm::to_vector(candidateSliceOp->getOperands());
+  tensor::ParallelInsertSliceOp clonedCandidateSliceOp =
+      mlir::clone(rewriter, candidateSliceOp,
+                  candidateSliceOp->getResultTypes(), candidateSliceOpOperands);
+
+  // 6.a. Clone the consumer after the cloned tensor.parallel_insert_slice.
+  rewriter.setInsertionPointAfter(clonedCandidateSliceOp);
+  SmallVector<Value> newForOpBlockArgsForConsumerDest = llvm::map_to_vector(
+      newLoopBody->getArguments().drop_front(loopBody->getNumArguments()),
+      [](BlockArgument b) -> Value { return b; });
+  auto clonedConsumerOp = cast<TilingInterface>(cloneOpAndUpdateDestinationArgs(
+      rewriter, consumerOp, newForOpBlockArgsForConsumerDest));
+
+  // 6.b. Replace all uses of the scf.forall's result use in the consumer with
+  // the source of the cloned tensor.parallel_insert_slice.
+  rewriter.replaceUsesWithIf(forallOp.getResult(resultNumber),
+                             clonedCandidateSliceOp.getSource(),
+                             [&](OpOperand &operand) {
+                               return operand.getOwner() == clonedConsumerOp;
+                             });
+
+  // 7. Perform tiling of the cloned consumer.
+  rewriter.setInsertionPoint(newforallOp.getTerminator());
+  FailureOr<TilingResult> tileAndFuseResult =
+      tensor::replaceInsertSliceWithTiledConsumer(
+          rewriter,
+          cast<OffsetSizeAndStrideOpInterface>(
+              clonedCandidateSliceOp.getOperation()),
+          clonedConsumerOp->getOpOperand(operandNumber));
+  if (failed(tileAndFuseResult)) {
+    return rewriter.notifyMatchFailure(clonedConsumerOp,
+                                       "failed to tile consumer op: ");
+  }
+  assert(!(tileAndFuseResult->tiledOps.empty()) && "tiled consumer not found");
+
+  // 8. Extract offset/sizes/strides required to create the
+  // tensor.parallel_insert_slice for each result of the consumer.
+  SmallVector<OpFoldResult> offsets = clonedCandidateSliceOp.getMixedOffsets();
+  SmallVector<OpFoldResult> sizes = clonedCandidateSliceOp.getMixedSizes();
+  SmallVector<OpFoldResult> strides = clonedCandidateSliceOp.getMixedStrides();
+  // 9. Check all insert stride is 1.
+  if (llvm::any_of(strides, [](OpFoldResult stride) {
+        return !isConstantIntValue(stride, 1);
+      })) {
+    return rewriter.notifyMatchFailure(
+        clonedCandidateSliceOp, "containingOp's result yield with stride");
+  }
+  SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
+  // 10. Try to get iter domain position from input position.
+  rewriter.setInsertionPointAfter(tileAndFuseResult->tiledOps[0]);
+  ;
+  if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTile(
+          rewriter, operandNumber, offsets, sizes, iterDomainOffsets,
+          iterDomainSizes))) {
+    return rewriter.notifyMatchFailure(
+        clonedConsumerOp, "can't get iter domain position from input position");
+  }
+
+  // 11. Try to get all containing op result's position from iter domain
+  // position.
+  llvm::SmallVector<std::pair<llvm::SmallVector<OpFoldResult>,
+                              llvm::SmallVector<OpFoldResult>>>
+      resultPositions(clonedConsumerOp->getNumResults());
+  for (auto [idx, v] : llvm::enumerate(clonedConsumerOp->getResults())) {
+    if (failed(clonedConsumerOp.getResultTilePosition(
+            rewriter, idx, iterDomainOffsets, iterDomainSizes,
+            resultPositions[idx].first, resultPositions[idx].second))) {
+      return rewriter.notifyMatchFailure(
+          clonedConsumerOp,
+          "can't get result domain position from iter domain position");
+    }
+  }
+
+  // 12. Fix terminator.
+  scf::InParallelOp newTerminatorOp = newforallOp.getTerminator();
+  SmallVector<Operation *> yieldingOps = llvm::map_to_vector(
+      newTerminatorOp.getYieldingOps(), [](Operation &op) { return &op; });
+  Operation *firstYieldOp = yieldingOps.front();
----------------
ftynse wrote:

There is no need to go over all yielding ops if only the first one is necessary. There is also no need to materialize a vector. All these things have a cost. Even if that cost is small, having plenty of such costs everywhere in the codebase accumulates significantly. And while the compiler may be able to optimize this away, increase compile time is also a cost.

https://github.com/llvm/llvm-project/pull/88712


More information about the Mlir-commits mailing list