[Mlir-commits] [mlir] 99833cd - [mlir][linalg] Add reduction tiling using scf.foreachthread

Thomas Raoux llvmlistbot at llvm.org
Mon Nov 14 10:06:06 PST 2022


Author: Thomas Raoux
Date: 2022-11-14T18:05:40Z
New Revision: 99833cd8188c1ddbf8c76f4683813060553af79a

URL: https://github.com/llvm/llvm-project/commit/99833cd8188c1ddbf8c76f4683813060553af79a
DIFF: https://github.com/llvm/llvm-project/commit/99833cd8188c1ddbf8c76f4683813060553af79a.diff

LOG: [mlir][linalg] Add reduction tiling using scf.foreachthread

This adds a transformation to tile reduction operations to partial
reduction using scf.foreachthread. This uses
PartialReductionOpInterface to create a merge operation of the partial
tiles.

Differential Revision: https://reviews.llvm.org/D137912

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
    mlir/test/Dialect/Linalg/transform-tile-reduction.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 4cfec70b0c07e..d6fae79ecc9b0 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -626,7 +626,6 @@ def SplitReductionOp : Op<Transform_Dialect, "structured.split_reduction",
   }];
 }
 
-
 def TileReductionUsingScfOp : Op<Transform_Dialect, "structured.tile_reduction_using_scf",
        [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
         TransformEachOpTrait, TransformOpInterface]> {
@@ -714,6 +713,89 @@ def TileReductionUsingScfOp : Op<Transform_Dialect, "structured.tile_reduction_u
   }];
 }
 
+def TileReductionUsingForeachThreadOp : 
+  Op<Transform_Dialect, "structured.tile_reduction_using_foreach_thread",
+       [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+        TransformEachOpTrait, TransformOpInterface]> {
+  let description = [{
+    Tile a PartialReductionOpInterface op to a tiled `scf.foreach_thread` doing
+    partial reduction.
+
+    This transformation tiles the `target` along the reduction dimensions. It
+    creates a tensor initialized with the identity value. Then it creates a
+    `scf.foreach_thread` loops with the number threads given by `num_threads`.
+    The op is tiled op with a size equal to `floordiv(size, num_threads)`.
+    All the partial reduction value is are parallel inserted to create a new
+    tensor. After the loop a merge operation is created to do a final reduction
+    with the partial reductions tensor.
+    
+    #### Return modes
+
+    This 3 returned handles point to:
+      - the fill op used to initialize the neutral element, 
+      - the parallel tiled op and 
+      - the result-combining op.
+
+    #### Example:
+    
+    ```
+      %red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+                                              affine_map<(d0, d1) -> (d0)>],
+      iterator_types = ["parallel", "reduction"]}
+      ins(%arg0 : tensor<?x?xf32>)
+      outs(%out : tensor<?xf32>) {
+        ^bb0(%arg7: f32, %arg9: f32):
+        %1 = arith.addf %arg7, %arg9 : f32
+        linalg.yield %1 : f32
+      } -> tensor<?xf32>
+      return %red : tensor<?xf32>
+    ```
+  
+    is transformed into:
+    
+    ```
+      %0 = tensor.empty(%dim_1) : tensor<?x5xf32>
+      %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<?x5xf32>) -> tensor<?x5xf32>
+      %2 = scf.foreach_thread (%arg2) in (%c5) shared_outs(%arg3 = %1) -> (tensor<?x5xf32>) {
+        %4 = affine.min #map(%arg2)[%dim_0]
+        %5 = affine.max #map1(%4)
+        %extracted_slice = tensor.extract_slice %arg3[0, %arg2] [%dim, 1] [1, 1] : tensor<?x5xf32> to tensor<?xf32>
+        %6 = affine.apply #map2(%arg2)[%dim_0]
+        %extracted_slice_2 = tensor.extract_slice %arg0[0, %6] [%dim, %5] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+        %extracted_slice_3 = tensor.extract_slice %extracted_slice[0] [%dim] [1] : tensor<?xf32> to tensor<?xf32>
+        %7 = linalg.generic {indexing_maps = [#map3, #map4], iterator_types = ["parallel", "reduction"]} ins(%extracted_slice_2 : tensor<?x?xf32>) outs(%extracted_slice_3 : tensor<?xf32>) {
+        ^bb0(%in: f32, %out: f32):
+          %9 = arith.addf %in, %out : f32
+          linalg.yield %9 : f32
+        } -> tensor<?xf32>
+        scf.foreach_thread.perform_concurrently {
+          tensor.parallel_insert_slice %7 into %arg3[0, %arg2] [%dim, 1] [1, 1] : tensor<?xf32> into tensor<?x5xf32>
+        }
+      } {thread_dim_mapping = []}
+      %3 = linalg.generic {indexing_maps = [#map3, #map4], iterator_types = ["parallel", "reduction"]} ins(%2 : tensor<?x5xf32>) outs(%arg1 : tensor<?xf32>) {
+      ^bb0(%in: f32, %out: f32):
+        %4 = arith.addf %in, %out : f32
+        linalg.yield %4 : f32
+      } -> tensor<?xf32>
+    ```
+  }];
+
+  let arguments = (ins PDL_Operation:$target,
+                   DefaultValuedAttr<I64ArrayAttr, "{}">:$num_threads);
+  let results = (outs PDL_Operation:$fill_op,
+                      PDL_Operation:$split_linalg_op,
+                      PDL_Operation:$combining_linalg_op);
+
+  let assemblyFormat = "$target attr-dict";
+
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure applyToOne(
+        ::mlir::linalg::LinalgOp target, 
+        ::llvm::SmallVectorImpl<::mlir::Operation *> &results, 
+        ::mlir::transform::TransformState &state);
+  }];
+}
+
 def TileOp : Op<Transform_Dialect, "structured.tile",
        [DeclareOpInterfaceMethods<TransformOpInterface>,
         DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {

diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 758386fbc5bc1..dcdc532312460 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -445,6 +445,47 @@ tileToForeachThreadOpUsingTileSizes(RewriterBase &builder, TilingInterface op,
                                     ArrayRef<OpFoldResult> tileSizes,
                                     Optional<ArrayAttr> mapping);
 
+/// Transformation information returned after reduction tiling.
+struct ForeachThreadReductionTilingResult {
+  /// The partial reduction tiled op generated.
+  Operation *parallelTiledOp;
+  /// The final reduction operation merging all the partial reductions.
+  Operation *mergeOp;
+  /// The op initializing the tensor used for partial reductions.
+  Operation *initialOp;
+  /// The `scf.foreach_thread` operation that iterate over the tiles.
+  scf::ForeachThreadOp loops;
+};
+
+/// Method to tile a reduction to parallel iterations computing partial
+/// reductions. After the loop all the partial reduction are merged into a final
+/// reduction. For example for the following sequence
+///
+/// ```mlir
+/// %0 = linalg.generic %in ["parallel", "reduction"]
+///   : tensor<7x9xf32> -> tensor<7xf32>
+/// ```
+///
+/// into:
+///
+/// ```mlir
+/// %0 = linalg.fill ... : tensor<7x4xf32>
+/// %1 = scf.foreach_thread (%iv) in (%c4) shared_outs(%arg0 = %0)
+///   -> (tensor<7x4xf32>) {
+///   %2 = tensor.extract_slice %arg3 : tensor<7x4xf32> to tensor<7xf32>
+///   %3 = tensor.extract_slice %in : tensor<7x9xf32> -> tensor<7x?xf32>
+///   %4 = linalg.generic %2, %3 ["parallel", "reduction"]
+///     : tensor<7x?xf32> -> tensor<7xf32>
+///   %5 = tensor.insert_slice %3, %arg0[0, %iv] : tensor<7x4xf32>
+/// }
+/// %6 = linalg.generic %1 ["parallel", "reduction"]
+///   : tensor<7x4xf32> -> tensor<7xf32>
+/// ```
+FailureOr<ForeachThreadReductionTilingResult>
+tileReductionUsingForeachThread(RewriterBase &b, PartialReductionOpInterface op,
+                                ArrayRef<OpFoldResult> numThreads,
+                                Optional<ArrayAttr> mapping);
+
 /// All indices returned by IndexOp should be invariant with respect to
 /// tiling. Therefore, if an operation is tiled, we have to transform the
 /// indices accordingly, i.e. offset them by the values of the corresponding

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 960fabee05b59..ff8549253cff5 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1165,6 +1165,39 @@ DiagnosedSilenceableFailure transform::TileReductionUsingScfOp::applyToOne(
   return DiagnosedSilenceableFailure(success());
 }
 
+//===----------------------------------------------------------------------===//
+// TileReductionUsingForeachThreadOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::TileReductionUsingForeachThreadOp::applyToOne(
+    linalg::LinalgOp target, SmallVectorImpl<Operation *> &results,
+    transform::TransformState &state) {
+  SimpleRewriter rewriter(getContext());
+  rewriter.setInsertionPoint(target);
+  SmallVector<int64_t> numThreads = extractFromI64ArrayAttr(getNumThreads());
+  SmallVector<OpFoldResult> numThreadResults;
+  for (int64_t num : numThreads) {
+    numThreadResults.push_back(rewriter.getIndexAttr(num));
+  }
+
+  FailureOr<linalg::ForeachThreadReductionTilingResult> result =
+      linalg::tileReductionUsingForeachThread(
+          rewriter, cast<PartialReductionOpInterface>(target.getOperation()),
+          numThreadResults, /*mapping=*/llvm::None);
+
+  if (failed(result)) {
+    results.assign(3, nullptr);
+    Diagnostic diag(target->getLoc(), DiagnosticSeverity::Remark);
+    diag << "could not tile reduction in target.";
+    return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
+  }
+  results.push_back(result->initialOp);
+  results.push_back(result->parallelTiledOp);
+  results.push_back(result->mergeOp);
+  return DiagnosedSilenceableFailure(success());
+}
+
 //===----------------------------------------------------------------------===//
 // TileOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index 284109e6bf3f2..a63ddc496117c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -211,58 +211,21 @@ static OpFoldResult buildMin(OpBuilder &b, Location loc,
       vals);
 }
 
-/// Rewrite a TilingInterface `op` to a tiled `scf.foreach_thread`. The
-/// tiling is specified by the number of tiles/threads `numThreads` and the
-/// optional nominal tile size `nominalTileSizes`. If `nominalTilSizes` is
-/// not specified, then  it is derived from `numThreads` as `ceilDiv(dimSize[i],
-/// numThreads[i])`. If non-empty, the `mapping` is added as an
-/// attribute to the resulting `scf.foreach_thread`. A zero tile sizes indicate
-/// that the dimension is not tiled, and can be thought of as tiling by the full
-/// size of data.
-/// It is the user's responsibility to ensure that `numThreads` is a valid
-/// tiling specification (i.e. that only tiles parallel dimensions, e.g. in the
-/// Linalg case). If `omitTileOffsetBoundsCheck` is true, then the function will
-/// assume that `tileSize[i] * (numThread[i] -1) <= dimSize[i]` holds.
-static FailureOr<ForeachThreadTilingResult> tileToForeachThreadOpImpl(
-    RewriterBase &b, TilingInterface op, ArrayRef<OpFoldResult> numThreads,
+/// Fill out the `tiledOffsets` and `tiledSizes` to be used to tile to a given
+/// number of threads.
+static void calculateTileOffsetsAndSizes(
+    RewriterBase &b, Location loc, scf::ForeachThreadOp foreachThreadOp,
+    ArrayRef<OpFoldResult> numThreads, SmallVector<Range> loopRanges,
+    bool omitTileOffsetBoundsCheck,
     Optional<ArrayRef<OpFoldResult>> nominalTileSizes,
-    Optional<ArrayAttr> mapping, bool omitTileOffsetBoundsCheck) {
-  Location loc = op->getLoc();
-  OpBuilder::InsertionGuard g(b);
-  SmallVector<Range> loopRanges = op.getIterationDomain(b);
-  if (loopRanges.empty())
-    return op->emitOpError("expected non-empty loop ranges");
-  auto hasStrideOne = [](Range r) { return !isConstantIntValue(r.stride, 1); };
-  if (llvm::any_of(loopRanges, hasStrideOne))
-    return op->emitOpError("only stride-1 supported atm");
-
-  // Gather destination tensors.
-  SmallVector<Value> dest;
-  if (failed(tensor::getOrCreateDestinations(b, loc, op, dest)))
-    return op->emitOpError("failed to get destination tensors");
-
+    SmallVector<OpFoldResult> &tiledOffsets,
+    SmallVector<OpFoldResult> &tiledSizes) {
+  ValueRange threadIds = foreachThreadOp.getThreadIndices();
   SmallVector<OpFoldResult> nonZeroNumThreads =
       llvm::to_vector(llvm::make_filter_range(numThreads, [](OpFoldResult ofr) {
         return !isConstantIntValue(ofr, 0);
       }));
-  SmallVector<Value> materializedNonZeroNumThreads =
-      llvm::to_vector(llvm::map_range(nonZeroNumThreads, [&](OpFoldResult ofr) {
-        return getValueOrCreateConstantIndexOp(b, loc, ofr);
-      }));
-
-  Operation *tiledOp = nullptr;
-
-  // Create the ForeachThreadOp. We don't use the lambda body-builder
-  // version because we require the use of RewriterBase in the body, so we
-  // manually move the insertion point to the body below.
-  scf::ForeachThreadOp foreachThreadOp = b.create<scf::ForeachThreadOp>(
-      loc, dest, ValueRange(materializedNonZeroNumThreads), mapping);
-
-  // Fill out the ForeachThreadOp body.
-  b.setInsertionPointToStart(foreachThreadOp.getBody(0));
-  ValueRange threadIds = foreachThreadOp.getThreadIndices();
   int64_t nLoops = loopRanges.size();
-  SmallVector<OpFoldResult> tiledOffsets, tiledSizes;
   tiledOffsets.reserve(nLoops);
   tiledSizes.reserve(nLoops);
   for (unsigned loopIdx = 0, threadIdIdx = 0; loopIdx < nLoops; ++loopIdx) {
@@ -316,6 +279,61 @@ static FailureOr<ForeachThreadTilingResult> tileToForeachThreadOpImpl(
     tiledSizes.push_back(tileSizePerThread);
     ++threadIdIdx;
   }
+}
+
+/// Rewrite a TilingInterface `op` to a tiled `scf.foreach_thread`. The
+/// tiling is specified by the number of tiles/threads `numThreads` and the
+/// optional nominal tile size `nominalTileSizes`. If `nominalTilSizes` is
+/// not specified, then  it is derived from `numThreads` as `ceilDiv(dimSize[i],
+/// numThreads[i])`. If non-empty, the `mapping` is added as an
+/// attribute to the resulting `scf.foreach_thread`. A zero tile sizes indicate
+/// that the dimension is not tiled, and can be thought of as tiling by the full
+/// size of data.
+/// It is the user's responsibility to ensure that `numThreads` is a valid
+/// tiling specification (i.e. that only tiles parallel dimensions, e.g. in the
+/// Linalg case). If `omitTileOffsetBoundsCheck` is true, then the function will
+/// assume that `tileSize[i] * (numThread[i] -1) <= dimSize[i]` holds.
+static FailureOr<ForeachThreadTilingResult> tileToForeachThreadOpImpl(
+    RewriterBase &b, TilingInterface op, ArrayRef<OpFoldResult> numThreads,
+    Optional<ArrayRef<OpFoldResult>> nominalTileSizes,
+    Optional<ArrayAttr> mapping, bool omitTileOffsetBoundsCheck) {
+  Location loc = op->getLoc();
+  OpBuilder::InsertionGuard g(b);
+  SmallVector<Range> loopRanges = op.getIterationDomain(b);
+  if (loopRanges.empty())
+    return op->emitOpError("expected non-empty loop ranges");
+  auto hasStrideOne = [](Range r) { return !isConstantIntValue(r.stride, 1); };
+  if (llvm::any_of(loopRanges, hasStrideOne))
+    return op->emitOpError("only stride-1 supported atm");
+
+  // Gather destination tensors.
+  SmallVector<Value> dest;
+  if (failed(tensor::getOrCreateDestinations(b, loc, op, dest)))
+    return op->emitOpError("failed to get destination tensors");
+
+  SmallVector<OpFoldResult> nonZeroNumThreads =
+      llvm::to_vector(llvm::make_filter_range(numThreads, [](OpFoldResult ofr) {
+        return !isConstantIntValue(ofr, 0);
+      }));
+  SmallVector<Value> materializedNonZeroNumThreads =
+      llvm::to_vector(llvm::map_range(nonZeroNumThreads, [&](OpFoldResult ofr) {
+        return getValueOrCreateConstantIndexOp(b, loc, ofr);
+      }));
+
+  Operation *tiledOp = nullptr;
+
+  // Create the ForeachThreadOp. We don't use the lambda body-builder
+  // version because we require the use of RewriterBase in the body, so we
+  // manually move the insertion point to the body below.
+  scf::ForeachThreadOp foreachThreadOp = b.create<scf::ForeachThreadOp>(
+      loc, dest, ValueRange(materializedNonZeroNumThreads), mapping);
+
+  // Fill out the ForeachThreadOp body.
+  b.setInsertionPointToStart(foreachThreadOp.getBody(0));
+  SmallVector<OpFoldResult> tiledOffsets, tiledSizes;
+  calculateTileOffsetsAndSizes(b, loc, foreachThreadOp, numThreads, loopRanges,
+                               omitTileOffsetBoundsCheck, nominalTileSizes,
+                               tiledOffsets, tiledSizes);
 
   // Clone the tileable op and update its destination operands to use the output
   // bbArgs of the ForeachThreadOp.
@@ -392,6 +410,140 @@ linalg::tileToForeachThreadOpUsingTileSizes(RewriterBase &b, TilingInterface op,
                                    /*omitTileOffsetBoundsCheck=*/true);
 }
 
+FailureOr<linalg::ForeachThreadReductionTilingResult>
+linalg::tileReductionUsingForeachThread(RewriterBase &b,
+                                        PartialReductionOpInterface op,
+                                        ArrayRef<OpFoldResult> numThreads,
+                                        Optional<ArrayAttr> mapping) {
+  Location loc = op.getLoc();
+  OpBuilder::InsertionGuard g(b);
+  // Ops implementing PartialReductionOpInterface are expected to implement
+  // TilingInterface.
+  auto tilingInterfaceOp = cast<TilingInterface>(op.getOperation());
+  SmallVector<Range> iterationDomain = tilingInterfaceOp.getIterationDomain(b);
+  if (op->getNumResults() != 1)
+    return b.notifyMatchFailure(
+        op, "don't support ops with multiple results for now");
+  SmallVector<utils::IteratorType> iterators =
+      tilingInterfaceOp.getLoopIteratorTypes();
+  SmallVector<unsigned> redDims;
+  cast<linalg::LinalgOp>(op.getOperation()).getReductionDims(redDims);
+  if (redDims.size() != 1)
+    return b.notifyMatchFailure(
+        op, "only support ops with one reduction dimension.");
+  int reductionDim = static_cast<int>(redDims.front());
+  // 1. create the inital tensor value.
+  FailureOr<Operation *> identityTensor =
+      op.generateInitialTensorForPartialReduction(b, loc, numThreads,
+                                                  reductionDim);
+  if (failed(identityTensor))
+    return b.notifyMatchFailure(op,
+                                "cannot create a tensor of identity value.");
+
+  // Gather destination tensors.
+  SmallVector<Value> dest;
+  if (failed(tensor::getOrCreateDestinations(b, loc, op, dest)))
+    return b.notifyMatchFailure(op, "failed to get destination tensors");
+
+  Operation *tiledOp = nullptr;
+
+  SmallVector<OpFoldResult> nonZeroNumThreads =
+      llvm::to_vector(llvm::make_filter_range(numThreads, [](OpFoldResult ofr) {
+        return !isConstantIntValue(ofr, 0);
+      }));
+  SmallVector<Value> materializedNonZeroNumThreads =
+      llvm::to_vector(llvm::map_range(nonZeroNumThreads, [&](OpFoldResult ofr) {
+        return getValueOrCreateConstantIndexOp(b, loc, ofr);
+      }));
+
+  // 2. Create the ForeachThreadOp with an empty region.
+  scf::ForeachThreadOp foreachThreadOp = b.create<scf::ForeachThreadOp>(
+      loc, identityTensor.value()->getResults(),
+      ValueRange(materializedNonZeroNumThreads), mapping);
+
+  // 3. calculate the tile offsets and sizes.
+  b.setInsertionPointToStart(foreachThreadOp.getBody(0));
+  SmallVector<OpFoldResult> tiledOffsets, tiledSizes;
+  calculateTileOffsetsAndSizes(
+      b, loc, foreachThreadOp, numThreads, iterationDomain,
+      /*omitTileOffsetBoundsCheck =*/false,
+      /*nominalTileSizes=*/llvm::None, tiledOffsets, tiledSizes);
+
+  // 4. Clone the tileable op and update its destination operands to use the
+  // output bbArgs of the ForeachThreadOp.
+  ArrayRef<BlockArgument> destBbArgs =
+      foreachThreadOp.getOutputBlockArguments();
+  Operation *clonedOp = b.clone(*op.getOperation());
+  auto destinationStyleOp = cast<DestinationStyleOpInterface>(clonedOp);
+  for (OpOperand *outOperand : destinationStyleOp.getDpsInitOperands()) {
+    auto *it = llvm::find(dest, outOperand->get());
+    assert(it != dest.end() && "dest operand not found in dest");
+    unsigned destNum = std::distance(dest.begin(), it);
+    SmallVector<OpFoldResult> strides(numThreads.size(), b.getIndexAttr(1));
+    SmallVector<OpFoldResult> outOffsets(numThreads.size(), b.getIndexAttr(0));
+    SmallVector<OpFoldResult> sizes = tiledSizes;
+    sizes[reductionDim] = b.getIndexAttr(1);
+    outOffsets[reductionDim] = foreachThreadOp.getThreadIndices().front();
+    // TODO: use SubsetExtractOpInterface once it is available.
+    Value patial = b.create<tensor::ExtractSliceOp>(
+        loc, outOperand->get().getType().cast<RankedTensorType>(),
+        destBbArgs[destNum], outOffsets, sizes, strides);
+    outOperand->set(patial);
+  }
+
+  // 5. Tile the cloned op and delete the clone.
+  SmallVector<Operation *> tiledOps =
+      cast<TilingInterface>(clonedOp).getTiledImplementation(b, tiledOffsets,
+                                                             tiledSizes);
+  b.eraseOp(clonedOp);
+  assert(tiledOps.size() == 1 && "expected a single produced tiled op");
+  tiledOp = tiledOps.front();
+
+  // 6. Insert the partial reductions back into a new tensor.
+  auto tiledInterfaceOp = dyn_cast<TilingInterface>(tiledOp);
+  assert(tiledInterfaceOp && "Tiled op does not implement TilingInterface");
+  OpBuilder::InsertPoint insertPt = b.saveInsertionPoint();
+  for (auto it : llvm::zip(llvm::seq(unsigned(0), unsigned(dest.size())),
+                           tiledInterfaceOp->getResults(), destBbArgs)) {
+    b.setInsertionPoint(insertPt.getBlock(), insertPt.getPoint());
+    SmallVector<OpFoldResult> resultOffsets, resultSizes;
+    if (failed(tilingInterfaceOp.getResultTilePosition(
+            b, std::get<0>(it), tiledOffsets, tiledSizes, resultOffsets,
+            resultSizes)))
+      return op->emitOpError("output offsets couldn't be calculated");
+    SmallVector<OpFoldResult> resultOffsetsRank, resultSizesRank;
+    int64_t offIdx = 0;
+    int64_t sizeIdx = 0;
+    for (int64_t i = 0, e = numThreads.size(); i < e; ++i) {
+      if (i == reductionDim) {
+        resultOffsetsRank.push_back(foreachThreadOp.getThreadIndices().front());
+        resultSizesRank.push_back(b.getIndexAttr(1));
+        continue;
+      }
+      resultOffsetsRank.push_back(resultOffsets[offIdx++]);
+      resultSizesRank.push_back(resultSizes[sizeIdx++]);
+    }
+
+    SmallVector<OpFoldResult> strides(resultSizesRank.size(),
+                                      b.getIndexAttr(1));
+    b.setInsertionPointToEnd(foreachThreadOp.getTerminator().getBody());
+    b.create<tensor::ParallelInsertSliceOp>(loc, std::get<1>(it),
+                                            std::get<2>(it), resultOffsetsRank,
+                                            resultSizesRank, strides);
+  }
+  // 7. Merge the partial reductions.
+  b.setInsertionPointAfter(foreachThreadOp);
+  Operation *mergeOp =
+      op.mergeReductions(b, loc, foreachThreadOp->getResults(), reductionDim);
+  b.replaceOp(op, mergeOp->getResults());
+  ForeachThreadReductionTilingResult results;
+  results.initialOp = identityTensor.value();
+  results.loops = foreachThreadOp;
+  results.parallelTiledOp = tiledOp;
+  results.mergeOp = mergeOp;
+  return results;
+}
+
 // Insert a tile `source` into the destination tensor `dest`. The position at
 // which the tile is inserted (as well as size of tile) is taken from a given
 // ExtractSliceOp `sliceOp`.

diff  --git a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
index dad2f8476d1ff..131c7a6193bde 100644
--- a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
+++ b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -test-transform-dialect-interpreter -split-input-file -canonicalize | FileCheck %s
+// RUN: mlir-opt %s -test-transform-dialect-interpreter -split-input-file -canonicalize -cse | FileCheck %s
 
 func.func @reduction_tile(%arg0: tensor<?x?xf32>, %out: tensor<?xf32>) -> tensor<?xf32> {
   %red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
@@ -86,3 +86,114 @@ transform.sequence failures(propagate) {
 //     CHECK:   }
 //     CHECK:   linalg.generic 
 //     CHECK:   return
+
+// -----
+
+func.func @reduction_tile_parallel(
+  %arg0: tensor<?x?xf32>, %out: tensor<?xf32>) -> tensor<?xf32> {
+  %red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+                                          affine_map<(d0, d1) -> (d0)>],
+   iterator_types = ["parallel", "reduction"]}
+   ins(%arg0 : tensor<?x?xf32>)
+   outs(%out : tensor<?xf32>) {
+    ^bb0(%arg7: f32, %arg9: f32):
+      %1 = arith.mulf %arg7, %arg7 : f32
+      %2 = arith.addf %1, %arg9 : f32
+      linalg.yield %2 : f32
+    } -> tensor<?xf32>
+  return %red : tensor<?xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb0(%arg1: !pdl.operation):
+  %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
+  %1, %2, %3 = transform.structured.tile_reduction_using_foreach_thread %0 { num_threads = [0, 5] }
+}
+
+// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0)[s0] -> (-(d0 * (s0 ceildiv 5)) + s0, s0 ceildiv 5)>
+// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0) -> (0, d0)>
+// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0)[s0] -> (d0 * (s0 ceildiv 5))>
+// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-DAG: #[[MAP4:.*]] = affine_map<(d0, d1) -> (d0)>
+//     CHECK: func @reduction_tile_parallel(%[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?xf32>
+// CHECK-DAG:   %[[I:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:   %[[C5:.*]] = arith.constant 5 : index
+// CHECK-DAG:   %[[D0:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?xf32>
+// CHECK-DAG:   %[[D1:.*]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?xf32>
+// CHECK-DAG:   %[[D2:.*]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?xf32>
+//     CHECK:   %[[E:.*]] = tensor.empty(%[[D2]]) : tensor<?x5xf32>
+//     CHECK:   %[[F:.*]] = linalg.fill ins(%[[I]] : f32) outs(%[[E]] : tensor<?x5xf32>) -> tensor<?x5xf32>
+//     CHECK:   %[[L:.*]] = scf.foreach_thread (%[[IV:.+]]) in (%[[C5]]) shared_outs(%[[ARG3:.+]] = %[[F]]) -> (tensor<?x5xf32>) {
+//     CHECK:     %[[TS0:.+]] = affine.min #[[MAP0]](%[[IV]])[%[[D1]]]
+//     CHECK:     %[[TS1:.+]] = affine.max #[[MAP1]](%[[TS0]])
+//     CHECK:     %[[ET:.+]] = tensor.extract_slice %[[ARG3:.+]][0, %[[IV]]] [%[[D0]], 1] [1, 1] : tensor<?x5xf32> to tensor<?xf32>
+//     CHECK:     %[[TINDEX:.+]] = affine.apply #[[MAP2]](%[[IV]])[%[[D1]]]
+//     CHECK:     %[[INCHUNK:.+]] = tensor.extract_slice %[[ARG0]][0, %[[TINDEX]]] [%[[D0]], %[[TS1]]] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+//     CHECK:     %[[TEMPEXT:.+]] = tensor.extract_slice %[[ET]][0] [%[[D0]]] [1] : tensor<?xf32> to tensor<?xf32>
+//     CHECK:     %[[PARTIAL:.+]] = linalg.generic {indexing_maps = [#[[MAP3]], #[[MAP4]]], iterator_types = ["parallel", "reduction"]} ins(%[[INCHUNK]] : tensor<?x?xf32>) outs(%[[TEMPEXT]] : tensor<?xf32>) {
+//     CHECK:       arith.mulf
+//     CHECK:       arith.addf
+//     CHECK:       linalg.yield
+//     CHECK:     } -> tensor<?xf32>
+//     CHECK:     scf.foreach_thread.perform_concurrently {
+//     CHECK:       tensor.parallel_insert_slice %[[PARTIAL]] into %[[ARG3]][0, %[[IV]]] [%[[D0]], 1] [1, 1] : tensor<?xf32> into tensor<?x5xf32>
+//     CHECK:     }
+//     CHECK:   }
+//     CHECK:   %[[R:.*]] = linalg.generic {indexing_maps = [#[[MAP3]], #[[MAP4]]], iterator_types = ["parallel", "reduction"]} ins(%[[L]] : tensor<?x5xf32>) outs(%[[ARG1]] : tensor<?xf32>) {
+//     CHECK:     arith.addf
+//     CHECK:     linalg.yield
+//     CHECK:   } -> tensor<?xf32>
+//     CHECK:   return %[[R]] : tensor<?xf32>
+
+// -----
+
+func.func @matmul_tile_parallel(
+  %A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %out: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %matmul = linalg.matmul ins(%A, %B: tensor<?x?xf32>, tensor<?x?xf32>)
+                     outs(%out: tensor<?x?xf32>) -> tensor<?x?xf32>
+  return %matmul : tensor<?x?xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb0(%arg1: !pdl.operation):
+  %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
+  %1, %2, %3 = transform.structured.tile_reduction_using_foreach_thread %0 { num_threads = [0, 0, 5] }
+}
+
+// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0)[s0] -> (-(d0 * (s0 ceildiv 5)) + s0, s0 ceildiv 5)>
+// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0) -> (0, d0)>
+// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0)[s0] -> (d0 * (s0 ceildiv 5))>
+// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-DAG: #[[MAP4:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+//     CHECK: func @matmul_tile_parallel(%[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>
+// CHECK-DAG:   %[[I:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:   %[[C5:.*]] = arith.constant 5 : index
+// CHECK-DAG:   %[[D0:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?xf32>
+// CHECK-DAG:   %[[D1:.*]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?xf32>
+// CHECK-DAG:   %[[D2:.*]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
+// CHECK-DAG:   %[[D3:.*]] = tensor.dim %[[ARG2]], %[[C0]] : tensor<?x?xf32>
+// CHECK-DAG:   %[[D4:.*]] = tensor.dim %[[ARG2]], %[[C1]] : tensor<?x?xf32>
+//     CHECK:   %[[E:.*]] = tensor.empty(%[[D3]], %[[D4]]) : tensor<?x?x5xf32>
+//     CHECK:   %[[F:.*]] = linalg.fill ins(%[[I]] : f32) outs(%[[E]] : tensor<?x?x5xf32>) -> tensor<?x?x5xf32>
+//     CHECK:   %[[L:.*]] = scf.foreach_thread (%[[IV:.+]]) in (%[[C5]]) shared_outs(%[[ARG3:.+]] = %[[F]]) -> (tensor<?x?x5xf32>) {
+//     CHECK:     %[[TS0:.+]] = affine.min #[[MAP0]](%[[IV]])[%[[D1]]]
+//     CHECK:     %[[TS1:.+]] = affine.max #[[MAP1]](%[[TS0]])
+//     CHECK:     %[[ET:.+]] = tensor.extract_slice %[[ARG3:.+]][0, 0, %[[IV]]] [%[[D0]], %[[D2]], 1] [1, 1, 1] : tensor<?x?x5xf32> to tensor<?x?xf32>
+//     CHECK:     %[[TINDEX:.+]] = affine.apply #[[MAP2]](%[[IV]])[%[[D1]]]
+//     CHECK:     %[[INCHUNKA:.+]] = tensor.extract_slice %[[ARG0]][0, %[[TINDEX]]] [%[[D0]], %[[TS1]]] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+//     CHECK:     %[[INCHUNKB:.+]] = tensor.extract_slice %[[ARG1]][%[[TINDEX]], 0] [%[[TS1]], %[[D2]]] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+//     CHECK:     %[[TEMPEXT:.+]] = tensor.extract_slice %[[ET]][0, 0] [%[[D0]], %[[D2]]] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+//     CHECK:     %[[PARTIAL:.+]] = linalg.matmul ins(%[[INCHUNKA]], %[[INCHUNKB]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[TEMPEXT]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+//     CHECK:     scf.foreach_thread.perform_concurrently {
+//     CHECK:       tensor.parallel_insert_slice %[[PARTIAL]] into %[[ARG3]][0, 0, %[[IV]]] [%[[D0]], %[[D2]], 1] [1, 1, 1] : tensor<?x?xf32> into tensor<?x?x5xf32>
+//     CHECK:     }
+//     CHECK:   }
+//     CHECK:   %[[R:.*]] = linalg.generic {indexing_maps = [#[[MAP3]], #[[MAP4]]], iterator_types = ["parallel", "parallel", "reduction"]} ins(%[[L]] : tensor<?x?x5xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) {
+//     CHECK:     arith.addf
+//     CHECK:     linalg.yield
+//     CHECK:   } -> tensor<?x?xf32>
+//     CHECK:   return %[[R]] : tensor<?x?xf32>


        


More information about the Mlir-commits mailing list