[Mlir-commits] [mlir] dc3dc97 - Remove the `conversionCallStack` from the MLIR TypeConverter
Mehdi Amini
llvmlistbot at llvm.org
Sun Aug 27 16:14:42 PDT 2023
Author: Mehdi Amini
Date: 2023-08-27T16:14:31-07:00
New Revision: dc3dc97410af5a298f9374da1f8ca797c43f5f08
URL: https://github.com/llvm/llvm-project/commit/dc3dc97410af5a298f9374da1f8ca797c43f5f08
DIFF: https://github.com/llvm/llvm-project/commit/dc3dc97410af5a298f9374da1f8ca797c43f5f08.diff
LOG: Remove the `conversionCallStack` from the MLIR TypeConverter
This vector keeps tracks of recursive types through the recursive invocations
of `convertType()`. However this is something only useful for some specific
cases, in which the dedicated conversion callbacks can handle this stack
privately.
This allows removing a mutable member of the type converter.
Reviewed By: springerm
Differential Revision: https://reviews.llvm.org/D158351
Added:
Modified:
flang/include/flang/Optimizer/CodeGen/TypeConverter.h
flang/lib/Optimizer/CodeGen/TypeConverter.cpp
mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
mlir/include/mlir/Transforms/DialectConversion.h
mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
mlir/lib/Transforms/Utils/DialectConversion.cpp
mlir/test/lib/Dialect/Test/TestPatterns.cpp
Removed:
################################################################################
diff --git a/flang/include/flang/Optimizer/CodeGen/TypeConverter.h b/flang/include/flang/Optimizer/CodeGen/TypeConverter.h
index f42c40eb68902b..5b8b51b5a30bc8 100644
--- a/flang/include/flang/Optimizer/CodeGen/TypeConverter.h
+++ b/flang/include/flang/Optimizer/CodeGen/TypeConverter.h
@@ -57,8 +57,7 @@ class LLVMTypeConverter : public mlir::LLVMTypeConverter {
// fir.type<name(p : TY'...){f : TY...}> --> llvm<"%name = { ty... }">
std::optional<mlir::LogicalResult>
convertRecordType(fir::RecordType derived,
- llvm::SmallVectorImpl<mlir::Type> &results,
- llvm::ArrayRef<mlir::Type> callStack) const;
+ llvm::SmallVectorImpl<mlir::Type> &results);
// Is an extended descriptor needed given the element type of a fir.box type ?
// Extended descriptors are required for derived types.
diff --git a/flang/lib/Optimizer/CodeGen/TypeConverter.cpp b/flang/lib/Optimizer/CodeGen/TypeConverter.cpp
index fd5f0c7135fea2..77e94c00ec0792 100644
--- a/flang/lib/Optimizer/CodeGen/TypeConverter.cpp
+++ b/flang/lib/Optimizer/CodeGen/TypeConverter.cpp
@@ -21,6 +21,7 @@
#include "flang/Optimizer/Dialect/Support/FIRContext.h"
#include "flang/Optimizer/Dialect/Support/KindMapping.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "llvm/ADT/ScopeExit.h"
#include "llvm/Support/Debug.h"
namespace fir {
@@ -81,11 +82,10 @@ LLVMTypeConverter::LLVMTypeConverter(mlir::ModuleOp module, bool applyTBAA)
});
addConversion(
[&](fir::PointerType pointer) { return convertPointerLike(pointer); });
- addConversion([&](fir::RecordType derived,
- llvm::SmallVectorImpl<mlir::Type> &results,
- llvm::ArrayRef<mlir::Type> callStack) {
- return convertRecordType(derived, results, callStack);
- });
+ addConversion(
+ [&](fir::RecordType derived, llvm::SmallVectorImpl<mlir::Type> &results) {
+ return convertRecordType(derived, results);
+ });
addConversion(
[&](fir::RealType real) { return convertRealType(real.getFKind()); });
addConversion(
@@ -167,14 +167,19 @@ mlir::Type LLVMTypeConverter::indexType() const {
// fir.type<name(p : TY'...){f : TY...}> --> llvm<"%name = { ty... }">
std::optional<mlir::LogicalResult> LLVMTypeConverter::convertRecordType(
- fir::RecordType derived, llvm::SmallVectorImpl<mlir::Type> &results,
- llvm::ArrayRef<mlir::Type> callStack) const {
+ fir::RecordType derived, llvm::SmallVectorImpl<mlir::Type> &results) {
auto name = derived.getName();
auto st = mlir::LLVM::LLVMStructType::getIdentified(&getContext(), name);
- if (llvm::count(callStack, derived) > 1) {
+
+ auto &callStack = getCurrentThreadRecursiveStack();
+ if (llvm::count(callStack, derived)) {
results.push_back(st);
return mlir::success();
}
+ callStack.push_back(derived);
+ auto popConversionCallStack =
+ llvm::make_scope_exit([&callStack]() { callStack.pop_back(); });
+
llvm::SmallVector<mlir::Type> members;
for (auto mem : derived.getTypeList()) {
// Prevent fir.box from degenerating to a pointer to a descriptor in the
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
index 94a2c7c2c52ae5..c7128b4d233dfb 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
@@ -162,6 +162,12 @@ class LLVMTypeConverter : public TypeConverter {
/// Pointer to the LLVM dialect.
LLVM::LLVMDialect *llvmDialect;
+ // Recursive structure detection.
+ // We store one entry per thread here, and rely on locking.
+ DenseMap<uint64_t, std::unique_ptr<SmallVector<Type>>> conversionCallStack;
+ llvm::sys::SmartRWMutex<true> callStackMutex;
+ SmallVector<Type> &getCurrentThreadRecursiveStack();
+
private:
/// Convert a function type. The arguments and results are converted one by
/// one. Additionally, if the function returns more than one value, pack the
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 6e11c3ed0a0179..0fce86bae412c2 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -307,7 +307,7 @@ class TypeConverter {
/// types is empty, the type is removed and any usages of the existing value
/// are expected to be removed during conversion.
using ConversionCallbackFn = std::function<std::optional<LogicalResult>(
- Type, SmallVectorImpl<Type> &, ArrayRef<Type>)>;
+ Type, SmallVectorImpl<Type> &)>;
/// The signature of the callback used to materialize a conversion.
using MaterializationCallbackFn = std::function<std::optional<Value>(
@@ -330,44 +330,30 @@ class TypeConverter {
template <typename T, typename FnT>
std::enable_if_t<std::is_invocable_v<FnT, T>, ConversionCallbackFn>
wrapCallback(FnT &&callback) const {
- return wrapCallback<T>(
- [callback = std::forward<FnT>(callback)](
- T type, SmallVectorImpl<Type> &results, ArrayRef<Type>) {
- if (std::optional<Type> resultOpt = callback(type)) {
- bool wasSuccess = static_cast<bool>(*resultOpt);
- if (wasSuccess)
- results.push_back(*resultOpt);
- return std::optional<LogicalResult>(success(wasSuccess));
- }
- return std::optional<LogicalResult>();
- });
+ return wrapCallback<T>([callback = std::forward<FnT>(callback)](
+ T type, SmallVectorImpl<Type> &results) {
+ if (std::optional<Type> resultOpt = callback(type)) {
+ bool wasSuccess = static_cast<bool>(*resultOpt);
+ if (wasSuccess)
+ results.push_back(*resultOpt);
+ return std::optional<LogicalResult>(success(wasSuccess));
+ }
+ return std::optional<LogicalResult>();
+ });
}
/// With callback of form: `std::optional<LogicalResult>(
- /// T, SmallVectorImpl<Type> &)`.
+ /// T, SmallVectorImpl<Type> &, ArrayRef<Type>)`.
template <typename T, typename FnT>
std::enable_if_t<std::is_invocable_v<FnT, T, SmallVectorImpl<Type> &>,
ConversionCallbackFn>
- wrapCallback(FnT &&callback) const {
- return wrapCallback<T>(
- [callback = std::forward<FnT>(callback)](
- T type, SmallVectorImpl<Type> &results, ArrayRef<Type>) {
- return callback(type, results);
- });
- }
- /// With callback of form: `std::optional<LogicalResult>(
- /// T, SmallVectorImpl<Type> &, ArrayRef<Type>)`.
- template <typename T, typename FnT>
- std::enable_if_t<
- std::is_invocable_v<FnT, T, SmallVectorImpl<Type> &, ArrayRef<Type>>,
- ConversionCallbackFn>
wrapCallback(FnT &&callback) const {
return [callback = std::forward<FnT>(callback)](
- Type type, SmallVectorImpl<Type> &results,
- ArrayRef<Type> callStack) -> std::optional<LogicalResult> {
+ Type type,
+ SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
T derivedType = dyn_cast<T>(type);
if (!derivedType)
return std::nullopt;
- return callback(derivedType, results, callStack);
+ return callback(derivedType, results);
};
}
@@ -435,10 +421,6 @@ class TypeConverter {
mutable DenseMap<Type, Type> cachedDirectConversions;
/// This cache stores the successful 1->N conversions, where N != 1.
mutable DenseMap<Type, SmallVector<Type, 2>> cachedMultiConversions;
-
- /// Stores the types that are being converted in the case when convertType
- /// is being called recursively to convert nested types.
- mutable SmallVector<Type, 2> conversionCallStack;
};
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index b0842b9972c76d..2d2a2bff7013c6 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -11,10 +11,34 @@
#include "mlir/Conversion/LLVMCommon/MemRefBuilder.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
+#include "llvm/ADT/ScopeExit.h"
+#include "llvm/Support/Threading.h"
+#include <memory>
+#include <mutex>
#include <optional>
using namespace mlir;
+SmallVector<Type> &LLVMTypeConverter::getCurrentThreadRecursiveStack() {
+ {
+ // Most of the time, the entry already exists in the map.
+ std::shared_lock<decltype(callStackMutex)> lock(callStackMutex,
+ std::defer_lock);
+ if (getContext().isMultithreadingEnabled())
+ lock.lock();
+ auto recursiveStack = conversionCallStack.find(llvm::get_threadid());
+ if (recursiveStack != conversionCallStack.end())
+ return *recursiveStack->second;
+ }
+
+ // First time this thread gets here, we have to get an exclusive access to
+ // inset in the map
+ std::unique_lock<decltype(callStackMutex)> lock(callStackMutex);
+ auto recursiveStackInserted = conversionCallStack.insert(std::make_pair(
+ llvm::get_threadid(), std::make_unique<SmallVector<Type>>()));
+ return *recursiveStackInserted.first->second.get();
+}
+
/// Create an LLVMTypeConverter using default LowerToLLVMOptions.
LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
const DataLayoutAnalysis *analysis)
@@ -56,8 +80,9 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
return LLVM::LLVMPointerType::get(pointee, type.getAddressSpace());
return std::nullopt;
});
- addConversion([&](LLVM::LLVMStructType type, SmallVectorImpl<Type> &results,
- ArrayRef<Type> callStack) -> std::optional<LogicalResult> {
+
+ addConversion([&](LLVM::LLVMStructType type, SmallVectorImpl<Type> &results)
+ -> std::optional<LogicalResult> {
// Fastpath for types that won't be converted by this callback anyway.
if (LLVM::isCompatibleType(type)) {
results.push_back(type);
@@ -75,10 +100,15 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
type.getContext(),
("_Converted_" + std::to_string(counter) + type.getName()).str());
}
- if (llvm::count(callStack, type) > 1) {
+
+ SmallVectorImpl<Type> &recursiveStack = getCurrentThreadRecursiveStack();
+ if (llvm::count(recursiveStack, type)) {
results.push_back(convertedType);
return success();
}
+ recursiveStack.push_back(type);
+ auto popConversionCallStack = llvm::make_scope_exit(
+ [&recursiveStack]() { recursiveStack.pop_back(); });
SmallVector<Type> convertedElemTypes;
convertedElemTypes.reserve(type.getBody().size());
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index a345f8d5b5be49..5ce3281b51a1eb 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -2935,12 +2935,9 @@ LogicalResult TypeConverter::convertType(Type t,
// Walk the added converters in reverse order to apply the most recently
// registered first.
size_t currentCount = results.size();
- conversionCallStack.push_back(t);
- auto popConversionCallStack =
- llvm::make_scope_exit([this]() { conversionCallStack.pop_back(); });
+
for (const ConversionCallbackFn &converter : llvm::reverse(conversions)) {
- if (std::optional<LogicalResult> result =
- converter(t, results, conversionCallStack)) {
+ if (std::optional<LogicalResult> result = converter(t, results)) {
if (!succeeded(*result)) {
cachedDirectConversions.try_emplace(t, nullptr);
return failure();
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 30ed4109ad8bd1..47afe6d2eecb6d 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -17,6 +17,7 @@
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/FoldUtils.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/ScopeExit.h"
using namespace mlir;
using namespace test;
@@ -1374,6 +1375,7 @@ struct TestTypeConversionDriver
void runOnOperation() override {
// Initialize the type converter.
+ SmallVector<Type, 2> conversionCallStack;
TypeConverter converter;
/// Add the legal set of type conversions.
@@ -1394,8 +1396,8 @@ struct TestTypeConversionDriver
converter.addConversion(
// Convert a recursive self-referring type into a non-self-referring
// type named "outer_converted_type" that contains a SimpleAType.
- [&](test::TestRecursiveType type, SmallVectorImpl<Type> &results,
- ArrayRef<Type> callStack) -> std::optional<LogicalResult> {
+ [&](test::TestRecursiveType type,
+ SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
// If the type is already converted, return it to indicate that it is
// legal.
if (type.getName() == "outer_converted_type") {
@@ -1403,11 +1405,16 @@ struct TestTypeConversionDriver
return success();
}
+ conversionCallStack.push_back(type);
+ auto popConversionCallStack = llvm::make_scope_exit(
+ [&conversionCallStack]() { conversionCallStack.pop_back(); });
+
// If the type is on the call stack more than once (it is there at
// least once because of the _current_ call, which is always the last
// element on the stack), we've hit the recursive case. Just return
// SimpleAType here to create a non-recursive type as a result.
- if (llvm::is_contained(callStack.drop_back(), type)) {
+ if (llvm::is_contained(ArrayRef(conversionCallStack).drop_back(),
+ type)) {
results.push_back(test::SimpleAType::get(type.getContext()));
return success();
}
More information about the Mlir-commits
mailing list