[Mlir-commits] [mlir] [MLIR] Add continuous tiling to TileUsingForOp (PR #82792)
Oleksandr Alex Zinenko
llvmlistbot at llvm.org
Tue Feb 27 06:21:00 PST 2024
================
@@ -309,6 +311,188 @@ static LogicalResult generateLoopNest(RewriterBase &rewriter, Location loc,
return rewriter.notifyMatchFailure(loc, "unhandled loop type");
}
+static void continuousLoopNestHelper(
+ OpBuilder &builder, Location loc, ArrayRef<Range> loopRanges,
+ SmallVector<LoopLikeOpInterface> &loops, uint64_t loopLevelIdx,
+ uint64_t &loopIdx, ArrayRef<OpFoldResult> tileSizes,
+ SmallVector<bool> &CTileVector, std::map<int, OpFoldResult> &sizesMap,
+ SmallVector<scf::ForOp> &innermostLoops, ValueRange destinationTensors = {},
+ bool isHeadOrInsideHeadLoop = false) {
+
+ Value offset = getValueOrCreateConstantIndexOp(
+ builder, loc, loopRanges[loopLevelIdx].offset);
+ Value size = getValueOrCreateConstantIndexOp(builder, loc,
+ loopRanges[loopLevelIdx].size);
+ Value tileSize =
+ getValueOrCreateConstantIndexOp(builder, loc, tileSizes[loopLevelIdx]);
+
+ AffineExpr sym0, sym1, sym2;
+ bindSymbols(builder.getContext(), sym0, sym1, sym2);
+ AffineMap defaultSplitMap =
+ AffineMap::get(0, 3, {sym1 - ((sym1 - sym0) % sym2)});
+ // Simplified map for use when step is power of 2 and lower bound
+ // is exactly divisble by step.
+ AffineMap powerSplitMap = AffineMap::get(0, 3, {sym1 - (sym1 % sym2)});
+
+ uint64_t tileSizeInt = *getConstantIntValue(tileSize);
+
+ // Enforce no tiling when tile size is zero.
+ // No need to create a loop here.
+ if (tileSizeInt == 0) {
+ continuousLoopNestHelper(builder, loc, loopRanges, loops, loopLevelIdx + 1,
+ loopIdx, tileSizes, CTileVector, sizesMap,
+ innermostLoops, destinationTensors,
+ isHeadOrInsideHeadLoop);
+ return;
+ }
+
+ // The head loop is always tiled using the tile size specified
+ // in the size parameters to tile_using_for transform.
+ auto loop = builder.create<scf::ForOp>(
+ loc, offset, size, tileSize, destinationTensors,
+ [&](OpBuilder &bodyBuilder, Location bodyLoc, Value iv,
+ ValueRange /*iterArgs*/) {
+ sizesMap[loopIdx] =
+ getBoundedTileSize(bodyBuilder, bodyLoc, loopRanges[loopLevelIdx],
+ iv, getAsOpFoldResult(tileSize));
+ });
+
+ loop->setAttr(kLoopIndexLabel, builder.getIndexAttr(loopIdx));
+ ++loopIdx;
+
+ scf::ForOp currentLoop = loop;
+ auto lbInt = getConstantIntValue(currentLoop.getLowerBound());
+ // Use simplified powerSplitMap instead of the default when possible.
+ bool usePowerSplit = (lbInt.has_value()) &&
+ (*lbInt % tileSizeInt == static_cast<int64_t>(0)) &&
+ (tileSizeInt == llvm::bit_floor(tileSizeInt));
+
+ AffineMap splitMap = usePowerSplit ? powerSplitMap : defaultSplitMap;
+
+ bool isInnermostLoop = loopLevelIdx == loopRanges.size() - 1;
+ if (isInnermostLoop)
+ innermostLoops.push_back(currentLoop);
+
+ if (isHeadOrInsideHeadLoop)
+ loops.push_back(loop);
+
+ builder.setInsertionPointToEnd(loop.getBody());
+
+ // Create the nested loop inside current loop.
+ if (!isInnermostLoop)
+ continuousLoopNestHelper(builder, loop->getLoc(), loopRanges, loops,
+ loopLevelIdx + 1, loopIdx, tileSizes, CTileVector,
+ sizesMap, innermostLoops, loop.getRegionIterArgs(),
+ isHeadOrInsideHeadLoop);
+
+ // Apply continuous tiling to current loop if continuous_tiles
+ // specifies so.
+ while (CTileVector[loopLevelIdx] && tileSizeInt > 1) {
+
+ uint64_t maxPower = llvm::bit_floor(tileSizeInt);
+ tileSizeInt = maxPower == tileSizeInt ? maxPower >> 1 : maxPower;
+
+ builder.setInsertionPoint(currentLoop);
+
+ auto constStepOp = builder.create<arith::ConstantIndexOp>(loc, tileSizeInt);
+
+ Value splitBound = builder.createOrFold<affine::AffineApplyOp>(
+ loc, splitMap,
+ ValueRange{currentLoop.getLowerBound(), currentLoop.getUpperBound(),
+ currentLoop.getStep()});
+
+ builder.setInsertionPointAfter(currentLoop);
+ auto additionalLoop =
+ builder.create<scf::ForOp>(currentLoop->getLoc(), splitBound, size,
+ constStepOp, destinationTensors);
+
+ additionalLoop.getInitArgsMutable().assign(currentLoop->getResults());
----------------
ftynse wrote:
Nit: fold this into the op constructor above.
https://github.com/llvm/llvm-project/pull/82792
More information about the Mlir-commits
mailing list