[Mlir-commits] [mlir] a71bc5f - [mlir] Improve FieldParser list container detection

Rahul Kayaith llvmlistbot at llvm.org
Mon Apr 3 12:13:05 PDT 2023


Author: rkayaith
Date: 2023-04-03T15:12:59-04:00
New Revision: a71bc5f56d1f992bfd1de2f8e2279b6d5338c6db

URL: https://github.com/llvm/llvm-project/commit/a71bc5f56d1f992bfd1de2f8e2279b6d5338c6db
DIFF: https://github.com/llvm/llvm-project/commit/a71bc5f56d1f992bfd1de2f8e2279b6d5338c6db.diff

LOG: [mlir] Improve FieldParser list container detection

The current detection logic will fail for containers with an overloaded
`push_back` member. This causes issues with types like `std::vector` and
`SmallVector<SomeNonTriviallyCopyableT>`, which have both
`push_back(const T&)` and `push_back(T&&)`.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D147101

Added: 
    

Modified: 
    mlir/include/mlir/IR/DialectImplementation.h
    mlir/test/lib/Dialect/Test/TestAttrDefs.td
    mlir/test/lib/Dialect/Test/TestAttributes.cpp
    mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir
    mlir/test/mlir-tblgen/attr-or-type-format.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/DialectImplementation.h b/mlir/include/mlir/IR/DialectImplementation.h
index 460318760bb33..f045e1017ea8a 100644
--- a/mlir/include/mlir/IR/DialectImplementation.h
+++ b/mlir/include/mlir/IR/DialectImplementation.h
@@ -140,12 +140,18 @@ struct FieldParser<
   }
 };
 
+namespace detail {
+template <typename T>
+using has_push_back_t = decltype(std::declval<T>().push_back(
+    std::declval<typename T::value_type &&>()));
+} // namespace detail
+
 /// Parse any container that supports back insertion as a list.
 template <typename ContainerT>
-struct FieldParser<
-    ContainerT, std::enable_if_t<std::is_member_function_pointer<
-                                     decltype(&ContainerT::push_back)>::value,
-                                 ContainerT>> {
+struct FieldParser<ContainerT,
+                   std::enable_if_t<llvm::is_detected<detail::has_push_back_t,
+                                                      ContainerT>::value,
+                                    ContainerT>> {
   using ElementT = typename ContainerT::value_type;
   static FailureOr<ContainerT> parse(AsmParser &parser) {
     ContainerT elements;
@@ -153,7 +159,7 @@ struct FieldParser<
       auto element = FieldParser<ElementT>::parse(parser);
       if (failed(element))
         return failure();
-      elements.push_back(*element);
+      elements.push_back(std::move(*element));
       return success();
     };
     if (parser.parseCommaSeparatedList(elementParser))

diff  --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
index bd7c050f9e857..ba9909d08ae30 100644
--- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
@@ -173,6 +173,11 @@ def TestParamFour : ArrayRefParameter<"int", ""> {
   let printer = "::printIntArray($_printer, $_self)";
 }
 
+
+def TestParamVector : ArrayRefParameter<"int", ""> {
+  let cppStorageType = "std::vector<int>";
+}
+
 def TestParamUnsigned : AttrParameter<"uint64_t", ""> {}
 
 def TestAttrWithFormat : Test_Attr<"TestAttrWithFormat"> {
@@ -183,6 +188,7 @@ def TestAttrWithFormat : Test_Attr<"TestAttrWithFormat"> {
     "::mlir::IntegerAttr":$three,
     TestParamFour:$four,
     TestParamUnsigned:$five,
+    TestParamVector:$six,
     // Array of another attribute.
     ArrayRefParameter<
       "AttrWithTypeBuilderAttr", // The parameter C++ type.
@@ -192,7 +198,7 @@ def TestAttrWithFormat : Test_Attr<"TestAttrWithFormat"> {
 
   let mnemonic = "attr_with_format";
   let assemblyFormat = [{
-    `<` $one `:` struct($two, $four) `:` $three `:` $five `,`
+    `<` $one `:` struct($two, $four) `:` $three `:` $five `:` `[` $six `]` `,`
     `[` `` $arrayOfAttrWithTypeBuilderAttr `]` `>`
   }];
   let genVerifyDecl = 1;

diff  --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
index 24955ffb1713f..e0ccd500aa909 100644
--- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
@@ -98,11 +98,10 @@ TestI64ElementsAttr::verify(function_ref<InFlightDiagnostic()> emitError,
   return success();
 }
 
-LogicalResult
-TestAttrWithFormatAttr::verify(function_ref<InFlightDiagnostic()> emitError,
-                               int64_t one, std::string two, IntegerAttr three,
-                               ArrayRef<int> four, uint64_t five,
-                               ArrayRef<AttrWithTypeBuilderAttr> arrayOfAttrs) {
+LogicalResult TestAttrWithFormatAttr::verify(
+    function_ref<InFlightDiagnostic()> emitError, int64_t one, std::string two,
+    IntegerAttr three, ArrayRef<int> four, uint64_t five, ArrayRef<int> six,
+    ArrayRef<AttrWithTypeBuilderAttr> arrayOfAttrs) {
   if (four.size() != static_cast<unsigned>(one))
     return emitError() << "expected 'one' to equal 'four.size()'";
   return success();

diff  --git a/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir b/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir
index e63b5dec8a6d0..8e421bd424118 100644
--- a/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir
+++ b/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir
@@ -5,10 +5,10 @@
 // CHECK: !test.type_with_format<2147, three = "hi", two = "hi">
 func.func private @test_roundtrip_parameter_parsers(!test.type_with_format<111, three = #test<attr_ugly begin 5 : index end>, two = "foo">) -> !test.type_with_format<2147, two = "hi", three = "hi">
 attributes {
-  // CHECK: #test.attr_with_format<3 : two = "hello", four = [1, 2, 3] : 42 : i64 : 0, [ 10 : i16]
-  attr0 = #test.attr_with_format<3 : two = "hello", four = [1, 2, 3] : 42 : i64 : 0, [10 : i16]>,
-  // CHECK: #test.attr_with_format<5 : two = "a_string", four = [4, 5, 6, 7, 8] : 8 : i8 : 255, [ 10 : i16]>,
-  attr1 = #test.attr_with_format<5 : two = "a_string", four = [4, 5, 6, 7, 8] : 8 : i8 : 255, [10 : i16]>,
+  // CHECK: #test.attr_with_format<3 : two = "hello", four = [1, 2, 3] : 42 : i64 : 0 : [4, 5, 6], [ 10 : i16]
+  attr0 = #test.attr_with_format<3 : two = "hello", four = [1, 2, 3] : 42 : i64 : 0 : [4, 5, 6], [10 : i16]>,
+  // CHECK: #test.attr_with_format<5 : two = "a_string", four = [4, 5, 6, 7, 8] : 8 : i8 : 255 : [9, 10, 11], [ 10 : i16]>,
+  attr1 = #test.attr_with_format<5 : two = "a_string", four = [4, 5, 6, 7, 8] : 8 : i8 : 255 : [9, 10, 11], [10 : i16]>,
   // CHECK: #test<attr_ugly begin 5 : index end>
   attr2 = #test<attr_ugly begin 5 : index end>,
   // CHECK: #test.attr_params<42, 24>

diff  --git a/mlir/test/mlir-tblgen/attr-or-type-format.mlir b/mlir/test/mlir-tblgen/attr-or-type-format.mlir
index ead7e83bcb319..6b61deff1b69f 100644
--- a/mlir/test/mlir-tblgen/attr-or-type-format.mlir
+++ b/mlir/test/mlir-tblgen/attr-or-type-format.mlir
@@ -123,7 +123,7 @@ func.func private @test_type_syntax_error() -> !test.type_with_format<42, two =
 
 func.func private @test_verifier_fails() -> () attributes {
   // expected-error at +1 {{expected 'one' to equal 'four.size()'}}
-  attr = #test.attr_with_format<42 : two = "hello", four = [1, 2, 3] : 42 : i64 : 0, [10 : i16]>
+  attr = #test.attr_with_format<42 : two = "hello", four = [1, 2, 3] : 42 : i64 : 0 : [4, 5, 6], [10 : i16]>
 }
 
 // -----


        


More information about the Mlir-commits mailing list