[Mlir-commits] [mlir] [MLIR][SCF] Add an API to fuse consumer to a producer within scf loop (PR #88712)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Apr 15 14:40:07 PDT 2024


================
@@ -820,6 +874,429 @@ getUntiledProducerFromSliceSource(OpOperand *source,
   return {dyn_cast<OpResult>(source->get()), destinationIterArg};
 }
 
+/// Implementation of fusing consumer of a single slice by computing the
+/// slice of the consumer in-place for scf.forall.
+static FailureOr<scf::SCFFuseConsumerOfSliceResult>
+tileAndFuseConsumerOfSliceSCFForall(
+    RewriterBase &rewriter, tensor::ParallelInsertSliceOp candidateSliceOp) {
+  // 1. Get the consumer of the source.
+  unsigned operandNumber = 0;
+  auto [consumerOp, destinationInitArg] =
+      getUntiledConsumerFromSliceDestSCFForall(
+          &candidateSliceOp.getDestMutable(), operandNumber);
+  if (!consumerOp)
+    return failure();
+  OpBuilder::InsertionGuard g(rewriter);
+  // Using candidateSliceOp->getParentOp() because we have the following case :-
+  // scf.forall.in_parallel {
+  //   tensor.parallel_insert_slice ...
+  // }
+  rewriter.setInsertionPoint(candidateSliceOp->getParentOp());
+
+  Operation *containingOp = candidateSliceOp->getParentOp()->getParentOp();
+  // Check consumer has tiling interface.
+  auto tileableConsumer = dyn_cast<TilingInterface>(consumerOp);
+  if (!tileableConsumer) {
+    llvm::outs() << "consumer is not a TileableInterface: " << *consumerOp
+                 << "\n";
+    return failure();
+  }
+
+  // Check containing op is "scf::ForallOp".
+  auto forallOp = dyn_cast<scf::ForallOp>(containingOp);
+  if (!forallOp) {
+    llvm::outs() << "containing op is not a scf.forall: " << containingOp
+                 << "\n";
+    return failure();
+  }
+
+  // Check consumer don't use more than one result of containingOp.
+  Value bridge(nullptr);
+  SmallVector<unsigned> operandNums;
+  for (auto [idx, opd] : llvm::enumerate((consumerOp->getOperands()))) {
+    if (opd.getDefiningOp() == containingOp) {
+      operandNums.push_back(idx);
+      if (!bridge) {
+        bridge = opd;
+      } else if (bridge != opd) {
+        llvm::outs()
+            << "consumer's operand use more than one containingOp's result\n";
+        return failure();
+      }
+    }
+  }
+
+  // TODO: We have to init result of consumer before scf.forall, use
+  //       DestinationStyleOpInterface to get result shape from init for now.
+  //       Add support for other op such as op has InferTypeOpInterface.
+  // Check consumer has DestinationStyleOpInterface.
+  auto dstOp = dyn_cast<DestinationStyleOpInterface>(consumerOp);
+  if (!dstOp) {
+    llvm::outs() << "consumer op should have destination style op interface"
+                 << "\n";
+    return failure();
+  }
+
+  // Check consumer doon't use scf.forall's output as init.
+  SmallVector<Value> dpsInits = llvm::to_vector<4>(
+      llvm::map_range(dstOp.getDpsInits(), [](Value v) { return v; }));
+  if (llvm::is_contained(dpsInits, bridge)) {
+    llvm::outs() << "consumer op take result of scf.forall as init\n";
+    return failure();
+  }
+
+  // Check result was inserted only once.
+  int64_t bridgeResultIdx = cast<OpResult>(bridge).getResultNumber();
+  scf::InParallelOp terminatorOp = forallOp.getTerminator();
+
+  SmallVector<OpFoldResult> offsets = candidateSliceOp.getMixedOffsets();
+  SmallVector<OpFoldResult> sizes = candidateSliceOp.getMixedSizes();
+  SmallVector<OpFoldResult> strides = candidateSliceOp.getMixedStrides();
+
+  // Check all insert stride is 1.
+  if (llvm::any_of(strides, [](OpFoldResult foldRes) {
----------------
MaheshRavishankar wrote:

Use `isConstantIntValue(foldRes, 1)` instead.

https://github.com/llvm/llvm-project/pull/88712


More information about the Mlir-commits mailing list