[Mlir-commits] [mlir] 3a833a0 - [mlir][PDL] Add support for variadic operands and results in the PDL Interpreter
River Riddle
llvmlistbot at llvm.org
Tue Mar 16 13:20:35 PDT 2021
Author: River Riddle
Date: 2021-03-16T13:20:19-07:00
New Revision: 3a833a0e0e526d4ef3f0037eaa2ace3511f216ce
URL: https://github.com/llvm/llvm-project/commit/3a833a0e0e526d4ef3f0037eaa2ace3511f216ce
DIFF: https://github.com/llvm/llvm-project/commit/3a833a0e0e526d4ef3f0037eaa2ace3511f216ce.diff
LOG: [mlir][PDL] Add support for variadic operands and results in the PDL Interpreter
This revision extends the PDL Interpreter dialect to add support for variadic operands and results, with ranges of these values represented via the recently added !pdl.range type. To support this extension, three new operations have been added that closely match the single variant:
* pdl_interp.check_types : Compare a range of types with a known range.
* pdl_interp.create_types : Create a constant range of types.
* pdl_interp.get_operands : Get a range of operands from an operation.
* pdl_interp.get_results : Get a range of results from an operation.
* pdl_interp.switch_types : Switch on a range of types.
This revision handles adding support in the interpreter dialect and the conversion from PDL to PDLInterp. Support for variadic operands and results in the bytecode will be added in a followup revision.
Differential Revision: https://reviews.llvm.org/D95722
Added:
Modified:
mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td
mlir/include/mlir/IR/OpBase.td
mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
mlir/lib/Conversion/PDLToPDLInterp/Predicate.cpp
mlir/lib/Conversion/PDLToPDLInterp/Predicate.h
mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp
mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.h
mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp
mlir/lib/Rewrite/ByteCode.cpp
mlir/lib/TableGen/Predicate.cpp
mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir
mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-rewriter.mlir
mlir/test/Dialect/PDLInterp/ops.mlir
mlir/test/Rewrite/pdl-bytecode.mlir
mlir/test/mlir-tblgen/op-attribute.td
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td
index 8f8a5b130175..e35208747ade 100644
--- a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td
+++ b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td
@@ -168,7 +168,7 @@ def PDLInterp_ApplyRewriteOp : PDLInterp_Op<"apply_rewrite"> {
def PDLInterp_AreEqualOp
: PDLInterp_PredicateOp<"are_equal", [NoSideEffect, SameTypeOperands]> {
- let summary = "Check if two positional values are equivalent";
+ let summary = "Check if two positional values or ranges are equivalent";
let description = [{
`pdl_interp.are_equal` operations compare two positional values for
equality. On success, this operation branches to the true destination,
@@ -241,19 +241,29 @@ def PDLInterp_CheckOperandCountOp
let summary = "Check the number of operands of an `Operation`";
let description = [{
`pdl_interp.check_operand_count` operations compare the number of operands
- of a given operation value with a constant. On success, this operation
- branches to the true destination, otherwise the false destination is taken.
+ of a given operation value with a constant. The comparison is either exact
+ or at_least, with the latter used to compare against a minimum number of
+ expected operands. On success, this operation branches to the true
+ destination, otherwise the false destination is taken.
Example:
```mlir
+ // Check for exact equality.
pdl_interp.check_operand_count of %op is 2 -> ^matchDest, ^failureDest
+
+ // Check for at least N operands.
+ pdl_interp.check_operand_count of %op is at_least 2 -> ^matchDest, ^failureDest
```
}];
let arguments = (ins PDL_Operation:$operation,
- Confined<I32Attr, [IntNonNegative]>:$count);
- let assemblyFormat = "`of` $operation `is` $count attr-dict `->` successors";
+ Confined<I32Attr, [IntNonNegative]>:$count,
+ UnitAttr:$compareAtLeast);
+ let assemblyFormat = [{
+ `of` $operation `is` (`at_least` $compareAtLeast^)? $count attr-dict
+ `->` successors
+ }];
}
//===----------------------------------------------------------------------===//
@@ -288,19 +298,29 @@ def PDLInterp_CheckResultCountOp
let summary = "Check the number of results of an `Operation`";
let description = [{
`pdl_interp.check_result_count` operations compare the number of results
- of a given operation value with a constant. On success, this operation
- branches to the true destination, otherwise the false destination is taken.
+ of a given operation value with a constant. The comparison is either exact
+ or at_least, with the latter used to compare against a minimum number of
+ expected results. On success, this operation branches to the true
+ destination, otherwise the false destination is taken.
Example:
```mlir
- pdl_interp.check_result_count of %op is 0 -> ^matchDest, ^failureDest
+ // Check for exact equality.
+ pdl_interp.check_result_count of %op is 2 -> ^matchDest, ^failureDest
+
+ // Check for at least N results.
+ pdl_interp.check_result_count of %op is at_least 2 -> ^matchDest, ^failureDest
```
}];
let arguments = (ins PDL_Operation:$operation,
- Confined<I32Attr, [IntNonNegative]>:$count);
- let assemblyFormat = "`of` $operation `is` $count attr-dict `->` successors";
+ Confined<I32Attr, [IntNonNegative]>:$count,
+ UnitAttr:$compareAtLeast);
+ let assemblyFormat = [{
+ `of` $operation `is` (`at_least` $compareAtLeast^)? $count attr-dict
+ `->` successors
+ }];
}
//===----------------------------------------------------------------------===//
@@ -326,6 +346,30 @@ def PDLInterp_CheckTypeOp
let assemblyFormat = "$value `is` $type attr-dict `->` successors";
}
+//===----------------------------------------------------------------------===//
+// pdl_interp::CheckTypesOp
+//===----------------------------------------------------------------------===//
+
+def PDLInterp_CheckTypesOp
+ : PDLInterp_PredicateOp<"check_types", [NoSideEffect]> {
+ let summary = "Compare a range of types to a range of known values";
+ let description = [{
+ `pdl_interp.check_types` operations compare a range of types with a
+ statically known range of types. On success, this operation branches
+ to the true destination, otherwise the false destination is taken.
+
+ Example:
+
+ ```mlir
+ pdl_interp.check_types %type are [i32, i64] -> ^matchDest, ^failureDest
+ ```
+ }];
+
+ let arguments = (ins PDL_RangeOf<PDL_Type>:$value,
+ TypeArrayAttr:$types);
+ let assemblyFormat = "$value `are` $types attr-dict `->` successors";
+}
+
//===----------------------------------------------------------------------===//
// pdl_interp::CreateAttributeOp
//===----------------------------------------------------------------------===//
@@ -363,21 +407,23 @@ def PDLInterp_CreateOperationOp
let summary = "Create an instance of a specific `Operation`";
let description = [{
`pdl_interp.create_operation` operations create an `Operation` instance with
- the specified attributes, operands, and result types.
+ the specified attributes, operands, and result types. See `pdl.operation`
+ for a more detailed description on the interpretation of the arguments to
+ this operation.
Example:
```mlir
// Create an instance of a `foo.op` operation.
- %op = pdl_interp.create_operation "foo.op"(%arg0) {"attrA" = %attr0} -> %type, %type
+ %op = pdl_interp.create_operation "foo.op"(%arg0 : !pdl.value) {"attrA" = %attr0} -> (%type : !pdl.type)
```
}];
let arguments = (ins StrAttr:$name,
- Variadic<PDL_Value>:$operands,
+ Variadic<PDL_InstOrRangeOf<PDL_Value>>:$operands,
Variadic<PDL_Attribute>:$attributes,
StrArrayAttr:$attributeNames,
- Variadic<PDL_Type>:$types);
+ Variadic<PDL_InstOrRangeOf<PDL_Type>>:$types);
let results = (outs PDL_Operation:$operation);
let builders = [
@@ -386,9 +432,13 @@ def PDLInterp_CreateOperationOp
"ArrayAttr":$attributeNames), [{
build($_builder, $_state, $_builder.getType<pdl::OperationType>(), name,
operands, attributes, attributeNames, types);
- }]>];
- let parser = [{ return ::parseCreateOperationOp(parser, result); }];
- let printer = [{ ::print(p, *this); }];
+ }]>
+ ];
+ let assemblyFormat = [{
+ $name (`(` $operands^ `:` type($operands) `)`)?
+ custom<CreateOperationOpAttributes>($attributes, $attributeNames)
+ (`->` `(` $types^ `:` type($types) `)`)? attr-dict
+ }];
}
//===----------------------------------------------------------------------===//
@@ -419,6 +469,28 @@ def PDLInterp_CreateTypeOp : PDLInterp_Op<"create_type", [NoSideEffect]> {
];
}
+//===----------------------------------------------------------------------===//
+// pdl_interp::CreateTypesOp
+//===----------------------------------------------------------------------===//
+
+def PDLInterp_CreateTypesOp : PDLInterp_Op<"create_types", [NoSideEffect]> {
+ let summary = "Create an interpreter handle to a range of constant `Type`s";
+ let description = [{
+ `pdl_interp.create_types` operations generate a handle within the
+ interpreter for a specific range of constant type values.
+
+ Example:
+
+ ```mlir
+ pdl_interp.create_types [i64, i64]
+ ```
+ }];
+
+ let arguments = (ins TypeArrayAttr:$value);
+ let results = (outs PDL_RangeOf<PDL_Type>:$result);
+ let assemblyFormat = "$value attr-dict";
+}
+
//===----------------------------------------------------------------------===//
// pdl_interp::EraseOp
//===----------------------------------------------------------------------===//
@@ -523,19 +595,20 @@ def PDLInterp_GetDefiningOpOp
let summary = "Get the defining operation of a `Value`";
let description = [{
`pdl_interp.get_defining_op` operations try to get the defining operation
- of a specific value. If the value is not an operation result, null is
- returned.
+ of a specific value or range of values. In the case of range, the defining
+ op of the first value is returned. If the value is not an operation result
+ or range of operand results, null is returned.
Example:
```mlir
- %op = pdl_interp.get_defining_op of %value
+ %op = pdl_interp.get_defining_op of %value : !pdl.value
```
}];
- let arguments = (ins PDL_Value:$value);
+ let arguments = (ins PDL_InstOrRangeOf<PDL_Value>:$value);
let results = (outs PDL_Operation:$operation);
- let assemblyFormat = "`of` $value attr-dict";
+ let assemblyFormat = "`of` $value `:` type($value) attr-dict";
}
//===----------------------------------------------------------------------===//
@@ -562,6 +635,49 @@ def PDLInterp_GetOperandOp : PDLInterp_Op<"get_operand", [NoSideEffect]> {
let assemblyFormat = "$index `of` $operation attr-dict";
}
+//===----------------------------------------------------------------------===//
+// pdl_interp::GetOperandsOp
+//===----------------------------------------------------------------------===//
+
+def PDLInterp_GetOperandsOp : PDLInterp_Op<"get_operands", [NoSideEffect]> {
+ let summary = "Get a specified operand group from an `Operation`";
+ let description = [{
+ `pdl_interp.get_operands` operations try to get a specific operand
+ group from an operation. If the expected result is a single Value, null is
+ returned if the operand group is not of size 1. If a range is expected,
+ null is returned if the operand group is invalid. If no index is provided,
+ the returned operand group corresponds to all operands of the operation.
+
+ Example:
+
+ ```mlir
+ // Get the first group of operands from an operation, and expect a single
+ // element.
+ %operand = pdl_interp.get_operands 0 of %op : !pdl.value
+
+ // Get the first group of operands from an operation.
+ %operands = pdl_interp.get_operands 0 of %op : !pdl.range<value>
+
+ // Get all of the operands from an operation.
+ %operands = pdl_interp.get_operands of %op : !pdl.range<value>
+ ```
+ }];
+
+ let arguments = (ins
+ PDL_Operation:$operation,
+ OptionalAttr<Confined<I32Attr, [IntNonNegative]>>:$index
+ );
+ let results = (outs PDL_InstOrRangeOf<PDL_Value>:$value);
+ let assemblyFormat = "($index^)? `of` $operation `:` type($value) attr-dict";
+ let builders = [
+ OpBuilder<(ins "Type":$resultType, "Value":$operation,
+ "Optional<unsigned>":$index), [{
+ build($_builder, $_state, resultType, operation,
+ index ? $_builder.getI32IntegerAttr(*index) : IntegerAttr());
+ }]>,
+ ];
+}
+
//===----------------------------------------------------------------------===//
// pdl_interp::GetResultOp
//===----------------------------------------------------------------------===//
@@ -586,59 +702,117 @@ def PDLInterp_GetResultOp : PDLInterp_Op<"get_result", [NoSideEffect]> {
let assemblyFormat = "$index `of` $operation attr-dict";
}
+//===----------------------------------------------------------------------===//
+// pdl_interp::GetResultsOp
+//===----------------------------------------------------------------------===//
+
+def PDLInterp_GetResultsOp : PDLInterp_Op<"get_results", [NoSideEffect]> {
+ let summary = "Get a specified result group from an `Operation`";
+ let description = [{
+ `pdl_interp.get_results` operations try to get a specific result group
+ from an operation. If the expected result is a single Value, null is
+ returned if the result group is not of size 1. If a range is expected,
+ null is returned if the result group is invalid. If no index is provided,
+ the returned operand group corresponds to all results of the operation.
+
+ Example:
+
+ ```mlir
+ // Get the first group of results from an operation, and expect a single
+ // element.
+ %result = pdl_interp.get_results 0 of %op : !pdl.value
+
+ // Get the first group of results from an operation.
+ %results = pdl_interp.get_results 0 of %op : !pdl.range<value>
+
+ // Get all of the results from an operation.
+ %results = pdl_interp.get_results of %op : !pdl.range<value>
+ ```
+ }];
+
+ let arguments = (ins
+ PDL_Operation:$operation,
+ OptionalAttr<Confined<I32Attr, [IntNonNegative]>>:$index
+ );
+ let results = (outs PDL_InstOrRangeOf<PDL_Value>:$value);
+ let assemblyFormat = "($index^)? `of` $operation `:` type($value) attr-dict";
+ let builders = [
+ OpBuilder<(ins "Type":$resultType, "Value":$operation,
+ "Optional<unsigned>":$index), [{
+ build($_builder, $_state, resultType, operation,
+ index ? $_builder.getI32IntegerAttr(*index) : IntegerAttr());
+ }]>,
+ OpBuilder<(ins "Value":$operation), [{
+ build($_builder, $_state,
+ pdl::RangeType::get($_builder.getType<pdl::ValueType>()), operation,
+ IntegerAttr());
+ }]>,
+ ];
+}
+
//===----------------------------------------------------------------------===//
// pdl_interp::GetValueTypeOp
//===----------------------------------------------------------------------===//
-// Get a type from the root operation, held in the rewriter context.
-def PDLInterp_GetValueTypeOp : PDLInterp_Op<"get_value_type", [NoSideEffect]> {
+def PDLInterp_GetValueTypeOp : PDLInterp_Op<"get_value_type", [NoSideEffect,
+ TypesMatchWith<"`value` type matches arity of `result`",
+ "result", "value", "getGetValueTypeOpValueType($_self)">]> {
let summary = "Get the result type of a specified `Value`";
let description = [{
`pdl_interp.get_value_type` operations get the resulting type of a specific
- value.
+ value or range thereof.
Example:
```mlir
- %type = pdl_interp.get_value_type of %value
+ // Get the type of a single value.
+ %type = pdl_interp.get_value_type of %value : !pdl.type
+
+ // Get the types of a value range.
+ %type = pdl_interp.get_value_type of %values : !pdl.range<type>
```
}];
- let arguments = (ins PDL_Value:$value);
- let results = (outs PDL_Type:$result);
- let assemblyFormat = "`of` $value attr-dict";
+ let arguments = (ins PDL_InstOrRangeOf<PDL_Value>:$value);
+ let results = (outs PDL_InstOrRangeOf<PDL_Type>:$result);
+ let assemblyFormat = "`of` $value `:` type($result) attr-dict";
let builders = [
OpBuilder<(ins "Value":$value), [{
- build($_builder, $_state, $_builder.getType<pdl::TypeType>(), value);
+ Type valType = value.getType();
+ Type typeType = $_builder.getType<pdl::TypeType>();
+ build($_builder, $_state,
+ valType.isa<pdl::RangeType>() ? pdl::RangeType::get(typeType)
+ : typeType,
+ value);
}]>
];
}
//===----------------------------------------------------------------------===//
-// pdl_interp::InferredTypeOp
+// pdl_interp::InferredTypesOp
//===----------------------------------------------------------------------===//
-def PDLInterp_InferredTypeOp : PDLInterp_Op<"inferred_type"> {
- let summary = "Generate a handle to a Type that is \"inferred\"";
+def PDLInterp_InferredTypesOp : PDLInterp_Op<"inferred_types"> {
+ let summary = "Generate a handle to a range of Types that are \"inferred\"";
let description = [{
- `pdl_interp.inferred_type` operations generate a handle to a type that
- should be inferred. This signals to other operations, such as
- `pdl_interp.create_operation`, that this type should be inferred.
+ `pdl_interp.inferred_types` operations generate handles to ranges of types
+ that should be inferred. This signals to other operations, such as
+ `pdl_interp.create_operation`, that these types should be inferred.
Example:
```mlir
- pdl_interp.inferred_type
+ %types = pdl_interp.inferred_types
```
}];
- let results = (outs PDL_Type:$type);
+ let results = (outs PDL_RangeOf<PDL_Type>:$type);
let assemblyFormat = "attr-dict";
-
let builders = [
OpBuilder<(ins), [{
- build($_builder, $_state, $_builder.getType<pdl::TypeType>());
- }]>,
+ build($_builder, $_state,
+ pdl::RangeType::get($_builder.getType<pdl::TypeType>()));
+ }]>
];
}
@@ -650,7 +824,8 @@ def PDLInterp_IsNotNullOp
: PDLInterp_PredicateOp<"is_not_null", [NoSideEffect]> {
let summary = "Check if a positional value is non-null";
let description = [{
- `pdl_interp.is_not_null` operations check that a positional value exists. On
+ `pdl_interp.is_not_null` operations check that a positional value or range
+ exists. For ranges, this does not mean that the range was simply empty. On
success, this operation branches to the true destination. Otherwise, the
false destination is taken.
@@ -718,12 +893,15 @@ def PDLInterp_ReplaceOp : PDLInterp_Op<"replace"> {
```mlir
// Replace root node with 2 values:
- pdl_interp.replace %root with (%val0, %val1)
+ pdl_interp.replace %root with (%val0, %val1 : !pdl.type, !pdl.type)
```
}];
let arguments = (ins PDL_Operation:$operation,
- Variadic<PDL_Value>:$replValues);
- let assemblyFormat = "$operation `with` `(` $replValues `)` attr-dict";
+ Variadic<PDL_InstOrRangeOf<PDL_Value>>:$replValues);
+ let assemblyFormat = [{
+ $operation `with` ` ` `(` ($replValues^ `:` type($replValues))? `)`
+ attr-dict
+ }];
}
//===----------------------------------------------------------------------===//
@@ -886,9 +1064,9 @@ def PDLInterp_SwitchTypeOp : PDLInterp_SwitchOp<"switch_type", [NoSideEffect]> {
}];
let builders = [
- OpBuilder<(ins "Value":$edge, "TypeRange":$types, "Block *":$defaultDest,
- "BlockRange":$dests), [{
- build($_builder, $_state, edge, $_builder.getTypeArrayAttr(types),
+ OpBuilder<(ins "Value":$edge, "ArrayRef<Attribute>":$types,
+ "Block *":$defaultDest, "BlockRange":$dests), [{
+ build($_builder, $_state, edge, $_builder.getArrayAttr(types),
defaultDest, dests);
}]>,
];
@@ -898,4 +1076,45 @@ def PDLInterp_SwitchTypeOp : PDLInterp_SwitchOp<"switch_type", [NoSideEffect]> {
}];
}
+//===----------------------------------------------------------------------===//
+// pdl_interp::SwitchTypesOp
+//===----------------------------------------------------------------------===//
+
+def PDLInterp_SwitchTypesOp : PDLInterp_SwitchOp<"switch_types",
+ [NoSideEffect]> {
+ let summary = "Switch on a range of `Type` values";
+ let description = [{
+ `pdl_interp.switch_types` operations compare a range of types with a set of
+ statically known ranges. If the value matches one of the provided case
+ values the destination for that case value is taken, otherwise the default
+ destination is taken.
+
+ Example:
+
+ ```mlir
+ pdl_interp.switch_types %type is [[i32], [i64, i64]] -> ^i32Dest, ^i64Dest, ^defaultDest
+ ```
+ }];
+
+ let arguments = (ins
+ PDL_RangeOf<PDL_Type>:$value,
+ TypedArrayAttrBase<TypeArrayAttr, "type-array array attribute">:$caseValues
+ );
+ let assemblyFormat = [{
+ $value `to` $caseValues `(` $cases `)` attr-dict `->` $defaultDest
+ }];
+
+ let builders = [
+ OpBuilder<(ins "Value":$edge, "ArrayRef<Attribute>":$types,
+ "Block *":$defaultDest, "BlockRange":$dests), [{
+ build($_builder, $_state, edge, $_builder.getArrayAttr(types),
+ defaultDest, dests);
+ }]>,
+ ];
+
+ let extraClassDeclaration = [{
+ auto getCaseTypes() { return caseValues().getAsRange<ArrayAttr>(); }
+ }];
+}
+
#endif // MLIR_DIALECT_PDLINTERP_IR_PDLINTERPOPS
diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 70a5236d885f..5a7037af63d2 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -1439,7 +1439,7 @@ class TypedArrayAttrBase<Attr element, string summary>: ArrayAttrBase<
CPred<"$_self.isa<::mlir::ArrayAttr>()">,
// Guarantee all elements satisfy the constraints from `element`
Concat<"::llvm::all_of($_self.cast<::mlir::ArrayAttr>(), "
- "[](::mlir::Attribute attr) { return ",
+ "[&](::mlir::Attribute attr) { return ",
SubstLeaves<"$_self", "attr", element.predicate>,
"; })">]>,
summary> {
diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
index d1da22671d95..57a0885c03c8 100644
--- a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
+++ b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
@@ -56,9 +56,8 @@ struct PatternLowering {
/// Create an interpreter switch predicate operation, with a provided default
/// and several case destinations.
- void generateSwitch(Block *currentBlock, Qualifier *question, Value val,
- Block *defaultDest,
- ArrayRef<std::pair<Qualifier *, Block *>> dests);
+ void generateSwitch(SwitchNode *switchNode, Block *currentBlock,
+ Qualifier *question, Value val, Block *defaultDest);
/// Create the interpreter operations to record a successful pattern match.
void generateRecordMatch(Block *currentBlock, Block *nextBlock,
@@ -88,9 +87,15 @@ struct PatternLowering {
void generateRewriter(pdl::ResultOp resultOp,
DenseMap<Value, Value> &rewriteValues,
function_ref<Value(Value)> mapRewriteValue);
+ void generateRewriter(pdl::ResultsOp resultOp,
+ DenseMap<Value, Value> &rewriteValues,
+ function_ref<Value(Value)> mapRewriteValue);
void generateRewriter(pdl::TypeOp typeOp,
DenseMap<Value, Value> &rewriteValues,
function_ref<Value(Value)> mapRewriteValue);
+ void generateRewriter(pdl::TypesOp typeOp,
+ DenseMap<Value, Value> &rewriteValues,
+ function_ref<Value(Value)> mapRewriteValue);
/// Generate the values used for resolving the result types of an operation
/// created within a dag rewriter region.
@@ -200,12 +205,7 @@ Block *PatternLowering::generateMatcher(MatcherNode &node) {
// Generate code for a switch node.
} else if (auto *switchNode = dyn_cast<SwitchNode>(&node)) {
- // Collect the next blocks for all of the children and generate a switch.
- llvm::MapVector<Qualifier *, Block *> children;
- for (auto &it : switchNode->getChildren())
- children.insert({it.first, generateMatcher(*it.second)});
- generateSwitch(block, node.getQuestion(), val, nextBlock,
- children.takeVector());
+ generateSwitch(switchNode, block, node.getQuestion(), val, nextBlock);
// Generate code for a success node.
} else if (auto *successNode = dyn_cast<SuccessNode>(&node)) {
@@ -242,6 +242,14 @@ Value PatternLowering::getValueAt(Block *cur, Position *pos) {
operandPos->getOperandNumber());
break;
}
+ case Predicates::OperandGroupPos: {
+ auto *operandPos = cast<OperandGroupPosition>(pos);
+ Type valueTy = builder.getType<pdl::ValueType>();
+ value = builder.create<pdl_interp::GetOperandsOp>(
+ loc, operandPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy,
+ parentVal, operandPos->getOperandGroupNumber());
+ break;
+ }
case Predicates::AttributePos: {
auto *attrPos = cast<AttributePosition>(pos);
value = builder.create<pdl_interp::GetAttributeOp>(
@@ -250,10 +258,10 @@ Value PatternLowering::getValueAt(Block *cur, Position *pos) {
break;
}
case Predicates::TypePos: {
- if (parentVal.getType().isa<pdl::ValueType>())
- value = builder.create<pdl_interp::GetValueTypeOp>(loc, parentVal);
- else
+ if (parentVal.getType().isa<pdl::AttributeType>())
value = builder.create<pdl_interp::GetAttributeTypeOp>(loc, parentVal);
+ else
+ value = builder.create<pdl_interp::GetValueTypeOp>(loc, parentVal);
break;
}
case Predicates::ResultPos: {
@@ -263,6 +271,14 @@ Value PatternLowering::getValueAt(Block *cur, Position *pos) {
resPos->getResultNumber());
break;
}
+ case Predicates::ResultGroupPos: {
+ auto *resPos = cast<ResultGroupPosition>(pos);
+ Type valueTy = builder.getType<pdl::ValueType>();
+ value = builder.create<pdl_interp::GetResultsOp>(
+ loc, resPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy,
+ parentVal, resPos->getResultGroupNumber());
+ break;
+ }
default:
llvm_unreachable("Generating unknown Position getter");
break;
@@ -277,7 +293,8 @@ void PatternLowering::generatePredicate(Block *currentBlock,
Block *falseDest) {
builder.setInsertionPointToEnd(currentBlock);
Location loc = val.getLoc();
- switch (question->getKind()) {
+ Predicates::Kind kind = question->getKind();
+ switch (kind) {
case Predicates::IsNotNullQuestion:
builder.create<pdl_interp::IsNotNullOp>(loc, val, trueDest, falseDest);
break;
@@ -289,8 +306,12 @@ void PatternLowering::generatePredicate(Block *currentBlock,
}
case Predicates::TypeQuestion: {
auto *ans = cast<TypeAnswer>(answer);
- builder.create<pdl_interp::CheckTypeOp>(
- loc, val, TypeAttr::get(ans->getValue()), trueDest, falseDest);
+ if (val.getType().isa<pdl::RangeType>())
+ builder.create<pdl_interp::CheckTypesOp>(
+ loc, val, ans->getValue().cast<ArrayAttr>(), trueDest, falseDest);
+ else
+ builder.create<pdl_interp::CheckTypeOp>(
+ loc, val, ans->getValue().cast<TypeAttr>(), trueDest, falseDest);
break;
}
case Predicates::AttributeQuestion: {
@@ -299,18 +320,20 @@ void PatternLowering::generatePredicate(Block *currentBlock,
trueDest, falseDest);
break;
}
- case Predicates::OperandCountQuestion: {
- auto *unsignedAnswer = cast<UnsignedAnswer>(answer);
+ case Predicates::OperandCountAtLeastQuestion:
+ case Predicates::OperandCountQuestion:
builder.create<pdl_interp::CheckOperandCountOp>(
- loc, val, unsignedAnswer->getValue(), trueDest, falseDest);
+ loc, val, cast<UnsignedAnswer>(answer)->getValue(),
+ /*compareAtLeast=*/kind == Predicates::OperandCountAtLeastQuestion,
+ trueDest, falseDest);
break;
- }
- case Predicates::ResultCountQuestion: {
- auto *unsignedAnswer = cast<UnsignedAnswer>(answer);
+ case Predicates::ResultCountAtLeastQuestion:
+ case Predicates::ResultCountQuestion:
builder.create<pdl_interp::CheckResultCountOp>(
- loc, val, unsignedAnswer->getValue(), trueDest, falseDest);
+ loc, val, cast<UnsignedAnswer>(answer)->getValue(),
+ /*compareAtLeast=*/kind == Predicates::ResultCountAtLeastQuestion,
+ trueDest, falseDest);
break;
- }
case Predicates::EqualToQuestion: {
auto *equalToQuestion = cast<EqualToQuestion>(question);
builder.create<pdl_interp::AreEqualOp>(
@@ -336,7 +359,7 @@ void PatternLowering::generatePredicate(Block *currentBlock,
template <typename OpT, typename PredT, typename ValT = typename PredT::KeyTy>
static void createSwitchOp(Value val, Block *defaultDest, OpBuilder &builder,
- ArrayRef<std::pair<Qualifier *, Block *>> dests) {
+ llvm::MapVector<Qualifier *, Block *> &dests) {
std::vector<ValT> values;
std::vector<Block *> blocks;
values.reserve(dests.size());
@@ -348,27 +371,83 @@ static void createSwitchOp(Value val, Block *defaultDest, OpBuilder &builder,
builder.create<OpT>(val.getLoc(), val, values, defaultDest, blocks);
}
-void PatternLowering::generateSwitch(
- Block *currentBlock, Qualifier *question, Value val, Block *defaultDest,
- ArrayRef<std::pair<Qualifier *, Block *>> dests) {
+void PatternLowering::generateSwitch(SwitchNode *switchNode,
+ Block *currentBlock, Qualifier *question,
+ Value val, Block *defaultDest) {
+ // If the switch question is not an exact answer, i.e. for the `at_least`
+ // cases, we generate a special block sequence.
+ Predicates::Kind kind = question->getKind();
+ if (kind == Predicates::OperandCountAtLeastQuestion ||
+ kind == Predicates::ResultCountAtLeastQuestion) {
+ // Order the children such that the cases are in reverse numerical order.
+ SmallVector<unsigned> sortedChildren(
+ llvm::seq<unsigned>(0, switchNode->getChildren().size()));
+ llvm::sort(sortedChildren, [&](unsigned lhs, unsigned rhs) {
+ return cast<UnsignedAnswer>(switchNode->getChild(lhs).first)->getValue() >
+ cast<UnsignedAnswer>(switchNode->getChild(rhs).first)->getValue();
+ });
+
+ // Build the destination for each child using the next highest child as a
+ // a failure destination. This essentially creates the following control
+ // flow:
+ //
+ // if (operand_count < 1)
+ // goto failure
+ // if (child1.match())
+ // ...
+ //
+ // if (operand_count < 2)
+ // goto failure
+ // if (child2.match())
+ // ...
+ //
+ // failure:
+ // ...
+ //
+ failureBlockStack.push_back(defaultDest);
+ for (unsigned idx : sortedChildren) {
+ auto &child = switchNode->getChild(idx);
+ Block *childBlock = generateMatcher(*child.second);
+ Block *predicateBlock = builder.createBlock(childBlock);
+ generatePredicate(predicateBlock, question, child.first, val, childBlock,
+ defaultDest);
+ failureBlockStack.back() = predicateBlock;
+ }
+ Block *firstPredicateBlock = failureBlockStack.pop_back_val();
+ currentBlock->getOperations().splice(currentBlock->end(),
+ firstPredicateBlock->getOperations());
+ firstPredicateBlock->erase();
+ return;
+ }
+
+ // Otherwise, generate each of the children and generate an interpreter
+ // switch.
+ llvm::MapVector<Qualifier *, Block *> children;
+ for (auto &it : switchNode->getChildren())
+ children.insert({it.first, generateMatcher(*it.second)});
builder.setInsertionPointToEnd(currentBlock);
+
switch (question->getKind()) {
case Predicates::OperandCountQuestion:
return createSwitchOp<pdl_interp::SwitchOperandCountOp, UnsignedAnswer,
- int32_t>(val, defaultDest, builder, dests);
+ int32_t>(val, defaultDest, builder, children);
case Predicates::ResultCountQuestion:
return createSwitchOp<pdl_interp::SwitchResultCountOp, UnsignedAnswer,
- int32_t>(val, defaultDest, builder, dests);
+ int32_t>(val, defaultDest, builder, children);
case Predicates::OperationNameQuestion:
return createSwitchOp<pdl_interp::SwitchOperationNameOp,
OperationNameAnswer>(val, defaultDest, builder,
- dests);
+ children);
case Predicates::TypeQuestion:
+ if (val.getType().isa<pdl::RangeType>()) {
+ return createSwitchOp<pdl_interp::SwitchTypesOp, TypeAnswer>(
+ val, defaultDest, builder, children);
+ }
return createSwitchOp<pdl_interp::SwitchTypeOp, TypeAnswer>(
- val, defaultDest, builder, dests);
+ val, defaultDest, builder, children);
case Predicates::AttributeQuestion:
return createSwitchOp<pdl_interp::SwitchAttributeOp, AttributeAnswer>(
- val, defaultDest, builder, dests);
+ val, defaultDest, builder, children);
default:
llvm_unreachable("Generating unknown switch predicate.");
}
@@ -436,6 +515,11 @@ SymbolRefAttr PatternLowering::generateRewriter(
return newValue = builder.create<pdl_interp::CreateTypeOp>(
typeOp.getLoc(), type);
}
+ } else if (pdl::TypesOp typeOp = dyn_cast<pdl::TypesOp>(oldOp)) {
+ if (ArrayAttr type = typeOp.typesAttr()) {
+ return newValue = builder.create<pdl_interp::CreateTypesOp>(
+ typeOp.getLoc(), typeOp.getType(), type);
+ }
}
// Otherwise, add this as an input to the rewriter.
@@ -460,10 +544,10 @@ SymbolRefAttr PatternLowering::generateRewriter(
for (Operation &rewriteOp : *rewriter.getBody()) {
llvm::TypeSwitch<Operation *>(&rewriteOp)
.Case<pdl::ApplyNativeRewriteOp, pdl::AttributeOp, pdl::EraseOp,
- pdl::OperationOp, pdl::ReplaceOp, pdl::ResultOp, pdl::TypeOp>(
- [&](auto op) {
- this->generateRewriter(op, rewriteValues, mapRewriteValue);
- });
+ pdl::OperationOp, pdl::ReplaceOp, pdl::ResultOp, pdl::ResultsOp,
+ pdl::TypeOp, pdl::TypesOp>([&](auto op) {
+ this->generateRewriter(op, rewriteValues, mapRewriteValue);
+ });
}
}
@@ -529,14 +613,39 @@ void PatternLowering::generateRewriter(
rewriteValues[operationOp.op()] = createdOp;
// Generate accesses for any results that have their types constrained.
- for (auto it : llvm::enumerate(operationOp.types())) {
+ // Handle the case where there is a single range representing all of the
+ // result types.
+ OperandRange resultTys = operationOp.types();
+ if (resultTys.size() == 1 && resultTys[0].getType().isa<pdl::RangeType>()) {
+ Value &type = rewriteValues[resultTys[0]];
+ if (!type) {
+ auto results = builder.create<pdl_interp::GetResultsOp>(loc, createdOp);
+ type = builder.create<pdl_interp::GetValueTypeOp>(loc, results);
+ }
+ return;
+ }
+
+ // Otherwise, populate the individual results.
+ bool seenVariableLength = false;
+ Type valueTy = builder.getType<pdl::ValueType>();
+ Type valueRangeTy = pdl::RangeType::get(valueTy);
+ for (auto it : llvm::enumerate(resultTys)) {
Value &type = rewriteValues[it.value()];
if (type)
continue;
-
- Value getResultVal = builder.create<pdl_interp::GetResultOp>(
- loc, builder.getType<pdl::ValueType>(), createdOp, it.index());
- type = builder.create<pdl_interp::GetValueTypeOp>(loc, getResultVal);
+ bool isVariadic = it.value().getType().isa<pdl::RangeType>();
+ seenVariableLength |= isVariadic;
+
+ // After a variable length result has been seen, we need to use result
+ // groups because the exact index of the result is not statically known.
+ Value resultVal;
+ if (seenVariableLength)
+ resultVal = builder.create<pdl_interp::GetResultsOp>(
+ loc, isVariadic ? valueRangeTy : valueTy, createdOp, it.index());
+ else
+ resultVal = builder.create<pdl_interp::GetResultOp>(
+ loc, valueTy, createdOp, it.index());
+ type = builder.create<pdl_interp::GetValueTypeOp>(loc, resultVal);
}
}
@@ -549,11 +658,12 @@ void PatternLowering::generateRewriter(
// for using an operation for simplicitly, but the interpreter isn't as
// user facing.
if (Value replOp = replaceOp.replOperation()) {
- pdl::OperationOp op = cast<pdl::OperationOp>(replOp.getDefiningOp());
- for (unsigned i = 0, e = op.types().size(); i < e; ++i)
- replOperands.push_back(builder.create<pdl_interp::GetResultOp>(
- replOp.getLoc(), builder.getType<pdl::ValueType>(),
- mapRewriteValue(replOp), i));
+ // Don't use replace if we know the replaced operation has no results.
+ auto opOp = replaceOp.operation().getDefiningOp<pdl::OperationOp>();
+ if (!opOp || !opOp.types().empty()) {
+ replOperands.push_back(builder.create<pdl_interp::GetResultsOp>(
+ replOp.getLoc(), mapRewriteValue(replOp)));
+ }
} else {
for (Value operand : replaceOp.replValues())
replOperands.push_back(mapRewriteValue(operand));
@@ -578,15 +688,33 @@ void PatternLowering::generateRewriter(
mapRewriteValue(resultOp.parent()), resultOp.index());
}
+void PatternLowering::generateRewriter(
+ pdl::ResultsOp resultOp, DenseMap<Value, Value> &rewriteValues,
+ function_ref<Value(Value)> mapRewriteValue) {
+ rewriteValues[resultOp] = builder.create<pdl_interp::GetResultsOp>(
+ resultOp.getLoc(), resultOp.getType(), mapRewriteValue(resultOp.parent()),
+ resultOp.index());
+}
+
void PatternLowering::generateRewriter(
pdl::TypeOp typeOp, DenseMap<Value, Value> &rewriteValues,
function_ref<Value(Value)> mapRewriteValue) {
// If the type isn't constant, the users (e.g. OperationOp) will resolve this
// type.
if (TypeAttr typeAttr = typeOp.typeAttr()) {
- Value newType =
+ rewriteValues[typeOp] =
builder.create<pdl_interp::CreateTypeOp>(typeOp.getLoc(), typeAttr);
- rewriteValues[typeOp] = newType;
+ }
+}
+
+void PatternLowering::generateRewriter(
+ pdl::TypesOp typeOp, DenseMap<Value, Value> &rewriteValues,
+ function_ref<Value(Value)> mapRewriteValue) {
+ // If the type isn't constant, the users (e.g. OperationOp) will resolve this
+ // type.
+ if (ArrayAttr typeAttr = typeOp.typesAttr()) {
+ rewriteValues[typeOp] = builder.create<pdl_interp::CreateTypesOp>(
+ typeOp.getLoc(), typeOp.getType(), typeAttr);
}
}
@@ -594,28 +722,38 @@ void PatternLowering::generateOperationResultTypeRewriter(
pdl::OperationOp op, SmallVectorImpl<Value> &types,
DenseMap<Value, Value> &rewriteValues,
function_ref<Value(Value)> mapRewriteValue) {
- // Functor that returns if the given use can be used to infer a type.
+ // Look for an operation that was replaced by `op`. The result types will be
+ // inferred from the results that were replaced.
Block *rewriterBlock = op->getBlock();
- auto getReplacedOperationFrom = [&](OpOperand &use) -> Operation * {
+ Value replacedOp;
+ for (OpOperand &use : op.op().getUses()) {
// Check that the use corresponds to a ReplaceOp and that it is the
// replacement value, not the operation being replaced.
pdl::ReplaceOp replOpUser = dyn_cast<pdl::ReplaceOp>(use.getOwner());
if (!replOpUser || use.getOperandNumber() == 0)
- return nullptr;
+ continue;
// Make sure the replaced operation was defined before this one.
- Operation *replacedOp = replOpUser.operation().getDefiningOp();
- if (replacedOp->getBlock() != rewriterBlock ||
- replacedOp->isBeforeInBlock(op))
- return replacedOp;
- return nullptr;
- };
+ Value replOpVal = replOpUser.operation();
+ Operation *replacedOp = replOpVal.getDefiningOp();
+ if (replacedOp->getBlock() == rewriterBlock &&
+ !replacedOp->isBeforeInBlock(op))
+ continue;
+
+ Value replacedOpResults = builder.create<pdl_interp::GetResultsOp>(
+ replacedOp->getLoc(), mapRewriteValue(replOpVal));
+ types.push_back(builder.create<pdl_interp::GetValueTypeOp>(
+ replacedOp->getLoc(), replacedOpResults));
+ return;
+ }
+
+ // Check if the operation has type inference support.
+ if (op.hasTypeInference()) {
+ types.push_back(builder.create<pdl_interp::InferredTypesOp>(op.getLoc()));
+ return;
+ }
- // If non-None/non-Null, this is an operation that is replaced by `op`.
- // If Null, there is no full replacement operation for `op`.
- // If None, a replacement operation hasn't been searched for.
- Optional<Operation *> fullReplacedOperation;
- bool hasTypeInference = op.hasTypeInference();
- auto resultTypeValues = op.types();
+ // Otherwise, handle inference for each of the result types individually.
+ OperandRange resultTypeValues = op.types();
types.reserve(resultTypeValues.size());
for (auto it : llvm::enumerate(resultTypeValues)) {
Value resultType = it.value();
@@ -632,30 +770,11 @@ void PatternLowering::generateOperationResultTypeRewriter(
continue;
}
- // Check if the operation has type inference support.
- if (hasTypeInference) {
- types.push_back(builder.create<pdl_interp::InferredTypeOp>(op.getLoc()));
- continue;
- }
-
- // Look for an operation that was replaced by `op`. The result type will be
- // inferred from the result that was replaced. There is guaranteed to be a
- // replacement for either the op, or this specific result. Note that this is
- // guaranteed by the verifier of `pdl::OperationOp`.
- Operation *replacedOp = nullptr;
- if (!fullReplacedOperation.hasValue()) {
- for (OpOperand &use : op.op().getUses())
- if ((replacedOp = getReplacedOperationFrom(use)))
- break;
- fullReplacedOperation = replacedOp;
- assert(fullReplacedOperation &&
- "expected replaced op to infer a result type from");
- } else {
- replacedOp = fullReplacedOperation.getValue();
- }
-
- auto replOpOp = cast<pdl::OperationOp>(replacedOp);
- types.push_back(mapRewriteValue(replOpOp.types()[it.index()]));
+ // The verifier asserts that the result types of each pdl.operation can be
+ // inferred. If we reach here, there is a bug either in the logic above or
+ // in the verifier for pdl.operation.
+ op->emitOpError() << "unable to infer result type for operation";
+ llvm_unreachable("unable to infer result type for operation");
}
}
diff --git a/mlir/lib/Conversion/PDLToPDLInterp/Predicate.cpp b/mlir/lib/Conversion/PDLToPDLInterp/Predicate.cpp
index 3eaeb13cffc0..8983ecb8d324 100644
--- a/mlir/lib/Conversion/PDLToPDLInterp/Predicate.cpp
+++ b/mlir/lib/Conversion/PDLToPDLInterp/Predicate.cpp
@@ -17,6 +17,13 @@ using namespace mlir::pdl_to_pdl_interp;
Position::~Position() {}
+/// Returns the depth of the first ancestor operation position.
+unsigned Position::getOperationDepth() const {
+ if (const auto *operationPos = dyn_cast<OperationPosition>(this))
+ return operationPos->getDepth();
+ return parent->getOperationDepth();
+}
+
//===----------------------------------------------------------------------===//
// AttributePosition
@@ -32,18 +39,8 @@ OperandPosition::OperandPosition(const KeyTy &key) : Base(key) {
}
//===----------------------------------------------------------------------===//
-// OperationPosition
-
-OperationPosition *OperationPosition::get(StorageUniquer &uniquer,
- ArrayRef<unsigned> index) {
- assert(!index.empty() && "expected at least two indices");
-
- // Set the parent position if this isn't the root.
- Position *parent = nullptr;
- if (index.size() > 1) {
- auto *node = OperationPosition::get(uniquer, index.drop_back());
- parent = OperandPosition::get(uniquer, std::make_pair(node, index.back()));
- }
- return uniquer.get<OperationPosition>(
- [parent](OperationPosition *node) { node->parent = parent; }, index);
+// OperandGroupPosition
+
+OperandGroupPosition::OperandGroupPosition(const KeyTy &key) : Base(key) {
+ parent = std::get<0>(key);
}
diff --git a/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h b/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h
index 4d5c909465da..1c8fece05e07 100644
--- a/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h
+++ b/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h
@@ -45,8 +45,10 @@ enum Kind : unsigned {
/// Positions, ordered by decreasing priority.
OperationPos,
OperandPos,
+ OperandGroupPos,
AttributePos,
ResultPos,
+ ResultGroupPos,
TypePos,
// Questions, ordered by dependency and decreasing priority.
@@ -54,7 +56,9 @@ enum Kind : unsigned {
OperationNameQuestion,
TypeQuestion,
AttributeQuestion,
+ OperandCountAtLeastQuestion,
OperandCountQuestion,
+ ResultCountAtLeastQuestion,
ResultCountQuestion,
EqualToQuestion,
ConstraintQuestion,
@@ -129,21 +133,15 @@ struct OperationPosition;
/// predicates, and assists generating bytecode and memory management.
///
/// Operation positions form the base of other positions, which are formed
-/// relative to a parent operation, e.g. OperandPosition<[0] -> 1>. Operations
-/// are indexed by child index: [0, 1, 2] refers to the 3rd child of the 2nd
-/// child of the root operation.
-///
-/// Positions are linked to their parent position, which describes how to obtain
-/// a positional value. As a concrete example, getting OperationPosition<[0, 1]>
-/// would be `root->getOperand(1)->getDefiningOp()`, so its parent is
-/// OperandPosition<[0] -> 1>, whose parent is OperationPosition<[0]>.
+/// relative to a parent operation. Operations are anchored at Operand nodes,
+/// except for the root operation which is parentless.
class Position : public StorageUniquer::BaseStorage {
public:
explicit Position(Predicates::Kind kind) : kind(kind) {}
virtual ~Position();
- /// Returns the base node position. This is an array of indices.
- virtual ArrayRef<unsigned> getIndex() const = 0;
+ /// Returns the depth of the first ancestor operation position.
+ unsigned getOperationDepth() const;
/// Returns the parent position. The root operation position has no parent.
Position *getParent() const { return parent; }
@@ -170,9 +168,6 @@ struct AttributePosition
Predicates::AttributePos> {
explicit AttributePosition(const KeyTy &key);
- /// Returns the index of this position.
- ArrayRef<unsigned> getIndex() const final { return parent->getIndex(); }
-
/// Returns the attribute name of this position.
Identifier getName() const { return key.second; }
};
@@ -187,42 +182,61 @@ struct OperandPosition
Predicates::OperandPos> {
explicit OperandPosition(const KeyTy &key);
- /// Returns the index of this position.
- ArrayRef<unsigned> getIndex() const final { return parent->getIndex(); }
-
/// Returns the operand number of this position.
unsigned getOperandNumber() const { return key.second; }
};
+//===----------------------------------------------------------------------===//
+// OperandGroupPosition
+
+/// A position describing an operand group of an operation.
+struct OperandGroupPosition
+ : public PredicateBase<
+ OperandGroupPosition, Position,
+ std::tuple<OperationPosition *, Optional<unsigned>, bool>,
+ Predicates::OperandGroupPos> {
+ explicit OperandGroupPosition(const KeyTy &key);
+
+ /// Returns a hash suitable for the given keytype.
+ static llvm::hash_code hashKey(const KeyTy &key) {
+ return llvm::hash_value(key);
+ }
+
+ /// Returns the group number of this position. If None, this group refers to
+ /// all operands.
+ Optional<unsigned> getOperandGroupNumber() const { return std::get<1>(key); }
+
+ /// Returns if the operand group has unknown size. If false, the operand group
+ /// has at max one element.
+ bool isVariadic() const { return std::get<2>(key); }
+};
+
//===----------------------------------------------------------------------===//
// OperationPosition
/// An operation position describes an operation node in the IR. Other position
/// kinds are formed with respect to an operation position.
-struct OperationPosition
- : public PredicateBase<OperationPosition, Position, ArrayRef<unsigned>,
- Predicates::OperationPos> {
- using Base::Base;
+struct OperationPosition : public PredicateBase<OperationPosition, Position,
+ std::pair<Position *, unsigned>,
+ Predicates::OperationPos> {
+ explicit OperationPosition(const KeyTy &key) : Base(key) {
+ parent = key.first;
+ }
- /// Gets the root position, which is always [0].
+ /// Gets the root position.
static OperationPosition *getRoot(StorageUniquer &uniquer) {
- return get(uniquer, ArrayRef<unsigned>(0));
+ return Base::get(uniquer, nullptr, 0);
}
- /// Gets a node position for the given index.
- static OperationPosition *get(StorageUniquer &uniquer,
- ArrayRef<unsigned> index);
-
- /// Constructs an instance with the given storage allocator.
- static OperationPosition *construct(StorageUniquer::StorageAllocator &alloc,
- ArrayRef<unsigned> key) {
- return Base::construct(alloc, alloc.copyInto(key));
+ /// Gets an operation position with the given parent.
+ static OperationPosition *get(StorageUniquer &uniquer, Position *parent) {
+ return Base::get(uniquer, parent, parent->getOperationDepth() + 1);
}
- /// Returns the index of this position.
- ArrayRef<unsigned> getIndex() const final { return key; }
+ /// Returns the depth of this position.
+ unsigned getDepth() const { return key.second; }
/// Returns if this operation position corresponds to the root.
- bool isRoot() const { return key.size() == 1 && key[0] == 0; }
+ bool isRoot() const { return getDepth() == 0; }
};
//===----------------------------------------------------------------------===//
@@ -235,13 +249,37 @@ struct ResultPosition
Predicates::ResultPos> {
explicit ResultPosition(const KeyTy &key) : Base(key) { parent = key.first; }
- /// Returns the index of this position.
- ArrayRef<unsigned> getIndex() const final { return key.first->getIndex(); }
-
/// Returns the result number of this position.
unsigned getResultNumber() const { return key.second; }
};
+//===----------------------------------------------------------------------===//
+// ResultGroupPosition
+
+/// A position describing a result group of an operation.
+struct ResultGroupPosition
+ : public PredicateBase<
+ ResultGroupPosition, Position,
+ std::tuple<OperationPosition *, Optional<unsigned>, bool>,
+ Predicates::ResultGroupPos> {
+ explicit ResultGroupPosition(const KeyTy &key) : Base(key) {
+ parent = std::get<0>(key);
+ }
+
+ /// Returns a hash suitable for the given keytype.
+ static llvm::hash_code hashKey(const KeyTy &key) {
+ return llvm::hash_value(key);
+ }
+
+ /// Returns the group number of this position. If None, this group refers to
+ /// all results.
+ Optional<unsigned> getResultGroupNumber() const { return std::get<1>(key); }
+
+ /// Returns if the result group has unknown size. If false, the result group
+ /// has at max one element.
+ bool isVariadic() const { return std::get<2>(key); }
+};
+
//===----------------------------------------------------------------------===//
// TypePosition
@@ -250,14 +288,11 @@ struct ResultPosition
struct TypePosition : public PredicateBase<TypePosition, Position, Position *,
Predicates::TypePos> {
explicit TypePosition(const KeyTy &key) : Base(key) {
- assert((isa<AttributePosition>(key) || isa<OperandPosition>(key) ||
- isa<ResultPosition>(key)) &&
+ assert((isa<AttributePosition, OperandPosition, OperandGroupPosition,
+ ResultPosition, ResultGroupPosition>(key)) &&
"expected parent to be an attribute, operand, or result");
parent = key;
}
-
- /// Returns the index of this position.
- ArrayRef<unsigned> getIndex() const final { return key->getIndex(); }
};
//===----------------------------------------------------------------------===//
@@ -311,8 +346,9 @@ struct TrueAnswer
using Base::Base;
};
-/// An Answer representing a `Type` value.
-struct TypeAnswer : public PredicateBase<TypeAnswer, Qualifier, Type,
+/// An Answer representing a `Type` value. The value is stored as either a
+/// TypeAttr, or an ArrayAttr of TypeAttr.
+struct TypeAnswer : public PredicateBase<TypeAnswer, Qualifier, Attribute,
Predicates::TypeAnswer> {
using Base::Base;
};
@@ -365,6 +401,9 @@ struct IsNotNullQuestion
struct OperandCountQuestion
: public PredicateBase<OperandCountQuestion, Qualifier, void,
Predicates::OperandCountQuestion> {};
+struct OperandCountAtLeastQuestion
+ : public PredicateBase<OperandCountAtLeastQuestion, Qualifier, void,
+ Predicates::OperandCountAtLeastQuestion> {};
/// Compare the name of an operation with a known value.
struct OperationNameQuestion
@@ -375,6 +414,9 @@ struct OperationNameQuestion
struct ResultCountQuestion
: public PredicateBase<ResultCountQuestion, Qualifier, void,
Predicates::ResultCountQuestion> {};
+struct ResultCountAtLeastQuestion
+ : public PredicateBase<ResultCountAtLeastQuestion, Qualifier, void,
+ Predicates::ResultCountAtLeastQuestion> {};
/// Compare the type of an attribute or value with a known type.
struct TypeQuestion : public PredicateBase<TypeQuestion, Qualifier, void,
@@ -392,8 +434,10 @@ class PredicateUniquer : public StorageUniquer {
// Register the types of Positions with the uniquer.
registerParametricStorageType<AttributePosition>();
registerParametricStorageType<OperandPosition>();
+ registerParametricStorageType<OperandGroupPosition>();
registerParametricStorageType<OperationPosition>();
registerParametricStorageType<ResultPosition>();
+ registerParametricStorageType<ResultGroupPosition>();
registerParametricStorageType<TypePosition>();
// Register the types of Questions with the uniquer.
@@ -409,8 +453,10 @@ class PredicateUniquer : public StorageUniquer {
registerSingletonStorageType<AttributeQuestion>();
registerSingletonStorageType<IsNotNullQuestion>();
registerSingletonStorageType<OperandCountQuestion>();
+ registerSingletonStorageType<OperandCountAtLeastQuestion>();
registerSingletonStorageType<OperationNameQuestion>();
registerSingletonStorageType<ResultCountQuestion>();
+ registerSingletonStorageType<ResultCountAtLeastQuestion>();
registerSingletonStorageType<TypeQuestion>();
}
};
@@ -433,10 +479,10 @@ class PredicateBuilder {
Position *getRoot() { return OperationPosition::getRoot(uniquer); }
/// Returns the parent position defining the value held by the given operand.
- OperationPosition *getParent(OperandPosition *p) {
- std::vector<unsigned> index = p->getIndex();
- index.push_back(p->getOperandNumber());
- return OperationPosition::get(uniquer, index);
+ OperationPosition *getOperandDefiningOp(Position *p) {
+ assert((isa<OperandPosition, OperandGroupPosition>(p)) &&
+ "expected operand position");
+ return OperationPosition::get(uniquer, p);
}
/// Returns an attribute position for an attribute of the given operation.
@@ -449,11 +495,29 @@ class PredicateBuilder {
return OperandPosition::get(uniquer, p, operand);
}
+ /// Returns a position for a group of operands of the given operation.
+ Position *getOperandGroup(OperationPosition *p, Optional<unsigned> group,
+ bool isVariadic) {
+ return OperandGroupPosition::get(uniquer, p, group, isVariadic);
+ }
+ Position *getAllOperands(OperationPosition *p) {
+ return getOperandGroup(p, /*group=*/llvm::None, /*isVariadic=*/true);
+ }
+
/// Returns a result position for a result of the given operation.
Position *getResult(OperationPosition *p, unsigned result) {
return ResultPosition::get(uniquer, p, result);
}
+ /// Returns a position for a group of results of the given operation.
+ Position *getResultGroup(OperationPosition *p, Optional<unsigned> group,
+ bool isVariadic) {
+ return ResultGroupPosition::get(uniquer, p, group, isVariadic);
+ }
+ Position *getAllResults(OperationPosition *p) {
+ return getResultGroup(p, /*group=*/llvm::None, /*isVariadic=*/true);
+ }
+
/// Returns a type position for the given entity.
Position *getType(Position *p) { return TypePosition::get(uniquer, p); }
@@ -496,6 +560,10 @@ class PredicateBuilder {
return {OperandCountQuestion::get(uniquer),
UnsignedAnswer::get(uniquer, count)};
}
+ Predicate getOperandCountAtLeast(unsigned count) {
+ return {OperandCountAtLeastQuestion::get(uniquer),
+ UnsignedAnswer::get(uniquer, count)};
+ }
/// Create a predicate comparing the name of an operation to a known value.
Predicate getOperationName(StringRef name) {
@@ -509,10 +577,15 @@ class PredicateBuilder {
return {ResultCountQuestion::get(uniquer),
UnsignedAnswer::get(uniquer, count)};
}
+ Predicate getResultCountAtLeast(unsigned count) {
+ return {ResultCountAtLeastQuestion::get(uniquer),
+ UnsignedAnswer::get(uniquer, count)};
+ }
/// Create a predicate comparing the type of an attribute or value to a known
- /// type.
- Predicate getTypeConstraint(Type type) {
+ /// type. The value is stored as either a TypeAttr, or an ArrayAttr of
+ /// TypeAttr.
+ Predicate getTypeConstraint(Attribute type) {
return {TypeQuestion::get(uniquer), TypeAnswer::get(uniquer, type)};
}
diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp
index 885fbad0f976..bcd32dfa4bef 100644
--- a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp
+++ b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp
@@ -28,7 +28,13 @@ static void getTreePredicates(std::vector<PositionalPredicate> &predList,
/// Compares the depths of two positions.
static bool comparePosDepth(Position *lhs, Position *rhs) {
- return lhs->getIndex().size() < rhs->getIndex().size();
+ return lhs->getOperationDepth() < rhs->getOperationDepth();
+}
+
+/// Returns the number of non-range elements within `values`.
+static unsigned getNumNonRangeValues(ValueRange values) {
+ return llvm::count_if(values.getTypes(),
+ [](Type type) { return !type.isa<pdl::RangeType>(); });
}
static void getTreePredicates(std::vector<PositionalPredicate> &predList,
@@ -46,28 +52,50 @@ static void getTreePredicates(std::vector<PositionalPredicate> &predList,
predList.emplace_back(pos, builder.getAttributeConstraint(value));
}
-static void getTreePredicates(std::vector<PositionalPredicate> &predList,
- Value val, PredicateBuilder &builder,
- DenseMap<Value, Position *> &inputs,
- OperandPosition *pos) {
- assert(val.getType().isa<pdl::ValueType>() && "expected value type");
-
- // Prevent traversal into a null value.
- predList.emplace_back(pos, builder.getIsNotNull());
+/// Collect all of the predicates for the given operand position.
+static void getOperandTreePredicates(std::vector<PositionalPredicate> &predList,
+ Value val, PredicateBuilder &builder,
+ DenseMap<Value, Position *> &inputs,
+ Position *pos) {
+ Type valueType = val.getType();
+ bool isVariadic = valueType.isa<pdl::RangeType>();
// If this is a typed operand, add a type constraint.
- if (auto in = val.getDefiningOp<pdl::OperandOp>()) {
- if (Value type = in.type())
- getTreePredicates(predList, type, builder, inputs, builder.getType(pos));
-
- // Otherwise, recurse into a result node.
- } else if (auto resultOp = val.getDefiningOp<pdl::ResultOp>()) {
- OperationPosition *parentPos = builder.getParent(pos);
- Position *resultPos = builder.getResult(parentPos, resultOp.index());
- predList.emplace_back(parentPos, builder.getIsNotNull());
- predList.emplace_back(resultPos, builder.getEqualTo(pos));
- getTreePredicates(predList, resultOp.parent(), builder, inputs, parentPos);
- }
+ TypeSwitch<Operation *>(val.getDefiningOp())
+ .Case<pdl::OperandOp, pdl::OperandsOp>([&](auto op) {
+ // Prevent traversal into a null value if the operand has a proper
+ // index.
+ if (std::is_same<pdl::OperandOp, decltype(op)>::value ||
+ cast<OperandGroupPosition>(pos)->getOperandGroupNumber())
+ predList.emplace_back(pos, builder.getIsNotNull());
+
+ if (Value type = op.type())
+ getTreePredicates(predList, type, builder, inputs,
+ builder.getType(pos));
+ })
+ .Case<pdl::ResultOp, pdl::ResultsOp>([&](auto op) {
+ Optional<unsigned> index = op.index();
+
+ // Prevent traversal into a null value if the result has a proper index.
+ if (index)
+ predList.emplace_back(pos, builder.getIsNotNull());
+
+ // Get the parent operation of this operand.
+ OperationPosition *parentPos = builder.getOperandDefiningOp(pos);
+ predList.emplace_back(parentPos, builder.getIsNotNull());
+
+ // Ensure that the operands match the corresponding results of the
+ // parent operation.
+ Position *resultPos = nullptr;
+ if (std::is_same<pdl::ResultOp, decltype(op)>::value)
+ resultPos = builder.getResult(parentPos, *index);
+ else
+ resultPos = builder.getResultGroup(parentPos, index, isVariadic);
+ predList.emplace_back(resultPos, builder.getEqualTo(pos));
+
+ // Collect the predicates of the parent operation.
+ getTreePredicates(predList, op.parent(), builder, inputs, parentPos);
+ });
}
static void getTreePredicates(std::vector<PositionalPredicate> &predList,
@@ -86,11 +114,25 @@ static void getTreePredicates(std::vector<PositionalPredicate> &predList,
if (Optional<StringRef> opName = op.name())
predList.emplace_back(pos, builder.getOperationName(*opName));
- // Check that the operation has the proper number of operands and results.
+ // Check that the operation has the proper number of operands. If there are
+ // any variable length operands, we check a minimum instead of an exact count.
OperandRange operands = op.operands();
+ unsigned minOperands = getNumNonRangeValues(operands);
+ if (minOperands != operands.size()) {
+ if (minOperands)
+ predList.emplace_back(pos, builder.getOperandCountAtLeast(minOperands));
+ } else {
+ predList.emplace_back(pos, builder.getOperandCount(minOperands));
+ }
+
+ // Check that the operation has the proper number of results. If there are
+ // any variable length results, we check a minimum instead of an exact count.
OperandRange types = op.types();
- predList.emplace_back(pos, builder.getOperandCount(operands.size()));
- predList.emplace_back(pos, builder.getResultCount(types.size()));
+ unsigned minResults = getNumNonRangeValues(types);
+ if (minResults == types.size())
+ predList.emplace_back(pos, builder.getResultCount(types.size()));
+ else if (minResults)
+ predList.emplace_back(pos, builder.getResultCountAtLeast(minResults));
// Recurse into any attributes, operands, or results.
for (auto it : llvm::zip(op.attributeNames(), op.attributes())) {
@@ -99,15 +141,47 @@ static void getTreePredicates(std::vector<PositionalPredicate> &predList,
builder.getAttribute(opPos,
std::get<0>(it).cast<StringAttr>().getValue()));
}
- for (auto operandIt : llvm::enumerate(operands)) {
- getTreePredicates(predList, operandIt.value(), builder, inputs,
- builder.getOperand(opPos, operandIt.index()));
+
+ // Process the operands and results of the operation. For all values up to
+ // the first variable length value, we use the concrete operand/result
+ // number. After that, we use the "group" given that we can't know the
+ // concrete indices until runtime. If there is only one variadic operand
+ // group, we treat it as all of the operands/results of the operation.
+ /// Operands.
+ if (operands.size() == 1 && operands[0].getType().isa<pdl::RangeType>()) {
+ getTreePredicates(predList, operands.front(), builder, inputs,
+ builder.getAllOperands(opPos));
+ } else {
+ bool foundVariableLength = false;
+ for (auto operandIt : llvm::enumerate(operands)) {
+ bool isVariadic = operandIt.value().getType().isa<pdl::RangeType>();
+ foundVariableLength |= isVariadic;
+
+ Position *pos =
+ foundVariableLength
+ ? builder.getOperandGroup(opPos, operandIt.index(), isVariadic)
+ : builder.getOperand(opPos, operandIt.index());
+ getTreePredicates(predList, operandIt.value(), builder, inputs, pos);
+ }
}
- for (auto &resultIt : llvm::enumerate(types)) {
- auto *resultPos = builder.getResult(pos, resultIt.index());
- predList.emplace_back(resultPos, builder.getIsNotNull());
- getTreePredicates(predList, resultIt.value(), builder, inputs,
- builder.getType(resultPos));
+ /// Results.
+ if (types.size() == 1 && types[0].getType().isa<pdl::RangeType>()) {
+ getTreePredicates(predList, types.front(), builder, inputs,
+ builder.getType(builder.getAllResults(opPos)));
+ } else {
+ bool foundVariableLength = false;
+ for (auto &resultIt : llvm::enumerate(types)) {
+ bool isVariadic = resultIt.value().getType().isa<pdl::RangeType>();
+ foundVariableLength |= isVariadic;
+
+ auto *resultPos =
+ foundVariableLength
+ ? builder.getResultGroup(pos, resultIt.index(), isVariadic)
+ : builder.getResult(pos, resultIt.index());
+ predList.emplace_back(resultPos, builder.getIsNotNull());
+ getTreePredicates(predList, resultIt.value(), builder, inputs,
+ builder.getType(resultPos));
+ }
}
}
@@ -115,12 +189,14 @@ static void getTreePredicates(std::vector<PositionalPredicate> &predList,
Value val, PredicateBuilder &builder,
DenseMap<Value, Position *> &inputs,
TypePosition *pos) {
- assert(val.getType().isa<pdl::TypeType>() && "expected value type");
- pdl::TypeOp typeOp = cast<pdl::TypeOp>(val.getDefiningOp());
-
// Check for a constraint on a constant type.
- if (Optional<Type> type = typeOp.type())
- predList.emplace_back(pos, builder.getTypeConstraint(*type));
+ if (pdl::TypeOp typeOp = val.getDefiningOp<pdl::TypeOp>()) {
+ if (Attribute type = typeOp.typeAttr())
+ predList.emplace_back(pos, builder.getTypeConstraint(type));
+ } else if (pdl::TypesOp typeOp = val.getDefiningOp<pdl::TypesOp>()) {
+ if (Attribute typeAttr = typeOp.typesAttr())
+ predList.emplace_back(pos, builder.getTypeConstraint(typeAttr));
+ }
}
/// Collect the tree predicates anchored at the given value.
@@ -133,8 +209,8 @@ static void getTreePredicates(std::vector<PositionalPredicate> &predList,
if (!it.second) {
// If this is an input value that has been visited in the tree, add a
// constraint to ensure that both instances refer to the same value.
- if (isa<pdl::AttributeOp, pdl::OperandOp, pdl::OperationOp, pdl::TypeOp>(
- val.getDefiningOp())) {
+ if (isa<pdl::AttributeOp, pdl::OperandOp, pdl::OperandsOp, pdl::OperationOp,
+ pdl::TypeOp>(val.getDefiningOp())) {
auto minMaxPositions =
std::minmax(pos, it.first->second, comparePosDepth);
predList.emplace_back(minMaxPositions.second,
@@ -144,9 +220,11 @@ static void getTreePredicates(std::vector<PositionalPredicate> &predList,
}
TypeSwitch<Position *>(pos)
- .Case<AttributePosition, OperandPosition, OperationPosition,
- TypePosition>([&](auto *derivedPos) {
- getTreePredicates(predList, val, builder, inputs, derivedPos);
+ .Case<AttributePosition, OperationPosition, TypePosition>([&](auto *pos) {
+ getTreePredicates(predList, val, builder, inputs, pos);
+ })
+ .Case<OperandPosition, OperandGroupPosition>([&](auto *pos) {
+ getOperandTreePredicates(predList, val, builder, inputs, pos);
})
.Default([](auto *) { llvm_unreachable("unexpected position kind"); });
}
@@ -180,11 +258,30 @@ static void getResultPredicates(pdl::ResultOp op,
Position *&resultPos = inputs[op];
if (resultPos)
return;
+
+ // Ensure that the result isn't null.
auto *parentPos = cast<OperationPosition>(inputs.lookup(op.parent()));
resultPos = builder.getResult(parentPos, op.index());
predList.emplace_back(resultPos, builder.getIsNotNull());
}
+static void getResultPredicates(pdl::ResultsOp op,
+ std::vector<PositionalPredicate> &predList,
+ PredicateBuilder &builder,
+ DenseMap<Value, Position *> &inputs) {
+ Position *&resultPos = inputs[op];
+ if (resultPos)
+ return;
+
+ // Ensure that the result isn't null if the result has an index.
+ auto *parentPos = cast<OperationPosition>(inputs.lookup(op.parent()));
+ bool isVariadic = op.getType().isa<pdl::RangeType>();
+ Optional<unsigned> index = op.index();
+ resultPos = builder.getResultGroup(parentPos, index, isVariadic);
+ if (index)
+ predList.emplace_back(resultPos, builder.getIsNotNull());
+}
+
/// Collect all of the predicates that cannot be determined via walking the
/// tree.
static void getNonTreePredicates(pdl::PatternOp pattern,
@@ -192,10 +289,13 @@ static void getNonTreePredicates(pdl::PatternOp pattern,
PredicateBuilder &builder,
DenseMap<Value, Position *> &inputs) {
for (Operation &op : pattern.body().getOps()) {
- if (auto constraintOp = dyn_cast<pdl::ApplyNativeConstraintOp>(&op))
- getConstraintPredicates(constraintOp, predList, builder, inputs);
- else if (auto resultOp = dyn_cast<pdl::ResultOp>(&op))
- getResultPredicates(resultOp, predList, builder, inputs);
+ TypeSwitch<Operation *>(&op)
+ .Case<pdl::ApplyNativeConstraintOp>([&](auto constraintOp) {
+ getConstraintPredicates(constraintOp, predList, builder, inputs);
+ })
+ .Case<pdl::ResultOp, pdl::ResultsOp>([&](auto resultOp) {
+ getResultPredicates(resultOp, predList, builder, inputs);
+ });
}
}
@@ -254,10 +354,10 @@ struct OrderedPredicate {
// * lower position dependency
// * lower predicate dependency
auto *rhsPos = rhs.position;
- return std::make_tuple(primary, secondary, rhsPos->getIndex().size(),
+ return std::make_tuple(primary, secondary, rhsPos->getOperationDepth(),
rhsPos->getKind(), rhs.question->getKind()) >
std::make_tuple(rhs.primary, rhs.secondary,
- position->getIndex().size(), position->getKind(),
+ position->getOperationDepth(), position->getKind(),
question->getKind());
}
};
diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.h b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.h
index 1621fa96747b..ac2fa98d7c7b 100644
--- a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.h
+++ b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.h
@@ -190,6 +190,12 @@ struct SwitchNode : public MatcherNode {
using ChildMapT = llvm::MapVector<Qualifier *, std::unique_ptr<MatcherNode>>;
ChildMapT &getChildren() { return children; }
+ /// Returns the child at the given index.
+ std::pair<Qualifier *, std::unique_ptr<MatcherNode>> &getChild(unsigned i) {
+ assert(i < children.size() && "invalid child index");
+ return *std::next(children.begin(), i);
+ }
+
private:
/// Switch predicate "answers" select the child. Answers that are not found
/// default to the failure node.
diff --git a/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp b/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp
index 8b9c27c63e82..a93f3c48503c 100644
--- a/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp
+++ b/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp
@@ -29,28 +29,12 @@ void PDLInterpDialect::initialize() {
// pdl_interp::CreateOperationOp
//===----------------------------------------------------------------------===//
-static ParseResult parseCreateOperationOp(OpAsmParser &p,
- OperationState &state) {
- if (p.parseOptionalAttrDict(state.attributes))
- return failure();
+static ParseResult parseCreateOperationOpAttributes(
+ OpAsmParser &p, SmallVectorImpl<OpAsmParser::OperandType> &attrOperands,
+ ArrayAttr &attrNamesAttr) {
Builder &builder = p.getBuilder();
-
- // Parse the operation name.
- StringAttr opName;
- if (p.parseAttribute(opName, "name", state.attributes))
- return failure();
-
- // Parse the operands.
- SmallVector<OpAsmParser::OperandType, 4> operands;
- if (p.parseLParen() || p.parseOperandList(operands) || p.parseRParen() ||
- p.resolveOperands(operands, builder.getType<pdl::ValueType>(),
- state.operands))
- return failure();
-
- // Parse the attributes.
SmallVector<Attribute, 4> attrNames;
if (succeeded(p.parseOptionalLBrace())) {
- SmallVector<OpAsmParser::OperandType, 4> attrOps;
do {
StringAttr nameAttr;
OpAsmParser::OperandType operand;
@@ -58,60 +42,35 @@ static ParseResult parseCreateOperationOp(OpAsmParser &p,
p.parseOperand(operand))
return failure();
attrNames.push_back(nameAttr);
- attrOps.push_back(operand);
+ attrOperands.push_back(operand);
} while (succeeded(p.parseOptionalComma()));
-
- if (p.parseRBrace() ||
- p.resolveOperands(attrOps, builder.getType<pdl::AttributeType>(),
- state.operands))
- return failure();
- }
- state.addAttribute("attributeNames", builder.getArrayAttr(attrNames));
- state.addTypes(builder.getType<pdl::OperationType>());
-
- // Parse the result types.
- SmallVector<OpAsmParser::OperandType, 4> opResultTypes;
- if (p.parseArrow())
- return failure();
- if (succeeded(p.parseOptionalLParen())) {
- if (p.parseRParen())
+ if (p.parseRBrace())
return failure();
- } else if (p.parseOperandList(opResultTypes) ||
- p.resolveOperands(opResultTypes, builder.getType<pdl::TypeType>(),
- state.operands)) {
- return failure();
}
-
- int32_t operandSegmentSizes[] = {static_cast<int32_t>(operands.size()),
- static_cast<int32_t>(attrNames.size()),
- static_cast<int32_t>(opResultTypes.size())};
- state.addAttribute("operand_segment_sizes",
- builder.getI32VectorAttr(operandSegmentSizes));
+ attrNamesAttr = builder.getArrayAttr(attrNames);
return success();
}
-static void print(OpAsmPrinter &p, CreateOperationOp op) {
- p << "pdl_interp.create_operation ";
- p.printOptionalAttrDict(op->getAttrs(),
- {"attributeNames", "name", "operand_segment_sizes"});
- p << '"' << op.name() << "\"(" << op.operands() << ')';
+static void printCreateOperationOpAttributes(OpAsmPrinter &p,
+ CreateOperationOp op,
+ OperandRange attrArgs,
+ ArrayAttr attrNames) {
+ if (attrNames.empty())
+ return;
+ p << " {";
+ interleaveComma(llvm::seq<int>(0, attrNames.size()), p,
+ [&](int i) { p << attrNames[i] << " = " << attrArgs[i]; });
+ p << '}';
+}
- // Emit the optional attributes.
- ArrayAttr attrNames = op.attributeNames();
- if (!attrNames.empty()) {
- Operation::operand_range attrArgs = op.attributes();
- p << " {";
- interleaveComma(llvm::seq<int>(0, attrNames.size()), p,
- [&](int i) { p << attrNames[i] << " = " << attrArgs[i]; });
- p << '}';
- }
+//===----------------------------------------------------------------------===//
+// pdl_interp::GetValueTypeOp
+//===----------------------------------------------------------------------===//
- // Print the result type constraints of the operation.
- auto types = op.types();
- if (types.empty())
- p << " -> ()";
- else
- p << " -> " << op.types();
+/// Given the result type of a `GetValueTypeOp`, return the expected input type.
+static Type getGetValueTypeOpValueType(Type type) {
+ Type valueTy = pdl::ValueType::get(type.getContext());
+ return type.isa<pdl::RangeType>() ? pdl::RangeType::get(valueTy) : valueTy;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp
index c09892caec1b..ef96e25c7be3 100644
--- a/mlir/lib/Rewrite/ByteCode.cpp
+++ b/mlir/lib/Rewrite/ByteCode.cpp
@@ -208,7 +208,7 @@ class Generator {
void generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer);
void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer);
void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer);
- void generate(pdl_interp::InferredTypeOp op, ByteCodeWriter &writer);
+ void generate(pdl_interp::InferredTypesOp op, ByteCodeWriter &writer);
void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer);
void generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer);
void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer);
@@ -487,7 +487,7 @@ void Generator::generate(Operation *op, ByteCodeWriter &writer) {
pdl_interp::GetAttributeOp, pdl_interp::GetAttributeTypeOp,
pdl_interp::GetDefiningOpOp, pdl_interp::GetOperandOp,
pdl_interp::GetResultOp, pdl_interp::GetValueTypeOp,
- pdl_interp::InferredTypeOp, pdl_interp::IsNotNullOp,
+ pdl_interp::InferredTypesOp, pdl_interp::IsNotNullOp,
pdl_interp::RecordMatchOp, pdl_interp::ReplaceOp,
pdl_interp::SwitchAttributeOp, pdl_interp::SwitchTypeOp,
pdl_interp::SwitchOperandCountOp, pdl_interp::SwitchOperationNameOp,
@@ -615,9 +615,9 @@ void Generator::generate(pdl_interp::GetValueTypeOp op,
ByteCodeWriter &writer) {
writer.append(OpCode::GetValueType, op.result(), op.value());
}
-void Generator::generate(pdl_interp::InferredTypeOp op,
+void Generator::generate(pdl_interp::InferredTypesOp op,
ByteCodeWriter &writer) {
- // InferType maps to a null type as a marker for inferring a result type.
+ // InferType maps to a null type as a marker for inferring result types.
getMemIndex(op.type()) = getMemIndex(Type());
}
void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) {
@@ -980,16 +980,12 @@ void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter,
state.name.getAbstractOperation()->getInterface<InferTypeOpInterface>();
// TODO: Handle failure.
- SmallVector<Type, 2> inferredTypes;
+ state.types.clear();
if (failed(concept->inferReturnTypes(
state.getContext(), state.location, state.operands,
state.attributes.getDictionary(state.getContext()), state.regions,
- inferredTypes)))
+ state.types)))
return;
-
- for (unsigned i = 0, e = state.types.size(); i != e; ++i)
- if (!state.types[i])
- state.types[i] = inferredTypes[i];
}
Operation *resultOp = rewriter.createOperation(state);
memory[memIndex] = resultOp;
diff --git a/mlir/lib/TableGen/Predicate.cpp b/mlir/lib/TableGen/Predicate.cpp
index a37847f0d489..dd71540c15d4 100644
--- a/mlir/lib/TableGen/Predicate.cpp
+++ b/mlir/lib/TableGen/Predicate.cpp
@@ -133,6 +133,23 @@ namespace {
using Subst = std::pair<StringRef, StringRef>;
} // end anonymous namespace
+/// Perform the given substitutions on 'str' in-place.
+static void performSubstitutions(std::string &str,
+ ArrayRef<Subst> substitutions) {
+ // Apply all parent substitutions from innermost to outermost.
+ for (const auto &subst : llvm::reverse(substitutions)) {
+ auto pos = str.find(std::string(subst.first));
+ while (pos != std::string::npos) {
+ str.replace(pos, subst.first.size(), std::string(subst.second));
+ // Skip the newly inserted substring, which itself may consider the
+ // pattern to match.
+ pos += subst.second.size();
+ // Find the next possible match position.
+ pos = str.find(std::string(subst.first), pos);
+ }
+ }
+}
+
// Build the predicate tree starting from the top-level predicate, which may
// have children, and perform leaf substitutions inplace. Note that after
// substitution, nodes are still pointing to the original TableGen record.
@@ -147,19 +164,7 @@ buildPredicateTree(const Pred &root,
rootNode->predicate = &root;
if (!root.isCombined()) {
rootNode->expr = root.getCondition();
- // Apply all parent substitutions from innermost to outermost.
- for (const auto &subst : llvm::reverse(substitutions)) {
- auto pos = rootNode->expr.find(std::string(subst.first));
- while (pos != std::string::npos) {
- rootNode->expr.replace(pos, subst.first.size(),
- std::string(subst.second));
- // Skip the newly inserted substring, which itself may consider the
- // pattern to match.
- pos += subst.second.size();
- // Find the next possible match position.
- pos = rootNode->expr.find(std::string(subst.first), pos);
- }
- }
+ performSubstitutions(rootNode->expr, substitutions);
return rootNode;
}
@@ -170,12 +175,14 @@ buildPredicateTree(const Pred &root,
const auto &substPred = static_cast<const SubstLeavesPred &>(root);
allSubstitutions.push_back(
{substPred.getPattern(), substPred.getReplacement()});
- }
- // If the current predicate is a ConcatPred, record the prefix and suffix.
- else if (rootNode->kind == PredCombinerKind::Concat) {
+
+ // If the current predicate is a ConcatPred, record the prefix and suffix.
+ } else if (rootNode->kind == PredCombinerKind::Concat) {
const auto &concatPred = static_cast<const ConcatPred &>(root);
rootNode->prefix = std::string(concatPred.getPrefix());
+ performSubstitutions(rootNode->prefix, substitutions);
rootNode->suffix = std::string(concatPred.getSuffix());
+ performSubstitutions(rootNode->suffix, substitutions);
}
// Build child subtrees.
diff --git a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir
index 0792f76cba7a..0af77a24efb4 100644
--- a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir
+++ b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir
@@ -103,6 +103,59 @@ module @inputs {
// -----
+// CHECK-LABEL: module @variadic_inputs
+module @variadic_inputs {
+ // CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation)
+ // CHECK-DAG: pdl_interp.check_operand_count of %[[ROOT]] is at_least 2
+
+ // The first operand has a known index.
+ // CHECK-DAG: %[[INPUT:.*]] = pdl_interp.get_operand 0 of %[[ROOT]]
+ // CHECK-DAG: pdl_interp.is_not_null %[[INPUT]] : !pdl.value
+
+ // The second operand is a group of unknown size, with a type constraint.
+ // CHECK-DAG: %[[VAR_INPUTS:.*]] = pdl_interp.get_operands 1 of %[[ROOT]] : !pdl.range<value>
+ // CHECK-DAG: pdl_interp.is_not_null %[[VAR_INPUTS]] : !pdl.range<value>
+
+ // CHECK-DAG: %[[INPUT_TYPE:.*]] = pdl_interp.get_value_type of %[[VAR_INPUTS]] : !pdl.range<type>
+ // CHECK-DAG: pdl_interp.check_types %[[INPUT_TYPE]] are [i64]
+
+ // The third operand is at an unknown offset due to operand 2, but is expected
+ // to be of size 1.
+ // CHECK-DAG: %[[INPUT2:.*]] = pdl_interp.get_operands 2 of %[[ROOT]] : !pdl.value
+ // CHECK-DAG: pdl_interp.are_equal %[[INPUT]], %[[INPUT2]] : !pdl.value
+ pdl.pattern : benefit(1) {
+ %types = pdl.types : [i64]
+ %inputs = pdl.operands : %types
+ %input = pdl.operand
+ %root = pdl.operation(%input, %inputs, %input : !pdl.value, !pdl.range<value>, !pdl.value)
+ pdl.rewrite %root with "rewriter"
+ }
+}
+
+// -----
+
+// CHECK-LABEL: module @single_operand_range
+module @single_operand_range {
+ // CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation)
+
+ // Check that the operand range is treated as all of the operands of the
+ // operation.
+ // CHECK-DAG: %[[RESULTS:.*]] = pdl_interp.get_operands of %[[ROOT]]
+ // CHECK-DAG: %[[RESULT_TYPES:.*]] = pdl_interp.get_value_type of %[[RESULTS]] : !pdl.range<type>
+ // CHECK-DAG: pdl_interp.check_types %[[RESULT_TYPES]] are [i64]
+
+ // The operand count is unknown, so there is no need to check for it.
+ // CHECK-NOT: pdl_interp.check_operand_count
+ pdl.pattern : benefit(1) {
+ %types = pdl.types : [i64]
+ %operands = pdl.operands : %types
+ %root = pdl.operation(%operands : !pdl.range<value>)
+ pdl.rewrite %root with "rewriter"
+ }
+}
+
+// -----
+
// CHECK-LABEL: module @results
module @results {
// CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation)
@@ -127,6 +180,57 @@ module @results {
// -----
+// CHECK-LABEL: module @variadic_results
+module @variadic_results {
+ // CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation)
+ // CHECK-DAG: pdl_interp.check_result_count of %[[ROOT]] is at_least 2
+
+ // The first result has a known index.
+ // CHECK-DAG: %[[RESULT:.*]] = pdl_interp.get_result 0 of %[[ROOT]]
+ // CHECK-DAG: pdl_interp.is_not_null %[[RESULT]] : !pdl.value
+
+ // The second result is a group of unknown size, with a type constraint.
+ // CHECK-DAG: %[[VAR_RESULTS:.*]] = pdl_interp.get_results 1 of %[[ROOT]] : !pdl.range<value>
+ // CHECK-DAG: pdl_interp.is_not_null %[[VAR_RESULTS]] : !pdl.range<value>
+
+ // CHECK-DAG: %[[RESULT_TYPE:.*]] = pdl_interp.get_value_type of %[[VAR_RESULTS]] : !pdl.range<type>
+ // CHECK-DAG: pdl_interp.check_types %[[RESULT_TYPE]] are [i64]
+
+ // The third result is at an unknown offset due to result 1, but is expected
+ // to be of size 1.
+ // CHECK-DAG: %[[RESULT2:.*]] = pdl_interp.get_results 2 of %[[ROOT]] : !pdl.value
+ // CHECK-DAG: pdl_interp.is_not_null %[[RESULT2]] : !pdl.value
+ pdl.pattern : benefit(1) {
+ %types = pdl.types : [i64]
+ %type = pdl.type
+ %root = pdl.operation -> (%type, %types, %type : !pdl.type, !pdl.range<type>, !pdl.type)
+ pdl.rewrite %root with "rewriter"
+ }
+}
+
+// -----
+
+// CHECK-LABEL: module @single_result_range
+module @single_result_range {
+ // CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation)
+
+ // Check that the result range is treated as all of the results of the
+ // operation.
+ // CHECK-DAG: %[[RESULTS:.*]] = pdl_interp.get_results of %[[ROOT]]
+ // CHECK-DAG: %[[RESULT_TYPES:.*]] = pdl_interp.get_value_type of %[[RESULTS]] : !pdl.range<type>
+ // CHECK-DAG: pdl_interp.check_types %[[RESULT_TYPES]] are [i64]
+
+ // The result count is unknown, so there is no need to check for it.
+ // CHECK-NOT: pdl_interp.check_result_count
+ pdl.pattern : benefit(1) {
+ %types = pdl.types : [i64]
+ %root = pdl.operation -> (%types : !pdl.range<type>)
+ pdl.rewrite %root with "rewriter"
+ }
+}
+
+// -----
+
// CHECK-LABEL: module @results_as_operands
module @results_as_operands {
// CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation)
@@ -160,8 +264,29 @@ module @results_as_operands {
// -----
-// CHECK-LABEL: module @switch_result_types
-module @switch_result_types {
+// CHECK-LABEL: module @single_result_range_as_operands
+module @single_result_range_as_operands {
+ // CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation)
+ // CHECK-DAG: %[[OPERANDS:.*]] = pdl_interp.get_operands of %[[ROOT]] : !pdl.range<value>
+ // CHECK-DAG: %[[OP:.*]] = pdl_interp.get_defining_op of %[[OPERANDS]] : !pdl.range<value>
+ // CHECK-DAG: pdl_interp.is_not_null %[[OP]]
+ // CHECK-DAG: %[[RESULTS:.*]] = pdl_interp.get_results of %[[OP]] : !pdl.range<value>
+ // CHECK-DAG: pdl_interp.are_equal %[[RESULTS]], %[[OPERANDS]] : !pdl.range<value>
+
+ pdl.pattern : benefit(1) {
+ %types = pdl.types
+ %inputOp = pdl.operation -> (%types : !pdl.range<type>)
+ %results = pdl.results of %inputOp
+
+ %root = pdl.operation(%results : !pdl.range<value>)
+ pdl.rewrite %root with "rewriter"
+ }
+}
+
+// -----
+
+// CHECK-LABEL: module @switch_single_result_type
+module @switch_single_result_type {
// CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation)
// CHECK: %[[RESULT:.*]] = pdl_interp.get_result 0 of %[[ROOT]]
// CHECK: %[[RESULT_TYPE:.*]] = pdl_interp.get_value_type of %[[RESULT]]
@@ -178,6 +303,84 @@ module @switch_result_types {
}
}
+// -----
+
+// CHECK-LABEL: module @switch_result_types
+module @switch_result_types {
+ // CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation)
+ // CHECK: %[[RESULTS:.*]] = pdl_interp.get_results of %[[ROOT]]
+ // CHECK: %[[RESULT_TYPES:.*]] = pdl_interp.get_value_type of %[[RESULTS]]
+ // CHECK: pdl_interp.switch_types %[[RESULT_TYPES]] to {{\[\[}}i32], [i64, i32]]
+ pdl.pattern : benefit(1) {
+ %types = pdl.types : [i32]
+ %root = pdl.operation -> (%types : !pdl.range<type>)
+ pdl.rewrite %root with "rewriter"
+ }
+ pdl.pattern : benefit(1) {
+ %types = pdl.types : [i64, i32]
+ %root = pdl.operation -> (%types : !pdl.range<type>)
+ pdl.rewrite %root with "rewriter"
+ }
+}
+
+// -----
+
+// CHECK-LABEL: module @switch_operand_count_at_least
+module @switch_operand_count_at_least {
+ // Check that when there are multiple "at_least" checks, the failure branch
+ // goes to the next one in increasing order.
+
+ // CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation)
+ // CHECK: pdl_interp.check_operand_count of %[[ROOT]] is at_least 1 -> ^[[PATTERN_1_NEXT_BLOCK:.*]],
+ // CHECK: ^bb2:
+ // CHECK-NEXT: pdl_interp.check_operand_count of %[[ROOT]] is at_least 2
+ // CHECK: ^[[PATTERN_1_NEXT_BLOCK]]:
+ // CHECK-NEXT: {{.*}} -> ^{{.*}}, ^bb2
+ pdl.pattern : benefit(1) {
+ %operand = pdl.operand
+ %operands = pdl.operands
+ %root = pdl.operation(%operand, %operands : !pdl.value, !pdl.range<value>)
+ pdl.rewrite %root with "rewriter"
+ }
+ pdl.pattern : benefit(1) {
+ %operand = pdl.operand
+ %operand2 = pdl.operand
+ %operands = pdl.operands
+ %root = pdl.operation(%operand, %operand2, %operands : !pdl.value, !pdl.value, !pdl.range<value>)
+ pdl.rewrite %root with "rewriter"
+ }
+}
+
+// -----
+
+// CHECK-LABEL: module @switch_result_count_at_least
+module @switch_result_count_at_least {
+ // Check that when there are multiple "at_least" checks, the failure branch
+ // goes to the next one in increasing order.
+
+ // CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation)
+ // CHECK: pdl_interp.check_result_count of %[[ROOT]] is at_least 1 -> ^[[PATTERN_1_NEXT_BLOCK:.*]],
+ // CHECK: ^[[PATTERN_2_BLOCK:[a-zA-Z_0-9]*]]:
+ // CHECK: pdl_interp.check_result_count of %[[ROOT]] is at_least 2
+ // CHECK: ^[[PATTERN_1_NEXT_BLOCK]]:
+ // CHECK-NEXT: pdl_interp.get_result
+ // CHECK-NEXT: pdl_interp.is_not_null {{.*}} -> ^{{.*}}, ^[[PATTERN_2_BLOCK]]
+ pdl.pattern : benefit(1) {
+ %type = pdl.type
+ %types = pdl.types
+ %root = pdl.operation -> (%type, %types : !pdl.type, !pdl.range<type>)
+ pdl.rewrite %root with "rewriter"
+ }
+ pdl.pattern : benefit(1) {
+ %type = pdl.type
+ %type2 = pdl.type
+ %types = pdl.types
+ %root = pdl.operation -> (%type, %type2, %types : !pdl.type, !pdl.type, !pdl.range<type>)
+ pdl.rewrite %root with "rewriter"
+ }
+}
+
+
// -----
// CHECK-LABEL: module @predicate_ordering
diff --git a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-rewriter.mlir b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-rewriter.mlir
index 67ac7c811ab7..58d1c3177dad 100644
--- a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-rewriter.mlir
+++ b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-rewriter.mlir
@@ -37,7 +37,7 @@ module @operation_attributes {
// CHECK: module @rewriters
// CHECK: func @pdl_generated_rewriter(%[[ATTR:.*]]: !pdl.attribute, %[[ROOT:.*]]: !pdl.operation)
// CHECK: %[[ATTR1:.*]] = pdl_interp.create_attribute true
- // CHECK: pdl_interp.create_operation "foo.op"() {"attr" = %[[ATTR]], "attr1" = %[[ATTR1]]}
+ // CHECK: pdl_interp.create_operation "foo.op" {"attr" = %[[ATTR]], "attr1" = %[[ATTR1]]}
pdl.pattern : benefit(1) {
%attr = pdl.attribute
%root = pdl.operation "foo.op" {"attr" = %attr}
@@ -55,9 +55,9 @@ module @operation_attributes {
module @operation_operands {
// CHECK: module @rewriters
// CHECK: func @pdl_generated_rewriter(%[[OPERAND:.*]]: !pdl.value, %[[ROOT:.*]]: !pdl.operation)
- // CHECK: %[[NEWOP:.*]] = pdl_interp.create_operation "foo.op"(%[[OPERAND]])
+ // CHECK: %[[NEWOP:.*]] = pdl_interp.create_operation "foo.op"(%[[OPERAND]] : !pdl.value)
// CHECK: %[[OPERAND1:.*]] = pdl_interp.get_result 0 of %[[NEWOP]]
- // CHECK: pdl_interp.create_operation "foo.op2"(%[[OPERAND1]])
+ // CHECK: pdl_interp.create_operation "foo.op2"(%[[OPERAND1]] : !pdl.value)
pdl.pattern : benefit(1) {
%operand = pdl.operand
%root = pdl.operation "foo.op"(%operand : !pdl.value)
@@ -77,9 +77,9 @@ module @operation_operands {
module @operation_operands {
// CHECK: module @rewriters
// CHECK: func @pdl_generated_rewriter(%[[OPERAND:.*]]: !pdl.value, %[[ROOT:.*]]: !pdl.operation)
- // CHECK: %[[NEWOP:.*]] = pdl_interp.create_operation "foo.op"(%[[OPERAND]])
+ // CHECK: %[[NEWOP:.*]] = pdl_interp.create_operation "foo.op"(%[[OPERAND]] : !pdl.value)
// CHECK: %[[OPERAND1:.*]] = pdl_interp.get_result 0 of %[[NEWOP]]
- // CHECK: pdl_interp.create_operation "foo.op2"(%[[OPERAND1]])
+ // CHECK: pdl_interp.create_operation "foo.op2"(%[[OPERAND1]] : !pdl.value)
pdl.pattern : benefit(1) {
%operand = pdl.operand
%root = pdl.operation "foo.op"(%operand : !pdl.value)
@@ -95,11 +95,13 @@ module @operation_operands {
// -----
-// CHECK-LABEL: module @operation_result_types
-module @operation_result_types {
+// CHECK-LABEL: module @operation_infer_types_from_replaceop
+module @operation_infer_types_from_replaceop {
// CHECK: module @rewriters
- // CHECK: func @pdl_generated_rewriter(%[[TYPE:.*]]: !pdl.type, %[[TYPE1:.*]]: !pdl.type
- // CHECK: pdl_interp.create_operation "foo.op"() -> %[[TYPE]], %[[TYPE1]]
+ // CHECK: func @pdl_generated_rewriter(%[[ROOT:.*]]: !pdl.operation
+ // CHECK: %[[RESULTS:.*]] = pdl_interp.get_results of %[[ROOT]]
+ // CHECK: %[[RESULT_TYPES:.*]] = pdl_interp.get_value_type of %[[RESULTS]]
+ // CHECK: pdl_interp.create_operation "foo.op" -> (%[[RESULT_TYPES]] : !pdl.range<type>)
pdl.pattern : benefit(1) {
%rootType = pdl.type
%rootType1 = pdl.type
@@ -114,13 +116,46 @@ module @operation_result_types {
// -----
+// CHECK-LABEL: module @operation_infer_types_from_otherop_individual_results
+module @operation_infer_types_from_otherop_individual_results {
+ // CHECK: module @rewriters
+ // CHECK: func @pdl_generated_rewriter(%[[TYPE:.*]]: !pdl.type, %[[TYPES:.*]]: !pdl.range<type>
+ // CHECK: pdl_interp.create_operation "foo.op" -> (%[[TYPE]], %[[TYPES]] : !pdl.type, !pdl.range<type>)
+ pdl.pattern : benefit(1) {
+ %rootType = pdl.type
+ %rootTypes = pdl.types
+ %root = pdl.operation "foo.op" -> (%rootType, %rootTypes : !pdl.type, !pdl.range<type>)
+ pdl.rewrite %root {
+ %newOp = pdl.operation "foo.op" -> (%rootType, %rootTypes : !pdl.type, !pdl.range<type>)
+ }
+ }
+}
+
+// -----
+
+// CHECK-LABEL: module @operation_infer_types_from_otherop_results
+module @operation_infer_types_from_otherop_results {
+ // CHECK: module @rewriters
+ // CHECK: func @pdl_generated_rewriter(%[[TYPES:.*]]: !pdl.range<type>
+ // CHECK: pdl_interp.create_operation "foo.op" -> (%[[TYPES]] : !pdl.range<type>)
+ pdl.pattern : benefit(1) {
+ %rootTypes = pdl.types
+ %root = pdl.operation "foo.op" -> (%rootTypes : !pdl.range<type>)
+ pdl.rewrite %root {
+ %newOp = pdl.operation "foo.op" -> (%rootTypes : !pdl.range<type>)
+ }
+ }
+}
+
+// -----
+
// CHECK-LABEL: module @replace_with_op
module @replace_with_op {
// CHECK: module @rewriters
// CHECK: func @pdl_generated_rewriter(%[[ROOT:.*]]: !pdl.operation)
// CHECK: %[[NEWOP:.*]] = pdl_interp.create_operation
- // CHECK: %[[OP_RESULT:.*]] = pdl_interp.get_result 0 of %[[NEWOP]]
- // CHECK: pdl_interp.replace %[[ROOT]] with(%[[OP_RESULT]])
+ // CHECK: %[[RESULTS:.*]] = pdl_interp.get_results of %[[NEWOP]]
+ // CHECK: pdl_interp.replace %[[ROOT]] with (%[[RESULTS]] : !pdl.range<value>)
pdl.pattern : benefit(1) {
%type = pdl.type : i32
%root = pdl.operation "foo.op" -> (%type : !pdl.type)
@@ -136,17 +171,21 @@ module @replace_with_op {
// CHECK-LABEL: module @replace_with_values
module @replace_with_values {
// CHECK: module @rewriters
- // CHECK: func @pdl_generated_rewriter(%[[ROOT:.*]]: !pdl.operation)
+ // CHECK: func @pdl_generated_rewriter({{.*}}, %[[ROOT:.*]]: !pdl.operation)
// CHECK: %[[NEWOP:.*]] = pdl_interp.create_operation
- // CHECK: %[[OP_RESULT:.*]] = pdl_interp.get_result 0 of %[[NEWOP]]
- // CHECK: pdl_interp.replace %[[ROOT]] with(%[[OP_RESULT]])
+ // CHECK: %[[RESULT:.*]] = pdl_interp.get_result 0 of %[[NEWOP]]
+ // CHECK: %[[RESULTS:.*]] = pdl_interp.get_results 1 of %[[NEWOP]] : !pdl.range<value>
+ // CHECK: %[[RESULTS_2:.*]] = pdl_interp.get_results 2 of %[[NEWOP]] : !pdl.value
+ // CHECK: pdl_interp.replace %[[ROOT]] with (%[[RESULT]], %[[RESULTS]], %[[RESULTS_2]] : !pdl.value, !pdl.range<value>, !pdl.value)
pdl.pattern : benefit(1) {
- %type = pdl.type : i32
- %root = pdl.operation "foo.op" -> (%type : !pdl.type)
+ %types = pdl.types
+ %root = pdl.operation "foo.op" -> (%types : !pdl.range<type>)
pdl.rewrite %root {
- %newOp = pdl.operation "foo.op" -> (%type : !pdl.type)
+ %newOp = pdl.operation "foo.op" -> (%types : !pdl.range<type>)
%newResult = pdl.result 0 of %newOp
- pdl.replace %root with (%newResult : !pdl.value)
+ %newResults = pdl.results 1 of %newOp -> !pdl.range<value>
+ %newResults2 = pdl.results 2 of %newOp -> !pdl.value
+ pdl.replace %root with (%newResult, %newResults, %newResults2 : !pdl.value, !pdl.range<value>, !pdl.value)
}
}
}
@@ -175,14 +214,13 @@ module @apply_native_rewrite {
// CHECK: module @rewriters
// CHECK: func @pdl_generated_rewriter(%[[ROOT:.*]]: !pdl.operation)
// CHECK: %[[TYPE:.*]] = pdl_interp.apply_rewrite "functor" [true](%[[ROOT]] : !pdl.operation) : !pdl.type
- // CHECK: pdl_interp.create_operation "foo.op"() -> %[[TYPE]]
+ // CHECK: pdl_interp.create_operation "foo.op" -> (%[[TYPE]] : !pdl.type)
pdl.pattern : benefit(1) {
%type = pdl.type
%root = pdl.operation "foo.op" -> (%type : !pdl.type)
pdl.rewrite %root {
%newType = pdl.apply_native_rewrite "functor"[true](%root : !pdl.operation) : !pdl.type
%newOp = pdl.operation "foo.op" -> (%newType : !pdl.type)
- pdl.replace %root with %newOp
}
}
}
diff --git a/mlir/test/Dialect/PDLInterp/ops.mlir b/mlir/test/Dialect/PDLInterp/ops.mlir
index d76b17c394e8..072dfaddcda2 100644
--- a/mlir/test/Dialect/PDLInterp/ops.mlir
+++ b/mlir/test/Dialect/PDLInterp/ops.mlir
@@ -10,16 +10,16 @@ func @operations(%attribute: !pdl.attribute,
%input: !pdl.value,
%type: !pdl.type) {
// attributes, operands, and results
- %op0 = pdl_interp.create_operation "foo.op"(%input) {"attr" = %attribute} -> %type
+ %op0 = pdl_interp.create_operation "foo.op"(%input : !pdl.value) {"attr" = %attribute} -> (%type : !pdl.type)
// attributes, and results
- %op1 = pdl_interp.create_operation "foo.op"() {"attr" = %attribute} -> %type
+ %op1 = pdl_interp.create_operation "foo.op" {"attr" = %attribute} -> (%type : !pdl.type)
// attributes
- %op2 = pdl_interp.create_operation "foo.op"() {"attr" = %attribute, "attr1" = %attribute} -> ()
+ %op2 = pdl_interp.create_operation "foo.op" {"attr" = %attribute, "attr1" = %attribute}
// operands, and results
- %op3 = pdl_interp.create_operation "foo.op"(%input) -> %type
+ %op3 = pdl_interp.create_operation "foo.op"(%input : !pdl.value) -> (%type : !pdl.type)
pdl_interp.finalize
}
diff --git a/mlir/test/Rewrite/pdl-bytecode.mlir b/mlir/test/Rewrite/pdl-bytecode.mlir
index 2093d03bbf25..b0acd328147a 100644
--- a/mlir/test/Rewrite/pdl-bytecode.mlir
+++ b/mlir/test/Rewrite/pdl-bytecode.mlir
@@ -25,7 +25,7 @@ module @patterns {
module @rewriters {
func @success(%root : !pdl.operation) {
- %op = pdl_interp.create_operation "test.replaced_by_pattern"() -> ()
+ %op = pdl_interp.create_operation "test.replaced_by_pattern"
pdl_interp.erase %root
pdl_interp.finalize
}
@@ -122,7 +122,7 @@ module @patterns {
module @rewriters {
func @success(%root : !pdl.operation) {
- %op = pdl_interp.create_operation "test.success"() -> ()
+ %op = pdl_interp.create_operation "test.success"
pdl_interp.erase %root
pdl_interp.finalize
}
@@ -157,7 +157,7 @@ module @patterns {
module @rewriters {
func @success(%root : !pdl.operation) {
- %op = pdl_interp.create_operation "test.success"() -> ()
+ %op = pdl_interp.create_operation "test.success"
pdl_interp.erase %root
pdl_interp.finalize
}
@@ -190,7 +190,7 @@ module @patterns {
module @rewriters {
func @success(%root : !pdl.operation) {
- %op = pdl_interp.create_operation "test.success"() -> ()
+ %op = pdl_interp.create_operation "test.success"
pdl_interp.erase %root
pdl_interp.finalize
}
@@ -222,7 +222,7 @@ module @patterns {
module @rewriters {
func @success(%root : !pdl.operation) {
- %op = pdl_interp.create_operation "test.success"() -> ()
+ %op = pdl_interp.create_operation "test.success"
pdl_interp.erase %root
pdl_interp.finalize
}
@@ -256,7 +256,7 @@ module @patterns {
module @rewriters {
func @success(%root : !pdl.operation) {
- %op = pdl_interp.create_operation "test.success"() -> ()
+ %op = pdl_interp.create_operation "test.success"
pdl_interp.erase %root
pdl_interp.finalize
}
@@ -288,7 +288,7 @@ module @patterns {
module @rewriters {
func @success(%root : !pdl.operation) {
- %op = pdl_interp.create_operation "test.success"() -> ()
+ %op = pdl_interp.create_operation "test.success"
pdl_interp.erase %root
pdl_interp.finalize
}
@@ -325,7 +325,7 @@ module @patterns {
module @rewriters {
func @success(%root : !pdl.operation) {
- %op = pdl_interp.create_operation "test.success"() -> ()
+ %op = pdl_interp.create_operation "test.success"
pdl_interp.erase %root
pdl_interp.finalize
}
@@ -375,7 +375,7 @@ module @patterns {
module @rewriters {
func @success(%root : !pdl.operation) {
- %op = pdl_interp.create_operation "test.success"() -> ()
+ %op = pdl_interp.create_operation "test.success"
pdl_interp.erase %root
pdl_interp.finalize
}
@@ -425,8 +425,8 @@ module @patterns {
^pat1:
%operand0 = pdl_interp.get_operand 0 of %root
%operand4 = pdl_interp.get_operand 4 of %root
- %defOp0 = pdl_interp.get_defining_op of %operand0
- %defOp4 = pdl_interp.get_defining_op of %operand4
+ %defOp0 = pdl_interp.get_defining_op of %operand0 : !pdl.value
+ %defOp4 = pdl_interp.get_defining_op of %operand4 : !pdl.value
pdl_interp.are_equal %defOp0, %defOp4 : !pdl.operation -> ^pat2, ^end
^pat2:
@@ -438,7 +438,7 @@ module @patterns {
module @rewriters {
func @success(%root : !pdl.operation) {
- %op = pdl_interp.create_operation "test.success"() -> ()
+ %op = pdl_interp.create_operation "test.success"
pdl_interp.erase %root
pdl_interp.finalize
}
@@ -476,8 +476,8 @@ module @patterns {
^pat1:
%result0 = pdl_interp.get_result 0 of %root
%result4 = pdl_interp.get_result 4 of %root
- %result0_type = pdl_interp.get_value_type of %result0
- %result4_type = pdl_interp.get_value_type of %result4
+ %result0_type = pdl_interp.get_value_type of %result0 : !pdl.type
+ %result4_type = pdl_interp.get_value_type of %result4 : !pdl.type
pdl_interp.are_equal %result0_type, %result4_type : !pdl.type -> ^pat2, ^end
^pat2:
@@ -489,7 +489,7 @@ module @patterns {
module @rewriters {
func @success(%root : !pdl.operation) {
- %op = pdl_interp.create_operation "test.success"() -> ()
+ %op = pdl_interp.create_operation "test.success"
pdl_interp.erase %root
pdl_interp.finalize
}
@@ -513,7 +513,7 @@ module @ir attributes { test.get_result_1 } {
// Fully tested within the tests for other operations.
//===----------------------------------------------------------------------===//
-// pdl_interp::InferredTypeOp
+// pdl_interp::InferredTypesOp
//===----------------------------------------------------------------------===//
// Fully tested within the tests for other operations.
@@ -549,7 +549,7 @@ module @patterns {
pdl_interp.finalize
}
func @success(%root : !pdl.operation) {
- %op = pdl_interp.create_operation "test.success"() -> ()
+ %op = pdl_interp.create_operation "test.success"
pdl_interp.erase %root
pdl_interp.finalize
}
@@ -582,7 +582,7 @@ module @patterns {
module @rewriters {
func @success(%root : !pdl.operation) {
%operand = pdl_interp.get_operand 0 of %root
- pdl_interp.replace %root with (%operand)
+ pdl_interp.replace %root with (%operand : !pdl.value)
pdl_interp.finalize
}
}
@@ -622,7 +622,7 @@ module @patterns {
module @rewriters {
func @success(%root : !pdl.operation) {
- %op = pdl_interp.create_operation "test.success"() -> ()
+ %op = pdl_interp.create_operation "test.success"
pdl_interp.erase %root
pdl_interp.finalize
}
@@ -657,7 +657,7 @@ module @patterns {
module @rewriters {
func @success(%root : !pdl.operation) {
- %op = pdl_interp.create_operation "test.success"() -> ()
+ %op = pdl_interp.create_operation "test.success"
pdl_interp.erase %root
pdl_interp.finalize
}
@@ -693,7 +693,7 @@ module @patterns {
module @rewriters {
func @success(%root : !pdl.operation) {
- %op = pdl_interp.create_operation "test.success"() -> ()
+ %op = pdl_interp.create_operation "test.success"
pdl_interp.erase %root
pdl_interp.finalize
}
@@ -728,7 +728,7 @@ module @patterns {
module @rewriters {
func @success(%root : !pdl.operation) {
- %op = pdl_interp.create_operation "test.success"() -> ()
+ %op = pdl_interp.create_operation "test.success"
pdl_interp.erase %root
pdl_interp.finalize
}
@@ -768,7 +768,7 @@ module @patterns {
module @rewriters {
func @success(%root : !pdl.operation) {
- %op = pdl_interp.create_operation "test.success"() -> ()
+ %op = pdl_interp.create_operation "test.success"
pdl_interp.erase %root
pdl_interp.finalize
}
diff --git a/mlir/test/mlir-tblgen/op-attribute.td b/mlir/test/mlir-tblgen/op-attribute.td
index 725afd9bc1aa..987f417d867c 100644
--- a/mlir/test/mlir-tblgen/op-attribute.td
+++ b/mlir/test/mlir-tblgen/op-attribute.td
@@ -136,7 +136,7 @@ def BOp : NS_Op<"b_op", []> {
// DEF: if (!((tblgen_function_attr.isa<::mlir::FlatSymbolRefAttr>())))
// DEF: if (!(((tblgen_some_type_attr.isa<::mlir::TypeAttr>())) && ((tblgen_some_type_attr.cast<::mlir::TypeAttr>().getValue().isa<SomeType>()))))
// DEF: if (!((tblgen_array_attr.isa<::mlir::ArrayAttr>())))
-// DEF: if (!(((tblgen_some_attr_array.isa<::mlir::ArrayAttr>())) && (::llvm::all_of(tblgen_some_attr_array.cast<::mlir::ArrayAttr>(), [](::mlir::Attribute attr) { return (some-condition); }))))
+// DEF: if (!(((tblgen_some_attr_array.isa<::mlir::ArrayAttr>())) && (::llvm::all_of(tblgen_some_attr_array.cast<::mlir::ArrayAttr>(), [&](::mlir::Attribute attr) { return (some-condition); }))))
// DEF: if (!(((tblgen_type_attr.isa<::mlir::TypeAttr>())) && ((tblgen_type_attr.cast<::mlir::TypeAttr>().getValue().isa<::mlir::Type>()))))
// Test common attribute kind getters' return types
More information about the Mlir-commits
mailing list