[flang-commits] [flang] b8967e0 - [flang][openacc] Support multiple device_type when lowering (#78634)

via flang-commits flang-commits at lists.llvm.org
Thu Jan 18 21:20:34 PST 2024


Author: Valentin Clement (バレンタイン クレメン)
Date: 2024-01-18T21:20:28-08:00
New Revision: b8967e003e202cba1b77412478a1990c9dcccdca

URL: https://github.com/llvm/llvm-project/commit/b8967e003e202cba1b77412478a1990c9dcccdca
DIFF: https://github.com/llvm/llvm-project/commit/b8967e003e202cba1b77412478a1990c9dcccdca.diff

LOG: [flang][openacc] Support multiple device_type when lowering (#78634)

routine, data, parallel, serial, kernels and loop construct all support
the device_type clause. This clause takes a list of device_type.
Previously the lowering code was assuming that the list s a single item.
This PR updates the lowering to handle any number of device_types.

Added: 
    

Modified: 
    flang/lib/Lower/OpenACC.cpp
    flang/test/Lower/OpenACC/acc-device-type.f90
    flang/test/Lower/OpenACC/acc-loop.f90
    flang/test/Lower/OpenACC/acc-routine.f90
    mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp

Removed: 
    


################################################################################
diff  --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index fd89d27db74dc05..682ca06cabd6f6b 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -1470,15 +1470,19 @@ genAsyncClause(Fortran::lower::AbstractConverter &converter,
                llvm::SmallVector<mlir::Value> &async,
                llvm::SmallVector<mlir::Attribute> &asyncDeviceTypes,
                llvm::SmallVector<mlir::Attribute> &asyncOnlyDeviceTypes,
-               mlir::acc::DeviceTypeAttr deviceTypeAttr,
+               llvm::SmallVector<mlir::Attribute> &deviceTypeAttrs,
                Fortran::lower::StatementContext &stmtCtx) {
   const auto &asyncClauseValue = asyncClause->v;
   if (asyncClauseValue) { // async has a value.
-    async.push_back(fir::getBase(converter.genExprValue(
-        *Fortran::semantics::GetExpr(*asyncClauseValue), stmtCtx)));
-    asyncDeviceTypes.push_back(deviceTypeAttr);
+    mlir::Value asyncValue = fir::getBase(converter.genExprValue(
+        *Fortran::semantics::GetExpr(*asyncClauseValue), stmtCtx));
+    for (auto deviceTypeAttr : deviceTypeAttrs) {
+      async.push_back(asyncValue);
+      asyncDeviceTypes.push_back(deviceTypeAttr);
+    }
   } else {
-    asyncOnlyDeviceTypes.push_back(deviceTypeAttr);
+    for (auto deviceTypeAttr : deviceTypeAttrs)
+      asyncOnlyDeviceTypes.push_back(deviceTypeAttr);
   }
 }
 
@@ -1504,10 +1508,9 @@ getDeviceType(Fortran::common::OpenACCDeviceType device) {
 }
 
 static void gatherDeviceTypeAttrs(
-    fir::FirOpBuilder &builder, mlir::Location clauseLocation,
+    fir::FirOpBuilder &builder,
     const Fortran::parser::AccClause::DeviceType *deviceTypeClause,
-    llvm::SmallVector<mlir::Attribute> &deviceTypes,
-    Fortran::lower::StatementContext &stmtCtx) {
+    llvm::SmallVector<mlir::Attribute> &deviceTypes) {
   const Fortran::parser::AccDeviceTypeExprList &deviceTypeExprList =
       deviceTypeClause->v;
   for (const auto &deviceTypeExpr : deviceTypeExprList.v)
@@ -1560,20 +1563,25 @@ genWaitClause(Fortran::lower::AbstractConverter &converter,
               llvm::SmallVector<mlir::Attribute> &waitOperandsDeviceTypes,
               llvm::SmallVector<mlir::Attribute> &waitOnlyDeviceTypes,
               llvm::SmallVector<int32_t> &waitOperandsSegments,
-              mlir::Value &waitDevnum, mlir::acc::DeviceTypeAttr deviceTypeAttr,
+              mlir::Value &waitDevnum,
+              llvm::SmallVector<mlir::Attribute> deviceTypeAttrs,
               Fortran::lower::StatementContext &stmtCtx) {
   const auto &waitClauseValue = waitClause->v;
   if (waitClauseValue) { // wait has a value.
     const Fortran::parser::AccWaitArgument &waitArg = *waitClauseValue;
     const auto &waitList =
         std::get<std::list<Fortran::parser::ScalarIntExpr>>(waitArg.t);
-    auto crtWaitOperands = waitOperands.size();
+    llvm::SmallVector<mlir::Value> waitValues;
     for (const Fortran::parser::ScalarIntExpr &value : waitList) {
-      waitOperands.push_back(fir::getBase(converter.genExprValue(
+      waitValues.push_back(fir::getBase(converter.genExprValue(
           *Fortran::semantics::GetExpr(value), stmtCtx)));
     }
-    waitOperandsDeviceTypes.push_back(deviceTypeAttr);
-    waitOperandsSegments.push_back(waitOperands.size() - crtWaitOperands);
+    for (auto deviceTypeAttr : deviceTypeAttrs) {
+      for (auto value : waitValues)
+        waitOperands.push_back(value);
+      waitOperandsDeviceTypes.push_back(deviceTypeAttr);
+      waitOperandsSegments.push_back(waitValues.size());
+    }
 
     // TODO: move to device_type model.
     const auto &waitDevnumValue =
@@ -1582,7 +1590,8 @@ genWaitClause(Fortran::lower::AbstractConverter &converter,
       waitDevnum = fir::getBase(converter.genExprValue(
           *Fortran::semantics::GetExpr(*waitDevnumValue), stmtCtx));
   } else {
-    waitOnlyDeviceTypes.push_back(deviceTypeAttr);
+    for (auto deviceTypeAttr : deviceTypeAttrs)
+      waitOnlyDeviceTypes.push_back(deviceTypeAttr);
   }
 }
 
@@ -1610,91 +1619,112 @@ createLoopOp(Fortran::lower::AbstractConverter &converter,
 
   // device_type attribute is set to `none` until a device_type clause is
   // encountered.
-  auto crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get(
-      builder.getContext(), mlir::acc::DeviceType::None);
+  llvm::SmallVector<mlir::Attribute> crtDeviceTypes;
+  crtDeviceTypes.push_back(mlir::acc::DeviceTypeAttr::get(
+      builder.getContext(), mlir::acc::DeviceType::None));
 
   for (const Fortran::parser::AccClause &clause : accClauseList.v) {
     mlir::Location clauseLocation = converter.genLocation(clause.source);
     if (const auto *gangClause =
             std::get_if<Fortran::parser::AccClause::Gang>(&clause.u)) {
       if (gangClause->v) {
-        auto crtGangOperands = gangOperands.size();
         const Fortran::parser::AccGangArgList &x = *gangClause->v;
+        mlir::SmallVector<mlir::Value> gangValues;
+        mlir::SmallVector<mlir::Attribute> gangArgs;
         for (const Fortran::parser::AccGangArg &gangArg : x.v) {
           if (const auto *num =
                   std::get_if<Fortran::parser::AccGangArg::Num>(&gangArg.u)) {
-            gangOperands.push_back(fir::getBase(converter.genExprValue(
+            gangValues.push_back(fir::getBase(converter.genExprValue(
                 *Fortran::semantics::GetExpr(num->v), stmtCtx)));
-            gangArgTypes.push_back(mlir::acc::GangArgTypeAttr::get(
+            gangArgs.push_back(mlir::acc::GangArgTypeAttr::get(
                 builder.getContext(), mlir::acc::GangArgType::Num));
           } else if (const auto *staticArg =
                          std::get_if<Fortran::parser::AccGangArg::Static>(
                              &gangArg.u)) {
             const Fortran::parser::AccSizeExpr &sizeExpr = staticArg->v;
             if (sizeExpr.v) {
-              gangOperands.push_back(fir::getBase(converter.genExprValue(
+              gangValues.push_back(fir::getBase(converter.genExprValue(
                   *Fortran::semantics::GetExpr(*sizeExpr.v), stmtCtx)));
             } else {
               // * was passed as value and will be represented as a special
               // constant.
-              gangOperands.push_back(builder.createIntegerConstant(
+              gangValues.push_back(builder.createIntegerConstant(
                   clauseLocation, builder.getIndexType(), starCst));
             }
-            gangArgTypes.push_back(mlir::acc::GangArgTypeAttr::get(
+            gangArgs.push_back(mlir::acc::GangArgTypeAttr::get(
                 builder.getContext(), mlir::acc::GangArgType::Static));
           } else if (const auto *dim =
                          std::get_if<Fortran::parser::AccGangArg::Dim>(
                              &gangArg.u)) {
-            gangOperands.push_back(fir::getBase(converter.genExprValue(
+            gangValues.push_back(fir::getBase(converter.genExprValue(
                 *Fortran::semantics::GetExpr(dim->v), stmtCtx)));
-            gangArgTypes.push_back(mlir::acc::GangArgTypeAttr::get(
+            gangArgs.push_back(mlir::acc::GangArgTypeAttr::get(
                 builder.getContext(), mlir::acc::GangArgType::Dim));
           }
         }
-        gangOperandsSegments.push_back(gangOperands.size() - crtGangOperands);
-        gangOperandsDeviceTypes.push_back(crtDeviceTypeAttr);
+        for (auto crtDeviceTypeAttr : crtDeviceTypes) {
+          for (const auto &pair : llvm::zip(gangValues, gangArgs)) {
+            gangOperands.push_back(std::get<0>(pair));
+            gangArgTypes.push_back(std::get<1>(pair));
+          }
+          gangOperandsSegments.push_back(gangValues.size());
+          gangOperandsDeviceTypes.push_back(crtDeviceTypeAttr);
+        }
       } else {
-        gangDeviceTypes.push_back(crtDeviceTypeAttr);
+        for (auto crtDeviceTypeAttr : crtDeviceTypes)
+          gangDeviceTypes.push_back(crtDeviceTypeAttr);
       }
     } else if (const auto *workerClause =
                    std::get_if<Fortran::parser::AccClause::Worker>(&clause.u)) {
       if (workerClause->v) {
-        workerNumOperands.push_back(fir::getBase(converter.genExprValue(
-            *Fortran::semantics::GetExpr(*workerClause->v), stmtCtx)));
-        workerNumOperandsDeviceTypes.push_back(crtDeviceTypeAttr);
+        mlir::Value workerNumValue = fir::getBase(converter.genExprValue(
+            *Fortran::semantics::GetExpr(*workerClause->v), stmtCtx));
+        for (auto crtDeviceTypeAttr : crtDeviceTypes) {
+          workerNumOperands.push_back(workerNumValue);
+          workerNumOperandsDeviceTypes.push_back(crtDeviceTypeAttr);
+        }
       } else {
-        workerNumDeviceTypes.push_back(crtDeviceTypeAttr);
+        for (auto crtDeviceTypeAttr : crtDeviceTypes)
+          workerNumDeviceTypes.push_back(crtDeviceTypeAttr);
       }
     } else if (const auto *vectorClause =
                    std::get_if<Fortran::parser::AccClause::Vector>(&clause.u)) {
       if (vectorClause->v) {
-        vectorOperands.push_back(fir::getBase(converter.genExprValue(
-            *Fortran::semantics::GetExpr(*vectorClause->v), stmtCtx)));
-        vectorOperandsDeviceTypes.push_back(crtDeviceTypeAttr);
+        mlir::Value vectorValue = fir::getBase(converter.genExprValue(
+            *Fortran::semantics::GetExpr(*vectorClause->v), stmtCtx));
+        for (auto crtDeviceTypeAttr : crtDeviceTypes) {
+          vectorOperands.push_back(vectorValue);
+          vectorOperandsDeviceTypes.push_back(crtDeviceTypeAttr);
+        }
       } else {
-        vectorDeviceTypes.push_back(crtDeviceTypeAttr);
+        for (auto crtDeviceTypeAttr : crtDeviceTypes)
+          vectorDeviceTypes.push_back(crtDeviceTypeAttr);
       }
     } else if (const auto *tileClause =
                    std::get_if<Fortran::parser::AccClause::Tile>(&clause.u)) {
       const Fortran::parser::AccTileExprList &accTileExprList = tileClause->v;
-      auto crtTileOperands = tileOperands.size();
+      llvm::SmallVector<mlir::Value> tileValues;
       for (const auto &accTileExpr : accTileExprList.v) {
         const auto &expr =
             std::get<std::optional<Fortran::parser::ScalarIntConstantExpr>>(
                 accTileExpr.t);
         if (expr) {
-          tileOperands.push_back(fir::getBase(converter.genExprValue(
+          tileValues.push_back(fir::getBase(converter.genExprValue(
               *Fortran::semantics::GetExpr(*expr), stmtCtx)));
         } else {
           // * was passed as value and will be represented as a special
           // constant.
           mlir::Value tileStar = builder.createIntegerConstant(
               clauseLocation, builder.getIntegerType(32), starCst);
-          tileOperands.push_back(tileStar);
+          tileValues.push_back(tileStar);
         }
       }
-      tileOperandsDeviceTypes.push_back(crtDeviceTypeAttr);
-      tileOperandsSegments.push_back(tileOperands.size() - crtTileOperands);
+      for (auto crtDeviceTypeAttr : crtDeviceTypes) {
+        for (auto value : tileValues)
+          tileOperands.push_back(value);
+        tileOperandsDeviceTypes.push_back(crtDeviceTypeAttr);
+        tileOperandsSegments.push_back(tileValues.size());
+      }
     } else if (const auto *privateClause =
                    std::get_if<Fortran::parser::AccClause::Private>(
                        &clause.u)) {
@@ -1707,21 +1737,20 @@ createLoopOp(Fortran::lower::AbstractConverter &converter,
       genReductions(reductionClause->v, converter, semanticsContext, stmtCtx,
                     reductionOperands, reductionRecipes);
     } else if (std::get_if<Fortran::parser::AccClause::Seq>(&clause.u)) {
-      seqDeviceTypes.push_back(crtDeviceTypeAttr);
+      for (auto crtDeviceTypeAttr : crtDeviceTypes)
+        seqDeviceTypes.push_back(crtDeviceTypeAttr);
     } else if (std::get_if<Fortran::parser::AccClause::Independent>(
                    &clause.u)) {
-      independentDeviceTypes.push_back(crtDeviceTypeAttr);
+      for (auto crtDeviceTypeAttr : crtDeviceTypes)
+        independentDeviceTypes.push_back(crtDeviceTypeAttr);
     } else if (std::get_if<Fortran::parser::AccClause::Auto>(&clause.u)) {
-      autoDeviceTypes.push_back(crtDeviceTypeAttr);
+      for (auto crtDeviceTypeAttr : crtDeviceTypes)
+        autoDeviceTypes.push_back(crtDeviceTypeAttr);
     } else if (const auto *deviceTypeClause =
                    std::get_if<Fortran::parser::AccClause::DeviceType>(
                        &clause.u)) {
-      const Fortran::parser::AccDeviceTypeExprList &deviceTypeExprList =
-          deviceTypeClause->v;
-      assert(deviceTypeExprList.v.size() == 1 &&
-             "expect only one device_type expr");
-      crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get(
-          builder.getContext(), getDeviceType(deviceTypeExprList.v.front().v));
+      crtDeviceTypes.clear();
+      gatherDeviceTypeAttrs(builder, deviceTypeClause, crtDeviceTypes);
     } else if (const auto *collapseClause =
                    std::get_if<Fortran::parser::AccClause::Collapse>(
                        &clause.u)) {
@@ -1729,14 +1758,18 @@ createLoopOp(Fortran::lower::AbstractConverter &converter,
       const auto &force = std::get<bool>(arg.t);
       if (force)
         TODO(clauseLocation, "OpenACC collapse force modifier");
+
       const auto &intExpr =
           std::get<Fortran::parser::ScalarIntConstantExpr>(arg.t);
       const auto *expr = Fortran::semantics::GetExpr(intExpr);
       const std::optional<int64_t> collapseValue =
           Fortran::evaluate::ToInt64(*expr);
       assert(collapseValue && "expect integer value for the collapse clause");
-      collapseValues.push_back(*collapseValue);
-      collapseDeviceTypes.push_back(crtDeviceTypeAttr);
+
+      for (auto crtDeviceTypeAttr : crtDeviceTypes) {
+        collapseValues.push_back(*collapseValue);
+        collapseDeviceTypes.push_back(crtDeviceTypeAttr);
+      }
     }
   }
 
@@ -1923,45 +1956,56 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
 
   // device_type attribute is set to `none` until a device_type clause is
   // encountered.
+  llvm::SmallVector<mlir::Attribute> crtDeviceTypes;
   auto crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get(
       builder.getContext(), mlir::acc::DeviceType::None);
+  crtDeviceTypes.push_back(crtDeviceTypeAttr);
 
-  // Lower clauses values mapped to operands.
-  // Keep track of each group of operands separatly as clauses can appear
+  // Lower clauses values mapped to operands and array attributes.
+  // Keep track of each group of operands separately as clauses can appear
   // more than once.
   for (const Fortran::parser::AccClause &clause : accClauseList.v) {
     mlir::Location clauseLocation = converter.genLocation(clause.source);
     if (const auto *asyncClause =
             std::get_if<Fortran::parser::AccClause::Async>(&clause.u)) {
       genAsyncClause(converter, asyncClause, async, asyncDeviceTypes,
-                     asyncOnlyDeviceTypes, crtDeviceTypeAttr, stmtCtx);
+                     asyncOnlyDeviceTypes, crtDeviceTypes, stmtCtx);
     } else if (const auto *waitClause =
                    std::get_if<Fortran::parser::AccClause::Wait>(&clause.u)) {
       genWaitClause(converter, waitClause, waitOperands,
                     waitOperandsDeviceTypes, waitOnlyDeviceTypes,
-                    waitOperandsSegments, waitDevnum, crtDeviceTypeAttr,
-                    stmtCtx);
+                    waitOperandsSegments, waitDevnum, crtDeviceTypes, stmtCtx);
     } else if (const auto *numGangsClause =
                    std::get_if<Fortran::parser::AccClause::NumGangs>(
                        &clause.u)) {
-      auto crtNumGangs = numGangs.size();
+      llvm::SmallVector<mlir::Value> numGangValues;
       for (const Fortran::parser::ScalarIntExpr &expr : numGangsClause->v)
-        numGangs.push_back(fir::getBase(converter.genExprValue(
+        numGangValues.push_back(fir::getBase(converter.genExprValue(
             *Fortran::semantics::GetExpr(expr), stmtCtx)));
-      numGangsDeviceTypes.push_back(crtDeviceTypeAttr);
-      numGangsSegments.push_back(numGangs.size() - crtNumGangs);
+      for (auto crtDeviceTypeAttr : crtDeviceTypes) {
+        for (auto value : numGangValues)
+          numGangs.push_back(value);
+        numGangsDeviceTypes.push_back(crtDeviceTypeAttr);
+        numGangsSegments.push_back(numGangValues.size());
+      }
     } else if (const auto *numWorkersClause =
                    std::get_if<Fortran::parser::AccClause::NumWorkers>(
                        &clause.u)) {
-      numWorkers.push_back(fir::getBase(converter.genExprValue(
-          *Fortran::semantics::GetExpr(numWorkersClause->v), stmtCtx)));
-      numWorkersDeviceTypes.push_back(crtDeviceTypeAttr);
+      mlir::Value numWorkerValue = fir::getBase(converter.genExprValue(
+          *Fortran::semantics::GetExpr(numWorkersClause->v), stmtCtx));
+      for (auto crtDeviceTypeAttr : crtDeviceTypes) {
+        numWorkers.push_back(numWorkerValue);
+        numWorkersDeviceTypes.push_back(crtDeviceTypeAttr);
+      }
     } else if (const auto *vectorLengthClause =
                    std::get_if<Fortran::parser::AccClause::VectorLength>(
                        &clause.u)) {
-      vectorLength.push_back(fir::getBase(converter.genExprValue(
-          *Fortran::semantics::GetExpr(vectorLengthClause->v), stmtCtx)));
-      vectorLengthDeviceTypes.push_back(crtDeviceTypeAttr);
+      mlir::Value vectorLengthValue = fir::getBase(converter.genExprValue(
+          *Fortran::semantics::GetExpr(vectorLengthClause->v), stmtCtx));
+      for (auto crtDeviceTypeAttr : crtDeviceTypes) {
+        vectorLength.push_back(vectorLengthValue);
+        vectorLengthDeviceTypes.push_back(crtDeviceTypeAttr);
+      }
     } else if (const auto *ifClause =
                    std::get_if<Fortran::parser::AccClause::If>(&clause.u)) {
       genIfClause(converter, clauseLocation, ifClause, ifCond, stmtCtx);
@@ -2115,12 +2159,8 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
     } else if (const auto *deviceTypeClause =
                    std::get_if<Fortran::parser::AccClause::DeviceType>(
                        &clause.u)) {
-      const Fortran::parser::AccDeviceTypeExprList &deviceTypeExprList =
-          deviceTypeClause->v;
-      assert(deviceTypeExprList.v.size() == 1 &&
-             "expect only one device_type expr");
-      crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get(
-          builder.getContext(), getDeviceType(deviceTypeExprList.v.front().v));
+      crtDeviceTypes.clear();
+      gatherDeviceTypeAttrs(builder, deviceTypeClause, crtDeviceTypes);
     }
   }
 
@@ -2239,10 +2279,11 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
 
   // device_type attribute is set to `none` until a device_type clause is
   // encountered.
-  auto crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get(
-      builder.getContext(), mlir::acc::DeviceType::None);
+  llvm::SmallVector<mlir::Attribute> crtDeviceTypes;
+  crtDeviceTypes.push_back(mlir::acc::DeviceTypeAttr::get(
+      builder.getContext(), mlir::acc::DeviceType::None));
 
-  // Lower clauses values mapped to operands.
+  // Lower clauses values mapped to operands and array attributes.
   // Keep track of each group of operands separately as clauses can appear
   // more than once.
   for (const Fortran::parser::AccClause &clause : accClauseList.v) {
@@ -2323,19 +2364,23 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
     } else if (const auto *asyncClause =
                    std::get_if<Fortran::parser::AccClause::Async>(&clause.u)) {
       genAsyncClause(converter, asyncClause, async, asyncDeviceTypes,
-                     asyncOnlyDeviceTypes, crtDeviceTypeAttr, stmtCtx);
+                     asyncOnlyDeviceTypes, crtDeviceTypes, stmtCtx);
     } else if (const auto *waitClause =
                    std::get_if<Fortran::parser::AccClause::Wait>(&clause.u)) {
       genWaitClause(converter, waitClause, waitOperands,
                     waitOperandsDeviceTypes, waitOnlyDeviceTypes,
-                    waitOperandsSegments, waitDevnum, crtDeviceTypeAttr,
-                    stmtCtx);
+                    waitOperandsSegments, waitDevnum, crtDeviceTypes, stmtCtx);
     } else if(const auto *defaultClause = 
                   std::get_if<Fortran::parser::AccClause::Default>(&clause.u)) {
       if ((defaultClause->v).v == llvm::acc::DefaultValue::ACC_Default_none)
         hasDefaultNone = true;
       else if ((defaultClause->v).v == llvm::acc::DefaultValue::ACC_Default_present)
         hasDefaultPresent = true;
+    } else if (const auto *deviceTypeClause =
+                   std::get_if<Fortran::parser::AccClause::DeviceType>(
+                       &clause.u)) {
+      crtDeviceTypes.clear();
+      gatherDeviceTypeAttrs(builder, deviceTypeClause, crtDeviceTypes);
     }
   }
 
@@ -2727,8 +2772,7 @@ genACCInitShutdownOp(Fortran::lower::AbstractConverter &converter,
     } else if (const auto *deviceTypeClause =
                    std::get_if<Fortran::parser::AccClause::DeviceType>(
                        &clause.u)) {
-      gatherDeviceTypeAttrs(builder, clauseLocation, deviceTypeClause,
-                            deviceTypes, stmtCtx);
+      gatherDeviceTypeAttrs(builder, deviceTypeClause, deviceTypes);
     }
   }
 
@@ -2777,8 +2821,7 @@ void genACCSetOp(Fortran::lower::AbstractConverter &converter,
     } else if (const auto *deviceTypeClause =
                    std::get_if<Fortran::parser::AccClause::DeviceType>(
                        &clause.u)) {
-      gatherDeviceTypeAttrs(builder, clauseLocation, deviceTypeClause,
-                            deviceTypes, stmtCtx);
+      gatherDeviceTypeAttrs(builder, deviceTypeClause, deviceTypes);
     }
   }
 
@@ -2835,8 +2878,7 @@ genACCUpdateOp(Fortran::lower::AbstractConverter &converter,
     } else if (const auto *deviceTypeClause =
                    std::get_if<Fortran::parser::AccClause::DeviceType>(
                        &clause.u)) {
-      gatherDeviceTypeAttrs(builder, clauseLocation, deviceTypeClause,
-                            deviceTypes, stmtCtx);
+      gatherDeviceTypeAttrs(builder, deviceTypeClause, deviceTypes);
     } else if (const auto *hostClause =
                    std::get_if<Fortran::parser::AccClause::Host>(&clause.u)) {
       genDataOperandOperations<mlir::acc::GetDevicePtrOp>(
@@ -3592,15 +3634,16 @@ void Fortran::lower::genOpenACCRoutineConstruct(
 
   // device_type attribute is set to `none` until a device_type clause is
   // encountered.
-  auto crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get(
-      builder.getContext(), mlir::acc::DeviceType::None);
+  llvm::SmallVector<mlir::Attribute> crtDeviceTypes;
+  crtDeviceTypes.push_back(mlir::acc::DeviceTypeAttr::get(
+      builder.getContext(), mlir::acc::DeviceType::None));
 
   for (const Fortran::parser::AccClause &clause : clauses.v) {
     if (std::get_if<Fortran::parser::AccClause::Seq>(&clause.u)) {
-      seqDeviceTypes.push_back(crtDeviceTypeAttr);
+      for (auto crtDeviceTypeAttr : crtDeviceTypes)
+        seqDeviceTypes.push_back(crtDeviceTypeAttr);
     } else if (const auto *gangClause =
                    std::get_if<Fortran::parser::AccClause::Gang>(&clause.u)) {
-
       if (gangClause->v) {
         const Fortran::parser::AccGangArgList &x = *gangClause->v;
         for (const Fortran::parser::AccGangArg &gangArg : x.v) {
@@ -3611,27 +3654,36 @@ void Fortran::lower::genOpenACCRoutineConstruct(
             if (!dimValue)
               mlir::emitError(loc,
                               "dim value must be a constant positive integer");
-            gangDimValues.push_back(
-                builder.getIntegerAttr(builder.getI64Type(), *dimValue));
-            gangDimDeviceTypes.push_back(crtDeviceTypeAttr);
+            mlir::Attribute gangDimAttr =
+                builder.getIntegerAttr(builder.getI64Type(), *dimValue);
+            for (auto crtDeviceTypeAttr : crtDeviceTypes) {
+              gangDimValues.push_back(gangDimAttr);
+              gangDimDeviceTypes.push_back(crtDeviceTypeAttr);
+            }
           }
         }
       } else {
-        gangDeviceTypes.push_back(crtDeviceTypeAttr);
+        for (auto crtDeviceTypeAttr : crtDeviceTypes)
+          gangDeviceTypes.push_back(crtDeviceTypeAttr);
       }
     } else if (std::get_if<Fortran::parser::AccClause::Vector>(&clause.u)) {
-      vectorDeviceTypes.push_back(crtDeviceTypeAttr);
+      for (auto crtDeviceTypeAttr : crtDeviceTypes)
+        vectorDeviceTypes.push_back(crtDeviceTypeAttr);
     } else if (std::get_if<Fortran::parser::AccClause::Worker>(&clause.u)) {
-      workerDeviceTypes.push_back(crtDeviceTypeAttr);
+      for (auto crtDeviceTypeAttr : crtDeviceTypes)
+        workerDeviceTypes.push_back(crtDeviceTypeAttr);
     } else if (std::get_if<Fortran::parser::AccClause::Nohost>(&clause.u)) {
       hasNohost = true;
     } else if (const auto *bindClause =
                    std::get_if<Fortran::parser::AccClause::Bind>(&clause.u)) {
       if (const auto *name =
               std::get_if<Fortran::parser::Name>(&bindClause->v.u)) {
-        bindNames.push_back(
-            builder.getStringAttr(converter.mangleName(*name->symbol)));
-        bindNameDeviceTypes.push_back(crtDeviceTypeAttr);
+        mlir::Attribute bindNameAttr =
+            builder.getStringAttr(converter.mangleName(*name->symbol));
+        for (auto crtDeviceTypeAttr : crtDeviceTypes) {
+          bindNames.push_back(bindNameAttr);
+          bindNameDeviceTypes.push_back(crtDeviceTypeAttr);
+        }
       } else if (const auto charExpr =
                      std::get_if<Fortran::parser::ScalarDefaultCharExpr>(
                          &bindClause->v.u)) {
@@ -3640,18 +3692,18 @@ void Fortran::lower::genOpenACCRoutineConstruct(
                                                           *charExpr);
         if (!name)
           mlir::emitError(loc, "Could not retrieve the bind name");
-        bindNames.push_back(builder.getStringAttr(*name));
-        bindNameDeviceTypes.push_back(crtDeviceTypeAttr);
+
+        mlir::Attribute bindNameAttr = builder.getStringAttr(*name);
+        for (auto crtDeviceTypeAttr : crtDeviceTypes) {
+          bindNames.push_back(bindNameAttr);
+          bindNameDeviceTypes.push_back(crtDeviceTypeAttr);
+        }
       }
     } else if (const auto *deviceTypeClause =
                    std::get_if<Fortran::parser::AccClause::DeviceType>(
                        &clause.u)) {
-      const Fortran::parser::AccDeviceTypeExprList &deviceTypeExprList =
-          deviceTypeClause->v;
-      assert(deviceTypeExprList.v.size() == 1 &&
-             "expect only one device_type expr");
-      crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get(
-          builder.getContext(), getDeviceType(deviceTypeExprList.v.front().v));
+      crtDeviceTypes.clear();
+      gatherDeviceTypeAttrs(builder, deviceTypeClause, crtDeviceTypes);
     }
   }
 

diff  --git a/flang/test/Lower/OpenACC/acc-device-type.f90 b/flang/test/Lower/OpenACC/acc-device-type.f90
index 871dbc95f60fcba..ae01d0dc5fcde35 100644
--- a/flang/test/Lower/OpenACC/acc-device-type.f90
+++ b/flang/test/Lower/OpenACC/acc-device-type.f90
@@ -40,5 +40,9 @@ subroutine sub1()
 
 ! CHECK: acc.parallel num_gangs({%c2{{.*}} : i32}, {%c1{{.*}} : i32, %c1{{.*}} : i32, %c1{{.*}} : i32} [#acc.device_type<nvidia>])
 
+  !$acc parallel device_type(nvidia, default) num_gangs(1, 1, 1)
+  !$acc end parallel
+
+! CHECK: acc.parallel num_gangs({%c1{{.*}} : i32, %c1{{.*}} : i32, %c1{{.*}} : i32} [#acc.device_type<nvidia>], {%c1{{.*}} : i32, %c1{{.*}} : i32, %c1{{.*}} : i32} [#acc.device_type<default>])
 
 end subroutine

diff  --git a/flang/test/Lower/OpenACC/acc-loop.f90 b/flang/test/Lower/OpenACC/acc-loop.f90
index 42e14afb35f522b..59c2513332a9768 100644
--- a/flang/test/Lower/OpenACC/acc-loop.f90
+++ b/flang/test/Lower/OpenACC/acc-loop.f90
@@ -326,4 +326,10 @@ program acc_loop
 
 ! CHECK: acc.loop gang([#acc.device_type<none>], {num=%c8{{.*}} : i32} [#acc.device_type<nvidia>])
 
+  !$acc loop device_type(nvidia, default) gang
+  DO i = 1, n
+  END DO
+
+! CHECK: acc.loop gang([#acc.device_type<nvidia>, #acc.device_type<default>]) {
+
 end program

diff  --git a/flang/test/Lower/OpenACC/acc-routine.f90 b/flang/test/Lower/OpenACC/acc-routine.f90
index 2fe150e70b0cfbe..1170af18bc33410 100644
--- a/flang/test/Lower/OpenACC/acc-routine.f90
+++ b/flang/test/Lower/OpenACC/acc-routine.f90
@@ -2,6 +2,7 @@
 
 ! RUN: bbc -fopenacc -emit-hlfir %s -o - | FileCheck %s
 
+! CHECK: acc.routine @acc_routine_17 func(@_QPacc_routine19) bind("_QPacc_routine17" [#acc.device_type<host>], "_QPacc_routine17" [#acc.device_type<default>], "_QPacc_routine16" [#acc.device_type<multicore>])
 ! CHECK: acc.routine @acc_routine_16 func(@_QPacc_routine18) bind("_QPacc_routine17" [#acc.device_type<host>], "_QPacc_routine16" [#acc.device_type<multicore>])
 ! CHECK: acc.routine @acc_routine_15 func(@_QPacc_routine17) worker ([#acc.device_type<host>]) vector ([#acc.device_type<multicore>])
 ! CHECK: acc.routine @acc_routine_14 func(@_QPacc_routine16) gang([#acc.device_type<nvidia>]) seq ([#acc.device_type<host>])
@@ -120,3 +121,7 @@ subroutine acc_routine17()
 subroutine acc_routine18()
   !$acc routine device_type(host) bind(acc_routine17) dtype(multicore) bind(acc_routine16) 
 end subroutine
+
+subroutine acc_routine19()
+  !$acc routine device_type(host,default) bind(acc_routine17) dtype(multicore) bind(acc_routine16) 
+end subroutine

diff  --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index 20465f6bb86ed1d..bc03adbcae64df7 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -1449,7 +1449,8 @@ void printGangClause(OpAsmPrinter &p, Operation *op,
                      std::optional<mlir::DenseI32ArrayAttr> segments,
                      std::optional<mlir::ArrayAttr> gangOnlyDeviceTypes) {
 
-  if (operands.begin() == operands.end() && gangOnlyDeviceTypes &&
+  if (operands.begin() == operands.end() &&
+      hasDeviceTypeValues(gangOnlyDeviceTypes) &&
       gangOnlyDeviceTypes->size() == 1) {
     auto deviceTypeAttr =
         mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*gangOnlyDeviceTypes)[0]);
@@ -1464,7 +1465,7 @@ void printGangClause(OpAsmPrinter &p, Operation *op,
       hasDeviceTypeValues(deviceTypes))
     p << ", ";
 
-  if (deviceTypes) {
+  if (hasDeviceTypeValues(deviceTypes)) {
     unsigned opIdx = 0;
     llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
       p << "{";


        


More information about the flang-commits mailing list