[Mlir-commits] [mlir] 5837fdc - [mlir][llvm] Pass struct results as parameter in c wrapper
Stephan Herhut
llvmlistbot at llvm.org
Wed Mar 17 04:59:15 PDT 2021
Author: Stephan Herhut
Date: 2021-03-17T12:58:52+01:00
New Revision: 5837fdc4ccc4d61e9eb7b6d310760c0be2e52124
URL: https://github.com/llvm/llvm-project/commit/5837fdc4ccc4d61e9eb7b6d310760c0be2e52124
DIFF: https://github.com/llvm/llvm-project/commit/5837fdc4ccc4d61e9eb7b6d310760c0be2e52124.diff
LOG: [mlir][llvm] Pass struct results as parameter in c wrapper
Returning structs directly in LLVM does not necessarily align with the C ABI of
the platform. This might happen to work on Linux but for small structs this
breaks on Windows. With this change, the wrappers work platform independently.
Differential Revision: https://reviews.llvm.org/D98725
Added:
Modified:
mlir/docs/LLVMDialectMemRefConvention.md
mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
mlir/test/Conversion/StandardToLLVM/calling-convention.mlir
Removed:
################################################################################
diff --git a/mlir/docs/LLVMDialectMemRefConvention.md b/mlir/docs/LLVMDialectMemRefConvention.md
index 78ec6fb00752..d528f0c01c21 100644
--- a/mlir/docs/LLVMDialectMemRefConvention.md
+++ b/mlir/docs/LLVMDialectMemRefConvention.md
@@ -232,29 +232,40 @@ struct MemRefDescriptor {
};
```
+Furthermore, we also rewrite function results to pointer parameters if the
+rewritten function result has a struct type. The special result parameter is
+added as the first parameter and is of pointer-to-struct type.
+
If enabled, the option will do the following. For _external_ functions declared
in the MLIR module.
1. Declare a new function `_mlir_ciface_<original name>` where memref arguments
are converted to pointer-to-struct and the remaining arguments are converted
- as usual.
-1. Add a body to the original function (making it non-external) that
- 1. allocates a memref descriptor,
- 1. populates it, and
- 1. passes the pointer to it into the newly declared interface function,
+ as usual. Results are converted to a special argument if they are of struct
+ type.
+2. Add a body to the original function (making it non-external) that
+ 1. allocates memref descriptors,
+ 2. populates them,
+ 3. potentially allocates space for the result struct, and
+ 4. passes the pointers to these into the newly declared interface function,
then
- 1. collects the result of the call and returns it to the caller.
+ 5. collects the result of the call (potentially from the result struct),
+ and
+ 6. returns it to the caller.
For (non-external) functions defined in the MLIR module.
1. Define a new function `_mlir_ciface_<original name>` where memref arguments
are converted to pointer-to-struct and the remaining arguments are converted
- as usual.
-1. Populate the body of the newly defined function with IR that
+ as usual. Results are converted to a special argument if they are of struct
+ type.
+2. Populate the body of the newly defined function with IR that
1. loads descriptors from pointers;
- 1. unpacks descriptor into individual non-aggregate values;
- 1. passes these values into the original function;
- 1. collects the result of the call and returns it to the caller.
+ 2. unpacks descriptor into individual non-aggregate values;
+ 3. passes these values into the original function;
+ 4. collects the results of the call and
+ 5. either copies the results into the result struct or returns them to the
+ caller.
Examples:
@@ -342,6 +353,49 @@ llvm.func @_mlir_ciface_foo(%arg0: !llvm.memref_2d_ptr) {
}
```
+```mlir
+func @foo(%arg0: memref<?x?xf32>) -> memref<?x?xf32> {
+ return %arg0 : memref<?x?xf32>
+}
+
+// Gets converted into the following
+// (using type alias for brevity):
+!llvm.memref_2d = type !llvm.struct<(ptr<f32>, ptr<f32>, i64,
+ array<2xi64>, array<2xi64>)>
+!llvm.memref_2d_ptr = type !llvm.ptr<struct<(ptr<f32>, ptr<f32>, i64,
+ array<2xi64>, array<2xi64>)>>
+
+// Function with unpacked arguments.
+llvm.func @foo(%arg0: !llvm.ptr<f32>, %arg1: !llvm.ptr<f32>, %arg2: i64,
+ %arg3: i64, %arg4: i64, %arg5: i64, %arg6: i64)
+ -> !llvm.memref_2d {
+ %0 = llvm.mlir.undef : !llvm.memref_2d
+ %1 = llvm.insertvalue %arg0, %0[0] : !llvm.memref_2d
+ %2 = llvm.insertvalue %arg1, %1[1] : !llvm.memref_2d
+ %3 = llvm.insertvalue %arg2, %2[2] : !llvm.memref_2d
+ %4 = llvm.insertvalue %arg3, %3[3, 0] : !llvm.memref_2d
+ %5 = llvm.insertvalue %arg5, %4[4, 0] : !llvm.memref_2d
+ %6 = llvm.insertvalue %arg4, %5[3, 1] : !llvm.memref_2d
+ %7 = llvm.insertvalue %arg6, %6[4, 1] : !llvm.memref_2d
+ llvm.return %7 : !llvm.memref_2d
+}
+
+// Interface function callable from C.
+llvm.func @_mlir_ciface_foo(%arg0: !llvm.memref_2d_ptr, %arg1: !llvm.memref_2d_ptr) {
+ %0 = llvm.load %arg1 : !llvm.memref_2d_ptr
+ %1 = llvm.extractvalue %0[0] : !llvm.memref_2d
+ %2 = llvm.extractvalue %0[1] : !llvm.memref_2d
+ %3 = llvm.extractvalue %0[2] : !llvm.memref_2d
+ %4 = llvm.extractvalue %0[3, 0] : !llvm.memref_2d
+ %5 = llvm.extractvalue %0[3, 1] : !llvm.memref_2d
+ %6 = llvm.extractvalue %0[4, 0] : !llvm.memref_2d
+ %7 = llvm.extractvalue %0[4, 1] : !llvm.memref_2d
+ %8 = llvm.call @foo(%1, %2, %3, %4, %5, %6, %7)
+ : (!llvm.ptr<f32>, !llvm.ptr<f32>, i64, i64, i64, i64, i64) -> !llvm.memref_2d
+ llvm.store %8, %arg0 : !llvm.memref_2d_ptr
+ llvm.return
+}
+
Rationale: Introducing auxiliary functions for C-compatible interfaces is
preferred to modifying the calling convention since it will minimize the effect
of C compatibility on intra-module calls or calls between MLIR-generated
diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
index c37528178e83..84052c6676e4 100644
--- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
+++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
@@ -116,8 +116,10 @@ class LLVMTypeConverter : public TypeConverter {
OpBuilder &builder);
/// Converts the function type to a C-compatible format, in particular using
- /// pointers to memref descriptors for arguments.
- Type convertFunctionTypeCWrapper(FunctionType type);
+ /// pointers to memref descriptors for arguments. Also converts the return
+ /// type to a pointer argument if it is a struct. Returns true if this
+ /// was the case.
+ std::pair<Type, bool> convertFunctionTypeCWrapper(FunctionType type);
/// Returns the data layout to use during and after conversion.
const llvm::DataLayout &getDataLayout() { return options.dataLayout; }
diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index de1df34eaa5d..91e520e3ca62 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -253,8 +253,24 @@ Type LLVMTypeConverter::convertFunctionSignature(
/// Converts the function type to a C-compatible format, in particular using
/// pointers to memref descriptors for arguments.
-Type LLVMTypeConverter::convertFunctionTypeCWrapper(FunctionType type) {
+std::pair<Type, bool>
+LLVMTypeConverter::convertFunctionTypeCWrapper(FunctionType type) {
SmallVector<Type, 4> inputs;
+ bool resultIsNowArg = false;
+
+ Type resultType = type.getNumResults() == 0
+ ? LLVM::LLVMVoidType::get(&getContext())
+ : unwrap(packFunctionResults(type.getResults()));
+ if (!resultType)
+ return {};
+
+ if (auto structType = resultType.dyn_cast<LLVM::LLVMStructType>()) {
+ // Struct types cannot be safely returned via C interface. Make this a
+ // pointer argument, instead.
+ inputs.push_back(LLVM::LLVMPointerType::get(structType));
+ resultType = LLVM::LLVMVoidType::get(&getContext());
+ resultIsNowArg = true;
+ }
for (Type t : type.getInputs()) {
auto converted = convertType(t);
@@ -265,13 +281,7 @@ Type LLVMTypeConverter::convertFunctionTypeCWrapper(FunctionType type) {
inputs.push_back(converted);
}
- Type resultType = type.getNumResults() == 0
- ? LLVM::LLVMVoidType::get(&getContext())
- : unwrap(packFunctionResults(type.getResults()));
- if (!resultType)
- return {};
-
- return LLVM::LLVMFunctionType::get(resultType, inputs);
+ return {LLVM::LLVMFunctionType::get(resultType, inputs), resultIsNowArg};
}
static constexpr unsigned kAllocatedPtrPosInMemRefDescriptor = 0;
@@ -1212,8 +1222,11 @@ static void filterFuncAttributes(ArrayRef<NamedAttribute> attrs,
/// Creates an auxiliary function with pointer-to-memref-descriptor-struct
/// arguments instead of unpacked arguments. This function can be called from C
/// by passing a pointer to a C struct corresponding to a memref descriptor.
+/// Similarly, returned memrefs are passed via pointers to a C struct that is
+/// passed as additional argument.
/// Internally, the auxiliary function unpacks the descriptor into individual
-/// components and forwards them to `newFuncOp`.
+/// components and forwards them to `newFuncOp` and forwards the results to
+/// the extra arguments.
static void wrapForExternalCallers(OpBuilder &rewriter, Location loc,
LLVMTypeConverter &typeConverter,
FuncOp funcOp, LLVM::LLVMFuncOp newFuncOp) {
@@ -1221,17 +1234,21 @@ static void wrapForExternalCallers(OpBuilder &rewriter, Location loc,
SmallVector<NamedAttribute, 4> attributes;
filterFuncAttributes(funcOp->getAttrs(), /*filterArgAttrs=*/false,
attributes);
+ Type wrapperFuncType;
+ bool resultIsNowArg;
+ std::tie(wrapperFuncType, resultIsNowArg) =
+ typeConverter.convertFunctionTypeCWrapper(type);
auto wrapperFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(),
- typeConverter.convertFunctionTypeCWrapper(type), LLVM::Linkage::External,
- attributes);
+ wrapperFuncType, LLVM::Linkage::External, attributes);
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(wrapperFuncOp.addEntryBlock());
SmallVector<Value, 8> args;
+ size_t argOffset = resultIsNowArg ? 1 : 0;
for (auto &en : llvm::enumerate(type.getInputs())) {
- Value arg = wrapperFuncOp.getArgument(en.index());
+ Value arg = wrapperFuncOp.getArgument(en.index() + argOffset);
if (auto memrefType = en.value().dyn_cast<MemRefType>()) {
Value loaded = rewriter.create<LLVM::LoadOp>(loc, arg);
MemRefDescriptor::unpack(rewriter, loc, loaded, memrefType, args);
@@ -1243,28 +1260,40 @@ static void wrapForExternalCallers(OpBuilder &rewriter, Location loc,
continue;
}
- args.push_back(wrapperFuncOp.getArgument(en.index()));
+ args.push_back(arg);
}
+
auto call = rewriter.create<LLVM::CallOp>(loc, newFuncOp, args);
- rewriter.create<LLVM::ReturnOp>(loc, call.getResults());
+
+ if (resultIsNowArg) {
+ rewriter.create<LLVM::StoreOp>(loc, call.getResult(0),
+ wrapperFuncOp.getArgument(0));
+ rewriter.create<LLVM::ReturnOp>(loc, ValueRange{});
+ } else {
+ rewriter.create<LLVM::ReturnOp>(loc, call.getResults());
+ }
}
/// Creates an auxiliary function with pointer-to-memref-descriptor-struct
/// arguments instead of unpacked arguments. Creates a body for the (external)
/// `newFuncOp` that allocates a memref descriptor on stack, packs the
/// individual arguments into this descriptor and passes a pointer to it into
-/// the auxiliary function. This auxiliary external function is now compatible
-/// with functions defined in C using pointers to C structs corresponding to a
-/// memref descriptor.
+/// the auxiliary function. If the result of the function cannot be directly
+/// returned, we write it to a special first argument that provides a pointer
+/// to a corresponding struct. This auxiliary external function is now
+/// compatible with functions defined in C using pointers to C structs
+/// corresponding to a memref descriptor.
static void wrapExternalFunction(OpBuilder &builder, Location loc,
LLVMTypeConverter &typeConverter,
FuncOp funcOp, LLVM::LLVMFuncOp newFuncOp) {
OpBuilder::InsertionGuard guard(builder);
- Type wrapperType =
+ Type wrapperType;
+ bool resultIsNowArg;
+ std::tie(wrapperType, resultIsNowArg) =
typeConverter.convertFunctionTypeCWrapper(funcOp.getType());
// This conversion can only fail if it could not convert one of the argument
- // types. But since it has been applies to a non-wrapper function before, it
+ // types. But since it has been applied to a non-wrapper function before, it
// should have failed earlier and not reach this point at all.
assert(wrapperType && "unexpected type conversion failure");
@@ -1285,6 +1314,17 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc,
args.reserve(type.getNumInputs());
ValueRange wrapperArgsRange(newFuncOp.getArguments());
+ if (resultIsNowArg) {
+ // Allocate the struct on the stack and pass the pointer.
+ Type resultType =
+ wrapperType.cast<LLVM::LLVMFunctionType>().getParamType(0);
+ Value one = builder.create<LLVM::ConstantOp>(
+ loc, typeConverter.convertType(builder.getIndexType()),
+ builder.getIntegerAttr(builder.getIndexType(), 1));
+ Value result = builder.create<LLVM::AllocaOp>(loc, resultType, one);
+ args.push_back(result);
+ }
+
// Iterate over the inputs of the original function and pack values into
// memref descriptors if the original type is a memref.
for (auto &en : llvm::enumerate(type.getInputs())) {
@@ -1322,7 +1362,13 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc,
assert(wrapperArgsRange.empty() && "did not map some of the arguments");
auto call = builder.create<LLVM::CallOp>(loc, wrapperFunc, args);
- builder.create<LLVM::ReturnOp>(loc, call.getResults());
+
+ if (resultIsNowArg) {
+ Value result = builder.create<LLVM::LoadOp>(loc, args.front());
+ builder.create<LLVM::ReturnOp>(loc, ValueRange{result});
+ } else {
+ builder.create<LLVM::ReturnOp>(loc, call.getResults());
+ }
}
namespace {
diff --git a/mlir/test/Conversion/StandardToLLVM/calling-convention.mlir b/mlir/test/Conversion/StandardToLLVM/calling-convention.mlir
index 6db4106d6e69..e0fc24ee1333 100644
--- a/mlir/test/Conversion/StandardToLLVM/calling-convention.mlir
+++ b/mlir/test/Conversion/StandardToLLVM/calling-convention.mlir
@@ -144,7 +144,7 @@ func @return_var_memref_caller(%arg0: memref<4x3xf32>) {
}
// CHECK-LABEL: llvm.func @return_var_memref
-func @return_var_memref(%arg0: memref<4x3xf32>) -> memref<*xf32> {
+func @return_var_memref(%arg0: memref<4x3xf32>) -> memref<*xf32> attributes { llvm.emit_c_interface } {
// Match the construction of the unranked descriptor.
// CHECK: %[[ALLOCA:.*]] = llvm.alloca
// CHECK: %[[MEMORY:.*]] = llvm.bitcast %[[ALLOCA]]
@@ -177,6 +177,10 @@ func @return_var_memref(%arg0: memref<4x3xf32>) -> memref<*xf32> {
return %0 : memref<*xf32>
}
+// Check that the result memref is passed as parameter
+// CHECK-LABEL: @_mlir_ciface_return_var_memref
+// CHECK-SAME: (%{{.*}}: !llvm.ptr<struct<(i64, ptr<i8>)>>, %{{.*}}: !llvm.ptr<struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>>)
+
// CHECK-LABEL: llvm.func @return_two_var_memref_caller
func @return_two_var_memref_caller(%arg0: memref<4x3xf32>) {
// Only check that we create two
diff erent descriptors using
diff erent
@@ -206,7 +210,7 @@ func @return_two_var_memref_caller(%arg0: memref<4x3xf32>) {
}
// CHECK-LABEL: llvm.func @return_two_var_memref
-func @return_two_var_memref(%arg0: memref<4x3xf32>) -> (memref<*xf32>, memref<*xf32>) {
+func @return_two_var_memref(%arg0: memref<4x3xf32>) -> (memref<*xf32>, memref<*xf32>) attributes { llvm.emit_c_interface } {
// Match the construction of the unranked descriptor.
// CHECK: %[[ALLOCA:.*]] = llvm.alloca
// CHECK: %[[MEMORY:.*]] = llvm.bitcast %[[ALLOCA]]
@@ -240,3 +244,8 @@ func @return_two_var_memref(%arg0: memref<4x3xf32>) -> (memref<*xf32>, memref<*x
return %0, %0 : memref<*xf32>, memref<*xf32>
}
+// Check that the result memrefs are passed as parameter
+// CHECK-LABEL: @_mlir_ciface_return_two_var_memref
+// CHECK-SAME: (%{{.*}}: !llvm.ptr<struct<(struct<(i64, ptr<i8>)>, struct<(i64, ptr<i8>)>)>>,
+// CHECK-SAME: %{{.*}}: !llvm.ptr<struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>>)
+
More information about the Mlir-commits
mailing list