[Mlir-commits] [mlir] [mlir][SCF] Allow tiling by specifying maximum number of tiles. (PR #91878)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat May 11 22:43:43 PDT 2024


https://github.com/MaheshRavishankar created https://github.com/llvm/llvm-project/pull/91878

None

>From ff45ad2e0dc347f9e5cfff8eba65a2e7a886b6ef Mon Sep 17 00:00:00 2001
From: MaheshRavishankar <mahesh at nod-labs.com>
Date: Sat, 11 May 2024 13:38:36 -0700
Subject: [PATCH] [mlir][SCF] Allow tiling by specifying maximum number of
 tiles.

---
 .../Linalg/TransformOps/LinalgTransformOps.h  |   6 +-
 .../SCF/Transforms/TileUsingInterface.h       |  35 ++-
 .../TransformOps/LinalgTransformOps.cpp       |  31 +-
 .../SCF/Transforms/TileUsingInterface.cpp     | 287 +++++++++++++-----
 mlir/test/Dialect/Linalg/tile-to-forall.mlir  |   1 -
 .../TestTilingInterfaceTransformOps.cpp       |   6 +-
 6 files changed, 270 insertions(+), 96 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
index 3af642752724c..db25c9b241734 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
@@ -30,6 +30,10 @@ class GenericOp;
 class LinalgOp;
 } // namespace linalg
 
+namespace scf {
+struct SCFTilingResult;
+} // namespace scf
+
 namespace tensor {
 class InsertSliceOp;
 class PackOp;
@@ -60,7 +64,7 @@ tileToForallOpImpl(RewriterBase &rewriter, transform::TransformState &state,
                    ArrayRef<OpFoldResult> mixedNumThreads,
                    ArrayRef<OpFoldResult> mixedTileSizes,
                    std::optional<ArrayAttr> mapping,
-                   linalg::ForallTilingResult &tilingResult);
+                   scf::SCFTilingResult &tilingResult);
 
 } // namespace transform
 } // namespace mlir
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index 965ef9e203be2..c1775ea4818c7 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -31,9 +31,13 @@ using SCFTileSizeComputationFunction =
 
 /// Options to use to control tiling.
 struct SCFTilingOptions {
-  /// Computation function that returns the tile sizes for each operation.
-  /// Delayed construction of constant tile sizes should occur to interoperate
-  /// with folding.
+  /// Computation function that returns the tile sizes to use for each loop.
+  /// Returning a tile size of zero implies no tiling for that loop. If the
+  /// size of the returned vector is smaller than the number of loops, the inner
+  /// loops are not tiled. If the size of the returned vector is larger, then
+  /// the vector is truncated to number of loops.  Only one of
+  /// `tileSizeComputationFunction` or `maxNumTilesComputationFunction` should
+  /// be used.
   SCFTileSizeComputationFunction tileSizeComputationFunction = nullptr;
 
   SCFTilingOptions &
@@ -44,7 +48,25 @@ struct SCFTilingOptions {
   /// Convenience function to set the `tileSizeComputationFunction` to a
   /// function that computes tile sizes at the point they are needed. Allows
   /// proper interaction with folding.
-  SCFTilingOptions &setTileSizes(ArrayRef<OpFoldResult> ts);
+  SCFTilingOptions &setTileSizes(ArrayRef<OpFoldResult> tileSizes);
+
+  /// Computation function that returns the maximum number of tile to use for
+  /// each loop. Returning a tile size of zero implies no tiling for that loop.
+  /// If the size of the returned vector is smaller than the number of loops,
+  /// the inner loops are not tiled. If the size of the returned vector is
+  /// larger, then the vector is truncated to number of loops. Only one of
+  /// `tileSizeComputationFunction` or `maxNumTilesComputationFunction` should
+  /// be used.
+  SCFTileSizeComputationFunction maxNumTilesComputationFunction = nullptr;
+
+  SCFTilingOptions &
+  setMaxNumTilesComputationFunction(SCFTileSizeComputationFunction fun) {
+    maxNumTilesComputationFunction = std::move(fun);
+    return *this;
+  }
+  /// Convenience function to set the `tileSizeComputationFunction` to a
+  /// function that computes tile sizes at the point they are needed.
+  SCFTilingOptions &setMaxNumTiles(ArrayRef<OpFoldResult> numTiles);
 
   /// The interchange vector to reorder the tiled loops.
   SmallVector<int64_t> interchangeVector = {};
@@ -66,9 +88,8 @@ struct SCFTilingOptions {
   /// when using loop constructs that dont support such a mapping (like
   /// `scf.for`)
   SmallVector<Attribute> mappingVector = {};
-  SCFTilingOptions &setMapping(ArrayRef<DeviceMappingAttrInterface> mapping) {
-    mappingVector = llvm::map_to_vector(
-        mapping, [](auto attr) -> Attribute { return attr; });
+  SCFTilingOptions &setMapping(ArrayRef<Attribute> mapping) {
+    mappingVector = llvm::to_vector(mapping);
     return *this;
   }
 };
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 13582a140a965..9fa463763068f 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2917,7 +2917,7 @@ DiagnosedSilenceableFailure transform::tileToForallOpImpl(
     TransformOpInterface transformOp, Operation *target,
     ArrayRef<OpFoldResult> mixedNumThreads,
     ArrayRef<OpFoldResult> mixedTileSizes, std::optional<ArrayAttr> mapping,
-    linalg::ForallTilingResult &tilingResult) {
+    scf::SCFTilingResult &tilingResult) {
   // Transform all targets one by one.
   auto tileableOp = dyn_cast<TilingInterface>(target);
   if (!tileableOp) {
@@ -2928,18 +2928,22 @@ DiagnosedSilenceableFailure transform::tileToForallOpImpl(
     return diag;
   }
   rewriter.setInsertionPoint(tileableOp);
-  FailureOr<linalg::ForallTilingResult> maybeTilingResult = failure();
+  scf::SCFTilingOptions options;
+  options.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
   if (!mixedNumThreads.empty()) {
-    maybeTilingResult =
-        linalg::tileToForallOp(rewriter, tileableOp, mixedNumThreads, mapping);
+    options.setMaxNumTiles(mixedNumThreads);
   } else {
-    maybeTilingResult = linalg::tileToForallOpUsingTileSizes(
-        rewriter, tileableOp, mixedTileSizes, mapping);
+    options.setTileSizes(mixedTileSizes);
   }
+  if (mapping) {
+    options.setMapping(mapping.value().getValue());
+  }
+  FailureOr<scf::SCFTilingResult> maybeTilingResult =
+      scf::tileUsingSCF(rewriter, tileableOp, options);
 
   if (failed(maybeTilingResult))
     return transformOp.emitDefaultSilenceableFailure(tileableOp);
-  rewriter.replaceOp(tileableOp, maybeTilingResult->tileOp->getResults());
+  rewriter.replaceOp(tileableOp, maybeTilingResult->replacements);
 
   tilingResult = *maybeTilingResult;
   return DiagnosedSilenceableFailure::success();
@@ -2975,14 +2979,14 @@ DiagnosedSilenceableFailure transform::TileUsingForallOp::apply(
     return status;
 
   for (Operation *target : state.getPayloadOps(getTarget())) {
-    linalg::ForallTilingResult tilingResult;
+    scf::SCFTilingResult tilingResult;
     DiagnosedSilenceableFailure diag = tileToForallOpImpl(
         rewriter, state, transformOp, target, mixedNumThreads, mixedTileSizes,
         getMapping(), tilingResult);
     if (!diag.succeeded())
       return diag;
-    tileOps.push_back(tilingResult.tileOp);
-    tiledOps.push_back(tilingResult.tiledOp);
+    tileOps.push_back(tilingResult.loops.front());
+    tiledOps.append(tilingResult.tiledOps);
   }
 
   transformResults.set(cast<OpResult>(getForallOp()), tileOps);
@@ -3460,7 +3464,7 @@ DiagnosedSilenceableFailure transform::MapCopyToThreadsOp::applyToOne(
 
   // OpBuilder only used to compute attributes.
   OpBuilder b(getContext());
-  linalg::ForallTilingResult tilingResult;
+  scf::SCFTilingResult tilingResult;
   DiagnosedSilenceableFailure diag = tileToForallOpImpl(
       /*rewriter=*/rewriter,
       /*state=*/state,
@@ -3473,8 +3477,9 @@ DiagnosedSilenceableFailure transform::MapCopyToThreadsOp::applyToOne(
   if (!diag.succeeded())
     return diag;
 
-  results.push_back(tilingResult.tileOp);
-  results.push_back(tilingResult.tiledOp);
+  results.push_back(tilingResult.loops.front());
+  for (auto op : tilingResult.tiledOps)
+    results.push_back(op);
   return DiagnosedSilenceableFailure::success();
 }
 
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 1a84a59ddb69d..83bb8532a8152 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -41,6 +41,16 @@ scf::SCFTilingOptions::setTileSizes(ArrayRef<OpFoldResult> ts) {
   return *this;
 }
 
+scf::SCFTilingOptions &
+scf::SCFTilingOptions::setMaxNumTiles(ArrayRef<OpFoldResult> mnt) {
+  assert(!maxNumTilesComputationFunction && "max num tiles already set");
+  auto maxNumTiles = llvm::to_vector(mnt);
+  maxNumTilesComputationFunction = [maxNumTiles](OpBuilder &b, Operation *op) {
+    return maxNumTiles;
+  };
+  return *this;
+}
+
 /// Helper method to adjust the interchange vector to match the iteration
 /// domain.
 static SmallVector<int64_t>
@@ -60,6 +70,101 @@ fillInterchangeVector(ArrayRef<int64_t> interchangeVector,
 // tileUsingSCF implementation.
 //===----------------------------------------------------------------------===//
 
+/// Verify the tile size options are set in a consistent manner.
+static LogicalResult
+verifyTileSizeOptions(RewriterBase &rewriter, Location loc,
+                      const scf::SCFTilingOptions &options) {
+  if (!options.tileSizeComputationFunction &&
+      !options.maxNumTilesComputationFunction) {
+    return rewriter.notifyMatchFailure(
+        loc, "at least one of tile size computation function or max num tiles "
+             "computation must be specified.");
+  }
+  if (options.tileSizeComputationFunction &&
+      options.maxNumTilesComputationFunction) {
+    return rewriter.notifyMatchFailure(
+        loc, "only one of tile size computation function or max num tiles "
+             "computation function can be specified");
+  }
+
+  // If specified, check that the interchange vector is a permutation.
+  if (!options.interchangeVector.empty()) {
+    if (!isPermutationVector(options.interchangeVector)) {
+      return rewriter.notifyMatchFailure(
+          loc, "invalid intechange vector, not a permutation of the entire "
+               "iteration space");
+    }
+  }
+  return success();
+}
+
+/// Compute the tile sizes and num tiles values. The `numTiles`
+/// is empty if the `maxNumTilesComputationFunction` is not specified.
+static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>>
+getTileSizesAndNumTiles(RewriterBase &rewriter, TilingInterface op,
+                        ArrayRef<Range> iterationDomain,
+                        const scf::SCFTilingOptions &options) {
+  SmallVector<OpFoldResult> tileSizes, numTiles;
+
+  // Enforce the convention that "tiling by zero"
+  // skips tiling a particular dimension. This convention is significantly
+  // simpler to handle instead of adjusting affine maps to account for missing
+  // dimensions.
+  auto numLoops = iterationDomain.size();
+  if (options.tileSizeComputationFunction) {
+    tileSizes = options.tileSizeComputationFunction(rewriter, op);
+    tileSizes.resize(numLoops, rewriter.getIndexAttr(0));
+    return {tileSizes, numTiles};
+  }
+
+  assert(options.maxNumTilesComputationFunction &&
+         "expected at least one of tile sizes cpomputation function or max num "
+         "tiles computation function");
+  // Enforce the convention that "maxNumTiles to zero"
+  // skips tiling a particular dimension. This convention is significantly
+  // simpler to handle instead of adjusting affine maps to account for missing
+  // dimensions.
+  SmallVector<OpFoldResult> maxNumTiles =
+      options.maxNumTilesComputationFunction(rewriter, op);
+  maxNumTiles.resize(numLoops, rewriter.getIndexAttr(0));
+
+  // Use the maxNumTiles to compute the tile sizes as
+  // - niters = ceilDiv(ub - lb, step)
+  // - tileSize = ceilDiv(niters, maxNumTiles)
+  AffineExpr s0, s1, s2, s3;
+  bindSymbols(rewriter.getContext(), s0, s1, s2, s3);
+  AffineExpr numIters = (s1 - s0).ceilDiv(s2);
+  AffineExpr tileSizeExpr = numIters.ceilDiv(s3);
+  tileSizes.resize(numLoops, rewriter.getIndexAttr(0));
+  for (auto [index, maxNumTile] : llvm::enumerate(maxNumTiles)) {
+    if (isConstantIntValue(maxNumTile, 0))
+      continue;
+
+    tileSizes[index] = affine::makeComposedFoldedAffineApply(
+        rewriter, op.getLoc(), tileSizeExpr,
+        {iterationDomain[index].offset, iterationDomain[index].size,
+         iterationDomain[index].stride, maxNumTile});
+  }
+
+  // After computing the tile size recompute the num tiles. reason to do this
+  // is to avoid corner cases like:
+  // [lb, ub, step] = [0, 300, 1], maxNumTiles = 21.
+  // Computed tileSize = 15. With this the actual number of threads is 20
+  // Not accounting for that creates a slice of size 0 which is undefined.
+  AffineExpr numTileExpr = numIters.floorDiv(s3);
+  numTiles.resize(tileSizes.size(), rewriter.getIndexAttr(0));
+  for (auto [index, tileSize] : llvm::enumerate(tileSizes)) {
+    if (isConstantIntValue(tileSize, 0))
+      continue;
+    numTiles[index] = affine::makeComposedFoldedAffineApply(
+        rewriter, op.getLoc(), numTileExpr,
+        {iterationDomain[index].offset, iterationDomain[index].size,
+         iterationDomain[index].stride, tileSize});
+  }
+
+  return {tileSizes, numTiles};
+}
+
 // Check if `stride` evenly divides the trip count `size - offset`.
 static bool tileDividesIterationDomain(Range loopRange) {
   std::optional<int64_t> offsetAsInt = getConstantIntValue(loopRange.offset);
@@ -99,6 +204,46 @@ static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc,
       b, loc, minMap, SmallVector<OpFoldResult>{iv, tileSize, size});
 }
 
+/// Compute the tile offsets and sizes.
+static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>>
+getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs,
+                      ArrayRef<Range> iterationDomain,
+                      ArrayRef<OpFoldResult> tileSizes, bool isLoopNormalized) {
+  SmallVector<OpFoldResult> offsets, sizes;
+  int materializedLoopNum = 0;
+
+  AffineExpr d0, s0, s1, s2;
+  AffineExpr offsetExpr;
+  if (isLoopNormalized) {
+    bindDims(rewriter.getContext(), d0);
+    bindSymbols(rewriter.getContext(), s0, s1, s2);
+    offsetExpr = s0 + d0 * s1 * s2;
+  }
+
+  for (auto [tileSize, loopRange] :
+       llvm::zip_equal(tileSizes, iterationDomain)) {
+    if (isConstantIntValue(tileSize, 0)) {
+      offsets.push_back(loopRange.offset);
+      sizes.push_back(loopRange.size);
+      continue;
+    }
+    // If loop is normalized, the offset is (lb + iv * step * tileSize)
+    Value iv = ivs[materializedLoopNum++];
+    OpFoldResult offset;
+    if (isLoopNormalized) {
+      offset = affine::makeComposedFoldedAffineApply(
+          rewriter, loc, offsetExpr,
+          ArrayRef<OpFoldResult>{iv, loopRange.offset, loopRange.stride,
+                                 tileSize});
+    } else {
+      offset = getAsOpFoldResult(iv);
+    }
+    offsets.push_back(offset);
+    sizes.push_back(getBoundedTileSize(rewriter, loc, loopRange, iv, tileSize));
+  }
+  return {offsets, sizes};
+}
+
 /// A function that allows returning additional yielded values during
 /// `yieldTiledValuesAndReplace`.
 /// - `ivs` induction variable for the loop.
@@ -144,8 +289,8 @@ static Operation *cloneOpAndUpdateDestinationArgs(RewriterBase &rewriter,
 ///    populated.
 static LogicalResult generateLoopNestUsingForOp(
     RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
-    ArrayRef<OpFoldResult> tileSizes, ValueRange destinationTensors,
-    YieldTiledValuesFn yieldTiledValuesFn,
+    ArrayRef<OpFoldResult> tileSizes, ArrayRef<OpFoldResult> numTiles,
+    ValueRange destinationTensors, YieldTiledValuesFn yieldTiledValuesFn,
     SmallVector<LoopLikeOpInterface> &loops) {
   assert(!loopRanges.empty() && "unexpected empty loop ranges");
   assert(loopRanges.size() == tileSizes.size() &&
@@ -153,15 +298,30 @@ static LogicalResult generateLoopNestUsingForOp(
   OpBuilder::InsertionGuard guard(rewriter);
   SmallVector<Value> ivs;
 
-  for (auto [loopRange, tileSize] : llvm::zip_equal(loopRanges, tileSizes)) {
+  Value zero, one;
+  if (!numTiles.empty()) {
+    zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+    ;
+    one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+  }
+
+  for (auto [index, loopRange, tileSize] :
+       llvm::enumerate(loopRanges, tileSizes)) {
     // No loops if tile size is zero. Set offset and size to the loop
     // offset and size.
     if (isConstantIntValue(tileSize, 0))
       continue;
 
-    Value lb = getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.offset);
-    Value ub = getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.size);
-    Value step = getValueOrCreateConstantIndexOp(rewriter, loc, tileSize);
+    Value lb, ub, step;
+    if (numTiles.empty()) {
+      lb = getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.offset);
+      ub = getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.size);
+      step = getValueOrCreateConstantIndexOp(rewriter, loc, tileSize);
+    } else {
+      lb = zero;
+      ub = getValueOrCreateConstantIndexOp(rewriter, loc, numTiles[index]);
+      step = one;
+    }
     auto loop =
         rewriter.create<scf::ForOp>(loc, lb, ub, step, destinationTensors,
                                     [](OpBuilder &bodyBuilder, Location bodyLoc,
@@ -220,32 +380,45 @@ static LogicalResult generateLoopNestUsingForOp(
 ///    populated.
 static LogicalResult generateLoopNestUsingForallOp(
     RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
-    ArrayRef<OpFoldResult> tileSizes, ArrayRef<Attribute> mappingVector,
-    ValueRange destinationTensors, YieldTiledValuesFn tiledBodyFn,
-    SmallVector<LoopLikeOpInterface> &loops) {
-  SmallVector<OpFoldResult> lbs, ubs, steps;
+    ArrayRef<OpFoldResult> tileSizes, ArrayRef<OpFoldResult> numTiles,
+    ArrayRef<Attribute> mappingVector, ValueRange destinationTensors,
+    YieldTiledValuesFn tiledBodyFn, SmallVector<LoopLikeOpInterface> &loops) {
   assert(!loopRanges.empty() && "unexpected empty loop ranges");
   assert(loopRanges.size() == tileSizes.size() &&
          "expected as many tile sizes as loop ranges");
+  assert((numTiles.empty() || numTiles.size() == loopRanges.size()) &&
+         "expected max number of tiles to be either empty or equal to number "
+         "of loops");
   OpBuilder::InsertionGuard guard(rewriter);
   SmallVector<OpFoldResult> offsets(loopRanges.size()),
       sizes(loopRanges.size());
 
-  for (auto [tileSize, loopRange] : llvm::zip_equal(tileSizes, loopRanges)) {
-    if (isConstantIntValue(tileSize, 0))
-      continue;
-    lbs.push_back(loopRange.offset);
-    ubs.push_back(loopRange.size);
-    steps.push_back(tileSize);
-  }
-  assert(!lbs.empty() && "Expected at least one loop range");
-
   std::optional<ArrayAttr> mappingAttr;
   if (!mappingVector.empty())
     mappingAttr = rewriter.getArrayAttr(mappingVector);
 
-  auto forallOp = rewriter.create<scf::ForallOp>(
-      loc, lbs, ubs, steps, destinationTensors, mappingAttr);
+  scf::ForallOp forallOp;
+  SmallVector<OpFoldResult> lbs, ubs, steps;
+  if (numTiles.empty()) {
+    for (auto [tileSize, loopRange] : llvm::zip_equal(tileSizes, loopRanges)) {
+      if (isConstantIntValue(tileSize, 0))
+        continue;
+      lbs.push_back(loopRange.offset);
+      ubs.push_back(loopRange.size);
+      steps.push_back(tileSize);
+    }
+    assert(!lbs.empty() && "Expected at least one loop range");
+    forallOp = rewriter.create<scf::ForallOp>(loc, lbs, ubs, steps,
+                                              destinationTensors, mappingAttr);
+  } else {
+    SmallVector<OpFoldResult> numThreads;
+    for (auto maxNumTile : numTiles) {
+      if (!isConstantIntValue(maxNumTile, 0))
+        numThreads.push_back(maxNumTile);
+    }
+    forallOp = rewriter.create<scf::ForallOp>(loc, numThreads,
+                                              destinationTensors, mappingAttr);
+  }
   loops.push_back(forallOp);
 
   rewriter.setInsertionPoint(forallOp.getTerminator());
@@ -282,13 +455,11 @@ static LogicalResult generateLoopNestUsingForallOp(
 ///    loop.
 /// - `loops` is an in-out parameter into which the generated loops are
 ///    populated.
-static LogicalResult generateLoopNest(RewriterBase &rewriter, Location loc,
-                                      const scf::SCFTilingOptions &options,
-                                      ArrayRef<Range> loopRanges,
-                                      ArrayRef<OpFoldResult> tileSizes,
-                                      ValueRange destinationTensors,
-                                      YieldTiledValuesFn tiledBodyFn,
-                                      SmallVector<LoopLikeOpInterface> &loops) {
+static LogicalResult generateLoopNest(
+    RewriterBase &rewriter, Location loc, const scf::SCFTilingOptions &options,
+    ArrayRef<Range> loopRanges, ArrayRef<OpFoldResult> tileSizes,
+    ArrayRef<OpFoldResult> numTiles, ValueRange destinationTensors,
+    YieldTiledValuesFn tiledBodyFn, SmallVector<LoopLikeOpInterface> &loops) {
   // If the tile sizes are all zero, no loops are generated. Just call the
   // callback function to handle untiled case.
   if (llvm::all_of(tileSizes, isZeroIndex)) {
@@ -299,11 +470,12 @@ static LogicalResult generateLoopNest(RewriterBase &rewriter, Location loc,
   }
   if (options.loopType == scf::SCFTilingOptions::LoopType::ForOp) {
     return generateLoopNestUsingForOp(rewriter, loc, loopRanges, tileSizes,
-                                      destinationTensors, tiledBodyFn, loops);
+                                      numTiles, destinationTensors, tiledBodyFn,
+                                      loops);
   }
   if (options.loopType == scf::SCFTilingOptions::LoopType::ForallOp) {
     return generateLoopNestUsingForallOp(
-        rewriter, loc, loopRanges, tileSizes, options.mappingVector,
+        rewriter, loc, loopRanges, tileSizes, numTiles, options.mappingVector,
         destinationTensors, tiledBodyFn, loops);
   }
   return rewriter.notifyMatchFailure(loc, "unhandled loop type");
@@ -527,28 +699,20 @@ static LogicalResult addInitOperandsToLoopNest(
 FailureOr<scf::SCFTilingResult>
 mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
                         const scf::SCFTilingOptions &options) {
+  if (failed(verifyTileSizeOptions(rewriter, op.getLoc(), options))) {
+    return failure();
+  }
+
   OpBuilder::InsertionGuard guard(rewriter);
   rewriter.setInsertionPointAfter(op);
 
-  if (!options.tileSizeComputationFunction) {
-    return rewriter.notifyMatchFailure(
-        op, "missing tile size computation function");
-  }
-
   // 1. Get the range of the loops that are represented by the operation.
   SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter);
-  size_t numLoops = iterationDomain.size();
 
-  // 2. Materialize the tile sizes. Enforce the convention that "tiling by zero"
-  // skips tiling a particular dimension. This convention is significantly
-  // simpler to handle instead of adjusting affine maps to account for missing
-  // dimensions.
-  SmallVector<OpFoldResult> tileSizes =
-      options.tileSizeComputationFunction(rewriter, op);
-  if (tileSizes.size() < iterationDomain.size()) {
-    auto zero = rewriter.getIndexAttr(0);
-    tileSizes.append(numLoops - tileSizes.size(), zero);
-  }
+  // 2. Materialize the tile sizes or max num tiles;
+  SmallVector<OpFoldResult> tileSizes, numTiles;
+  std::tie(tileSizes, numTiles) =
+      getTileSizesAndNumTiles(rewriter, op, iterationDomain, options);
 
   // 3. If there is an interchange specified, permute the iteration domain and
   // the tile sizes.
@@ -556,16 +720,13 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
   if (!options.interchangeVector.empty()) {
     interchangeVector = fillInterchangeVector(options.interchangeVector,
                                               iterationDomain.size());
-  }
-  if (!interchangeVector.empty()) {
-    if (!isPermutationVector(interchangeVector)) {
-      return rewriter.notifyMatchFailure(
-          op, "invalid intechange vector, not a permutation of the entire "
-              "iteration space");
-    }
+    assert(isPermutationVector(interchangeVector) &&
+           "expected interchange vector to be a permutation");
 
     applyPermutationToVector(iterationDomain, interchangeVector);
     applyPermutationToVector(tileSizes, interchangeVector);
+    if (!numTiles.empty())
+      applyPermutationToVector(numTiles, interchangeVector);
   }
 
   FailureOr<TilingResult> tilingResult;
@@ -579,21 +740,8 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
       -> LogicalResult {
     // 4a. Compute the `offsets` and `sizes` to use for tiling.
     SmallVector<OpFoldResult> offsets, sizes;
-    {
-      int materializedLoopNum = 0;
-      for (auto [tileSize, loopRange] :
-           llvm::zip_equal(tileSizes, iterationDomain)) {
-        if (isConstantIntValue(tileSize, 0)) {
-          offsets.push_back(loopRange.offset);
-          sizes.push_back(loopRange.size);
-          continue;
-        }
-        Value iv = ivs[materializedLoopNum++];
-        offsets.push_back(iv);
-        sizes.push_back(
-            getBoundedTileSize(rewriter, loc, loopRange, iv, tileSize));
-      }
-    }
+    std::tie(offsets, sizes) = getTileOffsetAndSizes(
+        rewriter, loc, ivs, iterationDomain, tileSizes, !numTiles.empty());
 
     // 4b. If interchange was provided, apply inverse of the interchange
     //     to get back the offsets/sizes in the order to be specified.
@@ -661,7 +809,7 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
   // 7. Generate the tiled loops nest using the callback defined above.
   SmallVector<LoopLikeOpInterface> loops;
   if (failed(generateLoopNest(rewriter, op.getLoc(), options, iterationDomain,
-                              tileSizes, destinationTensors,
+                              tileSizes, numTiles, destinationTensors,
                               innerYieldTiledValuesFn, loops)))
     return op.emitOpError("failed to generate tiling loops");
   assert(succeeded(tilingResult) &&
@@ -774,6 +922,7 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b,
   scf::SCFTilingOptions options;
   options.setLoopType(scf::SCFTilingOptions::LoopType::ForOp);
   if (failed(generateLoopNest(b, loc, options, iterationDomain, tileSizesVector,
+                              /*numTiles=*/ArrayRef<OpFoldResult>{},
                               destinationTensors, innerYieldTiledValuesFn,
                               loops)))
     return b.notifyMatchFailure(op, "failed to tile for parallel reduction");
diff --git a/mlir/test/Dialect/Linalg/tile-to-forall.mlir b/mlir/test/Dialect/Linalg/tile-to-forall.mlir
index 8545dfd25eccf..f33739f119eaf 100644
--- a/mlir/test/Dialect/Linalg/tile-to-forall.mlir
+++ b/mlir/test/Dialect/Linalg/tile-to-forall.mlir
@@ -177,7 +177,6 @@ module attributes {transform.with_named_sequence} {
   }
 }
 
-
 // -----
 
 // CHECK-DAG: #[[$map0:.+]] = affine_map<()[s0] -> (s0 ceildiv 10)>
diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
index 335db1a61f476..d4126f04a2f35 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
@@ -182,11 +182,7 @@ applyTileToAll(RewriterBase &rewriter, Operation *transformOp,
     scf::SCFTilingOptions tilingOptions;
     tilingOptions.setTileSizes(tileSizes).setInterchange(interchange);
     if (mapping) {
-      auto mappingAttrs =
-          llvm::map_to_vector(mapping.value(), [](Attribute attr) {
-            return cast<DeviceMappingAttrInterface>(attr);
-          });
-      tilingOptions.setMapping(mappingAttrs);
+      tilingOptions.setMapping(mapping.value().getValue());
     }
     tilingOptions.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
 



More information about the Mlir-commits mailing list