[Mlir-commits] [mlir] e84f6b6 - [mlir] Fix conflict of user defined reserved functions with internal prototypes (#123378)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jan 28 05:40:51 PST 2025


Author: Luohao Wang
Date: 2025-01-28T14:40:47+01:00
New Revision: e84f6b6a88c1222d512edf0644c8f869dd12b8ef

URL: https://github.com/llvm/llvm-project/commit/e84f6b6a88c1222d512edf0644c8f869dd12b8ef
DIFF: https://github.com/llvm/llvm-project/commit/e84f6b6a88c1222d512edf0644c8f869dd12b8ef.diff

LOG: [mlir] Fix conflict of user defined reserved functions with internal prototypes (#123378)

On lowering from `memref` to LLVM, `malloc` and other intrinsic
functions from `libc` will be declared in the current module. User's
redefinition of these reserved functions will poison the internal
analysis with wrong prototype. This patch adds assertion on the found
function's type and reports if it mismatch with the intended type.

Related to #120950


---------

Co-authored-by: Luohao Wang <Luohaothu at users.noreply.github.com>

Added: 
    

Modified: 
    mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h
    mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
    mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
    mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
    mlir/lib/Conversion/LLVMCommon/Pattern.cpp
    mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp
    mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
    mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
    mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
    mlir/test/Conversion/MemRefToLLVM/invalid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h b/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h
index c2742b6fc1d737..33402301115b73 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h
@@ -23,11 +23,10 @@ namespace LLVM {
 /// Generate IR that prints the given string to stdout.
 /// If a custom runtime function is defined via `runtimeFunctionName`, it must
 /// have the signature void(char const*). The default function is `printString`.
-void createPrintStrCall(OpBuilder &builder, Location loc, ModuleOp moduleOp,
-                        StringRef symbolName, StringRef string,
-                        const LLVMTypeConverter &typeConverter,
-                        bool addNewline = true,
-                        std::optional<StringRef> runtimeFunctionName = {});
+LogicalResult createPrintStrCall(
+    OpBuilder &builder, Location loc, ModuleOp moduleOp, StringRef symbolName,
+    StringRef string, const LLVMTypeConverter &typeConverter,
+    bool addNewline = true, std::optional<StringRef> runtimeFunctionName = {});
 } // namespace LLVM
 
 } // namespace mlir

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
index 852490cf7428f8..05e9fe9d58859c 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
@@ -16,7 +16,6 @@
 
 #include "mlir/IR/Operation.h"
 #include "mlir/Support/LLVM.h"
-#include <optional>
 
 namespace mlir {
 class Location;
@@ -29,42 +28,47 @@ class ValueRange;
 namespace LLVM {
 class LLVMFuncOp;
 
-/// Helper functions to lookup or create the declaration for commonly used
+/// Helper functions to look up or create the declaration for commonly used
 /// external C function calls. The list of functions provided here must be
 /// implemented separately (e.g. as part of a support runtime library or as part
 /// of the libc).
-LLVM::LLVMFuncOp lookupOrCreatePrintI64Fn(Operation *moduleOp);
-LLVM::LLVMFuncOp lookupOrCreatePrintU64Fn(Operation *moduleOp);
-LLVM::LLVMFuncOp lookupOrCreatePrintF16Fn(Operation *moduleOp);
-LLVM::LLVMFuncOp lookupOrCreatePrintBF16Fn(Operation *moduleOp);
-LLVM::LLVMFuncOp lookupOrCreatePrintF32Fn(Operation *moduleOp);
-LLVM::LLVMFuncOp lookupOrCreatePrintF64Fn(Operation *moduleOp);
+/// Failure if an unexpected version of function is found.
+FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintI64Fn(Operation *moduleOp);
+FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintU64Fn(Operation *moduleOp);
+FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintF16Fn(Operation *moduleOp);
+FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintBF16Fn(Operation *moduleOp);
+FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintF32Fn(Operation *moduleOp);
+FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintF64Fn(Operation *moduleOp);
 /// Declares a function to print a C-string.
 /// If a custom runtime function is defined via `runtimeFunctionName`, it must
 /// have the signature void(char const*). The default function is `printString`.
-LLVM::LLVMFuncOp
+FailureOr<LLVM::LLVMFuncOp>
 lookupOrCreatePrintStringFn(Operation *moduleOp,
                             std::optional<StringRef> runtimeFunctionName = {});
-LLVM::LLVMFuncOp lookupOrCreatePrintOpenFn(Operation *moduleOp);
-LLVM::LLVMFuncOp lookupOrCreatePrintCloseFn(Operation *moduleOp);
-LLVM::LLVMFuncOp lookupOrCreatePrintCommaFn(Operation *moduleOp);
-LLVM::LLVMFuncOp lookupOrCreatePrintNewlineFn(Operation *moduleOp);
-LLVM::LLVMFuncOp lookupOrCreateMallocFn(Operation *moduleOp, Type indexType);
-LLVM::LLVMFuncOp lookupOrCreateAlignedAllocFn(Operation *moduleOp,
-                                              Type indexType);
-LLVM::LLVMFuncOp lookupOrCreateFreeFn(Operation *moduleOp);
-LLVM::LLVMFuncOp lookupOrCreateGenericAllocFn(Operation *moduleOp,
-                                              Type indexType);
-LLVM::LLVMFuncOp lookupOrCreateGenericAlignedAllocFn(Operation *moduleOp,
-                                                     Type indexType);
-LLVM::LLVMFuncOp lookupOrCreateGenericFreeFn(Operation *moduleOp);
-LLVM::LLVMFuncOp lookupOrCreateMemRefCopyFn(Operation *moduleOp, Type indexType,
-                                            Type unrankedDescriptorType);
+FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintOpenFn(Operation *moduleOp);
+FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintCloseFn(Operation *moduleOp);
+FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintCommaFn(Operation *moduleOp);
+FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintNewlineFn(Operation *moduleOp);
+FailureOr<LLVM::LLVMFuncOp> lookupOrCreateMallocFn(Operation *moduleOp,
+                                                   Type indexType);
+FailureOr<LLVM::LLVMFuncOp> lookupOrCreateAlignedAllocFn(Operation *moduleOp,
+                                                         Type indexType);
+FailureOr<LLVM::LLVMFuncOp> lookupOrCreateFreeFn(Operation *moduleOp);
+FailureOr<LLVM::LLVMFuncOp> lookupOrCreateGenericAllocFn(Operation *moduleOp,
+                                                         Type indexType);
+FailureOr<LLVM::LLVMFuncOp>
+lookupOrCreateGenericAlignedAllocFn(Operation *moduleOp, Type indexType);
+FailureOr<LLVM::LLVMFuncOp> lookupOrCreateGenericFreeFn(Operation *moduleOp);
+FailureOr<LLVM::LLVMFuncOp>
+lookupOrCreateMemRefCopyFn(Operation *moduleOp, Type indexType,
+                           Type unrankedDescriptorType);
 
 /// Create a FuncOp with signature `resultType`(`paramTypes`)` and name `name`.
-LLVM::LLVMFuncOp lookupOrCreateFn(Operation *moduleOp, StringRef name,
-                                  ArrayRef<Type> paramTypes = {},
-                                  Type resultType = {}, bool isVarArg = false);
+/// Return a failure if the FuncOp found has unexpected signature.
+FailureOr<LLVM::LLVMFuncOp>
+lookupOrCreateFn(Operation *moduleOp, StringRef name,
+                 ArrayRef<Type> paramTypes = {}, Type resultType = {},
+                 bool isVarArg = false, bool isReserved = false);
 
 } // namespace LLVM
 } // namespace mlir

diff  --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
index 9b5aeb3fef30b4..47d4474a5c28d7 100644
--- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
+++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
@@ -396,8 +396,10 @@ class CoroBeginOpConversion : public AsyncOpConversionPattern<CoroBeginOp> {
     // Allocate memory for the coroutine frame.
     auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn(
         op->getParentOfType<ModuleOp>(), rewriter.getI64Type());
+    if (failed(allocFuncOp))
+      return failure();
     auto coroAlloc = rewriter.create<LLVM::CallOp>(
-        loc, allocFuncOp, ValueRange{coroAlign, coroSize});
+        loc, allocFuncOp.value(), ValueRange{coroAlign, coroSize});
 
     // Begin a coroutine: @llvm.coro.begin.
     auto coroId = CoroBeginOpAdaptor(adaptor.getOperands()).getId();
@@ -431,7 +433,9 @@ class CoroFreeOpConversion : public AsyncOpConversionPattern<CoroFreeOp> {
     // Free the memory.
     auto freeFuncOp =
         LLVM::lookupOrCreateFreeFn(op->getParentOfType<ModuleOp>());
-    rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, freeFuncOp,
+    if (failed(freeFuncOp))
+      return failure();
+    rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, freeFuncOp.value(),
                                               ValueRange(coroMem.getResult()));
 
     return success();

diff  --git a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
index d0ffb94f3f96a9..debfd003bd5b5e 100644
--- a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
+++ b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
@@ -61,9 +61,13 @@ struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {
 
     // Failed block: Generate IR to print the message and call `abort`.
     Block *failureBlock = rewriter.createBlock(opBlock->getParent());
-    LLVM::createPrintStrCall(rewriter, loc, module, "assert_msg", op.getMsg(),
-                             *getTypeConverter(), /*addNewLine=*/false,
-                             /*runtimeFunctionName=*/"puts");
+    auto createResult = LLVM::createPrintStrCall(
+        rewriter, loc, module, "assert_msg", op.getMsg(), *getTypeConverter(),
+        /*addNewLine=*/false,
+        /*runtimeFunctionName=*/"puts");
+    if (createResult.failed())
+      return failure();
+
     if (abortOnFailedAssert) {
       // Insert the `abort` declaration if necessary.
       auto abortFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("abort");

diff  --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
index a47a2872ceb073..840bd3df61a063 100644
--- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
@@ -276,11 +276,17 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
 
   // Find the malloc and free, or declare them if necessary.
   auto module = builder.getInsertionPoint()->getParentOfType<ModuleOp>();
-  LLVM::LLVMFuncOp freeFunc, mallocFunc;
-  if (toDynamic)
+  FailureOr<LLVM::LLVMFuncOp> freeFunc, mallocFunc;
+  if (toDynamic) {
     mallocFunc = LLVM::lookupOrCreateMallocFn(module, indexType);
-  if (!toDynamic)
+    if (failed(mallocFunc))
+      return failure();
+  }
+  if (!toDynamic) {
     freeFunc = LLVM::lookupOrCreateFreeFn(module);
+    if (failed(freeFunc))
+      return failure();
+  }
 
   unsigned unrankedMemrefPos = 0;
   for (unsigned i = 0, e = operands.size(); i < e; ++i) {
@@ -293,7 +299,8 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
     // Allocate memory, copy, and free the source if necessary.
     Value memory =
         toDynamic
-            ? builder.create<LLVM::CallOp>(loc, mallocFunc, allocationSize)
+            ? builder
+                  .create<LLVM::CallOp>(loc, mallocFunc.value(), allocationSize)
                   .getResult()
             : builder.create<LLVM::AllocaOp>(loc, getVoidPtrType(),
                                              IntegerType::get(getContext(), 8),
@@ -302,7 +309,7 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
     Value source = desc.memRefDescPtr(builder, loc);
     builder.create<LLVM::MemcpyOp>(loc, memory, source, allocationSize, false);
     if (!toDynamic)
-      builder.create<LLVM::CallOp>(loc, freeFunc, source);
+      builder.create<LLVM::CallOp>(loc, freeFunc.value(), source);
 
     // Create a new descriptor. The same descriptor can be returned multiple
     // times, attempting to modify its pointer can lead to memory leaks

diff  --git a/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp b/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp
index bd7b401efec17a..337c01f01a7cc7 100644
--- a/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp
@@ -27,7 +27,7 @@ static std::string ensureSymbolNameIsUnique(ModuleOp moduleOp,
   return uniqueName;
 }
 
-void mlir::LLVM::createPrintStrCall(
+LogicalResult mlir::LLVM::createPrintStrCall(
     OpBuilder &builder, Location loc, ModuleOp moduleOp, StringRef symbolName,
     StringRef string, const LLVMTypeConverter &typeConverter, bool addNewline,
     std::optional<StringRef> runtimeFunctionName) {
@@ -59,8 +59,11 @@ void mlir::LLVM::createPrintStrCall(
   SmallVector<LLVM::GEPArg> indices(1, 0);
   Value gep =
       builder.create<LLVM::GEPOp>(loc, ptrTy, arrayTy, msgAddr, indices);
-  Operation *printer =
+  FailureOr<LLVM::LLVMFuncOp> printer =
       LLVM::lookupOrCreatePrintStringFn(moduleOp, runtimeFunctionName);
-  builder.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(printer),
-                               gep);
+  if (failed(printer))
+    return failure();
+  builder.create<LLVM::CallOp>(loc, TypeRange(),
+                               SymbolRefAttr::get(printer.value()), gep);
+  return success();
 }

diff  --git a/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp b/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
index a6408391b1330c..c5b2e83df93dcb 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
@@ -14,9 +14,9 @@
 
 using namespace mlir;
 
-namespace {
-LLVM::LLVMFuncOp getNotalignedAllocFn(const LLVMTypeConverter *typeConverter,
-                                      Operation *module, Type indexType) {
+static FailureOr<LLVM::LLVMFuncOp>
+getNotalignedAllocFn(const LLVMTypeConverter *typeConverter, Operation *module,
+                     Type indexType) {
   bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
   if (useGenericFn)
     return LLVM::lookupOrCreateGenericAllocFn(module, indexType);
@@ -24,8 +24,9 @@ LLVM::LLVMFuncOp getNotalignedAllocFn(const LLVMTypeConverter *typeConverter,
   return LLVM::lookupOrCreateMallocFn(module, indexType);
 }
 
-LLVM::LLVMFuncOp getAlignedAllocFn(const LLVMTypeConverter *typeConverter,
-                                   Operation *module, Type indexType) {
+static FailureOr<LLVM::LLVMFuncOp>
+getAlignedAllocFn(const LLVMTypeConverter *typeConverter, Operation *module,
+                  Type indexType) {
   bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
 
   if (useGenericFn)
@@ -34,8 +35,6 @@ LLVM::LLVMFuncOp getAlignedAllocFn(const LLVMTypeConverter *typeConverter,
   return LLVM::lookupOrCreateAlignedAllocFn(module, indexType);
 }
 
-} // end namespace
-
 Value AllocationOpLLVMLowering::createAligned(
     ConversionPatternRewriter &rewriter, Location loc, Value input,
     Value alignment) {
@@ -80,10 +79,13 @@ std::tuple<Value, Value> AllocationOpLLVMLowering::allocateBufferManuallyAlign(
         << " to integer address space "
            "failed. Consider adding memory space conversions.";
   }
-  LLVM::LLVMFuncOp allocFuncOp = getNotalignedAllocFn(
+  FailureOr<LLVM::LLVMFuncOp> allocFuncOp = getNotalignedAllocFn(
       getTypeConverter(), op->getParentWithTrait<OpTrait::SymbolTable>(),
       getIndexType());
-  auto results = rewriter.create<LLVM::CallOp>(loc, allocFuncOp, sizeBytes);
+  if (failed(allocFuncOp))
+    return std::make_tuple(Value(), Value());
+  auto results =
+      rewriter.create<LLVM::CallOp>(loc, allocFuncOp.value(), sizeBytes);
 
   Value allocatedPtr =
       castAllocFuncResult(rewriter, loc, results.getResult(), memRefType,
@@ -146,11 +148,13 @@ Value AllocationOpLLVMLowering::allocateBufferAutoAlign(
     sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment);
 
   Type elementPtrType = this->getElementPtrType(memRefType);
-  LLVM::LLVMFuncOp allocFuncOp = getAlignedAllocFn(
+  FailureOr<LLVM::LLVMFuncOp> allocFuncOp = getAlignedAllocFn(
       getTypeConverter(), op->getParentWithTrait<OpTrait::SymbolTable>(),
       getIndexType());
+  if (failed(allocFuncOp))
+    return Value();
   auto results = rewriter.create<LLVM::CallOp>(
-      loc, allocFuncOp, ValueRange({allocAlignment, sizeBytes}));
+      loc, allocFuncOp.value(), ValueRange({allocAlignment, sizeBytes}));
 
   return castAllocFuncResult(rewriter, loc, results.getResult(), memRefType,
                              elementPtrType, *getTypeConverter());

diff  --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index f7542b8b3bc5c7..af1dba4587dc1f 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -38,12 +38,12 @@ using namespace mlir;
 
 namespace {
 
-bool isStaticStrideOrOffset(int64_t strideOrOffset) {
+static bool isStaticStrideOrOffset(int64_t strideOrOffset) {
   return !ShapedType::isDynamic(strideOrOffset);
 }
 
-LLVM::LLVMFuncOp getFreeFn(const LLVMTypeConverter *typeConverter,
-                           ModuleOp module) {
+static FailureOr<LLVM::LLVMFuncOp>
+getFreeFn(const LLVMTypeConverter *typeConverter, ModuleOp module) {
   bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
 
   if (useGenericFn)
@@ -220,8 +220,10 @@ struct DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> {
   matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     // Insert the `free` declaration if it is not already present.
-    LLVM::LLVMFuncOp freeFunc =
+    FailureOr<LLVM::LLVMFuncOp> freeFunc =
         getFreeFn(getTypeConverter(), op->getParentOfType<ModuleOp>());
+    if (failed(freeFunc))
+      return failure();
     Value allocatedPtr;
     if (auto unrankedTy =
             llvm::dyn_cast<UnrankedMemRefType>(op.getMemref().getType())) {
@@ -236,7 +238,8 @@ struct DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> {
       allocatedPtr = MemRefDescriptor(adaptor.getMemref())
                          .allocatedPtr(rewriter, op.getLoc());
     }
-    rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, freeFunc, allocatedPtr);
+    rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, freeFunc.value(),
+                                              allocatedPtr);
     return success();
   }
 };
@@ -838,7 +841,9 @@ struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
     auto elemSize = getSizeInBytes(loc, srcType.getElementType(), rewriter);
     auto copyFn = LLVM::lookupOrCreateMemRefCopyFn(
         op->getParentOfType<ModuleOp>(), getIndexType(), sourcePtr.getType());
-    rewriter.create<LLVM::CallOp>(loc, copyFn,
+    if (failed(copyFn))
+      return failure();
+    rewriter.create<LLVM::CallOp>(loc, copyFn.value(),
                                   ValueRange{elemSize, sourcePtr, targetPtr});
 
     // Restore stack used for descriptors

diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index a1e21cb524bd9a..baed98c13adc7c 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1546,11 +1546,15 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
 
     auto punct = printOp.getPunctuation();
     if (auto stringLiteral = printOp.getStringLiteral()) {
-      LLVM::createPrintStrCall(rewriter, loc, parent, "vector_print_str",
-                               *stringLiteral, *getTypeConverter(),
-                               /*addNewline=*/false);
+      auto createResult =
+          LLVM::createPrintStrCall(rewriter, loc, parent, "vector_print_str",
+                                   *stringLiteral, *getTypeConverter(),
+                                   /*addNewline=*/false);
+      if (createResult.failed())
+        return failure();
+
     } else if (punct != PrintPunctuation::NoPunctuation) {
-      emitCall(rewriter, printOp->getLoc(), [&] {
+      FailureOr<LLVM::LLVMFuncOp> op = [&]() {
         switch (punct) {
         case PrintPunctuation::Close:
           return LLVM::lookupOrCreatePrintCloseFn(parent);
@@ -1563,7 +1567,10 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
         default:
           llvm_unreachable("unexpected punctuation");
         }
-      }());
+      }();
+      if (failed(op))
+        return failure();
+      emitCall(rewriter, printOp->getLoc(), op.value());
     }
 
     rewriter.eraseOp(printOp);
@@ -1588,7 +1595,7 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
 
     // Make sure element type has runtime support.
     PrintConversion conversion = PrintConversion::None;
-    Operation *printer;
+    FailureOr<Operation *> printer;
     if (printType.isF32()) {
       printer = LLVM::lookupOrCreatePrintF32Fn(parent);
     } else if (printType.isF64()) {
@@ -1631,6 +1638,8 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
     } else {
       return failure();
     }
+    if (failed(printer))
+      return failure();
 
     switch (conversion) {
     case PrintConversion::ZeroExt64:
@@ -1648,7 +1657,7 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
     case PrintConversion::None:
       break;
     }
-    emitCall(rewriter, loc, printer, value);
+    emitCall(rewriter, loc, printer.value(), value);
     return success();
   }
 

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
index 88421a16ccf9fb..68d4426e653019 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
@@ -45,56 +45,85 @@ static constexpr llvm::StringRef kGenericFree = "_mlir_memref_to_llvm_free";
 static constexpr llvm::StringRef kMemRefCopy = "memrefCopy";
 
 /// Generic print function lookupOrCreate helper.
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFn(Operation *moduleOp,
-                                              StringRef name,
-                                              ArrayRef<Type> paramTypes,
-                                              Type resultType, bool isVarArg) {
+FailureOr<LLVM::LLVMFuncOp>
+mlir::LLVM::lookupOrCreateFn(Operation *moduleOp, StringRef name,
+                             ArrayRef<Type> paramTypes, Type resultType,
+                             bool isVarArg, bool isReserved) {
   assert(moduleOp->hasTrait<OpTrait::SymbolTable>() &&
          "expected SymbolTable operation");
   auto func = llvm::dyn_cast_or_null<LLVM::LLVMFuncOp>(
       SymbolTable::lookupSymbolIn(moduleOp, name));
-  if (func)
+  auto funcT = LLVMFunctionType::get(resultType, paramTypes, isVarArg);
+  // Assert the signature of the found function is same as expected
+  if (func) {
+    if (funcT != func.getFunctionType()) {
+      if (isReserved) {
+        func.emitError("redefinition of reserved function '")
+            << name << "' of 
diff erent type " << func.getFunctionType()
+            << " is prohibited";
+      } else {
+        func.emitError("redefinition of function '")
+            << name << "' of 
diff erent type " << funcT << " is prohibited";
+      }
+      return failure();
+    }
     return func;
+  }
   OpBuilder b(moduleOp->getRegion(0));
   return b.create<LLVM::LLVMFuncOp>(
       moduleOp->getLoc(), name,
       LLVM::LLVMFunctionType::get(resultType, paramTypes, isVarArg));
 }
 
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintI64Fn(Operation *moduleOp) {
-  return lookupOrCreateFn(moduleOp, kPrintI64,
-                          IntegerType::get(moduleOp->getContext(), 64),
-                          LLVM::LLVMVoidType::get(moduleOp->getContext()));
+static FailureOr<LLVM::LLVMFuncOp>
+lookupOrCreateReservedFn(Operation *moduleOp, StringRef name,
+                         ArrayRef<Type> paramTypes, Type resultType) {
+  return lookupOrCreateFn(moduleOp, name, paramTypes, resultType,
+                          /*isVarArg=*/false, /*isReserved=*/true);
 }
 
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintU64Fn(Operation *moduleOp) {
-  return lookupOrCreateFn(moduleOp, kPrintU64,
-                          IntegerType::get(moduleOp->getContext(), 64),
-                          LLVM::LLVMVoidType::get(moduleOp->getContext()));
+FailureOr<LLVM::LLVMFuncOp>
+mlir::LLVM::lookupOrCreatePrintI64Fn(Operation *moduleOp) {
+  return lookupOrCreateReservedFn(
+      moduleOp, kPrintI64, IntegerType::get(moduleOp->getContext(), 64),
+      LLVM::LLVMVoidType::get(moduleOp->getContext()));
 }
 
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF16Fn(Operation *moduleOp) {
-  return lookupOrCreateFn(moduleOp, kPrintF16,
-                          IntegerType::get(moduleOp->getContext(), 16), // bits!
-                          LLVM::LLVMVoidType::get(moduleOp->getContext()));
+FailureOr<LLVM::LLVMFuncOp>
+mlir::LLVM::lookupOrCreatePrintU64Fn(Operation *moduleOp) {
+  return lookupOrCreateReservedFn(
+      moduleOp, kPrintU64, IntegerType::get(moduleOp->getContext(), 64),
+      LLVM::LLVMVoidType::get(moduleOp->getContext()));
 }
 
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintBF16Fn(Operation *moduleOp) {
-  return lookupOrCreateFn(moduleOp, kPrintBF16,
-                          IntegerType::get(moduleOp->getContext(), 16), // bits!
-                          LLVM::LLVMVoidType::get(moduleOp->getContext()));
+FailureOr<LLVM::LLVMFuncOp>
+mlir::LLVM::lookupOrCreatePrintF16Fn(Operation *moduleOp) {
+  return lookupOrCreateReservedFn(
+      moduleOp, kPrintF16,
+      IntegerType::get(moduleOp->getContext(), 16), // bits!
+      LLVM::LLVMVoidType::get(moduleOp->getContext()));
 }
 
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF32Fn(Operation *moduleOp) {
-  return lookupOrCreateFn(moduleOp, kPrintF32,
-                          Float32Type::get(moduleOp->getContext()),
-                          LLVM::LLVMVoidType::get(moduleOp->getContext()));
+FailureOr<LLVM::LLVMFuncOp>
+mlir::LLVM::lookupOrCreatePrintBF16Fn(Operation *moduleOp) {
+  return lookupOrCreateReservedFn(
+      moduleOp, kPrintBF16,
+      IntegerType::get(moduleOp->getContext(), 16), // bits!
+      LLVM::LLVMVoidType::get(moduleOp->getContext()));
 }
 
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF64Fn(Operation *moduleOp) {
-  return lookupOrCreateFn(moduleOp, kPrintF64,
-                          Float64Type::get(moduleOp->getContext()),
-                          LLVM::LLVMVoidType::get(moduleOp->getContext()));
+FailureOr<LLVM::LLVMFuncOp>
+mlir::LLVM::lookupOrCreatePrintF32Fn(Operation *moduleOp) {
+  return lookupOrCreateReservedFn(
+      moduleOp, kPrintF32, Float32Type::get(moduleOp->getContext()),
+      LLVM::LLVMVoidType::get(moduleOp->getContext()));
+}
+
+FailureOr<LLVM::LLVMFuncOp>
+mlir::LLVM::lookupOrCreatePrintF64Fn(Operation *moduleOp) {
+  return lookupOrCreateReservedFn(
+      moduleOp, kPrintF64, Float64Type::get(moduleOp->getContext()),
+      LLVM::LLVMVoidType::get(moduleOp->getContext()));
 }
 
 static LLVM::LLVMPointerType getCharPtr(MLIRContext *context) {
@@ -106,75 +135,87 @@ static LLVM::LLVMPointerType getVoidPtr(MLIRContext *context) {
   return getCharPtr(context);
 }
 
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintStringFn(
+FailureOr<LLVM::LLVMFuncOp> mlir::LLVM::lookupOrCreatePrintStringFn(
     Operation *moduleOp, std::optional<StringRef> runtimeFunctionName) {
-  return lookupOrCreateFn(moduleOp, runtimeFunctionName.value_or(kPrintString),
-                          getCharPtr(moduleOp->getContext()),
-                          LLVM::LLVMVoidType::get(moduleOp->getContext()));
+  return lookupOrCreateReservedFn(
+      moduleOp, runtimeFunctionName.value_or(kPrintString),
+      getCharPtr(moduleOp->getContext()),
+      LLVM::LLVMVoidType::get(moduleOp->getContext()));
 }
 
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintOpenFn(Operation *moduleOp) {
-  return lookupOrCreateFn(moduleOp, kPrintOpen, {},
-                          LLVM::LLVMVoidType::get(moduleOp->getContext()));
+FailureOr<LLVM::LLVMFuncOp>
+mlir::LLVM::lookupOrCreatePrintOpenFn(Operation *moduleOp) {
+  return lookupOrCreateReservedFn(
+      moduleOp, kPrintOpen, {},
+      LLVM::LLVMVoidType::get(moduleOp->getContext()));
 }
 
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCloseFn(Operation *moduleOp) {
-  return lookupOrCreateFn(moduleOp, kPrintClose, {},
-                          LLVM::LLVMVoidType::get(moduleOp->getContext()));
+FailureOr<LLVM::LLVMFuncOp>
+mlir::LLVM::lookupOrCreatePrintCloseFn(Operation *moduleOp) {
+  return lookupOrCreateReservedFn(
+      moduleOp, kPrintClose, {},
+      LLVM::LLVMVoidType::get(moduleOp->getContext()));
 }
 
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCommaFn(Operation *moduleOp) {
-  return lookupOrCreateFn(moduleOp, kPrintComma, {},
-                          LLVM::LLVMVoidType::get(moduleOp->getContext()));
+FailureOr<LLVM::LLVMFuncOp>
+mlir::LLVM::lookupOrCreatePrintCommaFn(Operation *moduleOp) {
+  return lookupOrCreateReservedFn(
+      moduleOp, kPrintComma, {},
+      LLVM::LLVMVoidType::get(moduleOp->getContext()));
 }
 
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintNewlineFn(Operation *moduleOp) {
-  return lookupOrCreateFn(moduleOp, kPrintNewline, {},
-                          LLVM::LLVMVoidType::get(moduleOp->getContext()));
+FailureOr<LLVM::LLVMFuncOp>
+mlir::LLVM::lookupOrCreatePrintNewlineFn(Operation *moduleOp) {
+  return lookupOrCreateReservedFn(
+      moduleOp, kPrintNewline, {},
+      LLVM::LLVMVoidType::get(moduleOp->getContext()));
 }
 
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateMallocFn(Operation *moduleOp,
-                                                    Type indexType) {
-  return LLVM::lookupOrCreateFn(moduleOp, kMalloc, indexType,
-                                getVoidPtr(moduleOp->getContext()));
+FailureOr<LLVM::LLVMFuncOp>
+mlir::LLVM::lookupOrCreateMallocFn(Operation *moduleOp, Type indexType) {
+  return lookupOrCreateReservedFn(moduleOp, kMalloc, indexType,
+                                  getVoidPtr(moduleOp->getContext()));
 }
 
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateAlignedAllocFn(Operation *moduleOp,
-                                                          Type indexType) {
-  return LLVM::lookupOrCreateFn(moduleOp, kAlignedAlloc, {indexType, indexType},
-                                getVoidPtr(moduleOp->getContext()));
+FailureOr<LLVM::LLVMFuncOp>
+mlir::LLVM::lookupOrCreateAlignedAllocFn(Operation *moduleOp, Type indexType) {
+  return lookupOrCreateReservedFn(moduleOp, kAlignedAlloc,
+                                  {indexType, indexType},
+                                  getVoidPtr(moduleOp->getContext()));
 }
 
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFreeFn(Operation *moduleOp) {
-  return LLVM::lookupOrCreateFn(
+FailureOr<LLVM::LLVMFuncOp>
+mlir::LLVM::lookupOrCreateFreeFn(Operation *moduleOp) {
+  return lookupOrCreateReservedFn(
       moduleOp, kFree, getVoidPtr(moduleOp->getContext()),
       LLVM::LLVMVoidType::get(moduleOp->getContext()));
 }
 
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateGenericAllocFn(Operation *moduleOp,
-                                                          Type indexType) {
-  return LLVM::lookupOrCreateFn(moduleOp, kGenericAlloc, indexType,
-                                getVoidPtr(moduleOp->getContext()));
+FailureOr<LLVM::LLVMFuncOp>
+mlir::LLVM::lookupOrCreateGenericAllocFn(Operation *moduleOp, Type indexType) {
+  return lookupOrCreateReservedFn(moduleOp, kGenericAlloc, indexType,
+                                  getVoidPtr(moduleOp->getContext()));
 }
 
-LLVM::LLVMFuncOp
+FailureOr<LLVM::LLVMFuncOp>
 mlir::LLVM::lookupOrCreateGenericAlignedAllocFn(Operation *moduleOp,
                                                 Type indexType) {
-  return LLVM::lookupOrCreateFn(moduleOp, kGenericAlignedAlloc,
-                                {indexType, indexType},
-                                getVoidPtr(moduleOp->getContext()));
+  return lookupOrCreateReservedFn(moduleOp, kGenericAlignedAlloc,
+                                  {indexType, indexType},
+                                  getVoidPtr(moduleOp->getContext()));
 }
 
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateGenericFreeFn(Operation *moduleOp) {
-  return LLVM::lookupOrCreateFn(
+FailureOr<LLVM::LLVMFuncOp>
+mlir::LLVM::lookupOrCreateGenericFreeFn(Operation *moduleOp) {
+  return lookupOrCreateReservedFn(
       moduleOp, kGenericFree, getVoidPtr(moduleOp->getContext()),
       LLVM::LLVMVoidType::get(moduleOp->getContext()));
 }
 
-LLVM::LLVMFuncOp
+FailureOr<LLVM::LLVMFuncOp>
 mlir::LLVM::lookupOrCreateMemRefCopyFn(Operation *moduleOp, Type indexType,
                                        Type unrankedDescriptorType) {
-  return LLVM::lookupOrCreateFn(
+  return lookupOrCreateReservedFn(
       moduleOp, kMemRefCopy,
       ArrayRef<Type>{indexType, unrankedDescriptorType, unrankedDescriptorType},
       LLVM::LLVMVoidType::get(moduleOp->getContext()));

diff  --git a/mlir/test/Conversion/MemRefToLLVM/invalid.mlir b/mlir/test/Conversion/MemRefToLLVM/invalid.mlir
index 40dd75af1dd770..1e12b83a24b5a7 100644
--- a/mlir/test/Conversion/MemRefToLLVM/invalid.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/invalid.mlir
@@ -2,6 +2,13 @@
 // Since the error is at an unknown location, we use FileCheck instead of
 // -veri-y-diagnostics here
 
+// CHECK: redefinition of reserved function 'malloc' of 
diff erent type '!llvm.func<void (i64)>' is prohibited
+llvm.func @malloc(i64)
+func.func @redef_reserved() {
+    %alloc = memref.alloc() : memref<1024x64xf32, 1>
+    llvm.return
+}
+
 // CHECK: conversion of memref memory space "foo" to integer address space failed. Consider adding memory space conversions.
 // CHECK-LABEL: @bad_address_space
 func.func @bad_address_space(%a: memref<2xindex, "foo">) {


        


More information about the Mlir-commits mailing list