[Mlir-commits] [mlir] d61e68f - [mlir][linalg][transform] linalg::tileToForallOpImpl processes only one target
Matthias Springer
llvmlistbot at llvm.org
Fri May 5 04:35:19 PDT 2023
Author: Matthias Springer
Date: 2023-05-05T20:35:10+09:00
New Revision: d61e68ff0e0c8ad91cdbab8bb25cdd40930afb6f
URL: https://github.com/llvm/llvm-project/commit/d61e68ff0e0c8ad91cdbab8bb25cdd40930afb6f
DIFF: https://github.com/llvm/llvm-project/commit/d61e68ff0e0c8ad91cdbab8bb25cdd40930afb6f.diff
LOG: [mlir][linalg][transform] linalg::tileToForallOpImpl processes only one target
`tileToForallOpImpl` takes only one target instead of a list of all targets.
This is in preparation of D149847.
Differential Revision: https://reviews.llvm.org/D149929
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
index 775377e5e5bce..95133e296c757 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
@@ -23,6 +23,7 @@ class TilingInterface;
class RewriterBase;
namespace linalg {
+struct ForallTilingResult;
class GenericOp;
class LinalgOp;
} // namespace linalg
@@ -48,12 +49,13 @@ class DialectRegistry;
namespace transform {
/// Implementation of tiling operations using `scf.forall`.
-DiagnosedSilenceableFailure tileToForallOpImpl(
- RewriterBase &rewriter, transform::TransformState &state,
- TransformOpInterface transformOp, ArrayRef<Operation *> targets,
- ArrayRef<OpFoldResult> mixedNumThreads,
- ArrayRef<OpFoldResult> mixedTileSizes, std::optional<ArrayAttr> mapping,
- SmallVector<Operation *> &tileOps, SmallVector<Operation *> &tiledOps);
+DiagnosedSilenceableFailure
+tileToForallOpImpl(RewriterBase &rewriter, transform::TransformState &state,
+ TransformOpInterface transformOp, Operation *target,
+ ArrayRef<OpFoldResult> mixedNumThreads,
+ ArrayRef<OpFoldResult> mixedTileSizes,
+ std::optional<ArrayAttr> mapping,
+ linalg::ForallTilingResult &tilingResult);
} // namespace transform
} // namespace mlir
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 9c7236e21842a..74eb3a2df0f96 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2506,15 +2506,11 @@ void transform::TileToForallOp::build(OpBuilder &builder,
DiagnosedSilenceableFailure transform::tileToForallOpImpl(
RewriterBase &rewriter, transform::TransformState &state,
- TransformOpInterface transformOp, ArrayRef<Operation *> targets,
+ TransformOpInterface transformOp, Operation *target,
ArrayRef<OpFoldResult> mixedNumThreads,
ArrayRef<OpFoldResult> mixedTileSizes, std::optional<ArrayAttr> mapping,
- SmallVector<Operation *> &tileOps, SmallVector<Operation *> &tiledOps) {
- if (targets.empty())
- return DiagnosedSilenceableFailure::success();
-
+ linalg::ForallTilingResult &tilingResult) {
// Transform all targets one by one.
- for (Operation *target : targets) {
auto tileableOp = dyn_cast<TilingInterface>(target);
if (!tileableOp) {
DiagnosedSilenceableFailure diag =
@@ -2524,23 +2520,21 @@ DiagnosedSilenceableFailure transform::tileToForallOpImpl(
return diag;
}
rewriter.setInsertionPoint(tileableOp);
- FailureOr<linalg::ForallTilingResult> tilingResult = failure();
+ FailureOr<linalg::ForallTilingResult> maybeTilingResult = failure();
if (!mixedNumThreads.empty()) {
- tilingResult = linalg::tileToForallOp(rewriter, tileableOp,
- mixedNumThreads, mapping);
+ maybeTilingResult = linalg::tileToForallOp(rewriter, tileableOp,
+ mixedNumThreads, mapping);
} else {
- tilingResult = linalg::tileToForallOpUsingTileSizes(
+ maybeTilingResult = linalg::tileToForallOpUsingTileSizes(
rewriter, tileableOp, mixedTileSizes, mapping);
}
- if (failed(tilingResult))
+ if (failed(maybeTilingResult))
return transformOp.emitDefaultSilenceableFailure(tileableOp);
- rewriter.replaceOp(tileableOp, tilingResult->tileOp->getResults());
+ rewriter.replaceOp(tileableOp, maybeTilingResult->tileOp->getResults());
- tileOps.push_back(tilingResult->tileOp);
- tiledOps.push_back(tilingResult->tiledOp);
- }
- return DiagnosedSilenceableFailure::success();
+ tilingResult = *maybeTilingResult;
+ return DiagnosedSilenceableFailure::success();
}
DiagnosedSilenceableFailure
@@ -2577,12 +2571,16 @@ transform::TileToForallOp::apply(transform::TransformResults &transformResults,
if (!status.succeeded())
return status;
- DiagnosedSilenceableFailure diag =
- tileToForallOpImpl(rewriter, state, transformOp, targets, mixedNumThreads,
- mixedTileSizes, getMapping(), tileOps, tiledOps);
-
- if (!diag.succeeded())
- return diag;
+ for (Operation *target : targets) {
+ linalg::ForallTilingResult tilingResult;
+ DiagnosedSilenceableFailure diag = tileToForallOpImpl(
+ rewriter, state, transformOp, target, mixedNumThreads, mixedTileSizes,
+ getMapping(), tilingResult);
+ if (!diag.succeeded())
+ return diag;
+ tileOps.push_back(tilingResult.tileOp);
+ tiledOps.push_back(tilingResult.tiledOp);
+ }
transformResults.set(getForallOp().cast<OpResult>(), tileOps);
transformResults.set(getTiledOp().cast<OpResult>(), tiledOps);
More information about the Mlir-commits
mailing list