[Mlir-commits] [mlir] [mlir] Fix conflict of user defined reserved functions with internal prototypes (PR #123378)

Luohao Wang llvmlistbot at llvm.org
Fri Jan 17 09:55:02 PST 2025


https://github.com/Luohaothu created https://github.com/llvm/llvm-project/pull/123378

Related to #120950

On lowering from `memref` to LLVM, `malloc` and other intrinsic functions from `libc` will be declared in the current module. User's redefinition of these reserved functions will poison the internal analysis with wrong prototype. This patch adds assertion on the found function's type and reports if it mismatch with the intended type.

>From 188a59cd61ab1459e5ca1bb73f905c158d9e3cf0 Mon Sep 17 00:00:00 2001
From: Luohao Wang <luohaothu at live.com>
Date: Sat, 18 Jan 2025 01:43:09 +0800
Subject: [PATCH 1/2] [mlir] Add assertion on reserved function's type

---
 .../mlir/Dialect/LLVMIR/FunctionCallUtils.h   |  3 +-
 .../Dialect/LLVMIR/IR/FunctionCallUtils.cpp   | 56 ++++++++++++-------
 2 files changed, 38 insertions(+), 21 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
index 852490cf7428f8..3095c83b90db9e 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
@@ -64,7 +64,8 @@ LLVM::LLVMFuncOp lookupOrCreateMemRefCopyFn(Operation *moduleOp, Type indexType,
 /// Create a FuncOp with signature `resultType`(`paramTypes`)` and name `name`.
 LLVM::LLVMFuncOp lookupOrCreateFn(Operation *moduleOp, StringRef name,
                                   ArrayRef<Type> paramTypes = {},
-                                  Type resultType = {}, bool isVarArg = false);
+                                  Type resultType = {}, bool isVarArg = false,
+                                  bool isReserved = false);
 
 } // namespace LLVM
 } // namespace mlir
diff --git a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
index 88421a16ccf9fb..ecc31df40ea52f 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
@@ -48,13 +48,29 @@ static constexpr llvm::StringRef kMemRefCopy = "memrefCopy";
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFn(Operation *moduleOp,
                                               StringRef name,
                                               ArrayRef<Type> paramTypes,
-                                              Type resultType, bool isVarArg) {
+                                              Type resultType, bool isVarArg, bool isReserved) {
   assert(moduleOp->hasTrait<OpTrait::SymbolTable>() &&
          "expected SymbolTable operation");
   auto func = llvm::dyn_cast_or_null<LLVM::LLVMFuncOp>(
       SymbolTable::lookupSymbolIn(moduleOp, name));
-  if (func)
+  auto funcT = LLVMFunctionType::get(resultType, paramTypes, isVarArg);
+  // Assert the signature of the found function is same as expected
+  if (func) {
+    if (funcT != func.getFunctionType()) {
+      if (isReserved) {
+        func.emitError("redefinition of reserved function '" + name + "' of different type ")
+        .append(func.getFunctionType())
+        .append(" is prohibited");
+        exit(0);
+      } else {
+        func.emitError("redefinition of function '" + name + "' of different type ")
+        .append(funcT)
+        .append(" is prohibited");
+        exit(0);
+      }
+    }
     return func;
+  }
   OpBuilder b(moduleOp->getRegion(0));
   return b.create<LLVM::LLVMFuncOp>(
       moduleOp->getLoc(), name,
@@ -64,37 +80,37 @@ LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFn(Operation *moduleOp,
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintI64Fn(Operation *moduleOp) {
   return lookupOrCreateFn(moduleOp, kPrintI64,
                           IntegerType::get(moduleOp->getContext(), 64),
-                          LLVM::LLVMVoidType::get(moduleOp->getContext()));
+                          LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
 }
 
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintU64Fn(Operation *moduleOp) {
   return lookupOrCreateFn(moduleOp, kPrintU64,
                           IntegerType::get(moduleOp->getContext(), 64),
-                          LLVM::LLVMVoidType::get(moduleOp->getContext()));
+                          LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
 }
 
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF16Fn(Operation *moduleOp) {
   return lookupOrCreateFn(moduleOp, kPrintF16,
                           IntegerType::get(moduleOp->getContext(), 16), // bits!
-                          LLVM::LLVMVoidType::get(moduleOp->getContext()));
+                          LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
 }
 
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintBF16Fn(Operation *moduleOp) {
   return lookupOrCreateFn(moduleOp, kPrintBF16,
                           IntegerType::get(moduleOp->getContext(), 16), // bits!
-                          LLVM::LLVMVoidType::get(moduleOp->getContext()));
+                          LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
 }
 
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF32Fn(Operation *moduleOp) {
   return lookupOrCreateFn(moduleOp, kPrintF32,
                           Float32Type::get(moduleOp->getContext()),
-                          LLVM::LLVMVoidType::get(moduleOp->getContext()));
+                          LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
 }
 
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF64Fn(Operation *moduleOp) {
   return lookupOrCreateFn(moduleOp, kPrintF64,
                           Float64Type::get(moduleOp->getContext()),
-                          LLVM::LLVMVoidType::get(moduleOp->getContext()));
+                          LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
 }
 
 static LLVM::LLVMPointerType getCharPtr(MLIRContext *context) {
@@ -110,51 +126,51 @@ LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintStringFn(
     Operation *moduleOp, std::optional<StringRef> runtimeFunctionName) {
   return lookupOrCreateFn(moduleOp, runtimeFunctionName.value_or(kPrintString),
                           getCharPtr(moduleOp->getContext()),
-                          LLVM::LLVMVoidType::get(moduleOp->getContext()));
+                          LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
 }
 
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintOpenFn(Operation *moduleOp) {
   return lookupOrCreateFn(moduleOp, kPrintOpen, {},
-                          LLVM::LLVMVoidType::get(moduleOp->getContext()));
+                          LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
 }
 
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCloseFn(Operation *moduleOp) {
   return lookupOrCreateFn(moduleOp, kPrintClose, {},
-                          LLVM::LLVMVoidType::get(moduleOp->getContext()));
+                          LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
 }
 
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCommaFn(Operation *moduleOp) {
   return lookupOrCreateFn(moduleOp, kPrintComma, {},
-                          LLVM::LLVMVoidType::get(moduleOp->getContext()));
+                          LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
 }
 
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintNewlineFn(Operation *moduleOp) {
   return lookupOrCreateFn(moduleOp, kPrintNewline, {},
-                          LLVM::LLVMVoidType::get(moduleOp->getContext()));
+                          LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
 }
 
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateMallocFn(Operation *moduleOp,
                                                     Type indexType) {
   return LLVM::lookupOrCreateFn(moduleOp, kMalloc, indexType,
-                                getVoidPtr(moduleOp->getContext()));
+                                getVoidPtr(moduleOp->getContext()), false, true);
 }
 
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateAlignedAllocFn(Operation *moduleOp,
                                                           Type indexType) {
   return LLVM::lookupOrCreateFn(moduleOp, kAlignedAlloc, {indexType, indexType},
-                                getVoidPtr(moduleOp->getContext()));
+                                getVoidPtr(moduleOp->getContext()), false, true);
 }
 
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFreeFn(Operation *moduleOp) {
   return LLVM::lookupOrCreateFn(
       moduleOp, kFree, getVoidPtr(moduleOp->getContext()),
-      LLVM::LLVMVoidType::get(moduleOp->getContext()));
+      LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
 }
 
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateGenericAllocFn(Operation *moduleOp,
                                                           Type indexType) {
   return LLVM::lookupOrCreateFn(moduleOp, kGenericAlloc, indexType,
-                                getVoidPtr(moduleOp->getContext()));
+                                getVoidPtr(moduleOp->getContext()), false, true);
 }
 
 LLVM::LLVMFuncOp
@@ -162,13 +178,13 @@ mlir::LLVM::lookupOrCreateGenericAlignedAllocFn(Operation *moduleOp,
                                                 Type indexType) {
   return LLVM::lookupOrCreateFn(moduleOp, kGenericAlignedAlloc,
                                 {indexType, indexType},
-                                getVoidPtr(moduleOp->getContext()));
+                                getVoidPtr(moduleOp->getContext()), false, true);
 }
 
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateGenericFreeFn(Operation *moduleOp) {
   return LLVM::lookupOrCreateFn(
       moduleOp, kGenericFree, getVoidPtr(moduleOp->getContext()),
-      LLVM::LLVMVoidType::get(moduleOp->getContext()));
+      LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
 }
 
 LLVM::LLVMFuncOp
@@ -177,5 +193,5 @@ mlir::LLVM::lookupOrCreateMemRefCopyFn(Operation *moduleOp, Type indexType,
   return LLVM::lookupOrCreateFn(
       moduleOp, kMemRefCopy,
       ArrayRef<Type>{indexType, unrankedDescriptorType, unrankedDescriptorType},
-      LLVM::LLVMVoidType::get(moduleOp->getContext()));
+      LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
 }

>From cc9e7dd0813bc030216907aaeaea964d313567cc Mon Sep 17 00:00:00 2001
From: Luohao Wang <luohaothu at live.com>
Date: Sat, 18 Jan 2025 01:43:28 +0800
Subject: [PATCH 2/2] [mlir] Add test

---
 mlir/test/Conversion/MemRefToLLVM/issue-120950.mlir | 11 +++++++++++
 1 file changed, 11 insertions(+)
 create mode 100644 mlir/test/Conversion/MemRefToLLVM/issue-120950.mlir

diff --git a/mlir/test/Conversion/MemRefToLLVM/issue-120950.mlir b/mlir/test/Conversion/MemRefToLLVM/issue-120950.mlir
new file mode 100644
index 00000000000000..f744e4f7635ea7
--- /dev/null
+++ b/mlir/test/Conversion/MemRefToLLVM/issue-120950.mlir
@@ -0,0 +1,11 @@
+// RUN: mlir-opt %s -finalize-memref-to-llvm 2>&1 | FileCheck %s
+
+#map = affine_map<(d0) -> (d0 + 1)>
+module {
+  // CHECK: redefinition of reserved function 'malloc' of different type '!llvm.func<void (i64)>' is prohibited
+  llvm.func @malloc(i64)
+  func.func @issue_120950() {
+    %alloc = memref.alloc() : memref<1024x64xf32, 1>
+    llvm.return
+  }
+}



More information about the Mlir-commits mailing list