[Mlir-commits] [mlir] [flang][mlir][openacc] Switch device_type representation to an enum (PR #70250)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Oct 25 13:30:30 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-openacc

Author: Valentin Clement (バレンタイン クレメン) (clementval)

<details>
<summary>Changes</summary>

Switch the representation from scalar integer to a enumeration. The parser transform the string in the input to the correct enumeration. 

---

Patch is 36.87 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/70250.diff


20 Files Affected:

- (modified) flang/include/flang/Parser/dump-parse-tree.h (+1) 
- (modified) flang/include/flang/Parser/parse-tree.h (+2-2) 
- (modified) flang/lib/Lower/OpenACC.cpp (+55-38) 
- (modified) flang/lib/Parser/openacc-parsers.cpp (+9-3) 
- (modified) flang/test/Lower/OpenACC/acc-init.f90 (+6-4) 
- (modified) flang/test/Lower/OpenACC/acc-set.f90 (+3-5) 
- (modified) flang/test/Lower/OpenACC/acc-shutdown.f90 (+2-4) 
- (modified) flang/test/Lower/OpenACC/acc-update.f90 (+3-6) 
- (modified) flang/test/Semantics/OpenACC/acc-data.f90 (+1-1) 
- (modified) flang/test/Semantics/OpenACC/acc-init-validity.f90 (+4-4) 
- (modified) flang/test/Semantics/OpenACC/acc-kernels-loop.f90 (+2-2) 
- (modified) flang/test/Semantics/OpenACC/acc-kernels.f90 (+2-2) 
- (modified) flang/test/Semantics/OpenACC/acc-parallel.f90 (+3-3) 
- (modified) flang/test/Semantics/OpenACC/acc-set-validity.f90 (+4-4) 
- (modified) flang/test/Semantics/OpenACC/acc-shutdown-validity.f90 (+4-4) 
- (modified) flang/test/Semantics/OpenACC/acc-update-validity.f90 (+1-1) 
- (modified) mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td (+31-14) 
- (modified) mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp (+2-3) 
- (modified) mlir/test/Dialect/OpenACC/invalid.mlir (+1-1) 
- (modified) mlir/test/Dialect/OpenACC/ops.mlir (+8-12) 


``````````diff
diff --git a/flang/include/flang/Parser/dump-parse-tree.h b/flang/include/flang/Parser/dump-parse-tree.h
index 494e54faa64c841..7c479a2334ea555 100644
--- a/flang/include/flang/Parser/dump-parse-tree.h
+++ b/flang/include/flang/Parser/dump-parse-tree.h
@@ -101,6 +101,7 @@ class ParseTreeDumper {
   NODE(parser, AccSelfClause)
   NODE(parser, AccStandaloneDirective)
   NODE(parser, AccDeviceTypeExpr)
+  NODE_ENUM(parser::AccDeviceTypeExpr, Device)
   NODE(parser, AccDeviceTypeExprList)
   NODE(parser, AccTileExpr)
   NODE(parser, AccTileExprList)
diff --git a/flang/include/flang/Parser/parse-tree.h b/flang/include/flang/Parser/parse-tree.h
index 83c8db936934a03..4806fc49f3441de 100644
--- a/flang/include/flang/Parser/parse-tree.h
+++ b/flang/include/flang/Parser/parse-tree.h
@@ -4061,9 +4061,9 @@ struct AccWaitArgument {
 };
 
 struct AccDeviceTypeExpr {
-  TUPLE_CLASS_BOILERPLATE(AccDeviceTypeExpr);
+  ENUM_CLASS(Device, Star, Default, Nvidia, Radeon, Host, Multicore)
+  WRAPPER_CLASS_BOILERPLATE(AccDeviceTypeExpr, Device);
   CharBlock source;
-  std::tuple<std::optional<ScalarIntExpr>> t; // if null then *
 };
 
 struct AccDeviceTypeExprList {
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 90644279d9e78ce..024c38c96277313 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -150,7 +150,7 @@ static void createDeclareAllocFuncWithArg(mlir::OpBuilder &modBuilder,
           builder, loc, registerFuncOp.getArgument(0), asFortranDesc, bounds,
           /*structured=*/false, /*implicit=*/true,
           mlir::acc::DataClause::acc_update_device, descTy);
-  llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 0, 0, 1};
+  llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 0, 1};
   llvm::SmallVector<mlir::Value> operands{updateDeviceOp.getResult()};
   createSimpleOp<mlir::acc::UpdateOp>(builder, loc, operands, operandSegments);
 
@@ -219,7 +219,7 @@ static void createDeclareDeallocFuncWithArg(
           builder, loc, loadOp, asFortran, bounds,
           /*structured=*/false, /*implicit=*/true,
           mlir::acc::DataClause::acc_update_device, loadOp.getType());
-  llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 0, 0, 1};
+  llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 0, 1};
   llvm::SmallVector<mlir::Value> operands{updateDeviceOp.getResult()};
   createSimpleOp<mlir::acc::UpdateOp>(builder, loc, operands, operandSegments);
   modBuilder.setInsertionPointAfter(postDeallocOp);
@@ -1416,27 +1416,35 @@ static void genAsyncClause(Fortran::lower::AbstractConverter &converter,
   }
 }
 
-static void genDeviceTypeClause(
-    Fortran::lower::AbstractConverter &converter, mlir::Location clauseLocation,
+static mlir::acc::DeviceType
+getDeviceType(Fortran::parser::AccDeviceTypeExpr::Device device) {
+  switch (device) {
+  case Fortran::parser::AccDeviceTypeExpr::Device::Star:
+    return mlir::acc::DeviceType::Star;
+  case Fortran::parser::AccDeviceTypeExpr::Device::Default:
+    return mlir::acc::DeviceType::Default;
+  case Fortran::parser::AccDeviceTypeExpr::Device::Nvidia:
+    return mlir::acc::DeviceType::Nvidia;
+  case Fortran::parser::AccDeviceTypeExpr::Device::Radeon:
+    return mlir::acc::DeviceType::Radeon;
+  case Fortran::parser::AccDeviceTypeExpr::Device::Host:
+    return mlir::acc::DeviceType::Host;
+  case Fortran::parser::AccDeviceTypeExpr::Device::Multicore:
+    return mlir::acc::DeviceType::Multicore;
+  }
+  return mlir::acc::DeviceType::Default;
+}
+
+static void gatherDeviceTypeAttrs(
+    fir::FirOpBuilder &builder, mlir::Location clauseLocation,
     const Fortran::parser::AccClause::DeviceType *deviceTypeClause,
-    llvm::SmallVectorImpl<mlir::Value> &operands,
+    llvm::SmallVector<mlir::Attribute> &deviceTypes,
     Fortran::lower::StatementContext &stmtCtx) {
   const Fortran::parser::AccDeviceTypeExprList &deviceTypeExprList =
       deviceTypeClause->v;
-  for (const auto &deviceTypeExpr : deviceTypeExprList.v) {
-    const auto &expr = std::get<std::optional<Fortran::parser::ScalarIntExpr>>(
-        deviceTypeExpr.t);
-    if (expr) {
-      operands.push_back(fir::getBase(converter.genExprValue(
-          *Fortran::semantics::GetExpr(expr), stmtCtx, &clauseLocation)));
-    } else {
-      // * was passed as value and will be represented as a special constant.
-      fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
-      mlir::Value star = firOpBuilder.createIntegerConstant(
-          clauseLocation, firOpBuilder.getIndexType(), starCst);
-      operands.push_back(star);
-    }
-  }
+  for (const auto &deviceTypeExpr : deviceTypeExprList.v)
+    deviceTypes.push_back(mlir::acc::DeviceTypeAttr::get(
+        builder.getContext(), getDeviceType(deviceTypeExpr.v)));
 }
 
 static void genIfClause(Fortran::lower::AbstractConverter &converter,
@@ -2428,10 +2436,10 @@ genACCInitShutdownOp(Fortran::lower::AbstractConverter &converter,
                      mlir::Location currentLocation,
                      const Fortran::parser::AccClauseList &accClauseList) {
   mlir::Value ifCond, deviceNum;
-  llvm::SmallVector<mlir::Value> deviceTypeOperands;
 
-  fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+  fir::FirOpBuilder &builder = converter.getFirOpBuilder();
   Fortran::lower::StatementContext stmtCtx;
+  llvm::SmallVector<mlir::Attribute> deviceTypes;
 
   // Lower clauses values mapped to operands.
   // Keep track of each group of operands separately as clauses can appear
@@ -2449,19 +2457,23 @@ genACCInitShutdownOp(Fortran::lower::AbstractConverter &converter,
     } else if (const auto *deviceTypeClause =
                    std::get_if<Fortran::parser::AccClause::DeviceType>(
                        &clause.u)) {
-      genDeviceTypeClause(converter, clauseLocation, deviceTypeClause,
-                          deviceTypeOperands, stmtCtx);
+      gatherDeviceTypeAttrs(builder, clauseLocation, deviceTypeClause,
+                            deviceTypes, stmtCtx);
     }
   }
 
   // Prepare the operand segment size attribute and the operands value range.
   llvm::SmallVector<mlir::Value, 6> operands;
-  llvm::SmallVector<int32_t, 3> operandSegments;
-  addOperands(operands, operandSegments, deviceTypeOperands);
+  llvm::SmallVector<int32_t, 2> operandSegments;
+
   addOperand(operands, operandSegments, deviceNum);
   addOperand(operands, operandSegments, ifCond);
 
-  createSimpleOp<Op>(firOpBuilder, currentLocation, operands, operandSegments);
+  Op op =
+      createSimpleOp<Op>(builder, currentLocation, operands, operandSegments);
+  if (!deviceTypes.empty())
+    op.setDeviceTypesAttr(
+        mlir::ArrayAttr::get(builder.getContext(), deviceTypes));
 }
 
 void genACCSetOp(Fortran::lower::AbstractConverter &converter,
@@ -2470,8 +2482,9 @@ void genACCSetOp(Fortran::lower::AbstractConverter &converter,
   mlir::Value ifCond, deviceNum, defaultAsync;
   llvm::SmallVector<mlir::Value> deviceTypeOperands;
 
-  fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+  fir::FirOpBuilder &builder = converter.getFirOpBuilder();
   Fortran::lower::StatementContext stmtCtx;
+  llvm::SmallVector<mlir::Attribute> deviceTypes;
 
   // Lower clauses values mapped to operands.
   // Keep track of each group of operands separately as clauses can appear
@@ -2494,21 +2507,22 @@ void genACCSetOp(Fortran::lower::AbstractConverter &converter,
     } else if (const auto *deviceTypeClause =
                    std::get_if<Fortran::parser::AccClause::DeviceType>(
                        &clause.u)) {
-      genDeviceTypeClause(converter, clauseLocation, deviceTypeClause,
-                          deviceTypeOperands, stmtCtx);
+      gatherDeviceTypeAttrs(builder, clauseLocation, deviceTypeClause,
+                            deviceTypes, stmtCtx);
     }
   }
 
   // Prepare the operand segment size attribute and the operands value range.
   llvm::SmallVector<mlir::Value> operands;
-  llvm::SmallVector<int32_t, 4> operandSegments;
-  addOperands(operands, operandSegments, deviceTypeOperands);
+  llvm::SmallVector<int32_t, 3> operandSegments;
   addOperand(operands, operandSegments, defaultAsync);
   addOperand(operands, operandSegments, deviceNum);
   addOperand(operands, operandSegments, ifCond);
 
-  createSimpleOp<mlir::acc::SetOp>(firOpBuilder, currentLocation, operands,
-                                   operandSegments);
+  auto op = createSimpleOp<mlir::acc::SetOp>(builder, currentLocation, operands,
+                                             operandSegments);
+  if (!deviceTypes.empty())
+    op.setDeviceTypeAttr(mlir::cast<mlir::acc::DeviceTypeAttr>(deviceTypes[0]));
 }
 
 static void
@@ -2520,6 +2534,7 @@ genACCUpdateOp(Fortran::lower::AbstractConverter &converter,
   mlir::Value ifCond, async, waitDevnum;
   llvm::SmallVector<mlir::Value> dataClauseOperands, updateHostOperands,
       waitOperands, deviceTypeOperands;
+  llvm::SmallVector<mlir::Attribute> deviceTypes;
 
   // Async and wait clause have optional values but can be present with
   // no value as well. When there is no value, the op has an attribute to
@@ -2548,8 +2563,8 @@ genACCUpdateOp(Fortran::lower::AbstractConverter &converter,
     } else if (const auto *deviceTypeClause =
                    std::get_if<Fortran::parser::AccClause::DeviceType>(
                        &clause.u)) {
-      genDeviceTypeClause(converter, clauseLocation, deviceTypeClause,
-                          deviceTypeOperands, stmtCtx);
+      gatherDeviceTypeAttrs(builder, clauseLocation, deviceTypeClause,
+                            deviceTypes, stmtCtx);
     } else if (const auto *hostClause =
                    std::get_if<Fortran::parser::AccClause::Host>(&clause.u)) {
       genDataOperandOperations<mlir::acc::GetDevicePtrOp>(
@@ -2587,11 +2602,13 @@ genACCUpdateOp(Fortran::lower::AbstractConverter &converter,
   addOperand(operands, operandSegments, async);
   addOperand(operands, operandSegments, waitDevnum);
   addOperands(operands, operandSegments, waitOperands);
-  addOperands(operands, operandSegments, deviceTypeOperands);
   addOperands(operands, operandSegments, dataClauseOperands);
 
   mlir::acc::UpdateOp updateOp = createSimpleOp<mlir::acc::UpdateOp>(
       builder, currentLocation, operands, operandSegments);
+  if (!deviceTypes.empty())
+    updateOp.setDeviceTypesAttr(
+        mlir::ArrayAttr::get(builder.getContext(), deviceTypes));
 
   genDataExitOperations<mlir::acc::GetDevicePtrOp, mlir::acc::UpdateHostOp>(
       builder, updateHostOperands, /*structured=*/false, /*implicit=*/false);
@@ -2772,7 +2789,7 @@ static void createDeclareAllocFunc(mlir::OpBuilder &modBuilder,
           builder, loc, addrOp, asFortranDesc, bounds,
           /*structured=*/false, /*implicit=*/true,
           mlir::acc::DataClause::acc_update_device, addrOp.getType());
-  llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 0, 0, 1};
+  llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 0, 1};
   llvm::SmallVector<mlir::Value> operands{updateDeviceOp.getResult()};
   createSimpleOp<mlir::acc::UpdateOp>(builder, loc, operands, operandSegments);
 
@@ -2848,7 +2865,7 @@ static void createDeclareDeallocFunc(mlir::OpBuilder &modBuilder,
           builder, loc, addrOp, asFortran, bounds,
           /*structured=*/false, /*implicit=*/true,
           mlir::acc::DataClause::acc_update_device, addrOp.getType());
-  llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 0, 0, 1};
+  llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 0, 1};
   llvm::SmallVector<mlir::Value> operands{updateDeviceOp.getResult()};
   createSimpleOp<mlir::acc::UpdateOp>(builder, loc, operands, operandSegments);
   modBuilder.setInsertionPointAfter(postDeallocOp);
diff --git a/flang/lib/Parser/openacc-parsers.cpp b/flang/lib/Parser/openacc-parsers.cpp
index 131f7332a69701a..5b9267e0e17c6db 100644
--- a/flang/lib/Parser/openacc-parsers.cpp
+++ b/flang/lib/Parser/openacc-parsers.cpp
@@ -53,9 +53,15 @@ TYPE_PARSER(construct<AccSizeExpr>(scalarIntExpr) ||
     construct<AccSizeExpr>("*" >> construct<std::optional<ScalarIntExpr>>()))
 TYPE_PARSER(construct<AccSizeExprList>(nonemptyList(Parser<AccSizeExpr>{})))
 
-TYPE_PARSER(construct<AccDeviceTypeExpr>(scalarIntExpr) ||
-    construct<AccDeviceTypeExpr>(
-        "*" >> construct<std::optional<ScalarIntExpr>>()))
+TYPE_PARSER(sourced(construct<AccDeviceTypeExpr>(
+    first("*" >> pure(AccDeviceTypeExpr::Device::Star),
+        "DEFAULT" >> pure(AccDeviceTypeExpr::Device::Default),
+        "NVIDIA" >> pure(AccDeviceTypeExpr::Device::Nvidia),
+        "ACC_DEVICE_NVIDIA" >> pure(AccDeviceTypeExpr::Device::Nvidia),
+        "RADEON" >> pure(AccDeviceTypeExpr::Device::Radeon),
+        "HOST" >> pure(AccDeviceTypeExpr::Device::Host),
+        "MULTICORE" >> pure(AccDeviceTypeExpr::Device::Multicore)))))
+
 TYPE_PARSER(
     construct<AccDeviceTypeExprList>(nonemptyList(Parser<AccDeviceTypeExpr>{})))
 
diff --git a/flang/test/Lower/OpenACC/acc-init.f90 b/flang/test/Lower/OpenACC/acc-init.f90
index de940426b6f1c0b..d1fd638c7ac0e8a 100644
--- a/flang/test/Lower/OpenACC/acc-init.f90
+++ b/flang/test/Lower/OpenACC/acc-init.f90
@@ -4,6 +4,7 @@
 ! RUN: bbc -fopenacc -emit-hlfir %s -o - | FileCheck %s
 
 subroutine acc_init
+  implicit none
   logical :: ifCondition = .TRUE.
   integer :: ifInt = 1
 
@@ -23,15 +24,16 @@ subroutine acc_init
 !CHECK: [[DEVNUM:%.*]] = arith.constant 1 : i32
 !CHECK: acc.init device_num([[DEVNUM]] : i32){{$}}
 
-  !$acc init device_num(1) device_type(1, 2)
+  !$acc init device_num(1) device_type(host, multicore)
 !CHECK: [[DEVNUM:%.*]] = arith.constant 1 : i32
-!CHECK: [[DEVTYPE1:%.*]] = arith.constant 1 : i32
-!CHECK: [[DEVTYPE2:%.*]] = arith.constant 2 : i32
-!CHECK: acc.init device_type([[DEVTYPE1]], [[DEVTYPE2]] : i32, i32) device_num([[DEVNUM]] : i32){{$}}
+!CHECK: acc.init device_num([[DEVNUM]] : i32) attributes {device_types = [#acc.device_type<host>, #acc.device_type<multicore>]}
 
   !$acc init if(ifInt)
 !CHECK: %[[IFINT:.*]] = fir.load %{{.*}} : !fir.ref<i32>
 !CHECK: %[[CONV:.*]] = fir.convert %[[IFINT]] : (i32) -> i1
 !CHECK: acc.init if(%[[CONV]])
 
+   !$acc init device_type(nvidia)
+!CHECK: acc.init attributes {device_types = [#acc.device_type<nvidia>]}
+
 end subroutine acc_init
diff --git a/flang/test/Lower/OpenACC/acc-set.f90 b/flang/test/Lower/OpenACC/acc-set.f90
index 52baedeafecb2bb..39bf26e0072b7ca 100644
--- a/flang/test/Lower/OpenACC/acc-set.f90
+++ b/flang/test/Lower/OpenACC/acc-set.f90
@@ -14,7 +14,7 @@ program test_acc_set
 
 !$acc set device_type(*)
 
-!$acc set device_type(0)
+!$acc set device_type(multicore)
 
 end
 
@@ -34,10 +34,8 @@ program test_acc_set
 ! CHECK: %[[C0:.*]] = arith.constant 0 : i32
 ! CHECK: acc.set device_num(%[[C0]] : i32)
 
-! CHECK: %[[C_1:.*]] = arith.constant -1 : index
-! CHECK: acc.set device_type(%[[C_1]] : index)
+! CHECK: acc.set attributes {device_type = #acc.device_type<*>}
 
-! CHECK: %[[C0:.*]] = arith.constant 0 : i32
-! CHECK: acc.set device_type(%[[C0]] : i32)
+! CHECK: acc.set attributes {device_type = #acc.device_type<multicore>}
 
 
diff --git a/flang/test/Lower/OpenACC/acc-shutdown.f90 b/flang/test/Lower/OpenACC/acc-shutdown.f90
index 49e1acc546d900c..f63f5d62b4fe921 100644
--- a/flang/test/Lower/OpenACC/acc-shutdown.f90
+++ b/flang/test/Lower/OpenACC/acc-shutdown.f90
@@ -22,10 +22,8 @@ subroutine acc_shutdown
 !CHECK: [[DEVNUM:%.*]] = arith.constant 1 : i32
 !CHECK: acc.shutdown device_num([[DEVNUM]] : i32){{$}}
 
-  !$acc shutdown device_num(1) device_type(1, 2)
+  !$acc shutdown device_num(1) device_type(default, nvidia)
 !CHECK: [[DEVNUM:%.*]] = arith.constant 1 : i32
-!CHECK: [[DEVTYPE1:%.*]] = arith.constant 1 : i32
-!CHECK: [[DEVTYPE2:%.*]] = arith.constant 2 : i32
-!CHECK: acc.shutdown device_type([[DEVTYPE1]], [[DEVTYPE2]] : i32, i32) device_num([[DEVNUM]] : i32){{$}}
+!CHECK: acc.shutdown device_num([[DEVNUM]] : i32) attributes {device_types = [#acc.device_type<default>, #acc.device_type<nvidia>]} 
 
 end subroutine acc_shutdown
diff --git a/flang/test/Lower/OpenACC/acc-update.f90 b/flang/test/Lower/OpenACC/acc-update.f90
index f7343a69285f85a..5d5f5733ef7f1a6 100644
--- a/flang/test/Lower/OpenACC/acc-update.f90
+++ b/flang/test/Lower/OpenACC/acc-update.f90
@@ -145,20 +145,17 @@ subroutine acc_update
 ! FIR:   acc.update_host accPtr(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) to varPtr(%[[A]] : !fir.ref<!fir.array<10x10xf32>>) {name = "a", structured = false}
 ! HLFIR: acc.update_host accPtr(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) to varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) {name = "a", structured = false}
 
-  !$acc update host(a) device_type(1, 2)
+  !$acc update host(a) device_type(default, host)
 ! FIR:   %[[DEVPTR_A:.*]] = acc.getdeviceptr varPtr(%[[A]] : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) -> !fir.ref<!fir.array<10x10xf32>> {dataClause = #acc<data_clause acc_update_host>, name = "a", structured = false}
 ! HLFIR: %[[DEVPTR_A:.*]] = acc.getdeviceptr varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) -> !fir.ref<!fir.array<10x10xf32>> {dataClause = #acc<data_clause acc_update_host>, name = "a", structured = false}
-! CHECK: [[DEVTYPE1:%.*]] = arith.constant 1 : i32
-! CHECK: [[DEVTYPE2:%.*]] = arith.constant 2 : i32
-! CHECK: acc.update device_type([[DEVTYPE1]], [[DEVTYPE2]] : i32, i32) dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>){{$}}
+! CHECK: acc.update dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) attributes {device_types = [#acc.device_type<default>, #acc.device_type<host>]} 
 ! FIR:   acc.update_host accPtr(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) to varPtr(%[[A]] : !fir.ref<!fir.array<10x10xf32>>) {name = "a", structured = false}
 ! HLFIR: acc.update_host accPtr(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) to varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) {name = "a", structured = false}
 
   !$acc update host(a) device_type(*)
 ! FIR:   %[[DEVPTR_A:.*]] = acc.getdeviceptr varPtr(%[[A]] : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) -> !fir.ref<!fir.array<10x10xf32>> {dataClause = #acc<data_clause acc_update_host>, name = "a", structured = false}
 ! HLFIR: %[[DEVPTR_A:.*]] = acc.getdeviceptr varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) -> !fir.ref<!fir.array<10x10xf32>> {dataClause = #acc<data_clause acc_update_host>, name = "a", structured = false}
-! CHECK: [[DEVTYPE3:%.*]] = arith.constant -1 : index
-! CHECK: acc.update device_type([[DEVTYPE3]] : index) dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>){{$}}
+! CHECK: acc.update dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) attributes {device_types = [#acc.device_type<*>]} 
 ! FIR:   acc.update_host accPtr(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) to varPtr(%[[A]] : !fir.ref<!fir.array<10x10xf32>>) {name = "a", structured = false}
 ! HLFIR: acc.update_host accPtr(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) to varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) {name = "a", structured = false}
 
diff --git a/flang/test/Semantics/OpenACC/acc-data.f90 b/flang/test/Semantics/OpenACC/acc-data.f90
index 17e0624b8cf24d4..1a7a6f95f3d891e 100644
--- a/flang/test/Semantics/OpenACC/acc-data.f90
+++ b/flang/test/Semantics/OpenACC/acc-data.f90
@@ -184,7 +184,7 @@ program openacc_data_validity
   !$acc data copy(aa) wait
   !$acc end data
 
-  !$acc data copy(aa) device_type(1) wait
+  !$acc data copy(aa) device_type(default) wait
   !$acc end data
 
 end program openacc_data_validity
diff --git a/flang/test/Semantics/OpenACC/acc-init-validity.f90 b/flang/test/Semantics/OpenACC/acc-init-validity.f90
index f54898f73fdce28..3b594a25217c094 100644
--- a/flang/test/Semantics/OpenACC/acc-init-validity.f90
+++ b/flang/test/Semantics/OpenACC/acc-init-validity.f90
@@ -20,9 +20,9 @@ program openacc_init_validity
   !$acc init if(ifInt)
   !$acc init device_num(1)
   !$acc init device_num(i)
-  !$acc init device_type(i)
-  !$acc init device_type(2, i, j)
-  !$acc init device_num(i) device_type(i, j) if(ifCondition)
+  !$acc init device_type(default)
+  !$acc init device_type(nvidia, radeon)
+  !$acc init device_num(i) device_type(host, multicore) if(ifCondition)
 
   !$acc parallel...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/70250


More information about the Mlir-commits mailing list