[Mlir-commits] [mlir] [mlir][TilingInterface] Add scf::tileUsingSCFForallOp method to tile using the interface to generate `scf::forall`. (PR #67083)

lorenzo chelini llvmlistbot at llvm.org
Thu Oct 19 03:04:30 PDT 2023


================
@@ -728,6 +741,121 @@ mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
                                    getAsOperations(forLoops), replacements};
 }
 
+//===----------------------------------------------------------------------===//
+// tileUsingSCFForAllOp implementation.
+//===----------------------------------------------------------------------===//
+
+FailureOr<scf::SCFTilingResult>
+mlir::scf::tileUsingSCFForallOp(RewriterBase &rewriter, TilingInterface op,
+                                const scf::SCFTilingOptions &options) {
+  Location loc = op->getLoc();
+  OpBuilder::InsertionGuard g(rewriter);
+
+  // 1. Get the range of loops that are represented by the operation.
+  SmallVector<Range> loopRanges = op.getIterationDomain(rewriter);
+  if (loopRanges.empty())
+    return op->emitOpError("expected non-empty loop ranges");
+  auto hasStrideOne = [](Range r) { return !isConstantIntValue(r.stride, 1); };
+  if (llvm::any_of(loopRanges, hasStrideOne))
+    return op->emitOpError("only stride-1 supported atm");
+
+  // 2. Get the tile sizes. If tile size is 0, it is not tiled and distributed.
+  // To make it easier, pad the tile sizes to loopRanges.size with value 0.
+  SmallVector<OpFoldResult> tileSizeVector =
+      options.tileSizeComputationFunction(rewriter, op);
+  tileSizeVector.resize(loopRanges.size(), rewriter.getIndexAttr(0));
+
+  // 3. Build the offsets, sizes and steps for the tile and distributed loops.
+  SmallVector<OpFoldResult> lbs, ubs, steps;
+  for (auto [index, tileSize, loopRange] :
+       llvm::enumerate(tileSizeVector, loopRanges)) {
+    if (isConstantIntValue(tileSize, 0))
+      continue;
+    lbs.push_back(loopRange.offset);
+    ubs.push_back(loopRange.size);
+    steps.push_back(tileSize);
+  }
+
+  // 4. Gather destination tensors.
+  SmallVector<Value> dest;
+  if (failed(tensor::getOrCreateDestinations(rewriter, loc, op, dest)))
+    return op->emitOpError("failed to get destination tensors");
+
+  // 5. Build the device mapping attribute;
+  std::optional<ArrayAttr> mappingAttr;
+  if (!options.mappingVector.empty()) {
+    mappingAttr = rewriter.getArrayAttr(ArrayRef(options.mappingVector));
+  }
+
+  // 6. Create the ForallOp. We don't use the lambda body-builder
+  // version because we require the use of RewriterBase in the body, so we
+  // manually move the insertion point to the body below.
+  auto forallOp =
+      rewriter.create<scf::ForallOp>(loc, lbs, ubs, steps, dest, mappingAttr);
+
+  // 7. Get the tile offset and sizes.
+  rewriter.setInsertionPoint(forallOp.getTerminator());
+  SmallVector<OpFoldResult> tiledOffsets, tiledSizes;
+  tiledOffsets.reserve(loopRanges.size());
+  tiledSizes.reserve(loopRanges.size());
+  ValueRange ivs = forallOp.getInductionVars();
+  {
+    int materializedLoopNum = 0;
+    for (auto [index, tileSize, loopRange] :
+         llvm::enumerate(tileSizeVector, loopRanges)) {
+      if (isConstantIntValue(tileSize, 0)) {
+        tiledOffsets.push_back(loopRange.offset);
+        tiledSizes.push_back(loopRange.size);
+        continue;
+      }
+      Value iv = ivs[materializedLoopNum++];
+      tiledOffsets.push_back(iv);
+      tiledSizes.push_back(
+          getBoundedTileSize(rewriter, loc, loopRange, iv, tileSize));
+    }
+  }
+
+  // 8. Tile the operation. Clone the operation to allow fix up of destination
+  // operands
+  ArrayRef<BlockArgument> destBbArgs = forallOp.getOutputBlockArguments();
+  Operation *clonedOp =
+      cloneOpAndUpdateDestinationArgs(rewriter, op, destBbArgs);
+  FailureOr<TilingResult> tilingResult =
+      cast<TilingInterface>(clonedOp).getTiledImplementation(
+          rewriter, tiledOffsets, tiledSizes);
+  if (failed(tilingResult))
+    return clonedOp->emitError("Failed to tile op: ");
+  rewriter.eraseOp(clonedOp);
+
+  // 9. Parallel insert back into the result tensor.
+  for (auto [index, tiledValue, destBBArg] :
+       llvm::enumerate(tilingResult->tiledValues, destBbArgs)) {
+    // 9.a. Partial subset information is inserted just before the terminator.
+    rewriter.setInsertionPoint(forallOp.getTerminator());
+
+    SmallVector<OpFoldResult> resultOffsets, resultSizes;
+    if (failed(op.getResultTilePosition(rewriter, index, tiledOffsets,
----------------
chelini wrote:

nit: braces here.

https://github.com/llvm/llvm-project/pull/67083


More information about the Mlir-commits mailing list