[flang-commits] [flang] [flang][OpenMP] Add support for complex reductions (PR #87488)

Mats Petersson via flang-commits flang-commits at lists.llvm.org
Fri Apr 5 03:33:01 PDT 2024


https://github.com/Leporacanthicus updated https://github.com/llvm/llvm-project/pull/87488

>From 91356512f3c18ee6c5c44ae2f91ae10e7a6c3dce Mon Sep 17 00:00:00 2001
From: Mats Petersson <mats.petersson at arm.com>
Date: Wed, 3 Apr 2024 13:43:30 +0100
Subject: [PATCH 1/3] [FLANG Add support for complex OpenMP reductions

The SALMON application uses OpenMP reductions on complex values,
which wasn't supported in Flang. This adds the basic support
for this functionality.
---
 flang/lib/Lower/OpenMP/ReductionProcessor.cpp | 23 +++++++++++++++----
 flang/lib/Lower/OpenMP/ReductionProcessor.h   | 21 ++++++++++++++++-
 2 files changed, 38 insertions(+), 6 deletions(-)

diff --git a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
index c1c94119fd9083..f06209c0e62032 100644
--- a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
@@ -13,7 +13,9 @@
 #include "ReductionProcessor.h"
 
 #include "flang/Lower/AbstractConverter.h"
+#include "flang/Lower/ConvertType.h"
 #include "flang/Lower/SymbolMap.h"
+#include "flang/Optimizer/Builder/Complex.h"
 #include "flang/Optimizer/Builder/HLFIRTools.h"
 #include "flang/Optimizer/Builder/Todo.h"
 #include "flang/Optimizer/Dialect/FIRType.h"
@@ -131,7 +133,7 @@ ReductionProcessor::getReductionInitValue(mlir::Location loc, mlir::Type type,
                                           fir::FirOpBuilder &builder) {
   type = fir::unwrapRefType(type);
   if (!fir::isa_integer(type) && !fir::isa_real(type) &&
-      !mlir::isa<fir::LogicalType>(type))
+      !fir::isa_complex(type) && !mlir::isa<fir::LogicalType>(type))
     TODO(loc, "Reduction of some types is not supported");
   switch (redId) {
   case ReductionIdentifier::MAX: {
@@ -175,6 +177,17 @@ ReductionProcessor::getReductionInitValue(mlir::Location loc, mlir::Type type,
   case ReductionIdentifier::OR:
   case ReductionIdentifier::EQV:
   case ReductionIdentifier::NEQV:
+    if (auto cplxTy = mlir::dyn_cast<fir::ComplexType>(type)) {
+      mlir::Type realTy =
+          Fortran::lower::convertReal(builder.getContext(), cplxTy.getFKind());
+      //      mlir::FloatType realTy =
+      //      mlir::dyn_cast<mlir::FloatType>(cplxTy.getElementType());
+      //      const llvm::fltSemantics &sem = (realTy).getFloatSemantics();
+      mlir::Value init = builder.createRealConstant(
+          loc, realTy, getOperationIdentity(redId, loc));
+      return fir::factory::Complex{builder, loc}.createComplex(type, init,
+                                                               init);
+    }
     if (type.isa<mlir::FloatType>())
       return builder.create<mlir::arith::ConstantOp>(
           loc, type,
@@ -229,13 +242,13 @@ mlir::Value ReductionProcessor::createScalarCombiner(
     break;
   case ReductionIdentifier::ADD:
     reductionOp =
-        getReductionOperation<mlir::arith::AddFOp, mlir::arith::AddIOp>(
-            builder, type, loc, op1, op2);
+        getReductionOperation<mlir::arith::AddFOp, mlir::arith::AddIOp,
+                              fir::AddcOp>(builder, type, loc, op1, op2);
     break;
   case ReductionIdentifier::MULTIPLY:
     reductionOp =
-        getReductionOperation<mlir::arith::MulFOp, mlir::arith::MulIOp>(
-            builder, type, loc, op1, op2);
+        getReductionOperation<mlir::arith::MulFOp, mlir::arith::MulIOp,
+                              fir::MulcOp>(builder, type, loc, op1, op2);
     break;
   case ReductionIdentifier::AND: {
     mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
diff --git a/flang/lib/Lower/OpenMP/ReductionProcessor.h b/flang/lib/Lower/OpenMP/ReductionProcessor.h
index ee2732547fc288..7ea252fde3602e 100644
--- a/flang/lib/Lower/OpenMP/ReductionProcessor.h
+++ b/flang/lib/Lower/OpenMP/ReductionProcessor.h
@@ -97,6 +97,10 @@ class ReductionProcessor {
                                            fir::FirOpBuilder &builder);
 
   template <typename FloatOp, typename IntegerOp>
+  static mlir::Value getReductionOperation(fir::FirOpBuilder &builder,
+                                           mlir::Type type, mlir::Location loc,
+                                           mlir::Value op1, mlir::Value op2);
+  template <typename FloatOp, typename IntegerOp, typename ComplexOp>
   static mlir::Value getReductionOperation(fir::FirOpBuilder &builder,
                                            mlir::Type type, mlir::Location loc,
                                            mlir::Value op1, mlir::Value op2);
@@ -136,12 +140,27 @@ ReductionProcessor::getReductionOperation(fir::FirOpBuilder &builder,
                                           mlir::Value op1, mlir::Value op2) {
   type = fir::unwrapRefType(type);
   assert(type.isIntOrIndexOrFloat() &&
-         "only integer and float types are currently supported");
+         "only integer, float and complex types are currently supported");
   if (type.isIntOrIndex())
     return builder.create<IntegerOp>(loc, op1, op2);
   return builder.create<FloatOp>(loc, op1, op2);
 }
 
+template <typename FloatOp, typename IntegerOp, typename ComplexOp>
+mlir::Value
+ReductionProcessor::getReductionOperation(fir::FirOpBuilder &builder,
+                                          mlir::Type type, mlir::Location loc,
+                                          mlir::Value op1, mlir::Value op2) {
+  assert(type.isIntOrIndexOrFloat() ||
+         fir::isa_complex(type) &&
+             "only integer, float and complex types are currently supported");
+  if (type.isIntOrIndex())
+    return builder.create<IntegerOp>(loc, op1, op2);
+  if (fir::isa_real(type))
+    return builder.create<FloatOp>(loc, op1, op2);
+  return builder.create<ComplexOp>(loc, op1, op2);
+}
+
 } // namespace omp
 } // namespace lower
 } // namespace Fortran

>From 65f197c9bcb844bfab5d0cd462cad586e6443135 Mon Sep 17 00:00:00 2001
From: Mats Petersson <mats.petersson at arm.com>
Date: Thu, 4 Apr 2024 13:56:29 +0100
Subject: [PATCH 2/3] Fix init value

---
 flang/lib/Lower/OpenMP/ReductionProcessor.cpp | 11 ++--
 .../OpenMP/parallel-reduction-complex.f90     | 50 +++++++++++++++++++
 2 files changed, 55 insertions(+), 6 deletions(-)
 create mode 100644 flang/test/Lower/OpenMP/parallel-reduction-complex.f90

diff --git a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
index f06209c0e62032..0453c01522779b 100644
--- a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
@@ -180,13 +180,12 @@ ReductionProcessor::getReductionInitValue(mlir::Location loc, mlir::Type type,
     if (auto cplxTy = mlir::dyn_cast<fir::ComplexType>(type)) {
       mlir::Type realTy =
           Fortran::lower::convertReal(builder.getContext(), cplxTy.getFKind());
-      //      mlir::FloatType realTy =
-      //      mlir::dyn_cast<mlir::FloatType>(cplxTy.getElementType());
-      //      const llvm::fltSemantics &sem = (realTy).getFloatSemantics();
-      mlir::Value init = builder.createRealConstant(
+      mlir::Value initRe = builder.createRealConstant(
           loc, realTy, getOperationIdentity(redId, loc));
-      return fir::factory::Complex{builder, loc}.createComplex(type, init,
-                                                               init);
+      mlir::Value initIm = builder.createRealConstant(loc, realTy, 0);
+
+      return fir::factory::Complex{builder, loc}.createComplex(type, initRe,
+                                                               initIm);
     }
     if (type.isa<mlir::FloatType>())
       return builder.create<mlir::arith::ConstantOp>(
diff --git a/flang/test/Lower/OpenMP/parallel-reduction-complex.f90 b/flang/test/Lower/OpenMP/parallel-reduction-complex.f90
new file mode 100644
index 00000000000000..bc5a6b475e2569
--- /dev/null
+++ b/flang/test/Lower/OpenMP/parallel-reduction-complex.f90
@@ -0,0 +1,50 @@
+! RUN: bbc -emit-hlfir -fopenmp -o - %s 2>&1 | FileCheck %s
+! RUN: %flang_fc1 -emit-hlfir -fopenmp -o - %s 2>&1 | FileCheck %s
+
+!CHECK-LABEL: omp.declare_reduction
+!CHECK-SAME: @[[RED_NAME:.*]] : !fir.complex<8> init {
+!CHECK: ^bb0(%{{.*}}: !fir.complex<8>):
+!CHECK:  %[[C0_1:.*]] = arith.constant 0.000000e+00 : f64
+!CHECK:  %[[C0_2:.*]] = arith.constant 0.000000e+00 : f64
+!CHECK:  %[[UNDEF:.*]] = fir.undefined !fir.complex<8>
+!CHECK:  %[[RES_1:.*]] = fir.insert_value %[[UNDEF]], %[[C0_1]], [0 : index]
+!CHECK:  %[[RES_2:.*]] = fir.insert_value %[[RES_1]], %[[C0_2]], [1 : index]
+!CHECK:  omp.yield(%[[RES_2]] : !fir.complex<8>)
+!CHECK: } combiner {
+!CHECK: ^bb0(%[[ARG0:.*]]: !fir.complex<8>, %[[ARG1:.*]]: !fir.complex<8>):
+!CHECK:  %[[RES:.*]] = fir.addc %[[ARG0]], %[[ARG1]] {{.*}}: !fir.complex<8>
+!CHECK:  omp.yield(%[[RES]] : !fir.complex<8>)
+!CHECK: }
+
+!CHECK-LABEL: func.func @_QPsimple_complex_add
+!CHECK:  %[[CREF:.*]] = fir.alloca !fir.complex<8> {bindc_name = "c", {{.*}}}
+!CHECK:  %[[C_DECL:.*]]:2 = hlfir.declare %[[CREF]] {uniq_name = "_QFsimple_complex_addEc"} : (!fir.ref<!fir.complex<8>>) -> (!fir.ref<!fir.complex<8>>, !fir.ref<!fir.complex<8>>)
+!CHECK:  %[[C_START_RE:.*]] = arith.constant 0.000000e+00 : f64
+!CHECK:  %[[C_START_IM:.*]] = arith.constant 0.000000e+00 : f64
+!CHECK:  %[[UNDEF_1:.*]] = fir.undefined !fir.complex<8>
+!CHECK:  %[[VAL_1:.*]] = fir.insert_value %[[UNDEF_1]], %[[C_START_RE]], [0 : index]
+!CHECK:  %[[VAL_2:.*]] = fir.insert_value %[[VAL_1]], %[[C_START_IM]], [1 : index]
+!CHECK:  hlfir.assign %[[VAL_2]] to %[[C_DECL]]#0 : !fir.complex<8>, !fir.ref<!fir.complex<8>>
+!CHECK:  omp.parallel reduction(@[[RED_NAME]] %[[C_DECL]]#0 -> %[[PRV:.+]] : !fir.ref<!fir.complex<8>>) {
+!CHECK:    %[[P_DECL:.+]]:2 = hlfir.declare %[[PRV]] {{.*}} : (!fir.ref<!fir.complex<8>>) -> (!fir.ref<!fir.complex<8>>, !fir.ref<!fir.complex<8>>)
+!CHECK:    %[[LPRV:.+]] = fir.load %[[P_DECL]]#0 : !fir.ref<!fir.complex<8>>
+!CHECK:    %[[C_INCR_RE:.*]] = arith.constant 1.000000e+00 : f64
+!CHECK:    %[[C_INCR_IM:.*]] = arith.constant 0.000000e+00 : f64
+!CHECK:    %[[UNDEF_2:.*]] = fir.undefined !fir.complex<8>
+!CHECK:    %[[INCR_1:.*]] = fir.insert_value %[[UNDEF_2]], %[[C_INCR_RE]], [0 : index]
+!CHECK:    %[[INCR_2:.*]] = fir.insert_value %[[INCR_1]], %[[C_INCR_IM]], [1 : index]
+!CHECK:    %[[RES:.+]] = fir.addc %[[LPRV]], %[[INCR_2]] {{.*}} : !fir.complex<8>
+!CHECK:    hlfir.assign %[[RES]] to %[[P_DECL]]#0 : !fir.complex<8>, !fir.ref<!fir.complex<8>>
+!CHECK:    omp.terminator
+!CHECK:  }
+!CHECK: return
+subroutine simple_complex_add
+    complex(8) :: c
+    c = 0
+
+    !$omp parallel reduction(+:c)
+    c = c + 1
+    !$omp end parallel
+
+    print *, c
+end subroutine

>From a659f44ffa3f629a2251b5312c1df4405eb1ad40 Mon Sep 17 00:00:00 2001
From: Mats Petersson <mats.petersson at arm.com>
Date: Thu, 4 Apr 2024 14:19:24 +0100
Subject: [PATCH 3/3] Add tests for complex reduction

---
 .../OpenMP/parallel-reduction-complex-mul.f90 | 50 +++++++++++++++++++
 1 file changed, 50 insertions(+)
 create mode 100644 flang/test/Lower/OpenMP/parallel-reduction-complex-mul.f90

diff --git a/flang/test/Lower/OpenMP/parallel-reduction-complex-mul.f90 b/flang/test/Lower/OpenMP/parallel-reduction-complex-mul.f90
new file mode 100644
index 00000000000000..376defb8235814
--- /dev/null
+++ b/flang/test/Lower/OpenMP/parallel-reduction-complex-mul.f90
@@ -0,0 +1,50 @@
+! RUN: bbc -emit-hlfir -fopenmp -o - %s 2>&1 | FileCheck %s
+! RUN: %flang_fc1 -emit-hlfir -fopenmp -o - %s 2>&1 | FileCheck %s
+
+!CHECK-LABEL: omp.declare_reduction
+!CHECK-SAME: @[[RED_NAME:.*]] : !fir.complex<8> init {
+!CHECK: ^bb0(%{{.*}}: !fir.complex<8>):
+!CHECK:  %[[C0_1:.*]] = arith.constant 1.000000e+00 : f64
+!CHECK:  %[[C0_2:.*]] = arith.constant 0.000000e+00 : f64
+!CHECK:  %[[UNDEF:.*]] = fir.undefined !fir.complex<8>
+!CHECK:  %[[RES_1:.*]] = fir.insert_value %[[UNDEF]], %[[C0_1]], [0 : index]
+!CHECK:  %[[RES_2:.*]] = fir.insert_value %[[RES_1]], %[[C0_2]], [1 : index]
+!CHECK:  omp.yield(%[[RES_2]] : !fir.complex<8>)
+!CHECK: } combiner {
+!CHECK: ^bb0(%[[ARG0:.*]]: !fir.complex<8>, %[[ARG1:.*]]: !fir.complex<8>):
+!CHECK:  %[[RES:.*]] = fir.mulc %[[ARG0]], %[[ARG1]] {{.*}}: !fir.complex<8>
+!CHECK:  omp.yield(%[[RES]] : !fir.complex<8>)
+!CHECK: }
+
+!CHECK-LABEL: func.func @_QPsimple_complex_mul
+!CHECK:  %[[CREF:.*]] = fir.alloca !fir.complex<8> {bindc_name = "c", {{.*}}}
+!CHECK:  %[[C_DECL:.*]]:2 = hlfir.declare %[[CREF]] {uniq_name = "_QFsimple_complex_mulEc"} : (!fir.ref<!fir.complex<8>>) -> (!fir.ref<!fir.complex<8>>, !fir.ref<!fir.complex<8>>)
+!CHECK:  %[[C_START_RE:.*]] = arith.constant 0.000000e+00 : f64
+!CHECK:  %[[C_START_IM:.*]] = arith.constant 0.000000e+00 : f64
+!CHECK:  %[[UNDEF_1:.*]] = fir.undefined !fir.complex<8>
+!CHECK:  %[[VAL_1:.*]] = fir.insert_value %[[UNDEF_1]], %[[C_START_RE]], [0 : index]
+!CHECK:  %[[VAL_2:.*]] = fir.insert_value %[[VAL_1]], %[[C_START_IM]], [1 : index]
+!CHECK:  hlfir.assign %[[VAL_2]] to %[[C_DECL]]#0 : !fir.complex<8>, !fir.ref<!fir.complex<8>>
+!CHECK:  omp.parallel reduction(@[[RED_NAME]] %[[C_DECL]]#0 -> %[[PRV:.+]] : !fir.ref<!fir.complex<8>>) {
+!CHECK:    %[[P_DECL:.+]]:2 = hlfir.declare %[[PRV]] {{.*}} : (!fir.ref<!fir.complex<8>>) -> (!fir.ref<!fir.complex<8>>, !fir.ref<!fir.complex<8>>)
+!CHECK:    %[[LPRV:.+]] = fir.load %[[P_DECL]]#0 : !fir.ref<!fir.complex<8>>
+!CHECK:    %[[C_INCR_RE:.*]] = arith.constant 1.000000e+00 : f64
+!CHECK:    %[[C_INCR_IM:.*]] = arith.constant -2.000000e+00 : f64
+!CHECK:    %[[UNDEF_2:.*]] = fir.undefined !fir.complex<8>
+!CHECK:    %[[INCR_1:.*]] = fir.insert_value %[[UNDEF_2]], %[[C_INCR_RE]], [0 : index]
+!CHECK:    %[[INCR_2:.*]] = fir.insert_value %[[INCR_1]], %[[C_INCR_IM]], [1 : index]
+!CHECK:    %[[RES:.+]] = fir.mulc %[[LPRV]], %[[INCR_2]] {{.*}} : !fir.complex<8>
+!CHECK:    hlfir.assign %[[RES]] to %[[P_DECL]]#0 : !fir.complex<8>, !fir.ref<!fir.complex<8>>
+!CHECK:    omp.terminator
+!CHECK:  }
+!CHECK: return
+subroutine simple_complex_mul
+    complex(8) :: c
+    c = 0
+
+    !$omp parallel reduction(*:c)
+    c = c * cmplx(1, -2)
+    !$omp end parallel
+
+    print *, c
+end subroutine



More information about the flang-commits mailing list