[Mlir-commits] [mlir] 02c4c0d - [mlir][pdl] Remove CreateNativeOp in favor of a more general ApplyNativeRewriteOp.
River Riddle
llvmlistbot at llvm.org
Tue Mar 16 13:20:30 PDT 2021
Author: River Riddle
Date: 2021-03-16T13:20:18-07:00
New Revision: 02c4c0d5b2adc79c122bd2662a4458f75771aecf
URL: https://github.com/llvm/llvm-project/commit/02c4c0d5b2adc79c122bd2662a4458f75771aecf
DIFF: https://github.com/llvm/llvm-project/commit/02c4c0d5b2adc79c122bd2662a4458f75771aecf.diff
LOG: [mlir][pdl] Remove CreateNativeOp in favor of a more general ApplyNativeRewriteOp.
This has a numerous amount of benefits, given the overly clunky nature of CreateNativeOp:
* Users can now call into arbitrary rewrite functions from inside of PDL, allowing for more natural interleaving of PDL/C++ and enabling for more of the pattern to be in PDL.
* Removes the need for an additional set of C++ functions/registry/etc. The new ApplyNativeRewriteOp will use the same PDLRewriteFunction as the existing RewriteOp. This reduces the API surface area exposed to users.
This revision also introduces a new PDLResultList class. This class is used to provide results of native rewrite functions back to PDL. We introduce a new class instead of using a SmallVector to simplify the work necessary for variadics, given that ranges will require some changes to the structure of PDLValue.
Differential Revision: https://reviews.llvm.org/D95720
Added:
Modified:
mlir/include/mlir/Dialect/PDL/IR/PDLOps.td
mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td
mlir/include/mlir/IR/PatternMatch.h
mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp
mlir/lib/Dialect/PDL/IR/PDL.cpp
mlir/lib/IR/PatternMatch.cpp
mlir/lib/Rewrite/ByteCode.cpp
mlir/lib/Rewrite/ByteCode.h
mlir/lib/Rewrite/FrozenRewritePatternList.cpp
mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir
mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-rewriter.mlir
mlir/test/Dialect/PDL/invalid.mlir
mlir/test/Rewrite/pdl-bytecode.mlir
mlir/test/lib/Rewrite/TestPDLByteCode.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td
index 76e4c5d022a4..74f3fce08933 100644
--- a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td
+++ b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td
@@ -29,17 +29,17 @@ class PDL_Op<string mnemonic, list<OpTrait> traits = []>
}
//===----------------------------------------------------------------------===//
-// pdl::ApplyConstraintOp
+// pdl::ApplyNativeConstraintOp
//===----------------------------------------------------------------------===//
-def PDL_ApplyConstraintOp
- : PDL_Op<"apply_constraint", [HasParent<"pdl::PatternOp">]> {
- let summary = "Apply a generic constraint to a set of provided entities";
+def PDL_ApplyNativeConstraintOp
+ : PDL_Op<"apply_native_constraint", [HasParent<"pdl::PatternOp">]> {
+ let summary = "Apply a native constraint to a set of provided entities";
let description = [{
- `apply_constraint` operations apply a generic constraint, that has been
- registered externally with the consumer of PDL, to a given set of entities.
- The constraint is permitted to accept any number of constant valued
- parameters.
+ `pdl.apply_native_constraint` operations apply a native C++ constraint, that
+ has been registered externally with the consumer of PDL, to a given set of
+ entities. The constraint is permitted to accept any number of constant
+ valued parameters.
Example:
@@ -47,7 +47,7 @@ def PDL_ApplyConstraintOp
// Apply `myConstraint` to the entities defined by `input`, `attr`, and
// `op`. `42`, `"abc"`, and `i32` are constant parameters passed to the
// constraint.
- pdl.apply_constraint "myConstraint"[42, "abc", i32](%input, %attr, %op : !pdl.value, !pdl.attribute, !pdl.operation)
+ pdl.apply_native_constraint "myConstraint"[42, "abc", i32](%input, %attr, %op : !pdl.value, !pdl.attribute, !pdl.operation)
```
}];
@@ -67,6 +67,58 @@ def PDL_ApplyConstraintOp
];
}
+//===----------------------------------------------------------------------===//
+// pdl::ApplyNativeRewriteOp
+//===----------------------------------------------------------------------===//
+
+def PDL_ApplyNativeRewriteOp
+ : PDL_Op<"apply_native_rewrite", [HasParent<"pdl::RewriteOp">]> {
+ let summary = "Apply a native rewrite method inside of pdl.rewrite region";
+ let description = [{
+ `pdl.apply_native_rewrite` operations apply a native C++ function, that has
+ been registered externally with the consumer of PDL, to perform a rewrite
+ and optionally return a number of values. The native function may accept any
+ number of arguments and constant attribute parameters. This operation is
+ used within a pdl.rewrite region to enable the interleaving of native
+ rewrite methods with other pdl constructs.
+
+ Example:
+
+ ```mlir
+ // Apply a native rewrite method that returns an attribute.
+ %ret = pdl.apply_native_rewrite "myNativeFunc"[42, "gt"](%arg0, %arg1) : !pdl.attribute
+ ```
+
+ ```c++
+ // The native rewrite as defined in C++:
+ static void myNativeFunc(ArrayRef<PDLValue> args, ArrayAttr constantParams,
+ PatternRewriter &rewriter,
+ PDLResultList &results) {
+ Value arg0 = args[0].cast<Value>();
+ Value arg1 = args[1].cast<Value>();
+ IntegerAttr param0 = constantParams[0].cast<IntegerAttr>();
+ StringAttr param1 = constantParams[1].cast<StringAttr>();
+
+ // Just push back the first param attribute.
+ results.push_back(param0);
+ }
+
+ void registerNativeRewrite(PDLPatternModule &pdlModule) {
+ pdlModule.registerRewriteFunction("myNativeFunc", myNativeFunc);
+ }
+ ```
+ }];
+
+ let arguments = (ins StrAttr:$name,
+ Variadic<PDL_AnyType>:$args,
+ OptionalAttr<ArrayAttr>:$constParams);
+ let results = (outs Variadic<PDL_AnyType>:$results);
+ let assemblyFormat = [{
+ $name ($constParams^)? (`(` $args^ `:` type($args) `)`)? `:` type($results)
+ attr-dict
+ }];
+}
+
//===----------------------------------------------------------------------===//
// pdl::AttributeOp
//===----------------------------------------------------------------------===//
@@ -113,39 +165,6 @@ def PDL_AttributeOp : PDL_Op<"attribute"> {
];
}
-//===----------------------------------------------------------------------===//
-// pdl::CreateNativeOp
-//===----------------------------------------------------------------------===//
-
-def PDL_CreateNativeOp
- : PDL_Op<"create_native", [HasParent<"pdl::RewriteOp">]> {
- let summary = "Call a native creation method to construct an `Attribute`, "
- "`Operation`, `Type`, or `Value`";
- let description = [{
- `pdl.create_native` operations invoke a native C++ function, that has been
- registered externally with the consumer of PDL, to create an `Attribute`,
- `Operation`, `Type`, or `Value`. The native function must produce a value
- of the specified return type, and may accept any number of positional
- arguments and constant attribute parameters.
-
- Example:
-
- ```mlir
- %ret = pdl.create_native "myNativeFunc"[42, "gt"](%arg0, %arg1) : !pdl.attribute
- ```
- }];
-
- let arguments = (ins StrAttr:$name,
- Variadic<PDL_AnyType>:$args,
- OptionalAttr<ArrayAttr>:$constParams);
- let results = (outs PDL_AnyType:$result);
- let assemblyFormat = [{
- $name ($constParams^)? (`(` $args^ `:` type($args) `)`)? `:` type($result)
- attr-dict
- }];
- let verifier = ?;
-}
-
//===----------------------------------------------------------------------===//
// pdl::EraseOp
//===----------------------------------------------------------------------===//
@@ -233,9 +252,10 @@ def PDL_OperationOp
`pdl.rewrite`, all of the result types must be "inferable". This means that
the type must be attributable to either a constant type value or the result
type of another entity, such as an attribute, the result of a
- `createNative`, or the result type of another operation. If the result type
- value does not meet any of these criteria, the operation must provide the
- `InferTypeOpInterface` to ensure that the result types can be inferred.
+ `apply_native_rewrite`, or the result type of another operation. If the
+ result type value does not meet any of these criteria, the operation must
+ override the `InferTypeOpInterface` to ensure that the result types can be
+ inferred.
Example:
@@ -416,13 +436,14 @@ def PDL_RewriteOp : PDL_Op<"rewrite", [
let summary = "Specify the rewrite of a matched pattern";
let description = [{
`pdl.rewrite` operations terminate the region of a `pdl.pattern` and specify
- the rewrite of a `pdl.pattern`, on the specified root operation. The
+ the main rewrite of a `pdl.pattern`, on the specified root operation. The
rewrite is specified either via a string name (`name`) to an external
rewrite function, or via the region body. The rewrite region, if specified,
must contain a single block and terminate via the `pdl.rewrite_end`
operation. If the rewrite is external, it also takes a set of constant
parameters and a set of additional positional values defined within the
- matcher as arguments.
+ matcher as arguments. If the rewrite is external, the root operation is
+ passed to the native function as the first argument.
Example:
diff --git a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td
index 517a0f4f0af0..8f8a5b130175 100644
--- a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td
+++ b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td
@@ -130,32 +130,35 @@ def PDLInterp_ApplyRewriteOp : PDLInterp_Op<"apply_rewrite"> {
let description = [{
`pdl_interp.apply_rewrite` operations invoke an external rewriter that has
been registered with the interpreter to perform the rewrite after a
- successful match. The rewrite is passed the root operation being matched, a
- set of additional positional arguments generated within the matcher, and a
- set of constant parameters.
+ successful match. The rewrite is passed a set of positional arguments,
+ and a set of constant parameters. The rewrite function may return any
+ number of results.
Example:
```mlir
// Rewriter operating solely on the root operation.
- pdl_interp.apply_rewrite "rewriter" on %root
+ pdl_interp.apply_rewrite "rewriter"(%root : !pdl.operation)
+
+ // Rewriter operating solely on the root operation and return an attribute.
+ %attr = pdl_interp.apply_rewrite "rewriter"(%root : !pdl.operation) : !pdl.attribute
// Rewriter operating on the root operation along with additional arguments
// from the matcher.
- pdl_interp.apply_rewrite "rewriter"(%value : !pdl.value) on %root
+ pdl_interp.apply_rewrite "rewriter"(%root : !pdl.operation, %value : !pdl.value)
// Rewriter operating on the root operation along with additional arguments
// and constant parameters.
- pdl_interp.apply_rewrite "rewriter"[42](%value : !pdl.value) on %root
+ pdl_interp.apply_rewrite "rewriter"[42](%root : !pdl.operation, %value : !pdl.value)
```
}];
let arguments = (ins StrAttr:$name,
- PDL_Operation:$root,
Variadic<PDL_AnyType>:$args,
OptionalAttr<ArrayAttr>:$constParams);
+ let results = (outs Variadic<PDL_AnyType>:$results);
let assemblyFormat = [{
- $name ($constParams^)? (`(` $args^ `:` type($args) `)`)? `on` $root
- attr-dict
+ $name ($constParams^)? (`(` $args^ `:` type($args) `)`)?
+ (`:` type($results)^)? attr-dict
}];
}
@@ -351,38 +354,6 @@ def PDLInterp_CreateAttributeOp
}]>];
}
-//===----------------------------------------------------------------------===//
-// pdl_interp::CreateNativeOp
-//===----------------------------------------------------------------------===//
-
-def PDLInterp_CreateNativeOp : PDLInterp_Op<"create_native"> {
- let summary = "Call a native creation method to construct an `Attribute`, "
- "`Operation`, `Type`, or `Value`";
- let description = [{
- `pdl_interp.create_native` operations invoke a native C++ function, that has
- been registered externally with the consumer of PDL, to create an
- `Attribute`, `Operation`, `Type`, or `Value`. The native function must
- produce a value of the specified return type, and may accept any number of
- positional arguments and constant attribute parameters.
-
- Example:
-
- ```mlir
- %ret = pdl_interp.create_native "myNativeFunc"[42, "gt"](%arg0, %arg1 : !pdl.value, !pdl.value) : !pdl.attribute
- ```
- }];
-
- let arguments = (ins StrAttr:$name,
- Variadic<PDL_AnyType>:$args,
- OptionalAttr<ArrayAttr>:$constParams);
- let results = (outs PDL_AnyType:$result);
- let assemblyFormat = [{
- $name ($constParams^)? (`(` $args^ `:` type($args) `)`)? `:` type($result)
- attr-dict
- }];
- let verifier = ?;
-}
-
//===----------------------------------------------------------------------===//
// pdl_interp::CreateOperationOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 8e1a5b98c318..56da9b870948 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -302,6 +302,33 @@ inline raw_ostream &operator<<(raw_ostream &os, PDLValue value) {
return os;
}
+//===----------------------------------------------------------------------===//
+// PDLResultList
+
+/// The class represents a list of PDL results, returned by a native rewrite
+/// method. It provides the mechanism with which to pass PDLValues back to the
+/// PDL bytecode.
+class PDLResultList {
+public:
+ /// Push a new Attribute value onto the result list.
+ void push_back(Attribute value) { results.push_back(value); }
+
+ /// Push a new Operation onto the result list.
+ void push_back(Operation *value) { results.push_back(value); }
+
+ /// Push a new Type onto the result list.
+ void push_back(Type value) { results.push_back(value); }
+
+ /// Push a new Value onto the result list.
+ void push_back(Value value) { results.push_back(value); }
+
+protected:
+ PDLResultList() = default;
+
+ /// The PDL results held by this list.
+ SmallVector<PDLValue> results;
+};
+
//===----------------------------------------------------------------------===//
// PDLPatternModule
@@ -311,16 +338,13 @@ inline raw_ostream &operator<<(raw_ostream &os, PDLValue value) {
/// success if the constraint successfully held, failure otherwise.
using PDLConstraintFunction = std::function<LogicalResult(
ArrayRef<PDLValue>, ArrayAttr, PatternRewriter &)>;
-/// A native PDL creation function. This function creates a new PDLValue given
-/// a set of existing PDL values, a set of constant parameters specified in
-/// Attribute form, and a PatternRewriter. Returns the newly created PDLValue.
-using PDLCreateFunction =
- std::function<PDLValue(ArrayRef<PDLValue>, ArrayAttr, PatternRewriter &)>;
-/// A native PDL rewrite function. This function rewrites the given root
-/// operation using the provided PatternRewriter. This method is only invoked
-/// when the corresponding match was successful.
-using PDLRewriteFunction = std::function<void(Operation *, ArrayRef<PDLValue>,
- ArrayAttr, PatternRewriter &)>;
+/// A native PDL rewrite function. This function performs a rewrite on the
+/// given set of values and constant parameters. Any results from this rewrite
+/// that should be passed back to PDL should be added to the provided result
+/// list. This method is only invoked when the corresponding match was
+/// successful.
+using PDLRewriteFunction = std::function<void(
+ ArrayRef<PDLValue>, ArrayAttr, PatternRewriter &, PDLResultList &)>;
/// A generic PDL pattern constraint function. This function applies a
/// constraint to a given opaque PDLValue entity. The second parameter is a set
/// of constant value parameters specified in Attribute form. Returns success if
@@ -367,9 +391,6 @@ class PDLPatternModule {
});
}
- /// Register a creation function.
- void registerCreateFunction(StringRef name, PDLCreateFunction createFn);
-
/// Register a rewrite function.
void registerRewriteFunction(StringRef name, PDLRewriteFunction rewriteFn);
@@ -380,13 +401,6 @@ class PDLPatternModule {
llvm::StringMap<PDLConstraintFunction> takeConstraintFunctions() {
return constraintFunctions;
}
- /// Return the set of the registered create functions.
- const llvm::StringMap<PDLCreateFunction> &getCreateFunctions() const {
- return createFunctions;
- }
- llvm::StringMap<PDLCreateFunction> takeCreateFunctions() {
- return createFunctions;
- }
/// Return the set of the registered rewrite functions.
const llvm::StringMap<PDLRewriteFunction> &getRewriteFunctions() const {
return rewriteFunctions;
@@ -399,7 +413,6 @@ class PDLPatternModule {
void clear() {
pdlModule = nullptr;
constraintFunctions.clear();
- createFunctions.clear();
rewriteFunctions.clear();
}
@@ -409,7 +422,6 @@ class PDLPatternModule {
/// The external functions referenced from within the PDL module.
llvm::StringMap<PDLConstraintFunction> constraintFunctions;
- llvm::StringMap<PDLCreateFunction> createFunctions;
llvm::StringMap<PDLRewriteFunction> rewriteFunctions;
};
diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
index 3368ceb9be88..d1da22671d95 100644
--- a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
+++ b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
@@ -70,6 +70,9 @@ struct PatternLowering {
SmallVectorImpl<Position *> &usedMatchValues);
/// Generate the rewriter code for the given operation.
+ void generateRewriter(pdl::ApplyNativeRewriteOp rewriteOp,
+ DenseMap<Value, Value> &rewriteValues,
+ function_ref<Value(Value)> mapRewriteValue);
void generateRewriter(pdl::AttributeOp attrOp,
DenseMap<Value, Value> &rewriteValues,
function_ref<Value(Value)> mapRewriteValue);
@@ -79,9 +82,6 @@ struct PatternLowering {
void generateRewriter(pdl::OperationOp operationOp,
DenseMap<Value, Value> &rewriteValues,
function_ref<Value(Value)> mapRewriteValue);
- void generateRewriter(pdl::CreateNativeOp createNativeOp,
- DenseMap<Value, Value> &rewriteValues,
- function_ref<Value(Value)> mapRewriteValue);
void generateRewriter(pdl::ReplaceOp replaceOp,
DenseMap<Value, Value> &rewriteValues,
function_ref<Value(Value)> mapRewriteValue);
@@ -449,17 +449,17 @@ SymbolRefAttr PatternLowering::generateRewriter(
// method.
pdl::RewriteOp rewriter = pattern.getRewriter();
if (StringAttr rewriteName = rewriter.nameAttr()) {
- Value root = mapRewriteValue(rewriter.root());
- SmallVector<Value, 4> args = llvm::to_vector<4>(
- llvm::map_range(rewriter.externalArgs(), mapRewriteValue));
+ auto mappedArgs = llvm::map_range(rewriter.externalArgs(), mapRewriteValue);
+ SmallVector<Value, 4> args(1, mapRewriteValue(rewriter.root()));
+ args.append(mappedArgs.begin(), mappedArgs.end());
builder.create<pdl_interp::ApplyRewriteOp>(
- rewriter.getLoc(), rewriteName, root, args,
+ rewriter.getLoc(), /*resultTypes=*/TypeRange(), rewriteName, args,
rewriter.externalConstParamsAttr());
} else {
// Otherwise this is a dag rewriter defined using PDL operations.
for (Operation &rewriteOp : *rewriter.getBody()) {
llvm::TypeSwitch<Operation *>(&rewriteOp)
- .Case<pdl::AttributeOp, pdl::CreateNativeOp, pdl::EraseOp,
+ .Case<pdl::ApplyNativeRewriteOp, pdl::AttributeOp, pdl::EraseOp,
pdl::OperationOp, pdl::ReplaceOp, pdl::ResultOp, pdl::TypeOp>(
[&](auto op) {
this->generateRewriter(op, rewriteValues, mapRewriteValue);
@@ -478,6 +478,19 @@ SymbolRefAttr PatternLowering::generateRewriter(
builder.getSymbolRefAttr(rewriterFunc));
}
+void PatternLowering::generateRewriter(
+ pdl::ApplyNativeRewriteOp rewriteOp, DenseMap<Value, Value> &rewriteValues,
+ function_ref<Value(Value)> mapRewriteValue) {
+ SmallVector<Value, 2> arguments;
+ for (Value argument : rewriteOp.args())
+ arguments.push_back(mapRewriteValue(argument));
+ auto interpOp = builder.create<pdl_interp::ApplyRewriteOp>(
+ rewriteOp.getLoc(), rewriteOp.getResultTypes(), rewriteOp.nameAttr(),
+ arguments, rewriteOp.constParamsAttr());
+ for (auto it : llvm::zip(rewriteOp.results(), interpOp.results()))
+ rewriteValues[std::get<0>(it)] = std::get<1>(it);
+}
+
void PatternLowering::generateRewriter(
pdl::AttributeOp attrOp, DenseMap<Value, Value> &rewriteValues,
function_ref<Value(Value)> mapRewriteValue) {
@@ -527,18 +540,6 @@ void PatternLowering::generateRewriter(
}
}
-void PatternLowering::generateRewriter(
- pdl::CreateNativeOp createNativeOp, DenseMap<Value, Value> &rewriteValues,
- function_ref<Value(Value)> mapRewriteValue) {
- SmallVector<Value, 2> arguments;
- for (Value argument : createNativeOp.args())
- arguments.push_back(mapRewriteValue(argument));
- Value result = builder.create<pdl_interp::CreateNativeOp>(
- createNativeOp.getLoc(), createNativeOp.result().getType(),
- createNativeOp.nameAttr(), arguments, createNativeOp.constParamsAttr());
- rewriteValues[createNativeOp] = result;
-}
-
void PatternLowering::generateRewriter(
pdl::ReplaceOp replaceOp, DenseMap<Value, Value> &rewriteValues,
function_ref<Value(Value)> mapRewriteValue) {
diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp
index 0db35f050515..885fbad0f976 100644
--- a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp
+++ b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp
@@ -153,7 +153,7 @@ static void getTreePredicates(std::vector<PositionalPredicate> &predList,
/// Collect all of the predicates related to constraints within the given
/// pattern operation.
-static void getConstraintPredicates(pdl::ApplyConstraintOp op,
+static void getConstraintPredicates(pdl::ApplyNativeConstraintOp op,
std::vector<PositionalPredicate> &predList,
PredicateBuilder &builder,
DenseMap<Value, Position *> &inputs) {
@@ -192,7 +192,7 @@ static void getNonTreePredicates(pdl::PatternOp pattern,
PredicateBuilder &builder,
DenseMap<Value, Position *> &inputs) {
for (Operation &op : pattern.body().getOps()) {
- if (auto constraintOp = dyn_cast<pdl::ApplyConstraintOp>(&op))
+ if (auto constraintOp = dyn_cast<pdl::ApplyNativeConstraintOp>(&op))
getConstraintPredicates(constraintOp, predList, builder, inputs);
else if (auto resultOp = dyn_cast<pdl::ResultOp>(&op))
getResultPredicates(resultOp, predList, builder, inputs);
diff --git a/mlir/lib/Dialect/PDL/IR/PDL.cpp b/mlir/lib/Dialect/PDL/IR/PDL.cpp
index d35aab41ba8f..dc1f501825bd 100644
--- a/mlir/lib/Dialect/PDL/IR/PDL.cpp
+++ b/mlir/lib/Dialect/PDL/IR/PDL.cpp
@@ -64,15 +64,25 @@ verifyHasBindingUseInMatcher(Operation *op,
}
//===----------------------------------------------------------------------===//
-// pdl::ApplyConstraintOp
+// pdl::ApplyNativeConstraintOp
//===----------------------------------------------------------------------===//
-static LogicalResult verify(ApplyConstraintOp op) {
+static LogicalResult verify(ApplyNativeConstraintOp op) {
if (op.getNumOperands() == 0)
return op.emitOpError("expected at least one argument");
return success();
}
+//===----------------------------------------------------------------------===//
+// pdl::ApplyNativeRewriteOp
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verify(ApplyNativeRewriteOp op) {
+ if (op.getNumOperands() == 0 && op.getNumResults() == 0)
+ return op.emitOpError("expected at least one argument or result");
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// pdl::AttributeOp
//===----------------------------------------------------------------------===//
@@ -162,9 +172,9 @@ static LogicalResult verifyResultTypesAreInferrable(OperationOp op,
Operation *resultTypeOp = it.value().getDefiningOp();
assert(resultTypeOp && "expected valid result type operation");
- // If the op was defined by a `create_native`, it is guaranteed to be
+ // If the op was defined by a `apply_native_rewrite`, it is guaranteed to be
// usable.
- if (isa<CreateNativeOp>(resultTypeOp))
+ if (isa<ApplyNativeRewriteOp>(resultTypeOp))
continue;
// If the type is already constrained, there is nothing to do.
diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index 90e89a536405..034698d85cb1 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -102,7 +102,6 @@ void PDLPatternModule::mergeIn(PDLPatternModule &&other) {
// Steal the other state if we have no patterns.
if (!pdlModule) {
constraintFunctions = std::move(other.constraintFunctions);
- createFunctions = std::move(other.createFunctions);
rewriteFunctions = std::move(other.rewriteFunctions);
pdlModule = std::move(other.pdlModule);
return;
@@ -110,8 +109,6 @@ void PDLPatternModule::mergeIn(PDLPatternModule &&other) {
// Steal the functions of the other module.
for (auto &it : constraintFunctions)
registerConstraintFunction(it.first(), std::move(it.second));
- for (auto &it : createFunctions)
- registerCreateFunction(it.first(), std::move(it.second));
for (auto &it : rewriteFunctions)
registerRewriteFunction(it.first(), std::move(it.second));
@@ -132,13 +129,7 @@ void PDLPatternModule::registerConstraintFunction(
assert(it.second &&
"constraint with the given name has already been registered");
}
-void PDLPatternModule::registerCreateFunction(StringRef name,
- PDLCreateFunction createFn) {
- auto it = createFunctions.try_emplace(name, std::move(createFn));
- (void)it;
- assert(it.second && "native create function with the given name has "
- "already been registered");
-}
+
void PDLPatternModule::registerRewriteFunction(StringRef name,
PDLRewriteFunction rewriteFn) {
auto it = rewriteFunctions.try_emplace(name, std::move(rewriteFn));
diff --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp
index 1986b3f87d96..c09892caec1b 100644
--- a/mlir/lib/Rewrite/ByteCode.cpp
+++ b/mlir/lib/Rewrite/ByteCode.cpp
@@ -80,8 +80,6 @@ enum OpCode : ByteCodeField {
CheckOperationName,
/// Compare the result count of an operation with a constant.
CheckResultCount,
- /// Invoke a native creation method.
- CreateNative,
/// Create an operation.
CreateOperation,
/// Erase an operation.
@@ -148,15 +146,12 @@ class Generator {
SmallVectorImpl<PDLByteCodePattern> &patterns,
ByteCodeField &maxValueMemoryIndex,
llvm::StringMap<PDLConstraintFunction> &constraintFns,
- llvm::StringMap<PDLCreateFunction> &createFns,
llvm::StringMap<PDLRewriteFunction> &rewriteFns)
: ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode),
rewriterByteCode(rewriterByteCode), patterns(patterns),
maxValueMemoryIndex(maxValueMemoryIndex) {
for (auto it : llvm::enumerate(constraintFns))
constraintToMemIndex.try_emplace(it.value().first(), it.index());
- for (auto it : llvm::enumerate(createFns))
- nativeCreateToMemIndex.try_emplace(it.value().first(), it.index());
for (auto it : llvm::enumerate(rewriteFns))
externalRewriterToMemIndex.try_emplace(it.value().first(), it.index());
}
@@ -203,7 +198,6 @@ class Generator {
void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer);
void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer);
void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer);
- void generate(pdl_interp::CreateNativeOp op, ByteCodeWriter &writer);
void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer);
void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer);
void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer);
@@ -235,10 +229,6 @@ class Generator {
/// in the bytecode registry.
llvm::StringMap<ByteCodeField> constraintToMemIndex;
- /// Mapping from the name of an externally registered creation method to its
- /// index in the bytecode registry.
- llvm::StringMap<ByteCodeField> nativeCreateToMemIndex;
-
/// Mapping from rewriter function name to the bytecode address of the
/// rewriter function in byte.
llvm::StringMap<ByteCodeAddr> rewriterToAddr;
@@ -492,16 +482,16 @@ void Generator::generate(Operation *op, ByteCodeWriter &writer) {
pdl_interp::CheckAttributeOp, pdl_interp::CheckOperandCountOp,
pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp,
pdl_interp::CheckTypeOp, pdl_interp::CreateAttributeOp,
- pdl_interp::CreateNativeOp, pdl_interp::CreateOperationOp,
- pdl_interp::CreateTypeOp, pdl_interp::EraseOp,
- pdl_interp::FinalizeOp, pdl_interp::GetAttributeOp,
- pdl_interp::GetAttributeTypeOp, pdl_interp::GetDefiningOpOp,
- pdl_interp::GetOperandOp, pdl_interp::GetResultOp,
- pdl_interp::GetValueTypeOp, pdl_interp::InferredTypeOp,
- pdl_interp::IsNotNullOp, pdl_interp::RecordMatchOp,
- pdl_interp::ReplaceOp, pdl_interp::SwitchAttributeOp,
- pdl_interp::SwitchTypeOp, pdl_interp::SwitchOperandCountOp,
- pdl_interp::SwitchOperationNameOp, pdl_interp::SwitchResultCountOp>(
+ pdl_interp::CreateOperationOp, pdl_interp::CreateTypeOp,
+ pdl_interp::EraseOp, pdl_interp::FinalizeOp,
+ pdl_interp::GetAttributeOp, pdl_interp::GetAttributeTypeOp,
+ pdl_interp::GetDefiningOpOp, pdl_interp::GetOperandOp,
+ pdl_interp::GetResultOp, pdl_interp::GetValueTypeOp,
+ pdl_interp::InferredTypeOp, pdl_interp::IsNotNullOp,
+ pdl_interp::RecordMatchOp, pdl_interp::ReplaceOp,
+ pdl_interp::SwitchAttributeOp, pdl_interp::SwitchTypeOp,
+ pdl_interp::SwitchOperandCountOp, pdl_interp::SwitchOperationNameOp,
+ pdl_interp::SwitchResultCountOp>(
[&](auto interpOp) { this->generate(interpOp, writer); })
.Default([](Operation *) {
llvm_unreachable("unknown `pdl_interp` operation");
@@ -522,8 +512,16 @@ void Generator::generate(pdl_interp::ApplyRewriteOp op,
assert(externalRewriterToMemIndex.count(op.name()) &&
"expected index for rewrite function");
writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.name()],
- op.constParamsAttr(), op.root());
+ op.constParamsAttr());
writer.appendPDLValueList(op.args());
+
+#ifndef NDEBUG
+ // In debug mode we also append the number of results so that we can assert
+ // that the native creation function gave us the correct number of results.
+ writer.append(ByteCodeField(op.results().size()));
+#endif
+ for (Value result : op.results())
+ writer.append(result);
}
void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) {
writer.append(OpCode::AreEqual, op.lhs(), op.rhs(), op.getSuccessors());
@@ -559,14 +557,6 @@ void Generator::generate(pdl_interp::CreateAttributeOp op,
// Simply repoint the memory index of the result to the constant.
getMemIndex(op.attribute()) = getMemIndex(op.value());
}
-void Generator::generate(pdl_interp::CreateNativeOp op,
- ByteCodeWriter &writer) {
- assert(nativeCreateToMemIndex.count(op.name()) &&
- "expected index for creation function");
- writer.append(OpCode::CreateNative, nativeCreateToMemIndex[op.name()],
- op.result(), op.constParamsAttr());
- writer.appendPDLValueList(op.args());
-}
void Generator::generate(pdl_interp::CreateOperationOp op,
ByteCodeWriter &writer) {
writer.append(OpCode::CreateOperation, op.operation(),
@@ -678,18 +668,15 @@ void Generator::generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer) {
PDLByteCode::PDLByteCode(ModuleOp module,
llvm::StringMap<PDLConstraintFunction> constraintFns,
- llvm::StringMap<PDLCreateFunction> createFns,
llvm::StringMap<PDLRewriteFunction> rewriteFns) {
Generator generator(module.getContext(), uniquedData, matcherByteCode,
rewriterByteCode, patterns, maxValueMemoryIndex,
- constraintFns, createFns, rewriteFns);
+ constraintFns, rewriteFns);
generator.generate(module);
// Initialize the external functions.
for (auto &it : constraintFns)
constraintFunctions.push_back(std::move(it.second));
- for (auto &it : createFns)
- createFunctions.push_back(std::move(it.second));
for (auto &it : rewriteFns)
rewriteFunctions.push_back(std::move(it.second));
}
@@ -717,12 +704,11 @@ class ByteCodeExecutor {
ArrayRef<PatternBenefit> currentPatternBenefits,
ArrayRef<PDLByteCodePattern> patterns,
ArrayRef<PDLConstraintFunction> constraintFunctions,
- ArrayRef<PDLCreateFunction> createFunctions,
ArrayRef<PDLRewriteFunction> rewriteFunctions)
: curCodeIt(curCodeIt), memory(memory), uniquedMemory(uniquedMemory),
code(code), currentPatternBenefits(currentPatternBenefits),
patterns(patterns), constraintFunctions(constraintFunctions),
- createFunctions(createFunctions), rewriteFunctions(rewriteFunctions) {}
+ rewriteFunctions(rewriteFunctions) {}
/// Start executing the code at the current bytecode index. `matches` is an
/// optional field provided when this function is executed in a matching
@@ -740,7 +726,6 @@ class ByteCodeExecutor {
void executeCheckOperandCount();
void executeCheckOperationName();
void executeCheckResultCount();
- void executeCreateNative(PatternRewriter &rewriter);
void executeCreateOperation(PatternRewriter &rewriter,
Location mainRewriteLoc);
void executeEraseOp(PatternRewriter &rewriter);
@@ -866,9 +851,17 @@ class ByteCodeExecutor {
ArrayRef<PatternBenefit> currentPatternBenefits;
ArrayRef<PDLByteCodePattern> patterns;
ArrayRef<PDLConstraintFunction> constraintFunctions;
- ArrayRef<PDLCreateFunction> createFunctions;
ArrayRef<PDLRewriteFunction> rewriteFunctions;
};
+
+/// This class is an instantiation of the PDLResultList that provides access to
+/// the returned results. This API is not on `PDLResultList` to avoid
+/// overexposing access to information specific solely to the ByteCode.
+class ByteCodeRewriteResultList : public PDLResultList {
+public:
+ /// Return the list of PDL results.
+ MutableArrayRef<PDLValue> getResults() { return results; }
+};
} // end anonymous namespace
void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) {
@@ -892,18 +885,29 @@ void ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n");
const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()];
ArrayAttr constParams = read<ArrayAttr>();
- Operation *root = read<Operation *>();
SmallVector<PDLValue, 16> args;
readList<PDLValue>(args);
LLVM_DEBUG({
- llvm::dbgs() << " * Root: " << *root << "\n * Arguments: ";
+ llvm::dbgs() << " * Arguments: ";
llvm::interleaveComma(args, llvm::dbgs());
llvm::dbgs() << "\n * Parameters: " << constParams << "\n";
});
-
- // Invoke the native rewrite function.
- rewriteFn(root, args, constParams, rewriter);
+ ByteCodeRewriteResultList results;
+ rewriteFn(args, constParams, rewriter, results);
+
+ // Store the results in the bytecode memory.
+#ifndef NDEBUG
+ ByteCodeField expectedNumberOfResults = read();
+ assert(results.getResults().size() == expectedNumberOfResults &&
+ "native PDL rewrite function returned unexpected number of results");
+#endif
+
+ // Store the results in the bytecode memory.
+ for (PDLValue &result : results.getResults()) {
+ LLVM_DEBUG(llvm::dbgs() << " * Result: " << result << "\n");
+ memory[read()] = result.getAsOpaquePointer();
+ }
}
void ByteCodeExecutor::executeAreEqual() {
@@ -950,26 +954,6 @@ void ByteCodeExecutor::executeCheckResultCount() {
selectJump(op->getNumResults() == expectedCount);
}
-void ByteCodeExecutor::executeCreateNative(PatternRewriter &rewriter) {
- LLVM_DEBUG(llvm::dbgs() << "Executing CreateNative:\n");
- const PDLCreateFunction &createFn = createFunctions[read()];
- ByteCodeField resultIndex = read();
- ArrayAttr constParams = read<ArrayAttr>();
- SmallVector<PDLValue, 16> args;
- readList<PDLValue>(args);
-
- LLVM_DEBUG({
- llvm::dbgs() << " * Arguments: ";
- llvm::interleaveComma(args, llvm::dbgs());
- llvm::dbgs() << "\n * Parameters: " << constParams << "\n";
- });
-
- PDLValue result = createFn(args, constParams, rewriter);
- memory[resultIndex] = result.getAsOpaquePointer();
-
- LLVM_DEBUG(llvm::dbgs() << " * Result: " << result << "\n");
-}
-
void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter,
Location mainRewriteLoc) {
LLVM_DEBUG(llvm::dbgs() << "Executing CreateOperation:\n");
@@ -1246,9 +1230,6 @@ void ByteCodeExecutor::execute(
case CheckResultCount:
executeCheckResultCount();
break;
- case CreateNative:
- executeCreateNative(rewriter);
- break;
case CreateOperation:
executeCreateOperation(rewriter, *mainRewriteLoc);
break;
@@ -1338,8 +1319,7 @@ void PDLByteCode::match(Operation *op, PatternRewriter &rewriter,
// The matcher function always starts at code address 0.
ByteCodeExecutor executor(matcherByteCode.data(), state.memory, uniquedData,
matcherByteCode, state.currentPatternBenefits,
- patterns, constraintFunctions, createFunctions,
- rewriteFunctions);
+ patterns, constraintFunctions, rewriteFunctions);
executor.execute(rewriter, &matches);
// Order the found matches by benefit.
@@ -1356,9 +1336,9 @@ void PDLByteCode::rewrite(PatternRewriter &rewriter, const MatchResult &match,
// memory buffer.
llvm::copy(match.values, state.memory.begin());
- ByteCodeExecutor executor(
- &rewriterByteCode[match.pattern->getRewriterAddr()], state.memory,
- uniquedData, rewriterByteCode, state.currentPatternBenefits, patterns,
- constraintFunctions, createFunctions, rewriteFunctions);
+ ByteCodeExecutor executor(&rewriterByteCode[match.pattern->getRewriterAddr()],
+ state.memory, uniquedData, rewriterByteCode,
+ state.currentPatternBenefits, patterns,
+ constraintFunctions, rewriteFunctions);
executor.execute(rewriter, /*matches=*/nullptr, match.location);
}
diff --git a/mlir/lib/Rewrite/ByteCode.h b/mlir/lib/Rewrite/ByteCode.h
index 38dbbcd855ce..f6a3bcbe54f9 100644
--- a/mlir/lib/Rewrite/ByteCode.h
+++ b/mlir/lib/Rewrite/ByteCode.h
@@ -114,7 +114,6 @@ class PDLByteCode {
/// the PDL interpreter dialect.
PDLByteCode(ModuleOp module,
llvm::StringMap<PDLConstraintFunction> constraintFns,
- llvm::StringMap<PDLCreateFunction> createFns,
llvm::StringMap<PDLRewriteFunction> rewriteFns);
/// Return the patterns held by the bytecode.
@@ -160,7 +159,6 @@ class PDLByteCode {
/// A set of user defined functions invoked via PDL.
std::vector<PDLConstraintFunction> constraintFunctions;
- std::vector<PDLCreateFunction> createFunctions;
std::vector<PDLRewriteFunction> rewriteFunctions;
/// The maximum memory index used by a value.
diff --git a/mlir/lib/Rewrite/FrozenRewritePatternList.cpp b/mlir/lib/Rewrite/FrozenRewritePatternList.cpp
index 40f7aec44e51..c2de51a647dd 100644
--- a/mlir/lib/Rewrite/FrozenRewritePatternList.cpp
+++ b/mlir/lib/Rewrite/FrozenRewritePatternList.cpp
@@ -70,7 +70,7 @@ FrozenRewritePatternList::FrozenRewritePatternList(
// Generate the pdl bytecode.
impl->pdlByteCode = std::make_unique<detail::PDLByteCode>(
pdlModule, pdlPatterns.takeConstraintFunctions(),
- pdlPatterns.takeCreateFunctions(), pdlPatterns.takeRewriteFunctions());
+ pdlPatterns.takeRewriteFunctions());
}
FrozenRewritePatternList::~FrozenRewritePatternList() {}
diff --git a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir
index c856ab5c9f6f..a42b51604945 100644
--- a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir
+++ b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir
@@ -24,7 +24,7 @@ module @simple {
// CHECK: module @rewriters
// CHECK: func @pdl_generated_rewriter(%[[REWRITE_ROOT:.*]]: !pdl.operation)
- // CHECK: pdl_interp.apply_rewrite "rewriter" on %[[REWRITE_ROOT]]
+ // CHECK: pdl_interp.apply_rewrite "rewriter"(%[[REWRITE_ROOT]]
// CHECK: pdl_interp.finalize
pdl.pattern : benefit(1) {
%root = pdl.operation "foo.op"()
@@ -72,7 +72,7 @@ module @constraints {
%root = pdl.operation(%input0, %input1)
%result0 = pdl.result 0 of %root
- pdl.apply_constraint "multi_constraint"[true](%input0, %input1, %result0 : !pdl.value, !pdl.value, !pdl.value)
+ pdl.apply_native_constraint "multi_constraint"[true](%input0, %input1, %result0 : !pdl.value, !pdl.value, !pdl.value)
pdl.rewrite %root with "rewriter"
}
}
@@ -194,7 +194,7 @@ module @predicate_ordering {
pdl.pattern : benefit(1) {
%resultType = pdl.type
- pdl.apply_constraint "typeConstraint"[](%resultType : !pdl.type)
+ pdl.apply_native_constraint "typeConstraint"[](%resultType : !pdl.type)
%root = pdl.operation -> %resultType
pdl.rewrite %root with "rewriter"
}
diff --git a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-rewriter.mlir b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-rewriter.mlir
index 5652b2118afe..3d0d565c547f 100644
--- a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-rewriter.mlir
+++ b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-rewriter.mlir
@@ -6,7 +6,7 @@
module @external {
// CHECK: module @rewriters
// CHECK: func @pdl_generated_rewriter(%[[ROOT:.*]]: !pdl.operation, %[[INPUT:.*]]: !pdl.value)
- // CHECK: pdl_interp.apply_rewrite "rewriter" [true](%[[INPUT]] : !pdl.value) on %[[ROOT]]
+ // CHECK: pdl_interp.apply_rewrite "rewriter" [true](%[[ROOT]], %[[INPUT]] : !pdl.operation, !pdl.value)
pdl.pattern : benefit(1) {
%input = pdl.operand
%root = pdl.operation "foo.op"(%input)
@@ -170,17 +170,17 @@ module @replace_with_no_results {
// -----
-// CHECK-LABEL: module @create_native
-module @create_native {
+// CHECK-LABEL: module @apply_native_rewrite
+module @apply_native_rewrite {
// CHECK: module @rewriters
// CHECK: func @pdl_generated_rewriter(%[[ROOT:.*]]: !pdl.operation)
- // CHECK: %[[TYPE:.*]] = pdl_interp.create_native "functor" [true](%[[ROOT]] : !pdl.operation) : !pdl.type
+ // CHECK: %[[TYPE:.*]] = pdl_interp.apply_rewrite "functor" [true](%[[ROOT]] : !pdl.operation) : !pdl.type
// CHECK: pdl_interp.create_operation "foo.op"() -> %[[TYPE]]
pdl.pattern : benefit(1) {
%type = pdl.type
%root = pdl.operation "foo.op" -> %type
pdl.rewrite %root {
- %newType = pdl.create_native "functor"[true](%root : !pdl.operation) : !pdl.type
+ %newType = pdl.apply_native_rewrite "functor"[true](%root : !pdl.operation) : !pdl.type
%newOp = pdl.operation "foo.op" -> %newType
pdl.replace %root with %newOp
}
diff --git a/mlir/test/Dialect/PDL/invalid.mlir b/mlir/test/Dialect/PDL/invalid.mlir
index 0f900bbe3f53..a054da24ba4d 100644
--- a/mlir/test/Dialect/PDL/invalid.mlir
+++ b/mlir/test/Dialect/PDL/invalid.mlir
@@ -1,19 +1,33 @@
// RUN: mlir-opt %s -split-input-file -verify-diagnostics
//===----------------------------------------------------------------------===//
-// pdl::ApplyConstraintOp
+// pdl::ApplyNativeConstraintOp
//===----------------------------------------------------------------------===//
pdl.pattern : benefit(1) {
%op = pdl.operation "foo.op"
// expected-error at below {{expected at least one argument}}
- "pdl.apply_constraint"() {name = "foo", params = []} : () -> ()
+ "pdl.apply_native_constraint"() {name = "foo", params = []} : () -> ()
pdl.rewrite %op with "rewriter"
}
// -----
+//===----------------------------------------------------------------------===//
+// pdl::ApplyNativeRewriteOp
+//===----------------------------------------------------------------------===//
+
+pdl.pattern : benefit(1) {
+ %op = pdl.operation "foo.op"
+ pdl.rewrite %op {
+ // expected-error at below {{expected at least one argument}}
+ "pdl.apply_native_rewrite"() {name = "foo", params = []} : () -> ()
+ }
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// pdl::AttributeOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Rewrite/pdl-bytecode.mlir b/mlir/test/Rewrite/pdl-bytecode.mlir
index b2a22d0a8749..2093d03bbf25 100644
--- a/mlir/test/Rewrite/pdl-bytecode.mlir
+++ b/mlir/test/Rewrite/pdl-bytecode.mlir
@@ -58,7 +58,7 @@ module @patterns {
module @rewriters {
func @success(%root : !pdl.operation) {
%operand = pdl_interp.get_operand 0 of %root
- pdl_interp.apply_rewrite "rewriter"[42](%operand : !pdl.value) on %root
+ pdl_interp.apply_rewrite "rewriter"[42](%root, %operand : !pdl.operation, !pdl.value)
pdl_interp.finalize
}
}
@@ -72,6 +72,35 @@ module @ir attributes { test.apply_rewrite_1 } {
%input = "test.op_input"() : () -> i32
"test.op"(%input) : (i32) -> ()
}
+
+// -----
+
+module @patterns {
+ func @matcher(%root : !pdl.operation) {
+ pdl_interp.check_operation_name of %root is "test.op" -> ^pat, ^end
+
+ ^pat:
+ pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end
+
+ ^end:
+ pdl_interp.finalize
+ }
+
+ module @rewriters {
+ func @success(%root : !pdl.operation) {
+ %op = pdl_interp.apply_rewrite "creator"(%root : !pdl.operation) : !pdl.operation
+ pdl_interp.erase %root
+ pdl_interp.finalize
+ }
+ }
+}
+
+// CHECK-LABEL: test.apply_rewrite_2
+// CHECK: "test.success"
+module @ir attributes { test.apply_rewrite_2 } {
+ "test.op"() : () -> ()
+}
+
// -----
//===----------------------------------------------------------------------===//
@@ -317,38 +346,6 @@ module @ir attributes { test.check_type_1 } {
// Fully tested within the tests for other operations.
-//===----------------------------------------------------------------------===//
-// pdl_interp::CreateNativeOp
-//===----------------------------------------------------------------------===//
-
-// -----
-
-module @patterns {
- func @matcher(%root : !pdl.operation) {
- pdl_interp.check_operation_name of %root is "test.op" -> ^pat, ^end
-
- ^pat:
- pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end
-
- ^end:
- pdl_interp.finalize
- }
-
- module @rewriters {
- func @success(%root : !pdl.operation) {
- %op = pdl_interp.create_native "creator"(%root : !pdl.operation) : !pdl.operation
- pdl_interp.erase %root
- pdl_interp.finalize
- }
- }
-}
-
-// CHECK-LABEL: test.create_native_1
-// CHECK: "test.success"
-module @ir attributes { test.create_native_1 } {
- "test.op"() : () -> ()
-}
-
//===----------------------------------------------------------------------===//
// pdl_interp::CreateOperationOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Rewrite/TestPDLByteCode.cpp b/mlir/test/lib/Rewrite/TestPDLByteCode.cpp
index 3b23cb103675..e60022ba94cc 100644
--- a/mlir/test/lib/Rewrite/TestPDLByteCode.cpp
+++ b/mlir/test/lib/Rewrite/TestPDLByteCode.cpp
@@ -26,18 +26,18 @@ static LogicalResult customMultiEntityConstraint(ArrayRef<PDLValue> values,
}
// Custom creator invoked from PDL.
-static PDLValue customCreate(ArrayRef<PDLValue> args, ArrayAttr constantParams,
- PatternRewriter &rewriter) {
- return rewriter.createOperation(
- OperationState(args[0].cast<Operation *>()->getLoc(), "test.success"));
+static void customCreate(ArrayRef<PDLValue> args, ArrayAttr constantParams,
+ PatternRewriter &rewriter, PDLResultList &results) {
+ results.push_back(rewriter.createOperation(
+ OperationState(args[0].cast<Operation *>()->getLoc(), "test.success")));
}
/// Custom rewriter invoked from PDL.
-static void customRewriter(Operation *root, ArrayRef<PDLValue> args,
- ArrayAttr constantParams,
- PatternRewriter &rewriter) {
+static void customRewriter(ArrayRef<PDLValue> args, ArrayAttr constantParams,
+ PatternRewriter &rewriter, PDLResultList &results) {
+ Operation *root = args[0].cast<Operation *>();
OperationState successOpState(root->getLoc(), "test.success");
- successOpState.addOperands(args[0].cast<Value>());
+ successOpState.addOperands(args[1].cast<Value>());
successOpState.addAttribute("constantParams", constantParams);
rewriter.createOperation(successOpState);
rewriter.eraseOp(root);
@@ -63,7 +63,7 @@ struct TestPDLByteCodePass
customMultiEntityConstraint);
pdlPattern.registerConstraintFunction("single_entity_constraint",
customSingleEntityConstraint);
- pdlPattern.registerCreateFunction("creator", customCreate);
+ pdlPattern.registerRewriteFunction("creator", customCreate);
pdlPattern.registerRewriteFunction("rewriter", customRewriter);
OwningRewritePatternList patternList(std::move(pdlPattern));
More information about the Mlir-commits
mailing list