[Mlir-commits] [mlir] [mlir][SCF] Deprecate `linalg::tileToForallOp` and `linalg::tileToForallOpUsingTileSizes` (PR #91878)
Quinn Dawkins
llvmlistbot at llvm.org
Wed May 22 17:13:15 PDT 2024
================
@@ -304,188 +304,6 @@ static void calculateTileOffsetsAndSizes(
}
}
-/// Returns a vector of bools representing if, for each axis, `op` can be tiled
-/// without incurring in a race condition and thus it is thread-safe to do the
-/// tiling. This is checked by iterating over numThreads and ensuring that the
-/// corresponding iterator type is "parallel". If it is not, then we know that
-/// such dimension is unsafe to tile.
-SmallVector<bool> safeToTileToForall(mlir::MLIRContext *ctx, LinalgOp linalgOp,
- ArrayRef<OpFoldResult> numThreads) {
- auto iterators = linalgOp.getIteratorTypesArray();
- SmallVector<bool> safeToTile(numThreads.size(), true);
-
- for (unsigned i = 0, e = numThreads.size(); i != e; i++) {
- if (auto attr = llvm::dyn_cast_if_present<Attribute>(numThreads[i])) {
- if (cast<IntegerAttr>(attr).getValue().getSExtValue() > 1) {
- safeToTile[i] = iterators[i] == utils::IteratorType::parallel;
- }
- } else {
- safeToTile[i] = iterators[i] == utils::IteratorType::parallel;
- }
- }
- return safeToTile;
-}
-
-/// Rewrite a TilingInterface `op` to a tiled `scf.forall`. The
-/// tiling is specified by the number of tiles/threads `numThreads` and the
-/// optional nominal tile size `nominalTileSizes`. If `nominalTilSizes` is
-/// not specified, then it is derived from `numThreads` as `ceilDiv(dimSize[i],
-/// numThreads[i])`. If non-empty, the `mapping` is added as an
-/// attribute to the resulting `scf.forall`. A zero tile sizes indicate
-/// that the dimension is not tiled, and can be thought of as tiling by the full
-/// size of data.
-/// It is the user's responsibility to ensure that `numThreads` is a valid
-/// tiling specification (i.e. that only tiles parallel dimensions, e.g. in the
-/// Linalg case). If the dimension is not parallelizable, a warning is issued to
-/// notify the user that the generated code is not safe to parallelize. If
-/// `omitTileOffsetBoundsCheck` is true, then the function will assume that
-/// `tileSize[i] * (numThread[i] -1) <= dimSize[i]` holds.
-static FailureOr<ForallTilingResult> tileToForallOpImpl(
- RewriterBase &b, TilingInterface op, ArrayRef<OpFoldResult> numThreads,
- std::optional<ArrayRef<OpFoldResult>> nominalTileSizes,
- std::optional<ArrayAttr> mapping, bool omitTileOffsetBoundsCheck) {
- Location loc = op->getLoc();
- OpBuilder::InsertionGuard g(b);
-
- SmallVector<Range> loopRanges = op.getIterationDomain(b);
- if (loopRanges.empty())
- return op->emitOpError("expected non-empty loop ranges");
- auto hasStrideOne = [](Range r) { return !isConstantIntValue(r.stride, 1); };
- if (llvm::any_of(loopRanges, hasStrideOne))
- return op->emitOpError("only stride-1 supported atm");
-
- // Gather destination tensors.
- SmallVector<Value> dest;
- if (failed(tensor::getOrCreateDestinations(b, loc, op, dest)))
- return op->emitOpError("failed to get destination tensors");
-
- SmallVector<OpFoldResult> nonZeroNumThreads =
- llvm::to_vector(llvm::make_filter_range(numThreads, [](OpFoldResult ofr) {
- return !isConstantIntValue(ofr, 0);
- }));
- SmallVector<Value> materializedNonZeroNumThreads =
- llvm::to_vector(llvm::map_range(nonZeroNumThreads, [&](OpFoldResult ofr) {
- return getValueOrCreateConstantIndexOp(b, loc, ofr);
- }));
-
- LinalgOp linalgOp = dyn_cast<LinalgOp>(op.getOperation());
- if (linalgOp) {
- // Check if tiling is thread safe and print a warning if not.
- SmallVector<bool> tilingSafety =
- safeToTileToForall(b.getContext(), linalgOp, numThreads);
- for (size_t i = 0; i < tilingSafety.size(); i++)
- if (!tilingSafety[i])
- op.emitWarning() << "tiling is not thread safe at axis #" << i;
- }
-
- // 1. Create the ForallOp. We don't use the lambda body-builder
- // version because we require the use of RewriterBase in the body, so we
- // manually move the insertion point to the body below.
- scf::ForallOp forallOp = b.create<scf::ForallOp>(
- loc, getAsOpFoldResult((materializedNonZeroNumThreads)), dest, mapping);
-
- // 2. Fill out the ForallOp body.
- SmallVector<OpFoldResult> tiledOffsets, tiledSizes;
- calculateTileOffsetsAndSizes(b, loc, forallOp, numThreads, loopRanges,
- omitTileOffsetBoundsCheck, nominalTileSizes,
- tiledOffsets, tiledSizes);
-
- // 3. Clone the tileable op and update its destination operands to use the
- // output bbArgs of the ForallOp.
- ArrayRef<BlockArgument> destBbArgs = forallOp.getRegionIterArgs();
- Operation *tiledOp = nullptr;
- SmallVector<Value> tiledValues;
- {
- // 3.a. RAII guard, inserting within forallOp, before terminator.
- OpBuilder::InsertionGuard g(b);
- b.setInsertionPoint(forallOp.getTerminator());
- Operation *clonedOp = b.clone(*op.getOperation());
- auto destinationStyleOp = dyn_cast<DestinationStyleOpInterface>(clonedOp);
- if (destinationStyleOp) {
- for (OpOperand &outOperand : destinationStyleOp.getDpsInitsMutable()) {
- // Swap tensor inits with the corresponding block argument of the
- // scf.forall op. Memref inits remain as is.
- if (isa<TensorType>(outOperand.get().getType())) {
- auto *it = llvm::find(dest, outOperand.get());
- assert(it != dest.end() && "could not find destination tensor");
- unsigned destNum = std::distance(dest.begin(), it);
- outOperand.set(destBbArgs[destNum]);
- }
- }
- }
-
- // 4. Tile the cloned op and delete the clone.
- FailureOr<TilingResult> tilingResult =
- cast<TilingInterface>(clonedOp).getTiledImplementation(b, tiledOffsets,
- tiledSizes);
- if (failed(tilingResult))
- return clonedOp->emitError("Failed to tile op: ");
- if (tilingResult->tiledOps.size() != 1) {
- return clonedOp->emitError("expected a single produced tiled op, got ")
- << tilingResult->tiledOps.size();
- }
-
- b.eraseOp(clonedOp);
- tiledOp = tilingResult->tiledOps.front();
- tiledValues = tilingResult->tiledValues;
- }
-
- // 5. Parallel insert back into the result tensor.
- for (auto it : llvm::zip(llvm::seq(unsigned(0), unsigned(dest.size())),
- tiledValues, destBbArgs)) {
- // 5.a. Partial subset information is inserted just before the terminator.
- OpBuilder::InsertionGuard g(b);
- b.setInsertionPoint(forallOp.getTerminator());
-
- SmallVector<OpFoldResult> resultOffsets, resultSizes;
- if (failed(op.getResultTilePosition(b, std::get<0>(it), tiledOffsets,
- tiledSizes, resultOffsets,
- resultSizes)))
- return op->emitOpError("output offsets couldn't be calculated");
- SmallVector<OpFoldResult> strides(resultSizes.size(), b.getIndexAttr(1));
-
- // 5.b. Parallel insertions are inserted at the end of the combining
- // terminator.
- b.setInsertionPointToEnd(forallOp.getTerminator().getBody());
- b.create<tensor::ParallelInsertSliceOp>(loc, std::get<1>(it),
- std::get<2>(it), resultOffsets,
- resultSizes, strides);
- }
- return ForallTilingResult{forallOp, tiledOp};
-}
-
-FailureOr<ForallTilingResult>
-linalg::tileToForallOp(RewriterBase &b, TilingInterface op,
- ArrayRef<OpFoldResult> numThreads,
- std::optional<ArrayAttr> mapping) {
- return tileToForallOpImpl(b, op, numThreads,
- /*nominalTileSizes=*/std::nullopt, mapping,
- /*omitTileOffsetBoundsCheck=*/false);
-}
-
-FailureOr<ForallTilingResult>
-linalg::tileToForallOpUsingTileSizes(RewriterBase &b, TilingInterface op,
- ArrayRef<OpFoldResult> tileSizes,
- std::optional<ArrayAttr> mapping) {
- SmallVector<Range> loopRanges = op.getIterationDomain(b);
- unsigned nLoops = loopRanges.size();
- SmallVector<OpFoldResult> numThreads;
- numThreads.reserve(nLoops);
- AffineExpr s0, s1;
- bindSymbols(b.getContext(), s0, s1);
- AffineExpr divExpr = s0.ceilDiv(s1);
- for (const auto &it : llvm::zip(tileSizes, loopRanges)) {
- OpFoldResult numTiles = std::get<0>(it);
- if (!isConstantIntValue(numTiles, 0))
- numTiles = makeComposedFoldedAffineApply(
- b, op.getLoc(), divExpr, {std::get<1>(it).size, std::get<0>(it)});
- numThreads.push_back(numTiles);
- }
- return tileToForallOpImpl(b, op, numThreads,
- /*nominalTileSizes=*/tileSizes, mapping,
- /*omitTileOffsetBoundsCheck=*/true);
-}
-
----------------
qedawkins wrote:
Would it be courtesy to mark these functions as deprecated first and remove them in a later patch in case there are any downstream users?
https://github.com/llvm/llvm-project/pull/91878
More information about the Mlir-commits
mailing list