[Mlir-commits] [mlir] 0b52fa9 - [mlir][transform] Add ApplyPatternsOp and PatternRegistry

Matthias Springer llvmlistbot at llvm.org
Fri Jun 2 06:04:07 PDT 2023


Author: Matthias Springer
Date: 2023-06-02T14:58:20+02:00
New Revision: 0b52fa900aa3dca7b6b1873cb6ed78bf3ab42b18

URL: https://github.com/llvm/llvm-project/commit/0b52fa900aa3dca7b6b1873cb6ed78bf3ab42b18
DIFF: https://github.com/llvm/llvm-project/commit/0b52fa900aa3dca7b6b1873cb6ed78bf3ab42b18.diff

LOG: [mlir][transform] Add ApplyPatternsOp and PatternRegistry

Add a new transform op that applies patterns to a targeted payload op. Patterns can be registered by transform dialect extensions in a pattern registry.

Differential Revision: https://reviews.llvm.org/D151983

Added: 
    mlir/test/Dialect/Transform/test-pattern-application.mlir

Modified: 
    mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h
    mlir/include/mlir/Dialect/Transform/IR/TransformOps.h
    mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
    mlir/include/mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.h
    mlir/lib/Dialect/Transform/IR/TransformOps.cpp
    mlir/test/Dialect/Transform/ops-invalid.mlir
    mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h
index e156602ea886b..db27f2c6fc49b 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h
@@ -32,11 +32,18 @@ class TransformDialectDataBase {
 
 protected:
   /// Must be called by the subclass with the appropriate type ID.
-  explicit TransformDialectDataBase(TypeID typeID) : typeID(typeID) {}
+  explicit TransformDialectDataBase(TypeID typeID, MLIRContext *ctx)
+      : typeID(typeID), ctx(ctx) {}
+
+  /// Return the MLIR context.
+  MLIRContext *getContext() const { return ctx; }
 
 private:
   /// The type ID of the subclass.
   const TypeID typeID;
+
+  /// The MLIR context.
+  MLIRContext *ctx;
 };
 } // namespace detail
 
@@ -55,7 +62,8 @@ template <typename DerivedTy>
 class TransformDialectData : public detail::TransformDialectDataBase {
 protected:
   /// Forward the TypeID of the derived class to the base.
-  TransformDialectData() : TransformDialectDataBase(TypeID::get<DerivedTy>()) {}
+  TransformDialectData(MLIRContext *ctx)
+      : TransformDialectDataBase(TypeID::get<DerivedTy>(), ctx) {}
 };
 
 #ifndef NDEBUG
@@ -294,7 +302,8 @@ DataTy &TransformDialect::getOrCreateExtraData() {
   if (it != extraData.end())
     return static_cast<DataTy &>(*it->getSecond());
 
-  auto emplaced = extraData.try_emplace(typeID, std::make_unique<DataTy>());
+  auto emplaced =
+      extraData.try_emplace(typeID, std::make_unique<DataTy>(getContext()));
   return static_cast<DataTy &>(*emplaced.first->getSecond());
 }
 

diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h
index 56353a295c6ed..3e3461bb14f6e 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h
@@ -12,6 +12,7 @@
 #include "mlir/Bytecode/BytecodeOpInterface.h"
 #include "mlir/Dialect/Transform/IR/MatchInterfaces.h"
 #include "mlir/Dialect/Transform/IR/TransformAttrs.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
 #include "mlir/Dialect/Transform/IR/TransformTypes.h"
 #include "mlir/IR/FunctionInterfaces.h"
@@ -25,6 +26,8 @@
 
 namespace mlir {
 namespace transform {
+class ApplyPatternsOp;
+
 enum class FailurePropagationMode : uint32_t;
 class FailurePropagationModeAttr;
 
@@ -120,9 +123,71 @@ class TrackingListener : public RewriterBase::Listener,
   TransformOpInterface transformOp;
 };
 
+/// A specialized listener that keeps track of cases in which no replacement
+/// payload could be found. The error state of this listener must be checked
+/// before the end of its lifetime.
+class ErrorCheckingTrackingListener : public TrackingListener {
+public:
+  using transform::TrackingListener::TrackingListener;
+
+  ~ErrorCheckingTrackingListener() override;
+
+  /// Check and return the current error state of this listener. Afterwards,
+  /// resets the error state to "success".
+  DiagnosedSilenceableFailure checkAndResetError();
+
+  /// Return "true" if this tracking listener had a failure.
+  bool failed() const;
+
+protected:
+  void notifyPayloadReplacementNotFound(Operation *op,
+                                        ValueRange values) override;
+
+private:
+  /// The error state of this listener. "Success" indicates that no error
+  /// happened so far.
+  DiagnosedSilenceableFailure status = DiagnosedSilenceableFailure::success();
+
+  /// The number of errors that have been encountered.
+  int64_t errorCounter = 0;
+};
+
+/// The PatternRegistry stores callbacks to functions that populate a
+/// `RewritePatternSet`. Registered patterns can be applied with the
+/// "transform.apply_patterns" op.
+class PatternRegistry : public TransformDialectData<PatternRegistry> {
+public:
+  PatternRegistry(MLIRContext *ctx) : TransformDialectData(ctx), builder(ctx) {}
+
+  /// A function that populates a `RewritePatternSet`.
+  using PopulatePatternsFn = std::function<void(RewritePatternSet &)>;
+
+  /// Registers patterns with the specified identifier. The identifier should
+  /// be prefixed with the dialect to which the patterns belong.
+  void registerPatterns(StringRef identifier, PopulatePatternsFn &&fn);
+
+protected:
+  friend class ApplyPatternsOp;
+
+  /// Returns "true" if patterns are registered with the specified identifier.
+  bool hasPatterns(StringAttr identifier) const;
+
+  /// Populates the given pattern set with the specified patterns.
+  void populatePatterns(StringAttr identifier,
+                        RewritePatternSet &patternSet) const;
+
+private:
+  /// A builder for creating StringAttrs.
+  Builder builder;
+
+  DenseMap<StringAttr, PopulatePatternsFn> patterns;
+};
+
 } // namespace transform
 } // namespace mlir
 
+MLIR_DECLARE_EXPLICIT_TYPE_ID(mlir::transform::PatternRegistry)
+
 #define GET_OP_CLASSES
 #include "mlir/Dialect/Transform/IR/TransformOps.h.inc"
 

diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index 6036687017a55..57a7bd33acfc5 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -126,6 +126,49 @@ def AnnotateOp : TransformDialectOp<"annotate",
     "`:` type($target) (`,` type($param)^)?";
 }
 
+def ApplyPatternsOp : TransformDialectOp<"apply_patterns",
+    [TransformOpInterface, TransformEachOpTrait,
+     DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
+  let summary = "Greedily applies patterns to the body of the targeted op";
+  let description = [{
+    This transform greedily applies the specified patterns to the body of the
+    targeted op until a fixpoint was reached. Patterns are not applied to the
+    targeted op itself.
+
+    Only patterns that were registered in the transform dialect's
+    `PatternRegistry` are available. Additional patterns can be registered as
+    part of transform dialect extensions.
+
+    This transform only reads the target handle and modifies the payload. If a
+    pattern erases or replaces a tracked op, the mapping is updated accordingly.
+
+    Only replacements via `RewriterBase::replaceOp` or `replaceOpWithNewOp` are
+    considered "payload op replacements". Furthermore, only if the replacement
+    values are defined by the same op and that op has the same type as the
+    original op, the mapping is updated. Otherwise, this transform fails
+    silently unless `fail_on_payload_replacement_not_found` is set to "false".
+    More details can be found at the documentation site of `TrackingListener`.
+
+    This transform also fails silently if the pattern application did not
+    converge within the default number of iterations/rewrites of the greedy
+    pattern rewrite driver.
+  }];
+
+  let arguments = (ins
+    TransformHandleTypeInterface:$target, ArrayAttr:$patterns,
+    DefaultValuedAttr<BoolAttr, "true">:$fail_on_payload_replacement_not_found);
+  let results = (outs);
+  let assemblyFormat = "$patterns `to` $target attr-dict `:` type($target)";
+  let hasVerifier = 1;
+
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure applyToOne(
+      ::mlir::Operation *target,
+      ::mlir::transform::ApplyToEachResultList &results,
+      ::mlir::transform::TransformState &state);
+  }];
+}
+
 def CastOp : TransformDialectOp<"cast",
     [TransformOpInterface, TransformEachOpTrait,
      DeclareOpInterfaceMethods<CastOpInterface>,

diff  --git a/mlir/include/mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.h b/mlir/include/mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.h
index 9077e9fc9ffbd..5172bcf204e5f 100644
--- a/mlir/include/mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.h
+++ b/mlir/include/mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.h
@@ -27,6 +27,8 @@ namespace transform {
 /// populated by extensions.
 class PDLMatchHooks : public TransformDialectData<PDLMatchHooks> {
 public:
+  PDLMatchHooks(MLIRContext *ctx) : TransformDialectData(ctx) {}
+
   /// Takes ownership of the named PDL constraint function from the given
   /// map and makes them available for use by the operations in the dialect.
   void

diff  --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index f1a57f7087272..c076a8cab89ea 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -16,6 +16,7 @@
 #include "mlir/IR/FunctionImplementation.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/ScopeExit.h"
 #include "llvm/ADT/SmallPtrSet.h"
@@ -31,6 +32,8 @@
 
 using namespace mlir;
 
+MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::transform::PatternRegistry)
+
 static ParseResult parseSequenceOpOperands(
     OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root,
     Type &rootType,
@@ -175,6 +178,62 @@ void transform::TrackingListener::notifyOperationReplaced(
   (void)replacePayloadOp(op, replacement);
 }
 
+transform::ErrorCheckingTrackingListener::~ErrorCheckingTrackingListener() {
+  // The state of the ErrorCheckingTrackingListener must be checked and reset
+  // if there was an error. This is to prevent errors from accidentally being
+  // missed.
+  assert(status.succeeded() && "listener state was not checked");
+}
+
+DiagnosedSilenceableFailure
+transform::ErrorCheckingTrackingListener::checkAndResetError() {
+  DiagnosedSilenceableFailure s = std::move(status);
+  status = DiagnosedSilenceableFailure::success();
+  errorCounter = 0;
+  return s;
+}
+
+bool transform::ErrorCheckingTrackingListener::failed() const {
+  return !status.succeeded();
+}
+
+void transform::ErrorCheckingTrackingListener::notifyPayloadReplacementNotFound(
+    Operation *op, ValueRange values) {
+  if (status.succeeded()) {
+    status = emitSilenceableFailure(
+        getTransformOp(), "tracking listener failed to find replacement op");
+  }
+
+  status.attachNote(op->getLoc()) << "[" << errorCounter << "] replaced op";
+  for (auto &&[index, value] : llvm::enumerate(values))
+    status.attachNote(value.getLoc())
+        << "[" << errorCounter << "] replacement value " << index;
+
+  ++errorCounter;
+}
+
+//===----------------------------------------------------------------------===//
+// PatternRegistry
+//===----------------------------------------------------------------------===//
+
+void transform::PatternRegistry::registerPatterns(StringRef identifier,
+                                                  PopulatePatternsFn &&fn) {
+  StringAttr attr = builder.getStringAttr(identifier);
+  assert(!patterns.contains(attr) && "patterns identifier is already in use");
+  patterns.try_emplace(attr, std::move(fn));
+}
+
+void transform::PatternRegistry::populatePatterns(
+    StringAttr identifier, RewritePatternSet &patternSet) const {
+  auto it = patterns.find(identifier);
+  assert(it != patterns.end() && "patterns not registered in registry");
+  it->second(patternSet);
+}
+
+bool transform::PatternRegistry::hasPatterns(StringAttr identifier) const {
+  return patterns.contains(identifier);
+}
+
 //===----------------------------------------------------------------------===//
 // AlternativesOp
 //===----------------------------------------------------------------------===//
@@ -356,6 +415,77 @@ void transform::AnnotateOp::getEffects(
   modifiesPayload(effects);
 }
 
+//===----------------------------------------------------------------------===//
+// ApplyPatternsOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::ApplyPatternsOp::applyToOne(Operation *target,
+                                       ApplyToEachResultList &results,
+                                       transform::TransformState &state) {
+  // Gather all specified patterns.
+  MLIRContext *ctx = target->getContext();
+  RewritePatternSet patterns(ctx);
+  const auto &registry = getContext()
+                             ->getLoadedDialect<transform::TransformDialect>()
+                             ->getExtraData<transform::PatternRegistry>();
+  for (Attribute attr : getPatterns())
+    registry.populatePatterns(attr.cast<StringAttr>(), patterns);
+
+  // Configure the GreedyPatternRewriteDriver.
+  ErrorCheckingTrackingListener listener(state, *this);
+  GreedyRewriteConfig config;
+  config.listener = &listener;
+
+  // Manually gather list of ops because the other GreedyPatternRewriteDriver
+  // overloads only accepts ops that are isolated from above. This way, patterns
+  // can be applied to ops that are not isolated from above.
+  SmallVector<Operation *> ops;
+  target->walk([&](Operation *nestedOp) {
+    if (target != nestedOp)
+      ops.push_back(nestedOp);
+  });
+  LogicalResult result =
+      applyOpPatternsAndFold(ops, std::move(patterns), config);
+  // A failure typically indicates that the pattern application did not
+  // converge.
+  if (failed(result)) {
+    return emitSilenceableFailure(target)
+           << "greedy pattern application failed";
+  }
+
+  // Check listener state for tracking errors.
+  if (listener.failed()) {
+    DiagnosedSilenceableFailure status = listener.checkAndResetError();
+    if (getFailOnPayloadReplacementNotFound())
+      return status;
+    (void)status.silence();
+  }
+
+  return DiagnosedSilenceableFailure::success();
+}
+
+LogicalResult transform::ApplyPatternsOp::verify() {
+  const auto &registry = getContext()
+                             ->getLoadedDialect<transform::TransformDialect>()
+                             ->getExtraData<transform::PatternRegistry>();
+  for (Attribute attr : getPatterns()) {
+    auto strAttr = attr.dyn_cast<StringAttr>();
+    if (!strAttr)
+      return emitOpError() << "expected " << getPatternsAttrName()
+                           << " to be an array of strings";
+    if (!registry.hasPatterns(strAttr))
+      return emitOpError() << "patterns not registered: " << strAttr.strref();
+  }
+  return success();
+}
+
+void transform::ApplyPatternsOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  transform::onlyReadsHandle(getTarget(), effects);
+  transform::modifiesPayload(effects);
+}
+
 //===----------------------------------------------------------------------===//
 // CastOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Transform/ops-invalid.mlir b/mlir/test/Dialect/Transform/ops-invalid.mlir
index 74c101f79075a..6436c7d860c37 100644
--- a/mlir/test/Dialect/Transform/ops-invalid.mlir
+++ b/mlir/test/Dialect/Transform/ops-invalid.mlir
@@ -672,3 +672,19 @@ module attributes { transform.with_named_sequence } {
       @match -> @action : (!transform.any_op) -> !transform.any_op
   }
 }
+
+// -----
+
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op):
+  // expected-error @below {{patterns not registered: transform.invalid_pattern_identifier}}
+  transform.apply_patterns ["transform.invalid_pattern_identifier"] to %arg0 : !transform.any_op
+}
+
+// -----
+
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op):
+  // expected-error @below {{expected "patterns" to be an array of strings}}
+  transform.apply_patterns [3, 9] to %arg0 : !transform.any_op
+}

diff  --git a/mlir/test/Dialect/Transform/test-pattern-application.mlir b/mlir/test/Dialect/Transform/test-pattern-application.mlir
new file mode 100644
index 0000000000000..0df76d808f880
--- /dev/null
+++ b/mlir/test/Dialect/Transform/test-pattern-application.mlir
@@ -0,0 +1,123 @@
+// RUN: mlir-opt %s --test-transform-dialect-interpreter -allow-unregistered-dialect --split-input-file --verify-diagnostics | FileCheck %s
+
+// CHECK-LABEL: func @update_tracked_op_mapping()
+//       CHECK:   "test.container"() ({
+//       CHECK:     %0 = "test.foo"() {annotated} : () -> i32
+//       CHECK:   }) : () -> ()
+func.func @update_tracked_op_mapping() {
+  "test.container"() ({
+    %0 = "test.foo"() {replace_with_new_op = "test.foo"} : () -> (i32)
+  }) : () -> ()
+  return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+  %0 = transform.structured.match ops{["test.container"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+  %1 = transform.structured.match ops{["test.foo"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+  transform.apply_patterns ["transform.test"] to %0 : !transform.any_op
+  // Add an attribute to %1, which is now mapped to a new op.
+  transform.annotate %1 "annotated" : !transform.any_op
+}
+
+// -----
+
+func.func @replacement_op_not_found() {
+  "test.container"() ({
+    // expected-note @below {{[0] replaced op}}
+    // expected-note @below {{[0] replacement value 0}}
+    %0 = "test.foo"() {replace_with_new_op = "test.bar"} : () -> (i32)
+  }) : () -> ()
+  return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+  %0 = transform.structured.match ops{["test.container"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+  %1 = transform.structured.match ops{["test.foo"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+  // expected-error @below {{tracking listener failed to find replacement op}}
+  transform.apply_patterns ["transform.test"] to %0 : !transform.any_op
+  // %1 must be used in some way. If no replacement payload op could be found,
+  // an error is thrown only if the handle is not dead.
+  transform.annotate %1 "annotated" : !transform.any_op
+}
+
+// -----
+
+// CHECK-LABEL: func @replacement_op_for_dead_handle_not_found()
+//       CHECK:   "test.container"() ({
+//       CHECK:     %0 = "test.bar"() : () -> i32
+//       CHECK:   }) : () -> ()
+func.func @replacement_op_for_dead_handle_not_found() {
+  "test.container"() ({
+    %0 = "test.foo"() {replace_with_new_op = "test.bar"} : () -> (i32)
+  }) : () -> ()
+  return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+  %0 = transform.structured.match ops{["test.container"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+  %1 = transform.structured.match ops{["test.foo"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+  // No error because %1 is dead.
+  transform.apply_patterns ["transform.test"] to %0 : !transform.any_op
+}
+
+// -----
+
+// CHECK-LABEL: func @replacement_op_not_found_silenced()
+//       CHECK:   "test.container"() ({
+//       CHECK:     %0 = "test.bar"() : () -> i32
+//       CHECK:   }) : () -> ()
+func.func @replacement_op_not_found_silenced() {
+  "test.container"() ({
+    %0 = "test.foo"() {replace_with_new_op = "test.bar"} : () -> (i32)
+  }) : () -> ()
+  return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+  %0 = transform.structured.match ops{["test.container"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+  %1 = transform.structured.match ops{["test.foo"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+  transform.apply_patterns ["transform.test"] to %0 {fail_on_payload_replacement_not_found = false}: !transform.any_op
+  transform.annotate %1 "annotated" : !transform.any_op
+}
+
+// -----
+
+// CHECK-LABEL: func @patterns_apply_only_to_target_body()
+//       CHECK:   %0 = "test.foo"() {replace_with_new_op = "test.bar"} : () -> i32
+func.func @patterns_apply_only_to_target_body() {
+  %0 = "test.foo"() {replace_with_new_op = "test.bar"} : () -> (i32)
+  return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+  %0 = transform.structured.match ops{["test.foo"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+  transform.apply_patterns ["transform.test"] to %0 : !transform.any_op
+}
+
+// -----
+
+// CHECK-LABEL: func @erase_tracked_op()
+//       CHECK:   "test.container"() ({
+//  CHECK-NEXT:   ^bb0:
+//  CHECK-NEXT:   }) : () -> ()
+func.func @erase_tracked_op() {
+  "test.container"() ({
+    // expected-remark @below {{matched op}}
+    %0 = "test.erase_op"() {replace_with_new_op = "test.foo"} : () -> (i32)
+  }) : () -> ()
+  return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+  %0 = transform.structured.match ops{["test.container"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+  %1 = transform.structured.match ops{["test.erase_op"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+  transform.test_print_remark_at_operand %1, "matched op" : !transform.any_op
+  transform.apply_patterns ["transform.test"] to %0 : !transform.any_op
+  transform.test_print_remark_at_operand %1, "op was deleted" : !transform.any_op
+}

diff  --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
index f3b6c19d90b16..9af4c53cb1c86 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
@@ -746,6 +746,41 @@ mlir::test::TestTrackedRewriteOp::apply(transform::TransformResults &results,
 }
 
 namespace {
+// Test pattern to replace an operation with a new op.
+class ReplaceWithNewOp : public RewritePattern {
+public:
+  ReplaceWithNewOp(MLIRContext *context)
+      : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
+
+  LogicalResult matchAndRewrite(Operation *op,
+                                PatternRewriter &rewriter) const override {
+    auto newName = op->getAttrOfType<StringAttr>("replace_with_new_op");
+    if (!newName)
+      return failure();
+    Operation *newOp = rewriter.create(
+        op->getLoc(), OperationName(newName, op->getContext()).getIdentifier(),
+        op->getOperands(), op->getResultTypes());
+    rewriter.replaceOp(op, newOp->getResults());
+    return success();
+  }
+};
+
+// Test pattern to erase an operation.
+class EraseOp : public RewritePattern {
+public:
+  EraseOp(MLIRContext *context)
+      : RewritePattern("test.erase_op", /*benefit=*/1, context) {}
+  LogicalResult matchAndRewrite(Operation *op,
+                                PatternRewriter &rewriter) const override {
+    rewriter.eraseOp(op);
+    return success();
+  }
+};
+
+void populateTestPatterns(RewritePatternSet &patterns) {
+  patterns.insert<ReplaceWithNewOp, EraseOp>(patterns.getContext());
+}
+
 /// Test extension of the Transform dialect. Registers additional ops and
 /// declares PDL as dependent dialect since the additional ops are using PDL
 /// types for operands and results.
@@ -783,6 +818,11 @@ class TestTransformDialectExtension
           constraints.try_emplace("verbose_constraint", verboseConstraint);
           hooks.mergeInPDLMatchHooks(std::move(constraints));
         });
+
+    addDialectDataInitializer<transform::PatternRegistry>(
+        [&](transform::PatternRegistry &registry) {
+          registry.registerPatterns("transform.test", populateTestPatterns);
+        });
   }
 };
 } // namespace


        


More information about the Mlir-commits mailing list