[Mlir-commits] [mlir] d289a97 - [mlir][PDL] Add a PDL Interpreter Dialect
River Riddle
llvmlistbot at llvm.org
Wed Aug 26 05:23:48 PDT 2020
Author: River Riddle
Date: 2020-08-26T05:22:27-07:00
New Revision: d289a97f91443177b605926668512479c2cee37b
URL: https://github.com/llvm/llvm-project/commit/d289a97f91443177b605926668512479c2cee37b
DIFF: https://github.com/llvm/llvm-project/commit/d289a97f91443177b605926668512479c2cee37b.diff
LOG: [mlir][PDL] Add a PDL Interpreter Dialect
The PDL Interpreter dialect provides a lower level abstraction compared to the PDL dialect, and is targeted towards low level optimization and interpreter code generation. The dialect operations encapsulates low-level pattern match and rewrite "primitives", such as navigating the IR (Operation::getOperand), creating new operations (OpBuilder::create), etc. Many of the operations within this dialect also fuse branching control flow with some form of a predicate comparison operation. This type of fusion reduces the amount of work that an interpreter must do when executing.
An example of this representation is shown below:
```mlir
// The following high level PDL pattern:
pdl.pattern : benefit(1) {
%resultType = pdl.type
%inputOperand = pdl.input
%root, %results = pdl.operation "foo.op"(%inputOperand) -> %resultType
pdl.rewrite %root {
pdl.replace %root with (%inputOperand)
}
}
// May be represented in the interpreter dialect as follows:
module {
func @matcher(%arg0: !pdl.operation) {
pdl_interp.check_operation_name of %arg0 is "foo.op" -> ^bb2, ^bb1
^bb1:
pdl_interp.return
^bb2:
pdl_interp.check_operand_count of %arg0 is 1 -> ^bb3, ^bb1
^bb3:
pdl_interp.check_result_count of %arg0 is 1 -> ^bb4, ^bb1
^bb4:
%0 = pdl_interp.get_operand 0 of %arg0
pdl_interp.is_not_null %0 : !pdl.value -> ^bb5, ^bb1
^bb5:
%1 = pdl_interp.get_result 0 of %arg0
pdl_interp.is_not_null %1 : !pdl.value -> ^bb6, ^bb1
^bb6:
pdl_interp.record_match @rewriters::@rewriter(%0, %arg0 : !pdl.value, !pdl.operation) : benefit(1), loc([%arg0]), root("foo.op") -> ^bb1
}
module @rewriters {
func @rewriter(%arg0: !pdl.value, %arg1: !pdl.operation) {
pdl_interp.replace %arg1 with(%arg0)
pdl_interp.return
}
}
}
```
Differential Revision: https://reviews.llvm.org/D84579
Added:
mlir/include/mlir/Dialect/PDLInterp/CMakeLists.txt
mlir/include/mlir/Dialect/PDLInterp/IR/CMakeLists.txt
mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterp.h
mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td
mlir/lib/Dialect/PDLInterp/CMakeLists.txt
mlir/lib/Dialect/PDLInterp/IR/CMakeLists.txt
mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp
mlir/test/Dialect/PDLInterp/ops.mlir
Modified:
mlir/include/mlir/Dialect/CMakeLists.txt
mlir/include/mlir/Dialect/PDL/IR/PDLBase.td
mlir/include/mlir/Dialect/PDL/IR/PDLOps.td
mlir/include/mlir/IR/Attributes.h
mlir/include/mlir/IR/Builders.h
mlir/include/mlir/IR/OpImplementation.h
mlir/include/mlir/InitAllDialects.h
mlir/lib/Dialect/CMakeLists.txt
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/lib/Dialect/PDL/IR/PDL.cpp
mlir/lib/IR/Builders.cpp
mlir/lib/Parser/AttributeParser.cpp
mlir/lib/Parser/Parser.cpp
mlir/lib/Parser/Parser.h
mlir/test/Dialect/PDL/invalid.mlir
mlir/test/Dialect/PDL/ops.mlir
mlir/tools/mlir-tblgen/OpFormatGen.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/CMakeLists.txt b/mlir/include/mlir/Dialect/CMakeLists.txt
index 118210c45fe8..6426fa8a91e7 100644
--- a/mlir/include/mlir/Dialect/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/CMakeLists.txt
@@ -6,6 +6,7 @@ add_subdirectory(LLVMIR)
add_subdirectory(OpenACC)
add_subdirectory(OpenMP)
add_subdirectory(PDL)
+add_subdirectory(PDLInterp)
add_subdirectory(Quant)
add_subdirectory(SCF)
add_subdirectory(Shape)
diff --git a/mlir/include/mlir/Dialect/PDL/IR/PDLBase.td b/mlir/include/mlir/Dialect/PDL/IR/PDLBase.td
index a3392f215ae5..9802bf943157 100644
--- a/mlir/include/mlir/Dialect/PDL/IR/PDLBase.td
+++ b/mlir/include/mlir/Dialect/PDL/IR/PDLBase.td
@@ -49,7 +49,7 @@ def PDL_Dialect : Dialect {
%resultType = pdl.type
%inputOperand = pdl.input
%root, %results = pdl.operation "foo.op"(%inputOperand) -> %resultType
- pdl.rewrite(%root) {
+ pdl.rewrite %root {
pdl.replace %root with (%inputOperand)
}
}
diff --git a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td
index 1e865fbd921d..73b4f263accd 100644
--- a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td
+++ b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td
@@ -51,17 +51,18 @@ def PDL_ApplyConstraintOp
```
}];
- let arguments = (ins Variadic<PDL_PositionalValue>:$args,
- ArrayAttr:$params,
- StrAttr:$name);
- let assemblyFormat = "$name $params `(` $args `:` type($args) `)` attr-dict";
+ let arguments = (ins StrAttr:$name,
+ Variadic<PDL_PositionalValue>:$args,
+ OptionalAttr<ArrayAttr>:$constParams);
+ let assemblyFormat = [{
+ $name ($constParams^)? `(` $args `:` type($args) `)` attr-dict
+ }];
let builders = [
- OpBuilder<"OpBuilder &builder, OperationState &state, "
- "ValueRange args, ArrayRef<Attribute> params, "
- "StringRef name", [{
- build(builder, state, args, builder.getArrayAttr(params),
- builder.getStringAttr(name));
+ OpBuilder<"OpBuilder &builder, OperationState &state, StringRef name, "
+ "ValueRange args = {}, ArrayRef<Attribute> params = {}", [{
+ build(builder, state, builder.getStringAttr(name), args,
+ params.empty() ? ArrayAttr() : builder.getArrayAttr(params));
}]>,
];
}
@@ -135,12 +136,13 @@ def PDL_CreateNativeOp
```
}];
- let arguments = (ins StrAttr:$name, Variadic<PDL_PositionalValue>:$arguments,
- ArrayAttr:$constantParams);
+ let arguments = (ins StrAttr:$name,
+ Variadic<PDL_PositionalValue>:$args,
+ OptionalAttr<ArrayAttr>:$constParams);
let results = (outs PDL_PositionalValue:$result);
let assemblyFormat = [{
- $name $constantParams (`(` $arguments^ `:` type($arguments) `)`)?
- `:` type($result) attr-dict
+ $name ($constParams^)? (`(` $args^ `:` type($args) `)`)? `:` type($result)
+ attr-dict
}];
let verifier = ?;
}
@@ -222,7 +224,7 @@ def PDL_OperationOp
`pdl.operation`s are composed of a name, and a set of attribute, operand,
and result type values, that map to what those that would be on a
constructed instance of that operation. The results of a `pdl.operation` are
- a handle to the operation itself, and a handle to each of the operation
+ a handle to the operation itself, and a handle to each of the operation
result values.
When used within a matching context, the name of the operation may be
@@ -380,16 +382,18 @@ def PDL_RewriteOp : PDL_Op<"rewrite", [
rewrite is specified either via a string name (`name`) to an external
rewrite function, or via the region body. The rewrite region, if specified,
must contain a single block and terminate via the `pdl.rewrite_end`
- operation.
+ operation. If the rewrite is external, it also takes a set of constant
+ parameters and a set of additional positional values defined within the
+ matcher as arguments.
Example:
```mlir
// Specify an external rewrite function:
- pdl.rewrite "myExternalRewriter"(%root)
+ pdl.rewrite %root with "myExternalRewriter"(%value : !pdl.value)
// Specify the rewrite inline using PDL:
- pdl.rewrite(%root) {
+ pdl.rewrite %root {
%op = pdl.operation "foo.op"(%arg0, %arg1)
pdl.replace %root with %op
}
@@ -397,7 +401,9 @@ def PDL_RewriteOp : PDL_Op<"rewrite", [
}];
let arguments = (ins PDL_Operation:$root,
- OptionalAttr<StrAttr>:$name);
+ OptionalAttr<StrAttr>:$name,
+ Variadic<PDL_PositionalValue>:$externalArgs,
+ OptionalAttr<ArrayAttr>:$externalConstParams);
let regions = (region AnyRegion:$body);
}
diff --git a/mlir/include/mlir/Dialect/PDLInterp/CMakeLists.txt b/mlir/include/mlir/Dialect/PDLInterp/CMakeLists.txt
new file mode 100644
index 000000000000..f33061b2d87c
--- /dev/null
+++ b/mlir/include/mlir/Dialect/PDLInterp/CMakeLists.txt
@@ -0,0 +1 @@
+add_subdirectory(IR)
diff --git a/mlir/include/mlir/Dialect/PDLInterp/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/PDLInterp/IR/CMakeLists.txt
new file mode 100644
index 000000000000..8a707660fa26
--- /dev/null
+++ b/mlir/include/mlir/Dialect/PDLInterp/IR/CMakeLists.txt
@@ -0,0 +1,2 @@
+add_mlir_dialect(PDLInterpOps pdl_interp)
+add_mlir_doc(PDLInterpOps -gen-op-doc PDLInterpOps Dialects/)
diff --git a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterp.h b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterp.h
new file mode 100644
index 000000000000..6d895679b3d6
--- /dev/null
+++ b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterp.h
@@ -0,0 +1,39 @@
+//===- PDLInterp.h - PDL Interpreter dialect --------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares the interpreter dialect for the PDL pattern descriptor
+// language.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_PDLINTERP_IR_PDLINTERP_H_
+#define MLIR_DIALECT_PDLINTERP_IR_PDLINTERP_H_
+
+#include "mlir/Dialect/PDL/IR/PDL.h"
+#include "mlir/Interfaces/InferTypeOpInterface.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+
+namespace mlir {
+namespace pdl_interp {
+//===----------------------------------------------------------------------===//
+// PDLInterp Dialect
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/PDLInterp/IR/PDLInterpOpsDialect.h.inc"
+
+//===----------------------------------------------------------------------===//
+// PDLInterp Dialect Operations
+//===----------------------------------------------------------------------===//
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/PDLInterp/IR/PDLInterpOps.h.inc"
+
+} // end namespace pdl_interp
+} // end namespace mlir
+
+#endif // MLIR_DIALECT_PDLINTERP_IR_PDLINTERP_H_
diff --git a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td
new file mode 100644
index 000000000000..58a2032a2182
--- /dev/null
+++ b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td
@@ -0,0 +1,926 @@
+//===- PDLInterpOps.td - Pattern Interpreter Dialect -------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares the PDL interpreter dialect ops.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_PDLINTERP_IR_PDLINTERPOPS
+#define MLIR_DIALECT_PDLINTERP_IR_PDLINTERPOPS
+
+include "mlir/Dialect/PDL/IR/PDLBase.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+
+//===----------------------------------------------------------------------===//
+// PDLInterp Dialect
+//===----------------------------------------------------------------------===//
+
+def PDLInterp_Dialect : Dialect {
+ let summary = "Interpreted pattern execution dialect";
+ let description = [{
+ The PDL Interpreter dialect provides a lower level abstraction compared to
+ the PDL dialect, and is targeted towards low level optimization and
+ interpreter code generation. The dialect operations encapsulates
+ low-level pattern match and rewrite "primitives", such as navigating the
+ IR (Operation::getOperand), creating new operations (OpBuilder::create),
+ etc. Many of the operations within this dialect also fuse branching control
+ flow with some form of a predicate comparison operation. This type of fusion
+ reduces the amount of work that an interpreter must do when executing.
+ }];
+
+ let name = "pdl_interp";
+ let cppNamespace = "mlir::pdl_interp";
+ let dependentDialects = ["pdl::PDLDialect"];
+}
+
+//===----------------------------------------------------------------------===//
+// PDLInterp Operations
+//===----------------------------------------------------------------------===//
+
+// Generic interpreter operation.
+class PDLInterp_Op<string mnemonic, list<OpTrait> traits = []> :
+ Op<PDLInterp_Dialect, mnemonic, traits>;
+
+//===----------------------------------------------------------------------===//
+// PDLInterp_PredicateOp
+
+// Check operations evaluate a predicate on a positional value and then
+// conditionally branch on the result.
+class PDLInterp_PredicateOp<string mnemonic, list<OpTrait> traits = []> :
+ PDLInterp_Op<mnemonic, !listconcat([Terminator], traits)> {
+ let successors = (successor AnySuccessor:$trueDest, AnySuccessor:$falseDest);
+}
+
+//===----------------------------------------------------------------------===//
+// PDLInterp_SwitchOp
+
+// Switch operations evaluate a predicate on a positional value and then
+// conditionally branch on the result.
+class PDLInterp_SwitchOp<string mnemonic, list<OpTrait> traits = []> :
+ PDLInterp_Op<mnemonic, !listconcat([Terminator], traits)> {
+ let successors = (successor AnySuccessor:$defaultDest,
+ VariadicSuccessor<AnySuccessor>:$cases);
+
+ let verifier = [{
+ // Verify that the number of case destinations matches the number of case
+ // values.
+ size_t numDests = cases().size();
+ size_t numValues = caseValues().size();
+ if (numDests != numValues) {
+ return emitOpError("expected number of cases to match the number of case "
+ "values, got ")
+ << numDests << " but expected " << numValues;
+ }
+ return success();
+ }];
+}
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::ApplyConstraintOp
+//===----------------------------------------------------------------------===//
+
+def PDLInterp_ApplyConstraintOp : PDLInterp_PredicateOp<"apply_constraint"> {
+ let summary = "Apply a constraint to a set of positional values";
+ let description = [{
+ `pdl_interp.apply_constraint` operations apply a generic constraint, that
+ has been registered with the interpreter, with a given set of positional
+ values. The constraint may have any number of constant parameters. On
+ success, this operation branches to the true destination, otherwise the
+ false destination is taken.
+
+ Example:
+
+ ```mlir
+ // Apply `myConstraint` to the entities defined by `input`, `attr`, and
+ // `op`.
+ pdl_interp.apply_constraint "myConstraint"[42, "abc", i32](%input, %attr, %op : !pdl.value, !pdl.attribute, !pdl.operation)
+ ```
+ }];
+
+ let arguments = (ins StrAttr:$name,
+ Variadic<PDL_PositionalValue>:$args,
+ OptionalAttr<ArrayAttr>:$constParams);
+ let assemblyFormat = [{
+ $name ($constParams^)? `(` $args `:` type($args) `)` attr-dict `->`
+ successors
+ }];
+}
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::ApplyRewriteOp
+//===----------------------------------------------------------------------===//
+
+def PDLInterp_ApplyRewriteOp : PDLInterp_Op<"apply_rewrite"> {
+ let summary = "Invoke and apply an externally registered rewrite method";
+ let description = [{
+ `pdl_interp.apply_rewrite` operations invoke an external rewriter that has
+ been registered with the interpreter to perform the rewrite after a
+ successful match. The rewrite is passed the root operation being matched, a
+ set of additional positional arguments generated within the matcher, and a
+ set of constant parameters.
+
+ Example:
+
+ ```mlir
+ // Rewriter operating solely on the root operation.
+ pdl_interp.apply_rewrite "rewriter" on %root
+
+ // Rewriter operating on the root operation along with additional arguments
+ // from the matcher.
+ pdl_interp.apply_rewrite "rewriter"(%value : !pdl.value) on %root
+
+ // Rewriter operating on the root operation along with additional arguments
+ // and constant parameters.
+ pdl_interp.apply_rewrite "rewriter"[42](%value : !pdl.value) on %root
+ ```
+ }];
+ let arguments = (ins StrAttr:$name,
+ PDL_Operation:$root,
+ Variadic<PDL_PositionalValue>:$args,
+ OptionalAttr<ArrayAttr>:$constParams);
+ let assemblyFormat = [{
+ $name ($constParams^)? (`(` $args^ `:` type($args) `)`)? `on` $root
+ attr-dict
+ }];
+}
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::AreEqualOp
+//===----------------------------------------------------------------------===//
+
+def PDLInterp_AreEqualOp
+ : PDLInterp_PredicateOp<"are_equal", [NoSideEffect, SameTypeOperands]> {
+ let summary = "Check if two positional values are equivalent";
+ let description = [{
+ `pdl_interp.are_equal` operations compare two positional values for
+ equality. On success, this operation branches to the true destination,
+ otherwise the false destination is taken.
+
+ Example:
+
+ ```mlir
+ pdl_interp.are_equal %result1, %result2 : !pdl.value -> ^matchDest, ^failureDest
+ ```
+ }];
+
+ let arguments = (ins PDL_PositionalValue:$lhs,
+ PDL_PositionalValue:$rhs);
+ let assemblyFormat = "operands `:` type($lhs) attr-dict `->` successors";
+}
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::BranchOp
+//===----------------------------------------------------------------------===//
+
+def PDLInterp_BranchOp : PDLInterp_Op<"branch", [NoSideEffect, Terminator]> {
+ let summary = "General branch operation";
+ let description = [{
+ `pdl_interp.branch` operations expose general branch functionality to the
+ interpreter, and are generally used to branch from one pattern match
+ sequence to another.
+
+ Example:
+
+ ```mlir
+ pdl_interp.branch ^dest
+ ```
+ }];
+
+ let successors = (successor AnySuccessor:$dest);
+ let assemblyFormat = "$dest attr-dict";
+}
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::CheckAttributeOp
+//===----------------------------------------------------------------------===//
+
+def PDLInterp_CheckAttributeOp
+ : PDLInterp_PredicateOp<"check_attribute", [NoSideEffect]> {
+ let summary = "Check the value of an `Attribute`";
+ let description = [{
+ `pdl_interp.check_attribute` operations compare the value of a given
+ attribute with a constant value. On success, this operation branches to the
+ true destination, otherwise the false destination is taken.
+
+ Example:
+
+ ```mlir
+ pdl_interp.check_attribute %attr is 10 -> ^matchDest, ^failureDest
+ ```
+ }];
+
+ let arguments = (ins PDL_Attribute:$attribute, AnyAttr:$constantValue);
+ let assemblyFormat = [{
+ $attribute `is` $constantValue attr-dict `->` successors
+ }];
+}
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::CheckOperandCountOp
+//===----------------------------------------------------------------------===//
+
+def PDLInterp_CheckOperandCountOp
+ : PDLInterp_PredicateOp<"check_operand_count", [NoSideEffect]> {
+ let summary = "Check the number of operands of an `Operation`";
+ let description = [{
+ `pdl_interp.check_operand_count` operations compare the number of operands
+ of a given operation value with a constant. On success, this operation
+ branches to the true destination, otherwise the false destination is taken.
+
+ Example:
+
+ ```mlir
+ pdl_interp.check_operand_count of %op is 2 -> ^matchDest, ^failureDest
+ ```
+ }];
+
+ let arguments = (ins PDL_Operation:$operation,
+ Confined<I32Attr, [IntNonNegative]>:$count);
+ let assemblyFormat = "`of` $operation `is` $count attr-dict `->` successors";
+}
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::CheckOperationNameOp
+//===----------------------------------------------------------------------===//
+
+def PDLInterp_CheckOperationNameOp
+ : PDLInterp_PredicateOp<"check_operation_name", [NoSideEffect]> {
+ let summary = "Check the OperationName of an `Operation`";
+ let description = [{
+ `pdl_interp.check_operation_name` operations compare the name of a given
+ operation with a known name. On success, this operation branches to the true
+ destination, otherwise the false destination is taken.
+
+ Example:
+
+ ```mlir
+ pdl_interp.check_operation_name of %op is "foo.op" -> ^matchDest, ^failureDest
+ ```
+ }];
+
+ let arguments = (ins PDL_Operation:$operation, StrAttr:$name);
+ let assemblyFormat = "`of` $operation `is` $name attr-dict `->` successors";
+}
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::CheckResultCountOp
+//===----------------------------------------------------------------------===//
+
+def PDLInterp_CheckResultCountOp
+ : PDLInterp_PredicateOp<"check_result_count", [NoSideEffect]> {
+ let summary = "Check the number of results of an `Operation`";
+ let description = [{
+ `pdl_interp.check_result_count` operations compare the number of results
+ of a given operation value with a constant. On success, this operation
+ branches to the true destination, otherwise the false destination is taken.
+
+ Example:
+
+ ```mlir
+ pdl_interp.check_result_count of %op is 0 -> ^matchDest, ^failureDest
+ ```
+ }];
+
+ let arguments = (ins PDL_Operation:$operation,
+ Confined<I32Attr, [IntNonNegative]>:$count);
+ let assemblyFormat = "`of` $operation `is` $count attr-dict `->` successors";
+}
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::CheckTypeOp
+//===----------------------------------------------------------------------===//
+
+def PDLInterp_CheckTypeOp
+ : PDLInterp_PredicateOp<"check_type", [NoSideEffect]> {
+ let summary = "Compare a type to a known value";
+ let description = [{
+ `pdl_interp.check_type` operations compare a type with a statically known
+ type. On success, this operation branches to the true destination, otherwise
+ the false destination is taken.
+
+ Example:
+
+ ```mlir
+ pdl_interp.check_type %type is 0 -> ^matchDest, ^failureDest
+ ```
+ }];
+
+ let arguments = (ins PDL_Type:$value, TypeAttr:$type);
+ let assemblyFormat = "$value `is` $type attr-dict `->` successors";
+}
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::CreateAttributeOp
+//===----------------------------------------------------------------------===//
+
+def PDLInterp_CreateAttributeOp
+ : PDLInterp_Op<"create_attribute", [NoSideEffect]> {
+ let summary = "Create an interpreter handle to a constant `Attribute`";
+ let description = [{
+ `pdl_interp.create_attribute` operations generate a handle within the
+ interpreter for a specific constant attribute value.
+
+ Example:
+
+ ```mlir
+ pdl_interp.create_attribute 10 : i64
+ ```
+ }];
+
+ let arguments = (ins AnyAttr:$value);
+ let results = (outs PDL_Attribute:$attribute);
+ let assemblyFormat = "$value attr-dict";
+
+ let builders = [
+ OpBuilder<"OpBuilder &builder, OperationState &state, "
+ "Attribute value", [{
+ build(builder, state, builder.getType<pdl::AttributeType>(), value);
+ }]>];
+}
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::CreateNativeOp
+//===----------------------------------------------------------------------===//
+
+def PDLInterp_CreateNativeOp : PDLInterp_Op<"create_native"> {
+ let summary = "Call a native creation method to construct an `Attribute`, "
+ "`Operation`, `Type`, or `Value`";
+ let description = [{
+ `pdl_interp.create_native` operations invoke a native C++ function, that has
+ been registered externally with the consumer of PDL, to create an
+ `Attribute`, `Operation`, `Type`, or `Value`. The native function must
+ produce a value of the specified return type, and may accept any number of
+ positional arguments and constant attribute parameters.
+
+ Example:
+
+ ```mlir
+ %ret = pdl_interp.create_native "myNativeFunc"[42, "gt"](%arg0, %arg1) : !pdl.attribute
+ ```
+ }];
+
+ let arguments = (ins StrAttr:$name,
+ Variadic<PDL_PositionalValue>:$args,
+ OptionalAttr<ArrayAttr>:$constParams);
+ let results = (outs PDL_PositionalValue:$result);
+ let assemblyFormat = [{
+ $name ($constParams^)? (`(` $args^ `:` type($args) `)`)? `:` type($result)
+ attr-dict
+ }];
+ let verifier = ?;
+}
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::CreateOperationOp
+//===----------------------------------------------------------------------===//
+
+def PDLInterp_CreateOperationOp
+ : PDLInterp_Op<"create_operation", [AttrSizedOperandSegments]> {
+ let summary = "Create an instance of a specific `Operation`";
+ let description = [{
+ `pdl_interp.create_operation` operations create an `Operation` instance with
+ the specified attributes, operands, and result types.
+
+ Example:
+
+ ```mlir
+ // Create an instance of a `foo.op` operation.
+ %op = pdl_interp.create_operation "foo.op"(%arg0) {"attrA" = %attr0} -> %type, %type
+ ```
+ }];
+
+ let arguments = (ins StrAttr:$name,
+ Variadic<PDL_Value>:$operands,
+ Variadic<PDL_Attribute>:$attributes,
+ StrArrayAttr:$attributeNames,
+ Variadic<PDL_Type>:$types);
+ let results = (outs PDL_Operation:$operation);
+
+ let builders = [
+ OpBuilder<"OpBuilder &builder, OperationState &state, StringRef name, "
+ "ValueRange types, ValueRange operands, ValueRange attributes, "
+ "ArrayAttr attributeNames", [{
+ build(builder, state, builder.getType<pdl::OperationType>(), name,
+ operands, attributes, attributeNames, types);
+ }]>];
+ let parser = [{ return ::parseCreateOperationOp(parser, result); }];
+ let printer = [{ ::print(p, *this); }];
+}
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::CreateTypeOp
+//===----------------------------------------------------------------------===//
+
+def PDLInterp_CreateTypeOp : PDLInterp_Op<"create_type", [NoSideEffect]> {
+ let summary = "Create an interpreter handle to a constant `Type`";
+ let description = [{
+ `pdl_interp.create_type` operations generate a handle within the interpreter
+ for a specific constant type value.
+
+ Example:
+
+ ```mlir
+ pdl_interp.create_type i64
+ ```
+ }];
+
+ let arguments = (ins TypeAttr:$value);
+ let results = (outs PDL_Type:$result);
+ let assemblyFormat = "$value attr-dict";
+
+ let builders = [
+ OpBuilder<"OpBuilder &builder, OperationState &state, TypeAttr type", [{
+ build(builder, state, builder.getType<pdl::TypeType>(), type);
+ }]>
+ ];
+}
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::EraseOp
+//===----------------------------------------------------------------------===//
+
+def PDLInterp_EraseOp : PDLInterp_Op<"erase"> {
+ let summary = "Mark an operation as `erased`";
+ let description = [{
+ `pdl.erase` operations are used to specify that an operation should be
+ marked as erased. The semantics of this operation correspond with the
+ `eraseOp` method on a `PatternRewriter`.
+
+ Example:
+
+ ```mlir
+ pdl_interp.erase %root
+ ```
+ }];
+
+ let arguments = (ins PDL_Operation:$operation);
+ let assemblyFormat = "$operation attr-dict";
+}
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::FinalizeOp
+//===----------------------------------------------------------------------===//
+
+def PDLInterp_FinalizeOp
+ : PDLInterp_Op<"finalize", [NoSideEffect, Terminator]> {
+ let summary = "Finalize a pattern match or rewrite sequence";
+ let description = [{
+ `pdl_interp.finalize` is used to denote the termination of a match or
+ rewrite sequence.
+
+ Example:
+
+ ```mlir
+ pdl_interp.finalize
+ ```
+ }];
+ let assemblyFormat = "attr-dict";
+}
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::GetAttributeOp
+//===----------------------------------------------------------------------===//
+
+def PDLInterp_GetAttributeOp : PDLInterp_Op<"get_attribute", [NoSideEffect]> {
+ let summary = "Get a specified attribute value from an `Operation`";
+ let description = [{
+ `pdl_interp.get_attribute` operations try to get a specific attribute from
+ an operation. If the operation does not have that attribute, a null value is
+ returned.
+
+ Example:
+
+ ```mlir
+ %attr = pdl_interp.get_attribute "attr" of %op
+ ```
+ }];
+
+ let arguments = (ins PDL_Operation:$operation,
+ StrAttr:$name);
+ let results = (outs PDL_Attribute:$attribute);
+ let assemblyFormat = "$name `of` $operation attr-dict";
+}
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::GetAttributeTypeOp
+//===----------------------------------------------------------------------===//
+
+def PDLInterp_GetAttributeTypeOp
+ : PDLInterp_Op<"get_attribute_type", [NoSideEffect]> {
+ let summary = "Get the result type of a specified `Attribute`";
+ let description = [{
+ `pdl_interp.get_attribute_type` operations get the resulting type of a
+ specific attribute.
+
+ Example:
+
+ ```mlir
+ %type = pdl_interp.get_attribute_type of %attr
+ ```
+ }];
+
+ let arguments = (ins PDL_Attribute:$value);
+ let results = (outs PDL_Type:$result);
+ let assemblyFormat = "`of` $value attr-dict";
+
+ let builders = [
+ OpBuilder<"OpBuilder &builder, OperationState &state, Value value", [{
+ build(builder, state, builder.getType<pdl::TypeType>(), value);
+ }]>
+ ];
+}
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::GetDefiningOpOp
+//===----------------------------------------------------------------------===//
+
+def PDLInterp_GetDefiningOpOp
+ : PDLInterp_Op<"get_defining_op", [NoSideEffect]> {
+ let summary = "Get the defining operation of a `Value`";
+ let description = [{
+ `pdl_interp.get_defining_op` operations try to get the defining operation
+ of a specific value. If the value is not an operation result, null is
+ returned.
+
+ Example:
+
+ ```mlir
+ %op = pdl_interp.get_defining_op of %value
+ ```
+ }];
+
+ let arguments = (ins PDL_Value:$value);
+ let results = (outs PDL_Operation:$operation);
+ let assemblyFormat = "`of` $value attr-dict";
+}
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::GetOperandOp
+//===----------------------------------------------------------------------===//
+
+def PDLInterp_GetOperandOp : PDLInterp_Op<"get_operand", [NoSideEffect]> {
+ let summary = "Get a specified operand from an `Operation`";
+ let description = [{
+ `pdl_interp.get_operand` operations try to get a specific operand from an
+ operation If the operation does not have an operand for the given index, a
+ null value is returned.
+
+ Example:
+
+ ```mlir
+ %operand = pdl_interp.get_operand 1 of %op
+ ```
+ }];
+
+ let arguments = (ins PDL_Operation:$operation,
+ Confined<I32Attr, [IntNonNegative]>:$index);
+ let results = (outs PDL_Value:$value);
+ let assemblyFormat = "$index `of` $operation attr-dict";
+}
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::GetResultOp
+//===----------------------------------------------------------------------===//
+
+def PDLInterp_GetResultOp : PDLInterp_Op<"get_result", [NoSideEffect]> {
+ let summary = "Get a specified result from an `Operation`";
+ let description = [{
+ `pdl_interp.get_result` operations try to get a specific result from an
+ operation. If the operation does not have a result for the given index, a
+ null value is returned.
+
+ Example:
+
+ ```mlir
+ %result = pdl_interp.get_result 1 of %op
+ ```
+ }];
+
+ let arguments = (ins PDL_Operation:$operation,
+ Confined<I32Attr, [IntNonNegative]>:$index);
+ let results = (outs PDL_Value:$value);
+ let assemblyFormat = "$index `of` $operation attr-dict";
+}
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::GetValueTypeOp
+//===----------------------------------------------------------------------===//
+
+// Get a type from the root operation, held in the rewriter context.
+def PDLInterp_GetValueTypeOp : PDLInterp_Op<"get_value_type", [NoSideEffect]> {
+ let summary = "Get the result type of a specified `Value`";
+ let description = [{
+ `pdl_interp.get_value_type` operations get the resulting type of a specific
+ value.
+
+ Example:
+
+ ```mlir
+ %type = pdl_interp.get_value_type of %value
+ ```
+ }];
+
+ let arguments = (ins PDL_Value:$value);
+ let results = (outs PDL_Type:$result);
+ let assemblyFormat = "`of` $value attr-dict";
+
+ let builders = [
+ OpBuilder<"OpBuilder &builder, OperationState &state, Value value", [{
+ build(builder, state, builder.getType<pdl::TypeType>(), value);
+ }]>
+ ];
+}
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::InferredTypeOp
+//===----------------------------------------------------------------------===//
+
+def PDLInterp_InferredTypeOp : PDLInterp_Op<"inferred_type"> {
+ let summary = "Generate a handle to a Type that is \"inferred\"";
+ let description = [{
+ `pdl_interp.inferred_type` operations generate a handle to a type that
+ should be inferred. This signals to other operations, such as
+ `pdl_interp.create_operation`, that this type should be inferred.
+
+ Example:
+
+ ```mlir
+ pdl_interp.inferred_type
+ ```
+ }];
+ let results = (outs PDL_Type:$type);
+ let assemblyFormat = "attr-dict";
+
+ let builders = [
+ OpBuilder<"OpBuilder &builder, OperationState &state", [{
+ build(builder, state, builder.getType<pdl::TypeType>());
+ }]>,
+ ];
+}
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::IsNotNullOp
+//===----------------------------------------------------------------------===//
+
+def PDLInterp_IsNotNullOp
+ : PDLInterp_PredicateOp<"is_not_null", [NoSideEffect]> {
+ let summary = "Check if a positional value is non-null";
+ let description = [{
+ `pdl_interp.is_not_null` operations check that a positional value exists. On
+ success, this operation branches to the true destination. Otherwise, the
+ false destination is taken.
+
+ Example:
+
+ ```mlir
+ pdl_interp.is_not_null %value : !pdl.value -> ^matchDest, ^failureDest
+ ```
+ }];
+
+ let arguments = (ins PDL_PositionalValue:$value);
+ let assemblyFormat = "$value `:` type($value) attr-dict `->` successors";
+}
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::RecordMatchOp
+//===----------------------------------------------------------------------===//
+
+def PDLInterp_RecordMatchOp
+ : PDLInterp_Op<"record_match", [AttrSizedOperandSegments, Terminator]> {
+ let summary = "Record the metadata for a successful pattern match";
+ let description = [{
+ `pdl_interp.record_match` operations record a successful pattern match with
+ the interpreter and branch to the next part of the matcher. The metadata
+ recorded by these operations correspond to a specific `pdl.pattern`, as well
+ as what values were used during that match that should be propagated to the
+ rewriter.
+
+ Example:
+
+ ```mlir
+ pdl_interp.record_match @rewriters::myRewriter(%root : !pdl.operation) : benefit(1), loc([%root, %op1]), root("foo.op") -> ^nextDest
+ ```
+ }];
+
+ let arguments = (ins Variadic<PDL_PositionalValue>:$inputs,
+ Variadic<PDL_Operation>:$matchedOps,
+ SymbolRefAttr:$rewriter,
+ OptionalAttr<StrAttr>:$rootKind,
+ OptionalAttr<StrArrayAttr>:$generatedOps,
+ Confined<I16Attr, [IntNonNegative]>:$benefit);
+ let successors = (successor AnySuccessor:$dest);
+ let assemblyFormat = [{
+ $rewriter (`(` $inputs^ `:` type($inputs) `)`)? `:`
+ `benefit` `(` $benefit `)` `,`
+ (`generatedOps` `(` $generatedOps^ `)` `,`)?
+ `loc` `(` `[` $matchedOps `]` `)`
+ (`,` `root` `(` $rootKind^ `)`)? attr-dict `->` $dest
+ }];
+}
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::ReplaceOp
+//===----------------------------------------------------------------------===//
+
+def PDLInterp_ReplaceOp : PDLInterp_Op<"replace"> {
+ let summary = "Mark an operation as `replace`d";
+ let description = [{
+ `pdl_interp.replaced` operations are used to specify that an operation
+ should be marked as replaced. The semantics of this operation correspond
+ with the `replaceOp` method on a `PatternRewriter`. The set of replacement
+ values must match the number of results specified by the operation.
+
+ Example:
+
+ ```mlir
+ // Replace root node with 2 values:
+ pdl_interp.replace %root with (%val0, %val1)
+ ```
+ }];
+ let arguments = (ins PDL_Operation:$operation,
+ Variadic<PDL_Value>:$replValues);
+ let assemblyFormat = "$operation `with` `(` $replValues `)` attr-dict";
+}
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::SwitchAttributeOp
+//===----------------------------------------------------------------------===//
+
+def PDLInterp_SwitchAttributeOp
+ : PDLInterp_SwitchOp<"switch_attribute", [NoSideEffect]> {
+ let summary = "Switch on the value of an `Attribute`";
+ let description = [{
+ `pdl_interp.switch_attribute` operations compare the value of a given
+ attribute with a set of constant attributes. If the value matches one of the
+ provided case values the destination for that case value is taken, otherwise
+ the default destination is taken.
+
+ Example:
+
+ ```mlir
+ pdl_interp.switch_attribute %attr to [10, true] -> ^10Dest, ^trueDest, ^defaultDest
+ ```
+ }];
+ let arguments = (ins PDL_Attribute:$attribute, ArrayAttr:$caseValues);
+ let assemblyFormat = [{
+ $attribute `to` $caseValues `(` $cases `)` attr-dict `->` $defaultDest
+ }];
+
+ let builders = [
+ OpBuilder<"OpBuilder &builder, OperationState &state, Value attribute,"
+ "ArrayRef<Attribute> caseValues,"
+ "Block *defaultDest, ArrayRef<Block *> dests", [{
+ build(builder, state, attribute, builder.getArrayAttr(caseValues),
+ defaultDest, dests);
+ }]>];
+}
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::SwitchOperandCountOp
+//===----------------------------------------------------------------------===//
+
+def PDLInterp_SwitchOperandCountOp
+ : PDLInterp_SwitchOp<"switch_operand_count", [NoSideEffect]> {
+ let summary = "Switch on the operand count of an `Operation`";
+ let description = [{
+ `pdl_interp.switch_operand_count` operations compare the operand count of a
+ given operation with a set of potential counts. If the value matches one of
+ the provided case values the destination for that case value is taken,
+ otherwise the default destination is taken.
+
+ Example:
+
+ ```mlir
+ pdl_interp.switch_operand_count of %op to [10, 2] -> ^10Dest, ^2Dest, ^defaultDest
+ ```
+ }];
+
+ let arguments = (ins PDL_Operation:$operation, I32ElementsAttr:$caseValues);
+ let assemblyFormat = [{
+ `of` $operation `to` $caseValues `(` $cases `)` attr-dict `->` $defaultDest
+ }];
+
+ let builders = [
+ OpBuilder<"OpBuilder &builder, OperationState &state, Value operation, "
+ "ArrayRef<int32_t> counts, Block *defaultDest, "
+ "ArrayRef<Block *> dests", [{
+ build(builder, state, operation, builder.getI32VectorAttr(counts),
+ defaultDest, dests);
+ }]>];
+}
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::SwitchOperationNameOp
+//===----------------------------------------------------------------------===//
+
+def PDLInterp_SwitchOperationNameOp
+ : PDLInterp_SwitchOp<"switch_operation_name", [NoSideEffect]> {
+ let summary = "Switch on the OperationName of an `Operation`";
+ let description = [{
+ `pdl_interp.switch_operation_name` operations compare the name of a given
+ operation with a set of known names. If the value matches one of the
+ provided case values the destination for that case value is taken, otherwise
+ the default destination is taken.
+
+ Example:
+
+ ```mlir
+ pdl_interp.switch_operation_name of %op to ["foo.op", "bar.op"] -> ^fooDest, ^barDest, ^defaultDest
+ ```
+ }];
+
+ let arguments = (ins PDL_Operation:$operation,
+ StrArrayAttr:$caseValues);
+ let assemblyFormat = [{
+ `of` $operation `to` $caseValues `(` $cases `)` attr-dict `->` $defaultDest
+ }];
+
+ let builders = [
+ OpBuilder<"OpBuilder &builder, OperationState &state, Value operation, "
+ "ArrayRef<OperationName> names, "
+ "Block *defaultDest, ArrayRef<Block *> dests", [{
+ auto stringNames = llvm::to_vector<8>(llvm::map_range(names,
+ [](OperationName name) { return name.getStringRef(); }));
+ build(builder, state, operation, builder.getStrArrayAttr(stringNames),
+ defaultDest, dests);
+ }]>,
+ ];
+}
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::SwitchResultCountOp
+//===----------------------------------------------------------------------===//
+
+def PDLInterp_SwitchResultCountOp
+ : PDLInterp_SwitchOp<"switch_result_count", [NoSideEffect]> {
+ let summary = "Switch on the result count of an `Operation`";
+ let description = [{
+ `pdl_interp.switch_result_count` operations compare the result count of a
+ given operation with a set of potential counts. If the value matches one of
+ the provided case values the destination for that case value is taken,
+ otherwise the default destination is taken.
+
+ Example:
+
+ ```mlir
+ pdl_interp.switch_result_count of %op to [0, 2] -> ^0Dest, ^2Dest, ^defaultDest
+ ```
+ }];
+
+ let arguments = (ins PDL_Operation:$operation, I32ElementsAttr:$caseValues);
+ let assemblyFormat = [{
+ `of` $operation `to` $caseValues `(` $cases `)` attr-dict `->` $defaultDest
+ }];
+
+ let builders = [
+ OpBuilder<"OpBuilder &builder, OperationState &state, Value operation, "
+ "ArrayRef<int32_t> counts, Block *defaultDest, "
+ "ArrayRef<Block *> dests", [{
+ build(builder, state, operation, builder.getI32VectorAttr(counts),
+ defaultDest, dests);
+ }]>];
+}
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::SwitchTypeOp
+//===----------------------------------------------------------------------===//
+
+def PDLInterp_SwitchTypeOp : PDLInterp_SwitchOp<"switch_type", [NoSideEffect]> {
+ let summary = "Switch on a `Type` value";
+ let description = [{
+ `pdl_interp.switch_type` operations compare a type with a set of statically
+ known types. If the value matches one of the provided case values the
+ destination for that case value is taken, otherwise the default destination
+ is taken.
+
+ Example:
+
+ ```mlir
+ pdl_interp.switch_type %type to [i32, i64] -> ^i32Dest, ^i64Dest, ^defaultDest
+ ```
+ }];
+
+ let arguments = (ins PDL_Type:$value, TypeArrayAttr:$caseValues);
+ let assemblyFormat = [{
+ $value `to` $caseValues `(` $cases `)` attr-dict `->` $defaultDest
+ }];
+
+ let builders = [
+ OpBuilder<"OpBuilder &builder, OperationState &state, Value edge, "
+ "TypeRange types, Block *defaultDest, ArrayRef<Block *> dests", [{
+ build(builder, state, edge, builder.getTypeArrayAttr(types), defaultDest,
+ dests);
+ }]>,
+ ];
+
+ let extraClassDeclaration = [{
+ auto getCaseTypes() { return caseValues().getAsValueRange<TypeAttr>(); }
+ }];
+}
+
+#endif // MLIR_DIALECT_PDLINTERP_IR_PDLINTERPOPS
diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h
index aa8f2eafb896..d1b25cdbde80 100644
--- a/mlir/include/mlir/IR/Attributes.h
+++ b/mlir/include/mlir/IR/Attributes.h
@@ -217,12 +217,12 @@ class ArrayAttr : public Attribute::AttrBase<ArrayAttr, Attribute,
public:
template <typename AttrTy>
- llvm::iterator_range<attr_value_iterator<AttrTy>> getAsRange() {
+ iterator_range<attr_value_iterator<AttrTy>> getAsRange() {
return llvm::make_range(attr_value_iterator<AttrTy>(begin()),
attr_value_iterator<AttrTy>(end()));
}
- template <typename AttrTy, typename UnderlyingTy>
- auto getAsRange() {
+ template <typename AttrTy, typename UnderlyingTy = typename AttrTy::ValueType>
+ auto getAsValueRange() {
return llvm::map_range(getAsRange<AttrTy>(), [](AttrTy attr) {
return static_cast<UnderlyingTy>(attr.getValue());
});
@@ -589,6 +589,9 @@ class ElementsAttr : public Attribute {
/// Returns the number of elements held by this attribute.
int64_t getNumElements() const;
+ /// Returns the number of elements held by this attribute.
+ int64_t size() const { return getNumElements(); }
+
/// Generates a new ElementsAttr by mapping each int value to a new
/// underlying APInt. The new values can represent either an integer or float.
/// This ElementsAttr should contain integers.
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index c27585a6e343..aa1cc0a1a2b4 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -139,6 +139,7 @@ class Builder {
ArrayAttr getF32ArrayAttr(ArrayRef<float> values);
ArrayAttr getF64ArrayAttr(ArrayRef<double> values);
ArrayAttr getStrArrayAttr(ArrayRef<StringRef> values);
+ ArrayAttr getTypeArrayAttr(TypeRange values);
// Affine expressions and affine maps.
AffineExpr getAffineDimExpr(unsigned position);
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index 5bb4effa0f5c..e0726a9901cf 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -426,6 +426,12 @@ class OpAsmParser {
return parseOptionalAttribute(result, Type(), attrName, attrs);
}
+ /// Specialized variants of `parseOptionalAttribute` that remove potential
+ /// ambiguities in syntax.
+ virtual OptionalParseResult parseOptionalAttribute(ArrayAttr &result,
+ StringRef attrName,
+ NamedAttrList &attrs) = 0;
+
/// Parse an arbitrary attribute of a given type and return it in result. This
/// also adds the attribute to the specified attribute list with the specified
/// name.
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index deb27bbce413..190486a6c040 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -25,6 +25,7 @@
#include "mlir/Dialect/OpenACC/OpenACC.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "mlir/Dialect/PDL/IR/PDL.h"
+#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
#include "mlir/Dialect/Quant/QuantOps.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/SDBM/SDBMDialect.h"
@@ -49,6 +50,7 @@ inline void registerAllDialects(DialectRegistry ®istry) {
scf::SCFDialect,
omp::OpenMPDialect,
pdl::PDLDialect,
+ pdl_interp::PDLInterpDialect,
quant::QuantizationDialect,
spirv::SPIRVDialect,
StandardOpsDialect,
diff --git a/mlir/lib/Dialect/CMakeLists.txt b/mlir/lib/Dialect/CMakeLists.txt
index 36817638b269..790264fc4787 100644
--- a/mlir/lib/Dialect/CMakeLists.txt
+++ b/mlir/lib/Dialect/CMakeLists.txt
@@ -6,6 +6,7 @@ add_subdirectory(LLVMIR)
add_subdirectory(OpenACC)
add_subdirectory(OpenMP)
add_subdirectory(PDL)
+add_subdirectory(PDLInterp)
add_subdirectory(Quant)
add_subdirectory(SCF)
add_subdirectory(SDBM)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 559d41146903..c8e20ce57842 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -76,9 +76,7 @@ static LogicalResult isContraction(Operation *op) {
if (!genericOp)
return failure();
- auto mapRange =
- genericOp.indexing_maps().getAsRange<AffineMapAttr, AffineMap>();
-
+ auto mapRange = genericOp.indexing_maps().getAsValueRange<AffineMapAttr>();
return success(
genericOp.getNumInputs() == 2 && genericOp.getNumOutputs() == 1 &&
llvm::all_of(mapRange,
diff --git a/mlir/lib/Dialect/PDL/IR/PDL.cpp b/mlir/lib/Dialect/PDL/IR/PDL.cpp
index fc28e6964c72..0146f0d50b88 100644
--- a/mlir/lib/Dialect/PDL/IR/PDL.cpp
+++ b/mlir/lib/Dialect/PDL/IR/PDL.cpp
@@ -446,20 +446,39 @@ static LogicalResult verify(ReplaceOp op) {
//===----------------------------------------------------------------------===//
static ParseResult parseRewriteOp(OpAsmParser &p, OperationState &state) {
- // If the first token isn't a '(', this is an external rewrite.
- StringAttr nameAttr;
- if (failed(p.parseOptionalLParen())) {
- if (p.parseAttribute(nameAttr, "name", state.attributes) || p.parseLParen())
- return failure();
- }
-
// Parse the root operand.
OpAsmParser::OperandType rootOperand;
- if (p.parseOperand(rootOperand) || p.parseRParen() ||
+ if (p.parseOperand(rootOperand) ||
p.resolveOperand(rootOperand, p.getBuilder().getType<OperationType>(),
state.operands))
return failure();
+ // Parse an external rewrite.
+ StringAttr nameAttr;
+ if (succeeded(p.parseOptionalKeyword("with"))) {
+ if (p.parseAttribute(nameAttr, "name", state.attributes))
+ return failure();
+
+ // Parse the optional set of constant parameters.
+ ArrayAttr constantParams;
+ OptionalParseResult constantParamResult = p.parseOptionalAttribute(
+ constantParams, "externalConstParams", state.attributes);
+ if (constantParamResult.hasValue() && failed(*constantParamResult))
+ return failure();
+
+ // Parse the optional additional arguments.
+ if (succeeded(p.parseOptionalLParen())) {
+ SmallVector<OpAsmParser::OperandType, 4> arguments;
+ SmallVector<Type, 4> argumentTypes;
+ llvm::SMLoc argumentLoc = p.getCurrentLocation();
+ if (p.parseOperandList(arguments) ||
+ p.parseColonTypeList(argumentTypes) || p.parseRParen() ||
+ p.resolveOperands(arguments, argumentTypes, argumentLoc,
+ state.operands))
+ return failure();
+ }
+ }
+
// If this isn't an external rewrite, parse the region body.
Region &rewriteRegion = *state.addRegion();
if (!nameAttr) {
@@ -468,27 +487,58 @@ static ParseResult parseRewriteOp(OpAsmParser &p, OperationState &state) {
return failure();
RewriteOp::ensureTerminator(rewriteRegion, p.getBuilder(), state.location);
}
- return success();
+
+ return p.parseOptionalAttrDictWithKeyword(state.attributes);
}
static void print(OpAsmPrinter &p, RewriteOp op) {
- p << "pdl.rewrite";
+ p << "pdl.rewrite " << op.root();
if (Optional<StringRef> name = op.name()) {
- p << " \"" << *name << "\"(" << op.root() << ")";
- return;
+ p << " with \"" << *name << "\"";
+
+ if (ArrayAttr constantParams = op.externalConstParamsAttr())
+ p << constantParams;
+
+ OperandRange externalArgs = op.externalArgs();
+ if (!externalArgs.empty())
+ p << "(" << externalArgs << " : " << externalArgs.getTypes() << ")";
+ } else {
+ p.printRegion(op.body(), /*printEntryBlockArgs=*/false,
+ /*printBlockTerminators=*/false);
}
- p << "(" << op.root() << ")";
- p.printRegion(op.body(), /*printEntryBlockArgs=*/false,
- /*printBlockTerminators=*/false);
+ p.printOptionalAttrDictWithKeyword(op.getAttrs(),
+ {"name", "externalConstParams"});
}
static LogicalResult verify(RewriteOp op) {
Region &rewriteRegion = op.body();
- if (llvm::hasNItemsOrMore(rewriteRegion, 2)) {
- return op.emitOpError()
- << "expected rewrite region when specified to have a single block";
+
+ // Handle the case where the rewrite is external.
+ if (op.name()) {
+ if (!rewriteRegion.empty()) {
+ return op.emitOpError()
+ << "expected rewrite region to be empty when rewrite is external";
+ }
+ return success();
+ }
+
+ // Otherwise, check that the rewrite region only contains a single block.
+ if (rewriteRegion.empty()) {
+ return op.emitOpError() << "expected rewrite region to be non-empty if "
+ "external name is not specified";
}
+
+ // Check that no additional arguments were provided.
+ if (!op.externalArgs().empty()) {
+ return op.emitOpError() << "expected no external arguments when the "
+ "rewrite is specified inline";
+ }
+ if (op.externalConstParams()) {
+ return op.emitOpError() << "expected no external constant parameters when "
+ "the rewrite is specified inline";
+ }
+
return success();
}
diff --git a/mlir/lib/Dialect/PDLInterp/CMakeLists.txt b/mlir/lib/Dialect/PDLInterp/CMakeLists.txt
new file mode 100644
index 000000000000..f33061b2d87c
--- /dev/null
+++ b/mlir/lib/Dialect/PDLInterp/CMakeLists.txt
@@ -0,0 +1 @@
+add_subdirectory(IR)
diff --git a/mlir/lib/Dialect/PDLInterp/IR/CMakeLists.txt b/mlir/lib/Dialect/PDLInterp/IR/CMakeLists.txt
new file mode 100644
index 000000000000..6e0ebc66494f
--- /dev/null
+++ b/mlir/lib/Dialect/PDLInterp/IR/CMakeLists.txt
@@ -0,0 +1,15 @@
+add_mlir_dialect_library(MLIRPDLInterp
+ PDLInterp.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/PDLInterp
+
+ DEPENDS
+ MLIRPDLInterpOpsIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+ MLIRPDL
+ MLIRInferTypeOpInterface
+ MLIRSideEffectInterfaces
+ )
diff --git a/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp b/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp
new file mode 100644
index 000000000000..2119d7aeeea2
--- /dev/null
+++ b/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp
@@ -0,0 +1,122 @@
+//===- PDLInterp.cpp - PDL Interpreter Dialect ------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
+#include "mlir/Dialect/PDL/IR/PDLTypes.h"
+#include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/StandardTypes.h"
+
+using namespace mlir;
+using namespace mlir::pdl_interp;
+
+//===----------------------------------------------------------------------===//
+// PDLInterp Dialect
+//===----------------------------------------------------------------------===//
+
+void PDLInterpDialect::initialize() {
+ addOperations<
+#define GET_OP_LIST
+#include "mlir/Dialect/PDLInterp/IR/PDLInterpOps.cpp.inc"
+ >();
+}
+
+//===----------------------------------------------------------------------===//
+// pdl_interp::CreateOperationOp
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseCreateOperationOp(OpAsmParser &p,
+ OperationState &state) {
+ if (p.parseOptionalAttrDict(state.attributes))
+ return failure();
+ Builder &builder = p.getBuilder();
+
+ // Parse the operation name.
+ StringAttr opName;
+ if (p.parseAttribute(opName, "name", state.attributes))
+ return failure();
+
+ // Parse the operands.
+ SmallVector<OpAsmParser::OperandType, 4> operands;
+ if (p.parseLParen() || p.parseOperandList(operands) || p.parseRParen() ||
+ p.resolveOperands(operands, builder.getType<pdl::ValueType>(),
+ state.operands))
+ return failure();
+
+ // Parse the attributes.
+ SmallVector<Attribute, 4> attrNames;
+ if (succeeded(p.parseOptionalLBrace())) {
+ SmallVector<OpAsmParser::OperandType, 4> attrOps;
+ do {
+ StringAttr nameAttr;
+ OpAsmParser::OperandType operand;
+ if (p.parseAttribute(nameAttr) || p.parseEqual() ||
+ p.parseOperand(operand))
+ return failure();
+ attrNames.push_back(nameAttr);
+ attrOps.push_back(operand);
+ } while (succeeded(p.parseOptionalComma()));
+
+ if (p.parseRBrace() ||
+ p.resolveOperands(attrOps, builder.getType<pdl::AttributeType>(),
+ state.operands))
+ return failure();
+ }
+ state.addAttribute("attributeNames", builder.getArrayAttr(attrNames));
+ state.addTypes(builder.getType<pdl::OperationType>());
+
+ // Parse the result types.
+ SmallVector<OpAsmParser::OperandType, 4> opResultTypes;
+ if (p.parseArrow())
+ return failure();
+ if (succeeded(p.parseOptionalLParen())) {
+ if (p.parseRParen())
+ return failure();
+ } else if (p.parseOperandList(opResultTypes) ||
+ p.resolveOperands(opResultTypes, builder.getType<pdl::TypeType>(),
+ state.operands)) {
+ return failure();
+ }
+
+ int32_t operandSegmentSizes[] = {static_cast<int32_t>(operands.size()),
+ static_cast<int32_t>(attrNames.size()),
+ static_cast<int32_t>(opResultTypes.size())};
+ state.addAttribute("operand_segment_sizes",
+ builder.getI32VectorAttr(operandSegmentSizes));
+ return success();
+}
+
+static void print(OpAsmPrinter &p, CreateOperationOp op) {
+ p << "pdl_interp.create_operation ";
+ p.printOptionalAttrDict(op.getAttrs(),
+ {"attributeNames", "name", "operand_segment_sizes"});
+ p << '"' << op.name() << "\"(" << op.operands() << ')';
+
+ // Emit the optional attributes.
+ ArrayAttr attrNames = op.attributeNames();
+ if (!attrNames.empty()) {
+ Operation::operand_range attrArgs = op.attributes();
+ p << " {";
+ interleaveComma(llvm::seq<int>(0, attrNames.size()), p,
+ [&](int i) { p << attrNames[i] << " = " << attrArgs[i]; });
+ p << '}';
+ }
+
+ // Print the result type constraints of the operation.
+ auto types = op.types();
+ if (types.empty())
+ p << " -> ()";
+ else
+ p << " -> " << op.types();
+}
+
+//===----------------------------------------------------------------------===//
+// TableGen Auto-Generated Op and Interface Definitions
+//===----------------------------------------------------------------------===//
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/PDLInterp/IR/PDLInterpOps.cpp.inc"
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index c45d03174db9..dfd7b10f4f0b 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -261,6 +261,12 @@ ArrayAttr Builder::getStrArrayAttr(ArrayRef<StringRef> values) {
return getArrayAttr(attrs);
}
+ArrayAttr Builder::getTypeArrayAttr(TypeRange values) {
+ auto attrs = llvm::to_vector<8>(llvm::map_range(
+ values, [](Type v) -> Attribute { return TypeAttr::get(v); }));
+ return getArrayAttr(attrs);
+}
+
ArrayAttr Builder::getAffineMapArrayAttr(ArrayRef<AffineMap> values) {
auto attrs = llvm::to_vector<8>(llvm::map_range(
values, [](AffineMap v) -> Attribute { return AffineMapAttr::get(v); }));
diff --git a/mlir/lib/Parser/AttributeParser.cpp b/mlir/lib/Parser/AttributeParser.cpp
index 37ee938a4bcd..b7cae2778c10 100644
--- a/mlir/lib/Parser/AttributeParser.cpp
+++ b/mlir/lib/Parser/AttributeParser.cpp
@@ -221,6 +221,9 @@ OptionalParseResult Parser::parseOptionalAttribute(Attribute &attribute,
return result;
}
}
+OptionalParseResult Parser::parseOptionalAttribute(ArrayAttr &attribute) {
+ return parseOptionalAttributeWithToken(Token::l_square, attribute);
+}
/// Attribute dictionary.
///
diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp
index 0d3c659d0e0f..32d11e571ded 100644
--- a/mlir/lib/Parser/Parser.cpp
+++ b/mlir/lib/Parser/Parser.cpp
@@ -1045,15 +1045,37 @@ class CustomOpAsmParser : public OpAsmParser {
}
/// Parse an optional attribute.
- OptionalParseResult parseOptionalAttribute(Attribute &result, Type type,
- StringRef attrName,
- NamedAttrList &attrs) override {
+ /// Template utilities to simplify specifying multiple derived overloads.
+ template <typename AttrT>
+ OptionalParseResult
+ parseOptionalAttributeAndAddToList(AttrT &result, Type type,
+ StringRef attrName, NamedAttrList &attrs) {
OptionalParseResult parseResult =
parser.parseOptionalAttribute(result, type);
if (parseResult.hasValue() && succeeded(*parseResult))
attrs.push_back(parser.builder.getNamedAttr(attrName, result));
return parseResult;
}
+ template <typename AttrT>
+ OptionalParseResult parseOptionalAttributeAndAddToList(AttrT &result,
+ StringRef attrName,
+ NamedAttrList &attrs) {
+ OptionalParseResult parseResult = parser.parseOptionalAttribute(result);
+ if (parseResult.hasValue() && succeeded(*parseResult))
+ attrs.push_back(parser.builder.getNamedAttr(attrName, result));
+ return parseResult;
+ }
+
+ OptionalParseResult parseOptionalAttribute(Attribute &result, Type type,
+ StringRef attrName,
+ NamedAttrList &attrs) override {
+ return parseOptionalAttributeAndAddToList(result, type, attrName, attrs);
+ }
+ OptionalParseResult parseOptionalAttribute(ArrayAttr &result,
+ StringRef attrName,
+ NamedAttrList &attrs) override {
+ return parseOptionalAttributeAndAddToList(result, attrName, attrs);
+ }
/// Parse a named dictionary into 'result' if it is present.
ParseResult parseOptionalAttrDict(NamedAttrList &result) override {
diff --git a/mlir/lib/Parser/Parser.h b/mlir/lib/Parser/Parser.h
index 3b2c6e852544..61e54be83139 100644
--- a/mlir/lib/Parser/Parser.h
+++ b/mlir/lib/Parser/Parser.h
@@ -187,6 +187,22 @@ class Parser {
/// Parse an optional attribute with the provided type.
OptionalParseResult parseOptionalAttribute(Attribute &attribute,
Type type = {});
+ OptionalParseResult parseOptionalAttribute(ArrayAttr &attribute);
+
+ /// Parse an optional attribute that is demarcated by a specific token.
+ template <typename AttributeT>
+ OptionalParseResult parseOptionalAttributeWithToken(Token::Kind kind,
+ AttributeT &attr,
+ Type type = {}) {
+ if (getToken().isNot(kind))
+ return llvm::None;
+
+ if (Attribute parsedAttr = parseAttribute()) {
+ attr = parsedAttr.cast<ArrayAttr>();
+ return success();
+ }
+ return failure();
+ }
/// Parse an attribute dictionary.
ParseResult parseAttributeDict(NamedAttrList &attributes);
diff --git a/mlir/test/Dialect/PDL/invalid.mlir b/mlir/test/Dialect/PDL/invalid.mlir
index 7058d8bddbc4..f5c65409e671 100644
--- a/mlir/test/Dialect/PDL/invalid.mlir
+++ b/mlir/test/Dialect/PDL/invalid.mlir
@@ -9,7 +9,7 @@ pdl.pattern : benefit(1) {
// expected-error at below {{expected at least one argument}}
"pdl.apply_constraint"() {name = "foo", params = []} : () -> ()
- pdl.rewrite "rewriter"(%op)
+ pdl.rewrite %op with "rewriter"
}
// -----
@@ -25,14 +25,14 @@ pdl.pattern : benefit(1) {
%attr = pdl.attribute : %type 10
%op, %result = pdl.operation "foo.op" {"attr" = %attr} -> %type
- pdl.rewrite "rewriter"(%op)
+ pdl.rewrite %op with "rewriter"
}
// -----
pdl.pattern : benefit(1) {
%op = pdl.operation "foo.op"
- pdl.rewrite(%op) {
+ pdl.rewrite %op {
%type = pdl.type
// expected-error at below {{expected constant value when specified within a `pdl.rewrite`}}
@@ -44,7 +44,7 @@ pdl.pattern : benefit(1) {
pdl.pattern : benefit(1) {
%op = pdl.operation "foo.op"
- pdl.rewrite(%op) {
+ pdl.rewrite %op {
// expected-error at below {{expected constant value when specified within a `pdl.rewrite`}}
%attr = pdl.attribute
}
@@ -57,7 +57,7 @@ pdl.pattern : benefit(1) {
%unused = pdl.attribute
%op = pdl.operation "foo.op"
- pdl.rewrite "rewriter"(%op)
+ pdl.rewrite %op with "rewriter"
}
// -----
@@ -71,7 +71,7 @@ pdl.pattern : benefit(1) {
%unused = pdl.input
%op = pdl.operation "foo.op"
- pdl.rewrite "rewriter"(%op)
+ pdl.rewrite %op with "rewriter"
}
// -----
@@ -82,7 +82,7 @@ pdl.pattern : benefit(1) {
pdl.pattern : benefit(1) {
%op = pdl.operation "foo.op"
- pdl.rewrite(%op) {
+ pdl.rewrite %op {
// expected-error at below {{must have an operation name when nested within a `pdl.rewrite`}}
%newOp = pdl.operation
}
@@ -96,14 +96,14 @@ pdl.pattern : benefit(1) {
attributeNames = ["attr"],
operand_segment_sizes = dense<0> : vector<3xi32>
} : () -> (!pdl.operation)
- pdl.rewrite "rewriter"(%op)
+ pdl.rewrite %op with "rewriter"
}
// -----
pdl.pattern : benefit(1) {
%op = pdl.operation "foo.op"()
- pdl.rewrite (%op) {
+ pdl.rewrite %op {
%type = pdl.type
// expected-error at below {{op must have inferable or constrained result types when nested within `pdl.rewrite`}}
@@ -119,7 +119,7 @@ pdl.pattern : benefit(1) {
%unused = pdl.operation "foo.op"
%op = pdl.operation "foo.op"
- pdl.rewrite "rewriter"(%op)
+ pdl.rewrite %op with "rewriter"
}
// -----
@@ -142,7 +142,7 @@ pdl.pattern : benefit(1) {
"foo.other_op"() : () -> ()
%root = pdl.operation "foo.op"
- pdl.rewrite "foo"(%root)
+ pdl.rewrite %root with "foo"
}
// -----
@@ -153,7 +153,7 @@ pdl.pattern : benefit(1) {
pdl.pattern : benefit(1) {
%root = pdl.operation "foo.op"
- pdl.rewrite (%root) {
+ pdl.rewrite %root {
%type = pdl.type : i32
%newOp, %newResult = pdl.operation "foo.op" -> %type
@@ -167,7 +167,7 @@ pdl.pattern : benefit(1) {
pdl.pattern : benefit(1) {
%type = pdl.type : i32
%root, %oldResult = pdl.operation "foo.op" -> %type
- pdl.rewrite (%root) {
+ pdl.rewrite %root {
%newOp, %newResult = pdl.operation "foo.op" -> %type
// expected-error at below {{expected no replacement values to be provided when the replacement operation is present}}
@@ -181,7 +181,7 @@ pdl.pattern : benefit(1) {
pdl.pattern : benefit(1) {
%root = pdl.operation "foo.op"
- pdl.rewrite (%root) {
+ pdl.rewrite %root {
%type = pdl.type : i32
%newOp, %newResult = pdl.operation "foo.op" -> %type
@@ -192,6 +192,55 @@ pdl.pattern : benefit(1) {
// -----
+//===----------------------------------------------------------------------===//
+// pdl::RewriteOp
+//===----------------------------------------------------------------------===//
+
+pdl.pattern : benefit(1) {
+ %op = pdl.operation "foo.op"
+
+ // expected-error at below {{expected rewrite region to be non-empty if external name is not specified}}
+ "pdl.rewrite"(%op) ({}) : (!pdl.operation) -> ()
+}
+
+// -----
+
+pdl.pattern : benefit(1) {
+ %op = pdl.operation "foo.op"
+
+ // expected-error at below {{expected no external arguments when the rewrite is specified inline}}
+ "pdl.rewrite"(%op, %op) ({
+ ^bb1:
+ pdl.rewrite_end
+ }) : (!pdl.operation, !pdl.operation) -> ()
+}
+
+// -----
+
+pdl.pattern : benefit(1) {
+ %op = pdl.operation "foo.op"
+
+ // expected-error at below {{expected no external constant parameters when the rewrite is specified inline}}
+ "pdl.rewrite"(%op) ({
+ ^bb1:
+ pdl.rewrite_end
+ }) {externalConstParams = []} : (!pdl.operation) -> ()
+}
+
+// -----
+
+pdl.pattern : benefit(1) {
+ %op = pdl.operation "foo.op"
+
+ // expected-error at below {{expected rewrite region to be empty when rewrite is external}}
+ "pdl.rewrite"(%op) ({
+ ^bb1:
+ pdl.rewrite_end
+ }) {name = "foo"} : (!pdl.operation) -> ()
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// pdl::TypeOp
//===----------------------------------------------------------------------===//
@@ -201,5 +250,5 @@ pdl.pattern : benefit(1) {
%unused = pdl.type
%op = pdl.operation "foo.op"
- pdl.rewrite "rewriter"(%op)
+ pdl.rewrite %op with "rewriter"
}
diff --git a/mlir/test/Dialect/PDL/ops.mlir b/mlir/test/Dialect/PDL/ops.mlir
index f7c425f3bf29..37db36e654d4 100644
--- a/mlir/test/Dialect/PDL/ops.mlir
+++ b/mlir/test/Dialect/PDL/ops.mlir
@@ -1,8 +1,6 @@
// RUN: mlir-opt -split-input-file %s | mlir-opt
-// Verify the printed output can be parsed.
-// RUN: mlir-opt %s | mlir-opt
// Verify the generic form can be parsed.
-// RUN: mlir-opt -mlir-print-op-generic %s | mlir-opt
+// RUN: mlir-opt -split-input-file -mlir-print-op-generic %s | mlir-opt
// -----
@@ -15,7 +13,30 @@ pdl.pattern @operations : benefit(1) {
// Operation with input.
%input = pdl.input
%root = pdl.operation(%op0_result, %input)
- pdl.rewrite "rewriter"(%root)
+ pdl.rewrite %root with "rewriter"
+}
+
+// -----
+
+pdl.pattern @rewrite_with_args : benefit(1) {
+ %input = pdl.input
+ %root = pdl.operation(%input)
+ pdl.rewrite %root with "rewriter"(%input : !pdl.value)
+}
+
+// -----
+
+pdl.pattern @rewrite_with_params : benefit(1) {
+ %root = pdl.operation
+ pdl.rewrite %root with "rewriter"["I am param"]
+}
+
+// -----
+
+pdl.pattern @rewrite_with_args_and_params : benefit(1) {
+ %input = pdl.input
+ %root = pdl.operation(%input)
+ pdl.rewrite %root with "rewriter"["I am param"](%input : !pdl.value)
}
// -----
@@ -26,7 +47,7 @@ pdl.pattern @infer_type_from_operation_replace : benefit(1) {
%type1 = pdl.type : i32
%type2 = pdl.type
%root, %results:2 = pdl.operation -> %type1, %type2
- pdl.rewrite(%root) {
+ pdl.rewrite %root {
%type3 = pdl.type
%newOp, %newResults:2 = pdl.operation "foo.op" -> %type1, %type3
pdl.replace %root with %newOp
@@ -41,7 +62,7 @@ pdl.pattern @infer_type_from_result_replace : benefit(1) {
%type1 = pdl.type : i32
%type2 = pdl.type
%root, %results:2 = pdl.operation -> %type1, %type2
- pdl.rewrite(%root) {
+ pdl.rewrite %root {
%type3 = pdl.type
%newOp, %newResults:2 = pdl.operation "foo.op" -> %type1, %type3
pdl.replace %root with (%newResults#0, %newResults#1)
@@ -56,7 +77,7 @@ pdl.pattern @infer_type_from_type_used_in_match : benefit(1) {
%type1 = pdl.type : i32
%type2 = pdl.type
%root, %results:2 = pdl.operation -> %type1, %type2
- pdl.rewrite(%root) {
+ pdl.rewrite %root {
%newOp, %newResults:2 = pdl.operation "foo.op" -> %type1, %type2
}
}
diff --git a/mlir/test/Dialect/PDLInterp/ops.mlir b/mlir/test/Dialect/PDLInterp/ops.mlir
new file mode 100644
index 000000000000..d76b17c394e8
--- /dev/null
+++ b/mlir/test/Dialect/PDLInterp/ops.mlir
@@ -0,0 +1,25 @@
+// RUN: mlir-opt -split-input-file %s | mlir-opt
+// Verify the printed output can be parsed.
+// RUN: mlir-opt %s | mlir-opt
+// Verify the generic form can be parsed.
+// RUN: mlir-opt -mlir-print-op-generic %s | mlir-opt
+
+// -----
+
+func @operations(%attribute: !pdl.attribute,
+ %input: !pdl.value,
+ %type: !pdl.type) {
+ // attributes, operands, and results
+ %op0 = pdl_interp.create_operation "foo.op"(%input) {"attr" = %attribute} -> %type
+
+ // attributes, and results
+ %op1 = pdl_interp.create_operation "foo.op"() {"attr" = %attribute} -> %type
+
+ // attributes
+ %op2 = pdl_interp.create_operation "foo.op"() {"attr" = %attribute, "attr1" = %attribute} -> ()
+
+ // operands, and results
+ %op3 = pdl_interp.create_operation "foo.op"(%input) -> %type
+
+ pdl_interp.finalize
+}
diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
index 82a7312312d3..1f3ac4934224 100644
--- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
@@ -226,7 +226,7 @@ bool LiteralElement::isValidLiteral(StringRef value) {
// If there is only one character, this must either be punctuation or a
// single character bare identifier.
if (value.size() == 1)
- return isalpha(front) || StringRef("_:,=<>()[]?").contains(front);
+ return isalpha(front) || StringRef("_:,=<>()[]{}?").contains(front);
// Check the punctuation that are larger than a single character.
if (value == "->")
@@ -583,6 +583,8 @@ static void genLiteralParser(StringRef value, OpMethodBody &body) {
.Case("=", "Equal()")
.Case("<", "Less()")
.Case(">", "Greater()")
+ .Case("{", "LBrace()")
+ .Case("}", "RBrace()")
.Case("(", "LParen()")
.Case(")", "RParen()")
.Case("[", "LSquare()")
More information about the Mlir-commits
mailing list