[flang-commits] [flang] [flang][mlir][openacc] Switch device_type representation to an enum (PR #70250)

Valentin Clement バレンタイン クレメン via flang-commits flang-commits at lists.llvm.org
Wed Oct 25 15:26:51 PDT 2023


https://github.com/clementval updated https://github.com/llvm/llvm-project/pull/70250

>From 605777b63a74aebc1c50948764e4e9a3ac6341c7 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Wed, 25 Oct 2023 13:26:17 -0700
Subject: [PATCH 1/2] [flang][mlir][openacc] Switch device_type representation
 to an enum

---
 flang/include/flang/Parser/dump-parse-tree.h  |  1 +
 flang/include/flang/Parser/parse-tree.h       |  4 +-
 flang/lib/Lower/OpenACC.cpp                   | 93 +++++++++++--------
 flang/lib/Parser/openacc-parsers.cpp          | 12 ++-
 flang/test/Lower/OpenACC/acc-init.f90         | 10 +-
 flang/test/Lower/OpenACC/acc-set.f90          |  8 +-
 flang/test/Lower/OpenACC/acc-shutdown.f90     |  6 +-
 flang/test/Lower/OpenACC/acc-update.f90       |  9 +-
 flang/test/Semantics/OpenACC/acc-data.f90     |  2 +-
 .../Semantics/OpenACC/acc-init-validity.f90   |  8 +-
 .../Semantics/OpenACC/acc-kernels-loop.f90    |  4 +-
 flang/test/Semantics/OpenACC/acc-kernels.f90  |  4 +-
 flang/test/Semantics/OpenACC/acc-parallel.f90 |  6 +-
 .../Semantics/OpenACC/acc-set-validity.f90    |  8 +-
 .../OpenACC/acc-shutdown-validity.f90         |  8 +-
 .../Semantics/OpenACC/acc-update-validity.f90 |  2 +-
 .../mlir/Dialect/OpenACC/OpenACCOps.td        | 45 ++++++---
 mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp       |  5 +-
 mlir/test/Dialect/OpenACC/invalid.mlir        |  2 +-
 mlir/test/Dialect/OpenACC/ops.mlir            | 20 ++--
 20 files changed, 144 insertions(+), 113 deletions(-)

diff --git a/flang/include/flang/Parser/dump-parse-tree.h b/flang/include/flang/Parser/dump-parse-tree.h
index 494e54faa64c841..7c479a2334ea555 100644
--- a/flang/include/flang/Parser/dump-parse-tree.h
+++ b/flang/include/flang/Parser/dump-parse-tree.h
@@ -101,6 +101,7 @@ class ParseTreeDumper {
   NODE(parser, AccSelfClause)
   NODE(parser, AccStandaloneDirective)
   NODE(parser, AccDeviceTypeExpr)
+  NODE_ENUM(parser::AccDeviceTypeExpr, Device)
   NODE(parser, AccDeviceTypeExprList)
   NODE(parser, AccTileExpr)
   NODE(parser, AccTileExprList)
diff --git a/flang/include/flang/Parser/parse-tree.h b/flang/include/flang/Parser/parse-tree.h
index 83c8db936934a03..4806fc49f3441de 100644
--- a/flang/include/flang/Parser/parse-tree.h
+++ b/flang/include/flang/Parser/parse-tree.h
@@ -4061,9 +4061,9 @@ struct AccWaitArgument {
 };
 
 struct AccDeviceTypeExpr {
-  TUPLE_CLASS_BOILERPLATE(AccDeviceTypeExpr);
+  ENUM_CLASS(Device, Star, Default, Nvidia, Radeon, Host, Multicore)
+  WRAPPER_CLASS_BOILERPLATE(AccDeviceTypeExpr, Device);
   CharBlock source;
-  std::tuple<std::optional<ScalarIntExpr>> t; // if null then *
 };
 
 struct AccDeviceTypeExprList {
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 90644279d9e78ce..024c38c96277313 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -150,7 +150,7 @@ static void createDeclareAllocFuncWithArg(mlir::OpBuilder &modBuilder,
           builder, loc, registerFuncOp.getArgument(0), asFortranDesc, bounds,
           /*structured=*/false, /*implicit=*/true,
           mlir::acc::DataClause::acc_update_device, descTy);
-  llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 0, 0, 1};
+  llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 0, 1};
   llvm::SmallVector<mlir::Value> operands{updateDeviceOp.getResult()};
   createSimpleOp<mlir::acc::UpdateOp>(builder, loc, operands, operandSegments);
 
@@ -219,7 +219,7 @@ static void createDeclareDeallocFuncWithArg(
           builder, loc, loadOp, asFortran, bounds,
           /*structured=*/false, /*implicit=*/true,
           mlir::acc::DataClause::acc_update_device, loadOp.getType());
-  llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 0, 0, 1};
+  llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 0, 1};
   llvm::SmallVector<mlir::Value> operands{updateDeviceOp.getResult()};
   createSimpleOp<mlir::acc::UpdateOp>(builder, loc, operands, operandSegments);
   modBuilder.setInsertionPointAfter(postDeallocOp);
@@ -1416,27 +1416,35 @@ static void genAsyncClause(Fortran::lower::AbstractConverter &converter,
   }
 }
 
-static void genDeviceTypeClause(
-    Fortran::lower::AbstractConverter &converter, mlir::Location clauseLocation,
+static mlir::acc::DeviceType
+getDeviceType(Fortran::parser::AccDeviceTypeExpr::Device device) {
+  switch (device) {
+  case Fortran::parser::AccDeviceTypeExpr::Device::Star:
+    return mlir::acc::DeviceType::Star;
+  case Fortran::parser::AccDeviceTypeExpr::Device::Default:
+    return mlir::acc::DeviceType::Default;
+  case Fortran::parser::AccDeviceTypeExpr::Device::Nvidia:
+    return mlir::acc::DeviceType::Nvidia;
+  case Fortran::parser::AccDeviceTypeExpr::Device::Radeon:
+    return mlir::acc::DeviceType::Radeon;
+  case Fortran::parser::AccDeviceTypeExpr::Device::Host:
+    return mlir::acc::DeviceType::Host;
+  case Fortran::parser::AccDeviceTypeExpr::Device::Multicore:
+    return mlir::acc::DeviceType::Multicore;
+  }
+  return mlir::acc::DeviceType::Default;
+}
+
+static void gatherDeviceTypeAttrs(
+    fir::FirOpBuilder &builder, mlir::Location clauseLocation,
     const Fortran::parser::AccClause::DeviceType *deviceTypeClause,
-    llvm::SmallVectorImpl<mlir::Value> &operands,
+    llvm::SmallVector<mlir::Attribute> &deviceTypes,
     Fortran::lower::StatementContext &stmtCtx) {
   const Fortran::parser::AccDeviceTypeExprList &deviceTypeExprList =
       deviceTypeClause->v;
-  for (const auto &deviceTypeExpr : deviceTypeExprList.v) {
-    const auto &expr = std::get<std::optional<Fortran::parser::ScalarIntExpr>>(
-        deviceTypeExpr.t);
-    if (expr) {
-      operands.push_back(fir::getBase(converter.genExprValue(
-          *Fortran::semantics::GetExpr(expr), stmtCtx, &clauseLocation)));
-    } else {
-      // * was passed as value and will be represented as a special constant.
-      fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
-      mlir::Value star = firOpBuilder.createIntegerConstant(
-          clauseLocation, firOpBuilder.getIndexType(), starCst);
-      operands.push_back(star);
-    }
-  }
+  for (const auto &deviceTypeExpr : deviceTypeExprList.v)
+    deviceTypes.push_back(mlir::acc::DeviceTypeAttr::get(
+        builder.getContext(), getDeviceType(deviceTypeExpr.v)));
 }
 
 static void genIfClause(Fortran::lower::AbstractConverter &converter,
@@ -2428,10 +2436,10 @@ genACCInitShutdownOp(Fortran::lower::AbstractConverter &converter,
                      mlir::Location currentLocation,
                      const Fortran::parser::AccClauseList &accClauseList) {
   mlir::Value ifCond, deviceNum;
-  llvm::SmallVector<mlir::Value> deviceTypeOperands;
 
-  fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+  fir::FirOpBuilder &builder = converter.getFirOpBuilder();
   Fortran::lower::StatementContext stmtCtx;
+  llvm::SmallVector<mlir::Attribute> deviceTypes;
 
   // Lower clauses values mapped to operands.
   // Keep track of each group of operands separately as clauses can appear
@@ -2449,19 +2457,23 @@ genACCInitShutdownOp(Fortran::lower::AbstractConverter &converter,
     } else if (const auto *deviceTypeClause =
                    std::get_if<Fortran::parser::AccClause::DeviceType>(
                        &clause.u)) {
-      genDeviceTypeClause(converter, clauseLocation, deviceTypeClause,
-                          deviceTypeOperands, stmtCtx);
+      gatherDeviceTypeAttrs(builder, clauseLocation, deviceTypeClause,
+                            deviceTypes, stmtCtx);
     }
   }
 
   // Prepare the operand segment size attribute and the operands value range.
   llvm::SmallVector<mlir::Value, 6> operands;
-  llvm::SmallVector<int32_t, 3> operandSegments;
-  addOperands(operands, operandSegments, deviceTypeOperands);
+  llvm::SmallVector<int32_t, 2> operandSegments;
+
   addOperand(operands, operandSegments, deviceNum);
   addOperand(operands, operandSegments, ifCond);
 
-  createSimpleOp<Op>(firOpBuilder, currentLocation, operands, operandSegments);
+  Op op =
+      createSimpleOp<Op>(builder, currentLocation, operands, operandSegments);
+  if (!deviceTypes.empty())
+    op.setDeviceTypesAttr(
+        mlir::ArrayAttr::get(builder.getContext(), deviceTypes));
 }
 
 void genACCSetOp(Fortran::lower::AbstractConverter &converter,
@@ -2470,8 +2482,9 @@ void genACCSetOp(Fortran::lower::AbstractConverter &converter,
   mlir::Value ifCond, deviceNum, defaultAsync;
   llvm::SmallVector<mlir::Value> deviceTypeOperands;
 
-  fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+  fir::FirOpBuilder &builder = converter.getFirOpBuilder();
   Fortran::lower::StatementContext stmtCtx;
+  llvm::SmallVector<mlir::Attribute> deviceTypes;
 
   // Lower clauses values mapped to operands.
   // Keep track of each group of operands separately as clauses can appear
@@ -2494,21 +2507,22 @@ void genACCSetOp(Fortran::lower::AbstractConverter &converter,
     } else if (const auto *deviceTypeClause =
                    std::get_if<Fortran::parser::AccClause::DeviceType>(
                        &clause.u)) {
-      genDeviceTypeClause(converter, clauseLocation, deviceTypeClause,
-                          deviceTypeOperands, stmtCtx);
+      gatherDeviceTypeAttrs(builder, clauseLocation, deviceTypeClause,
+                            deviceTypes, stmtCtx);
     }
   }
 
   // Prepare the operand segment size attribute and the operands value range.
   llvm::SmallVector<mlir::Value> operands;
-  llvm::SmallVector<int32_t, 4> operandSegments;
-  addOperands(operands, operandSegments, deviceTypeOperands);
+  llvm::SmallVector<int32_t, 3> operandSegments;
   addOperand(operands, operandSegments, defaultAsync);
   addOperand(operands, operandSegments, deviceNum);
   addOperand(operands, operandSegments, ifCond);
 
-  createSimpleOp<mlir::acc::SetOp>(firOpBuilder, currentLocation, operands,
-                                   operandSegments);
+  auto op = createSimpleOp<mlir::acc::SetOp>(builder, currentLocation, operands,
+                                             operandSegments);
+  if (!deviceTypes.empty())
+    op.setDeviceTypeAttr(mlir::cast<mlir::acc::DeviceTypeAttr>(deviceTypes[0]));
 }
 
 static void
@@ -2520,6 +2534,7 @@ genACCUpdateOp(Fortran::lower::AbstractConverter &converter,
   mlir::Value ifCond, async, waitDevnum;
   llvm::SmallVector<mlir::Value> dataClauseOperands, updateHostOperands,
       waitOperands, deviceTypeOperands;
+  llvm::SmallVector<mlir::Attribute> deviceTypes;
 
   // Async and wait clause have optional values but can be present with
   // no value as well. When there is no value, the op has an attribute to
@@ -2548,8 +2563,8 @@ genACCUpdateOp(Fortran::lower::AbstractConverter &converter,
     } else if (const auto *deviceTypeClause =
                    std::get_if<Fortran::parser::AccClause::DeviceType>(
                        &clause.u)) {
-      genDeviceTypeClause(converter, clauseLocation, deviceTypeClause,
-                          deviceTypeOperands, stmtCtx);
+      gatherDeviceTypeAttrs(builder, clauseLocation, deviceTypeClause,
+                            deviceTypes, stmtCtx);
     } else if (const auto *hostClause =
                    std::get_if<Fortran::parser::AccClause::Host>(&clause.u)) {
       genDataOperandOperations<mlir::acc::GetDevicePtrOp>(
@@ -2587,11 +2602,13 @@ genACCUpdateOp(Fortran::lower::AbstractConverter &converter,
   addOperand(operands, operandSegments, async);
   addOperand(operands, operandSegments, waitDevnum);
   addOperands(operands, operandSegments, waitOperands);
-  addOperands(operands, operandSegments, deviceTypeOperands);
   addOperands(operands, operandSegments, dataClauseOperands);
 
   mlir::acc::UpdateOp updateOp = createSimpleOp<mlir::acc::UpdateOp>(
       builder, currentLocation, operands, operandSegments);
+  if (!deviceTypes.empty())
+    updateOp.setDeviceTypesAttr(
+        mlir::ArrayAttr::get(builder.getContext(), deviceTypes));
 
   genDataExitOperations<mlir::acc::GetDevicePtrOp, mlir::acc::UpdateHostOp>(
       builder, updateHostOperands, /*structured=*/false, /*implicit=*/false);
@@ -2772,7 +2789,7 @@ static void createDeclareAllocFunc(mlir::OpBuilder &modBuilder,
           builder, loc, addrOp, asFortranDesc, bounds,
           /*structured=*/false, /*implicit=*/true,
           mlir::acc::DataClause::acc_update_device, addrOp.getType());
-  llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 0, 0, 1};
+  llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 0, 1};
   llvm::SmallVector<mlir::Value> operands{updateDeviceOp.getResult()};
   createSimpleOp<mlir::acc::UpdateOp>(builder, loc, operands, operandSegments);
 
@@ -2848,7 +2865,7 @@ static void createDeclareDeallocFunc(mlir::OpBuilder &modBuilder,
           builder, loc, addrOp, asFortran, bounds,
           /*structured=*/false, /*implicit=*/true,
           mlir::acc::DataClause::acc_update_device, addrOp.getType());
-  llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 0, 0, 1};
+  llvm::SmallVector<int32_t> operandSegments{0, 0, 0, 0, 1};
   llvm::SmallVector<mlir::Value> operands{updateDeviceOp.getResult()};
   createSimpleOp<mlir::acc::UpdateOp>(builder, loc, operands, operandSegments);
   modBuilder.setInsertionPointAfter(postDeallocOp);
diff --git a/flang/lib/Parser/openacc-parsers.cpp b/flang/lib/Parser/openacc-parsers.cpp
index 131f7332a69701a..5b9267e0e17c6db 100644
--- a/flang/lib/Parser/openacc-parsers.cpp
+++ b/flang/lib/Parser/openacc-parsers.cpp
@@ -53,9 +53,15 @@ TYPE_PARSER(construct<AccSizeExpr>(scalarIntExpr) ||
     construct<AccSizeExpr>("*" >> construct<std::optional<ScalarIntExpr>>()))
 TYPE_PARSER(construct<AccSizeExprList>(nonemptyList(Parser<AccSizeExpr>{})))
 
-TYPE_PARSER(construct<AccDeviceTypeExpr>(scalarIntExpr) ||
-    construct<AccDeviceTypeExpr>(
-        "*" >> construct<std::optional<ScalarIntExpr>>()))
+TYPE_PARSER(sourced(construct<AccDeviceTypeExpr>(
+    first("*" >> pure(AccDeviceTypeExpr::Device::Star),
+        "DEFAULT" >> pure(AccDeviceTypeExpr::Device::Default),
+        "NVIDIA" >> pure(AccDeviceTypeExpr::Device::Nvidia),
+        "ACC_DEVICE_NVIDIA" >> pure(AccDeviceTypeExpr::Device::Nvidia),
+        "RADEON" >> pure(AccDeviceTypeExpr::Device::Radeon),
+        "HOST" >> pure(AccDeviceTypeExpr::Device::Host),
+        "MULTICORE" >> pure(AccDeviceTypeExpr::Device::Multicore)))))
+
 TYPE_PARSER(
     construct<AccDeviceTypeExprList>(nonemptyList(Parser<AccDeviceTypeExpr>{})))
 
diff --git a/flang/test/Lower/OpenACC/acc-init.f90 b/flang/test/Lower/OpenACC/acc-init.f90
index de940426b6f1c0b..d1fd638c7ac0e8a 100644
--- a/flang/test/Lower/OpenACC/acc-init.f90
+++ b/flang/test/Lower/OpenACC/acc-init.f90
@@ -4,6 +4,7 @@
 ! RUN: bbc -fopenacc -emit-hlfir %s -o - | FileCheck %s
 
 subroutine acc_init
+  implicit none
   logical :: ifCondition = .TRUE.
   integer :: ifInt = 1
 
@@ -23,15 +24,16 @@ subroutine acc_init
 !CHECK: [[DEVNUM:%.*]] = arith.constant 1 : i32
 !CHECK: acc.init device_num([[DEVNUM]] : i32){{$}}
 
-  !$acc init device_num(1) device_type(1, 2)
+  !$acc init device_num(1) device_type(host, multicore)
 !CHECK: [[DEVNUM:%.*]] = arith.constant 1 : i32
-!CHECK: [[DEVTYPE1:%.*]] = arith.constant 1 : i32
-!CHECK: [[DEVTYPE2:%.*]] = arith.constant 2 : i32
-!CHECK: acc.init device_type([[DEVTYPE1]], [[DEVTYPE2]] : i32, i32) device_num([[DEVNUM]] : i32){{$}}
+!CHECK: acc.init device_num([[DEVNUM]] : i32) attributes {device_types = [#acc.device_type<host>, #acc.device_type<multicore>]}
 
   !$acc init if(ifInt)
 !CHECK: %[[IFINT:.*]] = fir.load %{{.*}} : !fir.ref<i32>
 !CHECK: %[[CONV:.*]] = fir.convert %[[IFINT]] : (i32) -> i1
 !CHECK: acc.init if(%[[CONV]])
 
+   !$acc init device_type(nvidia)
+!CHECK: acc.init attributes {device_types = [#acc.device_type<nvidia>]}
+
 end subroutine acc_init
diff --git a/flang/test/Lower/OpenACC/acc-set.f90 b/flang/test/Lower/OpenACC/acc-set.f90
index 52baedeafecb2bb..39bf26e0072b7ca 100644
--- a/flang/test/Lower/OpenACC/acc-set.f90
+++ b/flang/test/Lower/OpenACC/acc-set.f90
@@ -14,7 +14,7 @@ program test_acc_set
 
 !$acc set device_type(*)
 
-!$acc set device_type(0)
+!$acc set device_type(multicore)
 
 end
 
@@ -34,10 +34,8 @@ program test_acc_set
 ! CHECK: %[[C0:.*]] = arith.constant 0 : i32
 ! CHECK: acc.set device_num(%[[C0]] : i32)
 
-! CHECK: %[[C_1:.*]] = arith.constant -1 : index
-! CHECK: acc.set device_type(%[[C_1]] : index)
+! CHECK: acc.set attributes {device_type = #acc.device_type<*>}
 
-! CHECK: %[[C0:.*]] = arith.constant 0 : i32
-! CHECK: acc.set device_type(%[[C0]] : i32)
+! CHECK: acc.set attributes {device_type = #acc.device_type<multicore>}
 
 
diff --git a/flang/test/Lower/OpenACC/acc-shutdown.f90 b/flang/test/Lower/OpenACC/acc-shutdown.f90
index 49e1acc546d900c..f63f5d62b4fe921 100644
--- a/flang/test/Lower/OpenACC/acc-shutdown.f90
+++ b/flang/test/Lower/OpenACC/acc-shutdown.f90
@@ -22,10 +22,8 @@ subroutine acc_shutdown
 !CHECK: [[DEVNUM:%.*]] = arith.constant 1 : i32
 !CHECK: acc.shutdown device_num([[DEVNUM]] : i32){{$}}
 
-  !$acc shutdown device_num(1) device_type(1, 2)
+  !$acc shutdown device_num(1) device_type(default, nvidia)
 !CHECK: [[DEVNUM:%.*]] = arith.constant 1 : i32
-!CHECK: [[DEVTYPE1:%.*]] = arith.constant 1 : i32
-!CHECK: [[DEVTYPE2:%.*]] = arith.constant 2 : i32
-!CHECK: acc.shutdown device_type([[DEVTYPE1]], [[DEVTYPE2]] : i32, i32) device_num([[DEVNUM]] : i32){{$}}
+!CHECK: acc.shutdown device_num([[DEVNUM]] : i32) attributes {device_types = [#acc.device_type<default>, #acc.device_type<nvidia>]} 
 
 end subroutine acc_shutdown
diff --git a/flang/test/Lower/OpenACC/acc-update.f90 b/flang/test/Lower/OpenACC/acc-update.f90
index f7343a69285f85a..5d5f5733ef7f1a6 100644
--- a/flang/test/Lower/OpenACC/acc-update.f90
+++ b/flang/test/Lower/OpenACC/acc-update.f90
@@ -145,20 +145,17 @@ subroutine acc_update
 ! FIR:   acc.update_host accPtr(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) to varPtr(%[[A]] : !fir.ref<!fir.array<10x10xf32>>) {name = "a", structured = false}
 ! HLFIR: acc.update_host accPtr(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) to varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) {name = "a", structured = false}
 
-  !$acc update host(a) device_type(1, 2)
+  !$acc update host(a) device_type(default, host)
 ! FIR:   %[[DEVPTR_A:.*]] = acc.getdeviceptr varPtr(%[[A]] : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) -> !fir.ref<!fir.array<10x10xf32>> {dataClause = #acc<data_clause acc_update_host>, name = "a", structured = false}
 ! HLFIR: %[[DEVPTR_A:.*]] = acc.getdeviceptr varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) -> !fir.ref<!fir.array<10x10xf32>> {dataClause = #acc<data_clause acc_update_host>, name = "a", structured = false}
-! CHECK: [[DEVTYPE1:%.*]] = arith.constant 1 : i32
-! CHECK: [[DEVTYPE2:%.*]] = arith.constant 2 : i32
-! CHECK: acc.update device_type([[DEVTYPE1]], [[DEVTYPE2]] : i32, i32) dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>){{$}}
+! CHECK: acc.update dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) attributes {device_types = [#acc.device_type<default>, #acc.device_type<host>]} 
 ! FIR:   acc.update_host accPtr(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) to varPtr(%[[A]] : !fir.ref<!fir.array<10x10xf32>>) {name = "a", structured = false}
 ! HLFIR: acc.update_host accPtr(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) to varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) {name = "a", structured = false}
 
   !$acc update host(a) device_type(*)
 ! FIR:   %[[DEVPTR_A:.*]] = acc.getdeviceptr varPtr(%[[A]] : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) -> !fir.ref<!fir.array<10x10xf32>> {dataClause = #acc<data_clause acc_update_host>, name = "a", structured = false}
 ! HLFIR: %[[DEVPTR_A:.*]] = acc.getdeviceptr varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) -> !fir.ref<!fir.array<10x10xf32>> {dataClause = #acc<data_clause acc_update_host>, name = "a", structured = false}
-! CHECK: [[DEVTYPE3:%.*]] = arith.constant -1 : index
-! CHECK: acc.update device_type([[DEVTYPE3]] : index) dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>){{$}}
+! CHECK: acc.update dataOperands(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) attributes {device_types = [#acc.device_type<*>]} 
 ! FIR:   acc.update_host accPtr(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) to varPtr(%[[A]] : !fir.ref<!fir.array<10x10xf32>>) {name = "a", structured = false}
 ! HLFIR: acc.update_host accPtr(%[[DEVPTR_A]] : !fir.ref<!fir.array<10x10xf32>>) bounds(%{{.*}}, %{{.*}}) to varPtr(%[[DECLA]]#1 : !fir.ref<!fir.array<10x10xf32>>) {name = "a", structured = false}
 
diff --git a/flang/test/Semantics/OpenACC/acc-data.f90 b/flang/test/Semantics/OpenACC/acc-data.f90
index 17e0624b8cf24d4..1a7a6f95f3d891e 100644
--- a/flang/test/Semantics/OpenACC/acc-data.f90
+++ b/flang/test/Semantics/OpenACC/acc-data.f90
@@ -184,7 +184,7 @@ program openacc_data_validity
   !$acc data copy(aa) wait
   !$acc end data
 
-  !$acc data copy(aa) device_type(1) wait
+  !$acc data copy(aa) device_type(default) wait
   !$acc end data
 
 end program openacc_data_validity
diff --git a/flang/test/Semantics/OpenACC/acc-init-validity.f90 b/flang/test/Semantics/OpenACC/acc-init-validity.f90
index f54898f73fdce28..3b594a25217c094 100644
--- a/flang/test/Semantics/OpenACC/acc-init-validity.f90
+++ b/flang/test/Semantics/OpenACC/acc-init-validity.f90
@@ -20,9 +20,9 @@ program openacc_init_validity
   !$acc init if(ifInt)
   !$acc init device_num(1)
   !$acc init device_num(i)
-  !$acc init device_type(i)
-  !$acc init device_type(2, i, j)
-  !$acc init device_num(i) device_type(i, j) if(ifCondition)
+  !$acc init device_type(default)
+  !$acc init device_type(nvidia, radeon)
+  !$acc init device_num(i) device_type(host, multicore) if(ifCondition)
 
   !$acc parallel
   !ERROR: Directive INIT may not be called within a compute region
@@ -94,7 +94,7 @@ program openacc_init_validity
   !$acc init device_num(1) device_num(i)
 
   !ERROR: At most one DEVICE_TYPE clause can appear on the INIT directive
-  !$acc init device_type(2) device_type(i, j)
+  !$acc init device_type(nvidia) device_type(default, *)
 
   !ERROR: Must have LOGICAL or INTEGER type
   !$acc init if(ifReal)
diff --git a/flang/test/Semantics/OpenACC/acc-kernels-loop.f90 b/flang/test/Semantics/OpenACC/acc-kernels-loop.f90
index 5facd4737788030..1a280f7c54f5cd3 100644
--- a/flang/test/Semantics/OpenACC/acc-kernels-loop.f90
+++ b/flang/test/Semantics/OpenACC/acc-kernels-loop.f90
@@ -264,12 +264,12 @@ program openacc_kernels_loop_validity
     a(i) = 3.14
   end do
 
-  !$acc kernels loop device_type(1)
+  !$acc kernels loop device_type(multicore)
   do i = 1, N
     a(i) = 3.14
   end do
 
-  !$acc kernels loop device_type(1, 3)
+  !$acc kernels loop device_type(host, multicore)
   do i = 1, N
     a(i) = 3.14
   end do
diff --git a/flang/test/Semantics/OpenACC/acc-kernels.f90 b/flang/test/Semantics/OpenACC/acc-kernels.f90
index a2c9c9e8be99b17..de220f7c7ddf7cf 100644
--- a/flang/test/Semantics/OpenACC/acc-kernels.f90
+++ b/flang/test/Semantics/OpenACC/acc-kernels.f90
@@ -122,10 +122,10 @@ program openacc_kernels_validity
   !$acc kernels device_type(*)
   !$acc end kernels
 
-  !$acc kernels device_type(1)
+  !$acc kernels device_type(default)
   !$acc end kernels
 
-  !$acc kernels device_type(1, 3)
+  !$acc kernels device_type(default, host)
   !$acc end kernels
 
   !$acc kernels device_type(*) async wait num_gangs(8) num_workers(8) vector_length(128)
diff --git a/flang/test/Semantics/OpenACC/acc-parallel.f90 b/flang/test/Semantics/OpenACC/acc-parallel.f90
index e85922e37c63e06..0e8d240d019983f 100644
--- a/flang/test/Semantics/OpenACC/acc-parallel.f90
+++ b/flang/test/Semantics/OpenACC/acc-parallel.f90
@@ -111,10 +111,10 @@ program openacc_parallel_validity
   !$acc parallel device_type(*)
   !$acc end parallel
 
-  !$acc parallel device_type(1)
+  !$acc parallel device_type(default)
   !$acc end parallel
 
-  !$acc parallel device_type(1, 3)
+  !$acc parallel device_type(default, host)
   !$acc end parallel
 
   !ERROR: Clause PRIVATE is not allowed after clause DEVICE_TYPE on the PARALLEL directive
@@ -131,7 +131,7 @@ program openacc_parallel_validity
   !$acc parallel device_type(*) num_gangs(8)
   !$acc end parallel
 
-  !$acc parallel device_type(1) async device_type(2) wait
+  !$acc parallel device_type(*) async device_type(host) wait
   !$acc end parallel
 
   !ERROR: Clause IF is not allowed after clause DEVICE_TYPE on the PARALLEL directive
diff --git a/flang/test/Semantics/OpenACC/acc-set-validity.f90 b/flang/test/Semantics/OpenACC/acc-set-validity.f90
index 896e39df6535c9a..74522b30d11bc49 100644
--- a/flang/test/Semantics/OpenACC/acc-set-validity.f90
+++ b/flang/test/Semantics/OpenACC/acc-set-validity.f90
@@ -90,17 +90,17 @@ program openacc_clause_validity
   !$acc set device_num(1) device_num(i)
 
   !ERROR: At most one DEVICE_TYPE clause can appear on the SET directive
-  !$acc set device_type(i) device_type(2)
+  !$acc set device_type(*) device_type(nvidia)
 
   !$acc set default_async(2)
   !$acc set default_async(i)
   !$acc set device_num(1)
   !$acc set device_num(i)
-  !$acc set device_type(i)
-  !$acc set device_num(1) default_async(2) device_type(2)
+  !$acc set device_type(default)
+  !$acc set device_num(1) default_async(2) device_type(*)
 
   !ERROR: The DEVICE_TYPE clause on the SET directive accepts only one value
-  !$acc set device_type(1, 2)
+  !$acc set device_type(*, default)
 
   !ERROR: At least one of DEFAULT_ASYNC, DEVICE_NUM, DEVICE_TYPE clause must appear on the SET directive
   !$acc set
diff --git a/flang/test/Semantics/OpenACC/acc-shutdown-validity.f90 b/flang/test/Semantics/OpenACC/acc-shutdown-validity.f90
index de40963f99e0486..43aed4fc98f42eb 100644
--- a/flang/test/Semantics/OpenACC/acc-shutdown-validity.f90
+++ b/flang/test/Semantics/OpenACC/acc-shutdown-validity.f90
@@ -80,9 +80,9 @@ program openacc_shutdown_validity
   !$acc shutdown if(ifCondition)
   !$acc shutdown device_num(1)
   !$acc shutdown device_num(i)
-  !$acc shutdown device_type(i)
-  !$acc shutdown device_type(2, i, j)
-  !$acc shutdown device_num(i) device_type(i, j) if(ifCondition)
+  !$acc shutdown device_type(*)
+  !$acc shutdown device_type(*, default, host)
+  !$acc shutdown device_num(i) device_type(default, host) if(ifCondition)
 
   !ERROR: At most one IF clause can appear on the SHUTDOWN directive
   !$acc shutdown if(.TRUE.) if(ifCondition)
@@ -91,6 +91,6 @@ program openacc_shutdown_validity
   !$acc shutdown device_num(1) device_num(i)
 
   !ERROR: At most one DEVICE_TYPE clause can appear on the SHUTDOWN directive
-  !$acc shutdown device_type(2) device_type(i, j)
+  !$acc shutdown device_type(*) device_type(host, default)
 
 end program openacc_shutdown_validity
diff --git a/flang/test/Semantics/OpenACC/acc-update-validity.f90 b/flang/test/Semantics/OpenACC/acc-update-validity.f90
index a409ba5ea549f80..1e75742e63e97b1 100644
--- a/flang/test/Semantics/OpenACC/acc-update-validity.f90
+++ b/flang/test/Semantics/OpenACC/acc-update-validity.f90
@@ -53,7 +53,7 @@ program openacc_update_validity
 
   !$acc update host(bb) device_type(*) wait
 
-  !$acc update self(cc) device_type(1,2) async device_type(3) wait
+  !$acc update self(cc) device_type(host,multicore) async device_type(*) wait
 
   !ERROR: At most one IF clause can appear on the UPDATE directive
   !$acc update device(aa) if(.true.) if(ifCondition)
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
index 10018c9fc7e27e8..3c5173fdda7f66a 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
@@ -155,6 +155,30 @@ def DeclareActionAttr : OpenACC_Attr<"DeclareAction", "declare_action"> {
   let assemblyFormat = "`<` struct(params) `>`";
 }
 
+// Device type enumeration.
+def OpenACC_DeviceTypeStar      : I32EnumAttrCase<"Star", 0, "*">;
+def OpenACC_DeviceTypeDefault   : I32EnumAttrCase<"Default", 1, "default">;
+def OpenACC_DeviceTypeHost      : I32EnumAttrCase<"Host", 2, "host">;
+def OpenACC_DeviceTypeMulticore : I32EnumAttrCase<"Multicore", 3, "multicore">;
+def OpenACC_DeviceTypeNvidia    : I32EnumAttrCase<"Nvidia", 4, "nvidia">;
+def OpenACC_DeviceTypeRadeon    : I32EnumAttrCase<"Radeon", 5, "radeon">;
+
+
+def OpenACC_DeviceType : I32EnumAttr<"DeviceType",
+    "built-in device type supported by OpenACC",
+    [OpenACC_DeviceTypeStar, OpenACC_DeviceTypeDefault,
+     OpenACC_DeviceTypeHost, OpenACC_DeviceTypeMulticore,
+     OpenACC_DeviceTypeNvidia, OpenACC_DeviceTypeRadeon
+    ]> {
+  let genSpecializedAttr = 0;
+  let cppNamespace = "::mlir::acc";
+}
+def OpenACC_DeviceTypeAttr : EnumAttr<OpenACC_Dialect,
+                                      OpenACC_DeviceType,
+                                      "device_type"> {
+  let assemblyFormat = [{ ```<` $value `>` }];
+}
+
 // Used for data specification in data clauses (2.7.1).
 // Either (or both) extent and upperbound must be specified.
 def OpenACC_DataBoundsOp : OpenACC_Op<"bounds",
@@ -1624,14 +1648,12 @@ def OpenACC_InitOp : OpenACC_Op<"init", [AttrSizedOperandSegments]> {
     ```
   }];
 
-  let arguments = (ins Variadic<AnyInteger>:$deviceTypeOperands,
+  let arguments = (ins OptionalAttr<TypedArrayAttrBase<OpenACC_DeviceTypeAttr, "Device type attributes">>:$device_types,
                        Optional<IntOrIndex>:$deviceNumOperand,
                        Optional<I1>:$ifCond);
 
   let assemblyFormat = [{
-    oilist(
-        `device_type` `(` $deviceTypeOperands `:` type($deviceTypeOperands) `)`
-      | `device_num` `(` $deviceNumOperand `:` type($deviceNumOperand) `)`
+    oilist(`device_num` `(` $deviceNumOperand `:` type($deviceNumOperand) `)`
       | `if` `(` $ifCond `)`
     ) attr-dict-with-keyword
   }];
@@ -1657,13 +1679,12 @@ def OpenACC_ShutdownOp : OpenACC_Op<"shutdown", [AttrSizedOperandSegments]> {
     ```
   }];
 
-  let arguments = (ins Variadic<AnyInteger>:$deviceTypeOperands,
+  let arguments = (ins OptionalAttr<TypedArrayAttrBase<OpenACC_DeviceTypeAttr, "Device type attributes">>:$device_types,
                        Optional<IntOrIndex>:$deviceNumOperand,
                        Optional<I1>:$ifCond);
 
   let assemblyFormat = [{
-    oilist(`device_type` `(` $deviceTypeOperands `:` type($deviceTypeOperands) `)`
-    |`device_num` `(` $deviceNumOperand `:` type($deviceNumOperand) `)`
+    oilist(`device_num` `(` $deviceNumOperand `:` type($deviceNumOperand) `)`
     |`if` `(` $ifCond `)`
     ) attr-dict-with-keyword
   }];
@@ -1687,15 +1708,13 @@ def OpenACC_SetOp : OpenACC_Op<"set", [AttrSizedOperandSegments]> {
     ```
   }];
 
-  let arguments = (ins Optional<IntOrIndex>:$deviceType,
+  let arguments = (ins OptionalAttr<OpenACC_DeviceTypeAttr>:$device_type,
                        Optional<IntOrIndex>:$defaultAsync,
                        Optional<IntOrIndex>:$deviceNum,
                        Optional<I1>:$ifCond);
 
   let assemblyFormat = [{
-    oilist(
-      `device_type` `(` $deviceType `:` type($deviceType) `)`
-    | `default_async` `(` $defaultAsync `:` type($defaultAsync) `)`
+    oilist(`default_async` `(` $defaultAsync `:` type($defaultAsync) `)`
     | `device_num` `(` $deviceNum `:` type($deviceNum) `)`
     | `if` `(` $ifCond `)`
     ) attr-dict-with-keyword
@@ -1729,7 +1748,7 @@ def OpenACC_UpdateOp : OpenACC_Op<"update", [AttrSizedOperandSegments]> {
                        Variadic<IntOrIndex>:$waitOperands,
                        UnitAttr:$async,
                        UnitAttr:$wait,
-                       Variadic<IntOrIndex>:$deviceTypeOperands,
+                       OptionalAttr<TypedArrayAttrBase<OpenACC_DeviceTypeAttr, "Device type attributes">>:$device_types,
                        Variadic<OpenACC_PointerLikeTypeInterface>:$dataClauseOperands,
                        UnitAttr:$ifPresent);
 
@@ -1746,8 +1765,6 @@ def OpenACC_UpdateOp : OpenACC_Op<"update", [AttrSizedOperandSegments]> {
         `if` `(` $ifCond `)`
       | `async` `(` $asyncOperand `:` type($asyncOperand) `)`
       | `wait_devnum` `(` $waitDevnum `:` type($waitDevnum) `)`
-      | `device_type` `(` $deviceTypeOperands `:`
-          type($deviceTypeOperands) `)`
       | `wait` `(` $waitOperands `:` type($waitOperands) `)`
       | `dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)`
     )
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index b7e2aec6a4e6a92..639f0b85450f591 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -1238,7 +1238,7 @@ LogicalResult acc::SetOp::verify() {
   while ((currOp = currOp->getParentOp()))
     if (isComputeOperation(currOp))
       return emitOpError("cannot be nested in a compute operation");
-  if (!getDeviceType() && !getDefaultAsync() && !getDeviceNum())
+  if (!getDeviceTypeAttr() && !getDefaultAsync() && !getDeviceNum())
     return emitOpError("at least one default_async, device_num, or device_type "
                        "operand must appear");
   return success();
@@ -1283,8 +1283,7 @@ Value UpdateOp::getDataOperand(unsigned i) {
   unsigned numOptional = getAsyncOperand() ? 1 : 0;
   numOptional += getWaitDevnum() ? 1 : 0;
   numOptional += getIfCond() ? 1 : 0;
-  return getOperand(getWaitOperands().size() + getDeviceTypeOperands().size() +
-                    numOptional + i);
+  return getOperand(getWaitOperands().size() + numOptional + i);
 }
 
 void UpdateOp::getCanonicalizationPatterns(RewritePatternSet &results,
diff --git a/mlir/test/Dialect/OpenACC/invalid.mlir b/mlir/test/Dialect/OpenACC/invalid.mlir
index ff92eab478bb4f5..b5241a8e4dc47fa 100644
--- a/mlir/test/Dialect/OpenACC/invalid.mlir
+++ b/mlir/test/Dialect/OpenACC/invalid.mlir
@@ -475,7 +475,7 @@ acc.parallel num_gangs(%i64value, %i64value, %i64value, %i64value : i64, i64, i6
 %i64value = arith.constant 1 : i64
 acc.parallel {
 // expected-error at +1 {{'acc.set' op cannot be nested in a compute operation}}
-  acc.set device_type(%i64value : i64)
+  acc.set attributes {device_type = #acc.device_type<nvidia>}
   acc.yield
 }
 
diff --git a/mlir/test/Dialect/OpenACC/ops.mlir b/mlir/test/Dialect/OpenACC/ops.mlir
index d1950b1fb3f2916..cf7a838f55ef855 100644
--- a/mlir/test/Dialect/OpenACC/ops.mlir
+++ b/mlir/test/Dialect/OpenACC/ops.mlir
@@ -974,7 +974,7 @@ func.func @testupdateop(%a: memref<f32>, %b: memref<f32>, %c: memref<f32>) -> ()
   acc.update async(%idxValue: index) dataOperands(%0: memref<f32>)
   acc.update wait_devnum(%i64Value: i64) wait(%i32Value, %idxValue : i32, index) dataOperands(%0: memref<f32>)
   acc.update if(%ifCond) dataOperands(%0: memref<f32>)
-  acc.update device_type(%i32Value : i32) dataOperands(%0: memref<f32>)
+  acc.update dataOperands(%0: memref<f32>) attributes {acc.device_types = [#acc.device_type<nvidia>]}
   acc.update dataOperands(%0, %1, %2 : memref<f32>, memref<f32>, memref<f32>)
   acc.update dataOperands(%0, %1, %2 : memref<f32>, memref<f32>, memref<f32>) attributes {async}
   acc.update dataOperands(%0, %1, %2 : memref<f32>, memref<f32>, memref<f32>) attributes {wait}
@@ -993,7 +993,7 @@ func.func @testupdateop(%a: memref<f32>, %b: memref<f32>, %c: memref<f32>) -> ()
 // CHECK:   acc.update async([[IDXVALUE]] : index) dataOperands(%{{.*}} : memref<f32>)
 // CHECK:   acc.update wait_devnum([[I64VALUE]] : i64) wait([[I32VALUE]], [[IDXVALUE]] : i32, index) dataOperands(%{{.*}} : memref<f32>)
 // CHECK:   acc.update if([[IFCOND]]) dataOperands(%{{.*}} : memref<f32>)
-// CHECK:   acc.update device_type([[I32VALUE]] : i32) dataOperands(%{{.*}} : memref<f32>)
+// CHECK:   acc.update dataOperands(%{{.*}} : memref<f32>) attributes {acc.device_types = [#acc.device_type<nvidia>]}
 // CHECK:   acc.update dataOperands(%{{.*}}, %{{.*}}, %{{.*}} : memref<f32>, memref<f32>, memref<f32>)
 // CHECK:   acc.update dataOperands(%{{.*}}, %{{.*}}, %{{.*}} : memref<f32>, memref<f32>, memref<f32>) attributes {async}
 // CHECK:   acc.update dataOperands(%{{.*}}, %{{.*}}, %{{.*}} : memref<f32>, memref<f32>, memref<f32>) attributes {wait}
@@ -1047,8 +1047,7 @@ acc.wait if(%ifCond)
 %idxValue = arith.constant 1 : index
 %ifCond = arith.constant true
 acc.init
-acc.init device_type(%i32Value : i32)
-acc.init device_type(%i32Value, %i32Value2 : i32, i32)
+acc.init attributes {acc.device_types = [#acc.device_type<nvidia>]}
 acc.init device_num(%i64Value : i64)
 acc.init device_num(%i32Value : i32)
 acc.init device_num(%idxValue : index)
@@ -1062,8 +1061,7 @@ acc.init device_num(%idxValue : index) if(%ifCond)
 // CHECK: [[IDXVALUE:%.*]] = arith.constant 1 : index
 // CHECK: [[IFCOND:%.*]] = arith.constant true
 // CHECK: acc.init
-// CHECK: acc.init device_type([[I32VALUE]] : i32)
-// CHECK: acc.init device_type([[I32VALUE]], [[I32VALUE2]] : i32, i32)
+// CHECK: acc.init attributes {acc.device_types = [#acc.device_type<nvidia>]} 
 // CHECK: acc.init device_num([[I64VALUE]] : i64)
 // CHECK: acc.init device_num([[I32VALUE]] : i32)
 // CHECK: acc.init device_num([[IDXVALUE]] : index)
@@ -1079,8 +1077,7 @@ acc.init device_num(%idxValue : index) if(%ifCond)
 %idxValue = arith.constant 1 : index
 %ifCond = arith.constant true
 acc.shutdown
-acc.shutdown device_type(%i32Value : i32)
-acc.shutdown device_type(%i32Value, %i32Value2 : i32, i32)
+acc.shutdown attributes {acc.device_types = [#acc.device_type<default>]}
 acc.shutdown device_num(%i64Value : i64)
 acc.shutdown device_num(%i32Value : i32)
 acc.shutdown device_num(%idxValue : index)
@@ -1094,8 +1091,7 @@ acc.shutdown device_num(%idxValue : index) if(%ifCond)
 // CHECK: [[IDXVALUE:%.*]] = arith.constant 1 : index
 // CHECK: [[IFCOND:%.*]] = arith.constant true
 // CHECK: acc.shutdown
-// CHECK: acc.shutdown device_type([[I32VALUE]] : i32)
-// CHECK: acc.shutdown device_type([[I32VALUE]], [[I32VALUE2]] : i32, i32)
+// CHECK: acc.shutdown attributes {acc.device_types = [#acc.device_type<default>]}
 // CHECK: acc.shutdown device_num([[I64VALUE]] : i64)
 // CHECK: acc.shutdown device_num([[I32VALUE]] : i32)
 // CHECK: acc.shutdown device_num([[IDXVALUE]] : index)
@@ -1718,7 +1714,7 @@ func.func @compute3(%a: memref<10x10xf32>, %b: memref<10x10xf32>, %c: memref<10x
 %i32Value2 = arith.constant 2 : i32
 %idxValue = arith.constant 1 : index
 %ifCond = arith.constant true
-acc.set device_type(%i32Value : i32)
+acc.set attributes {device_type = #acc.device_type<nvidia>}
 acc.set device_num(%i64Value : i64)
 acc.set device_num(%i32Value : i32)
 acc.set device_num(%idxValue : index)
@@ -1730,7 +1726,7 @@ acc.set default_async(%i32Value : i32)
 // CHECK: [[I32VALUE2:%.*]] = arith.constant 2 : i32
 // CHECK: [[IDXVALUE:%.*]] = arith.constant 1 : index
 // CHECK: [[IFCOND:%.*]] = arith.constant true
-// CHECK: acc.set device_type([[I32VALUE]] : i32)
+// CHECK: acc.set attributes {device_type = #acc.device_type<nvidia>}
 // CHECK: acc.set device_num([[I64VALUE]] : i64)
 // CHECK: acc.set device_num([[I32VALUE]] : i32)
 // CHECK: acc.set device_num([[IDXVALUE]] : index)

>From 1c68f9b24534698a94371e44206fbe3c3aa634d0 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Wed, 25 Oct 2023 15:26:37 -0700
Subject: [PATCH 2/2] Add assert for acc.set device_type

---
 flang/lib/Lower/OpenACC.cpp | 4 +++-
 1 file changed, 3 insertions(+), 1 deletion(-)

diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 024c38c96277313..b5dadb1b3da0f51 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -2521,8 +2521,10 @@ void genACCSetOp(Fortran::lower::AbstractConverter &converter,
 
   auto op = createSimpleOp<mlir::acc::SetOp>(builder, currentLocation, operands,
                                              operandSegments);
-  if (!deviceTypes.empty())
+  if (!deviceTypes.empty()) {
+    assert(deviceTypes.size() == 1 && "expect only one value for acc.set");
     op.setDeviceTypeAttr(mlir::cast<mlir::acc::DeviceTypeAttr>(deviceTypes[0]));
+  }
 }
 
 static void



More information about the flang-commits mailing list