[Mlir-commits] [mlir] [mlir][LLVM] Add `OpBuilder &` to `lookupOrCreateFn` functions (PR #136421)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Apr 19 03:01:26 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-llvm
Author: Matthias Springer (matthias-springer)
<details>
<summary>Changes</summary>
These functions are called from lowering patterns. All IR modifications in a pattern must be performed through the provided rewriter, but these functions used to instantiate a new `OpBuilder`, bypassing the provided rewriter.
---
Patch is 25.06 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/136421.diff
8 Files Affected:
- (modified) mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h (+35-22)
- (modified) mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp (+2-2)
- (modified) mlir/lib/Conversion/LLVMCommon/Pattern.cpp (+2-2)
- (modified) mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp (+1-1)
- (modified) mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp (+12-12)
- (modified) mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp (+8-6)
- (modified) mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp (+11-11)
- (modified) mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp (+49-42)
``````````diff
diff --git a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
index 05e9fe9d58859..4a7ec6f2efe64 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
@@ -33,40 +33,53 @@ class LLVMFuncOp;
/// implemented separately (e.g. as part of a support runtime library or as part
/// of the libc).
/// Failure if an unexpected version of function is found.
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintI64Fn(Operation *moduleOp);
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintU64Fn(Operation *moduleOp);
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintF16Fn(Operation *moduleOp);
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintBF16Fn(Operation *moduleOp);
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintF32Fn(Operation *moduleOp);
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintF64Fn(Operation *moduleOp);
+FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintI64Fn(OpBuilder &b,
+ Operation *moduleOp);
+FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintU64Fn(OpBuilder &b,
+ Operation *moduleOp);
+FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintF16Fn(OpBuilder &b,
+ Operation *moduleOp);
+FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintBF16Fn(OpBuilder &b,
+ Operation *moduleOp);
+FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintF32Fn(OpBuilder &b,
+ Operation *moduleOp);
+FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintF64Fn(OpBuilder &b,
+ Operation *moduleOp);
/// Declares a function to print a C-string.
/// If a custom runtime function is defined via `runtimeFunctionName`, it must
/// have the signature void(char const*). The default function is `printString`.
FailureOr<LLVM::LLVMFuncOp>
-lookupOrCreatePrintStringFn(Operation *moduleOp,
+lookupOrCreatePrintStringFn(OpBuilder &b, Operation *moduleOp,
std::optional<StringRef> runtimeFunctionName = {});
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintOpenFn(Operation *moduleOp);
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintCloseFn(Operation *moduleOp);
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintCommaFn(Operation *moduleOp);
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintNewlineFn(Operation *moduleOp);
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreateMallocFn(Operation *moduleOp,
- Type indexType);
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreateAlignedAllocFn(Operation *moduleOp,
- Type indexType);
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreateFreeFn(Operation *moduleOp);
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreateGenericAllocFn(Operation *moduleOp,
- Type indexType);
+FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintOpenFn(OpBuilder &b,
+ Operation *moduleOp);
+FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintCloseFn(OpBuilder &b,
+ Operation *moduleOp);
+FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintCommaFn(OpBuilder &b,
+ Operation *moduleOp);
+FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintNewlineFn(OpBuilder &b,
+ Operation *moduleOp);
FailureOr<LLVM::LLVMFuncOp>
-lookupOrCreateGenericAlignedAllocFn(Operation *moduleOp, Type indexType);
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreateGenericFreeFn(Operation *moduleOp);
+lookupOrCreateMallocFn(OpBuilder &b, Operation *moduleOp, Type indexType);
FailureOr<LLVM::LLVMFuncOp>
-lookupOrCreateMemRefCopyFn(Operation *moduleOp, Type indexType,
+lookupOrCreateAlignedAllocFn(OpBuilder &b, Operation *moduleOp, Type indexType);
+FailureOr<LLVM::LLVMFuncOp> lookupOrCreateFreeFn(OpBuilder &b,
+ Operation *moduleOp);
+FailureOr<LLVM::LLVMFuncOp>
+lookupOrCreateGenericAllocFn(OpBuilder &b, Operation *moduleOp, Type indexType);
+FailureOr<LLVM::LLVMFuncOp>
+lookupOrCreateGenericAlignedAllocFn(OpBuilder &b, Operation *moduleOp,
+ Type indexType);
+FailureOr<LLVM::LLVMFuncOp> lookupOrCreateGenericFreeFn(OpBuilder &b,
+ Operation *moduleOp);
+FailureOr<LLVM::LLVMFuncOp>
+lookupOrCreateMemRefCopyFn(OpBuilder &b, Operation *moduleOp, Type indexType,
Type unrankedDescriptorType);
/// Create a FuncOp with signature `resultType`(`paramTypes`)` and name `name`.
/// Return a failure if the FuncOp found has unexpected signature.
FailureOr<LLVM::LLVMFuncOp>
-lookupOrCreateFn(Operation *moduleOp, StringRef name,
+lookupOrCreateFn(OpBuilder &b, Operation *moduleOp, StringRef name,
ArrayRef<Type> paramTypes = {}, Type resultType = {},
bool isVarArg = false, bool isReserved = false);
diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
index 47d4474a5c28d..c95e375ce9afe 100644
--- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
+++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
@@ -395,7 +395,7 @@ class CoroBeginOpConversion : public AsyncOpConversionPattern<CoroBeginOp> {
// Allocate memory for the coroutine frame.
auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn(
- op->getParentOfType<ModuleOp>(), rewriter.getI64Type());
+ rewriter, op->getParentOfType<ModuleOp>(), rewriter.getI64Type());
if (failed(allocFuncOp))
return failure();
auto coroAlloc = rewriter.create<LLVM::CallOp>(
@@ -432,7 +432,7 @@ class CoroFreeOpConversion : public AsyncOpConversionPattern<CoroFreeOp> {
// Free the memory.
auto freeFuncOp =
- LLVM::lookupOrCreateFreeFn(op->getParentOfType<ModuleOp>());
+ LLVM::lookupOrCreateFreeFn(rewriter, op->getParentOfType<ModuleOp>());
if (failed(freeFuncOp))
return failure();
rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, freeFuncOp.value(),
diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
index 32bfd72475569..706ae1e72e2b3 100644
--- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
@@ -278,12 +278,12 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
auto module = builder.getInsertionPoint()->getParentOfType<ModuleOp>();
FailureOr<LLVM::LLVMFuncOp> freeFunc, mallocFunc;
if (toDynamic) {
- mallocFunc = LLVM::lookupOrCreateMallocFn(module, indexType);
+ mallocFunc = LLVM::lookupOrCreateMallocFn(builder, module, indexType);
if (failed(mallocFunc))
return failure();
}
if (!toDynamic) {
- freeFunc = LLVM::lookupOrCreateFreeFn(module);
+ freeFunc = LLVM::lookupOrCreateFreeFn(builder, module);
if (failed(freeFunc))
return failure();
}
diff --git a/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp b/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp
index 337c01f01a7cc..2815e05b3e11b 100644
--- a/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp
@@ -60,7 +60,7 @@ LogicalResult mlir::LLVM::createPrintStrCall(
Value gep =
builder.create<LLVM::GEPOp>(loc, ptrTy, arrayTy, msgAddr, indices);
FailureOr<LLVM::LLVMFuncOp> printer =
- LLVM::lookupOrCreatePrintStringFn(moduleOp, runtimeFunctionName);
+ LLVM::lookupOrCreatePrintStringFn(builder, moduleOp, runtimeFunctionName);
if (failed(printer))
return failure();
builder.create<LLVM::CallOp>(loc, TypeRange(),
diff --git a/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp b/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
index bad209a4ddecf..e9b79983696aa 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
@@ -15,24 +15,24 @@
using namespace mlir;
static FailureOr<LLVM::LLVMFuncOp>
-getNotalignedAllocFn(const LLVMTypeConverter *typeConverter, Operation *module,
- Type indexType) {
+getNotalignedAllocFn(OpBuilder &b, const LLVMTypeConverter *typeConverter,
+ Operation *module, Type indexType) {
bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
if (useGenericFn)
- return LLVM::lookupOrCreateGenericAllocFn(module, indexType);
+ return LLVM::lookupOrCreateGenericAllocFn(b, module, indexType);
- return LLVM::lookupOrCreateMallocFn(module, indexType);
+ return LLVM::lookupOrCreateMallocFn(b, module, indexType);
}
static FailureOr<LLVM::LLVMFuncOp>
-getAlignedAllocFn(const LLVMTypeConverter *typeConverter, Operation *module,
- Type indexType) {
+getAlignedAllocFn(OpBuilder &b, const LLVMTypeConverter *typeConverter,
+ Operation *module, Type indexType) {
bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
if (useGenericFn)
- return LLVM::lookupOrCreateGenericAlignedAllocFn(module, indexType);
+ return LLVM::lookupOrCreateGenericAlignedAllocFn(b, module, indexType);
- return LLVM::lookupOrCreateAlignedAllocFn(module, indexType);
+ return LLVM::lookupOrCreateAlignedAllocFn(b, module, indexType);
}
Value AllocationOpLLVMLowering::createAligned(
@@ -75,8 +75,8 @@ std::tuple<Value, Value> AllocationOpLLVMLowering::allocateBufferManuallyAlign(
Type elementPtrType = this->getElementPtrType(memRefType);
assert(elementPtrType && "could not compute element ptr type");
FailureOr<LLVM::LLVMFuncOp> allocFuncOp = getNotalignedAllocFn(
- getTypeConverter(), op->getParentWithTrait<OpTrait::SymbolTable>(),
- getIndexType());
+ rewriter, getTypeConverter(),
+ op->getParentWithTrait<OpTrait::SymbolTable>(), getIndexType());
if (failed(allocFuncOp))
return std::make_tuple(Value(), Value());
auto results =
@@ -144,8 +144,8 @@ Value AllocationOpLLVMLowering::allocateBufferAutoAlign(
Type elementPtrType = this->getElementPtrType(memRefType);
FailureOr<LLVM::LLVMFuncOp> allocFuncOp = getAlignedAllocFn(
- getTypeConverter(), op->getParentWithTrait<OpTrait::SymbolTable>(),
- getIndexType());
+ rewriter, getTypeConverter(),
+ op->getParentWithTrait<OpTrait::SymbolTable>(), getIndexType());
if (failed(allocFuncOp))
return Value();
auto results = rewriter.create<LLVM::CallOp>(
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index cb4317ef1bcec..9c219d8a3d8cb 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -43,13 +43,14 @@ static bool isStaticStrideOrOffset(int64_t strideOrOffset) {
}
static FailureOr<LLVM::LLVMFuncOp>
-getFreeFn(const LLVMTypeConverter *typeConverter, ModuleOp module) {
+getFreeFn(OpBuilder &b, const LLVMTypeConverter *typeConverter,
+ ModuleOp module) {
bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
if (useGenericFn)
- return LLVM::lookupOrCreateGenericFreeFn(module);
+ return LLVM::lookupOrCreateGenericFreeFn(b, module);
- return LLVM::lookupOrCreateFreeFn(module);
+ return LLVM::lookupOrCreateFreeFn(b, module);
}
struct AllocOpLowering : public AllocLikeOpLLVMLowering {
@@ -223,8 +224,8 @@ struct DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> {
matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Insert the `free` declaration if it is not already present.
- FailureOr<LLVM::LLVMFuncOp> freeFunc =
- getFreeFn(getTypeConverter(), op->getParentOfType<ModuleOp>());
+ FailureOr<LLVM::LLVMFuncOp> freeFunc = getFreeFn(
+ rewriter, getTypeConverter(), op->getParentOfType<ModuleOp>());
if (failed(freeFunc))
return failure();
Value allocatedPtr;
@@ -834,7 +835,8 @@ struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
// potential alignment
auto elemSize = getSizeInBytes(loc, srcType.getElementType(), rewriter);
auto copyFn = LLVM::lookupOrCreateMemRefCopyFn(
- op->getParentOfType<ModuleOp>(), getIndexType(), sourcePtr.getType());
+ rewriter, op->getParentOfType<ModuleOp>(), getIndexType(),
+ sourcePtr.getType());
if (failed(copyFn))
return failure();
rewriter.create<LLVM::CallOp>(loc, copyFn.value(),
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 1a35d08196459..076e5512f375b 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1570,13 +1570,13 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
FailureOr<LLVM::LLVMFuncOp> op = [&]() {
switch (punct) {
case PrintPunctuation::Close:
- return LLVM::lookupOrCreatePrintCloseFn(parent);
+ return LLVM::lookupOrCreatePrintCloseFn(rewriter, parent);
case PrintPunctuation::Open:
- return LLVM::lookupOrCreatePrintOpenFn(parent);
+ return LLVM::lookupOrCreatePrintOpenFn(rewriter, parent);
case PrintPunctuation::Comma:
- return LLVM::lookupOrCreatePrintCommaFn(parent);
+ return LLVM::lookupOrCreatePrintCommaFn(rewriter, parent);
case PrintPunctuation::NewLine:
- return LLVM::lookupOrCreatePrintNewlineFn(parent);
+ return LLVM::lookupOrCreatePrintNewlineFn(rewriter, parent);
default:
llvm_unreachable("unexpected punctuation");
}
@@ -1610,17 +1610,17 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
PrintConversion conversion = PrintConversion::None;
FailureOr<Operation *> printer;
if (printType.isF32()) {
- printer = LLVM::lookupOrCreatePrintF32Fn(parent);
+ printer = LLVM::lookupOrCreatePrintF32Fn(rewriter, parent);
} else if (printType.isF64()) {
- printer = LLVM::lookupOrCreatePrintF64Fn(parent);
+ printer = LLVM::lookupOrCreatePrintF64Fn(rewriter, parent);
} else if (printType.isF16()) {
conversion = PrintConversion::Bitcast16; // bits!
- printer = LLVM::lookupOrCreatePrintF16Fn(parent);
+ printer = LLVM::lookupOrCreatePrintF16Fn(rewriter, parent);
} else if (printType.isBF16()) {
conversion = PrintConversion::Bitcast16; // bits!
- printer = LLVM::lookupOrCreatePrintBF16Fn(parent);
+ printer = LLVM::lookupOrCreatePrintBF16Fn(rewriter, parent);
} else if (printType.isIndex()) {
- printer = LLVM::lookupOrCreatePrintU64Fn(parent);
+ printer = LLVM::lookupOrCreatePrintU64Fn(rewriter, parent);
} else if (auto intTy = dyn_cast<IntegerType>(printType)) {
// Integers need a zero or sign extension on the operand
// (depending on the source type) as well as a signed or
@@ -1630,7 +1630,7 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
if (width <= 64) {
if (width < 64)
conversion = PrintConversion::ZeroExt64;
- printer = LLVM::lookupOrCreatePrintU64Fn(parent);
+ printer = LLVM::lookupOrCreatePrintU64Fn(rewriter, parent);
} else {
return failure();
}
@@ -1643,7 +1643,7 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
conversion = PrintConversion::ZeroExt64;
else if (width < 64)
conversion = PrintConversion::SignExt64;
- printer = LLVM::lookupOrCreatePrintI64Fn(parent);
+ printer = LLVM::lookupOrCreatePrintI64Fn(rewriter, parent);
} else {
return failure();
}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
index 68d4426e65301..1b4a8f496d3d0 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
@@ -46,7 +46,7 @@ static constexpr llvm::StringRef kMemRefCopy = "memrefCopy";
/// Generic print function lookupOrCreate helper.
FailureOr<LLVM::LLVMFuncOp>
-mlir::LLVM::lookupOrCreateFn(Operation *moduleOp, StringRef name,
+mlir::LLVM::lookupOrCreateFn(OpBuilder &b, Operation *moduleOp, StringRef name,
ArrayRef<Type> paramTypes, Type resultType,
bool isVarArg, bool isReserved) {
assert(moduleOp->hasTrait<OpTrait::SymbolTable>() &&
@@ -69,60 +69,63 @@ mlir::LLVM::lookupOrCreateFn(Operation *moduleOp, StringRef name,
}
return func;
}
- OpBuilder b(moduleOp->getRegion(0));
+
+ OpBuilder::InsertionGuard g(b);
+ assert(!moduleOp->getRegion(0).empty() && "expected non-empty region");
+ b.setInsertionPointToStart(&moduleOp->getRegion(0).front());
return b.create<LLVM::LLVMFuncOp>(
moduleOp->getLoc(), name,
LLVM::LLVMFunctionType::get(resultType, paramTypes, isVarArg));
}
static FailureOr<LLVM::LLVMFuncOp>
-lookupOrCreateReservedFn(Operation *moduleOp, StringRef name,
+lookupOrCreateReservedFn(OpBuilder &b, Operation *moduleOp, StringRef name,
ArrayRef<Type> paramTypes, Type resultType) {
- return lookupOrCreateFn(moduleOp, name, paramTypes, resultType,
+ return lookupOrCreateFn(b, moduleOp, name, paramTypes, resultType,
/*isVarArg=*/false, /*isReserved=*/true);
}
FailureOr<LLVM::LLVMFuncOp>
-mlir::LLVM::lookupOrCreatePrintI64Fn(Operation *moduleOp) {
+mlir::LLVM::lookupOrCreatePrintI64Fn(OpBuilder &b, Operation *moduleOp) {
return lookupOrCreateReservedFn(
- moduleOp, kPrintI64, IntegerType::get(moduleOp->getContext(), 64),
+ b, moduleOp, kPrintI64, IntegerType::get(moduleOp->getContext(), 64),
LLVM::LLVMVoidType::get(moduleOp->getContext()));
}
FailureOr<LLVM::LLVMFuncOp>
-mlir::LLVM::lookupOrCreatePrintU64Fn(Operation *moduleOp) {
+mlir::LLVM::lookupOrCreatePrintU64Fn(OpBuilder &b, Operation *moduleOp) {
return lookupOrCreateReservedFn(
- moduleOp, kPrintU64, IntegerType::get(moduleOp->getContext(), 64),
+ b, moduleOp, kPrintU64, IntegerType::get(moduleOp->getContext(), 64),
LLVM::LLVMVoidType::get(moduleOp->getContext()));
}
FailureOr<LLVM::LLVMFuncOp>
-mlir::LLVM::lookupOrCreatePrintF16Fn(Operation *moduleOp) {
+mlir::LLVM::lookupOrCreatePrintF16Fn(OpBuilder &b, Operation *moduleOp) {
return lookupOrCreateReservedFn(
- moduleOp, kPrintF16,
+ b, moduleOp, kPrintF16,
IntegerType::get(moduleOp->getContext(), 16), // bits!
LLVM::LLVMVoidType::get(moduleOp->getContext()));
}
FailureOr<LLVM::LLVMFuncOp>
-mlir::LLVM::lookupOrCreatePrintBF16Fn(Operation *moduleOp) {
+mlir::LLVM::lookupOrCreatePrintBF16Fn(OpBuilder &b, Operation *moduleOp) {
return lookupOrCreateReservedFn(
- moduleOp, kPrintBF16,
+ b, moduleOp, kPrintBF16,
IntegerType::get(moduleOp->getContext(), 16), // bits!
LLVM::LLVMVoidType::get(moduleOp->getContext()));
}
FailureOr<LLVM::LLVMFuncOp>
-mlir::LLVM::lookupOrCreatePrintF32Fn(Operation *moduleOp) {
+mlir::LLVM::lookupOrCreatePrintF32Fn(OpBuilder &b, Operation *moduleOp) {
return lookupOrCreateReservedFn(
- moduleOp, kPrintF32, Float32Type::get(moduleOp->getContext()),
+ b, moduleOp, kPrintF32, Float32Type::get(moduleOp->getContext()),
LLVM::LLVMVoidType::get(moduleOp->getContext()));
}
FailureOr<LLVM::LLVMFuncOp>
-mlir::LLVM::lookupOrCreatePrintF64Fn(Operation *moduleOp) {
+mlir::LLVM::lookupOrCreatePrintF64Fn(OpBuilder &b, Operation *moduleOp) {
return lookupOrCreateReservedFn(
- moduleOp, kPrintF64, Float64Type::get(moduleOp->getContext()),
+ b, moduleOp, kPrintF64, Float64Type::get(moduleOp->getContext()),
LLVM::LLVMVoidType::get(moduleOp->getContext()));
}
@@ -136,87 +139,91 @@ static LLVM::LLVMPointerType getVoidPtr(MLIRContext *context) {
}
F...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/136421
More information about the Mlir-commits
mailing list