[Mlir-commits] [mlir] 162f757 - [mlir][LLVM] Add an attribute to control use of bare-pointer calling convention.
Mahesh Ravishankar
llvmlistbot at llvm.org
Thu Apr 6 09:20:14 PDT 2023
Author: Mahesh Ravishankar
Date: 2023-04-06T16:19:56Z
New Revision: 162f7572067d7d2d70202f5ff42532adf6f75517
URL: https://github.com/llvm/llvm-project/commit/162f7572067d7d2d70202f5ff42532adf6f75517
DIFF: https://github.com/llvm/llvm-project/commit/162f7572067d7d2d70202f5ff42532adf6f75517.diff
LOG: [mlir][LLVM] Add an attribute to control use of bare-pointer calling convention.
Currently the use of bare pointer calling convention is controlled
globally through use of an option in the `LLVMTypeConverter`. To allow
more fine-grained control use an attribute on a function to drive the
calling convention to use.
Reviewed By: ftynse
Differential Revision: https://reviews.llvm.org/D147494
Added:
Modified:
mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
mlir/test/Conversion/FuncToLLVM/calling-convention.mlir
mlir/test/Conversion/FuncToLLVM/func-memref-return.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
index b13b88d6773a8..600575139dbe5 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
@@ -53,20 +53,23 @@ class LLVMTypeConverter : public TypeConverter {
/// one and results are packed into a wrapped LLVM IR structure type. `result`
/// is populated with argument mapping.
Type convertFunctionSignature(FunctionType funcTy, bool isVariadic,
+ bool useBarePtrCallConv,
SignatureConversion &result);
/// Convert a non-empty list of types to be returned from a function into a
/// supported LLVM IR type. In particular, if more than one value is
/// returned, create an LLVM IR structure type with elements that correspond
/// to each of the MLIR types converted with `convertType`.
- Type packFunctionResults(TypeRange types);
+ Type packFunctionResults(TypeRange types,
+ bool useBarePointerCallConv = false);
/// Convert a type in the context of the default or bare pointer calling
/// convention. Calling convention sensitive types, such as MemRefType and
/// UnrankedMemRefType, are converted following the specific rules for the
/// calling convention. Calling convention independent types are converted
/// following the default LLVM type conversions.
- Type convertCallingConventionType(Type type);
+ Type convertCallingConventionType(Type type,
+ bool useBarePointerCallConv = false);
/// Promote the bare pointers in 'values' that resulted from memrefs to
/// descriptors. 'stdTypes' holds the types of 'values' before the conversion
@@ -95,8 +98,8 @@ class LLVMTypeConverter : public TypeConverter {
/// of the platform-specific C/C++ ABI lowering related to struct argument
/// passing.
SmallVector<Value, 4> promoteOperands(Location loc, ValueRange opOperands,
- ValueRange operands,
- OpBuilder &builder);
+ ValueRange operands, OpBuilder &builder,
+ bool useBarePtrCallConv = false);
/// Promote the LLVM struct representation of one MemRef descriptor to stack
/// and use pointer to struct to avoid the complexity of the platform-specific
diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
index 7200b2b3ea9af..86394aa969bb3 100644
--- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
+++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
@@ -58,6 +58,14 @@ using namespace mlir;
static constexpr StringRef varargsAttrName = "func.varargs";
static constexpr StringRef linkageAttrName = "llvm.linkage";
+static constexpr StringRef barePtrAttrName = "llvm.bareptr";
+
+/// Return `true` if the `op` should use bare pointer calling convention.
+static bool shouldUseBarePtrCallConv(Operation *op,
+ LLVMTypeConverter *typeConverter) {
+ return (op && op->hasAttr(barePtrAttrName)) ||
+ typeConverter->getOptions().useBarePtrCallConv;
+}
/// Only retain those attributes that are not constructed by
/// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument
@@ -267,6 +275,55 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc,
}
}
+/// Modifies the body of the function to construct the `MemRefDescriptor` from
+/// the bare pointer calling convention lowering of `memref` types.
+static void modifyFuncOpToUseBarePtrCallingConv(
+ ConversionPatternRewriter &rewriter, Location loc,
+ LLVMTypeConverter &typeConverter, LLVM::LLVMFuncOp funcOp,
+ TypeRange oldArgTypes) {
+ if (funcOp.getBody().empty())
+ return;
+
+ // Promote bare pointers from memref arguments to memref descriptors at the
+ // beginning of the function so that all the memrefs in the function have a
+ // uniform representation.
+ Block *entryBlock = &funcOp.getBody().front();
+ auto blockArgs = entryBlock->getArguments();
+ assert(blockArgs.size() == oldArgTypes.size() &&
+ "The number of arguments and types doesn't match");
+
+ OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPointToStart(entryBlock);
+ for (auto it : llvm::zip(blockArgs, oldArgTypes)) {
+ BlockArgument arg = std::get<0>(it);
+ Type argTy = std::get<1>(it);
+
+ // Unranked memrefs are not supported in the bare pointer calling
+ // convention. We should have bailed out before in the presence of
+ // unranked memrefs.
+ assert(!argTy.isa<UnrankedMemRefType>() &&
+ "Unranked memref is not supported");
+ auto memrefTy = argTy.dyn_cast<MemRefType>();
+ if (!memrefTy)
+ continue;
+
+ // Replace barePtr with a placeholder (undef), promote barePtr to a ranked
+ // or unranked memref descriptor and replace placeholder with the last
+ // instruction of the memref descriptor.
+ // TODO: The placeholder is needed to avoid replacing barePtr uses in the
+ // MemRef descriptor instructions. We may want to have a utility in the
+ // rewriter to properly handle this use case.
+ Location loc = funcOp.getLoc();
+ auto placeholder = rewriter.create<LLVM::UndefOp>(
+ loc, typeConverter.convertType(memrefTy));
+ rewriter.replaceUsesOfBlockArgument(arg, placeholder);
+
+ Value desc = MemRefDescriptor::fromStaticShape(rewriter, loc, typeConverter,
+ memrefTy, arg);
+ rewriter.replaceOp(placeholder, {desc});
+ }
+}
+
namespace {
struct FuncOpConversionBase : public ConvertOpToLLVMPattern<func::FuncOp> {
@@ -284,7 +341,7 @@ struct FuncOpConversionBase : public ConvertOpToLLVMPattern<func::FuncOp> {
TypeConverter::SignatureConversion result(funcOp.getNumArguments());
auto llvmType = getTypeConverter()->convertFunctionSignature(
funcOp.getFunctionType(), varargsAttr && varargsAttr.getValue(),
- result);
+ shouldUseBarePtrCallConv(funcOp, getTypeConverter()), result);
if (!llvmType)
return nullptr;
@@ -415,89 +472,24 @@ struct FuncOpConversion : public FuncOpConversionBase {
if (!newFuncOp)
return failure();
- if (funcOp->getAttrOfType<UnitAttr>(
- LLVM::LLVMDialect::getEmitCWrapperAttrName())) {
- if (newFuncOp.isVarArg())
- return funcOp->emitError("C interface for variadic functions is not "
- "supported yet.");
-
- if (newFuncOp.isExternal())
- wrapExternalFunction(rewriter, funcOp.getLoc(), *getTypeConverter(),
- funcOp, newFuncOp);
- else
- wrapForExternalCallers(rewriter, funcOp.getLoc(), *getTypeConverter(),
- funcOp, newFuncOp);
- }
-
- rewriter.eraseOp(funcOp);
- return success();
- }
-};
-
-/// FuncOp legalization pattern that converts MemRef arguments to bare pointers
-/// to the MemRef element type. This will impact the calling convention and ABI.
-struct BarePtrFuncOpConversion : public FuncOpConversionBase {
- using FuncOpConversionBase::FuncOpConversionBase;
-
- LogicalResult
- matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
-
- // TODO: bare ptr conversion could be handled by argument materialization
- // and most of the code below would go away. But to do this, we would need a
- // way to distinguish between FuncOp and other regions in the
- // addArgumentMaterialization hook.
+ if (!shouldUseBarePtrCallConv(funcOp, this->getTypeConverter())) {
+ if (funcOp->getAttrOfType<UnitAttr>(
+ LLVM::LLVMDialect::getEmitCWrapperAttrName())) {
+ if (newFuncOp.isVarArg())
+ return funcOp->emitError("C interface for variadic functions is not "
+ "supported yet.");
- // Store the type of memref-typed arguments before the conversion so that we
- // can promote them to MemRef descriptor at the beginning of the function.
- SmallVector<Type, 8> oldArgTypes =
- llvm::to_vector<8>(funcOp.getFunctionType().getInputs());
-
- auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter);
- if (!newFuncOp)
- return failure();
- if (newFuncOp.getBody().empty()) {
- rewriter.eraseOp(funcOp);
- return success();
- }
-
- // Promote bare pointers from memref arguments to memref descriptors at the
- // beginning of the function so that all the memrefs in the function have a
- // uniform representation.
- Block *entryBlock = &newFuncOp.getBody().front();
- auto blockArgs = entryBlock->getArguments();
- assert(blockArgs.size() == oldArgTypes.size() &&
- "The number of arguments and types doesn't match");
-
- OpBuilder::InsertionGuard guard(rewriter);
- rewriter.setInsertionPointToStart(entryBlock);
- for (auto it : llvm::zip(blockArgs, oldArgTypes)) {
- BlockArgument arg = std::get<0>(it);
- Type argTy = std::get<1>(it);
-
- // Unranked memrefs are not supported in the bare pointer calling
- // convention. We should have bailed out before in the presence of
- // unranked memrefs.
- assert(!argTy.isa<UnrankedMemRefType>() &&
- "Unranked memref is not supported");
- auto memrefTy = argTy.dyn_cast<MemRefType>();
- if (!memrefTy)
- continue;
-
- // Replace barePtr with a placeholder (undef), promote barePtr to a ranked
- // or unranked memref descriptor and replace placeholder with the last
- // instruction of the memref descriptor.
- // TODO: The placeholder is needed to avoid replacing barePtr uses in the
- // MemRef descriptor instructions. We may want to have a utility in the
- // rewriter to properly handle this use case.
- Location loc = funcOp.getLoc();
- auto placeholder = rewriter.create<LLVM::UndefOp>(
- loc, getTypeConverter()->convertType(memrefTy));
- rewriter.replaceUsesOfBlockArgument(arg, placeholder);
-
- Value desc = MemRefDescriptor::fromStaticShape(
- rewriter, loc, *getTypeConverter(), memrefTy, arg);
- rewriter.replaceOp(placeholder, {desc});
+ if (newFuncOp.isExternal())
+ wrapExternalFunction(rewriter, funcOp.getLoc(), *getTypeConverter(),
+ funcOp, newFuncOp);
+ else
+ wrapForExternalCallers(rewriter, funcOp.getLoc(), *getTypeConverter(),
+ funcOp, newFuncOp);
+ }
+ } else {
+ modifyFuncOpToUseBarePtrCallingConv(rewriter, funcOp.getLoc(),
+ *getTypeConverter(), newFuncOp,
+ funcOp.getFunctionType().getInputs());
}
rewriter.eraseOp(funcOp);
@@ -535,23 +527,24 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
using Super = CallOpInterfaceLowering<CallOpType>;
using Base = ConvertOpToLLVMPattern<CallOpType>;
- LogicalResult
- matchAndRewrite(CallOpType callOp, typename CallOpType::Adaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
+ LogicalResult matchAndRewriteImpl(CallOpType callOp,
+ typename CallOpType::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter,
+ bool useBarePtrCallConv = false) const {
// Pack the result types into a struct.
Type packedResult = nullptr;
unsigned numResults = callOp.getNumResults();
auto resultTypes = llvm::to_vector<4>(callOp.getResultTypes());
if (numResults != 0) {
- if (!(packedResult =
- this->getTypeConverter()->packFunctionResults(resultTypes)))
+ if (!(packedResult = this->getTypeConverter()->packFunctionResults(
+ resultTypes, useBarePtrCallConv)))
return failure();
}
auto promoted = this->getTypeConverter()->promoteOperands(
callOp.getLoc(), /*opOperands=*/callOp->getOperands(),
- adaptor.getOperands(), rewriter);
+ adaptor.getOperands(), rewriter, useBarePtrCallConv);
auto newOp = rewriter.create<LLVM::CallOp>(
callOp.getLoc(), packedResult ? TypeRange(packedResult) : TypeRange(),
promoted, callOp->getAttrs());
@@ -570,7 +563,7 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
}
}
- if (this->getTypeConverter()->getOptions().useBarePtrCallConv) {
+ if (useBarePtrCallConv) {
// For the bare-ptr calling convention, promote memref results to
// descriptors.
assert(results.size() == resultTypes.size() &&
@@ -590,11 +583,28 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
struct CallOpLowering : public CallOpInterfaceLowering<func::CallOp> {
using Super::Super;
+
+ LogicalResult
+ matchAndRewrite(func::CallOp callOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ bool useBarePtrCallConv = false;
+ if (Operation *callee = SymbolTable::lookupNearestSymbolFrom(
+ callOp, callOp.getCalleeAttr())) {
+ useBarePtrCallConv = shouldUseBarePtrCallConv(callee, getTypeConverter());
+ }
+ return matchAndRewriteImpl(callOp, adaptor, rewriter, useBarePtrCallConv);
+ }
};
struct CallIndirectOpLowering
: public CallOpInterfaceLowering<func::CallIndirectOp> {
using Super::Super;
+
+ LogicalResult
+ matchAndRewrite(func::CallIndirectOp callIndirectOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ return matchAndRewriteImpl(callIndirectOp, adaptor, rewriter);
+ }
};
struct UnrealizedConversionCastOpLowering
@@ -640,7 +650,10 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern<func::ReturnOp> {
unsigned numArguments = op.getNumOperands();
SmallVector<Value, 4> updatedOperands;
- if (getTypeConverter()->getOptions().useBarePtrCallConv) {
+ auto funcOp = op->getParentOfType<LLVM::LLVMFuncOp>();
+ bool useBarePtrCallConv =
+ shouldUseBarePtrCallConv(funcOp, this->getTypeConverter());
+ if (useBarePtrCallConv) {
// For the bare-ptr calling convention, extract the aligned pointer to
// be returned from the memref descriptor.
for (auto it : llvm::zip(op->getOperands(), adaptor.getOperands())) {
@@ -649,7 +662,7 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern<func::ReturnOp> {
if (oldTy.isa<MemRefType>() && getTypeConverter()->canConvertToBarePtr(
oldTy.cast<BaseMemRefType>())) {
MemRefDescriptor memrefDesc(newOperand);
- newOperand = memrefDesc.alignedPtr(rewriter, loc);
+ newOperand = memrefDesc.allocatedPtr(rewriter, loc);
} else if (oldTy.isa<UnrankedMemRefType>()) {
// Unranked memref is not supported in the bare pointer calling
// convention.
@@ -673,8 +686,8 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern<func::ReturnOp> {
// Otherwise, we need to pack the arguments into an LLVM struct type before
// returning.
- auto packedType =
- getTypeConverter()->packFunctionResults(op.getOperandTypes());
+ auto packedType = getTypeConverter()->packFunctionResults(
+ op.getOperandTypes(), useBarePtrCallConv);
if (!packedType) {
return rewriter.notifyMatchFailure(op, "could not convert result types");
}
@@ -692,10 +705,7 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern<func::ReturnOp> {
void mlir::populateFuncToLLVMFuncOpConversionPattern(
LLVMTypeConverter &converter, RewritePatternSet &patterns) {
- if (converter.getOptions().useBarePtrCallConv)
- patterns.add<BarePtrFuncOpConversion>(converter);
- else
- patterns.add<FuncOpConversion>(converter);
+ patterns.add<FuncOpConversion>(converter);
}
void mlir::populateFuncToLLVMConversionPatterns(LLVMTypeConverter &converter,
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index ec0d240040d1e..82c73b5f4dd2e 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -47,7 +47,8 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
TypeConverter::SignatureConversion signatureConversion(
gpuFuncOp.front().getNumArguments());
Type funcType = getTypeConverter()->convertFunctionSignature(
- gpuFuncOp.getFunctionType(), /*isVariadic=*/false, signatureConversion);
+ gpuFuncOp.getFunctionType(), /*isVariadic=*/false,
+ getTypeConverter()->getOptions().useBarePtrCallConv, signatureConversion);
// Create the new function operation. Only copy those attributes that are
// not specific to function modeling.
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index e24be1dfdf6b9..833ea36ecf7bd 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -209,8 +209,8 @@ Type LLVMTypeConverter::convertComplexType(ComplexType type) {
// pointer-to-function types.
Type LLVMTypeConverter::convertFunctionType(FunctionType type) {
SignatureConversion conversion(type.getNumInputs());
- Type converted =
- convertFunctionSignature(type, /*isVariadic=*/false, conversion);
+ Type converted = convertFunctionSignature(
+ type, /*isVariadic=*/false, options.useBarePtrCallConv, conversion);
if (!converted)
return {};
return getPointerType(converted);
@@ -221,12 +221,12 @@ Type LLVMTypeConverter::convertFunctionType(FunctionType type) {
// Function has one VoidType result. If MLIR Function has more than one result,
// they are into an LLVM StructType in their order of appearance.
Type LLVMTypeConverter::convertFunctionSignature(
- FunctionType funcTy, bool isVariadic,
+ FunctionType funcTy, bool isVariadic, bool useBarePtrCallConv,
LLVMTypeConverter::SignatureConversion &result) {
// Select the argument converter depending on the calling convention.
- auto funcArgConverter = options.useBarePtrCallConv
- ? barePtrFuncArgTypeConverter
- : structFuncArgTypeConverter;
+ useBarePtrCallConv = useBarePtrCallConv || options.useBarePtrCallConv;
+ auto funcArgConverter = useBarePtrCallConv ? barePtrFuncArgTypeConverter
+ : structFuncArgTypeConverter;
// Convert argument types one by one and check for errors.
for (auto [idx, type] : llvm::enumerate(funcTy.getInputs())) {
SmallVector<Type, 8> converted;
@@ -238,9 +238,10 @@ Type LLVMTypeConverter::convertFunctionSignature(
// If function does not return anything, create the void result type,
// if it returns on element, convert it, otherwise pack the result types into
// a struct.
- Type resultType = funcTy.getNumResults() == 0
- ? LLVM::LLVMVoidType::get(&getContext())
- : packFunctionResults(funcTy.getResults());
+ Type resultType =
+ funcTy.getNumResults() == 0
+ ? LLVM::LLVMVoidType::get(&getContext())
+ : packFunctionResults(funcTy.getResults(), useBarePtrCallConv);
if (!resultType)
return {};
return LLVM::LLVMFunctionType::get(resultType, result.getConvertedTypes(),
@@ -472,8 +473,9 @@ Type LLVMTypeConverter::convertVectorType(VectorType type) {
/// UnrankedMemRefType, are converted following the specific rules for the
/// calling convention. Calling convention independent types are converted
/// following the default LLVM type conversions.
-Type LLVMTypeConverter::convertCallingConventionType(Type type) {
- if (options.useBarePtrCallConv)
+Type LLVMTypeConverter::convertCallingConventionType(Type type,
+ bool useBarePtrCallConv) {
+ if (useBarePtrCallConv)
if (auto memrefTy = type.dyn_cast<BaseMemRefType>())
return convertMemRefToBarePtr(memrefTy);
@@ -498,16 +500,18 @@ void LLVMTypeConverter::promoteBarePtrsToDescriptors(
/// supported LLVM IR type. In particular, if more than one value is returned,
/// create an LLVM IR structure type with elements that correspond to each of
/// the MLIR types converted with `convertType`.
-Type LLVMTypeConverter::packFunctionResults(TypeRange types) {
+Type LLVMTypeConverter::packFunctionResults(TypeRange types,
+ bool useBarePtrCallConv) {
assert(!types.empty() && "expected non-empty list of type");
+ useBarePtrCallConv |= options.useBarePtrCallConv;
if (types.size() == 1)
- return convertCallingConventionType(types.front());
+ return convertCallingConventionType(types.front(), useBarePtrCallConv);
SmallVector<Type, 8> resultTypes;
resultTypes.reserve(types.size());
for (auto t : types) {
- auto converted = convertCallingConventionType(t);
+ auto converted = convertCallingConventionType(t, useBarePtrCallConv);
if (!converted || !LLVM::isCompatibleType(converted))
return {};
resultTypes.push_back(converted);
@@ -530,17 +534,18 @@ Value LLVMTypeConverter::promoteOneMemRefDescriptor(Location loc, Value operand,
return allocated;
}
-SmallVector<Value, 4> LLVMTypeConverter::promoteOperands(Location loc,
- ValueRange opOperands,
- ValueRange operands,
- OpBuilder &builder) {
+SmallVector<Value, 4>
+LLVMTypeConverter::promoteOperands(Location loc, ValueRange opOperands,
+ ValueRange operands, OpBuilder &builder,
+ bool useBarePtrCallConv) {
SmallVector<Value, 4> promotedOperands;
promotedOperands.reserve(operands.size());
+ useBarePtrCallConv |= options.useBarePtrCallConv;
for (auto it : llvm::zip(opOperands, operands)) {
auto operand = std::get<0>(it);
auto llvmOperand = std::get<1>(it);
- if (options.useBarePtrCallConv) {
+ if (useBarePtrCallConv) {
// For the bare-ptr calling convention, we only have to extract the
// aligned pointer of a memref.
if (auto memrefType = operand.getType().dyn_cast<MemRefType>()) {
@@ -603,7 +608,8 @@ LogicalResult mlir::structFuncArgTypeConverter(LLVMTypeConverter &converter,
LogicalResult mlir::barePtrFuncArgTypeConverter(LLVMTypeConverter &converter,
Type type,
SmallVectorImpl<Type> &result) {
- auto llvmTy = converter.convertCallingConventionType(type);
+ auto llvmTy =
+ converter.convertCallingConventionType(type, /*useBarePtrCallConv=*/true);
if (!llvmTy)
return failure();
diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
index 2cdce91806068..b93894757daa5 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
@@ -338,7 +338,8 @@ class AddressOfPattern : public SPIRVToLLVMConversion<spirv::AddressOfOp> {
auto dstType = typeConverter.convertType(op.getPointer().getType());
if (!dstType)
return failure();
- rewriter.replaceOpWithNewOp<LLVM::AddressOfOp>(op, dstType, op.getVariable());
+ rewriter.replaceOpWithNewOp<LLVM::AddressOfOp>(op, dstType,
+ op.getVariable());
return success();
}
};
@@ -582,7 +583,8 @@ class CompositeExtractPattern
}
rewriter.replaceOpWithNewOp<LLVM::ExtractValueOp>(
- op, adaptor.getComposite(), LLVM::convertArrayToIndices(op.getIndices()));
+ op, adaptor.getComposite(),
+ LLVM::convertArrayToIndices(op.getIndices()));
return success();
}
};
@@ -1146,7 +1148,8 @@ class SelectionPattern : public SPIRVToLLVMConversion<spirv::SelectionOp> {
Block *falseBlock = condBrOp.getFalseBlock();
rewriter.setInsertionPointToEnd(currentBlock);
rewriter.create<LLVM::CondBrOp>(loc, condBrOp.getCondition(), trueBlock,
- condBrOp.getTrueTargetOperands(), falseBlock,
+ condBrOp.getTrueTargetOperands(),
+ falseBlock,
condBrOp.getFalseTargetOperands());
rewriter.inlineRegionBefore(op.getBody(), continueBlock);
@@ -1329,7 +1332,8 @@ class FuncConversionPattern : public SPIRVToLLVMConversion<spirv::FuncOp> {
TypeConverter::SignatureConversion signatureConverter(
funcType.getNumInputs());
auto llvmType = typeConverter.convertFunctionSignature(
- funcType, /*isVariadic=*/false, signatureConverter);
+ funcType, /*isVariadic=*/false, /*useBarePtrCallConv=*/false,
+ signatureConverter);
if (!llvmType)
return failure();
diff --git a/mlir/test/Conversion/FuncToLLVM/calling-convention.mlir b/mlir/test/Conversion/FuncToLLVM/calling-convention.mlir
index daa824d84ba74..b1c065e0f1f8d 100644
--- a/mlir/test/Conversion/FuncToLLVM/calling-convention.mlir
+++ b/mlir/test/Conversion/FuncToLLVM/calling-convention.mlir
@@ -242,3 +242,67 @@ func.func @return_two_var_memref(%arg0: memref<4x3xf32>) -> (memref<*xf32>, memr
// CHECK-LABEL: @_mlir_ciface_return_two_var_memref
// CHECK-SAME: (%{{.*}}: !llvm.ptr,
// CHECK-SAME: %{{.*}}: !llvm.ptr)
+
+// CHECK-LABEL: llvm.func @bare_ptr_calling_conv(
+// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr
+// CHECK-SAME: -> !llvm.ptr
+func.func @bare_ptr_calling_conv(%arg0: memref<4x3xf32>, %arg1 : index, %arg2 : index, %arg3 : f32)
+ -> (memref<4x3xf32>) attributes { llvm.bareptr } {
+ // CHECK: %[[UNDEF_DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+ // CHECK: %[[INSERT_ALLOCPTR:.*]] = llvm.insertvalue %[[ARG0]], %[[UNDEF_DESC]][0]
+ // CHECK: %[[INSERT_ALIGNEDPTR:.*]] = llvm.insertvalue %[[ARG0]], %[[INSERT_ALLOCPTR]][1]
+ // CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64
+ // CHECK: %[[INSERT_OFFSET:.*]] = llvm.insertvalue %[[C0]], %[[INSERT_ALIGNEDPTR]][2]
+ // CHECK: %[[C4:.*]] = llvm.mlir.constant(4 : index) : i64
+ // CHECK: %[[INSERT_DIM0:.*]] = llvm.insertvalue %[[C4]], %[[INSERT_OFFSET]][3, 0]
+ // CHECK: %[[C3:.*]] = llvm.mlir.constant(3 : index) : i64
+ // CHECK: %[[INSERT_STRIDE0:.*]] = llvm.insertvalue %[[C3]], %[[INSERT_DIM0]][4, 0]
+ // CHECK: %[[C3:.*]] = llvm.mlir.constant(3 : index) : i64
+ // CHECK: %[[INSERT_DIM1:.*]] = llvm.insertvalue %[[C3]], %[[INSERT_STRIDE0]][3, 1]
+ // CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : index) : i64
+ // CHECK: %[[INSERT_STRIDE1:.*]] = llvm.insertvalue %[[C1]], %[[INSERT_DIM1]][4, 1]
+
+ // CHECK: %[[ALIGNEDPTR:.*]] = llvm.extractvalue %[[INSERT_STRIDE1]][1]
+ // CHECK: %[[STOREPTR:.*]] = llvm.getelementptr %[[ALIGNEDPTR]]
+ // CHECK: llvm.store %{{.*}}, %[[STOREPTR]]
+ memref.store %arg3, %arg0[%arg1, %arg2] : memref<4x3xf32>
+
+ // CHECK: llvm.return %[[ARG0]]
+ return %arg0 : memref<4x3xf32>
+}
+
+// CHECK-LABEL: llvm.func @bare_ptr_calling_conv_multiresult(
+// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr
+// CHECK-SAME: -> !llvm.struct<(f32, ptr)>
+func.func @bare_ptr_calling_conv_multiresult(%arg0: memref<4x3xf32>, %arg1 : index, %arg2 : index, %arg3 : f32)
+ -> (f32, memref<4x3xf32>) attributes { llvm.bareptr } {
+ // CHECK: %[[UNDEF_DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+ // CHECK: %[[INSERT_ALLOCPTR:.*]] = llvm.insertvalue %[[ARG0]], %[[UNDEF_DESC]][0]
+ // CHECK: %[[INSERT_ALIGNEDPTR:.*]] = llvm.insertvalue %[[ARG0]], %[[INSERT_ALLOCPTR]][1]
+ // CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64
+ // CHECK: %[[INSERT_OFFSET:.*]] = llvm.insertvalue %[[C0]], %[[INSERT_ALIGNEDPTR]][2]
+ // CHECK: %[[C4:.*]] = llvm.mlir.constant(4 : index) : i64
+ // CHECK: %[[INSERT_DIM0:.*]] = llvm.insertvalue %[[C4]], %[[INSERT_OFFSET]][3, 0]
+ // CHECK: %[[C3:.*]] = llvm.mlir.constant(3 : index) : i64
+ // CHECK: %[[INSERT_STRIDE0:.*]] = llvm.insertvalue %[[C3]], %[[INSERT_DIM0]][4, 0]
+ // CHECK: %[[C3:.*]] = llvm.mlir.constant(3 : index) : i64
+ // CHECK: %[[INSERT_DIM1:.*]] = llvm.insertvalue %[[C3]], %[[INSERT_STRIDE0]][3, 1]
+ // CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : index) : i64
+ // CHECK: %[[INSERT_STRIDE1:.*]] = llvm.insertvalue %[[C1]], %[[INSERT_DIM1]][4, 1]
+
+ // CHECK: %[[ALIGNEDPTR:.*]] = llvm.extractvalue %[[INSERT_STRIDE1]][1]
+ // CHECK: %[[STOREPTR:.*]] = llvm.getelementptr %[[ALIGNEDPTR]]
+ // CHECK: llvm.store %{{.*}}, %[[STOREPTR]]
+ memref.store %arg3, %arg0[%arg1, %arg2] : memref<4x3xf32>
+
+ // CHECK: %[[ALIGNEDPTR0:.*]] = llvm.extractvalue %[[INSERT_STRIDE1]][1]
+ // CHECK: %[[LOADPTR:.*]] = llvm.getelementptr %[[ALIGNEDPTR0]]
+ // CHECK: %[[RETURN0:.*]] = llvm.load %[[LOADPTR]]
+ %0 = memref.load %arg0[%arg1, %arg2] : memref<4x3xf32>
+
+ // CHECK: %[[RETURN_DESC:.*]] = llvm.mlir.undef : !llvm.struct<(f32, ptr)>
+ // CHECK: %[[INSERT_RETURN0:.*]] = llvm.insertvalue %[[RETURN0]], %[[RETURN_DESC]][0]
+ // CHECK: %[[INSERT_RETURN1:.*]] = llvm.insertvalue %[[ARG0]], %[[INSERT_RETURN0]][1]
+ // CHECK: llvm.return %[[INSERT_RETURN1]]
+ return %0, %arg0 : f32, memref<4x3xf32>
+}
diff --git a/mlir/test/Conversion/FuncToLLVM/func-memref-return.mlir b/mlir/test/Conversion/FuncToLLVM/func-memref-return.mlir
index 8663ce8cbbf2f..956c298123db2 100644
--- a/mlir/test/Conversion/FuncToLLVM/func-memref-return.mlir
+++ b/mlir/test/Conversion/FuncToLLVM/func-memref-return.mlir
@@ -27,7 +27,7 @@ func.func @check_static_return(%static : memref<32x18xf32>) -> memref<32x18xf32>
// BAREPTR-NEXT: %[[ins3:.*]] = llvm.insertvalue %[[val3]], %[[ins2]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// BAREPTR-NEXT: %[[val4:.*]] = llvm.mlir.constant(1 : index) : i64
// BAREPTR-NEXT: %[[ins4:.*]] = llvm.insertvalue %[[val4]], %[[ins3]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
-// BAREPTR-NEXT: %[[base1:.*]] = llvm.extractvalue %[[ins4]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// BAREPTR-NEXT: %[[base1:.*]] = llvm.extractvalue %[[ins4]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// BAREPTR-NEXT: llvm.return %[[base1]] : !llvm.ptr
return %static : memref<32x18xf32>
}
@@ -56,7 +56,7 @@ func.func @check_static_return_with_offset(%static : memref<32x18xf32, strided<[
// BAREPTR-NEXT: %[[ins3:.*]] = llvm.insertvalue %[[val3]], %[[ins2]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// BAREPTR-NEXT: %[[val4:.*]] = llvm.mlir.constant(1 : index) : i64
// BAREPTR-NEXT: %[[ins4:.*]] = llvm.insertvalue %[[val4]], %[[ins3]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
-// BAREPTR-NEXT: %[[base1:.*]] = llvm.extractvalue %[[ins4]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// BAREPTR-NEXT: %[[base1:.*]] = llvm.extractvalue %[[ins4]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// BAREPTR-NEXT: llvm.return %[[base1]] : !llvm.ptr
return %static : memref<32x18xf32, strided<[22,1], offset: 7>>
}
@@ -82,7 +82,7 @@ func.func @check_memref_func_call(%in : memref<10xi8>) -> memref<20xi8> {
// BAREPTR-NEXT: %[[c1:.*]] = llvm.mlir.constant(1 : index) : i64
// BAREPTR-NEXT: %[[outDesc:.*]] = llvm.insertvalue %[[c1]], %[[desc6]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
%res = call @foo(%in) : (memref<10xi8>) -> (memref<20xi8>)
- // BAREPTR-NEXT: %[[res:.*]] = llvm.extractvalue %[[outDesc]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // BAREPTR-NEXT: %[[res:.*]] = llvm.extractvalue %[[outDesc]][0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
// BAREPTR-NEXT: llvm.return %[[res]] : !llvm.ptr
return %res : memref<20xi8>
}
More information about the Mlir-commits
mailing list