[Mlir-commits] [mlir] Improve MLIR attribute get() method efficiency when complex members are involved (PR #68067)
Mehdi Amini
llvmlistbot at llvm.org
Mon Oct 2 23:23:55 PDT 2023
https://github.com/joker-eph updated https://github.com/llvm/llvm-project/pull/68067
>From 10825ea55946acf64377b01e3439b6b3267d4572 Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Mon, 2 Oct 2023 21:48:06 -0700
Subject: [PATCH] Improve MLIR attribute get() method efficiency when complex
members are involved
This ensures that the proper forward/move are involved, we go from 6
copy-construction to 0 (!) when building without assertions.
---
mlir/include/mlir/IR/StorageUniquerSupport.h | 9 +++--
mlir/include/mlir/Support/StorageUniquer.h | 5 ++-
mlir/test/lib/Dialect/Test/TestAttrDefs.td | 11 +++++
mlir/test/lib/Dialect/Test/TestAttributes.cpp | 40 +++++++++++++++++++
mlir/test/lib/Dialect/Test/TestAttributes.h | 13 ++++++
mlir/test/mlir-tblgen/attrdefs.td | 4 +-
mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp | 17 ++++----
mlir/unittests/IR/AttributeTest.cpp | 23 +++++++++++
8 files changed, 107 insertions(+), 15 deletions(-)
diff --git a/mlir/include/mlir/IR/StorageUniquerSupport.h b/mlir/include/mlir/IR/StorageUniquerSupport.h
index c466e230d341d3e..982d5220ab52ce9 100644
--- a/mlir/include/mlir/IR/StorageUniquerSupport.h
+++ b/mlir/include/mlir/IR/StorageUniquerSupport.h
@@ -175,11 +175,11 @@ class StorageUserBase : public BaseT, public Traits<ConcreteT>... {
/// function is guaranteed to return a non null object and will assert if
/// the arguments provided are invalid.
template <typename... Args>
- static ConcreteT get(MLIRContext *ctx, Args... args) {
+ static ConcreteT get(MLIRContext *ctx, Args &&...args) {
// Ensure that the invariants are correct for construction.
assert(
succeeded(ConcreteT::verify(getDefaultDiagnosticEmitFn(ctx), args...)));
- return UniquerT::template get<ConcreteT>(ctx, args...);
+ return UniquerT::template get<ConcreteT>(ctx, std::forward<Args>(args)...);
}
/// Get or create a new ConcreteT instance within the ctx, defined at
@@ -187,8 +187,9 @@ class StorageUserBase : public BaseT, public Traits<ConcreteT>... {
/// invalid, errors are emitted using the provided location and a null object
/// is returned.
template <typename... Args>
- static ConcreteT getChecked(const Location &loc, Args... args) {
- return ConcreteT::getChecked(getDefaultDiagnosticEmitFn(loc), args...);
+ static ConcreteT getChecked(const Location &loc, Args &&...args) {
+ return ConcreteT::getChecked(getDefaultDiagnosticEmitFn(loc),
+ std::forward<Args>(args)...);
}
/// Get or create a new ConcreteT instance within the ctx. If the arguments
diff --git a/mlir/include/mlir/Support/StorageUniquer.h b/mlir/include/mlir/Support/StorageUniquer.h
index 13359bf91f40d17..baaedc47dcb2cd5 100644
--- a/mlir/include/mlir/Support/StorageUniquer.h
+++ b/mlir/include/mlir/Support/StorageUniquer.h
@@ -16,6 +16,7 @@
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Allocator.h"
+#include <utility>
namespace mlir {
namespace detail {
@@ -300,9 +301,9 @@ class StorageUniquer {
static typename ImplTy::KeyTy getKey(Args &&...args) {
if constexpr (llvm::is_detected<detail::has_impltype_getkey_t, ImplTy,
Args...>::value)
- return ImplTy::getKey(args...);
+ return ImplTy::getKey(std::forward<Args>(args)...);
else
- return typename ImplTy::KeyTy(args...);
+ return typename ImplTy::KeyTy(std::forward<Args>(args)...);
}
//===--------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
index ec0a5548a160338..945c54c04d47ce8 100644
--- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
@@ -323,5 +323,16 @@ def Test_IteratorTypeArrayAttr
: TypedArrayAttrBase<Test_IteratorTypeEnum,
"Iterator type should be an enum.">;
+def TestParamCopyCount : AttrParameter<"CopyCount", "", "const CopyCount &"> {}
+
+// Test overridding attribute builders with a custom builder.
+def TestCopyCount : Test_Attr<"TestCopyCount"> {
+ let mnemonic = "copy_count";
+ let parameters = (ins TestParamCopyCount:$copy_count);
+ let assemblyFormat = "`<` $copy_count `>`";
+}
+
+
+
#endif // TEST_ATTRDEFS
diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
index 7fc2e6ab3ec0a0a..c240354e5d99044 100644
--- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
@@ -22,6 +22,7 @@
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/ADT/bit.h"
#include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/raw_ostream.h"
using namespace mlir;
using namespace test;
@@ -175,6 +176,45 @@ static void printTrueFalse(AsmPrinter &p, std::optional<int> result) {
p << (*result ? "true" : "false");
}
+//===----------------------------------------------------------------------===//
+// CopyCountAttr Implementation
+//===----------------------------------------------------------------------===//
+
+CopyCount::CopyCount(const CopyCount &rhs) : value(rhs.value) {
+ CopyCount::counter++;
+}
+
+CopyCount &CopyCount::operator=(const CopyCount &rhs) {
+ CopyCount::counter++;
+ value = rhs.value;
+ return *this;
+}
+
+int CopyCount::counter;
+
+static bool operator==(const test::CopyCount &lhs, const test::CopyCount &rhs) {
+ return lhs.value == rhs.value;
+}
+
+llvm::raw_ostream &test::operator<<(llvm::raw_ostream &os,
+ const test::CopyCount &value) {
+ return os << value.value;
+}
+
+template <>
+struct mlir::FieldParser<test::CopyCount> {
+ static FailureOr<test::CopyCount> parse(AsmParser &parser) {
+ std::string value;
+ if (parser.parseKeyword(value))
+ return failure();
+ return test::CopyCount(value);
+ }
+};
+namespace test {
+llvm::hash_code hash_value(const test::CopyCount ©Count) {
+ return llvm::hash_value(copyCount.value);
+}
+} // namespace test
//===----------------------------------------------------------------------===//
// Tablegen Generated Definitions
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.h b/mlir/test/lib/Dialect/Test/TestAttributes.h
index cc73e078bf7e20b..ef6eae51fdd628a 100644
--- a/mlir/test/lib/Dialect/Test/TestAttributes.h
+++ b/mlir/test/lib/Dialect/Test/TestAttributes.h
@@ -29,6 +29,19 @@
namespace test {
class TestDialect;
+// Payload class for the CopyCountAttr.
+class CopyCount {
+public:
+ CopyCount(std::string value) : value(value) {}
+ CopyCount(const CopyCount &rhs);
+ CopyCount &operator=(const CopyCount &rhs);
+ CopyCount(CopyCount &&rhs) = default;
+ CopyCount &operator=(CopyCount &&rhs) = default;
+ static int counter;
+ std::string value;
+};
+llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
+ const test::CopyCount &value);
/// A handle used to reference external elements instances.
using TestDialectResourceBlobHandle =
diff --git a/mlir/test/mlir-tblgen/attrdefs.td b/mlir/test/mlir-tblgen/attrdefs.td
index 020cd0ca65b691c..d7228368e71eb05 100644
--- a/mlir/test/mlir-tblgen/attrdefs.td
+++ b/mlir/test/mlir-tblgen/attrdefs.td
@@ -82,7 +82,7 @@ def B_CompoundAttrA : TestAttr<"CompoundA"> {
// Check that AttributeSelfTypeParameter is handled properly.
// DEF-LABEL: struct CompoundAAttrStorage
// DEF: CompoundAAttrStorage(
-// DEF-SAME: inner(inner)
+// DEF-SAME: inner(std::move(inner))
// DEF: bool operator==(const KeyTy &tblgenKey) const {
// DEF-NEXT: return
@@ -94,7 +94,7 @@ def B_CompoundAttrA : TestAttr<"CompoundA"> {
// DEF: static CompoundAAttrStorage *construct
// DEF: return new (allocator.allocate<CompoundAAttrStorage>())
-// DEF-SAME: CompoundAAttrStorage(widthOfSomething, exampleTdType, apFloat, dims, inner);
+// DEF-SAME: CompoundAAttrStorage(std::move(widthOfSomething), std::move(exampleTdType), std::move(apFloat), std::move(dims), std::move(inner));
// DEF: ::mlir::Type CompoundAAttr::getInner() const {
// DEF-NEXT: return getImpl()->inner;
diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
index f6e43d42d29f069..f14d33c7d13d310 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
@@ -353,7 +353,7 @@ void DefGen::emitDefaultBuilder() {
MethodBody &body = m->body().indent();
auto scope = body.scope("return Base::get(context", ");");
for (const auto ¶m : params)
- body << ", " << param.getName();
+ body << ", std::move(" << param.getName() << ")";
}
void DefGen::emitCheckedBuilder() {
@@ -474,8 +474,10 @@ void DefGen::emitTraitMethod(const InterfaceMethod &method) {
void DefGen::emitStorageConstructor() {
Constructor *ctor =
storageCls->addConstructor<Method::Inline>(getBuilderParams({}));
- for (auto ¶m : params)
- ctor->addMemberInitializer(param.getName(), param.getName());
+ for (auto ¶m : params) {
+ std::string movedValue = ("std::move(" + param.getName() + ")").str();
+ ctor->addMemberInitializer(param.getName(), movedValue);
+ }
}
void DefGen::emitKeyType() {
@@ -525,11 +527,11 @@ void DefGen::emitConstruct() {
: Method::Static,
MethodParameter(strfmt("::mlir::{0}StorageAllocator &", valueType),
"allocator"),
- MethodParameter("const KeyTy &", "tblgenKey"));
+ MethodParameter("KeyTy &", "tblgenKey"));
if (!def.hasStorageCustomConstructor()) {
auto &body = construct->body().indent();
for (const auto &it : llvm::enumerate(params)) {
- body << formatv("auto {0} = std::get<{1}>(tblgenKey);\n",
+ body << formatv("auto {0} = std::move(std::get<{1}>(tblgenKey));\n",
it.value().getName(), it.index());
}
// Use the parameters' custom allocator code, if provided.
@@ -544,8 +546,9 @@ void DefGen::emitConstruct() {
body.scope(strfmt("return new (allocator.allocate<{0}>()) {0}(",
def.getStorageClassName()),
");");
- llvm::interleaveComma(params, body,
- [&](auto ¶m) { body << param.getName(); });
+ llvm::interleaveComma(params, body, [&](auto ¶m) {
+ body << "std::move(" << param.getName() << ")";
+ });
}
}
diff --git a/mlir/unittests/IR/AttributeTest.cpp b/mlir/unittests/IR/AttributeTest.cpp
index 9afbce037b408c0..6307a10bad4cd93 100644
--- a/mlir/unittests/IR/AttributeTest.cpp
+++ b/mlir/unittests/IR/AttributeTest.cpp
@@ -13,6 +13,8 @@
#include "gtest/gtest.h"
#include <optional>
+#include "../../test/lib/Dialect/Test/TestDialect.h"
+
using namespace mlir;
using namespace mlir::detail;
@@ -459,4 +461,25 @@ TEST(SubElementTest, Nested) {
ArrayRef<Attribute>(
{strAttr, trueAttr, falseAttr, boolArrayAttr, dictAttr}));
}
+
+// Test how many times we call copy-ctor when building an attribute.
+TEST(CopyCountAttr, CopyCount) {
+ MLIRContext context;
+ context.loadDialect<test::TestDialect>();
+
+ test::CopyCount::counter = 0;
+ test::CopyCount copyCount("hello");
+ test::TestCopyCountAttr::get(&context, std::move(copyCount));
+ int counter1 = test::CopyCount::counter;
+ test::CopyCount::counter = 0;
+ test::TestCopyCountAttr::get(&context, std::move(copyCount));
+#ifndef NDEBUG
+ EXPECT_EQ(counter1, 1);
+ EXPECT_EQ(test::CopyCount::counter, 1);
+#else
+ EXPECT_EQ(counter1, 0);
+ EXPECT_EQ(test::CopyCount::counter, 0);
+#endif
+}
+
} // namespace
More information about the Mlir-commits
mailing list