[Mlir-commits] [mlir] [mlir][LLVM] Improve lowering of `llvm.byval` function arguments (PR #100028)

Diego Caballero llvmlistbot at llvm.org
Thu Aug 1 13:01:05 PDT 2024


https://github.com/dcaballe updated https://github.com/llvm/llvm-project/pull/100028

>From eb7cec3712345e5260142886e1fa608ad57ca69b Mon Sep 17 00:00:00 2001
From: Diego Caballero <dieg0ca6aller0 at gmail.com>
Date: Fri, 26 Jul 2024 16:22:13 -0700
Subject: [PATCH 1/4] [mlir][LLVM] Improve lowering of llvm.byval function
 arguments

When a function argument is annotated with the `llvm.byval` attribute,
[LLVM expects] (https://llvm.org/docs/LangRef.html#parameter-attributes)
the function argument type to be an `llvm.ptr`. For example:

```
func.func (%args0 : llvm.ptr {llvm.byval = !llvm.struct<(i32)>} {
  ...
}
```

Unfortunately, this makes the type conversion context-dependent, which is
something that the type conversion infrastructure (i.e., `LLVMTypeConverter`
in this particular case) doesn't support. For example, we may want to convert
`MyType` to `llvm.struct<(i32)>` in general, but to an `llvm.ptr` type only
when it's a function argument passed by value.

To fix this problem, this PR changes the FuncToLLVM conversion logic to always
generate an `llvm.ptr` when the function argument has a `llvm.byval` attribute.
An `llvm.load` is inserted into the function to retrieve the value expected by
the argument users.
---
 .../Conversion/LLVMCommon/TypeConverter.h     | 29 +++++---
 mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp | 39 +++++++++-
 .../Conversion/GPUCommon/GPUOpsLowering.cpp   |  6 +-
 .../Conversion/LLVMCommon/TypeConverter.cpp   | 71 +++++++++++++++----
 .../Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp    |  4 +-
 .../test/Transforms/test-convert-func-op.mlir | 30 +++++++-
 .../FuncToLLVM/TestConvertFuncOp.cpp          | 16 ++++-
 7 files changed, 165 insertions(+), 30 deletions(-)

diff --git a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
index e228229302cff..007e6ba39b632 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
@@ -21,6 +21,7 @@
 namespace mlir {
 
 class DataLayoutAnalysis;
+class FunctionOpInterface;
 class LowerToLLVMOptions;
 
 namespace LLVM {
@@ -35,6 +36,7 @@ class LLVMTypeConverter : public TypeConverter {
   /// Give structFuncArgTypeConverter access to memref-specific functions.
   friend LogicalResult
   structFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type,
+                             std::optional<NamedAttribute> byValByRefArgAttr,
                              SmallVectorImpl<Type> &result);
 
 public:
@@ -53,9 +55,10 @@ class LLVMTypeConverter : public TypeConverter {
   /// Convert a function type.  The arguments and results are converted one by
   /// one and results are packed into a wrapped LLVM IR structure type. `result`
   /// is populated with argument mapping.
-  Type convertFunctionSignature(FunctionType funcTy, bool isVariadic,
-                                bool useBarePtrCallConv,
-                                SignatureConversion &result) const;
+  Type convertFunctionSignature(
+      FunctionType funcTy, bool isVariadic, bool useBarePtrCallConv,
+      ArrayRef<std::optional<NamedAttribute>> byValByRefArgAttr,
+      SignatureConversion &result) const;
 
   /// Convert a non-empty list of types to be returned from a function into an
   /// LLVM-compatible type. In particular, if more than one value is returned,
@@ -242,15 +245,23 @@ class LLVMTypeConverter : public TypeConverter {
 /// argument to a list of non-aggregate types containing descriptor
 /// information, and an UnrankedmemRef function argument to a list containing
 /// the rank and a pointer to a descriptor struct.
-LogicalResult structFuncArgTypeConverter(const LLVMTypeConverter &converter,
-                                         Type type,
-                                         SmallVectorImpl<Type> &result);
+LogicalResult
+structFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type,
+                           std::optional<NamedAttribute> byValByRefArgAttr,
+                           SmallVectorImpl<Type> &result);
 
 /// Callback to convert function argument types. It converts MemRef function
 /// arguments to bare pointers to the MemRef element type.
-LogicalResult barePtrFuncArgTypeConverter(const LLVMTypeConverter &converter,
-                                          Type type,
-                                          SmallVectorImpl<Type> &result);
+LogicalResult
+barePtrFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type,
+                            std::optional<NamedAttribute> byValByRefArgAttr,
+                            SmallVectorImpl<Type> &result);
+
+/// Returns in `result` the `llvm.byval` or `llvm.byref` attributes, if
+/// present, or an empty attribute for each function argument.
+void filterByValByRefArgAttributes(
+    FunctionOpInterface funcOp,
+    SmallVectorImpl<std::optional<NamedAttribute>> &result);
 
 } // namespace mlir
 
diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
index c1f6d8bc5b361..8d4645d46f069 100644
--- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
+++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
@@ -267,6 +267,36 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc,
   }
 }
 
+/// Inserts `llvm.load` ops in the function body to restore the expected pointee
+/// value from `llvm.byval`/`llvm.byref` function arguments that were converted
+/// to LLVM pointer types.
+static void restoreByValByRefArgumentType(ConversionPatternRewriter &rewriter,
+                                          FunctionOpInterface funcOp) {
+  // Nothing to do for function declarations.
+  if (funcOp.isExternal())
+    return;
+
+  ConversionPatternRewriter::InsertionGuard guard(rewriter);
+  rewriter.setInsertionPointToStart(&funcOp.getFunctionBody().front());
+
+  SmallVector<std::optional<NamedAttribute>> byValRefArgAttrs;
+  filterByValByRefArgAttributes(funcOp, byValRefArgAttrs);
+  for (const auto &[arg, byValRefAttr] :
+       llvm::zip(funcOp.getArguments(), byValRefArgAttrs)) {
+    // Skip argument if no `llvm.byval` or `llvm.byref` attribute.
+    if (!byValRefAttr)
+      continue;
+
+    // Insert load to retrieve the actual argument passed by value/reference.
+    assert(isa<LLVM::LLVMPointerType>(arg.getType()) &&
+           "Expected LLVM pointer type for argument with "
+           "`llvm.byval`/`llvm.byref` attribute");
+    Type resTy = cast<TypeAttr>(byValRefAttr->getValue()).getValue();
+    auto valueArg = rewriter.create<LLVM::LoadOp>(arg.getLoc(), resTy, arg);
+    rewriter.replaceAllUsesExcept(arg, valueArg, valueArg);
+  }
+}
+
 FailureOr<LLVM::LLVMFuncOp>
 mlir::convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
                                 ConversionPatternRewriter &rewriter,
@@ -280,10 +310,12 @@ mlir::convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
   // Convert the original function arguments. They are converted using the
   // LLVMTypeConverter provided to this legalization pattern.
   auto varargsAttr = funcOp->getAttrOfType<BoolAttr>(varargsAttrName);
+  SmallVector<std::optional<NamedAttribute>> byValByRefArgAttrs;
+  filterByValByRefArgAttributes(funcOp, byValByRefArgAttrs);
   TypeConverter::SignatureConversion result(funcOp.getNumArguments());
   auto llvmType = converter.convertFunctionSignature(
       funcTy, varargsAttr && varargsAttr.getValue(),
-      shouldUseBarePtrCallConv(funcOp, &converter), result);
+      shouldUseBarePtrCallConv(funcOp, &converter), byValByRefArgAttrs, result);
   if (!llvmType)
     return rewriter.notifyMatchFailure(funcOp, "signature conversion failed");
 
@@ -398,6 +430,11 @@ mlir::convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
                                        "region types conversion failed");
   }
 
+  // Fix the type mismatch between the generated `llvm.ptr` and the expected
+  // pointee type in the function body when converting `llvm.byval`/`llvm.byref`
+  // function arguments.
+  restoreByValByRefArgumentType(rewriter, newFuncOp);
+
   if (!shouldUseBarePtrCallConv(funcOp, &converter)) {
     if (funcOp->getAttrOfType<UnitAttr>(
             LLVM::LLVMDialect::getEmitCWrapperAttrName())) {
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index 6053e34f30a41..143f7b3071253 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -53,10 +53,12 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
   // Remap proper input types.
   TypeConverter::SignatureConversion signatureConversion(
       gpuFuncOp.front().getNumArguments());
-
+  SmallVector<std::optional<NamedAttribute>> byValByRefArgAttrs;
+  filterByValByRefArgAttributes(gpuFuncOp, byValByRefArgAttrs);
   Type funcType = getTypeConverter()->convertFunctionSignature(
       gpuFuncOp.getFunctionType(), /*isVariadic=*/false,
-      getTypeConverter()->getOptions().useBarePtrCallConv, signatureConversion);
+      getTypeConverter()->getOptions().useBarePtrCallConv, byValByRefArgAttrs,
+      signatureConversion);
   if (!funcType) {
     return rewriter.notifyMatchFailure(gpuFuncOp, [&](Diagnostic &diag) {
       diag << "failed to convert function signature type for: "
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index d5df960928afb..c62096fdc853d 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -276,6 +276,7 @@ Type LLVMTypeConverter::convertFunctionType(FunctionType type) const {
 // they are into an LLVM StructType in their order of appearance.
 Type LLVMTypeConverter::convertFunctionSignature(
     FunctionType funcTy, bool isVariadic, bool useBarePtrCallConv,
+    ArrayRef<std::optional<NamedAttribute>> byValByRefArgAtts,
     LLVMTypeConverter::SignatureConversion &result) const {
   // Select the argument converter depending on the calling convention.
   useBarePtrCallConv = useBarePtrCallConv || options.useBarePtrCallConv;
@@ -284,7 +285,8 @@ Type LLVMTypeConverter::convertFunctionSignature(
   // Convert argument types one by one and check for errors.
   for (auto [idx, type] : llvm::enumerate(funcTy.getInputs())) {
     SmallVector<Type, 8> converted;
-    if (failed(funcArgConverter(*this, type, converted)))
+    if (failed(
+            funcArgConverter(*this, type, byValByRefArgAtts[idx], converted)))
       return {};
     result.addInputs(idx, converted);
   }
@@ -659,9 +661,10 @@ LLVMTypeConverter::promoteOperands(Location loc, ValueRange opOperands,
 /// argument to a list of non-aggregate types containing descriptor
 /// information, and an UnrankedmemRef function argument to a list containing
 /// the rank and a pointer to a descriptor struct.
-LogicalResult
-mlir::structFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type,
-                                 SmallVectorImpl<Type> &result) {
+LogicalResult mlir::structFuncArgTypeConverter(
+    const LLVMTypeConverter &converter, Type type,
+    std::optional<NamedAttribute> byValByRefArgAttr,
+    SmallVectorImpl<Type> &result) {
   if (auto memref = dyn_cast<MemRefType>(type)) {
     // In signatures, Memref descriptors are expanded into lists of
     // non-aggregate values.
@@ -679,23 +682,63 @@ mlir::structFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type,
     result.append(converted.begin(), converted.end());
     return success();
   }
-  auto converted = converter.convertType(type);
-  if (!converted)
-    return failure();
+
+  /// If the argument has the `llvm.byval` or `llvm.byref` attribute, the
+  /// converted type is an LLVM pointer so that the LLVM argument passing
+  /// is correct.
+  Type converted;
+  if (byValByRefArgAttr.has_value() &&
+      (byValByRefArgAttr->getName() == LLVM::LLVMDialect::getByValAttrName() ||
+       byValByRefArgAttr->getName() == LLVM::LLVMDialect::getByRefAttrName())) {
+    converted = LLVM::LLVMPointerType::get(type.getContext());
+  } else {
+    converted = converter.convertType(type);
+    if (!converted)
+      return failure();
+  }
+
   result.push_back(converted);
   return success();
 }
 
 /// Callback to convert function argument types. It converts MemRef function
 /// arguments to bare pointers to the MemRef element type.
-LogicalResult
-mlir::barePtrFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type,
-                                  SmallVectorImpl<Type> &result) {
-  auto llvmTy = converter.convertCallingConventionType(
-      type, /*useBarePointerCallConv=*/true);
-  if (!llvmTy)
-    return failure();
+LogicalResult mlir::barePtrFuncArgTypeConverter(
+    const LLVMTypeConverter &converter, Type type,
+    std::optional<NamedAttribute> byValByRefArgAttr,
+    SmallVectorImpl<Type> &result) {
+  /// If the argument has the `llvm.byval` or `llvm.byref` attribute, the
+  /// converted type is an LLVM pointer so that the LLVM argument passing
+  /// convention is correct.
+  Type llvmTy;
+  if (byValByRefArgAttr.has_value() &&
+      (byValByRefArgAttr->getName() == LLVM::LLVMDialect::getByValAttrName() ||
+       byValByRefArgAttr->getName() == LLVM::LLVMDialect::getByRefAttrName())) {
+    llvmTy = LLVM::LLVMPointerType::get(type.getContext());
+  } else {
+    llvmTy = converter.convertCallingConventionType(
+        type, /*useBarePointerCallConv=*/true);
+
+    if (!llvmTy)
+      return failure();
+  }
 
   result.push_back(llvmTy);
   return success();
 }
+
+void mlir::filterByValByRefArgAttributes(
+    FunctionOpInterface funcOp,
+    SmallVectorImpl<std::optional<NamedAttribute>> &result) {
+  assert(result.empty() && "Unexpected non-empty output");
+  result.resize(funcOp.getNumArguments(), std::nullopt);
+  for (int argIdx : llvm::seq(funcOp.getNumArguments())) {
+    for (NamedAttribute namedAttr : funcOp.getArgAttrs(argIdx)) {
+      if (namedAttr.getName() == LLVM::LLVMDialect::getByValAttrName() ||
+          namedAttr.getName() == LLVM::LLVMDialect::getByRefAttrName()) {
+        result[argIdx] = namedAttr;
+        break;
+      }
+    }
+  }
+}
diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
index da09384bfbe89..d5a9bc3783660 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
@@ -1408,11 +1408,13 @@ class FuncConversionPattern : public SPIRVToLLVMConversion<spirv::FuncOp> {
     // Convert function signature. At the moment LLVMType converter is enough
     // for currently supported types.
     auto funcType = funcOp.getFunctionType();
+    SmallVector<std::optional<NamedAttribute>> byValByRefArgAttrs;
+    filterByValByRefArgAttributes(funcOp, byValByRefArgAttrs);
     TypeConverter::SignatureConversion signatureConverter(
         funcType.getNumInputs());
     auto llvmType = typeConverter.convertFunctionSignature(
         funcType, /*isVariadic=*/false, /*useBarePtrCallConv=*/false,
-        signatureConverter);
+        byValByRefArgAttrs, signatureConverter);
     if (!llvmType)
       return failure();
 
diff --git a/mlir/test/Transforms/test-convert-func-op.mlir b/mlir/test/Transforms/test-convert-func-op.mlir
index 6e96703cda578..9f006a5a187c1 100644
--- a/mlir/test/Transforms/test-convert-func-op.mlir
+++ b/mlir/test/Transforms/test-convert-func-op.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -test-convert-func-op | FileCheck %s
+// RUN: mlir-opt %s -test-convert-func-op -split-input-file | FileCheck %s
 
 // CHECK-LABEL: llvm.func @add
 func.func @add(%arg0: i32, %arg1: i32) -> i32 attributes { llvm.emit_c_interface } {
@@ -10,3 +10,31 @@ func.func @add(%arg0: i32, %arg1: i32) -> i32 attributes { llvm.emit_c_interface
 // CHECK-SAME: [[ARG1:%[a-zA-Z0-9_]+]]: i32
 // CHECK-NEXT: [[RES:%.*]] = llvm.call @add([[ARG0]], [[ARG1]])
 // CHECK-NEXT: llvm.return [[RES]]
+
+// -----
+
+// Test that `llvm.byval` arguments are converted to `llvm.ptr` and the actual
+// value is retrieved within the `llvm.func`.
+
+// CHECK-LABEL: llvm.func @byval
+func.func @byval(%arg0: !test.smpla {llvm.byval = !test.smpla}) -> !test.smpla {
+  return %arg0 : !test.smpla
+}
+
+// CHECK-SAME: (%[[ARG0:.*]]: !llvm.ptr {llvm.byval = !llvm.struct<(i8, i8)>}) -> !llvm.struct<(i8, i8)>
+//      CHECK: %[[LD:.*]] = llvm.load %[[ARG0]] : !llvm.ptr -> !llvm.struct<(i8, i8)>
+//      CHECK: llvm.return %[[LD]] : !llvm.struct<(i8, i8)>
+
+// -----
+
+// Test that `llvm.byref` arguments are converted to `llvm.ptr` and the actual
+// value is retrieved within the `llvm.func`.
+
+// CHECK-LABEL: llvm.func @byref
+func.func @byref(%arg0: !test.smpla {llvm.byref = !test.smpla}) -> !test.smpla {
+  return %arg0 : !test.smpla
+}
+
+// CHECK-SAME: (%[[ARG0:.*]]: !llvm.ptr {llvm.byref = !llvm.struct<(i8, i8)>}) -> !llvm.struct<(i8, i8)>
+//      CHECK: %[[LD:.*]] = llvm.load %[[ARG0]] : !llvm.ptr -> !llvm.struct<(i8, i8)>
+//      CHECK: llvm.return %[[LD]] : !llvm.struct<(i8, i8)>
\ No newline at end of file
diff --git a/mlir/test/lib/Conversion/FuncToLLVM/TestConvertFuncOp.cpp b/mlir/test/lib/Conversion/FuncToLLVM/TestConvertFuncOp.cpp
index e25e890e2290a..75168dde93130 100644
--- a/mlir/test/lib/Conversion/FuncToLLVM/TestConvertFuncOp.cpp
+++ b/mlir/test/lib/Conversion/FuncToLLVM/TestConvertFuncOp.cpp
@@ -47,12 +47,23 @@ struct ReturnOpConversion : public ConvertOpToLLVMPattern<func::ReturnOp> {
   LogicalResult
   matchAndRewrite(func::ReturnOp returnOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnOp,
-                                                returnOp->getOperands());
+    SmallVector<Type> resTys;
+    if (failed(typeConverter->convertTypes(returnOp->getResultTypes(), resTys)))
+      return failure();
+
+    rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnOp, resTys,
+                                                adaptor.getOperands());
     return success();
   }
 };
 
+static std::optional<Type>
+convertSimpleATypeToStruct(test::SimpleAType simpleTy) {
+  MLIRContext *ctx = simpleTy.getContext();
+  SmallVector<Type> memberTys(2, IntegerType::get(ctx, /*width=*/8));
+  return LLVM::LLVMStructType::getLiteral(ctx, memberTys);
+}
+
 struct TestConvertFuncOp
     : public PassWrapper<TestConvertFuncOp, OperationPass<ModuleOp>> {
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestConvertFuncOp)
@@ -74,6 +85,7 @@ struct TestConvertFuncOp
     LowerToLLVMOptions options(ctx);
     // Populate type conversions.
     LLVMTypeConverter typeConverter(ctx, options);
+    typeConverter.addConversion(convertSimpleATypeToStruct);
 
     RewritePatternSet patterns(ctx);
     patterns.add<FuncOpConversion>(typeConverter);

>From 2160c7c51954fb76153a455fa1ec24a2882ba90d Mon Sep 17 00:00:00 2001
From: Diego Caballero <dieg0ca6aller0 at gmail.com>
Date: Tue, 30 Jul 2024 16:02:20 -0700
Subject: [PATCH 2/4] Revert "[mlir][LLVM] Improve lowering of llvm.byval
 function arguments"

This reverts commit eb7cec3712345e5260142886e1fa608ad57ca69b.
---
 .../Conversion/LLVMCommon/TypeConverter.h     | 29 +++-----
 mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp | 39 +---------
 .../Conversion/GPUCommon/GPUOpsLowering.cpp   |  6 +-
 .../Conversion/LLVMCommon/TypeConverter.cpp   | 71 ++++---------------
 .../Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp    |  4 +-
 .../test/Transforms/test-convert-func-op.mlir | 30 +-------
 .../FuncToLLVM/TestConvertFuncOp.cpp          | 16 +----
 7 files changed, 30 insertions(+), 165 deletions(-)

diff --git a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
index 007e6ba39b632..e228229302cff 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
@@ -21,7 +21,6 @@
 namespace mlir {
 
 class DataLayoutAnalysis;
-class FunctionOpInterface;
 class LowerToLLVMOptions;
 
 namespace LLVM {
@@ -36,7 +35,6 @@ class LLVMTypeConverter : public TypeConverter {
   /// Give structFuncArgTypeConverter access to memref-specific functions.
   friend LogicalResult
   structFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type,
-                             std::optional<NamedAttribute> byValByRefArgAttr,
                              SmallVectorImpl<Type> &result);
 
 public:
@@ -55,10 +53,9 @@ class LLVMTypeConverter : public TypeConverter {
   /// Convert a function type.  The arguments and results are converted one by
   /// one and results are packed into a wrapped LLVM IR structure type. `result`
   /// is populated with argument mapping.
-  Type convertFunctionSignature(
-      FunctionType funcTy, bool isVariadic, bool useBarePtrCallConv,
-      ArrayRef<std::optional<NamedAttribute>> byValByRefArgAttr,
-      SignatureConversion &result) const;
+  Type convertFunctionSignature(FunctionType funcTy, bool isVariadic,
+                                bool useBarePtrCallConv,
+                                SignatureConversion &result) const;
 
   /// Convert a non-empty list of types to be returned from a function into an
   /// LLVM-compatible type. In particular, if more than one value is returned,
@@ -245,23 +242,15 @@ class LLVMTypeConverter : public TypeConverter {
 /// argument to a list of non-aggregate types containing descriptor
 /// information, and an UnrankedmemRef function argument to a list containing
 /// the rank and a pointer to a descriptor struct.
-LogicalResult
-structFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type,
-                           std::optional<NamedAttribute> byValByRefArgAttr,
-                           SmallVectorImpl<Type> &result);
+LogicalResult structFuncArgTypeConverter(const LLVMTypeConverter &converter,
+                                         Type type,
+                                         SmallVectorImpl<Type> &result);
 
 /// Callback to convert function argument types. It converts MemRef function
 /// arguments to bare pointers to the MemRef element type.
-LogicalResult
-barePtrFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type,
-                            std::optional<NamedAttribute> byValByRefArgAttr,
-                            SmallVectorImpl<Type> &result);
-
-/// Returns in `result` the `llvm.byval` or `llvm.byref` attributes, if
-/// present, or an empty attribute for each function argument.
-void filterByValByRefArgAttributes(
-    FunctionOpInterface funcOp,
-    SmallVectorImpl<std::optional<NamedAttribute>> &result);
+LogicalResult barePtrFuncArgTypeConverter(const LLVMTypeConverter &converter,
+                                          Type type,
+                                          SmallVectorImpl<Type> &result);
 
 } // namespace mlir
 
diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
index 8d4645d46f069..c1f6d8bc5b361 100644
--- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
+++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
@@ -267,36 +267,6 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc,
   }
 }
 
-/// Inserts `llvm.load` ops in the function body to restore the expected pointee
-/// value from `llvm.byval`/`llvm.byref` function arguments that were converted
-/// to LLVM pointer types.
-static void restoreByValByRefArgumentType(ConversionPatternRewriter &rewriter,
-                                          FunctionOpInterface funcOp) {
-  // Nothing to do for function declarations.
-  if (funcOp.isExternal())
-    return;
-
-  ConversionPatternRewriter::InsertionGuard guard(rewriter);
-  rewriter.setInsertionPointToStart(&funcOp.getFunctionBody().front());
-
-  SmallVector<std::optional<NamedAttribute>> byValRefArgAttrs;
-  filterByValByRefArgAttributes(funcOp, byValRefArgAttrs);
-  for (const auto &[arg, byValRefAttr] :
-       llvm::zip(funcOp.getArguments(), byValRefArgAttrs)) {
-    // Skip argument if no `llvm.byval` or `llvm.byref` attribute.
-    if (!byValRefAttr)
-      continue;
-
-    // Insert load to retrieve the actual argument passed by value/reference.
-    assert(isa<LLVM::LLVMPointerType>(arg.getType()) &&
-           "Expected LLVM pointer type for argument with "
-           "`llvm.byval`/`llvm.byref` attribute");
-    Type resTy = cast<TypeAttr>(byValRefAttr->getValue()).getValue();
-    auto valueArg = rewriter.create<LLVM::LoadOp>(arg.getLoc(), resTy, arg);
-    rewriter.replaceAllUsesExcept(arg, valueArg, valueArg);
-  }
-}
-
 FailureOr<LLVM::LLVMFuncOp>
 mlir::convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
                                 ConversionPatternRewriter &rewriter,
@@ -310,12 +280,10 @@ mlir::convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
   // Convert the original function arguments. They are converted using the
   // LLVMTypeConverter provided to this legalization pattern.
   auto varargsAttr = funcOp->getAttrOfType<BoolAttr>(varargsAttrName);
-  SmallVector<std::optional<NamedAttribute>> byValByRefArgAttrs;
-  filterByValByRefArgAttributes(funcOp, byValByRefArgAttrs);
   TypeConverter::SignatureConversion result(funcOp.getNumArguments());
   auto llvmType = converter.convertFunctionSignature(
       funcTy, varargsAttr && varargsAttr.getValue(),
-      shouldUseBarePtrCallConv(funcOp, &converter), byValByRefArgAttrs, result);
+      shouldUseBarePtrCallConv(funcOp, &converter), result);
   if (!llvmType)
     return rewriter.notifyMatchFailure(funcOp, "signature conversion failed");
 
@@ -430,11 +398,6 @@ mlir::convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
                                        "region types conversion failed");
   }
 
-  // Fix the type mismatch between the generated `llvm.ptr` and the expected
-  // pointee type in the function body when converting `llvm.byval`/`llvm.byref`
-  // function arguments.
-  restoreByValByRefArgumentType(rewriter, newFuncOp);
-
   if (!shouldUseBarePtrCallConv(funcOp, &converter)) {
     if (funcOp->getAttrOfType<UnitAttr>(
             LLVM::LLVMDialect::getEmitCWrapperAttrName())) {
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index 143f7b3071253..6053e34f30a41 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -53,12 +53,10 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
   // Remap proper input types.
   TypeConverter::SignatureConversion signatureConversion(
       gpuFuncOp.front().getNumArguments());
-  SmallVector<std::optional<NamedAttribute>> byValByRefArgAttrs;
-  filterByValByRefArgAttributes(gpuFuncOp, byValByRefArgAttrs);
+
   Type funcType = getTypeConverter()->convertFunctionSignature(
       gpuFuncOp.getFunctionType(), /*isVariadic=*/false,
-      getTypeConverter()->getOptions().useBarePtrCallConv, byValByRefArgAttrs,
-      signatureConversion);
+      getTypeConverter()->getOptions().useBarePtrCallConv, signatureConversion);
   if (!funcType) {
     return rewriter.notifyMatchFailure(gpuFuncOp, [&](Diagnostic &diag) {
       diag << "failed to convert function signature type for: "
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index c62096fdc853d..d5df960928afb 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -276,7 +276,6 @@ Type LLVMTypeConverter::convertFunctionType(FunctionType type) const {
 // they are into an LLVM StructType in their order of appearance.
 Type LLVMTypeConverter::convertFunctionSignature(
     FunctionType funcTy, bool isVariadic, bool useBarePtrCallConv,
-    ArrayRef<std::optional<NamedAttribute>> byValByRefArgAtts,
     LLVMTypeConverter::SignatureConversion &result) const {
   // Select the argument converter depending on the calling convention.
   useBarePtrCallConv = useBarePtrCallConv || options.useBarePtrCallConv;
@@ -285,8 +284,7 @@ Type LLVMTypeConverter::convertFunctionSignature(
   // Convert argument types one by one and check for errors.
   for (auto [idx, type] : llvm::enumerate(funcTy.getInputs())) {
     SmallVector<Type, 8> converted;
-    if (failed(
-            funcArgConverter(*this, type, byValByRefArgAtts[idx], converted)))
+    if (failed(funcArgConverter(*this, type, converted)))
       return {};
     result.addInputs(idx, converted);
   }
@@ -661,10 +659,9 @@ LLVMTypeConverter::promoteOperands(Location loc, ValueRange opOperands,
 /// argument to a list of non-aggregate types containing descriptor
 /// information, and an UnrankedmemRef function argument to a list containing
 /// the rank and a pointer to a descriptor struct.
-LogicalResult mlir::structFuncArgTypeConverter(
-    const LLVMTypeConverter &converter, Type type,
-    std::optional<NamedAttribute> byValByRefArgAttr,
-    SmallVectorImpl<Type> &result) {
+LogicalResult
+mlir::structFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type,
+                                 SmallVectorImpl<Type> &result) {
   if (auto memref = dyn_cast<MemRefType>(type)) {
     // In signatures, Memref descriptors are expanded into lists of
     // non-aggregate values.
@@ -682,63 +679,23 @@ LogicalResult mlir::structFuncArgTypeConverter(
     result.append(converted.begin(), converted.end());
     return success();
   }
-
-  /// If the argument has the `llvm.byval` or `llvm.byref` attribute, the
-  /// converted type is an LLVM pointer so that the LLVM argument passing
-  /// is correct.
-  Type converted;
-  if (byValByRefArgAttr.has_value() &&
-      (byValByRefArgAttr->getName() == LLVM::LLVMDialect::getByValAttrName() ||
-       byValByRefArgAttr->getName() == LLVM::LLVMDialect::getByRefAttrName())) {
-    converted = LLVM::LLVMPointerType::get(type.getContext());
-  } else {
-    converted = converter.convertType(type);
-    if (!converted)
-      return failure();
-  }
-
+  auto converted = converter.convertType(type);
+  if (!converted)
+    return failure();
   result.push_back(converted);
   return success();
 }
 
 /// Callback to convert function argument types. It converts MemRef function
 /// arguments to bare pointers to the MemRef element type.
-LogicalResult mlir::barePtrFuncArgTypeConverter(
-    const LLVMTypeConverter &converter, Type type,
-    std::optional<NamedAttribute> byValByRefArgAttr,
-    SmallVectorImpl<Type> &result) {
-  /// If the argument has the `llvm.byval` or `llvm.byref` attribute, the
-  /// converted type is an LLVM pointer so that the LLVM argument passing
-  /// convention is correct.
-  Type llvmTy;
-  if (byValByRefArgAttr.has_value() &&
-      (byValByRefArgAttr->getName() == LLVM::LLVMDialect::getByValAttrName() ||
-       byValByRefArgAttr->getName() == LLVM::LLVMDialect::getByRefAttrName())) {
-    llvmTy = LLVM::LLVMPointerType::get(type.getContext());
-  } else {
-    llvmTy = converter.convertCallingConventionType(
-        type, /*useBarePointerCallConv=*/true);
-
-    if (!llvmTy)
-      return failure();
-  }
+LogicalResult
+mlir::barePtrFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type,
+                                  SmallVectorImpl<Type> &result) {
+  auto llvmTy = converter.convertCallingConventionType(
+      type, /*useBarePointerCallConv=*/true);
+  if (!llvmTy)
+    return failure();
 
   result.push_back(llvmTy);
   return success();
 }
-
-void mlir::filterByValByRefArgAttributes(
-    FunctionOpInterface funcOp,
-    SmallVectorImpl<std::optional<NamedAttribute>> &result) {
-  assert(result.empty() && "Unexpected non-empty output");
-  result.resize(funcOp.getNumArguments(), std::nullopt);
-  for (int argIdx : llvm::seq(funcOp.getNumArguments())) {
-    for (NamedAttribute namedAttr : funcOp.getArgAttrs(argIdx)) {
-      if (namedAttr.getName() == LLVM::LLVMDialect::getByValAttrName() ||
-          namedAttr.getName() == LLVM::LLVMDialect::getByRefAttrName()) {
-        result[argIdx] = namedAttr;
-        break;
-      }
-    }
-  }
-}
diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
index d5a9bc3783660..da09384bfbe89 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
@@ -1408,13 +1408,11 @@ class FuncConversionPattern : public SPIRVToLLVMConversion<spirv::FuncOp> {
     // Convert function signature. At the moment LLVMType converter is enough
     // for currently supported types.
     auto funcType = funcOp.getFunctionType();
-    SmallVector<std::optional<NamedAttribute>> byValByRefArgAttrs;
-    filterByValByRefArgAttributes(funcOp, byValByRefArgAttrs);
     TypeConverter::SignatureConversion signatureConverter(
         funcType.getNumInputs());
     auto llvmType = typeConverter.convertFunctionSignature(
         funcType, /*isVariadic=*/false, /*useBarePtrCallConv=*/false,
-        byValByRefArgAttrs, signatureConverter);
+        signatureConverter);
     if (!llvmType)
       return failure();
 
diff --git a/mlir/test/Transforms/test-convert-func-op.mlir b/mlir/test/Transforms/test-convert-func-op.mlir
index 9f006a5a187c1..6e96703cda578 100644
--- a/mlir/test/Transforms/test-convert-func-op.mlir
+++ b/mlir/test/Transforms/test-convert-func-op.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -test-convert-func-op -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-convert-func-op | FileCheck %s
 
 // CHECK-LABEL: llvm.func @add
 func.func @add(%arg0: i32, %arg1: i32) -> i32 attributes { llvm.emit_c_interface } {
@@ -10,31 +10,3 @@ func.func @add(%arg0: i32, %arg1: i32) -> i32 attributes { llvm.emit_c_interface
 // CHECK-SAME: [[ARG1:%[a-zA-Z0-9_]+]]: i32
 // CHECK-NEXT: [[RES:%.*]] = llvm.call @add([[ARG0]], [[ARG1]])
 // CHECK-NEXT: llvm.return [[RES]]
-
-// -----
-
-// Test that `llvm.byval` arguments are converted to `llvm.ptr` and the actual
-// value is retrieved within the `llvm.func`.
-
-// CHECK-LABEL: llvm.func @byval
-func.func @byval(%arg0: !test.smpla {llvm.byval = !test.smpla}) -> !test.smpla {
-  return %arg0 : !test.smpla
-}
-
-// CHECK-SAME: (%[[ARG0:.*]]: !llvm.ptr {llvm.byval = !llvm.struct<(i8, i8)>}) -> !llvm.struct<(i8, i8)>
-//      CHECK: %[[LD:.*]] = llvm.load %[[ARG0]] : !llvm.ptr -> !llvm.struct<(i8, i8)>
-//      CHECK: llvm.return %[[LD]] : !llvm.struct<(i8, i8)>
-
-// -----
-
-// Test that `llvm.byref` arguments are converted to `llvm.ptr` and the actual
-// value is retrieved within the `llvm.func`.
-
-// CHECK-LABEL: llvm.func @byref
-func.func @byref(%arg0: !test.smpla {llvm.byref = !test.smpla}) -> !test.smpla {
-  return %arg0 : !test.smpla
-}
-
-// CHECK-SAME: (%[[ARG0:.*]]: !llvm.ptr {llvm.byref = !llvm.struct<(i8, i8)>}) -> !llvm.struct<(i8, i8)>
-//      CHECK: %[[LD:.*]] = llvm.load %[[ARG0]] : !llvm.ptr -> !llvm.struct<(i8, i8)>
-//      CHECK: llvm.return %[[LD]] : !llvm.struct<(i8, i8)>
\ No newline at end of file
diff --git a/mlir/test/lib/Conversion/FuncToLLVM/TestConvertFuncOp.cpp b/mlir/test/lib/Conversion/FuncToLLVM/TestConvertFuncOp.cpp
index 75168dde93130..e25e890e2290a 100644
--- a/mlir/test/lib/Conversion/FuncToLLVM/TestConvertFuncOp.cpp
+++ b/mlir/test/lib/Conversion/FuncToLLVM/TestConvertFuncOp.cpp
@@ -47,23 +47,12 @@ struct ReturnOpConversion : public ConvertOpToLLVMPattern<func::ReturnOp> {
   LogicalResult
   matchAndRewrite(func::ReturnOp returnOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    SmallVector<Type> resTys;
-    if (failed(typeConverter->convertTypes(returnOp->getResultTypes(), resTys)))
-      return failure();
-
-    rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnOp, resTys,
-                                                adaptor.getOperands());
+    rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnOp,
+                                                returnOp->getOperands());
     return success();
   }
 };
 
-static std::optional<Type>
-convertSimpleATypeToStruct(test::SimpleAType simpleTy) {
-  MLIRContext *ctx = simpleTy.getContext();
-  SmallVector<Type> memberTys(2, IntegerType::get(ctx, /*width=*/8));
-  return LLVM::LLVMStructType::getLiteral(ctx, memberTys);
-}
-
 struct TestConvertFuncOp
     : public PassWrapper<TestConvertFuncOp, OperationPass<ModuleOp>> {
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestConvertFuncOp)
@@ -85,7 +74,6 @@ struct TestConvertFuncOp
     LowerToLLVMOptions options(ctx);
     // Populate type conversions.
     LLVMTypeConverter typeConverter(ctx, options);
-    typeConverter.addConversion(convertSimpleATypeToStruct);
 
     RewritePatternSet patterns(ctx);
     patterns.add<FuncOpConversion>(typeConverter);

>From 9e59d4a61dc80879026f945e70c9944e47122c7d Mon Sep 17 00:00:00 2001
From: Diego Caballero <dieg0ca6aller0 at gmail.com>
Date: Wed, 31 Jul 2024 17:32:25 -0700
Subject: [PATCH 3/4] Second approach

---
 .../Conversion/LLVMCommon/TypeConverter.h     | 15 +++++
 .../mlir/Transforms/DialectConversion.h       |  4 ++
 mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp | 49 ++++++++++++++++-
 .../Conversion/LLVMCommon/TypeConverter.cpp   | 55 +++++++++++++++++++
 .../Transforms/Utils/DialectConversion.cpp    |  8 +++
 .../test/Transforms/test-convert-func-op.mlir | 30 +++++++++-
 .../FuncToLLVM/TestConvertFuncOp.cpp          | 16 +++++-
 7 files changed, 171 insertions(+), 6 deletions(-)

diff --git a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
index e228229302cff..a0cc42d4d09f9 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
@@ -21,6 +21,7 @@
 namespace mlir {
 
 class DataLayoutAnalysis;
+class FunctionOpInterface;
 class LowerToLLVMOptions;
 
 namespace LLVM {
@@ -57,6 +58,13 @@ class LLVMTypeConverter : public TypeConverter {
                                 bool useBarePtrCallConv,
                                 SignatureConversion &result) const;
 
+  /// Replace the type of `llvm.byval` and `llvm.byref` function arguments with
+  /// an LLVM pointer type in the function signature.
+  LLVM::LLVMFunctionType materializePtrForByValByRefFuncArgs(
+      LLVM::LLVMFunctionType funcType,
+      ArrayRef<std::optional<NamedAttribute>> byValRefArgAttrs,
+      LLVMTypeConverter::SignatureConversion &signatureConv) const;
+
   /// Convert a non-empty list of types to be returned from a function into an
   /// LLVM-compatible type. In particular, if more than one value is returned,
   /// create an LLVM dialect structure type with elements that correspond to
@@ -252,6 +260,13 @@ LogicalResult barePtrFuncArgTypeConverter(const LLVMTypeConverter &converter,
                                           Type type,
                                           SmallVectorImpl<Type> &result);
 
+/// Returns the `llvm.byval` or `llvm.byref` attributes that are present in the
+/// function arguments. Returns an empty container if none of these attributes
+/// are found in any of the attributes.
+void filterByValByRefArgAttributes(
+    FunctionOpInterface funcOp,
+    SmallVectorImpl<std::optional<NamedAttribute>> &result);
+
 } // namespace mlir
 
 #endif // MLIR_CONVERSION_LLVMCOMMON_TYPECONVERTER_H
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index a51b00271f0ae..7f0983f10ff82 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -96,6 +96,10 @@ class TypeConverter {
     /// value. This drops the original argument.
     void remapInput(unsigned origInputNo, Value replacement);
 
+    /// Replace the type of an input that has been previously remapped to a new
+    /// single input.
+    void replaceRemappedInputType(unsigned origInputNo, Type type);
+
   private:
     /// Remap an input of the original signature with a range of types in the
     /// new signature.
diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
index c1f6d8bc5b361..cca5aac31264b 100644
--- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
+++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
@@ -267,6 +267,36 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc,
   }
 }
 
+/// Inserts `llvm.load` ops in the function body to restore the expected pointee
+/// value from `llvm.byval`/`llvm.byref` function arguments that were converted
+/// to LLVM pointer types.
+static void restoreByValByRefArgumentType(ConversionPatternRewriter &rewriter,
+                                          FunctionOpInterface funcOp) {
+  // Nothing to do for function declarations.
+  if (funcOp.isExternal())
+    return;
+
+  ConversionPatternRewriter::InsertionGuard guard(rewriter);
+  rewriter.setInsertionPointToStart(&funcOp.getFunctionBody().front());
+
+  SmallVector<std::optional<NamedAttribute>> byValRefArgAttrs;
+  filterByValByRefArgAttributes(funcOp, byValRefArgAttrs);
+  for (const auto &[arg, byValRefAttr] :
+       llvm::zip(funcOp.getArguments(), byValRefArgAttrs)) {
+    // Skip argument if no `llvm.byval` or `llvm.byref` attribute.
+    if (!byValRefAttr)
+      continue;
+
+    // Insert load to retrieve the actual argument passed by value/reference.
+    assert(isa<LLVM::LLVMPointerType>(arg.getType()) &&
+           "Expected LLVM pointer type for argument with "
+           "`llvm.byval`/`llvm.byref` attribute");
+    Type resTy = cast<TypeAttr>(byValRefAttr->getValue()).getValue();
+    auto valueArg = rewriter.create<LLVM::LoadOp>(arg.getLoc(), resTy, arg);
+    rewriter.replaceAllUsesExcept(arg, valueArg, valueArg);
+  }
+}
+
 FailureOr<LLVM::LLVMFuncOp>
 mlir::convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
                                 ConversionPatternRewriter &rewriter,
@@ -281,12 +311,20 @@ mlir::convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
   // LLVMTypeConverter provided to this legalization pattern.
   auto varargsAttr = funcOp->getAttrOfType<BoolAttr>(varargsAttrName);
   TypeConverter::SignatureConversion result(funcOp.getNumArguments());
-  auto llvmType = converter.convertFunctionSignature(
-      funcTy, varargsAttr && varargsAttr.getValue(),
-      shouldUseBarePtrCallConv(funcOp, &converter), result);
+  auto llvmType =
+      cast_or_null<LLVM::LLVMFunctionType>(converter.convertFunctionSignature(
+          funcTy, varargsAttr && varargsAttr.getValue(),
+          shouldUseBarePtrCallConv(funcOp, &converter), result));
   if (!llvmType)
     return rewriter.notifyMatchFailure(funcOp, "signature conversion failed");
 
+  // Make sure the type of `llvm.byval` and `llvm.byref` arguments are converted
+  // to LLVM pointer types.
+  SmallVector<std::optional<NamedAttribute>> byValByRefArgs;
+  filterByValByRefArgAttributes(funcOp, byValByRefArgs);
+  llvmType = converter.materializePtrForByValByRefFuncArgs(
+      llvmType, byValByRefArgs, result);
+
   // Create an LLVM function, use external linkage by default until MLIR
   // functions have linkage.
   LLVM::Linkage linkage = LLVM::Linkage::External;
@@ -398,6 +436,11 @@ mlir::convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
                                        "region types conversion failed");
   }
 
+  // Fix the type mismatch between the generated `llvm.ptr` and the expected
+  // pointee type in the function body when converting `llvm.byval`/`llvm.byref`
+  // function arguments.
+  restoreByValByRefArgumentType(rewriter, newFuncOp);
+
   if (!shouldUseBarePtrCallConv(funcOp, &converter)) {
     if (funcOp->getAttrOfType<UnitAttr>(
             LLVM::LLVMDialect::getEmitCWrapperAttrName())) {
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index d5df960928afb..7c05e5f1a330e 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -302,6 +302,40 @@ Type LLVMTypeConverter::convertFunctionSignature(
                                      isVariadic);
 }
 
+LLVM::LLVMFunctionType LLVMTypeConverter::materializePtrForByValByRefFuncArgs(
+    LLVM::LLVMFunctionType funcType,
+    ArrayRef<std::optional<NamedAttribute>> byValRefArgAttrs,
+    LLVMTypeConverter::SignatureConversion &signatureConv) const {
+  if (byValRefArgAttrs.empty())
+    return funcType;
+
+  // Replace the type of `llvm.byval` and `llvm.byref` arguments with an LLVM
+  // pointer type in the signature conversion.
+  for (int inArgIdx : llvm::seq(byValRefArgAttrs.size())) {
+    auto inAttr = byValRefArgAttrs[inArgIdx];
+    if (!inAttr)
+      continue;
+
+    StringRef inAttrName = inAttr->getName().getValue();
+    if (inAttrName != LLVM::LLVMDialect::getByValAttrName() &&
+        inAttrName != LLVM::LLVMDialect::getByRefAttrName())
+      continue;
+
+    auto mapping = signatureConv.getInputMapping(inArgIdx);
+    assert(mapping && "unexpected deletion of function argument");
+    // Replace the argument type with an LLVM pointer type. Only do so if there
+    // is a one-to-one mapping from old to new types.
+    if (mapping->size == 1) {
+      signatureConv.replaceRemappedInputType(
+          mapping->inputNo, LLVM::LLVMPointerType::get(&getContext()));
+    }
+  }
+
+  return LLVM::LLVMFunctionType::get(funcType.getReturnType(),
+                                     signatureConv.getConvertedTypes(),
+                                     funcType.isVarArg());
+}
+
 /// Converts the function type to a C-compatible format, in particular using
 /// pointers to memref descriptors for arguments.
 std::pair<LLVM::LLVMFunctionType, LLVM::LLVMStructType>
@@ -699,3 +733,24 @@ mlir::barePtrFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type,
   result.push_back(llvmTy);
   return success();
 }
+
+void mlir::filterByValByRefArgAttributes(
+    FunctionOpInterface funcOp,
+    SmallVectorImpl<std::optional<NamedAttribute>> &result) {
+  assert(result.empty() && "Unexpected non-empty output");
+  result.resize(funcOp.getNumArguments(), std::nullopt);
+  bool hasByValByRefAttrs = false;
+  for (int argIdx : llvm::seq(funcOp.getNumArguments())) {
+    for (NamedAttribute namedAttr : funcOp.getArgAttrs(argIdx)) {
+      if (namedAttr.getName() == LLVM::LLVMDialect::getByValAttrName() ||
+          namedAttr.getName() == LLVM::LLVMDialect::getByRefAttrName()) {
+        hasByValByRefAttrs = true;
+        result[argIdx] = namedAttr;
+        break;
+      }
+    }
+  }
+
+  if (!hasByValByRefAttrs)
+    result.clear();
+}
\ No newline at end of file
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index f26aa0a1516a6..99882a306c0ea 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -2958,6 +2958,14 @@ void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo,
       InputMapping{origInputNo, /*size=*/0, replacementValue};
 }
 
+void TypeConverter::SignatureConversion::replaceRemappedInputType(
+    unsigned origInputNo, Type type) {
+  auto inputMap = remappedInputs[origInputNo];
+  assert(inputMap && "Expected remapped input");
+  assert(inputMap->size == 1 && "Can't replace 1->N remapped input");
+  argTypes[inputMap->inputNo] = type;
+}
+
 LogicalResult TypeConverter::convertType(Type t,
                                          SmallVectorImpl<Type> &results) const {
   {
diff --git a/mlir/test/Transforms/test-convert-func-op.mlir b/mlir/test/Transforms/test-convert-func-op.mlir
index 6e96703cda578..e85f4dc990759 100644
--- a/mlir/test/Transforms/test-convert-func-op.mlir
+++ b/mlir/test/Transforms/test-convert-func-op.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -test-convert-func-op | FileCheck %s
+// RUN: mlir-opt %s -test-convert-func-op --split-input-file | FileCheck %s
 
 // CHECK-LABEL: llvm.func @add
 func.func @add(%arg0: i32, %arg1: i32) -> i32 attributes { llvm.emit_c_interface } {
@@ -10,3 +10,31 @@ func.func @add(%arg0: i32, %arg1: i32) -> i32 attributes { llvm.emit_c_interface
 // CHECK-SAME: [[ARG1:%[a-zA-Z0-9_]+]]: i32
 // CHECK-NEXT: [[RES:%.*]] = llvm.call @add([[ARG0]], [[ARG1]])
 // CHECK-NEXT: llvm.return [[RES]]
+
+// -----
+
+// Test that `llvm.byval` arguments are converted to `llvm.ptr` and the actual
+// value is retrieved within the `llvm.func`.
+
+// CHECK-LABEL: llvm.func @byval
+func.func @byval(%arg0: !test.smpla {llvm.byval = !test.smpla}) -> !test.smpla {
+  return %arg0 : !test.smpla
+}
+
+// CHECK-SAME: (%[[ARG0:.*]]: !llvm.ptr {llvm.byval = !llvm.struct<(i8, i8)>}) -> !llvm.struct<(i8, i8)>
+//      CHECK: %[[LD:.*]] = llvm.load %[[ARG0]] : !llvm.ptr -> !llvm.struct<(i8, i8)>
+//      CHECK: llvm.return %[[LD]] : !llvm.struct<(i8, i8)>
+
+// -----
+
+// Test that `llvm.byref` arguments are converted to `llvm.ptr` and the actual
+// value is retrieved within the `llvm.func`.
+
+// CHECK-LABEL: llvm.func @byref
+func.func @byref(%arg0: !test.smpla {llvm.byref = !test.smpla}) -> !test.smpla {
+  return %arg0 : !test.smpla
+}
+
+// CHECK-SAME: (%[[ARG0:.*]]: !llvm.ptr {llvm.byref = !llvm.struct<(i8, i8)>}) -> !llvm.struct<(i8, i8)>
+//      CHECK: %[[LD:.*]] = llvm.load %[[ARG0]] : !llvm.ptr -> !llvm.struct<(i8, i8)>
+//      CHECK: llvm.return %[[LD]] : !llvm.struct<(i8, i8)>
\ No newline at end of file
diff --git a/mlir/test/lib/Conversion/FuncToLLVM/TestConvertFuncOp.cpp b/mlir/test/lib/Conversion/FuncToLLVM/TestConvertFuncOp.cpp
index e25e890e2290a..75168dde93130 100644
--- a/mlir/test/lib/Conversion/FuncToLLVM/TestConvertFuncOp.cpp
+++ b/mlir/test/lib/Conversion/FuncToLLVM/TestConvertFuncOp.cpp
@@ -47,12 +47,23 @@ struct ReturnOpConversion : public ConvertOpToLLVMPattern<func::ReturnOp> {
   LogicalResult
   matchAndRewrite(func::ReturnOp returnOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnOp,
-                                                returnOp->getOperands());
+    SmallVector<Type> resTys;
+    if (failed(typeConverter->convertTypes(returnOp->getResultTypes(), resTys)))
+      return failure();
+
+    rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnOp, resTys,
+                                                adaptor.getOperands());
     return success();
   }
 };
 
+static std::optional<Type>
+convertSimpleATypeToStruct(test::SimpleAType simpleTy) {
+  MLIRContext *ctx = simpleTy.getContext();
+  SmallVector<Type> memberTys(2, IntegerType::get(ctx, /*width=*/8));
+  return LLVM::LLVMStructType::getLiteral(ctx, memberTys);
+}
+
 struct TestConvertFuncOp
     : public PassWrapper<TestConvertFuncOp, OperationPass<ModuleOp>> {
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestConvertFuncOp)
@@ -74,6 +85,7 @@ struct TestConvertFuncOp
     LowerToLLVMOptions options(ctx);
     // Populate type conversions.
     LLVMTypeConverter typeConverter(ctx, options);
+    typeConverter.addConversion(convertSimpleATypeToStruct);
 
     RewritePatternSet patterns(ctx);
     patterns.add<FuncOpConversion>(typeConverter);

>From ae9f5d7228e0131e8bc824aaaccab02ed52113b4 Mon Sep 17 00:00:00 2001
From: Diego Caballero <dieg0ca6aller0 at gmail.com>
Date: Thu, 1 Aug 2024 12:53:18 -0700
Subject: [PATCH 4/4] Fixes and cleanup

---
 .../Conversion/LLVMCommon/TypeConverter.h     |  7 ---
 .../mlir/Transforms/DialectConversion.h       |  5 ++
 mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp | 63 +++++++++++++++----
 .../Conversion/LLVMCommon/TypeConverter.cpp   | 21 -------
 .../Transforms/Utils/DialectConversion.cpp    |  8 +++
 5 files changed, 63 insertions(+), 41 deletions(-)

diff --git a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
index a0cc42d4d09f9..ff556b6a096c2 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
@@ -260,13 +260,6 @@ LogicalResult barePtrFuncArgTypeConverter(const LLVMTypeConverter &converter,
                                           Type type,
                                           SmallVectorImpl<Type> &result);
 
-/// Returns the `llvm.byval` or `llvm.byref` attributes that are present in the
-/// function arguments. Returns an empty container if none of these attributes
-/// are found in any of the attributes.
-void filterByValByRefArgAttributes(
-    FunctionOpInterface funcOp,
-    SmallVectorImpl<std::optional<NamedAttribute>> &result);
-
 } // namespace mlir
 
 #endif // MLIR_CONVERSION_LLVMCOMMON_TYPECONVERTER_H
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 7f0983f10ff82..fd0bb64f722f3 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -17,6 +17,7 @@
 #include "mlir/Rewrite/FrozenRewritePatternSet.h"
 #include "llvm/ADT/MapVector.h"
 #include "llvm/ADT/StringMap.h"
+
 #include <type_traits>
 
 namespace mlir {
@@ -75,6 +76,10 @@ class TypeConverter {
     /// Return the argument types for the new signature.
     ArrayRef<Type> getConvertedTypes() const { return argTypes; }
 
+    /// Get the converted type for the given argument only if there is a
+    /// one-to-one mapping for it. Otherwise, return std::nullptr.
+    std::optional<Type> getConvertedType(unsigned inputo) const;
+
     /// Get the input mapping for the given argument.
     std::optional<InputMapping> getInputMapping(unsigned input) const {
       return remappedInputs[input];
diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
index cca5aac31264b..4a4a415557ccf 100644
--- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
+++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
@@ -83,6 +83,38 @@ static void filterFuncAttributes(FunctionOpInterface func,
   }
 }
 
+/// Returns the `llvm.byval` or `llvm.byref` attributes that are present in the
+/// function arguments and were not converted to an LLVM pointer type. Returns
+/// an empty container if none of these attributes are found in any of the
+/// attributes.
+static void filterByValRefNonPtrAttrs(
+    FunctionOpInterface funcOp,
+    const TypeConverter::SignatureConversion &signatureConv,
+    SmallVectorImpl<std::optional<NamedAttribute>> &result,
+    function_ref<bool(int argIdx)> filter = nullptr) {
+  assert(result.empty() && "Unexpected non-empty output");
+  result.resize(funcOp.getNumArguments(), std::nullopt);
+  bool foundByValByRefAttrs = false;
+  for (int argIdx : llvm::seq(funcOp.getNumArguments())) {
+    for (NamedAttribute namedAttr : funcOp.getArgAttrs(argIdx)) {
+      if ((namedAttr.getName() == LLVM::LLVMDialect::getByValAttrName() ||
+           namedAttr.getName() == LLVM::LLVMDialect::getByRefAttrName())) {
+        // Retrieve the converted type from the converted signature and check
+        // that is not an LLVM pointer.
+        auto convType = signatureConv.getConvertedType(argIdx);
+        if (!convType || isa<LLVM::LLVMPointerType>(*convType))
+          continue;
+        foundByValByRefAttrs = true;
+        result[argIdx] = namedAttr;
+        break;
+      }
+    }
+  }
+
+  if (!foundByValByRefAttrs)
+    result.clear();
+}
+
 /// Propagate argument/results attributes.
 static void propagateArgResAttrs(OpBuilder &builder, bool resultStructType,
                                  FunctionOpInterface funcOp,
@@ -270,8 +302,10 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc,
 /// Inserts `llvm.load` ops in the function body to restore the expected pointee
 /// value from `llvm.byval`/`llvm.byref` function arguments that were converted
 /// to LLVM pointer types.
-static void restoreByValByRefArgumentType(ConversionPatternRewriter &rewriter,
-                                          FunctionOpInterface funcOp) {
+static void restoreByValRefArgumentType(
+    ConversionPatternRewriter &rewriter, const LLVMTypeConverter &typeConverter,
+    ArrayRef<std::optional<NamedAttribute>> byValRefNonPtrAttrs,
+    LLVM::LLVMFuncOp funcOp) {
   // Nothing to do for function declarations.
   if (funcOp.isExternal())
     return;
@@ -279,10 +313,8 @@ static void restoreByValByRefArgumentType(ConversionPatternRewriter &rewriter,
   ConversionPatternRewriter::InsertionGuard guard(rewriter);
   rewriter.setInsertionPointToStart(&funcOp.getFunctionBody().front());
 
-  SmallVector<std::optional<NamedAttribute>> byValRefArgAttrs;
-  filterByValByRefArgAttributes(funcOp, byValRefArgAttrs);
   for (const auto &[arg, byValRefAttr] :
-       llvm::zip(funcOp.getArguments(), byValRefArgAttrs)) {
+       llvm::zip(funcOp.getArguments(), byValRefNonPtrAttrs)) {
     // Skip argument if no `llvm.byval` or `llvm.byref` attribute.
     if (!byValRefAttr)
       continue;
@@ -291,7 +323,9 @@ static void restoreByValByRefArgumentType(ConversionPatternRewriter &rewriter,
     assert(isa<LLVM::LLVMPointerType>(arg.getType()) &&
            "Expected LLVM pointer type for argument with "
            "`llvm.byval`/`llvm.byref` attribute");
-    Type resTy = cast<TypeAttr>(byValRefAttr->getValue()).getValue();
+    Type resTy = typeConverter.convertType(
+        cast<TypeAttr>(byValRefAttr->getValue()).getValue());
+
     auto valueArg = rewriter.create<LLVM::LoadOp>(arg.getLoc(), resTy, arg);
     rewriter.replaceAllUsesExcept(arg, valueArg, valueArg);
   }
@@ -318,12 +352,14 @@ mlir::convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
   if (!llvmType)
     return rewriter.notifyMatchFailure(funcOp, "signature conversion failed");
 
-  // Make sure the type of `llvm.byval` and `llvm.byref` arguments are converted
-  // to LLVM pointer types.
-  SmallVector<std::optional<NamedAttribute>> byValByRefArgs;
-  filterByValByRefArgAttributes(funcOp, byValByRefArgs);
+  // Replace the type of `llvm.byval` and `llvm.byref` arguments that were not
+  // converted to an LLVM pointer type.
+  SmallVector<std::optional<NamedAttribute>> byValRefNonPtrAttrs;
+  filterByValRefNonPtrAttrs(funcOp, result, byValRefNonPtrAttrs);
   llvmType = converter.materializePtrForByValByRefFuncArgs(
-      llvmType, byValByRefArgs, result);
+      llvmType, byValRefNonPtrAttrs, result);
+  // Gather `llvm.byval` and `llvm.byref` arguments that need LLVM pointer
+  // materialization after the signature conversion..
 
   // Create an LLVM function, use external linkage by default until MLIR
   // functions have linkage.
@@ -436,10 +472,11 @@ mlir::convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
                                        "region types conversion failed");
   }
 
-  // Fix the type mismatch between the generated `llvm.ptr` and the expected
+  // Fix the type mismatch between the materialized `llvm.ptr` and the expected
   // pointee type in the function body when converting `llvm.byval`/`llvm.byref`
   // function arguments.
-  restoreByValByRefArgumentType(rewriter, newFuncOp);
+  restoreByValRefArgumentType(rewriter, converter, byValRefNonPtrAttrs,
+                              newFuncOp);
 
   if (!shouldUseBarePtrCallConv(funcOp, &converter)) {
     if (funcOp->getAttrOfType<UnitAttr>(
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index 7c05e5f1a330e..8f9490a6139c0 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -733,24 +733,3 @@ mlir::barePtrFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type,
   result.push_back(llvmTy);
   return success();
 }
-
-void mlir::filterByValByRefArgAttributes(
-    FunctionOpInterface funcOp,
-    SmallVectorImpl<std::optional<NamedAttribute>> &result) {
-  assert(result.empty() && "Unexpected non-empty output");
-  result.resize(funcOp.getNumArguments(), std::nullopt);
-  bool hasByValByRefAttrs = false;
-  for (int argIdx : llvm::seq(funcOp.getNumArguments())) {
-    for (NamedAttribute namedAttr : funcOp.getArgAttrs(argIdx)) {
-      if (namedAttr.getName() == LLVM::LLVMDialect::getByValAttrName() ||
-          namedAttr.getName() == LLVM::LLVMDialect::getByRefAttrName()) {
-        hasByValByRefAttrs = true;
-        result[argIdx] = namedAttr;
-        break;
-      }
-    }
-  }
-
-  if (!hasByValByRefAttrs)
-    result.clear();
-}
\ No newline at end of file
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 99882a306c0ea..52e7d1a793dad 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -2929,6 +2929,14 @@ LogicalResult OperationConverter::legalizeChangedResultType(
 // Type Conversion
 //===----------------------------------------------------------------------===//
 
+std::optional<Type>
+TypeConverter::SignatureConversion::getConvertedType(unsigned input) const {
+  auto mapping = getInputMapping(input);
+  if (!mapping || mapping->size != 1)
+    return std::nullopt;
+  return getConvertedTypes()[mapping->inputNo];
+}
+
 void TypeConverter::SignatureConversion::addInputs(unsigned origInputNo,
                                                    ArrayRef<Type> types) {
   assert(!types.empty() && "expected valid types");



More information about the Mlir-commits mailing list