[flang-commits] [flang] Revert "[Flang][OpenMP] NFC: Minor refactoring of Reduction lowering code" (PR #73139)

Kiran Chandramohan via flang-commits flang-commits at lists.llvm.org
Wed Nov 22 07:47:19 PST 2023


https://github.com/kiranchandramohan created https://github.com/llvm/llvm-project/pull/73139

Reverts llvm/llvm-project#70790 to fix CI failure (https://lab.llvm.org/buildbot/#/builders/268/builds/2884)

>From e54390f62e1c76c1239e9955d498dd8432d63c5f Mon Sep 17 00:00:00 2001
From: Kiran Chandramohan <kiranchandramohan at gmail.com>
Date: Wed, 22 Nov 2023 15:46:50 +0000
Subject: [PATCH] Revert "[Flang][OpenMP] NFC: Minor refactoring of Reduction
 lowering code (#70790)"

This reverts commit 8c02b34e3b9b1e2596651959ba76c66a7afaf545.
---
 flang/lib/Lower/OpenMP.cpp | 793 +++++++++++++++++--------------------
 1 file changed, 360 insertions(+), 433 deletions(-)

diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp
index 6267231d7fbe253..f6a61ba3a528e32 100644
--- a/flang/lib/Lower/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP.cpp
@@ -685,441 +685,276 @@ static void checkMapType(mlir::Location location, mlir::Type type) {
       TODO(location, "OMPD_target_data MapOperand BoxType");
 }
 
-class ReductionProcessor {
-public:
-  enum IntrinsicProc { MAX, MIN, IAND, IOR, IEOR };
-  static IntrinsicProc
-  getReductionType(const Fortran::parser::ProcedureDesignator &pd) {
-    auto redType = llvm::StringSwitch<std::optional<IntrinsicProc>>(
-                       getRealName(pd).ToString())
-                       .Case("max", IntrinsicProc::MAX)
-                       .Case("min", IntrinsicProc::MIN)
-                       .Case("iand", IntrinsicProc::IAND)
-                       .Case("ior", IntrinsicProc::IOR)
-                       .Case("ieor", IntrinsicProc::IEOR)
-                       .Default(std::nullopt);
-    assert(redType && "Invalid Reduction");
-    return *redType;
-  }
-
-  static bool supportedIntrinsicProcReduction(
-      const Fortran::parser::ProcedureDesignator &pd) {
-    const auto *name{Fortran::parser::Unwrap<Fortran::parser::Name>(pd)};
-    assert(name && "Invalid Reduction Intrinsic.");
-    auto redType = llvm::StringSwitch<std::optional<IntrinsicProc>>(
-                       getRealName(name).ToString())
-                       .Case("max", IntrinsicProc::MAX)
-                       .Case("min", IntrinsicProc::MIN)
-                       .Case("iand", IntrinsicProc::IAND)
-                       .Case("ior", IntrinsicProc::IOR)
-                       .Case("ieor", IntrinsicProc::IEOR)
-                       .Default(std::nullopt);
-    if (redType)
-      return true;
-    return false;
-  }
-
-  static const Fortran::semantics::SourceName
-  getRealName(const Fortran::parser::Name *name) {
-    return name->symbol->GetUltimate().name();
-  }
-
-  static const Fortran::semantics::SourceName
-  getRealName(const Fortran::parser::ProcedureDesignator &pd) {
-    const auto *name{Fortran::parser::Unwrap<Fortran::parser::Name>(pd)};
-    assert(name && "Invalid Reduction Intrinsic.");
-    return getRealName(name);
-  }
-
-  static std::string getReductionName(llvm::StringRef name, mlir::Type ty) {
-    return (llvm::Twine(name) +
-            (ty.isIntOrIndex() ? llvm::Twine("_i_") : llvm::Twine("_f_")) +
-            llvm::Twine(ty.getIntOrFloatBitWidth()))
-        .str();
-  }
-
-  static std::string getReductionName(
-      Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp,
-      mlir::Type ty) {
-    std::string reductionName;
-
-    switch (intrinsicOp) {
-    case Fortran::parser::DefinedOperator::IntrinsicOperator::Add:
-      reductionName = "add_reduction";
-      break;
-    case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply:
-      reductionName = "multiply_reduction";
-      break;
-    case Fortran::parser::DefinedOperator::IntrinsicOperator::AND:
-      return "and_reduction";
-    case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV:
-      return "eqv_reduction";
-    case Fortran::parser::DefinedOperator::IntrinsicOperator::OR:
-      return "or_reduction";
-    case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV:
-      return "neqv_reduction";
-    default:
-      reductionName = "other_reduction";
-      break;
-    }
-
-    return getReductionName(reductionName, ty);
-  }
+static std::string getReductionName(llvm::StringRef name, mlir::Type ty) {
+  return (llvm::Twine(name) +
+          (ty.isIntOrIndex() ? llvm::Twine("_i_") : llvm::Twine("_f_")) +
+          llvm::Twine(ty.getIntOrFloatBitWidth()))
+      .str();
+}
 
-  /// This function returns the identity value of the operator \p
-  /// reductionOpName. For example:
-  ///    0 + x = x,
-  ///    1 * x = x
-  static int getOperationIdentity(
-      Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp,
-      mlir::Location loc) {
-    switch (intrinsicOp) {
-    case Fortran::parser::DefinedOperator::IntrinsicOperator::Add:
-    case Fortran::parser::DefinedOperator::IntrinsicOperator::OR:
-    case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV:
-      return 0;
-    case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply:
-    case Fortran::parser::DefinedOperator::IntrinsicOperator::AND:
-    case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV:
-      return 1;
-    default:
-      TODO(loc, "Reduction of some intrinsic operators is not supported");
-    }
-  }
+static std::string getReductionName(
+    Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp,
+    mlir::Type ty) {
+  std::string reductionName;
 
-  static mlir::Value getIntrinsicProcInitValue(
-      mlir::Location loc, mlir::Type type,
-      const Fortran::parser::ProcedureDesignator &procDesignator,
-      fir::FirOpBuilder &builder) {
-    assert((fir::isa_integer(type) || fir::isa_real(type) ||
-            type.isa<fir::LogicalType>()) &&
-           "only integer, logical and real types are currently supported");
-    switch (getReductionType(procDesignator)) {
-    case IntrinsicProc::MAX: {
-      if (auto ty = type.dyn_cast<mlir::FloatType>()) {
-        const llvm::fltSemantics &sem = ty.getFloatSemantics();
-        return builder.createRealConstant(
-            loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/true));
-      }
-      unsigned bits = type.getIntOrFloatBitWidth();
-      int64_t minInt = llvm::APInt::getSignedMinValue(bits).getSExtValue();
-      return builder.createIntegerConstant(loc, type, minInt);
-    }
-    case IntrinsicProc::MIN: {
-      if (auto ty = type.dyn_cast<mlir::FloatType>()) {
-        const llvm::fltSemantics &sem = ty.getFloatSemantics();
-        return builder.createRealConstant(
-            loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/false));
-      }
-      unsigned bits = type.getIntOrFloatBitWidth();
-      int64_t maxInt = llvm::APInt::getSignedMaxValue(bits).getSExtValue();
-      return builder.createIntegerConstant(loc, type, maxInt);
-    }
-    case IntrinsicProc::IOR: {
-      unsigned bits = type.getIntOrFloatBitWidth();
-      int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue();
-      return builder.createIntegerConstant(loc, type, zeroInt);
-    }
-    case IntrinsicProc::IEOR: {
-      unsigned bits = type.getIntOrFloatBitWidth();
-      int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue();
-      return builder.createIntegerConstant(loc, type, zeroInt);
-    }
-    case IntrinsicProc::IAND: {
-      unsigned bits = type.getIntOrFloatBitWidth();
-      int64_t allOnInt = llvm::APInt::getAllOnes(bits).getSExtValue();
-      return builder.createIntegerConstant(loc, type, allOnInt);
-    }
-    }
-    llvm_unreachable("Unknown Reduction Intrinsic");
+  switch (intrinsicOp) {
+  case Fortran::parser::DefinedOperator::IntrinsicOperator::Add:
+    reductionName = "add_reduction";
+    break;
+  case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply:
+    reductionName = "multiply_reduction";
+    break;
+  case Fortran::parser::DefinedOperator::IntrinsicOperator::AND:
+    return "and_reduction";
+  case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV:
+    return "eqv_reduction";
+  case Fortran::parser::DefinedOperator::IntrinsicOperator::OR:
+    return "or_reduction";
+  case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV:
+    return "neqv_reduction";
+  default:
+    reductionName = "other_reduction";
+    break;
   }
 
-  static mlir::Value getIntrinsicOpInitValue(
-      mlir::Location loc, mlir::Type type,
-      Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp,
-      fir::FirOpBuilder &builder) {
+  return getReductionName(reductionName, ty);
+}
+
+/// This function returns the identity value of the operator \p reductionOpName.
+/// For example:
+///    0 + x = x,
+///    1 * x = x
+static int getOperationIdentity(llvm::StringRef reductionOpName,
+                                mlir::Location loc) {
+  if (reductionOpName.contains("add") || reductionOpName.contains("or") ||
+      reductionOpName.contains("neqv"))
+    return 0;
+  if (reductionOpName.contains("multiply") || reductionOpName.contains("and") ||
+      reductionOpName.contains("eqv"))
+    return 1;
+  TODO(loc, "Reduction of some intrinsic operators is not supported");
+}
+
+static mlir::Value getReductionInitValue(mlir::Location loc, mlir::Type type,
+                                         llvm::StringRef reductionOpName,
+                                         fir::FirOpBuilder &builder) {
+  assert((fir::isa_integer(type) || fir::isa_real(type) ||
+          type.isa<fir::LogicalType>()) &&
+         "only integer, logical and real types are currently supported");
+  if (reductionOpName.contains("max")) {
+    if (auto ty = type.dyn_cast<mlir::FloatType>()) {
+      const llvm::fltSemantics &sem = ty.getFloatSemantics();
+      return builder.createRealConstant(
+          loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/true));
+    }
+    unsigned bits = type.getIntOrFloatBitWidth();
+    int64_t minInt = llvm::APInt::getSignedMinValue(bits).getSExtValue();
+    return builder.createIntegerConstant(loc, type, minInt);
+  } else if (reductionOpName.contains("min")) {
+    if (auto ty = type.dyn_cast<mlir::FloatType>()) {
+      const llvm::fltSemantics &sem = ty.getFloatSemantics();
+      return builder.createRealConstant(
+          loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/false));
+    }
+    unsigned bits = type.getIntOrFloatBitWidth();
+    int64_t maxInt = llvm::APInt::getSignedMaxValue(bits).getSExtValue();
+    return builder.createIntegerConstant(loc, type, maxInt);
+  } else if (reductionOpName.contains("ior")) {
+    unsigned bits = type.getIntOrFloatBitWidth();
+    int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue();
+    return builder.createIntegerConstant(loc, type, zeroInt);
+  } else if (reductionOpName.contains("ieor")) {
+    unsigned bits = type.getIntOrFloatBitWidth();
+    int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue();
+    return builder.createIntegerConstant(loc, type, zeroInt);
+  } else if (reductionOpName.contains("iand")) {
+    unsigned bits = type.getIntOrFloatBitWidth();
+    int64_t allOnInt = llvm::APInt::getAllOnes(bits).getSExtValue();
+    return builder.createIntegerConstant(loc, type, allOnInt);
+  } else {
     if (type.isa<mlir::FloatType>())
       return builder.create<mlir::arith::ConstantOp>(
           loc, type,
-          builder.getFloatAttr(type,
-                               (double)getOperationIdentity(intrinsicOp, loc)));
+          builder.getFloatAttr(
+              type, (double)getOperationIdentity(reductionOpName, loc)));
 
     if (type.isa<fir::LogicalType>()) {
       mlir::Value intConst = builder.create<mlir::arith::ConstantOp>(
           loc, builder.getI1Type(),
           builder.getIntegerAttr(builder.getI1Type(),
-                                 getOperationIdentity(intrinsicOp, loc)));
+                                 getOperationIdentity(reductionOpName, loc)));
       return builder.createConvert(loc, type, intConst);
     }
 
     return builder.create<mlir::arith::ConstantOp>(
         loc, type,
-        builder.getIntegerAttr(type, getOperationIdentity(intrinsicOp, loc)));
-  }
-
-  template <typename FloatOp, typename IntegerOp>
-  static mlir::Value getReductionOperation(fir::FirOpBuilder &builder,
-                                           mlir::Type type, mlir::Location loc,
-                                           mlir::Value op1, mlir::Value op2) {
-    assert(type.isIntOrIndexOrFloat() &&
-           "only integer and float types are currently supported");
-    if (type.isIntOrIndex())
-      return builder.create<IntegerOp>(loc, op1, op2);
-    return builder.create<FloatOp>(loc, op1, op2);
-  }
-
-  /// Creates an OpenMP reduction declaration and inserts it into the provided
-  /// symbol table. The declaration has a constant initializer with the neutral
-  /// value `initValue`, and the reduction combiner carried over from `reduce`.
-  /// TODO: Generalize this for non-integer types, add atomic region.
-  static mlir::omp::ReductionDeclareOp createReductionDecl(
-      fir::FirOpBuilder &builder, llvm::StringRef reductionOpName,
-      const Fortran::parser::ProcedureDesignator &procDesignator,
-      mlir::Type type, mlir::Location loc) {
-    mlir::OpBuilder::InsertionGuard guard(builder);
-    mlir::ModuleOp module = builder.getModule();
-
-    auto decl =
-        module.lookupSymbol<mlir::omp::ReductionDeclareOp>(reductionOpName);
-    if (decl)
-      return decl;
+        builder.getIntegerAttr(type,
+                               getOperationIdentity(reductionOpName, loc)));
+  }
+}
+
+template <typename FloatOp, typename IntegerOp>
+static mlir::Value getReductionOperation(fir::FirOpBuilder &builder,
+                                         mlir::Type type, mlir::Location loc,
+                                         mlir::Value op1, mlir::Value op2) {
+  assert(type.isIntOrIndexOrFloat() &&
+         "only integer and float types are currently supported");
+  if (type.isIntOrIndex())
+    return builder.create<IntegerOp>(loc, op1, op2);
+  return builder.create<FloatOp>(loc, op1, op2);
+}
+
+static mlir::omp::ReductionDeclareOp
+createMinimalReductionDecl(fir::FirOpBuilder &builder,
+                           llvm::StringRef reductionOpName, mlir::Type type,
+                           mlir::Location loc) {
+  mlir::ModuleOp module = builder.getModule();
+  mlir::OpBuilder modBuilder(module.getBodyRegion());
+
+  mlir::omp::ReductionDeclareOp decl =
+      modBuilder.create<mlir::omp::ReductionDeclareOp>(loc, reductionOpName,
+                                                       type);
+  builder.createBlock(&decl.getInitializerRegion(),
+                      decl.getInitializerRegion().end(), {type}, {loc});
+  builder.setInsertionPointToEnd(&decl.getInitializerRegion().back());
+  mlir::Value init = getReductionInitValue(loc, type, reductionOpName, builder);
+  builder.create<mlir::omp::YieldOp>(loc, init);
+
+  builder.createBlock(&decl.getReductionRegion(),
+                      decl.getReductionRegion().end(), {type, type},
+                      {loc, loc});
+
+  return decl;
+}
+
+/// Creates an OpenMP reduction declaration and inserts it into the provided
+/// symbol table. The declaration has a constant initializer with the neutral
+/// value `initValue`, and the reduction combiner carried over from `reduce`.
+/// TODO: Generalize this for non-integer types, add atomic region.
+static mlir::omp::ReductionDeclareOp
+createReductionDecl(fir::FirOpBuilder &builder, llvm::StringRef reductionOpName,
+                    const Fortran::parser::ProcedureDesignator &procDesignator,
+                    mlir::Type type, mlir::Location loc) {
+  mlir::OpBuilder::InsertionGuard guard(builder);
+  mlir::ModuleOp module = builder.getModule();
+
+  auto decl =
+      module.lookupSymbol<mlir::omp::ReductionDeclareOp>(reductionOpName);
+  if (decl)
+    return decl;
 
-    mlir::OpBuilder modBuilder(module.getBodyRegion());
+  decl = createMinimalReductionDecl(builder, reductionOpName, type, loc);
+  builder.setInsertionPointToEnd(&decl.getReductionRegion().back());
+  mlir::Value op1 = decl.getReductionRegion().front().getArgument(0);
+  mlir::Value op2 = decl.getReductionRegion().front().getArgument(1);
 
-    decl = modBuilder.create<mlir::omp::ReductionDeclareOp>(
-        loc, reductionOpName, type);
-    builder.createBlock(&decl.getInitializerRegion(),
-                        decl.getInitializerRegion().end(), {type}, {loc});
-    builder.setInsertionPointToEnd(&decl.getInitializerRegion().back());
-    mlir::Value init =
-        getIntrinsicProcInitValue(loc, type, procDesignator, builder);
-    builder.create<mlir::omp::YieldOp>(loc, init);
-
-    builder.createBlock(&decl.getReductionRegion(),
-                        decl.getReductionRegion().end(), {type, type},
-                        {loc, loc});
-
-    builder.setInsertionPointToEnd(&decl.getReductionRegion().back());
-    mlir::Value op1 = decl.getReductionRegion().front().getArgument(0);
-    mlir::Value op2 = decl.getReductionRegion().front().getArgument(1);
-
-    mlir::Value reductionOp;
-    switch (getReductionType(procDesignator)) {
-    case IntrinsicProc::MAX:
+  mlir::Value reductionOp;
+  if (const auto *name{
+          Fortran::parser::Unwrap<Fortran::parser::Name>(procDesignator)}) {
+    if (name->source == "max") {
       reductionOp =
           getReductionOperation<mlir::arith::MaximumFOp, mlir::arith::MaxSIOp>(
               builder, type, loc, op1, op2);
-      break;
-    case IntrinsicProc::MIN:
+    } else if (name->source == "min") {
       reductionOp =
           getReductionOperation<mlir::arith::MinimumFOp, mlir::arith::MinSIOp>(
               builder, type, loc, op1, op2);
-      break;
-    case IntrinsicProc::IOR:
+    } else if (name->source == "ior") {
       assert((type.isIntOrIndex()) && "only integer is expected");
       reductionOp = builder.create<mlir::arith::OrIOp>(loc, op1, op2);
-      break;
-    case IntrinsicProc::IEOR:
+    } else if (name->source == "ieor") {
       assert((type.isIntOrIndex()) && "only integer is expected");
       reductionOp = builder.create<mlir::arith::XOrIOp>(loc, op1, op2);
-      break;
-    case IntrinsicProc::IAND:
+    } else if (name->source == "iand") {
       assert((type.isIntOrIndex()) && "only integer is expected");
       reductionOp = builder.create<mlir::arith::AndIOp>(loc, op1, op2);
-      break;
-    default:
-      llvm_unreachable(
-          "Reduction of some intrinsic operators is not supported");
+    } else {
+      TODO(loc, "Reduction of some intrinsic operators is not supported");
     }
-
-    builder.create<mlir::omp::YieldOp>(loc, reductionOp);
-    return decl;
   }
 
-  /// Creates an OpenMP reduction declaration and inserts it into the provided
-  /// symbol table. The declaration has a constant initializer with the neutral
-  /// value `initValue`, and the reduction combiner carried over from `reduce`.
-  /// TODO: Generalize this for non-integer types, add atomic region.
-  static mlir::omp::ReductionDeclareOp createReductionDecl(
-      fir::FirOpBuilder &builder, llvm::StringRef reductionOpName,
-      Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp,
-      mlir::Type type, mlir::Location loc) {
-    mlir::OpBuilder::InsertionGuard guard(builder);
-    mlir::ModuleOp module = builder.getModule();
-
-    auto decl =
-        module.lookupSymbol<mlir::omp::ReductionDeclareOp>(reductionOpName);
-    if (decl)
-      return decl;
-
-    mlir::OpBuilder modBuilder(module.getBodyRegion());
-
-    decl = modBuilder.create<mlir::omp::ReductionDeclareOp>(
-        loc, reductionOpName, type);
-    builder.createBlock(&decl.getInitializerRegion(),
-                        decl.getInitializerRegion().end(), {type}, {loc});
-    builder.setInsertionPointToEnd(&decl.getInitializerRegion().back());
-    mlir::Value init = getIntrinsicOpInitValue(loc, type, intrinsicOp, builder);
-    builder.create<mlir::omp::YieldOp>(loc, init);
+  builder.create<mlir::omp::YieldOp>(loc, reductionOp);
+  return decl;
+}
 
-    builder.createBlock(&decl.getReductionRegion(),
-                        decl.getReductionRegion().end(), {type, type},
-                        {loc, loc});
+/// Creates an OpenMP reduction declaration and inserts it into the provided
+/// symbol table. The declaration has a constant initializer with the neutral
+/// value `initValue`, and the reduction combiner carried over from `reduce`.
+/// TODO: Generalize this for non-integer types, add atomic region.
+static mlir::omp::ReductionDeclareOp createReductionDecl(
+    fir::FirOpBuilder &builder, llvm::StringRef reductionOpName,
+    Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp,
+    mlir::Type type, mlir::Location loc) {
+  mlir::OpBuilder::InsertionGuard guard(builder);
+  mlir::ModuleOp module = builder.getModule();
 
-    builder.setInsertionPointToEnd(&decl.getReductionRegion().back());
-    mlir::Value op1 = decl.getReductionRegion().front().getArgument(0);
-    mlir::Value op2 = decl.getReductionRegion().front().getArgument(1);
+  auto decl =
+      module.lookupSymbol<mlir::omp::ReductionDeclareOp>(reductionOpName);
+  if (decl)
+    return decl;
 
-    mlir::Value reductionOp;
-    switch (intrinsicOp) {
-    case Fortran::parser::DefinedOperator::IntrinsicOperator::Add:
-      reductionOp =
-          getReductionOperation<mlir::arith::AddFOp, mlir::arith::AddIOp>(
-              builder, type, loc, op1, op2);
-      break;
-    case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply:
-      reductionOp =
-          getReductionOperation<mlir::arith::MulFOp, mlir::arith::MulIOp>(
-              builder, type, loc, op1, op2);
-      break;
-    case Fortran::parser::DefinedOperator::IntrinsicOperator::AND: {
-      mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
-      mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
+  decl = createMinimalReductionDecl(builder, reductionOpName, type, loc);
+  builder.setInsertionPointToEnd(&decl.getReductionRegion().back());
+  mlir::Value op1 = decl.getReductionRegion().front().getArgument(0);
+  mlir::Value op2 = decl.getReductionRegion().front().getArgument(1);
 
-      mlir::Value andiOp =
-          builder.create<mlir::arith::AndIOp>(loc, op1I1, op2I1);
+  mlir::Value reductionOp;
+  switch (intrinsicOp) {
+  case Fortran::parser::DefinedOperator::IntrinsicOperator::Add:
+    reductionOp =
+        getReductionOperation<mlir::arith::AddFOp, mlir::arith::AddIOp>(
+            builder, type, loc, op1, op2);
+    break;
+  case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply:
+    reductionOp =
+        getReductionOperation<mlir::arith::MulFOp, mlir::arith::MulIOp>(
+            builder, type, loc, op1, op2);
+    break;
+  case Fortran::parser::DefinedOperator::IntrinsicOperator::AND: {
+    mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
+    mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
 
-      reductionOp = builder.createConvert(loc, type, andiOp);
-      break;
-    }
-    case Fortran::parser::DefinedOperator::IntrinsicOperator::OR: {
-      mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
-      mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
+    mlir::Value andiOp = builder.create<mlir::arith::AndIOp>(loc, op1I1, op2I1);
 
-      mlir::Value oriOp = builder.create<mlir::arith::OrIOp>(loc, op1I1, op2I1);
+    reductionOp = builder.createConvert(loc, type, andiOp);
+    break;
+  }
+  case Fortran::parser::DefinedOperator::IntrinsicOperator::OR: {
+    mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
+    mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
 
-      reductionOp = builder.createConvert(loc, type, oriOp);
-      break;
-    }
-    case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV: {
-      mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
-      mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
+    mlir::Value oriOp = builder.create<mlir::arith::OrIOp>(loc, op1I1, op2I1);
 
-      mlir::Value cmpiOp = builder.create<mlir::arith::CmpIOp>(
-          loc, mlir::arith::CmpIPredicate::eq, op1I1, op2I1);
+    reductionOp = builder.createConvert(loc, type, oriOp);
+    break;
+  }
+  case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV: {
+    mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
+    mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
 
-      reductionOp = builder.createConvert(loc, type, cmpiOp);
-      break;
-    }
-    case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV: {
-      mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
-      mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
+    mlir::Value cmpiOp = builder.create<mlir::arith::CmpIOp>(
+        loc, mlir::arith::CmpIPredicate::eq, op1I1, op2I1);
 
-      mlir::Value cmpiOp = builder.create<mlir::arith::CmpIOp>(
-          loc, mlir::arith::CmpIPredicate::ne, op1I1, op2I1);
+    reductionOp = builder.createConvert(loc, type, cmpiOp);
+    break;
+  }
+  case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV: {
+    mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
+    mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
 
-      reductionOp = builder.createConvert(loc, type, cmpiOp);
-      break;
-    }
-    default:
-      TODO(loc, "Reduction of some intrinsic operators is not supported");
-    }
+    mlir::Value cmpiOp = builder.create<mlir::arith::CmpIOp>(
+        loc, mlir::arith::CmpIPredicate::ne, op1I1, op2I1);
 
-    builder.create<mlir::omp::YieldOp>(loc, reductionOp);
-    return decl;
+    reductionOp = builder.createConvert(loc, type, cmpiOp);
+    break;
   }
-
-  /// Creates a reduction declaration and associates it with an OpenMP block
-  /// directive.
-  static void addReductionDecl(
-      mlir::Location currentLocation,
-      Fortran::lower::AbstractConverter &converter,
-      const Fortran::parser::OmpReductionClause &reduction,
-      llvm::SmallVectorImpl<mlir::Value> &reductionVars,
-      llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols) {
-    fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
-    mlir::omp::ReductionDeclareOp decl;
-    const auto &redOperator{
-        std::get<Fortran::parser::OmpReductionOperator>(reduction.t)};
-    const auto &objectList{
-        std::get<Fortran::parser::OmpObjectList>(reduction.t)};
-    if (const auto &redDefinedOp =
-            std::get_if<Fortran::parser::DefinedOperator>(&redOperator.u)) {
-      const auto &intrinsicOp{
-          std::get<Fortran::parser::DefinedOperator::IntrinsicOperator>(
-              redDefinedOp->u)};
-      switch (intrinsicOp) {
-      case Fortran::parser::DefinedOperator::IntrinsicOperator::Add:
-      case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply:
-      case Fortran::parser::DefinedOperator::IntrinsicOperator::AND:
-      case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV:
-      case Fortran::parser::DefinedOperator::IntrinsicOperator::OR:
-      case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV:
-        break;
-
-      default:
-        TODO(currentLocation,
-             "Reduction of some intrinsic operators is not supported");
-        break;
-      }
-      for (const Fortran::parser::OmpObject &ompObject : objectList.v) {
-        if (const auto *name{
-                Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
-          if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
-            mlir::Value symVal = converter.getSymbolAddress(*symbol);
-            if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
-              symVal = declOp.getBase();
-            mlir::Type redType =
-                symVal.getType().cast<fir::ReferenceType>().getEleTy();
-            reductionVars.push_back(symVal);
-            if (redType.isa<fir::LogicalType>())
-              decl = createReductionDecl(
-                  firOpBuilder,
-                  getReductionName(intrinsicOp, firOpBuilder.getI1Type()),
-                  intrinsicOp, redType, currentLocation);
-            else if (redType.isIntOrIndexOrFloat()) {
-              decl = createReductionDecl(firOpBuilder,
-                                         getReductionName(intrinsicOp, redType),
-                                         intrinsicOp, redType, currentLocation);
-            } else {
-              TODO(currentLocation, "Reduction of some types is not supported");
-            }
-            reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get(
-                firOpBuilder.getContext(), decl.getSymName()));
-          }
-        }
-      }
-    } else if (const auto *reductionIntrinsic =
-                   std::get_if<Fortran::parser::ProcedureDesignator>(
-                       &redOperator.u)) {
-      if (ReductionProcessor::supportedIntrinsicProcReduction(
-              *reductionIntrinsic)) {
-        for (const Fortran::parser::OmpObject &ompObject : objectList.v) {
-          if (const auto *name{
-                  Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
-            if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
-              mlir::Value symVal = converter.getSymbolAddress(*symbol);
-              if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
-                symVal = declOp.getBase();
-              mlir::Type redType =
-                  symVal.getType().cast<fir::ReferenceType>().getEleTy();
-              reductionVars.push_back(symVal);
-              assert(redType.isIntOrIndexOrFloat() &&
-                     "Unsupported reduction type");
-              decl = createReductionDecl(
-                  firOpBuilder,
-                  getReductionName(getRealName(*reductionIntrinsic).ToString(),
-                                   redType),
-                  *reductionIntrinsic, redType, currentLocation);
-              reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get(
-                  firOpBuilder.getContext(), decl.getSymName()));
-            }
-          }
-        }
-      }
-    }
+  default:
+    TODO(loc, "Reduction of some intrinsic operators is not supported");
   }
-};
+
+  builder.create<mlir::omp::YieldOp>(loc, reductionOp);
+  return decl;
+}
 
 static mlir::omp::ScheduleModifier
 translateScheduleModifier(const Fortran::parser::OmpScheduleModifierType &m) {
@@ -1302,6 +1137,101 @@ static mlir::Value getIfClauseOperand(
                                     ifVal);
 }
 
+/// Creates a reduction declaration and associates it with an OpenMP block
+/// directive.
+static void
+addReductionDecl(mlir::Location currentLocation,
+                 Fortran::lower::AbstractConverter &converter,
+                 const Fortran::parser::OmpReductionClause &reduction,
+                 llvm::SmallVectorImpl<mlir::Value> &reductionVars,
+                 llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols) {
+  fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+  mlir::omp::ReductionDeclareOp decl;
+  const auto &redOperator{
+      std::get<Fortran::parser::OmpReductionOperator>(reduction.t)};
+  const auto &objectList{std::get<Fortran::parser::OmpObjectList>(reduction.t)};
+  if (const auto &redDefinedOp =
+          std::get_if<Fortran::parser::DefinedOperator>(&redOperator.u)) {
+    const auto &intrinsicOp{
+        std::get<Fortran::parser::DefinedOperator::IntrinsicOperator>(
+            redDefinedOp->u)};
+    switch (intrinsicOp) {
+    case Fortran::parser::DefinedOperator::IntrinsicOperator::Add:
+    case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply:
+    case Fortran::parser::DefinedOperator::IntrinsicOperator::AND:
+    case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV:
+    case Fortran::parser::DefinedOperator::IntrinsicOperator::OR:
+    case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV:
+      break;
+
+    default:
+      TODO(currentLocation,
+           "Reduction of some intrinsic operators is not supported");
+      break;
+    }
+    for (const Fortran::parser::OmpObject &ompObject : objectList.v) {
+      if (const auto *name{
+              Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
+        if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
+          mlir::Value symVal = converter.getSymbolAddress(*symbol);
+          if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
+            symVal = declOp.getBase();
+          mlir::Type redType =
+              symVal.getType().cast<fir::ReferenceType>().getEleTy();
+          reductionVars.push_back(symVal);
+          if (redType.isa<fir::LogicalType>())
+            decl = createReductionDecl(
+                firOpBuilder,
+                getReductionName(intrinsicOp, firOpBuilder.getI1Type()),
+                intrinsicOp, redType, currentLocation);
+          else if (redType.isIntOrIndexOrFloat()) {
+            decl = createReductionDecl(firOpBuilder,
+                                       getReductionName(intrinsicOp, redType),
+                                       intrinsicOp, redType, currentLocation);
+          } else {
+            TODO(currentLocation, "Reduction of some types is not supported");
+          }
+          reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get(
+              firOpBuilder.getContext(), decl.getSymName()));
+        }
+      }
+    }
+  } else if (const auto *reductionIntrinsic =
+                 std::get_if<Fortran::parser::ProcedureDesignator>(
+                     &redOperator.u)) {
+    if (const auto *name{Fortran::parser::Unwrap<Fortran::parser::Name>(
+            reductionIntrinsic)}) {
+      if ((name->source != "max") && (name->source != "min") &&
+          (name->source != "ior") && (name->source != "ieor") &&
+          (name->source != "iand")) {
+        TODO(currentLocation,
+             "Reduction of intrinsic procedures is not supported");
+      }
+      std::string intrinsicOp = name->ToString();
+      for (const Fortran::parser::OmpObject &ompObject : objectList.v) {
+        if (const auto *name{
+                Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
+          if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
+            mlir::Value symVal = converter.getSymbolAddress(*symbol);
+            if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
+              symVal = declOp.getBase();
+            mlir::Type redType =
+                symVal.getType().cast<fir::ReferenceType>().getEleTy();
+            reductionVars.push_back(symVal);
+            assert(redType.isIntOrIndexOrFloat() &&
+                   "Unsupported reduction type");
+            decl = createReductionDecl(
+                firOpBuilder, getReductionName(intrinsicOp, redType),
+                *reductionIntrinsic, redType, currentLocation);
+            reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get(
+                firOpBuilder.getContext(), decl.getSymName()));
+          }
+        }
+      }
+    }
+  }
+}
+
 static void
 addUseDeviceClause(Fortran::lower::AbstractConverter &converter,
                    const Fortran::parser::OmpObjectList &useDeviceClause,
@@ -1898,9 +1828,8 @@ bool ClauseProcessor::processReduction(
   return findRepeatableClause<ClauseTy::Reduction>(
       [&](const ClauseTy::Reduction *reductionClause,
           const Fortran::parser::CharBlock &) {
-        ReductionProcessor rp;
-        rp.addReductionDecl(currentLocation, converter, reductionClause->v,
-                            reductionVars, reductionDeclSymbols);
+        addReductionDecl(currentLocation, converter, reductionClause->v,
+                         reductionVars, reductionDeclSymbols);
       });
 }
 
@@ -3736,50 +3665,48 @@ void Fortran::lower::genOpenMPReduction(
       } else if (const auto *reductionIntrinsic =
                      std::get_if<Fortran::parser::ProcedureDesignator>(
                          &redOperator.u)) {
-        if (!ReductionProcessor::supportedIntrinsicProcReduction(
-                *reductionIntrinsic))
-          continue;
-        ReductionProcessor::IntrinsicProc redIntrinsicProc =
-            ReductionProcessor::getReductionType(*reductionIntrinsic);
-        for (const Fortran::parser::OmpObject &ompObject : objectList.v) {
-          if (const auto *name{
-                  Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
-            if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
-              mlir::Value reductionVal = converter.getSymbolAddress(*symbol);
-              if (auto declOp = reductionVal.getDefiningOp<hlfir::DeclareOp>())
-                reductionVal = declOp.getBase();
-              for (const mlir::OpOperand &reductionValUse :
-                   reductionVal.getUses()) {
-                if (auto loadOp = mlir::dyn_cast<fir::LoadOp>(
-                        reductionValUse.getOwner())) {
-                  mlir::Value loadVal = loadOp.getRes();
-                  // Max is lowered as a compare -> select.
-                  // Match the pattern here.
-                  mlir::Operation *reductionOp =
-                      findReductionChain(loadVal, &reductionVal);
-                  if (reductionOp == nullptr)
-                    continue;
-
-                  if (redIntrinsicProc ==
-                          ReductionProcessor::IntrinsicProc::MAX ||
-                      redIntrinsicProc ==
-                          ReductionProcessor::IntrinsicProc::MIN) {
-                    assert(mlir::isa<mlir::arith::SelectOp>(reductionOp) &&
-                           "Selection Op not found in reduction intrinsic");
-                    mlir::Operation *compareOp =
-                        getCompareFromReductionOp(reductionOp, loadVal);
-                    updateReduction(compareOp, firOpBuilder, loadVal,
-                                    reductionVal);
-                  }
-                  if (redIntrinsicProc ==
-                          ReductionProcessor::IntrinsicProc::IOR ||
-                      redIntrinsicProc ==
-                          ReductionProcessor::IntrinsicProc::IEOR ||
-                      redIntrinsicProc ==
-                          ReductionProcessor::IntrinsicProc::IAND) {
+        if (const auto *name{Fortran::parser::Unwrap<Fortran::parser::Name>(
+                reductionIntrinsic)}) {
+          std::string redName = name->ToString();
+          if ((name->source != "max") && (name->source != "min") &&
+              (name->source != "ior") && (name->source != "ieor") &&
+              (name->source != "iand")) {
+            continue;
+          }
+          for (const Fortran::parser::OmpObject &ompObject : objectList.v) {
+            if (const auto *name{Fortran::parser::Unwrap<Fortran::parser::Name>(
+                    ompObject)}) {
+              if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
+                mlir::Value reductionVal = converter.getSymbolAddress(*symbol);
+                if (auto declOp =
+                        reductionVal.getDefiningOp<hlfir::DeclareOp>())
+                  reductionVal = declOp.getBase();
+                for (const mlir::OpOperand &reductionValUse :
+                     reductionVal.getUses()) {
+                  if (auto loadOp = mlir::dyn_cast<fir::LoadOp>(
+                          reductionValUse.getOwner())) {
+                    mlir::Value loadVal = loadOp.getRes();
+                    // Max is lowered as a compare -> select.
+                    // Match the pattern here.
+                    mlir::Operation *reductionOp =
+                        findReductionChain(loadVal, &reductionVal);
+                    if (reductionOp == nullptr)
+                      continue;
+
+                    if (redName == "max" || redName == "min") {
+                      assert(mlir::isa<mlir::arith::SelectOp>(reductionOp) &&
+                             "Selection Op not found in reduction intrinsic");
+                      mlir::Operation *compareOp =
+                          getCompareFromReductionOp(reductionOp, loadVal);
+                      updateReduction(compareOp, firOpBuilder, loadVal,
+                                      reductionVal);
+                    }
+                    if (redName == "ior" || redName == "ieor" ||
+                        redName == "iand") {
 
-                    updateReduction(reductionOp, firOpBuilder, loadVal,
-                                    reductionVal);
+                      updateReduction(reductionOp, firOpBuilder, loadVal,
+                                      reductionVal);
+                    }
                   }
                 }
               }



More information about the flang-commits mailing list