[Mlir-commits] [mlir] d843755 - [NFC] Refactor function declaration addition in AsyncToLLVM
Rahul Joshi
llvmlistbot at llvm.org
Fri Nov 13 13:00:09 PST 2020
Author: Rahul Joshi
Date: 2020-11-13T12:53:19-08:00
New Revision: d8437552205d5036f3746d760002fbb637f4018d
URL: https://github.com/llvm/llvm-project/commit/d8437552205d5036f3746d760002fbb637f4018d
DIFF: https://github.com/llvm/llvm-project/commit/d8437552205d5036f3746d760002fbb637f4018d.diff
LOG: [NFC] Refactor function declaration addition in AsyncToLLVM
- Extract repeated code into helper function/lambdas.
Differential Revision: https://reviews.llvm.org/D91453
Added:
Modified:
mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
index f063e02fd067..f0bd72028184 100644
--- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
+++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
@@ -106,43 +106,22 @@ struct AsyncAPI {
static void addAsyncRuntimeApiDeclarations(ModuleOp module) {
auto builder = OpBuilder::atBlockTerminator(module.getBody());
- MLIRContext *ctx = module.getContext();
- Location loc = module.getLoc();
-
- if (!module.lookupSymbol(kCreateToken))
- builder.create<FuncOp>(loc, kCreateToken,
- AsyncAPI::createTokenFunctionType(ctx));
-
- if (!module.lookupSymbol(kCreateGroup))
- builder.create<FuncOp>(loc, kCreateGroup,
- AsyncAPI::createGroupFunctionType(ctx));
-
- if (!module.lookupSymbol(kEmplaceToken))
- builder.create<FuncOp>(loc, kEmplaceToken,
- AsyncAPI::emplaceTokenFunctionType(ctx));
-
- if (!module.lookupSymbol(kAwaitToken))
- builder.create<FuncOp>(loc, kAwaitToken,
- AsyncAPI::awaitTokenFunctionType(ctx));
-
- if (!module.lookupSymbol(kAwaitGroup))
- builder.create<FuncOp>(loc, kAwaitGroup,
- AsyncAPI::awaitGroupFunctionType(ctx));
+ auto addFuncDecl = [&](StringRef name, FunctionType type) {
+ if (module.lookupSymbol(name))
+ return;
+ builder.create<FuncOp>(module.getLoc(), name, type);
+ };
- if (!module.lookupSymbol(kExecute))
- builder.create<FuncOp>(loc, kExecute, AsyncAPI::executeFunctionType(ctx));
-
- if (!module.lookupSymbol(kAddTokenToGroup))
- builder.create<FuncOp>(loc, kAddTokenToGroup,
- AsyncAPI::addTokenToGroupFunctionType(ctx));
-
- if (!module.lookupSymbol(kAwaitAndExecute))
- builder.create<FuncOp>(loc, kAwaitAndExecute,
- AsyncAPI::awaitAndExecuteFunctionType(ctx));
-
- if (!module.lookupSymbol(kAwaitAllAndExecute))
- builder.create<FuncOp>(loc, kAwaitAllAndExecute,
- AsyncAPI::awaitAllAndExecuteFunctionType(ctx));
+ MLIRContext *ctx = module.getContext();
+ addFuncDecl(kCreateToken, AsyncAPI::createTokenFunctionType(ctx));
+ addFuncDecl(kCreateGroup, AsyncAPI::createGroupFunctionType(ctx));
+ addFuncDecl(kEmplaceToken, AsyncAPI::emplaceTokenFunctionType(ctx));
+ addFuncDecl(kAwaitToken, AsyncAPI::awaitTokenFunctionType(ctx));
+ addFuncDecl(kAwaitGroup, AsyncAPI::awaitGroupFunctionType(ctx));
+ addFuncDecl(kExecute, AsyncAPI::executeFunctionType(ctx));
+ addFuncDecl(kAddTokenToGroup, AsyncAPI::addTokenToGroupFunctionType(ctx));
+ addFuncDecl(kAwaitAndExecute, AsyncAPI::awaitAndExecuteFunctionType(ctx));
+ addFuncDecl(kAwaitAllAndExecute, AsyncAPI::awaitAllAndExecuteFunctionType(ctx));
}
//===----------------------------------------------------------------------===//
@@ -158,13 +137,21 @@ static constexpr const char *kCoroEnd = "llvm.coro.end";
static constexpr const char *kCoroFree = "llvm.coro.free";
static constexpr const char *kCoroResume = "llvm.coro.resume";
+/// Adds an LLVM function declaration to a module.
+static void addLLVMFuncDecl(ModuleOp module, OpBuilder &builder, StringRef name,
+ LLVM::LLVMType ret,
+ ArrayRef<LLVM::LLVMType> params) {
+ if (module.lookupSymbol(name))
+ return;
+ LLVM::LLVMType type = LLVM::LLVMType::getFunctionTy(ret, params, false);
+ builder.create<LLVM::LLVMFuncOp>(module.getLoc(), name, type);
+}
+
/// Adds coroutine intrinsics declarations to the module.
static void addCoroutineIntrinsicsDeclarations(ModuleOp module) {
using namespace mlir::LLVM;
MLIRContext *ctx = module.getContext();
- Location loc = module.getLoc();
-
OpBuilder builder(module.getBody()->getTerminator());
auto token = LLVMTokenType::get(ctx);
@@ -176,38 +163,14 @@ static void addCoroutineIntrinsicsDeclarations(ModuleOp module) {
auto i64 = LLVMType::getInt64Ty(ctx);
auto i8Ptr = LLVMType::getInt8PtrTy(ctx);
- if (!module.lookupSymbol(kCoroId))
- builder.create<LLVMFuncOp>(
- loc, kCoroId,
- LLVMType::getFunctionTy(token, {i32, i8Ptr, i8Ptr, i8Ptr}, false));
-
- if (!module.lookupSymbol(kCoroSizeI64))
- builder.create<LLVMFuncOp>(loc, kCoroSizeI64,
- LLVMType::getFunctionTy(i64, false));
-
- if (!module.lookupSymbol(kCoroBegin))
- builder.create<LLVMFuncOp>(
- loc, kCoroBegin, LLVMType::getFunctionTy(i8Ptr, {token, i8Ptr}, false));
-
- if (!module.lookupSymbol(kCoroSave))
- builder.create<LLVMFuncOp>(loc, kCoroSave,
- LLVMType::getFunctionTy(token, i8Ptr, false));
-
- if (!module.lookupSymbol(kCoroSuspend))
- builder.create<LLVMFuncOp>(loc, kCoroSuspend,
- LLVMType::getFunctionTy(i8, {token, i1}, false));
-
- if (!module.lookupSymbol(kCoroEnd))
- builder.create<LLVMFuncOp>(loc, kCoroEnd,
- LLVMType::getFunctionTy(i1, {i8Ptr, i1}, false));
-
- if (!module.lookupSymbol(kCoroFree))
- builder.create<LLVMFuncOp>(
- loc, kCoroFree, LLVMType::getFunctionTy(i8Ptr, {token, i8Ptr}, false));
-
- if (!module.lookupSymbol(kCoroResume))
- builder.create<LLVMFuncOp>(loc, kCoroResume,
- LLVMType::getFunctionTy(voidTy, i8Ptr, false));
+ addLLVMFuncDecl(module, builder, kCoroId, token, {i32, i8Ptr, i8Ptr, i8Ptr});
+ addLLVMFuncDecl(module, builder, kCoroSizeI64, i64, {});
+ addLLVMFuncDecl(module, builder, kCoroBegin, i8Ptr, {token, i8Ptr});
+ addLLVMFuncDecl(module, builder, kCoroSave, token, {i8Ptr});
+ addLLVMFuncDecl(module, builder, kCoroSuspend, i8, {token, i1});
+ addLLVMFuncDecl(module, builder, kCoroEnd, i1, {i8Ptr, i1});
+ addLLVMFuncDecl(module, builder, kCoroFree, i8Ptr, {token, i8Ptr});
+ addLLVMFuncDecl(module, builder, kCoroResume, voidTy, {i8Ptr});
}
//===----------------------------------------------------------------------===//
@@ -222,21 +185,14 @@ static void addCRuntimeDeclarations(ModuleOp module) {
using namespace mlir::LLVM;
MLIRContext *ctx = module.getContext();
- Location loc = module.getLoc();
-
OpBuilder builder(module.getBody()->getTerminator());
auto voidTy = LLVMType::getVoidTy(ctx);
auto i64 = LLVMType::getInt64Ty(ctx);
auto i8Ptr = LLVMType::getInt8PtrTy(ctx);
- if (!module.lookupSymbol(kMalloc))
- builder.create<LLVM::LLVMFuncOp>(
- loc, kMalloc, LLVMType::getFunctionTy(i8Ptr, {i64}, false));
-
- if (!module.lookupSymbol(kFree))
- builder.create<LLVM::LLVMFuncOp>(
- loc, kFree, LLVMType::getFunctionTy(voidTy, i8Ptr, false));
+ addLLVMFuncDecl(module, builder, kMalloc, i8Ptr, {i64});
+ addLLVMFuncDecl(module, builder, kFree, voidTy, {i8Ptr});
}
//===----------------------------------------------------------------------===//
@@ -261,8 +217,8 @@ static void addResumeFunction(ModuleOp module) {
auto i8Ptr = LLVM::LLVMType::getInt8PtrTy(ctx);
auto resumeOp = moduleBuilder.create<LLVM::LLVMFuncOp>(
- loc, kResume, LLVM::LLVMFunctionType::get(voidTy, {i8Ptr}));
- SymbolTable::setSymbolVisibility(resumeOp, SymbolTable::Visibility::Private);
+ loc, kResume, LLVM::LLVMType::getFunctionTy(voidTy, {i8Ptr}, false));
+ resumeOp.setPrivate();
auto *block = resumeOp.addEntryBlock();
OpBuilder blockBuilder = OpBuilder::atBlockEnd(block);
More information about the Mlir-commits
mailing list