[Mlir-commits] [mlir] [mlir][LLVM] Improve lowering of `llvm.byval` function arguments (PR #100028)
Diego Caballero
llvmlistbot at llvm.org
Fri Jul 26 16:39:16 PDT 2024
https://github.com/dcaballe updated https://github.com/llvm/llvm-project/pull/100028
>From 51d558e655ce2b9440134effe195137ac9ccbfd3 Mon Sep 17 00:00:00 2001
From: Diego Caballero <dieg0ca6aller0 at gmail.com>
Date: Fri, 26 Jul 2024 16:22:13 -0700
Subject: [PATCH] [mlir][LLVM] Improve lowering of llvm.byval function
arguments
When a function argument is annotated with the `llvm.byval` attribute,
[LLVM expects] (https://llvm.org/docs/LangRef.html#parameter-attributes)
the function argument type to be an `llvm.ptr`. For example:
```
func.func (%args0 : llvm.ptr {llvm.byval = !llvm.struct<(i32)>} {
...
}
```
Unfortunately, this makes the type conversion context-dependent, which is
something that the type conversion infrastructure (i.e., `LLVMTypeConverter`
in this particular case) doesn't support. For example, we may want to convert
`MyType` to `llvm.struct<(i32)>` in general, but to an `llvm.ptr` type only
when it's a function argument passed by value.
To fix this problem, this PR changes the FuncToLLVM conversion logic to always
generate an `llvm.ptr` when the function argument has a `llvm.byval` attribute.
An `llvm.load` is inserted into the function to retrieve the value expected by
the argument users.
---
.../Conversion/LLVMCommon/TypeConverter.h | 29 +++++---
mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp | 39 +++++++++-
.../Conversion/GPUCommon/GPUOpsLowering.cpp | 6 +-
.../Conversion/LLVMCommon/TypeConverter.cpp | 72 +++++++++++++++----
.../Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp | 4 +-
.../test/Transforms/test-convert-func-op.mlir | 30 +++++++-
.../FuncToLLVM/TestConvertFuncOp.cpp | 16 ++++-
7 files changed, 166 insertions(+), 30 deletions(-)
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
index e228229302cff..007e6ba39b632 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
@@ -21,6 +21,7 @@
namespace mlir {
class DataLayoutAnalysis;
+class FunctionOpInterface;
class LowerToLLVMOptions;
namespace LLVM {
@@ -35,6 +36,7 @@ class LLVMTypeConverter : public TypeConverter {
/// Give structFuncArgTypeConverter access to memref-specific functions.
friend LogicalResult
structFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type,
+ std::optional<NamedAttribute> byValByRefArgAttr,
SmallVectorImpl<Type> &result);
public:
@@ -53,9 +55,10 @@ class LLVMTypeConverter : public TypeConverter {
/// Convert a function type. The arguments and results are converted one by
/// one and results are packed into a wrapped LLVM IR structure type. `result`
/// is populated with argument mapping.
- Type convertFunctionSignature(FunctionType funcTy, bool isVariadic,
- bool useBarePtrCallConv,
- SignatureConversion &result) const;
+ Type convertFunctionSignature(
+ FunctionType funcTy, bool isVariadic, bool useBarePtrCallConv,
+ ArrayRef<std::optional<NamedAttribute>> byValByRefArgAttr,
+ SignatureConversion &result) const;
/// Convert a non-empty list of types to be returned from a function into an
/// LLVM-compatible type. In particular, if more than one value is returned,
@@ -242,15 +245,23 @@ class LLVMTypeConverter : public TypeConverter {
/// argument to a list of non-aggregate types containing descriptor
/// information, and an UnrankedmemRef function argument to a list containing
/// the rank and a pointer to a descriptor struct.
-LogicalResult structFuncArgTypeConverter(const LLVMTypeConverter &converter,
- Type type,
- SmallVectorImpl<Type> &result);
+LogicalResult
+structFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type,
+ std::optional<NamedAttribute> byValByRefArgAttr,
+ SmallVectorImpl<Type> &result);
/// Callback to convert function argument types. It converts MemRef function
/// arguments to bare pointers to the MemRef element type.
-LogicalResult barePtrFuncArgTypeConverter(const LLVMTypeConverter &converter,
- Type type,
- SmallVectorImpl<Type> &result);
+LogicalResult
+barePtrFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type,
+ std::optional<NamedAttribute> byValByRefArgAttr,
+ SmallVectorImpl<Type> &result);
+
+/// Returns in `result` the `llvm.byval` or `llvm.byref` attributes, if
+/// present, or an empty attribute for each function argument.
+void filterByValByRefArgAttributes(
+ FunctionOpInterface funcOp,
+ SmallVectorImpl<std::optional<NamedAttribute>> &result);
} // namespace mlir
diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
index c1f6d8bc5b361..8d4645d46f069 100644
--- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
+++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
@@ -267,6 +267,36 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc,
}
}
+/// Inserts `llvm.load` ops in the function body to restore the expected pointee
+/// value from `llvm.byval`/`llvm.byref` function arguments that were converted
+/// to LLVM pointer types.
+static void restoreByValByRefArgumentType(ConversionPatternRewriter &rewriter,
+ FunctionOpInterface funcOp) {
+ // Nothing to do for function declarations.
+ if (funcOp.isExternal())
+ return;
+
+ ConversionPatternRewriter::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPointToStart(&funcOp.getFunctionBody().front());
+
+ SmallVector<std::optional<NamedAttribute>> byValRefArgAttrs;
+ filterByValByRefArgAttributes(funcOp, byValRefArgAttrs);
+ for (const auto &[arg, byValRefAttr] :
+ llvm::zip(funcOp.getArguments(), byValRefArgAttrs)) {
+ // Skip argument if no `llvm.byval` or `llvm.byref` attribute.
+ if (!byValRefAttr)
+ continue;
+
+ // Insert load to retrieve the actual argument passed by value/reference.
+ assert(isa<LLVM::LLVMPointerType>(arg.getType()) &&
+ "Expected LLVM pointer type for argument with "
+ "`llvm.byval`/`llvm.byref` attribute");
+ Type resTy = cast<TypeAttr>(byValRefAttr->getValue()).getValue();
+ auto valueArg = rewriter.create<LLVM::LoadOp>(arg.getLoc(), resTy, arg);
+ rewriter.replaceAllUsesExcept(arg, valueArg, valueArg);
+ }
+}
+
FailureOr<LLVM::LLVMFuncOp>
mlir::convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
ConversionPatternRewriter &rewriter,
@@ -280,10 +310,12 @@ mlir::convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
// Convert the original function arguments. They are converted using the
// LLVMTypeConverter provided to this legalization pattern.
auto varargsAttr = funcOp->getAttrOfType<BoolAttr>(varargsAttrName);
+ SmallVector<std::optional<NamedAttribute>> byValByRefArgAttrs;
+ filterByValByRefArgAttributes(funcOp, byValByRefArgAttrs);
TypeConverter::SignatureConversion result(funcOp.getNumArguments());
auto llvmType = converter.convertFunctionSignature(
funcTy, varargsAttr && varargsAttr.getValue(),
- shouldUseBarePtrCallConv(funcOp, &converter), result);
+ shouldUseBarePtrCallConv(funcOp, &converter), byValByRefArgAttrs, result);
if (!llvmType)
return rewriter.notifyMatchFailure(funcOp, "signature conversion failed");
@@ -398,6 +430,11 @@ mlir::convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
"region types conversion failed");
}
+ // Fix the type mismatch between the generated `llvm.ptr` and the expected
+ // pointee type in the function body when converting `llvm.byval`/`llvm.byref`
+ // function arguments.
+ restoreByValByRefArgumentType(rewriter, newFuncOp);
+
if (!shouldUseBarePtrCallConv(funcOp, &converter)) {
if (funcOp->getAttrOfType<UnitAttr>(
LLVM::LLVMDialect::getEmitCWrapperAttrName())) {
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index 6053e34f30a41..143f7b3071253 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -53,10 +53,12 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
// Remap proper input types.
TypeConverter::SignatureConversion signatureConversion(
gpuFuncOp.front().getNumArguments());
-
+ SmallVector<std::optional<NamedAttribute>> byValByRefArgAttrs;
+ filterByValByRefArgAttributes(gpuFuncOp, byValByRefArgAttrs);
Type funcType = getTypeConverter()->convertFunctionSignature(
gpuFuncOp.getFunctionType(), /*isVariadic=*/false,
- getTypeConverter()->getOptions().useBarePtrCallConv, signatureConversion);
+ getTypeConverter()->getOptions().useBarePtrCallConv, byValByRefArgAttrs,
+ signatureConversion);
if (!funcType) {
return rewriter.notifyMatchFailure(gpuFuncOp, [&](Diagnostic &diag) {
diag << "failed to convert function signature type for: "
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index d5df960928afb..38de4f2fdeac1 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -276,6 +276,7 @@ Type LLVMTypeConverter::convertFunctionType(FunctionType type) const {
// they are into an LLVM StructType in their order of appearance.
Type LLVMTypeConverter::convertFunctionSignature(
FunctionType funcTy, bool isVariadic, bool useBarePtrCallConv,
+ ArrayRef<std::optional<NamedAttribute>> byValByRefArgAtts,
LLVMTypeConverter::SignatureConversion &result) const {
// Select the argument converter depending on the calling convention.
useBarePtrCallConv = useBarePtrCallConv || options.useBarePtrCallConv;
@@ -284,7 +285,8 @@ Type LLVMTypeConverter::convertFunctionSignature(
// Convert argument types one by one and check for errors.
for (auto [idx, type] : llvm::enumerate(funcTy.getInputs())) {
SmallVector<Type, 8> converted;
- if (failed(funcArgConverter(*this, type, converted)))
+ if (failed(
+ funcArgConverter(*this, type, byValByRefArgAtts[idx], converted)))
return {};
result.addInputs(idx, converted);
}
@@ -659,9 +661,10 @@ LLVMTypeConverter::promoteOperands(Location loc, ValueRange opOperands,
/// argument to a list of non-aggregate types containing descriptor
/// information, and an UnrankedmemRef function argument to a list containing
/// the rank and a pointer to a descriptor struct.
-LogicalResult
-mlir::structFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type,
- SmallVectorImpl<Type> &result) {
+LogicalResult mlir::structFuncArgTypeConverter(
+ const LLVMTypeConverter &converter, Type type,
+ std::optional<NamedAttribute> byValByRefArgAttr,
+ SmallVectorImpl<Type> &result) {
if (auto memref = dyn_cast<MemRefType>(type)) {
// In signatures, Memref descriptors are expanded into lists of
// non-aggregate values.
@@ -679,23 +682,64 @@ mlir::structFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type,
result.append(converted.begin(), converted.end());
return success();
}
- auto converted = converter.convertType(type);
- if (!converted)
- return failure();
+
+ /// If the argument has the `llvm.byval` or `llvm.byref` attribute, the
+ /// converted type is an LLVM pointer so that the LLVM argument passing
+ /// is correct.
+ Type converted;
+ if (byValByRefArgAttr.has_value() &&
+ (byValByRefArgAttr->getName() == LLVM::LLVMDialect::getByValAttrName() ||
+ byValByRefArgAttr->getName() == LLVM::LLVMDialect::getByRefAttrName())) {
+ converted = LLVM::LLVMPointerType::get(type.getContext());
+ } else {
+ converted = converter.convertType(type);
+ if (!converted)
+ return failure();
+ }
+
result.push_back(converted);
return success();
}
/// Callback to convert function argument types. It converts MemRef function
/// arguments to bare pointers to the MemRef element type.
-LogicalResult
-mlir::barePtrFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type,
- SmallVectorImpl<Type> &result) {
- auto llvmTy = converter.convertCallingConventionType(
- type, /*useBarePointerCallConv=*/true);
- if (!llvmTy)
- return failure();
+LogicalResult mlir::barePtrFuncArgTypeConverter(
+ const LLVMTypeConverter &converter, Type type,
+ std::optional<NamedAttribute> byValByRefArgAttr,
+ SmallVectorImpl<Type> &result) {
+ /// If the argument has the `llvm.byval` or `llvm.byref` attribute, the
+ /// converted type is an LLVM pointer so that the LLVM argument passing
+ /// convention is correct.
+ Type llvmTy;
+ if (byValByRefArgAttr.has_value() &&
+ (byValByRefArgAttr->getName() == LLVM::LLVMDialect::getByValAttrName() ||
+ byValByRefArgAttr->getName() == LLVM::LLVMDialect::getByRefAttrName())) {
+ llvmTy = LLVM::LLVMPointerType::get(type.getContext());
+ } else {
+ llvmTy = converter.convertCallingConventionType(
+ type, /*useBarePointerCallConv=*/true);
+
+ if (!llvmTy)
+ return failure();
+ }
result.push_back(llvmTy);
return success();
}
+
+void mlir::filterByValByRefArgAttributes(
+ FunctionOpInterface funcOp,
+ SmallVectorImpl<std::optional<NamedAttribute>> &result) {
+
+ assert(result.empty() && "Unexpected non-empty output");
+ for (int argIdx : llvm::seq(funcOp.getNumArguments())) {
+ for (NamedAttribute namedAttr : funcOp.getArgAttrs(argIdx)) {
+ if (namedAttr.getName() == LLVM::LLVMDialect::getByValAttrName() ||
+ namedAttr.getName() == LLVM::LLVMDialect::getByRefAttrName()) {
+ result.emplace_back(namedAttr);
+ break;
+ }
+ }
+ result.emplace_back(std::nullopt);
+ }
+}
diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
index da09384bfbe89..d5a9bc3783660 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
@@ -1408,11 +1408,13 @@ class FuncConversionPattern : public SPIRVToLLVMConversion<spirv::FuncOp> {
// Convert function signature. At the moment LLVMType converter is enough
// for currently supported types.
auto funcType = funcOp.getFunctionType();
+ SmallVector<std::optional<NamedAttribute>> byValByRefArgAttrs;
+ filterByValByRefArgAttributes(funcOp, byValByRefArgAttrs);
TypeConverter::SignatureConversion signatureConverter(
funcType.getNumInputs());
auto llvmType = typeConverter.convertFunctionSignature(
funcType, /*isVariadic=*/false, /*useBarePtrCallConv=*/false,
- signatureConverter);
+ byValByRefArgAttrs, signatureConverter);
if (!llvmType)
return failure();
diff --git a/mlir/test/Transforms/test-convert-func-op.mlir b/mlir/test/Transforms/test-convert-func-op.mlir
index 6e96703cda578..9f006a5a187c1 100644
--- a/mlir/test/Transforms/test-convert-func-op.mlir
+++ b/mlir/test/Transforms/test-convert-func-op.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -test-convert-func-op | FileCheck %s
+// RUN: mlir-opt %s -test-convert-func-op -split-input-file | FileCheck %s
// CHECK-LABEL: llvm.func @add
func.func @add(%arg0: i32, %arg1: i32) -> i32 attributes { llvm.emit_c_interface } {
@@ -10,3 +10,31 @@ func.func @add(%arg0: i32, %arg1: i32) -> i32 attributes { llvm.emit_c_interface
// CHECK-SAME: [[ARG1:%[a-zA-Z0-9_]+]]: i32
// CHECK-NEXT: [[RES:%.*]] = llvm.call @add([[ARG0]], [[ARG1]])
// CHECK-NEXT: llvm.return [[RES]]
+
+// -----
+
+// Test that `llvm.byval` arguments are converted to `llvm.ptr` and the actual
+// value is retrieved within the `llvm.func`.
+
+// CHECK-LABEL: llvm.func @byval
+func.func @byval(%arg0: !test.smpla {llvm.byval = !test.smpla}) -> !test.smpla {
+ return %arg0 : !test.smpla
+}
+
+// CHECK-SAME: (%[[ARG0:.*]]: !llvm.ptr {llvm.byval = !llvm.struct<(i8, i8)>}) -> !llvm.struct<(i8, i8)>
+// CHECK: %[[LD:.*]] = llvm.load %[[ARG0]] : !llvm.ptr -> !llvm.struct<(i8, i8)>
+// CHECK: llvm.return %[[LD]] : !llvm.struct<(i8, i8)>
+
+// -----
+
+// Test that `llvm.byref` arguments are converted to `llvm.ptr` and the actual
+// value is retrieved within the `llvm.func`.
+
+// CHECK-LABEL: llvm.func @byref
+func.func @byref(%arg0: !test.smpla {llvm.byref = !test.smpla}) -> !test.smpla {
+ return %arg0 : !test.smpla
+}
+
+// CHECK-SAME: (%[[ARG0:.*]]: !llvm.ptr {llvm.byref = !llvm.struct<(i8, i8)>}) -> !llvm.struct<(i8, i8)>
+// CHECK: %[[LD:.*]] = llvm.load %[[ARG0]] : !llvm.ptr -> !llvm.struct<(i8, i8)>
+// CHECK: llvm.return %[[LD]] : !llvm.struct<(i8, i8)>
\ No newline at end of file
diff --git a/mlir/test/lib/Conversion/FuncToLLVM/TestConvertFuncOp.cpp b/mlir/test/lib/Conversion/FuncToLLVM/TestConvertFuncOp.cpp
index e25e890e2290a..75168dde93130 100644
--- a/mlir/test/lib/Conversion/FuncToLLVM/TestConvertFuncOp.cpp
+++ b/mlir/test/lib/Conversion/FuncToLLVM/TestConvertFuncOp.cpp
@@ -47,12 +47,23 @@ struct ReturnOpConversion : public ConvertOpToLLVMPattern<func::ReturnOp> {
LogicalResult
matchAndRewrite(func::ReturnOp returnOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnOp,
- returnOp->getOperands());
+ SmallVector<Type> resTys;
+ if (failed(typeConverter->convertTypes(returnOp->getResultTypes(), resTys)))
+ return failure();
+
+ rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnOp, resTys,
+ adaptor.getOperands());
return success();
}
};
+static std::optional<Type>
+convertSimpleATypeToStruct(test::SimpleAType simpleTy) {
+ MLIRContext *ctx = simpleTy.getContext();
+ SmallVector<Type> memberTys(2, IntegerType::get(ctx, /*width=*/8));
+ return LLVM::LLVMStructType::getLiteral(ctx, memberTys);
+}
+
struct TestConvertFuncOp
: public PassWrapper<TestConvertFuncOp, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestConvertFuncOp)
@@ -74,6 +85,7 @@ struct TestConvertFuncOp
LowerToLLVMOptions options(ctx);
// Populate type conversions.
LLVMTypeConverter typeConverter(ctx, options);
+ typeConverter.addConversion(convertSimpleATypeToStruct);
RewritePatternSet patterns(ctx);
patterns.add<FuncOpConversion>(typeConverter);
More information about the Mlir-commits
mailing list