[Mlir-commits] [mlir] 5837fdc - [mlir][llvm] Pass struct results as parameter in c wrapper

Stephan Herhut llvmlistbot at llvm.org
Wed Mar 17 04:59:15 PDT 2021


Author: Stephan Herhut
Date: 2021-03-17T12:58:52+01:00
New Revision: 5837fdc4ccc4d61e9eb7b6d310760c0be2e52124

URL: https://github.com/llvm/llvm-project/commit/5837fdc4ccc4d61e9eb7b6d310760c0be2e52124
DIFF: https://github.com/llvm/llvm-project/commit/5837fdc4ccc4d61e9eb7b6d310760c0be2e52124.diff

LOG: [mlir][llvm] Pass struct results as parameter in c wrapper

Returning structs directly in LLVM does not necessarily align with the C ABI of
the platform. This might happen to work on Linux but for small structs this
breaks on Windows. With this change, the wrappers work platform independently.

Differential Revision: https://reviews.llvm.org/D98725

Added: 
    

Modified: 
    mlir/docs/LLVMDialectMemRefConvention.md
    mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
    mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
    mlir/test/Conversion/StandardToLLVM/calling-convention.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/docs/LLVMDialectMemRefConvention.md b/mlir/docs/LLVMDialectMemRefConvention.md
index 78ec6fb00752..d528f0c01c21 100644
--- a/mlir/docs/LLVMDialectMemRefConvention.md
+++ b/mlir/docs/LLVMDialectMemRefConvention.md
@@ -232,29 +232,40 @@ struct MemRefDescriptor {
 };
 ```
 
+Furthermore, we also rewrite function results to pointer parameters if the
+rewritten function result has a struct type. The special result parameter is
+added as the first parameter and is of pointer-to-struct type.
+
 If enabled, the option will do the following. For _external_ functions declared
 in the MLIR module.
 
 1.  Declare a new function `_mlir_ciface_<original name>` where memref arguments
     are converted to pointer-to-struct and the remaining arguments are converted
-    as usual.
-1.  Add a body to the original function (making it non-external) that
-    1.  allocates a memref descriptor,
-    1.  populates it, and
-    1.  passes the pointer to it into the newly declared interface function,
+    as usual. Results are converted to a special argument if they are of struct
+    type.
+2.  Add a body to the original function (making it non-external) that
+    1.  allocates memref descriptors,
+    2.  populates them,
+    3.  potentially allocates space for the result struct, and
+    4.  passes the pointers to these into the newly declared interface function,
         then
-    1.  collects the result of the call and returns it to the caller.
+    5.  collects the result of the call (potentially from the result struct),
+        and
+    6.  returns it to the caller.
 
 For (non-external) functions defined in the MLIR module.
 
 1.  Define a new function `_mlir_ciface_<original name>` where memref arguments
     are converted to pointer-to-struct and the remaining arguments are converted
-    as usual.
-1.  Populate the body of the newly defined function with IR that
+    as usual. Results are converted to a special argument if they are of struct
+    type.
+2.  Populate the body of the newly defined function with IR that
     1.  loads descriptors from pointers;
-    1.  unpacks descriptor into individual non-aggregate values;
-    1.  passes these values into the original function;
-    1.  collects the result of the call and returns it to the caller.
+    2.  unpacks descriptor into individual non-aggregate values;
+    3.  passes these values into the original function;
+    4.  collects the results of the call and
+    5.  either copies the results into the result struct or returns them to the
+        caller.
 
 Examples:
 
@@ -342,6 +353,49 @@ llvm.func @_mlir_ciface_foo(%arg0: !llvm.memref_2d_ptr) {
 }
 ```
 
+```mlir
+func @foo(%arg0: memref<?x?xf32>) -> memref<?x?xf32> {
+  return %arg0 : memref<?x?xf32>
+}
+
+// Gets converted into the following
+// (using type alias for brevity):
+!llvm.memref_2d = type !llvm.struct<(ptr<f32>, ptr<f32>, i64,
+                                     array<2xi64>, array<2xi64>)>
+!llvm.memref_2d_ptr = type !llvm.ptr<struct<(ptr<f32>, ptr<f32>, i64,
+                                             array<2xi64>, array<2xi64>)>>
+
+// Function with unpacked arguments.
+llvm.func @foo(%arg0: !llvm.ptr<f32>, %arg1: !llvm.ptr<f32>, %arg2: i64,
+               %arg3: i64, %arg4: i64, %arg5: i64, %arg6: i64)
+    -> !llvm.memref_2d {
+  %0 = llvm.mlir.undef : !llvm.memref_2d
+  %1 = llvm.insertvalue %arg0, %0[0] : !llvm.memref_2d
+  %2 = llvm.insertvalue %arg1, %1[1] : !llvm.memref_2d
+  %3 = llvm.insertvalue %arg2, %2[2] : !llvm.memref_2d
+  %4 = llvm.insertvalue %arg3, %3[3, 0] : !llvm.memref_2d
+  %5 = llvm.insertvalue %arg5, %4[4, 0] : !llvm.memref_2d
+  %6 = llvm.insertvalue %arg4, %5[3, 1] : !llvm.memref_2d
+  %7 = llvm.insertvalue %arg6, %6[4, 1] : !llvm.memref_2d
+  llvm.return %7 : !llvm.memref_2d
+}
+
+// Interface function callable from C.
+llvm.func @_mlir_ciface_foo(%arg0: !llvm.memref_2d_ptr, %arg1: !llvm.memref_2d_ptr) {
+  %0 = llvm.load %arg1 : !llvm.memref_2d_ptr
+  %1 = llvm.extractvalue %0[0] : !llvm.memref_2d
+  %2 = llvm.extractvalue %0[1] : !llvm.memref_2d
+  %3 = llvm.extractvalue %0[2] : !llvm.memref_2d
+  %4 = llvm.extractvalue %0[3, 0] : !llvm.memref_2d
+  %5 = llvm.extractvalue %0[3, 1] : !llvm.memref_2d
+  %6 = llvm.extractvalue %0[4, 0] : !llvm.memref_2d
+  %7 = llvm.extractvalue %0[4, 1] : !llvm.memref_2d
+  %8 = llvm.call @foo(%1, %2, %3, %4, %5, %6, %7)
+    : (!llvm.ptr<f32>, !llvm.ptr<f32>, i64, i64, i64, i64, i64) -> !llvm.memref_2d
+  llvm.store %8, %arg0 : !llvm.memref_2d_ptr
+  llvm.return
+}
+
 Rationale: Introducing auxiliary functions for C-compatible interfaces is
 preferred to modifying the calling convention since it will minimize the effect
 of C compatibility on intra-module calls or calls between MLIR-generated

diff  --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
index c37528178e83..84052c6676e4 100644
--- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
+++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
@@ -116,8 +116,10 @@ class LLVMTypeConverter : public TypeConverter {
                                    OpBuilder &builder);
 
   /// Converts the function type to a C-compatible format, in particular using
-  /// pointers to memref descriptors for arguments.
-  Type convertFunctionTypeCWrapper(FunctionType type);
+  /// pointers to memref descriptors for arguments. Also converts the return
+  /// type to a pointer argument if it is a struct. Returns true if this
+  /// was the case.
+  std::pair<Type, bool> convertFunctionTypeCWrapper(FunctionType type);
 
   /// Returns the data layout to use during and after conversion.
   const llvm::DataLayout &getDataLayout() { return options.dataLayout; }

diff  --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index de1df34eaa5d..91e520e3ca62 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -253,8 +253,24 @@ Type LLVMTypeConverter::convertFunctionSignature(
 
 /// Converts the function type to a C-compatible format, in particular using
 /// pointers to memref descriptors for arguments.
-Type LLVMTypeConverter::convertFunctionTypeCWrapper(FunctionType type) {
+std::pair<Type, bool>
+LLVMTypeConverter::convertFunctionTypeCWrapper(FunctionType type) {
   SmallVector<Type, 4> inputs;
+  bool resultIsNowArg = false;
+
+  Type resultType = type.getNumResults() == 0
+                        ? LLVM::LLVMVoidType::get(&getContext())
+                        : unwrap(packFunctionResults(type.getResults()));
+  if (!resultType)
+    return {};
+
+  if (auto structType = resultType.dyn_cast<LLVM::LLVMStructType>()) {
+    // Struct types cannot be safely returned via C interface. Make this a
+    // pointer argument, instead.
+    inputs.push_back(LLVM::LLVMPointerType::get(structType));
+    resultType = LLVM::LLVMVoidType::get(&getContext());
+    resultIsNowArg = true;
+  }
 
   for (Type t : type.getInputs()) {
     auto converted = convertType(t);
@@ -265,13 +281,7 @@ Type LLVMTypeConverter::convertFunctionTypeCWrapper(FunctionType type) {
     inputs.push_back(converted);
   }
 
-  Type resultType = type.getNumResults() == 0
-                        ? LLVM::LLVMVoidType::get(&getContext())
-                        : unwrap(packFunctionResults(type.getResults()));
-  if (!resultType)
-    return {};
-
-  return LLVM::LLVMFunctionType::get(resultType, inputs);
+  return {LLVM::LLVMFunctionType::get(resultType, inputs), resultIsNowArg};
 }
 
 static constexpr unsigned kAllocatedPtrPosInMemRefDescriptor = 0;
@@ -1212,8 +1222,11 @@ static void filterFuncAttributes(ArrayRef<NamedAttribute> attrs,
 /// Creates an auxiliary function with pointer-to-memref-descriptor-struct
 /// arguments instead of unpacked arguments. This function can be called from C
 /// by passing a pointer to a C struct corresponding to a memref descriptor.
+/// Similarly, returned memrefs are passed via pointers to a C struct that is
+/// passed as additional argument.
 /// Internally, the auxiliary function unpacks the descriptor into individual
-/// components and forwards them to `newFuncOp`.
+/// components and forwards them to `newFuncOp` and forwards the results to
+/// the extra arguments.
 static void wrapForExternalCallers(OpBuilder &rewriter, Location loc,
                                    LLVMTypeConverter &typeConverter,
                                    FuncOp funcOp, LLVM::LLVMFuncOp newFuncOp) {
@@ -1221,17 +1234,21 @@ static void wrapForExternalCallers(OpBuilder &rewriter, Location loc,
   SmallVector<NamedAttribute, 4> attributes;
   filterFuncAttributes(funcOp->getAttrs(), /*filterArgAttrs=*/false,
                        attributes);
+  Type wrapperFuncType;
+  bool resultIsNowArg;
+  std::tie(wrapperFuncType, resultIsNowArg) =
+      typeConverter.convertFunctionTypeCWrapper(type);
   auto wrapperFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
       loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(),
-      typeConverter.convertFunctionTypeCWrapper(type), LLVM::Linkage::External,
-      attributes);
+      wrapperFuncType, LLVM::Linkage::External, attributes);
 
   OpBuilder::InsertionGuard guard(rewriter);
   rewriter.setInsertionPointToStart(wrapperFuncOp.addEntryBlock());
 
   SmallVector<Value, 8> args;
+  size_t argOffset = resultIsNowArg ? 1 : 0;
   for (auto &en : llvm::enumerate(type.getInputs())) {
-    Value arg = wrapperFuncOp.getArgument(en.index());
+    Value arg = wrapperFuncOp.getArgument(en.index() + argOffset);
     if (auto memrefType = en.value().dyn_cast<MemRefType>()) {
       Value loaded = rewriter.create<LLVM::LoadOp>(loc, arg);
       MemRefDescriptor::unpack(rewriter, loc, loaded, memrefType, args);
@@ -1243,28 +1260,40 @@ static void wrapForExternalCallers(OpBuilder &rewriter, Location loc,
       continue;
     }
 
-    args.push_back(wrapperFuncOp.getArgument(en.index()));
+    args.push_back(arg);
   }
+
   auto call = rewriter.create<LLVM::CallOp>(loc, newFuncOp, args);
-  rewriter.create<LLVM::ReturnOp>(loc, call.getResults());
+
+  if (resultIsNowArg) {
+    rewriter.create<LLVM::StoreOp>(loc, call.getResult(0),
+                                   wrapperFuncOp.getArgument(0));
+    rewriter.create<LLVM::ReturnOp>(loc, ValueRange{});
+  } else {
+    rewriter.create<LLVM::ReturnOp>(loc, call.getResults());
+  }
 }
 
 /// Creates an auxiliary function with pointer-to-memref-descriptor-struct
 /// arguments instead of unpacked arguments. Creates a body for the (external)
 /// `newFuncOp` that allocates a memref descriptor on stack, packs the
 /// individual arguments into this descriptor and passes a pointer to it into
-/// the auxiliary function. This auxiliary external function is now compatible
-/// with functions defined in C using pointers to C structs corresponding to a
-/// memref descriptor.
+/// the auxiliary function. If the result of the function cannot be directly
+/// returned, we write it to a special first argument that provides a pointer
+/// to a corresponding struct. This auxiliary external function is now
+/// compatible with functions defined in C using pointers to C structs
+/// corresponding to a memref descriptor.
 static void wrapExternalFunction(OpBuilder &builder, Location loc,
                                  LLVMTypeConverter &typeConverter,
                                  FuncOp funcOp, LLVM::LLVMFuncOp newFuncOp) {
   OpBuilder::InsertionGuard guard(builder);
 
-  Type wrapperType =
+  Type wrapperType;
+  bool resultIsNowArg;
+  std::tie(wrapperType, resultIsNowArg) =
       typeConverter.convertFunctionTypeCWrapper(funcOp.getType());
   // This conversion can only fail if it could not convert one of the argument
-  // types. But since it has been applies to a non-wrapper function before, it
+  // types. But since it has been applied to a non-wrapper function before, it
   // should have failed earlier and not reach this point at all.
   assert(wrapperType && "unexpected type conversion failure");
 
@@ -1285,6 +1314,17 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc,
   args.reserve(type.getNumInputs());
   ValueRange wrapperArgsRange(newFuncOp.getArguments());
 
+  if (resultIsNowArg) {
+    // Allocate the struct on the stack and pass the pointer.
+    Type resultType =
+        wrapperType.cast<LLVM::LLVMFunctionType>().getParamType(0);
+    Value one = builder.create<LLVM::ConstantOp>(
+        loc, typeConverter.convertType(builder.getIndexType()),
+        builder.getIntegerAttr(builder.getIndexType(), 1));
+    Value result = builder.create<LLVM::AllocaOp>(loc, resultType, one);
+    args.push_back(result);
+  }
+
   // Iterate over the inputs of the original function and pack values into
   // memref descriptors if the original type is a memref.
   for (auto &en : llvm::enumerate(type.getInputs())) {
@@ -1322,7 +1362,13 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc,
   assert(wrapperArgsRange.empty() && "did not map some of the arguments");
 
   auto call = builder.create<LLVM::CallOp>(loc, wrapperFunc, args);
-  builder.create<LLVM::ReturnOp>(loc, call.getResults());
+
+  if (resultIsNowArg) {
+    Value result = builder.create<LLVM::LoadOp>(loc, args.front());
+    builder.create<LLVM::ReturnOp>(loc, ValueRange{result});
+  } else {
+    builder.create<LLVM::ReturnOp>(loc, call.getResults());
+  }
 }
 
 namespace {

diff  --git a/mlir/test/Conversion/StandardToLLVM/calling-convention.mlir b/mlir/test/Conversion/StandardToLLVM/calling-convention.mlir
index 6db4106d6e69..e0fc24ee1333 100644
--- a/mlir/test/Conversion/StandardToLLVM/calling-convention.mlir
+++ b/mlir/test/Conversion/StandardToLLVM/calling-convention.mlir
@@ -144,7 +144,7 @@ func @return_var_memref_caller(%arg0: memref<4x3xf32>) {
 }
 
 // CHECK-LABEL: llvm.func @return_var_memref
-func @return_var_memref(%arg0: memref<4x3xf32>) -> memref<*xf32> {
+func @return_var_memref(%arg0: memref<4x3xf32>) -> memref<*xf32> attributes { llvm.emit_c_interface } {
   // Match the construction of the unranked descriptor.
   // CHECK: %[[ALLOCA:.*]] = llvm.alloca
   // CHECK: %[[MEMORY:.*]] = llvm.bitcast %[[ALLOCA]]
@@ -177,6 +177,10 @@ func @return_var_memref(%arg0: memref<4x3xf32>) -> memref<*xf32> {
   return %0 : memref<*xf32>
 }
 
+// Check that the result memref is passed as parameter
+// CHECK-LABEL: @_mlir_ciface_return_var_memref
+// CHECK-SAME: (%{{.*}}: !llvm.ptr<struct<(i64, ptr<i8>)>>, %{{.*}}: !llvm.ptr<struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>>)
+
 // CHECK-LABEL: llvm.func @return_two_var_memref_caller
 func @return_two_var_memref_caller(%arg0: memref<4x3xf32>) {
   // Only check that we create two 
diff erent descriptors using 
diff erent
@@ -206,7 +210,7 @@ func @return_two_var_memref_caller(%arg0: memref<4x3xf32>) {
 }
 
 // CHECK-LABEL: llvm.func @return_two_var_memref
-func @return_two_var_memref(%arg0: memref<4x3xf32>) -> (memref<*xf32>, memref<*xf32>) {
+func @return_two_var_memref(%arg0: memref<4x3xf32>) -> (memref<*xf32>, memref<*xf32>) attributes { llvm.emit_c_interface } {
   // Match the construction of the unranked descriptor.
   // CHECK: %[[ALLOCA:.*]] = llvm.alloca
   // CHECK: %[[MEMORY:.*]] = llvm.bitcast %[[ALLOCA]]
@@ -240,3 +244,8 @@ func @return_two_var_memref(%arg0: memref<4x3xf32>) -> (memref<*xf32>, memref<*x
   return %0, %0 : memref<*xf32>, memref<*xf32>
 }
 
+// Check that the result memrefs are passed as parameter
+// CHECK-LABEL: @_mlir_ciface_return_two_var_memref
+// CHECK-SAME: (%{{.*}}: !llvm.ptr<struct<(struct<(i64, ptr<i8>)>, struct<(i64, ptr<i8>)>)>>,
+// CHECK-SAME: %{{.*}}: !llvm.ptr<struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>>)
+


        


More information about the Mlir-commits mailing list