[Mlir-commits] [mlir] a8daefe - Lock the MLIR TypeConverter caches management to make it thread-safe (NFC)
Mehdi Amini
llvmlistbot at llvm.org
Sun Aug 27 16:58:31 PDT 2023
Author: Mehdi Amini
Date: 2023-08-27T16:45:33-07:00
New Revision: a8daefed341f56a28cec3a83c86b2e65471aae88
URL: https://github.com/llvm/llvm-project/commit/a8daefed341f56a28cec3a83c86b2e65471aae88
DIFF: https://github.com/llvm/llvm-project/commit/a8daefed341f56a28cec3a83c86b2e65471aae88.diff
LOG: Lock the MLIR TypeConverter caches management to make it thread-safe (NFC)
Reviewed By: springerm
Differential Revision: https://reviews.llvm.org/D158354
Added:
Modified:
mlir/include/mlir/Transforms/DialectConversion.h
mlir/lib/Transforms/Utils/DialectConversion.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 0fce86bae412c2..6de981d35c8c3a 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -38,6 +38,22 @@ class Value;
class TypeConverter {
public:
virtual ~TypeConverter() = default;
+ TypeConverter() = default;
+ // Copy the registered conversions, but not the caches
+ TypeConverter(const TypeConverter &other)
+ : conversions(other.conversions),
+ argumentMaterializations(other.argumentMaterializations),
+ sourceMaterializations(other.sourceMaterializations),
+ targetMaterializations(other.targetMaterializations),
+ typeAttributeConversions(other.typeAttributeConversions) {}
+ TypeConverter &operator=(const TypeConverter &other) {
+ conversions = other.conversions;
+ argumentMaterializations = other.argumentMaterializations;
+ sourceMaterializations = other.sourceMaterializations;
+ targetMaterializations = other.targetMaterializations;
+ typeAttributeConversions = other.typeAttributeConversions;
+ return *this;
+ }
/// This class provides all of the information necessary to convert a type
/// signature.
@@ -421,6 +437,8 @@ class TypeConverter {
mutable DenseMap<Type, Type> cachedDirectConversions;
/// This cache stores the successful 1->N conversions, where N != 1.
mutable DenseMap<Type, SmallVector<Type, 2>> cachedMultiConversions;
+ /// A mutex used for cache access
+ mutable llvm::sys::SmartRWMutex<true> cacheMutex;
};
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 5ce3281b51a1eb..3fcf32094414b0 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -2920,24 +2920,34 @@ void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo,
LogicalResult TypeConverter::convertType(Type t,
SmallVectorImpl<Type> &results) const {
- auto existingIt = cachedDirectConversions.find(t);
- if (existingIt != cachedDirectConversions.end()) {
- if (existingIt->second)
- results.push_back(existingIt->second);
- return success(existingIt->second != nullptr);
- }
- auto multiIt = cachedMultiConversions.find(t);
- if (multiIt != cachedMultiConversions.end()) {
- results.append(multiIt->second.begin(), multiIt->second.end());
- return success();
+ {
+ std::shared_lock<decltype(cacheMutex)> cacheReadLock(cacheMutex,
+ std::defer_lock);
+ if (t.getContext()->isMultithreadingEnabled())
+ cacheReadLock.lock();
+ auto existingIt = cachedDirectConversions.find(t);
+ if (existingIt != cachedDirectConversions.end()) {
+ if (existingIt->second)
+ results.push_back(existingIt->second);
+ return success(existingIt->second != nullptr);
+ }
+ auto multiIt = cachedMultiConversions.find(t);
+ if (multiIt != cachedMultiConversions.end()) {
+ results.append(multiIt->second.begin(), multiIt->second.end());
+ return success();
+ }
}
-
// Walk the added converters in reverse order to apply the most recently
// registered first.
size_t currentCount = results.size();
+ std::unique_lock<decltype(cacheMutex)> cacheWriteLock(cacheMutex,
+ std::defer_lock);
+
for (const ConversionCallbackFn &converter : llvm::reverse(conversions)) {
if (std::optional<LogicalResult> result = converter(t, results)) {
+ if (t.getContext()->isMultithreadingEnabled())
+ cacheWriteLock.lock();
if (!succeeded(*result)) {
cachedDirectConversions.try_emplace(t, nullptr);
return failure();
More information about the Mlir-commits
mailing list