[Mlir-commits] [mlir] [mlir][LLVM] Add `OpBuilder &` to `lookupOrCreateFn` functions (PR #136421)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Apr 19 03:01:27 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

<details>
<summary>Changes</summary>

These functions are called from lowering patterns. All IR modifications in a pattern must be performed through the provided rewriter, but these functions used to instantiate a new `OpBuilder`, bypassing the provided rewriter.


---

Patch is 25.06 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/136421.diff


8 Files Affected:

- (modified) mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h (+35-22) 
- (modified) mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp (+2-2) 
- (modified) mlir/lib/Conversion/LLVMCommon/Pattern.cpp (+2-2) 
- (modified) mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp (+1-1) 
- (modified) mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp (+12-12) 
- (modified) mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp (+8-6) 
- (modified) mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp (+11-11) 
- (modified) mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp (+49-42) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
index 05e9fe9d58859..4a7ec6f2efe64 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
@@ -33,40 +33,53 @@ class LLVMFuncOp;
 /// implemented separately (e.g. as part of a support runtime library or as part
 /// of the libc).
 /// Failure if an unexpected version of function is found.
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintI64Fn(Operation *moduleOp);
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintU64Fn(Operation *moduleOp);
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintF16Fn(Operation *moduleOp);
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintBF16Fn(Operation *moduleOp);
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintF32Fn(Operation *moduleOp);
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintF64Fn(Operation *moduleOp);
+FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintI64Fn(OpBuilder &b,
+                                                     Operation *moduleOp);
+FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintU64Fn(OpBuilder &b,
+                                                     Operation *moduleOp);
+FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintF16Fn(OpBuilder &b,
+                                                     Operation *moduleOp);
+FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintBF16Fn(OpBuilder &b,
+                                                      Operation *moduleOp);
+FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintF32Fn(OpBuilder &b,
+                                                     Operation *moduleOp);
+FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintF64Fn(OpBuilder &b,
+                                                     Operation *moduleOp);
 /// Declares a function to print a C-string.
 /// If a custom runtime function is defined via `runtimeFunctionName`, it must
 /// have the signature void(char const*). The default function is `printString`.
 FailureOr<LLVM::LLVMFuncOp>
-lookupOrCreatePrintStringFn(Operation *moduleOp,
+lookupOrCreatePrintStringFn(OpBuilder &b, Operation *moduleOp,
                             std::optional<StringRef> runtimeFunctionName = {});
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintOpenFn(Operation *moduleOp);
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintCloseFn(Operation *moduleOp);
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintCommaFn(Operation *moduleOp);
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintNewlineFn(Operation *moduleOp);
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreateMallocFn(Operation *moduleOp,
-                                                   Type indexType);
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreateAlignedAllocFn(Operation *moduleOp,
-                                                         Type indexType);
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreateFreeFn(Operation *moduleOp);
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreateGenericAllocFn(Operation *moduleOp,
-                                                         Type indexType);
+FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintOpenFn(OpBuilder &b,
+                                                      Operation *moduleOp);
+FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintCloseFn(OpBuilder &b,
+                                                       Operation *moduleOp);
+FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintCommaFn(OpBuilder &b,
+                                                       Operation *moduleOp);
+FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintNewlineFn(OpBuilder &b,
+                                                         Operation *moduleOp);
 FailureOr<LLVM::LLVMFuncOp>
-lookupOrCreateGenericAlignedAllocFn(Operation *moduleOp, Type indexType);
-FailureOr<LLVM::LLVMFuncOp> lookupOrCreateGenericFreeFn(Operation *moduleOp);
+lookupOrCreateMallocFn(OpBuilder &b, Operation *moduleOp, Type indexType);
 FailureOr<LLVM::LLVMFuncOp>
-lookupOrCreateMemRefCopyFn(Operation *moduleOp, Type indexType,
+lookupOrCreateAlignedAllocFn(OpBuilder &b, Operation *moduleOp, Type indexType);
+FailureOr<LLVM::LLVMFuncOp> lookupOrCreateFreeFn(OpBuilder &b,
+                                                 Operation *moduleOp);
+FailureOr<LLVM::LLVMFuncOp>
+lookupOrCreateGenericAllocFn(OpBuilder &b, Operation *moduleOp, Type indexType);
+FailureOr<LLVM::LLVMFuncOp>
+lookupOrCreateGenericAlignedAllocFn(OpBuilder &b, Operation *moduleOp,
+                                    Type indexType);
+FailureOr<LLVM::LLVMFuncOp> lookupOrCreateGenericFreeFn(OpBuilder &b,
+                                                        Operation *moduleOp);
+FailureOr<LLVM::LLVMFuncOp>
+lookupOrCreateMemRefCopyFn(OpBuilder &b, Operation *moduleOp, Type indexType,
                            Type unrankedDescriptorType);
 
 /// Create a FuncOp with signature `resultType`(`paramTypes`)` and name `name`.
 /// Return a failure if the FuncOp found has unexpected signature.
 FailureOr<LLVM::LLVMFuncOp>
-lookupOrCreateFn(Operation *moduleOp, StringRef name,
+lookupOrCreateFn(OpBuilder &b, Operation *moduleOp, StringRef name,
                  ArrayRef<Type> paramTypes = {}, Type resultType = {},
                  bool isVarArg = false, bool isReserved = false);
 
diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
index 47d4474a5c28d..c95e375ce9afe 100644
--- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
+++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
@@ -395,7 +395,7 @@ class CoroBeginOpConversion : public AsyncOpConversionPattern<CoroBeginOp> {
 
     // Allocate memory for the coroutine frame.
     auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn(
-        op->getParentOfType<ModuleOp>(), rewriter.getI64Type());
+        rewriter, op->getParentOfType<ModuleOp>(), rewriter.getI64Type());
     if (failed(allocFuncOp))
       return failure();
     auto coroAlloc = rewriter.create<LLVM::CallOp>(
@@ -432,7 +432,7 @@ class CoroFreeOpConversion : public AsyncOpConversionPattern<CoroFreeOp> {
 
     // Free the memory.
     auto freeFuncOp =
-        LLVM::lookupOrCreateFreeFn(op->getParentOfType<ModuleOp>());
+        LLVM::lookupOrCreateFreeFn(rewriter, op->getParentOfType<ModuleOp>());
     if (failed(freeFuncOp))
       return failure();
     rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, freeFuncOp.value(),
diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
index 32bfd72475569..706ae1e72e2b3 100644
--- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
@@ -278,12 +278,12 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
   auto module = builder.getInsertionPoint()->getParentOfType<ModuleOp>();
   FailureOr<LLVM::LLVMFuncOp> freeFunc, mallocFunc;
   if (toDynamic) {
-    mallocFunc = LLVM::lookupOrCreateMallocFn(module, indexType);
+    mallocFunc = LLVM::lookupOrCreateMallocFn(builder, module, indexType);
     if (failed(mallocFunc))
       return failure();
   }
   if (!toDynamic) {
-    freeFunc = LLVM::lookupOrCreateFreeFn(module);
+    freeFunc = LLVM::lookupOrCreateFreeFn(builder, module);
     if (failed(freeFunc))
       return failure();
   }
diff --git a/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp b/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp
index 337c01f01a7cc..2815e05b3e11b 100644
--- a/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp
@@ -60,7 +60,7 @@ LogicalResult mlir::LLVM::createPrintStrCall(
   Value gep =
       builder.create<LLVM::GEPOp>(loc, ptrTy, arrayTy, msgAddr, indices);
   FailureOr<LLVM::LLVMFuncOp> printer =
-      LLVM::lookupOrCreatePrintStringFn(moduleOp, runtimeFunctionName);
+      LLVM::lookupOrCreatePrintStringFn(builder, moduleOp, runtimeFunctionName);
   if (failed(printer))
     return failure();
   builder.create<LLVM::CallOp>(loc, TypeRange(),
diff --git a/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp b/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
index bad209a4ddecf..e9b79983696aa 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
@@ -15,24 +15,24 @@
 using namespace mlir;
 
 static FailureOr<LLVM::LLVMFuncOp>
-getNotalignedAllocFn(const LLVMTypeConverter *typeConverter, Operation *module,
-                     Type indexType) {
+getNotalignedAllocFn(OpBuilder &b, const LLVMTypeConverter *typeConverter,
+                     Operation *module, Type indexType) {
   bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
   if (useGenericFn)
-    return LLVM::lookupOrCreateGenericAllocFn(module, indexType);
+    return LLVM::lookupOrCreateGenericAllocFn(b, module, indexType);
 
-  return LLVM::lookupOrCreateMallocFn(module, indexType);
+  return LLVM::lookupOrCreateMallocFn(b, module, indexType);
 }
 
 static FailureOr<LLVM::LLVMFuncOp>
-getAlignedAllocFn(const LLVMTypeConverter *typeConverter, Operation *module,
-                  Type indexType) {
+getAlignedAllocFn(OpBuilder &b, const LLVMTypeConverter *typeConverter,
+                  Operation *module, Type indexType) {
   bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
 
   if (useGenericFn)
-    return LLVM::lookupOrCreateGenericAlignedAllocFn(module, indexType);
+    return LLVM::lookupOrCreateGenericAlignedAllocFn(b, module, indexType);
 
-  return LLVM::lookupOrCreateAlignedAllocFn(module, indexType);
+  return LLVM::lookupOrCreateAlignedAllocFn(b, module, indexType);
 }
 
 Value AllocationOpLLVMLowering::createAligned(
@@ -75,8 +75,8 @@ std::tuple<Value, Value> AllocationOpLLVMLowering::allocateBufferManuallyAlign(
   Type elementPtrType = this->getElementPtrType(memRefType);
   assert(elementPtrType && "could not compute element ptr type");
   FailureOr<LLVM::LLVMFuncOp> allocFuncOp = getNotalignedAllocFn(
-      getTypeConverter(), op->getParentWithTrait<OpTrait::SymbolTable>(),
-      getIndexType());
+      rewriter, getTypeConverter(),
+      op->getParentWithTrait<OpTrait::SymbolTable>(), getIndexType());
   if (failed(allocFuncOp))
     return std::make_tuple(Value(), Value());
   auto results =
@@ -144,8 +144,8 @@ Value AllocationOpLLVMLowering::allocateBufferAutoAlign(
 
   Type elementPtrType = this->getElementPtrType(memRefType);
   FailureOr<LLVM::LLVMFuncOp> allocFuncOp = getAlignedAllocFn(
-      getTypeConverter(), op->getParentWithTrait<OpTrait::SymbolTable>(),
-      getIndexType());
+      rewriter, getTypeConverter(),
+      op->getParentWithTrait<OpTrait::SymbolTable>(), getIndexType());
   if (failed(allocFuncOp))
     return Value();
   auto results = rewriter.create<LLVM::CallOp>(
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index cb4317ef1bcec..9c219d8a3d8cb 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -43,13 +43,14 @@ static bool isStaticStrideOrOffset(int64_t strideOrOffset) {
 }
 
 static FailureOr<LLVM::LLVMFuncOp>
-getFreeFn(const LLVMTypeConverter *typeConverter, ModuleOp module) {
+getFreeFn(OpBuilder &b, const LLVMTypeConverter *typeConverter,
+          ModuleOp module) {
   bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
 
   if (useGenericFn)
-    return LLVM::lookupOrCreateGenericFreeFn(module);
+    return LLVM::lookupOrCreateGenericFreeFn(b, module);
 
-  return LLVM::lookupOrCreateFreeFn(module);
+  return LLVM::lookupOrCreateFreeFn(b, module);
 }
 
 struct AllocOpLowering : public AllocLikeOpLLVMLowering {
@@ -223,8 +224,8 @@ struct DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> {
   matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     // Insert the `free` declaration if it is not already present.
-    FailureOr<LLVM::LLVMFuncOp> freeFunc =
-        getFreeFn(getTypeConverter(), op->getParentOfType<ModuleOp>());
+    FailureOr<LLVM::LLVMFuncOp> freeFunc = getFreeFn(
+        rewriter, getTypeConverter(), op->getParentOfType<ModuleOp>());
     if (failed(freeFunc))
       return failure();
     Value allocatedPtr;
@@ -834,7 +835,8 @@ struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
     // potential alignment
     auto elemSize = getSizeInBytes(loc, srcType.getElementType(), rewriter);
     auto copyFn = LLVM::lookupOrCreateMemRefCopyFn(
-        op->getParentOfType<ModuleOp>(), getIndexType(), sourcePtr.getType());
+        rewriter, op->getParentOfType<ModuleOp>(), getIndexType(),
+        sourcePtr.getType());
     if (failed(copyFn))
       return failure();
     rewriter.create<LLVM::CallOp>(loc, copyFn.value(),
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 1a35d08196459..076e5512f375b 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1570,13 +1570,13 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
       FailureOr<LLVM::LLVMFuncOp> op = [&]() {
         switch (punct) {
         case PrintPunctuation::Close:
-          return LLVM::lookupOrCreatePrintCloseFn(parent);
+          return LLVM::lookupOrCreatePrintCloseFn(rewriter, parent);
         case PrintPunctuation::Open:
-          return LLVM::lookupOrCreatePrintOpenFn(parent);
+          return LLVM::lookupOrCreatePrintOpenFn(rewriter, parent);
         case PrintPunctuation::Comma:
-          return LLVM::lookupOrCreatePrintCommaFn(parent);
+          return LLVM::lookupOrCreatePrintCommaFn(rewriter, parent);
         case PrintPunctuation::NewLine:
-          return LLVM::lookupOrCreatePrintNewlineFn(parent);
+          return LLVM::lookupOrCreatePrintNewlineFn(rewriter, parent);
         default:
           llvm_unreachable("unexpected punctuation");
         }
@@ -1610,17 +1610,17 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
     PrintConversion conversion = PrintConversion::None;
     FailureOr<Operation *> printer;
     if (printType.isF32()) {
-      printer = LLVM::lookupOrCreatePrintF32Fn(parent);
+      printer = LLVM::lookupOrCreatePrintF32Fn(rewriter, parent);
     } else if (printType.isF64()) {
-      printer = LLVM::lookupOrCreatePrintF64Fn(parent);
+      printer = LLVM::lookupOrCreatePrintF64Fn(rewriter, parent);
     } else if (printType.isF16()) {
       conversion = PrintConversion::Bitcast16; // bits!
-      printer = LLVM::lookupOrCreatePrintF16Fn(parent);
+      printer = LLVM::lookupOrCreatePrintF16Fn(rewriter, parent);
     } else if (printType.isBF16()) {
       conversion = PrintConversion::Bitcast16; // bits!
-      printer = LLVM::lookupOrCreatePrintBF16Fn(parent);
+      printer = LLVM::lookupOrCreatePrintBF16Fn(rewriter, parent);
     } else if (printType.isIndex()) {
-      printer = LLVM::lookupOrCreatePrintU64Fn(parent);
+      printer = LLVM::lookupOrCreatePrintU64Fn(rewriter, parent);
     } else if (auto intTy = dyn_cast<IntegerType>(printType)) {
       // Integers need a zero or sign extension on the operand
       // (depending on the source type) as well as a signed or
@@ -1630,7 +1630,7 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
         if (width <= 64) {
           if (width < 64)
             conversion = PrintConversion::ZeroExt64;
-          printer = LLVM::lookupOrCreatePrintU64Fn(parent);
+          printer = LLVM::lookupOrCreatePrintU64Fn(rewriter, parent);
         } else {
           return failure();
         }
@@ -1643,7 +1643,7 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
             conversion = PrintConversion::ZeroExt64;
           else if (width < 64)
             conversion = PrintConversion::SignExt64;
-          printer = LLVM::lookupOrCreatePrintI64Fn(parent);
+          printer = LLVM::lookupOrCreatePrintI64Fn(rewriter, parent);
         } else {
           return failure();
         }
diff --git a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
index 68d4426e65301..1b4a8f496d3d0 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
@@ -46,7 +46,7 @@ static constexpr llvm::StringRef kMemRefCopy = "memrefCopy";
 
 /// Generic print function lookupOrCreate helper.
 FailureOr<LLVM::LLVMFuncOp>
-mlir::LLVM::lookupOrCreateFn(Operation *moduleOp, StringRef name,
+mlir::LLVM::lookupOrCreateFn(OpBuilder &b, Operation *moduleOp, StringRef name,
                              ArrayRef<Type> paramTypes, Type resultType,
                              bool isVarArg, bool isReserved) {
   assert(moduleOp->hasTrait<OpTrait::SymbolTable>() &&
@@ -69,60 +69,63 @@ mlir::LLVM::lookupOrCreateFn(Operation *moduleOp, StringRef name,
     }
     return func;
   }
-  OpBuilder b(moduleOp->getRegion(0));
+
+  OpBuilder::InsertionGuard g(b);
+  assert(!moduleOp->getRegion(0).empty() && "expected non-empty region");
+  b.setInsertionPointToStart(&moduleOp->getRegion(0).front());
   return b.create<LLVM::LLVMFuncOp>(
       moduleOp->getLoc(), name,
       LLVM::LLVMFunctionType::get(resultType, paramTypes, isVarArg));
 }
 
 static FailureOr<LLVM::LLVMFuncOp>
-lookupOrCreateReservedFn(Operation *moduleOp, StringRef name,
+lookupOrCreateReservedFn(OpBuilder &b, Operation *moduleOp, StringRef name,
                          ArrayRef<Type> paramTypes, Type resultType) {
-  return lookupOrCreateFn(moduleOp, name, paramTypes, resultType,
+  return lookupOrCreateFn(b, moduleOp, name, paramTypes, resultType,
                           /*isVarArg=*/false, /*isReserved=*/true);
 }
 
 FailureOr<LLVM::LLVMFuncOp>
-mlir::LLVM::lookupOrCreatePrintI64Fn(Operation *moduleOp) {
+mlir::LLVM::lookupOrCreatePrintI64Fn(OpBuilder &b, Operation *moduleOp) {
   return lookupOrCreateReservedFn(
-      moduleOp, kPrintI64, IntegerType::get(moduleOp->getContext(), 64),
+      b, moduleOp, kPrintI64, IntegerType::get(moduleOp->getContext(), 64),
       LLVM::LLVMVoidType::get(moduleOp->getContext()));
 }
 
 FailureOr<LLVM::LLVMFuncOp>
-mlir::LLVM::lookupOrCreatePrintU64Fn(Operation *moduleOp) {
+mlir::LLVM::lookupOrCreatePrintU64Fn(OpBuilder &b, Operation *moduleOp) {
   return lookupOrCreateReservedFn(
-      moduleOp, kPrintU64, IntegerType::get(moduleOp->getContext(), 64),
+      b, moduleOp, kPrintU64, IntegerType::get(moduleOp->getContext(), 64),
       LLVM::LLVMVoidType::get(moduleOp->getContext()));
 }
 
 FailureOr<LLVM::LLVMFuncOp>
-mlir::LLVM::lookupOrCreatePrintF16Fn(Operation *moduleOp) {
+mlir::LLVM::lookupOrCreatePrintF16Fn(OpBuilder &b, Operation *moduleOp) {
   return lookupOrCreateReservedFn(
-      moduleOp, kPrintF16,
+      b, moduleOp, kPrintF16,
       IntegerType::get(moduleOp->getContext(), 16), // bits!
       LLVM::LLVMVoidType::get(moduleOp->getContext()));
 }
 
 FailureOr<LLVM::LLVMFuncOp>
-mlir::LLVM::lookupOrCreatePrintBF16Fn(Operation *moduleOp) {
+mlir::LLVM::lookupOrCreatePrintBF16Fn(OpBuilder &b, Operation *moduleOp) {
   return lookupOrCreateReservedFn(
-      moduleOp, kPrintBF16,
+      b, moduleOp, kPrintBF16,
       IntegerType::get(moduleOp->getContext(), 16), // bits!
       LLVM::LLVMVoidType::get(moduleOp->getContext()));
 }
 
 FailureOr<LLVM::LLVMFuncOp>
-mlir::LLVM::lookupOrCreatePrintF32Fn(Operation *moduleOp) {
+mlir::LLVM::lookupOrCreatePrintF32Fn(OpBuilder &b, Operation *moduleOp) {
   return lookupOrCreateReservedFn(
-      moduleOp, kPrintF32, Float32Type::get(moduleOp->getContext()),
+      b, moduleOp, kPrintF32, Float32Type::get(moduleOp->getContext()),
       LLVM::LLVMVoidType::get(moduleOp->getContext()));
 }
 
 FailureOr<LLVM::LLVMFuncOp>
-mlir::LLVM::lookupOrCreatePrintF64Fn(Operation *moduleOp) {
+mlir::LLVM::lookupOrCreatePrintF64Fn(OpBuilder &b, Operation *moduleOp) {
   return lookupOrCreateReservedFn(
-      moduleOp, kPrintF64, Float64Type::get(moduleOp->getContext()),
+      b, moduleOp, kPrintF64, Float64Type::get(moduleOp->getContext()),
       LLVM::LLVMVoidType::get(moduleOp->getContext()));
 }
 
@@ -136,87 +139,91 @@ static LLVM::LLVMPointerType getVoidPtr(MLIRContext *context) {
 }
 
 F...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/136421


More information about the Mlir-commits mailing list