[Mlir-commits] [mlir] bb37296 - [MLIR] Add optional cached symbol tables to LLVM conversion patterns (#144032)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Jun 21 01:55:47 PDT 2025
Author: Michele Scuttari
Date: 2025-06-21T10:55:44+02:00
New Revision: bb372963dfcef9722d5aeb4f65ddb5c50be24e01
URL: https://github.com/llvm/llvm-project/commit/bb372963dfcef9722d5aeb4f65ddb5c50be24e01
DIFF: https://github.com/llvm/llvm-project/commit/bb372963dfcef9722d5aeb4f65ddb5c50be24e01.diff
LOG: [MLIR] Add optional cached symbol tables to LLVM conversion patterns (#144032)
This PR allows to optionally speed up the lookup of symbols by providing a `SymbolTableCollection` instance to the interested conversion patterns. It is follow-up on the discussion about symbol / symbol table management carried on [Discourse](https://discourse.llvm.org/t/symbol-table-as-first-class-citizen-in-builders/86813).
Added:
Modified:
mlir/include/mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h
mlir/include/mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h
mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h
mlir/include/mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h
mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp
mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h b/mlir/include/mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h
index 88f18022da9bb..2dfb6b03bcfcd 100644
--- a/mlir/include/mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h
+++ b/mlir/include/mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h
@@ -20,6 +20,7 @@ class DialectRegistry;
class LLVMTypeConverter;
class RewritePatternSet;
class Pass;
+class SymbolTableCollection;
#define GEN_PASS_DECL_CONVERTCONTROLFLOWTOLLVMPASS
#include "mlir/Conversion/Passes.h.inc"
@@ -39,9 +40,9 @@ void populateControlFlowToLLVMConversionPatterns(
/// Populate the cf.assert to LLVM conversion pattern. If `abortOnFailure` is
/// set to false, the program execution continues when a condition is
/// unsatisfied.
-void populateAssertToLLVMConversionPattern(const LLVMTypeConverter &converter,
- RewritePatternSet &patterns,
- bool abortOnFailure = true);
+void populateAssertToLLVMConversionPattern(
+ const LLVMTypeConverter &converter, RewritePatternSet &patterns,
+ bool abortOnFailure = true, SymbolTableCollection *symbolTables = nullptr);
void registerConvertControlFlowToLLVMInterface(DialectRegistry ®istry);
diff --git a/mlir/include/mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h b/mlir/include/mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h
index b1ea2740c0605..e530b0a43b8e0 100644
--- a/mlir/include/mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h
+++ b/mlir/include/mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h
@@ -27,20 +27,23 @@ class DialectRegistry;
class LLVMTypeConverter;
class RewritePatternSet;
class SymbolTable;
+class SymbolTableCollection;
/// Convert input FunctionOpInterface operation to LLVMFuncOp by using the
/// provided LLVMTypeConverter. Return failure if failed to so.
FailureOr<LLVM::LLVMFuncOp>
convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
ConversionPatternRewriter &rewriter,
- const LLVMTypeConverter &converter);
+ const LLVMTypeConverter &converter,
+ SymbolTableCollection *symbolTables = nullptr);
/// Collect the default pattern to convert a FuncOp to the LLVM dialect. If
/// `emitCWrappers` is set, the pattern will also produce functions
/// that pass memref descriptors by pointer-to-structure in addition to the
/// default unpacked form.
void populateFuncToLLVMFuncOpConversionPattern(
- const LLVMTypeConverter &converter, RewritePatternSet &patterns);
+ const LLVMTypeConverter &converter, RewritePatternSet &patterns,
+ SymbolTableCollection *symbolTables = nullptr);
/// Collect the patterns to convert from the Func dialect to LLVM. The
/// conversion patterns capture the LLVMTypeConverter and the LowerToLLVMOptions
@@ -57,7 +60,7 @@ void populateFuncToLLVMFuncOpConversionPattern(
/// not an error to provide it anyway.
void populateFuncToLLVMConversionPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
- const SymbolTable *symbolTable = nullptr);
+ SymbolTableCollection *symbolTables = nullptr);
void registerConvertFuncToLLVMInterface(DialectRegistry ®istry);
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h b/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h
index 33402301115b7..d7de40555bb6a 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h
@@ -17,6 +17,7 @@ namespace mlir {
class OpBuilder;
class LLVMTypeConverter;
+class SymbolTableCollection;
namespace LLVM {
@@ -26,7 +27,8 @@ namespace LLVM {
LogicalResult createPrintStrCall(
OpBuilder &builder, Location loc, ModuleOp moduleOp, StringRef symbolName,
StringRef string, const LLVMTypeConverter &typeConverter,
- bool addNewline = true, std::optional<StringRef> runtimeFunctionName = {});
+ bool addNewline = true, std::optional<StringRef> runtimeFunctionName = {},
+ SymbolTableCollection *symbolTables = nullptr);
} // namespace LLVM
} // namespace mlir
diff --git a/mlir/include/mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h b/mlir/include/mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h
index 996a64baf9dd5..e93d5bdce7bf2 100644
--- a/mlir/include/mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h
+++ b/mlir/include/mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h
@@ -16,6 +16,7 @@ class DialectRegistry;
class Pass;
class LLVMTypeConverter;
class RewritePatternSet;
+class SymbolTableCollection;
#define GEN_PASS_DECL_FINALIZEMEMREFTOLLVMCONVERSIONPASS
#include "mlir/Conversion/Passes.h.inc"
@@ -23,7 +24,8 @@ class RewritePatternSet;
/// Collect a set of patterns to convert memory-related operations from the
/// MemRef dialect to the LLVM dialect.
void populateFinalizeMemRefToLLVMConversionPatterns(
- const LLVMTypeConverter &converter, RewritePatternSet &patterns);
+ const LLVMTypeConverter &converter, RewritePatternSet &patterns,
+ SymbolTableCollection *symbolTables = nullptr);
void registerConvertMemRefToLLVMInterface(DialectRegistry ®istry);
diff --git a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
index 4a7ec6f2efe64..8ad9ed18acebd 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
@@ -24,6 +24,7 @@ class OpBuilder;
class Operation;
class Type;
class ValueRange;
+class SymbolTableCollection;
namespace LLVM {
class LLVMFuncOp;
@@ -33,55 +34,73 @@ class LLVMFuncOp;
/// implemented separately (e.g. as part of a support runtime library or as part
/// of the libc).
/// Failure if an unexpected version of function is found.
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintI64Fn(OpBuilder &b,
- Operation *moduleOp);
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintU64Fn(OpBuilder &b,
- Operation *moduleOp);
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintF16Fn(OpBuilder &b,
- Operation *moduleOp);
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintBF16Fn(OpBuilder &b,
- Operation *moduleOp);
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintF32Fn(OpBuilder &b,
- Operation *moduleOp);
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintF64Fn(OpBuilder &b,
- Operation *moduleOp);
+FailureOr<LLVM::LLVMFuncOp>
+lookupOrCreatePrintI64Fn(OpBuilder &b, Operation *moduleOp,
+ SymbolTableCollection *symbolTables = nullptr);
+FailureOr<LLVM::LLVMFuncOp>
+lookupOrCreatePrintU64Fn(OpBuilder &b, Operation *moduleOp,
+ SymbolTableCollection *symbolTables = nullptr);
+FailureOr<LLVM::LLVMFuncOp>
+lookupOrCreatePrintF16Fn(OpBuilder &b, Operation *moduleOp,
+ SymbolTableCollection *symbolTables = nullptr);
+FailureOr<LLVM::LLVMFuncOp>
+lookupOrCreatePrintBF16Fn(OpBuilder &b, Operation *moduleOp,
+ SymbolTableCollection *symbolTables = nullptr);
+FailureOr<LLVM::LLVMFuncOp>
+lookupOrCreatePrintF32Fn(OpBuilder &b, Operation *moduleOp,
+ SymbolTableCollection *symbolTables = nullptr);
+FailureOr<LLVM::LLVMFuncOp>
+lookupOrCreatePrintF64Fn(OpBuilder &b, Operation *moduleOp,
+ SymbolTableCollection *symbolTables = nullptr);
/// Declares a function to print a C-string.
/// If a custom runtime function is defined via `runtimeFunctionName`, it must
/// have the signature void(char const*). The default function is `printString`.
FailureOr<LLVM::LLVMFuncOp>
lookupOrCreatePrintStringFn(OpBuilder &b, Operation *moduleOp,
- std::optional<StringRef> runtimeFunctionName = {});
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintOpenFn(OpBuilder &b,
- Operation *moduleOp);
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintCloseFn(OpBuilder &b,
- Operation *moduleOp);
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintCommaFn(OpBuilder &b,
- Operation *moduleOp);
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintNewlineFn(OpBuilder &b,
- Operation *moduleOp);
+ std::optional<StringRef> runtimeFunctionName = {},
+ SymbolTableCollection *symbolTables = nullptr);
+FailureOr<LLVM::LLVMFuncOp>
+lookupOrCreatePrintOpenFn(OpBuilder &b, Operation *moduleOp,
+ SymbolTableCollection *symbolTables = nullptr);
+FailureOr<LLVM::LLVMFuncOp>
+lookupOrCreatePrintCloseFn(OpBuilder &b, Operation *moduleOp,
+ SymbolTableCollection *symbolTables = nullptr);
+FailureOr<LLVM::LLVMFuncOp>
+lookupOrCreatePrintCommaFn(OpBuilder &b, Operation *moduleOp,
+ SymbolTableCollection *symbolTables = nullptr);
+FailureOr<LLVM::LLVMFuncOp>
+lookupOrCreatePrintNewlineFn(OpBuilder &b, Operation *moduleOp,
+ SymbolTableCollection *symbolTables = nullptr);
+FailureOr<LLVM::LLVMFuncOp>
+lookupOrCreateMallocFn(OpBuilder &b, Operation *moduleOp, Type indexType,
+ SymbolTableCollection *symbolTables = nullptr);
FailureOr<LLVM::LLVMFuncOp>
-lookupOrCreateMallocFn(OpBuilder &b, Operation *moduleOp, Type indexType);
+lookupOrCreateAlignedAllocFn(OpBuilder &b, Operation *moduleOp, Type indexType,
+ SymbolTableCollection *symbolTables = nullptr);
FailureOr<LLVM::LLVMFuncOp>
-lookupOrCreateAlignedAllocFn(OpBuilder &b, Operation *moduleOp, Type indexType);
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreateFreeFn(OpBuilder &b,
- Operation *moduleOp);
+lookupOrCreateFreeFn(OpBuilder &b, Operation *moduleOp,
+ SymbolTableCollection *symbolTables = nullptr);
FailureOr<LLVM::LLVMFuncOp>
-lookupOrCreateGenericAllocFn(OpBuilder &b, Operation *moduleOp, Type indexType);
+lookupOrCreateGenericAllocFn(OpBuilder &b, Operation *moduleOp, Type indexType,
+ SymbolTableCollection *symbolTables = nullptr);
+FailureOr<LLVM::LLVMFuncOp> lookupOrCreateGenericAlignedAllocFn(
+ OpBuilder &b, Operation *moduleOp, Type indexType,
+ SymbolTableCollection *symbolTables = nullptr);
FailureOr<LLVM::LLVMFuncOp>
-lookupOrCreateGenericAlignedAllocFn(OpBuilder &b, Operation *moduleOp,
- Type indexType);
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreateGenericFreeFn(OpBuilder &b,
- Operation *moduleOp);
+lookupOrCreateGenericFreeFn(OpBuilder &b, Operation *moduleOp,
+ SymbolTableCollection *symbolTables = nullptr);
FailureOr<LLVM::LLVMFuncOp>
lookupOrCreateMemRefCopyFn(OpBuilder &b, Operation *moduleOp, Type indexType,
- Type unrankedDescriptorType);
+ Type unrankedDescriptorType,
+ SymbolTableCollection *symbolTables = nullptr);
/// Create a FuncOp with signature `resultType`(`paramTypes`)` and name `name`.
/// Return a failure if the FuncOp found has unexpected signature.
FailureOr<LLVM::LLVMFuncOp>
lookupOrCreateFn(OpBuilder &b, Operation *moduleOp, StringRef name,
ArrayRef<Type> paramTypes = {}, Type resultType = {},
- bool isVarArg = false, bool isReserved = false);
+ bool isVarArg = false, bool isReserved = false,
+ SymbolTableCollection *symbolTables = nullptr);
} // namespace LLVM
} // namespace mlir
diff --git a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
index d31d7d801e149..3d0804fd11b6b 100644
--- a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
+++ b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
@@ -44,9 +44,10 @@ namespace {
/// lowering.
struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {
explicit AssertOpLowering(const LLVMTypeConverter &typeConverter,
- bool abortOnFailedAssert = true)
+ bool abortOnFailedAssert = true,
+ SymbolTableCollection *symbolTables = nullptr)
: ConvertOpToLLVMPattern<cf::AssertOp>(typeConverter, /*benefit=*/1),
- abortOnFailedAssert(abortOnFailedAssert) {}
+ abortOnFailedAssert(abortOnFailedAssert), symbolTables(symbolTables) {}
LogicalResult
matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor,
@@ -64,7 +65,7 @@ struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {
auto createResult = LLVM::createPrintStrCall(
rewriter, loc, module, "assert_msg", op.getMsg(), *getTypeConverter(),
/*addNewLine=*/false,
- /*runtimeFunctionName=*/"puts");
+ /*runtimeFunctionName=*/"puts", symbolTables);
if (createResult.failed())
return failure();
@@ -96,6 +97,8 @@ struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {
/// If set to `false`, messages are printed but program execution continues.
/// This is useful for testing asserts.
bool abortOnFailedAssert = true;
+
+ SymbolTableCollection *symbolTables = nullptr;
};
/// Helper function for converting branch ops. This function converts the
@@ -232,8 +235,8 @@ void mlir::cf::populateControlFlowToLLVMConversionPatterns(
void mlir::cf::populateAssertToLLVMConversionPattern(
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
- bool abortOnFailure) {
- patterns.add<AssertOpLowering>(converter, abortOnFailure);
+ bool abortOnFailure, SymbolTableCollection *symbolTables) {
+ patterns.add<AssertOpLowering>(converter, abortOnFailure, symbolTables);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
index 538016927256b..4499cbd4d1a20 100644
--- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
+++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
@@ -299,10 +299,9 @@ static void restoreByValRefArgumentType(
}
}
-FailureOr<LLVM::LLVMFuncOp>
-mlir::convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
- ConversionPatternRewriter &rewriter,
- const LLVMTypeConverter &converter) {
+FailureOr<LLVM::LLVMFuncOp> mlir::convertFuncOpToLLVMFuncOp(
+ FunctionOpInterface funcOp, ConversionPatternRewriter &rewriter,
+ const LLVMTypeConverter &converter, SymbolTableCollection *symbolTables) {
// Check the funcOp has `FunctionType`.
auto funcTy = dyn_cast<FunctionType>(funcOp.getFunctionType());
if (!funcTy)
@@ -361,10 +360,25 @@ mlir::convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
SmallVector<NamedAttribute, 4> attributes;
filterFuncAttributes(funcOp, attributes);
+
+ Operation *symbolTableOp = funcOp->getParentWithTrait<OpTrait::SymbolTable>();
+
+ if (symbolTables && symbolTableOp) {
+ SymbolTable &symbolTable = symbolTables->getSymbolTable(symbolTableOp);
+ symbolTable.remove(funcOp);
+ }
+
auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
funcOp.getLoc(), funcOp.getName(), llvmType, linkage,
/*dsoLocal=*/false, /*cconv=*/LLVM::CConv::C, /*comdat=*/nullptr,
attributes);
+
+ if (symbolTables && symbolTableOp) {
+ auto ip = rewriter.getInsertionPoint();
+ SymbolTable &symbolTable = symbolTables->getSymbolTable(symbolTableOp);
+ symbolTable.insert(newFuncOp, ip);
+ }
+
cast<FunctionOpInterface>(newFuncOp.getOperation())
.setVisibility(funcOp.getVisibility());
@@ -473,16 +487,20 @@ namespace {
/// FuncOp legalization pattern that converts MemRef arguments to pointers to
/// MemRef descriptors (LLVM struct data types) containing all the MemRef type
/// information.
-struct FuncOpConversion : public ConvertOpToLLVMPattern<func::FuncOp> {
- FuncOpConversion(const LLVMTypeConverter &converter)
- : ConvertOpToLLVMPattern(converter) {}
+class FuncOpConversion : public ConvertOpToLLVMPattern<func::FuncOp> {
+ SymbolTableCollection *symbolTables = nullptr;
+
+public:
+ explicit FuncOpConversion(const LLVMTypeConverter &converter,
+ SymbolTableCollection *symbolTables = nullptr)
+ : ConvertOpToLLVMPattern(converter), symbolTables(symbolTables) {}
LogicalResult
matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
FailureOr<LLVM::LLVMFuncOp> newFuncOp = mlir::convertFuncOpToLLVMFuncOp(
cast<FunctionOpInterface>(funcOp.getOperation()), rewriter,
- *getTypeConverter());
+ *getTypeConverter(), symbolTables);
if (failed(newFuncOp))
return rewriter.notifyMatchFailure(funcOp, "Could not convert funcop");
@@ -591,11 +609,11 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
class CallOpLowering : public CallOpInterfaceLowering<func::CallOp> {
public:
- CallOpLowering(const LLVMTypeConverter &typeConverter,
- // Can be nullptr.
- const SymbolTable *symbolTable, PatternBenefit benefit = 1)
+ explicit CallOpLowering(const LLVMTypeConverter &typeConverter,
+ SymbolTableCollection *symbolTables = nullptr,
+ PatternBenefit benefit = 1)
: CallOpInterfaceLowering<func::CallOp>(typeConverter, benefit),
- symbolTable(symbolTable) {}
+ symbolTables(symbolTables) {}
LogicalResult
matchAndRewrite(func::CallOp callOp, OpAdaptor adaptor,
@@ -603,10 +621,10 @@ class CallOpLowering : public CallOpInterfaceLowering<func::CallOp> {
bool useBarePtrCallConv = false;
if (getTypeConverter()->getOptions().useBarePtrCallConv) {
useBarePtrCallConv = true;
- } else if (symbolTable != nullptr) {
+ } else if (symbolTables != nullptr) {
// Fast lookup.
Operation *callee =
- symbolTable->lookup(callOp.getCalleeAttr().getValue());
+ symbolTables->lookupNearestSymbolFrom(callOp, callOp.getCalleeAttr());
useBarePtrCallConv =
callee != nullptr && callee->hasAttr(barePtrAttrName);
} else {
@@ -620,7 +638,7 @@ class CallOpLowering : public CallOpInterfaceLowering<func::CallOp> {
}
private:
- const SymbolTable *symbolTable = nullptr;
+ SymbolTableCollection *symbolTables = nullptr;
};
struct CallIndirectOpLowering
@@ -731,16 +749,17 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern<func::ReturnOp> {
} // namespace
void mlir::populateFuncToLLVMFuncOpConversionPattern(
- const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
- patterns.add<FuncOpConversion>(converter);
+ const LLVMTypeConverter &converter, RewritePatternSet &patterns,
+ SymbolTableCollection *symbolTables) {
+ patterns.add<FuncOpConversion>(converter, symbolTables);
}
void mlir::populateFuncToLLVMConversionPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
- const SymbolTable *symbolTable) {
- populateFuncToLLVMFuncOpConversionPattern(converter, patterns);
+ SymbolTableCollection *symbolTables) {
+ populateFuncToLLVMFuncOpConversionPattern(converter, patterns, symbolTables);
patterns.add<CallIndirectOpLowering>(converter);
- patterns.add<CallOpLowering>(converter, symbolTable);
+ patterns.add<CallOpLowering>(converter, symbolTables);
patterns.add<ConstantOpLowering>(converter);
patterns.add<ReturnOpLowering>(converter);
}
@@ -780,15 +799,11 @@ struct ConvertFuncToLLVMPass
LLVMTypeConverter typeConverter(&getContext(), options,
&dataLayoutAnalysis);
- std::optional<SymbolTable> optSymbolTable = std::nullopt;
- const SymbolTable *symbolTable = nullptr;
- if (!options.useBarePtrCallConv) {
- optSymbolTable.emplace(m);
- symbolTable = &optSymbolTable.value();
- }
-
RewritePatternSet patterns(&getContext());
- populateFuncToLLVMConversionPatterns(typeConverter, patterns, symbolTable);
+ SymbolTableCollection symbolTables;
+
+ populateFuncToLLVMConversionPatterns(typeConverter, patterns,
+ &symbolTables);
LLVMConversionTarget target(getContext());
if (failed(applyPartialConversion(m, target, std::move(patterns))))
diff --git a/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp b/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp
index 2815e05b3e11b..49c73fbc9dd79 100644
--- a/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp
@@ -17,8 +17,26 @@
using namespace mlir;
using namespace llvm;
-static std::string ensureSymbolNameIsUnique(ModuleOp moduleOp,
- StringRef symbolName) {
+/// Check if a given symbol name is already in use within the module operation.
+/// If no symbol with such name is present, then the same identifier is
+/// returned. Otherwise, a unique and yet unused identifier is computed starting
+/// from the requested one.
+static std::string
+ensureSymbolNameIsUnique(ModuleOp moduleOp, StringRef symbolName,
+ SymbolTableCollection *symbolTables = nullptr) {
+ if (symbolTables) {
+ SymbolTable &symbolTable = symbolTables->getSymbolTable(moduleOp);
+ unsigned counter = 0;
+ SmallString<128> uniqueName = symbolTable.generateSymbolName<128>(
+ symbolName,
+ [&](const SmallString<128> &tentativeName) {
+ return symbolTable.lookupSymbolIn(moduleOp, tentativeName) != nullptr;
+ },
+ counter);
+
+ return static_cast<std::string>(uniqueName);
+ }
+
static int counter = 0;
std::string uniqueName = std::string(symbolName);
while (moduleOp.lookupSymbol(uniqueName)) {
@@ -30,7 +48,8 @@ static std::string ensureSymbolNameIsUnique(ModuleOp moduleOp,
LogicalResult mlir::LLVM::createPrintStrCall(
OpBuilder &builder, Location loc, ModuleOp moduleOp, StringRef symbolName,
StringRef string, const LLVMTypeConverter &typeConverter, bool addNewline,
- std::optional<StringRef> runtimeFunctionName) {
+ std::optional<StringRef> runtimeFunctionName,
+ SymbolTableCollection *symbolTables) {
auto ip = builder.saveInsertionPoint();
builder.setInsertionPointToStart(moduleOp.getBody());
MLIRContext *ctx = builder.getContext();
@@ -49,7 +68,7 @@ LogicalResult mlir::LLVM::createPrintStrCall(
LLVM::LLVMArrayType::get(IntegerType::get(ctx, 8), elementVals.size());
auto globalOp = builder.create<LLVM::GlobalOp>(
loc, arrayTy, /*constant=*/true, LLVM::Linkage::Private,
- ensureSymbolNameIsUnique(moduleOp, symbolName), dataAttr);
+ ensureSymbolNameIsUnique(moduleOp, symbolName, symbolTables), dataAttr);
auto ptrTy = LLVM::LLVMPointerType::get(builder.getContext());
// Emit call to `printStr` in runtime library.
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 8ccf1bfc292d5..e8294a5234c4f 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -48,35 +48,39 @@ static bool isStaticStrideOrOffset(int64_t strideOrOffset) {
}
static FailureOr<LLVM::LLVMFuncOp>
-getFreeFn(OpBuilder &b, const LLVMTypeConverter *typeConverter,
- ModuleOp module) {
+getFreeFn(OpBuilder &b, const LLVMTypeConverter *typeConverter, ModuleOp module,
+ SymbolTableCollection *symbolTables) {
bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
if (useGenericFn)
- return LLVM::lookupOrCreateGenericFreeFn(b, module);
+ return LLVM::lookupOrCreateGenericFreeFn(b, module, symbolTables);
- return LLVM::lookupOrCreateFreeFn(b, module);
+ return LLVM::lookupOrCreateFreeFn(b, module, symbolTables);
}
static FailureOr<LLVM::LLVMFuncOp>
getNotalignedAllocFn(OpBuilder &b, const LLVMTypeConverter *typeConverter,
- Operation *module, Type indexType) {
+ Operation *module, Type indexType,
+ SymbolTableCollection *symbolTables) {
bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
if (useGenericFn)
- return LLVM::lookupOrCreateGenericAllocFn(b, module, indexType);
+ return LLVM::lookupOrCreateGenericAllocFn(b, module, indexType,
+ symbolTables);
- return LLVM::lookupOrCreateMallocFn(b, module, indexType);
+ return LLVM::lookupOrCreateMallocFn(b, module, indexType, symbolTables);
}
static FailureOr<LLVM::LLVMFuncOp>
getAlignedAllocFn(OpBuilder &b, const LLVMTypeConverter *typeConverter,
- Operation *module, Type indexType) {
+ Operation *module, Type indexType,
+ SymbolTableCollection *symbolTables) {
bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
if (useGenericFn)
- return LLVM::lookupOrCreateGenericAlignedAllocFn(b, module, indexType);
+ return LLVM::lookupOrCreateGenericAlignedAllocFn(b, module, indexType,
+ symbolTables);
- return LLVM::lookupOrCreateAlignedAllocFn(b, module, indexType);
+ return LLVM::lookupOrCreateAlignedAllocFn(b, module, indexType, symbolTables);
}
/// Computes the aligned value for 'input' as follows:
@@ -126,8 +130,15 @@ static Value castAllocFuncResult(ConversionPatternRewriter &rewriter,
return allocatedPtr;
}
-struct AllocOpLowering : public ConvertOpToLLVMPattern<memref::AllocOp> {
- using ConvertOpToLLVMPattern<memref::AllocOp>::ConvertOpToLLVMPattern;
+class AllocOpLowering : public ConvertOpToLLVMPattern<memref::AllocOp> {
+ SymbolTableCollection *symbolTables = nullptr;
+
+public:
+ explicit AllocOpLowering(const LLVMTypeConverter &typeConverter,
+ SymbolTableCollection *symbolTables = nullptr,
+ PatternBenefit benefit = 1)
+ : ConvertOpToLLVMPattern<memref::AllocOp>(typeConverter, benefit),
+ symbolTables(symbolTables) {}
LogicalResult
matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor,
@@ -138,9 +149,10 @@ struct AllocOpLowering : public ConvertOpToLLVMPattern<memref::AllocOp> {
return rewriter.notifyMatchFailure(op, "incompatible memref type");
// Get or insert alloc function into the module.
- FailureOr<LLVM::LLVMFuncOp> allocFuncOp = getNotalignedAllocFn(
- rewriter, getTypeConverter(),
- op->getParentWithTrait<OpTrait::SymbolTable>(), getIndexType());
+ FailureOr<LLVM::LLVMFuncOp> allocFuncOp =
+ getNotalignedAllocFn(rewriter, getTypeConverter(),
+ op->getParentWithTrait<OpTrait::SymbolTable>(),
+ getIndexType(), symbolTables);
if (failed(allocFuncOp))
return failure();
@@ -210,8 +222,15 @@ struct AllocOpLowering : public ConvertOpToLLVMPattern<memref::AllocOp> {
}
};
-struct AlignedAllocOpLowering : public ConvertOpToLLVMPattern<memref::AllocOp> {
- using ConvertOpToLLVMPattern<memref::AllocOp>::ConvertOpToLLVMPattern;
+class AlignedAllocOpLowering : public ConvertOpToLLVMPattern<memref::AllocOp> {
+ SymbolTableCollection *symbolTables = nullptr;
+
+public:
+ explicit AlignedAllocOpLowering(const LLVMTypeConverter &typeConverter,
+ SymbolTableCollection *symbolTables = nullptr,
+ PatternBenefit benefit = 1)
+ : ConvertOpToLLVMPattern<memref::AllocOp>(typeConverter, benefit),
+ symbolTables(symbolTables) {}
LogicalResult
matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor,
@@ -222,9 +241,10 @@ struct AlignedAllocOpLowering : public ConvertOpToLLVMPattern<memref::AllocOp> {
return rewriter.notifyMatchFailure(op, "incompatible memref type");
// Get or insert alloc function into module.
- FailureOr<LLVM::LLVMFuncOp> allocFuncOp = getAlignedAllocFn(
- rewriter, getTypeConverter(),
- op->getParentWithTrait<OpTrait::SymbolTable>(), getIndexType());
+ FailureOr<LLVM::LLVMFuncOp> allocFuncOp =
+ getAlignedAllocFn(rewriter, getTypeConverter(),
+ op->getParentWithTrait<OpTrait::SymbolTable>(),
+ getIndexType(), symbolTables);
if (failed(allocFuncOp))
return failure();
@@ -446,18 +466,23 @@ struct AssumeAlignmentOpLowering
// A `dealloc` is converted into a call to `free` on the underlying data buffer.
// The memref descriptor being an SSA value, there is no need to clean it up
// in any way.
-struct DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> {
- using ConvertOpToLLVMPattern<memref::DeallocOp>::ConvertOpToLLVMPattern;
+class DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> {
+ SymbolTableCollection *symbolTables = nullptr;
- explicit DeallocOpLowering(const LLVMTypeConverter &converter)
- : ConvertOpToLLVMPattern<memref::DeallocOp>(converter) {}
+public:
+ explicit DeallocOpLowering(const LLVMTypeConverter &typeConverter,
+ SymbolTableCollection *symbolTables = nullptr,
+ PatternBenefit benefit = 1)
+ : ConvertOpToLLVMPattern<memref::DeallocOp>(typeConverter, benefit),
+ symbolTables(symbolTables) {}
LogicalResult
matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Insert the `free` declaration if it is not already present.
- FailureOr<LLVM::LLVMFuncOp> freeFunc = getFreeFn(
- rewriter, getTypeConverter(), op->getParentOfType<ModuleOp>());
+ FailureOr<LLVM::LLVMFuncOp> freeFunc =
+ getFreeFn(rewriter, getTypeConverter(), op->getParentOfType<ModuleOp>(),
+ symbolTables);
if (failed(freeFunc))
return failure();
Value allocatedPtr;
@@ -710,9 +735,15 @@ convertGlobalMemrefTypeToLLVM(MemRefType type,
}
/// GlobalMemrefOp is lowered to a LLVM Global Variable.
-struct GlobalMemrefOpLowering
- : public ConvertOpToLLVMPattern<memref::GlobalOp> {
- using ConvertOpToLLVMPattern<memref::GlobalOp>::ConvertOpToLLVMPattern;
+class GlobalMemrefOpLowering : public ConvertOpToLLVMPattern<memref::GlobalOp> {
+ SymbolTableCollection *symbolTables = nullptr;
+
+public:
+ explicit GlobalMemrefOpLowering(const LLVMTypeConverter &typeConverter,
+ SymbolTableCollection *symbolTables = nullptr,
+ PatternBenefit benefit = 1)
+ : ConvertOpToLLVMPattern<memref::GlobalOp>(typeConverter, benefit),
+ symbolTables(symbolTables) {}
LogicalResult
matchAndRewrite(memref::GlobalOp global, OpAdaptor adaptor,
@@ -743,9 +774,31 @@ struct GlobalMemrefOpLowering
if (failed(addressSpace))
return global.emitOpError(
"memory space cannot be converted to an integer address space");
+
+ if (symbolTables) {
+ Operation *symbolTableOp =
+ global->getParentWithTrait<OpTrait::SymbolTable>();
+
+ if (symbolTableOp) {
+ SymbolTable &symbolTable = symbolTables->getSymbolTable(symbolTableOp);
+ symbolTable.remove(global);
+ }
+ }
+
auto newGlobal = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
global, arrayTy, global.getConstant(), linkage, global.getSymName(),
initialValue, alignment, *addressSpace);
+
+ if (symbolTables) {
+ Operation *symbolTableOp =
+ global->getParentWithTrait<OpTrait::SymbolTable>();
+
+ if (symbolTableOp) {
+ SymbolTable &symbolTable = symbolTables->getSymbolTable(symbolTableOp);
+ symbolTable.insert(newGlobal, rewriter.getInsertionPoint());
+ }
+ }
+
if (!global.isExternal() && global.isUninitialized()) {
rewriter.createBlock(&newGlobal.getInitializerRegion());
Value undef[] = {
@@ -997,8 +1050,15 @@ struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> {
/// For memrefs with identity layouts, the copy is lowered to the llvm
/// `memcpy` intrinsic. For non-identity layouts, the copy is lowered to a call
/// to the generic `MemrefCopyFn`.
-struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
- using ConvertOpToLLVMPattern<memref::CopyOp>::ConvertOpToLLVMPattern;
+class MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
+ SymbolTableCollection *symbolTables = nullptr;
+
+public:
+ explicit MemRefCopyOpLowering(const LLVMTypeConverter &typeConverter,
+ SymbolTableCollection *symbolTables = nullptr,
+ PatternBenefit benefit = 1)
+ : ConvertOpToLLVMPattern<memref::CopyOp>(typeConverter, benefit),
+ symbolTables(symbolTables) {}
LogicalResult
lowerToMemCopyIntrinsic(memref::CopyOp op, OpAdaptor adaptor,
@@ -1093,7 +1153,7 @@ struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
auto elemSize = getSizeInBytes(loc, srcType.getElementType(), rewriter);
auto copyFn = LLVM::lookupOrCreateMemRefCopyFn(
rewriter, op->getParentOfType<ModuleOp>(), getIndexType(),
- sourcePtr.getType());
+ sourcePtr.getType(), symbolTables);
if (failed(copyFn))
return failure();
rewriter.create<LLVM::CallOp>(loc, copyFn.value(),
@@ -1928,7 +1988,8 @@ class ExtractStridedMetadataOpLowering
} // namespace
void mlir::populateFinalizeMemRefToLLVMConversionPatterns(
- const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
+ const LLVMTypeConverter &converter, RewritePatternSet &patterns,
+ SymbolTableCollection *symbolTables) {
// clang-format off
patterns.add<
AllocaOpLowering,
@@ -1939,11 +2000,9 @@ void mlir::populateFinalizeMemRefToLLVMConversionPatterns(
DimOpLowering,
ExtractStridedMetadataOpLowering,
GenericAtomicRMWOpLowering,
- GlobalMemrefOpLowering,
GetGlobalMemrefOpLowering,
LoadOpLowering,
MemRefCastOpLowering,
- MemRefCopyOpLowering,
MemorySpaceCastOpLowering,
MemRefReinterpretCastOpLowering,
MemRefReshapeOpLowering,
@@ -1956,11 +2015,14 @@ void mlir::populateFinalizeMemRefToLLVMConversionPatterns(
TransposeOpLowering,
ViewOpLowering>(converter);
// clang-format on
+ patterns.add<GlobalMemrefOpLowering, MemRefCopyOpLowering>(converter,
+ symbolTables);
auto allocLowering = converter.getOptions().allocLowering;
if (allocLowering == LowerToLLVMOptions::AllocLowering::AlignedAlloc)
- patterns.add<AlignedAllocOpLowering, DeallocOpLowering>(converter);
+ patterns.add<AlignedAllocOpLowering, DeallocOpLowering>(converter,
+ symbolTables);
else if (allocLowering == LowerToLLVMOptions::AllocLowering::Malloc)
- patterns.add<AllocOpLowering, DeallocOpLowering>(converter);
+ patterns.add<AllocOpLowering, DeallocOpLowering>(converter, symbolTables);
}
namespace {
@@ -1987,7 +2049,9 @@ struct FinalizeMemRefToLLVMConversionPass
LLVMTypeConverter typeConverter(&getContext(), options,
&dataLayoutAnalysis);
RewritePatternSet patterns(&getContext());
- populateFinalizeMemRefToLLVMConversionPatterns(typeConverter, patterns);
+ SymbolTableCollection symbolTables;
+ populateFinalizeMemRefToLLVMConversionPatterns(typeConverter, patterns,
+ &symbolTables);
LLVMConversionTarget target(getContext());
target.addLegalOp<func::FuncOp>();
if (failed(applyPartialConversion(op, target, std::move(patterns))))
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index f725993635672..d53d11f87efe8 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1595,8 +1595,14 @@ class VectorCreateMaskOpConversion
};
class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
+ SymbolTableCollection *symbolTables = nullptr;
+
public:
- using ConvertOpToLLVMPattern<vector::PrintOp>::ConvertOpToLLVMPattern;
+ explicit VectorPrintOpConversion(
+ const LLVMTypeConverter &typeConverter,
+ SymbolTableCollection *symbolTables = nullptr)
+ : ConvertOpToLLVMPattern<vector::PrintOp>(typeConverter),
+ symbolTables(symbolTables) {}
// Lowering implementation that relies on a small runtime support library,
// which only needs to provide a few printing methods (single value for all
@@ -1643,13 +1649,17 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
FailureOr<LLVM::LLVMFuncOp> op = [&]() {
switch (punct) {
case PrintPunctuation::Close:
- return LLVM::lookupOrCreatePrintCloseFn(rewriter, parent);
+ return LLVM::lookupOrCreatePrintCloseFn(rewriter, parent,
+ symbolTables);
case PrintPunctuation::Open:
- return LLVM::lookupOrCreatePrintOpenFn(rewriter, parent);
+ return LLVM::lookupOrCreatePrintOpenFn(rewriter, parent,
+ symbolTables);
case PrintPunctuation::Comma:
- return LLVM::lookupOrCreatePrintCommaFn(rewriter, parent);
+ return LLVM::lookupOrCreatePrintCommaFn(rewriter, parent,
+ symbolTables);
case PrintPunctuation::NewLine:
- return LLVM::lookupOrCreatePrintNewlineFn(rewriter, parent);
+ return LLVM::lookupOrCreatePrintNewlineFn(rewriter, parent,
+ symbolTables);
default:
llvm_unreachable("unexpected punctuation");
}
@@ -1683,17 +1693,17 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
PrintConversion conversion = PrintConversion::None;
FailureOr<Operation *> printer;
if (printType.isF32()) {
- printer = LLVM::lookupOrCreatePrintF32Fn(rewriter, parent);
+ printer = LLVM::lookupOrCreatePrintF32Fn(rewriter, parent, symbolTables);
} else if (printType.isF64()) {
- printer = LLVM::lookupOrCreatePrintF64Fn(rewriter, parent);
+ printer = LLVM::lookupOrCreatePrintF64Fn(rewriter, parent, symbolTables);
} else if (printType.isF16()) {
conversion = PrintConversion::Bitcast16; // bits!
- printer = LLVM::lookupOrCreatePrintF16Fn(rewriter, parent);
+ printer = LLVM::lookupOrCreatePrintF16Fn(rewriter, parent, symbolTables);
} else if (printType.isBF16()) {
conversion = PrintConversion::Bitcast16; // bits!
- printer = LLVM::lookupOrCreatePrintBF16Fn(rewriter, parent);
+ printer = LLVM::lookupOrCreatePrintBF16Fn(rewriter, parent, symbolTables);
} else if (printType.isIndex()) {
- printer = LLVM::lookupOrCreatePrintU64Fn(rewriter, parent);
+ printer = LLVM::lookupOrCreatePrintU64Fn(rewriter, parent, symbolTables);
} else if (auto intTy = dyn_cast<IntegerType>(printType)) {
// Integers need a zero or sign extension on the operand
// (depending on the source type) as well as a signed or
@@ -1703,7 +1713,8 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
if (width <= 64) {
if (width < 64)
conversion = PrintConversion::ZeroExt64;
- printer = LLVM::lookupOrCreatePrintU64Fn(rewriter, parent);
+ printer =
+ LLVM::lookupOrCreatePrintU64Fn(rewriter, parent, symbolTables);
} else {
return failure();
}
@@ -1716,7 +1727,8 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
conversion = PrintConversion::ZeroExt64;
else if (width < 64)
conversion = PrintConversion::SignExt64;
- printer = LLVM::lookupOrCreatePrintI64Fn(rewriter, parent);
+ printer =
+ LLVM::lookupOrCreatePrintI64Fn(rewriter, parent, symbolTables);
} else {
return failure();
}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
index 1b4a8f496d3d0..89f765dacda35 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
@@ -44,15 +44,31 @@ static constexpr llvm::StringRef kGenericAlignedAlloc =
static constexpr llvm::StringRef kGenericFree = "_mlir_memref_to_llvm_free";
static constexpr llvm::StringRef kMemRefCopy = "memrefCopy";
+namespace {
+/// Search for an LLVMFuncOp with a given name within an operation with the
+/// SymbolTable trait. An optional collection of cached symbol tables can be
+/// given to avoid a linear scan of the symbol table operation.
+LLVM::LLVMFuncOp lookupFuncOp(StringRef name, Operation *symbolTableOp,
+ SymbolTableCollection *symbolTables = nullptr) {
+ if (symbolTables) {
+ return symbolTables->lookupSymbolIn<LLVM::LLVMFuncOp>(
+ symbolTableOp, StringAttr::get(symbolTableOp->getContext(), name));
+ }
+
+ return llvm::dyn_cast_or_null<LLVM::LLVMFuncOp>(
+ SymbolTable::lookupSymbolIn(symbolTableOp, name));
+}
+} // namespace
+
/// Generic print function lookupOrCreate helper.
FailureOr<LLVM::LLVMFuncOp>
mlir::LLVM::lookupOrCreateFn(OpBuilder &b, Operation *moduleOp, StringRef name,
ArrayRef<Type> paramTypes, Type resultType,
- bool isVarArg, bool isReserved) {
+ bool isVarArg, bool isReserved,
+ SymbolTableCollection *symbolTables) {
assert(moduleOp->hasTrait<OpTrait::SymbolTable>() &&
"expected SymbolTable operation");
- auto func = llvm::dyn_cast_or_null<LLVM::LLVMFuncOp>(
- SymbolTable::lookupSymbolIn(moduleOp, name));
+ auto func = lookupFuncOp(name, moduleOp, symbolTables);
auto funcT = LLVMFunctionType::get(resultType, paramTypes, isVarArg);
// Assert the signature of the found function is same as expected
if (func) {
@@ -73,60 +89,75 @@ mlir::LLVM::lookupOrCreateFn(OpBuilder &b, Operation *moduleOp, StringRef name,
OpBuilder::InsertionGuard g(b);
assert(!moduleOp->getRegion(0).empty() && "expected non-empty region");
b.setInsertionPointToStart(&moduleOp->getRegion(0).front());
- return b.create<LLVM::LLVMFuncOp>(
+ auto funcOp = b.create<LLVM::LLVMFuncOp>(
moduleOp->getLoc(), name,
LLVM::LLVMFunctionType::get(resultType, paramTypes, isVarArg));
+
+ if (symbolTables) {
+ SymbolTable &symbolTable = symbolTables->getSymbolTable(moduleOp);
+ symbolTable.insert(funcOp, moduleOp->getRegion(0).front().begin());
+ }
+
+ return funcOp;
}
static FailureOr<LLVM::LLVMFuncOp>
lookupOrCreateReservedFn(OpBuilder &b, Operation *moduleOp, StringRef name,
- ArrayRef<Type> paramTypes, Type resultType) {
+ ArrayRef<Type> paramTypes, Type resultType,
+ SymbolTableCollection *symbolTables) {
return lookupOrCreateFn(b, moduleOp, name, paramTypes, resultType,
- /*isVarArg=*/false, /*isReserved=*/true);
+ /*isVarArg=*/false, /*isReserved=*/true,
+ symbolTables);
}
FailureOr<LLVM::LLVMFuncOp>
-mlir::LLVM::lookupOrCreatePrintI64Fn(OpBuilder &b, Operation *moduleOp) {
+mlir::LLVM::lookupOrCreatePrintI64Fn(OpBuilder &b, Operation *moduleOp,
+ SymbolTableCollection *symbolTables) {
return lookupOrCreateReservedFn(
b, moduleOp, kPrintI64, IntegerType::get(moduleOp->getContext(), 64),
- LLVM::LLVMVoidType::get(moduleOp->getContext()));
+ LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
}
FailureOr<LLVM::LLVMFuncOp>
-mlir::LLVM::lookupOrCreatePrintU64Fn(OpBuilder &b, Operation *moduleOp) {
+mlir::LLVM::lookupOrCreatePrintU64Fn(OpBuilder &b, Operation *moduleOp,
+ SymbolTableCollection *symbolTables) {
return lookupOrCreateReservedFn(
b, moduleOp, kPrintU64, IntegerType::get(moduleOp->getContext(), 64),
- LLVM::LLVMVoidType::get(moduleOp->getContext()));
+ LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
}
FailureOr<LLVM::LLVMFuncOp>
-mlir::LLVM::lookupOrCreatePrintF16Fn(OpBuilder &b, Operation *moduleOp) {
+mlir::LLVM::lookupOrCreatePrintF16Fn(OpBuilder &b, Operation *moduleOp,
+ SymbolTableCollection *symbolTables) {
return lookupOrCreateReservedFn(
b, moduleOp, kPrintF16,
IntegerType::get(moduleOp->getContext(), 16), // bits!
- LLVM::LLVMVoidType::get(moduleOp->getContext()));
+ LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
}
FailureOr<LLVM::LLVMFuncOp>
-mlir::LLVM::lookupOrCreatePrintBF16Fn(OpBuilder &b, Operation *moduleOp) {
+mlir::LLVM::lookupOrCreatePrintBF16Fn(OpBuilder &b, Operation *moduleOp,
+ SymbolTableCollection *symbolTables) {
return lookupOrCreateReservedFn(
b, moduleOp, kPrintBF16,
IntegerType::get(moduleOp->getContext(), 16), // bits!
- LLVM::LLVMVoidType::get(moduleOp->getContext()));
+ LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
}
FailureOr<LLVM::LLVMFuncOp>
-mlir::LLVM::lookupOrCreatePrintF32Fn(OpBuilder &b, Operation *moduleOp) {
+mlir::LLVM::lookupOrCreatePrintF32Fn(OpBuilder &b, Operation *moduleOp,
+ SymbolTableCollection *symbolTables) {
return lookupOrCreateReservedFn(
b, moduleOp, kPrintF32, Float32Type::get(moduleOp->getContext()),
- LLVM::LLVMVoidType::get(moduleOp->getContext()));
+ LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
}
FailureOr<LLVM::LLVMFuncOp>
-mlir::LLVM::lookupOrCreatePrintF64Fn(OpBuilder &b, Operation *moduleOp) {
+mlir::LLVM::lookupOrCreatePrintF64Fn(OpBuilder &b, Operation *moduleOp,
+ SymbolTableCollection *symbolTables) {
return lookupOrCreateReservedFn(
b, moduleOp, kPrintF64, Float64Type::get(moduleOp->getContext()),
- LLVM::LLVMVoidType::get(moduleOp->getContext()));
+ LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
}
static LLVM::LLVMPointerType getCharPtr(MLIRContext *context) {
@@ -140,90 +171,102 @@ static LLVM::LLVMPointerType getVoidPtr(MLIRContext *context) {
FailureOr<LLVM::LLVMFuncOp> mlir::LLVM::lookupOrCreatePrintStringFn(
OpBuilder &b, Operation *moduleOp,
- std::optional<StringRef> runtimeFunctionName) {
+ std::optional<StringRef> runtimeFunctionName,
+ SymbolTableCollection *symbolTables) {
return lookupOrCreateReservedFn(
b, moduleOp, runtimeFunctionName.value_or(kPrintString),
getCharPtr(moduleOp->getContext()),
- LLVM::LLVMVoidType::get(moduleOp->getContext()));
+ LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
}
FailureOr<LLVM::LLVMFuncOp>
-mlir::LLVM::lookupOrCreatePrintOpenFn(OpBuilder &b, Operation *moduleOp) {
+mlir::LLVM::lookupOrCreatePrintOpenFn(OpBuilder &b, Operation *moduleOp,
+ SymbolTableCollection *symbolTables) {
return lookupOrCreateReservedFn(
b, moduleOp, kPrintOpen, {},
- LLVM::LLVMVoidType::get(moduleOp->getContext()));
+ LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
}
FailureOr<LLVM::LLVMFuncOp>
-mlir::LLVM::lookupOrCreatePrintCloseFn(OpBuilder &b, Operation *moduleOp) {
+mlir::LLVM::lookupOrCreatePrintCloseFn(OpBuilder &b, Operation *moduleOp,
+ SymbolTableCollection *symbolTables) {
return lookupOrCreateReservedFn(
b, moduleOp, kPrintClose, {},
- LLVM::LLVMVoidType::get(moduleOp->getContext()));
+ LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
}
FailureOr<LLVM::LLVMFuncOp>
-mlir::LLVM::lookupOrCreatePrintCommaFn(OpBuilder &b, Operation *moduleOp) {
+mlir::LLVM::lookupOrCreatePrintCommaFn(OpBuilder &b, Operation *moduleOp,
+ SymbolTableCollection *symbolTables) {
return lookupOrCreateReservedFn(
b, moduleOp, kPrintComma, {},
- LLVM::LLVMVoidType::get(moduleOp->getContext()));
+ LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
}
FailureOr<LLVM::LLVMFuncOp>
-mlir::LLVM::lookupOrCreatePrintNewlineFn(OpBuilder &b, Operation *moduleOp) {
+mlir::LLVM::lookupOrCreatePrintNewlineFn(OpBuilder &b, Operation *moduleOp,
+ SymbolTableCollection *symbolTables) {
return lookupOrCreateReservedFn(
b, moduleOp, kPrintNewline, {},
- LLVM::LLVMVoidType::get(moduleOp->getContext()));
+ LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
}
FailureOr<LLVM::LLVMFuncOp>
mlir::LLVM::lookupOrCreateMallocFn(OpBuilder &b, Operation *moduleOp,
- Type indexType) {
+ Type indexType,
+ SymbolTableCollection *symbolTables) {
return lookupOrCreateReservedFn(b, moduleOp, kMalloc, indexType,
- getVoidPtr(moduleOp->getContext()));
+ getVoidPtr(moduleOp->getContext()),
+ symbolTables);
}
FailureOr<LLVM::LLVMFuncOp>
mlir::LLVM::lookupOrCreateAlignedAllocFn(OpBuilder &b, Operation *moduleOp,
- Type indexType) {
- return lookupOrCreateReservedFn(b, moduleOp, kAlignedAlloc,
- {indexType, indexType},
- getVoidPtr(moduleOp->getContext()));
+ Type indexType,
+ SymbolTableCollection *symbolTables) {
+ return lookupOrCreateReservedFn(
+ b, moduleOp, kAlignedAlloc, {indexType, indexType},
+ getVoidPtr(moduleOp->getContext()), symbolTables);
}
FailureOr<LLVM::LLVMFuncOp>
-mlir::LLVM::lookupOrCreateFreeFn(OpBuilder &b, Operation *moduleOp) {
+mlir::LLVM::lookupOrCreateFreeFn(OpBuilder &b, Operation *moduleOp,
+ SymbolTableCollection *symbolTables) {
return lookupOrCreateReservedFn(
b, moduleOp, kFree, getVoidPtr(moduleOp->getContext()),
- LLVM::LLVMVoidType::get(moduleOp->getContext()));
+ LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
}
FailureOr<LLVM::LLVMFuncOp>
mlir::LLVM::lookupOrCreateGenericAllocFn(OpBuilder &b, Operation *moduleOp,
- Type indexType) {
+ Type indexType,
+ SymbolTableCollection *symbolTables) {
return lookupOrCreateReservedFn(b, moduleOp, kGenericAlloc, indexType,
- getVoidPtr(moduleOp->getContext()));
+ getVoidPtr(moduleOp->getContext()),
+ symbolTables);
}
FailureOr<LLVM::LLVMFuncOp> mlir::LLVM::lookupOrCreateGenericAlignedAllocFn(
- OpBuilder &b, Operation *moduleOp, Type indexType) {
- return lookupOrCreateReservedFn(b, moduleOp, kGenericAlignedAlloc,
- {indexType, indexType},
- getVoidPtr(moduleOp->getContext()));
+ OpBuilder &b, Operation *moduleOp, Type indexType,
+ SymbolTableCollection *symbolTables) {
+ return lookupOrCreateReservedFn(
+ b, moduleOp, kGenericAlignedAlloc, {indexType, indexType},
+ getVoidPtr(moduleOp->getContext()), symbolTables);
}
FailureOr<LLVM::LLVMFuncOp>
-mlir::LLVM::lookupOrCreateGenericFreeFn(OpBuilder &b, Operation *moduleOp) {
+mlir::LLVM::lookupOrCreateGenericFreeFn(OpBuilder &b, Operation *moduleOp,
+ SymbolTableCollection *symbolTables) {
return lookupOrCreateReservedFn(
b, moduleOp, kGenericFree, getVoidPtr(moduleOp->getContext()),
- LLVM::LLVMVoidType::get(moduleOp->getContext()));
+ LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
}
-FailureOr<LLVM::LLVMFuncOp>
-mlir::LLVM::lookupOrCreateMemRefCopyFn(OpBuilder &b, Operation *moduleOp,
- Type indexType,
- Type unrankedDescriptorType) {
+FailureOr<LLVM::LLVMFuncOp> mlir::LLVM::lookupOrCreateMemRefCopyFn(
+ OpBuilder &b, Operation *moduleOp, Type indexType,
+ Type unrankedDescriptorType, SymbolTableCollection *symbolTables) {
return lookupOrCreateReservedFn(
b, moduleOp, kMemRefCopy,
ArrayRef<Type>{indexType, unrankedDescriptorType, unrankedDescriptorType},
- LLVM::LLVMVoidType::get(moduleOp->getContext()));
+ LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
}
More information about the Mlir-commits
mailing list