[Mlir-commits] [mlir] [mlir] Fix conflict of user defined reserved functions with internal prototypes (PR #123378)
Luohao Wang
llvmlistbot at llvm.org
Sat Jan 25 23:31:11 PST 2025
https://github.com/Luohaothu updated https://github.com/llvm/llvm-project/pull/123378
>From 3b892e6a0618cd3092a16aa2e56bb80594c9c4ba Mon Sep 17 00:00:00 2001
From: Luohao Wang <luohaothu at live.com>
Date: Sat, 18 Jan 2025 01:43:09 +0800
Subject: [PATCH 1/9] [mlir] Add assertion on reserved function's type
---
.../mlir/Dialect/LLVMIR/FunctionCallUtils.h | 3 +-
.../Dialect/LLVMIR/IR/FunctionCallUtils.cpp | 56 ++++++++++++-------
2 files changed, 38 insertions(+), 21 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
index 852490cf7428f8..3095c83b90db9e 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
@@ -64,7 +64,8 @@ LLVM::LLVMFuncOp lookupOrCreateMemRefCopyFn(Operation *moduleOp, Type indexType,
/// Create a FuncOp with signature `resultType`(`paramTypes`)` and name `name`.
LLVM::LLVMFuncOp lookupOrCreateFn(Operation *moduleOp, StringRef name,
ArrayRef<Type> paramTypes = {},
- Type resultType = {}, bool isVarArg = false);
+ Type resultType = {}, bool isVarArg = false,
+ bool isReserved = false);
} // namespace LLVM
} // namespace mlir
diff --git a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
index 88421a16ccf9fb..ecc31df40ea52f 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
@@ -48,13 +48,29 @@ static constexpr llvm::StringRef kMemRefCopy = "memrefCopy";
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFn(Operation *moduleOp,
StringRef name,
ArrayRef<Type> paramTypes,
- Type resultType, bool isVarArg) {
+ Type resultType, bool isVarArg, bool isReserved) {
assert(moduleOp->hasTrait<OpTrait::SymbolTable>() &&
"expected SymbolTable operation");
auto func = llvm::dyn_cast_or_null<LLVM::LLVMFuncOp>(
SymbolTable::lookupSymbolIn(moduleOp, name));
- if (func)
+ auto funcT = LLVMFunctionType::get(resultType, paramTypes, isVarArg);
+ // Assert the signature of the found function is same as expected
+ if (func) {
+ if (funcT != func.getFunctionType()) {
+ if (isReserved) {
+ func.emitError("redefinition of reserved function '" + name + "' of different type ")
+ .append(func.getFunctionType())
+ .append(" is prohibited");
+ exit(0);
+ } else {
+ func.emitError("redefinition of function '" + name + "' of different type ")
+ .append(funcT)
+ .append(" is prohibited");
+ exit(0);
+ }
+ }
return func;
+ }
OpBuilder b(moduleOp->getRegion(0));
return b.create<LLVM::LLVMFuncOp>(
moduleOp->getLoc(), name,
@@ -64,37 +80,37 @@ LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFn(Operation *moduleOp,
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintI64Fn(Operation *moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintI64,
IntegerType::get(moduleOp->getContext(), 64),
- LLVM::LLVMVoidType::get(moduleOp->getContext()));
+ LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
}
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintU64Fn(Operation *moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintU64,
IntegerType::get(moduleOp->getContext(), 64),
- LLVM::LLVMVoidType::get(moduleOp->getContext()));
+ LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
}
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF16Fn(Operation *moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintF16,
IntegerType::get(moduleOp->getContext(), 16), // bits!
- LLVM::LLVMVoidType::get(moduleOp->getContext()));
+ LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
}
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintBF16Fn(Operation *moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintBF16,
IntegerType::get(moduleOp->getContext(), 16), // bits!
- LLVM::LLVMVoidType::get(moduleOp->getContext()));
+ LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
}
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF32Fn(Operation *moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintF32,
Float32Type::get(moduleOp->getContext()),
- LLVM::LLVMVoidType::get(moduleOp->getContext()));
+ LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
}
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF64Fn(Operation *moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintF64,
Float64Type::get(moduleOp->getContext()),
- LLVM::LLVMVoidType::get(moduleOp->getContext()));
+ LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
}
static LLVM::LLVMPointerType getCharPtr(MLIRContext *context) {
@@ -110,51 +126,51 @@ LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintStringFn(
Operation *moduleOp, std::optional<StringRef> runtimeFunctionName) {
return lookupOrCreateFn(moduleOp, runtimeFunctionName.value_or(kPrintString),
getCharPtr(moduleOp->getContext()),
- LLVM::LLVMVoidType::get(moduleOp->getContext()));
+ LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
}
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintOpenFn(Operation *moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintOpen, {},
- LLVM::LLVMVoidType::get(moduleOp->getContext()));
+ LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
}
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCloseFn(Operation *moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintClose, {},
- LLVM::LLVMVoidType::get(moduleOp->getContext()));
+ LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
}
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCommaFn(Operation *moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintComma, {},
- LLVM::LLVMVoidType::get(moduleOp->getContext()));
+ LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
}
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintNewlineFn(Operation *moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintNewline, {},
- LLVM::LLVMVoidType::get(moduleOp->getContext()));
+ LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
}
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateMallocFn(Operation *moduleOp,
Type indexType) {
return LLVM::lookupOrCreateFn(moduleOp, kMalloc, indexType,
- getVoidPtr(moduleOp->getContext()));
+ getVoidPtr(moduleOp->getContext()), false, true);
}
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateAlignedAllocFn(Operation *moduleOp,
Type indexType) {
return LLVM::lookupOrCreateFn(moduleOp, kAlignedAlloc, {indexType, indexType},
- getVoidPtr(moduleOp->getContext()));
+ getVoidPtr(moduleOp->getContext()), false, true);
}
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFreeFn(Operation *moduleOp) {
return LLVM::lookupOrCreateFn(
moduleOp, kFree, getVoidPtr(moduleOp->getContext()),
- LLVM::LLVMVoidType::get(moduleOp->getContext()));
+ LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
}
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateGenericAllocFn(Operation *moduleOp,
Type indexType) {
return LLVM::lookupOrCreateFn(moduleOp, kGenericAlloc, indexType,
- getVoidPtr(moduleOp->getContext()));
+ getVoidPtr(moduleOp->getContext()), false, true);
}
LLVM::LLVMFuncOp
@@ -162,13 +178,13 @@ mlir::LLVM::lookupOrCreateGenericAlignedAllocFn(Operation *moduleOp,
Type indexType) {
return LLVM::lookupOrCreateFn(moduleOp, kGenericAlignedAlloc,
{indexType, indexType},
- getVoidPtr(moduleOp->getContext()));
+ getVoidPtr(moduleOp->getContext()), false, true);
}
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateGenericFreeFn(Operation *moduleOp) {
return LLVM::lookupOrCreateFn(
moduleOp, kGenericFree, getVoidPtr(moduleOp->getContext()),
- LLVM::LLVMVoidType::get(moduleOp->getContext()));
+ LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
}
LLVM::LLVMFuncOp
@@ -177,5 +193,5 @@ mlir::LLVM::lookupOrCreateMemRefCopyFn(Operation *moduleOp, Type indexType,
return LLVM::lookupOrCreateFn(
moduleOp, kMemRefCopy,
ArrayRef<Type>{indexType, unrankedDescriptorType, unrankedDescriptorType},
- LLVM::LLVMVoidType::get(moduleOp->getContext()));
+ LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
}
>From bdb69fc838254c35ea55542c39dbf9392cd6d4b2 Mon Sep 17 00:00:00 2001
From: Luohao Wang <luohaothu at live.com>
Date: Sat, 18 Jan 2025 01:43:28 +0800
Subject: [PATCH 2/9] [mlir] Add test
---
mlir/test/Conversion/MemRefToLLVM/issue-120950.mlir | 11 +++++++++++
1 file changed, 11 insertions(+)
create mode 100644 mlir/test/Conversion/MemRefToLLVM/issue-120950.mlir
diff --git a/mlir/test/Conversion/MemRefToLLVM/issue-120950.mlir b/mlir/test/Conversion/MemRefToLLVM/issue-120950.mlir
new file mode 100644
index 00000000000000..f744e4f7635ea7
--- /dev/null
+++ b/mlir/test/Conversion/MemRefToLLVM/issue-120950.mlir
@@ -0,0 +1,11 @@
+// RUN: mlir-opt %s -finalize-memref-to-llvm 2>&1 | FileCheck %s
+
+#map = affine_map<(d0) -> (d0 + 1)>
+module {
+ // CHECK: redefinition of reserved function 'malloc' of different type '!llvm.func<void (i64)>' is prohibited
+ llvm.func @malloc(i64)
+ func.func @issue_120950() {
+ %alloc = memref.alloc() : memref<1024x64xf32, 1>
+ llvm.return
+ }
+}
>From d26a77d431b4b18ed5b320185a55f0984c6f3aea Mon Sep 17 00:00:00 2001
From: Luohao Wang <luohaothu at live.com>
Date: Sat, 18 Jan 2025 02:12:39 +0800
Subject: [PATCH 3/9] [mlir] Reformat code
---
.../Dialect/LLVMIR/IR/FunctionCallUtils.cpp | 77 +++++++++++--------
1 file changed, 45 insertions(+), 32 deletions(-)
diff --git a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
index ecc31df40ea52f..757a1acf3626f6 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
@@ -48,7 +48,8 @@ static constexpr llvm::StringRef kMemRefCopy = "memrefCopy";
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFn(Operation *moduleOp,
StringRef name,
ArrayRef<Type> paramTypes,
- Type resultType, bool isVarArg, bool isReserved) {
+ Type resultType, bool isVarArg,
+ bool isReserved) {
assert(moduleOp->hasTrait<OpTrait::SymbolTable>() &&
"expected SymbolTable operation");
auto func = llvm::dyn_cast_or_null<LLVM::LLVMFuncOp>(
@@ -58,14 +59,16 @@ LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFn(Operation *moduleOp,
if (func) {
if (funcT != func.getFunctionType()) {
if (isReserved) {
- func.emitError("redefinition of reserved function '" + name + "' of different type ")
- .append(func.getFunctionType())
- .append(" is prohibited");
+ func.emitError("redefinition of reserved function '" + name +
+ "' of different type ")
+ .append(func.getFunctionType())
+ .append(" is prohibited");
exit(0);
} else {
- func.emitError("redefinition of function '" + name + "' of different type ")
- .append(funcT)
- .append(" is prohibited");
+ func.emitError("redefinition of function '" + name +
+ "' of different type ")
+ .append(funcT)
+ .append(" is prohibited");
exit(0);
}
}
@@ -78,39 +81,41 @@ LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFn(Operation *moduleOp,
}
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintI64Fn(Operation *moduleOp) {
- return lookupOrCreateFn(moduleOp, kPrintI64,
- IntegerType::get(moduleOp->getContext(), 64),
- LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
+ return lookupOrCreateFn(
+ moduleOp, kPrintI64, IntegerType::get(moduleOp->getContext(), 64),
+ LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
}
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintU64Fn(Operation *moduleOp) {
- return lookupOrCreateFn(moduleOp, kPrintU64,
- IntegerType::get(moduleOp->getContext(), 64),
- LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
+ return lookupOrCreateFn(
+ moduleOp, kPrintU64, IntegerType::get(moduleOp->getContext(), 64),
+ LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
}
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF16Fn(Operation *moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintF16,
IntegerType::get(moduleOp->getContext(), 16), // bits!
- LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
+ LLVM::LLVMVoidType::get(moduleOp->getContext()),
+ false, true);
}
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintBF16Fn(Operation *moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintBF16,
IntegerType::get(moduleOp->getContext(), 16), // bits!
- LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
+ LLVM::LLVMVoidType::get(moduleOp->getContext()),
+ false, true);
}
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF32Fn(Operation *moduleOp) {
- return lookupOrCreateFn(moduleOp, kPrintF32,
- Float32Type::get(moduleOp->getContext()),
- LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
+ return lookupOrCreateFn(
+ moduleOp, kPrintF32, Float32Type::get(moduleOp->getContext()),
+ LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
}
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF64Fn(Operation *moduleOp) {
- return lookupOrCreateFn(moduleOp, kPrintF64,
- Float64Type::get(moduleOp->getContext()),
- LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
+ return lookupOrCreateFn(
+ moduleOp, kPrintF64, Float64Type::get(moduleOp->getContext()),
+ LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
}
static LLVM::LLVMPointerType getCharPtr(MLIRContext *context) {
@@ -126,39 +131,46 @@ LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintStringFn(
Operation *moduleOp, std::optional<StringRef> runtimeFunctionName) {
return lookupOrCreateFn(moduleOp, runtimeFunctionName.value_or(kPrintString),
getCharPtr(moduleOp->getContext()),
- LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
+ LLVM::LLVMVoidType::get(moduleOp->getContext()),
+ false, true);
}
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintOpenFn(Operation *moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintOpen, {},
- LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
+ LLVM::LLVMVoidType::get(moduleOp->getContext()),
+ false, true);
}
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCloseFn(Operation *moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintClose, {},
- LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
+ LLVM::LLVMVoidType::get(moduleOp->getContext()),
+ false, true);
}
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCommaFn(Operation *moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintComma, {},
- LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
+ LLVM::LLVMVoidType::get(moduleOp->getContext()),
+ false, true);
}
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintNewlineFn(Operation *moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintNewline, {},
- LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
+ LLVM::LLVMVoidType::get(moduleOp->getContext()),
+ false, true);
}
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateMallocFn(Operation *moduleOp,
Type indexType) {
return LLVM::lookupOrCreateFn(moduleOp, kMalloc, indexType,
- getVoidPtr(moduleOp->getContext()), false, true);
+ getVoidPtr(moduleOp->getContext()), false,
+ true);
}
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateAlignedAllocFn(Operation *moduleOp,
Type indexType) {
return LLVM::lookupOrCreateFn(moduleOp, kAlignedAlloc, {indexType, indexType},
- getVoidPtr(moduleOp->getContext()), false, true);
+ getVoidPtr(moduleOp->getContext()), false,
+ true);
}
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFreeFn(Operation *moduleOp) {
@@ -170,15 +182,16 @@ LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFreeFn(Operation *moduleOp) {
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateGenericAllocFn(Operation *moduleOp,
Type indexType) {
return LLVM::lookupOrCreateFn(moduleOp, kGenericAlloc, indexType,
- getVoidPtr(moduleOp->getContext()), false, true);
+ getVoidPtr(moduleOp->getContext()), false,
+ true);
}
LLVM::LLVMFuncOp
mlir::LLVM::lookupOrCreateGenericAlignedAllocFn(Operation *moduleOp,
Type indexType) {
- return LLVM::lookupOrCreateFn(moduleOp, kGenericAlignedAlloc,
- {indexType, indexType},
- getVoidPtr(moduleOp->getContext()), false, true);
+ return LLVM::lookupOrCreateFn(
+ moduleOp, kGenericAlignedAlloc, {indexType, indexType},
+ getVoidPtr(moduleOp->getContext()), false, true);
}
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateGenericFreeFn(Operation *moduleOp) {
>From 5ec38d6c09ff700dab72d27e3610c7866a23c7fc Mon Sep 17 00:00:00 2001
From: Luohao Wang <Luohaothu at users.noreply.github.com>
Date: Tue, 21 Jan 2025 17:39:13 +0800
Subject: [PATCH 4/9] [mlir] Wrapped return value of function lookup in
`FailureOr` for error handling
---
.../Conversion/LLVMCommon/PrintCallHelper.h | 2 +-
.../mlir/Dialect/LLVMIR/FunctionCallUtils.h | 43 ++---
.../Conversion/AsyncToLLVM/AsyncToLLVM.cpp | 8 +-
.../ControlFlowToLLVM/ControlFlowToLLVM.cpp | 6 +-
mlir/lib/Conversion/LLVMCommon/Pattern.cpp | 16 +-
.../Conversion/LLVMCommon/PrintCallHelper.cpp | 14 +-
.../MemRefToLLVM/AllocLikeConversion.cpp | 14 +-
.../Conversion/MemRefToLLVM/MemRefToLLVM.cpp | 15 +-
.../VectorToLLVM/ConvertVectorToLLVM.cpp | 48 +++--
.../Dialect/LLVMIR/IR/FunctionCallUtils.cpp | 174 ++++++++++--------
10 files changed, 196 insertions(+), 144 deletions(-)
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h b/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h
index c2742b6fc1d737..5af86956c0ad92 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h
@@ -23,7 +23,7 @@ namespace LLVM {
/// Generate IR that prints the given string to stdout.
/// If a custom runtime function is defined via `runtimeFunctionName`, it must
/// have the signature void(char const*). The default function is `printString`.
-void createPrintStrCall(OpBuilder &builder, Location loc, ModuleOp moduleOp,
+LogicalResult createPrintStrCall(OpBuilder &builder, Location loc, ModuleOp moduleOp,
StringRef symbolName, StringRef string,
const LLVMTypeConverter &typeConverter,
bool addNewline = true,
diff --git a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
index 3095c83b90db9e..473a69019d2399 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
@@ -16,7 +16,6 @@
#include "mlir/IR/Operation.h"
#include "mlir/Support/LLVM.h"
-#include <optional>
namespace mlir {
class Location;
@@ -29,40 +28,42 @@ class ValueRange;
namespace LLVM {
class LLVMFuncOp;
-/// Helper functions to lookup or create the declaration for commonly used
+/// Helper functions to look up or create the declaration for commonly used
/// external C function calls. The list of functions provided here must be
/// implemented separately (e.g. as part of a support runtime library or as part
/// of the libc).
-LLVM::LLVMFuncOp lookupOrCreatePrintI64Fn(Operation *moduleOp);
-LLVM::LLVMFuncOp lookupOrCreatePrintU64Fn(Operation *moduleOp);
-LLVM::LLVMFuncOp lookupOrCreatePrintF16Fn(Operation *moduleOp);
-LLVM::LLVMFuncOp lookupOrCreatePrintBF16Fn(Operation *moduleOp);
-LLVM::LLVMFuncOp lookupOrCreatePrintF32Fn(Operation *moduleOp);
-LLVM::LLVMFuncOp lookupOrCreatePrintF64Fn(Operation *moduleOp);
+/// Failure if an unexpected version of function is found.
+FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintI64Fn(Operation *moduleOp);
+FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintU64Fn(Operation *moduleOp);
+FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintF16Fn(Operation *moduleOp);
+FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintBF16Fn(Operation *moduleOp);
+FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintF32Fn(Operation *moduleOp);
+FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintF64Fn(Operation *moduleOp);
/// Declares a function to print a C-string.
/// If a custom runtime function is defined via `runtimeFunctionName`, it must
/// have the signature void(char const*). The default function is `printString`.
-LLVM::LLVMFuncOp
+FailureOr<LLVM::LLVMFuncOp>
lookupOrCreatePrintStringFn(Operation *moduleOp,
std::optional<StringRef> runtimeFunctionName = {});
-LLVM::LLVMFuncOp lookupOrCreatePrintOpenFn(Operation *moduleOp);
-LLVM::LLVMFuncOp lookupOrCreatePrintCloseFn(Operation *moduleOp);
-LLVM::LLVMFuncOp lookupOrCreatePrintCommaFn(Operation *moduleOp);
-LLVM::LLVMFuncOp lookupOrCreatePrintNewlineFn(Operation *moduleOp);
-LLVM::LLVMFuncOp lookupOrCreateMallocFn(Operation *moduleOp, Type indexType);
-LLVM::LLVMFuncOp lookupOrCreateAlignedAllocFn(Operation *moduleOp,
+FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintOpenFn(Operation *moduleOp);
+FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintCloseFn(Operation *moduleOp);
+FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintCommaFn(Operation *moduleOp);
+FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintNewlineFn(Operation *moduleOp);
+FailureOr<LLVM::LLVMFuncOp> lookupOrCreateMallocFn(Operation *moduleOp, Type indexType);
+FailureOr<LLVM::LLVMFuncOp> lookupOrCreateAlignedAllocFn(Operation *moduleOp,
Type indexType);
-LLVM::LLVMFuncOp lookupOrCreateFreeFn(Operation *moduleOp);
-LLVM::LLVMFuncOp lookupOrCreateGenericAllocFn(Operation *moduleOp,
+FailureOr<LLVM::LLVMFuncOp> lookupOrCreateFreeFn(Operation *moduleOp);
+FailureOr<LLVM::LLVMFuncOp> lookupOrCreateGenericAllocFn(Operation *moduleOp,
Type indexType);
-LLVM::LLVMFuncOp lookupOrCreateGenericAlignedAllocFn(Operation *moduleOp,
+FailureOr<LLVM::LLVMFuncOp> lookupOrCreateGenericAlignedAllocFn(Operation *moduleOp,
Type indexType);
-LLVM::LLVMFuncOp lookupOrCreateGenericFreeFn(Operation *moduleOp);
-LLVM::LLVMFuncOp lookupOrCreateMemRefCopyFn(Operation *moduleOp, Type indexType,
+FailureOr<LLVM::LLVMFuncOp> lookupOrCreateGenericFreeFn(Operation *moduleOp);
+FailureOr<LLVM::LLVMFuncOp> lookupOrCreateMemRefCopyFn(Operation *moduleOp, Type indexType,
Type unrankedDescriptorType);
/// Create a FuncOp with signature `resultType`(`paramTypes`)` and name `name`.
-LLVM::LLVMFuncOp lookupOrCreateFn(Operation *moduleOp, StringRef name,
+/// Return a failure if the FuncOp found has unexpected signature.
+FailureOr<LLVM::LLVMFuncOp> lookupOrCreateFn(Operation *moduleOp, StringRef name,
ArrayRef<Type> paramTypes = {},
Type resultType = {}, bool isVarArg = false,
bool isReserved = false);
diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
index 9b5aeb3fef30b4..47d4474a5c28d7 100644
--- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
+++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
@@ -396,8 +396,10 @@ class CoroBeginOpConversion : public AsyncOpConversionPattern<CoroBeginOp> {
// Allocate memory for the coroutine frame.
auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn(
op->getParentOfType<ModuleOp>(), rewriter.getI64Type());
+ if (failed(allocFuncOp))
+ return failure();
auto coroAlloc = rewriter.create<LLVM::CallOp>(
- loc, allocFuncOp, ValueRange{coroAlign, coroSize});
+ loc, allocFuncOp.value(), ValueRange{coroAlign, coroSize});
// Begin a coroutine: @llvm.coro.begin.
auto coroId = CoroBeginOpAdaptor(adaptor.getOperands()).getId();
@@ -431,7 +433,9 @@ class CoroFreeOpConversion : public AsyncOpConversionPattern<CoroFreeOp> {
// Free the memory.
auto freeFuncOp =
LLVM::lookupOrCreateFreeFn(op->getParentOfType<ModuleOp>());
- rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, freeFuncOp,
+ if (failed(freeFuncOp))
+ return failure();
+ rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, freeFuncOp.value(),
ValueRange(coroMem.getResult()));
return success();
diff --git a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
index d0ffb94f3f96a9..cdcb613e04ab12 100644
--- a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
+++ b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
@@ -61,9 +61,11 @@ struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {
// Failed block: Generate IR to print the message and call `abort`.
Block *failureBlock = rewriter.createBlock(opBlock->getParent());
- LLVM::createPrintStrCall(rewriter, loc, module, "assert_msg", op.getMsg(),
+ if (LLVM::createPrintStrCall(rewriter, loc, module, "assert_msg", op.getMsg(),
*getTypeConverter(), /*addNewLine=*/false,
- /*runtimeFunctionName=*/"puts");
+ /*runtimeFunctionName=*/"puts").failed()) {
+ return failure();
+ }
if (abortOnFailedAssert) {
// Insert the `abort` declaration if necessary.
auto abortFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("abort");
diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
index a47a2872ceb073..10f72cda7706db 100644
--- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
@@ -276,11 +276,17 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
// Find the malloc and free, or declare them if necessary.
auto module = builder.getInsertionPoint()->getParentOfType<ModuleOp>();
- LLVM::LLVMFuncOp freeFunc, mallocFunc;
- if (toDynamic)
+ FailureOr<LLVM::LLVMFuncOp> freeFunc, mallocFunc;
+ if (toDynamic) {
mallocFunc = LLVM::lookupOrCreateMallocFn(module, indexType);
- if (!toDynamic)
+ if (failed(mallocFunc))
+ return failure();
+ }
+ if (!toDynamic) {
freeFunc = LLVM::lookupOrCreateFreeFn(module);
+ if (failed(freeFunc))
+ return failure();
+ }
unsigned unrankedMemrefPos = 0;
for (unsigned i = 0, e = operands.size(); i < e; ++i) {
@@ -293,7 +299,7 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
// Allocate memory, copy, and free the source if necessary.
Value memory =
toDynamic
- ? builder.create<LLVM::CallOp>(loc, mallocFunc, allocationSize)
+ ? builder.create<LLVM::CallOp>(loc, mallocFunc.value(), allocationSize)
.getResult()
: builder.create<LLVM::AllocaOp>(loc, getVoidPtrType(),
IntegerType::get(getContext(), 8),
@@ -302,7 +308,7 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
Value source = desc.memRefDescPtr(builder, loc);
builder.create<LLVM::MemcpyOp>(loc, memory, source, allocationSize, false);
if (!toDynamic)
- builder.create<LLVM::CallOp>(loc, freeFunc, source);
+ builder.create<LLVM::CallOp>(loc, freeFunc.value(), source);
// Create a new descriptor. The same descriptor can be returned multiple
// times, attempting to modify its pointer can lead to memory leaks
diff --git a/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp b/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp
index bd7b401efec17a..607e1d65045523 100644
--- a/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp
@@ -27,7 +27,7 @@ static std::string ensureSymbolNameIsUnique(ModuleOp moduleOp,
return uniqueName;
}
-void mlir::LLVM::createPrintStrCall(
+LogicalResult mlir::LLVM::createPrintStrCall(
OpBuilder &builder, Location loc, ModuleOp moduleOp, StringRef symbolName,
StringRef string, const LLVMTypeConverter &typeConverter, bool addNewline,
std::optional<StringRef> runtimeFunctionName) {
@@ -59,8 +59,12 @@ void mlir::LLVM::createPrintStrCall(
SmallVector<LLVM::GEPArg> indices(1, 0);
Value gep =
builder.create<LLVM::GEPOp>(loc, ptrTy, arrayTy, msgAddr, indices);
- Operation *printer =
- LLVM::lookupOrCreatePrintStringFn(moduleOp, runtimeFunctionName);
- builder.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(printer),
- gep);
+ if (auto printer =
+ LLVM::lookupOrCreatePrintStringFn(moduleOp, runtimeFunctionName); succeeded(printer)) {
+ builder.create<LLVM::CallOp>(loc, TypeRange(),
+ SymbolRefAttr::get(printer.value()), gep);
+ } else {
+ return failure();
+ }
+ return success();
}
diff --git a/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp b/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
index a6408391b1330c..0ee92722157f35 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
@@ -15,7 +15,7 @@
using namespace mlir;
namespace {
-LLVM::LLVMFuncOp getNotalignedAllocFn(const LLVMTypeConverter *typeConverter,
+FailureOr<LLVM::LLVMFuncOp> getNotalignedAllocFn(const LLVMTypeConverter *typeConverter,
Operation *module, Type indexType) {
bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
if (useGenericFn)
@@ -24,7 +24,7 @@ LLVM::LLVMFuncOp getNotalignedAllocFn(const LLVMTypeConverter *typeConverter,
return LLVM::lookupOrCreateMallocFn(module, indexType);
}
-LLVM::LLVMFuncOp getAlignedAllocFn(const LLVMTypeConverter *typeConverter,
+FailureOr<LLVM::LLVMFuncOp> getAlignedAllocFn(const LLVMTypeConverter *typeConverter,
Operation *module, Type indexType) {
bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
@@ -80,10 +80,11 @@ std::tuple<Value, Value> AllocationOpLLVMLowering::allocateBufferManuallyAlign(
<< " to integer address space "
"failed. Consider adding memory space conversions.";
}
- LLVM::LLVMFuncOp allocFuncOp = getNotalignedAllocFn(
+ FailureOr<LLVM::LLVMFuncOp> allocFuncOp = getNotalignedAllocFn(
getTypeConverter(), op->getParentWithTrait<OpTrait::SymbolTable>(),
getIndexType());
- auto results = rewriter.create<LLVM::CallOp>(loc, allocFuncOp, sizeBytes);
+ if (failed(allocFuncOp)) return std::make_tuple(Value(), Value());
+ auto results = rewriter.create<LLVM::CallOp>(loc, allocFuncOp.value(), sizeBytes);
Value allocatedPtr =
castAllocFuncResult(rewriter, loc, results.getResult(), memRefType,
@@ -146,11 +147,12 @@ Value AllocationOpLLVMLowering::allocateBufferAutoAlign(
sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment);
Type elementPtrType = this->getElementPtrType(memRefType);
- LLVM::LLVMFuncOp allocFuncOp = getAlignedAllocFn(
+ FailureOr<LLVM::LLVMFuncOp> allocFuncOp = getAlignedAllocFn(
getTypeConverter(), op->getParentWithTrait<OpTrait::SymbolTable>(),
getIndexType());
+ if (failed(allocFuncOp)) return Value();
auto results = rewriter.create<LLVM::CallOp>(
- loc, allocFuncOp, ValueRange({allocAlignment, sizeBytes}));
+ loc, allocFuncOp.value(), ValueRange({allocAlignment, sizeBytes}));
return castAllocFuncResult(rewriter, loc, results.getResult(), memRefType,
elementPtrType, *getTypeConverter());
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index f7542b8b3bc5c7..ac27e0dd09bdcd 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -42,8 +42,8 @@ bool isStaticStrideOrOffset(int64_t strideOrOffset) {
return !ShapedType::isDynamic(strideOrOffset);
}
-LLVM::LLVMFuncOp getFreeFn(const LLVMTypeConverter *typeConverter,
- ModuleOp module) {
+FailureOr<LLVM::LLVMFuncOp> getFreeFn(const LLVMTypeConverter *typeConverter,
+ ModuleOp module) {
bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
if (useGenericFn)
@@ -220,8 +220,10 @@ struct DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> {
matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Insert the `free` declaration if it is not already present.
- LLVM::LLVMFuncOp freeFunc =
+ auto freeFunc =
getFreeFn(getTypeConverter(), op->getParentOfType<ModuleOp>());
+ if (failed(freeFunc))
+ return failure();
Value allocatedPtr;
if (auto unrankedTy =
llvm::dyn_cast<UnrankedMemRefType>(op.getMemref().getType())) {
@@ -236,7 +238,8 @@ struct DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> {
allocatedPtr = MemRefDescriptor(adaptor.getMemref())
.allocatedPtr(rewriter, op.getLoc());
}
- rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, freeFunc, allocatedPtr);
+ rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, freeFunc.value(),
+ allocatedPtr);
return success();
}
};
@@ -838,7 +841,9 @@ struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
auto elemSize = getSizeInBytes(loc, srcType.getElementType(), rewriter);
auto copyFn = LLVM::lookupOrCreateMemRefCopyFn(
op->getParentOfType<ModuleOp>(), getIndexType(), sourcePtr.getType());
- rewriter.create<LLVM::CallOp>(loc, copyFn,
+ if (failed(copyFn))
+ return failure();
+ rewriter.create<LLVM::CallOp>(loc, copyFn.value(),
ValueRange{elemSize, sourcePtr, targetPtr});
// Restore stack used for descriptors
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index a1e21cb524bd9a..79617506008fab 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1546,24 +1546,32 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
auto punct = printOp.getPunctuation();
if (auto stringLiteral = printOp.getStringLiteral()) {
- LLVM::createPrintStrCall(rewriter, loc, parent, "vector_print_str",
- *stringLiteral, *getTypeConverter(),
- /*addNewline=*/false);
+ if (LLVM::createPrintStrCall(rewriter, loc, parent, "vector_print_str",
+ *stringLiteral, *getTypeConverter(),
+ /*addNewline=*/false)
+ .failed()) {
+ return failure();
+ }
} else if (punct != PrintPunctuation::NoPunctuation) {
- emitCall(rewriter, printOp->getLoc(), [&] {
- switch (punct) {
- case PrintPunctuation::Close:
- return LLVM::lookupOrCreatePrintCloseFn(parent);
- case PrintPunctuation::Open:
- return LLVM::lookupOrCreatePrintOpenFn(parent);
- case PrintPunctuation::Comma:
- return LLVM::lookupOrCreatePrintCommaFn(parent);
- case PrintPunctuation::NewLine:
- return LLVM::lookupOrCreatePrintNewlineFn(parent);
- default:
- llvm_unreachable("unexpected punctuation");
- }
- }());
+ if (auto op = [&] -> FailureOr<LLVM::LLVMFuncOp> {
+ switch (punct) {
+ case PrintPunctuation::Close:
+ return LLVM::lookupOrCreatePrintCloseFn(parent);
+ case PrintPunctuation::Open:
+ return LLVM::lookupOrCreatePrintOpenFn(parent);
+ case PrintPunctuation::Comma:
+ return LLVM::lookupOrCreatePrintCommaFn(parent);
+ case PrintPunctuation::NewLine:
+ return LLVM::lookupOrCreatePrintNewlineFn(parent);
+ default:
+ llvm_unreachable("unexpected punctuation");
+ }
+ }();
+ succeeded(op))
+ emitCall(rewriter, printOp->getLoc(), op.value());
+ else {
+ return failure();
+ }
}
rewriter.eraseOp(printOp);
@@ -1588,7 +1596,7 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
// Make sure element type has runtime support.
PrintConversion conversion = PrintConversion::None;
- Operation *printer;
+ FailureOr<Operation *> printer;
if (printType.isF32()) {
printer = LLVM::lookupOrCreatePrintF32Fn(parent);
} else if (printType.isF64()) {
@@ -1631,6 +1639,8 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
} else {
return failure();
}
+ if (failed(printer))
+ return failure();
switch (conversion) {
case PrintConversion::ZeroExt64:
@@ -1648,7 +1658,7 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
case PrintConversion::None:
break;
}
- emitCall(rewriter, loc, printer, value);
+ emitCall(rewriter, loc, printer.value(), value);
return success();
}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
index 757a1acf3626f6..c2c87bc7544bd7 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
@@ -45,11 +45,10 @@ static constexpr llvm::StringRef kGenericFree = "_mlir_memref_to_llvm_free";
static constexpr llvm::StringRef kMemRefCopy = "memrefCopy";
/// Generic print function lookupOrCreate helper.
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFn(Operation *moduleOp,
- StringRef name,
- ArrayRef<Type> paramTypes,
- Type resultType, bool isVarArg,
- bool isReserved) {
+FailureOr<LLVM::LLVMFuncOp>
+mlir::LLVM::lookupOrCreateFn(Operation *moduleOp, StringRef name,
+ ArrayRef<Type> paramTypes, Type resultType,
+ bool isVarArg, bool isReserved) {
assert(moduleOp->hasTrait<OpTrait::SymbolTable>() &&
"expected SymbolTable operation");
auto func = llvm::dyn_cast_or_null<LLVM::LLVMFuncOp>(
@@ -63,14 +62,13 @@ LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFn(Operation *moduleOp,
"' of different type ")
.append(func.getFunctionType())
.append(" is prohibited");
- exit(0);
} else {
func.emitError("redefinition of function '" + name +
"' of different type ")
.append(funcT)
.append(" is prohibited");
- exit(0);
}
+ return failure();
}
return func;
}
@@ -80,42 +78,58 @@ LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFn(Operation *moduleOp,
LLVM::LLVMFunctionType::get(resultType, paramTypes, isVarArg));
}
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintI64Fn(Operation *moduleOp) {
- return lookupOrCreateFn(
+namespace {
+FailureOr<LLVM::LLVMFuncOp> lookupOrCreateReservedFn(Operation *moduleOp,
+ StringRef name,
+ ArrayRef<Type> paramTypes,
+ Type resultType) {
+ return lookupOrCreateFn(moduleOp, name, paramTypes, resultType,
+ /*isVarArg=*/false, /*isReserved=*/true);
+}
+} // namespace
+
+FailureOr<LLVM::LLVMFuncOp>
+mlir::LLVM::lookupOrCreatePrintI64Fn(Operation *moduleOp) {
+ return lookupOrCreateReservedFn(
moduleOp, kPrintI64, IntegerType::get(moduleOp->getContext(), 64),
- LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
+ LLVM::LLVMVoidType::get(moduleOp->getContext()));
}
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintU64Fn(Operation *moduleOp) {
- return lookupOrCreateFn(
+FailureOr<LLVM::LLVMFuncOp>
+mlir::LLVM::lookupOrCreatePrintU64Fn(Operation *moduleOp) {
+ return lookupOrCreateReservedFn(
moduleOp, kPrintU64, IntegerType::get(moduleOp->getContext(), 64),
- LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
+ LLVM::LLVMVoidType::get(moduleOp->getContext()));
}
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF16Fn(Operation *moduleOp) {
- return lookupOrCreateFn(moduleOp, kPrintF16,
- IntegerType::get(moduleOp->getContext(), 16), // bits!
- LLVM::LLVMVoidType::get(moduleOp->getContext()),
- false, true);
+FailureOr<LLVM::LLVMFuncOp>
+mlir::LLVM::lookupOrCreatePrintF16Fn(Operation *moduleOp) {
+ return lookupOrCreateReservedFn(
+ moduleOp, kPrintF16,
+ IntegerType::get(moduleOp->getContext(), 16), // bits!
+ LLVM::LLVMVoidType::get(moduleOp->getContext()));
}
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintBF16Fn(Operation *moduleOp) {
- return lookupOrCreateFn(moduleOp, kPrintBF16,
- IntegerType::get(moduleOp->getContext(), 16), // bits!
- LLVM::LLVMVoidType::get(moduleOp->getContext()),
- false, true);
+FailureOr<LLVM::LLVMFuncOp>
+mlir::LLVM::lookupOrCreatePrintBF16Fn(Operation *moduleOp) {
+ return lookupOrCreateReservedFn(
+ moduleOp, kPrintBF16,
+ IntegerType::get(moduleOp->getContext(), 16), // bits!
+ LLVM::LLVMVoidType::get(moduleOp->getContext()));
}
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF32Fn(Operation *moduleOp) {
- return lookupOrCreateFn(
+FailureOr<LLVM::LLVMFuncOp>
+mlir::LLVM::lookupOrCreatePrintF32Fn(Operation *moduleOp) {
+ return lookupOrCreateReservedFn(
moduleOp, kPrintF32, Float32Type::get(moduleOp->getContext()),
- LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
+ LLVM::LLVMVoidType::get(moduleOp->getContext()));
}
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF64Fn(Operation *moduleOp) {
- return lookupOrCreateFn(
+FailureOr<LLVM::LLVMFuncOp>
+mlir::LLVM::lookupOrCreatePrintF64Fn(Operation *moduleOp) {
+ return lookupOrCreateReservedFn(
moduleOp, kPrintF64, Float64Type::get(moduleOp->getContext()),
- LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
+ LLVM::LLVMVoidType::get(moduleOp->getContext()));
}
static LLVM::LLVMPointerType getCharPtr(MLIRContext *context) {
@@ -127,84 +141,88 @@ static LLVM::LLVMPointerType getVoidPtr(MLIRContext *context) {
return getCharPtr(context);
}
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintStringFn(
+FailureOr<LLVM::LLVMFuncOp> mlir::LLVM::lookupOrCreatePrintStringFn(
Operation *moduleOp, std::optional<StringRef> runtimeFunctionName) {
- return lookupOrCreateFn(moduleOp, runtimeFunctionName.value_or(kPrintString),
- getCharPtr(moduleOp->getContext()),
- LLVM::LLVMVoidType::get(moduleOp->getContext()),
- false, true);
+ return lookupOrCreateReservedFn(
+ moduleOp, runtimeFunctionName.value_or(kPrintString),
+ getCharPtr(moduleOp->getContext()),
+ LLVM::LLVMVoidType::get(moduleOp->getContext()));
}
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintOpenFn(Operation *moduleOp) {
- return lookupOrCreateFn(moduleOp, kPrintOpen, {},
- LLVM::LLVMVoidType::get(moduleOp->getContext()),
- false, true);
+FailureOr<LLVM::LLVMFuncOp>
+mlir::LLVM::lookupOrCreatePrintOpenFn(Operation *moduleOp) {
+ return lookupOrCreateReservedFn(
+ moduleOp, kPrintOpen, {},
+ LLVM::LLVMVoidType::get(moduleOp->getContext()));
}
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCloseFn(Operation *moduleOp) {
- return lookupOrCreateFn(moduleOp, kPrintClose, {},
- LLVM::LLVMVoidType::get(moduleOp->getContext()),
- false, true);
+FailureOr<LLVM::LLVMFuncOp>
+mlir::LLVM::lookupOrCreatePrintCloseFn(Operation *moduleOp) {
+ return lookupOrCreateReservedFn(
+ moduleOp, kPrintClose, {},
+ LLVM::LLVMVoidType::get(moduleOp->getContext()));
}
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCommaFn(Operation *moduleOp) {
- return lookupOrCreateFn(moduleOp, kPrintComma, {},
- LLVM::LLVMVoidType::get(moduleOp->getContext()),
- false, true);
+FailureOr<LLVM::LLVMFuncOp>
+mlir::LLVM::lookupOrCreatePrintCommaFn(Operation *moduleOp) {
+ return lookupOrCreateReservedFn(
+ moduleOp, kPrintComma, {},
+ LLVM::LLVMVoidType::get(moduleOp->getContext()));
}
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintNewlineFn(Operation *moduleOp) {
- return lookupOrCreateFn(moduleOp, kPrintNewline, {},
- LLVM::LLVMVoidType::get(moduleOp->getContext()),
- false, true);
+FailureOr<LLVM::LLVMFuncOp>
+mlir::LLVM::lookupOrCreatePrintNewlineFn(Operation *moduleOp) {
+ return lookupOrCreateReservedFn(
+ moduleOp, kPrintNewline, {},
+ LLVM::LLVMVoidType::get(moduleOp->getContext()));
}
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateMallocFn(Operation *moduleOp,
- Type indexType) {
- return LLVM::lookupOrCreateFn(moduleOp, kMalloc, indexType,
- getVoidPtr(moduleOp->getContext()), false,
- true);
+FailureOr<LLVM::LLVMFuncOp>
+mlir::LLVM::lookupOrCreateMallocFn(Operation *moduleOp, Type indexType) {
+ return lookupOrCreateReservedFn(moduleOp, kMalloc, indexType,
+ getVoidPtr(moduleOp->getContext()));
}
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateAlignedAllocFn(Operation *moduleOp,
- Type indexType) {
- return LLVM::lookupOrCreateFn(moduleOp, kAlignedAlloc, {indexType, indexType},
- getVoidPtr(moduleOp->getContext()), false,
- true);
+FailureOr<LLVM::LLVMFuncOp>
+mlir::LLVM::lookupOrCreateAlignedAllocFn(Operation *moduleOp, Type indexType) {
+ return lookupOrCreateReservedFn(moduleOp, kAlignedAlloc,
+ {indexType, indexType},
+ getVoidPtr(moduleOp->getContext()));
}
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFreeFn(Operation *moduleOp) {
- return LLVM::lookupOrCreateFn(
+FailureOr<LLVM::LLVMFuncOp>
+mlir::LLVM::lookupOrCreateFreeFn(Operation *moduleOp) {
+ return lookupOrCreateReservedFn(
moduleOp, kFree, getVoidPtr(moduleOp->getContext()),
- LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
+ LLVM::LLVMVoidType::get(moduleOp->getContext()));
}
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateGenericAllocFn(Operation *moduleOp,
- Type indexType) {
- return LLVM::lookupOrCreateFn(moduleOp, kGenericAlloc, indexType,
- getVoidPtr(moduleOp->getContext()), false,
- true);
+FailureOr<LLVM::LLVMFuncOp>
+mlir::LLVM::lookupOrCreateGenericAllocFn(Operation *moduleOp, Type indexType) {
+ return lookupOrCreateReservedFn(moduleOp, kGenericAlloc, indexType,
+ getVoidPtr(moduleOp->getContext()));
}
-LLVM::LLVMFuncOp
+FailureOr<LLVM::LLVMFuncOp>
mlir::LLVM::lookupOrCreateGenericAlignedAllocFn(Operation *moduleOp,
Type indexType) {
- return LLVM::lookupOrCreateFn(
- moduleOp, kGenericAlignedAlloc, {indexType, indexType},
- getVoidPtr(moduleOp->getContext()), false, true);
+ return lookupOrCreateReservedFn(moduleOp, kGenericAlignedAlloc,
+ {indexType, indexType},
+ getVoidPtr(moduleOp->getContext()));
}
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateGenericFreeFn(Operation *moduleOp) {
- return LLVM::lookupOrCreateFn(
+FailureOr<LLVM::LLVMFuncOp>
+mlir::LLVM::lookupOrCreateGenericFreeFn(Operation *moduleOp) {
+ return lookupOrCreateReservedFn(
moduleOp, kGenericFree, getVoidPtr(moduleOp->getContext()),
- LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
+ LLVM::LLVMVoidType::get(moduleOp->getContext()));
}
-LLVM::LLVMFuncOp
+FailureOr<LLVM::LLVMFuncOp>
mlir::LLVM::lookupOrCreateMemRefCopyFn(Operation *moduleOp, Type indexType,
Type unrankedDescriptorType) {
- return LLVM::lookupOrCreateFn(
+ return lookupOrCreateReservedFn(
moduleOp, kMemRefCopy,
ArrayRef<Type>{indexType, unrankedDescriptorType, unrankedDescriptorType},
- LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
+ LLVM::LLVMVoidType::get(moduleOp->getContext()));
}
>From 9079caff8b7bfa80bef19b2d33bdd7e361a8eb66 Mon Sep 17 00:00:00 2001
From: Luohao Wang <luohaothu at live.com>
Date: Tue, 21 Jan 2025 18:16:52 +0800
Subject: [PATCH 5/9] [mlir] [Test] Moved & renamed test case
---
mlir/test/Conversion/MemRefToLLVM/invalid.mlir | 7 +++++++
mlir/test/Conversion/MemRefToLLVM/issue-120950.mlir | 11 -----------
2 files changed, 7 insertions(+), 11 deletions(-)
delete mode 100644 mlir/test/Conversion/MemRefToLLVM/issue-120950.mlir
diff --git a/mlir/test/Conversion/MemRefToLLVM/invalid.mlir b/mlir/test/Conversion/MemRefToLLVM/invalid.mlir
index 40dd75af1dd770..1e12b83a24b5a7 100644
--- a/mlir/test/Conversion/MemRefToLLVM/invalid.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/invalid.mlir
@@ -2,6 +2,13 @@
// Since the error is at an unknown location, we use FileCheck instead of
// -veri-y-diagnostics here
+// CHECK: redefinition of reserved function 'malloc' of different type '!llvm.func<void (i64)>' is prohibited
+llvm.func @malloc(i64)
+func.func @redef_reserved() {
+ %alloc = memref.alloc() : memref<1024x64xf32, 1>
+ llvm.return
+}
+
// CHECK: conversion of memref memory space "foo" to integer address space failed. Consider adding memory space conversions.
// CHECK-LABEL: @bad_address_space
func.func @bad_address_space(%a: memref<2xindex, "foo">) {
diff --git a/mlir/test/Conversion/MemRefToLLVM/issue-120950.mlir b/mlir/test/Conversion/MemRefToLLVM/issue-120950.mlir
deleted file mode 100644
index f744e4f7635ea7..00000000000000
--- a/mlir/test/Conversion/MemRefToLLVM/issue-120950.mlir
+++ /dev/null
@@ -1,11 +0,0 @@
-// RUN: mlir-opt %s -finalize-memref-to-llvm 2>&1 | FileCheck %s
-
-#map = affine_map<(d0) -> (d0 + 1)>
-module {
- // CHECK: redefinition of reserved function 'malloc' of different type '!llvm.func<void (i64)>' is prohibited
- llvm.func @malloc(i64)
- func.func @issue_120950() {
- %alloc = memref.alloc() : memref<1024x64xf32, 1>
- llvm.return
- }
-}
>From 475409c4fac865f26c94b751a9b0dbfcc937b83a Mon Sep 17 00:00:00 2001
From: Luohao Wang <luohaothu at live.com>
Date: Tue, 21 Jan 2025 18:29:22 +0800
Subject: [PATCH 6/9] Reformat code
---
.../Conversion/LLVMCommon/PrintCallHelper.h | 9 ++++---
.../mlir/Dialect/LLVMIR/FunctionCallUtils.h | 24 ++++++++++---------
.../ControlFlowToLLVM/ControlFlowToLLVM.cpp | 8 ++++---
mlir/lib/Conversion/LLVMCommon/Pattern.cpp | 3 ++-
.../Conversion/LLVMCommon/PrintCallHelper.cpp | 3 ++-
.../MemRefToLLVM/AllocLikeConversion.cpp | 19 +++++++++------
.../Dialect/LLVMIR/IR/FunctionCallUtils.cpp | 12 +++++-----
7 files changed, 44 insertions(+), 34 deletions(-)
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h b/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h
index 5af86956c0ad92..33402301115b73 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h
@@ -23,11 +23,10 @@ namespace LLVM {
/// Generate IR that prints the given string to stdout.
/// If a custom runtime function is defined via `runtimeFunctionName`, it must
/// have the signature void(char const*). The default function is `printString`.
-LogicalResult createPrintStrCall(OpBuilder &builder, Location loc, ModuleOp moduleOp,
- StringRef symbolName, StringRef string,
- const LLVMTypeConverter &typeConverter,
- bool addNewline = true,
- std::optional<StringRef> runtimeFunctionName = {});
+LogicalResult createPrintStrCall(
+ OpBuilder &builder, Location loc, ModuleOp moduleOp, StringRef symbolName,
+ StringRef string, const LLVMTypeConverter &typeConverter,
+ bool addNewline = true, std::optional<StringRef> runtimeFunctionName = {});
} // namespace LLVM
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
index 473a69019d2399..05e9fe9d58859c 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
@@ -49,24 +49,26 @@ FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintOpenFn(Operation *moduleOp);
FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintCloseFn(Operation *moduleOp);
FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintCommaFn(Operation *moduleOp);
FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintNewlineFn(Operation *moduleOp);
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreateMallocFn(Operation *moduleOp, Type indexType);
+FailureOr<LLVM::LLVMFuncOp> lookupOrCreateMallocFn(Operation *moduleOp,
+ Type indexType);
FailureOr<LLVM::LLVMFuncOp> lookupOrCreateAlignedAllocFn(Operation *moduleOp,
- Type indexType);
+ Type indexType);
FailureOr<LLVM::LLVMFuncOp> lookupOrCreateFreeFn(Operation *moduleOp);
FailureOr<LLVM::LLVMFuncOp> lookupOrCreateGenericAllocFn(Operation *moduleOp,
- Type indexType);
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreateGenericAlignedAllocFn(Operation *moduleOp,
- Type indexType);
+ Type indexType);
+FailureOr<LLVM::LLVMFuncOp>
+lookupOrCreateGenericAlignedAllocFn(Operation *moduleOp, Type indexType);
FailureOr<LLVM::LLVMFuncOp> lookupOrCreateGenericFreeFn(Operation *moduleOp);
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreateMemRefCopyFn(Operation *moduleOp, Type indexType,
- Type unrankedDescriptorType);
+FailureOr<LLVM::LLVMFuncOp>
+lookupOrCreateMemRefCopyFn(Operation *moduleOp, Type indexType,
+ Type unrankedDescriptorType);
/// Create a FuncOp with signature `resultType`(`paramTypes`)` and name `name`.
/// Return a failure if the FuncOp found has unexpected signature.
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreateFn(Operation *moduleOp, StringRef name,
- ArrayRef<Type> paramTypes = {},
- Type resultType = {}, bool isVarArg = false,
- bool isReserved = false);
+FailureOr<LLVM::LLVMFuncOp>
+lookupOrCreateFn(Operation *moduleOp, StringRef name,
+ ArrayRef<Type> paramTypes = {}, Type resultType = {},
+ bool isVarArg = false, bool isReserved = false);
} // namespace LLVM
} // namespace mlir
diff --git a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
index cdcb613e04ab12..f2fc235fecb289 100644
--- a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
+++ b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
@@ -61,9 +61,11 @@ struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {
// Failed block: Generate IR to print the message and call `abort`.
Block *failureBlock = rewriter.createBlock(opBlock->getParent());
- if (LLVM::createPrintStrCall(rewriter, loc, module, "assert_msg", op.getMsg(),
- *getTypeConverter(), /*addNewLine=*/false,
- /*runtimeFunctionName=*/"puts").failed()) {
+ if (LLVM::createPrintStrCall(rewriter, loc, module, "assert_msg",
+ op.getMsg(), *getTypeConverter(),
+ /*addNewLine=*/false,
+ /*runtimeFunctionName=*/"puts")
+ .failed()) {
return failure();
}
if (abortOnFailedAssert) {
diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
index 10f72cda7706db..840bd3df61a063 100644
--- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
@@ -299,7 +299,8 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
// Allocate memory, copy, and free the source if necessary.
Value memory =
toDynamic
- ? builder.create<LLVM::CallOp>(loc, mallocFunc.value(), allocationSize)
+ ? builder
+ .create<LLVM::CallOp>(loc, mallocFunc.value(), allocationSize)
.getResult()
: builder.create<LLVM::AllocaOp>(loc, getVoidPtrType(),
IntegerType::get(getContext(), 8),
diff --git a/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp b/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp
index 607e1d65045523..381e2ffea8eb29 100644
--- a/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp
@@ -60,7 +60,8 @@ LogicalResult mlir::LLVM::createPrintStrCall(
Value gep =
builder.create<LLVM::GEPOp>(loc, ptrTy, arrayTy, msgAddr, indices);
if (auto printer =
- LLVM::lookupOrCreatePrintStringFn(moduleOp, runtimeFunctionName); succeeded(printer)) {
+ LLVM::lookupOrCreatePrintStringFn(moduleOp, runtimeFunctionName);
+ succeeded(printer)) {
builder.create<LLVM::CallOp>(loc, TypeRange(),
SymbolRefAttr::get(printer.value()), gep);
} else {
diff --git a/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp b/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
index 0ee92722157f35..1712d0b5844b88 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
@@ -15,8 +15,9 @@
using namespace mlir;
namespace {
-FailureOr<LLVM::LLVMFuncOp> getNotalignedAllocFn(const LLVMTypeConverter *typeConverter,
- Operation *module, Type indexType) {
+FailureOr<LLVM::LLVMFuncOp>
+getNotalignedAllocFn(const LLVMTypeConverter *typeConverter, Operation *module,
+ Type indexType) {
bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
if (useGenericFn)
return LLVM::lookupOrCreateGenericAllocFn(module, indexType);
@@ -24,8 +25,9 @@ FailureOr<LLVM::LLVMFuncOp> getNotalignedAllocFn(const LLVMTypeConverter *typeCo
return LLVM::lookupOrCreateMallocFn(module, indexType);
}
-FailureOr<LLVM::LLVMFuncOp> getAlignedAllocFn(const LLVMTypeConverter *typeConverter,
- Operation *module, Type indexType) {
+FailureOr<LLVM::LLVMFuncOp>
+getAlignedAllocFn(const LLVMTypeConverter *typeConverter, Operation *module,
+ Type indexType) {
bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
if (useGenericFn)
@@ -83,8 +85,10 @@ std::tuple<Value, Value> AllocationOpLLVMLowering::allocateBufferManuallyAlign(
FailureOr<LLVM::LLVMFuncOp> allocFuncOp = getNotalignedAllocFn(
getTypeConverter(), op->getParentWithTrait<OpTrait::SymbolTable>(),
getIndexType());
- if (failed(allocFuncOp)) return std::make_tuple(Value(), Value());
- auto results = rewriter.create<LLVM::CallOp>(loc, allocFuncOp.value(), sizeBytes);
+ if (failed(allocFuncOp))
+ return std::make_tuple(Value(), Value());
+ auto results =
+ rewriter.create<LLVM::CallOp>(loc, allocFuncOp.value(), sizeBytes);
Value allocatedPtr =
castAllocFuncResult(rewriter, loc, results.getResult(), memRefType,
@@ -150,7 +154,8 @@ Value AllocationOpLLVMLowering::allocateBufferAutoAlign(
FailureOr<LLVM::LLVMFuncOp> allocFuncOp = getAlignedAllocFn(
getTypeConverter(), op->getParentWithTrait<OpTrait::SymbolTable>(),
getIndexType());
- if (failed(allocFuncOp)) return Value();
+ if (failed(allocFuncOp))
+ return Value();
auto results = rewriter.create<LLVM::CallOp>(
loc, allocFuncOp.value(), ValueRange({allocAlignment, sizeBytes}));
diff --git a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
index c2c87bc7544bd7..9df5c4554c2360 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
@@ -180,14 +180,14 @@ mlir::LLVM::lookupOrCreatePrintNewlineFn(Operation *moduleOp) {
FailureOr<LLVM::LLVMFuncOp>
mlir::LLVM::lookupOrCreateMallocFn(Operation *moduleOp, Type indexType) {
return lookupOrCreateReservedFn(moduleOp, kMalloc, indexType,
- getVoidPtr(moduleOp->getContext()));
+ getVoidPtr(moduleOp->getContext()));
}
FailureOr<LLVM::LLVMFuncOp>
mlir::LLVM::lookupOrCreateAlignedAllocFn(Operation *moduleOp, Type indexType) {
return lookupOrCreateReservedFn(moduleOp, kAlignedAlloc,
- {indexType, indexType},
- getVoidPtr(moduleOp->getContext()));
+ {indexType, indexType},
+ getVoidPtr(moduleOp->getContext()));
}
FailureOr<LLVM::LLVMFuncOp>
@@ -200,15 +200,15 @@ mlir::LLVM::lookupOrCreateFreeFn(Operation *moduleOp) {
FailureOr<LLVM::LLVMFuncOp>
mlir::LLVM::lookupOrCreateGenericAllocFn(Operation *moduleOp, Type indexType) {
return lookupOrCreateReservedFn(moduleOp, kGenericAlloc, indexType,
- getVoidPtr(moduleOp->getContext()));
+ getVoidPtr(moduleOp->getContext()));
}
FailureOr<LLVM::LLVMFuncOp>
mlir::LLVM::lookupOrCreateGenericAlignedAllocFn(Operation *moduleOp,
Type indexType) {
return lookupOrCreateReservedFn(moduleOp, kGenericAlignedAlloc,
- {indexType, indexType},
- getVoidPtr(moduleOp->getContext()));
+ {indexType, indexType},
+ getVoidPtr(moduleOp->getContext()));
}
FailureOr<LLVM::LLVMFuncOp>
>From 874a1cd68dd800c0b8a5f5263a24a7a36d41eae5 Mon Sep 17 00:00:00 2001
From: Luohao Wang <luohaothu at live.com>
Date: Sun, 26 Jan 2025 10:47:13 +0800
Subject: [PATCH 7/9] Make stylish fixes
---
.../Conversion/LLVMCommon/PrintCallHelper.cpp | 10 +++---
.../MemRefToLLVM/AllocLikeConversion.cpp | 7 ++--
.../Conversion/MemRefToLLVM/MemRefToLLVM.cpp | 8 ++---
.../VectorToLLVM/ConvertVectorToLLVM.cpp | 34 +++++++++----------
.../Dialect/LLVMIR/IR/FunctionCallUtils.cpp | 22 +++++-------
5 files changed, 34 insertions(+), 47 deletions(-)
diff --git a/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp b/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp
index 381e2ffea8eb29..deabb748b56522 100644
--- a/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp
@@ -59,13 +59,11 @@ LogicalResult mlir::LLVM::createPrintStrCall(
SmallVector<LLVM::GEPArg> indices(1, 0);
Value gep =
builder.create<LLVM::GEPOp>(loc, ptrTy, arrayTy, msgAddr, indices);
- if (auto printer =
+ FailureOr<LLVM::LLVMFuncOp> printer =
LLVM::lookupOrCreatePrintStringFn(moduleOp, runtimeFunctionName);
- succeeded(printer)) {
- builder.create<LLVM::CallOp>(loc, TypeRange(),
- SymbolRefAttr::get(printer.value()), gep);
- } else {
+ if(failed(printer))
return failure();
- }
+ builder.create<LLVM::CallOp>(loc, TypeRange(),
+ SymbolRefAttr::get(printer.value()), gep);
return success();
}
diff --git a/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp b/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
index 1712d0b5844b88..c5b2e83df93dcb 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
@@ -14,8 +14,7 @@
using namespace mlir;
-namespace {
-FailureOr<LLVM::LLVMFuncOp>
+static FailureOr<LLVM::LLVMFuncOp>
getNotalignedAllocFn(const LLVMTypeConverter *typeConverter, Operation *module,
Type indexType) {
bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
@@ -25,7 +24,7 @@ getNotalignedAllocFn(const LLVMTypeConverter *typeConverter, Operation *module,
return LLVM::lookupOrCreateMallocFn(module, indexType);
}
-FailureOr<LLVM::LLVMFuncOp>
+static FailureOr<LLVM::LLVMFuncOp>
getAlignedAllocFn(const LLVMTypeConverter *typeConverter, Operation *module,
Type indexType) {
bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
@@ -36,8 +35,6 @@ getAlignedAllocFn(const LLVMTypeConverter *typeConverter, Operation *module,
return LLVM::lookupOrCreateAlignedAllocFn(module, indexType);
}
-} // end namespace
-
Value AllocationOpLLVMLowering::createAligned(
ConversionPatternRewriter &rewriter, Location loc, Value input,
Value alignment) {
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index ac27e0dd09bdcd..af1dba4587dc1f 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -38,12 +38,12 @@ using namespace mlir;
namespace {
-bool isStaticStrideOrOffset(int64_t strideOrOffset) {
+static bool isStaticStrideOrOffset(int64_t strideOrOffset) {
return !ShapedType::isDynamic(strideOrOffset);
}
-FailureOr<LLVM::LLVMFuncOp> getFreeFn(const LLVMTypeConverter *typeConverter,
- ModuleOp module) {
+static FailureOr<LLVM::LLVMFuncOp>
+getFreeFn(const LLVMTypeConverter *typeConverter, ModuleOp module) {
bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
if (useGenericFn)
@@ -220,7 +220,7 @@ struct DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> {
matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Insert the `free` declaration if it is not already present.
- auto freeFunc =
+ FailureOr<LLVM::LLVMFuncOp> freeFunc =
getFreeFn(getTypeConverter(), op->getParentOfType<ModuleOp>());
if (failed(freeFunc))
return failure();
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 79617506008fab..258374f71c7d5e 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1553,25 +1553,23 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
return failure();
}
} else if (punct != PrintPunctuation::NoPunctuation) {
- if (auto op = [&] -> FailureOr<LLVM::LLVMFuncOp> {
- switch (punct) {
- case PrintPunctuation::Close:
- return LLVM::lookupOrCreatePrintCloseFn(parent);
- case PrintPunctuation::Open:
- return LLVM::lookupOrCreatePrintOpenFn(parent);
- case PrintPunctuation::Comma:
- return LLVM::lookupOrCreatePrintCommaFn(parent);
- case PrintPunctuation::NewLine:
- return LLVM::lookupOrCreatePrintNewlineFn(parent);
- default:
- llvm_unreachable("unexpected punctuation");
- }
- }();
- succeeded(op))
- emitCall(rewriter, printOp->getLoc(), op.value());
- else {
+ FailureOr<LLVM::LLVMFuncOp> op = [&]() {
+ switch (punct) {
+ case PrintPunctuation::Close:
+ return LLVM::lookupOrCreatePrintCloseFn(parent);
+ case PrintPunctuation::Open:
+ return LLVM::lookupOrCreatePrintOpenFn(parent);
+ case PrintPunctuation::Comma:
+ return LLVM::lookupOrCreatePrintCommaFn(parent);
+ case PrintPunctuation::NewLine:
+ return LLVM::lookupOrCreatePrintNewlineFn(parent);
+ default:
+ llvm_unreachable("unexpected punctuation");
+ }
+ }();
+ if (failed(op))
return failure();
- }
+ emitCall(rewriter, printOp->getLoc(), op.value());
}
rewriter.eraseOp(printOp);
diff --git a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
index 9df5c4554c2360..68d4426e653019 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
@@ -58,15 +58,12 @@ mlir::LLVM::lookupOrCreateFn(Operation *moduleOp, StringRef name,
if (func) {
if (funcT != func.getFunctionType()) {
if (isReserved) {
- func.emitError("redefinition of reserved function '" + name +
- "' of different type ")
- .append(func.getFunctionType())
- .append(" is prohibited");
+ func.emitError("redefinition of reserved function '")
+ << name << "' of different type " << func.getFunctionType()
+ << " is prohibited";
} else {
- func.emitError("redefinition of function '" + name +
- "' of different type ")
- .append(funcT)
- .append(" is prohibited");
+ func.emitError("redefinition of function '")
+ << name << "' of different type " << funcT << " is prohibited";
}
return failure();
}
@@ -78,15 +75,12 @@ mlir::LLVM::lookupOrCreateFn(Operation *moduleOp, StringRef name,
LLVM::LLVMFunctionType::get(resultType, paramTypes, isVarArg));
}
-namespace {
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreateReservedFn(Operation *moduleOp,
- StringRef name,
- ArrayRef<Type> paramTypes,
- Type resultType) {
+static FailureOr<LLVM::LLVMFuncOp>
+lookupOrCreateReservedFn(Operation *moduleOp, StringRef name,
+ ArrayRef<Type> paramTypes, Type resultType) {
return lookupOrCreateFn(moduleOp, name, paramTypes, resultType,
/*isVarArg=*/false, /*isReserved=*/true);
}
-} // namespace
FailureOr<LLVM::LLVMFuncOp>
mlir::LLVM::lookupOrCreatePrintI64Fn(Operation *moduleOp) {
>From 2d7dc5d888146a6fddb9c717bacd975dcf008553 Mon Sep 17 00:00:00 2001
From: Luohao Wang <luohaothu at live.com>
Date: Sun, 26 Jan 2025 15:22:35 +0800
Subject: [PATCH 8/9] More stylish fixes
---
.../ControlFlowToLLVM/ControlFlowToLLVM.cpp | 12 ++++++------
.../Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp | 9 +++++----
2 files changed, 11 insertions(+), 10 deletions(-)
diff --git a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
index f2fc235fecb289..debfd003bd5b5e 100644
--- a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
+++ b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
@@ -61,13 +61,13 @@ struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {
// Failed block: Generate IR to print the message and call `abort`.
Block *failureBlock = rewriter.createBlock(opBlock->getParent());
- if (LLVM::createPrintStrCall(rewriter, loc, module, "assert_msg",
- op.getMsg(), *getTypeConverter(),
- /*addNewLine=*/false,
- /*runtimeFunctionName=*/"puts")
- .failed()) {
+ auto createResult = LLVM::createPrintStrCall(
+ rewriter, loc, module, "assert_msg", op.getMsg(), *getTypeConverter(),
+ /*addNewLine=*/false,
+ /*runtimeFunctionName=*/"puts");
+ if (createResult.failed())
return failure();
- }
+
if (abortOnFailedAssert) {
// Insert the `abort` declaration if necessary.
auto abortFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("abort");
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 258374f71c7d5e..baed98c13adc7c 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1546,12 +1546,13 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
auto punct = printOp.getPunctuation();
if (auto stringLiteral = printOp.getStringLiteral()) {
- if (LLVM::createPrintStrCall(rewriter, loc, parent, "vector_print_str",
+ auto createResult =
+ LLVM::createPrintStrCall(rewriter, loc, parent, "vector_print_str",
*stringLiteral, *getTypeConverter(),
- /*addNewline=*/false)
- .failed()) {
+ /*addNewline=*/false);
+ if (createResult.failed())
return failure();
- }
+
} else if (punct != PrintPunctuation::NoPunctuation) {
FailureOr<LLVM::LLVMFuncOp> op = [&]() {
switch (punct) {
>From a7f3308d74e5f4b4ce3bd7e5a4f35cb92f4f633d Mon Sep 17 00:00:00 2001
From: Luohao Wang <luohaothu at live.com>
Date: Sun, 26 Jan 2025 15:25:44 +0800
Subject: [PATCH 9/9] Format code
---
mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp b/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp
index deabb748b56522..337c01f01a7cc7 100644
--- a/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp
@@ -60,10 +60,10 @@ LogicalResult mlir::LLVM::createPrintStrCall(
Value gep =
builder.create<LLVM::GEPOp>(loc, ptrTy, arrayTy, msgAddr, indices);
FailureOr<LLVM::LLVMFuncOp> printer =
- LLVM::lookupOrCreatePrintStringFn(moduleOp, runtimeFunctionName);
- if(failed(printer))
+ LLVM::lookupOrCreatePrintStringFn(moduleOp, runtimeFunctionName);
+ if (failed(printer))
return failure();
builder.create<LLVM::CallOp>(loc, TypeRange(),
- SymbolRefAttr::get(printer.value()), gep);
+ SymbolRefAttr::get(printer.value()), gep);
return success();
}
More information about the Mlir-commits
mailing list