[Mlir-commits] [mlir] 9cc8e45 - [mlir][transform] Add notifyPayloadOperationReplaced to TransformRewriter
Matthias Springer
llvmlistbot at llvm.org
Mon Jun 26 08:28:53 PDT 2023
Author: Matthias Springer
Date: 2023-06-26T17:28:45+02:00
New Revision: 9cc8e45898ed3acb0893c260872c1ee50d566738
URL: https://github.com/llvm/llvm-project/commit/9cc8e45898ed3acb0893c260872c1ee50d566738
DIFF: https://github.com/llvm/llvm-project/commit/9cc8e45898ed3acb0893c260872c1ee50d566738.diff
LOG: [mlir][transform] Add notifyPayloadOperationReplaced to TransformRewriter
This function allows users to update payload op mappings in cases where such replacements cannot be performed automatically by the rewriter/listener interface.
Differential Revision: https://reviews.llvm.org/D153764
Added:
Modified:
mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
mlir/test/Dialect/Transform/test-interpreter.mlir
mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index 9b45c15777040..d54f9c404fba4 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -923,9 +923,12 @@ class TrackingListener : public RewriterBase::Listener,
TransformOpInterface getTransformOp() const { return transformOp; }
private:
+ friend class TransformRewriter;
+
void notifyOperationRemoved(Operation *op) override;
void notifyOperationReplaced(Operation *op, ValueRange newValues) override;
+ using Listener::notifyOperationReplaced;
/// The transform op in which this TrackingListener is used.
TransformOpInterface transformOp;
@@ -981,6 +984,19 @@ class TransformRewriter : public RewriterBase {
/// Silence all tracking failures that have been encountered so far.
void silenceTrackingFailure();
+ /// Notify the transform dialect interpreter that the given op has been
+ /// replaced with another op and that the mapping between handles and payload
+ /// ops/values should be updated. This function should be called before the
+ /// original op is erased. It fails if the operation could not be replaced,
+ /// e.g., because the original operation is not tracked.
+ ///
+ /// Note: As long as IR modifications are performed through this rewriter,
+ /// the transform state is usually updated automatically. This function should
+ /// be used when unsupported rewriter API is used; e.g., updating all uses of
+ /// a tracked operation one-by-one instead of using `RewriterBase::replaceOp`.
+ LogicalResult notifyPayloadOperationReplaced(Operation *op,
+ Operation *replacement);
+
private:
ErrorCheckingTrackingListener *const listener;
};
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 042476390c280..2148624c2963d 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -688,36 +688,6 @@ bool transform::FuseIntoContainingOp::allowsRepeatedHandleOperands() {
return true;
}
-namespace {
-/// Unsafely exposes an internal protected method of TransformState::Extension
-/// as public.
-///
-/// MUST NOT be used directly.
-class UnsafeOpReplacementStateExtension : public TransformState::Extension {
-public:
- UnsafeOpReplacementStateExtension(TransformState &state)
- : TransformState::Extension(state) {}
-
- MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
- UnsafeOpReplacementStateExtension)
-
- LogicalResult doReplacePayloadOp(Operation *op, Operation *replacement) {
- return replacePayloadOp(op, replacement);
- }
-};
-} // namespace
-
-/// Replaces `payload` with `replacement` in all handles stored in the state.
-/// MUST NOT be used except for the case immediately below.
-static void forciblyReplaceReferencedPayloadOperation(TransformState &state,
- Operation *payload,
- Operation *replacement) {
- UnsafeOpReplacementStateExtension extension(state);
- // This may return failure if the payload is not associated with any handle,
- // ignore that.
- (void)extension.doReplacePayloadOp(payload, replacement);
-}
-
DiagnosedSilenceableFailure
transform::FuseIntoContainingOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
@@ -787,6 +757,19 @@ transform::FuseIntoContainingOp::apply(transform::TransformRewriter &rewriter,
LLVM_DEBUG(DBGS() << "\nFused a direct extract use\n" << *containingOp);
fusedOps.append(tiledOps);
if (newContainingOp) {
+ // Update handles associated with the containing op so we don't need to
+ // invalidate them. This is a hack to support better composability
+ // between tiling and fusion while a proper mechanism is being
+ // investigated.
+ //
+ // DO NOT replicate this elsewhere unless you understand what you are
+ // doing.
+ LogicalResult replacementStatus =
+ rewriter.notifyPayloadOperationReplaced(containingOp,
+ newContainingOp);
+ (void)replacementStatus;
+ assert(succeeded(replacementStatus) &&
+ "unable to update transform state mapping");
rewriter.eraseOp(containingOp);
containingOp = newContainingOp;
}
@@ -813,14 +796,6 @@ transform::FuseIntoContainingOp::apply(transform::TransformRewriter &rewriter,
return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
}
- // Update handles associated with the containing op so we don't need to
- // invalidate them. This is a hack to support better composability between
- // tiling and fusion while a proper mechanism is being investigated.
- //
- // DO NOT replicate this elsewhere unless you understand what you are doing.
- forciblyReplaceReferencedPayloadOperation(state, *containingOps.begin(),
- containingOp);
-
results.set(cast<OpResult>(getFusedOp()), fusedOps);
results.set(cast<OpResult>(getNewContainingOp()), {containingOp});
return DiagnosedSilenceableFailure::success();
diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index 013d0e29ef5ce..27794e6636959 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -1452,6 +1452,11 @@ void transform::TransformRewriter::silenceTrackingFailure() {
}
}
+LogicalResult transform::TransformRewriter::notifyPayloadOperationReplaced(
+ Operation *op, Operation *replacement) {
+ return listener->replacePayloadOp(op, replacement);
+}
+
//===----------------------------------------------------------------------===//
// Utilities for TransformEachOpTrait.
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
index 3854cceb6273d..773405e61a4f8 100644
--- a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt --test-transform-dialect-interpreter --split-input-file %s | FileCheck %s
+// RUN: mlir-opt --test-transform-dialect-interpreter --split-input-file %s -verify-diagnostics | FileCheck %s
#map0 = affine_map<()[s0, s1] -> (s0 ceildiv s1)>
#map1 = affine_map<(d0)[s0] -> (d0 * s0)>
@@ -323,6 +323,7 @@ module {
// CHECK: %[[R0:.*]]:2 = scf.forall (%[[ARG5:.*]]) in (%{{.*}}) shared_outs(%[[ARG6:.*]] = %[[OUT_2]], %[[ARG7:.*]] = %[[OUT_1]])
// CHECK-SAME: -> (tensor<?xf32>, tensor<?xf32>) {
+ // expected-remark @below{{new containing op}}
%2 = scf.forall (%i) in (%1) shared_outs(%o = %out_2) -> (tensor<?xf32>) {
// CHECK: %[[I0:.*]] = affine.apply {{.*}}
%3 = affine.apply #map1(%i)[%idx]
@@ -350,8 +351,9 @@ module {
%1 = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.op<"scf.forall">
// linalg.generic is tileable. The op is tiled and fused.
- transform.structured.fuse_into_containing_op %0 into %1
+ %fused, %containing = transform.structured.fuse_into_containing_op %0 into %1
: (!transform.op<"linalg.generic">, !transform.op<"scf.forall">) -> (!transform.any_op, !transform.any_op)
+ test_print_remark_at_operand %containing, "new containing op" : !transform.any_op
}
}
diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index bfe61df4e4043..4ab9d65e6475b 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -1710,3 +1710,20 @@ transform.sequence failures(propagate) {
transform.annotate %0 "broadcast_attr" = %2 : !transform.any_op, !transform.param<i64>
transform.annotate %0 "unit_attr" : !transform.any_op
}
+
+// -----
+
+func.func @notify_payload_op_replaced(%arg0: index, %arg1: index) {
+ %0 = arith.muli %arg0, %arg1 {original} : index
+ // expected-remark @below{{updated handle}}
+ %1 = arith.muli %arg0, %arg1 {replacement} : index
+ return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+ %0 = transform.structured.match attributes{original} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.structured.match attributes{replacement} in %arg1 : (!transform.any_op) -> !transform.any_op
+ test_notify_payload_op_replaced %0, %1 : (!transform.any_op, !transform.any_op) -> ()
+ test_print_remark_at_operand %0, "updated handle" : !transform.any_op
+}
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
index 51c0615932b61..f3e80060f5cec 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
@@ -872,6 +872,34 @@ LogicalResult mlir::test::TestReEnterRegionOp::verify() {
return success();
}
+DiagnosedSilenceableFailure mlir::test::TestNotifyPayloadOpReplacedOp::apply(
+ transform::TransformRewriter &rewriter,
+ transform::TransformResults &results, transform::TransformState &state) {
+ auto originalOps = state.getPayloadOps(getOriginal());
+ auto replacementOps = state.getPayloadOps(getReplacement());
+ if (llvm::range_size(originalOps) != llvm::range_size(replacementOps))
+ return emitSilenceableError() << "expected same number of original and "
+ "replacement payload operations";
+ for (const auto &[original, replacement] :
+ llvm::zip(originalOps, replacementOps)) {
+ if (failed(
+ rewriter.notifyPayloadOperationReplaced(original, replacement))) {
+ auto diag = emitSilenceableError()
+ << "unable to replace payload op in transform mapping";
+ diag.attachNote(original->getLoc()) << "original payload op";
+ diag.attachNote(replacement->getLoc()) << "replacement payload op";
+ return diag;
+ }
+ }
+ return DiagnosedSilenceableFailure::success();
+}
+
+void mlir::test::TestNotifyPayloadOpReplacedOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ transform::onlyReadsHandle(getOriginal(), effects);
+ transform::onlyReadsHandle(getReplacement(), effects);
+}
+
namespace {
/// Test extension of the Transform dialect. Registers additional ops and
/// declares PDL as dependent dialect since the additional ops are using PDL
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
index 594c32d165d43..02f4955c30cc6 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
@@ -564,4 +564,15 @@ def TestReEnterRegionOp
let hasVerifier = 1;
}
+def TestNotifyPayloadOpReplacedOp
+ : Op<Transform_Dialect, "test_notify_payload_op_replaced",
+ [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ DeclareOpInterfaceMethods<TransformOpInterface>]> {
+ let arguments = (ins TransformHandleTypeInterface:$original,
+ TransformHandleTypeInterface:$replacement);
+ let results = (outs);
+ let assemblyFormat = "$original `,` $replacement attr-dict `:` functional-type(operands, results)";
+ let cppNamespace = "::mlir::test";
+}
+
#endif // MLIR_TESTTRANSFORMDIALECTEXTENSION_TD
More information about the Mlir-commits
mailing list