[flang-commits] [flang] [Flang][OpenMP] Consider renames when processing reduction intrinsics (PR #70822)

via flang-commits flang-commits at lists.llvm.org
Tue Oct 31 09:21:39 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-flang-openmp

Author: Kiran Chandramohan (kiranchandramohan)

<details>
<summary>Changes</summary>

Fixes #<!-- -->68654

Depends on https://github.com/llvm/llvm-project/pull/70790

---

Patch is 43.86 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/70822.diff


4 Files Affected:

- (modified) flang/lib/Lower/OpenMP.cpp (+434-358) 
- (modified) flang/lib/Semantics/check-omp-structure.cpp (+11-9) 
- (modified) flang/lib/Semantics/resolve-directives.cpp (+15-8) 
- (added) flang/test/Lower/OpenMP/wsloop-reduction-max-2.f90 (+19) 


``````````diff
diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp
index 0faaae6c08e0476..d3db230ae958d9f 100644
--- a/flang/lib/Lower/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP.cpp
@@ -667,276 +667,444 @@ static void checkMapType(mlir::Location location, mlir::Type type) {
       TODO(location, "OMPD_target_data MapOperand BoxType");
 }
 
-static std::string getReductionName(llvm::StringRef name, mlir::Type ty) {
-  return (llvm::Twine(name) +
-          (ty.isIntOrIndex() ? llvm::Twine("_i_") : llvm::Twine("_f_")) +
-          llvm::Twine(ty.getIntOrFloatBitWidth()))
-      .str();
-}
+class ReductionProcessor {
+public:
+  enum IntrinsicProc { MAX, MIN, IAND, IOR, IEOR };
+  static IntrinsicProc
+  getReductionType(const Fortran::parser::ProcedureDesignator &pd) {
+    auto redType = llvm::StringSwitch<std::optional<IntrinsicProc>>(
+                       getRealName(pd).ToString())
+                       .Case("max", IntrinsicProc::MAX)
+                       .Case("min", IntrinsicProc::MIN)
+                       .Case("iand", IntrinsicProc::IAND)
+                       .Case("ior", IntrinsicProc::IOR)
+                       .Case("ieor", IntrinsicProc::IEOR)
+                       .Default(std::nullopt);
+    assert(redType && "Invalid Reduction");
+    return *redType;
+  }
+
+  static bool supportedIntrinsicProcReduction(
+      const Fortran::parser::ProcedureDesignator &pd) {
+    const auto *name{Fortran::parser::Unwrap<Fortran::parser::Name>(pd)};
+    assert(name && "Invalid Reduction Intrinsic.");
+    if (!name->symbol->GetUltimate().attrs().test(
+            Fortran::semantics::Attr::INTRINSIC))
+      return false;
+    auto redType = llvm::StringSwitch<std::optional<IntrinsicProc>>(
+                       getRealName(name).ToString())
+                       .Case("max", IntrinsicProc::MAX)
+                       .Case("min", IntrinsicProc::MIN)
+                       .Case("iand", IntrinsicProc::IAND)
+                       .Case("ior", IntrinsicProc::IOR)
+                       .Case("ieor", IntrinsicProc::IEOR)
+                       .Default(std::nullopt);
+    if (redType)
+      return true;
+    return false;
+  }
+
+  static const Fortran::semantics::SourceName
+  getRealName(const Fortran::parser::Name *name) {
+    return name->symbol->GetUltimate().name();
+  }
 
-static std::string getReductionName(
-    Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp,
-    mlir::Type ty) {
-  std::string reductionName;
+  static const Fortran::semantics::SourceName
+  getRealName(const Fortran::parser::ProcedureDesignator &pd) {
+    const auto *name{Fortran::parser::Unwrap<Fortran::parser::Name>(pd)};
+    assert(name && "Invalid Reduction Intrinsic.");
+    return getRealName(name);
+  }
 
-  switch (intrinsicOp) {
-  case Fortran::parser::DefinedOperator::IntrinsicOperator::Add:
-    reductionName = "add_reduction";
-    break;
-  case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply:
-    reductionName = "multiply_reduction";
-    break;
-  case Fortran::parser::DefinedOperator::IntrinsicOperator::AND:
-    return "and_reduction";
-  case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV:
-    return "eqv_reduction";
-  case Fortran::parser::DefinedOperator::IntrinsicOperator::OR:
-    return "or_reduction";
-  case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV:
-    return "neqv_reduction";
-  default:
-    reductionName = "other_reduction";
-    break;
+  static std::string getReductionName(llvm::StringRef name, mlir::Type ty) {
+    return (llvm::Twine(name) +
+            (ty.isIntOrIndex() ? llvm::Twine("_i_") : llvm::Twine("_f_")) +
+            llvm::Twine(ty.getIntOrFloatBitWidth()))
+        .str();
+  }
+
+  static std::string getReductionName(
+      Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp,
+      mlir::Type ty) {
+    std::string reductionName;
+
+    switch (intrinsicOp) {
+    case Fortran::parser::DefinedOperator::IntrinsicOperator::Add:
+      reductionName = "add_reduction";
+      break;
+    case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply:
+      reductionName = "multiply_reduction";
+      break;
+    case Fortran::parser::DefinedOperator::IntrinsicOperator::AND:
+      return "and_reduction";
+    case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV:
+      return "eqv_reduction";
+    case Fortran::parser::DefinedOperator::IntrinsicOperator::OR:
+      return "or_reduction";
+    case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV:
+      return "neqv_reduction";
+    default:
+      reductionName = "other_reduction";
+      break;
+    }
+
+    return getReductionName(reductionName, ty);
+  }
+
+  /// This function returns the identity value of the operator \p
+  /// reductionOpName. For example:
+  ///    0 + x = x,
+  ///    1 * x = x
+  static int getOperationIdentity(
+      Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp,
+      mlir::Location loc) {
+    switch (intrinsicOp) {
+    case Fortran::parser::DefinedOperator::IntrinsicOperator::Add:
+    case Fortran::parser::DefinedOperator::IntrinsicOperator::OR:
+    case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV:
+      return 0;
+    case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply:
+    case Fortran::parser::DefinedOperator::IntrinsicOperator::AND:
+    case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV:
+      return 1;
+    default:
+      TODO(loc, "Reduction of some intrinsic operators is not supported");
+    }
   }
 
-  return getReductionName(reductionName, ty);
-}
-
-/// This function returns the identity value of the operator \p reductionOpName.
-/// For example:
-///    0 + x = x,
-///    1 * x = x
-static int getOperationIdentity(llvm::StringRef reductionOpName,
-                                mlir::Location loc) {
-  if (reductionOpName.contains("add") || reductionOpName.contains("or") ||
-      reductionOpName.contains("neqv"))
-    return 0;
-  if (reductionOpName.contains("multiply") || reductionOpName.contains("and") ||
-      reductionOpName.contains("eqv"))
-    return 1;
-  TODO(loc, "Reduction of some intrinsic operators is not supported");
-}
-
-static mlir::Value getReductionInitValue(mlir::Location loc, mlir::Type type,
-                                         llvm::StringRef reductionOpName,
-                                         fir::FirOpBuilder &builder) {
-  assert((fir::isa_integer(type) || fir::isa_real(type) ||
-          type.isa<fir::LogicalType>()) &&
-         "only integer, logical and real types are currently supported");
-  if (reductionOpName.contains("max")) {
-    if (auto ty = type.dyn_cast<mlir::FloatType>()) {
-      const llvm::fltSemantics &sem = ty.getFloatSemantics();
-      return builder.createRealConstant(
-          loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/true));
+  static mlir::Value getIntrinsicProcInitValue(
+      mlir::Location loc, mlir::Type type,
+      const Fortran::parser::ProcedureDesignator &procDesignator,
+      fir::FirOpBuilder &builder) {
+    assert((fir::isa_integer(type) || fir::isa_real(type) ||
+            type.isa<fir::LogicalType>()) &&
+           "only integer, logical and real types are currently supported");
+    switch (getReductionType(procDesignator)) {
+    case IntrinsicProc::MAX: {
+      if (auto ty = type.dyn_cast<mlir::FloatType>()) {
+        const llvm::fltSemantics &sem = ty.getFloatSemantics();
+        return builder.createRealConstant(
+            loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/true));
+      }
+      unsigned bits = type.getIntOrFloatBitWidth();
+      int64_t minInt = llvm::APInt::getSignedMinValue(bits).getSExtValue();
+      return builder.createIntegerConstant(loc, type, minInt);
+    }
+    case IntrinsicProc::MIN: {
+      if (auto ty = type.dyn_cast<mlir::FloatType>()) {
+        const llvm::fltSemantics &sem = ty.getFloatSemantics();
+        return builder.createRealConstant(
+            loc, type, llvm::APFloat::getSmallest(sem, /*Negative=*/true));
+      }
+      unsigned bits = type.getIntOrFloatBitWidth();
+      int64_t maxInt = llvm::APInt::getSignedMaxValue(bits).getSExtValue();
+      return builder.createIntegerConstant(loc, type, maxInt);
     }
-    unsigned bits = type.getIntOrFloatBitWidth();
-    int64_t minInt = llvm::APInt::getSignedMinValue(bits).getSExtValue();
-    return builder.createIntegerConstant(loc, type, minInt);
-  } else if (reductionOpName.contains("min")) {
-    if (auto ty = type.dyn_cast<mlir::FloatType>()) {
-      const llvm::fltSemantics &sem = ty.getFloatSemantics();
-      return builder.createRealConstant(
-          loc, type, llvm::APFloat::getSmallest(sem, /*Negative=*/true));
+    case IntrinsicProc::IOR: {
+      unsigned bits = type.getIntOrFloatBitWidth();
+      int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue();
+      return builder.createIntegerConstant(loc, type, zeroInt);
     }
-    unsigned bits = type.getIntOrFloatBitWidth();
-    int64_t maxInt = llvm::APInt::getSignedMaxValue(bits).getSExtValue();
-    return builder.createIntegerConstant(loc, type, maxInt);
-  } else if (reductionOpName.contains("ior")) {
-    unsigned bits = type.getIntOrFloatBitWidth();
-    int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue();
-    return builder.createIntegerConstant(loc, type, zeroInt);
-  } else if (reductionOpName.contains("ieor")) {
-    unsigned bits = type.getIntOrFloatBitWidth();
-    int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue();
-    return builder.createIntegerConstant(loc, type, zeroInt);
-  } else if (reductionOpName.contains("iand")) {
-    unsigned bits = type.getIntOrFloatBitWidth();
-    int64_t allOnInt = llvm::APInt::getAllOnes(bits).getSExtValue();
-    return builder.createIntegerConstant(loc, type, allOnInt);
-  } else {
+    case IntrinsicProc::IEOR: {
+      unsigned bits = type.getIntOrFloatBitWidth();
+      int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue();
+      return builder.createIntegerConstant(loc, type, zeroInt);
+    }
+    case IntrinsicProc::IAND: {
+      unsigned bits = type.getIntOrFloatBitWidth();
+      int64_t allOnInt = llvm::APInt::getAllOnes(bits).getSExtValue();
+      return builder.createIntegerConstant(loc, type, allOnInt);
+    }
+    }
+    llvm_unreachable("Unknown Reduction Intrinsic");
+  }
+
+  static mlir::Value getIntrinsicOpInitValue(
+      mlir::Location loc, mlir::Type type,
+      Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp,
+      fir::FirOpBuilder &builder) {
     if (type.isa<mlir::FloatType>())
       return builder.create<mlir::arith::ConstantOp>(
           loc, type,
-          builder.getFloatAttr(
-              type, (double)getOperationIdentity(reductionOpName, loc)));
+          builder.getFloatAttr(type,
+                               (double)getOperationIdentity(intrinsicOp, loc)));
 
     if (type.isa<fir::LogicalType>()) {
       mlir::Value intConst = builder.create<mlir::arith::ConstantOp>(
           loc, builder.getI1Type(),
           builder.getIntegerAttr(builder.getI1Type(),
-                                 getOperationIdentity(reductionOpName, loc)));
+                                 getOperationIdentity(intrinsicOp, loc)));
       return builder.createConvert(loc, type, intConst);
     }
 
     return builder.create<mlir::arith::ConstantOp>(
         loc, type,
-        builder.getIntegerAttr(type,
-                               getOperationIdentity(reductionOpName, loc)));
-  }
-}
-
-template <typename FloatOp, typename IntegerOp>
-static mlir::Value getReductionOperation(fir::FirOpBuilder &builder,
-                                         mlir::Type type, mlir::Location loc,
-                                         mlir::Value op1, mlir::Value op2) {
-  assert(type.isIntOrIndexOrFloat() &&
-         "only integer and float types are currently supported");
-  if (type.isIntOrIndex())
-    return builder.create<IntegerOp>(loc, op1, op2);
-  return builder.create<FloatOp>(loc, op1, op2);
-}
-
-static mlir::omp::ReductionDeclareOp
-createMinimalReductionDecl(fir::FirOpBuilder &builder,
-                           llvm::StringRef reductionOpName, mlir::Type type,
-                           mlir::Location loc) {
-  mlir::ModuleOp module = builder.getModule();
-  mlir::OpBuilder modBuilder(module.getBodyRegion());
-
-  mlir::omp::ReductionDeclareOp decl =
-      modBuilder.create<mlir::omp::ReductionDeclareOp>(loc, reductionOpName,
-                                                       type);
-  builder.createBlock(&decl.getInitializerRegion(),
-                      decl.getInitializerRegion().end(), {type}, {loc});
-  builder.setInsertionPointToEnd(&decl.getInitializerRegion().back());
-  mlir::Value init = getReductionInitValue(loc, type, reductionOpName, builder);
-  builder.create<mlir::omp::YieldOp>(loc, init);
-
-  builder.createBlock(&decl.getReductionRegion(),
-                      decl.getReductionRegion().end(), {type, type},
-                      {loc, loc});
-
-  return decl;
-}
-
-/// Creates an OpenMP reduction declaration and inserts it into the provided
-/// symbol table. The declaration has a constant initializer with the neutral
-/// value `initValue`, and the reduction combiner carried over from `reduce`.
-/// TODO: Generalize this for non-integer types, add atomic region.
-static mlir::omp::ReductionDeclareOp
-createReductionDecl(fir::FirOpBuilder &builder, llvm::StringRef reductionOpName,
-                    const Fortran::parser::ProcedureDesignator &procDesignator,
-                    mlir::Type type, mlir::Location loc) {
-  mlir::OpBuilder::InsertionGuard guard(builder);
-  mlir::ModuleOp module = builder.getModule();
-
-  auto decl =
-      module.lookupSymbol<mlir::omp::ReductionDeclareOp>(reductionOpName);
-  if (decl)
-    return decl;
+        builder.getIntegerAttr(type, getOperationIdentity(intrinsicOp, loc)));
+  }
+
+  template <typename FloatOp, typename IntegerOp>
+  static mlir::Value getReductionOperation(fir::FirOpBuilder &builder,
+                                           mlir::Type type, mlir::Location loc,
+                                           mlir::Value op1, mlir::Value op2) {
+    assert(type.isIntOrIndexOrFloat() &&
+           "only integer and float types are currently supported");
+    if (type.isIntOrIndex())
+      return builder.create<IntegerOp>(loc, op1, op2);
+    return builder.create<FloatOp>(loc, op1, op2);
+  }
+
+  /// Creates an OpenMP reduction declaration and inserts it into the provided
+  /// symbol table. The declaration has a constant initializer with the neutral
+  /// value `initValue`, and the reduction combiner carried over from `reduce`.
+  /// TODO: Generalize this for non-integer types, add atomic region.
+  static mlir::omp::ReductionDeclareOp createReductionDecl(
+      fir::FirOpBuilder &builder, llvm::StringRef reductionOpName,
+      const Fortran::parser::ProcedureDesignator &procDesignator,
+      mlir::Type type, mlir::Location loc) {
+    mlir::OpBuilder::InsertionGuard guard(builder);
+    mlir::ModuleOp module = builder.getModule();
+
+    auto decl =
+        module.lookupSymbol<mlir::omp::ReductionDeclareOp>(reductionOpName);
+    if (decl)
+      return decl;
 
-  decl = createMinimalReductionDecl(builder, reductionOpName, type, loc);
-  builder.setInsertionPointToEnd(&decl.getReductionRegion().back());
-  mlir::Value op1 = decl.getReductionRegion().front().getArgument(0);
-  mlir::Value op2 = decl.getReductionRegion().front().getArgument(1);
+    mlir::OpBuilder modBuilder(module.getBodyRegion());
 
-  mlir::Value reductionOp;
-  if (const auto *name{
-          Fortran::parser::Unwrap<Fortran::parser::Name>(procDesignator)}) {
-    if (name->source == "max") {
+    decl = modBuilder.create<mlir::omp::ReductionDeclareOp>(
+        loc, reductionOpName, type);
+    builder.createBlock(&decl.getInitializerRegion(),
+                        decl.getInitializerRegion().end(), {type}, {loc});
+    builder.setInsertionPointToEnd(&decl.getInitializerRegion().back());
+    mlir::Value init =
+        getIntrinsicProcInitValue(loc, type, procDesignator, builder);
+    builder.create<mlir::omp::YieldOp>(loc, init);
+
+    builder.createBlock(&decl.getReductionRegion(),
+                        decl.getReductionRegion().end(), {type, type},
+                        {loc, loc});
+
+    builder.setInsertionPointToEnd(&decl.getReductionRegion().back());
+    mlir::Value op1 = decl.getReductionRegion().front().getArgument(0);
+    mlir::Value op2 = decl.getReductionRegion().front().getArgument(1);
+
+    mlir::Value reductionOp;
+    switch (getReductionType(procDesignator)) {
+    case IntrinsicProc::MAX:
       reductionOp =
           getReductionOperation<mlir::arith::MaximumFOp, mlir::arith::MaxSIOp>(
               builder, type, loc, op1, op2);
-    } else if (name->source == "min") {
+      break;
+    case IntrinsicProc::MIN:
       reductionOp =
           getReductionOperation<mlir::arith::MinimumFOp, mlir::arith::MinSIOp>(
               builder, type, loc, op1, op2);
-    } else if (name->source == "ior") {
+      break;
+    case IntrinsicProc::IOR:
       assert((type.isIntOrIndex()) && "only integer is expected");
       reductionOp = builder.create<mlir::arith::OrIOp>(loc, op1, op2);
-    } else if (name->source == "ieor") {
+      break;
+    case IntrinsicProc::IEOR:
       assert((type.isIntOrIndex()) && "only integer is expected");
       reductionOp = builder.create<mlir::arith::XOrIOp>(loc, op1, op2);
-    } else if (name->source == "iand") {
+      break;
+    case IntrinsicProc::IAND:
       assert((type.isIntOrIndex()) && "only integer is expected");
       reductionOp = builder.create<mlir::arith::AndIOp>(loc, op1, op2);
-    } else {
-      TODO(loc, "Reduction of some intrinsic operators is not supported");
+      break;
+    default:
+      llvm_unreachable(
+          "Reduction of some intrinsic operators is not supported");
     }
+
+    builder.create<mlir::omp::YieldOp>(loc, reductionOp);
+    return decl;
   }
 
-  builder.create<mlir::omp::YieldOp>(loc, reductionOp);
-  return decl;
-}
+  /// Creates an OpenMP reduction declaration and inserts it into the provided
+  /// symbol table. The declaration has a constant initializer with the neutral
+  /// value `initValue`, and the reduction combiner carried over from `reduce`.
+  /// TODO: Generalize this for non-integer types, add atomic region.
+  static mlir::omp::ReductionDeclareOp createReductionDecl(
+      fir::FirOpBuilder &builder, llvm::StringRef reductionOpName,
+      Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp,
+      mlir::Type type, mlir::Location loc) {
+    mlir::OpBuilder::InsertionGuard guard(builder);
+    mlir::ModuleOp module = builder.getModule();
 
-/// Creates an OpenMP reduction declaration and inserts it into the provided
-/// symbol table. The declaration has a constant initializer with the neutral
-/// value `initValue`, and the reduction combiner carried over from `reduce`.
-/// TODO: Generalize this for non-integer types, add atomic region.
-static mlir::omp::ReductionDeclareOp createReductionDecl(
-    fir::FirOpBuilder &builder, llvm::StringRef reductionOpName,
-    Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp,
-    mlir::Type type, mlir::Location loc) {
-  mlir::OpBuilder::InsertionGuard guard(builder);
-  mlir::ModuleOp module = builder.getModule();
+    auto decl =
+        module.lookupSymbol<mlir::omp::ReductionDeclareOp>(reductionOpName);
+    if (decl)
+      return decl;
 
-  auto decl =
-      module.lookupSymbol<mlir::omp::ReductionDeclareOp>(reductionOpName);
-  if (decl)
-    return decl;
+    mlir::OpBuilder modBuilder(module.getBodyRegion());
 
-  decl = createMinimalReductionDecl(builder, reductionOpName, type, loc);
-  builder.setInsertionPointToEnd(&decl.getReductionRegion().back());
-  mlir::Value op1 = decl.getReductionRegion().front().getArgument(0);
-  mlir::Value op2 = decl.getReductionRegion().front().getArgument(1);
+    decl = modBuilder.create<mlir::omp::ReductionDeclareOp>(
+        loc, reductionOpName, type);
+    builder.createBlock(&decl.getInitializerRegion(),
+                        decl.getInitializerRegion().end(), {type}, {loc});
+    builder.setInsertionPointToEnd(&decl.getInitializerRe...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/70822


More information about the flang-commits mailing list