[Mlir-commits] [mlir] [MLIR] Add continuous tiling to TileUsingForOp (PR #82792)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue May 21 08:31:33 PDT 2024


https://github.com/muneebkhan85 updated https://github.com/llvm/llvm-project/pull/82792

>From 37f4d823597a060707a0cb85af88d1df87b120b6 Mon Sep 17 00:00:00 2001
From: Muneeb Khan <muneeb.khan at huawei.com>
Date: Wed, 1 May 2024 22:15:18 +0800
Subject: [PATCH 1/3] [MLIR] Add continuous tiling to Transform dialect

Add continuous tiling op structured.continuous_tile
to the transform dialect that returns as result a list of
exponentially diminishing tile sizes and a list of split
points to do a multiway split of the target linalg op along
the specified dimension.
---
 .../Linalg/TransformOps/LinalgTransformOps.td |  46 ++++++
 .../Dialect/Linalg/Transforms/Transforms.h    |  20 +++
 .../TransformOps/LinalgTransformOps.cpp       | 151 ++++++++++++++++++
 mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp | 133 +++++++++++++++
 4 files changed, 350 insertions(+)

diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 5585ba27fdad8..6e6bfab8d179e 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -1819,6 +1819,52 @@ def TileReductionUsingForallOp :
 
 }
 
+//===----------------------------------------------------------------------===//
+// ContinuousTileSizesOp
+//===----------------------------------------------------------------------===//
+
+def ContinuousTileSizesOp : Op<Transform_Dialect, "structured.continuous_tile_sizes",
+       [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+        DeclareOpInterfaceMethods<TransformOpInterface>,
+        ReportTrackingListenerFailuresOpTrait]> {
+  let description = [{
+    This transform takes a linalg as target and a dimension and target size
+    as attributes to generate a list of (1) exponentially diminishing
+    tile sizes that are powers of 2; and (2) the corresponding chunk-sizes
+    the linalg op should be split into along the given dimension.
+
+    For example, for `target_size` 9, and `dimension` 0 for the following
+    linalg op as target
+
+    ```
+      %0 = linalg.matmul  ins(%arg0, %arg1: tensor<25x34xf32>, tensor<34x25xf32>)
+                      outs(%arg2: tensor<25x25xf32>)
+    ```
+
+    the first result `tile_sizes` will be a list of diminishing tile sizes
+    9, 4, 2, 1; and the second result will be a list of chunk sizes
+    18, 4, 2, 1 that the corresponding dimension should be split into.
+
+    After the linalg has been split along the given dimension (for example using
+    multiway split), each chunk can be tiled with the corresponding tile size in
+    the `tile_sizes` list generated as a result of this op.
+
+    Specifying the output type as !transform.param<i64> will cause `tile_sizes`
+    and `split_points` to be computed statically and not dynamically.
+  }];
+
+  let arguments = (ins TransformHandleTypeInterface:$target,
+                       ConfinedAttr<I64Attr, [IntNonNegative]>:$dimension,
+                       ConfinedAttr<I64Attr, [IntNonNegative]>:$target_size);
+  let results = (outs TransformAnyParamTypeOrAnyHandle:$tile_sizes,
+                      TransformAnyParamTypeOrAnyHandle:$split_points);
+  let hasVerifier = 1;
+  let assemblyFormat =
+    "$target attr-dict `:` custom<ContinuousTileSizeTypes>("
+    "type($target), type($tile_sizes), type($split_points))";
+
+}
+
 //===----------------------------------------------------------------------===//
 // TileUsingForOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 5ecf84fa9c701..f60ff85d4ea6d 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -801,6 +801,15 @@ struct MultiSizeSpecificationBase {
   /// Number of tiles associated with each size.
   T lowTripCount, highTripCount;
 };
+
+template <typename T>
+struct ContinuousTileSizeSpecificationBase {
+  /// Tile sizes.
+  SmallVector<T> tileSizes;
+  /// Number of tiles associated with each size.
+  SmallVector<T> tripCounts;
+};
+
 } // namespace detail
 
 /// A description of a multi-size tiling comprising tile sizes and numbers of
@@ -811,6 +820,11 @@ struct MultiSizeSpecification
 struct StaticMultiSizeSpecification
     : public detail::MultiSizeSpecificationBase<int64_t> {};
 
+struct ContinuousTileSizeSpecification
+    : public detail::ContinuousTileSizeSpecificationBase<Value> {};
+struct StaticContinuousTileSizeSpecification
+    : public detail::ContinuousTileSizeSpecificationBase<int64_t> {};
+
 /// Emits the IR computing the multi-sized tiling specification with two tile
 /// sizes not exceeding `targetSize`, each divisible by `sizeDivisor`, such
 /// that there exist numbers of tiles with these sizes that fully cover the
@@ -846,6 +860,12 @@ FailureOr<StaticMultiSizeSpecification>
 computeStaticMultiTileSizes(LinalgOp op, unsigned dimension, int64_t targetSize,
                             int64_t divisor);
 
+FailureOr<StaticContinuousTileSizeSpecification>
+computeStaticContinuousTileSizes(LinalgOp op, unsigned dimension,
+                                 unsigned targetSize);
+FailureOr<ContinuousTileSizeSpecification>
+computeContinuousTileSizes(OpBuilder &builder, LinalgOp op, unsigned dimension,
+                           OpFoldResult targetSize, bool emitAssertions);
 /// Rewrite a TilingInterface `op` to a tiled `scf.forall`, applying
 /// tiling by `numThreads`.
 /// If non-empty, the `mapping` is added as an attribute to the
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 13582a140a965..4b2d9485c16d1 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2581,6 +2581,157 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne(
   return DiagnosedSilenceableFailure::success();
 }
 
+//===----------------------------------------------------------------------===//
+// ContinuousTileSizesOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::ContinuousTileSizesOp::apply(transform::TransformRewriter &rewriter,
+                                        TransformResults &transformResults,
+                                        TransformState &state) {
+
+  SmallVector<Operation *> targetOps =
+      llvm::to_vector(state.getPayloadOps(getTarget()));
+
+  if (!llvm::hasSingleElement(targetOps)) {
+    return emitDefiniteFailure() << "requires exactly one target (got "
+                                 << llvm::range_size(targetOps) << ")";
+  }
+
+  auto target = dyn_cast<LinalgOp>(*targetOps.begin());
+
+  OpBuilder builder(target.getContext());
+
+  if (!target)
+    return emitDefiniteFailure() << "expected Linalg Op";
+
+  if (isa<TransformParamTypeInterface>(getSplitPoints().getType())) {
+    if (target.hasDynamicShape()) {
+      auto diag = emitSilenceableError()
+                  << "cannot compute parametric tile sizes for dynamically "
+                     "shaped payload op";
+      diag.attachNote(target->getLoc()) << "payload op";
+      return diag;
+    }
+
+    FailureOr<StaticContinuousTileSizeSpecification> spec =
+        computeStaticContinuousTileSizes(target, getDimension(),
+                                         getTargetSize());
+    if (failed(spec)) {
+      return emitSilenceableError()
+             << "failed to compute multi-size tiling sizes";
+    }
+
+    SmallVector<int64_t> splitPoints;
+
+    auto tileSizeTripCountPairs =
+        llvm::zip_equal(spec->tileSizes, spec->tripCounts);
+
+    for (auto [idx, pair] : llvm::enumerate(tileSizeTripCountPairs))
+      splitPoints.push_back(std::get<0>(pair) * std::get<1>(pair));
+
+    auto makeI64AttrsFromI64 = [&](ArrayRef<int64_t> values) {
+      return llvm::to_vector(
+          llvm::map_range(values, [&](int64_t value) -> Attribute {
+            return builder.getI64IntegerAttr(value);
+          }));
+    };
+    transformResults.setParams(cast<OpResult>(getTileSizes()),
+                               makeI64AttrsFromI64(spec->tileSizes));
+    transformResults.setParams(cast<OpResult>(getSplitPoints()),
+                               makeI64AttrsFromI64(splitPoints));
+
+    return DiagnosedSilenceableFailure::success();
+  }
+
+  builder.setInsertionPoint(target);
+
+  OpFoldResult targetSize = builder.getIndexAttr(getTargetSize());
+  unsigned dimension = getDimension();
+
+  FailureOr<ContinuousTileSizeSpecification> spec =
+      computeContinuousTileSizes(builder, target, dimension, targetSize, true);
+  if (failed(spec)) {
+    return emitSilenceableError() << "could not generate tile size computation";
+  }
+
+  auto tileSizeTripCountPairs =
+      llvm::zip_equal(spec->tileSizes, spec->tripCounts);
+
+  AffineExpr s0 = builder.getAffineSymbolExpr(0);
+  AffineExpr s1 = builder.getAffineSymbolExpr(1);
+  auto apply = [&](AffineExpr expr, ArrayRef<OpFoldResult> ofrs) -> Value {
+    return affine::makeComposedAffineApply(builder, target->getLoc(), expr,
+                                           ofrs);
+  };
+
+  SmallVector<Value> splitPoints;
+  Value splitPoint;
+  for (auto [idx, pair] : llvm::enumerate(tileSizeTripCountPairs)) {
+    splitPoint = apply(s0 * s1, {std::get<0>(pair), std::get<1>(pair)});
+    splitPoints.push_back(splitPoint);
+  }
+
+  auto makeOpFromValue = [&](ArrayRef<Value> values) {
+    return llvm::to_vector(
+        llvm::map_range(values, [&](Value value) -> Operation * {
+          return value.getDefiningOp();
+        }));
+  };
+
+  transformResults.set(cast<OpResult>(getTileSizes()),
+                       makeOpFromValue(spec->tileSizes));
+  transformResults.set(cast<OpResult>(getSplitPoints()),
+                       makeOpFromValue(splitPoints));
+
+  return DiagnosedSilenceableFailure::success();
+}
+
+LogicalResult transform::ContinuousTileSizesOp::verify() {
+
+  if (getTileSizes().getType() != getSplitPoints().getType()) {
+    return emitOpError() << "expects all results type to be the same";
+  }
+
+  return success();
+}
+
+void transform::ContinuousTileSizesOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  if (isa<TransformParamTypeInterface>(getTileSizes().getType()))
+    onlyReadsPayload(effects);
+  else
+    modifiesPayload(effects);
+  onlyReadsHandle(getTarget(), effects);
+  producesHandle(getTileSizes(), effects);
+  producesHandle(getSplitPoints(), effects);
+}
+
+static void printContinuousTileSizeTypes(OpAsmPrinter &printer, Operation *op,
+                                         Type targetType, Type tile_sizes,
+                                         Type) {
+  printer.printFunctionalType(TypeRange{targetType}, TypeRange{tile_sizes});
+}
+
+static ParseResult parseContinuousTileSizeTypes(OpAsmParser &parser,
+                                                Type &targetType,
+                                                Type &tileSizesType,
+                                                Type &splitPointsType) {
+  FunctionType funcType;
+  llvm::SMLoc typeLoc = parser.getCurrentLocation();
+  if (failed(parser.parseType<FunctionType>(funcType)))
+    return failure();
+
+  if (funcType.getNumInputs() != 1 || funcType.getNumResults() != 1) {
+    parser.emitError(typeLoc) << "expects a trailing functional type with one "
+                                 "argument and one result";
+  }
+  targetType = funcType.getInput(0);
+  tileSizesType = splitPointsType = funcType.getResult(0);
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // TileUsingForOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index df4089d61bfd7..8049ade591a70 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -107,6 +107,139 @@ static void emitIsPositiveIndexAssertion(ImplicitLocOpBuilder &b,
       b.getStringAttr("expected strictly positive tile size and divisor"));
 }
 
+FailureOr<StaticContinuousTileSizeSpecification>
+mlir::linalg::computeStaticContinuousTileSizes(LinalgOp op, unsigned dimension,
+                                               unsigned targetSize) {
+
+  assert(!op.hasDynamicShape() &&
+         "cannot compute static multi-tile sizes for an op with dynamic shape");
+  assert(targetSize > 0 && "target size must be non-negative");
+  assert(dimension < op.getNumLoops() && "dimension overflow");
+
+  StaticContinuousTileSizeSpecification spec;
+  int64_t loopRange = op.getStaticLoopRanges()[dimension];
+  int64_t tripCount = loopRange / targetSize;
+
+  unsigned tileSize = targetSize;
+
+  spec.tileSizes.push_back(tileSize);
+  spec.tripCounts.push_back(tripCount);
+
+  int64_t remainderChunk = loopRange % targetSize;
+
+  while (tileSize > 1 && remainderChunk != 0) {
+
+    uint64_t maxPower = llvm::bit_floor(tileSize);
+    tileSize = maxPower == tileSize ? maxPower >> 1 : maxPower;
+
+    tripCount = remainderChunk / tileSize;
+
+    if (tripCount > 0) {
+      spec.tileSizes.push_back(tileSize);
+      spec.tripCounts.push_back(tripCount);
+    }
+
+    remainderChunk = remainderChunk % tileSize;
+  }
+
+  auto tripCountCheck = [&](SmallVector<int64_t> tileSizes,
+                            SmallVector<int64_t> tripCounts,
+                            int64_t range) -> bool {
+    int64_t computedRange = 0;
+    for (auto [tileSize, tripCount] : llvm::zip(tileSizes, tripCounts))
+      computedRange += tileSize * tripCount;
+    return range == computedRange;
+  };
+
+  if (!tripCountCheck(spec.tileSizes, spec.tripCounts, loopRange))
+    return failure();
+
+  return spec;
+}
+
+FailureOr<ContinuousTileSizeSpecification>
+mlir::linalg::computeContinuousTileSizes(OpBuilder &builder, LinalgOp op,
+                                         unsigned dimension,
+                                         OpFoldResult targetSize,
+                                         bool emitAssertions) {
+
+  // Bail out on dimension overflow.
+  if (dimension >= op.getNumLoops())
+    return failure();
+
+  // The code below works only on values.
+  Location loc = op.getLoc();
+  ImplicitLocOpBuilder b(loc, builder);
+  if (emitAssertions) {
+    emitIsPositiveIndexAssertion(b, targetSize);
+  }
+  Value targetSizeValue =
+      getValueOrCreateConstantIndexOp(builder, loc, targetSize);
+
+  // Find the trip count of the iteration space dimension for which the tile
+  // sizes are computed.
+  SmallVector<OpFoldResult> allShapes =
+      op.createFlatListOfOperandDims(b, b.getLoc());
+  AffineMap shapesToLoops = op.getShapesToLoopsMap();
+  SmallVector<OpFoldResult> loopRanges =
+      makeComposedFoldedMultiResultAffineApply(b, op.getLoc(), shapesToLoops,
+                                               allShapes);
+
+  Value loopRange =
+      getValueOrCreateConstantIndexOp(b, op.getLoc(), loopRanges[dimension]);
+
+  ContinuousTileSizeSpecification spec;
+
+  // Compute the tile sizes and the respective numbers of tiles.
+  AffineExpr s0 = b.getAffineSymbolExpr(0);
+  AffineExpr s1 = b.getAffineSymbolExpr(1);
+  auto apply = [&](AffineExpr expr, ArrayRef<OpFoldResult> ofrs) -> Value {
+    return affine::makeComposedAffineApply(b, b.getLoc(), expr, ofrs);
+  };
+
+  Value tripCountValue = apply(s0.floorDiv(s1), {loopRange, targetSizeValue});
+  Value remainderChunkValue = apply(s0 % s1, {loopRange, targetSizeValue});
+
+  OpFoldResult tripCountSize = affine::makeComposedFoldedAffineApply(
+      b, b.getLoc(), s0.floorDiv(s1), {loopRange, targetSizeValue});
+
+  uint64_t tileSizeInt = *getConstantIntValue(targetSizeValue);
+
+  assert(tileSizeInt > 0 && "target size must be non-negative");
+
+  spec.tileSizes.push_back(targetSizeValue);
+  spec.tripCounts.push_back(tripCountValue);
+
+  while (tileSizeInt > 1) {
+    uint64_t maxPower = llvm::bit_floor(tileSizeInt);
+    tileSizeInt = maxPower == tileSizeInt ? maxPower >> 1 : maxPower;
+    auto constStepOp =
+        builder.createOrFold<arith::ConstantIndexOp>(b.getLoc(), tileSizeInt);
+    tripCountValue = apply(s0.floorDiv(s1), {remainderChunkValue, constStepOp});
+
+    tripCountSize = affine::makeComposedFoldedAffineApply(
+        b, b.getLoc(), s0.floorDiv(s1), {remainderChunkValue, constStepOp});
+
+    // Optimization if tripCount can be determined to be zero.
+    if (Attribute attr = llvm::dyn_cast_if_present<Attribute>(tripCountSize)) {
+      auto intAttr = cast<IntegerAttr>(attr);
+      bool isTripCountZero = intAttr.getValue().isZero();
+
+      if (!isTripCountZero) {
+        spec.tileSizes.push_back(constStepOp);
+        spec.tripCounts.push_back(tripCountValue);
+      }
+    } else {
+      spec.tileSizes.push_back(constStepOp);
+      spec.tripCounts.push_back(tripCountValue);
+    }
+
+    remainderChunkValue = apply(s0 % s1, {remainderChunkValue, constStepOp});
+  }
+
+  return spec;
+}
+
 FailureOr<StaticMultiSizeSpecification>
 mlir::linalg::computeStaticMultiTileSizes(LinalgOp op, unsigned dimension,
                                           int64_t targetSize, int64_t divisor) {

>From e1adf1ed387c34c1f226a0c1b460f3e8a538dbc4 Mon Sep 17 00:00:00 2001
From: Muneeb Khan <muneeb.khan at huawei.com>
Date: Wed, 1 May 2024 22:26:19 +0800
Subject: [PATCH 2/3] [MLIR] Add support for multiway split in SplitOp

Add functionality that enables SplitOp to do a multiway split of
a traget linalg along a given dimension. When multiway attribute
is `true`, the SplitOp takes a list of split points and applies
it to a single linalg along the given dimension to generate
multiple linalgs extracted from the target.
---
 .../Linalg/TransformOps/LinalgTransformOps.td |  23 ++-
 .../TransformOps/LinalgTransformOps.cpp       | 150 +++++++++++++-----
 2 files changed, 123 insertions(+), 50 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 6e6bfab8d179e..51138104c055e 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -1396,7 +1396,7 @@ def SplitOp : Op<Transform_Dialect, "structured.split",
      DeclareOpInterfaceMethods<TransformOpInterface>,
      ReportTrackingListenerFailuresOpTrait]> {
   let description = [{
-    Indicates that the given `target` op should be split into two complementary
+    Splits the given `target` op into two or more complementary
     parts, which combined cover the entire iteration domain of the original op.
     The split is performed along the iteration space dimension provided as
     attribute. In case of dimension overflow, the transformation fails. The
@@ -1409,16 +1409,27 @@ def SplitOp : Op<Transform_Dialect, "structured.split",
     operations pointed to by the target handle.
 
     The operation consumes the target handle, but preserves the split point
-    handle if provided. It produces two new handles pointing to the two parts
-    of the structured op after splitting, in the same order as the target
-    operand, with the first handle corresponding to the part with lower
-    iteration space indices.
+    handle if provided. Without the `multiway` attribute, it produces two
+    new handles pointing to the two parts of the structured op after splitting,
+    in the same order as the target operand, with the first handle
+    corresponding to the part with lower iteration space indices.
+
+    Multiway split mode is enabled by specifying the `multiway` attribute.
+    In this mode a single `target` op is split into multiple parts covering
+    the iteration space of the specified dimension. `static_split_point` and
+    `dynamic_split_point` in this case is a list of chunk sizes that the given
+    dimension should be split into. With `multiway` it produces two handles;
+    the first handle is a list of the multiple parts of the structured op
+    after splitting, where the target dimensions for each linalg op in the
+    list corresponds to the chunk sizes specfied in the input split list.
+    The second handle is empty.
   }];
 
   let arguments = (ins TransformHandleTypeInterface:$target,
                        I64Attr:$dimension,
                        Optional<TransformAnyParamTypeOrAnyHandle>:$dynamic_split_point,
-                       I64Attr:$static_split_point);
+                       I64Attr:$static_split_point,
+                       UnitAttr:$multiway);
   let results = (outs TransformHandleTypeInterface:$first,
                       TransformHandleTypeInterface:$second);
   let hasCustomAssemblyFormat = 1;
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 4b2d9485c16d1..907a7bfb72a59 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2269,8 +2269,20 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
   // Collect the dynamic split points if provided.
   SmallVector<Operation *> payload =
       llvm::to_vector(state.getPayloadOps(getTarget()));
+
+  bool isMultiwaySplit = getMultiway() ? true : false;
+
+  if (isMultiwaySplit && !llvm::hasSingleElement(payload)) {
+    return emitDefiniteFailure() << "requires exactly one target when "
+                                    "multiway split is enabled (got "
+                                 << llvm::range_size(payload) << ")";
+  }
+
   SmallVector<OpFoldResult> splitPoints;
-  splitPoints.reserve(payload.size());
+
+  if (!isMultiwaySplit)
+    splitPoints.reserve(payload.size());
+
   if (getDynamicSplitPoint()) {
     auto diag = DiagnosedSilenceableFailure::success();
     if (isa<TransformHandleTypeInterface>(getDynamicSplitPoint().getType())) {
@@ -2293,7 +2305,9 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
     if (diag.isSilenceableFailure())
       return diag;
 
-    if (splitPoints.size() != payload.size()) {
+    // For multiway split, a single payload is expected to have multiple
+    // split points.
+    if (!isMultiwaySplit && splitPoints.size() != payload.size()) {
       return emitDefiniteFailure()
              << "expected the dynamic split point handle to point to as "
                 "many operations ("
@@ -2305,57 +2319,105 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
                        rewriter.getIndexAttr(getStaticSplitPoint()));
   }
 
-  // Split each target operation.
-  SmallVector<Operation *> first, second;
-  Operation *noSecondPart = nullptr;
-  for (const auto &pair : llvm::zip(payload, splitPoints)) {
-    Operation *target = std::get<0>(pair);
-    auto linalgOp = dyn_cast<LinalgOp>(target);
-    if (!linalgOp) {
-      auto diag = emitSilenceableError() << "only applies to structured ops";
-      diag.attachNote(target->getLoc()) << "target op";
-      return diag;
-    }
+  if (isMultiwaySplit) {
 
-    if (getDimension() >= linalgOp.getNumLoops()) {
-      auto diag = emitSilenceableError() << "dimension " << getDimension()
-                                         << " does not exist in target op";
-      diag.attachNote(target->getLoc()) << "target op";
-      return diag;
+    // Split a single target operation at multiple points.
+    SmallVector<Operation *> opList;
+    Operation *head, *tail;
+    for (const auto [idx, splitPoint] : llvm::enumerate(splitPoints)) {
+
+      Operation *target;
+      if (idx == 0)
+        target = payload.front();
+      else
+        target = tail;
+
+      if (!target)
+        break;
+
+      auto linalgOp = dyn_cast<LinalgOp>(target);
+
+      if (!linalgOp) {
+        auto diag = emitSilenceableError() << "only applies to structured ops";
+        diag.attachNote(target->getLoc()) << "target op";
+        return diag;
+      }
+
+      if (getDimension() >= linalgOp.getNumLoops()) {
+        auto diag = emitSilenceableError() << "dimension " << getDimension()
+                                           << " does not exist in target op";
+        diag.attachNote(target->getLoc()) << "target op";
+        return diag;
+      }
+
+      rewriter.setInsertionPoint(linalgOp);
+      std::tie(head, tail) = linalg::splitOp(
+          rewriter, cast<TilingInterface>(linalgOp.getOperation()),
+          getDimension(), splitPoint);
+
+      opList.push_back(head);
     }
 
-    rewriter.setInsertionPoint(linalgOp);
-    std::tie(first.emplace_back(), second.emplace_back()) = linalg::splitOp(
-        rewriter, cast<TilingInterface>(linalgOp.getOperation()),
-        getDimension(), std::get<1>(pair));
+    // Append any leftover parts to the end of the result list.
+    if (tail)
+      opList.push_back(tail);
+    results.set(cast<OpResult>(getFirst()), opList);
+    results.set(cast<OpResult>(getSecond()), {});
 
-    // Propagate errors.
-    if (!first.back() && !second.back()) {
-      auto diag = emitDefiniteFailure() << "internal failure in splitting";
-      diag.attachNote(target->getLoc()) << "target op";
-      return diag;
+  } else {
+    // Split each target operation.
+    SmallVector<Operation *> first, second;
+    Operation *noSecondPart = nullptr;
+    for (const auto &pair : llvm::zip(payload, splitPoints)) {
+      Operation *target = std::get<0>(pair);
+      auto linalgOp = dyn_cast<LinalgOp>(target);
+      if (!linalgOp) {
+        auto diag = emitSilenceableError() << "only applies to structured ops";
+        diag.attachNote(target->getLoc()) << "target op";
+        return diag;
+      }
+
+      if (getDimension() >= linalgOp.getNumLoops()) {
+        auto diag = emitSilenceableError() << "dimension " << getDimension()
+                                           << " does not exist in target op";
+        diag.attachNote(target->getLoc()) << "target op";
+        return diag;
+      }
+
+      rewriter.setInsertionPoint(linalgOp);
+      std::tie(first.emplace_back(), second.emplace_back()) = linalg::splitOp(
+          rewriter, cast<TilingInterface>(linalgOp.getOperation()),
+          getDimension(), std::get<1>(pair));
+
+      // Propagate errors.
+      if (!first.back() && !second.back()) {
+        auto diag = emitDefiniteFailure() << "internal failure in splitting";
+        diag.attachNote(target->getLoc()) << "target op";
+        return diag;
+      }
+
+      // Do not add null second parts.
+      if (!second.back()) {
+        noSecondPart = target;
+        second.pop_back();
+      }
     }
 
-    // Do not add null second parts.
-    if (!second.back()) {
-      noSecondPart = target;
-      second.pop_back();
+    if (second.size() != first.size() && !second.empty()) {
+      auto diag = emitSilenceableError()
+                  << "splitting does not produce the second part for a subset "
+                     "of targets";
+      diag.attachNote()
+          << "expected splitting to produce the second part of all "
+             "or none of the targets";
+      diag.attachNote(noSecondPart->getLoc())
+          << "first target with no second part";
+      return diag;
     }
-  }
 
-  if (second.size() != first.size() && !second.empty()) {
-    auto diag = emitSilenceableError()
-                << "splitting does not produce the second part for a subset "
-                   "of targets";
-    diag.attachNote() << "expected splitting to produce the second part of all "
-                         "or none of the targets";
-    diag.attachNote(noSecondPart->getLoc())
-        << "first target with no second part";
-    return diag;
+    results.set(cast<OpResult>(getFirst()), first);
+    results.set(cast<OpResult>(getSecond()), second);
   }
-
-  results.set(cast<OpResult>(getFirst()), first);
-  results.set(cast<OpResult>(getSecond()), second);
   return DiagnosedSilenceableFailure::success();
 }
 

>From 646cc2a932e448fc5550ae077961fc0fc85a7642 Mon Sep 17 00:00:00 2001
From: Muneeb Khan <muneeb.khan at huawei.com>
Date: Thu, 2 May 2024 19:10:49 +0800
Subject: [PATCH 3/3] [MLIR] Test multiway SplitOp

Tests SplitOp for multiway splitting of a linalg op using
the result of `continuous_tile_sizes` to specify mutliple
split points for a single linalg op.
---
 .../continuous-tiling-multiway-split.mlir     | 100 ++++++++++++++++++
 1 file changed, 100 insertions(+)
 create mode 100644 mlir/test/Dialect/Linalg/continuous-tiling-multiway-split.mlir

diff --git a/mlir/test/Dialect/Linalg/continuous-tiling-multiway-split.mlir b/mlir/test/Dialect/Linalg/continuous-tiling-multiway-split.mlir
new file mode 100644
index 0000000000000..609766fbdc91f
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/continuous-tiling-multiway-split.mlir
@@ -0,0 +1,100 @@
+// RUN: mlir-opt --transform-interpreter --canonicalize --split-input-file %s | FileCheck %s
+
+// This tests the results of continuous_tile_sizes on multiway splitOp.
+// continuous_tile_sizes returns a list of tile-sizes and a list of split points.
+// The list of split points is consumed by splitOp to split the linalg.matmul op
+// along dimension 1 to produce as many split-up linalg.matmul ops.
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %tiles, %splits = transform.structured.continuous_tile_sizes %0 { dimension = 1, target_size = 9} : (!transform.any_op) -> !transform.any_op
+    %low, %high = transform.structured.split %0 after %splits { dimension = 1, multiway } : !transform.any_op, !transform.any_op
+    transform.yield
+  }
+}
+
+func.func @continuous_tile_linalg_matmul(
+  %arg0: tensor<25x34xf32>, %arg1: tensor<34x25xf32>, %arg2: tensor<25x25xf32>)
+    -> tensor<25x25xf32> {
+  %0 = linalg.matmul  ins(%arg0, %arg1: tensor<25x34xf32>, tensor<34x25xf32>)
+                     outs(%arg2: tensor<25x25xf32>)
+    -> tensor<25x25xf32>
+
+  return %0 : tensor<25x25xf32>
+}
+
+// CHECK-LABEL: @continuous_tile_linalg_matmul
+// CHECK-SAME: %[[IN1:.+]]: tensor<25x34xf32>, %[[IN2:.+]]: tensor<34x25xf32>, %[[OUT:.+]]: tensor<25x25xf32>
+// CHECK:      %[[SLICE:.+]] = tensor.extract_slice %[[IN2]][0, 0] [34, 18] [1, 1] : tensor<34x25xf32> to tensor<34x18xf32>
+// CHECK       %[[SLICE0:.+]] = tensor.extract_slice %[[OUT]][0, 0] [25, 18] [1, 1] : tensor<25x25xf32> to tensor<25x18xf32>
+// CHECK       %[[MM0:.+]] = linalg.matmul ins(%[[IN1]], %[[SLICE]] : tensor<25x34xf32>, tensor<34x18xf32>) outs(%[[SLICE0]] : tensor<25x18xf32>) -> tensor<25x18xf32>
+// CHECK       %[[INSLICE:.+]] = tensor.insert_slice %[[MM0]] into %[[OUT]][0, 0] [25, 18] [1, 1] : tensor<25x18xf32> into tensor<25x25xf32>
+// CHECK       %[[SLICE1]] = tensor.extract_slice %[[IN2]][0, 18] [34, 7] [1, 1] : tensor<34x25xf32> to tensor<34x7xf32>
+// CHECK       %[[SLICE2]] = tensor.extract_slice %[[INSLICE]][0, 18] [25, 7] [1, 1] : tensor<25x25xf32> to tensor<25x7xf32>
+// CHECK       %[[SLICE3]] = tensor.extract_slice %[[SLICE1]][0, 0] [34, 4] [1, 1] : tensor<34x7xf32> to tensor<34x4xf32>
+// CHECK       %[[SLICE4]] = tensor.extract_slice %[[SLICE2]][0, 0] [25, 4] [1, 1] : tensor<25x7xf32> to tensor<25x4xf32>
+// CHECK       %[[MM1:.+]] = linalg.matmul ins(%[[IN1]], %[[SLICE3]] : tensor<25x34xf32>, tensor<34x4xf32>) outs(%[[SLICE4]] : tensor<25x4xf32>) -> tensor<25x4xf32>
+// CHECK       %[[INSLICE0:.+]] = tensor.insert_slice %[[MM1]] into %[[SLICE2]][0, 0] [25, 4] [1, 1] : tensor<25x4xf32> into tensor<25x7xf32>
+// CHECK       %[[SLICE5]] = tensor.extract_slice %[[SLICE1]][0, 4] [34, 3] [1, 1] : tensor<34x7xf32> to tensor<34x3xf32>
+// CHECK       %[[SLICE6]] = tensor.extract_slice %[[INSLICE0]][0, 4] [25, 3] [1, 1] : tensor<25x7xf32> to tensor<25x3xf32>
+// CHECK       %[[SLICE7]] = tensor.extract_slice %[[SLICE5]][0, 0] [34, 2] [1, 1] : tensor<34x3xf32> to tensor<34x2xf32>
+// CHECK       %[[SLICE8]] = tensor.extract_slice %[[SLICE6]][0, 0] [25, 2] [1, 1] : tensor<25x3xf32> to tensor<25x2xf32>
+// CHECK       %[[MM2:.+]] = linalg.matmul ins(%[[IN1]], %[[SLICE7]] : tensor<25x34xf32>, tensor<34x2xf32>) outs(%[[SLICE8]] : tensor<25x2xf32>) -> tensor<25x2xf32>
+// CHECK       %[[INSLICE1:.+]] = tensor.insert_slice %[[MM2]] into %[[SLICE6]][0, 0] [25, 2] [1, 1] : tensor<25x2xf32> into tensor<25x3xf32>
+// CHECK       %[[SLICE9]] = tensor.extract_slice %[[SLICE5]][0, 2] [34, 1] [1, 1] : tensor<34x3xf32> to tensor<34x1xf32>
+// CHECK       %[[SLICE10]] = tensor.extract_slice %[[INSLICE1]][0, 2] [25, 1] [1, 1] : tensor<25x3xf32> to tensor<25x1xf32>
+// CHECK       %[[MM3:.+]] = linalg.matmul ins(%[[IN1]], %[[SLICE9]] : tensor<25x34xf32>, tensor<34x1xf32>) outs(%[[SLICE10]] : tensor<25x1xf32>) -> tensor<25x1xf32>
+// CHECK       %[[INSLICE2]] = tensor.insert_slice %[[MM3]] into %[[INSLICE1]][0, 2] [25, 1] [1, 1] : tensor<25x1xf32> into tensor<25x3xf32>
+// CHECK       %[[INSLICE3]] = tensor.insert_slice %[[INSLICE2]] into %[[INSLICE0]][0, 4] [25, 3] [1, 1] : tensor<25x3xf32> into tensor<25x7xf32>
+// CHECK       %[[INSLICE4]] = tensor.insert_slice %[[INSLICE3]] into %[[INSLICE]][0, 18] [25, 7] [1, 1] : tensor<25x7xf32> into tensor<25x25xf32>
+// CHECK       return %[[INSLICE4]] : tensor<25x25xf32>
+
+// -----
+
+// Tests the same as above except that the !transform.param<i64> output type in
+// continuous_tile_sizes op triggers tile sizes and split points to be computed
+// statically and not dynamically.
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %tiles, %splits = transform.structured.continuous_tile_sizes %0 { dimension = 1, target_size = 9} : (!transform.any_op) -> !transform.param<i64>
+    %low, %high = transform.structured.split %0 after %splits { dimension = 1, multiway } : !transform.any_op, !transform.param<i64>
+    transform.yield
+  }
+}
+
+func.func @continuous_tile_static_linalg_matmul(
+  %arg0: tensor<25x34xf32>, %arg1: tensor<34x25xf32>, %arg2: tensor<25x25xf32>)
+    -> tensor<25x25xf32> {
+  %0 = linalg.matmul  ins(%arg0, %arg1: tensor<25x34xf32>, tensor<34x25xf32>)
+                     outs(%arg2: tensor<25x25xf32>)
+    -> tensor<25x25xf32>
+
+  return %0 : tensor<25x25xf32>
+}
+
+// CHECK-LABEL: @continuous_tile_static_linalg_matmul
+// CHECK-SAME: %[[IN1:.+]]: tensor<25x34xf32>, %[[IN2:.+]]: tensor<34x25xf32>, %[[OUT:.+]]: tensor<25x25xf32>
+// CHECK:      %[[SLICE:.+]] = tensor.extract_slice %[[IN2]][0, 0] [34, 18] [1, 1] : tensor<34x25xf32> to tensor<34x18xf32>
+// CHECK       %[[SLICE0:.+]] = tensor.extract_slice %[[OUT]][0, 0] [25, 18] [1, 1] : tensor<25x25xf32> to tensor<25x18xf32>
+// CHECK       %[[MM0:.+]] = linalg.matmul ins(%[[IN1]], %[[SLICE]] : tensor<25x34xf32>, tensor<34x18xf32>) outs(%[[SLICE0]] : tensor<25x18xf32>) -> tensor<25x18xf32>
+// CHECK       %[[INSLICE:.+]] = tensor.insert_slice %[[MM0]] into %[[OUT]][0, 0] [25, 18] [1, 1] : tensor<25x18xf32> into tensor<25x25xf32>
+// CHECK       %[[SLICE1]] = tensor.extract_slice %[[IN2]][0, 18] [34, 7] [1, 1] : tensor<34x25xf32> to tensor<34x7xf32>
+// CHECK       %[[SLICE2]] = tensor.extract_slice %[[INSLICE]][0, 18] [25, 7] [1, 1] : tensor<25x25xf32> to tensor<25x7xf32>
+// CHECK       %[[SLICE3]] = tensor.extract_slice %[[SLICE1]][0, 0] [34, 4] [1, 1] : tensor<34x7xf32> to tensor<34x4xf32>
+// CHECK       %[[SLICE4]] = tensor.extract_slice %[[SLICE2]][0, 0] [25, 4] [1, 1] : tensor<25x7xf32> to tensor<25x4xf32>
+// CHECK       %[[MM1:.+]] = linalg.matmul ins(%[[IN1]], %[[SLICE3]] : tensor<25x34xf32>, tensor<34x4xf32>) outs(%[[SLICE4]] : tensor<25x4xf32>) -> tensor<25x4xf32>
+// CHECK       %[[INSLICE0:.+]] = tensor.insert_slice %[[MM1]] into %[[SLICE2]][0, 0] [25, 4] [1, 1] : tensor<25x4xf32> into tensor<25x7xf32>
+// CHECK       %[[SLICE5]] = tensor.extract_slice %[[SLICE1]][0, 4] [34, 3] [1, 1] : tensor<34x7xf32> to tensor<34x3xf32>
+// CHECK       %[[SLICE6]] = tensor.extract_slice %[[INSLICE0]][0, 4] [25, 3] [1, 1] : tensor<25x7xf32> to tensor<25x3xf32>
+// CHECK       %[[SLICE7]] = tensor.extract_slice %[[SLICE5]][0, 0] [34, 2] [1, 1] : tensor<34x3xf32> to tensor<34x2xf32>
+// CHECK       %[[SLICE8]] = tensor.extract_slice %[[SLICE6]][0, 0] [25, 2] [1, 1] : tensor<25x3xf32> to tensor<25x2xf32>
+// CHECK       %[[MM2:.+]] = linalg.matmul ins(%[[IN1]], %[[SLICE7]] : tensor<25x34xf32>, tensor<34x2xf32>) outs(%[[SLICE8]] : tensor<25x2xf32>) -> tensor<25x2xf32>
+// CHECK       %[[INSLICE1:.+]] = tensor.insert_slice %[[MM2]] into %[[SLICE6]][0, 0] [25, 2] [1, 1] : tensor<25x2xf32> into tensor<25x3xf32>
+// CHECK       %[[SLICE9]] = tensor.extract_slice %[[SLICE5]][0, 2] [34, 1] [1, 1] : tensor<34x3xf32> to tensor<34x1xf32>
+// CHECK       %[[SLICE10]] = tensor.extract_slice %[[INSLICE1]][0, 2] [25, 1] [1, 1] : tensor<25x3xf32> to tensor<25x1xf32>
+// CHECK       %[[MM3:.+]] = linalg.matmul ins(%[[IN1]], %[[SLICE9]] : tensor<25x34xf32>, tensor<34x1xf32>) outs(%[[SLICE10]] : tensor<25x1xf32>) -> tensor<25x1xf32>
+// CHECK       %[[INSLICE2]] = tensor.insert_slice %[[MM3]] into %[[INSLICE1]][0, 2] [25, 1] [1, 1] : tensor<25x1xf32> into tensor<25x3xf32>
+// CHECK       %[[INSLICE3]] = tensor.insert_slice %[[INSLICE2]] into %[[INSLICE0]][0, 4] [25, 3] [1, 1] : tensor<25x3xf32> into tensor<25x7xf32>
+// CHECK       %[[INSLICE4]] = tensor.insert_slice %[[INSLICE3]] into %[[INSLICE]][0, 18] [25, 7] [1, 1] : tensor<25x7xf32> into tensor<25x25xf32>
+// CHECK       return %[[INSLICE4]] : tensor<25x25xf32>



More information about the Mlir-commits mailing list