[Mlir-commits] [mlir] [MLIR] Add continuous tiling to TileUsingForOp (PR #82792)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu May 23 05:01:44 PDT 2024
================
@@ -2581,6 +2643,157 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne(
return DiagnosedSilenceableFailure::success();
}
+//===----------------------------------------------------------------------===//
+// ContinuousTileSizesOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::ContinuousTileSizesOp::apply(transform::TransformRewriter &rewriter,
+ TransformResults &transformResults,
+ TransformState &state) {
+
+ SmallVector<Operation *> targetOps =
+ llvm::to_vector(state.getPayloadOps(getTarget()));
+
+ if (!llvm::hasSingleElement(targetOps)) {
+ return emitDefiniteFailure() << "requires exactly one target (got "
+ << llvm::range_size(targetOps) << ")";
+ }
+
+ auto target = dyn_cast<LinalgOp>(*targetOps.begin());
+
+ OpBuilder builder(target.getContext());
+
+ if (!target)
+ return emitDefiniteFailure() << "expected Linalg Op";
+
+ if (isa<TransformParamTypeInterface>(getSplitPoints().getType())) {
+ if (target.hasDynamicShape()) {
+ auto diag = emitSilenceableError()
+ << "cannot compute parametric tile sizes for dynamically "
+ "shaped payload op";
+ diag.attachNote(target->getLoc()) << "payload op";
+ return diag;
+ }
+
+ FailureOr<StaticContinuousTileSizeSpecification> spec =
+ computeStaticContinuousTileSizes(target, getDimension(),
+ getTargetSize());
+ if (failed(spec)) {
+ return emitSilenceableError()
+ << "failed to compute multi-size tiling sizes";
+ }
+
+ SmallVector<int64_t> splitPoints;
+
+ auto tileSizeTripCountPairs =
+ llvm::zip_equal(spec->tileSizes, spec->tripCounts);
+
+ for (auto [idx, pair] : llvm::enumerate(tileSizeTripCountPairs))
----------------
muneebkhan85 wrote:
Fixed.
https://github.com/llvm/llvm-project/pull/82792
More information about the Mlir-commits
mailing list