[flang-commits] [flang] [flang] replace fir.complex usages with mlir complex (PR #110850)

via flang-commits flang-commits at lists.llvm.org
Wed Oct 2 07:41:38 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-flang-openmp

Author: None (jeanPerier)

<details>
<summary>Changes</summary>

Core patch of https://discourse.llvm.org/t/rfc-flang-replace-usages-of-fir-complex-by-mlir-complex-type/82292.
After that, the last step is to remove fir.complex from FIR types.

---

Patch is 534.99 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/110850.diff


147 Files Affected:

- (modified) flang/include/flang/Optimizer/Builder/Complex.h (+3-4) 
- (modified) flang/include/flang/Optimizer/Builder/IntrinsicCall.h (+1-1) 
- (modified) flang/include/flang/Optimizer/Builder/Runtime/RTBuilder.h (+4-2) 
- (modified) flang/include/flang/Optimizer/Dialect/FIROps.td (+4-4) 
- (modified) flang/include/flang/Optimizer/Dialect/FIRTypes.td (+5-4) 
- (modified) flang/include/flang/Optimizer/Support/Utils.h (+32-46) 
- (modified) flang/lib/Lower/Bridge.cpp (+2-4) 
- (modified) flang/lib/Lower/ConvertConstant.cpp (+5-11) 
- (modified) flang/lib/Lower/ConvertExpr.cpp (+2-2) 
- (modified) flang/lib/Lower/ConvertExprToHLFIR.cpp (+1-1) 
- (modified) flang/lib/Lower/ConvertType.cpp (+1-4) 
- (modified) flang/lib/Lower/ConvertVariable.cpp (+1-1) 
- (modified) flang/lib/Lower/IO.cpp (+4-4) 
- (modified) flang/lib/Lower/OpenACC.cpp (+7-8) 
- (modified) flang/lib/Lower/OpenMP/ClauseProcessor.cpp (+1-1) 
- (modified) flang/lib/Lower/OpenMP/ReductionProcessor.cpp (+2-3) 
- (modified) flang/lib/Optimizer/Builder/Complex.cpp (+8-9) 
- (modified) flang/lib/Optimizer/Builder/FIRBuilder.cpp (+1-4) 
- (modified) flang/lib/Optimizer/Builder/IntrinsicCall.cpp (+14-42) 
- (modified) flang/lib/Optimizer/CodeGen/Target.cpp (+1-1) 
- (modified) flang/lib/Optimizer/Dialect/FIROps.cpp (+1-1) 
- (modified) flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp (+3-6) 
- (modified) flang/lib/Optimizer/Transforms/CufOpConversion.cpp (+3-3) 
- (modified) flang/lib/Optimizer/Transforms/DebugTypeGenerator.cpp (+3-10) 
- (modified) flang/lib/Optimizer/Transforms/LoopVersioning.cpp (+1-1) 
- (modified) flang/test/Fir/array-copies-pointers.fir (+2-2) 
- (modified) flang/test/Fir/compare.fir (+5-5) 
- (modified) flang/test/Fir/convert-to-llvm-openmp-and-fir.fir (+3-3) 
- (modified) flang/test/Fir/convert-to-llvm.fir (+33-33) 
- (modified) flang/test/Fir/convert.fir (+3-3) 
- (modified) flang/test/Fir/fir-ops.fir (+52-52) 
- (modified) flang/test/Fir/fir-types.fir (+2-2) 
- (modified) flang/test/Fir/rebox.fir (+5-5) 
- (modified) flang/test/Fir/struct-passing-x86-64-byval.fir (+8-8) 
- (modified) flang/test/Fir/target-complex16.f90 (+8-8) 
- (modified) flang/test/Fir/target-rewrite-arg-position.fir (+9-9) 
- (modified) flang/test/Fir/target-rewrite-complex-10-x86.fir (+18-18) 
- (modified) flang/test/Fir/target-rewrite-complex.fir (+616-616) 
- (modified) flang/test/Fir/target-rewrite-complex16.fir (+50-50) 
- (modified) flang/test/Fir/target-rewrite-selective.fir (+21-21) 
- (modified) flang/test/Fir/target.fir (+16-16) 
- (modified) flang/test/Fir/types-to-llvm.fir (+6-6) 
- (modified) flang/test/Fir/undo-complex-pattern.fir (+48-48) 
- (modified) flang/test/HLFIR/assign-codegen.fir (+11-11) 
- (modified) flang/test/HLFIR/assign.fir (+10-10) 
- (modified) flang/test/HLFIR/associate.fir (+9-9) 
- (modified) flang/test/HLFIR/assumed_shape_with_value_keyword.f90 (+12-12) 
- (modified) flang/test/HLFIR/designate-codegen-complex-part.fir (+31-31) 
- (modified) flang/test/HLFIR/designate.fir (+8-8) 
- (modified) flang/test/HLFIR/invalid.fir (+5-5) 
- (modified) flang/test/HLFIR/opt-scalar-assign.fir (+13-13) 
- (modified) flang/test/Intrinsics/math-codegen.fir (+69-69) 
- (modified) flang/test/Lower/HLFIR/array-ctor-as-inlined-temp.f90 (+13-13) 
- (modified) flang/test/Lower/HLFIR/assignment-intrinsics.f90 (+8-8) 
- (modified) flang/test/Lower/HLFIR/binary-ops.f90 (+41-41) 
- (modified) flang/test/Lower/HLFIR/calls-f77.f90 (+1-1) 
- (modified) flang/test/Lower/HLFIR/calls-percent-val-ref.f90 (+12-12) 
- (modified) flang/test/Lower/HLFIR/constant.f90 (+3-3) 
- (modified) flang/test/Lower/HLFIR/conversion-ops.f90 (+18-18) 
- (modified) flang/test/Lower/HLFIR/designators-component-ref.f90 (+2-2) 
- (modified) flang/test/Lower/HLFIR/designators.f90 (+8-8) 
- (modified) flang/test/Lower/HLFIR/entry_return.f90 (+26-26) 
- (modified) flang/test/Lower/HLFIR/implicit-call-mismatch.f90 (+1-1) 
- (modified) flang/test/Lower/HLFIR/implicit-type-conversion-allocatable.f90 (+7-7) 
- (modified) flang/test/Lower/HLFIR/initial-target-component.f90 (+3-3) 
- (modified) flang/test/Lower/HLFIR/intrinsic-dynamically-optional.f90 (+19-19) 
- (modified) flang/test/Lower/HLFIR/unary-ops.f90 (+9-9) 
- (modified) flang/test/Lower/HLFIR/user-defined-assignment.f90 (+5-5) 
- (modified) flang/test/Lower/HLFIR/vector-subscript-as-value.f90 (+3-3) 
- (modified) flang/test/Lower/Intrinsics/abs.f90 (+10-14) 
- (modified) flang/test/Lower/Intrinsics/acos.f90 (+2-2) 
- (modified) flang/test/Lower/Intrinsics/acos_complex16.f90 (+1-1) 
- (modified) flang/test/Lower/Intrinsics/acosh.f90 (+2-2) 
- (modified) flang/test/Lower/Intrinsics/acosh_complex16.f90 (+1-1) 
- (modified) flang/test/Lower/Intrinsics/asin.f90 (+2-2) 
- (modified) flang/test/Lower/Intrinsics/asin_complex16.f90 (+1-1) 
- (modified) flang/test/Lower/Intrinsics/asinh.f90 (+2-2) 
- (modified) flang/test/Lower/Intrinsics/asinh_complex16.f90 (+1-1) 
- (modified) flang/test/Lower/Intrinsics/atan.f90 (+4-4) 
- (modified) flang/test/Lower/Intrinsics/atan_complex16.f90 (+1-1) 
- (modified) flang/test/Lower/Intrinsics/atanh.f90 (+2-2) 
- (modified) flang/test/Lower/Intrinsics/atanh_complex16.f90 (+1-1) 
- (modified) flang/test/Lower/Intrinsics/cabs_real16.f90 (+1-1) 
- (modified) flang/test/Lower/Intrinsics/cmplx.f90 (+20-20) 
- (modified) flang/test/Lower/Intrinsics/cos_complex16.f90 (+1-1) 
- (modified) flang/test/Lower/Intrinsics/cosh_complex16.f90 (+1-1) 
- (modified) flang/test/Lower/Intrinsics/dconjg.f90 (+6-6) 
- (modified) flang/test/Lower/Intrinsics/dimag.f90 (+3-3) 
- (modified) flang/test/Lower/Intrinsics/dot_product.f90 (+45-52) 
- (modified) flang/test/Lower/Intrinsics/dreal.f90 (+3-3) 
- (modified) flang/test/Lower/Intrinsics/exp.f90 (+20-28) 
- (modified) flang/test/Lower/Intrinsics/exp_complex16.f90 (+1-1) 
- (modified) flang/test/Lower/Intrinsics/log.f90 (+20-28) 
- (modified) flang/test/Lower/Intrinsics/log_complex16.f90 (+1-1) 
- (modified) flang/test/Lower/Intrinsics/pow_complex16.f90 (+1-1) 
- (modified) flang/test/Lower/Intrinsics/pow_complex16i.f90 (+1-1) 
- (modified) flang/test/Lower/Intrinsics/pow_complex16k.f90 (+1-1) 
- (modified) flang/test/Lower/Intrinsics/product.f90 (+8-10) 
- (modified) flang/test/Lower/Intrinsics/sin_complex16.f90 (+1-1) 
- (modified) flang/test/Lower/Intrinsics/sinh_complex16.f90 (+1-1) 
- (modified) flang/test/Lower/Intrinsics/sqrt_complex16.f90 (+1-1) 
- (modified) flang/test/Lower/Intrinsics/sum.f90 (+8-10) 
- (modified) flang/test/Lower/Intrinsics/tan_complex16.f90 (+1-1) 
- (modified) flang/test/Lower/Intrinsics/tanh_complex16.f90 (+1-1) 
- (modified) flang/test/Lower/OpenACC/acc-reduction.f90 (+34-34) 
- (modified) flang/test/Lower/OpenMP/DelayedPrivatization/target-private-multiple-variables.f90 (+3-3) 
- (modified) flang/test/Lower/OpenMP/copyprivate.f90 (+8-8) 
- (modified) flang/test/Lower/OpenMP/lastprivate-allocatable.f90 (+7-7) 
- (modified) flang/test/Lower/OpenMP/parallel-firstprivate-clause-scalar.f90 (+13-13) 
- (modified) flang/test/Lower/OpenMP/parallel-private-clause.f90 (+3-3) 
- (modified) flang/test/Lower/OpenMP/parallel-reduction-complex-mul.f90 (+17-17) 
- (modified) flang/test/Lower/OpenMP/parallel-reduction-complex.f90 (+17-17) 
- (modified) flang/test/Lower/OpenMP/private-commonblock.f90 (+15-15) 
- (modified) flang/test/Lower/OpenMP/task.f90 (+1-1) 
- (modified) flang/test/Lower/OpenMP/threadprivate-commonblock.f90 (+14-14) 
- (modified) flang/test/Lower/OpenMP/threadprivate-non-global.f90 (+11-11) 
- (modified) flang/test/Lower/OpenMP/threadprivate-real-logical-complex-derivedtype.f90 (+10-10) 
- (modified) flang/test/Lower/array-constructor-1.f90 (+4-4) 
- (modified) flang/test/Lower/array-elemental-calls-2.f90 (+4-4) 
- (modified) flang/test/Lower/array-elemental-subroutines.f90 (+3-3) 
- (modified) flang/test/Lower/array.f90 (+2-2) 
- (modified) flang/test/Lower/assignment.f90 (+50-50) 
- (modified) flang/test/Lower/basic-function.f90 (+12-12) 
- (modified) flang/test/Lower/call-bindc.f90 (+1-1) 
- (modified) flang/test/Lower/call-by-value.f90 (+4-4) 
- (modified) flang/test/Lower/complex-operations.f90 (+49-55) 
- (modified) flang/test/Lower/complex-real.f90 (+2-2) 
- (modified) flang/test/Lower/derived-assignments.f90 (+11-11) 
- (modified) flang/test/Lower/dummy-procedure.f90 (+4-4) 
- (modified) flang/test/Lower/entry-statement.f90 (+8-8) 
- (modified) flang/test/Lower/implicit-call-mismatch.f90 (+6-6) 
- (modified) flang/test/Lower/math-lowering/abs.f90 (+4-4) 
- (modified) flang/test/Lower/pointer.f90 (+3-3) 
- (modified) flang/test/Lower/polymorphic.f90 (+2-2) 
- (modified) flang/test/Lower/sqrt.f90 (+4-4) 
- (modified) flang/test/Lower/trigonometric-intrinsics.f90 (+24-24) 
- (modified) flang/test/Lower/vector-subscript-io.f90 (+5-5) 
- (modified) flang/test/Transforms/debug-90683.fir (+5-5) 
- (modified) flang/test/Transforms/debug-complex-1.fir (+10-10) 
- (modified) flang/test/Transforms/debug-derived-type-1.fir (+4-4) 
- (modified) flang/test/Transforms/loop-versioning.fir (+16-16) 
- (modified) flang/test/Transforms/omp-map-info-finalization.fir (+8-8) 
- (modified) flang/test/Transforms/simplifyintrinsics.fir (+11-11) 
- (modified) flang/unittests/Optimizer/Builder/ComplexTest.cpp (+2-9) 
- (modified) flang/unittests/Optimizer/Builder/Runtime/RuntimeCallTestBase.h (+4-4) 
- (modified) flang/unittests/Optimizer/FIRTypesTest.cpp (+2-2) 
- (modified) flang/unittests/Optimizer/RTBuilder.cpp (+1-1) 


``````````diff
diff --git a/flang/include/flang/Optimizer/Builder/Complex.h b/flang/include/flang/Optimizer/Builder/Complex.h
index 3286e41ee039e5..cd0a876a4cef00 100644
--- a/flang/include/flang/Optimizer/Builder/Complex.h
+++ b/flang/include/flang/Optimizer/Builder/Complex.h
@@ -32,13 +32,12 @@ class Complex {
   mlir::Type getComplexPartType(mlir::Value cplx) const;
   mlir::Type getComplexPartType(mlir::Type complexType) const;
 
-  /// Complex operation creation. They create MLIR operations.
-  mlir::Value createComplex(fir::KindTy kind, mlir::Value real,
-                            mlir::Value imag);
-
   /// Create a complex value.
   mlir::Value createComplex(mlir::Type complexType, mlir::Value real,
                             mlir::Value imag);
+  /// Create a complex value given the real and imag parts real type (which
+  /// must be the same).
+  mlir::Value createComplex(mlir::Value real, mlir::Value imag);
 
   /// Returns the Real/Imag part of \p cplx
   mlir::Value extractComplexPart(mlir::Value cplx, bool isImagPart) {
diff --git a/flang/include/flang/Optimizer/Builder/IntrinsicCall.h b/flang/include/flang/Optimizer/Builder/IntrinsicCall.h
index 0b5b88f3981cfc..868a8b4e287424 100644
--- a/flang/include/flang/Optimizer/Builder/IntrinsicCall.h
+++ b/flang/include/flang/Optimizer/Builder/IntrinsicCall.h
@@ -668,7 +668,7 @@ static inline mlir::Type getTypeHelper(mlir::MLIRContext *context,
     r = builder.getRealType(kind);
     break;
   case ParamTypeId::Complex:
-    r = fir::ComplexType::get(context, kind);
+    r = mlir::ComplexType::get(builder.getRealType(kind));
     break;
   }
 
diff --git a/flang/include/flang/Optimizer/Builder/Runtime/RTBuilder.h b/flang/include/flang/Optimizer/Builder/Runtime/RTBuilder.h
index c16f6813c99ad2..66e11b5585d521 100644
--- a/flang/include/flang/Optimizer/Builder/Runtime/RTBuilder.h
+++ b/flang/include/flang/Optimizer/Builder/Runtime/RTBuilder.h
@@ -441,13 +441,15 @@ constexpr TypeBuilderFunc getModel<const std::complex<double> *>() {
 template <>
 constexpr TypeBuilderFunc getModel<c_float_complex_t>() {
   return [](mlir::MLIRContext *context) -> mlir::Type {
-    return fir::ComplexType::get(context, sizeof(float));
+    mlir::Type floatTy = getModel<float>()(context);
+    return mlir::ComplexType::get(floatTy);
   };
 }
 template <>
 constexpr TypeBuilderFunc getModel<c_double_complex_t>() {
   return [](mlir::MLIRContext *context) -> mlir::Type {
-    return fir::ComplexType::get(context, sizeof(double));
+    mlir::Type floatTy = getModel<double>()(context);
+    return mlir::ComplexType::get(floatTy);
   };
 }
 template <>
diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index b34d7629613bad..2a84d8e986c5e3 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -2611,13 +2611,13 @@ class fir_UnaryArithmeticOp<string mnemonic, list<Trait> traits = []> :
 
 class ComplexUnaryArithmeticOp<string mnemonic, list<Trait> traits = []> :
       fir_UnaryArithmeticOp<mnemonic, traits>,
-      Arguments<(ins fir_ComplexType:$operand)>;
+      Arguments<(ins AnyFirComplex:$operand)>;
 
 def fir_NegcOp : ComplexUnaryArithmeticOp<"negc">;
 
 class ComplexArithmeticOp<string mnemonic, list<Trait> traits = []> :
       fir_ArithmeticOp<mnemonic, traits>,
-      Arguments<(ins fir_ComplexType:$lhs, fir_ComplexType:$rhs,
+      Arguments<(ins AnyFirComplex:$lhs, AnyFirComplex:$rhs,
           DefaultValuedAttr<Arith_FastMathAttr,
                             "::mlir::arith::FastMathFlags::none">:$fastmath)>;
 
@@ -2641,8 +2641,8 @@ def fir_CmpcOp : fir_Op<"cmpc",
   }];
 
   let arguments = (ins
-      fir_ComplexType:$lhs,
-      fir_ComplexType:$rhs,
+      AnyFirComplex:$lhs,
+      AnyFirComplex:$rhs,
       DefaultValuedAttr<Arith_FastMathAttr, "::mlir::arith::FastMathFlags::none">:$fastmath);
 
   let results = (outs AnyLogicalLike);
diff --git a/flang/include/flang/Optimizer/Dialect/FIRTypes.td b/flang/include/flang/Optimizer/Dialect/FIRTypes.td
index ae984de63db428..af5f5ed2433dcd 100644
--- a/flang/include/flang/Optimizer/Dialect/FIRTypes.td
+++ b/flang/include/flang/Optimizer/Dialect/FIRTypes.td
@@ -601,12 +601,13 @@ def AnyRealLike : TypeConstraint<Or<[FloatLike.predicate,
     fir_RealType.predicate]>, "any real">;
 def AnyIntegerType : Type<AnyIntegerLike.predicate, "any integer">;
 
+def AnyFirComplexLike : TypeConstraint<Or<[fir_ComplexType.predicate,
+  AnyComplex.predicate]>, "any FIR or MLIR complex type">;
+def AnyFirComplex : Type<AnyFirComplexLike.predicate, "any FIR or MLIR complex type">;
+
 // Composable types
-// Note that we include both fir_ComplexType and AnyComplex, so we can use both
-// the FIR ComplexType and the MLIR ComplexType (the former is used to represent
-// Fortran complex and the latter for C++ std::complex).
 def AnyCompositeLike : TypeConstraint<Or<[fir_RecordType.predicate,
-    fir_SequenceType.predicate, fir_ComplexType.predicate, AnyComplex.predicate,
+    fir_SequenceType.predicate, AnyFirComplexLike.predicate,
     fir_VectorType.predicate, IsTupleTypePred, fir_CharacterType.predicate]>,
     "any composite">;
 
diff --git a/flang/include/flang/Optimizer/Support/Utils.h b/flang/include/flang/Optimizer/Support/Utils.h
index 06cf9c0be157c3..2e25ef5f19bbe1 100644
--- a/flang/include/flang/Optimizer/Support/Utils.h
+++ b/flang/include/flang/Optimizer/Support/Utils.h
@@ -84,25 +84,33 @@ inline std::string mlirTypeToString(mlir::Type type) {
   return result;
 }
 
-inline std::string mlirTypeToIntrinsicFortran(fir::FirOpBuilder &builder,
-                                              mlir::Type type,
-                                              mlir::Location loc,
-                                              const llvm::Twine &name) {
+inline std::optional<int> mlirFloatTypeToKind(mlir::Type type) {
   if (type.isF16())
-    return "REAL(KIND=2)";
+    return 2;
   else if (type.isBF16())
-    return "REAL(KIND=3)";
-  else if (type.isTF32())
-    return "REAL(KIND=unknown)";
+    return 3;
   else if (type.isF32())
-    return "REAL(KIND=4)";
+    return 4;
   else if (type.isF64())
-    return "REAL(KIND=8)";
+    return 8;
   else if (type.isF80())
-    return "REAL(KIND=10)";
+    return 10;
   else if (type.isF128())
-    return "REAL(KIND=16)";
-  else if (type.isInteger(8))
+    return 16;
+  return std::nullopt;
+}
+
+inline std::string mlirTypeToIntrinsicFortran(fir::FirOpBuilder &builder,
+                                              mlir::Type type,
+                                              mlir::Location loc,
+                                              const llvm::Twine &name) {
+  if (auto floatTy = mlir::dyn_cast<mlir::FloatType>(type)) {
+    if (std::optional<int> kind = mlirFloatTypeToKind(type))
+      return "REAL(KIND="s + std::to_string(*kind) + ")";
+  } else if (auto cplxTy = mlir::dyn_cast<mlir::ComplexType>(type)) {
+    if (std::optional<int> kind = mlirFloatTypeToKind(cplxTy.getElementType()))
+      return "COMPLEX(KIND+"s + std::to_string(*kind) + ")";
+  } else if (type.isInteger(8))
     return "INTEGER(KIND=1)";
   else if (type.isInteger(16))
     return "INTEGER(KIND=2)";
@@ -112,18 +120,6 @@ inline std::string mlirTypeToIntrinsicFortran(fir::FirOpBuilder &builder,
     return "INTEGER(KIND=8)";
   else if (type.isInteger(128))
     return "INTEGER(KIND=16)";
-  else if (type == fir::ComplexType::get(builder.getContext(), 2))
-    return "COMPLEX(KIND=2)";
-  else if (type == fir::ComplexType::get(builder.getContext(), 3))
-    return "COMPLEX(KIND=3)";
-  else if (type == fir::ComplexType::get(builder.getContext(), 4))
-    return "COMPLEX(KIND=4)";
-  else if (type == fir::ComplexType::get(builder.getContext(), 8))
-    return "COMPLEX(KIND=8)";
-  else if (type == fir::ComplexType::get(builder.getContext(), 10))
-    return "COMPLEX(KIND=10)";
-  else if (type == fir::ComplexType::get(builder.getContext(), 16))
-    return "COMPLEX(KIND=16)";
   else if (type == fir::LogicalType::get(builder.getContext(), 1))
     return "LOGICAL(KIND=1)";
   else if (type == fir::LogicalType::get(builder.getContext(), 2))
@@ -132,9 +128,9 @@ inline std::string mlirTypeToIntrinsicFortran(fir::FirOpBuilder &builder,
     return "LOGICAL(KIND=4)";
   else if (type == fir::LogicalType::get(builder.getContext(), 8))
     return "LOGICAL(KIND=8)";
-  else
-    fir::emitFatalError(loc, "unsupported type in " + name + ": " +
-                                 fir::mlirTypeToString(type));
+
+  fir::emitFatalError(loc, "unsupported type in " + name + ": " +
+                               fir::mlirTypeToString(type));
 }
 
 inline void intrinsicTypeTODO(fir::FirOpBuilder &builder, mlir::Type type,
@@ -159,19 +155,13 @@ inline void intrinsicTypeTODO2(fir::FirOpBuilder &builder, mlir::Type type1,
 
 inline std::pair<Fortran::common::TypeCategory, KindMapping::KindTy>
 mlirTypeToCategoryKind(mlir::Location loc, mlir::Type type) {
-  if (type.isF16())
-    return {Fortran::common::TypeCategory::Real, 2};
-  else if (type.isBF16())
-    return {Fortran::common::TypeCategory::Real, 3};
-  else if (type.isF32())
-    return {Fortran::common::TypeCategory::Real, 4};
-  else if (type.isF64())
-    return {Fortran::common::TypeCategory::Real, 8};
-  else if (type.isF80())
-    return {Fortran::common::TypeCategory::Real, 10};
-  else if (type.isF128())
-    return {Fortran::common::TypeCategory::Real, 16};
-  else if (type.isInteger(8))
+  if (auto floatTy = mlir::dyn_cast<mlir::FloatType>(type)) {
+    if (std::optional<int> kind = mlirFloatTypeToKind(type))
+      return {Fortran::common::TypeCategory::Real, *kind};
+  } else if (auto cplxTy = mlir::dyn_cast<mlir::ComplexType>(type)) {
+    if (std::optional<int> kind = mlirFloatTypeToKind(cplxTy.getElementType()))
+      return {Fortran::common::TypeCategory::Complex, *kind};
+  } else if (type.isInteger(8))
     return {Fortran::common::TypeCategory::Integer, 1};
   else if (type.isInteger(16))
     return {Fortran::common::TypeCategory::Integer, 2};
@@ -181,17 +171,13 @@ mlirTypeToCategoryKind(mlir::Location loc, mlir::Type type) {
     return {Fortran::common::TypeCategory::Integer, 8};
   else if (type.isInteger(128))
     return {Fortran::common::TypeCategory::Integer, 16};
-  else if (auto complexType = mlir::dyn_cast<fir::ComplexType>(type))
-    return {Fortran::common::TypeCategory::Complex, complexType.getFKind()};
   else if (auto logicalType = mlir::dyn_cast<fir::LogicalType>(type))
     return {Fortran::common::TypeCategory::Logical, logicalType.getFKind()};
   else if (auto charType = mlir::dyn_cast<fir::CharacterType>(type))
     return {Fortran::common::TypeCategory::Character, charType.getFKind()};
   else if (mlir::isa<fir::RecordType>(type))
     return {Fortran::common::TypeCategory::Derived, 0};
-  else
-    fir::emitFatalError(loc,
-                        "unsupported type: " + fir::mlirTypeToString(type));
+  fir::emitFatalError(loc, "unsupported type: " + fir::mlirTypeToString(type));
 }
 
 /// Find the fir.type_info that was created for this \p recordType in \p module,
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index ebcb7613969661..0894a5903635e1 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -5323,10 +5323,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
           LLVM_ATTRIBUTE_UNUSED auto getBitWidth = [this](mlir::Type ty) {
             // 15.6.2.6.3: differering result types should be integer, real,
             // complex or logical
-            if (auto cmplx = mlir::dyn_cast_or_null<fir::ComplexType>(ty)) {
-              fir::KindTy kind = cmplx.getFKind();
-              return 2 * builder->getKindMap().getRealBitsize(kind);
-            }
+            if (auto cmplx = mlir::dyn_cast_or_null<mlir::ComplexType>(ty))
+              return 2 * cmplx.getElementType().getIntOrFloatBitWidth();
             if (auto logical = mlir::dyn_cast_or_null<fir::LogicalType>(ty)) {
               fir::KindTy kind = logical.getFKind();
               return builder->getKindMap().getLogicalBitsize(kind);
diff --git a/flang/lib/Lower/ConvertConstant.cpp b/flang/lib/Lower/ConvertConstant.cpp
index 3361817ee27ee2..748be508235f17 100644
--- a/flang/lib/Lower/ConvertConstant.cpp
+++ b/flang/lib/Lower/ConvertConstant.cpp
@@ -152,9 +152,6 @@ class DenseGlobalBuilder {
                       : TC;
     attributeElementType = Fortran::lower::getFIRType(
         builder.getContext(), attrTc, KIND, std::nullopt);
-    if (auto firCTy = mlir::dyn_cast<fir::ComplexType>(attributeElementType))
-      attributeElementType =
-          mlir::ComplexType::get(firCTy.getEleType(builder.getKindMap()));
     for (auto element : constant.values())
       attributes.push_back(
           convertToAttribute<TC, KIND>(builder, element, attributeElementType));
@@ -264,14 +261,11 @@ static mlir::Value genScalarLit(
       return genRealConstant<KIND>(builder, loc, floatVal);
     }
   } else if constexpr (TC == Fortran::common::TypeCategory::Complex) {
-    mlir::Value realPart =
-        genScalarLit<Fortran::common::TypeCategory::Real, KIND>(builder, loc,
-                                                                value.REAL());
-    mlir::Value imagPart =
-        genScalarLit<Fortran::common::TypeCategory::Real, KIND>(builder, loc,
-                                                                value.AIMAG());
-    return fir::factory::Complex{builder, loc}.createComplex(KIND, realPart,
-                                                             imagPart);
+    mlir::Value real = genScalarLit<Fortran::common::TypeCategory::Real, KIND>(
+        builder, loc, value.REAL());
+    mlir::Value imag = genScalarLit<Fortran::common::TypeCategory::Real, KIND>(
+        builder, loc, value.AIMAG());
+    return fir::factory::Complex{builder, loc}.createComplex(real, imag);
   } else /*constexpr*/ {
     llvm_unreachable("unhandled constant");
   }
diff --git a/flang/lib/Lower/ConvertExpr.cpp b/flang/lib/Lower/ConvertExpr.cpp
index 62a7615e1af13c..72c236be42ce3d 100644
--- a/flang/lib/Lower/ConvertExpr.cpp
+++ b/flang/lib/Lower/ConvertExpr.cpp
@@ -1127,7 +1127,7 @@ class ScalarExprLowering {
   ExtValue genval(const Fortran::evaluate::ComplexConstructor<KIND> &op) {
     mlir::Value realPartValue = genunbox(op.left());
     return fir::factory::Complex{builder, getLoc()}.createComplex(
-        KIND, realPartValue, genunbox(op.right()));
+        realPartValue, genunbox(op.right()));
   }
 
   template <int KIND>
@@ -5242,7 +5242,7 @@ class ArrayExprLowering {
     return [=](IterSpace iters) -> ExtValue {
       mlir::Value lhs = fir::getBase(lf(iters));
       mlir::Value rhs = fir::getBase(rf(iters));
-      return fir::factory::Complex{builder, loc}.createComplex(KIND, lhs, rhs);
+      return fir::factory::Complex{builder, loc}.createComplex(lhs, rhs);
     };
   }
 
diff --git a/flang/lib/Lower/ConvertExprToHLFIR.cpp b/flang/lib/Lower/ConvertExprToHLFIR.cpp
index 1933f38f735b57..98ecd156a65e3a 100644
--- a/flang/lib/Lower/ConvertExprToHLFIR.cpp
+++ b/flang/lib/Lower/ConvertExprToHLFIR.cpp
@@ -1218,7 +1218,7 @@ struct BinaryOp<Fortran::evaluate::ComplexConstructor<KIND>> {
                                          fir::FirOpBuilder &builder, const Op &,
                                          hlfir::Entity lhs, hlfir::Entity rhs) {
     mlir::Value res =
-        fir::factory::Complex{builder, loc}.createComplex(KIND, lhs, rhs);
+        fir::factory::Complex{builder, loc}.createComplex(lhs, rhs);
     return hlfir::EntityWithAttributes{res};
   }
 };
diff --git a/flang/lib/Lower/ConvertType.cpp b/flang/lib/Lower/ConvertType.cpp
index a47fc99ea9f456..8664477b50078a 100644
--- a/flang/lib/Lower/ConvertType.cpp
+++ b/flang/lib/Lower/ConvertType.cpp
@@ -96,10 +96,7 @@ static mlir::Type genCharacterType(
 }
 
 static mlir::Type genComplexType(mlir::MLIRContext *context, int KIND) {
-  if (Fortran::evaluate::IsValidKindOfIntrinsicType(
-          Fortran::common::TypeCategory::Complex, KIND))
-    return fir::ComplexType::get(context, KIND);
-  return {};
+  return mlir::ComplexType::get(genRealType(context, KIND));
 }
 
 static mlir::Type
diff --git a/flang/lib/Lower/ConvertVariable.cpp b/flang/lib/Lower/ConvertVariable.cpp
index f76d44f5479d32..f9635408d6e8c6 100644
--- a/flang/lib/Lower/ConvertVariable.cpp
+++ b/flang/lib/Lower/ConvertVariable.cpp
@@ -519,7 +519,7 @@ static fir::GlobalOp defineGlobal(Fortran::lower::AbstractConverter &converter,
   if (mlir::isa<fir::SequenceType>(symTy) &&
       !Fortran::semantics::IsAllocatableOrPointer(sym)) {
     mlir::Type eleTy = mlir::cast<fir::SequenceType>(symTy).getEleTy();
-    if (mlir::isa<mlir::IntegerType, mlir::FloatType, fir::ComplexType,
+    if (mlir::isa<mlir::IntegerType, mlir::FloatType, mlir::ComplexType,
                   fir::LogicalType>(eleTy)) {
       const auto *details =
           sym.detailsIf<Fortran::semantics::ObjectEntityDetails>();
diff --git a/flang/lib/Lower/IO.cpp b/flang/lib/Lower/IO.cpp
index 9e98b230b676f6..1894b0cfd1bec2 100644
--- a/flang/lib/Lower/IO.cpp
+++ b/flang/lib/Lower/IO.cpp
@@ -684,9 +684,9 @@ static mlir::func::FuncOp getOutputFunc(mlir::Location loc,
       return getIORuntimeFunc<mkIOKey(OutputReal64)>(loc, builder);
   }
   auto kindMap = fir::getKindMapping(builder.getModule());
-  if (auto ty = mlir::dyn_cast<fir::ComplexType>(type)) {
+  if (auto ty = mlir::dyn_cast<mlir::ComplexType>(type)) {
     // COMPLEX(KIND=k) corresponds to a pair of REAL(KIND=k).
-    auto width = kindMap.getRealBitsize(ty.getFKind());
+    auto width = mlir::cast<mlir::FloatType>(ty.getElementType()).getWidth();
     if (width == 32)
       return getIORuntimeFunc<mkIOKey(OutputComplex32)>(loc, builder);
     else if (width == 64)
@@ -788,8 +788,8 @@ static mlir::func::FuncOp getInputFunc(mlir::Location loc,
       return getIORuntimeFunc<mkIOKey(InputReal64)>(loc, builder);
   }
   auto kindMap = fir::getKindMapping(builder.getModule());
-  if (auto ty = mlir::dyn_cast<fir::ComplexType>(type)) {
-    auto width = kindMap.getRealBitsize(ty.getFKind());
+  if (auto ty = mlir::dyn_cast<mlir::ComplexType>(type)) {
+    auto width = mlir::cast<mlir::FloatType>(ty.getElementType()).getWidth();
     if (width == 32)
       return getIORuntimeFunc<mkIOKey(InputComplex32)>(loc, builder);
     else if (width == 64)
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index c97398fc43f923..878dccc4ecbc4b 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -998,7 +998,7 @@ static mlir::Value getReductionInitValue(fir::FirOpBuilder &builder,
         builder.getIntegerAttr(ty, getReductionInitValue<llvm::APInt>(op, ty)));
   if (op == mlir::acc::ReductionOperator::AccMin ||
       op == mlir::acc::ReductionOperator::AccMax) {
-    if (mlir::isa<fir::ComplexType>(ty))
+    if (mlir::isa<mlir::ComplexType>(ty))
       llvm::report_fatal_error(
           "min/max reduction not supported for complex type");
     if (auto floatTy = mlir::dyn_cast_or_null<mlir::FloatType>(ty))
@@ -1010,14 +1010,13 @@ static mlir::Value getReductionInitValue(fir::FirOpBuilder &builder,
     return builder.create<mlir::arith::ConstantOp>(
         loc, ty,
         builder.getFloatAttr(ty, getReductionInitValue<int64_t>(op, ty)));
-  } else if (auto cmplxTy = mlir::dyn_cast_or_null<fir::ComplexType>(ty)) {
-    mlir::Type floatTy =
-        Fortran::lower::convertReal(builder.getContext(), cmplxTy.getFKind());
+  } else if (auto cmplxTy = mlir::dyn_cast_or_null<mlir::ComplexType>(ty)) {
+    mlir::Type floatTy = cmplxTy.getElementType();
     mlir::Value realInit = builder.createRealConstant(
         loc, floatTy, getReductionInitValue<int64_t>(op, cmplxTy));
     mlir::Value imagInit = builder.createRealConstant(loc, floatTy, 0.0);
-    return fir::factory::Complex{builder, loc}.createComplex(
-        cmplxTy.getFKind(), realInit, imagInit);
+    return fir::factory::Complex{builder, loc}.createComplex(cmplxTy, realInit,
+                                                             imagInit);
   }
 
   if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(ty))
@@ -1136,7 +1135,7 @@ static mlir::Value genScalarCombiner(fir::FirOpBuilder &builder,
       return builder.create<mlir::arith::AddIOp>(loc, value1, value2);
     if (mlir::isa<mlir::FloatType>(ty))
   ...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/110850


More information about the flang-commits mailing list