[Mlir-commits] [mlir] [MLIR] Add continuous tiling to TileUsingForOp (PR #82792)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu May 30 06:24:52 PDT 2024
https://github.com/muneebkhan85 updated https://github.com/llvm/llvm-project/pull/82792
>From 5f00f20c882a269014087e4ec3bd410615af020f Mon Sep 17 00:00:00 2001
From: Muneeb Khan <muneeb.khan at huawei.com>
Date: Wed, 1 May 2024 22:15:18 +0800
Subject: [PATCH 1/8] [MLIR] Add continuous tiling to Transform dialect
Add continuous tiling op structured.continuous_tile
to the transform dialect that returns as result a list of
exponentially diminishing tile sizes and a list of split
points to do a multiway split of the target linalg op along
the specified dimension.
---
.../Linalg/TransformOps/LinalgTransformOps.td | 46 ++++++
.../Dialect/Linalg/Transforms/Transforms.h | 20 +++
.../TransformOps/LinalgTransformOps.cpp | 151 ++++++++++++++++++
mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp | 133 +++++++++++++++
4 files changed, 350 insertions(+)
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 93e2c2db729da..cce423b09617e 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -1819,6 +1819,52 @@ def TileReductionUsingForallOp :
}
+//===----------------------------------------------------------------------===//
+// ContinuousTileSizesOp
+//===----------------------------------------------------------------------===//
+
+def ContinuousTileSizesOp : Op<Transform_Dialect, "structured.continuous_tile_sizes",
+ [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ DeclareOpInterfaceMethods<TransformOpInterface>,
+ ReportTrackingListenerFailuresOpTrait]> {
+ let description = [{
+ This transform takes a linalg as target and a dimension and target size
+ as attributes to generate a list of (1) exponentially diminishing
+ tile sizes that are powers of 2; and (2) the corresponding chunk-sizes
+ the linalg op should be split into along the given dimension.
+
+ For example, for `target_size` 9, and `dimension` 0 for the following
+ linalg op as target
+
+ ```
+ %0 = linalg.matmul ins(%arg0, %arg1: tensor<25x34xf32>, tensor<34x25xf32>)
+ outs(%arg2: tensor<25x25xf32>)
+ ```
+
+ the first result `tile_sizes` will be a list of diminishing tile sizes
+ 9, 4, 2, 1; and the second result will be a list of chunk sizes
+ 18, 4, 2, 1 that the corresponding dimension should be split into.
+
+ After the linalg has been split along the given dimension (for example using
+ multiway split), each chunk can be tiled with the corresponding tile size in
+ the `tile_sizes` list generated as a result of this op.
+
+ Specifying the output type as !transform.param<i64> will cause `tile_sizes`
+ and `split_points` to be computed statically and not dynamically.
+ }];
+
+ let arguments = (ins TransformHandleTypeInterface:$target,
+ ConfinedAttr<I64Attr, [IntNonNegative]>:$dimension,
+ ConfinedAttr<I64Attr, [IntNonNegative]>:$target_size);
+ let results = (outs TransformAnyParamTypeOrAnyHandle:$tile_sizes,
+ TransformAnyParamTypeOrAnyHandle:$split_points);
+ let hasVerifier = 1;
+ let assemblyFormat =
+ "$target attr-dict `:` custom<ContinuousTileSizeTypes>("
+ "type($target), type($tile_sizes), type($split_points))";
+
+}
+
//===----------------------------------------------------------------------===//
// TileUsingForOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 308ce92e35520..ef3656c334ea6 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -801,6 +801,15 @@ struct MultiSizeSpecificationBase {
/// Number of tiles associated with each size.
T lowTripCount, highTripCount;
};
+
+template <typename T>
+struct ContinuousTileSizeSpecificationBase {
+ /// Tile sizes.
+ SmallVector<T> tileSizes;
+ /// Number of tiles associated with each size.
+ SmallVector<T> tripCounts;
+};
+
} // namespace detail
/// A description of a multi-size tiling comprising tile sizes and numbers of
@@ -811,6 +820,11 @@ struct MultiSizeSpecification
struct StaticMultiSizeSpecification
: public detail::MultiSizeSpecificationBase<int64_t> {};
+struct ContinuousTileSizeSpecification
+ : public detail::ContinuousTileSizeSpecificationBase<Value> {};
+struct StaticContinuousTileSizeSpecification
+ : public detail::ContinuousTileSizeSpecificationBase<int64_t> {};
+
/// Emits the IR computing the multi-sized tiling specification with two tile
/// sizes not exceeding `targetSize`, each divisible by `sizeDivisor`, such
/// that there exist numbers of tiles with these sizes that fully cover the
@@ -846,6 +860,12 @@ FailureOr<StaticMultiSizeSpecification>
computeStaticMultiTileSizes(LinalgOp op, unsigned dimension, int64_t targetSize,
int64_t divisor);
+FailureOr<StaticContinuousTileSizeSpecification>
+computeStaticContinuousTileSizes(LinalgOp op, unsigned dimension,
+ unsigned targetSize);
+FailureOr<ContinuousTileSizeSpecification>
+computeContinuousTileSizes(OpBuilder &builder, LinalgOp op, unsigned dimension,
+ OpFoldResult targetSize, bool emitAssertions);
/// Rewrite a TilingInterface `op` to a tiled `scf.forall`, applying
/// tiling by `numThreads`.
/// If non-empty, the `mapping` is added as an attribute to the
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 9b3121774ab3a..1e9e3163ef996 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2583,6 +2583,157 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne(
return DiagnosedSilenceableFailure::success();
}
+//===----------------------------------------------------------------------===//
+// ContinuousTileSizesOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::ContinuousTileSizesOp::apply(transform::TransformRewriter &rewriter,
+ TransformResults &transformResults,
+ TransformState &state) {
+
+ SmallVector<Operation *> targetOps =
+ llvm::to_vector(state.getPayloadOps(getTarget()));
+
+ if (!llvm::hasSingleElement(targetOps)) {
+ return emitDefiniteFailure() << "requires exactly one target (got "
+ << llvm::range_size(targetOps) << ")";
+ }
+
+ auto target = dyn_cast<LinalgOp>(*targetOps.begin());
+
+ OpBuilder builder(target.getContext());
+
+ if (!target)
+ return emitDefiniteFailure() << "expected Linalg Op";
+
+ if (isa<TransformParamTypeInterface>(getSplitPoints().getType())) {
+ if (target.hasDynamicShape()) {
+ auto diag = emitSilenceableError()
+ << "cannot compute parametric tile sizes for dynamically "
+ "shaped payload op";
+ diag.attachNote(target->getLoc()) << "payload op";
+ return diag;
+ }
+
+ FailureOr<StaticContinuousTileSizeSpecification> spec =
+ computeStaticContinuousTileSizes(target, getDimension(),
+ getTargetSize());
+ if (failed(spec)) {
+ return emitSilenceableError()
+ << "failed to compute multi-size tiling sizes";
+ }
+
+ SmallVector<int64_t> splitPoints;
+
+ auto tileSizeTripCountPairs =
+ llvm::zip_equal(spec->tileSizes, spec->tripCounts);
+
+ for (auto [idx, pair] : llvm::enumerate(tileSizeTripCountPairs))
+ splitPoints.push_back(std::get<0>(pair) * std::get<1>(pair));
+
+ auto makeI64AttrsFromI64 = [&](ArrayRef<int64_t> values) {
+ return llvm::to_vector(
+ llvm::map_range(values, [&](int64_t value) -> Attribute {
+ return builder.getI64IntegerAttr(value);
+ }));
+ };
+ transformResults.setParams(cast<OpResult>(getTileSizes()),
+ makeI64AttrsFromI64(spec->tileSizes));
+ transformResults.setParams(cast<OpResult>(getSplitPoints()),
+ makeI64AttrsFromI64(splitPoints));
+
+ return DiagnosedSilenceableFailure::success();
+ }
+
+ builder.setInsertionPoint(target);
+
+ OpFoldResult targetSize = builder.getIndexAttr(getTargetSize());
+ unsigned dimension = getDimension();
+
+ FailureOr<ContinuousTileSizeSpecification> spec =
+ computeContinuousTileSizes(builder, target, dimension, targetSize, true);
+ if (failed(spec)) {
+ return emitSilenceableError() << "could not generate tile size computation";
+ }
+
+ auto tileSizeTripCountPairs =
+ llvm::zip_equal(spec->tileSizes, spec->tripCounts);
+
+ AffineExpr s0 = builder.getAffineSymbolExpr(0);
+ AffineExpr s1 = builder.getAffineSymbolExpr(1);
+ auto apply = [&](AffineExpr expr, ArrayRef<OpFoldResult> ofrs) -> Value {
+ return affine::makeComposedAffineApply(builder, target->getLoc(), expr,
+ ofrs);
+ };
+
+ SmallVector<Value> splitPoints;
+ Value splitPoint;
+ for (auto [idx, pair] : llvm::enumerate(tileSizeTripCountPairs)) {
+ splitPoint = apply(s0 * s1, {std::get<0>(pair), std::get<1>(pair)});
+ splitPoints.push_back(splitPoint);
+ }
+
+ auto makeOpFromValue = [&](ArrayRef<Value> values) {
+ return llvm::to_vector(
+ llvm::map_range(values, [&](Value value) -> Operation * {
+ return value.getDefiningOp();
+ }));
+ };
+
+ transformResults.set(cast<OpResult>(getTileSizes()),
+ makeOpFromValue(spec->tileSizes));
+ transformResults.set(cast<OpResult>(getSplitPoints()),
+ makeOpFromValue(splitPoints));
+
+ return DiagnosedSilenceableFailure::success();
+}
+
+LogicalResult transform::ContinuousTileSizesOp::verify() {
+
+ if (getTileSizes().getType() != getSplitPoints().getType()) {
+ return emitOpError() << "expects all results type to be the same";
+ }
+
+ return success();
+}
+
+void transform::ContinuousTileSizesOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ if (isa<TransformParamTypeInterface>(getTileSizes().getType()))
+ onlyReadsPayload(effects);
+ else
+ modifiesPayload(effects);
+ onlyReadsHandle(getTarget(), effects);
+ producesHandle(getTileSizes(), effects);
+ producesHandle(getSplitPoints(), effects);
+}
+
+static void printContinuousTileSizeTypes(OpAsmPrinter &printer, Operation *op,
+ Type targetType, Type tile_sizes,
+ Type) {
+ printer.printFunctionalType(TypeRange{targetType}, TypeRange{tile_sizes});
+}
+
+static ParseResult parseContinuousTileSizeTypes(OpAsmParser &parser,
+ Type &targetType,
+ Type &tileSizesType,
+ Type &splitPointsType) {
+ FunctionType funcType;
+ llvm::SMLoc typeLoc = parser.getCurrentLocation();
+ if (failed(parser.parseType<FunctionType>(funcType)))
+ return failure();
+
+ if (funcType.getNumInputs() != 1 || funcType.getNumResults() != 1) {
+ parser.emitError(typeLoc) << "expects a trailing functional type with one "
+ "argument and one result";
+ }
+ targetType = funcType.getInput(0);
+ tileSizesType = splitPointsType = funcType.getResult(0);
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// TileUsingForOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index fd314ef9f8134..d88bcaf142b87 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -107,6 +107,139 @@ static void emitIsPositiveIndexAssertion(ImplicitLocOpBuilder &b,
b.getStringAttr("expected strictly positive tile size and divisor"));
}
+FailureOr<StaticContinuousTileSizeSpecification>
+mlir::linalg::computeStaticContinuousTileSizes(LinalgOp op, unsigned dimension,
+ unsigned targetSize) {
+
+ assert(!op.hasDynamicShape() &&
+ "cannot compute static multi-tile sizes for an op with dynamic shape");
+ assert(targetSize > 0 && "target size must be non-negative");
+ assert(dimension < op.getNumLoops() && "dimension overflow");
+
+ StaticContinuousTileSizeSpecification spec;
+ int64_t loopRange = op.getStaticLoopRanges()[dimension];
+ int64_t tripCount = loopRange / targetSize;
+
+ unsigned tileSize = targetSize;
+
+ spec.tileSizes.push_back(tileSize);
+ spec.tripCounts.push_back(tripCount);
+
+ int64_t remainderChunk = loopRange % targetSize;
+
+ while (tileSize > 1 && remainderChunk != 0) {
+
+ uint64_t maxPower = llvm::bit_floor(tileSize);
+ tileSize = maxPower == tileSize ? maxPower >> 1 : maxPower;
+
+ tripCount = remainderChunk / tileSize;
+
+ if (tripCount > 0) {
+ spec.tileSizes.push_back(tileSize);
+ spec.tripCounts.push_back(tripCount);
+ }
+
+ remainderChunk = remainderChunk % tileSize;
+ }
+
+ auto tripCountCheck = [&](SmallVector<int64_t> tileSizes,
+ SmallVector<int64_t> tripCounts,
+ int64_t range) -> bool {
+ int64_t computedRange = 0;
+ for (auto [tileSize, tripCount] : llvm::zip(tileSizes, tripCounts))
+ computedRange += tileSize * tripCount;
+ return range == computedRange;
+ };
+
+ if (!tripCountCheck(spec.tileSizes, spec.tripCounts, loopRange))
+ return failure();
+
+ return spec;
+}
+
+FailureOr<ContinuousTileSizeSpecification>
+mlir::linalg::computeContinuousTileSizes(OpBuilder &builder, LinalgOp op,
+ unsigned dimension,
+ OpFoldResult targetSize,
+ bool emitAssertions) {
+
+ // Bail out on dimension overflow.
+ if (dimension >= op.getNumLoops())
+ return failure();
+
+ // The code below works only on values.
+ Location loc = op.getLoc();
+ ImplicitLocOpBuilder b(loc, builder);
+ if (emitAssertions) {
+ emitIsPositiveIndexAssertion(b, targetSize);
+ }
+ Value targetSizeValue =
+ getValueOrCreateConstantIndexOp(builder, loc, targetSize);
+
+ // Find the trip count of the iteration space dimension for which the tile
+ // sizes are computed.
+ SmallVector<OpFoldResult> allShapes =
+ op.createFlatListOfOperandDims(b, b.getLoc());
+ AffineMap shapesToLoops = op.getShapesToLoopsMap();
+ SmallVector<OpFoldResult> loopRanges =
+ makeComposedFoldedMultiResultAffineApply(b, op.getLoc(), shapesToLoops,
+ allShapes);
+
+ Value loopRange =
+ getValueOrCreateConstantIndexOp(b, op.getLoc(), loopRanges[dimension]);
+
+ ContinuousTileSizeSpecification spec;
+
+ // Compute the tile sizes and the respective numbers of tiles.
+ AffineExpr s0 = b.getAffineSymbolExpr(0);
+ AffineExpr s1 = b.getAffineSymbolExpr(1);
+ auto apply = [&](AffineExpr expr, ArrayRef<OpFoldResult> ofrs) -> Value {
+ return affine::makeComposedAffineApply(b, b.getLoc(), expr, ofrs);
+ };
+
+ Value tripCountValue = apply(s0.floorDiv(s1), {loopRange, targetSizeValue});
+ Value remainderChunkValue = apply(s0 % s1, {loopRange, targetSizeValue});
+
+ OpFoldResult tripCountSize = affine::makeComposedFoldedAffineApply(
+ b, b.getLoc(), s0.floorDiv(s1), {loopRange, targetSizeValue});
+
+ uint64_t tileSizeInt = *getConstantIntValue(targetSizeValue);
+
+ assert(tileSizeInt > 0 && "target size must be non-negative");
+
+ spec.tileSizes.push_back(targetSizeValue);
+ spec.tripCounts.push_back(tripCountValue);
+
+ while (tileSizeInt > 1) {
+ uint64_t maxPower = llvm::bit_floor(tileSizeInt);
+ tileSizeInt = maxPower == tileSizeInt ? maxPower >> 1 : maxPower;
+ auto constStepOp =
+ builder.createOrFold<arith::ConstantIndexOp>(b.getLoc(), tileSizeInt);
+ tripCountValue = apply(s0.floorDiv(s1), {remainderChunkValue, constStepOp});
+
+ tripCountSize = affine::makeComposedFoldedAffineApply(
+ b, b.getLoc(), s0.floorDiv(s1), {remainderChunkValue, constStepOp});
+
+ // Optimization if tripCount can be determined to be zero.
+ if (Attribute attr = llvm::dyn_cast_if_present<Attribute>(tripCountSize)) {
+ auto intAttr = cast<IntegerAttr>(attr);
+ bool isTripCountZero = intAttr.getValue().isZero();
+
+ if (!isTripCountZero) {
+ spec.tileSizes.push_back(constStepOp);
+ spec.tripCounts.push_back(tripCountValue);
+ }
+ } else {
+ spec.tileSizes.push_back(constStepOp);
+ spec.tripCounts.push_back(tripCountValue);
+ }
+
+ remainderChunkValue = apply(s0 % s1, {remainderChunkValue, constStepOp});
+ }
+
+ return spec;
+}
+
FailureOr<StaticMultiSizeSpecification>
mlir::linalg::computeStaticMultiTileSizes(LinalgOp op, unsigned dimension,
int64_t targetSize, int64_t divisor) {
>From 752f6c361383e45f44c25a3427c8f47afc134f51 Mon Sep 17 00:00:00 2001
From: Muneeb Khan <muneeb.khan at huawei.com>
Date: Wed, 1 May 2024 22:26:19 +0800
Subject: [PATCH 2/8] [MLIR] Add support for multiway split in SplitOp
Add functionality that enables SplitOp to do a multiway split of
a traget linalg along a given dimension. When multiway attribute
is `true`, the SplitOp takes a list of split points and applies
it to a single linalg along the given dimension to generate
multiple linalgs extracted from the target.
---
.../Linalg/TransformOps/LinalgTransformOps.td | 23 ++-
.../TransformOps/LinalgTransformOps.cpp | 150 +++++++++++++-----
2 files changed, 123 insertions(+), 50 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index cce423b09617e..aed686c7c56b3 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -1396,7 +1396,7 @@ def SplitOp : Op<Transform_Dialect, "structured.split",
DeclareOpInterfaceMethods<TransformOpInterface>,
ReportTrackingListenerFailuresOpTrait]> {
let description = [{
- Indicates that the given `target` op should be split into two complementary
+ Splits the given `target` op into two or more complementary
parts, which combined cover the entire iteration domain of the original op.
The split is performed along the iteration space dimension provided as
attribute. In case of dimension overflow, the transformation fails. The
@@ -1409,16 +1409,27 @@ def SplitOp : Op<Transform_Dialect, "structured.split",
operations pointed to by the target handle.
The operation consumes the target handle, but preserves the split point
- handle if provided. It produces two new handles pointing to the two parts
- of the structured op after splitting, in the same order as the target
- operand, with the first handle corresponding to the part with lower
- iteration space indices.
+ handle if provided. Without the `multiway` attribute, it produces two
+ new handles pointing to the two parts of the structured op after splitting,
+ in the same order as the target operand, with the first handle
+ corresponding to the part with lower iteration space indices.
+
+ Multiway split mode is enabled by specifying the `multiway` attribute.
+ In this mode a single `target` op is split into multiple parts covering
+ the iteration space of the specified dimension. `static_split_point` and
+ `dynamic_split_point` in this case is a list of chunk sizes that the given
+ dimension should be split into. With `multiway` it produces two handles;
+ the first handle is a list of the multiple parts of the structured op
+ after splitting, where the target dimensions for each linalg op in the
+ list corresponds to the chunk sizes specfied in the input split list.
+ The second handle is empty.
}];
let arguments = (ins TransformHandleTypeInterface:$target,
I64Attr:$dimension,
Optional<TransformAnyParamTypeOrAnyHandle>:$dynamic_split_point,
- I64Attr:$static_split_point);
+ I64Attr:$static_split_point,
+ UnitAttr:$multiway);
let results = (outs TransformHandleTypeInterface:$first,
TransformHandleTypeInterface:$second);
let hasCustomAssemblyFormat = 1;
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 1e9e3163ef996..68c8c52df3760 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2269,8 +2269,20 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
// Collect the dynamic split points if provided.
SmallVector<Operation *> payload =
llvm::to_vector(state.getPayloadOps(getTarget()));
+
+ bool isMultiwaySplit = getMultiway() ? true : false;
+
+ if (isMultiwaySplit && !llvm::hasSingleElement(payload)) {
+ return emitDefiniteFailure() << "requires exactly one target when "
+ "multiway split is enabled (got "
+ << llvm::range_size(payload) << ")";
+ }
+
SmallVector<OpFoldResult> splitPoints;
- splitPoints.reserve(payload.size());
+
+ if (!isMultiwaySplit)
+ splitPoints.reserve(payload.size());
+
if (getDynamicSplitPoint()) {
auto diag = DiagnosedSilenceableFailure::success();
if (isa<TransformHandleTypeInterface>(getDynamicSplitPoint().getType())) {
@@ -2293,7 +2305,9 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
if (diag.isSilenceableFailure())
return diag;
- if (splitPoints.size() != payload.size()) {
+ // For multiway split, a single payload is expected to have multiple
+ // split points.
+ if (!isMultiwaySplit && splitPoints.size() != payload.size()) {
return emitDefiniteFailure()
<< "expected the dynamic split point handle to point to as "
"many operations ("
@@ -2305,57 +2319,105 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
rewriter.getIndexAttr(getStaticSplitPoint()));
}
- // Split each target operation.
- SmallVector<Operation *> first, second;
- Operation *noSecondPart = nullptr;
- for (const auto &pair : llvm::zip(payload, splitPoints)) {
- Operation *target = std::get<0>(pair);
- auto linalgOp = dyn_cast<LinalgOp>(target);
- if (!linalgOp) {
- auto diag = emitSilenceableError() << "only applies to structured ops";
- diag.attachNote(target->getLoc()) << "target op";
- return diag;
- }
+ if (isMultiwaySplit) {
- if (getDimension() >= linalgOp.getNumLoops()) {
- auto diag = emitSilenceableError() << "dimension " << getDimension()
- << " does not exist in target op";
- diag.attachNote(target->getLoc()) << "target op";
- return diag;
+ // Split a single target operation at multiple points.
+ SmallVector<Operation *> opList;
+ Operation *head, *tail;
+ for (const auto [idx, splitPoint] : llvm::enumerate(splitPoints)) {
+
+ Operation *target;
+ if (idx == 0)
+ target = payload.front();
+ else
+ target = tail;
+
+ if (!target)
+ break;
+
+ auto linalgOp = dyn_cast<LinalgOp>(target);
+
+ if (!linalgOp) {
+ auto diag = emitSilenceableError() << "only applies to structured ops";
+ diag.attachNote(target->getLoc()) << "target op";
+ return diag;
+ }
+
+ if (getDimension() >= linalgOp.getNumLoops()) {
+ auto diag = emitSilenceableError() << "dimension " << getDimension()
+ << " does not exist in target op";
+ diag.attachNote(target->getLoc()) << "target op";
+ return diag;
+ }
+
+ rewriter.setInsertionPoint(linalgOp);
+ std::tie(head, tail) = linalg::splitOp(
+ rewriter, cast<TilingInterface>(linalgOp.getOperation()),
+ getDimension(), splitPoint);
+
+ opList.push_back(head);
}
- rewriter.setInsertionPoint(linalgOp);
- std::tie(first.emplace_back(), second.emplace_back()) = linalg::splitOp(
- rewriter, cast<TilingInterface>(linalgOp.getOperation()),
- getDimension(), std::get<1>(pair));
+ // Append any leftover parts to the end of the result list.
+ if (tail)
+ opList.push_back(tail);
+ results.set(cast<OpResult>(getFirst()), opList);
+ results.set(cast<OpResult>(getSecond()), {});
- // Propagate errors.
- if (!first.back() && !second.back()) {
- auto diag = emitDefiniteFailure() << "internal failure in splitting";
- diag.attachNote(target->getLoc()) << "target op";
- return diag;
+ } else {
+ // Split each target operation.
+ SmallVector<Operation *> first, second;
+ Operation *noSecondPart = nullptr;
+ for (const auto &pair : llvm::zip(payload, splitPoints)) {
+ Operation *target = std::get<0>(pair);
+ auto linalgOp = dyn_cast<LinalgOp>(target);
+ if (!linalgOp) {
+ auto diag = emitSilenceableError() << "only applies to structured ops";
+ diag.attachNote(target->getLoc()) << "target op";
+ return diag;
+ }
+
+ if (getDimension() >= linalgOp.getNumLoops()) {
+ auto diag = emitSilenceableError() << "dimension " << getDimension()
+ << " does not exist in target op";
+ diag.attachNote(target->getLoc()) << "target op";
+ return diag;
+ }
+
+ rewriter.setInsertionPoint(linalgOp);
+ std::tie(first.emplace_back(), second.emplace_back()) = linalg::splitOp(
+ rewriter, cast<TilingInterface>(linalgOp.getOperation()),
+ getDimension(), std::get<1>(pair));
+
+ // Propagate errors.
+ if (!first.back() && !second.back()) {
+ auto diag = emitDefiniteFailure() << "internal failure in splitting";
+ diag.attachNote(target->getLoc()) << "target op";
+ return diag;
+ }
+
+ // Do not add null second parts.
+ if (!second.back()) {
+ noSecondPart = target;
+ second.pop_back();
+ }
}
- // Do not add null second parts.
- if (!second.back()) {
- noSecondPart = target;
- second.pop_back();
+ if (second.size() != first.size() && !second.empty()) {
+ auto diag = emitSilenceableError()
+ << "splitting does not produce the second part for a subset "
+ "of targets";
+ diag.attachNote()
+ << "expected splitting to produce the second part of all "
+ "or none of the targets";
+ diag.attachNote(noSecondPart->getLoc())
+ << "first target with no second part";
+ return diag;
}
- }
- if (second.size() != first.size() && !second.empty()) {
- auto diag = emitSilenceableError()
- << "splitting does not produce the second part for a subset "
- "of targets";
- diag.attachNote() << "expected splitting to produce the second part of all "
- "or none of the targets";
- diag.attachNote(noSecondPart->getLoc())
- << "first target with no second part";
- return diag;
+ results.set(cast<OpResult>(getFirst()), first);
+ results.set(cast<OpResult>(getSecond()), second);
}
-
- results.set(cast<OpResult>(getFirst()), first);
- results.set(cast<OpResult>(getSecond()), second);
return DiagnosedSilenceableFailure::success();
}
>From 668b9912a39accc141028ea4463529efeb53b95e Mon Sep 17 00:00:00 2001
From: Muneeb Khan <muneeb.khan at huawei.com>
Date: Thu, 2 May 2024 19:10:49 +0800
Subject: [PATCH 3/8] [MLIR] Test multiway SplitOp
Tests SplitOp for multiway splitting of a linalg op using
the result of `continuous_tile_sizes` to specify mutliple
split points for a single linalg op.
---
.../continuous-tiling-multiway-split.mlir | 100 ++++++++++++++++++
1 file changed, 100 insertions(+)
create mode 100644 mlir/test/Dialect/Linalg/continuous-tiling-multiway-split.mlir
diff --git a/mlir/test/Dialect/Linalg/continuous-tiling-multiway-split.mlir b/mlir/test/Dialect/Linalg/continuous-tiling-multiway-split.mlir
new file mode 100644
index 0000000000000..609766fbdc91f
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/continuous-tiling-multiway-split.mlir
@@ -0,0 +1,100 @@
+// RUN: mlir-opt --transform-interpreter --canonicalize --split-input-file %s | FileCheck %s
+
+// This tests the results of continuous_tile_sizes on multiway splitOp.
+// continuous_tile_sizes returns a list of tile-sizes and a list of split points.
+// The list of split points is consumed by splitOp to split the linalg.matmul op
+// along dimension 1 to produce as many split-up linalg.matmul ops.
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %tiles, %splits = transform.structured.continuous_tile_sizes %0 { dimension = 1, target_size = 9} : (!transform.any_op) -> !transform.any_op
+ %low, %high = transform.structured.split %0 after %splits { dimension = 1, multiway } : !transform.any_op, !transform.any_op
+ transform.yield
+ }
+}
+
+func.func @continuous_tile_linalg_matmul(
+ %arg0: tensor<25x34xf32>, %arg1: tensor<34x25xf32>, %arg2: tensor<25x25xf32>)
+ -> tensor<25x25xf32> {
+ %0 = linalg.matmul ins(%arg0, %arg1: tensor<25x34xf32>, tensor<34x25xf32>)
+ outs(%arg2: tensor<25x25xf32>)
+ -> tensor<25x25xf32>
+
+ return %0 : tensor<25x25xf32>
+}
+
+// CHECK-LABEL: @continuous_tile_linalg_matmul
+// CHECK-SAME: %[[IN1:.+]]: tensor<25x34xf32>, %[[IN2:.+]]: tensor<34x25xf32>, %[[OUT:.+]]: tensor<25x25xf32>
+// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[IN2]][0, 0] [34, 18] [1, 1] : tensor<34x25xf32> to tensor<34x18xf32>
+// CHECK %[[SLICE0:.+]] = tensor.extract_slice %[[OUT]][0, 0] [25, 18] [1, 1] : tensor<25x25xf32> to tensor<25x18xf32>
+// CHECK %[[MM0:.+]] = linalg.matmul ins(%[[IN1]], %[[SLICE]] : tensor<25x34xf32>, tensor<34x18xf32>) outs(%[[SLICE0]] : tensor<25x18xf32>) -> tensor<25x18xf32>
+// CHECK %[[INSLICE:.+]] = tensor.insert_slice %[[MM0]] into %[[OUT]][0, 0] [25, 18] [1, 1] : tensor<25x18xf32> into tensor<25x25xf32>
+// CHECK %[[SLICE1]] = tensor.extract_slice %[[IN2]][0, 18] [34, 7] [1, 1] : tensor<34x25xf32> to tensor<34x7xf32>
+// CHECK %[[SLICE2]] = tensor.extract_slice %[[INSLICE]][0, 18] [25, 7] [1, 1] : tensor<25x25xf32> to tensor<25x7xf32>
+// CHECK %[[SLICE3]] = tensor.extract_slice %[[SLICE1]][0, 0] [34, 4] [1, 1] : tensor<34x7xf32> to tensor<34x4xf32>
+// CHECK %[[SLICE4]] = tensor.extract_slice %[[SLICE2]][0, 0] [25, 4] [1, 1] : tensor<25x7xf32> to tensor<25x4xf32>
+// CHECK %[[MM1:.+]] = linalg.matmul ins(%[[IN1]], %[[SLICE3]] : tensor<25x34xf32>, tensor<34x4xf32>) outs(%[[SLICE4]] : tensor<25x4xf32>) -> tensor<25x4xf32>
+// CHECK %[[INSLICE0:.+]] = tensor.insert_slice %[[MM1]] into %[[SLICE2]][0, 0] [25, 4] [1, 1] : tensor<25x4xf32> into tensor<25x7xf32>
+// CHECK %[[SLICE5]] = tensor.extract_slice %[[SLICE1]][0, 4] [34, 3] [1, 1] : tensor<34x7xf32> to tensor<34x3xf32>
+// CHECK %[[SLICE6]] = tensor.extract_slice %[[INSLICE0]][0, 4] [25, 3] [1, 1] : tensor<25x7xf32> to tensor<25x3xf32>
+// CHECK %[[SLICE7]] = tensor.extract_slice %[[SLICE5]][0, 0] [34, 2] [1, 1] : tensor<34x3xf32> to tensor<34x2xf32>
+// CHECK %[[SLICE8]] = tensor.extract_slice %[[SLICE6]][0, 0] [25, 2] [1, 1] : tensor<25x3xf32> to tensor<25x2xf32>
+// CHECK %[[MM2:.+]] = linalg.matmul ins(%[[IN1]], %[[SLICE7]] : tensor<25x34xf32>, tensor<34x2xf32>) outs(%[[SLICE8]] : tensor<25x2xf32>) -> tensor<25x2xf32>
+// CHECK %[[INSLICE1:.+]] = tensor.insert_slice %[[MM2]] into %[[SLICE6]][0, 0] [25, 2] [1, 1] : tensor<25x2xf32> into tensor<25x3xf32>
+// CHECK %[[SLICE9]] = tensor.extract_slice %[[SLICE5]][0, 2] [34, 1] [1, 1] : tensor<34x3xf32> to tensor<34x1xf32>
+// CHECK %[[SLICE10]] = tensor.extract_slice %[[INSLICE1]][0, 2] [25, 1] [1, 1] : tensor<25x3xf32> to tensor<25x1xf32>
+// CHECK %[[MM3:.+]] = linalg.matmul ins(%[[IN1]], %[[SLICE9]] : tensor<25x34xf32>, tensor<34x1xf32>) outs(%[[SLICE10]] : tensor<25x1xf32>) -> tensor<25x1xf32>
+// CHECK %[[INSLICE2]] = tensor.insert_slice %[[MM3]] into %[[INSLICE1]][0, 2] [25, 1] [1, 1] : tensor<25x1xf32> into tensor<25x3xf32>
+// CHECK %[[INSLICE3]] = tensor.insert_slice %[[INSLICE2]] into %[[INSLICE0]][0, 4] [25, 3] [1, 1] : tensor<25x3xf32> into tensor<25x7xf32>
+// CHECK %[[INSLICE4]] = tensor.insert_slice %[[INSLICE3]] into %[[INSLICE]][0, 18] [25, 7] [1, 1] : tensor<25x7xf32> into tensor<25x25xf32>
+// CHECK return %[[INSLICE4]] : tensor<25x25xf32>
+
+// -----
+
+// Tests the same as above except that the !transform.param<i64> output type in
+// continuous_tile_sizes op triggers tile sizes and split points to be computed
+// statically and not dynamically.
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %tiles, %splits = transform.structured.continuous_tile_sizes %0 { dimension = 1, target_size = 9} : (!transform.any_op) -> !transform.param<i64>
+ %low, %high = transform.structured.split %0 after %splits { dimension = 1, multiway } : !transform.any_op, !transform.param<i64>
+ transform.yield
+ }
+}
+
+func.func @continuous_tile_static_linalg_matmul(
+ %arg0: tensor<25x34xf32>, %arg1: tensor<34x25xf32>, %arg2: tensor<25x25xf32>)
+ -> tensor<25x25xf32> {
+ %0 = linalg.matmul ins(%arg0, %arg1: tensor<25x34xf32>, tensor<34x25xf32>)
+ outs(%arg2: tensor<25x25xf32>)
+ -> tensor<25x25xf32>
+
+ return %0 : tensor<25x25xf32>
+}
+
+// CHECK-LABEL: @continuous_tile_static_linalg_matmul
+// CHECK-SAME: %[[IN1:.+]]: tensor<25x34xf32>, %[[IN2:.+]]: tensor<34x25xf32>, %[[OUT:.+]]: tensor<25x25xf32>
+// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[IN2]][0, 0] [34, 18] [1, 1] : tensor<34x25xf32> to tensor<34x18xf32>
+// CHECK %[[SLICE0:.+]] = tensor.extract_slice %[[OUT]][0, 0] [25, 18] [1, 1] : tensor<25x25xf32> to tensor<25x18xf32>
+// CHECK %[[MM0:.+]] = linalg.matmul ins(%[[IN1]], %[[SLICE]] : tensor<25x34xf32>, tensor<34x18xf32>) outs(%[[SLICE0]] : tensor<25x18xf32>) -> tensor<25x18xf32>
+// CHECK %[[INSLICE:.+]] = tensor.insert_slice %[[MM0]] into %[[OUT]][0, 0] [25, 18] [1, 1] : tensor<25x18xf32> into tensor<25x25xf32>
+// CHECK %[[SLICE1]] = tensor.extract_slice %[[IN2]][0, 18] [34, 7] [1, 1] : tensor<34x25xf32> to tensor<34x7xf32>
+// CHECK %[[SLICE2]] = tensor.extract_slice %[[INSLICE]][0, 18] [25, 7] [1, 1] : tensor<25x25xf32> to tensor<25x7xf32>
+// CHECK %[[SLICE3]] = tensor.extract_slice %[[SLICE1]][0, 0] [34, 4] [1, 1] : tensor<34x7xf32> to tensor<34x4xf32>
+// CHECK %[[SLICE4]] = tensor.extract_slice %[[SLICE2]][0, 0] [25, 4] [1, 1] : tensor<25x7xf32> to tensor<25x4xf32>
+// CHECK %[[MM1:.+]] = linalg.matmul ins(%[[IN1]], %[[SLICE3]] : tensor<25x34xf32>, tensor<34x4xf32>) outs(%[[SLICE4]] : tensor<25x4xf32>) -> tensor<25x4xf32>
+// CHECK %[[INSLICE0:.+]] = tensor.insert_slice %[[MM1]] into %[[SLICE2]][0, 0] [25, 4] [1, 1] : tensor<25x4xf32> into tensor<25x7xf32>
+// CHECK %[[SLICE5]] = tensor.extract_slice %[[SLICE1]][0, 4] [34, 3] [1, 1] : tensor<34x7xf32> to tensor<34x3xf32>
+// CHECK %[[SLICE6]] = tensor.extract_slice %[[INSLICE0]][0, 4] [25, 3] [1, 1] : tensor<25x7xf32> to tensor<25x3xf32>
+// CHECK %[[SLICE7]] = tensor.extract_slice %[[SLICE5]][0, 0] [34, 2] [1, 1] : tensor<34x3xf32> to tensor<34x2xf32>
+// CHECK %[[SLICE8]] = tensor.extract_slice %[[SLICE6]][0, 0] [25, 2] [1, 1] : tensor<25x3xf32> to tensor<25x2xf32>
+// CHECK %[[MM2:.+]] = linalg.matmul ins(%[[IN1]], %[[SLICE7]] : tensor<25x34xf32>, tensor<34x2xf32>) outs(%[[SLICE8]] : tensor<25x2xf32>) -> tensor<25x2xf32>
+// CHECK %[[INSLICE1:.+]] = tensor.insert_slice %[[MM2]] into %[[SLICE6]][0, 0] [25, 2] [1, 1] : tensor<25x2xf32> into tensor<25x3xf32>
+// CHECK %[[SLICE9]] = tensor.extract_slice %[[SLICE5]][0, 2] [34, 1] [1, 1] : tensor<34x3xf32> to tensor<34x1xf32>
+// CHECK %[[SLICE10]] = tensor.extract_slice %[[INSLICE1]][0, 2] [25, 1] [1, 1] : tensor<25x3xf32> to tensor<25x1xf32>
+// CHECK %[[MM3:.+]] = linalg.matmul ins(%[[IN1]], %[[SLICE9]] : tensor<25x34xf32>, tensor<34x1xf32>) outs(%[[SLICE10]] : tensor<25x1xf32>) -> tensor<25x1xf32>
+// CHECK %[[INSLICE2]] = tensor.insert_slice %[[MM3]] into %[[INSLICE1]][0, 2] [25, 1] [1, 1] : tensor<25x1xf32> into tensor<25x3xf32>
+// CHECK %[[INSLICE3]] = tensor.insert_slice %[[INSLICE2]] into %[[INSLICE0]][0, 4] [25, 3] [1, 1] : tensor<25x3xf32> into tensor<25x7xf32>
+// CHECK %[[INSLICE4]] = tensor.insert_slice %[[INSLICE3]] into %[[INSLICE]][0, 18] [25, 7] [1, 1] : tensor<25x7xf32> into tensor<25x25xf32>
+// CHECK return %[[INSLICE4]] : tensor<25x25xf32>
>From 0e49e915c64f45e15a702483a13e8c9d80e92cd2 Mon Sep 17 00:00:00 2001
From: Muneeb Khan <muneeb.khan at huawei.com>
Date: Mon, 20 May 2024 23:27:12 +0800
Subject: [PATCH 4/8] fix for SplitOp; switch from split-point to chunk-sizes
terminology. remove tautology. use emitSilenceableFailure. use references,
use conditional operator. move common code out of conditional in a lambda
function. check splitting operation was performed correctly. bug fix and
refactoring for code duplication.
---
.../Linalg/TransformOps/LinalgTransformOps.td | 25 +--
.../TransformOps/LinalgTransformOps.cpp | 164 ++++++++++--------
.../Dialect/Linalg/transform-op-split.mlir | 2 +-
3 files changed, 104 insertions(+), 87 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index aed686c7c56b3..a7a63ead07b1f 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -1399,16 +1399,18 @@ def SplitOp : Op<Transform_Dialect, "structured.split",
Splits the given `target` op into two or more complementary
parts, which combined cover the entire iteration domain of the original op.
The split is performed along the iteration space dimension provided as
- attribute. In case of dimension overflow, the transformation fails. The
- split is performed at the dimension iterator value specified as either the
- static split point attribute when it is known at transform IR construction
- time or as the handle to an operation producing a single index-typed value
- when it is computed by payload IR. In the latter case, the static split
+ chunk size attribute specifying the size of the lower part; the remaining
+ range in the iteration space is assigned as the upper part. In case of
+ dimension overflow, the transformation fails. The split is performed at the
+ dimension iterator value specified as either the static chunk size
+ attribute when it is known at transform IR construction time or
+ as the handle to an operation producing a single index-typed value
+ when it is computed by payload IR. In the latter case, the chunk size
point must be set to `ShapedType::kDynamic` and the dynamic size handle
must point to as many value-producing operations as there are structured
operations pointed to by the target handle.
- The operation consumes the target handle, but preserves the split point
+ The operation consumes the target handle, but preserves the chunk size
handle if provided. Without the `multiway` attribute, it produces two
new handles pointing to the two parts of the structured op after splitting,
in the same order as the target operand, with the first handle
@@ -1416,19 +1418,20 @@ def SplitOp : Op<Transform_Dialect, "structured.split",
Multiway split mode is enabled by specifying the `multiway` attribute.
In this mode a single `target` op is split into multiple parts covering
- the iteration space of the specified dimension. `static_split_point` and
- `dynamic_split_point` in this case is a list of chunk sizes that the given
+ the iteration space of the specified dimension. `static_chunk_sizes` and
+ `dynamic_chunk_sizes` in this case is a list of chunk sizes that the given
dimension should be split into. With `multiway` it produces two handles;
the first handle is a list of the multiple parts of the structured op
after splitting, where the target dimensions for each linalg op in the
list corresponds to the chunk sizes specfied in the input split list.
- The second handle is empty.
+ If the chunk sizes do not cover the entire iteration space, the leftover
+ chunk is the last payload in the first handle. The second handle is empty.
}];
let arguments = (ins TransformHandleTypeInterface:$target,
I64Attr:$dimension,
- Optional<TransformAnyParamTypeOrAnyHandle>:$dynamic_split_point,
- I64Attr:$static_split_point,
+ Optional<TransformAnyParamTypeOrAnyHandle>:$dynamic_chunk_sizes,
+ I64Attr:$static_chunk_sizes,
UnitAttr:$multiway);
let results = (outs TransformHandleTypeInterface:$first,
TransformHandleTypeInterface:$second);
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 68c8c52df3760..cf2ef3bb9a16e 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2270,24 +2270,25 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
SmallVector<Operation *> payload =
llvm::to_vector(state.getPayloadOps(getTarget()));
- bool isMultiwaySplit = getMultiway() ? true : false;
+ bool isMultiwaySplit = getMultiway();
if (isMultiwaySplit && !llvm::hasSingleElement(payload)) {
- return emitDefiniteFailure() << "requires exactly one target when "
- "multiway split is enabled (got "
- << llvm::range_size(payload) << ")";
+ return mlir::emitSilenceableFailure(getLoc())
+ << "requires exactly one target when "
+ "multiway split is enabled (got "
+ << llvm::range_size(payload) << ")";
}
- SmallVector<OpFoldResult> splitPoints;
+ SmallVector<OpFoldResult> chunkSizes;
if (!isMultiwaySplit)
- splitPoints.reserve(payload.size());
+ chunkSizes.reserve(payload.size());
- if (getDynamicSplitPoint()) {
+ if (getDynamicChunkSizes()) {
auto diag = DiagnosedSilenceableFailure::success();
- if (isa<TransformHandleTypeInterface>(getDynamicSplitPoint().getType())) {
- splitPoints = llvm::to_vector(llvm::map_range(
- state.getPayloadOps(getDynamicSplitPoint()), [&](Operation *op) {
+ if (isa<TransformHandleTypeInterface>(getDynamicChunkSizes().getType())) {
+ chunkSizes = llvm::to_vector(llvm::map_range(
+ state.getPayloadOps(getDynamicChunkSizes()), [&](Operation *op) {
if (op->getNumResults() != 1 ||
!op->getResult(0).getType().isIndex()) {
diag = emitSilenceableError()
@@ -2298,8 +2299,8 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
return OpFoldResult(op->getResult(0));
}));
} else {
- splitPoints = llvm::to_vector(
- llvm::map_range(state.getParams(getDynamicSplitPoint()),
+ chunkSizes = llvm::to_vector(
+ llvm::map_range(state.getParams(getDynamicChunkSizes()),
[](Attribute attr) { return OpFoldResult(attr); }));
}
if (diag.isSilenceableFailure())
@@ -2307,53 +2308,75 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
// For multiway split, a single payload is expected to have multiple
// split points.
- if (!isMultiwaySplit && splitPoints.size() != payload.size()) {
+ if (!isMultiwaySplit && chunkSizes.size() != payload.size()) {
return emitDefiniteFailure()
<< "expected the dynamic split point handle to point to as "
"many operations ("
- << splitPoints.size() << ") as the target handle ("
+ << chunkSizes.size() << ") as the target handle ("
<< payload.size() << ")";
}
} else {
- splitPoints.resize(payload.size(),
- rewriter.getIndexAttr(getStaticSplitPoint()));
+ chunkSizes.resize(payload.size(),
+ rewriter.getIndexAttr(getStaticChunkSizes()));
}
+ auto checkStructuredOpAndDimensions = [&](LinalgOp linalgOp, Location loc) {
+ if (!linalgOp) {
+ auto diag = emitSilenceableError() << "only applies to structured ops";
+ diag.attachNote(loc) << "target op";
+ return diag;
+ }
+
+ if (getDimension() >= linalgOp.getNumLoops()) {
+ auto diag = emitSilenceableError() << "dimension " << getDimension()
+ << " does not exist in target op";
+ diag.attachNote(loc) << "target op";
+ return diag;
+ }
+ return DiagnosedSilenceableFailure::success();
+ };
+
+ auto checkFailureInSplitting = [&](bool hasFailed, Location loc) {
+ if (hasFailed) {
+ auto diag = emitDefiniteFailure() << "internal failure in splitting";
+ diag.attachNote(loc) << "target op";
+ return DiagnosedSilenceableFailure(diag);
+ }
+ return DiagnosedSilenceableFailure::success();
+ };
+
if (isMultiwaySplit) {
// Split a single target operation at multiple points.
SmallVector<Operation *> opList;
Operation *head, *tail;
- for (const auto [idx, splitPoint] : llvm::enumerate(splitPoints)) {
+ Operation *target = payload.front();
+
+ auto linalgOp = dyn_cast<LinalgOp>(target);
+ auto diag = checkStructuredOpAndDimensions(linalgOp, target->getLoc());
+
+ if (diag.isSilenceableFailure())
+ return diag;
- Operation *target;
- if (idx == 0)
- target = payload.front();
- else
+ for (const auto &&[idx, chunkSize] : llvm::enumerate(chunkSizes)) {
+
+ if (idx > 0)
target = tail;
if (!target)
break;
- auto linalgOp = dyn_cast<LinalgOp>(target);
-
- if (!linalgOp) {
- auto diag = emitSilenceableError() << "only applies to structured ops";
- diag.attachNote(target->getLoc()) << "target op";
- return diag;
- }
-
- if (getDimension() >= linalgOp.getNumLoops()) {
- auto diag = emitSilenceableError() << "dimension " << getDimension()
- << " does not exist in target op";
- diag.attachNote(target->getLoc()) << "target op";
- return diag;
- }
+ linalgOp = dyn_cast<LinalgOp>(target);
rewriter.setInsertionPoint(linalgOp);
std::tie(head, tail) = linalg::splitOp(
rewriter, cast<TilingInterface>(linalgOp.getOperation()),
- getDimension(), splitPoint);
+ getDimension(), chunkSize);
+
+ // Propagate errors.
+ auto diag = checkFailureInSplitting(!head && !tail, target->getLoc());
+ if (diag.isDefiniteFailure())
+ return diag;
opList.push_back(head);
}
@@ -2368,21 +2391,13 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
// Split each target operation.
SmallVector<Operation *> first, second;
Operation *noSecondPart = nullptr;
- for (const auto &pair : llvm::zip(payload, splitPoints)) {
+ for (const auto &pair : llvm::zip(payload, chunkSizes)) {
Operation *target = std::get<0>(pair);
auto linalgOp = dyn_cast<LinalgOp>(target);
- if (!linalgOp) {
- auto diag = emitSilenceableError() << "only applies to structured ops";
- diag.attachNote(target->getLoc()) << "target op";
- return diag;
- }
+ auto diag = checkStructuredOpAndDimensions(linalgOp, target->getLoc());
- if (getDimension() >= linalgOp.getNumLoops()) {
- auto diag = emitSilenceableError() << "dimension " << getDimension()
- << " does not exist in target op";
- diag.attachNote(target->getLoc()) << "target op";
+ if (diag.isSilenceableFailure())
return diag;
- }
rewriter.setInsertionPoint(linalgOp);
std::tie(first.emplace_back(), second.emplace_back()) = linalg::splitOp(
@@ -2390,11 +2405,10 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
getDimension(), std::get<1>(pair));
// Propagate errors.
- if (!first.back() && !second.back()) {
- auto diag = emitDefiniteFailure() << "internal failure in splitting";
- diag.attachNote(target->getLoc()) << "target op";
+ auto diagSplit = checkFailureInSplitting(!first.back() && !second.back(),
+ target->getLoc());
+ if (diagSplit.isDefiniteFailure())
return diag;
- }
// Do not add null second parts.
if (!second.back()) {
@@ -2424,27 +2438,27 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
void SplitOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
consumesHandle(getTarget(), effects);
- if (getDynamicSplitPoint())
- onlyReadsHandle(getDynamicSplitPoint(), effects);
+ if (getDynamicChunkSizes())
+ onlyReadsHandle(getDynamicChunkSizes(), effects);
producesHandle(getResults(), effects);
modifiesPayload(effects);
}
ParseResult SplitOp::parse(OpAsmParser &parser, OperationState &result) {
- OpAsmParser::UnresolvedOperand target, dynamicSplitPoint;
- IntegerAttr staticSplitPoint;
+ OpAsmParser::UnresolvedOperand target, dynamicChunkSizes;
+ IntegerAttr staticChunkSizes;
if (parser.parseOperand(target) || parser.parseKeyword("after"))
return failure();
OptionalParseResult dynamicPointParseResult =
- parser.parseOptionalOperand(dynamicSplitPoint);
+ parser.parseOptionalOperand(dynamicChunkSizes);
if (!dynamicPointParseResult.has_value()) {
- int64_t staticSplitPointValue;
- if (failed(parser.parseInteger(staticSplitPointValue)))
+ int64_t staticChunkSizesValue;
+ if (failed(parser.parseInteger(staticChunkSizesValue)))
return failure();
- staticSplitPoint =
- parser.getBuilder().getI64IntegerAttr(staticSplitPointValue);
+ staticChunkSizes =
+ parser.getBuilder().getI64IntegerAttr(staticChunkSizesValue);
}
Type targetType;
@@ -2454,43 +2468,43 @@ ParseResult SplitOp::parse(OpAsmParser &parser, OperationState &result) {
return failure();
}
if (dynamicPointParseResult.has_value()) {
- Type splitPointType;
+ Type ChunkSizesType;
if (failed(*dynamicPointParseResult) || parser.parseComma() ||
- parser.parseType(splitPointType) ||
- parser.resolveOperand(dynamicSplitPoint, splitPointType,
+ parser.parseType(ChunkSizesType) ||
+ parser.resolveOperand(dynamicChunkSizes, ChunkSizesType,
result.operands)) {
return failure();
}
- staticSplitPoint =
+ staticChunkSizes =
parser.getBuilder().getI64IntegerAttr(ShapedType::kDynamic);
}
result.addAttribute(
- SplitOp::getStaticSplitPointAttrName(result.name).getValue(),
- staticSplitPoint);
+ SplitOp::getStaticChunkSizesAttrName(result.name).getValue(),
+ staticChunkSizes);
result.addTypes({targetType, targetType});
return success();
}
void SplitOp::print(OpAsmPrinter &printer) {
printer << " " << getTarget() << " after ";
- int64_t staticSplitSize = static_cast<int64_t>(getStaticSplitPoint());
- if (staticSplitSize != ShapedType::kDynamic)
- printer << staticSplitSize;
+ int64_t staticChunkSize = static_cast<int64_t>(getStaticChunkSizes());
+ if (staticChunkSize != ShapedType::kDynamic)
+ printer << staticChunkSize;
else
- printer << getDynamicSplitPoint();
+ printer << getDynamicChunkSizes();
printer << " ";
printer.printOptionalAttrDict(getOperation()->getAttrs(),
- {getStaticSplitPointAttrName()});
+ {getStaticChunkSizesAttrName()});
printer << " : " << getTarget().getType();
- if (staticSplitSize == ShapedType::kDynamic)
- printer << ", " << getDynamicSplitPoint().getType();
+ if (staticChunkSize == ShapedType::kDynamic)
+ printer << ", " << getDynamicChunkSizes().getType();
}
LogicalResult SplitOp::verify() {
- if ((static_cast<int64_t>(getStaticSplitPoint()) != ShapedType::kDynamic) ^
- (getDynamicSplitPoint() == nullptr)) {
+ if ((static_cast<int64_t>(getStaticChunkSizes()) != ShapedType::kDynamic) ^
+ (getDynamicChunkSizes() == nullptr)) {
return emitOpError() << "expects either a dynamic or a static split "
"point to be provided";
}
diff --git a/mlir/test/Dialect/Linalg/transform-op-split.mlir b/mlir/test/Dialect/Linalg/transform-op-split.mlir
index 566e517d69789..e072fff4c5d77 100644
--- a/mlir/test/Dialect/Linalg/transform-op-split.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-split.mlir
@@ -197,7 +197,7 @@ func.func @two_d(%arg0: tensor<10x34xf32>,
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.consumed}) {
// expected-error @below {{expects either a dynamic or a static split point to be provided}}
- %0:2 = "transform.structured.split"(%arg1) { dimension = 1, static_split_point = -9223372036854775808 } : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ %0:2 = "transform.structured.split"(%arg1) { dimension = 1, static_chunk_sizes = -9223372036854775808 } : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.yield
}
}
>From a814d48fb01e96aaac036e2c5f1c86845e6e569e Mon Sep 17 00:00:00 2001
From: Muneeb Khan <muneeb.khan at huawei.com>
Date: Tue, 21 May 2024 18:23:50 +0800
Subject: [PATCH 5/8] fix to continuous tile sizes; switch from split-points to
chunk-sizes. fix for loops. rename lambda functions, use map_to_vector.
---
.../Linalg/TransformOps/LinalgTransformOps.td | 19 +++---
.../TransformOps/LinalgTransformOps.cpp | 63 +++++++++----------
2 files changed, 38 insertions(+), 44 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index a7a63ead07b1f..866275cedf68b 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -1842,10 +1842,9 @@ def ContinuousTileSizesOp : Op<Transform_Dialect, "structured.continuous_tile_si
DeclareOpInterfaceMethods<TransformOpInterface>,
ReportTrackingListenerFailuresOpTrait]> {
let description = [{
- This transform takes a linalg as target and a dimension and target size
- as attributes to generate a list of (1) exponentially diminishing
- tile sizes that are powers of 2; and (2) the corresponding chunk-sizes
- the linalg op should be split into along the given dimension.
+ This transform emits the IR computing the list of (1) exponentially
+ diminishing tile sizes that are powers of 2; and (2) the corresponding
+ chunk-sizes the target op should be split into along the given dimension.
For example, for `target_size` 9, and `dimension` 0 for the following
linalg op as target
@@ -1859,23 +1858,23 @@ def ContinuousTileSizesOp : Op<Transform_Dialect, "structured.continuous_tile_si
9, 4, 2, 1; and the second result will be a list of chunk sizes
18, 4, 2, 1 that the corresponding dimension should be split into.
- After the linalg has been split along the given dimension (for example using
- multiway split), each chunk can be tiled with the corresponding tile size in
- the `tile_sizes` list generated as a result of this op.
+ After the target op has been split along the given dimension (for example
+ using multiway split), each chunk can be tiled with the corresponding tile
+ size in the `tile_sizes` list generated as a result of this op.
Specifying the output type as !transform.param<i64> will cause `tile_sizes`
- and `split_points` to be computed statically and not dynamically.
+ and `chunk_sizes` to be computed statically and not dynamically.
}];
let arguments = (ins TransformHandleTypeInterface:$target,
ConfinedAttr<I64Attr, [IntNonNegative]>:$dimension,
ConfinedAttr<I64Attr, [IntNonNegative]>:$target_size);
let results = (outs TransformAnyParamTypeOrAnyHandle:$tile_sizes,
- TransformAnyParamTypeOrAnyHandle:$split_points);
+ TransformAnyParamTypeOrAnyHandle:$chunk_sizes);
let hasVerifier = 1;
let assemblyFormat =
"$target attr-dict `:` custom<ContinuousTileSizeTypes>("
- "type($target), type($tile_sizes), type($split_points))";
+ "type($target), type($tile_sizes), type($chunk_sizes))";
}
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index cf2ef3bb9a16e..01998b080d2b4 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2672,8 +2672,9 @@ transform::ContinuousTileSizesOp::apply(transform::TransformRewriter &rewriter,
llvm::to_vector(state.getPayloadOps(getTarget()));
if (!llvm::hasSingleElement(targetOps)) {
- return emitDefiniteFailure() << "requires exactly one target (got "
- << llvm::range_size(targetOps) << ")";
+ return mlir::emitSilenceableFailure(getLoc())
+ << "requires exactly one target (got " << llvm::range_size(targetOps)
+ << ")";
}
auto target = dyn_cast<LinalgOp>(*targetOps.begin());
@@ -2683,7 +2684,7 @@ transform::ContinuousTileSizesOp::apply(transform::TransformRewriter &rewriter,
if (!target)
return emitDefiniteFailure() << "expected Linalg Op";
- if (isa<TransformParamTypeInterface>(getSplitPoints().getType())) {
+ if (isa<TransformParamTypeInterface>(getChunkSizes().getType())) {
if (target.hasDynamicShape()) {
auto diag = emitSilenceableError()
<< "cannot compute parametric tile sizes for dynamically "
@@ -2700,24 +2701,21 @@ transform::ContinuousTileSizesOp::apply(transform::TransformRewriter &rewriter,
<< "failed to compute multi-size tiling sizes";
}
- SmallVector<int64_t> splitPoints;
+ SmallVector<int64_t> chunkSizes;
- auto tileSizeTripCountPairs =
- llvm::zip_equal(spec->tileSizes, spec->tripCounts);
+ for (auto &&[tileSize, tripCount] :
+ llvm::zip_equal(spec->tileSizes, spec->tripCounts))
+ chunkSizes.push_back(tileSize * tripCount);
- for (auto [idx, pair] : llvm::enumerate(tileSizeTripCountPairs))
- splitPoints.push_back(std::get<0>(pair) * std::get<1>(pair));
-
- auto makeI64AttrsFromI64 = [&](ArrayRef<int64_t> values) {
- return llvm::to_vector(
- llvm::map_range(values, [&](int64_t value) -> Attribute {
+ auto getI64AttrsFromI64 = [&](ArrayRef<int64_t> values) {
+ return llvm::map_to_vector(values, [&](int64_t value) -> Attribute {
return builder.getI64IntegerAttr(value);
- }));
+ });
};
transformResults.setParams(cast<OpResult>(getTileSizes()),
- makeI64AttrsFromI64(spec->tileSizes));
- transformResults.setParams(cast<OpResult>(getSplitPoints()),
- makeI64AttrsFromI64(splitPoints));
+ getI64AttrsFromI64(spec->tileSizes));
+ transformResults.setParams(cast<OpResult>(getChunkSizes()),
+ getI64AttrsFromI64(chunkSizes));
return DiagnosedSilenceableFailure::success();
}
@@ -2733,9 +2731,6 @@ transform::ContinuousTileSizesOp::apply(transform::TransformRewriter &rewriter,
return emitSilenceableError() << "could not generate tile size computation";
}
- auto tileSizeTripCountPairs =
- llvm::zip_equal(spec->tileSizes, spec->tripCounts);
-
AffineExpr s0 = builder.getAffineSymbolExpr(0);
AffineExpr s1 = builder.getAffineSymbolExpr(1);
auto apply = [&](AffineExpr expr, ArrayRef<OpFoldResult> ofrs) -> Value {
@@ -2743,31 +2738,31 @@ transform::ContinuousTileSizesOp::apply(transform::TransformRewriter &rewriter,
ofrs);
};
- SmallVector<Value> splitPoints;
+ SmallVector<Value> chunkSizes;
Value splitPoint;
- for (auto [idx, pair] : llvm::enumerate(tileSizeTripCountPairs)) {
- splitPoint = apply(s0 * s1, {std::get<0>(pair), std::get<1>(pair)});
- splitPoints.push_back(splitPoint);
+ for (auto &&[tileSize, tripCount] :
+ llvm::zip_equal(spec->tileSizes, spec->tripCounts)) {
+ splitPoint = apply(s0 * s1, {tileSize, tripCount});
+ chunkSizes.push_back(splitPoint);
}
- auto makeOpFromValue = [&](ArrayRef<Value> values) {
- return llvm::to_vector(
- llvm::map_range(values, [&](Value value) -> Operation * {
+ auto getDefiningOps = [&](ArrayRef<Value> values) {
+ return llvm::map_to_vector(values, [&](Value value) -> Operation * {
return value.getDefiningOp();
- }));
+ });
};
transformResults.set(cast<OpResult>(getTileSizes()),
- makeOpFromValue(spec->tileSizes));
- transformResults.set(cast<OpResult>(getSplitPoints()),
- makeOpFromValue(splitPoints));
+ getDefiningOps(spec->tileSizes));
+ transformResults.set(cast<OpResult>(getChunkSizes()),
+ getDefiningOps(chunkSizes));
return DiagnosedSilenceableFailure::success();
}
LogicalResult transform::ContinuousTileSizesOp::verify() {
- if (getTileSizes().getType() != getSplitPoints().getType()) {
+ if (getTileSizes().getType() != getChunkSizes().getType()) {
return emitOpError() << "expects all results type to be the same";
}
@@ -2782,7 +2777,7 @@ void transform::ContinuousTileSizesOp::getEffects(
modifiesPayload(effects);
onlyReadsHandle(getTarget(), effects);
producesHandle(getTileSizes(), effects);
- producesHandle(getSplitPoints(), effects);
+ producesHandle(getChunkSizes(), effects);
}
static void printContinuousTileSizeTypes(OpAsmPrinter &printer, Operation *op,
@@ -2794,7 +2789,7 @@ static void printContinuousTileSizeTypes(OpAsmPrinter &printer, Operation *op,
static ParseResult parseContinuousTileSizeTypes(OpAsmParser &parser,
Type &targetType,
Type &tileSizesType,
- Type &splitPointsType) {
+ Type &chunkSizesType) {
FunctionType funcType;
llvm::SMLoc typeLoc = parser.getCurrentLocation();
if (failed(parser.parseType<FunctionType>(funcType)))
@@ -2805,7 +2800,7 @@ static ParseResult parseContinuousTileSizeTypes(OpAsmParser &parser,
"argument and one result";
}
targetType = funcType.getInput(0);
- tileSizesType = splitPointsType = funcType.getResult(0);
+ tileSizesType = chunkSizesType = funcType.getResult(0);
return success();
}
>From 904d1124f41c07061acd3901973c579bd42ef70b Mon Sep 17 00:00:00 2001
From: Muneeb Khan <muneeb.khan at huawei.com>
Date: Wed, 22 May 2024 22:11:16 +0800
Subject: [PATCH 6/8] fix for cts; Adapt to use TilingInterface.
---
.../Dialect/Linalg/Transforms/Transforms.h | 5 ++--
.../TransformOps/LinalgTransformOps.cpp | 24 +++++++++--------
mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp | 26 +++++++++----------
3 files changed, 28 insertions(+), 27 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index ef3656c334ea6..da9300499d096 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -864,8 +864,9 @@ FailureOr<StaticContinuousTileSizeSpecification>
computeStaticContinuousTileSizes(LinalgOp op, unsigned dimension,
unsigned targetSize);
FailureOr<ContinuousTileSizeSpecification>
-computeContinuousTileSizes(OpBuilder &builder, LinalgOp op, unsigned dimension,
- OpFoldResult targetSize, bool emitAssertions);
+computeContinuousTileSizes(OpBuilder &builder, TilingInterface op,
+ unsigned dimension, OpFoldResult targetSize,
+ bool emitAssertions);
/// Rewrite a TilingInterface `op` to a tiled `scf.forall`, applying
/// tiling by `numThreads`.
/// If non-empty, the `mapping` is added as an attribute to the
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 01998b080d2b4..1adb2571703c4 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2677,24 +2677,26 @@ transform::ContinuousTileSizesOp::apply(transform::TransformRewriter &rewriter,
<< ")";
}
- auto target = dyn_cast<LinalgOp>(*targetOps.begin());
-
- OpBuilder builder(target.getContext());
+ Operation *target = *targetOps.begin();
+ auto linalgOp = dyn_cast<LinalgOp>(target);
+ auto tileableOp = dyn_cast<TilingInterface>(target);
- if (!target)
+ if (!linalgOp)
return emitDefiniteFailure() << "expected Linalg Op";
+ OpBuilder builder(linalgOp.getContext());
+
if (isa<TransformParamTypeInterface>(getChunkSizes().getType())) {
- if (target.hasDynamicShape()) {
+ if (linalgOp.hasDynamicShape()) {
auto diag = emitSilenceableError()
<< "cannot compute parametric tile sizes for dynamically "
"shaped payload op";
- diag.attachNote(target->getLoc()) << "payload op";
+ diag.attachNote(linalgOp->getLoc()) << "payload op";
return diag;
}
FailureOr<StaticContinuousTileSizeSpecification> spec =
- computeStaticContinuousTileSizes(target, getDimension(),
+ computeStaticContinuousTileSizes(linalgOp, getDimension(),
getTargetSize());
if (failed(spec)) {
return emitSilenceableError()
@@ -2720,13 +2722,13 @@ transform::ContinuousTileSizesOp::apply(transform::TransformRewriter &rewriter,
return DiagnosedSilenceableFailure::success();
}
- builder.setInsertionPoint(target);
+ builder.setInsertionPoint(linalgOp);
OpFoldResult targetSize = builder.getIndexAttr(getTargetSize());
unsigned dimension = getDimension();
- FailureOr<ContinuousTileSizeSpecification> spec =
- computeContinuousTileSizes(builder, target, dimension, targetSize, true);
+ FailureOr<ContinuousTileSizeSpecification> spec = computeContinuousTileSizes(
+ builder, tileableOp, dimension, targetSize, true);
if (failed(spec)) {
return emitSilenceableError() << "could not generate tile size computation";
}
@@ -2734,7 +2736,7 @@ transform::ContinuousTileSizesOp::apply(transform::TransformRewriter &rewriter,
AffineExpr s0 = builder.getAffineSymbolExpr(0);
AffineExpr s1 = builder.getAffineSymbolExpr(1);
auto apply = [&](AffineExpr expr, ArrayRef<OpFoldResult> ofrs) -> Value {
- return affine::makeComposedAffineApply(builder, target->getLoc(), expr,
+ return affine::makeComposedAffineApply(builder, linalgOp->getLoc(), expr,
ofrs);
};
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index d88bcaf142b87..b8e90135dee6e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -108,7 +108,8 @@ static void emitIsPositiveIndexAssertion(ImplicitLocOpBuilder &b,
}
FailureOr<StaticContinuousTileSizeSpecification>
-mlir::linalg::computeStaticContinuousTileSizes(LinalgOp op, unsigned dimension,
+mlir::linalg::computeStaticContinuousTileSizes(LinalgOp op,
+ unsigned dimension,
unsigned targetSize) {
assert(!op.hasDynamicShape() &&
@@ -158,17 +159,20 @@ mlir::linalg::computeStaticContinuousTileSizes(LinalgOp op, unsigned dimension,
}
FailureOr<ContinuousTileSizeSpecification>
-mlir::linalg::computeContinuousTileSizes(OpBuilder &builder, LinalgOp op,
+mlir::linalg::computeContinuousTileSizes(OpBuilder &builder, TilingInterface op,
unsigned dimension,
OpFoldResult targetSize,
bool emitAssertions) {
+ SmallVector<Range> loopRanges = op.getIterationDomain(builder);
+ unsigned numLoops = loopRanges.size();
+
// Bail out on dimension overflow.
- if (dimension >= op.getNumLoops())
+ if (dimension >= numLoops)
return failure();
// The code below works only on values.
- Location loc = op.getLoc();
+ Location loc = op->getLoc();
ImplicitLocOpBuilder b(loc, builder);
if (emitAssertions) {
emitIsPositiveIndexAssertion(b, targetSize);
@@ -178,16 +182,8 @@ mlir::linalg::computeContinuousTileSizes(OpBuilder &builder, LinalgOp op,
// Find the trip count of the iteration space dimension for which the tile
// sizes are computed.
- SmallVector<OpFoldResult> allShapes =
- op.createFlatListOfOperandDims(b, b.getLoc());
- AffineMap shapesToLoops = op.getShapesToLoopsMap();
- SmallVector<OpFoldResult> loopRanges =
- makeComposedFoldedMultiResultAffineApply(b, op.getLoc(), shapesToLoops,
- allShapes);
-
- Value loopRange =
- getValueOrCreateConstantIndexOp(b, op.getLoc(), loopRanges[dimension]);
-
+ Value loopRange = getValueOrCreateConstantIndexOp(b, loc,
+ loopRanges[dimension].size);
ContinuousTileSizeSpecification spec;
// Compute the tile sizes and the respective numbers of tiles.
@@ -203,6 +199,8 @@ mlir::linalg::computeContinuousTileSizes(OpBuilder &builder, LinalgOp op,
OpFoldResult tripCountSize = affine::makeComposedFoldedAffineApply(
b, b.getLoc(), s0.floorDiv(s1), {loopRange, targetSizeValue});
+ // emitAssertions above already asserts that targetSize is
+ // a poistive integer.
uint64_t tileSizeInt = *getConstantIntValue(targetSizeValue);
assert(tileSizeInt > 0 && "target size must be non-negative");
>From 5e14f7630c4370857a97a880dda8aa21e6016162 Mon Sep 17 00:00:00 2001
From: Muneeb Khan <muneeb.khan at huawei.com>
Date: Fri, 24 May 2024 19:33:46 +0800
Subject: [PATCH 7/8] fix to SplitOp; modify python test scripts to chunk
sizes.
---
.../python/mlir/dialects/transform/structured.py | 16 ++++++++--------
.../python/dialects/transform_structured_ext.py | 4 ++--
2 files changed, 10 insertions(+), 10 deletions(-)
diff --git a/mlir/python/mlir/dialects/transform/structured.py b/mlir/python/mlir/dialects/transform/structured.py
index 2c49ef0960c75..41051c0d5b2ff 100644
--- a/mlir/python/mlir/dialects/transform/structured.py
+++ b/mlir/python/mlir/dialects/transform/structured.py
@@ -432,25 +432,25 @@ def __init__(
self,
target: Union[Operation, Value],
dimension: Union[int, Attribute],
- split_point: Union[int, Operation, Value, Attribute],
+ chunk_sizes: Union[int, Operation, Value, Attribute],
*,
loc=None,
ip=None,
):
- if isinstance(split_point, int):
- static_split_point = split_point
- dynamic_split_point = None
+ if isinstance(chunk_sizes, int):
+ static_chunk_sizes = chunk_sizes
+ dynamic_chunk_sizes = None
else:
- static_split_point = ShapedType.get_dynamic_size()
- dynamic_split_point = split_point
+ static_chunk_sizes = ShapedType.get_dynamic_size()
+ dynamic_chunk_sizes = chunk_sizes
super().__init__(
target.type,
target.type,
target,
dimension=dimension,
- static_split_point=static_split_point,
- dynamic_split_point=dynamic_split_point,
+ static_chunk_sizes=static_chunk_sizes,
+ dynamic_chunk_sizes=dynamic_chunk_sizes,
loc=loc,
ip=ip,
)
diff --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py
index f97017b7a2c75..3ea73e8beea36 100644
--- a/mlir/test/python/dialects/transform_structured_ext.py
+++ b/mlir/test/python/dialects/transform_structured_ext.py
@@ -361,8 +361,8 @@ def testScalarize(target):
@run
@create_sequence
def testSplit(target):
- split = structured.SplitOp(target, dimension=1, split_point=42)
- structured.SplitOp(split.results[0], dimension=3, split_point=split.results[1])
+ split = structured.SplitOp(target, dimension=1, chunk_sizes=42)
+ structured.SplitOp(split.results[0], dimension=3, chunk_sizes=split.results[1])
# CHECK-LABEL: TEST: testSplit
# CHECK: %[[F:.+]], %[[S:.+]] = transform.structured.split %{{.*}} after 42 {dimension = 1
# CHECK: transform.structured.split %[[F]] after %[[S]] {dimension = 3
>From 49d1286b731f3771ceedbc4cf20f7a5c51e60eed Mon Sep 17 00:00:00 2001
From: Muneeb Khan <muneeb.khan at huawei.com>
Date: Fri, 24 May 2024 20:51:56 +0800
Subject: [PATCH 8/8] fix to SplitOp; bug fixes for Windows build expand auto
to actual types
---
.../TransformOps/LinalgTransformOps.cpp | 35 +++++++++++--------
1 file changed, 20 insertions(+), 15 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 1adb2571703c4..7347e487a899e 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2320,7 +2320,8 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
rewriter.getIndexAttr(getStaticChunkSizes()));
}
- auto checkStructuredOpAndDimensions = [&](LinalgOp linalgOp, Location loc) {
+ auto checkStructuredOpAndDimensions =
+ [&](LinalgOp linalgOp, Location loc) -> DiagnosedSilenceableFailure {
if (!linalgOp) {
auto diag = emitSilenceableError() << "only applies to structured ops";
diag.attachNote(loc) << "target op";
@@ -2336,11 +2337,12 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
return DiagnosedSilenceableFailure::success();
};
- auto checkFailureInSplitting = [&](bool hasFailed, Location loc) {
+ auto checkFailureInSplitting =
+ [&](bool hasFailed, Location loc) -> DiagnosedSilenceableFailure {
if (hasFailed) {
auto diag = emitDefiniteFailure() << "internal failure in splitting";
diag.attachNote(loc) << "target op";
- return DiagnosedSilenceableFailure(diag);
+ return diag;
}
return DiagnosedSilenceableFailure::success();
};
@@ -2349,24 +2351,25 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
// Split a single target operation at multiple points.
SmallVector<Operation *> opList;
- Operation *head, *tail;
+ TilingInterface head, tail;
Operation *target = payload.front();
- auto linalgOp = dyn_cast<LinalgOp>(target);
- auto diag = checkStructuredOpAndDimensions(linalgOp, target->getLoc());
+ LinalgOp linalgOp = dyn_cast<LinalgOp>(target);
+ DiagnosedSilenceableFailure diag =
+ checkStructuredOpAndDimensions(linalgOp, target->getLoc());
if (diag.isSilenceableFailure())
return diag;
- for (const auto &&[idx, chunkSize] : llvm::enumerate(chunkSizes)) {
+ for (auto &&[idx, chunkSize] : llvm::enumerate(chunkSizes)) {
if (idx > 0)
- target = tail;
+ target = tail.getOperation();
if (!target)
break;
- linalgOp = dyn_cast<LinalgOp>(target);
+ linalgOp = cast<LinalgOp>(target);
rewriter.setInsertionPoint(linalgOp);
std::tie(head, tail) = linalg::splitOp(
@@ -2374,11 +2377,12 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
getDimension(), chunkSize);
// Propagate errors.
- auto diag = checkFailureInSplitting(!head && !tail, target->getLoc());
+ DiagnosedSilenceableFailure diag =
+ checkFailureInSplitting(!head && !tail, target->getLoc());
if (diag.isDefiniteFailure())
return diag;
- opList.push_back(head);
+ opList.push_back(head.getOperation());
}
// Append any leftover parts to the end of the result list.
@@ -2393,8 +2397,9 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
Operation *noSecondPart = nullptr;
for (const auto &pair : llvm::zip(payload, chunkSizes)) {
Operation *target = std::get<0>(pair);
- auto linalgOp = dyn_cast<LinalgOp>(target);
- auto diag = checkStructuredOpAndDimensions(linalgOp, target->getLoc());
+ LinalgOp linalgOp = dyn_cast<LinalgOp>(target);
+ DiagnosedSilenceableFailure diag =
+ checkStructuredOpAndDimensions(linalgOp, target->getLoc());
if (diag.isSilenceableFailure())
return diag;
@@ -2405,8 +2410,8 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
getDimension(), std::get<1>(pair));
// Propagate errors.
- auto diagSplit = checkFailureInSplitting(!first.back() && !second.back(),
- target->getLoc());
+ DiagnosedSilenceableFailure diagSplit = checkFailureInSplitting(
+ !first.back() && !second.back(), target->getLoc());
if (diagSplit.isDefiniteFailure())
return diag;
More information about the Mlir-commits
mailing list