[Mlir-commits] [mlir] 2ac2e9a - [mlir][LLVM] Improve lowering of `llvm.byval` function arguments (#100028)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Aug 8 19:27:58 PDT 2024


Author: Diego Caballero
Date: 2024-08-08T19:27:54-07:00
New Revision: 2ac2e9a5b6c97cbf267db1ef322ed21ebceb2aba

URL: https://github.com/llvm/llvm-project/commit/2ac2e9a5b6c97cbf267db1ef322ed21ebceb2aba
DIFF: https://github.com/llvm/llvm-project/commit/2ac2e9a5b6c97cbf267db1ef322ed21ebceb2aba.diff

LOG: [mlir][LLVM] Improve lowering of `llvm.byval` function arguments (#100028)

When a function argument is annotated with the `llvm.byval` attribute,
[LLVM expects](https://llvm.org/docs/LangRef.html#parameter-attributes)
the function argument type to be an `llvm.ptr`. For example:

```
func.func (%args0 : llvm.ptr {llvm.byval = !llvm.struct<(i32)>} {
  ...
}
```

Unfortunately, this makes the type conversion context-dependent, which
is something that the type conversion infrastructure (i.e.,
`LLVMTypeConverter` in this particular case) doesn't support. For
example, we may want to convert `MyType` to `llvm.struct<(i32)>` in
general, but to an `llvm.ptr` type only when it's a function argument
passed by value.

To fix this problem, this PR changes the FuncToLLVM conversion logic to
generate an `llvm.ptr` when the function argument has a `llvm.byval`
attribute. An `llvm.load` is inserted into the function to retrieve the
value expected by the argument users.

Added: 
    

Modified: 
    mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
    mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
    mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
    mlir/test/Transforms/test-convert-func-op.mlir
    mlir/test/lib/Conversion/FuncToLLVM/TestConvertFuncOp.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
index e228229302cff..d79b90f840ce8 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
@@ -21,6 +21,7 @@
 namespace mlir {
 
 class DataLayoutAnalysis;
+class FunctionOpInterface;
 class LowerToLLVMOptions;
 
 namespace LLVM {
@@ -50,13 +51,25 @@ class LLVMTypeConverter : public TypeConverter {
   LLVMTypeConverter(MLIRContext *ctx, const LowerToLLVMOptions &options,
                     const DataLayoutAnalysis *analysis = nullptr);
 
-  /// Convert a function type.  The arguments and results are converted one by
+  /// Convert a function type. The arguments and results are converted one by
   /// one and results are packed into a wrapped LLVM IR structure type. `result`
   /// is populated with argument mapping.
   Type convertFunctionSignature(FunctionType funcTy, bool isVariadic,
                                 bool useBarePtrCallConv,
                                 SignatureConversion &result) const;
 
+  /// Convert a function type. The arguments and results are converted one by
+  /// one and results are packed into a wrapped LLVM IR structure type. `result`
+  /// is populated with argument mapping. Converted types of `llvm.byval` and
+  /// `llvm.byref` function arguments which are not LLVM pointers are overridden
+  /// with LLVM pointers. Overridden arguments are returned in
+  /// `byValRefNonPtrAttrs`.
+  Type convertFunctionSignature(FunctionOpInterface funcOp, bool isVariadic,
+                                bool useBarePtrCallConv,
+                                LLVMTypeConverter::SignatureConversion &result,
+                                SmallVectorImpl<std::optional<NamedAttribute>>
+                                    &byValRefNonPtrAttrs) const;
+
   /// Convert a non-empty list of types to be returned from a function into an
   /// LLVM-compatible type. In particular, if more than one value is returned,
   /// create an LLVM dialect structure type with elements that correspond to
@@ -159,12 +172,26 @@ class LLVMTypeConverter : public TypeConverter {
   SmallVector<Type> &getCurrentThreadRecursiveStack();
 
 private:
-  /// Convert a function type.  The arguments and results are converted one by
-  /// one.  Additionally, if the function returns more than one value, pack the
+  /// Convert a function type. The arguments and results are converted one by
+  /// one. Additionally, if the function returns more than one value, pack the
   /// results into an LLVM IR structure type so that the converted function type
   /// returns at most one result.
   Type convertFunctionType(FunctionType type) const;
 
+  /// Common implementation for `convertFunctionSignature` methods. Convert a
+  /// function type. The arguments and results are converted one by one and
+  /// results are packed into a wrapped LLVM IR structure type. `result` is
+  /// populated with argument mapping. If `byValRefNonPtrAttrs` is provided,
+  /// converted types of `llvm.byval` and `llvm.byref` function arguments which
+  /// are not LLVM pointers are overridden with LLVM pointers. `llvm.byval` and
+  /// `llvm.byref` arguments that were already converted to LLVM pointer types
+  /// are removed from 'byValRefNonPtrAttrs`.
+  Type convertFunctionSignatureImpl(
+      FunctionType funcTy, bool isVariadic, bool useBarePtrCallConv,
+      LLVMTypeConverter::SignatureConversion &result,
+      SmallVectorImpl<std::optional<NamedAttribute>> *byValRefNonPtrAttrs)
+      const;
+
   /// Convert the index type.  Uses llvmModule data layout to create an integer
   /// of the pointer bitwidth.
   Type convertIndexType(IndexType type) const;

diff  --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
index c1f6d8bc5b361..4c2e8682285c5 100644
--- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
+++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
@@ -267,6 +267,38 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc,
   }
 }
 
+/// Inserts `llvm.load` ops in the function body to restore the expected pointee
+/// value from `llvm.byval`/`llvm.byref` function arguments that were converted
+/// to LLVM pointer types.
+static void restoreByValRefArgumentType(
+    ConversionPatternRewriter &rewriter, const LLVMTypeConverter &typeConverter,
+    ArrayRef<std::optional<NamedAttribute>> byValRefNonPtrAttrs,
+    LLVM::LLVMFuncOp funcOp) {
+  // Nothing to do for function declarations.
+  if (funcOp.isExternal())
+    return;
+
+  ConversionPatternRewriter::InsertionGuard guard(rewriter);
+  rewriter.setInsertionPointToStart(&funcOp.getFunctionBody().front());
+
+  for (const auto &[arg, byValRefAttr] :
+       llvm::zip(funcOp.getArguments(), byValRefNonPtrAttrs)) {
+    // Skip argument if no `llvm.byval` or `llvm.byref` attribute.
+    if (!byValRefAttr)
+      continue;
+
+    // Insert load to retrieve the actual argument passed by value/reference.
+    assert(isa<LLVM::LLVMPointerType>(arg.getType()) &&
+           "Expected LLVM pointer type for argument with "
+           "`llvm.byval`/`llvm.byref` attribute");
+    Type resTy = typeConverter.convertType(
+        cast<TypeAttr>(byValRefAttr->getValue()).getValue());
+
+    auto valueArg = rewriter.create<LLVM::LoadOp>(arg.getLoc(), resTy, arg);
+    rewriter.replaceAllUsesExcept(arg, valueArg, valueArg);
+  }
+}
+
 FailureOr<LLVM::LLVMFuncOp>
 mlir::convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
                                 ConversionPatternRewriter &rewriter,
@@ -280,10 +312,14 @@ mlir::convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
   // Convert the original function arguments. They are converted using the
   // LLVMTypeConverter provided to this legalization pattern.
   auto varargsAttr = funcOp->getAttrOfType<BoolAttr>(varargsAttrName);
+  // Gather `llvm.byval` and `llvm.byref` arguments whose type convertion was
+  // overriden with an LLVM pointer type for later processing.
+  SmallVector<std::optional<NamedAttribute>> byValRefNonPtrAttrs;
   TypeConverter::SignatureConversion result(funcOp.getNumArguments());
   auto llvmType = converter.convertFunctionSignature(
-      funcTy, varargsAttr && varargsAttr.getValue(),
-      shouldUseBarePtrCallConv(funcOp, &converter), result);
+      funcOp, varargsAttr && varargsAttr.getValue(),
+      shouldUseBarePtrCallConv(funcOp, &converter), result,
+      byValRefNonPtrAttrs);
   if (!llvmType)
     return rewriter.notifyMatchFailure(funcOp, "signature conversion failed");
 
@@ -398,6 +434,12 @@ mlir::convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
                                        "region types conversion failed");
   }
 
+  // Fix the type mismatch between the materialized `llvm.ptr` and the expected
+  // pointee type in the function body when converting `llvm.byval`/`llvm.byref`
+  // function arguments.
+  restoreByValRefArgumentType(rewriter, converter, byValRefNonPtrAttrs,
+                              newFuncOp);
+
   if (!shouldUseBarePtrCallConv(funcOp, &converter)) {
     if (funcOp->getAttrOfType<UnitAttr>(
             LLVM::LLVMDialect::getEmitCWrapperAttrName())) {

diff  --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index 17be4d91ee054..5313a64ed47e3 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -270,13 +270,42 @@ Type LLVMTypeConverter::convertFunctionType(FunctionType type) const {
   return LLVM::LLVMPointerType::get(type.getContext());
 }
 
+/// Returns the `llvm.byval` or `llvm.byref` attributes that are present in the
+/// function arguments. Returns an empty container if none of these attributes
+/// are found in any of the arguments.
+static void
+filterByValRefArgAttrs(FunctionOpInterface funcOp,
+                       SmallVectorImpl<std::optional<NamedAttribute>> &result) {
+  assert(result.empty() && "Unexpected non-empty output");
+  result.resize(funcOp.getNumArguments(), std::nullopt);
+  bool foundByValByRefAttrs = false;
+  for (int argIdx : llvm::seq(funcOp.getNumArguments())) {
+    for (NamedAttribute namedAttr : funcOp.getArgAttrs(argIdx)) {
+      if ((namedAttr.getName() == LLVM::LLVMDialect::getByValAttrName() ||
+           namedAttr.getName() == LLVM::LLVMDialect::getByRefAttrName())) {
+        foundByValByRefAttrs = true;
+        result[argIdx] = namedAttr;
+        break;
+      }
+    }
+  }
+
+  if (!foundByValByRefAttrs)
+    result.clear();
+}
+
 // Function types are converted to LLVM Function types by recursively converting
-// argument and result types.  If MLIR Function has zero results, the LLVM
-// Function has one VoidType result.  If MLIR Function has more than one result,
+// argument and result types. If MLIR Function has zero results, the LLVM
+// Function has one VoidType result. If MLIR Function has more than one result,
 // they are into an LLVM StructType in their order of appearance.
-Type LLVMTypeConverter::convertFunctionSignature(
+// If `byValRefNonPtrAttrs` is provided, converted types of `llvm.byval` and
+// `llvm.byref` function arguments which are not LLVM pointers are overridden
+// with LLVM pointers. `llvm.byval` and `llvm.byref` arguments that were already
+// converted to LLVM pointer types are removed from 'byValRefNonPtrAttrs`.
+Type LLVMTypeConverter::convertFunctionSignatureImpl(
     FunctionType funcTy, bool isVariadic, bool useBarePtrCallConv,
-    LLVMTypeConverter::SignatureConversion &result) const {
+    LLVMTypeConverter::SignatureConversion &result,
+    SmallVectorImpl<std::optional<NamedAttribute>> *byValRefNonPtrAttrs) const {
   // Select the argument converter depending on the calling convention.
   useBarePtrCallConv = useBarePtrCallConv || options.useBarePtrCallConv;
   auto funcArgConverter = useBarePtrCallConv ? barePtrFuncArgTypeConverter
@@ -286,6 +315,19 @@ Type LLVMTypeConverter::convertFunctionSignature(
     SmallVector<Type, 8> converted;
     if (failed(funcArgConverter(*this, type, converted)))
       return {};
+
+    // Rewrite converted type of `llvm.byval` or `llvm.byref` function
+    // argument that was not converted to an LLVM pointer types.
+    if (byValRefNonPtrAttrs != nullptr && !byValRefNonPtrAttrs->empty() &&
+        converted.size() == 1 && (*byValRefNonPtrAttrs)[idx].has_value()) {
+      // If the argument was already converted to an LLVM pointer type, we stop
+      // tracking it as it doesn't need more processing.
+      if (isa<LLVM::LLVMPointerType>(converted[0]))
+        (*byValRefNonPtrAttrs)[idx] = std::nullopt;
+      else
+        converted[0] = LLVM::LLVMPointerType::get(&getContext());
+    }
+
     result.addInputs(idx, converted);
   }
 
@@ -302,6 +344,27 @@ Type LLVMTypeConverter::convertFunctionSignature(
                                      isVariadic);
 }
 
+Type LLVMTypeConverter::convertFunctionSignature(
+    FunctionType funcTy, bool isVariadic, bool useBarePtrCallConv,
+    LLVMTypeConverter::SignatureConversion &result) const {
+  return convertFunctionSignatureImpl(funcTy, isVariadic, useBarePtrCallConv,
+                                      result,
+                                      /*byValRefNonPtrAttrs=*/nullptr);
+}
+
+Type LLVMTypeConverter::convertFunctionSignature(
+    FunctionOpInterface funcOp, bool isVariadic, bool useBarePtrCallConv,
+    LLVMTypeConverter::SignatureConversion &result,
+    SmallVectorImpl<std::optional<NamedAttribute>> &byValRefNonPtrAttrs) const {
+  // Gather all `llvm.byval` and `llvm.byref` function arguments. Only those
+  // that were not converted to LLVM pointer types will be returned for further
+  // processing.
+  filterByValRefArgAttrs(funcOp, byValRefNonPtrAttrs);
+  auto funcTy = cast<FunctionType>(funcOp.getFunctionType());
+  return convertFunctionSignatureImpl(funcTy, isVariadic, useBarePtrCallConv,
+                                      result, &byValRefNonPtrAttrs);
+}
+
 /// Converts the function type to a C-compatible format, in particular using
 /// pointers to memref descriptors for arguments.
 std::pair<LLVM::LLVMFunctionType, LLVM::LLVMStructType>

diff  --git a/mlir/test/Transforms/test-convert-func-op.mlir b/mlir/test/Transforms/test-convert-func-op.mlir
index 6e96703cda578..180f16a32991b 100644
--- a/mlir/test/Transforms/test-convert-func-op.mlir
+++ b/mlir/test/Transforms/test-convert-func-op.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -test-convert-func-op | FileCheck %s
+// RUN: mlir-opt %s -test-convert-func-op --split-input-file | FileCheck %s
 
 // CHECK-LABEL: llvm.func @add
 func.func @add(%arg0: i32, %arg1: i32) -> i32 attributes { llvm.emit_c_interface } {
@@ -10,3 +10,31 @@ func.func @add(%arg0: i32, %arg1: i32) -> i32 attributes { llvm.emit_c_interface
 // CHECK-SAME: [[ARG1:%[a-zA-Z0-9_]+]]: i32
 // CHECK-NEXT: [[RES:%.*]] = llvm.call @add([[ARG0]], [[ARG1]])
 // CHECK-NEXT: llvm.return [[RES]]
+
+// -----
+
+// Test that `llvm.byval` arguments are converted to `llvm.ptr` and the actual
+// value is retrieved within the `llvm.func`.
+
+// CHECK-LABEL: llvm.func @byval
+func.func @byval(%arg0: !test.smpla {llvm.byval = !test.smpla}) -> !test.smpla {
+  return %arg0 : !test.smpla
+}
+
+// CHECK-SAME: (%[[ARG0:.*]]: !llvm.ptr {llvm.byval = !llvm.struct<(i8, i8)>}) -> !llvm.struct<(i8, i8)>
+//      CHECK: %[[LD:.*]] = llvm.load %[[ARG0]] : !llvm.ptr -> !llvm.struct<(i8, i8)>
+//      CHECK: llvm.return %[[LD]] : !llvm.struct<(i8, i8)>
+
+// -----
+
+// Test that `llvm.byref` arguments are converted to `llvm.ptr` and the actual
+// value is retrieved within the `llvm.func`.
+
+// CHECK-LABEL: llvm.func @byref
+func.func @byref(%arg0: !test.smpla {llvm.byref = !test.smpla}) -> !test.smpla {
+  return %arg0 : !test.smpla
+}
+
+// CHECK-SAME: (%[[ARG0:.*]]: !llvm.ptr {llvm.byref = !llvm.struct<(i8, i8)>}) -> !llvm.struct<(i8, i8)>
+//      CHECK: %[[LD:.*]] = llvm.load %[[ARG0]] : !llvm.ptr -> !llvm.struct<(i8, i8)>
+//      CHECK: llvm.return %[[LD]] : !llvm.struct<(i8, i8)>

diff  --git a/mlir/test/lib/Conversion/FuncToLLVM/TestConvertFuncOp.cpp b/mlir/test/lib/Conversion/FuncToLLVM/TestConvertFuncOp.cpp
index e25e890e2290a..75168dde93130 100644
--- a/mlir/test/lib/Conversion/FuncToLLVM/TestConvertFuncOp.cpp
+++ b/mlir/test/lib/Conversion/FuncToLLVM/TestConvertFuncOp.cpp
@@ -47,12 +47,23 @@ struct ReturnOpConversion : public ConvertOpToLLVMPattern<func::ReturnOp> {
   LogicalResult
   matchAndRewrite(func::ReturnOp returnOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnOp,
-                                                returnOp->getOperands());
+    SmallVector<Type> resTys;
+    if (failed(typeConverter->convertTypes(returnOp->getResultTypes(), resTys)))
+      return failure();
+
+    rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnOp, resTys,
+                                                adaptor.getOperands());
     return success();
   }
 };
 
+static std::optional<Type>
+convertSimpleATypeToStruct(test::SimpleAType simpleTy) {
+  MLIRContext *ctx = simpleTy.getContext();
+  SmallVector<Type> memberTys(2, IntegerType::get(ctx, /*width=*/8));
+  return LLVM::LLVMStructType::getLiteral(ctx, memberTys);
+}
+
 struct TestConvertFuncOp
     : public PassWrapper<TestConvertFuncOp, OperationPass<ModuleOp>> {
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestConvertFuncOp)
@@ -74,6 +85,7 @@ struct TestConvertFuncOp
     LowerToLLVMOptions options(ctx);
     // Populate type conversions.
     LLVMTypeConverter typeConverter(ctx, options);
+    typeConverter.addConversion(convertSimpleATypeToStruct);
 
     RewritePatternSet patterns(ctx);
     patterns.add<FuncOpConversion>(typeConverter);


        


More information about the Mlir-commits mailing list