[flang-commits] [mlir] [flang] [flang][openacc] Support multiple device_type when lowering (PR #78634)

via flang-commits flang-commits at lists.llvm.org
Thu Jan 18 13:58:19 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-fir-hlfir

@llvm/pr-subscribers-openacc

Author: Valentin Clement (バレンタイン クレメン) (clementval)

<details>
<summary>Changes</summary>

routine, data, parallel, serial, kernels and loop construct all support the device_type clause. This clause takes a list of device_type. Previously the lowering code was assuming that the list s a single item. This PR updates the lowering to handle any number of device_types. 

---

Patch is 30.09 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/78634.diff


5 Files Affected:

- (modified) flang/lib/Lower/OpenACC.cpp (+157-105) 
- (modified) flang/test/Lower/OpenACC/acc-device-type.f90 (+4) 
- (modified) flang/test/Lower/OpenACC/acc-loop.f90 (+6) 
- (modified) flang/test/Lower/OpenACC/acc-routine.f90 (+5) 
- (modified) mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp (+3-2) 


``````````diff
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index fd89d27db74dc05..682ca06cabd6f6b 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -1470,15 +1470,19 @@ genAsyncClause(Fortran::lower::AbstractConverter &converter,
                llvm::SmallVector<mlir::Value> &async,
                llvm::SmallVector<mlir::Attribute> &asyncDeviceTypes,
                llvm::SmallVector<mlir::Attribute> &asyncOnlyDeviceTypes,
-               mlir::acc::DeviceTypeAttr deviceTypeAttr,
+               llvm::SmallVector<mlir::Attribute> &deviceTypeAttrs,
                Fortran::lower::StatementContext &stmtCtx) {
   const auto &asyncClauseValue = asyncClause->v;
   if (asyncClauseValue) { // async has a value.
-    async.push_back(fir::getBase(converter.genExprValue(
-        *Fortran::semantics::GetExpr(*asyncClauseValue), stmtCtx)));
-    asyncDeviceTypes.push_back(deviceTypeAttr);
+    mlir::Value asyncValue = fir::getBase(converter.genExprValue(
+        *Fortran::semantics::GetExpr(*asyncClauseValue), stmtCtx));
+    for (auto deviceTypeAttr : deviceTypeAttrs) {
+      async.push_back(asyncValue);
+      asyncDeviceTypes.push_back(deviceTypeAttr);
+    }
   } else {
-    asyncOnlyDeviceTypes.push_back(deviceTypeAttr);
+    for (auto deviceTypeAttr : deviceTypeAttrs)
+      asyncOnlyDeviceTypes.push_back(deviceTypeAttr);
   }
 }
 
@@ -1504,10 +1508,9 @@ getDeviceType(Fortran::common::OpenACCDeviceType device) {
 }
 
 static void gatherDeviceTypeAttrs(
-    fir::FirOpBuilder &builder, mlir::Location clauseLocation,
+    fir::FirOpBuilder &builder,
     const Fortran::parser::AccClause::DeviceType *deviceTypeClause,
-    llvm::SmallVector<mlir::Attribute> &deviceTypes,
-    Fortran::lower::StatementContext &stmtCtx) {
+    llvm::SmallVector<mlir::Attribute> &deviceTypes) {
   const Fortran::parser::AccDeviceTypeExprList &deviceTypeExprList =
       deviceTypeClause->v;
   for (const auto &deviceTypeExpr : deviceTypeExprList.v)
@@ -1560,20 +1563,25 @@ genWaitClause(Fortran::lower::AbstractConverter &converter,
               llvm::SmallVector<mlir::Attribute> &waitOperandsDeviceTypes,
               llvm::SmallVector<mlir::Attribute> &waitOnlyDeviceTypes,
               llvm::SmallVector<int32_t> &waitOperandsSegments,
-              mlir::Value &waitDevnum, mlir::acc::DeviceTypeAttr deviceTypeAttr,
+              mlir::Value &waitDevnum,
+              llvm::SmallVector<mlir::Attribute> deviceTypeAttrs,
               Fortran::lower::StatementContext &stmtCtx) {
   const auto &waitClauseValue = waitClause->v;
   if (waitClauseValue) { // wait has a value.
     const Fortran::parser::AccWaitArgument &waitArg = *waitClauseValue;
     const auto &waitList =
         std::get<std::list<Fortran::parser::ScalarIntExpr>>(waitArg.t);
-    auto crtWaitOperands = waitOperands.size();
+    llvm::SmallVector<mlir::Value> waitValues;
     for (const Fortran::parser::ScalarIntExpr &value : waitList) {
-      waitOperands.push_back(fir::getBase(converter.genExprValue(
+      waitValues.push_back(fir::getBase(converter.genExprValue(
           *Fortran::semantics::GetExpr(value), stmtCtx)));
     }
-    waitOperandsDeviceTypes.push_back(deviceTypeAttr);
-    waitOperandsSegments.push_back(waitOperands.size() - crtWaitOperands);
+    for (auto deviceTypeAttr : deviceTypeAttrs) {
+      for (auto value : waitValues)
+        waitOperands.push_back(value);
+      waitOperandsDeviceTypes.push_back(deviceTypeAttr);
+      waitOperandsSegments.push_back(waitValues.size());
+    }
 
     // TODO: move to device_type model.
     const auto &waitDevnumValue =
@@ -1582,7 +1590,8 @@ genWaitClause(Fortran::lower::AbstractConverter &converter,
       waitDevnum = fir::getBase(converter.genExprValue(
           *Fortran::semantics::GetExpr(*waitDevnumValue), stmtCtx));
   } else {
-    waitOnlyDeviceTypes.push_back(deviceTypeAttr);
+    for (auto deviceTypeAttr : deviceTypeAttrs)
+      waitOnlyDeviceTypes.push_back(deviceTypeAttr);
   }
 }
 
@@ -1610,91 +1619,112 @@ createLoopOp(Fortran::lower::AbstractConverter &converter,
 
   // device_type attribute is set to `none` until a device_type clause is
   // encountered.
-  auto crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get(
-      builder.getContext(), mlir::acc::DeviceType::None);
+  llvm::SmallVector<mlir::Attribute> crtDeviceTypes;
+  crtDeviceTypes.push_back(mlir::acc::DeviceTypeAttr::get(
+      builder.getContext(), mlir::acc::DeviceType::None));
 
   for (const Fortran::parser::AccClause &clause : accClauseList.v) {
     mlir::Location clauseLocation = converter.genLocation(clause.source);
     if (const auto *gangClause =
             std::get_if<Fortran::parser::AccClause::Gang>(&clause.u)) {
       if (gangClause->v) {
-        auto crtGangOperands = gangOperands.size();
         const Fortran::parser::AccGangArgList &x = *gangClause->v;
+        mlir::SmallVector<mlir::Value> gangValues;
+        mlir::SmallVector<mlir::Attribute> gangArgs;
         for (const Fortran::parser::AccGangArg &gangArg : x.v) {
           if (const auto *num =
                   std::get_if<Fortran::parser::AccGangArg::Num>(&gangArg.u)) {
-            gangOperands.push_back(fir::getBase(converter.genExprValue(
+            gangValues.push_back(fir::getBase(converter.genExprValue(
                 *Fortran::semantics::GetExpr(num->v), stmtCtx)));
-            gangArgTypes.push_back(mlir::acc::GangArgTypeAttr::get(
+            gangArgs.push_back(mlir::acc::GangArgTypeAttr::get(
                 builder.getContext(), mlir::acc::GangArgType::Num));
           } else if (const auto *staticArg =
                          std::get_if<Fortran::parser::AccGangArg::Static>(
                              &gangArg.u)) {
             const Fortran::parser::AccSizeExpr &sizeExpr = staticArg->v;
             if (sizeExpr.v) {
-              gangOperands.push_back(fir::getBase(converter.genExprValue(
+              gangValues.push_back(fir::getBase(converter.genExprValue(
                   *Fortran::semantics::GetExpr(*sizeExpr.v), stmtCtx)));
             } else {
               // * was passed as value and will be represented as a special
               // constant.
-              gangOperands.push_back(builder.createIntegerConstant(
+              gangValues.push_back(builder.createIntegerConstant(
                   clauseLocation, builder.getIndexType(), starCst));
             }
-            gangArgTypes.push_back(mlir::acc::GangArgTypeAttr::get(
+            gangArgs.push_back(mlir::acc::GangArgTypeAttr::get(
                 builder.getContext(), mlir::acc::GangArgType::Static));
           } else if (const auto *dim =
                          std::get_if<Fortran::parser::AccGangArg::Dim>(
                              &gangArg.u)) {
-            gangOperands.push_back(fir::getBase(converter.genExprValue(
+            gangValues.push_back(fir::getBase(converter.genExprValue(
                 *Fortran::semantics::GetExpr(dim->v), stmtCtx)));
-            gangArgTypes.push_back(mlir::acc::GangArgTypeAttr::get(
+            gangArgs.push_back(mlir::acc::GangArgTypeAttr::get(
                 builder.getContext(), mlir::acc::GangArgType::Dim));
           }
         }
-        gangOperandsSegments.push_back(gangOperands.size() - crtGangOperands);
-        gangOperandsDeviceTypes.push_back(crtDeviceTypeAttr);
+        for (auto crtDeviceTypeAttr : crtDeviceTypes) {
+          for (const auto &pair : llvm::zip(gangValues, gangArgs)) {
+            gangOperands.push_back(std::get<0>(pair));
+            gangArgTypes.push_back(std::get<1>(pair));
+          }
+          gangOperandsSegments.push_back(gangValues.size());
+          gangOperandsDeviceTypes.push_back(crtDeviceTypeAttr);
+        }
       } else {
-        gangDeviceTypes.push_back(crtDeviceTypeAttr);
+        for (auto crtDeviceTypeAttr : crtDeviceTypes)
+          gangDeviceTypes.push_back(crtDeviceTypeAttr);
       }
     } else if (const auto *workerClause =
                    std::get_if<Fortran::parser::AccClause::Worker>(&clause.u)) {
       if (workerClause->v) {
-        workerNumOperands.push_back(fir::getBase(converter.genExprValue(
-            *Fortran::semantics::GetExpr(*workerClause->v), stmtCtx)));
-        workerNumOperandsDeviceTypes.push_back(crtDeviceTypeAttr);
+        mlir::Value workerNumValue = fir::getBase(converter.genExprValue(
+            *Fortran::semantics::GetExpr(*workerClause->v), stmtCtx));
+        for (auto crtDeviceTypeAttr : crtDeviceTypes) {
+          workerNumOperands.push_back(workerNumValue);
+          workerNumOperandsDeviceTypes.push_back(crtDeviceTypeAttr);
+        }
       } else {
-        workerNumDeviceTypes.push_back(crtDeviceTypeAttr);
+        for (auto crtDeviceTypeAttr : crtDeviceTypes)
+          workerNumDeviceTypes.push_back(crtDeviceTypeAttr);
       }
     } else if (const auto *vectorClause =
                    std::get_if<Fortran::parser::AccClause::Vector>(&clause.u)) {
       if (vectorClause->v) {
-        vectorOperands.push_back(fir::getBase(converter.genExprValue(
-            *Fortran::semantics::GetExpr(*vectorClause->v), stmtCtx)));
-        vectorOperandsDeviceTypes.push_back(crtDeviceTypeAttr);
+        mlir::Value vectorValue = fir::getBase(converter.genExprValue(
+            *Fortran::semantics::GetExpr(*vectorClause->v), stmtCtx));
+        for (auto crtDeviceTypeAttr : crtDeviceTypes) {
+          vectorOperands.push_back(vectorValue);
+          vectorOperandsDeviceTypes.push_back(crtDeviceTypeAttr);
+        }
       } else {
-        vectorDeviceTypes.push_back(crtDeviceTypeAttr);
+        for (auto crtDeviceTypeAttr : crtDeviceTypes)
+          vectorDeviceTypes.push_back(crtDeviceTypeAttr);
       }
     } else if (const auto *tileClause =
                    std::get_if<Fortran::parser::AccClause::Tile>(&clause.u)) {
       const Fortran::parser::AccTileExprList &accTileExprList = tileClause->v;
-      auto crtTileOperands = tileOperands.size();
+      llvm::SmallVector<mlir::Value> tileValues;
       for (const auto &accTileExpr : accTileExprList.v) {
         const auto &expr =
             std::get<std::optional<Fortran::parser::ScalarIntConstantExpr>>(
                 accTileExpr.t);
         if (expr) {
-          tileOperands.push_back(fir::getBase(converter.genExprValue(
+          tileValues.push_back(fir::getBase(converter.genExprValue(
               *Fortran::semantics::GetExpr(*expr), stmtCtx)));
         } else {
           // * was passed as value and will be represented as a special
           // constant.
           mlir::Value tileStar = builder.createIntegerConstant(
               clauseLocation, builder.getIntegerType(32), starCst);
-          tileOperands.push_back(tileStar);
+          tileValues.push_back(tileStar);
         }
       }
-      tileOperandsDeviceTypes.push_back(crtDeviceTypeAttr);
-      tileOperandsSegments.push_back(tileOperands.size() - crtTileOperands);
+      for (auto crtDeviceTypeAttr : crtDeviceTypes) {
+        for (auto value : tileValues)
+          tileOperands.push_back(value);
+        tileOperandsDeviceTypes.push_back(crtDeviceTypeAttr);
+        tileOperandsSegments.push_back(tileValues.size());
+      }
     } else if (const auto *privateClause =
                    std::get_if<Fortran::parser::AccClause::Private>(
                        &clause.u)) {
@@ -1707,21 +1737,20 @@ createLoopOp(Fortran::lower::AbstractConverter &converter,
       genReductions(reductionClause->v, converter, semanticsContext, stmtCtx,
                     reductionOperands, reductionRecipes);
     } else if (std::get_if<Fortran::parser::AccClause::Seq>(&clause.u)) {
-      seqDeviceTypes.push_back(crtDeviceTypeAttr);
+      for (auto crtDeviceTypeAttr : crtDeviceTypes)
+        seqDeviceTypes.push_back(crtDeviceTypeAttr);
     } else if (std::get_if<Fortran::parser::AccClause::Independent>(
                    &clause.u)) {
-      independentDeviceTypes.push_back(crtDeviceTypeAttr);
+      for (auto crtDeviceTypeAttr : crtDeviceTypes)
+        independentDeviceTypes.push_back(crtDeviceTypeAttr);
     } else if (std::get_if<Fortran::parser::AccClause::Auto>(&clause.u)) {
-      autoDeviceTypes.push_back(crtDeviceTypeAttr);
+      for (auto crtDeviceTypeAttr : crtDeviceTypes)
+        autoDeviceTypes.push_back(crtDeviceTypeAttr);
     } else if (const auto *deviceTypeClause =
                    std::get_if<Fortran::parser::AccClause::DeviceType>(
                        &clause.u)) {
-      const Fortran::parser::AccDeviceTypeExprList &deviceTypeExprList =
-          deviceTypeClause->v;
-      assert(deviceTypeExprList.v.size() == 1 &&
-             "expect only one device_type expr");
-      crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get(
-          builder.getContext(), getDeviceType(deviceTypeExprList.v.front().v));
+      crtDeviceTypes.clear();
+      gatherDeviceTypeAttrs(builder, deviceTypeClause, crtDeviceTypes);
     } else if (const auto *collapseClause =
                    std::get_if<Fortran::parser::AccClause::Collapse>(
                        &clause.u)) {
@@ -1729,14 +1758,18 @@ createLoopOp(Fortran::lower::AbstractConverter &converter,
       const auto &force = std::get<bool>(arg.t);
       if (force)
         TODO(clauseLocation, "OpenACC collapse force modifier");
+
       const auto &intExpr =
           std::get<Fortran::parser::ScalarIntConstantExpr>(arg.t);
       const auto *expr = Fortran::semantics::GetExpr(intExpr);
       const std::optional<int64_t> collapseValue =
           Fortran::evaluate::ToInt64(*expr);
       assert(collapseValue && "expect integer value for the collapse clause");
-      collapseValues.push_back(*collapseValue);
-      collapseDeviceTypes.push_back(crtDeviceTypeAttr);
+
+      for (auto crtDeviceTypeAttr : crtDeviceTypes) {
+        collapseValues.push_back(*collapseValue);
+        collapseDeviceTypes.push_back(crtDeviceTypeAttr);
+      }
     }
   }
 
@@ -1923,45 +1956,56 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
 
   // device_type attribute is set to `none` until a device_type clause is
   // encountered.
+  llvm::SmallVector<mlir::Attribute> crtDeviceTypes;
   auto crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get(
       builder.getContext(), mlir::acc::DeviceType::None);
+  crtDeviceTypes.push_back(crtDeviceTypeAttr);
 
-  // Lower clauses values mapped to operands.
-  // Keep track of each group of operands separatly as clauses can appear
+  // Lower clauses values mapped to operands and array attributes.
+  // Keep track of each group of operands separately as clauses can appear
   // more than once.
   for (const Fortran::parser::AccClause &clause : accClauseList.v) {
     mlir::Location clauseLocation = converter.genLocation(clause.source);
     if (const auto *asyncClause =
             std::get_if<Fortran::parser::AccClause::Async>(&clause.u)) {
       genAsyncClause(converter, asyncClause, async, asyncDeviceTypes,
-                     asyncOnlyDeviceTypes, crtDeviceTypeAttr, stmtCtx);
+                     asyncOnlyDeviceTypes, crtDeviceTypes, stmtCtx);
     } else if (const auto *waitClause =
                    std::get_if<Fortran::parser::AccClause::Wait>(&clause.u)) {
       genWaitClause(converter, waitClause, waitOperands,
                     waitOperandsDeviceTypes, waitOnlyDeviceTypes,
-                    waitOperandsSegments, waitDevnum, crtDeviceTypeAttr,
-                    stmtCtx);
+                    waitOperandsSegments, waitDevnum, crtDeviceTypes, stmtCtx);
     } else if (const auto *numGangsClause =
                    std::get_if<Fortran::parser::AccClause::NumGangs>(
                        &clause.u)) {
-      auto crtNumGangs = numGangs.size();
+      llvm::SmallVector<mlir::Value> numGangValues;
       for (const Fortran::parser::ScalarIntExpr &expr : numGangsClause->v)
-        numGangs.push_back(fir::getBase(converter.genExprValue(
+        numGangValues.push_back(fir::getBase(converter.genExprValue(
             *Fortran::semantics::GetExpr(expr), stmtCtx)));
-      numGangsDeviceTypes.push_back(crtDeviceTypeAttr);
-      numGangsSegments.push_back(numGangs.size() - crtNumGangs);
+      for (auto crtDeviceTypeAttr : crtDeviceTypes) {
+        for (auto value : numGangValues)
+          numGangs.push_back(value);
+        numGangsDeviceTypes.push_back(crtDeviceTypeAttr);
+        numGangsSegments.push_back(numGangValues.size());
+      }
     } else if (const auto *numWorkersClause =
                    std::get_if<Fortran::parser::AccClause::NumWorkers>(
                        &clause.u)) {
-      numWorkers.push_back(fir::getBase(converter.genExprValue(
-          *Fortran::semantics::GetExpr(numWorkersClause->v), stmtCtx)));
-      numWorkersDeviceTypes.push_back(crtDeviceTypeAttr);
+      mlir::Value numWorkerValue = fir::getBase(converter.genExprValue(
+          *Fortran::semantics::GetExpr(numWorkersClause->v), stmtCtx));
+      for (auto crtDeviceTypeAttr : crtDeviceTypes) {
+        numWorkers.push_back(numWorkerValue);
+        numWorkersDeviceTypes.push_back(crtDeviceTypeAttr);
+      }
     } else if (const auto *vectorLengthClause =
                    std::get_if<Fortran::parser::AccClause::VectorLength>(
                        &clause.u)) {
-      vectorLength.push_back(fir::getBase(converter.genExprValue(
-          *Fortran::semantics::GetExpr(vectorLengthClause->v), stmtCtx)));
-      vectorLengthDeviceTypes.push_back(crtDeviceTypeAttr);
+      mlir::Value vectorLengthValue = fir::getBase(converter.genExprValue(
+          *Fortran::semantics::GetExpr(vectorLengthClause->v), stmtCtx));
+      for (auto crtDeviceTypeAttr : crtDeviceTypes) {
+        vectorLength.push_back(vectorLengthValue);
+        vectorLengthDeviceTypes.push_back(crtDeviceTypeAttr);
+      }
     } else if (const auto *ifClause =
                    std::get_if<Fortran::parser::AccClause::If>(&clause.u)) {
       genIfClause(converter, clauseLocation, ifClause, ifCond, stmtCtx);
@@ -2115,12 +2159,8 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
     } else if (const auto *deviceTypeClause =
                    std::get_if<Fortran::parser::AccClause::DeviceType>(
                        &clause.u)) {
-      const Fortran::parser::AccDeviceTypeExprList &deviceTypeExprList =
-          deviceTypeClause->v;
-      assert(deviceTypeExprList.v.size() == 1 &&
-             "expect only one device_type expr");
-      crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get(
-          builder.getContext(), getDeviceType(deviceTypeExprList.v.front().v));
+      crtDeviceTypes.clear();
+      gatherDeviceTypeAttrs(builder, deviceTypeClause, crtDeviceTypes);
     }
   }
 
@@ -2239,10 +2279,11 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
 
   // device_type attribute is set to `none` until a device_type clause is
   // encountered.
-  auto crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get(
-      builder.getContext(), mlir::acc::DeviceType::None);
+  llvm::SmallVector<mlir::Attribute> crtDeviceTypes;
+  crtDeviceTypes.push_back(mlir::acc::DeviceTypeAttr::get(
+      builder.getContext(), mlir::acc::DeviceType::None));
 
-  // Lower clauses values mapped to operands.
+  // Lower clauses values mapped to operands and array attributes.
   // Keep track of each group of operands separately as clauses can appear
   // more than once.
   for (const Fortran::parser::AccClause &clause : accClauseList.v) {
@@ -2323,19 +2364,23 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
     } else if (const auto *asyncClause =
                    std::get_if<Fortran::parser::AccClause::Async>(&clause.u)) {
       genAsyncClause(converter, asyncClause, async, asyncDeviceTypes,
-                     asyncOnlyDeviceTypes, crtDeviceTypeAttr, stmtCtx);
+                     asyncOnlyDeviceTypes, crtDeviceTypes, stmtCtx);
     } else if (const auto *waitClause =
                    std::get_if<Fortran::parser::AccClause::Wait>(&clause.u)) {
       genWaitClause(converter, waitClause, waitOperands,
                     waitOperandsDeviceTypes, waitOnlyDeviceTypes,
-                    waitOperandsSegments, waitDevnum, crtDeviceTypeAttr,
-                    stmtCtx);
+                    waitOperandsSegments, waitDevnum, crtDeviceTypes, stmtCtx);
     } else if(const auto *defaultClause = 
                   std::get_if<...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/78634


More information about the flang-commits mailing list