[flang-commits] [flang] [Flang][OpenMP] NFC: Minor refactoring of Reduction lowering code (PR #70790)

Kiran Chandramohan via flang-commits flang-commits at lists.llvm.org
Tue Oct 31 04:57:11 PDT 2023


https://github.com/kiranchandramohan created https://github.com/llvm/llvm-project/pull/70790

Move reduction lowering code into a ReductionProcessor class. Create an enumeration for Intrinsic Procedure reductions.

>From 3911841d4a5ce7d5320c875d4fb49d9f969810f6 Mon Sep 17 00:00:00 2001
From: Kiran Chandramohan <kiran.chandramohan at arm.com>
Date: Tue, 31 Oct 2023 10:43:30 +0000
Subject: [PATCH] [Flang][OpenMP] NFC: Minor refactoring of Reduction lowering
 code

Move reduction lowering code into a ReductionProcessor class.
Create an enumeration for Intrinsic Procedure reductions.
---
 flang/lib/Lower/OpenMP.cpp | 789 ++++++++++++++++++++-----------------
 1 file changed, 431 insertions(+), 358 deletions(-)

diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp
index 0faaae6c08e0476..a804a2ce1780e7d 100644
--- a/flang/lib/Lower/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP.cpp
@@ -667,276 +667,441 @@ static void checkMapType(mlir::Location location, mlir::Type type) {
       TODO(location, "OMPD_target_data MapOperand BoxType");
 }
 
-static std::string getReductionName(llvm::StringRef name, mlir::Type ty) {
-  return (llvm::Twine(name) +
-          (ty.isIntOrIndex() ? llvm::Twine("_i_") : llvm::Twine("_f_")) +
-          llvm::Twine(ty.getIntOrFloatBitWidth()))
-      .str();
-}
+class ReductionProcessor {
+public:
+  enum IntrinsicProc { MAX, MIN, IAND, IOR, IEOR };
+  static IntrinsicProc
+  getReductionType(const Fortran::parser::ProcedureDesignator &pd) {
+    auto redType = llvm::StringSwitch<std::optional<IntrinsicProc>>(
+                       getRealName(pd).ToString())
+                       .Case("max", IntrinsicProc::MAX)
+                       .Case("min", IntrinsicProc::MIN)
+                       .Case("iand", IntrinsicProc::IAND)
+                       .Case("ior", IntrinsicProc::IOR)
+                       .Case("ieor", IntrinsicProc::IEOR)
+                       .Default(std::nullopt);
+    assert(redType && "Invalid Reduction");
+    return *redType;
+  }
+
+  static bool supportedIntrinsicProcReduction(
+      const Fortran::parser::ProcedureDesignator &pd) {
+    const auto *name{Fortran::parser::Unwrap<Fortran::parser::Name>(pd)};
+    assert(name && "Invalid Reduction Intrinsic.");
+    auto redType = llvm::StringSwitch<std::optional<IntrinsicProc>>(
+                       getRealName(name).ToString())
+                       .Case("max", IntrinsicProc::MAX)
+                       .Case("min", IntrinsicProc::MIN)
+                       .Case("iand", IntrinsicProc::IAND)
+                       .Case("ior", IntrinsicProc::IOR)
+                       .Case("ieor", IntrinsicProc::IEOR)
+                       .Default(std::nullopt);
+    if (redType)
+      return true;
+    return false;
+  }
+
+  static const Fortran::semantics::SourceName
+  getRealName(const Fortran::parser::Name *name) {
+    return name->symbol->GetUltimate().name();
+  }
 
-static std::string getReductionName(
-    Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp,
-    mlir::Type ty) {
-  std::string reductionName;
+  static const Fortran::semantics::SourceName
+  getRealName(const Fortran::parser::ProcedureDesignator &pd) {
+    const auto *name{Fortran::parser::Unwrap<Fortran::parser::Name>(pd)};
+    assert(name && "Invalid Reduction Intrinsic.");
+    return getRealName(name);
+  }
 
-  switch (intrinsicOp) {
-  case Fortran::parser::DefinedOperator::IntrinsicOperator::Add:
-    reductionName = "add_reduction";
-    break;
-  case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply:
-    reductionName = "multiply_reduction";
-    break;
-  case Fortran::parser::DefinedOperator::IntrinsicOperator::AND:
-    return "and_reduction";
-  case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV:
-    return "eqv_reduction";
-  case Fortran::parser::DefinedOperator::IntrinsicOperator::OR:
-    return "or_reduction";
-  case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV:
-    return "neqv_reduction";
-  default:
-    reductionName = "other_reduction";
-    break;
+  static std::string getReductionName(llvm::StringRef name, mlir::Type ty) {
+    return (llvm::Twine(name) +
+            (ty.isIntOrIndex() ? llvm::Twine("_i_") : llvm::Twine("_f_")) +
+            llvm::Twine(ty.getIntOrFloatBitWidth()))
+        .str();
+  }
+
+  static std::string getReductionName(
+      Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp,
+      mlir::Type ty) {
+    std::string reductionName;
+
+    switch (intrinsicOp) {
+    case Fortran::parser::DefinedOperator::IntrinsicOperator::Add:
+      reductionName = "add_reduction";
+      break;
+    case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply:
+      reductionName = "multiply_reduction";
+      break;
+    case Fortran::parser::DefinedOperator::IntrinsicOperator::AND:
+      return "and_reduction";
+    case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV:
+      return "eqv_reduction";
+    case Fortran::parser::DefinedOperator::IntrinsicOperator::OR:
+      return "or_reduction";
+    case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV:
+      return "neqv_reduction";
+    default:
+      reductionName = "other_reduction";
+      break;
+    }
+
+    return getReductionName(reductionName, ty);
+  }
+
+  /// This function returns the identity value of the operator \p
+  /// reductionOpName. For example:
+  ///    0 + x = x,
+  ///    1 * x = x
+  static int getOperationIdentity(
+      Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp,
+      mlir::Location loc) {
+    switch (intrinsicOp) {
+    case Fortran::parser::DefinedOperator::IntrinsicOperator::Add:
+    case Fortran::parser::DefinedOperator::IntrinsicOperator::OR:
+    case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV:
+      return 0;
+    case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply:
+    case Fortran::parser::DefinedOperator::IntrinsicOperator::AND:
+    case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV:
+      return 1;
+    default:
+      TODO(loc, "Reduction of some intrinsic operators is not supported");
+    }
   }
 
-  return getReductionName(reductionName, ty);
-}
-
-/// This function returns the identity value of the operator \p reductionOpName.
-/// For example:
-///    0 + x = x,
-///    1 * x = x
-static int getOperationIdentity(llvm::StringRef reductionOpName,
-                                mlir::Location loc) {
-  if (reductionOpName.contains("add") || reductionOpName.contains("or") ||
-      reductionOpName.contains("neqv"))
-    return 0;
-  if (reductionOpName.contains("multiply") || reductionOpName.contains("and") ||
-      reductionOpName.contains("eqv"))
-    return 1;
-  TODO(loc, "Reduction of some intrinsic operators is not supported");
-}
-
-static mlir::Value getReductionInitValue(mlir::Location loc, mlir::Type type,
-                                         llvm::StringRef reductionOpName,
-                                         fir::FirOpBuilder &builder) {
-  assert((fir::isa_integer(type) || fir::isa_real(type) ||
-          type.isa<fir::LogicalType>()) &&
-         "only integer, logical and real types are currently supported");
-  if (reductionOpName.contains("max")) {
-    if (auto ty = type.dyn_cast<mlir::FloatType>()) {
-      const llvm::fltSemantics &sem = ty.getFloatSemantics();
-      return builder.createRealConstant(
-          loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/true));
+  static mlir::Value getIntrinsicProcInitValue(
+      mlir::Location loc, mlir::Type type,
+      const Fortran::parser::ProcedureDesignator &procDesignator,
+      fir::FirOpBuilder &builder) {
+    assert((fir::isa_integer(type) || fir::isa_real(type) ||
+            type.isa<fir::LogicalType>()) &&
+           "only integer, logical and real types are currently supported");
+    switch (getReductionType(procDesignator)) {
+    case IntrinsicProc::MAX: {
+      if (auto ty = type.dyn_cast<mlir::FloatType>()) {
+        const llvm::fltSemantics &sem = ty.getFloatSemantics();
+        return builder.createRealConstant(
+            loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/true));
+      }
+      unsigned bits = type.getIntOrFloatBitWidth();
+      int64_t minInt = llvm::APInt::getSignedMinValue(bits).getSExtValue();
+      return builder.createIntegerConstant(loc, type, minInt);
+    }
+    case IntrinsicProc::MIN: {
+      if (auto ty = type.dyn_cast<mlir::FloatType>()) {
+        const llvm::fltSemantics &sem = ty.getFloatSemantics();
+        return builder.createRealConstant(
+            loc, type, llvm::APFloat::getSmallest(sem, /*Negative=*/true));
+      }
+      unsigned bits = type.getIntOrFloatBitWidth();
+      int64_t maxInt = llvm::APInt::getSignedMaxValue(bits).getSExtValue();
+      return builder.createIntegerConstant(loc, type, maxInt);
     }
-    unsigned bits = type.getIntOrFloatBitWidth();
-    int64_t minInt = llvm::APInt::getSignedMinValue(bits).getSExtValue();
-    return builder.createIntegerConstant(loc, type, minInt);
-  } else if (reductionOpName.contains("min")) {
-    if (auto ty = type.dyn_cast<mlir::FloatType>()) {
-      const llvm::fltSemantics &sem = ty.getFloatSemantics();
-      return builder.createRealConstant(
-          loc, type, llvm::APFloat::getSmallest(sem, /*Negative=*/true));
+    case IntrinsicProc::IOR: {
+      unsigned bits = type.getIntOrFloatBitWidth();
+      int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue();
+      return builder.createIntegerConstant(loc, type, zeroInt);
     }
-    unsigned bits = type.getIntOrFloatBitWidth();
-    int64_t maxInt = llvm::APInt::getSignedMaxValue(bits).getSExtValue();
-    return builder.createIntegerConstant(loc, type, maxInt);
-  } else if (reductionOpName.contains("ior")) {
-    unsigned bits = type.getIntOrFloatBitWidth();
-    int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue();
-    return builder.createIntegerConstant(loc, type, zeroInt);
-  } else if (reductionOpName.contains("ieor")) {
-    unsigned bits = type.getIntOrFloatBitWidth();
-    int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue();
-    return builder.createIntegerConstant(loc, type, zeroInt);
-  } else if (reductionOpName.contains("iand")) {
-    unsigned bits = type.getIntOrFloatBitWidth();
-    int64_t allOnInt = llvm::APInt::getAllOnes(bits).getSExtValue();
-    return builder.createIntegerConstant(loc, type, allOnInt);
-  } else {
+    case IntrinsicProc::IEOR: {
+      unsigned bits = type.getIntOrFloatBitWidth();
+      int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue();
+      return builder.createIntegerConstant(loc, type, zeroInt);
+    }
+    case IntrinsicProc::IAND: {
+      unsigned bits = type.getIntOrFloatBitWidth();
+      int64_t allOnInt = llvm::APInt::getAllOnes(bits).getSExtValue();
+      return builder.createIntegerConstant(loc, type, allOnInt);
+    }
+    }
+    llvm_unreachable("Unknown Reduction Intrinsic");
+  }
+
+  static mlir::Value getIntrinsicOpInitValue(
+      mlir::Location loc, mlir::Type type,
+      Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp,
+      fir::FirOpBuilder &builder) {
     if (type.isa<mlir::FloatType>())
       return builder.create<mlir::arith::ConstantOp>(
           loc, type,
-          builder.getFloatAttr(
-              type, (double)getOperationIdentity(reductionOpName, loc)));
+          builder.getFloatAttr(type,
+                               (double)getOperationIdentity(intrinsicOp, loc)));
 
     if (type.isa<fir::LogicalType>()) {
       mlir::Value intConst = builder.create<mlir::arith::ConstantOp>(
           loc, builder.getI1Type(),
           builder.getIntegerAttr(builder.getI1Type(),
-                                 getOperationIdentity(reductionOpName, loc)));
+                                 getOperationIdentity(intrinsicOp, loc)));
       return builder.createConvert(loc, type, intConst);
     }
 
     return builder.create<mlir::arith::ConstantOp>(
         loc, type,
-        builder.getIntegerAttr(type,
-                               getOperationIdentity(reductionOpName, loc)));
-  }
-}
-
-template <typename FloatOp, typename IntegerOp>
-static mlir::Value getReductionOperation(fir::FirOpBuilder &builder,
-                                         mlir::Type type, mlir::Location loc,
-                                         mlir::Value op1, mlir::Value op2) {
-  assert(type.isIntOrIndexOrFloat() &&
-         "only integer and float types are currently supported");
-  if (type.isIntOrIndex())
-    return builder.create<IntegerOp>(loc, op1, op2);
-  return builder.create<FloatOp>(loc, op1, op2);
-}
-
-static mlir::omp::ReductionDeclareOp
-createMinimalReductionDecl(fir::FirOpBuilder &builder,
-                           llvm::StringRef reductionOpName, mlir::Type type,
-                           mlir::Location loc) {
-  mlir::ModuleOp module = builder.getModule();
-  mlir::OpBuilder modBuilder(module.getBodyRegion());
-
-  mlir::omp::ReductionDeclareOp decl =
-      modBuilder.create<mlir::omp::ReductionDeclareOp>(loc, reductionOpName,
-                                                       type);
-  builder.createBlock(&decl.getInitializerRegion(),
-                      decl.getInitializerRegion().end(), {type}, {loc});
-  builder.setInsertionPointToEnd(&decl.getInitializerRegion().back());
-  mlir::Value init = getReductionInitValue(loc, type, reductionOpName, builder);
-  builder.create<mlir::omp::YieldOp>(loc, init);
-
-  builder.createBlock(&decl.getReductionRegion(),
-                      decl.getReductionRegion().end(), {type, type},
-                      {loc, loc});
-
-  return decl;
-}
-
-/// Creates an OpenMP reduction declaration and inserts it into the provided
-/// symbol table. The declaration has a constant initializer with the neutral
-/// value `initValue`, and the reduction combiner carried over from `reduce`.
-/// TODO: Generalize this for non-integer types, add atomic region.
-static mlir::omp::ReductionDeclareOp
-createReductionDecl(fir::FirOpBuilder &builder, llvm::StringRef reductionOpName,
-                    const Fortran::parser::ProcedureDesignator &procDesignator,
-                    mlir::Type type, mlir::Location loc) {
-  mlir::OpBuilder::InsertionGuard guard(builder);
-  mlir::ModuleOp module = builder.getModule();
-
-  auto decl =
-      module.lookupSymbol<mlir::omp::ReductionDeclareOp>(reductionOpName);
-  if (decl)
-    return decl;
+        builder.getIntegerAttr(type, getOperationIdentity(intrinsicOp, loc)));
+  }
+
+  template <typename FloatOp, typename IntegerOp>
+  static mlir::Value getReductionOperation(fir::FirOpBuilder &builder,
+                                           mlir::Type type, mlir::Location loc,
+                                           mlir::Value op1, mlir::Value op2) {
+    assert(type.isIntOrIndexOrFloat() &&
+           "only integer and float types are currently supported");
+    if (type.isIntOrIndex())
+      return builder.create<IntegerOp>(loc, op1, op2);
+    return builder.create<FloatOp>(loc, op1, op2);
+  }
+
+  /// Creates an OpenMP reduction declaration and inserts it into the provided
+  /// symbol table. The declaration has a constant initializer with the neutral
+  /// value `initValue`, and the reduction combiner carried over from `reduce`.
+  /// TODO: Generalize this for non-integer types, add atomic region.
+  static mlir::omp::ReductionDeclareOp createReductionDecl(
+      fir::FirOpBuilder &builder, llvm::StringRef reductionOpName,
+      const Fortran::parser::ProcedureDesignator &procDesignator,
+      mlir::Type type, mlir::Location loc) {
+    mlir::OpBuilder::InsertionGuard guard(builder);
+    mlir::ModuleOp module = builder.getModule();
+
+    auto decl =
+        module.lookupSymbol<mlir::omp::ReductionDeclareOp>(reductionOpName);
+    if (decl)
+      return decl;
 
-  decl = createMinimalReductionDecl(builder, reductionOpName, type, loc);
-  builder.setInsertionPointToEnd(&decl.getReductionRegion().back());
-  mlir::Value op1 = decl.getReductionRegion().front().getArgument(0);
-  mlir::Value op2 = decl.getReductionRegion().front().getArgument(1);
+    mlir::OpBuilder modBuilder(module.getBodyRegion());
 
-  mlir::Value reductionOp;
-  if (const auto *name{
-          Fortran::parser::Unwrap<Fortran::parser::Name>(procDesignator)}) {
-    if (name->source == "max") {
+    decl = modBuilder.create<mlir::omp::ReductionDeclareOp>(
+        loc, reductionOpName, type);
+    builder.createBlock(&decl.getInitializerRegion(),
+                        decl.getInitializerRegion().end(), {type}, {loc});
+    builder.setInsertionPointToEnd(&decl.getInitializerRegion().back());
+    mlir::Value init =
+        getIntrinsicProcInitValue(loc, type, procDesignator, builder);
+    builder.create<mlir::omp::YieldOp>(loc, init);
+
+    builder.createBlock(&decl.getReductionRegion(),
+                        decl.getReductionRegion().end(), {type, type},
+                        {loc, loc});
+
+    builder.setInsertionPointToEnd(&decl.getReductionRegion().back());
+    mlir::Value op1 = decl.getReductionRegion().front().getArgument(0);
+    mlir::Value op2 = decl.getReductionRegion().front().getArgument(1);
+
+    mlir::Value reductionOp;
+    switch (getReductionType(procDesignator)) {
+    case IntrinsicProc::MAX:
       reductionOp =
           getReductionOperation<mlir::arith::MaximumFOp, mlir::arith::MaxSIOp>(
               builder, type, loc, op1, op2);
-    } else if (name->source == "min") {
+      break;
+    case IntrinsicProc::MIN:
       reductionOp =
           getReductionOperation<mlir::arith::MinimumFOp, mlir::arith::MinSIOp>(
               builder, type, loc, op1, op2);
-    } else if (name->source == "ior") {
+      break;
+    case IntrinsicProc::IOR:
       assert((type.isIntOrIndex()) && "only integer is expected");
       reductionOp = builder.create<mlir::arith::OrIOp>(loc, op1, op2);
-    } else if (name->source == "ieor") {
+      break;
+    case IntrinsicProc::IEOR:
       assert((type.isIntOrIndex()) && "only integer is expected");
       reductionOp = builder.create<mlir::arith::XOrIOp>(loc, op1, op2);
-    } else if (name->source == "iand") {
+      break;
+    case IntrinsicProc::IAND:
       assert((type.isIntOrIndex()) && "only integer is expected");
       reductionOp = builder.create<mlir::arith::AndIOp>(loc, op1, op2);
-    } else {
-      TODO(loc, "Reduction of some intrinsic operators is not supported");
+      break;
+    default:
+      llvm_unreachable(
+          "Reduction of some intrinsic operators is not supported");
     }
+
+    builder.create<mlir::omp::YieldOp>(loc, reductionOp);
+    return decl;
   }
 
-  builder.create<mlir::omp::YieldOp>(loc, reductionOp);
-  return decl;
-}
+  /// Creates an OpenMP reduction declaration and inserts it into the provided
+  /// symbol table. The declaration has a constant initializer with the neutral
+  /// value `initValue`, and the reduction combiner carried over from `reduce`.
+  /// TODO: Generalize this for non-integer types, add atomic region.
+  static mlir::omp::ReductionDeclareOp createReductionDecl(
+      fir::FirOpBuilder &builder, llvm::StringRef reductionOpName,
+      Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp,
+      mlir::Type type, mlir::Location loc) {
+    mlir::OpBuilder::InsertionGuard guard(builder);
+    mlir::ModuleOp module = builder.getModule();
 
-/// Creates an OpenMP reduction declaration and inserts it into the provided
-/// symbol table. The declaration has a constant initializer with the neutral
-/// value `initValue`, and the reduction combiner carried over from `reduce`.
-/// TODO: Generalize this for non-integer types, add atomic region.
-static mlir::omp::ReductionDeclareOp createReductionDecl(
-    fir::FirOpBuilder &builder, llvm::StringRef reductionOpName,
-    Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp,
-    mlir::Type type, mlir::Location loc) {
-  mlir::OpBuilder::InsertionGuard guard(builder);
-  mlir::ModuleOp module = builder.getModule();
+    auto decl =
+        module.lookupSymbol<mlir::omp::ReductionDeclareOp>(reductionOpName);
+    if (decl)
+      return decl;
 
-  auto decl =
-      module.lookupSymbol<mlir::omp::ReductionDeclareOp>(reductionOpName);
-  if (decl)
-    return decl;
+    mlir::OpBuilder modBuilder(module.getBodyRegion());
 
-  decl = createMinimalReductionDecl(builder, reductionOpName, type, loc);
-  builder.setInsertionPointToEnd(&decl.getReductionRegion().back());
-  mlir::Value op1 = decl.getReductionRegion().front().getArgument(0);
-  mlir::Value op2 = decl.getReductionRegion().front().getArgument(1);
+    decl = modBuilder.create<mlir::omp::ReductionDeclareOp>(
+        loc, reductionOpName, type);
+    builder.createBlock(&decl.getInitializerRegion(),
+                        decl.getInitializerRegion().end(), {type}, {loc});
+    builder.setInsertionPointToEnd(&decl.getInitializerRegion().back());
+    mlir::Value init = getIntrinsicOpInitValue(loc, type, intrinsicOp, builder);
+    builder.create<mlir::omp::YieldOp>(loc, init);
 
-  mlir::Value reductionOp;
-  switch (intrinsicOp) {
-  case Fortran::parser::DefinedOperator::IntrinsicOperator::Add:
-    reductionOp =
-        getReductionOperation<mlir::arith::AddFOp, mlir::arith::AddIOp>(
-            builder, type, loc, op1, op2);
-    break;
-  case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply:
-    reductionOp =
-        getReductionOperation<mlir::arith::MulFOp, mlir::arith::MulIOp>(
-            builder, type, loc, op1, op2);
-    break;
-  case Fortran::parser::DefinedOperator::IntrinsicOperator::AND: {
-    mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
-    mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
+    builder.createBlock(&decl.getReductionRegion(),
+                        decl.getReductionRegion().end(), {type, type},
+                        {loc, loc});
 
-    mlir::Value andiOp = builder.create<mlir::arith::AndIOp>(loc, op1I1, op2I1);
+    builder.setInsertionPointToEnd(&decl.getReductionRegion().back());
+    mlir::Value op1 = decl.getReductionRegion().front().getArgument(0);
+    mlir::Value op2 = decl.getReductionRegion().front().getArgument(1);
 
-    reductionOp = builder.createConvert(loc, type, andiOp);
-    break;
-  }
-  case Fortran::parser::DefinedOperator::IntrinsicOperator::OR: {
-    mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
-    mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
+    mlir::Value reductionOp;
+    switch (intrinsicOp) {
+    case Fortran::parser::DefinedOperator::IntrinsicOperator::Add:
+      reductionOp =
+          getReductionOperation<mlir::arith::AddFOp, mlir::arith::AddIOp>(
+              builder, type, loc, op1, op2);
+      break;
+    case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply:
+      reductionOp =
+          getReductionOperation<mlir::arith::MulFOp, mlir::arith::MulIOp>(
+              builder, type, loc, op1, op2);
+      break;
+    case Fortran::parser::DefinedOperator::IntrinsicOperator::AND: {
+      mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
+      mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
 
-    mlir::Value oriOp = builder.create<mlir::arith::OrIOp>(loc, op1I1, op2I1);
+      mlir::Value andiOp =
+          builder.create<mlir::arith::AndIOp>(loc, op1I1, op2I1);
 
-    reductionOp = builder.createConvert(loc, type, oriOp);
-    break;
-  }
-  case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV: {
-    mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
-    mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
+      reductionOp = builder.createConvert(loc, type, andiOp);
+      break;
+    }
+    case Fortran::parser::DefinedOperator::IntrinsicOperator::OR: {
+      mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
+      mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
 
-    mlir::Value cmpiOp = builder.create<mlir::arith::CmpIOp>(
-        loc, mlir::arith::CmpIPredicate::eq, op1I1, op2I1);
+      mlir::Value oriOp = builder.create<mlir::arith::OrIOp>(loc, op1I1, op2I1);
 
-    reductionOp = builder.createConvert(loc, type, cmpiOp);
-    break;
-  }
-  case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV: {
-    mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
-    mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
+      reductionOp = builder.createConvert(loc, type, oriOp);
+      break;
+    }
+    case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV: {
+      mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
+      mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
 
-    mlir::Value cmpiOp = builder.create<mlir::arith::CmpIOp>(
-        loc, mlir::arith::CmpIPredicate::ne, op1I1, op2I1);
+      mlir::Value cmpiOp = builder.create<mlir::arith::CmpIOp>(
+          loc, mlir::arith::CmpIPredicate::eq, op1I1, op2I1);
 
-    reductionOp = builder.createConvert(loc, type, cmpiOp);
-    break;
-  }
-  default:
-    TODO(loc, "Reduction of some intrinsic operators is not supported");
+      reductionOp = builder.createConvert(loc, type, cmpiOp);
+      break;
+    }
+    case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV: {
+      mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
+      mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
+
+      mlir::Value cmpiOp = builder.create<mlir::arith::CmpIOp>(
+          loc, mlir::arith::CmpIPredicate::ne, op1I1, op2I1);
+
+      reductionOp = builder.createConvert(loc, type, cmpiOp);
+      break;
+    }
+    default:
+      TODO(loc, "Reduction of some intrinsic operators is not supported");
+    }
+
+    builder.create<mlir::omp::YieldOp>(loc, reductionOp);
+    return decl;
   }
 
-  builder.create<mlir::omp::YieldOp>(loc, reductionOp);
-  return decl;
-}
+  /// Creates a reduction declaration and associates it with an OpenMP block
+  /// directive.
+  static void addReductionDecl(
+      mlir::Location currentLocation,
+      Fortran::lower::AbstractConverter &converter,
+      const Fortran::parser::OmpReductionClause &reduction,
+      llvm::SmallVectorImpl<mlir::Value> &reductionVars,
+      llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols) {
+    fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+    mlir::omp::ReductionDeclareOp decl;
+    const auto &redOperator{
+        std::get<Fortran::parser::OmpReductionOperator>(reduction.t)};
+    const auto &objectList{
+        std::get<Fortran::parser::OmpObjectList>(reduction.t)};
+    if (const auto &redDefinedOp =
+            std::get_if<Fortran::parser::DefinedOperator>(&redOperator.u)) {
+      const auto &intrinsicOp{
+          std::get<Fortran::parser::DefinedOperator::IntrinsicOperator>(
+              redDefinedOp->u)};
+      switch (intrinsicOp) {
+      case Fortran::parser::DefinedOperator::IntrinsicOperator::Add:
+      case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply:
+      case Fortran::parser::DefinedOperator::IntrinsicOperator::AND:
+      case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV:
+      case Fortran::parser::DefinedOperator::IntrinsicOperator::OR:
+      case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV:
+        break;
+
+      default:
+        TODO(currentLocation,
+             "Reduction of some intrinsic operators is not supported");
+        break;
+      }
+      for (const Fortran::parser::OmpObject &ompObject : objectList.v) {
+        if (const auto *name{
+                Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
+          if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
+            mlir::Value symVal = converter.getSymbolAddress(*symbol);
+            if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
+              symVal = declOp.getBase();
+            mlir::Type redType =
+                symVal.getType().cast<fir::ReferenceType>().getEleTy();
+            reductionVars.push_back(symVal);
+            if (redType.isa<fir::LogicalType>())
+              decl = createReductionDecl(
+                  firOpBuilder,
+                  getReductionName(intrinsicOp, firOpBuilder.getI1Type()),
+                  intrinsicOp, redType, currentLocation);
+            else if (redType.isIntOrIndexOrFloat()) {
+              decl = createReductionDecl(firOpBuilder,
+                                         getReductionName(intrinsicOp, redType),
+                                         intrinsicOp, redType, currentLocation);
+            } else {
+              TODO(currentLocation, "Reduction of some types is not supported");
+            }
+            reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get(
+                firOpBuilder.getContext(), decl.getSymName()));
+          }
+        }
+      }
+    } else if (const auto *reductionIntrinsic =
+                   std::get_if<Fortran::parser::ProcedureDesignator>(
+                       &redOperator.u)) {
+      if (ReductionProcessor::supportedIntrinsicProcReduction(
+              *reductionIntrinsic)) {
+        for (const Fortran::parser::OmpObject &ompObject : objectList.v) {
+          if (const auto *name{
+                  Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
+            if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
+              mlir::Value symVal = converter.getSymbolAddress(*symbol);
+              if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
+                symVal = declOp.getBase();
+              mlir::Type redType =
+                  symVal.getType().cast<fir::ReferenceType>().getEleTy();
+              reductionVars.push_back(symVal);
+              assert(redType.isIntOrIndexOrFloat() &&
+                     "Unsupported reduction type");
+              decl = createReductionDecl(
+                  firOpBuilder,
+                  getReductionName(getRealName(*reductionIntrinsic).ToString(),
+                                   redType),
+                  *reductionIntrinsic, redType, currentLocation);
+              reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get(
+                  firOpBuilder.getContext(), decl.getSymName()));
+            }
+          }
+        }
+      }
+    }
+  }
+};
 
 static mlir::omp::ScheduleModifier
 translateScheduleModifier(const Fortran::parser::OmpScheduleModifierType &m) {
@@ -1119,101 +1284,6 @@ static mlir::Value getIfClauseOperand(
                                     ifVal);
 }
 
-/// Creates a reduction declaration and associates it with an OpenMP block
-/// directive.
-static void
-addReductionDecl(mlir::Location currentLocation,
-                 Fortran::lower::AbstractConverter &converter,
-                 const Fortran::parser::OmpReductionClause &reduction,
-                 llvm::SmallVectorImpl<mlir::Value> &reductionVars,
-                 llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols) {
-  fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
-  mlir::omp::ReductionDeclareOp decl;
-  const auto &redOperator{
-      std::get<Fortran::parser::OmpReductionOperator>(reduction.t)};
-  const auto &objectList{std::get<Fortran::parser::OmpObjectList>(reduction.t)};
-  if (const auto &redDefinedOp =
-          std::get_if<Fortran::parser::DefinedOperator>(&redOperator.u)) {
-    const auto &intrinsicOp{
-        std::get<Fortran::parser::DefinedOperator::IntrinsicOperator>(
-            redDefinedOp->u)};
-    switch (intrinsicOp) {
-    case Fortran::parser::DefinedOperator::IntrinsicOperator::Add:
-    case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply:
-    case Fortran::parser::DefinedOperator::IntrinsicOperator::AND:
-    case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV:
-    case Fortran::parser::DefinedOperator::IntrinsicOperator::OR:
-    case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV:
-      break;
-
-    default:
-      TODO(currentLocation,
-           "Reduction of some intrinsic operators is not supported");
-      break;
-    }
-    for (const Fortran::parser::OmpObject &ompObject : objectList.v) {
-      if (const auto *name{
-              Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
-        if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
-          mlir::Value symVal = converter.getSymbolAddress(*symbol);
-          if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
-            symVal = declOp.getBase();
-          mlir::Type redType =
-              symVal.getType().cast<fir::ReferenceType>().getEleTy();
-          reductionVars.push_back(symVal);
-          if (redType.isa<fir::LogicalType>())
-            decl = createReductionDecl(
-                firOpBuilder,
-                getReductionName(intrinsicOp, firOpBuilder.getI1Type()),
-                intrinsicOp, redType, currentLocation);
-          else if (redType.isIntOrIndexOrFloat()) {
-            decl = createReductionDecl(firOpBuilder,
-                                       getReductionName(intrinsicOp, redType),
-                                       intrinsicOp, redType, currentLocation);
-          } else {
-            TODO(currentLocation, "Reduction of some types is not supported");
-          }
-          reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get(
-              firOpBuilder.getContext(), decl.getSymName()));
-        }
-      }
-    }
-  } else if (const auto *reductionIntrinsic =
-                 std::get_if<Fortran::parser::ProcedureDesignator>(
-                     &redOperator.u)) {
-    if (const auto *name{Fortran::parser::Unwrap<Fortran::parser::Name>(
-            reductionIntrinsic)}) {
-      if ((name->source != "max") && (name->source != "min") &&
-          (name->source != "ior") && (name->source != "ieor") &&
-          (name->source != "iand")) {
-        TODO(currentLocation,
-             "Reduction of intrinsic procedures is not supported");
-      }
-      std::string intrinsicOp = name->ToString();
-      for (const Fortran::parser::OmpObject &ompObject : objectList.v) {
-        if (const auto *name{
-                Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
-          if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
-            mlir::Value symVal = converter.getSymbolAddress(*symbol);
-            if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
-              symVal = declOp.getBase();
-            mlir::Type redType =
-                symVal.getType().cast<fir::ReferenceType>().getEleTy();
-            reductionVars.push_back(symVal);
-            assert(redType.isIntOrIndexOrFloat() &&
-                   "Unsupported reduction type");
-            decl = createReductionDecl(
-                firOpBuilder, getReductionName(intrinsicOp, redType),
-                *reductionIntrinsic, redType, currentLocation);
-            reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get(
-                firOpBuilder.getContext(), decl.getSymName()));
-          }
-        }
-      }
-    }
-  }
-}
-
 static void
 addUseDeviceClause(Fortran::lower::AbstractConverter &converter,
                    const Fortran::parser::OmpObjectList &useDeviceClause,
@@ -1801,8 +1871,9 @@ bool ClauseProcessor::processReduction(
   return findRepeatableClause<ClauseTy::Reduction>(
       [&](const ClauseTy::Reduction *reductionClause,
           const Fortran::parser::CharBlock &) {
-        addReductionDecl(currentLocation, converter, reductionClause->v,
-                         reductionVars, reductionDeclSymbols);
+        ReductionProcessor rp;
+        rp.addReductionDecl(currentLocation, converter, reductionClause->v,
+                            reductionVars, reductionDeclSymbols);
       });
 }
 
@@ -3400,48 +3471,50 @@ void Fortran::lower::genOpenMPReduction(
       } else if (const auto *reductionIntrinsic =
                      std::get_if<Fortran::parser::ProcedureDesignator>(
                          &redOperator.u)) {
-        if (const auto *name{Fortran::parser::Unwrap<Fortran::parser::Name>(
-                reductionIntrinsic)}) {
-          std::string redName = name->ToString();
-          if ((name->source != "max") && (name->source != "min") &&
-              (name->source != "ior") && (name->source != "ieor") &&
-              (name->source != "iand")) {
-            continue;
-          }
-          for (const Fortran::parser::OmpObject &ompObject : objectList.v) {
-            if (const auto *name{Fortran::parser::Unwrap<Fortran::parser::Name>(
-                    ompObject)}) {
-              if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
-                mlir::Value reductionVal = converter.getSymbolAddress(*symbol);
-                if (auto declOp =
-                        reductionVal.getDefiningOp<hlfir::DeclareOp>())
-                  reductionVal = declOp.getBase();
-                for (const mlir::OpOperand &reductionValUse :
-                     reductionVal.getUses()) {
-                  if (auto loadOp = mlir::dyn_cast<fir::LoadOp>(
-                          reductionValUse.getOwner())) {
-                    mlir::Value loadVal = loadOp.getRes();
-                    // Max is lowered as a compare -> select.
-                    // Match the pattern here.
-                    mlir::Operation *reductionOp =
-                        findReductionChain(loadVal, &reductionVal);
-                    if (reductionOp == nullptr)
-                      continue;
-
-                    if (redName == "max" || redName == "min") {
-                      assert(mlir::isa<mlir::arith::SelectOp>(reductionOp) &&
-                             "Selection Op not found in reduction intrinsic");
-                      mlir::Operation *compareOp =
-                          getCompareFromReductionOp(reductionOp, loadVal);
-                      updateReduction(compareOp, firOpBuilder, loadVal,
-                                      reductionVal);
-                    }
-                    if (redName == "ior" || redName == "ieor" ||
-                        redName == "iand") {
+        if (!ReductionProcessor::supportedIntrinsicProcReduction(
+                *reductionIntrinsic))
+          continue;
+        ReductionProcessor::IntrinsicProc redIntrinsicProc =
+            ReductionProcessor::getReductionType(*reductionIntrinsic);
+        for (const Fortran::parser::OmpObject &ompObject : objectList.v) {
+          if (const auto *name{
+                  Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
+            if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
+              mlir::Value reductionVal = converter.getSymbolAddress(*symbol);
+              if (auto declOp = reductionVal.getDefiningOp<hlfir::DeclareOp>())
+                reductionVal = declOp.getBase();
+              for (const mlir::OpOperand &reductionValUse :
+                   reductionVal.getUses()) {
+                if (auto loadOp = mlir::dyn_cast<fir::LoadOp>(
+                        reductionValUse.getOwner())) {
+                  mlir::Value loadVal = loadOp.getRes();
+                  // Max is lowered as a compare -> select.
+                  // Match the pattern here.
+                  mlir::Operation *reductionOp =
+                      findReductionChain(loadVal, &reductionVal);
+                  if (reductionOp == nullptr)
+                    continue;
+
+                  if (redIntrinsicProc ==
+                          ReductionProcessor::IntrinsicProc::MAX ||
+                      redIntrinsicProc ==
+                          ReductionProcessor::IntrinsicProc::MIN) {
+                    assert(mlir::isa<mlir::arith::SelectOp>(reductionOp) &&
+                           "Selection Op not found in reduction intrinsic");
+                    mlir::Operation *compareOp =
+                        getCompareFromReductionOp(reductionOp, loadVal);
+                    updateReduction(compareOp, firOpBuilder, loadVal,
+                                    reductionVal);
+                  }
+                  if (redIntrinsicProc ==
+                          ReductionProcessor::IntrinsicProc::IOR ||
+                      redIntrinsicProc ==
+                          ReductionProcessor::IntrinsicProc::IEOR ||
+                      redIntrinsicProc ==
+                          ReductionProcessor::IntrinsicProc::IAND) {
 
-                      updateReduction(reductionOp, firOpBuilder, loadVal,
-                                      reductionVal);
-                    }
+                    updateReduction(reductionOp, firOpBuilder, loadVal,
+                                    reductionVal);
                   }
                 }
               }



More information about the flang-commits mailing list