[Mlir-commits] [mlir] 9cc8e45 - [mlir][transform] Add notifyPayloadOperationReplaced to TransformRewriter

Matthias Springer llvmlistbot at llvm.org
Mon Jun 26 08:28:53 PDT 2023


Author: Matthias Springer
Date: 2023-06-26T17:28:45+02:00
New Revision: 9cc8e45898ed3acb0893c260872c1ee50d566738

URL: https://github.com/llvm/llvm-project/commit/9cc8e45898ed3acb0893c260872c1ee50d566738
DIFF: https://github.com/llvm/llvm-project/commit/9cc8e45898ed3acb0893c260872c1ee50d566738.diff

LOG: [mlir][transform] Add notifyPayloadOperationReplaced to TransformRewriter

This function allows users to update payload op mappings in cases where such replacements cannot be performed automatically by the rewriter/listener interface.

Differential Revision: https://reviews.llvm.org/D153764

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
    mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
    mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
    mlir/test/Dialect/Transform/test-interpreter.mlir
    mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
    mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index 9b45c15777040..d54f9c404fba4 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -923,9 +923,12 @@ class TrackingListener : public RewriterBase::Listener,
   TransformOpInterface getTransformOp() const { return transformOp; }
 
 private:
+  friend class TransformRewriter;
+
   void notifyOperationRemoved(Operation *op) override;
 
   void notifyOperationReplaced(Operation *op, ValueRange newValues) override;
+  using Listener::notifyOperationReplaced;
 
   /// The transform op in which this TrackingListener is used.
   TransformOpInterface transformOp;
@@ -981,6 +984,19 @@ class TransformRewriter : public RewriterBase {
   /// Silence all tracking failures that have been encountered so far.
   void silenceTrackingFailure();
 
+  /// Notify the transform dialect interpreter that the given op has been
+  /// replaced with another op and that the mapping between handles and payload
+  /// ops/values should be updated. This function should be called before the
+  /// original op is erased. It fails if the operation could not be replaced,
+  /// e.g., because the original operation is not tracked.
+  ///
+  /// Note: As long as IR modifications are performed through this rewriter,
+  /// the transform state is usually updated automatically. This function should
+  /// be used when unsupported rewriter API is used; e.g., updating all uses of
+  /// a tracked operation one-by-one instead of using `RewriterBase::replaceOp`.
+  LogicalResult notifyPayloadOperationReplaced(Operation *op,
+                                               Operation *replacement);
+
 private:
   ErrorCheckingTrackingListener *const listener;
 };

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 042476390c280..2148624c2963d 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -688,36 +688,6 @@ bool transform::FuseIntoContainingOp::allowsRepeatedHandleOperands() {
   return true;
 }
 
-namespace {
-/// Unsafely exposes an internal protected method of TransformState::Extension
-/// as public.
-///
-/// MUST NOT be used directly.
-class UnsafeOpReplacementStateExtension : public TransformState::Extension {
-public:
-  UnsafeOpReplacementStateExtension(TransformState &state)
-      : TransformState::Extension(state) {}
-
-  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
-      UnsafeOpReplacementStateExtension)
-
-  LogicalResult doReplacePayloadOp(Operation *op, Operation *replacement) {
-    return replacePayloadOp(op, replacement);
-  }
-};
-} // namespace
-
-/// Replaces `payload` with `replacement` in all handles stored in the state.
-/// MUST NOT be used except for the case immediately below.
-static void forciblyReplaceReferencedPayloadOperation(TransformState &state,
-                                                      Operation *payload,
-                                                      Operation *replacement) {
-  UnsafeOpReplacementStateExtension extension(state);
-  // This may return failure if the payload is not associated with any handle,
-  // ignore that.
-  (void)extension.doReplacePayloadOp(payload, replacement);
-}
-
 DiagnosedSilenceableFailure
 transform::FuseIntoContainingOp::apply(transform::TransformRewriter &rewriter,
                                        transform::TransformResults &results,
@@ -787,6 +757,19 @@ transform::FuseIntoContainingOp::apply(transform::TransformRewriter &rewriter,
       LLVM_DEBUG(DBGS() << "\nFused a direct extract use\n" << *containingOp);
       fusedOps.append(tiledOps);
       if (newContainingOp) {
+        // Update handles associated with the containing op so we don't need to
+        // invalidate them. This is a hack to support better composability
+        // between tiling and fusion while a proper mechanism is being
+        // investigated.
+        //
+        // DO NOT replicate this elsewhere unless you understand what you are
+        // doing.
+        LogicalResult replacementStatus =
+            rewriter.notifyPayloadOperationReplaced(containingOp,
+                                                    newContainingOp);
+        (void)replacementStatus;
+        assert(succeeded(replacementStatus) &&
+               "unable to update transform state mapping");
         rewriter.eraseOp(containingOp);
         containingOp = newContainingOp;
       }
@@ -813,14 +796,6 @@ transform::FuseIntoContainingOp::apply(transform::TransformRewriter &rewriter,
     return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
   }
 
-  // Update handles associated with the containing op so we don't need to
-  // invalidate them. This is a hack to support better composability between
-  // tiling and fusion while a proper mechanism is being investigated.
-  //
-  // DO NOT replicate this elsewhere unless you understand what you are doing.
-  forciblyReplaceReferencedPayloadOperation(state, *containingOps.begin(),
-                                            containingOp);
-
   results.set(cast<OpResult>(getFusedOp()), fusedOps);
   results.set(cast<OpResult>(getNewContainingOp()), {containingOp});
   return DiagnosedSilenceableFailure::success();

diff  --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index 013d0e29ef5ce..27794e6636959 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -1452,6 +1452,11 @@ void transform::TransformRewriter::silenceTrackingFailure() {
   }
 }
 
+LogicalResult transform::TransformRewriter::notifyPayloadOperationReplaced(
+    Operation *op, Operation *replacement) {
+  return listener->replacePayloadOp(op, replacement);
+}
+
 //===----------------------------------------------------------------------===//
 // Utilities for TransformEachOpTrait.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
index 3854cceb6273d..773405e61a4f8 100644
--- a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt --test-transform-dialect-interpreter --split-input-file %s | FileCheck %s
+// RUN: mlir-opt --test-transform-dialect-interpreter --split-input-file %s -verify-diagnostics | FileCheck %s
 
 #map0 = affine_map<()[s0, s1] -> (s0 ceildiv s1)>
 #map1 = affine_map<(d0)[s0] -> (d0 * s0)>
@@ -323,6 +323,7 @@ module {
 
     // CHECK: %[[R0:.*]]:2 = scf.forall (%[[ARG5:.*]]) in (%{{.*}}) shared_outs(%[[ARG6:.*]] = %[[OUT_2]], %[[ARG7:.*]] = %[[OUT_1]])
     // CHECK-SAME: -> (tensor<?xf32>, tensor<?xf32>) {
+    // expected-remark @below{{new containing op}}
     %2 = scf.forall (%i) in (%1) shared_outs(%o = %out_2) -> (tensor<?xf32>) {
       // CHECK: %[[I0:.*]] = affine.apply {{.*}}
       %3 = affine.apply #map1(%i)[%idx]
@@ -350,8 +351,9 @@ module {
     %1 = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.op<"scf.forall">
 
     // linalg.generic is tileable. The op is tiled and fused.
-    transform.structured.fuse_into_containing_op %0 into %1
+    %fused, %containing = transform.structured.fuse_into_containing_op %0 into %1
       : (!transform.op<"linalg.generic">, !transform.op<"scf.forall">) -> (!transform.any_op, !transform.any_op)
+    test_print_remark_at_operand %containing, "new containing op" : !transform.any_op
   }
 }
 

diff  --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index bfe61df4e4043..4ab9d65e6475b 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -1710,3 +1710,20 @@ transform.sequence failures(propagate) {
   transform.annotate %0 "broadcast_attr" = %2 : !transform.any_op, !transform.param<i64>
   transform.annotate %0 "unit_attr" : !transform.any_op
 }
+
+// -----
+
+func.func @notify_payload_op_replaced(%arg0: index, %arg1: index) {
+  %0 = arith.muli %arg0, %arg1 {original} : index
+  // expected-remark @below{{updated handle}}
+  %1 = arith.muli %arg0, %arg1 {replacement} : index
+  return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+  %0 = transform.structured.match attributes{original} in %arg1 : (!transform.any_op) -> !transform.any_op
+  %1 = transform.structured.match attributes{replacement} in %arg1 : (!transform.any_op) -> !transform.any_op
+  test_notify_payload_op_replaced %0, %1 : (!transform.any_op, !transform.any_op) -> ()
+  test_print_remark_at_operand %0, "updated handle" : !transform.any_op
+}

diff  --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
index 51c0615932b61..f3e80060f5cec 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
@@ -872,6 +872,34 @@ LogicalResult mlir::test::TestReEnterRegionOp::verify() {
   return success();
 }
 
+DiagnosedSilenceableFailure mlir::test::TestNotifyPayloadOpReplacedOp::apply(
+    transform::TransformRewriter &rewriter,
+    transform::TransformResults &results, transform::TransformState &state) {
+  auto originalOps = state.getPayloadOps(getOriginal());
+  auto replacementOps = state.getPayloadOps(getReplacement());
+  if (llvm::range_size(originalOps) != llvm::range_size(replacementOps))
+    return emitSilenceableError() << "expected same number of original and "
+                                     "replacement payload operations";
+  for (const auto &[original, replacement] :
+       llvm::zip(originalOps, replacementOps)) {
+    if (failed(
+            rewriter.notifyPayloadOperationReplaced(original, replacement))) {
+      auto diag = emitSilenceableError()
+                  << "unable to replace payload op in transform mapping";
+      diag.attachNote(original->getLoc()) << "original payload op";
+      diag.attachNote(replacement->getLoc()) << "replacement payload op";
+      return diag;
+    }
+  }
+  return DiagnosedSilenceableFailure::success();
+}
+
+void mlir::test::TestNotifyPayloadOpReplacedOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  transform::onlyReadsHandle(getOriginal(), effects);
+  transform::onlyReadsHandle(getReplacement(), effects);
+}
+
 namespace {
 /// Test extension of the Transform dialect. Registers additional ops and
 /// declares PDL as dependent dialect since the additional ops are using PDL

diff  --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
index 594c32d165d43..02f4955c30cc6 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
@@ -564,4 +564,15 @@ def TestReEnterRegionOp
   let hasVerifier = 1;
 }
 
+def TestNotifyPayloadOpReplacedOp
+  : Op<Transform_Dialect, "test_notify_payload_op_replaced",
+       [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+        DeclareOpInterfaceMethods<TransformOpInterface>]> {
+  let arguments = (ins TransformHandleTypeInterface:$original,
+                       TransformHandleTypeInterface:$replacement);
+  let results = (outs);
+  let assemblyFormat = "$original `,` $replacement attr-dict `:` functional-type(operands, results)";
+  let cppNamespace = "::mlir::test";
+}
+
 #endif // MLIR_TESTTRANSFORMDIALECTEXTENSION_TD


        


More information about the Mlir-commits mailing list