[flang-commits] [flang] [flang][FIR] handle argument attributes in fir.call (PR #126711)
via flang-commits
flang-commits at lists.llvm.org
Tue Feb 11 12:31:38 PST 2025
https://github.com/jeanPerier updated https://github.com/llvm/llvm-project/pull/126711
>From c83d2ed9d74de1235957afc5fc4c20a23b4f9d4f Mon Sep 17 00:00:00 2001
From: Jean Perier <jperier at nvidia.com>
Date: Tue, 11 Feb 2025 02:40:51 -0800
Subject: [PATCH 1/2] [flang][FIR] handle argument attributes in fir.call
---
flang/lib/Optimizer/CodeGen/CodeGen.cpp | 6 ++++-
flang/lib/Optimizer/Dialect/FIROps.cpp | 31 +++++++++++++++----------
flang/test/Fir/convert-to-llvm.fir | 18 ++++++++++++++
flang/test/Fir/fir-ops.fir | 20 ++++++++++++++++
4 files changed, 62 insertions(+), 13 deletions(-)
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index cb4eb8303a495..6346ee0d35292 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -589,10 +589,14 @@ struct CallOpConversion : public fir::FIROpConversion<fir::CallOp> {
// Convert arith::FastMathFlagsAttr to LLVM::FastMathFlagsAttr.
mlir::arith::AttrConvertFastMathToLLVM<fir::CallOp, mlir::LLVM::CallOp>
attrConvert(call);
- rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>(
+ auto llvmCall = rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>(
call, resultTys, adaptor.getOperands(),
addLLVMOpBundleAttrs(rewriter, attrConvert.getAttrs(),
adaptor.getOperands().size()));
+ if (mlir::ArrayAttr argAttrs = call.getArgAttrsAttr())
+ llvmCall.setArgAttrsAttr(argAttrs);
+ if (mlir::ArrayAttr resAttrs = call.getResAttrsAttr())
+ llvmCall.setArgAttrsAttr(resAttrs);
return mlir::success();
}
};
diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index fa83aa380e489..7e50622db08c9 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -1121,11 +1121,12 @@ void fir::CallOp::print(mlir::OpAsmPrinter &p) {
p.printOptionalAttrDict((*this)->getAttrs(),
{fir::CallOp::getCalleeAttrNameStr(),
- getFastmathAttrName(), getProcedureAttrsAttrName()});
- auto resultTypes{getResultTypes()};
- llvm::SmallVector<mlir::Type> argTypes(
- llvm::drop_begin(getOperandTypes(), isDirect ? 0 : 1));
- p << " : " << mlir::FunctionType::get(getContext(), argTypes, resultTypes);
+ getFastmathAttrName(), getProcedureAttrsAttrName(),
+ getArgAttrsAttrName(), getResAttrsAttrName()});
+ p << " : ";
+ mlir::call_interface_impl::printFunctionSignature(
+ p, getArgs().drop_front(isDirect ? 0 : 1).getTypes(), getArgAttrsAttr(),
+ /*isVariadic=*/false, getResultTypes(), getResAttrsAttr());
}
mlir::ParseResult fir::CallOp::parse(mlir::OpAsmParser &parser,
@@ -1142,7 +1143,6 @@ mlir::ParseResult fir::CallOp::parse(mlir::OpAsmParser &parser,
attrs))
return mlir::failure();
- mlir::Type type;
if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::Paren))
return mlir::failure();
@@ -1163,13 +1163,17 @@ mlir::ParseResult fir::CallOp::parse(mlir::OpAsmParser &parser,
fmfAttrName, attrs))
return mlir::failure();
- if (parser.parseOptionalAttrDict(attrs) || parser.parseColon() ||
- parser.parseType(type))
+ if (parser.parseOptionalAttrDict(attrs) || parser.parseColon())
return mlir::failure();
-
- auto funcType = mlir::dyn_cast<mlir::FunctionType>(type);
- if (!funcType)
+ llvm::SmallVector<mlir::Type> argTypes;
+ llvm::SmallVector<mlir::Type> resTypes;
+ llvm::SmallVector<mlir::DictionaryAttr> argAttrs;
+ llvm::SmallVector<mlir::DictionaryAttr> resultAttrs;
+ if (mlir::call_interface_impl::parseFunctionSignature(
+ parser, argTypes, argAttrs, resTypes, resultAttrs))
return parser.emitError(parser.getNameLoc(), "expected function type");
+ mlir::FunctionType funcType =
+ mlir::FunctionType::get(parser.getContext(), argTypes, resTypes);
if (isDirect) {
if (parser.resolveOperands(operands, funcType.getInputs(),
parser.getNameLoc(), result.operands))
@@ -1183,8 +1187,11 @@ mlir::ParseResult fir::CallOp::parse(mlir::OpAsmParser &parser,
parser.getNameLoc(), result.operands))
return mlir::failure();
}
- result.addTypes(funcType.getResults());
result.attributes = attrs;
+ mlir::call_interface_impl::addArgAndResultAttrs(
+ parser.getBuilder(), result, argAttrs, resultAttrs,
+ getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
+ result.addTypes(funcType.getResults());
return mlir::success();
}
diff --git a/flang/test/Fir/convert-to-llvm.fir b/flang/test/Fir/convert-to-llvm.fir
index 6d7a4a09918e5..a4c176a9e2ee8 100644
--- a/flang/test/Fir/convert-to-llvm.fir
+++ b/flang/test/Fir/convert-to-llvm.fir
@@ -2853,3 +2853,21 @@ gpu.module @cuda_device_mod {
// CHECK: llvm.func @malloc(i64) -> !llvm.ptr
// CHECK: llvm.call @malloc
// CHECK: lvm.call @free
+
+// -----
+
+func.func private @somefunc(i32, !fir.ref<i64>)
+
+// CHECK-LABEL: @test_call_arg_attrs_direct
+func.func @test_call_arg_attrs_direct(%arg0: i32, %arg1: !fir.ref<i64>) {
+ // CHECK: llvm.call @somefunc(%{{.*}}, %{{.*}}) : (i32, !llvm.ptr {llvm.byval = i64}) -> ()
+ fir.call @somefunc(%arg0, %arg1) : (i32, !fir.ref<i64> {llvm.byval = i64}) -> ()
+ return
+}
+
+// CHECK-LABEL: @test_call_arg_attrs_indirect
+func.func @test_call_arg_attrs_indirect(%arg0: i16, %arg1: (i16)-> i16) -> i16 {
+ // CHECK: llvm.call %arg1(%{{.*}}) : !llvm.ptr, (i16 {llvm.signext}) -> (i16 {llvm.signext})
+ %0 = fir.call %arg1(%arg0) : (i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext})
+ return %0 : i16
+}
diff --git a/flang/test/Fir/fir-ops.fir b/flang/test/Fir/fir-ops.fir
index 5a30858511f0c..1bfcb3a9f3dc8 100644
--- a/flang/test/Fir/fir-ops.fir
+++ b/flang/test/Fir/fir-ops.fir
@@ -913,3 +913,23 @@ func.func @test_is_assumed_size(%arg0: !fir.class<!fir.array<*:none>>, %arg1 : !
// CHECK-SAME: %[[B:.*]]: !fir.box<!fir.array<?xf32>>)
// CHECK: fir.is_assumed_size %[[A]] : (!fir.class<!fir.array<*:none>>) -> i1
// CHECK: fir.is_assumed_size %[[B]] : (!fir.box<!fir.array<?xf32>>) -> i1
+
+func.func private @somefunc(i32, !fir.ref<i64>)
+
+// CHECK-LABEL: @test_call_arg_attrs_direct
+// CHECK-SAME: %[[VAL_0:.*]]: i32,
+// CHECK-SAME: %[[VAL_1:.*]]: !fir.ref<i64>) {
+func.func @test_call_arg_attrs_direct(%arg0: i32, %arg1: !fir.ref<i64>) {
+ // CHECK: fir.call @somefunc(%[[VAL_0]], %[[VAL_1]]) : (i32, !fir.ref<i64> {llvm.byval = i64}) -> ()
+ fir.call @somefunc(%arg0, %arg1) : (i32, !fir.ref<i64> {llvm.byval = i64}) -> ()
+ return
+}
+
+// CHECK-LABEL: @test_call_arg_attrs_indirect
+// CHECK-SAME: %[[VAL_0:.*]]: i16,
+// CHECK-SAME: %[[VAL_1:.*]]: (i16) -> i16) -> i16 {
+func.func @test_call_arg_attrs_indirect(%arg0: i16, %arg1: (i16)-> i16) -> i16 {
+ // CHECK: fir.call %[[VAL_1]](%[[VAL_0]]) : (i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext})
+ %0 = fir.call %arg1(%arg0) : (i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext})
+ return %0 : i16
+}
>From b014e6dea7672c2b88e1f4c3a5602f1a87ab5525 Mon Sep 17 00:00:00 2001
From: Jean Perier <jperier at nvidia.com>
Date: Tue, 11 Feb 2025 12:29:48 -0800
Subject: [PATCH 2/2] fix typo in codegen
---
flang/lib/Optimizer/CodeGen/CodeGen.cpp | 2 +-
flang/test/Fir/convert-to-llvm.fir | 2 +-
2 files changed, 2 insertions(+), 2 deletions(-)
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index 6346ee0d35292..f938d8d377465 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -596,7 +596,7 @@ struct CallOpConversion : public fir::FIROpConversion<fir::CallOp> {
if (mlir::ArrayAttr argAttrs = call.getArgAttrsAttr())
llvmCall.setArgAttrsAttr(argAttrs);
if (mlir::ArrayAttr resAttrs = call.getResAttrsAttr())
- llvmCall.setArgAttrsAttr(resAttrs);
+ llvmCall.setResAttrsAttr(resAttrs);
return mlir::success();
}
};
diff --git a/flang/test/Fir/convert-to-llvm.fir b/flang/test/Fir/convert-to-llvm.fir
index a4c176a9e2ee8..c11cfd5d5faa1 100644
--- a/flang/test/Fir/convert-to-llvm.fir
+++ b/flang/test/Fir/convert-to-llvm.fir
@@ -2867,7 +2867,7 @@ func.func @test_call_arg_attrs_direct(%arg0: i32, %arg1: !fir.ref<i64>) {
// CHECK-LABEL: @test_call_arg_attrs_indirect
func.func @test_call_arg_attrs_indirect(%arg0: i16, %arg1: (i16)-> i16) -> i16 {
- // CHECK: llvm.call %arg1(%{{.*}}) : !llvm.ptr, (i16 {llvm.signext}) -> (i16 {llvm.signext})
+ // CHECK: llvm.call %arg1(%{{.*}}) : !llvm.ptr, (i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext})
%0 = fir.call %arg1(%arg0) : (i16 {llvm.noundef, llvm.signext}) -> (i16 {llvm.signext})
return %0 : i16
}
More information about the flang-commits
mailing list