[Mlir-commits] [mlir] 9437bf4 - [mlir][Linalg][Transform] Fix effect on RewriteInDestinationPassingStyleOp that did not consume its operand
Nicolas Vasilache
llvmlistbot at llvm.org
Tue Mar 21 02:17:51 PDT 2023
Author: Nicolas Vasilache
Date: 2023-03-21T02:17:45-07:00
New Revision: 9437bf418a7fdb9a1079f416dd28bb7107161d74
URL: https://github.com/llvm/llvm-project/commit/9437bf418a7fdb9a1079f416dd28bb7107161d74
DIFF: https://github.com/llvm/llvm-project/commit/9437bf418a7fdb9a1079f416dd28bb7107161d74.diff
LOG: [mlir][Linalg][Transform] Fix effect on RewriteInDestinationPassingStyleOp that did not consume its operand
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 712abf341f460..c58e955cb7951 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -83,8 +83,10 @@ def BufferizeToAllocationOp : Op<Transform_Dialect,
//===----------------------------------------------------------------------===//
def DecomposeOp : Op<Transform_Dialect, "structured.decompose",
- [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
- TransformOpInterface, TransformEachOpTrait]> {
+ [FunctionalStyleTransformOpTrait,
+ MemoryEffectsOpInterface,
+ TransformOpInterface,
+ TransformEachOpTrait]> {
let description = [{
Decomposes named complex operations, such as higher-dimensional
(depthwise) convolutions, into combinations of lower-dimensional equivalents
@@ -932,9 +934,10 @@ def ScalarizeOp : Op<Transform_Dialect, "structured.scalarize",
def RewriteInDestinationPassingStyleOp : Op<
Transform_Dialect, "structured.rewrite_in_destination_passing_style",
- [MemoryEffectsOpInterface,
- NavigationTransformOpTrait,
- DeclareOpInterfaceMethods<TransformOpInterface>]> {
+ [FunctionalStyleTransformOpTrait,
+ MemoryEffectsOpInterface,
+ TransformOpInterface,
+ TransformEachOpTrait]> {
let description = [{
Rewrite a supported tensor operation that is not in destination-passing style
into a form that is in destination-passing style.
@@ -963,6 +966,13 @@ def RewriteInDestinationPassingStyleOp : Op<
$target attr-dict
`:` functional-type($target, results)
}];
+
+ let extraClassDeclaration = [{
+ ::mlir::DiagnosedSilenceableFailure applyToOne(
+ ::mlir::Operation *target,
+ ::mlir::transform::ApplyToEachResultList &results,
+ ::mlir::transform::TransformState &state);
+ }];
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 407b8d213de1c..d98eb3b781fc5 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2000,24 +2000,21 @@ transform::ScalarizeOp::applyToOne(LinalgOp target,
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure
-transform::RewriteInDestinationPassingStyleOp::apply(
- transform::TransformResults &results, transform::TransformState &state) {
+transform::RewriteInDestinationPassingStyleOp::applyToOne(
+ Operation *target, transform::ApplyToEachResultList &results,
+ transform::TransformState &state) {
SmallVector<Operation *> res;
- ArrayRef<Operation *> targetOps = state.getPayloadOps(getTarget());
- for (Operation *target : targetOps) {
- IRRewriter rewriter(target->getContext());
- rewriter.setInsertionPoint(target);
- FailureOr<Operation *> maybeResult =
- TypeSwitch<Operation *, FailureOr<Operation *>>(target)
- .Case<tensor::FromElementsOp, tensor::GenerateOp, tensor::PadOp>(
- [&rewriter](auto op) {
- return rewriteInDestinationPassingStyle(rewriter, op);
- });
- if (failed(maybeResult))
- return emitDefaultSilenceableFailure(target);
- res.push_back(*maybeResult);
- }
- results.set(getResult().cast<OpResult>(), res);
+ IRRewriter rewriter(target->getContext());
+ rewriter.setInsertionPoint(target);
+ FailureOr<Operation *> maybeResult =
+ TypeSwitch<Operation *, FailureOr<Operation *>>(target)
+ .Case<tensor::FromElementsOp, tensor::GenerateOp, tensor::PadOp>(
+ [&rewriter](auto op) {
+ return rewriteInDestinationPassingStyle(rewriter, op);
+ });
+ if (failed(maybeResult))
+ return emitDefaultSilenceableFailure(target);
+ results.push_back(*maybeResult);
return DiagnosedSilenceableFailure::success();
}
More information about the Mlir-commits
mailing list