[Mlir-commits] [mlir] 30f2242 - [mlir] Connect Transform dialect to PDL

Alex Zinenko llvmlistbot at llvm.org
Thu Apr 21 07:23:18 PDT 2022


Author: Alex Zinenko
Date: 2022-04-21T16:23:10+02:00
New Revision: 30f22429d38944e126db75296a1ffc6c12c7b87a

URL: https://github.com/llvm/llvm-project/commit/30f22429d38944e126db75296a1ffc6c12c7b87a
DIFF: https://github.com/llvm/llvm-project/commit/30f22429d38944e126db75296a1ffc6c12c7b87a.diff

LOG: [mlir] Connect Transform dialect to PDL

This introduces a pair of ops to the Transform dialect that connect it to PDL
patterns. Transform dialect relies on PDL for matching the Payload IR ops that
are about to be transformed. For this purpose, it provides a container op for
patterns, a "pdl_match" op and transform interface implementations that call
into the pattern matching infrastructure.

To enable the caching of compiled patterns, this also provides the extension
mechanism for TransformState. Extensions allow one to store additional
information in the TransformState and thus communicate it between different
Transform dialect operations when they are applied. They can be added and
removed when applying transform ops. An extension containing a symbol table in
which the pattern names are resolved and a pattern compilation cache is
introduced as the first client.

Depends On D123664

Reviewed By: Mogball

Differential Revision: https://reviews.llvm.org/D124007

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h
    mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
    mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
    mlir/include/mlir/Dialect/Transform/IR/TransformOps.h
    mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
    mlir/lib/Dialect/Transform/IR/CMakeLists.txt
    mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
    mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
    mlir/lib/Dialect/Transform/IR/TransformOps.cpp
    mlir/test/Dialect/Transform/ops-invalid.mlir
    mlir/test/Dialect/Transform/ops.mlir
    mlir/test/Dialect/Transform/test-interpreter.mlir
    mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
    mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
    utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h
index 628a46535f338..d1607b57622f9 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h
@@ -9,9 +9,13 @@
 #ifndef MLIR_DIALECT_TRANSFORM_IR_TRANSFORMDIALECT_H
 #define MLIR_DIALECT_TRANSFORM_IR_TRANSFORMDIALECT_H
 
+#include "mlir/Dialect/PDL/IR/PDL.h"
+#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
 #include "mlir/IR/Dialect.h"
+#include "mlir/IR/PatternMatch.h"
 #include "mlir/Support/LLVM.h"
+#include "llvm/ADT/StringMap.h"
 
 #include "mlir/Dialect/Transform/IR/TransformDialect.h.inc"
 
@@ -57,6 +61,7 @@ class TransformDialectExtension
       loader(context);
     for (const Initializer &init : opInitializers)
       init(transformDialect);
+    transformDialect->mergeInPDLMatchHooks(std::move(pdlMatchConstraintFns));
   }
 
 protected:
@@ -88,9 +93,30 @@ class TransformDialectExtension
         [](MLIRContext *context) { context->loadDialect<DialectTy>(); });
   }
 
+  /// Injects the named constraint to make it available for use with the
+  /// PDLMatchOp in the transform dialect.
+  void registerPDLMatchConstraintFn(StringRef name,
+                                    PDLConstraintFunction &&fn) {
+    pdlMatchConstraintFns.try_emplace(name,
+                                      std::forward<PDLConstraintFunction>(fn));
+  }
+  template <typename ConstraintFnTy>
+  void registerPDLMatchConstraintFn(StringRef name, ConstraintFnTy &&fn) {
+    pdlMatchConstraintFns.try_emplace(
+        name, ::mlir::detail::pdl_function_builder::buildConstraintFn(
+                  std::forward<ConstraintFnTy>(fn)));
+  }
+
 private:
   SmallVector<Initializer> opInitializers;
   SmallVector<DialectLoader> dialectLoaders;
+
+  /// A list of constraints that should be made availble to PDL patterns
+  /// processed by PDLMatchOp in the Transform dialect.
+  ///
+  /// Declared as mutable so its contents can be moved in the `apply` const
+  /// method, which is only called once.
+  mutable llvm::StringMap<PDLConstraintFunction> pdlMatchConstraintFns;
 };
 
 } // namespace transform

diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
index aca6497bcb9c1..d695b850474a4 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
@@ -84,6 +84,13 @@ def Transform_Dialect : Dialect {
     `LoopTransformDialectExtension` in the cases above. Unprefixed operation
     names are reserved for ops defined directly in the Transform dialect.
 
+    Overall, Transform IR ops are expected to be contained in a single top-level
+    op. Such top-level ops specifie how to apply the transformations described
+    by operations they contain, e.g., `transform.sequence` executes
+    transformations one by one and fails if any of them fails. Such ops are
+    expected to have the `PossibleTopLevelTransformOpTrait` and may be used
+    without arguments.
+
     ## Intended Use and Integrations
 
     The transformation control infrastructure provided by this dialect is
@@ -163,13 +170,32 @@ def Transform_Dialect : Dialect {
   let cppNamespace = "::mlir::transform";
   let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed;
 
+  let dependentDialects = [
+    "::mlir::pdl::PDLDialect",
+    "::mlir::pdl_interp::PDLInterpDialect",
+  ];
+
   let extraClassDeclaration = [{
-    // Make addOperations available to the TransformDialectExtension class.
+      /// Returns the named PDL constraint functions available in the dialect
+      /// as a map from their name to the function.
+      const ::llvm::StringMap<::mlir::PDLConstraintFunction> &
+      getPDLConstraintHooks() const;
+
     private:
+      // Make addOperations available to the TransformDialectExtension class.
       using ::mlir::Dialect::addOperations;
 
       template <typename, typename...>
       friend class TransformDialectExtension;
+
+      /// Takes ownership of the named PDL constraint function from the given
+      /// map and makes them available for use by the operations in the dialect.
+      void mergeInPDLMatchHooks(
+          ::llvm::StringMap<::mlir::PDLConstraintFunction> &&constraintFns);
+
+      /// A container for PDL constraint function that can be used by
+      /// operations in this dialect.
+      PDLPatternModule pdlMatchHooks;
   }];
 }
 
@@ -178,4 +204,12 @@ def Transform_Dialect : Dialect {
 class TransformDialectOp<string mnemonic, list<Trait> traits = []>
     : Op<Transform_Dialect, mnemonic, traits>;
 
+// Trait for operations that may be top-level operations in Transform IR.
+// Operations must have one single-block region and must be usable without
+// operands. See the C++ definition of the trait for more information.
+def PossibleTopLevelTransformOpTrait
+    : NativeOpTrait<"PossibleTopLevelTransformOpTrait"> {
+  let cppNamespace = "::mlir::transform";
+}
+
 #endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORMDIALECT

diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index f109ad599b841..49d3bcd8be454 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -140,6 +140,89 @@ class TransformState {
   };
   friend class RegionScope;
 
+  /// Base class for TransformState extensions that allow TransformState to
+  /// contain user-specified information in the state object. Clients are
+  /// expected to derive this class, add the desired fields, and make the
+  /// derived class compatible with the MLIR TypeID mechanism:
+  ///
+  /// ```mlir
+  /// class MyExtension final : public TransformState::Extension {
+  /// public:
+  ///   MyExtension(TranfsormState &state, int myData)
+  ///     : Extension(state) {...}
+  /// private:
+  ///   int mySupplementaryData;
+  /// };
+  /// ```
+  ///
+  /// Instances of this and derived classes are not expected to be created by
+  /// the user, instead they are directly constructed within a TransformState. A
+  /// TransformState can only contain one extension with the given TypeID.
+  /// Extensions can be obtained from a TransformState instance, and can be
+  /// removed when they are no longer required.
+  ///
+  /// ```mlir
+  /// transformState.addExtension<MyExtension>(/*myData=*/42);
+  /// MyExtension *ext = transformState.getExtension<MyExtension>();
+  /// ext->doSomething();
+  /// ```
+  class Extension {
+    // Allow TransformState to allocate Extensions.
+    friend class TransformState;
+
+  public:
+    /// Base virtual destructor.
+    // Out-of-line definition ensures symbols are emitted in a single object
+    // file.
+    virtual ~Extension();
+
+  protected:
+    /// Constructs an extension of the given TransformState object.
+    Extension(TransformState &state) : state(state) {}
+
+  private:
+    /// Back-reference to the state that is being extended.
+    TransformState &state;
+  };
+
+  /// Adds a new Extension of the type specified as template parameter,
+  /// constructing it with the arguments provided. The extension is owned by the
+  /// TransformState. It is expected that the state does not already have an
+  /// extension of the same type. Extension constructors are expected to take
+  /// a reference to TransformState as first argument, automatically supplied
+  /// by this call.
+  template <typename Ty, typename... Args>
+  Ty &addExtension(Args &&...args) {
+    static_assert(
+        std::is_base_of<Extension, Ty>::value,
+        "only an class derived from TransformState::Extension is allowed here");
+    auto ptr = std::make_unique<Ty>(*this, std::forward<Args>(args)...);
+    auto result = extensions.try_emplace(TypeID::get<Ty>(), std::move(ptr));
+    assert(result.second && "extension already added");
+    return *static_cast<Ty *>(result.first->second.get());
+  }
+
+  /// Returns the extension of the specified type.
+  template <typename Ty>
+  Ty *getExtension() {
+    static_assert(
+        std::is_base_of<Extension, Ty>::value,
+        "only an class derived from TransformState::Extension is allowed here");
+    auto iter = extensions.find(TypeID::get<Ty>());
+    if (iter == extensions.end())
+      return nullptr;
+    return static_cast<Ty *>(iter->second.get());
+  }
+
+  /// Removes the extension of the specified type.
+  template <typename Ty>
+  void removeExtension() {
+    static_assert(
+        std::is_base_of<Extension, Ty>::value,
+        "only an class derived from TransformState::Extension is allowed here");
+    extensions.erase(TypeID::get<Ty>());
+  }
+
 private:
   /// Identifier for storing top-level value in the `operations` mapping.
   static constexpr Value kTopLevelValue = Value();
@@ -196,6 +279,10 @@ class TransformState {
   /// the region in which the transform IR values are defined.
   llvm::SmallDenseMap<Region *, Mappings> mappings;
 
+  /// Extensions attached to the TransformState, identified by the TypeID of
+  /// their type. Only one extension of any given type is allowed.
+  DenseMap<TypeID, std::unique_ptr<Extension>> extensions;
+
   /// The top-level operation that contains all payload IR, typically a module.
   Operation *topLevel;
 
@@ -241,6 +328,54 @@ TransformState::RegionScope TransformState::make_region_scope(Region &region) {
   return RegionScope(*this, region);
 }
 
+namespace detail {
+/// Maps the only block argument of the op with PossibleTopLevelTransformOpTrait
+/// to either the list of operations associated with its operand or the root of
+/// the payload IR, depending on what is available in the context.
+LogicalResult
+mapPossibleTopLevelTransformOpBlockArguments(TransformState &state,
+                                             Operation *op);
+
+/// Verification hook for PossibleTopLevelTransformOpTrait.
+LogicalResult verifyPossibleTopLevelTransformOpTrait(Operation *op);
+} // namespace detail
+
+/// This trait is supposed to be attached to Transform dialect operations that
+/// can be standalone top-level transforms. Such operations typically contain
+/// other Transform dialect operations that can be executed following some
+/// control flow logic specific to the current operation. The operations with
+/// this trait are expected to have exactly one single-block region with one
+/// argument of PDL Operation type. The operations are also expected to be valid
+/// without operands, in which case they are considered top-level, and with one
+/// or more arguments, in which case they are considered nested. Top-level
+/// operations have the block argument of the entry block in the Transform IR
+/// correspond to the root operation of Payload IR. Nested operations have the
+/// block argument of the entry block in the Transform IR correspond to a list
+/// of Payload IR operations mapped to the first operand of the Transform IR
+/// operation. The operation must implement TransformOpInterface.
+template <typename OpTy>
+class PossibleTopLevelTransformOpTrait
+    : public OpTrait::TraitBase<OpTy, PossibleTopLevelTransformOpTrait> {
+public:
+  /// Verifies that `op` satisfies the invariants of this trait. Not expected to
+  /// be called directly.
+  static LogicalResult verifyTrait(Operation *op) {
+    return detail::verifyPossibleTopLevelTransformOpTrait(op);
+  }
+
+  /// Returns the single block of the op's only region.
+  Block *getBodyBlock() { return &this->getOperation()->getRegion(0).front(); }
+
+  /// Sets up the mapping between the entry block of the only region of this op
+  /// and the relevant list of Payload IR operations in the given state. The
+  /// state is expected to be already scoped at the region of this operation.
+  /// Returns failure if the mapping failed, e.g., the value is already mapped.
+  LogicalResult mapBlockArguments(TransformState &state) {
+    return detail::mapPossibleTopLevelTransformOpBlockArguments(
+        state, this->getOperation());
+  }
+};
+
 } // namespace transform
 } // namespace mlir
 

diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h
index a12b5abd8ffc8..9714b77ad9683 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h
@@ -13,6 +13,7 @@
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/SymbolTable.h"
 
 #define GET_OP_CLASSES
 #include "mlir/Dialect/Transform/IR/TransformOps.h.inc"

diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index 246de281568b1..489197fe46f60 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -10,12 +10,40 @@
 #define MLIR_DIALECT_TRANSFORM_IR_TRANSFORMOPS
 
 include "mlir/IR/OpAsmInterface.td"
+include "mlir/IR/SymbolInterfaces.td"
 include "mlir/Dialect/PDL/IR/PDLTypes.td"
 include "mlir/Dialect/Transform/IR/TransformDialect.td"
 include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
 
+def PDLMatchOp : TransformDialectOp<"pdl_match",
+    [DeclareOpInterfaceMethods<TransformOpInterface>]> {
+  let summary = "Finds ops that match the named PDL pattern";
+  let description = [{
+    Find Payload IR ops nested within the Payload IR op associated with the
+    operand that match the PDL pattern identified by its name. The pattern is
+    expected to be defined in the closest surrounding `WithPDLPatternsOp`.
+
+    Produces a Transform IR value associated with the list of Payload IR ops
+    that matched the pattern. The order of results in the list is that of the
+    Operation::walk, clients are advised not to rely on a specific order though.
+    If the operand is assocaited with multiple Payload IR ops, finds matching
+    ops nested within each of those and produces a single list containing all
+    of the matched ops.
+
+    The tranfsormation is considered successful regardless of whether some
+    Payload IR ops actually matched the pattern and only fails if the pattern
+    could not be looked up or compiled.
+  }];
+
+  let arguments = (ins PDL_Operation:$root, SymbolRefAttr:$pattern_name);
+  let results = (outs PDL_Operation:$matched);
+
+  let assemblyFormat = "$pattern_name `in` $root attr-dict";
+}
+
 def SequenceOp : TransformDialectOp<"sequence",
     [DeclareOpInterfaceMethods<TransformOpInterface>, OpAsmOpInterface,
+     PossibleTopLevelTransformOpTrait,
      SingleBlockImplicitTerminator<"::mlir::transform::YieldOp">]> {
   let summary = "Contains a sequence of other transform ops to apply";
   let description = [{
@@ -48,13 +76,60 @@ def SequenceOp : TransformDialectOp<"sequence",
   let extraClassDeclaration = [{
     /// Allow the dialect prefix to be omitted.
     static StringRef getDefaultDialect() { return "transform"; }
+  }];
+
+  let hasVerifier = 1;
+}
 
-    Block *getBodyBlock() {
-      return &getBody().front();
+def WithPDLPatternsOp : TransformDialectOp<"with_pdl_patterns",
+    [DeclareOpInterfaceMethods<TransformOpInterface>, NoTerminator,
+     OpAsmOpInterface, PossibleTopLevelTransformOpTrait, SymbolTable]> {
+  let summary = "Contains PDL patterns available for use in transforms";
+  let description = [{
+    This op contains a set of named PDL patterns that are available for the
+    Transform dialect operations to be used for pattern matching. For example,
+    PDLMatchOp can be used to produce a Transform IR value associated with all
+    Payload IR operations that match the pattern as follows:
+
+    ```mlir
+    transform.with_pdl_patterns {
+    ^bb0(%arg0: !pdl.operation):
+      pdl.pattern @my_pattern : benefit(1) {
+        %0 = pdl.operation //...
+        // Regular PDL goes here.
+        pdl.rewrite %0 with "transform.dialect"
+      }
+
+      sequence %arg0 {
+      ^bb0(%arg1: !pdl.operation):
+        %1 = pdl_match @my_pattern in %arg1
+        // Use %1 as handle
+      }
     }
+    ```
+
+    Note that the pattern is expected to finish with a `pdl.rewrite` terminator
+    that points to the custom rewriter named "transform.dialect". The rewriter
+    actually does nothing, but the transform application will keep track of the
+    operations that matched the pattern.
+
+    This op is expected to contain `pdl.pattern` operations and exactly one
+    another Transform dialect operation that gets executed with all patterns
+    available. This op is a possible top-level Transform IR op, the argument of
+    its entry block corresponds to either the root op of the payload IR or the
+    ops associated with its operand when provided.
   }];
 
+  let arguments = (ins Optional<PDL_Operation>:$root);
+  let regions = (region SizedRegion<1>:$body);
+  let assemblyFormat = "($root^)? attr-dict-with-keyword regions";
+
   let hasVerifier = 1;
+
+  let extraClassDeclaration = [{
+    /// Allow the dialect prefix to be omitted.
+    static StringRef getDefaultDialect() { return "transform"; }
+  }];
 }
 
 def YieldOp : TransformDialectOp<"yield", [Terminator]> {

diff  --git a/mlir/lib/Dialect/Transform/IR/CMakeLists.txt b/mlir/lib/Dialect/Transform/IR/CMakeLists.txt
index 760ce9364b0aa..a5ac053c91195 100644
--- a/mlir/lib/Dialect/Transform/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Transform/IR/CMakeLists.txt
@@ -11,4 +11,5 @@ add_mlir_dialect_library(MLIRTransformDialect
   MLIRIR
   MLIRPDL
   MLIRPDLInterp
+  MLIRRewrite
   )

diff  --git a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
index a566cb91ee750..513f8736237a4 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp
@@ -19,3 +19,15 @@ void transform::TransformDialect::initialize() {
 #include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
       >();
 }
+
+void transform::TransformDialect::mergeInPDLMatchHooks(
+    llvm::StringMap<PDLConstraintFunction> &&constraintFns) {
+  // Steal the constraint functions form the given map.
+  for (auto &it : constraintFns)
+    pdlMatchHooks.registerConstraintFunction(it.getKey(), std::move(it.second));
+}
+
+const llvm::StringMap<PDLConstraintFunction> &
+transform::TransformDialect::getPDLConstraintHooks() const {
+  return pdlMatchHooks.getConstraintFunctions();
+}

diff  --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index 7df299a94cfb6..2c9a2870dd616 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/Dialect/PDL/IR/PDLTypes.h"
 #include "mlir/IR/Diagnostics.h"
 #include "mlir/IR/Operation.h"
 #include "llvm/ADT/ScopeExit.h"
@@ -117,6 +118,8 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
   return success();
 }
 
+transform::TransformState::Extension::~Extension() = default;
+
 //===----------------------------------------------------------------------===//
 // TransformResults
 //===----------------------------------------------------------------------===//
@@ -145,6 +148,61 @@ transform::TransformResults::get(unsigned resultNumber) const {
   return segments[resultNumber];
 }
 
+//===----------------------------------------------------------------------===//
+// Utilities for PossibleTopLevelTransformOpTrait.
+//===----------------------------------------------------------------------===//
+
+LogicalResult transform::detail::mapPossibleTopLevelTransformOpBlockArguments(
+    TransformState &state, Operation *op) {
+  SmallVector<Operation *> targets;
+  if (op->getNumOperands() != 0)
+    llvm::append_range(targets, state.getPayloadOps(op->getOperand(0)));
+  else
+    targets.push_back(state.getTopLevel());
+
+  return state.mapBlockArguments(op->getRegion(0).front().getArgument(0),
+                                 targets);
+}
+
+LogicalResult
+transform::detail::verifyPossibleTopLevelTransformOpTrait(Operation *op) {
+  // Attaching this trait without the interface is a misuse of the API, but it
+  // cannot be caught via a static_assert because interface registration is
+  // dynamic.
+  assert(isa<TransformOpInterface>(op) &&
+         "should implement TransformOpInterface to have "
+         "PossibleTopLevelTransformOpTrait");
+
+  if (op->getNumRegions() != 1)
+    return op->emitOpError() << "expects one region";
+
+  Region *bodyRegion = &op->getRegion(0);
+  if (!llvm::hasNItems(*bodyRegion, 1))
+    return op->emitOpError() << "expects a single-block region";
+
+  Block *body = &bodyRegion->front();
+  if (body->getNumArguments() != 1 ||
+      !body->getArgumentTypes()[0].isa<pdl::OperationType>()) {
+    return op->emitOpError()
+           << "expects the entry block to have one argument of type "
+           << pdl::OperationType::get(op->getContext());
+  }
+
+  if (auto *parent =
+          op->getParentWithTrait<PossibleTopLevelTransformOpTrait>()) {
+    if (op->getNumOperands() == 0) {
+      InFlightDiagnostic diag =
+          op->emitOpError()
+          << "expects the root operation to be provided for a nested op";
+      diag.attachNote(parent->getLoc())
+          << "nested in another possible top-level op";
+      return diag;
+    }
+  }
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // Generated interface implementation.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 3018e3b5b68bf..c68ba11e3a06f 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -7,26 +7,143 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Transform/IR/TransformOps.h"
+#include "mlir/Dialect/PDL/IR/PDLOps.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
 #include "mlir/IR/Builders.h"
-
 #include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Rewrite/FrozenRewritePatternSet.h"
+#include "mlir/Rewrite/PatternApplicator.h"
+#include "llvm/ADT/ScopeExit.h"
 
 using namespace mlir;
 
 #define GET_OP_CLASSES
 #include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
 
-LogicalResult transform::SequenceOp::apply(transform::TransformResults &results,
+//===----------------------------------------------------------------------===//
+// PatternApplicatorExtension
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// A simple pattern rewriter that can be constructed from a context. This is
+/// necessary to apply patterns to a specific op locally.
+class TrivialPatternRewriter : public PatternRewriter {
+public:
+  explicit TrivialPatternRewriter(MLIRContext *context)
+      : PatternRewriter(context) {}
+};
+
+/// A TransformState extension that keeps track of compiled PDL pattern sets.
+/// This is intended to be used along the WithPDLPatterns op. The extension
+/// can be constructed given an operation that has a SymbolTable trait and
+/// contains pdl::PatternOp instances. The patterns are compiled lazily and one
+/// by one when requested; this behavior is subject to change.
+class PatternApplicatorExtension : public transform::TransformState::Extension {
+public:
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PatternApplicatorExtension)
+
+  /// Creates the extension for patterns contained in `patternContainer`.
+  explicit PatternApplicatorExtension(transform::TransformState &state,
+                                      Operation *patternContainer)
+      : Extension(state), patterns(patternContainer) {}
+
+  /// Appends to `results` the operations contained in `root` that matched the
+  /// PDL pattern with the given name. Note that `root` may or may not be the
+  /// operation that contains PDL patterns. Reports an error if the pattern
+  /// cannot be found. Note that when no operations are matched, this still
+  /// succeeds as long as the pattern exists.
+  LogicalResult findAllMatches(StringRef patternName, Operation *root,
+                               SmallVectorImpl<Operation *> &results);
+
+private:
+  /// Map from the pattern name to a singleton set of rewrite patterns that only
+  /// contains the pattern with this name. Populated when the pattern is first
+  /// requested.
+  // TODO: reconsider the efficiency of this storage when more usage data is
+  // available. Storing individual patterns in a set and triggering compilation
+  // for each of them has overhead. So does compiling a large set of patterns
+  // only to apply a handlful of them.
+  llvm::StringMap<FrozenRewritePatternSet> compiledPatterns;
+
+  /// A symbol table operation containing the relevant PDL patterns.
+  SymbolTable patterns;
+};
+
+LogicalResult PatternApplicatorExtension::findAllMatches(
+    StringRef patternName, Operation *root,
+    SmallVectorImpl<Operation *> &results) {
+  auto it = compiledPatterns.find(patternName);
+  if (it == compiledPatterns.end()) {
+    auto patternOp = patterns.lookup<pdl::PatternOp>(patternName);
+    if (!patternOp)
+      return failure();
+
+    OwningOpRef<ModuleOp> pdlModuleOp = ModuleOp::create(patternOp.getLoc());
+    patternOp->moveBefore(pdlModuleOp->getBody(),
+                          pdlModuleOp->getBody()->end());
+    PDLPatternModule patternModule(std::move(pdlModuleOp));
+
+    // Merge in the hooks owned by the dialect. Make a copy as they may be
+    // also used by the following operations.
+    auto *dialect =
+        root->getContext()->getLoadedDialect<transform::TransformDialect>();
+    for (const auto &pair : dialect->getPDLConstraintHooks())
+      patternModule.registerConstraintFunction(pair.first(), pair.second);
+
+    // Register a noop rewriter because PDL requires patterns to end with some
+    // rewrite call.
+    patternModule.registerRewriteFunction(
+        "transform.dialect", [](PatternRewriter &, Operation *) {});
+
+    it = compiledPatterns
+             .try_emplace(patternOp.getName(), std::move(patternModule))
+             .first;
+  }
+
+  PatternApplicator applicator(it->second);
+  TrivialPatternRewriter rewriter(root->getContext());
+  applicator.applyDefaultCostModel();
+  root->walk([&](Operation *op) {
+    if (succeeded(applicator.matchAndRewrite(op, rewriter)))
+      results.push_back(op);
+  });
+
+  return success();
+}
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// PDLMatchOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult transform::PDLMatchOp::apply(transform::TransformResults &results,
                                            transform::TransformState &state) {
+  auto *extension = state.getExtension<PatternApplicatorExtension>();
+  assert(extension &&
+         "expected PatternApplicatorExtension to be attached by the parent op");
   SmallVector<Operation *> targets;
-  if (getRoot())
-    llvm::append_range(targets, state.getPayloadOps(getRoot()));
-  else
-    targets.push_back(state.getTopLevel());
+  for (Operation *root : state.getPayloadOps(getRoot())) {
+    if (failed(extension->findAllMatches(
+            getPatternName().getLeafReference().getValue(), root, targets))) {
+      return emitOpError() << "could not find pattern '" << getPatternName()
+                           << "'";
+    }
+  }
+  results.set(getResult().cast<OpResult>(), targets);
+  return success();
+}
 
+//===----------------------------------------------------------------------===//
+// SequenceOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult transform::SequenceOp::apply(transform::TransformResults &results,
+                                           transform::TransformState &state) {
   // Map the entry block argument to the list of operations.
   auto scope = state.make_region_scope(*getBodyBlock()->getParent());
-  if (failed(state.mapBlockArguments(getBodyBlock()->getArgument(0), targets)))
+  if (failed(mapBlockArguments(state)))
     return failure();
 
   // Apply the sequenced ops one by one.
@@ -48,23 +165,6 @@ LogicalResult transform::SequenceOp::apply(transform::TransformResults &results,
 }
 
 LogicalResult transform::SequenceOp::verify() {
-  if (getBodyBlock()->getNumArguments() != 1 ||
-      !getBodyBlock()->getArgumentTypes()[0].isa<pdl::OperationType>()) {
-    return emitOpError()
-           << "expected the entry block to have one argument of type "
-           << pdl::OperationType::get(getContext());
-  }
-
-  if (auto parent = getOperation()->getParentOfType<transform::SequenceOp>()) {
-    if (!getRoot()) {
-      InFlightDiagnostic diag =
-          emitOpError()
-          << "expected the root operation to be provided for a nested sequence";
-      diag.attachNote(parent.getLoc()) << "nested in another sequence";
-      return diag;
-    }
-  }
-
   for (Operation &child : *getBodyBlock()) {
     if (!isa<TransformOpInterface>(child) &&
         &child != &getBodyBlock()->back()) {
@@ -99,3 +199,65 @@ LogicalResult transform::SequenceOp::verify() {
   }
   return success();
 }
+
+//===----------------------------------------------------------------------===//
+// WithPDLPatternsOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+transform::WithPDLPatternsOp::apply(transform::TransformResults &results,
+                                    transform::TransformState &state) {
+  OwningOpRef<ModuleOp> pdlModuleOp =
+      ModuleOp::create(getOperation()->getLoc());
+  TransformOpInterface transformOp = nullptr;
+  for (Operation &nested : getBody().front()) {
+    if (!isa<pdl::PatternOp>(nested)) {
+      transformOp = cast<TransformOpInterface>(nested);
+      break;
+    }
+  }
+
+  state.addExtension<PatternApplicatorExtension>(getOperation());
+  auto guard = llvm::make_scope_exit(
+      [&]() { state.removeExtension<PatternApplicatorExtension>(); });
+
+  auto scope = state.make_region_scope(getBody());
+  if (failed(mapBlockArguments(state)))
+    return failure();
+  return state.applyTransform(transformOp);
+}
+
+LogicalResult transform::WithPDLPatternsOp::verify() {
+  Block *body = getBodyBlock();
+  Operation *topLevelOp = nullptr;
+  for (Operation &op : body->getOperations()) {
+    if (isa<pdl::PatternOp>(op))
+      continue;
+
+    if (op.hasTrait<::mlir::transform::PossibleTopLevelTransformOpTrait>()) {
+      if (topLevelOp) {
+        InFlightDiagnostic diag =
+            emitOpError() << "expects only one non-pattern op in its body";
+        diag.attachNote(topLevelOp->getLoc()) << "first non-pattern op";
+        diag.attachNote(op.getLoc()) << "second non-pattern op";
+        return diag;
+      }
+      topLevelOp = &op;
+      continue;
+    }
+
+    InFlightDiagnostic diag =
+        emitOpError()
+        << "expects only pattern and top-level transform ops in its body";
+    diag.attachNote(op.getLoc()) << "offending op";
+    return diag;
+  }
+
+  if (auto parent = getOperation()->getParentOfType<WithPDLPatternsOp>()) {
+    InFlightDiagnostic diag = emitOpError() << "cannot be nested";
+    diag.attachNote(parent.getLoc()) << "parent operation";
+    return diag;
+  }
+
+  return success();
+}

diff  --git a/mlir/test/Dialect/Transform/ops-invalid.mlir b/mlir/test/Dialect/Transform/ops-invalid.mlir
index 614628107834e..61ed760d700f1 100644
--- a/mlir/test/Dialect/Transform/ops-invalid.mlir
+++ b/mlir/test/Dialect/Transform/ops-invalid.mlir
@@ -1,15 +1,15 @@
 // RUN: mlir-opt %s -split-input-file -verify-diagnostics
 
-// expected-error @below {{expected the entry block to have one argument of type '!pdl.operation'}}
+// expected-error @below {{expects the entry block to have one argument of type '!pdl.operation'}}
 transform.sequence {
 }
 
 // -----
 
-// expected-note @below {{nested in another sequence}}
+// expected-note @below {{nested in another possible top-level op}}
 transform.sequence {
 ^bb0(%arg0: !pdl.operation):
-  // expected-error @below {{expected the root operation to be provided for a nested sequence}}
+  // expected-error @below {{expects the root operation to be provided for a nested op}}
   transform.sequence {
   ^bb1(%arg1: !pdl.operation):
   }
@@ -50,3 +50,64 @@ transform.sequence {
   // expected-note @below {{terminator}}
   transform.yield
 } : !pdl.operation
+
+// -----
+
+// expected-note @below {{nested in another possible top-level op}}
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  // expected-error @below {{expects the root operation to be provided for a nested op}}
+  transform.sequence {
+  ^bb1(%arg1: !pdl.operation):
+  }
+}
+
+// -----
+
+// expected-error @below {{expects only one non-pattern op in its body}}
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  // expected-note @below {{first non-pattern op}}
+  transform.sequence {
+  ^bb1(%arg1: !pdl.operation):
+  }
+  // expected-note @below {{second non-pattern op}}
+  transform.sequence {
+  ^bb1(%arg1: !pdl.operation):
+  }
+}
+
+// -----
+
+// expected-error @below {{expects only pattern and top-level transform ops in its body}}
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  // expected-note @below {{offending op}}
+  "test.something"() : () -> ()
+}
+
+// -----
+
+// expected-note @below {{parent operation}}
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+   // expected-error @below {{op cannot be nested}}
+  transform.with_pdl_patterns %arg0 {
+  ^bb1(%arg1: !pdl.operation):
+  }
+}
+
+// -----
+
+// expected-error @below {{expects one region}}
+"transform.test_transform_unrestricted_op_no_interface"() : () -> ()
+
+// -----
+
+// expected-error @below {{expects a single-block region}}
+"transform.test_transform_unrestricted_op_no_interface"() ({
+^bb0(%arg0: !pdl.operation):
+  "test.potential_terminator"() : () -> ()
+^bb1:
+  "test.potential_terminator"() : () -> ()
+}) : () -> ()

diff  --git a/mlir/test/Dialect/Transform/ops.mlir b/mlir/test/Dialect/Transform/ops.mlir
index c3aab426aad26..34ee62e0bbc75 100644
--- a/mlir/test/Dialect/Transform/ops.mlir
+++ b/mlir/test/Dialect/Transform/ops.mlir
@@ -10,3 +10,23 @@ transform.sequence {
   ^bb1(%arg1: !pdl.operation):
   }
 }
+
+// CHECK: transform.with_pdl_patterns
+// CHECK: ^{{.+}}(%[[ARG:.+]]: !pdl.operation):
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  // CHECK: sequence %[[ARG]]
+  sequence %arg0 {
+  ^bb1(%arg1: !pdl.operation):
+  }
+}
+
+// CHECK: transform.sequence
+// CHECK: ^{{.+}}(%[[ARG:.+]]: !pdl.operation):
+transform.sequence {
+^bb0(%arg0: !pdl.operation):
+  // CHECK: with_pdl_patterns %[[ARG]]
+  with_pdl_patterns %arg0 {
+  ^bb1(%arg1: !pdl.operation):
+  }
+}

diff  --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index a6ceeea82a5c8..2b2416480af12 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -69,3 +69,31 @@ transform.sequence {
   // expected-remark @below {{succeeded}}
   test_consume_operand_if_matches_param_or_fail %0[42]
 }
+
+// -----
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  sequence %arg0 {
+  ^bb0(%arg1: !pdl.operation):
+    %0 = pdl_match @some in %arg1
+    test_print_remark_at_operand %0, "matched"
+  }
+
+  pdl.pattern @some : benefit(1) {
+    %0 = pdl.operation "test.some_op"
+    pdl.rewrite %0 with "transform.dialect"
+  }
+
+  pdl.pattern @other : benefit(1) {
+    %0 = pdl.operation "test.other_op"
+    pdl.rewrite %0 with "transform.dialect"
+  }
+}
+
+// expected-remark @below {{matched}}
+"test.some_op"() : () -> ()
+"test.other_op"() : () -> ()
+// expected-remark @below {{matched}}
+"test.some_op"() : () -> ()
+

diff  --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
index 4aed0aae1e776..c3bbb5a66c61e 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
@@ -22,7 +22,8 @@ using namespace mlir;
 
 namespace {
 /// Simple transform op defined outside of the dialect. Just emits a remark when
-/// applied.
+/// applied. This op is defined in C++ to test that C++ definitions also work
+/// for op injection into the Transform dialect.
 class TestTransformOp
     : public Op<TestTransformOp, transform::TransformOpInterface::Trait> {
 public:
@@ -63,6 +64,33 @@ class TestTransformOp
       printer << " " << getMessage();
   }
 };
+
+/// A test op to exercise the verifier of the PossibleTopLevelTransformOpTrait
+/// in cases where it is attached to ops that do not comply with the trait
+/// requirements. This op cannot be defined in ODS because ODS generates strict
+/// verifiers that overalp with those in the trait and run earlier.
+class TestTransformUnrestrictedOpNoInterface
+    : public Op<TestTransformUnrestrictedOpNoInterface,
+                transform::PossibleTopLevelTransformOpTrait,
+                transform::TransformOpInterface::Trait> {
+public:
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
+      TestTransformUnrestrictedOpNoInterface)
+
+  using Op::Op;
+
+  static ArrayRef<StringRef> getAttributeNames() { return {}; }
+
+  static constexpr llvm::StringLiteral getOperationName() {
+    return llvm::StringLiteral(
+        "transform.test_transform_unrestricted_op_no_interface");
+  }
+
+  LogicalResult apply(transform::TransformResults &results,
+                      transform::TransformState &state) {
+    return success();
+  }
+};
 } // namespace
 
 LogicalResult mlir::test::TestProduceParamOrForwardOperandOp::apply(
@@ -97,6 +125,15 @@ LogicalResult mlir::test::TestConsumeOperandIfMatchesParamOrFail::apply(
   return success();
 }
 
+LogicalResult mlir::test::TestPrintRemarkAtOperandOp::apply(
+    transform::TransformResults &results, transform::TransformState &state) {
+  ArrayRef<Operation *> payload = state.getPayloadOps(getOperand());
+  for (Operation *op : payload)
+    op->emitRemark() << getMessage();
+
+  return success();
+}
+
 namespace {
 /// Test extension of the Transform dialect. Registers additional ops and
 /// declares PDL as dependent dialect since the additional ops are using PDL
@@ -108,6 +145,7 @@ class TestTransformDialectExtension
   TestTransformDialectExtension() {
     declareDependentDialect<pdl::PDLDialect>();
     registerTransformOps<TestTransformOp,
+                         TestTransformUnrestrictedOpNoInterface,
 #define GET_OP_LIST
 #include "TestTransformDialectExtension.cpp.inc"
                          >();

diff  --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
index c263409c618d1..4596780ac131e 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
@@ -38,4 +38,12 @@ def TestConsumeOperandIfMatchesParamOrFail
   let cppNamespace = "::mlir::test";
 }
 
+def TestPrintRemarkAtOperandOp
+  : Op<Transform_Dialect, "test_print_remark_at_operand",
+       [DeclareOpInterfaceMethods<TransformOpInterface>]> {
+  let arguments = (ins PDL_Operation:$operand, StrAttr:$message);
+  let assemblyFormat = "$operand `,` $message attr-dict";
+  let cppNamespace = "::mlir::test";
+}
+
 #endif // MLIR_TESTTRANSFORMDIALECTEXTENSION_TD

diff  --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 346e7f7d16a7d..2b70cd7afb037 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -7772,6 +7772,8 @@ cc_library(
     deps = [
         ":IR",
         ":PDLDialect",
+        ":PDLInterpDialect",
+        ":Rewrite",
         ":Support",
         ":TransformDialectIncGen",
         ":TransformDialectInterfacesIncGen",


        


More information about the Mlir-commits mailing list