[flang-commits] [flang] [Flang][OpenMP] NFC: Refactor reduction code (PR #79876)

Kiran Chandramohan via flang-commits flang-commits at lists.llvm.org
Fri Feb 9 08:44:01 PST 2024


https://github.com/kiranchandramohan updated https://github.com/llvm/llvm-project/pull/79876

>From 62fe53f265ac12e6b82cd12f3d5907a91fe261ed Mon Sep 17 00:00:00 2001
From: Kiran Chandramohan <kiran.chandramohan at arm.com>
Date: Mon, 29 Jan 2024 18:02:20 +0000
Subject: [PATCH 1/2] [Flang][OpenMP] NFC: Refactor reduction code

Introduces a new enumeration to list all Fortran reduction identifiers.
Moves the combiner code-generation into a separate function for possible
reuse in array context in future.
---
 flang/lib/Lower/OpenMP.cpp | 337 +++++++++++++++++++------------------
 1 file changed, 170 insertions(+), 167 deletions(-)

diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp
index ad4cffc707535f..a0f2a11f550dec 100644
--- a/flang/lib/Lower/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP.cpp
@@ -731,21 +731,59 @@ static void checkMapType(mlir::Location location, mlir::Type type) {
 
 class ReductionProcessor {
 public:
-  enum IntrinsicProc { MAX, MIN, IAND, IOR, IEOR };
-  static IntrinsicProc
+  // TODO: Move this enumeration to the OpenMP dialect
+  enum ReductionIdentifier {
+    ID,
+    USER_DEF_OP,
+    ADD,
+    SUBTRACT,
+    MULTIPLY,
+    AND,
+    OR,
+    EQV,
+    NEQV,
+    MAX,
+    MIN,
+    IAND,
+    IOR,
+    IEOR
+  };
+  static ReductionIdentifier
   getReductionType(const Fortran::parser::ProcedureDesignator &pd) {
-    auto redType = llvm::StringSwitch<std::optional<IntrinsicProc>>(
+    auto redType = llvm::StringSwitch<std::optional<ReductionIdentifier>>(
                        getRealName(pd).ToString())
-                       .Case("max", IntrinsicProc::MAX)
-                       .Case("min", IntrinsicProc::MIN)
-                       .Case("iand", IntrinsicProc::IAND)
-                       .Case("ior", IntrinsicProc::IOR)
-                       .Case("ieor", IntrinsicProc::IEOR)
+                       .Case("max", ReductionIdentifier::MAX)
+                       .Case("min", ReductionIdentifier::MIN)
+                       .Case("iand", ReductionIdentifier::IAND)
+                       .Case("ior", ReductionIdentifier::IOR)
+                       .Case("ieor", ReductionIdentifier::IEOR)
                        .Default(std::nullopt);
     assert(redType && "Invalid Reduction");
     return *redType;
   }
 
+  static ReductionIdentifier getReductionType(
+      Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp) {
+    switch (intrinsicOp) {
+    case Fortran::parser::DefinedOperator::IntrinsicOperator::Add:
+      return ReductionIdentifier::ADD;
+    case Fortran::parser::DefinedOperator::IntrinsicOperator::Subtract:
+      return ReductionIdentifier::SUBTRACT;
+    case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply:
+      return ReductionIdentifier::MULTIPLY;
+    case Fortran::parser::DefinedOperator::IntrinsicOperator::AND:
+      return ReductionIdentifier::AND;
+    case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV:
+      return ReductionIdentifier::EQV;
+    case Fortran::parser::DefinedOperator::IntrinsicOperator::OR:
+      return ReductionIdentifier::OR;
+    case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV:
+      return ReductionIdentifier::NEQV;
+    default:
+      llvm_unreachable("unexpected intrinsic operator in reduction");
+    }
+  }
+
   static bool supportedIntrinsicProcReduction(
       const Fortran::parser::ProcedureDesignator &pd) {
     const auto *name{Fortran::parser::Unwrap<Fortran::parser::Name>(pd)};
@@ -753,13 +791,13 @@ class ReductionProcessor {
     if (!name->symbol->GetUltimate().attrs().test(
             Fortran::semantics::Attr::INTRINSIC))
       return false;
-    auto redType = llvm::StringSwitch<std::optional<IntrinsicProc>>(
+    auto redType = llvm::StringSwitch<std::optional<ReductionIdentifier>>(
                        getRealName(name).ToString())
-                       .Case("max", IntrinsicProc::MAX)
-                       .Case("min", IntrinsicProc::MIN)
-                       .Case("iand", IntrinsicProc::IAND)
-                       .Case("ior", IntrinsicProc::IOR)
-                       .Case("ieor", IntrinsicProc::IEOR)
+                       .Case("max", ReductionIdentifier::MAX)
+                       .Case("min", ReductionIdentifier::MIN)
+                       .Case("iand", ReductionIdentifier::IAND)
+                       .Case("ior", ReductionIdentifier::IOR)
+                       .Case("ieor", ReductionIdentifier::IEOR)
                        .Default(std::nullopt);
     if (redType)
       return true;
@@ -817,32 +855,30 @@ class ReductionProcessor {
   /// reductionOpName. For example:
   ///    0 + x = x,
   ///    1 * x = x
-  static int getOperationIdentity(
-      Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp,
-      mlir::Location loc) {
-    switch (intrinsicOp) {
-    case Fortran::parser::DefinedOperator::IntrinsicOperator::Add:
-    case Fortran::parser::DefinedOperator::IntrinsicOperator::OR:
-    case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV:
+  static int getOperationIdentity(ReductionIdentifier redId,
+                                  mlir::Location loc) {
+    switch (redId) {
+    case ReductionIdentifier::ADD:
+    case ReductionIdentifier::OR:
+    case ReductionIdentifier::NEQV:
       return 0;
-    case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply:
-    case Fortran::parser::DefinedOperator::IntrinsicOperator::AND:
-    case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV:
+    case ReductionIdentifier::MULTIPLY:
+    case ReductionIdentifier::AND:
+    case ReductionIdentifier::EQV:
       return 1;
     default:
       TODO(loc, "Reduction of some intrinsic operators is not supported");
     }
   }
 
-  static mlir::Value getIntrinsicProcInitValue(
-      mlir::Location loc, mlir::Type type,
-      const Fortran::parser::ProcedureDesignator &procDesignator,
-      fir::FirOpBuilder &builder) {
+  static mlir::Value getReductionInitValue(mlir::Location loc, mlir::Type type,
+                                           ReductionIdentifier redId,
+                                           fir::FirOpBuilder &builder) {
     assert((fir::isa_integer(type) || fir::isa_real(type) ||
             type.isa<fir::LogicalType>()) &&
            "only integer, logical and real types are currently supported");
-    switch (getReductionType(procDesignator)) {
-    case IntrinsicProc::MAX: {
+    switch (redId) {
+    case ReductionIdentifier::MAX: {
       if (auto ty = type.dyn_cast<mlir::FloatType>()) {
         const llvm::fltSemantics &sem = ty.getFloatSemantics();
         return builder.createRealConstant(
@@ -852,7 +888,7 @@ class ReductionProcessor {
       int64_t minInt = llvm::APInt::getSignedMinValue(bits).getSExtValue();
       return builder.createIntegerConstant(loc, type, minInt);
     }
-    case IntrinsicProc::MIN: {
+    case ReductionIdentifier::MIN: {
       if (auto ty = type.dyn_cast<mlir::FloatType>()) {
         const llvm::fltSemantics &sem = ty.getFloatSemantics();
         return builder.createRealConstant(
@@ -862,46 +898,50 @@ class ReductionProcessor {
       int64_t maxInt = llvm::APInt::getSignedMaxValue(bits).getSExtValue();
       return builder.createIntegerConstant(loc, type, maxInt);
     }
-    case IntrinsicProc::IOR: {
+    case ReductionIdentifier::IOR: {
       unsigned bits = type.getIntOrFloatBitWidth();
       int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue();
       return builder.createIntegerConstant(loc, type, zeroInt);
     }
-    case IntrinsicProc::IEOR: {
+    case ReductionIdentifier::IEOR: {
       unsigned bits = type.getIntOrFloatBitWidth();
       int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue();
       return builder.createIntegerConstant(loc, type, zeroInt);
     }
-    case IntrinsicProc::IAND: {
+    case ReductionIdentifier::IAND: {
       unsigned bits = type.getIntOrFloatBitWidth();
       int64_t allOnInt = llvm::APInt::getAllOnes(bits).getSExtValue();
       return builder.createIntegerConstant(loc, type, allOnInt);
     }
-    }
-    llvm_unreachable("Unknown Reduction Intrinsic");
-  }
+    case ReductionIdentifier::ADD:
+    case ReductionIdentifier::MULTIPLY:
+    case ReductionIdentifier::AND:
+    case ReductionIdentifier::OR:
+    case ReductionIdentifier::EQV:
+    case ReductionIdentifier::NEQV:
+      if (type.isa<mlir::FloatType>())
+        return builder.create<mlir::arith::ConstantOp>(
+            loc, type,
+            builder.getFloatAttr(type,
+                                 (double)getOperationIdentity(redId, loc)));
+
+      if (type.isa<fir::LogicalType>()) {
+        mlir::Value intConst = builder.create<mlir::arith::ConstantOp>(
+            loc, builder.getI1Type(),
+            builder.getIntegerAttr(builder.getI1Type(),
+                                   getOperationIdentity(redId, loc)));
+        return builder.createConvert(loc, type, intConst);
+      }
 
-  static mlir::Value getIntrinsicOpInitValue(
-      mlir::Location loc, mlir::Type type,
-      Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp,
-      fir::FirOpBuilder &builder) {
-    if (type.isa<mlir::FloatType>())
       return builder.create<mlir::arith::ConstantOp>(
           loc, type,
-          builder.getFloatAttr(type,
-                               (double)getOperationIdentity(intrinsicOp, loc)));
-
-    if (type.isa<fir::LogicalType>()) {
-      mlir::Value intConst = builder.create<mlir::arith::ConstantOp>(
-          loc, builder.getI1Type(),
-          builder.getIntegerAttr(builder.getI1Type(),
-                                 getOperationIdentity(intrinsicOp, loc)));
-      return builder.createConvert(loc, type, intConst);
+          builder.getIntegerAttr(type, getOperationIdentity(redId, loc)));
+    case ReductionIdentifier::ID:
+    case ReductionIdentifier::USER_DEF_OP:
+    case ReductionIdentifier::SUBTRACT:
+      TODO(loc, "Reduction of some identifier types is not supported");
     }
-
-    return builder.create<mlir::arith::ConstantOp>(
-        loc, type,
-        builder.getIntegerAttr(type, getOperationIdentity(intrinsicOp, loc)));
+    llvm_unreachable("Unhandled Reduction identifier : getReductionInitValue");
   }
 
   template <typename FloatOp, typename IntegerOp>
@@ -915,118 +955,46 @@ class ReductionProcessor {
     return builder.create<FloatOp>(loc, op1, op2);
   }
 
-  /// Creates an OpenMP reduction declaration and inserts it into the provided
-  /// symbol table. The declaration has a constant initializer with the neutral
-  /// value `initValue`, and the reduction combiner carried over from `reduce`.
-  /// TODO: Generalize this for non-integer types, add atomic region.
-  static mlir::omp::ReductionDeclareOp createReductionDecl(
-      fir::FirOpBuilder &builder, llvm::StringRef reductionOpName,
-      const Fortran::parser::ProcedureDesignator &procDesignator,
-      mlir::Type type, mlir::Location loc) {
-    mlir::OpBuilder::InsertionGuard guard(builder);
-    mlir::ModuleOp module = builder.getModule();
-
-    auto decl =
-        module.lookupSymbol<mlir::omp::ReductionDeclareOp>(reductionOpName);
-    if (decl)
-      return decl;
-
-    mlir::OpBuilder modBuilder(module.getBodyRegion());
-
-    decl = modBuilder.create<mlir::omp::ReductionDeclareOp>(
-        loc, reductionOpName, type);
-    builder.createBlock(&decl.getInitializerRegion(),
-                        decl.getInitializerRegion().end(), {type}, {loc});
-    builder.setInsertionPointToEnd(&decl.getInitializerRegion().back());
-    mlir::Value init =
-        getIntrinsicProcInitValue(loc, type, procDesignator, builder);
-    builder.create<mlir::omp::YieldOp>(loc, init);
-
-    builder.createBlock(&decl.getReductionRegion(),
-                        decl.getReductionRegion().end(), {type, type},
-                        {loc, loc});
-
-    builder.setInsertionPointToEnd(&decl.getReductionRegion().back());
-    mlir::Value op1 = decl.getReductionRegion().front().getArgument(0);
-    mlir::Value op2 = decl.getReductionRegion().front().getArgument(1);
-
+  static mlir::Value createScalarCombiner(fir::FirOpBuilder &builder,
+                                          mlir::Location loc,
+                                          ReductionIdentifier redId,
+                                          mlir::Type type, mlir::Value op1,
+                                          mlir::Value op2) {
     mlir::Value reductionOp;
-    switch (getReductionType(procDesignator)) {
-    case IntrinsicProc::MAX:
+    switch (redId) {
+    case ReductionIdentifier::MAX:
       reductionOp =
           getReductionOperation<mlir::arith::MaximumFOp, mlir::arith::MaxSIOp>(
               builder, type, loc, op1, op2);
       break;
-    case IntrinsicProc::MIN:
+    case ReductionIdentifier::MIN:
       reductionOp =
           getReductionOperation<mlir::arith::MinimumFOp, mlir::arith::MinSIOp>(
               builder, type, loc, op1, op2);
       break;
-    case IntrinsicProc::IOR:
+    case ReductionIdentifier::IOR:
       assert((type.isIntOrIndex()) && "only integer is expected");
       reductionOp = builder.create<mlir::arith::OrIOp>(loc, op1, op2);
       break;
-    case IntrinsicProc::IEOR:
+    case ReductionIdentifier::IEOR:
       assert((type.isIntOrIndex()) && "only integer is expected");
       reductionOp = builder.create<mlir::arith::XOrIOp>(loc, op1, op2);
       break;
-    case IntrinsicProc::IAND:
+    case ReductionIdentifier::IAND:
       assert((type.isIntOrIndex()) && "only integer is expected");
       reductionOp = builder.create<mlir::arith::AndIOp>(loc, op1, op2);
       break;
-    }
-
-    builder.create<mlir::omp::YieldOp>(loc, reductionOp);
-    return decl;
-  }
-
-  /// Creates an OpenMP reduction declaration and inserts it into the provided
-  /// symbol table. The declaration has a constant initializer with the neutral
-  /// value `initValue`, and the reduction combiner carried over from `reduce`.
-  /// TODO: Generalize this for non-integer types, add atomic region.
-  static mlir::omp::ReductionDeclareOp createReductionDecl(
-      fir::FirOpBuilder &builder, llvm::StringRef reductionOpName,
-      Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp,
-      mlir::Type type, mlir::Location loc) {
-    mlir::OpBuilder::InsertionGuard guard(builder);
-    mlir::ModuleOp module = builder.getModule();
-
-    auto decl =
-        module.lookupSymbol<mlir::omp::ReductionDeclareOp>(reductionOpName);
-    if (decl)
-      return decl;
-
-    mlir::OpBuilder modBuilder(module.getBodyRegion());
-
-    decl = modBuilder.create<mlir::omp::ReductionDeclareOp>(
-        loc, reductionOpName, type);
-    builder.createBlock(&decl.getInitializerRegion(),
-                        decl.getInitializerRegion().end(), {type}, {loc});
-    builder.setInsertionPointToEnd(&decl.getInitializerRegion().back());
-    mlir::Value init = getIntrinsicOpInitValue(loc, type, intrinsicOp, builder);
-    builder.create<mlir::omp::YieldOp>(loc, init);
-
-    builder.createBlock(&decl.getReductionRegion(),
-                        decl.getReductionRegion().end(), {type, type},
-                        {loc, loc});
-
-    builder.setInsertionPointToEnd(&decl.getReductionRegion().back());
-    mlir::Value op1 = decl.getReductionRegion().front().getArgument(0);
-    mlir::Value op2 = decl.getReductionRegion().front().getArgument(1);
-
-    mlir::Value reductionOp;
-    switch (intrinsicOp) {
-    case Fortran::parser::DefinedOperator::IntrinsicOperator::Add:
+    case ReductionIdentifier::ADD:
       reductionOp =
           getReductionOperation<mlir::arith::AddFOp, mlir::arith::AddIOp>(
               builder, type, loc, op1, op2);
       break;
-    case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply:
+    case ReductionIdentifier::MULTIPLY:
       reductionOp =
           getReductionOperation<mlir::arith::MulFOp, mlir::arith::MulIOp>(
               builder, type, loc, op1, op2);
       break;
-    case Fortran::parser::DefinedOperator::IntrinsicOperator::AND: {
+    case ReductionIdentifier::AND: {
       mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
       mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
 
@@ -1036,7 +1004,7 @@ class ReductionProcessor {
       reductionOp = builder.createConvert(loc, type, andiOp);
       break;
     }
-    case Fortran::parser::DefinedOperator::IntrinsicOperator::OR: {
+    case ReductionIdentifier::OR: {
       mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
       mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
 
@@ -1045,7 +1013,7 @@ class ReductionProcessor {
       reductionOp = builder.createConvert(loc, type, oriOp);
       break;
     }
-    case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV: {
+    case ReductionIdentifier::EQV: {
       mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
       mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
 
@@ -1055,7 +1023,7 @@ class ReductionProcessor {
       reductionOp = builder.createConvert(loc, type, cmpiOp);
       break;
     }
-    case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV: {
+    case ReductionIdentifier::NEQV: {
       mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
       mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
 
@@ -1069,7 +1037,46 @@ class ReductionProcessor {
       TODO(loc, "Reduction of some intrinsic operators is not supported");
     }
 
+    return reductionOp;
+  }
+
+  /// Creates an OpenMP reduction declaration and inserts it into the provided
+  /// symbol table. The declaration has a constant initializer with the neutral
+  /// value `initValue`, and the reduction combiner carried over from `reduce`.
+  /// TODO: Generalize this for non-integer types, add atomic region.
+  static mlir::omp::ReductionDeclareOp createReductionDecl(
+      fir::FirOpBuilder &builder, llvm::StringRef reductionOpName,
+      const ReductionIdentifier redId, mlir::Type type, mlir::Location loc) {
+    mlir::OpBuilder::InsertionGuard guard(builder);
+    mlir::ModuleOp module = builder.getModule();
+
+    auto decl =
+        module.lookupSymbol<mlir::omp::ReductionDeclareOp>(reductionOpName);
+    if (decl)
+      return decl;
+
+    mlir::OpBuilder modBuilder(module.getBodyRegion());
+
+    decl = modBuilder.create<mlir::omp::ReductionDeclareOp>(
+        loc, reductionOpName, type);
+    builder.createBlock(&decl.getInitializerRegion(),
+                        decl.getInitializerRegion().end(), {type}, {loc});
+    builder.setInsertionPointToEnd(&decl.getInitializerRegion().back());
+    mlir::Value init = getReductionInitValue(loc, type, redId, builder);
+    builder.create<mlir::omp::YieldOp>(loc, init);
+
+    builder.createBlock(&decl.getReductionRegion(),
+                        decl.getReductionRegion().end(), {type, type},
+                        {loc, loc});
+
+    builder.setInsertionPointToEnd(&decl.getReductionRegion().back());
+    mlir::Value op1 = decl.getReductionRegion().front().getArgument(0);
+    mlir::Value op2 = decl.getReductionRegion().front().getArgument(1);
+
+    mlir::Value reductionOp =
+        createScalarCombiner(builder, loc, redId, type, op1, op2);
     builder.create<mlir::omp::YieldOp>(loc, reductionOp);
+
     return decl;
   }
 
@@ -1092,15 +1099,15 @@ class ReductionProcessor {
       const auto &intrinsicOp{
           std::get<Fortran::parser::DefinedOperator::IntrinsicOperator>(
               redDefinedOp->u)};
-      switch (intrinsicOp) {
-      case Fortran::parser::DefinedOperator::IntrinsicOperator::Add:
-      case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply:
-      case Fortran::parser::DefinedOperator::IntrinsicOperator::AND:
-      case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV:
-      case Fortran::parser::DefinedOperator::IntrinsicOperator::OR:
-      case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV:
+      ReductionIdentifier redId = getReductionType(intrinsicOp);
+      switch (redId) {
+      case ReductionIdentifier::ADD:
+      case ReductionIdentifier::MULTIPLY:
+      case ReductionIdentifier::AND:
+      case ReductionIdentifier::EQV:
+      case ReductionIdentifier::OR:
+      case ReductionIdentifier::NEQV:
         break;
-
       default:
         TODO(currentLocation,
              "Reduction of some intrinsic operators is not supported");
@@ -1120,11 +1127,11 @@ class ReductionProcessor {
               decl = createReductionDecl(
                   firOpBuilder,
                   getReductionName(intrinsicOp, firOpBuilder.getI1Type()),
-                  intrinsicOp, redType, currentLocation);
+                  redId, redType, currentLocation);
             else if (redType.isIntOrIndexOrFloat()) {
               decl = createReductionDecl(firOpBuilder,
                                          getReductionName(intrinsicOp, redType),
-                                         intrinsicOp, redType, currentLocation);
+                                         redId, redType, currentLocation);
             } else {
               TODO(currentLocation, "Reduction of some types is not supported");
             }
@@ -1138,6 +1145,8 @@ class ReductionProcessor {
                        &redOperator.u)) {
       if (ReductionProcessor::supportedIntrinsicProcReduction(
               *reductionIntrinsic)) {
+        ReductionProcessor::ReductionIdentifier redId =
+            ReductionProcessor::getReductionType(*reductionIntrinsic);
         for (const Fortran::parser::OmpObject &ompObject : objectList.v) {
           if (const auto *name{
                   Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
@@ -1154,7 +1163,7 @@ class ReductionProcessor {
                   firOpBuilder,
                   getReductionName(getRealName(*reductionIntrinsic).ToString(),
                                    redType),
-                  *reductionIntrinsic, redType, currentLocation);
+                  redId, redType, currentLocation);
               reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get(
                   firOpBuilder.getContext(), decl.getSymName()));
             }
@@ -4174,7 +4183,7 @@ void Fortran::lower::genOpenMPReduction(
         if (!ReductionProcessor::supportedIntrinsicProcReduction(
                 *reductionIntrinsic))
           continue;
-        ReductionProcessor::IntrinsicProc redIntrinsicProc =
+        ReductionProcessor::ReductionIdentifier redId =
             ReductionProcessor::getReductionType(*reductionIntrinsic);
         for (const Fortran::parser::OmpObject &ompObject : objectList.v) {
           if (const auto *name{
@@ -4195,10 +4204,8 @@ void Fortran::lower::genOpenMPReduction(
                   if (reductionOp == nullptr)
                     continue;
 
-                  if (redIntrinsicProc ==
-                          ReductionProcessor::IntrinsicProc::MAX ||
-                      redIntrinsicProc ==
-                          ReductionProcessor::IntrinsicProc::MIN) {
+                  if (redId == ReductionProcessor::ReductionIdentifier::MAX ||
+                      redId == ReductionProcessor::ReductionIdentifier::MIN) {
                     assert(mlir::isa<mlir::arith::SelectOp>(reductionOp) &&
                            "Selection Op not found in reduction intrinsic");
                     mlir::Operation *compareOp =
@@ -4206,13 +4213,9 @@ void Fortran::lower::genOpenMPReduction(
                     updateReduction(compareOp, firOpBuilder, loadVal,
                                     reductionVal);
                   }
-                  if (redIntrinsicProc ==
-                          ReductionProcessor::IntrinsicProc::IOR ||
-                      redIntrinsicProc ==
-                          ReductionProcessor::IntrinsicProc::IEOR ||
-                      redIntrinsicProc ==
-                          ReductionProcessor::IntrinsicProc::IAND) {
-
+                  if (redId == ReductionProcessor::ReductionIdentifier::IOR ||
+                      redId == ReductionProcessor::ReductionIdentifier::IEOR ||
+                      redId == ReductionProcessor::ReductionIdentifier::IAND) {
                     updateReduction(reductionOp, firOpBuilder, loadVal,
                                     reductionVal);
                   }

>From ed67a15c20dab2b54753e9d481d1bf681bc4a424 Mon Sep 17 00:00:00 2001
From: Kiran Chandramohan <kiran.chandramohan at arm.com>
Date: Fri, 9 Feb 2024 16:24:22 +0000
Subject: [PATCH 2/2] [Flang][OpenMP] Address review comments

---
 flang/lib/Lower/OpenMP.cpp | 19 ++++++++-----------
 1 file changed, 8 insertions(+), 11 deletions(-)

diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp
index a0f2a11f550dec..fd18b212bad515 100644
--- a/flang/lib/Lower/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP.cpp
@@ -791,17 +791,14 @@ class ReductionProcessor {
     if (!name->symbol->GetUltimate().attrs().test(
             Fortran::semantics::Attr::INTRINSIC))
       return false;
-    auto redType = llvm::StringSwitch<std::optional<ReductionIdentifier>>(
-                       getRealName(name).ToString())
-                       .Case("max", ReductionIdentifier::MAX)
-                       .Case("min", ReductionIdentifier::MIN)
-                       .Case("iand", ReductionIdentifier::IAND)
-                       .Case("ior", ReductionIdentifier::IOR)
-                       .Case("ieor", ReductionIdentifier::IEOR)
-                       .Default(std::nullopt);
-    if (redType)
-      return true;
-    return false;
+    auto redType = llvm::StringSwitch<bool>(getRealName(name).ToString())
+                       .Case("max", true)
+                       .Case("min", true)
+                       .Case("iand", true)
+                       .Case("ior", true)
+                       .Case("ieor", true)
+                       .Default(false);
+    return redType;
   }
 
   static const Fortran::semantics::SourceName



More information about the flang-commits mailing list