[llvm-branch-commits] [mlir] 00a61b3 - [mlir][ODS] Add new RangedTypesMatchWith operation predicate

River Riddle via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Thu Jan 14 11:47:16 PST 2021


Author: River Riddle
Date: 2021-01-14T11:35:49-08:00
New Revision: 00a61b327dd8a7071ce0baadd16ea4c7b7e31e73

URL: https://github.com/llvm/llvm-project/commit/00a61b327dd8a7071ce0baadd16ea4c7b7e31e73
DIFF: https://github.com/llvm/llvm-project/commit/00a61b327dd8a7071ce0baadd16ea4c7b7e31e73.diff

LOG: [mlir][ODS] Add new RangedTypesMatchWith operation predicate

This is a variant of TypesMatchWith that provides support for variadic arguments. This is necessary because ranges generally can't use the default operator== comparators for checking equality.

Differential Revision: https://reviews.llvm.org/D94574

Added: 
    

Modified: 
    mlir/include/mlir/IR/OpBase.td
    mlir/test/lib/Dialect/Test/TestOps.td
    mlir/test/mlir-tblgen/op-format.mlir
    mlir/tools/mlir-tblgen/OpFormatGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 73ddbc1d56eb..3b55e51d8178 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -2191,16 +2191,28 @@ class AllTypesMatch<list<string> names> :
     AllMatchSameOperatorTrait<names, "$_self.getType()", "type">;
 
 // A type constraint that denotes `transform(lhs.getType()) == rhs.getType()`.
+// An optional comparator function may be provided that changes the above form
+// into: `comparator(transform(lhs.getType()), rhs.getType())`.
 class TypesMatchWith<string summary, string lhsArg, string rhsArg,
-                     string transform> :
-    PredOpTrait<summary, CPred<
-      !subst("$_self", "$" # lhsArg # ".getType()", transform)
-      # " == $" # rhsArg # ".getType()">> {
+                     string transform, string comparator = "std::equal_to<>()">
+  : PredOpTrait<summary, CPred<
+      comparator # "(" #
+      !subst("$_self", "$" # lhsArg # ".getType()", transform) #
+      ", $" # rhsArg # ".getType())">> {
   string lhs = lhsArg;
   string rhs = rhsArg;
   string transformer = transform;
 }
 
+// Special variant of `TypesMatchWith` that provides a comparator suitable for
+// ranged arguments.
+class RangedTypesMatchWith<string summary, string lhsArg, string rhsArg,
+                           string transform>
+  : TypesMatchWith<summary, lhsArg, rhsArg, transform,
+    "[](auto &&lhs, auto &&rhs) { "
+    "return std::equal(lhs.begin(), lhs.end(), rhs.begin(), rhs.end());"
+    " }">;
+
 // Type Constraint operand `idx`'s Element type is `type`.
 class TCopVTEtIs<int idx, Type type> : And<[
    CPred<"$_op.getNumOperands() > " # idx>,

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 1fc419cc375f..d1cbe77ac21b 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -1733,6 +1733,15 @@ def FormatTypesMatchVarOp : TEST_Op<"format_types_match_var", [
   let assemblyFormat = "attr-dict $value `:` type($value)";
 }
 
+def FormatTypesMatchVariadicOp : TEST_Op<"format_types_match_variadic", [
+    RangedTypesMatchWith<"result type matches operand", "value", "result",
+                         "llvm::make_range($_self.begin(), $_self.end())">
+  ]> {
+  let arguments = (ins Variadic<AnyType>:$value);
+  let results = (outs Variadic<AnyType>:$result);
+  let assemblyFormat = "attr-dict $value `:` type($value)";
+}
+
 def FormatTypesMatchAttrOp : TEST_Op<"format_types_match_attr", [
     TypesMatchWith<"result type matches constant", "value", "result", "$_self">
   ]> {

diff  --git a/mlir/test/mlir-tblgen/op-format.mlir b/mlir/test/mlir-tblgen/op-format.mlir
index 334313debda1..4eb64772aee2 100644
--- a/mlir/test/mlir-tblgen/op-format.mlir
+++ b/mlir/test/mlir-tblgen/op-format.mlir
@@ -308,5 +308,8 @@ test.format_infer_variadic_type_from_non_variadic %i64, %i64 : i64
 // CHECK: test.format_types_match_var %[[I64]] : i64
 %ignored_res3 = test.format_types_match_var %i64 : i64
 
+// CHECK: test.format_types_match_variadic %[[I64]], %[[I64]], %[[I64]] : i64, i64, i64
+%ignored_res4:3 = test.format_types_match_variadic %i64, %i64, %i64 : i64, i64, i64
+
 // CHECK: test.format_types_match_attr 1 : i64
-%ignored_res4 = test.format_types_match_attr 1 : i64
+%ignored_res5 = test.format_types_match_attr 1 : i64

diff  --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
index 749ef1613c14..bba796f9b492 100644
--- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
@@ -1287,10 +1287,16 @@ void OperationFormat::genParserTypeResolution(Operator &op,
     if (Optional<int> val = resolver.getBuilderIdx()) {
       body << "odsBuildableType" << *val;
     } else if (const NamedTypeConstraint *var = resolver.getVariable()) {
-      if (Optional<StringRef> tform = resolver.getVarTransformer())
-        body << tgfmt(*tform, &FmtContext().withSelf(var->name + "Types[0]"));
-      else
+      if (Optional<StringRef> tform = resolver.getVarTransformer()) {
+        FmtContext fmtContext;
+        if (var->isVariadic())
+          fmtContext.withSelf(var->name + "Types");
+        else
+          fmtContext.withSelf(var->name + "Types[0]");
+        body << tgfmt(*tform, &fmtContext);
+      } else {
         body << var->name << "Types";
+      }
     } else if (const NamedAttribute *attr = resolver.getAttribute()) {
       if (Optional<StringRef> tform = resolver.getVarTransformer())
         body << tgfmt(*tform,


        


More information about the llvm-branch-commits mailing list