[flang-commits] [flang] [flang] lower std::complex value argument to tuple<f, f> (PR #110643)

via flang-commits flang-commits at lists.llvm.org
Tue Oct 1 08:43:54 PDT 2024


https://github.com/jeanPerier updated https://github.com/llvm/llvm-project/pull/110643

>From 2086a44951a0537bcd2c902bc6699a96bbe6c11f Mon Sep 17 00:00:00 2001
From: Jean Perier <jperier at nvidia.com>
Date: Tue, 1 Oct 2024 01:11:30 -0700
Subject: [PATCH 1/2] [flang] lower std::complex value argument to tuple<f,f>

---
 .../Optimizer/Builder/Runtime/RTBuilder.h     | 32 ++++++++++++-------
 flang/unittests/Optimizer/RTBuilder.cpp       | 20 ++++++++++++
 2 files changed, 41 insertions(+), 11 deletions(-)

diff --git a/flang/include/flang/Optimizer/Builder/Runtime/RTBuilder.h b/flang/include/flang/Optimizer/Builder/Runtime/RTBuilder.h
index a103861f1510b8..e7facb7c280e97 100644
--- a/flang/include/flang/Optimizer/Builder/Runtime/RTBuilder.h
+++ b/flang/include/flang/Optimizer/Builder/Runtime/RTBuilder.h
@@ -400,17 +400,33 @@ constexpr TypeBuilderFunc getModel<bool &>() {
     return fir::ReferenceType::get(f(context));
   };
 }
+
+// Note about getModel<std::complex<T>>
+// Prefer passing/returning the complex by reference in the runtime to
+// avoid ABI issues.
+// C++ std::complex is not an intrinsic type, and it while it is storage
+// compatible with C/Fortran complex type, it follows the struct value passing
+// ABI rule, which may differ from how C complex are passed on some platforms.
 template <>
 constexpr TypeBuilderFunc getModel<std::complex<float>>() {
   return [](mlir::MLIRContext *context) -> mlir::Type {
-    return mlir::ComplexType::get(mlir::FloatType::getF32(context));
+    mlir::Type floatTy = getModel<float>()(context);
+    return mlir::TupleType::get(context, {floatTy, floatTy});
   };
 }
+template <>
+constexpr TypeBuilderFunc getModel<std::complex<double>>() {
+  return [](mlir::MLIRContext *context) -> mlir::Type {
+    mlir::Type floatTy = getModel<double>()(context);
+    return mlir::TupleType::get(context, {floatTy, floatTy});
+  };
+}
+
 template <>
 constexpr TypeBuilderFunc getModel<std::complex<float> &>() {
   return [](mlir::MLIRContext *context) -> mlir::Type {
-    TypeBuilderFunc f{getModel<std::complex<float>>()};
-    return fir::ReferenceType::get(f(context));
+    mlir::Type floatTy = getModel<float>()(context);
+    return fir::ReferenceType::get(mlir::ComplexType::get(floatTy));
   };
 }
 template <>
@@ -422,16 +438,10 @@ constexpr TypeBuilderFunc getModel<const std::complex<float> *>() {
   return getModel<std::complex<float> *>();
 }
 template <>
-constexpr TypeBuilderFunc getModel<std::complex<double>>() {
-  return [](mlir::MLIRContext *context) -> mlir::Type {
-    return mlir::ComplexType::get(mlir::FloatType::getF64(context));
-  };
-}
-template <>
 constexpr TypeBuilderFunc getModel<std::complex<double> &>() {
   return [](mlir::MLIRContext *context) -> mlir::Type {
-    TypeBuilderFunc f{getModel<std::complex<double>>()};
-    return fir::ReferenceType::get(f(context));
+    mlir::Type floatTy = getModel<double>()(context);
+    return fir::ReferenceType::get(mlir::ComplexType::get(floatTy));
   };
 }
 template <>
diff --git a/flang/unittests/Optimizer/RTBuilder.cpp b/flang/unittests/Optimizer/RTBuilder.cpp
index d6cf96c4351c2b..47c12cfcbc4a08 100644
--- a/flang/unittests/Optimizer/RTBuilder.cpp
+++ b/flang/unittests/Optimizer/RTBuilder.cpp
@@ -9,6 +9,7 @@
 #include "flang/Optimizer/Builder/Runtime/RTBuilder.h"
 #include "gtest/gtest.h"
 #include "flang/Optimizer/Support/InitFIR.h"
+#include <complex>
 
 // Check that it is possible to make a difference between complex runtime
 // function using C99 complex and C++ std::complex. This is important since
@@ -18,6 +19,7 @@
 
 // Fake runtime header to be introspected.
 c_float_complex_t c99_cacosf(c_float_complex_t);
+std::complex<float> cpp_runtime(std::complex<float>);
 
 TEST(RTBuilderTest, ComplexRuntimeInterface) {
   mlir::DialectRegistry registry;
@@ -34,3 +36,21 @@ TEST(RTBuilderTest, ComplexRuntimeInterface) {
   EXPECT_EQ(c99_cacosf_funcTy.getInput(0), cplx_ty);
   EXPECT_EQ(c99_cacosf_funcTy.getResult(0), cplx_ty);
 }
+
+TEST(RTBuilderTest, CppComplexRuntimeInterface) {
+  mlir::DialectRegistry registry;
+  fir::support::registerDialects(registry);
+  mlir::MLIRContext ctx(registry);
+  fir::support::loadDialects(ctx);
+  mlir::Type cpp_runtime_signature{
+      fir::runtime::RuntimeTableKey<decltype(cpp_runtime)>::getTypeModel()(
+          &ctx)};
+  auto cpp_runtime_funcTy =
+      mlir::cast<mlir::FunctionType>(cpp_runtime_signature);
+  EXPECT_EQ(cpp_runtime_funcTy.getNumInputs(), 1u);
+  EXPECT_EQ(cpp_runtime_funcTy.getNumResults(), 1u);
+  auto fp = mlir::FloatType::getF32(&ctx);
+  auto cplx_ty = mlir::TupleType::get(&ctx, {fp, fp});
+  EXPECT_EQ(cpp_runtime_funcTy.getInput(0), cplx_ty);
+  EXPECT_EQ(cpp_runtime_funcTy.getResult(0), cplx_ty);
+}

>From ccd01b86b76144d79f045a68d856b6096b6ed4e5 Mon Sep 17 00:00:00 2001
From: Jean Perier <jperier at nvidia.com>
Date: Tue, 1 Oct 2024 08:42:48 -0700
Subject: [PATCH 2/2] remove std::complex value handling

---
 .../Optimizer/Builder/Runtime/RTBuilder.h     | 59 +++++++++++++------
 flang/unittests/Optimizer/RTBuilder.cpp       | 19 ------
 2 files changed, 40 insertions(+), 38 deletions(-)

diff --git a/flang/include/flang/Optimizer/Builder/Runtime/RTBuilder.h b/flang/include/flang/Optimizer/Builder/Runtime/RTBuilder.h
index e7facb7c280e97..6acabf53853911 100644
--- a/flang/include/flang/Optimizer/Builder/Runtime/RTBuilder.h
+++ b/flang/include/flang/Optimizer/Builder/Runtime/RTBuilder.h
@@ -401,26 +401,12 @@ constexpr TypeBuilderFunc getModel<bool &>() {
   };
 }
 
-// Note about getModel<std::complex<T>>
+// getModel<std::complex<T>> are not implemented on purpose.
 // Prefer passing/returning the complex by reference in the runtime to
 // avoid ABI issues.
 // C++ std::complex is not an intrinsic type, and it while it is storage
 // compatible with C/Fortran complex type, it follows the struct value passing
 // ABI rule, which may differ from how C complex are passed on some platforms.
-template <>
-constexpr TypeBuilderFunc getModel<std::complex<float>>() {
-  return [](mlir::MLIRContext *context) -> mlir::Type {
-    mlir::Type floatTy = getModel<float>()(context);
-    return mlir::TupleType::get(context, {floatTy, floatTy});
-  };
-}
-template <>
-constexpr TypeBuilderFunc getModel<std::complex<double>>() {
-  return [](mlir::MLIRContext *context) -> mlir::Type {
-    mlir::Type floatTy = getModel<double>()(context);
-    return mlir::TupleType::get(context, {floatTy, floatTy});
-  };
-}
 
 template <>
 constexpr TypeBuilderFunc getModel<std::complex<float> &>() {
@@ -531,10 +517,45 @@ REDUCTION_VALUE_OPERATION_MODEL(double)
 REDUCTION_REF_OPERATION_MODEL(long double)
 REDUCTION_VALUE_OPERATION_MODEL(long double)
 
-REDUCTION_REF_OPERATION_MODEL(std::complex<float>)
-REDUCTION_VALUE_OPERATION_MODEL(std::complex<float>)
-REDUCTION_REF_OPERATION_MODEL(std::complex<double>)
-REDUCTION_VALUE_OPERATION_MODEL(std::complex<double>)
+// FIXME: the runtime is not using the correct ABIs when calling complex
+// callbacks. lowering either need to create wrappers or just have an inline
+// implementation for it. https://github.com/llvm/llvm-project/issues/110674
+template <>
+constexpr TypeBuilderFunc
+getModel<Fortran::runtime::ValueReductionOperation<std::complex<float>>>() {
+  return [](mlir::MLIRContext *context) -> mlir::Type {
+    mlir::Type cplx = mlir::ComplexType::get(getModel<float>()(context));
+    auto refTy = fir::ReferenceType::get(cplx);
+    return mlir::FunctionType::get(context, {cplx, cplx}, refTy);
+  };
+}
+template <>
+constexpr TypeBuilderFunc
+getModel<Fortran::runtime::ValueReductionOperation<std::complex<double>>>() {
+  return [](mlir::MLIRContext *context) -> mlir::Type {
+    mlir::Type cplx = mlir::ComplexType::get(getModel<double>()(context));
+    auto refTy = fir::ReferenceType::get(cplx);
+    return mlir::FunctionType::get(context, {cplx, cplx}, refTy);
+  };
+}
+template <>
+constexpr TypeBuilderFunc
+getModel<Fortran::runtime::ReferenceReductionOperation<std::complex<float>>>() {
+  return [](mlir::MLIRContext *context) -> mlir::Type {
+    mlir::Type cplx = mlir::ComplexType::get(getModel<float>()(context));
+    auto refTy = fir::ReferenceType::get(cplx);
+    return mlir::FunctionType::get(context, {refTy, refTy}, refTy);
+  };
+}
+template <>
+constexpr TypeBuilderFunc getModel<
+    Fortran::runtime::ReferenceReductionOperation<std::complex<double>>>() {
+  return [](mlir::MLIRContext *context) -> mlir::Type {
+    mlir::Type cplx = mlir::ComplexType::get(getModel<double>()(context));
+    auto refTy = fir::ReferenceType::get(cplx);
+    return mlir::FunctionType::get(context, {refTy, refTy}, refTy);
+  };
+}
 
 REDUCTION_CHAR_OPERATION_MODEL(char)
 REDUCTION_CHAR_OPERATION_MODEL(char16_t)
diff --git a/flang/unittests/Optimizer/RTBuilder.cpp b/flang/unittests/Optimizer/RTBuilder.cpp
index 47c12cfcbc4a08..2ccc4353f9ee40 100644
--- a/flang/unittests/Optimizer/RTBuilder.cpp
+++ b/flang/unittests/Optimizer/RTBuilder.cpp
@@ -19,7 +19,6 @@
 
 // Fake runtime header to be introspected.
 c_float_complex_t c99_cacosf(c_float_complex_t);
-std::complex<float> cpp_runtime(std::complex<float>);
 
 TEST(RTBuilderTest, ComplexRuntimeInterface) {
   mlir::DialectRegistry registry;
@@ -36,21 +35,3 @@ TEST(RTBuilderTest, ComplexRuntimeInterface) {
   EXPECT_EQ(c99_cacosf_funcTy.getInput(0), cplx_ty);
   EXPECT_EQ(c99_cacosf_funcTy.getResult(0), cplx_ty);
 }
-
-TEST(RTBuilderTest, CppComplexRuntimeInterface) {
-  mlir::DialectRegistry registry;
-  fir::support::registerDialects(registry);
-  mlir::MLIRContext ctx(registry);
-  fir::support::loadDialects(ctx);
-  mlir::Type cpp_runtime_signature{
-      fir::runtime::RuntimeTableKey<decltype(cpp_runtime)>::getTypeModel()(
-          &ctx)};
-  auto cpp_runtime_funcTy =
-      mlir::cast<mlir::FunctionType>(cpp_runtime_signature);
-  EXPECT_EQ(cpp_runtime_funcTy.getNumInputs(), 1u);
-  EXPECT_EQ(cpp_runtime_funcTy.getNumResults(), 1u);
-  auto fp = mlir::FloatType::getF32(&ctx);
-  auto cplx_ty = mlir::TupleType::get(&ctx, {fp, fp});
-  EXPECT_EQ(cpp_runtime_funcTy.getInput(0), cplx_ty);
-  EXPECT_EQ(cpp_runtime_funcTy.getResult(0), cplx_ty);
-}



More information about the flang-commits mailing list