[Mlir-commits] [mlir] [mlir][LLVM] Improve lowering of `llvm.byval` function arguments (PR #100028)

Diego Caballero llvmlistbot at llvm.org
Fri Jul 26 16:39:16 PDT 2024


https://github.com/dcaballe updated https://github.com/llvm/llvm-project/pull/100028

>From 51d558e655ce2b9440134effe195137ac9ccbfd3 Mon Sep 17 00:00:00 2001
From: Diego Caballero <dieg0ca6aller0 at gmail.com>
Date: Fri, 26 Jul 2024 16:22:13 -0700
Subject: [PATCH] [mlir][LLVM] Improve lowering of llvm.byval function
 arguments

When a function argument is annotated with the `llvm.byval` attribute,
[LLVM expects] (https://llvm.org/docs/LangRef.html#parameter-attributes)
the function argument type to be an `llvm.ptr`. For example:

```
func.func (%args0 : llvm.ptr {llvm.byval = !llvm.struct<(i32)>} {
  ...
}
```

Unfortunately, this makes the type conversion context-dependent, which is
something that the type conversion infrastructure (i.e., `LLVMTypeConverter`
in this particular case) doesn't support. For example, we may want to convert
`MyType` to `llvm.struct<(i32)>` in general, but to an `llvm.ptr` type only
when it's a function argument passed by value.

To fix this problem, this PR changes the FuncToLLVM conversion logic to always
generate an `llvm.ptr` when the function argument has a `llvm.byval` attribute.
An `llvm.load` is inserted into the function to retrieve the value expected by
the argument users.
---
 .../Conversion/LLVMCommon/TypeConverter.h     | 29 +++++---
 mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp | 39 +++++++++-
 .../Conversion/GPUCommon/GPUOpsLowering.cpp   |  6 +-
 .../Conversion/LLVMCommon/TypeConverter.cpp   | 72 +++++++++++++++----
 .../Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp    |  4 +-
 .../test/Transforms/test-convert-func-op.mlir | 30 +++++++-
 .../FuncToLLVM/TestConvertFuncOp.cpp          | 16 ++++-
 7 files changed, 166 insertions(+), 30 deletions(-)

diff --git a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
index e228229302cff..007e6ba39b632 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
@@ -21,6 +21,7 @@
 namespace mlir {
 
 class DataLayoutAnalysis;
+class FunctionOpInterface;
 class LowerToLLVMOptions;
 
 namespace LLVM {
@@ -35,6 +36,7 @@ class LLVMTypeConverter : public TypeConverter {
   /// Give structFuncArgTypeConverter access to memref-specific functions.
   friend LogicalResult
   structFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type,
+                             std::optional<NamedAttribute> byValByRefArgAttr,
                              SmallVectorImpl<Type> &result);
 
 public:
@@ -53,9 +55,10 @@ class LLVMTypeConverter : public TypeConverter {
   /// Convert a function type.  The arguments and results are converted one by
   /// one and results are packed into a wrapped LLVM IR structure type. `result`
   /// is populated with argument mapping.
-  Type convertFunctionSignature(FunctionType funcTy, bool isVariadic,
-                                bool useBarePtrCallConv,
-                                SignatureConversion &result) const;
+  Type convertFunctionSignature(
+      FunctionType funcTy, bool isVariadic, bool useBarePtrCallConv,
+      ArrayRef<std::optional<NamedAttribute>> byValByRefArgAttr,
+      SignatureConversion &result) const;
 
   /// Convert a non-empty list of types to be returned from a function into an
   /// LLVM-compatible type. In particular, if more than one value is returned,
@@ -242,15 +245,23 @@ class LLVMTypeConverter : public TypeConverter {
 /// argument to a list of non-aggregate types containing descriptor
 /// information, and an UnrankedmemRef function argument to a list containing
 /// the rank and a pointer to a descriptor struct.
-LogicalResult structFuncArgTypeConverter(const LLVMTypeConverter &converter,
-                                         Type type,
-                                         SmallVectorImpl<Type> &result);
+LogicalResult
+structFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type,
+                           std::optional<NamedAttribute> byValByRefArgAttr,
+                           SmallVectorImpl<Type> &result);
 
 /// Callback to convert function argument types. It converts MemRef function
 /// arguments to bare pointers to the MemRef element type.
-LogicalResult barePtrFuncArgTypeConverter(const LLVMTypeConverter &converter,
-                                          Type type,
-                                          SmallVectorImpl<Type> &result);
+LogicalResult
+barePtrFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type,
+                            std::optional<NamedAttribute> byValByRefArgAttr,
+                            SmallVectorImpl<Type> &result);
+
+/// Returns in `result` the `llvm.byval` or `llvm.byref` attributes, if
+/// present, or an empty attribute for each function argument.
+void filterByValByRefArgAttributes(
+    FunctionOpInterface funcOp,
+    SmallVectorImpl<std::optional<NamedAttribute>> &result);
 
 } // namespace mlir
 
diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
index c1f6d8bc5b361..8d4645d46f069 100644
--- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
+++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
@@ -267,6 +267,36 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc,
   }
 }
 
+/// Inserts `llvm.load` ops in the function body to restore the expected pointee
+/// value from `llvm.byval`/`llvm.byref` function arguments that were converted
+/// to LLVM pointer types.
+static void restoreByValByRefArgumentType(ConversionPatternRewriter &rewriter,
+                                          FunctionOpInterface funcOp) {
+  // Nothing to do for function declarations.
+  if (funcOp.isExternal())
+    return;
+
+  ConversionPatternRewriter::InsertionGuard guard(rewriter);
+  rewriter.setInsertionPointToStart(&funcOp.getFunctionBody().front());
+
+  SmallVector<std::optional<NamedAttribute>> byValRefArgAttrs;
+  filterByValByRefArgAttributes(funcOp, byValRefArgAttrs);
+  for (const auto &[arg, byValRefAttr] :
+       llvm::zip(funcOp.getArguments(), byValRefArgAttrs)) {
+    // Skip argument if no `llvm.byval` or `llvm.byref` attribute.
+    if (!byValRefAttr)
+      continue;
+
+    // Insert load to retrieve the actual argument passed by value/reference.
+    assert(isa<LLVM::LLVMPointerType>(arg.getType()) &&
+           "Expected LLVM pointer type for argument with "
+           "`llvm.byval`/`llvm.byref` attribute");
+    Type resTy = cast<TypeAttr>(byValRefAttr->getValue()).getValue();
+    auto valueArg = rewriter.create<LLVM::LoadOp>(arg.getLoc(), resTy, arg);
+    rewriter.replaceAllUsesExcept(arg, valueArg, valueArg);
+  }
+}
+
 FailureOr<LLVM::LLVMFuncOp>
 mlir::convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
                                 ConversionPatternRewriter &rewriter,
@@ -280,10 +310,12 @@ mlir::convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
   // Convert the original function arguments. They are converted using the
   // LLVMTypeConverter provided to this legalization pattern.
   auto varargsAttr = funcOp->getAttrOfType<BoolAttr>(varargsAttrName);
+  SmallVector<std::optional<NamedAttribute>> byValByRefArgAttrs;
+  filterByValByRefArgAttributes(funcOp, byValByRefArgAttrs);
   TypeConverter::SignatureConversion result(funcOp.getNumArguments());
   auto llvmType = converter.convertFunctionSignature(
       funcTy, varargsAttr && varargsAttr.getValue(),
-      shouldUseBarePtrCallConv(funcOp, &converter), result);
+      shouldUseBarePtrCallConv(funcOp, &converter), byValByRefArgAttrs, result);
   if (!llvmType)
     return rewriter.notifyMatchFailure(funcOp, "signature conversion failed");
 
@@ -398,6 +430,11 @@ mlir::convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
                                        "region types conversion failed");
   }
 
+  // Fix the type mismatch between the generated `llvm.ptr` and the expected
+  // pointee type in the function body when converting `llvm.byval`/`llvm.byref`
+  // function arguments.
+  restoreByValByRefArgumentType(rewriter, newFuncOp);
+
   if (!shouldUseBarePtrCallConv(funcOp, &converter)) {
     if (funcOp->getAttrOfType<UnitAttr>(
             LLVM::LLVMDialect::getEmitCWrapperAttrName())) {
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index 6053e34f30a41..143f7b3071253 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -53,10 +53,12 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
   // Remap proper input types.
   TypeConverter::SignatureConversion signatureConversion(
       gpuFuncOp.front().getNumArguments());
-
+  SmallVector<std::optional<NamedAttribute>> byValByRefArgAttrs;
+  filterByValByRefArgAttributes(gpuFuncOp, byValByRefArgAttrs);
   Type funcType = getTypeConverter()->convertFunctionSignature(
       gpuFuncOp.getFunctionType(), /*isVariadic=*/false,
-      getTypeConverter()->getOptions().useBarePtrCallConv, signatureConversion);
+      getTypeConverter()->getOptions().useBarePtrCallConv, byValByRefArgAttrs,
+      signatureConversion);
   if (!funcType) {
     return rewriter.notifyMatchFailure(gpuFuncOp, [&](Diagnostic &diag) {
       diag << "failed to convert function signature type for: "
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index d5df960928afb..38de4f2fdeac1 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -276,6 +276,7 @@ Type LLVMTypeConverter::convertFunctionType(FunctionType type) const {
 // they are into an LLVM StructType in their order of appearance.
 Type LLVMTypeConverter::convertFunctionSignature(
     FunctionType funcTy, bool isVariadic, bool useBarePtrCallConv,
+    ArrayRef<std::optional<NamedAttribute>> byValByRefArgAtts,
     LLVMTypeConverter::SignatureConversion &result) const {
   // Select the argument converter depending on the calling convention.
   useBarePtrCallConv = useBarePtrCallConv || options.useBarePtrCallConv;
@@ -284,7 +285,8 @@ Type LLVMTypeConverter::convertFunctionSignature(
   // Convert argument types one by one and check for errors.
   for (auto [idx, type] : llvm::enumerate(funcTy.getInputs())) {
     SmallVector<Type, 8> converted;
-    if (failed(funcArgConverter(*this, type, converted)))
+    if (failed(
+            funcArgConverter(*this, type, byValByRefArgAtts[idx], converted)))
       return {};
     result.addInputs(idx, converted);
   }
@@ -659,9 +661,10 @@ LLVMTypeConverter::promoteOperands(Location loc, ValueRange opOperands,
 /// argument to a list of non-aggregate types containing descriptor
 /// information, and an UnrankedmemRef function argument to a list containing
 /// the rank and a pointer to a descriptor struct.
-LogicalResult
-mlir::structFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type,
-                                 SmallVectorImpl<Type> &result) {
+LogicalResult mlir::structFuncArgTypeConverter(
+    const LLVMTypeConverter &converter, Type type,
+    std::optional<NamedAttribute> byValByRefArgAttr,
+    SmallVectorImpl<Type> &result) {
   if (auto memref = dyn_cast<MemRefType>(type)) {
     // In signatures, Memref descriptors are expanded into lists of
     // non-aggregate values.
@@ -679,23 +682,64 @@ mlir::structFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type,
     result.append(converted.begin(), converted.end());
     return success();
   }
-  auto converted = converter.convertType(type);
-  if (!converted)
-    return failure();
+
+  /// If the argument has the `llvm.byval` or `llvm.byref` attribute, the
+  /// converted type is an LLVM pointer so that the LLVM argument passing
+  /// is correct.
+  Type converted;
+  if (byValByRefArgAttr.has_value() &&
+      (byValByRefArgAttr->getName() == LLVM::LLVMDialect::getByValAttrName() ||
+       byValByRefArgAttr->getName() == LLVM::LLVMDialect::getByRefAttrName())) {
+    converted = LLVM::LLVMPointerType::get(type.getContext());
+  } else {
+    converted = converter.convertType(type);
+    if (!converted)
+      return failure();
+  }
+
   result.push_back(converted);
   return success();
 }
 
 /// Callback to convert function argument types. It converts MemRef function
 /// arguments to bare pointers to the MemRef element type.
-LogicalResult
-mlir::barePtrFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type,
-                                  SmallVectorImpl<Type> &result) {
-  auto llvmTy = converter.convertCallingConventionType(
-      type, /*useBarePointerCallConv=*/true);
-  if (!llvmTy)
-    return failure();
+LogicalResult mlir::barePtrFuncArgTypeConverter(
+    const LLVMTypeConverter &converter, Type type,
+    std::optional<NamedAttribute> byValByRefArgAttr,
+    SmallVectorImpl<Type> &result) {
+  /// If the argument has the `llvm.byval` or `llvm.byref` attribute, the
+  /// converted type is an LLVM pointer so that the LLVM argument passing
+  /// convention is correct.
+  Type llvmTy;
+  if (byValByRefArgAttr.has_value() &&
+      (byValByRefArgAttr->getName() == LLVM::LLVMDialect::getByValAttrName() ||
+       byValByRefArgAttr->getName() == LLVM::LLVMDialect::getByRefAttrName())) {
+    llvmTy = LLVM::LLVMPointerType::get(type.getContext());
+  } else {
+    llvmTy = converter.convertCallingConventionType(
+        type, /*useBarePointerCallConv=*/true);
+
+    if (!llvmTy)
+      return failure();
+  }
 
   result.push_back(llvmTy);
   return success();
 }
+
+void mlir::filterByValByRefArgAttributes(
+    FunctionOpInterface funcOp,
+    SmallVectorImpl<std::optional<NamedAttribute>> &result) {
+
+  assert(result.empty() && "Unexpected non-empty output");
+  for (int argIdx : llvm::seq(funcOp.getNumArguments())) {
+    for (NamedAttribute namedAttr : funcOp.getArgAttrs(argIdx)) {
+      if (namedAttr.getName() == LLVM::LLVMDialect::getByValAttrName() ||
+          namedAttr.getName() == LLVM::LLVMDialect::getByRefAttrName()) {
+        result.emplace_back(namedAttr);
+        break;
+      }
+    }
+    result.emplace_back(std::nullopt);
+  }
+}
diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
index da09384bfbe89..d5a9bc3783660 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
@@ -1408,11 +1408,13 @@ class FuncConversionPattern : public SPIRVToLLVMConversion<spirv::FuncOp> {
     // Convert function signature. At the moment LLVMType converter is enough
     // for currently supported types.
     auto funcType = funcOp.getFunctionType();
+    SmallVector<std::optional<NamedAttribute>> byValByRefArgAttrs;
+    filterByValByRefArgAttributes(funcOp, byValByRefArgAttrs);
     TypeConverter::SignatureConversion signatureConverter(
         funcType.getNumInputs());
     auto llvmType = typeConverter.convertFunctionSignature(
         funcType, /*isVariadic=*/false, /*useBarePtrCallConv=*/false,
-        signatureConverter);
+        byValByRefArgAttrs, signatureConverter);
     if (!llvmType)
       return failure();
 
diff --git a/mlir/test/Transforms/test-convert-func-op.mlir b/mlir/test/Transforms/test-convert-func-op.mlir
index 6e96703cda578..9f006a5a187c1 100644
--- a/mlir/test/Transforms/test-convert-func-op.mlir
+++ b/mlir/test/Transforms/test-convert-func-op.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -test-convert-func-op | FileCheck %s
+// RUN: mlir-opt %s -test-convert-func-op -split-input-file | FileCheck %s
 
 // CHECK-LABEL: llvm.func @add
 func.func @add(%arg0: i32, %arg1: i32) -> i32 attributes { llvm.emit_c_interface } {
@@ -10,3 +10,31 @@ func.func @add(%arg0: i32, %arg1: i32) -> i32 attributes { llvm.emit_c_interface
 // CHECK-SAME: [[ARG1:%[a-zA-Z0-9_]+]]: i32
 // CHECK-NEXT: [[RES:%.*]] = llvm.call @add([[ARG0]], [[ARG1]])
 // CHECK-NEXT: llvm.return [[RES]]
+
+// -----
+
+// Test that `llvm.byval` arguments are converted to `llvm.ptr` and the actual
+// value is retrieved within the `llvm.func`.
+
+// CHECK-LABEL: llvm.func @byval
+func.func @byval(%arg0: !test.smpla {llvm.byval = !test.smpla}) -> !test.smpla {
+  return %arg0 : !test.smpla
+}
+
+// CHECK-SAME: (%[[ARG0:.*]]: !llvm.ptr {llvm.byval = !llvm.struct<(i8, i8)>}) -> !llvm.struct<(i8, i8)>
+//      CHECK: %[[LD:.*]] = llvm.load %[[ARG0]] : !llvm.ptr -> !llvm.struct<(i8, i8)>
+//      CHECK: llvm.return %[[LD]] : !llvm.struct<(i8, i8)>
+
+// -----
+
+// Test that `llvm.byref` arguments are converted to `llvm.ptr` and the actual
+// value is retrieved within the `llvm.func`.
+
+// CHECK-LABEL: llvm.func @byref
+func.func @byref(%arg0: !test.smpla {llvm.byref = !test.smpla}) -> !test.smpla {
+  return %arg0 : !test.smpla
+}
+
+// CHECK-SAME: (%[[ARG0:.*]]: !llvm.ptr {llvm.byref = !llvm.struct<(i8, i8)>}) -> !llvm.struct<(i8, i8)>
+//      CHECK: %[[LD:.*]] = llvm.load %[[ARG0]] : !llvm.ptr -> !llvm.struct<(i8, i8)>
+//      CHECK: llvm.return %[[LD]] : !llvm.struct<(i8, i8)>
\ No newline at end of file
diff --git a/mlir/test/lib/Conversion/FuncToLLVM/TestConvertFuncOp.cpp b/mlir/test/lib/Conversion/FuncToLLVM/TestConvertFuncOp.cpp
index e25e890e2290a..75168dde93130 100644
--- a/mlir/test/lib/Conversion/FuncToLLVM/TestConvertFuncOp.cpp
+++ b/mlir/test/lib/Conversion/FuncToLLVM/TestConvertFuncOp.cpp
@@ -47,12 +47,23 @@ struct ReturnOpConversion : public ConvertOpToLLVMPattern<func::ReturnOp> {
   LogicalResult
   matchAndRewrite(func::ReturnOp returnOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnOp,
-                                                returnOp->getOperands());
+    SmallVector<Type> resTys;
+    if (failed(typeConverter->convertTypes(returnOp->getResultTypes(), resTys)))
+      return failure();
+
+    rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnOp, resTys,
+                                                adaptor.getOperands());
     return success();
   }
 };
 
+static std::optional<Type>
+convertSimpleATypeToStruct(test::SimpleAType simpleTy) {
+  MLIRContext *ctx = simpleTy.getContext();
+  SmallVector<Type> memberTys(2, IntegerType::get(ctx, /*width=*/8));
+  return LLVM::LLVMStructType::getLiteral(ctx, memberTys);
+}
+
 struct TestConvertFuncOp
     : public PassWrapper<TestConvertFuncOp, OperationPass<ModuleOp>> {
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestConvertFuncOp)
@@ -74,6 +85,7 @@ struct TestConvertFuncOp
     LowerToLLVMOptions options(ctx);
     // Populate type conversions.
     LLVMTypeConverter typeConverter(ctx, options);
+    typeConverter.addConversion(convertSimpleATypeToStruct);
 
     RewritePatternSet patterns(ctx);
     patterns.add<FuncOpConversion>(typeConverter);



More information about the Mlir-commits mailing list