[Mlir-commits] [mlir] [MLIR] Add continuous tiling to TileUsingForOp (PR #82792)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu May 30 07:09:02 PDT 2024
================
@@ -2286,103 +2299,168 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
return OpFoldResult(op->getResult(0));
}));
} else {
- splitPoints = llvm::to_vector(
- llvm::map_range(state.getParams(getDynamicSplitPoint()),
+ chunkSizes = llvm::to_vector(
+ llvm::map_range(state.getParams(getDynamicChunkSizes()),
[](Attribute attr) { return OpFoldResult(attr); }));
}
if (diag.isSilenceableFailure())
return diag;
- if (splitPoints.size() != payload.size()) {
+ // For multiway split, a single payload is expected to have multiple
+ // split points.
+ if (!isMultiwaySplit && chunkSizes.size() != payload.size()) {
return emitDefiniteFailure()
<< "expected the dynamic split point handle to point to as "
"many operations ("
- << splitPoints.size() << ") as the target handle ("
+ << chunkSizes.size() << ") as the target handle ("
<< payload.size() << ")";
}
} else {
- splitPoints.resize(payload.size(),
- rewriter.getIndexAttr(getStaticSplitPoint()));
+ chunkSizes.resize(payload.size(),
+ rewriter.getIndexAttr(getStaticChunkSizes()));
}
- // Split each target operation.
- SmallVector<Operation *> first, second;
- Operation *noSecondPart = nullptr;
- for (const auto &pair : llvm::zip(payload, splitPoints)) {
- Operation *target = std::get<0>(pair);
- auto linalgOp = dyn_cast<LinalgOp>(target);
+ auto checkStructuredOpAndDimensions =
+ [&](LinalgOp linalgOp, Location loc) -> DiagnosedSilenceableFailure {
if (!linalgOp) {
auto diag = emitSilenceableError() << "only applies to structured ops";
- diag.attachNote(target->getLoc()) << "target op";
+ diag.attachNote(loc) << "target op";
return diag;
}
if (getDimension() >= linalgOp.getNumLoops()) {
auto diag = emitSilenceableError() << "dimension " << getDimension()
- << " does not exist in target op";
- diag.attachNote(target->getLoc()) << "target op";
+ << " does not exist in target op";
+ diag.attachNote(loc) << "target op";
return diag;
}
+ return DiagnosedSilenceableFailure::success();
+ };
- rewriter.setInsertionPoint(linalgOp);
- std::tie(first.emplace_back(), second.emplace_back()) = linalg::splitOp(
- rewriter, cast<TilingInterface>(linalgOp.getOperation()),
- getDimension(), std::get<1>(pair));
-
- // Propagate errors.
- if (!first.back() && !second.back()) {
+ auto checkFailureInSplitting =
+ [&](bool hasFailed, Location loc) -> DiagnosedSilenceableFailure {
+ if (hasFailed) {
auto diag = emitDefiniteFailure() << "internal failure in splitting";
- diag.attachNote(target->getLoc()) << "target op";
+ diag.attachNote(loc) << "target op";
return diag;
}
+ return DiagnosedSilenceableFailure::success();
+ };
+
+ if (isMultiwaySplit) {
- // Do not add null second parts.
- if (!second.back()) {
- noSecondPart = target;
- second.pop_back();
+ // Split a single target operation at multiple points.
+ SmallVector<Operation *> opList;
+ Operation *head, *tail;
+ Operation *target = payload.front();
+
+ auto linalgOp = dyn_cast<LinalgOp>(target);
+ auto diag = checkStructuredOpAndDimensions(linalgOp, target->getLoc());
+
+ if (diag.isSilenceableFailure())
+ return diag;
+
+ for (const auto &&[idx, chunkSize] : llvm::enumerate(chunkSizes)) {
----------------
muneebkhan85 wrote:
Fixed.
https://github.com/llvm/llvm-project/pull/82792
More information about the Mlir-commits
mailing list