[Mlir-commits] [mlir] 9437bf4 - [mlir][Linalg][Transform] Fix effect on RewriteInDestinationPassingStyleOp that did not consume its operand

Nicolas Vasilache llvmlistbot at llvm.org
Tue Mar 21 02:17:51 PDT 2023


Author: Nicolas Vasilache
Date: 2023-03-21T02:17:45-07:00
New Revision: 9437bf418a7fdb9a1079f416dd28bb7107161d74

URL: https://github.com/llvm/llvm-project/commit/9437bf418a7fdb9a1079f416dd28bb7107161d74
DIFF: https://github.com/llvm/llvm-project/commit/9437bf418a7fdb9a1079f416dd28bb7107161d74.diff

LOG: [mlir][Linalg][Transform] Fix effect on RewriteInDestinationPassingStyleOp that did not consume its operand

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 712abf341f460..c58e955cb7951 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -83,8 +83,10 @@ def BufferizeToAllocationOp : Op<Transform_Dialect,
 //===----------------------------------------------------------------------===//
 
 def DecomposeOp : Op<Transform_Dialect, "structured.decompose",
-    [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
-     TransformOpInterface, TransformEachOpTrait]> {
+    [FunctionalStyleTransformOpTrait, 
+     MemoryEffectsOpInterface,
+     TransformOpInterface, 
+     TransformEachOpTrait]> {
   let description = [{
     Decomposes named complex operations, such as higher-dimensional
     (depthwise) convolutions, into combinations of lower-dimensional equivalents
@@ -932,9 +934,10 @@ def ScalarizeOp : Op<Transform_Dialect, "structured.scalarize",
 
 def RewriteInDestinationPassingStyleOp : Op<
     Transform_Dialect, "structured.rewrite_in_destination_passing_style",
-    [MemoryEffectsOpInterface,
-     NavigationTransformOpTrait,
-     DeclareOpInterfaceMethods<TransformOpInterface>]> {
+    [FunctionalStyleTransformOpTrait, 
+     MemoryEffectsOpInterface,
+     TransformOpInterface, 
+     TransformEachOpTrait]> {
   let description = [{
     Rewrite a supported tensor operation that is not in destination-passing style
     into a form that is in destination-passing style.
@@ -963,6 +966,13 @@ def RewriteInDestinationPassingStyleOp : Op<
     $target attr-dict
     `:` functional-type($target, results)
   }];
+
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure applyToOne(
+        ::mlir::Operation *target,
+        ::mlir::transform::ApplyToEachResultList &results,
+        ::mlir::transform::TransformState &state);
+  }];
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 407b8d213de1c..d98eb3b781fc5 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2000,24 +2000,21 @@ transform::ScalarizeOp::applyToOne(LinalgOp target,
 //===----------------------------------------------------------------------===//
 
 DiagnosedSilenceableFailure
-transform::RewriteInDestinationPassingStyleOp::apply(
-    transform::TransformResults &results, transform::TransformState &state) {
+transform::RewriteInDestinationPassingStyleOp::applyToOne(
+    Operation *target, transform::ApplyToEachResultList &results,
+    transform::TransformState &state) {
   SmallVector<Operation *> res;
-  ArrayRef<Operation *> targetOps = state.getPayloadOps(getTarget());
-  for (Operation *target : targetOps) {
-    IRRewriter rewriter(target->getContext());
-    rewriter.setInsertionPoint(target);
-    FailureOr<Operation *> maybeResult =
-        TypeSwitch<Operation *, FailureOr<Operation *>>(target)
-            .Case<tensor::FromElementsOp, tensor::GenerateOp, tensor::PadOp>(
-                [&rewriter](auto op) {
-                  return rewriteInDestinationPassingStyle(rewriter, op);
-                });
-    if (failed(maybeResult))
-      return emitDefaultSilenceableFailure(target);
-    res.push_back(*maybeResult);
-  }
-  results.set(getResult().cast<OpResult>(), res);
+  IRRewriter rewriter(target->getContext());
+  rewriter.setInsertionPoint(target);
+  FailureOr<Operation *> maybeResult =
+      TypeSwitch<Operation *, FailureOr<Operation *>>(target)
+          .Case<tensor::FromElementsOp, tensor::GenerateOp, tensor::PadOp>(
+              [&rewriter](auto op) {
+                return rewriteInDestinationPassingStyle(rewriter, op);
+              });
+  if (failed(maybeResult))
+    return emitDefaultSilenceableFailure(target);
+  results.push_back(*maybeResult);
   return DiagnosedSilenceableFailure::success();
 }
 


        


More information about the Mlir-commits mailing list