[Mlir-commits] [mlir] [MLIR] Add continuous tiling to TileUsingForOp (PR #82792)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu May 30 07:08:29 PDT 2024


================
@@ -107,6 +107,138 @@ static void emitIsPositiveIndexAssertion(ImplicitLocOpBuilder &b,
       b.getStringAttr("expected strictly positive tile size and divisor"));
 }
 
+FailureOr<StaticContinuousTileSizeSpecification>
+mlir::linalg::computeStaticContinuousTileSizes(TilingInterface op,
+                                               unsigned dimension,
+                                               unsigned targetSize) {
+
+  LinalgOp linalgOp = dyn_cast<LinalgOp>(op.getOperation());
+
+  assert(!linalgOp.hasDynamicShape() &&
+         "cannot compute static multi-tile sizes for an op with dynamic shape");
+  assert(targetSize > 0 && "target size must be non-negative");
+  assert(dimension < linalgOp.getNumLoops() && "dimension overflow");
+
+  StaticContinuousTileSizeSpecification spec;
+  int64_t loopRange = linalgOp.getStaticLoopRanges()[dimension];
+  int64_t tripCount = loopRange / targetSize;
+
+  unsigned tileSize = targetSize;
+
+  spec.tileSizes.push_back(tileSize);
+  spec.tripCounts.push_back(tripCount);
+
+  int64_t remainderChunk = loopRange % targetSize;
+
+  while (tileSize > 1 && remainderChunk != 0) {
+
+    uint64_t maxPower = llvm::bit_floor(tileSize);
+    tileSize = maxPower == tileSize ? maxPower >> 1 : maxPower;
+
+    tripCount = remainderChunk / tileSize;
+
+    if (tripCount > 0) {
+      spec.tileSizes.push_back(tileSize);
+      spec.tripCounts.push_back(tripCount);
+    }
+
+    remainderChunk = remainderChunk % tileSize;
+  }
+
+  auto tripCountCheck = [&](SmallVector<int64_t> tileSizes,
+                            SmallVector<int64_t> tripCounts,
+                            int64_t range) -> bool {
+    int64_t computedRange = 0;
+    for (auto [tileSize, tripCount] : llvm::zip(tileSizes, tripCounts))
+      computedRange += tileSize * tripCount;
+    return range == computedRange;
+  };
+
+  if (!tripCountCheck(spec.tileSizes, spec.tripCounts, loopRange))
+    return failure();
+
+  return spec;
+}
+
+FailureOr<ContinuousTileSizeSpecification>
+mlir::linalg::computeContinuousTileSizes(OpBuilder &builder, TilingInterface op,
+                                         unsigned dimension,
+                                         OpFoldResult targetSize,
+                                         bool emitAssertions) {
+
+  unsigned numLoops = op.getIterationDomain(builder).size();
+
+  // Bail out on dimension overflow.
+  if (dimension >= numLoops)
+    return failure();
+
+  // The code below works only on values.
+  Location loc = op->getLoc();
+  ImplicitLocOpBuilder b(loc, builder);
+  if (emitAssertions) {
+    emitIsPositiveIndexAssertion(b, targetSize);
+  }
+  Value targetSizeValue =
+      getValueOrCreateConstantIndexOp(builder, loc, targetSize);
+
+  // Find the trip count of the iteration space dimension for which the tile
+  // sizes are computed.
+  SmallVector<Range> loopRanges = op.getIterationDomain(builder);
+  Value loopRange = getValueOrCreateConstantIndexOp(b, op->getLoc(),
+                                                    loopRanges[dimension].size);
+
+  ContinuousTileSizeSpecification spec;
+
+  // Compute the tile sizes and the respective numbers of tiles.
+  AffineExpr s0 = b.getAffineSymbolExpr(0);
+  AffineExpr s1 = b.getAffineSymbolExpr(1);
+  auto apply = [&](AffineExpr expr, ArrayRef<OpFoldResult> ofrs) -> Value {
+    return affine::makeComposedAffineApply(b, b.getLoc(), expr, ofrs);
+  };
+
+  Value tripCountValue = apply(s0.floorDiv(s1), {loopRange, targetSizeValue});
+  Value remainderChunkValue = apply(s0 % s1, {loopRange, targetSizeValue});
+
+  OpFoldResult tripCountSize = affine::makeComposedFoldedAffineApply(
+      b, b.getLoc(), s0.floorDiv(s1), {loopRange, targetSizeValue});
+
+  uint64_t tileSizeInt = *getConstantIntValue(targetSizeValue);
----------------
muneebkhan85 wrote:

`emitAssertions` check (enabled at function call) asserts that `targetSize` is a positive integer. I have now added a comment about it here too. 

https://github.com/llvm/llvm-project/pull/82792


More information about the Mlir-commits mailing list