[Mlir-commits] [mlir] e332c22 - [mlir][LLVM] NFC - Refactor a lookupOrCreateFn to reuse common function creation.
Nicolas Vasilache
llvmlistbot at llvm.org
Thu Feb 11 07:55:55 PST 2021
Author: Nicolas Vasilache
Date: 2021-02-11T15:52:33Z
New Revision: e332c22cdf54d16974f445ba4345cd2a76d7fc6a
URL: https://github.com/llvm/llvm-project/commit/e332c22cdf54d16974f445ba4345cd2a76d7fc6a
DIFF: https://github.com/llvm/llvm-project/commit/e332c22cdf54d16974f445ba4345cd2a76d7fc6a.diff
LOG: [mlir][LLVM] NFC - Refactor a lookupOrCreateFn to reuse common function creation.
Differential revision: https://reviews.llvm.org/D96488
Added:
mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
Modified:
mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/lib/Dialect/LLVMIR/CMakeLists.txt
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
new file mode 100644
index 000000000000..7efff9774cd5
--- /dev/null
+++ b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
@@ -0,0 +1,63 @@
+//===- FunctionCallUtils.h - Utilities for C function calls -----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares helper functions to call common simple C functions in
+// LLVMIR (e.g. among others to support printing and debugging).
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_LLVMIR_FUNCTIONCALLUTILS_H_
+#define MLIR_DIALECT_LLVMIR_FUNCTIONCALLUTILS_H_
+
+#include "mlir/IR/Operation.h"
+#include "mlir/Support/LLVM.h"
+
+namespace mlir {
+class Location;
+class ModuleOp;
+class OpBuilder;
+class Operation;
+class Type;
+class ValueRange;
+
+namespace LLVM {
+class LLVMFuncOp;
+
+/// Helper functions to lookup or create the declaration for commonly used
+/// external C function calls. Such ops can then be invoked by creating a CallOp
+/// with the proper arguments via `createLLVMCall`.
+/// The list of functions provided here must be implemented separately (e.g. as
+/// part of a support runtime library or as part of the libc).
+LLVM::LLVMFuncOp lookupOrCreatePrintI64Fn(ModuleOp moduleOp);
+LLVM::LLVMFuncOp lookupOrCreatePrintU64Fn(ModuleOp moduleOp);
+LLVM::LLVMFuncOp lookupOrCreatePrintF32Fn(ModuleOp moduleOp);
+LLVM::LLVMFuncOp lookupOrCreatePrintF64Fn(ModuleOp moduleOp);
+LLVM::LLVMFuncOp lookupOrCreatePrintOpenFn(ModuleOp moduleOp);
+LLVM::LLVMFuncOp lookupOrCreatePrintCloseFn(ModuleOp moduleOp);
+LLVM::LLVMFuncOp lookupOrCreatePrintCommaFn(ModuleOp moduleOp);
+LLVM::LLVMFuncOp lookupOrCreatePrintNewlineFn(ModuleOp moduleOp);
+LLVM::LLVMFuncOp lookupOrCreateMallocFn(ModuleOp moduleOp, Type indexType);
+LLVM::LLVMFuncOp lookupOrCreateAlignedAllocFn(ModuleOp moduleOp,
+ Type indexType);
+LLVM::LLVMFuncOp lookupOrCreateFreeFn(ModuleOp moduleOp);
+
+/// Create a FuncOp with signature `resultType`(`paramTypes`)` and name `name`.
+LLVM::LLVMFuncOp lookupOrCreateFn(ModuleOp moduleOp, StringRef name,
+ ArrayRef<Type> paramTypes = {},
+ Type resultType = {});
+
+/// Helper wrapper to create a call to `fn` with `args` and `resultTypes`.
+Operation::result_range createLLVMCall(OpBuilder &b, Location loc,
+ LLVM::LLVMFuncOp fn,
+ ValueRange args = {},
+ ArrayRef<Type> resultTypes = {});
+
+} // namespace LLVM
+} // namespace mlir
+
+#endif // MLIR_DIALECT_LLVMIR_FUNCTIONCALLUTILS_H_
diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index ea0a4259637c..da28ecbfc035 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -14,6 +14,7 @@
#include "../PassDetail.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
+#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h"
@@ -1793,31 +1794,6 @@ struct AllocLikeOpLowering : public ConvertToLLVMPattern {
return rewriter.create<LLVM::SubOp>(loc, bumped, mod);
}
- // Creates a call to an allocation function with params and casts the
- // resulting void pointer to ptrType.
- Value createAllocCall(Location loc, StringRef name, Type ptrType,
- ArrayRef<Value> params, ModuleOp module,
- ConversionPatternRewriter &rewriter) const {
- SmallVector<Type, 2> paramTypes;
- auto allocFuncOp = module.lookupSymbol<LLVM::LLVMFuncOp>(name);
- if (!allocFuncOp) {
- for (Value param : params)
- paramTypes.push_back(param.getType());
- auto allocFuncType =
- LLVM::LLVMFunctionType::get(getVoidPtrType(), paramTypes);
- OpBuilder::InsertionGuard guard(rewriter);
- rewriter.setInsertionPointToStart(module.getBody());
- allocFuncOp = rewriter.create<LLVM::LLVMFuncOp>(rewriter.getUnknownLoc(),
- name, allocFuncType);
- }
- auto allocFuncSymbol = rewriter.getSymbolRefAttr(allocFuncOp);
- auto allocatedPtr = rewriter
- .create<LLVM::CallOp>(loc, getVoidPtrType(),
- allocFuncSymbol, params)
- .getResult(0);
- return rewriter.create<LLVM::BitcastOp>(loc, ptrType, allocatedPtr);
- }
-
/// Allocates the underlying buffer. Returns the allocated pointer and the
/// aligned pointer.
virtual std::tuple<Value, Value>
@@ -1909,9 +1885,12 @@ struct AllocOpLowering : public AllocLikeOpLowering {
// Allocate the underlying buffer and store a pointer to it in the MemRef
// descriptor.
Type elementPtrType = this->getElementPtrType(memRefType);
+ auto allocFuncOp = LLVM::lookupOrCreateMallocFn(
+ allocOp->getParentOfType<ModuleOp>(), getIndexType());
+ auto results = createLLVMCall(rewriter, loc, allocFuncOp, {sizeBytes},
+ getVoidPtrType());
Value allocatedPtr =
- createAllocCall(loc, "malloc", elementPtrType, {sizeBytes},
- allocOp->getParentOfType<ModuleOp>(), rewriter);
+ rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, results[0]);
Value alignedPtr = allocatedPtr;
if (alignment) {
@@ -1991,9 +1970,13 @@ struct AlignedAllocOpLowering : public AllocLikeOpLowering {
sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment);
Type elementPtrType = this->getElementPtrType(memRefType);
- Value allocatedPtr = createAllocCall(
- loc, "aligned_alloc", elementPtrType, {allocAlignment, sizeBytes},
- allocOp->getParentOfType<ModuleOp>(), rewriter);
+ auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn(
+ allocOp->getParentOfType<ModuleOp>(), getIndexType());
+ auto results =
+ createLLVMCall(rewriter, loc, allocFuncOp, {allocAlignment, sizeBytes},
+ getVoidPtrType());
+ Value allocatedPtr =
+ rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, results[0]);
return std::make_tuple(allocatedPtr, allocatedPtr);
}
@@ -2056,31 +2039,17 @@ static LogicalResult copyUnrankedDescriptors(OpBuilder &builder, Location loc,
// Get frequently used types.
MLIRContext *context = builder.getContext();
- auto voidType = LLVM::LLVMVoidType::get(context);
Type voidPtrType = LLVM::LLVMPointerType::get(IntegerType::get(context, 8));
auto i1Type = IntegerType::get(context, 1);
Type indexType = typeConverter.getIndexType();
// Find the malloc and free, or declare them if necessary.
auto module = builder.getInsertionPoint()->getParentOfType<ModuleOp>();
- auto mallocFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("malloc");
- if (!mallocFunc && toDynamic) {
- OpBuilder::InsertionGuard guard(builder);
- builder.setInsertionPointToStart(module.getBody());
- mallocFunc = builder.create<LLVM::LLVMFuncOp>(
- builder.getUnknownLoc(), "malloc",
- LLVM::LLVMFunctionType::get(voidPtrType, llvm::makeArrayRef(indexType),
- /*isVarArg=*/false));
- }
- auto freeFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("free");
- if (!freeFunc && !toDynamic) {
- OpBuilder::InsertionGuard guard(builder);
- builder.setInsertionPointToStart(module.getBody());
- freeFunc = builder.create<LLVM::LLVMFuncOp>(
- builder.getUnknownLoc(), "free",
- LLVM::LLVMFunctionType::get(voidType, llvm::makeArrayRef(voidPtrType),
- /*isVarArg=*/false));
- }
+ LLVM::LLVMFuncOp freeFunc, mallocFunc;
+ if (toDynamic)
+ mallocFunc = LLVM::lookupOrCreateMallocFn(module, indexType);
+ if (!toDynamic)
+ freeFunc = LLVM::lookupOrCreateFreeFn(module);
// Initialize shared constants.
Value zero =
@@ -2217,17 +2186,7 @@ struct DeallocOpLowering : public ConvertOpToLLVMPattern<DeallocOp> {
DeallocOp::Adaptor transformed(operands);
// Insert the `free` declaration if it is not already present.
- auto freeFunc =
- op->getParentOfType<ModuleOp>().lookupSymbol<LLVM::LLVMFuncOp>("free");
- if (!freeFunc) {
- OpBuilder::InsertionGuard guard(rewriter);
- rewriter.setInsertionPointToStart(
- op->getParentOfType<ModuleOp>().getBody());
- freeFunc = rewriter.create<LLVM::LLVMFuncOp>(
- rewriter.getUnknownLoc(), "free",
- LLVM::LLVMFunctionType::get(getVoidType(), getVoidPtrType()));
- }
-
+ auto freeFunc = LLVM::lookupOrCreateFreeFn(op->getParentOfType<ModuleOp>());
MemRefDescriptor memref(transformed.memref());
Value casted = rewriter.create<LLVM::BitcastOp>(
op.getLoc(), getVoidPtrType(),
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 683de815a54e..54cdd9cfde60 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -10,6 +10,7 @@
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
+#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Vector/VectorOps.h"
@@ -1311,11 +1312,14 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
Type eltType = vectorType ? vectorType.getElementType() : printType;
Operation *printer;
if (eltType.isF32()) {
- printer = getPrintFloat(printOp);
+ printer =
+ LLVM::lookupOrCreatePrintF32Fn(printOp->getParentOfType<ModuleOp>());
} else if (eltType.isF64()) {
- printer = getPrintDouble(printOp);
+ printer =
+ LLVM::lookupOrCreatePrintF64Fn(printOp->getParentOfType<ModuleOp>());
} else if (eltType.isIndex()) {
- printer = getPrintU64(printOp);
+ printer =
+ LLVM::lookupOrCreatePrintU64Fn(printOp->getParentOfType<ModuleOp>());
} else if (auto intTy = eltType.dyn_cast<IntegerType>()) {
// Integers need a zero or sign extension on the operand
// (depending on the source type) as well as a signed or
@@ -1325,7 +1329,8 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
if (width <= 64) {
if (width < 64)
conversion = PrintConversion::ZeroExt64;
- printer = getPrintU64(printOp);
+ printer = LLVM::lookupOrCreatePrintU64Fn(
+ printOp->getParentOfType<ModuleOp>());
} else {
return failure();
}
@@ -1338,7 +1343,8 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
conversion = PrintConversion::ZeroExt64;
else if (width < 64)
conversion = PrintConversion::SignExt64;
- printer = getPrintI64(printOp);
+ printer = LLVM::lookupOrCreatePrintI64Fn(
+ printOp->getParentOfType<ModuleOp>());
} else {
return failure();
}
@@ -1351,7 +1357,9 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
int64_t rank = vectorType ? vectorType.getRank() : 0;
emitRanks(rewriter, printOp, adaptor.source(), vectorType, printer, rank,
conversion);
- emitCall(rewriter, printOp->getLoc(), getPrintNewline(printOp));
+ emitCall(rewriter, printOp->getLoc(),
+ LLVM::lookupOrCreatePrintNewlineFn(
+ printOp->getParentOfType<ModuleOp>()));
rewriter.eraseOp(printOp);
return success();
}
@@ -1386,8 +1394,10 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
return;
}
- emitCall(rewriter, loc, getPrintOpen(op));
- Operation *printComma = getPrintComma(op);
+ emitCall(rewriter, loc,
+ LLVM::lookupOrCreatePrintOpenFn(op->getParentOfType<ModuleOp>()));
+ Operation *printComma =
+ LLVM::lookupOrCreatePrintCommaFn(op->getParentOfType<ModuleOp>());
int64_t dim = vectorType.getDimSize(0);
for (int64_t d = 0; d < dim; ++d) {
auto reducedType =
@@ -1401,7 +1411,8 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
if (d != dim - 1)
emitCall(rewriter, loc, printComma);
}
- emitCall(rewriter, loc, getPrintClose(op));
+ emitCall(rewriter, loc,
+ LLVM::lookupOrCreatePrintCloseFn(op->getParentOfType<ModuleOp>()));
}
// Helper to emit a call.
@@ -1410,46 +1421,6 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
rewriter.create<LLVM::CallOp>(loc, TypeRange(),
rewriter.getSymbolRefAttr(ref), params);
}
-
- // Helper for printer method declaration (first hit) and lookup.
- static Operation *getPrint(Operation *op, StringRef name,
- ArrayRef<Type> params) {
- auto module = op->getParentOfType<ModuleOp>();
- auto func = module.lookupSymbol<LLVM::LLVMFuncOp>(name);
- if (func)
- return func;
- OpBuilder moduleBuilder(module.getBodyRegion());
- return moduleBuilder.create<LLVM::LLVMFuncOp>(
- op->getLoc(), name,
- LLVM::LLVMFunctionType::get(LLVM::LLVMVoidType::get(op->getContext()),
- params));
- }
-
- // Helpers for method names.
- Operation *getPrintI64(Operation *op) const {
- return getPrint(op, "printI64", IntegerType::get(op->getContext(), 64));
- }
- Operation *getPrintU64(Operation *op) const {
- return getPrint(op, "printU64", IntegerType::get(op->getContext(), 64));
- }
- Operation *getPrintFloat(Operation *op) const {
- return getPrint(op, "printF32", Float32Type::get(op->getContext()));
- }
- Operation *getPrintDouble(Operation *op) const {
- return getPrint(op, "printF64", Float64Type::get(op->getContext()));
- }
- Operation *getPrintOpen(Operation *op) const {
- return getPrint(op, "printOpen", {});
- }
- Operation *getPrintClose(Operation *op) const {
- return getPrint(op, "printClose", {});
- }
- Operation *getPrintComma(Operation *op) const {
- return getPrint(op, "printComma", {});
- }
- Operation *getPrintNewline(Operation *op) const {
- return getPrint(op, "printNewline", {});
- }
};
/// Progressive lowering of ExtractStridedSliceOp to either:
diff --git a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
index c2f88d06062c..337e3c48d1fc 100644
--- a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
+++ b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
@@ -1,6 +1,7 @@
add_subdirectory(Transforms)
add_mlir_dialect_library(MLIRLLVMIR
+ IR/FunctionCallUtils.cpp
IR/LLVMDialect.cpp
IR/LLVMTypes.cpp
IR/LLVMTypeSyntax.cpp
diff --git a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
new file mode 100644
index 000000000000..a43c2251c2d9
--- /dev/null
+++ b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
@@ -0,0 +1,125 @@
+//===- FunctionCallUtils.cpp - Utilities for C function calls -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements helper functions to call common simple C functions in
+// LLVMIR (e.g. amon others to support printing and debugging).
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/Support/LLVM.h"
+
+using namespace mlir;
+using namespace mlir::LLVM;
+
+/// Helper functions to lookup or create the declaration for commonly used
+/// external C function calls. The list of functions provided here must be
+/// implemented separately (e.g. as part of a support runtime library or as
+/// part of the libc).
+static constexpr llvm::StringRef kPrintI64 = "printI64";
+static constexpr llvm::StringRef kPrintU64 = "printU64";
+static constexpr llvm::StringRef kPrintF32 = "printF32";
+static constexpr llvm::StringRef kPrintF64 = "printF64";
+static constexpr llvm::StringRef kPrintOpen = "printOpen";
+static constexpr llvm::StringRef kPrintClose = "printClose";
+static constexpr llvm::StringRef kPrintComma = "printComma";
+static constexpr llvm::StringRef kPrintNewline = "printNewline";
+static constexpr llvm::StringRef kMalloc = "malloc";
+static constexpr llvm::StringRef kAlignedAlloc = "aligned_alloc";
+static constexpr llvm::StringRef kFree = "free";
+
+/// Generic print function lookupOrCreate helper.
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFn(ModuleOp moduleOp, StringRef name,
+ ArrayRef<Type> paramTypes,
+ Type resultType) {
+ auto func = moduleOp.lookupSymbol<LLVM::LLVMFuncOp>(name);
+ if (func)
+ return func;
+ OpBuilder b(moduleOp.getBodyRegion());
+ return b.create<LLVM::LLVMFuncOp>(
+ moduleOp->getLoc(), name,
+ LLVM::LLVMFunctionType::get(resultType, paramTypes));
+}
+
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintI64Fn(ModuleOp moduleOp) {
+ return lookupOrCreateFn(moduleOp, kPrintI64,
+ IntegerType::get(moduleOp->getContext(), 64),
+ LLVM::LLVMVoidType::get(moduleOp->getContext()));
+}
+
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintU64Fn(ModuleOp moduleOp) {
+ return lookupOrCreateFn(moduleOp, kPrintU64,
+ IntegerType::get(moduleOp->getContext(), 64),
+ LLVM::LLVMVoidType::get(moduleOp->getContext()));
+}
+
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF32Fn(ModuleOp moduleOp) {
+ return lookupOrCreateFn(moduleOp, kPrintF32,
+ Float32Type::get(moduleOp->getContext()),
+ LLVM::LLVMVoidType::get(moduleOp->getContext()));
+}
+
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF64Fn(ModuleOp moduleOp) {
+ return lookupOrCreateFn(moduleOp, kPrintF64,
+ Float64Type::get(moduleOp->getContext()),
+ LLVM::LLVMVoidType::get(moduleOp->getContext()));
+}
+
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintOpenFn(ModuleOp moduleOp) {
+ return lookupOrCreateFn(moduleOp, kPrintOpen, {},
+ LLVM::LLVMVoidType::get(moduleOp->getContext()));
+}
+
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCloseFn(ModuleOp moduleOp) {
+ return lookupOrCreateFn(moduleOp, kPrintClose, {},
+ LLVM::LLVMVoidType::get(moduleOp->getContext()));
+}
+
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCommaFn(ModuleOp moduleOp) {
+ return lookupOrCreateFn(moduleOp, kPrintComma, {},
+ LLVM::LLVMVoidType::get(moduleOp->getContext()));
+}
+
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintNewlineFn(ModuleOp moduleOp) {
+ return lookupOrCreateFn(moduleOp, kPrintNewline, {},
+ LLVM::LLVMVoidType::get(moduleOp->getContext()));
+}
+
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateMallocFn(ModuleOp moduleOp,
+ Type indexType) {
+ return LLVM::lookupOrCreateFn(
+ moduleOp, kMalloc, indexType,
+ LLVM::LLVMPointerType::get(IntegerType::get(moduleOp->getContext(), 8)));
+}
+
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateAlignedAllocFn(ModuleOp moduleOp,
+ Type indexType) {
+ return LLVM::lookupOrCreateFn(
+ moduleOp, kAlignedAlloc, {indexType, indexType},
+ LLVM::LLVMPointerType::get(IntegerType::get(moduleOp->getContext(), 8)));
+}
+
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFreeFn(ModuleOp moduleOp) {
+ return LLVM::lookupOrCreateFn(
+ moduleOp, kFree,
+ LLVM::LLVMPointerType::get(IntegerType::get(moduleOp->getContext(), 8)),
+ LLVM::LLVMVoidType::get(moduleOp->getContext()));
+}
+
+Operation::result_range mlir::LLVM::createLLVMCall(OpBuilder &b, Location loc,
+ LLVM::LLVMFuncOp fn,
+ ValueRange paramTypes,
+ ArrayRef<Type> resultTypes) {
+ return b
+ .create<LLVM::CallOp>(loc, resultTypes, b.getSymbolRefAttr(fn),
+ paramTypes)
+ ->getResults();
+}
More information about the Mlir-commits
mailing list