[Mlir-commits] [mlir] [MLIR] Add optional cached symbol tables to LLVM conversion patterns (PR #144032)
Michele Scuttari
llvmlistbot at llvm.org
Fri Jun 20 09:31:29 PDT 2025
https://github.com/mscuttari updated https://github.com/llvm/llvm-project/pull/144032
>From cbc49497277afe104295c6dc400aa8cc9aa3bb96 Mon Sep 17 00:00:00 2001
From: Michele Scuttari <michele.scuttari at outlook.com>
Date: Thu, 12 Jun 2025 23:18:07 +0200
Subject: [PATCH 1/4] Cache symbol tables in LLVM conversion patterns
---
.../ControlFlowToLLVM/ControlFlowToLLVM.h | 7 +-
.../Conversion/FuncToLLVM/ConvertFuncToLLVM.h | 9 +-
.../Conversion/LLVMCommon/PrintCallHelper.h | 4 +-
.../Conversion/MemRefToLLVM/MemRefToLLVM.h | 4 +-
.../mlir/Dialect/LLVMIR/FunctionCallUtils.h | 83 +++++++----
.../ControlFlowToLLVM/ControlFlowToLLVM.cpp | 13 +-
mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp | 72 +++++----
.../Conversion/LLVMCommon/PrintCallHelper.cpp | 23 ++-
.../Conversion/MemRefToLLVM/MemRefToLLVM.cpp | 140 +++++++++++++-----
.../VectorToLLVM/ConvertVectorToLLVM.cpp | 36 +++--
.../Dialect/LLVMIR/IR/FunctionCallUtils.cpp | 138 +++++++++++------
11 files changed, 353 insertions(+), 176 deletions(-)
diff --git a/mlir/include/mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h b/mlir/include/mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h
index 88f18022da9bb..2dfb6b03bcfcd 100644
--- a/mlir/include/mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h
+++ b/mlir/include/mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h
@@ -20,6 +20,7 @@ class DialectRegistry;
class LLVMTypeConverter;
class RewritePatternSet;
class Pass;
+class SymbolTableCollection;
#define GEN_PASS_DECL_CONVERTCONTROLFLOWTOLLVMPASS
#include "mlir/Conversion/Passes.h.inc"
@@ -39,9 +40,9 @@ void populateControlFlowToLLVMConversionPatterns(
/// Populate the cf.assert to LLVM conversion pattern. If `abortOnFailure` is
/// set to false, the program execution continues when a condition is
/// unsatisfied.
-void populateAssertToLLVMConversionPattern(const LLVMTypeConverter &converter,
- RewritePatternSet &patterns,
- bool abortOnFailure = true);
+void populateAssertToLLVMConversionPattern(
+ const LLVMTypeConverter &converter, RewritePatternSet &patterns,
+ bool abortOnFailure = true, SymbolTableCollection *symbolTables = nullptr);
void registerConvertControlFlowToLLVMInterface(DialectRegistry ®istry);
diff --git a/mlir/include/mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h b/mlir/include/mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h
index b1ea2740c0605..e530b0a43b8e0 100644
--- a/mlir/include/mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h
+++ b/mlir/include/mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h
@@ -27,20 +27,23 @@ class DialectRegistry;
class LLVMTypeConverter;
class RewritePatternSet;
class SymbolTable;
+class SymbolTableCollection;
/// Convert input FunctionOpInterface operation to LLVMFuncOp by using the
/// provided LLVMTypeConverter. Return failure if failed to so.
FailureOr<LLVM::LLVMFuncOp>
convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
ConversionPatternRewriter &rewriter,
- const LLVMTypeConverter &converter);
+ const LLVMTypeConverter &converter,
+ SymbolTableCollection *symbolTables = nullptr);
/// Collect the default pattern to convert a FuncOp to the LLVM dialect. If
/// `emitCWrappers` is set, the pattern will also produce functions
/// that pass memref descriptors by pointer-to-structure in addition to the
/// default unpacked form.
void populateFuncToLLVMFuncOpConversionPattern(
- const LLVMTypeConverter &converter, RewritePatternSet &patterns);
+ const LLVMTypeConverter &converter, RewritePatternSet &patterns,
+ SymbolTableCollection *symbolTables = nullptr);
/// Collect the patterns to convert from the Func dialect to LLVM. The
/// conversion patterns capture the LLVMTypeConverter and the LowerToLLVMOptions
@@ -57,7 +60,7 @@ void populateFuncToLLVMFuncOpConversionPattern(
/// not an error to provide it anyway.
void populateFuncToLLVMConversionPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
- const SymbolTable *symbolTable = nullptr);
+ SymbolTableCollection *symbolTables = nullptr);
void registerConvertFuncToLLVMInterface(DialectRegistry ®istry);
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h b/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h
index 33402301115b7..d7de40555bb6a 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h
@@ -17,6 +17,7 @@ namespace mlir {
class OpBuilder;
class LLVMTypeConverter;
+class SymbolTableCollection;
namespace LLVM {
@@ -26,7 +27,8 @@ namespace LLVM {
LogicalResult createPrintStrCall(
OpBuilder &builder, Location loc, ModuleOp moduleOp, StringRef symbolName,
StringRef string, const LLVMTypeConverter &typeConverter,
- bool addNewline = true, std::optional<StringRef> runtimeFunctionName = {});
+ bool addNewline = true, std::optional<StringRef> runtimeFunctionName = {},
+ SymbolTableCollection *symbolTables = nullptr);
} // namespace LLVM
} // namespace mlir
diff --git a/mlir/include/mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h b/mlir/include/mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h
index 996a64baf9dd5..e93d5bdce7bf2 100644
--- a/mlir/include/mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h
+++ b/mlir/include/mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h
@@ -16,6 +16,7 @@ class DialectRegistry;
class Pass;
class LLVMTypeConverter;
class RewritePatternSet;
+class SymbolTableCollection;
#define GEN_PASS_DECL_FINALIZEMEMREFTOLLVMCONVERSIONPASS
#include "mlir/Conversion/Passes.h.inc"
@@ -23,7 +24,8 @@ class RewritePatternSet;
/// Collect a set of patterns to convert memory-related operations from the
/// MemRef dialect to the LLVM dialect.
void populateFinalizeMemRefToLLVMConversionPatterns(
- const LLVMTypeConverter &converter, RewritePatternSet &patterns);
+ const LLVMTypeConverter &converter, RewritePatternSet &patterns,
+ SymbolTableCollection *symbolTables = nullptr);
void registerConvertMemRefToLLVMInterface(DialectRegistry ®istry);
diff --git a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
index 4a7ec6f2efe64..8ad9ed18acebd 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
@@ -24,6 +24,7 @@ class OpBuilder;
class Operation;
class Type;
class ValueRange;
+class SymbolTableCollection;
namespace LLVM {
class LLVMFuncOp;
@@ -33,55 +34,73 @@ class LLVMFuncOp;
/// implemented separately (e.g. as part of a support runtime library or as part
/// of the libc).
/// Failure if an unexpected version of function is found.
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintI64Fn(OpBuilder &b,
- Operation *moduleOp);
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintU64Fn(OpBuilder &b,
- Operation *moduleOp);
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintF16Fn(OpBuilder &b,
- Operation *moduleOp);
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintBF16Fn(OpBuilder &b,
- Operation *moduleOp);
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintF32Fn(OpBuilder &b,
- Operation *moduleOp);
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintF64Fn(OpBuilder &b,
- Operation *moduleOp);
+FailureOr<LLVM::LLVMFuncOp>
+lookupOrCreatePrintI64Fn(OpBuilder &b, Operation *moduleOp,
+ SymbolTableCollection *symbolTables = nullptr);
+FailureOr<LLVM::LLVMFuncOp>
+lookupOrCreatePrintU64Fn(OpBuilder &b, Operation *moduleOp,
+ SymbolTableCollection *symbolTables = nullptr);
+FailureOr<LLVM::LLVMFuncOp>
+lookupOrCreatePrintF16Fn(OpBuilder &b, Operation *moduleOp,
+ SymbolTableCollection *symbolTables = nullptr);
+FailureOr<LLVM::LLVMFuncOp>
+lookupOrCreatePrintBF16Fn(OpBuilder &b, Operation *moduleOp,
+ SymbolTableCollection *symbolTables = nullptr);
+FailureOr<LLVM::LLVMFuncOp>
+lookupOrCreatePrintF32Fn(OpBuilder &b, Operation *moduleOp,
+ SymbolTableCollection *symbolTables = nullptr);
+FailureOr<LLVM::LLVMFuncOp>
+lookupOrCreatePrintF64Fn(OpBuilder &b, Operation *moduleOp,
+ SymbolTableCollection *symbolTables = nullptr);
/// Declares a function to print a C-string.
/// If a custom runtime function is defined via `runtimeFunctionName`, it must
/// have the signature void(char const*). The default function is `printString`.
FailureOr<LLVM::LLVMFuncOp>
lookupOrCreatePrintStringFn(OpBuilder &b, Operation *moduleOp,
- std::optional<StringRef> runtimeFunctionName = {});
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintOpenFn(OpBuilder &b,
- Operation *moduleOp);
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintCloseFn(OpBuilder &b,
- Operation *moduleOp);
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintCommaFn(OpBuilder &b,
- Operation *moduleOp);
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintNewlineFn(OpBuilder &b,
- Operation *moduleOp);
+ std::optional<StringRef> runtimeFunctionName = {},
+ SymbolTableCollection *symbolTables = nullptr);
+FailureOr<LLVM::LLVMFuncOp>
+lookupOrCreatePrintOpenFn(OpBuilder &b, Operation *moduleOp,
+ SymbolTableCollection *symbolTables = nullptr);
+FailureOr<LLVM::LLVMFuncOp>
+lookupOrCreatePrintCloseFn(OpBuilder &b, Operation *moduleOp,
+ SymbolTableCollection *symbolTables = nullptr);
+FailureOr<LLVM::LLVMFuncOp>
+lookupOrCreatePrintCommaFn(OpBuilder &b, Operation *moduleOp,
+ SymbolTableCollection *symbolTables = nullptr);
+FailureOr<LLVM::LLVMFuncOp>
+lookupOrCreatePrintNewlineFn(OpBuilder &b, Operation *moduleOp,
+ SymbolTableCollection *symbolTables = nullptr);
+FailureOr<LLVM::LLVMFuncOp>
+lookupOrCreateMallocFn(OpBuilder &b, Operation *moduleOp, Type indexType,
+ SymbolTableCollection *symbolTables = nullptr);
FailureOr<LLVM::LLVMFuncOp>
-lookupOrCreateMallocFn(OpBuilder &b, Operation *moduleOp, Type indexType);
+lookupOrCreateAlignedAllocFn(OpBuilder &b, Operation *moduleOp, Type indexType,
+ SymbolTableCollection *symbolTables = nullptr);
FailureOr<LLVM::LLVMFuncOp>
-lookupOrCreateAlignedAllocFn(OpBuilder &b, Operation *moduleOp, Type indexType);
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreateFreeFn(OpBuilder &b,
- Operation *moduleOp);
+lookupOrCreateFreeFn(OpBuilder &b, Operation *moduleOp,
+ SymbolTableCollection *symbolTables = nullptr);
FailureOr<LLVM::LLVMFuncOp>
-lookupOrCreateGenericAllocFn(OpBuilder &b, Operation *moduleOp, Type indexType);
+lookupOrCreateGenericAllocFn(OpBuilder &b, Operation *moduleOp, Type indexType,
+ SymbolTableCollection *symbolTables = nullptr);
+FailureOr<LLVM::LLVMFuncOp> lookupOrCreateGenericAlignedAllocFn(
+ OpBuilder &b, Operation *moduleOp, Type indexType,
+ SymbolTableCollection *symbolTables = nullptr);
FailureOr<LLVM::LLVMFuncOp>
-lookupOrCreateGenericAlignedAllocFn(OpBuilder &b, Operation *moduleOp,
- Type indexType);
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreateGenericFreeFn(OpBuilder &b,
- Operation *moduleOp);
+lookupOrCreateGenericFreeFn(OpBuilder &b, Operation *moduleOp,
+ SymbolTableCollection *symbolTables = nullptr);
FailureOr<LLVM::LLVMFuncOp>
lookupOrCreateMemRefCopyFn(OpBuilder &b, Operation *moduleOp, Type indexType,
- Type unrankedDescriptorType);
+ Type unrankedDescriptorType,
+ SymbolTableCollection *symbolTables = nullptr);
/// Create a FuncOp with signature `resultType`(`paramTypes`)` and name `name`.
/// Return a failure if the FuncOp found has unexpected signature.
FailureOr<LLVM::LLVMFuncOp>
lookupOrCreateFn(OpBuilder &b, Operation *moduleOp, StringRef name,
ArrayRef<Type> paramTypes = {}, Type resultType = {},
- bool isVarArg = false, bool isReserved = false);
+ bool isVarArg = false, bool isReserved = false,
+ SymbolTableCollection *symbolTables = nullptr);
} // namespace LLVM
} // namespace mlir
diff --git a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
index debfd003bd5b5..eaa8e7d26d4bd 100644
--- a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
+++ b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
@@ -44,9 +44,10 @@ namespace {
/// lowering.
struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {
explicit AssertOpLowering(const LLVMTypeConverter &typeConverter,
- bool abortOnFailedAssert = true)
+ bool abortOnFailedAssert = true,
+ SymbolTableCollection *symbolTables = nullptr)
: ConvertOpToLLVMPattern<cf::AssertOp>(typeConverter, /*benefit=*/1),
- abortOnFailedAssert(abortOnFailedAssert) {}
+ abortOnFailedAssert(abortOnFailedAssert), symbolTables(symbolTables) {}
LogicalResult
matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor,
@@ -64,7 +65,7 @@ struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {
auto createResult = LLVM::createPrintStrCall(
rewriter, loc, module, "assert_msg", op.getMsg(), *getTypeConverter(),
/*addNewLine=*/false,
- /*runtimeFunctionName=*/"puts");
+ /*runtimeFunctionName=*/"puts", symbolTables);
if (createResult.failed())
return failure();
@@ -96,6 +97,8 @@ struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {
/// If set to `false`, messages are printed but program execution continues.
/// This is useful for testing asserts.
bool abortOnFailedAssert = true;
+
+ SymbolTableCollection *symbolTables = nullptr;
};
/// Helper function for converting branch ops. This function converts the
@@ -227,8 +230,8 @@ void mlir::cf::populateControlFlowToLLVMConversionPatterns(
void mlir::cf::populateAssertToLLVMConversionPattern(
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
- bool abortOnFailure) {
- patterns.add<AssertOpLowering>(converter, abortOnFailure);
+ bool abortOnFailure, SymbolTableCollection *symbolTables) {
+ patterns.add<AssertOpLowering>(converter, abortOnFailure, symbolTables);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
index 328c605add65c..6a6371921c1d5 100644
--- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
+++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
@@ -299,10 +299,9 @@ static void restoreByValRefArgumentType(
}
}
-FailureOr<LLVM::LLVMFuncOp>
-mlir::convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
- ConversionPatternRewriter &rewriter,
- const LLVMTypeConverter &converter) {
+FailureOr<LLVM::LLVMFuncOp> mlir::convertFuncOpToLLVMFuncOp(
+ FunctionOpInterface funcOp, ConversionPatternRewriter &rewriter,
+ const LLVMTypeConverter &converter, SymbolTableCollection *symbolTables) {
// Check the funcOp has `FunctionType`.
auto funcTy = dyn_cast<FunctionType>(funcOp.getFunctionType());
if (!funcTy)
@@ -365,10 +364,25 @@ mlir::convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
SmallVector<NamedAttribute, 4> attributes;
filterFuncAttributes(funcOp, attributes);
+
+ Operation *symbolTableOp = funcOp->getParentWithTrait<OpTrait::SymbolTable>();
+
+ if (symbolTables && symbolTableOp) {
+ SymbolTable &symbolTable = symbolTables->getSymbolTable(symbolTableOp);
+ symbolTable.remove(funcOp);
+ }
+
auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
funcOp.getLoc(), funcOp.getName(), llvmType, linkage,
/*dsoLocal=*/false, /*cconv=*/LLVM::CConv::C, /*comdat=*/nullptr,
attributes);
+
+ if (symbolTables && symbolTableOp) {
+ auto ip = rewriter.getInsertionPoint();
+ SymbolTable &symbolTable = symbolTables->getSymbolTable(symbolTableOp);
+ symbolTable.insert(newFuncOp, ip);
+ }
+
cast<FunctionOpInterface>(newFuncOp.getOperation())
.setVisibility(funcOp.getVisibility());
@@ -477,16 +491,20 @@ namespace {
/// FuncOp legalization pattern that converts MemRef arguments to pointers to
/// MemRef descriptors (LLVM struct data types) containing all the MemRef type
/// information.
-struct FuncOpConversion : public ConvertOpToLLVMPattern<func::FuncOp> {
- FuncOpConversion(const LLVMTypeConverter &converter)
- : ConvertOpToLLVMPattern(converter) {}
+class FuncOpConversion : public ConvertOpToLLVMPattern<func::FuncOp> {
+ SymbolTableCollection *symbolTables = nullptr;
+
+public:
+ explicit FuncOpConversion(const LLVMTypeConverter &converter,
+ SymbolTableCollection *symbolTables = nullptr)
+ : ConvertOpToLLVMPattern(converter), symbolTables(symbolTables) {}
LogicalResult
matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
FailureOr<LLVM::LLVMFuncOp> newFuncOp = mlir::convertFuncOpToLLVMFuncOp(
cast<FunctionOpInterface>(funcOp.getOperation()), rewriter,
- *getTypeConverter());
+ *getTypeConverter(), symbolTables);
if (failed(newFuncOp))
return rewriter.notifyMatchFailure(funcOp, "Could not convert funcop");
@@ -595,11 +613,12 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
class CallOpLowering : public CallOpInterfaceLowering<func::CallOp> {
public:
- CallOpLowering(const LLVMTypeConverter &typeConverter,
- // Can be nullptr.
- const SymbolTable *symbolTable, PatternBenefit benefit = 1)
+ explicit CallOpLowering(const LLVMTypeConverter &typeConverter,
+ // Can be nullptr.
+ SymbolTableCollection *symbolTables = nullptr,
+ PatternBenefit benefit = 1)
: CallOpInterfaceLowering<func::CallOp>(typeConverter, benefit),
- symbolTable(symbolTable) {}
+ symbolTables(symbolTables) {}
LogicalResult
matchAndRewrite(func::CallOp callOp, OpAdaptor adaptor,
@@ -607,10 +626,10 @@ class CallOpLowering : public CallOpInterfaceLowering<func::CallOp> {
bool useBarePtrCallConv = false;
if (getTypeConverter()->getOptions().useBarePtrCallConv) {
useBarePtrCallConv = true;
- } else if (symbolTable != nullptr) {
+ } else if (symbolTables != nullptr) {
// Fast lookup.
Operation *callee =
- symbolTable->lookup(callOp.getCalleeAttr().getValue());
+ symbolTables->lookupNearestSymbolFrom(callOp, callOp.getCalleeAttr());
useBarePtrCallConv =
callee != nullptr && callee->hasAttr(barePtrAttrName);
} else {
@@ -624,7 +643,7 @@ class CallOpLowering : public CallOpInterfaceLowering<func::CallOp> {
}
private:
- const SymbolTable *symbolTable = nullptr;
+ SymbolTableCollection *symbolTables = nullptr;
};
struct CallIndirectOpLowering
@@ -735,16 +754,17 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern<func::ReturnOp> {
} // namespace
void mlir::populateFuncToLLVMFuncOpConversionPattern(
- const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
- patterns.add<FuncOpConversion>(converter);
+ const LLVMTypeConverter &converter, RewritePatternSet &patterns,
+ SymbolTableCollection *symbolTables) {
+ patterns.add<FuncOpConversion>(converter, symbolTables);
}
void mlir::populateFuncToLLVMConversionPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
- const SymbolTable *symbolTable) {
- populateFuncToLLVMFuncOpConversionPattern(converter, patterns);
+ SymbolTableCollection *symbolTables) {
+ populateFuncToLLVMFuncOpConversionPattern(converter, patterns, symbolTables);
patterns.add<CallIndirectOpLowering>(converter);
- patterns.add<CallOpLowering>(converter, symbolTable);
+ patterns.add<CallOpLowering>(converter, symbolTables);
patterns.add<ConstantOpLowering>(converter);
patterns.add<ReturnOpLowering>(converter);
}
@@ -784,15 +804,11 @@ struct ConvertFuncToLLVMPass
LLVMTypeConverter typeConverter(&getContext(), options,
&dataLayoutAnalysis);
- std::optional<SymbolTable> optSymbolTable = std::nullopt;
- const SymbolTable *symbolTable = nullptr;
- if (!options.useBarePtrCallConv) {
- optSymbolTable.emplace(m);
- symbolTable = &optSymbolTable.value();
- }
-
RewritePatternSet patterns(&getContext());
- populateFuncToLLVMConversionPatterns(typeConverter, patterns, symbolTable);
+ SymbolTableCollection symbolTables;
+
+ populateFuncToLLVMConversionPatterns(typeConverter, patterns,
+ &symbolTables);
LLVMConversionTarget target(getContext());
if (failed(applyPartialConversion(m, target, std::move(patterns))))
diff --git a/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp b/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp
index 2815e05b3e11b..b2f84a02a52d1 100644
--- a/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp
@@ -17,8 +17,22 @@
using namespace mlir;
using namespace llvm;
-static std::string ensureSymbolNameIsUnique(ModuleOp moduleOp,
- StringRef symbolName) {
+static std::string
+ensureSymbolNameIsUnique(ModuleOp moduleOp, StringRef symbolName,
+ SymbolTableCollection *symbolTables = nullptr) {
+ if (symbolTables) {
+ SymbolTable &symbolTable = symbolTables->getSymbolTable(moduleOp);
+ unsigned counter = 0;
+ SmallString<128> uniqueName = symbolTable.generateSymbolName<128>(
+ symbolName,
+ [&](const SmallString<128> &tentativeName) {
+ return symbolTable.lookupSymbolIn(moduleOp, tentativeName) != nullptr;
+ },
+ counter);
+
+ return static_cast<std::string>(uniqueName);
+ }
+
static int counter = 0;
std::string uniqueName = std::string(symbolName);
while (moduleOp.lookupSymbol(uniqueName)) {
@@ -30,7 +44,8 @@ static std::string ensureSymbolNameIsUnique(ModuleOp moduleOp,
LogicalResult mlir::LLVM::createPrintStrCall(
OpBuilder &builder, Location loc, ModuleOp moduleOp, StringRef symbolName,
StringRef string, const LLVMTypeConverter &typeConverter, bool addNewline,
- std::optional<StringRef> runtimeFunctionName) {
+ std::optional<StringRef> runtimeFunctionName,
+ SymbolTableCollection *symbolTables) {
auto ip = builder.saveInsertionPoint();
builder.setInsertionPointToStart(moduleOp.getBody());
MLIRContext *ctx = builder.getContext();
@@ -49,7 +64,7 @@ LogicalResult mlir::LLVM::createPrintStrCall(
LLVM::LLVMArrayType::get(IntegerType::get(ctx, 8), elementVals.size());
auto globalOp = builder.create<LLVM::GlobalOp>(
loc, arrayTy, /*constant=*/true, LLVM::Linkage::Private,
- ensureSymbolNameIsUnique(moduleOp, symbolName), dataAttr);
+ ensureSymbolNameIsUnique(moduleOp, symbolName, symbolTables), dataAttr);
auto ptrTy = LLVM::LLVMPointerType::get(builder.getContext());
// Emit call to `printStr` in runtime library.
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index ade4e4d3de8ec..93c7bd46977a5 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -45,35 +45,39 @@ static bool isStaticStrideOrOffset(int64_t strideOrOffset) {
}
static FailureOr<LLVM::LLVMFuncOp>
-getFreeFn(OpBuilder &b, const LLVMTypeConverter *typeConverter,
- ModuleOp module) {
+getFreeFn(OpBuilder &b, const LLVMTypeConverter *typeConverter, ModuleOp module,
+ SymbolTableCollection *symbolTables) {
bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
if (useGenericFn)
- return LLVM::lookupOrCreateGenericFreeFn(b, module);
+ return LLVM::lookupOrCreateGenericFreeFn(b, module, symbolTables);
- return LLVM::lookupOrCreateFreeFn(b, module);
+ return LLVM::lookupOrCreateFreeFn(b, module, symbolTables);
}
static FailureOr<LLVM::LLVMFuncOp>
getNotalignedAllocFn(OpBuilder &b, const LLVMTypeConverter *typeConverter,
- Operation *module, Type indexType) {
+ Operation *module, Type indexType,
+ SymbolTableCollection *symbolTables) {
bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
if (useGenericFn)
- return LLVM::lookupOrCreateGenericAllocFn(b, module, indexType);
+ return LLVM::lookupOrCreateGenericAllocFn(b, module, indexType,
+ symbolTables);
- return LLVM::lookupOrCreateMallocFn(b, module, indexType);
+ return LLVM::lookupOrCreateMallocFn(b, module, indexType, symbolTables);
}
static FailureOr<LLVM::LLVMFuncOp>
getAlignedAllocFn(OpBuilder &b, const LLVMTypeConverter *typeConverter,
- Operation *module, Type indexType) {
+ Operation *module, Type indexType,
+ SymbolTableCollection *symbolTables) {
bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
if (useGenericFn)
- return LLVM::lookupOrCreateGenericAlignedAllocFn(b, module, indexType);
+ return LLVM::lookupOrCreateGenericAlignedAllocFn(b, module, indexType,
+ symbolTables);
- return LLVM::lookupOrCreateAlignedAllocFn(b, module, indexType);
+ return LLVM::lookupOrCreateAlignedAllocFn(b, module, indexType, symbolTables);
}
/// Computes the aligned value for 'input' as follows:
@@ -123,8 +127,15 @@ static Value castAllocFuncResult(ConversionPatternRewriter &rewriter,
return allocatedPtr;
}
-struct AllocOpLowering : public ConvertOpToLLVMPattern<memref::AllocOp> {
- using ConvertOpToLLVMPattern<memref::AllocOp>::ConvertOpToLLVMPattern;
+class AllocOpLowering : public ConvertOpToLLVMPattern<memref::AllocOp> {
+ SymbolTableCollection *symbolTables = nullptr;
+
+public:
+ explicit AllocOpLowering(const LLVMTypeConverter &typeConverter,
+ SymbolTableCollection *symbolTables = nullptr,
+ PatternBenefit benefit = 1)
+ : ConvertOpToLLVMPattern<memref::AllocOp>(typeConverter, benefit),
+ symbolTables(symbolTables) {}
LogicalResult
matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor,
@@ -135,9 +146,10 @@ struct AllocOpLowering : public ConvertOpToLLVMPattern<memref::AllocOp> {
return rewriter.notifyMatchFailure(op, "incompatible memref type");
// Get or insert alloc function into the module.
- FailureOr<LLVM::LLVMFuncOp> allocFuncOp = getNotalignedAllocFn(
- rewriter, getTypeConverter(),
- op->getParentWithTrait<OpTrait::SymbolTable>(), getIndexType());
+ FailureOr<LLVM::LLVMFuncOp> allocFuncOp =
+ getNotalignedAllocFn(rewriter, getTypeConverter(),
+ op->getParentWithTrait<OpTrait::SymbolTable>(),
+ getIndexType(), symbolTables);
if (failed(allocFuncOp))
return failure();
@@ -207,8 +219,15 @@ struct AllocOpLowering : public ConvertOpToLLVMPattern<memref::AllocOp> {
}
};
-struct AlignedAllocOpLowering : public ConvertOpToLLVMPattern<memref::AllocOp> {
- using ConvertOpToLLVMPattern<memref::AllocOp>::ConvertOpToLLVMPattern;
+class AlignedAllocOpLowering : public ConvertOpToLLVMPattern<memref::AllocOp> {
+ SymbolTableCollection *symbolTables = nullptr;
+
+public:
+ explicit AlignedAllocOpLowering(const LLVMTypeConverter &typeConverter,
+ SymbolTableCollection *symbolTables = nullptr,
+ PatternBenefit benefit = 1)
+ : ConvertOpToLLVMPattern<memref::AllocOp>(typeConverter, benefit),
+ symbolTables(symbolTables) {}
LogicalResult
matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor,
@@ -219,9 +238,10 @@ struct AlignedAllocOpLowering : public ConvertOpToLLVMPattern<memref::AllocOp> {
return rewriter.notifyMatchFailure(op, "incompatible memref type");
// Get or insert alloc function into module.
- FailureOr<LLVM::LLVMFuncOp> allocFuncOp = getAlignedAllocFn(
- rewriter, getTypeConverter(),
- op->getParentWithTrait<OpTrait::SymbolTable>(), getIndexType());
+ FailureOr<LLVM::LLVMFuncOp> allocFuncOp =
+ getAlignedAllocFn(rewriter, getTypeConverter(),
+ op->getParentWithTrait<OpTrait::SymbolTable>(),
+ getIndexType(), symbolTables);
if (failed(allocFuncOp))
return failure();
@@ -443,18 +463,23 @@ struct AssumeAlignmentOpLowering
// A `dealloc` is converted into a call to `free` on the underlying data buffer.
// The memref descriptor being an SSA value, there is no need to clean it up
// in any way.
-struct DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> {
- using ConvertOpToLLVMPattern<memref::DeallocOp>::ConvertOpToLLVMPattern;
+class DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> {
+ SymbolTableCollection *symbolTables = nullptr;
- explicit DeallocOpLowering(const LLVMTypeConverter &converter)
- : ConvertOpToLLVMPattern<memref::DeallocOp>(converter) {}
+public:
+ explicit DeallocOpLowering(const LLVMTypeConverter &typeConverter,
+ SymbolTableCollection *symbolTables = nullptr,
+ PatternBenefit benefit = 1)
+ : ConvertOpToLLVMPattern<memref::DeallocOp>(typeConverter, benefit),
+ symbolTables(symbolTables) {}
LogicalResult
matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Insert the `free` declaration if it is not already present.
- FailureOr<LLVM::LLVMFuncOp> freeFunc = getFreeFn(
- rewriter, getTypeConverter(), op->getParentOfType<ModuleOp>());
+ FailureOr<LLVM::LLVMFuncOp> freeFunc =
+ getFreeFn(rewriter, getTypeConverter(), op->getParentOfType<ModuleOp>(),
+ symbolTables);
if (failed(freeFunc))
return failure();
Value allocatedPtr;
@@ -707,9 +732,15 @@ convertGlobalMemrefTypeToLLVM(MemRefType type,
}
/// GlobalMemrefOp is lowered to a LLVM Global Variable.
-struct GlobalMemrefOpLowering
- : public ConvertOpToLLVMPattern<memref::GlobalOp> {
- using ConvertOpToLLVMPattern<memref::GlobalOp>::ConvertOpToLLVMPattern;
+class GlobalMemrefOpLowering : public ConvertOpToLLVMPattern<memref::GlobalOp> {
+ SymbolTableCollection *symbolTables = nullptr;
+
+public:
+ explicit GlobalMemrefOpLowering(const LLVMTypeConverter &typeConverter,
+ SymbolTableCollection *symbolTables = nullptr,
+ PatternBenefit benefit = 1)
+ : ConvertOpToLLVMPattern<memref::GlobalOp>(typeConverter, benefit),
+ symbolTables(symbolTables) {}
LogicalResult
matchAndRewrite(memref::GlobalOp global, OpAdaptor adaptor,
@@ -740,9 +771,31 @@ struct GlobalMemrefOpLowering
if (failed(addressSpace))
return global.emitOpError(
"memory space cannot be converted to an integer address space");
+
+ if (symbolTables) {
+ Operation *symbolTableOp =
+ global->getParentWithTrait<OpTrait::SymbolTable>();
+
+ if (symbolTableOp) {
+ SymbolTable &symbolTable = symbolTables->getSymbolTable(symbolTableOp);
+ symbolTable.remove(global);
+ }
+ }
+
auto newGlobal = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
global, arrayTy, global.getConstant(), linkage, global.getSymName(),
initialValue, alignment, *addressSpace);
+
+ if (symbolTables) {
+ Operation *symbolTableOp =
+ global->getParentWithTrait<OpTrait::SymbolTable>();
+
+ if (symbolTableOp) {
+ SymbolTable &symbolTable = symbolTables->getSymbolTable(symbolTableOp);
+ symbolTable.insert(newGlobal, rewriter.getInsertionPoint());
+ }
+ }
+
if (!global.isExternal() && global.isUninitialized()) {
rewriter.createBlock(&newGlobal.getInitializerRegion());
Value undef[] = {
@@ -994,8 +1047,15 @@ struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> {
/// For memrefs with identity layouts, the copy is lowered to the llvm
/// `memcpy` intrinsic. For non-identity layouts, the copy is lowered to a call
/// to the generic `MemrefCopyFn`.
-struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
- using ConvertOpToLLVMPattern<memref::CopyOp>::ConvertOpToLLVMPattern;
+class MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
+ SymbolTableCollection *symbolTables = nullptr;
+
+public:
+ explicit MemRefCopyOpLowering(const LLVMTypeConverter &typeConverter,
+ SymbolTableCollection *symbolTables = nullptr,
+ PatternBenefit benefit = 1)
+ : ConvertOpToLLVMPattern<memref::CopyOp>(typeConverter, benefit),
+ symbolTables(symbolTables) {}
LogicalResult
lowerToMemCopyIntrinsic(memref::CopyOp op, OpAdaptor adaptor,
@@ -1090,7 +1150,7 @@ struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
auto elemSize = getSizeInBytes(loc, srcType.getElementType(), rewriter);
auto copyFn = LLVM::lookupOrCreateMemRefCopyFn(
rewriter, op->getParentOfType<ModuleOp>(), getIndexType(),
- sourcePtr.getType());
+ sourcePtr.getType(), symbolTables);
if (failed(copyFn))
return failure();
rewriter.create<LLVM::CallOp>(loc, copyFn.value(),
@@ -1915,7 +1975,8 @@ class ExtractStridedMetadataOpLowering
} // namespace
void mlir::populateFinalizeMemRefToLLVMConversionPatterns(
- const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
+ const LLVMTypeConverter &converter, RewritePatternSet &patterns,
+ SymbolTableCollection *symbolTables) {
// clang-format off
patterns.add<
AllocaOpLowering,
@@ -1926,11 +1987,9 @@ void mlir::populateFinalizeMemRefToLLVMConversionPatterns(
DimOpLowering,
ExtractStridedMetadataOpLowering,
GenericAtomicRMWOpLowering,
- GlobalMemrefOpLowering,
GetGlobalMemrefOpLowering,
LoadOpLowering,
MemRefCastOpLowering,
- MemRefCopyOpLowering,
MemorySpaceCastOpLowering,
MemRefReinterpretCastOpLowering,
MemRefReshapeOpLowering,
@@ -1943,11 +2002,14 @@ void mlir::populateFinalizeMemRefToLLVMConversionPatterns(
TransposeOpLowering,
ViewOpLowering>(converter);
// clang-format on
+ patterns.add<GlobalMemrefOpLowering, MemRefCopyOpLowering>(converter,
+ symbolTables);
auto allocLowering = converter.getOptions().allocLowering;
if (allocLowering == LowerToLLVMOptions::AllocLowering::AlignedAlloc)
- patterns.add<AlignedAllocOpLowering, DeallocOpLowering>(converter);
+ patterns.add<AlignedAllocOpLowering, DeallocOpLowering>(converter,
+ symbolTables);
else if (allocLowering == LowerToLLVMOptions::AllocLowering::Malloc)
- patterns.add<AllocOpLowering, DeallocOpLowering>(converter);
+ patterns.add<AllocOpLowering, DeallocOpLowering>(converter, symbolTables);
}
namespace {
@@ -1974,7 +2036,9 @@ struct FinalizeMemRefToLLVMConversionPass
LLVMTypeConverter typeConverter(&getContext(), options,
&dataLayoutAnalysis);
RewritePatternSet patterns(&getContext());
- populateFinalizeMemRefToLLVMConversionPatterns(typeConverter, patterns);
+ SymbolTableCollection symbolTables;
+ populateFinalizeMemRefToLLVMConversionPatterns(typeConverter, patterns,
+ &symbolTables);
LLVMConversionTarget target(getContext());
target.addLegalOp<func::FuncOp>();
if (failed(applyPartialConversion(op, target, std::move(patterns))))
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index f725993635672..d53d11f87efe8 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1595,8 +1595,14 @@ class VectorCreateMaskOpConversion
};
class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
+ SymbolTableCollection *symbolTables = nullptr;
+
public:
- using ConvertOpToLLVMPattern<vector::PrintOp>::ConvertOpToLLVMPattern;
+ explicit VectorPrintOpConversion(
+ const LLVMTypeConverter &typeConverter,
+ SymbolTableCollection *symbolTables = nullptr)
+ : ConvertOpToLLVMPattern<vector::PrintOp>(typeConverter),
+ symbolTables(symbolTables) {}
// Lowering implementation that relies on a small runtime support library,
// which only needs to provide a few printing methods (single value for all
@@ -1643,13 +1649,17 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
FailureOr<LLVM::LLVMFuncOp> op = [&]() {
switch (punct) {
case PrintPunctuation::Close:
- return LLVM::lookupOrCreatePrintCloseFn(rewriter, parent);
+ return LLVM::lookupOrCreatePrintCloseFn(rewriter, parent,
+ symbolTables);
case PrintPunctuation::Open:
- return LLVM::lookupOrCreatePrintOpenFn(rewriter, parent);
+ return LLVM::lookupOrCreatePrintOpenFn(rewriter, parent,
+ symbolTables);
case PrintPunctuation::Comma:
- return LLVM::lookupOrCreatePrintCommaFn(rewriter, parent);
+ return LLVM::lookupOrCreatePrintCommaFn(rewriter, parent,
+ symbolTables);
case PrintPunctuation::NewLine:
- return LLVM::lookupOrCreatePrintNewlineFn(rewriter, parent);
+ return LLVM::lookupOrCreatePrintNewlineFn(rewriter, parent,
+ symbolTables);
default:
llvm_unreachable("unexpected punctuation");
}
@@ -1683,17 +1693,17 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
PrintConversion conversion = PrintConversion::None;
FailureOr<Operation *> printer;
if (printType.isF32()) {
- printer = LLVM::lookupOrCreatePrintF32Fn(rewriter, parent);
+ printer = LLVM::lookupOrCreatePrintF32Fn(rewriter, parent, symbolTables);
} else if (printType.isF64()) {
- printer = LLVM::lookupOrCreatePrintF64Fn(rewriter, parent);
+ printer = LLVM::lookupOrCreatePrintF64Fn(rewriter, parent, symbolTables);
} else if (printType.isF16()) {
conversion = PrintConversion::Bitcast16; // bits!
- printer = LLVM::lookupOrCreatePrintF16Fn(rewriter, parent);
+ printer = LLVM::lookupOrCreatePrintF16Fn(rewriter, parent, symbolTables);
} else if (printType.isBF16()) {
conversion = PrintConversion::Bitcast16; // bits!
- printer = LLVM::lookupOrCreatePrintBF16Fn(rewriter, parent);
+ printer = LLVM::lookupOrCreatePrintBF16Fn(rewriter, parent, symbolTables);
} else if (printType.isIndex()) {
- printer = LLVM::lookupOrCreatePrintU64Fn(rewriter, parent);
+ printer = LLVM::lookupOrCreatePrintU64Fn(rewriter, parent, symbolTables);
} else if (auto intTy = dyn_cast<IntegerType>(printType)) {
// Integers need a zero or sign extension on the operand
// (depending on the source type) as well as a signed or
@@ -1703,7 +1713,8 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
if (width <= 64) {
if (width < 64)
conversion = PrintConversion::ZeroExt64;
- printer = LLVM::lookupOrCreatePrintU64Fn(rewriter, parent);
+ printer =
+ LLVM::lookupOrCreatePrintU64Fn(rewriter, parent, symbolTables);
} else {
return failure();
}
@@ -1716,7 +1727,8 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
conversion = PrintConversion::ZeroExt64;
else if (width < 64)
conversion = PrintConversion::SignExt64;
- printer = LLVM::lookupOrCreatePrintI64Fn(rewriter, parent);
+ printer =
+ LLVM::lookupOrCreatePrintI64Fn(rewriter, parent, symbolTables);
} else {
return failure();
}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
index 1b4a8f496d3d0..62ace5d20bd88 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
@@ -44,15 +44,28 @@ static constexpr llvm::StringRef kGenericAlignedAlloc =
static constexpr llvm::StringRef kGenericFree = "_mlir_memref_to_llvm_free";
static constexpr llvm::StringRef kMemRefCopy = "memrefCopy";
+namespace {
+LLVM::LLVMFuncOp lookupFuncOp(StringRef name, Operation *symbolTableOp,
+ SymbolTableCollection *symbolTables) {
+ if (symbolTables) {
+ return symbolTables->lookupSymbolIn<LLVM::LLVMFuncOp>(
+ symbolTableOp, StringAttr::get(symbolTableOp->getContext(), name));
+ }
+
+ return llvm::dyn_cast_or_null<LLVM::LLVMFuncOp>(
+ SymbolTable::lookupSymbolIn(symbolTableOp, name));
+}
+} // namespace
+
/// Generic print function lookupOrCreate helper.
FailureOr<LLVM::LLVMFuncOp>
mlir::LLVM::lookupOrCreateFn(OpBuilder &b, Operation *moduleOp, StringRef name,
ArrayRef<Type> paramTypes, Type resultType,
- bool isVarArg, bool isReserved) {
+ bool isVarArg, bool isReserved,
+ SymbolTableCollection *symbolTables) {
assert(moduleOp->hasTrait<OpTrait::SymbolTable>() &&
"expected SymbolTable operation");
- auto func = llvm::dyn_cast_or_null<LLVM::LLVMFuncOp>(
- SymbolTable::lookupSymbolIn(moduleOp, name));
+ auto func = lookupFuncOp(name, moduleOp, symbolTables);
auto funcT = LLVMFunctionType::get(resultType, paramTypes, isVarArg);
// Assert the signature of the found function is same as expected
if (func) {
@@ -73,60 +86,75 @@ mlir::LLVM::lookupOrCreateFn(OpBuilder &b, Operation *moduleOp, StringRef name,
OpBuilder::InsertionGuard g(b);
assert(!moduleOp->getRegion(0).empty() && "expected non-empty region");
b.setInsertionPointToStart(&moduleOp->getRegion(0).front());
- return b.create<LLVM::LLVMFuncOp>(
+ auto funcOp = b.create<LLVM::LLVMFuncOp>(
moduleOp->getLoc(), name,
LLVM::LLVMFunctionType::get(resultType, paramTypes, isVarArg));
+
+ if (symbolTables) {
+ SymbolTable &symbolTable = symbolTables->getSymbolTable(moduleOp);
+ symbolTable.insert(funcOp, moduleOp->getRegion(0).front().begin());
+ }
+
+ return funcOp;
}
static FailureOr<LLVM::LLVMFuncOp>
lookupOrCreateReservedFn(OpBuilder &b, Operation *moduleOp, StringRef name,
- ArrayRef<Type> paramTypes, Type resultType) {
+ ArrayRef<Type> paramTypes, Type resultType,
+ SymbolTableCollection *symbolTables) {
return lookupOrCreateFn(b, moduleOp, name, paramTypes, resultType,
- /*isVarArg=*/false, /*isReserved=*/true);
+ /*isVarArg=*/false, /*isReserved=*/true,
+ symbolTables);
}
FailureOr<LLVM::LLVMFuncOp>
-mlir::LLVM::lookupOrCreatePrintI64Fn(OpBuilder &b, Operation *moduleOp) {
+mlir::LLVM::lookupOrCreatePrintI64Fn(OpBuilder &b, Operation *moduleOp,
+ SymbolTableCollection *symbolTables) {
return lookupOrCreateReservedFn(
b, moduleOp, kPrintI64, IntegerType::get(moduleOp->getContext(), 64),
- LLVM::LLVMVoidType::get(moduleOp->getContext()));
+ LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
}
FailureOr<LLVM::LLVMFuncOp>
-mlir::LLVM::lookupOrCreatePrintU64Fn(OpBuilder &b, Operation *moduleOp) {
+mlir::LLVM::lookupOrCreatePrintU64Fn(OpBuilder &b, Operation *moduleOp,
+ SymbolTableCollection *symbolTables) {
return lookupOrCreateReservedFn(
b, moduleOp, kPrintU64, IntegerType::get(moduleOp->getContext(), 64),
- LLVM::LLVMVoidType::get(moduleOp->getContext()));
+ LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
}
FailureOr<LLVM::LLVMFuncOp>
-mlir::LLVM::lookupOrCreatePrintF16Fn(OpBuilder &b, Operation *moduleOp) {
+mlir::LLVM::lookupOrCreatePrintF16Fn(OpBuilder &b, Operation *moduleOp,
+ SymbolTableCollection *symbolTables) {
return lookupOrCreateReservedFn(
b, moduleOp, kPrintF16,
IntegerType::get(moduleOp->getContext(), 16), // bits!
- LLVM::LLVMVoidType::get(moduleOp->getContext()));
+ LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
}
FailureOr<LLVM::LLVMFuncOp>
-mlir::LLVM::lookupOrCreatePrintBF16Fn(OpBuilder &b, Operation *moduleOp) {
+mlir::LLVM::lookupOrCreatePrintBF16Fn(OpBuilder &b, Operation *moduleOp,
+ SymbolTableCollection *symbolTables) {
return lookupOrCreateReservedFn(
b, moduleOp, kPrintBF16,
IntegerType::get(moduleOp->getContext(), 16), // bits!
- LLVM::LLVMVoidType::get(moduleOp->getContext()));
+ LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
}
FailureOr<LLVM::LLVMFuncOp>
-mlir::LLVM::lookupOrCreatePrintF32Fn(OpBuilder &b, Operation *moduleOp) {
+mlir::LLVM::lookupOrCreatePrintF32Fn(OpBuilder &b, Operation *moduleOp,
+ SymbolTableCollection *symbolTables) {
return lookupOrCreateReservedFn(
b, moduleOp, kPrintF32, Float32Type::get(moduleOp->getContext()),
- LLVM::LLVMVoidType::get(moduleOp->getContext()));
+ LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
}
FailureOr<LLVM::LLVMFuncOp>
-mlir::LLVM::lookupOrCreatePrintF64Fn(OpBuilder &b, Operation *moduleOp) {
+mlir::LLVM::lookupOrCreatePrintF64Fn(OpBuilder &b, Operation *moduleOp,
+ SymbolTableCollection *symbolTables) {
return lookupOrCreateReservedFn(
b, moduleOp, kPrintF64, Float64Type::get(moduleOp->getContext()),
- LLVM::LLVMVoidType::get(moduleOp->getContext()));
+ LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
}
static LLVM::LLVMPointerType getCharPtr(MLIRContext *context) {
@@ -140,90 +168,102 @@ static LLVM::LLVMPointerType getVoidPtr(MLIRContext *context) {
FailureOr<LLVM::LLVMFuncOp> mlir::LLVM::lookupOrCreatePrintStringFn(
OpBuilder &b, Operation *moduleOp,
- std::optional<StringRef> runtimeFunctionName) {
+ std::optional<StringRef> runtimeFunctionName,
+ SymbolTableCollection *symbolTables) {
return lookupOrCreateReservedFn(
b, moduleOp, runtimeFunctionName.value_or(kPrintString),
getCharPtr(moduleOp->getContext()),
- LLVM::LLVMVoidType::get(moduleOp->getContext()));
+ LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
}
FailureOr<LLVM::LLVMFuncOp>
-mlir::LLVM::lookupOrCreatePrintOpenFn(OpBuilder &b, Operation *moduleOp) {
+mlir::LLVM::lookupOrCreatePrintOpenFn(OpBuilder &b, Operation *moduleOp,
+ SymbolTableCollection *symbolTables) {
return lookupOrCreateReservedFn(
b, moduleOp, kPrintOpen, {},
- LLVM::LLVMVoidType::get(moduleOp->getContext()));
+ LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
}
FailureOr<LLVM::LLVMFuncOp>
-mlir::LLVM::lookupOrCreatePrintCloseFn(OpBuilder &b, Operation *moduleOp) {
+mlir::LLVM::lookupOrCreatePrintCloseFn(OpBuilder &b, Operation *moduleOp,
+ SymbolTableCollection *symbolTables) {
return lookupOrCreateReservedFn(
b, moduleOp, kPrintClose, {},
- LLVM::LLVMVoidType::get(moduleOp->getContext()));
+ LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
}
FailureOr<LLVM::LLVMFuncOp>
-mlir::LLVM::lookupOrCreatePrintCommaFn(OpBuilder &b, Operation *moduleOp) {
+mlir::LLVM::lookupOrCreatePrintCommaFn(OpBuilder &b, Operation *moduleOp,
+ SymbolTableCollection *symbolTables) {
return lookupOrCreateReservedFn(
b, moduleOp, kPrintComma, {},
- LLVM::LLVMVoidType::get(moduleOp->getContext()));
+ LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
}
FailureOr<LLVM::LLVMFuncOp>
-mlir::LLVM::lookupOrCreatePrintNewlineFn(OpBuilder &b, Operation *moduleOp) {
+mlir::LLVM::lookupOrCreatePrintNewlineFn(OpBuilder &b, Operation *moduleOp,
+ SymbolTableCollection *symbolTables) {
return lookupOrCreateReservedFn(
b, moduleOp, kPrintNewline, {},
- LLVM::LLVMVoidType::get(moduleOp->getContext()));
+ LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
}
FailureOr<LLVM::LLVMFuncOp>
mlir::LLVM::lookupOrCreateMallocFn(OpBuilder &b, Operation *moduleOp,
- Type indexType) {
+ Type indexType,
+ SymbolTableCollection *symbolTables) {
return lookupOrCreateReservedFn(b, moduleOp, kMalloc, indexType,
- getVoidPtr(moduleOp->getContext()));
+ getVoidPtr(moduleOp->getContext()),
+ symbolTables);
}
FailureOr<LLVM::LLVMFuncOp>
mlir::LLVM::lookupOrCreateAlignedAllocFn(OpBuilder &b, Operation *moduleOp,
- Type indexType) {
- return lookupOrCreateReservedFn(b, moduleOp, kAlignedAlloc,
- {indexType, indexType},
- getVoidPtr(moduleOp->getContext()));
+ Type indexType,
+ SymbolTableCollection *symbolTables) {
+ return lookupOrCreateReservedFn(
+ b, moduleOp, kAlignedAlloc, {indexType, indexType},
+ getVoidPtr(moduleOp->getContext()), symbolTables);
}
FailureOr<LLVM::LLVMFuncOp>
-mlir::LLVM::lookupOrCreateFreeFn(OpBuilder &b, Operation *moduleOp) {
+mlir::LLVM::lookupOrCreateFreeFn(OpBuilder &b, Operation *moduleOp,
+ SymbolTableCollection *symbolTables) {
return lookupOrCreateReservedFn(
b, moduleOp, kFree, getVoidPtr(moduleOp->getContext()),
- LLVM::LLVMVoidType::get(moduleOp->getContext()));
+ LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
}
FailureOr<LLVM::LLVMFuncOp>
mlir::LLVM::lookupOrCreateGenericAllocFn(OpBuilder &b, Operation *moduleOp,
- Type indexType) {
+ Type indexType,
+ SymbolTableCollection *symbolTables) {
return lookupOrCreateReservedFn(b, moduleOp, kGenericAlloc, indexType,
- getVoidPtr(moduleOp->getContext()));
+ getVoidPtr(moduleOp->getContext()),
+ symbolTables);
}
FailureOr<LLVM::LLVMFuncOp> mlir::LLVM::lookupOrCreateGenericAlignedAllocFn(
- OpBuilder &b, Operation *moduleOp, Type indexType) {
- return lookupOrCreateReservedFn(b, moduleOp, kGenericAlignedAlloc,
- {indexType, indexType},
- getVoidPtr(moduleOp->getContext()));
+ OpBuilder &b, Operation *moduleOp, Type indexType,
+ SymbolTableCollection *symbolTables) {
+ return lookupOrCreateReservedFn(
+ b, moduleOp, kGenericAlignedAlloc, {indexType, indexType},
+ getVoidPtr(moduleOp->getContext()), symbolTables);
}
FailureOr<LLVM::LLVMFuncOp>
-mlir::LLVM::lookupOrCreateGenericFreeFn(OpBuilder &b, Operation *moduleOp) {
+mlir::LLVM::lookupOrCreateGenericFreeFn(OpBuilder &b, Operation *moduleOp,
+ SymbolTableCollection *symbolTables) {
return lookupOrCreateReservedFn(
b, moduleOp, kGenericFree, getVoidPtr(moduleOp->getContext()),
- LLVM::LLVMVoidType::get(moduleOp->getContext()));
+ LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
}
-FailureOr<LLVM::LLVMFuncOp>
-mlir::LLVM::lookupOrCreateMemRefCopyFn(OpBuilder &b, Operation *moduleOp,
- Type indexType,
- Type unrankedDescriptorType) {
+FailureOr<LLVM::LLVMFuncOp> mlir::LLVM::lookupOrCreateMemRefCopyFn(
+ OpBuilder &b, Operation *moduleOp, Type indexType,
+ Type unrankedDescriptorType, SymbolTableCollection *symbolTables) {
return lookupOrCreateReservedFn(
b, moduleOp, kMemRefCopy,
ArrayRef<Type>{indexType, unrankedDescriptorType, unrankedDescriptorType},
- LLVM::LLVMVoidType::get(moduleOp->getContext()));
+ LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
}
>From 74a74cddc55e8239cae01ffd8282d4636528ac5a Mon Sep 17 00:00:00 2001
From: Michele Scuttari <michele.scuttari at outlook.com>
Date: Fri, 20 Jun 2025 18:30:39 +0200
Subject: [PATCH 2/4] Set default value of pattern constructor argument
---
mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp | 1 -
1 file changed, 1 deletion(-)
diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
index 6a6371921c1d5..00c6736e6f21f 100644
--- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
+++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
@@ -614,7 +614,6 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
class CallOpLowering : public CallOpInterfaceLowering<func::CallOp> {
public:
explicit CallOpLowering(const LLVMTypeConverter &typeConverter,
- // Can be nullptr.
SymbolTableCollection *symbolTables = nullptr,
PatternBenefit benefit = 1)
: CallOpInterfaceLowering<func::CallOp>(typeConverter, benefit),
>From bd7fe2a9e3c900b24c370a898ddddacc8ba09fcf Mon Sep 17 00:00:00 2001
From: Michele Scuttari <michele.scuttari at outlook.com>
Date: Fri, 20 Jun 2025 18:30:58 +0200
Subject: [PATCH 3/4] Add documentation
---
mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp | 4 ++++
1 file changed, 4 insertions(+)
diff --git a/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp b/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp
index b2f84a02a52d1..49c73fbc9dd79 100644
--- a/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp
@@ -17,6 +17,10 @@
using namespace mlir;
using namespace llvm;
+/// Check if a given symbol name is already in use within the module operation.
+/// If no symbol with such name is present, then the same identifier is
+/// returned. Otherwise, a unique and yet unused identifier is computed starting
+/// from the requested one.
static std::string
ensureSymbolNameIsUnique(ModuleOp moduleOp, StringRef symbolName,
SymbolTableCollection *symbolTables = nullptr) {
>From 9686663a4e79e3c7823e4414c89fade3cd348ab6 Mon Sep 17 00:00:00 2001
From: Michele Scuttari <michele.scuttari at outlook.com>
Date: Fri, 20 Jun 2025 18:31:15 +0200
Subject: [PATCH 4/4] Add documentation
---
mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp | 5 ++++-
1 file changed, 4 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
index 62ace5d20bd88..89f765dacda35 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
@@ -45,8 +45,11 @@ static constexpr llvm::StringRef kGenericFree = "_mlir_memref_to_llvm_free";
static constexpr llvm::StringRef kMemRefCopy = "memrefCopy";
namespace {
+/// Search for an LLVMFuncOp with a given name within an operation with the
+/// SymbolTable trait. An optional collection of cached symbol tables can be
+/// given to avoid a linear scan of the symbol table operation.
LLVM::LLVMFuncOp lookupFuncOp(StringRef name, Operation *symbolTableOp,
- SymbolTableCollection *symbolTables) {
+ SymbolTableCollection *symbolTables = nullptr) {
if (symbolTables) {
return symbolTables->lookupSymbolIn<LLVM::LLVMFuncOp>(
symbolTableOp, StringAttr::get(symbolTableOp->getContext(), name));
More information about the Mlir-commits
mailing list