[Mlir-commits] [mlir] 94d608d - [mlir] move PDL-related transform ops into an extension

Alex Zinenko llvmlistbot at llvm.org
Wed May 24 05:25:13 PDT 2023


Author: Alex Zinenko
Date: 2023-05-24T12:25:06Z
New Revision: 94d608d410267db693aa85070263e2b4ef0be913

URL: https://github.com/llvm/llvm-project/commit/94d608d410267db693aa85070263e2b4ef0be913
DIFF: https://github.com/llvm/llvm-project/commit/94d608d410267db693aa85070263e2b4ef0be913.diff

LOG: [mlir] move PDL-related transform ops into an extension

The initial bring-up of the Transform dialect relied on PDL to provide
the default handle type (`!pdl.operation`) and the matching capability.
Both are now provided natively by the Transform dialect removing the
reason to have a hard dependency on the PDL dialect and its interpreter.
Move PDL-related transform operations into a separate extension.

This requires us to introduce a dialect state extension mechanism into
the Transform dialect so it no longer needs to know about PDL constraint
functions that may be injected by extensions similarly to operations and
types. This mechanism will be reused to connect pattern application
drivers and the Transform dialect.

This completes the restructuring of the Transform dialect to remove
overrilance on PDL.

Note to downstreams: flow that are using `!pdl.operation` with Transform
dialect operations will now require `transform::PDLExtension` to be
applied to the transform dialect in order to provide the transform
handle type interface for `!pdl.operation`.

Reviewed By: springerm

Differential Revision: https://reviews.llvm.org/D151104

Added: 
    mlir/include/mlir/Dialect/Transform/PDLExtension/CMakeLists.txt
    mlir/include/mlir/Dialect/Transform/PDLExtension/PDLExtension.h
    mlir/include/mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.h
    mlir/include/mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.td
    mlir/lib/Dialect/Transform/PDLExtension/CMakeLists.txt
    mlir/lib/Dialect/Transform/PDLExtension/PDLExtension.cpp
    mlir/lib/Dialect/Transform/PDLExtension/PDLExtensionOps.cpp
    mlir/python/mlir/dialects/TransformPDLExtensionOps.td
    mlir/python/mlir/dialects/_transform_pdl_extension_ops_ext.py
    mlir/python/mlir/dialects/transform/pdl.py
    mlir/test/Dialect/Transform/test-pdl-extension.mlir

Modified: 
    mlir/include/mlir/Dialect/Transform/CMakeLists.txt
    mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h
    mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
    mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
    mlir/include/mlir/Dialect/Transform/IR/TransformOps.h
    mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
    mlir/include/mlir/InitAllDialects.h
    mlir/lib/Dialect/Transform/CMakeLists.txt
    mlir/lib/Dialect/Transform/IR/CMakeLists.txt
    mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
    mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
    mlir/lib/Dialect/Transform/IR/TransformOps.cpp
    mlir/python/CMakeLists.txt
    mlir/python/mlir/dialects/_transform_ops_ext.py
    mlir/test/Dialect/Transform/test-interpreter.mlir
    mlir/test/lib/Dialect/Transform/CMakeLists.txt
    mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
    mlir/test/python/dialects/transform.py
    mlir/test/python/dialects/transform_structured_ext.py
    utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
    utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel
    utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Transform/CMakeLists.txt b/mlir/include/mlir/Dialect/Transform/CMakeLists.txt
index 9f57627c321fb..d9fbaee802398 100644
--- a/mlir/include/mlir/Dialect/Transform/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Transform/CMakeLists.txt
@@ -1,2 +1,3 @@
 add_subdirectory(IR)
+add_subdirectory(PDLExtension)
 add_subdirectory(Transforms)

diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h
index 36712add2eb05..e156602ea886b 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h
@@ -12,12 +12,52 @@
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Support/LLVM.h"
+#include "mlir/Support/TypeID.h"
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/StringMap.h"
 #include <optional>
 
 namespace mlir {
 namespace transform {
+
+namespace detail {
+/// Concrete base class for CRTP TransformDialectDataBase. Must not be used
+/// directly.
+class TransformDialectDataBase {
+public:
+  virtual ~TransformDialectDataBase() = default;
+
+  /// Returns the dynamic type ID of the subclass.
+  TypeID getTypeID() const { return typeID; }
+
+protected:
+  /// Must be called by the subclass with the appropriate type ID.
+  explicit TransformDialectDataBase(TypeID typeID) : typeID(typeID) {}
+
+private:
+  /// The type ID of the subclass.
+  const TypeID typeID;
+};
+} // namespace detail
+
+/// Base class for additional data owned by the Transform dialect. Extensions
+/// may communicate with each other using this data. The data object is
+/// identified by the TypeID of the specific data subclass, querying the data of
+/// the same subclass returns a reference to the same object. When a Transform
+/// dialect extension is initialized, it can populate the data in the specific
+/// subclass. When a Transform op is applied, it can read (but not mutate) the
+/// data in the specific subclass, including the data provided by other
+/// extensions.
+///
+/// This follows CRTP: derived classes must list themselves as template
+/// argument.
+template <typename DerivedTy>
+class TransformDialectData : public detail::TransformDialectDataBase {
+protected:
+  /// Forward the TypeID of the derived class to the base.
+  TransformDialectData() : TransformDialectDataBase(TypeID::get<DerivedTy>()) {}
+};
+
 #ifndef NDEBUG
 namespace detail {
 /// Asserts that the operations provided as template arguments implement the
@@ -85,9 +125,8 @@ class TransformDialectExtension
       for (const DialectLoader &loader : generatedDialectLoaders)
         loader(context);
 
-    for (const Initializer &init : opInitializers)
+    for (const Initializer &init : initializers)
       init(transformDialect);
-    transformDialect->mergeInPDLMatchHooks(std::move(pdlMatchConstraintFns));
   }
 
 protected:
@@ -100,6 +139,41 @@ class TransformDialectExtension
     static_cast<DerivedTy *>(this)->init();
   }
 
+  /// Registers a custom initialization step to be performed when the extension
+  /// is applied to the dialect while loading. This is discouraged in favor of
+  /// more specific calls `declareGeneratedDialect`, `addDialectDataInitializer`
+  /// etc. `Func` must be convertible to the `void (MLIRContext *)` form. It
+  /// will be called during the extension initialization and given the current
+  /// MLIR context. This may be used to attach additional interfaces that cannot
+  /// be attached elsewhere.
+  template <typename Func>
+  void addCustomInitializationStep(Func &&func) {
+    std::function<void(MLIRContext *)> initializer = func;
+    dialectLoaders.push_back(
+        [init = std::move(initializer)](MLIRContext *ctx) { init(ctx); });
+  }
+
+  /// Registers the given function as one of the initializers for the
+  /// dialect-owned data of the kind specified as template argument. The
+  /// function must be convertible to the `void (DataTy &)` form. It will be
+  /// called during the extension initialization and will be given a mutable
+  /// reference to `DataTy`. The callback is expected to append data to the
+  /// given storage, and is not allowed to remove or destructively mutate the
+  /// existing data. The order in which callbacks from 
diff erent extensions are
+  /// executed is unspecified so the callbacks may not rely on data being
+  /// already present. `DataTy` must be a class deriving `TransformDialectData`.
+  template <typename DataTy, typename Func>
+  void addDialectDataInitializer(Func &&func) {
+    static_assert(std::is_base_of_v<detail::TransformDialectDataBase, DataTy>,
+                  "only classes deriving TransformDialectData are accepted");
+
+    std::function<void(DataTy &)> initializer = func;
+    initializers.push_back(
+        [init = std::move(initializer)](TransformDialect *transformDialect) {
+          init(transformDialect->getOrCreateExtraData<DataTy>());
+        });
+  }
+
   /// Hook for derived classes to inject constructor behavior.
   void init() {}
 
@@ -108,7 +182,7 @@ class TransformDialectExtension
   /// implementations must be already available when the operation is injected.
   template <typename... OpTys>
   void registerTransformOps() {
-    opInitializers.push_back([](TransformDialect *transformDialect) {
+    initializers.push_back([](TransformDialect *transformDialect) {
       transformDialect->addOperationsChecked<OpTys...>();
     });
   }
@@ -120,7 +194,7 @@ class TransformDialectExtension
   /// `StringRef` that is unique across all injected types.
   template <typename... TypeTys>
   void registerTypes() {
-    opInitializers.push_back([](TransformDialect *transformDialect) {
+    initializers.push_back([](TransformDialect *transformDialect) {
       transformDialect->addTypesChecked<TypeTys...>();
     });
   }
@@ -151,22 +225,10 @@ class TransformDialectExtension
         [](MLIRContext *context) { context->loadDialect<DialectTy>(); });
   }
 
-  /// Injects the named constraint to make it available for use with the
-  /// PDLMatchOp in the transform dialect.
-  void registerPDLMatchConstraintFn(StringRef name,
-                                    PDLConstraintFunction &&fn) {
-    pdlMatchConstraintFns.try_emplace(name,
-                                      std::forward<PDLConstraintFunction>(fn));
-  }
-  template <typename ConstraintFnTy>
-  void registerPDLMatchConstraintFn(StringRef name, ConstraintFnTy &&fn) {
-    pdlMatchConstraintFns.try_emplace(
-        name, ::mlir::detail::pdl_function_builder::buildConstraintFn(
-                  std::forward<ConstraintFnTy>(fn)));
-  }
-
 private:
-  SmallVector<Initializer> opInitializers;
+  /// Callbacks performing extension initialization, e.g., registering ops,
+  /// types and defining the additional data.
+  SmallVector<Initializer> initializers;
 
   /// Callbacks loading the dependent dialects, i.e. the dialect needed for the
   /// extension ops.
@@ -176,13 +238,6 @@ class TransformDialectExtension
   /// applying the transformations.
   SmallVector<DialectLoader> generatedDialectLoaders;
 
-  /// A list of constraints that should be made available to PDL patterns
-  /// processed by PDLMatchOp in the Transform dialect.
-  ///
-  /// Declared as mutable so its contents can be moved in the `apply` const
-  /// method, which is only called once.
-  mutable llvm::StringMap<PDLConstraintFunction> pdlMatchConstraintFns;
-
   /// Indicates that the extension is in build-only mode.
   bool buildOnly;
 };
@@ -232,6 +287,17 @@ void TransformDialect::addTypeIfNotRegistered() {
 #endif // NDEBUG
 }
 
+template <typename DataTy>
+DataTy &TransformDialect::getOrCreateExtraData() {
+  TypeID typeID = TypeID::get<DataTy>();
+  auto it = extraData.find(typeID);
+  if (it != extraData.end())
+    return static_cast<DataTy &>(*it->getSecond());
+
+  auto emplaced = extraData.try_emplace(typeID, std::make_unique<DataTy>());
+  return static_cast<DataTy &>(*emplaced.first->getSecond());
+}
+
 /// A wrapper for transform dialect extensions that forces them to be
 /// constructed in the build-only mode.
 template <typename DerivedTy>

diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
index 160f1ff6ec627..0539187256dde 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
@@ -18,36 +18,31 @@ def Transform_Dialect : Dialect {
   let name = "transform";
   let cppNamespace = "::mlir::transform";
 
-  let dependentDialects = [
-    "::mlir::pdl::PDLDialect",
-    "::mlir::pdl_interp::PDLInterpDialect",
-  ];
-
   let hasOperationAttrVerify = 1;
   let usePropertiesForAttributes = 1;
 
   let extraClassDeclaration = [{
       /// Name of the attribute attachable to the symbol table operation
       /// containing named sequences. This is used to trigger verification.
-      constexpr const static llvm::StringLiteral
+      constexpr const static ::llvm::StringLiteral
           kWithNamedSequenceAttrName = "transform.with_named_sequence";
 
       /// Names of the attribute attachable to an operation so it can be
       /// identified as root by the default interpreter pass.
-      constexpr const static llvm::StringLiteral
+      constexpr const static ::llvm::StringLiteral
           kTargetTagAttrName = "transform.target_tag";
 
       /// Names of the attributes indicating whether an argument of an external
       /// transform dialect symbol is consumed or only read.
-      constexpr const static llvm::StringLiteral
+      constexpr const static ::llvm::StringLiteral
           kArgConsumedAttrName = "transform.consumed";
-      constexpr const static llvm::StringLiteral
+      constexpr const static ::llvm::StringLiteral
           kArgReadOnlyAttrName = "transform.readonly";
 
-      /// Returns the named PDL constraint functions available in the dialect
-      /// as a map from their name to the function.
-      const ::llvm::StringMap<::mlir::PDLConstraintFunction> &
-      getPDLConstraintHooks() const;
+      template <typename DataTy>
+      const DataTy &getExtraData() const {
+        return *static_cast<const DataTy *>(extraData.at(::mlir::TypeID::get<DataTy>()).get());
+      }
 
       /// Parses a type registered by this dialect or one of its extensions.
       ::mlir::Type parseType(::mlir::DialectAsmParser &parser) const override;
@@ -92,23 +87,27 @@ def Transform_Dialect : Dialect {
       /// mnemonic.
       [[noreturn]] void reportDuplicateTypeRegistration(StringRef mnemonic);
 
+      /// Registers dialect types with the context.
       void initializeTypes();
 
+      // Give extensions access to injection functions.
       template <typename, typename...>
       friend class TransformDialectExtension;
 
-      /// Takes ownership of the named PDL constraint function from the given
-      /// map and makes them available for use by the operations in the dialect.
-      void mergeInPDLMatchHooks(
-          ::llvm::StringMap<::mlir::PDLConstraintFunction> &&constraintFns);
+      /// Gets a mutable reference to extra data of the kind specified as
+      /// template argument. Allocates the data on the first call.
+      template <typename DataTy>
+      DataTy &getOrCreateExtraData();
 
       //===----------------------------------------------------------------===//
       // Data fields
       //===----------------------------------------------------------------===//
 
-      /// A container for PDL constraint function that can be used by
-      /// operations in this dialect.
-      ::mlir::PDLPatternModule pdlMatchHooks;
+      /// Additional data associated with and owned by the dialect. Accessible
+      /// to extensions.
+      ::llvm::DenseMap<::mlir::TypeID, std::unique_ptr<
+            ::mlir::transform::detail::TransformDialectDataBase>>
+          extraData;
 
       /// A map from type mnemonic to its parsing function for the remainder of
       /// the syntax. The parser has access to the mnemonic, so it is used for

diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index 6730552c9c53a..77d0c7d99ce6a 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -38,6 +38,14 @@ mapPossibleTopLevelTransformOpBlockArguments(TransformState &state,
 /// Verification hook for PossibleTopLevelTransformOpTrait.
 LogicalResult verifyPossibleTopLevelTransformOpTrait(Operation *op);
 
+/// Populates `effects` with side effects implied by
+/// PossibleTopLevelTransformOpTrait for the given operation. The operation may
+/// have an optional `root` operand, indicating it is not in fact top-level. It
+/// is also expected to have a single-block body.
+void getPotentialTopLevelEffects(
+    Operation *operation, Value root, Block &body,
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects);
+
 /// Verification hook for TransformOpInterface.
 LogicalResult verifyTransformOpInterface(Operation *op);
 
@@ -753,15 +761,16 @@ TransformState::make_isolated_region_scope(Region &region) {
 /// can be standalone top-level transforms. Such operations typically contain
 /// other Transform dialect operations that can be executed following some
 /// control flow logic specific to the current operation. The operations with
-/// this trait are expected to have at least one single-block region with one
-/// argument of PDL Operation type. The operations are also expected to be valid
-/// without operands, in which case they are considered top-level, and with one
-/// or more arguments, in which case they are considered nested. Top-level
-/// operations have the block argument of the entry block in the Transform IR
-/// correspond to the root operation of Payload IR. Nested operations have the
-/// block argument of the entry block in the Transform IR correspond to a list
-/// of Payload IR operations mapped to the first operand of the Transform IR
-/// operation. The operation must implement TransformOpInterface.
+/// this trait are expected to have at least one single-block region with at
+/// least one argument of type implementing TransformHandleTypeInterface. The
+/// operations are also expected to be valid without operands, in which case
+/// they are considered top-level, and with one or more arguments, in which case
+/// they are considered nested. Top-level operations have the block argument of
+/// the entry block in the Transform IR correspond to the root operation of
+/// Payload IR. Nested operations have the block argument of the entry block in
+/// the Transform IR correspond to a list of Payload IR operations mapped to the
+/// first operand of the Transform IR operation. The operation must implement
+/// TransformOpInterface.
 template <typename OpTy>
 class PossibleTopLevelTransformOpTrait
     : public OpTrait::TraitBase<OpTy, PossibleTopLevelTransformOpTrait> {
@@ -777,6 +786,14 @@ class PossibleTopLevelTransformOpTrait
     return &this->getOperation()->getRegion(region).front();
   }
 
+  /// Populates `effects` with side effects implied by this trait.
+  void getPotentialTopLevelEffects(
+      SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+    detail::getPotentialTopLevelEffects(
+        this->getOperation(), cast<OpTy>(this->getOperation()).getRoot(),
+        *getBodyBlock(), effects);
+  }
+
   /// Sets up the mapping between the entry block of the given region of this op
   /// and the relevant list of Payload IR operations in the given state. The
   /// state is expected to be already scoped at the region of this operation.

diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h
index 7a0f80200cc47..543eba9df7ab2 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h
@@ -9,7 +9,6 @@
 #ifndef MLIR_DIALECT_TRANSFORM_IR_TRANSFORMOPS_H
 #define MLIR_DIALECT_TRANSFORM_IR_TRANSFORMOPS_H
 
-#include "mlir/Dialect/PDL/IR/PDLTypes.h"
 #include "mlir/Dialect/Transform/IR/MatchInterfaces.h"
 #include "mlir/Dialect/Transform/IR/TransformAttrs.h"
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"

diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index aa88e49511062..a313d285492d7 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -575,37 +575,6 @@ def ParamConstantOp : Op<Transform_Dialect, "param.constant", [
   let assemblyFormat = "$value attr-dict `->` type($param)";
 }
 
-def PDLMatchOp : TransformDialectOp<"pdl_match",
-    [DeclareOpInterfaceMethods<TransformOpInterface>,
-     DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
-  let summary = "Finds ops that match the named PDL pattern";
-  let description = [{
-    Find Payload IR ops nested within the Payload IR op associated with the
-    operand that match the PDL pattern identified by its name. The pattern is
-    expected to be defined in the closest surrounding `WithPDLPatternsOp`.
-
-    Produces a Transform IR value associated with the list of Payload IR ops
-    that matched the pattern. The order of results in the list is that of the
-    Operation::walk, clients are advised not to rely on a specific order though.
-    If the operand is associated with multiple Payload IR ops, finds matching
-    ops nested within each of those and produces a single list containing all
-    of the matched ops.
-
-    The transformation is considered successful regardless of whether some
-    Payload IR ops actually matched the pattern and only fails if the pattern
-    could not be looked up or compiled.
-  }];
-
-  let arguments = (ins
-    Arg<TransformHandleTypeInterface, "Payload IR scope to match within">:$root,
-    SymbolRefAttr:$pattern_name);
-  let results = (outs
-    Res<TransformHandleTypeInterface, "Handle to the matched Payload IR ops">:$matched);
-
-  let assemblyFormat = "$pattern_name `in` $root attr-dict `:` "
-                       "functional-type(operands, results)";
-}
-
 def PrintOp : TransformDialectOp<"print",
     [DeclareOpInterfaceMethods<TransformOpInterface>,
      DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
@@ -753,61 +722,6 @@ def SequenceOp : TransformDialectOp<"sequence",
   let hasVerifier = 1;
 }
 
-def WithPDLPatternsOp : TransformDialectOp<"with_pdl_patterns",
-    [DeclareOpInterfaceMethods<TransformOpInterface>, NoTerminator,
-     OpAsmOpInterface, PossibleTopLevelTransformOpTrait,
-     DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
-     SymbolTable]> {
-  let summary = "Contains PDL patterns available for use in transforms";
-  let description = [{
-    This op contains a set of named PDL patterns that are available for the
-    Transform dialect operations to be used for pattern matching. For example,
-    PDLMatchOp can be used to produce a Transform IR value associated with all
-    Payload IR operations that match the pattern as follows:
-
-    ```mlir
-    transform.with_pdl_patterns {
-    ^bb0(%arg0: !transform.any_op):
-      pdl.pattern @my_pattern : benefit(1) {
-        %0 = pdl.operation //...
-        // Regular PDL goes here.
-        pdl.rewrite %0 with "transform.dialect"
-      }
-
-      sequence %arg0 failures(propagate) {
-      ^bb0(%arg1: !transform.any_op):
-        %1 = pdl_match @my_pattern in %arg1
-        // Use %1 as handle
-      }
-    }
-    ```
-
-    Note that the pattern is expected to finish with a `pdl.rewrite` terminator
-    that points to the custom rewriter named "transform.dialect". The rewriter
-    actually does nothing, but the transform application will keep track of the
-    operations that matched the pattern.
-
-    This op is expected to contain `pdl.pattern` operations and exactly one
-    another Transform dialect operation that gets executed with all patterns
-    available. This op is a possible top-level Transform IR op, the argument of
-    its entry block corresponds to either the root op of the payload IR or the
-    ops associated with its operand when provided.
-  }];
-
-  let arguments = (ins
-    Arg<Optional<TransformHandleTypeInterface>, "Root operation of the Payload IR"
-        >:$root);
-  let regions = (region SizedRegion<1>:$body);
-  let assemblyFormat = "($root^ `:` type($root))? attr-dict-with-keyword regions";
-
-  let hasVerifier = 1;
-
-  let extraClassDeclaration = [{
-    /// Allow the dialect prefix to be omitted.
-    static StringRef getDefaultDialect() { return "transform"; }
-  }];
-}
-
 def YieldOp : TransformDialectOp<"yield",
     [Terminator, DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
   let summary = "Yields operation handles from a transform IR region";

diff  --git a/mlir/include/mlir/Dialect/Transform/PDLExtension/CMakeLists.txt b/mlir/include/mlir/Dialect/Transform/PDLExtension/CMakeLists.txt
new file mode 100644
index 0000000000000..6af6b838f266f
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Transform/PDLExtension/CMakeLists.txt
@@ -0,0 +1,6 @@
+set(LLVM_TARGET_DEFINITIONS PDLExtensionOps.td)
+mlir_tablegen(PDLExtensionOps.h.inc -gen-op-decls)
+mlir_tablegen(PDLExtensionOps.cpp.inc -gen-op-defs)
+add_public_tablegen_target(MLIRTransformDialectPDLExtensionOpsIncGen)
+
+add_mlir_doc(PDLExtensionOps PDLExtensionOps Dialects/ -gen-op-doc)

diff  --git a/mlir/include/mlir/Dialect/Transform/PDLExtension/PDLExtension.h b/mlir/include/mlir/Dialect/Transform/PDLExtension/PDLExtension.h
new file mode 100644
index 0000000000000..08915213cd22c
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Transform/PDLExtension/PDLExtension.h
@@ -0,0 +1,16 @@
+//===- PDLExtension.h - PDL extension for Transform dialect -----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+namespace mlir {
+class DialectRegistry;
+
+namespace transform {
+/// Registers the PDL extension of the Transform dialect in the given registry.
+void registerPDLExtension(DialectRegistry &dialectRegistry);
+} // namespace transform
+} // namespace mlir

diff  --git a/mlir/include/mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.h b/mlir/include/mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.h
new file mode 100644
index 0000000000000..a159c30df86d3
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.h
@@ -0,0 +1,49 @@
+//===- PDLExtensionOps.h - PDL extension for Transform dialect --*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_TRANSFORM_PDLEXTENSION_PDLEXTENSIONOPS_H
+#define MLIR_DIALECT_TRANSFORM_PDLEXTENSION_PDLEXTENSIONOPS_H
+
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.h.inc"
+
+namespace mlir {
+namespace transform {
+/// PDL constraint callbacks that can be used by the PDL extension of the
+/// Transform dialect. These are owned by the Transform dialect and can be
+/// populated by extensions.
+class PDLMatchHooks : public TransformDialectData<PDLMatchHooks> {
+public:
+  /// Takes ownership of the named PDL constraint function from the given
+  /// map and makes them available for use by the operations in the dialect.
+  void
+  mergeInPDLMatchHooks(llvm::StringMap<PDLConstraintFunction> &&constraintFns);
+
+  /// Returns the named PDL constraint functions available in the dialect
+  /// as a map from their name to the function.
+  const llvm::StringMap<::mlir::PDLConstraintFunction> &
+  getPDLConstraintHooks() const;
+
+private:
+  /// A container for PDL constraint function that can be used by
+  /// operations in this dialect.
+  PDLPatternModule pdlMatchHooks;
+};
+} // namespace transform
+} // namespace mlir
+
+MLIR_DECLARE_EXPLICIT_TYPE_ID(mlir::transform::PDLMatchHooks)
+
+#endif // MLIR_DIALECT_TRANSFORM_PDLEXTENSION_PDLEXTENSIONOPS_H

diff  --git a/mlir/include/mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.td b/mlir/include/mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.td
new file mode 100644
index 0000000000000..16107b3d0869f
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.td
@@ -0,0 +1,104 @@
+//===- TransformOps.td - Transform dialect operations ------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_TRANSFORM_PDLEXTENSION_PDLEXTENSIONOPS
+#define MLIR_DIALECT_TRANSFORM_PDLEXTENSION_PDLEXTENSIONOPS
+
+include "mlir/Dialect/Transform/IR/TransformDialect.td"
+include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/IR/OpAsmInterface.td"
+include "mlir/IR/SymbolInterfaces.td"
+
+def PDLMatchOp : TransformDialectOp<"pdl_match",
+    [DeclareOpInterfaceMethods<TransformOpInterface>,
+     DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
+  let summary = "Finds ops that match the named PDL pattern";
+  let description = [{
+    Find Payload IR ops nested within the Payload IR op associated with the
+    operand that match the PDL pattern identified by its name. The pattern is
+    expected to be defined in the closest surrounding `WithPDLPatternsOp`.
+
+    Produces a Transform IR value associated with the list of Payload IR ops
+    that matched the pattern. The order of results in the list is that of the
+    Operation::walk, clients are advised not to rely on a specific order though.
+    If the operand is associated with multiple Payload IR ops, finds matching
+    ops nested within each of those and produces a single list containing all
+    of the matched ops.
+
+    The transformation is considered successful regardless of whether some
+    Payload IR ops actually matched the pattern and only fails if the pattern
+    could not be looked up or compiled.
+  }];
+
+  let arguments = (ins
+    Arg<TransformHandleTypeInterface, "Payload IR scope to match within">:$root,
+    SymbolRefAttr:$pattern_name);
+  let results = (outs
+    Res<TransformHandleTypeInterface, "Handle to the matched Payload IR ops">:$matched);
+
+  let assemblyFormat = "$pattern_name `in` $root attr-dict `:` "
+                       "functional-type(operands, results)";
+}
+
+def WithPDLPatternsOp : TransformDialectOp<"with_pdl_patterns",
+    [DeclareOpInterfaceMethods<TransformOpInterface>, NoTerminator,
+     OpAsmOpInterface, PossibleTopLevelTransformOpTrait,
+     DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+     SymbolTable]> {
+  let summary = "Contains PDL patterns available for use in transforms";
+  let description = [{
+    This op contains a set of named PDL patterns that are available for the
+    Transform dialect operations to be used for pattern matching. For example,
+    PDLMatchOp can be used to produce a Transform IR value associated with all
+    Payload IR operations that match the pattern as follows:
+
+    ```mlir
+    transform.with_pdl_patterns {
+    ^bb0(%arg0: !transform.any_op):
+      pdl.pattern @my_pattern : benefit(1) {
+        %0 = pdl.operation //...
+        // Regular PDL goes here.
+        pdl.rewrite %0 with "transform.dialect"
+      }
+
+      sequence %arg0 failures(propagate) {
+      ^bb0(%arg1: !transform.any_op):
+        %1 = pdl_match @my_pattern in %arg1
+        // Use %1 as handle
+      }
+    }
+    ```
+
+    Note that the pattern is expected to finish with a `pdl.rewrite` terminator
+    that points to the custom rewriter named "transform.dialect". The rewriter
+    actually does nothing, but the transform application will keep track of the
+    operations that matched the pattern.
+
+    This op is expected to contain `pdl.pattern` operations and exactly one
+    another Transform dialect operation that gets executed with all patterns
+    available. This op is a possible top-level Transform IR op, the argument of
+    its entry block corresponds to either the root op of the payload IR or the
+    ops associated with its operand when provided.
+  }];
+
+  let arguments = (ins
+    Arg<Optional<TransformHandleTypeInterface>, "Root operation of the Payload IR"
+        >:$root);
+  let regions = (region SizedRegion<1>:$body);
+  let assemblyFormat = "($root^ `:` type($root))? attr-dict-with-keyword regions";
+
+  let hasVerifier = 1;
+
+  let extraClassDeclaration = [{
+    /// Allow the dialect prefix to be omitted.
+    static StringRef getDefaultDialect() { return "transform"; }
+  }];
+}
+
+#endif // MLIR_DIALECT_TRANSFORM_PDLEXTENSION_PDLEXTENSIONOPS

diff  --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index b00de3f0a2002..e307b236b39a5 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -76,6 +76,7 @@
 #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/PDLExtension/PDLExtension.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h"
 #include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h"
@@ -135,6 +136,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
   memref::registerTransformDialectExtension(registry);
   scf::registerTransformDialectExtension(registry);
   tensor::registerTransformDialectExtension(registry);
+  transform::registerPDLExtension(registry);
   vector::registerTransformDialectExtension(registry);
 
   // Register all external models.

diff  --git a/mlir/lib/Dialect/Transform/CMakeLists.txt b/mlir/lib/Dialect/Transform/CMakeLists.txt
index 31167e6af908b..9e144eba25710 100644
--- a/mlir/lib/Dialect/Transform/CMakeLists.txt
+++ b/mlir/lib/Dialect/Transform/CMakeLists.txt
@@ -1,3 +1,4 @@
 add_subdirectory(IR)
+add_subdirectory(PDLExtension)
 add_subdirectory(Transforms)
 add_subdirectory(Utils)

diff  --git a/mlir/lib/Dialect/Transform/IR/CMakeLists.txt b/mlir/lib/Dialect/Transform/IR/CMakeLists.txt
index 2fed20f927380..4fb27512c4907 100644
--- a/mlir/lib/Dialect/Transform/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Transform/IR/CMakeLists.txt
@@ -14,8 +14,6 @@ add_mlir_dialect_library(MLIRTransformDialect
   LINK_LIBS PUBLIC
   MLIRIR
   MLIRParser
-  MLIRPDLDialect
-  MLIRPDLInterpDialect
   MLIRRewrite
   MLIRSideEffectInterfaces
   MLIRTransforms

diff  --git a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
index 6780c3bad9685..d0759941f1ad3 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
@@ -8,8 +8,6 @@
 
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
 #include "mlir/Analysis/CallGraph.h"
-#include "mlir/Dialect/PDL/IR/PDL.h"
-#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
 #include "mlir/Dialect/Transform/IR/TransformOps.h"
 #include "mlir/Dialect/Transform/IR/TransformTypes.h"
@@ -51,18 +49,6 @@ void transform::detail::checkImplementsTransformHandleTypeInterface(
 }
 #endif // NDEBUG
 
-namespace {
-struct PDLOperationTypeTransformHandleTypeInterfaceImpl
-    : public transform::TransformHandleTypeInterface::ExternalModel<
-          PDLOperationTypeTransformHandleTypeInterfaceImpl,
-          pdl::OperationType> {
-  DiagnosedSilenceableFailure
-  checkPayload(Type type, Location loc, ArrayRef<Operation *> payload) const {
-    return DiagnosedSilenceableFailure::success();
-  }
-};
-} // namespace
-
 void transform::TransformDialect::initialize() {
   // Using the checked versions to enable the same assertions as for the ops
   // from extensions.
@@ -71,21 +57,6 @@ void transform::TransformDialect::initialize() {
 #include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
       >();
   initializeTypes();
-
-  pdl::OperationType::attachInterface<
-      PDLOperationTypeTransformHandleTypeInterfaceImpl>(*getContext());
-}
-
-void transform::TransformDialect::mergeInPDLMatchHooks(
-    llvm::StringMap<PDLConstraintFunction> &&constraintFns) {
-  // Steal the constraint functions from the given map.
-  for (auto &it : constraintFns)
-    pdlMatchHooks.registerConstraintFunction(it.getKey(), std::move(it.second));
-}
-
-const llvm::StringMap<PDLConstraintFunction> &
-transform::TransformDialect::getPDLConstraintHooks() const {
-  return pdlMatchHooks.getConstraintFunctions();
 }
 
 Type transform::TransformDialect::parseType(DialectAsmParser &parser) const {

diff  --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index 5685187e853f5..37caa60edadaf 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -1242,6 +1242,61 @@ void transform::detail::forwardTerminatorOperands(
 // Utilities for PossibleTopLevelTransformOpTrait.
 //===----------------------------------------------------------------------===//
 
+/// Appends to `effects` the memory effect instances on `target` with the same
+/// resource and effect as the ones the operation `iface` having on `source`.
+static void
+remapEffects(MemoryEffectOpInterface iface, BlockArgument source, Value target,
+             SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  SmallVector<MemoryEffects::EffectInstance> nestedEffects;
+  iface.getEffectsOnValue(source, nestedEffects);
+  for (const auto &effect : nestedEffects)
+    effects.emplace_back(effect.getEffect(), target, effect.getResource());
+}
+
+/// Appends to `effects` the same effects as the operations of `block` have on
+/// block arguments but associated with `operands.`
+static void
+remapArgumentEffects(Block &block, ValueRange operands,
+                     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  for (Operation &op : block) {
+    auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
+    if (!iface)
+      continue;
+
+    for (auto &&[source, target] : llvm::zip(block.getArguments(), operands)) {
+      remapEffects(iface, source, target, effects);
+    }
+
+    SmallVector<MemoryEffects::EffectInstance> nestedEffects;
+    iface.getEffectsOnResource(transform::PayloadIRResource::get(),
+                               nestedEffects);
+    llvm::append_range(effects, nestedEffects);
+  }
+}
+
+void transform::detail::getPotentialTopLevelEffects(
+    Operation *operation, Value root, Block &body,
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  transform::onlyReadsHandle(operation->getOperands(), effects);
+  transform::producesHandle(operation->getResults(), effects);
+
+  if (!root) {
+    for (Operation &op : body) {
+      auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
+      if (!iface)
+        continue;
+
+      SmallVector<MemoryEffects::EffectInstance, 2> nestedEffects;
+      iface.getEffects(effects);
+    }
+    return;
+  }
+
+  // Carry over all effects on arguments of the entry block as those on the
+  // operands, this is the same value just remapped.
+  remapArgumentEffects(body, operation->getOperands(), effects);
+}
+
 LogicalResult transform::detail::mapPossibleTopLevelTransformOpBlockArguments(
     TransformState &state, Operation *op, Region &region) {
   SmallVector<Operation *> targets;

diff  --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index ad001707ddd64..a3b55a45dd96e 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -7,7 +7,6 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Transform/IR/TransformOps.h"
-#include "mlir/Dialect/PDL/IR/PDLOps.h"
 #include "mlir/Dialect/Transform/IR/MatchInterfaces.h"
 #include "mlir/Dialect/Transform/IR/TransformAttrs.h"
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
@@ -17,8 +16,6 @@
 #include "mlir/IR/FunctionImplementation.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
-#include "mlir/Rewrite/FrozenRewritePatternSet.h"
-#include "mlir/Rewrite/PatternApplicator.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/ScopeExit.h"
 #include "llvm/ADT/SmallPtrSet.h"
@@ -52,99 +49,6 @@ static ParseResult parseForeachMatchSymbols(OpAsmParser &parser,
 #define GET_OP_CLASSES
 #include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
 
-//===----------------------------------------------------------------------===//
-// PatternApplicatorExtension
-//===----------------------------------------------------------------------===//
-
-namespace {
-/// A TransformState extension that keeps track of compiled PDL pattern sets.
-/// This is intended to be used along the WithPDLPatterns op. The extension
-/// can be constructed given an operation that has a SymbolTable trait and
-/// contains pdl::PatternOp instances. The patterns are compiled lazily and one
-/// by one when requested; this behavior is subject to change.
-class PatternApplicatorExtension : public transform::TransformState::Extension {
-public:
-  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PatternApplicatorExtension)
-
-  /// Creates the extension for patterns contained in `patternContainer`.
-  explicit PatternApplicatorExtension(transform::TransformState &state,
-                                      Operation *patternContainer)
-      : Extension(state), patterns(patternContainer) {}
-
-  /// Appends to `results` the operations contained in `root` that matched the
-  /// PDL pattern with the given name. Note that `root` may or may not be the
-  /// operation that contains PDL patterns. Reports an error if the pattern
-  /// cannot be found. Note that when no operations are matched, this still
-  /// succeeds as long as the pattern exists.
-  LogicalResult findAllMatches(StringRef patternName, Operation *root,
-                               SmallVectorImpl<Operation *> &results);
-
-private:
-  /// Map from the pattern name to a singleton set of rewrite patterns that only
-  /// contains the pattern with this name. Populated when the pattern is first
-  /// requested.
-  // TODO: reconsider the efficiency of this storage when more usage data is
-  // available. Storing individual patterns in a set and triggering compilation
-  // for each of them has overhead. So does compiling a large set of patterns
-  // only to apply a handlful of them.
-  llvm::StringMap<FrozenRewritePatternSet> compiledPatterns;
-
-  /// A symbol table operation containing the relevant PDL patterns.
-  SymbolTable patterns;
-};
-
-LogicalResult PatternApplicatorExtension::findAllMatches(
-    StringRef patternName, Operation *root,
-    SmallVectorImpl<Operation *> &results) {
-  auto it = compiledPatterns.find(patternName);
-  if (it == compiledPatterns.end()) {
-    auto patternOp = patterns.lookup<pdl::PatternOp>(patternName);
-    if (!patternOp)
-      return failure();
-
-    // Copy the pattern operation into a new module that is compiled and
-    // consumed by the PDL interpreter.
-    OwningOpRef<ModuleOp> pdlModuleOp = ModuleOp::create(patternOp.getLoc());
-    auto builder = OpBuilder::atBlockEnd(pdlModuleOp->getBody());
-    builder.clone(*patternOp);
-    PDLPatternModule patternModule(std::move(pdlModuleOp));
-
-    // Merge in the hooks owned by the dialect. Make a copy as they may be
-    // also used by the following operations.
-    auto *dialect =
-        root->getContext()->getLoadedDialect<transform::TransformDialect>();
-    for (const auto &[name, constraintFn] : dialect->getPDLConstraintHooks())
-      patternModule.registerConstraintFunction(name, constraintFn);
-
-    // Register a noop rewriter because PDL requires patterns to end with some
-    // rewrite call.
-    patternModule.registerRewriteFunction(
-        "transform.dialect", [](PatternRewriter &, Operation *) {});
-
-    it = compiledPatterns
-             .try_emplace(patternOp.getName(), std::move(patternModule))
-             .first;
-  }
-
-  PatternApplicator applicator(it->second);
-  // We want to discourage direct use of PatternRewriter in APIs but In this
-  // very specific case, an IRRewriter is not enough.
-  struct TrivialPatternRewriter : public PatternRewriter {
-  public:
-    explicit TrivialPatternRewriter(MLIRContext *context)
-        : PatternRewriter(context) {}
-  };
-  TrivialPatternRewriter rewriter(root->getContext());
-  applicator.applyDefaultCostModel();
-  root->walk([&](Operation *op) {
-    if (succeeded(applicator.matchAndRewrite(op, rewriter)))
-      results.push_back(op);
-  });
-
-  return success();
-}
-} // namespace
-
 //===----------------------------------------------------------------------===//
 // TrackingListener
 //===----------------------------------------------------------------------===//
@@ -420,10 +324,7 @@ bool transform::CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
   assert(outputs.size() == 1 && "expected one output");
   return llvm::all_of(
       std::initializer_list<Type>{inputs.front(), outputs.front()},
-      [](Type ty) {
-        return llvm::isa<pdl::OperationType,
-                         transform::TransformHandleTypeInterface>(ty);
-      });
+      [](Type ty) { return isa<transform::TransformHandleTypeInterface>(ty); });
 }
 
 //===----------------------------------------------------------------------===//
@@ -1031,38 +932,6 @@ transform::IncludeOp::apply(transform::TransformResults &results,
   return result;
 }
 
-/// Appends to `effects` the memory effect instances on `target` with the same
-/// resource and effect as the ones the operation `iface` having on `source`.
-static void
-remapEffects(MemoryEffectOpInterface iface, BlockArgument source, Value target,
-             SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
-  SmallVector<MemoryEffects::EffectInstance> nestedEffects;
-  iface.getEffectsOnValue(source, nestedEffects);
-  for (const auto &effect : nestedEffects)
-    effects.emplace_back(effect.getEffect(), target, effect.getResource());
-}
-
-/// Appends to `effects` the same effects as the operations of `block` have on
-/// block arguments but associated with `operands.`
-static void
-remapArgumentEffects(Block &block, ValueRange operands,
-                     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
-  for (Operation &op : block) {
-    auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
-    if (!iface)
-      continue;
-
-    for (auto &&[source, target] : llvm::zip(block.getArguments(), operands)) {
-      remapEffects(iface, source, target, effects);
-    }
-
-    SmallVector<MemoryEffects::EffectInstance> nestedEffects;
-    iface.getEffectsOnResource(transform::PayloadIRResource::get(),
-                               nestedEffects);
-    llvm::append_range(effects, nestedEffects);
-  }
-}
-
 static DiagnosedSilenceableFailure
 verifyNamedSequenceOp(transform::NamedSequenceOp op);
 
@@ -1474,8 +1343,7 @@ LogicalResult transform::NamedSequenceOp::verify() {
 void transform::SplitHandleOp::build(OpBuilder &builder, OperationState &result,
                                      Value target, int64_t numResultHandles) {
   result.addOperands(target);
-  auto pdlOpType = pdl::OperationType::get(builder.getContext());
-  result.addTypes(SmallVector<pdl::OperationType>(numResultHandles, pdlOpType));
+  result.addTypes(SmallVector<Type>(numResultHandles, target.getType()));
 }
 
 DiagnosedSilenceableFailure
@@ -1535,35 +1403,6 @@ LogicalResult transform::SplitHandleOp::verify() {
   return success();
 }
 
-//===----------------------------------------------------------------------===//
-// PDLMatchOp
-//===----------------------------------------------------------------------===//
-
-DiagnosedSilenceableFailure
-transform::PDLMatchOp::apply(transform::TransformResults &results,
-                             transform::TransformState &state) {
-  auto *extension = state.getExtension<PatternApplicatorExtension>();
-  assert(extension &&
-         "expected PatternApplicatorExtension to be attached by the parent op");
-  SmallVector<Operation *> targets;
-  for (Operation *root : state.getPayloadOps(getRoot())) {
-    if (failed(extension->findAllMatches(
-            getPatternName().getLeafReference().getValue(), root, targets))) {
-      emitDefiniteFailure()
-          << "could not find pattern '" << getPatternName() << "'";
-    }
-  }
-  results.set(llvm::cast<OpResult>(getResult()), targets);
-  return DiagnosedSilenceableFailure::success();
-}
-
-void transform::PDLMatchOp::getEffects(
-    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
-  onlyReadsHandle(getRoot(), effects);
-  producesHandle(getMatched(), effects);
-  onlyReadsPayload(effects);
-}
-
 //===----------------------------------------------------------------------===//
 // ReplicateOp
 //===----------------------------------------------------------------------===//
@@ -1776,37 +1615,9 @@ LogicalResult transform::SequenceOp::verify() {
   return success();
 }
 
-/// Populate `effects` with transform dialect memory effects for the potential
-/// top-level operation. Such operations have recursive effects from nested
-/// operations. When they have an operand, we can additionally remap effects on
-/// the block argument to be effects on the operand.
-template <typename OpTy>
-static void getPotentialTopLevelEffects(
-    OpTy operation, SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
-  transform::onlyReadsHandle(operation->getOperands(), effects);
-  transform::producesHandle(operation->getResults(), effects);
-
-  if (!operation.getRoot()) {
-    for (Operation &op : *operation.getBodyBlock()) {
-      auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
-      if (!iface)
-        continue;
-
-      SmallVector<MemoryEffects::EffectInstance, 2> nestedEffects;
-      iface.getEffects(effects);
-    }
-    return;
-  }
-
-  // Carry over all effects on arguments of the entry block as those on the
-  // operands, this is the same value just remapped.
-  remapArgumentEffects(*operation.getBodyBlock(), operation->getOperands(),
-                       effects);
-}
-
 void transform::SequenceOp::getEffects(
     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
-  getPotentialTopLevelEffects(*this, effects);
+  getPotentialTopLevelEffects(effects);
 }
 
 OperandRange transform::SequenceOp::getSuccessorEntryOperands(
@@ -1908,77 +1719,6 @@ void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
   buildSequenceBody(builder, state, bbArgType, extraBindingTypes, bodyBuilder);
 }
 
-//===----------------------------------------------------------------------===//
-// WithPDLPatternsOp
-//===----------------------------------------------------------------------===//
-
-DiagnosedSilenceableFailure
-transform::WithPDLPatternsOp::apply(transform::TransformResults &results,
-                                    transform::TransformState &state) {
-  TransformOpInterface transformOp = nullptr;
-  for (Operation &nested : getBody().front()) {
-    if (!isa<pdl::PatternOp>(nested)) {
-      transformOp = cast<TransformOpInterface>(nested);
-      break;
-    }
-  }
-
-  state.addExtension<PatternApplicatorExtension>(getOperation());
-  auto guard = llvm::make_scope_exit(
-      [&]() { state.removeExtension<PatternApplicatorExtension>(); });
-
-  auto scope = state.make_region_scope(getBody());
-  if (failed(mapBlockArguments(state)))
-    return DiagnosedSilenceableFailure::definiteFailure();
-  return state.applyTransform(transformOp);
-}
-
-void transform::WithPDLPatternsOp::getEffects(
-    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
-  getPotentialTopLevelEffects(*this, effects);
-}
-
-LogicalResult transform::WithPDLPatternsOp::verify() {
-  Block *body = getBodyBlock();
-  Operation *topLevelOp = nullptr;
-  for (Operation &op : body->getOperations()) {
-    if (isa<pdl::PatternOp>(op))
-      continue;
-
-    if (op.hasTrait<::mlir::transform::PossibleTopLevelTransformOpTrait>()) {
-      if (topLevelOp) {
-        InFlightDiagnostic diag =
-            emitOpError() << "expects only one non-pattern op in its body";
-        diag.attachNote(topLevelOp->getLoc()) << "first non-pattern op";
-        diag.attachNote(op.getLoc()) << "second non-pattern op";
-        return diag;
-      }
-      topLevelOp = &op;
-      continue;
-    }
-
-    InFlightDiagnostic diag =
-        emitOpError()
-        << "expects only pattern and top-level transform ops in its body";
-    diag.attachNote(op.getLoc()) << "offending op";
-    return diag;
-  }
-
-  if (auto parent = getOperation()->getParentOfType<WithPDLPatternsOp>()) {
-    InFlightDiagnostic diag = emitOpError() << "cannot be nested";
-    diag.attachNote(parent.getLoc()) << "parent operation";
-    return diag;
-  }
-
-  if (!topLevelOp) {
-    InFlightDiagnostic diag = emitOpError()
-                              << "expects at least one non-pattern op";
-    return diag;
-  }
-
-  return success();
-}
-
 //===----------------------------------------------------------------------===//
 // PrintOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Transform/PDLExtension/CMakeLists.txt b/mlir/lib/Dialect/Transform/PDLExtension/CMakeLists.txt
new file mode 100644
index 0000000000000..4a60ed48a1343
--- /dev/null
+++ b/mlir/lib/Dialect/Transform/PDLExtension/CMakeLists.txt
@@ -0,0 +1,13 @@
+add_mlir_dialect_library(MLIRTransformPDLExtension
+  PDLExtension.cpp
+  PDLExtensionOps.cpp
+
+  DEPENDS
+  MLIRTransformDialectPDLExtensionOpsIncGen
+
+  LINK_LIBS PUBLIC
+  MLIRIR
+  MLIRTransformDialect
+  MLIRPDLDialect
+  MLIRPDLInterpDialect
+)

diff  --git a/mlir/lib/Dialect/Transform/PDLExtension/PDLExtension.cpp b/mlir/lib/Dialect/Transform/PDLExtension/PDLExtension.cpp
new file mode 100644
index 0000000000000..2c770abd56d52
--- /dev/null
+++ b/mlir/lib/Dialect/Transform/PDLExtension/PDLExtension.cpp
@@ -0,0 +1,69 @@
+//===- PDLExtension.cpp - PDL extension for the Transform dialect ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Transform/PDLExtension/PDLExtension.h"
+#include "mlir/Dialect/PDL/IR/PDL.h"
+#include "mlir/Dialect/PDL/IR/PDLTypes.h"
+#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.h"
+#include "mlir/IR/DialectRegistry.h"
+
+using namespace mlir;
+
+namespace {
+/// Implementation of the TransformHandleTypeInterface for the PDL
+/// OperationType. Accepts any payload operation.
+struct PDLOperationTypeTransformHandleTypeInterfaceImpl
+    : public transform::TransformHandleTypeInterface::ExternalModel<
+          PDLOperationTypeTransformHandleTypeInterfaceImpl,
+          pdl::OperationType> {
+
+  /// Accept any operation.
+  DiagnosedSilenceableFailure
+  checkPayload(Type type, Location loc, ArrayRef<Operation *> payload) const {
+    return DiagnosedSilenceableFailure::success();
+  }
+};
+} // namespace
+
+namespace {
+/// PDL extension of the Transform dialect. This provides transform operations
+/// that connect to PDL matching as well as interfaces for PDL types to be used
+/// with Transform dialect operations.
+class PDLExtension : public transform::TransformDialectExtension<PDLExtension> {
+public:
+  void init() {
+    registerTransformOps<
+#define GET_OP_LIST
+#include "mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.cpp.inc"
+        >();
+
+    addDialectDataInitializer<transform::PDLMatchHooks>(
+        [](transform::PDLMatchHooks &) {});
+
+    // Declare PDL as dependent so we can attach an interface to its type in the
+    // later step.
+    declareDependentDialect<pdl::PDLDialect>();
+
+    // PDLInterp is only relevant if we actually apply the transform IR so
+    // declare it as generated.
+    declareGeneratedDialect<pdl_interp::PDLInterpDialect>();
+
+    // Make PDL OperationType usable as a transform dialect type.
+    addCustomInitializationStep([](MLIRContext *context) {
+      pdl::OperationType::attachInterface<
+          PDLOperationTypeTransformHandleTypeInterfaceImpl>(*context);
+    });
+  }
+};
+} // namespace
+
+void mlir::transform::registerPDLExtension(DialectRegistry &dialectRegistry) {
+  dialectRegistry.addExtensions<PDLExtension>();
+}

diff  --git a/mlir/lib/Dialect/Transform/PDLExtension/PDLExtensionOps.cpp b/mlir/lib/Dialect/Transform/PDLExtension/PDLExtensionOps.cpp
new file mode 100644
index 0000000000000..5126d79adb189
--- /dev/null
+++ b/mlir/lib/Dialect/Transform/PDLExtension/PDLExtensionOps.cpp
@@ -0,0 +1,234 @@
+//===- PDLExtensionOps.cpp - PDL extension for the Transform dialect ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.h"
+#include "mlir/Dialect/PDL/IR/PDLOps.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/Rewrite/FrozenRewritePatternSet.h"
+#include "mlir/Rewrite/PatternApplicator.h"
+#include "llvm/ADT/ScopeExit.h"
+
+using namespace mlir;
+
+MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::transform::PDLMatchHooks)
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.cpp.inc"
+
+//===----------------------------------------------------------------------===//
+// PatternApplicatorExtension
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// A TransformState extension that keeps track of compiled PDL pattern sets.
+/// This is intended to be used along the WithPDLPatterns op. The extension
+/// can be constructed given an operation that has a SymbolTable trait and
+/// contains pdl::PatternOp instances. The patterns are compiled lazily and one
+/// by one when requested; this behavior is subject to change.
+class PatternApplicatorExtension : public transform::TransformState::Extension {
+public:
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PatternApplicatorExtension)
+
+  /// Creates the extension for patterns contained in `patternContainer`.
+  explicit PatternApplicatorExtension(transform::TransformState &state,
+                                      Operation *patternContainer)
+      : Extension(state), patterns(patternContainer) {}
+
+  /// Appends to `results` the operations contained in `root` that matched the
+  /// PDL pattern with the given name. Note that `root` may or may not be the
+  /// operation that contains PDL patterns. Reports an error if the pattern
+  /// cannot be found. Note that when no operations are matched, this still
+  /// succeeds as long as the pattern exists.
+  LogicalResult findAllMatches(StringRef patternName, Operation *root,
+                               SmallVectorImpl<Operation *> &results);
+
+private:
+  /// Map from the pattern name to a singleton set of rewrite patterns that only
+  /// contains the pattern with this name. Populated when the pattern is first
+  /// requested.
+  // TODO: reconsider the efficiency of this storage when more usage data is
+  // available. Storing individual patterns in a set and triggering compilation
+  // for each of them has overhead. So does compiling a large set of patterns
+  // only to apply a handful of them.
+  llvm::StringMap<FrozenRewritePatternSet> compiledPatterns;
+
+  /// A symbol table operation containing the relevant PDL patterns.
+  SymbolTable patterns;
+};
+
+LogicalResult PatternApplicatorExtension::findAllMatches(
+    StringRef patternName, Operation *root,
+    SmallVectorImpl<Operation *> &results) {
+  auto it = compiledPatterns.find(patternName);
+  if (it == compiledPatterns.end()) {
+    auto patternOp = patterns.lookup<pdl::PatternOp>(patternName);
+    if (!patternOp)
+      return failure();
+
+    // Copy the pattern operation into a new module that is compiled and
+    // consumed by the PDL interpreter.
+    OwningOpRef<ModuleOp> pdlModuleOp = ModuleOp::create(patternOp.getLoc());
+    auto builder = OpBuilder::atBlockEnd(pdlModuleOp->getBody());
+    builder.clone(*patternOp);
+    PDLPatternModule patternModule(std::move(pdlModuleOp));
+
+    // Merge in the hooks owned by the dialect. Make a copy as they may be
+    // also used by the following operations.
+    auto *dialect =
+        root->getContext()->getLoadedDialect<transform::TransformDialect>();
+    for (const auto &[name, constraintFn] :
+         dialect->getExtraData<transform::PDLMatchHooks>()
+             .getPDLConstraintHooks()) {
+      patternModule.registerConstraintFunction(name, constraintFn);
+    }
+
+    // Register a noop rewriter because PDL requires patterns to end with some
+    // rewrite call.
+    patternModule.registerRewriteFunction(
+        "transform.dialect", [](PatternRewriter &, Operation *) {});
+
+    it = compiledPatterns
+             .try_emplace(patternOp.getName(), std::move(patternModule))
+             .first;
+  }
+
+  PatternApplicator applicator(it->second);
+  // We want to discourage direct use of PatternRewriter in APIs but In this
+  // very specific case, an IRRewriter is not enough.
+  struct TrivialPatternRewriter : public PatternRewriter {
+  public:
+    explicit TrivialPatternRewriter(MLIRContext *context)
+        : PatternRewriter(context) {}
+  };
+  TrivialPatternRewriter rewriter(root->getContext());
+  applicator.applyDefaultCostModel();
+  root->walk([&](Operation *op) {
+    if (succeeded(applicator.matchAndRewrite(op, rewriter)))
+      results.push_back(op);
+  });
+
+  return success();
+}
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// PDLMatchHooks
+//===----------------------------------------------------------------------===//
+
+void transform::PDLMatchHooks::mergeInPDLMatchHooks(
+    llvm::StringMap<PDLConstraintFunction> &&constraintFns) {
+  // Steal the constraint functions from the given map.
+  for (auto &it : constraintFns)
+    pdlMatchHooks.registerConstraintFunction(it.getKey(), std::move(it.second));
+}
+
+const llvm::StringMap<PDLConstraintFunction> &
+transform::PDLMatchHooks::getPDLConstraintHooks() const {
+  return pdlMatchHooks.getConstraintFunctions();
+}
+
+//===----------------------------------------------------------------------===//
+// PDLMatchOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::PDLMatchOp::apply(transform::TransformResults &results,
+                             transform::TransformState &state) {
+  auto *extension = state.getExtension<PatternApplicatorExtension>();
+  assert(extension &&
+         "expected PatternApplicatorExtension to be attached by the parent op");
+  SmallVector<Operation *> targets;
+  for (Operation *root : state.getPayloadOps(getRoot())) {
+    if (failed(extension->findAllMatches(
+            getPatternName().getLeafReference().getValue(), root, targets))) {
+      emitDefiniteFailure()
+          << "could not find pattern '" << getPatternName() << "'";
+    }
+  }
+  results.set(llvm::cast<OpResult>(getResult()), targets);
+  return DiagnosedSilenceableFailure::success();
+}
+
+void transform::PDLMatchOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  onlyReadsHandle(getRoot(), effects);
+  producesHandle(getMatched(), effects);
+  onlyReadsPayload(effects);
+}
+
+//===----------------------------------------------------------------------===//
+// WithPDLPatternsOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::WithPDLPatternsOp::apply(transform::TransformResults &results,
+                                    transform::TransformState &state) {
+  TransformOpInterface transformOp = nullptr;
+  for (Operation &nested : getBody().front()) {
+    if (!isa<pdl::PatternOp>(nested)) {
+      transformOp = cast<TransformOpInterface>(nested);
+      break;
+    }
+  }
+
+  state.addExtension<PatternApplicatorExtension>(getOperation());
+  auto guard = llvm::make_scope_exit(
+      [&]() { state.removeExtension<PatternApplicatorExtension>(); });
+
+  auto scope = state.make_region_scope(getBody());
+  if (failed(mapBlockArguments(state)))
+    return DiagnosedSilenceableFailure::definiteFailure();
+  return state.applyTransform(transformOp);
+}
+
+void transform::WithPDLPatternsOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  getPotentialTopLevelEffects(effects);
+}
+
+LogicalResult transform::WithPDLPatternsOp::verify() {
+  Block *body = getBodyBlock();
+  Operation *topLevelOp = nullptr;
+  for (Operation &op : body->getOperations()) {
+    if (isa<pdl::PatternOp>(op))
+      continue;
+
+    if (op.hasTrait<::mlir::transform::PossibleTopLevelTransformOpTrait>()) {
+      if (topLevelOp) {
+        InFlightDiagnostic diag =
+            emitOpError() << "expects only one non-pattern op in its body";
+        diag.attachNote(topLevelOp->getLoc()) << "first non-pattern op";
+        diag.attachNote(op.getLoc()) << "second non-pattern op";
+        return diag;
+      }
+      topLevelOp = &op;
+      continue;
+    }
+
+    InFlightDiagnostic diag =
+        emitOpError()
+        << "expects only pattern and top-level transform ops in its body";
+    diag.attachNote(op.getLoc()) << "offending op";
+    return diag;
+  }
+
+  if (auto parent = getOperation()->getParentOfType<WithPDLPatternsOp>()) {
+    InFlightDiagnostic diag = emitOpError() << "cannot be nested";
+    diag.attachNote(parent.getLoc()) << "parent operation";
+    return diag;
+  }
+
+  if (!topLevelOp) {
+    InFlightDiagnostic diag = emitOpError()
+                              << "expects at least one non-pattern op";
+    return diag;
+  }
+
+  return success();
+}

diff  --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index b0b4ed94fc759..39dd7b006c89a 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -114,6 +114,16 @@ declare_mlir_dialect_python_bindings(
   DIALECT_NAME linalg
   DEPENDS LinalgOdsGen)
 
+declare_mlir_dialect_extension_python_bindings(
+ADD_TO_PARENT MLIRPythonSources.Dialects
+ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
+  TD_FILE dialects/TransformPDLExtensionOps.td
+  SOURCES
+    dialects/_transform_pdl_extension_ops_ext.py
+    dialects/transform/pdl.py
+  DIALECT_NAME transform
+  EXTENSION_NAME transform_pdl_extension)
+
 declare_mlir_dialect_python_bindings(
   ADD_TO_PARENT MLIRPythonSources.Dialects
   ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"

diff  --git a/mlir/python/mlir/dialects/TransformPDLExtensionOps.td b/mlir/python/mlir/dialects/TransformPDLExtensionOps.td
new file mode 100644
index 0000000000000..e3e5daf18d738
--- /dev/null
+++ b/mlir/python/mlir/dialects/TransformPDLExtensionOps.td
@@ -0,0 +1,20 @@
+//===-- TransformPDLExtensionOps.td - Binding entry point --*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Entry point of the generated Python bindings for the PDL extension of the
+// Transform dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef PYTHON_BINDINGS_TRANSFORM_PDL_EXTENSION_OPS
+#define PYTHON_BINDINGS_TRANSFORM_PDL_EXTENSION_OPS
+
+include "mlir/Bindings/Python/Attributes.td"
+include "mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.td"
+
+#endif // PYTHON_BINDINGS_TRANSFORM_PDL_EXTENSION_OPS

diff  --git a/mlir/python/mlir/dialects/_transform_ops_ext.py b/mlir/python/mlir/dialects/_transform_ops_ext.py
index 8651c76ea7dfc..cc4428ea5b115 100644
--- a/mlir/python/mlir/dialects/_transform_ops_ext.py
+++ b/mlir/python/mlir/dialects/_transform_ops_ext.py
@@ -60,26 +60,6 @@ def __init__(
     )
 
 
-class PDLMatchOp:
-
-  def __init__(
-      self,
-      result_type: Type,
-      target: Union[Operation, Value],
-      pattern_name: Union[Attribute, str],
-      *,
-      loc=None,
-      ip=None,
-  ):
-    super().__init__(
-        result_type,
-        _get_op_result_or_value(target),
-        pattern_name,
-        loc=loc,
-        ip=ip,
-    )
-
-
 class ReplicateOp:
 
   def __init__(
@@ -152,28 +132,6 @@ def bodyExtraArgs(self) -> BlockArgumentList:
     return self.body.arguments[1:]
 
 
-class WithPDLPatternsOp:
-
-  def __init__(self,
-               target: Union[Operation, Value, Type],
-               *,
-               loc=None,
-               ip=None):
-    root = _get_op_result_or_value(target) if not isinstance(target,
-                                                             Type) else None
-    root_type = target if isinstance(target, Type) else root.type
-    super().__init__(root=root, loc=loc, ip=ip)
-    self.regions[0].blocks.append(root_type)
-
-  @property
-  def body(self) -> Block:
-    return self.regions[0].blocks[0]
-
-  @property
-  def bodyTarget(self) -> Value:
-    return self.body.arguments[0]
-
-
 class YieldOp:
 
   def __init__(

diff  --git a/mlir/python/mlir/dialects/_transform_pdl_extension_ops_ext.py b/mlir/python/mlir/dialects/_transform_pdl_extension_ops_ext.py
new file mode 100644
index 0000000000000..c4e4b4b4254b0
--- /dev/null
+++ b/mlir/python/mlir/dialects/_transform_pdl_extension_ops_ext.py
@@ -0,0 +1,55 @@
+#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+#  See https://llvm.org/LICENSE.txt for license information.
+#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+try:
+  from ..ir import *
+  from ._ods_common import (
+      get_op_result_or_value as _get_op_result_or_value,
+      get_op_results_or_values as _get_op_results_or_values,
+  )
+except ImportError as e:
+  raise RuntimeError("Error loading imports from extension module") from e
+
+from typing import Union
+
+class PDLMatchOp:
+
+  def __init__(
+      self,
+      result_type: Type,
+      target: Union[Operation, Value],
+      pattern_name: Union[Attribute, str],
+      *,
+      loc=None,
+      ip=None,
+  ):
+    super().__init__(
+        result_type,
+        _get_op_result_or_value(target),
+        pattern_name,
+        loc=loc,
+        ip=ip,
+    )
+
+
+class WithPDLPatternsOp:
+
+  def __init__(self,
+               target: Union[Operation, Value, Type],
+               *,
+               loc=None,
+               ip=None):
+    root = _get_op_result_or_value(target) if not isinstance(target,
+                                                             Type) else None
+    root_type = target if isinstance(target, Type) else root.type
+    super().__init__(root=root, loc=loc, ip=ip)
+    self.regions[0].blocks.append(root_type)
+
+  @property
+  def body(self) -> Block:
+    return self.regions[0].blocks[0]
+
+  @property
+  def bodyTarget(self) -> Value:
+    return self.body.arguments[0]

diff  --git a/mlir/python/mlir/dialects/transform/pdl.py b/mlir/python/mlir/dialects/transform/pdl.py
new file mode 100644
index 0000000000000..b1515287a3f1f
--- /dev/null
+++ b/mlir/python/mlir/dialects/transform/pdl.py
@@ -0,0 +1,5 @@
+#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+#  See https://llvm.org/LICENSE.txt for license information.
+#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from .._transform_pdl_extension_ops_gen import *

diff  --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index 11a556877f5f6..a885c89af0317 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -83,33 +83,6 @@ transform.sequence failures(propagate) {
 
 // -----
 
-transform.with_pdl_patterns {
-^bb0(%arg0: !transform.any_op):
-  sequence %arg0 : !transform.any_op failures(propagate) {
-  ^bb0(%arg1: !transform.any_op):
-    %0 = pdl_match @some in %arg1 : (!transform.any_op) -> !transform.any_op
-    test_print_remark_at_operand %0, "matched" : !transform.any_op
-  }
-
-  pdl.pattern @some : benefit(1) {
-    %0 = pdl.operation "test.some_op"
-    pdl.rewrite %0 with "transform.dialect"
-  }
-
-  pdl.pattern @other : benefit(1) {
-    %0 = pdl.operation "test.other_op"
-    pdl.rewrite %0 with "transform.dialect"
-  }
-}
-
-// expected-remark @below {{matched}}
-"test.some_op"() : () -> ()
-"test.other_op"() : () -> ()
-// expected-remark @below {{matched}}
-"test.some_op"() : () -> ()
-
-// -----
-
 // expected-remark @below {{parent function}}
 func.func @foo() {
   %0 = arith.constant 0 : i32

diff  --git a/mlir/test/Dialect/Transform/test-pdl-extension.mlir b/mlir/test/Dialect/Transform/test-pdl-extension.mlir
new file mode 100644
index 0000000000000..b5f9fbf451291
--- /dev/null
+++ b/mlir/test/Dialect/Transform/test-pdl-extension.mlir
@@ -0,0 +1,47 @@
+// RUN: mlir-opt %s --test-transform-dialect-interpreter -allow-unregistered-dialect --split-input-file --verify-diagnostics
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !transform.any_op):
+  sequence %arg0 : !transform.any_op failures(propagate) {
+  ^bb0(%arg1: !transform.any_op):
+    %0 = pdl_match @some in %arg1 : (!transform.any_op) -> !transform.any_op
+    test_print_remark_at_operand %0, "matched" : !transform.any_op
+  }
+
+  pdl.pattern @some : benefit(1) {
+    %0 = pdl.operation "test.some_op"
+    pdl.rewrite %0 with "transform.dialect"
+  }
+
+  pdl.pattern @other : benefit(1) {
+    %0 = pdl.operation "test.other_op"
+    pdl.rewrite %0 with "transform.dialect"
+  }
+}
+
+// expected-remark @below {{matched}}
+"test.some_op"() : () -> ()
+"test.other_op"() : () -> ()
+// expected-remark @below {{matched}}
+"test.some_op"() : () -> ()
+
+
+// -----
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !transform.any_op):
+  sequence %arg0 : !transform.any_op failures(propagate) {
+  ^bb1(%arg1: !transform.any_op):
+    %0 = pdl_match @some in %arg1 : (!transform.any_op) -> !transform.any_op
+  }
+
+  pdl.pattern @some : benefit(1) {
+    %0 = pdl.operation "test.some_op"
+    pdl.apply_native_constraint "verbose_constraint"(%0 : !pdl.operation)
+    pdl.rewrite %0 with "transform.dialect"
+  }
+}
+
+// expected-warning @below {{from PDL constraint}}
+"test.some_op"() : () -> ()
+"test.other_op"() : () -> ()

diff  --git a/mlir/test/lib/Dialect/Transform/CMakeLists.txt b/mlir/test/lib/Dialect/Transform/CMakeLists.txt
index b86b8f56ba6c4..c7e83d3a7128b 100644
--- a/mlir/test/lib/Dialect/Transform/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Transform/CMakeLists.txt
@@ -21,4 +21,5 @@ add_mlir_library(MLIRTestTransformDialect
   MLIRPDLDialect
   MLIRTransformDialect
   MLIRTransformDialectTransforms
+  MLIRTransformPDLExtension
 )

diff  --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
index 50a4c92da9aa4..2b23b88f40bef 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
@@ -17,7 +17,9 @@
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
 #include "mlir/Dialect/Transform/IR/TransformOps.h"
+#include "mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.h"
 #include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/PatternMatch.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/Compiler.h"
@@ -754,6 +756,23 @@ class TestTransformDialectExtension
 #define GET_TYPEDEF_LIST
 #include "TestTransformDialectExtensionTypes.cpp.inc"
         >();
+
+    auto verboseConstraint = [](PatternRewriter &rewriter,
+                                ArrayRef<PDLValue> pdlValues) {
+      for (const PDLValue &pdlValue : pdlValues) {
+        if (Operation *op = pdlValue.dyn_cast<Operation *>()) {
+          op->emitWarning() << "from PDL constraint";
+        }
+      }
+      return success();
+    };
+
+    addDialectDataInitializer<transform::PDLMatchHooks>(
+        [&](transform::PDLMatchHooks &hooks) {
+          llvm::StringMap<PDLConstraintFunction> constraints;
+          constraints.try_emplace("verbose_constraint", verboseConstraint);
+          hooks.mergeInPDLMatchHooks(std::move(constraints));
+        });
   }
 };
 } // namespace

diff  --git a/mlir/test/python/dialects/transform.py b/mlir/test/python/dialects/transform.py
index 5b64582dcd6de..6b36c025eafa3 100644
--- a/mlir/test/python/dialects/transform.py
+++ b/mlir/test/python/dialects/transform.py
@@ -2,7 +2,7 @@
 
 from mlir.ir import *
 from mlir.dialects import transform
-from mlir.dialects import pdl
+from mlir.dialects.transform import pdl as transform_pdl
 
 
 def run(f):
@@ -103,13 +103,13 @@ def testNestedSequenceOpWithExtras():
 
 @run
 def testTransformPDLOps():
-  withPdl = transform.WithPDLPatternsOp(transform.AnyOpType.get())
+  withPdl = transform_pdl.WithPDLPatternsOp(transform.AnyOpType.get())
   with InsertionPoint(withPdl.body):
     sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE,
                                     [transform.AnyOpType.get()],
                                     withPdl.bodyTarget)
     with InsertionPoint(sequence.body):
-      match = transform.PDLMatchOp(transform.AnyOpType.get(), sequence.bodyTarget, "pdl_matcher")
+      match = transform_pdl.PDLMatchOp(transform.AnyOpType.get(), sequence.bodyTarget, "pdl_matcher")
       transform.YieldOp(match)
   # CHECK-LABEL: TEST: testTransformPDLOps
   # CHECK: transform.with_pdl_patterns {
@@ -148,13 +148,13 @@ def testMergeHandlesOp():
 
 @run
 def testReplicateOp():
-  with_pdl = transform.WithPDLPatternsOp(transform.AnyOpType.get())
+  with_pdl = transform_pdl.WithPDLPatternsOp(transform.AnyOpType.get())
   with InsertionPoint(with_pdl.body):
     sequence = transform.SequenceOp(
         transform.FailurePropagationMode.PROPAGATE, [], with_pdl.bodyTarget)
     with InsertionPoint(sequence.body):
-      m1 = transform.PDLMatchOp(transform.AnyOpType.get(), sequence.bodyTarget, "first")
-      m2 = transform.PDLMatchOp(transform.AnyOpType.get(), sequence.bodyTarget, "second")
+      m1 = transform_pdl.PDLMatchOp(transform.AnyOpType.get(), sequence.bodyTarget, "first")
+      m2 = transform_pdl.PDLMatchOp(transform.AnyOpType.get(), sequence.bodyTarget, "second")
       transform.ReplicateOp(m1, [m2])
       transform.YieldOp()
   # CHECK-LABEL: TEST: testReplicateOp

diff  --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py
index 9684bfb47f1b0..d2a82b8218f25 100644
--- a/mlir/test/python/dialects/transform_structured_ext.py
+++ b/mlir/test/python/dialects/transform_structured_ext.py
@@ -4,6 +4,7 @@
 from mlir.dialects import transform
 from mlir.dialects import pdl
 from mlir.dialects.transform import structured
+from mlir.dialects.transform import pdl as transform_pdl
 
 
 def run(f):
@@ -151,13 +152,13 @@ def testTileZero():
 
 @run
 def testTileDynamic():
-  with_pdl = transform.WithPDLPatternsOp(pdl.OperationType.get())
+  with_pdl = transform_pdl.WithPDLPatternsOp(pdl.OperationType.get())
   with InsertionPoint(with_pdl.body):
     sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [],
                                     with_pdl.bodyTarget)
     with InsertionPoint(sequence.body):
-      m1 = transform.PDLMatchOp(pdl.OperationType.get(), sequence.bodyTarget, "first")
-      m2 = transform.PDLMatchOp(pdl.OperationType.get(), sequence.bodyTarget, "second")
+      m1 = transform_pdl.PDLMatchOp(pdl.OperationType.get(), sequence.bodyTarget, "first")
+      m2 = transform_pdl.PDLMatchOp(pdl.OperationType.get(), sequence.bodyTarget, "second")
       structured.TileOp(sequence.bodyTarget,
                         sizes=[m1, 3, m2, 0])
       transform.YieldOp()

diff  --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index c9ba6519b6504..b36bdf931a282 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -7495,6 +7495,7 @@ cc_library(
         ":TosaToLinalg",
         ":TransformDialect",
         ":TransformDialectTransforms",
+        ":TransformPDLExtension",
         ":Transforms",
         ":TransformsPassIncGen",
         ":VectorDialect",
@@ -9732,7 +9733,6 @@ td_library(
         ":ControlFlowInterfacesTdFiles",
         ":InferTypeOpInterfaceTdFiles",
         ":OpBaseTdFiles",
-        ":PDLDialectTdFiles",
         ":SideEffectInterfacesTdFiles",
     ],
 )
@@ -9889,8 +9889,6 @@ cc_library(
         ":CallOpInterfaces",
         ":ControlFlowInterfaces",
         ":IR",
-        ":PDLDialect",
-        ":PDLInterpDialect",
         ":Rewrite",
         ":SideEffectInterfaces",
         ":Support",
@@ -9906,6 +9904,54 @@ cc_library(
     ],
 )
 
+td_library(
+    name = "TransformPDLExtensionTdFiles",
+    srcs = glob(["include/mlir/Dialect/Transform/PDLExtension/*.td"]),
+    deps = [
+        ":PDLDialectTdFiles",
+        ":TransformDialectTdFiles",
+    ],
+)
+
+gentbl_cc_library(
+    name = "TransformPDLExtensionOpsIncGen",
+    strip_include_prefix = "include",
+    tbl_outs = [
+        (
+            [
+                "-gen-op-decls",
+            ],
+            "include/mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.h.inc",
+        ),
+        (
+            [
+                "-gen-op-defs",
+            ],
+            "include/mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.cpp.inc",
+        ),
+    ],
+    tblgen = ":mlir-tblgen",
+    td_file = "include/mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.td",
+    deps = [":TransformPDLExtensionTdFiles"],
+)
+
+cc_library(
+    name = "TransformPDLExtension",
+    srcs = glob(["lib/Dialect/Transform/PDLExtension/*.cpp"]),
+    hdrs = glob(["include/mlir/Dialect/Transform/PDLExtension/*.h"]),
+    deps = [
+        ":IR",
+        ":PDLDialect",
+        ":PDLInterpDialect",
+        ":SideEffectInterfaces",
+        ":Support",
+        ":TransformDialect",
+        ":TransformPDLExtensionOpsIncGen",
+        ":Rewrite",
+        "//llvm:Support",
+    ],
+)
+
 td_library(
     name = "TransformDialectTransformsTdFiles",
     srcs = glob(["include/mlir/Dialect/Transform/Transforms/*.td"]),

diff  --git a/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel
index 06a97f4c921ee..f6c87ea23291e 100644
--- a/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel
@@ -927,6 +927,26 @@ gentbl_filegroup(
     ],
 )
 
+gentbl_filegroup(
+    name = "PDLTransformOpsPyGen",
+    tbl_outs = [
+        (
+            [
+                "-gen-python-op-bindings",
+                "-bind-dialect=transform",
+                "-dialect-extension=transform_pdl_extension",
+            ],
+            "mlir/dialects/_transform_pdl_extension_ops_gen.py",
+        ),
+    ],
+    tblgen = "//mlir:mlir-tblgen",
+    td_file = "mlir/dialects/TransformPDLExtensionOps.td",
+    deps = [
+        ":TransformOpsPyTdFiles",
+        "//mlir:TransformPDLExtensionTdFiles",
+    ],
+)
+
 filegroup(
     name = "TransformOpsPyFiles",
     srcs = [
@@ -934,6 +954,7 @@ filegroup(
         "mlir/dialects/_structured_transform_ops_ext.py",
         "mlir/dialects/_transform_ops_ext.py",
         ":LoopTransformOpsPyGen",
+        ":PDLTransformOpsPyGen",
         ":StructuredTransformOpsPyGen",
         ":TransformOpsPyGen",
     ],

diff  --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
index b4f8ca79f22c2..c95aea56bb4ce 100644
--- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
@@ -317,6 +317,7 @@ gentbl_cc_library(
         ":TransformDialectTdFiles",
         "//mlir:PDLDialectTdFiles",
         "//mlir:TransformDialectTdFiles",
+        "//mlir:TransformPDLExtension",
     ],
 )
 
@@ -333,6 +334,7 @@ cc_library(
         "//mlir:Pass",
         "//mlir:TransformDialect",
         "//mlir:TransformDialectTransforms",
+        "//mlir:TransformPDLExtension",
     ],
 )
 


        


More information about the Mlir-commits mailing list