[flang-commits] [flang] [flang][FIR] handle argument attributes in fir.call (PR #126711)

via flang-commits flang-commits at lists.llvm.org
Tue Feb 11 02:53:55 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-flang-codegen

Author: None (jeanPerier)

<details>
<summary>Changes</summary>

Add pretty printer/parser for fir.call argument/result attributes and propagate them to llvm.call.

This will allow implementing the TODO about ABI relevant argument attribute in indirect calls [here](https://github.com/llvm/llvm-project/blob/67f59a642f96cdb0b73b991b96b5ca05b988e50d/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp#L545).

---
Full diff: https://github.com/llvm/llvm-project/pull/126711.diff


4 Files Affected:

- (modified) flang/lib/Optimizer/CodeGen/CodeGen.cpp (+5-1) 
- (modified) flang/lib/Optimizer/Dialect/FIROps.cpp (+19-12) 
- (modified) flang/test/Fir/convert-to-llvm.fir (+18) 
- (modified) flang/test/Fir/fir-ops.fir (+20) 


``````````diff
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index cb4eb8303a4959e..6346ee0d35292c6 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -589,10 +589,14 @@ struct CallOpConversion : public fir::FIROpConversion<fir::CallOp> {
     // Convert arith::FastMathFlagsAttr to LLVM::FastMathFlagsAttr.
     mlir::arith::AttrConvertFastMathToLLVM<fir::CallOp, mlir::LLVM::CallOp>
         attrConvert(call);
-    rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>(
+    auto llvmCall = rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>(
         call, resultTys, adaptor.getOperands(),
         addLLVMOpBundleAttrs(rewriter, attrConvert.getAttrs(),
                              adaptor.getOperands().size()));
+    if (mlir::ArrayAttr argAttrs = call.getArgAttrsAttr())
+      llvmCall.setArgAttrsAttr(argAttrs);
+    if (mlir::ArrayAttr resAttrs = call.getResAttrsAttr())
+      llvmCall.setArgAttrsAttr(resAttrs);
     return mlir::success();
   }
 };
diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index fa83aa380e489c3..7e50622db08c98b 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -1121,11 +1121,12 @@ void fir::CallOp::print(mlir::OpAsmPrinter &p) {
 
   p.printOptionalAttrDict((*this)->getAttrs(),
                           {fir::CallOp::getCalleeAttrNameStr(),
-                           getFastmathAttrName(), getProcedureAttrsAttrName()});
-  auto resultTypes{getResultTypes()};
-  llvm::SmallVector<mlir::Type> argTypes(
-      llvm::drop_begin(getOperandTypes(), isDirect ? 0 : 1));
-  p << " : " << mlir::FunctionType::get(getContext(), argTypes, resultTypes);
+                           getFastmathAttrName(), getProcedureAttrsAttrName(),
+                           getArgAttrsAttrName(), getResAttrsAttrName()});
+  p << " : ";
+  mlir::call_interface_impl::printFunctionSignature(
+      p, getArgs().drop_front(isDirect ? 0 : 1).getTypes(), getArgAttrsAttr(),
+      /*isVariadic=*/false, getResultTypes(), getResAttrsAttr());
 }
 
 mlir::ParseResult fir::CallOp::parse(mlir::OpAsmParser &parser,
@@ -1142,7 +1143,6 @@ mlir::ParseResult fir::CallOp::parse(mlir::OpAsmParser &parser,
                               attrs))
       return mlir::failure();
 
-  mlir::Type type;
   if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::Paren))
     return mlir::failure();
 
@@ -1163,13 +1163,17 @@ mlir::ParseResult fir::CallOp::parse(mlir::OpAsmParser &parser,
                                                 fmfAttrName, attrs))
       return mlir::failure();
 
-  if (parser.parseOptionalAttrDict(attrs) || parser.parseColon() ||
-      parser.parseType(type))
+  if (parser.parseOptionalAttrDict(attrs) || parser.parseColon())
     return mlir::failure();
-
-  auto funcType = mlir::dyn_cast<mlir::FunctionType>(type);
-  if (!funcType)
+  llvm::SmallVector<mlir::Type> argTypes;
+  llvm::SmallVector<mlir::Type> resTypes;
+  llvm::SmallVector<mlir::DictionaryAttr> argAttrs;
+  llvm::SmallVector<mlir::DictionaryAttr> resultAttrs;
+  if (mlir::call_interface_impl::parseFunctionSignature(
+          parser, argTypes, argAttrs, resTypes, resultAttrs))
     return parser.emitError(parser.getNameLoc(), "expected function type");
+  mlir::FunctionType funcType =
+      mlir::FunctionType::get(parser.getContext(), argTypes, resTypes);
   if (isDirect) {
     if (parser.resolveOperands(operands, funcType.getInputs(),
                                parser.getNameLoc(), result.operands))
@@ -1183,8 +1187,11 @@ mlir::ParseResult fir::CallOp::parse(mlir::OpAsmParser &parser,
                                parser.getNameLoc(), result.operands))
       return mlir::failure();
   }
-  result.addTypes(funcType.getResults());
   result.attributes = attrs;
+  mlir::call_interface_impl::addArgAndResultAttrs(
+      parser.getBuilder(), result, argAttrs, resultAttrs,
+      getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
+  result.addTypes(funcType.getResults());
   return mlir::success();
 }
 
diff --git a/flang/test/Fir/convert-to-llvm.fir b/flang/test/Fir/convert-to-llvm.fir
index 6d7a4a09918e5a2..a4c176a9e2ee86e 100644
--- a/flang/test/Fir/convert-to-llvm.fir
+++ b/flang/test/Fir/convert-to-llvm.fir
@@ -2853,3 +2853,21 @@ gpu.module @cuda_device_mod {
 // CHECK: llvm.func @malloc(i64) -> !llvm.ptr
 // CHECK: llvm.call @malloc
 // CHECK: lvm.call @free
+
+// -----
+
+func.func private @somefunc(i32, !fir.ref<i64>)
+
+// CHECK-LABEL: @test_call_arg_attrs_direct
+func.func @test_call_arg_attrs_direct(%arg0: i32, %arg1: !fir.ref<i64>) {
+  // CHECK: llvm.call @somefunc(%{{.*}}, %{{.*}}) : (i32, !llvm.ptr {llvm.byval = i64}) -> ()
+  fir.call @somefunc(%arg0, %arg1) : (i32, !fir.ref<i64> {llvm.byval = i64}) -> ()
+  return
+}
+
+// CHECK-LABEL: @test_call_arg_attrs_indirect
+func.func @test_call_arg_attrs_indirect(%arg0: i16, %arg1: (i16)-> i16) -> i16 {
+  // CHECK: llvm.call %arg1(%{{.*}}) : !llvm.ptr, (i16 {llvm.signext}) -> (i16 {llvm.signext})
+  %0 = fir.call %arg1(%arg0) : (i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext})
+  return %0 : i16
+}
diff --git a/flang/test/Fir/fir-ops.fir b/flang/test/Fir/fir-ops.fir
index 5a30858511f0c9a..1bfcb3a9f3dc898 100644
--- a/flang/test/Fir/fir-ops.fir
+++ b/flang/test/Fir/fir-ops.fir
@@ -913,3 +913,23 @@ func.func @test_is_assumed_size(%arg0: !fir.class<!fir.array<*:none>>, %arg1 : !
 // CHECK-SAME: %[[B:.*]]: !fir.box<!fir.array<?xf32>>)
   // CHECK: fir.is_assumed_size %[[A]] : (!fir.class<!fir.array<*:none>>) -> i1
   // CHECK: fir.is_assumed_size %[[B]] : (!fir.box<!fir.array<?xf32>>) -> i1
+
+func.func private @somefunc(i32, !fir.ref<i64>)
+
+// CHECK-LABEL: @test_call_arg_attrs_direct
+// CHECK-SAME:    %[[VAL_0:.*]]: i32,
+// CHECK-SAME:    %[[VAL_1:.*]]: !fir.ref<i64>) {
+func.func @test_call_arg_attrs_direct(%arg0: i32, %arg1: !fir.ref<i64>) {
+  // CHECK:  fir.call @somefunc(%[[VAL_0]], %[[VAL_1]]) : (i32, !fir.ref<i64> {llvm.byval = i64}) -> ()
+  fir.call @somefunc(%arg0, %arg1) : (i32, !fir.ref<i64> {llvm.byval = i64}) -> ()
+  return
+}
+
+// CHECK-LABEL: @test_call_arg_attrs_indirect
+// CHECK-SAME:    %[[VAL_0:.*]]: i16,
+// CHECK-SAME:    %[[VAL_1:.*]]: (i16) -> i16) -> i16 {
+func.func @test_call_arg_attrs_indirect(%arg0: i16, %arg1: (i16)-> i16) -> i16 {
+  // CHECK:  fir.call %[[VAL_1]](%[[VAL_0]]) : (i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext})
+  %0 = fir.call %arg1(%arg0) : (i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext})
+  return %0 : i16
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/126711


More information about the flang-commits mailing list