[flang-commits] [flang] f0e028f - [flang][openacc] Lower clauses on loop construct to OpenACC dialect

via flang-commits flang-commits at lists.llvm.org
Thu Sep 17 08:34:49 PDT 2020


Author: Valentin Clement
Date: 2020-09-17T11:34:43-04:00
New Revision: f0e028f4b32393676b5d3eb36d6598ec5a390180

URL: https://github.com/llvm/llvm-project/commit/f0e028f4b32393676b5d3eb36d6598ec5a390180
DIFF: https://github.com/llvm/llvm-project/commit/f0e028f4b32393676b5d3eb36d6598ec5a390180.diff

LOG: [flang][openacc] Lower clauses on loop construct to OpenACC dialect

Lower OpenACCLoopConstruct and most of the clauses to the OpenACC acc.loop operation in MLIR.
This patch refelcts what can be upstream from PR flang-compiler/f18-llvm-project#419

Reviewed By: SouraVX

Differential Revision: https://reviews.llvm.org/D87389

Added: 
    

Modified: 
    flang/include/flang/Optimizer/Dialect/FIRDialect.h
    flang/lib/Lower/OpenACC.cpp

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Optimizer/Dialect/FIRDialect.h b/flang/include/flang/Optimizer/Dialect/FIRDialect.h
index 9702c54367b8..a4b0e3f9aa7f 100644
--- a/flang/include/flang/Optimizer/Dialect/FIRDialect.h
+++ b/flang/include/flang/Optimizer/Dialect/FIRDialect.h
@@ -37,6 +37,7 @@ inline void registerFIRDialects(mlir::DialectRegistry &registry) {
   // clang-format off
   registry.insert<mlir::AffineDialect,
                   mlir::LLVM::LLVMDialect,
+                  mlir::acc::OpenACCDialect,
                   mlir::omp::OpenMPDialect,
                   mlir::scf::SCFDialect,
                   mlir::StandardOpsDialect,

diff  --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 5c8c29e491d6..f91aff792cbd 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -11,16 +11,202 @@
 //===----------------------------------------------------------------------===//
 
 #include "flang/Lower/OpenACC.h"
+#include "flang/Common/idioms.h"
 #include "flang/Lower/Bridge.h"
 #include "flang/Lower/FIRBuilder.h"
 #include "flang/Lower/PFTBuilder.h"
 #include "flang/Parser/parse-tree.h"
+#include "flang/Semantics/tools.h"
+#include "mlir/Dialect/OpenACC/OpenACC.h"
 #include "llvm/Frontend/OpenACC/ACC.h.inc"
 
 #define TODO() llvm_unreachable("not yet implemented")
 
+static const Fortran::parser::Name *
+getDesignatorNameIfDataRef(const Fortran::parser::Designator &designator) {
+  const auto *dataRef{std::get_if<Fortran::parser::DataRef>(&designator.u)};
+  return dataRef ? std::get_if<Fortran::parser::Name>(&dataRef->u) : nullptr;
+}
+
+static void genObjectList(const Fortran::parser::AccObjectList &objectList,
+                          Fortran::lower::AbstractConverter &converter,
+                          std::int32_t &objectsCount,
+                          SmallVector<Value, 8> &operands) {
+  for (const auto &accObject : objectList.v) {
+    std::visit(
+        Fortran::common::visitors{
+            [&](const Fortran::parser::Designator &designator) {
+              if (const auto *name = getDesignatorNameIfDataRef(designator)) {
+                ++objectsCount;
+                const auto variable = converter.getSymbolAddress(*name->symbol);
+                operands.push_back(variable);
+              }
+            },
+            [&](const Fortran::parser::Name &name) {
+              ++objectsCount;
+              const auto variable = converter.getSymbolAddress(*name.symbol);
+              operands.push_back(variable);
+            }},
+        accObject.u);
+  }
+}
+
+static void genACC(Fortran::lower::AbstractConverter &converter,
+                   Fortran::lower::pft::Evaluation &eval,
+                   const Fortran::parser::OpenACCLoopConstruct &loopConstruct) {
+
+  const auto &beginLoopDirective =
+      std::get<Fortran::parser::AccBeginLoopDirective>(loopConstruct.t);
+  const auto &loopDirective =
+      std::get<Fortran::parser::AccLoopDirective>(beginLoopDirective.t);
+
+  if (loopDirective.v == llvm::acc::ACCD_loop) {
+    auto &firOpBuilder = converter.getFirOpBuilder();
+    auto currentLocation = converter.getCurrentLocation();
+    llvm::ArrayRef<mlir::Type> argTy;
+
+    // Add attribute extracted from clauses.
+    const auto &accClauseList =
+        std::get<Fortran::parser::AccClauseList>(beginLoopDirective.t);
+
+    mlir::Value workerNum;
+    mlir::Value vectorLength;
+    mlir::Value gangNum;
+    mlir::Value gangStatic;
+    std::int32_t tileOperands = 0;
+    std::int32_t privateOperands = 0;
+    std::int32_t reductionOperands = 0;
+    std::int64_t executionMapping = mlir::acc::OpenACCExecMapping::NONE;
+    SmallVector<Value, 8> operands;
+
+    // Lower clauses values mapped to operands.
+    for (const auto &clause : accClauseList.v) {
+      if (const auto *gangClause =
+              std::get_if<Fortran::parser::AccClause::Gang>(&clause.u)) {
+        if (gangClause->v) {
+          const Fortran::parser::AccGangArgument &x = *gangClause->v;
+          if (const auto &gangNumValue =
+                  std::get<std::optional<Fortran::parser::ScalarIntExpr>>(
+                      x.t)) {
+            gangNum = converter.genExprValue(
+                *Fortran::semantics::GetExpr(gangNumValue.value()));
+            operands.push_back(gangNum);
+          }
+          if (const auto &gangStaticValue =
+                  std::get<std::optional<Fortran::parser::AccSizeExpr>>(x.t)) {
+            const auto &expr =
+                std::get<std::optional<Fortran::parser::ScalarIntExpr>>(
+                    gangStaticValue.value().t);
+            if (expr) {
+              gangStatic =
+                  converter.genExprValue(*Fortran::semantics::GetExpr(*expr));
+            } else {
+              // * was passed as value and will be represented as a -1 constant
+              // integer.
+              gangStatic = firOpBuilder.createIntegerConstant(
+                  currentLocation, firOpBuilder.getIntegerType(32),
+                  /* STAR */ -1);
+            }
+            operands.push_back(gangStatic);
+          }
+        }
+        executionMapping |= mlir::acc::OpenACCExecMapping::GANG;
+      } else if (const auto *workerClause =
+                     std::get_if<Fortran::parser::AccClause::Worker>(
+                         &clause.u)) {
+        if (workerClause->v) {
+          workerNum = converter.genExprValue(
+              *Fortran::semantics::GetExpr(*workerClause->v));
+          operands.push_back(workerNum);
+        }
+        executionMapping |= mlir::acc::OpenACCExecMapping::WORKER;
+      } else if (const auto *vectorClause =
+                     std::get_if<Fortran::parser::AccClause::Vector>(
+                         &clause.u)) {
+        if (vectorClause->v) {
+          vectorLength = converter.genExprValue(
+              *Fortran::semantics::GetExpr(*vectorClause->v));
+          operands.push_back(vectorLength);
+        }
+        executionMapping |= mlir::acc::OpenACCExecMapping::VECTOR;
+      } else if (const auto *tileClause =
+                     std::get_if<Fortran::parser::AccClause::Tile>(&clause.u)) {
+        const Fortran::parser::AccTileExprList &accTileExprList = tileClause->v;
+        for (const auto &accTileExpr : accTileExprList.v) {
+          const auto &expr =
+              std::get<std::optional<Fortran::parser::ScalarIntConstantExpr>>(
+                  accTileExpr.t);
+          ++tileOperands;
+          if (expr) {
+            operands.push_back(
+                converter.genExprValue(*Fortran::semantics::GetExpr(*expr)));
+          } else {
+            // * was passed as value and will be represented as a -1 constant
+            // integer.
+            mlir::Value tileStar = firOpBuilder.createIntegerConstant(
+                currentLocation, firOpBuilder.getIntegerType(32),
+                /* STAR */ -1);
+            operands.push_back(tileStar);
+          }
+        }
+      } else if (const auto *privateClause =
+                     std::get_if<Fortran::parser::AccClause::Private>(
+                         &clause.u)) {
+        const Fortran::parser::AccObjectList &accObjectList = privateClause->v;
+        genObjectList(accObjectList, converter, privateOperands, operands);
+      }
+      // Reduction clause is left out for the moment as the clause will probably
+      // end up having its own operation.
+    }
+
+    auto loopOp = firOpBuilder.create<mlir::acc::LoopOp>(currentLocation, argTy,
+                                                         operands);
+
+    firOpBuilder.createBlock(&loopOp.getRegion());
+    auto &block = loopOp.getRegion().back();
+    firOpBuilder.setInsertionPointToStart(&block);
+    // ensure the block is well-formed.
+    firOpBuilder.create<mlir::acc::YieldOp>(currentLocation);
+
+    loopOp.setAttr(mlir::acc::LoopOp::getOperandSegmentSizeAttr(),
+                   firOpBuilder.getI32VectorAttr(
+                       {gangNum ? 1 : 0, gangStatic ? 1 : 0, workerNum ? 1 : 0,
+                        vectorLength ? 1 : 0, tileOperands, privateOperands,
+                        reductionOperands}));
+
+    loopOp.setAttr(mlir::acc::LoopOp::getExecutionMappingAttrName(),
+                   firOpBuilder.getI64IntegerAttr(executionMapping));
+
+    // Lower clauses mapped to attributes
+    for (const auto &clause : accClauseList.v) {
+      if (const auto *collapseClause =
+              std::get_if<Fortran::parser::AccClause::Collapse>(&clause.u)) {
+        const auto *expr = Fortran::semantics::GetExpr(collapseClause->v);
+        const auto collapseValue = Fortran::evaluate::ToInt64(*expr);
+        if (collapseValue) {
+          loopOp.setAttr(mlir::acc::LoopOp::getCollapseAttrName(),
+                         firOpBuilder.getI64IntegerAttr(*collapseValue));
+        }
+      } else if (std::get_if<Fortran::parser::AccClause::Seq>(&clause.u)) {
+        loopOp.setAttr(mlir::acc::LoopOp::getSeqAttrName(),
+                       firOpBuilder.getUnitAttr());
+      } else if (std::get_if<Fortran::parser::AccClause::Independent>(
+                     &clause.u)) {
+        loopOp.setAttr(mlir::acc::LoopOp::getIndependentAttrName(),
+                       firOpBuilder.getUnitAttr());
+      } else if (std::get_if<Fortran::parser::AccClause::Auto>(&clause.u)) {
+        loopOp.setAttr(mlir::acc::LoopOp::getAutoAttrName(),
+                       firOpBuilder.getUnitAttr());
+      }
+    }
+
+    // Place the insertion point to the start of the first block.
+    firOpBuilder.setInsertionPointToStart(&block);
+  }
+}
+
 void Fortran::lower::genOpenACCConstruct(
-    Fortran::lower::AbstractConverter &absConv,
+    Fortran::lower::AbstractConverter &converter,
     Fortran::lower::pft::Evaluation &eval,
     const Fortran::parser::OpenACCConstruct &accConstruct) {
 
@@ -32,7 +218,7 @@ void Fortran::lower::genOpenACCConstruct(
           [&](const Fortran::parser::OpenACCCombinedConstruct
                   &combinedConstruct) { TODO(); },
           [&](const Fortran::parser::OpenACCLoopConstruct &loopConstruct) {
-            TODO();
+            genACC(converter, eval, loopConstruct);
           },
           [&](const Fortran::parser::OpenACCStandaloneConstruct
                   &standaloneConstruct) { TODO(); },


        


More information about the flang-commits mailing list