[flang-commits] [flang] c8517f1 - [flang] Add support for dense complex constants

Leandro Lupori via flang-commits flang-commits at lists.llvm.org
Wed Aug 30 06:51:54 PDT 2023


Author: Leandro Lupori
Date: 2023-08-30T10:51:02-03:00
New Revision: c8517f17525e641432b916defb2aa94077bf1b09

URL: https://github.com/llvm/llvm-project/commit/c8517f17525e641432b916defb2aa94077bf1b09
DIFF: https://github.com/llvm/llvm-project/commit/c8517f17525e641432b916defb2aa94077bf1b09.diff

LOG: [flang] Add support for dense complex constants

Add support for representing complex array constants with MLIR
dense attribute. This improves compile time and greatly reduces
memory usage of programs with large complex array constants.

Fixes https://github.com/llvm/llvm-project/issues/63610

Reviewed By: vzakhari

Differential Revision: https://reviews.llvm.org/D155951

Added: 
    

Modified: 
    flang/lib/Lower/ConvertConstant.cpp
    flang/lib/Lower/ConvertVariable.cpp
    flang/lib/Optimizer/CodeGen/CodeGen.cpp
    flang/test/Lower/array.f90

Removed: 
    


################################################################################
diff  --git a/flang/lib/Lower/ConvertConstant.cpp b/flang/lib/Lower/ConvertConstant.cpp
index 4eff94bf5c6b7a..d1887b610c8fae 100644
--- a/flang/lib/Lower/ConvertConstant.cpp
+++ b/flang/lib/Lower/ConvertConstant.cpp
@@ -61,12 +61,24 @@ static mlir::Attribute convertToAttribute(
   } else if constexpr (TC == Fortran::common::TypeCategory::Logical) {
     return builder.getIntegerAttr(type, value.IsTrue());
   } else {
-    static_assert(TC == Fortran::common::TypeCategory::Real,
-                  "type values cannot be converted to attributes");
-    std::string str = value.DumpHexadecimal();
-    auto floatVal =
-        consAPFloat(builder.getKindMap().getFloatSemantics(KIND), str);
-    return builder.getFloatAttr(type, floatVal);
+    auto getFloatAttr = [&](const auto &value, mlir::Type type) {
+      std::string str = value.DumpHexadecimal();
+      auto floatVal =
+          consAPFloat(builder.getKindMap().getFloatSemantics(KIND), str);
+      return builder.getFloatAttr(type, floatVal);
+    };
+
+    if constexpr (TC == Fortran::common::TypeCategory::Real) {
+      return getFloatAttr(value, type);
+    } else {
+      static_assert(TC == Fortran::common::TypeCategory::Complex,
+                    "type values cannot be converted to attributes");
+      mlir::Type eleTy = mlir::cast<mlir::ComplexType>(type).getElementType();
+      llvm::SmallVector<mlir::Attribute, 2> attrs = {
+          getFloatAttr(value.REAL(), eleTy),
+          getFloatAttr(value.AIMAG(), eleTy)};
+      return builder.getArrayAttr(attrs);
+    }
   }
   return {};
 }
@@ -75,12 +87,11 @@ namespace {
 /// Helper class to lower an array constant to a global with an MLIR dense
 /// attribute.
 ///
-/// If we have an array of integer, real, or logical, then we can
+/// If we have an array of integer, real, complex, or logical, then we can
 /// create a global array with the dense attribute.
 ///
-/// The mlir tensor type can only handle integer, real, or logical. It
-/// does not currently support nested structures which is required for
-/// complex.
+/// The mlir tensor type can only handle integer, real, complex, or logical.
+/// It does not currently support nested structures.
 class DenseGlobalBuilder {
 public:
   static fir::GlobalOp tryCreating(fir::FirOpBuilder &builder,
@@ -98,6 +109,8 @@ class DenseGlobalBuilder {
             [&](const Fortran::evaluate::Expr<Fortran::evaluate::SomeReal> &x) {
               globalBuilder.tryConvertingToAttributes(builder, x);
             },
+            [&](const Fortran::evaluate::Expr<Fortran::evaluate::SomeComplex> &
+                    x) { globalBuilder.tryConvertingToAttributes(builder, x); },
             [](const auto &) {},
         },
         initExpr.u);
@@ -133,6 +146,9 @@ class DenseGlobalBuilder {
                       : TC;
     attributeElementType = Fortran::lower::getFIRType(
         builder.getContext(), attrTc, KIND, std::nullopt);
+    if (auto firCTy = mlir::dyn_cast<fir::ComplexType>(attributeElementType))
+      attributeElementType =
+          mlir::ComplexType::get(firCTy.getEleType(builder.getKindMap()));
     for (auto element : constant.values())
       attributes.push_back(
           convertToAttribute<TC, KIND>(builder, element, attributeElementType));
@@ -544,7 +560,8 @@ genOutlineArrayLit(Fortran::lower::AbstractConverter &converter,
     // always possible.
     if constexpr (T::category == Fortran::common::TypeCategory::Logical ||
                   T::category == Fortran::common::TypeCategory::Integer ||
-                  T::category == Fortran::common::TypeCategory::Real) {
+                  T::category == Fortran::common::TypeCategory::Real ||
+                  T::category == Fortran::common::TypeCategory::Complex) {
       global = DenseGlobalBuilder::tryCreating(
           builder, loc, arrayTy, globalName, builder.createInternalLinkage(),
           true, constant);

diff  --git a/flang/lib/Lower/ConvertVariable.cpp b/flang/lib/Lower/ConvertVariable.cpp
index 20c99558b1a9d8..15731f14557c6c 100644
--- a/flang/lib/Lower/ConvertVariable.cpp
+++ b/flang/lib/Lower/ConvertVariable.cpp
@@ -431,12 +431,13 @@ static fir::GlobalOp defineGlobal(Fortran::lower::AbstractConverter &converter,
 
   // If this is an array, check to see if we can use a dense attribute
   // with a tensor mlir type. This optimization currently only supports
-  // Fortran arrays of integer, real, or logical. The tensor type does
-  // not support nested structures which are needed for complex numbers.
+  // Fortran arrays of integer, real, complex, or logical. The tensor
+  // type does not support nested structures.
   if (symTy.isa<fir::SequenceType>() &&
       !Fortran::semantics::IsAllocatableOrPointer(sym)) {
     mlir::Type eleTy = symTy.cast<fir::SequenceType>().getEleTy();
-    if (eleTy.isa<mlir::IntegerType, mlir::FloatType, fir::LogicalType>()) {
+    if (eleTy.isa<mlir::IntegerType, mlir::FloatType, fir::ComplexType,
+                  fir::LogicalType>()) {
       const auto *details =
           sym.detailsIf<Fortran::semantics::ObjectEntityDetails>();
       if (details->init()) {

diff  --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index 76c15547429195..04e40cc759315f 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -2892,6 +2892,42 @@ struct HasValueOpConversion : public FIROpConversion<fir::HasValueOp> {
   }
 };
 
+// Check if attr's type is compatible with ty.
+//
+// This is done by comparing attr's element type, converted to LLVM type,
+// with ty's element type.
+//
+// Only integer and floating point (including complex) attributes are
+// supported. Also, attr is expected to have a TensorType and ty is expected
+// to be of LLVMArrayType. If any of the previous conditions is false, then
+// the specified attr and ty are not supported by this function and are
+// assumed to be compatible.
+static inline bool attributeTypeIsCompatible(mlir::MLIRContext *ctx,
+                                             mlir::Attribute attr,
+                                             mlir::Type ty) {
+  // Get attr's LLVM element type.
+  if (!attr)
+    return true;
+  auto intOrFpEleAttr = mlir::dyn_cast<mlir::DenseIntOrFPElementsAttr>(attr);
+  if (!intOrFpEleAttr)
+    return true;
+  auto tensorTy = mlir::dyn_cast<mlir::TensorType>(intOrFpEleAttr.getType());
+  if (!tensorTy)
+    return true;
+  mlir::Type attrEleTy =
+      mlir::LLVMTypeConverter(ctx).convertType(tensorTy.getElementType());
+
+  // Get ty's element type.
+  auto arrTy = mlir::dyn_cast<mlir::LLVM::LLVMArrayType>(ty);
+  if (!arrTy)
+    return true;
+  mlir::Type eleTy = arrTy.getElementType();
+  while ((arrTy = mlir::dyn_cast<mlir::LLVM::LLVMArrayType>(eleTy)))
+    eleTy = arrTy.getElementType();
+
+  return attrEleTy == eleTy;
+}
+
 /// Lower `fir.global` operation to `llvm.global` operation.
 /// `fir.insert_on_range` operations are replaced with constant dense attribute
 /// if they are applied on the full range.
@@ -2906,6 +2942,7 @@ struct GlobalOpConversion : public FIROpConversion<fir::GlobalOp> {
       tyAttr = tyAttr.cast<mlir::LLVM::LLVMPointerType>().getElementType();
     auto loc = global.getLoc();
     mlir::Attribute initAttr = global.getInitVal().value_or(mlir::Attribute());
+    assert(attributeTypeIsCompatible(global.getContext(), initAttr, tyAttr));
     auto linkage = convertLinkage(global.getLinkName());
     auto isConst = global.getConstant().has_value();
     auto g = rewriter.create<mlir::LLVM::GlobalOp>(

diff  --git a/flang/test/Lower/array.f90 b/flang/test/Lower/array.f90
index c64da68f302721..852f5605892800 100644
--- a/flang/test/Lower/array.f90
+++ b/flang/test/Lower/array.f90
@@ -103,13 +103,20 @@ subroutine range()
   real, dimension(2,3) ::  a1
   integer, dimension(3,4) :: a2
   integer, dimension(2,3,4) :: a3
- 
+  complex, dimension(2,3) :: c0, c1
+
   a0 = (/1, 2, 3, 3, 3, 3, 3, 3, 3, 3/)
   a1 = reshape((/3.5, 3.5, 3.5, 3.5, 3.5, 3.5/), shape(a1))
   a2 = reshape((/1, 3, 3, 5, 3, 3, 3, 3, 9, 9, 9, 8/), shape(a2))
   a3 = reshape((/1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9, 10, 10, 11, 11, 12, 12/), shape(a3))
+
+  c0 = reshape((/(1.0, 1.5), (2.0, 2.5), (3.0, 3.5), (4.0, 4.5), (5.0, 5.5), (6.0, 6.5)/), shape(c0))
+  data c1/6 * (0.0, 0.0)/
 end subroutine range
 
+! c1 data
+! CHECK: fir.global internal @_QFrangeEc1(dense<(0.000000e+00,0.000000e+00)> : tensor<3x2xcomplex<f32>>) : !fir.array<2x3x!fir.complex<4>>
+
 ! a0 array constructor
 ! CHECK: fir.global internal @_QQro.10xi4.{{.*}}(dense<[1, 2, 3, 3, 3, 3, 3, 3, 3, 3]> : tensor<10xi32>) constant : !fir.array<10xi32>
 
@@ -122,6 +129,9 @@ end subroutine range
 ! a3 array constructor
 ! CHECK: fir.global internal @_QQro.2x3x4xi4.{{.*}}(dense<{{\[\[\[1, 1], \[2, 2], \[3, 3]], \[\[4, 4], \[5, 5], \[6, 6]], \[\[7, 7], \[8, 8], \[9, 9]], \[\[10, 10], \[11, 11], \[12, 12]]]}}> : tensor<4x3x2xi32>) constant : !fir.array<2x3x4xi32>
 
+! c0 array constructor
+! CHECK: fir.global internal @_QQro.2x3xz4.{{.*}}(dense<{{\[}}[(1.000000e+00,1.500000e+00), (2.000000e+00,2.500000e+00)], [(3.000000e+00,3.500000e+00), (4.000000e+00,4.500000e+00)], [(5.000000e+00,5.500000e+00), (6.000000e+00,6.500000e+00)]]> : tensor<3x2xcomplex<f32>>) constant : !fir.array<2x3x!fir.complex<4>>
+
 ! CHECK-LABEL rangeGlobal
 subroutine rangeGlobal()
 ! CHECK: fir.global internal @_QFrangeglobal{{.*}}(dense<[1, 1, 2, 2, 3, 3]> : tensor<6xi32>) : !fir.array<6xi32>


        


More information about the flang-commits mailing list