[flang-commits] [flang] [flang][FIR] handle argument attributes in fir.call (PR #126711)

via flang-commits flang-commits at lists.llvm.org
Tue Feb 11 12:31:38 PST 2025


https://github.com/jeanPerier updated https://github.com/llvm/llvm-project/pull/126711

>From c83d2ed9d74de1235957afc5fc4c20a23b4f9d4f Mon Sep 17 00:00:00 2001
From: Jean Perier <jperier at nvidia.com>
Date: Tue, 11 Feb 2025 02:40:51 -0800
Subject: [PATCH 1/2] [flang][FIR] handle argument attributes in fir.call

---
 flang/lib/Optimizer/CodeGen/CodeGen.cpp |  6 ++++-
 flang/lib/Optimizer/Dialect/FIROps.cpp  | 31 +++++++++++++++----------
 flang/test/Fir/convert-to-llvm.fir      | 18 ++++++++++++++
 flang/test/Fir/fir-ops.fir              | 20 ++++++++++++++++
 4 files changed, 62 insertions(+), 13 deletions(-)

diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index cb4eb8303a495..6346ee0d35292 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -589,10 +589,14 @@ struct CallOpConversion : public fir::FIROpConversion<fir::CallOp> {
     // Convert arith::FastMathFlagsAttr to LLVM::FastMathFlagsAttr.
     mlir::arith::AttrConvertFastMathToLLVM<fir::CallOp, mlir::LLVM::CallOp>
         attrConvert(call);
-    rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>(
+    auto llvmCall = rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>(
         call, resultTys, adaptor.getOperands(),
         addLLVMOpBundleAttrs(rewriter, attrConvert.getAttrs(),
                              adaptor.getOperands().size()));
+    if (mlir::ArrayAttr argAttrs = call.getArgAttrsAttr())
+      llvmCall.setArgAttrsAttr(argAttrs);
+    if (mlir::ArrayAttr resAttrs = call.getResAttrsAttr())
+      llvmCall.setArgAttrsAttr(resAttrs);
     return mlir::success();
   }
 };
diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index fa83aa380e489..7e50622db08c9 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -1121,11 +1121,12 @@ void fir::CallOp::print(mlir::OpAsmPrinter &p) {
 
   p.printOptionalAttrDict((*this)->getAttrs(),
                           {fir::CallOp::getCalleeAttrNameStr(),
-                           getFastmathAttrName(), getProcedureAttrsAttrName()});
-  auto resultTypes{getResultTypes()};
-  llvm::SmallVector<mlir::Type> argTypes(
-      llvm::drop_begin(getOperandTypes(), isDirect ? 0 : 1));
-  p << " : " << mlir::FunctionType::get(getContext(), argTypes, resultTypes);
+                           getFastmathAttrName(), getProcedureAttrsAttrName(),
+                           getArgAttrsAttrName(), getResAttrsAttrName()});
+  p << " : ";
+  mlir::call_interface_impl::printFunctionSignature(
+      p, getArgs().drop_front(isDirect ? 0 : 1).getTypes(), getArgAttrsAttr(),
+      /*isVariadic=*/false, getResultTypes(), getResAttrsAttr());
 }
 
 mlir::ParseResult fir::CallOp::parse(mlir::OpAsmParser &parser,
@@ -1142,7 +1143,6 @@ mlir::ParseResult fir::CallOp::parse(mlir::OpAsmParser &parser,
                               attrs))
       return mlir::failure();
 
-  mlir::Type type;
   if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::Paren))
     return mlir::failure();
 
@@ -1163,13 +1163,17 @@ mlir::ParseResult fir::CallOp::parse(mlir::OpAsmParser &parser,
                                                 fmfAttrName, attrs))
       return mlir::failure();
 
-  if (parser.parseOptionalAttrDict(attrs) || parser.parseColon() ||
-      parser.parseType(type))
+  if (parser.parseOptionalAttrDict(attrs) || parser.parseColon())
     return mlir::failure();
-
-  auto funcType = mlir::dyn_cast<mlir::FunctionType>(type);
-  if (!funcType)
+  llvm::SmallVector<mlir::Type> argTypes;
+  llvm::SmallVector<mlir::Type> resTypes;
+  llvm::SmallVector<mlir::DictionaryAttr> argAttrs;
+  llvm::SmallVector<mlir::DictionaryAttr> resultAttrs;
+  if (mlir::call_interface_impl::parseFunctionSignature(
+          parser, argTypes, argAttrs, resTypes, resultAttrs))
     return parser.emitError(parser.getNameLoc(), "expected function type");
+  mlir::FunctionType funcType =
+      mlir::FunctionType::get(parser.getContext(), argTypes, resTypes);
   if (isDirect) {
     if (parser.resolveOperands(operands, funcType.getInputs(),
                                parser.getNameLoc(), result.operands))
@@ -1183,8 +1187,11 @@ mlir::ParseResult fir::CallOp::parse(mlir::OpAsmParser &parser,
                                parser.getNameLoc(), result.operands))
       return mlir::failure();
   }
-  result.addTypes(funcType.getResults());
   result.attributes = attrs;
+  mlir::call_interface_impl::addArgAndResultAttrs(
+      parser.getBuilder(), result, argAttrs, resultAttrs,
+      getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
+  result.addTypes(funcType.getResults());
   return mlir::success();
 }
 
diff --git a/flang/test/Fir/convert-to-llvm.fir b/flang/test/Fir/convert-to-llvm.fir
index 6d7a4a09918e5..a4c176a9e2ee8 100644
--- a/flang/test/Fir/convert-to-llvm.fir
+++ b/flang/test/Fir/convert-to-llvm.fir
@@ -2853,3 +2853,21 @@ gpu.module @cuda_device_mod {
 // CHECK: llvm.func @malloc(i64) -> !llvm.ptr
 // CHECK: llvm.call @malloc
 // CHECK: lvm.call @free
+
+// -----
+
+func.func private @somefunc(i32, !fir.ref<i64>)
+
+// CHECK-LABEL: @test_call_arg_attrs_direct
+func.func @test_call_arg_attrs_direct(%arg0: i32, %arg1: !fir.ref<i64>) {
+  // CHECK: llvm.call @somefunc(%{{.*}}, %{{.*}}) : (i32, !llvm.ptr {llvm.byval = i64}) -> ()
+  fir.call @somefunc(%arg0, %arg1) : (i32, !fir.ref<i64> {llvm.byval = i64}) -> ()
+  return
+}
+
+// CHECK-LABEL: @test_call_arg_attrs_indirect
+func.func @test_call_arg_attrs_indirect(%arg0: i16, %arg1: (i16)-> i16) -> i16 {
+  // CHECK: llvm.call %arg1(%{{.*}}) : !llvm.ptr, (i16 {llvm.signext}) -> (i16 {llvm.signext})
+  %0 = fir.call %arg1(%arg0) : (i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext})
+  return %0 : i16
+}
diff --git a/flang/test/Fir/fir-ops.fir b/flang/test/Fir/fir-ops.fir
index 5a30858511f0c..1bfcb3a9f3dc8 100644
--- a/flang/test/Fir/fir-ops.fir
+++ b/flang/test/Fir/fir-ops.fir
@@ -913,3 +913,23 @@ func.func @test_is_assumed_size(%arg0: !fir.class<!fir.array<*:none>>, %arg1 : !
 // CHECK-SAME: %[[B:.*]]: !fir.box<!fir.array<?xf32>>)
   // CHECK: fir.is_assumed_size %[[A]] : (!fir.class<!fir.array<*:none>>) -> i1
   // CHECK: fir.is_assumed_size %[[B]] : (!fir.box<!fir.array<?xf32>>) -> i1
+
+func.func private @somefunc(i32, !fir.ref<i64>)
+
+// CHECK-LABEL: @test_call_arg_attrs_direct
+// CHECK-SAME:    %[[VAL_0:.*]]: i32,
+// CHECK-SAME:    %[[VAL_1:.*]]: !fir.ref<i64>) {
+func.func @test_call_arg_attrs_direct(%arg0: i32, %arg1: !fir.ref<i64>) {
+  // CHECK:  fir.call @somefunc(%[[VAL_0]], %[[VAL_1]]) : (i32, !fir.ref<i64> {llvm.byval = i64}) -> ()
+  fir.call @somefunc(%arg0, %arg1) : (i32, !fir.ref<i64> {llvm.byval = i64}) -> ()
+  return
+}
+
+// CHECK-LABEL: @test_call_arg_attrs_indirect
+// CHECK-SAME:    %[[VAL_0:.*]]: i16,
+// CHECK-SAME:    %[[VAL_1:.*]]: (i16) -> i16) -> i16 {
+func.func @test_call_arg_attrs_indirect(%arg0: i16, %arg1: (i16)-> i16) -> i16 {
+  // CHECK:  fir.call %[[VAL_1]](%[[VAL_0]]) : (i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext})
+  %0 = fir.call %arg1(%arg0) : (i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext})
+  return %0 : i16
+}

>From b014e6dea7672c2b88e1f4c3a5602f1a87ab5525 Mon Sep 17 00:00:00 2001
From: Jean Perier <jperier at nvidia.com>
Date: Tue, 11 Feb 2025 12:29:48 -0800
Subject: [PATCH 2/2] fix typo in codegen

---
 flang/lib/Optimizer/CodeGen/CodeGen.cpp | 2 +-
 flang/test/Fir/convert-to-llvm.fir      | 2 +-
 2 files changed, 2 insertions(+), 2 deletions(-)

diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index 6346ee0d35292..f938d8d377465 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -596,7 +596,7 @@ struct CallOpConversion : public fir::FIROpConversion<fir::CallOp> {
     if (mlir::ArrayAttr argAttrs = call.getArgAttrsAttr())
       llvmCall.setArgAttrsAttr(argAttrs);
     if (mlir::ArrayAttr resAttrs = call.getResAttrsAttr())
-      llvmCall.setArgAttrsAttr(resAttrs);
+      llvmCall.setResAttrsAttr(resAttrs);
     return mlir::success();
   }
 };
diff --git a/flang/test/Fir/convert-to-llvm.fir b/flang/test/Fir/convert-to-llvm.fir
index a4c176a9e2ee8..c11cfd5d5faa1 100644
--- a/flang/test/Fir/convert-to-llvm.fir
+++ b/flang/test/Fir/convert-to-llvm.fir
@@ -2867,7 +2867,7 @@ func.func @test_call_arg_attrs_direct(%arg0: i32, %arg1: !fir.ref<i64>) {
 
 // CHECK-LABEL: @test_call_arg_attrs_indirect
 func.func @test_call_arg_attrs_indirect(%arg0: i16, %arg1: (i16)-> i16) -> i16 {
-  // CHECK: llvm.call %arg1(%{{.*}}) : !llvm.ptr, (i16 {llvm.signext}) -> (i16 {llvm.signext})
+  // CHECK: llvm.call %arg1(%{{.*}}) : !llvm.ptr, (i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext})
   %0 = fir.call %arg1(%arg0) : (i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext})
   return %0 : i16
 }



More information about the flang-commits mailing list