[Mlir-commits] [mlir] [MLIR] Add continuous tiling to transform dialect (PR #82792)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri May 31 03:35:30 PDT 2024


https://github.com/muneebkhan85 updated https://github.com/llvm/llvm-project/pull/82792

>From fd216134a523d75671549fa0459b53284cf566b0 Mon Sep 17 00:00:00 2001
From: Muneeb Khan <muneeb.khan at huawei.com>
Date: Wed, 1 May 2024 22:15:18 +0800
Subject: [PATCH 1/3] [MLIR] Add continuous tiling to Transform dialect

Add continuous tiling op `structured.continuous_tile`
to the transform dialect that returns as result (1) a list of
exponentially diminishing tile sizes, and (2) a list of chunk
sizes -- along the specified dimension of the target --
where the corresponding tile sizes from (1) can be applied.
The list of chunk sizes from (2) cover the entire iteration
space along the given dimension of the target.
---
 .../Linalg/TransformOps/LinalgTransformOps.td |  45 ++++++
 .../Dialect/Linalg/Transforms/Transforms.h    |  21 +++
 .../TransformOps/LinalgTransformOps.cpp       | 148 ++++++++++++++++++
 mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp | 131 ++++++++++++++++
 4 files changed, 345 insertions(+)

diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 93e2c2db729da..e46dc1565964a 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -1819,6 +1819,51 @@ def TileReductionUsingForallOp :
 
 }
 
+//===----------------------------------------------------------------------===//
+// ContinuousTileSizesOp
+//===----------------------------------------------------------------------===//
+
+def ContinuousTileSizesOp : Op<Transform_Dialect, "structured.continuous_tile_sizes",
+       [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+        DeclareOpInterfaceMethods<TransformOpInterface>,
+        ReportTrackingListenerFailuresOpTrait]> {
+  let description = [{
+    This transform emits the IR computing the list of (1) exponentially
+    diminishing tile sizes that are powers of 2; and (2) the corresponding
+    chunk-sizes the target op should be split into along the given dimension.
+
+    For example, for `target_size` 9, and `dimension` 0 for the following
+    linalg op as target
+
+    ```
+      %0 = linalg.matmul  ins(%arg0, %arg1: tensor<25x34xf32>, tensor<34x25xf32>)
+                      outs(%arg2: tensor<25x25xf32>)
+    ```
+
+    the first result `tile_sizes` will be a list of diminishing tile sizes
+    9, 4, 2, 1; and the second result will be a list of chunk sizes
+    18, 4, 2, 1 that the corresponding dimension should be split into.
+
+    After the target op has been split along the given dimension (for example
+    using multiway split), each chunk can be tiled with the corresponding tile
+    size in the `tile_sizes` list generated as a result of this op.
+
+    Specifying the output type as !transform.param<i64> will cause `tile_sizes`
+    and `chunk_sizes` to be computed statically and not dynamically.
+  }];
+
+  let arguments = (ins TransformHandleTypeInterface:$target,
+                       ConfinedAttr<I64Attr, [IntNonNegative]>:$dimension,
+                       ConfinedAttr<I64Attr, [IntNonNegative]>:$target_size);
+  let results = (outs TransformAnyParamTypeOrAnyHandle:$tile_sizes,
+                      TransformAnyParamTypeOrAnyHandle:$chunk_sizes);
+  let hasVerifier = 1;
+  let assemblyFormat =
+    "$target attr-dict `:` custom<ContinuousTileSizeTypes>("
+    "type($target), type($tile_sizes), type($chunk_sizes))";
+
+}
+
 //===----------------------------------------------------------------------===//
 // TileUsingForOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 308ce92e35520..da9300499d096 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -801,6 +801,15 @@ struct MultiSizeSpecificationBase {
   /// Number of tiles associated with each size.
   T lowTripCount, highTripCount;
 };
+
+template <typename T>
+struct ContinuousTileSizeSpecificationBase {
+  /// Tile sizes.
+  SmallVector<T> tileSizes;
+  /// Number of tiles associated with each size.
+  SmallVector<T> tripCounts;
+};
+
 } // namespace detail
 
 /// A description of a multi-size tiling comprising tile sizes and numbers of
@@ -811,6 +820,11 @@ struct MultiSizeSpecification
 struct StaticMultiSizeSpecification
     : public detail::MultiSizeSpecificationBase<int64_t> {};
 
+struct ContinuousTileSizeSpecification
+    : public detail::ContinuousTileSizeSpecificationBase<Value> {};
+struct StaticContinuousTileSizeSpecification
+    : public detail::ContinuousTileSizeSpecificationBase<int64_t> {};
+
 /// Emits the IR computing the multi-sized tiling specification with two tile
 /// sizes not exceeding `targetSize`, each divisible by `sizeDivisor`, such
 /// that there exist numbers of tiles with these sizes that fully cover the
@@ -846,6 +860,13 @@ FailureOr<StaticMultiSizeSpecification>
 computeStaticMultiTileSizes(LinalgOp op, unsigned dimension, int64_t targetSize,
                             int64_t divisor);
 
+FailureOr<StaticContinuousTileSizeSpecification>
+computeStaticContinuousTileSizes(LinalgOp op, unsigned dimension,
+                                 unsigned targetSize);
+FailureOr<ContinuousTileSizeSpecification>
+computeContinuousTileSizes(OpBuilder &builder, TilingInterface op,
+                           unsigned dimension, OpFoldResult targetSize,
+                           bool emitAssertions);
 /// Rewrite a TilingInterface `op` to a tiled `scf.forall`, applying
 /// tiling by `numThreads`.
 /// If non-empty, the `mapping` is added as an attribute to the
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 9b3121774ab3a..01238147527b7 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2583,6 +2583,154 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne(
   return DiagnosedSilenceableFailure::success();
 }
 
+//===----------------------------------------------------------------------===//
+// ContinuousTileSizesOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::ContinuousTileSizesOp::apply(transform::TransformRewriter &rewriter,
+                                        TransformResults &transformResults,
+                                        TransformState &state) {
+
+  SmallVector<Operation *> targetOps =
+      llvm::to_vector(state.getPayloadOps(getTarget()));
+
+  if (!llvm::hasSingleElement(targetOps)) {
+    return mlir::emitSilenceableFailure(getLoc())
+           << "requires exactly one target (got " << llvm::range_size(targetOps)
+           << ")";
+  }
+
+  Operation *target = *targetOps.begin();
+  auto linalgOp = dyn_cast<LinalgOp>(target);
+  auto tileableOp = dyn_cast<TilingInterface>(target);
+
+  if (!linalgOp)
+    return emitDefiniteFailure() << "expected Linalg Op";
+
+  OpBuilder builder(linalgOp.getContext());
+
+  if (isa<TransformParamTypeInterface>(getChunkSizes().getType())) {
+    if (linalgOp.hasDynamicShape()) {
+      auto diag = emitSilenceableError()
+                  << "cannot compute parametric tile sizes for dynamically "
+                     "shaped payload op";
+      diag.attachNote(linalgOp->getLoc()) << "payload op";
+      return diag;
+    }
+
+    FailureOr<StaticContinuousTileSizeSpecification> spec =
+        computeStaticContinuousTileSizes(linalgOp, getDimension(),
+                                         getTargetSize());
+    if (failed(spec)) {
+      return emitSilenceableError()
+             << "failed to compute multi-size tiling sizes";
+    }
+
+    SmallVector<int64_t> chunkSizes;
+
+    for (auto &&[tileSize, tripCount] :
+         llvm::zip_equal(spec->tileSizes, spec->tripCounts))
+      chunkSizes.push_back(tileSize * tripCount);
+
+    auto getI64AttrsFromI64 = [&](ArrayRef<int64_t> values) {
+      return llvm::map_to_vector(values, [&](int64_t value) -> Attribute {
+            return builder.getI64IntegerAttr(value);
+          });
+    };
+    transformResults.setParams(cast<OpResult>(getTileSizes()),
+                               getI64AttrsFromI64(spec->tileSizes));
+    transformResults.setParams(cast<OpResult>(getChunkSizes()),
+                               getI64AttrsFromI64(chunkSizes));
+
+    return DiagnosedSilenceableFailure::success();
+  }
+
+  builder.setInsertionPoint(linalgOp);
+
+  OpFoldResult targetSize = builder.getIndexAttr(getTargetSize());
+  unsigned dimension = getDimension();
+
+  FailureOr<ContinuousTileSizeSpecification> spec = computeContinuousTileSizes(
+      builder, tileableOp, dimension, targetSize, true);
+  if (failed(spec)) {
+    return emitSilenceableError() << "could not generate tile size computation";
+  }
+
+  AffineExpr s0 = builder.getAffineSymbolExpr(0);
+  AffineExpr s1 = builder.getAffineSymbolExpr(1);
+  auto apply = [&](AffineExpr expr, ArrayRef<OpFoldResult> ofrs) -> Value {
+    return affine::makeComposedAffineApply(builder, linalgOp->getLoc(), expr,
+                                           ofrs);
+  };
+
+  SmallVector<Value> chunkSizes;
+  Value splitPoint;
+  for (auto &&[tileSize, tripCount] :
+       llvm::zip_equal(spec->tileSizes, spec->tripCounts)) {
+    splitPoint = apply(s0 * s1, {tileSize, tripCount});
+    chunkSizes.push_back(splitPoint);
+  }
+
+  auto getDefiningOps = [&](ArrayRef<Value> values) {
+        return llvm::map_to_vector(values, [&](Value value) -> Operation * {
+          return value.getDefiningOp();
+        });
+  };
+
+  transformResults.set(cast<OpResult>(getTileSizes()),
+                       getDefiningOps(spec->tileSizes));
+  transformResults.set(cast<OpResult>(getChunkSizes()),
+                       getDefiningOps(chunkSizes));
+
+  return DiagnosedSilenceableFailure::success();
+}
+
+LogicalResult transform::ContinuousTileSizesOp::verify() {
+
+  if (getTileSizes().getType() != getChunkSizes().getType()) {
+    return emitOpError() << "expects all results type to be the same";
+  }
+
+  return success();
+}
+
+void transform::ContinuousTileSizesOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  if (isa<TransformParamTypeInterface>(getTileSizes().getType()))
+    onlyReadsPayload(effects);
+  else
+    modifiesPayload(effects);
+  onlyReadsHandle(getTarget(), effects);
+  producesHandle(getTileSizes(), effects);
+  producesHandle(getChunkSizes(), effects);
+}
+
+static void printContinuousTileSizeTypes(OpAsmPrinter &printer, Operation *op,
+                                         Type targetType, Type tile_sizes,
+                                         Type) {
+  printer.printFunctionalType(TypeRange{targetType}, TypeRange{tile_sizes});
+}
+
+static ParseResult parseContinuousTileSizeTypes(OpAsmParser &parser,
+                                                Type &targetType,
+                                                Type &tileSizesType,
+                                                Type &chunkSizesType) {
+  FunctionType funcType;
+  llvm::SMLoc typeLoc = parser.getCurrentLocation();
+  if (failed(parser.parseType<FunctionType>(funcType)))
+    return failure();
+
+  if (funcType.getNumInputs() != 1 || funcType.getNumResults() != 1) {
+    parser.emitError(typeLoc) << "expects a trailing functional type with one "
+                                 "argument and one result";
+  }
+  targetType = funcType.getInput(0);
+  tileSizesType = chunkSizesType = funcType.getResult(0);
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // TileUsingForOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index fd314ef9f8134..b8e90135dee6e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -107,6 +107,137 @@ static void emitIsPositiveIndexAssertion(ImplicitLocOpBuilder &b,
       b.getStringAttr("expected strictly positive tile size and divisor"));
 }
 
+FailureOr<StaticContinuousTileSizeSpecification>
+mlir::linalg::computeStaticContinuousTileSizes(LinalgOp op,
+                                               unsigned dimension,
+                                               unsigned targetSize) {
+
+  assert(!op.hasDynamicShape() &&
+         "cannot compute static multi-tile sizes for an op with dynamic shape");
+  assert(targetSize > 0 && "target size must be non-negative");
+  assert(dimension < op.getNumLoops() && "dimension overflow");
+
+  StaticContinuousTileSizeSpecification spec;
+  int64_t loopRange = op.getStaticLoopRanges()[dimension];
+  int64_t tripCount = loopRange / targetSize;
+
+  unsigned tileSize = targetSize;
+
+  spec.tileSizes.push_back(tileSize);
+  spec.tripCounts.push_back(tripCount);
+
+  int64_t remainderChunk = loopRange % targetSize;
+
+  while (tileSize > 1 && remainderChunk != 0) {
+
+    uint64_t maxPower = llvm::bit_floor(tileSize);
+    tileSize = maxPower == tileSize ? maxPower >> 1 : maxPower;
+
+    tripCount = remainderChunk / tileSize;
+
+    if (tripCount > 0) {
+      spec.tileSizes.push_back(tileSize);
+      spec.tripCounts.push_back(tripCount);
+    }
+
+    remainderChunk = remainderChunk % tileSize;
+  }
+
+  auto tripCountCheck = [&](SmallVector<int64_t> tileSizes,
+                            SmallVector<int64_t> tripCounts,
+                            int64_t range) -> bool {
+    int64_t computedRange = 0;
+    for (auto [tileSize, tripCount] : llvm::zip(tileSizes, tripCounts))
+      computedRange += tileSize * tripCount;
+    return range == computedRange;
+  };
+
+  if (!tripCountCheck(spec.tileSizes, spec.tripCounts, loopRange))
+    return failure();
+
+  return spec;
+}
+
+FailureOr<ContinuousTileSizeSpecification>
+mlir::linalg::computeContinuousTileSizes(OpBuilder &builder, TilingInterface op,
+                                         unsigned dimension,
+                                         OpFoldResult targetSize,
+                                         bool emitAssertions) {
+
+  SmallVector<Range> loopRanges = op.getIterationDomain(builder);
+  unsigned numLoops = loopRanges.size();
+
+  // Bail out on dimension overflow.
+  if (dimension >= numLoops)
+    return failure();
+
+  // The code below works only on values.
+  Location loc = op->getLoc();
+  ImplicitLocOpBuilder b(loc, builder);
+  if (emitAssertions) {
+    emitIsPositiveIndexAssertion(b, targetSize);
+  }
+  Value targetSizeValue =
+      getValueOrCreateConstantIndexOp(builder, loc, targetSize);
+
+  // Find the trip count of the iteration space dimension for which the tile
+  // sizes are computed.
+  Value loopRange = getValueOrCreateConstantIndexOp(b, loc,
+                                                    loopRanges[dimension].size);
+  ContinuousTileSizeSpecification spec;
+
+  // Compute the tile sizes and the respective numbers of tiles.
+  AffineExpr s0 = b.getAffineSymbolExpr(0);
+  AffineExpr s1 = b.getAffineSymbolExpr(1);
+  auto apply = [&](AffineExpr expr, ArrayRef<OpFoldResult> ofrs) -> Value {
+    return affine::makeComposedAffineApply(b, b.getLoc(), expr, ofrs);
+  };
+
+  Value tripCountValue = apply(s0.floorDiv(s1), {loopRange, targetSizeValue});
+  Value remainderChunkValue = apply(s0 % s1, {loopRange, targetSizeValue});
+
+  OpFoldResult tripCountSize = affine::makeComposedFoldedAffineApply(
+      b, b.getLoc(), s0.floorDiv(s1), {loopRange, targetSizeValue});
+
+  // emitAssertions above already asserts that targetSize is
+  // a poistive integer.
+  uint64_t tileSizeInt = *getConstantIntValue(targetSizeValue);
+
+  assert(tileSizeInt > 0 && "target size must be non-negative");
+
+  spec.tileSizes.push_back(targetSizeValue);
+  spec.tripCounts.push_back(tripCountValue);
+
+  while (tileSizeInt > 1) {
+    uint64_t maxPower = llvm::bit_floor(tileSizeInt);
+    tileSizeInt = maxPower == tileSizeInt ? maxPower >> 1 : maxPower;
+    auto constStepOp =
+        builder.createOrFold<arith::ConstantIndexOp>(b.getLoc(), tileSizeInt);
+    tripCountValue = apply(s0.floorDiv(s1), {remainderChunkValue, constStepOp});
+
+    tripCountSize = affine::makeComposedFoldedAffineApply(
+        b, b.getLoc(), s0.floorDiv(s1), {remainderChunkValue, constStepOp});
+
+    // Optimization if tripCount can be determined to be zero.
+    if (Attribute attr = llvm::dyn_cast_if_present<Attribute>(tripCountSize)) {
+      auto intAttr = cast<IntegerAttr>(attr);
+      bool isTripCountZero = intAttr.getValue().isZero();
+
+      if (!isTripCountZero) {
+        spec.tileSizes.push_back(constStepOp);
+        spec.tripCounts.push_back(tripCountValue);
+      }
+    } else {
+      spec.tileSizes.push_back(constStepOp);
+      spec.tripCounts.push_back(tripCountValue);
+    }
+
+    remainderChunkValue = apply(s0 % s1, {remainderChunkValue, constStepOp});
+  }
+
+  return spec;
+}
+
 FailureOr<StaticMultiSizeSpecification>
 mlir::linalg::computeStaticMultiTileSizes(LinalgOp op, unsigned dimension,
                                           int64_t targetSize, int64_t divisor) {

>From c3f759d6722fc30c27bb452a8a120d00e5dcd508 Mon Sep 17 00:00:00 2001
From: Muneeb Khan <muneeb.khan at huawei.com>
Date: Wed, 1 May 2024 22:26:19 +0800
Subject: [PATCH 2/3] [MLIR] Add support for multiway split in SplitOp

Add functionality that enables SplitOp to do a multiway split of
a traget op along a given dimension. With multiway attribute,
SplitOp takes a list of chunk sizes and applies it to a single
target along the given dimension to generate multiple
structured ops extracted from the target.
---
 .../Linalg/TransformOps/LinalgTransformOps.td |  40 ++--
 .../TransformOps/LinalgTransformOps.cpp       | 221 ++++++++++++------
 .../mlir/dialects/transform/structured.py     |  16 +-
 .../Dialect/Linalg/transform-op-split.mlir    |   2 +-
 .../dialects/transform_structured_ext.py      |   4 +-
 5 files changed, 189 insertions(+), 94 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index e46dc1565964a..866275cedf68b 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -1396,29 +1396,43 @@ def SplitOp : Op<Transform_Dialect, "structured.split",
      DeclareOpInterfaceMethods<TransformOpInterface>,
      ReportTrackingListenerFailuresOpTrait]> {
   let description = [{
-    Indicates that the given `target` op should be split into two complementary
+    Splits the given `target` op into two or more complementary
     parts, which combined cover the entire iteration domain of the original op.
     The split is performed along the iteration space dimension provided as
-    attribute. In case of dimension overflow, the transformation fails. The
-    split is performed at the dimension iterator value specified as either the
-    static split point attribute when it is known at transform IR construction
-    time or as the handle to an operation producing a single index-typed value
-    when it is computed by payload IR. In the latter case, the static split
+    chunk size attribute specifying the size of the lower part; the remaining
+    range in the iteration space is assigned as the upper part. In case of
+    dimension overflow, the transformation fails. The split is performed at the
+    dimension iterator value specified as either the static chunk size
+    attribute when it is known at transform IR construction time or
+    as the handle to an operation producing a single index-typed value
+    when it is computed by payload IR. In the latter case, the chunk size
     point must be set to `ShapedType::kDynamic` and the dynamic size handle
     must point to as many value-producing operations as there are structured
     operations pointed to by the target handle.
 
-    The operation consumes the target handle, but preserves the split point
-    handle if provided. It produces two new handles pointing to the two parts
-    of the structured op after splitting, in the same order as the target
-    operand, with the first handle corresponding to the part with lower
-    iteration space indices.
+    The operation consumes the target handle, but preserves the chunk size
+    handle if provided. Without the `multiway` attribute, it produces two
+    new handles pointing to the two parts of the structured op after splitting,
+    in the same order as the target operand, with the first handle
+    corresponding to the part with lower iteration space indices.
+
+    Multiway split mode is enabled by specifying the `multiway` attribute.
+    In this mode a single `target` op is split into multiple parts covering
+    the iteration space of the specified dimension. `static_chunk_sizes` and
+    `dynamic_chunk_sizes` in this case is a list of chunk sizes that the given
+    dimension should be split into. With `multiway` it produces two handles;
+    the first handle is a list of the multiple parts of the structured op
+    after splitting, where the target dimensions for each linalg op in the
+    list corresponds to the chunk sizes specfied in the input split list.
+    If the chunk sizes do not cover the entire iteration space, the leftover
+    chunk is the last payload in the first handle. The second handle is empty.
   }];
 
   let arguments = (ins TransformHandleTypeInterface:$target,
                        I64Attr:$dimension,
-                       Optional<TransformAnyParamTypeOrAnyHandle>:$dynamic_split_point,
-                       I64Attr:$static_split_point);
+                       Optional<TransformAnyParamTypeOrAnyHandle>:$dynamic_chunk_sizes,
+                       I64Attr:$static_chunk_sizes,
+                       UnitAttr:$multiway);
   let results = (outs TransformHandleTypeInterface:$first,
                       TransformHandleTypeInterface:$second);
   let hasCustomAssemblyFormat = 1;
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 01238147527b7..7347e487a899e 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2269,13 +2269,26 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
   // Collect the dynamic split points if provided.
   SmallVector<Operation *> payload =
       llvm::to_vector(state.getPayloadOps(getTarget()));
-  SmallVector<OpFoldResult> splitPoints;
-  splitPoints.reserve(payload.size());
-  if (getDynamicSplitPoint()) {
+
+  bool isMultiwaySplit = getMultiway();
+
+  if (isMultiwaySplit && !llvm::hasSingleElement(payload)) {
+    return mlir::emitSilenceableFailure(getLoc())
+           << "requires exactly one target when "
+              "multiway split is enabled (got "
+           << llvm::range_size(payload) << ")";
+  }
+
+  SmallVector<OpFoldResult> chunkSizes;
+
+  if (!isMultiwaySplit)
+    chunkSizes.reserve(payload.size());
+
+  if (getDynamicChunkSizes()) {
     auto diag = DiagnosedSilenceableFailure::success();
-    if (isa<TransformHandleTypeInterface>(getDynamicSplitPoint().getType())) {
-      splitPoints = llvm::to_vector(llvm::map_range(
-          state.getPayloadOps(getDynamicSplitPoint()), [&](Operation *op) {
+    if (isa<TransformHandleTypeInterface>(getDynamicChunkSizes().getType())) {
+      chunkSizes = llvm::to_vector(llvm::map_range(
+          state.getPayloadOps(getDynamicChunkSizes()), [&](Operation *op) {
             if (op->getNumResults() != 1 ||
                 !op->getResult(0).getType().isIndex()) {
               diag = emitSilenceableError()
@@ -2286,103 +2299,171 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
             return OpFoldResult(op->getResult(0));
           }));
     } else {
-      splitPoints = llvm::to_vector(
-          llvm::map_range(state.getParams(getDynamicSplitPoint()),
+      chunkSizes = llvm::to_vector(
+          llvm::map_range(state.getParams(getDynamicChunkSizes()),
                           [](Attribute attr) { return OpFoldResult(attr); }));
     }
     if (diag.isSilenceableFailure())
       return diag;
 
-    if (splitPoints.size() != payload.size()) {
+    // For multiway split, a single payload is expected to have multiple
+    // split points.
+    if (!isMultiwaySplit && chunkSizes.size() != payload.size()) {
       return emitDefiniteFailure()
              << "expected the dynamic split point handle to point to as "
                 "many operations ("
-             << splitPoints.size() << ") as the target handle ("
+             << chunkSizes.size() << ") as the target handle ("
              << payload.size() << ")";
     }
   } else {
-    splitPoints.resize(payload.size(),
-                       rewriter.getIndexAttr(getStaticSplitPoint()));
+    chunkSizes.resize(payload.size(),
+                       rewriter.getIndexAttr(getStaticChunkSizes()));
   }
 
-  // Split each target operation.
-  SmallVector<Operation *> first, second;
-  Operation *noSecondPart = nullptr;
-  for (const auto &pair : llvm::zip(payload, splitPoints)) {
-    Operation *target = std::get<0>(pair);
-    auto linalgOp = dyn_cast<LinalgOp>(target);
+  auto checkStructuredOpAndDimensions =
+      [&](LinalgOp linalgOp, Location loc) -> DiagnosedSilenceableFailure {
     if (!linalgOp) {
       auto diag = emitSilenceableError() << "only applies to structured ops";
-      diag.attachNote(target->getLoc()) << "target op";
+      diag.attachNote(loc) << "target op";
       return diag;
     }
 
     if (getDimension() >= linalgOp.getNumLoops()) {
       auto diag = emitSilenceableError() << "dimension " << getDimension()
-                                         << " does not exist in target op";
-      diag.attachNote(target->getLoc()) << "target op";
+                                          << " does not exist in target op";
+      diag.attachNote(loc) << "target op";
       return diag;
     }
+    return DiagnosedSilenceableFailure::success();
+  };
 
-    rewriter.setInsertionPoint(linalgOp);
-    std::tie(first.emplace_back(), second.emplace_back()) = linalg::splitOp(
-        rewriter, cast<TilingInterface>(linalgOp.getOperation()),
-        getDimension(), std::get<1>(pair));
-
-    // Propagate errors.
-    if (!first.back() && !second.back()) {
+  auto checkFailureInSplitting =
+      [&](bool hasFailed, Location loc) -> DiagnosedSilenceableFailure {
+    if (hasFailed) {
       auto diag = emitDefiniteFailure() << "internal failure in splitting";
-      diag.attachNote(target->getLoc()) << "target op";
+      diag.attachNote(loc) << "target op";
       return diag;
     }
+    return DiagnosedSilenceableFailure::success();
+  };
+
+  if (isMultiwaySplit) {
+
+    // Split a single target operation at multiple points.
+    SmallVector<Operation *> opList;
+    TilingInterface head, tail;
+    Operation *target = payload.front();
+
+    LinalgOp linalgOp = dyn_cast<LinalgOp>(target);
+    DiagnosedSilenceableFailure diag =
+        checkStructuredOpAndDimensions(linalgOp, target->getLoc());
+
+    if (diag.isSilenceableFailure())
+      return diag;
+
+    for (auto &&[idx, chunkSize] : llvm::enumerate(chunkSizes)) {
+
+      if (idx > 0)
+        target = tail.getOperation();
+
+      if (!target)
+        break;
 
-    // Do not add null second parts.
-    if (!second.back()) {
-      noSecondPart = target;
-      second.pop_back();
+      linalgOp = cast<LinalgOp>(target);
+
+      rewriter.setInsertionPoint(linalgOp);
+      std::tie(head, tail) = linalg::splitOp(
+          rewriter, cast<TilingInterface>(linalgOp.getOperation()),
+          getDimension(), chunkSize);
+
+      // Propagate errors.
+      DiagnosedSilenceableFailure diag =
+          checkFailureInSplitting(!head && !tail, target->getLoc());
+      if (diag.isDefiniteFailure())
+        return diag;
+
+      opList.push_back(head.getOperation());
     }
-  }
 
-  if (second.size() != first.size() && !second.empty()) {
-    auto diag = emitSilenceableError()
-                << "splitting does not produce the second part for a subset "
-                   "of targets";
-    diag.attachNote() << "expected splitting to produce the second part of all "
-                         "or none of the targets";
-    diag.attachNote(noSecondPart->getLoc())
-        << "first target with no second part";
-    return diag;
-  }
+    // Append any leftover parts to the end of the result list.
+    if (tail)
+      opList.push_back(tail);
+    results.set(cast<OpResult>(getFirst()), opList);
+    results.set(cast<OpResult>(getSecond()), {});
+
+  } else {
+    // Split each target operation.
+    SmallVector<Operation *> first, second;
+    Operation *noSecondPart = nullptr;
+    for (const auto &pair : llvm::zip(payload, chunkSizes)) {
+      Operation *target = std::get<0>(pair);
+      LinalgOp linalgOp = dyn_cast<LinalgOp>(target);
+      DiagnosedSilenceableFailure diag =
+          checkStructuredOpAndDimensions(linalgOp, target->getLoc());
+
+      if (diag.isSilenceableFailure())
+        return diag;
+
+      rewriter.setInsertionPoint(linalgOp);
+      std::tie(first.emplace_back(), second.emplace_back()) = linalg::splitOp(
+          rewriter, cast<TilingInterface>(linalgOp.getOperation()),
+          getDimension(), std::get<1>(pair));
+
+      // Propagate errors.
+      DiagnosedSilenceableFailure diagSplit = checkFailureInSplitting(
+          !first.back() && !second.back(), target->getLoc());
+      if (diagSplit.isDefiniteFailure())
+        return diag;
+
+      // Do not add null second parts.
+      if (!second.back()) {
+        noSecondPart = target;
+        second.pop_back();
+      }
+    }
+
+    if (second.size() != first.size() && !second.empty()) {
+      auto diag = emitSilenceableError()
+                  << "splitting does not produce the second part for a subset "
+                     "of targets";
+      diag.attachNote()
+          << "expected splitting to produce the second part of all "
+             "or none of the targets";
+      diag.attachNote(noSecondPart->getLoc())
+          << "first target with no second part";
+      return diag;
+    }
 
-  results.set(cast<OpResult>(getFirst()), first);
-  results.set(cast<OpResult>(getSecond()), second);
+    results.set(cast<OpResult>(getFirst()), first);
+    results.set(cast<OpResult>(getSecond()), second);
+  }
   return DiagnosedSilenceableFailure::success();
 }
 
 void SplitOp::getEffects(
     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
   consumesHandle(getTarget(), effects);
-  if (getDynamicSplitPoint())
-    onlyReadsHandle(getDynamicSplitPoint(), effects);
+  if (getDynamicChunkSizes())
+    onlyReadsHandle(getDynamicChunkSizes(), effects);
   producesHandle(getResults(), effects);
   modifiesPayload(effects);
 }
 
 ParseResult SplitOp::parse(OpAsmParser &parser, OperationState &result) {
-  OpAsmParser::UnresolvedOperand target, dynamicSplitPoint;
-  IntegerAttr staticSplitPoint;
+  OpAsmParser::UnresolvedOperand target, dynamicChunkSizes;
+  IntegerAttr staticChunkSizes;
   if (parser.parseOperand(target) || parser.parseKeyword("after"))
     return failure();
 
   OptionalParseResult dynamicPointParseResult =
-      parser.parseOptionalOperand(dynamicSplitPoint);
+      parser.parseOptionalOperand(dynamicChunkSizes);
   if (!dynamicPointParseResult.has_value()) {
-    int64_t staticSplitPointValue;
-    if (failed(parser.parseInteger(staticSplitPointValue)))
+    int64_t staticChunkSizesValue;
+    if (failed(parser.parseInteger(staticChunkSizesValue)))
       return failure();
 
-    staticSplitPoint =
-        parser.getBuilder().getI64IntegerAttr(staticSplitPointValue);
+    staticChunkSizes =
+        parser.getBuilder().getI64IntegerAttr(staticChunkSizesValue);
   }
 
   Type targetType;
@@ -2392,43 +2473,43 @@ ParseResult SplitOp::parse(OpAsmParser &parser, OperationState &result) {
     return failure();
   }
   if (dynamicPointParseResult.has_value()) {
-    Type splitPointType;
+    Type ChunkSizesType;
     if (failed(*dynamicPointParseResult) || parser.parseComma() ||
-        parser.parseType(splitPointType) ||
-        parser.resolveOperand(dynamicSplitPoint, splitPointType,
+        parser.parseType(ChunkSizesType) ||
+        parser.resolveOperand(dynamicChunkSizes, ChunkSizesType,
                               result.operands)) {
       return failure();
     }
 
-    staticSplitPoint =
+    staticChunkSizes =
         parser.getBuilder().getI64IntegerAttr(ShapedType::kDynamic);
   }
 
   result.addAttribute(
-      SplitOp::getStaticSplitPointAttrName(result.name).getValue(),
-      staticSplitPoint);
+      SplitOp::getStaticChunkSizesAttrName(result.name).getValue(),
+      staticChunkSizes);
   result.addTypes({targetType, targetType});
   return success();
 }
 
 void SplitOp::print(OpAsmPrinter &printer) {
   printer << " " << getTarget() << " after ";
-  int64_t staticSplitSize = static_cast<int64_t>(getStaticSplitPoint());
-  if (staticSplitSize != ShapedType::kDynamic)
-    printer << staticSplitSize;
+  int64_t staticChunkSize = static_cast<int64_t>(getStaticChunkSizes());
+  if (staticChunkSize != ShapedType::kDynamic)
+    printer << staticChunkSize;
   else
-    printer << getDynamicSplitPoint();
+    printer << getDynamicChunkSizes();
   printer << " ";
   printer.printOptionalAttrDict(getOperation()->getAttrs(),
-                                {getStaticSplitPointAttrName()});
+                                {getStaticChunkSizesAttrName()});
   printer << " : " << getTarget().getType();
-  if (staticSplitSize == ShapedType::kDynamic)
-    printer << ", " << getDynamicSplitPoint().getType();
+  if (staticChunkSize == ShapedType::kDynamic)
+    printer << ", " << getDynamicChunkSizes().getType();
 }
 
 LogicalResult SplitOp::verify() {
-  if ((static_cast<int64_t>(getStaticSplitPoint()) != ShapedType::kDynamic) ^
-      (getDynamicSplitPoint() == nullptr)) {
+  if ((static_cast<int64_t>(getStaticChunkSizes()) != ShapedType::kDynamic) ^
+      (getDynamicChunkSizes() == nullptr)) {
     return emitOpError() << "expects either a dynamic or a static split "
                             "point to be provided";
   }
diff --git a/mlir/python/mlir/dialects/transform/structured.py b/mlir/python/mlir/dialects/transform/structured.py
index 2c49ef0960c75..41051c0d5b2ff 100644
--- a/mlir/python/mlir/dialects/transform/structured.py
+++ b/mlir/python/mlir/dialects/transform/structured.py
@@ -432,25 +432,25 @@ def __init__(
         self,
         target: Union[Operation, Value],
         dimension: Union[int, Attribute],
-        split_point: Union[int, Operation, Value, Attribute],
+        chunk_sizes: Union[int, Operation, Value, Attribute],
         *,
         loc=None,
         ip=None,
     ):
-        if isinstance(split_point, int):
-            static_split_point = split_point
-            dynamic_split_point = None
+        if isinstance(chunk_sizes, int):
+            static_chunk_sizes = chunk_sizes
+            dynamic_chunk_sizes = None
         else:
-            static_split_point = ShapedType.get_dynamic_size()
-            dynamic_split_point = split_point
+            static_chunk_sizes = ShapedType.get_dynamic_size()
+            dynamic_chunk_sizes = chunk_sizes
 
         super().__init__(
             target.type,
             target.type,
             target,
             dimension=dimension,
-            static_split_point=static_split_point,
-            dynamic_split_point=dynamic_split_point,
+            static_chunk_sizes=static_chunk_sizes,
+            dynamic_chunk_sizes=dynamic_chunk_sizes,
             loc=loc,
             ip=ip,
         )
diff --git a/mlir/test/Dialect/Linalg/transform-op-split.mlir b/mlir/test/Dialect/Linalg/transform-op-split.mlir
index 566e517d69789..e072fff4c5d77 100644
--- a/mlir/test/Dialect/Linalg/transform-op-split.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-split.mlir
@@ -197,7 +197,7 @@ func.func @two_d(%arg0: tensor<10x34xf32>,
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.consumed}) {
     // expected-error @below {{expects either a dynamic or a static split point to be provided}}
-    %0:2 = "transform.structured.split"(%arg1) { dimension = 1, static_split_point = -9223372036854775808 } : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %0:2 = "transform.structured.split"(%arg1) { dimension = 1, static_chunk_sizes = -9223372036854775808 } : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
     transform.yield
   }
 }
diff --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py
index f97017b7a2c75..3ea73e8beea36 100644
--- a/mlir/test/python/dialects/transform_structured_ext.py
+++ b/mlir/test/python/dialects/transform_structured_ext.py
@@ -361,8 +361,8 @@ def testScalarize(target):
 @run
 @create_sequence
 def testSplit(target):
-    split = structured.SplitOp(target, dimension=1, split_point=42)
-    structured.SplitOp(split.results[0], dimension=3, split_point=split.results[1])
+    split = structured.SplitOp(target, dimension=1, chunk_sizes=42)
+    structured.SplitOp(split.results[0], dimension=3, chunk_sizes=split.results[1])
     # CHECK-LABEL: TEST: testSplit
     # CHECK: %[[F:.+]], %[[S:.+]] = transform.structured.split %{{.*}} after 42 {dimension = 1
     # CHECK: transform.structured.split %[[F]] after %[[S]] {dimension = 3

>From f4f657a3690c2f8cb8272bbf60105f61637de966 Mon Sep 17 00:00:00 2001
From: Muneeb Khan <muneeb.khan at huawei.com>
Date: Thu, 2 May 2024 19:10:49 +0800
Subject: [PATCH 3/3] [MLIR] Test multiway SplitOp

Tests SplitOp for multiway splitting of a structured op using
the result of `continuous_tile_sizes` to specify the
chunk sizes that the target's specified dimesion should be
split into.
---
 .../continuous-tiling-multiway-split.mlir     | 100 ++++++++++++++++++
 1 file changed, 100 insertions(+)
 create mode 100644 mlir/test/Dialect/Linalg/continuous-tiling-multiway-split.mlir

diff --git a/mlir/test/Dialect/Linalg/continuous-tiling-multiway-split.mlir b/mlir/test/Dialect/Linalg/continuous-tiling-multiway-split.mlir
new file mode 100644
index 0000000000000..609766fbdc91f
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/continuous-tiling-multiway-split.mlir
@@ -0,0 +1,100 @@
+// RUN: mlir-opt --transform-interpreter --canonicalize --split-input-file %s | FileCheck %s
+
+// This tests the results of continuous_tile_sizes on multiway splitOp.
+// continuous_tile_sizes returns a list of tile-sizes and a list of split points.
+// The list of split points is consumed by splitOp to split the linalg.matmul op
+// along dimension 1 to produce as many split-up linalg.matmul ops.
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %tiles, %splits = transform.structured.continuous_tile_sizes %0 { dimension = 1, target_size = 9} : (!transform.any_op) -> !transform.any_op
+    %low, %high = transform.structured.split %0 after %splits { dimension = 1, multiway } : !transform.any_op, !transform.any_op
+    transform.yield
+  }
+}
+
+func.func @continuous_tile_linalg_matmul(
+  %arg0: tensor<25x34xf32>, %arg1: tensor<34x25xf32>, %arg2: tensor<25x25xf32>)
+    -> tensor<25x25xf32> {
+  %0 = linalg.matmul  ins(%arg0, %arg1: tensor<25x34xf32>, tensor<34x25xf32>)
+                     outs(%arg2: tensor<25x25xf32>)
+    -> tensor<25x25xf32>
+
+  return %0 : tensor<25x25xf32>
+}
+
+// CHECK-LABEL: @continuous_tile_linalg_matmul
+// CHECK-SAME: %[[IN1:.+]]: tensor<25x34xf32>, %[[IN2:.+]]: tensor<34x25xf32>, %[[OUT:.+]]: tensor<25x25xf32>
+// CHECK:      %[[SLICE:.+]] = tensor.extract_slice %[[IN2]][0, 0] [34, 18] [1, 1] : tensor<34x25xf32> to tensor<34x18xf32>
+// CHECK       %[[SLICE0:.+]] = tensor.extract_slice %[[OUT]][0, 0] [25, 18] [1, 1] : tensor<25x25xf32> to tensor<25x18xf32>
+// CHECK       %[[MM0:.+]] = linalg.matmul ins(%[[IN1]], %[[SLICE]] : tensor<25x34xf32>, tensor<34x18xf32>) outs(%[[SLICE0]] : tensor<25x18xf32>) -> tensor<25x18xf32>
+// CHECK       %[[INSLICE:.+]] = tensor.insert_slice %[[MM0]] into %[[OUT]][0, 0] [25, 18] [1, 1] : tensor<25x18xf32> into tensor<25x25xf32>
+// CHECK       %[[SLICE1]] = tensor.extract_slice %[[IN2]][0, 18] [34, 7] [1, 1] : tensor<34x25xf32> to tensor<34x7xf32>
+// CHECK       %[[SLICE2]] = tensor.extract_slice %[[INSLICE]][0, 18] [25, 7] [1, 1] : tensor<25x25xf32> to tensor<25x7xf32>
+// CHECK       %[[SLICE3]] = tensor.extract_slice %[[SLICE1]][0, 0] [34, 4] [1, 1] : tensor<34x7xf32> to tensor<34x4xf32>
+// CHECK       %[[SLICE4]] = tensor.extract_slice %[[SLICE2]][0, 0] [25, 4] [1, 1] : tensor<25x7xf32> to tensor<25x4xf32>
+// CHECK       %[[MM1:.+]] = linalg.matmul ins(%[[IN1]], %[[SLICE3]] : tensor<25x34xf32>, tensor<34x4xf32>) outs(%[[SLICE4]] : tensor<25x4xf32>) -> tensor<25x4xf32>
+// CHECK       %[[INSLICE0:.+]] = tensor.insert_slice %[[MM1]] into %[[SLICE2]][0, 0] [25, 4] [1, 1] : tensor<25x4xf32> into tensor<25x7xf32>
+// CHECK       %[[SLICE5]] = tensor.extract_slice %[[SLICE1]][0, 4] [34, 3] [1, 1] : tensor<34x7xf32> to tensor<34x3xf32>
+// CHECK       %[[SLICE6]] = tensor.extract_slice %[[INSLICE0]][0, 4] [25, 3] [1, 1] : tensor<25x7xf32> to tensor<25x3xf32>
+// CHECK       %[[SLICE7]] = tensor.extract_slice %[[SLICE5]][0, 0] [34, 2] [1, 1] : tensor<34x3xf32> to tensor<34x2xf32>
+// CHECK       %[[SLICE8]] = tensor.extract_slice %[[SLICE6]][0, 0] [25, 2] [1, 1] : tensor<25x3xf32> to tensor<25x2xf32>
+// CHECK       %[[MM2:.+]] = linalg.matmul ins(%[[IN1]], %[[SLICE7]] : tensor<25x34xf32>, tensor<34x2xf32>) outs(%[[SLICE8]] : tensor<25x2xf32>) -> tensor<25x2xf32>
+// CHECK       %[[INSLICE1:.+]] = tensor.insert_slice %[[MM2]] into %[[SLICE6]][0, 0] [25, 2] [1, 1] : tensor<25x2xf32> into tensor<25x3xf32>
+// CHECK       %[[SLICE9]] = tensor.extract_slice %[[SLICE5]][0, 2] [34, 1] [1, 1] : tensor<34x3xf32> to tensor<34x1xf32>
+// CHECK       %[[SLICE10]] = tensor.extract_slice %[[INSLICE1]][0, 2] [25, 1] [1, 1] : tensor<25x3xf32> to tensor<25x1xf32>
+// CHECK       %[[MM3:.+]] = linalg.matmul ins(%[[IN1]], %[[SLICE9]] : tensor<25x34xf32>, tensor<34x1xf32>) outs(%[[SLICE10]] : tensor<25x1xf32>) -> tensor<25x1xf32>
+// CHECK       %[[INSLICE2]] = tensor.insert_slice %[[MM3]] into %[[INSLICE1]][0, 2] [25, 1] [1, 1] : tensor<25x1xf32> into tensor<25x3xf32>
+// CHECK       %[[INSLICE3]] = tensor.insert_slice %[[INSLICE2]] into %[[INSLICE0]][0, 4] [25, 3] [1, 1] : tensor<25x3xf32> into tensor<25x7xf32>
+// CHECK       %[[INSLICE4]] = tensor.insert_slice %[[INSLICE3]] into %[[INSLICE]][0, 18] [25, 7] [1, 1] : tensor<25x7xf32> into tensor<25x25xf32>
+// CHECK       return %[[INSLICE4]] : tensor<25x25xf32>
+
+// -----
+
+// Tests the same as above except that the !transform.param<i64> output type in
+// continuous_tile_sizes op triggers tile sizes and split points to be computed
+// statically and not dynamically.
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %tiles, %splits = transform.structured.continuous_tile_sizes %0 { dimension = 1, target_size = 9} : (!transform.any_op) -> !transform.param<i64>
+    %low, %high = transform.structured.split %0 after %splits { dimension = 1, multiway } : !transform.any_op, !transform.param<i64>
+    transform.yield
+  }
+}
+
+func.func @continuous_tile_static_linalg_matmul(
+  %arg0: tensor<25x34xf32>, %arg1: tensor<34x25xf32>, %arg2: tensor<25x25xf32>)
+    -> tensor<25x25xf32> {
+  %0 = linalg.matmul  ins(%arg0, %arg1: tensor<25x34xf32>, tensor<34x25xf32>)
+                     outs(%arg2: tensor<25x25xf32>)
+    -> tensor<25x25xf32>
+
+  return %0 : tensor<25x25xf32>
+}
+
+// CHECK-LABEL: @continuous_tile_static_linalg_matmul
+// CHECK-SAME: %[[IN1:.+]]: tensor<25x34xf32>, %[[IN2:.+]]: tensor<34x25xf32>, %[[OUT:.+]]: tensor<25x25xf32>
+// CHECK:      %[[SLICE:.+]] = tensor.extract_slice %[[IN2]][0, 0] [34, 18] [1, 1] : tensor<34x25xf32> to tensor<34x18xf32>
+// CHECK       %[[SLICE0:.+]] = tensor.extract_slice %[[OUT]][0, 0] [25, 18] [1, 1] : tensor<25x25xf32> to tensor<25x18xf32>
+// CHECK       %[[MM0:.+]] = linalg.matmul ins(%[[IN1]], %[[SLICE]] : tensor<25x34xf32>, tensor<34x18xf32>) outs(%[[SLICE0]] : tensor<25x18xf32>) -> tensor<25x18xf32>
+// CHECK       %[[INSLICE:.+]] = tensor.insert_slice %[[MM0]] into %[[OUT]][0, 0] [25, 18] [1, 1] : tensor<25x18xf32> into tensor<25x25xf32>
+// CHECK       %[[SLICE1]] = tensor.extract_slice %[[IN2]][0, 18] [34, 7] [1, 1] : tensor<34x25xf32> to tensor<34x7xf32>
+// CHECK       %[[SLICE2]] = tensor.extract_slice %[[INSLICE]][0, 18] [25, 7] [1, 1] : tensor<25x25xf32> to tensor<25x7xf32>
+// CHECK       %[[SLICE3]] = tensor.extract_slice %[[SLICE1]][0, 0] [34, 4] [1, 1] : tensor<34x7xf32> to tensor<34x4xf32>
+// CHECK       %[[SLICE4]] = tensor.extract_slice %[[SLICE2]][0, 0] [25, 4] [1, 1] : tensor<25x7xf32> to tensor<25x4xf32>
+// CHECK       %[[MM1:.+]] = linalg.matmul ins(%[[IN1]], %[[SLICE3]] : tensor<25x34xf32>, tensor<34x4xf32>) outs(%[[SLICE4]] : tensor<25x4xf32>) -> tensor<25x4xf32>
+// CHECK       %[[INSLICE0:.+]] = tensor.insert_slice %[[MM1]] into %[[SLICE2]][0, 0] [25, 4] [1, 1] : tensor<25x4xf32> into tensor<25x7xf32>
+// CHECK       %[[SLICE5]] = tensor.extract_slice %[[SLICE1]][0, 4] [34, 3] [1, 1] : tensor<34x7xf32> to tensor<34x3xf32>
+// CHECK       %[[SLICE6]] = tensor.extract_slice %[[INSLICE0]][0, 4] [25, 3] [1, 1] : tensor<25x7xf32> to tensor<25x3xf32>
+// CHECK       %[[SLICE7]] = tensor.extract_slice %[[SLICE5]][0, 0] [34, 2] [1, 1] : tensor<34x3xf32> to tensor<34x2xf32>
+// CHECK       %[[SLICE8]] = tensor.extract_slice %[[SLICE6]][0, 0] [25, 2] [1, 1] : tensor<25x3xf32> to tensor<25x2xf32>
+// CHECK       %[[MM2:.+]] = linalg.matmul ins(%[[IN1]], %[[SLICE7]] : tensor<25x34xf32>, tensor<34x2xf32>) outs(%[[SLICE8]] : tensor<25x2xf32>) -> tensor<25x2xf32>
+// CHECK       %[[INSLICE1:.+]] = tensor.insert_slice %[[MM2]] into %[[SLICE6]][0, 0] [25, 2] [1, 1] : tensor<25x2xf32> into tensor<25x3xf32>
+// CHECK       %[[SLICE9]] = tensor.extract_slice %[[SLICE5]][0, 2] [34, 1] [1, 1] : tensor<34x3xf32> to tensor<34x1xf32>
+// CHECK       %[[SLICE10]] = tensor.extract_slice %[[INSLICE1]][0, 2] [25, 1] [1, 1] : tensor<25x3xf32> to tensor<25x1xf32>
+// CHECK       %[[MM3:.+]] = linalg.matmul ins(%[[IN1]], %[[SLICE9]] : tensor<25x34xf32>, tensor<34x1xf32>) outs(%[[SLICE10]] : tensor<25x1xf32>) -> tensor<25x1xf32>
+// CHECK       %[[INSLICE2]] = tensor.insert_slice %[[MM3]] into %[[INSLICE1]][0, 2] [25, 1] [1, 1] : tensor<25x1xf32> into tensor<25x3xf32>
+// CHECK       %[[INSLICE3]] = tensor.insert_slice %[[INSLICE2]] into %[[INSLICE0]][0, 4] [25, 3] [1, 1] : tensor<25x3xf32> into tensor<25x7xf32>
+// CHECK       %[[INSLICE4]] = tensor.insert_slice %[[INSLICE3]] into %[[INSLICE]][0, 18] [25, 7] [1, 1] : tensor<25x7xf32> into tensor<25x25xf32>
+// CHECK       return %[[INSLICE4]] : tensor<25x25xf32>



More information about the Mlir-commits mailing list