[Mlir-commits] [mlir] 297ba16 - [mlir][linalg] Add tile_size option to `structured.tile_to_foreach_thread_op`

Christopher Bate llvmlistbot at llvm.org
Thu Jul 21 09:36:27 PDT 2022


Author: Christopher Bate
Date: 2022-07-21T10:32:01-06:00
New Revision: 297ba167ded073a47dd9ea7e408aa95acdfcedf1

URL: https://github.com/llvm/llvm-project/commit/297ba167ded073a47dd9ea7e408aa95acdfcedf1
DIFF: https://github.com/llvm/llvm-project/commit/297ba167ded073a47dd9ea7e408aa95acdfcedf1.diff

LOG: [mlir][linalg] Add tile_size option to `structured.tile_to_foreach_thread_op`

This change modifies `structured.tile_to_foreach_thread_op` so that
it accepts either `tile_sizes` or `num_threads` parameters. If
`tile_sizes` are specified, then the number of threads required is
derived the tile sizes rather than the other way around. In both cases,
more aggressive folding of loop parameters is enabled during the
transformation, allowing for the potential elimination of `affine.min`
and `affine.max` operations in the static shape case when calculating
the final adjusted tile size.

Differential Revision: https://reviews.llvm.org/D130139

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/lib/Dialect/Affine/IR/AffineOps.cpp
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
    mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
    mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 9f1cd681315e3..b8bcf136ee383 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -608,15 +608,17 @@ def TileToForeachThreadOp :
        TransformEachOpTrait,
        TransformOpInterface]> {
   let description = [{
-    Tile a TilingInterface `op` to a tiled `scf.foreach_thread`, applying
-    tiling by `num_threads`.
-    If non-empty, the `thread_dim_mapping` is added as an attribute to the 
+    Tile a TilingInterface op to a tiled `scf.foreach_thread`. Tiling is
+    applied by either specifying `num_threads` or `tile_size`. If `num_threads`
+    is specified, then the tile size for each dimension `i` is calculated
+    dynamically via `ceilDiv(dimSize[i], num_threads[i])`.
+    If non-empty, the `thread_dim_mapping` is added as an attribute to the
     resulting `scf.foreach_thread`.
-    Zero tile sizes indicate that the dimension is not tiled, and can be thought
-    of as tiling by the full size of data.
-    It is the user's responsibility to ensure that `num_threads` is a valid 
-    tiling specification (i.e. that only tiles parallel dimensions, e.g. in the 
-    Linalg case).
+    Zero tile sizes indicate that the dimension is not tiled and can be
+    thought of as tiling by the full size of data.
+    It is the user's responsibility to ensure that `num_threads/tile_sizes` is
+    a valid tiling specification (i.e. that only tiles parallel dimensions, 
+    e.g. in the Linalg case).
     
     #### Return modes
     
@@ -627,24 +629,39 @@ def TileToForeachThreadOp :
     successfully, the transform succeeds.
     Otherwise the transform silently fails.
     
-    The 2 returned handles point to only the subset of successfully produced 
+    The two returned handles point to only the subset of successfully produced
     tiled operations, which can all be empty.
     
-    These 2 returned handles point to:
+    These two returned handles point to:
       - the new scf.foreach_thread op,
       - the tiled op that implements TilingInterface.
+
+    ### Example using `num_threads`
+
+    ```
+    %0 = pdl_match @match_matmul in %arg1    
+    %3:2 = transform.structured.tile_to_foreach_thread_op %0 num_threads [10, 20]
+    ```
+
+    ### Example using `tile_sizes`
+    
+    ```
+    %0 = pdl_match @match_matmul in %arg1    
+    %3:2 = transform.structured.tile_to_foreach_thread_op %0 tile_sizes [10, 20, 0]
+    ```
   }];
 
   let arguments = (ins PDL_Operation:$target,
                    // TODO: dynamic number of threads.
-                   DefaultValuedAttr<I64ArrayAttr, "{}">:$num_threads,
+                   OptionalAttr<DefaultValuedAttr<I64ArrayAttr, "{}">>:$num_threads,
+                   OptionalAttr<DefaultValuedAttr<I64ArrayAttr, "{}">>:$tile_sizes,
                    OptionalAttr<I64ArrayAttr>:$thread_dim_mapping);
   let results = (outs PDL_Operation:$foreach_thread_op,
                       PDL_Operation:$tiled_op);
 
   let assemblyFormat = [{
-    $target $num_threads (`(` `mapped` `to` `dims` $thread_dim_mapping^ `)`)? 
-      attr-dict
+    $target  (`num_threads` $num_threads^) : (`tile_sizes` $tile_sizes)?
+      (`(` `mapped` `to` `dims` $thread_dim_mapping^ `)`)? attr-dict
   }];
 
   let extraClassDeclaration = [{

diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index c81a7ee2ac326..569faf383a0ee 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -466,10 +466,17 @@ struct ForeachThreadTilingResult {
   Operation *tiledOp;
 };
 FailureOr<ForeachThreadTilingResult>
-tileToForeachThreadOp(OpBuilder &builder, TilingInterface op,
+tileToForeachThreadOp(RewriterBase &builder, TilingInterface op,
                       ArrayRef<OpFoldResult> numThreads,
                       ArrayRef<int64_t> threadDimMapping = {});
 
+/// Same as `tileToForeachThreadOp`, but calculate the number of threads
+/// required using the given tileSizes.
+FailureOr<ForeachThreadTilingResult>
+tileToForeachThreadOpUsingTileSizes(RewriterBase &builder, TilingInterface op,
+                                    ArrayRef<OpFoldResult> tileSizes,
+                                    ArrayRef<int64_t> threadDimMapping = {});
+
 /// All indices returned by IndexOp should be invariant with respect to tiling.
 /// Therefore, if an operation is tiled, we have to transform the indices
 /// accordingly, i.e. offset them by the values of the corresponding induction

diff  --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index d2dc66d610789..345e27742fbbd 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -721,8 +721,13 @@ static void materializeConstants(OpBuilder &b, Location loc,
       actualValues.push_back(value);
       continue;
     }
-    constants.push_back(dialect->materializeConstant(b, ofr.get<Attribute>(),
-                                                     b.getIndexType(), loc));
+    // Since we are directly specifying `index` as the result type, we need to
+    // ensure the provided attribute is also an index type. Otherwise, the
+    // AffineDialect materializer will create invalid `arith.constant`
+    // operations if the provided Attribute is any other kind of integer.
+    constants.push_back(dialect->materializeConstant(
+        b, b.getIndexAttr(ofr.get<Attribute>().cast<IntegerAttr>().getInt()),
+        b.getIndexType(), loc));
     actualValues.push_back(constants.back()->getResult(0));
   }
 }

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index adab8da3d518b..070b1fc4eb821 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -909,12 +909,20 @@ DiagnosedSilenceableFailure transform::TileToForeachThreadOp::applyToOne(
   IRRewriter rewriter(getContext());
   rewriter.setInsertionPoint(target);
   auto maybeThreadDimMappingAttr = getThreadDimMapping();
-  FailureOr<ForeachThreadTilingResult> tilingResult =
-      linalg::tileToForeachThreadOp(
-          rewriter, target, getAsOpFoldResult(getNumThreads()),
-          maybeThreadDimMappingAttr
-              ? extractFromI64ArrayAttr(*maybeThreadDimMappingAttr)
-              : ArrayRef<int64_t>{});
+  auto dimMapping =
+      llvm::to_vector(maybeThreadDimMappingAttr
+                          ? extractFromI64ArrayAttr(*maybeThreadDimMappingAttr)
+                          : ArrayRef<int64_t>{});
+
+  FailureOr<ForeachThreadTilingResult> tilingResult = failure();
+  if (Optional<ArrayAttr> numThreads = getNumThreads())
+    tilingResult = linalg::tileToForeachThreadOp(
+        rewriter, target, getAsOpFoldResult(*numThreads), dimMapping);
+
+  if (Optional<ArrayAttr> tileSizes = getTileSizes())
+    tilingResult = linalg::tileToForeachThreadOpUsingTileSizes(
+        rewriter, target, getAsOpFoldResult(*tileSizes), dimMapping);
+
   if (failed(tilingResult))
     return emitDefaultSilenceableFailure(target);
   rewriter.replaceOp(target, tilingResult->tileOp->getResults());

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index a8112dbe50b84..6d97dfc6d84fb 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -41,9 +41,11 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   MLIRAnalysis
   MLIRArithmeticDialect
   MLIRArithmeticTransforms
+  MLIRArithmeticUtils
   MLIRBufferizationDialect
   MLIRBufferizationTransforms
   MLIRComplexDialect
+  MLIRDialectUtils
   MLIRFuncDialect
   MLIRFuncToLLVM
   MLIRFuncTransforms

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index 0571ff5432afb..1dfaf69efa728 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -13,6 +13,7 @@
 #include <utility>
 
 #include "PassDetail.h"
+#include "mlir/Dialect/Arithmetic/Utils/Utils.h"
 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Linalg/Passes.h"
@@ -182,23 +183,43 @@ createMatchingParallelSubsetInsertOp(OpBuilder &b, Location loc,
 }
 
 /// Build an `affine_max` of all the `vals`.
-static Value buildMax(OpBuilder &b, Location loc, ValueRange vals) {
+static OpFoldResult buildMax(OpBuilder &b, Location loc,
+                             ArrayRef<OpFoldResult> vals) {
+  SmallVector<Value> args = getValueOrCreateConstantIndexOp(b, loc, vals);
   return b.createOrFold<AffineMaxOp>(
       loc, AffineMap::getMultiDimIdentityMap(vals.size(), loc.getContext()),
-      vals);
+      args);
 }
 
-/// Build an `affine_min` of all the `vals`.
-static Value buildMin(OpBuilder &b, Location loc, ValueRange vals) {
-  return b.createOrFold<AffineMinOp>(
-      loc, AffineMap::getMultiDimIdentityMap(vals.size(), loc.getContext()),
-      vals);
+/// Returns true if the maximum tile offset `tileSize * numThreads-1` is less
+/// than `iterationSize`.
+static bool canOmitTileOffsetInBoundsCheck(OpFoldResult tileSize,
+                                           OpFoldResult numThreads,
+                                           OpFoldResult iterationSize) {
+  Optional<int64_t> tileSizeConst = getConstantIntValue(tileSize);
+  Optional<int64_t> numThreadsConst = getConstantIntValue(numThreads);
+  Optional<int64_t> iterSizeConst = getConstantIntValue(iterationSize);
+  if (!tileSizeConst || !numThreadsConst || !iterSizeConst)
+    return false;
+  return *tileSizeConst * (*numThreadsConst - 1) < *iterSizeConst;
 }
 
-FailureOr<ForeachThreadTilingResult>
-linalg::tileToForeachThreadOp(OpBuilder &b, TilingInterface op,
-                              ArrayRef<OpFoldResult> numThreads,
-                              ArrayRef<int64_t> threadDimMapping) {
+/// Rewrite a TilingInterface `op` to a tiled `scf.foreach_thread`. The
+/// tiling is specified by the number of tiles/threads `numThreads` and the
+/// optional nominal tile size `nominalTileSizes`. If `nominalTilSizes` is
+/// not specified, then  it is derived from `numThreads` as `ceilDiv(dimSize[i],
+/// numThreads[i])`. If non-empty, the `threadDimMapping` is added as an
+/// attribute to the resulting `scf.foreach_thread`. A zero tile sizes indicate
+/// that the dimension is not tiled, and can be thought of as tiling by the full
+/// size of data.
+/// It is the user's responsibility to ensure that `numThreads` is a valid
+/// tiling specification (i.e. that only tiles parallel dimensions, e.g. in the
+/// Linalg case). If `omitTileOffsetBoundsCheck` is true, then the function will
+/// assume that `tileSize[i] * (numThread[i] -1) <= dimSize[i]` holds.
+static FailureOr<ForeachThreadTilingResult> tileToForeachThreadOpImpl(
+    RewriterBase &b, TilingInterface op, ArrayRef<OpFoldResult> numThreads,
+    Optional<ArrayRef<OpFoldResult>> nominalTileSizes,
+    ArrayRef<int64_t> threadDimMapping, bool omitTileOffsetBoundsCheck) {
   Location loc = op->getLoc();
   OpBuilder::InsertionGuard g(b);
   SmallVector<Range> loopRanges = op.getIterationDomain(b);
@@ -224,80 +245,128 @@ linalg::tileToForeachThreadOp(OpBuilder &b, TilingInterface op,
 
   Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
   Operation *tiledOp = nullptr;
+
+  // Create the ForeachThreadOp. We don't use the lambda body-builder
+  // version because we require the use of RewriterBase in the body, so we
+  // manually move the insertion point to the body below.
   scf::ForeachThreadOp foreachThreadOp = b.create<scf::ForeachThreadOp>(
-      loc, materializedNonZeroNumThreads, threadDimMapping,
-      [&](OpBuilder &b, Location loc, ValueRange threadIds) {
-        int64_t nLoops = loopRanges.size();
-        SmallVector<OpFoldResult> tiledOffsets, tiledSizes;
-        tiledOffsets.reserve(nLoops);
-        tiledSizes.reserve(nLoops);
-        for (unsigned loopIdx = 0, threadIdIdx = 0; loopIdx < nLoops;
-             ++loopIdx) {
-          bool overflow = loopIdx >= numThreads.size();
-          bool isZero = !overflow && isConstantIntValue(numThreads[loopIdx], 0);
-          // Degenerate case: take the whole domain.
-          if (overflow || isZero) {
-            tiledOffsets.push_back(loopRanges[loopIdx].offset);
-            tiledSizes.push_back(loopRanges[loopIdx].size);
-            continue;
-          }
-
-          // Tiled case: compute the offset and size.
-          AffineExpr i, j, M, N, O;
-          bindDims(b.getContext(), i, j);
-          bindSymbols(b.getContext(), M, N, O);
-          Value size = loopRanges[loopIdx].size;
-          Value offset = loopRanges[loopIdx].offset;
-          Value threadId = threadIds[threadIdIdx];
-          // TODO: more aggressive foldings.
-          // Symbolic fixed max size per thread.
-          // TODO: floor + 0/1 depending on case for better load-balancing.
-          Value maxSizePerThread = b.createOrFold<AffineApplyOp>(
-              loc, M.ceilDiv(N),
-              ValueRange{size, materializedNonZeroNumThreads[threadIdIdx]});
-          // Dynamic offset shifted by threadId * maxSizePerThread.
-          Value offsetPerThread = b.createOrFold<AffineApplyOp>(
-              loc, i + j * M, ValueRange{offset, threadId, maxSizePerThread});
-          // Dynamic upper-bound depending on the threadId.
-          Value sizeMinusOffsetPerThread = b.createOrFold<AffineApplyOp>(
-              loc, -i + M, ValueRange{offsetPerThread, size});
-          Value tileSizePerThread = buildMin(
-              b, loc, ValueRange{sizeMinusOffsetPerThread, maxSizePerThread});
-          tiledOffsets.push_back(offsetPerThread);
-          // TODO: if tileSizePerThread <= 0 early exit.
-          tiledSizes.push_back(
-              buildMax(b, loc, ValueRange{zero, tileSizePerThread}));
-          ++threadIdIdx;
-        }
-
-        SmallVector<Operation *> tiledOps =
-            op.getTiledImplementation(b, destOperands, tiledOffsets, tiledSizes,
-                                      /*tileDestOperands=*/true);
-        assert(tiledOps.size() == 1 && "expected a single produced tiled op");
-        tiledOp = tiledOps.front();
-
-        auto tilingInterfaceOp = dyn_cast<TilingInterface>(tiledOp);
-        assert(tilingInterfaceOp &&
-               "Tiled op does not implement TilingInterface");
-
-        auto tiledDestOperands = tilingInterfaceOp.getDestinationOperands(b);
-
-        // Create terminator with parallel subset insert operations.
-        auto performConcurrentlyOp = b.create<scf::PerformConcurrentlyOp>(loc);
-        OpBuilder::InsertionGuard g(b);
-        b.setInsertionPointToStart(performConcurrentlyOp.getBody());
-        for (auto it :
-             llvm::zip(tiledDestOperands, tilingInterfaceOp->getResults(),
-                       destOperands)) {
-          createMatchingParallelSubsetInsertOp(
-              b, loc,
-              cast<tensor::ExtractSliceOp>(std::get<0>(it).getDefiningOp()),
-              std::get<1>(it), std::get<2>(it));
-        }
-      });
+      loc, op->getResultTypes(), ValueRange(materializedNonZeroNumThreads),
+      threadDimMapping);
+
+  // Fill out the ForeachThreadOp body.
+  b.setInsertionPointToStart(foreachThreadOp.getBody(0));
+  ValueRange threadIds = foreachThreadOp.getThreadIndices();
+  int64_t nLoops = loopRanges.size();
+  SmallVector<OpFoldResult> tiledOffsets, tiledSizes;
+  tiledOffsets.reserve(nLoops);
+  tiledSizes.reserve(nLoops);
+  for (unsigned loopIdx = 0, threadIdIdx = 0; loopIdx < nLoops; ++loopIdx) {
+    bool overflow = loopIdx >= numThreads.size();
+    bool isZero = !overflow && isConstantIntValue(numThreads[loopIdx], 0);
+    // Degenerate case: take the whole domain.
+    if (overflow || isZero) {
+      tiledOffsets.push_back(loopRanges[loopIdx].offset);
+      tiledSizes.push_back(loopRanges[loopIdx].size);
+      continue;
+    }
+
+    // Tiled case: compute the offset and size.
+    AffineExpr i, j, M, N, O;
+    bindDims(b.getContext(), i, j);
+    bindSymbols(b.getContext(), M, N, O);
+    Value size = loopRanges[loopIdx].size;
+    Value offset = loopRanges[loopIdx].offset;
+    Value threadId = threadIds[threadIdIdx];
+    // Symbolic fixed max size per thread.
+    // TODO: floor + 0/1 depending on case for better load-balancing.
+    OpFoldResult tileSizePerThread =
+        nominalTileSizes.hasValue()
+            ? (*nominalTileSizes)[loopIdx]
+            : makeComposedFoldedAffineApply(
+                  b, loc, M.ceilDiv(N),
+                  ArrayRef<OpFoldResult>{size, nonZeroNumThreads[threadIdIdx]});
+
+    // Dynamic offset shifted by threadId * maxSizePerThread.
+    OpFoldResult offsetPerThread = makeComposedFoldedAffineApply(
+        b, loc, i + j * M, {offset, threadId, tileSizePerThread});
+    // Dynamic upper-bound depending on the threadId.
+    OpFoldResult residualTileSize = makeComposedFoldedAffineApply(
+        b, loc, i + j * M - N,
+        {offset, nonZeroNumThreads[threadIdIdx], tileSizePerThread, size});
+    if (!isConstantIntValue(residualTileSize, 0)) {
+      OpFoldResult sizeMinusOffsetPerThread = makeComposedFoldedAffineApply(
+          b, loc, -i + M, {offsetPerThread, size});
+      tileSizePerThread = makeComposedFoldedAffineMin(
+          b, loc, AffineMap::getMultiDimIdentityMap(2, b.getContext()),
+          ArrayRef<OpFoldResult>{sizeMinusOffsetPerThread, tileSizePerThread});
+    }
+
+    tiledOffsets.push_back(offsetPerThread);
+    // TODO: if tileSizePerThread <= 0 early exit.
+    if (!omitTileOffsetBoundsCheck &&
+        !canOmitTileOffsetInBoundsCheck(tileSizePerThread,
+                                        nonZeroNumThreads[threadIdIdx], size))
+      tileSizePerThread = buildMax(b, loc, {zero, tileSizePerThread});
+
+    tiledSizes.push_back(tileSizePerThread);
+    ++threadIdIdx;
+  }
+
+  SmallVector<Operation *> tiledOps =
+      op.getTiledImplementation(b, destOperands, tiledOffsets, tiledSizes,
+                                /*tileDestOperands=*/true);
+  assert(tiledOps.size() == 1 && "expected a single produced tiled op");
+  tiledOp = tiledOps.front();
+
+  auto tilingInterfaceOp = dyn_cast<TilingInterface>(tiledOp);
+  assert(tilingInterfaceOp && "Tiled op does not implement TilingInterface");
+
+  auto tiledDestOperands = tilingInterfaceOp.getDestinationOperands(b);
+
+  // Create terminator with parallel subset insert operations.
+  b.setInsertionPointToStart(foreachThreadOp.getTerminator().getBody());
+  for (auto it : llvm::zip(tiledDestOperands, tilingInterfaceOp->getResults(),
+                           destOperands)) {
+    createMatchingParallelSubsetInsertOp(
+        b, loc, cast<tensor::ExtractSliceOp>(std::get<0>(it).getDefiningOp()),
+        std::get<1>(it), std::get<2>(it));
+  }
   return ForeachThreadTilingResult{foreachThreadOp, tiledOp};
 }
 
+FailureOr<ForeachThreadTilingResult>
+linalg::tileToForeachThreadOp(RewriterBase &b, TilingInterface op,
+                              ArrayRef<OpFoldResult> numThreads,
+                              ArrayRef<int64_t> threadDimMapping) {
+  return tileToForeachThreadOpImpl(b, op, numThreads, /*nominalTileSizes=*/None,
+                                   threadDimMapping,
+                                   /*omitTileOffsetBoundsCheck=*/false);
+}
+
+FailureOr<ForeachThreadTilingResult>
+linalg::tileToForeachThreadOpUsingTileSizes(
+    RewriterBase &b, TilingInterface op, ArrayRef<OpFoldResult> tileSizes,
+    ArrayRef<int64_t> threadDimMapping) {
+  SmallVector<Range> loopRanges = op.getIterationDomain(b);
+  unsigned nLoops = loopRanges.size();
+  SmallVector<OpFoldResult> numThreads;
+  numThreads.reserve(nLoops);
+  AffineExpr s0, s1;
+  bindSymbols(b.getContext(), s0, s1);
+  AffineExpr divExpr = s0.ceilDiv(s1);
+  for (const auto &it : llvm::zip(tileSizes, loopRanges)) {
+    OpFoldResult numTiles = std::get<0>(it);
+    if (!isConstantIntValue(numTiles, 0))
+      numTiles = makeComposedFoldedAffineApply(
+          b, op.getLoc(), divExpr, {std::get<1>(it).size, std::get<0>(it)});
+    numThreads.push_back(numTiles);
+  }
+  return tileToForeachThreadOpImpl(b, op, numThreads,
+                                   /*nominalTileSizes=*/tileSizes,
+                                   threadDimMapping,
+                                   /*omitTileOffsetBoundsCheck=*/true);
+}
+
 // Insert a tile `source` into the destination tensor `dest`. The position at
 // which the tile is inserted (as well as size of tile) is taken from a given
 // ExtractSliceOp `sliceOp`.

diff  --git a/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir b/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir
index 2529f4a50556f..ab97df3611571 100644
--- a/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir
+++ b/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --test-transform-dialect-interpreter -canonicalize | FileCheck %s
+// RUN: mlir-opt %s --test-transform-dialect-interpreter -canonicalize -split-input-file | FileCheck %s
 
 // Offset per thread:
 // CHECK-DAG: affine_map<(d0)[s0] -> (d0 * (s0 ceildiv 10))>
@@ -37,7 +37,143 @@ module {
     transform.sequence %arg0 {
     ^bb1(%arg1: !pdl.operation):
       %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
-      %1:2 = transform.structured.tile_to_foreach_thread_op %0 [10, 20] (mapped to dims [1, 0])
+      %1:2 = transform.structured.tile_to_foreach_thread_op %0 num_threads [10, 20] (mapped to dims [1, 0])
     }
   }
 }
+
+// -----
+
+// Tests that dimension 0 can eliminate affine.min/max, dimension 1 cannot.
+
+// CHECK-DAG: #[[$map0:.+]] = affine_map<(d0) -> (d0 * -15 + 300, 15)>
+// CHECK-DAG: #[[$map1:.+]] = affine_map<(d0) -> (0, d0)>
+// CHECK-DAG: #[[$map2:.+]] = affine_map<(d0) -> (d0 * 10)>
+// CHECK-DAG: #[[$map3:.+]] = affine_map<(d0) -> (d0 * 15)>
+
+// CHECK-LABEL: matmul_static(
+//  CHECK-SAME:   %[[A:[0-9a-z]+]]: tensor
+//  CHECK-SAME:   %[[B:[0-9a-z]+]]: tensor
+//  CHECK-SAME:   %[[C:[0-9a-z]+]]: tensor
+func.func @matmul_static(%A: tensor<100x200xf32>, %B: tensor<200x300xf32>, %C: tensor<100x300xf32>) -> tensor<100x300xf32> {  
+  //  CHECK-DAG: %[[c10:.+]] = arith.constant 10 : index
+  //  CHECK-DAG: %[[c21:.+]] = arith.constant 21 : index
+  //      CHECK: scf.foreach_thread (%[[IV0:.+]], %[[IV1:.+]]) in (%[[c10]], %[[c21]])
+  //      CHECK:   %[[TSMIN:.+]] = affine.min #[[$map0]](%[[IV1]])
+  //      CHECK:   %[[TS:.+]] = affine.max #[[$map1]](%[[TSMIN]])
+  //  CHECK-NOT:   affine.min
+  //  CHECK-NOT:   affine.max
+  //      CHECK:   %[[LB0:.+]] = affine.apply #[[$map2]](%[[IV0]])
+  //      CHECK:   %[[tA:.+]] = tensor.extract_slice %[[A]][%[[LB0]], 0] [10, 200] [1, 1] :
+  //      CHECK:   %[[LB1:.+]] = affine.apply #[[$map3]](%[[IV1]])
+  //      CHECK:   %[[tB:.+]] = tensor.extract_slice %[[B]][0, %[[LB1]]] [200, %[[TS]]] [1, 1] :
+  //      CHECK:   %[[LB0:.+]] = affine.apply #[[$map2]](%[[IV0]])
+  //      CHECK:   %[[LB1:.+]] = affine.apply #[[$map3]](%[[IV1]])
+  //      CHECK:   %[[tC:.+]] = tensor.extract_slice %[[C]][%[[LB0]], %[[LB1]]] [10, %[[TS]]] [1, 1] :
+  //      CHECK:   linalg.matmul
+  //      CHECK:   scf.foreach_thread.perform_concurrently
+  // CHECK-NEXT:    tensor.parallel_insert_slice
+  %0 = linalg.matmul ins(%A, %B : tensor<100x200xf32>, tensor<200x300xf32>)
+                    outs(%C : tensor<100x300xf32>) -> (tensor<100x300xf32>)
+  return %0 : tensor<100x300xf32>
+}
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  transform.sequence %arg0 {
+  ^bb1(%arg1: !pdl.operation):
+    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
+    %1:2 = transform.structured.tile_to_foreach_thread_op %0 num_threads [10, 21]
+  }
+}
+
+
+// -----
+
+// CHECK-DAG: #[[$map0:.+]] = affine_map<()[s0] -> (s0 ceildiv 10)>
+// CHECK-DAG: #[[$map1:.+]] = affine_map<()[s0] -> (s0 ceildiv 20)>
+// CHECK-DAG: #[[$map2:.+]] = affine_map<(d0)[s0] -> (d0 * -10 + s0, 10)>
+// CHECK-DAG: #[[$map4:.+]] = affine_map<(d0)[s0] -> (d0 * -20 + s0, 20)>
+// CHECK-DAG: #[[$map5:.+]] = affine_map<(d0) -> (d0 * 10)>
+// CHECK-DAG: #[[$map6:.+]] = affine_map<(d0) -> (d0 * 20)>
+
+// CHECK-LABEL: matmul_tile_size_dynamic(
+//  CHECK-SAME:   %[[A:[0-9a-z]+]]: tensor<?x?xf32>
+//  CHECK-SAME:   %[[B:[0-9a-z]+]]: tensor<?x?xf32>
+//  CHECK-SAME:   %[[C:[0-9a-z]+]]: tensor<?x?xf32>
+func.func @matmul_tile_size_dynamic(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>) -> tensor<?x?xf32> {  
+  //      CHECK: %[[M:.+]] = tensor.dim %[[A]], %c0 :
+  //      CHECK: %[[N:.+]] = tensor.dim %[[B]], %c1 : 
+  //      CHECK: %[[NT0:.+]] = affine.apply #map0()[%[[M]]]
+  //      CHECK: %[[NT1:.+]] = affine.apply #map1()[%[[N]]]
+  //      CHECK: %[[M:.+]] = tensor.dim %[[A]], %c0 :
+  //      CHECK: %[[N:.+]] = tensor.dim %[[B]], %c1 :
+  //      CHECK: scf.foreach_thread (%[[IV0:.+]], %[[IV1:.+]]) in (%[[NT0]], %[[NT1]])
+  //      CHECK:   %[[TS0:.+]] = affine.min #[[$map2]](%[[IV0]])[%[[M]]]  
+  //      CHECK:   %[[TS1:.+]] = affine.min #[[$map4]](%[[IV1]])[%[[N]]]
+  //      CHECK:   %[[LB0:.+]] = affine.apply #[[$map5]](%[[IV0]])
+  //      CHECK    tensor.extract_slice %[[A]]
+  //      CHECK:   %[[LB1:.+]] = affine.apply #[[$map6]](%[[IV1]])
+  //      CHECK    tensor.extract_slice %[[B]]
+  //      CHECK:   %[[LB0:.+]] = affine.apply #[[$map5]](%[[IV0]])
+  //      CHECK:   %[[LB1:.+]] = affine.apply #[[$map6]](%[[IV1]])
+  //      CHECK    tensor.extract_slice %[[C]]
+  //      CHECK:   linalg.matmul
+  //      CHECK:   scf.foreach_thread.perform_concurrently
+  // CHECK-NEXT:    tensor.parallel_insert_slice
+  %0 = linalg.matmul ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>)
+                    outs(%C : tensor<?x?xf32>) -> (tensor<?x?xf32>)
+  return %0 : tensor<?x?xf32>
+}
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  transform.sequence %arg0 {
+  ^bb1(%arg1: !pdl.operation):    
+    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
+    %1:2 = transform.structured.tile_to_foreach_thread_op %0 tile_sizes [10, 20]
+  }
+}
+
+// -----
+
+// Tests that dimension 0 can eliminate affine.min/max, dimension 1 cannot.
+
+// CHECK-DAG: #[[$map0:.+]] = affine_map<(d0) -> (d0 * -21 + 300, 21)>
+// CHECK-DAG: #[[$map2:.+]] = affine_map<(d0) -> (d0 * 10)>
+// CHECK-DAG: #[[$map3:.+]] = affine_map<(d0) -> (d0 * 21)>
+
+// CHECK-LABEL: matmul_tile_size_static(
+//  CHECK-SAME:   %[[A:[0-9a-z]+]]: tensor
+//  CHECK-SAME:   %[[B:[0-9a-z]+]]: tensor
+//  CHECK-SAME:   %[[C:[0-9a-z]+]]: tensor
+func.func @matmul_tile_size_static(%A: tensor<100x200xf32>, %B: tensor<200x300xf32>, %C: tensor<100x300xf32>) -> tensor<100x300xf32> {
+  //  CHECK-DAG: %[[c10:.+]] = arith.constant 10 :
+  //  CHECK-DAG: %[[c15:.+]] = arith.constant 15 :
+  //      CHECK: scf.foreach_thread (%[[IV0:.+]], %[[IV1:.+]]) in (%[[c10]], %[[c15]])
+  //      CHECK:   %[[TS:.+]] = affine.min #[[$map0]](%[[IV1]])  
+  //  CHECK-NOT:   affine.max
+  //  CHECK-NOT:   affine.min
+  //      CHECK:   %[[LB0:.+]] = affine.apply #[[$map2]](%[[IV0]])
+  //      CHECK:   %[[tA:.+]] = tensor.extract_slice %[[A]][%[[LB0]], 0] [10, 200] [1, 1] :
+  //      CHECK:   %[[LB1:.+]] = affine.apply #[[$map3]](%[[IV1]])
+  //      CHECK:   %[[tB:.+]] = tensor.extract_slice %[[B]][0, %[[LB1]]] [200, %[[TS]]] [1, 1] :
+  //      CHECK:   %[[LB0:.+]] = affine.apply #[[$map2]](%[[IV0]])
+  //      CHECK:   %[[LB1:.+]] = affine.apply #[[$map3]](%[[IV1]])
+  //      CHECK:   %[[tC:.+]] = tensor.extract_slice %[[C]][%[[LB0]], %[[LB1]]] [10, %[[TS]]] [1, 1] :
+  //      CHECK:   linalg.matmul
+  //      CHECK:   scf.foreach_thread.perform_concurrently
+  // CHECK-NEXT:    tensor.parallel_insert_slice
+  %0 = linalg.matmul ins(%A, %B : tensor<100x200xf32>, tensor<200x300xf32>)
+                    outs(%C : tensor<100x300xf32>) -> (tensor<100x300xf32>)
+  return %0 : tensor<100x300xf32>
+}
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  transform.sequence %arg0 {
+  ^bb1(%arg1: !pdl.operation):
+    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
+    %1:2 = transform.structured.tile_to_foreach_thread_op %0 tile_sizes [10, 21]
+  }
+}


        


More information about the Mlir-commits mailing list