[Mlir-commits] [mlir] [mlir] Fix conflict of user defined reserved functions with internal prototypes (PR #123378)

Luohao Wang llvmlistbot at llvm.org
Tue Jan 21 02:29:41 PST 2025


https://github.com/Luohaothu updated https://github.com/llvm/llvm-project/pull/123378

>From 3b892e6a0618cd3092a16aa2e56bb80594c9c4ba Mon Sep 17 00:00:00 2001
From: Luohao Wang <luohaothu at live.com>
Date: Sat, 18 Jan 2025 01:43:09 +0800
Subject: [PATCH 1/6] [mlir] Add assertion on reserved function's type

---
 .../mlir/Dialect/LLVMIR/FunctionCallUtils.h   |  3 +-
 .../Dialect/LLVMIR/IR/FunctionCallUtils.cpp   | 56 ++++++++++++-------
 2 files changed, 38 insertions(+), 21 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
index 852490cf7428f8..3095c83b90db9e 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
@@ -64,7 +64,8 @@ LLVM::LLVMFuncOp lookupOrCreateMemRefCopyFn(Operation *moduleOp, Type indexType,
 /// Create a FuncOp with signature `resultType`(`paramTypes`)` and name `name`.
 LLVM::LLVMFuncOp lookupOrCreateFn(Operation *moduleOp, StringRef name,
                                   ArrayRef<Type> paramTypes = {},
-                                  Type resultType = {}, bool isVarArg = false);
+                                  Type resultType = {}, bool isVarArg = false,
+                                  bool isReserved = false);
 
 } // namespace LLVM
 } // namespace mlir
diff --git a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
index 88421a16ccf9fb..ecc31df40ea52f 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
@@ -48,13 +48,29 @@ static constexpr llvm::StringRef kMemRefCopy = "memrefCopy";
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFn(Operation *moduleOp,
                                               StringRef name,
                                               ArrayRef<Type> paramTypes,
-                                              Type resultType, bool isVarArg) {
+                                              Type resultType, bool isVarArg, bool isReserved) {
   assert(moduleOp->hasTrait<OpTrait::SymbolTable>() &&
          "expected SymbolTable operation");
   auto func = llvm::dyn_cast_or_null<LLVM::LLVMFuncOp>(
       SymbolTable::lookupSymbolIn(moduleOp, name));
-  if (func)
+  auto funcT = LLVMFunctionType::get(resultType, paramTypes, isVarArg);
+  // Assert the signature of the found function is same as expected
+  if (func) {
+    if (funcT != func.getFunctionType()) {
+      if (isReserved) {
+        func.emitError("redefinition of reserved function '" + name + "' of different type ")
+        .append(func.getFunctionType())
+        .append(" is prohibited");
+        exit(0);
+      } else {
+        func.emitError("redefinition of function '" + name + "' of different type ")
+        .append(funcT)
+        .append(" is prohibited");
+        exit(0);
+      }
+    }
     return func;
+  }
   OpBuilder b(moduleOp->getRegion(0));
   return b.create<LLVM::LLVMFuncOp>(
       moduleOp->getLoc(), name,
@@ -64,37 +80,37 @@ LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFn(Operation *moduleOp,
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintI64Fn(Operation *moduleOp) {
   return lookupOrCreateFn(moduleOp, kPrintI64,
                           IntegerType::get(moduleOp->getContext(), 64),
-                          LLVM::LLVMVoidType::get(moduleOp->getContext()));
+                          LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
 }
 
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintU64Fn(Operation *moduleOp) {
   return lookupOrCreateFn(moduleOp, kPrintU64,
                           IntegerType::get(moduleOp->getContext(), 64),
-                          LLVM::LLVMVoidType::get(moduleOp->getContext()));
+                          LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
 }
 
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF16Fn(Operation *moduleOp) {
   return lookupOrCreateFn(moduleOp, kPrintF16,
                           IntegerType::get(moduleOp->getContext(), 16), // bits!
-                          LLVM::LLVMVoidType::get(moduleOp->getContext()));
+                          LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
 }
 
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintBF16Fn(Operation *moduleOp) {
   return lookupOrCreateFn(moduleOp, kPrintBF16,
                           IntegerType::get(moduleOp->getContext(), 16), // bits!
-                          LLVM::LLVMVoidType::get(moduleOp->getContext()));
+                          LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
 }
 
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF32Fn(Operation *moduleOp) {
   return lookupOrCreateFn(moduleOp, kPrintF32,
                           Float32Type::get(moduleOp->getContext()),
-                          LLVM::LLVMVoidType::get(moduleOp->getContext()));
+                          LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
 }
 
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF64Fn(Operation *moduleOp) {
   return lookupOrCreateFn(moduleOp, kPrintF64,
                           Float64Type::get(moduleOp->getContext()),
-                          LLVM::LLVMVoidType::get(moduleOp->getContext()));
+                          LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
 }
 
 static LLVM::LLVMPointerType getCharPtr(MLIRContext *context) {
@@ -110,51 +126,51 @@ LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintStringFn(
     Operation *moduleOp, std::optional<StringRef> runtimeFunctionName) {
   return lookupOrCreateFn(moduleOp, runtimeFunctionName.value_or(kPrintString),
                           getCharPtr(moduleOp->getContext()),
-                          LLVM::LLVMVoidType::get(moduleOp->getContext()));
+                          LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
 }
 
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintOpenFn(Operation *moduleOp) {
   return lookupOrCreateFn(moduleOp, kPrintOpen, {},
-                          LLVM::LLVMVoidType::get(moduleOp->getContext()));
+                          LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
 }
 
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCloseFn(Operation *moduleOp) {
   return lookupOrCreateFn(moduleOp, kPrintClose, {},
-                          LLVM::LLVMVoidType::get(moduleOp->getContext()));
+                          LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
 }
 
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCommaFn(Operation *moduleOp) {
   return lookupOrCreateFn(moduleOp, kPrintComma, {},
-                          LLVM::LLVMVoidType::get(moduleOp->getContext()));
+                          LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
 }
 
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintNewlineFn(Operation *moduleOp) {
   return lookupOrCreateFn(moduleOp, kPrintNewline, {},
-                          LLVM::LLVMVoidType::get(moduleOp->getContext()));
+                          LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
 }
 
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateMallocFn(Operation *moduleOp,
                                                     Type indexType) {
   return LLVM::lookupOrCreateFn(moduleOp, kMalloc, indexType,
-                                getVoidPtr(moduleOp->getContext()));
+                                getVoidPtr(moduleOp->getContext()), false, true);
 }
 
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateAlignedAllocFn(Operation *moduleOp,
                                                           Type indexType) {
   return LLVM::lookupOrCreateFn(moduleOp, kAlignedAlloc, {indexType, indexType},
-                                getVoidPtr(moduleOp->getContext()));
+                                getVoidPtr(moduleOp->getContext()), false, true);
 }
 
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFreeFn(Operation *moduleOp) {
   return LLVM::lookupOrCreateFn(
       moduleOp, kFree, getVoidPtr(moduleOp->getContext()),
-      LLVM::LLVMVoidType::get(moduleOp->getContext()));
+      LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
 }
 
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateGenericAllocFn(Operation *moduleOp,
                                                           Type indexType) {
   return LLVM::lookupOrCreateFn(moduleOp, kGenericAlloc, indexType,
-                                getVoidPtr(moduleOp->getContext()));
+                                getVoidPtr(moduleOp->getContext()), false, true);
 }
 
 LLVM::LLVMFuncOp
@@ -162,13 +178,13 @@ mlir::LLVM::lookupOrCreateGenericAlignedAllocFn(Operation *moduleOp,
                                                 Type indexType) {
   return LLVM::lookupOrCreateFn(moduleOp, kGenericAlignedAlloc,
                                 {indexType, indexType},
-                                getVoidPtr(moduleOp->getContext()));
+                                getVoidPtr(moduleOp->getContext()), false, true);
 }
 
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateGenericFreeFn(Operation *moduleOp) {
   return LLVM::lookupOrCreateFn(
       moduleOp, kGenericFree, getVoidPtr(moduleOp->getContext()),
-      LLVM::LLVMVoidType::get(moduleOp->getContext()));
+      LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
 }
 
 LLVM::LLVMFuncOp
@@ -177,5 +193,5 @@ mlir::LLVM::lookupOrCreateMemRefCopyFn(Operation *moduleOp, Type indexType,
   return LLVM::lookupOrCreateFn(
       moduleOp, kMemRefCopy,
       ArrayRef<Type>{indexType, unrankedDescriptorType, unrankedDescriptorType},
-      LLVM::LLVMVoidType::get(moduleOp->getContext()));
+      LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
 }

>From bdb69fc838254c35ea55542c39dbf9392cd6d4b2 Mon Sep 17 00:00:00 2001
From: Luohao Wang <luohaothu at live.com>
Date: Sat, 18 Jan 2025 01:43:28 +0800
Subject: [PATCH 2/6] [mlir] Add test

---
 mlir/test/Conversion/MemRefToLLVM/issue-120950.mlir | 11 +++++++++++
 1 file changed, 11 insertions(+)
 create mode 100644 mlir/test/Conversion/MemRefToLLVM/issue-120950.mlir

diff --git a/mlir/test/Conversion/MemRefToLLVM/issue-120950.mlir b/mlir/test/Conversion/MemRefToLLVM/issue-120950.mlir
new file mode 100644
index 00000000000000..f744e4f7635ea7
--- /dev/null
+++ b/mlir/test/Conversion/MemRefToLLVM/issue-120950.mlir
@@ -0,0 +1,11 @@
+// RUN: mlir-opt %s -finalize-memref-to-llvm 2>&1 | FileCheck %s
+
+#map = affine_map<(d0) -> (d0 + 1)>
+module {
+  // CHECK: redefinition of reserved function 'malloc' of different type '!llvm.func<void (i64)>' is prohibited
+  llvm.func @malloc(i64)
+  func.func @issue_120950() {
+    %alloc = memref.alloc() : memref<1024x64xf32, 1>
+    llvm.return
+  }
+}

>From d26a77d431b4b18ed5b320185a55f0984c6f3aea Mon Sep 17 00:00:00 2001
From: Luohao Wang <luohaothu at live.com>
Date: Sat, 18 Jan 2025 02:12:39 +0800
Subject: [PATCH 3/6] [mlir] Reformat code

---
 .../Dialect/LLVMIR/IR/FunctionCallUtils.cpp   | 77 +++++++++++--------
 1 file changed, 45 insertions(+), 32 deletions(-)

diff --git a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
index ecc31df40ea52f..757a1acf3626f6 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
@@ -48,7 +48,8 @@ static constexpr llvm::StringRef kMemRefCopy = "memrefCopy";
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFn(Operation *moduleOp,
                                               StringRef name,
                                               ArrayRef<Type> paramTypes,
-                                              Type resultType, bool isVarArg, bool isReserved) {
+                                              Type resultType, bool isVarArg,
+                                              bool isReserved) {
   assert(moduleOp->hasTrait<OpTrait::SymbolTable>() &&
          "expected SymbolTable operation");
   auto func = llvm::dyn_cast_or_null<LLVM::LLVMFuncOp>(
@@ -58,14 +59,16 @@ LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFn(Operation *moduleOp,
   if (func) {
     if (funcT != func.getFunctionType()) {
       if (isReserved) {
-        func.emitError("redefinition of reserved function '" + name + "' of different type ")
-        .append(func.getFunctionType())
-        .append(" is prohibited");
+        func.emitError("redefinition of reserved function '" + name +
+                       "' of different type ")
+            .append(func.getFunctionType())
+            .append(" is prohibited");
         exit(0);
       } else {
-        func.emitError("redefinition of function '" + name + "' of different type ")
-        .append(funcT)
-        .append(" is prohibited");
+        func.emitError("redefinition of function '" + name +
+                       "' of different type ")
+            .append(funcT)
+            .append(" is prohibited");
         exit(0);
       }
     }
@@ -78,39 +81,41 @@ LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFn(Operation *moduleOp,
 }
 
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintI64Fn(Operation *moduleOp) {
-  return lookupOrCreateFn(moduleOp, kPrintI64,
-                          IntegerType::get(moduleOp->getContext(), 64),
-                          LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
+  return lookupOrCreateFn(
+      moduleOp, kPrintI64, IntegerType::get(moduleOp->getContext(), 64),
+      LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
 }
 
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintU64Fn(Operation *moduleOp) {
-  return lookupOrCreateFn(moduleOp, kPrintU64,
-                          IntegerType::get(moduleOp->getContext(), 64),
-                          LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
+  return lookupOrCreateFn(
+      moduleOp, kPrintU64, IntegerType::get(moduleOp->getContext(), 64),
+      LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
 }
 
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF16Fn(Operation *moduleOp) {
   return lookupOrCreateFn(moduleOp, kPrintF16,
                           IntegerType::get(moduleOp->getContext(), 16), // bits!
-                          LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
+                          LLVM::LLVMVoidType::get(moduleOp->getContext()),
+                          false, true);
 }
 
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintBF16Fn(Operation *moduleOp) {
   return lookupOrCreateFn(moduleOp, kPrintBF16,
                           IntegerType::get(moduleOp->getContext(), 16), // bits!
-                          LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
+                          LLVM::LLVMVoidType::get(moduleOp->getContext()),
+                          false, true);
 }
 
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF32Fn(Operation *moduleOp) {
-  return lookupOrCreateFn(moduleOp, kPrintF32,
-                          Float32Type::get(moduleOp->getContext()),
-                          LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
+  return lookupOrCreateFn(
+      moduleOp, kPrintF32, Float32Type::get(moduleOp->getContext()),
+      LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
 }
 
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF64Fn(Operation *moduleOp) {
-  return lookupOrCreateFn(moduleOp, kPrintF64,
-                          Float64Type::get(moduleOp->getContext()),
-                          LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
+  return lookupOrCreateFn(
+      moduleOp, kPrintF64, Float64Type::get(moduleOp->getContext()),
+      LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
 }
 
 static LLVM::LLVMPointerType getCharPtr(MLIRContext *context) {
@@ -126,39 +131,46 @@ LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintStringFn(
     Operation *moduleOp, std::optional<StringRef> runtimeFunctionName) {
   return lookupOrCreateFn(moduleOp, runtimeFunctionName.value_or(kPrintString),
                           getCharPtr(moduleOp->getContext()),
-                          LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
+                          LLVM::LLVMVoidType::get(moduleOp->getContext()),
+                          false, true);
 }
 
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintOpenFn(Operation *moduleOp) {
   return lookupOrCreateFn(moduleOp, kPrintOpen, {},
-                          LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
+                          LLVM::LLVMVoidType::get(moduleOp->getContext()),
+                          false, true);
 }
 
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCloseFn(Operation *moduleOp) {
   return lookupOrCreateFn(moduleOp, kPrintClose, {},
-                          LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
+                          LLVM::LLVMVoidType::get(moduleOp->getContext()),
+                          false, true);
 }
 
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCommaFn(Operation *moduleOp) {
   return lookupOrCreateFn(moduleOp, kPrintComma, {},
-                          LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
+                          LLVM::LLVMVoidType::get(moduleOp->getContext()),
+                          false, true);
 }
 
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintNewlineFn(Operation *moduleOp) {
   return lookupOrCreateFn(moduleOp, kPrintNewline, {},
-                          LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
+                          LLVM::LLVMVoidType::get(moduleOp->getContext()),
+                          false, true);
 }
 
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateMallocFn(Operation *moduleOp,
                                                     Type indexType) {
   return LLVM::lookupOrCreateFn(moduleOp, kMalloc, indexType,
-                                getVoidPtr(moduleOp->getContext()), false, true);
+                                getVoidPtr(moduleOp->getContext()), false,
+                                true);
 }
 
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateAlignedAllocFn(Operation *moduleOp,
                                                           Type indexType) {
   return LLVM::lookupOrCreateFn(moduleOp, kAlignedAlloc, {indexType, indexType},
-                                getVoidPtr(moduleOp->getContext()), false, true);
+                                getVoidPtr(moduleOp->getContext()), false,
+                                true);
 }
 
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFreeFn(Operation *moduleOp) {
@@ -170,15 +182,16 @@ LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFreeFn(Operation *moduleOp) {
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateGenericAllocFn(Operation *moduleOp,
                                                           Type indexType) {
   return LLVM::lookupOrCreateFn(moduleOp, kGenericAlloc, indexType,
-                                getVoidPtr(moduleOp->getContext()), false, true);
+                                getVoidPtr(moduleOp->getContext()), false,
+                                true);
 }
 
 LLVM::LLVMFuncOp
 mlir::LLVM::lookupOrCreateGenericAlignedAllocFn(Operation *moduleOp,
                                                 Type indexType) {
-  return LLVM::lookupOrCreateFn(moduleOp, kGenericAlignedAlloc,
-                                {indexType, indexType},
-                                getVoidPtr(moduleOp->getContext()), false, true);
+  return LLVM::lookupOrCreateFn(
+      moduleOp, kGenericAlignedAlloc, {indexType, indexType},
+      getVoidPtr(moduleOp->getContext()), false, true);
 }
 
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateGenericFreeFn(Operation *moduleOp) {

>From 5ec38d6c09ff700dab72d27e3610c7866a23c7fc Mon Sep 17 00:00:00 2001
From: Luohao Wang <Luohaothu at users.noreply.github.com>
Date: Tue, 21 Jan 2025 17:39:13 +0800
Subject: [PATCH 4/6] [mlir] Wrapped return value of function lookup in
 `FailureOr` for error handling

---
 .../Conversion/LLVMCommon/PrintCallHelper.h   |   2 +-
 .../mlir/Dialect/LLVMIR/FunctionCallUtils.h   |  43 ++---
 .../Conversion/AsyncToLLVM/AsyncToLLVM.cpp    |   8 +-
 .../ControlFlowToLLVM/ControlFlowToLLVM.cpp   |   6 +-
 mlir/lib/Conversion/LLVMCommon/Pattern.cpp    |  16 +-
 .../Conversion/LLVMCommon/PrintCallHelper.cpp |  14 +-
 .../MemRefToLLVM/AllocLikeConversion.cpp      |  14 +-
 .../Conversion/MemRefToLLVM/MemRefToLLVM.cpp  |  15 +-
 .../VectorToLLVM/ConvertVectorToLLVM.cpp      |  48 +++--
 .../Dialect/LLVMIR/IR/FunctionCallUtils.cpp   | 174 ++++++++++--------
 10 files changed, 196 insertions(+), 144 deletions(-)

diff --git a/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h b/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h
index c2742b6fc1d737..5af86956c0ad92 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h
@@ -23,7 +23,7 @@ namespace LLVM {
 /// Generate IR that prints the given string to stdout.
 /// If a custom runtime function is defined via `runtimeFunctionName`, it must
 /// have the signature void(char const*). The default function is `printString`.
-void createPrintStrCall(OpBuilder &builder, Location loc, ModuleOp moduleOp,
+LogicalResult createPrintStrCall(OpBuilder &builder, Location loc, ModuleOp moduleOp,
                         StringRef symbolName, StringRef string,
                         const LLVMTypeConverter &typeConverter,
                         bool addNewline = true,
diff --git a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
index 3095c83b90db9e..473a69019d2399 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
@@ -16,7 +16,6 @@
 
 #include "mlir/IR/Operation.h"
 #include "mlir/Support/LLVM.h"
-#include <optional>
 
 namespace mlir {
 class Location;
@@ -29,40 +28,42 @@ class ValueRange;
 namespace LLVM {
 class LLVMFuncOp;
 
-/// Helper functions to lookup or create the declaration for commonly used
+/// Helper functions to look up or create the declaration for commonly used
 /// external C function calls. The list of functions provided here must be
 /// implemented separately (e.g. as part of a support runtime library or as part
 /// of the libc).
-LLVM::LLVMFuncOp lookupOrCreatePrintI64Fn(Operation *moduleOp);
-LLVM::LLVMFuncOp lookupOrCreatePrintU64Fn(Operation *moduleOp);
-LLVM::LLVMFuncOp lookupOrCreatePrintF16Fn(Operation *moduleOp);
-LLVM::LLVMFuncOp lookupOrCreatePrintBF16Fn(Operation *moduleOp);
-LLVM::LLVMFuncOp lookupOrCreatePrintF32Fn(Operation *moduleOp);
-LLVM::LLVMFuncOp lookupOrCreatePrintF64Fn(Operation *moduleOp);
+/// Failure if an unexpected version of function is found.
+FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintI64Fn(Operation *moduleOp);
+FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintU64Fn(Operation *moduleOp);
+FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintF16Fn(Operation *moduleOp);
+FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintBF16Fn(Operation *moduleOp);
+FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintF32Fn(Operation *moduleOp);
+FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintF64Fn(Operation *moduleOp);
 /// Declares a function to print a C-string.
 /// If a custom runtime function is defined via `runtimeFunctionName`, it must
 /// have the signature void(char const*). The default function is `printString`.
-LLVM::LLVMFuncOp
+FailureOr<LLVM::LLVMFuncOp>
 lookupOrCreatePrintStringFn(Operation *moduleOp,
                             std::optional<StringRef> runtimeFunctionName = {});
-LLVM::LLVMFuncOp lookupOrCreatePrintOpenFn(Operation *moduleOp);
-LLVM::LLVMFuncOp lookupOrCreatePrintCloseFn(Operation *moduleOp);
-LLVM::LLVMFuncOp lookupOrCreatePrintCommaFn(Operation *moduleOp);
-LLVM::LLVMFuncOp lookupOrCreatePrintNewlineFn(Operation *moduleOp);
-LLVM::LLVMFuncOp lookupOrCreateMallocFn(Operation *moduleOp, Type indexType);
-LLVM::LLVMFuncOp lookupOrCreateAlignedAllocFn(Operation *moduleOp,
+FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintOpenFn(Operation *moduleOp);
+FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintCloseFn(Operation *moduleOp);
+FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintCommaFn(Operation *moduleOp);
+FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintNewlineFn(Operation *moduleOp);
+FailureOr<LLVM::LLVMFuncOp> lookupOrCreateMallocFn(Operation *moduleOp, Type indexType);
+FailureOr<LLVM::LLVMFuncOp> lookupOrCreateAlignedAllocFn(Operation *moduleOp,
                                               Type indexType);
-LLVM::LLVMFuncOp lookupOrCreateFreeFn(Operation *moduleOp);
-LLVM::LLVMFuncOp lookupOrCreateGenericAllocFn(Operation *moduleOp,
+FailureOr<LLVM::LLVMFuncOp> lookupOrCreateFreeFn(Operation *moduleOp);
+FailureOr<LLVM::LLVMFuncOp> lookupOrCreateGenericAllocFn(Operation *moduleOp,
                                               Type indexType);
-LLVM::LLVMFuncOp lookupOrCreateGenericAlignedAllocFn(Operation *moduleOp,
+FailureOr<LLVM::LLVMFuncOp> lookupOrCreateGenericAlignedAllocFn(Operation *moduleOp,
                                                      Type indexType);
-LLVM::LLVMFuncOp lookupOrCreateGenericFreeFn(Operation *moduleOp);
-LLVM::LLVMFuncOp lookupOrCreateMemRefCopyFn(Operation *moduleOp, Type indexType,
+FailureOr<LLVM::LLVMFuncOp> lookupOrCreateGenericFreeFn(Operation *moduleOp);
+FailureOr<LLVM::LLVMFuncOp> lookupOrCreateMemRefCopyFn(Operation *moduleOp, Type indexType,
                                             Type unrankedDescriptorType);
 
 /// Create a FuncOp with signature `resultType`(`paramTypes`)` and name `name`.
-LLVM::LLVMFuncOp lookupOrCreateFn(Operation *moduleOp, StringRef name,
+/// Return a failure if the FuncOp found has unexpected signature.
+FailureOr<LLVM::LLVMFuncOp> lookupOrCreateFn(Operation *moduleOp, StringRef name,
                                   ArrayRef<Type> paramTypes = {},
                                   Type resultType = {}, bool isVarArg = false,
                                   bool isReserved = false);
diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
index 9b5aeb3fef30b4..47d4474a5c28d7 100644
--- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
+++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
@@ -396,8 +396,10 @@ class CoroBeginOpConversion : public AsyncOpConversionPattern<CoroBeginOp> {
     // Allocate memory for the coroutine frame.
     auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn(
         op->getParentOfType<ModuleOp>(), rewriter.getI64Type());
+    if (failed(allocFuncOp))
+      return failure();
     auto coroAlloc = rewriter.create<LLVM::CallOp>(
-        loc, allocFuncOp, ValueRange{coroAlign, coroSize});
+        loc, allocFuncOp.value(), ValueRange{coroAlign, coroSize});
 
     // Begin a coroutine: @llvm.coro.begin.
     auto coroId = CoroBeginOpAdaptor(adaptor.getOperands()).getId();
@@ -431,7 +433,9 @@ class CoroFreeOpConversion : public AsyncOpConversionPattern<CoroFreeOp> {
     // Free the memory.
     auto freeFuncOp =
         LLVM::lookupOrCreateFreeFn(op->getParentOfType<ModuleOp>());
-    rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, freeFuncOp,
+    if (failed(freeFuncOp))
+      return failure();
+    rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, freeFuncOp.value(),
                                               ValueRange(coroMem.getResult()));
 
     return success();
diff --git a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
index d0ffb94f3f96a9..cdcb613e04ab12 100644
--- a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
+++ b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
@@ -61,9 +61,11 @@ struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {
 
     // Failed block: Generate IR to print the message and call `abort`.
     Block *failureBlock = rewriter.createBlock(opBlock->getParent());
-    LLVM::createPrintStrCall(rewriter, loc, module, "assert_msg", op.getMsg(),
+    if (LLVM::createPrintStrCall(rewriter, loc, module, "assert_msg", op.getMsg(),
                              *getTypeConverter(), /*addNewLine=*/false,
-                             /*runtimeFunctionName=*/"puts");
+                             /*runtimeFunctionName=*/"puts").failed()) {
+      return failure();
+    }
     if (abortOnFailedAssert) {
       // Insert the `abort` declaration if necessary.
       auto abortFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("abort");
diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
index a47a2872ceb073..10f72cda7706db 100644
--- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
@@ -276,11 +276,17 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
 
   // Find the malloc and free, or declare them if necessary.
   auto module = builder.getInsertionPoint()->getParentOfType<ModuleOp>();
-  LLVM::LLVMFuncOp freeFunc, mallocFunc;
-  if (toDynamic)
+  FailureOr<LLVM::LLVMFuncOp> freeFunc, mallocFunc;
+  if (toDynamic) {
     mallocFunc = LLVM::lookupOrCreateMallocFn(module, indexType);
-  if (!toDynamic)
+    if (failed(mallocFunc))
+      return failure();
+  }
+  if (!toDynamic) {
     freeFunc = LLVM::lookupOrCreateFreeFn(module);
+    if (failed(freeFunc))
+      return failure();
+  }
 
   unsigned unrankedMemrefPos = 0;
   for (unsigned i = 0, e = operands.size(); i < e; ++i) {
@@ -293,7 +299,7 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
     // Allocate memory, copy, and free the source if necessary.
     Value memory =
         toDynamic
-            ? builder.create<LLVM::CallOp>(loc, mallocFunc, allocationSize)
+            ? builder.create<LLVM::CallOp>(loc, mallocFunc.value(), allocationSize)
                   .getResult()
             : builder.create<LLVM::AllocaOp>(loc, getVoidPtrType(),
                                              IntegerType::get(getContext(), 8),
@@ -302,7 +308,7 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
     Value source = desc.memRefDescPtr(builder, loc);
     builder.create<LLVM::MemcpyOp>(loc, memory, source, allocationSize, false);
     if (!toDynamic)
-      builder.create<LLVM::CallOp>(loc, freeFunc, source);
+      builder.create<LLVM::CallOp>(loc, freeFunc.value(), source);
 
     // Create a new descriptor. The same descriptor can be returned multiple
     // times, attempting to modify its pointer can lead to memory leaks
diff --git a/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp b/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp
index bd7b401efec17a..607e1d65045523 100644
--- a/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp
@@ -27,7 +27,7 @@ static std::string ensureSymbolNameIsUnique(ModuleOp moduleOp,
   return uniqueName;
 }
 
-void mlir::LLVM::createPrintStrCall(
+LogicalResult mlir::LLVM::createPrintStrCall(
     OpBuilder &builder, Location loc, ModuleOp moduleOp, StringRef symbolName,
     StringRef string, const LLVMTypeConverter &typeConverter, bool addNewline,
     std::optional<StringRef> runtimeFunctionName) {
@@ -59,8 +59,12 @@ void mlir::LLVM::createPrintStrCall(
   SmallVector<LLVM::GEPArg> indices(1, 0);
   Value gep =
       builder.create<LLVM::GEPOp>(loc, ptrTy, arrayTy, msgAddr, indices);
-  Operation *printer =
-      LLVM::lookupOrCreatePrintStringFn(moduleOp, runtimeFunctionName);
-  builder.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(printer),
-                               gep);
+  if (auto printer =
+          LLVM::lookupOrCreatePrintStringFn(moduleOp, runtimeFunctionName); succeeded(printer)) {
+    builder.create<LLVM::CallOp>(loc, TypeRange(),
+                                 SymbolRefAttr::get(printer.value()), gep);
+  } else {
+    return failure();
+  }
+  return success();
 }
diff --git a/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp b/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
index a6408391b1330c..0ee92722157f35 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
@@ -15,7 +15,7 @@
 using namespace mlir;
 
 namespace {
-LLVM::LLVMFuncOp getNotalignedAllocFn(const LLVMTypeConverter *typeConverter,
+FailureOr<LLVM::LLVMFuncOp> getNotalignedAllocFn(const LLVMTypeConverter *typeConverter,
                                       Operation *module, Type indexType) {
   bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
   if (useGenericFn)
@@ -24,7 +24,7 @@ LLVM::LLVMFuncOp getNotalignedAllocFn(const LLVMTypeConverter *typeConverter,
   return LLVM::lookupOrCreateMallocFn(module, indexType);
 }
 
-LLVM::LLVMFuncOp getAlignedAllocFn(const LLVMTypeConverter *typeConverter,
+FailureOr<LLVM::LLVMFuncOp> getAlignedAllocFn(const LLVMTypeConverter *typeConverter,
                                    Operation *module, Type indexType) {
   bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
 
@@ -80,10 +80,11 @@ std::tuple<Value, Value> AllocationOpLLVMLowering::allocateBufferManuallyAlign(
         << " to integer address space "
            "failed. Consider adding memory space conversions.";
   }
-  LLVM::LLVMFuncOp allocFuncOp = getNotalignedAllocFn(
+  FailureOr<LLVM::LLVMFuncOp> allocFuncOp = getNotalignedAllocFn(
       getTypeConverter(), op->getParentWithTrait<OpTrait::SymbolTable>(),
       getIndexType());
-  auto results = rewriter.create<LLVM::CallOp>(loc, allocFuncOp, sizeBytes);
+  if (failed(allocFuncOp)) return std::make_tuple(Value(), Value());
+  auto results = rewriter.create<LLVM::CallOp>(loc, allocFuncOp.value(), sizeBytes);
 
   Value allocatedPtr =
       castAllocFuncResult(rewriter, loc, results.getResult(), memRefType,
@@ -146,11 +147,12 @@ Value AllocationOpLLVMLowering::allocateBufferAutoAlign(
     sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment);
 
   Type elementPtrType = this->getElementPtrType(memRefType);
-  LLVM::LLVMFuncOp allocFuncOp = getAlignedAllocFn(
+  FailureOr<LLVM::LLVMFuncOp> allocFuncOp = getAlignedAllocFn(
       getTypeConverter(), op->getParentWithTrait<OpTrait::SymbolTable>(),
       getIndexType());
+  if (failed(allocFuncOp)) return Value();
   auto results = rewriter.create<LLVM::CallOp>(
-      loc, allocFuncOp, ValueRange({allocAlignment, sizeBytes}));
+      loc, allocFuncOp.value(), ValueRange({allocAlignment, sizeBytes}));
 
   return castAllocFuncResult(rewriter, loc, results.getResult(), memRefType,
                              elementPtrType, *getTypeConverter());
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index f7542b8b3bc5c7..ac27e0dd09bdcd 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -42,8 +42,8 @@ bool isStaticStrideOrOffset(int64_t strideOrOffset) {
   return !ShapedType::isDynamic(strideOrOffset);
 }
 
-LLVM::LLVMFuncOp getFreeFn(const LLVMTypeConverter *typeConverter,
-                           ModuleOp module) {
+FailureOr<LLVM::LLVMFuncOp> getFreeFn(const LLVMTypeConverter *typeConverter,
+                                      ModuleOp module) {
   bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
 
   if (useGenericFn)
@@ -220,8 +220,10 @@ struct DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> {
   matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     // Insert the `free` declaration if it is not already present.
-    LLVM::LLVMFuncOp freeFunc =
+    auto freeFunc =
         getFreeFn(getTypeConverter(), op->getParentOfType<ModuleOp>());
+    if (failed(freeFunc))
+      return failure();
     Value allocatedPtr;
     if (auto unrankedTy =
             llvm::dyn_cast<UnrankedMemRefType>(op.getMemref().getType())) {
@@ -236,7 +238,8 @@ struct DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> {
       allocatedPtr = MemRefDescriptor(adaptor.getMemref())
                          .allocatedPtr(rewriter, op.getLoc());
     }
-    rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, freeFunc, allocatedPtr);
+    rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, freeFunc.value(),
+                                              allocatedPtr);
     return success();
   }
 };
@@ -838,7 +841,9 @@ struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
     auto elemSize = getSizeInBytes(loc, srcType.getElementType(), rewriter);
     auto copyFn = LLVM::lookupOrCreateMemRefCopyFn(
         op->getParentOfType<ModuleOp>(), getIndexType(), sourcePtr.getType());
-    rewriter.create<LLVM::CallOp>(loc, copyFn,
+    if (failed(copyFn))
+      return failure();
+    rewriter.create<LLVM::CallOp>(loc, copyFn.value(),
                                   ValueRange{elemSize, sourcePtr, targetPtr});
 
     // Restore stack used for descriptors
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index a1e21cb524bd9a..79617506008fab 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1546,24 +1546,32 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
 
     auto punct = printOp.getPunctuation();
     if (auto stringLiteral = printOp.getStringLiteral()) {
-      LLVM::createPrintStrCall(rewriter, loc, parent, "vector_print_str",
-                               *stringLiteral, *getTypeConverter(),
-                               /*addNewline=*/false);
+      if (LLVM::createPrintStrCall(rewriter, loc, parent, "vector_print_str",
+                                   *stringLiteral, *getTypeConverter(),
+                                   /*addNewline=*/false)
+              .failed()) {
+        return failure();
+      }
     } else if (punct != PrintPunctuation::NoPunctuation) {
-      emitCall(rewriter, printOp->getLoc(), [&] {
-        switch (punct) {
-        case PrintPunctuation::Close:
-          return LLVM::lookupOrCreatePrintCloseFn(parent);
-        case PrintPunctuation::Open:
-          return LLVM::lookupOrCreatePrintOpenFn(parent);
-        case PrintPunctuation::Comma:
-          return LLVM::lookupOrCreatePrintCommaFn(parent);
-        case PrintPunctuation::NewLine:
-          return LLVM::lookupOrCreatePrintNewlineFn(parent);
-        default:
-          llvm_unreachable("unexpected punctuation");
-        }
-      }());
+      if (auto op = [&] -> FailureOr<LLVM::LLVMFuncOp> {
+            switch (punct) {
+            case PrintPunctuation::Close:
+              return LLVM::lookupOrCreatePrintCloseFn(parent);
+            case PrintPunctuation::Open:
+              return LLVM::lookupOrCreatePrintOpenFn(parent);
+            case PrintPunctuation::Comma:
+              return LLVM::lookupOrCreatePrintCommaFn(parent);
+            case PrintPunctuation::NewLine:
+              return LLVM::lookupOrCreatePrintNewlineFn(parent);
+            default:
+              llvm_unreachable("unexpected punctuation");
+            }
+          }();
+          succeeded(op))
+        emitCall(rewriter, printOp->getLoc(), op.value());
+      else {
+        return failure();
+      }
     }
 
     rewriter.eraseOp(printOp);
@@ -1588,7 +1596,7 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
 
     // Make sure element type has runtime support.
     PrintConversion conversion = PrintConversion::None;
-    Operation *printer;
+    FailureOr<Operation *> printer;
     if (printType.isF32()) {
       printer = LLVM::lookupOrCreatePrintF32Fn(parent);
     } else if (printType.isF64()) {
@@ -1631,6 +1639,8 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
     } else {
       return failure();
     }
+    if (failed(printer))
+      return failure();
 
     switch (conversion) {
     case PrintConversion::ZeroExt64:
@@ -1648,7 +1658,7 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
     case PrintConversion::None:
       break;
     }
-    emitCall(rewriter, loc, printer, value);
+    emitCall(rewriter, loc, printer.value(), value);
     return success();
   }
 
diff --git a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
index 757a1acf3626f6..c2c87bc7544bd7 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
@@ -45,11 +45,10 @@ static constexpr llvm::StringRef kGenericFree = "_mlir_memref_to_llvm_free";
 static constexpr llvm::StringRef kMemRefCopy = "memrefCopy";
 
 /// Generic print function lookupOrCreate helper.
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFn(Operation *moduleOp,
-                                              StringRef name,
-                                              ArrayRef<Type> paramTypes,
-                                              Type resultType, bool isVarArg,
-                                              bool isReserved) {
+FailureOr<LLVM::LLVMFuncOp>
+mlir::LLVM::lookupOrCreateFn(Operation *moduleOp, StringRef name,
+                             ArrayRef<Type> paramTypes, Type resultType,
+                             bool isVarArg, bool isReserved) {
   assert(moduleOp->hasTrait<OpTrait::SymbolTable>() &&
          "expected SymbolTable operation");
   auto func = llvm::dyn_cast_or_null<LLVM::LLVMFuncOp>(
@@ -63,14 +62,13 @@ LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFn(Operation *moduleOp,
                        "' of different type ")
             .append(func.getFunctionType())
             .append(" is prohibited");
-        exit(0);
       } else {
         func.emitError("redefinition of function '" + name +
                        "' of different type ")
             .append(funcT)
             .append(" is prohibited");
-        exit(0);
       }
+      return failure();
     }
     return func;
   }
@@ -80,42 +78,58 @@ LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFn(Operation *moduleOp,
       LLVM::LLVMFunctionType::get(resultType, paramTypes, isVarArg));
 }
 
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintI64Fn(Operation *moduleOp) {
-  return lookupOrCreateFn(
+namespace {
+FailureOr<LLVM::LLVMFuncOp> lookupOrCreateReservedFn(Operation *moduleOp,
+                                                     StringRef name,
+                                                     ArrayRef<Type> paramTypes,
+                                                     Type resultType) {
+  return lookupOrCreateFn(moduleOp, name, paramTypes, resultType,
+                          /*isVarArg=*/false, /*isReserved=*/true);
+}
+} // namespace
+
+FailureOr<LLVM::LLVMFuncOp>
+mlir::LLVM::lookupOrCreatePrintI64Fn(Operation *moduleOp) {
+  return lookupOrCreateReservedFn(
       moduleOp, kPrintI64, IntegerType::get(moduleOp->getContext(), 64),
-      LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
+      LLVM::LLVMVoidType::get(moduleOp->getContext()));
 }
 
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintU64Fn(Operation *moduleOp) {
-  return lookupOrCreateFn(
+FailureOr<LLVM::LLVMFuncOp>
+mlir::LLVM::lookupOrCreatePrintU64Fn(Operation *moduleOp) {
+  return lookupOrCreateReservedFn(
       moduleOp, kPrintU64, IntegerType::get(moduleOp->getContext(), 64),
-      LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
+      LLVM::LLVMVoidType::get(moduleOp->getContext()));
 }
 
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF16Fn(Operation *moduleOp) {
-  return lookupOrCreateFn(moduleOp, kPrintF16,
-                          IntegerType::get(moduleOp->getContext(), 16), // bits!
-                          LLVM::LLVMVoidType::get(moduleOp->getContext()),
-                          false, true);
+FailureOr<LLVM::LLVMFuncOp>
+mlir::LLVM::lookupOrCreatePrintF16Fn(Operation *moduleOp) {
+  return lookupOrCreateReservedFn(
+      moduleOp, kPrintF16,
+      IntegerType::get(moduleOp->getContext(), 16), // bits!
+      LLVM::LLVMVoidType::get(moduleOp->getContext()));
 }
 
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintBF16Fn(Operation *moduleOp) {
-  return lookupOrCreateFn(moduleOp, kPrintBF16,
-                          IntegerType::get(moduleOp->getContext(), 16), // bits!
-                          LLVM::LLVMVoidType::get(moduleOp->getContext()),
-                          false, true);
+FailureOr<LLVM::LLVMFuncOp>
+mlir::LLVM::lookupOrCreatePrintBF16Fn(Operation *moduleOp) {
+  return lookupOrCreateReservedFn(
+      moduleOp, kPrintBF16,
+      IntegerType::get(moduleOp->getContext(), 16), // bits!
+      LLVM::LLVMVoidType::get(moduleOp->getContext()));
 }
 
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF32Fn(Operation *moduleOp) {
-  return lookupOrCreateFn(
+FailureOr<LLVM::LLVMFuncOp>
+mlir::LLVM::lookupOrCreatePrintF32Fn(Operation *moduleOp) {
+  return lookupOrCreateReservedFn(
       moduleOp, kPrintF32, Float32Type::get(moduleOp->getContext()),
-      LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
+      LLVM::LLVMVoidType::get(moduleOp->getContext()));
 }
 
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF64Fn(Operation *moduleOp) {
-  return lookupOrCreateFn(
+FailureOr<LLVM::LLVMFuncOp>
+mlir::LLVM::lookupOrCreatePrintF64Fn(Operation *moduleOp) {
+  return lookupOrCreateReservedFn(
       moduleOp, kPrintF64, Float64Type::get(moduleOp->getContext()),
-      LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
+      LLVM::LLVMVoidType::get(moduleOp->getContext()));
 }
 
 static LLVM::LLVMPointerType getCharPtr(MLIRContext *context) {
@@ -127,84 +141,88 @@ static LLVM::LLVMPointerType getVoidPtr(MLIRContext *context) {
   return getCharPtr(context);
 }
 
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintStringFn(
+FailureOr<LLVM::LLVMFuncOp> mlir::LLVM::lookupOrCreatePrintStringFn(
     Operation *moduleOp, std::optional<StringRef> runtimeFunctionName) {
-  return lookupOrCreateFn(moduleOp, runtimeFunctionName.value_or(kPrintString),
-                          getCharPtr(moduleOp->getContext()),
-                          LLVM::LLVMVoidType::get(moduleOp->getContext()),
-                          false, true);
+  return lookupOrCreateReservedFn(
+      moduleOp, runtimeFunctionName.value_or(kPrintString),
+      getCharPtr(moduleOp->getContext()),
+      LLVM::LLVMVoidType::get(moduleOp->getContext()));
 }
 
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintOpenFn(Operation *moduleOp) {
-  return lookupOrCreateFn(moduleOp, kPrintOpen, {},
-                          LLVM::LLVMVoidType::get(moduleOp->getContext()),
-                          false, true);
+FailureOr<LLVM::LLVMFuncOp>
+mlir::LLVM::lookupOrCreatePrintOpenFn(Operation *moduleOp) {
+  return lookupOrCreateReservedFn(
+      moduleOp, kPrintOpen, {},
+      LLVM::LLVMVoidType::get(moduleOp->getContext()));
 }
 
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCloseFn(Operation *moduleOp) {
-  return lookupOrCreateFn(moduleOp, kPrintClose, {},
-                          LLVM::LLVMVoidType::get(moduleOp->getContext()),
-                          false, true);
+FailureOr<LLVM::LLVMFuncOp>
+mlir::LLVM::lookupOrCreatePrintCloseFn(Operation *moduleOp) {
+  return lookupOrCreateReservedFn(
+      moduleOp, kPrintClose, {},
+      LLVM::LLVMVoidType::get(moduleOp->getContext()));
 }
 
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCommaFn(Operation *moduleOp) {
-  return lookupOrCreateFn(moduleOp, kPrintComma, {},
-                          LLVM::LLVMVoidType::get(moduleOp->getContext()),
-                          false, true);
+FailureOr<LLVM::LLVMFuncOp>
+mlir::LLVM::lookupOrCreatePrintCommaFn(Operation *moduleOp) {
+  return lookupOrCreateReservedFn(
+      moduleOp, kPrintComma, {},
+      LLVM::LLVMVoidType::get(moduleOp->getContext()));
 }
 
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintNewlineFn(Operation *moduleOp) {
-  return lookupOrCreateFn(moduleOp, kPrintNewline, {},
-                          LLVM::LLVMVoidType::get(moduleOp->getContext()),
-                          false, true);
+FailureOr<LLVM::LLVMFuncOp>
+mlir::LLVM::lookupOrCreatePrintNewlineFn(Operation *moduleOp) {
+  return lookupOrCreateReservedFn(
+      moduleOp, kPrintNewline, {},
+      LLVM::LLVMVoidType::get(moduleOp->getContext()));
 }
 
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateMallocFn(Operation *moduleOp,
-                                                    Type indexType) {
-  return LLVM::lookupOrCreateFn(moduleOp, kMalloc, indexType,
-                                getVoidPtr(moduleOp->getContext()), false,
-                                true);
+FailureOr<LLVM::LLVMFuncOp>
+mlir::LLVM::lookupOrCreateMallocFn(Operation *moduleOp, Type indexType) {
+  return lookupOrCreateReservedFn(moduleOp, kMalloc, indexType,
+                                        getVoidPtr(moduleOp->getContext()));
 }
 
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateAlignedAllocFn(Operation *moduleOp,
-                                                          Type indexType) {
-  return LLVM::lookupOrCreateFn(moduleOp, kAlignedAlloc, {indexType, indexType},
-                                getVoidPtr(moduleOp->getContext()), false,
-                                true);
+FailureOr<LLVM::LLVMFuncOp>
+mlir::LLVM::lookupOrCreateAlignedAllocFn(Operation *moduleOp, Type indexType) {
+  return lookupOrCreateReservedFn(moduleOp, kAlignedAlloc,
+                                        {indexType, indexType},
+                                        getVoidPtr(moduleOp->getContext()));
 }
 
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFreeFn(Operation *moduleOp) {
-  return LLVM::lookupOrCreateFn(
+FailureOr<LLVM::LLVMFuncOp>
+mlir::LLVM::lookupOrCreateFreeFn(Operation *moduleOp) {
+  return lookupOrCreateReservedFn(
       moduleOp, kFree, getVoidPtr(moduleOp->getContext()),
-      LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
+      LLVM::LLVMVoidType::get(moduleOp->getContext()));
 }
 
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateGenericAllocFn(Operation *moduleOp,
-                                                          Type indexType) {
-  return LLVM::lookupOrCreateFn(moduleOp, kGenericAlloc, indexType,
-                                getVoidPtr(moduleOp->getContext()), false,
-                                true);
+FailureOr<LLVM::LLVMFuncOp>
+mlir::LLVM::lookupOrCreateGenericAllocFn(Operation *moduleOp, Type indexType) {
+  return lookupOrCreateReservedFn(moduleOp, kGenericAlloc, indexType,
+                                        getVoidPtr(moduleOp->getContext()));
 }
 
-LLVM::LLVMFuncOp
+FailureOr<LLVM::LLVMFuncOp>
 mlir::LLVM::lookupOrCreateGenericAlignedAllocFn(Operation *moduleOp,
                                                 Type indexType) {
-  return LLVM::lookupOrCreateFn(
-      moduleOp, kGenericAlignedAlloc, {indexType, indexType},
-      getVoidPtr(moduleOp->getContext()), false, true);
+  return lookupOrCreateReservedFn(moduleOp, kGenericAlignedAlloc,
+                                        {indexType, indexType},
+                                        getVoidPtr(moduleOp->getContext()));
 }
 
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateGenericFreeFn(Operation *moduleOp) {
-  return LLVM::lookupOrCreateFn(
+FailureOr<LLVM::LLVMFuncOp>
+mlir::LLVM::lookupOrCreateGenericFreeFn(Operation *moduleOp) {
+  return lookupOrCreateReservedFn(
       moduleOp, kGenericFree, getVoidPtr(moduleOp->getContext()),
-      LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
+      LLVM::LLVMVoidType::get(moduleOp->getContext()));
 }
 
-LLVM::LLVMFuncOp
+FailureOr<LLVM::LLVMFuncOp>
 mlir::LLVM::lookupOrCreateMemRefCopyFn(Operation *moduleOp, Type indexType,
                                        Type unrankedDescriptorType) {
-  return LLVM::lookupOrCreateFn(
+  return lookupOrCreateReservedFn(
       moduleOp, kMemRefCopy,
       ArrayRef<Type>{indexType, unrankedDescriptorType, unrankedDescriptorType},
-      LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
+      LLVM::LLVMVoidType::get(moduleOp->getContext()));
 }

>From 9079caff8b7bfa80bef19b2d33bdd7e361a8eb66 Mon Sep 17 00:00:00 2001
From: Luohao Wang <luohaothu at live.com>
Date: Tue, 21 Jan 2025 18:16:52 +0800
Subject: [PATCH 5/6] [mlir] [Test] Moved & renamed test case

---
 mlir/test/Conversion/MemRefToLLVM/invalid.mlir      |  7 +++++++
 mlir/test/Conversion/MemRefToLLVM/issue-120950.mlir | 11 -----------
 2 files changed, 7 insertions(+), 11 deletions(-)
 delete mode 100644 mlir/test/Conversion/MemRefToLLVM/issue-120950.mlir

diff --git a/mlir/test/Conversion/MemRefToLLVM/invalid.mlir b/mlir/test/Conversion/MemRefToLLVM/invalid.mlir
index 40dd75af1dd770..1e12b83a24b5a7 100644
--- a/mlir/test/Conversion/MemRefToLLVM/invalid.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/invalid.mlir
@@ -2,6 +2,13 @@
 // Since the error is at an unknown location, we use FileCheck instead of
 // -veri-y-diagnostics here
 
+// CHECK: redefinition of reserved function 'malloc' of different type '!llvm.func<void (i64)>' is prohibited
+llvm.func @malloc(i64)
+func.func @redef_reserved() {
+    %alloc = memref.alloc() : memref<1024x64xf32, 1>
+    llvm.return
+}
+
 // CHECK: conversion of memref memory space "foo" to integer address space failed. Consider adding memory space conversions.
 // CHECK-LABEL: @bad_address_space
 func.func @bad_address_space(%a: memref<2xindex, "foo">) {
diff --git a/mlir/test/Conversion/MemRefToLLVM/issue-120950.mlir b/mlir/test/Conversion/MemRefToLLVM/issue-120950.mlir
deleted file mode 100644
index f744e4f7635ea7..00000000000000
--- a/mlir/test/Conversion/MemRefToLLVM/issue-120950.mlir
+++ /dev/null
@@ -1,11 +0,0 @@
-// RUN: mlir-opt %s -finalize-memref-to-llvm 2>&1 | FileCheck %s
-
-#map = affine_map<(d0) -> (d0 + 1)>
-module {
-  // CHECK: redefinition of reserved function 'malloc' of different type '!llvm.func<void (i64)>' is prohibited
-  llvm.func @malloc(i64)
-  func.func @issue_120950() {
-    %alloc = memref.alloc() : memref<1024x64xf32, 1>
-    llvm.return
-  }
-}

>From 475409c4fac865f26c94b751a9b0dbfcc937b83a Mon Sep 17 00:00:00 2001
From: Luohao Wang <luohaothu at live.com>
Date: Tue, 21 Jan 2025 18:29:22 +0800
Subject: [PATCH 6/6] Reformat code

---
 .../Conversion/LLVMCommon/PrintCallHelper.h   |  9 ++++---
 .../mlir/Dialect/LLVMIR/FunctionCallUtils.h   | 24 ++++++++++---------
 .../ControlFlowToLLVM/ControlFlowToLLVM.cpp   |  8 ++++---
 mlir/lib/Conversion/LLVMCommon/Pattern.cpp    |  3 ++-
 .../Conversion/LLVMCommon/PrintCallHelper.cpp |  3 ++-
 .../MemRefToLLVM/AllocLikeConversion.cpp      | 19 +++++++++------
 .../Dialect/LLVMIR/IR/FunctionCallUtils.cpp   | 12 +++++-----
 7 files changed, 44 insertions(+), 34 deletions(-)

diff --git a/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h b/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h
index 5af86956c0ad92..33402301115b73 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h
@@ -23,11 +23,10 @@ namespace LLVM {
 /// Generate IR that prints the given string to stdout.
 /// If a custom runtime function is defined via `runtimeFunctionName`, it must
 /// have the signature void(char const*). The default function is `printString`.
-LogicalResult createPrintStrCall(OpBuilder &builder, Location loc, ModuleOp moduleOp,
-                        StringRef symbolName, StringRef string,
-                        const LLVMTypeConverter &typeConverter,
-                        bool addNewline = true,
-                        std::optional<StringRef> runtimeFunctionName = {});
+LogicalResult createPrintStrCall(
+    OpBuilder &builder, Location loc, ModuleOp moduleOp, StringRef symbolName,
+    StringRef string, const LLVMTypeConverter &typeConverter,
+    bool addNewline = true, std::optional<StringRef> runtimeFunctionName = {});
 } // namespace LLVM
 
 } // namespace mlir
diff --git a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
index 473a69019d2399..05e9fe9d58859c 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
@@ -49,24 +49,26 @@ FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintOpenFn(Operation *moduleOp);
 FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintCloseFn(Operation *moduleOp);
 FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintCommaFn(Operation *moduleOp);
 FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintNewlineFn(Operation *moduleOp);
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreateMallocFn(Operation *moduleOp, Type indexType);
+FailureOr<LLVM::LLVMFuncOp> lookupOrCreateMallocFn(Operation *moduleOp,
+                                                   Type indexType);
 FailureOr<LLVM::LLVMFuncOp> lookupOrCreateAlignedAllocFn(Operation *moduleOp,
-                                              Type indexType);
+                                                         Type indexType);
 FailureOr<LLVM::LLVMFuncOp> lookupOrCreateFreeFn(Operation *moduleOp);
 FailureOr<LLVM::LLVMFuncOp> lookupOrCreateGenericAllocFn(Operation *moduleOp,
-                                              Type indexType);
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreateGenericAlignedAllocFn(Operation *moduleOp,
-                                                     Type indexType);
+                                                         Type indexType);
+FailureOr<LLVM::LLVMFuncOp>
+lookupOrCreateGenericAlignedAllocFn(Operation *moduleOp, Type indexType);
 FailureOr<LLVM::LLVMFuncOp> lookupOrCreateGenericFreeFn(Operation *moduleOp);
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreateMemRefCopyFn(Operation *moduleOp, Type indexType,
-                                            Type unrankedDescriptorType);
+FailureOr<LLVM::LLVMFuncOp>
+lookupOrCreateMemRefCopyFn(Operation *moduleOp, Type indexType,
+                           Type unrankedDescriptorType);
 
 /// Create a FuncOp with signature `resultType`(`paramTypes`)` and name `name`.
 /// Return a failure if the FuncOp found has unexpected signature.
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreateFn(Operation *moduleOp, StringRef name,
-                                  ArrayRef<Type> paramTypes = {},
-                                  Type resultType = {}, bool isVarArg = false,
-                                  bool isReserved = false);
+FailureOr<LLVM::LLVMFuncOp>
+lookupOrCreateFn(Operation *moduleOp, StringRef name,
+                 ArrayRef<Type> paramTypes = {}, Type resultType = {},
+                 bool isVarArg = false, bool isReserved = false);
 
 } // namespace LLVM
 } // namespace mlir
diff --git a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
index cdcb613e04ab12..f2fc235fecb289 100644
--- a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
+++ b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
@@ -61,9 +61,11 @@ struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {
 
     // Failed block: Generate IR to print the message and call `abort`.
     Block *failureBlock = rewriter.createBlock(opBlock->getParent());
-    if (LLVM::createPrintStrCall(rewriter, loc, module, "assert_msg", op.getMsg(),
-                             *getTypeConverter(), /*addNewLine=*/false,
-                             /*runtimeFunctionName=*/"puts").failed()) {
+    if (LLVM::createPrintStrCall(rewriter, loc, module, "assert_msg",
+                                 op.getMsg(), *getTypeConverter(),
+                                 /*addNewLine=*/false,
+                                 /*runtimeFunctionName=*/"puts")
+            .failed()) {
       return failure();
     }
     if (abortOnFailedAssert) {
diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
index 10f72cda7706db..840bd3df61a063 100644
--- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
@@ -299,7 +299,8 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
     // Allocate memory, copy, and free the source if necessary.
     Value memory =
         toDynamic
-            ? builder.create<LLVM::CallOp>(loc, mallocFunc.value(), allocationSize)
+            ? builder
+                  .create<LLVM::CallOp>(loc, mallocFunc.value(), allocationSize)
                   .getResult()
             : builder.create<LLVM::AllocaOp>(loc, getVoidPtrType(),
                                              IntegerType::get(getContext(), 8),
diff --git a/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp b/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp
index 607e1d65045523..381e2ffea8eb29 100644
--- a/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp
@@ -60,7 +60,8 @@ LogicalResult mlir::LLVM::createPrintStrCall(
   Value gep =
       builder.create<LLVM::GEPOp>(loc, ptrTy, arrayTy, msgAddr, indices);
   if (auto printer =
-          LLVM::lookupOrCreatePrintStringFn(moduleOp, runtimeFunctionName); succeeded(printer)) {
+          LLVM::lookupOrCreatePrintStringFn(moduleOp, runtimeFunctionName);
+      succeeded(printer)) {
     builder.create<LLVM::CallOp>(loc, TypeRange(),
                                  SymbolRefAttr::get(printer.value()), gep);
   } else {
diff --git a/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp b/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
index 0ee92722157f35..1712d0b5844b88 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
@@ -15,8 +15,9 @@
 using namespace mlir;
 
 namespace {
-FailureOr<LLVM::LLVMFuncOp> getNotalignedAllocFn(const LLVMTypeConverter *typeConverter,
-                                      Operation *module, Type indexType) {
+FailureOr<LLVM::LLVMFuncOp>
+getNotalignedAllocFn(const LLVMTypeConverter *typeConverter, Operation *module,
+                     Type indexType) {
   bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
   if (useGenericFn)
     return LLVM::lookupOrCreateGenericAllocFn(module, indexType);
@@ -24,8 +25,9 @@ FailureOr<LLVM::LLVMFuncOp> getNotalignedAllocFn(const LLVMTypeConverter *typeCo
   return LLVM::lookupOrCreateMallocFn(module, indexType);
 }
 
-FailureOr<LLVM::LLVMFuncOp> getAlignedAllocFn(const LLVMTypeConverter *typeConverter,
-                                   Operation *module, Type indexType) {
+FailureOr<LLVM::LLVMFuncOp>
+getAlignedAllocFn(const LLVMTypeConverter *typeConverter, Operation *module,
+                  Type indexType) {
   bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
 
   if (useGenericFn)
@@ -83,8 +85,10 @@ std::tuple<Value, Value> AllocationOpLLVMLowering::allocateBufferManuallyAlign(
   FailureOr<LLVM::LLVMFuncOp> allocFuncOp = getNotalignedAllocFn(
       getTypeConverter(), op->getParentWithTrait<OpTrait::SymbolTable>(),
       getIndexType());
-  if (failed(allocFuncOp)) return std::make_tuple(Value(), Value());
-  auto results = rewriter.create<LLVM::CallOp>(loc, allocFuncOp.value(), sizeBytes);
+  if (failed(allocFuncOp))
+    return std::make_tuple(Value(), Value());
+  auto results =
+      rewriter.create<LLVM::CallOp>(loc, allocFuncOp.value(), sizeBytes);
 
   Value allocatedPtr =
       castAllocFuncResult(rewriter, loc, results.getResult(), memRefType,
@@ -150,7 +154,8 @@ Value AllocationOpLLVMLowering::allocateBufferAutoAlign(
   FailureOr<LLVM::LLVMFuncOp> allocFuncOp = getAlignedAllocFn(
       getTypeConverter(), op->getParentWithTrait<OpTrait::SymbolTable>(),
       getIndexType());
-  if (failed(allocFuncOp)) return Value();
+  if (failed(allocFuncOp))
+    return Value();
   auto results = rewriter.create<LLVM::CallOp>(
       loc, allocFuncOp.value(), ValueRange({allocAlignment, sizeBytes}));
 
diff --git a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
index c2c87bc7544bd7..9df5c4554c2360 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
@@ -180,14 +180,14 @@ mlir::LLVM::lookupOrCreatePrintNewlineFn(Operation *moduleOp) {
 FailureOr<LLVM::LLVMFuncOp>
 mlir::LLVM::lookupOrCreateMallocFn(Operation *moduleOp, Type indexType) {
   return lookupOrCreateReservedFn(moduleOp, kMalloc, indexType,
-                                        getVoidPtr(moduleOp->getContext()));
+                                  getVoidPtr(moduleOp->getContext()));
 }
 
 FailureOr<LLVM::LLVMFuncOp>
 mlir::LLVM::lookupOrCreateAlignedAllocFn(Operation *moduleOp, Type indexType) {
   return lookupOrCreateReservedFn(moduleOp, kAlignedAlloc,
-                                        {indexType, indexType},
-                                        getVoidPtr(moduleOp->getContext()));
+                                  {indexType, indexType},
+                                  getVoidPtr(moduleOp->getContext()));
 }
 
 FailureOr<LLVM::LLVMFuncOp>
@@ -200,15 +200,15 @@ mlir::LLVM::lookupOrCreateFreeFn(Operation *moduleOp) {
 FailureOr<LLVM::LLVMFuncOp>
 mlir::LLVM::lookupOrCreateGenericAllocFn(Operation *moduleOp, Type indexType) {
   return lookupOrCreateReservedFn(moduleOp, kGenericAlloc, indexType,
-                                        getVoidPtr(moduleOp->getContext()));
+                                  getVoidPtr(moduleOp->getContext()));
 }
 
 FailureOr<LLVM::LLVMFuncOp>
 mlir::LLVM::lookupOrCreateGenericAlignedAllocFn(Operation *moduleOp,
                                                 Type indexType) {
   return lookupOrCreateReservedFn(moduleOp, kGenericAlignedAlloc,
-                                        {indexType, indexType},
-                                        getVoidPtr(moduleOp->getContext()));
+                                  {indexType, indexType},
+                                  getVoidPtr(moduleOp->getContext()));
 }
 
 FailureOr<LLVM::LLVMFuncOp>



More information about the Mlir-commits mailing list