[Mlir-commits] [mlir] Improve MLIR attribute get() method efficiency when complex members are involved (PR #68067)

Mehdi Amini llvmlistbot at llvm.org
Mon Oct 2 22:33:15 PDT 2023


https://github.com/joker-eph created https://github.com/llvm/llvm-project/pull/68067

This ensures that the proper forward/move are involved, we go from 6 copy-construction to 0 (!) when building without assertions.

>From 8c0575424dc027259b1b1911fa2a40908a435673 Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Mon, 2 Oct 2023 21:48:06 -0700
Subject: [PATCH] Improve MLIR attribute get() method efficiency when complex
 members are involved

This ensures that the proper forward/move are involved, we go from 6
copy-construction to 0 (!) when building without assertions.
---
 mlir/include/mlir/IR/StorageUniquerSupport.h  |  9 +++--
 mlir/include/mlir/Support/StorageUniquer.h    |  5 ++-
 mlir/test/lib/Dialect/Test/TestAttrDefs.td    | 11 +++++
 mlir/test/lib/Dialect/Test/TestAttributes.cpp | 40 +++++++++++++++++++
 mlir/test/lib/Dialect/Test/TestAttributes.h   | 13 ++++++
 mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp   | 17 ++++----
 mlir/unittests/IR/AttributeTest.cpp           | 23 +++++++++++
 7 files changed, 105 insertions(+), 13 deletions(-)

diff --git a/mlir/include/mlir/IR/StorageUniquerSupport.h b/mlir/include/mlir/IR/StorageUniquerSupport.h
index c466e230d341d3e..982d5220ab52ce9 100644
--- a/mlir/include/mlir/IR/StorageUniquerSupport.h
+++ b/mlir/include/mlir/IR/StorageUniquerSupport.h
@@ -175,11 +175,11 @@ class StorageUserBase : public BaseT, public Traits<ConcreteT>... {
   /// function is guaranteed to return a non null object and will assert if
   /// the arguments provided are invalid.
   template <typename... Args>
-  static ConcreteT get(MLIRContext *ctx, Args... args) {
+  static ConcreteT get(MLIRContext *ctx, Args &&...args) {
     // Ensure that the invariants are correct for construction.
     assert(
         succeeded(ConcreteT::verify(getDefaultDiagnosticEmitFn(ctx), args...)));
-    return UniquerT::template get<ConcreteT>(ctx, args...);
+    return UniquerT::template get<ConcreteT>(ctx, std::forward<Args>(args)...);
   }
 
   /// Get or create a new ConcreteT instance within the ctx, defined at
@@ -187,8 +187,9 @@ class StorageUserBase : public BaseT, public Traits<ConcreteT>... {
   /// invalid, errors are emitted using the provided location and a null object
   /// is returned.
   template <typename... Args>
-  static ConcreteT getChecked(const Location &loc, Args... args) {
-    return ConcreteT::getChecked(getDefaultDiagnosticEmitFn(loc), args...);
+  static ConcreteT getChecked(const Location &loc, Args &&...args) {
+    return ConcreteT::getChecked(getDefaultDiagnosticEmitFn(loc),
+                                 std::forward<Args>(args)...);
   }
 
   /// Get or create a new ConcreteT instance within the ctx. If the arguments
diff --git a/mlir/include/mlir/Support/StorageUniquer.h b/mlir/include/mlir/Support/StorageUniquer.h
index 13359bf91f40d17..baaedc47dcb2cd5 100644
--- a/mlir/include/mlir/Support/StorageUniquer.h
+++ b/mlir/include/mlir/Support/StorageUniquer.h
@@ -16,6 +16,7 @@
 #include "llvm/ADT/DenseSet.h"
 #include "llvm/ADT/StringRef.h"
 #include "llvm/Support/Allocator.h"
+#include <utility>
 
 namespace mlir {
 namespace detail {
@@ -300,9 +301,9 @@ class StorageUniquer {
   static typename ImplTy::KeyTy getKey(Args &&...args) {
     if constexpr (llvm::is_detected<detail::has_impltype_getkey_t, ImplTy,
                                     Args...>::value)
-      return ImplTy::getKey(args...);
+      return ImplTy::getKey(std::forward<Args>(args)...);
     else
-      return typename ImplTy::KeyTy(args...);
+      return typename ImplTy::KeyTy(std::forward<Args>(args)...);
   }
 
   //===--------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
index ec0a5548a160338..945c54c04d47ce8 100644
--- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
@@ -323,5 +323,16 @@ def Test_IteratorTypeArrayAttr
     : TypedArrayAttrBase<Test_IteratorTypeEnum,
   "Iterator type should be an enum.">;
 
+def TestParamCopyCount : AttrParameter<"CopyCount", "", "const CopyCount &"> {}
+
+// Test overridding attribute builders with a custom builder.
+def TestCopyCount : Test_Attr<"TestCopyCount"> {
+  let mnemonic = "copy_count";
+  let parameters = (ins TestParamCopyCount:$copy_count);
+  let assemblyFormat = "`<` $copy_count `>`";
+}
+
+
+
 
 #endif // TEST_ATTRDEFS
diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
index 7fc2e6ab3ec0a0a..c240354e5d99044 100644
--- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
@@ -22,6 +22,7 @@
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/ADT/bit.h"
 #include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/raw_ostream.h"
 
 using namespace mlir;
 using namespace test;
@@ -175,6 +176,45 @@ static void printTrueFalse(AsmPrinter &p, std::optional<int> result) {
   p << (*result ? "true" : "false");
 }
 
+//===----------------------------------------------------------------------===//
+// CopyCountAttr Implementation
+//===----------------------------------------------------------------------===//
+
+CopyCount::CopyCount(const CopyCount &rhs) : value(rhs.value) {
+  CopyCount::counter++;
+}
+
+CopyCount &CopyCount::operator=(const CopyCount &rhs) {
+  CopyCount::counter++;
+  value = rhs.value;
+  return *this;
+}
+
+int CopyCount::counter;
+
+static bool operator==(const test::CopyCount &lhs, const test::CopyCount &rhs) {
+  return lhs.value == rhs.value;
+}
+
+llvm::raw_ostream &test::operator<<(llvm::raw_ostream &os,
+                                    const test::CopyCount &value) {
+  return os << value.value;
+}
+
+template <>
+struct mlir::FieldParser<test::CopyCount> {
+  static FailureOr<test::CopyCount> parse(AsmParser &parser) {
+    std::string value;
+    if (parser.parseKeyword(value))
+      return failure();
+    return test::CopyCount(value);
+  }
+};
+namespace test {
+llvm::hash_code hash_value(const test::CopyCount &copyCount) {
+  return llvm::hash_value(copyCount.value);
+}
+} // namespace test
 //===----------------------------------------------------------------------===//
 // Tablegen Generated Definitions
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.h b/mlir/test/lib/Dialect/Test/TestAttributes.h
index cc73e078bf7e20b..ef6eae51fdd628a 100644
--- a/mlir/test/lib/Dialect/Test/TestAttributes.h
+++ b/mlir/test/lib/Dialect/Test/TestAttributes.h
@@ -29,6 +29,19 @@
 
 namespace test {
 class TestDialect;
+// Payload class for the CopyCountAttr.
+class CopyCount {
+public:
+  CopyCount(std::string value) : value(value) {}
+  CopyCount(const CopyCount &rhs);
+  CopyCount &operator=(const CopyCount &rhs);
+  CopyCount(CopyCount &&rhs) = default;
+  CopyCount &operator=(CopyCount &&rhs) = default;
+  static int counter;
+  std::string value;
+};
+llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
+                              const test::CopyCount &value);
 
 /// A handle used to reference external elements instances.
 using TestDialectResourceBlobHandle =
diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
index f6e43d42d29f069..f14d33c7d13d310 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
@@ -353,7 +353,7 @@ void DefGen::emitDefaultBuilder() {
   MethodBody &body = m->body().indent();
   auto scope = body.scope("return Base::get(context", ");");
   for (const auto &param : params)
-    body << ", " << param.getName();
+    body << ", std::move(" << param.getName() << ")";
 }
 
 void DefGen::emitCheckedBuilder() {
@@ -474,8 +474,10 @@ void DefGen::emitTraitMethod(const InterfaceMethod &method) {
 void DefGen::emitStorageConstructor() {
   Constructor *ctor =
       storageCls->addConstructor<Method::Inline>(getBuilderParams({}));
-  for (auto &param : params)
-    ctor->addMemberInitializer(param.getName(), param.getName());
+  for (auto &param : params) {
+    std::string movedValue = ("std::move(" + param.getName() + ")").str();
+    ctor->addMemberInitializer(param.getName(), movedValue);
+  }
 }
 
 void DefGen::emitKeyType() {
@@ -525,11 +527,11 @@ void DefGen::emitConstruct() {
                                         : Method::Static,
       MethodParameter(strfmt("::mlir::{0}StorageAllocator &", valueType),
                       "allocator"),
-      MethodParameter("const KeyTy &", "tblgenKey"));
+      MethodParameter("KeyTy &", "tblgenKey"));
   if (!def.hasStorageCustomConstructor()) {
     auto &body = construct->body().indent();
     for (const auto &it : llvm::enumerate(params)) {
-      body << formatv("auto {0} = std::get<{1}>(tblgenKey);\n",
+      body << formatv("auto {0} = std::move(std::get<{1}>(tblgenKey));\n",
                       it.value().getName(), it.index());
     }
     // Use the parameters' custom allocator code, if provided.
@@ -544,8 +546,9 @@ void DefGen::emitConstruct() {
         body.scope(strfmt("return new (allocator.allocate<{0}>()) {0}(",
                           def.getStorageClassName()),
                    ");");
-    llvm::interleaveComma(params, body,
-                          [&](auto &param) { body << param.getName(); });
+    llvm::interleaveComma(params, body, [&](auto &param) {
+      body << "std::move(" << param.getName() << ")";
+    });
   }
 }
 
diff --git a/mlir/unittests/IR/AttributeTest.cpp b/mlir/unittests/IR/AttributeTest.cpp
index 9afbce037b408c0..6307a10bad4cd93 100644
--- a/mlir/unittests/IR/AttributeTest.cpp
+++ b/mlir/unittests/IR/AttributeTest.cpp
@@ -13,6 +13,8 @@
 #include "gtest/gtest.h"
 #include <optional>
 
+#include "../../test/lib/Dialect/Test/TestDialect.h"
+
 using namespace mlir;
 using namespace mlir::detail;
 
@@ -459,4 +461,25 @@ TEST(SubElementTest, Nested) {
             ArrayRef<Attribute>(
                 {strAttr, trueAttr, falseAttr, boolArrayAttr, dictAttr}));
 }
+
+// Test how many times we call copy-ctor when building an attribute.
+TEST(CopyCountAttr, CopyCount) {
+  MLIRContext context;
+  context.loadDialect<test::TestDialect>();
+
+  test::CopyCount::counter = 0;
+  test::CopyCount copyCount("hello");
+  test::TestCopyCountAttr::get(&context, std::move(copyCount));
+  int counter1 = test::CopyCount::counter;
+  test::CopyCount::counter = 0;
+  test::TestCopyCountAttr::get(&context, std::move(copyCount));
+#ifndef NDEBUG
+  EXPECT_EQ(counter1, 1);
+  EXPECT_EQ(test::CopyCount::counter, 1);
+#else
+  EXPECT_EQ(counter1, 0);
+  EXPECT_EQ(test::CopyCount::counter, 0);
+#endif
+}
+
 } // namespace



More information about the Mlir-commits mailing list