[flang-commits] [flang] [flang] correctly deal with bind(c) derived type result ABI (PR #111678)
via flang-commits
flang-commits at lists.llvm.org
Wed Oct 9 06:47:38 PDT 2024
https://github.com/jeanPerier created https://github.com/llvm/llvm-project/pull/111678
Derived type results of BIND(C) function should be returned according the the C ABI for returning the related C struct type.
This currently did not happen since the abstract-result pass was forcing the Fortran ABI for all derived type results.
use the bind_c attribute that was added on call/func/dispatch in FIR to prevent such rewrite in the abstract result pass, and update the target-rewrite pass to deal with the struct return ABI.
So far, the target specific part of the target-rewrite is only implemented for X86-64 according to the "System V Application Binary Interface AMD64 v1", the other targets will hit a TODO, just like for BIND(C), VALUE derived type arguments.
This intends to deal with https://github.com/llvm/llvm-project/issues/102113.
@kiranchandramohan and @pawosm-arm, I hesitated making a silent TODO, but I am seeing that on simple struct like `type, bind(c) :: t; real x; end type`, the ABI is also not currently correct on ARM looking at the C equivalent with C (flang used to return result via memory, while clang returns it in register).
>From d540e1fffe558883f0ff810358442670f3bf5f9b Mon Sep 17 00:00:00 2001
From: Jean Perier <jperier at nvidia.com>
Date: Wed, 9 Oct 2024 06:31:49 -0700
Subject: [PATCH] [flang] correctly deal with bind(c) derived type result ABI
---
.../include/flang/Optimizer/CodeGen/Target.h | 5 +
.../flang/Optimizer/Dialect/FIROpsSupport.h | 21 +++
flang/lib/Optimizer/CodeGen/Target.cpp | 72 ++++++++-
flang/lib/Optimizer/CodeGen/TargetRewrite.cpp | 137 ++++++++++++++----
.../Optimizer/Transforms/AbstractResult.cpp | 65 ++++++++-
flang/test/Fir/abstract-results-bindc.fir | 43 ++++++
flang/test/Fir/struct-return-x86-64.fir | 120 +++++++++++++++
7 files changed, 423 insertions(+), 40 deletions(-)
create mode 100644 flang/test/Fir/abstract-results-bindc.fir
create mode 100644 flang/test/Fir/struct-return-x86-64.fir
diff --git a/flang/include/flang/Optimizer/CodeGen/Target.h b/flang/include/flang/Optimizer/CodeGen/Target.h
index a7161152a5c323..3b38583511927a 100644
--- a/flang/include/flang/Optimizer/CodeGen/Target.h
+++ b/flang/include/flang/Optimizer/CodeGen/Target.h
@@ -126,6 +126,11 @@ class CodeGenSpecifics {
structArgumentType(mlir::Location loc, fir::RecordType recTy,
const Marshalling &previousArguments) const = 0;
+ /// Type representation of a `fir.type<T>` type argument when returned by
+ /// value. Such value may need to be converted to a hidden reference argument.
+ virtual Marshalling structReturnType(mlir::Location loc,
+ fir::RecordType eleTy) const = 0;
+
/// Type representation of a `boxchar<n>` type argument when passed by value.
/// An argument value may need to be passed as a (safe) reference argument.
///
diff --git a/flang/include/flang/Optimizer/Dialect/FIROpsSupport.h b/flang/include/flang/Optimizer/Dialect/FIROpsSupport.h
index cdbefdb2341485..fb7b1d16f62f3a 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROpsSupport.h
+++ b/flang/include/flang/Optimizer/Dialect/FIROpsSupport.h
@@ -177,6 +177,27 @@ inline mlir::NamedAttribute getAdaptToByRefAttr(Builder &builder) {
}
bool isDummyArgument(mlir::Value v);
+
+template <fir::FortranProcedureFlagsEnum Flag>
+inline bool hasProcedureAttr(fir::FortranProcedureFlagsEnumAttr flags) {
+ return flags && bitEnumContainsAny(flags.getValue(), Flag);
+}
+
+template <fir::FortranProcedureFlagsEnum Flag>
+inline bool hasProcedureAttr(mlir::Operation *op) {
+ if (auto firCallOp = mlir::dyn_cast<fir::CallOp>(op))
+ return hasProcedureAttr<Flag>(firCallOp.getProcedureAttrsAttr());
+ if (auto firCallOp = mlir::dyn_cast<fir::DispatchOp>(op))
+ return hasProcedureAttr<Flag>(firCallOp.getProcedureAttrsAttr());
+ return hasProcedureAttr<Flag>(
+ op->getAttrOfType<fir::FortranProcedureFlagsEnumAttr>(
+ getFortranProcedureFlagsAttrName()));
+}
+
+inline bool hasBindcAttr(mlir::Operation *op) {
+ return hasProcedureAttr<fir::FortranProcedureFlagsEnum::bind_c>(op);
+}
+
} // namespace fir
#endif // FORTRAN_OPTIMIZER_DIALECT_FIROPSSUPPORT_H
diff --git a/flang/lib/Optimizer/CodeGen/Target.cpp b/flang/lib/Optimizer/CodeGen/Target.cpp
index a12b59413f4456..c47bda2187a684 100644
--- a/flang/lib/Optimizer/CodeGen/Target.cpp
+++ b/flang/lib/Optimizer/CodeGen/Target.cpp
@@ -100,6 +100,11 @@ struct GenericTarget : public CodeGenSpecifics {
TODO(loc, "passing VALUE BIND(C) derived type for this target");
}
+ CodeGenSpecifics::Marshalling
+ structReturnType(mlir::Location loc, fir::RecordType ty) const override {
+ TODO(loc, "returning BIND(C) derived type for this target");
+ }
+
CodeGenSpecifics::Marshalling
integerArgumentType(mlir::Location loc,
mlir::IntegerType argTy) const override {
@@ -533,7 +538,8 @@ struct TargetX86_64 : public GenericTarget<TargetX86_64> {
/// When \p recTy is a one field record type that can be passed
/// like the field on its own, returns the field type. Returns
/// a null type otherwise.
- mlir::Type passAsFieldIfOneFieldStruct(fir::RecordType recTy) const {
+ mlir::Type passAsFieldIfOneFieldStruct(fir::RecordType recTy,
+ bool allowComplex = false) const {
auto typeList = recTy.getTypeList();
if (typeList.size() != 1)
return {};
@@ -541,6 +547,8 @@ struct TargetX86_64 : public GenericTarget<TargetX86_64> {
if (mlir::isa<mlir::FloatType, mlir::IntegerType, fir::LogicalType>(
fieldType))
return fieldType;
+ if (allowComplex && mlir::isa<mlir::ComplexType>(fieldType))
+ return fieldType;
if (mlir::isa<fir::CharacterType>(fieldType)) {
// Only CHARACTER(1) are expected in BIND(C) contexts, which is the only
// contexts where derived type may be passed in registers.
@@ -593,7 +601,7 @@ struct TargetX86_64 : public GenericTarget<TargetX86_64> {
postMerge(byteOffset, Lo, Hi);
if (Lo == ArgClass::Memory || Lo == ArgClass::X87 ||
Lo == ArgClass::ComplexX87)
- return passOnTheStack(loc, recTy);
+ return passOnTheStack(loc, recTy, /*isResult=*/false);
int neededIntRegisters = 0;
int neededSSERegisters = 0;
if (Lo == ArgClass::SSE)
@@ -609,7 +617,7 @@ struct TargetX86_64 : public GenericTarget<TargetX86_64> {
// all in registers or all on the stack).
if (!hasEnoughRegisters(loc, neededIntRegisters, neededSSERegisters,
previousArguments))
- return passOnTheStack(loc, recTy);
+ return passOnTheStack(loc, recTy, /*isResult=*/false);
if (auto fieldType = passAsFieldIfOneFieldStruct(recTy)) {
CodeGenSpecifics::Marshalling marshal;
@@ -641,9 +649,61 @@ struct TargetX86_64 : public GenericTarget<TargetX86_64> {
return marshal;
}
+ CodeGenSpecifics::Marshalling
+ structReturnType(mlir::Location loc, fir::RecordType recTy) const override {
+ std::uint64_t byteOffset = 0;
+ ArgClass Lo, Hi;
+ Lo = Hi = ArgClass::NoClass;
+ byteOffset = classifyStruct(loc, recTy, byteOffset, Lo, Hi);
+ mlir::MLIRContext *context = recTy.getContext();
+ postMerge(byteOffset, Lo, Hi);
+ if (Lo == ArgClass::Memory)
+ return passOnTheStack(loc, recTy, /*isResult=*/true);
+
+ // Note that X87/ComplexX87 are passed in memory, but returned via %st0
+ // %st1 registers. Here, they are returned as fp80 or {fp80, fp80} by
+ // passAsFieldIfOneFieldStruct, and LLVM will use the expected registers.
+
+ // Note that {_Complex long double} is not 100% clear from an ABI
+ // perspective because the aggregate post merger rules say it should be
+ // passed in memory because it is bigger than 2 eight bytes. This has the
+ // funny effect of
+ // {_Complex long double} return to be dealt with differently than
+ // _Complex long double. ICC, NVC, and Clang return the struct in memory,
+ // GCC does not. The code here follows ICC and Clang because that seems to
+ // be in line with the standard (nothing in the section about return says
+ // that the step 5. of the aggregate classification should not be done for
+ // the classification of the result).
+
+ if (auto fieldType =
+ passAsFieldIfOneFieldStruct(recTy, /*allowComplex=*/true)) {
+ if (auto complexType = mlir::dyn_cast<mlir::ComplexType>(fieldType))
+ return complexReturnType(loc, complexType.getElementType());
+ CodeGenSpecifics::Marshalling marshal;
+ marshal.emplace_back(fieldType, AT{});
+ return marshal;
+ }
+
+ if (Hi == ArgClass::NoClass || Hi == ArgClass::SSEUp) {
+ // Return a single integer or floating point argument.
+ mlir::Type lowType = pickLLVMArgType(loc, context, Lo, byteOffset);
+ CodeGenSpecifics::Marshalling marshal;
+ marshal.emplace_back(lowType, AT{});
+ return marshal;
+ }
+ // Will be returned in two different registers. Generate {lowTy, HiTy} for
+ // the LLVM IR result type.
+ CodeGenSpecifics::Marshalling marshal;
+ mlir::Type lowType = pickLLVMArgType(loc, context, Lo, 8u);
+ mlir::Type hiType = pickLLVMArgType(loc, context, Hi, byteOffset - 8u);
+ marshal.emplace_back(mlir::TupleType::get(context, {lowType, hiType}),
+ AT{});
+ return marshal;
+ }
+
/// Marshal an argument that must be passed on the stack.
- CodeGenSpecifics::Marshalling passOnTheStack(mlir::Location loc,
- mlir::Type ty) const {
+ CodeGenSpecifics::Marshalling
+ passOnTheStack(mlir::Location loc, mlir::Type ty, bool isResult) const {
CodeGenSpecifics::Marshalling marshal;
auto sizeAndAlign =
fir::getTypeSizeAndAlignmentOrCrash(loc, ty, getDataLayout(), kindMap);
@@ -651,7 +711,7 @@ struct TargetX86_64 : public GenericTarget<TargetX86_64> {
unsigned short align =
std::max(sizeAndAlign.second, static_cast<unsigned short>(8));
marshal.emplace_back(fir::ReferenceType::get(ty),
- AT{align, /*byval=*/true, /*sret=*/false});
+ AT{align, /*byval=*/!isResult, /*sret=*/isResult});
return marshal;
}
};
diff --git a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
index fd56fd6bf50f44..04a3ea684642c8 100644
--- a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
+++ b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
@@ -142,20 +142,16 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
mlir::ModuleOp getModule() { return getOperation(); }
- template <typename A, typename B, typename C>
+ template <typename Ty, typename Callback>
std::optional<std::function<mlir::Value(mlir::Operation *)>>
- rewriteCallComplexResultType(
- mlir::Location loc, A ty, B &newResTys,
- fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs, C &newOpers,
- mlir::Value &savedStackPtr) {
- if (noComplexConversion) {
- newResTys.push_back(ty);
- return std::nullopt;
- }
- auto m = specifics->complexReturnType(loc, ty.getElementType());
- // Currently targets mandate COMPLEX is a single aggregate or packed
- // scalar, including the sret case.
- assert(m.size() == 1 && "target of complex return not supported");
+ rewriteCallResultType(mlir::Location loc, mlir::Type originalResTy,
+ Ty &newResTys,
+ fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
+ Callback &newOpers, mlir::Value &savedStackPtr,
+ fir::CodeGenSpecifics::Marshalling &m) {
+ // Currently, targets mandate COMPLEX or STRUCT is a single aggregate or
+ // packed scalar, including the sret case.
+ assert(m.size() == 1 && "return type not supported on this target");
auto resTy = std::get<mlir::Type>(m[0]);
auto attr = std::get<fir::CodeGenSpecifics::Attributes>(m[0]);
if (attr.isSRet()) {
@@ -170,7 +166,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
newInTyAndAttrs.push_back(m[0]);
newOpers.push_back(stack);
return [=](mlir::Operation *) -> mlir::Value {
- auto memTy = fir::ReferenceType::get(ty);
+ auto memTy = fir::ReferenceType::get(originalResTy);
auto cast = rewriter->create<fir::ConvertOp>(loc, memTy, stack);
return rewriter->create<fir::LoadOp>(loc, cast);
};
@@ -180,11 +176,41 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
// We are going to generate an alloca, so save the stack pointer.
if (!savedStackPtr)
savedStackPtr = genStackSave(loc);
- return this->convertValueInMemory(loc, call->getResult(0), ty,
+ return this->convertValueInMemory(loc, call->getResult(0), originalResTy,
/*inputMayBeBigger=*/true);
};
}
+ template <typename Ty, typename Callback>
+ std::optional<std::function<mlir::Value(mlir::Operation *)>>
+ rewriteCallComplexResultType(
+ mlir::Location loc, mlir::ComplexType ty, Ty &newResTys,
+ fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs, Callback &newOpers,
+ mlir::Value &savedStackPtr) {
+ if (noComplexConversion) {
+ newResTys.push_back(ty);
+ return std::nullopt;
+ }
+ auto m = specifics->complexReturnType(loc, ty.getElementType());
+ return rewriteCallResultType(loc, ty, newResTys, newInTyAndAttrs, newOpers,
+ savedStackPtr, m);
+ }
+
+ template <typename Ty, typename Callback>
+ std::optional<std::function<mlir::Value(mlir::Operation *)>>
+ rewriteCallStructResultType(
+ mlir::Location loc, fir::RecordType recTy, Ty &newResTys,
+ fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs, Callback &newOpers,
+ mlir::Value &savedStackPtr) {
+ if (noStructConversion) {
+ newResTys.push_back(recTy);
+ return std::nullopt;
+ }
+ auto m = specifics->structReturnType(loc, recTy);
+ return rewriteCallResultType(loc, recTy, newResTys, newInTyAndAttrs,
+ newOpers, savedStackPtr, m);
+ }
+
void passArgumentOnStackOrWithNewType(
mlir::Location loc, fir::CodeGenSpecifics::TypeAndAttr newTypeAndAttr,
mlir::Type oldType, mlir::Value oper,
@@ -356,6 +382,11 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
newInTyAndAttrs, newOpers,
savedStackPtr);
})
+ .template Case<fir::RecordType>([&](fir::RecordType recTy) {
+ wrap = rewriteCallStructResultType(loc, recTy, newResTys,
+ newInTyAndAttrs, newOpers,
+ savedStackPtr);
+ })
.Default([&](mlir::Type ty) { newResTys.push_back(ty); });
} else if (fnTy.getResults().size() > 1) {
TODO(loc, "multiple results not supported yet");
@@ -562,6 +593,24 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
}
}
+ template <typename Ty>
+ void
+ lowerStructSignatureRes(mlir::Location loc, fir::RecordType recTy,
+ Ty &newResTys,
+ fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs) {
+ if (noComplexConversion) {
+ newResTys.push_back(recTy);
+ return;
+ } else {
+ for (auto &tup : specifics->structReturnType(loc, recTy)) {
+ if (std::get<fir::CodeGenSpecifics::Attributes>(tup).isSRet())
+ newInTyAndAttrs.push_back(tup);
+ else
+ newResTys.push_back(std::get<mlir::Type>(tup));
+ }
+ }
+ }
+
void
lowerStructSignatureArg(mlir::Location loc, fir::RecordType recTy,
fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs) {
@@ -595,6 +644,9 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
.Case<mlir::ComplexType>([&](mlir::ComplexType ty) {
lowerComplexSignatureRes(loc, ty, newResTys, newInTyAndAttrs);
})
+ .Case<fir::RecordType>([&](fir::RecordType ty) {
+ lowerStructSignatureRes(loc, ty, newResTys, newInTyAndAttrs);
+ })
.Default([&](mlir::Type ty) { newResTys.push_back(ty); });
}
llvm::SmallVector<mlir::Type> trailingInTys;
@@ -696,7 +748,8 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
for (auto ty : func.getResults())
if ((mlir::isa<fir::BoxCharType>(ty) && !noCharacterConversion) ||
(fir::isa_complex(ty) && !noComplexConversion) ||
- (mlir::isa<mlir::IntegerType>(ty) && hasCCallingConv)) {
+ (mlir::isa<mlir::IntegerType>(ty) && hasCCallingConv) ||
+ (mlir::isa<fir::RecordType>(ty) && !noStructConversion)) {
LLVM_DEBUG(llvm::dbgs() << "rewrite " << signature << " for target\n");
return false;
}
@@ -770,6 +823,9 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
rewriter->getUnitAttr()));
newResTys.push_back(retTy);
})
+ .Case<fir::RecordType>([&](fir::RecordType recTy) {
+ doStructReturn(func, recTy, newResTys, newInTyAndAttrs, fixups);
+ })
.Default([&](mlir::Type ty) { newResTys.push_back(ty); });
// Saved potential shift in argument. Handling of result can add arguments
@@ -1062,21 +1118,12 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
return false;
}
- /// Convert a complex return value. This can involve converting the return
- /// value to a "hidden" first argument or packing the complex into a wide
- /// GPR.
template <typename Ty, typename FIXUPS>
- void doComplexReturn(mlir::func::FuncOp func, mlir::ComplexType cmplx,
- Ty &newResTys,
- fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
- FIXUPS &fixups) {
- if (noComplexConversion) {
- newResTys.push_back(cmplx);
- return;
- }
- auto m =
- specifics->complexReturnType(func.getLoc(), cmplx.getElementType());
- assert(m.size() == 1);
+ void doReturn(mlir::func::FuncOp func, Ty &newResTys,
+ fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
+ FIXUPS &fixups, fir::CodeGenSpecifics::Marshalling &m) {
+ assert(m.size() == 1 &&
+ "expect result to be turned into single argument or result so far");
auto &tup = m[0];
auto attr = std::get<fir::CodeGenSpecifics::Attributes>(tup);
auto argTy = std::get<mlir::Type>(tup);
@@ -1117,6 +1164,36 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
newResTys.push_back(argTy);
}
+ /// Convert a complex return value. This can involve converting the return
+ /// value to a "hidden" first argument or packing the complex into a wide
+ /// GPR.
+ template <typename Ty, typename FIXUPS>
+ void doComplexReturn(mlir::func::FuncOp func, mlir::ComplexType cmplx,
+ Ty &newResTys,
+ fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
+ FIXUPS &fixups) {
+ if (noComplexConversion) {
+ newResTys.push_back(cmplx);
+ return;
+ }
+ auto m =
+ specifics->complexReturnType(func.getLoc(), cmplx.getElementType());
+ doReturn(func, newResTys, newInTyAndAttrs, fixups, m);
+ }
+
+ template <typename Ty, typename FIXUPS>
+ void doStructReturn(mlir::func::FuncOp func, fir::RecordType recTy,
+ Ty &newResTys,
+ fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
+ FIXUPS &fixups) {
+ if (noStructConversion) {
+ newResTys.push_back(recTy);
+ return;
+ }
+ auto m = specifics->structReturnType(func.getLoc(), recTy);
+ doReturn(func, newResTys, newInTyAndAttrs, fixups, m);
+ }
+
template <typename FIXUPS>
void
createFuncOpArgFixups(mlir::func::FuncOp func,
diff --git a/flang/lib/Optimizer/Transforms/AbstractResult.cpp b/flang/lib/Optimizer/Transforms/AbstractResult.cpp
index 7299ff80121e13..c0ec820d87ed44 100644
--- a/flang/lib/Optimizer/Transforms/AbstractResult.cpp
+++ b/flang/lib/Optimizer/Transforms/AbstractResult.cpp
@@ -32,6 +32,33 @@ using namespace mlir;
namespace fir {
namespace {
+// Helper to only build the symbol table if needed because its build time is
+// linear on the number of symbols in the module.
+struct LazySymbolTable {
+ LazySymbolTable(mlir::Operation *op)
+ : module{op->getParentOfType<mlir::ModuleOp>()} {}
+ void build() {
+ if (table)
+ return;
+ table = std::make_unique<mlir::SymbolTable>(module);
+ }
+
+ template <typename T>
+ T lookup(llvm::StringRef name) {
+ build();
+ return table->lookup<T>(name);
+ }
+
+private:
+ std::unique_ptr<mlir::SymbolTable> table;
+ mlir::ModuleOp module;
+};
+
+bool hasScalarDerivedResult(mlir::FunctionType funTy) {
+ return funTy.getNumResults() == 1 &&
+ mlir::isa<fir::RecordType>(funTy.getResult(0));
+}
+
static mlir::Type getResultArgumentType(mlir::Type resultType,
bool shouldBoxResult) {
return llvm::TypeSwitch<mlir::Type, mlir::Type>(resultType)
@@ -190,7 +217,14 @@ class SaveResultOpConversion
llvm::LogicalResult
matchAndRewrite(fir::SaveResultOp op,
mlir::PatternRewriter &rewriter) const override {
- rewriter.eraseOp(op);
+ mlir::Operation *call = op.getValue().getDefiningOp();
+ if (mlir::isa<fir::RecordType>(op.getValue().getType()) && call &&
+ fir::hasBindcAttr(call)) {
+ rewriter.replaceOpWithNewOp<fir::StoreOp>(op, op.getValue(),
+ op.getMemref());
+ } else {
+ rewriter.eraseOp(op);
+ }
return mlir::success();
}
};
@@ -300,6 +334,12 @@ class AbstractResultOpt
auto *context = &getContext();
// Convert function type itself if it has an abstract result.
auto funcTy = mlir::cast<mlir::FunctionType>(func.getFunctionType());
+ // Scalar derived result of BIND(C) function must be returned according
+ // to the C struct return ABI which is target dependent and implemented in
+ // the target-rewrite pass.
+ if (hasScalarDerivedResult(funcTy) &&
+ fir::hasBindcAttr(func.getOperation()))
+ return;
if (hasAbstractResult(funcTy)) {
if (fir::isa_builtin_cptr_type(funcTy.getResult(0))) {
func.setType(getCPtrFunctionType(funcTy));
@@ -395,6 +435,8 @@ class AbstractResultOpt
return;
}
+ LazySymbolTable symbolTable(op);
+
mlir::RewritePatternSet patterns(context);
mlir::ConversionTarget target = *context;
const bool shouldBoxResult = this->passResultAsBox.getValue();
@@ -409,14 +451,29 @@ class AbstractResultOpt
mlir::func::FuncDialect>();
target.addIllegalOp<fir::SaveResultOp>();
target.addDynamicallyLegalOp<fir::CallOp>([](fir::CallOp call) {
- return !hasAbstractResult(call.getFunctionType());
+ mlir::FunctionType funTy = call.getFunctionType();
+ if (hasScalarDerivedResult(funTy) &&
+ fir::hasBindcAttr(call.getOperation()))
+ return true;
+ return !hasAbstractResult(funTy);
});
- target.addDynamicallyLegalOp<fir::AddrOfOp>([](fir::AddrOfOp addrOf) {
- if (auto funTy = mlir::dyn_cast<mlir::FunctionType>(addrOf.getType()))
+ target.addDynamicallyLegalOp<fir::AddrOfOp>([&symbolTable](
+ fir::AddrOfOp addrOf) {
+ if (auto funTy = mlir::dyn_cast<mlir::FunctionType>(addrOf.getType())) {
+ if (hasScalarDerivedResult(funTy)) {
+ auto func = symbolTable.lookup<mlir::func::FuncOp>(
+ addrOf.getSymbol().getRootReference().getValue());
+ return func && fir::hasBindcAttr(func.getOperation());
+ }
return !hasAbstractResult(funTy);
+ }
return true;
});
target.addDynamicallyLegalOp<fir::DispatchOp>([](fir::DispatchOp dispatch) {
+ mlir::FunctionType funTy = dispatch.getFunctionType();
+ if (hasScalarDerivedResult(funTy) &&
+ fir::hasBindcAttr(dispatch.getOperation()))
+ return true;
return !hasAbstractResult(dispatch.getFunctionType());
});
diff --git a/flang/test/Fir/abstract-results-bindc.fir b/flang/test/Fir/abstract-results-bindc.fir
new file mode 100644
index 00000000000000..9b26730f7d2923
--- /dev/null
+++ b/flang/test/Fir/abstract-results-bindc.fir
@@ -0,0 +1,43 @@
+// Test that bind_c derived type results are not moved to a hidden argument
+// by the abstract-result pass. They will be dealt with according to the C
+// struct returning ABI for the target in the target-rewrite pass.
+// RUN: fir-opt %s --abstract-result | FileCheck %s
+
+!t = !fir.type<t{i:f32, j: i32, k: f32}>
+
+func.func private @foo() -> !t attributes {fir.proc_attrs = #fir.proc_attrs<bind_c>}
+
+func.func @test_call(%x: !fir.ref<!t>) {
+ %0 = fir.call @foo() proc_attrs<bind_c> : () -> !t
+ fir.save_result %0 to %x : !t, !fir.ref<!t>
+ return
+}
+
+func.func @test_addr_of() -> (() -> !t) {
+ %0 = fir.address_of(@foo) : () -> !t
+ return %0 : () -> !t
+}
+
+func.func @test_dispatch(%x: !fir.ref<!t>, %y : !fir.class<!fir.type<somet>>) {
+ %0 = fir.dispatch "bar"(%y : !fir.class<!fir.type<somet>>) (%y : !fir.class<!fir.type<somet>>) -> !t proc_attrs<bind_c> {pass_arg_pos = 0 : i32}
+ fir.save_result %0 to %x : !t, !fir.ref<!t>
+ return
+}
+
+// CHECK-LABEL: func.func @test_call(
+// CHECK-SAME: %[[VAL_0:.*]]: !fir.ref<!fir.type<t{i:f32,j:i32,k:f32}>>) {
+// CHECK: %[[VAL_1:.*]] = fir.call @foo() proc_attrs<bind_c> : () -> !fir.type<t{i:f32,j:i32,k:f32}>
+// CHECK: fir.store %[[VAL_1]] to %[[VAL_0]] : !fir.ref<!fir.type<t{i:f32,j:i32,k:f32}>>
+// CHECK: return
+// CHECK: }
+// CHECK-LABEL: func.func @test_addr_of() -> (() -> !fir.type<t{i:f32,j:i32,k:f32}>) {
+// CHECK: %[[VAL_0:.*]] = fir.address_of(@foo) : () -> !fir.type<t{i:f32,j:i32,k:f32}>
+// CHECK: return %[[VAL_0]] : () -> !fir.type<t{i:f32,j:i32,k:f32}>
+// CHECK: }
+// CHECK-LABEL: func.func @test_dispatch(
+// CHECK-SAME: %[[VAL_0:.*]]: !fir.ref<!fir.type<t{i:f32,j:i32,k:f32}>>,
+// CHECK-SAME: %[[VAL_1:.*]]: !fir.class<!fir.type<somet>>) {
+// CHECK: %[[VAL_2:.*]] = fir.dispatch "bar"(%[[VAL_1]] : !fir.class<!fir.type<somet>>) (%[[VAL_1]] : !fir.class<!fir.type<somet>>) -> !fir.type<t{i:f32,j:i32,k:f32}> proc_attrs <bind_c> {pass_arg_pos = 0 : i32}
+// CHECK: fir.store %[[VAL_2]] to %[[VAL_0]] : !fir.ref<!fir.type<t{i:f32,j:i32,k:f32}>>
+// CHECK: return
+// CHECK: }
diff --git a/flang/test/Fir/struct-return-x86-64.fir b/flang/test/Fir/struct-return-x86-64.fir
new file mode 100644
index 00000000000000..f4c2add69ff7e9
--- /dev/null
+++ b/flang/test/Fir/struct-return-x86-64.fir
@@ -0,0 +1,120 @@
+// Test X86-64 ABI rewrite of struct returned by value (BIND(C), VALUE derived types).
+// REQUIRES: x86-registered-target
+// RUN: fir-opt --target-rewrite %s | FileCheck %s
+
+!fits_in_reg = !fir.type<t1{i:f32,j:i32,k:f32}>
+!too_big = !fir.type<t2{i:!fir.array<5xf32>}>
+
+module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128", llvm.target_triple = "x86_64-unknown-linux-gnu"} {
+
+ func.func private @test_inreg() -> !fits_in_reg
+ func.func @test_call_inreg(%arg0: !fir.ref<!fits_in_reg>) {
+ %0 = fir.call @test_inreg() : () -> !fits_in_reg
+ fir.store %0 to %arg0 : !fir.ref<!fits_in_reg>
+ return
+ }
+ func.func @test_addr_of_inreg() -> (() -> ()) {
+ %0 = fir.address_of(@test_inreg) : () -> !fits_in_reg
+ %1 = fir.convert %0 : (() -> !fits_in_reg) -> (() -> ())
+ return %1 : () -> ()
+ }
+ func.func @test_dispatch_inreg(%arg0: !fir.ref<!fits_in_reg>, %arg1: !fir.class<!fir.type<somet>>) {
+ %0 = fir.dispatch "bar"(%arg1 : !fir.class<!fir.type<somet>>) (%arg1 : !fir.class<!fir.type<somet>>) -> !fits_in_reg {pass_arg_pos = 0 : i32}
+ fir.store %0 to %arg0 : !fir.ref<!fits_in_reg>
+ return
+ }
+
+ func.func private @test_sret() -> !too_big
+ func.func @test_call_sret(%arg0: !fir.ref<!too_big>) {
+ %0 = fir.call @test_sret() : () -> !too_big
+ fir.store %0 to %arg0 : !fir.ref<!too_big>
+ return
+ }
+ func.func @test_addr_of_sret() -> (() -> ()) {
+ %0 = fir.address_of(@test_sret) : () -> !too_big
+ %1 = fir.convert %0 : (() -> !too_big) -> (() -> ())
+ return %1 : () -> ()
+ }
+ func.func @test_dispatch_sret(%arg0: !fir.ref<!too_big>, %arg1: !fir.class<!fir.type<somet>>) {
+ %0 = fir.dispatch "bar"(%arg1 : !fir.class<!fir.type<somet>>) (%arg1 : !fir.class<!fir.type<somet>>) -> !too_big {pass_arg_pos = 0 : i32}
+ fir.store %0 to %arg0 : !fir.ref<!too_big>
+ return
+ }
+ func.func private @test_fp_80() -> !fir.type<t3{i:f80}>
+ func.func private @test_complex_80() -> !fir.type<t4{i:complex<f80>}>
+ func.func private @test_two_fp_80() -> !fir.type<t5{i:f80,j:f80}>
+ func.func private @test_fp128() -> !fir.type<t6{i:f128}>
+}
+
+// CHECK-LABEL: func.func private @test_inreg() -> tuple<i64, f32>
+
+// CHECK-LABEL: func.func @test_call_inreg(
+// CHECK-SAME: %[[VAL_0:.*]]: !fir.ref<!fir.type<t1{i:f32,j:i32,k:f32}>>) {
+// CHECK: %[[VAL_1:.*]] = fir.call @test_inreg() : () -> tuple<i64, f32>
+// CHECK: %[[VAL_2:.*]] = llvm.intr.stacksave : !llvm.ptr
+// CHECK: %[[VAL_3:.*]] = fir.alloca tuple<i64, f32>
+// CHECK: fir.store %[[VAL_1]] to %[[VAL_3]] : !fir.ref<tuple<i64, f32>>
+// CHECK: %[[VAL_4:.*]] = fir.convert %[[VAL_3]] : (!fir.ref<tuple<i64, f32>>) -> !fir.ref<!fir.type<t1{i:f32,j:i32,k:f32}>>
+// CHECK: %[[VAL_5:.*]] = fir.load %[[VAL_4]] : !fir.ref<!fir.type<t1{i:f32,j:i32,k:f32}>>
+// CHECK: llvm.intr.stackrestore %[[VAL_2]] : !llvm.ptr
+// CHECK: fir.store %[[VAL_5]] to %[[VAL_0]] : !fir.ref<!fir.type<t1{i:f32,j:i32,k:f32}>>
+// CHECK: return
+// CHECK: }
+
+// CHECK-LABEL: func.func @test_addr_of_inreg() -> (() -> ()) {
+// CHECK: %[[VAL_0:.*]] = fir.address_of(@test_inreg) : () -> tuple<i64, f32>
+// CHECK: %[[VAL_1:.*]] = fir.convert %[[VAL_0]] : (() -> tuple<i64, f32>) -> (() -> ())
+// CHECK: return %[[VAL_1]] : () -> ()
+// CHECK: }
+
+// CHECK-LABEL: func.func @test_dispatch_inreg(
+// CHECK-SAME: %[[VAL_0:.*]]: !fir.ref<!fir.type<t1{i:f32,j:i32,k:f32}>>,
+// CHECK-SAME: %[[VAL_1:.*]]: !fir.class<!fir.type<somet>>) {
+// CHECK: %[[VAL_2:.*]] = fir.dispatch "bar"(%[[VAL_1]] : !fir.class<!fir.type<somet>>) (%[[VAL_1]] : !fir.class<!fir.type<somet>>) -> tuple<i64, f32> {pass_arg_pos = 0 : i32}
+// CHECK: %[[VAL_3:.*]] = llvm.intr.stacksave : !llvm.ptr
+// CHECK: %[[VAL_4:.*]] = fir.alloca tuple<i64, f32>
+// CHECK: fir.store %[[VAL_2]] to %[[VAL_4]] : !fir.ref<tuple<i64, f32>>
+// CHECK: %[[VAL_5:.*]] = fir.convert %[[VAL_4]] : (!fir.ref<tuple<i64, f32>>) -> !fir.ref<!fir.type<t1{i:f32,j:i32,k:f32}>>
+// CHECK: %[[VAL_6:.*]] = fir.load %[[VAL_5]] : !fir.ref<!fir.type<t1{i:f32,j:i32,k:f32}>>
+// CHECK: llvm.intr.stackrestore %[[VAL_3]] : !llvm.ptr
+// CHECK: fir.store %[[VAL_6]] to %[[VAL_0]] : !fir.ref<!fir.type<t1{i:f32,j:i32,k:f32}>>
+// CHECK: return
+// CHECK: }
+// CHECK: func.func private @test_sret(!fir.ref<!fir.type<t2{i:!fir.array<5xf32>}>> {llvm.align = 8 : i32, llvm.sret = !fir.type<t2{i:!fir.array<5xf32>}>})
+
+// CHECK-LABEL: func.func @test_call_sret(
+// CHECK-SAME: %[[VAL_0:.*]]: !fir.ref<!fir.type<t2{i:!fir.array<5xf32>}>>) {
+// CHECK: %[[VAL_1:.*]] = llvm.intr.stacksave : !llvm.ptr
+// CHECK: %[[VAL_2:.*]] = fir.alloca !fir.type<t2{i:!fir.array<5xf32>}>
+// CHECK: fir.call @test_sret(%[[VAL_2]]) : (!fir.ref<!fir.type<t2{i:!fir.array<5xf32>}>>) -> ()
+// CHECK: %[[VAL_3:.*]] = fir.convert %[[VAL_2]] : (!fir.ref<!fir.type<t2{i:!fir.array<5xf32>}>>) -> !fir.ref<!fir.type<t2{i:!fir.array<5xf32>}>>
+// CHECK: %[[VAL_4:.*]] = fir.load %[[VAL_3]] : !fir.ref<!fir.type<t2{i:!fir.array<5xf32>}>>
+// CHECK: llvm.intr.stackrestore %[[VAL_1]] : !llvm.ptr
+// CHECK: fir.store %[[VAL_4]] to %[[VAL_0]] : !fir.ref<!fir.type<t2{i:!fir.array<5xf32>}>>
+// CHECK: return
+// CHECK: }
+
+// CHECK-LABEL: func.func @test_addr_of_sret() -> (() -> ()) {
+// CHECK: %[[VAL_0:.*]] = fir.address_of(@test_sret) : (!fir.ref<!fir.type<t2{i:!fir.array<5xf32>}>>) -> ()
+// CHECK: %[[VAL_1:.*]] = fir.convert %[[VAL_0]] : ((!fir.ref<!fir.type<t2{i:!fir.array<5xf32>}>>) -> ()) -> (() -> ())
+// CHECK: return %[[VAL_1]] : () -> ()
+// CHECK: }
+
+// CHECK-LABEL: func.func @test_dispatch_sret(
+// CHECK-SAME: %[[VAL_0:.*]]: !fir.ref<!fir.type<t2{i:!fir.array<5xf32>}>>,
+// CHECK-SAME: %[[VAL_1:.*]]: !fir.class<!fir.type<somet>>) {
+// CHECK: %[[VAL_2:.*]] = llvm.intr.stacksave : !llvm.ptr
+// CHECK: %[[VAL_3:.*]] = fir.alloca !fir.type<t2{i:!fir.array<5xf32>}>
+// CHECK: fir.dispatch "bar"(%[[VAL_1]] : !fir.class<!fir.type<somet>>) (%[[VAL_3]], %[[VAL_1]] : !fir.ref<!fir.type<t2{i:!fir.array<5xf32>}>>, !fir.class<!fir.type<somet>>) {pass_arg_pos = 1 : i32}
+// CHECK: %[[VAL_4:.*]] = fir.convert %[[VAL_3]] : (!fir.ref<!fir.type<t2{i:!fir.array<5xf32>}>>) -> !fir.ref<!fir.type<t2{i:!fir.array<5xf32>}>>
+// CHECK: %[[VAL_5:.*]] = fir.load %[[VAL_4]] : !fir.ref<!fir.type<t2{i:!fir.array<5xf32>}>>
+// CHECK: llvm.intr.stackrestore %[[VAL_2]] : !llvm.ptr
+// CHECK: fir.store %[[VAL_5]] to %[[VAL_0]] : !fir.ref<!fir.type<t2{i:!fir.array<5xf32>}>>
+// CHECK: return
+// CHECK: }
+
+
+// CHECK: func.func private @test_fp_80() -> f80
+// CHECK: func.func private @test_complex_80(!fir.ref<!fir.type<t4{i:complex<f80>}>> {llvm.align = 16 : i32, llvm.sret = !fir.type<t4{i:complex<f80>}>})
+// CHECK: func.func private @test_two_fp_80(!fir.ref<!fir.type<t5{i:f80,j:f80}>> {llvm.align = 16 : i32, llvm.sret = !fir.type<t5{i:f80,j:f80}>})
+// CHECK: func.func private @test_fp128() -> f128
More information about the flang-commits
mailing list