[Mlir-commits] [mlir] [mlir] Fix conflict of user defined reserved functions with internal prototypes (PR #123378)

Luohao Wang llvmlistbot at llvm.org
Fri Jan 17 10:13:06 PST 2025


https://github.com/Luohaothu updated https://github.com/llvm/llvm-project/pull/123378

>From 188a59cd61ab1459e5ca1bb73f905c158d9e3cf0 Mon Sep 17 00:00:00 2001
From: Luohao Wang <luohaothu at live.com>
Date: Sat, 18 Jan 2025 01:43:09 +0800
Subject: [PATCH 1/3] [mlir] Add assertion on reserved function's type

---
 .../mlir/Dialect/LLVMIR/FunctionCallUtils.h   |  3 +-
 .../Dialect/LLVMIR/IR/FunctionCallUtils.cpp   | 56 ++++++++++++-------
 2 files changed, 38 insertions(+), 21 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
index 852490cf7428f8..3095c83b90db9e 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
@@ -64,7 +64,8 @@ LLVM::LLVMFuncOp lookupOrCreateMemRefCopyFn(Operation *moduleOp, Type indexType,
 /// Create a FuncOp with signature `resultType`(`paramTypes`)` and name `name`.
 LLVM::LLVMFuncOp lookupOrCreateFn(Operation *moduleOp, StringRef name,
                                   ArrayRef<Type> paramTypes = {},
-                                  Type resultType = {}, bool isVarArg = false);
+                                  Type resultType = {}, bool isVarArg = false,
+                                  bool isReserved = false);
 
 } // namespace LLVM
 } // namespace mlir
diff --git a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
index 88421a16ccf9fb..ecc31df40ea52f 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
@@ -48,13 +48,29 @@ static constexpr llvm::StringRef kMemRefCopy = "memrefCopy";
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFn(Operation *moduleOp,
                                               StringRef name,
                                               ArrayRef<Type> paramTypes,
-                                              Type resultType, bool isVarArg) {
+                                              Type resultType, bool isVarArg, bool isReserved) {
   assert(moduleOp->hasTrait<OpTrait::SymbolTable>() &&
          "expected SymbolTable operation");
   auto func = llvm::dyn_cast_or_null<LLVM::LLVMFuncOp>(
       SymbolTable::lookupSymbolIn(moduleOp, name));
-  if (func)
+  auto funcT = LLVMFunctionType::get(resultType, paramTypes, isVarArg);
+  // Assert the signature of the found function is same as expected
+  if (func) {
+    if (funcT != func.getFunctionType()) {
+      if (isReserved) {
+        func.emitError("redefinition of reserved function '" + name + "' of different type ")
+        .append(func.getFunctionType())
+        .append(" is prohibited");
+        exit(0);
+      } else {
+        func.emitError("redefinition of function '" + name + "' of different type ")
+        .append(funcT)
+        .append(" is prohibited");
+        exit(0);
+      }
+    }
     return func;
+  }
   OpBuilder b(moduleOp->getRegion(0));
   return b.create<LLVM::LLVMFuncOp>(
       moduleOp->getLoc(), name,
@@ -64,37 +80,37 @@ LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFn(Operation *moduleOp,
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintI64Fn(Operation *moduleOp) {
   return lookupOrCreateFn(moduleOp, kPrintI64,
                           IntegerType::get(moduleOp->getContext(), 64),
-                          LLVM::LLVMVoidType::get(moduleOp->getContext()));
+                          LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
 }
 
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintU64Fn(Operation *moduleOp) {
   return lookupOrCreateFn(moduleOp, kPrintU64,
                           IntegerType::get(moduleOp->getContext(), 64),
-                          LLVM::LLVMVoidType::get(moduleOp->getContext()));
+                          LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
 }
 
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF16Fn(Operation *moduleOp) {
   return lookupOrCreateFn(moduleOp, kPrintF16,
                           IntegerType::get(moduleOp->getContext(), 16), // bits!
-                          LLVM::LLVMVoidType::get(moduleOp->getContext()));
+                          LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
 }
 
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintBF16Fn(Operation *moduleOp) {
   return lookupOrCreateFn(moduleOp, kPrintBF16,
                           IntegerType::get(moduleOp->getContext(), 16), // bits!
-                          LLVM::LLVMVoidType::get(moduleOp->getContext()));
+                          LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
 }
 
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF32Fn(Operation *moduleOp) {
   return lookupOrCreateFn(moduleOp, kPrintF32,
                           Float32Type::get(moduleOp->getContext()),
-                          LLVM::LLVMVoidType::get(moduleOp->getContext()));
+                          LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
 }
 
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF64Fn(Operation *moduleOp) {
   return lookupOrCreateFn(moduleOp, kPrintF64,
                           Float64Type::get(moduleOp->getContext()),
-                          LLVM::LLVMVoidType::get(moduleOp->getContext()));
+                          LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
 }
 
 static LLVM::LLVMPointerType getCharPtr(MLIRContext *context) {
@@ -110,51 +126,51 @@ LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintStringFn(
     Operation *moduleOp, std::optional<StringRef> runtimeFunctionName) {
   return lookupOrCreateFn(moduleOp, runtimeFunctionName.value_or(kPrintString),
                           getCharPtr(moduleOp->getContext()),
-                          LLVM::LLVMVoidType::get(moduleOp->getContext()));
+                          LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
 }
 
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintOpenFn(Operation *moduleOp) {
   return lookupOrCreateFn(moduleOp, kPrintOpen, {},
-                          LLVM::LLVMVoidType::get(moduleOp->getContext()));
+                          LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
 }
 
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCloseFn(Operation *moduleOp) {
   return lookupOrCreateFn(moduleOp, kPrintClose, {},
-                          LLVM::LLVMVoidType::get(moduleOp->getContext()));
+                          LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
 }
 
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCommaFn(Operation *moduleOp) {
   return lookupOrCreateFn(moduleOp, kPrintComma, {},
-                          LLVM::LLVMVoidType::get(moduleOp->getContext()));
+                          LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
 }
 
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintNewlineFn(Operation *moduleOp) {
   return lookupOrCreateFn(moduleOp, kPrintNewline, {},
-                          LLVM::LLVMVoidType::get(moduleOp->getContext()));
+                          LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
 }
 
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateMallocFn(Operation *moduleOp,
                                                     Type indexType) {
   return LLVM::lookupOrCreateFn(moduleOp, kMalloc, indexType,
-                                getVoidPtr(moduleOp->getContext()));
+                                getVoidPtr(moduleOp->getContext()), false, true);
 }
 
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateAlignedAllocFn(Operation *moduleOp,
                                                           Type indexType) {
   return LLVM::lookupOrCreateFn(moduleOp, kAlignedAlloc, {indexType, indexType},
-                                getVoidPtr(moduleOp->getContext()));
+                                getVoidPtr(moduleOp->getContext()), false, true);
 }
 
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFreeFn(Operation *moduleOp) {
   return LLVM::lookupOrCreateFn(
       moduleOp, kFree, getVoidPtr(moduleOp->getContext()),
-      LLVM::LLVMVoidType::get(moduleOp->getContext()));
+      LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
 }
 
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateGenericAllocFn(Operation *moduleOp,
                                                           Type indexType) {
   return LLVM::lookupOrCreateFn(moduleOp, kGenericAlloc, indexType,
-                                getVoidPtr(moduleOp->getContext()));
+                                getVoidPtr(moduleOp->getContext()), false, true);
 }
 
 LLVM::LLVMFuncOp
@@ -162,13 +178,13 @@ mlir::LLVM::lookupOrCreateGenericAlignedAllocFn(Operation *moduleOp,
                                                 Type indexType) {
   return LLVM::lookupOrCreateFn(moduleOp, kGenericAlignedAlloc,
                                 {indexType, indexType},
-                                getVoidPtr(moduleOp->getContext()));
+                                getVoidPtr(moduleOp->getContext()), false, true);
 }
 
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateGenericFreeFn(Operation *moduleOp) {
   return LLVM::lookupOrCreateFn(
       moduleOp, kGenericFree, getVoidPtr(moduleOp->getContext()),
-      LLVM::LLVMVoidType::get(moduleOp->getContext()));
+      LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
 }
 
 LLVM::LLVMFuncOp
@@ -177,5 +193,5 @@ mlir::LLVM::lookupOrCreateMemRefCopyFn(Operation *moduleOp, Type indexType,
   return LLVM::lookupOrCreateFn(
       moduleOp, kMemRefCopy,
       ArrayRef<Type>{indexType, unrankedDescriptorType, unrankedDescriptorType},
-      LLVM::LLVMVoidType::get(moduleOp->getContext()));
+      LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
 }

>From cc9e7dd0813bc030216907aaeaea964d313567cc Mon Sep 17 00:00:00 2001
From: Luohao Wang <luohaothu at live.com>
Date: Sat, 18 Jan 2025 01:43:28 +0800
Subject: [PATCH 2/3] [mlir] Add test

---
 mlir/test/Conversion/MemRefToLLVM/issue-120950.mlir | 11 +++++++++++
 1 file changed, 11 insertions(+)
 create mode 100644 mlir/test/Conversion/MemRefToLLVM/issue-120950.mlir

diff --git a/mlir/test/Conversion/MemRefToLLVM/issue-120950.mlir b/mlir/test/Conversion/MemRefToLLVM/issue-120950.mlir
new file mode 100644
index 00000000000000..f744e4f7635ea7
--- /dev/null
+++ b/mlir/test/Conversion/MemRefToLLVM/issue-120950.mlir
@@ -0,0 +1,11 @@
+// RUN: mlir-opt %s -finalize-memref-to-llvm 2>&1 | FileCheck %s
+
+#map = affine_map<(d0) -> (d0 + 1)>
+module {
+  // CHECK: redefinition of reserved function 'malloc' of different type '!llvm.func<void (i64)>' is prohibited
+  llvm.func @malloc(i64)
+  func.func @issue_120950() {
+    %alloc = memref.alloc() : memref<1024x64xf32, 1>
+    llvm.return
+  }
+}

>From 832a33e098f2833f87258d580d638f01e4869582 Mon Sep 17 00:00:00 2001
From: Luohao Wang <luohaothu at live.com>
Date: Sat, 18 Jan 2025 02:12:39 +0800
Subject: [PATCH 3/3] [mlir] Reformat code

---
 .../Dialect/LLVMIR/IR/FunctionCallUtils.cpp   | 77 +++++++++++--------
 1 file changed, 45 insertions(+), 32 deletions(-)

diff --git a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
index ecc31df40ea52f..757a1acf3626f6 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
@@ -48,7 +48,8 @@ static constexpr llvm::StringRef kMemRefCopy = "memrefCopy";
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFn(Operation *moduleOp,
                                               StringRef name,
                                               ArrayRef<Type> paramTypes,
-                                              Type resultType, bool isVarArg, bool isReserved) {
+                                              Type resultType, bool isVarArg,
+                                              bool isReserved) {
   assert(moduleOp->hasTrait<OpTrait::SymbolTable>() &&
          "expected SymbolTable operation");
   auto func = llvm::dyn_cast_or_null<LLVM::LLVMFuncOp>(
@@ -58,14 +59,16 @@ LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFn(Operation *moduleOp,
   if (func) {
     if (funcT != func.getFunctionType()) {
       if (isReserved) {
-        func.emitError("redefinition of reserved function '" + name + "' of different type ")
-        .append(func.getFunctionType())
-        .append(" is prohibited");
+        func.emitError("redefinition of reserved function '" + name +
+                       "' of different type ")
+            .append(func.getFunctionType())
+            .append(" is prohibited");
         exit(0);
       } else {
-        func.emitError("redefinition of function '" + name + "' of different type ")
-        .append(funcT)
-        .append(" is prohibited");
+        func.emitError("redefinition of function '" + name +
+                       "' of different type ")
+            .append(funcT)
+            .append(" is prohibited");
         exit(0);
       }
     }
@@ -78,39 +81,41 @@ LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFn(Operation *moduleOp,
 }
 
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintI64Fn(Operation *moduleOp) {
-  return lookupOrCreateFn(moduleOp, kPrintI64,
-                          IntegerType::get(moduleOp->getContext(), 64),
-                          LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
+  return lookupOrCreateFn(
+      moduleOp, kPrintI64, IntegerType::get(moduleOp->getContext(), 64),
+      LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
 }
 
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintU64Fn(Operation *moduleOp) {
-  return lookupOrCreateFn(moduleOp, kPrintU64,
-                          IntegerType::get(moduleOp->getContext(), 64),
-                          LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
+  return lookupOrCreateFn(
+      moduleOp, kPrintU64, IntegerType::get(moduleOp->getContext(), 64),
+      LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
 }
 
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF16Fn(Operation *moduleOp) {
   return lookupOrCreateFn(moduleOp, kPrintF16,
                           IntegerType::get(moduleOp->getContext(), 16), // bits!
-                          LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
+                          LLVM::LLVMVoidType::get(moduleOp->getContext()),
+                          false, true);
 }
 
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintBF16Fn(Operation *moduleOp) {
   return lookupOrCreateFn(moduleOp, kPrintBF16,
                           IntegerType::get(moduleOp->getContext(), 16), // bits!
-                          LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
+                          LLVM::LLVMVoidType::get(moduleOp->getContext()),
+                          false, true);
 }
 
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF32Fn(Operation *moduleOp) {
-  return lookupOrCreateFn(moduleOp, kPrintF32,
-                          Float32Type::get(moduleOp->getContext()),
-                          LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
+  return lookupOrCreateFn(
+      moduleOp, kPrintF32, Float32Type::get(moduleOp->getContext()),
+      LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
 }
 
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF64Fn(Operation *moduleOp) {
-  return lookupOrCreateFn(moduleOp, kPrintF64,
-                          Float64Type::get(moduleOp->getContext()),
-                          LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
+  return lookupOrCreateFn(
+      moduleOp, kPrintF64, Float64Type::get(moduleOp->getContext()),
+      LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
 }
 
 static LLVM::LLVMPointerType getCharPtr(MLIRContext *context) {
@@ -126,39 +131,46 @@ LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintStringFn(
     Operation *moduleOp, std::optional<StringRef> runtimeFunctionName) {
   return lookupOrCreateFn(moduleOp, runtimeFunctionName.value_or(kPrintString),
                           getCharPtr(moduleOp->getContext()),
-                          LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
+                          LLVM::LLVMVoidType::get(moduleOp->getContext()),
+                          false, true);
 }
 
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintOpenFn(Operation *moduleOp) {
   return lookupOrCreateFn(moduleOp, kPrintOpen, {},
-                          LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
+                          LLVM::LLVMVoidType::get(moduleOp->getContext()),
+                          false, true);
 }
 
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCloseFn(Operation *moduleOp) {
   return lookupOrCreateFn(moduleOp, kPrintClose, {},
-                          LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
+                          LLVM::LLVMVoidType::get(moduleOp->getContext()),
+                          false, true);
 }
 
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCommaFn(Operation *moduleOp) {
   return lookupOrCreateFn(moduleOp, kPrintComma, {},
-                          LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
+                          LLVM::LLVMVoidType::get(moduleOp->getContext()),
+                          false, true);
 }
 
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintNewlineFn(Operation *moduleOp) {
   return lookupOrCreateFn(moduleOp, kPrintNewline, {},
-                          LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
+                          LLVM::LLVMVoidType::get(moduleOp->getContext()),
+                          false, true);
 }
 
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateMallocFn(Operation *moduleOp,
                                                     Type indexType) {
   return LLVM::lookupOrCreateFn(moduleOp, kMalloc, indexType,
-                                getVoidPtr(moduleOp->getContext()), false, true);
+                                getVoidPtr(moduleOp->getContext()), false,
+                                true);
 }
 
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateAlignedAllocFn(Operation *moduleOp,
                                                           Type indexType) {
   return LLVM::lookupOrCreateFn(moduleOp, kAlignedAlloc, {indexType, indexType},
-                                getVoidPtr(moduleOp->getContext()), false, true);
+                                getVoidPtr(moduleOp->getContext()), false,
+                                true);
 }
 
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFreeFn(Operation *moduleOp) {
@@ -170,15 +182,16 @@ LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFreeFn(Operation *moduleOp) {
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateGenericAllocFn(Operation *moduleOp,
                                                           Type indexType) {
   return LLVM::lookupOrCreateFn(moduleOp, kGenericAlloc, indexType,
-                                getVoidPtr(moduleOp->getContext()), false, true);
+                                getVoidPtr(moduleOp->getContext()), false,
+                                true);
 }
 
 LLVM::LLVMFuncOp
 mlir::LLVM::lookupOrCreateGenericAlignedAllocFn(Operation *moduleOp,
                                                 Type indexType) {
-  return LLVM::lookupOrCreateFn(moduleOp, kGenericAlignedAlloc,
-                                {indexType, indexType},
-                                getVoidPtr(moduleOp->getContext()), false, true);
+  return LLVM::lookupOrCreateFn(
+      moduleOp, kGenericAlignedAlloc, {indexType, indexType},
+      getVoidPtr(moduleOp->getContext()), false, true);
 }
 
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateGenericFreeFn(Operation *moduleOp) {



More information about the Mlir-commits mailing list