[Mlir-commits] [mlir] d843755 - [NFC] Refactor function declaration addition in AsyncToLLVM

Rahul Joshi llvmlistbot at llvm.org
Fri Nov 13 13:00:09 PST 2020


Author: Rahul Joshi
Date: 2020-11-13T12:53:19-08:00
New Revision: d8437552205d5036f3746d760002fbb637f4018d

URL: https://github.com/llvm/llvm-project/commit/d8437552205d5036f3746d760002fbb637f4018d
DIFF: https://github.com/llvm/llvm-project/commit/d8437552205d5036f3746d760002fbb637f4018d.diff

LOG: [NFC] Refactor function declaration addition in AsyncToLLVM

- Extract repeated code into helper function/lambdas.

Differential Revision: https://reviews.llvm.org/D91453

Added: 
    

Modified: 
    mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
index f063e02fd067..f0bd72028184 100644
--- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
+++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
@@ -106,43 +106,22 @@ struct AsyncAPI {
 static void addAsyncRuntimeApiDeclarations(ModuleOp module) {
   auto builder = OpBuilder::atBlockTerminator(module.getBody());
 
-  MLIRContext *ctx = module.getContext();
-  Location loc = module.getLoc();
-
-  if (!module.lookupSymbol(kCreateToken))
-    builder.create<FuncOp>(loc, kCreateToken,
-                           AsyncAPI::createTokenFunctionType(ctx));
-
-  if (!module.lookupSymbol(kCreateGroup))
-    builder.create<FuncOp>(loc, kCreateGroup,
-                           AsyncAPI::createGroupFunctionType(ctx));
-
-  if (!module.lookupSymbol(kEmplaceToken))
-    builder.create<FuncOp>(loc, kEmplaceToken,
-                           AsyncAPI::emplaceTokenFunctionType(ctx));
-
-  if (!module.lookupSymbol(kAwaitToken))
-    builder.create<FuncOp>(loc, kAwaitToken,
-                           AsyncAPI::awaitTokenFunctionType(ctx));
-
-  if (!module.lookupSymbol(kAwaitGroup))
-    builder.create<FuncOp>(loc, kAwaitGroup,
-                           AsyncAPI::awaitGroupFunctionType(ctx));
+  auto addFuncDecl = [&](StringRef name, FunctionType type) {
+    if (module.lookupSymbol(name))
+      return;
+    builder.create<FuncOp>(module.getLoc(), name, type);
+  };
 
-  if (!module.lookupSymbol(kExecute))
-    builder.create<FuncOp>(loc, kExecute, AsyncAPI::executeFunctionType(ctx));
-
-  if (!module.lookupSymbol(kAddTokenToGroup))
-    builder.create<FuncOp>(loc, kAddTokenToGroup,
-                           AsyncAPI::addTokenToGroupFunctionType(ctx));
-
-  if (!module.lookupSymbol(kAwaitAndExecute))
-    builder.create<FuncOp>(loc, kAwaitAndExecute,
-                           AsyncAPI::awaitAndExecuteFunctionType(ctx));
-
-  if (!module.lookupSymbol(kAwaitAllAndExecute))
-    builder.create<FuncOp>(loc, kAwaitAllAndExecute,
-                           AsyncAPI::awaitAllAndExecuteFunctionType(ctx));
+  MLIRContext *ctx = module.getContext();
+  addFuncDecl(kCreateToken, AsyncAPI::createTokenFunctionType(ctx));
+  addFuncDecl(kCreateGroup, AsyncAPI::createGroupFunctionType(ctx));
+  addFuncDecl(kEmplaceToken, AsyncAPI::emplaceTokenFunctionType(ctx));
+  addFuncDecl(kAwaitToken, AsyncAPI::awaitTokenFunctionType(ctx));
+  addFuncDecl(kAwaitGroup, AsyncAPI::awaitGroupFunctionType(ctx));
+  addFuncDecl(kExecute, AsyncAPI::executeFunctionType(ctx));
+  addFuncDecl(kAddTokenToGroup, AsyncAPI::addTokenToGroupFunctionType(ctx));
+  addFuncDecl(kAwaitAndExecute, AsyncAPI::awaitAndExecuteFunctionType(ctx));
+  addFuncDecl(kAwaitAllAndExecute, AsyncAPI::awaitAllAndExecuteFunctionType(ctx));
 }
 
 //===----------------------------------------------------------------------===//
@@ -158,13 +137,21 @@ static constexpr const char *kCoroEnd = "llvm.coro.end";
 static constexpr const char *kCoroFree = "llvm.coro.free";
 static constexpr const char *kCoroResume = "llvm.coro.resume";
 
+/// Adds an LLVM function declaration to a module.
+static void addLLVMFuncDecl(ModuleOp module, OpBuilder &builder, StringRef name,
+                            LLVM::LLVMType ret,
+                            ArrayRef<LLVM::LLVMType> params) {
+  if (module.lookupSymbol(name))
+    return;
+  LLVM::LLVMType type = LLVM::LLVMType::getFunctionTy(ret, params, false);
+  builder.create<LLVM::LLVMFuncOp>(module.getLoc(), name, type);
+}
+
 /// Adds coroutine intrinsics declarations to the module.
 static void addCoroutineIntrinsicsDeclarations(ModuleOp module) {
   using namespace mlir::LLVM;
 
   MLIRContext *ctx = module.getContext();
-  Location loc = module.getLoc();
-
   OpBuilder builder(module.getBody()->getTerminator());
 
   auto token = LLVMTokenType::get(ctx);
@@ -176,38 +163,14 @@ static void addCoroutineIntrinsicsDeclarations(ModuleOp module) {
   auto i64 = LLVMType::getInt64Ty(ctx);
   auto i8Ptr = LLVMType::getInt8PtrTy(ctx);
 
-  if (!module.lookupSymbol(kCoroId))
-    builder.create<LLVMFuncOp>(
-        loc, kCoroId,
-        LLVMType::getFunctionTy(token, {i32, i8Ptr, i8Ptr, i8Ptr}, false));
-
-  if (!module.lookupSymbol(kCoroSizeI64))
-    builder.create<LLVMFuncOp>(loc, kCoroSizeI64,
-                               LLVMType::getFunctionTy(i64, false));
-
-  if (!module.lookupSymbol(kCoroBegin))
-    builder.create<LLVMFuncOp>(
-        loc, kCoroBegin, LLVMType::getFunctionTy(i8Ptr, {token, i8Ptr}, false));
-
-  if (!module.lookupSymbol(kCoroSave))
-    builder.create<LLVMFuncOp>(loc, kCoroSave,
-                               LLVMType::getFunctionTy(token, i8Ptr, false));
-
-  if (!module.lookupSymbol(kCoroSuspend))
-    builder.create<LLVMFuncOp>(loc, kCoroSuspend,
-                               LLVMType::getFunctionTy(i8, {token, i1}, false));
-
-  if (!module.lookupSymbol(kCoroEnd))
-    builder.create<LLVMFuncOp>(loc, kCoroEnd,
-                               LLVMType::getFunctionTy(i1, {i8Ptr, i1}, false));
-
-  if (!module.lookupSymbol(kCoroFree))
-    builder.create<LLVMFuncOp>(
-        loc, kCoroFree, LLVMType::getFunctionTy(i8Ptr, {token, i8Ptr}, false));
-
-  if (!module.lookupSymbol(kCoroResume))
-    builder.create<LLVMFuncOp>(loc, kCoroResume,
-                               LLVMType::getFunctionTy(voidTy, i8Ptr, false));
+  addLLVMFuncDecl(module, builder, kCoroId, token, {i32, i8Ptr, i8Ptr, i8Ptr});
+  addLLVMFuncDecl(module, builder, kCoroSizeI64, i64, {});
+  addLLVMFuncDecl(module, builder, kCoroBegin, i8Ptr, {token, i8Ptr});
+  addLLVMFuncDecl(module, builder, kCoroSave, token, {i8Ptr});
+  addLLVMFuncDecl(module, builder, kCoroSuspend, i8, {token, i1});
+  addLLVMFuncDecl(module, builder, kCoroEnd, i1, {i8Ptr, i1});
+  addLLVMFuncDecl(module, builder, kCoroFree, i8Ptr, {token, i8Ptr});
+  addLLVMFuncDecl(module, builder, kCoroResume, voidTy, {i8Ptr});
 }
 
 //===----------------------------------------------------------------------===//
@@ -222,21 +185,14 @@ static void addCRuntimeDeclarations(ModuleOp module) {
   using namespace mlir::LLVM;
 
   MLIRContext *ctx = module.getContext();
-  Location loc = module.getLoc();
-
   OpBuilder builder(module.getBody()->getTerminator());
 
   auto voidTy = LLVMType::getVoidTy(ctx);
   auto i64 = LLVMType::getInt64Ty(ctx);
   auto i8Ptr = LLVMType::getInt8PtrTy(ctx);
 
-  if (!module.lookupSymbol(kMalloc))
-    builder.create<LLVM::LLVMFuncOp>(
-        loc, kMalloc, LLVMType::getFunctionTy(i8Ptr, {i64}, false));
-
-  if (!module.lookupSymbol(kFree))
-    builder.create<LLVM::LLVMFuncOp>(
-        loc, kFree, LLVMType::getFunctionTy(voidTy, i8Ptr, false));
+  addLLVMFuncDecl(module, builder, kMalloc, i8Ptr, {i64});
+  addLLVMFuncDecl(module, builder, kFree, voidTy, {i8Ptr});
 }
 
 //===----------------------------------------------------------------------===//
@@ -261,8 +217,8 @@ static void addResumeFunction(ModuleOp module) {
   auto i8Ptr = LLVM::LLVMType::getInt8PtrTy(ctx);
 
   auto resumeOp = moduleBuilder.create<LLVM::LLVMFuncOp>(
-      loc, kResume, LLVM::LLVMFunctionType::get(voidTy, {i8Ptr}));
-  SymbolTable::setSymbolVisibility(resumeOp, SymbolTable::Visibility::Private);
+      loc, kResume, LLVM::LLVMType::getFunctionTy(voidTy, {i8Ptr}, false));
+  resumeOp.setPrivate();
 
   auto *block = resumeOp.addEntryBlock();
   OpBuilder blockBuilder = OpBuilder::atBlockEnd(block);


        


More information about the Mlir-commits mailing list