[Mlir-commits] [mlir] [mlir][SCF] Deprecate `linalg::tileToForallOp` and `linalg::tileToForallOpUsingTileSizes` (PR #91878)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu May 30 16:55:09 PDT 2024


https://github.com/MaheshRavishankar updated https://github.com/llvm/llvm-project/pull/91878

>From a6ba8036d29cafcb11bd6a5857f8fa59fd50a7dc Mon Sep 17 00:00:00 2001
From: MaheshRavishankar <mahesh at nod-labs.com>
Date: Sat, 11 May 2024 13:38:36 -0700
Subject: [PATCH 1/7] [mlir][SCF] Allow tiling by specifying maximum number of
 tiles.

---
 .../Linalg/TransformOps/LinalgTransformOps.h  |   6 +-
 .../Dialect/Linalg/Transforms/Transforms.h    |  24 --
 .../SCF/Transforms/TileUsingInterface.h       |  35 ++-
 .../TransformOps/LinalgTransformOps.cpp       |  47 ++-
 mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp | 182 ------------
 .../SCF/Transforms/TileUsingInterface.cpp     | 271 +++++++++++++-----
 mlir/test/Dialect/Linalg/tile-to-forall.mlir  |   1 -
 .../TestTilingInterfaceTransformOps.cpp       |   6 +-
 8 files changed, 270 insertions(+), 302 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
index 3af642752724c..db25c9b241734 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
@@ -30,6 +30,10 @@ class GenericOp;
 class LinalgOp;
 } // namespace linalg
 
+namespace scf {
+struct SCFTilingResult;
+} // namespace scf
+
 namespace tensor {
 class InsertSliceOp;
 class PackOp;
@@ -60,7 +64,7 @@ tileToForallOpImpl(RewriterBase &rewriter, transform::TransformState &state,
                    ArrayRef<OpFoldResult> mixedNumThreads,
                    ArrayRef<OpFoldResult> mixedTileSizes,
                    std::optional<ArrayAttr> mapping,
-                   linalg::ForallTilingResult &tilingResult);
+                   scf::SCFTilingResult &tilingResult);
 
 } // namespace transform
 } // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 308ce92e35520..e9b10de68bc44 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -846,30 +846,6 @@ FailureOr<StaticMultiSizeSpecification>
 computeStaticMultiTileSizes(LinalgOp op, unsigned dimension, int64_t targetSize,
                             int64_t divisor);
 
-/// Rewrite a TilingInterface `op` to a tiled `scf.forall`, applying
-/// tiling by `numThreads`.
-/// If non-empty, the `mapping` is added as an attribute to the
-/// resulting `scf.forall`.
-/// Zero tile sizes indicate that the dimension is not tiled, and can be
-/// thought of as tiling by the full size of data. It is the user's
-/// responsibility to ensure that `numThreads` is a valid tiling specification
-/// (i.e. that only tiles parallel dimensions, e.g. in the Linalg case).
-struct ForallTilingResult {
-  Operation *tileOp;
-  Operation *tiledOp;
-};
-FailureOr<ForallTilingResult> tileToForallOp(RewriterBase &builder,
-                                             TilingInterface op,
-                                             ArrayRef<OpFoldResult> numThreads,
-                                             std::optional<ArrayAttr> mapping);
-
-/// Same as `tileToForallOp`, but calculate the number of threads
-/// required using the given tileSizes.
-FailureOr<ForallTilingResult>
-tileToForallOpUsingTileSizes(RewriterBase &builder, TilingInterface op,
-                             ArrayRef<OpFoldResult> tileSizes,
-                             std::optional<ArrayAttr> mapping);
-
 /// Transformation information returned after reduction tiling.
 struct ForallReductionTilingResult {
   /// The partial reduction tiled op generated.
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index 6d567171e185a..9c9e6dd6001a8 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -31,9 +31,13 @@ using SCFTileSizeComputationFunction =
 
 /// Options to use to control tiling.
 struct SCFTilingOptions {
-  /// Computation function that returns the tile sizes for each operation.
-  /// Delayed construction of constant tile sizes should occur to interoperate
-  /// with folding.
+  /// Computation function that returns the tile sizes to use for each loop.
+  /// Returning a tile size of zero implies no tiling for that loop. If the
+  /// size of the returned vector is smaller than the number of loops, the inner
+  /// loops are not tiled. If the size of the returned vector is larger, then
+  /// the vector is truncated to number of loops.  Only one of
+  /// `tileSizeComputationFunction` or `maxNumTilesComputationFunction` should
+  /// be used.
   SCFTileSizeComputationFunction tileSizeComputationFunction = nullptr;
 
   SCFTilingOptions &
@@ -44,7 +48,25 @@ struct SCFTilingOptions {
   /// Convenience function to set the `tileSizeComputationFunction` to a
   /// function that computes tile sizes at the point they are needed. Allows
   /// proper interaction with folding.
-  SCFTilingOptions &setTileSizes(ArrayRef<OpFoldResult> ts);
+  SCFTilingOptions &setTileSizes(ArrayRef<OpFoldResult> tileSizes);
+
+  /// Computation function that returns the maximum number of tile to use for
+  /// each loop. Returning a tile size of zero implies no tiling for that loop.
+  /// If the size of the returned vector is smaller than the number of loops,
+  /// the inner loops are not tiled. If the size of the returned vector is
+  /// larger, then the vector is truncated to number of loops. Only one of
+  /// `tileSizeComputationFunction` or `maxNumTilesComputationFunction` should
+  /// be used.
+  SCFTileSizeComputationFunction maxNumTilesComputationFunction = nullptr;
+
+  SCFTilingOptions &
+  setMaxNumTilesComputationFunction(SCFTileSizeComputationFunction fun) {
+    maxNumTilesComputationFunction = std::move(fun);
+    return *this;
+  }
+  /// Convenience function to set the `tileSizeComputationFunction` to a
+  /// function that computes tile sizes at the point they are needed.
+  SCFTilingOptions &setMaxNumTiles(ArrayRef<OpFoldResult> numTiles);
 
   /// The interchange vector to reorder the tiled loops.
   SmallVector<int64_t> interchangeVector = {};
@@ -66,9 +88,8 @@ struct SCFTilingOptions {
   /// when using loop constructs that dont support such a mapping (like
   /// `scf.for`)
   SmallVector<Attribute> mappingVector = {};
-  SCFTilingOptions &setMapping(ArrayRef<DeviceMappingAttrInterface> mapping) {
-    mappingVector = llvm::map_to_vector(
-        mapping, [](auto attr) -> Attribute { return attr; });
+  SCFTilingOptions &setMapping(ArrayRef<Attribute> mapping) {
+    mappingVector = llvm::to_vector(mapping);
     return *this;
   }
 };
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 9b3121774ab3a..bce0d8a0f65db 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2919,7 +2919,7 @@ DiagnosedSilenceableFailure transform::tileToForallOpImpl(
     TransformOpInterface transformOp, Operation *target,
     ArrayRef<OpFoldResult> mixedNumThreads,
     ArrayRef<OpFoldResult> mixedTileSizes, std::optional<ArrayAttr> mapping,
-    linalg::ForallTilingResult &tilingResult) {
+    scf::SCFTilingResult &tilingResult) {
   // Transform all targets one by one.
   auto tileableOp = dyn_cast<TilingInterface>(target);
   if (!tileableOp) {
@@ -2930,18 +2930,38 @@ DiagnosedSilenceableFailure transform::tileToForallOpImpl(
     return diag;
   }
   rewriter.setInsertionPoint(tileableOp);
-  FailureOr<linalg::ForallTilingResult> maybeTilingResult = failure();
+  scf::SCFTilingOptions options;
+  options.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
   if (!mixedNumThreads.empty()) {
-    maybeTilingResult =
-        linalg::tileToForallOp(rewriter, tileableOp, mixedNumThreads, mapping);
+    options.setMaxNumTiles(mixedNumThreads);
   } else {
-    maybeTilingResult = linalg::tileToForallOpUsingTileSizes(
-        rewriter, tileableOp, mixedTileSizes, mapping);
+    SmallVector<Range> loopRanges = tileableOp.getIterationDomain(rewriter);
+    unsigned nLoops = loopRanges.size();
+    SmallVector<OpFoldResult> numThreads;
+    numThreads.reserve(nLoops);
+    AffineExpr s0, s1;
+    bindSymbols(rewriter.getContext(), s0, s1);
+    AffineExpr divExpr = s0.ceilDiv(s1);
+    for (int i = 0, e = std::min(mixedTileSizes.size(), loopRanges.size());
+         i < e; ++i) {
+      OpFoldResult numTiles = mixedTileSizes[i];
+      if (!isConstantIntValue(numTiles, 0))
+        numTiles = affine::makeComposedFoldedAffineApply(
+            rewriter, tileableOp.getLoc(), divExpr,
+            {loopRanges[i].size, numTiles});
+      numThreads.push_back(numTiles);
+    }
+    options.setMaxNumTiles(numThreads);
+  }
+  if (mapping) {
+    options.setMapping(mapping.value().getValue());
   }
+  FailureOr<scf::SCFTilingResult> maybeTilingResult =
+      scf::tileUsingSCF(rewriter, tileableOp, options);
 
   if (failed(maybeTilingResult))
     return transformOp.emitDefaultSilenceableFailure(tileableOp);
-  rewriter.replaceOp(tileableOp, maybeTilingResult->tileOp->getResults());
+  rewriter.replaceOp(tileableOp, maybeTilingResult->replacements);
 
   tilingResult = *maybeTilingResult;
   return DiagnosedSilenceableFailure::success();
@@ -2977,14 +2997,14 @@ DiagnosedSilenceableFailure transform::TileUsingForallOp::apply(
     return status;
 
   for (Operation *target : state.getPayloadOps(getTarget())) {
-    linalg::ForallTilingResult tilingResult;
+    scf::SCFTilingResult tilingResult;
     DiagnosedSilenceableFailure diag = tileToForallOpImpl(
         rewriter, state, transformOp, target, mixedNumThreads, mixedTileSizes,
         getMapping(), tilingResult);
     if (!diag.succeeded())
       return diag;
-    tileOps.push_back(tilingResult.tileOp);
-    tiledOps.push_back(tilingResult.tiledOp);
+    tileOps.push_back(tilingResult.loops.front());
+    tiledOps.append(tilingResult.tiledOps);
   }
 
   transformResults.set(cast<OpResult>(getForallOp()), tileOps);
@@ -3462,7 +3482,7 @@ DiagnosedSilenceableFailure transform::MapCopyToThreadsOp::applyToOne(
 
   // OpBuilder only used to compute attributes.
   OpBuilder b(getContext());
-  linalg::ForallTilingResult tilingResult;
+  scf::SCFTilingResult tilingResult;
   DiagnosedSilenceableFailure diag = tileToForallOpImpl(
       /*rewriter=*/rewriter,
       /*state=*/state,
@@ -3475,8 +3495,9 @@ DiagnosedSilenceableFailure transform::MapCopyToThreadsOp::applyToOne(
   if (!diag.succeeded())
     return diag;
 
-  results.push_back(tilingResult.tileOp);
-  results.push_back(tilingResult.tiledOp);
+  results.push_back(tilingResult.loops.front());
+  for (auto op : tilingResult.tiledOps)
+    results.push_back(op);
   return DiagnosedSilenceableFailure::success();
 }
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index fd314ef9f8134..5e9840dc551de 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -304,188 +304,6 @@ static void calculateTileOffsetsAndSizes(
   }
 }
 
-/// Returns a vector of bools representing if, for each axis, `op` can be tiled
-/// without incurring in a race condition and thus it is thread-safe to do the
-/// tiling. This is checked by iterating over numThreads and ensuring that the
-/// corresponding iterator type is "parallel". If it is not, then we know that
-/// such dimension is unsafe to tile.
-SmallVector<bool> safeToTileToForall(mlir::MLIRContext *ctx, LinalgOp linalgOp,
-                                     ArrayRef<OpFoldResult> numThreads) {
-  auto iterators = linalgOp.getIteratorTypesArray();
-  SmallVector<bool> safeToTile(numThreads.size(), true);
-
-  for (unsigned i = 0, e = numThreads.size(); i != e; i++) {
-    if (auto attr = llvm::dyn_cast_if_present<Attribute>(numThreads[i])) {
-      if (cast<IntegerAttr>(attr).getValue().getSExtValue() > 1) {
-        safeToTile[i] = iterators[i] == utils::IteratorType::parallel;
-      }
-    } else {
-      safeToTile[i] = iterators[i] == utils::IteratorType::parallel;
-    }
-  }
-  return safeToTile;
-}
-
-/// Rewrite a TilingInterface `op` to a tiled `scf.forall`. The
-/// tiling is specified by the number of tiles/threads `numThreads` and the
-/// optional nominal tile size `nominalTileSizes`. If `nominalTilSizes` is
-/// not specified, then  it is derived from `numThreads` as `ceilDiv(dimSize[i],
-/// numThreads[i])`. If non-empty, the `mapping` is added as an
-/// attribute to the resulting `scf.forall`. A zero tile sizes indicate
-/// that the dimension is not tiled, and can be thought of as tiling by the full
-/// size of data.
-/// It is the user's responsibility to ensure that `numThreads` is a valid
-/// tiling specification (i.e. that only tiles parallel dimensions, e.g. in the
-/// Linalg case). If the dimension is not parallelizable, a warning is issued to
-/// notify the user that the generated code is not safe to parallelize. If
-/// `omitTileOffsetBoundsCheck` is true, then the function will assume that
-/// `tileSize[i] * (numThread[i] -1) <= dimSize[i]` holds.
-static FailureOr<ForallTilingResult> tileToForallOpImpl(
-    RewriterBase &b, TilingInterface op, ArrayRef<OpFoldResult> numThreads,
-    std::optional<ArrayRef<OpFoldResult>> nominalTileSizes,
-    std::optional<ArrayAttr> mapping, bool omitTileOffsetBoundsCheck) {
-  Location loc = op->getLoc();
-  OpBuilder::InsertionGuard g(b);
-
-  SmallVector<Range> loopRanges = op.getIterationDomain(b);
-  if (loopRanges.empty())
-    return op->emitOpError("expected non-empty loop ranges");
-  auto hasStrideOne = [](Range r) { return !isConstantIntValue(r.stride, 1); };
-  if (llvm::any_of(loopRanges, hasStrideOne))
-    return op->emitOpError("only stride-1 supported atm");
-
-  // Gather destination tensors.
-  SmallVector<Value> dest;
-  if (failed(tensor::getOrCreateDestinations(b, loc, op, dest)))
-    return op->emitOpError("failed to get destination tensors");
-
-  SmallVector<OpFoldResult> nonZeroNumThreads =
-      llvm::to_vector(llvm::make_filter_range(numThreads, [](OpFoldResult ofr) {
-        return !isConstantIntValue(ofr, 0);
-      }));
-  SmallVector<Value> materializedNonZeroNumThreads =
-      llvm::to_vector(llvm::map_range(nonZeroNumThreads, [&](OpFoldResult ofr) {
-        return getValueOrCreateConstantIndexOp(b, loc, ofr);
-      }));
-
-  LinalgOp linalgOp = dyn_cast<LinalgOp>(op.getOperation());
-  if (linalgOp) {
-    // Check if tiling is thread safe and print a warning if not.
-    SmallVector<bool> tilingSafety =
-        safeToTileToForall(b.getContext(), linalgOp, numThreads);
-    for (size_t i = 0; i < tilingSafety.size(); i++)
-      if (!tilingSafety[i])
-        op.emitWarning() << "tiling is not thread safe at axis #" << i;
-  }
-
-  // 1. Create the ForallOp. We don't use the lambda body-builder
-  // version because we require the use of RewriterBase in the body, so we
-  // manually move the insertion point to the body below.
-  scf::ForallOp forallOp = b.create<scf::ForallOp>(
-      loc, getAsOpFoldResult((materializedNonZeroNumThreads)), dest, mapping);
-
-  // 2. Fill out the ForallOp body.
-  SmallVector<OpFoldResult> tiledOffsets, tiledSizes;
-  calculateTileOffsetsAndSizes(b, loc, forallOp, numThreads, loopRanges,
-                               omitTileOffsetBoundsCheck, nominalTileSizes,
-                               tiledOffsets, tiledSizes);
-
-  // 3. Clone the tileable op and update its destination operands to use the
-  // output bbArgs of the ForallOp.
-  ArrayRef<BlockArgument> destBbArgs = forallOp.getRegionIterArgs();
-  Operation *tiledOp = nullptr;
-  SmallVector<Value> tiledValues;
-  {
-    // 3.a. RAII guard, inserting within forallOp, before terminator.
-    OpBuilder::InsertionGuard g(b);
-    b.setInsertionPoint(forallOp.getTerminator());
-    Operation *clonedOp = b.clone(*op.getOperation());
-    auto destinationStyleOp = dyn_cast<DestinationStyleOpInterface>(clonedOp);
-    if (destinationStyleOp) {
-      for (OpOperand &outOperand : destinationStyleOp.getDpsInitsMutable()) {
-        // Swap tensor inits with the corresponding block argument of the
-        // scf.forall op. Memref inits remain as is.
-        if (isa<TensorType>(outOperand.get().getType())) {
-          auto *it = llvm::find(dest, outOperand.get());
-          assert(it != dest.end() && "could not find destination tensor");
-          unsigned destNum = std::distance(dest.begin(), it);
-          outOperand.set(destBbArgs[destNum]);
-        }
-      }
-    }
-
-    // 4. Tile the cloned op and delete the clone.
-    FailureOr<TilingResult> tilingResult =
-        cast<TilingInterface>(clonedOp).getTiledImplementation(b, tiledOffsets,
-                                                               tiledSizes);
-    if (failed(tilingResult))
-      return clonedOp->emitError("Failed to tile op: ");
-    if (tilingResult->tiledOps.size() != 1) {
-      return clonedOp->emitError("expected a single produced tiled op, got ")
-             << tilingResult->tiledOps.size();
-    }
-
-    b.eraseOp(clonedOp);
-    tiledOp = tilingResult->tiledOps.front();
-    tiledValues = tilingResult->tiledValues;
-  }
-
-  // 5. Parallel insert back into the result tensor.
-  for (auto it : llvm::zip(llvm::seq(unsigned(0), unsigned(dest.size())),
-                           tiledValues, destBbArgs)) {
-    // 5.a. Partial subset information is inserted just before the terminator.
-    OpBuilder::InsertionGuard g(b);
-    b.setInsertionPoint(forallOp.getTerminator());
-
-    SmallVector<OpFoldResult> resultOffsets, resultSizes;
-    if (failed(op.getResultTilePosition(b, std::get<0>(it), tiledOffsets,
-                                        tiledSizes, resultOffsets,
-                                        resultSizes)))
-      return op->emitOpError("output offsets couldn't be calculated");
-    SmallVector<OpFoldResult> strides(resultSizes.size(), b.getIndexAttr(1));
-
-    // 5.b. Parallel insertions are inserted at the end of the combining
-    // terminator.
-    b.setInsertionPointToEnd(forallOp.getTerminator().getBody());
-    b.create<tensor::ParallelInsertSliceOp>(loc, std::get<1>(it),
-                                            std::get<2>(it), resultOffsets,
-                                            resultSizes, strides);
-  }
-  return ForallTilingResult{forallOp, tiledOp};
-}
-
-FailureOr<ForallTilingResult>
-linalg::tileToForallOp(RewriterBase &b, TilingInterface op,
-                       ArrayRef<OpFoldResult> numThreads,
-                       std::optional<ArrayAttr> mapping) {
-  return tileToForallOpImpl(b, op, numThreads,
-                            /*nominalTileSizes=*/std::nullopt, mapping,
-                            /*omitTileOffsetBoundsCheck=*/false);
-}
-
-FailureOr<ForallTilingResult>
-linalg::tileToForallOpUsingTileSizes(RewriterBase &b, TilingInterface op,
-                                     ArrayRef<OpFoldResult> tileSizes,
-                                     std::optional<ArrayAttr> mapping) {
-  SmallVector<Range> loopRanges = op.getIterationDomain(b);
-  unsigned nLoops = loopRanges.size();
-  SmallVector<OpFoldResult> numThreads;
-  numThreads.reserve(nLoops);
-  AffineExpr s0, s1;
-  bindSymbols(b.getContext(), s0, s1);
-  AffineExpr divExpr = s0.ceilDiv(s1);
-  for (const auto &it : llvm::zip(tileSizes, loopRanges)) {
-    OpFoldResult numTiles = std::get<0>(it);
-    if (!isConstantIntValue(numTiles, 0))
-      numTiles = makeComposedFoldedAffineApply(
-          b, op.getLoc(), divExpr, {std::get<1>(it).size, std::get<0>(it)});
-    numThreads.push_back(numTiles);
-  }
-  return tileToForallOpImpl(b, op, numThreads,
-                            /*nominalTileSizes=*/tileSizes, mapping,
-                            /*omitTileOffsetBoundsCheck=*/true);
-}
-
 template <typename LoopTy>
 static FailureOr<TiledLinalgOp>
 tileLinalgOpImpl(RewriterBase &b, LinalgOp op, ArrayRef<OpFoldResult> tileSizes,
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index a72dafe725177..47f74fb5e9cb2 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -41,6 +41,16 @@ scf::SCFTilingOptions::setTileSizes(ArrayRef<OpFoldResult> ts) {
   return *this;
 }
 
+scf::SCFTilingOptions &
+scf::SCFTilingOptions::setMaxNumTiles(ArrayRef<OpFoldResult> mnt) {
+  assert(!maxNumTilesComputationFunction && "max num tiles already set");
+  auto maxNumTiles = llvm::to_vector(mnt);
+  maxNumTilesComputationFunction = [maxNumTiles](OpBuilder &b, Operation *op) {
+    return maxNumTiles;
+  };
+  return *this;
+}
+
 /// Helper method to adjust the interchange vector to match the iteration
 /// domain.
 static SmallVector<int64_t>
@@ -60,6 +70,85 @@ fillInterchangeVector(ArrayRef<int64_t> interchangeVector,
 // tileUsingSCF implementation.
 //===----------------------------------------------------------------------===//
 
+/// Verify the tile size options are set in a consistent manner.
+static LogicalResult
+verifyTileSizeOptions(RewriterBase &rewriter, Location loc,
+                      const scf::SCFTilingOptions &options) {
+  if (!options.tileSizeComputationFunction &&
+      !options.maxNumTilesComputationFunction) {
+    return rewriter.notifyMatchFailure(
+        loc, "at least one of tile size computation function or max num tiles "
+             "computation must be specified.");
+  }
+  if (options.tileSizeComputationFunction &&
+      options.maxNumTilesComputationFunction) {
+    return rewriter.notifyMatchFailure(
+        loc, "only one of tile size computation function or max num tiles "
+             "computation function can be specified");
+  }
+
+  // If specified, check that the interchange vector is a permutation.
+  if (!options.interchangeVector.empty()) {
+    if (!isPermutationVector(options.interchangeVector)) {
+      return rewriter.notifyMatchFailure(
+          loc, "invalid intechange vector, not a permutation of the entire "
+               "iteration space");
+    }
+  }
+  return success();
+}
+
+/// Compute the tile sizes and num tiles values. The `numTiles`
+/// is empty if the `maxNumTilesComputationFunction` is not specified.
+static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>>
+getTileSizesAndNumTiles(RewriterBase &rewriter, TilingInterface op,
+                        ArrayRef<Range> iterationDomain,
+                        const scf::SCFTilingOptions &options) {
+  SmallVector<OpFoldResult> tileSizes, numTiles;
+
+  // Enforce the convention that "tiling by zero"
+  // skips tiling a particular dimension. This convention is significantly
+  // simpler to handle instead of adjusting affine maps to account for missing
+  // dimensions.
+  auto numLoops = iterationDomain.size();
+  if (options.tileSizeComputationFunction) {
+    tileSizes = options.tileSizeComputationFunction(rewriter, op);
+    tileSizes.resize(numLoops, rewriter.getIndexAttr(0));
+    return {tileSizes, numTiles};
+  }
+
+  assert(options.maxNumTilesComputationFunction &&
+         "expected at least one of tile sizes cpomputation function or max num "
+         "tiles computation function");
+  // Enforce the convention that "maxNumTiles to zero"
+  // skips tiling a particular dimension. This convention is significantly
+  // simpler to handle instead of adjusting affine maps to account for missing
+  // dimensions.
+  SmallVector<OpFoldResult> maxNumTiles =
+      options.maxNumTilesComputationFunction(rewriter, op);
+  maxNumTiles.resize(numLoops, rewriter.getIndexAttr(0));
+
+  // Use the maxNumTiles to compute the tile sizes as
+  // - niters = ceilDiv(ub - lb, step)
+  // - tileSize = ceilDiv(niters, maxNumTiles)
+  AffineExpr s0, s1, s2, s3;
+  bindSymbols(rewriter.getContext(), s0, s1, s2, s3);
+  AffineExpr numIters = (s1 - s0).ceilDiv(s2);
+  AffineExpr tileSizeExpr = numIters.ceilDiv(s3);
+  tileSizes.resize(numLoops, rewriter.getIndexAttr(0));
+  for (auto [index, maxNumTile] : llvm::enumerate(maxNumTiles)) {
+    if (isConstantIntValue(maxNumTile, 0))
+      continue;
+
+    tileSizes[index] = affine::makeComposedFoldedAffineApply(
+        rewriter, op.getLoc(), tileSizeExpr,
+        {iterationDomain[index].offset, iterationDomain[index].size,
+         iterationDomain[index].stride, maxNumTile});
+  }
+
+  return {tileSizes, maxNumTiles};
+}
+
 // Check if `stride` evenly divides the trip count `size - offset`.
 static bool tileDividesIterationDomain(Range loopRange) {
   std::optional<int64_t> offsetAsInt = getConstantIntValue(loopRange.offset);
@@ -99,6 +188,46 @@ static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc,
       b, loc, minMap, SmallVector<OpFoldResult>{iv, tileSize, size});
 }
 
+/// Compute the tile offsets and sizes.
+static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>>
+getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs,
+                      ArrayRef<Range> iterationDomain,
+                      ArrayRef<OpFoldResult> tileSizes, bool isLoopNormalized) {
+  SmallVector<OpFoldResult> offsets, sizes;
+  int materializedLoopNum = 0;
+
+  AffineExpr d0, s0, s1, s2;
+  AffineExpr offsetExpr;
+  if (isLoopNormalized) {
+    bindDims(rewriter.getContext(), d0);
+    bindSymbols(rewriter.getContext(), s0, s1, s2);
+    offsetExpr = s0 + d0 * s1 * s2;
+  }
+
+  for (auto [tileSize, loopRange] :
+       llvm::zip_equal(tileSizes, iterationDomain)) {
+    if (isConstantIntValue(tileSize, 0)) {
+      offsets.push_back(loopRange.offset);
+      sizes.push_back(loopRange.size);
+      continue;
+    }
+    // If loop is normalized, the offset is (lb + iv * step * tileSize)
+    Value iv = ivs[materializedLoopNum++];
+    OpFoldResult offset;
+    if (isLoopNormalized) {
+      offset = affine::makeComposedFoldedAffineApply(
+          rewriter, loc, offsetExpr,
+          ArrayRef<OpFoldResult>{iv, loopRange.offset, loopRange.stride,
+                                 tileSize});
+    } else {
+      offset = getAsOpFoldResult(iv);
+    }
+    offsets.push_back(offset);
+    sizes.push_back(getBoundedTileSize(rewriter, loc, loopRange, iv, tileSize));
+  }
+  return {offsets, sizes};
+}
+
 /// A function that allows returning additional yielded values during
 /// `yieldTiledValuesAndReplace`.
 /// - `ivs` induction variable for the loop.
@@ -144,8 +273,8 @@ static Operation *cloneOpAndUpdateDestinationArgs(RewriterBase &rewriter,
 ///    populated.
 static LogicalResult generateLoopNestUsingForOp(
     RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
-    ArrayRef<OpFoldResult> tileSizes, ValueRange destinationTensors,
-    YieldTiledValuesFn yieldTiledValuesFn,
+    ArrayRef<OpFoldResult> tileSizes, ArrayRef<OpFoldResult> numTiles,
+    ValueRange destinationTensors, YieldTiledValuesFn yieldTiledValuesFn,
     SmallVector<LoopLikeOpInterface> &loops) {
   assert(!loopRanges.empty() && "unexpected empty loop ranges");
   assert(loopRanges.size() == tileSizes.size() &&
@@ -153,15 +282,30 @@ static LogicalResult generateLoopNestUsingForOp(
   OpBuilder::InsertionGuard guard(rewriter);
   SmallVector<Value> ivs;
 
-  for (auto [loopRange, tileSize] : llvm::zip_equal(loopRanges, tileSizes)) {
+  Value zero, one;
+  if (!numTiles.empty()) {
+    zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+    ;
+    one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+  }
+
+  for (auto [index, loopRange, tileSize] :
+       llvm::enumerate(loopRanges, tileSizes)) {
     // No loops if tile size is zero. Set offset and size to the loop
     // offset and size.
     if (isConstantIntValue(tileSize, 0))
       continue;
 
-    Value lb = getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.offset);
-    Value ub = getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.size);
-    Value step = getValueOrCreateConstantIndexOp(rewriter, loc, tileSize);
+    Value lb, ub, step;
+    if (numTiles.empty()) {
+      lb = getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.offset);
+      ub = getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.size);
+      step = getValueOrCreateConstantIndexOp(rewriter, loc, tileSize);
+    } else {
+      lb = zero;
+      ub = getValueOrCreateConstantIndexOp(rewriter, loc, numTiles[index]);
+      step = one;
+    }
     auto loop =
         rewriter.create<scf::ForOp>(loc, lb, ub, step, destinationTensors,
                                     [](OpBuilder &bodyBuilder, Location bodyLoc,
@@ -223,32 +367,45 @@ static LogicalResult generateLoopNestUsingForOp(
 ///    populated.
 static LogicalResult generateLoopNestUsingForallOp(
     RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
-    ArrayRef<OpFoldResult> tileSizes, ArrayRef<Attribute> mappingVector,
-    ValueRange destinationTensors, YieldTiledValuesFn tiledBodyFn,
-    SmallVector<LoopLikeOpInterface> &loops) {
-  SmallVector<OpFoldResult> lbs, ubs, steps;
+    ArrayRef<OpFoldResult> tileSizes, ArrayRef<OpFoldResult> numTiles,
+    ArrayRef<Attribute> mappingVector, ValueRange destinationTensors,
+    YieldTiledValuesFn tiledBodyFn, SmallVector<LoopLikeOpInterface> &loops) {
   assert(!loopRanges.empty() && "unexpected empty loop ranges");
   assert(loopRanges.size() == tileSizes.size() &&
          "expected as many tile sizes as loop ranges");
+  assert((numTiles.empty() || numTiles.size() == loopRanges.size()) &&
+         "expected max number of tiles to be either empty or equal to number "
+         "of loops");
   OpBuilder::InsertionGuard guard(rewriter);
   SmallVector<OpFoldResult> offsets(loopRanges.size()),
       sizes(loopRanges.size());
 
-  for (auto [tileSize, loopRange] : llvm::zip_equal(tileSizes, loopRanges)) {
-    if (isConstantIntValue(tileSize, 0))
-      continue;
-    lbs.push_back(loopRange.offset);
-    ubs.push_back(loopRange.size);
-    steps.push_back(tileSize);
-  }
-  assert(!lbs.empty() && "Expected at least one loop range");
-
   std::optional<ArrayAttr> mappingAttr;
   if (!mappingVector.empty())
     mappingAttr = rewriter.getArrayAttr(mappingVector);
 
-  auto forallOp = rewriter.create<scf::ForallOp>(
-      loc, lbs, ubs, steps, destinationTensors, mappingAttr);
+  scf::ForallOp forallOp;
+  SmallVector<OpFoldResult> lbs, ubs, steps;
+  if (numTiles.empty()) {
+    for (auto [tileSize, loopRange] : llvm::zip_equal(tileSizes, loopRanges)) {
+      if (isConstantIntValue(tileSize, 0))
+        continue;
+      lbs.push_back(loopRange.offset);
+      ubs.push_back(loopRange.size);
+      steps.push_back(tileSize);
+    }
+    assert(!lbs.empty() && "Expected at least one loop range");
+    forallOp = rewriter.create<scf::ForallOp>(loc, lbs, ubs, steps,
+                                              destinationTensors, mappingAttr);
+  } else {
+    SmallVector<OpFoldResult> numThreads;
+    for (auto maxNumTile : numTiles) {
+      if (!isConstantIntValue(maxNumTile, 0))
+        numThreads.push_back(maxNumTile);
+    }
+    forallOp = rewriter.create<scf::ForallOp>(loc, numThreads,
+                                              destinationTensors, mappingAttr);
+  }
   loops.push_back(forallOp);
 
   rewriter.setInsertionPoint(forallOp.getTerminator());
@@ -285,13 +442,11 @@ static LogicalResult generateLoopNestUsingForallOp(
 ///    loop.
 /// - `loops` is an in-out parameter into which the generated loops are
 ///    populated.
-static LogicalResult generateLoopNest(RewriterBase &rewriter, Location loc,
-                                      const scf::SCFTilingOptions &options,
-                                      ArrayRef<Range> loopRanges,
-                                      ArrayRef<OpFoldResult> tileSizes,
-                                      ValueRange destinationTensors,
-                                      YieldTiledValuesFn tiledBodyFn,
-                                      SmallVector<LoopLikeOpInterface> &loops) {
+static LogicalResult generateLoopNest(
+    RewriterBase &rewriter, Location loc, const scf::SCFTilingOptions &options,
+    ArrayRef<Range> loopRanges, ArrayRef<OpFoldResult> tileSizes,
+    ArrayRef<OpFoldResult> numTiles, ValueRange destinationTensors,
+    YieldTiledValuesFn tiledBodyFn, SmallVector<LoopLikeOpInterface> &loops) {
   // If the tile sizes are all zero, no loops are generated. Just call the
   // callback function to handle untiled case.
   if (llvm::all_of(tileSizes, isZeroIndex)) {
@@ -302,11 +457,12 @@ static LogicalResult generateLoopNest(RewriterBase &rewriter, Location loc,
   }
   if (options.loopType == scf::SCFTilingOptions::LoopType::ForOp) {
     return generateLoopNestUsingForOp(rewriter, loc, loopRanges, tileSizes,
-                                      destinationTensors, tiledBodyFn, loops);
+                                      numTiles, destinationTensors, tiledBodyFn,
+                                      loops);
   }
   if (options.loopType == scf::SCFTilingOptions::LoopType::ForallOp) {
     return generateLoopNestUsingForallOp(
-        rewriter, loc, loopRanges, tileSizes, options.mappingVector,
+        rewriter, loc, loopRanges, tileSizes, numTiles, options.mappingVector,
         destinationTensors, tiledBodyFn, loops);
   }
   return rewriter.notifyMatchFailure(loc, "unhandled loop type");
@@ -530,28 +686,20 @@ static LogicalResult addInitOperandsToLoopNest(
 FailureOr<scf::SCFTilingResult>
 mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
                         const scf::SCFTilingOptions &options) {
+  if (failed(verifyTileSizeOptions(rewriter, op.getLoc(), options))) {
+    return failure();
+  }
+
   OpBuilder::InsertionGuard guard(rewriter);
   rewriter.setInsertionPointAfter(op);
 
-  if (!options.tileSizeComputationFunction) {
-    return rewriter.notifyMatchFailure(
-        op, "missing tile size computation function");
-  }
-
   // 1. Get the range of the loops that are represented by the operation.
   SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter);
-  size_t numLoops = iterationDomain.size();
 
-  // 2. Materialize the tile sizes. Enforce the convention that "tiling by zero"
-  // skips tiling a particular dimension. This convention is significantly
-  // simpler to handle instead of adjusting affine maps to account for missing
-  // dimensions.
-  SmallVector<OpFoldResult> tileSizes =
-      options.tileSizeComputationFunction(rewriter, op);
-  if (tileSizes.size() < iterationDomain.size()) {
-    auto zero = rewriter.getIndexAttr(0);
-    tileSizes.append(numLoops - tileSizes.size(), zero);
-  }
+  // 2. Materialize the tile sizes or max num tiles;
+  SmallVector<OpFoldResult> tileSizes, numTiles;
+  std::tie(tileSizes, numTiles) =
+      getTileSizesAndNumTiles(rewriter, op, iterationDomain, options);
 
   // 3. If there is an interchange specified, permute the iteration domain and
   // the tile sizes.
@@ -559,16 +707,13 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
   if (!options.interchangeVector.empty()) {
     interchangeVector = fillInterchangeVector(options.interchangeVector,
                                               iterationDomain.size());
-  }
-  if (!interchangeVector.empty()) {
-    if (!isPermutationVector(interchangeVector)) {
-      return rewriter.notifyMatchFailure(
-          op, "invalid intechange vector, not a permutation of the entire "
-              "iteration space");
-    }
+    assert(isPermutationVector(interchangeVector) &&
+           "expected interchange vector to be a permutation");
 
     applyPermutationToVector(iterationDomain, interchangeVector);
     applyPermutationToVector(tileSizes, interchangeVector);
+    if (!numTiles.empty())
+      applyPermutationToVector(numTiles, interchangeVector);
   }
 
   FailureOr<TilingResult> tilingResult;
@@ -582,21 +727,8 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
       -> LogicalResult {
     // 4a. Compute the `offsets` and `sizes` to use for tiling.
     SmallVector<OpFoldResult> offsets, sizes;
-    {
-      int materializedLoopNum = 0;
-      for (auto [tileSize, loopRange] :
-           llvm::zip_equal(tileSizes, iterationDomain)) {
-        if (isConstantIntValue(tileSize, 0)) {
-          offsets.push_back(loopRange.offset);
-          sizes.push_back(loopRange.size);
-          continue;
-        }
-        Value iv = ivs[materializedLoopNum++];
-        offsets.push_back(iv);
-        sizes.push_back(
-            getBoundedTileSize(rewriter, loc, loopRange, iv, tileSize));
-      }
-    }
+    std::tie(offsets, sizes) = getTileOffsetAndSizes(
+        rewriter, loc, ivs, iterationDomain, tileSizes, !numTiles.empty());
 
     // 4b. If interchange was provided, apply inverse of the interchange
     //     to get back the offsets/sizes in the order to be specified.
@@ -664,7 +796,7 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
   // 7. Generate the tiled loops nest using the callback defined above.
   SmallVector<LoopLikeOpInterface> loops;
   if (failed(generateLoopNest(rewriter, op.getLoc(), options, iterationDomain,
-                              tileSizes, destinationTensors,
+                              tileSizes, numTiles, destinationTensors,
                               innerYieldTiledValuesFn, loops)))
     return op.emitOpError("failed to generate tiling loops");
   assert(succeeded(tilingResult) &&
@@ -773,6 +905,7 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b,
   scf::SCFTilingOptions options;
   options.setLoopType(scf::SCFTilingOptions::LoopType::ForOp);
   if (failed(generateLoopNest(b, loc, options, iterationDomain, tileSizesVector,
+                              /*numTiles=*/ArrayRef<OpFoldResult>{},
                               initTensors, innerYieldTiledValuesFn, loops)))
     return b.notifyMatchFailure(op, "failed to tile for parallel reduction");
 
diff --git a/mlir/test/Dialect/Linalg/tile-to-forall.mlir b/mlir/test/Dialect/Linalg/tile-to-forall.mlir
index 8545dfd25eccf..f33739f119eaf 100644
--- a/mlir/test/Dialect/Linalg/tile-to-forall.mlir
+++ b/mlir/test/Dialect/Linalg/tile-to-forall.mlir
@@ -177,7 +177,6 @@ module attributes {transform.with_named_sequence} {
   }
 }
 
-
 // -----
 
 // CHECK-DAG: #[[$map0:.+]] = affine_map<()[s0] -> (s0 ceildiv 10)>
diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
index 335db1a61f476..d4126f04a2f35 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
@@ -182,11 +182,7 @@ applyTileToAll(RewriterBase &rewriter, Operation *transformOp,
     scf::SCFTilingOptions tilingOptions;
     tilingOptions.setTileSizes(tileSizes).setInterchange(interchange);
     if (mapping) {
-      auto mappingAttrs =
-          llvm::map_to_vector(mapping.value(), [](Attribute attr) {
-            return cast<DeviceMappingAttrInterface>(attr);
-          });
-      tilingOptions.setMapping(mappingAttrs);
+      tilingOptions.setMapping(mapping.value().getValue());
     }
     tilingOptions.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
 

>From 58296c951459061c958c5aa46037db0267535461 Mon Sep 17 00:00:00 2001
From: MaheshRavishankar <mahesh.ravishankar at gmail.com>
Date: Mon, 20 May 2024 23:17:56 -0700
Subject: [PATCH 2/7] Allow specifying both numThreads and tileSizes to keep
 the same existing semantics of distribution using number of threads.

---
 .../SCF/Transforms/TileUsingInterface.h       |  28 +-
 .../TransformOps/LinalgTransformOps.cpp       |   5 +-
 .../SCF/Transforms/TileUsingInterface.cpp     | 266 ++++++++++--------
 mlir/test/Dialect/Linalg/tile-to-forall.mlir  |  52 ++--
 .../Dialect/Linalg/transform-op-tile.mlir     |  29 +-
 .../tile-pad-using-interface.mlir             |  10 +-
 .../TilingInterface/tile-using-interface.mlir |  50 ++--
 7 files changed, 234 insertions(+), 206 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index 9c9e6dd6001a8..1a799066b20d1 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -35,9 +35,7 @@ struct SCFTilingOptions {
   /// Returning a tile size of zero implies no tiling for that loop. If the
   /// size of the returned vector is smaller than the number of loops, the inner
   /// loops are not tiled. If the size of the returned vector is larger, then
-  /// the vector is truncated to number of loops.  Only one of
-  /// `tileSizeComputationFunction` or `maxNumTilesComputationFunction` should
-  /// be used.
+  /// the vector is truncated to number of loops.
   SCFTileSizeComputationFunction tileSizeComputationFunction = nullptr;
 
   SCFTilingOptions &
@@ -50,23 +48,25 @@ struct SCFTilingOptions {
   /// proper interaction with folding.
   SCFTilingOptions &setTileSizes(ArrayRef<OpFoldResult> tileSizes);
 
-  /// Computation function that returns the maximum number of tile to use for
-  /// each loop. Returning a tile size of zero implies no tiling for that loop.
-  /// If the size of the returned vector is smaller than the number of loops,
-  /// the inner loops are not tiled. If the size of the returned vector is
-  /// larger, then the vector is truncated to number of loops. Only one of
-  /// `tileSizeComputationFunction` or `maxNumTilesComputationFunction` should
-  /// be used.
-  SCFTileSizeComputationFunction maxNumTilesComputationFunction = nullptr;
+  /// Computation function that returns the number of threads to use for
+  /// each loop. Returning a num threads of zero implies no tiling for that
+  /// loop. If the size of the returned vector is smaller than the number of
+  /// loops, the inner loops are not tiled. If the size of the returned vector
+  /// is larger, then the vector is truncated to number of loops. Note: This
+  /// option is only supported with loopType set to `LoopType::ForallOp`. If the
+  /// tile size function is not specified while the num threads computation is,
+  /// then the tile size is determined automatically to map at most one tile per
+  /// thread.
+  SCFTileSizeComputationFunction numThreadsComputationFunction = nullptr;
 
   SCFTilingOptions &
-  setMaxNumTilesComputationFunction(SCFTileSizeComputationFunction fun) {
-    maxNumTilesComputationFunction = std::move(fun);
+  setNumThreadsComputationFunction(SCFTileSizeComputationFunction fun) {
+    numThreadsComputationFunction = std::move(fun);
     return *this;
   }
   /// Convenience function to set the `tileSizeComputationFunction` to a
   /// function that computes tile sizes at the point they are needed.
-  SCFTilingOptions &setMaxNumTiles(ArrayRef<OpFoldResult> numTiles);
+  SCFTilingOptions &setNumThreads(ArrayRef<OpFoldResult> numThreads);
 
   /// The interchange vector to reorder the tiled loops.
   SmallVector<int64_t> interchangeVector = {};
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index bce0d8a0f65db..8bf7db2e15061 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2933,7 +2933,7 @@ DiagnosedSilenceableFailure transform::tileToForallOpImpl(
   scf::SCFTilingOptions options;
   options.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
   if (!mixedNumThreads.empty()) {
-    options.setMaxNumTiles(mixedNumThreads);
+    options.setNumThreads(mixedNumThreads);
   } else {
     SmallVector<Range> loopRanges = tileableOp.getIterationDomain(rewriter);
     unsigned nLoops = loopRanges.size();
@@ -2951,7 +2951,8 @@ DiagnosedSilenceableFailure transform::tileToForallOpImpl(
             {loopRanges[i].size, numTiles});
       numThreads.push_back(numTiles);
     }
-    options.setMaxNumTiles(numThreads);
+    options.setNumThreads(numThreads);
+    options.setTileSizes(mixedTileSizes);
   }
   if (mapping) {
     options.setMapping(mapping.value().getValue());
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 47f74fb5e9cb2..c44f39fa40e4a 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -42,11 +42,11 @@ scf::SCFTilingOptions::setTileSizes(ArrayRef<OpFoldResult> ts) {
 }
 
 scf::SCFTilingOptions &
-scf::SCFTilingOptions::setMaxNumTiles(ArrayRef<OpFoldResult> mnt) {
-  assert(!maxNumTilesComputationFunction && "max num tiles already set");
-  auto maxNumTiles = llvm::to_vector(mnt);
-  maxNumTilesComputationFunction = [maxNumTiles](OpBuilder &b, Operation *op) {
-    return maxNumTiles;
+scf::SCFTilingOptions::setNumThreads(ArrayRef<OpFoldResult> nt) {
+  assert(!numThreadsComputationFunction && "num tiles already set");
+  auto numThreads = llvm::to_vector(nt);
+  numThreadsComputationFunction = [numThreads](OpBuilder &b, Operation *op) {
+    return numThreads;
   };
   return *this;
 }
@@ -74,17 +74,12 @@ fillInterchangeVector(ArrayRef<int64_t> interchangeVector,
 static LogicalResult
 verifyTileSizeOptions(RewriterBase &rewriter, Location loc,
                       const scf::SCFTilingOptions &options) {
-  if (!options.tileSizeComputationFunction &&
-      !options.maxNumTilesComputationFunction) {
+  // Specifying number of tile is only supported on `scf.forall` op.
+  if (options.numThreadsComputationFunction &&
+      options.loopType != scf::SCFTilingOptions::LoopType::ForallOp) {
     return rewriter.notifyMatchFailure(
-        loc, "at least one of tile size computation function or max num tiles "
-             "computation must be specified.");
-  }
-  if (options.tileSizeComputationFunction &&
-      options.maxNumTilesComputationFunction) {
-    return rewriter.notifyMatchFailure(
-        loc, "only one of tile size computation function or max num tiles "
-             "computation function can be specified");
+        loc, "number of tiles/threads can only by specified when loop type is "
+             "set to use `scf.forall`");
   }
 
   // If specified, check that the interchange vector is a permutation.
@@ -98,58 +93,94 @@ verifyTileSizeOptions(RewriterBase &rewriter, Location loc,
   return success();
 }
 
-/// Compute the tile sizes and num tiles values. The `numTiles`
-/// is empty if the `maxNumTilesComputationFunction` is not specified.
+/// Compute the tile sizes and num threads values passed in.
 static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>>
-getTileSizesAndNumTiles(RewriterBase &rewriter, TilingInterface op,
-                        ArrayRef<Range> iterationDomain,
-                        const scf::SCFTilingOptions &options) {
-  SmallVector<OpFoldResult> tileSizes, numTiles;
+getTileSizes(RewriterBase &rewriter, TilingInterface op,
+             ArrayRef<Range> iterationDomain,
+             const scf::SCFTilingOptions &options) {
+  OpFoldResult zero = rewriter.getIndexAttr(0);
+  SmallVector<OpFoldResult> tileSizes, numThreads;
+  size_t numLoops = iterationDomain.size();
+
+  // Check whether the number of tiles to use is specified.
+  if (options.numThreadsComputationFunction) {
+    numThreads = options.numThreadsComputationFunction(rewriter, op);
+    numThreads.resize(numLoops, zero);
+
+    // If the number of tiles is also specified, use that.
+    if (options.tileSizeComputationFunction) {
+      tileSizes = options.tileSizeComputationFunction(rewriter, op);
+    } else {
+      // Compute the tile sizes from the iteration domain and number
+      // of tiles as follows
+      // - niters = ceilDiv(ub - lb, step)
+      // - tileSize = ceilDiv(niters, numThreads)
+      AffineExpr s0, s1, s2, s3;
+      bindSymbols(rewriter.getContext(), s0, s1, s2, s3);
+      AffineExpr numItersExpr = (s1 - s0).ceilDiv(s2);
+      AffineExpr tileSizeExpr = numItersExpr.ceilDiv(s3);
+      tileSizes.resize(numLoops, zero);
+      for (auto [index, range, nt] :
+           llvm::enumerate(iterationDomain, numThreads)) {
+        if (isConstantIntValue(nt, 0))
+          continue;
+
+        tileSizes[index] = affine::makeComposedFoldedAffineApply(
+            rewriter, op.getLoc(), tileSizeExpr,
+            {range.offset, range.size, range.stride, nt});
+      }
+    }
+    tileSizes.resize(numLoops, zero);
+    return {tileSizes, numThreads};
+  }
 
   // Enforce the convention that "tiling by zero"
   // skips tiling a particular dimension. This convention is significantly
   // simpler to handle instead of adjusting affine maps to account for missing
   // dimensions.
-  auto numLoops = iterationDomain.size();
   if (options.tileSizeComputationFunction) {
     tileSizes = options.tileSizeComputationFunction(rewriter, op);
-    tileSizes.resize(numLoops, rewriter.getIndexAttr(0));
-    return {tileSizes, numTiles};
   }
+  tileSizes.resize(numLoops, zero);
 
-  assert(options.maxNumTilesComputationFunction &&
-         "expected at least one of tile sizes cpomputation function or max num "
-         "tiles computation function");
-  // Enforce the convention that "maxNumTiles to zero"
-  // skips tiling a particular dimension. This convention is significantly
-  // simpler to handle instead of adjusting affine maps to account for missing
-  // dimensions.
-  SmallVector<OpFoldResult> maxNumTiles =
-      options.maxNumTilesComputationFunction(rewriter, op);
-  maxNumTiles.resize(numLoops, rewriter.getIndexAttr(0));
-
-  // Use the maxNumTiles to compute the tile sizes as
-  // - niters = ceilDiv(ub - lb, step)
-  // - tileSize = ceilDiv(niters, maxNumTiles)
-  AffineExpr s0, s1, s2, s3;
-  bindSymbols(rewriter.getContext(), s0, s1, s2, s3);
-  AffineExpr numIters = (s1 - s0).ceilDiv(s2);
-  AffineExpr tileSizeExpr = numIters.ceilDiv(s3);
-  tileSizes.resize(numLoops, rewriter.getIndexAttr(0));
-  for (auto [index, maxNumTile] : llvm::enumerate(maxNumTiles)) {
-    if (isConstantIntValue(maxNumTile, 0))
+  return {tileSizes, numThreads};
+}
+
+/// Checks if any of the tiled loops are not parallel.
+static void checkSafeToTileToForall(TilingInterface op,
+                                    ArrayRef<OpFoldResult> tileSizes,
+                                    ArrayRef<OpFoldResult> numThreads) {
+  auto iterators = op.getLoopIteratorTypes();
+  assert(iterators.size() == tileSizes.size() &&
+         "expected as many tile size values as number of loops");
+  assert((numThreads.empty() || (numThreads.size() == iterators.size())) &&
+         "when specified, expected number of threads to use for each loop");
+
+  for (auto [index, iterator, tileSize] :
+       llvm::enumerate(iterators, tileSizes)) {
+    // If num threads is specified, check that it is greater than one only for
+    // parallel dimensions.
+    if (!numThreads.empty()) {
+      if (std::optional<int64_t> constNumThreads =
+              getConstantIntValue(numThreads[index])) {
+        if (constNumThreads.value() > 1 &&
+            iterator != utils::IteratorType::parallel) {
+          op.emitWarning() << "tiling is not thread safe at axis #" << index;
+        }
+      }
       continue;
+    }
 
-    tileSizes[index] = affine::makeComposedFoldedAffineApply(
-        rewriter, op.getLoc(), tileSizeExpr,
-        {iterationDomain[index].offset, iterationDomain[index].size,
-         iterationDomain[index].stride, maxNumTile});
+    if (std::optional<int64_t> constTileSize = getConstantIntValue(tileSize)) {
+      if (constTileSize.value() > 0 &&
+          iterator != utils::IteratorType::parallel) {
+        op.emitWarning() << "tiling is not thread safe at axis #" << index;
+      }
+    }
   }
-
-  return {tileSizes, maxNumTiles};
 }
 
-// Check if `stride` evenly divides the trip count `size - offset`.
+/// Check if `stride` evenly divides the trip count `size - offset`.
 static bool tileDividesIterationDomain(Range loopRange) {
   std::optional<int64_t> offsetAsInt = getConstantIntValue(loopRange.offset);
   if (!offsetAsInt)
@@ -163,10 +194,10 @@ static bool tileDividesIterationDomain(Range loopRange) {
   return ((sizeAsInt.value() - offsetAsInt.value()) % strideAsInt.value() == 0);
 }
 
-/// Returns the bounded tile size given the current `iv`, `loopRange` and
-/// `tileSize`, i.e., `min(tileSize, range.end() - iv)`.
+/// Returns the bounded tile size given the current `offset`, `loopRange` and
+/// `tileSize`, i.e., `min(tileSize, range.end() - offset)`.
 static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc,
-                                       Range loopRange, Value iv,
+                                       Range loopRange, OpFoldResult offset,
                                        OpFoldResult tileSize) {
   std::optional<int64_t> ts = getConstantIntValue(tileSize);
   if (ts && ts.value() == 1)
@@ -185,7 +216,7 @@ static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc,
   AffineMap minMap = AffineMap::get(1, 2, {s0, s1 - d0}, b.getContext());
   Value size = getValueOrCreateConstantIndexOp(b, loc, loopRange.size);
   return affine::makeComposedFoldedAffineMin(
-      b, loc, minMap, SmallVector<OpFoldResult>{iv, tileSize, size});
+      b, loc, minMap, SmallVector<OpFoldResult>{offset, tileSize, size});
 }
 
 /// Compute the tile offsets and sizes.
@@ -223,11 +254,29 @@ getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs,
       offset = getAsOpFoldResult(iv);
     }
     offsets.push_back(offset);
-    sizes.push_back(getBoundedTileSize(rewriter, loc, loopRange, iv, tileSize));
+    sizes.push_back(
+        getBoundedTileSize(rewriter, loc, loopRange, offset, tileSize));
   }
   return {offsets, sizes};
 }
 
+/// Function to return the bounds of the loops to be generated.
+static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>,
+                  SmallVector<OpFoldResult>>
+getLoopBounds(RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
+              ArrayRef<OpFoldResult> tileSizes) {
+  SmallVector<OpFoldResult> lbs, ubs, steps;
+  for (auto [loopRange, tileSize] : llvm::zip_equal(loopRanges, tileSizes)) {
+    // No loop if the tile size is 0.
+    if (isConstantIntValue(tileSize, 0))
+      continue;
+    lbs.push_back(loopRange.offset);
+    ubs.push_back(loopRange.size);
+    steps.push_back(tileSize);
+  }
+  return {lbs, ubs, steps};
+}
+
 /// A function that allows returning additional yielded values during
 /// `yieldTiledValuesAndReplace`.
 /// - `ivs` induction variable for the loop.
@@ -273,39 +322,26 @@ static Operation *cloneOpAndUpdateDestinationArgs(RewriterBase &rewriter,
 ///    populated.
 static LogicalResult generateLoopNestUsingForOp(
     RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
-    ArrayRef<OpFoldResult> tileSizes, ArrayRef<OpFoldResult> numTiles,
-    ValueRange destinationTensors, YieldTiledValuesFn yieldTiledValuesFn,
+    ArrayRef<OpFoldResult> tileSizes, ValueRange destinationTensors,
+    YieldTiledValuesFn yieldTiledValuesFn,
     SmallVector<LoopLikeOpInterface> &loops) {
   assert(!loopRanges.empty() && "unexpected empty loop ranges");
   assert(loopRanges.size() == tileSizes.size() &&
          "expected as many tile sizes as loop ranges");
   OpBuilder::InsertionGuard guard(rewriter);
-  SmallVector<Value> ivs;
 
-  Value zero, one;
-  if (!numTiles.empty()) {
-    zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
-    ;
-    one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
-  }
-
-  for (auto [index, loopRange, tileSize] :
-       llvm::enumerate(loopRanges, tileSizes)) {
-    // No loops if tile size is zero. Set offset and size to the loop
-    // offset and size.
-    if (isConstantIntValue(tileSize, 0))
-      continue;
+  SmallVector<OpFoldResult> lbs, ubs, steps;
+  std::tie(lbs, ubs, steps) =
+      getLoopBounds(rewriter, loc, loopRanges, tileSizes);
+  SmallVector<Value> lbVals =
+      getValueOrCreateConstantIndexOp(rewriter, loc, lbs);
+  SmallVector<Value> ubVals =
+      getValueOrCreateConstantIndexOp(rewriter, loc, ubs);
+  SmallVector<Value> stepVals =
+      getValueOrCreateConstantIndexOp(rewriter, loc, steps);
 
-    Value lb, ub, step;
-    if (numTiles.empty()) {
-      lb = getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.offset);
-      ub = getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.size);
-      step = getValueOrCreateConstantIndexOp(rewriter, loc, tileSize);
-    } else {
-      lb = zero;
-      ub = getValueOrCreateConstantIndexOp(rewriter, loc, numTiles[index]);
-      step = one;
-    }
+  SmallVector<Value> ivs;
+  for (auto [lb, ub, step] : llvm::zip_equal(lbVals, ubVals, stepVals)) {
     auto loop =
         rewriter.create<scf::ForOp>(loc, lb, ub, step, destinationTensors,
                                     [](OpBuilder &bodyBuilder, Location bodyLoc,
@@ -367,15 +403,12 @@ static LogicalResult generateLoopNestUsingForOp(
 ///    populated.
 static LogicalResult generateLoopNestUsingForallOp(
     RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
-    ArrayRef<OpFoldResult> tileSizes, ArrayRef<OpFoldResult> numTiles,
+    ArrayRef<OpFoldResult> tileSizes, ArrayRef<OpFoldResult> numThreads,
     ArrayRef<Attribute> mappingVector, ValueRange destinationTensors,
     YieldTiledValuesFn tiledBodyFn, SmallVector<LoopLikeOpInterface> &loops) {
   assert(!loopRanges.empty() && "unexpected empty loop ranges");
   assert(loopRanges.size() == tileSizes.size() &&
          "expected as many tile sizes as loop ranges");
-  assert((numTiles.empty() || numTiles.size() == loopRanges.size()) &&
-         "expected max number of tiles to be either empty or equal to number "
-         "of loops");
   OpBuilder::InsertionGuard guard(rewriter);
   SmallVector<OpFoldResult> offsets(loopRanges.size()),
       sizes(loopRanges.size());
@@ -385,25 +418,23 @@ static LogicalResult generateLoopNestUsingForallOp(
     mappingAttr = rewriter.getArrayAttr(mappingVector);
 
   scf::ForallOp forallOp;
-  SmallVector<OpFoldResult> lbs, ubs, steps;
-  if (numTiles.empty()) {
-    for (auto [tileSize, loopRange] : llvm::zip_equal(tileSizes, loopRanges)) {
-      if (isConstantIntValue(tileSize, 0))
+  bool useNumThreads = !numThreads.empty();
+
+  if (useNumThreads) {
+    // Prune the zero numthreads.
+    SmallVector<OpFoldResult> nonZeroNumThreads;
+    for (auto nt : numThreads) {
+      if (isConstantIntValue(nt, 0))
         continue;
-      lbs.push_back(loopRange.offset);
-      ubs.push_back(loopRange.size);
-      steps.push_back(tileSize);
+      nonZeroNumThreads.push_back(nt);
     }
-    assert(!lbs.empty() && "Expected at least one loop range");
-    forallOp = rewriter.create<scf::ForallOp>(loc, lbs, ubs, steps,
+    forallOp = rewriter.create<scf::ForallOp>(loc, nonZeroNumThreads,
                                               destinationTensors, mappingAttr);
   } else {
-    SmallVector<OpFoldResult> numThreads;
-    for (auto maxNumTile : numTiles) {
-      if (!isConstantIntValue(maxNumTile, 0))
-        numThreads.push_back(maxNumTile);
-    }
-    forallOp = rewriter.create<scf::ForallOp>(loc, numThreads,
+    SmallVector<OpFoldResult> lbs, ubs, steps;
+    std::tie(lbs, ubs, steps) =
+        getLoopBounds(rewriter, loc, loopRanges, tileSizes);
+    forallOp = rewriter.create<scf::ForallOp>(loc, lbs, ubs, steps,
                                               destinationTensors, mappingAttr);
   }
   loops.push_back(forallOp);
@@ -445,7 +476,7 @@ static LogicalResult generateLoopNestUsingForallOp(
 static LogicalResult generateLoopNest(
     RewriterBase &rewriter, Location loc, const scf::SCFTilingOptions &options,
     ArrayRef<Range> loopRanges, ArrayRef<OpFoldResult> tileSizes,
-    ArrayRef<OpFoldResult> numTiles, ValueRange destinationTensors,
+    ArrayRef<OpFoldResult> numThreads, ValueRange destinationTensors,
     YieldTiledValuesFn tiledBodyFn, SmallVector<LoopLikeOpInterface> &loops) {
   // If the tile sizes are all zero, no loops are generated. Just call the
   // callback function to handle untiled case.
@@ -457,12 +488,11 @@ static LogicalResult generateLoopNest(
   }
   if (options.loopType == scf::SCFTilingOptions::LoopType::ForOp) {
     return generateLoopNestUsingForOp(rewriter, loc, loopRanges, tileSizes,
-                                      numTiles, destinationTensors, tiledBodyFn,
-                                      loops);
+                                      destinationTensors, tiledBodyFn, loops);
   }
   if (options.loopType == scf::SCFTilingOptions::LoopType::ForallOp) {
     return generateLoopNestUsingForallOp(
-        rewriter, loc, loopRanges, tileSizes, numTiles, options.mappingVector,
+        rewriter, loc, loopRanges, tileSizes, numThreads, options.mappingVector,
         destinationTensors, tiledBodyFn, loops);
   }
   return rewriter.notifyMatchFailure(loc, "unhandled loop type");
@@ -696,10 +726,16 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
   // 1. Get the range of the loops that are represented by the operation.
   SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter);
 
-  // 2. Materialize the tile sizes or max num tiles;
-  SmallVector<OpFoldResult> tileSizes, numTiles;
-  std::tie(tileSizes, numTiles) =
-      getTileSizesAndNumTiles(rewriter, op, iterationDomain, options);
+  // 2. Materialize the tile sizes and/or number of threads;
+  SmallVector<OpFoldResult> tileSizes, numThreads;
+  std::tie(tileSizes, numThreads) =
+      getTileSizes(rewriter, op, iterationDomain, options);
+
+  // Check if it is safe to tile. This is hold over from previous iterations
+  // of tile to for-all. Consider dropping it.
+  if (options.loopType == scf::SCFTilingOptions::LoopType::ForallOp) {
+    checkSafeToTileToForall(op, tileSizes, numThreads);
+  }
 
   // 3. If there is an interchange specified, permute the iteration domain and
   // the tile sizes.
@@ -712,8 +748,8 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
 
     applyPermutationToVector(iterationDomain, interchangeVector);
     applyPermutationToVector(tileSizes, interchangeVector);
-    if (!numTiles.empty())
-      applyPermutationToVector(numTiles, interchangeVector);
+    if (!numThreads.empty())
+      applyPermutationToVector(numThreads, interchangeVector);
   }
 
   FailureOr<TilingResult> tilingResult;
@@ -728,7 +764,7 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
     // 4a. Compute the `offsets` and `sizes` to use for tiling.
     SmallVector<OpFoldResult> offsets, sizes;
     std::tie(offsets, sizes) = getTileOffsetAndSizes(
-        rewriter, loc, ivs, iterationDomain, tileSizes, !numTiles.empty());
+        rewriter, loc, ivs, iterationDomain, tileSizes, !numThreads.empty());
 
     // 4b. If interchange was provided, apply inverse of the interchange
     //     to get back the offsets/sizes in the order to be specified.
@@ -796,7 +832,7 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
   // 7. Generate the tiled loops nest using the callback defined above.
   SmallVector<LoopLikeOpInterface> loops;
   if (failed(generateLoopNest(rewriter, op.getLoc(), options, iterationDomain,
-                              tileSizes, numTiles, destinationTensors,
+                              tileSizes, numThreads, destinationTensors,
                               innerYieldTiledValuesFn, loops)))
     return op.emitOpError("failed to generate tiling loops");
   assert(succeeded(tilingResult) &&
@@ -905,7 +941,7 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b,
   scf::SCFTilingOptions options;
   options.setLoopType(scf::SCFTilingOptions::LoopType::ForOp);
   if (failed(generateLoopNest(b, loc, options, iterationDomain, tileSizesVector,
-                              /*numTiles=*/ArrayRef<OpFoldResult>{},
+                              /*numThreads=*/ArrayRef<OpFoldResult>{},
                               initTensors, innerYieldTiledValuesFn, loops)))
     return b.notifyMatchFailure(op, "failed to tile for parallel reduction");
 
diff --git a/mlir/test/Dialect/Linalg/tile-to-forall.mlir b/mlir/test/Dialect/Linalg/tile-to-forall.mlir
index f33739f119eaf..d1ed468fce323 100644
--- a/mlir/test/Dialect/Linalg/tile-to-forall.mlir
+++ b/mlir/test/Dialect/Linalg/tile-to-forall.mlir
@@ -3,9 +3,9 @@
 // Offset per thread:
 // CHECK-DAG: affine_map<(d0)[s0] -> (d0 * (s0 ceildiv 10))>
 // Per thread tile size.
-// CHECK-DAG: affine_map<(d0)[s0] -> (-(d0 * (s0 ceildiv 10)) + s0, s0 ceildiv 10)>
+// CHECK-DAG: affine_map<(d0)[s0] -> (s0 ceildiv 10, -(d0 * (s0 ceildiv 10)) + s0)>
 // CHECK-DAG: affine_map<(d0)[s0] -> (d0 * (s0 ceildiv 20))>
-// CHECK-DAG: affine_map<(d0)[s0] -> (-(d0 * (s0 ceildiv 20)) + s0, s0 ceildiv 20)>
+// CHECK-DAG: affine_map<(d0)[s0] -> (s0 ceildiv 20, -(d0 * (s0 ceildiv 20)) + s0)>
 
 module {
 // CHECK-LABEL: matmul(
@@ -96,7 +96,7 @@ module {
 // In this test case, matmul dims and tile size are dynamic.
 
 // CHECK-DAG: #[[$map0:.+]] = affine_map<()[s0, s1] -> (s0 ceildiv s1)>
-// CHECK-DAG: #[[$map2:.+]] = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)>
+// CHECK-DAG: #[[$map2:.+]] = affine_map<(d0)[s0, s1] -> (s0, -(d0 * s0) + s1)>
 // CHECK-DAG: #[[$map4:.+]] = affine_map<(d0)[s0] -> (d0 * s0)>
 
 // CHECK-LABEL: matmul_tile_size_dynamic_dynamic(
@@ -140,7 +140,7 @@ module attributes {transform.with_named_sequence} {
 
 // Tests that dimension 0 can eliminate affine.min/max, dimension 1 cannot.
 
-// CHECK-DAG: #[[$map0:.+]] = affine_map<(d0) -> (d0 * -15 + 300, 15)>
+// CHECK-DAG: #[[$map0:.+]] = affine_map<(d0) -> (15, d0 * -15 + 300)>
 // CHECK-DAG: #[[$map1:.+]] = affine_map<(d0) -> (0, d0)>
 // CHECK-DAG: #[[$map2:.+]] = affine_map<(d0) -> (d0 * 10)>
 // CHECK-DAG: #[[$map3:.+]] = affine_map<(d0) -> (d0 * 15)>
@@ -176,30 +176,29 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
-
 // -----
 
-// CHECK-DAG: #[[$map0:.+]] = affine_map<()[s0] -> (s0 ceildiv 10)>
-// CHECK-DAG: #[[$map1:.+]] = affine_map<()[s0] -> (s0 ceildiv 20)>
-// CHECK-DAG: #[[$map2:.+]] = affine_map<(d0)[s0] -> (d0 * -10 + s0, 10)>
-// CHECK-DAG: #[[$map4:.+]] = affine_map<(d0)[s0] -> (d0 * -20 + s0, 20)>
-// CHECK-DAG: #[[$map5:.+]] = affine_map<(d0) -> (d0 * 10)>
-// CHECK-DAG: #[[$map6:.+]] = affine_map<(d0) -> (d0 * 20)>
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 10)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 ceildiv 20)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0] -> (d0 * -10 + s0, 10)>
+// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0] -> (d0 * -20 + s0, 20)>
+// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0) -> (d0 * 10)>
+// CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0) -> (d0 * 20)>
 
-// CHECK-LABEL: matmul_tile_size_dynamic(
+//       CHECK: matmul_tile_size_dynamic(
 //  CHECK-SAME:   %[[A:[0-9a-z]+]]: tensor<?x?xf32>
 //  CHECK-SAME:   %[[B:[0-9a-z]+]]: tensor<?x?xf32>
 //  CHECK-SAME:   %[[C:[0-9a-z]+]]: tensor<?x?xf32>
 func.func @matmul_tile_size_dynamic(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
   //      CHECK: %[[M:.+]] = tensor.dim %[[A]], %c0 :
   //      CHECK: %[[N:.+]] = tensor.dim %[[B]], %c1 :
-  //      CHECK: %[[NT0:.+]] = affine.apply #map()[%[[M]]]
-  //      CHECK: %[[NT1:.+]] = affine.apply #map1()[%[[N]]]
+  //      CHECK: %[[NT0:.+]] = affine.apply #[[MAP0]]()[%[[M]]]
+  //      CHECK: %[[NT1:.+]] = affine.apply #[[MAP1]]()[%[[N]]]
   //      CHECK: scf.forall (%[[IV0:.+]], %[[IV1:.+]]) in (%[[NT0]], %[[NT1]]) shared_outs(%[[C_BLK:.*]] = %[[C]])
-  //      CHECK:   %[[TS0:.+]] = affine.min #[[$map2]](%[[IV0]])[%[[M]]]
-  //      CHECK:   %[[TS1:.+]] = affine.min #[[$map4]](%[[IV1]])[%[[N]]]
-  //      CHECK:   %[[LB0:.+]] = affine.apply #[[$map5]](%[[IV0]])
-  //      CHECK:   %[[LB1:.+]] = affine.apply #[[$map6]](%[[IV1]])
+  //      CHECK:   %[[TS0:.+]] = affine.min #[[MAP2]](%[[IV0]])[%[[M]]]
+  //      CHECK:   %[[TS1:.+]] = affine.min #[[MAP3]](%[[IV1]])[%[[N]]]
+  //      CHECK:   %[[LB0:.+]] = affine.apply #[[MAP4]](%[[IV0]])
+  //      CHECK:   %[[LB1:.+]] = affine.apply #[[MAP5]](%[[IV1]])
   //      CHECK:   tensor.extract_slice %[[A]]
   //      CHECK:   tensor.extract_slice %[[B]]
   //      CHECK:   tensor.extract_slice %[[C_BLK]]
@@ -219,26 +218,25 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
-
 // -----
 
 // Tests that dimension 0 can eliminate affine.min/max, dimension 1 cannot.
 
-// CHECK-DAG: #[[$map0:.+]] = affine_map<(d0) -> (d0 * -21 + 300, 21)>
-// CHECK-DAG: #[[$map2:.+]] = affine_map<(d0) -> (d0 * 10)>
-// CHECK-DAG: #[[$map3:.+]] = affine_map<(d0) -> (d0 * 21)>
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0) -> (d0 * -21 + 300, 21)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (d0 * 10)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0) -> (d0 * 21)>
 
-// CHECK-LABEL: matmul_tile_size_static(
+//       CHECK: matmul_tile_size_static(
 //  CHECK-SAME:   %[[A:[0-9a-z]+]]: tensor
 //  CHECK-SAME:   %[[B:[0-9a-z]+]]: tensor
 //  CHECK-SAME:   %[[C:[0-9a-z]+]]: tensor
 func.func @matmul_tile_size_static(%A: tensor<100x200xf32>, %B: tensor<200x300xf32>, %C: tensor<100x300xf32>) -> tensor<100x300xf32> {
   //      CHECK: scf.forall (%[[IV0:.+]], %[[IV1:.+]]) in (10, 15) shared_outs(%[[C_BLK:.*]] = %[[C]])
-  //      CHECK:   %[[TS:.+]] = affine.min #[[$map0]](%[[IV1]])
+  //      CHECK:   %[[TS:.+]] = affine.min #[[MAP0]](%[[IV1]])
   //  CHECK-NOT:   affine.max
   //  CHECK-NOT:   affine.min
-  //      CHECK:   %[[LB0:.+]] = affine.apply #[[$map2]](%[[IV0]])
-  //      CHECK:   %[[LB1:.+]] = affine.apply #[[$map3]](%[[IV1]])
+  //      CHECK:   %[[LB0:.+]] = affine.apply #[[MAP1]](%[[IV0]])
+  //      CHECK:   %[[LB1:.+]] = affine.apply #[[MAP2]](%[[IV1]])
   //      CHECK:   %[[tA:.+]] = tensor.extract_slice %[[A]][%[[LB0]], 0] [10, 200] [1, 1] :
   //      CHECK:   %[[tB:.+]] = tensor.extract_slice %[[B]][0, %[[LB1]]] [200, %[[TS]]] [1, 1] :
   //      CHECK:   %[[tC:.+]] = tensor.extract_slice %[[C_BLK]][%[[LB0]], %[[LB1]]] [10, %[[TS]]] [1, 1] :
@@ -298,7 +296,7 @@ module {
 
 // CHECK-DAG: #[[$map0:.+]] = affine_map<()[s0, s1] -> (s0 ceildiv s1)>
 // CHECK-DAG: #[[$map1:.+]] = affine_map<()[s0] -> (s0 ceildiv 20)>
-// CHECK-DAG: #[[$map2:.+]] = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)>
+// CHECK-DAG: #[[$map2:.+]] = affine_map<(d0)[s0, s1] -> (s0, -(d0 * s0) + s1)>
 // CHECK-DAG: #[[$map3:.+]] = affine_map<(d0)[s0] -> (d0 * -20 + s0, 20)>
 // CHECK-DAG: #[[$map4:.+]] = affine_map<(d0)[s0] -> (d0 * s0)>
 // CHECK-DAG: #[[$map5:.+]] = affine_map<(d0) -> (d0 * 20)>
diff --git a/mlir/test/Dialect/Linalg/transform-op-tile.mlir b/mlir/test/Dialect/Linalg/transform-op-tile.mlir
index d244670f73754..3467a539496b8 100644
--- a/mlir/test/Dialect/Linalg/transform-op-tile.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-tile.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt --transform-interpreter --mlir-print-local-scope --split-input-file --verify-diagnostics %s | FileCheck %s
+// RUN: mlir-opt --transform-interpreter --mlir-print-local-scope --split-input-file --verify-diagnostics --cse %s | FileCheck %s
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
@@ -178,12 +178,11 @@ module {
 
 // CHECK-LABEL:   func.func @scalable_tile(
 // CHECK-SAME:      %[[ARG_0:.*]]: tensor<?xf32>, %[[ARG_1:.*]]: tensor<?xf32>, %[[ARG_2:.*]]: tensor<?xf32>,
-// CHECK:           %[[C4:.*]] = arith.constant 0 : index
-// CHECK:           %[[DIM:.*]] = tensor.dim %[[ARG_0]], %[[C4]] : tensor<?xf32>
+// CHECK:           %[[C0:.*]] = arith.constant 0 : index
+// CHECK:           %[[DIM:.*]] = tensor.dim %[[ARG_0]], %[[C0]] : tensor<?xf32>
 // CHECK:           %[[VEC_SIZE:.*]] = arith.constant 4 : index
 // CHECK:           %[[VS:.*]] = vector.vscale
 // CHECK:           %[[STEP:.*]] = arith.muli %[[VEC_SIZE]], %[[VS]] : index
-// CHECK:           %[[C0:.*]] = arith.constant 0 : index
 // CHECK:           scf.for %[[IV:.*]] = %[[C0]] to %[[DIM]] step %[[STEP]] iter_args(%[[VAL:.*]] = %[[ARG_2]]) -> (tensor<?xf32>) {
 // CHECK:             %[[SIZE:.*]] = affine.min affine_map<(d0)[s0, s1] -> (s0, -d0 + s1)>(%[[IV]])[%[[STEP]], %[[DIM]]]
 // CHECK:             %[[SLICE_ARG0:.*]] = tensor.extract_slice %[[ARG_0]][%[[IV]]] [%[[SIZE]]] [1] : tensor<?xf32> to tensor<?xf32>
@@ -202,20 +201,14 @@ module {
 // -----
 
 // CHECK-LABEL:   func.func @scalable_and_fixed_length_tile
-// CHECK:           %[[C4:.*]] = arith.constant 4 : index
-// CHECK:           %[[VS:.*]] = vector.vscale
-// CHECK:           %[[STEP_2:.*]] = arith.muli %[[C4]], %[[VS]] : index
-// CHECK:           %[[C0:.*]] = arith.constant 0 : index
-// CHECK:           %[[C128:.*]] = arith.constant 128 : index
-// CHECK:           %[[STEP_0:.*]] = arith.constant 4 : index
-// CHECK:           scf.for %[[VAL_11:.*]] = %[[C0]] to %[[C128]] step %[[STEP_0]]
-// CHECK:             %[[C0_1:.*]] = arith.constant 0 : index
-// CHECK:             %[[C128_1:.*]] = arith.constant 128 : index
-// CHECK:             %[[STEP_1:.*]] = arith.constant 4 : index
-// CHECK:             scf.for %[[VAL_16:.*]] = %[[C0_1]] to %[[C128_1]] step %[[STEP_1]]
-// CHECK:               %[[C0_2:.*]] = arith.constant 0 : index
-// CHECK:               %[[C128_2:.*]] = arith.constant 128 : index
-// CHECK:               scf.for %{{.*}} = %[[C0_2]] to %[[C128_2]] step %[[STEP_2]]
+//   CHECK-DAG:     %[[C4:.*]] = arith.constant 4 : index
+//   CHECK-DAG:     %[[VS:.*]] = vector.vscale
+//   CHECK-DAG:     %[[STEP_2:.*]] = arith.muli %[[C4]], %[[VS]] : index
+//   CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
+//   CHECK-DAG:     %[[C128:.*]] = arith.constant 128 : index
+//       CHECK:     scf.for %[[VAL_11:.*]] = %[[C0]] to %[[C128]] step %[[C4]]
+//       CHECK:       scf.for %[[VAL_16:.*]] = %[[C0]] to %[[C128]] step %[[C4]]
+//       CHECK:         scf.for %{{.*}} = %[[C0]] to %[[C128]] step %[[STEP_2]]
 
 func.func @scalable_and_fixed_length_tile(
   %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>)
diff --git a/mlir/test/Interfaces/TilingInterface/tile-pad-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-pad-using-interface.mlir
index 7d247aefcf6b1..ccf8e37c094f4 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-pad-using-interface.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-pad-using-interface.mlir
@@ -31,8 +31,8 @@ module attributes {transform.with_named_sequence} {
 //   CHECK-DAG:   %[[DIM_IN1:.+]] = tensor.dim %[[IN]], %[[C1]]
 //   CHECK-DAG:   %[[DIM1:.+]] = affine.apply #[[MAP1]]()[%[[DIM_IN1]]]
 //   CHECK-DAG:   %[[C2:.+]] = arith.constant 2 : index
+//   CHECK-DAG:   %[[C3:.+]] = arith.constant 3 : index
 //       CHECK:   %[[RESULT:[a-zA-Z0-9]+]] = scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[DIM0]] step %[[C2]]
-//       CHECK:     %[[C3:.+]] = arith.constant 3 : index
 //       CHECK:     scf.for {{.*}} = %[[C0]] to %[[DIM1]] step %[[C3]] iter_args(%[[INNER_OUT:.*]] =
 //       CHECK:       %[[SWAP_RESULT:.*]] = scf.if
 //       CHECK:         tensor.generate
@@ -62,8 +62,8 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
-//   CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 + 8)>
-//   CHECK-DAG: #[[MAP1:.*]] = affine_map<()[s0] -> (s0 + 7)>
+//   CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 + 8)>
+//   CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 + 7)>
 //       CHECK: func @dynamic_2d_pad_tensor_inner_tiling(
 //  CHECK-SAME:     %[[IN:.*]]: tensor<?x?xf32>
 //   CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
@@ -107,9 +107,9 @@ module attributes {transform.with_named_sequence} {
 //   CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
 //   CHECK-DAG:   %[[C15:.*]] = arith.constant 15 : index
 //   CHECK-DAG:   %[[C2:.*]] = arith.constant 2 : index
+//   CHECK-DAG:   %[[C16:.*]] = arith.constant 16 : index
+//   CHECK-DAG:   %[[C3:.*]] = arith.constant 3 : index
 //       CHECK:   %[[RESULT:.*]] = scf.for {{.*}} = %[[C0]] to %[[C15]] step %[[C2]]
-//   CHECK-DAG:     %[[C16:.*]] = arith.constant 16 : index
-//   CHECK-DAG:     %[[C3:.*]] = arith.constant 3 : index
 //       CHECK:     scf.for {{.*}} = %[[C0]] to %[[C16]] step %[[C3]] iter_args(%[[INNER_OUT:.*]] =
 //       CHECK:       %[[SWAP_RESULT:.*]] = scf.if
 //       CHECK:         tensor.generate
diff --git a/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir
index 488a52e8e3e91..08be9737f4302 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir
@@ -24,13 +24,13 @@ module attributes {transform.with_named_sequence} {
 // CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<?x?xf32>
 //  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
 //  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
-//  CHECK-DAG:   %[[C10:.+]] = arith.constant 10 : index
 //  CHECK-DAG:   %[[M:.+]] = tensor.dim %[[ARG0]], %[[C0]]
 //  CHECK-DAG:   %[[K:.+]] = tensor.dim %[[ARG0]], %[[C1]]
 //  CHECK-DAG:   %[[N:.+]] = tensor.dim %[[ARG1]], %[[C1]]
+//  CHECK-DAG:   %[[C10:.+]] = arith.constant 10 : index
+//  CHECK-DAG:   %[[C20:.+]] = arith.constant 20 : index
 //      CHECK:   %[[OUTER:[a-zA-Z0-9]+]] = scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[M]] step %[[C10]]
 // CHECK-SAME:       iter_args(%[[INIT0:.+]] = %[[ARG2]])
-//  CHECK-DAG:     %[[C20:.+]] = arith.constant 20 : index
 //      CHECK:     %[[INNER:[a-zA-Z0-9]+]] = scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[N]] step %[[C20]]
 // CHECK-SAME:         iter_args(%[[INIT1:.+]] = %[[INIT0]])
 //  CHECK-DAG:       %[[TS_Y:.+]] = affine.min #[[$MAP0]](%[[IV0]])[%[[M]]]
@@ -77,14 +77,14 @@ module attributes {transform.with_named_sequence} {
 // CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: memref<?x?xf32>
 //  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
 //  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
-//  CHECK-DAG:   %[[C10:.+]] = arith.constant 10 : index
 //  CHECK-DAG:   %[[M:.+]] = memref.dim %[[ARG0]], %[[C0]]
 //  CHECK-DAG:   %[[K:.+]] = memref.dim %[[ARG0]], %[[C1]]
 //  CHECK-DAG:   %[[N:.+]] = memref.dim %[[ARG1]], %[[C1]]
+//  CHECK-DAG:   %[[C10:.+]] = arith.constant 10 : index
+//  CHECK-DAG:   %[[C20:.+]] = arith.constant 20 : index
+//  CHECK-DAG:   %[[C30:.+]] = arith.constant 30 : index
 //      CHECK:   scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[M]] step %[[C10]]
-//  CHECK-DAG:     %[[C20:.+]] = arith.constant 20 : index
 //      CHECK:     scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[N]] step %[[C20]]
-//  CHECK-DAG:       %[[C30:.+]] = arith.constant 30 : index
 //      CHECK:       scf.for %[[IV2:[a-zA-Z0-9]+]] = %[[C0]] to %[[K]] step %[[C30]]
 //  CHECK-DAG:         %[[TS_M:.+]] = affine.min #[[$MAP0]](%[[IV0]])[%[[M]]]
 //  CHECK-DAG:         %[[TS_N:.+]] = affine.min #[[$MAP1]](%[[IV1]])[%[[N]]]
@@ -130,15 +130,15 @@ module attributes {transform.with_named_sequence} {
 //   CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0) -> (10, -d0 + 128)>
 // CHECK-LABEL: func.func @multi_result(
 //  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<128x200x300xf32>)
-//   CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
-//   CHECK-DAG:   %[[C10:.+]] = arith.constant 10 : index
-//   CHECK-DAG:   %[[C128:.+]] = arith.constant 128 : index
 //   CHECK-DAG:   %[[INIT0:.+]] = tensor.empty()
 //   CHECK-DAG:   %[[INIT1:.+]] = tensor.empty()
+//   CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//   CHECK-DAG:   %[[C128:.+]] = arith.constant 128 : index
+//   CHECK-DAG:   %[[C300:.+]] = arith.constant 300 : index
+//   CHECK-DAG:   %[[C10:.+]] = arith.constant 10 : index
+//   CHECK-DAG:   %[[C20:.+]] = arith.constant 20 : index
 //       CHECK:   %[[OUTER:[a-zA-Z0-9]+]]:2 = scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[C128]] step %[[C10]]
 //  CHECK-SAME:       iter_args(%[[ARG1:[a-zA-Z0-9]+]] = %[[INIT0]], %[[ARG2:[a-zA-Z0-9]+]] = %[[INIT1]])
-//   CHECK-DAG:     %[[C300:.+]] = arith.constant 300 : index
-//   CHECK-DAG:     %[[C20:.+]] = arith.constant 20 : index
 //       CHECK:     %[[INNER:[a-zA-Z0-9]+]]:2 = scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[C300]] step %[[C20]]
 //  CHECK-SAME:         iter_args(%[[ARG3:[a-zA-Z0-9]+]] = %[[ARG1]], %[[ARG4:[a-zA-Z0-9]+]] = %[[ARG2]])
 //   CHECK-DAG:       %[[TS_Y:.+]] = affine.min #[[$MAP0]](%[[IV0]])
@@ -193,7 +193,6 @@ module attributes {transform.with_named_sequence} {
 //  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
 //  CHECK-DAG:   %[[C2:.+]] = arith.constant 2 : index
 //  CHECK-DAG:   %[[C3:.+]] = arith.constant 3 : index
-//  CHECK-DAG:   %[[C10:.+]] = arith.constant 10 : index
 //  CHECK-DAG:   %[[N:.+]] = tensor.dim %[[INPUT]], %[[C0]]
 //  CHECK-DAG:   %[[C:.+]] = tensor.dim %[[INPUT]], %[[C3]]
 //  CHECK-DAG:   %[[P:.+]] = tensor.dim %[[FILTER]], %[[C0]]
@@ -201,12 +200,13 @@ module attributes {transform.with_named_sequence} {
 //  CHECK-DAG:   %[[F:.+]] = tensor.dim %[[FILTER]], %[[C3]]
 //  CHECK-DAG:   %[[R:.+]] = tensor.dim %[[INIT]], %[[C1]]
 //  CHECK-DAG:   %[[S:.+]] = tensor.dim %[[INIT]], %[[C2]]
+//  CHECK-DAG:   %[[C10:.+]] = arith.constant 10 : index
+//  CHECK-DAG:   %[[C20:.+]] = arith.constant 20 : index
+//  CHECK-DAG:   %[[C30:.+]] = arith.constant 30 : index
 //      CHECK:   scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[P]] step %[[C10]]
 // CHECK-SAME:       iter_args(%[[INIT0:.+]] = %[[INIT]])
-//  CHECK-DAG:     %[[C20:.+]] = arith.constant 20 : index
 //      CHECK:     scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[Q]] step %[[C20]]
 // CHECK-SAME:         iter_args(%[[INIT1:.+]] = %[[INIT0]])
-//  CHECK-DAG:       %[[C30:.+]] = arith.constant 30 : index
 //      CHECK:       scf.for %[[IV2:[a-zA-Z0-9]+]] = %[[C0]] to %[[C]] step %[[C30]]
 // CHECK-SAME:           iter_args(%[[INIT2:.+]] = %[[INIT1]])
 //  CHECK-DAG:         %[[TS_P:.+]] = affine.min #[[$MAP0]](%[[IV0]])[%[[P]]]
@@ -259,15 +259,15 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
-//       CHECK: #[[$MAP_ADD:.+]] = affine_map<(d0, d1) -> (d0 + d1)>
-// CHECK-LABEL: @indexed_semantics
-//       CHECK:   scf.for %[[I0:.+]] = %{{.*}} to %{{.*}} step %{{.*}}
-//       CHECK:     scf.for %[[I1:.+]] = %{{.*}} to %{{.*}} step %{{.*}}
-//       CHECK:       %[[INDEX0:.+]] = linalg.index 0
-//       CHECK:       %[[INDEX0_AMENDED:.+]] = affine.apply #[[$MAP_ADD]](%[[INDEX0]], %[[I0]])
-//       CHECK:       %[[INDEX1:.+]] = linalg.index 1
-//       CHECK:       %[[INDEX1_AMENDED:.+]] = affine.apply #[[$MAP_ADD]](%[[INDEX1]], %[[I1]])
-//       CHECK:       arith.addi %[[INDEX0_AMENDED]], %[[INDEX1_AMENDED]]
+// CHECK: #[[MAP_ADD:.+]] = affine_map<(d0, d1) -> (d0 + d1)>
+// CHECK: @indexed_semantics
+// CHECK:   scf.for %[[I0:.+]] = %{{.*}} to %{{.*}} step %{{.*}}
+// CHECK:     scf.for %[[I1:.+]] = %{{.*}} to %{{.*}} step %{{.*}}
+// CHECK:       %[[INDEX0:.+]] = linalg.index 0
+// CHECK:       %[[INDEX0_AMENDED:.+]] = affine.apply #[[MAP_ADD]](%[[INDEX0]], %[[I0]])
+// CHECK:       %[[INDEX1:.+]] = linalg.index 1
+// CHECK:       %[[INDEX1_AMENDED:.+]] = affine.apply #[[MAP_ADD]](%[[INDEX1]], %[[I1]])
+// CHECK:       arith.addi %[[INDEX0_AMENDED]], %[[INDEX1_AMENDED]]
 
 // -----
 
@@ -296,16 +296,16 @@ module attributes {transform.with_named_sequence} {
 // CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<?x?xf32>
 //  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
 //  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
-//  CHECK-DAG:   %[[C20:.+]] = arith.constant 20 : index
 //  CHECK-DAG:   %[[M:.+]] = tensor.dim %[[ARG0]], %[[C0]]
 //  CHECK-DAG:   %[[K:.+]] = tensor.dim %[[ARG0]], %[[C1]]
 //  CHECK-DAG:   %[[N:.+]] = tensor.dim %[[ARG1]], %[[C1]]
+//  CHECK-DAG:   %[[C10:.+]] = arith.constant 10 : index
+//  CHECK-DAG:   %[[C20:.+]] = arith.constant 20 : index
+//  CHECK-DAG:   %[[C30:.+]] = arith.constant 30 : index
 //      CHECK:   %[[OUTER:[a-zA-Z0-9]+]] = scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[N]] step %[[C20]]
 // CHECK-SAME:       iter_args(%[[INIT0:.+]] = %[[ARG2]])
-//  CHECK-DAG:     %[[C30:.+]] = arith.constant 30 : index
 //      CHECK:     %[[INNER1:[a-zA-Z0-9]+]] = scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[K]] step %[[C30]]
 // CHECK-SAME:         iter_args(%[[INIT1:.+]] = %[[INIT0]])
-//  CHECK-DAG:       %[[C10:.+]] = arith.constant 10 : index
 //      CHECK:       %[[INNER2:[a-zA-Z0-9]+]] = scf.for %[[IV2:[a-zA-Z0-9]+]] = %[[C0]] to %[[M]] step %[[C10]]
 // CHECK-SAME:           iter_args(%[[INIT2:.+]] = %[[INIT1]])
 //  CHECK-DAG:         %[[TS_N:.+]] = affine.min #[[$MAP0]](%[[IV0]])[%[[N]]]

>From 0b0005670c049840a44c99c7d5fe105bfb92bb5c Mon Sep 17 00:00:00 2001
From: MaheshRavishankar <mahesh.ravishankar at gmail.com>
Date: Thu, 23 May 2024 12:00:55 -0700
Subject: [PATCH 3/7] Add logic to account for negative tile sizes.

---
 .../SCF/Transforms/TileUsingInterface.cpp     | 102 +++++++++++++-----
 mlir/test/Dialect/Linalg/tile-to-forall.mlir  |  49 +++++----
 2 files changed, 102 insertions(+), 49 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index c44f39fa40e4a..5b9bb713f19e0 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -219,45 +219,93 @@ static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc,
       b, loc, minMap, SmallVector<OpFoldResult>{offset, tileSize, size});
 }
 
+/// Returns true if the maximum tile offset `tileSize * numThreads-1` is less
+/// than `iterationSize`.
+static bool canOmitTileOffsetInBoundsCheck(OpFoldResult tileSize,
+                                           OpFoldResult numThreads,
+                                           OpFoldResult iterationSize) {
+  std::optional<int64_t> tileSizeConst = getConstantIntValue(tileSize);
+  std::optional<int64_t> numThreadsConst = getConstantIntValue(numThreads);
+  std::optional<int64_t> iterSizeConst = getConstantIntValue(iterationSize);
+  if (!tileSizeConst || !numThreadsConst || !iterSizeConst)
+    return false;
+  return *tileSizeConst * (*numThreadsConst - 1) < *iterSizeConst;
+}
+
 /// Compute the tile offsets and sizes.
 static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>>
 getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs,
                       ArrayRef<Range> iterationDomain,
-                      ArrayRef<OpFoldResult> tileSizes, bool isLoopNormalized) {
+                      ArrayRef<OpFoldResult> tileSizes,
+                      ArrayRef<OpFoldResult> numThreads) {
   SmallVector<OpFoldResult> offsets, sizes;
   int materializedLoopNum = 0;
 
-  AffineExpr d0, s0, s1, s2;
-  AffineExpr offsetExpr;
-  if (isLoopNormalized) {
-    bindDims(rewriter.getContext(), d0);
+  if (!numThreads.empty()) {
+    AffineExpr d0, d1, s0, s1, s2;
+    AffineExpr offsetExpr, residualTileSizeExpr;
+    bindDims(rewriter.getContext(), d0, d1);
     bindSymbols(rewriter.getContext(), s0, s1, s2);
-    offsetExpr = s0 + d0 * s1 * s2;
-  }
+    offsetExpr = d0 + d1 * s0 * s1;
+    residualTileSizeExpr = s2 - (d0 + d1 * s0 * s1);
 
-  for (auto [tileSize, loopRange] :
-       llvm::zip_equal(tileSizes, iterationDomain)) {
-    if (isConstantIntValue(tileSize, 0)) {
-      offsets.push_back(loopRange.offset);
-      sizes.push_back(loopRange.size);
-      continue;
-    }
-    // If loop is normalized, the offset is (lb + iv * step * tileSize)
-    Value iv = ivs[materializedLoopNum++];
-    OpFoldResult offset;
-    if (isLoopNormalized) {
-      offset = affine::makeComposedFoldedAffineApply(
+    for (auto [nt, tileSize, loopRange] :
+         llvm::zip_equal(numThreads, tileSizes, iterationDomain)) {
+
+      if (isConstantIntValue(nt, 0) || isConstantIntValue(nt, 1)) {
+        offsets.push_back(loopRange.offset);
+        sizes.push_back(loopRange.size);
+        continue;
+      }
+
+      Value iv = ivs[materializedLoopNum++];
+      OpFoldResult offset = affine::makeComposedFoldedAffineApply(
           rewriter, loc, offsetExpr,
-          ArrayRef<OpFoldResult>{iv, loopRange.offset, loopRange.stride,
+          ArrayRef<OpFoldResult>{loopRange.offset, iv, loopRange.stride,
                                  tileSize});
-    } else {
-      offset = getAsOpFoldResult(iv);
+      OpFoldResult residualTileSize = affine::makeComposedFoldedAffineApply(
+          rewriter, loc, residualTileSizeExpr,
+          {loopRange.offset, nt, loopRange.stride, tileSize, loopRange.size});
+      OpFoldResult size = tileSize;
+      if (!isConstantIntValue(residualTileSize, 0)) {
+        OpFoldResult sizeMinusOffsetPerThread =
+            affine::makeComposedFoldedAffineApply(rewriter, loc, s0 - d0,
+                                                  {offset, loopRange.size});
+        size = affine::makeComposedFoldedAffineMin(
+            rewriter, loc,
+            AffineMap::getMultiDimIdentityMap(2, rewriter.getContext()),
+            {sizeMinusOffsetPerThread, tileSize});
+      }
+      if (!canOmitTileOffsetInBoundsCheck(tileSize, nt, loopRange.size)) {
+        AffineMap maxMap =
+            AffineMap::getMultiDimIdentityMap(2, rewriter.getContext());
+        size = affine::makeComposedFoldedAffineMax(
+            rewriter, loc, maxMap, {rewriter.getIndexAttr(0), size});
+      }
+
+      offsets.push_back(offset);
+      sizes.push_back(size);
+    }
+    return {offsets, sizes};
+  } else {
+    for (auto [tileSize, loopRange] :
+         llvm::zip_equal(tileSizes, iterationDomain)) {
+
+      if (isConstantIntValue(tileSize, 0)) {
+        offsets.push_back(loopRange.offset);
+        sizes.push_back(loopRange.size);
+        continue;
+      }
+
+      Value iv = ivs[materializedLoopNum++];
+      OpFoldResult offset = getAsOpFoldResult(iv);
+      offsets.push_back(offset);
+      OpFoldResult size =
+          getBoundedTileSize(rewriter, loc, loopRange, offset, tileSize);
+      sizes.push_back(size);
     }
-    offsets.push_back(offset);
-    sizes.push_back(
-        getBoundedTileSize(rewriter, loc, loopRange, offset, tileSize));
+    return {offsets, sizes};
   }
-  return {offsets, sizes};
 }
 
 /// Function to return the bounds of the loops to be generated.
@@ -764,7 +812,7 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
     // 4a. Compute the `offsets` and `sizes` to use for tiling.
     SmallVector<OpFoldResult> offsets, sizes;
     std::tie(offsets, sizes) = getTileOffsetAndSizes(
-        rewriter, loc, ivs, iterationDomain, tileSizes, !numThreads.empty());
+        rewriter, loc, ivs, iterationDomain, tileSizes, numThreads);
 
     // 4b. If interchange was provided, apply inverse of the interchange
     //     to get back the offsets/sizes in the order to be specified.
diff --git a/mlir/test/Dialect/Linalg/tile-to-forall.mlir b/mlir/test/Dialect/Linalg/tile-to-forall.mlir
index d1ed468fce323..c0ba5a8402d5f 100644
--- a/mlir/test/Dialect/Linalg/tile-to-forall.mlir
+++ b/mlir/test/Dialect/Linalg/tile-to-forall.mlir
@@ -3,9 +3,9 @@
 // Offset per thread:
 // CHECK-DAG: affine_map<(d0)[s0] -> (d0 * (s0 ceildiv 10))>
 // Per thread tile size.
-// CHECK-DAG: affine_map<(d0)[s0] -> (s0 ceildiv 10, -(d0 * (s0 ceildiv 10)) + s0)>
+// CHECK-DAG: affine_map<(d0)[s0] -> (-(d0 * (s0 ceildiv 10)) + s0, s0 ceildiv 10)>
 // CHECK-DAG: affine_map<(d0)[s0] -> (d0 * (s0 ceildiv 20))>
-// CHECK-DAG: affine_map<(d0)[s0] -> (s0 ceildiv 20, -(d0 * (s0 ceildiv 20)) + s0)>
+// CHECK-DAG: affine_map<(d0)[s0] -> (-(d0 * (s0 ceildiv 20)) + s0, s0 ceildiv 20)>
 
 module {
 // CHECK-LABEL: matmul(
@@ -96,7 +96,7 @@ module {
 // In this test case, matmul dims and tile size are dynamic.
 
 // CHECK-DAG: #[[$map0:.+]] = affine_map<()[s0, s1] -> (s0 ceildiv s1)>
-// CHECK-DAG: #[[$map2:.+]] = affine_map<(d0)[s0, s1] -> (s0, -(d0 * s0) + s1)>
+// CHECK-DAG: #[[$map2:.+]] = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)>
 // CHECK-DAG: #[[$map4:.+]] = affine_map<(d0)[s0] -> (d0 * s0)>
 
 // CHECK-LABEL: matmul_tile_size_dynamic_dynamic(
@@ -140,7 +140,7 @@ module attributes {transform.with_named_sequence} {
 
 // Tests that dimension 0 can eliminate affine.min/max, dimension 1 cannot.
 
-// CHECK-DAG: #[[$map0:.+]] = affine_map<(d0) -> (15, d0 * -15 + 300)>
+// CHECK-DAG: #[[$map0:.+]] = affine_map<(d0) -> (d0 * -15 + 300, 15)>
 // CHECK-DAG: #[[$map1:.+]] = affine_map<(d0) -> (0, d0)>
 // CHECK-DAG: #[[$map2:.+]] = affine_map<(d0) -> (d0 * 10)>
 // CHECK-DAG: #[[$map3:.+]] = affine_map<(d0) -> (d0 * 15)>
@@ -176,6 +176,7 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
 // -----
 
 // CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 10)>
@@ -296,7 +297,7 @@ module {
 
 // CHECK-DAG: #[[$map0:.+]] = affine_map<()[s0, s1] -> (s0 ceildiv s1)>
 // CHECK-DAG: #[[$map1:.+]] = affine_map<()[s0] -> (s0 ceildiv 20)>
-// CHECK-DAG: #[[$map2:.+]] = affine_map<(d0)[s0, s1] -> (s0, -(d0 * s0) + s1)>
+// CHECK-DAG: #[[$map2:.+]] = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)>
 // CHECK-DAG: #[[$map3:.+]] = affine_map<(d0)[s0] -> (d0 * -20 + s0, 20)>
 // CHECK-DAG: #[[$map4:.+]] = affine_map<(d0)[s0] -> (d0 * s0)>
 // CHECK-DAG: #[[$map5:.+]] = affine_map<(d0) -> (d0 * 20)>
@@ -339,7 +340,6 @@ module attributes {transform.with_named_sequence} {
 // -----
 
 // CHECK-DAG: #[[$map0:.+]] = affine_map<(d0) -> (d0 * -15 + 100, 15)>
-// CHECK-DAG: #[[$map1:.+]] = affine_map<(d0) -> (0, d0)>
 // CHECK-DAG: #[[$map2:.+]] = affine_map<(d0) -> (d0 * 15)>
 // CHECK-DAG: #[[$map3:.+]] = affine_map<(d0) -> (d0)>
 
@@ -352,8 +352,7 @@ module attributes {transform.with_named_sequence} {
                                          %OUT1: tensor<100xf32>, %OUT2: tensor<100xf32>)
                                          -> (tensor<100xf32>, tensor<100xf32>) {
 //      CHECK: scf.forall (%[[IV0:.+]]) in (7) shared_outs(%[[OUT1:[0-9a-z]+]] = %[[ORGOUT1]], %[[OUT2:[0-9a-z]+]] = %[[ORGOUT2]])
-//      CHECK:   %[[TSMIN:.+]] = affine.min #[[$map0]](%[[IV0]])
-//      CHECK:   %[[TS:.+]] = affine.max #[[$map1]](%[[TSMIN]])
+//      CHECK:   %[[TS:.+]] = affine.min #[[$map0]](%[[IV0]])
 //  CHECK-NOT:   affine.min
 //  CHECK-NOT:   affine.max
 //      CHECK:   %[[LB:.+]] = affine.apply #[[$map2]](%[[IV0]])
@@ -453,9 +452,10 @@ module attributes {transform.with_named_sequence} {
 // CHECK-DAG: #[[$map0:.+]] = affine_map<()[s0] -> (s0 ceildiv 10)>
 // CHECK-DAG: #[[$map1:.+]] = affine_map<()[s0] -> (s0 ceildiv 20)>
 // CHECK-DAG: #[[$map2:.+]] = affine_map<(d0)[s0] -> (d0 * -10 + s0, 10)>
-// CHECK-DAG: #[[$map3:.+]] = affine_map<(d0)[s0] -> (d0 * -20 + s0, 20)>
-// CHECK-DAG: #[[$map4:.+]] = affine_map<(d0) -> (d0 * 10)>
-// CHECK-DAG: #[[$map5:.+]] = affine_map<(d0) -> (d0 * 20)>
+// CHECK-DAG: #[[$map3:.+]] = affine_map<(d0) -> (0, d0)>
+// CHECK-DAG: #[[$map4:.+]] = affine_map<(d0)[s0] -> (d0 * -20 + s0, 20)>
+// CHECK-DAG: #[[$map5:.+]] = affine_map<(d0) -> (d0 * 10)>
+// CHECK-DAG: #[[$map6:.+]] = affine_map<(d0) -> (d0 * 20)>
 
 // CHECK-LABEL: matmul_tile_size_dynamic(
 //  CHECK-SAME:   %[[A:[0-9a-z]+]]: tensor<?x?xf32>
@@ -470,10 +470,12 @@ func.func @matmul_tile_size_dynamic(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C
   //      CHECK: %[[NT1:.+]] = affine.apply #map1()[%[[N]]]
   //      CHECK: %[[K:.+]] = tensor.dim %[[A]], %[[c1]] :
   //      CHECK: scf.forall (%[[IV0:.+]], %[[IV1:.+]]) in (%[[NT0]], %[[NT1]]) shared_outs(%[[C_BLK:.*]] = %[[C]])
-  //      CHECK:   %[[TS0:.+]] = affine.min #[[$map2]](%[[IV0]])[%[[M]]]
-  //      CHECK:   %[[TS1:.+]] = affine.min #[[$map3]](%[[IV1]])[%[[N]]]
-  //      CHECK:   %[[LB0:.+]] = affine.apply #[[$map4]](%[[IV0]])
-  //      CHECK:   %[[LB1:.+]] = affine.apply #[[$map5]](%[[IV1]])
+  //      CHECK:   %[[TSMIN0:.+]] = affine.min #[[$map2]](%[[IV0]])[%[[M]]]
+  //      CHECK:   %[[TS0:.+]] = affine.max #[[$map3]](%[[TSMIN0]])
+  //      CHECK:   %[[TSMIN1:.+]] = affine.min #[[$map4]](%[[IV1]])[%[[N]]]
+  //      CHECK:   %[[TS1:.+]] = affine.max #[[$map3]](%[[TSMIN1]])
+  //      CHECK:   %[[LB0:.+]] = affine.apply #[[$map5]](%[[IV0]])
+  //      CHECK:   %[[LB1:.+]] = affine.apply #[[$map6]](%[[IV1]])
   //      CHECK:   tensor.extract_slice %[[A]][%[[LB0]], 0] [%[[TS0]], %[[K]]] [1, 1] :
   //      CHECK:   tensor.extract_slice %[[B]][0, %[[LB1]]] [%[[K]], %[[TS1]]] [1, 1] :
   //      CHECK:   tensor.extract_slice %[[C_BLK]][%[[LB0]], %[[LB1]]] [%[[TS0]], %[[TS1]]] [1, 1] :
@@ -521,9 +523,10 @@ module attributes {transform.with_named_sequence} {
 // CHECK-DAG: #[[$map0:.+]] = affine_map<()[s0] -> (s0 ceildiv 10)>
 // CHECK-DAG: #[[$map1:.+]] = affine_map<()[s0] -> (s0 ceildiv 20)>
 // CHECK-DAG: #[[$map2:.+]] = affine_map<(d0)[s0] -> (d0 * -10 + s0, 10)>
-// CHECK-DAG: #[[$map3:.+]] = affine_map<(d0)[s0] -> (d0 * -20 + s0, 20)>
-// CHECK-DAG: #[[$map4:.+]] = affine_map<(d0) -> (d0 * 10)>
-// CHECK-DAG: #[[$map5:.+]] = affine_map<(d0) -> (d0 * 20)>
+// CHECK-DAG: #[[$map3:.+]] = affine_map<(d0) -> (0, d0)>
+// CHECK-DAG: #[[$map4:.+]] = affine_map<(d0)[s0] -> (d0 * -20 + s0, 20)>
+// CHECK-DAG: #[[$map5:.+]] = affine_map<(d0) -> (d0 * 10)>
+// CHECK-DAG: #[[$map6:.+]] = affine_map<(d0) -> (d0 * 20)>
 
 // CHECK-LABEL: matmul_tile_size_dynamic(
 //  CHECK-SAME:   %[[A:[0-9a-z]+]]: tensor<?x?xf32>
@@ -538,10 +541,12 @@ func.func @matmul_tile_size_dynamic(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C
   //      CHECK: %[[NT1:.+]] = affine.apply #map1()[%[[N]]]
   //      CHECK: %[[K:.+]] = tensor.dim %[[A]], %[[c1]] :
   //      CHECK: scf.forall (%[[IV0:.+]], %[[IV1:.+]]) in (%[[NT0]], %[[NT1]]) shared_outs(%[[C_BLK:.*]] = %[[C]])
-  //      CHECK:   %[[TS0:.+]] = affine.min #[[$map2]](%[[IV0]])[%[[M]]]
-  //      CHECK:   %[[TS1:.+]] = affine.min #[[$map3]](%[[IV1]])[%[[N]]]
-  //      CHECK:   %[[LB0:.+]] = affine.apply #[[$map4]](%[[IV0]])
-  //      CHECK:   %[[LB1:.+]] = affine.apply #[[$map5]](%[[IV1]])
+  //      CHECK:   %[[TSMIN0:.+]] = affine.min #[[$map2]](%[[IV0]])[%[[M]]]
+  //      CHECK:   %[[TS0:.+]] = affine.max #[[$map3]](%[[TSMIN0]])
+  //      CHECK:   %[[TSMIN1:.+]] = affine.min #[[$map4]](%[[IV1]])[%[[N]]]
+  //      CHECK:   %[[TS1:.+]] = affine.max #[[$map3]](%[[TSMIN1]])
+  //      CHECK:   %[[LB0:.+]] = affine.apply #[[$map5]](%[[IV0]])
+  //      CHECK:   %[[LB1:.+]] = affine.apply #[[$map6]](%[[IV1]])
   //      CHECK:   tensor.extract_slice %[[A]][%[[LB0]], 0] [%[[TS0]], %[[K]]] [1, 1] :
   //      CHECK:   tensor.extract_slice %[[B]][0, %[[LB1]]] [%[[K]], %[[TS1]]] [1, 1] :
   //      CHECK:   tensor.extract_slice %[[C_BLK]][%[[LB0]], %[[LB1]]] [%[[TS0]], %[[TS1]]] [1, 1] :

>From d62f3958b41f7d9e2a15e150c7d9f17c8edeb62e Mon Sep 17 00:00:00 2001
From: MaheshRavishankar <mahesh.ravishankar at gmail.com>
Date: Thu, 23 May 2024 21:51:44 -0700
Subject: [PATCH 4/7] Put back CHECK-LABELs

---
 mlir/test/Dialect/Linalg/tile-to-forall.mlir  | 42 +++++++++----------
 .../TilingInterface/tile-using-interface.mlir | 18 ++++----
 2 files changed, 30 insertions(+), 30 deletions(-)

diff --git a/mlir/test/Dialect/Linalg/tile-to-forall.mlir b/mlir/test/Dialect/Linalg/tile-to-forall.mlir
index c0ba5a8402d5f..6e92deaf4cf0d 100644
--- a/mlir/test/Dialect/Linalg/tile-to-forall.mlir
+++ b/mlir/test/Dialect/Linalg/tile-to-forall.mlir
@@ -179,27 +179,27 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
-// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 10)>
-// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 ceildiv 20)>
-// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0] -> (d0 * -10 + s0, 10)>
-// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0] -> (d0 * -20 + s0, 20)>
-// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0) -> (d0 * 10)>
-// CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0) -> (d0 * 20)>
-
-//       CHECK: matmul_tile_size_dynamic(
+// CHECK-DAG: #[[$map0:.+]] = affine_map<()[s0] -> (s0 ceildiv 10)>
+// CHECK-DAG: #[[$map1:.+]] = affine_map<()[s0] -> (s0 ceildiv 20)>
+// CHECK-DAG: #[[$map2:.+]] = affine_map<(d0)[s0] -> (d0 * -10 + s0, 10)>
+// CHECK-DAG: #[[$map4:.+]] = affine_map<(d0)[s0] -> (d0 * -20 + s0, 20)>
+// CHECK-DAG: #[[$map5:.+]] = affine_map<(d0) -> (d0 * 10)>
+// CHECK-DAG: #[[$map6:.+]] = affine_map<(d0) -> (d0 * 20)>
+
+// CHECK-LABEL: matmul_tile_size_dynamic(
 //  CHECK-SAME:   %[[A:[0-9a-z]+]]: tensor<?x?xf32>
 //  CHECK-SAME:   %[[B:[0-9a-z]+]]: tensor<?x?xf32>
 //  CHECK-SAME:   %[[C:[0-9a-z]+]]: tensor<?x?xf32>
 func.func @matmul_tile_size_dynamic(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
   //      CHECK: %[[M:.+]] = tensor.dim %[[A]], %c0 :
   //      CHECK: %[[N:.+]] = tensor.dim %[[B]], %c1 :
-  //      CHECK: %[[NT0:.+]] = affine.apply #[[MAP0]]()[%[[M]]]
-  //      CHECK: %[[NT1:.+]] = affine.apply #[[MAP1]]()[%[[N]]]
+  //      CHECK: %[[NT0:.+]] = affine.apply #[[$map0]]()[%[[M]]]
+  //      CHECK: %[[NT1:.+]] = affine.apply #[[$map1]]()[%[[N]]]
   //      CHECK: scf.forall (%[[IV0:.+]], %[[IV1:.+]]) in (%[[NT0]], %[[NT1]]) shared_outs(%[[C_BLK:.*]] = %[[C]])
-  //      CHECK:   %[[TS0:.+]] = affine.min #[[MAP2]](%[[IV0]])[%[[M]]]
-  //      CHECK:   %[[TS1:.+]] = affine.min #[[MAP3]](%[[IV1]])[%[[N]]]
-  //      CHECK:   %[[LB0:.+]] = affine.apply #[[MAP4]](%[[IV0]])
-  //      CHECK:   %[[LB1:.+]] = affine.apply #[[MAP5]](%[[IV1]])
+  //      CHECK:   %[[TS0:.+]] = affine.min #[[$map2]](%[[IV0]])[%[[M]]]
+  //      CHECK:   %[[TS1:.+]] = affine.min #[[$map4]](%[[IV1]])[%[[N]]]
+  //      CHECK:   %[[LB0:.+]] = affine.apply #[[$map5]](%[[IV0]])
+  //      CHECK:   %[[LB1:.+]] = affine.apply #[[$map6]](%[[IV1]])
   //      CHECK:   tensor.extract_slice %[[A]]
   //      CHECK:   tensor.extract_slice %[[B]]
   //      CHECK:   tensor.extract_slice %[[C_BLK]]
@@ -223,21 +223,21 @@ module attributes {transform.with_named_sequence} {
 
 // Tests that dimension 0 can eliminate affine.min/max, dimension 1 cannot.
 
-// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0) -> (d0 * -21 + 300, 21)>
-// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (d0 * 10)>
-// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0) -> (d0 * 21)>
+// CHECK-DAG: #[[$map0:.+]] = affine_map<(d0) -> (d0 * -21 + 300, 21)>
+// CHECK-DAG: #[[$map2:.+]] = affine_map<(d0) -> (d0 * 10)>
+// CHECK-DAG: #[[$map3:.+]] = affine_map<(d0) -> (d0 * 21)>
 
-//       CHECK: matmul_tile_size_static(
+// CHECK-LABEL: matmul_tile_size_static(
 //  CHECK-SAME:   %[[A:[0-9a-z]+]]: tensor
 //  CHECK-SAME:   %[[B:[0-9a-z]+]]: tensor
 //  CHECK-SAME:   %[[C:[0-9a-z]+]]: tensor
 func.func @matmul_tile_size_static(%A: tensor<100x200xf32>, %B: tensor<200x300xf32>, %C: tensor<100x300xf32>) -> tensor<100x300xf32> {
   //      CHECK: scf.forall (%[[IV0:.+]], %[[IV1:.+]]) in (10, 15) shared_outs(%[[C_BLK:.*]] = %[[C]])
-  //      CHECK:   %[[TS:.+]] = affine.min #[[MAP0]](%[[IV1]])
+  //      CHECK:   %[[TS:.+]] = affine.min #[[$map0]](%[[IV1]])
   //  CHECK-NOT:   affine.max
   //  CHECK-NOT:   affine.min
-  //      CHECK:   %[[LB0:.+]] = affine.apply #[[MAP1]](%[[IV0]])
-  //      CHECK:   %[[LB1:.+]] = affine.apply #[[MAP2]](%[[IV1]])
+  //      CHECK:   %[[LB0:.+]] = affine.apply #[[$map2]](%[[IV0]])
+  //      CHECK:   %[[LB1:.+]] = affine.apply #[[$map3]](%[[IV1]])
   //      CHECK:   %[[tA:.+]] = tensor.extract_slice %[[A]][%[[LB0]], 0] [10, 200] [1, 1] :
   //      CHECK:   %[[tB:.+]] = tensor.extract_slice %[[B]][0, %[[LB1]]] [200, %[[TS]]] [1, 1] :
   //      CHECK:   %[[tC:.+]] = tensor.extract_slice %[[C_BLK]][%[[LB0]], %[[LB1]]] [10, %[[TS]]] [1, 1] :
diff --git a/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir
index 08be9737f4302..0a4d4c45f10be 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir
@@ -259,15 +259,15 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
-// CHECK: #[[MAP_ADD:.+]] = affine_map<(d0, d1) -> (d0 + d1)>
-// CHECK: @indexed_semantics
-// CHECK:   scf.for %[[I0:.+]] = %{{.*}} to %{{.*}} step %{{.*}}
-// CHECK:     scf.for %[[I1:.+]] = %{{.*}} to %{{.*}} step %{{.*}}
-// CHECK:       %[[INDEX0:.+]] = linalg.index 0
-// CHECK:       %[[INDEX0_AMENDED:.+]] = affine.apply #[[MAP_ADD]](%[[INDEX0]], %[[I0]])
-// CHECK:       %[[INDEX1:.+]] = linalg.index 1
-// CHECK:       %[[INDEX1_AMENDED:.+]] = affine.apply #[[MAP_ADD]](%[[INDEX1]], %[[I1]])
-// CHECK:       arith.addi %[[INDEX0_AMENDED]], %[[INDEX1_AMENDED]]
+//       CHECK: #[[$MAP_ADD:.+]] = affine_map<(d0, d1) -> (d0 + d1)>
+// CHECK-LABEL: @indexed_semantics
+//       CHECK:   scf.for %[[I0:.+]] = %{{.*}} to %{{.*}} step %{{.*}}
+//       CHECK:     scf.for %[[I1:.+]] = %{{.*}} to %{{.*}} step %{{.*}}
+//       CHECK:       %[[INDEX0:.+]] = linalg.index 0
+//       CHECK:       %[[INDEX0_AMENDED:.+]] = affine.apply #[[$MAP_ADD]](%[[INDEX0]], %[[I0]])
+//       CHECK:       %[[INDEX1:.+]] = linalg.index 1
+//       CHECK:       %[[INDEX1_AMENDED:.+]] = affine.apply #[[$MAP_ADD]](%[[INDEX1]], %[[I1]])
+//       CHECK:       arith.addi %[[INDEX0_AMENDED]], %[[INDEX1_AMENDED]]
 
 // -----
 

>From 3de42124f4097319daf9771a41143987de1c1f98 Mon Sep 17 00:00:00 2001
From: MaheshRavishankar <mahesh.ravishankar at gmail.com>
Date: Thu, 23 May 2024 21:59:11 -0700
Subject: [PATCH 5/7] Address comments

---
 mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h | 4 ++--
 mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp        | 2 +-
 2 files changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index 1a799066b20d1..175882cd54a63 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -64,8 +64,8 @@ struct SCFTilingOptions {
     numThreadsComputationFunction = std::move(fun);
     return *this;
   }
-  /// Convenience function to set the `tileSizeComputationFunction` to a
-  /// function that computes tile sizes at the point they are needed.
+  /// Convenience function to set the `numThreadsComputationFunction` to a
+  /// function that computes num threads at the point they are needed.
   SCFTilingOptions &setNumThreads(ArrayRef<OpFoldResult> numThreads);
 
   /// The interchange vector to reorder the tiled loops.
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 5b9bb713f19e0..1c7605bcfb0d6 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -86,7 +86,7 @@ verifyTileSizeOptions(RewriterBase &rewriter, Location loc,
   if (!options.interchangeVector.empty()) {
     if (!isPermutationVector(options.interchangeVector)) {
       return rewriter.notifyMatchFailure(
-          loc, "invalid intechange vector, not a permutation of the entire "
+          loc, "invalid interchange vector, not a permutation of the entire "
                "iteration space");
     }
   }

>From 715a509a0a7b5518968675d2f08dad92fe14af0e Mon Sep 17 00:00:00 2001
From: MaheshRavishankar <mahesh.ravishankar at gmail.com>
Date: Sun, 26 May 2024 17:38:05 -0700
Subject: [PATCH 6/7] Next round of comments.

---
 .../SCF/Transforms/TileUsingInterface.cpp     | 46 ++++++++++---------
 1 file changed, 24 insertions(+), 22 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 1c7605bcfb0d6..0fe7f6296a654 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -74,11 +74,11 @@ fillInterchangeVector(ArrayRef<int64_t> interchangeVector,
 static LogicalResult
 verifyTileSizeOptions(RewriterBase &rewriter, Location loc,
                       const scf::SCFTilingOptions &options) {
-  // Specifying number of tile is only supported on `scf.forall` op.
+  // Specifying number of threads is only supported on `scf.forall` op.
   if (options.numThreadsComputationFunction &&
       options.loopType != scf::SCFTilingOptions::LoopType::ForallOp) {
     return rewriter.notifyMatchFailure(
-        loc, "number of tiles/threads can only by specified when loop type is "
+        loc, "number of threads can only by specified when loop type is "
              "set to use `scf.forall`");
   }
 
@@ -110,25 +110,27 @@ getTileSizes(RewriterBase &rewriter, TilingInterface op,
     // If the number of tiles is also specified, use that.
     if (options.tileSizeComputationFunction) {
       tileSizes = options.tileSizeComputationFunction(rewriter, op);
-    } else {
-      // Compute the tile sizes from the iteration domain and number
-      // of tiles as follows
-      // - niters = ceilDiv(ub - lb, step)
-      // - tileSize = ceilDiv(niters, numThreads)
-      AffineExpr s0, s1, s2, s3;
-      bindSymbols(rewriter.getContext(), s0, s1, s2, s3);
-      AffineExpr numItersExpr = (s1 - s0).ceilDiv(s2);
-      AffineExpr tileSizeExpr = numItersExpr.ceilDiv(s3);
       tileSizes.resize(numLoops, zero);
-      for (auto [index, range, nt] :
-           llvm::enumerate(iterationDomain, numThreads)) {
-        if (isConstantIntValue(nt, 0))
-          continue;
+      return {tileSizes, numThreads};
+    }
 
-        tileSizes[index] = affine::makeComposedFoldedAffineApply(
-            rewriter, op.getLoc(), tileSizeExpr,
-            {range.offset, range.size, range.stride, nt});
-      }
+    // Compute the tile sizes from the iteration domain and number
+    // of tiles as follows
+    // - niters = ceilDiv(ub - lb, step)
+    // - tileSize = ceilDiv(niters, numThreads)
+    AffineExpr s0, s1, s2, s3;
+    bindSymbols(rewriter.getContext(), s0, s1, s2, s3);
+    AffineExpr numItersExpr = (s1 - s0).ceilDiv(s2);
+    AffineExpr tileSizeExpr = numItersExpr.ceilDiv(s3);
+    tileSizes.resize(numLoops, zero);
+    for (auto [index, range, nt] :
+         llvm::enumerate(iterationDomain, numThreads)) {
+      if (isConstantIntValue(nt, 0))
+        continue;
+
+      tileSizes[index] = affine::makeComposedFoldedAffineApply(
+          rewriter, op.getLoc(), tileSizeExpr,
+          {range.offset, range.size, range.stride, nt});
     }
     tileSizes.resize(numLoops, zero);
     return {tileSizes, numThreads};
@@ -138,9 +140,9 @@ getTileSizes(RewriterBase &rewriter, TilingInterface op,
   // skips tiling a particular dimension. This convention is significantly
   // simpler to handle instead of adjusting affine maps to account for missing
   // dimensions.
-  if (options.tileSizeComputationFunction) {
-    tileSizes = options.tileSizeComputationFunction(rewriter, op);
-  }
+  assert(options.tileSizeComputationFunction &&
+         "expected tile sizes to be specified");
+  tileSizes = options.tileSizeComputationFunction(rewriter, op);
   tileSizes.resize(numLoops, zero);
 
   return {tileSizes, numThreads};

>From 4bfa8a32f3b22896d0572032cd3c42aa2f3935b6 Mon Sep 17 00:00:00 2001
From: MaheshRavishankar <mahesh.ravishankar at gmail.com>
Date: Thu, 30 May 2024 16:54:38 -0700
Subject: [PATCH 7/7] Drop support for non-unit strides, and assert that
 strides of iteration domain are 1.

---
 .../SCF/Transforms/TileUsingInterface.cpp     | 42 +++++++++++++------
 1 file changed, 29 insertions(+), 13 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 0fe7f6296a654..aca27d85253f9 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -98,6 +98,11 @@ static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>>
 getTileSizes(RewriterBase &rewriter, TilingInterface op,
              ArrayRef<Range> iterationDomain,
              const scf::SCFTilingOptions &options) {
+  assert(
+      llvm::all_of(iterationDomain,
+                   [](Range r) { return isConstantIntValue(r.stride, 1); }) &&
+      "tile size computation assumes that all dimensions of the iteration "
+      "domain have stride 1");
   OpFoldResult zero = rewriter.getIndexAttr(0);
   SmallVector<OpFoldResult> tileSizes, numThreads;
   size_t numLoops = iterationDomain.size();
@@ -118,10 +123,11 @@ getTileSizes(RewriterBase &rewriter, TilingInterface op,
     // of tiles as follows
     // - niters = ceilDiv(ub - lb, step)
     // - tileSize = ceilDiv(niters, numThreads)
-    AffineExpr s0, s1, s2, s3;
-    bindSymbols(rewriter.getContext(), s0, s1, s2, s3);
-    AffineExpr numItersExpr = (s1 - s0).ceilDiv(s2);
-    AffineExpr tileSizeExpr = numItersExpr.ceilDiv(s3);
+    AffineExpr s0, s1, s2;
+    bindSymbols(rewriter.getContext(), s0, s1, s2);
+    // TODO: The step here is assumed to be 1.
+    AffineExpr numItersExpr = (s1 - s0);
+    AffineExpr tileSizeExpr = numItersExpr.ceilDiv(s2);
     tileSizes.resize(numLoops, zero);
     for (auto [index, range, nt] :
          llvm::enumerate(iterationDomain, numThreads)) {
@@ -129,8 +135,7 @@ getTileSizes(RewriterBase &rewriter, TilingInterface op,
         continue;
 
       tileSizes[index] = affine::makeComposedFoldedAffineApply(
-          rewriter, op.getLoc(), tileSizeExpr,
-          {range.offset, range.size, range.stride, nt});
+          rewriter, op.getLoc(), tileSizeExpr, {range.offset, range.size, nt});
     }
     tileSizes.resize(numLoops, zero);
     return {tileSizes, numThreads};
@@ -243,13 +248,19 @@ getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs,
   SmallVector<OpFoldResult> offsets, sizes;
   int materializedLoopNum = 0;
 
+  assert(
+      llvm::all_of(iterationDomain,
+                   [](Range r) { return isConstantIntValue(r.stride, 1); }) &&
+      "the offset and tile size computation assumes stride 1 for all "
+      "dimensions of the iteration domain");
+
   if (!numThreads.empty()) {
-    AffineExpr d0, d1, s0, s1, s2;
+    AffineExpr d0, d1, s0, s1;
     AffineExpr offsetExpr, residualTileSizeExpr;
     bindDims(rewriter.getContext(), d0, d1);
-    bindSymbols(rewriter.getContext(), s0, s1, s2);
-    offsetExpr = d0 + d1 * s0 * s1;
-    residualTileSizeExpr = s2 - (d0 + d1 * s0 * s1);
+    bindSymbols(rewriter.getContext(), s0, s1);
+    offsetExpr = d0 + d1 * s0;
+    residualTileSizeExpr = s1 - (d0 + d1 * s0);
 
     for (auto [nt, tileSize, loopRange] :
          llvm::zip_equal(numThreads, tileSizes, iterationDomain)) {
@@ -263,11 +274,11 @@ getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs,
       Value iv = ivs[materializedLoopNum++];
       OpFoldResult offset = affine::makeComposedFoldedAffineApply(
           rewriter, loc, offsetExpr,
-          ArrayRef<OpFoldResult>{loopRange.offset, iv, loopRange.stride,
-                                 tileSize});
+          ArrayRef<OpFoldResult>{loopRange.offset, iv, tileSize});
       OpFoldResult residualTileSize = affine::makeComposedFoldedAffineApply(
           rewriter, loc, residualTileSizeExpr,
-          {loopRange.offset, nt, loopRange.stride, tileSize, loopRange.size});
+          {loopRange.offset, nt, tileSize, loopRange.size});
+
       OpFoldResult size = tileSize;
       if (!isConstantIntValue(residualTileSize, 0)) {
         OpFoldResult sizeMinusOffsetPerThread =
@@ -775,6 +786,11 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
 
   // 1. Get the range of the loops that are represented by the operation.
   SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter);
+  if (llvm::any_of(iterationDomain,
+                   [](Range r) { return !isConstantIntValue(r.stride, 1); })) {
+    return rewriter.notifyMatchFailure(
+        op, "unhandled tiling of iteration domain with non-unit stride");
+  }
 
   // 2. Materialize the tile sizes and/or number of threads;
   SmallVector<OpFoldResult> tileSizes, numThreads;



More information about the Mlir-commits mailing list