[flang-commits] [flang] 221f438 - [flang][OpenMP] Add support for complex reductions (#87488)

via flang-commits flang-commits at lists.llvm.org
Mon Apr 8 02:18:18 PDT 2024


Author: Mats Petersson
Date: 2024-04-08T10:18:14+01:00
New Revision: 221f438af1c1292d787b58da99a5a7b371888456

URL: https://github.com/llvm/llvm-project/commit/221f438af1c1292d787b58da99a5a7b371888456
DIFF: https://github.com/llvm/llvm-project/commit/221f438af1c1292d787b58da99a5a7b371888456.diff

LOG: [flang][OpenMP] Add support for complex reductions (#87488)

This adds support for complex type to the OpenMP reductions. 

Note that some more work would be needed to give decent error messages when complex 
is used in ways that need client supplied functions (e.g. MAX or MIN). It does fail these with
a not so user friendly message at present.

Added: 
    flang/test/Lower/OpenMP/parallel-reduction-complex-mul.f90
    flang/test/Lower/OpenMP/parallel-reduction-complex.f90

Modified: 
    flang/lib/Lower/OpenMP/ReductionProcessor.cpp
    flang/lib/Lower/OpenMP/ReductionProcessor.h

Removed: 
    


################################################################################
diff  --git a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
index c1c94119fd9083..0453c01522779b 100644
--- a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
@@ -13,7 +13,9 @@
 #include "ReductionProcessor.h"
 
 #include "flang/Lower/AbstractConverter.h"
+#include "flang/Lower/ConvertType.h"
 #include "flang/Lower/SymbolMap.h"
+#include "flang/Optimizer/Builder/Complex.h"
 #include "flang/Optimizer/Builder/HLFIRTools.h"
 #include "flang/Optimizer/Builder/Todo.h"
 #include "flang/Optimizer/Dialect/FIRType.h"
@@ -131,7 +133,7 @@ ReductionProcessor::getReductionInitValue(mlir::Location loc, mlir::Type type,
                                           fir::FirOpBuilder &builder) {
   type = fir::unwrapRefType(type);
   if (!fir::isa_integer(type) && !fir::isa_real(type) &&
-      !mlir::isa<fir::LogicalType>(type))
+      !fir::isa_complex(type) && !mlir::isa<fir::LogicalType>(type))
     TODO(loc, "Reduction of some types is not supported");
   switch (redId) {
   case ReductionIdentifier::MAX: {
@@ -175,6 +177,16 @@ ReductionProcessor::getReductionInitValue(mlir::Location loc, mlir::Type type,
   case ReductionIdentifier::OR:
   case ReductionIdentifier::EQV:
   case ReductionIdentifier::NEQV:
+    if (auto cplxTy = mlir::dyn_cast<fir::ComplexType>(type)) {
+      mlir::Type realTy =
+          Fortran::lower::convertReal(builder.getContext(), cplxTy.getFKind());
+      mlir::Value initRe = builder.createRealConstant(
+          loc, realTy, getOperationIdentity(redId, loc));
+      mlir::Value initIm = builder.createRealConstant(loc, realTy, 0);
+
+      return fir::factory::Complex{builder, loc}.createComplex(type, initRe,
+                                                               initIm);
+    }
     if (type.isa<mlir::FloatType>())
       return builder.create<mlir::arith::ConstantOp>(
           loc, type,
@@ -229,13 +241,13 @@ mlir::Value ReductionProcessor::createScalarCombiner(
     break;
   case ReductionIdentifier::ADD:
     reductionOp =
-        getReductionOperation<mlir::arith::AddFOp, mlir::arith::AddIOp>(
-            builder, type, loc, op1, op2);
+        getReductionOperation<mlir::arith::AddFOp, mlir::arith::AddIOp,
+                              fir::AddcOp>(builder, type, loc, op1, op2);
     break;
   case ReductionIdentifier::MULTIPLY:
     reductionOp =
-        getReductionOperation<mlir::arith::MulFOp, mlir::arith::MulIOp>(
-            builder, type, loc, op1, op2);
+        getReductionOperation<mlir::arith::MulFOp, mlir::arith::MulIOp,
+                              fir::MulcOp>(builder, type, loc, op1, op2);
     break;
   case ReductionIdentifier::AND: {
     mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);

diff  --git a/flang/lib/Lower/OpenMP/ReductionProcessor.h b/flang/lib/Lower/OpenMP/ReductionProcessor.h
index ee2732547fc288..7ea252fde3602e 100644
--- a/flang/lib/Lower/OpenMP/ReductionProcessor.h
+++ b/flang/lib/Lower/OpenMP/ReductionProcessor.h
@@ -97,6 +97,10 @@ class ReductionProcessor {
                                            fir::FirOpBuilder &builder);
 
   template <typename FloatOp, typename IntegerOp>
+  static mlir::Value getReductionOperation(fir::FirOpBuilder &builder,
+                                           mlir::Type type, mlir::Location loc,
+                                           mlir::Value op1, mlir::Value op2);
+  template <typename FloatOp, typename IntegerOp, typename ComplexOp>
   static mlir::Value getReductionOperation(fir::FirOpBuilder &builder,
                                            mlir::Type type, mlir::Location loc,
                                            mlir::Value op1, mlir::Value op2);
@@ -136,12 +140,27 @@ ReductionProcessor::getReductionOperation(fir::FirOpBuilder &builder,
                                           mlir::Value op1, mlir::Value op2) {
   type = fir::unwrapRefType(type);
   assert(type.isIntOrIndexOrFloat() &&
-         "only integer and float types are currently supported");
+         "only integer, float and complex types are currently supported");
   if (type.isIntOrIndex())
     return builder.create<IntegerOp>(loc, op1, op2);
   return builder.create<FloatOp>(loc, op1, op2);
 }
 
+template <typename FloatOp, typename IntegerOp, typename ComplexOp>
+mlir::Value
+ReductionProcessor::getReductionOperation(fir::FirOpBuilder &builder,
+                                          mlir::Type type, mlir::Location loc,
+                                          mlir::Value op1, mlir::Value op2) {
+  assert(type.isIntOrIndexOrFloat() ||
+         fir::isa_complex(type) &&
+             "only integer, float and complex types are currently supported");
+  if (type.isIntOrIndex())
+    return builder.create<IntegerOp>(loc, op1, op2);
+  if (fir::isa_real(type))
+    return builder.create<FloatOp>(loc, op1, op2);
+  return builder.create<ComplexOp>(loc, op1, op2);
+}
+
 } // namespace omp
 } // namespace lower
 } // namespace Fortran

diff  --git a/flang/test/Lower/OpenMP/parallel-reduction-complex-mul.f90 b/flang/test/Lower/OpenMP/parallel-reduction-complex-mul.f90
new file mode 100644
index 00000000000000..376defb8235814
--- /dev/null
+++ b/flang/test/Lower/OpenMP/parallel-reduction-complex-mul.f90
@@ -0,0 +1,50 @@
+! RUN: bbc -emit-hlfir -fopenmp -o - %s 2>&1 | FileCheck %s
+! RUN: %flang_fc1 -emit-hlfir -fopenmp -o - %s 2>&1 | FileCheck %s
+
+!CHECK-LABEL: omp.declare_reduction
+!CHECK-SAME: @[[RED_NAME:.*]] : !fir.complex<8> init {
+!CHECK: ^bb0(%{{.*}}: !fir.complex<8>):
+!CHECK:  %[[C0_1:.*]] = arith.constant 1.000000e+00 : f64
+!CHECK:  %[[C0_2:.*]] = arith.constant 0.000000e+00 : f64
+!CHECK:  %[[UNDEF:.*]] = fir.undefined !fir.complex<8>
+!CHECK:  %[[RES_1:.*]] = fir.insert_value %[[UNDEF]], %[[C0_1]], [0 : index]
+!CHECK:  %[[RES_2:.*]] = fir.insert_value %[[RES_1]], %[[C0_2]], [1 : index]
+!CHECK:  omp.yield(%[[RES_2]] : !fir.complex<8>)
+!CHECK: } combiner {
+!CHECK: ^bb0(%[[ARG0:.*]]: !fir.complex<8>, %[[ARG1:.*]]: !fir.complex<8>):
+!CHECK:  %[[RES:.*]] = fir.mulc %[[ARG0]], %[[ARG1]] {{.*}}: !fir.complex<8>
+!CHECK:  omp.yield(%[[RES]] : !fir.complex<8>)
+!CHECK: }
+
+!CHECK-LABEL: func.func @_QPsimple_complex_mul
+!CHECK:  %[[CREF:.*]] = fir.alloca !fir.complex<8> {bindc_name = "c", {{.*}}}
+!CHECK:  %[[C_DECL:.*]]:2 = hlfir.declare %[[CREF]] {uniq_name = "_QFsimple_complex_mulEc"} : (!fir.ref<!fir.complex<8>>) -> (!fir.ref<!fir.complex<8>>, !fir.ref<!fir.complex<8>>)
+!CHECK:  %[[C_START_RE:.*]] = arith.constant 0.000000e+00 : f64
+!CHECK:  %[[C_START_IM:.*]] = arith.constant 0.000000e+00 : f64
+!CHECK:  %[[UNDEF_1:.*]] = fir.undefined !fir.complex<8>
+!CHECK:  %[[VAL_1:.*]] = fir.insert_value %[[UNDEF_1]], %[[C_START_RE]], [0 : index]
+!CHECK:  %[[VAL_2:.*]] = fir.insert_value %[[VAL_1]], %[[C_START_IM]], [1 : index]
+!CHECK:  hlfir.assign %[[VAL_2]] to %[[C_DECL]]#0 : !fir.complex<8>, !fir.ref<!fir.complex<8>>
+!CHECK:  omp.parallel reduction(@[[RED_NAME]] %[[C_DECL]]#0 -> %[[PRV:.+]] : !fir.ref<!fir.complex<8>>) {
+!CHECK:    %[[P_DECL:.+]]:2 = hlfir.declare %[[PRV]] {{.*}} : (!fir.ref<!fir.complex<8>>) -> (!fir.ref<!fir.complex<8>>, !fir.ref<!fir.complex<8>>)
+!CHECK:    %[[LPRV:.+]] = fir.load %[[P_DECL]]#0 : !fir.ref<!fir.complex<8>>
+!CHECK:    %[[C_INCR_RE:.*]] = arith.constant 1.000000e+00 : f64
+!CHECK:    %[[C_INCR_IM:.*]] = arith.constant -2.000000e+00 : f64
+!CHECK:    %[[UNDEF_2:.*]] = fir.undefined !fir.complex<8>
+!CHECK:    %[[INCR_1:.*]] = fir.insert_value %[[UNDEF_2]], %[[C_INCR_RE]], [0 : index]
+!CHECK:    %[[INCR_2:.*]] = fir.insert_value %[[INCR_1]], %[[C_INCR_IM]], [1 : index]
+!CHECK:    %[[RES:.+]] = fir.mulc %[[LPRV]], %[[INCR_2]] {{.*}} : !fir.complex<8>
+!CHECK:    hlfir.assign %[[RES]] to %[[P_DECL]]#0 : !fir.complex<8>, !fir.ref<!fir.complex<8>>
+!CHECK:    omp.terminator
+!CHECK:  }
+!CHECK: return
+subroutine simple_complex_mul
+    complex(8) :: c
+    c = 0
+
+    !$omp parallel reduction(*:c)
+    c = c * cmplx(1, -2)
+    !$omp end parallel
+
+    print *, c
+end subroutine

diff  --git a/flang/test/Lower/OpenMP/parallel-reduction-complex.f90 b/flang/test/Lower/OpenMP/parallel-reduction-complex.f90
new file mode 100644
index 00000000000000..bc5a6b475e2569
--- /dev/null
+++ b/flang/test/Lower/OpenMP/parallel-reduction-complex.f90
@@ -0,0 +1,50 @@
+! RUN: bbc -emit-hlfir -fopenmp -o - %s 2>&1 | FileCheck %s
+! RUN: %flang_fc1 -emit-hlfir -fopenmp -o - %s 2>&1 | FileCheck %s
+
+!CHECK-LABEL: omp.declare_reduction
+!CHECK-SAME: @[[RED_NAME:.*]] : !fir.complex<8> init {
+!CHECK: ^bb0(%{{.*}}: !fir.complex<8>):
+!CHECK:  %[[C0_1:.*]] = arith.constant 0.000000e+00 : f64
+!CHECK:  %[[C0_2:.*]] = arith.constant 0.000000e+00 : f64
+!CHECK:  %[[UNDEF:.*]] = fir.undefined !fir.complex<8>
+!CHECK:  %[[RES_1:.*]] = fir.insert_value %[[UNDEF]], %[[C0_1]], [0 : index]
+!CHECK:  %[[RES_2:.*]] = fir.insert_value %[[RES_1]], %[[C0_2]], [1 : index]
+!CHECK:  omp.yield(%[[RES_2]] : !fir.complex<8>)
+!CHECK: } combiner {
+!CHECK: ^bb0(%[[ARG0:.*]]: !fir.complex<8>, %[[ARG1:.*]]: !fir.complex<8>):
+!CHECK:  %[[RES:.*]] = fir.addc %[[ARG0]], %[[ARG1]] {{.*}}: !fir.complex<8>
+!CHECK:  omp.yield(%[[RES]] : !fir.complex<8>)
+!CHECK: }
+
+!CHECK-LABEL: func.func @_QPsimple_complex_add
+!CHECK:  %[[CREF:.*]] = fir.alloca !fir.complex<8> {bindc_name = "c", {{.*}}}
+!CHECK:  %[[C_DECL:.*]]:2 = hlfir.declare %[[CREF]] {uniq_name = "_QFsimple_complex_addEc"} : (!fir.ref<!fir.complex<8>>) -> (!fir.ref<!fir.complex<8>>, !fir.ref<!fir.complex<8>>)
+!CHECK:  %[[C_START_RE:.*]] = arith.constant 0.000000e+00 : f64
+!CHECK:  %[[C_START_IM:.*]] = arith.constant 0.000000e+00 : f64
+!CHECK:  %[[UNDEF_1:.*]] = fir.undefined !fir.complex<8>
+!CHECK:  %[[VAL_1:.*]] = fir.insert_value %[[UNDEF_1]], %[[C_START_RE]], [0 : index]
+!CHECK:  %[[VAL_2:.*]] = fir.insert_value %[[VAL_1]], %[[C_START_IM]], [1 : index]
+!CHECK:  hlfir.assign %[[VAL_2]] to %[[C_DECL]]#0 : !fir.complex<8>, !fir.ref<!fir.complex<8>>
+!CHECK:  omp.parallel reduction(@[[RED_NAME]] %[[C_DECL]]#0 -> %[[PRV:.+]] : !fir.ref<!fir.complex<8>>) {
+!CHECK:    %[[P_DECL:.+]]:2 = hlfir.declare %[[PRV]] {{.*}} : (!fir.ref<!fir.complex<8>>) -> (!fir.ref<!fir.complex<8>>, !fir.ref<!fir.complex<8>>)
+!CHECK:    %[[LPRV:.+]] = fir.load %[[P_DECL]]#0 : !fir.ref<!fir.complex<8>>
+!CHECK:    %[[C_INCR_RE:.*]] = arith.constant 1.000000e+00 : f64
+!CHECK:    %[[C_INCR_IM:.*]] = arith.constant 0.000000e+00 : f64
+!CHECK:    %[[UNDEF_2:.*]] = fir.undefined !fir.complex<8>
+!CHECK:    %[[INCR_1:.*]] = fir.insert_value %[[UNDEF_2]], %[[C_INCR_RE]], [0 : index]
+!CHECK:    %[[INCR_2:.*]] = fir.insert_value %[[INCR_1]], %[[C_INCR_IM]], [1 : index]
+!CHECK:    %[[RES:.+]] = fir.addc %[[LPRV]], %[[INCR_2]] {{.*}} : !fir.complex<8>
+!CHECK:    hlfir.assign %[[RES]] to %[[P_DECL]]#0 : !fir.complex<8>, !fir.ref<!fir.complex<8>>
+!CHECK:    omp.terminator
+!CHECK:  }
+!CHECK: return
+subroutine simple_complex_add
+    complex(8) :: c
+    c = 0
+
+    !$omp parallel reduction(+:c)
+    c = c + 1
+    !$omp end parallel
+
+    print *, c
+end subroutine


        


More information about the flang-commits mailing list