[Mlir-commits] [mlir] ea64828 - [mlir:PDL] Expand how native constraint/rewrite functions can be defined

River Riddle llvmlistbot at llvm.org
Wed Apr 6 17:42:09 PDT 2022


Author: River Riddle
Date: 2022-04-06T17:41:59-07:00
New Revision: ea64828a10e304f8131e40dfef062173fe606e6a

URL: https://github.com/llvm/llvm-project/commit/ea64828a10e304f8131e40dfef062173fe606e6a
DIFF: https://github.com/llvm/llvm-project/commit/ea64828a10e304f8131e40dfef062173fe606e6a.diff

LOG: [mlir:PDL] Expand how native constraint/rewrite functions can be defined

This commit refactors the expected form of native constraint and rewrite
functions, and greatly reduces the necessary user complexity required when
defining a native function. Namely, this commit adds in automatic processing
of the necessary PDLValue glue code, and allows for users to define
constraint/rewrite functions using the C++ types that they actually want to
use.

As an example, lets see a simple example rewrite defined today:

```
static void rewriteFn(PatternRewriter &rewriter, PDLResultList &results,
                      ArrayRef<PDLValue> args) {
  ValueRange operandValues = args[0].cast<ValueRange>();
  TypeRange typeValues = args[1].cast<TypeRange>();
  ...
  // Create an operation at some point and pass it back to PDL.
  Operation *op = rewriter.create<SomeOp>(...);
  results.push_back(op);
}
```

After this commit, that same rewrite could be defined as:

```
static Operation *rewriteFn(PatternRewriter &rewriter ValueRange operandValues,
                            TypeRange typeValues) {
  ...
  // Create an operation at some point and pass it back to PDL.
  return rewriter.create<SomeOp>(...);
}
```

Differential Revision: https://reviews.llvm.org/D122086

Added: 
    

Modified: 
    llvm/include/llvm/ADT/STLExtras.h
    mlir/docs/PDLL.md
    mlir/include/mlir/Dialect/PDL/IR/PDLOps.td
    mlir/include/mlir/IR/Builders.h
    mlir/include/mlir/IR/PatternMatch.h
    mlir/include/mlir/Transforms/DialectConversion.h
    mlir/lib/Rewrite/ByteCode.cpp
    mlir/lib/Tools/PDLL/CodeGen/CPPGen.cpp
    mlir/lib/Transforms/Utils/DialectConversion.cpp
    mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
    mlir/test/Rewrite/pdl-bytecode.mlir
    mlir/test/lib/Rewrite/TestPDLByteCode.cpp
    mlir/test/mlir-pdll/CodeGen/CPP/general.pdll

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/ADT/STLExtras.h b/llvm/include/llvm/ADT/STLExtras.h
index 75d7175453583..fc461f8fe7e50 100644
--- a/llvm/include/llvm/ADT/STLExtras.h
+++ b/llvm/include/llvm/ADT/STLExtras.h
@@ -129,7 +129,7 @@ struct function_traits<ReturnType (ClassType::*)(Args...) const, false> {
 /// Overload for class function types.
 template <typename ClassType, typename ReturnType, typename... Args>
 struct function_traits<ReturnType (ClassType::*)(Args...), false>
-    : function_traits<ReturnType (ClassType::*)(Args...) const> {};
+    : public function_traits<ReturnType (ClassType::*)(Args...) const> {};
 /// Overload for non-class function types.
 template <typename ReturnType, typename... Args>
 struct function_traits<ReturnType (*)(Args...), false> {
@@ -143,6 +143,9 @@ struct function_traits<ReturnType (*)(Args...), false> {
   template <size_t i>
   using arg_t = typename std::tuple_element<i, std::tuple<Args...>>::type;
 };
+template <typename ReturnType, typename... Args>
+struct function_traits<ReturnType (*const)(Args...), false>
+    : public function_traits<ReturnType (*)(Args...)> {};
 /// Overload for non-class function type references.
 template <typename ReturnType, typename... Args>
 struct function_traits<ReturnType (&)(Args...), false>

diff  --git a/mlir/docs/PDLL.md b/mlir/docs/PDLL.md
index ab24a680d37b8..33293505d3714 100644
--- a/mlir/docs/PDLL.md
+++ b/mlir/docs/PDLL.md
@@ -1006,17 +1006,11 @@ External constraints are those registered explicitly with the `RewritePatternSet
 the C++ PDL API. For example, the constraints above may be registered as:
 
 ```c++
-// TODO: Cleanup when we allow more accessible wrappers around PDL functions.
-static LogicalResult hasOneUseImpl(PDLValue pdlValue, PatternRewriter &rewriter) {
-  Value value = pdlValue.cast<Value>();
-  
+static LogicalResult hasOneUseImpl(PatternRewriter &rewriter, Value value) {
   return success(value.hasOneUse());
 }
-static LogicalResult hasSameElementTypeImpl(ArrayRef<PDLValue> pdlValues,
-                                            PatternRewriter &rewriter) {
-  Value value1 = pdlValues[0].cast<Value>();
-  Value value2 = pdlValues[1].cast<Value>();
-      
+static LogicalResult hasSameElementTypeImpl(PatternRewriter &rewriter,
+                                            Value value1, Value Value2) {
   return success(value1.getType().cast<ShapedType>().getElementType() ==
                  value2.getType().cast<ShapedType>().getElementType());
 }
@@ -1307,14 +1301,10 @@ External rewrites are those registered explicitly with the `RewritePatternSet` v
 the C++ PDL API. For example, the rewrite above may be registered as:
 
 ```c++
-// TODO: Cleanup when we allow more accessible wrappers around PDL functions.
-static void buildOpImpl(ArrayRef<PDLValue> args, PatternRewriter &rewriter,
-                        PDLResultList &results) {
-  Value value = args[0].cast<Value>();
-
+static Operation *buildOpImpl(PDLResultList &results, Value value) {
   // insert special rewrite logic here.
   Operation *resultOp = ...; 
-  results.push_back(resultOp);
+  return resultOp;
 }
 
 void registerNativeRewrite(RewritePatternSet &patterns) {

diff  --git a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td
index 7f0253e59a32b..4e43b2677f4a8 100644
--- a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td
+++ b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td
@@ -68,18 +68,14 @@ def PDL_ApplyNativeRewriteOp
 
     ```mlir
     // Apply a native rewrite method that returns an attribute.
-    %ret = pdl.apply_native_rewrite "myNativeFunc"(%arg0, %arg1) : !pdl.attribute
+    %ret = pdl.apply_native_rewrite "myNativeFunc"(%arg0, %attr1) : !pdl.attribute
     ```
 
     ```c++
     // The native rewrite as defined in C++:
-    static void myNativeFunc(ArrayRef<PDLValue> args, PatternRewriter &rewriter,
-                             PDLResultList &results) {
-      Value arg0 = args[0].cast<Value>();
-      Value arg1 = args[1].cast<Value>();
-
-      // Just push back the first param attribute.
-      results.push_back(param0);
+    static Attribute myNativeFunc(PatternRewriter &rewriter, Value arg0, Attribute arg1) {
+      // Just return the second arg.
+      return arg1;
     }
 
     void registerNativeRewrite(PDLPatternModule &pdlModule) {

diff  --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 1bd6187a78d48..f4c8863624740 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -409,7 +409,8 @@ class OpBuilder : public Builder {
 
   /// Creates an operation with the given fields.
   Operation *create(Location loc, StringAttr opName, ValueRange operands,
-                    TypeRange types, ArrayRef<NamedAttribute> attributes = {},
+                    TypeRange types = {},
+                    ArrayRef<NamedAttribute> attributes = {},
                     BlockRange successors = {},
                     MutableArrayRef<std::unique_ptr<Region>> regions = {});
 

diff  --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 11f85ee38bef8..478fa2ae97b1c 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -386,6 +386,222 @@ class OpTraitRewritePattern : public RewritePattern {
                        benefit, context) {}
 };
 
+//===----------------------------------------------------------------------===//
+// RewriterBase
+//===----------------------------------------------------------------------===//
+
+/// This class coordinates the application of a rewrite on a set of IR,
+/// providing a way for clients to track mutations and create new operations.
+/// This class serves as a common API for IR mutation between pattern rewrites
+/// and non-pattern rewrites, and facilitates the development of shared
+/// IR transformation utilities.
+class RewriterBase : public OpBuilder, public OpBuilder::Listener {
+public:
+  /// Move the blocks that belong to "region" before the given position in
+  /// another region "parent". The two regions must be 
diff erent. The caller
+  /// is responsible for creating or updating the operation transferring flow
+  /// of control to the region and passing it the correct block arguments.
+  virtual void inlineRegionBefore(Region &region, Region &parent,
+                                  Region::iterator before);
+  void inlineRegionBefore(Region &region, Block *before);
+
+  /// Clone the blocks that belong to "region" before the given position in
+  /// another region "parent". The two regions must be 
diff erent. The caller is
+  /// responsible for creating or updating the operation transferring flow of
+  /// control to the region and passing it the correct block arguments.
+  virtual void cloneRegionBefore(Region &region, Region &parent,
+                                 Region::iterator before,
+                                 BlockAndValueMapping &mapping);
+  void cloneRegionBefore(Region &region, Region &parent,
+                         Region::iterator before);
+  void cloneRegionBefore(Region &region, Block *before);
+
+  /// This method replaces the uses of the results of `op` with the values in
+  /// `newValues` when the provided `functor` returns true for a specific use.
+  /// The number of values in `newValues` is required to match the number of
+  /// results of `op`. `allUsesReplaced`, if non-null, is set to true if all of
+  /// the uses of `op` were replaced. Note that in some rewriters, the given
+  /// 'functor' may be stored beyond the lifetime of the rewrite being applied.
+  /// As such, the function should not capture by reference and instead use
+  /// value capture as necessary.
+  virtual void
+  replaceOpWithIf(Operation *op, ValueRange newValues, bool *allUsesReplaced,
+                  llvm::unique_function<bool(OpOperand &) const> functor);
+  void replaceOpWithIf(Operation *op, ValueRange newValues,
+                       llvm::unique_function<bool(OpOperand &) const> functor) {
+    replaceOpWithIf(op, newValues, /*allUsesReplaced=*/nullptr,
+                    std::move(functor));
+  }
+
+  /// This method replaces the uses of the results of `op` with the values in
+  /// `newValues` when a use is nested within the given `block`. The number of
+  /// values in `newValues` is required to match the number of results of `op`.
+  /// If all uses of this operation are replaced, the operation is erased.
+  void replaceOpWithinBlock(Operation *op, ValueRange newValues, Block *block,
+                            bool *allUsesReplaced = nullptr);
+
+  /// This method replaces the results of the operation with the specified list
+  /// of values. The number of provided values must match the number of results
+  /// of the operation.
+  virtual void replaceOp(Operation *op, ValueRange newValues);
+
+  /// Replaces the result op with a new op that is created without verification.
+  /// The result values of the two ops must be the same types.
+  template <typename OpTy, typename... Args>
+  OpTy replaceOpWithNewOp(Operation *op, Args &&...args) {
+    auto newOp = create<OpTy>(op->getLoc(), std::forward<Args>(args)...);
+    replaceOpWithResultsOfAnotherOp(op, newOp.getOperation());
+    return newOp;
+  }
+
+  /// This method erases an operation that is known to have no uses.
+  virtual void eraseOp(Operation *op);
+
+  /// This method erases all operations in a block.
+  virtual void eraseBlock(Block *block);
+
+  /// Merge the operations of block 'source' into the end of block 'dest'.
+  /// 'source's predecessors must either be empty or only contain 'dest`.
+  /// 'argValues' is used to replace the block arguments of 'source' after
+  /// merging.
+  virtual void mergeBlocks(Block *source, Block *dest,
+                           ValueRange argValues = llvm::None);
+
+  // Merge the operations of block 'source' before the operation 'op'. Source
+  // block should not have existing predecessors or successors.
+  void mergeBlockBefore(Block *source, Operation *op,
+                        ValueRange argValues = llvm::None);
+
+  /// Split the operations starting at "before" (inclusive) out of the given
+  /// block into a new block, and return it.
+  virtual Block *splitBlock(Block *block, Block::iterator before);
+
+  /// This method is used to notify the rewriter that an in-place operation
+  /// modification is about to happen. A call to this function *must* be
+  /// followed by a call to either `finalizeRootUpdate` or `cancelRootUpdate`.
+  /// This is a minor efficiency win (it avoids creating a new operation and
+  /// removing the old one) but also often allows simpler code in the client.
+  virtual void startRootUpdate(Operation *op) {}
+
+  /// This method is used to signal the end of a root update on the given
+  /// operation. This can only be called on operations that were provided to a
+  /// call to `startRootUpdate`.
+  virtual void finalizeRootUpdate(Operation *op) {}
+
+  /// This method cancels a pending root update. This can only be called on
+  /// operations that were provided to a call to `startRootUpdate`.
+  virtual void cancelRootUpdate(Operation *op) {}
+
+  /// This method is a utility wrapper around a root update of an operation. It
+  /// wraps calls to `startRootUpdate` and `finalizeRootUpdate` around the given
+  /// callable.
+  template <typename CallableT>
+  void updateRootInPlace(Operation *root, CallableT &&callable) {
+    startRootUpdate(root);
+    callable();
+    finalizeRootUpdate(root);
+  }
+
+  /// Used to notify the rewriter that the IR failed to be rewritten because of
+  /// a match failure, and provide a callback to populate a diagnostic with the
+  /// reason why the failure occurred. This method allows for derived rewriters
+  /// to optionally hook into the reason why a rewrite failed, and display it to
+  /// users.
+  template <typename CallbackT>
+  std::enable_if_t<!std::is_convertible<CallbackT, Twine>::value, LogicalResult>
+  notifyMatchFailure(Location loc, CallbackT &&reasonCallback) {
+#ifndef NDEBUG
+    return notifyMatchFailure(loc,
+                              function_ref<void(Diagnostic &)>(reasonCallback));
+#else
+    return failure();
+#endif
+  }
+  template <typename CallbackT>
+  std::enable_if_t<!std::is_convertible<CallbackT, Twine>::value, LogicalResult>
+  notifyMatchFailure(Operation *op, CallbackT &&reasonCallback) {
+    return notifyMatchFailure(op->getLoc(),
+                              function_ref<void(Diagnostic &)>(reasonCallback));
+  }
+  template <typename ArgT>
+  LogicalResult notifyMatchFailure(ArgT &&arg, const Twine &msg) {
+    return notifyMatchFailure(std::forward<ArgT>(arg),
+                              [&](Diagnostic &diag) { diag << msg; });
+  }
+  template <typename ArgT>
+  LogicalResult notifyMatchFailure(ArgT &&arg, const char *msg) {
+    return notifyMatchFailure(std::forward<ArgT>(arg), Twine(msg));
+  }
+
+protected:
+  /// Initialize the builder with this rewriter as the listener.
+  explicit RewriterBase(MLIRContext *ctx) : OpBuilder(ctx, /*listener=*/this) {}
+  explicit RewriterBase(const OpBuilder &otherBuilder)
+      : OpBuilder(otherBuilder) {
+    setListener(this);
+  }
+  ~RewriterBase() override;
+
+  /// These are the callback methods that subclasses can choose to implement if
+  /// they would like to be notified about certain types of mutations.
+
+  /// Notify the rewriter that the specified operation is about to be replaced
+  /// with another set of operations. This is called before the uses of the
+  /// operation have been changed.
+  virtual void notifyRootReplaced(Operation *op) {}
+
+  /// This is called on an operation that a rewrite is removing, right before
+  /// the operation is deleted. At this point, the operation has zero uses.
+  virtual void notifyOperationRemoved(Operation *op) {}
+
+  /// Notify the rewriter that the pattern failed to match the given operation,
+  /// and provide a callback to populate a diagnostic with the reason why the
+  /// failure occurred. This method allows for derived rewriters to optionally
+  /// hook into the reason why a rewrite failed, and display it to users.
+  virtual LogicalResult
+  notifyMatchFailure(Location loc,
+                     function_ref<void(Diagnostic &)> reasonCallback) {
+    return failure();
+  }
+
+private:
+  void operator=(const RewriterBase &) = delete;
+  RewriterBase(const RewriterBase &) = delete;
+
+  /// 'op' and 'newOp' are known to have the same number of results, replace the
+  /// uses of op with uses of newOp.
+  void replaceOpWithResultsOfAnotherOp(Operation *op, Operation *newOp);
+};
+
+//===----------------------------------------------------------------------===//
+// IRRewriter
+//===----------------------------------------------------------------------===//
+
+/// This class coordinates rewriting a piece of IR outside of a pattern rewrite,
+/// providing a way to keep track of the mutations made to the IR. This class
+/// should only be used in situations where another `RewriterBase` instance,
+/// such as a `PatternRewriter`, is not available.
+class IRRewriter : public RewriterBase {
+public:
+  explicit IRRewriter(MLIRContext *ctx) : RewriterBase(ctx) {}
+  explicit IRRewriter(const OpBuilder &builder) : RewriterBase(builder) {}
+};
+
+//===----------------------------------------------------------------------===//
+// PatternRewriter
+//===----------------------------------------------------------------------===//
+
+/// A special type of `RewriterBase` that coordinates the application of a
+/// rewrite pattern on the current IR being matched, providing a way to keep
+/// track of any mutations made. This class should be used to perform all
+/// necessary IR mutations within a rewrite pattern, as the pattern driver may
+/// be tracking various state that would be invalidated when a mutation takes
+/// place.
+class PatternRewriter : public RewriterBase {
+public:
+  using RewriterBase::RewriterBase;
+};
+
 //===----------------------------------------------------------------------===//
 // PDLPatternModule
 //===----------------------------------------------------------------------===//
@@ -587,291 +803,561 @@ class PDLResultList {
 /// constraint to a given set of opaque PDLValue entities. Returns success if
 /// the constraint successfully held, failure otherwise.
 using PDLConstraintFunction =
-    std::function<LogicalResult(ArrayRef<PDLValue>, PatternRewriter &)>;
+    std::function<LogicalResult(PatternRewriter &, ArrayRef<PDLValue>)>;
 /// A native PDL rewrite function. This function performs a rewrite on the
 /// given set of values. Any results from this rewrite that should be passed
 /// back to PDL should be added to the provided result list. This method is only
 /// invoked when the corresponding match was successful.
 using PDLRewriteFunction =
-    std::function<void(ArrayRef<PDLValue>, PatternRewriter &, PDLResultList &)>;
-
-/// This class contains all of the necessary data for a set of PDL patterns, or
-/// pattern rewrites specified in the form of the PDL dialect. This PDL module
-/// contained by this pattern may contain any number of `pdl.pattern`
-/// operations.
-class PDLPatternModule {
-public:
-  PDLPatternModule() = default;
+    std::function<void(PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>;
 
-  /// Construct a PDL pattern with the given module.
-  PDLPatternModule(OwningOpRef<ModuleOp> pdlModule)
-      : pdlModule(std::move(pdlModule)) {}
+namespace detail {
+namespace pdl_function_builder {
+/// A utility variable that always resolves to false. This is useful for static
+/// asserts that are always false, but only should fire in certain templated
+/// constructs. For example, if a templated function should never be called, the
+/// function could be defined as:
+///
+/// template <typename T>
+/// void foo() {
+///  static_assert(always_false<T>, "This function should never be called");
+/// }
+///
+template <class... T>
+constexpr bool always_false = false;
 
-  /// Merge the state in `other` into this pattern module.
-  void mergeIn(PDLPatternModule &&other);
+//===----------------------------------------------------------------------===//
+// PDL Function Builder: Type Processing
+//===----------------------------------------------------------------------===//
 
-  /// Return the internal PDL module of this pattern.
-  ModuleOp getModule() { return pdlModule.get(); }
+/// This struct provides a convenient way to determine how to process a given
+/// type as either a PDL parameter, or a result value. This allows for
+/// supporting complex types in constraint and rewrite functions, without
+/// requiring the user to hand-write the necessary glue code themselves.
+/// Specializations of this class should implement the following methods to
+/// enable support as a PDL argument or result type:
+///
+///   static LogicalResult verifyAsArg(
+///     function_ref<LogicalResult(const Twine &)> errorFn, PDLValue pdlValue,
+///     size_t argIdx);
+///
+///     * This method verifies that the given PDLValue is valid for use as a
+///       value of `T`.
+///
+///   static T processAsArg(PDLValue pdlValue);
+///
+///     *  This method processes the given PDLValue as a value of `T`.
+///
+///   static void processAsResult(PatternRewriter &, PDLResultList &results,
+///                               const T &value);
+///
+///     *  This method processes the given value of `T` as the result of a
+///        function invocation. The method should package the value into an
+///        appropriate form and append it to the given result list.
+///
+/// If the type `T` is based on a higher order value, consider using
+/// `ProcessPDLValueBasedOn` as a base class of the specialization to simplify
+/// the implementation.
+///
+template <typename T, typename Enable = void>
+struct ProcessPDLValue;
+
+/// This struct provides a simplified model for processing types that are based
+/// on another type, e.g. APInt is based on the handling for IntegerAttr. This
+/// allows for building the necessary processing functions on top of the base
+/// value instead of a PDLValue. Derived users should implement the following
+/// (which subsume the ProcessPDLValue variants):
+///
+///   static LogicalResult verifyAsArg(
+///     function_ref<LogicalResult(const Twine &)> errorFn,
+///     const BaseT &baseValue, size_t argIdx);
+///
+///     * This method verifies that the given PDLValue is valid for use as a
+///       value of `T`.
+///
+///   static T processAsArg(BaseT baseValue);
+///
+///     *  This method processes the given base value as a value of `T`.
+///
+template <typename T, typename BaseT>
+struct ProcessPDLValueBasedOn {
+  static LogicalResult
+  verifyAsArg(function_ref<LogicalResult(const Twine &)> errorFn,
+              PDLValue pdlValue, size_t argIdx) {
+    // Verify the base class before continuing.
+    if (failed(ProcessPDLValue<BaseT>::verifyAsArg(errorFn, pdlValue, argIdx)))
+      return failure();
+    return ProcessPDLValue<T>::verifyAsArg(
+        errorFn, ProcessPDLValue<BaseT>::processAsArg(pdlValue), argIdx);
+  }
+  static T processAsArg(PDLValue pdlValue) {
+    return ProcessPDLValue<T>::processAsArg(
+        ProcessPDLValue<BaseT>::processAsArg(pdlValue));
+  }
 
-  //===--------------------------------------------------------------------===//
-  // Function Registry
+  /// Explicitly add the expected parent API to ensure the parent class
+  /// implements the necessary API (and doesn't implicitly inherit it from
+  /// somewhere else).
+  static LogicalResult
+  verifyAsArg(function_ref<LogicalResult(const Twine &)> errorFn, BaseT value,
+              size_t argIdx) {
+    return success();
+  }
+  static T processAsArg(BaseT baseValue);
+};
 
-  /// Register a constraint function.
-  void registerConstraintFunction(StringRef name,
-                                  PDLConstraintFunction constraintFn);
-  /// Register a single entity constraint function.
-  template <typename SingleEntityFn>
-  std::enable_if_t<!llvm::is_invocable<SingleEntityFn, ArrayRef<PDLValue>,
-                                       PatternRewriter &>::value>
-  registerConstraintFunction(StringRef name, SingleEntityFn &&constraintFn) {
-    registerConstraintFunction(
-        name, [constraintFn = std::forward<SingleEntityFn>(constraintFn)](
-                  ArrayRef<PDLValue> values, PatternRewriter &rewriter) {
-          assert(values.size() == 1 &&
-                 "expected values to have a single entity");
-          return constraintFn(values[0], rewriter);
-        });
+/// This struct provides a simplified model for processing types that have
+/// "builtin" PDLValue support:
+///   * Attribute, Operation *, Type, TypeRange, ValueRange
+template <typename T>
+struct ProcessBuiltinPDLValue {
+  static LogicalResult
+  verifyAsArg(function_ref<LogicalResult(const Twine &)> errorFn,
+              PDLValue pdlValue, size_t argIdx) {
+    if (pdlValue)
+      return success();
+    return errorFn("expected a non-null value for argument " + Twine(argIdx) +
+                   " of type: " + llvm::getTypeName<T>());
   }
 
-  /// Register a rewrite function.
-  void registerRewriteFunction(StringRef name, PDLRewriteFunction rewriteFn);
+  static T processAsArg(PDLValue pdlValue) { return pdlValue.cast<T>(); }
+  static void processAsResult(PatternRewriter &, PDLResultList &results,
+                              T value) {
+    results.push_back(value);
+  }
+};
 
-  /// Return the set of the registered constraint functions.
-  const llvm::StringMap<PDLConstraintFunction> &getConstraintFunctions() const {
-    return constraintFunctions;
+/// This struct provides a simplified model for processing types that inherit
+/// from builtin PDLValue types. For example, derived attributes like
+/// IntegerAttr, derived types like IntegerType, derived operations like
+/// ModuleOp, Interfaces, etc.
+template <typename T, typename BaseT>
+struct ProcessDerivedPDLValue : public ProcessPDLValueBasedOn<T, BaseT> {
+  static LogicalResult
+  verifyAsArg(function_ref<LogicalResult(const Twine &)> errorFn,
+              BaseT baseValue, size_t argIdx) {
+    return TypeSwitch<BaseT, LogicalResult>(baseValue)
+        .Case([&](T) { return success(); })
+        .Default([&](BaseT) {
+          return errorFn("expected argument " + Twine(argIdx) +
+                         " to be of type: " + llvm::getTypeName<T>());
+        });
   }
-  llvm::StringMap<PDLConstraintFunction> takeConstraintFunctions() {
-    return constraintFunctions;
+  static T processAsArg(BaseT baseValue) {
+    return baseValue.template cast<T>();
   }
-  /// Return the set of the registered rewrite functions.
-  const llvm::StringMap<PDLRewriteFunction> &getRewriteFunctions() const {
-    return rewriteFunctions;
-  }
-  llvm::StringMap<PDLRewriteFunction> takeRewriteFunctions() {
-    return rewriteFunctions;
+  static void processAsResult(PatternRewriter &, PDLResultList &results,
+                              T value) {
+    results.push_back(value);
   }
+};
 
-  /// Clear out the patterns and functions within this module.
-  void clear() {
-    pdlModule = nullptr;
-    constraintFunctions.clear();
-    rewriteFunctions.clear();
+//===----------------------------------------------------------------------===//
+// Attribute
+
+template <>
+struct ProcessPDLValue<Attribute> : public ProcessBuiltinPDLValue<Attribute> {};
+template <typename T>
+struct ProcessPDLValue<T,
+                       std::enable_if_t<std::is_base_of<Attribute, T>::value>>
+    : public ProcessDerivedPDLValue<T, Attribute> {};
+
+/// Handling for various Attribute value types.
+template <>
+struct ProcessPDLValue<StringRef>
+    : public ProcessPDLValueBasedOn<StringRef, StringAttr> {
+  static StringRef processAsArg(StringAttr value) { return value.getValue(); }
+  static void processAsResult(PatternRewriter &rewriter, PDLResultList &results,
+                              StringRef value) {
+    results.push_back(rewriter.getStringAttr(value));
   }
+};
+template <>
+struct ProcessPDLValue<std::string>
+    : public ProcessPDLValueBasedOn<std::string, StringAttr> {
+  template <typename T>
+  static std::string processAsArg(T value) {
+    static_assert(always_false<T>,
+                  "`std::string` arguments require a string copy, use "
+                  "`StringRef` for string-like arguments instead");
+  }
+  static void processAsResult(PatternRewriter &rewriter, PDLResultList &results,
+                              StringRef value) {
+    results.push_back(rewriter.getStringAttr(value));
+  }
+};
 
-private:
-  /// The module containing the `pdl.pattern` operations.
-  OwningOpRef<ModuleOp> pdlModule;
-
-  /// The external functions referenced from within the PDL module.
-  llvm::StringMap<PDLConstraintFunction> constraintFunctions;
-  llvm::StringMap<PDLRewriteFunction> rewriteFunctions;
+//===----------------------------------------------------------------------===//
+// Operation
+
+template <>
+struct ProcessPDLValue<Operation *>
+    : public ProcessBuiltinPDLValue<Operation *> {};
+template <typename T>
+struct ProcessPDLValue<T, std::enable_if_t<std::is_base_of<OpState, T>::value>>
+    : public ProcessDerivedPDLValue<T, Operation *> {
+  static T processAsArg(Operation *value) { return cast<T>(value); }
 };
 
 //===----------------------------------------------------------------------===//
-// RewriterBase
+// Type
+
+template <>
+struct ProcessPDLValue<Type> : public ProcessBuiltinPDLValue<Type> {};
+template <typename T>
+struct ProcessPDLValue<T, std::enable_if_t<std::is_base_of<Type, T>::value>>
+    : public ProcessDerivedPDLValue<T, Type> {};
+
 //===----------------------------------------------------------------------===//
+// TypeRange
+
+template <>
+struct ProcessPDLValue<TypeRange> : public ProcessBuiltinPDLValue<TypeRange> {};
+template <>
+struct ProcessPDLValue<ValueTypeRange<OperandRange>> {
+  static void processAsResult(PatternRewriter &, PDLResultList &results,
+                              ValueTypeRange<OperandRange> types) {
+    results.push_back(types);
+  }
+};
+template <>
+struct ProcessPDLValue<ValueTypeRange<ResultRange>> {
+  static void processAsResult(PatternRewriter &, PDLResultList &results,
+                              ValueTypeRange<ResultRange> types) {
+    results.push_back(types);
+  }
+};
 
-/// This class coordinates the application of a rewrite on a set of IR,
-/// providing a way for clients to track mutations and create new operations.
-/// This class serves as a common API for IR mutation between pattern rewrites
-/// and non-pattern rewrites, and facilitates the development of shared
-/// IR transformation utilities.
-class RewriterBase : public OpBuilder, public OpBuilder::Listener {
-public:
-  /// Move the blocks that belong to "region" before the given position in
-  /// another region "parent". The two regions must be 
diff erent. The caller
-  /// is responsible for creating or updating the operation transferring flow
-  /// of control to the region and passing it the correct block arguments.
-  virtual void inlineRegionBefore(Region &region, Region &parent,
-                                  Region::iterator before);
-  void inlineRegionBefore(Region &region, Block *before);
+//===----------------------------------------------------------------------===//
+// Value
 
-  /// Clone the blocks that belong to "region" before the given position in
-  /// another region "parent". The two regions must be 
diff erent. The caller is
-  /// responsible for creating or updating the operation transferring flow of
-  /// control to the region and passing it the correct block arguments.
-  virtual void cloneRegionBefore(Region &region, Region &parent,
-                                 Region::iterator before,
-                                 BlockAndValueMapping &mapping);
-  void cloneRegionBefore(Region &region, Region &parent,
-                         Region::iterator before);
-  void cloneRegionBefore(Region &region, Block *before);
+template <>
+struct ProcessPDLValue<Value> : public ProcessBuiltinPDLValue<Value> {};
 
-  /// This method replaces the uses of the results of `op` with the values in
-  /// `newValues` when the provided `functor` returns true for a specific use.
-  /// The number of values in `newValues` is required to match the number of
-  /// results of `op`. `allUsesReplaced`, if non-null, is set to true if all of
-  /// the uses of `op` were replaced. Note that in some rewriters, the given
-  /// 'functor' may be stored beyond the lifetime of the rewrite being applied.
-  /// As such, the function should not capture by reference and instead use
-  /// value capture as necessary.
-  virtual void
-  replaceOpWithIf(Operation *op, ValueRange newValues, bool *allUsesReplaced,
-                  llvm::unique_function<bool(OpOperand &) const> functor);
-  void replaceOpWithIf(Operation *op, ValueRange newValues,
-                       llvm::unique_function<bool(OpOperand &) const> functor) {
-    replaceOpWithIf(op, newValues, /*allUsesReplaced=*/nullptr,
-                    std::move(functor));
+//===----------------------------------------------------------------------===//
+// ValueRange
+
+template <>
+struct ProcessPDLValue<ValueRange> : public ProcessBuiltinPDLValue<ValueRange> {
+};
+template <>
+struct ProcessPDLValue<OperandRange> {
+  static void processAsResult(PatternRewriter &, PDLResultList &results,
+                              OperandRange values) {
+    results.push_back(values);
   }
+};
+template <>
+struct ProcessPDLValue<ResultRange> {
+  static void processAsResult(PatternRewriter &, PDLResultList &results,
+                              ResultRange values) {
+    results.push_back(values);
+  }
+};
 
-  /// This method replaces the uses of the results of `op` with the values in
-  /// `newValues` when a use is nested within the given `block`. The number of
-  /// values in `newValues` is required to match the number of results of `op`.
-  /// If all uses of this operation are replaced, the operation is erased.
-  void replaceOpWithinBlock(Operation *op, ValueRange newValues, Block *block,
-                            bool *allUsesReplaced = nullptr);
+//===----------------------------------------------------------------------===//
+// PDL Function Builder: Argument Handling
+//===----------------------------------------------------------------------===//
 
-  /// This method replaces the results of the operation with the specified list
-  /// of values. The number of provided values must match the number of results
-  /// of the operation.
-  virtual void replaceOp(Operation *op, ValueRange newValues);
+/// Validate the given PDLValues match the constraints defined by the argument
+/// types of the given function. In the case of failure, a match failure
+/// diagnostic is emitted.
+/// FIXME: This should be completely removed in favor of `assertArgs`, but PDL
+/// does not currently preserve Constraint application ordering.
+template <typename PDLFnT, std::size_t... I>
+LogicalResult verifyAsArgs(PatternRewriter &rewriter, ArrayRef<PDLValue> values,
+                           std::index_sequence<I...>) {
+  using FnTraitsT = llvm::function_traits<PDLFnT>;
+
+  auto errorFn = [&](const Twine &msg) {
+    return rewriter.notifyMatchFailure(rewriter.getUnknownLoc(), msg);
+  };
+  LogicalResult result = success();
+  (void)std::initializer_list<int>{
+      (result =
+           succeeded(result)
+               ? ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>::
+                     verifyAsArg(errorFn, values[I], I)
+               : failure(),
+       0)...};
+  return result;
+}
 
-  /// Replaces the result op with a new op that is created without verification.
-  /// The result values of the two ops must be the same types.
-  template <typename OpTy, typename... Args>
-  OpTy replaceOpWithNewOp(Operation *op, Args &&... args) {
-    auto newOp = create<OpTy>(op->getLoc(), std::forward<Args>(args)...);
-    replaceOpWithResultsOfAnotherOp(op, newOp.getOperation());
-    return newOp;
-  }
+/// Assert that the given PDLValues match the constraints defined by the
+/// arguments of the given function. In the case of failure, a fatal error
+/// is emitted.
+template <typename PDLFnT, std::size_t... I>
+void assertArgs(PatternRewriter &rewriter, ArrayRef<PDLValue> values,
+                std::index_sequence<I...>) {
+  using FnTraitsT = llvm::function_traits<PDLFnT>;
+
+  // We only want to do verification in debug builds, same as with `assert`.
+#if LLVM_ENABLE_ABI_BREAKING_CHECKS
+  auto errorFn = [&](const Twine &msg) -> LogicalResult {
+    llvm::report_fatal_error(msg);
+  };
+  (void)std::initializer_list<int>{
+      (assert(succeeded(ProcessPDLValue<typename FnTraitsT::template arg_t<
+                            I + 1>>::verifyAsArg(errorFn, values[I], I))),
+       0)...};
+#endif
+}
 
-  /// This method erases an operation that is known to have no uses.
-  virtual void eraseOp(Operation *op);
+//===----------------------------------------------------------------------===//
+// PDL Function Builder: Results Handling
+//===----------------------------------------------------------------------===//
 
-  /// This method erases all operations in a block.
-  virtual void eraseBlock(Block *block);
+/// Store a single result within the result list.
+template <typename T>
+static void processResults(PatternRewriter &rewriter, PDLResultList &results,
+                           T &&value) {
+  ProcessPDLValue<T>::processAsResult(rewriter, results,
+                                      std::forward<T>(value));
+}
 
-  /// Merge the operations of block 'source' into the end of block 'dest'.
-  /// 'source's predecessors must either be empty or only contain 'dest`.
-  /// 'argValues' is used to replace the block arguments of 'source' after
-  /// merging.
-  virtual void mergeBlocks(Block *source, Block *dest,
-                           ValueRange argValues = llvm::None);
+/// Store a std::pair<> as individual results within the result list.
+template <typename T1, typename T2>
+static void processResults(PatternRewriter &rewriter, PDLResultList &results,
+                           std::pair<T1, T2> &&pair) {
+  processResults(rewriter, results, std::move(pair.first));
+  processResults(rewriter, results, std::move(pair.second));
+}
 
-  // Merge the operations of block 'source' before the operation 'op'. Source
-  // block should not have existing predecessors or successors.
-  void mergeBlockBefore(Block *source, Operation *op,
-                        ValueRange argValues = llvm::None);
+/// Store a std::tuple<> as individual results within the result list.
+template <typename... Ts>
+static void processResults(PatternRewriter &rewriter, PDLResultList &results,
+                           std::tuple<Ts...> &&tuple) {
+  auto applyFn = [&](auto &&...args) {
+    // TODO: Use proper fold expressions when we have C++17. For now we use a
+    // bogus std::initializer_list to work around C++14 limitations.
+    (void)std::initializer_list<int>{
+        (processResults(rewriter, results, std::move(args)), 0)...};
+  };
+  llvm::apply_tuple(applyFn, std::move(tuple));
+}
 
-  /// Split the operations starting at "before" (inclusive) out of the given
-  /// block into a new block, and return it.
-  virtual Block *splitBlock(Block *block, Block::iterator before);
+//===----------------------------------------------------------------------===//
+// PDL Constraint Builder
+//===----------------------------------------------------------------------===//
 
-  /// This method is used to notify the rewriter that an in-place operation
-  /// modification is about to happen. A call to this function *must* be
-  /// followed by a call to either `finalizeRootUpdate` or `cancelRootUpdate`.
-  /// This is a minor efficiency win (it avoids creating a new operation and
-  /// removing the old one) but also often allows simpler code in the client.
-  virtual void startRootUpdate(Operation *op) {}
+/// Process the arguments of a native constraint and invoke it.
+template <typename PDLFnT, std::size_t... I,
+          typename FnTraitsT = llvm::function_traits<PDLFnT>>
+typename FnTraitsT::result_t
+processArgsAndInvokeConstraint(PDLFnT &fn, PatternRewriter &rewriter,
+                               ArrayRef<PDLValue> values,
+                               std::index_sequence<I...>) {
+  return fn(
+      rewriter,
+      (ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>::processAsArg(
+          values[I]))...);
+}
 
-  /// This method is used to signal the end of a root update on the given
-  /// operation. This can only be called on operations that were provided to a
-  /// call to `startRootUpdate`.
-  virtual void finalizeRootUpdate(Operation *op) {}
+/// Build a constraint function from the given function `ConstraintFnT`. This
+/// allows for enabling the user to define simpler, more direct constraint
+/// functions without needing to handle the low-level PDL goop.
+///
+/// If the constraint function is already in the correct form, we just forward
+/// it directly.
+template <typename ConstraintFnT>
+std::enable_if_t<
+    std::is_convertible<ConstraintFnT, PDLConstraintFunction>::value,
+    PDLConstraintFunction>
+buildConstraintFn(ConstraintFnT &&constraintFn) {
+  return std::forward<ConstraintFnT>(constraintFn);
+}
+/// Otherwise, we generate a wrapper that will unpack the PDLValues in the form
+/// we desire.
+template <typename ConstraintFnT>
+std::enable_if_t<
+    !std::is_convertible<ConstraintFnT, PDLConstraintFunction>::value,
+    PDLConstraintFunction>
+buildConstraintFn(ConstraintFnT &&constraintFn) {
+  return [constraintFn = std::forward<ConstraintFnT>(constraintFn)](
+             PatternRewriter &rewriter,
+             ArrayRef<PDLValue> values) -> LogicalResult {
+    auto argIndices = std::make_index_sequence<
+        llvm::function_traits<ConstraintFnT>::num_args - 1>();
+    if (failed(verifyAsArgs<ConstraintFnT>(rewriter, values, argIndices)))
+      return failure();
+    return processArgsAndInvokeConstraint(constraintFn, rewriter, values,
+                                          argIndices);
+  };
+}
 
-  /// This method cancels a pending root update. This can only be called on
-  /// operations that were provided to a call to `startRootUpdate`.
-  virtual void cancelRootUpdate(Operation *op) {}
+//===----------------------------------------------------------------------===//
+// PDL Rewrite Builder
+//===----------------------------------------------------------------------===//
 
-  /// This method is a utility wrapper around a root update of an operation. It
-  /// wraps calls to `startRootUpdate` and `finalizeRootUpdate` around the given
-  /// callable.
-  template <typename CallableT>
-  void updateRootInPlace(Operation *root, CallableT &&callable) {
-    startRootUpdate(root);
-    callable();
-    finalizeRootUpdate(root);
-  }
+/// Process the arguments of a native rewrite and invoke it.
+/// This overload handles the case of no return values.
+template <typename PDLFnT, std::size_t... I,
+          typename FnTraitsT = llvm::function_traits<PDLFnT>>
+std::enable_if_t<std::is_same<typename FnTraitsT::result_t, void>::value>
+processArgsAndInvokeRewrite(PDLFnT &fn, PatternRewriter &rewriter,
+                            PDLResultList &, ArrayRef<PDLValue> values,
+                            std::index_sequence<I...>) {
+  fn(rewriter,
+     (ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>::processAsArg(
+         values[I]))...);
+}
+/// This overload handles the case of return values, which need to be packaged
+/// into the result list.
+template <typename PDLFnT, std::size_t... I,
+          typename FnTraitsT = llvm::function_traits<PDLFnT>>
+std::enable_if_t<!std::is_same<typename FnTraitsT::result_t, void>::value>
+processArgsAndInvokeRewrite(PDLFnT &fn, PatternRewriter &rewriter,
+                            PDLResultList &results, ArrayRef<PDLValue> values,
+                            std::index_sequence<I...>) {
+  processResults(
+      rewriter, results,
+      fn(rewriter, (ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>::
+                        processAsArg(values[I]))...));
+}
 
-  /// Used to notify the rewriter that the IR failed to be rewritten because of
-  /// a match failure, and provide a callback to populate a diagnostic with the
-  /// reason why the failure occurred. This method allows for derived rewriters
-  /// to optionally hook into the reason why a rewrite failed, and display it to
-  /// users.
-  template <typename CallbackT>
-  std::enable_if_t<!std::is_convertible<CallbackT, Twine>::value, LogicalResult>
-  notifyMatchFailure(Operation *op, CallbackT &&reasonCallback) {
-#ifndef NDEBUG
-    return notifyMatchFailure(op,
-                              function_ref<void(Diagnostic &)>(reasonCallback));
-#else
-    return failure();
-#endif
-  }
-  LogicalResult notifyMatchFailure(Operation *op, const Twine &msg) {
-    return notifyMatchFailure(op, [&](Diagnostic &diag) { diag << msg; });
-  }
-  LogicalResult notifyMatchFailure(Operation *op, const char *msg) {
-    return notifyMatchFailure(op, Twine(msg));
-  }
+/// Build a rewrite function from the given function `RewriteFnT`. This
+/// allows for enabling the user to define simpler, more direct rewrite
+/// functions without needing to handle the low-level PDL goop.
+///
+/// If the rewrite function is already in the correct form, we just forward
+/// it directly.
+template <typename RewriteFnT>
+std::enable_if_t<std::is_convertible<RewriteFnT, PDLRewriteFunction>::value,
+                 PDLRewriteFunction>
+buildRewriteFn(RewriteFnT &&rewriteFn) {
+  return std::forward<RewriteFnT>(rewriteFn);
+}
+/// Otherwise, we generate a wrapper that will unpack the PDLValues in the form
+/// we desire.
+template <typename RewriteFnT>
+std::enable_if_t<!std::is_convertible<RewriteFnT, PDLRewriteFunction>::value,
+                 PDLRewriteFunction>
+buildRewriteFn(RewriteFnT &&rewriteFn) {
+  return [rewriteFn = std::forward<RewriteFnT>(rewriteFn)](
+             PatternRewriter &rewriter, PDLResultList &results,
+             ArrayRef<PDLValue> values) {
+    auto argIndices =
+        std::make_index_sequence<llvm::function_traits<RewriteFnT>::num_args -
+                                 1>();
+    assertArgs<RewriteFnT>(rewriter, values, argIndices);
+    processArgsAndInvokeRewrite(rewriteFn, rewriter, results, values,
+                                argIndices);
+  };
+}
 
-protected:
-  /// Initialize the builder with this rewriter as the listener.
-  explicit RewriterBase(MLIRContext *ctx) : OpBuilder(ctx, /*listener=*/this) {}
-  explicit RewriterBase(const OpBuilder &otherBuilder)
-      : OpBuilder(otherBuilder) {
-    setListener(this);
-  }
-  ~RewriterBase() override;
+} // namespace pdl_function_builder
+} // namespace detail
 
-  /// These are the callback methods that subclasses can choose to implement if
-  /// they would like to be notified about certain types of mutations.
+/// This class contains all of the necessary data for a set of PDL patterns, or
+/// pattern rewrites specified in the form of the PDL dialect. This PDL module
+/// contained by this pattern may contain any number of `pdl.pattern`
+/// operations.
+class PDLPatternModule {
+public:
+  PDLPatternModule() = default;
 
-  /// Notify the rewriter that the specified operation is about to be replaced
-  /// with another set of operations. This is called before the uses of the
-  /// operation have been changed.
-  virtual void notifyRootReplaced(Operation *op) {}
+  /// Construct a PDL pattern with the given module.
+  PDLPatternModule(OwningOpRef<ModuleOp> pdlModule)
+      : pdlModule(std::move(pdlModule)) {}
 
-  /// This is called on an operation that a rewrite is removing, right before
-  /// the operation is deleted. At this point, the operation has zero uses.
-  virtual void notifyOperationRemoved(Operation *op) {}
+  /// Merge the state in `other` into this pattern module.
+  void mergeIn(PDLPatternModule &&other);
 
-  /// Notify the rewriter that the pattern failed to match the given operation,
-  /// and provide a callback to populate a diagnostic with the reason why the
-  /// failure occurred. This method allows for derived rewriters to optionally
-  /// hook into the reason why a rewrite failed, and display it to users.
-  virtual LogicalResult
-  notifyMatchFailure(Operation *op,
-                     function_ref<void(Diagnostic &)> reasonCallback) {
-    return failure();
-  }
+  /// Return the internal PDL module of this pattern.
+  ModuleOp getModule() { return pdlModule.get(); }
 
-private:
-  void operator=(const RewriterBase &) = delete;
-  RewriterBase(const RewriterBase &) = delete;
+  //===--------------------------------------------------------------------===//
+  // Function Registry
 
-  /// 'op' and 'newOp' are known to have the same number of results, replace the
-  /// uses of op with uses of newOp.
-  void replaceOpWithResultsOfAnotherOp(Operation *op, Operation *newOp);
-};
+  /// Register a constraint function with PDL. A constraint function may be
+  /// specified in one of two ways:
+  ///
+  ///   * `LogicalResult (PatternRewriter &, ArrayRef<PDLValue>)`
+  ///
+  ///   In this overload the arguments of the constraint function are passed via
+  ///   the low-level PDLValue form.
+  ///
+  ///   * `LogicalResult (PatternRewriter &, ValueTs... values)`
+  ///
+  ///   In this form the arguments of the constraint function are passed via the
+  ///   expected high level C++ type. In this form, the framework will
+  ///   automatically unwrap PDLValues and convert them to the expected ValueTs.
+  ///   For example, if the constraint function accepts a `Operation *`, the
+  ///   framework will automatically cast the input PDLValue. In the case of a
+  ///   `StringRef`, the framework will automatically unwrap the argument as a
+  ///   StringAttr and pass the underlying string value. To see the full list of
+  ///   supported types, or to see how to add handling for custom types, view
+  ///   the definition of `ProcessPDLValue` above.
+  void registerConstraintFunction(StringRef name,
+                                  PDLConstraintFunction constraintFn);
+  template <typename ConstraintFnT>
+  void registerConstraintFunction(StringRef name,
+                                  ConstraintFnT &&constraintFn) {
+    registerConstraintFunction(name,
+                               detail::pdl_function_builder::buildConstraintFn(
+                                   std::forward<ConstraintFnT>(constraintFn)));
+  }
 
-//===----------------------------------------------------------------------===//
-// IRRewriter
-//===----------------------------------------------------------------------===//
+  /// Register a rewrite function with PDL. A rewrite function may be specified
+  /// in one of two ways:
+  ///
+  ///   * `void (PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)`
+  ///
+  ///   In this overload the arguments of the constraint function are passed via
+  ///   the low-level PDLValue form, and the results are manually appended to
+  ///   the given result list.
+  ///
+  ///   * `ResultT (PatternRewriter &, ValueTs... values)`
+  ///
+  ///   In this form the arguments and result of the rewrite function are passed
+  ///   via the expected high level C++ type. In this form, the framework will
+  ///   automatically unwrap the PDLValues arguments and convert them to the
+  ///   expected ValueTs. It will also automatically handle the processing and
+  ///   packaging of the result value to the result list. For example, if the
+  ///   rewrite function takes a `Operation *`, the framework will automatically
+  ///   cast the input PDLValue. In the case of a `StringRef`, the framework
+  ///   will automatically unwrap the argument as a StringAttr and pass the
+  ///   underlying string value. In the reverse case, if the rewrite returns a
+  ///   StringRef or std::string, it will automatically package this as a
+  ///   StringAttr and append it to the result list. To see the full list of
+  ///   supported types, or to see how to add handling for custom types, view
+  ///   the definition of `ProcessPDLValue` above.
+  void registerRewriteFunction(StringRef name, PDLRewriteFunction rewriteFn);
+  template <typename RewriteFnT>
+  void registerRewriteFunction(StringRef name, RewriteFnT &&rewriteFn) {
+    registerRewriteFunction(name, detail::pdl_function_builder::buildRewriteFn(
+                                      std::forward<RewriteFnT>(rewriteFn)));
+  }
 
-/// This class coordinates rewriting a piece of IR outside of a pattern rewrite,
-/// providing a way to keep track of the mutations made to the IR. This class
-/// should only be used in situations where another `RewriterBase` instance,
-/// such as a `PatternRewriter`, is not available.
-class IRRewriter : public RewriterBase {
-public:
-  explicit IRRewriter(MLIRContext *ctx) : RewriterBase(ctx) {}
-  explicit IRRewriter(const OpBuilder &builder) : RewriterBase(builder) {}
-};
+  /// Return the set of the registered constraint functions.
+  const llvm::StringMap<PDLConstraintFunction> &getConstraintFunctions() const {
+    return constraintFunctions;
+  }
+  llvm::StringMap<PDLConstraintFunction> takeConstraintFunctions() {
+    return constraintFunctions;
+  }
+  /// Return the set of the registered rewrite functions.
+  const llvm::StringMap<PDLRewriteFunction> &getRewriteFunctions() const {
+    return rewriteFunctions;
+  }
+  llvm::StringMap<PDLRewriteFunction> takeRewriteFunctions() {
+    return rewriteFunctions;
+  }
 
-//===----------------------------------------------------------------------===//
-// PatternRewriter
-//===----------------------------------------------------------------------===//
+  /// Clear out the patterns and functions within this module.
+  void clear() {
+    pdlModule = nullptr;
+    constraintFunctions.clear();
+    rewriteFunctions.clear();
+  }
 
-/// A special type of `RewriterBase` that coordinates the application of a
-/// rewrite pattern on the current IR being matched, providing a way to keep
-/// track of any mutations made. This class should be used to perform all
-/// necessary IR mutations within a rewrite pattern, as the pattern driver may
-/// be tracking various state that would be invalidated when a mutation takes
-/// place.
-class PatternRewriter : public RewriterBase {
-public:
-  using RewriterBase::RewriterBase;
+private:
+  /// The module containing the `pdl.pattern` operations.
+  OwningOpRef<ModuleOp> pdlModule;
+
+  /// The external functions referenced from within the PDL module.
+  llvm::StringMap<PDLConstraintFunction> constraintFunctions;
+  llvm::StringMap<PDLRewriteFunction> rewriteFunctions;
 };
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 5b409d695baed..d0be98a307d70 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -629,7 +629,7 @@ class ConversionPatternRewriter final : public PatternRewriter {
 
   /// PatternRewriter hook for notifying match failure reasons.
   LogicalResult
-  notifyMatchFailure(Operation *op,
+  notifyMatchFailure(Location loc,
                      function_ref<void(Diagnostic &)> reasonCallback) override;
   using PatternRewriter::notifyMatchFailure;
 

diff  --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp
index 367f51ad601a8..c2dc41a81c6f8 100644
--- a/mlir/lib/Rewrite/ByteCode.cpp
+++ b/mlir/lib/Rewrite/ByteCode.cpp
@@ -1340,7 +1340,7 @@ void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) {
   });
 
   // Invoke the constraint and jump to the proper destination.
-  selectJump(succeeded(constraintFn(args, rewriter)));
+  selectJump(succeeded(constraintFn(rewriter, args)));
 }
 
 void ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
@@ -1357,7 +1357,7 @@ void ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
   // Execute the rewrite function.
   ByteCodeField numResults = read();
   ByteCodeRewriteResultList results(numResults);
-  rewriteFn(args, rewriter, results);
+  rewriteFn(rewriter, results, args);
 
   assert(results.getResults().size() == numResults &&
          "native PDL rewrite function returned unexpected number of results");

diff  --git a/mlir/lib/Tools/PDLL/CodeGen/CPPGen.cpp b/mlir/lib/Tools/PDLL/CodeGen/CPPGen.cpp
index c6937aa736a11..c3b5c957007e8 100644
--- a/mlir/lib/Tools/PDLL/CodeGen/CPPGen.cpp
+++ b/mlir/lib/Tools/PDLL/CodeGen/CPPGen.cpp
@@ -184,9 +184,9 @@ void CodeGen::generateConstraintOrRewrite(StringRef name, bool isConstraint,
         .Case([&](ast::ValueRangeType) { return "::mlir::ValueRange"; });
   };
   os << "static " << (isConstraint ? "::mlir::LogicalResult " : "void ") << name
-     << "PDLFn(::llvm::ArrayRef<::mlir::PDLValue> values, "
-        "::mlir::PatternRewriter &rewriter"
-     << (isConstraint ? "" : ", ::mlir::PDLResultList &results") << ") {\n";
+     << "PDLFn(::mlir::PatternRewriter &rewriter, "
+     << (isConstraint ? "" : "::mlir::PDLResultList &results, ")
+     << "::llvm::ArrayRef<::mlir::PDLValue> values) {\n";
 
   const char *argumentInitStr = R"(
   {0} {1} = {{};

diff  --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index bdeb0fa222b31..575b3cbd3c335 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1673,8 +1673,8 @@ void ConversionPatternRewriter::cancelRootUpdate(Operation *op) {
 }
 
 LogicalResult ConversionPatternRewriter::notifyMatchFailure(
-    Operation *op, function_ref<void(Diagnostic &)> reasonCallback) {
-  return impl->notifyMatchFailure(op->getLoc(), reasonCallback);
+    Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
+  return impl->notifyMatchFailure(loc, reasonCallback);
 }
 
 detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {

diff  --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 7fd46c711db01..81b57c420a726 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -76,7 +76,7 @@ class GreedyPatternRewriteDriver : public PatternRewriter {
 
   /// PatternRewriter hook for notifying match failure reasons.
   LogicalResult
-  notifyMatchFailure(Operation *op,
+  notifyMatchFailure(Location loc,
                      function_ref<void(Diagnostic &)> reasonCallback) override;
 
   /// The low-level pattern applicator.
@@ -348,9 +348,9 @@ void GreedyPatternRewriteDriver::eraseOp(Operation *op) {
 }
 
 LogicalResult GreedyPatternRewriteDriver::notifyMatchFailure(
-    Operation *op, function_ref<void(Diagnostic &)> reasonCallback) {
+    Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
   LLVM_DEBUG({
-    Diagnostic diag(op->getLoc(), DiagnosticSeverity::Remark);
+    Diagnostic diag(loc, DiagnosticSeverity::Remark);
     reasonCallback(diag);
     logger.startLine() << "** Failure : " << diag.str() << "\n";
   });

diff  --git a/mlir/test/Rewrite/pdl-bytecode.mlir b/mlir/test/Rewrite/pdl-bytecode.mlir
index d06c500241b0b..e1a8c6081d4e5 100644
--- a/mlir/test/Rewrite/pdl-bytecode.mlir
+++ b/mlir/test/Rewrite/pdl-bytecode.mlir
@@ -181,8 +181,9 @@ module @patterns {
 
   module @rewriters {
     pdl_interp.func @success(%root : !pdl.operation) {
+      %attr = pdl_interp.apply_rewrite "str_creator" : !pdl.attribute
       %type = pdl_interp.apply_rewrite "type_creator" : !pdl.type
-      %newOp = pdl_interp.create_operation "test.success" -> (%type : !pdl.type)
+      %newOp = pdl_interp.create_operation "test.success" {"attr" = %attr} -> (%type : !pdl.type)
       pdl_interp.erase %root
       pdl_interp.finalize
     }
@@ -190,7 +191,7 @@ module @patterns {
 }
 
 // CHECK-LABEL: test.apply_rewrite_4
-// CHECK: "test.success"() : () -> f32
+// CHECK: "test.success"() {attr = "test.str"} : () -> f32
 module @ir attributes { test.apply_rewrite_4 } {
   "test.op"() : () -> ()
 }

diff  --git a/mlir/test/lib/Rewrite/TestPDLByteCode.cpp b/mlir/test/lib/Rewrite/TestPDLByteCode.cpp
index 13465ba2865e0..daa1c371f27c9 100644
--- a/mlir/test/lib/Rewrite/TestPDLByteCode.cpp
+++ b/mlir/test/lib/Rewrite/TestPDLByteCode.cpp
@@ -14,53 +14,42 @@
 using namespace mlir;
 
 /// Custom constraint invoked from PDL.
-static LogicalResult customSingleEntityConstraint(PDLValue value,
-                                                  PatternRewriter &rewriter) {
-  Operation *rootOp = value.cast<Operation *>();
+static LogicalResult customSingleEntityConstraint(PatternRewriter &rewriter,
+                                                  Operation *rootOp) {
   return success(rootOp->getName().getStringRef() == "test.op");
 }
-static LogicalResult customMultiEntityConstraint(ArrayRef<PDLValue> values,
-                                                 PatternRewriter &rewriter) {
-  return customSingleEntityConstraint(values[1], rewriter);
+static LogicalResult customMultiEntityConstraint(PatternRewriter &rewriter,
+                                                 Operation *root,
+                                                 Operation *rootCopy) {
+  return customSingleEntityConstraint(rewriter, rootCopy);
 }
-static LogicalResult
-customMultiEntityVariadicConstraint(ArrayRef<PDLValue> values,
-                                    PatternRewriter &rewriter) {
-  if (llvm::any_of(values, [](const PDLValue &value) { return !value; }))
-    return failure();
-  ValueRange operandValues = values[0].cast<ValueRange>();
-  TypeRange typeValues = values[1].cast<TypeRange>();
+static LogicalResult customMultiEntityVariadicConstraint(
+    PatternRewriter &rewriter, ValueRange operandValues, TypeRange typeValues) {
   if (operandValues.size() != 2 || typeValues.size() != 2)
     return failure();
   return success();
 }
 
 // Custom creator invoked from PDL.
-static void customCreate(ArrayRef<PDLValue> args, PatternRewriter &rewriter,
-                         PDLResultList &results) {
-  results.push_back(rewriter.create(
-      OperationState(args[0].cast<Operation *>()->getLoc(), "test.success")));
+static Operation *customCreate(PatternRewriter &rewriter, Operation *op) {
+  return rewriter.create(OperationState(op->getLoc(), "test.success"));
 }
-
-static void customVariadicResultCreate(ArrayRef<PDLValue> args,
-                                       PatternRewriter &rewriter,
-                                       PDLResultList &results) {
-  Operation *root = args[0].cast<Operation *>();
-  results.push_back(root->getOperands());
-  results.push_back(root->getOperands().getTypes());
+static auto customVariadicResultCreate(PatternRewriter &rewriter,
+                                       Operation *root) {
+  return std::make_pair(root->getOperands(), root->getOperands().getTypes());
+}
+static Type customCreateType(PatternRewriter &rewriter) {
+  return rewriter.getF32Type();
 }
-static void customCreateType(ArrayRef<PDLValue> args, PatternRewriter &rewriter,
-                             PDLResultList &results) {
-  results.push_back(rewriter.getF32Type());
+static std::string customCreateStrAttr(PatternRewriter &rewriter) {
+  return "test.str";
 }
 
 /// Custom rewriter invoked from PDL.
-static void customRewriter(ArrayRef<PDLValue> args, PatternRewriter &rewriter,
-                           PDLResultList &results) {
-  Operation *root = args[0].cast<Operation *>();
-  OperationState successOpState(root->getLoc(), "test.success");
-  successOpState.addOperands(args[1].cast<Value>());
-  rewriter.create(successOpState);
+static void customRewriter(PatternRewriter &rewriter, Operation *root,
+                           Value input) {
+  rewriter.create(root->getLoc(), rewriter.getStringAttr("test.success"),
+                  input);
   rewriter.eraseOp(root);
 }
 
@@ -117,6 +106,7 @@ struct TestPDLByteCodePass
     pdlPattern.registerRewriteFunction("var_creator",
                                        customVariadicResultCreate);
     pdlPattern.registerRewriteFunction("type_creator", customCreateType);
+    pdlPattern.registerRewriteFunction("str_creator", customCreateStrAttr);
     pdlPattern.registerRewriteFunction("rewriter", customRewriter);
     patternList.add(std::move(pdlPattern));
 

diff  --git a/mlir/test/mlir-pdll/CodeGen/CPP/general.pdll b/mlir/test/mlir-pdll/CodeGen/CPP/general.pdll
index 9f0ea1386322d..802958a3872f2 100644
--- a/mlir/test/mlir-pdll/CodeGen/CPP/general.pdll
+++ b/mlir/test/mlir-pdll/CodeGen/CPP/general.pdll
@@ -43,8 +43,8 @@ Pattern => erase op<test.op3>;
 
 // Check the generation of native constraints and rewrites.
 
-// CHECK:      static ::mlir::LogicalResult TestCstPDLFn(::llvm::ArrayRef<::mlir::PDLValue> values,
-// CHECK-SAME:                                           ::mlir::PatternRewriter &rewriter) {
+// CHECK:      static ::mlir::LogicalResult TestCstPDLFn(::mlir::PatternRewriter &rewriter,
+// CHECK-SAME:                                           ::llvm::ArrayRef<::mlir::PDLValue> values) {
 // CHECK:   ::mlir::Attribute attr = {};
 // CHECK:   if (values[0])
 // CHECK:     attr = values[0].cast<::mlir::Attribute>();
@@ -69,8 +69,8 @@ Pattern => erase op<test.op3>;
 
 // CHECK-NOT: TestUnusedCst
 
-// CHECK: static void TestRewritePDLFn(::llvm::ArrayRef<::mlir::PDLValue> values,
-// CHECK-SAME:                         ::mlir::PatternRewriter &rewriter, ::mlir::PDLResultList &results) {
+// CHECK: static void TestRewritePDLFn(::mlir::PatternRewriter &rewriter, ::mlir::PDLResultList &results,
+// CHECK-SAME:                         ::llvm::ArrayRef<::mlir::PDLValue> values) {
 // CHECK:   ::mlir::Attribute attr = {};
 // CHECK:   ::mlir::Operation * op = {};
 // CHECK:   ::mlir::Type type = {};


        


More information about the Mlir-commits mailing list