[flang-commits] [flang] [flang] AArch64 ABI for BIND(C) VALUE parameters (PR #118305)
via flang-commits
flang-commits at lists.llvm.org
Mon Dec 2 06:52:48 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-fir-hlfir
Author: David Truby (DavidTruby)
<details>
<summary>Changes</summary>
This patch adds handling for derived type VALUE parameters in BIND(C)
functions for AArch64.
---
Full diff: https://github.com/llvm/llvm-project/pull/118305.diff
2 Files Affected:
- (modified) flang/lib/Optimizer/CodeGen/Target.cpp (+116-24)
- (added) flang/test/Fir/struct-passing-aarch64-byval.fir (+73)
``````````diff
diff --git a/flang/lib/Optimizer/CodeGen/Target.cpp b/flang/lib/Optimizer/CodeGen/Target.cpp
index f7bffbf53c190e..0d865ee09535a3 100644
--- a/flang/lib/Optimizer/CodeGen/Target.cpp
+++ b/flang/lib/Optimizer/CodeGen/Target.cpp
@@ -826,7 +826,7 @@ struct TargetAArch64 : public GenericTarget<TargetAArch64> {
return marshal;
}
- // Flatten a RecordType::TypeList containing more record types or array types
+ // Flatten a RecordType::TypeList containing more record types or array type
static std::optional<std::vector<mlir::Type>>
flattenTypeList(const RecordType::TypeList &types) {
std::vector<mlir::Type> flatTypes;
@@ -870,51 +870,143 @@ struct TargetAArch64 : public GenericTarget<TargetAArch64> {
// Determine if the type is a Homogenous Floating-point Aggregate (HFA). An
// HFA is a record type with up to 4 floating-point members of the same type.
- static bool isHFA(fir::RecordType ty) {
+ static std::optional<int> usedRegsForHFA(fir::RecordType ty) {
RecordType::TypeList types = ty.getTypeList();
if (types.empty() || types.size() > 4)
- return false;
+ return std::nullopt;
std::optional<std::vector<mlir::Type>> flatTypes = flattenTypeList(types);
if (!flatTypes || flatTypes->size() > 4) {
- return false;
+ return std::nullopt;
}
if (!isa_real(flatTypes->front())) {
- return false;
+ return std::nullopt;
+ }
+
+ return llvm::all_equal(*flatTypes) ? std::optional<int>{flatTypes->size()}
+ : std::nullopt;
+ }
+
+ struct NRegs {
+ int n{0};
+ bool isSimd{false};
+ };
+
+ NRegs usedRegsForRecordType(mlir::Location loc, fir::RecordType type) const {
+ if (std::optional<int> size = usedRegsForHFA(type))
+ return {*size, true};
+
+ auto [size, align] = fir::getTypeSizeAndAlignmentOrCrash(
+ loc, type, getDataLayout(), kindMap);
+
+ if (size <= 16)
+ return {static_cast<int>((size + 7) / 8), false};
+
+ // Pass on the stack, i.e. no registers used
+ return {};
+ }
+
+ NRegs usedRegsForType(mlir::Location loc, mlir::Type type) const {
+ return llvm::TypeSwitch<mlir::Type, NRegs>(type)
+ .Case<mlir::IntegerType>([&](auto intTy) {
+ return intTy.getWidth() == 128 ? NRegs{2, false} : NRegs{1, false};
+ })
+ .Case<mlir::FloatType>([&](auto) { return NRegs{1, true}; })
+ .Case<mlir::ComplexType>([&](auto) { return NRegs{2, true}; })
+ .Case<fir::LogicalType>([&](auto) { return NRegs{1, false}; })
+ .Case<fir::CharacterType>([&](auto) { return NRegs{1, false}; })
+ .Case<fir::SequenceType>([&](auto ty) {
+ NRegs nregs = usedRegsForType(loc, ty.getEleTy());
+ nregs.n *= ty.getShape()[0];
+ return nregs;
+ })
+ .Case<fir::RecordType>(
+ [&](auto ty) { return usedRegsForRecordType(loc, ty); })
+ .Case<fir::VectorType>([&](auto) {
+ TODO(loc, "passing vector argument to C by value is not supported");
+ return NRegs{};
+ });
+ }
+
+ bool hasEnoughRegisters(mlir::Location loc, fir::RecordType type,
+ const Marshalling &previousArguments) const {
+ int availIntRegisters = 8;
+ int availSIMDRegisters = 8;
+
+ // Check previous arguments to see how many registers are used already
+ for (auto [type, attr] : previousArguments) {
+ if (availIntRegisters <= 0 || availSIMDRegisters <= 0)
+ break;
+
+ if (attr.isByVal())
+ continue; // Previous argument passed on the stack
+
+ NRegs nregs = usedRegsForType(loc, type);
+ if (nregs.isSimd)
+ availSIMDRegisters -= nregs.n;
+ else
+ availIntRegisters -= nregs.n;
}
- return llvm::all_equal(*flatTypes);
+ NRegs nregs = usedRegsForRecordType(loc, type);
+
+ if (nregs.isSimd)
+ return nregs.n <= availSIMDRegisters;
+
+ return nregs.n <= availIntRegisters;
+ }
+
+ CodeGenSpecifics::Marshalling
+ passOnTheStack(mlir::Location loc, mlir::Type ty, bool isResult) const {
+ CodeGenSpecifics::Marshalling marshal;
+ auto sizeAndAlign =
+ fir::getTypeSizeAndAlignmentOrCrash(loc, ty, getDataLayout(), kindMap);
+ // The stack is always 8 byte aligned
+ unsigned short align =
+ std::max(sizeAndAlign.second, static_cast<unsigned short>(8));
+ marshal.emplace_back(fir::ReferenceType::get(ty),
+ AT{align, /*byval=*/!isResult, /*sret=*/isResult});
+ return marshal;
}
// AArch64 procedure call ABI:
// https://github.com/ARM-software/abi-aa/blob/main/aapcs64/aapcs64.rst#parameter-passing
CodeGenSpecifics::Marshalling
- structReturnType(mlir::Location loc, fir::RecordType ty) const override {
+ structType(mlir::Location loc, fir::RecordType type, bool isResult) const {
+ NRegs nregs = usedRegsForRecordType(loc, type);
+
+ // If the type needs no registers it must need to be passed on the stack
+ if (nregs.n == 0)
+ return passOnTheStack(loc, type, isResult);
+
CodeGenSpecifics::Marshalling marshal;
- if (isHFA(ty)) {
- // Just return the existing record type
- marshal.emplace_back(ty, AT{});
- return marshal;
+ mlir::Type pcsType;
+ if (nregs.isSimd) {
+ pcsType = type;
+ } else {
+ pcsType = fir::SequenceType::get(
+ nregs.n, mlir::IntegerType::get(type.getContext(), 64));
}
- auto [size, align] =
- fir::getTypeSizeAndAlignmentOrCrash(loc, ty, getDataLayout(), kindMap);
+ marshal.emplace_back(pcsType, AT{});
+ return marshal;
+ }
- // return in registers if size <= 16 bytes
- if (size <= 16) {
- std::size_t dwordSize = (size + 7) / 8;
- auto newTy = fir::SequenceType::get(
- dwordSize, mlir::IntegerType::get(ty.getContext(), 64));
- marshal.emplace_back(newTy, AT{});
- return marshal;
+ CodeGenSpecifics::Marshalling
+ structArgumentType(mlir::Location loc, fir::RecordType ty,
+ const Marshalling &previousArguments) const override {
+ if (!hasEnoughRegisters(loc, ty, previousArguments)) {
+ return passOnTheStack(loc, ty, /*isResult=*/false);
}
- unsigned short stackAlign = std::max<unsigned short>(align, 8u);
- marshal.emplace_back(fir::ReferenceType::get(ty),
- AT{stackAlign, false, true});
- return marshal;
+ return structType(loc, ty, /*isResult=*/false);
+ }
+
+ CodeGenSpecifics::Marshalling
+ structReturnType(mlir::Location loc, fir::RecordType ty) const override {
+ return structType(loc, ty, /*isResult=*/true);
}
};
} // namespace
diff --git a/flang/test/Fir/struct-passing-aarch64-byval.fir b/flang/test/Fir/struct-passing-aarch64-byval.fir
new file mode 100644
index 00000000000000..27143459dde2f2
--- /dev/null
+++ b/flang/test/Fir/struct-passing-aarch64-byval.fir
@@ -0,0 +1,73 @@
+// Test AArch64 ABI rewrite of struct passed by value (BIND(C), VALUE derived types).
+// RUN: fir-opt --target-rewrite="target=aarch64-unknown-linux-gnu" %s | FileCheck %s
+
+// CHECK-LABEL: func.func private @small_i32(!fir.array<2xi64>)
+func.func private @small_i32(!fir.type<small_i32{i:i32,j:i32,k:i32}>)
+// CHECK-LABEL: func.func private @small_i64(!fir.array<2xi64>)
+func.func private @small_i64(!fir.type<small_i64{i:i64,j:i64}>)
+// CHECK-LABEL: func.func private @small_mixed(!fir.array<2xi64>)
+func.func private @small_mixed(!fir.type<small_mixed{i:i64,j:f32,k:i32}>)
+// CHECK-LABEL: func.func private @small_non_hfa(!fir.array<2xi64>)
+func.func private @small_non_hfa(!fir.type<small_non_hfa{i:f64,j:f32,k:f16}>)
+
+// CHECK-LABEL: func.func private @hfa_f16(!fir.type<hfa_f16{i:f16,j:f16}>)
+func.func private @hfa_f16(!fir.type<hfa_f16{i:f16,j:f16}>)
+// CHECK-LABEL: func.func private @hfa_bf16(!fir.type<hfa_bf16{i:bf16,j:bf16,k:bf16,l:bf16}>)
+func.func private @hfa_bf16(!fir.type<hfa_bf16{i:bf16,j:bf16,k:bf16,l:bf16}>)
+// CHECK-LABEL: func.func private @hfa_f32(!fir.type<hfa_f32{i:f32,j:f32}>)
+func.func private @hfa_f32(!fir.type<hfa_f32{i:f32,j:f32}>)
+// CHECK-LABEL: func.func private @hfa_f64(!fir.type<hfa_f64{i:f64,j:f64,k:f64}>)
+func.func private @hfa_f64(!fir.type<hfa_f64{i:f64,j:f64,k:f64}>)
+// CHECK-LABEL: func.func private @hfa_f128(!fir.type<hfa_f128{i:f128,j:f128,k:f128,l:f128}>)
+func.func private @hfa_f128(!fir.type<hfa_f128{i:f128,j:f128,k:f128,l:f128}>)
+
+// CHECK-LABEL: func.func private @multi_small_integer(!fir.array<2xi64>, !fir.array<2xi64>)
+func.func private @multi_small_integer(!fir.type<small_i32{i:i32,j:i32,k:i32}>, !fir.type<small_i64{i:i64,j:i64}>)
+// CHECK-LABEL: func.func private @multi_hfas(!fir.type<hfa_f16{i:f16,j:f16}>, !fir.type<hfa_f128{i:f128,j:f128,k:f128,l:f128}>)
+func.func private @multi_hfas(!fir.type<hfa_f16{i:f16,j:f16}>, !fir.type<hfa_f128{i:f128,j:f128,k:f128,l:f128}>)
+// CHECK-LABEL: func.func private @multi_mixed(!fir.type<hfa_f64{i:f64,j:f64,k:f64}>, !fir.array<2xi64>, !fir.type<hfa_f32{i:f32,j:f32}>, !fir.array<2xi64>)
+func.func private @multi_mixed(!fir.type<hfa_f64{i:f64,j:f64,k:f64}>,!fir.type<small_non_hfa{i:f64,j:f32,k:f16}>,!fir.type<hfa_f32{i:f32,j:f32}>,!fir.type<small_i64{i:i64,j:i64}>)
+
+// CHECK-LABEL: func.func private @int_max(!fir.array<2xi64>,
+// CHECK-SAME: !fir.array<2xi64>,
+// CHECK-SAME: !fir.array<2xi64>,
+// CHECK-SAME: !fir.array<2xi64>)
+func.func private @int_max(!fir.type<int_max{i:i64,j:i64}>,
+ !fir.type<int_max{i:i64,j:i64}>,
+ !fir.type<int_max{i:i64,j:i64}>,
+ !fir.type<int_max{i:i64,j:i64}>)
+// CHECK-LABEL: func.func private @hfa_max(!fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>, !fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>)
+func.func private @hfa_max(!fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>, !fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>)
+// CHECK-LABEL: func.func private @max(!fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>,
+// CHECK-SAME: !fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>,
+// CHECK-SAME: !fir.array<2xi64>,
+// CHECK-SAME: !fir.array<2xi64>,
+// CHECK-SAME: !fir.array<2xi64>,
+// CHECK-SAME: !fir.array<2xi64>)
+func.func private @max(!fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>,
+ !fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>,
+ !fir.type<int_max{i:i64,j:i64}>,
+ !fir.type<int_max{i:i64,j:i64}>,
+ !fir.type<int_max{i:i64,j:i64}>,
+ !fir.type<int_max{i:i64,j:i64}>)
+
+
+// CHECK-LABEL: func.func private @too_many_int(!fir.array<2xi64>,
+// CHECK-SAME: !fir.array<2xi64>,
+// CHECK-SAME: !fir.array<2xi64>,
+// CHECK-SAME: !fir.array<2xi64>,
+// CHECK-SAME: !fir.ref<!fir.type<int_max{i:i64,j:i64}>> {{{.*}}, llvm.byval = !fir.type<int_max{i:i64,j:i64}>})
+func.func private @too_many_int(!fir.type<int_max{i:i64,j:i64}>,
+ !fir.type<int_max{i:i64,j:i64}>,
+ !fir.type<int_max{i:i64,j:i64}>,
+ !fir.type<int_max{i:i64,j:i64}>,
+ !fir.type<int_max{i:i64,j:i64}>)
+// CHECK-LABEL: func.func private @too_many_hfa(!fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>,
+// CHECK-SAME: !fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>,
+// CHECK-SAME: !fir.ref<!fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>> {{{.*}}, llvm.byval = !fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>})
+func.func private @too_many_hfa(!fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>,
+ !fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>,
+ !fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>)
+
+// CHECK-LABEL: func.func private @too_big(!fir.ref<!fir.type<too_big{i:!fir.array<5xi32>}>> {{{.*}}, llvm.byval = !fir.type<too_big{i:!fir.array<5xi32>}>})
+func.func private @too_big(!fir.type<too_big{i:!fir.array<5xi32>}>)
``````````
</details>
https://github.com/llvm/llvm-project/pull/118305
More information about the flang-commits
mailing list