[flang-commits] [flang] [flang] Use DataLayout for computing type size in LoopVersioning. (PR #79778)

Slava Zakharin via flang-commits flang-commits at lists.llvm.org
Sun Jan 28 19:41:39 PST 2024


https://github.com/vzakhari created https://github.com/llvm/llvm-project/pull/79778

The existing type size computation in LoopVersioning does not work
for REAL*10, because the compute element size is 10 bytes,
which violates the power-of-two assertion.
We'd better use the DataLayout for computing the storage size
of each element of an array of the given type.


>From 2d05d03854c4a88ce9612598af09a3af39de98e8 Mon Sep 17 00:00:00 2001
From: Slava Zakharin <szakharin at nvidia.com>
Date: Sun, 28 Jan 2024 19:33:28 -0800
Subject: [PATCH] [flang] Use DataLayout for computing type size in
 LoopVersioning.

The existing type size computation in LoopVersioning does not work
for REAL*10, because the compute element size is 10 bytes,
which violates the power-of-two assertion.
We'd better use the DataLayout for computing the storage size
of each element of an array of the given type.
---
 .../include/flang/Optimizer/Dialect/FIRType.h | 12 +++
 flang/lib/Optimizer/CodeGen/Target.cpp        | 64 +-------------
 flang/lib/Optimizer/Dialect/FIRType.cpp       | 53 ++++++++++++
 .../Optimizer/Transforms/LoopVersioning.cpp   | 19 ++--
 flang/test/Transforms/loop-versioning.fir     | 86 ++++++++++++++++++-
 5 files changed, 168 insertions(+), 66 deletions(-)

diff --git a/flang/include/flang/Optimizer/Dialect/FIRType.h b/flang/include/flang/Optimizer/Dialect/FIRType.h
index 75106b3028ac903..0fb8e6a442a3232 100644
--- a/flang/include/flang/Optimizer/Dialect/FIRType.h
+++ b/flang/include/flang/Optimizer/Dialect/FIRType.h
@@ -15,6 +15,7 @@
 
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/Interfaces/DataLayoutInterfaces.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/IR/Type.h"
 
@@ -465,6 +466,17 @@ inline bool isBoxProcAddressType(mlir::Type t) {
 std::string getTypeAsString(mlir::Type ty, const KindMapping &kindMap,
                             llvm::StringRef prefix = "");
 
+/// Return the size and alignment of FIR types.
+/// TODO: consider moving this to a DataLayoutTypeInterface implementation
+/// for FIR types. It should first be ensured that it is OK to open the gate of
+/// target dependent type size inquiries in lowering. It would also not be
+/// straightforward given the need for a kind map that would need to be
+/// converted in terms of mlir::DataLayoutEntryKey.
+std::pair<std::uint64_t, unsigned short>
+getTypeSizeAndAlignment(mlir::Location loc, mlir::Type ty,
+                        const mlir::DataLayout &dl,
+                        const fir::KindMapping &kindMap);
+
 } // namespace fir
 
 #endif // FORTRAN_OPTIMIZER_DIALECT_FIRTYPE_H
diff --git a/flang/lib/Optimizer/CodeGen/Target.cpp b/flang/lib/Optimizer/CodeGen/Target.cpp
index a4df0b09177ab75..f2c47ffa8894127 100644
--- a/flang/lib/Optimizer/CodeGen/Target.cpp
+++ b/flang/lib/Optimizer/CodeGen/Target.cpp
@@ -59,63 +59,6 @@ static void typeTodo(const llvm::fltSemantics *sem, mlir::Location loc,
   }
 }
 
-/// Return the size and alignment of FIR types.
-/// TODO: consider moving this to a DataLayoutTypeInterface implementation
-/// for FIR types. It should first be ensured that it is OK to open the gate of
-/// target dependent type size inquiries in lowering. It would also not be
-/// straightforward given the need for a kind map that would need to be
-/// converted in terms of mlir::DataLayoutEntryKey.
-static std::pair<std::uint64_t, unsigned short>
-getSizeAndAlignment(mlir::Location loc, mlir::Type ty,
-                    const mlir::DataLayout &dl,
-                    const fir::KindMapping &kindMap) {
-  if (mlir::isa<mlir::IntegerType, mlir::FloatType, mlir::ComplexType>(ty)) {
-    llvm::TypeSize size = dl.getTypeSize(ty);
-    unsigned short alignment = dl.getTypeABIAlignment(ty);
-    return {size, alignment};
-  }
-  if (auto firCmplx = mlir::dyn_cast<fir::ComplexType>(ty)) {
-    auto [floatSize, floatAlign] =
-        getSizeAndAlignment(loc, firCmplx.getEleType(kindMap), dl, kindMap);
-    return {llvm::alignTo(floatSize, floatAlign) + floatSize, floatAlign};
-  }
-  if (auto real = mlir::dyn_cast<fir::RealType>(ty))
-    return getSizeAndAlignment(loc, real.getFloatType(kindMap), dl, kindMap);
-
-  if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(ty)) {
-    auto [eleSize, eleAlign] =
-        getSizeAndAlignment(loc, seqTy.getEleTy(), dl, kindMap);
-
-    std::uint64_t size =
-        llvm::alignTo(eleSize, eleAlign) * seqTy.getConstantArraySize();
-    return {size, eleAlign};
-  }
-  if (auto recTy = mlir::dyn_cast<fir::RecordType>(ty)) {
-    std::uint64_t size = 0;
-    unsigned short align = 1;
-    for (auto component : recTy.getTypeList()) {
-      auto [compSize, compAlign] =
-          getSizeAndAlignment(loc, component.second, dl, kindMap);
-      size =
-          llvm::alignTo(size, compAlign) + llvm::alignTo(compSize, compAlign);
-      align = std::max(align, compAlign);
-    }
-    return {size, align};
-  }
-  if (auto logical = mlir::dyn_cast<fir::LogicalType>(ty)) {
-    mlir::Type intTy = mlir::IntegerType::get(
-        logical.getContext(), kindMap.getLogicalBitsize(logical.getFKind()));
-    return getSizeAndAlignment(loc, intTy, dl, kindMap);
-  }
-  if (auto character = mlir::dyn_cast<fir::CharacterType>(ty)) {
-    mlir::Type intTy = mlir::IntegerType::get(
-        character.getContext(),
-        kindMap.getCharacterBitsize(character.getFKind()));
-    return getSizeAndAlignment(loc, intTy, dl, kindMap);
-  }
-  TODO(loc, "computing size of a component");
-}
-
 namespace {
 template <typename S>
 struct GenericTarget : public CodeGenSpecifics {
@@ -489,7 +432,7 @@ struct TargetX86_64 : public GenericTarget<TargetX86_64> {
       }
       mlir::Type compType = component.second;
       auto [compSize, compAlign] =
-          getSizeAndAlignment(loc, compType, getDataLayout(), kindMap);
+          fir::getTypeSizeAndAlignment(loc, compType, getDataLayout(), kindMap);
       byteOffset = llvm::alignTo(byteOffset, compAlign);
       ArgClass LoComp, HiComp;
       classify(loc, compType, byteOffset, LoComp, HiComp);
@@ -510,7 +453,7 @@ struct TargetX86_64 : public GenericTarget<TargetX86_64> {
     mlir::Type eleTy = seqTy.getEleTy();
     const std::uint64_t arraySize = seqTy.getConstantArraySize();
     auto [eleSize, eleAlign] =
-        getSizeAndAlignment(loc, eleTy, getDataLayout(), kindMap);
+        fir::getTypeSizeAndAlignment(loc, eleTy, getDataLayout(), kindMap);
     std::uint64_t eleStorageSize = llvm::alignTo(eleSize, eleAlign);
     for (std::uint64_t i = 0; i < arraySize; ++i) {
       byteOffset = llvm::alignTo(byteOffset, eleAlign);
@@ -697,7 +640,8 @@ struct TargetX86_64 : public GenericTarget<TargetX86_64> {
   CodeGenSpecifics::Marshalling passOnTheStack(mlir::Location loc,
                                                mlir::Type ty) const {
     CodeGenSpecifics::Marshalling marshal;
-    auto sizeAndAlign = getSizeAndAlignment(loc, ty, getDataLayout(), kindMap);
+    auto sizeAndAlign =
+        fir::getTypeSizeAndAlignment(loc, ty, getDataLayout(), kindMap);
     // The stack is always 8 byte aligned (note 14 in 3.2.3).
     unsigned short align =
         std::max(sizeAndAlign.second, static_cast<unsigned short>(8));
diff --git a/flang/lib/Optimizer/Dialect/FIRType.cpp b/flang/lib/Optimizer/Dialect/FIRType.cpp
index 0e80110848fa805..9c8812276a0a472 100644
--- a/flang/lib/Optimizer/Dialect/FIRType.cpp
+++ b/flang/lib/Optimizer/Dialect/FIRType.cpp
@@ -12,6 +12,7 @@
 
 #include "flang/Optimizer/Dialect/FIRType.h"
 #include "flang/ISO_Fortran_binding_wrapper.h"
+#include "flang/Optimizer/Builder/Todo.h"
 #include "flang/Optimizer/Dialect/FIRDialect.h"
 #include "flang/Optimizer/Dialect/Support/KindMapping.h"
 #include "flang/Tools/PointerModels.h"
@@ -1339,3 +1340,55 @@ void FIROpsDialect::registerTypes() {
   fir::LLVMPointerType::attachInterface<
       OpenACCPointerLikeModel<fir::LLVMPointerType>>(*getContext());
 }
+
+std::pair<std::uint64_t, unsigned short>
+fir::getTypeSizeAndAlignment(mlir::Location loc, mlir::Type ty,
+                             const mlir::DataLayout &dl,
+                             const fir::KindMapping &kindMap) {
+  if (mlir::isa<mlir::IntegerType, mlir::FloatType, mlir::ComplexType>(ty)) {
+    llvm::TypeSize size = dl.getTypeSize(ty);
+    unsigned short alignment = dl.getTypeABIAlignment(ty);
+    return {size, alignment};
+  }
+  if (auto firCmplx = mlir::dyn_cast<fir::ComplexType>(ty)) {
+    auto [floatSize, floatAlign] =
+        getTypeSizeAndAlignment(loc, firCmplx.getEleType(kindMap), dl, kindMap);
+    return {llvm::alignTo(floatSize, floatAlign) + floatSize, floatAlign};
+  }
+  if (auto real = mlir::dyn_cast<fir::RealType>(ty))
+    return getTypeSizeAndAlignment(loc, real.getFloatType(kindMap), dl,
+                                   kindMap);
+
+  if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(ty)) {
+    auto [eleSize, eleAlign] =
+        getTypeSizeAndAlignment(loc, seqTy.getEleTy(), dl, kindMap);
+
+    std::uint64_t size =
+        llvm::alignTo(eleSize, eleAlign) * seqTy.getConstantArraySize();
+    return {size, eleAlign};
+  }
+  if (auto recTy = mlir::dyn_cast<fir::RecordType>(ty)) {
+    std::uint64_t size = 0;
+    unsigned short align = 1;
+    for (auto component : recTy.getTypeList()) {
+      auto [compSize, compAlign] =
+          getTypeSizeAndAlignment(loc, component.second, dl, kindMap);
+      size =
+          llvm::alignTo(size, compAlign) + llvm::alignTo(compSize, compAlign);
+      align = std::max(align, compAlign);
+    }
+    return {size, align};
+  }
+  if (auto logical = mlir::dyn_cast<fir::LogicalType>(ty)) {
+    mlir::Type intTy = mlir::IntegerType::get(
+        logical.getContext(), kindMap.getLogicalBitsize(logical.getFKind()));
+    return getTypeSizeAndAlignment(loc, intTy, dl, kindMap);
+  }
+  if (auto character = mlir::dyn_cast<fir::CharacterType>(ty)) {
+    mlir::Type intTy = mlir::IntegerType::get(
+        character.getContext(),
+        kindMap.getCharacterBitsize(character.getFKind()));
+    return getTypeSizeAndAlignment(loc, intTy, dl, kindMap);
+  }
+  TODO(loc, "computing size of a component");
+}
diff --git a/flang/lib/Optimizer/Transforms/LoopVersioning.cpp b/flang/lib/Optimizer/Transforms/LoopVersioning.cpp
index bca70a0a0d322f8..7cbd2dd1f897a5f 100644
--- a/flang/lib/Optimizer/Transforms/LoopVersioning.cpp
+++ b/flang/lib/Optimizer/Transforms/LoopVersioning.cpp
@@ -49,6 +49,7 @@
 #include "flang/Optimizer/Dialect/FIRType.h"
 #include "flang/Optimizer/Dialect/Support/FIRContext.h"
 #include "flang/Optimizer/Dialect/Support/KindMapping.h"
+#include "flang/Optimizer/Support/DataLayout.h"
 #include "flang/Optimizer/Transforms/Passes.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/IR/Dominance.h"
@@ -241,6 +242,12 @@ void LoopVersioningPass::runOnOperation() {
   mlir::ModuleOp module = func->getParentOfType<mlir::ModuleOp>();
   fir::KindMapping kindMap = fir::getKindMapping(module);
   mlir::SmallVector<ArgInfo, 4> argsOfInterest;
+  std::optional<mlir::DataLayout> dl =
+      fir::support::getOrSetDataLayout(module, /*allowDefaultLayout=*/false);
+  if (!dl)
+    mlir::emitError(module.getLoc(),
+                    "data layout attribute is required to perform " DEBUG_TYPE
+                    "pass");
   for (auto &arg : args) {
     // Optional arguments must be checked for IsPresent before
     // looking for the bounds. They are unsupported for the time being.
@@ -256,11 +263,13 @@ void LoopVersioningPass::runOnOperation() {
           seqTy.getShape()[0] == fir::SequenceType::getUnknownExtent()) {
         size_t typeSize = 0;
         mlir::Type elementType = fir::unwrapSeqOrBoxedSeqType(arg.getType());
-        if (elementType.isa<mlir::FloatType>() ||
-            elementType.isa<mlir::IntegerType>())
-          typeSize = elementType.getIntOrFloatBitWidth() / 8;
-        else if (auto cty = elementType.dyn_cast<fir::ComplexType>())
-          typeSize = 2 * cty.getEleType(kindMap).getIntOrFloatBitWidth() / 8;
+        if (mlir::isa<mlir::FloatType>(elementType) ||
+            mlir::isa<mlir::IntegerType>(elementType) ||
+            mlir::isa<fir::ComplexType>(elementType)) {
+          auto [eleSize, eleAlign] = fir::getTypeSizeAndAlignment(
+              arg.getLoc(), elementType, *dl, kindMap);
+          typeSize = llvm::alignTo(eleSize, eleAlign);
+        }
         if (typeSize)
           argsOfInterest.push_back({arg, typeSize, rank, {}});
         else
diff --git a/flang/test/Transforms/loop-versioning.fir b/flang/test/Transforms/loop-versioning.fir
index d1ad1510b0e899e..0c627416563477d 100644
--- a/flang/test/Transforms/loop-versioning.fir
+++ b/flang/test/Transforms/loop-versioning.fir
@@ -11,7 +11,7 @@
 //       sum = sum + a(i)
 //    end do
 //  end subroutine sum1d
-module {
+module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<f80, dense<128> : vector<2xi64>>, #dlti.dl_entry<i128, dense<128> : vector<2xi64>>, #dlti.dl_entry<i64, dense<64> : vector<2xi64>>, #dlti.dl_entry<!llvm.ptr<272>, dense<64> : vector<4xi64>>, #dlti.dl_entry<!llvm.ptr<271>, dense<32> : vector<4xi64>>, #dlti.dl_entry<!llvm.ptr<270>, dense<32> : vector<4xi64>>, #dlti.dl_entry<f128, dense<128> : vector<2xi64>>, #dlti.dl_entry<f64, dense<64> : vector<2xi64>>, #dlti.dl_entry<f16, dense<16> : vector<2xi64>>, #dlti.dl_entry<i32, dense<32> : vector<2xi64>>, #dlti.dl_entry<i16, dense<16> : vector<2xi64>>, #dlti.dl_entry<i8, dense<8> : vector<2xi64>>, #dlti.dl_entry<i1, dense<8> : vector<2xi64>>, #dlti.dl_entry<!llvm.ptr, dense<64> : vector<4xi64>>, #dlti.dl_entry<"dlti.endianness", "little">, #dlti.dl_entry<"dlti.stack_alignment", 128 : i64>>} {
   func.func @sum1d(%arg0: !fir.box<!fir.array<?xf64>> {fir.bindc_name = "a"}, %arg1: !fir.ref<i32> {fir.bindc_name = "n"}) {
     %decl = fir.declare %arg0 {uniq_name = "a"} : (!fir.box<!fir.array<?xf64>>) -> !fir.box<!fir.array<?xf64>>
     %rebox = fir.rebox %decl : (!fir.box<!fir.array<?xf64>>) -> !fir.box<!fir.array<?xf64>>
@@ -1556,5 +1556,89 @@ func.func @minloc(%arg0: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "x"}, %ar
 // CHECK: fir.if %{{.*}} {
 // CHECK:   {{.*}} = arith.cmpi eq, %[[V17]], %c2147483647_i32
 
+func.func @_QPtest_real10(%arg0: !fir.box<!fir.array<?x?xf80>> {fir.bindc_name = "a"}) -> f80 {
+  %c10 = arith.constant 10 : index
+  %c1 = arith.constant 1 : index
+  %cst = arith.constant 0.000000e+00 : f80
+  %0 = fir.declare %arg0 {fortran_attrs = #fir.var_attrs<intent_in>, uniq_name = "_QFtest_real10Ea"} : (!fir.box<!fir.array<?x?xf80>>) -> !fir.box<!fir.array<?x?xf80>>
+  %1 = fir.rebox %0 : (!fir.box<!fir.array<?x?xf80>>) -> !fir.box<!fir.array<?x?xf80>>
+  %2 = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFtest_real10Ei"}
+  %3 = fir.declare %2 {uniq_name = "_QFtest_real10Ei"} : (!fir.ref<i32>) -> !fir.ref<i32>
+  %4 = fir.alloca f80 {bindc_name = "res", uniq_name = "_QFtest_real10Eres"}
+  %5 = fir.declare %4 {uniq_name = "_QFtest_real10Eres"} : (!fir.ref<f80>) -> !fir.ref<f80>
+  %6 = fir.address_of(@_QFtest_real10ECxdp) : !fir.ref<i32>
+  %7 = fir.declare %6 {fortran_attrs = #fir.var_attrs<parameter>, uniq_name = "_QFtest_real10ECxdp"} : (!fir.ref<i32>) -> !fir.ref<i32>
+  fir.store %cst to %5 : !fir.ref<f80>
+  %8 = fir.convert %c1 : (index) -> i32
+  %9:2 = fir.do_loop %arg1 = %c1 to %c10 step %c1 iter_args(%arg2 = %8) -> (index, i32) {
+    fir.store %arg2 to %3 : !fir.ref<i32>
+    %11 = fir.load %5 : !fir.ref<f80>
+    %12 = fir.load %3 : !fir.ref<i32>
+    %13 = fir.convert %12 : (i32) -> i64
+    %14 = fir.array_coor %1 %13, %13 : (!fir.box<!fir.array<?x?xf80>>, i64, i64) -> !fir.ref<f80>
+    %15 = fir.load %14 : !fir.ref<f80>
+    %16 = arith.addf %11, %15 fastmath<contract> : f80
+    fir.store %16 to %5 : !fir.ref<f80>
+    %17 = arith.addi %arg1, %c1 : index
+    %18 = fir.load %3 : !fir.ref<i32>
+    %19 = arith.addi %18, %8 : i32
+    fir.result %17, %19 : index, i32
+  }
+  fir.store %9#1 to %3 : !fir.ref<i32>
+  %10 = fir.load %5 : !fir.ref<f80>
+  return %10 : f80
+}
+// CHECK-LABEL:   func.func @_QPtest_real10(
+// CHECK:           fir.if
+// CHECK:             fir.do_loop
+// CHECK-DAG:           arith.shrsi %{{[^,]*}}, %[[SHIFT:.*]] : index
+// CHECK-DAG:           %[[SHIFT]] = arith.constant 4 : index
+// CHECK:             fir.result
+// CHECK:           } else {
+// CHECK:             fir.do_loop
+
+func.func @_QPtest_complex10(%arg0: !fir.box<!fir.array<?x?x!fir.complex<10>>> {fir.bindc_name = "a"}) -> !fir.complex<10> {
+  %c10 = arith.constant 10 : index
+  %c1 = arith.constant 1 : index
+  %cst = arith.constant 0.000000e+00 : f80
+  %0 = fir.declare %arg0 {fortran_attrs = #fir.var_attrs<intent_in>, uniq_name = "_QFtest_complex10Ea"} : (!fir.box<!fir.array<?x?x!fir.complex<10>>>) -> !fir.box<!fir.array<?x?x!fir.complex<10>>>
+  %1 = fir.rebox %0 : (!fir.box<!fir.array<?x?x!fir.complex<10>>>) -> !fir.box<!fir.array<?x?x!fir.complex<10>>>
+  %2 = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFtest_complex10Ei"}
+  %3 = fir.declare %2 {uniq_name = "_QFtest_complex10Ei"} : (!fir.ref<i32>) -> !fir.ref<i32>
+  %4 = fir.alloca !fir.complex<10> {bindc_name = "res", uniq_name = "_QFtest_complex10Eres"}
+  %5 = fir.declare %4 {uniq_name = "_QFtest_complex10Eres"} : (!fir.ref<!fir.complex<10>>) -> !fir.ref<!fir.complex<10>>
+  %6 = fir.address_of(@_QFtest_complex10ECxdp) : !fir.ref<i32>
+  %7 = fir.declare %6 {fortran_attrs = #fir.var_attrs<parameter>, uniq_name = "_QFtest_complex10ECxdp"} : (!fir.ref<i32>) -> !fir.ref<i32>
+  %8 = fir.undefined !fir.complex<10>
+  %9 = fir.insert_value %8, %cst, [0 : index] : (!fir.complex<10>, f80) -> !fir.complex<10>
+  %10 = fir.insert_value %9, %cst, [1 : index] : (!fir.complex<10>, f80) -> !fir.complex<10>
+  fir.store %10 to %5 : !fir.ref<!fir.complex<10>>
+  %11 = fir.convert %c1 : (index) -> i32
+  %12:2 = fir.do_loop %arg1 = %c1 to %c10 step %c1 iter_args(%arg2 = %11) -> (index, i32) {
+    fir.store %arg2 to %3 : !fir.ref<i32>
+    %14 = fir.load %5 : !fir.ref<!fir.complex<10>>
+    %15 = fir.load %3 : !fir.ref<i32>
+    %16 = fir.convert %15 : (i32) -> i64
+    %17 = fir.array_coor %1 %16, %16 : (!fir.box<!fir.array<?x?x!fir.complex<10>>>, i64, i64) -> !fir.ref<!fir.complex<10>>
+    %18 = fir.load %17 : !fir.ref<!fir.complex<10>>
+    %19 = fir.addc %14, %18 {fastmath = #arith.fastmath<contract>} : !fir.complex<10>
+    fir.store %19 to %5 : !fir.ref<!fir.complex<10>>
+    %20 = arith.addi %arg1, %c1 : index
+    %21 = fir.load %3 : !fir.ref<i32>
+    %22 = arith.addi %21, %11 : i32
+    fir.result %20, %22 : index, i32
+  }
+  fir.store %12#1 to %3 : !fir.ref<i32>
+  %13 = fir.load %5 : !fir.ref<!fir.complex<10>>
+  return %13 : !fir.complex<10>
+}
+// CHECK-LABEL:   func.func @_QPtest_complex10(
+// CHECK:           fir.if
+// CHECK:             fir.do_loop
+// CHECK-DAG:           arith.shrsi %{{[^,]*}}, %[[SHIFT:.*]] : index
+// CHECK-DAG:           %[[SHIFT]] = arith.constant 5 : index
+// CHECK:             fir.result
+// CHECK:           } else {
+// CHECK:             fir.do_loop
 
 } // End module



More information about the flang-commits mailing list