[llvm-branch-commits] [mlir] 00a61b3 - [mlir][ODS] Add new RangedTypesMatchWith operation predicate
River Riddle via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Thu Jan 14 11:47:16 PST 2021
Author: River Riddle
Date: 2021-01-14T11:35:49-08:00
New Revision: 00a61b327dd8a7071ce0baadd16ea4c7b7e31e73
URL: https://github.com/llvm/llvm-project/commit/00a61b327dd8a7071ce0baadd16ea4c7b7e31e73
DIFF: https://github.com/llvm/llvm-project/commit/00a61b327dd8a7071ce0baadd16ea4c7b7e31e73.diff
LOG: [mlir][ODS] Add new RangedTypesMatchWith operation predicate
This is a variant of TypesMatchWith that provides support for variadic arguments. This is necessary because ranges generally can't use the default operator== comparators for checking equality.
Differential Revision: https://reviews.llvm.org/D94574
Added:
Modified:
mlir/include/mlir/IR/OpBase.td
mlir/test/lib/Dialect/Test/TestOps.td
mlir/test/mlir-tblgen/op-format.mlir
mlir/tools/mlir-tblgen/OpFormatGen.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 73ddbc1d56eb..3b55e51d8178 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -2191,16 +2191,28 @@ class AllTypesMatch<list<string> names> :
AllMatchSameOperatorTrait<names, "$_self.getType()", "type">;
// A type constraint that denotes `transform(lhs.getType()) == rhs.getType()`.
+// An optional comparator function may be provided that changes the above form
+// into: `comparator(transform(lhs.getType()), rhs.getType())`.
class TypesMatchWith<string summary, string lhsArg, string rhsArg,
- string transform> :
- PredOpTrait<summary, CPred<
- !subst("$_self", "$" # lhsArg # ".getType()", transform)
- # " == $" # rhsArg # ".getType()">> {
+ string transform, string comparator = "std::equal_to<>()">
+ : PredOpTrait<summary, CPred<
+ comparator # "(" #
+ !subst("$_self", "$" # lhsArg # ".getType()", transform) #
+ ", $" # rhsArg # ".getType())">> {
string lhs = lhsArg;
string rhs = rhsArg;
string transformer = transform;
}
+// Special variant of `TypesMatchWith` that provides a comparator suitable for
+// ranged arguments.
+class RangedTypesMatchWith<string summary, string lhsArg, string rhsArg,
+ string transform>
+ : TypesMatchWith<summary, lhsArg, rhsArg, transform,
+ "[](auto &&lhs, auto &&rhs) { "
+ "return std::equal(lhs.begin(), lhs.end(), rhs.begin(), rhs.end());"
+ " }">;
+
// Type Constraint operand `idx`'s Element type is `type`.
class TCopVTEtIs<int idx, Type type> : And<[
CPred<"$_op.getNumOperands() > " # idx>,
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 1fc419cc375f..d1cbe77ac21b 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -1733,6 +1733,15 @@ def FormatTypesMatchVarOp : TEST_Op<"format_types_match_var", [
let assemblyFormat = "attr-dict $value `:` type($value)";
}
+def FormatTypesMatchVariadicOp : TEST_Op<"format_types_match_variadic", [
+ RangedTypesMatchWith<"result type matches operand", "value", "result",
+ "llvm::make_range($_self.begin(), $_self.end())">
+ ]> {
+ let arguments = (ins Variadic<AnyType>:$value);
+ let results = (outs Variadic<AnyType>:$result);
+ let assemblyFormat = "attr-dict $value `:` type($value)";
+}
+
def FormatTypesMatchAttrOp : TEST_Op<"format_types_match_attr", [
TypesMatchWith<"result type matches constant", "value", "result", "$_self">
]> {
diff --git a/mlir/test/mlir-tblgen/op-format.mlir b/mlir/test/mlir-tblgen/op-format.mlir
index 334313debda1..4eb64772aee2 100644
--- a/mlir/test/mlir-tblgen/op-format.mlir
+++ b/mlir/test/mlir-tblgen/op-format.mlir
@@ -308,5 +308,8 @@ test.format_infer_variadic_type_from_non_variadic %i64, %i64 : i64
// CHECK: test.format_types_match_var %[[I64]] : i64
%ignored_res3 = test.format_types_match_var %i64 : i64
+// CHECK: test.format_types_match_variadic %[[I64]], %[[I64]], %[[I64]] : i64, i64, i64
+%ignored_res4:3 = test.format_types_match_variadic %i64, %i64, %i64 : i64, i64, i64
+
// CHECK: test.format_types_match_attr 1 : i64
-%ignored_res4 = test.format_types_match_attr 1 : i64
+%ignored_res5 = test.format_types_match_attr 1 : i64
diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
index 749ef1613c14..bba796f9b492 100644
--- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
@@ -1287,10 +1287,16 @@ void OperationFormat::genParserTypeResolution(Operator &op,
if (Optional<int> val = resolver.getBuilderIdx()) {
body << "odsBuildableType" << *val;
} else if (const NamedTypeConstraint *var = resolver.getVariable()) {
- if (Optional<StringRef> tform = resolver.getVarTransformer())
- body << tgfmt(*tform, &FmtContext().withSelf(var->name + "Types[0]"));
- else
+ if (Optional<StringRef> tform = resolver.getVarTransformer()) {
+ FmtContext fmtContext;
+ if (var->isVariadic())
+ fmtContext.withSelf(var->name + "Types");
+ else
+ fmtContext.withSelf(var->name + "Types[0]");
+ body << tgfmt(*tform, &fmtContext);
+ } else {
body << var->name << "Types";
+ }
} else if (const NamedAttribute *attr = resolver.getAttribute()) {
if (Optional<StringRef> tform = resolver.getVarTransformer())
body << tgfmt(*tform,
More information about the llvm-branch-commits
mailing list