[flang-commits] [flang] [mlir] [flang][AIX] BIND(C) derived type alignment for AIX (PR #121505)

via flang-commits flang-commits at lists.llvm.org
Thu Jan 2 09:19:58 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-semantics

@llvm/pr-subscribers-flang-fir-hlfir

Author: Kelvin Li (kkwli)

<details>
<summary>Changes</summary>

This patch is to handle the alignment requirement for the `bind(c)` derived type component that is real type and larger than 4 bytes. The alignment of such component is 4-byte.

---
Full diff: https://github.com/llvm/llvm-project/pull/121505.diff


9 Files Affected:

- (modified) flang/include/flang/Optimizer/CodeGen/TypeConverter.h (+1-1) 
- (modified) flang/include/flang/Optimizer/Dialect/FIRTypes.td (+3) 
- (modified) flang/lib/Lower/ConvertType.cpp (+42) 
- (modified) flang/lib/Optimizer/CodeGen/TypeConverter.cpp (+6-4) 
- (modified) flang/lib/Optimizer/Dialect/FIRType.cpp (+9-1) 
- (modified) flang/lib/Semantics/compute-offsets.cpp (+77-5) 
- (added) flang/test/Lower/derived-types-bindc.f90 (+44) 
- (added) flang/test/Semantics/offsets04.f90 (+105) 
- (modified) mlir/include/mlir/Transforms/DialectConversion.h (+15) 


``````````diff
diff --git a/flang/include/flang/Optimizer/CodeGen/TypeConverter.h b/flang/include/flang/Optimizer/CodeGen/TypeConverter.h
index 7c317ddeea1fa4..20270d41b1e9a1 100644
--- a/flang/include/flang/Optimizer/CodeGen/TypeConverter.h
+++ b/flang/include/flang/Optimizer/CodeGen/TypeConverter.h
@@ -62,7 +62,7 @@ class LLVMTypeConverter : public mlir::LLVMTypeConverter {
   // fir.type<name(p : TY'...){f : TY...}>  -->  llvm<"%name = { ty... }">
   std::optional<llvm::LogicalResult>
   convertRecordType(fir::RecordType derived,
-                    llvm::SmallVectorImpl<mlir::Type> &results);
+                    llvm::SmallVectorImpl<mlir::Type> &results, bool isPacked);
 
   // Is an extended descriptor needed given the element type of a fir.box type ?
   // Extended descriptors are required for derived types.
diff --git a/flang/include/flang/Optimizer/Dialect/FIRTypes.td b/flang/include/flang/Optimizer/Dialect/FIRTypes.td
index 3919c9191c2122..4d832a49236bf1 100644
--- a/flang/include/flang/Optimizer/Dialect/FIRTypes.td
+++ b/flang/include/flang/Optimizer/Dialect/FIRTypes.td
@@ -346,6 +346,9 @@ def fir_RecordType : FIR_Type<"Record", "type"> {
     void finalize(llvm::ArrayRef<TypePair> lenPList,
                   llvm::ArrayRef<TypePair> typeList);
 
+    bool isPacked() const;
+    void pack(bool);
+
     detail::RecordTypeStorage const *uniqueKey() const;
   }];
 }
diff --git a/flang/lib/Lower/ConvertType.cpp b/flang/lib/Lower/ConvertType.cpp
index 452ddda426fa10..40dbaa28ce7cac 100644
--- a/flang/lib/Lower/ConvertType.cpp
+++ b/flang/lib/Lower/ConvertType.cpp
@@ -20,6 +20,8 @@
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "llvm/Support/Debug.h"
+#include "llvm/TargetParser/Host.h"
+#include "llvm/TargetParser/Triple.h"
 
 #define DEBUG_TYPE "flang-lower-type"
 
@@ -385,9 +387,19 @@ struct TypeBuilderImpl {
     // with dozens of components/parents (modern Fortran).
     derivedTypeInConstruction.try_emplace(&derivedScope, rec);
 
+    auto targetTriple{llvm::Triple(
+        llvm::Triple::normalize(llvm::sys::getDefaultTargetTriple()))};
+    // Always generate packed FIR struct type for bind(c) derived type for AIX
+    if (targetTriple.getOS() == llvm::Triple::OSType::AIX &&
+        tySpec.typeSymbol().attrs().test(Fortran::semantics::Attr::BIND_C) &&
+        !IsIsoCType(&tySpec)) {
+      rec.pack(true);
+    }
+
     // Gather the record type fields.
     // (1) The data components.
     if (converter.getLoweringOptions().getLowerToHighLevelFIR()) {
+      size_t prev_offset{0};
       // In HLFIR the parent component is the first fir.type component.
       for (const auto &componentName :
            typeSymbol.get<Fortran::semantics::DerivedTypeDetails>()
@@ -397,7 +409,37 @@ struct TypeBuilderImpl {
                "failed to find derived type component symbol");
         const Fortran::semantics::Symbol &component = scopeIter->second.get();
         mlir::Type ty = genSymbolType(component);
+        if (rec.isPacked()) {
+          auto compSize{component.size()};
+          auto compOffset{component.offset()};
+
+          if (prev_offset < compOffset) {
+            size_t pad{compOffset - prev_offset};
+            mlir::Type i8Ty{mlir::IntegerType::get(context, 8)};
+            fir::SequenceType::Shape shape{static_cast<int64_t>(pad)};
+            mlir::Type padTy{fir::SequenceType::get(shape, i8Ty)};
+            prev_offset += pad;
+            cs.emplace_back("", padTy);
+          }
+          prev_offset += compSize;
+        }
         cs.emplace_back(converter.getRecordTypeFieldName(component), ty);
+        if (rec.isPacked()) {
+          // For the last component, determine if any padding is needed.
+          if (componentName ==
+              typeSymbol.get<Fortran::semantics::DerivedTypeDetails>()
+                  .componentNames()
+                  .back()) {
+            auto compEnd{component.offset() + component.size()};
+            if (compEnd < derivedScope.size()) {
+              size_t pad{derivedScope.size() - compEnd};
+              mlir::Type i8Ty{mlir::IntegerType::get(context, 8)};
+              fir::SequenceType::Shape shape{static_cast<int64_t>(pad)};
+              mlir::Type padTy{fir::SequenceType::get(shape, i8Ty)};
+              cs.emplace_back("", padTy);
+            }
+          }
+        }
       }
     } else {
       for (const auto &component :
diff --git a/flang/lib/Optimizer/CodeGen/TypeConverter.cpp b/flang/lib/Optimizer/CodeGen/TypeConverter.cpp
index c23203efcd3df2..0eace903720f03 100644
--- a/flang/lib/Optimizer/CodeGen/TypeConverter.cpp
+++ b/flang/lib/Optimizer/CodeGen/TypeConverter.cpp
@@ -82,7 +82,7 @@ LLVMTypeConverter::LLVMTypeConverter(mlir::ModuleOp module, bool applyTBAA,
       [&](fir::PointerType pointer) { return convertPointerLike(pointer); });
   addConversion(
       [&](fir::RecordType derived, llvm::SmallVectorImpl<mlir::Type> &results) {
-        return convertRecordType(derived, results);
+        return convertRecordType(derived, results, derived.isPacked());
       });
   addConversion(
       [&](fir::ReferenceType ref) { return convertPointerLike(ref); });
@@ -133,8 +133,10 @@ mlir::Type LLVMTypeConverter::indexType() const {
 }
 
 // fir.type<name(p : TY'...){f : TY...}>  -->  llvm<"%name = { ty... }">
-std::optional<llvm::LogicalResult> LLVMTypeConverter::convertRecordType(
-    fir::RecordType derived, llvm::SmallVectorImpl<mlir::Type> &results) {
+std::optional<llvm::LogicalResult>
+LLVMTypeConverter::convertRecordType(fir::RecordType derived,
+                                     llvm::SmallVectorImpl<mlir::Type> &results,
+                                     bool isPacked) {
   auto name = fir::NameUniquer::dropTypeConversionMarkers(derived.getName());
   auto st = mlir::LLVM::LLVMStructType::getIdentified(&getContext(), name);
 
@@ -156,7 +158,7 @@ std::optional<llvm::LogicalResult> LLVMTypeConverter::convertRecordType(
     else
       members.push_back(mlir::cast<mlir::Type>(convertType(mem.second)));
   }
-  if (mlir::failed(st.setBody(members, /*isPacked=*/false)))
+  if (mlir::failed(st.setBody(members, isPacked)))
     return mlir::failure();
   results.push_back(st);
   return mlir::success();
diff --git a/flang/lib/Optimizer/Dialect/FIRType.cpp b/flang/lib/Optimizer/Dialect/FIRType.cpp
index cba7fa64128502..ea06eb092ed918 100644
--- a/flang/lib/Optimizer/Dialect/FIRType.cpp
+++ b/flang/lib/Optimizer/Dialect/FIRType.cpp
@@ -165,16 +165,20 @@ struct RecordTypeStorage : public mlir::TypeStorage {
     setTypeList(typeList);
   }
 
+  bool isPacked() const { return packed; }
+  void pack(bool p) { packed = p; }
+
 protected:
   std::string name;
   bool finalized;
+  bool packed;
   std::vector<RecordType::TypePair> lens;
   std::vector<RecordType::TypePair> types;
 
 private:
   RecordTypeStorage() = delete;
   explicit RecordTypeStorage(llvm::StringRef name)
-      : name{name}, finalized{false} {}
+      : name{name}, finalized{false}, packed{false} {}
 };
 
 } // namespace detail
@@ -973,6 +977,10 @@ RecordType::TypeList fir::RecordType::getLenParamList() const {
 
 bool fir::RecordType::isFinalized() const { return getImpl()->isFinalized(); }
 
+void fir::RecordType::pack(bool p) { getImpl()->pack(p); }
+
+bool fir::RecordType::isPacked() const { return getImpl()->isPacked(); }
+
 detail::RecordTypeStorage const *fir::RecordType::uniqueKey() const {
   return getImpl();
 }
diff --git a/flang/lib/Semantics/compute-offsets.cpp b/flang/lib/Semantics/compute-offsets.cpp
index 94640fa30baa54..7d516b3e8df54a 100644
--- a/flang/lib/Semantics/compute-offsets.cpp
+++ b/flang/lib/Semantics/compute-offsets.cpp
@@ -17,6 +17,8 @@
 #include "flang/Semantics/symbol.h"
 #include "flang/Semantics/tools.h"
 #include "flang/Semantics/type.h"
+#include "llvm/TargetParser/Host.h"
+#include "llvm/TargetParser/Triple.h"
 #include <algorithm>
 #include <vector>
 
@@ -51,9 +53,12 @@ class ComputeOffsetsHelper {
   SymbolAndOffset Resolve(const SymbolAndOffset &);
   std::size_t ComputeOffset(const EquivalenceObject &);
   // Returns amount of padding that was needed for alignment
-  std::size_t DoSymbol(Symbol &);
+  std::size_t DoSymbol(Symbol &,
+                       std::optional<const size_t> newAlign = std::nullopt);
   SizeAndAlignment GetSizeAndAlignment(const Symbol &, bool entire);
   std::size_t Align(std::size_t, std::size_t);
+  std::optional<size_t> CompAlignment(const Symbol &);
+  std::optional<size_t> HasSpecialAlign(const Symbol &, Scope &);
 
   SemanticsContext &context_;
   std::size_t offset_{0};
@@ -65,6 +70,60 @@ class ComputeOffsetsHelper {
       equivalenceBlock_;
 };
 
+static bool isReal8OrLarger(const Fortran::semantics::DeclTypeSpec *type) {
+  return ((type->IsNumeric(common::TypeCategory::Real) ||
+           type->IsNumeric(common::TypeCategory::Complex)) &&
+          evaluate::ToInt64(type->numericTypeSpec().kind()) > 4);
+}
+
+std::optional<size_t> ComputeOffsetsHelper::CompAlignment(const Symbol &sym) {
+  size_t max_align{0};
+  bool contain_double{false};
+  const auto derivedTypeSpec{sym.GetType()->AsDerived()};
+  DirectComponentIterator directs{*derivedTypeSpec};
+  for (auto it = directs.begin(); it != directs.end(); ++it) {
+    auto type{it->GetType()};
+    auto s{GetSizeAndAlignment(*it, true)};
+    if (isReal8OrLarger(type)) {
+      max_align = std::max(max_align, 4UL);
+      contain_double = true;
+    } else if (type->AsDerived()) {
+      if (const auto newAlgin = CompAlignment(*it)) {
+        max_align = std::max(max_align, s.alignment);
+      } else {
+        return std::nullopt;
+      }
+    } else {
+      max_align = std::max(max_align, s.alignment);
+    }
+  }
+
+  if (contain_double)
+    return max_align;
+  else
+    return std::nullopt;
+}
+
+std::optional<size_t> ComputeOffsetsHelper::HasSpecialAlign(const Symbol &sym,
+                                                            Scope &scope) {
+  // On AIX, if the component that is not the first component and is
+  // a float of 8 bytes or larger, it has the 4-byte alignment.
+  // Only set the special alignment for bind(c) derived type on that platform.
+  if (const auto type = sym.GetType()) {
+    auto &symOwner{sym.owner()};
+    if (symOwner.symbol() && symOwner.IsDerivedType() &&
+        symOwner.symbol()->attrs().HasAny({semantics::Attr::BIND_C}) &&
+        &sym != &(*scope.GetSymbols().front())) {
+      if (isReal8OrLarger(type)) {
+        return 4UL;
+      } else if (type->AsDerived()) {
+        return CompAlignment(sym);
+      }
+    }
+  }
+  return std::nullopt;
+}
+
 void ComputeOffsetsHelper::Compute(Scope &scope) {
   for (Scope &child : scope.children()) {
     ComputeOffsets(context_, child);
@@ -113,7 +172,15 @@ void ComputeOffsetsHelper::Compute(Scope &scope) {
     if (!FindCommonBlockContaining(*symbol) &&
         dependents_.find(symbol) == dependents_.end() &&
         equivalenceBlock_.find(symbol) == equivalenceBlock_.end()) {
-      DoSymbol(*symbol);
+
+      std::optional<size_t> newAlign{std::nullopt};
+      // Handle special alignment requirement for AIX
+      auto triple{llvm::Triple(llvm::Triple::normalize(
+         llvm::sys::getDefaultTargetTriple()))};
+      if (triple.getOS() == llvm::Triple::OSType::AIX) {
+        newAlign = HasSpecialAlign(*symbol, scope);
+      }
+      DoSymbol(*symbol, newAlign);
       if (auto *generic{symbol->detailsIf<GenericDetails>()}) {
         if (Symbol * specific{generic->specific()};
             specific && !FindCommonBlockContaining(*specific)) {
@@ -313,7 +380,8 @@ std::size_t ComputeOffsetsHelper::ComputeOffset(
   return result;
 }
 
-std::size_t ComputeOffsetsHelper::DoSymbol(Symbol &symbol) {
+std::size_t ComputeOffsetsHelper::DoSymbol(Symbol &symbol,
+                                           std::optional<const size_t> newAlign) {
   if (!symbol.has<ObjectEntityDetails>() && !symbol.has<ProcEntityDetails>()) {
     return 0;
   }
@@ -322,12 +390,16 @@ std::size_t ComputeOffsetsHelper::DoSymbol(Symbol &symbol) {
     return 0;
   }
   std::size_t previousOffset{offset_};
-  offset_ = Align(offset_, s.alignment);
+  size_t alignVal{s.alignment};
+  if (newAlign) {
+    alignVal = newAlign.value();
+  }
+  offset_ = Align(offset_, alignVal);
   std::size_t padding{offset_ - previousOffset};
   symbol.set_size(s.size);
   symbol.set_offset(offset_);
   offset_ += s.size;
-  alignment_ = std::max(alignment_, s.alignment);
+  alignment_ = std::max(alignment_, alignVal);
   return padding;
 }
 
diff --git a/flang/test/Lower/derived-types-bindc.f90 b/flang/test/Lower/derived-types-bindc.f90
new file mode 100644
index 00000000000000..309b2b7f5f4929
--- /dev/null
+++ b/flang/test/Lower/derived-types-bindc.f90
@@ -0,0 +1,44 @@
+! Test padding for BIND(C) derived types lowering for AIX target
+! RUN: %flang_fc1 -emit-llvm %s -o - | FileCheck %s
+
+! REQUIRES: target={{.+}}-aix{{.*}}
+
+subroutine s1()
+  use, intrinsic :: iso_c_binding
+  type, bind(c) :: t0
+    character(c_char) :: x1
+    real(c_double) :: x2
+  end type
+  type(t0) :: xt0
+! CHECK-DAG: %_QFs1Tt0 = type <{ [1 x i8], [3 x i8], double }>
+
+  type, bind(c) :: t1
+    integer(c_short) :: x1
+    real(c_double) :: x2
+  end type
+  type(t1) :: xt1
+! CHECK-DAG: %_QFs1Tt1 = type <{ i16, [2 x i8], double }>
+
+  type, bind(c) :: t2
+    integer(c_short) :: x1
+    real(c_double) :: x2
+    character(c_char) :: x3
+  end type
+  type(t2) :: xt2
+! CHECK-DAG: %_QFs1Tt2 = type <{ i16, [2 x i8], double, [1 x i8], [3 x i8] }>
+
+  type, bind(c) :: t3
+    character(c_char) :: x1
+    complex(c_double_complex) :: x2
+  end type
+  type(t3) :: xt3
+! CHECK-DAG: %_QFs1Tt3 = type <{ [1 x i8], [3 x i8], { double, double } }>
+
+  type, bind(c) :: t4
+    integer(c_short) :: x1
+    complex(c_double_complex) :: x2
+    character(c_char) :: x3
+  end type
+  type(t4) :: xt4
+! CHECK-DAG: %_QFs1Tt4 = type <{ i16, [2 x i8], { double, double }, [1 x i8], [3 x i8] }>
+end subroutine s1
diff --git a/flang/test/Semantics/offsets04.f90 b/flang/test/Semantics/offsets04.f90
new file mode 100644
index 00000000000000..d0d871a981c175
--- /dev/null
+++ b/flang/test/Semantics/offsets04.f90
@@ -0,0 +1,105 @@
+!RUN: %flang_fc1 -fdebug-dump-symbols %s | FileCheck %s
+
+!REQUIRES: target={{.+}}-aix{{.*}}
+
+! Size and alignment of bind(c) derived types
+subroutine s1()
+  use, intrinsic :: iso_c_binding
+  type, bind(c) :: dt1
+    character(c_char) :: x1    !CHECK: x1 size=1 offset=0:
+    real(c_double) :: x2       !CHECK: x2 size=8 offset=4:
+  end type
+  type, bind(c) :: dt2
+    character(c_char) :: x1(9) !CHECK: x1 size=9 offset=0:
+    real(c_double) :: x2       !CHECK: x2 size=8 offset=12:
+  end type
+  type, bind(c) :: dt3
+    integer(c_short) :: x1     !CHECK: x1 size=2 offset=0:
+    real(c_double) :: x2       !CHECK: x2 size=8 offset=4:
+  end type
+  type, bind(c) :: dt4
+    integer(c_int) :: x1       !CHECK: x1 size=4 offset=0:
+    real(c_double) :: x2       !CHECK: x2 size=8 offset=4:
+  end type
+  type, bind(c) :: dt5
+    real(c_double) :: x1       !CHECK: x1 size=8 offset=0:
+    real(c_double) :: x2       !CHECK: x2 size=8 offset=8:
+  end type
+  type, bind(c) :: dt6
+    integer(c_long) :: x1      !CHECK: x1 size=8 offset=0:
+    character(c_char) :: x2    !CHECK: x2 size=1 offset=8:
+    real(c_double) :: x3       !CHECK: x3 size=8 offset=12:
+  end type
+  type, bind(c) :: dt7
+    integer(c_long) :: x1      !CHECK: x1 size=8 offset=0:
+    integer(c_long) :: x2      !CHECK: x2 size=8 offset=8:
+    character(c_char) :: x3    !CHECK: x3 size=1 offset=16:
+    real(c_double) :: x4       !CHECK: x4 size=8 offset=20:
+  end type
+  type, bind(c) :: dt8
+    character(c_char) :: x1         !CHECK: x1 size=1 offset=0:
+    complex(c_double_complex) :: x2 !CHECK: x2 size=16 offset=4:
+  end type
+end subroutine
+
+subroutine s2()
+  use, intrinsic :: iso_c_binding
+  type, bind(c) :: dt10
+    character(c_char) :: x1
+    real(c_double) :: x2
+  end type
+  type, bind(c) :: dt11
+    type(dt10) :: y1           !CHECK: y1 size=12 offset=0:
+    real(c_double) :: y2       !CHECK: y2 size=8 offset=12:
+  end type
+  type, bind(c) :: dt12
+    character(c_char) :: y1    !CHECK: y1 size=1 offset=0:
+    type(dt10) :: y2           !CHECK: y2 size=12 offset=4:
+    character(c_char) :: y3    !CHECK: y3 size=1 offset=16:
+  end type
+  type, bind(c) :: dt13
+    integer(c_short) :: y1     !CHECK: y1 size=2 offset=0:
+    type(dt10) :: y2           !CHECK: y2 size=12 offset=4:
+    character(c_char) :: y3    !CHECK: y3 size=1 offset=16:
+  end type
+
+  type, bind(c) :: dt20
+    character(c_char) :: x1
+    integer(c_short) :: x2
+  end type
+  type, bind(c) :: dt21
+    real(c_double) :: y1       !CHECK: y1 size=8 offset=0:
+    type(dt20) :: y2           !CHECK: y2 size=4 offset=8:
+    real(c_double) :: y3       !CHECK: y3 size=8 offset=12:
+  end type
+
+  type, bind(c) :: dt30
+    character(c_char) :: x1
+    character(c_char) :: x2
+  end type
+  type, bind(c) :: dt31
+     integer(c_long) :: y1     !CHECK: y1 size=8 offset=0:
+     type(dt30) :: y2          !CHECK: y2 size=2 offset=8:
+     real(c_double) :: y3      !CHECK: y3 size=8 offset=12:
+  end type
+
+  type, bind(c) :: dt40
+    integer(c_short) :: x1
+    real(c_double) :: x2
+  end type
+  type, bind(c) :: dt41
+    real(c_double) :: y1       !CHECK: y1 size=8 offset=0:
+    type(dt40) :: y2           !CHECK: y2 size=12 offset=8:
+    real(c_double) :: y3       !CHECK: y3 size=8 offset=20:
+  end type
+
+  type, bind(c) :: dt50
+    integer(c_short) :: x1
+    complex(c_double_complex) :: x2
+  end type
+  type, bind(c) :: dt51
+    real(c_double) :: y1            !CHECK: y1 size=8 offset=0:
+    type(dt50) :: y2                !CHECK: y2 size=20 offset=8:
+    complex(c_double_complex) :: y3 !CHECK: y3 size=16 offset=28:
+  end type
+end subroutine
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 28150e886913e3..99826cdebb343d 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -391,6 +391,21 @@ class TypeConverter {
       return callback(derivedType, results);
     };
   }
+  /// With callback of form: `std::optional<LogicalResult>(
+  ///     T, SmallVectorImpl<Type> &, bool)`.
+  template <typename T, typename FnT>
+  std::enable_if_t<std::is_invocable_v<FnT, T, SmallVectorImpl<Type> &, bool>,
+                   ConversionCallbackFn>
+  wrapCallback(FnT &&callback) const {
+    return [callback = std::forward<FnT>(callback)](
+               Type type, SmallVectorImpl<Type> &results,
+               bool isPacked) -> std::optional<LogicalResult> {
+      T derivedType = dyn_cast<T>(type);
+      if (!derivedType)
+        return std::nullopt;
+      return callback(derivedType, results, isPacked);
+    };
+  }
 
   /// Register a type conversion.
   void registerConversion(ConversionCallbackFn callback) {

``````````

</details>


https://github.com/llvm/llvm-project/pull/121505


More information about the flang-commits mailing list