[Mlir-commits] [mlir] 2ac2e9a - [mlir][LLVM] Improve lowering of `llvm.byval` function arguments (#100028)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Aug 8 19:27:58 PDT 2024
Author: Diego Caballero
Date: 2024-08-08T19:27:54-07:00
New Revision: 2ac2e9a5b6c97cbf267db1ef322ed21ebceb2aba
URL: https://github.com/llvm/llvm-project/commit/2ac2e9a5b6c97cbf267db1ef322ed21ebceb2aba
DIFF: https://github.com/llvm/llvm-project/commit/2ac2e9a5b6c97cbf267db1ef322ed21ebceb2aba.diff
LOG: [mlir][LLVM] Improve lowering of `llvm.byval` function arguments (#100028)
When a function argument is annotated with the `llvm.byval` attribute,
[LLVM expects](https://llvm.org/docs/LangRef.html#parameter-attributes)
the function argument type to be an `llvm.ptr`. For example:
```
func.func (%args0 : llvm.ptr {llvm.byval = !llvm.struct<(i32)>} {
...
}
```
Unfortunately, this makes the type conversion context-dependent, which
is something that the type conversion infrastructure (i.e.,
`LLVMTypeConverter` in this particular case) doesn't support. For
example, we may want to convert `MyType` to `llvm.struct<(i32)>` in
general, but to an `llvm.ptr` type only when it's a function argument
passed by value.
To fix this problem, this PR changes the FuncToLLVM conversion logic to
generate an `llvm.ptr` when the function argument has a `llvm.byval`
attribute. An `llvm.load` is inserted into the function to retrieve the
value expected by the argument users.
Added:
Modified:
mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
mlir/test/Transforms/test-convert-func-op.mlir
mlir/test/lib/Conversion/FuncToLLVM/TestConvertFuncOp.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
index e228229302cff..d79b90f840ce8 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
@@ -21,6 +21,7 @@
namespace mlir {
class DataLayoutAnalysis;
+class FunctionOpInterface;
class LowerToLLVMOptions;
namespace LLVM {
@@ -50,13 +51,25 @@ class LLVMTypeConverter : public TypeConverter {
LLVMTypeConverter(MLIRContext *ctx, const LowerToLLVMOptions &options,
const DataLayoutAnalysis *analysis = nullptr);
- /// Convert a function type. The arguments and results are converted one by
+ /// Convert a function type. The arguments and results are converted one by
/// one and results are packed into a wrapped LLVM IR structure type. `result`
/// is populated with argument mapping.
Type convertFunctionSignature(FunctionType funcTy, bool isVariadic,
bool useBarePtrCallConv,
SignatureConversion &result) const;
+ /// Convert a function type. The arguments and results are converted one by
+ /// one and results are packed into a wrapped LLVM IR structure type. `result`
+ /// is populated with argument mapping. Converted types of `llvm.byval` and
+ /// `llvm.byref` function arguments which are not LLVM pointers are overridden
+ /// with LLVM pointers. Overridden arguments are returned in
+ /// `byValRefNonPtrAttrs`.
+ Type convertFunctionSignature(FunctionOpInterface funcOp, bool isVariadic,
+ bool useBarePtrCallConv,
+ LLVMTypeConverter::SignatureConversion &result,
+ SmallVectorImpl<std::optional<NamedAttribute>>
+ &byValRefNonPtrAttrs) const;
+
/// Convert a non-empty list of types to be returned from a function into an
/// LLVM-compatible type. In particular, if more than one value is returned,
/// create an LLVM dialect structure type with elements that correspond to
@@ -159,12 +172,26 @@ class LLVMTypeConverter : public TypeConverter {
SmallVector<Type> &getCurrentThreadRecursiveStack();
private:
- /// Convert a function type. The arguments and results are converted one by
- /// one. Additionally, if the function returns more than one value, pack the
+ /// Convert a function type. The arguments and results are converted one by
+ /// one. Additionally, if the function returns more than one value, pack the
/// results into an LLVM IR structure type so that the converted function type
/// returns at most one result.
Type convertFunctionType(FunctionType type) const;
+ /// Common implementation for `convertFunctionSignature` methods. Convert a
+ /// function type. The arguments and results are converted one by one and
+ /// results are packed into a wrapped LLVM IR structure type. `result` is
+ /// populated with argument mapping. If `byValRefNonPtrAttrs` is provided,
+ /// converted types of `llvm.byval` and `llvm.byref` function arguments which
+ /// are not LLVM pointers are overridden with LLVM pointers. `llvm.byval` and
+ /// `llvm.byref` arguments that were already converted to LLVM pointer types
+ /// are removed from 'byValRefNonPtrAttrs`.
+ Type convertFunctionSignatureImpl(
+ FunctionType funcTy, bool isVariadic, bool useBarePtrCallConv,
+ LLVMTypeConverter::SignatureConversion &result,
+ SmallVectorImpl<std::optional<NamedAttribute>> *byValRefNonPtrAttrs)
+ const;
+
/// Convert the index type. Uses llvmModule data layout to create an integer
/// of the pointer bitwidth.
Type convertIndexType(IndexType type) const;
diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
index c1f6d8bc5b361..4c2e8682285c5 100644
--- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
+++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
@@ -267,6 +267,38 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc,
}
}
+/// Inserts `llvm.load` ops in the function body to restore the expected pointee
+/// value from `llvm.byval`/`llvm.byref` function arguments that were converted
+/// to LLVM pointer types.
+static void restoreByValRefArgumentType(
+ ConversionPatternRewriter &rewriter, const LLVMTypeConverter &typeConverter,
+ ArrayRef<std::optional<NamedAttribute>> byValRefNonPtrAttrs,
+ LLVM::LLVMFuncOp funcOp) {
+ // Nothing to do for function declarations.
+ if (funcOp.isExternal())
+ return;
+
+ ConversionPatternRewriter::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPointToStart(&funcOp.getFunctionBody().front());
+
+ for (const auto &[arg, byValRefAttr] :
+ llvm::zip(funcOp.getArguments(), byValRefNonPtrAttrs)) {
+ // Skip argument if no `llvm.byval` or `llvm.byref` attribute.
+ if (!byValRefAttr)
+ continue;
+
+ // Insert load to retrieve the actual argument passed by value/reference.
+ assert(isa<LLVM::LLVMPointerType>(arg.getType()) &&
+ "Expected LLVM pointer type for argument with "
+ "`llvm.byval`/`llvm.byref` attribute");
+ Type resTy = typeConverter.convertType(
+ cast<TypeAttr>(byValRefAttr->getValue()).getValue());
+
+ auto valueArg = rewriter.create<LLVM::LoadOp>(arg.getLoc(), resTy, arg);
+ rewriter.replaceAllUsesExcept(arg, valueArg, valueArg);
+ }
+}
+
FailureOr<LLVM::LLVMFuncOp>
mlir::convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
ConversionPatternRewriter &rewriter,
@@ -280,10 +312,14 @@ mlir::convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
// Convert the original function arguments. They are converted using the
// LLVMTypeConverter provided to this legalization pattern.
auto varargsAttr = funcOp->getAttrOfType<BoolAttr>(varargsAttrName);
+ // Gather `llvm.byval` and `llvm.byref` arguments whose type convertion was
+ // overriden with an LLVM pointer type for later processing.
+ SmallVector<std::optional<NamedAttribute>> byValRefNonPtrAttrs;
TypeConverter::SignatureConversion result(funcOp.getNumArguments());
auto llvmType = converter.convertFunctionSignature(
- funcTy, varargsAttr && varargsAttr.getValue(),
- shouldUseBarePtrCallConv(funcOp, &converter), result);
+ funcOp, varargsAttr && varargsAttr.getValue(),
+ shouldUseBarePtrCallConv(funcOp, &converter), result,
+ byValRefNonPtrAttrs);
if (!llvmType)
return rewriter.notifyMatchFailure(funcOp, "signature conversion failed");
@@ -398,6 +434,12 @@ mlir::convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
"region types conversion failed");
}
+ // Fix the type mismatch between the materialized `llvm.ptr` and the expected
+ // pointee type in the function body when converting `llvm.byval`/`llvm.byref`
+ // function arguments.
+ restoreByValRefArgumentType(rewriter, converter, byValRefNonPtrAttrs,
+ newFuncOp);
+
if (!shouldUseBarePtrCallConv(funcOp, &converter)) {
if (funcOp->getAttrOfType<UnitAttr>(
LLVM::LLVMDialect::getEmitCWrapperAttrName())) {
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index 17be4d91ee054..5313a64ed47e3 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -270,13 +270,42 @@ Type LLVMTypeConverter::convertFunctionType(FunctionType type) const {
return LLVM::LLVMPointerType::get(type.getContext());
}
+/// Returns the `llvm.byval` or `llvm.byref` attributes that are present in the
+/// function arguments. Returns an empty container if none of these attributes
+/// are found in any of the arguments.
+static void
+filterByValRefArgAttrs(FunctionOpInterface funcOp,
+ SmallVectorImpl<std::optional<NamedAttribute>> &result) {
+ assert(result.empty() && "Unexpected non-empty output");
+ result.resize(funcOp.getNumArguments(), std::nullopt);
+ bool foundByValByRefAttrs = false;
+ for (int argIdx : llvm::seq(funcOp.getNumArguments())) {
+ for (NamedAttribute namedAttr : funcOp.getArgAttrs(argIdx)) {
+ if ((namedAttr.getName() == LLVM::LLVMDialect::getByValAttrName() ||
+ namedAttr.getName() == LLVM::LLVMDialect::getByRefAttrName())) {
+ foundByValByRefAttrs = true;
+ result[argIdx] = namedAttr;
+ break;
+ }
+ }
+ }
+
+ if (!foundByValByRefAttrs)
+ result.clear();
+}
+
// Function types are converted to LLVM Function types by recursively converting
-// argument and result types. If MLIR Function has zero results, the LLVM
-// Function has one VoidType result. If MLIR Function has more than one result,
+// argument and result types. If MLIR Function has zero results, the LLVM
+// Function has one VoidType result. If MLIR Function has more than one result,
// they are into an LLVM StructType in their order of appearance.
-Type LLVMTypeConverter::convertFunctionSignature(
+// If `byValRefNonPtrAttrs` is provided, converted types of `llvm.byval` and
+// `llvm.byref` function arguments which are not LLVM pointers are overridden
+// with LLVM pointers. `llvm.byval` and `llvm.byref` arguments that were already
+// converted to LLVM pointer types are removed from 'byValRefNonPtrAttrs`.
+Type LLVMTypeConverter::convertFunctionSignatureImpl(
FunctionType funcTy, bool isVariadic, bool useBarePtrCallConv,
- LLVMTypeConverter::SignatureConversion &result) const {
+ LLVMTypeConverter::SignatureConversion &result,
+ SmallVectorImpl<std::optional<NamedAttribute>> *byValRefNonPtrAttrs) const {
// Select the argument converter depending on the calling convention.
useBarePtrCallConv = useBarePtrCallConv || options.useBarePtrCallConv;
auto funcArgConverter = useBarePtrCallConv ? barePtrFuncArgTypeConverter
@@ -286,6 +315,19 @@ Type LLVMTypeConverter::convertFunctionSignature(
SmallVector<Type, 8> converted;
if (failed(funcArgConverter(*this, type, converted)))
return {};
+
+ // Rewrite converted type of `llvm.byval` or `llvm.byref` function
+ // argument that was not converted to an LLVM pointer types.
+ if (byValRefNonPtrAttrs != nullptr && !byValRefNonPtrAttrs->empty() &&
+ converted.size() == 1 && (*byValRefNonPtrAttrs)[idx].has_value()) {
+ // If the argument was already converted to an LLVM pointer type, we stop
+ // tracking it as it doesn't need more processing.
+ if (isa<LLVM::LLVMPointerType>(converted[0]))
+ (*byValRefNonPtrAttrs)[idx] = std::nullopt;
+ else
+ converted[0] = LLVM::LLVMPointerType::get(&getContext());
+ }
+
result.addInputs(idx, converted);
}
@@ -302,6 +344,27 @@ Type LLVMTypeConverter::convertFunctionSignature(
isVariadic);
}
+Type LLVMTypeConverter::convertFunctionSignature(
+ FunctionType funcTy, bool isVariadic, bool useBarePtrCallConv,
+ LLVMTypeConverter::SignatureConversion &result) const {
+ return convertFunctionSignatureImpl(funcTy, isVariadic, useBarePtrCallConv,
+ result,
+ /*byValRefNonPtrAttrs=*/nullptr);
+}
+
+Type LLVMTypeConverter::convertFunctionSignature(
+ FunctionOpInterface funcOp, bool isVariadic, bool useBarePtrCallConv,
+ LLVMTypeConverter::SignatureConversion &result,
+ SmallVectorImpl<std::optional<NamedAttribute>> &byValRefNonPtrAttrs) const {
+ // Gather all `llvm.byval` and `llvm.byref` function arguments. Only those
+ // that were not converted to LLVM pointer types will be returned for further
+ // processing.
+ filterByValRefArgAttrs(funcOp, byValRefNonPtrAttrs);
+ auto funcTy = cast<FunctionType>(funcOp.getFunctionType());
+ return convertFunctionSignatureImpl(funcTy, isVariadic, useBarePtrCallConv,
+ result, &byValRefNonPtrAttrs);
+}
+
/// Converts the function type to a C-compatible format, in particular using
/// pointers to memref descriptors for arguments.
std::pair<LLVM::LLVMFunctionType, LLVM::LLVMStructType>
diff --git a/mlir/test/Transforms/test-convert-func-op.mlir b/mlir/test/Transforms/test-convert-func-op.mlir
index 6e96703cda578..180f16a32991b 100644
--- a/mlir/test/Transforms/test-convert-func-op.mlir
+++ b/mlir/test/Transforms/test-convert-func-op.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -test-convert-func-op | FileCheck %s
+// RUN: mlir-opt %s -test-convert-func-op --split-input-file | FileCheck %s
// CHECK-LABEL: llvm.func @add
func.func @add(%arg0: i32, %arg1: i32) -> i32 attributes { llvm.emit_c_interface } {
@@ -10,3 +10,31 @@ func.func @add(%arg0: i32, %arg1: i32) -> i32 attributes { llvm.emit_c_interface
// CHECK-SAME: [[ARG1:%[a-zA-Z0-9_]+]]: i32
// CHECK-NEXT: [[RES:%.*]] = llvm.call @add([[ARG0]], [[ARG1]])
// CHECK-NEXT: llvm.return [[RES]]
+
+// -----
+
+// Test that `llvm.byval` arguments are converted to `llvm.ptr` and the actual
+// value is retrieved within the `llvm.func`.
+
+// CHECK-LABEL: llvm.func @byval
+func.func @byval(%arg0: !test.smpla {llvm.byval = !test.smpla}) -> !test.smpla {
+ return %arg0 : !test.smpla
+}
+
+// CHECK-SAME: (%[[ARG0:.*]]: !llvm.ptr {llvm.byval = !llvm.struct<(i8, i8)>}) -> !llvm.struct<(i8, i8)>
+// CHECK: %[[LD:.*]] = llvm.load %[[ARG0]] : !llvm.ptr -> !llvm.struct<(i8, i8)>
+// CHECK: llvm.return %[[LD]] : !llvm.struct<(i8, i8)>
+
+// -----
+
+// Test that `llvm.byref` arguments are converted to `llvm.ptr` and the actual
+// value is retrieved within the `llvm.func`.
+
+// CHECK-LABEL: llvm.func @byref
+func.func @byref(%arg0: !test.smpla {llvm.byref = !test.smpla}) -> !test.smpla {
+ return %arg0 : !test.smpla
+}
+
+// CHECK-SAME: (%[[ARG0:.*]]: !llvm.ptr {llvm.byref = !llvm.struct<(i8, i8)>}) -> !llvm.struct<(i8, i8)>
+// CHECK: %[[LD:.*]] = llvm.load %[[ARG0]] : !llvm.ptr -> !llvm.struct<(i8, i8)>
+// CHECK: llvm.return %[[LD]] : !llvm.struct<(i8, i8)>
diff --git a/mlir/test/lib/Conversion/FuncToLLVM/TestConvertFuncOp.cpp b/mlir/test/lib/Conversion/FuncToLLVM/TestConvertFuncOp.cpp
index e25e890e2290a..75168dde93130 100644
--- a/mlir/test/lib/Conversion/FuncToLLVM/TestConvertFuncOp.cpp
+++ b/mlir/test/lib/Conversion/FuncToLLVM/TestConvertFuncOp.cpp
@@ -47,12 +47,23 @@ struct ReturnOpConversion : public ConvertOpToLLVMPattern<func::ReturnOp> {
LogicalResult
matchAndRewrite(func::ReturnOp returnOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnOp,
- returnOp->getOperands());
+ SmallVector<Type> resTys;
+ if (failed(typeConverter->convertTypes(returnOp->getResultTypes(), resTys)))
+ return failure();
+
+ rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnOp, resTys,
+ adaptor.getOperands());
return success();
}
};
+static std::optional<Type>
+convertSimpleATypeToStruct(test::SimpleAType simpleTy) {
+ MLIRContext *ctx = simpleTy.getContext();
+ SmallVector<Type> memberTys(2, IntegerType::get(ctx, /*width=*/8));
+ return LLVM::LLVMStructType::getLiteral(ctx, memberTys);
+}
+
struct TestConvertFuncOp
: public PassWrapper<TestConvertFuncOp, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestConvertFuncOp)
@@ -74,6 +85,7 @@ struct TestConvertFuncOp
LowerToLLVMOptions options(ctx);
// Populate type conversions.
LLVMTypeConverter typeConverter(ctx, options);
+ typeConverter.addConversion(convertSimpleATypeToStruct);
RewritePatternSet patterns(ctx);
patterns.add<FuncOpConversion>(typeConverter);
More information about the Mlir-commits
mailing list