[flang-commits] [flang] [Flang][OpenMP] NFC: Refactor reduction code (PR #79876)

via flang-commits flang-commits at lists.llvm.org
Mon Jan 29 10:15:34 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-fir-hlfir

@llvm/pr-subscribers-flang-openmp

Author: Kiran Chandramohan (kiranchandramohan)

<details>
<summary>Changes</summary>

Introduces a new enumeration to list all Fortran reduction identifiers. Moves the combiner code-generation into a separate function for possible reuse in array context in future.

---

Patch is 23.40 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/79876.diff


1 Files Affected:

- (modified) flang/lib/Lower/OpenMP.cpp (+170-167) 


``````````diff
diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp
index 7dd25f75d9eb76f..52d222f3d601f6a 100644
--- a/flang/lib/Lower/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP.cpp
@@ -726,21 +726,59 @@ static void checkMapType(mlir::Location location, mlir::Type type) {
 
 class ReductionProcessor {
 public:
-  enum IntrinsicProc { MAX, MIN, IAND, IOR, IEOR };
-  static IntrinsicProc
+  // TODO: Move this enumeration to the OpenMP dialect
+  enum ReductionIdentifier {
+    ID,
+    USER_DEF_OP,
+    ADD,
+    SUBTRACT,
+    MULTIPLY,
+    AND,
+    OR,
+    EQV,
+    NEQV,
+    MAX,
+    MIN,
+    IAND,
+    IOR,
+    IEOR
+  };
+  static ReductionIdentifier
   getReductionType(const Fortran::parser::ProcedureDesignator &pd) {
-    auto redType = llvm::StringSwitch<std::optional<IntrinsicProc>>(
+    auto redType = llvm::StringSwitch<std::optional<ReductionIdentifier>>(
                        getRealName(pd).ToString())
-                       .Case("max", IntrinsicProc::MAX)
-                       .Case("min", IntrinsicProc::MIN)
-                       .Case("iand", IntrinsicProc::IAND)
-                       .Case("ior", IntrinsicProc::IOR)
-                       .Case("ieor", IntrinsicProc::IEOR)
+                       .Case("max", ReductionIdentifier::MAX)
+                       .Case("min", ReductionIdentifier::MIN)
+                       .Case("iand", ReductionIdentifier::IAND)
+                       .Case("ior", ReductionIdentifier::IOR)
+                       .Case("ieor", ReductionIdentifier::IEOR)
                        .Default(std::nullopt);
     assert(redType && "Invalid Reduction");
     return *redType;
   }
 
+  static ReductionIdentifier getReductionType(
+      Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp) {
+    switch (intrinsicOp) {
+    case Fortran::parser::DefinedOperator::IntrinsicOperator::Add:
+      return ReductionIdentifier::ADD;
+    case Fortran::parser::DefinedOperator::IntrinsicOperator::Subtract:
+      return ReductionIdentifier::SUBTRACT;
+    case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply:
+      return ReductionIdentifier::MULTIPLY;
+    case Fortran::parser::DefinedOperator::IntrinsicOperator::AND:
+      return ReductionIdentifier::AND;
+    case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV:
+      return ReductionIdentifier::EQV;
+    case Fortran::parser::DefinedOperator::IntrinsicOperator::OR:
+      return ReductionIdentifier::OR;
+    case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV:
+      return ReductionIdentifier::NEQV;
+    default:
+      llvm_unreachable("unexpected intrinsic operator in reduction");
+    }
+  }
+
   static bool supportedIntrinsicProcReduction(
       const Fortran::parser::ProcedureDesignator &pd) {
     const auto *name{Fortran::parser::Unwrap<Fortran::parser::Name>(pd)};
@@ -748,13 +786,13 @@ class ReductionProcessor {
     if (!name->symbol->GetUltimate().attrs().test(
             Fortran::semantics::Attr::INTRINSIC))
       return false;
-    auto redType = llvm::StringSwitch<std::optional<IntrinsicProc>>(
+    auto redType = llvm::StringSwitch<std::optional<ReductionIdentifier>>(
                        getRealName(name).ToString())
-                       .Case("max", IntrinsicProc::MAX)
-                       .Case("min", IntrinsicProc::MIN)
-                       .Case("iand", IntrinsicProc::IAND)
-                       .Case("ior", IntrinsicProc::IOR)
-                       .Case("ieor", IntrinsicProc::IEOR)
+                       .Case("max", ReductionIdentifier::MAX)
+                       .Case("min", ReductionIdentifier::MIN)
+                       .Case("iand", ReductionIdentifier::IAND)
+                       .Case("ior", ReductionIdentifier::IOR)
+                       .Case("ieor", ReductionIdentifier::IEOR)
                        .Default(std::nullopt);
     if (redType)
       return true;
@@ -812,32 +850,30 @@ class ReductionProcessor {
   /// reductionOpName. For example:
   ///    0 + x = x,
   ///    1 * x = x
-  static int getOperationIdentity(
-      Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp,
-      mlir::Location loc) {
-    switch (intrinsicOp) {
-    case Fortran::parser::DefinedOperator::IntrinsicOperator::Add:
-    case Fortran::parser::DefinedOperator::IntrinsicOperator::OR:
-    case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV:
+  static int getOperationIdentity(ReductionIdentifier redId,
+                                  mlir::Location loc) {
+    switch (redId) {
+    case ReductionIdentifier::ADD:
+    case ReductionIdentifier::OR:
+    case ReductionIdentifier::NEQV:
       return 0;
-    case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply:
-    case Fortran::parser::DefinedOperator::IntrinsicOperator::AND:
-    case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV:
+    case ReductionIdentifier::MULTIPLY:
+    case ReductionIdentifier::AND:
+    case ReductionIdentifier::EQV:
       return 1;
     default:
       TODO(loc, "Reduction of some intrinsic operators is not supported");
     }
   }
 
-  static mlir::Value getIntrinsicProcInitValue(
-      mlir::Location loc, mlir::Type type,
-      const Fortran::parser::ProcedureDesignator &procDesignator,
-      fir::FirOpBuilder &builder) {
+  static mlir::Value getReductionInitValue(mlir::Location loc, mlir::Type type,
+                                           ReductionIdentifier redId,
+                                           fir::FirOpBuilder &builder) {
     assert((fir::isa_integer(type) || fir::isa_real(type) ||
             type.isa<fir::LogicalType>()) &&
            "only integer, logical and real types are currently supported");
-    switch (getReductionType(procDesignator)) {
-    case IntrinsicProc::MAX: {
+    switch (redId) {
+    case ReductionIdentifier::MAX: {
       if (auto ty = type.dyn_cast<mlir::FloatType>()) {
         const llvm::fltSemantics &sem = ty.getFloatSemantics();
         return builder.createRealConstant(
@@ -847,7 +883,7 @@ class ReductionProcessor {
       int64_t minInt = llvm::APInt::getSignedMinValue(bits).getSExtValue();
       return builder.createIntegerConstant(loc, type, minInt);
     }
-    case IntrinsicProc::MIN: {
+    case ReductionIdentifier::MIN: {
       if (auto ty = type.dyn_cast<mlir::FloatType>()) {
         const llvm::fltSemantics &sem = ty.getFloatSemantics();
         return builder.createRealConstant(
@@ -857,46 +893,50 @@ class ReductionProcessor {
       int64_t maxInt = llvm::APInt::getSignedMaxValue(bits).getSExtValue();
       return builder.createIntegerConstant(loc, type, maxInt);
     }
-    case IntrinsicProc::IOR: {
+    case ReductionIdentifier::IOR: {
       unsigned bits = type.getIntOrFloatBitWidth();
       int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue();
       return builder.createIntegerConstant(loc, type, zeroInt);
     }
-    case IntrinsicProc::IEOR: {
+    case ReductionIdentifier::IEOR: {
       unsigned bits = type.getIntOrFloatBitWidth();
       int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue();
       return builder.createIntegerConstant(loc, type, zeroInt);
     }
-    case IntrinsicProc::IAND: {
+    case ReductionIdentifier::IAND: {
       unsigned bits = type.getIntOrFloatBitWidth();
       int64_t allOnInt = llvm::APInt::getAllOnes(bits).getSExtValue();
       return builder.createIntegerConstant(loc, type, allOnInt);
     }
-    }
-    llvm_unreachable("Unknown Reduction Intrinsic");
-  }
+    case ReductionIdentifier::ADD:
+    case ReductionIdentifier::MULTIPLY:
+    case ReductionIdentifier::AND:
+    case ReductionIdentifier::OR:
+    case ReductionIdentifier::EQV:
+    case ReductionIdentifier::NEQV:
+      if (type.isa<mlir::FloatType>())
+        return builder.create<mlir::arith::ConstantOp>(
+            loc, type,
+            builder.getFloatAttr(type,
+                                 (double)getOperationIdentity(redId, loc)));
+
+      if (type.isa<fir::LogicalType>()) {
+        mlir::Value intConst = builder.create<mlir::arith::ConstantOp>(
+            loc, builder.getI1Type(),
+            builder.getIntegerAttr(builder.getI1Type(),
+                                   getOperationIdentity(redId, loc)));
+        return builder.createConvert(loc, type, intConst);
+      }
 
-  static mlir::Value getIntrinsicOpInitValue(
-      mlir::Location loc, mlir::Type type,
-      Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp,
-      fir::FirOpBuilder &builder) {
-    if (type.isa<mlir::FloatType>())
       return builder.create<mlir::arith::ConstantOp>(
           loc, type,
-          builder.getFloatAttr(type,
-                               (double)getOperationIdentity(intrinsicOp, loc)));
-
-    if (type.isa<fir::LogicalType>()) {
-      mlir::Value intConst = builder.create<mlir::arith::ConstantOp>(
-          loc, builder.getI1Type(),
-          builder.getIntegerAttr(builder.getI1Type(),
-                                 getOperationIdentity(intrinsicOp, loc)));
-      return builder.createConvert(loc, type, intConst);
+          builder.getIntegerAttr(type, getOperationIdentity(redId, loc)));
+    case ReductionIdentifier::ID:
+    case ReductionIdentifier::USER_DEF_OP:
+    case ReductionIdentifier::SUBTRACT:
+      TODO(loc, "Reduction of some identifier types is not supported");
     }
-
-    return builder.create<mlir::arith::ConstantOp>(
-        loc, type,
-        builder.getIntegerAttr(type, getOperationIdentity(intrinsicOp, loc)));
+    llvm_unreachable("Unhandled Reduction identifier : getReductionInitValue");
   }
 
   template <typename FloatOp, typename IntegerOp>
@@ -910,118 +950,46 @@ class ReductionProcessor {
     return builder.create<FloatOp>(loc, op1, op2);
   }
 
-  /// Creates an OpenMP reduction declaration and inserts it into the provided
-  /// symbol table. The declaration has a constant initializer with the neutral
-  /// value `initValue`, and the reduction combiner carried over from `reduce`.
-  /// TODO: Generalize this for non-integer types, add atomic region.
-  static mlir::omp::ReductionDeclareOp createReductionDecl(
-      fir::FirOpBuilder &builder, llvm::StringRef reductionOpName,
-      const Fortran::parser::ProcedureDesignator &procDesignator,
-      mlir::Type type, mlir::Location loc) {
-    mlir::OpBuilder::InsertionGuard guard(builder);
-    mlir::ModuleOp module = builder.getModule();
-
-    auto decl =
-        module.lookupSymbol<mlir::omp::ReductionDeclareOp>(reductionOpName);
-    if (decl)
-      return decl;
-
-    mlir::OpBuilder modBuilder(module.getBodyRegion());
-
-    decl = modBuilder.create<mlir::omp::ReductionDeclareOp>(
-        loc, reductionOpName, type);
-    builder.createBlock(&decl.getInitializerRegion(),
-                        decl.getInitializerRegion().end(), {type}, {loc});
-    builder.setInsertionPointToEnd(&decl.getInitializerRegion().back());
-    mlir::Value init =
-        getIntrinsicProcInitValue(loc, type, procDesignator, builder);
-    builder.create<mlir::omp::YieldOp>(loc, init);
-
-    builder.createBlock(&decl.getReductionRegion(),
-                        decl.getReductionRegion().end(), {type, type},
-                        {loc, loc});
-
-    builder.setInsertionPointToEnd(&decl.getReductionRegion().back());
-    mlir::Value op1 = decl.getReductionRegion().front().getArgument(0);
-    mlir::Value op2 = decl.getReductionRegion().front().getArgument(1);
-
+  static mlir::Value createScalarCombiner(fir::FirOpBuilder &builder,
+                                          mlir::Location loc,
+                                          ReductionIdentifier redId,
+                                          mlir::Type type, mlir::Value op1,
+                                          mlir::Value op2) {
     mlir::Value reductionOp;
-    switch (getReductionType(procDesignator)) {
-    case IntrinsicProc::MAX:
+    switch (redId) {
+    case ReductionIdentifier::MAX:
       reductionOp =
           getReductionOperation<mlir::arith::MaximumFOp, mlir::arith::MaxSIOp>(
               builder, type, loc, op1, op2);
       break;
-    case IntrinsicProc::MIN:
+    case ReductionIdentifier::MIN:
       reductionOp =
           getReductionOperation<mlir::arith::MinimumFOp, mlir::arith::MinSIOp>(
               builder, type, loc, op1, op2);
       break;
-    case IntrinsicProc::IOR:
+    case ReductionIdentifier::IOR:
       assert((type.isIntOrIndex()) && "only integer is expected");
       reductionOp = builder.create<mlir::arith::OrIOp>(loc, op1, op2);
       break;
-    case IntrinsicProc::IEOR:
+    case ReductionIdentifier::IEOR:
       assert((type.isIntOrIndex()) && "only integer is expected");
       reductionOp = builder.create<mlir::arith::XOrIOp>(loc, op1, op2);
       break;
-    case IntrinsicProc::IAND:
+    case ReductionIdentifier::IAND:
       assert((type.isIntOrIndex()) && "only integer is expected");
       reductionOp = builder.create<mlir::arith::AndIOp>(loc, op1, op2);
       break;
-    }
-
-    builder.create<mlir::omp::YieldOp>(loc, reductionOp);
-    return decl;
-  }
-
-  /// Creates an OpenMP reduction declaration and inserts it into the provided
-  /// symbol table. The declaration has a constant initializer with the neutral
-  /// value `initValue`, and the reduction combiner carried over from `reduce`.
-  /// TODO: Generalize this for non-integer types, add atomic region.
-  static mlir::omp::ReductionDeclareOp createReductionDecl(
-      fir::FirOpBuilder &builder, llvm::StringRef reductionOpName,
-      Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp,
-      mlir::Type type, mlir::Location loc) {
-    mlir::OpBuilder::InsertionGuard guard(builder);
-    mlir::ModuleOp module = builder.getModule();
-
-    auto decl =
-        module.lookupSymbol<mlir::omp::ReductionDeclareOp>(reductionOpName);
-    if (decl)
-      return decl;
-
-    mlir::OpBuilder modBuilder(module.getBodyRegion());
-
-    decl = modBuilder.create<mlir::omp::ReductionDeclareOp>(
-        loc, reductionOpName, type);
-    builder.createBlock(&decl.getInitializerRegion(),
-                        decl.getInitializerRegion().end(), {type}, {loc});
-    builder.setInsertionPointToEnd(&decl.getInitializerRegion().back());
-    mlir::Value init = getIntrinsicOpInitValue(loc, type, intrinsicOp, builder);
-    builder.create<mlir::omp::YieldOp>(loc, init);
-
-    builder.createBlock(&decl.getReductionRegion(),
-                        decl.getReductionRegion().end(), {type, type},
-                        {loc, loc});
-
-    builder.setInsertionPointToEnd(&decl.getReductionRegion().back());
-    mlir::Value op1 = decl.getReductionRegion().front().getArgument(0);
-    mlir::Value op2 = decl.getReductionRegion().front().getArgument(1);
-
-    mlir::Value reductionOp;
-    switch (intrinsicOp) {
-    case Fortran::parser::DefinedOperator::IntrinsicOperator::Add:
+    case ReductionIdentifier::ADD:
       reductionOp =
           getReductionOperation<mlir::arith::AddFOp, mlir::arith::AddIOp>(
               builder, type, loc, op1, op2);
       break;
-    case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply:
+    case ReductionIdentifier::MULTIPLY:
       reductionOp =
           getReductionOperation<mlir::arith::MulFOp, mlir::arith::MulIOp>(
               builder, type, loc, op1, op2);
       break;
-    case Fortran::parser::DefinedOperator::IntrinsicOperator::AND: {
+    case ReductionIdentifier::AND: {
       mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
       mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
 
@@ -1031,7 +999,7 @@ class ReductionProcessor {
       reductionOp = builder.createConvert(loc, type, andiOp);
       break;
     }
-    case Fortran::parser::DefinedOperator::IntrinsicOperator::OR: {
+    case ReductionIdentifier::OR: {
       mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
       mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
 
@@ -1040,7 +1008,7 @@ class ReductionProcessor {
       reductionOp = builder.createConvert(loc, type, oriOp);
       break;
     }
-    case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV: {
+    case ReductionIdentifier::EQV: {
       mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
       mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
 
@@ -1050,7 +1018,7 @@ class ReductionProcessor {
       reductionOp = builder.createConvert(loc, type, cmpiOp);
       break;
     }
-    case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV: {
+    case ReductionIdentifier::NEQV: {
       mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
       mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
 
@@ -1064,7 +1032,46 @@ class ReductionProcessor {
       TODO(loc, "Reduction of some intrinsic operators is not supported");
     }
 
+    return reductionOp;
+  }
+
+  /// Creates an OpenMP reduction declaration and inserts it into the provided
+  /// symbol table. The declaration has a constant initializer with the neutral
+  /// value `initValue`, and the reduction combiner carried over from `reduce`.
+  /// TODO: Generalize this for non-integer types, add atomic region.
+  static mlir::omp::ReductionDeclareOp createReductionDecl(
+      fir::FirOpBuilder &builder, llvm::StringRef reductionOpName,
+      const ReductionIdentifier redId, mlir::Type type, mlir::Location loc) {
+    mlir::OpBuilder::InsertionGuard guard(builder);
+    mlir::ModuleOp module = builder.getModule();
+
+    auto decl =
+        module.lookupSymbol<mlir::omp::ReductionDeclareOp>(reductionOpName);
+    if (decl)
+      return decl;
+
+    mlir::OpBuilder modBuilder(module.getBodyRegion());
+
+    decl = modBuilder.create<mlir::omp::ReductionDeclareOp>(
+        loc, reductionOpName, type);
+    builder.createBlock(&decl.getInitializerRegion(),
+                        decl.getInitializerRegion().end(), {type}, {loc});
+    builder.setInsertionPointToEnd(&decl.getInitializerRegion().back());
+    mlir::Value init = getReductionInitValue(loc, type, redId, builder);
+    builder.create<mlir::omp::YieldOp>(loc, init);
+
+    builder.createBlock(&decl.getReductionRegion(),
+                        decl.getReductionRegion().end(), {type, type},
+                        {loc, loc});
+
+    builder.setInsertionPointToEnd(&decl.getReductionRegion().back());
+    mlir::Value op1 = decl.getReductionRegion().front().getArgument(0);
+    mlir::Value op2 = decl.getReductionRegion().front().getArgument(1);
+
+    mlir::Value reductionOp =
+        createScalarCombiner(builder, loc, redId, type, op1, op2);
     builder.create<mlir::omp::YieldOp>(loc, reductionOp);
+
     return decl;
   }
 
@@ -1087,15 +1094,15 @@ class ReductionProcessor {
       const auto &intrinsicOp{
           std::get<Fortran::parser::DefinedOperator::IntrinsicOperator>(
               redDefinedOp->u)};
-      switch (intrinsicOp) {
-      case Fortran::parser::DefinedOperator::IntrinsicOperator::Add:
-      case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply:
-      case Fortran::parser::DefinedOperator::IntrinsicOperator::AND:
-      case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV:
-      case Fortran::parser::DefinedOperator::IntrinsicOperator::OR:
-      case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV:
+      ReductionIdentifier redId = getReductionType(intrinsicOp);
+      switch (redId) {
+      case ReductionIdentifier::ADD:
+      case ReductionIdentifier::MULTIPLY:
+      case ReductionIdentifier::AND:
+      case ReductionIdentifier::EQV:
+      case ReductionIdentifier::OR:
+      case ReductionIdentifier::NEQV:
         break;
-
       default:
         TODO(currentLocation,
              "Reduction of some intrinsic operators is not supported");
@@ -1115,11 +1122,11 @@ class ReductionProcessor {
               decl = createReductionDecl(
                   firOpBuilder,
                   getReductionName(intrinsicOp, firOpBuilder.getI1Type()),
-                  intrinsicOp, redType, currentLocation);
+                  redId, redType, currentLocation);
             else if (redType.isIntOrIndexOrFloat()) {
               decl = createReductionDecl(firOpBuilder,
                                ...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/79876


More information about the flang-commits mailing list