[flang-commits] [flang] [flang][OpenMP] Add support for complex reductions (PR #87488)
Mats Petersson via flang-commits
flang-commits at lists.llvm.org
Fri Apr 5 03:33:01 PDT 2024
https://github.com/Leporacanthicus updated https://github.com/llvm/llvm-project/pull/87488
>From 91356512f3c18ee6c5c44ae2f91ae10e7a6c3dce Mon Sep 17 00:00:00 2001
From: Mats Petersson <mats.petersson at arm.com>
Date: Wed, 3 Apr 2024 13:43:30 +0100
Subject: [PATCH 1/3] [FLANG Add support for complex OpenMP reductions
The SALMON application uses OpenMP reductions on complex values,
which wasn't supported in Flang. This adds the basic support
for this functionality.
---
flang/lib/Lower/OpenMP/ReductionProcessor.cpp | 23 +++++++++++++++----
flang/lib/Lower/OpenMP/ReductionProcessor.h | 21 ++++++++++++++++-
2 files changed, 38 insertions(+), 6 deletions(-)
diff --git a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
index c1c94119fd9083..f06209c0e62032 100644
--- a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
@@ -13,7 +13,9 @@
#include "ReductionProcessor.h"
#include "flang/Lower/AbstractConverter.h"
+#include "flang/Lower/ConvertType.h"
#include "flang/Lower/SymbolMap.h"
+#include "flang/Optimizer/Builder/Complex.h"
#include "flang/Optimizer/Builder/HLFIRTools.h"
#include "flang/Optimizer/Builder/Todo.h"
#include "flang/Optimizer/Dialect/FIRType.h"
@@ -131,7 +133,7 @@ ReductionProcessor::getReductionInitValue(mlir::Location loc, mlir::Type type,
fir::FirOpBuilder &builder) {
type = fir::unwrapRefType(type);
if (!fir::isa_integer(type) && !fir::isa_real(type) &&
- !mlir::isa<fir::LogicalType>(type))
+ !fir::isa_complex(type) && !mlir::isa<fir::LogicalType>(type))
TODO(loc, "Reduction of some types is not supported");
switch (redId) {
case ReductionIdentifier::MAX: {
@@ -175,6 +177,17 @@ ReductionProcessor::getReductionInitValue(mlir::Location loc, mlir::Type type,
case ReductionIdentifier::OR:
case ReductionIdentifier::EQV:
case ReductionIdentifier::NEQV:
+ if (auto cplxTy = mlir::dyn_cast<fir::ComplexType>(type)) {
+ mlir::Type realTy =
+ Fortran::lower::convertReal(builder.getContext(), cplxTy.getFKind());
+ // mlir::FloatType realTy =
+ // mlir::dyn_cast<mlir::FloatType>(cplxTy.getElementType());
+ // const llvm::fltSemantics &sem = (realTy).getFloatSemantics();
+ mlir::Value init = builder.createRealConstant(
+ loc, realTy, getOperationIdentity(redId, loc));
+ return fir::factory::Complex{builder, loc}.createComplex(type, init,
+ init);
+ }
if (type.isa<mlir::FloatType>())
return builder.create<mlir::arith::ConstantOp>(
loc, type,
@@ -229,13 +242,13 @@ mlir::Value ReductionProcessor::createScalarCombiner(
break;
case ReductionIdentifier::ADD:
reductionOp =
- getReductionOperation<mlir::arith::AddFOp, mlir::arith::AddIOp>(
- builder, type, loc, op1, op2);
+ getReductionOperation<mlir::arith::AddFOp, mlir::arith::AddIOp,
+ fir::AddcOp>(builder, type, loc, op1, op2);
break;
case ReductionIdentifier::MULTIPLY:
reductionOp =
- getReductionOperation<mlir::arith::MulFOp, mlir::arith::MulIOp>(
- builder, type, loc, op1, op2);
+ getReductionOperation<mlir::arith::MulFOp, mlir::arith::MulIOp,
+ fir::MulcOp>(builder, type, loc, op1, op2);
break;
case ReductionIdentifier::AND: {
mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
diff --git a/flang/lib/Lower/OpenMP/ReductionProcessor.h b/flang/lib/Lower/OpenMP/ReductionProcessor.h
index ee2732547fc288..7ea252fde3602e 100644
--- a/flang/lib/Lower/OpenMP/ReductionProcessor.h
+++ b/flang/lib/Lower/OpenMP/ReductionProcessor.h
@@ -97,6 +97,10 @@ class ReductionProcessor {
fir::FirOpBuilder &builder);
template <typename FloatOp, typename IntegerOp>
+ static mlir::Value getReductionOperation(fir::FirOpBuilder &builder,
+ mlir::Type type, mlir::Location loc,
+ mlir::Value op1, mlir::Value op2);
+ template <typename FloatOp, typename IntegerOp, typename ComplexOp>
static mlir::Value getReductionOperation(fir::FirOpBuilder &builder,
mlir::Type type, mlir::Location loc,
mlir::Value op1, mlir::Value op2);
@@ -136,12 +140,27 @@ ReductionProcessor::getReductionOperation(fir::FirOpBuilder &builder,
mlir::Value op1, mlir::Value op2) {
type = fir::unwrapRefType(type);
assert(type.isIntOrIndexOrFloat() &&
- "only integer and float types are currently supported");
+ "only integer, float and complex types are currently supported");
if (type.isIntOrIndex())
return builder.create<IntegerOp>(loc, op1, op2);
return builder.create<FloatOp>(loc, op1, op2);
}
+template <typename FloatOp, typename IntegerOp, typename ComplexOp>
+mlir::Value
+ReductionProcessor::getReductionOperation(fir::FirOpBuilder &builder,
+ mlir::Type type, mlir::Location loc,
+ mlir::Value op1, mlir::Value op2) {
+ assert(type.isIntOrIndexOrFloat() ||
+ fir::isa_complex(type) &&
+ "only integer, float and complex types are currently supported");
+ if (type.isIntOrIndex())
+ return builder.create<IntegerOp>(loc, op1, op2);
+ if (fir::isa_real(type))
+ return builder.create<FloatOp>(loc, op1, op2);
+ return builder.create<ComplexOp>(loc, op1, op2);
+}
+
} // namespace omp
} // namespace lower
} // namespace Fortran
>From 65f197c9bcb844bfab5d0cd462cad586e6443135 Mon Sep 17 00:00:00 2001
From: Mats Petersson <mats.petersson at arm.com>
Date: Thu, 4 Apr 2024 13:56:29 +0100
Subject: [PATCH 2/3] Fix init value
---
flang/lib/Lower/OpenMP/ReductionProcessor.cpp | 11 ++--
.../OpenMP/parallel-reduction-complex.f90 | 50 +++++++++++++++++++
2 files changed, 55 insertions(+), 6 deletions(-)
create mode 100644 flang/test/Lower/OpenMP/parallel-reduction-complex.f90
diff --git a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
index f06209c0e62032..0453c01522779b 100644
--- a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
@@ -180,13 +180,12 @@ ReductionProcessor::getReductionInitValue(mlir::Location loc, mlir::Type type,
if (auto cplxTy = mlir::dyn_cast<fir::ComplexType>(type)) {
mlir::Type realTy =
Fortran::lower::convertReal(builder.getContext(), cplxTy.getFKind());
- // mlir::FloatType realTy =
- // mlir::dyn_cast<mlir::FloatType>(cplxTy.getElementType());
- // const llvm::fltSemantics &sem = (realTy).getFloatSemantics();
- mlir::Value init = builder.createRealConstant(
+ mlir::Value initRe = builder.createRealConstant(
loc, realTy, getOperationIdentity(redId, loc));
- return fir::factory::Complex{builder, loc}.createComplex(type, init,
- init);
+ mlir::Value initIm = builder.createRealConstant(loc, realTy, 0);
+
+ return fir::factory::Complex{builder, loc}.createComplex(type, initRe,
+ initIm);
}
if (type.isa<mlir::FloatType>())
return builder.create<mlir::arith::ConstantOp>(
diff --git a/flang/test/Lower/OpenMP/parallel-reduction-complex.f90 b/flang/test/Lower/OpenMP/parallel-reduction-complex.f90
new file mode 100644
index 00000000000000..bc5a6b475e2569
--- /dev/null
+++ b/flang/test/Lower/OpenMP/parallel-reduction-complex.f90
@@ -0,0 +1,50 @@
+! RUN: bbc -emit-hlfir -fopenmp -o - %s 2>&1 | FileCheck %s
+! RUN: %flang_fc1 -emit-hlfir -fopenmp -o - %s 2>&1 | FileCheck %s
+
+!CHECK-LABEL: omp.declare_reduction
+!CHECK-SAME: @[[RED_NAME:.*]] : !fir.complex<8> init {
+!CHECK: ^bb0(%{{.*}}: !fir.complex<8>):
+!CHECK: %[[C0_1:.*]] = arith.constant 0.000000e+00 : f64
+!CHECK: %[[C0_2:.*]] = arith.constant 0.000000e+00 : f64
+!CHECK: %[[UNDEF:.*]] = fir.undefined !fir.complex<8>
+!CHECK: %[[RES_1:.*]] = fir.insert_value %[[UNDEF]], %[[C0_1]], [0 : index]
+!CHECK: %[[RES_2:.*]] = fir.insert_value %[[RES_1]], %[[C0_2]], [1 : index]
+!CHECK: omp.yield(%[[RES_2]] : !fir.complex<8>)
+!CHECK: } combiner {
+!CHECK: ^bb0(%[[ARG0:.*]]: !fir.complex<8>, %[[ARG1:.*]]: !fir.complex<8>):
+!CHECK: %[[RES:.*]] = fir.addc %[[ARG0]], %[[ARG1]] {{.*}}: !fir.complex<8>
+!CHECK: omp.yield(%[[RES]] : !fir.complex<8>)
+!CHECK: }
+
+!CHECK-LABEL: func.func @_QPsimple_complex_add
+!CHECK: %[[CREF:.*]] = fir.alloca !fir.complex<8> {bindc_name = "c", {{.*}}}
+!CHECK: %[[C_DECL:.*]]:2 = hlfir.declare %[[CREF]] {uniq_name = "_QFsimple_complex_addEc"} : (!fir.ref<!fir.complex<8>>) -> (!fir.ref<!fir.complex<8>>, !fir.ref<!fir.complex<8>>)
+!CHECK: %[[C_START_RE:.*]] = arith.constant 0.000000e+00 : f64
+!CHECK: %[[C_START_IM:.*]] = arith.constant 0.000000e+00 : f64
+!CHECK: %[[UNDEF_1:.*]] = fir.undefined !fir.complex<8>
+!CHECK: %[[VAL_1:.*]] = fir.insert_value %[[UNDEF_1]], %[[C_START_RE]], [0 : index]
+!CHECK: %[[VAL_2:.*]] = fir.insert_value %[[VAL_1]], %[[C_START_IM]], [1 : index]
+!CHECK: hlfir.assign %[[VAL_2]] to %[[C_DECL]]#0 : !fir.complex<8>, !fir.ref<!fir.complex<8>>
+!CHECK: omp.parallel reduction(@[[RED_NAME]] %[[C_DECL]]#0 -> %[[PRV:.+]] : !fir.ref<!fir.complex<8>>) {
+!CHECK: %[[P_DECL:.+]]:2 = hlfir.declare %[[PRV]] {{.*}} : (!fir.ref<!fir.complex<8>>) -> (!fir.ref<!fir.complex<8>>, !fir.ref<!fir.complex<8>>)
+!CHECK: %[[LPRV:.+]] = fir.load %[[P_DECL]]#0 : !fir.ref<!fir.complex<8>>
+!CHECK: %[[C_INCR_RE:.*]] = arith.constant 1.000000e+00 : f64
+!CHECK: %[[C_INCR_IM:.*]] = arith.constant 0.000000e+00 : f64
+!CHECK: %[[UNDEF_2:.*]] = fir.undefined !fir.complex<8>
+!CHECK: %[[INCR_1:.*]] = fir.insert_value %[[UNDEF_2]], %[[C_INCR_RE]], [0 : index]
+!CHECK: %[[INCR_2:.*]] = fir.insert_value %[[INCR_1]], %[[C_INCR_IM]], [1 : index]
+!CHECK: %[[RES:.+]] = fir.addc %[[LPRV]], %[[INCR_2]] {{.*}} : !fir.complex<8>
+!CHECK: hlfir.assign %[[RES]] to %[[P_DECL]]#0 : !fir.complex<8>, !fir.ref<!fir.complex<8>>
+!CHECK: omp.terminator
+!CHECK: }
+!CHECK: return
+subroutine simple_complex_add
+ complex(8) :: c
+ c = 0
+
+ !$omp parallel reduction(+:c)
+ c = c + 1
+ !$omp end parallel
+
+ print *, c
+end subroutine
>From a659f44ffa3f629a2251b5312c1df4405eb1ad40 Mon Sep 17 00:00:00 2001
From: Mats Petersson <mats.petersson at arm.com>
Date: Thu, 4 Apr 2024 14:19:24 +0100
Subject: [PATCH 3/3] Add tests for complex reduction
---
.../OpenMP/parallel-reduction-complex-mul.f90 | 50 +++++++++++++++++++
1 file changed, 50 insertions(+)
create mode 100644 flang/test/Lower/OpenMP/parallel-reduction-complex-mul.f90
diff --git a/flang/test/Lower/OpenMP/parallel-reduction-complex-mul.f90 b/flang/test/Lower/OpenMP/parallel-reduction-complex-mul.f90
new file mode 100644
index 00000000000000..376defb8235814
--- /dev/null
+++ b/flang/test/Lower/OpenMP/parallel-reduction-complex-mul.f90
@@ -0,0 +1,50 @@
+! RUN: bbc -emit-hlfir -fopenmp -o - %s 2>&1 | FileCheck %s
+! RUN: %flang_fc1 -emit-hlfir -fopenmp -o - %s 2>&1 | FileCheck %s
+
+!CHECK-LABEL: omp.declare_reduction
+!CHECK-SAME: @[[RED_NAME:.*]] : !fir.complex<8> init {
+!CHECK: ^bb0(%{{.*}}: !fir.complex<8>):
+!CHECK: %[[C0_1:.*]] = arith.constant 1.000000e+00 : f64
+!CHECK: %[[C0_2:.*]] = arith.constant 0.000000e+00 : f64
+!CHECK: %[[UNDEF:.*]] = fir.undefined !fir.complex<8>
+!CHECK: %[[RES_1:.*]] = fir.insert_value %[[UNDEF]], %[[C0_1]], [0 : index]
+!CHECK: %[[RES_2:.*]] = fir.insert_value %[[RES_1]], %[[C0_2]], [1 : index]
+!CHECK: omp.yield(%[[RES_2]] : !fir.complex<8>)
+!CHECK: } combiner {
+!CHECK: ^bb0(%[[ARG0:.*]]: !fir.complex<8>, %[[ARG1:.*]]: !fir.complex<8>):
+!CHECK: %[[RES:.*]] = fir.mulc %[[ARG0]], %[[ARG1]] {{.*}}: !fir.complex<8>
+!CHECK: omp.yield(%[[RES]] : !fir.complex<8>)
+!CHECK: }
+
+!CHECK-LABEL: func.func @_QPsimple_complex_mul
+!CHECK: %[[CREF:.*]] = fir.alloca !fir.complex<8> {bindc_name = "c", {{.*}}}
+!CHECK: %[[C_DECL:.*]]:2 = hlfir.declare %[[CREF]] {uniq_name = "_QFsimple_complex_mulEc"} : (!fir.ref<!fir.complex<8>>) -> (!fir.ref<!fir.complex<8>>, !fir.ref<!fir.complex<8>>)
+!CHECK: %[[C_START_RE:.*]] = arith.constant 0.000000e+00 : f64
+!CHECK: %[[C_START_IM:.*]] = arith.constant 0.000000e+00 : f64
+!CHECK: %[[UNDEF_1:.*]] = fir.undefined !fir.complex<8>
+!CHECK: %[[VAL_1:.*]] = fir.insert_value %[[UNDEF_1]], %[[C_START_RE]], [0 : index]
+!CHECK: %[[VAL_2:.*]] = fir.insert_value %[[VAL_1]], %[[C_START_IM]], [1 : index]
+!CHECK: hlfir.assign %[[VAL_2]] to %[[C_DECL]]#0 : !fir.complex<8>, !fir.ref<!fir.complex<8>>
+!CHECK: omp.parallel reduction(@[[RED_NAME]] %[[C_DECL]]#0 -> %[[PRV:.+]] : !fir.ref<!fir.complex<8>>) {
+!CHECK: %[[P_DECL:.+]]:2 = hlfir.declare %[[PRV]] {{.*}} : (!fir.ref<!fir.complex<8>>) -> (!fir.ref<!fir.complex<8>>, !fir.ref<!fir.complex<8>>)
+!CHECK: %[[LPRV:.+]] = fir.load %[[P_DECL]]#0 : !fir.ref<!fir.complex<8>>
+!CHECK: %[[C_INCR_RE:.*]] = arith.constant 1.000000e+00 : f64
+!CHECK: %[[C_INCR_IM:.*]] = arith.constant -2.000000e+00 : f64
+!CHECK: %[[UNDEF_2:.*]] = fir.undefined !fir.complex<8>
+!CHECK: %[[INCR_1:.*]] = fir.insert_value %[[UNDEF_2]], %[[C_INCR_RE]], [0 : index]
+!CHECK: %[[INCR_2:.*]] = fir.insert_value %[[INCR_1]], %[[C_INCR_IM]], [1 : index]
+!CHECK: %[[RES:.+]] = fir.mulc %[[LPRV]], %[[INCR_2]] {{.*}} : !fir.complex<8>
+!CHECK: hlfir.assign %[[RES]] to %[[P_DECL]]#0 : !fir.complex<8>, !fir.ref<!fir.complex<8>>
+!CHECK: omp.terminator
+!CHECK: }
+!CHECK: return
+subroutine simple_complex_mul
+ complex(8) :: c
+ c = 0
+
+ !$omp parallel reduction(*:c)
+ c = c * cmplx(1, -2)
+ !$omp end parallel
+
+ print *, c
+end subroutine
More information about the flang-commits
mailing list