[Mlir-commits] [mlir] acc2a12 - [mlir][Linalg] Expose the implementation of the tiling to scf.foreach_thread.
Mahesh Ravishankar
llvmlistbot at llvm.org
Thu Sep 22 15:19:49 PDT 2022
Author: Mahesh Ravishankar
Date: 2022-09-22T22:19:19Z
New Revision: acc2a12c3318aad2ebdc161a0a7f4c2fec52e18d
URL: https://github.com/llvm/llvm-project/commit/acc2a12c3318aad2ebdc161a0a7f4c2fec52e18d
DIFF: https://github.com/llvm/llvm-project/commit/acc2a12c3318aad2ebdc161a0a7f4c2fec52e18d.diff
LOG: [mlir][Linalg] Expose the implementation of the tiling to scf.foreach_thread.
This allows downstream uses to use the implementation of the tiling
itself, while performing other transformations that are necessary to
go with it.
Reviewed By: nicolasvasilache
Differential Revision: https://reviews.llvm.org/D134335
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
index 45b0600adcbd..f7952db7e2a2 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
@@ -15,6 +15,7 @@
namespace mlir {
class TilingInterface;
+class RewriterBase;
namespace linalg {
class GenericOp;
class LinalgOp;
@@ -33,6 +34,17 @@ class LinalgOp;
namespace mlir {
class DialectRegistry;
+namespace transform {
+
+/// Implementation of tiling operations using `scf.foreach_thread`.
+DiagnosedSilenceableFailure tileToForeachThreadOpImpl(
+ RewriterBase &rewriter, transform::TransformState &state,
+ TransformOpInterface transformOp, ArrayRef<Operation *> targets,
+ ArrayRef<OpFoldResult> mixedNumThreads,
+ ArrayRef<OpFoldResult> mixedTileSizes, Optional<ArrayAttr> threadDimMapping,
+ SmallVector<Operation *> &tileOps, SmallVector<Operation *> &tiledOps);
+} // namespace transform
+
namespace linalg {
void registerTransformDialectExtension(DialectRegistry ®istry);
} // namespace linalg
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 108570f53843..ca3c932c77cc 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1339,22 +1339,15 @@ transform::MapNestedForeachThreadToGpuThreads::applyToOne(
// TileToForeachThreadOp
//===----------------------------------------------------------------------===//
-DiagnosedSilenceableFailure transform::TileToForeachThreadOp::apply(
- transform::TransformResults &transformResults,
- transform::TransformState &state) {
- IRRewriter rewriter(getContext());
- ArrayRef<Operation *> targets = state.getPayloadOps(getTarget());
-
- // If there the target payload ops are empty, there is nothing to do.
- if (targets.empty()) {
- transformResults.set(getForeachThreadOp().cast<OpResult>(), {});
- transformResults.set(getTiledOp().cast<OpResult>(), {});
+DiagnosedSilenceableFailure transform::tileToForeachThreadOpImpl(
+ RewriterBase &rewriter, transform::TransformState &state,
+ TransformOpInterface transformOp, ArrayRef<Operation *> targets,
+ ArrayRef<OpFoldResult> mixedNumThreads,
+ ArrayRef<OpFoldResult> mixedTileSizes, Optional<ArrayAttr> threadDimMapping,
+ SmallVector<Operation *> &tileOps, SmallVector<Operation *> &tiledOps) {
+
+ if (targets.empty())
return DiagnosedSilenceableFailure(success());
- }
-
- // Result payload ops.
- SmallVector<Operation *> tileOps;
- SmallVector<Operation *> tiledOps;
// Given a list of OpFoldResults that are either index attrs or op handles,
// return a list of OpFoldResults where all op handles are replaced with the
@@ -1372,7 +1365,7 @@ DiagnosedSilenceableFailure transform::TileToForeachThreadOp::apply(
state.getPayloadOps(ofr.get<Value>());
if (dynamicNumThreads.size() != 1) {
DiagnosedSilenceableFailure diag =
- emitSilenceableError()
+ transformOp.emitSilenceableError()
<< "handle must be mapped to exactly 1 payload op";
diag.attachNote(ofr.get<Value>().getLoc())
<< "mapped to " << dynamicNumThreads.size() << " ops";
@@ -1382,7 +1375,7 @@ DiagnosedSilenceableFailure transform::TileToForeachThreadOp::apply(
if (op->getNumResults() != 1 ||
!op->getResult(0).getType().isIndex()) {
DiagnosedSilenceableFailure diag =
- emitSilenceableError()
+ transformOp.emitSilenceableError()
<< "payload op must have exactly 1 index result";
diag.attachNote(op->getLoc())
<< "has " << op->getNumResults() << " results";
@@ -1398,14 +1391,14 @@ DiagnosedSilenceableFailure transform::TileToForeachThreadOp::apply(
// Convert to OpFoldResults[index attributes or payload op].
SmallVector<OpFoldResult> numThreads;
DiagnosedSilenceableFailure status =
- getOpResultsOrIndexAttrs(numThreads, getMixedNumThreads());
+ getOpResultsOrIndexAttrs(numThreads, mixedNumThreads);
if (!status.succeeded())
return status;
// getMixedTileSizes are OpFoldResults[index attributes or PDL operation].
// Convert to OpFoldResults[index attributes or payload op].
SmallVector<OpFoldResult> tileSizes;
- status = getOpResultsOrIndexAttrs(tileSizes, getMixedTileSizes());
+ status = getOpResultsOrIndexAttrs(tileSizes, mixedTileSizes);
if (!status.succeeded())
return status;
@@ -1414,19 +1407,20 @@ DiagnosedSilenceableFailure transform::TileToForeachThreadOp::apply(
auto tilableOp = dyn_cast<TilingInterface>(target);
if (!tilableOp) {
DiagnosedSilenceableFailure diag =
- emitSilenceableError() << "only TilingInterface ops are supported";
+ transformOp.emitSilenceableError()
+ << "only TilingInterface ops are supported";
diag.attachNote(target->getLoc()) << "target op";
return diag;
}
rewriter.setInsertionPoint(tilableOp);
- auto maybeThreadDimMappingAttr = getThreadDimMapping();
+ auto maybeThreadDimMappingAttr = threadDimMapping;
auto dimMapping = llvm::to_vector(
maybeThreadDimMappingAttr
? extractFromI64ArrayAttr(*maybeThreadDimMappingAttr)
: ArrayRef<int64_t>{});
- FailureOr<ForeachThreadTilingResult> tilingResult = failure();
- if (!getMixedNumThreads().empty()) {
+ FailureOr<linalg::ForeachThreadTilingResult> tilingResult = failure();
+ if (!mixedNumThreads.empty()) {
tilingResult = linalg::tileToForeachThreadOp(rewriter, tilableOp,
numThreads, dimMapping);
} else {
@@ -1435,12 +1429,32 @@ DiagnosedSilenceableFailure transform::TileToForeachThreadOp::apply(
}
if (failed(tilingResult))
- return emitDefaultSilenceableFailure(tilableOp);
+ return transformOp.emitDefaultSilenceableFailure(tilableOp);
rewriter.replaceOp(tilableOp, tilingResult->tileOp->getResults());
tileOps.push_back(tilingResult->tileOp);
tiledOps.push_back(tilingResult->tiledOp);
}
+ return DiagnosedSilenceableFailure(success());
+}
+
+DiagnosedSilenceableFailure transform::TileToForeachThreadOp::apply(
+ transform::TransformResults &transformResults,
+ transform::TransformState &state) {
+ IRRewriter rewriter(getContext());
+ ArrayRef<Operation *> targets = state.getPayloadOps(getTarget());
+
+ // Result payload ops.
+ SmallVector<Operation *> tileOps;
+ SmallVector<Operation *> tiledOps;
+
+ DiagnosedSilenceableFailure diag = tileToForeachThreadOpImpl(
+ rewriter, state, cast<TransformOpInterface>(getOperation()), targets,
+ getMixedNumThreads(), getMixedTileSizes(), getThreadDimMapping(), tileOps,
+ tiledOps);
+
+ if (!diag.succeeded())
+ return diag;
transformResults.set(getForeachThreadOp().cast<OpResult>(), tileOps);
transformResults.set(getTiledOp().cast<OpResult>(), tiledOps);
More information about the Mlir-commits
mailing list