[flang-commits] [flang] [flang] AArch64 support for BIND(C) derived return types (PR #114051)

David Truby via flang-commits flang-commits at lists.llvm.org
Tue Nov 12 10:40:22 PST 2024


https://github.com/DavidTruby updated https://github.com/llvm/llvm-project/pull/114051

>From 717429e5e5dfd8e22787822f23e192258960f050 Mon Sep 17 00:00:00 2001
From: David Truby <david.truby at arm.com>
Date: Tue, 29 Oct 2024 12:55:22 +0000
Subject: [PATCH 1/3] [flang] AArch64 support for BIND(C) derived return types

This patch adds support for BIND(C) derived types as return values
matching the AArch64 Procedure Call Standard for C.

Support for BIND(C) derived types as value parameters will be in a
separate patch.
---
 flang/lib/Optimizer/CodeGen/Target.cpp   |  42 ++++++
 flang/test/Fir/struct-return-aarch64.fir | 156 +++++++++++++++++++++++
 2 files changed, 198 insertions(+)
 create mode 100644 flang/test/Fir/struct-return-aarch64.fir

diff --git a/flang/lib/Optimizer/CodeGen/Target.cpp b/flang/lib/Optimizer/CodeGen/Target.cpp
index 6c148dffb0e55a..15ffdb74ef51d6 100644
--- a/flang/lib/Optimizer/CodeGen/Target.cpp
+++ b/flang/lib/Optimizer/CodeGen/Target.cpp
@@ -825,6 +825,48 @@ struct TargetAArch64 : public GenericTarget<TargetAArch64> {
     }
     return marshal;
   }
+
+  static bool isHFA(fir::RecordType ty) {
+    auto types = ty.getTypeList();
+    if (types.empty() || types.size() > 4) {
+      return false;
+    }
+
+    if (!isa_real(types.front().second)) {
+      types.front().second.dump();
+      return false;
+    }
+
+    return llvm::all_equal(llvm::make_second_range(types));
+  }
+
+  CodeGenSpecifics::Marshalling
+  structReturnType(mlir::Location loc, fir::RecordType ty) const override {
+    CodeGenSpecifics::Marshalling marshal;
+
+    if (isHFA(ty)) {
+      auto newTy = fir::SequenceType::get({ty.getNumFields()}, ty.getType(0));
+      marshal.emplace_back(newTy, AT{});
+      return marshal;
+    }
+
+    auto [size, align] =
+        fir::getTypeSizeAndAlignmentOrCrash(loc, ty, getDataLayout(), kindMap);
+
+    // return in registers if size <= 16 bytes
+    if (size <= 16) {
+      auto dwordSize = (size + 7) / 8;
+      auto newTy = fir::SequenceType::get(
+          dwordSize, mlir::IntegerType::get(ty.getContext(), 64));
+      marshal.emplace_back(newTy, AT{});
+      return marshal;
+    }
+
+    unsigned short stackAlign = std::max<unsigned short>(align, 8u);
+    marshal.emplace_back(fir::ReferenceType::get(ty),
+                         AT{stackAlign, false, true});
+    return marshal;
+  }
 };
 } // namespace
 
diff --git a/flang/test/Fir/struct-return-aarch64.fir b/flang/test/Fir/struct-return-aarch64.fir
new file mode 100644
index 00000000000000..96f2f9999b3435
--- /dev/null
+++ b/flang/test/Fir/struct-return-aarch64.fir
@@ -0,0 +1,156 @@
+// Test AArch64 ABI rewrite of struct returned by value (BIND(C), VALUE derived types).
+// RUN: fir-opt --target-rewrite="target=aarch64-unknown-linux-gnu" %s | FileCheck %s
+
+!composite = !fir.type<t1{i:f32,j:i32,k:f32}>
+// CHECK-LABEL: func.func private @test_composite() -> !fir.array<2xi64>
+func.func private @test_composite() -> !composite
+// CHECK-LABEL: func.func @test_call_composite(
+// CHECK-SAME:    %[[ARG0:.*]]: !fir.ref<!fir.type<t1{i:f32,j:i32,k:f32}>>)
+func.func @test_call_composite(%arg0 : !fir.ref<!composite>) {
+  // CHECK: %[[OUT:.*]] = fir.call @test_composite() : () -> !fir.array<2xi64>
+  // CHECK: %[[STACK:.*]] = llvm.intr.stacksave : !llvm.ptr
+  // CHECK: %[[ARR:.*]] = fir.alloca !fir.array<2xi64>
+  // CHECK: fir.store %[[OUT]] to %[[ARR]] : !fir.ref<!fir.array<2xi64>>
+  // CHECK: %[[CVT:.*]] = fir.convert %[[ARR]] : (!fir.ref<!fir.array<2xi64>>) -> !fir.ref<!fir.type<t1{i:f32,j:i32,k:f32}>>
+  // CHECK: %[[LD:.*]] = fir.load %[[CVT]] : !fir.ref<!fir.type<t1{i:f32,j:i32,k:f32}>>
+  // CHECK: llvm.intr.stackrestore %[[STACK]] : !llvm.ptr
+  %out = fir.call @test_composite() : () -> !composite
+  // CHECK: fir.store %[[LD]] to %[[ARG0]] : !fir.ref<!fir.type<t1{i:f32,j:i32,k:f32}>>
+  fir.store %out to %arg0 : !fir.ref<!composite>
+  // CHECK: return
+  return
+}
+
+!hfa_f16 = !fir.type<t2{x:f16, y:f16}>
+// CHECK-LABEL: func.func private @test_hfa_f16() -> !fir.array<2xf16>
+func.func private @test_hfa_f16() -> !hfa_f16
+// CHECK-LABEL: func.func @test_call_hfa_f16(
+// CHECK-SAME:    %[[ARG0:.*]]: !fir.ref<!fir.type<t2{x:f16,y:f16}>>) {
+func.func @test_call_hfa_f16(%arg0 : !fir.ref<!hfa_f16>) {
+  // CHECK: %[[OUT:.*]] = fir.call @test_hfa_f16() : () -> !fir.array<2xf16>
+  // CHECK: %[[STACK:.*]] = llvm.intr.stacksave : !llvm.ptr
+  // CHECK: %[[ARR:.*]] = fir.alloca !fir.array<2xf16>
+  // CHECK: fir.store %[[OUT]] to %[[ARR]] : !fir.ref<!fir.array<2xf16>>
+  // CHECK: %[[CVT:.*]] = fir.convert %[[ARR]] : (!fir.ref<!fir.array<2xf16>>) -> !fir.ref<!fir.type<t2{x:f16,y:f16}>>
+  // CHECK: %[[LD:.*]] = fir.load %[[CVT]] : !fir.ref<!fir.type<t2{x:f16,y:f16}>>
+  // CHECK: llvm.intr.stackrestore %[[STACK]] : !llvm.ptr
+  %out = fir.call @test_hfa_f16() : () -> !hfa_f16
+  // CHECK: fir.store %[[LD]] to %[[ARG0]] : !fir.ref<!fir.type<t2{x:f16,y:f16}>>
+  fir.store %out to %arg0 : !fir.ref<!hfa_f16>
+  return
+}
+
+!hfa_f32 = !fir.type<t3{w:f32, x:f32, y:f32, z:f32}>
+// CHECK-LABEL: func.func private @test_hfa_f32() -> !fir.array<4xf32>
+func.func private @test_hfa_f32() -> !hfa_f32
+// CHECK-LABEL: func.func @test_call_hfa_f32(
+// CHECK-SAME:    %[[ARG0:.*]]: !fir.ref<!fir.type<t3{w:f32,x:f32,y:f32,z:f32}>>) {
+func.func @test_call_hfa_f32(%arg0 : !fir.ref<!hfa_f32>) {
+  // CHECK: %[[OUT:.*]] = fir.call @test_hfa_f32() : () -> !fir.array<4xf32>
+  // CHECK: %[[STACK:.*]] = llvm.intr.stacksave : !llvm.ptr
+  // CHECK: %[[ARR:.*]] = fir.alloca !fir.array<4xf32>
+  // CHECK: fir.store %[[OUT]] to %[[ARR]] : !fir.ref<!fir.array<4xf32>>
+  // CHECK: %[[CVT:.*]] = fir.convert %[[ARR]] : (!fir.ref<!fir.array<4xf32>>) -> !fir.ref<!fir.type<t3{w:f32,x:f32,y:f32,z:f32}>>
+  // CHECK: %[[LD:.*]] = fir.load %[[CVT]] : !fir.ref<!fir.type<t3{w:f32,x:f32,y:f32,z:f32}>>
+  // CHECK: llvm.intr.stackrestore %[[STACK]] : !llvm.ptr
+  %out = fir.call @test_hfa_f32() : () -> !hfa_f32
+  // CHECK: fir.store %[[LD]] to %[[ARG0]] : !fir.ref<!fir.type<t3{w:f32,x:f32,y:f32,z:f32}>>
+  fir.store %out to %arg0 : !fir.ref<!hfa_f32>
+  return
+}
+
+!hfa_f64 = !fir.type<t4{x:f64, y:f64, z:f64}>
+// CHECK-LABEL: func.func private @test_hfa_f64() -> !fir.array<3xf64>
+func.func private @test_hfa_f64() -> !hfa_f64
+// CHECK-LABEL: func.func @test_call_hfa_f64(
+// CHECK-SAME:    %[[ARG0:.*]]: !fir.ref<!fir.type<t4{x:f64,y:f64,z:f64}>>)
+func.func @test_call_hfa_f64(%arg0 : !fir.ref<!hfa_f64>) {
+  // CHECK: %[[OUT:.*]] = fir.call @test_hfa_f64() : () -> !fir.array<3xf64>
+  // CHECK: %[[STACK:.*]] = llvm.intr.stacksave : !llvm.ptr
+  // CHECK: %[[ARR:.*]] = fir.alloca !fir.array<3xf64>
+  // CHECK: fir.store %[[OUT]] to %[[ARR]] : !fir.ref<!fir.array<3xf64>>
+  // CHECK: %[[CVT:.*]] = fir.convert %[[ARR]] : (!fir.ref<!fir.array<3xf64>>) -> !fir.ref<!fir.type<t4{x:f64,y:f64,z:f64}>>
+  // CHECK: %[[LD:.*]] = fir.load %[[CVT]] : !fir.ref<!fir.type<t4{x:f64,y:f64,z:f64}>>
+  // CHECK: llvm.intr.stackrestore %[[STACK]] : !llvm.ptr
+  %out = fir.call @test_hfa_f64() : () -> !hfa_f64
+  // CHECK: fir.store %[[LD]] to %[[ARG0]] : !fir.ref<!fir.type<t4{x:f64,y:f64,z:f64}>>
+  fir.store %out to %arg0 : !fir.ref<!hfa_f64>
+  return
+}
+
+!hfa_f128 = !fir.type<t5{w:f128, x:f128, y:f128, z:f128}>
+// CHECK-LABEL: func.func private @test_hfa_f128() -> !fir.array<4xf128>
+func.func private @test_hfa_f128() -> !hfa_f128
+// CHECK-LABEL: func.func @test_call_hfa_f128(
+// CHECK-SAME:    %[[ARG0:.*]]: !fir.ref<!fir.type<t5{w:f128,x:f128,y:f128,z:f128}>>) {
+func.func @test_call_hfa_f128(%arg0 : !fir.ref<!hfa_f128>) {
+  // CHECK: %[[OUT:.*]] = fir.call @test_hfa_f128() : () -> !fir.array<4xf128>
+  // CHECK: %[[STACK:.*]] = llvm.intr.stacksave : !llvm.ptr
+  // CHECK: %[[ARR:.*]] = fir.alloca !fir.array<4xf128>
+  // CHECK: fir.store %[[OUT]] to %[[ARR]] : !fir.ref<!fir.array<4xf128>>
+  // CHECK: %[[CVT:.*]] = fir.convert %[[ARR]] : (!fir.ref<!fir.array<4xf128>>) -> !fir.ref<!fir.type<t5{w:f128,x:f128,y:f128,z:f128}>>
+  // CHECK: %[[LD:.*]] = fir.load %[[CVT]] : !fir.ref<!fir.type<t5{w:f128,x:f128,y:f128,z:f128}>>
+  // CHECK: llvm.intr.stackrestore %[[STACK]] : !llvm.ptr
+  %out = fir.call @test_hfa_f128() : () -> !hfa_f128
+  // CHECK: fir.store %[[LD]] to %[[ARG0]] : !fir.ref<!fir.type<t5{w:f128,x:f128,y:f128,z:f128}>>
+  fir.store %out to %arg0 : !fir.ref<!hfa_f128>
+  return
+}
+
+!hfa_bf16 = !fir.type<t6{w:bf16, x:bf16, y:bf16, z:bf16}>
+// CHECK-LABEL: func.func private @test_hfa_bf16() -> !fir.array<4xbf16>
+func.func private @test_hfa_bf16() -> !hfa_bf16
+// CHECK-LABEL: func.func @test_call_hfa_bf16(
+// CHECK-SAME:    %[[ARG0:.*]]: !fir.ref<!fir.type<t6{w:bf16,x:bf16,y:bf16,z:bf16}>>) {
+func.func @test_call_hfa_bf16(%arg0 : !fir.ref<!hfa_bf16>) {
+  // CHECK: %[[OUT:.*]] = fir.call @test_hfa_bf16() : () -> !fir.array<4xbf16>
+  // CHECK: %[[STACK:.*]] = llvm.intr.stacksave : !llvm.ptr
+  // CHECK: %[[ARR:.*]] = fir.alloca !fir.array<4xbf16>
+  // CHECK: fir.store %[[OUT]] to %[[ARR]] : !fir.ref<!fir.array<4xbf16>>
+  // CHECK: %[[CVT:.*]] = fir.convert %[[ARR]] : (!fir.ref<!fir.array<4xbf16>>) -> !fir.ref<!fir.type<t6{w:bf16,x:bf16,y:bf16,z:bf16}>>
+  // CHECK: %[[LD:.*]] = fir.load %[[CVT]] : !fir.ref<!fir.type<t6{w:bf16,x:bf16,y:bf16,z:bf16}>>
+  // CHECK: llvm.intr.stackrestore %[[STACK]] : !llvm.ptr
+  %out = fir.call @test_hfa_bf16() : () -> !hfa_bf16
+  // CHECK: fir.store %[[LD]] to %[[ARG0]] : !fir.ref<!fir.type<t6{w:bf16,x:bf16,y:bf16,z:bf16}>>
+  fir.store %out to %arg0 : !fir.ref<!hfa_bf16>
+  return
+}
+
+!too_big = !fir.type<t7{x:i64, y:i64, z:i64}>
+// CHECK-LABEL: func.func private @test_too_big(!fir.ref<!fir.type<t7{x:i64,y:i64,z:i64}>>
+// CHECK-SAME:    {llvm.align = 8 : i32, llvm.sret = !fir.type<t7{x:i64,y:i64,z:i64}>})
+func.func private @test_too_big() -> !too_big
+// CHECK-LABEL: func.func @test_call_too_big(
+// CHECK-SAME:    %[[ARG0:.*]]: !fir.ref<!fir.type<t7{x:i64,y:i64,z:i64}>>) {
+func.func @test_call_too_big(%arg0 : !fir.ref<!too_big>) {
+  // CHECK: %[[STACK:.*]] = llvm.intr.stacksave : !llvm.ptr
+  // CHECK: %[[ARG:.*]] = fir.alloca !fir.type<t7{x:i64,y:i64,z:i64}>
+  // CHECK: fir.call @test_too_big(%[[ARG]]) : (!fir.ref<!fir.type<t7{x:i64,y:i64,z:i64}>>) -> ()
+  // CHECK: %[[CVT:.*]] = fir.convert %[[ARG]] : (!fir.ref<!fir.type<t7{x:i64,y:i64,z:i64}>>) -> !fir.ref<!fir.type<t7{x:i64,y:i64,z:i64}>>
+  // CHECK: %[[LD:.*]] = fir.load %[[CVT]] : !fir.ref<!fir.type<t7{x:i64,y:i64,z:i64}>>
+  // CHECK: llvm.intr.stackrestore %[[STACK]] : !llvm.ptr
+  %out = fir.call @test_too_big() : () -> !too_big
+  // CHECK: fir.store %[[LD]] to %[[ARG0]] : !fir.ref<!fir.type<t7{x:i64,y:i64,z:i64}>>
+  fir.store %out to %arg0 : !fir.ref<!too_big>
+  return
+}
+
+
+!too_big_hfa = !fir.type<t8{i:!fir.array<5xf32>}>
+// CHECK-LABEL: func.func private @test_too_big_hfa(!fir.ref<!fir.type<t8{i:!fir.array<5xf32>}>>
+// CHECK-SAME:    {llvm.align = 8 : i32, llvm.sret = !fir.type<t8{i:!fir.array<5xf32>}>})
+func.func private @test_too_big_hfa() -> !too_big_hfa
+// CHECK-LABEL: func.func @test_call_too_big_hfa(
+// CHECK-SAME:    %[[ARG0:.*]]: !fir.ref<!fir.type<t8{i:!fir.array<5xf32>}>>) {
+func.func @test_call_too_big_hfa(%arg0 : !fir.ref<!too_big_hfa>) {
+  // CHECK: %[[STACK:.*]] = llvm.intr.stacksave : !llvm.ptr
+  // CHECK: %[[ARG:.*]] = fir.alloca !fir.type<t8{i:!fir.array<5xf32>}>
+  // CHECK: fir.call @test_too_big_hfa(%[[ARG]]) : (!fir.ref<!fir.type<t8{i:!fir.array<5xf32>}>>) -> ()
+  // CHECK: %[[CVT:.*]] = fir.convert %[[ARG]] : (!fir.ref<!fir.type<t8{i:!fir.array<5xf32>}>>) -> !fir.ref<!fir.type<t8{i:!fir.array<5xf32>}>>
+  // CHECK: %[[LD:.*]] = fir.load %[[CVT]] : !fir.ref<!fir.type<t8{i:!fir.array<5xf32>}>>
+  // CHECK: llvm.intr.stackrestore %[[STACK]] : !llvm.ptr
+  %out = fir.call @test_too_big_hfa() : () -> !too_big_hfa
+  // CHECK: fir.store %[[LD]] to %[[ARG0]] : !fir.ref<!fir.type<t8{i:!fir.array<5xf32>}>>
+  fir.store %out to %arg0 : !fir.ref<!too_big_hfa>
+  return
+}

>From da590f308a86a6485871c2e821244f9d94ff7ad6 Mon Sep 17 00:00:00 2001
From: David Truby <david.truby at arm.com>
Date: Wed, 30 Oct 2024 12:54:44 +0000
Subject: [PATCH 2/3] Fixes for review

---
 flang/lib/Optimizer/CodeGen/Target.cpp | 7 +++++--
 1 file changed, 5 insertions(+), 2 deletions(-)

diff --git a/flang/lib/Optimizer/CodeGen/Target.cpp b/flang/lib/Optimizer/CodeGen/Target.cpp
index 15ffdb74ef51d6..31a6a6f3aaa130 100644
--- a/flang/lib/Optimizer/CodeGen/Target.cpp
+++ b/flang/lib/Optimizer/CodeGen/Target.cpp
@@ -826,6 +826,8 @@ struct TargetAArch64 : public GenericTarget<TargetAArch64> {
     return marshal;
   }
 
+  // Determine if the type is a Homogenous Floating-point Aggregate (HFA). An
+  // HFA is a record type with up to 4 floating-point members of the same type.
   static bool isHFA(fir::RecordType ty) {
     auto types = ty.getTypeList();
     if (types.empty() || types.size() > 4) {
@@ -833,13 +835,14 @@ struct TargetAArch64 : public GenericTarget<TargetAArch64> {
     }
 
     if (!isa_real(types.front().second)) {
-      types.front().second.dump();
       return false;
     }
 
     return llvm::all_equal(llvm::make_second_range(types));
   }
 
+  // AArch64 procedure call ABI:
+  // https://github.com/ARM-software/abi-aa/blob/main/aapcs64/aapcs64.rst#parameter-passing
   CodeGenSpecifics::Marshalling
   structReturnType(mlir::Location loc, fir::RecordType ty) const override {
     CodeGenSpecifics::Marshalling marshal;
@@ -855,7 +858,7 @@ struct TargetAArch64 : public GenericTarget<TargetAArch64> {
 
     // return in registers if size <= 16 bytes
     if (size <= 16) {
-      auto dwordSize = (size + 7) / 8;
+      std::size_t dwordSize = (size + 7) / 8;
       auto newTy = fir::SequenceType::get(
           dwordSize, mlir::IntegerType::get(ty.getContext(), 64));
       marshal.emplace_back(newTy, AT{});

>From 67e144028c77d0265028df03c26327b5361be1f5 Mon Sep 17 00:00:00 2001
From: David Truby <david.truby at arm.com>
Date: Tue, 12 Nov 2024 18:39:49 +0000
Subject: [PATCH 3/3] Fix isHFA for nested structure types

---
 flang/lib/Optimizer/CodeGen/Target.cpp   |  58 +++++++++--
 flang/test/Fir/struct-return-aarch64.fir | 123 ++++++++++++++++++-----
 2 files changed, 150 insertions(+), 31 deletions(-)

diff --git a/flang/lib/Optimizer/CodeGen/Target.cpp b/flang/lib/Optimizer/CodeGen/Target.cpp
index 31a6a6f3aaa130..06b19c0031e434 100644
--- a/flang/lib/Optimizer/CodeGen/Target.cpp
+++ b/flang/lib/Optimizer/CodeGen/Target.cpp
@@ -826,19 +826,65 @@ struct TargetAArch64 : public GenericTarget<TargetAArch64> {
     return marshal;
   }
 
+  // Flatten a RecordType::TypeList containing more record types or array types
+  static std::optional<std::vector<mlir::Type>>
+  flattenTypeList(const RecordType::TypeList &types) {
+    std::vector<mlir::Type> flatTypes;
+    // The flat list will be at least the same size as the non-flat list.
+    flatTypes.reserve(types.size());
+    for (auto [c, type] : types) {
+      // Flatten record type
+      if (auto recTy = mlir::dyn_cast<RecordType>(type)) {
+        auto subTypeList = flattenTypeList(recTy.getTypeList());
+        if (!subTypeList)
+          return std::nullopt;
+        llvm::copy(*subTypeList, std::back_inserter(flatTypes));
+        continue;
+      }
+
+      // Flatten array type
+      if (auto seqTy = mlir::dyn_cast<SequenceType>(type)) {
+        if (seqTy.hasDynamicExtents())
+          return std::nullopt;
+        std::size_t n = seqTy.getConstantArraySize();
+        auto eleTy = seqTy.getElementType();
+        // Flatten array of record types
+        if (auto recTy = mlir::dyn_cast<RecordType>(eleTy)) {
+          auto subTypeList = flattenTypeList(recTy.getTypeList());
+          if (!subTypeList)
+            return std::nullopt;
+          for (std::size_t i = 0; i < n; ++i)
+            llvm::copy(*subTypeList, std::back_inserter(flatTypes));
+        } else {
+          std::fill_n(std::back_inserter(flatTypes),
+                      seqTy.getConstantArraySize(), eleTy);
+        }
+        continue;
+      }
+
+      // Other types are already flat
+      flatTypes.push_back(type);
+    }
+    return flatTypes;
+  }
+
   // Determine if the type is a Homogenous Floating-point Aggregate (HFA). An
   // HFA is a record type with up to 4 floating-point members of the same type.
   static bool isHFA(fir::RecordType ty) {
-    auto types = ty.getTypeList();
-    if (types.empty() || types.size() > 4) {
+    RecordType::TypeList types = ty.getTypeList();
+    if (types.empty() || types.size() > 4)
+      return false;
+
+    std::optional<std::vector<mlir::Type>> flatTypes = flattenTypeList(types);
+    if (!flatTypes || flatTypes->size() > 4) {
       return false;
     }
 
-    if (!isa_real(types.front().second)) {
+    if (!isa_real(flatTypes->front())) {
       return false;
     }
 
-    return llvm::all_equal(llvm::make_second_range(types));
+    return llvm::all_equal(*flatTypes);
   }
 
   // AArch64 procedure call ABI:
@@ -848,8 +894,8 @@ struct TargetAArch64 : public GenericTarget<TargetAArch64> {
     CodeGenSpecifics::Marshalling marshal;
 
     if (isHFA(ty)) {
-      auto newTy = fir::SequenceType::get({ty.getNumFields()}, ty.getType(0));
-      marshal.emplace_back(newTy, AT{});
+      // Just return the existing record type
+      marshal.emplace_back(ty, AT{});
       return marshal;
     }
 
diff --git a/flang/test/Fir/struct-return-aarch64.fir b/flang/test/Fir/struct-return-aarch64.fir
index 96f2f9999b3435..8b75c2cac7b6be 100644
--- a/flang/test/Fir/struct-return-aarch64.fir
+++ b/flang/test/Fir/struct-return-aarch64.fir
@@ -22,16 +22,16 @@ func.func @test_call_composite(%arg0 : !fir.ref<!composite>) {
 }
 
 !hfa_f16 = !fir.type<t2{x:f16, y:f16}>
-// CHECK-LABEL: func.func private @test_hfa_f16() -> !fir.array<2xf16>
+// CHECK-LABEL: func.func private @test_hfa_f16() -> !fir.type<t2{x:f16,y:f16}>
 func.func private @test_hfa_f16() -> !hfa_f16
 // CHECK-LABEL: func.func @test_call_hfa_f16(
 // CHECK-SAME:    %[[ARG0:.*]]: !fir.ref<!fir.type<t2{x:f16,y:f16}>>) {
 func.func @test_call_hfa_f16(%arg0 : !fir.ref<!hfa_f16>) {
-  // CHECK: %[[OUT:.*]] = fir.call @test_hfa_f16() : () -> !fir.array<2xf16>
+  // CHECK: %[[OUT:.*]] = fir.call @test_hfa_f16() : () -> !fir.type<t2{x:f16,y:f16}>
   // CHECK: %[[STACK:.*]] = llvm.intr.stacksave : !llvm.ptr
-  // CHECK: %[[ARR:.*]] = fir.alloca !fir.array<2xf16>
-  // CHECK: fir.store %[[OUT]] to %[[ARR]] : !fir.ref<!fir.array<2xf16>>
-  // CHECK: %[[CVT:.*]] = fir.convert %[[ARR]] : (!fir.ref<!fir.array<2xf16>>) -> !fir.ref<!fir.type<t2{x:f16,y:f16}>>
+  // CHECK: %[[ARR:.*]] = fir.alloca !fir.type<t2{x:f16,y:f16}>
+  // CHECK: fir.store %[[OUT]] to %[[ARR]] : !fir.ref<!fir.type<t2{x:f16,y:f16}>>
+  // CHECK: %[[CVT:.*]] = fir.convert %[[ARR]] : (!fir.ref<!fir.type<t2{x:f16,y:f16}>>) -> !fir.ref<!fir.type<t2{x:f16,y:f16}>>
   // CHECK: %[[LD:.*]] = fir.load %[[CVT]] : !fir.ref<!fir.type<t2{x:f16,y:f16}>>
   // CHECK: llvm.intr.stackrestore %[[STACK]] : !llvm.ptr
   %out = fir.call @test_hfa_f16() : () -> !hfa_f16
@@ -41,16 +41,16 @@ func.func @test_call_hfa_f16(%arg0 : !fir.ref<!hfa_f16>) {
 }
 
 !hfa_f32 = !fir.type<t3{w:f32, x:f32, y:f32, z:f32}>
-// CHECK-LABEL: func.func private @test_hfa_f32() -> !fir.array<4xf32>
+// CHECK-LABEL: func.func private @test_hfa_f32() -> !fir.type<t3{w:f32,x:f32,y:f32,z:f32}>
 func.func private @test_hfa_f32() -> !hfa_f32
 // CHECK-LABEL: func.func @test_call_hfa_f32(
 // CHECK-SAME:    %[[ARG0:.*]]: !fir.ref<!fir.type<t3{w:f32,x:f32,y:f32,z:f32}>>) {
 func.func @test_call_hfa_f32(%arg0 : !fir.ref<!hfa_f32>) {
-  // CHECK: %[[OUT:.*]] = fir.call @test_hfa_f32() : () -> !fir.array<4xf32>
+  // CHECK: %[[OUT:.*]] = fir.call @test_hfa_f32() : () -> !fir.type<t3{w:f32,x:f32,y:f32,z:f32}>
   // CHECK: %[[STACK:.*]] = llvm.intr.stacksave : !llvm.ptr
-  // CHECK: %[[ARR:.*]] = fir.alloca !fir.array<4xf32>
-  // CHECK: fir.store %[[OUT]] to %[[ARR]] : !fir.ref<!fir.array<4xf32>>
-  // CHECK: %[[CVT:.*]] = fir.convert %[[ARR]] : (!fir.ref<!fir.array<4xf32>>) -> !fir.ref<!fir.type<t3{w:f32,x:f32,y:f32,z:f32}>>
+  // CHECK: %[[ARR:.*]] = fir.alloca !fir.type<t3{w:f32,x:f32,y:f32,z:f32}>
+  // CHECK: fir.store %[[OUT]] to %[[ARR]] : !fir.ref<!fir.type<t3{w:f32,x:f32,y:f32,z:f32}>>
+  // CHECK: %[[CVT:.*]] = fir.convert %[[ARR]] : (!fir.ref<!fir.type<t3{w:f32,x:f32,y:f32,z:f32}>>) -> !fir.ref<!fir.type<t3{w:f32,x:f32,y:f32,z:f32}>>
   // CHECK: %[[LD:.*]] = fir.load %[[CVT]] : !fir.ref<!fir.type<t3{w:f32,x:f32,y:f32,z:f32}>>
   // CHECK: llvm.intr.stackrestore %[[STACK]] : !llvm.ptr
   %out = fir.call @test_hfa_f32() : () -> !hfa_f32
@@ -60,16 +60,16 @@ func.func @test_call_hfa_f32(%arg0 : !fir.ref<!hfa_f32>) {
 }
 
 !hfa_f64 = !fir.type<t4{x:f64, y:f64, z:f64}>
-// CHECK-LABEL: func.func private @test_hfa_f64() -> !fir.array<3xf64>
+// CHECK-LABEL: func.func private @test_hfa_f64() -> !fir.type<t4{x:f64,y:f64,z:f64}>
 func.func private @test_hfa_f64() -> !hfa_f64
 // CHECK-LABEL: func.func @test_call_hfa_f64(
 // CHECK-SAME:    %[[ARG0:.*]]: !fir.ref<!fir.type<t4{x:f64,y:f64,z:f64}>>)
 func.func @test_call_hfa_f64(%arg0 : !fir.ref<!hfa_f64>) {
-  // CHECK: %[[OUT:.*]] = fir.call @test_hfa_f64() : () -> !fir.array<3xf64>
+  // CHECK: %[[OUT:.*]] = fir.call @test_hfa_f64() : () -> !fir.type<t4{x:f64,y:f64,z:f64}>
   // CHECK: %[[STACK:.*]] = llvm.intr.stacksave : !llvm.ptr
-  // CHECK: %[[ARR:.*]] = fir.alloca !fir.array<3xf64>
-  // CHECK: fir.store %[[OUT]] to %[[ARR]] : !fir.ref<!fir.array<3xf64>>
-  // CHECK: %[[CVT:.*]] = fir.convert %[[ARR]] : (!fir.ref<!fir.array<3xf64>>) -> !fir.ref<!fir.type<t4{x:f64,y:f64,z:f64}>>
+  // CHECK: %[[ARR:.*]] = fir.alloca !fir.type<t4{x:f64,y:f64,z:f64}>
+  // CHECK: fir.store %[[OUT]] to %[[ARR]] : !fir.ref<!fir.type<t4{x:f64,y:f64,z:f64}>>
+  // CHECK: %[[CVT:.*]] = fir.convert %[[ARR]] : (!fir.ref<!fir.type<t4{x:f64,y:f64,z:f64}>>) -> !fir.ref<!fir.type<t4{x:f64,y:f64,z:f64}>>
   // CHECK: %[[LD:.*]] = fir.load %[[CVT]] : !fir.ref<!fir.type<t4{x:f64,y:f64,z:f64}>>
   // CHECK: llvm.intr.stackrestore %[[STACK]] : !llvm.ptr
   %out = fir.call @test_hfa_f64() : () -> !hfa_f64
@@ -79,16 +79,16 @@ func.func @test_call_hfa_f64(%arg0 : !fir.ref<!hfa_f64>) {
 }
 
 !hfa_f128 = !fir.type<t5{w:f128, x:f128, y:f128, z:f128}>
-// CHECK-LABEL: func.func private @test_hfa_f128() -> !fir.array<4xf128>
+// CHECK-LABEL: func.func private @test_hfa_f128() -> !fir.type<t5{w:f128,x:f128,y:f128,z:f128}>
 func.func private @test_hfa_f128() -> !hfa_f128
 // CHECK-LABEL: func.func @test_call_hfa_f128(
 // CHECK-SAME:    %[[ARG0:.*]]: !fir.ref<!fir.type<t5{w:f128,x:f128,y:f128,z:f128}>>) {
 func.func @test_call_hfa_f128(%arg0 : !fir.ref<!hfa_f128>) {
-  // CHECK: %[[OUT:.*]] = fir.call @test_hfa_f128() : () -> !fir.array<4xf128>
+  // CHECK: %[[OUT:.*]] = fir.call @test_hfa_f128() : () -> !fir.type<t5{w:f128,x:f128,y:f128,z:f128}>
   // CHECK: %[[STACK:.*]] = llvm.intr.stacksave : !llvm.ptr
-  // CHECK: %[[ARR:.*]] = fir.alloca !fir.array<4xf128>
-  // CHECK: fir.store %[[OUT]] to %[[ARR]] : !fir.ref<!fir.array<4xf128>>
-  // CHECK: %[[CVT:.*]] = fir.convert %[[ARR]] : (!fir.ref<!fir.array<4xf128>>) -> !fir.ref<!fir.type<t5{w:f128,x:f128,y:f128,z:f128}>>
+  // CHECK: %[[ARR:.*]] = fir.alloca !fir.type<t5{w:f128,x:f128,y:f128,z:f128}>
+  // CHECK: fir.store %[[OUT]] to %[[ARR]] : !fir.ref<!fir.type<t5{w:f128,x:f128,y:f128,z:f128}>>
+  // CHECK: %[[CVT:.*]] = fir.convert %[[ARR]] : (!fir.ref<!fir.type<t5{w:f128,x:f128,y:f128,z:f128}>>) -> !fir.ref<!fir.type<t5{w:f128,x:f128,y:f128,z:f128}>>
   // CHECK: %[[LD:.*]] = fir.load %[[CVT]] : !fir.ref<!fir.type<t5{w:f128,x:f128,y:f128,z:f128}>>
   // CHECK: llvm.intr.stackrestore %[[STACK]] : !llvm.ptr
   %out = fir.call @test_hfa_f128() : () -> !hfa_f128
@@ -98,16 +98,16 @@ func.func @test_call_hfa_f128(%arg0 : !fir.ref<!hfa_f128>) {
 }
 
 !hfa_bf16 = !fir.type<t6{w:bf16, x:bf16, y:bf16, z:bf16}>
-// CHECK-LABEL: func.func private @test_hfa_bf16() -> !fir.array<4xbf16>
+// CHECK-LABEL: func.func private @test_hfa_bf16() -> !fir.type<t6{w:bf16,x:bf16,y:bf16,z:bf16}>
 func.func private @test_hfa_bf16() -> !hfa_bf16
 // CHECK-LABEL: func.func @test_call_hfa_bf16(
 // CHECK-SAME:    %[[ARG0:.*]]: !fir.ref<!fir.type<t6{w:bf16,x:bf16,y:bf16,z:bf16}>>) {
 func.func @test_call_hfa_bf16(%arg0 : !fir.ref<!hfa_bf16>) {
-  // CHECK: %[[OUT:.*]] = fir.call @test_hfa_bf16() : () -> !fir.array<4xbf16>
+  // CHECK: %[[OUT:.*]] = fir.call @test_hfa_bf16() : () -> !fir.type<t6{w:bf16,x:bf16,y:bf16,z:bf16}>
   // CHECK: %[[STACK:.*]] = llvm.intr.stacksave : !llvm.ptr
-  // CHECK: %[[ARR:.*]] = fir.alloca !fir.array<4xbf16>
-  // CHECK: fir.store %[[OUT]] to %[[ARR]] : !fir.ref<!fir.array<4xbf16>>
-  // CHECK: %[[CVT:.*]] = fir.convert %[[ARR]] : (!fir.ref<!fir.array<4xbf16>>) -> !fir.ref<!fir.type<t6{w:bf16,x:bf16,y:bf16,z:bf16}>>
+  // CHECK: %[[ARR:.*]] = fir.alloca !fir.type<t6{w:bf16,x:bf16,y:bf16,z:bf16}>
+  // CHECK: fir.store %[[OUT]] to %[[ARR]] : !fir.ref<!fir.type<t6{w:bf16,x:bf16,y:bf16,z:bf16}>>
+  // CHECK: %[[CVT:.*]] = fir.convert %[[ARR]] : (!fir.ref<!fir.type<t6{w:bf16,x:bf16,y:bf16,z:bf16}>>) -> !fir.ref<!fir.type<t6{w:bf16,x:bf16,y:bf16,z:bf16}>>
   // CHECK: %[[LD:.*]] = fir.load %[[CVT]] : !fir.ref<!fir.type<t6{w:bf16,x:bf16,y:bf16,z:bf16}>>
   // CHECK: llvm.intr.stackrestore %[[STACK]] : !llvm.ptr
   %out = fir.call @test_hfa_bf16() : () -> !hfa_bf16
@@ -154,3 +154,76 @@ func.func @test_call_too_big_hfa(%arg0 : !fir.ref<!too_big_hfa>) {
   fir.store %out to %arg0 : !fir.ref<!too_big_hfa>
   return
 }
+
+!nested_hfa_first = !fir.type<t9{s:!hfa_f16,c:f16}>
+// CHECK-LABEL: func.func private @test_nested_hfa_first() -> !fir.type<t9{s:!fir.type<t2{x:f16,y:f16}>,c:f16}>
+func.func private @test_nested_hfa_first() -> !nested_hfa_first
+// CHECK-LABEL: func.func @test_call_nested_hfa_first(%arg0: !fir.ref<!fir.type<t9{s:!fir.type<t2{x:f16,y:f16}>,c:f16}>>) {
+func.func @test_call_nested_hfa_first(%arg0 : !fir.ref<!nested_hfa_first>) {
+  %out = fir.call @test_nested_hfa_first() : () -> !nested_hfa_first
+  // CHECK: %[[OUT:.*]] = fir.call @test_nested_hfa_first() : () -> !fir.type<t9{s:!fir.type<t2{x:f16,y:f16}>,c:f16}>
+  // CHECK: %[[STACK:.*]] = llvm.intr.stacksave : !llvm.ptr
+  // CHECK: %[[ARR:.*]] = fir.alloca !fir.type<t9{s:!fir.type<t2{x:f16,y:f16}>,c:f16}>
+  // CHECK: fir.store %[[OUT]] to %[[ARR]] : !fir.ref<!fir.type<t9{s:!fir.type<t2{x:f16,y:f16}>,c:f16}>>
+  // CHECK: %[[CVT:.*]] = fir.convert %[[ARR]] : (!fir.ref<!fir.type<t9{s:!fir.type<t2{x:f16,y:f16}>,c:f16}>>
+  // CHECK: %[[LD:.*]] = fir.load %[[CVT]] : !fir.ref<!fir.type<t9{s:!fir.type<t2{x:f16,y:f16}>,c:f16}>>
+  // CHECK: llvm.intr.stackrestore %[[STACK]] : !llvm.ptr
+  fir.store %out to %arg0 : !fir.ref<!nested_hfa_first>
+  // CHECK fir.store %[[LD]] to %[[ARG0]] : !fir.ref<!fir.type<t9{s:!fir.type<t2{x:f16,y:f16}>,c:f16}>>
+  return
+}
+
+
+!nested_hfa_middle = !fir.type<t10{a:f16,s:!hfa_f16,c:f16}>
+// CHECK-LABEL: func.func private @test_nested_hfa_middle() -> !fir.type<t10{a:f16,s:!fir.type<t2{x:f16,y:f16}>,c:f16}>
+func.func private @test_nested_hfa_middle() -> !nested_hfa_middle
+// CHECK-LABEL: func.func @test_call_nested_hfa_middle(%arg0: !fir.ref<!fir.type<t10{a:f16,s:!fir.type<t2{x:f16,y:f16}>,c:f16}>>) {
+func.func @test_call_nested_hfa_middle(%arg0 : !fir.ref<!nested_hfa_middle>) {
+  %out = fir.call @test_nested_hfa_middle() : () -> !nested_hfa_middle
+  // CHECK: %[[OUT:.*]] = fir.call @test_nested_hfa_middle() : () -> !fir.type<t10{a:f16,s:!fir.type<t2{x:f16,y:f16}>,c:f16}>
+  // CHECK: %[[STACK:.*]] = llvm.intr.stacksave : !llvm.ptr
+  // CHECK: %[[ARR:.*]] = fir.alloca !fir.type<t10{a:f16,s:!fir.type<t2{x:f16,y:f16}>,c:f16}>
+  // CHECK: fir.store %[[OUT]] to %[[ARR]] : !fir.ref<!fir.type<t10{a:f16,s:!fir.type<t2{x:f16,y:f16}>,c:f16}>>
+  // CHECK: %[[CVT:.*]] = fir.convert %[[ARR]] : (!fir.ref<!fir.type<t10{a:f16,s:!fir.type<t2{x:f16,y:f16}>,c:f16}>>
+  // CHECK: %[[LD:.*]] = fir.load %[[CVT]] : !fir.ref<!fir.type<t10{a:f16,s:!fir.type<t2{x:f16,y:f16}>,c:f16}>>
+  // CHECK: llvm.intr.stackrestore %[[STACK]] : !llvm.ptr
+  fir.store %out to %arg0 : !fir.ref<!nested_hfa_middle>
+  // CHECK fir.store %[[LD]] to %[[ARG0]] : !fir.ref<!fir.type<t10{a:f16,s:!fir.type<t2{x:f16,y:f16}>,c:f16}>>
+  return
+}
+
+!nested_hfa_end = !fir.type<t11{a:f16,s:!hfa_f16}>
+// CHECK-LABEL: func.func private @test_nested_hfa_end() -> !fir.type<t11{a:f16,s:!fir.type<t2{x:f16,y:f16}>}>
+func.func private @test_nested_hfa_end() -> !nested_hfa_end
+// CHECK-LABEL: func.func @test_call_nested_hfa_end(%arg0: !fir.ref<!fir.type<t11{a:f16,s:!fir.type<t2{x:f16,y:f16}>}>>) {
+func.func @test_call_nested_hfa_end(%arg0 : !fir.ref<!nested_hfa_end>) {
+  %out = fir.call @test_nested_hfa_end() : () -> !nested_hfa_end
+  // CHECK: %[[OUT:.*]] = fir.call @test_nested_hfa_end() : () -> !fir.type<t11{a:f16,s:!fir.type<t2{x:f16,y:f16}>}>
+  // CHECK: %[[STACK:.*]] = llvm.intr.stacksave : !llvm.ptr
+  // CHECK: %[[ARR:.*]] = fir.alloca !fir.type<t11{a:f16,s:!fir.type<t2{x:f16,y:f16}>}>
+  // CHECK: fir.store %[[OUT]] to %[[ARR]] : !fir.ref<!fir.type<t11{a:f16,s:!fir.type<t2{x:f16,y:f16}>}>>
+  // CHECK: %[[CVT:.*]] = fir.convert %[[ARR]] : (!fir.ref<!fir.type<t11{a:f16,s:!fir.type<t2{x:f16,y:f16}>}>>
+  // CHECK: %[[LD:.*]] = fir.load %[[CVT]] : !fir.ref<!fir.type<t11{a:f16,s:!fir.type<t2{x:f16,y:f16}>}>>
+  // CHECK: llvm.intr.stackrestore %[[STACK]] : !llvm.ptr
+  fir.store %out to %arg0 : !fir.ref<!nested_hfa_end>
+  // CHECK fir.store %[[LD]] to %[[ARG0]] : !fir.ref<!fir.type<t11{a:f16,s:!fir.type<t2{x:f16,y:f16}>}>>
+  return
+}
+
+!nested_hfa_array = !fir.type<t12{a:!fir.array<2xf32>,b:f32}>
+// CHECK-LABEL: func.func private @test_nested_hfa_array() -> !fir.type<t12{a:!fir.array<2xf32>,b:f32}>
+func.func private @test_nested_hfa_array() -> !nested_hfa_array
+// CHECK-LABEL: func.func @test_call_nested_hfa_array(%arg0: !fir.ref<!fir.type<t12{a:!fir.array<2xf32>,b:f32}>
+func.func @test_call_nested_hfa_array(%arg0 : !fir.ref<!nested_hfa_array>) {
+  %out = fir.call @test_nested_hfa_array() : () -> !nested_hfa_array
+  // CHECK: %[[OUT:.*]] = fir.call @test_nested_hfa_array() : () -> !fir.type<t12{a:!fir.array<2xf32>,b:f32}>
+  // CHECK: %[[STACK:.*]] = llvm.intr.stacksave : !llvm.ptr
+  // CHECK: %[[ARR:.*]] = fir.alloca !fir.type<t12{a:!fir.array<2xf32>,b:f32}>
+  // CHECK: fir.store %[[OUT]] to %[[ARR]] : !fir.ref<!fir.type<t12{a:!fir.array<2xf32>,b:f32}>
+  // CHECK: %[[CVT:.*]] = fir.convert %[[ARR]] : (!fir.ref<!fir.type<t12{a:!fir.array<2xf32>,b:f32}>
+  // CHECK: %[[LD:.*]] = fir.load %[[CVT]] : !fir.ref<!fir.type<t12{a:!fir.array<2xf32>,b:f32}>
+  // CHECK: llvm.intr.stackrestore %[[STACK]] : !llvm.ptr
+  fir.store %out to %arg0 : !fir.ref<!nested_hfa_array>
+  // CHECK fir.store %[[LD]] to %[[ARG0]] : !fir.ref<!fir.type<t12{a:!fir.array<2xf32>,b:f32}>
+  return
+}



More information about the flang-commits mailing list