[Mlir-commits] [mlir] [mlir][SCF] Deprecate `linalg::tileToForallOp` and `linalg::tileToForallOpUsingTileSizes` (PR #91878)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue May 21 23:59:14 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-scf

@llvm/pr-subscribers-mlir

Author: None (MaheshRavishankar)

<details>
<summary>Changes</summary>

The implementation of these methods are legacy and they are removed in favor of using the `scf::tileUsingSCF` methods as replacements. To get the latter on par with requirements of the deprecated methods, the tiling allows one to specify the maximum number of tiles to use instead of specifying the tile sizes. When tiling to `scf.forall` this specification is used to generate the `num_threads` version of the operation.

A slight deviation from previous implementation is that the deprecated method always generated the `num_threads` variant of the `scf.forall` operation. Instead now this is driven by the tiling options specified. This reduces the indexing math generated when the tile sizes are specified.

---

Patch is 59.54 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/91878.diff


11 Files Affected:

- (modified) mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h (+5-1) 
- (modified) mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h (-24) 
- (modified) mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h (+28-7) 
- (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+35-13) 
- (modified) mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp (-182) 
- (modified) mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp (+244-75) 
- (modified) mlir/test/Dialect/Linalg/tile-to-forall.mlir (+25-28) 
- (modified) mlir/test/Dialect/Linalg/transform-op-tile.mlir (+11-18) 
- (modified) mlir/test/Interfaces/TilingInterface/tile-pad-using-interface.mlir (+5-5) 
- (modified) mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir (+25-25) 
- (modified) mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp (+1-5) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
index 3af642752724c..db25c9b241734 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
@@ -30,6 +30,10 @@ class GenericOp;
 class LinalgOp;
 } // namespace linalg
 
+namespace scf {
+struct SCFTilingResult;
+} // namespace scf
+
 namespace tensor {
 class InsertSliceOp;
 class PackOp;
@@ -60,7 +64,7 @@ tileToForallOpImpl(RewriterBase &rewriter, transform::TransformState &state,
                    ArrayRef<OpFoldResult> mixedNumThreads,
                    ArrayRef<OpFoldResult> mixedTileSizes,
                    std::optional<ArrayAttr> mapping,
-                   linalg::ForallTilingResult &tilingResult);
+                   scf::SCFTilingResult &tilingResult);
 
 } // namespace transform
 } // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index f77c19ed0fcce..73fd6a469d0e7 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -846,30 +846,6 @@ FailureOr<StaticMultiSizeSpecification>
 computeStaticMultiTileSizes(LinalgOp op, unsigned dimension, int64_t targetSize,
                             int64_t divisor);
 
-/// Rewrite a TilingInterface `op` to a tiled `scf.forall`, applying
-/// tiling by `numThreads`.
-/// If non-empty, the `mapping` is added as an attribute to the
-/// resulting `scf.forall`.
-/// Zero tile sizes indicate that the dimension is not tiled, and can be
-/// thought of as tiling by the full size of data. It is the user's
-/// responsibility to ensure that `numThreads` is a valid tiling specification
-/// (i.e. that only tiles parallel dimensions, e.g. in the Linalg case).
-struct ForallTilingResult {
-  Operation *tileOp;
-  Operation *tiledOp;
-};
-FailureOr<ForallTilingResult> tileToForallOp(RewriterBase &builder,
-                                             TilingInterface op,
-                                             ArrayRef<OpFoldResult> numThreads,
-                                             std::optional<ArrayAttr> mapping);
-
-/// Same as `tileToForallOp`, but calculate the number of threads
-/// required using the given tileSizes.
-FailureOr<ForallTilingResult>
-tileToForallOpUsingTileSizes(RewriterBase &builder, TilingInterface op,
-                             ArrayRef<OpFoldResult> tileSizes,
-                             std::optional<ArrayAttr> mapping);
-
 /// Transformation information returned after reduction tiling.
 struct ForallReductionTilingResult {
   /// The partial reduction tiled op generated.
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index 965ef9e203be2..f6858c3cddf46 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -31,9 +31,11 @@ using SCFTileSizeComputationFunction =
 
 /// Options to use to control tiling.
 struct SCFTilingOptions {
-  /// Computation function that returns the tile sizes for each operation.
-  /// Delayed construction of constant tile sizes should occur to interoperate
-  /// with folding.
+  /// Computation function that returns the tile sizes to use for each loop.
+  /// Returning a tile size of zero implies no tiling for that loop. If the
+  /// size of the returned vector is smaller than the number of loops, the inner
+  /// loops are not tiled. If the size of the returned vector is larger, then
+  /// the vector is truncated to number of loops.
   SCFTileSizeComputationFunction tileSizeComputationFunction = nullptr;
 
   SCFTilingOptions &
@@ -44,7 +46,27 @@ struct SCFTilingOptions {
   /// Convenience function to set the `tileSizeComputationFunction` to a
   /// function that computes tile sizes at the point they are needed. Allows
   /// proper interaction with folding.
-  SCFTilingOptions &setTileSizes(ArrayRef<OpFoldResult> ts);
+  SCFTilingOptions &setTileSizes(ArrayRef<OpFoldResult> tileSizes);
+
+  /// Computation function that returns the number of threads to use for
+  /// each loop. Returning a num threads of zero implies no tiling for that
+  /// loop. If the size of the returned vector is smaller than the number of
+  /// loops, the inner loops are not tiled. If the size of the returned vector
+  /// is larger, then the vector is truncated to number of loops. Note: This
+  /// option is only supported with loopType set to `LoopType::ForallOp`. If the
+  /// tile size function is not specified while the num threads computation is,
+  /// then the tile size is determined automatically to map at most one tile per
+  /// thread.
+  SCFTileSizeComputationFunction numThreadsComputationFunction = nullptr;
+
+  SCFTilingOptions &
+  setNumThreadsComputationFunction(SCFTileSizeComputationFunction fun) {
+    numThreadsComputationFunction = std::move(fun);
+    return *this;
+  }
+  /// Convenience function to set the `tileSizeComputationFunction` to a
+  /// function that computes tile sizes at the point they are needed.
+  SCFTilingOptions &setNumThreads(ArrayRef<OpFoldResult> numThreads);
 
   /// The interchange vector to reorder the tiled loops.
   SmallVector<int64_t> interchangeVector = {};
@@ -66,9 +88,8 @@ struct SCFTilingOptions {
   /// when using loop constructs that dont support such a mapping (like
   /// `scf.for`)
   SmallVector<Attribute> mappingVector = {};
-  SCFTilingOptions &setMapping(ArrayRef<DeviceMappingAttrInterface> mapping) {
-    mappingVector = llvm::map_to_vector(
-        mapping, [](auto attr) -> Attribute { return attr; });
+  SCFTilingOptions &setMapping(ArrayRef<Attribute> mapping) {
+    mappingVector = llvm::to_vector(mapping);
     return *this;
   }
 };
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 13582a140a965..e9bcbee5c1a8a 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2917,7 +2917,7 @@ DiagnosedSilenceableFailure transform::tileToForallOpImpl(
     TransformOpInterface transformOp, Operation *target,
     ArrayRef<OpFoldResult> mixedNumThreads,
     ArrayRef<OpFoldResult> mixedTileSizes, std::optional<ArrayAttr> mapping,
-    linalg::ForallTilingResult &tilingResult) {
+    scf::SCFTilingResult &tilingResult) {
   // Transform all targets one by one.
   auto tileableOp = dyn_cast<TilingInterface>(target);
   if (!tileableOp) {
@@ -2928,18 +2928,39 @@ DiagnosedSilenceableFailure transform::tileToForallOpImpl(
     return diag;
   }
   rewriter.setInsertionPoint(tileableOp);
-  FailureOr<linalg::ForallTilingResult> maybeTilingResult = failure();
+  scf::SCFTilingOptions options;
+  options.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
   if (!mixedNumThreads.empty()) {
-    maybeTilingResult =
-        linalg::tileToForallOp(rewriter, tileableOp, mixedNumThreads, mapping);
+    options.setNumThreads(mixedNumThreads);
   } else {
-    maybeTilingResult = linalg::tileToForallOpUsingTileSizes(
-        rewriter, tileableOp, mixedTileSizes, mapping);
+    SmallVector<Range> loopRanges = tileableOp.getIterationDomain(rewriter);
+    unsigned nLoops = loopRanges.size();
+    SmallVector<OpFoldResult> numThreads;
+    numThreads.reserve(nLoops);
+    AffineExpr s0, s1;
+    bindSymbols(rewriter.getContext(), s0, s1);
+    AffineExpr divExpr = s0.ceilDiv(s1);
+    for (int i = 0, e = std::min(mixedTileSizes.size(), loopRanges.size());
+         i < e; ++i) {
+      OpFoldResult numTiles = mixedTileSizes[i];
+      if (!isConstantIntValue(numTiles, 0))
+        numTiles = affine::makeComposedFoldedAffineApply(
+            rewriter, tileableOp.getLoc(), divExpr,
+            {loopRanges[i].size, numTiles});
+      numThreads.push_back(numTiles);
+    }
+    options.setNumThreads(numThreads);
+    options.setTileSizes(mixedTileSizes);
+  }
+  if (mapping) {
+    options.setMapping(mapping.value().getValue());
   }
+  FailureOr<scf::SCFTilingResult> maybeTilingResult =
+      scf::tileUsingSCF(rewriter, tileableOp, options);
 
   if (failed(maybeTilingResult))
     return transformOp.emitDefaultSilenceableFailure(tileableOp);
-  rewriter.replaceOp(tileableOp, maybeTilingResult->tileOp->getResults());
+  rewriter.replaceOp(tileableOp, maybeTilingResult->replacements);
 
   tilingResult = *maybeTilingResult;
   return DiagnosedSilenceableFailure::success();
@@ -2975,14 +2996,14 @@ DiagnosedSilenceableFailure transform::TileUsingForallOp::apply(
     return status;
 
   for (Operation *target : state.getPayloadOps(getTarget())) {
-    linalg::ForallTilingResult tilingResult;
+    scf::SCFTilingResult tilingResult;
     DiagnosedSilenceableFailure diag = tileToForallOpImpl(
         rewriter, state, transformOp, target, mixedNumThreads, mixedTileSizes,
         getMapping(), tilingResult);
     if (!diag.succeeded())
       return diag;
-    tileOps.push_back(tilingResult.tileOp);
-    tiledOps.push_back(tilingResult.tiledOp);
+    tileOps.push_back(tilingResult.loops.front());
+    tiledOps.append(tilingResult.tiledOps);
   }
 
   transformResults.set(cast<OpResult>(getForallOp()), tileOps);
@@ -3460,7 +3481,7 @@ DiagnosedSilenceableFailure transform::MapCopyToThreadsOp::applyToOne(
 
   // OpBuilder only used to compute attributes.
   OpBuilder b(getContext());
-  linalg::ForallTilingResult tilingResult;
+  scf::SCFTilingResult tilingResult;
   DiagnosedSilenceableFailure diag = tileToForallOpImpl(
       /*rewriter=*/rewriter,
       /*state=*/state,
@@ -3473,8 +3494,9 @@ DiagnosedSilenceableFailure transform::MapCopyToThreadsOp::applyToOne(
   if (!diag.succeeded())
     return diag;
 
-  results.push_back(tilingResult.tileOp);
-  results.push_back(tilingResult.tiledOp);
+  results.push_back(tilingResult.loops.front());
+  for (auto op : tilingResult.tiledOps)
+    results.push_back(op);
   return DiagnosedSilenceableFailure::success();
 }
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index df4089d61bfd7..30031a443f7d8 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -304,188 +304,6 @@ static void calculateTileOffsetsAndSizes(
   }
 }
 
-/// Returns a vector of bools representing if, for each axis, `op` can be tiled
-/// without incurring in a race condition and thus it is thread-safe to do the
-/// tiling. This is checked by iterating over numThreads and ensuring that the
-/// corresponding iterator type is "parallel". If it is not, then we know that
-/// such dimension is unsafe to tile.
-SmallVector<bool> safeToTileToForall(mlir::MLIRContext *ctx, LinalgOp linalgOp,
-                                     ArrayRef<OpFoldResult> numThreads) {
-  auto iterators = linalgOp.getIteratorTypesArray();
-  SmallVector<bool> safeToTile(numThreads.size(), true);
-
-  for (unsigned i = 0, e = numThreads.size(); i != e; i++) {
-    if (auto attr = llvm::dyn_cast_if_present<Attribute>(numThreads[i])) {
-      if (cast<IntegerAttr>(attr).getValue().getSExtValue() > 1) {
-        safeToTile[i] = iterators[i] == utils::IteratorType::parallel;
-      }
-    } else {
-      safeToTile[i] = iterators[i] == utils::IteratorType::parallel;
-    }
-  }
-  return safeToTile;
-}
-
-/// Rewrite a TilingInterface `op` to a tiled `scf.forall`. The
-/// tiling is specified by the number of tiles/threads `numThreads` and the
-/// optional nominal tile size `nominalTileSizes`. If `nominalTilSizes` is
-/// not specified, then  it is derived from `numThreads` as `ceilDiv(dimSize[i],
-/// numThreads[i])`. If non-empty, the `mapping` is added as an
-/// attribute to the resulting `scf.forall`. A zero tile sizes indicate
-/// that the dimension is not tiled, and can be thought of as tiling by the full
-/// size of data.
-/// It is the user's responsibility to ensure that `numThreads` is a valid
-/// tiling specification (i.e. that only tiles parallel dimensions, e.g. in the
-/// Linalg case). If the dimension is not parallelizable, a warning is issued to
-/// notify the user that the generated code is not safe to parallelize. If
-/// `omitTileOffsetBoundsCheck` is true, then the function will assume that
-/// `tileSize[i] * (numThread[i] -1) <= dimSize[i]` holds.
-static FailureOr<ForallTilingResult> tileToForallOpImpl(
-    RewriterBase &b, TilingInterface op, ArrayRef<OpFoldResult> numThreads,
-    std::optional<ArrayRef<OpFoldResult>> nominalTileSizes,
-    std::optional<ArrayAttr> mapping, bool omitTileOffsetBoundsCheck) {
-  Location loc = op->getLoc();
-  OpBuilder::InsertionGuard g(b);
-
-  SmallVector<Range> loopRanges = op.getIterationDomain(b);
-  if (loopRanges.empty())
-    return op->emitOpError("expected non-empty loop ranges");
-  auto hasStrideOne = [](Range r) { return !isConstantIntValue(r.stride, 1); };
-  if (llvm::any_of(loopRanges, hasStrideOne))
-    return op->emitOpError("only stride-1 supported atm");
-
-  // Gather destination tensors.
-  SmallVector<Value> dest;
-  if (failed(tensor::getOrCreateDestinations(b, loc, op, dest)))
-    return op->emitOpError("failed to get destination tensors");
-
-  SmallVector<OpFoldResult> nonZeroNumThreads =
-      llvm::to_vector(llvm::make_filter_range(numThreads, [](OpFoldResult ofr) {
-        return !isConstantIntValue(ofr, 0);
-      }));
-  SmallVector<Value> materializedNonZeroNumThreads =
-      llvm::to_vector(llvm::map_range(nonZeroNumThreads, [&](OpFoldResult ofr) {
-        return getValueOrCreateConstantIndexOp(b, loc, ofr);
-      }));
-
-  LinalgOp linalgOp = dyn_cast<LinalgOp>(op.getOperation());
-  if (linalgOp) {
-    // Check if tiling is thread safe and print a warning if not.
-    SmallVector<bool> tilingSafety =
-        safeToTileToForall(b.getContext(), linalgOp, numThreads);
-    for (size_t i = 0; i < tilingSafety.size(); i++)
-      if (!tilingSafety[i])
-        op.emitWarning() << "tiling is not thread safe at axis #" << i;
-  }
-
-  // 1. Create the ForallOp. We don't use the lambda body-builder
-  // version because we require the use of RewriterBase in the body, so we
-  // manually move the insertion point to the body below.
-  scf::ForallOp forallOp = b.create<scf::ForallOp>(
-      loc, getAsOpFoldResult((materializedNonZeroNumThreads)), dest, mapping);
-
-  // 2. Fill out the ForallOp body.
-  SmallVector<OpFoldResult> tiledOffsets, tiledSizes;
-  calculateTileOffsetsAndSizes(b, loc, forallOp, numThreads, loopRanges,
-                               omitTileOffsetBoundsCheck, nominalTileSizes,
-                               tiledOffsets, tiledSizes);
-
-  // 3. Clone the tileable op and update its destination operands to use the
-  // output bbArgs of the ForallOp.
-  ArrayRef<BlockArgument> destBbArgs = forallOp.getRegionIterArgs();
-  Operation *tiledOp = nullptr;
-  SmallVector<Value> tiledValues;
-  {
-    // 3.a. RAII guard, inserting within forallOp, before terminator.
-    OpBuilder::InsertionGuard g(b);
-    b.setInsertionPoint(forallOp.getTerminator());
-    Operation *clonedOp = b.clone(*op.getOperation());
-    auto destinationStyleOp = dyn_cast<DestinationStyleOpInterface>(clonedOp);
-    if (destinationStyleOp) {
-      for (OpOperand &outOperand : destinationStyleOp.getDpsInitsMutable()) {
-        // Swap tensor inits with the corresponding block argument of the
-        // scf.forall op. Memref inits remain as is.
-        if (isa<TensorType>(outOperand.get().getType())) {
-          auto *it = llvm::find(dest, outOperand.get());
-          assert(it != dest.end() && "could not find destination tensor");
-          unsigned destNum = std::distance(dest.begin(), it);
-          outOperand.set(destBbArgs[destNum]);
-        }
-      }
-    }
-
-    // 4. Tile the cloned op and delete the clone.
-    FailureOr<TilingResult> tilingResult =
-        cast<TilingInterface>(clonedOp).getTiledImplementation(b, tiledOffsets,
-                                                               tiledSizes);
-    if (failed(tilingResult))
-      return clonedOp->emitError("Failed to tile op: ");
-    if (tilingResult->tiledOps.size() != 1) {
-      return clonedOp->emitError("expected a single produced tiled op, got ")
-             << tilingResult->tiledOps.size();
-    }
-
-    b.eraseOp(clonedOp);
-    tiledOp = tilingResult->tiledOps.front();
-    tiledValues = tilingResult->tiledValues;
-  }
-
-  // 5. Parallel insert back into the result tensor.
-  for (auto it : llvm::zip(llvm::seq(unsigned(0), unsigned(dest.size())),
-                           tiledValues, destBbArgs)) {
-    // 5.a. Partial subset information is inserted just before the terminator.
-    OpBuilder::InsertionGuard g(b);
-    b.setInsertionPoint(forallOp.getTerminator());
-
-    SmallVector<OpFoldResult> resultOffsets, resultSizes;
-    if (failed(op.getResultTilePosition(b, std::get<0>(it), tiledOffsets,
-                                        tiledSizes, resultOffsets,
-                                        resultSizes)))
-      return op->emitOpError("output offsets couldn't be calculated");
-    SmallVector<OpFoldResult> strides(resultSizes.size(), b.getIndexAttr(1));
-
-    // 5.b. Parallel insertions are inserted at the end of the combining
-    // terminator.
-    b.setInsertionPointToEnd(forallOp.getTerminator().getBody());
-    b.create<tensor::ParallelInsertSliceOp>(loc, std::get<1>(it),
-                                            std::get<2>(it), resultOffsets,
-                                            resultSizes, strides);
-  }
-  return ForallTilingResult{forallOp, tiledOp};
-}
-
-FailureOr<ForallTilingResult>
-linalg::tileToForallOp(RewriterBase &b, TilingInterface op,
-                       ArrayRef<OpFoldResult> numThreads,
-                       std::optional<ArrayAttr> mapping) {
-  return tileToForallOpImpl(b, op, numThreads,
-                            /*nominalTileSizes=*/std::nullopt, mapping,
-                            /*omitTileOffsetBoundsCheck=*/false);
-}
-
-FailureOr<ForallTilingResult>
-linalg::tileToForallOpUsingTileSizes(RewriterBase &b, TilingInterface op,
-                                     ArrayRef<OpFoldResult> tileSizes,
-                                     std::optional<ArrayAttr> mapping) {
-  SmallVector<Range> loopRanges = op.getIterationDomain(b);
-  unsigned nLoops = loopRanges.size();
-  SmallVector<OpFoldResult> numThreads;
-  numThreads.reserve(nLoops);
-  AffineExpr s0, s1;
-  bindSymbols(b.getContext(), s0, s1);
-  AffineExpr divExpr = s0.ceilDiv(s1);
-  for (const auto &it : llvm::zip(tileSizes, loopRanges)) {
-    OpFoldResult numTiles = std::get<0>(it);
-    if (!isConstantIntValue(numTiles, 0))
-      numTiles = makeComposedFoldedAffineApply(
-          b, op.getLoc(), divExpr, {std::get<1>(it).size, std::get<0>(it)});
-    numThreads.push_back(numTiles);
-  }
-  return tileToForallOpImpl(b, op, numThreads,
-                            /*nominalTileSizes=*/tileSizes, mapping,
-                            /*omitTileOffsetBoundsCheck=*/true);
-}
-
 template <typename LoopTy>
 static FailureOr<TiledLinalgOp>
 tileLinalgOpImpl(RewriterBase &b, LinalgOp op, ArrayRef<OpFoldResult> tileSizes,
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 1a84a59ddb69d..c0878c42da3b1 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -41,6 +41,16 @@ scf::SCFTilingOptions::setTileSizes(ArrayRef<OpFoldResult> ts) {
   return *this;
 }
 
+scf::SCFTilingOptions &
+scf::SCFTilingOptions::setNumThreads(ArrayRef<OpFoldResult> nt) {
+  assert(!numThreadsComputationFunction && "num tiles already set");
+  auto numThreads = llvm::to_vector(nt);
+  numThreadsComputationFunction = [numThreads](OpBuilder &b, Operation *op) {
+    return numThreads;
+  };
+  return *this;
+}
+
 /// Helper method to adjust the interchange vector to match the iteration
 /// domain.
 static SmallVector<int64_t>
@@ -60,7 +70,117 @@ fillInterchangeVector(ArrayRef<int64_t> interchangeVector,
 // tileUsingSCF implementation.
 //===...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/91878


More information about the Mlir-commits mailing list