[flang-commits] [flang] [mlir] [flang][AIX] BIND(C) derived type alignment for AIX (PR #121505)
via flang-commits
flang-commits at lists.llvm.org
Thu Jan 2 09:19:58 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Kelvin Li (kkwli)
<details>
<summary>Changes</summary>
This patch is to handle the alignment requirement for the `bind(c)` derived type component that is real type and larger than 4 bytes. The alignment of such component is 4-byte.
---
Full diff: https://github.com/llvm/llvm-project/pull/121505.diff
9 Files Affected:
- (modified) flang/include/flang/Optimizer/CodeGen/TypeConverter.h (+1-1)
- (modified) flang/include/flang/Optimizer/Dialect/FIRTypes.td (+3)
- (modified) flang/lib/Lower/ConvertType.cpp (+42)
- (modified) flang/lib/Optimizer/CodeGen/TypeConverter.cpp (+6-4)
- (modified) flang/lib/Optimizer/Dialect/FIRType.cpp (+9-1)
- (modified) flang/lib/Semantics/compute-offsets.cpp (+77-5)
- (added) flang/test/Lower/derived-types-bindc.f90 (+44)
- (added) flang/test/Semantics/offsets04.f90 (+105)
- (modified) mlir/include/mlir/Transforms/DialectConversion.h (+15)
``````````diff
diff --git a/flang/include/flang/Optimizer/CodeGen/TypeConverter.h b/flang/include/flang/Optimizer/CodeGen/TypeConverter.h
index 7c317ddeea1fa4..20270d41b1e9a1 100644
--- a/flang/include/flang/Optimizer/CodeGen/TypeConverter.h
+++ b/flang/include/flang/Optimizer/CodeGen/TypeConverter.h
@@ -62,7 +62,7 @@ class LLVMTypeConverter : public mlir::LLVMTypeConverter {
// fir.type<name(p : TY'...){f : TY...}> --> llvm<"%name = { ty... }">
std::optional<llvm::LogicalResult>
convertRecordType(fir::RecordType derived,
- llvm::SmallVectorImpl<mlir::Type> &results);
+ llvm::SmallVectorImpl<mlir::Type> &results, bool isPacked);
// Is an extended descriptor needed given the element type of a fir.box type ?
// Extended descriptors are required for derived types.
diff --git a/flang/include/flang/Optimizer/Dialect/FIRTypes.td b/flang/include/flang/Optimizer/Dialect/FIRTypes.td
index 3919c9191c2122..4d832a49236bf1 100644
--- a/flang/include/flang/Optimizer/Dialect/FIRTypes.td
+++ b/flang/include/flang/Optimizer/Dialect/FIRTypes.td
@@ -346,6 +346,9 @@ def fir_RecordType : FIR_Type<"Record", "type"> {
void finalize(llvm::ArrayRef<TypePair> lenPList,
llvm::ArrayRef<TypePair> typeList);
+ bool isPacked() const;
+ void pack(bool);
+
detail::RecordTypeStorage const *uniqueKey() const;
}];
}
diff --git a/flang/lib/Lower/ConvertType.cpp b/flang/lib/Lower/ConvertType.cpp
index 452ddda426fa10..40dbaa28ce7cac 100644
--- a/flang/lib/Lower/ConvertType.cpp
+++ b/flang/lib/Lower/ConvertType.cpp
@@ -20,6 +20,8 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "llvm/Support/Debug.h"
+#include "llvm/TargetParser/Host.h"
+#include "llvm/TargetParser/Triple.h"
#define DEBUG_TYPE "flang-lower-type"
@@ -385,9 +387,19 @@ struct TypeBuilderImpl {
// with dozens of components/parents (modern Fortran).
derivedTypeInConstruction.try_emplace(&derivedScope, rec);
+ auto targetTriple{llvm::Triple(
+ llvm::Triple::normalize(llvm::sys::getDefaultTargetTriple()))};
+ // Always generate packed FIR struct type for bind(c) derived type for AIX
+ if (targetTriple.getOS() == llvm::Triple::OSType::AIX &&
+ tySpec.typeSymbol().attrs().test(Fortran::semantics::Attr::BIND_C) &&
+ !IsIsoCType(&tySpec)) {
+ rec.pack(true);
+ }
+
// Gather the record type fields.
// (1) The data components.
if (converter.getLoweringOptions().getLowerToHighLevelFIR()) {
+ size_t prev_offset{0};
// In HLFIR the parent component is the first fir.type component.
for (const auto &componentName :
typeSymbol.get<Fortran::semantics::DerivedTypeDetails>()
@@ -397,7 +409,37 @@ struct TypeBuilderImpl {
"failed to find derived type component symbol");
const Fortran::semantics::Symbol &component = scopeIter->second.get();
mlir::Type ty = genSymbolType(component);
+ if (rec.isPacked()) {
+ auto compSize{component.size()};
+ auto compOffset{component.offset()};
+
+ if (prev_offset < compOffset) {
+ size_t pad{compOffset - prev_offset};
+ mlir::Type i8Ty{mlir::IntegerType::get(context, 8)};
+ fir::SequenceType::Shape shape{static_cast<int64_t>(pad)};
+ mlir::Type padTy{fir::SequenceType::get(shape, i8Ty)};
+ prev_offset += pad;
+ cs.emplace_back("", padTy);
+ }
+ prev_offset += compSize;
+ }
cs.emplace_back(converter.getRecordTypeFieldName(component), ty);
+ if (rec.isPacked()) {
+ // For the last component, determine if any padding is needed.
+ if (componentName ==
+ typeSymbol.get<Fortran::semantics::DerivedTypeDetails>()
+ .componentNames()
+ .back()) {
+ auto compEnd{component.offset() + component.size()};
+ if (compEnd < derivedScope.size()) {
+ size_t pad{derivedScope.size() - compEnd};
+ mlir::Type i8Ty{mlir::IntegerType::get(context, 8)};
+ fir::SequenceType::Shape shape{static_cast<int64_t>(pad)};
+ mlir::Type padTy{fir::SequenceType::get(shape, i8Ty)};
+ cs.emplace_back("", padTy);
+ }
+ }
+ }
}
} else {
for (const auto &component :
diff --git a/flang/lib/Optimizer/CodeGen/TypeConverter.cpp b/flang/lib/Optimizer/CodeGen/TypeConverter.cpp
index c23203efcd3df2..0eace903720f03 100644
--- a/flang/lib/Optimizer/CodeGen/TypeConverter.cpp
+++ b/flang/lib/Optimizer/CodeGen/TypeConverter.cpp
@@ -82,7 +82,7 @@ LLVMTypeConverter::LLVMTypeConverter(mlir::ModuleOp module, bool applyTBAA,
[&](fir::PointerType pointer) { return convertPointerLike(pointer); });
addConversion(
[&](fir::RecordType derived, llvm::SmallVectorImpl<mlir::Type> &results) {
- return convertRecordType(derived, results);
+ return convertRecordType(derived, results, derived.isPacked());
});
addConversion(
[&](fir::ReferenceType ref) { return convertPointerLike(ref); });
@@ -133,8 +133,10 @@ mlir::Type LLVMTypeConverter::indexType() const {
}
// fir.type<name(p : TY'...){f : TY...}> --> llvm<"%name = { ty... }">
-std::optional<llvm::LogicalResult> LLVMTypeConverter::convertRecordType(
- fir::RecordType derived, llvm::SmallVectorImpl<mlir::Type> &results) {
+std::optional<llvm::LogicalResult>
+LLVMTypeConverter::convertRecordType(fir::RecordType derived,
+ llvm::SmallVectorImpl<mlir::Type> &results,
+ bool isPacked) {
auto name = fir::NameUniquer::dropTypeConversionMarkers(derived.getName());
auto st = mlir::LLVM::LLVMStructType::getIdentified(&getContext(), name);
@@ -156,7 +158,7 @@ std::optional<llvm::LogicalResult> LLVMTypeConverter::convertRecordType(
else
members.push_back(mlir::cast<mlir::Type>(convertType(mem.second)));
}
- if (mlir::failed(st.setBody(members, /*isPacked=*/false)))
+ if (mlir::failed(st.setBody(members, isPacked)))
return mlir::failure();
results.push_back(st);
return mlir::success();
diff --git a/flang/lib/Optimizer/Dialect/FIRType.cpp b/flang/lib/Optimizer/Dialect/FIRType.cpp
index cba7fa64128502..ea06eb092ed918 100644
--- a/flang/lib/Optimizer/Dialect/FIRType.cpp
+++ b/flang/lib/Optimizer/Dialect/FIRType.cpp
@@ -165,16 +165,20 @@ struct RecordTypeStorage : public mlir::TypeStorage {
setTypeList(typeList);
}
+ bool isPacked() const { return packed; }
+ void pack(bool p) { packed = p; }
+
protected:
std::string name;
bool finalized;
+ bool packed;
std::vector<RecordType::TypePair> lens;
std::vector<RecordType::TypePair> types;
private:
RecordTypeStorage() = delete;
explicit RecordTypeStorage(llvm::StringRef name)
- : name{name}, finalized{false} {}
+ : name{name}, finalized{false}, packed{false} {}
};
} // namespace detail
@@ -973,6 +977,10 @@ RecordType::TypeList fir::RecordType::getLenParamList() const {
bool fir::RecordType::isFinalized() const { return getImpl()->isFinalized(); }
+void fir::RecordType::pack(bool p) { getImpl()->pack(p); }
+
+bool fir::RecordType::isPacked() const { return getImpl()->isPacked(); }
+
detail::RecordTypeStorage const *fir::RecordType::uniqueKey() const {
return getImpl();
}
diff --git a/flang/lib/Semantics/compute-offsets.cpp b/flang/lib/Semantics/compute-offsets.cpp
index 94640fa30baa54..7d516b3e8df54a 100644
--- a/flang/lib/Semantics/compute-offsets.cpp
+++ b/flang/lib/Semantics/compute-offsets.cpp
@@ -17,6 +17,8 @@
#include "flang/Semantics/symbol.h"
#include "flang/Semantics/tools.h"
#include "flang/Semantics/type.h"
+#include "llvm/TargetParser/Host.h"
+#include "llvm/TargetParser/Triple.h"
#include <algorithm>
#include <vector>
@@ -51,9 +53,12 @@ class ComputeOffsetsHelper {
SymbolAndOffset Resolve(const SymbolAndOffset &);
std::size_t ComputeOffset(const EquivalenceObject &);
// Returns amount of padding that was needed for alignment
- std::size_t DoSymbol(Symbol &);
+ std::size_t DoSymbol(Symbol &,
+ std::optional<const size_t> newAlign = std::nullopt);
SizeAndAlignment GetSizeAndAlignment(const Symbol &, bool entire);
std::size_t Align(std::size_t, std::size_t);
+ std::optional<size_t> CompAlignment(const Symbol &);
+ std::optional<size_t> HasSpecialAlign(const Symbol &, Scope &);
SemanticsContext &context_;
std::size_t offset_{0};
@@ -65,6 +70,60 @@ class ComputeOffsetsHelper {
equivalenceBlock_;
};
+static bool isReal8OrLarger(const Fortran::semantics::DeclTypeSpec *type) {
+ return ((type->IsNumeric(common::TypeCategory::Real) ||
+ type->IsNumeric(common::TypeCategory::Complex)) &&
+ evaluate::ToInt64(type->numericTypeSpec().kind()) > 4);
+}
+
+std::optional<size_t> ComputeOffsetsHelper::CompAlignment(const Symbol &sym) {
+ size_t max_align{0};
+ bool contain_double{false};
+ const auto derivedTypeSpec{sym.GetType()->AsDerived()};
+ DirectComponentIterator directs{*derivedTypeSpec};
+ for (auto it = directs.begin(); it != directs.end(); ++it) {
+ auto type{it->GetType()};
+ auto s{GetSizeAndAlignment(*it, true)};
+ if (isReal8OrLarger(type)) {
+ max_align = std::max(max_align, 4UL);
+ contain_double = true;
+ } else if (type->AsDerived()) {
+ if (const auto newAlgin = CompAlignment(*it)) {
+ max_align = std::max(max_align, s.alignment);
+ } else {
+ return std::nullopt;
+ }
+ } else {
+ max_align = std::max(max_align, s.alignment);
+ }
+ }
+
+ if (contain_double)
+ return max_align;
+ else
+ return std::nullopt;
+}
+
+std::optional<size_t> ComputeOffsetsHelper::HasSpecialAlign(const Symbol &sym,
+ Scope &scope) {
+ // On AIX, if the component that is not the first component and is
+ // a float of 8 bytes or larger, it has the 4-byte alignment.
+ // Only set the special alignment for bind(c) derived type on that platform.
+ if (const auto type = sym.GetType()) {
+ auto &symOwner{sym.owner()};
+ if (symOwner.symbol() && symOwner.IsDerivedType() &&
+ symOwner.symbol()->attrs().HasAny({semantics::Attr::BIND_C}) &&
+ &sym != &(*scope.GetSymbols().front())) {
+ if (isReal8OrLarger(type)) {
+ return 4UL;
+ } else if (type->AsDerived()) {
+ return CompAlignment(sym);
+ }
+ }
+ }
+ return std::nullopt;
+}
+
void ComputeOffsetsHelper::Compute(Scope &scope) {
for (Scope &child : scope.children()) {
ComputeOffsets(context_, child);
@@ -113,7 +172,15 @@ void ComputeOffsetsHelper::Compute(Scope &scope) {
if (!FindCommonBlockContaining(*symbol) &&
dependents_.find(symbol) == dependents_.end() &&
equivalenceBlock_.find(symbol) == equivalenceBlock_.end()) {
- DoSymbol(*symbol);
+
+ std::optional<size_t> newAlign{std::nullopt};
+ // Handle special alignment requirement for AIX
+ auto triple{llvm::Triple(llvm::Triple::normalize(
+ llvm::sys::getDefaultTargetTriple()))};
+ if (triple.getOS() == llvm::Triple::OSType::AIX) {
+ newAlign = HasSpecialAlign(*symbol, scope);
+ }
+ DoSymbol(*symbol, newAlign);
if (auto *generic{symbol->detailsIf<GenericDetails>()}) {
if (Symbol * specific{generic->specific()};
specific && !FindCommonBlockContaining(*specific)) {
@@ -313,7 +380,8 @@ std::size_t ComputeOffsetsHelper::ComputeOffset(
return result;
}
-std::size_t ComputeOffsetsHelper::DoSymbol(Symbol &symbol) {
+std::size_t ComputeOffsetsHelper::DoSymbol(Symbol &symbol,
+ std::optional<const size_t> newAlign) {
if (!symbol.has<ObjectEntityDetails>() && !symbol.has<ProcEntityDetails>()) {
return 0;
}
@@ -322,12 +390,16 @@ std::size_t ComputeOffsetsHelper::DoSymbol(Symbol &symbol) {
return 0;
}
std::size_t previousOffset{offset_};
- offset_ = Align(offset_, s.alignment);
+ size_t alignVal{s.alignment};
+ if (newAlign) {
+ alignVal = newAlign.value();
+ }
+ offset_ = Align(offset_, alignVal);
std::size_t padding{offset_ - previousOffset};
symbol.set_size(s.size);
symbol.set_offset(offset_);
offset_ += s.size;
- alignment_ = std::max(alignment_, s.alignment);
+ alignment_ = std::max(alignment_, alignVal);
return padding;
}
diff --git a/flang/test/Lower/derived-types-bindc.f90 b/flang/test/Lower/derived-types-bindc.f90
new file mode 100644
index 00000000000000..309b2b7f5f4929
--- /dev/null
+++ b/flang/test/Lower/derived-types-bindc.f90
@@ -0,0 +1,44 @@
+! Test padding for BIND(C) derived types lowering for AIX target
+! RUN: %flang_fc1 -emit-llvm %s -o - | FileCheck %s
+
+! REQUIRES: target={{.+}}-aix{{.*}}
+
+subroutine s1()
+ use, intrinsic :: iso_c_binding
+ type, bind(c) :: t0
+ character(c_char) :: x1
+ real(c_double) :: x2
+ end type
+ type(t0) :: xt0
+! CHECK-DAG: %_QFs1Tt0 = type <{ [1 x i8], [3 x i8], double }>
+
+ type, bind(c) :: t1
+ integer(c_short) :: x1
+ real(c_double) :: x2
+ end type
+ type(t1) :: xt1
+! CHECK-DAG: %_QFs1Tt1 = type <{ i16, [2 x i8], double }>
+
+ type, bind(c) :: t2
+ integer(c_short) :: x1
+ real(c_double) :: x2
+ character(c_char) :: x3
+ end type
+ type(t2) :: xt2
+! CHECK-DAG: %_QFs1Tt2 = type <{ i16, [2 x i8], double, [1 x i8], [3 x i8] }>
+
+ type, bind(c) :: t3
+ character(c_char) :: x1
+ complex(c_double_complex) :: x2
+ end type
+ type(t3) :: xt3
+! CHECK-DAG: %_QFs1Tt3 = type <{ [1 x i8], [3 x i8], { double, double } }>
+
+ type, bind(c) :: t4
+ integer(c_short) :: x1
+ complex(c_double_complex) :: x2
+ character(c_char) :: x3
+ end type
+ type(t4) :: xt4
+! CHECK-DAG: %_QFs1Tt4 = type <{ i16, [2 x i8], { double, double }, [1 x i8], [3 x i8] }>
+end subroutine s1
diff --git a/flang/test/Semantics/offsets04.f90 b/flang/test/Semantics/offsets04.f90
new file mode 100644
index 00000000000000..d0d871a981c175
--- /dev/null
+++ b/flang/test/Semantics/offsets04.f90
@@ -0,0 +1,105 @@
+!RUN: %flang_fc1 -fdebug-dump-symbols %s | FileCheck %s
+
+!REQUIRES: target={{.+}}-aix{{.*}}
+
+! Size and alignment of bind(c) derived types
+subroutine s1()
+ use, intrinsic :: iso_c_binding
+ type, bind(c) :: dt1
+ character(c_char) :: x1 !CHECK: x1 size=1 offset=0:
+ real(c_double) :: x2 !CHECK: x2 size=8 offset=4:
+ end type
+ type, bind(c) :: dt2
+ character(c_char) :: x1(9) !CHECK: x1 size=9 offset=0:
+ real(c_double) :: x2 !CHECK: x2 size=8 offset=12:
+ end type
+ type, bind(c) :: dt3
+ integer(c_short) :: x1 !CHECK: x1 size=2 offset=0:
+ real(c_double) :: x2 !CHECK: x2 size=8 offset=4:
+ end type
+ type, bind(c) :: dt4
+ integer(c_int) :: x1 !CHECK: x1 size=4 offset=0:
+ real(c_double) :: x2 !CHECK: x2 size=8 offset=4:
+ end type
+ type, bind(c) :: dt5
+ real(c_double) :: x1 !CHECK: x1 size=8 offset=0:
+ real(c_double) :: x2 !CHECK: x2 size=8 offset=8:
+ end type
+ type, bind(c) :: dt6
+ integer(c_long) :: x1 !CHECK: x1 size=8 offset=0:
+ character(c_char) :: x2 !CHECK: x2 size=1 offset=8:
+ real(c_double) :: x3 !CHECK: x3 size=8 offset=12:
+ end type
+ type, bind(c) :: dt7
+ integer(c_long) :: x1 !CHECK: x1 size=8 offset=0:
+ integer(c_long) :: x2 !CHECK: x2 size=8 offset=8:
+ character(c_char) :: x3 !CHECK: x3 size=1 offset=16:
+ real(c_double) :: x4 !CHECK: x4 size=8 offset=20:
+ end type
+ type, bind(c) :: dt8
+ character(c_char) :: x1 !CHECK: x1 size=1 offset=0:
+ complex(c_double_complex) :: x2 !CHECK: x2 size=16 offset=4:
+ end type
+end subroutine
+
+subroutine s2()
+ use, intrinsic :: iso_c_binding
+ type, bind(c) :: dt10
+ character(c_char) :: x1
+ real(c_double) :: x2
+ end type
+ type, bind(c) :: dt11
+ type(dt10) :: y1 !CHECK: y1 size=12 offset=0:
+ real(c_double) :: y2 !CHECK: y2 size=8 offset=12:
+ end type
+ type, bind(c) :: dt12
+ character(c_char) :: y1 !CHECK: y1 size=1 offset=0:
+ type(dt10) :: y2 !CHECK: y2 size=12 offset=4:
+ character(c_char) :: y3 !CHECK: y3 size=1 offset=16:
+ end type
+ type, bind(c) :: dt13
+ integer(c_short) :: y1 !CHECK: y1 size=2 offset=0:
+ type(dt10) :: y2 !CHECK: y2 size=12 offset=4:
+ character(c_char) :: y3 !CHECK: y3 size=1 offset=16:
+ end type
+
+ type, bind(c) :: dt20
+ character(c_char) :: x1
+ integer(c_short) :: x2
+ end type
+ type, bind(c) :: dt21
+ real(c_double) :: y1 !CHECK: y1 size=8 offset=0:
+ type(dt20) :: y2 !CHECK: y2 size=4 offset=8:
+ real(c_double) :: y3 !CHECK: y3 size=8 offset=12:
+ end type
+
+ type, bind(c) :: dt30
+ character(c_char) :: x1
+ character(c_char) :: x2
+ end type
+ type, bind(c) :: dt31
+ integer(c_long) :: y1 !CHECK: y1 size=8 offset=0:
+ type(dt30) :: y2 !CHECK: y2 size=2 offset=8:
+ real(c_double) :: y3 !CHECK: y3 size=8 offset=12:
+ end type
+
+ type, bind(c) :: dt40
+ integer(c_short) :: x1
+ real(c_double) :: x2
+ end type
+ type, bind(c) :: dt41
+ real(c_double) :: y1 !CHECK: y1 size=8 offset=0:
+ type(dt40) :: y2 !CHECK: y2 size=12 offset=8:
+ real(c_double) :: y3 !CHECK: y3 size=8 offset=20:
+ end type
+
+ type, bind(c) :: dt50
+ integer(c_short) :: x1
+ complex(c_double_complex) :: x2
+ end type
+ type, bind(c) :: dt51
+ real(c_double) :: y1 !CHECK: y1 size=8 offset=0:
+ type(dt50) :: y2 !CHECK: y2 size=20 offset=8:
+ complex(c_double_complex) :: y3 !CHECK: y3 size=16 offset=28:
+ end type
+end subroutine
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 28150e886913e3..99826cdebb343d 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -391,6 +391,21 @@ class TypeConverter {
return callback(derivedType, results);
};
}
+ /// With callback of form: `std::optional<LogicalResult>(
+ /// T, SmallVectorImpl<Type> &, bool)`.
+ template <typename T, typename FnT>
+ std::enable_if_t<std::is_invocable_v<FnT, T, SmallVectorImpl<Type> &, bool>,
+ ConversionCallbackFn>
+ wrapCallback(FnT &&callback) const {
+ return [callback = std::forward<FnT>(callback)](
+ Type type, SmallVectorImpl<Type> &results,
+ bool isPacked) -> std::optional<LogicalResult> {
+ T derivedType = dyn_cast<T>(type);
+ if (!derivedType)
+ return std::nullopt;
+ return callback(derivedType, results, isPacked);
+ };
+ }
/// Register a type conversion.
void registerConversion(ConversionCallbackFn callback) {
``````````
</details>
https://github.com/llvm/llvm-project/pull/121505
More information about the flang-commits
mailing list