[flang-commits] [flang] [flang] RISCV64 (lp64/lp64d) support for BIND(C) derived types (PR #198335)

Philipp Rados via flang-commits flang-commits at lists.llvm.org
Wed May 27 06:11:35 PDT 2026


https://github.com/prados-oc updated https://github.com/llvm/llvm-project/pull/198335

>From 72fc97805fc3d940ff6f8a16f23787b690556ef3 Mon Sep 17 00:00:00 2001
From: Philipp Rados <philipp.rados at openchip.com>
Date: Mon, 18 May 2026 16:13:24 +0200
Subject: [PATCH 1/4] [flang] RISCV64 support for BIND(C) derived types

---
 flang/lib/Optimizer/CodeGen/Target.cpp        | 303 ++++++++++++++++++
 .../test/Fir/struct-passing-riscv64-byval.fir |  90 ++++++
 flang/test/Fir/struct-return-riscv64.fir      | 173 ++++++++++
 3 files changed, 566 insertions(+)
 create mode 100644 flang/test/Fir/struct-passing-riscv64-byval.fir
 create mode 100644 flang/test/Fir/struct-return-riscv64.fir

diff --git a/flang/lib/Optimizer/CodeGen/Target.cpp b/flang/lib/Optimizer/CodeGen/Target.cpp
index 9b6c9be79120c..c45862ece7573 100644
--- a/flang/lib/Optimizer/CodeGen/Target.cpp
+++ b/flang/lib/Optimizer/CodeGen/Target.cpp
@@ -1389,10 +1389,18 @@ struct TargetSparcV9 : public GenericTarget<TargetSparcV9> {
 //===----------------------------------------------------------------------===//
 
 namespace {
+// RISCV64 calling convention specification:
+// https://github.com/riscv-non-isa/riscv-elf-psabi-doc/blob/master/riscv-cc.adoc#procedure-calling-convention
 struct TargetRISCV64 : public GenericTarget<TargetRISCV64> {
   using GenericTarget::GenericTarget;
 
   static constexpr int defaultWidth = 64;
+  static constexpr int defaultWidthBytes = defaultWidth / 8;
+  // TODO: Can't query ABI from inside TargetRewrite so assume the more common
+  // `lp64d` for now. Alternatively could check float-support from
+  // target-features, but that could be overridden by manually setting
+  // `-mabi=lp64`.
+  static constexpr bool hasHardFloatABI = true;
 
   CodeGenSpecifics::Marshalling
   complexArgumentType(mlir::Location loc, mlir::Type eleTy) const override {
@@ -1425,6 +1433,301 @@ struct TargetRISCV64 : public GenericTarget<TargetRISCV64> {
     }
     return marshal;
   }
+
+  void checkValidTypeOrCrash(mlir::Location loc, mlir::Type type) const {
+    llvm::TypeSwitch<mlir::Type>(type)
+        .Case<mlir::IntegerType>([&](auto integerTy) {
+          // 128 bit int will be passed like any other 128bit struct as 2
+          // registers.
+          if (integerTy.getWidth() > 128)
+            TODO(loc,
+                 "integerType with width exceeding 128 bits is unsupported");
+        })
+        .Case<mlir::FloatType>([&](auto floatTy) {
+          if (floatTy.getWidth() > 64)
+            TODO(loc, "128 bit float is not supported by RISCV64");
+        })
+        .Case([&](mlir::ComplexType cmplx) {
+          const auto *sem = &floatToSemantics(kindMap, cmplx.getElementType());
+          if (sem != &llvm::APFloat::IEEEsingle() &&
+              sem != &llvm::APFloat::IEEEdouble())
+            TODO(loc, "unsupported complex type(not IEEEsingle, IEEEdouble) "
+                      "as a structure component for BIND(C), "
+                      "VALUE derived type argument and type return");
+        })
+        .Case<fir::LogicalType, fir::CharacterType>([&](auto ty) {
+          // Always fine, characters with len>1 get already rejected before.
+        })
+        .Case([&](fir::RecordType recTy) {
+          for (auto [name, ty] : recTy.getTypeList())
+            checkValidTypeOrCrash(loc, ty);
+        })
+        .Case([&](fir::SequenceType seqTy) {
+          if (seqTy.hasDynamicExtents())
+            TODO(loc, "passing dynamic sequence argument to C by value is not "
+                      "supported");
+          checkValidTypeOrCrash(loc, seqTy.getElementType());
+        })
+        .Case([&](fir::VectorType vecTy) {
+          TODO(loc, "passing vector argument to C by value is not supported");
+        })
+        .Default([&](mlir::Type ty) {
+          if (!fir::conformsWithPassByRef(ty))
+            TODO(loc, "unsupported component type for BIND(C), VALUE derived "
+                      "type argument and type return");
+        });
+  }
+
+  CodeGenSpecifics::Marshalling
+  passOnTheStack(unsigned short recAlign, mlir::Type ty, bool isResult) const {
+    CodeGenSpecifics::Marshalling marshal;
+    // The stack is always 8 byte aligned
+    unsigned short align = std::max(recAlign, static_cast<unsigned short>(8));
+    marshal.emplace_back(fir::ReferenceType::get(ty),
+                         AT{align, /*byval=*/!isResult, /*sret=*/isResult});
+    return marshal;
+  }
+
+  // Flatten a RecordType::TypeList containing more record types or array types
+  static std::vector<mlir::Type>
+  flattenTypeList(const RecordType::TypeList &types) {
+    std::vector<mlir::Type> flatTypes;
+    // The flat list will be at least the same size as the non-flat list.
+    flatTypes.reserve(types.size());
+    for (auto [c, type] : types) {
+      // Flatten record type
+      if (auto recTy = mlir::dyn_cast<RecordType>(type)) {
+        auto subTypeList = flattenTypeList(recTy.getTypeList());
+        llvm::copy(subTypeList, std::back_inserter(flatTypes));
+        continue;
+      }
+
+      // Flatten array type
+      if (auto seqTy = mlir::dyn_cast<SequenceType>(type)) {
+        assert(!seqTy.hasDynamicExtents() &&
+               "dynamic sequences should have been caught before.");
+        std::size_t n = seqTy.getConstantArraySize();
+        auto eleTy = seqTy.getElementType();
+        // Flatten array of record types
+        if (auto recTy = mlir::dyn_cast<RecordType>(eleTy)) {
+          auto subTypeList = flattenTypeList(recTy.getTypeList());
+          for (std::size_t i = 0; i < n; ++i)
+            llvm::copy(subTypeList, std::back_inserter(flatTypes));
+        } else {
+          std::fill_n(std::back_inserter(flatTypes),
+                      seqTy.getConstantArraySize(), eleTy);
+        }
+        continue;
+      }
+
+      // Complex type is made up of 2 floats
+      if (auto compTy = mlir::dyn_cast<mlir::ComplexType>(type)) {
+        flatTypes.push_back(compTy.getElementType());
+        flatTypes.push_back(compTy.getElementType());
+        continue;
+      }
+
+      // Other types are already flat
+      flatTypes.push_back(type);
+    }
+    return flatTypes;
+  }
+
+  static bool floatAndCanPassInRegister(const mlir::Type &ty) {
+    return mlir::isa<mlir::FloatType>(ty) &&
+           mlir::cast<mlir::FloatType>(ty).getWidth() <= defaultWidth;
+  }
+
+  static bool integerAndCanPassInRegister(const mlir::Type &ty) {
+    return mlir::isa<mlir::IntegerType>(ty) &&
+           mlir::cast<mlir::IntegerType>(ty).getWidth() <= defaultWidth;
+  }
+
+  void checkAvailableRegisters(mlir::Location loc,
+                               const Marshalling &previousArguments,
+                               int &gprArgs, int &fprArgs) const {
+    for (auto [ty, attr] : previousArguments) {
+      if (gprArgs <= 0 && fprArgs <= 0)
+        break;
+
+      // previous argument was passed by value and thus takes no registers.
+      if (attr.isByVal())
+        continue;
+
+      llvm::TypeSwitch<mlir::Type>(ty)
+          .Case<mlir::IntegerType>([&](mlir::IntegerType intTy) {
+            if (gprArgs > 1 && intTy.getWidth() > 64)
+              gprArgs -= 2;
+            else if (gprArgs)
+              gprArgs--;
+          })
+          .Case<mlir::FloatType>([&](mlir::FloatType floatTy) {
+            if (fprArgs)
+              fprArgs--;
+          })
+          .Case<fir::SequenceType>([&](fir::SequenceType seqTy) {
+            auto sizeSeqTy = fir::getTypeSizeAndAlignmentOrCrash(
+                                 loc, seqTy, getDataLayout(), kindMap)
+                                 .first;
+            assert((sizeSeqTy <= 2 * defaultWidthBytes) &&
+                   "arrays can't be passed by value to bind(c) and "
+                   "if array is a record field it was marshalled before");
+
+            if (sizeSeqTy <= defaultWidthBytes && gprArgs) {
+              gprArgs--;
+              return;
+            }
+            if (sizeSeqTy <= 2 * defaultWidthBytes && gprArgs) {
+              // We try to use two registers.
+              if (gprArgs > 1)
+                gprArgs -= 2;
+              else
+                gprArgs--;
+            }
+          })
+          // NOTE: Tuples are only used to marshal result types and so can't
+          // appear in `previousArguments`.
+          .Default([&](mlir::Type ty) {
+            if (fir::conformsWithPassByRef(ty))
+              if (gprArgs)
+                gprArgs--;
+          });
+    }
+  }
+
+  CodeGenSpecifics::Marshalling
+  getIntCCArgs(mlir::MLIRContext *context,
+               CodeGenSpecifics::Marshalling marshal,
+               std::uint64_t recordSize) const {
+    assert(recordSize <= defaultWidthBytes * 2);
+
+    // NOTE: Clang doesn't handle split struct case when only a single register
+    // remains. In general it lets the code generator take care of properly
+    // handling excess integer register usage, do the same for Flang.
+    if (recordSize <= defaultWidthBytes) {
+      // Pass this as an integer.
+      int width = llvm::PowerOf2Ceil(recordSize * 8);
+      marshal.emplace_back(mlir::IntegerType::get(context, width), AT{});
+      return marshal;
+    }
+
+    auto intTy = mlir::IntegerType::get(context, defaultWidth);
+    marshal.emplace_back(fir::SequenceType::get({2}, intTy), AT{});
+    return marshal;
+  }
+
+  CodeGenSpecifics::Marshalling
+  classifyStruct(mlir::Location loc, fir::RecordType recTy, int gprArgs,
+                 int fprArgs, bool isResult,
+                 const Marshalling &previousArguments) const {
+    checkValidTypeOrCrash(loc, recTy);
+    auto [recordSize, recordAlign] = fir::getTypeSizeAndAlignmentOrCrash(
+        loc, recTy, getDataLayout(), kindMap);
+
+    CodeGenSpecifics::Marshalling marshal;
+    mlir::MLIRContext *context = recTy.getContext();
+
+    // This is odd and some targets reject it. The spec says to ignore it.
+    // IIRC Fortran does not allow empty structs and not all versions of C do.
+    // Try to do something sensible, rather than crashing.
+    if (recordSize == 0)
+      return passOnTheStack(recordAlign, recTy, isResult);
+
+    if (recordSize > 2 * defaultWidthBytes)
+      // This struct must go to the stack because it cannot be passed using only
+      // registers.
+      return passOnTheStack(recordAlign, recTy, isResult);
+
+    const std::vector<mlir::Type> &flattenedTypes =
+        flattenTypeList(recTy.getTypeList());
+
+    checkAvailableRegisters(loc, previousArguments, gprArgs, fprArgs);
+
+    if (flattenedTypes.size() == 1 &&
+        floatAndCanPassInRegister(flattenedTypes[0])) {
+      // A struct containing just one floating-point real is passed as though it
+      // were a standalone floating-point real.
+      if (fprArgs) {
+        marshal.emplace_back(flattenedTypes[0], AT{});
+        return marshal;
+      }
+      return getIntCCArgs(context, marshal, recordSize);
+    }
+
+    if (flattenedTypes.size() == 2 &&
+        floatAndCanPassInRegister(flattenedTypes[0]) &&
+        floatAndCanPassInRegister(flattenedTypes[1])) {
+      // A struct containing two floating-point reals is passed in two
+      // floating-point registers, if neither real is
+      // more than ABI_FLEN bits wide and at least two floating-point argument
+      // registers are available. (The registers need not be an aligned pair.)
+      // Otherwise, it is passed according to the integer calling convention.
+      if (fprArgs > 1) {
+        if (isResult)
+          // Results have to be passed in single return type so use tuples.
+          marshal.emplace_back(
+              mlir::TupleType::get(context, mlir::TypeRange{flattenedTypes[0],
+                                                            flattenedTypes[1]}),
+              AT{});
+        else {
+          // Clang flattens this as two floats, so do the same.
+          marshal.emplace_back(flattenedTypes[0], AT{});
+          marshal.emplace_back(flattenedTypes[1], AT{});
+        }
+        return marshal;
+      }
+      return getIntCCArgs(context, marshal, recordSize);
+    }
+
+    if (flattenedTypes.size() == 2 &&
+        ((floatAndCanPassInRegister(flattenedTypes[0]) &&
+          integerAndCanPassInRegister(flattenedTypes[1])) ||
+         (integerAndCanPassInRegister(flattenedTypes[0]) &&
+          floatAndCanPassInRegister(flattenedTypes[1])))) {
+      // A struct containing one floating-point real and one integer (or
+      // bitfield), in either order, is passed in a floating-point register and
+      // an integer register, provided the floating-point real is no more than
+      // ABI_FLEN bits wide and the integer is no more than XLEN bits wide, and
+      // at least one floating-point argument register and at least one integer
+      // argument register is available. If the struct is passed in this manner,
+      // and the integer is narrower than XLEN bits, the remaining bits are
+      // unspecified. If the struct is not passed in this manner, then it is
+      // passed according to the integer calling convention.
+      if (gprArgs && fprArgs) {
+        if (isResult)
+          marshal.emplace_back(
+              mlir::TupleType::get(context, mlir::TypeRange{flattenedTypes[0],
+                                                            flattenedTypes[1]}),
+              AT{});
+        else {
+          // Clang flattens this as one float and one integer, so do the same.
+          marshal.emplace_back(flattenedTypes[0], AT{});
+          marshal.emplace_back(flattenedTypes[1], AT{});
+        }
+        return marshal;
+      }
+    }
+
+    return getIntCCArgs(context, marshal, recordSize);
+  }
+
+  CodeGenSpecifics::Marshalling
+  structArgumentType(mlir::Location loc, fir::RecordType recTy,
+                     const Marshalling &previousArguments) const override {
+    int gprArgs = 8;
+    int fprArgs = hasHardFloatABI ? 8 : 0;
+
+    return classifyStruct(loc, recTy, gprArgs, fprArgs, /*isResult=*/false,
+                          previousArguments);
+  }
+
+  CodeGenSpecifics::Marshalling
+  structReturnType(mlir::Location loc, fir::RecordType recTy) const override {
+    int gprArgs = 2;
+    int fprArgs = hasHardFloatABI ? 2 : 0;
+
+    return classifyStruct(loc, recTy, gprArgs, fprArgs, /*isResult=*/true, {});
+  }
 };
 } // namespace
 
diff --git a/flang/test/Fir/struct-passing-riscv64-byval.fir b/flang/test/Fir/struct-passing-riscv64-byval.fir
new file mode 100644
index 0000000000000..d62b02a51dd20
--- /dev/null
+++ b/flang/test/Fir/struct-passing-riscv64-byval.fir
@@ -0,0 +1,90 @@
+// Test RISCV64 lp64 ABI rewrite of struct passed by value (BIND(C), VALUE derived types).
+
+// TODO: There is currently no way to query the ABI kind (`lp64` vs. `lp64d`) from TargetRewritePass, so only check
+// `lp64d` (current default) for now. Once this changes in the future can enable `lp64` tests too.
+// SKIP: fir-opt --target-rewrite="target=riscv64-unknown-linux-gnu,target-abi=lp64" %s | FileCheck %s --check-prefixes=CHECK-INT
+
+// RUN: fir-opt --target-rewrite="target=riscv64-unknown-linux-gnu" %s | FileCheck %s --check-prefixes=CHECK-FLOAT
+
+module attributes {llvm.data_layout = "e-m:e-p:64:64-i64:64-i128:128-n32:64-S128", llvm.target_triple = "riscv64-unknown-linux-gnu"} {
+
+// ================================================
+
+// CHECK-INT-LABEL: func.func private @single_i64(i64)
+// CHECK-INT-LABEL: func.func private @single_i32i32(i64)
+// CHECK-INT-LABEL: func.func private @single_i8i8i8(i32)
+// CHECK-INT-LABEL: func.func private @single_i8i8i32(i64)
+func.func private @single_i64(!fir.type<single_i64{i:i64}>)
+func.func private @single_i32i32(!fir.type<single_i32i32{i:i32,j:i32}>)
+func.func private @single_i8i8i8(!fir.type<single_i8i8i8{i:i8,j:i8,k:i8}>)
+func.func private @single_i8i8i32(!fir.type<single_i8i8i32{i:i8,j:i8,k:i32}>)
+
+// ================================================
+
+// CHECK-INT-LABEL: func.func private @double_i8i64(!fir.array<2xi64>)
+// CHECK-INT-LABEL: func.func private @double_int128(!fir.array<2xi64>)
+// CHECK-INT-LABEL: func.func private @double_complex(!fir.array<2xi64>)
+// CHECK-FLOAT-LABEL: func.func private @double_int128(!fir.array<2xi64>)
+// CHECK-FLOAT-LABEL: func.func private @double_complex(f64, f64)
+func.func private @double_i8i64(!fir.type<double_i8i64{i:i8,j:i64}>)
+func.func private @double_int128(!fir.type<double_int128{i:i128}>)
+func.func private @double_complex(!fir.type<double_complex{i:complex<f64>}>)
+
+// ================================================
+
+// CHECK-INT-LABEL: func.func private @large_struct(
+// CHECK-INT-SAME: !fir.ref<!fir.type<large_struct{i:i64,j:i64,k:i8}>> {{{.*}}, llvm.byval = !fir.type<large_struct{i:i64,j:i64,k:i8}>},
+// CHECK-INT-SAME: i64)
+// CHECK-INT-LABEL: func.func private @large_bad_alignment(
+// CHECK-INT-SAME: !fir.ref<!fir.type<large_bad_alignment{i:i8,j:i64,k:i8}>> {{{.*}}, llvm.byval = !fir.type<large_bad_alignment{i:i8,j:i64,k:i8}>})
+func.func private @large_struct(!fir.type<large_struct{i:i64,j:i64,k:i8}>, i64)
+func.func private @large_bad_alignment(!fir.type<large_bad_alignment{i:i8,j:i64,k:i8}>)
+
+// ================================================
+
+// CHECK-INT-LABEL:   func.func private @single_float(i32)
+// CHECK-INT-LABEL:   func.func private @float_int(i64)
+// CHECK-FLOAT-LABEL: func.func private @single_float(f32)
+// CHECK-FLOAT-LABEL: func.func private @float_int(f32, i32)
+func.func private @single_float(!fir.type<single_float{f:f32}>)
+func.func private @float_int(!fir.type<float_int{f:f32,i:i32}>)
+
+// ================================================
+
+// CHECK-INT-LABEL:   func.func private @split_struct(i64, i64, i64, i64, i64, i64, i64, !fir.array<2xi64>, i64)
+// CHECK-FLOAT-LABEL: func.func private @split_struct(i64, i64, i64, i64, i64, i64, i64, !fir.array<2xi64>, i64)
+func.func private @split_struct(i64, i64, i64, i64, i64, i64, i64, !fir.type<split_struct{i:i8,j:i64}>, i64)
+
+// CHECK-INT-LABEL:   func.func private @split_struct_float(f64, f64, f64, f64, f64, f64, f64, !fir.array<2xi64>)
+// CHECK-FLOAT-LABEL: func.func private @split_struct_float(f64, f64, f64, f64, f64, f64, f64, f32, i64)
+func.func private @split_struct_float(f64, f64, f64, f64, f64, f64, f64, !fir.type<split_struct_float{i:f32,j:i64}>)
+
+// CHECK-INT-LABEL:   func.func private @split_struct_float2(f64, f64, f64, f64, f64, f64, f64, !fir.array<2xi64>)
+// CHECK-FLOAT-LABEL: func.func private @split_struct_float2(f64, f64, f64, f64, f64, f64, f64, !fir.array<2xi64>)
+func.func private @split_struct_float2(f64, f64, f64, f64, f64, f64, f64, !fir.type<split_struct_float2{i:f64,j:f32}>)
+
+// CHECK-INT-LABEL:   func.func private @split_struct_float3(f64, f64, f64, f64, f64, f64, f64, !fir.array<2xi64>, !fir.array<2xi64>)
+// CHECK-FLOAT-LABEL: func.func private @split_struct_float3(f64, f64, f64, f64, f64, f64, f64, !fir.array<2xi64>, f64, i32)
+func.func private @split_struct_float3(f64, f64, f64, f64, f64, f64, f64, !fir.type<split_struct_float3{i:f64,j:f32}>, !fir.type<split_struct_mixed{i:f64,j:i32}>)
+
+// CHECK-INT-LABEL:   func.func private @split_struct_float4(f64, f64, f64, f64, f64, f64, f64, !fir.array<2xi64>, !fir.array<2xi64>)
+// CHECK-FLOAT-LABEL: func.func private @split_struct_float4(f64, f64, f64, f64, f64, f64, f64, i64, f32, !fir.array<2xi64>)
+func.func private @split_struct_float4(f64, f64, f64, f64, f64, f64, f64, !fir.type<split_struct_float4{i:i64,j:f32}>, !fir.type<split_struct_mixed{i:f64,j:i32}>)
+
+// CHECK-INT-LABEL:   func.func private @no_remaining_gpr(i64, i64, i64, i64, i64, i64, i64, i64, i64)
+// CHECK-FLOAT-LABEL: func.func private @no_remaining_gpr(i64, i64, i64, i64, i64, i64, i64, i64, i64)
+func.func private @no_remaining_gpr(i64, i64, i64, i64, i64, i64, i64, i64, !fir.type<no_remaining_gpr{i:f32,j:i32}>)
+
+// CHECK-INT-LABEL:   func.func private @no_remaining_gpr_int128(i64, i64, i64, i64, i64, i64, i128, i64)
+// CHECK-FLOAT-LABEL: func.func private @no_remaining_gpr_int128(i64, i64, i64, i64, i64, i64, i128, i64)
+func.func private @no_remaining_gpr_int128(i64, i64, i64, i64, i64, i64, i128, !fir.type<no_remaining_gpr{i:f32,j:i32}>)
+
+// CHECK-INT-LABEL:   func.func private @single_remaining_gpr(i64, i64, i64, i64, i64, i64, i64, i64)
+// CHECK-FLOAT-LABEL: func.func private @single_remaining_gpr(i64, i64, i64, i64, i64, i64, i64, f32, i32)
+func.func private @single_remaining_gpr(i64, i64, i64, i64, i64, i64, i64, !fir.type<single_remaining_gpr{i:f32,j:i32}>)
+
+// CHECK-INT-LABEL:   func.func private @prev_struct_passed(!fir.array<2xi64>, i32)
+// CHECK-FLOAT-LABEL: func.func private @prev_struct_passed(!fir.array<2xi64>, i32)
+func.func private @prev_struct_passed(!fir.type<prev{i:i64,j:i64}>, !fir.type<current{i:i32}>)
+
+}
diff --git a/flang/test/Fir/struct-return-riscv64.fir b/flang/test/Fir/struct-return-riscv64.fir
new file mode 100644
index 0000000000000..1508631934280
--- /dev/null
+++ b/flang/test/Fir/struct-return-riscv64.fir
@@ -0,0 +1,173 @@
+// Test RISCV64 ABI rewrite of struct returned by value (BIND(C), VALUE derived types).
+
+// TODO: There is currently no way to query the ABI kind (`lp64` vs. `lp64d`) from TargetRewritePass, so only check
+// `lp64d` (current hardcoded default) for now. Once this changes in the future can enable `lp64` tests too.
+// SKIP: fir-opt --target-rewrite="target=riscv64-unknown-linux-gnu,target-abi=lp64" %s | FileCheck --check-prefixes=CHECK-INT %s
+
+// RUN: fir-opt --target-rewrite="target=riscv64-unknown-linux-gnu" %s | FileCheck --check-prefixes=CHECK-FLOAT %s
+
+!single_struct = !fir.type<t1{i:i32,j:i16}>
+!double_struct = !fir.type<t2{i:i8,j:i32,k:i16}>
+!too_big = !fir.type<t3{i:i64,j:i64,k:i8}>
+!complex_double = !fir.type<t4{i:complex<f64>}>
+!mixed_float = !fir.type<t5{i:f32,j:i32}>
+
+module attributes {llvm.data_layout = "e-m:e-p:64:64-i64:64-i128:128-n32:64-S128", llvm.target_triple = "riscv64-unknown-linux-gnu"} {
+
+func.func private @single_return() -> !single_struct
+func.func @test_call_single_return(%arg0 : !fir.ref<!single_struct>) {
+  %out = fir.call @single_return() : () -> !single_struct
+  fir.store %out to %arg0 : !fir.ref<!single_struct>
+  return
+}
+
+// CHECK-INT-LABEL:   func.func private @single_return() -> i64
+// CHECK-INT-LABEL:   func.func @test_call_single_return(
+// CHECK-INT-SAME:      %[[ARG0:.*]]: !fir.ref<!fir.type<t1{i:i32,j:i16}>>) {
+// CHECK-INT:          %[[CALL_0:.*]] = fir.call @single_return() : () -> i64
+// CHECK-INT:          %[[INTR_0:.*]] = llvm.intr.stacksave : !llvm.ptr
+// CHECK-INT:          %[[ALLOCA_0:.*]] = fir.alloca i64
+// CHECK-INT:          fir.store %[[CALL_0]] to %[[ALLOCA_0]] : !fir.ref<i64>
+// CHECK-INT:          %[[CONVERT_0:.*]] = fir.convert %[[ALLOCA_0]] : (!fir.ref<i64>) -> !fir.ref<!fir.type<t1{i:i32,j:i16}>>
+// CHECK-INT:          %[[LOAD_0:.*]] = fir.load %[[CONVERT_0]] : !fir.ref<!fir.type<t1{i:i32,j:i16}>>
+// CHECK-INT:          llvm.intr.stackrestore %[[INTR_0]] : !llvm.ptr
+// CHECK-INT:          fir.store %[[LOAD_0]] to %[[ARG0]] : !fir.ref<!fir.type<t1{i:i32,j:i16}>>
+// CHECK-INT:          return
+// CHECK-INT:        }
+
+// ================================================
+
+func.func private @double_return() -> !double_struct
+func.func @test_call_double_return(%arg0 : !fir.ref<!double_struct>) {
+  %out = fir.call @double_return() : () -> !double_struct
+  fir.store %out to %arg0 : !fir.ref<!double_struct>
+  return
+}
+
+// CHECK-INT-LABEL:   func.func private @double_return() -> !fir.array<2xi64>
+// CHECK-INT-LABEL:   func.func @test_call_double_return(
+// CHECK-INT-SAME:      %[[ARG0:.*]]: !fir.ref<!fir.type<t2{i:i8,j:i32,k:i16}>>) {
+// CHECK-INT:          %[[CALL_0:.*]] = fir.call @double_return() : () -> !fir.array<2xi64>
+// CHECK-INT:          %[[INTR_0:.*]] = llvm.intr.stacksave : !llvm.ptr
+// CHECK-INT:          %[[ALLOCA_0:.*]] = fir.alloca !fir.array<2xi64>
+// CHECK-INT:          fir.store %[[CALL_0]] to %[[ALLOCA_0]] : !fir.ref<!fir.array<2xi64>>
+// CHECK-INT:          %[[CONVERT_0:.*]] = fir.convert %[[ALLOCA_0]] : (!fir.ref<!fir.array<2xi64>>) -> !fir.ref<!fir.type<t2{i:i8,j:i32,k:i16}>>
+// CHECK-INT:          %[[LOAD_0:.*]] = fir.load %[[CONVERT_0]] : !fir.ref<!fir.type<t2{i:i8,j:i32,k:i16}>>
+// CHECK-INT:          llvm.intr.stackrestore %[[INTR_0]] : !llvm.ptr
+// CHECK-INT:          fir.store %[[LOAD_0]] to %[[ARG0]] : !fir.ref<!fir.type<t2{i:i8,j:i32,k:i16}>>
+// CHECK-INT:          return
+
+// ================================================
+
+func.func private @too_big_return() -> !too_big
+func.func @test_call_too_big_return(%arg0 : !fir.ref<!too_big>) {
+  %out = fir.call @too_big_return() : () -> !too_big
+  fir.store %out to %arg0 : !fir.ref<!too_big>
+  return
+}
+
+// CHECK-INT-LABEL:   func.func private @too_big_return(!fir.ref<!fir.type<t3{i:i64,j:i64,k:i8}>> {llvm.align = 8 : i32, llvm.sret = !fir.type<t3{i:i64,j:i64,k:i8}>})
+// CHECK-INT-LABEL:   func.func @test_call_too_big_return(
+// CHECK-INT-SAME:      %[[ARG0:.*]]: !fir.ref<!fir.type<t3{i:i64,j:i64,k:i8}>>) {
+// CHECK-INT           %[[INTR_0:.*]] = llvm.intr.stacksave : !llvm.ptr
+// CHECK-INT           %[[ALLOCA_0:.*]] = fir.alloca !fir.type<t3{i:i64,j:i64,k:i8}>
+// CHECK-INT           fir.call @too_big_return(%[[ALLOCA_0]]) : (!fir.ref<!fir.type<t3{i:i64,j:i64,k:i8}>> {llvm.align = 8 : i32, llvm.sret = !fir.type<t3{i:i64,j:i64,k:i8}>}) -> ()
+// CHECK-INT           %[[CONVERT_0:.*]] = fir.convert %[[ALLOCA_0]] : (!fir.ref<!fir.type<t3{i:i64,j:i64,k:i8}>>) -> !fir.ref<!fir.type<t3{i:i64,j:i64,k:i8}>>
+// CHECK-INT           %[[LOAD_0:.*]] = fir.load %[[CONVERT_0]] : !fir.ref<!fir.type<t3{i:i64,j:i64,k:i8}>>
+// CHECK-INT           llvm.intr.stackrestore %[[INTR_0]] : !llvm.ptr
+// CHECK-INT           fir.store %[[LOAD_0]] to %[[ARG0]] : !fir.ref<!fir.type<t3{i:i64,j:i64,k:i8}>>
+// CHECK-INT           return
+
+
+// ================================================
+
+func.func private @too_big_more_args(i64) -> !too_big
+func.func @test_call_too_big_more_args(%arg0 : i64, %arg1 : !fir.ref<!too_big>) {
+  %out = fir.call @too_big_more_args(%arg0) : (i64) -> !too_big
+  fir.store %out to %arg1 : !fir.ref<!too_big>
+  return
+}
+
+// CHECK-INT-LABEL:   func.func private @too_big_more_args(!fir.ref<!fir.type<t3{i:i64,j:i64,k:i8}>> {llvm.align = 8 : i32, llvm.sret = !fir.type<t3{i:i64,j:i64,k:i8}>}, i64)
+// CHECK-INT-LABEL:   func.func @test_call_too_big_more_args(
+// CHECK-INT-SAME:      %[[ARG0:.*]]: i64,
+// CHECK-INT-SAME:      %[[ARG1:.*]]: !fir.ref<!fir.type<t3{i:i64,j:i64,k:i8}>>) {
+// CHECK-INT           %[[INTR_0:.*]] = llvm.intr.stacksave : !llvm.ptr
+// CHECK-INT           %[[ALLOCA_0:.*]] = fir.alloca !fir.type<t3{i:i64,j:i64,k:i8}>
+// CHECK-INT           fir.call @too_big_more_args(%[[ALLOCA_0]], %[[ARG0]]) : (!fir.ref<!fir.type<t3{i:i64,j:i64,k:i8}>> {llvm.align = 8 : i32, llvm.sret = !fir.type<t3{i:i64,j:i64,k:i8}>}, i64) -> ()
+// CHECK-INT           %[[CONVERT_0:.*]] = fir.convert %[[ALLOCA_0]] : (!fir.ref<!fir.type<t3{i:i64,j:i64,k:i8}>>) -> !fir.ref<!fir.type<t3{i:i64,j:i64,k:i8}>>
+// CHECK-INT           %[[LOAD_0:.*]] = fir.load %[[CONVERT_0]] : !fir.ref<!fir.type<t3{i:i64,j:i64,k:i8}>>
+// CHECK-INT           llvm.intr.stackrestore %[[INTR_0]] : !llvm.ptr
+// CHECK-INT           fir.store %[[LOAD_0]] to %[[ARG1]] : !fir.ref<!fir.type<t3{i:i64,j:i64,k:i8}>>
+// CHECK-INT           return
+
+
+// ================================================
+
+func.func private @complex_type() -> !complex_double
+func.func @test_complex_type(%arg0 : !fir.ref<!complex_double>) {
+  %out = fir.call @complex_type() : () -> !complex_double
+  fir.store %out to %arg0 : !fir.ref<!complex_double>
+  return
+}
+
+
+// CHECK-FLOAT-LABEL:   func.func @test_complex_type(
+// CHECK-FLOAT-SAME:      %[[ARG0:.*]]: !fir.ref<!fir.type<t4{i:complex<f64>}>>) {
+// CHECK-FLOAT:           %[[CALL_0:.*]] = fir.call @complex_type() : () -> tuple<f64, f64>
+// CHECK-FLOAT:           %[[INTR_0:.*]] = llvm.intr.stacksave : !llvm.ptr
+// CHECK-FLOAT:           %[[ALLOCA_0:.*]] = fir.alloca tuple<f64, f64>
+// CHECK-FLOAT:           fir.store %[[CALL_0]] to %[[ALLOCA_0]] : !fir.ref<tuple<f64, f64>>
+// CHECK-FLOAT:           %[[CONVERT_0:.*]] = fir.convert %[[ALLOCA_0]] : (!fir.ref<tuple<f64, f64>>) -> !fir.ref<!fir.type<t4{i:complex<f64>}>>
+// CHECK-FLOAT:           %[[LOAD_0:.*]] = fir.load %[[CONVERT_0]] : !fir.ref<!fir.type<t4{i:complex<f64>}>>
+// CHECK-FLOAT:           llvm.intr.stackrestore %[[INTR_0]] : !llvm.ptr
+// CHECK-FLOAT:           fir.store %[[LOAD_0]] to %[[ARG0]] : !fir.ref<!fir.type<t4{i:complex<f64>}>>
+// CHECK-FLOAT:           return
+
+// CHECK-INT-LABEL:   func.func @test_complex_type(
+// CHECK-INT-SAME:      %[[ARG0:.*]]: !fir.ref<!fir.type<t4{i:complex<f64>}>>) {
+// CHECK-INT:           %[[CALL_0:.*]] = fir.call @complex_type() : () -> !fir.array<2xi64>
+// CHECK-INT:           %[[INTR_0:.*]] = llvm.intr.stacksave : !llvm.ptr
+// CHECK-INT:           %[[ALLOCA_0:.*]] = fir.alloca !fir.array<2xi64>
+// CHECK-INT:           fir.store %[[CALL_0]] to %[[ALLOCA_0]] : !fir.ref<!fir.array<2xi64>>
+// CHECK-INT:           %[[CONVERT_0:.*]] = fir.convert %[[ALLOCA_0]] : (!fir.ref<!fir.array<2xi64>>) -> !fir.ref<!fir.type<t4{i:complex<f64>}>>
+// CHECK-INT:           %[[LOAD_0:.*]] = fir.load %[[CONVERT_0]] : !fir.ref<!fir.type<t4{i:complex<f64>}>>
+// CHECK-INT:           llvm.intr.stackrestore %[[INTR_0]] : !llvm.ptr
+// CHECK-INT:           fir.store %[[LOAD_0]] to %[[ARG0]] : !fir.ref<!fir.type<t4{i:complex<f64>}>>
+// CHECK-INT:           return
+
+
+// ================================================
+
+func.func private @mixed_float_type() -> !mixed_float
+func.func @test_mixed_float_type(%arg0 : !fir.ref<!mixed_float>) {
+  %out = fir.call @mixed_float_type() : () -> !mixed_float
+  fir.store %out to %arg0 : !fir.ref<!mixed_float>
+  return
+}
+
+// CHECK-FLOAT-LABEL:   func.func @test_mixed_float_type(
+// CHECK-FLOAT-SAME:     %[[ARG0:.*]]: !fir.ref<!fir.type<t5{i:f32,j:i32}>>) {
+// CHECK-FLOAT:          %[[CALL_0:.*]] = fir.call @mixed_float_type() : () -> tuple<f32, i32>
+// CHECK-FLOAT:          %[[INTR_0:.*]] = llvm.intr.stacksave : !llvm.ptr
+// CHECK-FLOAT:          %[[ALLOCA_0:.*]] = fir.alloca tuple<f32, i32>
+// CHECK-FLOAT:          fir.store %[[CALL_0]] to %[[ALLOCA_0]] : !fir.ref<tuple<f32, i32>>
+// CHECK-FLOAT:          %[[CONVERT_0:.*]] = fir.convert %[[ALLOCA_0]] : (!fir.ref<tuple<f32, i32>>) -> !fir.ref<!fir.type<t5{i:f32,j:i32}>>
+// CHECK-FLOAT:          %[[LOAD_0:.*]] = fir.load %[[CONVERT_0]] : !fir.ref<!fir.type<t5{i:f32,j:i32}>>
+// CHECK-FLOAT:          llvm.intr.stackrestore %[[INTR_0]] : !llvm.ptr
+// CHECK-FLOAT:          fir.store %[[LOAD_0]] to %[[ARG0]] : !fir.ref<!fir.type<t5{i:f32,j:i32}>>
+// CHECK-FLOAT:          return
+
+// CHECK-INT-LABEL:   func.func @test_mixed_float_type(
+// CHECK-INT-SAME:      %[[ARG0:.*]]: !fir.ref<!fir.type<t5{i:f32,j:i32}>>) {
+// CHECK-INT:           %[[CALL_0:.*]] = fir.call @mixed_float_type() : () -> i64
+// CHECK-INT:           %[[INTR_0:.*]] = llvm.intr.stacksave : !llvm.ptr
+// CHECK-INT:           %[[ALLOCA_0:.*]] = fir.alloca i64
+// CHECK-INT:           fir.store %[[CALL_0]] to %[[ALLOCA_0]] : !fir.ref<i64>
+// CHECK-INT:           %[[CONVERT_0:.*]] = fir.convert %[[ALLOCA_0]] : (!fir.ref<i64>) -> !fir.ref<!fir.type<t5{i:f32,j:i32}>>
+// CHECK-INT:           %[[LOAD_0:.*]] = fir.load %[[CONVERT_0]] : !fir.ref<!fir.type<t5{i:f32,j:i32}>>
+// CHECK-INT:           llvm.intr.stackrestore %[[INTR_0]] : !llvm.ptr
+// CHECK-INT:           fir.store %[[LOAD_0]] to %[[ARG0]] : !fir.ref<!fir.type<t5{i:f32,j:i32}>>
+// CHECK-INT:           return
+
+}

>From d891713c12a4e902ea3a5d85457159e7d123332e Mon Sep 17 00:00:00 2001
From: Philipp Rados <philipp.rados at openchip.com>
Date: Wed, 20 May 2026 16:48:22 +0200
Subject: [PATCH 2/4] Extend flattenTypeList and add error checking

---
 flang/lib/Optimizer/CodeGen/Target.cpp        | 176 +++++++++---------
 .../test/Fir/struct-passing-riscv64-byval.fir |  11 +-
 flang/test/Fir/struct-return-riscv64.fir      |  24 +++
 3 files changed, 122 insertions(+), 89 deletions(-)

diff --git a/flang/lib/Optimizer/CodeGen/Target.cpp b/flang/lib/Optimizer/CodeGen/Target.cpp
index c45862ece7573..c7f24e57e0f87 100644
--- a/flang/lib/Optimizer/CodeGen/Target.cpp
+++ b/flang/lib/Optimizer/CodeGen/Target.cpp
@@ -1434,102 +1434,98 @@ struct TargetRISCV64 : public GenericTarget<TargetRISCV64> {
     return marshal;
   }
 
-  void checkValidTypeOrCrash(mlir::Location loc, mlir::Type type) const {
+  CodeGenSpecifics::Marshalling
+  passOnTheStack(unsigned short recAlign, mlir::Type ty, bool isResult) const {
+    CodeGenSpecifics::Marshalling marshal;
+    // The stack is always 8 byte aligned
+    unsigned short align = std::max(recAlign, static_cast<unsigned short>(8));
+    marshal.emplace_back(fir::ReferenceType::get(ty),
+                         AT{align, /*byval=*/!isResult, /*sret=*/isResult});
+    return marshal;
+  }
+
+  const llvm::SmallVector<mlir::Type>
+  flattenTypeList(mlir::Location loc, const mlir::Type type) const {
+    llvm::SmallVector<mlir::Type> flatTypes;
+
     llvm::TypeSwitch<mlir::Type>(type)
-        .Case<mlir::IntegerType>([&](auto integerTy) {
-          // 128 bit int will be passed like any other 128bit struct as 2
-          // registers.
-          if (integerTy.getWidth() > 128)
+        .Case([&](mlir::IntegerType intTy) {
+          if (intTy.getWidth() <= 128)
+            flatTypes.push_back(intTy);
+          else
             TODO(loc,
                  "integerType with width exceeding 128 bits is unsupported");
         })
-        .Case<mlir::FloatType>([&](auto floatTy) {
-          if (floatTy.getWidth() > 64)
+        .Case([&](mlir::FloatType floatTy) {
+          if (floatTy.getWidth() <= 64)
+            flatTypes.push_back(floatTy);
+          else
             TODO(loc, "128 bit float is not supported by RISCV64");
         })
         .Case([&](mlir::ComplexType cmplx) {
           const auto *sem = &floatToSemantics(kindMap, cmplx.getElementType());
-          if (sem != &llvm::APFloat::IEEEsingle() &&
-              sem != &llvm::APFloat::IEEEdouble())
-            TODO(loc, "unsupported complex type(not IEEEsingle, IEEEdouble) "
+          if (sem == &llvm::APFloat::IEEEsingle() ||
+              sem == &llvm::APFloat::IEEEdouble())
+            std::fill_n(std::back_inserter(flatTypes), 2,
+                        cmplx.getElementType());
+          else
+            TODO(loc, "unsupported complex type(not IEEEsingle, IEEEdouble"
                       "as a structure component for BIND(C), "
                       "VALUE derived type argument and type return");
         })
-        .Case<fir::LogicalType, fir::CharacterType>([&](auto ty) {
-          // Always fine, characters with len>1 get already rejected before.
+        .Case([&](fir::LogicalType logicalTy) {
+          const unsigned width =
+              kindMap.getLogicalBitsize(logicalTy.getFKind());
+          flatTypes.push_back(mlir::IntegerType::get(type.getContext(), width));
         })
-        .Case([&](fir::RecordType recTy) {
-          for (auto [name, ty] : recTy.getTypeList())
-            checkValidTypeOrCrash(loc, ty);
+        .Case([&](fir::CharacterType charTy) {
+          if (charTy.getLen() == 1)
+            flatTypes.push_back(mlir::IntegerType::get(type.getContext(), 8));
+          else
+            TODO(loc,
+                 "fir.type value arg character components must have length 1");
         })
         .Case([&](fir::SequenceType seqTy) {
-          if (seqTy.hasDynamicExtents())
-            TODO(loc, "passing dynamic sequence argument to C by value is not "
-                      "supported");
-          checkValidTypeOrCrash(loc, seqTy.getElementType());
+          if (!seqTy.hasDynamicExtents()) {
+            const std::uint64_t numOfEle = seqTy.getConstantArraySize();
+            mlir::Type eleTy = seqTy.getEleTy();
+            // Don't check for subtype again if element-type is scalar.
+            if (mlir::isa<mlir::IntegerType, mlir::FloatType, fir::LogicalType>(
+                    eleTy)) {
+              std::fill_n(std::back_inserter(flatTypes), numOfEle, eleTy);
+            } else {
+              llvm::SmallVector<mlir::Type> subTypeList =
+                  flattenTypeList(loc, eleTy);
+              if (subTypeList.size() != 0)
+                for (std::uint64_t i = 0; i < numOfEle; ++i)
+                  llvm::copy(subTypeList, std::back_inserter(flatTypes));
+            }
+          } else
+            TODO(loc, "unsupported dynamic extent sequence type as a structure "
+                      "component for BIND(C), "
+                      "VALUE derived type argument and type return");
+        })
+        .Case([&](fir::RecordType recTy) {
+          for (auto &component : recTy.getTypeList()) {
+            mlir::Type eleTy = component.second;
+            llvm::SmallVector<mlir::Type> subTypeList =
+                flattenTypeList(loc, eleTy);
+            if (subTypeList.size() != 0)
+              llvm::copy(subTypeList, std::back_inserter(flatTypes));
+          }
         })
         .Case([&](fir::VectorType vecTy) {
           TODO(loc, "passing vector argument to C by value is not supported");
         })
         .Default([&](mlir::Type ty) {
-          if (!fir::conformsWithPassByRef(ty))
+          if (fir::conformsWithPassByRef(ty))
+            flatTypes.push_back(
+                mlir::IntegerType::get(type.getContext(), defaultWidth));
+          else
             TODO(loc, "unsupported component type for BIND(C), VALUE derived "
                       "type argument and type return");
         });
-  }
-
-  CodeGenSpecifics::Marshalling
-  passOnTheStack(unsigned short recAlign, mlir::Type ty, bool isResult) const {
-    CodeGenSpecifics::Marshalling marshal;
-    // The stack is always 8 byte aligned
-    unsigned short align = std::max(recAlign, static_cast<unsigned short>(8));
-    marshal.emplace_back(fir::ReferenceType::get(ty),
-                         AT{align, /*byval=*/!isResult, /*sret=*/isResult});
-    return marshal;
-  }
-
-  // Flatten a RecordType::TypeList containing more record types or array types
-  static std::vector<mlir::Type>
-  flattenTypeList(const RecordType::TypeList &types) {
-    std::vector<mlir::Type> flatTypes;
-    // The flat list will be at least the same size as the non-flat list.
-    flatTypes.reserve(types.size());
-    for (auto [c, type] : types) {
-      // Flatten record type
-      if (auto recTy = mlir::dyn_cast<RecordType>(type)) {
-        auto subTypeList = flattenTypeList(recTy.getTypeList());
-        llvm::copy(subTypeList, std::back_inserter(flatTypes));
-        continue;
-      }
-
-      // Flatten array type
-      if (auto seqTy = mlir::dyn_cast<SequenceType>(type)) {
-        assert(!seqTy.hasDynamicExtents() &&
-               "dynamic sequences should have been caught before.");
-        std::size_t n = seqTy.getConstantArraySize();
-        auto eleTy = seqTy.getElementType();
-        // Flatten array of record types
-        if (auto recTy = mlir::dyn_cast<RecordType>(eleTy)) {
-          auto subTypeList = flattenTypeList(recTy.getTypeList());
-          for (std::size_t i = 0; i < n; ++i)
-            llvm::copy(subTypeList, std::back_inserter(flatTypes));
-        } else {
-          std::fill_n(std::back_inserter(flatTypes),
-                      seqTy.getConstantArraySize(), eleTy);
-        }
-        continue;
-      }
 
-      // Complex type is made up of 2 floats
-      if (auto compTy = mlir::dyn_cast<mlir::ComplexType>(type)) {
-        flatTypes.push_back(compTy.getElementType());
-        flatTypes.push_back(compTy.getElementType());
-        continue;
-      }
-
-      // Other types are already flat
-      flatTypes.push_back(type);
-    }
     return flatTypes;
   }
 
@@ -1550,9 +1546,11 @@ struct TargetRISCV64 : public GenericTarget<TargetRISCV64> {
       if (gprArgs <= 0 && fprArgs <= 0)
         break;
 
-      // previous argument was passed by value and thus takes no registers.
-      if (attr.isByVal())
+      if (attr.isByVal()) {
+        if (gprArgs)
+          gprArgs--;
         continue;
+      }
 
       llvm::TypeSwitch<mlir::Type>(ty)
           .Case<mlir::IntegerType>([&](mlir::IntegerType intTy) {
@@ -1588,22 +1586,22 @@ struct TargetRISCV64 : public GenericTarget<TargetRISCV64> {
           // NOTE: Tuples are only used to marshal result types and so can't
           // appear in `previousArguments`.
           .Default([&](mlir::Type ty) {
-            if (fir::conformsWithPassByRef(ty))
-              if (gprArgs)
-                gprArgs--;
+            if (fir::conformsWithPassByRef(ty) && gprArgs)
+              gprArgs--;
           });
     }
   }
 
   CodeGenSpecifics::Marshalling
   getIntCCArgs(mlir::MLIRContext *context,
-               CodeGenSpecifics::Marshalling marshal,
-               std::uint64_t recordSize) const {
+               CodeGenSpecifics::Marshalling marshal, std::uint64_t recordSize,
+               std::uint64_t recordAlign) const {
     assert(recordSize <= defaultWidthBytes * 2);
 
     // NOTE: Clang doesn't handle split struct case when only a single register
     // remains. In general it lets the code generator take care of properly
     // handling excess integer register usage, do the same for Flang.
+    // For more info see comment in: llvm/lib/Target/RISCV/RISCVCallingConv.cpp
     if (recordSize <= defaultWidthBytes) {
       // Pass this as an integer.
       int width = llvm::PowerOf2Ceil(recordSize * 8);
@@ -1611,6 +1609,13 @@ struct TargetRISCV64 : public GenericTarget<TargetRISCV64> {
       return marshal;
     }
 
+    // Splitting of large scalars is handled in the backend.
+    if (recordAlign == 2 * defaultWidthBytes) {
+      marshal.emplace_back(mlir::IntegerType::get(context, 2 * defaultWidth),
+                           AT{});
+      return marshal;
+    }
+
     auto intTy = mlir::IntegerType::get(context, defaultWidth);
     marshal.emplace_back(fir::SequenceType::get({2}, intTy), AT{});
     return marshal;
@@ -1620,7 +1625,6 @@ struct TargetRISCV64 : public GenericTarget<TargetRISCV64> {
   classifyStruct(mlir::Location loc, fir::RecordType recTy, int gprArgs,
                  int fprArgs, bool isResult,
                  const Marshalling &previousArguments) const {
-    checkValidTypeOrCrash(loc, recTy);
     auto [recordSize, recordAlign] = fir::getTypeSizeAndAlignmentOrCrash(
         loc, recTy, getDataLayout(), kindMap);
 
@@ -1638,8 +1642,8 @@ struct TargetRISCV64 : public GenericTarget<TargetRISCV64> {
       // registers.
       return passOnTheStack(recordAlign, recTy, isResult);
 
-    const std::vector<mlir::Type> &flattenedTypes =
-        flattenTypeList(recTy.getTypeList());
+    const llvm::SmallVector<mlir::Type> &flattenedTypes =
+        flattenTypeList(loc, recTy);
 
     checkAvailableRegisters(loc, previousArguments, gprArgs, fprArgs);
 
@@ -1651,7 +1655,7 @@ struct TargetRISCV64 : public GenericTarget<TargetRISCV64> {
         marshal.emplace_back(flattenedTypes[0], AT{});
         return marshal;
       }
-      return getIntCCArgs(context, marshal, recordSize);
+      return getIntCCArgs(context, marshal, recordSize, recordAlign);
     }
 
     if (flattenedTypes.size() == 2 &&
@@ -1676,7 +1680,7 @@ struct TargetRISCV64 : public GenericTarget<TargetRISCV64> {
         }
         return marshal;
       }
-      return getIntCCArgs(context, marshal, recordSize);
+      return getIntCCArgs(context, marshal, recordSize, recordAlign);
     }
 
     if (flattenedTypes.size() == 2 &&
@@ -1708,7 +1712,7 @@ struct TargetRISCV64 : public GenericTarget<TargetRISCV64> {
       }
     }
 
-    return getIntCCArgs(context, marshal, recordSize);
+    return getIntCCArgs(context, marshal, recordSize, recordAlign);
   }
 
   CodeGenSpecifics::Marshalling
diff --git a/flang/test/Fir/struct-passing-riscv64-byval.fir b/flang/test/Fir/struct-passing-riscv64-byval.fir
index d62b02a51dd20..405dff49ce293 100644
--- a/flang/test/Fir/struct-passing-riscv64-byval.fir
+++ b/flang/test/Fir/struct-passing-riscv64-byval.fir
@@ -14,20 +14,19 @@ module attributes {llvm.data_layout = "e-m:e-p:64:64-i64:64-i128:128-n32:64-S128
 // CHECK-INT-LABEL: func.func private @single_i32i32(i64)
 // CHECK-INT-LABEL: func.func private @single_i8i8i8(i32)
 // CHECK-INT-LABEL: func.func private @single_i8i8i32(i64)
+// CHECK-INT-LABEL: func.func private @single_i128(i128)
 func.func private @single_i64(!fir.type<single_i64{i:i64}>)
 func.func private @single_i32i32(!fir.type<single_i32i32{i:i32,j:i32}>)
 func.func private @single_i8i8i8(!fir.type<single_i8i8i8{i:i8,j:i8,k:i8}>)
 func.func private @single_i8i8i32(!fir.type<single_i8i8i32{i:i8,j:i8,k:i32}>)
+func.func private @single_int128(!fir.type<single_int128{i:i128}>)
 
 // ================================================
 
 // CHECK-INT-LABEL: func.func private @double_i8i64(!fir.array<2xi64>)
-// CHECK-INT-LABEL: func.func private @double_int128(!fir.array<2xi64>)
 // CHECK-INT-LABEL: func.func private @double_complex(!fir.array<2xi64>)
-// CHECK-FLOAT-LABEL: func.func private @double_int128(!fir.array<2xi64>)
 // CHECK-FLOAT-LABEL: func.func private @double_complex(f64, f64)
 func.func private @double_i8i64(!fir.type<double_i8i64{i:i8,j:i64}>)
-func.func private @double_int128(!fir.type<double_int128{i:i128}>)
 func.func private @double_complex(!fir.type<double_complex{i:complex<f64>}>)
 
 // ================================================
@@ -44,10 +43,16 @@ func.func private @large_bad_alignment(!fir.type<large_bad_alignment{i:i8,j:i64,
 
 // CHECK-INT-LABEL:   func.func private @single_float(i32)
 // CHECK-INT-LABEL:   func.func private @float_int(i64)
+// CHECK-INT-LABEL:   func.func private @float_char(i64)
+// CHECK-INT-LABEL:   func.func private @logical_double(!fir.array<2xi64>)
 // CHECK-FLOAT-LABEL: func.func private @single_float(f32)
 // CHECK-FLOAT-LABEL: func.func private @float_int(f32, i32)
+// CHECK-FLOAT-LABEL: func.func private @float_char(f32, i8)
+// CHECK-FLOAT-LABEL: func.func private @logical_double(i8, f64)
 func.func private @single_float(!fir.type<single_float{f:f32}>)
 func.func private @float_int(!fir.type<float_int{f:f32,i:i32}>)
+func.func private @float_char(!fir.type<float_char{f:f32,c:!fir.char<1>}>)
+func.func private @logical_double(!fir.type<logical_double{l:!fir.logical<1>,f:f64}>)
 
 // ================================================
 
diff --git a/flang/test/Fir/struct-return-riscv64.fir b/flang/test/Fir/struct-return-riscv64.fir
index 1508631934280..03401ee7b42f9 100644
--- a/flang/test/Fir/struct-return-riscv64.fir
+++ b/flang/test/Fir/struct-return-riscv64.fir
@@ -11,6 +11,7 @@
 !too_big = !fir.type<t3{i:i64,j:i64,k:i8}>
 !complex_double = !fir.type<t4{i:complex<f64>}>
 !mixed_float = !fir.type<t5{i:f32,j:i32}>
+!large_scalar = !fir.type<t6{i:i128}>
 
 module attributes {llvm.data_layout = "e-m:e-p:64:64-i64:64-i128:128-n32:64-S128", llvm.target_triple = "riscv64-unknown-linux-gnu"} {
 
@@ -79,6 +80,29 @@ func.func @test_call_too_big_return(%arg0 : !fir.ref<!too_big>) {
 // CHECK-INT           return
 
 
+// ================================================
+
+func.func private @large_scalar_return() -> !large_scalar
+func.func @test_call_large_scalar_return(%arg0 : !fir.ref<!large_scalar>) {
+  %out = fir.call @large_scalar_return() : () -> !large_scalar
+  fir.store %out to %arg0 : !fir.ref<!large_scalar>
+  return
+}
+
+// CHECK-INT-LABEL:   func.func private @large_scalar_return() -> i128
+// CHECK-INT-LABEL:   func.func @test_call_large_scalar_return(
+// CHECK-INT-SAME:      %[[ARG0:.*]]: !fir.ref<!fir.type<t6{i:i128}>>) {
+// CHECK-INT:           %[[CALL_0:.*]] = fir.call @large_scalar_return() : () -> i128
+// CHECK-INT:           %[[INTR_0:.*]] = llvm.intr.stacksave : !llvm.ptr
+// CHECK-INT:           %[[ALLOCA_0:.*]] = fir.alloca i128
+// CHECK-INT:           fir.store %[[CALL_0]] to %[[ALLOCA_0]] : !fir.ref<i128>
+// CHECK-INT:           %[[CONVERT_0:.*]] = fir.convert %[[ALLOCA_0]] : (!fir.ref<i128>) -> !fir.ref<!fir.type<t6{i:i128}>>
+// CHECK-INT:           %[[LOAD_0:.*]] = fir.load %[[CONVERT_0]] : !fir.ref<!fir.type<t6{i:i128}>>
+// CHECK-INT:           llvm.intr.stackrestore %[[INTR_0]] : !llvm.ptr
+// CHECK-INT:           fir.store %[[LOAD_0]] to %[[ARG0]] : !fir.ref<!fir.type<t6{i:i128}>>
+// CHECK-INT:           return
+// CHECK-INT:         }
+
 // ================================================
 
 func.func private @too_big_more_args(i64) -> !too_big

>From 9415cf0704ada0a39f46bd68f536c06816185082 Mon Sep 17 00:00:00 2001
From: Philipp Rados <philipp.rados at openchip.com>
Date: Wed, 20 May 2026 17:08:04 +0200
Subject: [PATCH 3/4] Move flattenTypeList up to catch errors first

---
 flang/lib/Optimizer/CodeGen/Target.cpp | 7 ++++---
 1 file changed, 4 insertions(+), 3 deletions(-)

diff --git a/flang/lib/Optimizer/CodeGen/Target.cpp b/flang/lib/Optimizer/CodeGen/Target.cpp
index c7f24e57e0f87..f0e5cd9903ca2 100644
--- a/flang/lib/Optimizer/CodeGen/Target.cpp
+++ b/flang/lib/Optimizer/CodeGen/Target.cpp
@@ -1631,6 +1631,10 @@ struct TargetRISCV64 : public GenericTarget<TargetRISCV64> {
     CodeGenSpecifics::Marshalling marshal;
     mlir::MLIRContext *context = recTy.getContext();
 
+    // Have to do this first to catch any illegal types in the record.
+    const llvm::SmallVector<mlir::Type> &flattenedTypes =
+        flattenTypeList(loc, recTy);
+
     // This is odd and some targets reject it. The spec says to ignore it.
     // IIRC Fortran does not allow empty structs and not all versions of C do.
     // Try to do something sensible, rather than crashing.
@@ -1642,9 +1646,6 @@ struct TargetRISCV64 : public GenericTarget<TargetRISCV64> {
       // registers.
       return passOnTheStack(recordAlign, recTy, isResult);
 
-    const llvm::SmallVector<mlir::Type> &flattenedTypes =
-        flattenTypeList(loc, recTy);
-
     checkAvailableRegisters(loc, previousArguments, gprArgs, fprArgs);
 
     if (flattenedTypes.size() == 1 &&

>From bce79872aece962df0ff0b64cb419f2b5c6ba0c4 Mon Sep 17 00:00:00 2001
From: Philipp Rados <philipp.rados at openchip.com>
Date: Thu, 21 May 2026 12:43:09 +0200
Subject: [PATCH 4/4] Determine ABI from target-features + triple

---
 flang/lib/Optimizer/CodeGen/Target.cpp        | 21 ++++++++++++-------
 .../test/Fir/struct-passing-riscv64-byval.fir | 11 ++++------
 flang/test/Fir/struct-return-riscv64.fir      | 13 +++++-------
 3 files changed, 23 insertions(+), 22 deletions(-)

diff --git a/flang/lib/Optimizer/CodeGen/Target.cpp b/flang/lib/Optimizer/CodeGen/Target.cpp
index f0e5cd9903ca2..32a2c3a0dd680 100644
--- a/flang/lib/Optimizer/CodeGen/Target.cpp
+++ b/flang/lib/Optimizer/CodeGen/Target.cpp
@@ -1396,11 +1396,6 @@ struct TargetRISCV64 : public GenericTarget<TargetRISCV64> {
 
   static constexpr int defaultWidth = 64;
   static constexpr int defaultWidthBytes = defaultWidth / 8;
-  // TODO: Can't query ABI from inside TargetRewrite so assume the more common
-  // `lp64d` for now. Alternatively could check float-support from
-  // target-features, but that could be overridden by manually setting
-  // `-mabi=lp64`.
-  static constexpr bool hasHardFloatABI = true;
 
   CodeGenSpecifics::Marshalling
   complexArgumentType(mlir::Location loc, mlir::Type eleTy) const override {
@@ -1434,6 +1429,18 @@ struct TargetRISCV64 : public GenericTarget<TargetRISCV64> {
     return marshal;
   }
 
+  // NOTE: Currently only supports lp64/lp64d.
+  // TODO: Can't query target-abi from inside TargetRewrite so try to get it
+  // from target-features for now. Detailed logic from clang is explained in:
+  // riscv::getRISCVABI() in clang/lib/Driver/ToolChains/Arch/RISCV.cpp
+  bool hasHardFloatABI() const {
+    if (!targetFeatures.nullOrEmpty())
+      return targetFeatures.contains("+d");
+
+    // Fallback to get ABI from target-triple.
+    return triple.getOS() != llvm::Triple::UnknownOS;
+  }
+
   CodeGenSpecifics::Marshalling
   passOnTheStack(unsigned short recAlign, mlir::Type ty, bool isResult) const {
     CodeGenSpecifics::Marshalling marshal;
@@ -1720,7 +1727,7 @@ struct TargetRISCV64 : public GenericTarget<TargetRISCV64> {
   structArgumentType(mlir::Location loc, fir::RecordType recTy,
                      const Marshalling &previousArguments) const override {
     int gprArgs = 8;
-    int fprArgs = hasHardFloatABI ? 8 : 0;
+    int fprArgs = hasHardFloatABI() ? 8 : 0;
 
     return classifyStruct(loc, recTy, gprArgs, fprArgs, /*isResult=*/false,
                           previousArguments);
@@ -1729,7 +1736,7 @@ struct TargetRISCV64 : public GenericTarget<TargetRISCV64> {
   CodeGenSpecifics::Marshalling
   structReturnType(mlir::Location loc, fir::RecordType recTy) const override {
     int gprArgs = 2;
-    int fprArgs = hasHardFloatABI ? 2 : 0;
+    int fprArgs = hasHardFloatABI() ? 2 : 0;
 
     return classifyStruct(loc, recTy, gprArgs, fprArgs, /*isResult=*/true, {});
   }
diff --git a/flang/test/Fir/struct-passing-riscv64-byval.fir b/flang/test/Fir/struct-passing-riscv64-byval.fir
index 405dff49ce293..8ad022ea1a77d 100644
--- a/flang/test/Fir/struct-passing-riscv64-byval.fir
+++ b/flang/test/Fir/struct-passing-riscv64-byval.fir
@@ -1,10 +1,7 @@
-// Test RISCV64 lp64 ABI rewrite of struct passed by value (BIND(C), VALUE derived types).
+// Test RISCV64 lp64/lp64d ABI rewrite of struct passed by value (BIND(C), VALUE derived types).
 
-// TODO: There is currently no way to query the ABI kind (`lp64` vs. `lp64d`) from TargetRewritePass, so only check
-// `lp64d` (current default) for now. Once this changes in the future can enable `lp64` tests too.
-// SKIP: fir-opt --target-rewrite="target=riscv64-unknown-linux-gnu,target-abi=lp64" %s | FileCheck %s --check-prefixes=CHECK-INT
-
-// RUN: fir-opt --target-rewrite="target=riscv64-unknown-linux-gnu" %s | FileCheck %s --check-prefixes=CHECK-FLOAT
+// RUN: fir-opt --target-rewrite="target=riscv64" %s | FileCheck %s --check-prefixes=CHECK-INT
+// RUN: fir-opt --target-rewrite="target=riscv64 target-features=+d" %s | FileCheck %s --check-prefixes=CHECK-FLOAT
 
 module attributes {llvm.data_layout = "e-m:e-p:64:64-i64:64-i128:128-n32:64-S128", llvm.target_triple = "riscv64-unknown-linux-gnu"} {
 
@@ -19,7 +16,7 @@ func.func private @single_i64(!fir.type<single_i64{i:i64}>)
 func.func private @single_i32i32(!fir.type<single_i32i32{i:i32,j:i32}>)
 func.func private @single_i8i8i8(!fir.type<single_i8i8i8{i:i8,j:i8,k:i8}>)
 func.func private @single_i8i8i32(!fir.type<single_i8i8i32{i:i8,j:i8,k:i32}>)
-func.func private @single_int128(!fir.type<single_int128{i:i128}>)
+func.func private @single_i128(!fir.type<single_i128{i:i128}>)
 
 // ================================================
 
diff --git a/flang/test/Fir/struct-return-riscv64.fir b/flang/test/Fir/struct-return-riscv64.fir
index 03401ee7b42f9..bc126093ee017 100644
--- a/flang/test/Fir/struct-return-riscv64.fir
+++ b/flang/test/Fir/struct-return-riscv64.fir
@@ -1,10 +1,7 @@
-// Test RISCV64 ABI rewrite of struct returned by value (BIND(C), VALUE derived types).
+// Test RISCV64 lp64/lp64d ABI rewrite of struct returned by value (BIND(C), VALUE derived types).
 
-// TODO: There is currently no way to query the ABI kind (`lp64` vs. `lp64d`) from TargetRewritePass, so only check
-// `lp64d` (current hardcoded default) for now. Once this changes in the future can enable `lp64` tests too.
-// SKIP: fir-opt --target-rewrite="target=riscv64-unknown-linux-gnu,target-abi=lp64" %s | FileCheck --check-prefixes=CHECK-INT %s
-
-// RUN: fir-opt --target-rewrite="target=riscv64-unknown-linux-gnu" %s | FileCheck --check-prefixes=CHECK-FLOAT %s
+// RUN: fir-opt --target-rewrite="target=riscv64" %s | FileCheck --check-prefixes=CHECK-INT %s
+// RUN: fir-opt --target-rewrite="target=riscv64 target-features=+d" %s | FileCheck --check-prefixes=CHECK-FLOAT %s
 
 !single_struct = !fir.type<t1{i:i32,j:i16}>
 !double_struct = !fir.type<t2{i:i8,j:i32,k:i16}>
@@ -137,7 +134,7 @@ func.func @test_complex_type(%arg0 : !fir.ref<!complex_double>) {
 
 
 // CHECK-FLOAT-LABEL:   func.func @test_complex_type(
-// CHECK-FLOAT-SAME:      %[[ARG0:.*]]: !fir.ref<!fir.type<t4{i:complex<f64>}>>) {
+// CHECK-FLOAT-SAME:      %[[ARG0:.*]]: !fir.ref<!fir.type<t4{i:complex<f64>}>>) attributes {llvm.target_features = #llvm.target_features<["+d"]>} {
 // CHECK-FLOAT:           %[[CALL_0:.*]] = fir.call @complex_type() : () -> tuple<f64, f64>
 // CHECK-FLOAT:           %[[INTR_0:.*]] = llvm.intr.stacksave : !llvm.ptr
 // CHECK-FLOAT:           %[[ALLOCA_0:.*]] = fir.alloca tuple<f64, f64>
@@ -171,7 +168,7 @@ func.func @test_mixed_float_type(%arg0 : !fir.ref<!mixed_float>) {
 }
 
 // CHECK-FLOAT-LABEL:   func.func @test_mixed_float_type(
-// CHECK-FLOAT-SAME:     %[[ARG0:.*]]: !fir.ref<!fir.type<t5{i:f32,j:i32}>>) {
+// CHECK-FLOAT-SAME:     %[[ARG0:.*]]: !fir.ref<!fir.type<t5{i:f32,j:i32}>>) attributes {llvm.target_features = #llvm.target_features<["+d"]>} {
 // CHECK-FLOAT:          %[[CALL_0:.*]] = fir.call @mixed_float_type() : () -> tuple<f32, i32>
 // CHECK-FLOAT:          %[[INTR_0:.*]] = llvm.intr.stacksave : !llvm.ptr
 // CHECK-FLOAT:          %[[ALLOCA_0:.*]] = fir.alloca tuple<f32, i32>



More information about the flang-commits mailing list