[Mlir-commits] [mlir] [MLIR][SCF] Add an API to fuse consumer to a producer within scf loop (PR #88712)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Apr 15 14:40:07 PDT 2024


================
@@ -820,6 +874,429 @@ getUntiledProducerFromSliceSource(OpOperand *source,
   return {dyn_cast<OpResult>(source->get()), destinationIterArg};
 }
 
+/// Implementation of fusing consumer of a single slice by computing the
+/// slice of the consumer in-place for scf.forall.
+static FailureOr<scf::SCFFuseConsumerOfSliceResult>
+tileAndFuseConsumerOfSliceSCFForall(
+    RewriterBase &rewriter, tensor::ParallelInsertSliceOp candidateSliceOp) {
+  // 1. Get the consumer of the source.
+  unsigned operandNumber = 0;
+  auto [consumerOp, destinationInitArg] =
+      getUntiledConsumerFromSliceDestSCFForall(
+          &candidateSliceOp.getDestMutable(), operandNumber);
+  if (!consumerOp)
+    return failure();
+  OpBuilder::InsertionGuard g(rewriter);
+  // Using candidateSliceOp->getParentOp() because we have the following case :-
+  // scf.forall.in_parallel {
+  //   tensor.parallel_insert_slice ...
+  // }
+  rewriter.setInsertionPoint(candidateSliceOp->getParentOp());
+
+  Operation *containingOp = candidateSliceOp->getParentOp()->getParentOp();
+  // Check consumer has tiling interface.
+  auto tileableConsumer = dyn_cast<TilingInterface>(consumerOp);
+  if (!tileableConsumer) {
+    llvm::outs() << "consumer is not a TileableInterface: " << *consumerOp
+                 << "\n";
+    return failure();
+  }
+
+  // Check containing op is "scf::ForallOp".
+  auto forallOp = dyn_cast<scf::ForallOp>(containingOp);
+  if (!forallOp) {
+    llvm::outs() << "containing op is not a scf.forall: " << containingOp
+                 << "\n";
+    return failure();
+  }
+
+  // Check consumer don't use more than one result of containingOp.
+  Value bridge(nullptr);
+  SmallVector<unsigned> operandNums;
+  for (auto [idx, opd] : llvm::enumerate((consumerOp->getOperands()))) {
+    if (opd.getDefiningOp() == containingOp) {
+      operandNums.push_back(idx);
+      if (!bridge) {
+        bridge = opd;
+      } else if (bridge != opd) {
+        llvm::outs()
+            << "consumer's operand use more than one containingOp's result\n";
+        return failure();
+      }
+    }
+  }
+
+  // TODO: We have to init result of consumer before scf.forall, use
+  //       DestinationStyleOpInterface to get result shape from init for now.
+  //       Add support for other op such as op has InferTypeOpInterface.
+  // Check consumer has DestinationStyleOpInterface.
+  auto dstOp = dyn_cast<DestinationStyleOpInterface>(consumerOp);
+  if (!dstOp) {
+    llvm::outs() << "consumer op should have destination style op interface"
+                 << "\n";
+    return failure();
+  }
+
+  // Check consumer doon't use scf.forall's output as init.
+  SmallVector<Value> dpsInits = llvm::to_vector<4>(
+      llvm::map_range(dstOp.getDpsInits(), [](Value v) { return v; }));
+  if (llvm::is_contained(dpsInits, bridge)) {
+    llvm::outs() << "consumer op take result of scf.forall as init\n";
+    return failure();
+  }
+
+  // Check result was inserted only once.
+  int64_t bridgeResultIdx = cast<OpResult>(bridge).getResultNumber();
+  scf::InParallelOp terminatorOp = forallOp.getTerminator();
+
+  SmallVector<OpFoldResult> offsets = candidateSliceOp.getMixedOffsets();
+  SmallVector<OpFoldResult> sizes = candidateSliceOp.getMixedSizes();
+  SmallVector<OpFoldResult> strides = candidateSliceOp.getMixedStrides();
+
+  // Check all insert stride is 1.
+  if (llvm::any_of(strides, [](OpFoldResult foldRes) {
+        if (auto attr = foldRes.dyn_cast<Attribute>()) {
+          return cast<IntegerAttr>(attr).getInt() != 1;
+        }
+        return true;
+      })) {
+    llvm::outs() << "containingOp's result yield with stride\n";
+    return failure();
+  }
+
+  Location loc = forallOp.getLoc();
+  rewriter.setInsertionPoint(terminatorOp);
+
+  SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
+
+  // Try to get iter domain position from input position.
+  if (failed(tileableConsumer.getIterDomainTilePositionFromOperandPosition(
----------------
MaheshRavishankar wrote:

I think this is the consumer directly? This turned out to be a pretty difficult way of doing things. If you look at the tile and fuse producer methods, 
1)  The producer is first cloned into the loop body and all uses within the loops of the original producer is replaced with the cloned op (including in the `tensor.extract_slice` operation that is driving the fusion)
2) Then the `cloned_op` -> `tensor.extract_slice` is transformed to `tensor.extract_slice` -> `cloned_op` using the `replaceExtractSliceWithTiledProducer` method.

This is done to make it much simpler to handle destination operands. I think we need to do something similar here with the following modifications
1) Clone the loop operation (`scf.forall` or `scf.for`) just before the consumer without cloning its region. At the same time the `dpsInit` operands of the consumer need to be added as `iter_args` of the cloned loop operation.
2) Move the body of the original loop into the body of the cloned loop (use the `rewriter.mergeBlocks` API to handle the block arguments correctly between the old loop and new loop)
3) Clone the consumer operation into the loop body of the cloned operation just after the `tensor.insert_slice` operation. During the cloning use the region arguments corresponding to the newly created loop iter_args as the destination for the cloned consumer operation (this gets a bit complicated for `scf.forall` that as the `tensor.parallel_insert_slice, but in this case you clone it justafter the `tensor.insert_in_parallel` operation)
4) For the `scf.for` case replace all uses of the result of the loop in the cloned consumer operation with the `tensor.insert_slice` result. For the `scf.forall` case, replace all uses of the result of the loop in the cloned consumer operation with the corresponding `iter_arg` of the result. (Note that this part is a bit strange, so need to look at the IR to see what to do here properly, but the `scf.for` all case should be more straight-forward. I'd usggest defering the `scf.forall` for now till the `scf.for` case works as intended, then I think I know how to fix the `scf.forall` case but we can come to that later).
5) Transform the `tensor.insert_slice` -> `cloned_consumer`  to `cloned_consumer` -> `tensor.insert_slice` .



https://github.com/llvm/llvm-project/pull/88712


More information about the Mlir-commits mailing list