[flang-commits] [flang] 119c512 - [flang][openacc] Add support for complex add reduction

Valentin Clement via flang-commits flang-commits at lists.llvm.org
Wed Jul 12 14:28:28 PDT 2023


Author: Valentin Clement
Date: 2023-07-12T14:28:22-07:00
New Revision: 119c512cb26030d58618800075428d748cf90948

URL: https://github.com/llvm/llvm-project/commit/119c512cb26030d58618800075428d748cf90948
DIFF: https://github.com/llvm/llvm-project/commit/119c512cb26030d58618800075428d748cf90948.diff

LOG: [flang][openacc] Add support for complex add reduction

Add lowering support for reduction with the add operator
on complex type.

Reviewed By: razvanlupusoru

Differential Revision: https://reviews.llvm.org/D155007

Added: 
    

Modified: 
    flang/lib/Lower/OpenACC.cpp
    flang/test/Lower/OpenACC/acc-reduction.f90

Removed: 
    


################################################################################
diff  --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index e30ff193ed2ec8..802bbb8c284e50 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -13,10 +13,12 @@
 #include "flang/Lower/OpenACC.h"
 #include "flang/Common/idioms.h"
 #include "flang/Lower/Bridge.h"
+#include "flang/Lower/ConvertType.h"
 #include "flang/Lower/PFTBuilder.h"
 #include "flang/Lower/StatementContext.h"
 #include "flang/Lower/Support/Utils.h"
 #include "flang/Optimizer/Builder/BoxValue.h"
+#include "flang/Optimizer/Builder/Complex.h"
 #include "flang/Optimizer/Builder/FIRBuilder.h"
 #include "flang/Optimizer/Builder/IntrinsicCall.h"
 #include "flang/Optimizer/Builder/Todo.h"
@@ -712,11 +714,17 @@ static mlir::Value genReductionInitValue(fir::FirOpBuilder &builder,
           loc, ty,
           builder.getFloatAttr(ty,
                                getReductionInitValue<llvm::APFloat>(op, ty)));
-  } else {
-    if (auto floatTy = mlir::dyn_cast_or_null<mlir::FloatType>(ty))
-      return builder.create<mlir::arith::ConstantOp>(
-          loc, ty,
-          builder.getFloatAttr(ty, getReductionInitValue<int64_t>(op, ty)));
+  } else if (auto floatTy = mlir::dyn_cast_or_null<mlir::FloatType>(ty)) {
+    return builder.create<mlir::arith::ConstantOp>(
+        loc, ty,
+        builder.getFloatAttr(ty, getReductionInitValue<int64_t>(op, ty)));
+  } else if (auto cmplxTy = mlir::dyn_cast_or_null<fir::ComplexType>(ty)) {
+    mlir::Type floatTy =
+        Fortran::lower::convertReal(builder.getContext(), cmplxTy.getFKind());
+    mlir::Value init = builder.createRealConstant(
+        loc, floatTy, getReductionInitValue<int64_t>(op, cmplxTy));
+    return fir::factory::Complex{builder, loc}.createComplex(cmplxTy.getFKind(),
+                                                             init, init);
   }
   if (auto refTy = mlir::dyn_cast<fir::ReferenceType>(ty)) {
     if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(refTy.getEleTy())) {
@@ -738,7 +746,7 @@ static mlir::Value genReductionInitValue(fir::FirOpBuilder &builder,
     }
   }
 
-  TODO(loc, "reduction type");
+  llvm::report_fatal_error("Unsupported OpenACC reduction type");
 }
 
 template <typename Op>
@@ -808,6 +816,8 @@ static mlir::Value genCombiner(fir::FirOpBuilder &builder, mlir::Location loc,
       return builder.create<mlir::arith::AddIOp>(loc, value1, value2);
     if (mlir::isa<mlir::FloatType>(ty))
       return builder.create<mlir::arith::AddFOp>(loc, value1, value2);
+    if (auto cmplxTy = mlir::dyn_cast_or_null<fir::ComplexType>(ty))
+      return builder.create<fir::AddcOp>(loc, value1, value2);
     TODO(loc, "reduction add type");
   }
 

diff  --git a/flang/test/Lower/OpenACC/acc-reduction.f90 b/flang/test/Lower/OpenACC/acc-reduction.f90
index b88777c79249d3..9e6fa5c01a6a57 100644
--- a/flang/test/Lower/OpenACC/acc-reduction.f90
+++ b/flang/test/Lower/OpenACC/acc-reduction.f90
@@ -2,6 +2,19 @@
 
 ! RUN: bbc -fopenacc -emit-fir %s -o - | FileCheck %s
 
+! CHECK-LABEL: acc.reduction.recipe @reduction_add_z32 : !fir.complex<4> reduction_operator <add> init {
+! CHECK: ^bb0(%{{.*}}: !fir.complex<4>):
+! CHECK:   %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+! CHECK:   %[[UNDEF:.*]] = fir.undefined !fir.complex<4>
+! CHECK:   %[[UNDEF1:.*]] = fir.insert_value %[[UNDEF]], %[[CST]], [0 : index] : (!fir.complex<4>, f32) -> !fir.complex<4>
+! CHECK:   %[[UNDEF2:.*]] = fir.insert_value %[[UNDEF1]], %[[CST]], [1 : index] : (!fir.complex<4>, f32) -> !fir.complex<4>
+! CHECK:   acc.yield %[[UNDEF2]] : !fir.complex<4>
+! CHECK: } combiner {
+! CHECK: ^bb0(%[[ARG0:.*]]: !fir.complex<4>, %[[ARG1:.*]]: !fir.complex<4>):
+! CHECK:   %[[COMBINED:.*]] = fir.addc %[[ARG0]], %[[ARG1]] : !fir.complex<4> 
+! CHECK:   acc.yield %[[COMBINED]] : !fir.complex<4>
+! CHECK: }
+
 ! CHECK-LABEL: acc.reduction.recipe @reduction_neqv_l32 : !fir.logical<4> reduction_operator <neqv> init {
 ! CHECK: ^bb0(%{{.*}}: !fir.logical<4>):
 ! CHECK:   %[[CST:.*]] = arith.constant false
@@ -729,3 +742,13 @@ subroutine acc_reduction_neqv()
 ! CHECK-LABEL: func.func @_QPacc_reduction_neqv()
 ! CHECK: %[[RED:.*]] = acc.reduction varPtr(%{{.*}} : !fir.ref<!fir.logical<4>>) -> !fir.ref<!fir.logical<4>> {name = "l"}
 ! CHECK: acc.parallel reduction(@reduction_neqv_l32 -> %[[RED]] : !fir.ref<!fir.logical<4>>)
+
+subroutine acc_reduction_add_cmplx()
+  complex :: c
+  !$acc parallel reduction(+:c)
+  !$acc end parallel
+end subroutine
+
+! CHECK-LABEL: func.func @_QPacc_reduction_add_cmplx()
+! CHECK: %[[RED:.*]] = acc.reduction varPtr(%{{.*}} : !fir.ref<!fir.complex<4>>) -> !fir.ref<!fir.complex<4>> {name = "c"}
+! CHECK: acc.parallel reduction(@reduction_add_z32 -> %[[RED]] : !fir.ref<!fir.complex<4>>)


        


More information about the flang-commits mailing list