[Mlir-commits] [mlir] 79c2094 - [mlir][LLVMIR] Parse some type attributes for LLVM function parameters

Alexander Batashev llvmlistbot at llvm.org
Thu Aug 25 01:08:48 PDT 2022


Author: Alexander Batashev
Date: 2022-08-25T11:06:51+03:00
New Revision: 79c2094881c503773bec9f3bcfbd717f7f3a027a

URL: https://github.com/llvm/llvm-project/commit/79c2094881c503773bec9f3bcfbd717f7f3a027a
DIFF: https://github.com/llvm/llvm-project/commit/79c2094881c503773bec9f3bcfbd717f7f3a027a.diff

LOG: [mlir][LLVMIR] Parse some type attributes for LLVM function parameters

With the transition to opaque pointers, type information has been
transferred to function parameter attributes. This patch adds correct
parsing for some of those arguments and fixes some tests, that
previously used UnitAttr for those.

Differential Revision: https://reviews.llvm.org/D132366

Added: 
    mlir/test/Target/LLVMIR/Import/func-attrs.ll

Modified: 
    flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
    flang/test/Fir/target-rewrite-arg-position.fir
    flang/test/Fir/target-rewrite-boxchar.fir
    flang/test/Fir/target-rewrite-complex.fir
    flang/test/Fir/target-rewrite-complex16.fir
    flang/test/Fir/target.fir
    mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
    mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
    mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
    mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
    mlir/test/Dialect/LLVMIR/func.mlir
    mlir/test/Target/LLVMIR/llvmir-invalid.mlir
    mlir/test/Target/LLVMIR/llvmir.mlir

Removed: 
    


################################################################################
diff  --git a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
index b5c7374de4928..5bcf79ea83c06 100644
--- a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
+++ b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
@@ -756,7 +756,7 @@ class TargetRewrite : public fir::TargetRewriteBase<TargetRewrite> {
   }
 
   inline bool functionArgIsSRet(unsigned index, mlir::func::FuncOp func) {
-    if (auto attr = func.getArgAttrOfType<mlir::UnitAttr>(index, "llvm.sret"))
+    if (auto attr = func.getArgAttrOfType<mlir::TypeAttr>(index, "llvm.sret"))
       return true;
     return false;
   }
@@ -782,16 +782,22 @@ class TargetRewrite : public fir::TargetRewriteBase<TargetRewrite> {
       if (auto align = attr.getAlignment())
         fixups.emplace_back(
             FixupTy::Codes::ReturnAsStore, argNo, [=](mlir::func::FuncOp func) {
-              func.setArgAttr(argNo, "llvm.sret", rewriter->getUnitAttr());
+              auto elemType = fir::dyn_cast_ptrOrBoxEleTy(
+                  func.getFunctionType().getInput(argNo));
+              func.setArgAttr(argNo, "llvm.sret",
+                              mlir::TypeAttr::get(elemType));
               func.setArgAttr(argNo, "llvm.align",
                               rewriter->getIntegerAttr(
                                   rewriter->getIntegerType(32), align));
             });
       else
-        fixups.emplace_back(
-            FixupTy::Codes::ReturnAsStore, argNo, [=](mlir::func::FuncOp func) {
-              func.setArgAttr(argNo, "llvm.sret", rewriter->getUnitAttr());
-            });
+        fixups.emplace_back(FixupTy::Codes::ReturnAsStore, argNo,
+                            [=](mlir::func::FuncOp func) {
+                              auto elemType = fir::dyn_cast_ptrOrBoxEleTy(
+                                  func.getFunctionType().getInput(argNo));
+                              func.setArgAttr(argNo, "llvm.sret",
+                                              mlir::TypeAttr::get(elemType));
+                            });
       newInTys.push_back(argTy);
       return;
     } else {
@@ -833,7 +839,10 @@ class TargetRewrite : public fir::TargetRewriteBase<TargetRewrite> {
           fixups.emplace_back(
               FixupTy::Codes::ArgumentAsLoad, argNo,
               [=](mlir::func::FuncOp func) {
-                func.setArgAttr(argNo, "llvm.byval", rewriter->getUnitAttr());
+                auto elemType = fir::dyn_cast_ptrOrBoxEleTy(
+                    func.getFunctionType().getInput(argNo));
+                func.setArgAttr(argNo, "llvm.byval",
+                                mlir::TypeAttr::get(elemType));
                 func.setArgAttr(argNo, "llvm.align",
                                 rewriter->getIntegerAttr(
                                     rewriter->getIntegerType(32), align));
@@ -841,8 +850,10 @@ class TargetRewrite : public fir::TargetRewriteBase<TargetRewrite> {
         else
           fixups.emplace_back(FixupTy::Codes::ArgumentAsLoad, newInTys.size(),
                               [=](mlir::func::FuncOp func) {
+                                auto elemType = fir::dyn_cast_ptrOrBoxEleTy(
+                                    func.getFunctionType().getInput(argNo));
                                 func.setArgAttr(argNo, "llvm.byval",
-                                                rewriter->getUnitAttr());
+                                                mlir::TypeAttr::get(elemType));
                               });
       } else {
         if (auto align = attr.getAlignment())

diff  --git a/flang/test/Fir/target-rewrite-arg-position.fir b/flang/test/Fir/target-rewrite-arg-position.fir
index e5359ee5540fb..44d5981fdf82d 100644
--- a/flang/test/Fir/target-rewrite-arg-position.fir
+++ b/flang/test/Fir/target-rewrite-arg-position.fir
@@ -16,7 +16,7 @@ func.func @_QFPf(%arg0: !fir.ref<tuple<!fir.ref<i32>>> {fir.host_assoc}) -> !fir
 }
 
 // CHECK-LABEL: func.func @_QFPf
-// CHECK-SAME:    %{{.*}}: !fir.ref<tuple<!fir.real<16>, !fir.real<16>>> {llvm.align = 16 : i32, llvm.sret}, %arg1: !fir.ref<tuple<!fir.ref<i32>>> {fir.host_assoc, llvm.nest}) {
+// CHECK-SAME:    %{{.*}}: !fir.ref<tuple<!fir.real<16>, !fir.real<16>>> {llvm.align = 16 : i32, llvm.sret = tuple<!fir.real<16>, !fir.real<16>>}, %arg1: !fir.ref<tuple<!fir.ref<i32>>> {fir.host_assoc, llvm.nest}) {
 
 // -----
 

diff  --git a/flang/test/Fir/target-rewrite-boxchar.fir b/flang/test/Fir/target-rewrite-boxchar.fir
index 7ec1c42f9bca0..e1df9e06adc40 100644
--- a/flang/test/Fir/target-rewrite-boxchar.fir
+++ b/flang/test/Fir/target-rewrite-boxchar.fir
@@ -27,10 +27,10 @@ func.func @boxcharparams(%arg0 : !fir.boxchar<1>, %arg1 : !fir.boxchar<1>) -> i6
 // Test that we rewrite the signatures and bodies of functions that return a
 // boxchar.
 // INT32-LABEL: @boxcharsret
-// INT32-SAME: ([[ARG0:%[0-9A-Za-z]+]]: !fir.ref<!fir.char<1,?>> {llvm.sret}, [[ARG1:%[0-9A-Za-z]+]]: i32, [[ARG2:%[0-9A-Za-z]+]]: !fir.ref<!fir.char<1,?>>, [[ARG3:%[0-9A-Za-z]+]]: i32)
+// INT32-SAME: ([[ARG0:%[0-9A-Za-z]+]]: !fir.ref<!fir.char<1,?>> {llvm.sret = !fir.char<1,?>}, [[ARG1:%[0-9A-Za-z]+]]: i32, [[ARG2:%[0-9A-Za-z]+]]: !fir.ref<!fir.char<1,?>>, [[ARG3:%[0-9A-Za-z]+]]: i32)
 // INT64-LABEL: @boxcharsret
-// INT64-SAME: ([[ARG0:%[0-9A-Za-z]+]]: !fir.ref<!fir.char<1,?>> {llvm.sret}, [[ARG1:%[0-9A-Za-z]+]]: i64, [[ARG2:%[0-9A-Za-z]+]]: !fir.ref<!fir.char<1,?>>, [[ARG3:%[0-9A-Za-z]+]]: i64)
-func.func @boxcharsret(%arg0 : !fir.boxchar<1> {llvm.sret}, %arg1 : !fir.boxchar<1>) {
+// INT64-SAME: ([[ARG0:%[0-9A-Za-z]+]]: !fir.ref<!fir.char<1,?>> {llvm.sret = !fir.char<1,?>}, [[ARG1:%[0-9A-Za-z]+]]: i64, [[ARG2:%[0-9A-Za-z]+]]: !fir.ref<!fir.char<1,?>>, [[ARG3:%[0-9A-Za-z]+]]: i64)
+func.func @boxcharsret(%arg0 : !fir.boxchar<1> {llvm.sret = !fir.char<1,?>}, %arg1 : !fir.boxchar<1>) {
   // INT32-DAG: [[B0:%[0-9]+]] = fir.emboxchar [[ARG0]], [[ARG1]] : (!fir.ref<!fir.char<1,?>>, i32) -> !fir.boxchar<1>
   // INT32-DAG: [[B1:%[0-9]+]] = fir.emboxchar [[ARG2]], [[ARG3]] : (!fir.ref<!fir.char<1,?>>, i32) -> !fir.boxchar<1>
   // INT32-DAG: fir.unboxchar [[B0]] : (!fir.boxchar<1>) -> (!fir.ref<!fir.array<?x!fir.char<1>>>, i64)
@@ -57,10 +57,10 @@ func.func @boxcharsret(%arg0 : !fir.boxchar<1> {llvm.sret}, %arg1 : !fir.boxchar
 // Test that we rewrite the signatures of functions with a sret parameter and
 // several other parameters.
 // INT32-LABEL: @boxcharmultiple
-// INT32-SAME: ({{%[0-9A-Za-z]+}}: !fir.ref<!fir.char<1,?>> {llvm.sret}, {{%[0-9A-Za-z]+}}: i32, {{%[0-9A-Za-z]+}}: !fir.ref<!fir.char<1,?>>, {{%[0-9A-Za-z]+}}: !fir.ref<!fir.char<1,?>>, {{%[0-9A-Za-z]+}}: i32, {{%[0-9A-Za-z]+}}: i32)
+// INT32-SAME: ({{%[0-9A-Za-z]+}}: !fir.ref<!fir.char<1,?>> {llvm.sret = !fir.char<1,?>}, {{%[0-9A-Za-z]+}}: i32, {{%[0-9A-Za-z]+}}: !fir.ref<!fir.char<1,?>>, {{%[0-9A-Za-z]+}}: !fir.ref<!fir.char<1,?>>, {{%[0-9A-Za-z]+}}: i32, {{%[0-9A-Za-z]+}}: i32)
 // INT64-LABEL: @boxcharmultiple
-// INT64-SAME: ({{%[0-9A-Za-z]+}}: !fir.ref<!fir.char<1,?>> {llvm.sret}, {{%[0-9A-Za-z]+}}: i64, {{%[0-9A-Za-z]+}}: !fir.ref<!fir.char<1,?>>, {{%[0-9A-Za-z]+}}: !fir.ref<!fir.char<1,?>>, {{%[0-9A-Za-z]+}}: i64, {{%[0-9A-Za-z]+}}: i64)
-func.func @boxcharmultiple(%arg0 : !fir.boxchar<1> {llvm.sret}, %arg1 : !fir.boxchar<1>, %arg2 : !fir.boxchar<1>) {
+// INT64-SAME: ({{%[0-9A-Za-z]+}}: !fir.ref<!fir.char<1,?>> {llvm.sret = !fir.char<1,?>}, {{%[0-9A-Za-z]+}}: i64, {{%[0-9A-Za-z]+}}: !fir.ref<!fir.char<1,?>>, {{%[0-9A-Za-z]+}}: !fir.ref<!fir.char<1,?>>, {{%[0-9A-Za-z]+}}: i64, {{%[0-9A-Za-z]+}}: i64)
+func.func @boxcharmultiple(%arg0 : !fir.boxchar<1> {llvm.sret = !fir.char<1,?>}, %arg1 : !fir.boxchar<1>, %arg2 : !fir.boxchar<1>) {
   return
 }
 

diff  --git a/flang/test/Fir/target-rewrite-complex.fir b/flang/test/Fir/target-rewrite-complex.fir
index b83d582178140..121edd698942f 100644
--- a/flang/test/Fir/target-rewrite-complex.fir
+++ b/flang/test/Fir/target-rewrite-complex.fir
@@ -53,7 +53,7 @@ func.func @returncomplex4() -> !fir.complex<4> {
 // Test that we rewrite the signature and body of a function that returns a
 // complex<8>.
 // I32-LABEL:func @returncomplex8
-// I32-SAME: ([[ARG0:%[0-9A-Za-z]+]]: !fir.ref<tuple<!fir.real<8>, !fir.real<8>>>  {llvm.align = 4 : i32, llvm.sret})
+// I32-SAME: ([[ARG0:%[0-9A-Za-z]+]]: !fir.ref<tuple<!fir.real<8>, !fir.real<8>>>  {llvm.align = 4 : i32, llvm.sret = tuple<!fir.real<8>, !fir.real<8>>})
 // X64-LABEL: func @returncomplex8() -> tuple<!fir.real<8>, !fir.real<8>>
 // AARCH64-LABEL: func @returncomplex8() -> tuple<!fir.real<8>, !fir.real<8>>
 // PPC-LABEL: func @returncomplex8() -> tuple<!fir.real<8>, !fir.real<8>>
@@ -96,7 +96,7 @@ func.func @returncomplex8() -> !fir.complex<8> {
 }
 
 // Test that we rewrite the signature of a function that accepts a complex<4>.
-// I32-LABEL: func private @paramcomplex4(!fir.ref<tuple<!fir.real<4>, !fir.real<4>>> {llvm.align = 4 : i32, llvm.byval})
+// I32-LABEL: func private @paramcomplex4(!fir.ref<tuple<!fir.real<4>, !fir.real<4>>> {llvm.align = 4 : i32, llvm.byval = tuple<!fir.real<4>, !fir.real<4>>})
 // X64-LABEL: func private @paramcomplex4(!fir.vector<2:!fir.real<4>>)
 // AARCH64-LABEL: func private @paramcomplex4(!fir.array<2x!fir.real<4>>)
 // PPC-LABEL: func private @paramcomplex4(!fir.real<4>, !fir.real<4>)
@@ -156,7 +156,7 @@ func.func @callcomplex4() {
 }
 
 // Test that we rewrite the signature of a function that accepts a complex<8>.
-// I32-LABEL: func private @paramcomplex8(!fir.ref<tuple<!fir.real<8>, !fir.real<8>>> {llvm.align = 4 : i32, llvm.byval})
+// I32-LABEL: func private @paramcomplex8(!fir.ref<tuple<!fir.real<8>, !fir.real<8>>> {llvm.align = 4 : i32, llvm.byval = tuple<!fir.real<8>, !fir.real<8>>})
 // X64-LABEL: func private @paramcomplex8(!fir.real<8>, !fir.real<8>)
 // AARCH64-LABEL: func private @paramcomplex8(!fir.array<2x!fir.real<8>>)
 // PPC-LABEL: func private @paramcomplex8(!fir.real<8>, !fir.real<8>)
@@ -212,14 +212,14 @@ func.func @callcomplex8() {
 }
 
 // Test multiple complex<4> parameters and arguments
-// I32-LABEL: func private @calleemultipleparamscomplex4(!fir.ref<tuple<!fir.real<4>, !fir.real<4>>> {llvm.align = 4 : i32, llvm.byval}, !fir.ref<tuple<!fir.real<4>, !fir.real<4>>> {llvm.align = 4 : i32, llvm.byval}, !fir.ref<tuple<!fir.real<4>, !fir.real<4>>> {llvm.align = 4 : i32, llvm.byval})
+// I32-LABEL: func private @calleemultipleparamscomplex4(!fir.ref<tuple<!fir.real<4>, !fir.real<4>>> {llvm.align = 4 : i32, llvm.byval = tuple<!fir.real<4>, !fir.real<4>>}, !fir.ref<tuple<!fir.real<4>, !fir.real<4>>> {llvm.align = 4 : i32, llvm.byval = tuple<!fir.real<4>, !fir.real<4>>}, !fir.ref<tuple<!fir.real<4>, !fir.real<4>>> {llvm.align = 4 : i32, llvm.byval = tuple<!fir.real<4>, !fir.real<4>>})
 // X64-LABEL: func private @calleemultipleparamscomplex4(!fir.vector<2:!fir.real<4>>, !fir.vector<2:!fir.real<4>>, !fir.vector<2:!fir.real<4>>)
 // AARCH64-LABEL: func private @calleemultipleparamscomplex4(!fir.array<2x!fir.real<4>>, !fir.array<2x!fir.real<4>>, !fir.array<2x!fir.real<4>>)
 // PPC-LABEL: func private @calleemultipleparamscomplex4(!fir.real<4>, !fir.real<4>, !fir.real<4>, !fir.real<4>, !fir.real<4>, !fir.real<4>)
 func.func private @calleemultipleparamscomplex4(!fir.complex<4>, !fir.complex<4>, !fir.complex<4>) -> ()
 
 // I32-LABEL: func @multipleparamscomplex4
-// I32-SAME: ([[Z1:%[0-9A-Za-z]+]]: !fir.ref<tuple<!fir.real<4>, !fir.real<4>>> {llvm.align = 4 : i32, llvm.byval}, [[Z2:%[0-9A-Za-z]+]]: !fir.ref<tuple<!fir.real<4>, !fir.real<4>>> {llvm.align = 4 : i32, llvm.byval}, [[Z3:%[0-9A-Za-z]+]]: !fir.ref<tuple<!fir.real<4>, !fir.real<4>>> {llvm.align = 4 : i32, llvm.byval})
+// I32-SAME: ([[Z1:%[0-9A-Za-z]+]]: !fir.ref<tuple<!fir.real<4>, !fir.real<4>>> {llvm.align = 4 : i32, llvm.byval = tuple<!fir.real<4>, !fir.real<4>>}, [[Z2:%[0-9A-Za-z]+]]: !fir.ref<tuple<!fir.real<4>, !fir.real<4>>> {llvm.align = 4 : i32, llvm.byval = tuple<!fir.real<4>, !fir.real<4>>}, [[Z3:%[0-9A-Za-z]+]]: !fir.ref<tuple<!fir.real<4>, !fir.real<4>>> {llvm.align = 4 : i32, llvm.byval = tuple<!fir.real<4>, !fir.real<4>>})
 // X64-LABEL: func @multipleparamscomplex4
 // X64-SAME: ([[Z1:%[0-9A-Za-z]+]]: !fir.vector<2:!fir.real<4>>, [[Z2:%[0-9A-Za-z]+]]: !fir.vector<2:!fir.real<4>>, [[Z3:%[0-9A-Za-z]+]]: !fir.vector<2:!fir.real<4>>)
 // AARCH64-LABEL: func @multipleparamscomplex4
@@ -329,7 +329,7 @@ func.func @multipleparamscomplex4(%z1 : !fir.complex<4>, %z2 : !fir.complex<4>,
 // and returns MLIR complex<f32>.
 
 // I32-LABEL: func private @mlircomplexf32
-// I32-SAME: ([[Z1:%[0-9A-Za-z]+]]: !fir.ref<tuple<f32, f32>> {llvm.align = 4 : i32, llvm.byval}, [[Z2:%[0-9A-Za-z]+]]: !fir.ref<tuple<f32, f32>> {llvm.align = 4 : i32, llvm.byval})
+// I32-SAME: ([[Z1:%[0-9A-Za-z]+]]: !fir.ref<tuple<f32, f32>> {llvm.align = 4 : i32, llvm.byval = tuple<f32, f32>}, [[Z2:%[0-9A-Za-z]+]]: !fir.ref<tuple<f32, f32>> {llvm.align = 4 : i32, llvm.byval = tuple<f32, f32>})
 // I32-SAME: -> i64
 // X64-LABEL: func private @mlircomplexf32
 // X64-SAME: ([[Z1:%[0-9A-Za-z]+]]: !fir.vector<2:f32>, [[Z2:%[0-9A-Za-z]+]]: !fir.vector<2:f32>)

diff  --git a/flang/test/Fir/target-rewrite-complex16.fir b/flang/test/Fir/target-rewrite-complex16.fir
index 9c4f0af7dd77b..00eff6c582908 100644
--- a/flang/test/Fir/target-rewrite-complex16.fir
+++ b/flang/test/Fir/target-rewrite-complex16.fir
@@ -48,7 +48,7 @@ func.func @addrof() {
 }
 
 // CHECK-LABEL:   func.func @returncomplex16(
-// CHECK-SAME:      %[[VAL_0:.*]]: !fir.ref<tuple<!fir.real<16>, !fir.real<16>>> {llvm.align = 16 : i32, llvm.sret}) {
+// CHECK-SAME:      %[[VAL_0:.*]]: !fir.ref<tuple<!fir.real<16>, !fir.real<16>>> {llvm.align = 16 : i32, llvm.sret = tuple<!fir.real<16>, !fir.real<16>>}) {
 // CHECK:           %[[VAL_1:.*]] = fir.undefined !fir.complex<16>
 // CHECK:           %[[VAL_2:.*]] = arith.constant 2.000000e+00 : f128
 // CHECK:           %[[VAL_3:.*]] = fir.convert %[[VAL_2]] : (f128) -> !fir.real<16>
@@ -61,7 +61,7 @@ func.func @addrof() {
 // CHECK:           fir.store %[[VAL_8]] to %[[VAL_9]] : !fir.ref<!fir.complex<16>>
 // CHECK:           return
 // CHECK:         }
-// CHECK:         func.func private @paramcomplex16(!fir.ref<tuple<!fir.real<16>, !fir.real<16>>> {llvm.align = 16 : i32, llvm.byval})
+// CHECK:         func.func private @paramcomplex16(!fir.ref<tuple<!fir.real<16>, !fir.real<16>>> {llvm.align = 16 : i32, llvm.byval = tuple<!fir.real<16>, !fir.real<16>>})
 
 // CHECK-LABEL:   func.func @callcomplex16() {
 // CHECK:           %[[VAL_0:.*]] = fir.alloca tuple<!fir.real<16>, !fir.real<16>>
@@ -74,10 +74,10 @@ func.func @addrof() {
 // CHECK:           fir.call @paramcomplex16(%[[VAL_4]]) : (!fir.ref<tuple<!fir.real<16>, !fir.real<16>>>) -> ()
 // CHECK:           return
 // CHECK:         }
-// CHECK:         func.func private @calleemultipleparamscomplex16(!fir.ref<tuple<!fir.real<16>, !fir.real<16>>> {llvm.align = 16 : i32, llvm.byval}, !fir.ref<tuple<!fir.real<16>, !fir.real<16>>> {llvm.align = 16 : i32, llvm.byval}, !fir.ref<tuple<!fir.real<16>, !fir.real<16>>> {llvm.align = 16 : i32, llvm.byval})
+// CHECK:         func.func private @calleemultipleparamscomplex16(!fir.ref<tuple<!fir.real<16>, !fir.real<16>>> {llvm.align = 16 : i32, llvm.byval = tuple<!fir.real<16>, !fir.real<16>>}, !fir.ref<tuple<!fir.real<16>, !fir.real<16>>> {llvm.align = 16 : i32, llvm.byval = tuple<!fir.real<16>, !fir.real<16>>}, !fir.ref<tuple<!fir.real<16>, !fir.real<16>>> {llvm.align = 16 : i32, llvm.byval = tuple<!fir.real<16>, !fir.real<16>>})
 
 // CHECK-LABEL:   func.func @multipleparamscomplex16(
-// CHECK-SAME:       %[[VAL_0:.*]]: !fir.ref<tuple<!fir.real<16>, !fir.real<16>>> {llvm.align = 16 : i32, llvm.byval}, %[[VAL_1:.*]]: !fir.ref<tuple<!fir.real<16>, !fir.real<16>>> {llvm.align = 16 : i32, llvm.byval}, %[[VAL_2:.*]]: !fir.ref<tuple<!fir.real<16>, !fir.real<16>>> {llvm.align = 16 : i32, llvm.byval}) {
+// CHECK-SAME:       %[[VAL_0:.*]]: !fir.ref<tuple<!fir.real<16>, !fir.real<16>>> {llvm.align = 16 : i32, llvm.byval = tuple<!fir.real<16>, !fir.real<16>>}, %[[VAL_1:.*]]: !fir.ref<tuple<!fir.real<16>, !fir.real<16>>> {llvm.align = 16 : i32, llvm.byval = tuple<!fir.real<16>, !fir.real<16>>}, %[[VAL_2:.*]]: !fir.ref<tuple<!fir.real<16>, !fir.real<16>>> {llvm.align = 16 : i32, llvm.byval = tuple<!fir.real<16>, !fir.real<16>>}) {
 // CHECK:           %[[VAL_3:.*]] = fir.convert %[[VAL_2]] : (!fir.ref<tuple<!fir.real<16>, !fir.real<16>>>) -> !fir.ref<!fir.complex<16>>
 // CHECK:           %[[VAL_4:.*]] = fir.load %[[VAL_3]] : !fir.ref<!fir.complex<16>>
 // CHECK:           %[[VAL_5:.*]] = fir.convert %[[VAL_1]] : (!fir.ref<tuple<!fir.real<16>, !fir.real<16>>>) -> !fir.ref<!fir.complex<16>>
@@ -98,7 +98,7 @@ func.func @addrof() {
 // CHECK:         }
 
 // CHECK-LABEL:   func.func private @mlircomplexf128(
-// CHECK-SAME:      %[[VAL_0:.*]]: !fir.ref<tuple<f128, f128>> {llvm.align = 16 : i32, llvm.sret}, %[[VAL_1:.*]]: !fir.ref<tuple<f128, f128>> {llvm.align = 16 : i32, llvm.byval},  %[[VAL_2:.*]]: !fir.ref<tuple<f128, f128>> {llvm.align = 16 : i32, llvm.byval}) {
+// CHECK-SAME:      %[[VAL_0:.*]]: !fir.ref<tuple<f128, f128>> {llvm.align = 16 : i32, llvm.sret = tuple<f128, f128>}, %[[VAL_1:.*]]: !fir.ref<tuple<f128, f128>> {llvm.align = 16 : i32, llvm.byval = tuple<f128, f128>},  %[[VAL_2:.*]]: !fir.ref<tuple<f128, f128>> {llvm.align = 16 : i32, llvm.byval = tuple<f128, f128>}) {
 // CHECK:           %[[VAL_3:.*]] = fir.convert %[[VAL_2]] : (!fir.ref<tuple<f128, f128>>) -> !fir.ref<complex<f128>>
 // CHECK:           %[[VAL_4:.*]] = fir.load %[[VAL_3]] : !fir.ref<complex<f128>>
 // CHECK:           %[[VAL_5:.*]] = fir.convert %[[VAL_1]] : (!fir.ref<tuple<f128, f128>>) -> !fir.ref<complex<f128>>

diff  --git a/flang/test/Fir/target.fir b/flang/test/Fir/target.fir
index 39798d0a23581..831c75379aa07 100644
--- a/flang/test/Fir/target.fir
+++ b/flang/test/Fir/target.fir
@@ -115,7 +115,7 @@ func.func @char1lensum(%arg0 : !fir.boxchar<1>, %arg1 : !fir.boxchar<1>) -> i64
 // I32-LABEL: define void @char1copy(ptr sret(i8) %0, i32 %1, ptr %2, i32 %3)
 // I64-LABEL: define void @char1copy(ptr sret(i8) %0, i64 %1, ptr %2, i64 %3)
 // PPC-LABEL: define void @char1copy(ptr sret(i8) %0, i64 %1, ptr %2, i64 %3)
-func.func @char1copy(%arg0 : !fir.boxchar<1> {llvm.sret}, %arg1 : !fir.boxchar<1>) {
+func.func @char1copy(%arg0 : !fir.boxchar<1> {llvm.sret = !fir.char<1, ?>}, %arg1 : !fir.boxchar<1>) {
   // I32-DAG: %[[p0:.*]] = insertvalue { ptr, i32 } undef, ptr %2, 0
   // I32-DAG: = insertvalue { ptr, i32 } %[[p0]], i32 %3, 1
   // I32-DAG: %[[p1:.*]] = insertvalue { ptr, i32 } undef, ptr %0, 0

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
index bd6d8a1c7ad25..0933ad41f3530 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
@@ -45,6 +45,10 @@ def LLVM_Dialect : Dialect {
     static StringRef getLoopOptionsAttrName() { return "options"; }
     static StringRef getAccessGroupsAttrName() { return "access_groups"; }
     static StringRef getStructAttrsAttrName() { return "llvm.struct_attrs"; }
+    static StringRef getByValAttrName() { return "llvm.byval"; }
+    static StringRef getByRefAttrName() { return "llvm.byref"; }
+    static StringRef getStructRetAttrName() { return "llvm.sret"; }
+    static StringRef getInAllocaAttrName() { return "llvm.inalloca"; }
 
     /// Verifies if the attribute is a well-formed value for "llvm.struct_attrs"
     static LogicalResult verifyStructAttr(

diff  --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
index 1d7e28291fc23..4b80bcb90cc3e 100644
--- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
+++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
@@ -23,10 +23,13 @@
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/BlockAndValueMapping.h"
 #include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributeInterfaces.h"
+#include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
@@ -34,6 +37,7 @@
 #include "mlir/Support/MathExtras.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "mlir/Transforms/Passes.h"
+#include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/IR/DerivedTypes.h"
 #include "llvm/IR/IRBuilder.h"
@@ -311,10 +315,42 @@ struct FuncOpConversionBase : public ConvertOpToLLVMPattern<func::FuncOp> {
       SmallVector<Attribute, 4> newArgAttrs(
           llvmType.cast<LLVM::LLVMFunctionType>().getNumParams());
       for (unsigned i = 0, e = funcOp.getNumArguments(); i < e; ++i) {
+        // Some LLVM IR attribute have a type attached to them. During FuncOp ->
+        // LLVMFuncOp conversion these types may have changed. Account for that
+        // change by converting attributes' types as well.
+        SmallVector<NamedAttribute, 4> convertedAttrs;
+        auto attrsDict = argAttrDicts[i].cast<DictionaryAttr>();
+        convertedAttrs.reserve(attrsDict.size());
+        for (const NamedAttribute &attr : attrsDict) {
+          const auto convert = [&](const NamedAttribute &attr) {
+            return TypeAttr::get(getTypeConverter()->convertType(
+                attr.getValue().cast<TypeAttr>().getValue()));
+          };
+          if (attr.getName().getValue() ==
+              LLVM::LLVMDialect::getByValAttrName()) {
+            convertedAttrs.push_back(rewriter.getNamedAttr(
+                LLVM::LLVMDialect::getByValAttrName(), convert(attr)));
+          } else if (attr.getName().getValue() ==
+                     LLVM::LLVMDialect::getByRefAttrName()) {
+            convertedAttrs.push_back(rewriter.getNamedAttr(
+                LLVM::LLVMDialect::getByRefAttrName(), convert(attr)));
+          } else if (attr.getName().getValue() ==
+                     LLVM::LLVMDialect::getStructRetAttrName()) {
+            convertedAttrs.push_back(rewriter.getNamedAttr(
+                LLVM::LLVMDialect::getStructRetAttrName(), convert(attr)));
+          } else if (attr.getName().getValue() ==
+                     LLVM::LLVMDialect::getInAllocaAttrName()) {
+            convertedAttrs.push_back(rewriter.getNamedAttr(
+                LLVM::LLVMDialect::getInAllocaAttrName(), convert(attr)));
+          } else {
+            convertedAttrs.push_back(attr);
+          }
+        }
         auto mapping = result.getInputMapping(i);
         assert(mapping && "unexpected deletion of function argument");
         for (size_t j = 0; j < mapping->size; ++j)
-          newArgAttrs[mapping->inputNo + j] = argAttrDicts[i];
+          newArgAttrs[mapping->inputNo + j] =
+              DictionaryAttr::get(rewriter.getContext(), convertedAttrs);
       }
       attributes.push_back(
           rewriter.getNamedAttr(FunctionOpInterface::getArgDictAttrName(),

diff  --git a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
index 160d9396c5b00..88dc8f7db9d85 100644
--- a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
+++ b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
@@ -1195,6 +1195,36 @@ LogicalResult Importer::processFunction(llvm::Function *f) {
       UnknownLoc::get(context), f->getName(), functionType,
       convertLinkageFromLLVM(f->getLinkage()), dsoLocal, cconv);
 
+  for (const auto &arg : llvm::enumerate(functionType.getParams())) {
+    llvm::SmallVector<NamedAttribute, 1> argAttrs;
+    if (auto *type = f->getParamByValType(arg.index())) {
+      auto mlirType = processType(type);
+      argAttrs.push_back(
+          NamedAttribute(b.getStringAttr(LLVMDialect::getByValAttrName()),
+                         TypeAttr::get(mlirType)));
+    }
+    if (auto *type = f->getParamByRefType(arg.index())) {
+      auto mlirType = processType(type);
+      argAttrs.push_back(
+          NamedAttribute(b.getStringAttr(LLVMDialect::getByRefAttrName()),
+                         TypeAttr::get(mlirType)));
+    }
+    if (auto *type = f->getParamStructRetType(arg.index())) {
+      auto mlirType = processType(type);
+      argAttrs.push_back(
+          NamedAttribute(b.getStringAttr(LLVMDialect::getStructRetAttrName()),
+                         TypeAttr::get(mlirType)));
+    }
+    if (auto *type = f->getParamInAllocaType(arg.index())) {
+      auto mlirType = processType(type);
+      argAttrs.push_back(
+          NamedAttribute(b.getStringAttr(LLVMDialect::getInAllocaAttrName()),
+                         TypeAttr::get(mlirType)));
+    }
+
+    fop.setArgAttrs(arg.index(), argAttrs);
+  }
+
   if (FlatSymbolRefAttr personality = getPersonalityAsAttr(f))
     fop->setAttr(b.getStringAttr("personality"), personality);
   else if (f->hasPersonalityFn())

diff  --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 0464cfc825c1f..07054a9cacb33 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -840,23 +840,57 @@ LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) {
                            .addAlignmentAttr(llvm::Align(attr.getInt())));
     }
 
-    if (auto attr = func.getArgAttrOfType<UnitAttr>(argIdx, "llvm.sret")) {
+    if (auto attr = func.getArgAttrOfType<TypeAttr>(
+            argIdx, LLVMDialect::getStructRetAttrName())) {
       auto argTy = mlirArg.getType().dyn_cast<LLVM::LLVMPointerType>();
       if (!argTy)
         return func.emitError(
             "llvm.sret attribute attached to LLVM non-pointer argument");
-      llvmArg.addAttrs(
-          llvm::AttrBuilder(llvmArg.getContext())
-              .addStructRetAttr(convertType(argTy.getElementType())));
+      if (!argTy.isOpaque() && argTy.getElementType() != attr.getValue())
+        return func.emitError("llvm.sret attribute attached to LLVM pointer "
+                              "argument of a 
diff erent type");
+      llvmArg.addAttrs(llvm::AttrBuilder(llvmArg.getContext())
+                           .addStructRetAttr(convertType(attr.getValue())));
     }
 
-    if (auto attr = func.getArgAttrOfType<UnitAttr>(argIdx, "llvm.byval")) {
+    if (auto attr = func.getArgAttrOfType<TypeAttr>(
+            argIdx, LLVMDialect::getByValAttrName())) {
       auto argTy = mlirArg.getType().dyn_cast<LLVM::LLVMPointerType>();
       if (!argTy)
         return func.emitError(
             "llvm.byval attribute attached to LLVM non-pointer argument");
+      if (!argTy.isOpaque() && argTy.getElementType() != attr.getValue())
+        return func.emitError("llvm.byval attribute attached to LLVM pointer "
+                              "argument of a 
diff erent type");
+      llvmArg.addAttrs(llvm::AttrBuilder(llvmArg.getContext())
+                           .addByValAttr(convertType(attr.getValue())));
+    }
+
+    if (auto attr = func.getArgAttrOfType<TypeAttr>(
+            argIdx, LLVMDialect::getByRefAttrName())) {
+      auto argTy = mlirArg.getType().dyn_cast<LLVM::LLVMPointerType>();
+      if (!argTy)
+        return func.emitError(
+            "llvm.byref attribute attached to LLVM non-pointer argument");
+      if (!argTy.isOpaque() && argTy.getElementType() != attr.getValue())
+        return func.emitError("llvm.byref attribute attached to LLVM pointer "
+                              "argument of a 
diff erent type");
+      llvmArg.addAttrs(llvm::AttrBuilder(llvmArg.getContext())
+                           .addByRefAttr(convertType(attr.getValue())));
+    }
+
+    if (auto attr = func.getArgAttrOfType<TypeAttr>(
+            argIdx, LLVMDialect::getInAllocaAttrName())) {
+      auto argTy = mlirArg.getType().dyn_cast<LLVM::LLVMPointerType>();
+      if (!argTy)
+        return func.emitError(
+            "llvm.inalloca attribute attached to LLVM non-pointer argument");
+      if (!argTy.isOpaque() && argTy.getElementType() != attr.getValue())
+        return func.emitError(
+            "llvm.inalloca attribute attached to LLVM pointer "
+            "argument of a 
diff erent type");
       llvmArg.addAttrs(llvm::AttrBuilder(llvmArg.getContext())
-                           .addByValAttr(convertType(argTy.getElementType())));
+                           .addInAllocaAttr(convertType(attr.getValue())));
     }
 
     if (auto attr = func.getArgAttrOfType<UnitAttr>(argIdx, "llvm.nest")) {

diff  --git a/mlir/test/Dialect/LLVMIR/func.mlir b/mlir/test/Dialect/LLVMIR/func.mlir
index 9c2b973a67aad..da46908782ace 100644
--- a/mlir/test/Dialect/LLVMIR/func.mlir
+++ b/mlir/test/Dialect/LLVMIR/func.mlir
@@ -93,9 +93,9 @@ module {
     llvm.return
   }
 
-  // CHECK: llvm.func @sretattr(%{{.*}}: !llvm.ptr<i32> {llvm.sret})
-  // LOCINFO: llvm.func @sretattr(%{{.*}}: !llvm.ptr<i32> {llvm.sret} loc("some_source_loc"))
-  llvm.func @sretattr(%arg0: !llvm.ptr<i32> {llvm.sret} loc("some_source_loc")) {
+  // CHECK: llvm.func @sretattr(%{{.*}}: !llvm.ptr<i32> {llvm.sret = i32})
+  // LOCINFO: llvm.func @sretattr(%{{.*}}: !llvm.ptr<i32> {llvm.sret = i32} loc("some_source_loc"))
+  llvm.func @sretattr(%arg0: !llvm.ptr<i32> {llvm.sret = i32} loc("some_source_loc")) {
     llvm.return
   }
 

diff  --git a/mlir/test/Target/LLVMIR/Import/func-attrs.ll b/mlir/test/Target/LLVMIR/Import/func-attrs.ll
new file mode 100644
index 0000000000000..4008f12780d5c
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/Import/func-attrs.ll
@@ -0,0 +1,6 @@
+; RUN: mlir-translate -import-llvm %s | FileCheck %s
+
+; CHECK: llvm.func @foo(%arg0: !llvm.ptr {llvm.byval = i64}, %arg1: !llvm.ptr {llvm.byref = i64}, %arg2: !llvm.ptr {llvm.sret = i64}, %arg3: !llvm.ptr {llvm.inalloca = i64})
+define void @foo(ptr byval(i64) %arg0, ptr byref(i64) %arg1, ptr sret(i64) %arg2, ptr inalloca(i64) %arg3) {
+  ret void
+}

diff  --git a/mlir/test/Target/LLVMIR/llvmir-invalid.mlir b/mlir/test/Target/LLVMIR/llvmir-invalid.mlir
index ba23c8700c48d..e9dde1b37d05d 100644
--- a/mlir/test/Target/LLVMIR/llvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir-invalid.mlir
@@ -15,12 +15,19 @@ llvm.func @invalid_noalias(%arg0 : f32 {llvm.noalias}) -> f32 {
 // -----
 
 // expected-error @+1 {{llvm.sret attribute attached to LLVM non-pointer argument}}
-llvm.func @invalid_sret(%arg0 : f32 {llvm.sret}) -> f32 {
+llvm.func @invalid_sret(%arg0 : f32 {llvm.sret = f32}) -> f32 {
   llvm.return %arg0 : f32
 }
 
 // -----
 
+// expected-error @+1 {{llvm.sret attribute attached to LLVM pointer argument of a 
diff erent type}}
+llvm.func @invalid_sret(%arg0 : !llvm.ptr<f32> {llvm.sret = i32}) -> !llvm.ptr<f32> {
+  llvm.return %arg0 : !llvm.ptr<f32>
+}
+
+// -----
+
 // expected-error @+1 {{llvm.nest attribute attached to LLVM non-pointer argument}}
 llvm.func @invalid_nest(%arg0 : f32 {llvm.nest}) -> f32 {
   llvm.return %arg0 : f32
@@ -28,12 +35,47 @@ llvm.func @invalid_nest(%arg0 : f32 {llvm.nest}) -> f32 {
 // -----
 
 // expected-error @+1 {{llvm.byval attribute attached to LLVM non-pointer argument}}
-llvm.func @invalid_byval(%arg0 : f32 {llvm.byval}) -> f32 {
+llvm.func @invalid_byval(%arg0 : f32 {llvm.byval = f32}) -> f32 {
+  llvm.return %arg0 : f32
+}
+
+// -----
+
+// expected-error @+1 {{llvm.byval attribute attached to LLVM pointer argument of a 
diff erent type}}
+llvm.func @invalid_sret(%arg0 : !llvm.ptr<f32> {llvm.byval = i32}) -> !llvm.ptr<f32> {
+  llvm.return %arg0 : !llvm.ptr<f32>
+}
+
+// -----
+
+// expected-error @+1 {{llvm.byref attribute attached to LLVM non-pointer argument}}
+llvm.func @invalid_byval(%arg0 : f32 {llvm.byref = f32}) -> f32 {
   llvm.return %arg0 : f32
 }
 
 // -----
 
+// expected-error @+1 {{llvm.byref attribute attached to LLVM pointer argument of a 
diff erent type}}
+llvm.func @invalid_sret(%arg0 : !llvm.ptr<f32> {llvm.byref = i32}) -> !llvm.ptr<f32> {
+  llvm.return %arg0 : !llvm.ptr<f32>
+}
+
+// -----
+
+// expected-error @+1 {{llvm.inalloca attribute attached to LLVM non-pointer argument}}
+llvm.func @invalid_byval(%arg0 : f32 {llvm.inalloca = f32}) -> f32 {
+  llvm.return %arg0 : f32
+}
+
+// -----
+
+// expected-error @+1 {{llvm.inalloca attribute attached to LLVM pointer argument of a 
diff erent type}}
+llvm.func @invalid_sret(%arg0 : !llvm.ptr<f32> {llvm.inalloca = i32}) -> !llvm.ptr<f32> {
+  llvm.return %arg0 : !llvm.ptr<f32>
+}
+
+// -----
+
 // expected-error @+1 {{llvm.align attribute attached to LLVM non-pointer argument}}
 llvm.func @invalid_align(%arg0 : f32 {llvm.align = 4}) -> f32 {
   llvm.return %arg0 : f32

diff  --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir
index 3b68340cf7629..f7b021133465d 100644
--- a/mlir/test/Target/LLVMIR/llvmir.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir.mlir
@@ -1050,12 +1050,12 @@ llvm.func @llvm_noalias(%arg0: !llvm.ptr<f32> {llvm.noalias}) {
 }
 
 // CHECK-LABEL: define void @byvalattr(ptr byval(i32) %
-llvm.func @byvalattr(%arg0: !llvm.ptr<i32> {llvm.byval}) {
+llvm.func @byvalattr(%arg0: !llvm.ptr<i32> {llvm.byval = i32}) {
   llvm.return
 }
 
 // CHECK-LABEL: define void @sretattr(ptr sret(i32) %
-llvm.func @sretattr(%arg0: !llvm.ptr<i32> {llvm.sret}) {
+llvm.func @sretattr(%arg0: !llvm.ptr<i32> {llvm.sret = i32}) {
   llvm.return
 }
 


        


More information about the Mlir-commits mailing list