[Mlir-commits] [mlir] 02c4c0d - [mlir][pdl] Remove CreateNativeOp in favor of a more general ApplyNativeRewriteOp.

River Riddle llvmlistbot at llvm.org
Tue Mar 16 13:20:30 PDT 2021


Author: River Riddle
Date: 2021-03-16T13:20:18-07:00
New Revision: 02c4c0d5b2adc79c122bd2662a4458f75771aecf

URL: https://github.com/llvm/llvm-project/commit/02c4c0d5b2adc79c122bd2662a4458f75771aecf
DIFF: https://github.com/llvm/llvm-project/commit/02c4c0d5b2adc79c122bd2662a4458f75771aecf.diff

LOG: [mlir][pdl] Remove CreateNativeOp in favor of a more general ApplyNativeRewriteOp.

This has a numerous amount of benefits, given the overly clunky nature of CreateNativeOp:
* Users can now call into arbitrary rewrite functions from inside of PDL, allowing for more natural interleaving of PDL/C++ and enabling for more of the pattern to be in PDL.
* Removes the need for an additional set of C++ functions/registry/etc. The new ApplyNativeRewriteOp will use the same PDLRewriteFunction as the existing RewriteOp. This reduces the API surface area exposed to users.

This revision also introduces a new PDLResultList class. This class is used to provide results of native rewrite functions back to PDL. We introduce a new class instead of using a SmallVector to simplify the work necessary for variadics, given that ranges will require some changes to the structure of PDLValue.

Differential Revision: https://reviews.llvm.org/D95720

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/PDL/IR/PDLOps.td
    mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td
    mlir/include/mlir/IR/PatternMatch.h
    mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
    mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp
    mlir/lib/Dialect/PDL/IR/PDL.cpp
    mlir/lib/IR/PatternMatch.cpp
    mlir/lib/Rewrite/ByteCode.cpp
    mlir/lib/Rewrite/ByteCode.h
    mlir/lib/Rewrite/FrozenRewritePatternList.cpp
    mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir
    mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-rewriter.mlir
    mlir/test/Dialect/PDL/invalid.mlir
    mlir/test/Rewrite/pdl-bytecode.mlir
    mlir/test/lib/Rewrite/TestPDLByteCode.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td
index 76e4c5d022a4..74f3fce08933 100644
--- a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td
+++ b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td
@@ -29,17 +29,17 @@ class PDL_Op<string mnemonic, list<OpTrait> traits = []>
 }
 
 //===----------------------------------------------------------------------===//
-// pdl::ApplyConstraintOp
+// pdl::ApplyNativeConstraintOp
 //===----------------------------------------------------------------------===//
 
-def PDL_ApplyConstraintOp
-    : PDL_Op<"apply_constraint", [HasParent<"pdl::PatternOp">]> {
-  let summary = "Apply a generic constraint to a set of provided entities";
+def PDL_ApplyNativeConstraintOp
+    : PDL_Op<"apply_native_constraint", [HasParent<"pdl::PatternOp">]> {
+  let summary = "Apply a native constraint to a set of provided entities";
   let description = [{
-    `apply_constraint` operations apply a generic constraint, that has been
-    registered externally with the consumer of PDL, to a given set of entities.
-    The constraint is permitted to accept any number of constant valued
-    parameters.
+    `pdl.apply_native_constraint` operations apply a native C++ constraint, that
+    has been registered externally with the consumer of PDL, to a given set of
+    entities. The constraint is permitted to accept any number of constant
+    valued parameters.
 
     Example:
 
@@ -47,7 +47,7 @@ def PDL_ApplyConstraintOp
     // Apply `myConstraint` to the entities defined by `input`, `attr`, and
     // `op`. `42`, `"abc"`, and `i32` are constant parameters passed to the
     // constraint.
-    pdl.apply_constraint "myConstraint"[42, "abc", i32](%input, %attr, %op : !pdl.value, !pdl.attribute, !pdl.operation)
+    pdl.apply_native_constraint "myConstraint"[42, "abc", i32](%input, %attr, %op : !pdl.value, !pdl.attribute, !pdl.operation)
     ```
   }];
 
@@ -67,6 +67,58 @@ def PDL_ApplyConstraintOp
   ];
 }
 
+//===----------------------------------------------------------------------===//
+// pdl::ApplyNativeRewriteOp
+//===----------------------------------------------------------------------===//
+
+def PDL_ApplyNativeRewriteOp
+    : PDL_Op<"apply_native_rewrite", [HasParent<"pdl::RewriteOp">]> {
+  let summary = "Apply a native rewrite method inside of pdl.rewrite region";
+  let description = [{
+    `pdl.apply_native_rewrite` operations apply a native C++ function, that has
+    been registered externally with the consumer of PDL, to perform a rewrite
+    and optionally return a number of values. The native function may accept any
+    number of arguments and constant attribute parameters. This operation is
+    used within a pdl.rewrite region to enable the interleaving of native
+    rewrite methods with other pdl constructs.
+
+    Example:
+
+    ```mlir
+    // Apply a native rewrite method that returns an attribute.
+    %ret = pdl.apply_native_rewrite "myNativeFunc"[42, "gt"](%arg0, %arg1) : !pdl.attribute
+    ```
+
+    ```c++
+    // The native rewrite as defined in C++:
+    static void myNativeFunc(ArrayRef<PDLValue> args, ArrayAttr constantParams,
+                             PatternRewriter &rewriter,
+                             PDLResultList &results) {
+      Value arg0 = args[0].cast<Value>();
+      Value arg1 = args[1].cast<Value>();
+      IntegerAttr param0 = constantParams[0].cast<IntegerAttr>();
+      StringAttr param1 = constantParams[1].cast<StringAttr>();
+
+      // Just push back the first param attribute.
+      results.push_back(param0);
+    }
+
+    void registerNativeRewrite(PDLPatternModule &pdlModule) {
+      pdlModule.registerRewriteFunction("myNativeFunc", myNativeFunc);
+    }
+    ```
+  }];
+
+  let arguments = (ins StrAttr:$name,
+                       Variadic<PDL_AnyType>:$args,
+                       OptionalAttr<ArrayAttr>:$constParams);
+  let results = (outs Variadic<PDL_AnyType>:$results);
+  let assemblyFormat = [{
+    $name ($constParams^)? (`(` $args^ `:` type($args) `)`)? `:` type($results)
+    attr-dict
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // pdl::AttributeOp
 //===----------------------------------------------------------------------===//
@@ -113,39 +165,6 @@ def PDL_AttributeOp : PDL_Op<"attribute"> {
   ];
 }
 
-//===----------------------------------------------------------------------===//
-// pdl::CreateNativeOp
-//===----------------------------------------------------------------------===//
-
-def PDL_CreateNativeOp
-    : PDL_Op<"create_native", [HasParent<"pdl::RewriteOp">]> {
-  let summary = "Call a native creation method to construct an `Attribute`, "
-                "`Operation`, `Type`, or `Value`";
-  let description = [{
-    `pdl.create_native` operations invoke a native C++ function, that has been
-    registered externally with the consumer of PDL, to create an `Attribute`,
-    `Operation`, `Type`, or `Value`. The native function must produce a value
-    of the specified return type, and may accept any number of positional
-    arguments and constant attribute parameters.
-
-    Example:
-
-    ```mlir
-    %ret = pdl.create_native "myNativeFunc"[42, "gt"](%arg0, %arg1) : !pdl.attribute
-    ```
-  }];
-
-  let arguments = (ins StrAttr:$name,
-                       Variadic<PDL_AnyType>:$args,
-                       OptionalAttr<ArrayAttr>:$constParams);
-  let results = (outs PDL_AnyType:$result);
-  let assemblyFormat = [{
-    $name ($constParams^)? (`(` $args^ `:` type($args) `)`)? `:` type($result)
-    attr-dict
-  }];
-  let verifier = ?;
-}
-
 //===----------------------------------------------------------------------===//
 // pdl::EraseOp
 //===----------------------------------------------------------------------===//
@@ -233,9 +252,10 @@ def PDL_OperationOp
     `pdl.rewrite`, all of the result types must be "inferable". This means that
     the type must be attributable to either a constant type value or the result
     type of another entity, such as an attribute, the result of a
-    `createNative`, or the result type of another operation. If the result type
-    value does not meet any of these criteria, the operation must provide the
-    `InferTypeOpInterface` to ensure that the result types can be inferred.
+    `apply_native_rewrite`, or the result type of another operation. If the
+    result type value does not meet any of these criteria, the operation must
+    override the `InferTypeOpInterface` to ensure that the result types can be
+    inferred.
 
     Example:
 
@@ -416,13 +436,14 @@ def PDL_RewriteOp : PDL_Op<"rewrite", [
   let summary = "Specify the rewrite of a matched pattern";
   let description = [{
     `pdl.rewrite` operations terminate the region of a `pdl.pattern` and specify
-    the rewrite of a `pdl.pattern`, on the specified root operation. The
+    the main rewrite of a `pdl.pattern`, on the specified root operation. The
     rewrite is specified either via a string name (`name`) to an external
     rewrite function, or via the region body. The rewrite region, if specified,
     must contain a single block and terminate via the `pdl.rewrite_end`
     operation. If the rewrite is external, it also takes a set of constant
     parameters and a set of additional positional values defined within the
-    matcher as arguments.
+    matcher as arguments. If the rewrite is external, the root operation is
+    passed to the native function as the first argument.
 
     Example:
 

diff  --git a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td
index 517a0f4f0af0..8f8a5b130175 100644
--- a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td
+++ b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td
@@ -130,32 +130,35 @@ def PDLInterp_ApplyRewriteOp : PDLInterp_Op<"apply_rewrite"> {
   let description = [{
     `pdl_interp.apply_rewrite` operations invoke an external rewriter that has
     been registered with the interpreter to perform the rewrite after a
-    successful match. The rewrite is passed the root operation being matched, a
-    set of additional positional arguments generated within the matcher, and a
-    set of constant parameters.
+    successful match. The rewrite is passed a set of positional arguments,
+    and a set of constant parameters. The rewrite function may return any
+    number of results.
 
     Example:
 
     ```mlir
     // Rewriter operating solely on the root operation.
-    pdl_interp.apply_rewrite "rewriter" on %root
+    pdl_interp.apply_rewrite "rewriter"(%root : !pdl.operation)
+
+    // Rewriter operating solely on the root operation and return an attribute.
+    %attr = pdl_interp.apply_rewrite "rewriter"(%root : !pdl.operation) : !pdl.attribute
 
     // Rewriter operating on the root operation along with additional arguments
     // from the matcher.
-    pdl_interp.apply_rewrite "rewriter"(%value : !pdl.value) on %root
+    pdl_interp.apply_rewrite "rewriter"(%root : !pdl.operation, %value : !pdl.value)
 
     // Rewriter operating on the root operation along with additional arguments
     // and constant parameters.
-    pdl_interp.apply_rewrite "rewriter"[42](%value : !pdl.value) on %root
+    pdl_interp.apply_rewrite "rewriter"[42](%root : !pdl.operation, %value : !pdl.value)
     ```
   }];
   let arguments = (ins StrAttr:$name,
-                       PDL_Operation:$root,
                        Variadic<PDL_AnyType>:$args,
                        OptionalAttr<ArrayAttr>:$constParams);
+  let results = (outs Variadic<PDL_AnyType>:$results);
   let assemblyFormat = [{
-    $name ($constParams^)? (`(` $args^ `:` type($args) `)`)? `on` $root
-    attr-dict
+    $name ($constParams^)? (`(` $args^ `:` type($args) `)`)?
+    (`:` type($results)^)? attr-dict
   }];
 }
 
@@ -351,38 +354,6 @@ def PDLInterp_CreateAttributeOp
     }]>];
 }
 
-//===----------------------------------------------------------------------===//
-// pdl_interp::CreateNativeOp
-//===----------------------------------------------------------------------===//
-
-def PDLInterp_CreateNativeOp : PDLInterp_Op<"create_native"> {
-  let summary = "Call a native creation method to construct an `Attribute`, "
-                "`Operation`, `Type`, or `Value`";
-  let description = [{
-    `pdl_interp.create_native` operations invoke a native C++ function, that has
-    been registered externally with the consumer of PDL, to create an
-    `Attribute`, `Operation`, `Type`, or `Value`. The native function must
-    produce a value of the specified return type, and may accept any number of
-    positional arguments and constant attribute parameters.
-
-    Example:
-
-    ```mlir
-    %ret = pdl_interp.create_native "myNativeFunc"[42, "gt"](%arg0, %arg1 : !pdl.value, !pdl.value) : !pdl.attribute
-    ```
-  }];
-
-  let arguments = (ins StrAttr:$name,
-                       Variadic<PDL_AnyType>:$args,
-                       OptionalAttr<ArrayAttr>:$constParams);
-  let results = (outs PDL_AnyType:$result);
-  let assemblyFormat = [{
-    $name ($constParams^)? (`(` $args^ `:` type($args) `)`)? `:` type($result)
-    attr-dict
-  }];
-  let verifier = ?;
-}
-
 //===----------------------------------------------------------------------===//
 // pdl_interp::CreateOperationOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 8e1a5b98c318..56da9b870948 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -302,6 +302,33 @@ inline raw_ostream &operator<<(raw_ostream &os, PDLValue value) {
   return os;
 }
 
+//===----------------------------------------------------------------------===//
+// PDLResultList
+
+/// The class represents a list of PDL results, returned by a native rewrite
+/// method. It provides the mechanism with which to pass PDLValues back to the
+/// PDL bytecode.
+class PDLResultList {
+public:
+  /// Push a new Attribute value onto the result list.
+  void push_back(Attribute value) { results.push_back(value); }
+
+  /// Push a new Operation onto the result list.
+  void push_back(Operation *value) { results.push_back(value); }
+
+  /// Push a new Type onto the result list.
+  void push_back(Type value) { results.push_back(value); }
+
+  /// Push a new Value onto the result list.
+  void push_back(Value value) { results.push_back(value); }
+
+protected:
+  PDLResultList() = default;
+
+  /// The PDL results held by this list.
+  SmallVector<PDLValue> results;
+};
+
 //===----------------------------------------------------------------------===//
 // PDLPatternModule
 
@@ -311,16 +338,13 @@ inline raw_ostream &operator<<(raw_ostream &os, PDLValue value) {
 /// success if the constraint successfully held, failure otherwise.
 using PDLConstraintFunction = std::function<LogicalResult(
     ArrayRef<PDLValue>, ArrayAttr, PatternRewriter &)>;
-/// A native PDL creation function. This function creates a new PDLValue given
-/// a set of existing PDL values, a set of constant parameters specified in
-/// Attribute form, and a PatternRewriter. Returns the newly created PDLValue.
-using PDLCreateFunction =
-    std::function<PDLValue(ArrayRef<PDLValue>, ArrayAttr, PatternRewriter &)>;
-/// A native PDL rewrite function. This function rewrites the given root
-/// operation using the provided PatternRewriter. This method is only invoked
-/// when the corresponding match was successful.
-using PDLRewriteFunction = std::function<void(Operation *, ArrayRef<PDLValue>,
-                                              ArrayAttr, PatternRewriter &)>;
+/// A native PDL rewrite function. This function performs a rewrite on the
+/// given set of values and constant parameters. Any results from this rewrite
+/// that should be passed back to PDL should be added to the provided result
+/// list. This method is only invoked when the corresponding match was
+/// successful.
+using PDLRewriteFunction = std::function<void(
+    ArrayRef<PDLValue>, ArrayAttr, PatternRewriter &, PDLResultList &)>;
 /// A generic PDL pattern constraint function. This function applies a
 /// constraint to a given opaque PDLValue entity. The second parameter is a set
 /// of constant value parameters specified in Attribute form. Returns success if
@@ -367,9 +391,6 @@ class PDLPatternModule {
         });
   }
 
-  /// Register a creation function.
-  void registerCreateFunction(StringRef name, PDLCreateFunction createFn);
-
   /// Register a rewrite function.
   void registerRewriteFunction(StringRef name, PDLRewriteFunction rewriteFn);
 
@@ -380,13 +401,6 @@ class PDLPatternModule {
   llvm::StringMap<PDLConstraintFunction> takeConstraintFunctions() {
     return constraintFunctions;
   }
-  /// Return the set of the registered create functions.
-  const llvm::StringMap<PDLCreateFunction> &getCreateFunctions() const {
-    return createFunctions;
-  }
-  llvm::StringMap<PDLCreateFunction> takeCreateFunctions() {
-    return createFunctions;
-  }
   /// Return the set of the registered rewrite functions.
   const llvm::StringMap<PDLRewriteFunction> &getRewriteFunctions() const {
     return rewriteFunctions;
@@ -399,7 +413,6 @@ class PDLPatternModule {
   void clear() {
     pdlModule = nullptr;
     constraintFunctions.clear();
-    createFunctions.clear();
     rewriteFunctions.clear();
   }
 
@@ -409,7 +422,6 @@ class PDLPatternModule {
 
   /// The external functions referenced from within the PDL module.
   llvm::StringMap<PDLConstraintFunction> constraintFunctions;
-  llvm::StringMap<PDLCreateFunction> createFunctions;
   llvm::StringMap<PDLRewriteFunction> rewriteFunctions;
 };
 

diff  --git a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
index 3368ceb9be88..d1da22671d95 100644
--- a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
+++ b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
@@ -70,6 +70,9 @@ struct PatternLowering {
                                  SmallVectorImpl<Position *> &usedMatchValues);
 
   /// Generate the rewriter code for the given operation.
+  void generateRewriter(pdl::ApplyNativeRewriteOp rewriteOp,
+                        DenseMap<Value, Value> &rewriteValues,
+                        function_ref<Value(Value)> mapRewriteValue);
   void generateRewriter(pdl::AttributeOp attrOp,
                         DenseMap<Value, Value> &rewriteValues,
                         function_ref<Value(Value)> mapRewriteValue);
@@ -79,9 +82,6 @@ struct PatternLowering {
   void generateRewriter(pdl::OperationOp operationOp,
                         DenseMap<Value, Value> &rewriteValues,
                         function_ref<Value(Value)> mapRewriteValue);
-  void generateRewriter(pdl::CreateNativeOp createNativeOp,
-                        DenseMap<Value, Value> &rewriteValues,
-                        function_ref<Value(Value)> mapRewriteValue);
   void generateRewriter(pdl::ReplaceOp replaceOp,
                         DenseMap<Value, Value> &rewriteValues,
                         function_ref<Value(Value)> mapRewriteValue);
@@ -449,17 +449,17 @@ SymbolRefAttr PatternLowering::generateRewriter(
   // method.
   pdl::RewriteOp rewriter = pattern.getRewriter();
   if (StringAttr rewriteName = rewriter.nameAttr()) {
-    Value root = mapRewriteValue(rewriter.root());
-    SmallVector<Value, 4> args = llvm::to_vector<4>(
-        llvm::map_range(rewriter.externalArgs(), mapRewriteValue));
+    auto mappedArgs = llvm::map_range(rewriter.externalArgs(), mapRewriteValue);
+    SmallVector<Value, 4> args(1, mapRewriteValue(rewriter.root()));
+    args.append(mappedArgs.begin(), mappedArgs.end());
     builder.create<pdl_interp::ApplyRewriteOp>(
-        rewriter.getLoc(), rewriteName, root, args,
+        rewriter.getLoc(), /*resultTypes=*/TypeRange(), rewriteName, args,
         rewriter.externalConstParamsAttr());
   } else {
     // Otherwise this is a dag rewriter defined using PDL operations.
     for (Operation &rewriteOp : *rewriter.getBody()) {
       llvm::TypeSwitch<Operation *>(&rewriteOp)
-          .Case<pdl::AttributeOp, pdl::CreateNativeOp, pdl::EraseOp,
+          .Case<pdl::ApplyNativeRewriteOp, pdl::AttributeOp, pdl::EraseOp,
                 pdl::OperationOp, pdl::ReplaceOp, pdl::ResultOp, pdl::TypeOp>(
               [&](auto op) {
                 this->generateRewriter(op, rewriteValues, mapRewriteValue);
@@ -478,6 +478,19 @@ SymbolRefAttr PatternLowering::generateRewriter(
       builder.getSymbolRefAttr(rewriterFunc));
 }
 
+void PatternLowering::generateRewriter(
+    pdl::ApplyNativeRewriteOp rewriteOp, DenseMap<Value, Value> &rewriteValues,
+    function_ref<Value(Value)> mapRewriteValue) {
+  SmallVector<Value, 2> arguments;
+  for (Value argument : rewriteOp.args())
+    arguments.push_back(mapRewriteValue(argument));
+  auto interpOp = builder.create<pdl_interp::ApplyRewriteOp>(
+      rewriteOp.getLoc(), rewriteOp.getResultTypes(), rewriteOp.nameAttr(),
+      arguments, rewriteOp.constParamsAttr());
+  for (auto it : llvm::zip(rewriteOp.results(), interpOp.results()))
+    rewriteValues[std::get<0>(it)] = std::get<1>(it);
+}
+
 void PatternLowering::generateRewriter(
     pdl::AttributeOp attrOp, DenseMap<Value, Value> &rewriteValues,
     function_ref<Value(Value)> mapRewriteValue) {
@@ -527,18 +540,6 @@ void PatternLowering::generateRewriter(
   }
 }
 
-void PatternLowering::generateRewriter(
-    pdl::CreateNativeOp createNativeOp, DenseMap<Value, Value> &rewriteValues,
-    function_ref<Value(Value)> mapRewriteValue) {
-  SmallVector<Value, 2> arguments;
-  for (Value argument : createNativeOp.args())
-    arguments.push_back(mapRewriteValue(argument));
-  Value result = builder.create<pdl_interp::CreateNativeOp>(
-      createNativeOp.getLoc(), createNativeOp.result().getType(),
-      createNativeOp.nameAttr(), arguments, createNativeOp.constParamsAttr());
-  rewriteValues[createNativeOp] = result;
-}
-
 void PatternLowering::generateRewriter(
     pdl::ReplaceOp replaceOp, DenseMap<Value, Value> &rewriteValues,
     function_ref<Value(Value)> mapRewriteValue) {

diff  --git a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp
index 0db35f050515..885fbad0f976 100644
--- a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp
+++ b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp
@@ -153,7 +153,7 @@ static void getTreePredicates(std::vector<PositionalPredicate> &predList,
 
 /// Collect all of the predicates related to constraints within the given
 /// pattern operation.
-static void getConstraintPredicates(pdl::ApplyConstraintOp op,
+static void getConstraintPredicates(pdl::ApplyNativeConstraintOp op,
                                     std::vector<PositionalPredicate> &predList,
                                     PredicateBuilder &builder,
                                     DenseMap<Value, Position *> &inputs) {
@@ -192,7 +192,7 @@ static void getNonTreePredicates(pdl::PatternOp pattern,
                                  PredicateBuilder &builder,
                                  DenseMap<Value, Position *> &inputs) {
   for (Operation &op : pattern.body().getOps()) {
-    if (auto constraintOp = dyn_cast<pdl::ApplyConstraintOp>(&op))
+    if (auto constraintOp = dyn_cast<pdl::ApplyNativeConstraintOp>(&op))
       getConstraintPredicates(constraintOp, predList, builder, inputs);
     else if (auto resultOp = dyn_cast<pdl::ResultOp>(&op))
       getResultPredicates(resultOp, predList, builder, inputs);

diff  --git a/mlir/lib/Dialect/PDL/IR/PDL.cpp b/mlir/lib/Dialect/PDL/IR/PDL.cpp
index d35aab41ba8f..dc1f501825bd 100644
--- a/mlir/lib/Dialect/PDL/IR/PDL.cpp
+++ b/mlir/lib/Dialect/PDL/IR/PDL.cpp
@@ -64,15 +64,25 @@ verifyHasBindingUseInMatcher(Operation *op,
 }
 
 //===----------------------------------------------------------------------===//
-// pdl::ApplyConstraintOp
+// pdl::ApplyNativeConstraintOp
 //===----------------------------------------------------------------------===//
 
-static LogicalResult verify(ApplyConstraintOp op) {
+static LogicalResult verify(ApplyNativeConstraintOp op) {
   if (op.getNumOperands() == 0)
     return op.emitOpError("expected at least one argument");
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// pdl::ApplyNativeRewriteOp
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verify(ApplyNativeRewriteOp op) {
+  if (op.getNumOperands() == 0 && op.getNumResults() == 0)
+    return op.emitOpError("expected at least one argument or result");
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // pdl::AttributeOp
 //===----------------------------------------------------------------------===//
@@ -162,9 +172,9 @@ static LogicalResult verifyResultTypesAreInferrable(OperationOp op,
     Operation *resultTypeOp = it.value().getDefiningOp();
     assert(resultTypeOp && "expected valid result type operation");
 
-    // If the op was defined by a `create_native`, it is guaranteed to be
+    // If the op was defined by a `apply_native_rewrite`, it is guaranteed to be
     // usable.
-    if (isa<CreateNativeOp>(resultTypeOp))
+    if (isa<ApplyNativeRewriteOp>(resultTypeOp))
       continue;
 
     // If the type is already constrained, there is nothing to do.

diff  --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index 90e89a536405..034698d85cb1 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -102,7 +102,6 @@ void PDLPatternModule::mergeIn(PDLPatternModule &&other) {
   // Steal the other state if we have no patterns.
   if (!pdlModule) {
     constraintFunctions = std::move(other.constraintFunctions);
-    createFunctions = std::move(other.createFunctions);
     rewriteFunctions = std::move(other.rewriteFunctions);
     pdlModule = std::move(other.pdlModule);
     return;
@@ -110,8 +109,6 @@ void PDLPatternModule::mergeIn(PDLPatternModule &&other) {
   // Steal the functions of the other module.
   for (auto &it : constraintFunctions)
     registerConstraintFunction(it.first(), std::move(it.second));
-  for (auto &it : createFunctions)
-    registerCreateFunction(it.first(), std::move(it.second));
   for (auto &it : rewriteFunctions)
     registerRewriteFunction(it.first(), std::move(it.second));
 
@@ -132,13 +129,7 @@ void PDLPatternModule::registerConstraintFunction(
   assert(it.second &&
          "constraint with the given name has already been registered");
 }
-void PDLPatternModule::registerCreateFunction(StringRef name,
-                                              PDLCreateFunction createFn) {
-  auto it = createFunctions.try_emplace(name, std::move(createFn));
-  (void)it;
-  assert(it.second && "native create function with the given name has "
-                      "already been registered");
-}
+
 void PDLPatternModule::registerRewriteFunction(StringRef name,
                                                PDLRewriteFunction rewriteFn) {
   auto it = rewriteFunctions.try_emplace(name, std::move(rewriteFn));

diff  --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp
index 1986b3f87d96..c09892caec1b 100644
--- a/mlir/lib/Rewrite/ByteCode.cpp
+++ b/mlir/lib/Rewrite/ByteCode.cpp
@@ -80,8 +80,6 @@ enum OpCode : ByteCodeField {
   CheckOperationName,
   /// Compare the result count of an operation with a constant.
   CheckResultCount,
-  /// Invoke a native creation method.
-  CreateNative,
   /// Create an operation.
   CreateOperation,
   /// Erase an operation.
@@ -148,15 +146,12 @@ class Generator {
             SmallVectorImpl<PDLByteCodePattern> &patterns,
             ByteCodeField &maxValueMemoryIndex,
             llvm::StringMap<PDLConstraintFunction> &constraintFns,
-            llvm::StringMap<PDLCreateFunction> &createFns,
             llvm::StringMap<PDLRewriteFunction> &rewriteFns)
       : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode),
         rewriterByteCode(rewriterByteCode), patterns(patterns),
         maxValueMemoryIndex(maxValueMemoryIndex) {
     for (auto it : llvm::enumerate(constraintFns))
       constraintToMemIndex.try_emplace(it.value().first(), it.index());
-    for (auto it : llvm::enumerate(createFns))
-      nativeCreateToMemIndex.try_emplace(it.value().first(), it.index());
     for (auto it : llvm::enumerate(rewriteFns))
       externalRewriterToMemIndex.try_emplace(it.value().first(), it.index());
   }
@@ -203,7 +198,6 @@ class Generator {
   void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer);
   void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer);
   void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer);
-  void generate(pdl_interp::CreateNativeOp op, ByteCodeWriter &writer);
   void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer);
   void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer);
   void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer);
@@ -235,10 +229,6 @@ class Generator {
   /// in the bytecode registry.
   llvm::StringMap<ByteCodeField> constraintToMemIndex;
 
-  /// Mapping from the name of an externally registered creation method to its
-  /// index in the bytecode registry.
-  llvm::StringMap<ByteCodeField> nativeCreateToMemIndex;
-
   /// Mapping from rewriter function name to the bytecode address of the
   /// rewriter function in byte.
   llvm::StringMap<ByteCodeAddr> rewriterToAddr;
@@ -492,16 +482,16 @@ void Generator::generate(Operation *op, ByteCodeWriter &writer) {
             pdl_interp::CheckAttributeOp, pdl_interp::CheckOperandCountOp,
             pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp,
             pdl_interp::CheckTypeOp, pdl_interp::CreateAttributeOp,
-            pdl_interp::CreateNativeOp, pdl_interp::CreateOperationOp,
-            pdl_interp::CreateTypeOp, pdl_interp::EraseOp,
-            pdl_interp::FinalizeOp, pdl_interp::GetAttributeOp,
-            pdl_interp::GetAttributeTypeOp, pdl_interp::GetDefiningOpOp,
-            pdl_interp::GetOperandOp, pdl_interp::GetResultOp,
-            pdl_interp::GetValueTypeOp, pdl_interp::InferredTypeOp,
-            pdl_interp::IsNotNullOp, pdl_interp::RecordMatchOp,
-            pdl_interp::ReplaceOp, pdl_interp::SwitchAttributeOp,
-            pdl_interp::SwitchTypeOp, pdl_interp::SwitchOperandCountOp,
-            pdl_interp::SwitchOperationNameOp, pdl_interp::SwitchResultCountOp>(
+            pdl_interp::CreateOperationOp, pdl_interp::CreateTypeOp,
+            pdl_interp::EraseOp, pdl_interp::FinalizeOp,
+            pdl_interp::GetAttributeOp, pdl_interp::GetAttributeTypeOp,
+            pdl_interp::GetDefiningOpOp, pdl_interp::GetOperandOp,
+            pdl_interp::GetResultOp, pdl_interp::GetValueTypeOp,
+            pdl_interp::InferredTypeOp, pdl_interp::IsNotNullOp,
+            pdl_interp::RecordMatchOp, pdl_interp::ReplaceOp,
+            pdl_interp::SwitchAttributeOp, pdl_interp::SwitchTypeOp,
+            pdl_interp::SwitchOperandCountOp, pdl_interp::SwitchOperationNameOp,
+            pdl_interp::SwitchResultCountOp>(
           [&](auto interpOp) { this->generate(interpOp, writer); })
       .Default([](Operation *) {
         llvm_unreachable("unknown `pdl_interp` operation");
@@ -522,8 +512,16 @@ void Generator::generate(pdl_interp::ApplyRewriteOp op,
   assert(externalRewriterToMemIndex.count(op.name()) &&
          "expected index for rewrite function");
   writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.name()],
-                op.constParamsAttr(), op.root());
+                op.constParamsAttr());
   writer.appendPDLValueList(op.args());
+
+#ifndef NDEBUG
+  // In debug mode we also append the number of results so that we can assert
+  // that the native creation function gave us the correct number of results.
+  writer.append(ByteCodeField(op.results().size()));
+#endif
+  for (Value result : op.results())
+    writer.append(result);
 }
 void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) {
   writer.append(OpCode::AreEqual, op.lhs(), op.rhs(), op.getSuccessors());
@@ -559,14 +557,6 @@ void Generator::generate(pdl_interp::CreateAttributeOp op,
   // Simply repoint the memory index of the result to the constant.
   getMemIndex(op.attribute()) = getMemIndex(op.value());
 }
-void Generator::generate(pdl_interp::CreateNativeOp op,
-                         ByteCodeWriter &writer) {
-  assert(nativeCreateToMemIndex.count(op.name()) &&
-         "expected index for creation function");
-  writer.append(OpCode::CreateNative, nativeCreateToMemIndex[op.name()],
-                op.result(), op.constParamsAttr());
-  writer.appendPDLValueList(op.args());
-}
 void Generator::generate(pdl_interp::CreateOperationOp op,
                          ByteCodeWriter &writer) {
   writer.append(OpCode::CreateOperation, op.operation(),
@@ -678,18 +668,15 @@ void Generator::generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer) {
 
 PDLByteCode::PDLByteCode(ModuleOp module,
                          llvm::StringMap<PDLConstraintFunction> constraintFns,
-                         llvm::StringMap<PDLCreateFunction> createFns,
                          llvm::StringMap<PDLRewriteFunction> rewriteFns) {
   Generator generator(module.getContext(), uniquedData, matcherByteCode,
                       rewriterByteCode, patterns, maxValueMemoryIndex,
-                      constraintFns, createFns, rewriteFns);
+                      constraintFns, rewriteFns);
   generator.generate(module);
 
   // Initialize the external functions.
   for (auto &it : constraintFns)
     constraintFunctions.push_back(std::move(it.second));
-  for (auto &it : createFns)
-    createFunctions.push_back(std::move(it.second));
   for (auto &it : rewriteFns)
     rewriteFunctions.push_back(std::move(it.second));
 }
@@ -717,12 +704,11 @@ class ByteCodeExecutor {
                    ArrayRef<PatternBenefit> currentPatternBenefits,
                    ArrayRef<PDLByteCodePattern> patterns,
                    ArrayRef<PDLConstraintFunction> constraintFunctions,
-                   ArrayRef<PDLCreateFunction> createFunctions,
                    ArrayRef<PDLRewriteFunction> rewriteFunctions)
       : curCodeIt(curCodeIt), memory(memory), uniquedMemory(uniquedMemory),
         code(code), currentPatternBenefits(currentPatternBenefits),
         patterns(patterns), constraintFunctions(constraintFunctions),
-        createFunctions(createFunctions), rewriteFunctions(rewriteFunctions) {}
+        rewriteFunctions(rewriteFunctions) {}
 
   /// Start executing the code at the current bytecode index. `matches` is an
   /// optional field provided when this function is executed in a matching
@@ -740,7 +726,6 @@ class ByteCodeExecutor {
   void executeCheckOperandCount();
   void executeCheckOperationName();
   void executeCheckResultCount();
-  void executeCreateNative(PatternRewriter &rewriter);
   void executeCreateOperation(PatternRewriter &rewriter,
                               Location mainRewriteLoc);
   void executeEraseOp(PatternRewriter &rewriter);
@@ -866,9 +851,17 @@ class ByteCodeExecutor {
   ArrayRef<PatternBenefit> currentPatternBenefits;
   ArrayRef<PDLByteCodePattern> patterns;
   ArrayRef<PDLConstraintFunction> constraintFunctions;
-  ArrayRef<PDLCreateFunction> createFunctions;
   ArrayRef<PDLRewriteFunction> rewriteFunctions;
 };
+
+/// This class is an instantiation of the PDLResultList that provides access to
+/// the returned results. This API is not on `PDLResultList` to avoid
+/// overexposing access to information specific solely to the ByteCode.
+class ByteCodeRewriteResultList : public PDLResultList {
+public:
+  /// Return the list of PDL results.
+  MutableArrayRef<PDLValue> getResults() { return results; }
+};
 } // end anonymous namespace
 
 void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) {
@@ -892,18 +885,29 @@ void ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
   LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n");
   const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()];
   ArrayAttr constParams = read<ArrayAttr>();
-  Operation *root = read<Operation *>();
   SmallVector<PDLValue, 16> args;
   readList<PDLValue>(args);
 
   LLVM_DEBUG({
-    llvm::dbgs() << "  * Root: " << *root << "\n  * Arguments: ";
+    llvm::dbgs() << "  * Arguments: ";
     llvm::interleaveComma(args, llvm::dbgs());
     llvm::dbgs() << "\n  * Parameters: " << constParams << "\n";
   });
-
-  // Invoke the native rewrite function.
-  rewriteFn(root, args, constParams, rewriter);
+  ByteCodeRewriteResultList results;
+  rewriteFn(args, constParams, rewriter, results);
+
+  // Store the results in the bytecode memory.
+#ifndef NDEBUG
+  ByteCodeField expectedNumberOfResults = read();
+  assert(results.getResults().size() == expectedNumberOfResults &&
+         "native PDL rewrite function returned unexpected number of results");
+#endif
+
+  // Store the results in the bytecode memory.
+  for (PDLValue &result : results.getResults()) {
+    LLVM_DEBUG(llvm::dbgs() << "  * Result: " << result << "\n");
+    memory[read()] = result.getAsOpaquePointer();
+  }
 }
 
 void ByteCodeExecutor::executeAreEqual() {
@@ -950,26 +954,6 @@ void ByteCodeExecutor::executeCheckResultCount() {
   selectJump(op->getNumResults() == expectedCount);
 }
 
-void ByteCodeExecutor::executeCreateNative(PatternRewriter &rewriter) {
-  LLVM_DEBUG(llvm::dbgs() << "Executing CreateNative:\n");
-  const PDLCreateFunction &createFn = createFunctions[read()];
-  ByteCodeField resultIndex = read();
-  ArrayAttr constParams = read<ArrayAttr>();
-  SmallVector<PDLValue, 16> args;
-  readList<PDLValue>(args);
-
-  LLVM_DEBUG({
-    llvm::dbgs() << "  * Arguments: ";
-    llvm::interleaveComma(args, llvm::dbgs());
-    llvm::dbgs() << "\n  * Parameters: " << constParams << "\n";
-  });
-
-  PDLValue result = createFn(args, constParams, rewriter);
-  memory[resultIndex] = result.getAsOpaquePointer();
-
-  LLVM_DEBUG(llvm::dbgs() << "  * Result: " << result << "\n");
-}
-
 void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter,
                                               Location mainRewriteLoc) {
   LLVM_DEBUG(llvm::dbgs() << "Executing CreateOperation:\n");
@@ -1246,9 +1230,6 @@ void ByteCodeExecutor::execute(
     case CheckResultCount:
       executeCheckResultCount();
       break;
-    case CreateNative:
-      executeCreateNative(rewriter);
-      break;
     case CreateOperation:
       executeCreateOperation(rewriter, *mainRewriteLoc);
       break;
@@ -1338,8 +1319,7 @@ void PDLByteCode::match(Operation *op, PatternRewriter &rewriter,
   // The matcher function always starts at code address 0.
   ByteCodeExecutor executor(matcherByteCode.data(), state.memory, uniquedData,
                             matcherByteCode, state.currentPatternBenefits,
-                            patterns, constraintFunctions, createFunctions,
-                            rewriteFunctions);
+                            patterns, constraintFunctions, rewriteFunctions);
   executor.execute(rewriter, &matches);
 
   // Order the found matches by benefit.
@@ -1356,9 +1336,9 @@ void PDLByteCode::rewrite(PatternRewriter &rewriter, const MatchResult &match,
   // memory buffer.
   llvm::copy(match.values, state.memory.begin());
 
-  ByteCodeExecutor executor(
-      &rewriterByteCode[match.pattern->getRewriterAddr()], state.memory,
-      uniquedData, rewriterByteCode, state.currentPatternBenefits, patterns,
-      constraintFunctions, createFunctions, rewriteFunctions);
+  ByteCodeExecutor executor(&rewriterByteCode[match.pattern->getRewriterAddr()],
+                            state.memory, uniquedData, rewriterByteCode,
+                            state.currentPatternBenefits, patterns,
+                            constraintFunctions, rewriteFunctions);
   executor.execute(rewriter, /*matches=*/nullptr, match.location);
 }

diff  --git a/mlir/lib/Rewrite/ByteCode.h b/mlir/lib/Rewrite/ByteCode.h
index 38dbbcd855ce..f6a3bcbe54f9 100644
--- a/mlir/lib/Rewrite/ByteCode.h
+++ b/mlir/lib/Rewrite/ByteCode.h
@@ -114,7 +114,6 @@ class PDLByteCode {
   /// the PDL interpreter dialect.
   PDLByteCode(ModuleOp module,
               llvm::StringMap<PDLConstraintFunction> constraintFns,
-              llvm::StringMap<PDLCreateFunction> createFns,
               llvm::StringMap<PDLRewriteFunction> rewriteFns);
 
   /// Return the patterns held by the bytecode.
@@ -160,7 +159,6 @@ class PDLByteCode {
 
   /// A set of user defined functions invoked via PDL.
   std::vector<PDLConstraintFunction> constraintFunctions;
-  std::vector<PDLCreateFunction> createFunctions;
   std::vector<PDLRewriteFunction> rewriteFunctions;
 
   /// The maximum memory index used by a value.

diff  --git a/mlir/lib/Rewrite/FrozenRewritePatternList.cpp b/mlir/lib/Rewrite/FrozenRewritePatternList.cpp
index 40f7aec44e51..c2de51a647dd 100644
--- a/mlir/lib/Rewrite/FrozenRewritePatternList.cpp
+++ b/mlir/lib/Rewrite/FrozenRewritePatternList.cpp
@@ -70,7 +70,7 @@ FrozenRewritePatternList::FrozenRewritePatternList(
   // Generate the pdl bytecode.
   impl->pdlByteCode = std::make_unique<detail::PDLByteCode>(
       pdlModule, pdlPatterns.takeConstraintFunctions(),
-      pdlPatterns.takeCreateFunctions(), pdlPatterns.takeRewriteFunctions());
+      pdlPatterns.takeRewriteFunctions());
 }
 
 FrozenRewritePatternList::~FrozenRewritePatternList() {}

diff  --git a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir
index c856ab5c9f6f..a42b51604945 100644
--- a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir
+++ b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir
@@ -24,7 +24,7 @@ module @simple {
 
   // CHECK: module @rewriters
   // CHECK:   func @pdl_generated_rewriter(%[[REWRITE_ROOT:.*]]: !pdl.operation)
-  // CHECK:     pdl_interp.apply_rewrite "rewriter" on %[[REWRITE_ROOT]]
+  // CHECK:     pdl_interp.apply_rewrite "rewriter"(%[[REWRITE_ROOT]]
   // CHECK:     pdl_interp.finalize
   pdl.pattern : benefit(1) {
     %root = pdl.operation "foo.op"()
@@ -72,7 +72,7 @@ module @constraints {
     %root = pdl.operation(%input0, %input1)
     %result0 = pdl.result 0 of %root
 
-    pdl.apply_constraint "multi_constraint"[true](%input0, %input1, %result0 : !pdl.value, !pdl.value, !pdl.value)
+    pdl.apply_native_constraint "multi_constraint"[true](%input0, %input1, %result0 : !pdl.value, !pdl.value, !pdl.value)
     pdl.rewrite %root with "rewriter"
   }
 }
@@ -194,7 +194,7 @@ module @predicate_ordering  {
 
   pdl.pattern : benefit(1) {
     %resultType = pdl.type
-    pdl.apply_constraint "typeConstraint"[](%resultType : !pdl.type)
+    pdl.apply_native_constraint "typeConstraint"[](%resultType : !pdl.type)
     %root = pdl.operation -> %resultType
     pdl.rewrite %root with "rewriter"
   }

diff  --git a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-rewriter.mlir b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-rewriter.mlir
index 5652b2118afe..3d0d565c547f 100644
--- a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-rewriter.mlir
+++ b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-rewriter.mlir
@@ -6,7 +6,7 @@
 module @external {
   // CHECK: module @rewriters
   // CHECK:   func @pdl_generated_rewriter(%[[ROOT:.*]]: !pdl.operation, %[[INPUT:.*]]: !pdl.value)
-  // CHECK:     pdl_interp.apply_rewrite "rewriter" [true](%[[INPUT]] : !pdl.value) on %[[ROOT]]
+  // CHECK:     pdl_interp.apply_rewrite "rewriter" [true](%[[ROOT]], %[[INPUT]] : !pdl.operation, !pdl.value)
   pdl.pattern : benefit(1) {
     %input = pdl.operand
     %root = pdl.operation "foo.op"(%input)
@@ -170,17 +170,17 @@ module @replace_with_no_results {
 
 // -----
 
-// CHECK-LABEL: module @create_native
-module @create_native {
+// CHECK-LABEL: module @apply_native_rewrite
+module @apply_native_rewrite {
   // CHECK: module @rewriters
   // CHECK:   func @pdl_generated_rewriter(%[[ROOT:.*]]: !pdl.operation)
-  // CHECK:     %[[TYPE:.*]] = pdl_interp.create_native "functor" [true](%[[ROOT]] : !pdl.operation) : !pdl.type
+  // CHECK:     %[[TYPE:.*]] = pdl_interp.apply_rewrite "functor" [true](%[[ROOT]] : !pdl.operation) : !pdl.type
   // CHECK:     pdl_interp.create_operation "foo.op"() -> %[[TYPE]]
   pdl.pattern : benefit(1) {
     %type = pdl.type
     %root = pdl.operation "foo.op" -> %type
     pdl.rewrite %root {
-      %newType = pdl.create_native "functor"[true](%root : !pdl.operation) : !pdl.type
+      %newType = pdl.apply_native_rewrite "functor"[true](%root : !pdl.operation) : !pdl.type
       %newOp = pdl.operation "foo.op" -> %newType
       pdl.replace %root with %newOp
     }

diff  --git a/mlir/test/Dialect/PDL/invalid.mlir b/mlir/test/Dialect/PDL/invalid.mlir
index 0f900bbe3f53..a054da24ba4d 100644
--- a/mlir/test/Dialect/PDL/invalid.mlir
+++ b/mlir/test/Dialect/PDL/invalid.mlir
@@ -1,19 +1,33 @@
 // RUN: mlir-opt %s -split-input-file -verify-diagnostics
 
 //===----------------------------------------------------------------------===//
-// pdl::ApplyConstraintOp
+// pdl::ApplyNativeConstraintOp
 //===----------------------------------------------------------------------===//
 
 pdl.pattern : benefit(1) {
   %op = pdl.operation "foo.op"
 
   // expected-error at below {{expected at least one argument}}
-  "pdl.apply_constraint"() {name = "foo", params = []} : () -> ()
+  "pdl.apply_native_constraint"() {name = "foo", params = []} : () -> ()
   pdl.rewrite %op with "rewriter"
 }
 
 // -----
 
+//===----------------------------------------------------------------------===//
+// pdl::ApplyNativeRewriteOp
+//===----------------------------------------------------------------------===//
+
+pdl.pattern : benefit(1) {
+  %op = pdl.operation "foo.op"
+  pdl.rewrite %op {
+    // expected-error at below {{expected at least one argument}}
+    "pdl.apply_native_rewrite"() {name = "foo", params = []} : () -> ()
+  }
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // pdl::AttributeOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Rewrite/pdl-bytecode.mlir b/mlir/test/Rewrite/pdl-bytecode.mlir
index b2a22d0a8749..2093d03bbf25 100644
--- a/mlir/test/Rewrite/pdl-bytecode.mlir
+++ b/mlir/test/Rewrite/pdl-bytecode.mlir
@@ -58,7 +58,7 @@ module @patterns {
   module @rewriters {
     func @success(%root : !pdl.operation) {
       %operand = pdl_interp.get_operand 0 of %root
-      pdl_interp.apply_rewrite "rewriter"[42](%operand : !pdl.value) on %root
+      pdl_interp.apply_rewrite "rewriter"[42](%root, %operand : !pdl.operation, !pdl.value)
       pdl_interp.finalize
     }
   }
@@ -72,6 +72,35 @@ module @ir attributes { test.apply_rewrite_1 } {
   %input = "test.op_input"() : () -> i32
   "test.op"(%input) : (i32) -> ()
 }
+
+// -----
+
+module @patterns {
+  func @matcher(%root : !pdl.operation) {
+    pdl_interp.check_operation_name of %root is "test.op" -> ^pat, ^end
+
+  ^pat:
+    pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end
+
+  ^end:
+    pdl_interp.finalize
+  }
+
+  module @rewriters {
+    func @success(%root : !pdl.operation) {
+      %op = pdl_interp.apply_rewrite "creator"(%root : !pdl.operation) : !pdl.operation
+      pdl_interp.erase %root
+      pdl_interp.finalize
+    }
+  }
+}
+
+// CHECK-LABEL: test.apply_rewrite_2
+// CHECK: "test.success"
+module @ir attributes { test.apply_rewrite_2 } {
+  "test.op"() : () -> ()
+}
+
 // -----
 
 //===----------------------------------------------------------------------===//
@@ -317,38 +346,6 @@ module @ir attributes { test.check_type_1 } {
 
 // Fully tested within the tests for other operations.
 
-//===----------------------------------------------------------------------===//
-// pdl_interp::CreateNativeOp
-//===----------------------------------------------------------------------===//
-
-// -----
-
-module @patterns {
-  func @matcher(%root : !pdl.operation) {
-    pdl_interp.check_operation_name of %root is "test.op" -> ^pat, ^end
-
-  ^pat:
-    pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end
-
-  ^end:
-    pdl_interp.finalize
-  }
-
-  module @rewriters {
-    func @success(%root : !pdl.operation) {
-      %op = pdl_interp.create_native "creator"(%root : !pdl.operation) : !pdl.operation
-      pdl_interp.erase %root
-      pdl_interp.finalize
-    }
-  }
-}
-
-// CHECK-LABEL: test.create_native_1
-// CHECK: "test.success"
-module @ir attributes { test.create_native_1 } {
-  "test.op"() : () -> ()
-}
-
 //===----------------------------------------------------------------------===//
 // pdl_interp::CreateOperationOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/lib/Rewrite/TestPDLByteCode.cpp b/mlir/test/lib/Rewrite/TestPDLByteCode.cpp
index 3b23cb103675..e60022ba94cc 100644
--- a/mlir/test/lib/Rewrite/TestPDLByteCode.cpp
+++ b/mlir/test/lib/Rewrite/TestPDLByteCode.cpp
@@ -26,18 +26,18 @@ static LogicalResult customMultiEntityConstraint(ArrayRef<PDLValue> values,
 }
 
 // Custom creator invoked from PDL.
-static PDLValue customCreate(ArrayRef<PDLValue> args, ArrayAttr constantParams,
-                             PatternRewriter &rewriter) {
-  return rewriter.createOperation(
-      OperationState(args[0].cast<Operation *>()->getLoc(), "test.success"));
+static void customCreate(ArrayRef<PDLValue> args, ArrayAttr constantParams,
+                         PatternRewriter &rewriter, PDLResultList &results) {
+  results.push_back(rewriter.createOperation(
+      OperationState(args[0].cast<Operation *>()->getLoc(), "test.success")));
 }
 
 /// Custom rewriter invoked from PDL.
-static void customRewriter(Operation *root, ArrayRef<PDLValue> args,
-                           ArrayAttr constantParams,
-                           PatternRewriter &rewriter) {
+static void customRewriter(ArrayRef<PDLValue> args, ArrayAttr constantParams,
+                           PatternRewriter &rewriter, PDLResultList &results) {
+  Operation *root = args[0].cast<Operation *>();
   OperationState successOpState(root->getLoc(), "test.success");
-  successOpState.addOperands(args[0].cast<Value>());
+  successOpState.addOperands(args[1].cast<Value>());
   successOpState.addAttribute("constantParams", constantParams);
   rewriter.createOperation(successOpState);
   rewriter.eraseOp(root);
@@ -63,7 +63,7 @@ struct TestPDLByteCodePass
                                           customMultiEntityConstraint);
     pdlPattern.registerConstraintFunction("single_entity_constraint",
                                           customSingleEntityConstraint);
-    pdlPattern.registerCreateFunction("creator", customCreate);
+    pdlPattern.registerRewriteFunction("creator", customCreate);
     pdlPattern.registerRewriteFunction("rewriter", customRewriter);
 
     OwningRewritePatternList patternList(std::move(pdlPattern));


        


More information about the Mlir-commits mailing list