[Mlir-commits] [mlir] 0b52fa9 - [mlir][transform] Add ApplyPatternsOp and PatternRegistry
Matthias Springer
llvmlistbot at llvm.org
Fri Jun 2 06:04:07 PDT 2023
Author: Matthias Springer
Date: 2023-06-02T14:58:20+02:00
New Revision: 0b52fa900aa3dca7b6b1873cb6ed78bf3ab42b18
URL: https://github.com/llvm/llvm-project/commit/0b52fa900aa3dca7b6b1873cb6ed78bf3ab42b18
DIFF: https://github.com/llvm/llvm-project/commit/0b52fa900aa3dca7b6b1873cb6ed78bf3ab42b18.diff
LOG: [mlir][transform] Add ApplyPatternsOp and PatternRegistry
Add a new transform op that applies patterns to a targeted payload op. Patterns can be registered by transform dialect extensions in a pattern registry.
Differential Revision: https://reviews.llvm.org/D151983
Added:
mlir/test/Dialect/Transform/test-pattern-application.mlir
Modified:
mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h
mlir/include/mlir/Dialect/Transform/IR/TransformOps.h
mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
mlir/include/mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.h
mlir/lib/Dialect/Transform/IR/TransformOps.cpp
mlir/test/Dialect/Transform/ops-invalid.mlir
mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h
index e156602ea886b..db27f2c6fc49b 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h
@@ -32,11 +32,18 @@ class TransformDialectDataBase {
protected:
/// Must be called by the subclass with the appropriate type ID.
- explicit TransformDialectDataBase(TypeID typeID) : typeID(typeID) {}
+ explicit TransformDialectDataBase(TypeID typeID, MLIRContext *ctx)
+ : typeID(typeID), ctx(ctx) {}
+
+ /// Return the MLIR context.
+ MLIRContext *getContext() const { return ctx; }
private:
/// The type ID of the subclass.
const TypeID typeID;
+
+ /// The MLIR context.
+ MLIRContext *ctx;
};
} // namespace detail
@@ -55,7 +62,8 @@ template <typename DerivedTy>
class TransformDialectData : public detail::TransformDialectDataBase {
protected:
/// Forward the TypeID of the derived class to the base.
- TransformDialectData() : TransformDialectDataBase(TypeID::get<DerivedTy>()) {}
+ TransformDialectData(MLIRContext *ctx)
+ : TransformDialectDataBase(TypeID::get<DerivedTy>(), ctx) {}
};
#ifndef NDEBUG
@@ -294,7 +302,8 @@ DataTy &TransformDialect::getOrCreateExtraData() {
if (it != extraData.end())
return static_cast<DataTy &>(*it->getSecond());
- auto emplaced = extraData.try_emplace(typeID, std::make_unique<DataTy>());
+ auto emplaced =
+ extraData.try_emplace(typeID, std::make_unique<DataTy>(getContext()));
return static_cast<DataTy &>(*emplaced.first->getSecond());
}
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h
index 56353a295c6ed..3e3461bb14f6e 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h
@@ -12,6 +12,7 @@
#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/Dialect/Transform/IR/MatchInterfaces.h"
#include "mlir/Dialect/Transform/IR/TransformAttrs.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
#include "mlir/IR/FunctionInterfaces.h"
@@ -25,6 +26,8 @@
namespace mlir {
namespace transform {
+class ApplyPatternsOp;
+
enum class FailurePropagationMode : uint32_t;
class FailurePropagationModeAttr;
@@ -120,9 +123,71 @@ class TrackingListener : public RewriterBase::Listener,
TransformOpInterface transformOp;
};
+/// A specialized listener that keeps track of cases in which no replacement
+/// payload could be found. The error state of this listener must be checked
+/// before the end of its lifetime.
+class ErrorCheckingTrackingListener : public TrackingListener {
+public:
+ using transform::TrackingListener::TrackingListener;
+
+ ~ErrorCheckingTrackingListener() override;
+
+ /// Check and return the current error state of this listener. Afterwards,
+ /// resets the error state to "success".
+ DiagnosedSilenceableFailure checkAndResetError();
+
+ /// Return "true" if this tracking listener had a failure.
+ bool failed() const;
+
+protected:
+ void notifyPayloadReplacementNotFound(Operation *op,
+ ValueRange values) override;
+
+private:
+ /// The error state of this listener. "Success" indicates that no error
+ /// happened so far.
+ DiagnosedSilenceableFailure status = DiagnosedSilenceableFailure::success();
+
+ /// The number of errors that have been encountered.
+ int64_t errorCounter = 0;
+};
+
+/// The PatternRegistry stores callbacks to functions that populate a
+/// `RewritePatternSet`. Registered patterns can be applied with the
+/// "transform.apply_patterns" op.
+class PatternRegistry : public TransformDialectData<PatternRegistry> {
+public:
+ PatternRegistry(MLIRContext *ctx) : TransformDialectData(ctx), builder(ctx) {}
+
+ /// A function that populates a `RewritePatternSet`.
+ using PopulatePatternsFn = std::function<void(RewritePatternSet &)>;
+
+ /// Registers patterns with the specified identifier. The identifier should
+ /// be prefixed with the dialect to which the patterns belong.
+ void registerPatterns(StringRef identifier, PopulatePatternsFn &&fn);
+
+protected:
+ friend class ApplyPatternsOp;
+
+ /// Returns "true" if patterns are registered with the specified identifier.
+ bool hasPatterns(StringAttr identifier) const;
+
+ /// Populates the given pattern set with the specified patterns.
+ void populatePatterns(StringAttr identifier,
+ RewritePatternSet &patternSet) const;
+
+private:
+ /// A builder for creating StringAttrs.
+ Builder builder;
+
+ DenseMap<StringAttr, PopulatePatternsFn> patterns;
+};
+
} // namespace transform
} // namespace mlir
+MLIR_DECLARE_EXPLICIT_TYPE_ID(mlir::transform::PatternRegistry)
+
#define GET_OP_CLASSES
#include "mlir/Dialect/Transform/IR/TransformOps.h.inc"
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index 6036687017a55..57a7bd33acfc5 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -126,6 +126,49 @@ def AnnotateOp : TransformDialectOp<"annotate",
"`:` type($target) (`,` type($param)^)?";
}
+def ApplyPatternsOp : TransformDialectOp<"apply_patterns",
+ [TransformOpInterface, TransformEachOpTrait,
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
+ let summary = "Greedily applies patterns to the body of the targeted op";
+ let description = [{
+ This transform greedily applies the specified patterns to the body of the
+ targeted op until a fixpoint was reached. Patterns are not applied to the
+ targeted op itself.
+
+ Only patterns that were registered in the transform dialect's
+ `PatternRegistry` are available. Additional patterns can be registered as
+ part of transform dialect extensions.
+
+ This transform only reads the target handle and modifies the payload. If a
+ pattern erases or replaces a tracked op, the mapping is updated accordingly.
+
+ Only replacements via `RewriterBase::replaceOp` or `replaceOpWithNewOp` are
+ considered "payload op replacements". Furthermore, only if the replacement
+ values are defined by the same op and that op has the same type as the
+ original op, the mapping is updated. Otherwise, this transform fails
+ silently unless `fail_on_payload_replacement_not_found` is set to "false".
+ More details can be found at the documentation site of `TrackingListener`.
+
+ This transform also fails silently if the pattern application did not
+ converge within the default number of iterations/rewrites of the greedy
+ pattern rewrite driver.
+ }];
+
+ let arguments = (ins
+ TransformHandleTypeInterface:$target, ArrayAttr:$patterns,
+ DefaultValuedAttr<BoolAttr, "true">:$fail_on_payload_replacement_not_found);
+ let results = (outs);
+ let assemblyFormat = "$patterns `to` $target attr-dict `:` type($target)";
+ let hasVerifier = 1;
+
+ let extraClassDeclaration = [{
+ ::mlir::DiagnosedSilenceableFailure applyToOne(
+ ::mlir::Operation *target,
+ ::mlir::transform::ApplyToEachResultList &results,
+ ::mlir::transform::TransformState &state);
+ }];
+}
+
def CastOp : TransformDialectOp<"cast",
[TransformOpInterface, TransformEachOpTrait,
DeclareOpInterfaceMethods<CastOpInterface>,
diff --git a/mlir/include/mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.h b/mlir/include/mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.h
index 9077e9fc9ffbd..5172bcf204e5f 100644
--- a/mlir/include/mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.h
+++ b/mlir/include/mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.h
@@ -27,6 +27,8 @@ namespace transform {
/// populated by extensions.
class PDLMatchHooks : public TransformDialectData<PDLMatchHooks> {
public:
+ PDLMatchHooks(MLIRContext *ctx) : TransformDialectData(ctx) {}
+
/// Takes ownership of the named PDL constraint function from the given
/// map and makes them available for use by the operations in the dialect.
void
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index f1a57f7087272..c076a8cab89ea 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -16,6 +16,7 @@
#include "mlir/IR/FunctionImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/SmallPtrSet.h"
@@ -31,6 +32,8 @@
using namespace mlir;
+MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::transform::PatternRegistry)
+
static ParseResult parseSequenceOpOperands(
OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root,
Type &rootType,
@@ -175,6 +178,62 @@ void transform::TrackingListener::notifyOperationReplaced(
(void)replacePayloadOp(op, replacement);
}
+transform::ErrorCheckingTrackingListener::~ErrorCheckingTrackingListener() {
+ // The state of the ErrorCheckingTrackingListener must be checked and reset
+ // if there was an error. This is to prevent errors from accidentally being
+ // missed.
+ assert(status.succeeded() && "listener state was not checked");
+}
+
+DiagnosedSilenceableFailure
+transform::ErrorCheckingTrackingListener::checkAndResetError() {
+ DiagnosedSilenceableFailure s = std::move(status);
+ status = DiagnosedSilenceableFailure::success();
+ errorCounter = 0;
+ return s;
+}
+
+bool transform::ErrorCheckingTrackingListener::failed() const {
+ return !status.succeeded();
+}
+
+void transform::ErrorCheckingTrackingListener::notifyPayloadReplacementNotFound(
+ Operation *op, ValueRange values) {
+ if (status.succeeded()) {
+ status = emitSilenceableFailure(
+ getTransformOp(), "tracking listener failed to find replacement op");
+ }
+
+ status.attachNote(op->getLoc()) << "[" << errorCounter << "] replaced op";
+ for (auto &&[index, value] : llvm::enumerate(values))
+ status.attachNote(value.getLoc())
+ << "[" << errorCounter << "] replacement value " << index;
+
+ ++errorCounter;
+}
+
+//===----------------------------------------------------------------------===//
+// PatternRegistry
+//===----------------------------------------------------------------------===//
+
+void transform::PatternRegistry::registerPatterns(StringRef identifier,
+ PopulatePatternsFn &&fn) {
+ StringAttr attr = builder.getStringAttr(identifier);
+ assert(!patterns.contains(attr) && "patterns identifier is already in use");
+ patterns.try_emplace(attr, std::move(fn));
+}
+
+void transform::PatternRegistry::populatePatterns(
+ StringAttr identifier, RewritePatternSet &patternSet) const {
+ auto it = patterns.find(identifier);
+ assert(it != patterns.end() && "patterns not registered in registry");
+ it->second(patternSet);
+}
+
+bool transform::PatternRegistry::hasPatterns(StringAttr identifier) const {
+ return patterns.contains(identifier);
+}
+
//===----------------------------------------------------------------------===//
// AlternativesOp
//===----------------------------------------------------------------------===//
@@ -356,6 +415,77 @@ void transform::AnnotateOp::getEffects(
modifiesPayload(effects);
}
+//===----------------------------------------------------------------------===//
+// ApplyPatternsOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::ApplyPatternsOp::applyToOne(Operation *target,
+ ApplyToEachResultList &results,
+ transform::TransformState &state) {
+ // Gather all specified patterns.
+ MLIRContext *ctx = target->getContext();
+ RewritePatternSet patterns(ctx);
+ const auto ®istry = getContext()
+ ->getLoadedDialect<transform::TransformDialect>()
+ ->getExtraData<transform::PatternRegistry>();
+ for (Attribute attr : getPatterns())
+ registry.populatePatterns(attr.cast<StringAttr>(), patterns);
+
+ // Configure the GreedyPatternRewriteDriver.
+ ErrorCheckingTrackingListener listener(state, *this);
+ GreedyRewriteConfig config;
+ config.listener = &listener;
+
+ // Manually gather list of ops because the other GreedyPatternRewriteDriver
+ // overloads only accepts ops that are isolated from above. This way, patterns
+ // can be applied to ops that are not isolated from above.
+ SmallVector<Operation *> ops;
+ target->walk([&](Operation *nestedOp) {
+ if (target != nestedOp)
+ ops.push_back(nestedOp);
+ });
+ LogicalResult result =
+ applyOpPatternsAndFold(ops, std::move(patterns), config);
+ // A failure typically indicates that the pattern application did not
+ // converge.
+ if (failed(result)) {
+ return emitSilenceableFailure(target)
+ << "greedy pattern application failed";
+ }
+
+ // Check listener state for tracking errors.
+ if (listener.failed()) {
+ DiagnosedSilenceableFailure status = listener.checkAndResetError();
+ if (getFailOnPayloadReplacementNotFound())
+ return status;
+ (void)status.silence();
+ }
+
+ return DiagnosedSilenceableFailure::success();
+}
+
+LogicalResult transform::ApplyPatternsOp::verify() {
+ const auto ®istry = getContext()
+ ->getLoadedDialect<transform::TransformDialect>()
+ ->getExtraData<transform::PatternRegistry>();
+ for (Attribute attr : getPatterns()) {
+ auto strAttr = attr.dyn_cast<StringAttr>();
+ if (!strAttr)
+ return emitOpError() << "expected " << getPatternsAttrName()
+ << " to be an array of strings";
+ if (!registry.hasPatterns(strAttr))
+ return emitOpError() << "patterns not registered: " << strAttr.strref();
+ }
+ return success();
+}
+
+void transform::ApplyPatternsOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ transform::onlyReadsHandle(getTarget(), effects);
+ transform::modifiesPayload(effects);
+}
+
//===----------------------------------------------------------------------===//
// CastOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Transform/ops-invalid.mlir b/mlir/test/Dialect/Transform/ops-invalid.mlir
index 74c101f79075a..6436c7d860c37 100644
--- a/mlir/test/Dialect/Transform/ops-invalid.mlir
+++ b/mlir/test/Dialect/Transform/ops-invalid.mlir
@@ -672,3 +672,19 @@ module attributes { transform.with_named_sequence } {
@match -> @action : (!transform.any_op) -> !transform.any_op
}
}
+
+// -----
+
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op):
+ // expected-error @below {{patterns not registered: transform.invalid_pattern_identifier}}
+ transform.apply_patterns ["transform.invalid_pattern_identifier"] to %arg0 : !transform.any_op
+}
+
+// -----
+
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op):
+ // expected-error @below {{expected "patterns" to be an array of strings}}
+ transform.apply_patterns [3, 9] to %arg0 : !transform.any_op
+}
diff --git a/mlir/test/Dialect/Transform/test-pattern-application.mlir b/mlir/test/Dialect/Transform/test-pattern-application.mlir
new file mode 100644
index 0000000000000..0df76d808f880
--- /dev/null
+++ b/mlir/test/Dialect/Transform/test-pattern-application.mlir
@@ -0,0 +1,123 @@
+// RUN: mlir-opt %s --test-transform-dialect-interpreter -allow-unregistered-dialect --split-input-file --verify-diagnostics | FileCheck %s
+
+// CHECK-LABEL: func @update_tracked_op_mapping()
+// CHECK: "test.container"() ({
+// CHECK: %0 = "test.foo"() {annotated} : () -> i32
+// CHECK: }) : () -> ()
+func.func @update_tracked_op_mapping() {
+ "test.container"() ({
+ %0 = "test.foo"() {replace_with_new_op = "test.foo"} : () -> (i32)
+ }) : () -> ()
+ return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+ %0 = transform.structured.match ops{["test.container"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.structured.match ops{["test.foo"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns ["transform.test"] to %0 : !transform.any_op
+ // Add an attribute to %1, which is now mapped to a new op.
+ transform.annotate %1 "annotated" : !transform.any_op
+}
+
+// -----
+
+func.func @replacement_op_not_found() {
+ "test.container"() ({
+ // expected-note @below {{[0] replaced op}}
+ // expected-note @below {{[0] replacement value 0}}
+ %0 = "test.foo"() {replace_with_new_op = "test.bar"} : () -> (i32)
+ }) : () -> ()
+ return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+ %0 = transform.structured.match ops{["test.container"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.structured.match ops{["test.foo"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // expected-error @below {{tracking listener failed to find replacement op}}
+ transform.apply_patterns ["transform.test"] to %0 : !transform.any_op
+ // %1 must be used in some way. If no replacement payload op could be found,
+ // an error is thrown only if the handle is not dead.
+ transform.annotate %1 "annotated" : !transform.any_op
+}
+
+// -----
+
+// CHECK-LABEL: func @replacement_op_for_dead_handle_not_found()
+// CHECK: "test.container"() ({
+// CHECK: %0 = "test.bar"() : () -> i32
+// CHECK: }) : () -> ()
+func.func @replacement_op_for_dead_handle_not_found() {
+ "test.container"() ({
+ %0 = "test.foo"() {replace_with_new_op = "test.bar"} : () -> (i32)
+ }) : () -> ()
+ return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+ %0 = transform.structured.match ops{["test.container"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.structured.match ops{["test.foo"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // No error because %1 is dead.
+ transform.apply_patterns ["transform.test"] to %0 : !transform.any_op
+}
+
+// -----
+
+// CHECK-LABEL: func @replacement_op_not_found_silenced()
+// CHECK: "test.container"() ({
+// CHECK: %0 = "test.bar"() : () -> i32
+// CHECK: }) : () -> ()
+func.func @replacement_op_not_found_silenced() {
+ "test.container"() ({
+ %0 = "test.foo"() {replace_with_new_op = "test.bar"} : () -> (i32)
+ }) : () -> ()
+ return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+ %0 = transform.structured.match ops{["test.container"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.structured.match ops{["test.foo"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns ["transform.test"] to %0 {fail_on_payload_replacement_not_found = false}: !transform.any_op
+ transform.annotate %1 "annotated" : !transform.any_op
+}
+
+// -----
+
+// CHECK-LABEL: func @patterns_apply_only_to_target_body()
+// CHECK: %0 = "test.foo"() {replace_with_new_op = "test.bar"} : () -> i32
+func.func @patterns_apply_only_to_target_body() {
+ %0 = "test.foo"() {replace_with_new_op = "test.bar"} : () -> (i32)
+ return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+ %0 = transform.structured.match ops{["test.foo"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns ["transform.test"] to %0 : !transform.any_op
+}
+
+// -----
+
+// CHECK-LABEL: func @erase_tracked_op()
+// CHECK: "test.container"() ({
+// CHECK-NEXT: ^bb0:
+// CHECK-NEXT: }) : () -> ()
+func.func @erase_tracked_op() {
+ "test.container"() ({
+ // expected-remark @below {{matched op}}
+ %0 = "test.erase_op"() {replace_with_new_op = "test.foo"} : () -> (i32)
+ }) : () -> ()
+ return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+ %0 = transform.structured.match ops{["test.container"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.structured.match ops{["test.erase_op"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.test_print_remark_at_operand %1, "matched op" : !transform.any_op
+ transform.apply_patterns ["transform.test"] to %0 : !transform.any_op
+ transform.test_print_remark_at_operand %1, "op was deleted" : !transform.any_op
+}
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
index f3b6c19d90b16..9af4c53cb1c86 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
@@ -746,6 +746,41 @@ mlir::test::TestTrackedRewriteOp::apply(transform::TransformResults &results,
}
namespace {
+// Test pattern to replace an operation with a new op.
+class ReplaceWithNewOp : public RewritePattern {
+public:
+ ReplaceWithNewOp(MLIRContext *context)
+ : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
+
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override {
+ auto newName = op->getAttrOfType<StringAttr>("replace_with_new_op");
+ if (!newName)
+ return failure();
+ Operation *newOp = rewriter.create(
+ op->getLoc(), OperationName(newName, op->getContext()).getIdentifier(),
+ op->getOperands(), op->getResultTypes());
+ rewriter.replaceOp(op, newOp->getResults());
+ return success();
+ }
+};
+
+// Test pattern to erase an operation.
+class EraseOp : public RewritePattern {
+public:
+ EraseOp(MLIRContext *context)
+ : RewritePattern("test.erase_op", /*benefit=*/1, context) {}
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override {
+ rewriter.eraseOp(op);
+ return success();
+ }
+};
+
+void populateTestPatterns(RewritePatternSet &patterns) {
+ patterns.insert<ReplaceWithNewOp, EraseOp>(patterns.getContext());
+}
+
/// Test extension of the Transform dialect. Registers additional ops and
/// declares PDL as dependent dialect since the additional ops are using PDL
/// types for operands and results.
@@ -783,6 +818,11 @@ class TestTransformDialectExtension
constraints.try_emplace("verbose_constraint", verboseConstraint);
hooks.mergeInPDLMatchHooks(std::move(constraints));
});
+
+ addDialectDataInitializer<transform::PatternRegistry>(
+ [&](transform::PatternRegistry ®istry) {
+ registry.registerPatterns("transform.test", populateTestPatterns);
+ });
}
};
} // namespace
More information about the Mlir-commits
mailing list