[Mlir-commits] [mlir] [MLIR] Add continuous tiling to TileUsingForOp (PR #82792)
Oleksandr Alex Zinenko
llvmlistbot at llvm.org
Fri May 17 06:20:40 PDT 2024
================
@@ -2581,6 +2643,157 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne(
return DiagnosedSilenceableFailure::success();
}
+//===----------------------------------------------------------------------===//
+// ContinuousTileSizesOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::ContinuousTileSizesOp::apply(transform::TransformRewriter &rewriter,
+ TransformResults &transformResults,
+ TransformState &state) {
+
+ SmallVector<Operation *> targetOps =
+ llvm::to_vector(state.getPayloadOps(getTarget()));
+
+ if (!llvm::hasSingleElement(targetOps)) {
+ return emitDefiniteFailure() << "requires exactly one target (got "
+ << llvm::range_size(targetOps) << ")";
+ }
+
+ auto target = dyn_cast<LinalgOp>(*targetOps.begin());
+
+ OpBuilder builder(target.getContext());
+
+ if (!target)
+ return emitDefiniteFailure() << "expected Linalg Op";
+
+ if (isa<TransformParamTypeInterface>(getSplitPoints().getType())) {
+ if (target.hasDynamicShape()) {
+ auto diag = emitSilenceableError()
+ << "cannot compute parametric tile sizes for dynamically "
+ "shaped payload op";
+ diag.attachNote(target->getLoc()) << "payload op";
+ return diag;
+ }
+
+ FailureOr<StaticContinuousTileSizeSpecification> spec =
+ computeStaticContinuousTileSizes(target, getDimension(),
+ getTargetSize());
+ if (failed(spec)) {
+ return emitSilenceableError()
+ << "failed to compute multi-size tiling sizes";
+ }
+
+ SmallVector<int64_t> splitPoints;
+
+ auto tileSizeTripCountPairs =
+ llvm::zip_equal(spec->tileSizes, spec->tripCounts);
+
+ for (auto [idx, pair] : llvm::enumerate(tileSizeTripCountPairs))
+ splitPoints.push_back(std::get<0>(pair) * std::get<1>(pair));
+
+ auto makeI64AttrsFromI64 = [&](ArrayRef<int64_t> values) {
+ return llvm::to_vector(
+ llvm::map_range(values, [&](int64_t value) -> Attribute {
+ return builder.getI64IntegerAttr(value);
+ }));
+ };
+ transformResults.setParams(cast<OpResult>(getTileSizes()),
+ makeI64AttrsFromI64(spec->tileSizes));
+ transformResults.setParams(cast<OpResult>(getSplitPoints()),
+ makeI64AttrsFromI64(splitPoints));
+
+ return DiagnosedSilenceableFailure::success();
+ }
+
+ builder.setInsertionPoint(target);
+
+ OpFoldResult targetSize = builder.getIndexAttr(getTargetSize());
+ unsigned dimension = getDimension();
+
+ FailureOr<ContinuousTileSizeSpecification> spec =
+ computeContinuousTileSizes(builder, target, dimension, targetSize, true);
+ if (failed(spec)) {
+ return emitSilenceableError() << "could not generate tile size computation";
+ }
+
+ auto tileSizeTripCountPairs =
+ llvm::zip_equal(spec->tileSizes, spec->tripCounts);
+
+ AffineExpr s0 = builder.getAffineSymbolExpr(0);
+ AffineExpr s1 = builder.getAffineSymbolExpr(1);
+ auto apply = [&](AffineExpr expr, ArrayRef<OpFoldResult> ofrs) -> Value {
+ return affine::makeComposedAffineApply(builder, target->getLoc(), expr,
+ ofrs);
+ };
+
+ SmallVector<Value> splitPoints;
+ Value splitPoint;
+ for (auto [idx, pair] : llvm::enumerate(tileSizeTripCountPairs)) {
+ splitPoint = apply(s0 * s1, {std::get<0>(pair), std::get<1>(pair)});
+ splitPoints.push_back(splitPoint);
+ }
+
+ auto makeOpFromValue = [&](ArrayRef<Value> values) {
+ return llvm::to_vector(
+ llvm::map_range(values, [&](Value value) -> Operation * {
----------------
ftynse wrote:
map_to_vector
https://github.com/llvm/llvm-project/pull/82792
More information about the Mlir-commits
mailing list