[Mlir-commits] [mlir] [MLIR] Add continuous tiling to TileUsingForOp (PR #82792)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu May 30 07:08:35 PDT 2024
================
@@ -107,6 +107,138 @@ static void emitIsPositiveIndexAssertion(ImplicitLocOpBuilder &b,
b.getStringAttr("expected strictly positive tile size and divisor"));
}
+FailureOr<StaticContinuousTileSizeSpecification>
+mlir::linalg::computeStaticContinuousTileSizes(TilingInterface op,
+ unsigned dimension,
+ unsigned targetSize) {
+
+ LinalgOp linalgOp = dyn_cast<LinalgOp>(op.getOperation());
+
+ assert(!linalgOp.hasDynamicShape() &&
+ "cannot compute static multi-tile sizes for an op with dynamic shape");
+ assert(targetSize > 0 && "target size must be non-negative");
+ assert(dimension < linalgOp.getNumLoops() && "dimension overflow");
+
+ StaticContinuousTileSizeSpecification spec;
+ int64_t loopRange = linalgOp.getStaticLoopRanges()[dimension];
+ int64_t tripCount = loopRange / targetSize;
+
+ unsigned tileSize = targetSize;
+
+ spec.tileSizes.push_back(tileSize);
+ spec.tripCounts.push_back(tripCount);
+
+ int64_t remainderChunk = loopRange % targetSize;
+
+ while (tileSize > 1 && remainderChunk != 0) {
+
+ uint64_t maxPower = llvm::bit_floor(tileSize);
+ tileSize = maxPower == tileSize ? maxPower >> 1 : maxPower;
+
+ tripCount = remainderChunk / tileSize;
+
+ if (tripCount > 0) {
+ spec.tileSizes.push_back(tileSize);
+ spec.tripCounts.push_back(tripCount);
+ }
+
+ remainderChunk = remainderChunk % tileSize;
+ }
+
+ auto tripCountCheck = [&](SmallVector<int64_t> tileSizes,
+ SmallVector<int64_t> tripCounts,
+ int64_t range) -> bool {
+ int64_t computedRange = 0;
+ for (auto [tileSize, tripCount] : llvm::zip(tileSizes, tripCounts))
+ computedRange += tileSize * tripCount;
+ return range == computedRange;
+ };
+
+ if (!tripCountCheck(spec.tileSizes, spec.tripCounts, loopRange))
+ return failure();
+
+ return spec;
+}
+
+FailureOr<ContinuousTileSizeSpecification>
+mlir::linalg::computeContinuousTileSizes(OpBuilder &builder, TilingInterface op,
+ unsigned dimension,
+ OpFoldResult targetSize,
+ bool emitAssertions) {
+
+ unsigned numLoops = op.getIterationDomain(builder).size();
+
+ // Bail out on dimension overflow.
+ if (dimension >= numLoops)
+ return failure();
+
+ // The code below works only on values.
+ Location loc = op->getLoc();
+ ImplicitLocOpBuilder b(loc, builder);
+ if (emitAssertions) {
+ emitIsPositiveIndexAssertion(b, targetSize);
+ }
+ Value targetSizeValue =
+ getValueOrCreateConstantIndexOp(builder, loc, targetSize);
+
+ // Find the trip count of the iteration space dimension for which the tile
+ // sizes are computed.
+ SmallVector<Range> loopRanges = op.getIterationDomain(builder);
----------------
muneebkhan85 wrote:
Fixed.
https://github.com/llvm/llvm-project/pull/82792
More information about the Mlir-commits
mailing list