[flang-commits] [flang] 221f438 - [flang][OpenMP] Add support for complex reductions (#87488)
via flang-commits
flang-commits at lists.llvm.org
Mon Apr 8 02:18:18 PDT 2024
Author: Mats Petersson
Date: 2024-04-08T10:18:14+01:00
New Revision: 221f438af1c1292d787b58da99a5a7b371888456
URL: https://github.com/llvm/llvm-project/commit/221f438af1c1292d787b58da99a5a7b371888456
DIFF: https://github.com/llvm/llvm-project/commit/221f438af1c1292d787b58da99a5a7b371888456.diff
LOG: [flang][OpenMP] Add support for complex reductions (#87488)
This adds support for complex type to the OpenMP reductions.
Note that some more work would be needed to give decent error messages when complex
is used in ways that need client supplied functions (e.g. MAX or MIN). It does fail these with
a not so user friendly message at present.
Added:
flang/test/Lower/OpenMP/parallel-reduction-complex-mul.f90
flang/test/Lower/OpenMP/parallel-reduction-complex.f90
Modified:
flang/lib/Lower/OpenMP/ReductionProcessor.cpp
flang/lib/Lower/OpenMP/ReductionProcessor.h
Removed:
################################################################################
diff --git a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
index c1c94119fd9083..0453c01522779b 100644
--- a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
@@ -13,7 +13,9 @@
#include "ReductionProcessor.h"
#include "flang/Lower/AbstractConverter.h"
+#include "flang/Lower/ConvertType.h"
#include "flang/Lower/SymbolMap.h"
+#include "flang/Optimizer/Builder/Complex.h"
#include "flang/Optimizer/Builder/HLFIRTools.h"
#include "flang/Optimizer/Builder/Todo.h"
#include "flang/Optimizer/Dialect/FIRType.h"
@@ -131,7 +133,7 @@ ReductionProcessor::getReductionInitValue(mlir::Location loc, mlir::Type type,
fir::FirOpBuilder &builder) {
type = fir::unwrapRefType(type);
if (!fir::isa_integer(type) && !fir::isa_real(type) &&
- !mlir::isa<fir::LogicalType>(type))
+ !fir::isa_complex(type) && !mlir::isa<fir::LogicalType>(type))
TODO(loc, "Reduction of some types is not supported");
switch (redId) {
case ReductionIdentifier::MAX: {
@@ -175,6 +177,16 @@ ReductionProcessor::getReductionInitValue(mlir::Location loc, mlir::Type type,
case ReductionIdentifier::OR:
case ReductionIdentifier::EQV:
case ReductionIdentifier::NEQV:
+ if (auto cplxTy = mlir::dyn_cast<fir::ComplexType>(type)) {
+ mlir::Type realTy =
+ Fortran::lower::convertReal(builder.getContext(), cplxTy.getFKind());
+ mlir::Value initRe = builder.createRealConstant(
+ loc, realTy, getOperationIdentity(redId, loc));
+ mlir::Value initIm = builder.createRealConstant(loc, realTy, 0);
+
+ return fir::factory::Complex{builder, loc}.createComplex(type, initRe,
+ initIm);
+ }
if (type.isa<mlir::FloatType>())
return builder.create<mlir::arith::ConstantOp>(
loc, type,
@@ -229,13 +241,13 @@ mlir::Value ReductionProcessor::createScalarCombiner(
break;
case ReductionIdentifier::ADD:
reductionOp =
- getReductionOperation<mlir::arith::AddFOp, mlir::arith::AddIOp>(
- builder, type, loc, op1, op2);
+ getReductionOperation<mlir::arith::AddFOp, mlir::arith::AddIOp,
+ fir::AddcOp>(builder, type, loc, op1, op2);
break;
case ReductionIdentifier::MULTIPLY:
reductionOp =
- getReductionOperation<mlir::arith::MulFOp, mlir::arith::MulIOp>(
- builder, type, loc, op1, op2);
+ getReductionOperation<mlir::arith::MulFOp, mlir::arith::MulIOp,
+ fir::MulcOp>(builder, type, loc, op1, op2);
break;
case ReductionIdentifier::AND: {
mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
diff --git a/flang/lib/Lower/OpenMP/ReductionProcessor.h b/flang/lib/Lower/OpenMP/ReductionProcessor.h
index ee2732547fc288..7ea252fde3602e 100644
--- a/flang/lib/Lower/OpenMP/ReductionProcessor.h
+++ b/flang/lib/Lower/OpenMP/ReductionProcessor.h
@@ -97,6 +97,10 @@ class ReductionProcessor {
fir::FirOpBuilder &builder);
template <typename FloatOp, typename IntegerOp>
+ static mlir::Value getReductionOperation(fir::FirOpBuilder &builder,
+ mlir::Type type, mlir::Location loc,
+ mlir::Value op1, mlir::Value op2);
+ template <typename FloatOp, typename IntegerOp, typename ComplexOp>
static mlir::Value getReductionOperation(fir::FirOpBuilder &builder,
mlir::Type type, mlir::Location loc,
mlir::Value op1, mlir::Value op2);
@@ -136,12 +140,27 @@ ReductionProcessor::getReductionOperation(fir::FirOpBuilder &builder,
mlir::Value op1, mlir::Value op2) {
type = fir::unwrapRefType(type);
assert(type.isIntOrIndexOrFloat() &&
- "only integer and float types are currently supported");
+ "only integer, float and complex types are currently supported");
if (type.isIntOrIndex())
return builder.create<IntegerOp>(loc, op1, op2);
return builder.create<FloatOp>(loc, op1, op2);
}
+template <typename FloatOp, typename IntegerOp, typename ComplexOp>
+mlir::Value
+ReductionProcessor::getReductionOperation(fir::FirOpBuilder &builder,
+ mlir::Type type, mlir::Location loc,
+ mlir::Value op1, mlir::Value op2) {
+ assert(type.isIntOrIndexOrFloat() ||
+ fir::isa_complex(type) &&
+ "only integer, float and complex types are currently supported");
+ if (type.isIntOrIndex())
+ return builder.create<IntegerOp>(loc, op1, op2);
+ if (fir::isa_real(type))
+ return builder.create<FloatOp>(loc, op1, op2);
+ return builder.create<ComplexOp>(loc, op1, op2);
+}
+
} // namespace omp
} // namespace lower
} // namespace Fortran
diff --git a/flang/test/Lower/OpenMP/parallel-reduction-complex-mul.f90 b/flang/test/Lower/OpenMP/parallel-reduction-complex-mul.f90
new file mode 100644
index 00000000000000..376defb8235814
--- /dev/null
+++ b/flang/test/Lower/OpenMP/parallel-reduction-complex-mul.f90
@@ -0,0 +1,50 @@
+! RUN: bbc -emit-hlfir -fopenmp -o - %s 2>&1 | FileCheck %s
+! RUN: %flang_fc1 -emit-hlfir -fopenmp -o - %s 2>&1 | FileCheck %s
+
+!CHECK-LABEL: omp.declare_reduction
+!CHECK-SAME: @[[RED_NAME:.*]] : !fir.complex<8> init {
+!CHECK: ^bb0(%{{.*}}: !fir.complex<8>):
+!CHECK: %[[C0_1:.*]] = arith.constant 1.000000e+00 : f64
+!CHECK: %[[C0_2:.*]] = arith.constant 0.000000e+00 : f64
+!CHECK: %[[UNDEF:.*]] = fir.undefined !fir.complex<8>
+!CHECK: %[[RES_1:.*]] = fir.insert_value %[[UNDEF]], %[[C0_1]], [0 : index]
+!CHECK: %[[RES_2:.*]] = fir.insert_value %[[RES_1]], %[[C0_2]], [1 : index]
+!CHECK: omp.yield(%[[RES_2]] : !fir.complex<8>)
+!CHECK: } combiner {
+!CHECK: ^bb0(%[[ARG0:.*]]: !fir.complex<8>, %[[ARG1:.*]]: !fir.complex<8>):
+!CHECK: %[[RES:.*]] = fir.mulc %[[ARG0]], %[[ARG1]] {{.*}}: !fir.complex<8>
+!CHECK: omp.yield(%[[RES]] : !fir.complex<8>)
+!CHECK: }
+
+!CHECK-LABEL: func.func @_QPsimple_complex_mul
+!CHECK: %[[CREF:.*]] = fir.alloca !fir.complex<8> {bindc_name = "c", {{.*}}}
+!CHECK: %[[C_DECL:.*]]:2 = hlfir.declare %[[CREF]] {uniq_name = "_QFsimple_complex_mulEc"} : (!fir.ref<!fir.complex<8>>) -> (!fir.ref<!fir.complex<8>>, !fir.ref<!fir.complex<8>>)
+!CHECK: %[[C_START_RE:.*]] = arith.constant 0.000000e+00 : f64
+!CHECK: %[[C_START_IM:.*]] = arith.constant 0.000000e+00 : f64
+!CHECK: %[[UNDEF_1:.*]] = fir.undefined !fir.complex<8>
+!CHECK: %[[VAL_1:.*]] = fir.insert_value %[[UNDEF_1]], %[[C_START_RE]], [0 : index]
+!CHECK: %[[VAL_2:.*]] = fir.insert_value %[[VAL_1]], %[[C_START_IM]], [1 : index]
+!CHECK: hlfir.assign %[[VAL_2]] to %[[C_DECL]]#0 : !fir.complex<8>, !fir.ref<!fir.complex<8>>
+!CHECK: omp.parallel reduction(@[[RED_NAME]] %[[C_DECL]]#0 -> %[[PRV:.+]] : !fir.ref<!fir.complex<8>>) {
+!CHECK: %[[P_DECL:.+]]:2 = hlfir.declare %[[PRV]] {{.*}} : (!fir.ref<!fir.complex<8>>) -> (!fir.ref<!fir.complex<8>>, !fir.ref<!fir.complex<8>>)
+!CHECK: %[[LPRV:.+]] = fir.load %[[P_DECL]]#0 : !fir.ref<!fir.complex<8>>
+!CHECK: %[[C_INCR_RE:.*]] = arith.constant 1.000000e+00 : f64
+!CHECK: %[[C_INCR_IM:.*]] = arith.constant -2.000000e+00 : f64
+!CHECK: %[[UNDEF_2:.*]] = fir.undefined !fir.complex<8>
+!CHECK: %[[INCR_1:.*]] = fir.insert_value %[[UNDEF_2]], %[[C_INCR_RE]], [0 : index]
+!CHECK: %[[INCR_2:.*]] = fir.insert_value %[[INCR_1]], %[[C_INCR_IM]], [1 : index]
+!CHECK: %[[RES:.+]] = fir.mulc %[[LPRV]], %[[INCR_2]] {{.*}} : !fir.complex<8>
+!CHECK: hlfir.assign %[[RES]] to %[[P_DECL]]#0 : !fir.complex<8>, !fir.ref<!fir.complex<8>>
+!CHECK: omp.terminator
+!CHECK: }
+!CHECK: return
+subroutine simple_complex_mul
+ complex(8) :: c
+ c = 0
+
+ !$omp parallel reduction(*:c)
+ c = c * cmplx(1, -2)
+ !$omp end parallel
+
+ print *, c
+end subroutine
diff --git a/flang/test/Lower/OpenMP/parallel-reduction-complex.f90 b/flang/test/Lower/OpenMP/parallel-reduction-complex.f90
new file mode 100644
index 00000000000000..bc5a6b475e2569
--- /dev/null
+++ b/flang/test/Lower/OpenMP/parallel-reduction-complex.f90
@@ -0,0 +1,50 @@
+! RUN: bbc -emit-hlfir -fopenmp -o - %s 2>&1 | FileCheck %s
+! RUN: %flang_fc1 -emit-hlfir -fopenmp -o - %s 2>&1 | FileCheck %s
+
+!CHECK-LABEL: omp.declare_reduction
+!CHECK-SAME: @[[RED_NAME:.*]] : !fir.complex<8> init {
+!CHECK: ^bb0(%{{.*}}: !fir.complex<8>):
+!CHECK: %[[C0_1:.*]] = arith.constant 0.000000e+00 : f64
+!CHECK: %[[C0_2:.*]] = arith.constant 0.000000e+00 : f64
+!CHECK: %[[UNDEF:.*]] = fir.undefined !fir.complex<8>
+!CHECK: %[[RES_1:.*]] = fir.insert_value %[[UNDEF]], %[[C0_1]], [0 : index]
+!CHECK: %[[RES_2:.*]] = fir.insert_value %[[RES_1]], %[[C0_2]], [1 : index]
+!CHECK: omp.yield(%[[RES_2]] : !fir.complex<8>)
+!CHECK: } combiner {
+!CHECK: ^bb0(%[[ARG0:.*]]: !fir.complex<8>, %[[ARG1:.*]]: !fir.complex<8>):
+!CHECK: %[[RES:.*]] = fir.addc %[[ARG0]], %[[ARG1]] {{.*}}: !fir.complex<8>
+!CHECK: omp.yield(%[[RES]] : !fir.complex<8>)
+!CHECK: }
+
+!CHECK-LABEL: func.func @_QPsimple_complex_add
+!CHECK: %[[CREF:.*]] = fir.alloca !fir.complex<8> {bindc_name = "c", {{.*}}}
+!CHECK: %[[C_DECL:.*]]:2 = hlfir.declare %[[CREF]] {uniq_name = "_QFsimple_complex_addEc"} : (!fir.ref<!fir.complex<8>>) -> (!fir.ref<!fir.complex<8>>, !fir.ref<!fir.complex<8>>)
+!CHECK: %[[C_START_RE:.*]] = arith.constant 0.000000e+00 : f64
+!CHECK: %[[C_START_IM:.*]] = arith.constant 0.000000e+00 : f64
+!CHECK: %[[UNDEF_1:.*]] = fir.undefined !fir.complex<8>
+!CHECK: %[[VAL_1:.*]] = fir.insert_value %[[UNDEF_1]], %[[C_START_RE]], [0 : index]
+!CHECK: %[[VAL_2:.*]] = fir.insert_value %[[VAL_1]], %[[C_START_IM]], [1 : index]
+!CHECK: hlfir.assign %[[VAL_2]] to %[[C_DECL]]#0 : !fir.complex<8>, !fir.ref<!fir.complex<8>>
+!CHECK: omp.parallel reduction(@[[RED_NAME]] %[[C_DECL]]#0 -> %[[PRV:.+]] : !fir.ref<!fir.complex<8>>) {
+!CHECK: %[[P_DECL:.+]]:2 = hlfir.declare %[[PRV]] {{.*}} : (!fir.ref<!fir.complex<8>>) -> (!fir.ref<!fir.complex<8>>, !fir.ref<!fir.complex<8>>)
+!CHECK: %[[LPRV:.+]] = fir.load %[[P_DECL]]#0 : !fir.ref<!fir.complex<8>>
+!CHECK: %[[C_INCR_RE:.*]] = arith.constant 1.000000e+00 : f64
+!CHECK: %[[C_INCR_IM:.*]] = arith.constant 0.000000e+00 : f64
+!CHECK: %[[UNDEF_2:.*]] = fir.undefined !fir.complex<8>
+!CHECK: %[[INCR_1:.*]] = fir.insert_value %[[UNDEF_2]], %[[C_INCR_RE]], [0 : index]
+!CHECK: %[[INCR_2:.*]] = fir.insert_value %[[INCR_1]], %[[C_INCR_IM]], [1 : index]
+!CHECK: %[[RES:.+]] = fir.addc %[[LPRV]], %[[INCR_2]] {{.*}} : !fir.complex<8>
+!CHECK: hlfir.assign %[[RES]] to %[[P_DECL]]#0 : !fir.complex<8>, !fir.ref<!fir.complex<8>>
+!CHECK: omp.terminator
+!CHECK: }
+!CHECK: return
+subroutine simple_complex_add
+ complex(8) :: c
+ c = 0
+
+ !$omp parallel reduction(+:c)
+ c = c + 1
+ !$omp end parallel
+
+ print *, c
+end subroutine
More information about the flang-commits
mailing list