[Mlir-commits] [mlir] f7fda6b - [mlir][linalg] Add extra parameter to tiling reduction to foreach_thread

Thomas Raoux llvmlistbot at llvm.org
Wed Dec 7 10:37:23 PST 2022


Author: Thomas Raoux
Date: 2022-12-07T18:37:05Z
New Revision: f7fda6ba4a7bca12bc7ef62d27895aa24482eda5

URL: https://github.com/llvm/llvm-project/commit/f7fda6ba4a7bca12bc7ef62d27895aa24482eda5
DIFF: https://github.com/llvm/llvm-project/commit/f7fda6ba4a7bca12bc7ef62d27895aa24482eda5.diff

LOG: [mlir][linalg] Add extra parameter to tiling reduction to foreach_thread

This adds a tile_size parameter, when it is used the tiles are
cyclically distributed onto the threads of the scf.foreach_thread op.

Differential Revision: https://reviews.llvm.org/D139474

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
    mlir/test/Dialect/Linalg/transform-tile-reduction.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 3ea0a66625776..f7b0c03ca2f07 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -751,6 +751,8 @@ def TileReductionUsingForeachThreadOp :
     All the partial reduction value is are parallel inserted to create a new
     tensor. After the loop a merge operation is created to do a final reduction
     with the partial reductions tensor.
+    If an extra `tile_sizes` parameter is passed the tiles are cyclically
+    distributed on the threads of the `scf.foreach_threads` loop.
 
     #### Return modes
 
@@ -804,7 +806,8 @@ def TileReductionUsingForeachThreadOp :
   }];
 
   let arguments = (ins PDL_Operation:$target,
-                   DefaultValuedAttr<I64ArrayAttr, "{}">:$num_threads);
+                   DefaultValuedAttr<I64ArrayAttr, "{}">:$num_threads,
+                   DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_sizes);
   let results = (outs PDL_Operation:$fill_op,
                       PDL_Operation:$split_linalg_op,
                       PDL_Operation:$combining_linalg_op);

diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index a58c9dc23c1fc..d7603d2c3dd1b 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -496,7 +496,8 @@ struct ForeachThreadReductionTilingResult {
 FailureOr<ForeachThreadReductionTilingResult>
 tileReductionUsingForeachThread(RewriterBase &b, PartialReductionOpInterface op,
                                 ArrayRef<OpFoldResult> numThreads,
-                                Optional<ArrayAttr> mapping);
+                                ArrayRef<OpFoldResult> tileSizes = {},
+                                Optional<ArrayAttr> mapping = llvm::None);
 
 /// All indices returned by IndexOp should be invariant with respect to
 /// tiling. Therefore, if an operation is tiled, we have to transform the

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index b46349874bd4f..9e94f101349a2 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1217,16 +1217,12 @@ transform::TileReductionUsingForeachThreadOp::applyToOne(
     transform::TransformState &state) {
   TrivialPatternRewriter rewriter(getContext());
   rewriter.setInsertionPoint(target);
-  SmallVector<int64_t> numThreads = extractFromI64ArrayAttr(getNumThreads());
-  SmallVector<OpFoldResult> numThreadResults;
-  for (int64_t num : numThreads) {
-    numThreadResults.push_back(rewriter.getIndexAttr(num));
-  }
-
+  SmallVector<OpFoldResult> numThreads = getAsOpFoldResult(getNumThreads());
+  SmallVector<OpFoldResult> tileSizes = getAsOpFoldResult(getTileSizes());
   FailureOr<linalg::ForeachThreadReductionTilingResult> result =
       linalg::tileReductionUsingForeachThread(
           rewriter, cast<PartialReductionOpInterface>(target.getOperation()),
-          numThreadResults, /*mapping=*/std::nullopt);
+          numThreads, tileSizes, /*mapping=*/std::nullopt);
 
   if (failed(result)) {
     results.assign(3, nullptr);

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index cde33ce740e7d..8c34c42ea3ff9 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -13,6 +13,7 @@
 #include "mlir/Dialect/Linalg/Passes.h"
 
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Affine/LoopUtils.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -410,152 +411,6 @@ linalg::tileToForeachThreadOpUsingTileSizes(RewriterBase &b, TilingInterface op,
                                    /*omitTileOffsetBoundsCheck=*/true);
 }
 
-FailureOr<linalg::ForeachThreadReductionTilingResult>
-linalg::tileReductionUsingForeachThread(RewriterBase &b,
-                                        PartialReductionOpInterface op,
-                                        ArrayRef<OpFoldResult> numThreads,
-                                        Optional<ArrayAttr> mapping) {
-  Location loc = op.getLoc();
-  OpBuilder::InsertionGuard g(b);
-  // Ops implementing PartialReductionOpInterface are expected to implement
-  // TilingInterface.
-  auto tilingInterfaceOp = cast<TilingInterface>(op.getOperation());
-  SmallVector<Range> iterationDomain = tilingInterfaceOp.getIterationDomain(b);
-  if (op->getNumResults() != 1)
-    return b.notifyMatchFailure(
-        op, "don't support ops with multiple results for now");
-  SmallVector<utils::IteratorType> iterators =
-      tilingInterfaceOp.getLoopIteratorTypes();
-  SmallVector<unsigned> redDims;
-  cast<linalg::LinalgOp>(op.getOperation()).getReductionDims(redDims);
-  if (redDims.size() != 1)
-    return b.notifyMatchFailure(
-        op, "only support ops with one reduction dimension.");
-  int reductionDim = static_cast<int>(redDims.front());
-  // 1. create the inital tensor value.
-  FailureOr<Operation *> identityTensor =
-      op.generateInitialTensorForPartialReduction(b, loc, numThreads,
-                                                  reductionDim);
-  if (failed(identityTensor))
-    return b.notifyMatchFailure(op,
-                                "cannot create a tensor of identity value.");
-
-  // Gather destination tensors.
-  SmallVector<Value> dest;
-  if (failed(tensor::getOrCreateDestinations(b, loc, op, dest)))
-    return b.notifyMatchFailure(op, "failed to get destination tensors");
-
-  Operation *tiledOp = nullptr;
-
-  SmallVector<OpFoldResult> nonZeroNumThreads =
-      llvm::to_vector(llvm::make_filter_range(numThreads, [](OpFoldResult ofr) {
-        return !isConstantIntValue(ofr, 0);
-      }));
-  SmallVector<Value> materializedNonZeroNumThreads =
-      llvm::to_vector(llvm::map_range(nonZeroNumThreads, [&](OpFoldResult ofr) {
-        return getValueOrCreateConstantIndexOp(b, loc, ofr);
-      }));
-
-  // 2. Create the ForeachThreadOp with an empty region.
-  scf::ForeachThreadOp foreachThreadOp = b.create<scf::ForeachThreadOp>(
-      loc, identityTensor.value()->getResults(),
-      ValueRange(materializedNonZeroNumThreads), mapping);
-
-  // 3. calculate the tile offsets and sizes.
-  b.setInsertionPointToStart(foreachThreadOp.getBody(0));
-  SmallVector<OpFoldResult> tiledOffsets, tiledSizes;
-  calculateTileOffsetsAndSizes(
-      b, loc, foreachThreadOp, numThreads, iterationDomain,
-      /*omitTileOffsetBoundsCheck =*/false,
-      /*nominalTileSizes=*/std::nullopt, tiledOffsets, tiledSizes);
-
-  // 4. Clone the tileable op and update its destination operands to use the
-  // output bbArgs of the ForeachThreadOp.
-  ArrayRef<BlockArgument> destBbArgs =
-      foreachThreadOp.getOutputBlockArguments();
-  Operation *clonedOp = b.clone(*op.getOperation());
-  auto destinationStyleOp = cast<DestinationStyleOpInterface>(clonedOp);
-  for (OpOperand *outOperand : destinationStyleOp.getDpsInitOperands()) {
-    auto *it = llvm::find(dest, outOperand->get());
-    assert(it != dest.end() && "dest operand not found in dest");
-    unsigned destNum = std::distance(dest.begin(), it);
-    SmallVector<OpFoldResult> strides(numThreads.size(), b.getIndexAttr(1));
-    SmallVector<OpFoldResult> outOffsets(numThreads.size(), b.getIndexAttr(0));
-    SmallVector<OpFoldResult> sizes = tiledSizes;
-    sizes[reductionDim] = b.getIndexAttr(1);
-    outOffsets[reductionDim] = foreachThreadOp.getThreadIndices().front();
-    // TODO: use SubsetExtractOpInterface once it is available.
-    Value patial = b.create<tensor::ExtractSliceOp>(
-        loc, outOperand->get().getType().cast<RankedTensorType>(),
-        destBbArgs[destNum], outOffsets, sizes, strides);
-    outOperand->set(patial);
-  }
-
-  // 5. Tile the cloned op and delete the clone.
-  SmallVector<Operation *> tiledOps =
-      cast<TilingInterface>(clonedOp).getTiledImplementation(b, tiledOffsets,
-                                                             tiledSizes);
-  b.eraseOp(clonedOp);
-  assert(tiledOps.size() == 1 && "expected a single produced tiled op");
-  tiledOp = tiledOps.front();
-
-  // 6. Insert the partial reductions back into a new tensor.
-  auto tiledInterfaceOp = dyn_cast<TilingInterface>(tiledOp);
-  assert(tiledInterfaceOp && "Tiled op does not implement TilingInterface");
-  OpBuilder::InsertPoint insertPt = b.saveInsertionPoint();
-  for (auto it : llvm::zip(llvm::seq(unsigned(0), unsigned(dest.size())),
-                           tiledInterfaceOp->getResults(), destBbArgs)) {
-    b.setInsertionPoint(insertPt.getBlock(), insertPt.getPoint());
-    SmallVector<OpFoldResult> resultOffsets, resultSizes;
-    if (failed(tilingInterfaceOp.getResultTilePosition(
-            b, std::get<0>(it), tiledOffsets, tiledSizes, resultOffsets,
-            resultSizes)))
-      return op->emitOpError("output offsets couldn't be calculated");
-    SmallVector<OpFoldResult> resultOffsetsRank, resultSizesRank;
-    int64_t offIdx = 0;
-    int64_t sizeIdx = 0;
-    for (int64_t i = 0, e = numThreads.size(); i < e; ++i) {
-      if (i == reductionDim) {
-        resultOffsetsRank.push_back(foreachThreadOp.getThreadIndices().front());
-        resultSizesRank.push_back(b.getIndexAttr(1));
-        continue;
-      }
-      resultOffsetsRank.push_back(resultOffsets[offIdx++]);
-      resultSizesRank.push_back(resultSizes[sizeIdx++]);
-    }
-
-    SmallVector<OpFoldResult> strides(resultSizesRank.size(),
-                                      b.getIndexAttr(1));
-    b.setInsertionPointToEnd(foreachThreadOp.getTerminator().getBody());
-    b.create<tensor::ParallelInsertSliceOp>(loc, std::get<1>(it),
-                                            std::get<2>(it), resultOffsetsRank,
-                                            resultSizesRank, strides);
-  }
-  // 7. Merge the partial reductions.
-  b.setInsertionPointAfter(foreachThreadOp);
-  Operation *mergeOp =
-      op.mergeReductions(b, loc, foreachThreadOp->getResults(), reductionDim);
-  b.replaceOp(op, mergeOp->getResults());
-  ForeachThreadReductionTilingResult results;
-  results.initialOp = identityTensor.value();
-  results.loops = foreachThreadOp;
-  results.parallelTiledOp = tiledOp;
-  results.mergeOp = mergeOp;
-  return results;
-}
-
-// Insert a tile `source` into the destination tensor `dest`. The position at
-// which the tile is inserted (as well as size of tile) is taken from a given
-// ExtractSliceOp `sliceOp`.
-static Value insertSliceIntoTensor(OpBuilder &b, Location loc,
-                                   tensor::ExtractSliceOp sliceOp, Value source,
-                                   Value dest) {
-  return b.create<tensor::InsertSliceOp>(
-      loc, sliceOp.getSource().getType(), source, dest, sliceOp.getOffsets(),
-      sliceOp.getSizes(), sliceOp.getStrides(), sliceOp.getStaticOffsets(),
-      sliceOp.getStaticSizes(), sliceOp.getStaticStrides());
-}
-
 template <typename LoopTy>
 static FailureOr<TiledLinalgOp>
 tileLinalgOpImpl(RewriterBase &b, LinalgOp op, ArrayRef<OpFoldResult> tileSizes,
@@ -707,6 +562,165 @@ tileLinalgOpImpl(RewriterBase &b, LinalgOp op, ArrayRef<OpFoldResult> tileSizes,
       res, loops, outermostLoop ? outermostLoop->getResults() : tensorResults};
 }
 
+FailureOr<linalg::ForeachThreadReductionTilingResult>
+linalg::tileReductionUsingForeachThread(RewriterBase &b,
+                                        PartialReductionOpInterface op,
+                                        ArrayRef<OpFoldResult> numThreads,
+                                        ArrayRef<OpFoldResult> tileSizes,
+                                        Optional<ArrayAttr> mapping) {
+  Location loc = op.getLoc();
+  OpBuilder::InsertionGuard g(b);
+  // Ops implementing PartialReductionOpInterface are expected to implement
+  // TilingInterface.
+  auto tilingInterfaceOp = cast<TilingInterface>(op.getOperation());
+  SmallVector<Range> iterationDomain = tilingInterfaceOp.getIterationDomain(b);
+  if (op->getNumResults() != 1)
+    return b.notifyMatchFailure(
+        op, "don't support ops with multiple results for now");
+  SmallVector<utils::IteratorType> iterators =
+      tilingInterfaceOp.getLoopIteratorTypes();
+  SmallVector<unsigned> redDims;
+  cast<linalg::LinalgOp>(op.getOperation()).getReductionDims(redDims);
+  if (redDims.size() != 1)
+    return b.notifyMatchFailure(
+        op, "only support ops with one reduction dimension.");
+  if (!tileSizes.empty() && tileSizes.size() != numThreads.size())
+    return b.notifyMatchFailure(op, "if tile sizes are present it must have as "
+                                    "many elements as number of threads");
+  int reductionDim = static_cast<int>(redDims.front());
+  // 1. create the inital tensor value.
+  FailureOr<Operation *> identityTensor =
+      op.generateInitialTensorForPartialReduction(b, loc, numThreads,
+                                                  reductionDim);
+  if (failed(identityTensor))
+    return b.notifyMatchFailure(op,
+                                "cannot create a tensor of identity value.");
+
+  // Gather destination tensors.
+  SmallVector<Value> dest;
+  if (failed(tensor::getOrCreateDestinations(b, loc, op, dest)))
+    return b.notifyMatchFailure(op, "failed to get destination tensors");
+
+  Operation *tiledOp = nullptr;
+
+  SmallVector<OpFoldResult> nonZeroNumThreads =
+      llvm::to_vector(llvm::make_filter_range(numThreads, [](OpFoldResult ofr) {
+        return !isConstantIntValue(ofr, 0);
+      }));
+  SmallVector<Value> materializedNonZeroNumThreads =
+      getAsValues(b, loc, nonZeroNumThreads);
+
+  // 2. Create the ForeachThreadOp with an empty region.
+  scf::ForeachThreadOp foreachThreadOp = b.create<scf::ForeachThreadOp>(
+      loc, identityTensor.value()->getResults(),
+      ValueRange(materializedNonZeroNumThreads), mapping);
+
+  // 3. calculate the tile offsets and sizes.
+  b.setInsertionPointToStart(foreachThreadOp.getBody(0));
+  SmallVector<OpFoldResult> tiledOffsets, tiledSizes;
+  calculateTileOffsetsAndSizes(
+      b, loc, foreachThreadOp, numThreads, iterationDomain,
+      /*omitTileOffsetBoundsCheck =*/false,
+      /*nominalTileSizes=*/std::nullopt, tiledOffsets, tiledSizes);
+
+  // 4. Clone the tileable op and update its destination operands to use the
+  // output bbArgs of the ForeachThreadOp.
+  ArrayRef<BlockArgument> destBbArgs =
+      foreachThreadOp.getOutputBlockArguments();
+  Operation *clonedOp = b.clone(*op.getOperation());
+  b.setInsertionPointToStart(foreachThreadOp.getBody(0));
+  auto destinationStyleOp = cast<DestinationStyleOpInterface>(clonedOp);
+  for (OpOperand *initOperand : destinationStyleOp.getDpsInitOperands()) {
+    auto *it = llvm::find(dest, initOperand->get());
+    assert(it != dest.end() && "dest operand not found in dest");
+    unsigned destNum = std::distance(dest.begin(), it);
+    SmallVector<OpFoldResult> strides(numThreads.size(), b.getIndexAttr(1));
+    SmallVector<OpFoldResult> outOffsets(numThreads.size(), b.getIndexAttr(0));
+    SmallVector<OpFoldResult> sizes = tiledSizes;
+    sizes[reductionDim] = b.getIndexAttr(1);
+    outOffsets[reductionDim] = foreachThreadOp.getThreadIndices().front();
+    // TODO: use SubsetExtractOpInterface once it is available.
+    Value patial = b.create<tensor::ExtractSliceOp>(
+        loc, initOperand->get().getType().cast<RankedTensorType>(),
+        destBbArgs[destNum], outOffsets, sizes, strides);
+    initOperand->set(patial);
+  }
+  b.setInsertionPoint(clonedOp);
+
+  // 5. Tile the cloned op and delete the clone.
+  if (tileSizes.empty()) {
+    SmallVector<Operation *> tiledOps =
+        cast<TilingInterface>(clonedOp).getTiledImplementation(b, tiledOffsets,
+                                                               tiledSizes);
+    assert(tiledOps.size() == 1 && "expected a single produced tiled op");
+    tiledOp = tiledOps.front();
+  } else {
+    LinalgTilingOptions options;
+    auto tiled = tileLinalgOpImpl<scf::ForOp>(b, cast<LinalgOp>(clonedOp),
+                                              tileSizes, options);
+    SmallVector<Value> ids = foreachThreadOp.getThreadIndices();
+    mapLoopToProcessorIds(cast<scf::ForOp>(tiled->loops.back()), ids,
+                          materializedNonZeroNumThreads);
+    assert(tiled->loops.size() == 1 && "expected a single produced loop");
+    tiledOp = tiled->loops.front();
+  }
+  b.eraseOp(clonedOp);
+
+  // 6. Insert the partial reductions back into a new tensor.
+  b.setInsertionPointAfter(tiledOp);
+  OpBuilder::InsertPoint insertPt = b.saveInsertionPoint();
+  for (auto [index, result, bbArg] :
+       llvm::zip(llvm::seq<unsigned>(0, dest.size()), tiledOp->getResults(),
+                 destBbArgs)) {
+    b.setInsertionPoint(insertPt.getBlock(), insertPt.getPoint());
+    SmallVector<OpFoldResult> resultOffsets, resultSizes;
+    if (failed(tilingInterfaceOp.getResultTilePosition(
+            b, index, tiledOffsets, tiledSizes, resultOffsets, resultSizes)))
+      return op->emitOpError("output offsets couldn't be calculated");
+    SmallVector<OpFoldResult> resultOffsetsRank, resultSizesRank;
+    int64_t offIdx = 0;
+    int64_t sizeIdx = 0;
+    for (int64_t i = 0, e = numThreads.size(); i < e; ++i) {
+      if (i == reductionDim) {
+        resultOffsetsRank.push_back(foreachThreadOp.getThreadIndices().front());
+        resultSizesRank.push_back(b.getIndexAttr(1));
+        continue;
+      }
+      resultOffsetsRank.push_back(resultOffsets[offIdx++]);
+      resultSizesRank.push_back(resultSizes[sizeIdx++]);
+    }
+
+    SmallVector<OpFoldResult> strides(resultSizesRank.size(),
+                                      b.getIndexAttr(1));
+    b.setInsertionPointToEnd(foreachThreadOp.getTerminator().getBody());
+    b.create<tensor::ParallelInsertSliceOp>(
+        loc, result, bbArg, resultOffsetsRank, resultSizesRank, strides);
+  }
+  // 7. Merge the partial reductions.
+  b.setInsertionPointAfter(foreachThreadOp);
+  Operation *mergeOp =
+      op.mergeReductions(b, loc, foreachThreadOp->getResults(), reductionDim);
+  b.replaceOp(op, mergeOp->getResults());
+  ForeachThreadReductionTilingResult results;
+  results.initialOp = identityTensor.value();
+  results.loops = foreachThreadOp;
+  results.parallelTiledOp = tiledOp;
+  results.mergeOp = mergeOp;
+  return results;
+}
+
+// Insert a tile `source` into the destination tensor `dest`. The position at
+// which the tile is inserted (as well as size of tile) is taken from a given
+// ExtractSliceOp `sliceOp`.
+static Value insertSliceIntoTensor(OpBuilder &b, Location loc,
+                                   tensor::ExtractSliceOp sliceOp, Value source,
+                                   Value dest) {
+  return b.create<tensor::InsertSliceOp>(
+      loc, sliceOp.getSource().getType(), source, dest, sliceOp.getOffsets(),
+      sliceOp.getSizes(), sliceOp.getStrides(), sliceOp.getStaticOffsets(),
+      sliceOp.getStaticSizes(), sliceOp.getStaticStrides());
+}
+
 template <typename LoopTy>
 FailureOr<TiledLinalgOp> static tileLinalgOpImpl(
     RewriterBase &b, LinalgOp op, const LinalgTilingOptions &options) {

diff  --git a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
index 82925906e78c0..ad2dc0a4124d8 100644
--- a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
+++ b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
@@ -126,9 +126,9 @@ transform.sequence failures(propagate) {
 //     CHECK:   %[[E:.*]] = tensor.empty(%[[D2]]) : tensor<?x5xf32>
 //     CHECK:   %[[F:.*]] = linalg.fill ins(%[[I]] : f32) outs(%[[E]] : tensor<?x5xf32>) -> tensor<?x5xf32>
 //     CHECK:   %[[L:.*]] = scf.foreach_thread (%[[IV:.+]]) in (%[[C5]]) shared_outs(%[[ARG3:.+]] = %[[F]]) -> (tensor<?x5xf32>) {
-//     CHECK:     %[[TS0:.+]] = affine.min #[[MAP0]](%[[IV]])[%[[D1]]]
-//     CHECK:     %[[TS1:.+]] = affine.max #[[MAP1]](%[[TS0]])
-//     CHECK:     %[[ET:.+]] = tensor.extract_slice %[[ARG3:.+]][0, %[[IV]]] [%[[D0]], 1] [1, 1] : tensor<?x5xf32> to tensor<?xf32>
+// CHECK-DAG:     %[[TS0:.+]] = affine.min #[[MAP0]](%[[IV]])[%[[D1]]]
+// CHECK-DAG:     %[[TS1:.+]] = affine.max #[[MAP1]](%[[TS0]])
+// CHECK-DAG:     %[[ET:.+]] = tensor.extract_slice %[[ARG3:.+]][0, %[[IV]]] [%[[D0]], 1] [1, 1] : tensor<?x5xf32> to tensor<?xf32>
 //     CHECK:     %[[TINDEX:.+]] = affine.apply #[[MAP2]](%[[IV]])[%[[D1]]]
 //     CHECK:     %[[INCHUNK:.+]] = tensor.extract_slice %[[ARG0]][0, %[[TINDEX]]] [%[[D0]], %[[TS1]]] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
 //     CHECK:     %[[TEMPEXT:.+]] = tensor.extract_slice %[[ET]][0] [%[[D0]]] [1] : tensor<?xf32> to tensor<?xf32>
@@ -180,9 +180,9 @@ transform.sequence failures(propagate) {
 //     CHECK:   %[[E:.*]] = tensor.empty(%[[D3]], %[[D4]]) : tensor<?x?x5xf32>
 //     CHECK:   %[[F:.*]] = linalg.fill ins(%[[I]] : f32) outs(%[[E]] : tensor<?x?x5xf32>) -> tensor<?x?x5xf32>
 //     CHECK:   %[[L:.*]] = scf.foreach_thread (%[[IV:.+]]) in (%[[C5]]) shared_outs(%[[ARG3:.+]] = %[[F]]) -> (tensor<?x?x5xf32>) {
-//     CHECK:     %[[TS0:.+]] = affine.min #[[MAP0]](%[[IV]])[%[[D1]]]
-//     CHECK:     %[[TS1:.+]] = affine.max #[[MAP1]](%[[TS0]])
-//     CHECK:     %[[ET:.+]] = tensor.extract_slice %[[ARG3:.+]][0, 0, %[[IV]]] [%[[D0]], %[[D2]], 1] [1, 1, 1] : tensor<?x?x5xf32> to tensor<?x?xf32>
+// CHECK-DAG:     %[[TS0:.+]] = affine.min #[[MAP0]](%[[IV]])[%[[D1]]]
+// CHECK-DAG:     %[[TS1:.+]] = affine.max #[[MAP1]](%[[TS0]])
+// CHECK-DAG:     %[[ET:.+]] = tensor.extract_slice %[[ARG3:.+]][0, 0, %[[IV]]] [%[[D0]], %[[D2]], 1] [1, 1, 1] : tensor<?x?x5xf32> to tensor<?x?xf32>
 //     CHECK:     %[[TINDEX:.+]] = affine.apply #[[MAP2]](%[[IV]])[%[[D1]]]
 //     CHECK:     %[[INCHUNKA:.+]] = tensor.extract_slice %[[ARG0]][0, %[[TINDEX]]] [%[[D0]], %[[TS1]]] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
 //     CHECK:     %[[INCHUNKB:.+]] = tensor.extract_slice %[[ARG1]][%[[TINDEX]], 0] [%[[TS1]], %[[D2]]] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
@@ -197,3 +197,68 @@ transform.sequence failures(propagate) {
 //     CHECK:     linalg.yield
 //     CHECK:   } -> tensor<?x?xf32>
 //     CHECK:   return %[[R]] : tensor<?x?xf32>
+
+// -----
+
+func.func @reduction_tile_parallel_cyclic_dist(
+  %arg0: tensor<?x?xf32>, %out: tensor<?xf32>) -> tensor<?xf32> {
+  %red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+                                          affine_map<(d0, d1) -> (d0)>],
+   iterator_types = ["parallel", "reduction"]}
+   ins(%arg0 : tensor<?x?xf32>)
+   outs(%out : tensor<?xf32>) {
+    ^bb0(%arg7: f32, %arg9: f32):
+      %1 = arith.mulf %arg7, %arg7 : f32
+      %2 = arith.addf %1, %arg9 : f32
+      linalg.yield %2 : f32
+    } -> tensor<?xf32>
+  return %red : tensor<?xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb0(%arg1: !pdl.operation):
+  %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
+  %1, %2, %3 = transform.structured.tile_reduction_using_foreach_thread %0 { num_threads = [0, 5], tile_sizes = [0, 3] }
+}
+
+// CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 * 3)>
+// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 3)>
+// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0, d1) -> (d0)>
+
+//     CHECK: func @reduction_tile_parallel_cyclic_dist(%[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?xf32>
+// CHECK-DAG:   %[[I:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:   %[[C5:.*]] = arith.constant 5 : index
+// CHECK-DAG:   %[[C15:.*]] = arith.constant 15 : index
+// CHECK-DAG:   %[[D0:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?xf32>
+// CHECK-DAG:   %[[D2:.*]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?xf32>
+//     CHECK:   %[[E:.*]] = tensor.empty(%[[D2]]) : tensor<?x5xf32>
+//     CHECK:   %[[F:.*]] = linalg.fill ins(%[[I]] : f32) outs(%[[E]] : tensor<?x5xf32>) -> tensor<?x5xf32>
+//     CHECK:   %[[L:.*]] = scf.foreach_thread (%[[IV:.+]]) in (%[[C5]]) shared_outs(%[[ARG3:.+]] = %[[F]]) -> (tensor<?x5xf32>) {
+//     CHECK:     %[[ET:.+]] = tensor.extract_slice %[[ARG3:.+]][0, %[[IV]]] [%[[D0]], 1] [1, 1] : tensor<?x5xf32> to tensor<?xf32>
+//     CHECK:     %[[D1:.*]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?xf32>
+//     CHECK:     %[[LB:.+]] = affine.apply #[[MAP0]]()[%[[IV]]]
+//     CHECK:     %[[CARRY:.+]] = scf.for %[[IV1:.+]] = %[[LB]] to %[[D1]] step %[[C15]] iter_args(%[[ACC:.+]] = %[[ET]]) -> (tensor<?xf32>) {
+//     CHECK:       %[[TS0:.+]] = affine.min #[[MAP1]](%[[IV1]])[%[[D1]]]
+//     CHECK:       %[[D3:.+]] = tensor.dim %[[ACC]], %[[C0]] : tensor<?xf32>
+//     CHECK:       %[[INCHUNK:.+]] = tensor.extract_slice %[[ARG0]][0, %[[IV1]]] [%[[D0]], %[[TS0]]] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+//     CHECK:       %[[TEMPEXT:.+]] = tensor.extract_slice %[[ACC]][0] [%[[D3]]] [1] : tensor<?xf32> to tensor<?xf32>
+//     CHECK:       %[[PARTIAL:.+]] = linalg.generic {indexing_maps = [#[[MAP2]], #[[MAP3]]], iterator_types = ["parallel", "reduction"]} ins(%[[INCHUNK]] : tensor<?x?xf32>) outs(%[[TEMPEXT]] : tensor<?xf32>) {
+//     CHECK:         arith.mulf
+//     CHECK:         arith.addf
+//     CHECK:         linalg.yield
+//     CHECK:       } -> tensor<?xf32>
+//     CHECK:       %[[INS:.+]] = tensor.insert_slice %[[PARTIAL]] into %[[ACC]][0] [%[[D3]]] [1] : tensor<?xf32> into tensor<?xf32>
+//     CHECK:       scf.yield %[[INS]] : tensor<?xf32>
+//     CHECK:     }
+//     CHECK:     scf.foreach_thread.perform_concurrently {
+//     CHECK:       tensor.parallel_insert_slice %[[CARRY]] into %[[ARG3]][0, %[[IV]]] [%[[D0]], 1] [1, 1] : tensor<?xf32> into tensor<?x5xf32>
+//     CHECK:     }
+//     CHECK:   }
+//     CHECK:   %[[R:.*]] = linalg.generic {indexing_maps = [#[[MAP2]], #[[MAP3]]], iterator_types = ["parallel", "reduction"]} ins(%[[L]] : tensor<?x5xf32>) outs(%[[ARG1]] : tensor<?xf32>) {
+//     CHECK:     arith.addf
+//     CHECK:     linalg.yield
+//     CHECK:   } -> tensor<?xf32>
+//     CHECK:   return %[[R]] : tensor<?xf32>


        


More information about the Mlir-commits mailing list