[Mlir-commits] [mlir] e332c22 - [mlir][LLVM] NFC - Refactor a lookupOrCreateFn to reuse common function creation.

Nicolas Vasilache llvmlistbot at llvm.org
Thu Feb 11 07:55:55 PST 2021


Author: Nicolas Vasilache
Date: 2021-02-11T15:52:33Z
New Revision: e332c22cdf54d16974f445ba4345cd2a76d7fc6a

URL: https://github.com/llvm/llvm-project/commit/e332c22cdf54d16974f445ba4345cd2a76d7fc6a
DIFF: https://github.com/llvm/llvm-project/commit/e332c22cdf54d16974f445ba4345cd2a76d7fc6a.diff

LOG: [mlir][LLVM] NFC - Refactor a lookupOrCreateFn to reuse common function creation.

Differential revision: https://reviews.llvm.org/D96488

Added: 
    mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
    mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp

Modified: 
    mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
    mlir/lib/Dialect/LLVMIR/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
new file mode 100644
index 000000000000..7efff9774cd5
--- /dev/null
+++ b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
@@ -0,0 +1,63 @@
+//===- FunctionCallUtils.h - Utilities for C function calls -----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares helper functions to call common simple C functions in
+// LLVMIR (e.g. among others to support printing and debugging).
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_LLVMIR_FUNCTIONCALLUTILS_H_
+#define MLIR_DIALECT_LLVMIR_FUNCTIONCALLUTILS_H_
+
+#include "mlir/IR/Operation.h"
+#include "mlir/Support/LLVM.h"
+
+namespace mlir {
+class Location;
+class ModuleOp;
+class OpBuilder;
+class Operation;
+class Type;
+class ValueRange;
+
+namespace LLVM {
+class LLVMFuncOp;
+
+/// Helper functions to lookup or create the declaration for commonly used
+/// external C function calls. Such ops can then be invoked by creating a CallOp
+/// with the proper arguments via `createLLVMCall`.
+/// The list of functions provided here must be implemented separately (e.g. as
+/// part of a support runtime library or as part of the libc).
+LLVM::LLVMFuncOp lookupOrCreatePrintI64Fn(ModuleOp moduleOp);
+LLVM::LLVMFuncOp lookupOrCreatePrintU64Fn(ModuleOp moduleOp);
+LLVM::LLVMFuncOp lookupOrCreatePrintF32Fn(ModuleOp moduleOp);
+LLVM::LLVMFuncOp lookupOrCreatePrintF64Fn(ModuleOp moduleOp);
+LLVM::LLVMFuncOp lookupOrCreatePrintOpenFn(ModuleOp moduleOp);
+LLVM::LLVMFuncOp lookupOrCreatePrintCloseFn(ModuleOp moduleOp);
+LLVM::LLVMFuncOp lookupOrCreatePrintCommaFn(ModuleOp moduleOp);
+LLVM::LLVMFuncOp lookupOrCreatePrintNewlineFn(ModuleOp moduleOp);
+LLVM::LLVMFuncOp lookupOrCreateMallocFn(ModuleOp moduleOp, Type indexType);
+LLVM::LLVMFuncOp lookupOrCreateAlignedAllocFn(ModuleOp moduleOp,
+                                              Type indexType);
+LLVM::LLVMFuncOp lookupOrCreateFreeFn(ModuleOp moduleOp);
+
+/// Create a FuncOp with signature `resultType`(`paramTypes`)` and name `name`.
+LLVM::LLVMFuncOp lookupOrCreateFn(ModuleOp moduleOp, StringRef name,
+                                  ArrayRef<Type> paramTypes = {},
+                                  Type resultType = {});
+
+/// Helper wrapper to create a call to `fn` with `args` and `resultTypes`.
+Operation::result_range createLLVMCall(OpBuilder &b, Location loc,
+                                       LLVM::LLVMFuncOp fn,
+                                       ValueRange args = {},
+                                       ArrayRef<Type> resultTypes = {});
+
+} // namespace LLVM
+} // namespace mlir
+
+#endif // MLIR_DIALECT_LLVMIR_FUNCTIONCALLUTILS_H_

diff  --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index ea0a4259637c..da28ecbfc035 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -14,6 +14,7 @@
 #include "../PassDetail.h"
 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
+#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/IR/Attributes.h"
@@ -1793,31 +1794,6 @@ struct AllocLikeOpLowering : public ConvertToLLVMPattern {
     return rewriter.create<LLVM::SubOp>(loc, bumped, mod);
   }
 
-  // Creates a call to an allocation function with params and casts the
-  // resulting void pointer to ptrType.
-  Value createAllocCall(Location loc, StringRef name, Type ptrType,
-                        ArrayRef<Value> params, ModuleOp module,
-                        ConversionPatternRewriter &rewriter) const {
-    SmallVector<Type, 2> paramTypes;
-    auto allocFuncOp = module.lookupSymbol<LLVM::LLVMFuncOp>(name);
-    if (!allocFuncOp) {
-      for (Value param : params)
-        paramTypes.push_back(param.getType());
-      auto allocFuncType =
-          LLVM::LLVMFunctionType::get(getVoidPtrType(), paramTypes);
-      OpBuilder::InsertionGuard guard(rewriter);
-      rewriter.setInsertionPointToStart(module.getBody());
-      allocFuncOp = rewriter.create<LLVM::LLVMFuncOp>(rewriter.getUnknownLoc(),
-                                                      name, allocFuncType);
-    }
-    auto allocFuncSymbol = rewriter.getSymbolRefAttr(allocFuncOp);
-    auto allocatedPtr = rewriter
-                            .create<LLVM::CallOp>(loc, getVoidPtrType(),
-                                                  allocFuncSymbol, params)
-                            .getResult(0);
-    return rewriter.create<LLVM::BitcastOp>(loc, ptrType, allocatedPtr);
-  }
-
   /// Allocates the underlying buffer. Returns the allocated pointer and the
   /// aligned pointer.
   virtual std::tuple<Value, Value>
@@ -1909,9 +1885,12 @@ struct AllocOpLowering : public AllocLikeOpLowering {
     // Allocate the underlying buffer and store a pointer to it in the MemRef
     // descriptor.
     Type elementPtrType = this->getElementPtrType(memRefType);
+    auto allocFuncOp = LLVM::lookupOrCreateMallocFn(
+        allocOp->getParentOfType<ModuleOp>(), getIndexType());
+    auto results = createLLVMCall(rewriter, loc, allocFuncOp, {sizeBytes},
+                                  getVoidPtrType());
     Value allocatedPtr =
-        createAllocCall(loc, "malloc", elementPtrType, {sizeBytes},
-                        allocOp->getParentOfType<ModuleOp>(), rewriter);
+        rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, results[0]);
 
     Value alignedPtr = allocatedPtr;
     if (alignment) {
@@ -1991,9 +1970,13 @@ struct AlignedAllocOpLowering : public AllocLikeOpLowering {
       sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment);
 
     Type elementPtrType = this->getElementPtrType(memRefType);
-    Value allocatedPtr = createAllocCall(
-        loc, "aligned_alloc", elementPtrType, {allocAlignment, sizeBytes},
-        allocOp->getParentOfType<ModuleOp>(), rewriter);
+    auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn(
+        allocOp->getParentOfType<ModuleOp>(), getIndexType());
+    auto results =
+        createLLVMCall(rewriter, loc, allocFuncOp, {allocAlignment, sizeBytes},
+                       getVoidPtrType());
+    Value allocatedPtr =
+        rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, results[0]);
 
     return std::make_tuple(allocatedPtr, allocatedPtr);
   }
@@ -2056,31 +2039,17 @@ static LogicalResult copyUnrankedDescriptors(OpBuilder &builder, Location loc,
 
   // Get frequently used types.
   MLIRContext *context = builder.getContext();
-  auto voidType = LLVM::LLVMVoidType::get(context);
   Type voidPtrType = LLVM::LLVMPointerType::get(IntegerType::get(context, 8));
   auto i1Type = IntegerType::get(context, 1);
   Type indexType = typeConverter.getIndexType();
 
   // Find the malloc and free, or declare them if necessary.
   auto module = builder.getInsertionPoint()->getParentOfType<ModuleOp>();
-  auto mallocFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("malloc");
-  if (!mallocFunc && toDynamic) {
-    OpBuilder::InsertionGuard guard(builder);
-    builder.setInsertionPointToStart(module.getBody());
-    mallocFunc = builder.create<LLVM::LLVMFuncOp>(
-        builder.getUnknownLoc(), "malloc",
-        LLVM::LLVMFunctionType::get(voidPtrType, llvm::makeArrayRef(indexType),
-                                    /*isVarArg=*/false));
-  }
-  auto freeFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("free");
-  if (!freeFunc && !toDynamic) {
-    OpBuilder::InsertionGuard guard(builder);
-    builder.setInsertionPointToStart(module.getBody());
-    freeFunc = builder.create<LLVM::LLVMFuncOp>(
-        builder.getUnknownLoc(), "free",
-        LLVM::LLVMFunctionType::get(voidType, llvm::makeArrayRef(voidPtrType),
-                                    /*isVarArg=*/false));
-  }
+  LLVM::LLVMFuncOp freeFunc, mallocFunc;
+  if (toDynamic)
+    mallocFunc = LLVM::lookupOrCreateMallocFn(module, indexType);
+  if (!toDynamic)
+    freeFunc = LLVM::lookupOrCreateFreeFn(module);
 
   // Initialize shared constants.
   Value zero =
@@ -2217,17 +2186,7 @@ struct DeallocOpLowering : public ConvertOpToLLVMPattern<DeallocOp> {
     DeallocOp::Adaptor transformed(operands);
 
     // Insert the `free` declaration if it is not already present.
-    auto freeFunc =
-        op->getParentOfType<ModuleOp>().lookupSymbol<LLVM::LLVMFuncOp>("free");
-    if (!freeFunc) {
-      OpBuilder::InsertionGuard guard(rewriter);
-      rewriter.setInsertionPointToStart(
-          op->getParentOfType<ModuleOp>().getBody());
-      freeFunc = rewriter.create<LLVM::LLVMFuncOp>(
-          rewriter.getUnknownLoc(), "free",
-          LLVM::LLVMFunctionType::get(getVoidType(), getVoidPtrType()));
-    }
-
+    auto freeFunc = LLVM::lookupOrCreateFreeFn(op->getParentOfType<ModuleOp>());
     MemRefDescriptor memref(transformed.memref());
     Value casted = rewriter.create<LLVM::BitcastOp>(
         op.getLoc(), getVoidPtrType(),

diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 683de815a54e..54cdd9cfde60 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -10,6 +10,7 @@
 
 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
+#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/Dialect/Vector/VectorOps.h"
@@ -1311,11 +1312,14 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
     Type eltType = vectorType ? vectorType.getElementType() : printType;
     Operation *printer;
     if (eltType.isF32()) {
-      printer = getPrintFloat(printOp);
+      printer =
+          LLVM::lookupOrCreatePrintF32Fn(printOp->getParentOfType<ModuleOp>());
     } else if (eltType.isF64()) {
-      printer = getPrintDouble(printOp);
+      printer =
+          LLVM::lookupOrCreatePrintF64Fn(printOp->getParentOfType<ModuleOp>());
     } else if (eltType.isIndex()) {
-      printer = getPrintU64(printOp);
+      printer =
+          LLVM::lookupOrCreatePrintU64Fn(printOp->getParentOfType<ModuleOp>());
     } else if (auto intTy = eltType.dyn_cast<IntegerType>()) {
       // Integers need a zero or sign extension on the operand
       // (depending on the source type) as well as a signed or
@@ -1325,7 +1329,8 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
         if (width <= 64) {
           if (width < 64)
             conversion = PrintConversion::ZeroExt64;
-          printer = getPrintU64(printOp);
+          printer = LLVM::lookupOrCreatePrintU64Fn(
+              printOp->getParentOfType<ModuleOp>());
         } else {
           return failure();
         }
@@ -1338,7 +1343,8 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
             conversion = PrintConversion::ZeroExt64;
           else if (width < 64)
             conversion = PrintConversion::SignExt64;
-          printer = getPrintI64(printOp);
+          printer = LLVM::lookupOrCreatePrintI64Fn(
+              printOp->getParentOfType<ModuleOp>());
         } else {
           return failure();
         }
@@ -1351,7 +1357,9 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
     int64_t rank = vectorType ? vectorType.getRank() : 0;
     emitRanks(rewriter, printOp, adaptor.source(), vectorType, printer, rank,
               conversion);
-    emitCall(rewriter, printOp->getLoc(), getPrintNewline(printOp));
+    emitCall(rewriter, printOp->getLoc(),
+             LLVM::lookupOrCreatePrintNewlineFn(
+                 printOp->getParentOfType<ModuleOp>()));
     rewriter.eraseOp(printOp);
     return success();
   }
@@ -1386,8 +1394,10 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
       return;
     }
 
-    emitCall(rewriter, loc, getPrintOpen(op));
-    Operation *printComma = getPrintComma(op);
+    emitCall(rewriter, loc,
+             LLVM::lookupOrCreatePrintOpenFn(op->getParentOfType<ModuleOp>()));
+    Operation *printComma =
+        LLVM::lookupOrCreatePrintCommaFn(op->getParentOfType<ModuleOp>());
     int64_t dim = vectorType.getDimSize(0);
     for (int64_t d = 0; d < dim; ++d) {
       auto reducedType =
@@ -1401,7 +1411,8 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
       if (d != dim - 1)
         emitCall(rewriter, loc, printComma);
     }
-    emitCall(rewriter, loc, getPrintClose(op));
+    emitCall(rewriter, loc,
+             LLVM::lookupOrCreatePrintCloseFn(op->getParentOfType<ModuleOp>()));
   }
 
   // Helper to emit a call.
@@ -1410,46 +1421,6 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
     rewriter.create<LLVM::CallOp>(loc, TypeRange(),
                                   rewriter.getSymbolRefAttr(ref), params);
   }
-
-  // Helper for printer method declaration (first hit) and lookup.
-  static Operation *getPrint(Operation *op, StringRef name,
-                             ArrayRef<Type> params) {
-    auto module = op->getParentOfType<ModuleOp>();
-    auto func = module.lookupSymbol<LLVM::LLVMFuncOp>(name);
-    if (func)
-      return func;
-    OpBuilder moduleBuilder(module.getBodyRegion());
-    return moduleBuilder.create<LLVM::LLVMFuncOp>(
-        op->getLoc(), name,
-        LLVM::LLVMFunctionType::get(LLVM::LLVMVoidType::get(op->getContext()),
-                                    params));
-  }
-
-  // Helpers for method names.
-  Operation *getPrintI64(Operation *op) const {
-    return getPrint(op, "printI64", IntegerType::get(op->getContext(), 64));
-  }
-  Operation *getPrintU64(Operation *op) const {
-    return getPrint(op, "printU64", IntegerType::get(op->getContext(), 64));
-  }
-  Operation *getPrintFloat(Operation *op) const {
-    return getPrint(op, "printF32", Float32Type::get(op->getContext()));
-  }
-  Operation *getPrintDouble(Operation *op) const {
-    return getPrint(op, "printF64", Float64Type::get(op->getContext()));
-  }
-  Operation *getPrintOpen(Operation *op) const {
-    return getPrint(op, "printOpen", {});
-  }
-  Operation *getPrintClose(Operation *op) const {
-    return getPrint(op, "printClose", {});
-  }
-  Operation *getPrintComma(Operation *op) const {
-    return getPrint(op, "printComma", {});
-  }
-  Operation *getPrintNewline(Operation *op) const {
-    return getPrint(op, "printNewline", {});
-  }
 };
 
 /// Progressive lowering of ExtractStridedSliceOp to either:

diff  --git a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
index c2f88d06062c..337e3c48d1fc 100644
--- a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
+++ b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
@@ -1,6 +1,7 @@
 add_subdirectory(Transforms)
 
 add_mlir_dialect_library(MLIRLLVMIR
+  IR/FunctionCallUtils.cpp
   IR/LLVMDialect.cpp
   IR/LLVMTypes.cpp
   IR/LLVMTypeSyntax.cpp

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
new file mode 100644
index 000000000000..a43c2251c2d9
--- /dev/null
+++ b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
@@ -0,0 +1,125 @@
+//===- FunctionCallUtils.cpp - Utilities for C function calls -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements helper functions to call common simple C functions in
+// LLVMIR (e.g. amon others to support printing and debugging).
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/Support/LLVM.h"
+
+using namespace mlir;
+using namespace mlir::LLVM;
+
+/// Helper functions to lookup or create the declaration for commonly used
+/// external C function calls. The list of functions provided here must be
+/// implemented separately (e.g. as  part of a support runtime library or as
+/// part of the libc).
+static constexpr llvm::StringRef kPrintI64 = "printI64";
+static constexpr llvm::StringRef kPrintU64 = "printU64";
+static constexpr llvm::StringRef kPrintF32 = "printF32";
+static constexpr llvm::StringRef kPrintF64 = "printF64";
+static constexpr llvm::StringRef kPrintOpen = "printOpen";
+static constexpr llvm::StringRef kPrintClose = "printClose";
+static constexpr llvm::StringRef kPrintComma = "printComma";
+static constexpr llvm::StringRef kPrintNewline = "printNewline";
+static constexpr llvm::StringRef kMalloc = "malloc";
+static constexpr llvm::StringRef kAlignedAlloc = "aligned_alloc";
+static constexpr llvm::StringRef kFree = "free";
+
+/// Generic print function lookupOrCreate helper.
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFn(ModuleOp moduleOp, StringRef name,
+                                              ArrayRef<Type> paramTypes,
+                                              Type resultType) {
+  auto func = moduleOp.lookupSymbol<LLVM::LLVMFuncOp>(name);
+  if (func)
+    return func;
+  OpBuilder b(moduleOp.getBodyRegion());
+  return b.create<LLVM::LLVMFuncOp>(
+      moduleOp->getLoc(), name,
+      LLVM::LLVMFunctionType::get(resultType, paramTypes));
+}
+
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintI64Fn(ModuleOp moduleOp) {
+  return lookupOrCreateFn(moduleOp, kPrintI64,
+                          IntegerType::get(moduleOp->getContext(), 64),
+                          LLVM::LLVMVoidType::get(moduleOp->getContext()));
+}
+
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintU64Fn(ModuleOp moduleOp) {
+  return lookupOrCreateFn(moduleOp, kPrintU64,
+                          IntegerType::get(moduleOp->getContext(), 64),
+                          LLVM::LLVMVoidType::get(moduleOp->getContext()));
+}
+
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF32Fn(ModuleOp moduleOp) {
+  return lookupOrCreateFn(moduleOp, kPrintF32,
+                          Float32Type::get(moduleOp->getContext()),
+                          LLVM::LLVMVoidType::get(moduleOp->getContext()));
+}
+
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF64Fn(ModuleOp moduleOp) {
+  return lookupOrCreateFn(moduleOp, kPrintF64,
+                          Float64Type::get(moduleOp->getContext()),
+                          LLVM::LLVMVoidType::get(moduleOp->getContext()));
+}
+
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintOpenFn(ModuleOp moduleOp) {
+  return lookupOrCreateFn(moduleOp, kPrintOpen, {},
+                          LLVM::LLVMVoidType::get(moduleOp->getContext()));
+}
+
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCloseFn(ModuleOp moduleOp) {
+  return lookupOrCreateFn(moduleOp, kPrintClose, {},
+                          LLVM::LLVMVoidType::get(moduleOp->getContext()));
+}
+
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCommaFn(ModuleOp moduleOp) {
+  return lookupOrCreateFn(moduleOp, kPrintComma, {},
+                          LLVM::LLVMVoidType::get(moduleOp->getContext()));
+}
+
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintNewlineFn(ModuleOp moduleOp) {
+  return lookupOrCreateFn(moduleOp, kPrintNewline, {},
+                          LLVM::LLVMVoidType::get(moduleOp->getContext()));
+}
+
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateMallocFn(ModuleOp moduleOp,
+                                                    Type indexType) {
+  return LLVM::lookupOrCreateFn(
+      moduleOp, kMalloc, indexType,
+      LLVM::LLVMPointerType::get(IntegerType::get(moduleOp->getContext(), 8)));
+}
+
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateAlignedAllocFn(ModuleOp moduleOp,
+                                                          Type indexType) {
+  return LLVM::lookupOrCreateFn(
+      moduleOp, kAlignedAlloc, {indexType, indexType},
+      LLVM::LLVMPointerType::get(IntegerType::get(moduleOp->getContext(), 8)));
+}
+
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFreeFn(ModuleOp moduleOp) {
+  return LLVM::lookupOrCreateFn(
+      moduleOp, kFree,
+      LLVM::LLVMPointerType::get(IntegerType::get(moduleOp->getContext(), 8)),
+      LLVM::LLVMVoidType::get(moduleOp->getContext()));
+}
+
+Operation::result_range mlir::LLVM::createLLVMCall(OpBuilder &b, Location loc,
+                                                   LLVM::LLVMFuncOp fn,
+                                                   ValueRange paramTypes,
+                                                   ArrayRef<Type> resultTypes) {
+  return b
+      .create<LLVM::CallOp>(loc, resultTypes, b.getSymbolRefAttr(fn),
+                            paramTypes)
+      ->getResults();
+}


        


More information about the Mlir-commits mailing list