[flang-commits] [flang] 59ceb7d - [flang][openacc] Initial reduction clause lowering

Valentin Clement via flang-commits flang-commits at lists.llvm.org
Thu Jun 1 06:15:36 PDT 2023


Author: Valentin Clement
Date: 2023-06-01T22:15:28+09:00
New Revision: 59ceb7dd9a02f9c6a1342b3c282d1ddfa028ce34

URL: https://github.com/llvm/llvm-project/commit/59ceb7dd9a02f9c6a1342b3c282d1ddfa028ce34
DIFF: https://github.com/llvm/llvm-project/commit/59ceb7dd9a02f9c6a1342b3c282d1ddfa028ce34.diff

LOG: [flang][openacc] Initial reduction clause lowering

Add initial support to lower reduction clause to its representation in MLIR.

This patch adds support for addition of integer and real scalar types. Other
operators and types will be added with follow up patches.

Reviewed By: razvanlupusoru

Differential Revision: https://reviews.llvm.org/D151564

Added: 
    flang/test/Lower/OpenACC/acc-reduction.f90

Modified: 
    flang/include/flang/Lower/OpenACC.h
    flang/lib/Lower/OpenACC.cpp
    mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Lower/OpenACC.h b/flang/include/flang/Lower/OpenACC.h
index cd876ea2517aa..7546e842787f0 100644
--- a/flang/include/flang/Lower/OpenACC.h
+++ b/flang/include/flang/Lower/OpenACC.h
@@ -13,6 +13,8 @@
 #ifndef FORTRAN_LOWER_OPENACC_H
 #define FORTRAN_LOWER_OPENACC_H
 
+#include "mlir/Dialect/OpenACC/OpenACC.h"
+
 namespace llvm {
 class StringRef;
 }
@@ -21,9 +23,6 @@ namespace mlir {
 class Location;
 class Type;
 class OpBuilder;
-namespace acc {
-class PrivateRecipeOp;
-}
 } // namespace mlir
 
 namespace Fortran {
@@ -57,6 +56,12 @@ mlir::acc::PrivateRecipeOp createOrGetPrivateRecipe(mlir::OpBuilder &,
                                                     llvm::StringRef,
                                                     mlir::Location, mlir::Type);
 
+/// Get a acc.reduction.recipe op for the given type or create it if it does not
+/// exist yet.
+mlir::acc::ReductionRecipeOp
+createOrGetReductionRecipe(mlir::OpBuilder &, llvm::StringRef, mlir::Location,
+                           mlir::Type, mlir::acc::ReductionOperator);
+
 } // namespace lower
 } // namespace Fortran
 

diff  --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 306b799f36432..f33298763faa4 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -22,7 +22,6 @@
 #include "flang/Parser/parse-tree.h"
 #include "flang/Semantics/expression.h"
 #include "flang/Semantics/tools.h"
-#include "mlir/Dialect/OpenACC/OpenACC.h"
 #include "llvm/Frontend/OpenACC/ACC.h.inc"
 
 // Special value for * passed in device_type or gang clauses.
@@ -526,6 +525,132 @@ genPrivatizations(const Fortran::parser::AccObjectList &objectList,
   }
 }
 
+/// Return the corresponding enum value for the mlir::acc::ReductionOperator
+/// from the parser representation.
+static mlir::acc::ReductionOperator
+getReductionOperator(const Fortran::parser::AccReductionOperator &op) {
+  switch (op.v) {
+  case Fortran::parser::AccReductionOperator::Operator::Plus:
+    return mlir::acc::ReductionOperator::AccAdd;
+  case Fortran::parser::AccReductionOperator::Operator::Multiply:
+    return mlir::acc::ReductionOperator::AccMul;
+  case Fortran::parser::AccReductionOperator::Operator::Max:
+    return mlir::acc::ReductionOperator::AccMax;
+  case Fortran::parser::AccReductionOperator::Operator::Min:
+    return mlir::acc::ReductionOperator::AccMin;
+  case Fortran::parser::AccReductionOperator::Operator::Iand:
+    return mlir::acc::ReductionOperator::AccIand;
+  case Fortran::parser::AccReductionOperator::Operator::Ior:
+    return mlir::acc::ReductionOperator::AccIor;
+  case Fortran::parser::AccReductionOperator::Operator::Ieor:
+    return mlir::acc::ReductionOperator::AccXor;
+  case Fortran::parser::AccReductionOperator::Operator::And:
+    return mlir::acc::ReductionOperator::AccLand;
+  case Fortran::parser::AccReductionOperator::Operator::Or:
+    return mlir::acc::ReductionOperator::AccLor;
+  case Fortran::parser::AccReductionOperator::Operator::Eqv:
+    return mlir::acc::ReductionOperator::AccEqv;
+  case Fortran::parser::AccReductionOperator::Operator::Neqv:
+    return mlir::acc::ReductionOperator::AccNeqv;
+  }
+  llvm_unreachable("unexpected reduction operator");
+}
+
+static mlir::Value genReductionInitValue(mlir::OpBuilder &builder,
+                                         mlir::Location loc, mlir::Type ty,
+                                         mlir::acc::ReductionOperator op) {
+  if (op != mlir::acc::ReductionOperator::AccAdd)
+    TODO(loc, "reduction operator");
+
+  unsigned initValue = 0;
+
+  if (ty.isIntOrIndex())
+    return builder.create<mlir::arith::ConstantOp>(
+        loc, ty, builder.getIntegerAttr(ty, initValue));
+  if (mlir::isa<mlir::FloatType>(ty))
+    return builder.create<mlir::arith::ConstantOp>(
+        loc, ty, builder.getFloatAttr(ty, initValue));
+  TODO(loc, "reduction type");
+}
+
+static mlir::Value genCombiner(mlir::OpBuilder &builder, mlir::Location loc,
+                               mlir::acc::ReductionOperator op, mlir::Type ty,
+                               mlir::Value value1, mlir::Value value2) {
+  if (op == mlir::acc::ReductionOperator::AccAdd) {
+    if (ty.isIntOrIndex())
+      return builder.create<mlir::arith::AddIOp>(loc, value1, value2);
+    if (mlir::isa<mlir::FloatType>(ty))
+      return builder.create<mlir::arith::AddFOp>(loc, value1, value2);
+    TODO(loc, "reduction add type");
+  }
+  TODO(loc, "reduction operator");
+}
+
+mlir::acc::ReductionRecipeOp Fortran::lower::createOrGetReductionRecipe(
+    mlir::OpBuilder &builder, llvm::StringRef recipeName, mlir::Location loc,
+    mlir::Type ty, mlir::acc::ReductionOperator op) {
+  mlir::ModuleOp mod =
+      builder.getBlock()->getParent()->getParentOfType<mlir::ModuleOp>();
+  if (auto recipe = mod.lookupSymbol<mlir::acc::ReductionRecipeOp>(recipeName))
+    return recipe;
+
+  auto crtPos = builder.saveInsertionPoint();
+  mlir::OpBuilder modBuilder(mod.getBodyRegion());
+  auto recipe =
+      modBuilder.create<mlir::acc::ReductionRecipeOp>(loc, recipeName, ty, op);
+  builder.createBlock(&recipe.getInitRegion(), recipe.getInitRegion().end(),
+                      {ty}, {loc});
+  builder.setInsertionPointToEnd(&recipe.getInitRegion().back());
+  mlir::Value initValue = genReductionInitValue(builder, loc, ty, op);
+  builder.create<mlir::acc::YieldOp>(loc, initValue);
+
+  builder.createBlock(&recipe.getCombinerRegion(),
+                      recipe.getCombinerRegion().end(), {ty, ty}, {loc, loc});
+  builder.setInsertionPointToEnd(&recipe.getCombinerRegion().back());
+  mlir::Value v1 = recipe.getCombinerRegion().front().getArgument(0);
+  mlir::Value v2 = recipe.getCombinerRegion().front().getArgument(1);
+  mlir::Value combinedValue = genCombiner(builder, loc, op, ty, v1, v2);
+  builder.create<mlir::acc::YieldOp>(loc, combinedValue);
+  builder.restoreInsertionPoint(crtPos);
+  return recipe;
+}
+
+static void
+genReductions(const Fortran::parser::AccObjectListWithReduction &objectList,
+              Fortran::lower::AbstractConverter &converter,
+              Fortran::semantics::SemanticsContext &semanticsContext,
+              Fortran::lower::StatementContext &stmtCtx,
+              llvm::SmallVectorImpl<mlir::Value> &reductionOperands,
+              llvm::SmallVector<mlir::Attribute> &reductionRecipes) {
+  fir::FirOpBuilder &builder = converter.getFirOpBuilder();
+  const auto &objects = std::get<Fortran::parser::AccObjectList>(objectList.t);
+  const auto &op =
+      std::get<Fortran::parser::AccReductionOperator>(objectList.t);
+  mlir::acc::ReductionOperator mlirOp = getReductionOperator(op);
+  for (const auto &accObject : objects.v) {
+    llvm::SmallVector<mlir::Value> bounds;
+    std::stringstream asFortran;
+    mlir::Location operandLocation = genOperandLocation(converter, accObject);
+    mlir::Value baseAddr = gatherDataOperandAddrAndBounds(
+        converter, builder, semanticsContext, stmtCtx, accObject,
+        operandLocation, asFortran, bounds);
+
+    if (!fir::isa_trivial(fir::unwrapRefType(baseAddr.getType())))
+      TODO(operandLocation, "reduction with unsupported type");
+
+    mlir::Type ty = fir::unwrapRefType(baseAddr.getType());
+    std::string recipeName = fir::getTypeAsString(
+        ty, converter.getKindMap(),
+        ("reduction_" + stringifyReductionOperator(mlirOp)).str());
+    mlir::acc::ReductionRecipeOp recipe =
+        Fortran::lower::createOrGetReductionRecipe(builder, recipeName,
+                                                   operandLocation, ty, mlirOp);
+    reductionRecipes.push_back(mlir::SymbolRefAttr::get(
+        builder.getContext(), recipe.getSymName().str()));
+    reductionOperands.push_back(baseAddr);
+  }
+}
+
 static void
 addOperands(llvm::SmallVectorImpl<mlir::Value> &operands,
             llvm::SmallVectorImpl<int32_t> &operandSegments,
@@ -666,7 +791,7 @@ createLoopOp(Fortran::lower::AbstractConverter &converter,
   mlir::Value gangStatic;
   llvm::SmallVector<mlir::Value, 2> tileOperands, privateOperands,
       reductionOperands;
-  llvm::SmallVector<mlir::Attribute> privatizations;
+  llvm::SmallVector<mlir::Attribute> privatizations, reductionRecipes;
   bool hasGang = false, hasVector = false, hasWorker = false;
 
   for (const Fortran::parser::AccClause &clause : accClauseList.v) {
@@ -735,10 +860,11 @@ createLoopOp(Fortran::lower::AbstractConverter &converter,
                        &clause.u)) {
       genPrivatizations(privateClause->v, converter, semanticsContext, stmtCtx,
                         privateOperands, privatizations);
-    } else if (std::get_if<Fortran::parser::AccClause::Reduction>(&clause.u)) {
-      // Reduction clause is left out for the moment as the clause will probably
-      // end up having its own operation.
-      TODO(clauseLocation, "OpenACC compute construct reduction lowering");
+    } else if (const auto *reductionClause =
+                   std::get_if<Fortran::parser::AccClause::Reduction>(
+                       &clause.u)) {
+      genReductions(reductionClause->v, converter, semanticsContext, stmtCtx,
+                    reductionOperands, reductionRecipes);
     }
   }
 
@@ -767,6 +893,10 @@ createLoopOp(Fortran::lower::AbstractConverter &converter,
     loopOp.setPrivatizationsAttr(
         mlir::ArrayAttr::get(builder.getContext(), privatizations));
 
+  if (!reductionRecipes.empty())
+    loopOp.setReductionRecipesAttr(
+        mlir::ArrayAttr::get(builder.getContext(), reductionRecipes));
+
   // Lower clauses mapped to attributes
   for (const Fortran::parser::AccClause &clause : accClauseList.v) {
     if (const auto *collapseClause =

diff  --git a/flang/test/Lower/OpenACC/acc-reduction.f90 b/flang/test/Lower/OpenACC/acc-reduction.f90
new file mode 100644
index 0000000000000..4c95b40b8c9ea
--- /dev/null
+++ b/flang/test/Lower/OpenACC/acc-reduction.f90
@@ -0,0 +1,51 @@
+! This test checks lowering of OpenACC reduction clause.
+
+! RUN: bbc -fopenacc -emit-fir %s -o - | FileCheck %s
+
+! CHECK-LABEL: acc.reduction.recipe @reduction_add_f32 : f32 reduction_operator <add> init {
+! CHECK: ^bb0(%{{.*}}: f32):
+! CHECK:   %[[INIT:.*]] = arith.constant 0.000000e+00 : f32
+! CHECK:   acc.yield %[[INIT]] : f32
+! CHECK: } combiner {
+! CHECK: ^bb0(%[[ARG0:.*]]: f32, %[[ARG1:.*]]: f32):
+! CHECK:   %[[COMBINED:.*]] = arith.addf %[[ARG0]], %[[ARG1]] {{.*}} : f32
+! CHECK:   acc.yield %[[COMBINED]] : f32
+! CHECK: }
+
+! CHECK-LABEL: acc.reduction.recipe @reduction_add_i32 : i32 reduction_operator <add> init {
+! CHECK: ^bb0(%{{.*}}: i32):
+! CHECK:   %[[INIT:.*]] = arith.constant 0 : i32
+! CHECK:   acc.yield %[[INIT]] : i32
+! CHECK: } combiner {
+! CHECK: ^bb0(%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32):
+! CHECK:   %[[COMBINED:.*]] = arith.addi %[[ARG0]], %[[ARG1]] : i32
+! CHECK:   acc.yield %[[COMBINED]] : i32
+! CHECK: }
+
+subroutine acc_reduction_add_int(a, b)
+  integer :: a(100)
+  integer :: i, b
+
+  !$acc loop reduction(+:b)
+  do i = 1, 100
+    b = b + a(i)
+  end do
+end subroutine
+
+! CHECK-LABEL: func.func @_QPacc_reduction_add_int(
+! CHECK-SAME:  %{{.*}}: !fir.ref<!fir.array<100xi32>> {fir.bindc_name = "a"}, %[[B:.*]]: !fir.ref<i32> {fir.bindc_name = "b"})
+! CHECK:       acc.loop reduction(@reduction_add_i32 -> %[[B]] : !fir.ref<i32>) {
+
+subroutine acc_reduction_add_float(a, b)
+  real :: a(100), b
+  integer :: i
+
+  !$acc loop reduction(+:b)
+  do i = 1, 100
+    b = b + a(i)
+  end do
+end subroutine
+
+! CHECK-LABEL: func.func @_QPacc_reduction_add_float(
+! CHECK-SAME:  %{{.*}}: !fir.ref<!fir.array<100xf32>> {fir.bindc_name = "a"}, %[[B:.*]]: !fir.ref<f32> {fir.bindc_name = "b"})
+! CHECK:       acc.loop reduction(@reduction_add_f32 -> %[[B]] : !fir.ref<f32>)

diff  --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index fab270a6f1730..da5a2856aec21 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -498,7 +498,7 @@ template <typename Op>
 static LogicalResult
 checkSymOperandList(Operation *op, std::optional<mlir::ArrayAttr> attributes,
                     mlir::OperandRange operands, llvm::StringRef operandName,
-                    llvm::StringRef symbolName) {
+                    llvm::StringRef symbolName, bool checkOperandType = true) {
   if (!operands.empty()) {
     if (!attributes || attributes->size() != operands.size())
       return op->emitOpError()
@@ -527,7 +527,7 @@ checkSymOperandList(Operation *op, std::optional<mlir::ArrayAttr> attributes,
              << "expected symbol reference " << symbolRef << " to point to a "
              << operandName << " declaration";
 
-    if (decl.getType() && decl.getType() != varType)
+    if (checkOperandType && decl.getType() && decl.getType() != varType)
       return op->emitOpError() << "expected " << operandName << " (" << varType
                                << ") to be the same type as " << operandName
                                << " declaration (" << decl.getType() << ")";
@@ -751,7 +751,7 @@ LogicalResult acc::LoopOp::verify() {
 
   if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
           *this, getReductionRecipes(), getReductionOperands(), "reduction",
-          "reductions")))
+          "reductions", false)))
     return failure();
 
   // Check non-empty body().


        


More information about the flang-commits mailing list