[Mlir-commits] [mlir] [mlir][TilingInterface] Add scf::tileUsingSCFForallOp method to tile using the interface to generate `scf::forall`. (PR #67083)
Han-Chung Wang
llvmlistbot at llvm.org
Wed Oct 18 17:09:43 PDT 2023
================
@@ -728,6 +741,121 @@ mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
getAsOperations(forLoops), replacements};
}
+//===----------------------------------------------------------------------===//
+// tileUsingSCFForAllOp implementation.
+//===----------------------------------------------------------------------===//
+
+FailureOr<scf::SCFTilingResult>
+mlir::scf::tileUsingSCFForallOp(RewriterBase &rewriter, TilingInterface op,
+ const scf::SCFTilingOptions &options) {
+ Location loc = op->getLoc();
+ OpBuilder::InsertionGuard g(rewriter);
+
+ // 1. Get the range of loops that are represented by the operation.
+ SmallVector<Range> loopRanges = op.getIterationDomain(rewriter);
+ if (loopRanges.empty())
+ return op->emitOpError("expected non-empty loop ranges");
+ auto hasStrideOne = [](Range r) { return !isConstantIntValue(r.stride, 1); };
+ if (llvm::any_of(loopRanges, hasStrideOne))
+ return op->emitOpError("only stride-1 supported atm");
+
+ // 2. Get the tile sizes. If tile size is 0, it is not tiled and distributed.
+ // To make it easier, pad the tile sizes to loopRanges.size with value 0.
+ SmallVector<OpFoldResult> tileSizeVector =
+ options.tileSizeComputationFunction(rewriter, op);
+ tileSizeVector.resize(loopRanges.size(), rewriter.getIndexAttr(0));
+
+ // 3. Build the offsets, sizes and steps for the tile and distributed loops.
+ SmallVector<OpFoldResult> lbs, ubs, steps;
+ for (auto [index, tileSize, loopRange] :
+ llvm::enumerate(tileSizeVector, loopRanges)) {
+ if (isConstantIntValue(tileSize, 0))
+ continue;
+ lbs.push_back(loopRange.offset);
+ ubs.push_back(loopRange.size);
+ steps.push_back(tileSize);
+ }
+
+ // 4. Gather destination tensors.
+ SmallVector<Value> dest;
+ if (failed(tensor::getOrCreateDestinations(rewriter, loc, op, dest)))
+ return op->emitOpError("failed to get destination tensors");
+
+ // 5. Build the device mapping attribute;
+ std::optional<ArrayAttr> mappingAttr;
+ if (!options.mappingVector.empty()) {
+ mappingAttr = rewriter.getArrayAttr(ArrayRef(options.mappingVector));
+ }
+
+ // 6. Create the ForallOp. We don't use the lambda body-builder
+ // version because we require the use of RewriterBase in the body, so we
+ // manually move the insertion point to the body below.
+ auto forallOp =
+ rewriter.create<scf::ForallOp>(loc, lbs, ubs, steps, dest, mappingAttr);
+
+ // 7. Get the tile offset and sizes.
+ rewriter.setInsertionPoint(forallOp.getTerminator());
+ SmallVector<OpFoldResult> tiledOffsets, tiledSizes;
+ tiledOffsets.reserve(loopRanges.size());
+ tiledSizes.reserve(loopRanges.size());
+ ValueRange ivs = forallOp.getInductionVars();
+ {
+ int materializedLoopNum = 0;
+ for (auto [index, tileSize, loopRange] :
+ llvm::enumerate(tileSizeVector, loopRanges)) {
+ if (isConstantIntValue(tileSize, 0)) {
+ tiledOffsets.push_back(loopRange.offset);
+ tiledSizes.push_back(loopRange.size);
+ continue;
+ }
+ Value iv = ivs[materializedLoopNum++];
+ tiledOffsets.push_back(iv);
+ tiledSizes.push_back(
+ getBoundedTileSize(rewriter, loc, loopRange, iv, tileSize));
+ }
+ }
+
+ // 8. Tile the operation. Clone the operation to allow fix up of destination
+ // operands
+ ArrayRef<BlockArgument> destBbArgs = forallOp.getOutputBlockArguments();
+ Operation *clonedOp =
+ cloneOpAndUpdateDestinationArgs(rewriter, op, destBbArgs);
+ FailureOr<TilingResult> tilingResult =
+ cast<TilingInterface>(clonedOp).getTiledImplementation(
+ rewriter, tiledOffsets, tiledSizes);
+ if (failed(tilingResult))
+ return clonedOp->emitError("Failed to tile op: ");
+ rewriter.eraseOp(clonedOp);
+
+ // 9. Parallel insert back into the result tensor.
+ for (auto [index, tiledValue, destBBArg] :
+ llvm::enumerate(tilingResult->tiledValues, destBbArgs)) {
+ // 9.a. Partial subset information is inserted just before the terminator.
+ rewriter.setInsertionPoint(forallOp.getTerminator());
+
+ SmallVector<OpFoldResult> resultOffsets, resultSizes;
+ if (failed(op.getResultTilePosition(rewriter, index, tiledOffsets,
+ tiledSizes, resultOffsets,
+ resultSizes)))
+ return op->emitOpError("output offsets couldn't be calculated");
+ SmallVector<OpFoldResult> strides(resultSizes.size(),
+ rewriter.getIndexAttr(1));
+
+ // 5.b. Parallel insertions are inserted at the end of the combining
----------------
hanhanW wrote:
nit: `9.b`
https://github.com/llvm/llvm-project/pull/67083
More information about the Mlir-commits
mailing list