[Mlir-commits] [mlir] [MLIR] Add continuous tiling to TileUsingForOp (PR #82792)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu May 23 03:48:43 PDT 2024


https://github.com/muneebkhan85 updated https://github.com/llvm/llvm-project/pull/82792

>From 40d377e666c12a0bd7141407cf5392298504269c Mon Sep 17 00:00:00 2001
From: Muneeb Khan <muneeb.khan at huawei.com>
Date: Wed, 1 May 2024 22:15:18 +0800
Subject: [PATCH 1/6] [MLIR] Add continuous tiling to Transform dialect

Add continuous tiling op structured.continuous_tile
to the transform dialect that returns as result a list of
exponentially diminishing tile sizes and a list of split
points to do a multiway split of the target linalg op along
the specified dimension.
---
 .../Linalg/TransformOps/LinalgTransformOps.td |  46 ++++++
 .../Dialect/Linalg/Transforms/Transforms.h    |  20 +++
 .../TransformOps/LinalgTransformOps.cpp       | 151 ++++++++++++++++++
 mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp | 133 +++++++++++++++
 4 files changed, 350 insertions(+)

diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 5585ba27fdad8..6e6bfab8d179e 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -1819,6 +1819,52 @@ def TileReductionUsingForallOp :
 
 }
 
+//===----------------------------------------------------------------------===//
+// ContinuousTileSizesOp
+//===----------------------------------------------------------------------===//
+
+def ContinuousTileSizesOp : Op<Transform_Dialect, "structured.continuous_tile_sizes",
+       [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+        DeclareOpInterfaceMethods<TransformOpInterface>,
+        ReportTrackingListenerFailuresOpTrait]> {
+  let description = [{
+    This transform takes a linalg as target and a dimension and target size
+    as attributes to generate a list of (1) exponentially diminishing
+    tile sizes that are powers of 2; and (2) the corresponding chunk-sizes
+    the linalg op should be split into along the given dimension.
+
+    For example, for `target_size` 9, and `dimension` 0 for the following
+    linalg op as target
+
+    ```
+      %0 = linalg.matmul  ins(%arg0, %arg1: tensor<25x34xf32>, tensor<34x25xf32>)
+                      outs(%arg2: tensor<25x25xf32>)
+    ```
+
+    the first result `tile_sizes` will be a list of diminishing tile sizes
+    9, 4, 2, 1; and the second result will be a list of chunk sizes
+    18, 4, 2, 1 that the corresponding dimension should be split into.
+
+    After the linalg has been split along the given dimension (for example using
+    multiway split), each chunk can be tiled with the corresponding tile size in
+    the `tile_sizes` list generated as a result of this op.
+
+    Specifying the output type as !transform.param<i64> will cause `tile_sizes`
+    and `split_points` to be computed statically and not dynamically.
+  }];
+
+  let arguments = (ins TransformHandleTypeInterface:$target,
+                       ConfinedAttr<I64Attr, [IntNonNegative]>:$dimension,
+                       ConfinedAttr<I64Attr, [IntNonNegative]>:$target_size);
+  let results = (outs TransformAnyParamTypeOrAnyHandle:$tile_sizes,
+                      TransformAnyParamTypeOrAnyHandle:$split_points);
+  let hasVerifier = 1;
+  let assemblyFormat =
+    "$target attr-dict `:` custom<ContinuousTileSizeTypes>("
+    "type($target), type($tile_sizes), type($split_points))";
+
+}
+
 //===----------------------------------------------------------------------===//
 // TileUsingForOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index f77c19ed0fcce..d902df458c2b7 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -801,6 +801,15 @@ struct MultiSizeSpecificationBase {
   /// Number of tiles associated with each size.
   T lowTripCount, highTripCount;
 };
+
+template <typename T>
+struct ContinuousTileSizeSpecificationBase {
+  /// Tile sizes.
+  SmallVector<T> tileSizes;
+  /// Number of tiles associated with each size.
+  SmallVector<T> tripCounts;
+};
+
 } // namespace detail
 
 /// A description of a multi-size tiling comprising tile sizes and numbers of
@@ -811,6 +820,11 @@ struct MultiSizeSpecification
 struct StaticMultiSizeSpecification
     : public detail::MultiSizeSpecificationBase<int64_t> {};
 
+struct ContinuousTileSizeSpecification
+    : public detail::ContinuousTileSizeSpecificationBase<Value> {};
+struct StaticContinuousTileSizeSpecification
+    : public detail::ContinuousTileSizeSpecificationBase<int64_t> {};
+
 /// Emits the IR computing the multi-sized tiling specification with two tile
 /// sizes not exceeding `targetSize`, each divisible by `sizeDivisor`, such
 /// that there exist numbers of tiles with these sizes that fully cover the
@@ -846,6 +860,12 @@ FailureOr<StaticMultiSizeSpecification>
 computeStaticMultiTileSizes(LinalgOp op, unsigned dimension, int64_t targetSize,
                             int64_t divisor);
 
+FailureOr<StaticContinuousTileSizeSpecification>
+computeStaticContinuousTileSizes(LinalgOp op, unsigned dimension,
+                                 unsigned targetSize);
+FailureOr<ContinuousTileSizeSpecification>
+computeContinuousTileSizes(OpBuilder &builder, LinalgOp op, unsigned dimension,
+                           OpFoldResult targetSize, bool emitAssertions);
 /// Rewrite a TilingInterface `op` to a tiled `scf.forall`, applying
 /// tiling by `numThreads`.
 /// If non-empty, the `mapping` is added as an attribute to the
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 13582a140a965..4b2d9485c16d1 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2581,6 +2581,157 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne(
   return DiagnosedSilenceableFailure::success();
 }
 
+//===----------------------------------------------------------------------===//
+// ContinuousTileSizesOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::ContinuousTileSizesOp::apply(transform::TransformRewriter &rewriter,
+                                        TransformResults &transformResults,
+                                        TransformState &state) {
+
+  SmallVector<Operation *> targetOps =
+      llvm::to_vector(state.getPayloadOps(getTarget()));
+
+  if (!llvm::hasSingleElement(targetOps)) {
+    return emitDefiniteFailure() << "requires exactly one target (got "
+                                 << llvm::range_size(targetOps) << ")";
+  }
+
+  auto target = dyn_cast<LinalgOp>(*targetOps.begin());
+
+  OpBuilder builder(target.getContext());
+
+  if (!target)
+    return emitDefiniteFailure() << "expected Linalg Op";
+
+  if (isa<TransformParamTypeInterface>(getSplitPoints().getType())) {
+    if (target.hasDynamicShape()) {
+      auto diag = emitSilenceableError()
+                  << "cannot compute parametric tile sizes for dynamically "
+                     "shaped payload op";
+      diag.attachNote(target->getLoc()) << "payload op";
+      return diag;
+    }
+
+    FailureOr<StaticContinuousTileSizeSpecification> spec =
+        computeStaticContinuousTileSizes(target, getDimension(),
+                                         getTargetSize());
+    if (failed(spec)) {
+      return emitSilenceableError()
+             << "failed to compute multi-size tiling sizes";
+    }
+
+    SmallVector<int64_t> splitPoints;
+
+    auto tileSizeTripCountPairs =
+        llvm::zip_equal(spec->tileSizes, spec->tripCounts);
+
+    for (auto [idx, pair] : llvm::enumerate(tileSizeTripCountPairs))
+      splitPoints.push_back(std::get<0>(pair) * std::get<1>(pair));
+
+    auto makeI64AttrsFromI64 = [&](ArrayRef<int64_t> values) {
+      return llvm::to_vector(
+          llvm::map_range(values, [&](int64_t value) -> Attribute {
+            return builder.getI64IntegerAttr(value);
+          }));
+    };
+    transformResults.setParams(cast<OpResult>(getTileSizes()),
+                               makeI64AttrsFromI64(spec->tileSizes));
+    transformResults.setParams(cast<OpResult>(getSplitPoints()),
+                               makeI64AttrsFromI64(splitPoints));
+
+    return DiagnosedSilenceableFailure::success();
+  }
+
+  builder.setInsertionPoint(target);
+
+  OpFoldResult targetSize = builder.getIndexAttr(getTargetSize());
+  unsigned dimension = getDimension();
+
+  FailureOr<ContinuousTileSizeSpecification> spec =
+      computeContinuousTileSizes(builder, target, dimension, targetSize, true);
+  if (failed(spec)) {
+    return emitSilenceableError() << "could not generate tile size computation";
+  }
+
+  auto tileSizeTripCountPairs =
+      llvm::zip_equal(spec->tileSizes, spec->tripCounts);
+
+  AffineExpr s0 = builder.getAffineSymbolExpr(0);
+  AffineExpr s1 = builder.getAffineSymbolExpr(1);
+  auto apply = [&](AffineExpr expr, ArrayRef<OpFoldResult> ofrs) -> Value {
+    return affine::makeComposedAffineApply(builder, target->getLoc(), expr,
+                                           ofrs);
+  };
+
+  SmallVector<Value> splitPoints;
+  Value splitPoint;
+  for (auto [idx, pair] : llvm::enumerate(tileSizeTripCountPairs)) {
+    splitPoint = apply(s0 * s1, {std::get<0>(pair), std::get<1>(pair)});
+    splitPoints.push_back(splitPoint);
+  }
+
+  auto makeOpFromValue = [&](ArrayRef<Value> values) {
+    return llvm::to_vector(
+        llvm::map_range(values, [&](Value value) -> Operation * {
+          return value.getDefiningOp();
+        }));
+  };
+
+  transformResults.set(cast<OpResult>(getTileSizes()),
+                       makeOpFromValue(spec->tileSizes));
+  transformResults.set(cast<OpResult>(getSplitPoints()),
+                       makeOpFromValue(splitPoints));
+
+  return DiagnosedSilenceableFailure::success();
+}
+
+LogicalResult transform::ContinuousTileSizesOp::verify() {
+
+  if (getTileSizes().getType() != getSplitPoints().getType()) {
+    return emitOpError() << "expects all results type to be the same";
+  }
+
+  return success();
+}
+
+void transform::ContinuousTileSizesOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  if (isa<TransformParamTypeInterface>(getTileSizes().getType()))
+    onlyReadsPayload(effects);
+  else
+    modifiesPayload(effects);
+  onlyReadsHandle(getTarget(), effects);
+  producesHandle(getTileSizes(), effects);
+  producesHandle(getSplitPoints(), effects);
+}
+
+static void printContinuousTileSizeTypes(OpAsmPrinter &printer, Operation *op,
+                                         Type targetType, Type tile_sizes,
+                                         Type) {
+  printer.printFunctionalType(TypeRange{targetType}, TypeRange{tile_sizes});
+}
+
+static ParseResult parseContinuousTileSizeTypes(OpAsmParser &parser,
+                                                Type &targetType,
+                                                Type &tileSizesType,
+                                                Type &splitPointsType) {
+  FunctionType funcType;
+  llvm::SMLoc typeLoc = parser.getCurrentLocation();
+  if (failed(parser.parseType<FunctionType>(funcType)))
+    return failure();
+
+  if (funcType.getNumInputs() != 1 || funcType.getNumResults() != 1) {
+    parser.emitError(typeLoc) << "expects a trailing functional type with one "
+                                 "argument and one result";
+  }
+  targetType = funcType.getInput(0);
+  tileSizesType = splitPointsType = funcType.getResult(0);
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // TileUsingForOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index df4089d61bfd7..8049ade591a70 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -107,6 +107,139 @@ static void emitIsPositiveIndexAssertion(ImplicitLocOpBuilder &b,
       b.getStringAttr("expected strictly positive tile size and divisor"));
 }
 
+FailureOr<StaticContinuousTileSizeSpecification>
+mlir::linalg::computeStaticContinuousTileSizes(LinalgOp op, unsigned dimension,
+                                               unsigned targetSize) {
+
+  assert(!op.hasDynamicShape() &&
+         "cannot compute static multi-tile sizes for an op with dynamic shape");
+  assert(targetSize > 0 && "target size must be non-negative");
+  assert(dimension < op.getNumLoops() && "dimension overflow");
+
+  StaticContinuousTileSizeSpecification spec;
+  int64_t loopRange = op.getStaticLoopRanges()[dimension];
+  int64_t tripCount = loopRange / targetSize;
+
+  unsigned tileSize = targetSize;
+
+  spec.tileSizes.push_back(tileSize);
+  spec.tripCounts.push_back(tripCount);
+
+  int64_t remainderChunk = loopRange % targetSize;
+
+  while (tileSize > 1 && remainderChunk != 0) {
+
+    uint64_t maxPower = llvm::bit_floor(tileSize);
+    tileSize = maxPower == tileSize ? maxPower >> 1 : maxPower;
+
+    tripCount = remainderChunk / tileSize;
+
+    if (tripCount > 0) {
+      spec.tileSizes.push_back(tileSize);
+      spec.tripCounts.push_back(tripCount);
+    }
+
+    remainderChunk = remainderChunk % tileSize;
+  }
+
+  auto tripCountCheck = [&](SmallVector<int64_t> tileSizes,
+                            SmallVector<int64_t> tripCounts,
+                            int64_t range) -> bool {
+    int64_t computedRange = 0;
+    for (auto [tileSize, tripCount] : llvm::zip(tileSizes, tripCounts))
+      computedRange += tileSize * tripCount;
+    return range == computedRange;
+  };
+
+  if (!tripCountCheck(spec.tileSizes, spec.tripCounts, loopRange))
+    return failure();
+
+  return spec;
+}
+
+FailureOr<ContinuousTileSizeSpecification>
+mlir::linalg::computeContinuousTileSizes(OpBuilder &builder, LinalgOp op,
+                                         unsigned dimension,
+                                         OpFoldResult targetSize,
+                                         bool emitAssertions) {
+
+  // Bail out on dimension overflow.
+  if (dimension >= op.getNumLoops())
+    return failure();
+
+  // The code below works only on values.
+  Location loc = op.getLoc();
+  ImplicitLocOpBuilder b(loc, builder);
+  if (emitAssertions) {
+    emitIsPositiveIndexAssertion(b, targetSize);
+  }
+  Value targetSizeValue =
+      getValueOrCreateConstantIndexOp(builder, loc, targetSize);
+
+  // Find the trip count of the iteration space dimension for which the tile
+  // sizes are computed.
+  SmallVector<OpFoldResult> allShapes =
+      op.createFlatListOfOperandDims(b, b.getLoc());
+  AffineMap shapesToLoops = op.getShapesToLoopsMap();
+  SmallVector<OpFoldResult> loopRanges =
+      makeComposedFoldedMultiResultAffineApply(b, op.getLoc(), shapesToLoops,
+                                               allShapes);
+
+  Value loopRange =
+      getValueOrCreateConstantIndexOp(b, op.getLoc(), loopRanges[dimension]);
+
+  ContinuousTileSizeSpecification spec;
+
+  // Compute the tile sizes and the respective numbers of tiles.
+  AffineExpr s0 = b.getAffineSymbolExpr(0);
+  AffineExpr s1 = b.getAffineSymbolExpr(1);
+  auto apply = [&](AffineExpr expr, ArrayRef<OpFoldResult> ofrs) -> Value {
+    return affine::makeComposedAffineApply(b, b.getLoc(), expr, ofrs);
+  };
+
+  Value tripCountValue = apply(s0.floorDiv(s1), {loopRange, targetSizeValue});
+  Value remainderChunkValue = apply(s0 % s1, {loopRange, targetSizeValue});
+
+  OpFoldResult tripCountSize = affine::makeComposedFoldedAffineApply(
+      b, b.getLoc(), s0.floorDiv(s1), {loopRange, targetSizeValue});
+
+  uint64_t tileSizeInt = *getConstantIntValue(targetSizeValue);
+
+  assert(tileSizeInt > 0 && "target size must be non-negative");
+
+  spec.tileSizes.push_back(targetSizeValue);
+  spec.tripCounts.push_back(tripCountValue);
+
+  while (tileSizeInt > 1) {
+    uint64_t maxPower = llvm::bit_floor(tileSizeInt);
+    tileSizeInt = maxPower == tileSizeInt ? maxPower >> 1 : maxPower;
+    auto constStepOp =
+        builder.createOrFold<arith::ConstantIndexOp>(b.getLoc(), tileSizeInt);
+    tripCountValue = apply(s0.floorDiv(s1), {remainderChunkValue, constStepOp});
+
+    tripCountSize = affine::makeComposedFoldedAffineApply(
+        b, b.getLoc(), s0.floorDiv(s1), {remainderChunkValue, constStepOp});
+
+    // Optimization if tripCount can be determined to be zero.
+    if (Attribute attr = llvm::dyn_cast_if_present<Attribute>(tripCountSize)) {
+      auto intAttr = cast<IntegerAttr>(attr);
+      bool isTripCountZero = intAttr.getValue().isZero();
+
+      if (!isTripCountZero) {
+        spec.tileSizes.push_back(constStepOp);
+        spec.tripCounts.push_back(tripCountValue);
+      }
+    } else {
+      spec.tileSizes.push_back(constStepOp);
+      spec.tripCounts.push_back(tripCountValue);
+    }
+
+    remainderChunkValue = apply(s0 % s1, {remainderChunkValue, constStepOp});
+  }
+
+  return spec;
+}
+
 FailureOr<StaticMultiSizeSpecification>
 mlir::linalg::computeStaticMultiTileSizes(LinalgOp op, unsigned dimension,
                                           int64_t targetSize, int64_t divisor) {

>From 1f0c397de2cca7a98c4d74bb73705e790879efae Mon Sep 17 00:00:00 2001
From: Muneeb Khan <muneeb.khan at huawei.com>
Date: Wed, 1 May 2024 22:26:19 +0800
Subject: [PATCH 2/6] [MLIR] Add support for multiway split in SplitOp

Add functionality that enables SplitOp to do a multiway split of
a traget linalg along a given dimension. When multiway attribute
is `true`, the SplitOp takes a list of split points and applies
it to a single linalg along the given dimension to generate
multiple linalgs extracted from the target.
---
 .../Linalg/TransformOps/LinalgTransformOps.td |  23 ++-
 .../TransformOps/LinalgTransformOps.cpp       | 150 +++++++++++++-----
 2 files changed, 123 insertions(+), 50 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 6e6bfab8d179e..51138104c055e 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -1396,7 +1396,7 @@ def SplitOp : Op<Transform_Dialect, "structured.split",
      DeclareOpInterfaceMethods<TransformOpInterface>,
      ReportTrackingListenerFailuresOpTrait]> {
   let description = [{
-    Indicates that the given `target` op should be split into two complementary
+    Splits the given `target` op into two or more complementary
     parts, which combined cover the entire iteration domain of the original op.
     The split is performed along the iteration space dimension provided as
     attribute. In case of dimension overflow, the transformation fails. The
@@ -1409,16 +1409,27 @@ def SplitOp : Op<Transform_Dialect, "structured.split",
     operations pointed to by the target handle.
 
     The operation consumes the target handle, but preserves the split point
-    handle if provided. It produces two new handles pointing to the two parts
-    of the structured op after splitting, in the same order as the target
-    operand, with the first handle corresponding to the part with lower
-    iteration space indices.
+    handle if provided. Without the `multiway` attribute, it produces two
+    new handles pointing to the two parts of the structured op after splitting,
+    in the same order as the target operand, with the first handle
+    corresponding to the part with lower iteration space indices.
+
+    Multiway split mode is enabled by specifying the `multiway` attribute.
+    In this mode a single `target` op is split into multiple parts covering
+    the iteration space of the specified dimension. `static_split_point` and
+    `dynamic_split_point` in this case is a list of chunk sizes that the given
+    dimension should be split into. With `multiway` it produces two handles;
+    the first handle is a list of the multiple parts of the structured op
+    after splitting, where the target dimensions for each linalg op in the
+    list corresponds to the chunk sizes specfied in the input split list.
+    The second handle is empty.
   }];
 
   let arguments = (ins TransformHandleTypeInterface:$target,
                        I64Attr:$dimension,
                        Optional<TransformAnyParamTypeOrAnyHandle>:$dynamic_split_point,
-                       I64Attr:$static_split_point);
+                       I64Attr:$static_split_point,
+                       UnitAttr:$multiway);
   let results = (outs TransformHandleTypeInterface:$first,
                       TransformHandleTypeInterface:$second);
   let hasCustomAssemblyFormat = 1;
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 4b2d9485c16d1..907a7bfb72a59 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2269,8 +2269,20 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
   // Collect the dynamic split points if provided.
   SmallVector<Operation *> payload =
       llvm::to_vector(state.getPayloadOps(getTarget()));
+
+  bool isMultiwaySplit = getMultiway() ? true : false;
+
+  if (isMultiwaySplit && !llvm::hasSingleElement(payload)) {
+    return emitDefiniteFailure() << "requires exactly one target when "
+                                    "multiway split is enabled (got "
+                                 << llvm::range_size(payload) << ")";
+  }
+
   SmallVector<OpFoldResult> splitPoints;
-  splitPoints.reserve(payload.size());
+
+  if (!isMultiwaySplit)
+    splitPoints.reserve(payload.size());
+
   if (getDynamicSplitPoint()) {
     auto diag = DiagnosedSilenceableFailure::success();
     if (isa<TransformHandleTypeInterface>(getDynamicSplitPoint().getType())) {
@@ -2293,7 +2305,9 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
     if (diag.isSilenceableFailure())
       return diag;
 
-    if (splitPoints.size() != payload.size()) {
+    // For multiway split, a single payload is expected to have multiple
+    // split points.
+    if (!isMultiwaySplit && splitPoints.size() != payload.size()) {
       return emitDefiniteFailure()
              << "expected the dynamic split point handle to point to as "
                 "many operations ("
@@ -2305,57 +2319,105 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
                        rewriter.getIndexAttr(getStaticSplitPoint()));
   }
 
-  // Split each target operation.
-  SmallVector<Operation *> first, second;
-  Operation *noSecondPart = nullptr;
-  for (const auto &pair : llvm::zip(payload, splitPoints)) {
-    Operation *target = std::get<0>(pair);
-    auto linalgOp = dyn_cast<LinalgOp>(target);
-    if (!linalgOp) {
-      auto diag = emitSilenceableError() << "only applies to structured ops";
-      diag.attachNote(target->getLoc()) << "target op";
-      return diag;
-    }
+  if (isMultiwaySplit) {
 
-    if (getDimension() >= linalgOp.getNumLoops()) {
-      auto diag = emitSilenceableError() << "dimension " << getDimension()
-                                         << " does not exist in target op";
-      diag.attachNote(target->getLoc()) << "target op";
-      return diag;
+    // Split a single target operation at multiple points.
+    SmallVector<Operation *> opList;
+    Operation *head, *tail;
+    for (const auto [idx, splitPoint] : llvm::enumerate(splitPoints)) {
+
+      Operation *target;
+      if (idx == 0)
+        target = payload.front();
+      else
+        target = tail;
+
+      if (!target)
+        break;
+
+      auto linalgOp = dyn_cast<LinalgOp>(target);
+
+      if (!linalgOp) {
+        auto diag = emitSilenceableError() << "only applies to structured ops";
+        diag.attachNote(target->getLoc()) << "target op";
+        return diag;
+      }
+
+      if (getDimension() >= linalgOp.getNumLoops()) {
+        auto diag = emitSilenceableError() << "dimension " << getDimension()
+                                           << " does not exist in target op";
+        diag.attachNote(target->getLoc()) << "target op";
+        return diag;
+      }
+
+      rewriter.setInsertionPoint(linalgOp);
+      std::tie(head, tail) = linalg::splitOp(
+          rewriter, cast<TilingInterface>(linalgOp.getOperation()),
+          getDimension(), splitPoint);
+
+      opList.push_back(head);
     }
 
-    rewriter.setInsertionPoint(linalgOp);
-    std::tie(first.emplace_back(), second.emplace_back()) = linalg::splitOp(
-        rewriter, cast<TilingInterface>(linalgOp.getOperation()),
-        getDimension(), std::get<1>(pair));
+    // Append any leftover parts to the end of the result list.
+    if (tail)
+      opList.push_back(tail);
+    results.set(cast<OpResult>(getFirst()), opList);
+    results.set(cast<OpResult>(getSecond()), {});
 
-    // Propagate errors.
-    if (!first.back() && !second.back()) {
-      auto diag = emitDefiniteFailure() << "internal failure in splitting";
-      diag.attachNote(target->getLoc()) << "target op";
-      return diag;
+  } else {
+    // Split each target operation.
+    SmallVector<Operation *> first, second;
+    Operation *noSecondPart = nullptr;
+    for (const auto &pair : llvm::zip(payload, splitPoints)) {
+      Operation *target = std::get<0>(pair);
+      auto linalgOp = dyn_cast<LinalgOp>(target);
+      if (!linalgOp) {
+        auto diag = emitSilenceableError() << "only applies to structured ops";
+        diag.attachNote(target->getLoc()) << "target op";
+        return diag;
+      }
+
+      if (getDimension() >= linalgOp.getNumLoops()) {
+        auto diag = emitSilenceableError() << "dimension " << getDimension()
+                                           << " does not exist in target op";
+        diag.attachNote(target->getLoc()) << "target op";
+        return diag;
+      }
+
+      rewriter.setInsertionPoint(linalgOp);
+      std::tie(first.emplace_back(), second.emplace_back()) = linalg::splitOp(
+          rewriter, cast<TilingInterface>(linalgOp.getOperation()),
+          getDimension(), std::get<1>(pair));
+
+      // Propagate errors.
+      if (!first.back() && !second.back()) {
+        auto diag = emitDefiniteFailure() << "internal failure in splitting";
+        diag.attachNote(target->getLoc()) << "target op";
+        return diag;
+      }
+
+      // Do not add null second parts.
+      if (!second.back()) {
+        noSecondPart = target;
+        second.pop_back();
+      }
     }
 
-    // Do not add null second parts.
-    if (!second.back()) {
-      noSecondPart = target;
-      second.pop_back();
+    if (second.size() != first.size() && !second.empty()) {
+      auto diag = emitSilenceableError()
+                  << "splitting does not produce the second part for a subset "
+                     "of targets";
+      diag.attachNote()
+          << "expected splitting to produce the second part of all "
+             "or none of the targets";
+      diag.attachNote(noSecondPart->getLoc())
+          << "first target with no second part";
+      return diag;
     }
-  }
 
-  if (second.size() != first.size() && !second.empty()) {
-    auto diag = emitSilenceableError()
-                << "splitting does not produce the second part for a subset "
-                   "of targets";
-    diag.attachNote() << "expected splitting to produce the second part of all "
-                         "or none of the targets";
-    diag.attachNote(noSecondPart->getLoc())
-        << "first target with no second part";
-    return diag;
+    results.set(cast<OpResult>(getFirst()), first);
+    results.set(cast<OpResult>(getSecond()), second);
   }
-
-  results.set(cast<OpResult>(getFirst()), first);
-  results.set(cast<OpResult>(getSecond()), second);
   return DiagnosedSilenceableFailure::success();
 }
 

>From 3b59a7fa1e0e86bb3d94c02d59ce278d627f4c48 Mon Sep 17 00:00:00 2001
From: Muneeb Khan <muneeb.khan at huawei.com>
Date: Thu, 2 May 2024 19:10:49 +0800
Subject: [PATCH 3/6] [MLIR] Test multiway SplitOp

Tests SplitOp for multiway splitting of a linalg op using
the result of `continuous_tile_sizes` to specify mutliple
split points for a single linalg op.
---
 .../continuous-tiling-multiway-split.mlir     | 100 ++++++++++++++++++
 1 file changed, 100 insertions(+)
 create mode 100644 mlir/test/Dialect/Linalg/continuous-tiling-multiway-split.mlir

diff --git a/mlir/test/Dialect/Linalg/continuous-tiling-multiway-split.mlir b/mlir/test/Dialect/Linalg/continuous-tiling-multiway-split.mlir
new file mode 100644
index 0000000000000..609766fbdc91f
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/continuous-tiling-multiway-split.mlir
@@ -0,0 +1,100 @@
+// RUN: mlir-opt --transform-interpreter --canonicalize --split-input-file %s | FileCheck %s
+
+// This tests the results of continuous_tile_sizes on multiway splitOp.
+// continuous_tile_sizes returns a list of tile-sizes and a list of split points.
+// The list of split points is consumed by splitOp to split the linalg.matmul op
+// along dimension 1 to produce as many split-up linalg.matmul ops.
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %tiles, %splits = transform.structured.continuous_tile_sizes %0 { dimension = 1, target_size = 9} : (!transform.any_op) -> !transform.any_op
+    %low, %high = transform.structured.split %0 after %splits { dimension = 1, multiway } : !transform.any_op, !transform.any_op
+    transform.yield
+  }
+}
+
+func.func @continuous_tile_linalg_matmul(
+  %arg0: tensor<25x34xf32>, %arg1: tensor<34x25xf32>, %arg2: tensor<25x25xf32>)
+    -> tensor<25x25xf32> {
+  %0 = linalg.matmul  ins(%arg0, %arg1: tensor<25x34xf32>, tensor<34x25xf32>)
+                     outs(%arg2: tensor<25x25xf32>)
+    -> tensor<25x25xf32>
+
+  return %0 : tensor<25x25xf32>
+}
+
+// CHECK-LABEL: @continuous_tile_linalg_matmul
+// CHECK-SAME: %[[IN1:.+]]: tensor<25x34xf32>, %[[IN2:.+]]: tensor<34x25xf32>, %[[OUT:.+]]: tensor<25x25xf32>
+// CHECK:      %[[SLICE:.+]] = tensor.extract_slice %[[IN2]][0, 0] [34, 18] [1, 1] : tensor<34x25xf32> to tensor<34x18xf32>
+// CHECK       %[[SLICE0:.+]] = tensor.extract_slice %[[OUT]][0, 0] [25, 18] [1, 1] : tensor<25x25xf32> to tensor<25x18xf32>
+// CHECK       %[[MM0:.+]] = linalg.matmul ins(%[[IN1]], %[[SLICE]] : tensor<25x34xf32>, tensor<34x18xf32>) outs(%[[SLICE0]] : tensor<25x18xf32>) -> tensor<25x18xf32>
+// CHECK       %[[INSLICE:.+]] = tensor.insert_slice %[[MM0]] into %[[OUT]][0, 0] [25, 18] [1, 1] : tensor<25x18xf32> into tensor<25x25xf32>
+// CHECK       %[[SLICE1]] = tensor.extract_slice %[[IN2]][0, 18] [34, 7] [1, 1] : tensor<34x25xf32> to tensor<34x7xf32>
+// CHECK       %[[SLICE2]] = tensor.extract_slice %[[INSLICE]][0, 18] [25, 7] [1, 1] : tensor<25x25xf32> to tensor<25x7xf32>
+// CHECK       %[[SLICE3]] = tensor.extract_slice %[[SLICE1]][0, 0] [34, 4] [1, 1] : tensor<34x7xf32> to tensor<34x4xf32>
+// CHECK       %[[SLICE4]] = tensor.extract_slice %[[SLICE2]][0, 0] [25, 4] [1, 1] : tensor<25x7xf32> to tensor<25x4xf32>
+// CHECK       %[[MM1:.+]] = linalg.matmul ins(%[[IN1]], %[[SLICE3]] : tensor<25x34xf32>, tensor<34x4xf32>) outs(%[[SLICE4]] : tensor<25x4xf32>) -> tensor<25x4xf32>
+// CHECK       %[[INSLICE0:.+]] = tensor.insert_slice %[[MM1]] into %[[SLICE2]][0, 0] [25, 4] [1, 1] : tensor<25x4xf32> into tensor<25x7xf32>
+// CHECK       %[[SLICE5]] = tensor.extract_slice %[[SLICE1]][0, 4] [34, 3] [1, 1] : tensor<34x7xf32> to tensor<34x3xf32>
+// CHECK       %[[SLICE6]] = tensor.extract_slice %[[INSLICE0]][0, 4] [25, 3] [1, 1] : tensor<25x7xf32> to tensor<25x3xf32>
+// CHECK       %[[SLICE7]] = tensor.extract_slice %[[SLICE5]][0, 0] [34, 2] [1, 1] : tensor<34x3xf32> to tensor<34x2xf32>
+// CHECK       %[[SLICE8]] = tensor.extract_slice %[[SLICE6]][0, 0] [25, 2] [1, 1] : tensor<25x3xf32> to tensor<25x2xf32>
+// CHECK       %[[MM2:.+]] = linalg.matmul ins(%[[IN1]], %[[SLICE7]] : tensor<25x34xf32>, tensor<34x2xf32>) outs(%[[SLICE8]] : tensor<25x2xf32>) -> tensor<25x2xf32>
+// CHECK       %[[INSLICE1:.+]] = tensor.insert_slice %[[MM2]] into %[[SLICE6]][0, 0] [25, 2] [1, 1] : tensor<25x2xf32> into tensor<25x3xf32>
+// CHECK       %[[SLICE9]] = tensor.extract_slice %[[SLICE5]][0, 2] [34, 1] [1, 1] : tensor<34x3xf32> to tensor<34x1xf32>
+// CHECK       %[[SLICE10]] = tensor.extract_slice %[[INSLICE1]][0, 2] [25, 1] [1, 1] : tensor<25x3xf32> to tensor<25x1xf32>
+// CHECK       %[[MM3:.+]] = linalg.matmul ins(%[[IN1]], %[[SLICE9]] : tensor<25x34xf32>, tensor<34x1xf32>) outs(%[[SLICE10]] : tensor<25x1xf32>) -> tensor<25x1xf32>
+// CHECK       %[[INSLICE2]] = tensor.insert_slice %[[MM3]] into %[[INSLICE1]][0, 2] [25, 1] [1, 1] : tensor<25x1xf32> into tensor<25x3xf32>
+// CHECK       %[[INSLICE3]] = tensor.insert_slice %[[INSLICE2]] into %[[INSLICE0]][0, 4] [25, 3] [1, 1] : tensor<25x3xf32> into tensor<25x7xf32>
+// CHECK       %[[INSLICE4]] = tensor.insert_slice %[[INSLICE3]] into %[[INSLICE]][0, 18] [25, 7] [1, 1] : tensor<25x7xf32> into tensor<25x25xf32>
+// CHECK       return %[[INSLICE4]] : tensor<25x25xf32>
+
+// -----
+
+// Tests the same as above except that the !transform.param<i64> output type in
+// continuous_tile_sizes op triggers tile sizes and split points to be computed
+// statically and not dynamically.
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %tiles, %splits = transform.structured.continuous_tile_sizes %0 { dimension = 1, target_size = 9} : (!transform.any_op) -> !transform.param<i64>
+    %low, %high = transform.structured.split %0 after %splits { dimension = 1, multiway } : !transform.any_op, !transform.param<i64>
+    transform.yield
+  }
+}
+
+func.func @continuous_tile_static_linalg_matmul(
+  %arg0: tensor<25x34xf32>, %arg1: tensor<34x25xf32>, %arg2: tensor<25x25xf32>)
+    -> tensor<25x25xf32> {
+  %0 = linalg.matmul  ins(%arg0, %arg1: tensor<25x34xf32>, tensor<34x25xf32>)
+                     outs(%arg2: tensor<25x25xf32>)
+    -> tensor<25x25xf32>
+
+  return %0 : tensor<25x25xf32>
+}
+
+// CHECK-LABEL: @continuous_tile_static_linalg_matmul
+// CHECK-SAME: %[[IN1:.+]]: tensor<25x34xf32>, %[[IN2:.+]]: tensor<34x25xf32>, %[[OUT:.+]]: tensor<25x25xf32>
+// CHECK:      %[[SLICE:.+]] = tensor.extract_slice %[[IN2]][0, 0] [34, 18] [1, 1] : tensor<34x25xf32> to tensor<34x18xf32>
+// CHECK       %[[SLICE0:.+]] = tensor.extract_slice %[[OUT]][0, 0] [25, 18] [1, 1] : tensor<25x25xf32> to tensor<25x18xf32>
+// CHECK       %[[MM0:.+]] = linalg.matmul ins(%[[IN1]], %[[SLICE]] : tensor<25x34xf32>, tensor<34x18xf32>) outs(%[[SLICE0]] : tensor<25x18xf32>) -> tensor<25x18xf32>
+// CHECK       %[[INSLICE:.+]] = tensor.insert_slice %[[MM0]] into %[[OUT]][0, 0] [25, 18] [1, 1] : tensor<25x18xf32> into tensor<25x25xf32>
+// CHECK       %[[SLICE1]] = tensor.extract_slice %[[IN2]][0, 18] [34, 7] [1, 1] : tensor<34x25xf32> to tensor<34x7xf32>
+// CHECK       %[[SLICE2]] = tensor.extract_slice %[[INSLICE]][0, 18] [25, 7] [1, 1] : tensor<25x25xf32> to tensor<25x7xf32>
+// CHECK       %[[SLICE3]] = tensor.extract_slice %[[SLICE1]][0, 0] [34, 4] [1, 1] : tensor<34x7xf32> to tensor<34x4xf32>
+// CHECK       %[[SLICE4]] = tensor.extract_slice %[[SLICE2]][0, 0] [25, 4] [1, 1] : tensor<25x7xf32> to tensor<25x4xf32>
+// CHECK       %[[MM1:.+]] = linalg.matmul ins(%[[IN1]], %[[SLICE3]] : tensor<25x34xf32>, tensor<34x4xf32>) outs(%[[SLICE4]] : tensor<25x4xf32>) -> tensor<25x4xf32>
+// CHECK       %[[INSLICE0:.+]] = tensor.insert_slice %[[MM1]] into %[[SLICE2]][0, 0] [25, 4] [1, 1] : tensor<25x4xf32> into tensor<25x7xf32>
+// CHECK       %[[SLICE5]] = tensor.extract_slice %[[SLICE1]][0, 4] [34, 3] [1, 1] : tensor<34x7xf32> to tensor<34x3xf32>
+// CHECK       %[[SLICE6]] = tensor.extract_slice %[[INSLICE0]][0, 4] [25, 3] [1, 1] : tensor<25x7xf32> to tensor<25x3xf32>
+// CHECK       %[[SLICE7]] = tensor.extract_slice %[[SLICE5]][0, 0] [34, 2] [1, 1] : tensor<34x3xf32> to tensor<34x2xf32>
+// CHECK       %[[SLICE8]] = tensor.extract_slice %[[SLICE6]][0, 0] [25, 2] [1, 1] : tensor<25x3xf32> to tensor<25x2xf32>
+// CHECK       %[[MM2:.+]] = linalg.matmul ins(%[[IN1]], %[[SLICE7]] : tensor<25x34xf32>, tensor<34x2xf32>) outs(%[[SLICE8]] : tensor<25x2xf32>) -> tensor<25x2xf32>
+// CHECK       %[[INSLICE1:.+]] = tensor.insert_slice %[[MM2]] into %[[SLICE6]][0, 0] [25, 2] [1, 1] : tensor<25x2xf32> into tensor<25x3xf32>
+// CHECK       %[[SLICE9]] = tensor.extract_slice %[[SLICE5]][0, 2] [34, 1] [1, 1] : tensor<34x3xf32> to tensor<34x1xf32>
+// CHECK       %[[SLICE10]] = tensor.extract_slice %[[INSLICE1]][0, 2] [25, 1] [1, 1] : tensor<25x3xf32> to tensor<25x1xf32>
+// CHECK       %[[MM3:.+]] = linalg.matmul ins(%[[IN1]], %[[SLICE9]] : tensor<25x34xf32>, tensor<34x1xf32>) outs(%[[SLICE10]] : tensor<25x1xf32>) -> tensor<25x1xf32>
+// CHECK       %[[INSLICE2]] = tensor.insert_slice %[[MM3]] into %[[INSLICE1]][0, 2] [25, 1] [1, 1] : tensor<25x1xf32> into tensor<25x3xf32>
+// CHECK       %[[INSLICE3]] = tensor.insert_slice %[[INSLICE2]] into %[[INSLICE0]][0, 4] [25, 3] [1, 1] : tensor<25x3xf32> into tensor<25x7xf32>
+// CHECK       %[[INSLICE4]] = tensor.insert_slice %[[INSLICE3]] into %[[INSLICE]][0, 18] [25, 7] [1, 1] : tensor<25x7xf32> into tensor<25x25xf32>
+// CHECK       return %[[INSLICE4]] : tensor<25x25xf32>

>From 2ebfbeb73f6fcb961e1cde714a056e5ff2fcec6b Mon Sep 17 00:00:00 2001
From: Muneeb Khan <muneeb.khan at huawei.com>
Date: Mon, 20 May 2024 23:27:12 +0800
Subject: [PATCH 4/6] fix for SplitOp; switch from split-point to chunk-sizes
 terminology. remove tautology. use emitSilenceableFailure. use references,
 use conditional operator. move common code out of conditional in a lambda
 function. check splitting operation was performed correctly. bug fix and
 refactoring for code duplication.

---
 .../Linalg/TransformOps/LinalgTransformOps.td |  22 +--
 .../TransformOps/LinalgTransformOps.cpp       | 164 ++++++++++--------
 .../Dialect/Linalg/transform-op-split.mlir    |   2 +-
 3 files changed, 102 insertions(+), 86 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 51138104c055e..f20deff0e1389 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -1399,16 +1399,18 @@ def SplitOp : Op<Transform_Dialect, "structured.split",
     Splits the given `target` op into two or more complementary
     parts, which combined cover the entire iteration domain of the original op.
     The split is performed along the iteration space dimension provided as
-    attribute. In case of dimension overflow, the transformation fails. The
-    split is performed at the dimension iterator value specified as either the
-    static split point attribute when it is known at transform IR construction
-    time or as the handle to an operation producing a single index-typed value
-    when it is computed by payload IR. In the latter case, the static split
+    chunk size attribute specifying the size of the lower part; the remaining
+    range in the iteration space is assigned as the upper part. In case of
+    dimension overflow, the transformation fails. The split is performed at the
+    dimension iterator value specified as either the static chunk size
+    attribute when it is known at transform IR construction time or
+    as the handle to an operation producing a single index-typed value
+    when it is computed by payload IR. In the latter case, the chunk size
     point must be set to `ShapedType::kDynamic` and the dynamic size handle
     must point to as many value-producing operations as there are structured
     operations pointed to by the target handle.
 
-    The operation consumes the target handle, but preserves the split point
+    The operation consumes the target handle, but preserves the chunk size
     handle if provided. Without the `multiway` attribute, it produces two
     new handles pointing to the two parts of the structured op after splitting,
     in the same order as the target operand, with the first handle
@@ -1416,8 +1418,8 @@ def SplitOp : Op<Transform_Dialect, "structured.split",
 
     Multiway split mode is enabled by specifying the `multiway` attribute.
     In this mode a single `target` op is split into multiple parts covering
-    the iteration space of the specified dimension. `static_split_point` and
-    `dynamic_split_point` in this case is a list of chunk sizes that the given
+    the iteration space of the specified dimension. `static_chunk_sizes` and
+    `dynamic_chunk_sizes` in this case is a list of chunk sizes that the given
     dimension should be split into. With `multiway` it produces two handles;
     the first handle is a list of the multiple parts of the structured op
     after splitting, where the target dimensions for each linalg op in the
@@ -1427,8 +1429,8 @@ def SplitOp : Op<Transform_Dialect, "structured.split",
 
   let arguments = (ins TransformHandleTypeInterface:$target,
                        I64Attr:$dimension,
-                       Optional<TransformAnyParamTypeOrAnyHandle>:$dynamic_split_point,
-                       I64Attr:$static_split_point,
+                       Optional<TransformAnyParamTypeOrAnyHandle>:$dynamic_chunk_sizes,
+                       I64Attr:$static_chunk_sizes,
                        UnitAttr:$multiway);
   let results = (outs TransformHandleTypeInterface:$first,
                       TransformHandleTypeInterface:$second);
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 907a7bfb72a59..98f6bde8275f4 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2270,24 +2270,25 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
   SmallVector<Operation *> payload =
       llvm::to_vector(state.getPayloadOps(getTarget()));
 
-  bool isMultiwaySplit = getMultiway() ? true : false;
+  bool isMultiwaySplit = getMultiway();
 
   if (isMultiwaySplit && !llvm::hasSingleElement(payload)) {
-    return emitDefiniteFailure() << "requires exactly one target when "
-                                    "multiway split is enabled (got "
-                                 << llvm::range_size(payload) << ")";
+    return mlir::emitSilenceableFailure(getLoc())
+           << "requires exactly one target when "
+              "multiway split is enabled (got "
+           << llvm::range_size(payload) << ")";
   }
 
-  SmallVector<OpFoldResult> splitPoints;
+  SmallVector<OpFoldResult> chunkSizes;
 
   if (!isMultiwaySplit)
-    splitPoints.reserve(payload.size());
+    chunkSizes.reserve(payload.size());
 
-  if (getDynamicSplitPoint()) {
+  if (getDynamicChunkSizes()) {
     auto diag = DiagnosedSilenceableFailure::success();
-    if (isa<TransformHandleTypeInterface>(getDynamicSplitPoint().getType())) {
-      splitPoints = llvm::to_vector(llvm::map_range(
-          state.getPayloadOps(getDynamicSplitPoint()), [&](Operation *op) {
+    if (isa<TransformHandleTypeInterface>(getDynamicChunkSizes().getType())) {
+      chunkSizes = llvm::to_vector(llvm::map_range(
+          state.getPayloadOps(getDynamicChunkSizes()), [&](Operation *op) {
             if (op->getNumResults() != 1 ||
                 !op->getResult(0).getType().isIndex()) {
               diag = emitSilenceableError()
@@ -2298,8 +2299,8 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
             return OpFoldResult(op->getResult(0));
           }));
     } else {
-      splitPoints = llvm::to_vector(
-          llvm::map_range(state.getParams(getDynamicSplitPoint()),
+      chunkSizes = llvm::to_vector(
+          llvm::map_range(state.getParams(getDynamicChunkSizes()),
                           [](Attribute attr) { return OpFoldResult(attr); }));
     }
     if (diag.isSilenceableFailure())
@@ -2307,53 +2308,75 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
 
     // For multiway split, a single payload is expected to have multiple
     // split points.
-    if (!isMultiwaySplit && splitPoints.size() != payload.size()) {
+    if (!isMultiwaySplit && chunkSizes.size() != payload.size()) {
       return emitDefiniteFailure()
              << "expected the dynamic split point handle to point to as "
                 "many operations ("
-             << splitPoints.size() << ") as the target handle ("
+             << chunkSizes.size() << ") as the target handle ("
              << payload.size() << ")";
     }
   } else {
-    splitPoints.resize(payload.size(),
-                       rewriter.getIndexAttr(getStaticSplitPoint()));
+    chunkSizes.resize(payload.size(),
+                       rewriter.getIndexAttr(getStaticChunkSizes()));
   }
 
+  auto checkStructuredOpAndDimensions = [&](LinalgOp linalgOp, Location loc) {
+    if (!linalgOp) {
+      auto diag = emitSilenceableError() << "only applies to structured ops";
+      diag.attachNote(loc) << "target op";
+      return diag;
+    }
+
+    if (getDimension() >= linalgOp.getNumLoops()) {
+      auto diag = emitSilenceableError() << "dimension " << getDimension()
+                                          << " does not exist in target op";
+      diag.attachNote(loc) << "target op";
+      return diag;
+    }
+    return DiagnosedSilenceableFailure::success();
+  };
+
+  auto checkFailureInSplitting = [&](bool hasFailed, Location loc) {
+    if (hasFailed) {
+      auto diag = emitDefiniteFailure() << "internal failure in splitting";
+      diag.attachNote(loc) << "target op";
+      return DiagnosedSilenceableFailure(diag);
+    }
+    return DiagnosedSilenceableFailure::success();
+  };
+
   if (isMultiwaySplit) {
 
     // Split a single target operation at multiple points.
     SmallVector<Operation *> opList;
     Operation *head, *tail;
-    for (const auto [idx, splitPoint] : llvm::enumerate(splitPoints)) {
+    Operation *target = payload.front();
+
+    auto linalgOp = dyn_cast<LinalgOp>(target);
+    auto diag = checkStructuredOpAndDimensions(linalgOp, target->getLoc());
+
+    if (diag.isSilenceableFailure())
+      return diag;
 
-      Operation *target;
-      if (idx == 0)
-        target = payload.front();
-      else
+    for (const auto &&[idx, chunkSize] : llvm::enumerate(chunkSizes)) {
+
+      if (idx > 0)
         target = tail;
 
       if (!target)
         break;
 
-      auto linalgOp = dyn_cast<LinalgOp>(target);
-
-      if (!linalgOp) {
-        auto diag = emitSilenceableError() << "only applies to structured ops";
-        diag.attachNote(target->getLoc()) << "target op";
-        return diag;
-      }
-
-      if (getDimension() >= linalgOp.getNumLoops()) {
-        auto diag = emitSilenceableError() << "dimension " << getDimension()
-                                           << " does not exist in target op";
-        diag.attachNote(target->getLoc()) << "target op";
-        return diag;
-      }
+      linalgOp = dyn_cast<LinalgOp>(target);
 
       rewriter.setInsertionPoint(linalgOp);
       std::tie(head, tail) = linalg::splitOp(
           rewriter, cast<TilingInterface>(linalgOp.getOperation()),
-          getDimension(), splitPoint);
+          getDimension(), chunkSize);
+
+      // Propagate errors.
+      auto diag = checkFailureInSplitting(!head && !tail, target->getLoc());
+      if (diag.isDefiniteFailure())
+        return diag;
 
       opList.push_back(head);
     }
@@ -2368,21 +2391,13 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
     // Split each target operation.
     SmallVector<Operation *> first, second;
     Operation *noSecondPart = nullptr;
-    for (const auto &pair : llvm::zip(payload, splitPoints)) {
+    for (const auto &pair : llvm::zip(payload, chunkSizes)) {
       Operation *target = std::get<0>(pair);
       auto linalgOp = dyn_cast<LinalgOp>(target);
-      if (!linalgOp) {
-        auto diag = emitSilenceableError() << "only applies to structured ops";
-        diag.attachNote(target->getLoc()) << "target op";
-        return diag;
-      }
+      auto diag = checkStructuredOpAndDimensions(linalgOp, target->getLoc());
 
-      if (getDimension() >= linalgOp.getNumLoops()) {
-        auto diag = emitSilenceableError() << "dimension " << getDimension()
-                                           << " does not exist in target op";
-        diag.attachNote(target->getLoc()) << "target op";
+      if (diag.isSilenceableFailure())
         return diag;
-      }
 
       rewriter.setInsertionPoint(linalgOp);
       std::tie(first.emplace_back(), second.emplace_back()) = linalg::splitOp(
@@ -2390,11 +2405,10 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
           getDimension(), std::get<1>(pair));
 
       // Propagate errors.
-      if (!first.back() && !second.back()) {
-        auto diag = emitDefiniteFailure() << "internal failure in splitting";
-        diag.attachNote(target->getLoc()) << "target op";
+      auto diagSplit = checkFailureInSplitting(!first.back() && !second.back(),
+                                     target->getLoc());
+      if (diagSplit.isDefiniteFailure())
         return diag;
-      }
 
       // Do not add null second parts.
       if (!second.back()) {
@@ -2424,27 +2438,27 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
 void SplitOp::getEffects(
     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
   consumesHandle(getTarget(), effects);
-  if (getDynamicSplitPoint())
-    onlyReadsHandle(getDynamicSplitPoint(), effects);
+  if (getDynamicChunkSizes())
+    onlyReadsHandle(getDynamicChunkSizes(), effects);
   producesHandle(getResults(), effects);
   modifiesPayload(effects);
 }
 
 ParseResult SplitOp::parse(OpAsmParser &parser, OperationState &result) {
-  OpAsmParser::UnresolvedOperand target, dynamicSplitPoint;
-  IntegerAttr staticSplitPoint;
+  OpAsmParser::UnresolvedOperand target, dynamicChunkSizes;
+  IntegerAttr staticChunkSizes;
   if (parser.parseOperand(target) || parser.parseKeyword("after"))
     return failure();
 
   OptionalParseResult dynamicPointParseResult =
-      parser.parseOptionalOperand(dynamicSplitPoint);
+      parser.parseOptionalOperand(dynamicChunkSizes);
   if (!dynamicPointParseResult.has_value()) {
-    int64_t staticSplitPointValue;
-    if (failed(parser.parseInteger(staticSplitPointValue)))
+    int64_t staticChunkSizesValue;
+    if (failed(parser.parseInteger(staticChunkSizesValue)))
       return failure();
 
-    staticSplitPoint =
-        parser.getBuilder().getI64IntegerAttr(staticSplitPointValue);
+    staticChunkSizes =
+        parser.getBuilder().getI64IntegerAttr(staticChunkSizesValue);
   }
 
   Type targetType;
@@ -2454,43 +2468,43 @@ ParseResult SplitOp::parse(OpAsmParser &parser, OperationState &result) {
     return failure();
   }
   if (dynamicPointParseResult.has_value()) {
-    Type splitPointType;
+    Type ChunkSizesType;
     if (failed(*dynamicPointParseResult) || parser.parseComma() ||
-        parser.parseType(splitPointType) ||
-        parser.resolveOperand(dynamicSplitPoint, splitPointType,
+        parser.parseType(ChunkSizesType) ||
+        parser.resolveOperand(dynamicChunkSizes, ChunkSizesType,
                               result.operands)) {
       return failure();
     }
 
-    staticSplitPoint =
+    staticChunkSizes =
         parser.getBuilder().getI64IntegerAttr(ShapedType::kDynamic);
   }
 
   result.addAttribute(
-      SplitOp::getStaticSplitPointAttrName(result.name).getValue(),
-      staticSplitPoint);
+      SplitOp::getStaticChunkSizesAttrName(result.name).getValue(),
+      staticChunkSizes);
   result.addTypes({targetType, targetType});
   return success();
 }
 
 void SplitOp::print(OpAsmPrinter &printer) {
   printer << " " << getTarget() << " after ";
-  int64_t staticSplitSize = static_cast<int64_t>(getStaticSplitPoint());
-  if (staticSplitSize != ShapedType::kDynamic)
-    printer << staticSplitSize;
+  int64_t staticChunkSize = static_cast<int64_t>(getStaticChunkSizes());
+  if (staticChunkSize != ShapedType::kDynamic)
+    printer << staticChunkSize;
   else
-    printer << getDynamicSplitPoint();
+    printer << getDynamicChunkSizes();
   printer << " ";
   printer.printOptionalAttrDict(getOperation()->getAttrs(),
-                                {getStaticSplitPointAttrName()});
+                                {getStaticChunkSizesAttrName()});
   printer << " : " << getTarget().getType();
-  if (staticSplitSize == ShapedType::kDynamic)
-    printer << ", " << getDynamicSplitPoint().getType();
+  if (staticChunkSize == ShapedType::kDynamic)
+    printer << ", " << getDynamicChunkSizes().getType();
 }
 
 LogicalResult SplitOp::verify() {
-  if ((static_cast<int64_t>(getStaticSplitPoint()) != ShapedType::kDynamic) ^
-      (getDynamicSplitPoint() == nullptr)) {
+  if ((static_cast<int64_t>(getStaticChunkSizes()) != ShapedType::kDynamic) ^
+      (getDynamicChunkSizes() == nullptr)) {
     return emitOpError() << "expects either a dynamic or a static split "
                             "point to be provided";
   }
diff --git a/mlir/test/Dialect/Linalg/transform-op-split.mlir b/mlir/test/Dialect/Linalg/transform-op-split.mlir
index 566e517d69789..e072fff4c5d77 100644
--- a/mlir/test/Dialect/Linalg/transform-op-split.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-split.mlir
@@ -197,7 +197,7 @@ func.func @two_d(%arg0: tensor<10x34xf32>,
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.consumed}) {
     // expected-error @below {{expects either a dynamic or a static split point to be provided}}
-    %0:2 = "transform.structured.split"(%arg1) { dimension = 1, static_split_point = -9223372036854775808 } : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    %0:2 = "transform.structured.split"(%arg1) { dimension = 1, static_chunk_sizes = -9223372036854775808 } : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
     transform.yield
   }
 }

>From 5de838cd25a65f33113755d405724322da37b3e9 Mon Sep 17 00:00:00 2001
From: Muneeb Khan <muneeb.khan at huawei.com>
Date: Tue, 21 May 2024 18:23:50 +0800
Subject: [PATCH 5/6] fix to continuous tile sizes; switch from split-points to
 chunk-sizes. fix for loops. rename lambda functions, use map_to_vector.

---
 .../Linalg/TransformOps/LinalgTransformOps.td | 19 +++---
 .../TransformOps/LinalgTransformOps.cpp       | 58 ++++++++-----------
 2 files changed, 34 insertions(+), 43 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index f20deff0e1389..66d2cc6f97f58 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -1841,10 +1841,9 @@ def ContinuousTileSizesOp : Op<Transform_Dialect, "structured.continuous_tile_si
         DeclareOpInterfaceMethods<TransformOpInterface>,
         ReportTrackingListenerFailuresOpTrait]> {
   let description = [{
-    This transform takes a linalg as target and a dimension and target size
-    as attributes to generate a list of (1) exponentially diminishing
-    tile sizes that are powers of 2; and (2) the corresponding chunk-sizes
-    the linalg op should be split into along the given dimension.
+    This transform emits the IR computing the list of (1) exponentially
+    diminishing tile sizes that are powers of 2; and (2) the corresponding
+    chunk-sizes the target op should be split into along the given dimension.
 
     For example, for `target_size` 9, and `dimension` 0 for the following
     linalg op as target
@@ -1858,23 +1857,23 @@ def ContinuousTileSizesOp : Op<Transform_Dialect, "structured.continuous_tile_si
     9, 4, 2, 1; and the second result will be a list of chunk sizes
     18, 4, 2, 1 that the corresponding dimension should be split into.
 
-    After the linalg has been split along the given dimension (for example using
-    multiway split), each chunk can be tiled with the corresponding tile size in
-    the `tile_sizes` list generated as a result of this op.
+    After the target op has been split along the given dimension (for example
+    using multiway split), each chunk can be tiled with the corresponding tile
+    size in the `tile_sizes` list generated as a result of this op.
 
     Specifying the output type as !transform.param<i64> will cause `tile_sizes`
-    and `split_points` to be computed statically and not dynamically.
+    and `chunk_sizes` to be computed statically and not dynamically.
   }];
 
   let arguments = (ins TransformHandleTypeInterface:$target,
                        ConfinedAttr<I64Attr, [IntNonNegative]>:$dimension,
                        ConfinedAttr<I64Attr, [IntNonNegative]>:$target_size);
   let results = (outs TransformAnyParamTypeOrAnyHandle:$tile_sizes,
-                      TransformAnyParamTypeOrAnyHandle:$split_points);
+                      TransformAnyParamTypeOrAnyHandle:$chunk_sizes);
   let hasVerifier = 1;
   let assemblyFormat =
     "$target attr-dict `:` custom<ContinuousTileSizeTypes>("
-    "type($target), type($tile_sizes), type($split_points))";
+    "type($target), type($tile_sizes), type($chunk_sizes))";
 
 }
 
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 98f6bde8275f4..e469b7a6d4c4c 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2670,7 +2670,7 @@ transform::ContinuousTileSizesOp::apply(transform::TransformRewriter &rewriter,
       llvm::to_vector(state.getPayloadOps(getTarget()));
 
   if (!llvm::hasSingleElement(targetOps)) {
-    return emitDefiniteFailure() << "requires exactly one target (got "
+    return mlir::emitSilenceableFailure(getLoc()) << "requires exactly one target (got "
                                  << llvm::range_size(targetOps) << ")";
   }
 
@@ -2681,7 +2681,7 @@ transform::ContinuousTileSizesOp::apply(transform::TransformRewriter &rewriter,
   if (!target)
     return emitDefiniteFailure() << "expected Linalg Op";
 
-  if (isa<TransformParamTypeInterface>(getSplitPoints().getType())) {
+  if (isa<TransformParamTypeInterface>(getChunkSizes().getType())) {
     if (target.hasDynamicShape()) {
       auto diag = emitSilenceableError()
                   << "cannot compute parametric tile sizes for dynamically "
@@ -2698,24 +2698,20 @@ transform::ContinuousTileSizesOp::apply(transform::TransformRewriter &rewriter,
              << "failed to compute multi-size tiling sizes";
     }
 
-    SmallVector<int64_t> splitPoints;
+    SmallVector<int64_t> chunkSizes;
 
-    auto tileSizeTripCountPairs =
-        llvm::zip_equal(spec->tileSizes, spec->tripCounts);
+    for (auto &&[tileSize, tripCount] : llvm::zip_equal(spec->tileSizes, spec->tripCounts))
+      chunkSizes.push_back(tileSize * tripCount);
 
-    for (auto [idx, pair] : llvm::enumerate(tileSizeTripCountPairs))
-      splitPoints.push_back(std::get<0>(pair) * std::get<1>(pair));
-
-    auto makeI64AttrsFromI64 = [&](ArrayRef<int64_t> values) {
-      return llvm::to_vector(
-          llvm::map_range(values, [&](int64_t value) -> Attribute {
+    auto getI64AttrsFromI64 = [&](ArrayRef<int64_t> values) {
+      return llvm::map_to_vector(values, [&](int64_t value) -> Attribute {
             return builder.getI64IntegerAttr(value);
-          }));
+          });
     };
     transformResults.setParams(cast<OpResult>(getTileSizes()),
-                               makeI64AttrsFromI64(spec->tileSizes));
-    transformResults.setParams(cast<OpResult>(getSplitPoints()),
-                               makeI64AttrsFromI64(splitPoints));
+                               getI64AttrsFromI64(spec->tileSizes));
+    transformResults.setParams(cast<OpResult>(getChunkSizes()),
+                               getI64AttrsFromI64(chunkSizes));
 
     return DiagnosedSilenceableFailure::success();
   }
@@ -2731,9 +2727,6 @@ transform::ContinuousTileSizesOp::apply(transform::TransformRewriter &rewriter,
     return emitSilenceableError() << "could not generate tile size computation";
   }
 
-  auto tileSizeTripCountPairs =
-      llvm::zip_equal(spec->tileSizes, spec->tripCounts);
-
   AffineExpr s0 = builder.getAffineSymbolExpr(0);
   AffineExpr s1 = builder.getAffineSymbolExpr(1);
   auto apply = [&](AffineExpr expr, ArrayRef<OpFoldResult> ofrs) -> Value {
@@ -2741,31 +2734,30 @@ transform::ContinuousTileSizesOp::apply(transform::TransformRewriter &rewriter,
                                            ofrs);
   };
 
-  SmallVector<Value> splitPoints;
+  SmallVector<Value> chunkSizes;
   Value splitPoint;
-  for (auto [idx, pair] : llvm::enumerate(tileSizeTripCountPairs)) {
-    splitPoint = apply(s0 * s1, {std::get<0>(pair), std::get<1>(pair)});
-    splitPoints.push_back(splitPoint);
+  for (auto &&[tileSize, tripCount] : llvm::zip_equal(spec->tileSizes, spec->tripCounts)) {
+    splitPoint = apply(s0 * s1, {tileSize, tripCount});
+    chunkSizes.push_back(splitPoint);
   }
 
-  auto makeOpFromValue = [&](ArrayRef<Value> values) {
-    return llvm::to_vector(
-        llvm::map_range(values, [&](Value value) -> Operation * {
+  auto getDefiningOps = [&](ArrayRef<Value> values) {
+        return llvm::map_to_vector(values, [&](Value value) -> Operation * {
           return value.getDefiningOp();
-        }));
+        });
   };
 
   transformResults.set(cast<OpResult>(getTileSizes()),
-                       makeOpFromValue(spec->tileSizes));
-  transformResults.set(cast<OpResult>(getSplitPoints()),
-                       makeOpFromValue(splitPoints));
+                       getDefiningOps(spec->tileSizes));
+  transformResults.set(cast<OpResult>(getChunkSizes()),
+                       getDefiningOps(chunkSizes));
 
   return DiagnosedSilenceableFailure::success();
 }
 
 LogicalResult transform::ContinuousTileSizesOp::verify() {
 
-  if (getTileSizes().getType() != getSplitPoints().getType()) {
+  if (getTileSizes().getType() != getChunkSizes().getType()) {
     return emitOpError() << "expects all results type to be the same";
   }
 
@@ -2780,7 +2772,7 @@ void transform::ContinuousTileSizesOp::getEffects(
     modifiesPayload(effects);
   onlyReadsHandle(getTarget(), effects);
   producesHandle(getTileSizes(), effects);
-  producesHandle(getSplitPoints(), effects);
+  producesHandle(getChunkSizes(), effects);
 }
 
 static void printContinuousTileSizeTypes(OpAsmPrinter &printer, Operation *op,
@@ -2792,7 +2784,7 @@ static void printContinuousTileSizeTypes(OpAsmPrinter &printer, Operation *op,
 static ParseResult parseContinuousTileSizeTypes(OpAsmParser &parser,
                                                 Type &targetType,
                                                 Type &tileSizesType,
-                                                Type &splitPointsType) {
+                                                Type &chunkSizesType) {
   FunctionType funcType;
   llvm::SMLoc typeLoc = parser.getCurrentLocation();
   if (failed(parser.parseType<FunctionType>(funcType)))
@@ -2803,7 +2795,7 @@ static ParseResult parseContinuousTileSizeTypes(OpAsmParser &parser,
                                  "argument and one result";
   }
   targetType = funcType.getInput(0);
-  tileSizesType = splitPointsType = funcType.getResult(0);
+  tileSizesType = chunkSizesType = funcType.getResult(0);
 
   return success();
 }

>From 46f7c53431562a2b06a270223ac6a84c9abdb3e0 Mon Sep 17 00:00:00 2001
From: Muneeb Khan <muneeb.khan at huawei.com>
Date: Wed, 22 May 2024 22:11:16 +0800
Subject: [PATCH 6/6] fix for cts; Adapt to use TilingInterface.

---
 .../Dialect/Linalg/Transforms/Transforms.h    |  7 +++--
 .../TransformOps/LinalgTransformOps.cpp       | 24 +++++++-------
 mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp | 31 +++++++++----------
 3 files changed, 32 insertions(+), 30 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index d902df458c2b7..415e3508eb45c 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -861,11 +861,12 @@ computeStaticMultiTileSizes(LinalgOp op, unsigned dimension, int64_t targetSize,
                             int64_t divisor);
 
 FailureOr<StaticContinuousTileSizeSpecification>
-computeStaticContinuousTileSizes(LinalgOp op, unsigned dimension,
+computeStaticContinuousTileSizes(TilingInterface op, unsigned dimension,
                                  unsigned targetSize);
 FailureOr<ContinuousTileSizeSpecification>
-computeContinuousTileSizes(OpBuilder &builder, LinalgOp op, unsigned dimension,
-                           OpFoldResult targetSize, bool emitAssertions);
+computeContinuousTileSizes(OpBuilder &builder, TilingInterface op,
+                           unsigned dimension, OpFoldResult targetSize,
+                           bool emitAssertions);
 /// Rewrite a TilingInterface `op` to a tiled `scf.forall`, applying
 /// tiling by `numThreads`.
 /// If non-empty, the `mapping` is added as an attribute to the
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index e469b7a6d4c4c..ce1ced63e9039 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2674,24 +2674,26 @@ transform::ContinuousTileSizesOp::apply(transform::TransformRewriter &rewriter,
                                  << llvm::range_size(targetOps) << ")";
   }
 
-  auto target = dyn_cast<LinalgOp>(*targetOps.begin());
-
-  OpBuilder builder(target.getContext());
+  Operation *target = *targetOps.begin();
+  auto linalgOp = dyn_cast<LinalgOp>(target);
+  auto tileableOp = dyn_cast<TilingInterface>(target);
 
-  if (!target)
+  if (!linalgOp)
     return emitDefiniteFailure() << "expected Linalg Op";
 
+  OpBuilder builder(linalgOp.getContext());
+
   if (isa<TransformParamTypeInterface>(getChunkSizes().getType())) {
-    if (target.hasDynamicShape()) {
+    if (linalgOp.hasDynamicShape()) {
       auto diag = emitSilenceableError()
                   << "cannot compute parametric tile sizes for dynamically "
                      "shaped payload op";
-      diag.attachNote(target->getLoc()) << "payload op";
+      diag.attachNote(linalgOp->getLoc()) << "payload op";
       return diag;
     }
 
     FailureOr<StaticContinuousTileSizeSpecification> spec =
-        computeStaticContinuousTileSizes(target, getDimension(),
+        computeStaticContinuousTileSizes(tileableOp, getDimension(),
                                          getTargetSize());
     if (failed(spec)) {
       return emitSilenceableError()
@@ -2716,13 +2718,13 @@ transform::ContinuousTileSizesOp::apply(transform::TransformRewriter &rewriter,
     return DiagnosedSilenceableFailure::success();
   }
 
-  builder.setInsertionPoint(target);
+  builder.setInsertionPoint(linalgOp);
 
   OpFoldResult targetSize = builder.getIndexAttr(getTargetSize());
   unsigned dimension = getDimension();
 
-  FailureOr<ContinuousTileSizeSpecification> spec =
-      computeContinuousTileSizes(builder, target, dimension, targetSize, true);
+  FailureOr<ContinuousTileSizeSpecification> spec = computeContinuousTileSizes(
+      builder, tileableOp, dimension, targetSize, true);
   if (failed(spec)) {
     return emitSilenceableError() << "could not generate tile size computation";
   }
@@ -2730,7 +2732,7 @@ transform::ContinuousTileSizesOp::apply(transform::TransformRewriter &rewriter,
   AffineExpr s0 = builder.getAffineSymbolExpr(0);
   AffineExpr s1 = builder.getAffineSymbolExpr(1);
   auto apply = [&](AffineExpr expr, ArrayRef<OpFoldResult> ofrs) -> Value {
-    return affine::makeComposedAffineApply(builder, target->getLoc(), expr,
+    return affine::makeComposedAffineApply(builder, linalgOp->getLoc(), expr,
                                            ofrs);
   };
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index 8049ade591a70..8a06296626394 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -108,16 +108,19 @@ static void emitIsPositiveIndexAssertion(ImplicitLocOpBuilder &b,
 }
 
 FailureOr<StaticContinuousTileSizeSpecification>
-mlir::linalg::computeStaticContinuousTileSizes(LinalgOp op, unsigned dimension,
+mlir::linalg::computeStaticContinuousTileSizes(TilingInterface op,
+                                               unsigned dimension,
                                                unsigned targetSize) {
 
-  assert(!op.hasDynamicShape() &&
+  LinalgOp linalgOp = dyn_cast<LinalgOp>(op.getOperation());
+
+  assert(!linalgOp.hasDynamicShape() &&
          "cannot compute static multi-tile sizes for an op with dynamic shape");
   assert(targetSize > 0 && "target size must be non-negative");
-  assert(dimension < op.getNumLoops() && "dimension overflow");
+  assert(dimension < linalgOp.getNumLoops() && "dimension overflow");
 
   StaticContinuousTileSizeSpecification spec;
-  int64_t loopRange = op.getStaticLoopRanges()[dimension];
+  int64_t loopRange = linalgOp.getStaticLoopRanges()[dimension];
   int64_t tripCount = loopRange / targetSize;
 
   unsigned tileSize = targetSize;
@@ -158,17 +161,19 @@ mlir::linalg::computeStaticContinuousTileSizes(LinalgOp op, unsigned dimension,
 }
 
 FailureOr<ContinuousTileSizeSpecification>
-mlir::linalg::computeContinuousTileSizes(OpBuilder &builder, LinalgOp op,
+mlir::linalg::computeContinuousTileSizes(OpBuilder &builder, TilingInterface op,
                                          unsigned dimension,
                                          OpFoldResult targetSize,
                                          bool emitAssertions) {
 
+  unsigned numLoops = op.getIterationDomain(builder).size();
+
   // Bail out on dimension overflow.
-  if (dimension >= op.getNumLoops())
+  if (dimension >= numLoops)
     return failure();
 
   // The code below works only on values.
-  Location loc = op.getLoc();
+  Location loc = op->getLoc();
   ImplicitLocOpBuilder b(loc, builder);
   if (emitAssertions) {
     emitIsPositiveIndexAssertion(b, targetSize);
@@ -178,15 +183,9 @@ mlir::linalg::computeContinuousTileSizes(OpBuilder &builder, LinalgOp op,
 
   // Find the trip count of the iteration space dimension for which the tile
   // sizes are computed.
-  SmallVector<OpFoldResult> allShapes =
-      op.createFlatListOfOperandDims(b, b.getLoc());
-  AffineMap shapesToLoops = op.getShapesToLoopsMap();
-  SmallVector<OpFoldResult> loopRanges =
-      makeComposedFoldedMultiResultAffineApply(b, op.getLoc(), shapesToLoops,
-                                               allShapes);
-
-  Value loopRange =
-      getValueOrCreateConstantIndexOp(b, op.getLoc(), loopRanges[dimension]);
+  SmallVector<Range> loopRanges = op.getIterationDomain(builder);
+  Value loopRange = getValueOrCreateConstantIndexOp(b, op->getLoc(),
+                                                    loopRanges[dimension].size);
 
   ContinuousTileSizeSpecification spec;
 



More information about the Mlir-commits mailing list