[flang-commits] [flang] [flang] Add struct passing target rewrite hooks and partial X86-64 impl (PR #74829)
via flang-commits
flang-commits at lists.llvm.org
Fri Dec 8 07:45:31 PST 2023
https://github.com/jeanPerier updated https://github.com/llvm/llvm-project/pull/74829
>From ba69844c02c6b0d56a9ba4318f8fa13df9172078 Mon Sep 17 00:00:00 2001
From: Jean Perier <jperier at nvidia.com>
Date: Fri, 8 Dec 2023 03:16:39 -0800
Subject: [PATCH 1/3] [flang] Add struct passing target rewrite hooks and
partial X86-64 impl
In the context of C/Fortran interoperability (BIND(C)), it is possible
to give the VALUE attribute to a BIND(C) derived type dummy, which
according to Fortran 2018 18.3.6 - 2. (4) implies that it must be passed
like the equivalent C structure value. The way C structure value are
passed is ABI dependent.
LLVM does not implement the C struct ABI passing for LLVM aggregate type
arguments. It is up to the front-end, like clang is doing, to split the
struct into registers or pass the struct on the stack (llvm "byval") as
required by the target ABI.
So the logic for C struct passing sits in clang. Using it from flang
requires setting up a lot of clang context and to bridge FIR/MLIR
representation to clang AST representation for function signatures (in
both directions). It is a non trivial task.
See https://stackoverflow.com/questions/39438033/passing-structs-by-value-in-llvm-ir/75002581#75002581.
Since BIND(C) struct are rather limited as opposed to generic C struct
(e.g. no bit fields). It is easier to provide a limited implementation
of it for the case that matter to Fortran.
This patch:
- Updates the generic target rewrite pass to keep track of both the new
argument type and attributes. The motivation for this is to be able
to tell if a previously marshalled argument is passed in memory (it is
a C pointer), or if it is being passed on the stack (has the byval
llvm attributes).
- Adds an entry point in the target specific codegen to marshal struct
arguments, and use it in the generic target rewrite pass.
- Implements limited support for the X86-64 case. So far, the support
allows telling if a struct must be passed in register or on the stack,
and to deal with the stack case. The register case is left TODO in
this patch.
The X86-64 ABI implemented is the System V ABI for AMD64 version 1.0
---
.../flang/Optimizer/CodeGen/CGPasses.td | 5 +-
.../include/flang/Optimizer/CodeGen/Target.h | 34 +-
.../flang/Optimizer/CodeGen/TypeConverter.h | 6 +-
.../flang/Optimizer/Dialect/FIRTypes.td | 10 +
.../flang/Optimizer/Support/DataLayout.h | 12 +
flang/lib/Optimizer/CodeGen/CodeGen.cpp | 14 +-
flang/lib/Optimizer/CodeGen/Target.cpp | 336 ++++++++++++-
flang/lib/Optimizer/CodeGen/TargetRewrite.cpp | 459 +++++++++++-------
flang/lib/Optimizer/CodeGen/TypeConverter.cpp | 5 +-
flang/lib/Optimizer/Dialect/FIRType.cpp | 6 +
flang/lib/Optimizer/Support/DataLayout.cpp | 13 +
flang/test/Fir/recursive-type-tco.fir | 4 +-
.../test/Fir/struct-passing-x86-64-byval.fir | 103 ++++
13 files changed, 817 insertions(+), 190 deletions(-)
create mode 100644 flang/test/Fir/struct-passing-x86-64-byval.fir
diff --git a/flang/include/flang/Optimizer/CodeGen/CGPasses.td b/flang/include/flang/Optimizer/CodeGen/CGPasses.td
index 0014298a27a22..5e47119582776 100644
--- a/flang/include/flang/Optimizer/CodeGen/CGPasses.td
+++ b/flang/include/flang/Optimizer/CodeGen/CGPasses.td
@@ -23,7 +23,7 @@ def FIRToLLVMLowering : Pass<"fir-to-llvm-ir", "mlir::ModuleOp"> {
will also convert ops in the standard and FIRCG dialects.
}];
let constructor = "::fir::createFIRToLLVMPass()";
- let dependentDialects = ["mlir::LLVM::LLVMDialect"];
+ let dependentDialects = ["mlir::LLVM::LLVMDialect", "mlir::DLTIDialect"];
let options = [
Option<"forcedTargetTriple", "target", "std::string", /*default=*/"",
"Override module's target triple.">,
@@ -53,7 +53,8 @@ def TargetRewritePass : Pass<"target-rewrite", "mlir::ModuleOp"> {
representations that may differ based on the target machine.
}];
let constructor = "::fir::createFirTargetRewritePass()";
- let dependentDialects = [ "fir::FIROpsDialect", "mlir::func::FuncDialect" ];
+ let dependentDialects = [ "fir::FIROpsDialect", "mlir::func::FuncDialect",
+ "mlir::DLTIDialect" ];
let options = [
Option<"forcedTargetTriple", "target", "std::string", /*default=*/"",
"Override module's target triple.">,
diff --git a/flang/include/flang/Optimizer/CodeGen/Target.h b/flang/include/flang/Optimizer/CodeGen/Target.h
index acffe6c1cec9c..c3ef521ced120 100644
--- a/flang/include/flang/Optimizer/CodeGen/Target.h
+++ b/flang/include/flang/Optimizer/CodeGen/Target.h
@@ -13,6 +13,7 @@
#ifndef FORTRAN_OPTMIZER_CODEGEN_TARGET_H
#define FORTRAN_OPTMIZER_CODEGEN_TARGET_H
+#include "flang/Optimizer/Dialect/FIRType.h"
#include "flang/Optimizer/Dialect/Support/KindMapping.h"
#include "mlir/IR/BuiltinTypes.h"
#include "llvm/TargetParser/Triple.h"
@@ -20,6 +21,10 @@
#include <tuple>
#include <vector>
+namespace mlir {
+class DataLayout;
+}
+
namespace fir {
namespace details {
@@ -62,14 +67,20 @@ class Attributes {
class CodeGenSpecifics {
public:
using Attributes = details::Attributes;
- using Marshalling = std::vector<std::tuple<mlir::Type, Attributes>>;
+ using TypeAndAttr = std::tuple<mlir::Type, Attributes>;
+ using Marshalling = std::vector<TypeAndAttr>;
+
+ static std::unique_ptr<CodeGenSpecifics> get(mlir::MLIRContext *ctx,
+ llvm::Triple &&trp,
+ KindMapping &&kindMap,
+ const mlir::DataLayout &dl);
- static std::unique_ptr<CodeGenSpecifics>
- get(mlir::MLIRContext *ctx, llvm::Triple &&trp, KindMapping &&kindMap);
+ static TypeAndAttr getTypeAndAttr(mlir::Type t) { return TypeAndAttr{t, {}}; }
CodeGenSpecifics(mlir::MLIRContext *ctx, llvm::Triple &&trp,
- KindMapping &&kindMap)
- : context{*ctx}, triple{std::move(trp)}, kindMap{std::move(kindMap)} {}
+ KindMapping &&kindMap, const mlir::DataLayout &dl)
+ : context{*ctx}, triple{std::move(trp)}, kindMap{std::move(kindMap)},
+ dataLayout{&dl} {}
CodeGenSpecifics() = delete;
virtual ~CodeGenSpecifics() {}
@@ -90,6 +101,13 @@ class CodeGenSpecifics {
/// Type presentation of a `boxchar<n>` type value in memory.
virtual mlir::Type boxcharMemoryType(mlir::Type eleTy) const = 0;
+ /// Type representation of a `fir.type<T>` type argument when passed by
+ /// value. It may have to be split into several arguments, or be passed
+ /// as a byval reference argument (on the stack).
+ virtual Marshalling
+ structArgumentType(mlir::Location loc, fir::RecordType recTy,
+ const Marshalling &previousArguments) const = 0;
+
/// Type representation of a `boxchar<n>` type argument when passed by value.
/// An argument value may need to be passed as a (safe) reference argument.
///
@@ -143,10 +161,16 @@ class CodeGenSpecifics {
// Returns width in bits of C/C++ 'int' type size.
virtual unsigned char getCIntTypeWidth() const = 0;
+ const mlir::DataLayout &getDataLayout() const {
+ assert(dataLayout && "dataLayout must be set");
+ return *dataLayout;
+ }
+
protected:
mlir::MLIRContext &context;
llvm::Triple triple;
KindMapping kindMap;
+ const mlir::DataLayout *dataLayout = nullptr;
};
} // namespace fir
diff --git a/flang/include/flang/Optimizer/CodeGen/TypeConverter.h b/flang/include/flang/Optimizer/CodeGen/TypeConverter.h
index 9ce756bdfd966..396c136392555 100644
--- a/flang/include/flang/Optimizer/CodeGen/TypeConverter.h
+++ b/flang/include/flang/Optimizer/CodeGen/TypeConverter.h
@@ -39,6 +39,10 @@ static constexpr unsigned kDimLowerBoundPos = 0;
static constexpr unsigned kDimExtentPos = 1;
static constexpr unsigned kDimStridePos = 2;
+namespace mlir {
+class DataLayout;
+}
+
namespace fir {
/// FIR type converter
@@ -46,7 +50,7 @@ namespace fir {
class LLVMTypeConverter : public mlir::LLVMTypeConverter {
public:
LLVMTypeConverter(mlir::ModuleOp module, bool applyTBAA,
- bool forceUnifiedTBAATree);
+ bool forceUnifiedTBAATree, const mlir::DataLayout &);
// i32 is used here because LLVM wants i32 constants when indexing into struct
// types. Indexing into other aggregate types is more flexible.
diff --git a/flang/include/flang/Optimizer/Dialect/FIRTypes.td b/flang/include/flang/Optimizer/Dialect/FIRTypes.td
index 51608e3c1d63e..2a2f50720859e 100644
--- a/flang/include/flang/Optimizer/Dialect/FIRTypes.td
+++ b/flang/include/flang/Optimizer/Dialect/FIRTypes.td
@@ -326,6 +326,8 @@ def fir_RealType : FIR_Type<"Real", "real"> {
let extraClassDeclaration = [{
using KindTy = unsigned;
+ // Get MLIR float type with same semantics.
+ mlir::Type getFloatType(const fir::KindMapping &kindMap) const;
}];
let genVerifyDecl = 1;
@@ -495,6 +497,14 @@ def fir_SequenceType : FIR_Type<"Sequence", "array"> {
static constexpr Extent getUnknownExtent() {
return mlir::ShapedType::kDynamic;
}
+
+ std::uint64_t getConstantArraySize() {
+ assert(!hasDynamicExtents() && "array type must have constant shape");
+ std::uint64_t size = 1;
+ for (Extent extent : getShape())
+ size = size * static_cast<std::uint64_t>(extent);
+ return size;
+ }
}];
}
diff --git a/flang/include/flang/Optimizer/Support/DataLayout.h b/flang/include/flang/Optimizer/Support/DataLayout.h
index 88ff575a8ff08..d21576bb95f79 100644
--- a/flang/include/flang/Optimizer/Support/DataLayout.h
+++ b/flang/include/flang/Optimizer/Support/DataLayout.h
@@ -13,6 +13,9 @@
#ifndef FORTRAN_OPTIMIZER_SUPPORT_DATALAYOUT_H
#define FORTRAN_OPTIMIZER_SUPPORT_DATALAYOUT_H
+#include "mlir/Interfaces/DataLayoutInterfaces.h"
+#include <optional>
+
namespace mlir {
class ModuleOp;
}
@@ -34,6 +37,15 @@ void setMLIRDataLayout(mlir::ModuleOp mlirModule, const llvm::DataLayout &dl);
/// nothing.
void setMLIRDataLayoutFromAttributes(mlir::ModuleOp mlirModule,
bool allowDefaultLayout);
+
+/// Create mlir::DataLayout from the data layout information on the
+/// mlir::Module. Creates the data layout information attributes with
+/// setMLIRDataLayoutFromAttributes if the DLTI attribute is not yet set. If no
+/// information is present at all and \p allowDefaultLayout is false, returns
+/// std::nullopt.
+std::optional<mlir::DataLayout>
+getOrSetDataLayout(mlir::ModuleOp mlirModule, bool allowDefaultLayout = false);
+
} // namespace fir::support
#endif // FORTRAN_OPTIMIZER_SUPPORT_DATALAYOUT_H
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index bf175c8ebadee..c0f3ea3241a77 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -16,6 +16,7 @@
#include "flang/Optimizer/Dialect/FIRAttr.h"
#include "flang/Optimizer/Dialect/FIROps.h"
#include "flang/Optimizer/Dialect/FIRType.h"
+#include "flang/Optimizer/Support/DataLayout.h"
#include "flang/Optimizer/Support/InternalNames.h"
#include "flang/Optimizer/Support/TypeCode.h"
#include "flang/Optimizer/Support/Utils.h"
@@ -34,6 +35,7 @@
#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/DLTI/DLTI.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/Transforms/AddComdats.h"
#include "mlir/Dialect/OpenACC/OpenACC.h"
@@ -3820,10 +3822,20 @@ class FIRToLLVMLowering
if (mlir::failed(runPipeline(mathConvertionPM, mod)))
return signalPassFailure();
+ std::optional<mlir::DataLayout> dl =
+ fir::support::getOrSetDataLayout(mod, /*allowDefaultLayout=*/true);
+ if (!dl) {
+ mlir::emitError(mod.getLoc(),
+ "module operation must carry a data layout attribute "
+ "to generate llvm IR from FIR");
+ signalPassFailure();
+ return;
+ }
+
auto *context = getModule().getContext();
fir::LLVMTypeConverter typeConverter{getModule(),
options.applyTBAA || applyTBAA,
- options.forceUnifiedTBAATree};
+ options.forceUnifiedTBAATree, *dl};
mlir::RewritePatternSet pattern(context);
pattern.insert<
AbsentOpConversion, AddcOpConversion, AddrOfOpConversion,
diff --git a/flang/lib/Optimizer/CodeGen/Target.cpp b/flang/lib/Optimizer/CodeGen/Target.cpp
index 112f56e268c3b..ea10486a6b34c 100644
--- a/flang/lib/Optimizer/CodeGen/Target.cpp
+++ b/flang/lib/Optimizer/CodeGen/Target.cpp
@@ -18,6 +18,7 @@
#include "flang/Optimizer/Support/Utils.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/TypeRange.h"
+#include "llvm/ADT/TypeSwitch.h"
#define DEBUG_TYPE "flang-codegen-target"
@@ -58,6 +59,62 @@ static void typeTodo(const llvm::fltSemantics *sem, mlir::Location loc,
}
}
+/// Return the size of alignment of FIR types.
+/// TODO: consider moving this to a DataLayoutTypeInterface implementation
+/// for FIR types. It should first be ensured that it is OK to open the gate of
+/// target dependent type size inquiries in lowering. It would also not be
+/// straightforward given the need for a kind map that would need to be
+/// converted in terms of mlir::DataLayoutEntryKey.
+static std::pair<std::uint64_t, unsigned short>
+getSizeAndAlignment(mlir::Location loc, mlir::Type ty,
+ const mlir::DataLayout &dl,
+ const fir::KindMapping &kindMap) {
+ if (mlir::isa<mlir::IntegerType, mlir::FloatType, mlir::ComplexType>(ty)) {
+ llvm::TypeSize size = dl.getTypeSize(ty);
+ unsigned short alignment = dl.getTypeABIAlignment(ty);
+ return {size, alignment};
+ }
+ if (auto firCmplx = mlir::dyn_cast<fir::ComplexType>(ty)) {
+ auto [floatSize, floatAlign] =
+ getSizeAndAlignment(loc, firCmplx.getEleType(kindMap), dl, kindMap);
+ return {llvm::alignTo(floatSize, floatAlign) + floatSize, floatAlign};
+ }
+ if (auto real = mlir::dyn_cast<fir::RealType>(ty))
+ return getSizeAndAlignment(loc, real.getFloatType(kindMap), dl, kindMap);
+
+ if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(ty)) {
+ auto [eleSize, eleAlign] =
+ getSizeAndAlignment(loc, seqTy.getEleTy(), dl, kindMap);
+
+ std::uint64_t size =
+ llvm::alignTo(eleSize, eleAlign) * seqTy.getConstantArraySize();
+ return {size, eleAlign};
+ }
+ if (auto recTy = mlir::dyn_cast<fir::RecordType>(ty)) {
+ std::uint64_t size = 0;
+ unsigned short align = 8;
+ for (auto component : recTy.getTypeList()) {
+ auto [compSize, compAlign] =
+ getSizeAndAlignment(loc, component.second, dl, kindMap);
+ size =
+ llvm::alignTo(size, compAlign) + llvm::alignTo(compSize, compAlign);
+ align = std::max(align, compAlign);
+ }
+ return {size, align};
+ }
+ if (auto logical = mlir::dyn_cast<fir::LogicalType>(ty)) {
+ mlir::Type intTy = mlir::IntegerType::get(
+ logical.getContext(), kindMap.getLogicalBitsize(logical.getFKind()));
+ return getSizeAndAlignment(loc, intTy, dl, kindMap);
+ }
+ if (auto logical = mlir::dyn_cast<fir::CharacterType>(ty)) {
+ mlir::Type intTy = mlir::IntegerType::get(
+ logical.getContext(), kindMap.getLogicalBitsize(logical.getFKind()));
+ return getSizeAndAlignment(loc, intTy, dl, kindMap);
+ }
+ TODO(loc, "computing size of a component");
+}
+
namespace {
template <typename S>
struct GenericTarget : public CodeGenSpecifics {
@@ -95,6 +152,12 @@ struct GenericTarget : public CodeGenSpecifics {
return marshal;
}
+ CodeGenSpecifics::Marshalling
+ structArgumentType(mlir::Location loc, fir::RecordType,
+ const Marshalling &) const override {
+ TODO(loc, "passing VALUE BIND(C) derived type for this target");
+ }
+
CodeGenSpecifics::Marshalling
integerArgumentType(mlir::Location loc,
mlir::IntegerType argTy) const override {
@@ -318,6 +381,251 @@ struct TargetX86_64 : public GenericTarget<TargetX86_64> {
}
return marshal;
}
+
+ /// X86-64 argument classes from System V ABI version 1.0 section 3.2.3.
+ enum ArgClass {
+ Integer = 0,
+ SSE,
+ SSEUp,
+ X87,
+ X87Up,
+ ComplexX87,
+ NoClass,
+ Memory
+ };
+
+ /// Classify an argument type or a field of an aggregate type argument.
+ /// See ystem V ABI version 1.0 section 3.2.3.
+ /// The Lo and Hi class are set to the class of the lower eight eightbytes
+ /// and upper eight eightbytes on return.
+ /// If this is called for an aggregate field, the caller is responsible to
+ /// do the post-merge.
+ void classify(mlir::Location loc, mlir::Type type, std::uint64_t byteOffset,
+ ArgClass &Lo, ArgClass &Hi) const {
+ Hi = Lo = ArgClass::NoClass;
+ ArgClass ¤t = byteOffset < 8 ? Lo : Hi;
+ // System V AMD64 ABI 3.2.3. version 1.0
+ llvm::TypeSwitch<mlir::Type>(type)
+ .template Case<mlir::IntegerType>([&](mlir::IntegerType intTy) {
+ if (intTy.getWidth() == 128)
+ Hi = Lo = ArgClass::Integer;
+ else
+ current = ArgClass::Integer;
+ })
+ .template Case<mlir::FloatType, fir::RealType>([&](mlir::Type floatTy) {
+ const auto *sem = &floatToSemantics(kindMap, floatTy);
+ if (sem == &llvm::APFloat::x87DoubleExtended()) {
+ Lo = ArgClass::X87;
+ Hi = ArgClass::X87Up;
+ } else if (sem == &llvm::APFloat::IEEEquad()) {
+ Lo = ArgClass::SSE;
+ Hi = ArgClass::SSEUp;
+ } else {
+ current = ArgClass::SSE;
+ }
+ })
+ .template Case<fir::ComplexType>([&](fir::ComplexType cmplx) {
+ const auto *sem = &floatToSemantics(kindMap, cmplx.getElementType());
+ if (sem == &llvm::APFloat::x87DoubleExtended()) {
+ current = ArgClass::ComplexX87;
+ } else {
+ fir::SequenceType::Shape shape{2};
+ classifyArray(loc,
+ fir::SequenceType::get(shape, cmplx.getElementType()),
+ byteOffset, Lo, Hi);
+ }
+ })
+ .template Case<fir::LogicalType>([&](fir::LogicalType logical) {
+ if (kindMap.getLogicalBitsize(logical.getFKind()) == 128)
+ Hi = Lo = ArgClass::Integer;
+ else
+ current = ArgClass::Integer;
+ })
+ .template Case<fir::CharacterType>(
+ [&](fir::CharacterType character) { current = ArgClass::Integer; })
+ .template Case<fir::SequenceType>([&](fir::SequenceType seqTy) {
+ // Array component.
+ classifyArray(loc, seqTy, byteOffset, Lo, Hi);
+ })
+ .template Case<fir::RecordType>([&](fir::RecordType recTy) {
+ // Component that is a derived type.
+ classifyStruct(loc, recTy, byteOffset, Lo, Hi);
+ })
+ .template Case<fir::VectorType>([&](fir::VectorType vecTy) {
+ // Previously marshalled SSE eight byte for a previous struct
+ // argument.
+ auto *sem = fir::isa_real(vecTy.getEleTy())
+ ? &floatToSemantics(kindMap, vecTy.getEleTy())
+ : nullptr;
+ // Note expecting to hit this todo in standard code (it would
+ // require some vector type extension).
+ if (!(sem == &llvm::APFloat::IEEEsingle() && vecTy.getLen() <= 2) &&
+ !(sem == &llvm::APFloat::IEEEhalf() && vecTy.getLen() <= 4))
+ TODO(loc, "passing vector argument to C by value");
+ current = SSE;
+ })
+ .Default([&](mlir::Type ty) {
+ if (fir::conformsWithPassByRef(ty))
+ current = ArgClass::Integer; // Pointers.
+ else
+ TODO(loc, "unsupported component type for BIND(C), VALUE derived "
+ "type argument");
+ });
+ }
+
+ // Classify fields of a derived type starting at \p offset. Returns the new
+ // offset. Post-merge is left to the caller.
+ std::uint64_t classifyStruct(mlir::Location loc, fir::RecordType recTy,
+ std::uint64_t byteOffset, ArgClass &Lo,
+ ArgClass &Hi) const {
+ for (auto component : recTy.getTypeList()) {
+ if (byteOffset > 16) {
+ Lo = Hi = ArgClass::Memory;
+ return byteOffset;
+ }
+ mlir::Type compType = component.second;
+ auto [compSize, compAlign] =
+ getSizeAndAlignment(loc, compType, getDataLayout(), kindMap);
+ byteOffset = llvm::alignTo(byteOffset, compAlign);
+ ArgClass LoComp, HiComp;
+ classify(loc, compType, byteOffset, LoComp, HiComp);
+ Lo = mergeClass(Lo, LoComp);
+ Hi = mergeClass(Hi, HiComp);
+ byteOffset = byteOffset + llvm::alignTo(compSize, compAlign);
+ if (Lo == ArgClass::Memory || Hi == ArgClass::Memory)
+ return byteOffset;
+ }
+ return byteOffset;
+ }
+
+ // Classify fields of a constant size array type starting at \p offset.
+ // Returns the new offset. Post-merge is left to the caller.
+ void classifyArray(mlir::Location loc, fir::SequenceType seqTy,
+ std::uint64_t byteOffset, ArgClass &Lo,
+ ArgClass &Hi) const {
+ mlir::Type eleTy = seqTy.getEleTy();
+ const std::uint64_t arraySize = seqTy.getConstantArraySize();
+ auto [eleSize, eleAlign] =
+ getSizeAndAlignment(loc, eleTy, getDataLayout(), kindMap);
+ std::uint64_t eleStorageSize = llvm::alignTo(eleSize, eleAlign);
+ for (std::uint64_t i = 0; i < arraySize; ++i) {
+ byteOffset = llvm::alignTo(byteOffset, eleAlign);
+ if (byteOffset > 16) {
+ Lo = Hi = ArgClass::Memory;
+ return;
+ }
+ ArgClass LoComp, HiComp;
+ classify(loc, eleTy, byteOffset, LoComp, HiComp);
+ Lo = mergeClass(Lo, LoComp);
+ Hi = mergeClass(Hi, HiComp);
+ byteOffset = byteOffset + eleStorageSize;
+ if (Lo == ArgClass::Memory || Hi == ArgClass::Memory)
+ return;
+ }
+ }
+
+ // Goes through the previously marshalled arguments and count the
+ // register occupancy to check if there are enough registers left.
+ bool hasEnoughRegisters(mlir::Location loc, int neededIntRegisters,
+ int neededSSERegisters,
+ const Marshalling &previousArguments) const {
+ int availIntRegisters = 6;
+ int availSSERegisters = 8;
+ for (auto typeAndAttr : previousArguments) {
+ const auto &attr = std::get<Attributes>(typeAndAttr);
+ if (attr.isByVal() || attr.isSRet())
+ continue; // Previous argument passed on the stack.
+ ArgClass Lo, Hi;
+ Lo = Hi = ArgClass::NoClass;
+ classify(loc, std::get<mlir::Type>(typeAndAttr), 0, Lo, Hi);
+ // post merge is not needed here since previous aggregate arguments
+ // were marshalled into simpler arguments.
+ if (Lo == ArgClass::Integer)
+ --availIntRegisters;
+ else if (Lo == SSE)
+ --availSSERegisters;
+ if (Hi == ArgClass::Integer)
+ --availIntRegisters;
+ else if (Hi == ArgClass::SSE)
+ --availSSERegisters;
+ }
+ return availSSERegisters >= neededSSERegisters &&
+ availIntRegisters >= neededIntRegisters;
+ }
+
+ /// Argument class merging as described in System V ABI 3.2.3 point 4.
+ ArgClass mergeClass(ArgClass accum, ArgClass field) const {
+ assert((Accum != ArgClass::Memory && Accum != ArgClass::ComplexX87) &&
+ "Invalid accumulated classification during merge.");
+ if (accum == field || field == NoClass)
+ return accum;
+ if (field == ArgClass::Memory)
+ return ArgClass::Memory;
+ if (accum == NoClass)
+ return field;
+ if (accum == Integer || field == Integer)
+ return ArgClass::Integer;
+ if (field == ArgClass::X87 || field == ArgClass::X87Up ||
+ field == ArgClass::ComplexX87 || accum == ArgClass::X87 ||
+ accum == ArgClass::X87Up)
+ return Memory;
+ return SSE;
+ }
+
+ /// Argument class post merging as described in System V ABI 3.2.3 point 5.
+ void postMerge(std::uint64_t byteSize, ArgClass &Lo, ArgClass &Hi) const {
+ if (Hi == ArgClass::Memory)
+ Lo = ArgClass::Memory;
+ if (Hi == ArgClass::X87Up && Lo != ArgClass::X87)
+ Lo = ArgClass::Memory;
+ if (byteSize > 16 && (Lo != ArgClass::SSE || Hi != ArgClass::SSEUp))
+ Lo = ArgClass::Memory;
+ if (Hi == ArgClass::SSEUp && Lo != ArgClass::SSE)
+ Hi = SSE;
+ }
+
+ /// Marshal a derived type passed by value like a C struct.
+ CodeGenSpecifics::Marshalling
+ structArgumentType(mlir::Location loc, fir::RecordType recTy,
+ const Marshalling &previousArguments) const override {
+ std::uint64_t byteOffset = 0;
+ ArgClass Lo, Hi;
+ Lo = Hi = ArgClass::NoClass;
+ byteOffset = classifyStruct(loc, recTy, byteOffset, Lo, Hi);
+ postMerge(byteOffset, Lo, Hi);
+ if (Lo == ArgClass::Memory || Lo == ArgClass::X87 ||
+ Lo == ArgClass::ComplexX87)
+ return passOnTheStack(loc, recTy);
+ int neededIntRegisters = 0;
+ int neededSSERegisters = 0;
+ if (Lo == ArgClass::SSE)
+ ++neededSSERegisters;
+ else if (Lo == ArgClass::Integer)
+ ++neededIntRegisters;
+ if (Hi == ArgClass::SSE)
+ ++neededSSERegisters;
+ else if (Hi == ArgClass::Integer)
+ ++neededIntRegisters;
+ // C struct should not be split into LLVM registers if LLVM codegen is not
+ // able to later assign actual registers to all of them (struct passing is
+ // all in registers or all on the stack).
+ if (!hasEnoughRegisters(loc, neededIntRegisters, neededSSERegisters,
+ previousArguments))
+ return passOnTheStack(loc, recTy);
+ // TODO, marshal the struct into registers.
+ TODO(loc, "passing BIND(C), VALUE derived type in registers on X86-64");
+ }
+
+ /// Marshal an argument that must be passed on the stack.
+ CodeGenSpecifics::Marshalling passOnTheStack(mlir::Location loc,
+ mlir::Type ty) const {
+ CodeGenSpecifics::Marshalling marshal;
+ auto sizeAndAlign = getSizeAndAlignment(loc, ty, getDataLayout(), kindMap);
+ marshal.emplace_back(
+ fir::ReferenceType::get(ty),
+ AT{/*align=*/sizeAndAlign.second, /*byval=*/true, /*sret=*/false});
+ return marshal;
+ }
};
} // namespace
@@ -726,51 +1034,51 @@ struct TargetLoongArch64 : public GenericTarget<TargetLoongArch64> {
// TODO: Add other targets to this file as needed.
std::unique_ptr<fir::CodeGenSpecifics>
fir::CodeGenSpecifics::get(mlir::MLIRContext *ctx, llvm::Triple &&trp,
- KindMapping &&kindMap) {
+ KindMapping &&kindMap, const mlir::DataLayout &dl) {
switch (trp.getArch()) {
default:
break;
case llvm::Triple::ArchType::x86:
if (trp.isOSWindows())
return std::make_unique<TargetI386Win>(ctx, std::move(trp),
- std::move(kindMap));
+ std::move(kindMap), dl);
else
return std::make_unique<TargetI386>(ctx, std::move(trp),
- std::move(kindMap));
+ std::move(kindMap), dl);
case llvm::Triple::ArchType::x86_64:
if (trp.isOSWindows())
return std::make_unique<TargetX86_64Win>(ctx, std::move(trp),
- std::move(kindMap));
+ std::move(kindMap), dl);
else
return std::make_unique<TargetX86_64>(ctx, std::move(trp),
- std::move(kindMap));
+ std::move(kindMap), dl);
case llvm::Triple::ArchType::aarch64:
return std::make_unique<TargetAArch64>(ctx, std::move(trp),
- std::move(kindMap));
+ std::move(kindMap), dl);
case llvm::Triple::ArchType::ppc64:
return std::make_unique<TargetPPC64>(ctx, std::move(trp),
- std::move(kindMap));
+ std::move(kindMap), dl);
case llvm::Triple::ArchType::ppc64le:
return std::make_unique<TargetPPC64le>(ctx, std::move(trp),
- std::move(kindMap));
+ std::move(kindMap), dl);
case llvm::Triple::ArchType::sparc:
return std::make_unique<TargetSparc>(ctx, std::move(trp),
- std::move(kindMap));
+ std::move(kindMap), dl);
case llvm::Triple::ArchType::sparcv9:
return std::make_unique<TargetSparcV9>(ctx, std::move(trp),
- std::move(kindMap));
+ std::move(kindMap), dl);
case llvm::Triple::ArchType::riscv64:
return std::make_unique<TargetRISCV64>(ctx, std::move(trp),
- std::move(kindMap));
+ std::move(kindMap), dl);
case llvm::Triple::ArchType::amdgcn:
return std::make_unique<TargetAMDGPU>(ctx, std::move(trp),
- std::move(kindMap));
+ std::move(kindMap), dl);
case llvm::Triple::ArchType::nvptx64:
return std::make_unique<TargetNVPTX>(ctx, std::move(trp),
- std::move(kindMap));
+ std::move(kindMap), dl);
case llvm::Triple::ArchType::loongarch64:
return std::make_unique<TargetLoongArch64>(ctx, std::move(trp),
- std::move(kindMap));
+ std::move(kindMap), dl);
}
TODO(mlir::UnknownLoc::get(ctx), "target not implemented");
}
diff --git a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
index 241d1fa84e23d..277f3e447ed16 100644
--- a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
+++ b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
@@ -25,6 +25,8 @@
#include "flang/Optimizer/Dialect/FIROpsSupport.h"
#include "flang/Optimizer/Dialect/FIRType.h"
#include "flang/Optimizer/Dialect/Support/FIRContext.h"
+#include "flang/Optimizer/Support/DataLayout.h"
+#include "mlir/Dialect/DLTI/DLTI.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/TypeSwitch.h"
@@ -87,9 +89,23 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
if (!forcedTargetTriple.empty())
fir::setTargetTriple(mod, forcedTargetTriple);
- auto specifics = fir::CodeGenSpecifics::get(
- mod.getContext(), fir::getTargetTriple(mod), fir::getKindMapping(mod));
- setMembers(specifics.get(), &rewriter);
+ // TargetRewrite will require querying the type storage sizes, if it was
+ // not set already, create a DataLayoutSpec for the ModuleOp now.
+ std::optional<mlir::DataLayout> dl =
+ fir::support::getOrSetDataLayout(mod, /*allowDefaultLayout=*/true);
+ if (!dl) {
+ mlir::emitError(mod.getLoc(),
+ "module operation must carry a data layout attribute "
+ "to perform target ABI rewrites on FIR");
+ signalPassFailure();
+ return;
+ }
+
+ auto specifics =
+ fir::CodeGenSpecifics::get(mod.getContext(), fir::getTargetTriple(mod),
+ fir::getKindMapping(mod), *dl);
+
+ setMembers(specifics.get(), &rewriter, &*dl);
// We may need to call stacksave/stackrestore later, so
// create the FuncOps beforehand.
@@ -127,9 +143,10 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
template <typename A, typename B, typename C>
std::optional<std::function<mlir::Value(mlir::Operation *)>>
- rewriteCallComplexResultType(mlir::Location loc, A ty, B &newResTys,
- B &newInTys, C &newOpers,
- mlir::Value &savedStackPtr) {
+ rewriteCallComplexResultType(
+ mlir::Location loc, A ty, B &newResTys,
+ fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs, C &newOpers,
+ mlir::Value &savedStackPtr) {
if (noComplexConversion) {
newResTys.push_back(ty);
return std::nullopt;
@@ -149,7 +166,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
savedStackPtr = genStackSave(loc);
mlir::Value stack =
rewriter->create<fir::AllocaOp>(loc, fir::dyn_cast_ptrEleTy(resTy));
- newInTys.push_back(resTy);
+ newInTyAndAttrs.push_back(m[0]);
newOpers.push_back(stack);
return [=](mlir::Operation *) -> mlir::Value {
auto memTy = fir::ReferenceType::get(ty);
@@ -170,39 +187,49 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
};
}
- template <typename A, typename B, typename C>
- void rewriteCallComplexInputType(A ty, mlir::Value oper, B &newInTys,
- C &newOpers, mlir::Value &savedStackPtr) {
+ void passArgumentOnStackOrWithNewType(
+ mlir::Location loc, fir::CodeGenSpecifics::TypeAndAttr newTypeAndAttr,
+ mlir::Type oldType, mlir::Value oper,
+ llvm::SmallVectorImpl<mlir::Value> &newOpers,
+ mlir::Value &savedStackPtr) {
+ auto resTy = std::get<mlir::Type>(newTypeAndAttr);
+ auto attr = std::get<fir::CodeGenSpecifics::Attributes>(newTypeAndAttr);
+ auto oldRefTy = fir::ReferenceType::get(oldType);
+ // We are going to generate an alloca, so save the stack pointer.
+ if (!savedStackPtr)
+ savedStackPtr = genStackSave(loc);
+ if (attr.isByVal()) {
+ mlir::Value mem = rewriter->create<fir::AllocaOp>(loc, oldType);
+ rewriter->create<fir::StoreOp>(loc, oper, mem);
+ if (mem.getType() != resTy)
+ mem = rewriter->create<fir::ConvertOp>(loc, resTy, mem);
+ newOpers.push_back(mem);
+ } else {
+ auto mem = rewriter->create<fir::AllocaOp>(loc, resTy);
+ auto cast = rewriter->create<fir::ConvertOp>(loc, oldRefTy, mem);
+ rewriter->create<fir::StoreOp>(loc, oper, cast);
+ newOpers.push_back(rewriter->create<fir::LoadOp>(loc, mem));
+ }
+ }
+
+ template <typename CPLX>
+ void rewriteCallComplexInputType(
+ mlir::Location loc, CPLX ty, mlir::Value oper,
+ fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
+ llvm::SmallVectorImpl<mlir::Value> &newOpers,
+ mlir::Value &savedStackPtr) {
if (noComplexConversion) {
- newInTys.push_back(ty);
+ newInTyAndAttrs.push_back(fir::CodeGenSpecifics::getTypeAndAttr(ty));
newOpers.push_back(oper);
return;
}
- auto *ctx = ty.getContext();
- mlir::Location loc = mlir::UnknownLoc::get(ctx);
- if (auto *op = oper.getDefiningOp())
- loc = op->getLoc();
auto m = specifics->complexArgumentType(loc, ty.getElementType());
if (m.size() == 1) {
// COMPLEX is a single aggregate
- auto resTy = std::get<mlir::Type>(m[0]);
- auto attr = std::get<fir::CodeGenSpecifics::Attributes>(m[0]);
- auto oldRefTy = fir::ReferenceType::get(ty);
- // We are going to generate an alloca, so save the stack pointer.
- if (!savedStackPtr)
- savedStackPtr = genStackSave(loc);
- if (attr.isByVal()) {
- auto mem = rewriter->create<fir::AllocaOp>(loc, ty);
- rewriter->create<fir::StoreOp>(loc, oper, mem);
- newOpers.push_back(rewriter->create<fir::ConvertOp>(loc, resTy, mem));
- } else {
- auto mem = rewriter->create<fir::AllocaOp>(loc, resTy);
- auto cast = rewriter->create<fir::ConvertOp>(loc, oldRefTy, mem);
- rewriter->create<fir::StoreOp>(loc, oper, cast);
- newOpers.push_back(rewriter->create<fir::LoadOp>(loc, mem));
- }
- newInTys.push_back(resTy);
+ passArgumentOnStackOrWithNewType(loc, m[0], ty, oper, newOpers,
+ savedStackPtr);
+ newInTyAndAttrs.push_back(m[0]);
} else {
assert(m.size() == 2);
// COMPLEX is split into 2 separate arguments
@@ -214,12 +241,34 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
auto idx = rewriter->getIntegerAttr(iTy, index);
auto val = rewriter->create<fir::ExtractValueOp>(
loc, ty, oper, rewriter->getArrayAttr(idx));
- newInTys.push_back(ty);
+ newInTyAndAttrs.push_back(tup);
newOpers.push_back(val);
}
}
}
+ void rewriteCallStructInputType(
+ mlir::Location loc, fir::RecordType recTy, mlir::Value oper,
+ fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
+ llvm::SmallVectorImpl<mlir::Value> &newOpers,
+ mlir::Value &savedStackPtr) {
+ auto structArgs =
+ specifics->structArgumentType(loc, recTy, newInTyAndAttrs);
+ if (structArgs.size() != 1)
+ TODO(loc, "splitting BIND(C), VALUE derived type into several arguments");
+ passArgumentOnStackOrWithNewType(loc, structArgs[0], recTy, oper, newOpers,
+ savedStackPtr);
+ structArgs.push_back(structArgs[0]);
+ }
+
+ static bool hasByValOrSRetArgs(
+ const fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs) {
+ return llvm::any_of(newInTyAndAttrs, [](auto arg) {
+ const auto &attr = std::get<fir::CodeGenSpecifics::Attributes>(arg);
+ return attr.isByVal() || attr.isSRet();
+ });
+ }
+
// Convert fir.call and fir.dispatch Ops.
template <typename A>
void convertCallOp(A callOp) {
@@ -227,7 +276,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
auto loc = callOp.getLoc();
rewriter->setInsertionPoint(callOp);
llvm::SmallVector<mlir::Type> newResTys;
- llvm::SmallVector<mlir::Type> newInTys;
+ fir::CodeGenSpecifics::Marshalling newInTyAndAttrs;
llvm::SmallVector<mlir::Value> newOpers;
mlir::Value savedStackPtr = nullptr;
@@ -236,7 +285,8 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
int dropFront = 0;
if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) {
if (!callOp.getCallee()) {
- newInTys.push_back(fnTy.getInput(0));
+ newInTyAndAttrs.push_back(
+ fir::CodeGenSpecifics::getTypeAndAttr(fnTy.getInput(0)));
newOpers.push_back(callOp.getOperand(0));
dropFront = 1;
}
@@ -250,12 +300,14 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
mlir::Type ty = fnTy.getResult(0);
llvm::TypeSwitch<mlir::Type>(ty)
.template Case<fir::ComplexType>([&](fir::ComplexType cmplx) {
- wrap = rewriteCallComplexResultType(loc, cmplx, newResTys, newInTys,
- newOpers, savedStackPtr);
+ wrap = rewriteCallComplexResultType(loc, cmplx, newResTys,
+ newInTyAndAttrs, newOpers,
+ savedStackPtr);
})
.template Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) {
- wrap = rewriteCallComplexResultType(loc, cmplx, newResTys, newInTys,
- newOpers, savedStackPtr);
+ wrap = rewriteCallComplexResultType(loc, cmplx, newResTys,
+ newInTyAndAttrs, newOpers,
+ savedStackPtr);
})
.Default([&](mlir::Type ty) { newResTys.push_back(ty); });
} else if (fnTy.getResults().size() > 1) {
@@ -276,7 +328,8 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
bool sret;
if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) {
if (noCharacterConversion) {
- newInTys.push_back(boxTy);
+ newInTyAndAttrs.push_back(
+ fir::CodeGenSpecifics::getTypeAndAttr(boxTy));
newOpers.push_back(oper);
return;
}
@@ -304,18 +357,22 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
trailingInTys.push_back(argTy);
trailingOpers.push_back(unbox.getResult(idx));
} else {
- newInTys.push_back(argTy);
+ newInTyAndAttrs.push_back(e.value());
newOpers.push_back(unbox.getResult(idx));
}
}
})
.template Case<fir::ComplexType>([&](fir::ComplexType cmplx) {
- rewriteCallComplexInputType(cmplx, oper, newInTys, newOpers,
- savedStackPtr);
+ rewriteCallComplexInputType(loc, cmplx, oper, newInTyAndAttrs,
+ newOpers, savedStackPtr);
})
.template Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) {
- rewriteCallComplexInputType(cmplx, oper, newInTys, newOpers,
- savedStackPtr);
+ rewriteCallComplexInputType(loc, cmplx, oper, newInTyAndAttrs,
+ newOpers, savedStackPtr);
+ })
+ .template Case<fir::RecordType>([&](fir::RecordType recTy) {
+ rewriteCallStructInputType(loc, recTy, oper, newInTyAndAttrs,
+ newOpers, savedStackPtr);
})
.template Case<mlir::TupleType>([&](mlir::TupleType tuple) {
if (fir::isCharacterProcedureTuple(tuple)) {
@@ -344,12 +401,14 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
auto [funcPointer, len] =
fir::factory::extractCharacterProcedureTuple(builder, loc,
oper);
- newInTys.push_back(funcPointerType);
+ newInTyAndAttrs.push_back(
+ fir::CodeGenSpecifics::getTypeAndAttr(funcPointerType));
newOpers.push_back(funcPointer);
trailingInTys.push_back(lenType);
trailingOpers.push_back(len);
} else {
- newInTys.push_back(tuple);
+ newInTyAndAttrs.push_back(
+ fir::CodeGenSpecifics::getTypeAndAttr(tuple));
newOpers.push_back(oper);
}
})
@@ -358,11 +417,15 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
if (callOp.getPassArgPos() && *callOp.getPassArgPos() == index)
passArgShift = newOpers.size() - *callOp.getPassArgPos();
}
- newInTys.push_back(ty);
+ newInTyAndAttrs.push_back(
+ fir::CodeGenSpecifics::getTypeAndAttr(ty));
newOpers.push_back(oper);
});
}
- newInTys.insert(newInTys.end(), trailingInTys.begin(), trailingInTys.end());
+
+ llvm::SmallVector<mlir::Type> newInTypes = toTypeList(newInTyAndAttrs);
+ newInTypes.insert(newInTypes.end(), trailingInTys.begin(),
+ trailingInTys.end());
newOpers.insert(newOpers.end(), trailingOpers.begin(), trailingOpers.end());
llvm::SmallVector<mlir::Value, 1> newCallResults;
@@ -372,10 +435,15 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
newCall =
rewriter->create<A>(loc, *callOp.getCallee(), newResTys, newOpers);
} else {
- // Force new type on the input operand.
+ // TODO: llvm dialect must be updated to propagate argument on
+ // attributes for indirect calls. See:
+ // https://discourse.llvm.org/t/should-llvm-callop-be-able-to-carry-argument-attributes-for-indirect-calls/75431
+ if (hasByValOrSRetArgs(newInTyAndAttrs))
+ TODO(loc,
+ "passing argument or result on the stack in indirect calls");
newOpers[0].setType(mlir::FunctionType::get(
callOp.getContext(),
- mlir::TypeRange{newInTys}.drop_front(dropFront), newResTys));
+ mlir::TypeRange{newInTypes}.drop_front(dropFront), newResTys));
newCall = rewriter->create<A>(loc, newResTys, newOpers);
}
LLVM_DEBUG(llvm::dbgs() << "replacing call with " << newCall << '\n');
@@ -419,47 +487,69 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
// Result type fixup for fir::ComplexType and mlir::ComplexType
template <typename A, typename B>
- void lowerComplexSignatureRes(mlir::Location loc, A cmplx, B &newResTys,
- B &newInTys) {
+ void lowerComplexSignatureRes(
+ mlir::Location loc, A cmplx, B &newResTys,
+ fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs) {
if (noComplexConversion) {
newResTys.push_back(cmplx);
- } else {
- for (auto &tup :
- specifics->complexReturnType(loc, cmplx.getElementType())) {
- auto argTy = std::get<mlir::Type>(tup);
- if (std::get<fir::CodeGenSpecifics::Attributes>(tup).isSRet())
- newInTys.push_back(argTy);
- else
- newResTys.push_back(argTy);
- }
+ return;
+ }
+ for (auto &tup :
+ specifics->complexReturnType(loc, cmplx.getElementType())) {
+ auto argTy = std::get<mlir::Type>(tup);
+ if (std::get<fir::CodeGenSpecifics::Attributes>(tup).isSRet())
+ newInTyAndAttrs.push_back(tup);
+ else
+ newResTys.push_back(argTy);
}
}
// Argument type fixup for fir::ComplexType and mlir::ComplexType
- template <typename A, typename B>
- void lowerComplexSignatureArg(mlir::Location loc, A cmplx, B &newInTys) {
- if (noComplexConversion)
- newInTys.push_back(cmplx);
- else
- for (auto &tup :
- specifics->complexArgumentType(loc, cmplx.getElementType()))
- newInTys.push_back(std::get<mlir::Type>(tup));
+ template <typename A>
+ void lowerComplexSignatureArg(
+ mlir::Location loc, A cmplx,
+ fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs) {
+ if (noComplexConversion) {
+ newInTyAndAttrs.push_back(fir::CodeGenSpecifics::getTypeAndAttr(cmplx));
+ } else {
+ auto cplxArgs =
+ specifics->complexArgumentType(loc, cmplx.getElementType());
+ newInTyAndAttrs.insert(newInTyAndAttrs.end(), cplxArgs.begin(),
+ cplxArgs.end());
+ }
+ }
+
+ void
+ lowerStructSignatureArg(mlir::Location loc, fir::RecordType recTy,
+ fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs) {
+ auto structArgs =
+ specifics->structArgumentType(loc, recTy, newInTyAndAttrs);
+ newInTyAndAttrs.insert(newInTyAndAttrs.end(), structArgs.begin(),
+ structArgs.end());
+ }
+
+ llvm::SmallVector<mlir::Type>
+ toTypeList(const fir::CodeGenSpecifics::Marshalling &marshalled) {
+ llvm::SmallVector<mlir::Type> typeList;
+ for (auto &typeAndAttr : marshalled)
+ typeList.emplace_back(std::get<mlir::Type>(typeAndAttr));
+ return typeList;
}
/// Taking the address of a function. Modify the signature as needed.
void convertAddrOp(fir::AddrOfOp addrOp) {
rewriter->setInsertionPoint(addrOp);
auto addrTy = addrOp.getType().cast<mlir::FunctionType>();
+ fir::CodeGenSpecifics::Marshalling newInTyAndAttrs;
llvm::SmallVector<mlir::Type> newResTys;
- llvm::SmallVector<mlir::Type> newInTys;
auto loc = addrOp.getLoc();
for (mlir::Type ty : addrTy.getResults()) {
llvm::TypeSwitch<mlir::Type>(ty)
.Case<fir::ComplexType>([&](fir::ComplexType ty) {
- lowerComplexSignatureRes(loc, ty, newResTys, newInTys);
+ lowerComplexSignatureRes(loc, ty, newResTys, newInTyAndAttrs);
})
.Case<mlir::ComplexType>([&](mlir::ComplexType ty) {
- lowerComplexSignatureRes(loc, ty, newResTys, newInTys);
+ lowerComplexSignatureRes(loc, ty, newResTys, newInTyAndAttrs);
})
.Default([&](mlir::Type ty) { newResTys.push_back(ty); });
}
@@ -468,37 +558,49 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
llvm::TypeSwitch<mlir::Type>(ty)
.Case<fir::BoxCharType>([&](auto box) {
if (noCharacterConversion) {
- newInTys.push_back(box);
+ newInTyAndAttrs.push_back(
+ fir::CodeGenSpecifics::getTypeAndAttr(box));
} else {
for (auto &tup : specifics->boxcharArgumentType(box.getEleTy())) {
auto attr = std::get<fir::CodeGenSpecifics::Attributes>(tup);
auto argTy = std::get<mlir::Type>(tup);
- llvm::SmallVector<mlir::Type> &vec =
- attr.isAppend() ? trailingInTys : newInTys;
- vec.push_back(argTy);
+ if (attr.isAppend())
+ trailingInTys.push_back(argTy);
+ else
+ newInTyAndAttrs.push_back(tup);
}
}
})
.Case<fir::ComplexType>([&](fir::ComplexType ty) {
- lowerComplexSignatureArg(loc, ty, newInTys);
+ lowerComplexSignatureArg(loc, ty, newInTyAndAttrs);
})
.Case<mlir::ComplexType>([&](mlir::ComplexType ty) {
- lowerComplexSignatureArg(loc, ty, newInTys);
+ lowerComplexSignatureArg(loc, ty, newInTyAndAttrs);
})
.Case<mlir::TupleType>([&](mlir::TupleType tuple) {
if (fir::isCharacterProcedureTuple(tuple)) {
- newInTys.push_back(tuple.getType(0));
+ newInTyAndAttrs.push_back(
+ fir::CodeGenSpecifics::getTypeAndAttr(tuple.getType(0)));
trailingInTys.push_back(tuple.getType(1));
} else {
- newInTys.push_back(ty);
+ newInTyAndAttrs.push_back(
+ fir::CodeGenSpecifics::getTypeAndAttr(ty));
}
})
- .Default([&](mlir::Type ty) { newInTys.push_back(ty); });
+ .template Case<fir::RecordType>([&](fir::RecordType recTy) {
+ lowerStructSignatureArg(loc, recTy, newInTyAndAttrs);
+ })
+ .Default([&](mlir::Type ty) {
+ newInTyAndAttrs.push_back(
+ fir::CodeGenSpecifics::getTypeAndAttr(ty));
+ });
}
+ llvm::SmallVector<mlir::Type> newInTypes = toTypeList(newInTyAndAttrs);
// append trailing input types
- newInTys.insert(newInTys.end(), trailingInTys.begin(), trailingInTys.end());
+ newInTypes.insert(newInTypes.end(), trailingInTys.begin(),
+ trailingInTys.end());
// replace this op with a new one with the updated signature
- auto newTy = rewriter->getFunctionType(newInTys, newResTys);
+ auto newTy = rewriter->getFunctionType(newInTypes, newResTys);
auto newOp = rewriter->create<fir::AddrOfOp>(addrOp.getLoc(), newTy,
addrOp.getSymbol());
replaceOp(addrOp, newOp.getResult());
@@ -542,7 +644,8 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
if (((ty.isa<fir::BoxCharType>() || fir::isCharacterProcedureTuple(ty)) &&
!noCharacterConversion) ||
(fir::isa_complex(ty) && !noComplexConversion) ||
- (ty.isa<mlir::IntegerType>() && hasCCallingConv)) {
+ (ty.isa<mlir::IntegerType>() && hasCCallingConv) ||
+ ty.isa<fir::RecordType>()) {
LLVM_DEBUG(llvm::dbgs() << "rewrite " << signature << " for target\n");
return false;
}
@@ -566,7 +669,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
if (hasPortableSignature(funcTy, func) && !hasHostAssociations(func))
return;
llvm::SmallVector<mlir::Type> newResTys;
- llvm::SmallVector<mlir::Type> newInTys;
+ fir::CodeGenSpecifics::Marshalling newInTyAndAttrs;
llvm::SmallVector<std::pair<unsigned, mlir::NamedAttribute>> savedAttrs;
llvm::SmallVector<std::pair<unsigned, mlir::NamedAttribute>> extraAttrs;
llvm::SmallVector<FixupTy> fixups;
@@ -590,13 +693,13 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
if (noComplexConversion)
newResTys.push_back(cmplx);
else
- doComplexReturn(func, cmplx, newResTys, newInTys, fixups);
+ doComplexReturn(func, cmplx, newResTys, newInTyAndAttrs, fixups);
})
.Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) {
if (noComplexConversion)
newResTys.push_back(cmplx);
else
- doComplexReturn(func, cmplx, newResTys, newInTys, fixups);
+ doComplexReturn(func, cmplx, newResTys, newInTyAndAttrs, fixups);
})
.Case<mlir::IntegerType>([&](mlir::IntegerType intTy) {
auto m = specifics->integerArgumentType(func.getLoc(), intTy);
@@ -616,7 +719,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
// Saved potential shift in argument. Handling of result can add arguments
// at the beginning of the function signature.
- unsigned argumentShift = newInTys.size();
+ unsigned argumentShift = newInTyAndAttrs.size();
// Convert arguments
llvm::SmallVector<mlir::Type> trailingTys;
@@ -626,7 +729,8 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
llvm::TypeSwitch<mlir::Type>(ty)
.Case<fir::BoxCharType>([&](fir::BoxCharType boxTy) {
if (noCharacterConversion) {
- newInTys.push_back(boxTy);
+ newInTyAndAttrs.push_back(
+ fir::CodeGenSpecifics::getTypeAndAttr(boxTy));
} else {
// Convert a CHARACTER argument type. This can involve separating
// the pointer and the LEN into two arguments and moving the LEN
@@ -643,44 +747,40 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
} else {
if (sret) {
fixups.emplace_back(FixupTy::Codes::CharPair,
- newInTys.size(), index);
+ newInTyAndAttrs.size(), index);
} else {
fixups.emplace_back(FixupTy::Codes::Trailing,
- newInTys.size(), trailingTys.size());
+ newInTyAndAttrs.size(),
+ trailingTys.size());
}
- newInTys.push_back(argTy);
+ newInTyAndAttrs.push_back(tup);
}
}
}
})
.Case<fir::ComplexType>([&](fir::ComplexType cmplx) {
- if (noComplexConversion)
- newInTys.push_back(cmplx);
- else
- doComplexArg(func, cmplx, newInTys, fixups);
+ doComplexArg(func, cmplx, newInTyAndAttrs, fixups);
})
.Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) {
- if (noComplexConversion)
- newInTys.push_back(cmplx);
- else
- doComplexArg(func, cmplx, newInTys, fixups);
+ doComplexArg(func, cmplx, newInTyAndAttrs, fixups);
})
.Case<mlir::TupleType>([&](mlir::TupleType tuple) {
if (fir::isCharacterProcedureTuple(tuple)) {
fixups.emplace_back(FixupTy::Codes::TrailingCharProc,
- newInTys.size(), trailingTys.size());
- newInTys.push_back(tuple.getType(0));
+ newInTyAndAttrs.size(), trailingTys.size());
+ newInTyAndAttrs.push_back(
+ fir::CodeGenSpecifics::getTypeAndAttr(tuple.getType(0)));
trailingTys.push_back(tuple.getType(1));
} else {
- newInTys.push_back(ty);
+ newInTyAndAttrs.push_back(
+ fir::CodeGenSpecifics::getTypeAndAttr(ty));
}
})
.Case<mlir::IntegerType>([&](mlir::IntegerType intTy) {
auto m = specifics->integerArgumentType(func.getLoc(), intTy);
assert(m.size() == 1);
auto attr = std::get<fir::CodeGenSpecifics::Attributes>(m[0]);
- auto argTy = std::get<mlir::Type>(m[0]);
- auto argNo = newInTys.size();
+ auto argNo = newInTyAndAttrs.size();
llvm::StringRef extensionAttrName = attr.getIntExtensionAttrName();
if (!extensionAttrName.empty() &&
isFuncWithCCallingConvention(func))
@@ -691,14 +791,20 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
mlir::UnitAttr::get(func.getContext()));
});
- newInTys.push_back(argTy);
+ newInTyAndAttrs.push_back(m[0]);
+ })
+ .template Case<fir::RecordType>([&](fir::RecordType recTy) {
+ doStructArg(func, recTy, newInTyAndAttrs, fixups);
})
- .Default([&](mlir::Type ty) { newInTys.push_back(ty); });
+ .Default([&](mlir::Type ty) {
+ newInTyAndAttrs.push_back(
+ fir::CodeGenSpecifics::getTypeAndAttr(ty));
+ });
if (func.getArgAttrOfType<mlir::UnitAttr>(index,
fir::getHostAssocAttrName())) {
extraAttrs.push_back(
- {newInTys.size() - 1,
+ {newInTyAndAttrs.size() - 1,
rewriter->getNamedAttr("llvm.nest", rewriter->getUnitAttr())});
}
}
@@ -712,12 +818,16 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
int offset = 0;
for (std::remove_const_t<decltype(fixupSize)> i = 0; i < fixupSize; ++i) {
const auto &fixup = fixups[i];
+ mlir::Type fixupType =
+ fixup.index < newInTyAndAttrs.size()
+ ? std::get<mlir::Type>(newInTyAndAttrs[fixup.index])
+ : mlir::Type{};
switch (fixup.code) {
case FixupTy::Codes::ArgumentAsLoad: {
// Argument was pass-by-value, but is now pass-by-reference and
// possibly with a different element type.
- auto newArg = func.front().insertArgument(fixup.index,
- newInTys[fixup.index], loc);
+ auto newArg =
+ func.front().insertArgument(fixup.index, fixupType, loc);
rewriter->setInsertionPointToStart(&func.front());
auto oldArgTy =
fir::ReferenceType::get(oldArgTys[fixup.index - offset]);
@@ -732,14 +842,13 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
auto oldArgTy =
fir::ReferenceType::get(oldArgTys[fixup.index - offset]);
// If type did not change, keep the original argument.
- if (newInTys[fixup.index] == oldArgTy)
+ if (fixupType == oldArgTy)
break;
- auto newArg = func.front().insertArgument(fixup.index,
- newInTys[fixup.index], loc);
+ auto newArg =
+ func.front().insertArgument(fixup.index, fixupType, loc);
rewriter->setInsertionPointToStart(&func.front());
- auto mem =
- rewriter->create<fir::AllocaOp>(loc, newInTys[fixup.index]);
+ auto mem = rewriter->create<fir::AllocaOp>(loc, fixupType);
rewriter->create<fir::StoreOp>(loc, newArg, mem);
auto cast = rewriter->create<fir::ConvertOp>(loc, oldArgTy, mem);
mlir::Value load = rewriter->create<fir::LoadOp>(loc, cast);
@@ -753,8 +862,8 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
case FixupTy::Codes::CharPair: {
// The FIR boxchar argument has been split into a pair of distinct
// arguments that are in juxtaposition to each other.
- auto newArg = func.front().insertArgument(fixup.index,
- newInTys[fixup.index], loc);
+ auto newArg =
+ func.front().insertArgument(fixup.index, fixupType, loc);
if (fixup.second == 1) {
rewriter->setInsertionPointToStart(&func.front());
auto boxTy = oldArgTys[fixup.index - offset - fixup.second];
@@ -768,8 +877,8 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
case FixupTy::Codes::ReturnAsStore: {
// The value being returned is now being returned in memory (callee
// stack space) through a hidden reference argument.
- auto newArg = func.front().insertArgument(fixup.index,
- newInTys[fixup.index], loc);
+ auto newArg =
+ func.front().insertArgument(fixup.index, fixupType, loc);
offset++;
func.walk([&](mlir::func::ReturnOp ret) {
rewriter->setInsertionPoint(ret);
@@ -801,8 +910,8 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
case FixupTy::Codes::Split: {
// The FIR argument has been split into a pair of distinct arguments
// that are in juxtaposition to each other. (For COMPLEX value.)
- auto newArg = func.front().insertArgument(fixup.index,
- newInTys[fixup.index], loc);
+ auto newArg =
+ func.front().insertArgument(fixup.index, fixupType, loc);
if (fixup.second == 1) {
rewriter->setInsertionPointToStart(&func.front());
auto cplxTy = oldArgTys[fixup.index - offset - fixup.second];
@@ -825,8 +934,8 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
// The first part of the pair appears in the original argument
// position. The second part of the pair is appended after all the
// original arguments. (Boxchar arguments.)
- auto newBufArg = func.front().insertArgument(
- fixup.index, newInTys[fixup.index], loc);
+ auto newBufArg =
+ func.front().insertArgument(fixup.index, fixupType, loc);
auto newLenArg =
func.front().addArgument(trailingTys[fixup.second], loc);
auto boxTy = oldArgTys[fixup.index - offset];
@@ -841,8 +950,8 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
// pair of distinct arguments. The first part of the pair appears in
// the original argument position. The second part of the pair is
// appended after all the original arguments.
- auto newProcPointerArg = func.front().insertArgument(
- fixup.index, newInTys[fixup.index], loc);
+ auto newProcPointerArg =
+ func.front().insertArgument(fixup.index, fixupType, loc);
auto newLenArg =
func.front().addArgument(trailingTys[fixup.second], loc);
auto tupleType = oldArgTys[fixup.index - offset];
@@ -857,10 +966,11 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
}
}
+ llvm::SmallVector<mlir::Type> newInTypes = toTypeList(newInTyAndAttrs);
// Set the new type and finalize the arguments, etc.
- newInTys.insert(newInTys.end(), trailingTys.begin(), trailingTys.end());
+ newInTypes.insert(newInTypes.end(), trailingTys.begin(), trailingTys.end());
auto newFuncTy =
- mlir::FunctionType::get(func.getContext(), newInTys, newResTys);
+ mlir::FunctionType::get(func.getContext(), newInTypes, newResTys);
LLVM_DEBUG(llvm::dbgs() << "new func: " << newFuncTy << '\n');
func.setType(newFuncTy);
@@ -899,7 +1009,8 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
/// GPR.
template <typename A, typename B, typename C>
void doComplexReturn(mlir::func::FuncOp func, A cmplx, B &newResTys,
- B &newInTys, C &fixups) {
+ fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
+ C &fixups) {
if (noComplexConversion) {
newResTys.push_back(cmplx);
return;
@@ -911,7 +1022,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
auto attr = std::get<fir::CodeGenSpecifics::Attributes>(tup);
auto argTy = std::get<mlir::Type>(tup);
if (attr.isSRet()) {
- unsigned argNo = newInTys.size();
+ unsigned argNo = newInTyAndAttrs.size();
if (auto align = attr.getAlignment())
fixups.emplace_back(
FixupTy::Codes::ReturnAsStore, argNo, [=](mlir::func::FuncOp func) {
@@ -931,42 +1042,35 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
func.setArgAttr(argNo, "llvm.sret",
mlir::TypeAttr::get(elemType));
});
- newInTys.push_back(argTy);
+ newInTyAndAttrs.push_back(tup);
return;
- } else {
- if (auto align = attr.getAlignment())
- fixups.emplace_back(FixupTy::Codes::ReturnType, newResTys.size(),
- [=](mlir::func::FuncOp func) {
- func.setArgAttr(
- newResTys.size(), "llvm.align",
- rewriter->getIntegerAttr(
- rewriter->getIntegerType(32), align));
- });
- else
- fixups.emplace_back(FixupTy::Codes::ReturnType, newResTys.size());
}
+ if (auto align = attr.getAlignment())
+ fixups.emplace_back(
+ FixupTy::Codes::ReturnType, newResTys.size(),
+ [=](mlir::func::FuncOp func) {
+ func.setArgAttr(
+ newResTys.size(), "llvm.align",
+ rewriter->getIntegerAttr(rewriter->getIntegerType(32), align));
+ });
+ else
+ fixups.emplace_back(FixupTy::Codes::ReturnType, newResTys.size());
newResTys.push_back(argTy);
}
- /// Convert a complex argument value. This can involve storing the value to
- /// a temporary memory location or factoring the value into two distinct
- /// arguments.
- template <typename A, typename B, typename C>
- void doComplexArg(mlir::func::FuncOp func, A cmplx, B &newInTys, C &fixups) {
- if (noComplexConversion) {
- newInTys.push_back(cmplx);
- return;
- }
- auto m =
- specifics->complexArgumentType(func.getLoc(), cmplx.getElementType());
- const auto fixupCode =
- m.size() > 1 ? FixupTy::Codes::Split : FixupTy::Codes::ArgumentType;
- for (auto e : llvm::enumerate(m)) {
+ template <typename FIXUPS>
+ void
+ createFuncOpArgFixups(mlir::func::FuncOp func,
+ fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
+ fir::CodeGenSpecifics::Marshalling &argsInTys,
+ FIXUPS &fixups) {
+ const auto fixupCode = argsInTys.size() > 1 ? FixupTy::Codes::Split
+ : FixupTy::Codes::ArgumentType;
+ for (auto e : llvm::enumerate(argsInTys)) {
auto &tup = e.value();
auto index = e.index();
auto attr = std::get<fir::CodeGenSpecifics::Attributes>(tup);
- auto argTy = std::get<mlir::Type>(tup);
- auto argNo = newInTys.size();
+ auto argNo = newInTyAndAttrs.size();
if (attr.isByVal()) {
if (auto align = attr.getAlignment())
fixups.emplace_back(FixupTy::Codes::ArgumentAsLoad, argNo,
@@ -981,7 +1085,8 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
rewriter->getIntegerType(32), align));
});
else
- fixups.emplace_back(FixupTy::Codes::ArgumentAsLoad, newInTys.size(),
+ fixups.emplace_back(FixupTy::Codes::ArgumentAsLoad,
+ newInTyAndAttrs.size(),
[=](mlir::func::FuncOp func) {
auto elemType = fir::dyn_cast_ptrOrBoxEleTy(
func.getFunctionType().getInput(argNo));
@@ -999,8 +1104,33 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
else
fixups.emplace_back(fixupCode, argNo, index);
}
- newInTys.push_back(argTy);
+ newInTyAndAttrs.push_back(tup);
+ }
+ }
+
+ /// Convert a complex argument value. This can involve storing the value to
+ /// a temporary memory location or factoring the value into two distinct
+ /// arguments.
+ template <typename A, typename B>
+ void doComplexArg(mlir::func::FuncOp func, A cmplx,
+ fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
+ B &fixups) {
+ if (noComplexConversion) {
+ newInTyAndAttrs.push_back(fir::CodeGenSpecifics::getTypeAndAttr(cmplx));
+ return;
}
+ auto cplxArgs =
+ specifics->complexArgumentType(func.getLoc(), cmplx.getElementType());
+ createFuncOpArgFixups(func, newInTyAndAttrs, cplxArgs, fixups);
+ }
+
+ template <typename FIXUPS>
+ void doStructArg(mlir::func::FuncOp func, fir::RecordType recTy,
+ fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
+ FIXUPS &fixups) {
+ auto structArgs =
+ specifics->structArgumentType(func.getLoc(), recTy, newInTyAndAttrs);
+ createFuncOpArgFixups(func, newInTyAndAttrs, structArgs, fixups);
}
private:
@@ -1011,12 +1141,14 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
op->erase();
}
- inline void setMembers(fir::CodeGenSpecifics *s, mlir::OpBuilder *r) {
+ inline void setMembers(fir::CodeGenSpecifics *s, mlir::OpBuilder *r,
+ mlir::DataLayout *dl) {
specifics = s;
rewriter = r;
+ dataLayout = dl;
}
- inline void clearMembers() { setMembers(nullptr, nullptr); }
+ inline void clearMembers() { setMembers(nullptr, nullptr, nullptr); }
// Inserts a call to llvm.stacksave at the current insertion
// point and the given location. Returns the call's result Value.
@@ -1032,9 +1164,10 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
fir::CodeGenSpecifics *specifics = nullptr;
mlir::OpBuilder *rewriter = nullptr;
+ mlir::DataLayout *dataLayout = nullptr;
mlir::func::FuncOp stackSaveFn = nullptr;
mlir::func::FuncOp stackRestoreFn = nullptr;
-}; // namespace
+};
} // namespace
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
diff --git a/flang/lib/Optimizer/CodeGen/TypeConverter.cpp b/flang/lib/Optimizer/CodeGen/TypeConverter.cpp
index 1d48592ec6ac2..209c586411f41 100644
--- a/flang/lib/Optimizer/CodeGen/TypeConverter.cpp
+++ b/flang/lib/Optimizer/CodeGen/TypeConverter.cpp
@@ -27,12 +27,13 @@
namespace fir {
LLVMTypeConverter::LLVMTypeConverter(mlir::ModuleOp module, bool applyTBAA,
- bool forceUnifiedTBAATree)
+ bool forceUnifiedTBAATree,
+ const mlir::DataLayout &dl)
: mlir::LLVMTypeConverter(module.getContext()),
kindMapping(getKindMapping(module)),
specifics(CodeGenSpecifics::get(module.getContext(),
getTargetTriple(module),
- getKindMapping(module))),
+ getKindMapping(module), dl)),
tbaaBuilder(std::make_unique<TBAABuilder>(module->getContext(), applyTBAA,
forceUnifiedTBAATree)) {
LLVM_DEBUG(llvm::dbgs() << "FIR type converter\n");
diff --git a/flang/lib/Optimizer/Dialect/FIRType.cpp b/flang/lib/Optimizer/Dialect/FIRType.cpp
index 730317a9bc238..d0c7bae674b6c 100644
--- a/flang/lib/Optimizer/Dialect/FIRType.cpp
+++ b/flang/lib/Optimizer/Dialect/FIRType.cpp
@@ -849,6 +849,12 @@ fir::RealType::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
return mlir::success();
}
+mlir::Type fir::RealType::getFloatType(const fir::KindMapping &kindMap) const {
+ auto fkind = getFKind();
+ auto realTypeID = kindMap.getRealTypeID(fkind);
+ return fir::fromRealTypeID(getContext(), realTypeID, fkind);
+}
+
//===----------------------------------------------------------------------===//
// RecordType
//===----------------------------------------------------------------------===//
diff --git a/flang/lib/Optimizer/Support/DataLayout.cpp b/flang/lib/Optimizer/Support/DataLayout.cpp
index 5cd9c01e8ce00..93a3b92d08105 100644
--- a/flang/lib/Optimizer/Support/DataLayout.cpp
+++ b/flang/lib/Optimizer/Support/DataLayout.cpp
@@ -45,3 +45,16 @@ void fir::support::setMLIRDataLayoutFromAttributes(mlir::ModuleOp mlirModule,
llvm::DataLayout llvmDataLayout("");
fir::support::setMLIRDataLayout(mlirModule, llvmDataLayout);
}
+
+std::optional<mlir::DataLayout>
+fir::support::getOrSetDataLayout(mlir::ModuleOp mlirModule,
+ bool allowDefaultLayout) {
+ if (!mlirModule.getDataLayoutSpec()) {
+ fir::support::setMLIRDataLayoutFromAttributes(mlirModule,
+ allowDefaultLayout);
+ if (!mlirModule.getDataLayoutSpec()) {
+ return std::nullopt;
+ }
+ }
+ return mlir::DataLayout(mlirModule);
+}
diff --git a/flang/test/Fir/recursive-type-tco.fir b/flang/test/Fir/recursive-type-tco.fir
index 9933f727af12b..6fd222f26547a 100644
--- a/flang/test/Fir/recursive-type-tco.fir
+++ b/flang/test/Fir/recursive-type-tco.fir
@@ -5,7 +5,7 @@
// CHECK-LABEL: %t = type { ptr }
!t = !fir.type<t {p : !fir.ptr<!fir.type<t>>}>
-// CHECK-LABEL: @a(%t %{{.*}})
-func.func @a(%a : !t) {
+// CHECK-LABEL: @a({ %t } %{{.*}})
+func.func @a(%a : tuple<!t>) {
return
}
diff --git a/flang/test/Fir/struct-passing-x86-64-byval.fir b/flang/test/Fir/struct-passing-x86-64-byval.fir
new file mode 100644
index 0000000000000..791545a371608
--- /dev/null
+++ b/flang/test/Fir/struct-passing-x86-64-byval.fir
@@ -0,0 +1,103 @@
+// Test X86-64 ABI rewrite of struct passed by value (BIND(C), VALUE derived types).
+// This test test cases where the struct must be passed on the stack according
+// to the System V ABI.
+// REQUIRES: x86-registered-target
+// RUN: tco --target=x86_64-unknown-linux-gnu %s | FileCheck %s
+
+module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128", llvm.target_triple = "x86_64-unknown-linux-gnu"} {
+
+func.func @takes_toobig(%arg0: !fir.type<toobig{i:!fir.array<5xi32>}>) {
+ return
+}
+func.func @takes_toobig_align16(%arg0: !fir.type<toobig_align16{i:i128,j:i8}>) {
+ return
+}
+func.func @not_enough_int_reg_1(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: !fir.type<fits_in_1_int_reg{i:i32,j:i32}>) {
+ return
+}
+func.func @not_enough_int_reg_1b(%arg0: !fir.ref<i32>, %arg1: !fir.ref<i32>, %arg2: !fir.ref<i32>, %arg3: !fir.ref<i32>, %arg4: !fir.ref<i32>, %arg5: !fir.ref<i32>, %arg6: !fir.type<fits_in_1_int_reg{i:i32,j:i32}>) {
+ return
+}
+func.func @not_enough_int_reg_2(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: !fir.type<fits_in_2_int_reg{i:i64,j:i64}>) {
+ return
+}
+func.func @ftakes_toobig(%arg0: !fir.type<ftoobig{i:!fir.array<5xf32>}>) {
+ return
+}
+func.func @ftakes_toobig_align16(%arg0: !fir.type<ftoobig_align16{i:f128,j:f32}>) {
+ return
+}
+func.func @not_enough_sse_reg_1(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32, %arg7: f32, %arg8: !fir.type<fits_in_1_sse_reg{i:f32,j:f32}>) {
+ return
+}
+func.func @not_enough_sse_reg_1b(%arg0: !fir.complex<4>, %arg1: !fir.complex<4>, %arg2: !fir.complex<4>, %arg3: !fir.complex<4>, %arg4: !fir.complex<4>, %arg5: !fir.complex<4>, %arg6: !fir.complex<4>, %arg7: !fir.complex<4>, %arg8: !fir.type<fits_in_1_sse_reg{i:f32,j:f32}>) {
+ return
+}
+func.func @not_enough_sse_reg_1c(%arg0: !fir.complex<8>, %arg1: !fir.complex<8>, %arg2: !fir.complex<8>, %arg3: !fir.complex<8>, %arg4: !fir.type<fits_in_1_sse_reg{i:f32,j:f32}>) {
+ return
+}
+func.func @not_enough_sse_reg_2(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32, %arg7: !fir.type<fits_in_2_sse_reg{i:f64,j:f64}>) {
+ return
+}
+func.func @test_contains_x87(%arg0: !fir.type<contains_x87{i:f80}>) {
+ return
+}
+func.func @test_contains_complex_x87(%arg0: !fir.type<contains_complex_x87{i:!fir.complex<10>}>) {
+ return
+}
+func.func @test_nested_toobig(%arg0: !fir.type<nested_toobig{x:!fir.array<3x!fir.type<fits_in_1_int_reg{i:i32,j:i32}>>}>) {
+ return
+}
+func.func @test_badly_aligned(%arg0: !fir.type<badly_aligned{x:f32,y:f64,z:f32}>) {
+ return
+}
+func.func @test_logical_toobig(%arg0: !fir.type<logical_too_big{l:!fir.array<17x!fir.logical<1>>}>) {
+ return
+}
+func.func @l_not_enough_int_reg(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: !fir.type<l_fits_in_2_int_reg{l:!fir.array<4x!fir.logical<4>>}>) {
+ return
+}
+func.func @test_complex_toobig(%arg0: !fir.type<complex_too_big{c:!fir.array<2x!fir.complex<8>>}>) {
+ return
+}
+func.func @cplx_not_enough_sse_reg_1(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32, %arg7: f32, %arg8: !fir.type<cplx_fits_in_1_sse_reg{c:!fir.complex<4>}>) {
+ return
+}
+func.func @test_char_to_big(%arg0: !fir.type<char_too_big{c:!fir.array<17x!fir.char<1>>}>) {
+ return
+}
+func.func @char_not_enough_int_reg_1(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: !fir.type<char_fits_in_1_int_reg{c:!fir.array<8x!fir.char<1>>}>) {
+ return
+}
+func.func @mix_not_enough_int_reg_1(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: !fir.type<mix_in_1_int_reg{x:f32,i:i32}>) {
+ return
+}
+func.func @mix_not_enough_sse_reg_2(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32, %arg6: f32, %arg7: f32, %arg8: !fir.type<mix_in_1_int_reg_1_sse_reg{i:!fir.array<2xi32>,x:!fir.array<2xf32>}>) {
+ return
+}
+func.func private @_QPuse_it(!fir.ref<!fir.array<5xi32>>)
+}
+
+// CHECK: define void @takes_toobig(ptr byval(%toobig) align 8 %{{.*}}) {
+// CHECK: define void @takes_toobig_align16(ptr byval(%toobig_align16) align 16 %{{.*}}) {
+// CHECK: define void @not_enough_int_reg_1(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, ptr byval(%fits_in_1_int_reg) align 8 %{{.*}}) {
+// CHECK: define void @not_enough_int_reg_1b(ptr %{{.*}}, ptr %{{.*}}, ptr %{{.*}}, ptr %{{.*}}, ptr %{{.*}}, ptr %{{.*}}, ptr byval(%fits_in_1_int_reg) align 8 %{{.*}}) {
+// CHECK: define void @not_enough_int_reg_2(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, ptr byval(%fits_in_2_int_reg) align 8 %{{.*}}) {
+// CHECK: define void @ftakes_toobig(ptr byval(%ftoobig) align 8 %{{.*}}) {
+// CHECK: define void @ftakes_toobig_align16(ptr byval(%ftoobig_align16) align 16 %{{.*}}) {
+// CHECK: define void @not_enough_sse_reg_1(float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, ptr byval(%fits_in_1_sse_reg) align 8 %{{.*}}) {
+// CHECK: define void @not_enough_sse_reg_1b(<2 x float> %{{.*}}, <2 x float> %{{.*}}, <2 x float> %{{.*}}, <2 x float> %{{.*}}, <2 x float> %{{.*}}, <2 x float> %{{.*}}, <2 x float> %{{.*}}, <2 x float> %{{.*}}, ptr byval(%fits_in_1_sse_reg) align 8 %{{.*}}) {
+// CHECK: define void @not_enough_sse_reg_1c(double %{{.*}}, double %{{.*}}, double %{{.*}}, double %{{.*}}, double %{{.*}}, double %{{.*}}, double %{{.*}}, double %{{.*}}, ptr byval(%fits_in_1_sse_reg) align 8 %{{.*}}) {
+// CHECK: define void @not_enough_sse_reg_2(float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, ptr byval(%fits_in_2_sse_reg) align 8 %{{.*}}) {
+// CHECK: define void @test_contains_x87(ptr byval(%contains_x87) align 16 %{{.*}}) {
+// CHECK: define void @test_contains_complex_x87(ptr byval(%contains_complex_x87) align 16 %{{.*}}) {
+// CHECK: define void @test_nested_toobig(ptr byval(%nested_toobig) align 8 %{{.*}}) {
+// CHECK: define void @test_badly_aligned(ptr byval(%badly_aligned) align 8 %{{.*}}) {
+// CHECK: define void @test_logical_toobig(ptr byval(%logical_too_big) align 8 %{{.*}}) {
+// CHECK: define void @l_not_enough_int_reg(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, ptr byval(%l_fits_in_2_int_reg) align 8 %{{.*}}) {
+// CHECK: define void @test_complex_toobig(ptr byval(%complex_too_big) align 8 %{{.*}}) {
+// CHECK: define void @cplx_not_enough_sse_reg_1(float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, ptr byval(%cplx_fits_in_1_sse_reg) align 8 %{{.*}}) {
+// CHECK: define void @test_char_to_big(ptr byval(%char_too_big) align 8 %{{.*}}) {
+// CHECK: define void @char_not_enough_int_reg_1(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, ptr byval(%char_fits_in_1_int_reg) align 8 %{{.*}}) {
+// CHECK: define void @mix_not_enough_int_reg_1(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, ptr byval(%mix_in_1_int_reg) align 8 %{{.*}}) {
+// CHECK: define void @mix_not_enough_sse_reg_2(float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, ptr byval(%mix_in_1_int_reg_1_sse_reg) align 8 %{{.*}}) {
>From 6eb256e8be6947a52976400eb8c5de8997cb8e20 Mon Sep 17 00:00:00 2001
From: Jean Perier <jperier at nvidia.com>
Date: Fri, 8 Dec 2023 05:38:01 -0800
Subject: [PATCH 2/3] use getCharacterBitsize instead of getLogicalBitsize
---
flang/lib/Optimizer/CodeGen/Target.cpp | 5 +++--
1 file changed, 3 insertions(+), 2 deletions(-)
diff --git a/flang/lib/Optimizer/CodeGen/Target.cpp b/flang/lib/Optimizer/CodeGen/Target.cpp
index ea10486a6b34c..ed6637d4da012 100644
--- a/flang/lib/Optimizer/CodeGen/Target.cpp
+++ b/flang/lib/Optimizer/CodeGen/Target.cpp
@@ -107,9 +107,10 @@ getSizeAndAlignment(mlir::Location loc, mlir::Type ty,
logical.getContext(), kindMap.getLogicalBitsize(logical.getFKind()));
return getSizeAndAlignment(loc, intTy, dl, kindMap);
}
- if (auto logical = mlir::dyn_cast<fir::CharacterType>(ty)) {
+ if (auto character = mlir::dyn_cast<fir::CharacterType>(ty)) {
mlir::Type intTy = mlir::IntegerType::get(
- logical.getContext(), kindMap.getLogicalBitsize(logical.getFKind()));
+ character.getContext(),
+ kindMap.getCharacterBitsize(character.getFKind()));
return getSizeAndAlignment(loc, intTy, dl, kindMap);
}
TODO(loc, "computing size of a component");
>From 65548db10841caebf519616e1ddf24b6c5670e87 Mon Sep 17 00:00:00 2001
From: Jean Perier <jperier at nvidia.com>
Date: Fri, 8 Dec 2023 07:07:14 -0800
Subject: [PATCH 3/3] fix typo in assert
---
flang/lib/Optimizer/CodeGen/Target.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/flang/lib/Optimizer/CodeGen/Target.cpp b/flang/lib/Optimizer/CodeGen/Target.cpp
index ed6637d4da012..b1395f064db45 100644
--- a/flang/lib/Optimizer/CodeGen/Target.cpp
+++ b/flang/lib/Optimizer/CodeGen/Target.cpp
@@ -556,7 +556,7 @@ struct TargetX86_64 : public GenericTarget<TargetX86_64> {
/// Argument class merging as described in System V ABI 3.2.3 point 4.
ArgClass mergeClass(ArgClass accum, ArgClass field) const {
- assert((Accum != ArgClass::Memory && Accum != ArgClass::ComplexX87) &&
+ assert((accum != ArgClass::Memory && accum != ArgClass::ComplexX87) &&
"Invalid accumulated classification during merge.");
if (accum == field || field == NoClass)
return accum;
More information about the flang-commits
mailing list