[Mlir-commits] [mlir] [mlir] Fix conflict of user defined reserved functions with internal prototypes (PR #123378)
Luohao Wang
llvmlistbot at llvm.org
Fri Jan 17 10:13:06 PST 2025
https://github.com/Luohaothu updated https://github.com/llvm/llvm-project/pull/123378
>From 188a59cd61ab1459e5ca1bb73f905c158d9e3cf0 Mon Sep 17 00:00:00 2001
From: Luohao Wang <luohaothu at live.com>
Date: Sat, 18 Jan 2025 01:43:09 +0800
Subject: [PATCH 1/3] [mlir] Add assertion on reserved function's type
---
.../mlir/Dialect/LLVMIR/FunctionCallUtils.h | 3 +-
.../Dialect/LLVMIR/IR/FunctionCallUtils.cpp | 56 ++++++++++++-------
2 files changed, 38 insertions(+), 21 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
index 852490cf7428f8..3095c83b90db9e 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
@@ -64,7 +64,8 @@ LLVM::LLVMFuncOp lookupOrCreateMemRefCopyFn(Operation *moduleOp, Type indexType,
/// Create a FuncOp with signature `resultType`(`paramTypes`)` and name `name`.
LLVM::LLVMFuncOp lookupOrCreateFn(Operation *moduleOp, StringRef name,
ArrayRef<Type> paramTypes = {},
- Type resultType = {}, bool isVarArg = false);
+ Type resultType = {}, bool isVarArg = false,
+ bool isReserved = false);
} // namespace LLVM
} // namespace mlir
diff --git a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
index 88421a16ccf9fb..ecc31df40ea52f 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
@@ -48,13 +48,29 @@ static constexpr llvm::StringRef kMemRefCopy = "memrefCopy";
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFn(Operation *moduleOp,
StringRef name,
ArrayRef<Type> paramTypes,
- Type resultType, bool isVarArg) {
+ Type resultType, bool isVarArg, bool isReserved) {
assert(moduleOp->hasTrait<OpTrait::SymbolTable>() &&
"expected SymbolTable operation");
auto func = llvm::dyn_cast_or_null<LLVM::LLVMFuncOp>(
SymbolTable::lookupSymbolIn(moduleOp, name));
- if (func)
+ auto funcT = LLVMFunctionType::get(resultType, paramTypes, isVarArg);
+ // Assert the signature of the found function is same as expected
+ if (func) {
+ if (funcT != func.getFunctionType()) {
+ if (isReserved) {
+ func.emitError("redefinition of reserved function '" + name + "' of different type ")
+ .append(func.getFunctionType())
+ .append(" is prohibited");
+ exit(0);
+ } else {
+ func.emitError("redefinition of function '" + name + "' of different type ")
+ .append(funcT)
+ .append(" is prohibited");
+ exit(0);
+ }
+ }
return func;
+ }
OpBuilder b(moduleOp->getRegion(0));
return b.create<LLVM::LLVMFuncOp>(
moduleOp->getLoc(), name,
@@ -64,37 +80,37 @@ LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFn(Operation *moduleOp,
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintI64Fn(Operation *moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintI64,
IntegerType::get(moduleOp->getContext(), 64),
- LLVM::LLVMVoidType::get(moduleOp->getContext()));
+ LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
}
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintU64Fn(Operation *moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintU64,
IntegerType::get(moduleOp->getContext(), 64),
- LLVM::LLVMVoidType::get(moduleOp->getContext()));
+ LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
}
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF16Fn(Operation *moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintF16,
IntegerType::get(moduleOp->getContext(), 16), // bits!
- LLVM::LLVMVoidType::get(moduleOp->getContext()));
+ LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
}
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintBF16Fn(Operation *moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintBF16,
IntegerType::get(moduleOp->getContext(), 16), // bits!
- LLVM::LLVMVoidType::get(moduleOp->getContext()));
+ LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
}
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF32Fn(Operation *moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintF32,
Float32Type::get(moduleOp->getContext()),
- LLVM::LLVMVoidType::get(moduleOp->getContext()));
+ LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
}
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF64Fn(Operation *moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintF64,
Float64Type::get(moduleOp->getContext()),
- LLVM::LLVMVoidType::get(moduleOp->getContext()));
+ LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
}
static LLVM::LLVMPointerType getCharPtr(MLIRContext *context) {
@@ -110,51 +126,51 @@ LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintStringFn(
Operation *moduleOp, std::optional<StringRef> runtimeFunctionName) {
return lookupOrCreateFn(moduleOp, runtimeFunctionName.value_or(kPrintString),
getCharPtr(moduleOp->getContext()),
- LLVM::LLVMVoidType::get(moduleOp->getContext()));
+ LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
}
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintOpenFn(Operation *moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintOpen, {},
- LLVM::LLVMVoidType::get(moduleOp->getContext()));
+ LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
}
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCloseFn(Operation *moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintClose, {},
- LLVM::LLVMVoidType::get(moduleOp->getContext()));
+ LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
}
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCommaFn(Operation *moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintComma, {},
- LLVM::LLVMVoidType::get(moduleOp->getContext()));
+ LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
}
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintNewlineFn(Operation *moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintNewline, {},
- LLVM::LLVMVoidType::get(moduleOp->getContext()));
+ LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
}
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateMallocFn(Operation *moduleOp,
Type indexType) {
return LLVM::lookupOrCreateFn(moduleOp, kMalloc, indexType,
- getVoidPtr(moduleOp->getContext()));
+ getVoidPtr(moduleOp->getContext()), false, true);
}
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateAlignedAllocFn(Operation *moduleOp,
Type indexType) {
return LLVM::lookupOrCreateFn(moduleOp, kAlignedAlloc, {indexType, indexType},
- getVoidPtr(moduleOp->getContext()));
+ getVoidPtr(moduleOp->getContext()), false, true);
}
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFreeFn(Operation *moduleOp) {
return LLVM::lookupOrCreateFn(
moduleOp, kFree, getVoidPtr(moduleOp->getContext()),
- LLVM::LLVMVoidType::get(moduleOp->getContext()));
+ LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
}
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateGenericAllocFn(Operation *moduleOp,
Type indexType) {
return LLVM::lookupOrCreateFn(moduleOp, kGenericAlloc, indexType,
- getVoidPtr(moduleOp->getContext()));
+ getVoidPtr(moduleOp->getContext()), false, true);
}
LLVM::LLVMFuncOp
@@ -162,13 +178,13 @@ mlir::LLVM::lookupOrCreateGenericAlignedAllocFn(Operation *moduleOp,
Type indexType) {
return LLVM::lookupOrCreateFn(moduleOp, kGenericAlignedAlloc,
{indexType, indexType},
- getVoidPtr(moduleOp->getContext()));
+ getVoidPtr(moduleOp->getContext()), false, true);
}
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateGenericFreeFn(Operation *moduleOp) {
return LLVM::lookupOrCreateFn(
moduleOp, kGenericFree, getVoidPtr(moduleOp->getContext()),
- LLVM::LLVMVoidType::get(moduleOp->getContext()));
+ LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
}
LLVM::LLVMFuncOp
@@ -177,5 +193,5 @@ mlir::LLVM::lookupOrCreateMemRefCopyFn(Operation *moduleOp, Type indexType,
return LLVM::lookupOrCreateFn(
moduleOp, kMemRefCopy,
ArrayRef<Type>{indexType, unrankedDescriptorType, unrankedDescriptorType},
- LLVM::LLVMVoidType::get(moduleOp->getContext()));
+ LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
}
>From cc9e7dd0813bc030216907aaeaea964d313567cc Mon Sep 17 00:00:00 2001
From: Luohao Wang <luohaothu at live.com>
Date: Sat, 18 Jan 2025 01:43:28 +0800
Subject: [PATCH 2/3] [mlir] Add test
---
mlir/test/Conversion/MemRefToLLVM/issue-120950.mlir | 11 +++++++++++
1 file changed, 11 insertions(+)
create mode 100644 mlir/test/Conversion/MemRefToLLVM/issue-120950.mlir
diff --git a/mlir/test/Conversion/MemRefToLLVM/issue-120950.mlir b/mlir/test/Conversion/MemRefToLLVM/issue-120950.mlir
new file mode 100644
index 00000000000000..f744e4f7635ea7
--- /dev/null
+++ b/mlir/test/Conversion/MemRefToLLVM/issue-120950.mlir
@@ -0,0 +1,11 @@
+// RUN: mlir-opt %s -finalize-memref-to-llvm 2>&1 | FileCheck %s
+
+#map = affine_map<(d0) -> (d0 + 1)>
+module {
+ // CHECK: redefinition of reserved function 'malloc' of different type '!llvm.func<void (i64)>' is prohibited
+ llvm.func @malloc(i64)
+ func.func @issue_120950() {
+ %alloc = memref.alloc() : memref<1024x64xf32, 1>
+ llvm.return
+ }
+}
>From 832a33e098f2833f87258d580d638f01e4869582 Mon Sep 17 00:00:00 2001
From: Luohao Wang <luohaothu at live.com>
Date: Sat, 18 Jan 2025 02:12:39 +0800
Subject: [PATCH 3/3] [mlir] Reformat code
---
.../Dialect/LLVMIR/IR/FunctionCallUtils.cpp | 77 +++++++++++--------
1 file changed, 45 insertions(+), 32 deletions(-)
diff --git a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
index ecc31df40ea52f..757a1acf3626f6 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
@@ -48,7 +48,8 @@ static constexpr llvm::StringRef kMemRefCopy = "memrefCopy";
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFn(Operation *moduleOp,
StringRef name,
ArrayRef<Type> paramTypes,
- Type resultType, bool isVarArg, bool isReserved) {
+ Type resultType, bool isVarArg,
+ bool isReserved) {
assert(moduleOp->hasTrait<OpTrait::SymbolTable>() &&
"expected SymbolTable operation");
auto func = llvm::dyn_cast_or_null<LLVM::LLVMFuncOp>(
@@ -58,14 +59,16 @@ LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFn(Operation *moduleOp,
if (func) {
if (funcT != func.getFunctionType()) {
if (isReserved) {
- func.emitError("redefinition of reserved function '" + name + "' of different type ")
- .append(func.getFunctionType())
- .append(" is prohibited");
+ func.emitError("redefinition of reserved function '" + name +
+ "' of different type ")
+ .append(func.getFunctionType())
+ .append(" is prohibited");
exit(0);
} else {
- func.emitError("redefinition of function '" + name + "' of different type ")
- .append(funcT)
- .append(" is prohibited");
+ func.emitError("redefinition of function '" + name +
+ "' of different type ")
+ .append(funcT)
+ .append(" is prohibited");
exit(0);
}
}
@@ -78,39 +81,41 @@ LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFn(Operation *moduleOp,
}
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintI64Fn(Operation *moduleOp) {
- return lookupOrCreateFn(moduleOp, kPrintI64,
- IntegerType::get(moduleOp->getContext(), 64),
- LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
+ return lookupOrCreateFn(
+ moduleOp, kPrintI64, IntegerType::get(moduleOp->getContext(), 64),
+ LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
}
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintU64Fn(Operation *moduleOp) {
- return lookupOrCreateFn(moduleOp, kPrintU64,
- IntegerType::get(moduleOp->getContext(), 64),
- LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
+ return lookupOrCreateFn(
+ moduleOp, kPrintU64, IntegerType::get(moduleOp->getContext(), 64),
+ LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
}
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF16Fn(Operation *moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintF16,
IntegerType::get(moduleOp->getContext(), 16), // bits!
- LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
+ LLVM::LLVMVoidType::get(moduleOp->getContext()),
+ false, true);
}
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintBF16Fn(Operation *moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintBF16,
IntegerType::get(moduleOp->getContext(), 16), // bits!
- LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
+ LLVM::LLVMVoidType::get(moduleOp->getContext()),
+ false, true);
}
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF32Fn(Operation *moduleOp) {
- return lookupOrCreateFn(moduleOp, kPrintF32,
- Float32Type::get(moduleOp->getContext()),
- LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
+ return lookupOrCreateFn(
+ moduleOp, kPrintF32, Float32Type::get(moduleOp->getContext()),
+ LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
}
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF64Fn(Operation *moduleOp) {
- return lookupOrCreateFn(moduleOp, kPrintF64,
- Float64Type::get(moduleOp->getContext()),
- LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
+ return lookupOrCreateFn(
+ moduleOp, kPrintF64, Float64Type::get(moduleOp->getContext()),
+ LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
}
static LLVM::LLVMPointerType getCharPtr(MLIRContext *context) {
@@ -126,39 +131,46 @@ LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintStringFn(
Operation *moduleOp, std::optional<StringRef> runtimeFunctionName) {
return lookupOrCreateFn(moduleOp, runtimeFunctionName.value_or(kPrintString),
getCharPtr(moduleOp->getContext()),
- LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
+ LLVM::LLVMVoidType::get(moduleOp->getContext()),
+ false, true);
}
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintOpenFn(Operation *moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintOpen, {},
- LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
+ LLVM::LLVMVoidType::get(moduleOp->getContext()),
+ false, true);
}
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCloseFn(Operation *moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintClose, {},
- LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
+ LLVM::LLVMVoidType::get(moduleOp->getContext()),
+ false, true);
}
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCommaFn(Operation *moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintComma, {},
- LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
+ LLVM::LLVMVoidType::get(moduleOp->getContext()),
+ false, true);
}
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintNewlineFn(Operation *moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintNewline, {},
- LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
+ LLVM::LLVMVoidType::get(moduleOp->getContext()),
+ false, true);
}
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateMallocFn(Operation *moduleOp,
Type indexType) {
return LLVM::lookupOrCreateFn(moduleOp, kMalloc, indexType,
- getVoidPtr(moduleOp->getContext()), false, true);
+ getVoidPtr(moduleOp->getContext()), false,
+ true);
}
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateAlignedAllocFn(Operation *moduleOp,
Type indexType) {
return LLVM::lookupOrCreateFn(moduleOp, kAlignedAlloc, {indexType, indexType},
- getVoidPtr(moduleOp->getContext()), false, true);
+ getVoidPtr(moduleOp->getContext()), false,
+ true);
}
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFreeFn(Operation *moduleOp) {
@@ -170,15 +182,16 @@ LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFreeFn(Operation *moduleOp) {
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateGenericAllocFn(Operation *moduleOp,
Type indexType) {
return LLVM::lookupOrCreateFn(moduleOp, kGenericAlloc, indexType,
- getVoidPtr(moduleOp->getContext()), false, true);
+ getVoidPtr(moduleOp->getContext()), false,
+ true);
}
LLVM::LLVMFuncOp
mlir::LLVM::lookupOrCreateGenericAlignedAllocFn(Operation *moduleOp,
Type indexType) {
- return LLVM::lookupOrCreateFn(moduleOp, kGenericAlignedAlloc,
- {indexType, indexType},
- getVoidPtr(moduleOp->getContext()), false, true);
+ return LLVM::lookupOrCreateFn(
+ moduleOp, kGenericAlignedAlloc, {indexType, indexType},
+ getVoidPtr(moduleOp->getContext()), false, true);
}
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateGenericFreeFn(Operation *moduleOp) {
More information about the Mlir-commits
mailing list