[flang-commits] [flang] f51bdae - [Flang][OpenMP] Add support for OpenMP max reduction

Kiran Chandramohan via flang-commits flang-commits at lists.llvm.org
Tue Mar 14 14:49:03 PDT 2023


Author: Kiran Chandramohan
Date: 2023-03-14T21:38:08Z
New Revision: f51bdae4e3d603de81b7efab82889677dee276fa

URL: https://github.com/llvm/llvm-project/commit/f51bdae4e3d603de81b7efab82889677dee276fa
DIFF: https://github.com/llvm/llvm-project/commit/f51bdae4e3d603de81b7efab82889677dee276fa.diff

LOG: [Flang][OpenMP] Add support for OpenMP max reduction

This patch adds support for reduction of max-intrinsic for scalar
types. Max is lowered as a compare-select in the default lowering
flow for Flang. This pattern is matched and replaced with the
OpenMP dialect reduction operation.

Note: This is a temporary flow. The plan is to move to a flow where
the OpenMP reduction operation is inserted during lowering.

Reviewed By: do

Differential Revision: https://reviews.llvm.org/D145083

Added: 
    flang/test/Lower/OpenMP/wsloop-reduction-max.f90

Modified: 
    flang/lib/Lower/OpenMP.cpp

Removed: 
    flang/test/Lower/OpenMP/Todo/reduction-max.f90


################################################################################
diff  --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp
index 1f6449eca301b..22e5ff322f6fb 100644
--- a/flang/lib/Lower/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP.cpp
@@ -1121,22 +1121,38 @@ static int getOperationIdentity(llvm::StringRef reductionOpName,
 static Value getReductionInitValue(mlir::Location loc, mlir::Type type,
                                    llvm::StringRef reductionOpName,
                                    fir::FirOpBuilder &builder) {
-  if (type.isa<FloatType>())
+  assert((fir::isa_integer(type) || fir::isa_real(type) ||
+          type.isa<fir::LogicalType>()) &&
+         "only integer, logical and real types are currently supported");
+  if (reductionOpName.contains("max")) {
+    if (auto ty = type.dyn_cast<mlir::FloatType>()) {
+      const llvm::fltSemantics &sem = ty.getFloatSemantics();
+      return builder.createRealConstant(
+          loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/true));
+    }
+    unsigned bits = type.getIntOrFloatBitWidth();
+    int64_t minInt = llvm::APInt::getSignedMinValue(bits).getSExtValue();
+    return builder.createIntegerConstant(loc, type, minInt);
+  } else {
+    if (type.isa<FloatType>())
+      return builder.create<mlir::arith::ConstantOp>(
+          loc, type,
+          builder.getFloatAttr(
+              type, (double)getOperationIdentity(reductionOpName, loc)));
+
+    if (type.isa<fir::LogicalType>()) {
+      Value intConst = builder.create<mlir::arith::ConstantOp>(
+          loc, builder.getI1Type(),
+          builder.getIntegerAttr(builder.getI1Type(),
+                                 getOperationIdentity(reductionOpName, loc)));
+      return builder.createConvert(loc, type, intConst);
+    }
+
     return builder.create<mlir::arith::ConstantOp>(
         loc, type,
-        builder.getFloatAttr(
-            type, (double)getOperationIdentity(reductionOpName, loc)));
-
-  if (type.isa<fir::LogicalType>()) {
-    Value intConst = builder.create<mlir::arith::ConstantOp>(
-        loc, builder.getI1Type(),
-        builder.getIntegerAttr(builder.getI1Type(),
+        builder.getIntegerAttr(type,
                                getOperationIdentity(reductionOpName, loc)));
-    return builder.createConvert(loc, type, intConst);
   }
-  return builder.create<mlir::arith::ConstantOp>(
-      loc, type,
-      builder.getIntegerAttr(type, getOperationIdentity(reductionOpName, loc)));
 }
 
 template <typename FloatOp, typename IntegerOp>
@@ -1150,6 +1166,65 @@ static Value getReductionOperation(fir::FirOpBuilder &builder, mlir::Type type,
   return builder.create<FloatOp>(loc, op1, op2);
 }
 
+static omp::ReductionDeclareOp
+createMinimalReductionDecl(fir::FirOpBuilder &builder,
+                           llvm::StringRef reductionOpName, mlir::Type type,
+                           mlir::Location loc) {
+  mlir::ModuleOp module = builder.getModule();
+  mlir::OpBuilder modBuilder(module.getBodyRegion());
+
+  mlir::omp::ReductionDeclareOp decl =
+      modBuilder.create<omp::ReductionDeclareOp>(loc, reductionOpName, type);
+  builder.createBlock(&decl.getInitializerRegion(),
+                      decl.getInitializerRegion().end(), {type}, {loc});
+  builder.setInsertionPointToEnd(&decl.getInitializerRegion().back());
+  Value init = getReductionInitValue(loc, type, reductionOpName, builder);
+  builder.create<omp::YieldOp>(loc, init);
+
+  builder.createBlock(&decl.getReductionRegion(),
+                      decl.getReductionRegion().end(), {type, type},
+                      {loc, loc});
+
+  return decl;
+}
+
+/// Creates an OpenMP reduction declaration and inserts it into the provided
+/// symbol table. The declaration has a constant initializer with the neutral
+/// value `initValue`, and the reduction combiner carried over from `reduce`.
+/// TODO: Generalize this for non-integer types, add atomic region.
+static omp::ReductionDeclareOp
+createReductionDecl(fir::FirOpBuilder &builder, llvm::StringRef reductionOpName,
+                    const Fortran::parser::ProcedureDesignator &procDesignator,
+                    mlir::Type type, mlir::Location loc) {
+  OpBuilder::InsertionGuard guard(builder);
+  mlir::ModuleOp module = builder.getModule();
+
+  auto decl =
+      module.lookupSymbol<mlir::omp::ReductionDeclareOp>(reductionOpName);
+  if (decl)
+    return decl;
+
+  decl = createMinimalReductionDecl(builder, reductionOpName, type, loc);
+  builder.setInsertionPointToEnd(&decl.getReductionRegion().back());
+  mlir::Value op1 = decl.getReductionRegion().front().getArgument(0);
+  mlir::Value op2 = decl.getReductionRegion().front().getArgument(1);
+
+  Value reductionOp;
+  if (const auto *name{
+          Fortran::parser::Unwrap<Fortran::parser::Name>(procDesignator)}) {
+    if (name->source == "max") {
+      reductionOp =
+          getReductionOperation<mlir::arith::MaxFOp, mlir::arith::MaxSIOp>(
+              builder, type, loc, op1, op2);
+    } else {
+      TODO(loc, "Reduction of some intrinsic operators is not supported");
+    }
+  }
+
+  builder.create<omp::YieldOp>(loc, reductionOp);
+  return decl;
+}
+
 /// Creates an OpenMP reduction declaration and inserts it into the provided
 /// symbol table. The declaration has a constant initializer with the neutral
 /// value `initValue`, and the reduction combiner carried over from `reduce`.
@@ -1160,23 +1235,13 @@ static omp::ReductionDeclareOp createReductionDecl(
     mlir::Type type, mlir::Location loc) {
   OpBuilder::InsertionGuard guard(builder);
   mlir::ModuleOp module = builder.getModule();
-  mlir::OpBuilder modBuilder(module.getBodyRegion());
+
   auto decl =
       module.lookupSymbol<mlir::omp::ReductionDeclareOp>(reductionOpName);
-  if (!decl)
-    decl =
-        modBuilder.create<omp::ReductionDeclareOp>(loc, reductionOpName, type);
-  else
+  if (decl)
     return decl;
-  builder.createBlock(&decl.getInitializerRegion(),
-                      decl.getInitializerRegion().end(), {type}, {loc});
-  builder.setInsertionPointToEnd(&decl.getInitializerRegion().back());
-  Value init = getReductionInitValue(loc, type, reductionOpName, builder);
-  builder.create<omp::YieldOp>(loc, init);
 
-  builder.createBlock(&decl.getReductionRegion(),
-                      decl.getReductionRegion().end(), {type, type},
-                      {loc, loc});
+  decl = createMinimalReductionDecl(builder, reductionOpName, type, loc);
   builder.setInsertionPointToEnd(&decl.getReductionRegion().back());
   mlir::Value op1 = decl.getReductionRegion().front().getArgument(0);
   mlir::Value op2 = decl.getReductionRegion().front().getArgument(1);
@@ -1284,6 +1349,13 @@ getSIMDModifier(const Fortran::parser::OmpScheduleClause &x) {
   return mlir::omp::ScheduleModifier::none;
 }
 
+static std::string getReductionName(llvm::StringRef name, mlir::Type ty) {
+  return (llvm::Twine(name) +
+          (ty.isIntOrIndex() ? llvm::Twine("_i_") : llvm::Twine("_f_")) +
+          llvm::Twine(ty.getIntOrFloatBitWidth()))
+      .str();
+}
+
 static std::string getReductionName(
     Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp,
     mlir::Type ty) {
@@ -1305,10 +1377,7 @@ static std::string getReductionName(
     break;
   }
 
-  return (llvm::Twine(reductionName) +
-          (ty.isIntOrIndex() ? llvm::Twine("_i_") : llvm::Twine("_f_")) +
-          llvm::Twine(ty.getIntOrFloatBitWidth()))
-      .str();
+  return getReductionName(reductionName, ty);
 }
 
 static void genOMP(Fortran::lower::AbstractConverter &converter,
@@ -1443,9 +1512,34 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
             }
           }
         }
-      } else {
-        TODO(currentLocation,
-             "Reduction of intrinsic procedures is not supported");
+      } else if (auto reductionIntrinsic =
+                     std::get_if<Fortran::parser::ProcedureDesignator>(
+                         &redOperator.u)) {
+        if (const auto *name{Fortran::parser::Unwrap<Fortran::parser::Name>(
+                reductionIntrinsic)}) {
+          if (name->source != "max") {
+            TODO(currentLocation,
+                 "Reduction of intrinsic procedures is not supported");
+          }
+          for (const auto &ompObject : objectList.v) {
+            if (const auto *name{Fortran::parser::Unwrap<Fortran::parser::Name>(
+                    ompObject)}) {
+              if (const auto *symbol{name->symbol}) {
+                mlir::Value symVal = converter.getSymbolAddress(*symbol);
+                mlir::Type redType =
+                    symVal.getType().cast<fir::ReferenceType>().getEleTy();
+                reductionVars.push_back(symVal);
+                assert(redType.isIntOrIndexOrFloat() &&
+                       "Unsupported reduction type");
+                decl = createReductionDecl(
+                    firOpBuilder, getReductionName("max", redType),
+                    *reductionIntrinsic, redType, currentLocation);
+                reductionDeclSymbols.push_back(SymbolRefAttr::get(
+                    firOpBuilder.getContext(), decl.getSymName()));
+              }
+            }
+          }
+        }
       }
     } else if (const auto &simdlenClause =
                    std::get_if<Fortran::parser::OmpClause::Simdlen>(
@@ -2104,6 +2198,21 @@ void Fortran::lower::genOpenMPDeclarativeConstruct(
       ompDeclConstruct.u);
 }
 
+static mlir::Operation *getCompareFromReductionOp(mlir::Operation *reductionOp,
+                                                  mlir::Value loadVal) {
+  for (auto reductionOperand : reductionOp->getOperands()) {
+    if (auto compareOp = reductionOperand.getDefiningOp()) {
+      if (compareOp->getOperand(0) == loadVal ||
+          compareOp->getOperand(1) == loadVal)
+        assert((mlir::isa<mlir::arith::CmpIOp>(compareOp) ||
+                mlir::isa<mlir::arith::CmpFOp>(compareOp)) &&
+               "Expected comparison not found in reduction intrinsic");
+      return compareOp;
+    }
+  }
+  return nullptr;
+}
+
 // Generate an OpenMP reduction operation.
 // TODO: Currently assumes it is either an integer addition/multiplication
 // reduction, or a logical and reduction. Generalize this for various reduction
@@ -2170,6 +2279,40 @@ void Fortran::lower::genOpenMPReduction(
             }
           }
         }
+      } else if (auto reductionIntrinsic =
+                     std::get_if<Fortran::parser::ProcedureDesignator>(
+                         &redOperator.u)) {
+        if (const auto *name{Fortran::parser::Unwrap<Fortran::parser::Name>(
+                reductionIntrinsic)}) {
+          if (name->source != "max") {
+            continue;
+          }
+          for (const auto &ompObject : objectList.v) {
+            if (const auto *name{Fortran::parser::Unwrap<Fortran::parser::Name>(
+                    ompObject)}) {
+              if (const auto *symbol{name->symbol}) {
+                mlir::Value reductionVal = converter.getSymbolAddress(*symbol);
+                for (mlir::OpOperand &reductionValUse :
+                     reductionVal.getUses()) {
+                  if (auto loadOp = mlir::dyn_cast<fir::LoadOp>(
+                          reductionValUse.getOwner())) {
+                    mlir::Value loadVal = loadOp.getRes();
+                    // Max is lowered as a compare -> select.
+                    // Match the pattern here.
+                    mlir::Operation *reductionOp =
+                        findReductionChain(loadVal, &reductionVal);
+                    assert(mlir::isa<mlir::arith::SelectOp>(reductionOp) &&
+                           "Selection Op not found in reduction intrinsic");
+                    mlir::Operation *compareOp =
+                        getCompareFromReductionOp(reductionOp, loadVal);
+                    updateReduction(compareOp, firOpBuilder, loadVal,
+                                    reductionVal);
+                  }
+                }
+              }
+            }
+          }
+        }
       }
     }
   }

diff  --git a/flang/test/Lower/OpenMP/Todo/reduction-max.f90 b/flang/test/Lower/OpenMP/Todo/reduction-max.f90
deleted file mode 100644
index e965e6860712e..0000000000000
--- a/flang/test/Lower/OpenMP/Todo/reduction-max.f90
+++ /dev/null
@@ -1,16 +0,0 @@
-! RUN: %not_todo_cmd bbc -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s
-! RUN: %not_todo_cmd %flang_fc1 -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s
-
-! CHECK: not yet implemented: Reduction of intrinsic procedures is not supported
-subroutine reduction_max(y)
-  integer :: x, y(:)
-  x = 0
-  !$omp parallel
-  !$omp do reduction(max:x)
-  do i=1, 100
-    x = max(x, y(i))
-  end do
-  !$omp end do
-  !$omp end parallel
-  print *, x
-end subroutine

diff  --git a/flang/test/Lower/OpenMP/wsloop-reduction-max.f90 b/flang/test/Lower/OpenMP/wsloop-reduction-max.f90
new file mode 100644
index 0000000000000..0db5b7f813d74
--- /dev/null
+++ b/flang/test/Lower/OpenMP/wsloop-reduction-max.f90
@@ -0,0 +1,66 @@
+! RUN: bbc -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s
+! RUN: %flang_fc1 -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s
+
+!CHECK: omp.reduction.declare @[[MAX_DECLARE_F:.*]] : f32 init {
+!CHECK:   %[[MINIMUM_VAL_F:.*]] = arith.constant -3.40282347E+38 : f32
+!CHECK:   omp.yield(%[[MINIMUM_VAL_F]] : f32)
+!CHECK: combiner
+!CHECK: ^bb0(%[[ARG0_F:.*]]: f32, %[[ARG1_F:.*]]: f32):
+!CHECK:   %[[COMB_VAL_F:.*]] = arith.maxf %[[ARG0_F]], %[[ARG1_F]] {{.*}}: f32
+!CHECK:   omp.yield(%[[COMB_VAL_F]] : f32)
+
+!CHECK: omp.reduction.declare @[[MAX_DECLARE_I:.*]] : i32 init {
+!CHECK:   %[[MINIMUM_VAL_I:.*]] = arith.constant -2147483648 : i32
+!CHECK:   omp.yield(%[[MINIMUM_VAL_I]] : i32)
+!CHECK: combiner
+!CHECK: ^bb0(%[[ARG0_I:.*]]: i32, %[[ARG1_I:.*]]: i32):
+!CHECK:   %[[COMB_VAL_I:.*]] = arith.maxsi %[[ARG0_I]], %[[ARG1_I]] : i32
+!CHECK:   omp.yield(%[[COMB_VAL_I]] : i32)
+
+!CHECK-LABEL: @_QPreduction_max_int
+!CHECK-SAME: %[[Y_BOX:.*]]: !fir.box<!fir.array<?xi32>>
+!CHECK:   %[[X_REF:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFreduction_max_intEx"}
+!CHECK:   omp.parallel
+!CHECK:     omp.wsloop reduction(@[[MAX_DECLARE_I]] -> %[[X_REF]] : !fir.ref<i32>) for
+!CHECK:       %[[Y_I_REF:.*]] = fir.coordinate_of %[[Y_BOX]]
+!CHECK:       %[[Y_I:.*]] = fir.load %[[Y_I_REF]] : !fir.ref<i32>
+!CHECK:       omp.reduction %[[Y_I]], %[[X_REF]] : i32, !fir.ref<i32>
+!CHECK:       omp.yield
+!CHECK:     omp.terminator
+
+!CHECK-LABEL: @_QPreduction_max_real
+!CHECK-SAME: %[[Y_BOX:.*]]: !fir.box<!fir.array<?xf32>>
+!CHECK:   %[[X_REF:.*]] = fir.alloca f32 {bindc_name = "x", uniq_name = "_QFreduction_max_realEx"}
+!CHECK:   omp.parallel
+!CHECK:     omp.wsloop reduction(@[[MAX_DECLARE_F]] -> %[[X_REF]] : !fir.ref<f32>) for
+!CHECK:       %[[Y_I_REF:.*]] = fir.coordinate_of %[[Y_BOX]]
+!CHECK:       %[[Y_I:.*]] = fir.load %[[Y_I_REF]] : !fir.ref<f32>
+!CHECK:       omp.reduction %[[Y_I]], %[[X_REF]] : f32, !fir.ref<f32>
+!CHECK:       omp.yield
+!CHECK:     omp.terminator
+
+subroutine reduction_max_int(y)
+  integer :: x, y(:)
+  x = 0
+  !$omp parallel
+  !$omp do reduction(max:x)
+  do i=1, 100
+    x = max(x, y(i))
+  end do
+  !$omp end do
+  !$omp end parallel
+  print *, x
+end subroutine
+
+subroutine reduction_max_real(y)
+  real :: x, y(:)
+  x = 0.0
+  !$omp parallel
+  !$omp do reduction(max:x)
+  do i=1, 100
+    x = max(y(i), x)
+  end do
+  !$omp end do
+  !$omp end parallel
+  print *, x
+end subroutine


        


More information about the flang-commits mailing list