[Mlir-commits] [mlir] [mlir][Linalg] Deprecate `linalg::tileToForallOp` and `linalg::tileToForallOpUsingTileSizes` (PR #91878)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jun 11 15:59:48 PDT 2024


https://github.com/MaheshRavishankar updated https://github.com/llvm/llvm-project/pull/91878

>From b66c66809dfa2d5a7501762d5b9ea25192a36a2b Mon Sep 17 00:00:00 2001
From: MaheshRavishankar <mahesh at nod-labs.com>
Date: Sat, 11 May 2024 13:38:36 -0700
Subject: [PATCH] [mlir][Linalg] Deprecate linalg::tileToForallOp and
 linalg::tileToForallOpUsingTileSizes.

The implementation of these methods are legacy and they are removed in
favor of using the `scf::tileUsingSCF` methods as replacements. To get
the latter on par with requirements of the deprecated methods, the
tiling allows one to specify the maximum number of tiles to use
instead of specifying the tile sizes. When tiling to `scf.forall` this
specification is used to generate the `num_threads` version of the
operation.

A slight deviation from previous implementation is that the deprecated
method always generated the num_threads variant of the `scf.forall`
operation. Instead now this is driven by the tiling options
specified. This reduces the indexing math generated when the tile
sizes are specified.

**Moving from `linalg::tileToForallOp` to `scf::tileUsingSCF`**

```
OpBuilder b;
TilingInterface op;
ArrayRef<OpFoldResult> numThreads;
ArrayAttr mapping;
FailureOr<ForallTilingResult> result =linalg::tileToForallOp(b, op, numThreads, mapping);
```

can be replaced by

```
scf::SCFTilingOptions options;
options.setNumThreads(numThreads);
options.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
options.setMapping(mapping.getValue()); /*note the difference that setMapping takes an ArrayRef<Attribute> */
FailureOr<scf::SCFTilingResult> result = scf::tileUsingSCF(b, op, options);
```

This generates the `numThreads` version of the `scf.forall` for the
inter-tile loops, i.e.

```
... = scf.forall (%arg0, %arg1) in (%nt0, %nt1) shared_outs(...)
```

**Moving from `linalg::tileToForallOpUsingTileSizes` to `scf::tileUsingSCF`**

```
OpBuilder b;
TilingInterface op;
ArrayRef<OpFoldResult> tileSizes;
ArrayAttr mapping;
FailureOr<ForallTilingResult> result =linalg::tileToForallOpUsingTileSizes(b, op, tileSizes, mapping);
```

can be replaced by

```
scf::SCFTilingOptions options;
options.setTileSizes(tileSizes);
options.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
options.setMapping(mapping.getValue()); /*note the difference that setMapping takes an ArrayRef<Attribute> */
FailureOr<scf::SCFTilingResult> result = scf::tileUsingSCF(b, op, options);
```

Also note that `linalg::tileToForallOpUsingTileSizes` would
effectively call the `linalg::tileToForallOp` by computing the
`numThreads` from the op and `tileSizes` and generate the `numThreads`
version of the `scf.forall`. That is not the case anymore. Instead
this will directly generate the tileSizes version of the `scf.forall`
op

```
... = scf.forall(%arg0, %arg1) = (%lb0, %lb1) to (%ub0, %ub1) step(%step0, %step1) shared_outs(...)
```

If you actually want to use the `numThreads` version, it is upto the
caller to compute the `numThreads` and set `options.setNumThreads`
instead of `options.setTileSizes`.

** Changes to `transform.structured.tile_using_forall` **

The transform dialect op that called into `linalg::tileToForallOp` and
`linalg::tileToForallOpUsingTileSizes` have been modified to call
`scf::tileUsingSCF`. The transform dialect op always generates the
`numThreads` version of the `scf.forall` op. So when `tile_sizes` are
specified for the transform dialect op, the numThreads is computed
from it. So there is no functional change to
`transform.structured.tile_using_forall`. It always generates the
`numThreads` version of the `scf.forall` op (as it did before this
change).
---
 .../Linalg/TransformOps/LinalgTransformOps.h  |   6 +-
 .../Dialect/Linalg/Transforms/Transforms.h    |  24 --
 .../SCF/Transforms/TileUsingInterface.h       |  35 +-
 .../TransformOps/LinalgTransformOps.cpp       |  48 ++-
 mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp | 182 ---------
 .../SCF/Transforms/TileUsingInterface.cpp     | 386 ++++++++++++++----
 mlir/test/Dialect/Linalg/tile-to-forall.mlir  |  44 +-
 .../Dialect/Linalg/transform-op-tile.mlir     |  29 +-
 .../tile-pad-using-interface.mlir             |  10 +-
 .../TilingInterface/tile-using-interface.mlir |  32 +-
 .../TestTilingInterfaceTransformOps.cpp       |   6 +-
 11 files changed, 435 insertions(+), 367 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
index 3af642752724c..db25c9b241734 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
@@ -30,6 +30,10 @@ class GenericOp;
 class LinalgOp;
 } // namespace linalg
 
+namespace scf {
+struct SCFTilingResult;
+} // namespace scf
+
 namespace tensor {
 class InsertSliceOp;
 class PackOp;
@@ -60,7 +64,7 @@ tileToForallOpImpl(RewriterBase &rewriter, transform::TransformState &state,
                    ArrayRef<OpFoldResult> mixedNumThreads,
                    ArrayRef<OpFoldResult> mixedTileSizes,
                    std::optional<ArrayAttr> mapping,
-                   linalg::ForallTilingResult &tilingResult);
+                   scf::SCFTilingResult &tilingResult);
 
 } // namespace transform
 } // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 308ce92e35520..e9b10de68bc44 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -846,30 +846,6 @@ FailureOr<StaticMultiSizeSpecification>
 computeStaticMultiTileSizes(LinalgOp op, unsigned dimension, int64_t targetSize,
                             int64_t divisor);
 
-/// Rewrite a TilingInterface `op` to a tiled `scf.forall`, applying
-/// tiling by `numThreads`.
-/// If non-empty, the `mapping` is added as an attribute to the
-/// resulting `scf.forall`.
-/// Zero tile sizes indicate that the dimension is not tiled, and can be
-/// thought of as tiling by the full size of data. It is the user's
-/// responsibility to ensure that `numThreads` is a valid tiling specification
-/// (i.e. that only tiles parallel dimensions, e.g. in the Linalg case).
-struct ForallTilingResult {
-  Operation *tileOp;
-  Operation *tiledOp;
-};
-FailureOr<ForallTilingResult> tileToForallOp(RewriterBase &builder,
-                                             TilingInterface op,
-                                             ArrayRef<OpFoldResult> numThreads,
-                                             std::optional<ArrayAttr> mapping);
-
-/// Same as `tileToForallOp`, but calculate the number of threads
-/// required using the given tileSizes.
-FailureOr<ForallTilingResult>
-tileToForallOpUsingTileSizes(RewriterBase &builder, TilingInterface op,
-                             ArrayRef<OpFoldResult> tileSizes,
-                             std::optional<ArrayAttr> mapping);
-
 /// Transformation information returned after reduction tiling.
 struct ForallReductionTilingResult {
   /// The partial reduction tiled op generated.
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index dac79111af3c9..0f30b149e9a08 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -32,9 +32,11 @@ using SCFTileSizeComputationFunction =
 
 /// Options to use to control tiling.
 struct SCFTilingOptions {
-  /// Computation function that returns the tile sizes for each operation.
-  /// Delayed construction of constant tile sizes should occur to interoperate
-  /// with folding.
+  /// Computation function that returns the tile sizes to use for each loop.
+  /// Returning a tile size of zero implies no tiling for that loop. If the
+  /// size of the returned vector is smaller than the number of loops, the inner
+  /// loops are not tiled. If the size of the returned vector is larger, then
+  /// the vector is truncated to number of loops.
   SCFTileSizeComputationFunction tileSizeComputationFunction = nullptr;
 
   SCFTilingOptions &
@@ -45,7 +47,27 @@ struct SCFTilingOptions {
   /// Convenience function to set the `tileSizeComputationFunction` to a
   /// function that computes tile sizes at the point they are needed. Allows
   /// proper interaction with folding.
-  SCFTilingOptions &setTileSizes(ArrayRef<OpFoldResult> ts);
+  SCFTilingOptions &setTileSizes(ArrayRef<OpFoldResult> tileSizes);
+
+  /// Computation function that returns the number of threads to use for
+  /// each loop. Returning a num threads of zero implies no tiling for that
+  /// loop. If the size of the returned vector is smaller than the number of
+  /// loops, the inner loops are not tiled. If the size of the returned vector
+  /// is larger, then the vector is truncated to number of loops. Note: This
+  /// option is only supported with loopType set to `LoopType::ForallOp`. If the
+  /// tile size function is not specified while the num threads computation is,
+  /// then the tile size is determined automatically to map at most one tile per
+  /// thread.
+  SCFTileSizeComputationFunction numThreadsComputationFunction = nullptr;
+
+  SCFTilingOptions &
+  setNumThreadsComputationFunction(SCFTileSizeComputationFunction fun) {
+    numThreadsComputationFunction = std::move(fun);
+    return *this;
+  }
+  /// Convenience function to set the `numThreadsComputationFunction` to a
+  /// function that computes num threads at the point they are needed.
+  SCFTilingOptions &setNumThreads(ArrayRef<OpFoldResult> numThreads);
 
   /// The interchange vector to reorder the tiled loops.
   SmallVector<int64_t> interchangeVector = {};
@@ -67,9 +89,8 @@ struct SCFTilingOptions {
   /// when using loop constructs that dont support such a mapping (like
   /// `scf.for`)
   SmallVector<Attribute> mappingVector = {};
-  SCFTilingOptions &setMapping(ArrayRef<DeviceMappingAttrInterface> mapping) {
-    mappingVector = llvm::map_to_vector(
-        mapping, [](auto attr) -> Attribute { return attr; });
+  SCFTilingOptions &setMapping(ArrayRef<Attribute> mapping) {
+    mappingVector = llvm::to_vector(mapping);
     return *this;
   }
 };
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 9b3121774ab3a..8bf7db2e15061 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2919,7 +2919,7 @@ DiagnosedSilenceableFailure transform::tileToForallOpImpl(
     TransformOpInterface transformOp, Operation *target,
     ArrayRef<OpFoldResult> mixedNumThreads,
     ArrayRef<OpFoldResult> mixedTileSizes, std::optional<ArrayAttr> mapping,
-    linalg::ForallTilingResult &tilingResult) {
+    scf::SCFTilingResult &tilingResult) {
   // Transform all targets one by one.
   auto tileableOp = dyn_cast<TilingInterface>(target);
   if (!tileableOp) {
@@ -2930,18 +2930,39 @@ DiagnosedSilenceableFailure transform::tileToForallOpImpl(
     return diag;
   }
   rewriter.setInsertionPoint(tileableOp);
-  FailureOr<linalg::ForallTilingResult> maybeTilingResult = failure();
+  scf::SCFTilingOptions options;
+  options.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
   if (!mixedNumThreads.empty()) {
-    maybeTilingResult =
-        linalg::tileToForallOp(rewriter, tileableOp, mixedNumThreads, mapping);
+    options.setNumThreads(mixedNumThreads);
   } else {
-    maybeTilingResult = linalg::tileToForallOpUsingTileSizes(
-        rewriter, tileableOp, mixedTileSizes, mapping);
+    SmallVector<Range> loopRanges = tileableOp.getIterationDomain(rewriter);
+    unsigned nLoops = loopRanges.size();
+    SmallVector<OpFoldResult> numThreads;
+    numThreads.reserve(nLoops);
+    AffineExpr s0, s1;
+    bindSymbols(rewriter.getContext(), s0, s1);
+    AffineExpr divExpr = s0.ceilDiv(s1);
+    for (int i = 0, e = std::min(mixedTileSizes.size(), loopRanges.size());
+         i < e; ++i) {
+      OpFoldResult numTiles = mixedTileSizes[i];
+      if (!isConstantIntValue(numTiles, 0))
+        numTiles = affine::makeComposedFoldedAffineApply(
+            rewriter, tileableOp.getLoc(), divExpr,
+            {loopRanges[i].size, numTiles});
+      numThreads.push_back(numTiles);
+    }
+    options.setNumThreads(numThreads);
+    options.setTileSizes(mixedTileSizes);
+  }
+  if (mapping) {
+    options.setMapping(mapping.value().getValue());
   }
+  FailureOr<scf::SCFTilingResult> maybeTilingResult =
+      scf::tileUsingSCF(rewriter, tileableOp, options);
 
   if (failed(maybeTilingResult))
     return transformOp.emitDefaultSilenceableFailure(tileableOp);
-  rewriter.replaceOp(tileableOp, maybeTilingResult->tileOp->getResults());
+  rewriter.replaceOp(tileableOp, maybeTilingResult->replacements);
 
   tilingResult = *maybeTilingResult;
   return DiagnosedSilenceableFailure::success();
@@ -2977,14 +2998,14 @@ DiagnosedSilenceableFailure transform::TileUsingForallOp::apply(
     return status;
 
   for (Operation *target : state.getPayloadOps(getTarget())) {
-    linalg::ForallTilingResult tilingResult;
+    scf::SCFTilingResult tilingResult;
     DiagnosedSilenceableFailure diag = tileToForallOpImpl(
         rewriter, state, transformOp, target, mixedNumThreads, mixedTileSizes,
         getMapping(), tilingResult);
     if (!diag.succeeded())
       return diag;
-    tileOps.push_back(tilingResult.tileOp);
-    tiledOps.push_back(tilingResult.tiledOp);
+    tileOps.push_back(tilingResult.loops.front());
+    tiledOps.append(tilingResult.tiledOps);
   }
 
   transformResults.set(cast<OpResult>(getForallOp()), tileOps);
@@ -3462,7 +3483,7 @@ DiagnosedSilenceableFailure transform::MapCopyToThreadsOp::applyToOne(
 
   // OpBuilder only used to compute attributes.
   OpBuilder b(getContext());
-  linalg::ForallTilingResult tilingResult;
+  scf::SCFTilingResult tilingResult;
   DiagnosedSilenceableFailure diag = tileToForallOpImpl(
       /*rewriter=*/rewriter,
       /*state=*/state,
@@ -3475,8 +3496,9 @@ DiagnosedSilenceableFailure transform::MapCopyToThreadsOp::applyToOne(
   if (!diag.succeeded())
     return diag;
 
-  results.push_back(tilingResult.tileOp);
-  results.push_back(tilingResult.tiledOp);
+  results.push_back(tilingResult.loops.front());
+  for (auto op : tilingResult.tiledOps)
+    results.push_back(op);
   return DiagnosedSilenceableFailure::success();
 }
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index fd314ef9f8134..5e9840dc551de 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -304,188 +304,6 @@ static void calculateTileOffsetsAndSizes(
   }
 }
 
-/// Returns a vector of bools representing if, for each axis, `op` can be tiled
-/// without incurring in a race condition and thus it is thread-safe to do the
-/// tiling. This is checked by iterating over numThreads and ensuring that the
-/// corresponding iterator type is "parallel". If it is not, then we know that
-/// such dimension is unsafe to tile.
-SmallVector<bool> safeToTileToForall(mlir::MLIRContext *ctx, LinalgOp linalgOp,
-                                     ArrayRef<OpFoldResult> numThreads) {
-  auto iterators = linalgOp.getIteratorTypesArray();
-  SmallVector<bool> safeToTile(numThreads.size(), true);
-
-  for (unsigned i = 0, e = numThreads.size(); i != e; i++) {
-    if (auto attr = llvm::dyn_cast_if_present<Attribute>(numThreads[i])) {
-      if (cast<IntegerAttr>(attr).getValue().getSExtValue() > 1) {
-        safeToTile[i] = iterators[i] == utils::IteratorType::parallel;
-      }
-    } else {
-      safeToTile[i] = iterators[i] == utils::IteratorType::parallel;
-    }
-  }
-  return safeToTile;
-}
-
-/// Rewrite a TilingInterface `op` to a tiled `scf.forall`. The
-/// tiling is specified by the number of tiles/threads `numThreads` and the
-/// optional nominal tile size `nominalTileSizes`. If `nominalTilSizes` is
-/// not specified, then  it is derived from `numThreads` as `ceilDiv(dimSize[i],
-/// numThreads[i])`. If non-empty, the `mapping` is added as an
-/// attribute to the resulting `scf.forall`. A zero tile sizes indicate
-/// that the dimension is not tiled, and can be thought of as tiling by the full
-/// size of data.
-/// It is the user's responsibility to ensure that `numThreads` is a valid
-/// tiling specification (i.e. that only tiles parallel dimensions, e.g. in the
-/// Linalg case). If the dimension is not parallelizable, a warning is issued to
-/// notify the user that the generated code is not safe to parallelize. If
-/// `omitTileOffsetBoundsCheck` is true, then the function will assume that
-/// `tileSize[i] * (numThread[i] -1) <= dimSize[i]` holds.
-static FailureOr<ForallTilingResult> tileToForallOpImpl(
-    RewriterBase &b, TilingInterface op, ArrayRef<OpFoldResult> numThreads,
-    std::optional<ArrayRef<OpFoldResult>> nominalTileSizes,
-    std::optional<ArrayAttr> mapping, bool omitTileOffsetBoundsCheck) {
-  Location loc = op->getLoc();
-  OpBuilder::InsertionGuard g(b);
-
-  SmallVector<Range> loopRanges = op.getIterationDomain(b);
-  if (loopRanges.empty())
-    return op->emitOpError("expected non-empty loop ranges");
-  auto hasStrideOne = [](Range r) { return !isConstantIntValue(r.stride, 1); };
-  if (llvm::any_of(loopRanges, hasStrideOne))
-    return op->emitOpError("only stride-1 supported atm");
-
-  // Gather destination tensors.
-  SmallVector<Value> dest;
-  if (failed(tensor::getOrCreateDestinations(b, loc, op, dest)))
-    return op->emitOpError("failed to get destination tensors");
-
-  SmallVector<OpFoldResult> nonZeroNumThreads =
-      llvm::to_vector(llvm::make_filter_range(numThreads, [](OpFoldResult ofr) {
-        return !isConstantIntValue(ofr, 0);
-      }));
-  SmallVector<Value> materializedNonZeroNumThreads =
-      llvm::to_vector(llvm::map_range(nonZeroNumThreads, [&](OpFoldResult ofr) {
-        return getValueOrCreateConstantIndexOp(b, loc, ofr);
-      }));
-
-  LinalgOp linalgOp = dyn_cast<LinalgOp>(op.getOperation());
-  if (linalgOp) {
-    // Check if tiling is thread safe and print a warning if not.
-    SmallVector<bool> tilingSafety =
-        safeToTileToForall(b.getContext(), linalgOp, numThreads);
-    for (size_t i = 0; i < tilingSafety.size(); i++)
-      if (!tilingSafety[i])
-        op.emitWarning() << "tiling is not thread safe at axis #" << i;
-  }
-
-  // 1. Create the ForallOp. We don't use the lambda body-builder
-  // version because we require the use of RewriterBase in the body, so we
-  // manually move the insertion point to the body below.
-  scf::ForallOp forallOp = b.create<scf::ForallOp>(
-      loc, getAsOpFoldResult((materializedNonZeroNumThreads)), dest, mapping);
-
-  // 2. Fill out the ForallOp body.
-  SmallVector<OpFoldResult> tiledOffsets, tiledSizes;
-  calculateTileOffsetsAndSizes(b, loc, forallOp, numThreads, loopRanges,
-                               omitTileOffsetBoundsCheck, nominalTileSizes,
-                               tiledOffsets, tiledSizes);
-
-  // 3. Clone the tileable op and update its destination operands to use the
-  // output bbArgs of the ForallOp.
-  ArrayRef<BlockArgument> destBbArgs = forallOp.getRegionIterArgs();
-  Operation *tiledOp = nullptr;
-  SmallVector<Value> tiledValues;
-  {
-    // 3.a. RAII guard, inserting within forallOp, before terminator.
-    OpBuilder::InsertionGuard g(b);
-    b.setInsertionPoint(forallOp.getTerminator());
-    Operation *clonedOp = b.clone(*op.getOperation());
-    auto destinationStyleOp = dyn_cast<DestinationStyleOpInterface>(clonedOp);
-    if (destinationStyleOp) {
-      for (OpOperand &outOperand : destinationStyleOp.getDpsInitsMutable()) {
-        // Swap tensor inits with the corresponding block argument of the
-        // scf.forall op. Memref inits remain as is.
-        if (isa<TensorType>(outOperand.get().getType())) {
-          auto *it = llvm::find(dest, outOperand.get());
-          assert(it != dest.end() && "could not find destination tensor");
-          unsigned destNum = std::distance(dest.begin(), it);
-          outOperand.set(destBbArgs[destNum]);
-        }
-      }
-    }
-
-    // 4. Tile the cloned op and delete the clone.
-    FailureOr<TilingResult> tilingResult =
-        cast<TilingInterface>(clonedOp).getTiledImplementation(b, tiledOffsets,
-                                                               tiledSizes);
-    if (failed(tilingResult))
-      return clonedOp->emitError("Failed to tile op: ");
-    if (tilingResult->tiledOps.size() != 1) {
-      return clonedOp->emitError("expected a single produced tiled op, got ")
-             << tilingResult->tiledOps.size();
-    }
-
-    b.eraseOp(clonedOp);
-    tiledOp = tilingResult->tiledOps.front();
-    tiledValues = tilingResult->tiledValues;
-  }
-
-  // 5. Parallel insert back into the result tensor.
-  for (auto it : llvm::zip(llvm::seq(unsigned(0), unsigned(dest.size())),
-                           tiledValues, destBbArgs)) {
-    // 5.a. Partial subset information is inserted just before the terminator.
-    OpBuilder::InsertionGuard g(b);
-    b.setInsertionPoint(forallOp.getTerminator());
-
-    SmallVector<OpFoldResult> resultOffsets, resultSizes;
-    if (failed(op.getResultTilePosition(b, std::get<0>(it), tiledOffsets,
-                                        tiledSizes, resultOffsets,
-                                        resultSizes)))
-      return op->emitOpError("output offsets couldn't be calculated");
-    SmallVector<OpFoldResult> strides(resultSizes.size(), b.getIndexAttr(1));
-
-    // 5.b. Parallel insertions are inserted at the end of the combining
-    // terminator.
-    b.setInsertionPointToEnd(forallOp.getTerminator().getBody());
-    b.create<tensor::ParallelInsertSliceOp>(loc, std::get<1>(it),
-                                            std::get<2>(it), resultOffsets,
-                                            resultSizes, strides);
-  }
-  return ForallTilingResult{forallOp, tiledOp};
-}
-
-FailureOr<ForallTilingResult>
-linalg::tileToForallOp(RewriterBase &b, TilingInterface op,
-                       ArrayRef<OpFoldResult> numThreads,
-                       std::optional<ArrayAttr> mapping) {
-  return tileToForallOpImpl(b, op, numThreads,
-                            /*nominalTileSizes=*/std::nullopt, mapping,
-                            /*omitTileOffsetBoundsCheck=*/false);
-}
-
-FailureOr<ForallTilingResult>
-linalg::tileToForallOpUsingTileSizes(RewriterBase &b, TilingInterface op,
-                                     ArrayRef<OpFoldResult> tileSizes,
-                                     std::optional<ArrayAttr> mapping) {
-  SmallVector<Range> loopRanges = op.getIterationDomain(b);
-  unsigned nLoops = loopRanges.size();
-  SmallVector<OpFoldResult> numThreads;
-  numThreads.reserve(nLoops);
-  AffineExpr s0, s1;
-  bindSymbols(b.getContext(), s0, s1);
-  AffineExpr divExpr = s0.ceilDiv(s1);
-  for (const auto &it : llvm::zip(tileSizes, loopRanges)) {
-    OpFoldResult numTiles = std::get<0>(it);
-    if (!isConstantIntValue(numTiles, 0))
-      numTiles = makeComposedFoldedAffineApply(
-          b, op.getLoc(), divExpr, {std::get<1>(it).size, std::get<0>(it)});
-    numThreads.push_back(numTiles);
-  }
-  return tileToForallOpImpl(b, op, numThreads,
-                            /*nominalTileSizes=*/tileSizes, mapping,
-                            /*omitTileOffsetBoundsCheck=*/true);
-}
-
 template <typename LoopTy>
 static FailureOr<TiledLinalgOp>
 tileLinalgOpImpl(RewriterBase &b, LinalgOp op, ArrayRef<OpFoldResult> tileSizes,
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index f3d6b7a530117..8eb2ab59ac81d 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -42,6 +42,16 @@ scf::SCFTilingOptions::setTileSizes(ArrayRef<OpFoldResult> ts) {
   return *this;
 }
 
+scf::SCFTilingOptions &
+scf::SCFTilingOptions::setNumThreads(ArrayRef<OpFoldResult> nt) {
+  assert(!numThreadsComputationFunction && "num tiles already set");
+  auto numThreads = llvm::to_vector(nt);
+  numThreadsComputationFunction = [numThreads](OpBuilder &b, Operation *op) {
+    return numThreads;
+  };
+  return *this;
+}
+
 /// Helper method to adjust the interchange vector to match the iteration
 /// domain.
 static SmallVector<int64_t>
@@ -61,7 +71,120 @@ fillInterchangeVector(ArrayRef<int64_t> interchangeVector,
 // tileUsingSCF implementation.
 //===----------------------------------------------------------------------===//
 
-// Check if `stride` evenly divides the trip count `size - offset`.
+/// Verify the tile size options are set in a consistent manner.
+static LogicalResult
+verifyTileSizeOptions(RewriterBase &rewriter, Location loc,
+                      const scf::SCFTilingOptions &options) {
+  // Specifying number of threads is only supported on `scf.forall` op.
+  if (options.numThreadsComputationFunction &&
+      options.loopType != scf::SCFTilingOptions::LoopType::ForallOp) {
+    return rewriter.notifyMatchFailure(
+        loc, "number of threads can only by specified when loop type is "
+             "set to use `scf.forall`");
+  }
+
+  // If specified, check that the interchange vector is a permutation.
+  if (!options.interchangeVector.empty()) {
+    if (!isPermutationVector(options.interchangeVector)) {
+      return rewriter.notifyMatchFailure(
+          loc, "invalid interchange vector, not a permutation of the entire "
+               "iteration space");
+    }
+  }
+  return success();
+}
+
+/// Method to instantiate the tile sizes and/or number of threads specified
+/// by the user.
+static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>>
+getUserTileSizesAndNumThreads(RewriterBase &rewriter, TilingInterface op,
+                              ArrayRef<Range> iterationDomain,
+                              const scf::SCFTilingOptions &options) {
+  OpFoldResult zero = rewriter.getIndexAttr(0);
+  SmallVector<OpFoldResult> tileSizes, numThreads;
+  size_t numLoops = iterationDomain.size();
+
+  // Check whether the number of tiles to use is specified.
+  if (options.numThreadsComputationFunction) {
+    numThreads = options.numThreadsComputationFunction(rewriter, op);
+    numThreads.resize(numLoops, zero);
+
+    // If the number of tiles is also specified, use that.
+    if (options.tileSizeComputationFunction) {
+      tileSizes = options.tileSizeComputationFunction(rewriter, op);
+      tileSizes.resize(numLoops, zero);
+      return {tileSizes, numThreads};
+    }
+
+    // Compute the tile sizes from the iteration domain and number
+    // of tiles as follows
+    // - niters = ceilDiv(ub - lb, step)
+    // - tileSize = ceilDiv(niters, numThreads)
+    AffineExpr s0, s1, s2;
+    bindSymbols(rewriter.getContext(), s0, s1, s2);
+    // TODO: The step here is assumed to be 1.
+    AffineExpr numItersExpr = (s1 - s0);
+    AffineExpr tileSizeExpr = numItersExpr.ceilDiv(s2);
+    tileSizes.resize(numLoops, zero);
+    for (auto [index, range, nt] :
+         llvm::enumerate(iterationDomain, numThreads)) {
+      if (isConstantIntValue(nt, 0))
+        continue;
+
+      tileSizes[index] = affine::makeComposedFoldedAffineApply(
+          rewriter, op.getLoc(), tileSizeExpr, {range.offset, range.size, nt});
+    }
+    tileSizes.resize(numLoops, zero);
+    return {tileSizes, numThreads};
+  }
+
+  // Enforce the convention that "tiling by zero"
+  // skips tiling a particular dimension. This convention is significantly
+  // simpler to handle instead of adjusting affine maps to account for missing
+  // dimensions.
+  assert(options.tileSizeComputationFunction &&
+         "expected tile sizes to be specified");
+  tileSizes = options.tileSizeComputationFunction(rewriter, op);
+  tileSizes.resize(numLoops, zero);
+
+  return {tileSizes, numThreads};
+}
+
+/// Checks if any of the tiled loops are not parallel.
+static void checkSafeToTileToForall(TilingInterface op,
+                                    ArrayRef<OpFoldResult> tileSizes,
+                                    ArrayRef<OpFoldResult> numThreads) {
+  auto iterators = op.getLoopIteratorTypes();
+  assert(iterators.size() == tileSizes.size() &&
+         "expected as many tile size values as number of loops");
+  assert((numThreads.empty() || (numThreads.size() == iterators.size())) &&
+         "when specified, expected number of threads to use for each loop");
+
+  for (auto [index, iterator, tileSize] :
+       llvm::enumerate(iterators, tileSizes)) {
+    // If num threads is specified, check that it is greater than one only for
+    // parallel dimensions.
+    if (!numThreads.empty()) {
+      if (std::optional<int64_t> constNumThreads =
+              getConstantIntValue(numThreads[index])) {
+        if (constNumThreads.value() > 1 &&
+            iterator != utils::IteratorType::parallel) {
+          op.emitWarning() << "tiling is not thread safe at axis #" << index;
+        }
+      }
+      continue;
+    }
+
+    if (std::optional<int64_t> constTileSize = getConstantIntValue(tileSize)) {
+      if (constTileSize.value() > 0 &&
+          iterator != utils::IteratorType::parallel) {
+        op.emitWarning() << "tiling is not thread safe at axis #" << index;
+      }
+    }
+  }
+}
+
+/// Check if `stride` evenly divides the trip count `size - offset`.
 static bool tileDividesIterationDomain(Range loopRange) {
   std::optional<int64_t> offsetAsInt = getConstantIntValue(loopRange.offset);
   if (!offsetAsInt)
@@ -75,10 +198,10 @@ static bool tileDividesIterationDomain(Range loopRange) {
   return ((sizeAsInt.value() - offsetAsInt.value()) % strideAsInt.value() == 0);
 }
 
-/// Returns the bounded tile size given the current `iv`, `loopRange` and
-/// `tileSize`, i.e., `min(tileSize, range.end() - iv)`.
+/// Returns the bounded tile size given the current `offset`, `loopRange` and
+/// `tileSize`, i.e., `min(tileSize, range.end() - offset)`.
 static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc,
-                                       Range loopRange, Value iv,
+                                       Range loopRange, OpFoldResult offset,
                                        OpFoldResult tileSize) {
   std::optional<int64_t> ts = getConstantIntValue(tileSize);
   if (ts && ts.value() == 1)
@@ -97,7 +220,129 @@ static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc,
   AffineMap minMap = AffineMap::get(1, 2, {s0, s1 - d0}, b.getContext());
   Value size = getValueOrCreateConstantIndexOp(b, loc, loopRange.size);
   return affine::makeComposedFoldedAffineMin(
-      b, loc, minMap, SmallVector<OpFoldResult>{iv, tileSize, size});
+      b, loc, minMap, SmallVector<OpFoldResult>{offset, tileSize, size});
+}
+
+/// Returns true if the maximum tile offset `tileSize * numThreads-1` is less
+/// than `iterationSize`.
+static bool canOmitTileOffsetInBoundsCheck(OpFoldResult tileSize,
+                                           OpFoldResult numThreads,
+                                           OpFoldResult iterationSize) {
+  std::optional<int64_t> tileSizeConst = getConstantIntValue(tileSize);
+  std::optional<int64_t> numThreadsConst = getConstantIntValue(numThreads);
+  std::optional<int64_t> iterSizeConst = getConstantIntValue(iterationSize);
+  if (!tileSizeConst || !numThreadsConst || !iterSizeConst)
+    return false;
+  return *tileSizeConst * (*numThreadsConst - 1) < *iterSizeConst;
+}
+
+/// Compute the `OpFoldResult`s that represents the multi-dimensional
+/// `offset`s and `size`s of the tile of the iteration space that the
+/// innermost loop body of the generated tiled loops corresponds to.
+static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>>
+getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs,
+                      ArrayRef<Range> iterationDomain,
+                      ArrayRef<OpFoldResult> tileSizes,
+                      ArrayRef<OpFoldResult> numThreads) {
+  SmallVector<OpFoldResult> offsets, sizes;
+  int materializedLoopNum = 0;
+
+  if (!numThreads.empty()) {
+    AffineExpr d0, d1, s0, s1;
+    AffineExpr offsetExpr, residualTileSizeExpr;
+    bindDims(rewriter.getContext(), d0, d1);
+    bindSymbols(rewriter.getContext(), s0, s1);
+    offsetExpr = d0 + d1 * s0;
+    residualTileSizeExpr = s1 - (d0 + d1 * s0);
+
+    for (auto [nt, tileSize, loopRange] :
+         llvm::zip_equal(numThreads, tileSizes, iterationDomain)) {
+
+      // Non-tiled cases, set the offset and size to the
+      // `loopRange.offset/size`.
+      if (isConstantIntValue(nt, 0)) {
+        offsets.push_back(loopRange.offset);
+        sizes.push_back(loopRange.size);
+        continue;
+      }
+
+      Value iv = ivs[materializedLoopNum++];
+      OpFoldResult offset = affine::makeComposedFoldedAffineApply(
+          rewriter, loc, offsetExpr,
+          ArrayRef<OpFoldResult>{loopRange.offset, iv, tileSize});
+      OpFoldResult residualTileSize = affine::makeComposedFoldedAffineApply(
+          rewriter, loc, residualTileSizeExpr,
+          {loopRange.offset, nt, tileSize, loopRange.size});
+
+      OpFoldResult size = tileSize;
+      if (!isConstantIntValue(residualTileSize, 0)) {
+        OpFoldResult sizeMinusOffsetPerThread =
+            affine::makeComposedFoldedAffineApply(rewriter, loc, s0 - d0,
+                                                  {offset, loopRange.size});
+        size = affine::makeComposedFoldedAffineMin(
+            rewriter, loc,
+            AffineMap::getMultiDimIdentityMap(2, rewriter.getContext()),
+            {sizeMinusOffsetPerThread, tileSize});
+      }
+
+      // Consider the case where the original loop was `[0, 100)`.
+      // If number of threads are `7`, the tile size would be computed as
+      // `ceilDiv(100, 7) = 15`. For the last thread (thread_id = 6)
+      // - `offset = 0 + 6 * 15 = 105`
+      // - `tileSize = min(15, 100 - 105) = -5`
+      // To avoid negative tile sizes, we need to do a further
+      // `nonNegativeTileSize = affine.max(0, tileSize)`.
+      // This `max` can be avoided if
+      //  `offset + tileSize * (numThreads - 1) < (ub - lb)`
+      if (!canOmitTileOffsetInBoundsCheck(tileSize, nt, loopRange.size)) {
+        AffineMap maxMap =
+            AffineMap::getMultiDimIdentityMap(2, rewriter.getContext());
+        size = affine::makeComposedFoldedAffineMax(
+            rewriter, loc, maxMap, {rewriter.getIndexAttr(0), size});
+      }
+
+      offsets.push_back(offset);
+      sizes.push_back(size);
+    }
+    return {offsets, sizes};
+  } else {
+    for (auto [tileSize, loopRange] :
+         llvm::zip_equal(tileSizes, iterationDomain)) {
+
+      // Non-tiled cases, set the offset and size to the
+      // `loopRange.offset/size`.
+      if (isConstantIntValue(tileSize, 0)) {
+        offsets.push_back(loopRange.offset);
+        sizes.push_back(loopRange.size);
+        continue;
+      }
+
+      Value iv = ivs[materializedLoopNum++];
+      OpFoldResult offset = getAsOpFoldResult(iv);
+      offsets.push_back(offset);
+      OpFoldResult size =
+          getBoundedTileSize(rewriter, loc, loopRange, offset, tileSize);
+      sizes.push_back(size);
+    }
+    return {offsets, sizes};
+  }
+}
+
+/// Function to return the bounds of the loops to be generated.
+static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>,
+                  SmallVector<OpFoldResult>>
+getLoopBounds(RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
+              ArrayRef<OpFoldResult> tileSizes) {
+  SmallVector<OpFoldResult> lbs, ubs, steps;
+  for (auto [loopRange, tileSize] : llvm::zip_equal(loopRanges, tileSizes)) {
+    // No loop if the tile size is 0.
+    if (isConstantIntValue(tileSize, 0))
+      continue;
+    lbs.push_back(loopRange.offset);
+    ubs.push_back(loopRange.size);
+    steps.push_back(tileSize);
+  }
+  return {lbs, ubs, steps};
 }
 
 /// A function that allows returning additional yielded values during
@@ -152,17 +397,19 @@ static LogicalResult generateLoopNestUsingForOp(
   assert(loopRanges.size() == tileSizes.size() &&
          "expected as many tile sizes as loop ranges");
   OpBuilder::InsertionGuard guard(rewriter);
-  SmallVector<Value> ivs;
 
-  for (auto [loopRange, tileSize] : llvm::zip_equal(loopRanges, tileSizes)) {
-    // No loops if tile size is zero. Set offset and size to the loop
-    // offset and size.
-    if (isConstantIntValue(tileSize, 0))
-      continue;
+  SmallVector<OpFoldResult> lbs, ubs, steps;
+  std::tie(lbs, ubs, steps) =
+      getLoopBounds(rewriter, loc, loopRanges, tileSizes);
+  SmallVector<Value> lbVals =
+      getValueOrCreateConstantIndexOp(rewriter, loc, lbs);
+  SmallVector<Value> ubVals =
+      getValueOrCreateConstantIndexOp(rewriter, loc, ubs);
+  SmallVector<Value> stepVals =
+      getValueOrCreateConstantIndexOp(rewriter, loc, steps);
 
-    Value lb = getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.offset);
-    Value ub = getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.size);
-    Value step = getValueOrCreateConstantIndexOp(rewriter, loc, tileSize);
+  SmallVector<Value> ivs;
+  for (auto [lb, ub, step] : llvm::zip_equal(lbVals, ubVals, stepVals)) {
     auto loop =
         rewriter.create<scf::ForOp>(loc, lb, ub, step, destinationTensors,
                                     [](OpBuilder &bodyBuilder, Location bodyLoc,
@@ -224,10 +471,9 @@ static LogicalResult generateLoopNestUsingForOp(
 ///    populated.
 static LogicalResult generateLoopNestUsingForallOp(
     RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
-    ArrayRef<OpFoldResult> tileSizes, ArrayRef<Attribute> mappingVector,
-    ValueRange destinationTensors, YieldTiledValuesFn tiledBodyFn,
-    SmallVector<LoopLikeOpInterface> &loops) {
-  SmallVector<OpFoldResult> lbs, ubs, steps;
+    ArrayRef<OpFoldResult> tileSizes, ArrayRef<OpFoldResult> numThreads,
+    ArrayRef<Attribute> mappingVector, ValueRange destinationTensors,
+    YieldTiledValuesFn tiledBodyFn, SmallVector<LoopLikeOpInterface> &loops) {
   assert(!loopRanges.empty() && "unexpected empty loop ranges");
   assert(loopRanges.size() == tileSizes.size() &&
          "expected as many tile sizes as loop ranges");
@@ -235,21 +481,30 @@ static LogicalResult generateLoopNestUsingForallOp(
   SmallVector<OpFoldResult> offsets(loopRanges.size()),
       sizes(loopRanges.size());
 
-  for (auto [tileSize, loopRange] : llvm::zip_equal(tileSizes, loopRanges)) {
-    if (isConstantIntValue(tileSize, 0))
-      continue;
-    lbs.push_back(loopRange.offset);
-    ubs.push_back(loopRange.size);
-    steps.push_back(tileSize);
-  }
-  assert(!lbs.empty() && "Expected at least one loop range");
-
   std::optional<ArrayAttr> mappingAttr;
   if (!mappingVector.empty())
     mappingAttr = rewriter.getArrayAttr(mappingVector);
 
-  auto forallOp = rewriter.create<scf::ForallOp>(
-      loc, lbs, ubs, steps, destinationTensors, mappingAttr);
+  scf::ForallOp forallOp;
+  bool useNumThreads = !numThreads.empty();
+
+  if (useNumThreads) {
+    // Prune the zero numthreads.
+    SmallVector<OpFoldResult> nonZeroNumThreads;
+    for (auto nt : numThreads) {
+      if (isConstantIntValue(nt, 0))
+        continue;
+      nonZeroNumThreads.push_back(nt);
+    }
+    forallOp = rewriter.create<scf::ForallOp>(loc, nonZeroNumThreads,
+                                              destinationTensors, mappingAttr);
+  } else {
+    SmallVector<OpFoldResult> lbs, ubs, steps;
+    std::tie(lbs, ubs, steps) =
+        getLoopBounds(rewriter, loc, loopRanges, tileSizes);
+    forallOp = rewriter.create<scf::ForallOp>(loc, lbs, ubs, steps,
+                                              destinationTensors, mappingAttr);
+  }
   loops.push_back(forallOp);
 
   rewriter.setInsertionPoint(forallOp.getTerminator());
@@ -286,13 +541,11 @@ static LogicalResult generateLoopNestUsingForallOp(
 ///    loop.
 /// - `loops` is an in-out parameter into which the generated loops are
 ///    populated.
-static LogicalResult generateLoopNest(RewriterBase &rewriter, Location loc,
-                                      const scf::SCFTilingOptions &options,
-                                      ArrayRef<Range> loopRanges,
-                                      ArrayRef<OpFoldResult> tileSizes,
-                                      ValueRange destinationTensors,
-                                      YieldTiledValuesFn tiledBodyFn,
-                                      SmallVector<LoopLikeOpInterface> &loops) {
+static LogicalResult generateLoopNest(
+    RewriterBase &rewriter, Location loc, const scf::SCFTilingOptions &options,
+    ArrayRef<Range> loopRanges, ArrayRef<OpFoldResult> tileSizes,
+    ArrayRef<OpFoldResult> numThreads, ValueRange destinationTensors,
+    YieldTiledValuesFn tiledBodyFn, SmallVector<LoopLikeOpInterface> &loops) {
   // If the tile sizes are all zero, no loops are generated. Just call the
   // callback function to handle untiled case.
   if (llvm::all_of(tileSizes, isZeroIndex)) {
@@ -307,7 +560,7 @@ static LogicalResult generateLoopNest(RewriterBase &rewriter, Location loc,
   }
   if (options.loopType == scf::SCFTilingOptions::LoopType::ForallOp) {
     return generateLoopNestUsingForallOp(
-        rewriter, loc, loopRanges, tileSizes, options.mappingVector,
+        rewriter, loc, loopRanges, tileSizes, numThreads, options.mappingVector,
         destinationTensors, tiledBodyFn, loops);
   }
   return rewriter.notifyMatchFailure(loc, "unhandled loop type");
@@ -531,27 +784,25 @@ static LogicalResult addInitOperandsToLoopNest(
 FailureOr<scf::SCFTilingResult>
 mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
                         const scf::SCFTilingOptions &options) {
+  if (failed(verifyTileSizeOptions(rewriter, op.getLoc(), options))) {
+    return failure();
+  }
+
   OpBuilder::InsertionGuard guard(rewriter);
   rewriter.setInsertionPointAfter(op);
 
-  if (!options.tileSizeComputationFunction) {
-    return rewriter.notifyMatchFailure(
-        op, "missing tile size computation function");
-  }
-
   // 1. Get the range of the loops that are represented by the operation.
   SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter);
-  size_t numLoops = iterationDomain.size();
 
-  // 2. Materialize the tile sizes. Enforce the convention that "tiling by zero"
-  // skips tiling a particular dimension. This convention is significantly
-  // simpler to handle instead of adjusting affine maps to account for missing
-  // dimensions.
-  SmallVector<OpFoldResult> tileSizes =
-      options.tileSizeComputationFunction(rewriter, op);
-  if (tileSizes.size() < iterationDomain.size()) {
-    auto zero = rewriter.getIndexAttr(0);
-    tileSizes.append(numLoops - tileSizes.size(), zero);
+  // 2. Materialize the tile sizes and/or number of threads;
+  SmallVector<OpFoldResult> tileSizes, numThreads;
+  std::tie(tileSizes, numThreads) =
+      getUserTileSizesAndNumThreads(rewriter, op, iterationDomain, options);
+
+  // Check if it is safe to tile. This is hold over from previous iterations
+  // of tile to for-all. Consider dropping it.
+  if (options.loopType == scf::SCFTilingOptions::LoopType::ForallOp) {
+    checkSafeToTileToForall(op, tileSizes, numThreads);
   }
 
   // 3. If there is an interchange specified, permute the iteration domain and
@@ -560,16 +811,13 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
   if (!options.interchangeVector.empty()) {
     interchangeVector = fillInterchangeVector(options.interchangeVector,
                                               iterationDomain.size());
-  }
-  if (!interchangeVector.empty()) {
-    if (!isPermutationVector(interchangeVector)) {
-      return rewriter.notifyMatchFailure(
-          op, "invalid intechange vector, not a permutation of the entire "
-              "iteration space");
-    }
+    assert(isPermutationVector(interchangeVector) &&
+           "expected interchange vector to be a permutation");
 
     applyPermutationToVector(iterationDomain, interchangeVector);
     applyPermutationToVector(tileSizes, interchangeVector);
+    if (!numThreads.empty())
+      applyPermutationToVector(numThreads, interchangeVector);
   }
 
   FailureOr<TilingResult> tilingResult;
@@ -583,21 +831,8 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
       -> LogicalResult {
     // 4a. Compute the `offsets` and `sizes` to use for tiling.
     SmallVector<OpFoldResult> offsets, sizes;
-    {
-      int materializedLoopNum = 0;
-      for (auto [tileSize, loopRange] :
-           llvm::zip_equal(tileSizes, iterationDomain)) {
-        if (isConstantIntValue(tileSize, 0)) {
-          offsets.push_back(loopRange.offset);
-          sizes.push_back(loopRange.size);
-          continue;
-        }
-        Value iv = ivs[materializedLoopNum++];
-        offsets.push_back(iv);
-        sizes.push_back(
-            getBoundedTileSize(rewriter, loc, loopRange, iv, tileSize));
-      }
-    }
+    std::tie(offsets, sizes) = getTileOffsetAndSizes(
+        rewriter, loc, ivs, iterationDomain, tileSizes, numThreads);
 
     // 4b. If interchange was provided, apply inverse of the interchange
     //     to get back the offsets/sizes in the order to be specified.
@@ -665,7 +900,7 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
   // 7. Generate the tiled loops nest using the callback defined above.
   SmallVector<LoopLikeOpInterface> loops;
   if (failed(generateLoopNest(rewriter, op.getLoc(), options, iterationDomain,
-                              tileSizes, destinationTensors,
+                              tileSizes, numThreads, destinationTensors,
                               innerYieldTiledValuesFn, loops)))
     return op.emitOpError("failed to generate tiling loops");
   assert(succeeded(tilingResult) &&
@@ -774,6 +1009,7 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b,
   scf::SCFTilingOptions options;
   options.setLoopType(scf::SCFTilingOptions::LoopType::ForOp);
   if (failed(generateLoopNest(b, loc, options, iterationDomain, tileSizesVector,
+                              /*numThreads=*/ArrayRef<OpFoldResult>{},
                               initTensors, innerYieldTiledValuesFn, loops)))
     return b.notifyMatchFailure(op, "failed to tile for parallel reduction");
 
diff --git a/mlir/test/Dialect/Linalg/tile-to-forall.mlir b/mlir/test/Dialect/Linalg/tile-to-forall.mlir
index 8545dfd25eccf..6e92deaf4cf0d 100644
--- a/mlir/test/Dialect/Linalg/tile-to-forall.mlir
+++ b/mlir/test/Dialect/Linalg/tile-to-forall.mlir
@@ -177,7 +177,6 @@ module attributes {transform.with_named_sequence} {
   }
 }
 
-
 // -----
 
 // CHECK-DAG: #[[$map0:.+]] = affine_map<()[s0] -> (s0 ceildiv 10)>
@@ -194,8 +193,8 @@ module attributes {transform.with_named_sequence} {
 func.func @matmul_tile_size_dynamic(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
   //      CHECK: %[[M:.+]] = tensor.dim %[[A]], %c0 :
   //      CHECK: %[[N:.+]] = tensor.dim %[[B]], %c1 :
-  //      CHECK: %[[NT0:.+]] = affine.apply #map()[%[[M]]]
-  //      CHECK: %[[NT1:.+]] = affine.apply #map1()[%[[N]]]
+  //      CHECK: %[[NT0:.+]] = affine.apply #[[$map0]]()[%[[M]]]
+  //      CHECK: %[[NT1:.+]] = affine.apply #[[$map1]]()[%[[N]]]
   //      CHECK: scf.forall (%[[IV0:.+]], %[[IV1:.+]]) in (%[[NT0]], %[[NT1]]) shared_outs(%[[C_BLK:.*]] = %[[C]])
   //      CHECK:   %[[TS0:.+]] = affine.min #[[$map2]](%[[IV0]])[%[[M]]]
   //      CHECK:   %[[TS1:.+]] = affine.min #[[$map4]](%[[IV1]])[%[[N]]]
@@ -220,7 +219,6 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
-
 // -----
 
 // Tests that dimension 0 can eliminate affine.min/max, dimension 1 cannot.
@@ -342,7 +340,6 @@ module attributes {transform.with_named_sequence} {
 // -----
 
 // CHECK-DAG: #[[$map0:.+]] = affine_map<(d0) -> (d0 * -15 + 100, 15)>
-// CHECK-DAG: #[[$map1:.+]] = affine_map<(d0) -> (0, d0)>
 // CHECK-DAG: #[[$map2:.+]] = affine_map<(d0) -> (d0 * 15)>
 // CHECK-DAG: #[[$map3:.+]] = affine_map<(d0) -> (d0)>
 
@@ -355,8 +352,7 @@ module attributes {transform.with_named_sequence} {
                                          %OUT1: tensor<100xf32>, %OUT2: tensor<100xf32>)
                                          -> (tensor<100xf32>, tensor<100xf32>) {
 //      CHECK: scf.forall (%[[IV0:.+]]) in (7) shared_outs(%[[OUT1:[0-9a-z]+]] = %[[ORGOUT1]], %[[OUT2:[0-9a-z]+]] = %[[ORGOUT2]])
-//      CHECK:   %[[TSMIN:.+]] = affine.min #[[$map0]](%[[IV0]])
-//      CHECK:   %[[TS:.+]] = affine.max #[[$map1]](%[[TSMIN]])
+//      CHECK:   %[[TS:.+]] = affine.min #[[$map0]](%[[IV0]])
 //  CHECK-NOT:   affine.min
 //  CHECK-NOT:   affine.max
 //      CHECK:   %[[LB:.+]] = affine.apply #[[$map2]](%[[IV0]])
@@ -456,9 +452,10 @@ module attributes {transform.with_named_sequence} {
 // CHECK-DAG: #[[$map0:.+]] = affine_map<()[s0] -> (s0 ceildiv 10)>
 // CHECK-DAG: #[[$map1:.+]] = affine_map<()[s0] -> (s0 ceildiv 20)>
 // CHECK-DAG: #[[$map2:.+]] = affine_map<(d0)[s0] -> (d0 * -10 + s0, 10)>
-// CHECK-DAG: #[[$map3:.+]] = affine_map<(d0)[s0] -> (d0 * -20 + s0, 20)>
-// CHECK-DAG: #[[$map4:.+]] = affine_map<(d0) -> (d0 * 10)>
-// CHECK-DAG: #[[$map5:.+]] = affine_map<(d0) -> (d0 * 20)>
+// CHECK-DAG: #[[$map3:.+]] = affine_map<(d0) -> (0, d0)>
+// CHECK-DAG: #[[$map4:.+]] = affine_map<(d0)[s0] -> (d0 * -20 + s0, 20)>
+// CHECK-DAG: #[[$map5:.+]] = affine_map<(d0) -> (d0 * 10)>
+// CHECK-DAG: #[[$map6:.+]] = affine_map<(d0) -> (d0 * 20)>
 
 // CHECK-LABEL: matmul_tile_size_dynamic(
 //  CHECK-SAME:   %[[A:[0-9a-z]+]]: tensor<?x?xf32>
@@ -473,10 +470,12 @@ func.func @matmul_tile_size_dynamic(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C
   //      CHECK: %[[NT1:.+]] = affine.apply #map1()[%[[N]]]
   //      CHECK: %[[K:.+]] = tensor.dim %[[A]], %[[c1]] :
   //      CHECK: scf.forall (%[[IV0:.+]], %[[IV1:.+]]) in (%[[NT0]], %[[NT1]]) shared_outs(%[[C_BLK:.*]] = %[[C]])
-  //      CHECK:   %[[TS0:.+]] = affine.min #[[$map2]](%[[IV0]])[%[[M]]]
-  //      CHECK:   %[[TS1:.+]] = affine.min #[[$map3]](%[[IV1]])[%[[N]]]
-  //      CHECK:   %[[LB0:.+]] = affine.apply #[[$map4]](%[[IV0]])
-  //      CHECK:   %[[LB1:.+]] = affine.apply #[[$map5]](%[[IV1]])
+  //      CHECK:   %[[TSMIN0:.+]] = affine.min #[[$map2]](%[[IV0]])[%[[M]]]
+  //      CHECK:   %[[TS0:.+]] = affine.max #[[$map3]](%[[TSMIN0]])
+  //      CHECK:   %[[TSMIN1:.+]] = affine.min #[[$map4]](%[[IV1]])[%[[N]]]
+  //      CHECK:   %[[TS1:.+]] = affine.max #[[$map3]](%[[TSMIN1]])
+  //      CHECK:   %[[LB0:.+]] = affine.apply #[[$map5]](%[[IV0]])
+  //      CHECK:   %[[LB1:.+]] = affine.apply #[[$map6]](%[[IV1]])
   //      CHECK:   tensor.extract_slice %[[A]][%[[LB0]], 0] [%[[TS0]], %[[K]]] [1, 1] :
   //      CHECK:   tensor.extract_slice %[[B]][0, %[[LB1]]] [%[[K]], %[[TS1]]] [1, 1] :
   //      CHECK:   tensor.extract_slice %[[C_BLK]][%[[LB0]], %[[LB1]]] [%[[TS0]], %[[TS1]]] [1, 1] :
@@ -524,9 +523,10 @@ module attributes {transform.with_named_sequence} {
 // CHECK-DAG: #[[$map0:.+]] = affine_map<()[s0] -> (s0 ceildiv 10)>
 // CHECK-DAG: #[[$map1:.+]] = affine_map<()[s0] -> (s0 ceildiv 20)>
 // CHECK-DAG: #[[$map2:.+]] = affine_map<(d0)[s0] -> (d0 * -10 + s0, 10)>
-// CHECK-DAG: #[[$map3:.+]] = affine_map<(d0)[s0] -> (d0 * -20 + s0, 20)>
-// CHECK-DAG: #[[$map4:.+]] = affine_map<(d0) -> (d0 * 10)>
-// CHECK-DAG: #[[$map5:.+]] = affine_map<(d0) -> (d0 * 20)>
+// CHECK-DAG: #[[$map3:.+]] = affine_map<(d0) -> (0, d0)>
+// CHECK-DAG: #[[$map4:.+]] = affine_map<(d0)[s0] -> (d0 * -20 + s0, 20)>
+// CHECK-DAG: #[[$map5:.+]] = affine_map<(d0) -> (d0 * 10)>
+// CHECK-DAG: #[[$map6:.+]] = affine_map<(d0) -> (d0 * 20)>
 
 // CHECK-LABEL: matmul_tile_size_dynamic(
 //  CHECK-SAME:   %[[A:[0-9a-z]+]]: tensor<?x?xf32>
@@ -541,10 +541,12 @@ func.func @matmul_tile_size_dynamic(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C
   //      CHECK: %[[NT1:.+]] = affine.apply #map1()[%[[N]]]
   //      CHECK: %[[K:.+]] = tensor.dim %[[A]], %[[c1]] :
   //      CHECK: scf.forall (%[[IV0:.+]], %[[IV1:.+]]) in (%[[NT0]], %[[NT1]]) shared_outs(%[[C_BLK:.*]] = %[[C]])
-  //      CHECK:   %[[TS0:.+]] = affine.min #[[$map2]](%[[IV0]])[%[[M]]]
-  //      CHECK:   %[[TS1:.+]] = affine.min #[[$map3]](%[[IV1]])[%[[N]]]
-  //      CHECK:   %[[LB0:.+]] = affine.apply #[[$map4]](%[[IV0]])
-  //      CHECK:   %[[LB1:.+]] = affine.apply #[[$map5]](%[[IV1]])
+  //      CHECK:   %[[TSMIN0:.+]] = affine.min #[[$map2]](%[[IV0]])[%[[M]]]
+  //      CHECK:   %[[TS0:.+]] = affine.max #[[$map3]](%[[TSMIN0]])
+  //      CHECK:   %[[TSMIN1:.+]] = affine.min #[[$map4]](%[[IV1]])[%[[N]]]
+  //      CHECK:   %[[TS1:.+]] = affine.max #[[$map3]](%[[TSMIN1]])
+  //      CHECK:   %[[LB0:.+]] = affine.apply #[[$map5]](%[[IV0]])
+  //      CHECK:   %[[LB1:.+]] = affine.apply #[[$map6]](%[[IV1]])
   //      CHECK:   tensor.extract_slice %[[A]][%[[LB0]], 0] [%[[TS0]], %[[K]]] [1, 1] :
   //      CHECK:   tensor.extract_slice %[[B]][0, %[[LB1]]] [%[[K]], %[[TS1]]] [1, 1] :
   //      CHECK:   tensor.extract_slice %[[C_BLK]][%[[LB0]], %[[LB1]]] [%[[TS0]], %[[TS1]]] [1, 1] :
diff --git a/mlir/test/Dialect/Linalg/transform-op-tile.mlir b/mlir/test/Dialect/Linalg/transform-op-tile.mlir
index d244670f73754..3467a539496b8 100644
--- a/mlir/test/Dialect/Linalg/transform-op-tile.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-tile.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt --transform-interpreter --mlir-print-local-scope --split-input-file --verify-diagnostics %s | FileCheck %s
+// RUN: mlir-opt --transform-interpreter --mlir-print-local-scope --split-input-file --verify-diagnostics --cse %s | FileCheck %s
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
@@ -178,12 +178,11 @@ module {
 
 // CHECK-LABEL:   func.func @scalable_tile(
 // CHECK-SAME:      %[[ARG_0:.*]]: tensor<?xf32>, %[[ARG_1:.*]]: tensor<?xf32>, %[[ARG_2:.*]]: tensor<?xf32>,
-// CHECK:           %[[C4:.*]] = arith.constant 0 : index
-// CHECK:           %[[DIM:.*]] = tensor.dim %[[ARG_0]], %[[C4]] : tensor<?xf32>
+// CHECK:           %[[C0:.*]] = arith.constant 0 : index
+// CHECK:           %[[DIM:.*]] = tensor.dim %[[ARG_0]], %[[C0]] : tensor<?xf32>
 // CHECK:           %[[VEC_SIZE:.*]] = arith.constant 4 : index
 // CHECK:           %[[VS:.*]] = vector.vscale
 // CHECK:           %[[STEP:.*]] = arith.muli %[[VEC_SIZE]], %[[VS]] : index
-// CHECK:           %[[C0:.*]] = arith.constant 0 : index
 // CHECK:           scf.for %[[IV:.*]] = %[[C0]] to %[[DIM]] step %[[STEP]] iter_args(%[[VAL:.*]] = %[[ARG_2]]) -> (tensor<?xf32>) {
 // CHECK:             %[[SIZE:.*]] = affine.min affine_map<(d0)[s0, s1] -> (s0, -d0 + s1)>(%[[IV]])[%[[STEP]], %[[DIM]]]
 // CHECK:             %[[SLICE_ARG0:.*]] = tensor.extract_slice %[[ARG_0]][%[[IV]]] [%[[SIZE]]] [1] : tensor<?xf32> to tensor<?xf32>
@@ -202,20 +201,14 @@ module {
 // -----
 
 // CHECK-LABEL:   func.func @scalable_and_fixed_length_tile
-// CHECK:           %[[C4:.*]] = arith.constant 4 : index
-// CHECK:           %[[VS:.*]] = vector.vscale
-// CHECK:           %[[STEP_2:.*]] = arith.muli %[[C4]], %[[VS]] : index
-// CHECK:           %[[C0:.*]] = arith.constant 0 : index
-// CHECK:           %[[C128:.*]] = arith.constant 128 : index
-// CHECK:           %[[STEP_0:.*]] = arith.constant 4 : index
-// CHECK:           scf.for %[[VAL_11:.*]] = %[[C0]] to %[[C128]] step %[[STEP_0]]
-// CHECK:             %[[C0_1:.*]] = arith.constant 0 : index
-// CHECK:             %[[C128_1:.*]] = arith.constant 128 : index
-// CHECK:             %[[STEP_1:.*]] = arith.constant 4 : index
-// CHECK:             scf.for %[[VAL_16:.*]] = %[[C0_1]] to %[[C128_1]] step %[[STEP_1]]
-// CHECK:               %[[C0_2:.*]] = arith.constant 0 : index
-// CHECK:               %[[C128_2:.*]] = arith.constant 128 : index
-// CHECK:               scf.for %{{.*}} = %[[C0_2]] to %[[C128_2]] step %[[STEP_2]]
+//   CHECK-DAG:     %[[C4:.*]] = arith.constant 4 : index
+//   CHECK-DAG:     %[[VS:.*]] = vector.vscale
+//   CHECK-DAG:     %[[STEP_2:.*]] = arith.muli %[[C4]], %[[VS]] : index
+//   CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
+//   CHECK-DAG:     %[[C128:.*]] = arith.constant 128 : index
+//       CHECK:     scf.for %[[VAL_11:.*]] = %[[C0]] to %[[C128]] step %[[C4]]
+//       CHECK:       scf.for %[[VAL_16:.*]] = %[[C0]] to %[[C128]] step %[[C4]]
+//       CHECK:         scf.for %{{.*}} = %[[C0]] to %[[C128]] step %[[STEP_2]]
 
 func.func @scalable_and_fixed_length_tile(
   %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>)
diff --git a/mlir/test/Interfaces/TilingInterface/tile-pad-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-pad-using-interface.mlir
index 7d247aefcf6b1..ccf8e37c094f4 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-pad-using-interface.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-pad-using-interface.mlir
@@ -31,8 +31,8 @@ module attributes {transform.with_named_sequence} {
 //   CHECK-DAG:   %[[DIM_IN1:.+]] = tensor.dim %[[IN]], %[[C1]]
 //   CHECK-DAG:   %[[DIM1:.+]] = affine.apply #[[MAP1]]()[%[[DIM_IN1]]]
 //   CHECK-DAG:   %[[C2:.+]] = arith.constant 2 : index
+//   CHECK-DAG:   %[[C3:.+]] = arith.constant 3 : index
 //       CHECK:   %[[RESULT:[a-zA-Z0-9]+]] = scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[DIM0]] step %[[C2]]
-//       CHECK:     %[[C3:.+]] = arith.constant 3 : index
 //       CHECK:     scf.for {{.*}} = %[[C0]] to %[[DIM1]] step %[[C3]] iter_args(%[[INNER_OUT:.*]] =
 //       CHECK:       %[[SWAP_RESULT:.*]] = scf.if
 //       CHECK:         tensor.generate
@@ -62,8 +62,8 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
-//   CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 + 8)>
-//   CHECK-DAG: #[[MAP1:.*]] = affine_map<()[s0] -> (s0 + 7)>
+//   CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 + 8)>
+//   CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 + 7)>
 //       CHECK: func @dynamic_2d_pad_tensor_inner_tiling(
 //  CHECK-SAME:     %[[IN:.*]]: tensor<?x?xf32>
 //   CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
@@ -107,9 +107,9 @@ module attributes {transform.with_named_sequence} {
 //   CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
 //   CHECK-DAG:   %[[C15:.*]] = arith.constant 15 : index
 //   CHECK-DAG:   %[[C2:.*]] = arith.constant 2 : index
+//   CHECK-DAG:   %[[C16:.*]] = arith.constant 16 : index
+//   CHECK-DAG:   %[[C3:.*]] = arith.constant 3 : index
 //       CHECK:   %[[RESULT:.*]] = scf.for {{.*}} = %[[C0]] to %[[C15]] step %[[C2]]
-//   CHECK-DAG:     %[[C16:.*]] = arith.constant 16 : index
-//   CHECK-DAG:     %[[C3:.*]] = arith.constant 3 : index
 //       CHECK:     scf.for {{.*}} = %[[C0]] to %[[C16]] step %[[C3]] iter_args(%[[INNER_OUT:.*]] =
 //       CHECK:       %[[SWAP_RESULT:.*]] = scf.if
 //       CHECK:         tensor.generate
diff --git a/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir
index 488a52e8e3e91..0a4d4c45f10be 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir
@@ -24,13 +24,13 @@ module attributes {transform.with_named_sequence} {
 // CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<?x?xf32>
 //  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
 //  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
-//  CHECK-DAG:   %[[C10:.+]] = arith.constant 10 : index
 //  CHECK-DAG:   %[[M:.+]] = tensor.dim %[[ARG0]], %[[C0]]
 //  CHECK-DAG:   %[[K:.+]] = tensor.dim %[[ARG0]], %[[C1]]
 //  CHECK-DAG:   %[[N:.+]] = tensor.dim %[[ARG1]], %[[C1]]
+//  CHECK-DAG:   %[[C10:.+]] = arith.constant 10 : index
+//  CHECK-DAG:   %[[C20:.+]] = arith.constant 20 : index
 //      CHECK:   %[[OUTER:[a-zA-Z0-9]+]] = scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[M]] step %[[C10]]
 // CHECK-SAME:       iter_args(%[[INIT0:.+]] = %[[ARG2]])
-//  CHECK-DAG:     %[[C20:.+]] = arith.constant 20 : index
 //      CHECK:     %[[INNER:[a-zA-Z0-9]+]] = scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[N]] step %[[C20]]
 // CHECK-SAME:         iter_args(%[[INIT1:.+]] = %[[INIT0]])
 //  CHECK-DAG:       %[[TS_Y:.+]] = affine.min #[[$MAP0]](%[[IV0]])[%[[M]]]
@@ -77,14 +77,14 @@ module attributes {transform.with_named_sequence} {
 // CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: memref<?x?xf32>
 //  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
 //  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
-//  CHECK-DAG:   %[[C10:.+]] = arith.constant 10 : index
 //  CHECK-DAG:   %[[M:.+]] = memref.dim %[[ARG0]], %[[C0]]
 //  CHECK-DAG:   %[[K:.+]] = memref.dim %[[ARG0]], %[[C1]]
 //  CHECK-DAG:   %[[N:.+]] = memref.dim %[[ARG1]], %[[C1]]
+//  CHECK-DAG:   %[[C10:.+]] = arith.constant 10 : index
+//  CHECK-DAG:   %[[C20:.+]] = arith.constant 20 : index
+//  CHECK-DAG:   %[[C30:.+]] = arith.constant 30 : index
 //      CHECK:   scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[M]] step %[[C10]]
-//  CHECK-DAG:     %[[C20:.+]] = arith.constant 20 : index
 //      CHECK:     scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[N]] step %[[C20]]
-//  CHECK-DAG:       %[[C30:.+]] = arith.constant 30 : index
 //      CHECK:       scf.for %[[IV2:[a-zA-Z0-9]+]] = %[[C0]] to %[[K]] step %[[C30]]
 //  CHECK-DAG:         %[[TS_M:.+]] = affine.min #[[$MAP0]](%[[IV0]])[%[[M]]]
 //  CHECK-DAG:         %[[TS_N:.+]] = affine.min #[[$MAP1]](%[[IV1]])[%[[N]]]
@@ -130,15 +130,15 @@ module attributes {transform.with_named_sequence} {
 //   CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0) -> (10, -d0 + 128)>
 // CHECK-LABEL: func.func @multi_result(
 //  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<128x200x300xf32>)
-//   CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
-//   CHECK-DAG:   %[[C10:.+]] = arith.constant 10 : index
-//   CHECK-DAG:   %[[C128:.+]] = arith.constant 128 : index
 //   CHECK-DAG:   %[[INIT0:.+]] = tensor.empty()
 //   CHECK-DAG:   %[[INIT1:.+]] = tensor.empty()
+//   CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//   CHECK-DAG:   %[[C128:.+]] = arith.constant 128 : index
+//   CHECK-DAG:   %[[C300:.+]] = arith.constant 300 : index
+//   CHECK-DAG:   %[[C10:.+]] = arith.constant 10 : index
+//   CHECK-DAG:   %[[C20:.+]] = arith.constant 20 : index
 //       CHECK:   %[[OUTER:[a-zA-Z0-9]+]]:2 = scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[C128]] step %[[C10]]
 //  CHECK-SAME:       iter_args(%[[ARG1:[a-zA-Z0-9]+]] = %[[INIT0]], %[[ARG2:[a-zA-Z0-9]+]] = %[[INIT1]])
-//   CHECK-DAG:     %[[C300:.+]] = arith.constant 300 : index
-//   CHECK-DAG:     %[[C20:.+]] = arith.constant 20 : index
 //       CHECK:     %[[INNER:[a-zA-Z0-9]+]]:2 = scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[C300]] step %[[C20]]
 //  CHECK-SAME:         iter_args(%[[ARG3:[a-zA-Z0-9]+]] = %[[ARG1]], %[[ARG4:[a-zA-Z0-9]+]] = %[[ARG2]])
 //   CHECK-DAG:       %[[TS_Y:.+]] = affine.min #[[$MAP0]](%[[IV0]])
@@ -193,7 +193,6 @@ module attributes {transform.with_named_sequence} {
 //  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
 //  CHECK-DAG:   %[[C2:.+]] = arith.constant 2 : index
 //  CHECK-DAG:   %[[C3:.+]] = arith.constant 3 : index
-//  CHECK-DAG:   %[[C10:.+]] = arith.constant 10 : index
 //  CHECK-DAG:   %[[N:.+]] = tensor.dim %[[INPUT]], %[[C0]]
 //  CHECK-DAG:   %[[C:.+]] = tensor.dim %[[INPUT]], %[[C3]]
 //  CHECK-DAG:   %[[P:.+]] = tensor.dim %[[FILTER]], %[[C0]]
@@ -201,12 +200,13 @@ module attributes {transform.with_named_sequence} {
 //  CHECK-DAG:   %[[F:.+]] = tensor.dim %[[FILTER]], %[[C3]]
 //  CHECK-DAG:   %[[R:.+]] = tensor.dim %[[INIT]], %[[C1]]
 //  CHECK-DAG:   %[[S:.+]] = tensor.dim %[[INIT]], %[[C2]]
+//  CHECK-DAG:   %[[C10:.+]] = arith.constant 10 : index
+//  CHECK-DAG:   %[[C20:.+]] = arith.constant 20 : index
+//  CHECK-DAG:   %[[C30:.+]] = arith.constant 30 : index
 //      CHECK:   scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[P]] step %[[C10]]
 // CHECK-SAME:       iter_args(%[[INIT0:.+]] = %[[INIT]])
-//  CHECK-DAG:     %[[C20:.+]] = arith.constant 20 : index
 //      CHECK:     scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[Q]] step %[[C20]]
 // CHECK-SAME:         iter_args(%[[INIT1:.+]] = %[[INIT0]])
-//  CHECK-DAG:       %[[C30:.+]] = arith.constant 30 : index
 //      CHECK:       scf.for %[[IV2:[a-zA-Z0-9]+]] = %[[C0]] to %[[C]] step %[[C30]]
 // CHECK-SAME:           iter_args(%[[INIT2:.+]] = %[[INIT1]])
 //  CHECK-DAG:         %[[TS_P:.+]] = affine.min #[[$MAP0]](%[[IV0]])[%[[P]]]
@@ -296,16 +296,16 @@ module attributes {transform.with_named_sequence} {
 // CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<?x?xf32>
 //  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
 //  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
-//  CHECK-DAG:   %[[C20:.+]] = arith.constant 20 : index
 //  CHECK-DAG:   %[[M:.+]] = tensor.dim %[[ARG0]], %[[C0]]
 //  CHECK-DAG:   %[[K:.+]] = tensor.dim %[[ARG0]], %[[C1]]
 //  CHECK-DAG:   %[[N:.+]] = tensor.dim %[[ARG1]], %[[C1]]
+//  CHECK-DAG:   %[[C10:.+]] = arith.constant 10 : index
+//  CHECK-DAG:   %[[C20:.+]] = arith.constant 20 : index
+//  CHECK-DAG:   %[[C30:.+]] = arith.constant 30 : index
 //      CHECK:   %[[OUTER:[a-zA-Z0-9]+]] = scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[N]] step %[[C20]]
 // CHECK-SAME:       iter_args(%[[INIT0:.+]] = %[[ARG2]])
-//  CHECK-DAG:     %[[C30:.+]] = arith.constant 30 : index
 //      CHECK:     %[[INNER1:[a-zA-Z0-9]+]] = scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[K]] step %[[C30]]
 // CHECK-SAME:         iter_args(%[[INIT1:.+]] = %[[INIT0]])
-//  CHECK-DAG:       %[[C10:.+]] = arith.constant 10 : index
 //      CHECK:       %[[INNER2:[a-zA-Z0-9]+]] = scf.for %[[IV2:[a-zA-Z0-9]+]] = %[[C0]] to %[[M]] step %[[C10]]
 // CHECK-SAME:           iter_args(%[[INIT2:.+]] = %[[INIT1]])
 //  CHECK-DAG:         %[[TS_N:.+]] = affine.min #[[$MAP0]](%[[IV0]])[%[[N]]]
diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
index 833fb3cc65b81..abe41782c9b60 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
@@ -235,11 +235,7 @@ applyTileToAll(RewriterBase &rewriter, Operation *transformOp,
     scf::SCFTilingOptions tilingOptions;
     tilingOptions.setTileSizes(tileSizes).setInterchange(interchange);
     if (mapping) {
-      auto mappingAttrs =
-          llvm::map_to_vector(mapping.value(), [](Attribute attr) {
-            return cast<DeviceMappingAttrInterface>(attr);
-          });
-      tilingOptions.setMapping(mappingAttrs);
+      tilingOptions.setMapping(mapping.value().getValue());
     }
     tilingOptions.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
 



More information about the Mlir-commits mailing list