[Mlir-commits] [mlir] a9efcbf - [MLIR] Add continuous tiling to transform dialect (#82792)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jun 21 07:39:46 PDT 2024


Author: muneebkhan85
Date: 2024-06-21T16:39:43+02:00
New Revision: a9efcbf490d9b8f46ec37062ca8653b4068000e5

URL: https://github.com/llvm/llvm-project/commit/a9efcbf490d9b8f46ec37062ca8653b4068000e5
DIFF: https://github.com/llvm/llvm-project/commit/a9efcbf490d9b8f46ec37062ca8653b4068000e5.diff

LOG: [MLIR] Add continuous tiling to transform dialect (#82792)

This patch enables continuous tiling of a target structured op using
diminishing tile sizes. In cases where the tensor dimensions are not
exactly divisible by the tile size, we are left with leftover tensor
chunks that are irregularly tiled. This approach enables tiling of the
leftover chunk with a smaller tile size and repeats this process
recursively using exponentially diminishing tile sizes. This eventually
generates a chain of loops that apply tiling using diminishing tile
sizes.

Adds `continuous_tile_sizes` op to the transform dialect. This op, when
given a tile size and a dimension, computes a series of diminishing tile
sizes that can be used to tile the target along the given dimension.
Additionally, this op also generates a series of chunk sizes that the
corresponding tile sizes should be applied to along the given dimension.

Adds `multiway` attribute to `transform.structured.split` that enables
multiway splitting of a single target op along the given dimension, as
specified in a list enumerating the chunk sizes.

Added: 
    mlir/test/Dialect/Linalg/continuous-tiling-full.mlir
    mlir/test/Dialect/Linalg/continuous-tiling-multiway-split.mlir

Modified: 
    mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
    mlir/lib/Dialect/Transform/IR/TransformOps.cpp
    mlir/python/mlir/dialects/transform/structured.py
    mlir/test/Dialect/Linalg/transform-op-split.mlir
    mlir/test/python/dialects/transform_structured_ext.py

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 93e2c2db729da..866275cedf68b 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -1396,29 +1396,43 @@ def SplitOp : Op<Transform_Dialect, "structured.split",
      DeclareOpInterfaceMethods<TransformOpInterface>,
      ReportTrackingListenerFailuresOpTrait]> {
   let description = [{
-    Indicates that the given `target` op should be split into two complementary
+    Splits the given `target` op into two or more complementary
     parts, which combined cover the entire iteration domain of the original op.
     The split is performed along the iteration space dimension provided as
-    attribute. In case of dimension overflow, the transformation fails. The
-    split is performed at the dimension iterator value specified as either the
-    static split point attribute when it is known at transform IR construction
-    time or as the handle to an operation producing a single index-typed value
-    when it is computed by payload IR. In the latter case, the static split
+    chunk size attribute specifying the size of the lower part; the remaining
+    range in the iteration space is assigned as the upper part. In case of
+    dimension overflow, the transformation fails. The split is performed at the
+    dimension iterator value specified as either the static chunk size
+    attribute when it is known at transform IR construction time or
+    as the handle to an operation producing a single index-typed value
+    when it is computed by payload IR. In the latter case, the chunk size
     point must be set to `ShapedType::kDynamic` and the dynamic size handle
     must point to as many value-producing operations as there are structured
     operations pointed to by the target handle.
 
-    The operation consumes the target handle, but preserves the split point
-    handle if provided. It produces two new handles pointing to the two parts
-    of the structured op after splitting, in the same order as the target
-    operand, with the first handle corresponding to the part with lower
-    iteration space indices.
+    The operation consumes the target handle, but preserves the chunk size
+    handle if provided. Without the `multiway` attribute, it produces two
+    new handles pointing to the two parts of the structured op after splitting,
+    in the same order as the target operand, with the first handle
+    corresponding to the part with lower iteration space indices.
+
+    Multiway split mode is enabled by specifying the `multiway` attribute.
+    In this mode a single `target` op is split into multiple parts covering
+    the iteration space of the specified dimension. `static_chunk_sizes` and
+    `dynamic_chunk_sizes` in this case is a list of chunk sizes that the given
+    dimension should be split into. With `multiway` it produces two handles;
+    the first handle is a list of the multiple parts of the structured op
+    after splitting, where the target dimensions for each linalg op in the
+    list corresponds to the chunk sizes specfied in the input split list.
+    If the chunk sizes do not cover the entire iteration space, the leftover
+    chunk is the last payload in the first handle. The second handle is empty.
   }];
 
   let arguments = (ins TransformHandleTypeInterface:$target,
                        I64Attr:$dimension,
-                       Optional<TransformAnyParamTypeOrAnyHandle>:$dynamic_split_point,
-                       I64Attr:$static_split_point);
+                       Optional<TransformAnyParamTypeOrAnyHandle>:$dynamic_chunk_sizes,
+                       I64Attr:$static_chunk_sizes,
+                       UnitAttr:$multiway);
   let results = (outs TransformHandleTypeInterface:$first,
                       TransformHandleTypeInterface:$second);
   let hasCustomAssemblyFormat = 1;
@@ -1819,6 +1833,51 @@ def TileReductionUsingForallOp :
 
 }
 
+//===----------------------------------------------------------------------===//
+// ContinuousTileSizesOp
+//===----------------------------------------------------------------------===//
+
+def ContinuousTileSizesOp : Op<Transform_Dialect, "structured.continuous_tile_sizes",
+       [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+        DeclareOpInterfaceMethods<TransformOpInterface>,
+        ReportTrackingListenerFailuresOpTrait]> {
+  let description = [{
+    This transform emits the IR computing the list of (1) exponentially
+    diminishing tile sizes that are powers of 2; and (2) the corresponding
+    chunk-sizes the target op should be split into along the given dimension.
+
+    For example, for `target_size` 9, and `dimension` 0 for the following
+    linalg op as target
+
+    ```
+      %0 = linalg.matmul  ins(%arg0, %arg1: tensor<25x34xf32>, tensor<34x25xf32>)
+                      outs(%arg2: tensor<25x25xf32>)
+    ```
+
+    the first result `tile_sizes` will be a list of diminishing tile sizes
+    9, 4, 2, 1; and the second result will be a list of chunk sizes
+    18, 4, 2, 1 that the corresponding dimension should be split into.
+
+    After the target op has been split along the given dimension (for example
+    using multiway split), each chunk can be tiled with the corresponding tile
+    size in the `tile_sizes` list generated as a result of this op.
+
+    Specifying the output type as !transform.param<i64> will cause `tile_sizes`
+    and `chunk_sizes` to be computed statically and not dynamically.
+  }];
+
+  let arguments = (ins TransformHandleTypeInterface:$target,
+                       ConfinedAttr<I64Attr, [IntNonNegative]>:$dimension,
+                       ConfinedAttr<I64Attr, [IntNonNegative]>:$target_size);
+  let results = (outs TransformAnyParamTypeOrAnyHandle:$tile_sizes,
+                      TransformAnyParamTypeOrAnyHandle:$chunk_sizes);
+  let hasVerifier = 1;
+  let assemblyFormat =
+    "$target attr-dict `:` custom<ContinuousTileSizeTypes>("
+    "type($target), type($tile_sizes), type($chunk_sizes))";
+
+}
+
 //===----------------------------------------------------------------------===//
 // TileUsingForOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 05e97befdec1f..8424207ea47e5 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -801,6 +801,15 @@ struct MultiSizeSpecificationBase {
   /// Number of tiles associated with each size.
   T lowTripCount, highTripCount;
 };
+
+template <typename T>
+struct ContinuousTileSizeSpecificationBase {
+  /// Tile sizes.
+  SmallVector<T> tileSizes;
+  /// Number of tiles associated with each size.
+  SmallVector<T> tripCounts;
+};
+
 } // namespace detail
 
 /// A description of a multi-size tiling comprising tile sizes and numbers of
@@ -811,6 +820,11 @@ struct MultiSizeSpecification
 struct StaticMultiSizeSpecification
     : public detail::MultiSizeSpecificationBase<int64_t> {};
 
+struct ContinuousTileSizeSpecification
+    : public detail::ContinuousTileSizeSpecificationBase<Value> {};
+struct StaticContinuousTileSizeSpecification
+    : public detail::ContinuousTileSizeSpecificationBase<int64_t> {};
+
 /// Emits the IR computing the multi-sized tiling specification with two tile
 /// sizes not exceeding `targetSize`, each divisible by `sizeDivisor`, such
 /// that there exist numbers of tiles with these sizes that fully cover the
@@ -846,6 +860,13 @@ FailureOr<StaticMultiSizeSpecification>
 computeStaticMultiTileSizes(LinalgOp op, unsigned dimension, int64_t targetSize,
                             int64_t divisor);
 
+FailureOr<StaticContinuousTileSizeSpecification>
+computeStaticContinuousTileSizes(LinalgOp op, unsigned dimension,
+                                 unsigned targetSize);
+FailureOr<ContinuousTileSizeSpecification>
+computeContinuousTileSizes(OpBuilder &builder, TilingInterface op,
+                           unsigned dimension, OpFoldResult targetSize,
+                           bool emitAssertions);
 /// Rewrite a TilingInterface `op` to a tiled `scf.forall`, applying
 /// tiling by `numThreads`.
 /// If non-empty, the `mapping` is added as an attribute to the

diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index 3bb297cbf91d2..7a661c663e010 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -624,7 +624,10 @@ def ForeachOp : TransformDialectOp<"foreach",
     Each iteration gets executed by co-indexing the payloads of the arguments
     and mapping the body's arguments to these tuples, as though iterating over
     the zipped together `targets`. As such, in each iteration, the size of the
-    payload of each of the body's block arguments is exactly one.
+    payload of each of the body's block arguments is exactly one. The attribute
+    `zip_shortest` can be used if the targets vary in their number of payloads;
+    this will limit the iterations to only the number of payloads found in the
+    shortest target.
 
     This op always reads the target handles. Furthermore, it consumes a handle
     if there is a transform op in the body that consumes the corresponding
@@ -645,11 +648,12 @@ def ForeachOp : TransformDialectOp<"foreach",
     rollback capabilities.
   }];
 
-  let arguments = (ins Variadic<Transform_AnyHandleOrParamType>:$targets);
+  let arguments = (ins Variadic<Transform_AnyHandleOrParamType>:$targets,
+                       UnitAttr:$zip_shortest);
   let results = (outs Variadic<Transform_AnyHandleOrParamType>:$results);
   let regions = (region SizedRegion<1>:$body);
   let assemblyFormat =
-    "$targets `:` type($targets) (`->` type($results)^)? $body attr-dict";
+    "$targets attr-dict `:` type($targets) (`->` type($results)^)? $body";
   let hasVerifier = 1;
 
   let extraClassDeclaration = [{

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index bc02788f9c441..37467db568c27 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2266,13 +2266,26 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
   // Collect the dynamic split points if provided.
   SmallVector<Operation *> payload =
       llvm::to_vector(state.getPayloadOps(getTarget()));
-  SmallVector<OpFoldResult> splitPoints;
-  splitPoints.reserve(payload.size());
-  if (getDynamicSplitPoint()) {
+
+  bool isMultiwaySplit = getMultiway();
+
+  if (isMultiwaySplit && !llvm::hasSingleElement(payload)) {
+    return mlir::emitSilenceableFailure(getLoc())
+           << "requires exactly one target when "
+              "multiway split is enabled (got "
+           << llvm::range_size(payload) << ")";
+  }
+
+  SmallVector<OpFoldResult> chunkSizes;
+
+  if (!isMultiwaySplit)
+    chunkSizes.reserve(payload.size());
+
+  if (getDynamicChunkSizes()) {
     auto diag = DiagnosedSilenceableFailure::success();
-    if (isa<TransformHandleTypeInterface>(getDynamicSplitPoint().getType())) {
-      splitPoints = llvm::to_vector(llvm::map_range(
-          state.getPayloadOps(getDynamicSplitPoint()), [&](Operation *op) {
+    if (isa<TransformHandleTypeInterface>(getDynamicChunkSizes().getType())) {
+      chunkSizes = llvm::to_vector(llvm::map_range(
+          state.getPayloadOps(getDynamicChunkSizes()), [&](Operation *op) {
             if (op->getNumResults() != 1 ||
                 !op->getResult(0).getType().isIndex()) {
               diag = emitSilenceableError()
@@ -2283,103 +2296,172 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
             return OpFoldResult(op->getResult(0));
           }));
     } else {
-      splitPoints = llvm::to_vector(
-          llvm::map_range(state.getParams(getDynamicSplitPoint()),
+      chunkSizes = llvm::to_vector(
+          llvm::map_range(state.getParams(getDynamicChunkSizes()),
                           [](Attribute attr) { return OpFoldResult(attr); }));
     }
     if (diag.isSilenceableFailure())
       return diag;
 
-    if (splitPoints.size() != payload.size()) {
+    // For multiway split, a single payload is expected to have multiple
+    // split points.
+    if (!isMultiwaySplit && chunkSizes.size() != payload.size()) {
       return emitDefiniteFailure()
              << "expected the dynamic split point handle to point to as "
                 "many operations ("
-             << splitPoints.size() << ") as the target handle ("
+             << chunkSizes.size() << ") as the target handle ("
              << payload.size() << ")";
     }
   } else {
-    splitPoints.resize(payload.size(),
-                       rewriter.getIndexAttr(getStaticSplitPoint()));
+    chunkSizes.resize(payload.size(),
+                       rewriter.getIndexAttr(getStaticChunkSizes()));
   }
 
-  // Split each target operation.
-  SmallVector<Operation *> first, second;
-  Operation *noSecondPart = nullptr;
-  for (const auto &pair : llvm::zip(payload, splitPoints)) {
-    Operation *target = std::get<0>(pair);
-    auto linalgOp = dyn_cast<LinalgOp>(target);
+  auto checkStructuredOpAndDimensions =
+      [&](LinalgOp linalgOp, Location loc) -> DiagnosedSilenceableFailure {
     if (!linalgOp) {
       auto diag = emitSilenceableError() << "only applies to structured ops";
-      diag.attachNote(target->getLoc()) << "target op";
+      diag.attachNote(loc) << "target op";
       return diag;
     }
 
     if (getDimension() >= linalgOp.getNumLoops()) {
       auto diag = emitSilenceableError() << "dimension " << getDimension()
-                                         << " does not exist in target op";
-      diag.attachNote(target->getLoc()) << "target op";
+                                          << " does not exist in target op";
+      diag.attachNote(loc) << "target op";
       return diag;
     }
+    return DiagnosedSilenceableFailure::success();
+  };
 
-    rewriter.setInsertionPoint(linalgOp);
-    std::tie(first.emplace_back(), second.emplace_back()) = linalg::splitOp(
-        rewriter, cast<TilingInterface>(linalgOp.getOperation()),
-        getDimension(), std::get<1>(pair));
-
-    // Propagate errors.
-    if (!first.back() && !second.back()) {
+  auto checkFailureInSplitting =
+      [&](bool hasFailed, Location loc) -> DiagnosedSilenceableFailure {
+    if (hasFailed) {
       auto diag = emitDefiniteFailure() << "internal failure in splitting";
-      diag.attachNote(target->getLoc()) << "target op";
+      diag.attachNote(loc) << "target op";
       return diag;
     }
+    return DiagnosedSilenceableFailure::success();
+  };
+
+  if (isMultiwaySplit) {
+
+    // Split a single target operation at multiple points.
+    SmallVector<Operation *> opList;
+    TilingInterface head, tail;
+    Operation *target = payload.front();
+
+    LinalgOp linalgOp = dyn_cast<LinalgOp>(target);
 
-    // Do not add null second parts.
-    if (!second.back()) {
-      noSecondPart = target;
-      second.pop_back();
+    // Check that the target is a valid LinalgOp with correct dimensions.
+    DiagnosedSilenceableFailure diag =
+        checkStructuredOpAndDimensions(linalgOp, target->getLoc());
+    if (diag.isSilenceableFailure())
+      return diag;
+
+    for (auto &&[idx, chunkSize] : llvm::enumerate(chunkSizes)) {
+
+      if (idx > 0)
+        target = tail.getOperation();
+
+      if (!target)
+        break;
+
+      linalgOp = cast<LinalgOp>(target);
+
+      rewriter.setInsertionPoint(linalgOp);
+      std::tie(head, tail) = linalg::splitOp(
+          rewriter, cast<TilingInterface>(linalgOp.getOperation()),
+          getDimension(), chunkSize);
+
+      // Propagate errors.
+      DiagnosedSilenceableFailure diag =
+          checkFailureInSplitting(!head && !tail, target->getLoc());
+      if (diag.isDefiniteFailure())
+        return diag;
+
+      opList.push_back(head.getOperation());
     }
-  }
 
-  if (second.size() != first.size() && !second.empty()) {
-    auto diag = emitSilenceableError()
-                << "splitting does not produce the second part for a subset "
-                   "of targets";
-    diag.attachNote() << "expected splitting to produce the second part of all "
-                         "or none of the targets";
-    diag.attachNote(noSecondPart->getLoc())
-        << "first target with no second part";
-    return diag;
-  }
+    // Append any leftover parts to the end of the result list.
+    if (tail)
+      opList.push_back(tail.getOperation());
+    results.set(cast<OpResult>(getFirst()), opList);
+    results.set(cast<OpResult>(getSecond()), {});
 
-  results.set(cast<OpResult>(getFirst()), first);
-  results.set(cast<OpResult>(getSecond()), second);
+  } else {
+    // Split each target operation.
+    SmallVector<Operation *> first, second;
+    Operation *noSecondPart = nullptr;
+    for (const auto &pair : llvm::zip(payload, chunkSizes)) {
+      Operation *target = std::get<0>(pair);
+      LinalgOp linalgOp = dyn_cast<LinalgOp>(target);
+      DiagnosedSilenceableFailure diag =
+          checkStructuredOpAndDimensions(linalgOp, target->getLoc());
+
+      if (diag.isSilenceableFailure())
+        return diag;
+
+      rewriter.setInsertionPoint(linalgOp);
+      std::tie(first.emplace_back(), second.emplace_back()) = linalg::splitOp(
+          rewriter, cast<TilingInterface>(linalgOp.getOperation()),
+          getDimension(), std::get<1>(pair));
+
+      // Propagate errors.
+      DiagnosedSilenceableFailure diagSplit = checkFailureInSplitting(
+          !first.back() && !second.back(), target->getLoc());
+      if (diagSplit.isDefiniteFailure())
+        return diag;
+
+      // Do not add null second parts.
+      if (!second.back()) {
+        noSecondPart = target;
+        second.pop_back();
+      }
+    }
+
+    if (second.size() != first.size() && !second.empty()) {
+      auto diag = emitSilenceableError()
+                  << "splitting does not produce the second part for a subset "
+                     "of targets";
+      diag.attachNote()
+          << "expected splitting to produce the second part of all "
+             "or none of the targets";
+      diag.attachNote(noSecondPart->getLoc())
+          << "first target with no second part";
+      return diag;
+    }
+
+    results.set(cast<OpResult>(getFirst()), first);
+    results.set(cast<OpResult>(getSecond()), second);
+  }
   return DiagnosedSilenceableFailure::success();
 }
 
 void SplitOp::getEffects(
     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
   consumesHandle(getTargetMutable(), effects);
-  if (getDynamicSplitPoint())
-    onlyReadsHandle(getDynamicSplitPointMutable(), effects);
+  if (getDynamicChunkSizes())
+    onlyReadsHandle(getDynamicChunkSizesMutable(), effects);
   producesHandle(getOperation()->getOpResults(), effects);
   modifiesPayload(effects);
 }
 
 ParseResult SplitOp::parse(OpAsmParser &parser, OperationState &result) {
-  OpAsmParser::UnresolvedOperand target, dynamicSplitPoint;
-  IntegerAttr staticSplitPoint;
+  OpAsmParser::UnresolvedOperand target, dynamicChunkSizes;
+  IntegerAttr staticChunkSizes;
   if (parser.parseOperand(target) || parser.parseKeyword("after"))
     return failure();
 
   OptionalParseResult dynamicPointParseResult =
-      parser.parseOptionalOperand(dynamicSplitPoint);
+      parser.parseOptionalOperand(dynamicChunkSizes);
   if (!dynamicPointParseResult.has_value()) {
-    int64_t staticSplitPointValue;
-    if (failed(parser.parseInteger(staticSplitPointValue)))
+    int64_t staticChunkSizesValue;
+    if (failed(parser.parseInteger(staticChunkSizesValue)))
       return failure();
 
-    staticSplitPoint =
-        parser.getBuilder().getI64IntegerAttr(staticSplitPointValue);
+    staticChunkSizes =
+        parser.getBuilder().getI64IntegerAttr(staticChunkSizesValue);
   }
 
   Type targetType;
@@ -2389,43 +2471,43 @@ ParseResult SplitOp::parse(OpAsmParser &parser, OperationState &result) {
     return failure();
   }
   if (dynamicPointParseResult.has_value()) {
-    Type splitPointType;
+    Type ChunkSizesType;
     if (failed(*dynamicPointParseResult) || parser.parseComma() ||
-        parser.parseType(splitPointType) ||
-        parser.resolveOperand(dynamicSplitPoint, splitPointType,
+        parser.parseType(ChunkSizesType) ||
+        parser.resolveOperand(dynamicChunkSizes, ChunkSizesType,
                               result.operands)) {
       return failure();
     }
 
-    staticSplitPoint =
+    staticChunkSizes =
         parser.getBuilder().getI64IntegerAttr(ShapedType::kDynamic);
   }
 
   result.addAttribute(
-      SplitOp::getStaticSplitPointAttrName(result.name).getValue(),
-      staticSplitPoint);
+      SplitOp::getStaticChunkSizesAttrName(result.name).getValue(),
+      staticChunkSizes);
   result.addTypes({targetType, targetType});
   return success();
 }
 
 void SplitOp::print(OpAsmPrinter &printer) {
   printer << " " << getTarget() << " after ";
-  int64_t staticSplitSize = static_cast<int64_t>(getStaticSplitPoint());
-  if (staticSplitSize != ShapedType::kDynamic)
-    printer << staticSplitSize;
+  int64_t staticChunkSize = static_cast<int64_t>(getStaticChunkSizes());
+  if (staticChunkSize != ShapedType::kDynamic)
+    printer << staticChunkSize;
   else
-    printer << getDynamicSplitPoint();
+    printer << getDynamicChunkSizes();
   printer << " ";
   printer.printOptionalAttrDict(getOperation()->getAttrs(),
-                                {getStaticSplitPointAttrName()});
+                                {getStaticChunkSizesAttrName()});
   printer << " : " << getTarget().getType();
-  if (staticSplitSize == ShapedType::kDynamic)
-    printer << ", " << getDynamicSplitPoint().getType();
+  if (staticChunkSize == ShapedType::kDynamic)
+    printer << ", " << getDynamicChunkSizes().getType();
 }
 
 LogicalResult SplitOp::verify() {
-  if ((static_cast<int64_t>(getStaticSplitPoint()) != ShapedType::kDynamic) ^
-      (getDynamicSplitPoint() == nullptr)) {
+  if ((static_cast<int64_t>(getStaticChunkSizes()) != ShapedType::kDynamic) ^
+      (getDynamicChunkSizes() == nullptr)) {
     return emitOpError() << "expects either a dynamic or a static split "
                             "point to be provided";
   }
@@ -2584,6 +2666,153 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne(
   return DiagnosedSilenceableFailure::success();
 }
 
+//===----------------------------------------------------------------------===//
+// ContinuousTileSizesOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::ContinuousTileSizesOp::apply(transform::TransformRewriter &rewriter,
+                                        TransformResults &transformResults,
+                                        TransformState &state) {
+
+  SmallVector<Operation *> targetOps =
+      llvm::to_vector(state.getPayloadOps(getTarget()));
+
+  if (!llvm::hasSingleElement(targetOps)) {
+    return mlir::emitSilenceableFailure(getLoc())
+           << "requires exactly one target (got " << llvm::range_size(targetOps)
+           << ")";
+  }
+
+  Operation *target = *targetOps.begin();
+  auto linalgOp = dyn_cast<LinalgOp>(target);
+  auto tileableOp = dyn_cast<TilingInterface>(target);
+
+  if (!linalgOp)
+    return emitDefiniteFailure() << "expected Linalg Op";
+
+  OpBuilder builder(linalgOp.getContext());
+
+  if (isa<TransformParamTypeInterface>(getChunkSizes().getType())) {
+    if (linalgOp.hasDynamicShape()) {
+      auto diag = emitSilenceableError()
+                  << "cannot compute parametric tile sizes for dynamically "
+                     "shaped payload op";
+      diag.attachNote(linalgOp->getLoc()) << "payload op";
+      return diag;
+    }
+
+    FailureOr<StaticContinuousTileSizeSpecification> spec =
+        computeStaticContinuousTileSizes(linalgOp, getDimension(),
+                                         getTargetSize());
+    if (failed(spec)) {
+      return emitSilenceableError()
+             << "failed to compute multi-size tiling sizes";
+    }
+
+    SmallVector<int64_t> chunkSizes;
+
+    for (auto &&[tileSize, tripCount] :
+         llvm::zip_equal(spec->tileSizes, spec->tripCounts))
+      chunkSizes.push_back(tileSize * tripCount);
+
+    auto getI64AttrsFromI64 = [&](ArrayRef<int64_t> values) {
+      return llvm::map_to_vector(values, [&](int64_t value) -> Attribute {
+            return builder.getI64IntegerAttr(value);
+          });
+    };
+    transformResults.setParams(cast<OpResult>(getTileSizes()),
+                               getI64AttrsFromI64(spec->tileSizes));
+    transformResults.setParams(cast<OpResult>(getChunkSizes()),
+                               getI64AttrsFromI64(chunkSizes));
+
+    return DiagnosedSilenceableFailure::success();
+  }
+
+  builder.setInsertionPoint(linalgOp);
+
+  OpFoldResult targetSize = builder.getIndexAttr(getTargetSize());
+  unsigned dimension = getDimension();
+
+  FailureOr<ContinuousTileSizeSpecification> spec = computeContinuousTileSizes(
+      builder, tileableOp, dimension, targetSize, true);
+  if (failed(spec)) {
+    return emitSilenceableError() << "could not generate tile size computation";
+  }
+
+  AffineExpr s0 = builder.getAffineSymbolExpr(0);
+  AffineExpr s1 = builder.getAffineSymbolExpr(1);
+  auto apply = [&](AffineExpr expr, ArrayRef<OpFoldResult> ofrs) -> Value {
+    return affine::makeComposedAffineApply(builder, linalgOp->getLoc(), expr,
+                                           ofrs);
+  };
+
+  SmallVector<Value> chunkSizes;
+  Value splitPoint;
+  for (auto &&[tileSize, tripCount] :
+       llvm::zip_equal(spec->tileSizes, spec->tripCounts)) {
+    splitPoint = apply(s0 * s1, {tileSize, tripCount});
+    chunkSizes.push_back(splitPoint);
+  }
+
+  auto getDefiningOps = [&](ArrayRef<Value> values) {
+        return llvm::map_to_vector(values, [&](Value value) -> Operation * {
+          return value.getDefiningOp();
+        });
+  };
+
+  transformResults.set(cast<OpResult>(getTileSizes()),
+                       getDefiningOps(spec->tileSizes));
+  transformResults.set(cast<OpResult>(getChunkSizes()),
+                       getDefiningOps(chunkSizes));
+
+  return DiagnosedSilenceableFailure::success();
+}
+
+LogicalResult transform::ContinuousTileSizesOp::verify() {
+
+  if (getTileSizes().getType() != getChunkSizes().getType()) {
+    return emitOpError() << "expects all results type to be the same";
+  }
+
+  return success();
+}
+
+void transform::ContinuousTileSizesOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  if (isa<TransformParamTypeInterface>(getTileSizes().getType()))
+    onlyReadsPayload(effects);
+  else
+    modifiesPayload(effects);
+  onlyReadsHandle(getTargetMutable(), effects);
+  producesHandle(getOperation()->getOpResults(), effects);
+}
+
+static void printContinuousTileSizeTypes(OpAsmPrinter &printer, Operation *op,
+                                         Type targetType, Type tile_sizes,
+                                         Type) {
+  printer.printFunctionalType(TypeRange{targetType}, TypeRange{tile_sizes});
+}
+
+static ParseResult parseContinuousTileSizeTypes(OpAsmParser &parser,
+                                                Type &targetType,
+                                                Type &tileSizesType,
+                                                Type &chunkSizesType) {
+  FunctionType funcType;
+  llvm::SMLoc typeLoc = parser.getCurrentLocation();
+  if (failed(parser.parseType<FunctionType>(funcType)))
+    return failure();
+
+  if (funcType.getNumInputs() != 1 || funcType.getNumResults() != 1) {
+    parser.emitError(typeLoc) << "expects a trailing functional type with one "
+                                 "argument and one result";
+  }
+  targetType = funcType.getInput(0);
+  tileSizesType = chunkSizesType = funcType.getResult(0);
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // TileUsingForOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index d8dee82237156..8ef8651646829 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -107,6 +107,137 @@ static void emitIsPositiveIndexAssertion(ImplicitLocOpBuilder &b,
       b.getStringAttr("expected strictly positive tile size and divisor"));
 }
 
+FailureOr<StaticContinuousTileSizeSpecification>
+mlir::linalg::computeStaticContinuousTileSizes(LinalgOp op,
+                                               unsigned dimension,
+                                               unsigned targetSize) {
+
+  assert(!op.hasDynamicShape() &&
+         "cannot compute static multi-tile sizes for an op with dynamic shape");
+  assert(targetSize > 0 && "target size must be non-negative");
+  assert(dimension < op.getNumLoops() && "dimension overflow");
+
+  StaticContinuousTileSizeSpecification spec;
+  int64_t loopRange = op.getStaticLoopRanges()[dimension];
+  int64_t tripCount = loopRange / targetSize;
+
+  unsigned tileSize = targetSize;
+
+  spec.tileSizes.push_back(tileSize);
+  spec.tripCounts.push_back(tripCount);
+
+  int64_t remainderChunk = loopRange % targetSize;
+
+  while (tileSize > 1 && remainderChunk != 0) {
+
+    uint64_t maxPower = llvm::bit_floor(tileSize);
+    tileSize = maxPower == tileSize ? maxPower >> 1 : maxPower;
+
+    tripCount = remainderChunk / tileSize;
+
+    if (tripCount > 0) {
+      spec.tileSizes.push_back(tileSize);
+      spec.tripCounts.push_back(tripCount);
+    }
+
+    remainderChunk = remainderChunk % tileSize;
+  }
+
+  auto tripCountCheck = [&](SmallVector<int64_t> tileSizes,
+                            SmallVector<int64_t> tripCounts,
+                            int64_t range) -> bool {
+    int64_t computedRange = 0;
+    for (auto [tileSize, tripCount] : llvm::zip(tileSizes, tripCounts))
+      computedRange += tileSize * tripCount;
+    return range == computedRange;
+  };
+
+  if (!tripCountCheck(spec.tileSizes, spec.tripCounts, loopRange))
+    return failure();
+
+  return spec;
+}
+
+FailureOr<ContinuousTileSizeSpecification>
+mlir::linalg::computeContinuousTileSizes(OpBuilder &builder, TilingInterface op,
+                                         unsigned dimension,
+                                         OpFoldResult targetSize,
+                                         bool emitAssertions) {
+
+  SmallVector<Range> loopRanges = op.getIterationDomain(builder);
+  unsigned numLoops = loopRanges.size();
+
+  // Bail out on dimension overflow.
+  if (dimension >= numLoops)
+    return failure();
+
+  // The code below works only on values.
+  Location loc = op->getLoc();
+  ImplicitLocOpBuilder b(loc, builder);
+  if (emitAssertions) {
+    emitIsPositiveIndexAssertion(b, targetSize);
+  }
+  Value targetSizeValue =
+      getValueOrCreateConstantIndexOp(builder, loc, targetSize);
+
+  // Find the trip count of the iteration space dimension for which the tile
+  // sizes are computed.
+  Value loopRange = getValueOrCreateConstantIndexOp(b, loc,
+                                                    loopRanges[dimension].size);
+  ContinuousTileSizeSpecification spec;
+
+  // Compute the tile sizes and the respective numbers of tiles.
+  AffineExpr s0 = b.getAffineSymbolExpr(0);
+  AffineExpr s1 = b.getAffineSymbolExpr(1);
+  auto apply = [&](AffineExpr expr, ArrayRef<OpFoldResult> ofrs) -> Value {
+    return affine::makeComposedAffineApply(b, b.getLoc(), expr, ofrs);
+  };
+
+  Value tripCountValue = apply(s0.floorDiv(s1), {loopRange, targetSizeValue});
+  Value remainderChunkValue = apply(s0 % s1, {loopRange, targetSizeValue});
+
+  OpFoldResult tripCountSize = affine::makeComposedFoldedAffineApply(
+      b, b.getLoc(), s0.floorDiv(s1), {loopRange, targetSizeValue});
+
+  // emitAssertions above already asserts that targetSize is
+  // a poistive integer.
+  uint64_t tileSizeInt = *getConstantIntValue(targetSizeValue);
+
+  assert(tileSizeInt > 0 && "target size must be non-negative");
+
+  spec.tileSizes.push_back(targetSizeValue);
+  spec.tripCounts.push_back(tripCountValue);
+
+  while (tileSizeInt > 1) {
+    uint64_t maxPower = llvm::bit_floor(tileSizeInt);
+    tileSizeInt = maxPower == tileSizeInt ? maxPower >> 1 : maxPower;
+    auto constStepOp =
+        builder.createOrFold<arith::ConstantIndexOp>(b.getLoc(), tileSizeInt);
+    tripCountValue = apply(s0.floorDiv(s1), {remainderChunkValue, constStepOp});
+
+    tripCountSize = affine::makeComposedFoldedAffineApply(
+        b, b.getLoc(), s0.floorDiv(s1), {remainderChunkValue, constStepOp});
+
+    // Optimization if tripCount can be determined to be zero.
+    if (Attribute attr = llvm::dyn_cast_if_present<Attribute>(tripCountSize)) {
+      auto intAttr = cast<IntegerAttr>(attr);
+      bool isTripCountZero = intAttr.getValue().isZero();
+
+      if (!isTripCountZero) {
+        spec.tileSizes.push_back(constStepOp);
+        spec.tripCounts.push_back(tripCountValue);
+      }
+    } else {
+      spec.tileSizes.push_back(constStepOp);
+      spec.tripCounts.push_back(tripCountValue);
+    }
+
+    remainderChunkValue = apply(s0 % s1, {remainderChunkValue, constStepOp});
+  }
+
+  return spec;
+}
+
 FailureOr<StaticMultiSizeSpecification>
 mlir::linalg::computeStaticMultiTileSizes(LinalgOp op, unsigned dimension,
                                           int64_t targetSize, int64_t divisor) {

diff  --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 1efe708a5e349..3d6a9a18d9f41 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -1396,9 +1396,26 @@ transform::ForeachOp::apply(transform::TransformRewriter &rewriter,
   SmallVector<SmallVector<MappedValue>> payloads;
   detail::prepareValueMappings(payloads, getTargets(), state);
   size_t numIterations = payloads.empty() ? 0 : payloads.front().size();
+  bool isZipShortest = getZipShortest();
+
+  // In case of `zip_shortest`, set the number of iterations to the
+  // smallest payload in the targets.
+  if (isZipShortest) {
+    numIterations =
+        llvm::min_element(payloads, [&](const SmallVector<MappedValue> &A,
+                                        const SmallVector<MappedValue> &B) {
+          return A.size() < B.size();
+        })->size();
+
+    for (size_t argIdx = 0; argIdx < payloads.size(); argIdx++)
+      payloads[argIdx].resize(numIterations);
+  }
 
   // As we will be "zipping" over them, check all payloads have the same size.
-  for (size_t argIdx = 1; argIdx < payloads.size(); argIdx++) {
+  // `zip_shortest` adjusts all payloads to the same size, so skip this check
+  // when true.
+  for (size_t argIdx = 1; !isZipShortest && argIdx < payloads.size();
+       argIdx++) {
     if (payloads[argIdx].size() != numIterations) {
       return emitSilenceableError()
              << "prior targets' payload size (" << numIterations

diff  --git a/mlir/python/mlir/dialects/transform/structured.py b/mlir/python/mlir/dialects/transform/structured.py
index 2c49ef0960c75..41051c0d5b2ff 100644
--- a/mlir/python/mlir/dialects/transform/structured.py
+++ b/mlir/python/mlir/dialects/transform/structured.py
@@ -432,25 +432,25 @@ def __init__(
         self,
         target: Union[Operation, Value],
         dimension: Union[int, Attribute],
-        split_point: Union[int, Operation, Value, Attribute],
+        chunk_sizes: Union[int, Operation, Value, Attribute],
         *,
         loc=None,
         ip=None,
     ):
-        if isinstance(split_point, int):
-            static_split_point = split_point
-            dynamic_split_point = None
+        if isinstance(chunk_sizes, int):
+            static_chunk_sizes = chunk_sizes
+            dynamic_chunk_sizes = None
         else:
-            static_split_point = ShapedType.get_dynamic_size()
-            dynamic_split_point = split_point
+            static_chunk_sizes = ShapedType.get_dynamic_size()
+            dynamic_chunk_sizes = chunk_sizes
 
         super().__init__(
             target.type,
             target.type,
             target,
             dimension=dimension,
-            static_split_point=static_split_point,
-            dynamic_split_point=dynamic_split_point,
+            static_chunk_sizes=static_chunk_sizes,
+            dynamic_chunk_sizes=dynamic_chunk_sizes,
             loc=loc,
             ip=ip,
         )

diff  --git a/mlir/test/Dialect/Linalg/continuous-tiling-full.mlir b/mlir/test/Dialect/Linalg/continuous-tiling-full.mlir
new file mode 100644
index 0000000000000..b61c727f9cffc
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/continuous-tiling-full.mlir
@@ -0,0 +1,180 @@
+// RUN: mlir-opt --transform-interpreter --canonicalize --split-input-file %s | FileCheck %s
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %tile_sizes, %chunk_sizes = transform.structured.continuous_tile_sizes %0 { dimension = 0, target_size = 9 } : (!transform.any_op) -> !transform.any_op
+    %linalg_splits, %empty = transform.structured.split %0 after %chunk_sizes { dimension = 0, multiway } : !transform.any_op, !transform.any_op
+    transform.foreach %linalg_splits, %tile_sizes : !transform.any_op, !transform.any_op {
+    ^bb1(%linalg_split: !transform.any_op, %tile_size: !transform.any_op):
+      %tiled_linalg_split, %dim0_loop = transform.structured.tile_using_for %linalg_split tile_sizes [%tile_size] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+      transform.yield
+    }
+    transform.yield
+  }
+}
+
+func.func @continuous_tile_linalg_matmul(
+  %arg0: tensor<25x34xf32>, %arg1: tensor<34x25xf32>, %arg2: tensor<25x25xf32>)
+    -> tensor<25x25xf32> {
+  %0 = linalg.matmul  ins(%arg0, %arg1: tensor<25x34xf32>, tensor<34x25xf32>)
+                     outs(%arg2: tensor<25x25xf32>)
+    -> tensor<25x25xf32>
+
+  return %0 : tensor<25x25xf32>
+}
+
+// CHECK-LABEL: @continuous_tile_linalg_matmul
+// CHECK-SAME:  (%[[IN1:.+]]: tensor<25x34xf32>, %[[IN2:.+]]: tensor<34x25xf32>, %[[OUT:.+]]: tensor<25x25xf32>) -> tensor<25x25xf32> {
+// CHECK:         %[[C18:.+]] = arith.constant 18 : index
+// CHECK:         %[[C0:.+]] = arith.constant 0 : index
+// CHECK:         %[[C9:.+]] = arith.constant 9 : index
+// CHECK:         %[[XSIN18:.+]] = tensor.extract_slice %[[IN1]][0, 0] [18, 34] [1, 1] : tensor<25x34xf32> to tensor<18x34xf32>
+// CHECK:         %[[XSOUT18:.+]] = tensor.extract_slice %[[OUT]][0, 0] [18, 25] [1, 1] : tensor<25x25xf32> to tensor<18x25xf32>
+// CHECK:         %[[R0:.+]] = scf.for %[[IDX:.+]] = %[[C0]] to %[[C18]] step %[[C9]] iter_args(%[[XSOUT18ARG:.+]] = %[[XSOUT18]]) -> (tensor<18x25xf32>) {
+// CHECK:           %[[XSIN19:.+]] = tensor.extract_slice %[[XSIN18]][%[[IDX]], 0] [9, 34] [1, 1] : tensor<18x34xf32> to tensor<9x34xf32>
+// CHECK:           %[[XSOUT9:.+]] = tensor.extract_slice %[[XSOUT18ARG]][%[[IDX]], 0] [9, 25] [1, 1] : tensor<18x25xf32> to tensor<9x25xf32>
+// CHECK:           %[[MATMUL:.+]] = linalg.matmul ins(%[[XSIN19]], %[[IN2]] : tensor<9x34xf32>, tensor<34x25xf32>) outs(%[[XSOUT9]] : tensor<9x25xf32>) -> tensor<9x25xf32>
+// CHECK:           %[[INS9:.+]] = tensor.insert_slice %[[MATMUL]] into %[[XSOUT18ARG]][%[[IDX]], 0] [9, 25] [1, 1] : tensor<9x25xf32> into tensor<18x25xf32>
+// CHECK:           scf.yield %[[INS9]] : tensor<18x25xf32>
+// CHECK:         }
+// CHECK:         %[[INS:.+]] = tensor.insert_slice %[[R0]] into %[[OUT]][0, 0] [18, 25] [1, 1] : tensor<18x25xf32> into tensor<25x25xf32>
+// CHECK:         %[[XS1:.+]] = tensor.extract_slice %[[IN1]][18, 0] [7, 34] [1, 1] : tensor<25x34xf32> to tensor<7x34xf32>
+// CHECK:         %[[XS2:.+]] = tensor.extract_slice %[[INS]][18, 0] [7, 25] [1, 1] : tensor<25x25xf32> to tensor<7x25xf32>
+// CHECK:         %[[XS3:.+]] = tensor.extract_slice %[[XS1]][0, 0] [4, 34] [1, 1] : tensor<7x34xf32> to tensor<4x34xf32>
+// CHECK:         %[[XS4:.+]] = tensor.extract_slice %[[XS2]][0, 0] [4, 25] [1, 1] : tensor<7x25xf32> to tensor<4x25xf32>
+// CHECK:         %[[R1:.+]] = linalg.matmul ins(%[[XS3]], %[[IN2]] : tensor<4x34xf32>, tensor<34x25xf32>) outs(%[[XS4]] : tensor<4x25xf32>) -> tensor<4x25xf32>
+// CHECK:         %[[INS5:.+]] = tensor.insert_slice %[[R1]] into %[[XS2]][0, 0] [4, 25] [1, 1] : tensor<4x25xf32> into tensor<7x25xf32>
+// CHECK:         %[[XS6:.+]] = tensor.extract_slice %[[XS1]][4, 0] [3, 34] [1, 1] : tensor<7x34xf32> to tensor<3x34xf32>
+// CHECK:         %[[XS7:.+]] = tensor.extract_slice %[[INS5]][4, 0] [3, 25] [1, 1] : tensor<7x25xf32> to tensor<3x25xf32>
+// CHECK:         %[[XS8:.+]] = tensor.extract_slice %[[XS6]][0, 0] [2, 34] [1, 1] : tensor<3x34xf32> to tensor<2x34xf32>
+// CHECK:         %[[XS9:.+]] = tensor.extract_slice %[[XS7]][0, 0] [2, 25] [1, 1] : tensor<3x25xf32> to tensor<2x25xf32>
+// CHECK:         %[[R2:.+]] = linalg.matmul ins(%[[XS8]], %[[IN2]] : tensor<2x34xf32>, tensor<34x25xf32>) outs(%[[XS9]] : tensor<2x25xf32>) -> tensor<2x25xf32>
+// CHECK:         %[[INS10:.+]] = tensor.insert_slice %[[R2]] into %[[XS7]][0, 0] [2, 25] [1, 1] : tensor<2x25xf32> into tensor<3x25xf32>
+// CHECK:         %[[XS11:.+]] = tensor.extract_slice %[[XS6]][2, 0] [1, 34] [1, 1] : tensor<3x34xf32> to tensor<1x34xf32>
+// CHECK:         %[[XS12:.+]] = tensor.extract_slice %[[INS10]][2, 0] [1, 25] [1, 1] : tensor<3x25xf32> to tensor<1x25xf32>
+// CHECK:         %[[R3:.+]] = linalg.matmul ins(%[[XS11]], %[[IN2]] : tensor<1x34xf32>, tensor<34x25xf32>) outs(%[[XS12]] : tensor<1x25xf32>) -> tensor<1x25xf32>
+// CHECK:         %[[INS13:.+]] = tensor.insert_slice %[[R3]] into %[[INS10]][2, 0] [1, 25] [1, 1] : tensor<1x25xf32> into tensor<3x25xf32>
+// CHECK:         %[[INS14:.+]] = tensor.insert_slice %[[INS13]] into %[[INS5]][4, 0] [3, 25] [1, 1] : tensor<3x25xf32> into tensor<7x25xf32>
+// CHECK:         %[[INS15:.+]] = tensor.insert_slice %[[INS14]] into %[[INS]][18, 0] [7, 25] [1, 1] : tensor<7x25xf32> into tensor<25x25xf32>
+// CHECK:         return %[[INS15]] : tensor<25x25xf32>
+
+// -----
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %tile_sizes, %chunk_sizes = transform.structured.continuous_tile_sizes %0 { dimension = 0, target_size = 9 } : (!transform.any_op) -> !transform.param<i64>
+    %linalg_splits, %empty = transform.structured.split %0 after %chunk_sizes { dimension = 0, multiway } : !transform.any_op, !transform.param<i64>
+    transform.foreach %linalg_splits, %tile_sizes : !transform.any_op, !transform.param<i64> {
+    ^bb1(%linalg_split: !transform.any_op, %tile_size: !transform.param<i64>):
+      %tiled_linalg_split, %dim0_loop = transform.structured.tile_using_for %linalg_split tile_sizes [%tile_size] : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op)
+      transform.yield
+    }
+    transform.yield
+  }
+}
+
+func.func @continuous_tile_static_linalg_matmul(
+  %arg0: tensor<25x34xf32>, %arg1: tensor<34x25xf32>, %arg2: tensor<25x25xf32>)
+    -> tensor<25x25xf32> {
+  %0 = linalg.matmul  ins(%arg0, %arg1: tensor<25x34xf32>, tensor<34x25xf32>)
+                     outs(%arg2: tensor<25x25xf32>)
+    -> tensor<25x25xf32>
+
+  return %0 : tensor<25x25xf32>
+}
+
+// CHECK-LABEL: @continuous_tile_static_linalg_matmul
+// CHECK-SAME:  (%[[IN1:.+]]: tensor<25x34xf32>, %[[IN2:.+]]: tensor<34x25xf32>, %[[OUT:.+]]: tensor<25x25xf32>) -> tensor<25x25xf32> {
+// CHECK:         %[[C9:.+]] = arith.constant 9 : index
+// CHECK:         %[[C18:.+]] = arith.constant 18 : index
+// CHECK:         %[[C0:.+]] = arith.constant 0 : index
+// CHECK:         %[[XSIN18:.+]] = tensor.extract_slice %[[IN1]][0, 0] [18, 34] [1, 1] : tensor<25x34xf32> to tensor<18x34xf32>
+// CHECK:         %[[XSOUT18:.+]] = tensor.extract_slice %[[OUT]][0, 0] [18, 25] [1, 1] : tensor<25x25xf32> to tensor<18x25xf32>
+// CHECK:         %[[R0:.+]] = scf.for %[[IDX:.+]] = %[[C0]] to %[[C18]] step %[[C9]] iter_args(%[[XSOUT18ARG:.+]] = %[[XSOUT18]]) -> (tensor<18x25xf32>) {
+// CHECK:           %[[XSIN19:.+]] = tensor.extract_slice %[[XSIN18]][%[[IDX]], 0] [9, 34] [1, 1] : tensor<18x34xf32> to tensor<9x34xf32>
+// CHECK:           %[[XSOUT9:.+]] = tensor.extract_slice %[[XSOUT18ARG]][%[[IDX]], 0] [9, 25] [1, 1] : tensor<18x25xf32> to tensor<9x25xf32>
+// CHECK:           %[[MATMUL:.+]] = linalg.matmul ins(%[[XSIN19]], %[[IN2]] : tensor<9x34xf32>, tensor<34x25xf32>) outs(%[[XSOUT9]] : tensor<9x25xf32>) -> tensor<9x25xf32>
+// CHECK:           %[[INS9:.+]] = tensor.insert_slice %[[MATMUL]] into %[[XSOUT18ARG]][%[[IDX]], 0] [9, 25] [1, 1] : tensor<9x25xf32> into tensor<18x25xf32>
+// CHECK:           scf.yield %[[INS9]] : tensor<18x25xf32>
+// CHECK:         }
+// CHECK:         %[[INS:.+]] = tensor.insert_slice %[[R0]] into %[[OUT]][0, 0] [18, 25] [1, 1] : tensor<18x25xf32> into tensor<25x25xf32>
+// CHECK:         %[[XS1:.+]] = tensor.extract_slice %[[IN1]][18, 0] [7, 34] [1, 1] : tensor<25x34xf32> to tensor<7x34xf32>
+// CHECK:         %[[XS2:.+]] = tensor.extract_slice %[[INS]][18, 0] [7, 25] [1, 1] : tensor<25x25xf32> to tensor<7x25xf32>
+// CHECK:         %[[XS3:.+]] = tensor.extract_slice %[[XS1]][0, 0] [4, 34] [1, 1] : tensor<7x34xf32> to tensor<4x34xf32>
+// CHECK:         %[[XS4:.+]] = tensor.extract_slice %[[XS2]][0, 0] [4, 25] [1, 1] : tensor<7x25xf32> to tensor<4x25xf32>
+// CHECK:         %[[R1:.+]] = linalg.matmul ins(%[[XS3]], %[[IN2]] : tensor<4x34xf32>, tensor<34x25xf32>) outs(%[[XS4]] : tensor<4x25xf32>) -> tensor<4x25xf32>
+// CHECK:         %[[INS5:.+]] = tensor.insert_slice %[[R1]] into %[[XS2]][0, 0] [4, 25] [1, 1] : tensor<4x25xf32> into tensor<7x25xf32>
+// CHECK:         %[[XS6:.+]] = tensor.extract_slice %[[XS1]][4, 0] [3, 34] [1, 1] : tensor<7x34xf32> to tensor<3x34xf32>
+// CHECK:         %[[XS7:.+]] = tensor.extract_slice %[[INS5]][4, 0] [3, 25] [1, 1] : tensor<7x25xf32> to tensor<3x25xf32>
+// CHECK:         %[[XS8:.+]] = tensor.extract_slice %[[XS6]][0, 0] [2, 34] [1, 1] : tensor<3x34xf32> to tensor<2x34xf32>
+// CHECK:         %[[XS9:.+]] = tensor.extract_slice %[[XS7]][0, 0] [2, 25] [1, 1] : tensor<3x25xf32> to tensor<2x25xf32>
+// CHECK:         %[[R2:.+]] = linalg.matmul ins(%[[XS8]], %[[IN2]] : tensor<2x34xf32>, tensor<34x25xf32>) outs(%[[XS9]] : tensor<2x25xf32>) -> tensor<2x25xf32>
+// CHECK:         %[[INS10:.+]] = tensor.insert_slice %[[R2]] into %[[XS7]][0, 0] [2, 25] [1, 1] : tensor<2x25xf32> into tensor<3x25xf32>
+// CHECK:         %[[XS11:.+]] = tensor.extract_slice %[[XS6]][2, 0] [1, 34] [1, 1] : tensor<3x34xf32> to tensor<1x34xf32>
+// CHECK:         %[[XS12:.+]] = tensor.extract_slice %[[INS10]][2, 0] [1, 25] [1, 1] : tensor<3x25xf32> to tensor<1x25xf32>
+// CHECK:         %[[R3:.+]] = linalg.matmul ins(%[[XS11]], %[[IN2]] : tensor<1x34xf32>, tensor<34x25xf32>) outs(%[[XS12]] : tensor<1x25xf32>) -> tensor<1x25xf32>
+// CHECK:         %[[INS13:.+]] = tensor.insert_slice %[[R3]] into %[[INS10]][2, 0] [1, 25] [1, 1] : tensor<1x25xf32> into tensor<3x25xf32>
+// CHECK:         %[[INS14:.+]] = tensor.insert_slice %[[INS13]] into %[[INS5]][4, 0] [3, 25] [1, 1] : tensor<3x25xf32> into tensor<7x25xf32>
+// CHECK:         %[[INS15:.+]] = tensor.insert_slice %[[INS14]] into %[[INS]][18, 0] [7, 25] [1, 1] : tensor<7x25xf32> into tensor<25x25xf32>
+// CHECK:         return %[[INS15]] : tensor<25x25xf32>
+
+// -----
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %tile_sizes, %chunk_sizes = transform.structured.continuous_tile_sizes %0 { dimension = 0, target_size = 9 } : (!transform.any_op) -> !transform.any_op
+    %linalg_splits, %empty = transform.structured.split %0 after %chunk_sizes { dimension = 0, multiway } : !transform.any_op, !transform.any_op
+    transform.foreach %linalg_splits, %tile_sizes {zip_shortest} : !transform.any_op, !transform.any_op {
+    ^bb1(%linalg_split: !transform.any_op, %tile_size: !transform.any_op):
+      %tiled_linalg_split, %dim0_loop = transform.structured.tile_using_for %linalg_split tile_sizes [%tile_size] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+      transform.yield
+    }
+    transform.yield
+  }
+}
+
+func.func @continuous_tile_dynamic_linalg_matmul(
+  %arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>)
+    -> tensor<?x?xf32> {
+  %0 = linalg.matmul  ins(%arg0, %arg1: tensor<?x?xf32>, tensor<?x?xf32>)
+                     outs(%arg2: tensor<?x?xf32>)
+    -> tensor<?x?xf32>
+
+  return %0 : tensor<?x?xf32>
+}
+
+// CHECK:     #[[$MAP0:.*]] = affine_map<()[s0, s1] -> ((s0 floordiv 9) * 9, s1)>
+// CHECK:     #[[$MAP3:.*]] = affine_map<()[s0, s1, s2] -> (((s0 mod 9) floordiv 8) * 8, s1 - s2)>
+// CHECK:     #[[$MAP6:.*]] = affine_map<()[s0, s1, s2, s3] -> ((((s0 mod 9) mod 8) floordiv 4) * 4, s1 - s2 - s3)>
+// CHECK:     #[[$MAP9:.*]] = affine_map<()[s0, s1, s2, s3, s4] -> ((((s0 mod 9) mod 4) floordiv 2) * 2, s1 - s2 - s3 - s4)>
+// CHECK:     #[[$MAP12:.*]] = affine_map<()[s0, s1, s2, s3, s4, s5] -> ((s0 mod 9) mod 2, s1 - s2 - s3 - s4 - s5)>
+// CHECK-LABEL: @continuous_tile_dynamic_linalg_matmul
+// CHECK-DAG: %[[C9:.*]] = arith.constant 9 : index
+// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
+// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK:     %[[AM0:.*]] = affine.min #[[$MAP0]]()[%{{.*}}, %{{.*}}]
+// CHECK:     %{{.*}} = scf.for %[[IDX:.+]] = %[[C0]] to %[[AM0]] step %[[C9]] iter_args(%[[OUT:.+]] = %{{.*}}) -> (tensor<?x?xf32>) {
+// CHECK:       %[[MM:.+]] = linalg.matmul ins(%{{.*}}, %{{.*}} : tensor<?x?xf32>, tensor<?x?xf32>) outs(%{{.*}} : tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK:       %{{.*}} = tensor.insert_slice %[[MM]] into %[[OUT]][%[[IDX]], 0] [%{{.*}}, %{{.*}}] [1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
+// CHECK:     %[[AM4:.*]] = affine.min #[[$MAP3]]()[%{{.*}}, %{{.*}}, %[[AM0]]]
+// CHECK:     %{{.*}} = scf.for %[[IDX:.+]] = %[[C0]] to %[[AM4]] step %[[C8]] iter_args(%[[OUT:.+]] = %{{.*}}) -> (tensor<?x?xf32>) {
+// CHECK:       %[[MM:.+]] = linalg.matmul ins(%{{.*}}, %{{.*}} : tensor<?x?xf32>, tensor<?x?xf32>) outs(%{{.*}} : tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK:       %{{.*}} = tensor.insert_slice %[[MM]] into %[[OUT]][%[[IDX]], 0] [%{{.*}}, %{{.*}}] [1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
+// CHECK:     %[[AM8:.*]] = affine.min #[[$MAP6]]()[%{{.*}}, %{{.*}}, %[[AM0]], %[[AM4]]]
+// CHECK:     %{{.*}} = scf.for %[[IDX:.+]] = %[[C0]] to %[[AM8]] step %[[C4]] iter_args(%[[OUT:.+]] = %{{.*}}) -> (tensor<?x?xf32>) {
+// CHECK:       %[[MM:.+]] = linalg.matmul ins(%{{.*}}, %{{.*}} : tensor<?x?xf32>, tensor<?x?xf32>) outs(%{{.*}} : tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK:       %{{.*}} = tensor.insert_slice %[[MM]] into %[[OUT]][%[[IDX]], 0] [%{{.*}}, %{{.*}}] [1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
+// CHECK:     %[[AM12:.*]] = affine.min #[[$MAP9]]()[%{{.*}}, %{{.*}}, %[[AM0]], %[[AM4]], %[[AM8]]]
+// CHECK:     %{{.*}} = scf.for %[[IDX:.+]] = %[[C0]] to %[[AM12]] step %[[C2]] iter_args(%[[OUT:.+]] = %{{.*}}) -> (tensor<?x?xf32>) {
+// CHECK:       %[[MM:.+]] = linalg.matmul ins(%{{.*}}, %{{.*}} : tensor<?x?xf32>, tensor<?x?xf32>) outs(%{{.*}} : tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK:       %{{.*}} = tensor.insert_slice %[[MM]] into %[[OUT]][%[[IDX]], 0] [%{{.*}}, %{{.*}}] [1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
+// CHECK:     %[[AM16:.*]] = affine.min #[[$MAP12]]()[%{{.*}}, %{{.*}}, %[[AM0]], %[[AM4]], %[[AM8]], %[[AM12]]]
+// CHECK:     %{{.*}} = scf.for %[[IDX:.+]] = %[[C0]] to %[[AM16]] step %[[C1]] iter_args(%[[OUT:.+]] = %{{.*}}) -> (tensor<?x?xf32>) {
+// CHECK:       %[[MM:.+]] = linalg.matmul ins(%{{.*}}, %{{.*}} : tensor<1x?xf32>, tensor<?x?xf32>) outs(%{{.*}} : tensor<1x?xf32>) -> tensor<1x?xf32>
+// CHECK:       %{{.*}} = tensor.insert_slice %[[MM]] into %[[OUT]][%[[IDX]], 0] [1, %{{.*}}] [1, 1] : tensor<1x?xf32> into tensor<?x?xf32>
\ No newline at end of file

diff  --git a/mlir/test/Dialect/Linalg/continuous-tiling-multiway-split.mlir b/mlir/test/Dialect/Linalg/continuous-tiling-multiway-split.mlir
new file mode 100644
index 0000000000000..609766fbdc91f
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/continuous-tiling-multiway-split.mlir
@@ -0,0 +1,100 @@
+// RUN: mlir-opt --transform-interpreter --canonicalize --split-input-file %s | FileCheck %s
+
+// This tests the results of continuous_tile_sizes on multiway splitOp.
+// continuous_tile_sizes returns a list of tile-sizes and a list of split points.
+// The list of split points is consumed by splitOp to split the linalg.matmul op
+// along dimension 1 to produce as many split-up linalg.matmul ops.
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %tiles, %splits = transform.structured.continuous_tile_sizes %0 { dimension = 1, target_size = 9} : (!transform.any_op) -> !transform.any_op
+    %low, %high = transform.structured.split %0 after %splits { dimension = 1, multiway } : !transform.any_op, !transform.any_op
+    transform.yield
+  }
+}
+
+func.func @continuous_tile_linalg_matmul(
+  %arg0: tensor<25x34xf32>, %arg1: tensor<34x25xf32>, %arg2: tensor<25x25xf32>)
+    -> tensor<25x25xf32> {
+  %0 = linalg.matmul  ins(%arg0, %arg1: tensor<25x34xf32>, tensor<34x25xf32>)
+                     outs(%arg2: tensor<25x25xf32>)
+    -> tensor<25x25xf32>
+
+  return %0 : tensor<25x25xf32>
+}
+
+// CHECK-LABEL: @continuous_tile_linalg_matmul
+// CHECK-SAME: %[[IN1:.+]]: tensor<25x34xf32>, %[[IN2:.+]]: tensor<34x25xf32>, %[[OUT:.+]]: tensor<25x25xf32>
+// CHECK:      %[[SLICE:.+]] = tensor.extract_slice %[[IN2]][0, 0] [34, 18] [1, 1] : tensor<34x25xf32> to tensor<34x18xf32>
+// CHECK       %[[SLICE0:.+]] = tensor.extract_slice %[[OUT]][0, 0] [25, 18] [1, 1] : tensor<25x25xf32> to tensor<25x18xf32>
+// CHECK       %[[MM0:.+]] = linalg.matmul ins(%[[IN1]], %[[SLICE]] : tensor<25x34xf32>, tensor<34x18xf32>) outs(%[[SLICE0]] : tensor<25x18xf32>) -> tensor<25x18xf32>
+// CHECK       %[[INSLICE:.+]] = tensor.insert_slice %[[MM0]] into %[[OUT]][0, 0] [25, 18] [1, 1] : tensor<25x18xf32> into tensor<25x25xf32>
+// CHECK       %[[SLICE1]] = tensor.extract_slice %[[IN2]][0, 18] [34, 7] [1, 1] : tensor<34x25xf32> to tensor<34x7xf32>
+// CHECK       %[[SLICE2]] = tensor.extract_slice %[[INSLICE]][0, 18] [25, 7] [1, 1] : tensor<25x25xf32> to tensor<25x7xf32>
+// CHECK       %[[SLICE3]] = tensor.extract_slice %[[SLICE1]][0, 0] [34, 4] [1, 1] : tensor<34x7xf32> to tensor<34x4xf32>
+// CHECK       %[[SLICE4]] = tensor.extract_slice %[[SLICE2]][0, 0] [25, 4] [1, 1] : tensor<25x7xf32> to tensor<25x4xf32>
+// CHECK       %[[MM1:.+]] = linalg.matmul ins(%[[IN1]], %[[SLICE3]] : tensor<25x34xf32>, tensor<34x4xf32>) outs(%[[SLICE4]] : tensor<25x4xf32>) -> tensor<25x4xf32>
+// CHECK       %[[INSLICE0:.+]] = tensor.insert_slice %[[MM1]] into %[[SLICE2]][0, 0] [25, 4] [1, 1] : tensor<25x4xf32> into tensor<25x7xf32>
+// CHECK       %[[SLICE5]] = tensor.extract_slice %[[SLICE1]][0, 4] [34, 3] [1, 1] : tensor<34x7xf32> to tensor<34x3xf32>
+// CHECK       %[[SLICE6]] = tensor.extract_slice %[[INSLICE0]][0, 4] [25, 3] [1, 1] : tensor<25x7xf32> to tensor<25x3xf32>
+// CHECK       %[[SLICE7]] = tensor.extract_slice %[[SLICE5]][0, 0] [34, 2] [1, 1] : tensor<34x3xf32> to tensor<34x2xf32>
+// CHECK       %[[SLICE8]] = tensor.extract_slice %[[SLICE6]][0, 0] [25, 2] [1, 1] : tensor<25x3xf32> to tensor<25x2xf32>
+// CHECK       %[[MM2:.+]] = linalg.matmul ins(%[[IN1]], %[[SLICE7]] : tensor<25x34xf32>, tensor<34x2xf32>) outs(%[[SLICE8]] : tensor<25x2xf32>) -> tensor<25x2xf32>
+// CHECK       %[[INSLICE1:.+]] = tensor.insert_slice %[[MM2]] into %[[SLICE6]][0, 0] [25, 2] [1, 1] : tensor<25x2xf32> into tensor<25x3xf32>
+// CHECK       %[[SLICE9]] = tensor.extract_slice %[[SLICE5]][0, 2] [34, 1] [1, 1] : tensor<34x3xf32> to tensor<34x1xf32>
+// CHECK       %[[SLICE10]] = tensor.extract_slice %[[INSLICE1]][0, 2] [25, 1] [1, 1] : tensor<25x3xf32> to tensor<25x1xf32>
+// CHECK       %[[MM3:.+]] = linalg.matmul ins(%[[IN1]], %[[SLICE9]] : tensor<25x34xf32>, tensor<34x1xf32>) outs(%[[SLICE10]] : tensor<25x1xf32>) -> tensor<25x1xf32>
+// CHECK       %[[INSLICE2]] = tensor.insert_slice %[[MM3]] into %[[INSLICE1]][0, 2] [25, 1] [1, 1] : tensor<25x1xf32> into tensor<25x3xf32>
+// CHECK       %[[INSLICE3]] = tensor.insert_slice %[[INSLICE2]] into %[[INSLICE0]][0, 4] [25, 3] [1, 1] : tensor<25x3xf32> into tensor<25x7xf32>
+// CHECK       %[[INSLICE4]] = tensor.insert_slice %[[INSLICE3]] into %[[INSLICE]][0, 18] [25, 7] [1, 1] : tensor<25x7xf32> into tensor<25x25xf32>
+// CHECK       return %[[INSLICE4]] : tensor<25x25xf32>
+
+// -----
+
+// Tests the same as above except that the !transform.param<i64> output type in
+// continuous_tile_sizes op triggers tile sizes and split points to be computed
+// statically and not dynamically.
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %tiles, %splits = transform.structured.continuous_tile_sizes %0 { dimension = 1, target_size = 9} : (!transform.any_op) -> !transform.param<i64>
+    %low, %high = transform.structured.split %0 after %splits { dimension = 1, multiway } : !transform.any_op, !transform.param<i64>
+    transform.yield
+  }
+}
+
+func.func @continuous_tile_static_linalg_matmul(
+  %arg0: tensor<25x34xf32>, %arg1: tensor<34x25xf32>, %arg2: tensor<25x25xf32>)
+    -> tensor<25x25xf32> {
+  %0 = linalg.matmul  ins(%arg0, %arg1: tensor<25x34xf32>, tensor<34x25xf32>)
+                     outs(%arg2: tensor<25x25xf32>)
+    -> tensor<25x25xf32>
+
+  return %0 : tensor<25x25xf32>
+}
+
+// CHECK-LABEL: @continuous_tile_static_linalg_matmul
+// CHECK-SAME: %[[IN1:.+]]: tensor<25x34xf32>, %[[IN2:.+]]: tensor<34x25xf32>, %[[OUT:.+]]: tensor<25x25xf32>
+// CHECK:      %[[SLICE:.+]] = tensor.extract_slice %[[IN2]][0, 0] [34, 18] [1, 1] : tensor<34x25xf32> to tensor<34x18xf32>
+// CHECK       %[[SLICE0:.+]] = tensor.extract_slice %[[OUT]][0, 0] [25, 18] [1, 1] : tensor<25x25xf32> to tensor<25x18xf32>
+// CHECK       %[[MM0:.+]] = linalg.matmul ins(%[[IN1]], %[[SLICE]] : tensor<25x34xf32>, tensor<34x18xf32>) outs(%[[SLICE0]] : tensor<25x18xf32>) -> tensor<25x18xf32>
+// CHECK       %[[INSLICE:.+]] = tensor.insert_slice %[[MM0]] into %[[OUT]][0, 0] [25, 18] [1, 1] : tensor<25x18xf32> into tensor<25x25xf32>
+// CHECK       %[[SLICE1]] = tensor.extract_slice %[[IN2]][0, 18] [34, 7] [1, 1] : tensor<34x25xf32> to tensor<34x7xf32>
+// CHECK       %[[SLICE2]] = tensor.extract_slice %[[INSLICE]][0, 18] [25, 7] [1, 1] : tensor<25x25xf32> to tensor<25x7xf32>
+// CHECK       %[[SLICE3]] = tensor.extract_slice %[[SLICE1]][0, 0] [34, 4] [1, 1] : tensor<34x7xf32> to tensor<34x4xf32>
+// CHECK       %[[SLICE4]] = tensor.extract_slice %[[SLICE2]][0, 0] [25, 4] [1, 1] : tensor<25x7xf32> to tensor<25x4xf32>
+// CHECK       %[[MM1:.+]] = linalg.matmul ins(%[[IN1]], %[[SLICE3]] : tensor<25x34xf32>, tensor<34x4xf32>) outs(%[[SLICE4]] : tensor<25x4xf32>) -> tensor<25x4xf32>
+// CHECK       %[[INSLICE0:.+]] = tensor.insert_slice %[[MM1]] into %[[SLICE2]][0, 0] [25, 4] [1, 1] : tensor<25x4xf32> into tensor<25x7xf32>
+// CHECK       %[[SLICE5]] = tensor.extract_slice %[[SLICE1]][0, 4] [34, 3] [1, 1] : tensor<34x7xf32> to tensor<34x3xf32>
+// CHECK       %[[SLICE6]] = tensor.extract_slice %[[INSLICE0]][0, 4] [25, 3] [1, 1] : tensor<25x7xf32> to tensor<25x3xf32>
+// CHECK       %[[SLICE7]] = tensor.extract_slice %[[SLICE5]][0, 0] [34, 2] [1, 1] : tensor<34x3xf32> to tensor<34x2xf32>
+// CHECK       %[[SLICE8]] = tensor.extract_slice %[[SLICE6]][0, 0] [25, 2] [1, 1] : tensor<25x3xf32> to tensor<25x2xf32>
+// CHECK       %[[MM2:.+]] = linalg.matmul ins(%[[IN1]], %[[SLICE7]] : tensor<25x34xf32>, tensor<34x2xf32>) outs(%[[SLICE8]] : tensor<25x2xf32>) -> tensor<25x2xf32>
+// CHECK       %[[INSLICE1:.+]] = tensor.insert_slice %[[MM2]] into %[[SLICE6]][0, 0] [25, 2] [1, 1] : tensor<25x2xf32> into tensor<25x3xf32>
+// CHECK       %[[SLICE9]] = tensor.extract_slice %[[SLICE5]][0, 2] [34, 1] [1, 1] : tensor<34x3xf32> to tensor<34x1xf32>
+// CHECK       %[[SLICE10]] = tensor.extract_slice %[[INSLICE1]][0, 2] [25, 1] [1, 1] : tensor<25x3xf32> to tensor<25x1xf32>
+// CHECK       %[[MM3:.+]] = linalg.matmul ins(%[[IN1]], %[[SLICE9]] : tensor<25x34xf32>, tensor<34x1xf32>) outs(%[[SLICE10]] : tensor<25x1xf32>) -> tensor<25x1xf32>
+// CHECK       %[[INSLICE2]] = tensor.insert_slice %[[MM3]] into %[[INSLICE1]][0, 2] [25, 1] [1, 1] : tensor<25x1xf32> into tensor<25x3xf32>
+// CHECK       %[[INSLICE3]] = tensor.insert_slice %[[INSLICE2]] into %[[INSLICE0]][0, 4] [25, 3] [1, 1] : tensor<25x3xf32> into tensor<25x7xf32>
+// CHECK       %[[INSLICE4]] = tensor.insert_slice %[[INSLICE3]] into %[[INSLICE]][0, 18] [25, 7] [1, 1] : tensor<25x7xf32> into tensor<25x25xf32>
+// CHECK       return %[[INSLICE4]] : tensor<25x25xf32>

diff  --git a/mlir/test/Dialect/Linalg/transform-op-split.mlir b/mlir/test/Dialect/Linalg/transform-op-split.mlir
index 566e517d69789..e072fff4c5d77 100644
--- a/mlir/test/Dialect/Linalg/transform-op-split.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-split.mlir
@@ -197,7 +197,7 @@ func.func @two_d(%arg0: tensor<10x34xf32>,
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.consumed}) {
     // expected-error @below {{expects either a dynamic or a static split point to be provided}}
-    %0:2 = "transform.structured.split"(%arg1) { dimension = 1, static_split_point = -9223372036854775808 } : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %0:2 = "transform.structured.split"(%arg1) { dimension = 1, static_chunk_sizes = -9223372036854775808 } : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
     transform.yield
   }
 }

diff  --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py
index f97017b7a2c75..3ea73e8beea36 100644
--- a/mlir/test/python/dialects/transform_structured_ext.py
+++ b/mlir/test/python/dialects/transform_structured_ext.py
@@ -361,8 +361,8 @@ def testScalarize(target):
 @run
 @create_sequence
 def testSplit(target):
-    split = structured.SplitOp(target, dimension=1, split_point=42)
-    structured.SplitOp(split.results[0], dimension=3, split_point=split.results[1])
+    split = structured.SplitOp(target, dimension=1, chunk_sizes=42)
+    structured.SplitOp(split.results[0], dimension=3, chunk_sizes=split.results[1])
     # CHECK-LABEL: TEST: testSplit
     # CHECK: %[[F:.+]], %[[S:.+]] = transform.structured.split %{{.*}} after 42 {dimension = 1
     # CHECK: transform.structured.split %[[F]] after %[[S]] {dimension = 3


        


More information about the Mlir-commits mailing list