[llvm-branch-commits] [mlir] [mlir][Transform] `apply_conversion_patterns`: Update handles (PR #83950)
Matthias Springer via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Mon Mar 4 19:05:25 PST 2024
https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/83950
Until now, `transform.apply_conversion_patterns` consumed the target handle and potentially invalidated handles. This commit adds tracking functionality similar to `transform.apply_patterns`, such that handles are no longer invalidated, but updated based on op replacements performed by the dialect conversion.
This new functionality is hidden behind a `preserve_handles` attribute for now.
>From 270ade8d830bbed2ebee448c03b82b058707f6ff Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Tue, 5 Mar 2024 03:01:27 +0000
Subject: [PATCH] [mlir][Transform] `apply_conversion_patterns`: Update handles
Until now, `transform.apply_conversion_patterns` consumed the target handle and potentially invalidated handles. This commit adds tracking functionality similar to `transform.apply_patterns`, such that handles are no longer invalidated, but updated based on op replacements performed by the dialect conversion.
This new functionality is hidden behind a `preserve_handles` attribute for now.
---
.../Transform/IR/TransformInterfaces.h | 32 +++++++++----
.../mlir/Dialect/Transform/IR/TransformOps.td | 18 ++++++--
.../Transform/IR/TransformInterfaces.cpp | 39 ++++++++--------
.../lib/Dialect/Transform/IR/TransformOps.cpp | 45 ++++++++++++++++---
.../Transform/test-pattern-application.mlir | 30 +++++++++++++
5 files changed, 129 insertions(+), 35 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index 313cdc27f780a7..32724ff4b98e8e 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -921,20 +921,36 @@ TransformState::RegionScope TransformState::make_region_scope(Region ®ion) {
return RegionScope(*this, region);
}
+/// A configuration object for customizing a `TrackingListener`.
+struct TrackingListenerConfig {
+ using SkipHandleFn = std::function<bool(Value)>;
+
+ /// An optional function that returns "true" for handles that do not have to
+ /// be updated. These are typically dead or consumed handles.
+ SkipHandleFn skipHandleFn = nullptr;
+
+ /// If set to "true", the name of a replacement op must match the name of the
+ /// original op. If set to "false", the names of the payload ops tracked in a
+ /// handle may change as the tracking listener updates the transform state.
+ bool requireMatchingReplacementOpName = true;
+
+ /// If set to "true", cast ops (that implement the CastOpInterface) are
+ /// skipped and the replacement op search continues with the operands of the
+ /// cast op.
+ bool skipCastOps = true;
+};
+
/// A listener that updates a TransformState based on IR modifications. This
/// listener can be used during a greedy pattern rewrite to keep the transform
/// state up-to-date.
class TrackingListener : public RewriterBase::Listener,
public TransformState::Extension {
public:
- /// A function that returns "true" for handles that do not have to be updated.
- using SkipHandleFn = std::function<bool(Value)>;
-
/// Create a new TrackingListener for usage in the specified transform op.
/// Optionally, a function can be specified to identify handles that should
/// do not have to be updated.
TrackingListener(TransformState &state, TransformOpInterface op,
- SkipHandleFn skipHandleFn = nullptr);
+ TrackingListenerConfig config = TrackingListenerConfig());
protected:
/// Return a replacement payload op for the given op, which is going to be
@@ -959,7 +975,8 @@ class TrackingListener : public RewriterBase::Listener,
/// same computation; e.g., there may be tiled "linalg.generic" inside the
/// loop body that represents the original computation. Therefore, the
/// TrackingListener is conservative by default: it drops the mapping and
- /// triggers the "payload replacement not found" notification.
+ /// triggers the "payload replacement not found" notification. This default
+ /// behavior can be customized in `TrackingListenerConfig`.
///
/// If no replacement op could be found according to the rules mentioned
/// above, this function tries to skip over cast-like ops that implement
@@ -1023,9 +1040,8 @@ class TrackingListener : public RewriterBase::Listener,
/// The handles that are consumed by the transform op.
DenseSet<Value> consumedHandles;
- /// Handles for which this function evaluates to "true" do not have to be
- /// updated. These are typically dead or consumed handles.
- SkipHandleFn skipHandleFn;
+ /// Tracking listener configuration.
+ TrackingListenerConfig config;
};
/// A specialized listener that keeps track of cases in which no replacement
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index 9f513822ed0a4e..0e42d12a69a400 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -190,11 +190,20 @@ def ApplyConversionPatternsOp : TransformDialectOp<"apply_conversion_patterns",
The `legal_ops`, `illegal_ops`, `legal_dialects`, `illegal_dialects`
attributes specify the conversion target.
- This transform consumes the `target` handle and modifies the payload. It
- does not produce any handles.
+ This transform modifies the payload. By default, it consumes the `target`
+ handle. It does not produce any handles.
+
+ If the `preserve_handles` attribute is set, this transform does not consume
+ the `target` handle and instead updates handles based on notifications from
+ a tracking listener that is attached to the dialect conversion, similar to
+ `transform.apply_patterns`. Only replacements via `RewriterBase::replaceOp`
+ or `replaceOpWithNewOp` are considered "payload op replacements". In
+ contrast to `transform.apply_patterns`, we allow replacement ops even if the
+ op name has changed. More details can be found at the documentation site of
+ `TrackingListener`.
This transform produces a silenceable failure if the dialect conversion was
- unsuccessful.
+ unsuccessful or the tracking listener failed to find a replacement op.
}];
let arguments = (ins TransformHandleTypeInterface:$target,
@@ -202,7 +211,8 @@ def ApplyConversionPatternsOp : TransformDialectOp<"apply_conversion_patterns",
OptionalAttr<StrArrayAttr>:$illegal_ops,
OptionalAttr<StrArrayAttr>:$legal_dialects,
OptionalAttr<StrArrayAttr>:$illegal_dialects,
- UnitAttr:$partial_conversion);
+ UnitAttr:$partial_conversion,
+ UnitAttr:$preserve_handles);
let results = (outs);
let regions = (region
MaxSizedRegion<1>:$patterns,
diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index bb9f6fec452986..71a9d61198e3fb 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -918,7 +918,8 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
}
// Prepare rewriter and listener.
- TrackingListener::SkipHandleFn skipHandleFn = [&](Value handle) {
+ TrackingListenerConfig config;
+ config.skipHandleFn = [&](Value handle) {
// Skip handle if it is dead.
auto scopeIt =
llvm::find_if(llvm::reverse(regionStack), [&](RegionScope *scope) {
@@ -935,7 +936,7 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
return true;
};
transform::ErrorCheckingTrackingListener trackingListener(*this, transform,
- skipHandleFn);
+ config);
transform::TransformRewriter rewriter(transform->getContext(),
&trackingListener);
@@ -1184,9 +1185,8 @@ bool transform::TransformResults::isSet(unsigned resultNumber) const {
transform::TrackingListener::TrackingListener(TransformState &state,
TransformOpInterface op,
- SkipHandleFn skipHandleFn)
- : TransformState::Extension(state), transformOp(op),
- skipHandleFn(skipHandleFn) {
+ TrackingListenerConfig config)
+ : TransformState::Extension(state), transformOp(op), config(config) {
if (op) {
for (OpOperand *opOperand : transformOp.getConsumedHandleOpOperands()) {
consumedHandles.insert(opOperand->get());
@@ -1228,8 +1228,19 @@ DiagnosedSilenceableFailure transform::TrackingListener::findReplacementOp(
return diag;
}
- // If the defining op has the same type, we take it as a replacement.
- if (op->getName() == defOp->getName()) {
+ // Skip through ops that implement CastOpInterface.
+ if (config.skipCastOps && isa<CastOpInterface>(defOp)) {
+ values.clear();
+ values.assign(defOp->getOperands().begin(), defOp->getOperands().end());
+ diag.attachNote(defOp->getLoc())
+ << "using output of 'CastOpInterface' op";
+ continue;
+ }
+
+ // If the defining op has the same name or we do not care about the name of
+ // op replacements at all, we take it as a replacement.
+ if (!config.requireMatchingReplacementOpName ||
+ op->getName() == defOp->getName()) {
result = defOp;
return DiagnosedSilenceableFailure::success();
}
@@ -1251,14 +1262,6 @@ DiagnosedSilenceableFailure transform::TrackingListener::findReplacementOp(
"'FindPayloadReplacementOpInterface'";
continue;
}
-
- // Skip through ops that implement CastOpInterface.
- if (isa<CastOpInterface>(defOp)) {
- values.assign(defOp->getOperands().begin(), defOp->getOperands().end());
- diag.attachNote(defOp->getLoc())
- << "using output of 'CastOpInterface' op";
- continue;
- }
} while (!values.empty());
diag.attachNote() << "ran out of suitable replacement values";
@@ -1318,9 +1321,9 @@ void transform::TrackingListener::notifyOperationReplaced(
// Check if there are any handles that must be updated.
Value aliveHandle;
- if (skipHandleFn) {
- auto it =
- llvm::find_if(opHandles, [&](Value v) { return !skipHandleFn(v); });
+ if (config.skipHandleFn) {
+ auto it = llvm::find_if(opHandles,
+ [&](Value v) { return !config.skipHandleFn(v); });
if (it != opHandles.end())
aliveHandle = *it;
} else if (!opHandles.empty()) {
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 180d11c30e65de..ca80899ab07341 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -563,6 +563,17 @@ DiagnosedSilenceableFailure transform::ApplyConversionPatternsOp::apply(
}
}
+ // Attach a tracking listener if handles should be preserved. We configure the
+ // listener to allow op replacements with different names, as conversion
+ // patterns typically replace ops with replacement ops that have a different
+ // name.
+ TrackingListenerConfig trackingConfig;
+ trackingConfig.requireMatchingReplacementOpName = false;
+ ErrorCheckingTrackingListener trackingListener(state, *this, trackingConfig);
+ ConversionConfig conversionConfig;
+ if (getPreserveHandles())
+ conversionConfig.listener = &trackingListener;
+
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
for (Operation *target : state.getPayloadOps(getTarget())) {
// Make sure that this transform is not applied to itself. Modifying the
@@ -574,16 +585,36 @@ DiagnosedSilenceableFailure transform::ApplyConversionPatternsOp::apply(
LogicalResult status = failure();
if (getPartialConversion()) {
- status = applyPartialConversion(target, conversionTarget, frozenPatterns);
+ status = applyPartialConversion(target, conversionTarget, frozenPatterns,
+ conversionConfig);
} else {
- status = applyFullConversion(target, conversionTarget, frozenPatterns);
+ status = applyFullConversion(target, conversionTarget, frozenPatterns,
+ conversionConfig);
}
+ // Check dialect conversion state.
+ DiagnosedSilenceableFailure diag = DiagnosedSilenceableFailure::success();
if (failed(status)) {
- auto diag = emitSilenceableError() << "dialect conversion failed";
+ diag = emitSilenceableError() << "dialect conversion failed";
diag.attachNote(target->getLoc()) << "target op";
- return diag;
}
+
+ // Check tracking listener error state.
+ DiagnosedSilenceableFailure trackingFailure =
+ trackingListener.checkAndResetError();
+ if (!trackingFailure.succeeded()) {
+ if (diag.succeeded()) {
+ // Tracking failure is the only failure.
+ return trackingFailure;
+ } else {
+ diag.attachNote() << "tracking listener also failed: "
+ << trackingFailure.getMessage();
+ (void)trackingFailure.silence();
+ }
+ }
+
+ if (!diag.succeeded())
+ return diag;
}
return DiagnosedSilenceableFailure::success();
@@ -632,7 +663,11 @@ LogicalResult transform::ApplyConversionPatternsOp::verify() {
void transform::ApplyConversionPatternsOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
- transform::consumesHandle(getTarget(), effects);
+ if (!getPreserveHandles()) {
+ transform::consumesHandle(getTarget(), effects);
+ } else {
+ transform::onlyReadsHandle(getTarget(), effects);
+ }
transform::modifiesPayload(effects);
}
diff --git a/mlir/test/Dialect/Transform/test-pattern-application.mlir b/mlir/test/Dialect/Transform/test-pattern-application.mlir
index 0c41e81b17b522..fa8a555af92188 100644
--- a/mlir/test/Dialect/Transform/test-pattern-application.mlir
+++ b/mlir/test/Dialect/Transform/test-pattern-application.mlir
@@ -417,3 +417,33 @@ module attributes { transform.with_named_sequence } {
transform.yield
}
}
+
+// -----
+
+// "test.foo" is tracked and replaced with "test.new_op" during a dialect
+// conversion. Make sure that the handle is updated accordingly.
+
+// CHECK-LABEL: func @dialect_conversion_tracking
+// CHECK-NEXT: %[[m:.*]] = "test.new_op"() {annotated} : () -> memref<5xf32>
+// CHECK-NEXT: %[[cast:.*]] = builtin.unrealized_conversion_cast %0 : memref<5xf32> to tensor<5xf32>
+// CHECK-NEXT: return %[[cast]]
+func.func @dialect_conversion_tracking() -> tensor<5xf32> {
+ %0 = "test.foo"() {replace_with_new_op = "test.bar"} : () -> (tensor<5xf32>)
+ return %0 : tensor<5xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
+ %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.structured.match ops{["test.foo"]} in %0 : (!transform.any_op) -> !transform.any_op
+ transform.apply_conversion_patterns to %0 {
+ transform.apply_conversion_patterns.transform.test_conversion_patterns
+ } with type_converter {
+ transform.apply_conversion_patterns.transform.test_type_converter
+ } {legal_ops = ["func.func", "func.return", "test.new_op"], preserve_handles}
+ : !transform.any_op
+ // Add an attribute to %1, which is now mapped to a new op.
+ transform.annotate %1 "annotated" : !transform.any_op
+ transform.yield
+ }
+}
More information about the llvm-branch-commits
mailing list