[Mlir-commits] [mlir] [mlir][LLVM] Improve lowering of `llvm.byval` function arguments (PR #100028)
Diego Caballero
llvmlistbot at llvm.org
Thu Aug 1 13:04:17 PDT 2024
https://github.com/dcaballe updated https://github.com/llvm/llvm-project/pull/100028
>From eb7cec3712345e5260142886e1fa608ad57ca69b Mon Sep 17 00:00:00 2001
From: Diego Caballero <dieg0ca6aller0 at gmail.com>
Date: Fri, 26 Jul 2024 16:22:13 -0700
Subject: [PATCH 1/4] [mlir][LLVM] Improve lowering of llvm.byval function
arguments
When a function argument is annotated with the `llvm.byval` attribute,
[LLVM expects] (https://llvm.org/docs/LangRef.html#parameter-attributes)
the function argument type to be an `llvm.ptr`. For example:
```
func.func (%args0 : llvm.ptr {llvm.byval = !llvm.struct<(i32)>} {
...
}
```
Unfortunately, this makes the type conversion context-dependent, which is
something that the type conversion infrastructure (i.e., `LLVMTypeConverter`
in this particular case) doesn't support. For example, we may want to convert
`MyType` to `llvm.struct<(i32)>` in general, but to an `llvm.ptr` type only
when it's a function argument passed by value.
To fix this problem, this PR changes the FuncToLLVM conversion logic to always
generate an `llvm.ptr` when the function argument has a `llvm.byval` attribute.
An `llvm.load` is inserted into the function to retrieve the value expected by
the argument users.
---
.../Conversion/LLVMCommon/TypeConverter.h | 29 +++++---
mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp | 39 +++++++++-
.../Conversion/GPUCommon/GPUOpsLowering.cpp | 6 +-
.../Conversion/LLVMCommon/TypeConverter.cpp | 71 +++++++++++++++----
.../Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp | 4 +-
.../test/Transforms/test-convert-func-op.mlir | 30 +++++++-
.../FuncToLLVM/TestConvertFuncOp.cpp | 16 ++++-
7 files changed, 165 insertions(+), 30 deletions(-)
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
index e228229302cff..007e6ba39b632 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
@@ -21,6 +21,7 @@
namespace mlir {
class DataLayoutAnalysis;
+class FunctionOpInterface;
class LowerToLLVMOptions;
namespace LLVM {
@@ -35,6 +36,7 @@ class LLVMTypeConverter : public TypeConverter {
/// Give structFuncArgTypeConverter access to memref-specific functions.
friend LogicalResult
structFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type,
+ std::optional<NamedAttribute> byValByRefArgAttr,
SmallVectorImpl<Type> &result);
public:
@@ -53,9 +55,10 @@ class LLVMTypeConverter : public TypeConverter {
/// Convert a function type. The arguments and results are converted one by
/// one and results are packed into a wrapped LLVM IR structure type. `result`
/// is populated with argument mapping.
- Type convertFunctionSignature(FunctionType funcTy, bool isVariadic,
- bool useBarePtrCallConv,
- SignatureConversion &result) const;
+ Type convertFunctionSignature(
+ FunctionType funcTy, bool isVariadic, bool useBarePtrCallConv,
+ ArrayRef<std::optional<NamedAttribute>> byValByRefArgAttr,
+ SignatureConversion &result) const;
/// Convert a non-empty list of types to be returned from a function into an
/// LLVM-compatible type. In particular, if more than one value is returned,
@@ -242,15 +245,23 @@ class LLVMTypeConverter : public TypeConverter {
/// argument to a list of non-aggregate types containing descriptor
/// information, and an UnrankedmemRef function argument to a list containing
/// the rank and a pointer to a descriptor struct.
-LogicalResult structFuncArgTypeConverter(const LLVMTypeConverter &converter,
- Type type,
- SmallVectorImpl<Type> &result);
+LogicalResult
+structFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type,
+ std::optional<NamedAttribute> byValByRefArgAttr,
+ SmallVectorImpl<Type> &result);
/// Callback to convert function argument types. It converts MemRef function
/// arguments to bare pointers to the MemRef element type.
-LogicalResult barePtrFuncArgTypeConverter(const LLVMTypeConverter &converter,
- Type type,
- SmallVectorImpl<Type> &result);
+LogicalResult
+barePtrFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type,
+ std::optional<NamedAttribute> byValByRefArgAttr,
+ SmallVectorImpl<Type> &result);
+
+/// Returns in `result` the `llvm.byval` or `llvm.byref` attributes, if
+/// present, or an empty attribute for each function argument.
+void filterByValByRefArgAttributes(
+ FunctionOpInterface funcOp,
+ SmallVectorImpl<std::optional<NamedAttribute>> &result);
} // namespace mlir
diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
index c1f6d8bc5b361..8d4645d46f069 100644
--- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
+++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
@@ -267,6 +267,36 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc,
}
}
+/// Inserts `llvm.load` ops in the function body to restore the expected pointee
+/// value from `llvm.byval`/`llvm.byref` function arguments that were converted
+/// to LLVM pointer types.
+static void restoreByValByRefArgumentType(ConversionPatternRewriter &rewriter,
+ FunctionOpInterface funcOp) {
+ // Nothing to do for function declarations.
+ if (funcOp.isExternal())
+ return;
+
+ ConversionPatternRewriter::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPointToStart(&funcOp.getFunctionBody().front());
+
+ SmallVector<std::optional<NamedAttribute>> byValRefArgAttrs;
+ filterByValByRefArgAttributes(funcOp, byValRefArgAttrs);
+ for (const auto &[arg, byValRefAttr] :
+ llvm::zip(funcOp.getArguments(), byValRefArgAttrs)) {
+ // Skip argument if no `llvm.byval` or `llvm.byref` attribute.
+ if (!byValRefAttr)
+ continue;
+
+ // Insert load to retrieve the actual argument passed by value/reference.
+ assert(isa<LLVM::LLVMPointerType>(arg.getType()) &&
+ "Expected LLVM pointer type for argument with "
+ "`llvm.byval`/`llvm.byref` attribute");
+ Type resTy = cast<TypeAttr>(byValRefAttr->getValue()).getValue();
+ auto valueArg = rewriter.create<LLVM::LoadOp>(arg.getLoc(), resTy, arg);
+ rewriter.replaceAllUsesExcept(arg, valueArg, valueArg);
+ }
+}
+
FailureOr<LLVM::LLVMFuncOp>
mlir::convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
ConversionPatternRewriter &rewriter,
@@ -280,10 +310,12 @@ mlir::convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
// Convert the original function arguments. They are converted using the
// LLVMTypeConverter provided to this legalization pattern.
auto varargsAttr = funcOp->getAttrOfType<BoolAttr>(varargsAttrName);
+ SmallVector<std::optional<NamedAttribute>> byValByRefArgAttrs;
+ filterByValByRefArgAttributes(funcOp, byValByRefArgAttrs);
TypeConverter::SignatureConversion result(funcOp.getNumArguments());
auto llvmType = converter.convertFunctionSignature(
funcTy, varargsAttr && varargsAttr.getValue(),
- shouldUseBarePtrCallConv(funcOp, &converter), result);
+ shouldUseBarePtrCallConv(funcOp, &converter), byValByRefArgAttrs, result);
if (!llvmType)
return rewriter.notifyMatchFailure(funcOp, "signature conversion failed");
@@ -398,6 +430,11 @@ mlir::convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
"region types conversion failed");
}
+ // Fix the type mismatch between the generated `llvm.ptr` and the expected
+ // pointee type in the function body when converting `llvm.byval`/`llvm.byref`
+ // function arguments.
+ restoreByValByRefArgumentType(rewriter, newFuncOp);
+
if (!shouldUseBarePtrCallConv(funcOp, &converter)) {
if (funcOp->getAttrOfType<UnitAttr>(
LLVM::LLVMDialect::getEmitCWrapperAttrName())) {
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index 6053e34f30a41..143f7b3071253 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -53,10 +53,12 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
// Remap proper input types.
TypeConverter::SignatureConversion signatureConversion(
gpuFuncOp.front().getNumArguments());
-
+ SmallVector<std::optional<NamedAttribute>> byValByRefArgAttrs;
+ filterByValByRefArgAttributes(gpuFuncOp, byValByRefArgAttrs);
Type funcType = getTypeConverter()->convertFunctionSignature(
gpuFuncOp.getFunctionType(), /*isVariadic=*/false,
- getTypeConverter()->getOptions().useBarePtrCallConv, signatureConversion);
+ getTypeConverter()->getOptions().useBarePtrCallConv, byValByRefArgAttrs,
+ signatureConversion);
if (!funcType) {
return rewriter.notifyMatchFailure(gpuFuncOp, [&](Diagnostic &diag) {
diag << "failed to convert function signature type for: "
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index d5df960928afb..c62096fdc853d 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -276,6 +276,7 @@ Type LLVMTypeConverter::convertFunctionType(FunctionType type) const {
// they are into an LLVM StructType in their order of appearance.
Type LLVMTypeConverter::convertFunctionSignature(
FunctionType funcTy, bool isVariadic, bool useBarePtrCallConv,
+ ArrayRef<std::optional<NamedAttribute>> byValByRefArgAtts,
LLVMTypeConverter::SignatureConversion &result) const {
// Select the argument converter depending on the calling convention.
useBarePtrCallConv = useBarePtrCallConv || options.useBarePtrCallConv;
@@ -284,7 +285,8 @@ Type LLVMTypeConverter::convertFunctionSignature(
// Convert argument types one by one and check for errors.
for (auto [idx, type] : llvm::enumerate(funcTy.getInputs())) {
SmallVector<Type, 8> converted;
- if (failed(funcArgConverter(*this, type, converted)))
+ if (failed(
+ funcArgConverter(*this, type, byValByRefArgAtts[idx], converted)))
return {};
result.addInputs(idx, converted);
}
@@ -659,9 +661,10 @@ LLVMTypeConverter::promoteOperands(Location loc, ValueRange opOperands,
/// argument to a list of non-aggregate types containing descriptor
/// information, and an UnrankedmemRef function argument to a list containing
/// the rank and a pointer to a descriptor struct.
-LogicalResult
-mlir::structFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type,
- SmallVectorImpl<Type> &result) {
+LogicalResult mlir::structFuncArgTypeConverter(
+ const LLVMTypeConverter &converter, Type type,
+ std::optional<NamedAttribute> byValByRefArgAttr,
+ SmallVectorImpl<Type> &result) {
if (auto memref = dyn_cast<MemRefType>(type)) {
// In signatures, Memref descriptors are expanded into lists of
// non-aggregate values.
@@ -679,23 +682,63 @@ mlir::structFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type,
result.append(converted.begin(), converted.end());
return success();
}
- auto converted = converter.convertType(type);
- if (!converted)
- return failure();
+
+ /// If the argument has the `llvm.byval` or `llvm.byref` attribute, the
+ /// converted type is an LLVM pointer so that the LLVM argument passing
+ /// is correct.
+ Type converted;
+ if (byValByRefArgAttr.has_value() &&
+ (byValByRefArgAttr->getName() == LLVM::LLVMDialect::getByValAttrName() ||
+ byValByRefArgAttr->getName() == LLVM::LLVMDialect::getByRefAttrName())) {
+ converted = LLVM::LLVMPointerType::get(type.getContext());
+ } else {
+ converted = converter.convertType(type);
+ if (!converted)
+ return failure();
+ }
+
result.push_back(converted);
return success();
}
/// Callback to convert function argument types. It converts MemRef function
/// arguments to bare pointers to the MemRef element type.
-LogicalResult
-mlir::barePtrFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type,
- SmallVectorImpl<Type> &result) {
- auto llvmTy = converter.convertCallingConventionType(
- type, /*useBarePointerCallConv=*/true);
- if (!llvmTy)
- return failure();
+LogicalResult mlir::barePtrFuncArgTypeConverter(
+ const LLVMTypeConverter &converter, Type type,
+ std::optional<NamedAttribute> byValByRefArgAttr,
+ SmallVectorImpl<Type> &result) {
+ /// If the argument has the `llvm.byval` or `llvm.byref` attribute, the
+ /// converted type is an LLVM pointer so that the LLVM argument passing
+ /// convention is correct.
+ Type llvmTy;
+ if (byValByRefArgAttr.has_value() &&
+ (byValByRefArgAttr->getName() == LLVM::LLVMDialect::getByValAttrName() ||
+ byValByRefArgAttr->getName() == LLVM::LLVMDialect::getByRefAttrName())) {
+ llvmTy = LLVM::LLVMPointerType::get(type.getContext());
+ } else {
+ llvmTy = converter.convertCallingConventionType(
+ type, /*useBarePointerCallConv=*/true);
+
+ if (!llvmTy)
+ return failure();
+ }
result.push_back(llvmTy);
return success();
}
+
+void mlir::filterByValByRefArgAttributes(
+ FunctionOpInterface funcOp,
+ SmallVectorImpl<std::optional<NamedAttribute>> &result) {
+ assert(result.empty() && "Unexpected non-empty output");
+ result.resize(funcOp.getNumArguments(), std::nullopt);
+ for (int argIdx : llvm::seq(funcOp.getNumArguments())) {
+ for (NamedAttribute namedAttr : funcOp.getArgAttrs(argIdx)) {
+ if (namedAttr.getName() == LLVM::LLVMDialect::getByValAttrName() ||
+ namedAttr.getName() == LLVM::LLVMDialect::getByRefAttrName()) {
+ result[argIdx] = namedAttr;
+ break;
+ }
+ }
+ }
+}
diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
index da09384bfbe89..d5a9bc3783660 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
@@ -1408,11 +1408,13 @@ class FuncConversionPattern : public SPIRVToLLVMConversion<spirv::FuncOp> {
// Convert function signature. At the moment LLVMType converter is enough
// for currently supported types.
auto funcType = funcOp.getFunctionType();
+ SmallVector<std::optional<NamedAttribute>> byValByRefArgAttrs;
+ filterByValByRefArgAttributes(funcOp, byValByRefArgAttrs);
TypeConverter::SignatureConversion signatureConverter(
funcType.getNumInputs());
auto llvmType = typeConverter.convertFunctionSignature(
funcType, /*isVariadic=*/false, /*useBarePtrCallConv=*/false,
- signatureConverter);
+ byValByRefArgAttrs, signatureConverter);
if (!llvmType)
return failure();
diff --git a/mlir/test/Transforms/test-convert-func-op.mlir b/mlir/test/Transforms/test-convert-func-op.mlir
index 6e96703cda578..9f006a5a187c1 100644
--- a/mlir/test/Transforms/test-convert-func-op.mlir
+++ b/mlir/test/Transforms/test-convert-func-op.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -test-convert-func-op | FileCheck %s
+// RUN: mlir-opt %s -test-convert-func-op -split-input-file | FileCheck %s
// CHECK-LABEL: llvm.func @add
func.func @add(%arg0: i32, %arg1: i32) -> i32 attributes { llvm.emit_c_interface } {
@@ -10,3 +10,31 @@ func.func @add(%arg0: i32, %arg1: i32) -> i32 attributes { llvm.emit_c_interface
// CHECK-SAME: [[ARG1:%[a-zA-Z0-9_]+]]: i32
// CHECK-NEXT: [[RES:%.*]] = llvm.call @add([[ARG0]], [[ARG1]])
// CHECK-NEXT: llvm.return [[RES]]
+
+// -----
+
+// Test that `llvm.byval` arguments are converted to `llvm.ptr` and the actual
+// value is retrieved within the `llvm.func`.
+
+// CHECK-LABEL: llvm.func @byval
+func.func @byval(%arg0: !test.smpla {llvm.byval = !test.smpla}) -> !test.smpla {
+ return %arg0 : !test.smpla
+}
+
+// CHECK-SAME: (%[[ARG0:.*]]: !llvm.ptr {llvm.byval = !llvm.struct<(i8, i8)>}) -> !llvm.struct<(i8, i8)>
+// CHECK: %[[LD:.*]] = llvm.load %[[ARG0]] : !llvm.ptr -> !llvm.struct<(i8, i8)>
+// CHECK: llvm.return %[[LD]] : !llvm.struct<(i8, i8)>
+
+// -----
+
+// Test that `llvm.byref` arguments are converted to `llvm.ptr` and the actual
+// value is retrieved within the `llvm.func`.
+
+// CHECK-LABEL: llvm.func @byref
+func.func @byref(%arg0: !test.smpla {llvm.byref = !test.smpla}) -> !test.smpla {
+ return %arg0 : !test.smpla
+}
+
+// CHECK-SAME: (%[[ARG0:.*]]: !llvm.ptr {llvm.byref = !llvm.struct<(i8, i8)>}) -> !llvm.struct<(i8, i8)>
+// CHECK: %[[LD:.*]] = llvm.load %[[ARG0]] : !llvm.ptr -> !llvm.struct<(i8, i8)>
+// CHECK: llvm.return %[[LD]] : !llvm.struct<(i8, i8)>
\ No newline at end of file
diff --git a/mlir/test/lib/Conversion/FuncToLLVM/TestConvertFuncOp.cpp b/mlir/test/lib/Conversion/FuncToLLVM/TestConvertFuncOp.cpp
index e25e890e2290a..75168dde93130 100644
--- a/mlir/test/lib/Conversion/FuncToLLVM/TestConvertFuncOp.cpp
+++ b/mlir/test/lib/Conversion/FuncToLLVM/TestConvertFuncOp.cpp
@@ -47,12 +47,23 @@ struct ReturnOpConversion : public ConvertOpToLLVMPattern<func::ReturnOp> {
LogicalResult
matchAndRewrite(func::ReturnOp returnOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnOp,
- returnOp->getOperands());
+ SmallVector<Type> resTys;
+ if (failed(typeConverter->convertTypes(returnOp->getResultTypes(), resTys)))
+ return failure();
+
+ rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnOp, resTys,
+ adaptor.getOperands());
return success();
}
};
+static std::optional<Type>
+convertSimpleATypeToStruct(test::SimpleAType simpleTy) {
+ MLIRContext *ctx = simpleTy.getContext();
+ SmallVector<Type> memberTys(2, IntegerType::get(ctx, /*width=*/8));
+ return LLVM::LLVMStructType::getLiteral(ctx, memberTys);
+}
+
struct TestConvertFuncOp
: public PassWrapper<TestConvertFuncOp, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestConvertFuncOp)
@@ -74,6 +85,7 @@ struct TestConvertFuncOp
LowerToLLVMOptions options(ctx);
// Populate type conversions.
LLVMTypeConverter typeConverter(ctx, options);
+ typeConverter.addConversion(convertSimpleATypeToStruct);
RewritePatternSet patterns(ctx);
patterns.add<FuncOpConversion>(typeConverter);
>From 2160c7c51954fb76153a455fa1ec24a2882ba90d Mon Sep 17 00:00:00 2001
From: Diego Caballero <dieg0ca6aller0 at gmail.com>
Date: Tue, 30 Jul 2024 16:02:20 -0700
Subject: [PATCH 2/4] Revert "[mlir][LLVM] Improve lowering of llvm.byval
function arguments"
This reverts commit eb7cec3712345e5260142886e1fa608ad57ca69b.
---
.../Conversion/LLVMCommon/TypeConverter.h | 29 +++-----
mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp | 39 +---------
.../Conversion/GPUCommon/GPUOpsLowering.cpp | 6 +-
.../Conversion/LLVMCommon/TypeConverter.cpp | 71 ++++---------------
.../Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp | 4 +-
.../test/Transforms/test-convert-func-op.mlir | 30 +-------
.../FuncToLLVM/TestConvertFuncOp.cpp | 16 +----
7 files changed, 30 insertions(+), 165 deletions(-)
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
index 007e6ba39b632..e228229302cff 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
@@ -21,7 +21,6 @@
namespace mlir {
class DataLayoutAnalysis;
-class FunctionOpInterface;
class LowerToLLVMOptions;
namespace LLVM {
@@ -36,7 +35,6 @@ class LLVMTypeConverter : public TypeConverter {
/// Give structFuncArgTypeConverter access to memref-specific functions.
friend LogicalResult
structFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type,
- std::optional<NamedAttribute> byValByRefArgAttr,
SmallVectorImpl<Type> &result);
public:
@@ -55,10 +53,9 @@ class LLVMTypeConverter : public TypeConverter {
/// Convert a function type. The arguments and results are converted one by
/// one and results are packed into a wrapped LLVM IR structure type. `result`
/// is populated with argument mapping.
- Type convertFunctionSignature(
- FunctionType funcTy, bool isVariadic, bool useBarePtrCallConv,
- ArrayRef<std::optional<NamedAttribute>> byValByRefArgAttr,
- SignatureConversion &result) const;
+ Type convertFunctionSignature(FunctionType funcTy, bool isVariadic,
+ bool useBarePtrCallConv,
+ SignatureConversion &result) const;
/// Convert a non-empty list of types to be returned from a function into an
/// LLVM-compatible type. In particular, if more than one value is returned,
@@ -245,23 +242,15 @@ class LLVMTypeConverter : public TypeConverter {
/// argument to a list of non-aggregate types containing descriptor
/// information, and an UnrankedmemRef function argument to a list containing
/// the rank and a pointer to a descriptor struct.
-LogicalResult
-structFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type,
- std::optional<NamedAttribute> byValByRefArgAttr,
- SmallVectorImpl<Type> &result);
+LogicalResult structFuncArgTypeConverter(const LLVMTypeConverter &converter,
+ Type type,
+ SmallVectorImpl<Type> &result);
/// Callback to convert function argument types. It converts MemRef function
/// arguments to bare pointers to the MemRef element type.
-LogicalResult
-barePtrFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type,
- std::optional<NamedAttribute> byValByRefArgAttr,
- SmallVectorImpl<Type> &result);
-
-/// Returns in `result` the `llvm.byval` or `llvm.byref` attributes, if
-/// present, or an empty attribute for each function argument.
-void filterByValByRefArgAttributes(
- FunctionOpInterface funcOp,
- SmallVectorImpl<std::optional<NamedAttribute>> &result);
+LogicalResult barePtrFuncArgTypeConverter(const LLVMTypeConverter &converter,
+ Type type,
+ SmallVectorImpl<Type> &result);
} // namespace mlir
diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
index 8d4645d46f069..c1f6d8bc5b361 100644
--- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
+++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
@@ -267,36 +267,6 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc,
}
}
-/// Inserts `llvm.load` ops in the function body to restore the expected pointee
-/// value from `llvm.byval`/`llvm.byref` function arguments that were converted
-/// to LLVM pointer types.
-static void restoreByValByRefArgumentType(ConversionPatternRewriter &rewriter,
- FunctionOpInterface funcOp) {
- // Nothing to do for function declarations.
- if (funcOp.isExternal())
- return;
-
- ConversionPatternRewriter::InsertionGuard guard(rewriter);
- rewriter.setInsertionPointToStart(&funcOp.getFunctionBody().front());
-
- SmallVector<std::optional<NamedAttribute>> byValRefArgAttrs;
- filterByValByRefArgAttributes(funcOp, byValRefArgAttrs);
- for (const auto &[arg, byValRefAttr] :
- llvm::zip(funcOp.getArguments(), byValRefArgAttrs)) {
- // Skip argument if no `llvm.byval` or `llvm.byref` attribute.
- if (!byValRefAttr)
- continue;
-
- // Insert load to retrieve the actual argument passed by value/reference.
- assert(isa<LLVM::LLVMPointerType>(arg.getType()) &&
- "Expected LLVM pointer type for argument with "
- "`llvm.byval`/`llvm.byref` attribute");
- Type resTy = cast<TypeAttr>(byValRefAttr->getValue()).getValue();
- auto valueArg = rewriter.create<LLVM::LoadOp>(arg.getLoc(), resTy, arg);
- rewriter.replaceAllUsesExcept(arg, valueArg, valueArg);
- }
-}
-
FailureOr<LLVM::LLVMFuncOp>
mlir::convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
ConversionPatternRewriter &rewriter,
@@ -310,12 +280,10 @@ mlir::convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
// Convert the original function arguments. They are converted using the
// LLVMTypeConverter provided to this legalization pattern.
auto varargsAttr = funcOp->getAttrOfType<BoolAttr>(varargsAttrName);
- SmallVector<std::optional<NamedAttribute>> byValByRefArgAttrs;
- filterByValByRefArgAttributes(funcOp, byValByRefArgAttrs);
TypeConverter::SignatureConversion result(funcOp.getNumArguments());
auto llvmType = converter.convertFunctionSignature(
funcTy, varargsAttr && varargsAttr.getValue(),
- shouldUseBarePtrCallConv(funcOp, &converter), byValByRefArgAttrs, result);
+ shouldUseBarePtrCallConv(funcOp, &converter), result);
if (!llvmType)
return rewriter.notifyMatchFailure(funcOp, "signature conversion failed");
@@ -430,11 +398,6 @@ mlir::convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
"region types conversion failed");
}
- // Fix the type mismatch between the generated `llvm.ptr` and the expected
- // pointee type in the function body when converting `llvm.byval`/`llvm.byref`
- // function arguments.
- restoreByValByRefArgumentType(rewriter, newFuncOp);
-
if (!shouldUseBarePtrCallConv(funcOp, &converter)) {
if (funcOp->getAttrOfType<UnitAttr>(
LLVM::LLVMDialect::getEmitCWrapperAttrName())) {
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index 143f7b3071253..6053e34f30a41 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -53,12 +53,10 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
// Remap proper input types.
TypeConverter::SignatureConversion signatureConversion(
gpuFuncOp.front().getNumArguments());
- SmallVector<std::optional<NamedAttribute>> byValByRefArgAttrs;
- filterByValByRefArgAttributes(gpuFuncOp, byValByRefArgAttrs);
+
Type funcType = getTypeConverter()->convertFunctionSignature(
gpuFuncOp.getFunctionType(), /*isVariadic=*/false,
- getTypeConverter()->getOptions().useBarePtrCallConv, byValByRefArgAttrs,
- signatureConversion);
+ getTypeConverter()->getOptions().useBarePtrCallConv, signatureConversion);
if (!funcType) {
return rewriter.notifyMatchFailure(gpuFuncOp, [&](Diagnostic &diag) {
diag << "failed to convert function signature type for: "
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index c62096fdc853d..d5df960928afb 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -276,7 +276,6 @@ Type LLVMTypeConverter::convertFunctionType(FunctionType type) const {
// they are into an LLVM StructType in their order of appearance.
Type LLVMTypeConverter::convertFunctionSignature(
FunctionType funcTy, bool isVariadic, bool useBarePtrCallConv,
- ArrayRef<std::optional<NamedAttribute>> byValByRefArgAtts,
LLVMTypeConverter::SignatureConversion &result) const {
// Select the argument converter depending on the calling convention.
useBarePtrCallConv = useBarePtrCallConv || options.useBarePtrCallConv;
@@ -285,8 +284,7 @@ Type LLVMTypeConverter::convertFunctionSignature(
// Convert argument types one by one and check for errors.
for (auto [idx, type] : llvm::enumerate(funcTy.getInputs())) {
SmallVector<Type, 8> converted;
- if (failed(
- funcArgConverter(*this, type, byValByRefArgAtts[idx], converted)))
+ if (failed(funcArgConverter(*this, type, converted)))
return {};
result.addInputs(idx, converted);
}
@@ -661,10 +659,9 @@ LLVMTypeConverter::promoteOperands(Location loc, ValueRange opOperands,
/// argument to a list of non-aggregate types containing descriptor
/// information, and an UnrankedmemRef function argument to a list containing
/// the rank and a pointer to a descriptor struct.
-LogicalResult mlir::structFuncArgTypeConverter(
- const LLVMTypeConverter &converter, Type type,
- std::optional<NamedAttribute> byValByRefArgAttr,
- SmallVectorImpl<Type> &result) {
+LogicalResult
+mlir::structFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type,
+ SmallVectorImpl<Type> &result) {
if (auto memref = dyn_cast<MemRefType>(type)) {
// In signatures, Memref descriptors are expanded into lists of
// non-aggregate values.
@@ -682,63 +679,23 @@ LogicalResult mlir::structFuncArgTypeConverter(
result.append(converted.begin(), converted.end());
return success();
}
-
- /// If the argument has the `llvm.byval` or `llvm.byref` attribute, the
- /// converted type is an LLVM pointer so that the LLVM argument passing
- /// is correct.
- Type converted;
- if (byValByRefArgAttr.has_value() &&
- (byValByRefArgAttr->getName() == LLVM::LLVMDialect::getByValAttrName() ||
- byValByRefArgAttr->getName() == LLVM::LLVMDialect::getByRefAttrName())) {
- converted = LLVM::LLVMPointerType::get(type.getContext());
- } else {
- converted = converter.convertType(type);
- if (!converted)
- return failure();
- }
-
+ auto converted = converter.convertType(type);
+ if (!converted)
+ return failure();
result.push_back(converted);
return success();
}
/// Callback to convert function argument types. It converts MemRef function
/// arguments to bare pointers to the MemRef element type.
-LogicalResult mlir::barePtrFuncArgTypeConverter(
- const LLVMTypeConverter &converter, Type type,
- std::optional<NamedAttribute> byValByRefArgAttr,
- SmallVectorImpl<Type> &result) {
- /// If the argument has the `llvm.byval` or `llvm.byref` attribute, the
- /// converted type is an LLVM pointer so that the LLVM argument passing
- /// convention is correct.
- Type llvmTy;
- if (byValByRefArgAttr.has_value() &&
- (byValByRefArgAttr->getName() == LLVM::LLVMDialect::getByValAttrName() ||
- byValByRefArgAttr->getName() == LLVM::LLVMDialect::getByRefAttrName())) {
- llvmTy = LLVM::LLVMPointerType::get(type.getContext());
- } else {
- llvmTy = converter.convertCallingConventionType(
- type, /*useBarePointerCallConv=*/true);
-
- if (!llvmTy)
- return failure();
- }
+LogicalResult
+mlir::barePtrFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type,
+ SmallVectorImpl<Type> &result) {
+ auto llvmTy = converter.convertCallingConventionType(
+ type, /*useBarePointerCallConv=*/true);
+ if (!llvmTy)
+ return failure();
result.push_back(llvmTy);
return success();
}
-
-void mlir::filterByValByRefArgAttributes(
- FunctionOpInterface funcOp,
- SmallVectorImpl<std::optional<NamedAttribute>> &result) {
- assert(result.empty() && "Unexpected non-empty output");
- result.resize(funcOp.getNumArguments(), std::nullopt);
- for (int argIdx : llvm::seq(funcOp.getNumArguments())) {
- for (NamedAttribute namedAttr : funcOp.getArgAttrs(argIdx)) {
- if (namedAttr.getName() == LLVM::LLVMDialect::getByValAttrName() ||
- namedAttr.getName() == LLVM::LLVMDialect::getByRefAttrName()) {
- result[argIdx] = namedAttr;
- break;
- }
- }
- }
-}
diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
index d5a9bc3783660..da09384bfbe89 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
@@ -1408,13 +1408,11 @@ class FuncConversionPattern : public SPIRVToLLVMConversion<spirv::FuncOp> {
// Convert function signature. At the moment LLVMType converter is enough
// for currently supported types.
auto funcType = funcOp.getFunctionType();
- SmallVector<std::optional<NamedAttribute>> byValByRefArgAttrs;
- filterByValByRefArgAttributes(funcOp, byValByRefArgAttrs);
TypeConverter::SignatureConversion signatureConverter(
funcType.getNumInputs());
auto llvmType = typeConverter.convertFunctionSignature(
funcType, /*isVariadic=*/false, /*useBarePtrCallConv=*/false,
- byValByRefArgAttrs, signatureConverter);
+ signatureConverter);
if (!llvmType)
return failure();
diff --git a/mlir/test/Transforms/test-convert-func-op.mlir b/mlir/test/Transforms/test-convert-func-op.mlir
index 9f006a5a187c1..6e96703cda578 100644
--- a/mlir/test/Transforms/test-convert-func-op.mlir
+++ b/mlir/test/Transforms/test-convert-func-op.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -test-convert-func-op -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-convert-func-op | FileCheck %s
// CHECK-LABEL: llvm.func @add
func.func @add(%arg0: i32, %arg1: i32) -> i32 attributes { llvm.emit_c_interface } {
@@ -10,31 +10,3 @@ func.func @add(%arg0: i32, %arg1: i32) -> i32 attributes { llvm.emit_c_interface
// CHECK-SAME: [[ARG1:%[a-zA-Z0-9_]+]]: i32
// CHECK-NEXT: [[RES:%.*]] = llvm.call @add([[ARG0]], [[ARG1]])
// CHECK-NEXT: llvm.return [[RES]]
-
-// -----
-
-// Test that `llvm.byval` arguments are converted to `llvm.ptr` and the actual
-// value is retrieved within the `llvm.func`.
-
-// CHECK-LABEL: llvm.func @byval
-func.func @byval(%arg0: !test.smpla {llvm.byval = !test.smpla}) -> !test.smpla {
- return %arg0 : !test.smpla
-}
-
-// CHECK-SAME: (%[[ARG0:.*]]: !llvm.ptr {llvm.byval = !llvm.struct<(i8, i8)>}) -> !llvm.struct<(i8, i8)>
-// CHECK: %[[LD:.*]] = llvm.load %[[ARG0]] : !llvm.ptr -> !llvm.struct<(i8, i8)>
-// CHECK: llvm.return %[[LD]] : !llvm.struct<(i8, i8)>
-
-// -----
-
-// Test that `llvm.byref` arguments are converted to `llvm.ptr` and the actual
-// value is retrieved within the `llvm.func`.
-
-// CHECK-LABEL: llvm.func @byref
-func.func @byref(%arg0: !test.smpla {llvm.byref = !test.smpla}) -> !test.smpla {
- return %arg0 : !test.smpla
-}
-
-// CHECK-SAME: (%[[ARG0:.*]]: !llvm.ptr {llvm.byref = !llvm.struct<(i8, i8)>}) -> !llvm.struct<(i8, i8)>
-// CHECK: %[[LD:.*]] = llvm.load %[[ARG0]] : !llvm.ptr -> !llvm.struct<(i8, i8)>
-// CHECK: llvm.return %[[LD]] : !llvm.struct<(i8, i8)>
\ No newline at end of file
diff --git a/mlir/test/lib/Conversion/FuncToLLVM/TestConvertFuncOp.cpp b/mlir/test/lib/Conversion/FuncToLLVM/TestConvertFuncOp.cpp
index 75168dde93130..e25e890e2290a 100644
--- a/mlir/test/lib/Conversion/FuncToLLVM/TestConvertFuncOp.cpp
+++ b/mlir/test/lib/Conversion/FuncToLLVM/TestConvertFuncOp.cpp
@@ -47,23 +47,12 @@ struct ReturnOpConversion : public ConvertOpToLLVMPattern<func::ReturnOp> {
LogicalResult
matchAndRewrite(func::ReturnOp returnOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- SmallVector<Type> resTys;
- if (failed(typeConverter->convertTypes(returnOp->getResultTypes(), resTys)))
- return failure();
-
- rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnOp, resTys,
- adaptor.getOperands());
+ rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnOp,
+ returnOp->getOperands());
return success();
}
};
-static std::optional<Type>
-convertSimpleATypeToStruct(test::SimpleAType simpleTy) {
- MLIRContext *ctx = simpleTy.getContext();
- SmallVector<Type> memberTys(2, IntegerType::get(ctx, /*width=*/8));
- return LLVM::LLVMStructType::getLiteral(ctx, memberTys);
-}
-
struct TestConvertFuncOp
: public PassWrapper<TestConvertFuncOp, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestConvertFuncOp)
@@ -85,7 +74,6 @@ struct TestConvertFuncOp
LowerToLLVMOptions options(ctx);
// Populate type conversions.
LLVMTypeConverter typeConverter(ctx, options);
- typeConverter.addConversion(convertSimpleATypeToStruct);
RewritePatternSet patterns(ctx);
patterns.add<FuncOpConversion>(typeConverter);
>From 9e59d4a61dc80879026f945e70c9944e47122c7d Mon Sep 17 00:00:00 2001
From: Diego Caballero <dieg0ca6aller0 at gmail.com>
Date: Wed, 31 Jul 2024 17:32:25 -0700
Subject: [PATCH 3/4] Second approach
---
.../Conversion/LLVMCommon/TypeConverter.h | 15 +++++
.../mlir/Transforms/DialectConversion.h | 4 ++
mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp | 49 ++++++++++++++++-
.../Conversion/LLVMCommon/TypeConverter.cpp | 55 +++++++++++++++++++
.../Transforms/Utils/DialectConversion.cpp | 8 +++
.../test/Transforms/test-convert-func-op.mlir | 30 +++++++++-
.../FuncToLLVM/TestConvertFuncOp.cpp | 16 +++++-
7 files changed, 171 insertions(+), 6 deletions(-)
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
index e228229302cff..a0cc42d4d09f9 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
@@ -21,6 +21,7 @@
namespace mlir {
class DataLayoutAnalysis;
+class FunctionOpInterface;
class LowerToLLVMOptions;
namespace LLVM {
@@ -57,6 +58,13 @@ class LLVMTypeConverter : public TypeConverter {
bool useBarePtrCallConv,
SignatureConversion &result) const;
+ /// Replace the type of `llvm.byval` and `llvm.byref` function arguments with
+ /// an LLVM pointer type in the function signature.
+ LLVM::LLVMFunctionType materializePtrForByValByRefFuncArgs(
+ LLVM::LLVMFunctionType funcType,
+ ArrayRef<std::optional<NamedAttribute>> byValRefArgAttrs,
+ LLVMTypeConverter::SignatureConversion &signatureConv) const;
+
/// Convert a non-empty list of types to be returned from a function into an
/// LLVM-compatible type. In particular, if more than one value is returned,
/// create an LLVM dialect structure type with elements that correspond to
@@ -252,6 +260,13 @@ LogicalResult barePtrFuncArgTypeConverter(const LLVMTypeConverter &converter,
Type type,
SmallVectorImpl<Type> &result);
+/// Returns the `llvm.byval` or `llvm.byref` attributes that are present in the
+/// function arguments. Returns an empty container if none of these attributes
+/// are found in any of the attributes.
+void filterByValByRefArgAttributes(
+ FunctionOpInterface funcOp,
+ SmallVectorImpl<std::optional<NamedAttribute>> &result);
+
} // namespace mlir
#endif // MLIR_CONVERSION_LLVMCOMMON_TYPECONVERTER_H
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index a51b00271f0ae..7f0983f10ff82 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -96,6 +96,10 @@ class TypeConverter {
/// value. This drops the original argument.
void remapInput(unsigned origInputNo, Value replacement);
+ /// Replace the type of an input that has been previously remapped to a new
+ /// single input.
+ void replaceRemappedInputType(unsigned origInputNo, Type type);
+
private:
/// Remap an input of the original signature with a range of types in the
/// new signature.
diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
index c1f6d8bc5b361..cca5aac31264b 100644
--- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
+++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
@@ -267,6 +267,36 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc,
}
}
+/// Inserts `llvm.load` ops in the function body to restore the expected pointee
+/// value from `llvm.byval`/`llvm.byref` function arguments that were converted
+/// to LLVM pointer types.
+static void restoreByValByRefArgumentType(ConversionPatternRewriter &rewriter,
+ FunctionOpInterface funcOp) {
+ // Nothing to do for function declarations.
+ if (funcOp.isExternal())
+ return;
+
+ ConversionPatternRewriter::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPointToStart(&funcOp.getFunctionBody().front());
+
+ SmallVector<std::optional<NamedAttribute>> byValRefArgAttrs;
+ filterByValByRefArgAttributes(funcOp, byValRefArgAttrs);
+ for (const auto &[arg, byValRefAttr] :
+ llvm::zip(funcOp.getArguments(), byValRefArgAttrs)) {
+ // Skip argument if no `llvm.byval` or `llvm.byref` attribute.
+ if (!byValRefAttr)
+ continue;
+
+ // Insert load to retrieve the actual argument passed by value/reference.
+ assert(isa<LLVM::LLVMPointerType>(arg.getType()) &&
+ "Expected LLVM pointer type for argument with "
+ "`llvm.byval`/`llvm.byref` attribute");
+ Type resTy = cast<TypeAttr>(byValRefAttr->getValue()).getValue();
+ auto valueArg = rewriter.create<LLVM::LoadOp>(arg.getLoc(), resTy, arg);
+ rewriter.replaceAllUsesExcept(arg, valueArg, valueArg);
+ }
+}
+
FailureOr<LLVM::LLVMFuncOp>
mlir::convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
ConversionPatternRewriter &rewriter,
@@ -281,12 +311,20 @@ mlir::convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
// LLVMTypeConverter provided to this legalization pattern.
auto varargsAttr = funcOp->getAttrOfType<BoolAttr>(varargsAttrName);
TypeConverter::SignatureConversion result(funcOp.getNumArguments());
- auto llvmType = converter.convertFunctionSignature(
- funcTy, varargsAttr && varargsAttr.getValue(),
- shouldUseBarePtrCallConv(funcOp, &converter), result);
+ auto llvmType =
+ cast_or_null<LLVM::LLVMFunctionType>(converter.convertFunctionSignature(
+ funcTy, varargsAttr && varargsAttr.getValue(),
+ shouldUseBarePtrCallConv(funcOp, &converter), result));
if (!llvmType)
return rewriter.notifyMatchFailure(funcOp, "signature conversion failed");
+ // Make sure the type of `llvm.byval` and `llvm.byref` arguments are converted
+ // to LLVM pointer types.
+ SmallVector<std::optional<NamedAttribute>> byValByRefArgs;
+ filterByValByRefArgAttributes(funcOp, byValByRefArgs);
+ llvmType = converter.materializePtrForByValByRefFuncArgs(
+ llvmType, byValByRefArgs, result);
+
// Create an LLVM function, use external linkage by default until MLIR
// functions have linkage.
LLVM::Linkage linkage = LLVM::Linkage::External;
@@ -398,6 +436,11 @@ mlir::convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
"region types conversion failed");
}
+ // Fix the type mismatch between the generated `llvm.ptr` and the expected
+ // pointee type in the function body when converting `llvm.byval`/`llvm.byref`
+ // function arguments.
+ restoreByValByRefArgumentType(rewriter, newFuncOp);
+
if (!shouldUseBarePtrCallConv(funcOp, &converter)) {
if (funcOp->getAttrOfType<UnitAttr>(
LLVM::LLVMDialect::getEmitCWrapperAttrName())) {
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index d5df960928afb..7c05e5f1a330e 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -302,6 +302,40 @@ Type LLVMTypeConverter::convertFunctionSignature(
isVariadic);
}
+LLVM::LLVMFunctionType LLVMTypeConverter::materializePtrForByValByRefFuncArgs(
+ LLVM::LLVMFunctionType funcType,
+ ArrayRef<std::optional<NamedAttribute>> byValRefArgAttrs,
+ LLVMTypeConverter::SignatureConversion &signatureConv) const {
+ if (byValRefArgAttrs.empty())
+ return funcType;
+
+ // Replace the type of `llvm.byval` and `llvm.byref` arguments with an LLVM
+ // pointer type in the signature conversion.
+ for (int inArgIdx : llvm::seq(byValRefArgAttrs.size())) {
+ auto inAttr = byValRefArgAttrs[inArgIdx];
+ if (!inAttr)
+ continue;
+
+ StringRef inAttrName = inAttr->getName().getValue();
+ if (inAttrName != LLVM::LLVMDialect::getByValAttrName() &&
+ inAttrName != LLVM::LLVMDialect::getByRefAttrName())
+ continue;
+
+ auto mapping = signatureConv.getInputMapping(inArgIdx);
+ assert(mapping && "unexpected deletion of function argument");
+ // Replace the argument type with an LLVM pointer type. Only do so if there
+ // is a one-to-one mapping from old to new types.
+ if (mapping->size == 1) {
+ signatureConv.replaceRemappedInputType(
+ mapping->inputNo, LLVM::LLVMPointerType::get(&getContext()));
+ }
+ }
+
+ return LLVM::LLVMFunctionType::get(funcType.getReturnType(),
+ signatureConv.getConvertedTypes(),
+ funcType.isVarArg());
+}
+
/// Converts the function type to a C-compatible format, in particular using
/// pointers to memref descriptors for arguments.
std::pair<LLVM::LLVMFunctionType, LLVM::LLVMStructType>
@@ -699,3 +733,24 @@ mlir::barePtrFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type,
result.push_back(llvmTy);
return success();
}
+
+void mlir::filterByValByRefArgAttributes(
+ FunctionOpInterface funcOp,
+ SmallVectorImpl<std::optional<NamedAttribute>> &result) {
+ assert(result.empty() && "Unexpected non-empty output");
+ result.resize(funcOp.getNumArguments(), std::nullopt);
+ bool hasByValByRefAttrs = false;
+ for (int argIdx : llvm::seq(funcOp.getNumArguments())) {
+ for (NamedAttribute namedAttr : funcOp.getArgAttrs(argIdx)) {
+ if (namedAttr.getName() == LLVM::LLVMDialect::getByValAttrName() ||
+ namedAttr.getName() == LLVM::LLVMDialect::getByRefAttrName()) {
+ hasByValByRefAttrs = true;
+ result[argIdx] = namedAttr;
+ break;
+ }
+ }
+ }
+
+ if (!hasByValByRefAttrs)
+ result.clear();
+}
\ No newline at end of file
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index f26aa0a1516a6..99882a306c0ea 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -2958,6 +2958,14 @@ void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo,
InputMapping{origInputNo, /*size=*/0, replacementValue};
}
+void TypeConverter::SignatureConversion::replaceRemappedInputType(
+ unsigned origInputNo, Type type) {
+ auto inputMap = remappedInputs[origInputNo];
+ assert(inputMap && "Expected remapped input");
+ assert(inputMap->size == 1 && "Can't replace 1->N remapped input");
+ argTypes[inputMap->inputNo] = type;
+}
+
LogicalResult TypeConverter::convertType(Type t,
SmallVectorImpl<Type> &results) const {
{
diff --git a/mlir/test/Transforms/test-convert-func-op.mlir b/mlir/test/Transforms/test-convert-func-op.mlir
index 6e96703cda578..e85f4dc990759 100644
--- a/mlir/test/Transforms/test-convert-func-op.mlir
+++ b/mlir/test/Transforms/test-convert-func-op.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -test-convert-func-op | FileCheck %s
+// RUN: mlir-opt %s -test-convert-func-op --split-input-file | FileCheck %s
// CHECK-LABEL: llvm.func @add
func.func @add(%arg0: i32, %arg1: i32) -> i32 attributes { llvm.emit_c_interface } {
@@ -10,3 +10,31 @@ func.func @add(%arg0: i32, %arg1: i32) -> i32 attributes { llvm.emit_c_interface
// CHECK-SAME: [[ARG1:%[a-zA-Z0-9_]+]]: i32
// CHECK-NEXT: [[RES:%.*]] = llvm.call @add([[ARG0]], [[ARG1]])
// CHECK-NEXT: llvm.return [[RES]]
+
+// -----
+
+// Test that `llvm.byval` arguments are converted to `llvm.ptr` and the actual
+// value is retrieved within the `llvm.func`.
+
+// CHECK-LABEL: llvm.func @byval
+func.func @byval(%arg0: !test.smpla {llvm.byval = !test.smpla}) -> !test.smpla {
+ return %arg0 : !test.smpla
+}
+
+// CHECK-SAME: (%[[ARG0:.*]]: !llvm.ptr {llvm.byval = !llvm.struct<(i8, i8)>}) -> !llvm.struct<(i8, i8)>
+// CHECK: %[[LD:.*]] = llvm.load %[[ARG0]] : !llvm.ptr -> !llvm.struct<(i8, i8)>
+// CHECK: llvm.return %[[LD]] : !llvm.struct<(i8, i8)>
+
+// -----
+
+// Test that `llvm.byref` arguments are converted to `llvm.ptr` and the actual
+// value is retrieved within the `llvm.func`.
+
+// CHECK-LABEL: llvm.func @byref
+func.func @byref(%arg0: !test.smpla {llvm.byref = !test.smpla}) -> !test.smpla {
+ return %arg0 : !test.smpla
+}
+
+// CHECK-SAME: (%[[ARG0:.*]]: !llvm.ptr {llvm.byref = !llvm.struct<(i8, i8)>}) -> !llvm.struct<(i8, i8)>
+// CHECK: %[[LD:.*]] = llvm.load %[[ARG0]] : !llvm.ptr -> !llvm.struct<(i8, i8)>
+// CHECK: llvm.return %[[LD]] : !llvm.struct<(i8, i8)>
\ No newline at end of file
diff --git a/mlir/test/lib/Conversion/FuncToLLVM/TestConvertFuncOp.cpp b/mlir/test/lib/Conversion/FuncToLLVM/TestConvertFuncOp.cpp
index e25e890e2290a..75168dde93130 100644
--- a/mlir/test/lib/Conversion/FuncToLLVM/TestConvertFuncOp.cpp
+++ b/mlir/test/lib/Conversion/FuncToLLVM/TestConvertFuncOp.cpp
@@ -47,12 +47,23 @@ struct ReturnOpConversion : public ConvertOpToLLVMPattern<func::ReturnOp> {
LogicalResult
matchAndRewrite(func::ReturnOp returnOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnOp,
- returnOp->getOperands());
+ SmallVector<Type> resTys;
+ if (failed(typeConverter->convertTypes(returnOp->getResultTypes(), resTys)))
+ return failure();
+
+ rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnOp, resTys,
+ adaptor.getOperands());
return success();
}
};
+static std::optional<Type>
+convertSimpleATypeToStruct(test::SimpleAType simpleTy) {
+ MLIRContext *ctx = simpleTy.getContext();
+ SmallVector<Type> memberTys(2, IntegerType::get(ctx, /*width=*/8));
+ return LLVM::LLVMStructType::getLiteral(ctx, memberTys);
+}
+
struct TestConvertFuncOp
: public PassWrapper<TestConvertFuncOp, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestConvertFuncOp)
@@ -74,6 +85,7 @@ struct TestConvertFuncOp
LowerToLLVMOptions options(ctx);
// Populate type conversions.
LLVMTypeConverter typeConverter(ctx, options);
+ typeConverter.addConversion(convertSimpleATypeToStruct);
RewritePatternSet patterns(ctx);
patterns.add<FuncOpConversion>(typeConverter);
>From e9bc978e0eae784ec0cbaf8e1fd5c5cbdf1d8b19 Mon Sep 17 00:00:00 2001
From: Diego Caballero <dieg0ca6aller0 at gmail.com>
Date: Thu, 1 Aug 2024 12:53:18 -0700
Subject: [PATCH 4/4] Fixes and cleanup
---
.../Conversion/LLVMCommon/TypeConverter.h | 7 ---
.../mlir/Transforms/DialectConversion.h | 5 ++
mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp | 61 +++++++++++++++----
.../Conversion/LLVMCommon/TypeConverter.cpp | 21 -------
.../Transforms/Utils/DialectConversion.cpp | 8 +++
.../test/Transforms/test-convert-func-op.mlir | 2 +-
6 files changed, 62 insertions(+), 42 deletions(-)
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
index a0cc42d4d09f9..ff556b6a096c2 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
@@ -260,13 +260,6 @@ LogicalResult barePtrFuncArgTypeConverter(const LLVMTypeConverter &converter,
Type type,
SmallVectorImpl<Type> &result);
-/// Returns the `llvm.byval` or `llvm.byref` attributes that are present in the
-/// function arguments. Returns an empty container if none of these attributes
-/// are found in any of the attributes.
-void filterByValByRefArgAttributes(
- FunctionOpInterface funcOp,
- SmallVectorImpl<std::optional<NamedAttribute>> &result);
-
} // namespace mlir
#endif // MLIR_CONVERSION_LLVMCOMMON_TYPECONVERTER_H
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 7f0983f10ff82..fd0bb64f722f3 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -17,6 +17,7 @@
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/StringMap.h"
+
#include <type_traits>
namespace mlir {
@@ -75,6 +76,10 @@ class TypeConverter {
/// Return the argument types for the new signature.
ArrayRef<Type> getConvertedTypes() const { return argTypes; }
+ /// Get the converted type for the given argument only if there is a
+ /// one-to-one mapping for it. Otherwise, return std::nullptr.
+ std::optional<Type> getConvertedType(unsigned inputo) const;
+
/// Get the input mapping for the given argument.
std::optional<InputMapping> getInputMapping(unsigned input) const {
return remappedInputs[input];
diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
index cca5aac31264b..9610ba303ffd1 100644
--- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
+++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
@@ -83,6 +83,38 @@ static void filterFuncAttributes(FunctionOpInterface func,
}
}
+/// Returns the `llvm.byval` or `llvm.byref` attributes that are present in the
+/// function arguments and were not converted to an LLVM pointer type. Returns
+/// an empty container if none of these attributes are found in any of the
+/// attributes.
+static void filterByValRefNonPtrAttrs(
+ FunctionOpInterface funcOp,
+ const TypeConverter::SignatureConversion &signatureConv,
+ SmallVectorImpl<std::optional<NamedAttribute>> &result,
+ function_ref<bool(int argIdx)> filter = nullptr) {
+ assert(result.empty() && "Unexpected non-empty output");
+ result.resize(funcOp.getNumArguments(), std::nullopt);
+ bool foundByValByRefAttrs = false;
+ for (int argIdx : llvm::seq(funcOp.getNumArguments())) {
+ for (NamedAttribute namedAttr : funcOp.getArgAttrs(argIdx)) {
+ if ((namedAttr.getName() == LLVM::LLVMDialect::getByValAttrName() ||
+ namedAttr.getName() == LLVM::LLVMDialect::getByRefAttrName())) {
+ // Retrieve the converted type from the converted signature and check
+ // that is not an LLVM pointer.
+ auto convType = signatureConv.getConvertedType(argIdx);
+ if (!convType || isa<LLVM::LLVMPointerType>(*convType))
+ continue;
+ foundByValByRefAttrs = true;
+ result[argIdx] = namedAttr;
+ break;
+ }
+ }
+ }
+
+ if (!foundByValByRefAttrs)
+ result.clear();
+}
+
/// Propagate argument/results attributes.
static void propagateArgResAttrs(OpBuilder &builder, bool resultStructType,
FunctionOpInterface funcOp,
@@ -270,8 +302,10 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc,
/// Inserts `llvm.load` ops in the function body to restore the expected pointee
/// value from `llvm.byval`/`llvm.byref` function arguments that were converted
/// to LLVM pointer types.
-static void restoreByValByRefArgumentType(ConversionPatternRewriter &rewriter,
- FunctionOpInterface funcOp) {
+static void restoreByValRefArgumentType(
+ ConversionPatternRewriter &rewriter, const LLVMTypeConverter &typeConverter,
+ ArrayRef<std::optional<NamedAttribute>> byValRefNonPtrAttrs,
+ LLVM::LLVMFuncOp funcOp) {
// Nothing to do for function declarations.
if (funcOp.isExternal())
return;
@@ -279,10 +313,8 @@ static void restoreByValByRefArgumentType(ConversionPatternRewriter &rewriter,
ConversionPatternRewriter::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&funcOp.getFunctionBody().front());
- SmallVector<std::optional<NamedAttribute>> byValRefArgAttrs;
- filterByValByRefArgAttributes(funcOp, byValRefArgAttrs);
for (const auto &[arg, byValRefAttr] :
- llvm::zip(funcOp.getArguments(), byValRefArgAttrs)) {
+ llvm::zip(funcOp.getArguments(), byValRefNonPtrAttrs)) {
// Skip argument if no `llvm.byval` or `llvm.byref` attribute.
if (!byValRefAttr)
continue;
@@ -291,7 +323,9 @@ static void restoreByValByRefArgumentType(ConversionPatternRewriter &rewriter,
assert(isa<LLVM::LLVMPointerType>(arg.getType()) &&
"Expected LLVM pointer type for argument with "
"`llvm.byval`/`llvm.byref` attribute");
- Type resTy = cast<TypeAttr>(byValRefAttr->getValue()).getValue();
+ Type resTy = typeConverter.convertType(
+ cast<TypeAttr>(byValRefAttr->getValue()).getValue());
+
auto valueArg = rewriter.create<LLVM::LoadOp>(arg.getLoc(), resTy, arg);
rewriter.replaceAllUsesExcept(arg, valueArg, valueArg);
}
@@ -318,12 +352,12 @@ mlir::convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
if (!llvmType)
return rewriter.notifyMatchFailure(funcOp, "signature conversion failed");
- // Make sure the type of `llvm.byval` and `llvm.byref` arguments are converted
- // to LLVM pointer types.
- SmallVector<std::optional<NamedAttribute>> byValByRefArgs;
- filterByValByRefArgAttributes(funcOp, byValByRefArgs);
+ // Replace the type of `llvm.byval` and `llvm.byref` arguments that were not
+ // converted to an LLVM pointer type.
+ SmallVector<std::optional<NamedAttribute>> byValRefNonPtrAttrs;
+ filterByValRefNonPtrAttrs(funcOp, result, byValRefNonPtrAttrs);
llvmType = converter.materializePtrForByValByRefFuncArgs(
- llvmType, byValByRefArgs, result);
+ llvmType, byValRefNonPtrAttrs, result);
// Create an LLVM function, use external linkage by default until MLIR
// functions have linkage.
@@ -436,10 +470,11 @@ mlir::convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
"region types conversion failed");
}
- // Fix the type mismatch between the generated `llvm.ptr` and the expected
+ // Fix the type mismatch between the materialized `llvm.ptr` and the expected
// pointee type in the function body when converting `llvm.byval`/`llvm.byref`
// function arguments.
- restoreByValByRefArgumentType(rewriter, newFuncOp);
+ restoreByValRefArgumentType(rewriter, converter, byValRefNonPtrAttrs,
+ newFuncOp);
if (!shouldUseBarePtrCallConv(funcOp, &converter)) {
if (funcOp->getAttrOfType<UnitAttr>(
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index 7c05e5f1a330e..8f9490a6139c0 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -733,24 +733,3 @@ mlir::barePtrFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type,
result.push_back(llvmTy);
return success();
}
-
-void mlir::filterByValByRefArgAttributes(
- FunctionOpInterface funcOp,
- SmallVectorImpl<std::optional<NamedAttribute>> &result) {
- assert(result.empty() && "Unexpected non-empty output");
- result.resize(funcOp.getNumArguments(), std::nullopt);
- bool hasByValByRefAttrs = false;
- for (int argIdx : llvm::seq(funcOp.getNumArguments())) {
- for (NamedAttribute namedAttr : funcOp.getArgAttrs(argIdx)) {
- if (namedAttr.getName() == LLVM::LLVMDialect::getByValAttrName() ||
- namedAttr.getName() == LLVM::LLVMDialect::getByRefAttrName()) {
- hasByValByRefAttrs = true;
- result[argIdx] = namedAttr;
- break;
- }
- }
- }
-
- if (!hasByValByRefAttrs)
- result.clear();
-}
\ No newline at end of file
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 99882a306c0ea..52e7d1a793dad 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -2929,6 +2929,14 @@ LogicalResult OperationConverter::legalizeChangedResultType(
// Type Conversion
//===----------------------------------------------------------------------===//
+std::optional<Type>
+TypeConverter::SignatureConversion::getConvertedType(unsigned input) const {
+ auto mapping = getInputMapping(input);
+ if (!mapping || mapping->size != 1)
+ return std::nullopt;
+ return getConvertedTypes()[mapping->inputNo];
+}
+
void TypeConverter::SignatureConversion::addInputs(unsigned origInputNo,
ArrayRef<Type> types) {
assert(!types.empty() && "expected valid types");
diff --git a/mlir/test/Transforms/test-convert-func-op.mlir b/mlir/test/Transforms/test-convert-func-op.mlir
index e85f4dc990759..180f16a32991b 100644
--- a/mlir/test/Transforms/test-convert-func-op.mlir
+++ b/mlir/test/Transforms/test-convert-func-op.mlir
@@ -37,4 +37,4 @@ func.func @byref(%arg0: !test.smpla {llvm.byref = !test.smpla}) -> !test.smpla {
// CHECK-SAME: (%[[ARG0:.*]]: !llvm.ptr {llvm.byref = !llvm.struct<(i8, i8)>}) -> !llvm.struct<(i8, i8)>
// CHECK: %[[LD:.*]] = llvm.load %[[ARG0]] : !llvm.ptr -> !llvm.struct<(i8, i8)>
-// CHECK: llvm.return %[[LD]] : !llvm.struct<(i8, i8)>
\ No newline at end of file
+// CHECK: llvm.return %[[LD]] : !llvm.struct<(i8, i8)>
More information about the Mlir-commits
mailing list