[llvm-branch-commits] [mlir] [mlir][Transform] `apply_conversion_patterns`: Update handles (PR #83950)

Matthias Springer via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Mon Mar 4 19:05:25 PST 2024


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/83950

Until now, `transform.apply_conversion_patterns` consumed the target handle and potentially invalidated handles. This commit adds tracking functionality similar to `transform.apply_patterns`, such that handles are no longer invalidated, but updated based on op replacements performed by the dialect conversion.

This new functionality is hidden behind a `preserve_handles` attribute for now.

>From 270ade8d830bbed2ebee448c03b82b058707f6ff Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Tue, 5 Mar 2024 03:01:27 +0000
Subject: [PATCH] [mlir][Transform] `apply_conversion_patterns`: Update handles

Until now, `transform.apply_conversion_patterns` consumed the target handle and potentially invalidated handles. This commit adds tracking functionality similar to `transform.apply_patterns`, such that handles are no longer invalidated, but updated based on op replacements performed by the dialect conversion.

This new functionality is hidden behind a `preserve_handles` attribute for now.
---
 .../Transform/IR/TransformInterfaces.h        | 32 +++++++++----
 .../mlir/Dialect/Transform/IR/TransformOps.td | 18 ++++++--
 .../Transform/IR/TransformInterfaces.cpp      | 39 ++++++++--------
 .../lib/Dialect/Transform/IR/TransformOps.cpp | 45 ++++++++++++++++---
 .../Transform/test-pattern-application.mlir   | 30 +++++++++++++
 5 files changed, 129 insertions(+), 35 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index 313cdc27f780a7..32724ff4b98e8e 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -921,20 +921,36 @@ TransformState::RegionScope TransformState::make_region_scope(Region &region) {
   return RegionScope(*this, region);
 }
 
+/// A configuration object for customizing a `TrackingListener`.
+struct TrackingListenerConfig {
+  using SkipHandleFn = std::function<bool(Value)>;
+
+  /// An optional function that returns "true" for handles that do not have to
+  /// be updated. These are typically dead or consumed handles.
+  SkipHandleFn skipHandleFn = nullptr;
+
+  /// If set to "true", the name of a replacement op must match the name of the
+  /// original op. If set to "false", the names of the payload ops tracked in a
+  /// handle may change as the tracking listener updates the transform state.
+  bool requireMatchingReplacementOpName = true;
+
+  /// If set to "true", cast ops (that implement the CastOpInterface) are
+  /// skipped and the replacement op search continues with the operands of the
+  /// cast op.
+  bool skipCastOps = true;
+};
+
 /// A listener that updates a TransformState based on IR modifications. This
 /// listener can be used during a greedy pattern rewrite to keep the transform
 /// state up-to-date.
 class TrackingListener : public RewriterBase::Listener,
                          public TransformState::Extension {
 public:
-  /// A function that returns "true" for handles that do not have to be updated.
-  using SkipHandleFn = std::function<bool(Value)>;
-
   /// Create a new TrackingListener for usage in the specified transform op.
   /// Optionally, a function can be specified to identify handles that should
   /// do not have to be updated.
   TrackingListener(TransformState &state, TransformOpInterface op,
-                   SkipHandleFn skipHandleFn = nullptr);
+                   TrackingListenerConfig config = TrackingListenerConfig());
 
 protected:
   /// Return a replacement payload op for the given op, which is going to be
@@ -959,7 +975,8 @@ class TrackingListener : public RewriterBase::Listener,
   /// same computation; e.g., there may be tiled "linalg.generic" inside the
   /// loop body that represents the original computation. Therefore, the
   /// TrackingListener is conservative by default: it drops the mapping and
-  /// triggers the "payload replacement not found" notification.
+  /// triggers the "payload replacement not found" notification. This default
+  /// behavior can be customized in `TrackingListenerConfig`.
   ///
   /// If no replacement op could be found according to the rules mentioned
   /// above, this function tries to skip over cast-like ops that implement
@@ -1023,9 +1040,8 @@ class TrackingListener : public RewriterBase::Listener,
   /// The handles that are consumed by the transform op.
   DenseSet<Value> consumedHandles;
 
-  /// Handles for which this function evaluates to "true" do not have to be
-  /// updated. These are typically dead or consumed handles.
-  SkipHandleFn skipHandleFn;
+  /// Tracking listener configuration.
+  TrackingListenerConfig config;
 };
 
 /// A specialized listener that keeps track of cases in which no replacement
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index 9f513822ed0a4e..0e42d12a69a400 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -190,11 +190,20 @@ def ApplyConversionPatternsOp : TransformDialectOp<"apply_conversion_patterns",
     The `legal_ops`, `illegal_ops`, `legal_dialects`, `illegal_dialects`
     attributes specify the conversion target.
 
-    This transform consumes the `target` handle and modifies the payload. It
-    does not produce any handles.
+    This transform modifies the payload. By default, it consumes the `target`
+    handle. It does not produce any handles.
+
+    If the `preserve_handles` attribute is set, this transform does not consume
+    the `target` handle and instead updates handles based on notifications from
+    a tracking listener that is attached to the dialect conversion, similar to
+    `transform.apply_patterns`. Only replacements via `RewriterBase::replaceOp`
+    or `replaceOpWithNewOp` are considered "payload op replacements". In
+    contrast to `transform.apply_patterns`, we allow replacement ops even if the
+    op name has changed. More details can be found at the documentation site of
+    `TrackingListener`.
 
     This transform produces a silenceable failure if the dialect conversion was
-    unsuccessful.
+    unsuccessful or the tracking listener failed to find a replacement op.
   }];
 
   let arguments = (ins TransformHandleTypeInterface:$target,
@@ -202,7 +211,8 @@ def ApplyConversionPatternsOp : TransformDialectOp<"apply_conversion_patterns",
                        OptionalAttr<StrArrayAttr>:$illegal_ops,
                        OptionalAttr<StrArrayAttr>:$legal_dialects,
                        OptionalAttr<StrArrayAttr>:$illegal_dialects,
-                       UnitAttr:$partial_conversion);
+                       UnitAttr:$partial_conversion,
+                       UnitAttr:$preserve_handles);
   let results = (outs);
   let regions = (region
       MaxSizedRegion<1>:$patterns,
diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index bb9f6fec452986..71a9d61198e3fb 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -918,7 +918,8 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
   }
 
   // Prepare rewriter and listener.
-  TrackingListener::SkipHandleFn skipHandleFn = [&](Value handle) {
+  TrackingListenerConfig config;
+  config.skipHandleFn = [&](Value handle) {
     // Skip handle if it is dead.
     auto scopeIt =
         llvm::find_if(llvm::reverse(regionStack), [&](RegionScope *scope) {
@@ -935,7 +936,7 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
     return true;
   };
   transform::ErrorCheckingTrackingListener trackingListener(*this, transform,
-                                                            skipHandleFn);
+                                                            config);
   transform::TransformRewriter rewriter(transform->getContext(),
                                         &trackingListener);
 
@@ -1184,9 +1185,8 @@ bool transform::TransformResults::isSet(unsigned resultNumber) const {
 
 transform::TrackingListener::TrackingListener(TransformState &state,
                                               TransformOpInterface op,
-                                              SkipHandleFn skipHandleFn)
-    : TransformState::Extension(state), transformOp(op),
-      skipHandleFn(skipHandleFn) {
+                                              TrackingListenerConfig config)
+    : TransformState::Extension(state), transformOp(op), config(config) {
   if (op) {
     for (OpOperand *opOperand : transformOp.getConsumedHandleOpOperands()) {
       consumedHandles.insert(opOperand->get());
@@ -1228,8 +1228,19 @@ DiagnosedSilenceableFailure transform::TrackingListener::findReplacementOp(
       return diag;
     }
 
-    // If the defining op has the same type, we take it as a replacement.
-    if (op->getName() == defOp->getName()) {
+    // Skip through ops that implement CastOpInterface.
+    if (config.skipCastOps && isa<CastOpInterface>(defOp)) {
+      values.clear();
+      values.assign(defOp->getOperands().begin(), defOp->getOperands().end());
+      diag.attachNote(defOp->getLoc())
+          << "using output of 'CastOpInterface' op";
+      continue;
+    }
+
+    // If the defining op has the same name or we do not care about the name of
+    // op replacements at all, we take it as a replacement.
+    if (!config.requireMatchingReplacementOpName ||
+        op->getName() == defOp->getName()) {
       result = defOp;
       return DiagnosedSilenceableFailure::success();
     }
@@ -1251,14 +1262,6 @@ DiagnosedSilenceableFailure transform::TrackingListener::findReplacementOp(
                                           "'FindPayloadReplacementOpInterface'";
       continue;
     }
-
-    // Skip through ops that implement CastOpInterface.
-    if (isa<CastOpInterface>(defOp)) {
-      values.assign(defOp->getOperands().begin(), defOp->getOperands().end());
-      diag.attachNote(defOp->getLoc())
-          << "using output of 'CastOpInterface' op";
-      continue;
-    }
   } while (!values.empty());
 
   diag.attachNote() << "ran out of suitable replacement values";
@@ -1318,9 +1321,9 @@ void transform::TrackingListener::notifyOperationReplaced(
 
   // Check if there are any handles that must be updated.
   Value aliveHandle;
-  if (skipHandleFn) {
-    auto it =
-        llvm::find_if(opHandles, [&](Value v) { return !skipHandleFn(v); });
+  if (config.skipHandleFn) {
+    auto it = llvm::find_if(opHandles,
+                            [&](Value v) { return !config.skipHandleFn(v); });
     if (it != opHandles.end())
       aliveHandle = *it;
   } else if (!opHandles.empty()) {
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 180d11c30e65de..ca80899ab07341 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -563,6 +563,17 @@ DiagnosedSilenceableFailure transform::ApplyConversionPatternsOp::apply(
     }
   }
 
+  // Attach a tracking listener if handles should be preserved. We configure the
+  // listener to allow op replacements with different names, as conversion
+  // patterns typically replace ops with replacement ops that have a different
+  // name.
+  TrackingListenerConfig trackingConfig;
+  trackingConfig.requireMatchingReplacementOpName = false;
+  ErrorCheckingTrackingListener trackingListener(state, *this, trackingConfig);
+  ConversionConfig conversionConfig;
+  if (getPreserveHandles())
+    conversionConfig.listener = &trackingListener;
+
   FrozenRewritePatternSet frozenPatterns(std::move(patterns));
   for (Operation *target : state.getPayloadOps(getTarget())) {
     // Make sure that this transform is not applied to itself. Modifying the
@@ -574,16 +585,36 @@ DiagnosedSilenceableFailure transform::ApplyConversionPatternsOp::apply(
 
     LogicalResult status = failure();
     if (getPartialConversion()) {
-      status = applyPartialConversion(target, conversionTarget, frozenPatterns);
+      status = applyPartialConversion(target, conversionTarget, frozenPatterns,
+                                      conversionConfig);
     } else {
-      status = applyFullConversion(target, conversionTarget, frozenPatterns);
+      status = applyFullConversion(target, conversionTarget, frozenPatterns,
+                                   conversionConfig);
     }
 
+    // Check dialect conversion state.
+    DiagnosedSilenceableFailure diag = DiagnosedSilenceableFailure::success();
     if (failed(status)) {
-      auto diag = emitSilenceableError() << "dialect conversion failed";
+      diag = emitSilenceableError() << "dialect conversion failed";
       diag.attachNote(target->getLoc()) << "target op";
-      return diag;
     }
+
+    // Check tracking listener error state.
+    DiagnosedSilenceableFailure trackingFailure =
+        trackingListener.checkAndResetError();
+    if (!trackingFailure.succeeded()) {
+      if (diag.succeeded()) {
+        // Tracking failure is the only failure.
+        return trackingFailure;
+      } else {
+        diag.attachNote() << "tracking listener also failed: "
+                          << trackingFailure.getMessage();
+        (void)trackingFailure.silence();
+      }
+    }
+
+    if (!diag.succeeded())
+      return diag;
   }
 
   return DiagnosedSilenceableFailure::success();
@@ -632,7 +663,11 @@ LogicalResult transform::ApplyConversionPatternsOp::verify() {
 
 void transform::ApplyConversionPatternsOp::getEffects(
     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
-  transform::consumesHandle(getTarget(), effects);
+  if (!getPreserveHandles()) {
+    transform::consumesHandle(getTarget(), effects);
+  } else {
+    transform::onlyReadsHandle(getTarget(), effects);
+  }
   transform::modifiesPayload(effects);
 }
 
diff --git a/mlir/test/Dialect/Transform/test-pattern-application.mlir b/mlir/test/Dialect/Transform/test-pattern-application.mlir
index 0c41e81b17b522..fa8a555af92188 100644
--- a/mlir/test/Dialect/Transform/test-pattern-application.mlir
+++ b/mlir/test/Dialect/Transform/test-pattern-application.mlir
@@ -417,3 +417,33 @@ module attributes { transform.with_named_sequence } {
     transform.yield
   }
 }
+
+// -----
+
+// "test.foo" is tracked and replaced with "test.new_op" during a dialect
+// conversion. Make sure that the handle is updated accordingly.
+
+// CHECK-LABEL: func @dialect_conversion_tracking
+//  CHECK-NEXT:   %[[m:.*]] = "test.new_op"() {annotated} : () -> memref<5xf32>
+//  CHECK-NEXT:   %[[cast:.*]] = builtin.unrealized_conversion_cast %0 : memref<5xf32> to tensor<5xf32>
+//  CHECK-NEXT:   return %[[cast]]
+func.func @dialect_conversion_tracking() -> tensor<5xf32> {
+  %0 = "test.foo"() {replace_with_new_op = "test.bar"} : () -> (tensor<5xf32>)
+  return %0 : tensor<5xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
+    %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.structured.match ops{["test.foo"]} in %0 : (!transform.any_op) -> !transform.any_op
+    transform.apply_conversion_patterns to %0 {
+      transform.apply_conversion_patterns.transform.test_conversion_patterns
+    } with type_converter {
+      transform.apply_conversion_patterns.transform.test_type_converter
+    } {legal_ops = ["func.func", "func.return", "test.new_op"], preserve_handles}
+        : !transform.any_op
+    // Add an attribute to %1, which is now mapped to a new op.
+    transform.annotate %1 "annotated" : !transform.any_op
+    transform.yield
+  }
+}



More information about the llvm-branch-commits mailing list