[Mlir-commits] [mlir] [NFC][FuncOpToLLVM] refactor and move some utils to headers (PR #68665)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Oct 9 23:38:38 PDT 2023
https://github.com/allatit23 created https://github.com/llvm/llvm-project/pull/68665
* refactor `convertFuncOpToLLVMFuncOp` to accept a `FunctionOpInterface` instead of func::FuncOp
* move `convertFuncOpToLLVMFuncOp` to corresponding public header, making it available for downstream project.
>From 15fcfab294578ff6add4ad8e2d1aeecd7f0fba9b Mon Sep 17 00:00:00 2001
From: Allen Zhao <allzhao at nvidia.com>
Date: Mon, 9 Oct 2023 23:20:14 -0700
Subject: [PATCH] [NFC][FuncOpToLLVM] refactor and move some utils to headers
* refactor `convertFuncOpToLLVMFuncOp` to accept a `FunctionOpInterface` instead of func::FuncOp
* move `convertFuncOpToLLVMFuncOp` to corresponding public header, making it available for downstream project.
---
.../Conversion/FuncToLLVM/ConvertFuncToLLVM.h | 15 +
mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp | 284 +++++++++---------
2 files changed, 165 insertions(+), 134 deletions(-)
diff --git a/mlir/include/mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h b/mlir/include/mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h
index 21bd191aa9dc8c3..7ce9bb18c93ea43 100644
--- a/mlir/include/mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h
+++ b/mlir/include/mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h
@@ -14,13 +14,28 @@
#ifndef MLIR_CONVERSION_FUNCTOLLVM_CONVERTFUNCTOLLVM_H
#define MLIR_CONVERSION_FUNCTOLLVM_CONVERTFUNCTOLLVM_H
+#include "mlir/Interfaces/FunctionInterfaces.h"
+#include "mlir/Support/LogicalResult.h"
+
namespace mlir {
+namespace LLVM {
+class LLVMFuncOp;
+} // namespace LLVM
+
+class ConversionPatternRewriter;
class DialectRegistry;
class LLVMTypeConverter;
class RewritePatternSet;
class SymbolTable;
+/// Convert input FuncOp to LLVMFuncOp by using the provided LLVMTypeConverter.
+/// Returns failure if failed to so.
+FailureOr<LLVM::LLVMFuncOp>
+convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
+ ConversionPatternRewriter &rewriter,
+ const LLVMTypeConverter &converter);
+
/// Collect the default pattern to convert a FuncOp to the LLVM dialect. If
/// `emitCWrappers` is set, the pattern will also produce functions
/// that pass memref descriptors by pointer-to-structure in addition to the
diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
index 4aacb47a7fe9cc1..3506f50916132dd 100644
--- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
+++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
@@ -74,7 +74,7 @@ static bool shouldUseBarePtrCallConv(Operation *op,
/// Only retain those attributes that are not constructed by
/// `LLVMFuncOp::build`.
-static void filterFuncAttributes(func::FuncOp func,
+static void filterFuncAttributes(FunctionOpInterface func,
SmallVectorImpl<NamedAttribute> &result) {
for (const NamedAttribute &attr : func->getDiscardableAttrs()) {
if (attr.getName() == linkageAttrName ||
@@ -87,26 +87,26 @@ static void filterFuncAttributes(func::FuncOp func,
/// Propagate argument/results attributes.
static void propagateArgResAttrs(OpBuilder &builder, bool resultStructType,
- func::FuncOp funcOp,
+ FunctionOpInterface funcOp,
LLVM::LLVMFuncOp wrapperFuncOp) {
- auto argAttrs = funcOp.getArgAttrs();
+ auto argAttrs = funcOp.getAllArgAttrs();
if (!resultStructType) {
if (auto resAttrs = funcOp.getAllResultAttrs())
wrapperFuncOp.setAllResultAttrs(resAttrs);
if (argAttrs)
- wrapperFuncOp.setAllArgAttrs(*argAttrs);
+ wrapperFuncOp.setAllArgAttrs(argAttrs);
} else {
SmallVector<Attribute> argAttributes;
// Only modify the argument and result attributes when the result is now
// an argument.
if (argAttrs) {
argAttributes.push_back(builder.getDictionaryAttr({}));
- argAttributes.append(argAttrs->begin(), argAttrs->end());
+ argAttributes.append(argAttrs.begin(), argAttrs.end());
wrapperFuncOp.setAllArgAttrs(argAttributes);
}
}
- if (funcOp.getSymVisibilityAttr())
- wrapperFuncOp.setSymVisibility(funcOp.getSymVisibilityAttr());
+ cast<FunctionOpInterface>(wrapperFuncOp.getOperation())
+ .setVisibility(funcOp.getVisibility());
}
/// Creates an auxiliary function with pointer-to-memref-descriptor-struct
@@ -119,9 +119,9 @@ static void propagateArgResAttrs(OpBuilder &builder, bool resultStructType,
/// the extra arguments.
static void wrapForExternalCallers(OpBuilder &rewriter, Location loc,
const LLVMTypeConverter &typeConverter,
- func::FuncOp funcOp,
+ FunctionOpInterface funcOp,
LLVM::LLVMFuncOp newFuncOp) {
- auto type = funcOp.getFunctionType();
+ auto type = cast<FunctionType>(funcOp.getFunctionType());
auto [wrapperFuncType, resultStructType] =
typeConverter.convertFunctionTypeCWrapper(type);
@@ -179,12 +179,13 @@ static void wrapForExternalCallers(OpBuilder &rewriter, Location loc,
/// corresponding to a memref descriptor.
static void wrapExternalFunction(OpBuilder &builder, Location loc,
const LLVMTypeConverter &typeConverter,
- func::FuncOp funcOp,
+ FunctionOpInterface funcOp,
LLVM::LLVMFuncOp newFuncOp) {
OpBuilder::InsertionGuard guard(builder);
auto [wrapperType, resultStructType] =
- typeConverter.convertFunctionTypeCWrapper(funcOp.getFunctionType());
+ typeConverter.convertFunctionTypeCWrapper(
+ cast<FunctionType>(funcOp.getFunctionType()));
// This conversion can only fail if it could not convert one of the argument
// types. But since it has been applied to a non-wrapper function before, it
// should have failed earlier and not reach this point at all.
@@ -205,7 +206,7 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc,
builder.setInsertionPointToStart(newFuncOp.addEntryBlock());
// Get a ValueRange containing arguments.
- FunctionType type = funcOp.getFunctionType();
+ FunctionType type = cast<FunctionType>(funcOp.getFunctionType());
SmallVector<Value, 8> args;
args.reserve(type.getNumInputs());
ValueRange wrapperArgsRange(newFuncOp.getArguments());
@@ -317,6 +318,140 @@ static void modifyFuncOpToUseBarePtrCallingConv(
}
}
+FailureOr<LLVM::LLVMFuncOp>
+mlir::convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
+ ConversionPatternRewriter &rewriter,
+ const LLVMTypeConverter &converter) {
+ // Check the funcOp has `FunctionType`.
+ auto funcTy = dyn_cast<FunctionType>(funcOp.getFunctionType());
+ if (!funcTy)
+ return rewriter.notifyMatchFailure(
+ funcOp, "Only support FunctionOpInterface with FunctionType");
+
+ // Convert the original function arguments. They are converted using the
+ // LLVMTypeConverter provided to this legalization pattern.
+ auto varargsAttr = funcOp->getAttrOfType<BoolAttr>(varargsAttrName);
+ TypeConverter::SignatureConversion result(funcOp.getNumArguments());
+ auto llvmType = converter.convertFunctionSignature(
+ funcTy, varargsAttr && varargsAttr.getValue(),
+ shouldUseBarePtrCallConv(funcOp, &converter), result);
+ if (!llvmType)
+ return rewriter.notifyMatchFailure(funcOp, "signature conversion failed");
+
+ // Create an LLVM function, use external linkage by default until MLIR
+ // functions have linkage.
+ LLVM::Linkage linkage = LLVM::Linkage::External;
+ if (funcOp->hasAttr(linkageAttrName)) {
+ auto attr =
+ dyn_cast<mlir::LLVM::LinkageAttr>(funcOp->getAttr(linkageAttrName));
+ if (!attr) {
+ funcOp->emitError() << "Contains " << linkageAttrName
+ << " attribute not of type LLVM::LinkageAttr";
+ return rewriter.notifyMatchFailure(
+ funcOp, "Contains linkage attribute not of type LLVM::LinkageAttr");
+ }
+ linkage = attr.getLinkage();
+ }
+
+ SmallVector<NamedAttribute, 4> attributes;
+ filterFuncAttributes(funcOp, attributes);
+ auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
+ funcOp.getLoc(), funcOp.getName(), llvmType, linkage,
+ /*dsoLocal=*/false, /*cconv=*/LLVM::CConv::C, /*comdat=*/nullptr,
+ attributes);
+ cast<FunctionOpInterface>(newFuncOp.getOperation())
+ .setVisibility(funcOp.getVisibility());
+
+ // Create a memory effect attribute corresponding to readnone.
+ StringRef readnoneAttrName = LLVM::LLVMDialect::getReadnoneAttrName();
+ if (funcOp->hasAttr(readnoneAttrName)) {
+ auto attr = funcOp->getAttrOfType<UnitAttr>(readnoneAttrName);
+ if (!attr) {
+ funcOp->emitError() << "Contains " << readnoneAttrName
+ << " attribute not of type UnitAttr";
+ return rewriter.notifyMatchFailure(
+ funcOp, "Contains readnone attribute not of type UnitAttr");
+ }
+ auto memoryAttr = LLVM::MemoryEffectsAttr::get(
+ rewriter.getContext(),
+ {LLVM::ModRefInfo::NoModRef, LLVM::ModRefInfo::NoModRef,
+ LLVM::ModRefInfo::NoModRef});
+ newFuncOp.setMemoryAttr(memoryAttr);
+ }
+
+ // Propagate argument/result attributes to all converted arguments/result
+ // obtained after converting a given original argument/result.
+ if (ArrayAttr resAttrDicts = funcOp.getAllResultAttrs()) {
+ assert(!resAttrDicts.empty() && "expected array to be non-empty");
+ if (funcOp.getNumResults() == 1)
+ newFuncOp.setAllResultAttrs(resAttrDicts);
+ }
+ if (ArrayAttr argAttrDicts = funcOp.getAllArgAttrs()) {
+ SmallVector<Attribute> newArgAttrs(
+ cast<LLVM::LLVMFunctionType>(llvmType).getNumParams());
+ for (unsigned i = 0, e = funcOp.getNumArguments(); i < e; ++i) {
+ // Some LLVM IR attribute have a type attached to them. During FuncOp ->
+ // LLVMFuncOp conversion these types may have changed. Account for that
+ // change by converting attributes' types as well.
+ SmallVector<NamedAttribute, 4> convertedAttrs;
+ auto attrsDict = cast<DictionaryAttr>(argAttrDicts[i]);
+ convertedAttrs.reserve(attrsDict.size());
+ for (const NamedAttribute &attr : attrsDict) {
+ const auto convert = [&](const NamedAttribute &attr) {
+ return TypeAttr::get(converter.convertType(
+ cast<TypeAttr>(attr.getValue()).getValue()));
+ };
+ if (attr.getName().getValue() ==
+ LLVM::LLVMDialect::getByValAttrName()) {
+ convertedAttrs.push_back(rewriter.getNamedAttr(
+ LLVM::LLVMDialect::getByValAttrName(), convert(attr)));
+ } else if (attr.getName().getValue() ==
+ LLVM::LLVMDialect::getByRefAttrName()) {
+ convertedAttrs.push_back(rewriter.getNamedAttr(
+ LLVM::LLVMDialect::getByRefAttrName(), convert(attr)));
+ } else if (attr.getName().getValue() ==
+ LLVM::LLVMDialect::getStructRetAttrName()) {
+ convertedAttrs.push_back(rewriter.getNamedAttr(
+ LLVM::LLVMDialect::getStructRetAttrName(), convert(attr)));
+ } else if (attr.getName().getValue() ==
+ LLVM::LLVMDialect::getInAllocaAttrName()) {
+ convertedAttrs.push_back(rewriter.getNamedAttr(
+ LLVM::LLVMDialect::getInAllocaAttrName(), convert(attr)));
+ } else {
+ convertedAttrs.push_back(attr);
+ }
+ }
+ auto mapping = result.getInputMapping(i);
+ assert(mapping && "unexpected deletion of function argument");
+ // Only attach the new argument attributes if there is a one-to-one
+ // mapping from old to new types. Otherwise, attributes might be
+ // attached to types that they do not support.
+ if (mapping->size == 1) {
+ newArgAttrs[mapping->inputNo] =
+ DictionaryAttr::get(rewriter.getContext(), convertedAttrs);
+ continue;
+ }
+ // TODO: Implement custom handling for types that expand to multiple
+ // function arguments.
+ for (size_t j = 0; j < mapping->size; ++j)
+ newArgAttrs[mapping->inputNo + j] =
+ DictionaryAttr::get(rewriter.getContext(), {});
+ }
+ if (!newArgAttrs.empty())
+ newFuncOp.setAllArgAttrs(rewriter.getArrayAttr(newArgAttrs));
+ }
+
+ rewriter.inlineRegionBefore(funcOp.getFunctionBody(), newFuncOp.getBody(),
+ newFuncOp.end());
+ if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), converter,
+ &result))) {
+ return rewriter.notifyMatchFailure(funcOp,
+ "region types conversion failed");
+ }
+
+ return newFuncOp;
+}
+
namespace {
struct FuncOpConversionBase : public ConvertOpToLLVMPattern<func::FuncOp> {
@@ -328,128 +463,9 @@ struct FuncOpConversionBase : public ConvertOpToLLVMPattern<func::FuncOp> {
FailureOr<LLVM::LLVMFuncOp>
convertFuncOpToLLVMFuncOp(func::FuncOp funcOp,
ConversionPatternRewriter &rewriter) const {
- // Convert the original function arguments. They are converted using the
- // LLVMTypeConverter provided to this legalization pattern.
- auto varargsAttr = funcOp->getAttrOfType<BoolAttr>(varargsAttrName);
- TypeConverter::SignatureConversion result(funcOp.getNumArguments());
- auto llvmType = getTypeConverter()->convertFunctionSignature(
- funcOp.getFunctionType(), varargsAttr && varargsAttr.getValue(),
- shouldUseBarePtrCallConv(funcOp, getTypeConverter()), result);
- if (!llvmType)
- return rewriter.notifyMatchFailure(funcOp, "signature conversion failed");
-
- // Create an LLVM function, use external linkage by default until MLIR
- // functions have linkage.
- LLVM::Linkage linkage = LLVM::Linkage::External;
- if (funcOp->hasAttr(linkageAttrName)) {
- auto attr =
- dyn_cast<mlir::LLVM::LinkageAttr>(funcOp->getAttr(linkageAttrName));
- if (!attr) {
- funcOp->emitError() << "Contains " << linkageAttrName
- << " attribute not of type LLVM::LinkageAttr";
- return rewriter.notifyMatchFailure(
- funcOp, "Contains linkage attribute not of type LLVM::LinkageAttr");
- }
- linkage = attr.getLinkage();
- }
-
- SmallVector<NamedAttribute, 4> attributes;
- filterFuncAttributes(funcOp, attributes);
- auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
- funcOp.getLoc(), funcOp.getName(), llvmType, linkage,
- /*dsoLocal=*/false, /*cconv=*/LLVM::CConv::C, /*comdat=*/nullptr,
- attributes);
- if (funcOp.getSymVisibilityAttr())
- newFuncOp.setSymVisibility(funcOp.getSymVisibilityAttr());
-
- // Create a memory effect attribute corresponding to readnone.
- StringRef readnoneAttrName = LLVM::LLVMDialect::getReadnoneAttrName();
- if (funcOp->hasAttr(readnoneAttrName)) {
- auto attr = funcOp->getAttrOfType<UnitAttr>(readnoneAttrName);
- if (!attr) {
- funcOp->emitError() << "Contains " << readnoneAttrName
- << " attribute not of type UnitAttr";
- return rewriter.notifyMatchFailure(
- funcOp, "Contains readnone attribute not of type UnitAttr");
- }
- auto memoryAttr = LLVM::MemoryEffectsAttr::get(
- rewriter.getContext(),
- {LLVM::ModRefInfo::NoModRef, LLVM::ModRefInfo::NoModRef,
- LLVM::ModRefInfo::NoModRef});
- newFuncOp.setMemoryAttr(memoryAttr);
- }
-
- // Propagate argument/result attributes to all converted arguments/result
- // obtained after converting a given original argument/result.
- if (ArrayAttr resAttrDicts = funcOp.getAllResultAttrs()) {
- assert(!resAttrDicts.empty() && "expected array to be non-empty");
- if (funcOp.getNumResults() == 1)
- newFuncOp.setAllResultAttrs(resAttrDicts);
- }
- if (ArrayAttr argAttrDicts = funcOp.getAllArgAttrs()) {
- SmallVector<Attribute> newArgAttrs(
- cast<LLVM::LLVMFunctionType>(llvmType).getNumParams());
- for (unsigned i = 0, e = funcOp.getNumArguments(); i < e; ++i) {
- // Some LLVM IR attribute have a type attached to them. During FuncOp ->
- // LLVMFuncOp conversion these types may have changed. Account for that
- // change by converting attributes' types as well.
- SmallVector<NamedAttribute, 4> convertedAttrs;
- auto attrsDict = cast<DictionaryAttr>(argAttrDicts[i]);
- convertedAttrs.reserve(attrsDict.size());
- for (const NamedAttribute &attr : attrsDict) {
- const auto convert = [&](const NamedAttribute &attr) {
- return TypeAttr::get(getTypeConverter()->convertType(
- cast<TypeAttr>(attr.getValue()).getValue()));
- };
- if (attr.getName().getValue() ==
- LLVM::LLVMDialect::getByValAttrName()) {
- convertedAttrs.push_back(rewriter.getNamedAttr(
- LLVM::LLVMDialect::getByValAttrName(), convert(attr)));
- } else if (attr.getName().getValue() ==
- LLVM::LLVMDialect::getByRefAttrName()) {
- convertedAttrs.push_back(rewriter.getNamedAttr(
- LLVM::LLVMDialect::getByRefAttrName(), convert(attr)));
- } else if (attr.getName().getValue() ==
- LLVM::LLVMDialect::getStructRetAttrName()) {
- convertedAttrs.push_back(rewriter.getNamedAttr(
- LLVM::LLVMDialect::getStructRetAttrName(), convert(attr)));
- } else if (attr.getName().getValue() ==
- LLVM::LLVMDialect::getInAllocaAttrName()) {
- convertedAttrs.push_back(rewriter.getNamedAttr(
- LLVM::LLVMDialect::getInAllocaAttrName(), convert(attr)));
- } else {
- convertedAttrs.push_back(attr);
- }
- }
- auto mapping = result.getInputMapping(i);
- assert(mapping && "unexpected deletion of function argument");
- // Only attach the new argument attributes if there is a one-to-one
- // mapping from old to new types. Otherwise, attributes might be
- // attached to types that they do not support.
- if (mapping->size == 1) {
- newArgAttrs[mapping->inputNo] =
- DictionaryAttr::get(rewriter.getContext(), convertedAttrs);
- continue;
- }
- // TODO: Implement custom handling for types that expand to multiple
- // function arguments.
- for (size_t j = 0; j < mapping->size; ++j)
- newArgAttrs[mapping->inputNo + j] =
- DictionaryAttr::get(rewriter.getContext(), {});
- }
- if (!newArgAttrs.empty())
- newFuncOp.setAllArgAttrs(rewriter.getArrayAttr(newArgAttrs));
- }
-
- rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
- newFuncOp.end());
- if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), *typeConverter,
- &result))) {
- return rewriter.notifyMatchFailure(funcOp,
- "region types conversion failed");
- }
-
- return newFuncOp;
+ return mlir::convertFuncOpToLLVMFuncOp(
+ cast<FunctionOpInterface>(funcOp.getOperation()), rewriter,
+ *getTypeConverter());
}
};
More information about the Mlir-commits
mailing list