[Mlir-commits] [mlir] [MLIR] Add continuous tiling to transform dialect (PR #82792)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jun 20 04:35:36 PDT 2024
https://github.com/muneebkhan85 updated https://github.com/llvm/llvm-project/pull/82792
>From 77a93d682b08c465e337f18ac1181e818eed88c6 Mon Sep 17 00:00:00 2001
From: Muneeb Khan <muneeb.khan at huawei.com>
Date: Wed, 1 May 2024 22:15:18 +0800
Subject: [PATCH 1/5] [mlir][Transform] Add continuous tiling to Transform
dialect
Add continuous tiling op `structured.continuous_tile_sizes`
to the transform dialect that returns as result (1) a list of
exponentially diminishing tile sizes, and (2) a list of chunk
sizes -- along the specified dimension of the target --
where the corresponding tile sizes from (1) can be applied.
The list of chunk sizes from (2) cover the entire iteration
space along the given dimension of the target.
---
.../Linalg/TransformOps/LinalgTransformOps.td | 45 ++++++
.../Dialect/Linalg/Transforms/Transforms.h | 21 +++
.../TransformOps/LinalgTransformOps.cpp | 147 ++++++++++++++++++
mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp | 131 ++++++++++++++++
4 files changed, 344 insertions(+)
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 93e2c2db729da..e46dc1565964a 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -1819,6 +1819,51 @@ def TileReductionUsingForallOp :
}
+//===----------------------------------------------------------------------===//
+// ContinuousTileSizesOp
+//===----------------------------------------------------------------------===//
+
+def ContinuousTileSizesOp : Op<Transform_Dialect, "structured.continuous_tile_sizes",
+ [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ DeclareOpInterfaceMethods<TransformOpInterface>,
+ ReportTrackingListenerFailuresOpTrait]> {
+ let description = [{
+ This transform emits the IR computing the list of (1) exponentially
+ diminishing tile sizes that are powers of 2; and (2) the corresponding
+ chunk-sizes the target op should be split into along the given dimension.
+
+ For example, for `target_size` 9, and `dimension` 0 for the following
+ linalg op as target
+
+ ```
+ %0 = linalg.matmul ins(%arg0, %arg1: tensor<25x34xf32>, tensor<34x25xf32>)
+ outs(%arg2: tensor<25x25xf32>)
+ ```
+
+ the first result `tile_sizes` will be a list of diminishing tile sizes
+ 9, 4, 2, 1; and the second result will be a list of chunk sizes
+ 18, 4, 2, 1 that the corresponding dimension should be split into.
+
+ After the target op has been split along the given dimension (for example
+ using multiway split), each chunk can be tiled with the corresponding tile
+ size in the `tile_sizes` list generated as a result of this op.
+
+ Specifying the output type as !transform.param<i64> will cause `tile_sizes`
+ and `chunk_sizes` to be computed statically and not dynamically.
+ }];
+
+ let arguments = (ins TransformHandleTypeInterface:$target,
+ ConfinedAttr<I64Attr, [IntNonNegative]>:$dimension,
+ ConfinedAttr<I64Attr, [IntNonNegative]>:$target_size);
+ let results = (outs TransformAnyParamTypeOrAnyHandle:$tile_sizes,
+ TransformAnyParamTypeOrAnyHandle:$chunk_sizes);
+ let hasVerifier = 1;
+ let assemblyFormat =
+ "$target attr-dict `:` custom<ContinuousTileSizeTypes>("
+ "type($target), type($tile_sizes), type($chunk_sizes))";
+
+}
+
//===----------------------------------------------------------------------===//
// TileUsingForOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 05e97befdec1f..8424207ea47e5 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -801,6 +801,15 @@ struct MultiSizeSpecificationBase {
/// Number of tiles associated with each size.
T lowTripCount, highTripCount;
};
+
+template <typename T>
+struct ContinuousTileSizeSpecificationBase {
+ /// Tile sizes.
+ SmallVector<T> tileSizes;
+ /// Number of tiles associated with each size.
+ SmallVector<T> tripCounts;
+};
+
} // namespace detail
/// A description of a multi-size tiling comprising tile sizes and numbers of
@@ -811,6 +820,11 @@ struct MultiSizeSpecification
struct StaticMultiSizeSpecification
: public detail::MultiSizeSpecificationBase<int64_t> {};
+struct ContinuousTileSizeSpecification
+ : public detail::ContinuousTileSizeSpecificationBase<Value> {};
+struct StaticContinuousTileSizeSpecification
+ : public detail::ContinuousTileSizeSpecificationBase<int64_t> {};
+
/// Emits the IR computing the multi-sized tiling specification with two tile
/// sizes not exceeding `targetSize`, each divisible by `sizeDivisor`, such
/// that there exist numbers of tiles with these sizes that fully cover the
@@ -846,6 +860,13 @@ FailureOr<StaticMultiSizeSpecification>
computeStaticMultiTileSizes(LinalgOp op, unsigned dimension, int64_t targetSize,
int64_t divisor);
+FailureOr<StaticContinuousTileSizeSpecification>
+computeStaticContinuousTileSizes(LinalgOp op, unsigned dimension,
+ unsigned targetSize);
+FailureOr<ContinuousTileSizeSpecification>
+computeContinuousTileSizes(OpBuilder &builder, TilingInterface op,
+ unsigned dimension, OpFoldResult targetSize,
+ bool emitAssertions);
/// Rewrite a TilingInterface `op` to a tiled `scf.forall`, applying
/// tiling by `numThreads`.
/// If non-empty, the `mapping` is added as an attribute to the
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index bc02788f9c441..114e71035ce44 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2584,6 +2584,153 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne(
return DiagnosedSilenceableFailure::success();
}
+//===----------------------------------------------------------------------===//
+// ContinuousTileSizesOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::ContinuousTileSizesOp::apply(transform::TransformRewriter &rewriter,
+ TransformResults &transformResults,
+ TransformState &state) {
+
+ SmallVector<Operation *> targetOps =
+ llvm::to_vector(state.getPayloadOps(getTarget()));
+
+ if (!llvm::hasSingleElement(targetOps)) {
+ return mlir::emitSilenceableFailure(getLoc())
+ << "requires exactly one target (got " << llvm::range_size(targetOps)
+ << ")";
+ }
+
+ Operation *target = *targetOps.begin();
+ auto linalgOp = dyn_cast<LinalgOp>(target);
+ auto tileableOp = dyn_cast<TilingInterface>(target);
+
+ if (!linalgOp)
+ return emitDefiniteFailure() << "expected Linalg Op";
+
+ OpBuilder builder(linalgOp.getContext());
+
+ if (isa<TransformParamTypeInterface>(getChunkSizes().getType())) {
+ if (linalgOp.hasDynamicShape()) {
+ auto diag = emitSilenceableError()
+ << "cannot compute parametric tile sizes for dynamically "
+ "shaped payload op";
+ diag.attachNote(linalgOp->getLoc()) << "payload op";
+ return diag;
+ }
+
+ FailureOr<StaticContinuousTileSizeSpecification> spec =
+ computeStaticContinuousTileSizes(linalgOp, getDimension(),
+ getTargetSize());
+ if (failed(spec)) {
+ return emitSilenceableError()
+ << "failed to compute multi-size tiling sizes";
+ }
+
+ SmallVector<int64_t> chunkSizes;
+
+ for (auto &&[tileSize, tripCount] :
+ llvm::zip_equal(spec->tileSizes, spec->tripCounts))
+ chunkSizes.push_back(tileSize * tripCount);
+
+ auto getI64AttrsFromI64 = [&](ArrayRef<int64_t> values) {
+ return llvm::map_to_vector(values, [&](int64_t value) -> Attribute {
+ return builder.getI64IntegerAttr(value);
+ });
+ };
+ transformResults.setParams(cast<OpResult>(getTileSizes()),
+ getI64AttrsFromI64(spec->tileSizes));
+ transformResults.setParams(cast<OpResult>(getChunkSizes()),
+ getI64AttrsFromI64(chunkSizes));
+
+ return DiagnosedSilenceableFailure::success();
+ }
+
+ builder.setInsertionPoint(linalgOp);
+
+ OpFoldResult targetSize = builder.getIndexAttr(getTargetSize());
+ unsigned dimension = getDimension();
+
+ FailureOr<ContinuousTileSizeSpecification> spec = computeContinuousTileSizes(
+ builder, tileableOp, dimension, targetSize, true);
+ if (failed(spec)) {
+ return emitSilenceableError() << "could not generate tile size computation";
+ }
+
+ AffineExpr s0 = builder.getAffineSymbolExpr(0);
+ AffineExpr s1 = builder.getAffineSymbolExpr(1);
+ auto apply = [&](AffineExpr expr, ArrayRef<OpFoldResult> ofrs) -> Value {
+ return affine::makeComposedAffineApply(builder, linalgOp->getLoc(), expr,
+ ofrs);
+ };
+
+ SmallVector<Value> chunkSizes;
+ Value splitPoint;
+ for (auto &&[tileSize, tripCount] :
+ llvm::zip_equal(spec->tileSizes, spec->tripCounts)) {
+ splitPoint = apply(s0 * s1, {tileSize, tripCount});
+ chunkSizes.push_back(splitPoint);
+ }
+
+ auto getDefiningOps = [&](ArrayRef<Value> values) {
+ return llvm::map_to_vector(values, [&](Value value) -> Operation * {
+ return value.getDefiningOp();
+ });
+ };
+
+ transformResults.set(cast<OpResult>(getTileSizes()),
+ getDefiningOps(spec->tileSizes));
+ transformResults.set(cast<OpResult>(getChunkSizes()),
+ getDefiningOps(chunkSizes));
+
+ return DiagnosedSilenceableFailure::success();
+}
+
+LogicalResult transform::ContinuousTileSizesOp::verify() {
+
+ if (getTileSizes().getType() != getChunkSizes().getType()) {
+ return emitOpError() << "expects all results type to be the same";
+ }
+
+ return success();
+}
+
+void transform::ContinuousTileSizesOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ if (isa<TransformParamTypeInterface>(getTileSizes().getType()))
+ onlyReadsPayload(effects);
+ else
+ modifiesPayload(effects);
+ onlyReadsHandle(getTargetMutable(), effects);
+ producesHandle(getOperation()->getOpResults(), effects);
+}
+
+static void printContinuousTileSizeTypes(OpAsmPrinter &printer, Operation *op,
+ Type targetType, Type tile_sizes,
+ Type) {
+ printer.printFunctionalType(TypeRange{targetType}, TypeRange{tile_sizes});
+}
+
+static ParseResult parseContinuousTileSizeTypes(OpAsmParser &parser,
+ Type &targetType,
+ Type &tileSizesType,
+ Type &chunkSizesType) {
+ FunctionType funcType;
+ llvm::SMLoc typeLoc = parser.getCurrentLocation();
+ if (failed(parser.parseType<FunctionType>(funcType)))
+ return failure();
+
+ if (funcType.getNumInputs() != 1 || funcType.getNumResults() != 1) {
+ parser.emitError(typeLoc) << "expects a trailing functional type with one "
+ "argument and one result";
+ }
+ targetType = funcType.getInput(0);
+ tileSizesType = chunkSizesType = funcType.getResult(0);
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// TileUsingForOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index d8dee82237156..8ef8651646829 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -107,6 +107,137 @@ static void emitIsPositiveIndexAssertion(ImplicitLocOpBuilder &b,
b.getStringAttr("expected strictly positive tile size and divisor"));
}
+FailureOr<StaticContinuousTileSizeSpecification>
+mlir::linalg::computeStaticContinuousTileSizes(LinalgOp op,
+ unsigned dimension,
+ unsigned targetSize) {
+
+ assert(!op.hasDynamicShape() &&
+ "cannot compute static multi-tile sizes for an op with dynamic shape");
+ assert(targetSize > 0 && "target size must be non-negative");
+ assert(dimension < op.getNumLoops() && "dimension overflow");
+
+ StaticContinuousTileSizeSpecification spec;
+ int64_t loopRange = op.getStaticLoopRanges()[dimension];
+ int64_t tripCount = loopRange / targetSize;
+
+ unsigned tileSize = targetSize;
+
+ spec.tileSizes.push_back(tileSize);
+ spec.tripCounts.push_back(tripCount);
+
+ int64_t remainderChunk = loopRange % targetSize;
+
+ while (tileSize > 1 && remainderChunk != 0) {
+
+ uint64_t maxPower = llvm::bit_floor(tileSize);
+ tileSize = maxPower == tileSize ? maxPower >> 1 : maxPower;
+
+ tripCount = remainderChunk / tileSize;
+
+ if (tripCount > 0) {
+ spec.tileSizes.push_back(tileSize);
+ spec.tripCounts.push_back(tripCount);
+ }
+
+ remainderChunk = remainderChunk % tileSize;
+ }
+
+ auto tripCountCheck = [&](SmallVector<int64_t> tileSizes,
+ SmallVector<int64_t> tripCounts,
+ int64_t range) -> bool {
+ int64_t computedRange = 0;
+ for (auto [tileSize, tripCount] : llvm::zip(tileSizes, tripCounts))
+ computedRange += tileSize * tripCount;
+ return range == computedRange;
+ };
+
+ if (!tripCountCheck(spec.tileSizes, spec.tripCounts, loopRange))
+ return failure();
+
+ return spec;
+}
+
+FailureOr<ContinuousTileSizeSpecification>
+mlir::linalg::computeContinuousTileSizes(OpBuilder &builder, TilingInterface op,
+ unsigned dimension,
+ OpFoldResult targetSize,
+ bool emitAssertions) {
+
+ SmallVector<Range> loopRanges = op.getIterationDomain(builder);
+ unsigned numLoops = loopRanges.size();
+
+ // Bail out on dimension overflow.
+ if (dimension >= numLoops)
+ return failure();
+
+ // The code below works only on values.
+ Location loc = op->getLoc();
+ ImplicitLocOpBuilder b(loc, builder);
+ if (emitAssertions) {
+ emitIsPositiveIndexAssertion(b, targetSize);
+ }
+ Value targetSizeValue =
+ getValueOrCreateConstantIndexOp(builder, loc, targetSize);
+
+ // Find the trip count of the iteration space dimension for which the tile
+ // sizes are computed.
+ Value loopRange = getValueOrCreateConstantIndexOp(b, loc,
+ loopRanges[dimension].size);
+ ContinuousTileSizeSpecification spec;
+
+ // Compute the tile sizes and the respective numbers of tiles.
+ AffineExpr s0 = b.getAffineSymbolExpr(0);
+ AffineExpr s1 = b.getAffineSymbolExpr(1);
+ auto apply = [&](AffineExpr expr, ArrayRef<OpFoldResult> ofrs) -> Value {
+ return affine::makeComposedAffineApply(b, b.getLoc(), expr, ofrs);
+ };
+
+ Value tripCountValue = apply(s0.floorDiv(s1), {loopRange, targetSizeValue});
+ Value remainderChunkValue = apply(s0 % s1, {loopRange, targetSizeValue});
+
+ OpFoldResult tripCountSize = affine::makeComposedFoldedAffineApply(
+ b, b.getLoc(), s0.floorDiv(s1), {loopRange, targetSizeValue});
+
+ // emitAssertions above already asserts that targetSize is
+ // a poistive integer.
+ uint64_t tileSizeInt = *getConstantIntValue(targetSizeValue);
+
+ assert(tileSizeInt > 0 && "target size must be non-negative");
+
+ spec.tileSizes.push_back(targetSizeValue);
+ spec.tripCounts.push_back(tripCountValue);
+
+ while (tileSizeInt > 1) {
+ uint64_t maxPower = llvm::bit_floor(tileSizeInt);
+ tileSizeInt = maxPower == tileSizeInt ? maxPower >> 1 : maxPower;
+ auto constStepOp =
+ builder.createOrFold<arith::ConstantIndexOp>(b.getLoc(), tileSizeInt);
+ tripCountValue = apply(s0.floorDiv(s1), {remainderChunkValue, constStepOp});
+
+ tripCountSize = affine::makeComposedFoldedAffineApply(
+ b, b.getLoc(), s0.floorDiv(s1), {remainderChunkValue, constStepOp});
+
+ // Optimization if tripCount can be determined to be zero.
+ if (Attribute attr = llvm::dyn_cast_if_present<Attribute>(tripCountSize)) {
+ auto intAttr = cast<IntegerAttr>(attr);
+ bool isTripCountZero = intAttr.getValue().isZero();
+
+ if (!isTripCountZero) {
+ spec.tileSizes.push_back(constStepOp);
+ spec.tripCounts.push_back(tripCountValue);
+ }
+ } else {
+ spec.tileSizes.push_back(constStepOp);
+ spec.tripCounts.push_back(tripCountValue);
+ }
+
+ remainderChunkValue = apply(s0 % s1, {remainderChunkValue, constStepOp});
+ }
+
+ return spec;
+}
+
FailureOr<StaticMultiSizeSpecification>
mlir::linalg::computeStaticMultiTileSizes(LinalgOp op, unsigned dimension,
int64_t targetSize, int64_t divisor) {
>From a1f1c72445e0581c6a0fca830d9c7a95e88e20bd Mon Sep 17 00:00:00 2001
From: Muneeb Khan <muneeb.khan at huawei.com>
Date: Wed, 1 May 2024 22:26:19 +0800
Subject: [PATCH 2/5] [mlir][Transform] Add support for multiway split in
SplitOp
Add functionality that enables SplitOp to do a multiway split of
a traget op along a given dimension. With multiway attribute,
SplitOp takes a list of chunk sizes and applies it to a single
target along the given dimension to generate multiple
structured ops extracted from the target.
---
.../Linalg/TransformOps/LinalgTransformOps.td | 40 +++-
.../TransformOps/LinalgTransformOps.cpp | 222 ++++++++++++------
.../mlir/dialects/transform/structured.py | 16 +-
.../Dialect/Linalg/transform-op-split.mlir | 2 +-
.../dialects/transform_structured_ext.py | 4 +-
5 files changed, 190 insertions(+), 94 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index e46dc1565964a..866275cedf68b 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -1396,29 +1396,43 @@ def SplitOp : Op<Transform_Dialect, "structured.split",
DeclareOpInterfaceMethods<TransformOpInterface>,
ReportTrackingListenerFailuresOpTrait]> {
let description = [{
- Indicates that the given `target` op should be split into two complementary
+ Splits the given `target` op into two or more complementary
parts, which combined cover the entire iteration domain of the original op.
The split is performed along the iteration space dimension provided as
- attribute. In case of dimension overflow, the transformation fails. The
- split is performed at the dimension iterator value specified as either the
- static split point attribute when it is known at transform IR construction
- time or as the handle to an operation producing a single index-typed value
- when it is computed by payload IR. In the latter case, the static split
+ chunk size attribute specifying the size of the lower part; the remaining
+ range in the iteration space is assigned as the upper part. In case of
+ dimension overflow, the transformation fails. The split is performed at the
+ dimension iterator value specified as either the static chunk size
+ attribute when it is known at transform IR construction time or
+ as the handle to an operation producing a single index-typed value
+ when it is computed by payload IR. In the latter case, the chunk size
point must be set to `ShapedType::kDynamic` and the dynamic size handle
must point to as many value-producing operations as there are structured
operations pointed to by the target handle.
- The operation consumes the target handle, but preserves the split point
- handle if provided. It produces two new handles pointing to the two parts
- of the structured op after splitting, in the same order as the target
- operand, with the first handle corresponding to the part with lower
- iteration space indices.
+ The operation consumes the target handle, but preserves the chunk size
+ handle if provided. Without the `multiway` attribute, it produces two
+ new handles pointing to the two parts of the structured op after splitting,
+ in the same order as the target operand, with the first handle
+ corresponding to the part with lower iteration space indices.
+
+ Multiway split mode is enabled by specifying the `multiway` attribute.
+ In this mode a single `target` op is split into multiple parts covering
+ the iteration space of the specified dimension. `static_chunk_sizes` and
+ `dynamic_chunk_sizes` in this case is a list of chunk sizes that the given
+ dimension should be split into. With `multiway` it produces two handles;
+ the first handle is a list of the multiple parts of the structured op
+ after splitting, where the target dimensions for each linalg op in the
+ list corresponds to the chunk sizes specfied in the input split list.
+ If the chunk sizes do not cover the entire iteration space, the leftover
+ chunk is the last payload in the first handle. The second handle is empty.
}];
let arguments = (ins TransformHandleTypeInterface:$target,
I64Attr:$dimension,
- Optional<TransformAnyParamTypeOrAnyHandle>:$dynamic_split_point,
- I64Attr:$static_split_point);
+ Optional<TransformAnyParamTypeOrAnyHandle>:$dynamic_chunk_sizes,
+ I64Attr:$static_chunk_sizes,
+ UnitAttr:$multiway);
let results = (outs TransformHandleTypeInterface:$first,
TransformHandleTypeInterface:$second);
let hasCustomAssemblyFormat = 1;
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 114e71035ce44..37467db568c27 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2266,13 +2266,26 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
// Collect the dynamic split points if provided.
SmallVector<Operation *> payload =
llvm::to_vector(state.getPayloadOps(getTarget()));
- SmallVector<OpFoldResult> splitPoints;
- splitPoints.reserve(payload.size());
- if (getDynamicSplitPoint()) {
+
+ bool isMultiwaySplit = getMultiway();
+
+ if (isMultiwaySplit && !llvm::hasSingleElement(payload)) {
+ return mlir::emitSilenceableFailure(getLoc())
+ << "requires exactly one target when "
+ "multiway split is enabled (got "
+ << llvm::range_size(payload) << ")";
+ }
+
+ SmallVector<OpFoldResult> chunkSizes;
+
+ if (!isMultiwaySplit)
+ chunkSizes.reserve(payload.size());
+
+ if (getDynamicChunkSizes()) {
auto diag = DiagnosedSilenceableFailure::success();
- if (isa<TransformHandleTypeInterface>(getDynamicSplitPoint().getType())) {
- splitPoints = llvm::to_vector(llvm::map_range(
- state.getPayloadOps(getDynamicSplitPoint()), [&](Operation *op) {
+ if (isa<TransformHandleTypeInterface>(getDynamicChunkSizes().getType())) {
+ chunkSizes = llvm::to_vector(llvm::map_range(
+ state.getPayloadOps(getDynamicChunkSizes()), [&](Operation *op) {
if (op->getNumResults() != 1 ||
!op->getResult(0).getType().isIndex()) {
diag = emitSilenceableError()
@@ -2283,103 +2296,172 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
return OpFoldResult(op->getResult(0));
}));
} else {
- splitPoints = llvm::to_vector(
- llvm::map_range(state.getParams(getDynamicSplitPoint()),
+ chunkSizes = llvm::to_vector(
+ llvm::map_range(state.getParams(getDynamicChunkSizes()),
[](Attribute attr) { return OpFoldResult(attr); }));
}
if (diag.isSilenceableFailure())
return diag;
- if (splitPoints.size() != payload.size()) {
+ // For multiway split, a single payload is expected to have multiple
+ // split points.
+ if (!isMultiwaySplit && chunkSizes.size() != payload.size()) {
return emitDefiniteFailure()
<< "expected the dynamic split point handle to point to as "
"many operations ("
- << splitPoints.size() << ") as the target handle ("
+ << chunkSizes.size() << ") as the target handle ("
<< payload.size() << ")";
}
} else {
- splitPoints.resize(payload.size(),
- rewriter.getIndexAttr(getStaticSplitPoint()));
+ chunkSizes.resize(payload.size(),
+ rewriter.getIndexAttr(getStaticChunkSizes()));
}
- // Split each target operation.
- SmallVector<Operation *> first, second;
- Operation *noSecondPart = nullptr;
- for (const auto &pair : llvm::zip(payload, splitPoints)) {
- Operation *target = std::get<0>(pair);
- auto linalgOp = dyn_cast<LinalgOp>(target);
+ auto checkStructuredOpAndDimensions =
+ [&](LinalgOp linalgOp, Location loc) -> DiagnosedSilenceableFailure {
if (!linalgOp) {
auto diag = emitSilenceableError() << "only applies to structured ops";
- diag.attachNote(target->getLoc()) << "target op";
+ diag.attachNote(loc) << "target op";
return diag;
}
if (getDimension() >= linalgOp.getNumLoops()) {
auto diag = emitSilenceableError() << "dimension " << getDimension()
- << " does not exist in target op";
- diag.attachNote(target->getLoc()) << "target op";
+ << " does not exist in target op";
+ diag.attachNote(loc) << "target op";
return diag;
}
+ return DiagnosedSilenceableFailure::success();
+ };
- rewriter.setInsertionPoint(linalgOp);
- std::tie(first.emplace_back(), second.emplace_back()) = linalg::splitOp(
- rewriter, cast<TilingInterface>(linalgOp.getOperation()),
- getDimension(), std::get<1>(pair));
-
- // Propagate errors.
- if (!first.back() && !second.back()) {
+ auto checkFailureInSplitting =
+ [&](bool hasFailed, Location loc) -> DiagnosedSilenceableFailure {
+ if (hasFailed) {
auto diag = emitDefiniteFailure() << "internal failure in splitting";
- diag.attachNote(target->getLoc()) << "target op";
+ diag.attachNote(loc) << "target op";
return diag;
}
+ return DiagnosedSilenceableFailure::success();
+ };
+
+ if (isMultiwaySplit) {
+
+ // Split a single target operation at multiple points.
+ SmallVector<Operation *> opList;
+ TilingInterface head, tail;
+ Operation *target = payload.front();
+
+ LinalgOp linalgOp = dyn_cast<LinalgOp>(target);
+
+ // Check that the target is a valid LinalgOp with correct dimensions.
+ DiagnosedSilenceableFailure diag =
+ checkStructuredOpAndDimensions(linalgOp, target->getLoc());
+ if (diag.isSilenceableFailure())
+ return diag;
+
+ for (auto &&[idx, chunkSize] : llvm::enumerate(chunkSizes)) {
+
+ if (idx > 0)
+ target = tail.getOperation();
+
+ if (!target)
+ break;
- // Do not add null second parts.
- if (!second.back()) {
- noSecondPart = target;
- second.pop_back();
+ linalgOp = cast<LinalgOp>(target);
+
+ rewriter.setInsertionPoint(linalgOp);
+ std::tie(head, tail) = linalg::splitOp(
+ rewriter, cast<TilingInterface>(linalgOp.getOperation()),
+ getDimension(), chunkSize);
+
+ // Propagate errors.
+ DiagnosedSilenceableFailure diag =
+ checkFailureInSplitting(!head && !tail, target->getLoc());
+ if (diag.isDefiniteFailure())
+ return diag;
+
+ opList.push_back(head.getOperation());
}
- }
- if (second.size() != first.size() && !second.empty()) {
- auto diag = emitSilenceableError()
- << "splitting does not produce the second part for a subset "
- "of targets";
- diag.attachNote() << "expected splitting to produce the second part of all "
- "or none of the targets";
- diag.attachNote(noSecondPart->getLoc())
- << "first target with no second part";
- return diag;
- }
+ // Append any leftover parts to the end of the result list.
+ if (tail)
+ opList.push_back(tail.getOperation());
+ results.set(cast<OpResult>(getFirst()), opList);
+ results.set(cast<OpResult>(getSecond()), {});
+
+ } else {
+ // Split each target operation.
+ SmallVector<Operation *> first, second;
+ Operation *noSecondPart = nullptr;
+ for (const auto &pair : llvm::zip(payload, chunkSizes)) {
+ Operation *target = std::get<0>(pair);
+ LinalgOp linalgOp = dyn_cast<LinalgOp>(target);
+ DiagnosedSilenceableFailure diag =
+ checkStructuredOpAndDimensions(linalgOp, target->getLoc());
+
+ if (diag.isSilenceableFailure())
+ return diag;
+
+ rewriter.setInsertionPoint(linalgOp);
+ std::tie(first.emplace_back(), second.emplace_back()) = linalg::splitOp(
+ rewriter, cast<TilingInterface>(linalgOp.getOperation()),
+ getDimension(), std::get<1>(pair));
+
+ // Propagate errors.
+ DiagnosedSilenceableFailure diagSplit = checkFailureInSplitting(
+ !first.back() && !second.back(), target->getLoc());
+ if (diagSplit.isDefiniteFailure())
+ return diag;
+
+ // Do not add null second parts.
+ if (!second.back()) {
+ noSecondPart = target;
+ second.pop_back();
+ }
+ }
+
+ if (second.size() != first.size() && !second.empty()) {
+ auto diag = emitSilenceableError()
+ << "splitting does not produce the second part for a subset "
+ "of targets";
+ diag.attachNote()
+ << "expected splitting to produce the second part of all "
+ "or none of the targets";
+ diag.attachNote(noSecondPart->getLoc())
+ << "first target with no second part";
+ return diag;
+ }
- results.set(cast<OpResult>(getFirst()), first);
- results.set(cast<OpResult>(getSecond()), second);
+ results.set(cast<OpResult>(getFirst()), first);
+ results.set(cast<OpResult>(getSecond()), second);
+ }
return DiagnosedSilenceableFailure::success();
}
void SplitOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
consumesHandle(getTargetMutable(), effects);
- if (getDynamicSplitPoint())
- onlyReadsHandle(getDynamicSplitPointMutable(), effects);
+ if (getDynamicChunkSizes())
+ onlyReadsHandle(getDynamicChunkSizesMutable(), effects);
producesHandle(getOperation()->getOpResults(), effects);
modifiesPayload(effects);
}
ParseResult SplitOp::parse(OpAsmParser &parser, OperationState &result) {
- OpAsmParser::UnresolvedOperand target, dynamicSplitPoint;
- IntegerAttr staticSplitPoint;
+ OpAsmParser::UnresolvedOperand target, dynamicChunkSizes;
+ IntegerAttr staticChunkSizes;
if (parser.parseOperand(target) || parser.parseKeyword("after"))
return failure();
OptionalParseResult dynamicPointParseResult =
- parser.parseOptionalOperand(dynamicSplitPoint);
+ parser.parseOptionalOperand(dynamicChunkSizes);
if (!dynamicPointParseResult.has_value()) {
- int64_t staticSplitPointValue;
- if (failed(parser.parseInteger(staticSplitPointValue)))
+ int64_t staticChunkSizesValue;
+ if (failed(parser.parseInteger(staticChunkSizesValue)))
return failure();
- staticSplitPoint =
- parser.getBuilder().getI64IntegerAttr(staticSplitPointValue);
+ staticChunkSizes =
+ parser.getBuilder().getI64IntegerAttr(staticChunkSizesValue);
}
Type targetType;
@@ -2389,43 +2471,43 @@ ParseResult SplitOp::parse(OpAsmParser &parser, OperationState &result) {
return failure();
}
if (dynamicPointParseResult.has_value()) {
- Type splitPointType;
+ Type ChunkSizesType;
if (failed(*dynamicPointParseResult) || parser.parseComma() ||
- parser.parseType(splitPointType) ||
- parser.resolveOperand(dynamicSplitPoint, splitPointType,
+ parser.parseType(ChunkSizesType) ||
+ parser.resolveOperand(dynamicChunkSizes, ChunkSizesType,
result.operands)) {
return failure();
}
- staticSplitPoint =
+ staticChunkSizes =
parser.getBuilder().getI64IntegerAttr(ShapedType::kDynamic);
}
result.addAttribute(
- SplitOp::getStaticSplitPointAttrName(result.name).getValue(),
- staticSplitPoint);
+ SplitOp::getStaticChunkSizesAttrName(result.name).getValue(),
+ staticChunkSizes);
result.addTypes({targetType, targetType});
return success();
}
void SplitOp::print(OpAsmPrinter &printer) {
printer << " " << getTarget() << " after ";
- int64_t staticSplitSize = static_cast<int64_t>(getStaticSplitPoint());
- if (staticSplitSize != ShapedType::kDynamic)
- printer << staticSplitSize;
+ int64_t staticChunkSize = static_cast<int64_t>(getStaticChunkSizes());
+ if (staticChunkSize != ShapedType::kDynamic)
+ printer << staticChunkSize;
else
- printer << getDynamicSplitPoint();
+ printer << getDynamicChunkSizes();
printer << " ";
printer.printOptionalAttrDict(getOperation()->getAttrs(),
- {getStaticSplitPointAttrName()});
+ {getStaticChunkSizesAttrName()});
printer << " : " << getTarget().getType();
- if (staticSplitSize == ShapedType::kDynamic)
- printer << ", " << getDynamicSplitPoint().getType();
+ if (staticChunkSize == ShapedType::kDynamic)
+ printer << ", " << getDynamicChunkSizes().getType();
}
LogicalResult SplitOp::verify() {
- if ((static_cast<int64_t>(getStaticSplitPoint()) != ShapedType::kDynamic) ^
- (getDynamicSplitPoint() == nullptr)) {
+ if ((static_cast<int64_t>(getStaticChunkSizes()) != ShapedType::kDynamic) ^
+ (getDynamicChunkSizes() == nullptr)) {
return emitOpError() << "expects either a dynamic or a static split "
"point to be provided";
}
diff --git a/mlir/python/mlir/dialects/transform/structured.py b/mlir/python/mlir/dialects/transform/structured.py
index 2c49ef0960c75..41051c0d5b2ff 100644
--- a/mlir/python/mlir/dialects/transform/structured.py
+++ b/mlir/python/mlir/dialects/transform/structured.py
@@ -432,25 +432,25 @@ def __init__(
self,
target: Union[Operation, Value],
dimension: Union[int, Attribute],
- split_point: Union[int, Operation, Value, Attribute],
+ chunk_sizes: Union[int, Operation, Value, Attribute],
*,
loc=None,
ip=None,
):
- if isinstance(split_point, int):
- static_split_point = split_point
- dynamic_split_point = None
+ if isinstance(chunk_sizes, int):
+ static_chunk_sizes = chunk_sizes
+ dynamic_chunk_sizes = None
else:
- static_split_point = ShapedType.get_dynamic_size()
- dynamic_split_point = split_point
+ static_chunk_sizes = ShapedType.get_dynamic_size()
+ dynamic_chunk_sizes = chunk_sizes
super().__init__(
target.type,
target.type,
target,
dimension=dimension,
- static_split_point=static_split_point,
- dynamic_split_point=dynamic_split_point,
+ static_chunk_sizes=static_chunk_sizes,
+ dynamic_chunk_sizes=dynamic_chunk_sizes,
loc=loc,
ip=ip,
)
diff --git a/mlir/test/Dialect/Linalg/transform-op-split.mlir b/mlir/test/Dialect/Linalg/transform-op-split.mlir
index 566e517d69789..e072fff4c5d77 100644
--- a/mlir/test/Dialect/Linalg/transform-op-split.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-split.mlir
@@ -197,7 +197,7 @@ func.func @two_d(%arg0: tensor<10x34xf32>,
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.consumed}) {
// expected-error @below {{expects either a dynamic or a static split point to be provided}}
- %0:2 = "transform.structured.split"(%arg1) { dimension = 1, static_split_point = -9223372036854775808 } : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ %0:2 = "transform.structured.split"(%arg1) { dimension = 1, static_chunk_sizes = -9223372036854775808 } : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.yield
}
}
diff --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py
index f97017b7a2c75..3ea73e8beea36 100644
--- a/mlir/test/python/dialects/transform_structured_ext.py
+++ b/mlir/test/python/dialects/transform_structured_ext.py
@@ -361,8 +361,8 @@ def testScalarize(target):
@run
@create_sequence
def testSplit(target):
- split = structured.SplitOp(target, dimension=1, split_point=42)
- structured.SplitOp(split.results[0], dimension=3, split_point=split.results[1])
+ split = structured.SplitOp(target, dimension=1, chunk_sizes=42)
+ structured.SplitOp(split.results[0], dimension=3, chunk_sizes=split.results[1])
# CHECK-LABEL: TEST: testSplit
# CHECK: %[[F:.+]], %[[S:.+]] = transform.structured.split %{{.*}} after 42 {dimension = 1
# CHECK: transform.structured.split %[[F]] after %[[S]] {dimension = 3
>From 5caad6d32145faea198e467178ad85a3efcc7b33 Mon Sep 17 00:00:00 2001
From: Muneeb Khan <muneeb.khan at huawei.com>
Date: Thu, 2 May 2024 19:10:49 +0800
Subject: [PATCH 3/5] [mlir][test] Test multiway SplitOp
Tests SplitOp for multiway splitting of a structured op using
the result of `continuous_tile_sizes` to specify the
chunk sizes that the target's specified dimesion should be
split into.
---
.../continuous-tiling-multiway-split.mlir | 100 ++++++++++++++++++
1 file changed, 100 insertions(+)
create mode 100644 mlir/test/Dialect/Linalg/continuous-tiling-multiway-split.mlir
diff --git a/mlir/test/Dialect/Linalg/continuous-tiling-multiway-split.mlir b/mlir/test/Dialect/Linalg/continuous-tiling-multiway-split.mlir
new file mode 100644
index 0000000000000..609766fbdc91f
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/continuous-tiling-multiway-split.mlir
@@ -0,0 +1,100 @@
+// RUN: mlir-opt --transform-interpreter --canonicalize --split-input-file %s | FileCheck %s
+
+// This tests the results of continuous_tile_sizes on multiway splitOp.
+// continuous_tile_sizes returns a list of tile-sizes and a list of split points.
+// The list of split points is consumed by splitOp to split the linalg.matmul op
+// along dimension 1 to produce as many split-up linalg.matmul ops.
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %tiles, %splits = transform.structured.continuous_tile_sizes %0 { dimension = 1, target_size = 9} : (!transform.any_op) -> !transform.any_op
+ %low, %high = transform.structured.split %0 after %splits { dimension = 1, multiway } : !transform.any_op, !transform.any_op
+ transform.yield
+ }
+}
+
+func.func @continuous_tile_linalg_matmul(
+ %arg0: tensor<25x34xf32>, %arg1: tensor<34x25xf32>, %arg2: tensor<25x25xf32>)
+ -> tensor<25x25xf32> {
+ %0 = linalg.matmul ins(%arg0, %arg1: tensor<25x34xf32>, tensor<34x25xf32>)
+ outs(%arg2: tensor<25x25xf32>)
+ -> tensor<25x25xf32>
+
+ return %0 : tensor<25x25xf32>
+}
+
+// CHECK-LABEL: @continuous_tile_linalg_matmul
+// CHECK-SAME: %[[IN1:.+]]: tensor<25x34xf32>, %[[IN2:.+]]: tensor<34x25xf32>, %[[OUT:.+]]: tensor<25x25xf32>
+// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[IN2]][0, 0] [34, 18] [1, 1] : tensor<34x25xf32> to tensor<34x18xf32>
+// CHECK %[[SLICE0:.+]] = tensor.extract_slice %[[OUT]][0, 0] [25, 18] [1, 1] : tensor<25x25xf32> to tensor<25x18xf32>
+// CHECK %[[MM0:.+]] = linalg.matmul ins(%[[IN1]], %[[SLICE]] : tensor<25x34xf32>, tensor<34x18xf32>) outs(%[[SLICE0]] : tensor<25x18xf32>) -> tensor<25x18xf32>
+// CHECK %[[INSLICE:.+]] = tensor.insert_slice %[[MM0]] into %[[OUT]][0, 0] [25, 18] [1, 1] : tensor<25x18xf32> into tensor<25x25xf32>
+// CHECK %[[SLICE1]] = tensor.extract_slice %[[IN2]][0, 18] [34, 7] [1, 1] : tensor<34x25xf32> to tensor<34x7xf32>
+// CHECK %[[SLICE2]] = tensor.extract_slice %[[INSLICE]][0, 18] [25, 7] [1, 1] : tensor<25x25xf32> to tensor<25x7xf32>
+// CHECK %[[SLICE3]] = tensor.extract_slice %[[SLICE1]][0, 0] [34, 4] [1, 1] : tensor<34x7xf32> to tensor<34x4xf32>
+// CHECK %[[SLICE4]] = tensor.extract_slice %[[SLICE2]][0, 0] [25, 4] [1, 1] : tensor<25x7xf32> to tensor<25x4xf32>
+// CHECK %[[MM1:.+]] = linalg.matmul ins(%[[IN1]], %[[SLICE3]] : tensor<25x34xf32>, tensor<34x4xf32>) outs(%[[SLICE4]] : tensor<25x4xf32>) -> tensor<25x4xf32>
+// CHECK %[[INSLICE0:.+]] = tensor.insert_slice %[[MM1]] into %[[SLICE2]][0, 0] [25, 4] [1, 1] : tensor<25x4xf32> into tensor<25x7xf32>
+// CHECK %[[SLICE5]] = tensor.extract_slice %[[SLICE1]][0, 4] [34, 3] [1, 1] : tensor<34x7xf32> to tensor<34x3xf32>
+// CHECK %[[SLICE6]] = tensor.extract_slice %[[INSLICE0]][0, 4] [25, 3] [1, 1] : tensor<25x7xf32> to tensor<25x3xf32>
+// CHECK %[[SLICE7]] = tensor.extract_slice %[[SLICE5]][0, 0] [34, 2] [1, 1] : tensor<34x3xf32> to tensor<34x2xf32>
+// CHECK %[[SLICE8]] = tensor.extract_slice %[[SLICE6]][0, 0] [25, 2] [1, 1] : tensor<25x3xf32> to tensor<25x2xf32>
+// CHECK %[[MM2:.+]] = linalg.matmul ins(%[[IN1]], %[[SLICE7]] : tensor<25x34xf32>, tensor<34x2xf32>) outs(%[[SLICE8]] : tensor<25x2xf32>) -> tensor<25x2xf32>
+// CHECK %[[INSLICE1:.+]] = tensor.insert_slice %[[MM2]] into %[[SLICE6]][0, 0] [25, 2] [1, 1] : tensor<25x2xf32> into tensor<25x3xf32>
+// CHECK %[[SLICE9]] = tensor.extract_slice %[[SLICE5]][0, 2] [34, 1] [1, 1] : tensor<34x3xf32> to tensor<34x1xf32>
+// CHECK %[[SLICE10]] = tensor.extract_slice %[[INSLICE1]][0, 2] [25, 1] [1, 1] : tensor<25x3xf32> to tensor<25x1xf32>
+// CHECK %[[MM3:.+]] = linalg.matmul ins(%[[IN1]], %[[SLICE9]] : tensor<25x34xf32>, tensor<34x1xf32>) outs(%[[SLICE10]] : tensor<25x1xf32>) -> tensor<25x1xf32>
+// CHECK %[[INSLICE2]] = tensor.insert_slice %[[MM3]] into %[[INSLICE1]][0, 2] [25, 1] [1, 1] : tensor<25x1xf32> into tensor<25x3xf32>
+// CHECK %[[INSLICE3]] = tensor.insert_slice %[[INSLICE2]] into %[[INSLICE0]][0, 4] [25, 3] [1, 1] : tensor<25x3xf32> into tensor<25x7xf32>
+// CHECK %[[INSLICE4]] = tensor.insert_slice %[[INSLICE3]] into %[[INSLICE]][0, 18] [25, 7] [1, 1] : tensor<25x7xf32> into tensor<25x25xf32>
+// CHECK return %[[INSLICE4]] : tensor<25x25xf32>
+
+// -----
+
+// Tests the same as above except that the !transform.param<i64> output type in
+// continuous_tile_sizes op triggers tile sizes and split points to be computed
+// statically and not dynamically.
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %tiles, %splits = transform.structured.continuous_tile_sizes %0 { dimension = 1, target_size = 9} : (!transform.any_op) -> !transform.param<i64>
+ %low, %high = transform.structured.split %0 after %splits { dimension = 1, multiway } : !transform.any_op, !transform.param<i64>
+ transform.yield
+ }
+}
+
+func.func @continuous_tile_static_linalg_matmul(
+ %arg0: tensor<25x34xf32>, %arg1: tensor<34x25xf32>, %arg2: tensor<25x25xf32>)
+ -> tensor<25x25xf32> {
+ %0 = linalg.matmul ins(%arg0, %arg1: tensor<25x34xf32>, tensor<34x25xf32>)
+ outs(%arg2: tensor<25x25xf32>)
+ -> tensor<25x25xf32>
+
+ return %0 : tensor<25x25xf32>
+}
+
+// CHECK-LABEL: @continuous_tile_static_linalg_matmul
+// CHECK-SAME: %[[IN1:.+]]: tensor<25x34xf32>, %[[IN2:.+]]: tensor<34x25xf32>, %[[OUT:.+]]: tensor<25x25xf32>
+// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[IN2]][0, 0] [34, 18] [1, 1] : tensor<34x25xf32> to tensor<34x18xf32>
+// CHECK %[[SLICE0:.+]] = tensor.extract_slice %[[OUT]][0, 0] [25, 18] [1, 1] : tensor<25x25xf32> to tensor<25x18xf32>
+// CHECK %[[MM0:.+]] = linalg.matmul ins(%[[IN1]], %[[SLICE]] : tensor<25x34xf32>, tensor<34x18xf32>) outs(%[[SLICE0]] : tensor<25x18xf32>) -> tensor<25x18xf32>
+// CHECK %[[INSLICE:.+]] = tensor.insert_slice %[[MM0]] into %[[OUT]][0, 0] [25, 18] [1, 1] : tensor<25x18xf32> into tensor<25x25xf32>
+// CHECK %[[SLICE1]] = tensor.extract_slice %[[IN2]][0, 18] [34, 7] [1, 1] : tensor<34x25xf32> to tensor<34x7xf32>
+// CHECK %[[SLICE2]] = tensor.extract_slice %[[INSLICE]][0, 18] [25, 7] [1, 1] : tensor<25x25xf32> to tensor<25x7xf32>
+// CHECK %[[SLICE3]] = tensor.extract_slice %[[SLICE1]][0, 0] [34, 4] [1, 1] : tensor<34x7xf32> to tensor<34x4xf32>
+// CHECK %[[SLICE4]] = tensor.extract_slice %[[SLICE2]][0, 0] [25, 4] [1, 1] : tensor<25x7xf32> to tensor<25x4xf32>
+// CHECK %[[MM1:.+]] = linalg.matmul ins(%[[IN1]], %[[SLICE3]] : tensor<25x34xf32>, tensor<34x4xf32>) outs(%[[SLICE4]] : tensor<25x4xf32>) -> tensor<25x4xf32>
+// CHECK %[[INSLICE0:.+]] = tensor.insert_slice %[[MM1]] into %[[SLICE2]][0, 0] [25, 4] [1, 1] : tensor<25x4xf32> into tensor<25x7xf32>
+// CHECK %[[SLICE5]] = tensor.extract_slice %[[SLICE1]][0, 4] [34, 3] [1, 1] : tensor<34x7xf32> to tensor<34x3xf32>
+// CHECK %[[SLICE6]] = tensor.extract_slice %[[INSLICE0]][0, 4] [25, 3] [1, 1] : tensor<25x7xf32> to tensor<25x3xf32>
+// CHECK %[[SLICE7]] = tensor.extract_slice %[[SLICE5]][0, 0] [34, 2] [1, 1] : tensor<34x3xf32> to tensor<34x2xf32>
+// CHECK %[[SLICE8]] = tensor.extract_slice %[[SLICE6]][0, 0] [25, 2] [1, 1] : tensor<25x3xf32> to tensor<25x2xf32>
+// CHECK %[[MM2:.+]] = linalg.matmul ins(%[[IN1]], %[[SLICE7]] : tensor<25x34xf32>, tensor<34x2xf32>) outs(%[[SLICE8]] : tensor<25x2xf32>) -> tensor<25x2xf32>
+// CHECK %[[INSLICE1:.+]] = tensor.insert_slice %[[MM2]] into %[[SLICE6]][0, 0] [25, 2] [1, 1] : tensor<25x2xf32> into tensor<25x3xf32>
+// CHECK %[[SLICE9]] = tensor.extract_slice %[[SLICE5]][0, 2] [34, 1] [1, 1] : tensor<34x3xf32> to tensor<34x1xf32>
+// CHECK %[[SLICE10]] = tensor.extract_slice %[[INSLICE1]][0, 2] [25, 1] [1, 1] : tensor<25x3xf32> to tensor<25x1xf32>
+// CHECK %[[MM3:.+]] = linalg.matmul ins(%[[IN1]], %[[SLICE9]] : tensor<25x34xf32>, tensor<34x1xf32>) outs(%[[SLICE10]] : tensor<25x1xf32>) -> tensor<25x1xf32>
+// CHECK %[[INSLICE2]] = tensor.insert_slice %[[MM3]] into %[[INSLICE1]][0, 2] [25, 1] [1, 1] : tensor<25x1xf32> into tensor<25x3xf32>
+// CHECK %[[INSLICE3]] = tensor.insert_slice %[[INSLICE2]] into %[[INSLICE0]][0, 4] [25, 3] [1, 1] : tensor<25x3xf32> into tensor<25x7xf32>
+// CHECK %[[INSLICE4]] = tensor.insert_slice %[[INSLICE3]] into %[[INSLICE]][0, 18] [25, 7] [1, 1] : tensor<25x7xf32> into tensor<25x25xf32>
+// CHECK return %[[INSLICE4]] : tensor<25x25xf32>
>From 3f6359ae7970700002a15ff410b68e7e5f548bba Mon Sep 17 00:00:00 2001
From: Muneeb Khan <muneeb.khan at huawei.com>
Date: Mon, 17 Jun 2024 23:11:38 +0800
Subject: [PATCH 4/5] [mlir][Transform] Add `zip_shortest` to foreach
Adds `zip_shortest` functionality to `foreach` so that when it takes
multiple handles of varying lengths - instead of failing - it shrinks
the size of all payloads to that of the shortest payload.
---
.../mlir/Dialect/Transform/IR/TransformOps.td | 10 +++++++---
mlir/lib/Dialect/Transform/IR/TransformOps.cpp | 13 +++++++++++++
2 files changed, 20 insertions(+), 3 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index 3bb297cbf91d2..7a661c663e010 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -624,7 +624,10 @@ def ForeachOp : TransformDialectOp<"foreach",
Each iteration gets executed by co-indexing the payloads of the arguments
and mapping the body's arguments to these tuples, as though iterating over
the zipped together `targets`. As such, in each iteration, the size of the
- payload of each of the body's block arguments is exactly one.
+ payload of each of the body's block arguments is exactly one. The attribute
+ `zip_shortest` can be used if the targets vary in their number of payloads;
+ this will limit the iterations to only the number of payloads found in the
+ shortest target.
This op always reads the target handles. Furthermore, it consumes a handle
if there is a transform op in the body that consumes the corresponding
@@ -645,11 +648,12 @@ def ForeachOp : TransformDialectOp<"foreach",
rollback capabilities.
}];
- let arguments = (ins Variadic<Transform_AnyHandleOrParamType>:$targets);
+ let arguments = (ins Variadic<Transform_AnyHandleOrParamType>:$targets,
+ UnitAttr:$zip_shortest);
let results = (outs Variadic<Transform_AnyHandleOrParamType>:$results);
let regions = (region SizedRegion<1>:$body);
let assemblyFormat =
- "$targets `:` type($targets) (`->` type($results)^)? $body attr-dict";
+ "$targets attr-dict `:` type($targets) (`->` type($results)^)? $body";
let hasVerifier = 1;
let extraClassDeclaration = [{
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 1efe708a5e349..f16f9fb3a0d99 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -1396,6 +1396,19 @@ transform::ForeachOp::apply(transform::TransformRewriter &rewriter,
SmallVector<SmallVector<MappedValue>> payloads;
detail::prepareValueMappings(payloads, getTargets(), state);
size_t numIterations = payloads.empty() ? 0 : payloads.front().size();
+ bool isZipShortest = getZipShortest();
+
+ if (isZipShortest) {
+ size_t smallestNumIterations =
+ llvm::min_element(payloads, [&](const SmallVector<MappedValue> &A,
+ const SmallVector<MappedValue> &B) {
+ return A.size() < B.size();
+ })->size();
+
+ for (size_t argIdx = 0; argIdx < payloads.size(); argIdx++)
+ payloads[argIdx].resize(smallestNumIterations);
+ numIterations = smallestNumIterations;
+ }
// As we will be "zipping" over them, check all payloads have the same size.
for (size_t argIdx = 1; argIdx < payloads.size(); argIdx++) {
>From 1a8d0b1c3c958fd9af9ee9fef02d5dc612638a67 Mon Sep 17 00:00:00 2001
From: Muneeb Khan <muneeb.khan at huawei.com>
Date: Fri, 14 Jun 2024 23:29:35 +0800
Subject: [PATCH 5/5] [mlir][test] Test full continuous tiling
Introduce a full test case demonstrating the continuous tiling
idea end-to-end.
---
.../Linalg/continuous-tiling-full.mlir | 180 ++++++++++++++++++
1 file changed, 180 insertions(+)
create mode 100644 mlir/test/Dialect/Linalg/continuous-tiling-full.mlir
diff --git a/mlir/test/Dialect/Linalg/continuous-tiling-full.mlir b/mlir/test/Dialect/Linalg/continuous-tiling-full.mlir
new file mode 100644
index 0000000000000..b61c727f9cffc
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/continuous-tiling-full.mlir
@@ -0,0 +1,180 @@
+// RUN: mlir-opt --transform-interpreter --canonicalize --split-input-file %s | FileCheck %s
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %tile_sizes, %chunk_sizes = transform.structured.continuous_tile_sizes %0 { dimension = 0, target_size = 9 } : (!transform.any_op) -> !transform.any_op
+ %linalg_splits, %empty = transform.structured.split %0 after %chunk_sizes { dimension = 0, multiway } : !transform.any_op, !transform.any_op
+ transform.foreach %linalg_splits, %tile_sizes : !transform.any_op, !transform.any_op {
+ ^bb1(%linalg_split: !transform.any_op, %tile_size: !transform.any_op):
+ %tiled_linalg_split, %dim0_loop = transform.structured.tile_using_for %linalg_split tile_sizes [%tile_size] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+ transform.yield
+ }
+}
+
+func.func @continuous_tile_linalg_matmul(
+ %arg0: tensor<25x34xf32>, %arg1: tensor<34x25xf32>, %arg2: tensor<25x25xf32>)
+ -> tensor<25x25xf32> {
+ %0 = linalg.matmul ins(%arg0, %arg1: tensor<25x34xf32>, tensor<34x25xf32>)
+ outs(%arg2: tensor<25x25xf32>)
+ -> tensor<25x25xf32>
+
+ return %0 : tensor<25x25xf32>
+}
+
+// CHECK-LABEL: @continuous_tile_linalg_matmul
+// CHECK-SAME: (%[[IN1:.+]]: tensor<25x34xf32>, %[[IN2:.+]]: tensor<34x25xf32>, %[[OUT:.+]]: tensor<25x25xf32>) -> tensor<25x25xf32> {
+// CHECK: %[[C18:.+]] = arith.constant 18 : index
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[C9:.+]] = arith.constant 9 : index
+// CHECK: %[[XSIN18:.+]] = tensor.extract_slice %[[IN1]][0, 0] [18, 34] [1, 1] : tensor<25x34xf32> to tensor<18x34xf32>
+// CHECK: %[[XSOUT18:.+]] = tensor.extract_slice %[[OUT]][0, 0] [18, 25] [1, 1] : tensor<25x25xf32> to tensor<18x25xf32>
+// CHECK: %[[R0:.+]] = scf.for %[[IDX:.+]] = %[[C0]] to %[[C18]] step %[[C9]] iter_args(%[[XSOUT18ARG:.+]] = %[[XSOUT18]]) -> (tensor<18x25xf32>) {
+// CHECK: %[[XSIN19:.+]] = tensor.extract_slice %[[XSIN18]][%[[IDX]], 0] [9, 34] [1, 1] : tensor<18x34xf32> to tensor<9x34xf32>
+// CHECK: %[[XSOUT9:.+]] = tensor.extract_slice %[[XSOUT18ARG]][%[[IDX]], 0] [9, 25] [1, 1] : tensor<18x25xf32> to tensor<9x25xf32>
+// CHECK: %[[MATMUL:.+]] = linalg.matmul ins(%[[XSIN19]], %[[IN2]] : tensor<9x34xf32>, tensor<34x25xf32>) outs(%[[XSOUT9]] : tensor<9x25xf32>) -> tensor<9x25xf32>
+// CHECK: %[[INS9:.+]] = tensor.insert_slice %[[MATMUL]] into %[[XSOUT18ARG]][%[[IDX]], 0] [9, 25] [1, 1] : tensor<9x25xf32> into tensor<18x25xf32>
+// CHECK: scf.yield %[[INS9]] : tensor<18x25xf32>
+// CHECK: }
+// CHECK: %[[INS:.+]] = tensor.insert_slice %[[R0]] into %[[OUT]][0, 0] [18, 25] [1, 1] : tensor<18x25xf32> into tensor<25x25xf32>
+// CHECK: %[[XS1:.+]] = tensor.extract_slice %[[IN1]][18, 0] [7, 34] [1, 1] : tensor<25x34xf32> to tensor<7x34xf32>
+// CHECK: %[[XS2:.+]] = tensor.extract_slice %[[INS]][18, 0] [7, 25] [1, 1] : tensor<25x25xf32> to tensor<7x25xf32>
+// CHECK: %[[XS3:.+]] = tensor.extract_slice %[[XS1]][0, 0] [4, 34] [1, 1] : tensor<7x34xf32> to tensor<4x34xf32>
+// CHECK: %[[XS4:.+]] = tensor.extract_slice %[[XS2]][0, 0] [4, 25] [1, 1] : tensor<7x25xf32> to tensor<4x25xf32>
+// CHECK: %[[R1:.+]] = linalg.matmul ins(%[[XS3]], %[[IN2]] : tensor<4x34xf32>, tensor<34x25xf32>) outs(%[[XS4]] : tensor<4x25xf32>) -> tensor<4x25xf32>
+// CHECK: %[[INS5:.+]] = tensor.insert_slice %[[R1]] into %[[XS2]][0, 0] [4, 25] [1, 1] : tensor<4x25xf32> into tensor<7x25xf32>
+// CHECK: %[[XS6:.+]] = tensor.extract_slice %[[XS1]][4, 0] [3, 34] [1, 1] : tensor<7x34xf32> to tensor<3x34xf32>
+// CHECK: %[[XS7:.+]] = tensor.extract_slice %[[INS5]][4, 0] [3, 25] [1, 1] : tensor<7x25xf32> to tensor<3x25xf32>
+// CHECK: %[[XS8:.+]] = tensor.extract_slice %[[XS6]][0, 0] [2, 34] [1, 1] : tensor<3x34xf32> to tensor<2x34xf32>
+// CHECK: %[[XS9:.+]] = tensor.extract_slice %[[XS7]][0, 0] [2, 25] [1, 1] : tensor<3x25xf32> to tensor<2x25xf32>
+// CHECK: %[[R2:.+]] = linalg.matmul ins(%[[XS8]], %[[IN2]] : tensor<2x34xf32>, tensor<34x25xf32>) outs(%[[XS9]] : tensor<2x25xf32>) -> tensor<2x25xf32>
+// CHECK: %[[INS10:.+]] = tensor.insert_slice %[[R2]] into %[[XS7]][0, 0] [2, 25] [1, 1] : tensor<2x25xf32> into tensor<3x25xf32>
+// CHECK: %[[XS11:.+]] = tensor.extract_slice %[[XS6]][2, 0] [1, 34] [1, 1] : tensor<3x34xf32> to tensor<1x34xf32>
+// CHECK: %[[XS12:.+]] = tensor.extract_slice %[[INS10]][2, 0] [1, 25] [1, 1] : tensor<3x25xf32> to tensor<1x25xf32>
+// CHECK: %[[R3:.+]] = linalg.matmul ins(%[[XS11]], %[[IN2]] : tensor<1x34xf32>, tensor<34x25xf32>) outs(%[[XS12]] : tensor<1x25xf32>) -> tensor<1x25xf32>
+// CHECK: %[[INS13:.+]] = tensor.insert_slice %[[R3]] into %[[INS10]][2, 0] [1, 25] [1, 1] : tensor<1x25xf32> into tensor<3x25xf32>
+// CHECK: %[[INS14:.+]] = tensor.insert_slice %[[INS13]] into %[[INS5]][4, 0] [3, 25] [1, 1] : tensor<3x25xf32> into tensor<7x25xf32>
+// CHECK: %[[INS15:.+]] = tensor.insert_slice %[[INS14]] into %[[INS]][18, 0] [7, 25] [1, 1] : tensor<7x25xf32> into tensor<25x25xf32>
+// CHECK: return %[[INS15]] : tensor<25x25xf32>
+
+// -----
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %tile_sizes, %chunk_sizes = transform.structured.continuous_tile_sizes %0 { dimension = 0, target_size = 9 } : (!transform.any_op) -> !transform.param<i64>
+ %linalg_splits, %empty = transform.structured.split %0 after %chunk_sizes { dimension = 0, multiway } : !transform.any_op, !transform.param<i64>
+ transform.foreach %linalg_splits, %tile_sizes : !transform.any_op, !transform.param<i64> {
+ ^bb1(%linalg_split: !transform.any_op, %tile_size: !transform.param<i64>):
+ %tiled_linalg_split, %dim0_loop = transform.structured.tile_using_for %linalg_split tile_sizes [%tile_size] : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+ transform.yield
+ }
+}
+
+func.func @continuous_tile_static_linalg_matmul(
+ %arg0: tensor<25x34xf32>, %arg1: tensor<34x25xf32>, %arg2: tensor<25x25xf32>)
+ -> tensor<25x25xf32> {
+ %0 = linalg.matmul ins(%arg0, %arg1: tensor<25x34xf32>, tensor<34x25xf32>)
+ outs(%arg2: tensor<25x25xf32>)
+ -> tensor<25x25xf32>
+
+ return %0 : tensor<25x25xf32>
+}
+
+// CHECK-LABEL: @continuous_tile_static_linalg_matmul
+// CHECK-SAME: (%[[IN1:.+]]: tensor<25x34xf32>, %[[IN2:.+]]: tensor<34x25xf32>, %[[OUT:.+]]: tensor<25x25xf32>) -> tensor<25x25xf32> {
+// CHECK: %[[C9:.+]] = arith.constant 9 : index
+// CHECK: %[[C18:.+]] = arith.constant 18 : index
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[XSIN18:.+]] = tensor.extract_slice %[[IN1]][0, 0] [18, 34] [1, 1] : tensor<25x34xf32> to tensor<18x34xf32>
+// CHECK: %[[XSOUT18:.+]] = tensor.extract_slice %[[OUT]][0, 0] [18, 25] [1, 1] : tensor<25x25xf32> to tensor<18x25xf32>
+// CHECK: %[[R0:.+]] = scf.for %[[IDX:.+]] = %[[C0]] to %[[C18]] step %[[C9]] iter_args(%[[XSOUT18ARG:.+]] = %[[XSOUT18]]) -> (tensor<18x25xf32>) {
+// CHECK: %[[XSIN19:.+]] = tensor.extract_slice %[[XSIN18]][%[[IDX]], 0] [9, 34] [1, 1] : tensor<18x34xf32> to tensor<9x34xf32>
+// CHECK: %[[XSOUT9:.+]] = tensor.extract_slice %[[XSOUT18ARG]][%[[IDX]], 0] [9, 25] [1, 1] : tensor<18x25xf32> to tensor<9x25xf32>
+// CHECK: %[[MATMUL:.+]] = linalg.matmul ins(%[[XSIN19]], %[[IN2]] : tensor<9x34xf32>, tensor<34x25xf32>) outs(%[[XSOUT9]] : tensor<9x25xf32>) -> tensor<9x25xf32>
+// CHECK: %[[INS9:.+]] = tensor.insert_slice %[[MATMUL]] into %[[XSOUT18ARG]][%[[IDX]], 0] [9, 25] [1, 1] : tensor<9x25xf32> into tensor<18x25xf32>
+// CHECK: scf.yield %[[INS9]] : tensor<18x25xf32>
+// CHECK: }
+// CHECK: %[[INS:.+]] = tensor.insert_slice %[[R0]] into %[[OUT]][0, 0] [18, 25] [1, 1] : tensor<18x25xf32> into tensor<25x25xf32>
+// CHECK: %[[XS1:.+]] = tensor.extract_slice %[[IN1]][18, 0] [7, 34] [1, 1] : tensor<25x34xf32> to tensor<7x34xf32>
+// CHECK: %[[XS2:.+]] = tensor.extract_slice %[[INS]][18, 0] [7, 25] [1, 1] : tensor<25x25xf32> to tensor<7x25xf32>
+// CHECK: %[[XS3:.+]] = tensor.extract_slice %[[XS1]][0, 0] [4, 34] [1, 1] : tensor<7x34xf32> to tensor<4x34xf32>
+// CHECK: %[[XS4:.+]] = tensor.extract_slice %[[XS2]][0, 0] [4, 25] [1, 1] : tensor<7x25xf32> to tensor<4x25xf32>
+// CHECK: %[[R1:.+]] = linalg.matmul ins(%[[XS3]], %[[IN2]] : tensor<4x34xf32>, tensor<34x25xf32>) outs(%[[XS4]] : tensor<4x25xf32>) -> tensor<4x25xf32>
+// CHECK: %[[INS5:.+]] = tensor.insert_slice %[[R1]] into %[[XS2]][0, 0] [4, 25] [1, 1] : tensor<4x25xf32> into tensor<7x25xf32>
+// CHECK: %[[XS6:.+]] = tensor.extract_slice %[[XS1]][4, 0] [3, 34] [1, 1] : tensor<7x34xf32> to tensor<3x34xf32>
+// CHECK: %[[XS7:.+]] = tensor.extract_slice %[[INS5]][4, 0] [3, 25] [1, 1] : tensor<7x25xf32> to tensor<3x25xf32>
+// CHECK: %[[XS8:.+]] = tensor.extract_slice %[[XS6]][0, 0] [2, 34] [1, 1] : tensor<3x34xf32> to tensor<2x34xf32>
+// CHECK: %[[XS9:.+]] = tensor.extract_slice %[[XS7]][0, 0] [2, 25] [1, 1] : tensor<3x25xf32> to tensor<2x25xf32>
+// CHECK: %[[R2:.+]] = linalg.matmul ins(%[[XS8]], %[[IN2]] : tensor<2x34xf32>, tensor<34x25xf32>) outs(%[[XS9]] : tensor<2x25xf32>) -> tensor<2x25xf32>
+// CHECK: %[[INS10:.+]] = tensor.insert_slice %[[R2]] into %[[XS7]][0, 0] [2, 25] [1, 1] : tensor<2x25xf32> into tensor<3x25xf32>
+// CHECK: %[[XS11:.+]] = tensor.extract_slice %[[XS6]][2, 0] [1, 34] [1, 1] : tensor<3x34xf32> to tensor<1x34xf32>
+// CHECK: %[[XS12:.+]] = tensor.extract_slice %[[INS10]][2, 0] [1, 25] [1, 1] : tensor<3x25xf32> to tensor<1x25xf32>
+// CHECK: %[[R3:.+]] = linalg.matmul ins(%[[XS11]], %[[IN2]] : tensor<1x34xf32>, tensor<34x25xf32>) outs(%[[XS12]] : tensor<1x25xf32>) -> tensor<1x25xf32>
+// CHECK: %[[INS13:.+]] = tensor.insert_slice %[[R3]] into %[[INS10]][2, 0] [1, 25] [1, 1] : tensor<1x25xf32> into tensor<3x25xf32>
+// CHECK: %[[INS14:.+]] = tensor.insert_slice %[[INS13]] into %[[INS5]][4, 0] [3, 25] [1, 1] : tensor<3x25xf32> into tensor<7x25xf32>
+// CHECK: %[[INS15:.+]] = tensor.insert_slice %[[INS14]] into %[[INS]][18, 0] [7, 25] [1, 1] : tensor<7x25xf32> into tensor<25x25xf32>
+// CHECK: return %[[INS15]] : tensor<25x25xf32>
+
+// -----
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %tile_sizes, %chunk_sizes = transform.structured.continuous_tile_sizes %0 { dimension = 0, target_size = 9 } : (!transform.any_op) -> !transform.any_op
+ %linalg_splits, %empty = transform.structured.split %0 after %chunk_sizes { dimension = 0, multiway } : !transform.any_op, !transform.any_op
+ transform.foreach %linalg_splits, %tile_sizes {zip_shortest} : !transform.any_op, !transform.any_op {
+ ^bb1(%linalg_split: !transform.any_op, %tile_size: !transform.any_op):
+ %tiled_linalg_split, %dim0_loop = transform.structured.tile_using_for %linalg_split tile_sizes [%tile_size] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+ transform.yield
+ }
+}
+
+func.func @continuous_tile_dynamic_linalg_matmul(
+ %arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>)
+ -> tensor<?x?xf32> {
+ %0 = linalg.matmul ins(%arg0, %arg1: tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%arg2: tensor<?x?xf32>)
+ -> tensor<?x?xf32>
+
+ return %0 : tensor<?x?xf32>
+}
+
+// CHECK: #[[$MAP0:.*]] = affine_map<()[s0, s1] -> ((s0 floordiv 9) * 9, s1)>
+// CHECK: #[[$MAP3:.*]] = affine_map<()[s0, s1, s2] -> (((s0 mod 9) floordiv 8) * 8, s1 - s2)>
+// CHECK: #[[$MAP6:.*]] = affine_map<()[s0, s1, s2, s3] -> ((((s0 mod 9) mod 8) floordiv 4) * 4, s1 - s2 - s3)>
+// CHECK: #[[$MAP9:.*]] = affine_map<()[s0, s1, s2, s3, s4] -> ((((s0 mod 9) mod 4) floordiv 2) * 2, s1 - s2 - s3 - s4)>
+// CHECK: #[[$MAP12:.*]] = affine_map<()[s0, s1, s2, s3, s4, s5] -> ((s0 mod 9) mod 2, s1 - s2 - s3 - s4 - s5)>
+// CHECK-LABEL: @continuous_tile_dynamic_linalg_matmul
+// CHECK-DAG: %[[C9:.*]] = arith.constant 9 : index
+// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
+// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[AM0:.*]] = affine.min #[[$MAP0]]()[%{{.*}}, %{{.*}}]
+// CHECK: %{{.*}} = scf.for %[[IDX:.+]] = %[[C0]] to %[[AM0]] step %[[C9]] iter_args(%[[OUT:.+]] = %{{.*}}) -> (tensor<?x?xf32>) {
+// CHECK: %[[MM:.+]] = linalg.matmul ins(%{{.*}}, %{{.*}} : tensor<?x?xf32>, tensor<?x?xf32>) outs(%{{.*}} : tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK: %{{.*}} = tensor.insert_slice %[[MM]] into %[[OUT]][%[[IDX]], 0] [%{{.*}}, %{{.*}}] [1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
+// CHECK: %[[AM4:.*]] = affine.min #[[$MAP3]]()[%{{.*}}, %{{.*}}, %[[AM0]]]
+// CHECK: %{{.*}} = scf.for %[[IDX:.+]] = %[[C0]] to %[[AM4]] step %[[C8]] iter_args(%[[OUT:.+]] = %{{.*}}) -> (tensor<?x?xf32>) {
+// CHECK: %[[MM:.+]] = linalg.matmul ins(%{{.*}}, %{{.*}} : tensor<?x?xf32>, tensor<?x?xf32>) outs(%{{.*}} : tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK: %{{.*}} = tensor.insert_slice %[[MM]] into %[[OUT]][%[[IDX]], 0] [%{{.*}}, %{{.*}}] [1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
+// CHECK: %[[AM8:.*]] = affine.min #[[$MAP6]]()[%{{.*}}, %{{.*}}, %[[AM0]], %[[AM4]]]
+// CHECK: %{{.*}} = scf.for %[[IDX:.+]] = %[[C0]] to %[[AM8]] step %[[C4]] iter_args(%[[OUT:.+]] = %{{.*}}) -> (tensor<?x?xf32>) {
+// CHECK: %[[MM:.+]] = linalg.matmul ins(%{{.*}}, %{{.*}} : tensor<?x?xf32>, tensor<?x?xf32>) outs(%{{.*}} : tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK: %{{.*}} = tensor.insert_slice %[[MM]] into %[[OUT]][%[[IDX]], 0] [%{{.*}}, %{{.*}}] [1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
+// CHECK: %[[AM12:.*]] = affine.min #[[$MAP9]]()[%{{.*}}, %{{.*}}, %[[AM0]], %[[AM4]], %[[AM8]]]
+// CHECK: %{{.*}} = scf.for %[[IDX:.+]] = %[[C0]] to %[[AM12]] step %[[C2]] iter_args(%[[OUT:.+]] = %{{.*}}) -> (tensor<?x?xf32>) {
+// CHECK: %[[MM:.+]] = linalg.matmul ins(%{{.*}}, %{{.*}} : tensor<?x?xf32>, tensor<?x?xf32>) outs(%{{.*}} : tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK: %{{.*}} = tensor.insert_slice %[[MM]] into %[[OUT]][%[[IDX]], 0] [%{{.*}}, %{{.*}}] [1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
+// CHECK: %[[AM16:.*]] = affine.min #[[$MAP12]]()[%{{.*}}, %{{.*}}, %[[AM0]], %[[AM4]], %[[AM8]], %[[AM12]]]
+// CHECK: %{{.*}} = scf.for %[[IDX:.+]] = %[[C0]] to %[[AM16]] step %[[C1]] iter_args(%[[OUT:.+]] = %{{.*}}) -> (tensor<?x?xf32>) {
+// CHECK: %[[MM:.+]] = linalg.matmul ins(%{{.*}}, %{{.*}} : tensor<1x?xf32>, tensor<?x?xf32>) outs(%{{.*}} : tensor<1x?xf32>) -> tensor<1x?xf32>
+// CHECK: %{{.*}} = tensor.insert_slice %[[MM]] into %[[OUT]][%[[IDX]], 0] [1, %{{.*}}] [1, 1] : tensor<1x?xf32> into tensor<?x?xf32>
\ No newline at end of file
More information about the Mlir-commits
mailing list