[Mlir-commits] [mlir] a8daefe - Lock the MLIR TypeConverter caches management to make it thread-safe (NFC)

Mehdi Amini llvmlistbot at llvm.org
Sun Aug 27 16:58:31 PDT 2023


Author: Mehdi Amini
Date: 2023-08-27T16:45:33-07:00
New Revision: a8daefed341f56a28cec3a83c86b2e65471aae88

URL: https://github.com/llvm/llvm-project/commit/a8daefed341f56a28cec3a83c86b2e65471aae88
DIFF: https://github.com/llvm/llvm-project/commit/a8daefed341f56a28cec3a83c86b2e65471aae88.diff

LOG: Lock the MLIR TypeConverter caches management to make it thread-safe (NFC)

Reviewed By: springerm

Differential Revision: https://reviews.llvm.org/D158354

Added: 
    

Modified: 
    mlir/include/mlir/Transforms/DialectConversion.h
    mlir/lib/Transforms/Utils/DialectConversion.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 0fce86bae412c2..6de981d35c8c3a 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -38,6 +38,22 @@ class Value;
 class TypeConverter {
 public:
   virtual ~TypeConverter() = default;
+  TypeConverter() = default;
+  // Copy the registered conversions, but not the caches
+  TypeConverter(const TypeConverter &other)
+      : conversions(other.conversions),
+        argumentMaterializations(other.argumentMaterializations),
+        sourceMaterializations(other.sourceMaterializations),
+        targetMaterializations(other.targetMaterializations),
+        typeAttributeConversions(other.typeAttributeConversions) {}
+  TypeConverter &operator=(const TypeConverter &other) {
+    conversions = other.conversions;
+    argumentMaterializations = other.argumentMaterializations;
+    sourceMaterializations = other.sourceMaterializations;
+    targetMaterializations = other.targetMaterializations;
+    typeAttributeConversions = other.typeAttributeConversions;
+    return *this;
+  }
 
   /// This class provides all of the information necessary to convert a type
   /// signature.
@@ -421,6 +437,8 @@ class TypeConverter {
   mutable DenseMap<Type, Type> cachedDirectConversions;
   /// This cache stores the successful 1->N conversions, where N != 1.
   mutable DenseMap<Type, SmallVector<Type, 2>> cachedMultiConversions;
+  /// A mutex used for cache access
+  mutable llvm::sys::SmartRWMutex<true> cacheMutex;
 };
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 5ce3281b51a1eb..3fcf32094414b0 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -2920,24 +2920,34 @@ void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo,
 
 LogicalResult TypeConverter::convertType(Type t,
                                          SmallVectorImpl<Type> &results) const {
-  auto existingIt = cachedDirectConversions.find(t);
-  if (existingIt != cachedDirectConversions.end()) {
-    if (existingIt->second)
-      results.push_back(existingIt->second);
-    return success(existingIt->second != nullptr);
-  }
-  auto multiIt = cachedMultiConversions.find(t);
-  if (multiIt != cachedMultiConversions.end()) {
-    results.append(multiIt->second.begin(), multiIt->second.end());
-    return success();
+  {
+    std::shared_lock<decltype(cacheMutex)> cacheReadLock(cacheMutex,
+                                                         std::defer_lock);
+    if (t.getContext()->isMultithreadingEnabled())
+      cacheReadLock.lock();
+    auto existingIt = cachedDirectConversions.find(t);
+    if (existingIt != cachedDirectConversions.end()) {
+      if (existingIt->second)
+        results.push_back(existingIt->second);
+      return success(existingIt->second != nullptr);
+    }
+    auto multiIt = cachedMultiConversions.find(t);
+    if (multiIt != cachedMultiConversions.end()) {
+      results.append(multiIt->second.begin(), multiIt->second.end());
+      return success();
+    }
   }
-
   // Walk the added converters in reverse order to apply the most recently
   // registered first.
   size_t currentCount = results.size();
 
+  std::unique_lock<decltype(cacheMutex)> cacheWriteLock(cacheMutex,
+                                                        std::defer_lock);
+
   for (const ConversionCallbackFn &converter : llvm::reverse(conversions)) {
     if (std::optional<LogicalResult> result = converter(t, results)) {
+      if (t.getContext()->isMultithreadingEnabled())
+        cacheWriteLock.lock();
       if (!succeeded(*result)) {
         cachedDirectConversions.try_emplace(t, nullptr);
         return failure();


        


More information about the Mlir-commits mailing list