[Mlir-commits] [mlir] 5fc28e7 - Improve MLIR Attribute::get() method efficiency by reducing the amount of argument copies (#68067)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Oct 3 18:07:52 PDT 2023


Author: Mehdi Amini
Date: 2023-10-03T18:07:46-07:00
New Revision: 5fc28e7a8d989fffafec2301529d0164a28799b2

URL: https://github.com/llvm/llvm-project/commit/5fc28e7a8d989fffafec2301529d0164a28799b2
DIFF: https://github.com/llvm/llvm-project/commit/5fc28e7a8d989fffafec2301529d0164a28799b2.diff

LOG: Improve MLIR Attribute::get() method efficiency by reducing the amount of argument copies (#68067)

This ensures that the proper std::forward/std::move are involved, we go from 6
copy-constructions to 0 (!) on Attribute creation in release builds.

Added: 
    

Modified: 
    mlir/include/mlir/IR/StorageUniquerSupport.h
    mlir/include/mlir/Support/StorageUniquer.h
    mlir/test/lib/Dialect/Test/TestAttrDefs.td
    mlir/test/lib/Dialect/Test/TestAttributes.cpp
    mlir/test/lib/Dialect/Test/TestAttributes.h
    mlir/test/mlir-tblgen/attrdefs.td
    mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
    mlir/unittests/IR/AttributeTest.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/StorageUniquerSupport.h b/mlir/include/mlir/IR/StorageUniquerSupport.h
index c466e230d341d3e..982d5220ab52ce9 100644
--- a/mlir/include/mlir/IR/StorageUniquerSupport.h
+++ b/mlir/include/mlir/IR/StorageUniquerSupport.h
@@ -175,11 +175,11 @@ class StorageUserBase : public BaseT, public Traits<ConcreteT>... {
   /// function is guaranteed to return a non null object and will assert if
   /// the arguments provided are invalid.
   template <typename... Args>
-  static ConcreteT get(MLIRContext *ctx, Args... args) {
+  static ConcreteT get(MLIRContext *ctx, Args &&...args) {
     // Ensure that the invariants are correct for construction.
     assert(
         succeeded(ConcreteT::verify(getDefaultDiagnosticEmitFn(ctx), args...)));
-    return UniquerT::template get<ConcreteT>(ctx, args...);
+    return UniquerT::template get<ConcreteT>(ctx, std::forward<Args>(args)...);
   }
 
   /// Get or create a new ConcreteT instance within the ctx, defined at
@@ -187,8 +187,9 @@ class StorageUserBase : public BaseT, public Traits<ConcreteT>... {
   /// invalid, errors are emitted using the provided location and a null object
   /// is returned.
   template <typename... Args>
-  static ConcreteT getChecked(const Location &loc, Args... args) {
-    return ConcreteT::getChecked(getDefaultDiagnosticEmitFn(loc), args...);
+  static ConcreteT getChecked(const Location &loc, Args &&...args) {
+    return ConcreteT::getChecked(getDefaultDiagnosticEmitFn(loc),
+                                 std::forward<Args>(args)...);
   }
 
   /// Get or create a new ConcreteT instance within the ctx. If the arguments

diff  --git a/mlir/include/mlir/Support/StorageUniquer.h b/mlir/include/mlir/Support/StorageUniquer.h
index 13359bf91f40d17..2369d9cf90ff9dd 100644
--- a/mlir/include/mlir/Support/StorageUniquer.h
+++ b/mlir/include/mlir/Support/StorageUniquer.h
@@ -16,6 +16,7 @@
 #include "llvm/ADT/DenseSet.h"
 #include "llvm/ADT/StringRef.h"
 #include "llvm/Support/Allocator.h"
+#include <utility>
 
 namespace mlir {
 namespace detail {
@@ -207,7 +208,7 @@ class StorageUniquer {
 
     // Generate a constructor function for the derived storage.
     auto ctorFn = [&](StorageAllocator &allocator) {
-      auto *storage = Storage::construct(allocator, derivedKey);
+      auto *storage = Storage::construct(allocator, std::move(derivedKey));
       if (initFn)
         initFn(storage);
       return storage;
@@ -300,9 +301,9 @@ class StorageUniquer {
   static typename ImplTy::KeyTy getKey(Args &&...args) {
     if constexpr (llvm::is_detected<detail::has_impltype_getkey_t, ImplTy,
                                     Args...>::value)
-      return ImplTy::getKey(args...);
+      return ImplTy::getKey(std::forward<Args>(args)...);
     else
-      return typename ImplTy::KeyTy(args...);
+      return typename ImplTy::KeyTy(std::forward<Args>(args)...);
   }
 
   //===--------------------------------------------------------------------===//

diff  --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
index ec0a5548a160338..945c54c04d47ce8 100644
--- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
@@ -323,5 +323,16 @@ def Test_IteratorTypeArrayAttr
     : TypedArrayAttrBase<Test_IteratorTypeEnum,
   "Iterator type should be an enum.">;
 
+def TestParamCopyCount : AttrParameter<"CopyCount", "", "const CopyCount &"> {}
+
+// Test overridding attribute builders with a custom builder.
+def TestCopyCount : Test_Attr<"TestCopyCount"> {
+  let mnemonic = "copy_count";
+  let parameters = (ins TestParamCopyCount:$copy_count);
+  let assemblyFormat = "`<` $copy_count `>`";
+}
+
+
+
 
 #endif // TEST_ATTRDEFS

diff  --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
index 7fc2e6ab3ec0a0a..c240354e5d99044 100644
--- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
@@ -22,6 +22,7 @@
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/ADT/bit.h"
 #include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/raw_ostream.h"
 
 using namespace mlir;
 using namespace test;
@@ -175,6 +176,45 @@ static void printTrueFalse(AsmPrinter &p, std::optional<int> result) {
   p << (*result ? "true" : "false");
 }
 
+//===----------------------------------------------------------------------===//
+// CopyCountAttr Implementation
+//===----------------------------------------------------------------------===//
+
+CopyCount::CopyCount(const CopyCount &rhs) : value(rhs.value) {
+  CopyCount::counter++;
+}
+
+CopyCount &CopyCount::operator=(const CopyCount &rhs) {
+  CopyCount::counter++;
+  value = rhs.value;
+  return *this;
+}
+
+int CopyCount::counter;
+
+static bool operator==(const test::CopyCount &lhs, const test::CopyCount &rhs) {
+  return lhs.value == rhs.value;
+}
+
+llvm::raw_ostream &test::operator<<(llvm::raw_ostream &os,
+                                    const test::CopyCount &value) {
+  return os << value.value;
+}
+
+template <>
+struct mlir::FieldParser<test::CopyCount> {
+  static FailureOr<test::CopyCount> parse(AsmParser &parser) {
+    std::string value;
+    if (parser.parseKeyword(value))
+      return failure();
+    return test::CopyCount(value);
+  }
+};
+namespace test {
+llvm::hash_code hash_value(const test::CopyCount &copyCount) {
+  return llvm::hash_value(copyCount.value);
+}
+} // namespace test
 //===----------------------------------------------------------------------===//
 // Tablegen Generated Definitions
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/lib/Dialect/Test/TestAttributes.h b/mlir/test/lib/Dialect/Test/TestAttributes.h
index cc73e078bf7e20b..ef6eae51fdd628a 100644
--- a/mlir/test/lib/Dialect/Test/TestAttributes.h
+++ b/mlir/test/lib/Dialect/Test/TestAttributes.h
@@ -29,6 +29,19 @@
 
 namespace test {
 class TestDialect;
+// Payload class for the CopyCountAttr.
+class CopyCount {
+public:
+  CopyCount(std::string value) : value(value) {}
+  CopyCount(const CopyCount &rhs);
+  CopyCount &operator=(const CopyCount &rhs);
+  CopyCount(CopyCount &&rhs) = default;
+  CopyCount &operator=(CopyCount &&rhs) = default;
+  static int counter;
+  std::string value;
+};
+llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
+                              const test::CopyCount &value);
 
 /// A handle used to reference external elements instances.
 using TestDialectResourceBlobHandle =

diff  --git a/mlir/test/mlir-tblgen/attrdefs.td b/mlir/test/mlir-tblgen/attrdefs.td
index 020cd0ca65b691c..d7228368e71eb05 100644
--- a/mlir/test/mlir-tblgen/attrdefs.td
+++ b/mlir/test/mlir-tblgen/attrdefs.td
@@ -82,7 +82,7 @@ def B_CompoundAttrA : TestAttr<"CompoundA"> {
 // Check that AttributeSelfTypeParameter is handled properly.
 // DEF-LABEL: struct CompoundAAttrStorage
 // DEF: CompoundAAttrStorage(
-// DEF-SAME: inner(inner)
+// DEF-SAME: inner(std::move(inner))
 
 // DEF: bool operator==(const KeyTy &tblgenKey) const {
 // DEF-NEXT: return
@@ -94,7 +94,7 @@ def B_CompoundAttrA : TestAttr<"CompoundA"> {
 
 // DEF: static CompoundAAttrStorage *construct
 // DEF: return new (allocator.allocate<CompoundAAttrStorage>())
-// DEF-SAME: CompoundAAttrStorage(widthOfSomething, exampleTdType, apFloat, dims, inner);
+// DEF-SAME: CompoundAAttrStorage(std::move(widthOfSomething), std::move(exampleTdType), std::move(apFloat), std::move(dims), std::move(inner));
 
 // DEF: ::mlir::Type CompoundAAttr::getInner() const {
 // DEF-NEXT: return getImpl()->inner;

diff  --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
index f6e43d42d29f069..51f24de2442b957 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
@@ -353,7 +353,7 @@ void DefGen::emitDefaultBuilder() {
   MethodBody &body = m->body().indent();
   auto scope = body.scope("return Base::get(context", ");");
   for (const auto &param : params)
-    body << ", " << param.getName();
+    body << ", std::move(" << param.getName() << ")";
 }
 
 void DefGen::emitCheckedBuilder() {
@@ -474,8 +474,10 @@ void DefGen::emitTraitMethod(const InterfaceMethod &method) {
 void DefGen::emitStorageConstructor() {
   Constructor *ctor =
       storageCls->addConstructor<Method::Inline>(getBuilderParams({}));
-  for (auto &param : params)
-    ctor->addMemberInitializer(param.getName(), param.getName());
+  for (auto &param : params) {
+    std::string movedValue = ("std::move(" + param.getName() + ")").str();
+    ctor->addMemberInitializer(param.getName(), movedValue);
+  }
 }
 
 void DefGen::emitKeyType() {
@@ -525,11 +527,11 @@ void DefGen::emitConstruct() {
                                         : Method::Static,
       MethodParameter(strfmt("::mlir::{0}StorageAllocator &", valueType),
                       "allocator"),
-      MethodParameter("const KeyTy &", "tblgenKey"));
+      MethodParameter("KeyTy &&", "tblgenKey"));
   if (!def.hasStorageCustomConstructor()) {
     auto &body = construct->body().indent();
     for (const auto &it : llvm::enumerate(params)) {
-      body << formatv("auto {0} = std::get<{1}>(tblgenKey);\n",
+      body << formatv("auto {0} = std::move(std::get<{1}>(tblgenKey));\n",
                       it.value().getName(), it.index());
     }
     // Use the parameters' custom allocator code, if provided.
@@ -544,8 +546,9 @@ void DefGen::emitConstruct() {
         body.scope(strfmt("return new (allocator.allocate<{0}>()) {0}(",
                           def.getStorageClassName()),
                    ");");
-    llvm::interleaveComma(params, body,
-                          [&](auto &param) { body << param.getName(); });
+    llvm::interleaveComma(params, body, [&](auto &param) {
+      body << "std::move(" << param.getName() << ")";
+    });
   }
 }
 

diff  --git a/mlir/unittests/IR/AttributeTest.cpp b/mlir/unittests/IR/AttributeTest.cpp
index 9afbce037b408c0..20167825d4ca829 100644
--- a/mlir/unittests/IR/AttributeTest.cpp
+++ b/mlir/unittests/IR/AttributeTest.cpp
@@ -13,6 +13,8 @@
 #include "gtest/gtest.h"
 #include <optional>
 
+#include "../../test/lib/Dialect/Test/TestDialect.h"
+
 using namespace mlir;
 using namespace mlir::detail;
 
@@ -459,4 +461,26 @@ TEST(SubElementTest, Nested) {
             ArrayRef<Attribute>(
                 {strAttr, trueAttr, falseAttr, boolArrayAttr, dictAttr}));
 }
+
+// Test how many times we call copy-ctor when building an attribute.
+TEST(CopyCountAttr, CopyCount) {
+  MLIRContext context;
+  context.loadDialect<test::TestDialect>();
+
+  test::CopyCount::counter = 0;
+  test::CopyCount copyCount("hello");
+  test::TestCopyCountAttr::get(&context, std::move(copyCount));
+  int counter1 = test::CopyCount::counter;
+  test::CopyCount::counter = 0;
+  test::TestCopyCountAttr::get(&context, std::move(copyCount));
+#ifndef NDEBUG
+  // One verification enabled only in assert-mode requires a copy.
+  EXPECT_EQ(counter1, 1);
+  EXPECT_EQ(test::CopyCount::counter, 1);
+#else
+  EXPECT_EQ(counter1, 0);
+  EXPECT_EQ(test::CopyCount::counter, 0);
+#endif
+}
+
 } // namespace


        


More information about the Mlir-commits mailing list