[Mlir-commits] [mlir] [MLIR][SCF] Add an API to fuse consumer to a producer within scf loop (PR #88712)

Oleksandr Alex Zinenko llvmlistbot at llvm.org
Mon Apr 22 04:25:16 PDT 2024


================
@@ -1100,6 +1102,459 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
                                    replacements};
 }
 
+//===----------------------------------------------------------------------===//
+// tileAndFuseConsumerUsingSCF implementation.
+//===----------------------------------------------------------------------===//
+
+/// A utility function that checks whether the passed value has only one user.
+/// In case the defining operation is a tensor.insert_slice, it checks if the
+/// user is scf.yield.
+static LogicalResult checkAssumptionForFusingConsumer(Value result) {
+  Value::use_range uses = result.getUses();
+  if (!llvm::hasSingleElement(uses)) {
+    LLVM_DEBUG(llvm::dbgs() << "Too many uses of the candidate slice op\n");
+    return failure();
+  }
+  OpOperand &operandUse = (*uses.begin());
+  Operation *userOp = operandUse.getOwner();
+  if (!isa<scf::YieldOp>(userOp)) {
+    LLVM_DEBUG(llvm::dbgs()
+               << "Expected scf.yield to be the only user, but got -> "
+               << (*userOp));
+    return failure();
+  }
+  return success();
+}
+
+/// Fetch the first untiled consumer of a scf.for's result which is yielded by
+/// a tensor.insert_slice. This function makes the following assumptions :-
+/// 1.  tensor.insert_slice has scf.yield as its only user.
+/// 2.  scf.for's corresponding result has only one use.
+static FailureOr<OpOperand *>
+getUntiledConsumerFromSlice(tensor::InsertSliceOp candidateSliceOp) {
+  Value sliceResult = candidateSliceOp.getResult();
+  if (failed(checkAssumptionForFusingConsumer(candidateSliceOp.getResult()))) {
+    return failure();
+  }
+  // Step 1. Fetch the corresponding output.
+  OpOperand &yieldOpOperand = (*sliceResult.getUses().begin());
+  unsigned resultNumber = yieldOpOperand.getOperandNumber();
+  // Step 2. Check containing op is scf.for.
+  Operation *containingOp = candidateSliceOp->getParentOp();
+  auto forOp = dyn_cast<scf::ForOp>(containingOp);
+  if (!forOp) {
+    return failure();
+  }
+  Value resultingValue = forOp->getResult(resultNumber);
+
+  // Step 3. Check resulting value of scf.for has exactly one use.
+  if (!llvm::hasSingleElement(resultingValue.getUses())) {
+    return failure();
+  }
+
+  // Step 4. Get uses.
+  OpOperand &operand = (*resultingValue.getUses().begin());
+  Operation *consumerOp = operand.getOwner();
+  // TODO: We have to init result of consumer before scf.for, use
+  //       DestinationStyleOpInterface to get result shape from init for now.
+  //       Add support for other op such as op has InferTypeOpInterface.
+  if (!isa<TilingInterface>(consumerOp) ||
+      !isa<DestinationStyleOpInterface>(consumerOp)) {
+    return failure();
+  }
+  return &operand;
+}
+
+/// Implementation of fusing consumer of a single slice by computing the
+/// slice of the consumer in-place for scf.for.
+static FailureOr<scf::SCFFuseConsumerOfSliceResult>
+tileAndFuseConsumerOfSliceSCFFor(RewriterBase &rewriter,
+                                 tensor::InsertSliceOp candidateSliceOp) {
+  // 1. Get the consumer of scf.for for the result yielded by
+  // tensor.insert_slice.
+  FailureOr<OpOperand *> consumerOpOperand =
+      getUntiledConsumerFromSlice(candidateSliceOp);
+  if (failed(consumerOpOperand)) {
+    return rewriter.notifyMatchFailure(candidateSliceOp,
+                                       "could not fetch consumer to fuse");
+  }
+  Operation *consumerOp = (*consumerOpOperand)->getOwner();
+  unsigned operandNumber = (*consumerOpOperand)->getOperandNumber();
+  unsigned resultNumber =
+      cast<OpResult>((*consumerOpOperand)->get()).getResultNumber();
+
+  auto forOp = candidateSliceOp->getParentOfType<scf::ForOp>();
+
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(candidateSliceOp);
+
+  auto dstOp = static_cast<DestinationStyleOpInterface>(consumerOp);
+  // 2. Check consumer is not using scf.for's output as init.
+  SmallVector<Value> dpsInits =
+      llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; });
+  if (llvm::is_contained(dpsInits, forOp.getResult(0))) {
+    return rewriter.notifyMatchFailure(
+        consumerOp,
+        "consumer op taking the result of scf.for as init is not supported");
+  }
+
+  Location loc = forOp.getLoc();
+  SmallVector<Value> newOuts(forOp.getInits());
+  newOuts.append(dpsInits);
+
+  // 3. Create new scf.for op.
+  rewriter.setInsertionPoint(consumerOp);
+  auto newforOp = rewriter.create<scf::ForOp>(loc, forOp.getLowerBound(),
+                                              forOp.getUpperBound(),
+                                              forOp.getStep(), newOuts);
+  // 4. Move the loop body to the new op.
+  Block *loopBody = forOp.getBody();
+  Block *newLoopBody = newforOp.getBody();
+  rewriter.mergeBlocks(
+      loopBody, newLoopBody,
+      newLoopBody->getArguments().take_front(loopBody->getNumArguments()));
+
+  // 5. Clone tensor.insert_slice after original tensor.insert_slice.
+  rewriter.setInsertionPointAfter(candidateSliceOp);
+  SmallVector<Value> candidateSliceOpOperands =
+      llvm::to_vector(candidateSliceOp->getOperands());
+  tensor::InsertSliceOp clonedCandidateSliceOp =
+      mlir::clone(rewriter, candidateSliceOp,
+                  candidateSliceOp->getResultTypes(), candidateSliceOpOperands);
----------------
ftynse wrote:

I don't quite get why do we need to clone the slice op? Why can't we use the result of the original slice? Please document. 

Also nit: `rewriter.clone` would be more concise here.

https://github.com/llvm/llvm-project/pull/88712


More information about the Mlir-commits mailing list