[flang-commits] [flang] dc3dc97 - Remove the `conversionCallStack` from the MLIR TypeConverter

Mehdi Amini via flang-commits flang-commits at lists.llvm.org
Sun Aug 27 16:14:40 PDT 2023


Author: Mehdi Amini
Date: 2023-08-27T16:14:31-07:00
New Revision: dc3dc97410af5a298f9374da1f8ca797c43f5f08

URL: https://github.com/llvm/llvm-project/commit/dc3dc97410af5a298f9374da1f8ca797c43f5f08
DIFF: https://github.com/llvm/llvm-project/commit/dc3dc97410af5a298f9374da1f8ca797c43f5f08.diff

LOG: Remove the `conversionCallStack` from the MLIR TypeConverter

This vector keeps tracks of recursive types through the recursive invocations
of `convertType()`. However this is something only useful for some specific
cases, in which the dedicated conversion callbacks can handle this stack
privately.

This allows removing a mutable member of the type converter.

Reviewed By: springerm

Differential Revision: https://reviews.llvm.org/D158351

Added: 
    

Modified: 
    flang/include/flang/Optimizer/CodeGen/TypeConverter.h
    flang/lib/Optimizer/CodeGen/TypeConverter.cpp
    mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
    mlir/include/mlir/Transforms/DialectConversion.h
    mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
    mlir/lib/Transforms/Utils/DialectConversion.cpp
    mlir/test/lib/Dialect/Test/TestPatterns.cpp

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Optimizer/CodeGen/TypeConverter.h b/flang/include/flang/Optimizer/CodeGen/TypeConverter.h
index f42c40eb68902b..5b8b51b5a30bc8 100644
--- a/flang/include/flang/Optimizer/CodeGen/TypeConverter.h
+++ b/flang/include/flang/Optimizer/CodeGen/TypeConverter.h
@@ -57,8 +57,7 @@ class LLVMTypeConverter : public mlir::LLVMTypeConverter {
   // fir.type<name(p : TY'...){f : TY...}>  -->  llvm<"%name = { ty... }">
   std::optional<mlir::LogicalResult>
   convertRecordType(fir::RecordType derived,
-                    llvm::SmallVectorImpl<mlir::Type> &results,
-                    llvm::ArrayRef<mlir::Type> callStack) const;
+                    llvm::SmallVectorImpl<mlir::Type> &results);
 
   // Is an extended descriptor needed given the element type of a fir.box type ?
   // Extended descriptors are required for derived types.

diff  --git a/flang/lib/Optimizer/CodeGen/TypeConverter.cpp b/flang/lib/Optimizer/CodeGen/TypeConverter.cpp
index fd5f0c7135fea2..77e94c00ec0792 100644
--- a/flang/lib/Optimizer/CodeGen/TypeConverter.cpp
+++ b/flang/lib/Optimizer/CodeGen/TypeConverter.cpp
@@ -21,6 +21,7 @@
 #include "flang/Optimizer/Dialect/Support/FIRContext.h"
 #include "flang/Optimizer/Dialect/Support/KindMapping.h"
 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "llvm/ADT/ScopeExit.h"
 #include "llvm/Support/Debug.h"
 
 namespace fir {
@@ -81,11 +82,10 @@ LLVMTypeConverter::LLVMTypeConverter(mlir::ModuleOp module, bool applyTBAA)
   });
   addConversion(
       [&](fir::PointerType pointer) { return convertPointerLike(pointer); });
-  addConversion([&](fir::RecordType derived,
-                    llvm::SmallVectorImpl<mlir::Type> &results,
-                    llvm::ArrayRef<mlir::Type> callStack) {
-    return convertRecordType(derived, results, callStack);
-  });
+  addConversion(
+      [&](fir::RecordType derived, llvm::SmallVectorImpl<mlir::Type> &results) {
+        return convertRecordType(derived, results);
+      });
   addConversion(
       [&](fir::RealType real) { return convertRealType(real.getFKind()); });
   addConversion(
@@ -167,14 +167,19 @@ mlir::Type LLVMTypeConverter::indexType() const {
 
 // fir.type<name(p : TY'...){f : TY...}>  -->  llvm<"%name = { ty... }">
 std::optional<mlir::LogicalResult> LLVMTypeConverter::convertRecordType(
-    fir::RecordType derived, llvm::SmallVectorImpl<mlir::Type> &results,
-    llvm::ArrayRef<mlir::Type> callStack) const {
+    fir::RecordType derived, llvm::SmallVectorImpl<mlir::Type> &results) {
   auto name = derived.getName();
   auto st = mlir::LLVM::LLVMStructType::getIdentified(&getContext(), name);
-  if (llvm::count(callStack, derived) > 1) {
+
+  auto &callStack = getCurrentThreadRecursiveStack();
+  if (llvm::count(callStack, derived)) {
     results.push_back(st);
     return mlir::success();
   }
+  callStack.push_back(derived);
+  auto popConversionCallStack =
+      llvm::make_scope_exit([&callStack]() { callStack.pop_back(); });
+
   llvm::SmallVector<mlir::Type> members;
   for (auto mem : derived.getTypeList()) {
     // Prevent fir.box from degenerating to a pointer to a descriptor in the

diff  --git a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
index 94a2c7c2c52ae5..c7128b4d233dfb 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
@@ -162,6 +162,12 @@ class LLVMTypeConverter : public TypeConverter {
   /// Pointer to the LLVM dialect.
   LLVM::LLVMDialect *llvmDialect;
 
+  // Recursive structure detection.
+  // We store one entry per thread here, and rely on locking.
+  DenseMap<uint64_t, std::unique_ptr<SmallVector<Type>>> conversionCallStack;
+  llvm::sys::SmartRWMutex<true> callStackMutex;
+  SmallVector<Type> &getCurrentThreadRecursiveStack();
+
 private:
   /// Convert a function type.  The arguments and results are converted one by
   /// one.  Additionally, if the function returns more than one value, pack the

diff  --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 6e11c3ed0a0179..0fce86bae412c2 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -307,7 +307,7 @@ class TypeConverter {
   /// types is empty, the type is removed and any usages of the existing value
   /// are expected to be removed during conversion.
   using ConversionCallbackFn = std::function<std::optional<LogicalResult>(
-      Type, SmallVectorImpl<Type> &, ArrayRef<Type>)>;
+      Type, SmallVectorImpl<Type> &)>;
 
   /// The signature of the callback used to materialize a conversion.
   using MaterializationCallbackFn = std::function<std::optional<Value>(
@@ -330,44 +330,30 @@ class TypeConverter {
   template <typename T, typename FnT>
   std::enable_if_t<std::is_invocable_v<FnT, T>, ConversionCallbackFn>
   wrapCallback(FnT &&callback) const {
-    return wrapCallback<T>(
-        [callback = std::forward<FnT>(callback)](
-            T type, SmallVectorImpl<Type> &results, ArrayRef<Type>) {
-          if (std::optional<Type> resultOpt = callback(type)) {
-            bool wasSuccess = static_cast<bool>(*resultOpt);
-            if (wasSuccess)
-              results.push_back(*resultOpt);
-            return std::optional<LogicalResult>(success(wasSuccess));
-          }
-          return std::optional<LogicalResult>();
-        });
+    return wrapCallback<T>([callback = std::forward<FnT>(callback)](
+                               T type, SmallVectorImpl<Type> &results) {
+      if (std::optional<Type> resultOpt = callback(type)) {
+        bool wasSuccess = static_cast<bool>(*resultOpt);
+        if (wasSuccess)
+          results.push_back(*resultOpt);
+        return std::optional<LogicalResult>(success(wasSuccess));
+      }
+      return std::optional<LogicalResult>();
+    });
   }
   /// With callback of form: `std::optional<LogicalResult>(
-  ///     T, SmallVectorImpl<Type> &)`.
+  ///     T, SmallVectorImpl<Type> &, ArrayRef<Type>)`.
   template <typename T, typename FnT>
   std::enable_if_t<std::is_invocable_v<FnT, T, SmallVectorImpl<Type> &>,
                    ConversionCallbackFn>
-  wrapCallback(FnT &&callback) const {
-    return wrapCallback<T>(
-        [callback = std::forward<FnT>(callback)](
-            T type, SmallVectorImpl<Type> &results, ArrayRef<Type>) {
-          return callback(type, results);
-        });
-  }
-  /// With callback of form: `std::optional<LogicalResult>(
-  ///     T, SmallVectorImpl<Type> &, ArrayRef<Type>)`.
-  template <typename T, typename FnT>
-  std::enable_if_t<
-      std::is_invocable_v<FnT, T, SmallVectorImpl<Type> &, ArrayRef<Type>>,
-      ConversionCallbackFn>
   wrapCallback(FnT &&callback) const {
     return [callback = std::forward<FnT>(callback)](
-               Type type, SmallVectorImpl<Type> &results,
-               ArrayRef<Type> callStack) -> std::optional<LogicalResult> {
+               Type type,
+               SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
       T derivedType = dyn_cast<T>(type);
       if (!derivedType)
         return std::nullopt;
-      return callback(derivedType, results, callStack);
+      return callback(derivedType, results);
     };
   }
 
@@ -435,10 +421,6 @@ class TypeConverter {
   mutable DenseMap<Type, Type> cachedDirectConversions;
   /// This cache stores the successful 1->N conversions, where N != 1.
   mutable DenseMap<Type, SmallVector<Type, 2>> cachedMultiConversions;
-
-  /// Stores the types that are being converted in the case when convertType
-  /// is being called recursively to convert nested types.
-  mutable SmallVector<Type, 2> conversionCallStack;
 };
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index b0842b9972c76d..2d2a2bff7013c6 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -11,10 +11,34 @@
 #include "mlir/Conversion/LLVMCommon/MemRefBuilder.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
+#include "llvm/ADT/ScopeExit.h"
+#include "llvm/Support/Threading.h"
+#include <memory>
+#include <mutex>
 #include <optional>
 
 using namespace mlir;
 
+SmallVector<Type> &LLVMTypeConverter::getCurrentThreadRecursiveStack() {
+  {
+    // Most of the time, the entry already exists in the map.
+    std::shared_lock<decltype(callStackMutex)> lock(callStackMutex,
+                                                    std::defer_lock);
+    if (getContext().isMultithreadingEnabled())
+      lock.lock();
+    auto recursiveStack = conversionCallStack.find(llvm::get_threadid());
+    if (recursiveStack != conversionCallStack.end())
+      return *recursiveStack->second;
+  }
+
+  // First time this thread gets here, we have to get an exclusive access to
+  // inset in the map
+  std::unique_lock<decltype(callStackMutex)> lock(callStackMutex);
+  auto recursiveStackInserted = conversionCallStack.insert(std::make_pair(
+      llvm::get_threadid(), std::make_unique<SmallVector<Type>>()));
+  return *recursiveStackInserted.first->second.get();
+}
+
 /// Create an LLVMTypeConverter using default LowerToLLVMOptions.
 LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
                                      const DataLayoutAnalysis *analysis)
@@ -56,8 +80,9 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
       return LLVM::LLVMPointerType::get(pointee, type.getAddressSpace());
     return std::nullopt;
   });
-  addConversion([&](LLVM::LLVMStructType type, SmallVectorImpl<Type> &results,
-                    ArrayRef<Type> callStack) -> std::optional<LogicalResult> {
+
+  addConversion([&](LLVM::LLVMStructType type, SmallVectorImpl<Type> &results)
+                    -> std::optional<LogicalResult> {
     // Fastpath for types that won't be converted by this callback anyway.
     if (LLVM::isCompatibleType(type)) {
       results.push_back(type);
@@ -75,10 +100,15 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
             type.getContext(),
             ("_Converted_" + std::to_string(counter) + type.getName()).str());
       }
-      if (llvm::count(callStack, type) > 1) {
+
+      SmallVectorImpl<Type> &recursiveStack = getCurrentThreadRecursiveStack();
+      if (llvm::count(recursiveStack, type)) {
         results.push_back(convertedType);
         return success();
       }
+      recursiveStack.push_back(type);
+      auto popConversionCallStack = llvm::make_scope_exit(
+          [&recursiveStack]() { recursiveStack.pop_back(); });
 
       SmallVector<Type> convertedElemTypes;
       convertedElemTypes.reserve(type.getBody().size());

diff  --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index a345f8d5b5be49..5ce3281b51a1eb 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -2935,12 +2935,9 @@ LogicalResult TypeConverter::convertType(Type t,
   // Walk the added converters in reverse order to apply the most recently
   // registered first.
   size_t currentCount = results.size();
-  conversionCallStack.push_back(t);
-  auto popConversionCallStack =
-      llvm::make_scope_exit([this]() { conversionCallStack.pop_back(); });
+
   for (const ConversionCallbackFn &converter : llvm::reverse(conversions)) {
-    if (std::optional<LogicalResult> result =
-            converter(t, results, conversionCallStack)) {
+    if (std::optional<LogicalResult> result = converter(t, results)) {
       if (!succeeded(*result)) {
         cachedDirectConversions.try_emplace(t, nullptr);
         return failure();

diff  --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 30ed4109ad8bd1..47afe6d2eecb6d 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -17,6 +17,7 @@
 #include "mlir/Transforms/DialectConversion.h"
 #include "mlir/Transforms/FoldUtils.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/ScopeExit.h"
 
 using namespace mlir;
 using namespace test;
@@ -1374,6 +1375,7 @@ struct TestTypeConversionDriver
 
   void runOnOperation() override {
     // Initialize the type converter.
+    SmallVector<Type, 2> conversionCallStack;
     TypeConverter converter;
 
     /// Add the legal set of type conversions.
@@ -1394,8 +1396,8 @@ struct TestTypeConversionDriver
     converter.addConversion(
         // Convert a recursive self-referring type into a non-self-referring
         // type named "outer_converted_type" that contains a SimpleAType.
-        [&](test::TestRecursiveType type, SmallVectorImpl<Type> &results,
-            ArrayRef<Type> callStack) -> std::optional<LogicalResult> {
+        [&](test::TestRecursiveType type,
+            SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
           // If the type is already converted, return it to indicate that it is
           // legal.
           if (type.getName() == "outer_converted_type") {
@@ -1403,11 +1405,16 @@ struct TestTypeConversionDriver
             return success();
           }
 
+          conversionCallStack.push_back(type);
+          auto popConversionCallStack = llvm::make_scope_exit(
+              [&conversionCallStack]() { conversionCallStack.pop_back(); });
+
           // If the type is on the call stack more than once (it is there at
           // least once because of the _current_ call, which is always the last
           // element on the stack), we've hit the recursive case. Just return
           // SimpleAType here to create a non-recursive type as a result.
-          if (llvm::is_contained(callStack.drop_back(), type)) {
+          if (llvm::is_contained(ArrayRef(conversionCallStack).drop_back(),
+                                 type)) {
             results.push_back(test::SimpleAType::get(type.getContext()));
             return success();
           }


        


More information about the flang-commits mailing list